[
  {
    "path": ".github/CODE_OF_CONDUCT.md",
    "content": "\n# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.\n\nWe pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.\n\n## Our Standards\n\nExamples of behavior that contributes to a positive environment for our community include:\n\n* Demonstrating empathy and kindness toward other people\n* Being respectful of differing opinions, viewpoints, and experiences\n* Giving and gracefully accepting constructive feedback\n* Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience\n* Focusing on what is best not just for us as individuals, but for the overall community\n\nExamples of unacceptable behavior include:\n\n* The use of sexualized language or imagery, and sexual attention or advances of any kind\n* Trolling, insulting or derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or email address, without their explicit permission\n* Other conduct which could reasonably be considered inappropriate in a professional setting\n\n## Enforcement Responsibilities\n\nCommunity leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.\n\nCommunity leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.\n\n## Scope\n\nThis Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [INSERT CONTACT METHOD]. All complaints will be reviewed and investigated promptly and fairly.\n\nAll community leaders are obligated to respect the privacy and security of the reporter of any incident.\n\n## Enforcement Guidelines\n\nCommunity leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:\n\n### 1. Correction\n\n**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.\n\n**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.\n\n### 2. Warning\n\n**Community Impact**: A violation through a single incident or series of actions.\n\n**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.\n\n### 3. Temporary Ban\n\n**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior.\n\n**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.\n\n### 4. Permanent Ban\n\n**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior,  harassment of an individual, or aggression toward or disparagement of classes of individuals.\n\n**Consequence**: A permanent ban from any sort of public interaction within the community.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0].\n\nCommunity Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC].\n\nFor answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations].\n\n[homepage]: https://www.contributor-covenant.org\n[v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html\n[Mozilla CoC]: https://github.com/mozilla/diversity\n[FAQ]: https://www.contributor-covenant.org/faq\n[translations]: https://www.contributor-covenant.org/translations\n"
  },
  {
    "path": ".github/CONTRIBUTING.md",
    "content": "## Before Commit!\n\nYour commit message must follow Conventional Commits (https://www.conventionalcommits.org/) and your code should be formatted. The Git hooks will do most of the work automatically:\n\n### Tool Requirements\n\nYou need a recent `clang-format` (>= 18). In a conda environment you can install:\n\n```shell\nconda install -c conda-forge clang-format=18\n```\n\nIf you previously configured with an older version, remove the build directory and reconfigure:\n\n```shell\nrm -rf kt-kernel/build\n```\n\nInstall `black` for Python formatting:\n\n```shell\nconda install black\n```\n\n### Install hook:\n```shell\nbash kt-kernel/scripts/install-git-hooks.sh\n#or just cmake the kt-kernel\ncmake -S kt-kernel -B kt-kernel/build\n```\n\nThere are manual commands if you need format.\n\n```shell\ncmake -S kt-kernel -B kt-kernel/build\ncmake --build kt-kernel/build --target format\n```\n\n## Developer Note\n\nFormatting and commit message rules are enforced by Git hooks. After installing `clang-format` and `black`, just commit normally—the hooks will run formatting for you.\n\n> [!NOTE]\n> If formatting modifies files, the commit is aborted after staging those changes. Review them and run `git commit` again. Repeat until no further formatting changes appear.\n\n---\n\n### Conventional Commit Regex (Reference)\n\nThe commit-msg hook enforces this pattern:\n\n```text\nregex='^\\[(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip)\\](\\([^\\)]+\\))?(!)?: .+'\n```\n\nMeaning (English):\n* `[type]` required — one of feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip\n* Optional scope: `(scope)` — any chars except `)`\n* Optional breaking change marker: `!` right after type or scope\n* Separator: `: ` (colon + space)\n* Subject: free text (at least one character)\n\nExamples:\n```text\n[feat]: add adaptive batching\n[fix(parser)]: handle empty token list\n[docs]!: update API section for breaking rename\n```\n\nYou can bypass locally (not recommended) with:\n```shell\ngit commit --no-verify\n```\n## 提交前提醒\n\n提交信息必须满足 Conventional Commits 规范 (https://www.conventionalcommits.org/)，代码需要符合格式要求。Git 钩子已经集成了大部分工作：\n### 软件要求\n\n需要较新的 `clang-format` (>= 18)，在 conda 环境中安装：\n\n```shell\nconda install -c conda-forge clang-format=18\n```\n\n如果之前用老版本配置过，请删除构建目录重新配置：\n\n```shell\nrm -rf kt-kernel/build\n```\n\n安装 `black` 以进行 Python 文件格式化：\n\n```shell\nconda install black\n```\n### 安装钩子\n```shell\nbash kt-kernel/scripts/install-git-hooks.sh\n#or just cmake the kt-kernel\ncmake -S kt-kernel -B kt-kernel/build\n```\n如果你需要手动格式化：\n```shell\ncmake -S kt-kernel -B kt-kernel/build\ncmake --build kt-kernel/build --target format\n```\n\n## 开发者说明\n\n本仓库通过 Git hooks 自动执行代码格式化与提交信息规范检查。只需安装好 `clang-format` 与 `black` 后正常执行提交即可，钩子会自动格式化。\n\n> [!NOTE]\n> 如果格式化修改了文件，钩子会终止提交并已暂存这些改动。请查看修改后再次执行 `git commit`，重复直到没有新的格式化变更。\n\n### 提交信息正则（参考）\n\n钩子使用如下正则检查提交信息：\n```text\nregex='^\\[(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip)\\](\\([^\\)]+\\))?(!)?: .+'\n```\n含义：\n* `[type]` 必填：feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip\n* 作用域可选：`(scope)`，不能包含右括号\n* 可选的破坏性标记：`!`\n* 分隔符：冒号+空格 `: `\n* 描述：至少一个字符\n\n示例：\n```text\n[feat]: 增加自适应 batch 功能\n[fix(tokenizer)]: 修复空 token 列表处理\n[docs]!: 更新接口文档（存在破坏性修改）\n```\n\n跳过钩子（不推荐，仅紧急时）：\n```shell\ngit commit --no-verify\n```\n\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/-bug-.yaml",
    "content": "name: \"\\U0001F41B Bug / Help\"\ndescription: Create a report to help us improve the ktransformers project\nlabels: [\"pending\"]\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        Issues included in **[FAQs](https://github.com/kvcache-ai/ktransformers/issues/1608)** or those with **insufficient** information may be closed without a response.\n        已经包含在 **[常见问题](https://github.com/kvcache-ai/ktransformers/issues/1608)** 内或提供信息**不完整**的 issues 可能不会被回复。\n\n  - type: checkboxes\n    id: reminder\n    attributes:\n      label: Reminder\n      description: |\n        Please ensure you have read the above rules carefully and searched the existing issues (including FAQs).\n        请确保您已经认真阅读了上述规则并且搜索过现有的 issues（包括常见问题）。\n\n      options:\n        - label: I have read the above rules and searched the existing issues.\n          required: true\n\n  - type: textarea\n    id: system-info\n    validations:\n      required: true\n    attributes:\n      label: System Info\n      description: |\n        Please share your system info with us. You can run the command **lscpu**, ** nvidia-smi ** etc. and copy-paste its output below.\n        请提供您的系统信息。您可以在命令行运行 **lscpu**, **nvidia-smi** 等命令，并将其输出复制到该文本框中。\n\n      placeholder: ktransformers version,sglang version, platform, python version, cpu info, GPU/NPU info ...\n\n  - type: textarea\n    id: reproduction\n    validations:\n      required: true\n    attributes:\n      label: Reproduction\n      description: |\n        Please provide entry arguments, error messages and stack traces that reproduces the problem.\n        请提供入口参数，错误日志以及异常堆栈以便于我们复现问题。\n\n      value: |\n        ```text\n        Put your message here.\n        ```\n\n  - type: textarea\n    id: others\n    validations:\n      required: false\n    attributes:\n      label: Others"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/-feature-.yaml",
    "content": "name: \"\\U0001F680 Feature request\"\ndescription: Submit a request for a new feature\nlabels: [\"enhancement\", \"pending\"]\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        Please do not create issues that are not related to new features under this category.\n        请勿在此分类下创建和新特性无关的 issues。\n\n  - type: checkboxes\n    id: reminder\n    attributes:\n      label: Reminder\n      description: |\n        Please ensure you have read the above rules carefully and searched the existing issues.\n        请确保您已经认真阅读了上述规则并且搜索过现有的 issues。\n\n      options:\n        - label: I have read the above rules and searched the existing issues.\n          required: true\n\n  - type: textarea\n    id: description\n    validations:\n      required: true\n    attributes:\n      label: Description\n      description: |\n        A clear and concise description of the feature proposal.\n        请详细描述您希望加入的新功能特性。\n\n  - type: textarea\n    id: contribution\n    validations:\n      required: false\n    attributes:\n      label: Pull Request\n      description: |\n        Have you already created the relevant PR and submitted the code?\n        您是否已经创建了相关 PR 并提交了代码？"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "blank_issues_enabled: false\ncontact_links:\n  - name: 📚 FAQs | 常见问题\n    url: https://github.com/kvcache-ai/ktransformers/issues/1608\n    about: Reading in advance is recommended | 建议提前阅读"
  },
  {
    "path": ".github/PULL_REQUEST_TEMPLATE.md",
    "content": "# What does this PR do?\n\nFixes # (issue)\n\n## Before submitting\n\n- [ ] Did you read the [contributor guideline](https://github.com/kvcache-ai/ktransformers/blob/main/.github/CONTRIBUTING.md)?\n- [ ] Did you write any new necessary tests?"
  },
  {
    "path": ".github/SECURITY.md",
    "content": "# Reporting Security Issues\n\nTo report a security issue, please use the GitHub Security Advisory [\"Report a Vulnerability\"](https://github.com/kvcache-ai/ktransformers/security/advisories/new) tab.\n\nWe will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.\n\nReport security bugs in third-party modules to the person or team maintaining the module."
  },
  {
    "path": ".github/workflows/book-ci.yml",
    "content": "name: Book-CI\n\non:\n  push:\n    branches:\n      - main\n      # - server_support\n\n  pull_request:\n    branches:\n      - main\n      # - server_support\njobs:\n  test:\n    name: test\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [ubuntu-latest, macos-latest, windows-latest]\n    steps:\n      - uses: actions/checkout@v4\n      - name: Install Rust\n        run: |\n          rustup set profile minimal\n          rustup toolchain install stable\n          rustup default stable\n      - name: Setup mdBook\n        uses: peaceiris/actions-mdbook@v2\n        with:\n          mdbook-version: \"latest\"\n      # - name: Run tests\n      #   run: mdbook test"
  },
  {
    "path": ".github/workflows/deploy.yml",
    "content": "name: Deploy\n\non:\n  push:\n    branches:\n      - main\n      # - server_support\n\n  pull_request:\n    branches:\n      - main\n      # - server_support\n\ndefaults:\n  run:\n    shell: bash\n\npermissions:\n  contents: write\n\njobs:\n  deploy:\n    runs-on: ${{ matrix.os }}\n    strategy:\n      matrix:\n        os: [ubuntu-latest, macos-latest, windows-latest]\n    steps:\n      - uses: actions/checkout@v4\n      - name: Install Rust\n        run: |\n          rustup set profile minimal\n          rustup toolchain install stable\n          rustup default stable\n      - name: Setup mdBook\n        uses: peaceiris/actions-mdbook@v2\n        with:\n          mdbook-version: \"latest\"\n      - run: mdbook build\n      # - name: Copy Assets\n      #   run: |\n      #     chmod +x ci/copy-assets.sh\n      #     ci/copy-assets.sh ${{ matrix.os }}\n      - name: Deploy\n        uses: peaceiris/actions-gh-pages@v3\n        # or || github.ref == 'refs/heads/server_support'\n        if: ${{ github.ref == 'refs/heads/main' }}\n        with:\n          github_token: ${{ secrets.GITHUB_TOKEN }}\n          publish_dir: ./book"
  },
  {
    "path": ".github/workflows/docker-image.yml",
    "content": "name: DockerHub CI\n\non:\n  release:\n    types: [published]\n  workflow_dispatch:\n    inputs:\n      push_to_dockerhub:\n        description: 'Push image to DockerHub? (true/false)'\n        required: true\n        default: 'false'\n        type: boolean\n      cuda_version:\n        description: 'CUDA version (e.g., 12.8.1)'\n        required: false\n        default: '12.8.1'\n        type: string\n      push_simplified_tag:\n        description: 'Also push simplified tag? (true/false)'\n        required: false\n        default: 'true'\n        type: boolean\n      ubuntu_mirror:\n        description: 'Use Tsinghua Ubuntu mirror? (0/1)'\n        required: false\n        default: '0'\n        type: string\n\n  # push:\n  #   branches:\n  #     - main\nenv:\n  DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/ktransformers\njobs:\n  test:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - name: Run tests\n        run: |\n          if [ -f docker-compose.test.yml ]; then\n            docker-compose --file docker-compose.test.yml build\n            docker-compose --file docker-compose.test.yml run sut\n          else\n            docker build . --file docker/Dockerfile\n          fi\n\n  build-and-push:\n    needs: test\n    name: Build and Push Multi-Variant Docker Image\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n\n      - name: Move Docker data directory\n        run: |\n          sudo systemctl stop docker\n          sudo mkdir -p /mnt/docker\n          sudo rsync -avz /var/lib/docker/ /mnt/docker\n          sudo rm -rf /var/lib/docker\n          sudo ln -s /mnt/docker /var/lib/docker\n          sudo systemctl start docker\n\n      - name: Set up Docker Buildx\n        uses: docker/setup-buildx-action@v3\n\n      - name: Login to Docker Hub\n        uses: docker/login-action@v3\n        with:\n          username: ${{ secrets.DOCKERHUB_USERNAME }}\n          password: ${{ secrets.DOCKERHUB_TOKEN }}\n\n      - name: Determine build parameters\n        id: params\n        run: |\n          # Determine if we should push\n          if [ \"${{ github.event_name }}\" = \"release\" ]; then\n            echo \"should_push=true\" >> $GITHUB_OUTPUT\n            echo \"push_simplified=true\" >> $GITHUB_OUTPUT\n          elif [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ]; then\n            echo \"should_push=${{ inputs.push_to_dockerhub }}\" >> $GITHUB_OUTPUT\n            echo \"push_simplified=${{ inputs.push_simplified_tag }}\" >> $GITHUB_OUTPUT\n          else\n            echo \"should_push=false\" >> $GITHUB_OUTPUT\n            echo \"push_simplified=false\" >> $GITHUB_OUTPUT\n          fi\n\n          # Determine CUDA version\n          if [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ] && [ -n \"${{ inputs.cuda_version }}\" ]; then\n            echo \"cuda_version=${{ inputs.cuda_version }}\" >> $GITHUB_OUTPUT\n          else\n            echo \"cuda_version=12.8.1\" >> $GITHUB_OUTPUT\n          fi\n\n          # Determine Ubuntu mirror setting\n          if [ \"${{ github.event_name }}\" = \"workflow_dispatch\" ] && [ -n \"${{ inputs.ubuntu_mirror }}\" ]; then\n            echo \"ubuntu_mirror=${{ inputs.ubuntu_mirror }}\" >> $GITHUB_OUTPUT\n          else\n            echo \"ubuntu_mirror=0\" >> $GITHUB_OUTPUT\n          fi\n\n      - name: Build and push Docker image\n        run: |\n          cd docker\n\n          # Build command arguments\n          BUILD_ARGS=(\n            --cuda-version \"${{ steps.params.outputs.cuda_version }}\"\n            --ubuntu-mirror \"${{ steps.params.outputs.ubuntu_mirror }}\"\n            --repository \"${{ env.DOCKERHUB_REPO }}\"\n          )\n\n          # Add simplified tag option if enabled\n          if [ \"${{ steps.params.outputs.push_simplified }}\" = \"true\" ]; then\n            BUILD_ARGS+=(--also-push-simplified)\n          fi\n\n          # Add HTTP proxy if available\n          if [ -n \"${{ secrets.HTTP_PROXY }}\" ]; then\n            BUILD_ARGS+=(--http-proxy \"${{ secrets.HTTP_PROXY }}\")\n          fi\n\n          # Add HTTPS proxy if available\n          if [ -n \"${{ secrets.HTTPS_PROXY }}\" ]; then\n            BUILD_ARGS+=(--https-proxy \"${{ secrets.HTTPS_PROXY }}\")\n          fi\n\n          # Dry run if not pushing\n          if [ \"${{ steps.params.outputs.should_push }}\" != \"true\" ]; then\n            BUILD_ARGS+=(--dry-run)\n          fi\n\n          # Execute build script\n          ./push-to-dockerhub.sh \"${BUILD_ARGS[@]}\"\n\n      - name: Display image information\n        if: steps.params.outputs.should_push == 'true'\n        run: |\n          echo \"::notice title=Docker Image::Image pushed successfully to ${{ env.DOCKERHUB_REPO }}\"\n          echo \"Pull command: docker pull ${{ env.DOCKERHUB_REPO }}:v\\$(VERSION)-cu\\$(CUDA_SHORT)\"\n"
  },
  {
    "path": ".github/workflows/kt-kernel-tests.yml",
    "content": "name: PR KT-Kernel Test\n\non:\n  pull_request:\n    branches:\n      - main\n      - develop\n    types: [synchronize, labeled]\n  workflow_dispatch:\n\nconcurrency:\n  group: pr-kt-kernel-test-${{ github.ref }}\n  cancel-in-progress: true\n\njobs:\n  # =============================================== check changes ====================================================\n  check-changes:\n    runs-on: ubuntu-latest\n    outputs:\n      kt_kernel: ${{ steps.filter.outputs.kt_kernel }}\n    steps:\n      - name: Checkout code\n        uses: actions/checkout@v4\n\n      - name: Fail if the PR does not have the 'run-ci' label\n        if: github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'run-ci')\n        run: |\n          echo \"This pull request does not have the 'run-ci' label. Failing the workflow.\"\n          exit 1\n\n      - name: Fail if the PR is a draft\n        if: github.event_name == 'pull_request' && github.event.pull_request.draft == true\n        run: |\n          echo \"This pull request is a draft. Failing the workflow.\"\n          exit 1\n\n      - name: Detect file changes\n        id: filter\n        uses: dorny/paths-filter@v3\n        with:\n          filters: |\n            kt_kernel:\n              - \"kt-kernel/**\"\n              - \".github/workflows/kt-kernel-tests.yml\"\n\n  # =============================================== KT-Kernel tests ====================================================\n  per-commit-kt-kernel-cpu:\n    needs: [check-changes]\n    if: always() && !failure() && !cancelled() &&\n      (needs.check-changes.outputs.kt_kernel == 'true' || github.event_name == 'workflow_dispatch')\n    runs-on: kt-cpu\n    continue-on-error: false\n    steps:\n      - name: Cleanup\n        run: |\n          sudo rm -rf $GITHUB_WORKSPACE/* || true\n\n      - name: Checkout code\n        uses: actions/checkout@v4\n        with:\n          submodules: recursive\n\n      - name: Install KT-Kernel\n        run: |\n          cd kt-kernel\n          bash install.sh build\n\n      - name: Run KT-Kernel CPU tests\n        timeout-minutes: 60\n        run: |\n          cd kt-kernel/test\n          python3 run_suite.py --hw cpu --suite default\n\n  # =============================================== finish ====================================================\n  pr-test-kt-kernel-finish:\n    needs: [check-changes, per-commit-kt-kernel-cpu]\n    if: always()\n    runs-on: ubuntu-latest\n    steps:\n      - name: Check all dependent job statuses\n        run: |\n          # Convert the 'needs' context to a JSON string\n          json_needs='${{ toJson(needs) }}'\n\n          # Get a list of all job names from the JSON keys\n          job_names=$(echo \"$json_needs\" | jq -r 'keys_unsorted[]')\n\n          for job in $job_names; do\n            # For each job, extract its result\n            result=$(echo \"$json_needs\" | jq -r --arg j \"$job\" '.[$j].result')\n\n            # Print the job name and its result\n            echo \"$job: $result\"\n\n            # Check for failure or cancellation and exit if found\n            if [[ \"$result\" == \"failure\" || \"$result\" == \"cancelled\" ]]; then\n              echo \"The above jobs failed.\"\n              exit 1\n            fi\n          done\n\n          # If the loop completes, all jobs were successful\n          echo \"All jobs completed successfully\"\n          exit 0\n"
  },
  {
    "path": ".github/workflows/release-fake-tag.yml",
    "content": "name: Release Fake Tag\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - \"version.py\"\n  workflow_dispatch:\n\npermissions:\n  contents: write\n\njobs:\n  publish:\n    if: github.repository == 'kvcache-ai/ktransformers'\n    runs-on: ubuntu-latest\n    environment: 'prod'\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Get version\n        id: get_version\n        run: |\n          version=$(cat version.py | grep '__version__' | cut -d'\"' -f2)\n          echo \"TAG=v$version\" >> $GITHUB_OUTPUT\n\n      - name: Create and push tag\n        run: |\n          git config user.name \"ktransformers-bot\"\n          git config user.email \"ktransformers-bot@users.noreply.github.com\"\n          git tag ${{ steps.get_version.outputs.TAG }}\n          git push origin ${{ steps.get_version.outputs.TAG }}\n"
  },
  {
    "path": ".github/workflows/release-pypi.yml",
    "content": "name: Release to PyPI\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - \"version.py\"\n  workflow_dispatch:\n    inputs:\n      test_pypi:\n        description: 'Publish to TestPyPI instead of PyPI (for testing)'\n        required: false\n        default: 'false'\n        type: choice\n        options:\n          - 'true'\n          - 'false'\n\npermissions:\n  contents: read\n\njobs:\n  # ── sglang-kt (must be on PyPI before users can pip install kt-kernel) ──\n  build-and-publish-sglang-kt:\n    name: Build & publish sglang-kt\n    runs-on: [self-hosted, linux, x64]\n    if: github.repository == 'kvcache-ai/ktransformers' && github.ref == 'refs/heads/main'\n    environment: prod\n    permissions:\n      id-token: write\n      contents: read\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          submodules: recursive\n\n      - name: Set up Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: '3.12'\n\n      - name: Install build tools\n        run: |\n          python -m pip install --upgrade pip\n          pip install build wheel setuptools twine\n\n      - name: Build sglang-kt wheel\n        working-directory: third_party/sglang/python\n        run: |\n          KT_VERSION=$(python3 -c \"exec(open('${{ github.workspace }}/version.py').read()); print(__version__)\")\n          export SGLANG_KT_VERSION=\"$KT_VERSION\"\n          echo \"Building sglang-kt v${KT_VERSION} wheel...\"\n          python -m build --wheel -v\n          ls dist/ | grep -q \"sglang_kt\" || (echo \"ERROR: Wheel name does not contain sglang_kt\" && exit 1)\n\n      - name: Publish sglang-kt to PyPI\n        if: github.event.inputs.test_pypi != 'true'\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}\n        run: |\n          python -m twine upload --skip-existing --verbose third_party/sglang/python/dist/*.whl\n\n      - name: Publish sglang-kt to TestPyPI (if requested)\n        if: github.event.inputs.test_pypi == 'true'\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}\n        run: |\n          python -m twine upload --repository testpypi --skip-existing --verbose third_party/sglang/python/dist/*.whl\n\n  # ── kt-kernel ──\n  build-kt-kernel:\n    name: Build kt-kernel (Python ${{ matrix.python-version }})\n    runs-on: [self-hosted, linux, x64, gpu]\n    strategy:\n      fail-fast: false\n      matrix:\n        python-version: ['3.11', '3.12']\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          submodules: recursive\n\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v4\n        with:\n          python-version: ${{ matrix.python-version }}\n\n      - name: Verify CUDA availability\n        run: |\n          nvidia-smi || (echo \"ERROR: GPU not available\" && exit 1)\n          nvcc --version || (echo \"ERROR: CUDA toolkit not found\" && exit 1)\n\n      - name: Install dependencies\n        run: |\n          apt-get update && apt-get install -y cmake libhwloc-dev pkg-config libnuma-dev\n          python -m pip install --upgrade pip\n          pip install build wheel setuptools torch --index-url https://download.pytorch.org/whl/cu118\n\n      - name: Build kt-kernel wheel\n        working-directory: kt-kernel\n        env:\n          CPUINFER_BUILD_ALL_VARIANTS: '1'\n          CPUINFER_USE_CUDA: '1'\n          CPUINFER_CUDA_ARCHS: '80;86;89;90'\n          CPUINFER_CUDA_STATIC_RUNTIME: '1'\n          CPUINFER_BUILD_TYPE: 'Release'\n          CPUINFER_PARALLEL: '4'\n          CPUINFER_FORCE_REBUILD: '1'\n          CUDA_HOME: '/usr/local/cuda-11.8'\n        run: |\n          echo \"Building kt-kernel with:\"\n          echo \"  - CUDA support (SM 80, 86, 89, 90)\"\n          echo \"  - CPU multi-variant (AMX, AVX512, AVX2)\"\n          python -m build --wheel -v\n\n      - name: Verify wheel\n        working-directory: kt-kernel\n        run: |\n          echo \"Generated wheel:\"\n          ls -lh dist/\n\n          # Install and test\n          pip install dist/*.whl\n          python -c \"import kt_kernel; print(f'✓ Version: {kt_kernel.__version__}')\"\n          python -c \"import kt_kernel; print(f'✓ CPU variant: {kt_kernel.__cpu_variant__}')\"\n\n          # Verify CUDA support\n          python -c \"\n          from kt_kernel import kt_kernel_ext\n          cpu_infer = kt_kernel_ext.CPUInfer(4)\n          methods = dir(cpu_infer)\n          has_cuda = 'submit_with_cuda_stream' in methods\n          print(f'✓ CUDA support: {has_cuda}')\n          \"\n\n          # Verify CPU multi-variant support\n          echo \"Checking CPU variants in wheel...\"\n          python -m zipfile -l dist/*.whl | grep \"_kt_kernel_ext_\" || echo \"Warning: No variant .so files found\"\n          python -m zipfile -l dist/*.whl | grep \"_kt_kernel_ext_amx.cpython\" && echo \"✓ AMX variant found\" || echo \"Note: AMX variant missing\"\n          python -m zipfile -l dist/*.whl | grep \"_kt_kernel_ext_avx512\" && echo \"✓ AVX512 variants found\" || echo \"Note: AVX512 variants missing\"\n          python -m zipfile -l dist/*.whl | grep \"_kt_kernel_ext_avx2.cpython\" && echo \"✓ AVX2 variant found\" || echo \"Note: AVX2 variant missing\"\n\n          # Verify static linking (should NOT depend on libcudart.so)\n          rm -rf /tmp/check\n          unzip -q dist/*.whl -d /tmp/check\n          if ldd /tmp/check/kt_kernel/*.so 2>/dev/null | grep -q \"libcudart.so\"; then\n            echo \"ERROR: Dynamic cudart found, should be statically linked\"\n            exit 1\n          else\n            echo \"✓ CUDA runtime statically linked\"\n          fi\n\n      - name: Repair wheel for manylinux\n        working-directory: kt-kernel\n        run: |\n          pip install auditwheel patchelf\n          mkdir -p wheelhouse\n          for wheel in dist/*.whl; do\n            auditwheel repair \"$wheel\" --plat manylinux_2_17_x86_64 --exclude libcuda.so.1 -w wheelhouse/ || \\\n              cp \"$wheel\" wheelhouse/$(basename \"$wheel\" | sed 's/linux_x86_64/manylinux_2_17_x86_64/')\n          done\n          rm -f dist/*.whl && cp wheelhouse/*.whl dist/\n\n      - name: Upload artifact\n        uses: actions/upload-artifact@v4\n        with:\n          name: kt-kernel-wheels-py${{ matrix.python-version }}\n          path: kt-kernel/dist/*.whl\n          retention-days: 7\n\n  publish-pypi:\n    name: Publish kt-kernel to PyPI\n    needs: [build-and-publish-sglang-kt, build-kt-kernel]\n    runs-on: [self-hosted, linux, x64]\n    if: github.repository == 'kvcache-ai/ktransformers' && github.ref == 'refs/heads/main'\n    environment: prod\n    permissions:\n      id-token: write  # For trusted publishing (OIDC)\n      contents: read\n\n    steps:\n      - name: Download all wheel artifacts\n        uses: actions/download-artifact@v4\n        with:\n          path: artifacts/\n\n      - name: Organize wheels into dist/\n        run: |\n          mkdir -p dist/\n          find artifacts/ -name \"*.whl\" -exec cp {} dist/ \\;\n          echo \"Wheels to publish:\"\n          ls -lh dist/\n\n      - name: Get version from wheel\n        id: get_version\n        run: |\n          # Extract version from first wheel filename\n          wheel_name=$(ls dist/*.whl | head -1 | xargs basename)\n          # Extract version (format: kt_kernel-X.Y.Z-...)\n          version=$(echo \"$wheel_name\" | sed 's/kt_kernel-\\([0-9.]*\\)-.*/\\1/')\n          echo \"VERSION=$version\" >> $GITHUB_OUTPUT\n          echo \"Publishing version: $version\"\n\n      - name: Install twine\n        run: |\n          python -m pip install --upgrade pip\n          pip install twine\n\n      - name: Publish to TestPyPI (if requested)\n        if: github.event.inputs.test_pypi == 'true'\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}\n        run: |\n          python -m twine upload \\\n            --repository testpypi \\\n            --skip-existing \\\n            --verbose \\\n            dist/*.whl\n\n      - name: Publish to PyPI\n        if: github.event.inputs.test_pypi != 'true'\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}\n        run: |\n          python -m twine upload \\\n            --skip-existing \\\n            --verbose \\\n            dist/*.whl\n\n      - name: Create release summary\n        run: |\n          echo \"## 🎉 kt-kernel v${{ steps.get_version.outputs.VERSION }} Published to PyPI\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"### Installation\" >> $GITHUB_STEP_SUMMARY\n          echo '```bash' >> $GITHUB_STEP_SUMMARY\n          echo \"pip install kt-kernel==${{ steps.get_version.outputs.VERSION }}\" >> $GITHUB_STEP_SUMMARY\n          echo '```' >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"### Published Wheels\" >> $GITHUB_STEP_SUMMARY\n          echo \"Total: $(ls -1 dist/*.whl | wc -l) wheels (Python 3.10, 3.11, 3.12)\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"### Features\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"**CPU Multi-Variant Support:**\" >> $GITHUB_STEP_SUMMARY\n          echo \"- ✅ AMX (Intel Sapphire Rapids+, 2023)\" >> $GITHUB_STEP_SUMMARY\n          echo \"- ✅ AVX512 Base/VNNI/VBMI/BF16 (Intel Skylake-X/Ice Lake/Cascade Lake, 2017+)\" >> $GITHUB_STEP_SUMMARY\n          echo \"- ✅ AVX2 (Maximum compatibility, 2013+)\" >> $GITHUB_STEP_SUMMARY\n          echo \"- 🔧 Runtime CPU detection: Automatically selects optimal variant\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"**CUDA Support:**\" >> $GITHUB_STEP_SUMMARY\n          echo \"- ✅ SM 80 (Ampere: A100, RTX 3000 series)\" >> $GITHUB_STEP_SUMMARY\n          echo \"- ✅ SM 86 (Ampere: RTX 3060-3090)\" >> $GITHUB_STEP_SUMMARY\n          echo \"- ✅ SM 89 (Ada Lovelace: RTX 4000 series)\" >> $GITHUB_STEP_SUMMARY\n          echo \"- ✅ SM 90 (Hopper: H100)\" >> $GITHUB_STEP_SUMMARY\n          echo \"- 🔧 Static CUDA runtime: Compatible with CUDA 11.8+ and 12.x drivers\" >> $GITHUB_STEP_SUMMARY\n          echo \"- 🔧 Works on CPU-only systems (CUDA features disabled gracefully)\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"**Requirements:**\" >> $GITHUB_STEP_SUMMARY\n          echo \"- Python 3.10, 3.11, or 3.12\" >> $GITHUB_STEP_SUMMARY\n          echo \"- Linux x86-64 (manylinux_2_17 compatible)\" >> $GITHUB_STEP_SUMMARY\n          echo \"- For CUDA features: NVIDIA driver with CUDA 11.8+ or 12.x support\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"PyPI link: https://pypi.org/project/kt-kernel/${{ steps.get_version.outputs.VERSION }}/\" >> $GITHUB_STEP_SUMMARY\n"
  },
  {
    "path": ".github/workflows/release-sglang-kt.yml",
    "content": "name: Release sglang-kt to PyPI\n\non:\n  push:\n    branches:\n      - main\n    paths:\n      - \"third_party/sglang\"\n      - \"version.py\"\n  workflow_dispatch:\n    inputs:\n      test_pypi:\n        description: 'Publish to TestPyPI instead of PyPI (for testing)'\n        required: false\n        default: 'false'\n        type: choice\n        options:\n          - 'true'\n          - 'false'\n\npermissions:\n  contents: read\n\njobs:\n  build-sglang-kt:\n    name: Build sglang-kt wheel\n    runs-on: [self-hosted, linux, x64]\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          submodules: recursive\n\n      - name: Set up Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: '3.12'\n\n      - name: Install build tools\n        run: |\n          python -m pip install --upgrade pip\n          pip install build wheel setuptools\n\n      - name: Build sglang-kt wheel\n        working-directory: third_party/sglang/python\n        run: |\n          # Read version from ktransformers version.py\n          KT_VERSION=$(python3 -c \"exec(open('${{ github.workspace }}/version.py').read()); print(__version__)\")\n          export SGLANG_KT_VERSION=\"$KT_VERSION\"\n          echo \"Building sglang-kt v${KT_VERSION} wheel...\"\n          python -m build --wheel -v\n\n      - name: Verify wheel\n        working-directory: third_party/sglang/python\n        run: |\n          echo \"Generated wheel:\"\n          ls -lh dist/\n          # Verify the wheel has the correct package name\n          ls dist/ | grep -q \"sglang_kt\" || (echo \"ERROR: Wheel name does not contain sglang_kt\" && exit 1)\n          echo \"Wheel name verified.\"\n\n      - name: Upload artifact\n        uses: actions/upload-artifact@v4\n        with:\n          name: sglang-kt-wheel\n          path: third_party/sglang/python/dist/*.whl\n          retention-days: 7\n\n  publish-pypi:\n    name: Publish sglang-kt to PyPI\n    needs: [build-sglang-kt]\n    runs-on: [self-hosted, linux, x64]\n    if: github.repository == 'kvcache-ai/ktransformers' && github.ref == 'refs/heads/main'\n    environment: prod\n    permissions:\n      id-token: write\n      contents: read\n\n    steps:\n      - name: Download wheel artifact\n        uses: actions/download-artifact@v4\n        with:\n          name: sglang-kt-wheel\n          path: dist/\n\n      - name: Display wheels\n        run: |\n          echo \"Wheels to publish:\"\n          ls -lh dist/\n\n      - name: Install twine\n        run: |\n          python -m pip install --upgrade pip\n          pip install twine\n\n      - name: Publish to TestPyPI (if requested)\n        if: github.event.inputs.test_pypi == 'true'\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}\n        run: |\n          python -m twine upload \\\n            --repository testpypi \\\n            --skip-existing \\\n            --verbose \\\n            dist/*.whl\n\n      - name: Publish to PyPI\n        if: github.event.inputs.test_pypi != 'true'\n        env:\n          TWINE_USERNAME: __token__\n          TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}\n        run: |\n          python -m twine upload \\\n            --skip-existing \\\n            --verbose \\\n            dist/*.whl\n\n      - name: Create release summary\n        run: |\n          echo \"## sglang-kt Published to PyPI\" >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"### Installation\" >> $GITHUB_STEP_SUMMARY\n          echo '```bash' >> $GITHUB_STEP_SUMMARY\n          echo \"pip install sglang-kt\" >> $GITHUB_STEP_SUMMARY\n          echo '```' >> $GITHUB_STEP_SUMMARY\n          echo \"\" >> $GITHUB_STEP_SUMMARY\n          echo \"This is the kvcache-ai fork of SGLang with kt-kernel support.\" >> $GITHUB_STEP_SUMMARY\n          echo \"PyPI link: https://pypi.org/project/sglang-kt/\" >> $GITHUB_STEP_SUMMARY\n"
  },
  {
    "path": ".github/workflows/sync-sglang-submodule.yml",
    "content": "name: Sync sglang submodule\n\non:\n  schedule:\n    # Run daily at 08:00 UTC\n    - cron: \"0 8 * * *\"\n  workflow_dispatch:\n\npermissions:\n  contents: write\n  pull-requests: write\n\njobs:\n  sync:\n    name: Check for sglang-kt updates\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v4\n        with:\n          submodules: true\n          fetch-depth: 0\n          token: ${{ secrets.GITHUB_TOKEN }}\n\n      - name: Update sglang submodule to latest main\n        id: update\n        run: |\n          OLD_SHA=$(git -C third_party/sglang rev-parse HEAD)\n          git submodule update --remote third_party/sglang\n          NEW_SHA=$(git -C third_party/sglang rev-parse HEAD)\n\n          echo \"old_sha=$OLD_SHA\" >> \"$GITHUB_OUTPUT\"\n          echo \"new_sha=$NEW_SHA\" >> \"$GITHUB_OUTPUT\"\n\n          if [ \"$OLD_SHA\" = \"$NEW_SHA\" ]; then\n            echo \"changed=false\" >> \"$GITHUB_OUTPUT\"\n            echo \"sglang submodule is already up to date ($OLD_SHA)\"\n          else\n            echo \"changed=true\" >> \"$GITHUB_OUTPUT\"\n\n            # Collect commit log between old and new\n            COMMITS=$(git -C third_party/sglang log --oneline \"$OLD_SHA..$NEW_SHA\" | head -20)\n            echo \"commits<<EOF\" >> \"$GITHUB_OUTPUT\"\n            echo \"$COMMITS\" >> \"$GITHUB_OUTPUT\"\n            echo \"EOF\" >> \"$GITHUB_OUTPUT\"\n\n            # sglang-kt version = ktransformers version (from version.py)\n            VERSION=$(python3 -c \"exec(open('version.py').read()); print(__version__)\" 2>/dev/null || echo \"unknown\")\n            echo \"version=$VERSION\" >> \"$GITHUB_OUTPUT\"\n\n            echo \"sglang submodule updated: $OLD_SHA -> $NEW_SHA (v$VERSION)\"\n          fi\n\n      - name: Create pull request\n        if: steps.update.outputs.changed == 'true'\n        uses: peter-evans/create-pull-request@v6\n        with:\n          token: ${{ secrets.GITHUB_TOKEN }}\n          commit-message: |\n            [build]: sync sglang submodule to ${{ steps.update.outputs.new_sha }}\n          branch: auto/sync-sglang\n          delete-branch: true\n          title: \"[build] Sync sglang-kt submodule (v${{ steps.update.outputs.version }})\"\n          body: |\n            Automated sync of `third_party/sglang` submodule to latest `main`.\n\n            **Old ref:** `${{ steps.update.outputs.old_sha }}`\n            **New ref:** `${{ steps.update.outputs.new_sha }}`\n            **sglang-kt version:** `${{ steps.update.outputs.version }}`\n\n            ### Commits included\n            ```\n            ${{ steps.update.outputs.commits }}\n            ```\n\n            ---\n            *This PR was created automatically by the [sync-sglang-submodule](${{ github.server_url }}/${{ github.repository }}/actions/workflows/sync-sglang-submodule.yml) workflow.*\n          labels: |\n            dependencies\n            automated\n"
  },
  {
    "path": ".gitignore",
    "content": "__pycache__\nbuild\n.vscode\n*.so\n*.cache\nserver.db\nlogs\nnode_modules\n*.nsys-rep\n.vs/\n*pycache*\n*build/\n.DS_Store\ncompile_commands.json\n*.egg-info*\n*dist/\nktransformers/server/local_store/\nktransformers/server_test1.db\n*.patch\nimg/\ntmp*.txt\ntest.txt\nbook\nktransformers/tests/chat_txt.txt\nmmlu_result*\nktransformers/ktransformers_ext/cuda_musa/\ntest_prompt.txt\ncsrc/demo\nbuild*\nCMakeFiles/\nkvc2/\nsched/\n*.png"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"third_party/llama.cpp\"]\n\tpath = third_party/llama.cpp\n\turl = https://github.com/ggerganov/llama.cpp.git\n[submodule \"third_party/pybind11\"]\n\tpath = third_party/pybind11\n\turl = https://github.com/pybind/pybind11.git\n[submodule \"third_party/custom_flashinfer\"]\n\tpath = third_party/custom_flashinfer\n\turl = https://github.com/kvcache-ai/custom_flashinfer.git\n\tbranch = fix-precision-mla-merge-main\n[submodule \"third_party/sglang\"]\n\tpath = third_party/sglang\n\turl = https://github.com/kvcache-ai/sglang.git\n\tbranch = main\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "MAINTAINERS.md",
    "content": "# Maintainers\n\nThis document lists the current maintainers and outlines their responsibilities.\n\n## Current Maintainers\n\n| Name | GitHub | Role | Affiliation | Email |\n|------|--------|------|-------------|-------|\n| Weiyu Xie | [@ErvinXie](https://github.com/ErvinXie) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | xwy21@mails.tsinghua.edu.cn |\n| Hongtao Chen | [@chenht2022](https://github.com/chenht2022) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | cht22@mails.tsinghua.edu.cn |\n| Jianwei Dong | [@ovowei](https://github.com/ovowei) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | dongjw24@mails.tsinghua.edu.cn |\n| Ziwei Yuan | [@KMSorSMS](https://github.com/KMSorSMS) | Maintainer | [Approaching.AI](http://approaching.ai/) | 2022090910005@std.uestc.edu.cn |\n| Qingliang Ou | [@ouqingliang](https://github.com/ouqingliang) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | oql@bupt.edu.cn |\n| Jiaqi Liao | [@SkqLiao](https://github.com/SkqLiao) | Maintainer | [Approaching.AI](http://approaching.ai/) | jiaqi.liao@bit.edu.cn |\n| Peilin Li | [@JimmyPeilinLi](https://github.com/JimmyPeilinLi) | Maintainer | [Approaching.AI](http://approaching.ai/) | lipeilin@mail.nwpu.edu.cn |\n| Xingxing Hao | [@mrhaoxx](https://github.com/mrhaoxx) | Maintainer | [Approaching.AI](http://approaching.ai/) | mr.haoxx@gmail.com |\n| Boxin Zhang | [@Atream](https://github.com/Atream) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | zhangbx24@mails.tsinghua.edu.cn |\n| Jingqi Tang | [@Azure-Tang](https://github.com/Azure-Tang) | Maintainer | [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University | tangjq25@mails.tsinghua.edu.cn |\n| Jiahao Wang | [@qiyuxinlin](https://github.com/qiyuxinlin) | Maintainer | [Approaching.AI](http://approaching.ai/) | 202241050020@hdu.edu.cn |\n\n## Responsibilities\n\nMaintainers steward the project and keep it healthy for users and contributors.\n\n- Review and approve pull requests; ensure changes meet quality, testing, and documentation standards.\n- Triage issues, keep labels organized, and respond to questions in a timely manner.\n- Uphold the project’s code of conduct and report violations when needed.\n- Maintain CI reliability and address regressions promptly.\n- Oversee releases and keep compatibility with supported dependency versions.\n- Protect project security and follow the security disclosure process.\n\n## Becoming a Maintainer\n\nWe welcome contributors who show sustained, high-quality contributions and collaborative behavior. If you are interested, please contact an existing maintainer and share your recent contributions and areas of focus.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n  <p align=\"center\">\n\n<picture>\n    <img alt=\"KTransformers\" src=\"https://github.com/user-attachments/assets/d5a2492f-a415-4456-af99-4ab102f13f8b\" width=50%>\n\n</picture>\n\n</p>\n  <h3>A Flexible Framework for Experiencing Cutting-edge LLM Inference/Fine-tune Optimizations</h3>\n  <strong><a href=\"#-overview\">🎯 Overview</a> | <a href=\"#-kt-kernel---high-performance-inference-kernels\">🚀 kt-kernel</a> | <a href=\"#-kt-sft---fine-tuning-framework\">🎓 kt-sft</a> | <a href=\"#-citation\">🔥 Citation</a> | <a href=\"https://github.com/kvcache-ai/ktransformers/issues/1582\">🚀 Roadmap(2025Q4)</a>  </strong>\n</div>\n\n## 🎯 Overview\n\nKTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into **two core modules**: [kt-kernel](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel/) and [kt-sft](https://github.com/kvcache-ai/ktransformers/tree/main/kt-sft).\n\n## 🔥 Updates\n\n* **Feb 13, 2026**: MiniMax-M2.5 Day0 Support! ([Tutorial](./doc/en/MiniMax-M2.5.md))\n* **Feb 12, 2026**: GLM-5 Day0 Support! ([Tutorial](./doc/en/kt-kernel/GLM-5-Tutorial.md))\n* **Jan 27, 2026**: Kimi-K2.5 Day0 Support! ([Tutorial](./doc/en/Kimi-K2.5.md)) ([SFT Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.5.md))\n* **Jan 22, 2026**: Support [CPU-GPU Expert Scheduling](./doc/en/kt-kernel/experts-sched-Tutorial.md), [Native BF16 and FP8 per channel Precision](./doc/en/kt-kernel/Native-Precision-Tutorial.md) and [AutoDL unified fine-tuning and inference](./doc/zh/【云端低价训推】%20KTransformers%2BAutoDL%2BLlamaFactory：随用随租的低成本超大模型「微调%2B推理」一体化流程.pdf)\n* **Dec 24, 2025**: Support Native MiniMax-M2.1 inference. ([Tutorial](./doc/en/kt-kernel/MiniMax-M2.1-Tutorial.md))\n* **Dec 22, 2025**: Support RL-DPO fine-tuning with LLaMA-Factory. ([Tutorial](./doc/en/SFT/DPO_tutorial.md))\n* **Dec 5, 2025**: Support Native Kimi-K2-Thinking inference ([Tutorial](./doc/en/kt-kernel/Kimi-K2-Thinking-Native.md))\n* **Nov 6, 2025**: Support Kimi-K2-Thinking inference ([Tutorial](./doc/en/Kimi-K2-Thinking.md)) and fine-tune ([Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.md))\n* **Nov 4, 2025**: KTransformers Fine-Tuning × LLaMA-Factory Integration. ([Tutorial](./doc/en/KTransformers-Fine-Tuning_User-Guide.md))\n* **Oct 27, 2025**: Support Ascend NPU. ([Tutorial](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md))\n* **Oct 10, 2025**: Integrating into SGLang. ([Roadmap](https://github.com/sgl-project/sglang/issues/11425), [Blog](https://lmsys.org/blog/2025-10-22-KTransformers/))\n* **Sept 11, 2025**: Support Qwen3-Next. ([Tutorial](./doc/en/Qwen3-Next.md))\n* **Sept 05, 2025**: Support Kimi-K2-0905. ([Tutorial](./doc/en/Kimi-K2.md))\n* **July 26, 2025**: Support SmallThinker and GLM4-MoE. ([Tutorial](./doc/en/SmallThinker_and_Glm4moe.md))\n* **July 11, 2025**: Support Kimi-K2. ([Tutorial](./doc/en/Kimi-K2.md))\n* **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) [prefix cache](./doc/en/prefix_cache.md) reuse.\n* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).\n* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))\n* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)).\n* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)).\n* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).\n* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022--v023-longer-context--fp8-kernel) for DeepSeek-V3 and R1 in 24GB VRAM.\n* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).\n* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed （+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).\n* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md).\n* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.\n* **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU.\n* **Aug 14, 2024**: Support llamfile as linear backend.\n* **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\\*7B  and 8\\*22B; Support q2k, q3k, q5k dequant on gpu.\n* **Aug 9, 2024**: Support windows native.\n\n---\n\n## 📦 Core Modules\n\n### 🚀 [kt-kernel](./kt-kernel/) - High-Performance Inference Kernels\n\nCPU-optimized kernel operations for heterogeneous LLM inference.\n\n<img width=\"1049\" height=\"593\" alt=\"image\" src=\"https://github.com/user-attachments/assets/68f423da-3f55-4025-bdc9-9ceaa554f00b\" />\n\n\n**Key Features:**\n- **AMX/AVX Acceleration**: Intel AMX and AVX512/AVX2 optimized kernels for INT4/INT8 quantized inference\n- **MoE Optimization**: Efficient Mixture-of-Experts inference with NUMA-aware memory management\n- **Quantization Support**: CPU-side INT4/INT8 quantized weights, GPU-side GPTQ support\n- **Easy Integration**: Clean Python API for SGLang and other frameworks\n\n**Quick Start:**\n```bash\ncd kt-kernel\npip install .\n```\n\n**Use Cases:**\n\n- CPU-GPU hybrid inference for large MoE models\n- Integration with SGLang for production serving\n- Heterogeneous expert placement (hot experts on GPU, cold experts on CPU)\n\n**Performance Examples:**\n| Model | Hardware Configuration | Total Throughput | Output Throughput |\n|-------|------------------------|------------------|-------------------|\n| DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s (8-way concurrency) |\n\n👉 **[Full Documentation →](./kt-kernel/README.md)**\n\n---\n\n### 🎓 [kt-sft](./kt-sft/) - Fine-Tuning Framework\n\nKTransformers × LLaMA-Factory integration for ultra-large MoE model fine-tuning.\n\n![image-20251011010558909](https://raw.githubusercontent.com/kvcache-ai/ktransformers/main/doc/assets/image-20251011010558909.png)\n\n**Key Features:**\n\n- **Resource Efficient**: Fine-tune 671B DeepSeek-V3 with just **70GB GPU memory** + 1.3TB RAM\n- **LoRA Support**: Full LoRA fine-tuning with heterogeneous acceleration\n- **LLaMA-Factory Integration**: Seamless integration with popular fine-tuning framework\n- **Production Ready**: Chat, batch inference, and metrics evaluation\n\n**Performance Examples:**\n\n| Model | Configuration | Throughput | GPU Memory |\n|-------|--------------|------------|------------|\n| DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB (multi-GPU) |\n| DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB |\n\n**Quick Start:**\n```bash\ncd kt-sft\n# Install environment following kt-sft/README.md\nUSE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml\n```\n\n👉 **[Full Documentation →](./kt-sft/README.md)**\n\n---\n\n## 🔥 Citation\n\nIf you use KTransformers in your research, please cite our paper:\n\n```bibtex\n@inproceedings{10.1145/3731569.3764843,\n  title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},\n  author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},\n  booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},\n  year = {2025}\n}\n```\n\n## 👥 Contributors & Team\n\nDeveloped and maintained by:\n- [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University\n- [Approaching.AI](http://approaching.ai/)\n- [9#AISoft](https://github.com/aisoft9)\n- Community contributors\n\nWe welcome contributions! Please feel free to submit issues and pull requests.\n\n## 💬 Community & Support\n\n- **GitHub Issues**: [Report bugs or request features](https://github.com/kvcache-ai/ktransformers/issues)\n- **WeChat Group**: See [archive/WeChatGroup.png](./archive/WeChatGroup.png)\n\n## 📦 KT original Code\n\nThe original integrated KTransformers framework has been archived to the [`archive/`](./archive/) directory for reference. The project now focuses on the two core modules above for better modularity and maintainability.\n\nFor the original documentation with full quick-start guides and examples, see:\n- [archive/README.md](./archive/README.md) (English)\n- [archive/README_ZH.md](./archive/README_ZH.md) (中文)\n"
  },
  {
    "path": "README_ZH.md",
    "content": "<div align=\"center\">\n  <p align=\"center\">\n\n<picture>\n    <img alt=\"KTransformers\" src=\"https://github.com/user-attachments/assets/d5a2492f-a415-4456-af99-4ab102f13f8b\" width=50%>\n\n</picture>\n\n</p>\n  <h3>一个用于体验尖端 LLM 推理/微调优化的灵活框架</h3>\n  <strong><a href=\"#-概览\">🎯 概览</a> | <a href=\"#-kt-kernel---高性能推理内核\">🚀 kt-kernel</a> | <a href=\"#-kt-sft---微调框架\">🎓 kt-sft</a> | <a href=\"#-引用\">🔥 引用</a> </strong>\n</div>\n\n## 🎯 概览\n\nKTransformers 是一个专注于通过 CPU-GPU 异构计算实现大语言模型高效推理和微调的研究项目。该项目已发展为**两个核心模块**：[kt-kernel](./kt-kernel/) 和 [kt-sft](./kt-sft/)。\n\n## 🔥 更新\n\n* **2025 年 12 月 5 日**：支持原生 Kimi-K2-Thinking 推理（[教程](./doc/en/Kimi-K2-Thinking-Native.md)）\n* **2025 年 11 月 6 日**：支持 Kimi-K2-Thinking 推理（[教程](./doc/en/Kimi-K2-Thinking.md)）和微调（[教程](./doc/en/SFT_Installation_Guide_KimiK2.md)）\n* **2025 年 11 月 4 日**：KTransformers 微调 × LLaMA-Factory 集成（[教程](./doc/en/KTransformers-Fine-Tuning_User-Guide.md)）\n* **2025 年 10 月 27 日**：支持昇腾 NPU（[教程](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md)）\n* **2025 年 10 月 10 日**：集成到 SGLang（[路线图](https://github.com/sgl-project/sglang/issues/11425)，[博客](https://lmsys.org/blog/2025-10-22-KTransformers/)）\n* **2025 年 9 月 11 日**：支持 Qwen3-Next（[教程](./doc/en/Qwen3-Next.md)）\n* **2025 年 9 月 5 日**：支持 Kimi-K2-0905（[教程](./doc/en/Kimi-K2.md)）\n* **2025 年 7 月 26 日**：支持 SmallThinker 和 GLM4-MoE（[教程](./doc/en/SmallThinker_and_Glm4moe.md)）\n* **2025 年 7 月 11 日**：支持 Kimi-K2（[教程](./doc/en/Kimi-K2.md)）\n* **2025 年 6 月 30 日**：支持 3 层（GPU-CPU-磁盘）[前缀缓存](./doc/en/prefix_cache.md)复用\n* **2025 年 5 月 14 日**：支持 Intel Arc GPU（[教程](./doc/en/xpu.md)）\n* **2025 年 4 月 29 日**：支持 AMX-Int8、AMX-BF16 和 Qwen3MoE（[教程](./doc/en/AMX.md)）\n* **2025 年 4 月 9 日**：实验性支持 LLaMA 4 模型（[教程](./doc/en/llama4.md)）\n* **2025 年 4 月 2 日**：支持多并发（[教程](./doc/en/balance-serve.md)）\n* **2025 年 3 月 15 日**：支持 AMD GPU 上的 ROCm（[教程](./doc/en/ROCm.md)）\n* **2025 年 3 月 5 日**：支持 unsloth 1.58/2.51 位权重和 [IQ1_S/FP8 混合](./doc/en/fp8_kernel.md)权重。在 24GB VRAM 中支持 DeepSeek-V3 和 R1 的 139K [更长上下文](./doc/en/DeepseekR1_V3_tutorial.md#v022--v023-longer-context--fp8-kernel)\n* **2025 年 2 月 25 日**：为 DeepSeek-V3 和 R1 支持 [FP8 GPU 内核](./doc/en/fp8_kernel.md)；[更长上下文](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context)\n* **2025 年 2 月 15 日**：更长上下文（24GB VRAM 从 4K 到 8K）& 速度稍快（+15%，最高 16 Tokens/s），更新[文档](./doc/en/DeepseekR1_V3_tutorial.md)和[在线手册](https://kvcache-ai.github.io/ktransformers/)\n* **2025 年 2 月 10 日**：支持 Deepseek-R1 和 V3 在单 GPU（24GB VRAM）/多 GPU 和 382GB DRAM 上运行，速度提升高达 3~28 倍。详细案例展示和复现教程请参见[这里](./doc/en/DeepseekR1_V3_tutorial.md)\n* **2024 年 8 月 28 日**：将 DeepseekV2 所需的 VRAM 从 21GB 降低到 11GB\n* **2024 年 8 月 15 日**：更新了关于注入和多 GPU 的详细[教程](doc/en/injection_tutorial.md)\n* **2024 年 8 月 14 日**：支持 llamfile 作为线性后端\n* **2024 年 8 月 12 日**：支持多 GPU；支持新模型：mixtral 8\\*7B 和 8\\*22B；支持 GPU 上的 q2k、q3k、q5k 去量化\n* **2024 年 8 月 9 日**：支持 Windows 原生环境\n\n---\n\n## 📦 核心模块\n\n### 🚀 [kt-kernel](./kt-kernel/) - 高性能推理内核\n\n用于异构 LLM 推理的 CPU 优化内核操作。\n\n![image-20251011010558909](./doc/assets/heterogeneous_computing.png)\n\n**主要特性：**\n- **AMX/AVX 加速**：Intel AMX 和 AVX512/AVX2 优化的内核，用于 INT4/INT8 量化推理\n- **MoE 优化**：高效的专家混合推理，具有 NUMA 感知内存管理\n- **量化支持**：CPU 端 INT4/INT8 量化权重，GPU 端 GPTQ 支持\n- **易于集成**：为 SGLang 和其他框架提供简洁的 Python API\n\n**快速开始：**\n```bash\ncd kt-kernel\npip install .\n```\n\n**使用场景：**\n\n- 大型 MoE 模型的 CPU-GPU 混合推理\n- 与 SGLang 集成用于生产服务\n- 异构专家放置（热专家在 GPU 上，冷专家在 CPU 上）\n\n**性能示例：**\n| 模型 | 硬件配置 | 总吞吐量 | 输出吞吐量 |\n|-------|------------------------|------------------|-------------------|\n| DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s（8 路并发）|\n\n👉 **[完整文档 →](./kt-kernel/README.md)**\n\n---\n\n### 🎓 [kt-sft](./kt-sft/) - 微调框架\n\nKTransformers × LLaMA-Factory 集成，用于超大型 MoE 模型微调。\n\n![image-20251011010558909](./doc/assets/image-20251011010558909.png)\n\n**主要特性：**\n\n- **资源高效**：仅需 **70GB GPU 显存** + 1.3TB 内存即可微调 671B DeepSeek-V3\n- **LoRA 支持**：完整的 LoRA 微调，带有异构加速\n- **LLaMA-Factory 集成**：与流行的微调框架无缝集成\n- **生产就绪**：聊天、批量推理和指标评估\n\n**性能示例：**\n\n| 模型 | 配置 | 吞吐量 | GPU 显存 |\n|-------|--------------|------------|--------------|\n| DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB（多 GPU）|\n| DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB |\n\n**快速开始：**\n```bash\ncd kt-sft\n# 按照 kt-sft/README.md 安装环境\nUSE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml\n```\n\n👉 **[完整文档 →](./kt-sft/README.md)**\n\n---\n\n## 🔥 引用\n\n如果您在研究中使用了 KTransformers，请引用我们的论文：\n\n```bibtex\n@inproceedings{10.1145/3731569.3764843,\n  title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},\n  author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},\n  booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},\n  year = {2025}\n}\n```\n\n## 👥 贡献者与团队\n\n由以下团队开发和维护：\n- 清华大学 [MADSys 实验室](https://madsys.cs.tsinghua.edu.cn/)\n- [Approaching.AI](http://approaching.ai/)\n- 社区贡献者\n\n我们欢迎贡献！请随时提交问题和拉取请求。\n\n## 💬 社区与支持\n\n- **GitHub Issues**：[报告问题或请求功能](https://github.com/kvcache-ai/ktransformers/issues)\n- **微信群**：请参见 [archive/WeChatGroup.png](./archive/WeChatGroup.png)\n\n## 📦 KT原仓库\n\n原始的集成 KTransformers 框架已归档到 [`archive/`](./archive/) 目录以供参考。该项目现在专注于上述两个核心模块，以获得更好的模块化和可维护性。\n\n有关原始文档以及完整的快速入门指南和示例，请参见：\n- [archive/README.md](./archive/README.md)（英文）\n- [archive/README_ZH.md](./archive/README_ZH.md)（中文）\n"
  },
  {
    "path": "archive/.devcontainer/Dockerfile",
    "content": "FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel as compile_server\nWORKDIR /workspace\nENV CUDA_HOME /usr/local/cuda\nRUN <<EOF\napt update -y &&  apt install -y  --no-install-recommends \\\n    git \\\n    wget \\\n    vim \\\n    gcc \\\n    g++ \\\n    cmake && \nrm -rf /var/lib/apt/lists/* &&\npip install --upgrade pip &&\npip install ninja pyproject numpy cpufeature &&\npip install flash-attn &&\ncp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/\nEOF\n# Set the default shell to bash\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "archive/.devcontainer/devcontainer.json",
    "content": "{\n    \"name\": \"Ktrans Dev Container\",\n    \"privileged\": true,\n    \"build\": {\n        \"dockerfile\": \"Dockerfile\",\n        \"context\": \"..\",\n        \"args\": {\n            \"http_proxy\": \"${env:http_proxy}\",\n            \"https_proxy\": \"${env:https_proxy}\",\n        }\n    },\n    \"runArgs\": [\n        \"--network=host\",\n        \"--gpus\",\n        \"all\"\n        // \"--gpu all\"\n    ],\n    \"workspaceFolder\": \"/workspace\",\n    \"workspaceMount\": \"source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=cached\",\n    \"mounts\": [\n        \"source=/mnt/data,target=/mnt/incontainer,type=bind,consistency=cached\"\n    ],\n    \"customizations\": {\n        \"vscode\": {\n            \"extensions\": [\n            ],\n            \"settings\": {\n                \"terminal.integrated.shell.linux\": \"/bin/bash\",\n                \"cmake.configureOnOpen\": true,\n                \"cmake.generator\": \"Ninja\"\n            }\n        }\n    }\n}"
  },
  {
    "path": "archive/.flake8",
    "content": "[flake8]\nmax-line-length = 120\nextend-select = B950\nextend-ignore = E203,E501,E701, B001,B006,B007,B008,B009,B010,B011,B016,B028,B031,B950,E265,E266,E401,E402,E711,E712,E713,E721,E722,E731,F401,F403,F405,F541,F811,F821,F841,W391"
  },
  {
    "path": "archive/.gitmodules",
    "content": "[submodule \"third_party/llama.cpp\"]\n\tpath = archive/third_party/llama.cpp\n\turl = https://github.com/ggerganov/llama.cpp.git\n[submodule \"third_party/pybind11\"]\n\tpath = archive/third_party/pybind11\n\turl = https://github.com/pybind/pybind11.git\n[submodule \"third_party/spdlog\"]\n\tpath = archive/third_party/spdlog\n\turl = https://github.com/gabime/spdlog.git\n[submodule \"third_party/custom_flashinfer\"]\n\tpath = archive/third_party/custom_flashinfer\n\turl = https://github.com/kvcache-ai/custom_flashinfer.git\n\tbranch = fix-precision-mla-merge-main\n[submodule \"third_party/xxHash\"]\n\tpath = archive/third_party/xxHash\n\turl = https://github.com/Cyan4973/xxHash.git\n[submodule \"third_party/prometheus-cpp\"]\n\tpath = archive/third_party/prometheus-cpp\n\turl = https://github.com/jupp0r/prometheus-cpp\n[submodule \"third_party/PhotonLibOS\"]\n\tpath = archive/third_party/PhotonLibOS\n\turl = https://github.com/alibaba/PhotonLibOS.git\n[submodule \"kt-kernel/third_party/llama.cpp\"]\n\tpath = kt-kernel/third_party/llama.cpp\n\turl = https://github.com/ggerganov/llama.cpp.git\n[submodule \"kt-kernel/third_party/pybind11\"]\n\tpath = kt-kernel/third_party/pybind11\n\turl = https://github.com/pybind/pybind11.git\n"
  },
  {
    "path": "archive/.pylintrc",
    "content": "[MASTER]\nextension-pkg-whitelist=pydantic\nmax-line-length=120\n\n[MESSAGES CONTROL]\ndisable=missing-function-docstring"
  },
  {
    "path": "archive/Dockerfile",
    "content": "FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel as compile_server\n\n\nARG CPU_INSTRUCT=NATIVE\n\n# 设置工作目录和 CUDA 路径\nWORKDIR /workspace\nENV CUDA_HOME=/usr/local/cuda\n\n\n\n# 安装依赖\nRUN apt update -y\nRUN apt install -y --no-install-recommends \\\n    libtbb-dev \\\n    libssl-dev \\\n    libcurl4-openssl-dev \\\n    libaio1 \\\n    libaio-dev \\\n    libfmt-dev \\\n    libgflags-dev \\\n    zlib1g-dev \\\n    patchelf \\\n    git \\\n    wget \\\n    vim \\\n    gcc \\\n    g++ \\\n    cmake\n# 拷贝代码\nRUN git clone https://github.com/kvcache-ai/ktransformers.git \n# 清理 apt 缓存\nRUN rm -rf /var/lib/apt/lists/*\n\n# 进入项目目录\nWORKDIR /workspace/ktransformers\n# 初始化子模块\nRUN git submodule update --init --recursive\n\n# 升级 pip\nRUN pip install --upgrade pip\n\n# 安装构建依赖\nRUN pip install ninja pyproject numpy cpufeature aiohttp zmq openai\n\n# 安装 flash-attn（提前装可以避免后续某些编译依赖出错）\nRUN pip install flash-attn\n\n# 安装 ktransformers 本体（含编译）\nRUN CPU_INSTRUCT=${CPU_INSTRUCT} \\\n    USE_BALANCE_SERVE=1 \\\n    KTRANSFORMERS_FORCE_BUILD=TRUE \\\n    TORCH_CUDA_ARCH_LIST=\"8.0;8.6;8.7;8.9;9.0+PTX\" \\\n    pip install . --no-build-isolation --verbose\n\nRUN pip install third_party/custom_flashinfer/\n# 清理 pip 缓存\nRUN pip cache purge\n\n# 拷贝 C++ 运行时库\nRUN cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/\n\n# 保持容器运行（调试用）\nENTRYPOINT [\"tail\", \"-f\", \"/dev/null\"]"
  },
  {
    "path": "archive/Dockerfile.xpu",
    "content": "# Base image\nFROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04\n\nARG http_proxy\nARG https_proxy\n\nENV DEBIAN_FRONTEND=noninteractive\nENV CONDA_DIR=/opt/conda\n\n# Install dependencies\nRUN apt-get update && apt-get install -y \\\n    wget \\\n    curl \\\n    bash \\\n    git \\\n    vim \\\n    ca-certificates \\\n    binutils \\\n    cmake \\\n    g++ \\\n    && rm -rf /var/lib/apt/lists/*\n\n# Install Miniforge\nRUN wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O /tmp/miniforge.sh && \\\n    bash /tmp/miniforge.sh -b -p $CONDA_DIR && \\\n    rm /tmp/miniforge.sh && \\\n    $CONDA_DIR/bin/conda clean -afy\n\n# Add conda to PATH\nENV PATH=$CONDA_DIR/bin:$PATH\n\nRUN bash -c \"\\\n    source /opt/conda/etc/profile.d/conda.sh && \\\n    conda create --name ktransformers python=3.11 -y && \\\n    conda activate ktransformers && \\\n    conda env list && \\\n    conda install -c conda-forge libstdcxx-ng -y && \\\n    strings \\$(find /opt/conda/envs/ktransformers/lib -name 'libstdc++.so.6') | grep GLIBCXX | grep 3.4.32 \\\n\"\n\nRUN bash -c \"\\\n    source /opt/conda/etc/profile.d/conda.sh && \\\n    conda activate ktransformers && \\\n    pip install ipex-llm[xpu_2.6]==2.3.0b20250518 --extra-index-url https://download.pytorch.org/whl/xpu && \\\n    pip uninstall -y torch torchvision torchaudio && \\\n    pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu && \\\n    pip uninstall -y intel-opencl-rt dpcpp-cpp-rt && \\\n    pip list \\\n\"\n\n# Clone and set up ktransformers repo\nRUN bash -c \"\\\n    source $CONDA_DIR/etc/profile.d/conda.sh && \\\n    conda activate ktransformers && \\\n    git clone https://github.com/kvcache-ai/ktransformers.git && \\\n    cd ktransformers && \\\n    git submodule update --init && \\\n    sed -i 's/torch\\.xpu\\.is_available()/True/g' setup.py && \\\n    bash install.sh --dev xpu \\\n\"\n\n# Init conda and prepare bashrc\nRUN conda init bash && \\\n    echo \"source $CONDA_DIR/etc/profile.d/conda.sh\" >> ~/.bashrc && \\\n    echo \"conda activate ktransformers\" >> ~/.bashrc\n\nWORKDIR /ktransformers/\nCMD [\"bash\"]\n"
  },
  {
    "path": "archive/LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "archive/MANIFEST.in",
    "content": "graft third_party\ngraft ktransformers\ngraft local_chat.py\ngraft csrc\ninclude LICENSE README.md\nprune ktransformers/website\nprune ktransformers/logs\nprune ktransformers.egg-info\nprune third_party/llama.cpp/models\ngraft ktransformers/website/dist\nglobal-exclude __pycache__\ninclude KTransformersOps.*.so\ninclude cpuinfer_ext.*.so\n"
  },
  {
    "path": "archive/Makefile",
    "content": "flake_find:\n\tcd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' - \nformat:\n\t@cd ktransformers && black .\n\t@black setup.py\ndev_install:\n# clear build dirs\n\trm -rf build\n\trm -rf *.egg-info\n\trm -rf ktransformers/ktransformers_ext/build\n\trm -rf ktransformers/ktransformers_ext/cuda/build\n\trm -rf ktransformers/ktransformers_ext/cuda/dist\n\trm -rf ktransformers/ktransformers_ext/cuda/*.egg-info\n\n# install ktransformers\n\techo \"Installing python dependencies from requirements.txt\"\n\tpip install -r requirements-local_chat.txt\n\n\techo \"Installing ktransformers\"\n\tKTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation\n\techo \"Installation completed successfully\"\nclean:\n\trm -rf build\n\trm -rf *.egg-info\n\trm -rf ktransformers/ktransformers_ext/build\n\trm -rf ktransformers/ktransformers_ext/cuda/build\n\trm -rf ktransformers/ktransformers_ext/cuda/dist\n\trm -rf ktransformers/ktransformers_ext/cuda/*.egg-info\t\ninstall_numa:\n\tUSE_NUMA=1 make dev_install\ninstall_no_numa:\n\tenv -u USE_NUMA make dev_install"
  },
  {
    "path": "archive/README.md",
    "content": "<div align=\"center\">\n  <p align=\"center\">\n    <picture>\n      <img alt=\"KTransformers\" src=\"https://github.com/user-attachments/assets/d5a2492f-a415-4456-af99-4ab102f13f8b\" width=50%>\n    </picture>\n  </p>\n  <h3>High-Performance CPU-GPU Hybrid Inference for Large Language Models</h3>\n</div>\n\n## 🎯 Overview\n\nKTransformers is a research project focused on efficient inference and fine-tuning of large language models through CPU-GPU heterogeneous computing. The project has evolved into **two core modules**: [kt-kernel](./kt-kernel/) and [kt-sft](./kt-sft/).\n\n## 🔥 Updates\n\n* **Nov 6, 2025**: Support Kimi-K2-Thinking inference and fine-tune\n* **Nov 4, 2025**: KTransformers Fine-Tuning × LLaMA-Factory Integration\n* **Oct 27, 2025**: Support Ascend NPU\n* **Oct 10, 2025**: Integrating into SGLang ([Roadmap](https://github.com/sgl-project/sglang/issues/11425), [Blog](https://lmsys.org/blog/2025-10-22-KTransformers/))\n* **Sept 11, 2025**: Support Qwen3-Next\n* **Sept 05, 2025**: Support Kimi-K2-0905\n* **July 26, 2025**: Support SmallThinker and GLM4-MoE\n* **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) prefix cache reuse\n* **May 14, 2025**: Support Intel Arc GPU\n* **Apr 29, 2025**: Support AMX-Int8、AMX-BF16 and Qwen3MoE\n* **Apr 9, 2025**: Experimental support for LLaMA 4 models\n* **Apr 2, 2025**: Support Multi-concurrency\n* **Mar 15, 2025**: Support ROCm on AMD GPU\n* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and IQ1_S/FP8 hybrid weights; 139K longer context for DeepSeek-V3/R1\n* **Feb 25, 2025**: Support FP8 GPU kernel for DeepSeek-V3 and R1\n* **Feb 10, 2025**: Support Deepseek-R1 and V3, up to 3~28x speedup\n\n---\n\n## 📦 Core Modules\n\n### 🚀 [kt-kernel](./kt-kernel/) - High-Performance Inference Kernels\n\nCPU-optimized kernel operations for heterogeneous LLM inference.\n\n![image-20251011010558909](./doc/assets/heterogeneous_computing.png)\n\n**Key Features:**\n- **AMX/AVX Acceleration**: Intel AMX and AVX512/AVX2 optimized kernels for INT4/INT8 quantized inference\n- **MoE Optimization**: Efficient Mixture-of-Experts inference with NUMA-aware memory management\n- **Quantization Support**: CPU-side INT4/INT8 quantized weights, GPU-side GPTQ support\n- **Easy Integration**: Clean Python API for SGLang and other frameworks\n\n**Quick Start:**\n```bash\ncd kt-kernel\npip install .\n```\n\n**Use Cases:**\n\n- CPU-GPU hybrid inference for large MoE models\n- Integration with SGLang for production serving\n- Heterogeneous expert placement (hot experts on GPU, cold experts on CPU)\n\n**Performance Examples:**\n| Model | Hardware Configuration | Total Throughput | Output Throughput |\n|-------|------------------------|------------------|-------------------|\n| DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s (8-way concurrency) |\n\n👉 **[Full Documentation →](./kt-kernel/README.md)**\n\n---\n\n### 🎓 [kt-sft](./kt-sft/) - Fine-Tuning Framework\n\nKTransformers × LLaMA-Factory integration for ultra-large MoE model fine-tuning.\n\n![image-20251011010558909](./doc/assets/image-20251011010558909.png)\n\n**Key Features:**\n\n- **Resource Efficient**: Fine-tune 671B DeepSeek-V3 with just **70GB GPU memory** + 1.3TB RAM\n- **LoRA Support**: Full LoRA fine-tuning with heterogeneous acceleration\n- **LLaMA-Factory Integration**: Seamless integration with popular fine-tuning framework\n- **Production Ready**: Chat, batch inference, and metrics evaluation\n\n**Performance Examples:**\n\n| Model | Configuration | Throughput | GPU Memory |\n|-------|--------------|------------|------------|\n| DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB (multi-GPU) |\n| DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB |\n\n**Quick Start:**\n```bash\ncd kt-sft\n# Install environment following kt-sft/README.md\nUSE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml\n```\n\n👉 **[Full Documentation →](./kt-sft/README.md)**\n\n---\n\n## 🔥 Citation\n\nIf you use KTransformers in your research, please cite our paper:\n\n```bibtex\n@inproceedings{10.1145/3731569.3764843,\n  title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},\n  author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},\n  booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},\n  year = {2025}\n}\n```\n\n## 👥 Contributors & Team\n\nDeveloped and maintained by:\n- [MADSys Lab](https://madsys.cs.tsinghua.edu.cn/) @ Tsinghua University\n- [Approaching.AI](http://approaching.ai/)\n- Community contributors\n\nWe welcome contributions! Please feel free to submit issues and pull requests.\n\n## 💬 Community & Support\n\n- **GitHub Issues**: [Report bugs or request features](https://github.com/kvcache-ai/ktransformers/issues)\n- **GitHub Discussions**: [Ask questions and share ideas](https://github.com/kvcache-ai/ktransformers/discussions)\n- **WeChat Group**: See [archive/WeChatGroup.png](./archive/WeChatGroup.png)\n\n## 📦 Legacy Code\n\nThe original integrated KTransformers framework has been archived to the [`archive/`](./archive/) directory for reference. The project now focuses on the two core modules above for better modularity and maintainability.\n\nFor the original documentation with full quick-start guides and examples, see:\n- [archive/README_LEGACY.md](./archive/README_LEGACY.md) (English)\n- [archive/README_ZH_LEGACY.md](./archive/README_ZH_LEGACY.md) (中文)\n\n"
  },
  {
    "path": "archive/README_LEGACY.md",
    "content": "<div align=\"center\">\n  <!-- <h1>KTransformers</h1> -->\n  <p align=\"center\">\n\n<picture>\n    <img alt=\"KTransformers\" src=\"https://github.com/user-attachments/assets/d5a2492f-a415-4456-af99-4ab102f13f8b\" width=50%>\n\n</picture>\n\n</p>\n  <h3>A Flexible Framework for Experiencing Cutting-edge LLM Inference Optimizations</h3>\n  <strong><a href=\"#show-cases\">🌟 Show Cases</a> | <a href=\"#quick-start\">🚀 Quick Start</a> | <a href=\"#tutorial\">📃 Tutorial</a> | <a href=\"#Citation\">🔥  Citation </a> | <a href=\"https://github.com/kvcache-ai/ktransformers/discussions\">💬  Discussion </a>|<a href=\"#FAQ\"> 🙋 FAQ</a> </strong>\n</div>\n\n<h2 id=\"intro\">🎉 Introduction</h2>\nKTransformers, pronounced as Quick Transformers, is designed to enhance your 🤗 <a href=\"https://github.com/huggingface/transformers\">Transformers</a> experience with advanced kernel optimizations and placement/parallelism strategies.\n<br/><br/>\nKTransformers is a flexible, Python-centric framework designed with extensibility at its core. \nBy implementing and injecting an optimized module with a single line of code, users gain access to a Transformers-compatible\ninterface, RESTful APIs compliant with OpenAI and Ollama, and even a simplified ChatGPT-like web UI. \n<br/><br/>\nOur vision for KTransformers is to serve as a flexible platform for experimenting with innovative LLM inference optimizations. Please let us know if you need any other features.\n\n<h2 id=\"Updates\">🔥 Updates</h2>\n\n* **Nov 6, 2025**: Support Kimi-K2-Thinking inference ([Tutorial](./doc/en/Kimi-K2-Thinking.md)) and fine-tune ([Tutorial](./doc/en/SFT_Installation_Guide_KimiK2.md))\n* **Nov 4, 2025**: KTransformers Fine-Tuning × LLaMA-Factory Integration. ([Tutorial](./doc/en/KTransformers-Fine-Tuning_User-Guide.md))\n* **Oct 27, 2025**: Support Ascend NPU. ([Tutorial](./doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md))\n* **Oct 10, 2025**: Integrating into SGLang. ([Roadmap](https://github.com/sgl-project/sglang/issues/11425))\n* **Sept 11, 2025**: Support Qwen3-Next. ([Tutorial](./doc/en/Qwen3-Next.md))\n* **Sept 05, 2025**: Support Kimi-K2-0905. ([Tutorial](./doc/en/Kimi-K2.md))\n* **July 26, 2025**: Support SmallThinker and GLM4-MoE. ([Tutorial](./doc/en/SmallThinker_and_Glm4moe.md))\n* **July 11, 2025**: Support Kimi-K2. ([Tutorial](./doc/en/Kimi-K2.md))\n* **June 30, 2025**: Support 3-layer (GPU-CPU-Disk) [prefix cache](./doc/en/prefix_cache.md) reuse.\n* **May 14, 2025**: Support Intel Arc GPU ([Tutorial](./doc/en/xpu.md)).\n* **Apr 29, 2025**: Support AMX-Int8、 AMX-BF16 and Qwen3MoE ([Tutorial](./doc/en/AMX.md))\n\nhttps://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2\n\n* **Apr 9, 2025**: Experimental support for LLaMA 4 models ([Tutorial](./doc/en/llama4.md)).\n* **Apr 2, 2025**: Support Multi-concurrency. ([Tutorial](./doc/en/balance-serve.md)).\n\nhttps://github.com/user-attachments/assets/faa3bda2-928b-45a7-b44f-21e12ec84b8a\n\n* **Mar 15, 2025**: Support ROCm on AMD GPU ([Tutorial](./doc/en/ROCm.md)).\n* **Mar 5, 2025**: Support unsloth 1.58/2.51 bits weights and [IQ1_S/FP8 hybrid](./doc/en/fp8_kernel.md) weights. Support 139K [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022--v023-longer-context--fp8-kernel) for DeepSeek-V3 and R1 in 24GB VRAM.\n* **Feb 25, 2025**: Support [FP8 GPU kernel](./doc/en/fp8_kernel.md) for DeepSeek-V3 and R1; [Longer Context](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context).\n* **Feb 15, 2025**: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed （+15%, up to 16 Tokens/s), update [docs](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).\n* **Feb 10, 2025**: Support Deepseek-R1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup. For detailed show case and reproduction tutorial, see [here](./doc/en/DeepseekR1_V3_tutorial.md).\n* **Aug 28, 2024**: Decrease DeepseekV2's required VRAM from 21G to 11G.\n* **Aug 15, 2024**: Update detailed [tutorial](doc/en/injection_tutorial.md) for injection and multi-GPU.\n* **Aug 14, 2024**: Support llamfile as linear backend.\n* **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\\*7B  and 8\\*22B; Support q2k, q3k, q5k dequant on gpu.\n* **Aug 9, 2024**: Support windows native.\n\n<!-- * **Aug 28, 2024**: Support 1M context under the InternLM2.5-7B-Chat-1M model, utilizing 24GB of VRAM and 150GB of DRAM. The detailed tutorial is [here](./doc/en/long_context_tutorial.md). -->\n\n<h2 id=\"show-cases\">🌟 Show Cases</h2>\n\n<div>\n<h3>GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM</h3>\n</div>\n\nhttps://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285\n\n</p>\n\n- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM([Tutorial](./doc/en/DeepseekR1_V3_tutorial.md)).\n  \n  - Prefill Speed (tokens/s):\n    - KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only)\n    - Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**.\n  - Decode Speed (tokens/s):\n    - KTransformers: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only)\n    - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**.\n  - Upcoming Open Source Release:\n    - AMX optimizations and selective expert activation will be open-sourced in V0.3.\n    - Currently available only in preview binary distribution, which can be downloaded [here](./doc/en/DeepseekR1_V3_tutorial.md).\n- **Local 236B DeepSeek-Coder-V2:** Running its Q4_K_M version using only 21GB VRAM and 136GB DRAM, attainable on a local desktop machine, which scores even better than GPT4-0613 in [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench).\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"DeepSeek-Coder-V2 Score\" src=\"https://github.com/user-attachments/assets/d052924e-8631-44de-aad2-97c54b965693\" width=100%>\n  </picture>\n</p>\n\n- **Faster Speed:** Achieving 126 tokens/s for 2K prompt prefill and 13.6 tokens/s for generation through MoE offloading and injecting advanced kernels from [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) and [Marlin](https://github.com/IST-DASLab/marlin).\n- **VSCode Integration:** Wrapped into an OpenAI and Ollama compatible API for seamless integration as a backend for [Tabby](https://github.com/TabbyML/tabby) and various other frontends.\n\n<p align=\"center\">\n\nhttps://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c\n\n</p>\n\n<!-- <h3>1M Context Local Inference on a Desktop with Only 24GB VRAM</h3>\n<p align=\"center\">\n\nhttps://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12\n\n* **1M Context InternLM 2.5 7B**: Operates at full bf16 precision, utilizing 24GB VRAM and 150GB DRAM, which is feasible on a local desktop setup. It achieves a 92.88% success rate on the 1M \"Needle In a Haystack\" test and 100% on the 128K NIAH test.\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"Single Needle Retrieval 128K\" src=\"./doc/assets/needle_128K.png\" width=100%>\n  </picture>\n</p>\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"Single Needle Retrieval 1000K\" src=\"./doc/assets/needle_1M.png\" width=100%>\n  </picture>\n</p>\n\n* **Enhanced Speed**: Reaches 16.91 tokens/s for generation with a 1M context using sparse attention, powered by llamafile kernels. This method is over 10 times faster than full attention approach of llama.cpp.\n\n* **Flexible Sparse Attention Framework**: Offers a flexible block sparse attention framework for CPU offloaded decoding. Compatible with SnapKV, Quest, and InfLLm. Further information is available [here](./doc/en/long_context_introduction.md).\n -->\n\n<strong>More advanced features will coming soon, so stay tuned!</strong>\n\n<h2 id=\"quick-start\">🚀 Quick Start</h2>\n\nGetting started with KTransformers is simple! Follow the steps below to set up and start using it.\n\nwe have already supported vendors:\n\n- Metax\n- Sanechips (ZhuFeng V1.0)\n- Intel\n- Ascend\n- Kunpeng\n- AMD\n\n### 📥 Installation\n\nTo install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).\n\n<h2 id=\"tutorial\">📃 Brief Injection Tutorial</h2>\nAt the heart of KTransformers is a user-friendly, template-based injection framework. \nThis allows researchers to easily replace original torch modules with optimized variants. It also simplifies the process of combining multiple optimizations, allowing the exploration of their synergistic effects.\n\n</br>\n<p align=\"center\">\n  <picture>\n    <img alt=\"Inject-Struction\" src=\"https://github.com/user-attachments/assets/6b4c1e54-9f6d-45c5-a3fc-8fa45e7d257e\" width=65%>\n  </picture>\n</p>\n\nGiven that vLLM already serves as a great framework for large-scale deployment optimizations, KTransformers is particularly focused on local deployments that are constrained by limited resources. We pay special attention to heterogeneous computing opportunities, such as GPU/CPU offloading of quantized models. For example, we support the efficient <a herf=\"https://github.com/Mozilla-Ocho/llamafile/tree/main\">Llamafile</a> and <a herf=\"https://github.com/IST-DASLab/marlin\">Marlin</a> kernels for CPU and GPU, respectively. More details can be found <a herf=\"doc/en/operators/llamafile.md\">here</a>.\n\n<h3>Example Usage</h3>\nTo utilize the provided kernels, users only need to create a YAML-based injection template and add the call to `optimize_and_load_gguf` before using the Transformers model.\n\n```python\nwith torch.device(\"meta\"):\n    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)\noptimize_and_load_gguf(model, optimize_config_path, gguf_path, config)\n...\ngenerated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000)\n```\n\nIn this example, the AutoModel is first initialized on the meta device to avoid occupying any memory resources. Then, `optimize_and_load_gguf` iterates through all sub-modules of the model, matches rules specified in your YAML rule file, and replaces them with advanced modules as specified.\n\nAfter injection, the original `generate` interface is available, but we also provide a compatible `prefill_and_generate` method, which enables further optimizations like CUDAGraph to improve generation speed.\n\n<h3>How to custom your model</h3>\n\nA detailed tutorial of the injection and multi-GPU using DeepSeek-V2 as an example is given [here](doc/en/injection_tutorial.md).\n\nBelow is an example of a YAML template for replacing all original Linear modules with Marlin, an advanced 4-bit quantization kernel.\n\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformerLinear  # optimized Kernel on quantized data types\n    device: \"cpu\"   # which devices to load this module when initializing\n    kwargs:\n      generate_device: \"cuda\"\n      generate_linear_type: \"QuantizedLinearMarlin\"\n```\n\nEach rule in the YAML file has two parts: `match` and `replace`. The `match` part specifies which module should be replaced, and the `replace` part specifies the module to be injected into the model along with the initialization keywords.\n\nYou can find example rule templates for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models, in the [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) directory. These templates are used to power the `local_chat.py` demo.\n\nIf you are interested in our design principles and the implementation of the injection framework, please refer to the [design document](doc/en/deepseek-v2-injection.md).\n\n<h2 id=\"Citation\">🔥 Citation</h2>\n\nIf you use KTransformers for your research, please cite our [paper](https://madsys.cs.tsinghua.edu.cn/publication/ktransformers-unleashing-the-full-potential-of-cpu/gpu-hybrid-inference-for-moe-models/):\n\n```\n@inproceedings{10.1145/3731569.3764843,\ntitle = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},\nauthor = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},\nbooktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},\nyear = {2025}\n}\n```\n\n<h2 id=\"ack\">Acknowledgment and Contributors</h2>\n\nThe development of KTransformers is based on the flexible and versatile framework provided by Transformers. We also benefit from advanced kernels such as GGUF/GGML, Llamafile, Marlin, sglang and flashinfer. We are planning to contribute back to the community by upstreaming our modifications.\n\nKTransformers is actively maintained and developed by contributors from the <a href=\"https://madsys.cs.tsinghua.edu.cn/\">MADSys group</a> at Tsinghua University and members from <a href=\"http://approaching.ai/\">Approaching.AI</a>. We welcome new contributors to join us in making KTransformers faster and easier to use.\n\n<h2 id=\"ack\">Discussion</h2>\n\nIf you have any questions, feel free to open an issue. Alternatively, you can join our WeChat group for further discussion. QR Code: [WeChat Group](WeChatGroup.png)\n\n<h2 id=\"FAQ\">🙋 FAQ</h2>\n\nSome common questions are answered in the [FAQ](doc/en/FAQ.md).\n\n"
  },
  {
    "path": "archive/README_ZH.md",
    "content": "<div align=\"center\">\n  <p align=\"center\">\n    <picture>\n      <img alt=\"KTransformers\" src=\"https://github.com/user-attachments/assets/d5a2492f-a415-4456-af99-4ab102f13f8b\" width=50%>\n    </picture>\n  </p>\n  <h3>高性能 CPU-GPU 异构大语言模型推理</h3>\n</div>\n\n## 🎯 项目概述\n\nKTransformers 是一个专注于大语言模型高效推理和微调的研究项目，通过 CPU-GPU 异构计算实现资源受限环境下的模型部署。项目已演进为**两个核心模块**：[kt-kernel](./kt-kernel/) 和 [kt-sft](./kt-sft/)。\n\n## 🔥 更新\n\n* **2025年11月6日**：支持 Kimi-K2-Thinking 推理和微调\n* **2025年11月4日**：KTransformers 微调 × LLaMA-Factory 集成\n* **2025年10月27日**：支持 Ascend NPU\n* **2025年10月10日**：集成到 SGLang ([路线图](https://github.com/sgl-project/sglang/issues/11425), [博客](https://lmsys.org/blog/2025-10-22-KTransformers/))\n* **2025年9月11日**：支持 Qwen3-Next\n* **2025年9月5日**：支持 Kimi-K2-0905\n* **2025年7月26日**：支持 SmallThinker 和 GLM4-MoE\n* **2025年6月30日**：支持 3层（GPU-CPU-磁盘）前缀缓存复用\n* **2025年5月14日**：支持 Intel Arc GPU\n* **2025年4月29日**：支持 AMX-Int8、AMX-BF16 和 Qwen3MoE\n* **2025年4月9日**：实验性支持 LLaMA 4 模型\n* **2025年4月2日**：支持多并发\n* **2025年3月15日**：支持 AMD GPU 的 ROCm\n* **2025年3月5日**：支持 unsloth 1.58/2.51 bits 权重和 IQ1_S/FP8 混合权重；DeepSeek-V3/R1 支持 139K 长上下文\n* **2025年2月25日**：支持 DeepSeek-V3 和 R1 的 FP8 GPU 内核\n* **2025年2月10日**：支持 Deepseek-R1 和 V3，速度提升最高达 3~28 倍\n\n---\n\n## 📦 核心模块\n\n### 🚀 [kt-kernel](./kt-kernel/) - 高性能推理内核\n\n面向异构 LLM 推理的 CPU 优化内核操作库。\n\n![image-20251011010558909](./doc/assets/heterogeneous_computing.png)\n\n**核心特性：**\n- **AMX/AVX 加速**：Intel AMX 和 AVX512/AVX2 优化内核，支持 INT4/INT8 量化推理\n- **MoE 优化**：高效的专家混合推理，支持 NUMA 感知内存管理\n- **量化支持**：CPU 端 INT4/INT8 量化权重，GPU 端 GPTQ 支持\n- **易于集成**：简洁的 Python API，可集成到 SGLang 等框架\n\n**快速开始：**\n```bash\ncd kt-kernel\npip install .\n```\n\n**应用场景：**\n- 大型 MoE 模型的 CPU-GPU 混合推理\n- 与 SGLang 集成用于生产服务\n- 异构专家放置（热门专家在 GPU，冷门专家在 CPU）\n\n**性能示例：**\n| 模型 | 硬件配置 | 总吞吐量 | 输出吞吐量 |\n|------|---------|---------|-----------|\n| DeepSeek-R1-0528 (FP8) | 8×L20 GPU + Xeon Gold 6454S | 227.85 tokens/s | 87.58 tokens/s（8路并发）|\n\n👉 **[完整文档 →](./kt-kernel/README.md)**\n\n---\n\n### 🎓 [kt-sft](./kt-sft/) - 微调框架\n\nKTransformers × LLaMA-Factory 集成，支持超大 MoE 模型微调。\n\n![image-20251011010558909](./doc/assets/image-20251011010558909.png)\n\n**核心特性：**\n- **资源高效**：仅需 **70GB 显存** + 1.3TB 内存即可微调 671B DeepSeek-V3\n- **LoRA 支持**：完整的 LoRA 微调与异构加速\n- **LLaMA-Factory 集成**：与流行微调框架无缝集成\n- **生产就绪**：支持对话、批量推理和指标评估\n\n**性能示例：**\n| 模型 | 配置 | 吞吐量 | GPU 显存 |\n|------|------|--------|----------|\n| DeepSeek-V3 (671B) | LoRA + AMX | ~40 tokens/s | 70GB (多卡) |\n| DeepSeek-V2-Lite (14B) | LoRA + AMX | ~530 tokens/s | 6GB |\n\n**快速开始：**\n```bash\ncd kt-sft\n# 按照 kt-sft/README.md 安装环境\nUSE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml\n```\n\n👉 **[完整文档 →](./kt-sft/README.md)**\n\n---\n\n## 🔥 引用\n\n如果您在研究中使用了 KTransformers，请引用我们的论文：\n\n```bibtex\n@inproceedings{10.1145/3731569.3764843,\n  title = {KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models},\n  author = {Chen, Hongtao and Xie, Weiyu and Zhang, Boxin and Tang, Jingqi and Wang, Jiahao and Dong, Jianwei and Chen, Shaoyuan and Yuan, Ziwei and Lin, Chen and Qiu, Chengyu and Zhu, Yuening and Ou, Qingliang and Liao, Jiaqi and Chen, Xianglin and Ai, Zhiyuan and Wu, Yongwei and Zhang, Mingxing},\n  booktitle = {Proceedings of the ACM SIGOPS 31st Symposium on Operating Systems Principles},\n  year = {2025}\n}\n```\n\n## 👥 贡献者与团队\n\n由以下团队开发和维护：\n- 清华大学 [MADSys 实验室](https://madsys.cs.tsinghua.edu.cn/)\n- [Approaching.AI](http://approaching.ai/)\n- 社区贡献者\n\n我们欢迎贡献！请随时提交 issues 和 pull requests。\n\n## 💬 社区与支持\n\n- **GitHub Issues**：[报告 bug 或请求功能](https://github.com/kvcache-ai/ktransformers/issues)\n- **GitHub Discussions**：[提问和分享想法](https://github.com/kvcache-ai/ktransformers/discussions)\n- **微信群**：查看 [archive/WeChatGroup.png](./archive/WeChatGroup.png)\n\n## 📦 历史代码\n\n原完整的 KTransformers 框架代码已归档至 [`archive/`](./archive/) 目录供参考。项目现专注于上述两个核心模块，以实现更好的模块化和可维护性。\n\n关于原始完整文档（包含快速入门指南和示例），请查看：\n- [archive/README_LEGACY.md](./archive/README_LEGACY.md) (English)\n- [archive/README_ZH_LEGACY.md](./archive/README_ZH_LEGACY.md) (中文)\n"
  },
  {
    "path": "archive/README_ZH_LEGACY.md",
    "content": "<div align=\"center\">\n  <!-- <h1>KTransformers</h1> -->\n  <p align=\"center\">\n\n<picture>\n    <img alt=\"KTransformers\" src=\"https://github.com/user-attachments/assets/d5a2492f-a415-4456-af99-4ab102f13f8b\" width=50%>\n\n</picture>\n\n</p>\n  <h3>一个用于体验尖端 LLM 推理优化的灵活框架</h3>\n  <strong><a href=\"#show-cases\">🌟 案例展示</a> | <a href=\"#quick-start\">🚀 快速入门</a> | <a href=\"#tutorial\">📃 教程</a> | <a href=\"https://github.com/kvcache-ai/ktransformers/discussions\">💬 讨论</a> | <a href=\"#FAQ\">🙋 常见问题</a> </strong>\n</div>\n\n<h2 id=\"intro\">🎉 介绍</h2>\nKTransformers（发音为 Quick Transformers）旨在通过先进的内核优化和放置/并行策略来增强您对 🤗 [Transformers](https://github.com/huggingface/transformers) 的体验。\n<br/><br/>\nKTransformers 是一个以 Python 为中心的灵活框架，其核心是可扩展性。通过用一行代码实现并注入优化模块，用户可以获得与 Transformers 兼容的接口、符合 OpenAI 和 Ollama 的 RESTful API，甚至是一个简化的类似 ChatGPT 的 Web 界面。\n<br/><br/>\n我们对 KTransformers 的愿景是成为一个用于实验创新 LLM 推理优化的灵活平台。如果您需要任何其他功能，请告诉我们。\n\n<h2 id=\"Updates\">🔥 更新</h2>\n\n* **2025 年 2 月 15 日**：为DeepSeek-V3/R1支持[FP8 GPU内核](./doc/en/fp8_kernel.md); 支持更长的上下文([教程](./doc/en/DeepseekR1_V3_tutorial.md#v022-longer-context)).\n* **2025 年 2 月 15 日**：长上下文(从4K到8K，24GB VRAM) & 稍快的速度(+15%)(最快 16 Tokens/s)，文档请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 和 [在线指南](https://kvcache-ai.github.io/ktransformers/) 。\n* **2025 年 2 月 10 日**：支持 Deepseek-R1 和 V3 在单个（24GB VRAM）/多 GPU 和 382G DRAM 上运行，速度提升高达 3~28 倍。详细教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md)。\n* **2024 年 8 月 28 日**：支持 InternLM2.5-7B-Chat-1M 模型下的 1M 上下文，使用 24GB 的 VRAM 和 150GB 的 DRAM。详细教程请参见 [这里](./doc/en/long_context_tutorial.md)。\n* **2024 年 8 月 28 日**：将 DeepseekV2 所需的 VRAM 从 21G 降低到 11G。\n* **2024 年 8 月 15 日**：更新了详细的 [教程](doc/en/injection_tutorial.md)，介绍注入和多 GPU 的使用。\n* **2024 年 8 月 14 日**：支持 llamfile 作为线性后端。\n* **2024 年 8 月 12 日**：支持多 GPU；支持新模型：mixtral 8\\*7B 和 8\\*22B；支持 q2k、q3k、q5k 在 GPU 上的去量化。\n* **2024 年 8 月 9 日**：支持 Windows。\n\n<h2 id=\"show-cases\">🌟 案例展示</h2>\n\n<div>\n<h3>在仅 24GB VRAM 的桌面上运行 GPT-4/o1 级别的本地 VSCode Copilot</h3>\n</div>\n\nhttps://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285\n\n</p>\n\n- **[NEW!!!] 本地 671B DeepSeek-Coder-V3/R1**：使用其 Q4_K_M 版本，仅需 14GB VRAM 和 382GB DRAM 即可运行（教程请参见 [这里](./doc/en/DeepseekR1_V3_tutorial.md)）。\n\t- 预填充速度（tokens/s）：\n \t\t- KTransformers：54.21（32 核）→ 74.362（双插槽，2×32 核）→ 255.26（优化的 AMX 基 MoE 内核，仅 V0.3）→ 286.55（选择性使用 6 个专家，仅 V0.3）\n \t\t- 与 llama.cpp 在 2×32 核下相比，达到 **27.79× 速度提升**。\n \t- 解码速度（tokens/s）：\n \t\t- KTransformers：8.73（32 核）→ 11.26（双插槽，2×32 核）→ 13.69（选择性使用 6 个专家，仅 V0.3）\n \t\t- 与 llama.cpp 在 2×32 核下相比，达到 **3.03× 速度提升**。\n\t- 即将开源发布：\n\t\t- AMX 优化和选择性专家激活将在 V0.3 中开源。\n\t\t- 目前仅在预览二进制分发中可用，可从 [这里](./doc/en/DeepseekR1_V3_tutorial.md) 下载。\n\n- **本地 236B DeepSeek-Coder-V2**：使用其 Q4_K_M 版本，仅需 21GB VRAM 和 136GB DRAM 即可运行，甚至在 [BigCodeBench](https://huggingface.co/blog/leaderboard-bigcodebench) 中得分超过 GPT4-0613。\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"DeepSeek-Coder-V2 Score\" src=\"https://github.com/user-attachments/assets/d052924e-8631-44de-aad2-97c54b965693\" width=100%>\n  </picture>\n</p>\n\n- **更快的速度**：通过 MoE 卸载和注入来自 [Llamafile](https://github.com/Mozilla-Ocho/llamafile/tree/main) 和 [Marlin](https://github.com/IST-DASLab/marlin) 的高级内核，实现了 2K 提示预填充 126 tokens/s 和生成 13.6 tokens/s 的速度。\n- **VSCode 集成**：封装成符合 OpenAI 和 Ollama 的 API，可无缝集成到 [Tabby](https://github.com/TabbyML/tabby) 和其他前端的后端。\n\n<p align=\"center\">\n\nhttps://github.com/user-attachments/assets/4c6a8a38-05aa-497d-8eb1-3a5b3918429c\n\n</p>\n\n<!-- <h3>在仅 24GB VRAM 的桌面上进行 1M 上下文本地推理</h3>\n<p align=\"center\"> -->\n\n<!-- https://github.com/user-attachments/assets/a865e5e4-bca3-401e-94b8-af3c080e6c12 -->\n<!-- \n* **1M 上下文 InternLM 2.5 7B**：以全 bf16 精度运行，使用 24GB VRAM 和 150GB DRAM，可在本地桌面设置中实现。在 1M \"针在干草堆中\" 测试中达到 92.88% 的成功率，在 128K NIAH 测试中达到 100%。\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"Single Needle Retrieval 128K\" src=\"./doc/assets/needle_128K.png\" width=100%>\n  </picture>\n</p>\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"Single Needle Retrieval 1000K\" src=\"./doc/assets/needle_1M.png\" width=100%>\n  </picture>\n</p>\n\n* **增强的速度**：使用稀疏注意力，通过 llamafile 内核实现 1M 上下文生成 16.91 tokens/s 的速度。这种方法比 llama.cpp 的全注意力方法快 10 倍以上。\n\n* **灵活的稀疏注意力框架**：提供了一个灵活的块稀疏注意力框架，用于 CPU 卸载解码。与 SnapKV、Quest 和 InfLLm 兼容。更多信息请参见 [这里](./doc/en/long_context_introduction.md)。 -->\n\n<strong>更多高级功能即将推出，敬请期待！</strong>\n\n<h2 id=\"quick-start\">🚀 快速入门</h2>\n\n\nKTransformers 的入门非常简单！请参考我们的[安装指南]((https://kvcache-ai.github.io/ktransformers/))进行安装。\n\n<h2 id=\"tutorial\">📃 简要注入教程</h2>\nKTransformers 的核心是一个用户友好的、基于模板的注入框架。这使得研究人员可以轻松地将原始 torch 模块替换为优化的变体。它还简化了多种优化的组合过程，允许探索它们的协同效应。\n</br>\n<p align=\"center\">\n  <picture>\n    <img alt=\"Inject-Struction\" src=\"https://github.com/user-attachments/assets/6b4c1e54-9f6d-45c5-a3fc-8fa45e7d257e\" width=65%>\n  </picture>\n</p>\n\n鉴于 vLLM 已经是一个用于大规模部署优化的优秀框架，KTransformers 特别关注受资源限制的本地部署。我们特别关注异构计算时机，例如量化模型的 GPU/CPU 卸载。例如，我们支持高效的 <a herf=\"https://github.com/Mozilla-Ocho/llamafile/tree/main\">Llamafile</a> 和<a herf=\"https://github.com/IST-DASLab/marlin\">Marlin</a> 内核，分别用于 CPU 和 GPU。 更多详细信息可以在 <a herf=\"doc/en/operators/llamafile.md\">这里</a>找到。\n\n\n<h3>示例用法</h3>\n要使用提供的内核，用户只需创建一个基于 YAML 的注入模板，并在使用 Transformers 模型之前添加对 `optimize_and_load_gguf` 的调用。\n\n```python\nwith torch.device(\"meta\"):\n    model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)\noptimize_and_load_gguf(model, optimize_config_path, gguf_path, config)\n...\ngenerated = prefill_and_generate(model, tokenizer, input_tensor.cuda(), max_new_tokens=1000)\n```\n\n在这个示例中，首先在 meta 设备上初始化 AutoModel，以避免占用任何内存资源。然后，`optimize_and_load_gguf` 遍历模型的所有子模块，匹配您的 YAML 规则文件中指定的规则，并将它们替换为指定的高级模块。\n\n注入后，原始的 `generate` 接口仍然可用，但我们还提供了一个兼容的 `prefill_and_generate` 方法，这使得可以进一步优化，例如使用 CUDAGraph 提高生成速度。\n\n<h3>如何自定义您的模型</h3>\n\n一个详细的使用 DeepSeek-V2 作为示例的注入和 multi-GPU 教程在 [这里](doc/en/injection_tutorial.md)。\n\n以下是一个将所有原始 Linear 模块替换为 Marlin 的 YAML 模板示例，Marlin 是一个高级的 4 位量化内核。\n\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*$\"  # 正则表达式 \n    class: torch.nn.Linear  # 仅匹配同时符合名称和类的模块\n  replace:\n    class: ktransformers.operators.linear.KTransformerLinear  # 量化数据类型的优化内核\n    device: \"cpu\"   # 初始化时加载该模块的 device\n    kwargs:\n      generate_device: \"cuda\"\n      generate_linear_type: \"QuantizedLinearMarlin\"\n```\n\nYAML 文件中的每个规则都有两部分：`match` 和 `replace`。`match` 部分指定应替换的模块，`replace` 部分指定要注入到模型中的模块以及初始化关键字。\n\n您可以在 [ktransformers/optimize/optimize_rules](ktransformers/optimize/optimize_rules) 目录中找到用于优化 DeepSeek-V2 和 Qwen2-57B-A14 的示例规则模板。这些模板用于为 `local_chat.py` 示例提供支持。\n\n如果您对我们的设计原则和注入框架的实现感兴趣，请参考 [设计文档](doc/en/deepseek-v2-injection.md)。\n\n<h2 id=\"ack\">致谢和贡献者</h2>\n\nKTransformers 的开发基于 Transformers 提供的灵活和多功能框架。我们还受益于 GGUF/GGML、Llamafile 、 Marlin、sglang和flashinfer 等高级内核。我们计划通过向上游贡献我们的修改来回馈社区。\n\nKTransformers 由清华大学 <a href=\"https://madsys.cs.tsinghua.edu.cn/\">MADSys group</a> 小组的成员以及 <a href=\"http://approaching.ai/\">Approaching.AI</a> 的成员积极维护和开发。我们欢迎新的贡献者加入我们，使 KTransformers 更快、更易于使用。\n\n\n<h2 id=\"ack\">讨论</h2>\n\n如果您有任何问题，欢迎随时提出 issue。或者，您可以加入我们的微信群进行进一步讨论。二维码： [微信群](WeChatGroup.png)\n\n<h2 id=\"FAQ\">🙋 常见问题</h2>\n\n一些常见问题的答案可以在 [FAQ](doc/en/FAQ.md) 中找到。 \n"
  },
  {
    "path": "archive/SECURITY.md",
    "content": "# Security Policy\n\n## Supported Versions\n\nUse this section to tell people about which versions of your project are\ncurrently being supported with security updates.\n\n| Version | Supported          |\n| ------- | ------------------ |\n| 5.1.x   | :white_check_mark: |\n| 5.0.x   | :x:                |\n| 4.0.x   | :white_check_mark: |\n| < 4.0   | :x:                |\n\n## Reporting a Vulnerability\n\nUse this section to tell people how to report a vulnerability.\n\nTell them where to go, how often they can expect to get an update on a\nreported vulnerability, what to expect if the vulnerability is accepted or\ndeclined, etc.\n"
  },
  {
    "path": "archive/book.toml",
    "content": "[book]\nauthors = [\"kvcache-ai\"]\nlanguage = \"zh-CN\"\ntitle = \"Ktransformers\"\nsrc = \"doc\"\n\n[output.html]\ngit-repository-url = \"https://github.com/kvcache-ai/ktransformers\"\nedit-url-template = \"https://github.com/kvcache-ai/ktransformers/edit/main/{path}\"\n\n[output.html.playground]\neditable = true\ncopy-js = true\n# line-numbers = true\n\n[output.html.fold]\nenable = true\nlevel = 0"
  },
  {
    "path": "archive/config.json",
    "content": ""
  },
  {
    "path": "archive/csrc/balance_serve/CMakeLists.txt",
    "content": "option(KTRANSFORMERS_USE_NPU                 \"ktransformers: use NPU\"                           OFF)\nif(KTRANSFORMERS_USE_NPU)\n    add_definitions(-DKTRANSFORMERS_USE_NPU=1)\nendif()\n\nif(KTRANSFORMERS_USE_NPU)\n    set(ASCEND_HOME_PATH \"$ENV{ASCEND_HOME_PATH}\")\n    message(STATUS \"ASCEND_HOME_PATH is ${ASCEND_HOME_PATH}\")\n    include_directories(${ASCEND_HOME_PATH}/include)\n    \n    link_directories(${TORCH_INSTALL_PREFIX}/../torch.libs)\n    # find torch_npu\n    execute_process(\n            COMMAND python -c \"import torch; import torch_npu; print(torch_npu.__path__[0])\"\n            OUTPUT_VARIABLE TORCH_NPU_PATH\n            OUTPUT_STRIP_TRAILING_WHITESPACE\n    )\n    message(STATUS \"Found PTA at: ${TORCH_NPU_PATH}\")\n    find_library(PTA_LIBRARY torch_npu PATH \"${TORCH_NPU_PATH}/lib\")\nendif()\n\ncmake_minimum_required(VERSION 3.21)\nfind_program(GCC_COMPILER NAMES g++-13 g++-12 g++-11 g++ REQUIRED)\nset(CMAKE_CXX_COMPILER ${GCC_COMPILER})\n\n# 显示选定的编译器\nmessage(STATUS \"Using compiler: ${CMAKE_CXX_COMPILER}\")\n\n\nproject(balance_serve VERSION 0.1.0)\n\nset(CMAKE_CXX_STANDARD 20)\nset(CMAKE_CXX_FLAGS \"-Og -march=native -Wall -Wextra -g -fPIC\")\nset(CMAKE_BUILD_TYPE \"Debug\")\n# set(CMAKE_CXX_FLAGS \"-O3 -march=native -Wall -Wextra -fPIC\")\n# set(CMAKE_BUILD_TYPE \"Release\")\n\n\nif(NOT DEFINED _GLIBCXX_USE_CXX11_ABI)\n    find_package(Python3 REQUIRED COMPONENTS Interpreter)\n\n    execute_process(\n        COMMAND ${Python3_EXECUTABLE} -c\n        \"import torch; print('1' if torch.compiled_with_cxx11_abi() else '0')\"\n        OUTPUT_VARIABLE ABI_FLAG\n        OUTPUT_STRIP_TRAILING_WHITESPACE\n    )\n\n    set(_GLIBCXX_USE_CXX11_ABI ${ABI_FLAG} CACHE STRING \"C++11 ABI setting from PyTorch\" FORCE)\nendif()\n\n# 无论是否是自动检测，都传给编译器\nadd_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})\n\nmessage(STATUS \"_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI}\")\n\nfile(GLOB_RECURSE FMT_SOURCES \"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp\" \"${CMAKE_CURRENT_SOURCE_DIR}/*.hpp\" \"${CMAKE_CURRENT_SOURCE_DIR}/*.h\")\n\nadd_custom_target(\n    format\n    COMMAND clang-format\n    -i\n    -style=file\n    ${FMT_SOURCES}\n    COMMENT \"Running clang-format on all source files\"\n)\n\nset(BUILD_SHARED_LIBS ON)\nset(ENABLE_PUSH OFF)\nset(ENABLE_COMPRESSION OFF)\n\n# set(CMAKE_BUILD_TYPE \"Release\")\nset(CMAKE_EXPORT_COMPILE_COMMANDS ON)\n\nset(THIRD_PARTY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)\nset(THIRD_PARTY_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/third_party)\nadd_subdirectory(${THIRD_PARTY_DIR}/prometheus-cpp ${THIRD_PARTY_BUILD_DIR}/prometheus-cpp EXCLUDE_FROM_ALL)\nadd_subdirectory(${THIRD_PARTY_DIR}/xxHash/cmake_unofficial ${THIRD_PARTY_BUILD_DIR}/xxHash EXCLUDE_FROM_ALL)\nset_target_properties(xxhash PROPERTIES POSITION_INDEPENDENT_CODE ON)\n\n# add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/prometheus-cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/prometheus-cpp)\nset(SPDLOG_DIR ${THIRD_PARTY_DIR}/spdlog)\nset(FMT_DIR ${THIRD_PARTY_DIR}/fmt)\n\nset(KVC2_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/kvc2/src)\n\ninclude_directories(${THIRD_PARTY_DIR})\n\nadd_subdirectory(${THIRD_PARTY_DIR}/pybind11 ${THIRD_PARTY_BUILD_DIR}/pybind11)\n\nexecute_process(\n    COMMAND python3 -c \"import torch; print(torch.__path__[0])\"\n    OUTPUT_VARIABLE TORCH_INSTALL_PREFIX\n    OUTPUT_STRIP_TRAILING_WHITESPACE\n)\n\nmessage(STATUS \"Found PyTorch at: ${TORCH_INSTALL_PREFIX}\")\n\n# set(TORCH_INSTALL_PREFIX \"/home/xwy/.conda/envs/kvc/lib/python3.12/site-packages/torch\")\nfind_library(TORCH_PYTHON_LIBRARY torch_python PATH \"${TORCH_INSTALL_PREFIX}/lib\")\nfind_package(Torch REQUIRED PATHS \"${TORCH_INSTALL_PREFIX}/share/cmake/Torch\" NO_DEFAULT_PATH)\n\nadd_subdirectory(kvc2)\nadd_subdirectory(sched)\n\n# add_subdirectory(test)\n"
  },
  {
    "path": "archive/csrc/custom_marlin/__init__.py",
    "content": ""
  },
  {
    "path": "archive/csrc/custom_marlin/binding.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Azure-Tang\n * @Date         : 2024-07-25 13:38:30\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-12 03:05:04\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"gptq_marlin/ops.h\"\n// Python bindings\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <torch/extension.h>\n#include <torch/library.h>\n#include <torch/torch.h>\n// namespace py = pybind11;\n\nPYBIND11_MODULE(vLLMMarlin, m) {\n\n    /*m.def(\"dequantize_q8_0\", &dequantize_q8_0, \"Function to dequantize q8_0\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_q6_k\", &dequantize_q6_k, \"Function to dequantize q6_k\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_q5_k\", &dequantize_q5_k, \"Function to dequantize q5_k\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_q4_k\",  &dequantize_q4_k, \"Function to dequantize q4_k\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_q3_k\",  &dequantize_q3_k, \"Function to dequantize q3_k\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_q2_k\",  &dequantize_q2_k, \"Function to dequantize q2_k\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_iq4_xs\",  &dequantize_iq4_xs, \"Function to dequantize\n    iq4_xs data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));*/\n    m.def(\"gptq_marlin_gemm\", &gptq_marlin_gemm,\n          \"Function to perform GEMM using Marlin quantization.\", py::arg(\"a\"),\n          py::arg(\"b_q_weight\"), py::arg(\"b_scales\"), py::arg(\"g_idx\"),\n          py::arg(\"perm\"), py::arg(\"workspace\"), py::arg(\"num_bits\"), py::arg(\"size_m_tensor\"),\n          py::arg(\"size_m\"), py::arg(\"size_n\"), py::arg(\"size_k\"),\n          py::arg(\"sms\"), py::arg(\"is_k_full\"));\n    m.def(\"gptq_marlin_repack\", &gptq_marlin_repack,\n            \"gptq_marlin repack from GPTQ\");\n}"
  },
  {
    "path": "archive/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n /*\n  * Adapted from https://github.com/IST-DASLab/marlin\n  */\n  /*\n   * Adapted from\n   * https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n   */\n#include \"gptq_marlin.cuh\"\n#include \"gptq_marlin_dtypes.cuh\"\n#include <c10/cuda/CUDAGuard.h>\n#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t)                              \\\n    static_assert(std::is_same<scalar_t, half>::value ||                       \\\n                      std::is_same<scalar_t, nv_bfloat16>::value,              \\\n                  \"only float16 and bfloat16 is supported\");\n\ntemplate <typename T> inline std::string str(T x) { return std::to_string(x); }\n\nnamespace gptq_marlin {\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\n    __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,\n        int const* __restrict__ perm_int_ptr,\n        int4* __restrict__ out_int4_ptr, int size_m,\n        int size_k, int block_rows) {}\n\n    template <typename scalar_t,         // compute dtype, half or nv_float16\n        const int num_bits,        // number of bits used for weights\n        const int threads,         // number of threads in a threadblock\n        const int thread_m_blocks, // number of 16x16 blocks in the m\n        // dimension (batchsize) of the\n        // threadblock\n        const int thread_n_blocks, // same for n dimension (output)\n        const int thread_k_blocks, // same for k dimension (reduction)\n        const int stages, // number of stages for the async global->shared\n        // fetch pipeline\n        const bool has_act_order,   // whether act_order is enabled\n        const int group_blocks = -1 // number of consecutive 16x16 blocks\n        // with a separate quantization scale\n    >\n    __global__ void\n        Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk\n            const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn\n            int4* __restrict__ C,       // fp16 output buffer of shape mxn\n            const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape\n            // (k/groupsize)xn\n            const int* __restrict__ g_idx, // int32 group indices of shape k\n            int num_groups, // number of scale groups per output channel\n            int prob_m,     // batch dimension m\n            int prob_n,     // output dimension n\n            int prob_k,     // reduction dimension k\n            int* locks      // extra global storage for barrier synchronization\n        ) {}\n\n} // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n    torch::Tensor& b_scales, torch::Tensor& g_idx,\n    torch::Tensor& perm, torch::Tensor& workspace,\n    int64_t num_bits, int64_t size_m, int64_t size_n,\n    int64_t size_k, bool is_k_full) {\n    TORCH_CHECK_NOT_IMPLEMENTED(false,\n        \"marlin_gemm(..) requires CUDA_ARCH >= 8.0\");\n    return torch::empty({ 1, 1 });\n}\n\n#else\n\n    // m16n8k16 tensor core mma instruction with fp16 inputs and fp32\n    // output/accumulation.\n    template <typename scalar_t>\n    __device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,\n        const typename ScalarType<scalar_t>::FragB& frag_b,\n        typename ScalarType<scalar_t>::FragC& frag_c) {\n        const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n        const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n        float* c = reinterpret_cast<float*>(&frag_c);\n        if constexpr (std::is_same<scalar_t, half>::value) {\n            asm volatile(\n                \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n                \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n                : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n                : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n                \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n        }\n        else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n            asm volatile(\n                \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n                \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n                : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n                : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n                \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n        }\n        else {\n            STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n        }\n    }\n\n    // Instruction for loading a full 16x16 matrix fragment of operand A from shared\n    // memory, directly in tensor core layout.\n    template <typename scalar_t>\n    __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,\n        const void* smem_ptr) {\n        uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);\n        uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n        asm volatile(\n            \"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n            : \"=r\"(a[0]), \"=r\"(a[1]), \"=r\"(a[2]), \"=r\"(a[3])\n            : \"r\"(smem));\n    }\n\n    // Lookup-table based 3-input logical operation; explicitly used for\n    // dequantization as the compiler does not seem to automatically recognize it in\n    // all cases.\n    template <int lut> __device__ inline int lop3(int a, int b, int c) {\n        int res;\n        asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n            : \"=r\"(res)\n            : \"r\"(a), \"r\"(b), \"r\"(c), \"n\"(lut));\n        return res;\n    }\n\n    // Constructs destination register by taking bytes from 2 sources (based on\n    // mask)\n    template <int start_byte, int mask>\n    __device__ inline uint32_t prmt(uint32_t a) {\n        uint32_t res;\n        asm volatile(\"prmt.b32 %0, %1, %2, %3;\\n\"\n            : \"=r\"(res)\n            : \"r\"(a), \"n\"(start_byte), \"n\"(mask));\n        return res;\n    }\n\n    // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16\n    // values. We mostly follow the strategy in the link below, with some small\n    // changes:\n    // - FP16:\n    // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287\n    // - BF16:\n    // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385\n    template <typename scalar_t>\n    __device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {\n        STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n    }\n\n    template <>\n    __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {\n        const int LO = 0x000f000f;\n        const int HI = 0x00f000f0;\n        const int EX = 0x64006400;\n        // Guarantee that the `(a & b) | c` operations are LOP3s.\n        int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);\n        int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);\n        // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point\n        // directly into `SUB` and `ADD`.\n        const int SUB = 0x64086408;\n        const int MUL = 0x2c002c00;\n        const int ADD = 0xd480d480;\n        typename ScalarType<half>::FragB frag_b;\n        frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),\n            *reinterpret_cast<const half2*>(&SUB));\n        frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),\n            *reinterpret_cast<const half2*>(&MUL),\n            *reinterpret_cast<const half2*>(&ADD));\n        return frag_b;\n    }\n\n    template <>\n    __device__ inline typename ScalarType<nv_bfloat16>::FragB\n        dequant_4bit<nv_bfloat16>(int q) {\n        static constexpr uint32_t MASK = 0x000f000f;\n        static constexpr uint32_t EX = 0x43004300;\n\n        // Guarantee that the `(a & b) | c` operations are LOP3s.\n\n        int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n        q >>= 4;\n        int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n\n        typename ScalarType<nv_bfloat16>::FragB frag_b;\n        static constexpr uint32_t MUL = 0x3F803F80;\n        static constexpr uint32_t ADD = 0xC308C308;\n\n        frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),\n            *reinterpret_cast<const nv_bfloat162*>(&MUL),\n            *reinterpret_cast<const nv_bfloat162*>(&ADD));\n        frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),\n            *reinterpret_cast<const nv_bfloat162*>(&MUL),\n            *reinterpret_cast<const nv_bfloat162*>(&ADD));\n        return frag_b;\n    }\n\n    // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or\n    // bf16 Reference:\n    // - FP16:\n    // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85\n    // - BF16:\n    // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175\n    template <typename scalar_t>\n    __device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {\n        STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n    }\n\n    template <>\n    __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {\n        static constexpr uint32_t mask_for_elt_01 = 0x5250;\n        static constexpr uint32_t mask_for_elt_23 = 0x5351;\n        static constexpr uint32_t start_byte_for_fp16 = 0x64646464;\n\n        uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);\n        uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);\n\n        static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;\n\n        typename ScalarType<half>::FragB frag_b;\n        frag_b[0] =\n            __hsub2(*reinterpret_cast<half2*>(&lo),\n                *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n        frag_b[1] =\n            __hsub2(*reinterpret_cast<half2*>(&hi),\n                *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n        return frag_b;\n    }\n\n    template <>\n    __device__ inline typename ScalarType<nv_bfloat16>::FragB\n        dequant_8bit<nv_bfloat16>(int q) {\n        typename ScalarType<nv_bfloat16>::FragB frag_b;\n\n        float fp32_intermediates[4];\n        uint32_t* fp32_intermediates_casted =\n            reinterpret_cast<uint32_t*>(fp32_intermediates);\n\n        static constexpr uint32_t fp32_base = 0x4B000000;\n        fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);\n        fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);\n        fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);\n        fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);\n\n        fp32_intermediates[0] -= 8388736.f;\n        fp32_intermediates[1] -= 8388736.f;\n        fp32_intermediates[2] -= 8388736.f;\n        fp32_intermediates[3] -= 8388736.f;\n\n        uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);\n        bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],\n            fp32_intermediates_casted[1], 0x7632);\n        bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],\n            fp32_intermediates_casted[3], 0x7632);\n\n        return frag_b;\n    }\n\n    // Multiply dequantized values by the corresponding quantization scale; used\n    // only for grouped quantization.\n    template <typename scalar_t>\n    __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,\n        typename ScalarType<scalar_t>::FragS& frag_s,\n        int i) {\n        using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n        scalar_t2 s = ScalarType<scalar_t>::num2num2(\n            reinterpret_cast<scalar_t*>(&frag_s)[i]);\n        frag_b[0] = __hmul2(frag_b[0], s);\n        frag_b[1] = __hmul2(frag_b[1], s);\n    }\n\n    // Same as above, but for act_order (each K is multiplied individually)\n    template <typename scalar_t>\n    __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,\n        typename ScalarType<scalar_t>::FragS& frag_s_1,\n        typename ScalarType<scalar_t>::FragS& frag_s_2,\n        typename ScalarType<scalar_t>::FragS& frag_s_3,\n        typename ScalarType<scalar_t>::FragS& frag_s_4,\n        int i) {\n        using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n        scalar_t2 s_val_1_2;\n        s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];\n        s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];\n\n        scalar_t2 s_val_3_4;\n        s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];\n        s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];\n\n        frag_b[0] = __hmul2(frag_b[0], s_val_1_2);\n        frag_b[1] = __hmul2(frag_b[1], s_val_3_4);\n    }\n\n    // Given 2 floats multiply by 2 scales (halves)\n    template <typename scalar_t>\n    __device__ inline void scale_float(float* c,\n        typename ScalarType<scalar_t>::FragS& s) {\n        scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);\n        c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));\n        c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));\n    }\n\n    // Wait until barrier reaches `count`, then lock for current threadblock.\n    __device__ inline void barrier_acquire(int* lock, int count) {\n        if (threadIdx.x == 0) {\n            int state = -1;\n            do\n                // Guarantee that subsequent writes by this threadblock will be\n                // visible globally.\n                asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\"\n                    : \"=r\"(state)\n                    : \"l\"(lock));\n            while (state != count);\n        }\n        __syncthreads();\n    }\n\n    // Release barrier and increment visitation count.\n    __device__ inline void barrier_release(int* lock, bool reset = false) {\n        __syncthreads();\n        if (threadIdx.x == 0) {\n            if (reset) {\n                lock[0] = 0;\n                return;\n            }\n            int val = 1;\n            // Make sure that all writes since acquiring this barrier are visible\n            // globally, while releasing the barrier.\n            asm volatile(\"fence.acq_rel.gpu;\\n\");\n            asm volatile(\"red.relaxed.gpu.global.add.s32 [%0], %1;\\n\"\n                :\n            : \"l\"(lock), \"r\"(val));\n        }\n    }\n\n    // For a given \"a\" of size [M,K] performs a permutation of the K columns based\n    // on the given \"perm\" indices.\n    __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,\n        int const* __restrict__ perm_int_ptr,\n        int4* __restrict__ out_int4_ptr, int size_m,\n        int size_k, int block_rows) {\n        int start_row = block_rows * blockIdx.x;\n        int finish_row = start_row + block_rows;\n        if (finish_row > size_m) {\n            finish_row = size_m;\n        }\n        int cur_block_rows = finish_row - start_row;\n\n        int row_stride = size_k * sizeof(half) / 16;\n\n        auto permute_row = [&](int row) {\n            int iters = size_k / default_threads;\n            int rest = size_k % default_threads;\n\n            int offset = row * row_stride;\n\n            half const* a_row_half =\n                reinterpret_cast<half const*>(a_int4_ptr + offset);\n            half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);\n\n            int base_k = 0;\n\n            for (int i = 0; i < iters; i++) {\n                int cur_k = base_k + threadIdx.x;\n                int src_pos = perm_int_ptr[cur_k];\n\n                out_half[cur_k] = a_row_half[src_pos];\n\n                base_k += default_threads;\n            }\n\n            if (rest) {\n                if (threadIdx.x < rest) {\n                    int cur_k = base_k + threadIdx.x;\n                    int src_pos = perm_int_ptr[cur_k];\n\n                    out_half[cur_k] = a_row_half[src_pos];\n                }\n            }\n            };\n\n        for (int i = 0; i < cur_block_rows; i++) {\n            int cur_row = start_row + i;\n            if (cur_row < size_m) {\n                permute_row(cur_row);\n            }\n        }\n    }\n\n    template <typename scalar_t,         // compute dtype, half or nv_float16\n        const int num_bits,        // number of bits used for weights\n        const int threads,         // number of threads in a threadblock\n        const int thread_m_blocks, // number of 16x16 blocks in the m\n        // dimension (batchsize) of the\n        // threadblock\n        const int thread_n_blocks, // same for n dimension (output)\n        const int thread_k_blocks, // same for k dimension (reduction)\n        const int stages, // number of stages for the async global->shared\n        // fetch pipeline\n        const bool has_act_order,   // whether act_order is enabled\n        const int group_blocks = -1 // number of consecutive 16x16 blocks\n        // with a separate quantization scale\n    >\n    __device__ void\n        Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk\n            const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn\n            int4* __restrict__ C,       // fp16 output buffer of shape mxn\n            const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape\n            // (k/groupsize)xn\n            const int* __restrict__ g_idx, // int32 group indices of shape k\n            int num_groups, // number of scale groups per output channel\n            int prob_m,     // batch dimension m, should be divisible by (16 * thread_m_blocks) if bigger than that\n            int prob_n,     // output dimension n\n            int prob_k,     // reduction dimension k\n            int* locks      // extra global storage for barrier synchronization\n        ) {\n        // Each threadblock processes one \"stripe\" of the B matrix with (roughly) the\n        // same size, which might involve multiple column \"slices\" (of width 16 *\n        // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM\n        // example:\n        //   0 1 3\n        //   0 2 3\n        //   1 2 4\n        // While this kind of partitioning makes things somewhat more complicated, it\n        // ensures good utilization of all SMs for many kinds of shape and GPU\n        // configurations, while requiring as few slow global cross-threadblock\n        // reductions as possible.\n        using Dtype = ScalarType<scalar_t>;\n        using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n        using FragA = typename ScalarType<scalar_t>::FragA;\n        using FragB = typename ScalarType<scalar_t>::FragB;\n        using FragC = typename ScalarType<scalar_t>::FragC;\n        using FragS = typename ScalarType<scalar_t>::FragS;\n\n        constexpr int pack_factor = 32 / num_bits;\n\n        // int prob_m = *prob_m_ptr;\n        // const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);\n        // constexpr int thread_m_blocks = template_thread_m_blocks;\n\n        // For larger GEMMs we run multiple batchsize 64 versions in parallel for a\n        // better partitioning with less reductions\n        int parallel = 1;\n        if (prob_m > 16 * thread_m_blocks) {\n            parallel = prob_m / (16 * thread_m_blocks);\n            prob_m = 16 * thread_m_blocks;\n        }\n\n        int k_tiles = prob_k / 16 / thread_k_blocks;\n        int n_tiles = prob_n / 16 / thread_n_blocks;\n        int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);\n\n        if constexpr (!has_act_order && group_blocks != -1) {\n            if (group_blocks >= thread_k_blocks) {\n                // Ensure that the number of tiles in each stripe is a multiple of the\n                // groupsize; this avoids an annoying special case where a stripe starts\n                // in the middle of group.\n                iters = (group_blocks / thread_k_blocks) *\n                    div_ceil(iters, (group_blocks / thread_k_blocks));\n            }\n        }\n\n        int slice_row = (iters * blockIdx.x) % k_tiles;\n        int slice_col_par = (iters * blockIdx.x) / k_tiles;\n        int slice_col = slice_col_par;\n        int slice_iters;  // number of threadblock tiles in the current slice\n        int slice_count =\n            0;          // total number of active threadblocks in the current slice\n        int slice_idx;  // index of threadblock in current slice; numbered bottom to\n        // top\n\n    // We can easily implement parallel problem execution by just remapping\n    // indices and advancing global pointers\n        if (slice_col_par >= n_tiles) {\n            A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;\n            C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;\n            locks += (slice_col_par / n_tiles) * n_tiles;\n            slice_col = slice_col_par % n_tiles;\n        }\n\n        // Compute all information about the current slice which is required for\n        // synchronization.\n        auto init_slice = [&]() {\n            slice_iters =\n                iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);\n            if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;\n            if (slice_iters == 0) return;\n            if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;\n            slice_count = 1;\n            slice_idx = 0;\n            int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);\n            if (col_first <= k_tiles * (slice_col_par + 1)) {\n                int col_off = col_first - k_tiles * slice_col_par;\n                slice_count = div_ceil(k_tiles - col_off, iters);\n                if (col_off > 0) slice_count++;\n                int delta_first = iters * blockIdx.x - col_first;\n                if (delta_first < 0 || (col_off == 0 && delta_first == 0))\n                    slice_idx = slice_count - 1;\n                else {\n                    slice_idx = slice_count - 1 - delta_first / iters;\n                    if (col_off > 0) slice_idx--;\n                }\n            }\n            if (slice_col == n_tiles) {\n                A += 16 * thread_m_blocks * prob_k / 8;\n                C += 16 * thread_m_blocks * prob_n / 8;\n                locks += n_tiles;\n                slice_col = 0;\n            }\n            };\n        init_slice();\n\n        // A sizes/strides\n\n        // stride of the A matrix in global memory\n        int a_gl_stride = prob_k / 8;\n        // stride of an A matrix tile in shared memory\n        constexpr int a_sh_stride = 16 * thread_k_blocks / 8;\n        // delta between subsequent A tiles in global memory\n        constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;\n        // between subsequent accesses within a tile\n        int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);\n        // between shared memory writes\n        constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);\n        // between shared memory tile reads\n        constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));\n        // within a shared memory tile\n        constexpr int a_sh_rd_delta_i = a_sh_stride * 16;\n        // overall size of a tile\n        constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);\n        // number of shared write iterations for a tile\n        constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);\n\n        // B sizes/strides\n        int b_gl_stride = 16 * prob_n / (pack_factor * 4);\n        constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;\n        constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;\n        constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;\n\n        int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;\n        int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);\n        constexpr int b_sh_wr_delta = threads * b_thread_vecs;\n        constexpr int b_sh_rd_delta = threads * b_thread_vecs;\n        constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;\n        constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;\n\n        // Scale sizes/strides without act_order\n        int s_gl_stride = prob_n / 8;\n        constexpr int s_sh_stride = 16 * thread_n_blocks / 8;\n        constexpr int s_tb_groups =\n            !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks\n            ? thread_k_blocks / group_blocks\n            : 1;\n        constexpr int s_sh_stage = s_tb_groups * s_sh_stride;\n        int s_gl_rd_delta = s_gl_stride;\n\n        // Scale size/strides with act_order\n        constexpr int tb_k = 16 * thread_k_blocks;\n        constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;\n        // constexpr int act_s_row_stride      = 1;\n        // int           act_s_col_stride      = act_s_row_stride * num_groups;\n        int act_s_col_stride = 1;\n        int act_s_col_warp_stride = act_s_col_stride * 8;\n        int tb_n_warps = thread_n_blocks / 4;\n        int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;\n\n        // Global A read index of current thread.\n        int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n            (threadIdx.x % a_gl_rd_delta_o);\n        a_gl_rd += a_gl_rd_delta_o * slice_row;\n        // Shared write index of current thread.\n        int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +\n            (threadIdx.x % a_gl_rd_delta_o);\n        // Shared read index.\n        int a_sh_rd =\n            a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;\n        a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));\n\n        int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +\n            (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;\n        b_gl_rd += b_sh_stride * slice_col;\n        b_gl_rd += b_gl_rd_delta_o * slice_row;\n        int b_sh_wr = threadIdx.x * b_thread_vecs;\n        int b_sh_rd = threadIdx.x * b_thread_vecs;\n\n        // For act_order\n        constexpr int k_iter_size = tb_k / b_sh_wr_iters;\n        int slice_k_start = tb_k * slice_row;\n        int slice_k_finish = slice_k_start + tb_k * slice_iters;\n        int slice_k_start_shared_fetch = slice_k_start;\n        int slice_n_offset = act_s_col_tb_stride * slice_col;\n\n        // No act_order\n        int s_gl_rd;\n        if constexpr (!has_act_order) {\n            if constexpr (group_blocks == -1) {\n                s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n            }\n            else {\n                s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +\n                    s_sh_stride * slice_col + threadIdx.x;\n            }\n        }\n        int s_sh_wr = threadIdx.x;\n        bool s_sh_wr_pred = threadIdx.x < s_sh_stride;\n\n        // We use a different scale layout for grouped and column-wise quantization as\n        // we scale a `half2` tile in column-major layout in the former and in\n        // row-major in the latter case.\n        int s_sh_rd;\n        if constexpr (group_blocks != -1)\n            s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n            (threadIdx.x % 32) / 4;\n        else\n            s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n            (threadIdx.x % 32) % 4;\n\n        // Precompute which thread should not read memory in which iterations; this is\n        // needed if there are more threads than required for a certain tilesize or\n        // when the batchsize is not a multiple of 16.\n        bool a_sh_wr_pred[a_sh_wr_iters];\n#pragma unroll\n        for (int i = 0; i < a_sh_wr_iters; i++) {\n            a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;\n        }\n\n        // To ensure that writing and reading A tiles to/from shared memory, the\n        // latter in fragment format, is fully bank conflict free, we need to use a\n        // rather fancy XOR-based layout. The key here is that neither reads nor\n        // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the\n        // same shared memory banks. Further, it seems (based on NSight-Compute) that\n        // each warp must also write a consecutive memory segment?\n        auto transform_a = [&](int i) {\n            int row = i / a_gl_rd_delta_o;\n            return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;\n            };\n        // Since the computation of this remapping is non-trivial and, due to our main\n        // loop unrolls, all shared memory accesses are static, we simply precompute\n        // both transformed reads and writes.\n        int a_sh_wr_trans[a_sh_wr_iters];\n#pragma unroll\n        for (int i = 0; i < a_sh_wr_iters; i++) {\n            a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);\n        }\n        int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];\n#pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++) {\n#pragma unroll\n            for (int j = 0; j < thread_m_blocks; j++)\n            {\n                a_sh_rd_trans[i][j] =\n                    transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);\n            }\n        }\n\n        // Since B-accesses have non-constant stride they have to be computed at\n        // runtime; we break dependencies between subsequent accesses with a tile by\n        // maintining multiple pointers (we have enough registers), a tiny\n        // optimization.\n        const int4* B_ptr[b_sh_wr_iters];\n#pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++)\n            B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;\n\n        extern __shared__ int4 sh[];\n        // Shared memory storage for global fetch pipelines.\n        int4* sh_a = sh;\n        int4* sh_b = sh_a + (stages * a_sh_stage);\n        int4* sh_g_idx = sh_b + (stages * b_sh_stage);\n        int4* sh_s = sh_g_idx + (stages * g_idx_stage);\n\n        // Register storage for double buffer of shared memory reads.\n        FragA frag_a[2][thread_m_blocks];\n        I4 frag_b_quant[2][b_thread_vecs];\n        FragC frag_c[thread_m_blocks][4][2];\n        FragS frag_s[2][4];         // No act-order\n        FragS act_frag_s[2][4][4];  // For act-order\n\n        // Zero accumulators.\n        auto zero_accums = [&]() {\n#pragma unroll\n            for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)\n            {\n                reinterpret_cast<float*>(frag_c)[i] = 0;\n            }\n            };\n\n        int sh_first_group_id = -1;\n        int sh_num_groups = -1;\n        constexpr int sh_max_num_groups = 32;\n\n        auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,\n            int last_group_id) {\n                sh_first_group_id = first_group_id;\n                sh_num_groups = last_group_id - first_group_id + 1;\n\n                if (sh_num_groups < sh_max_num_groups) {\n                    sh_num_groups = sh_max_num_groups;\n                }\n\n                if (sh_first_group_id + sh_num_groups > num_groups) {\n                    sh_num_groups = num_groups - sh_first_group_id;\n                }\n\n                int row_offset = first_group_id * s_gl_stride;\n\n                if (is_async) {\n                    for (int i = 0; i < sh_num_groups; i++) {\n                        if (threadIdx.x < s_sh_stride) {\n                            cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],\n                                &scales_ptr[row_offset + (i * s_gl_stride) +\n                                slice_n_offset + threadIdx.x]);\n                        }\n                    }\n                }\n                else {\n                    for (int i = 0; i < sh_num_groups; i++) {\n                        if (threadIdx.x < s_sh_stride) {\n                            sh_s[(i * s_sh_stride) + threadIdx.x] =\n                                scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +\n                                threadIdx.x];\n                        }\n                    }\n                }\n            };\n        // Asynchronously fetch the next A, B and s tile from global to the next\n        // shared memory pipeline location.\n        auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {\n            if (pred) {\n                int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n#pragma unroll\n                for (int i = 0; i < a_sh_wr_iters; i++) {\n                    cp_async4_pred(\n                        &sh_a_stage[a_sh_wr_trans[i]],\n                        &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],\n                        a_sh_wr_pred[i]);\n                }\n                int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n#pragma unroll\n                for (int i = 0; i < b_sh_wr_iters; i++) {\n#pragma unroll\n                    for (int j = 0; j < b_thread_vecs; j++) {\n                        cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);\n                    }\n\n                    B_ptr[i] += b_gl_rd_delta_o;\n                }\n\n                if constexpr (has_act_order) {\n                    // Fetch g_idx thread-block portion\n                    int full_pipe = a_off;\n                    int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;\n                    if (cur_k < prob_k && cur_k < slice_k_finish) {\n                        int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n\n                        int4 const* cur_g_idx_stage_ptr =\n                            reinterpret_cast<int4 const*>(&g_idx[cur_k]);\n\n                        if (threadIdx.x < g_idx_stage) {\n                            cp_async4_pred(&sh_g_idx_stage[threadIdx.x],\n                                &cur_g_idx_stage_ptr[threadIdx.x]);\n                        }\n                    }\n                }\n                else {\n                    if constexpr (group_blocks != -1) {\n                        int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n                        if constexpr (group_blocks >= thread_k_blocks) {\n                            // Only fetch scales if this tile starts a new group\n                            if (pipe % (group_blocks / thread_k_blocks) == 0) {\n                                if (s_sh_wr_pred) {\n                                    cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);\n                                }\n                                s_gl_rd += s_gl_rd_delta;\n                            }\n                        }\n                        else {\n                            for (int i = 0; i < s_tb_groups; i++) {\n                                if (s_sh_wr_pred) {\n                                    cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],\n                                        &scales_ptr[s_gl_rd]);\n                                }\n                                s_gl_rd += s_gl_rd_delta;\n                            }\n                        }\n                    }\n                }\n            }\n            // Insert a fence even when we are winding down the pipeline to ensure that\n            // waiting is also correct at this point.\n            cp_async_fence();\n            };\n\n        // Wait until the next thread tile has been loaded to shared memory.\n        auto wait_for_stage = [&]() {\n            // We only have `stages - 2` active fetches since we are double buffering\n            // and can only issue the next fetch when it is guaranteed that the previous\n            // shared memory load is fully complete (as it may otherwise be\n            // overwritten).\n            cp_async_wait<stages - 2>();\n            __syncthreads();\n            };\n\n        // Load the next sub-tile from the current location in the shared memory pipe\n        // into the current register buffer.\n        auto fetch_to_registers = [&](int k, int pipe) {\n            int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n#pragma unroll\n            for (int i = 0; i < thread_m_blocks; i++)\n            {\n                ldsm4<scalar_t>(frag_a[k % 2][i],\n                    &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);\n            }\n\n            int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n\n#pragma unroll\n            for (int i = 0; i < b_thread_vecs; i++) {\n                frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(\n                    &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);\n            }\n            };\n\n        bool is_same_group[stages];\n        int same_group_id[stages];\n\n        auto init_same_group = [&](int pipe) {\n            if constexpr (!has_act_order) {\n                is_same_group[pipe] = false;\n                same_group_id[pipe] = 0;\n                return;\n            }\n\n            int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n            int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n            int group_id_1 = sh_g_idx_int_ptr[0];\n            int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];\n\n            is_same_group[pipe] = group_id_1 == group_id_2;\n            same_group_id[pipe] = group_id_1;\n            };\n\n        auto fetch_scales_to_registers = [&](int k, int full_pipe) {\n            int pipe = full_pipe % stages;\n\n            if constexpr (!has_act_order) {\n                // No act-order case\n                if constexpr (group_blocks != -1) {\n                    if constexpr (group_blocks >= thread_k_blocks) {\n                        int4* sh_s_stage =\n                            sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *\n                                (pipe / (group_blocks / thread_k_blocks)));\n                        reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];\n                    }\n                    else {\n                        int warp_id = threadIdx.x / 32;\n                        int n_warps = thread_n_blocks / 4;\n\n                        int warp_row = warp_id / n_warps;\n\n                        int cur_k = warp_row * 16;\n                        cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n                        int k_blocks = cur_k / 16;\n                        int cur_group_id = k_blocks / group_blocks;\n\n                        int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n                        reinterpret_cast<int4*>(&frag_s[k % 2])[0] =\n                            sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];\n                    }\n                }\n\n                return;\n            }\n\n            // Act-order case\n\n            // Determine K of the \"current\" thread-block\n            int cur_k = slice_k_start + tb_k * full_pipe;\n            if (cur_k >= prob_k || cur_k >= slice_k_finish) {\n                return;\n            }\n\n            // Reset (to current thread-block) since we read g_idx portion from the\n            // shared memory\n            cur_k = 0;\n\n            // Progress to current iteration\n            cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n            // Determine \"position\" inside the thread-block (based on warp and\n            // thread-id)\n            int warp_id = threadIdx.x / 32;\n            int n_warps =\n                thread_n_blocks / 4;  // Each warp processes 4 16-size tiles over N\n\n            int warp_row = warp_id / n_warps;\n            int warp_col = warp_id % n_warps;\n\n            cur_k += warp_row * 16;\n\n            int th_id = threadIdx.x % 32;\n            cur_k += (th_id % 4) * 2;  // Due to tensor-core layout for fp16 B matrix\n\n            int s_col_shift =\n                /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +\n                (th_id / 4) * act_s_col_stride;\n\n            if (is_same_group[pipe]) {\n                if (k % 2 == 0) {\n                    *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n                        sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +\n                        s_col_shift];\n                }\n                else {\n                    *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n                        *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));\n                }\n\n                for (int i = 1; i < 4; i++) {\n                    *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =\n                        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));\n                }\n                return;\n            }\n\n            int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n            int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n            constexpr int k_frag_offsets[4] = { 0, 1, 8,\n                                               9 };  // Tensor core offsets per thread\n\n#pragma unroll\n            for (int i = 0; i < 4; i++) {\n                int actual_k = cur_k + k_frag_offsets[i];\n\n                int group_id = sh_g_idx_int_ptr[actual_k];\n                int rel_group_id = group_id - sh_first_group_id;\n\n                *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =\n                    sh_s[rel_group_id * s_sh_stride + s_col_shift];\n            }\n            };\n\n        // Execute the actual tensor core matmul of a sub-tile.\n        auto matmul = [&](int k) {\n            // We have the m dimension as the inner loop in order to encourage overlapping\n            // dequantization and matmul operations.\n#pragma unroll\n            for (int j = 0; j < 4; j++) {\n                FragB frag_b0;\n                FragB frag_b1;\n                if constexpr (num_bits == 4) {\n                    int b_quant = frag_b_quant[k % 2][0][j];\n                    int b_quant_shift = b_quant >> 8;\n\n                    frag_b0 = dequant_4bit<scalar_t>(b_quant);\n                    frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);\n\n                }\n                else {\n                    int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);\n                    int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];\n                    int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];\n\n                    frag_b0 = dequant_8bit<scalar_t>(b_quant_0);\n                    frag_b1 = dequant_8bit<scalar_t>(b_quant_1);\n                }\n\n                // Apply scale to frag_b0\n                if constexpr (has_act_order) {\n                    scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],\n                        act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],\n                        act_frag_s[k % 2][3][j], 0);\n                }\n                else {\n                    if constexpr (group_blocks != -1) {\n                        scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);\n                    }\n                }\n\n                // Apply scale to frag_b1\n                if constexpr (has_act_order) {\n                    scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],\n                        act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],\n                        act_frag_s[k % 2][3][j], 1);\n\n                }\n                else {\n                    if constexpr (group_blocks != -1) {\n                        scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);\n                    }\n                }\n\n#pragma unroll\n                for (int i = 0; i < thread_m_blocks; i++) {\n                    mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);\n                    mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);\n                }\n            }\n            };\n\n        // Since we slice across the k dimension of a tile in order to increase the\n        // number of warps while keeping the n dimension of a tile reasonable, we have\n        // multiple warps that accumulate their partial sums of the same output\n        // location; which we have to reduce over in the end. We do in shared memory.\n        auto thread_block_reduce = [&]() {\n            constexpr int red_off = threads / b_sh_stride_threads / 2;\n            if (red_off >= 1) {\n                int red_idx = threadIdx.x / b_sh_stride_threads;\n                constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;\n                constexpr int red_sh_delta = b_sh_stride_threads;\n                int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +\n                    (threadIdx.x % b_sh_stride_threads);\n\n                // Parallel logarithmic shared memory reduction. We make sure to avoid any\n                // unnecessary read or write iterations, e.g., for two warps we write only\n                // once by warp 1 and read only once by warp 0.\n\n#pragma unroll\n                for (int m_block = 0; m_block < thread_m_blocks; m_block++) {\n#pragma unroll\n                    for (int i = red_off; i > 0; i /= 2) {\n                        if (i <= red_idx && red_idx < 2 * i) {\n#pragma unroll\n                            for (int j = 0; j < 4 * 2; j++) {\n                                int red_sh_wr =\n                                    red_sh_delta * j + (red_sh_rd - red_sh_stride * i);\n                                if (i < red_off) {\n                                    float* c_rd =\n                                        reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);\n                                    float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);\n#pragma unroll\n                                    for (int k = 0; k < 4; k++)\n                                        reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=\n                                        c_rd[k] + c_wr[k];\n                                }\n                                sh[red_sh_wr] =\n                                    reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];\n                            }\n                        }\n                        __syncthreads();\n                    }\n                    if (red_idx == 0) {\n#pragma unroll\n                        for (int i = 0; i < 4 * 2; i++) {\n                            float* c_rd =\n                                reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);\n#pragma unroll\n                            for (int j = 0; j < 4; j++)\n                                reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=\n                                c_rd[j];\n                        }\n                    }\n                    __syncthreads();\n                }\n            }\n            };\n\n        // Since multiple threadblocks may process parts of the same column slice, we\n        // finally have to globally reduce over the results. As the striped\n        // partitioning minimizes the number of such reductions and our outputs are\n        // usually rather small, we perform this reduction serially in L2 cache.\n        auto global_reduce = [&](bool first = false, bool last = false) {\n            // We are very careful here to reduce directly in the output buffer to\n            // maximize L2 cache utilization in this step. To do this, we write out\n            // results in FP16 (but still reduce with FP32 compute).\n            constexpr int active_threads = 32 * thread_n_blocks / 4;\n            if (threadIdx.x < active_threads) {\n                int c_gl_stride = prob_n / 8;\n                int c_gl_wr_delta_o = 8 * c_gl_stride;\n                int c_gl_wr_delta_i = 4 * (active_threads / 32);\n                int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +\n                    4 * (threadIdx.x / 32) + threadIdx.x % 4;\n                c_gl_wr += (2 * thread_n_blocks) * slice_col;\n                constexpr int c_sh_wr_delta = active_threads;\n                int c_sh_wr = threadIdx.x;\n\n                int row = (threadIdx.x % 32) / 4;\n\n                if (!first) {\n                    // Interestingly, doing direct global accesses here really seems to mess up\n                    // the compiler and lead to slowdowns, hence we also use async-copies even\n                    // though these fetches are not actually asynchronous.\n#pragma unroll\n                    for (int i = 0; i < thread_m_blocks * 4; i++) {\n                        cp_async4_pred(\n                            &sh[c_sh_wr + c_sh_wr_delta * i],\n                            &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +\n                            c_gl_wr_delta_i * (i % 2)],\n                            i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);\n                    }\n                    cp_async_fence();\n                    cp_async_wait<0>();\n                }\n\n#pragma unroll\n                for (int i = 0; i < thread_m_blocks * 4; i++) {\n                    if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {\n                        if (!first) {\n                            int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];\n#pragma unroll\n                            for (int j = 0; j < 2 * 4; j++) {\n                                reinterpret_cast<float*>(\n                                    &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=\n                                    Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);\n                            }\n                        }\n                        if (!last) {\n                            int4 c;\n#pragma unroll\n                            for (int j = 0; j < 2 * 4; j++) {\n                                reinterpret_cast<scalar_t*>(&c)[j] =\n                                    Dtype::float2num(reinterpret_cast<float*>(\n                                        &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);\n                            }\n                            C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =\n                                c;\n                        }\n                    }\n                }\n            }\n            };\n\n        // Write out the reduce final result in the correct layout. We only actually\n        // reshuffle matrix fragments in this step, the reduction above is performed\n        // in fragment layout.\n        auto write_result = [&]() {\n            int c_gl_stride = prob_n / 8;\n            constexpr int c_sh_stride = 2 * thread_n_blocks + 1;\n            int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));\n            constexpr int c_sh_rd_delta =\n                c_sh_stride * (threads / (2 * thread_n_blocks));\n\n            int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                (threadIdx.x % (2 * thread_n_blocks));\n            c_gl_wr += (2 * thread_n_blocks) * slice_col;\n            int c_sh_wr =\n                (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;\n            c_sh_wr += 32 * (threadIdx.x / 32);\n            int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                (threadIdx.x % (2 * thread_n_blocks));\n\n            int c_gl_wr_end = c_gl_stride * prob_m;\n\n            // We first reorder in shared memory to guarantee the most efficient final\n            // global write patterns\n            auto write = [&](int idx, float c0, float c1, FragS& s) {\n                scalar_t2 res =\n                    Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));\n\n                // For per-column quantization we finally apply the scale here (only for\n                // 4-bit)\n                if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {\n                    res = __hmul2(res, s[0]);\n                }\n\n                ((scalar_t2*)sh)[idx] = res;\n                };\n\n            if (threadIdx.x / 32 < thread_n_blocks / 4) {\n#pragma unroll\n                for (int i = 0; i < thread_m_blocks; i++) {\n#pragma unroll\n                    for (int j = 0; j < 4; j++) {\n                        int wr = c_sh_wr + 8 * j;\n                        write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],\n                            frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);\n                        write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],\n                            frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);\n                        write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],\n                            frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);\n                        write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],\n                            frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);\n                    }\n                    c_sh_wr += 16 * (4 * c_sh_stride);\n                }\n            }\n            __syncthreads();\n\n#pragma unroll\n            for (int i = 0;\n                i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));\n                i++) {\n                if (c_gl_wr < c_gl_wr_end) {\n                    C[c_gl_wr] = sh[c_sh_rd];\n                    c_gl_wr += c_gl_wr_delta;\n                    c_sh_rd += c_sh_rd_delta;\n                }\n            }\n            };\n\n        // Start global fetch and register load pipelines.\n        auto start_pipes = [&]() {\n\n#pragma unroll\n            for (int i = 0; i < stages - 1; i++) {\n                if (has_act_order && i == 0) {\n                    int last_g_idx = slice_k_start + stages * tb_k * 2;\n                    if (last_g_idx >= prob_k) {\n                        last_g_idx = prob_k - 1;\n                    }\n                    fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);\n                }\n                fetch_to_shared(i, i, i < slice_iters);\n            }\n\n            zero_accums();\n            wait_for_stage();\n            init_same_group(0);\n            fetch_to_registers(0, 0);\n            fetch_scales_to_registers(0, 0);\n            a_gl_rd += a_gl_rd_delta_o * (stages - 1);\n            slice_k_start_shared_fetch += tb_k * (stages - 1);\n            };\n        if (slice_iters) {\n            start_pipes();\n        }\n\n        // Main loop.\n        while (slice_iters) {\n            // We unroll over both the global fetch and the register load pipeline to\n            // ensure all shared memory accesses are static. Note that both pipelines\n            // have even length meaning that the next iteration will always start at\n            // index 0.\n\n#pragma unroll\n            for (int pipe = 0; pipe < stages;) {\n#pragma unroll\n                for (int k = 0; k < b_sh_wr_iters; k++) {\n                    fetch_to_registers(k + 1, pipe % stages);\n                    fetch_scales_to_registers(k + 1, pipe);\n                    if (k == b_sh_wr_iters - 2) {\n                        fetch_to_shared((pipe + stages - 1) % stages, pipe,\n                            slice_iters >= stages);\n                        pipe++;\n                        wait_for_stage();\n                        init_same_group(pipe % stages);\n                    }\n                    matmul(k);\n                }\n                slice_iters--;\n                if (slice_iters == 0) {\n                    break;\n                }\n            }\n\n            a_gl_rd += a_gl_rd_delta_o * stages;\n            slice_k_start += tb_k * stages;\n            slice_k_start_shared_fetch += tb_k * stages;\n\n            if constexpr (has_act_order) {\n                int first_group_id = g_idx[slice_k_start];\n                int last_g_idx = slice_k_start + stages * tb_k * 2;\n                if (last_g_idx >= prob_k) {\n                    last_g_idx = prob_k - 1;\n                }\n                int last_group_id = g_idx[last_g_idx];\n                if (last_group_id >= sh_first_group_id + sh_num_groups) {\n                    fetch_scales_to_shared(false, first_group_id, last_group_id);\n                    __syncthreads();\n                }\n            }\n\n            // Process results and, if necessary, proceed to the next column slice.\n            // While this pattern may not be the most readable, other ways of writing\n            // the loop seemed to noticeably worse performance after compilation.\n            if (slice_iters == 0) {\n                cp_async_wait<0>();\n                bool last = slice_idx == slice_count - 1;\n                // For per-column scales, we only fetch them here in the final step before\n                // write-out\n                if constexpr (!has_act_order && group_blocks == -1) {\n                    if constexpr (num_bits == 8) {\n                        if (s_sh_wr_pred) {\n                            cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n                        }\n                        cp_async_fence();\n                    }\n                    else {\n                        if (last) {\n                            if (s_sh_wr_pred) {\n                                cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n                            }\n                            cp_async_fence();\n                        }\n                    }\n                }\n\n                thread_block_reduce();\n                if constexpr (!has_act_order && group_blocks == -1) {\n                    if constexpr (num_bits == 8) {\n                        cp_async_wait<0>();\n                        __syncthreads();\n                        if (threadIdx.x / 32 < thread_n_blocks / 4) {\n                            reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n                            reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n                        }\n\n                    }\n                    else {\n                        if (last) {\n                            cp_async_wait<0>();\n                            __syncthreads();\n                            if (threadIdx.x / 32 < thread_n_blocks / 4) {\n                                reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n                                reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n                            }\n                        }\n                    }\n                }\n\n                // For 8-bit channelwise, we apply the scale before the global reduction\n                // that converts the fp32 results to fp16 (so that we avoid possible\n                // overflow in fp16)\n                if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {\n                    if (threadIdx.x / 32 < thread_n_blocks / 4) {\n#pragma unroll\n                        for (int i = 0; i < thread_m_blocks; i++) {\n#pragma unroll\n                            for (int j = 0; j < 4; j++) {\n                                scale_float<scalar_t>(\n                                    reinterpret_cast<float*>(&frag_c[i][j][0][0]),\n                                    frag_s[j / 2][2 * (j % 2) + 0]);\n                                scale_float<scalar_t>(\n                                    reinterpret_cast<float*>(&frag_c[i][j][0][2]),\n                                    frag_s[j / 2][2 * (j % 2) + 0]);\n\n                                scale_float<scalar_t>(\n                                    reinterpret_cast<float*>(&frag_c[i][j][1][0]),\n                                    frag_s[j / 2][2 * (j % 2) + 1]);\n                                scale_float<scalar_t>(\n                                    reinterpret_cast<float*>(&frag_c[i][j][1][2]),\n                                    frag_s[j / 2][2 * (j % 2) + 1]);\n                            }\n                        }\n                    }\n                }\n\n                if (slice_count > 1) {  // only globally reduce if there is more than one\n                    // block in a slice\n                    barrier_acquire(&locks[slice_col], slice_idx);\n                    global_reduce(slice_idx == 0, last);\n                    barrier_release(&locks[slice_col], last);\n                }\n                if (last)  // only the last block in a slice actually writes the result\n                    write_result();\n                slice_row = 0;\n                slice_col_par++;\n                slice_col++;\n                init_slice();\n                if (slice_iters) {\n                    a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                        (threadIdx.x % a_gl_rd_delta_o);\n#pragma unroll\n                    for (int i = 0; i < b_sh_wr_iters; i++)\n                        B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;\n                    if (slice_col == 0) {\n#pragma unroll\n                        for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;\n                    }\n\n                    // Update slice k/n for scales loading\n                    if constexpr (has_act_order) {\n                        slice_k_start = tb_k * slice_row;\n                        slice_k_finish = slice_k_start + tb_k * slice_iters;\n                        slice_k_start_shared_fetch = slice_k_start;\n                        slice_n_offset = act_s_col_tb_stride * slice_col;\n\n                    }\n                    else {\n                        s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n                    }\n\n                    start_pipes();\n                }\n            }\n        }\n    }\n\n    template <typename scalar_t,         // compute dtype, half or nv_float16\n        const int num_bits,        // number of bits used for weights\n        const int threads,         // number of threads in a threadblock\n        const int template_thread_m_blocks, // number of 16x16 blocks in the m\n        // dimension (batchsize) of the\n        // threadblock\n        const int thread_n_blocks, // same for n dimension (output)\n        const int thread_k_blocks, // same for k dimension (reduction)\n        const int stages, // number of stages for the async global->shared\n        // fetch pipeline\n        const bool has_act_order,   // whether act_order is enabled\n        const int group_blocks = -1 // number of consecutive 16x16 blocks\n        // with a separate quantization scale\n    >\n    __global__ void\n        Marlin_wrapper(const int4* __restrict__ A, // fp16 input matrix of shape mxk\n            const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn\n            int4* __restrict__ C,       // fp16 output buffer of shape mxn\n            const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape\n            // (k/groupsize)xn\n            const int* __restrict__ g_idx, // int32 group indices of shape k\n            int num_groups, // number of scale groups per output channel\n            const int* __restrict__ prob_m_ptr,     // batch dimension m\n            int prob_n,     // output dimension n\n            int prob_k,     // reduction dimension k\n            int* locks      // extra global storage for barrier synchronization\n        ) {\n        int prob_m = *prob_m_ptr;\n        prob_m = min(prob_m, 1024);\n        const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);\n        if(prob_m > 16 * thread_m_blocks)\n            prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks));\n        /*if (blockIdx.x == 0 && threadIdx.x == 0)\n            printf(\"marlin prob_m %d\\n\", prob_m);*/\n        if (thread_m_blocks == 1) {\n            Marlin<scalar_t, num_bits, threads, 1,\n                thread_n_blocks, thread_k_blocks, stages, has_act_order,\n                group_blocks>(\n                    A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,\n                    prob_k, locks);\n        }\n        else if (thread_m_blocks == 2) {\n            Marlin<scalar_t, num_bits, threads, 2,\n                thread_n_blocks, thread_k_blocks, stages, has_act_order,\n                group_blocks>(\n                    A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,\n                    prob_k, locks);\n        }\n        else if (thread_m_blocks == 3) {\n            Marlin<scalar_t, num_bits, threads, 3,\n                thread_n_blocks, thread_k_blocks, stages, has_act_order,\n                group_blocks>(\n                    A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,\n                    prob_k, locks);\n        }\n        else if (thread_m_blocks == 4) {\n            Marlin<scalar_t, num_bits, threads, 4,\n                thread_n_blocks, thread_k_blocks, stages, has_act_order,\n                group_blocks>(\n                    A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,\n                    prob_k, locks);\n        }\n    }\n\n#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \\\n                  HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS)                    \\\n    else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS &&     \\\n             thread_n_blocks == THREAD_N_BLOCKS &&                             \\\n             thread_k_blocks == THREAD_K_BLOCKS &&                             \\\n             has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \\\n             num_threads == NUM_THREADS) {                                     \\\n        cudaFuncSetAttribute(                                                  \\\n            Marlin_wrapper<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,           \\\n                   THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages,              \\\n                   HAS_ACT_ORDER, GROUP_BLOCKS>,                               \\\n            cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);      \\\n        Marlin_wrapper<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,               \\\n               THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER,   \\\n               GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \\\n            A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m_ptr, prob_n, \\\n            prob_k, locks);                                                    \\\n    }\n\n    typedef struct {\n        int thread_k;\n        int thread_n;\n        int num_threads;\n    } thread_config_t;\n\n    typedef struct {\n        int max_m_blocks;\n        thread_config_t tb_cfg;\n    } exec_config_t;\n\n    thread_config_t small_batch_thread_configs[] = {\n        // Ordered by priority\n\n        // thread_k, thread_n, num_threads\n        {128, 128, 256},\n        {64, 128, 128},\n        {128, 64, 128},\n    };\n\n    thread_config_t large_batch_thread_configs[] = {\n        // Ordered by priority\n\n        // thread_k, thread_n, num_threads\n        {64, 256, 256},\n        // {128, 128, 256},\n        {64, 128, 128},\n        {128, 64, 128},\n\n    };\n\n    int get_scales_cache_size(thread_config_t const& th_config, int prob_m,\n        int prob_n, int prob_k, int num_bits, int group_size,\n        bool has_act_order, bool is_k_full) {\n        bool cache_scales_chunk = has_act_order && !is_k_full;\n\n        int tb_n = th_config.thread_n;\n        int tb_k = th_config.thread_k;\n\n        // Get max scale groups per thread-block\n        int tb_groups;\n        if (group_size == -1) {\n            tb_groups = 1;\n        }\n        else if (group_size == 0) {\n            tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size\n        }\n        else {\n            tb_groups = div_ceil(tb_k, group_size);\n        }\n\n        if (cache_scales_chunk) {\n            int load_groups =\n                tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K\n            load_groups = max(load_groups, 32); // We load at least 32 scale groups\n            return load_groups * tb_n * 2;\n\n        }\n        else {\n            int tb_scales = tb_groups * tb_n * 2;\n\n            return tb_scales * pipe_stages;\n        }\n    }\n\n    bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,\n        int prob_m, int prob_n, int prob_k, int num_bits,\n        int scales_cache_size, int max_shared_mem) {\n        int pack_factor = 32 / num_bits;\n\n        // Get B size\n        int tb_k = th_config.thread_k;\n        int tb_n = th_config.thread_n;\n\n        int b_size = (tb_k * tb_n / pack_factor) * 4;\n\n        // Get A size\n        int m_blocks = div_ceil(prob_m, 16);\n        int tb_max_m = 16;\n\n        // zbx: too ugly\n        // origin\n        /*while (true) {\n          if (m_blocks >= max_m_blocks) {\n            tb_max_m *= max_m_blocks;\n            break;\n          }\n\n          max_m_blocks--;\n          if (max_m_blocks == 0) {\n            TORCH_CHECK(false, \"Unexpected m_blocks = \", m_blocks);\n          }\n        }*/\n        // refactor\n        tb_max_m *= std::min(m_blocks, max_m_blocks);\n\n        int a_size = (tb_max_m * tb_k) * 2;\n\n        float pipe_size = (a_size + b_size) * pipe_stages;\n\n        TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity\n        return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);\n    }\n\n    bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,\n        int prob_m, int prob_n, int prob_k, int num_bits,\n        int group_size, bool has_act_order, bool is_k_full,\n        int max_shared_mem) {\n        // Sanity\n        if (th_config.thread_k == -1 || th_config.thread_n == -1 ||\n            th_config.num_threads == -1) {\n            return false;\n        }\n\n        // Verify K/N are divisible by thread K/N\n        if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {\n            return false;\n        }\n\n        // Verify min for thread K/N\n        if (th_config.thread_n < min_thread_n ||\n            th_config.thread_k < min_thread_k) {\n            return false;\n        }\n\n        // num_threads must be at least 128 (= 4 warps)\n        if (th_config.num_threads < 128) {\n            return false;\n        }\n\n        //  Determine cache for scales\n        int scales_cache_size =\n            get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,\n                group_size, has_act_order, is_k_full);\n\n        // Check that pipeline fits into cache\n        if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n            num_bits, scales_cache_size, max_shared_mem)) {\n            return false;\n        }\n\n        return true;\n    }\n\n    exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,\n        int num_bits, int group_size,\n        bool has_act_order, bool is_k_full,\n        int max_shared_mem) {\n        int max_m_blocks = 4;\n        while (max_m_blocks > 0) {\n            if (prob_m <= 16) {\n                for (auto th_config : small_batch_thread_configs) {\n                    if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n,\n                        prob_k, num_bits, group_size, has_act_order,\n                        is_k_full, max_shared_mem)) {\n                        return exec_config_t{ max_m_blocks, th_config };\n                    }\n                }\n            }\n            else {\n                for (auto th_config : large_batch_thread_configs) {\n                    if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n,\n                        prob_k, num_bits, group_size, has_act_order,\n                        is_k_full, max_shared_mem)) {\n                        return exec_config_t{ max_m_blocks, th_config };\n                    }\n                }\n            }\n\n            max_m_blocks--; // Process less M blocks per invocation to reduce cache\n            // usage\n        }\n\n        return exec_config_t{ 0, {-1, -1, -1} };\n    }\n\n#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS)                     \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)        \n\n    template <typename scalar_t>\n    void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,\n        void* g_idx, void* perm, void* a_tmp, int* prob_m_ptr, int prob_m,\n        int prob_n, int prob_k, void* workspace, int num_bits,\n        bool has_act_order, bool is_k_full, int num_groups,\n        int group_size, int dev, cudaStream_t stream, int thread_k,\n        int thread_n, int sms, int max_par) {\n        TORCH_CHECK(num_bits == 4 || num_bits == 8,\n            \"num_bits must be 4 or 8. Got = \", num_bits);\n        TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, \"Invalid MNK = [\",\n            prob_m, \", \", prob_n, \", \", prob_k, \"]\");\n\n        int tot_m = prob_m;\n        int tot_m_blocks = div_ceil(tot_m, 16);\n        int pad = 16 * tot_m_blocks - tot_m;\n\n        if (sms == -1) {\n            cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);\n        }\n\n        int max_shared_mem = 0;\n        cudaDeviceGetAttribute(&max_shared_mem,\n            cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n        TORCH_CHECK(max_shared_mem > 0);\n\n        // Set thread config\n        exec_config_t exec_cfg;\n        if (thread_k != -1 && thread_n != -1) {\n            // User-defined config\n            exec_cfg = exec_config_t{\n                4, thread_config_t{thread_k, thread_n, default_threads} };\n        }\n        else {\n            // Auto config\n            exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits,\n                group_size, has_act_order, is_k_full,\n                max_shared_mem);\n        }\n\n        TORCH_CHECK(\n            exec_cfg.max_m_blocks > 0 &&\n            is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m,\n                prob_n, prob_k, num_bits, group_size, has_act_order,\n                is_k_full, max_shared_mem),\n            \"Invalid thread config: max_m_blocks = \", exec_cfg.max_m_blocks,\n            \", thread_k = \", exec_cfg.tb_cfg.thread_k,\n            \", thread_n = \", exec_cfg.tb_cfg.thread_n,\n            \", num_threads = \", exec_cfg.tb_cfg.num_threads, \" for MKN = [\", prob_m,\n            \", \", prob_k, \", \", prob_n, \"] and num_bits = \", num_bits,\n            \", group_size = \", group_size, \", has_act_order = \", has_act_order,\n            \", is_k_full = \", is_k_full, \", max_shared_mem = \", max_shared_mem);\n\n        int num_threads = exec_cfg.tb_cfg.num_threads;\n        thread_k = exec_cfg.tb_cfg.thread_k;\n        thread_n = exec_cfg.tb_cfg.thread_n;\n\n        int thread_k_blocks = thread_k / 16;\n        int thread_n_blocks = thread_n / 16;\n\n        int blocks = sms;\n\n        TORCH_CHECK(prob_n % thread_n == 0, \"prob_n = \", prob_n,\n            \" is not divisible by thread_n = \", thread_n);\n        TORCH_CHECK(prob_k % thread_k == 0, \"prob_k = \", prob_k,\n            \" is not divisible by thread_k = \", thread_k);\n\n        int group_blocks = 0;\n        if (has_act_order) {\n            if (is_k_full) {\n                TORCH_CHECK(group_size != -1);\n                group_blocks = group_size / 16;\n                TORCH_CHECK(prob_k % group_blocks == 0, \"prob_k = \", prob_k,\n                    \" is not divisible by group_blocks = \", group_blocks);\n            }\n            else {\n                TORCH_CHECK(group_size == 0);\n                group_blocks = 0;\n            }\n\n        }\n        else {\n            if (group_size == -1) {\n                group_blocks = -1;\n            }\n            else {\n                group_blocks = group_size / 16;\n                TORCH_CHECK(prob_k % group_blocks == 0, \"prob_k = \", prob_k,\n                    \" is not divisible by group_blocks = \", group_blocks);\n            }\n        }\n\n        const int4* A_ptr = (const int4*)A;\n        const int4* B_ptr = (const int4*)B;\n        int4* C_ptr = (int4*)C;\n        const int4* s_ptr = (const int4*)s;\n        const int* g_idx_ptr = (const int*)g_idx;\n        const int* perm_ptr = (const int*)perm;\n        int4* a_tmp_ptr = (int4*)a_tmp;\n\n        int* locks = (int*)workspace;\n\n        if (has_act_order) {\n            // Permute A columns\n            int block_rows = div_ceil(prob_m, blocks);\n            permute_cols_kernel << <blocks, default_threads, 0, stream >> > (\n                A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);\n            A_ptr = a_tmp_ptr;\n        }\n\n        // If we have a full K, then we can run the non-act-order version of Marlin\n        // (since the weight rows are reordered by increasing group ids, and by\n        // having a full K, we have full original groups)\n        if (is_k_full) {\n            has_act_order = false;\n        }\n\n        // Main loop\n        for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {\n            int thread_m_blocks = tot_m_blocks - i;\n            prob_m = tot_m - 16 * i;\n            int par = 1;\n            if (thread_m_blocks > exec_cfg.max_m_blocks) {\n                // Note that parallel > 1 currently only works for inputs without\n                // any padding\n                par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);\n                if (par > max_par)\n                    par = max_par;\n                prob_m = (16 * exec_cfg.max_m_blocks) * par;\n                i += exec_cfg.max_m_blocks * (par - 1);\n                thread_m_blocks = exec_cfg.max_m_blocks;\n            }\n\n            // Define kernel configurations\n#define undefined_error                                                        \\\n    TORCH_CHECK(false, \"Unsupported shapes: MNK = [\" + str(prob_m) + \", \" +    \\\n                           str(prob_n) + \", \" + str(prob_k) + \"]\" +            \\\n                           \", has_act_order = \" + str(has_act_order) +         \\\n                           \", num_groups = \" + str(num_groups) +               \\\n                           \", group_size = \" + str(group_size) +               \\\n                           \", thread_m_blocks = \" + str(thread_m_blocks) +     \\\n                           \", thread_n_blocks = \" + str(thread_n_blocks) +     \\\n                           \", thread_k_blocks = \" + str(thread_k_blocks));\n\n        /* std::cout << \"MNK = [\" + str(prob_m) + \", \" + \\\n             str(prob_n) + \", \" + str(prob_k) + \"]\" + \\\n             \", has_act_order = \" + str(has_act_order) + \\\n             \", num_groups = \" + str(num_groups) + \\\n             \", group_size = \" + str(group_size) + \\\n             \", thread_m_blocks = \" + str(thread_m_blocks) + \\\n             \", thread_n_blocks = \" + str(thread_n_blocks) + \\\n             \", thread_k_blocks = \" + str(thread_k_blocks) << std::endl;*/\n\n             /*if (false) {\n             }\n             // CALL_IF(4, 32, 2, 256)\n             // CALL_IF(4, 16, 4, 256)\n             __CALL_IF(4, 1, 16, 4, false, 4, 256)\n             __CALL_IF(4, 2, 16, 4, false, 4, 256)\n             // CALL_IF(4, 8, 8, 256)\n             __CALL_IF(4, 1, 8, 8, false, 4, 256)\n             __CALL_IF(4, 2, 8, 8, false, 4, 256)\n             // CALL_IF(4, 16, 4, 128)\n             __CALL_IF(4, 1, 16, 4, false, 4, 128)\n             __CALL_IF(4, 2, 16, 4, false, 4, 128)\n             // CALL_IF(4, 8, 8, 128)\n             __CALL_IF(4, 1, 8, 8, false, 4, 128)\n             __CALL_IF(4, 2, 8, 8, false, 4, 128)\n             else {undefined_error}*/\n\n            if (num_bits == 4 && num_threads == 256)\n            {\n                if (false) {\n                }\n                CALL_IF(4, 32, 2, 256)\n                    CALL_IF(4, 16, 4, 256)\n                    CALL_IF(4, 8, 8, 256)\n                else {\n                    undefined_error\n                }\n            }\n            else if (num_bits == 4 && num_threads == 128)\n            {\n                if (false) {\n                }\n                CALL_IF(4, 8, 4, 128)\n                    CALL_IF(4, 16, 4, 128)\n                    CALL_IF(4, 4, 8, 128)\n                else {\n                    undefined_error\n                }\n            }\n            // else if (num_bits == 8 && num_threads == 256)\n            // {\n            //     if (false) {\n            //     }\n            //     CALL_IF(8, 32, 2, 256)\n            //     CALL_IF(8, 16, 4, 256)\n            //     CALL_IF(8, 8, 8, 256)\n            //     else {\n            //         undefined_error\n            //     }\n            // }\n            // else if (num_bits == 8 && num_threads == 128)\n            // {\n            //     if (false) {\n            //     }\n            //     CALL_IF(8, 8, 4, 128)\n            //     CALL_IF(8, 16, 4, 128)\n            //     CALL_IF(8, 4, 8, 128)\n            //     else {\n            //         undefined_error\n            //     }\n            // }\n            else {\n                undefined_error\n            }\n\n            A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;\n            C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;\n        }\n    }\n\n} // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n    torch::Tensor& b_scales, torch::Tensor& g_idx,\n    torch::Tensor& perm, torch::Tensor& workspace,\n    int64_t num_bits, torch::Tensor size_m_tensor, int64_t size_m, int64_t size_n,\n    int64_t size_k, int sms, bool is_k_full) {\n    const at::cuda::OptionalCUDAGuard device_guard(device_of(a));\n    // Verify num_bits\n    TORCH_CHECK(num_bits == 4 || num_bits == 8,\n        \"num_bits must be 4 or 8. Got = \", num_bits);\n    int pack_factor = 32 / num_bits;\n\n    // Verify A\n    TORCH_CHECK(a.size(0) == size_m, \"Shape mismatch: a.size(0) = \", a.size(0),\n        \", size_m = \", size_m);\n    TORCH_CHECK(a.size(1) == size_k, \"Shape mismatch: a.size(1) = \", a.size(1),\n        \", size_k = \", size_k);\n\n    // Verify B\n    TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, \"size_k = \", size_k,\n        \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n    TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),\n        \"Shape mismatch: b_q_weight.size(0) = \", b_q_weight.size(0),\n        \", size_k = \", size_k,\n        \", tile_size = \", gptq_marlin::tile_size);\n    TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,\n        \"b_q_weight.size(1) = \", b_q_weight.size(1),\n        \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n    int actual_size_n =\n        (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;\n    TORCH_CHECK(size_n == actual_size_n, \"size_n = \", size_n,\n        \", actual_size_n = \", actual_size_n);\n\n    // Verify device and strides\n    TORCH_CHECK(a.device().is_cuda(), \"A is not on GPU\");\n    TORCH_CHECK(a.is_contiguous(), \"A is not contiguous\");\n\n    TORCH_CHECK(b_q_weight.device().is_cuda(), \"b_q_weight is not on GPU\");\n    TORCH_CHECK(b_q_weight.is_contiguous(), \"b_q_weight is not contiguous\");\n\n    TORCH_CHECK(b_scales.device().is_cuda(), \"b_scales is not on GPU\");\n    TORCH_CHECK(b_scales.is_contiguous(), \"b_scales is not contiguous\");\n\n    TORCH_CHECK(g_idx.device().is_cuda(), \"g_idx is not on GPU\");\n    TORCH_CHECK(g_idx.is_contiguous(), \"g_idx is not contiguous\");\n\n    TORCH_CHECK(perm.device().is_cuda(), \"perm is not on GPU\");\n    TORCH_CHECK(perm.is_contiguous(), \"perm is not contiguous\");\n\n    // Alloc buffers\n    auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());\n    torch::Tensor c = torch::empty({ size_m, size_n }, options);\n    torch::Tensor a_tmp = torch::empty({ size_m, size_k }, options);\n\n    // thread_k: `k` size of a thread_tile in `weights` (can usually be left as\n    // auto -1)\n    int thread_k = -1;\n    // thread_n: `n` size of a thread_tile in `weights` (can usually be left as\n    // auto -1)\n    int thread_n = -1;\n    // sms: number of SMs to use for the kernel (can usually be left as auto -1)\n    // int sms = -1; //zbx\n\n    // Verify g_idx and perm\n    TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||\n        (g_idx.size(0) == size_k && perm.size(0) == size_k),\n        \"Unexpected g_idx.size(0) = \", g_idx.size(0),\n        \" and perm.size(0) = \", perm.size(0),\n        \", where size_k = \", size_k);\n\n    // Detect groupsize and act_order\n    int num_groups = -1;\n    int group_size = -1;\n    bool has_act_order = g_idx.size(0) != 0;\n\n    int b_rank = b_scales.sizes().size();\n    TORCH_CHECK(b_rank == 2, \"b_scales rank = \", b_rank, \" is not 2\");\n    TORCH_CHECK(b_scales.size(1) == size_n,\n        \"b_scales dim 1 = \", b_scales.size(1),\n        \" is not size_n = \", size_n);\n    num_groups = b_scales.size(0);\n\n    if (has_act_order) {\n        if (is_k_full) {\n            TORCH_CHECK(num_groups > 1,\n                \"For act_order, num_groups must be > 1\");\n            TORCH_CHECK(size_k % num_groups == 0, \"size_k = \", size_k,\n                \", is not divisible by num_groups = \", num_groups);\n            group_size = size_k / num_groups;\n        }\n        else {\n            group_size = 0;\n        }\n\n    }\n    else {\n        if (num_groups > 1) {\n            TORCH_CHECK(\n                size_k % num_groups == 0, \"size_k = \", size_k,\n                \", is not divisible by b_scales.size(0) = \", b_scales.size(0));\n            group_size = size_k / num_groups;\n        }\n        else {\n            group_size = -1;\n        }\n    }\n\n    // Verify workspace size\n    TORCH_CHECK(\n        size_n % gptq_marlin::min_thread_n == 0, \"size_n = \", size_n,\n        \", is not divisible by min_thread_n = \", gptq_marlin::min_thread_n);\n    int min_workspace_size =\n        (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;\n    TORCH_CHECK(workspace.numel() >= min_workspace_size,\n        \"workspace.numel = \", workspace.numel(),\n        \" is below min_workspace_size = \", min_workspace_size);\n\n    int dev = a.get_device();\n    if (a.scalar_type() == at::ScalarType::Half) {\n        gptq_marlin::marlin_mm_f16i4<half>(\n            a.data_ptr<at::Half>(), b_q_weight.data_ptr(),\n            c.data_ptr<at::Half>(), b_scales.data_ptr<at::Half>(),\n            g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::Half>(),\n            size_m_tensor.data_ptr<int>(),\n            size_m, size_n, size_k, workspace.data_ptr(), num_bits,\n            has_act_order, is_k_full, num_groups, group_size, dev,\n            at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,\n            gptq_marlin::max_par);\n    }\n    else if (a.scalar_type() == at::ScalarType::BFloat16) {\n        gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(\n            a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),\n            c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),\n            g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),\n            size_m_tensor.data_ptr<int>(),\n            size_m, size_n, size_k, workspace.data_ptr(), num_bits,\n            has_act_order, is_k_full, num_groups, group_size, dev,\n            at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,\n            gptq_marlin::max_par);\n    }\n    else {\n        TORCH_CHECK(false,\n            \"gpt_marlin_gemm only supports bfloat16 and float16\");\n    }\n\n    return c;\n}\n\n#endif"
  },
  {
    "path": "archive/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n// Copyrigth 2024 The vLLM team.\n// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#pragma once\n\n#include <torch/all.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <iostream>\n\nnamespace gptq_marlin {\n\n// 8 warps are a good choice since every SM has 4 schedulers and having more\n// than 1 warp per schedule allows some more latency hiding. At the same time,\n// we want relatively few warps to have many registers per warp and small tiles.\nstatic constexpr int default_threads = 256;\n\nstatic constexpr int pipe_stages =\n    4; // 4 pipeline stages fit into shared memory\n\nstatic constexpr int min_thread_n = 64;\nstatic constexpr int min_thread_k = 64;\n\nstatic constexpr int tile_size = 16;\nstatic constexpr int max_par = 16;\n\ntemplate <typename T, int n> struct Vec {\n    T elems[n];\n    __device__ T &operator[](int i) { return elems[i]; }\n};\n\nusing I4 = Vec<int, 4>;\n\nconstexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n// No support for async\n#else\n\n__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,\n                                      bool pred = true) {\n    const int BYTES = 16;\n    uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n    asm volatile(\"{\\n\"\n                 \"   .reg .pred p;\\n\"\n                 \"   setp.ne.b32 p, %0, 0;\\n\"\n                 \"   @p cp.async.cg.shared.global [%1], [%2], %3;\\n\"\n                 \"}\\n\" ::\"r\"((int)pred),\n                 \"r\"(smem), \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {\n    const int BYTES = 16;\n    uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n    asm volatile(\"{\\n\"\n                 \"   cp.async.cg.shared.global [%0], [%1], %2;\\n\"\n                 \"}\\n\" ::\"r\"(smem),\n                 \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async_fence() {\n    asm volatile(\"cp.async.commit_group;\\n\" ::);\n}\n\ntemplate <int n> __device__ inline void cp_async_wait() {\n    asm volatile(\"cp.async.wait_group %0;\\n\" ::\"n\"(n));\n}\n\n#endif\n\n} // namespace gptq_marlin"
  },
  {
    "path": "archive/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n// Copyrigth 2024 The vLLM team.\n// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#ifndef _data_types_cuh\n#define _data_types_cuh\n#include \"gptq_marlin.cuh\"\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\nnamespace gptq_marlin {\n\ntemplate <typename scalar_t> class ScalarType {};\n\ntemplate <> class ScalarType<half> {\n  public:\n    using scalar_t = half;\n    using scalar_t2 = half2;\n\n    // Matrix fragments for tensor core instructions; their precise layout is\n    // documented here:\n    // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type\n    using FragA = Vec<half2, 4>;\n    using FragB = Vec<half2, 2>;\n    using FragC = Vec<float, 4>;\n    using FragS = Vec<half2, 1>;\n\n    static __device__ float inline num2float(const half x) {\n        return __half2float(x);\n    }\n\n    static __device__ half2 inline num2num2(const half x) {\n        return __half2half2(x);\n    }\n\n    static __device__ half2 inline nums2num2(const half x1, const half x2) {\n        return __halves2half2(x1, x2);\n    }\n\n    static __host__ __device__ half inline float2num(const float x) {\n        return __float2half(x);\n    }\n};\n\ntemplate <> class ScalarType<nv_bfloat16> {\n  public:\n    using scalar_t = nv_bfloat16;\n    using scalar_t2 = nv_bfloat162;\n\n    using FragA = Vec<nv_bfloat162, 4>;\n    using FragB = Vec<nv_bfloat162, 2>;\n    using FragC = Vec<float, 4>;\n    using FragS = Vec<nv_bfloat162, 1>;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    static __device__ float inline num2float(const nv_bfloat16 x) {\n        return __bfloat162float(x);\n    }\n\n    static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {\n        return __bfloat162bfloat162(x);\n    }\n\n    static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,\n                                                    const nv_bfloat16 x2) {\n        return __halves2bfloat162(x1, x2);\n    }\n\n    static __host__ __device__ nv_bfloat16 inline float2num(const float x) {\n        return __float2bfloat16(x);\n    }\n#endif\n};\n\n} // namespace gptq_marlin\n\n#endif"
  },
  {
    "path": "archive/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu",
    "content": "#include \"gptq_marlin.cuh\"\n\nnamespace gptq_marlin {\n\nstatic constexpr int repack_stages = 8;\n\nstatic constexpr int repack_threads = 256;\n\nstatic constexpr int tile_k_size = tile_size;\nstatic constexpr int tile_n_size = tile_k_size * 4;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\ntemplate <int const num_threads, int const num_bits, bool const has_perm>\n__global__ void marlin_repack_kernel(\n    uint32_t const* __restrict__ b_q_weight_ptr,\n    uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,\n    int size_k, int size_n) {}\n\n}  // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,\n                                 int64_t size_k, int64_t size_n,\n                                 int64_t num_bits) {\n  TORCH_CHECK_NOT_IMPLEMENTED(\n      false, \"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0\");\n  return torch::empty({1, 1});\n}\n\n#else\n\ntemplate <int const num_threads, int const num_bits, bool const has_perm>\n__global__ void marlin_repack_kernel(\n    uint32_t const* __restrict__ b_q_weight_ptr,\n    uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,\n    int size_k, int size_n) {\n  constexpr int pack_factor = 32 / num_bits;\n\n  int k_tiles = size_k / tile_k_size;\n  int n_tiles = size_n / tile_n_size;\n  int block_k_tiles = div_ceil(k_tiles, gridDim.x);\n\n  int start_k_tile = blockIdx.x * block_k_tiles;\n  if (start_k_tile >= k_tiles) {\n    return;\n  }\n\n  int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<repack_stages - 2>();\n    __syncthreads();\n  };\n\n  extern __shared__ int4 sh[];\n\n  constexpr int perm_size = tile_k_size / 4;\n\n  int4* sh_perm_ptr = sh;\n  int4* sh_pipe_ptr = sh_perm_ptr;\n  if constexpr (has_perm) {\n    sh_pipe_ptr += perm_size;\n  }\n\n  constexpr int tile_ints = tile_k_size / pack_factor;\n\n  constexpr int stage_n_threads = tile_n_size / 4;\n  constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;\n  constexpr int stage_size = stage_k_threads * stage_n_threads;\n\n  auto load_perm_to_shared = [&](int k_tile_id) {\n    int first_k_int4 = (k_tile_id * tile_k_size) / 4;\n\n    int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);\n\n    if (threadIdx.x < perm_size) {\n      sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];\n    }\n    __syncthreads();\n  };\n\n  auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {\n    if (n_tile_id >= n_tiles) {\n      cp_async_fence();\n      return;\n    }\n\n    int first_n = n_tile_id * tile_n_size;\n\n    int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;\n\n    if constexpr (has_perm) {\n      if (threadIdx.x < stage_size) {\n        int k_id = threadIdx.x / stage_n_threads;\n        int n_id = threadIdx.x % stage_n_threads;\n\n        uint32_t const* sh_perm_int_ptr =\n            reinterpret_cast<uint32_t const*>(sh_perm_ptr);\n\n        int src_k = sh_perm_int_ptr[k_id];\n        int src_k_packed = src_k / pack_factor;\n\n        cp_async4(\n            &sh_ptr[k_id * stage_n_threads + n_id],\n            reinterpret_cast<int4 const*>(&(\n                b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));\n      }\n\n    } else {\n      if (threadIdx.x < stage_size) {\n        int k_id = threadIdx.x / stage_n_threads;\n        int n_id = threadIdx.x % stage_n_threads;\n\n        int first_k = k_tile_id * tile_k_size;\n        int first_k_packed = first_k / pack_factor;\n\n        cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],\n                  reinterpret_cast<int4 const*>(\n                      &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +\n                                       first_n + (n_id * 4)])));\n      }\n    }\n\n    cp_async_fence();\n  };\n\n  auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {\n    if (n_tile_id >= n_tiles) {\n      return;\n    }\n\n    int warp_id = threadIdx.x / 32;\n    int th_id = threadIdx.x % 32;\n\n    if (warp_id >= 4) {\n      return;\n    }\n\n    int tc_col = th_id / 4;\n    int tc_row = (th_id % 4) * 2;\n\n    constexpr int tc_offsets[4] = {0, 1, 8, 9};\n\n    int cur_n = warp_id * 16 + tc_col;\n\n    constexpr int sh_stride = 64;\n    constexpr uint32_t mask = (1 << num_bits) - 1;\n\n    int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;\n    uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);\n\n    uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);\n\n    uint32_t vals[8];\n\n    if constexpr (has_perm) {\n      for (int i = 0; i < 4; i++) {\n        int k_idx = tc_row + tc_offsets[i];\n\n        uint32_t src_k = sh_perm_int_ptr[k_idx];\n        uint32_t src_k_pos = src_k % pack_factor;\n\n        uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];\n        uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;\n\n        uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];\n        uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;\n\n        vals[i] = b1_cur_val;\n        vals[4 + i] = b2_cur_val;\n      }\n\n    } else {\n      uint32_t b1_vals[tile_ints];\n      uint32_t b2_vals[tile_ints];\n\n  #pragma unroll\n      for (int i = 0; i < tile_ints; i++) {\n        b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];\n        b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];\n      }\n\n  #pragma unroll\n      for (int i = 0; i < 4; i++) {\n        int cur_elem = tc_row + tc_offsets[i];\n        int cur_int = cur_elem / pack_factor;\n        int cur_pos = cur_elem % pack_factor;\n\n        vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;\n        vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;\n      }\n    }\n\n    constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;\n    int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;\n\n    // Result of:\n    // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h\n    if constexpr (num_bits == 4) {\n      constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};\n\n      uint32_t res = 0;\n  #pragma unroll\n      for (int i = 0; i < 8; i++) {\n        res |= vals[pack_idx[i]] << (i * 4);\n      }\n\n      out_ptr[out_offset + th_id * 4 + warp_id] = res;\n\n    } else {\n      constexpr int pack_idx[4] = {0, 2, 1, 3};\n\n      uint32_t res1 = 0;\n      uint32_t res2 = 0;\n  #pragma unroll\n      for (int i = 0; i < 4; i++) {\n        res1 |= vals[pack_idx[i]] << (i * 8);\n        res2 |= vals[4 + pack_idx[i]] << (i * 8);\n      }\n\n      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;\n      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;\n    }\n  };\n\n  auto start_pipes = [&](int k_tile_id, int n_tile_id) {\n  #pragma unroll\n    for (int pipe = 0; pipe < repack_stages - 1; pipe++) {\n      fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);\n    }\n\n    wait_for_stage();\n  };\n  #pragma unroll\n  for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {\n    int n_tile_id = 0;\n\n    if constexpr (has_perm) {\n      load_perm_to_shared(k_tile_id);\n    }\n\n    start_pipes(k_tile_id, n_tile_id);\n\n    while (n_tile_id < n_tiles) {\n  #pragma unroll\n      for (int pipe = 0; pipe < repack_stages; pipe++) {\n        fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,\n                        n_tile_id + pipe + repack_stages - 1);\n        repack_tile(pipe, k_tile_id, n_tile_id + pipe);\n        wait_for_stage();\n      }\n      n_tile_id += repack_stages;\n    }\n  }\n}\n\n}  // namespace gptq_marlin\n\n  #define CALL_IF(NUM_BITS, HAS_PERM)                                          \\\n    else if (num_bits == NUM_BITS && has_perm == HAS_PERM) {                   \\\n      cudaFuncSetAttribute(                                                    \\\n          gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads,       \\\n                                            NUM_BITS, HAS_PERM>,               \\\n          cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);        \\\n      gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \\\n                                        HAS_PERM>                              \\\n          <<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>(   \\\n              b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);              \\\n    }\n\ntorch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,\n                                 int64_t size_k, int64_t size_n,\n                                 int64_t num_bits) {\n  // Verify compatibility with marlin tile of 16x64\n  TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, \"size_k = \", size_k,\n              \" is not divisible by tile_k_size = \", gptq_marlin::tile_k_size);\n  TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, \"size_n = \", size_n,\n              \" is not divisible by tile_n_size = \", gptq_marlin::tile_n_size);\n\n  TORCH_CHECK(num_bits == 4 || num_bits == 8,\n              \"num_bits must be 4 or 8. Got = \", num_bits);\n  int const pack_factor = 32 / num_bits;\n\n  // Verify B\n  TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),\n              \"Shape mismatch: b_q_weight.size(0) = \", b_q_weight.size(0),\n              \", size_k = \", size_k, \", pack_factor = \", pack_factor);\n  TORCH_CHECK(b_q_weight.size(1) == size_n,\n              \"b_q_weight.size(1) = \", b_q_weight.size(1),\n              \" is not size_n = \", size_n);\n\n  // Verify device and strides\n  TORCH_CHECK(b_q_weight.device().is_cuda(), \"b_q_weight is not on GPU\");\n  TORCH_CHECK(b_q_weight.is_contiguous(), \"b_q_weight is not contiguous\");\n  TORCH_CHECK(b_q_weight.dtype() == at::kInt, \"b_q_weight type is not kInt\");\n\n  TORCH_CHECK(perm.device().is_cuda(), \"perm is not on GPU\");\n  TORCH_CHECK(perm.is_contiguous(), \"perm is not contiguous\");\n  TORCH_CHECK(perm.dtype() == at::kInt, \"perm type is not at::kInt\");\n\n  // Alloc buffers\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));\n  auto options = torch::TensorOptions()\n                     .dtype(b_q_weight.dtype())\n                     .device(b_q_weight.device());\n  torch::Tensor out =\n      torch::empty({size_k / gptq_marlin::tile_size,\n                    size_n * gptq_marlin::tile_size / pack_factor},\n                   options);\n\n  // Detect if there is act_order\n  bool has_perm = perm.size(0) != 0;\n\n  // Get ptrs\n  uint32_t const* b_q_weight_ptr =\n      reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());\n  uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());\n  uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());\n\n  // Get dev info\n  int dev = b_q_weight.get_device();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);\n  int blocks;\n  cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);\n\n  int max_shared_mem = 0;\n  cudaDeviceGetAttribute(&max_shared_mem,\n                         cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n  TORCH_CHECK(max_shared_mem > 0);\n\n  if (false) {\n  }\n  CALL_IF(4, false)\n  CALL_IF(4, true)\n  CALL_IF(8, false)\n  CALL_IF(8, true)\n  else {\n    TORCH_CHECK(false, \"Unsupported repack config: num_bits = \", num_bits,\n                \", has_perm = \", has_perm);\n  }\n\n  return out;\n}\n\n#endif"
  },
  {
    "path": "archive/csrc/custom_marlin/gptq_marlin/ops.h",
    "content": "/**\n * @Description  :\n * @Author       : Azure\n * @Date         : 2024-07-22 09:27:55\n * @Version      : 1.0.0\n * @LastEditors  : Azure\n * @LastEditTime : 2024-07-26 08:35:00\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#pragma once\n\n#include <torch/extension.h>\n#include <torch/library.h>\n#include <torch/torch.h>\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,\n                               torch::Tensor &b_scales, torch::Tensor &g_idx,\n                               torch::Tensor &perm, torch::Tensor &workspace,\n                               int64_t num_bits, torch::Tensor size_m_tensor, int64_t size_m, int64_t size_n,\n                               int64_t size_k, int sms, bool is_k_full);\n\ntorch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor&perm,\n                                 int64_t size_k, int64_t size_n,\n                                 int64_t num_bits);"
  },
  {
    "path": "archive/csrc/custom_marlin/setup.py",
    "content": "from setuptools import setup, Extension\nfrom torch.utils import cpp_extension\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nsetup(\n    name='vLLMMarlin',\n    ext_modules=[\n        CUDAExtension(\n            'vLLMMarlin', [\n                #'custom_gguf/dequant.cu',\n                'binding.cpp',\n                'gptq_marlin/gptq_marlin.cu',\n                'gptq_marlin/gptq_marlin_repack.cu',\n            ],\n            extra_compile_args={\n                'cxx': ['-O3'],\n                'nvcc': [\n                    '-O3',\n                    '--use_fast_math',\n                    '-Xcompiler', '-fPIC',\n                ]\n            },\n        )\n    ],\n    cmdclass={'build_ext': BuildExtension}\n)"
  },
  {
    "path": "archive/csrc/custom_marlin/test_cuda_graph.py",
    "content": "import csv\r\nimport torch\r\nimport torch.nn as nn\r\nimport vLLMMarlin\r\ntorch.set_grad_enabled(False)\r\nfrom utils.marlin_utils import (\r\n\tMarlinWorkspace,\r\n\tmarlin_quantize,\r\n\tGPTQ_MARLIN_MIN_THREAD_N,\r\n\tGPTQ_MARLIN_MIN_THREAD_K,\r\n\tGPTQ_MARLIN_MAX_PARALLEL,\r\n)\r\n\r\ndef setup_seed(seed):\r\n\ttorch.manual_seed(seed)\r\n\ttorch.cuda.manual_seed_all(seed)\r\n\r\nsetup_seed(20241223)\r\n\r\ntorch.set_grad_enabled(False)\r\ntorch.set_default_dtype(torch.bfloat16)\r\nglobal_dtype=torch.bfloat16\r\nglobal_device=torch.device(\"cuda\",0)\r\nglobal_num_cases:int=int(50)\r\ntorch.cuda.set_device(0)\r\ntorch.backends.cudnn.enabled =True\r\ntorch.backends.cudnn.benchmark = True\r\n\r\nmax_batch_size = 512\r\nmax_tp = 8\r\nL2_size = 73728 * 1024\r\n\r\ndef get_usable_mem():\r\n\tproperties = torch.cuda.get_device_properties(global_device)\r\n\t#print(f\"Total memory: {properties.total_memory / (1024 ** 3):.2f} GB\")\r\n\tallocated_memory = torch.cuda.memory_allocated(global_device)\r\n\t#print(f\"Currently allocated memory: {allocated_memory / (1024 ** 2):.2f} MB\")\r\n\treserved_memory = torch.cuda.memory_reserved(global_device)\r\n\t#print(f\"Currently reserved memory: {reserved_memory / (1024 ** 2):.2f} MB\")\r\n\treturn properties.total_memory - 512 * 1024 ** 2 - allocated_memory# - reserved_memory\r\n\r\ndef exp_range(start, stop, step = 2):\r\n\tnow = start\r\n\twhile now <= stop:\r\n\t\tyield now\r\n\t\tnow *= step\r\n\r\ndef timing(func, iters, epochs=100):\r\n\t#warmup\r\n\tfor idx in range(iters):\r\n\t\tfunc(idx)\r\n\t\t\r\n\ttorch.cuda.synchronize()\r\n\tcuda_graph = torch.cuda.CUDAGraph()\r\n\twith torch.cuda.graph(cuda_graph):\r\n\t\tfor idx in range(iters):\r\n\t\t\tfunc(idx)\r\n\r\n\tfor _ in range(2000):\r\n\t\tcuda_graph.replay()\r\n\r\n\tstart_event = torch.cuda.Event(enable_timing=True)\r\n\tend_event = torch.cuda.Event(enable_timing=True)\r\n\tstream = torch.cuda.Stream()\r\n\ttorch.cuda.synchronize()\r\n\t#with torch.cuda.stream(stream):\r\n\tstart_event.record()\r\n\tfor _ in range(10):\r\n\t\tcuda_graph.replay()\r\n\tend_event.record()\r\n\ttorch.cuda.synchronize()\r\n\telapsed_time_ms0 = start_event.elapsed_time(end_event)\r\n\t\r\n\tstart_event = torch.cuda.Event(enable_timing=True)\r\n\tend_event = torch.cuda.Event(enable_timing=True)\r\n\ttorch.cuda.synchronize()\r\n\t#with torch.cuda.stream(stream):\r\n\tstart_event.record()\r\n\tfor _ in range(epochs+10):\r\n\t\tcuda_graph.replay()\r\n\tend_event.record()\r\n\ttorch.cuda.synchronize()\r\n\telapsed_time_ms = start_event.elapsed_time(end_event) - elapsed_time_ms0\r\n\t\r\n\t#print(elapsed_time_ms0, elapsed_time_ms)\r\n\treturn elapsed_time_ms/iters/epochs\r\n\r\nclass LinearMarlin(nn.Linear):\r\n\tmarlin_q_w: torch.Tensor\r\n\tmarlin_s: torch.Tensor\r\n\tg_idx: torch.Tensor\r\n\tsort_indices: torch.Tensor\r\n\thas_bias: bool\r\n\tdef __init__(\r\n\t\tself,\r\n\t\tin_features,\r\n\t\tout_features,\r\n\t\tbias = False,\r\n\t\tdevice: str = \"cuda\",\r\n\t\tnum_bits: int = 4,  # 4-bit/8-bit is supported\r\n\t\tgroup_size: int = 64,  # -1, 32, 64, 128\r\n\t\tact_order: bool = False,\r\n\t\tis_k_full=True,\r\n\t\tsms = -1, # sms in GPU\r\n\t\t**kwargs,\r\n\t):\r\n\t\tself.padding = False\r\n\t\tassert device.lower() != \"cpu\", \"Marlin quantized linear only supports GPU device\"\r\n\t\tif in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:\r\n\t\t\t#print(f\"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding\")\r\n\t\t\tself.padding = True\r\n\t\t\tself.orin_in_features = in_features\r\n\t\t\tself.orin_out_features = out_features\r\n\t\t\tin_features = (in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K\r\n\t\t\tout_features = (out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N\r\n\t\t\t#print(f\"After padding: in_features={in_features}, out_features={out_features}\")\r\n\t\t\t\r\n\r\n\t\tsuper().__init__(in_features, out_features, bias, device)\r\n\t\tself.has_bias = bias\r\n\t\tself.device = device\r\n\t\tself.num_bits = num_bits\r\n\t\tself.group_size = group_size\r\n\t\tself.act_order = act_order\r\n\t\t# TODO: optimize every shape GEMM\r\n\t\t\r\n\t\tblocks_k, blocks_n = in_features//128, out_features//128\r\n\r\n\t\tself.sms = sms\r\n\r\n\t\tself.is_k_full = is_k_full\r\n\t\t\r\n\t\tself.weight.requires_grad = False\r\n\t\tself.weight.t_()\r\n\t\t# Pack Marlin linear\r\n\t\t#w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(\r\n\t\t#    self.weight, self.num_bits, self.group_size, self.act_order\r\n\t\t#)\r\n\t\tmarlin_q_w = torch.randint(int(-1e9), int(1e9), (in_features//16, out_features*2), device=device, dtype=torch.int)\r\n\t\tmarlin_s = torch.randn((in_features//64, out_features), device=device)\r\n\t\tself.workspace = MarlinWorkspace(\r\n\t\t\tself.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, self.device\r\n\t\t)\r\n\t\tself.marlin_q_w = marlin_q_w\r\n\t\tself.marlin_s = marlin_s\r\n\t\tself.g_idx = torch.empty((0), dtype=torch.int32, device=self.device)\r\n\t\tself.sort_indices = torch.empty((0), dtype=torch.int32, device=self.device)\r\n\t\tself.k = self.weight.shape[0]\r\n\t\tself.n = self.weight.shape[1]\r\n\t\tself.weight = None\r\n\t\t\"\"\"\r\n\t\tprint(in_features, out_features)\r\n\t\tprint(marlin_q_w.shape)\r\n\t\tprint(marlin_q_w.dtype)\r\n\t\tprint(marlin_s.shape)\r\n\t\tprint(marlin_s.dtype)\r\n\t\tprint(self.workspace.scratch.shape)\r\n\t\tprint(self.workspace.scratch.dtype)\r\n\t\tprint(self.g_idx.shape)\r\n\t\tprint(self.g_idx.dtype)\r\n\t\tprint(self.sort_indices.shape)\r\n\t\tprint(self.sort_indices.dtype)\r\n\t\t#print(w_ref.shape)\r\n\t\t#print(w_ref.dtype)\r\n\t\t\"\"\"\r\n\t\t#w_ref = None\r\n\r\n\tdef forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:\r\n\t\t# Only support input x as BF16 and FP16\r\n\t\tx = x.to(self.device)\r\n\t\torig_shape = list(x.shape)\r\n\t\torig_dtype = x.dtype\r\n\t\tx = x.reshape(-1, x.shape[-1])\r\n\t\tif self.padding:\r\n\t\t\tpadding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)\r\n\t\t\tpadding_input[:,:self.orin_in_features] = x\r\n\t\t\tx = padding_input\r\n\t\tmarlin_s = self.marlin_s.to(x.dtype)\r\n\t\t#print(self.sms * ((orig_shape[0]+63)//64))\r\n\t\t\r\n\t\tsms = self.sms\r\n\r\n\t\tx = vLLMMarlin.gptq_marlin_gemm(\r\n\t\t\tx,\r\n\t\t\tself.marlin_q_w,\r\n\t\t\tmarlin_s,\r\n\t\t\tself.g_idx,\r\n\t\t\tself.sort_indices,\r\n\t\t\tself.workspace.scratch,\r\n\t\t\tself.num_bits,\r\n\t\t\tbsz_tensor,\r\n\t\t\tx.shape[0],\r\n\t\t\tself.n,\r\n\t\t\tx.shape[-1],\r\n\t\t\tsms,\r\n\t\t\tself.is_k_full,\r\n\t\t)\r\n\t\t# TODO: don't padding bias\r\n\t\tif self.has_bias:\r\n\t\t\tx = x + self.bias\r\n\t\tif self.padding:\r\n\t\t\tx = x[:,:self.orin_out_features]\r\n\t\t\torig_shape[-1] = self.orin_out_features\r\n\t\telse:\r\n\t\t\torig_shape[-1] = self.out_features\r\n\t\treturn x.reshape(orig_shape).to(orig_dtype)\r\n\r\ndef benchLinearMarlin(input_dim, output_dim):#, out_file\r\n\tprint(\"benchmarking MLP Marlin\")\r\n\tprint(\"-----------------------------------------------------------\")\r\n\theaders = [\"batch_size\", \"tp\", \"used_time\", \"bandwidth GB/s\", \"TFLOPS\", \"cases\", \"padding\", \"sms\"]\r\n\tprint(\" | \".join(headers) + \"\\n\")\r\n\trows = []\r\n\tfor batch_size in exp_range(1, 64):\r\n\t\tfor tp in exp_range(1, max_tp):\r\n\t\t\ttorch.cuda.empty_cache()\r\n\t\t\tif output_dim % tp != 0:\r\n\t\t\t\tcontinue\r\n\t\t\tcur_output_dim = output_dim // tp\r\n\t\t\tmodules = []\r\n\t\t\tinputs = []\r\n\t\t\tdata_size = int(0.53125*input_dim*cur_output_dim)\r\n\t\t\tinput_size = int(2*batch_size*input_dim)\r\n\t\t\toutput_size = int(2*batch_size*cur_output_dim)\r\n\t\t\tusable_mem = get_usable_mem() - 2 * input_dim * cur_output_dim\r\n\t\t\tmin_cases = max(global_num_cases, (2*L2_size) // (data_size+input_size))\r\n\t\t\tcases = int(min(min_cases, (usable_mem * 0.8) // (data_size+input_size)))\r\n\t\t\t#print(usable_mem, data_size, input_size, cases)\r\n\t\t\t\t\r\n\t\t\tbsz_tensor = torch.tensor([batch_size], device=global_device, dtype=torch.int32)\r\n\r\n\t\t\tif cases == 0:\r\n\t\t\t\trow = [f\"{batch_size}\", \"OOM\", \"OOM\", \"OOM\", \"0\", \"False\"]\r\n\t\t\t\trows.append(row)\r\n\t\t\t\tbreak\r\n\t\t\tfor _ in range(cases):\r\n\t\t\t\tmodules.append(LinearMarlin(input_dim, cur_output_dim, sms=56, non_equal_division=False).to(device=global_device).eval())\r\n\t\t\t\tinputs.append(torch.randn(batch_size, 1, input_dim, device=global_device))\r\n\t\t\t\t\r\n\t\t\tdef forward(case_id):\r\n\t\t\t\tmodules[case_id](inputs[case_id], bsz_tensor)\r\n\t\t\t\t\r\n\t\t\tused_time = timing(forward, iters=cases)\r\n\t\t\tbandwidth = (data_size+input_size+output_size)/used_time/1e6\r\n\t\t\tflops = 2*batch_size*input_dim*cur_output_dim\r\n\t\t\ttflops = flops/used_time/1e9\r\n\t\t\tcur_sms = modules[0].sms\r\n\t\t\trow = [f\"{batch_size}\", f\"{tp}\", f\"{used_time}\", f\"{bandwidth}\", f\"{tflops}\", f\"{cases}\", modules[0].padding, cur_sms]\r\n\t\t\trows.append(row)\r\n\t\t\tprint(f\"{batch_size}\", f\"{tp}\", f\"{used_time}\", f\"{bandwidth}\", f\"{tflops}\", f\"{cases}\", modules[0].padding, cur_sms)\r\n\t\r\n\t\"\"\"\r\n\twith open(out_file, 'w', newline='') as csvfile:\r\n\t\tcsvwriter = csv.writer(csvfile)\r\n\t\tcsvwriter.writerow(headers)\r\n\t\tfor row in rows:\r\n\t\t\tcsvwriter.writerow(row)\r\n\t\"\"\"\r\n\t\r\n\t\"\"\"\r\n\tmarkdown_table = \" | \".join(headers) + \"\\n\"\r\n\tmarkdown_table += \" | \".join([\"---\"] * len(headers)) + \"\\n\"\r\n\tfor row in rows:\r\n\t\tmarkdown_table += \" | \".join(row) + \"\\n\"\r\n\r\n\tprint(markdown_table)\r\n\t\"\"\"\r\n\t#print(\"finish write file\", out_file)\r\n\t#print(\"-------------------------------------------------------------\")\r\n\r\nif __name__ == \"__main__\":\r\n\t\r\n\tbenchLinearMarlin(5120, 3584)\r\n\texit(0)\r\n\t\r\n\tmax_batch = 1\r\n\tcur_batch = 1\r\n\r\n\r\n\tmarlin_linear = LinearMarlin(5120, 3584)\r\n\r\n\tinput_tensor = torch.randn(max_batch, 1, 5120, device=\"cuda\", dtype=torch.bfloat16)\r\n\tbsz_tensor = torch.tensor([max_batch], device=\"cuda\", dtype=torch.int32)\r\n\r\n\tout_truth = marlin_linear(input_tensor, bsz_tensor)\r\n\r\n\tprint(out_truth)\r\n\r\n\tg = torch.cuda.CUDAGraph()\r\n\twith torch.cuda.graph(g):\r\n\t\tout_buf = marlin_linear(input_tensor, bsz_tensor)\r\n\t\r\n\tfor i in range(10000):\r\n\t\tg.replay()\r\n\t\r\n\t#torch.testing.assert_close(out_buf, out_truth, rtol=1e-3, atol=1e-3)\r\n\t\r\n\tmarlin_linear = LinearMarlin(5120, 3584)\r\n\tg = torch.cuda.CUDAGraph()\r\n\twith torch.cuda.graph(g):\r\n\t\tout_buf = marlin_linear(input_tensor, bsz_tensor)\r\n\t\r\n\tnew_input = torch.randn(cur_batch, 1, 5120, device=\"cuda\", dtype=torch.bfloat16)\r\n\tbsz_tensor.copy_(torch.tensor([cur_batch], device=\"cuda\", dtype=torch.int32))\r\n\t\r\n\tnew_out_truth = marlin_linear(new_input, bsz_tensor)\r\n\tinput_tensor[:cur_batch].copy_(new_input)\r\n\tinput_tensor[cur_batch:] = 0\r\n\t\r\n\tg.replay()\r\n\t\r\n\ttorch.cuda.synchronize()\r\n\r\n\tdef printMinMax(tensor):\r\n\t\tabs_tensor = torch.abs(tensor)\r\n\r\n\t\tmin_val = torch.min(abs_tensor)\r\n\t\tmax_val = torch.max(abs_tensor)\r\n\r\n\t\tmin_indices = (abs_tensor == min_val).nonzero(as_tuple=True)\r\n\t\tmax_indices = (abs_tensor == max_val).nonzero(as_tuple=True)\r\n\r\n\t\tprint(f\"min: {min_val.item()}\")\r\n\t\tprint(f\"min idx: {min_indices}\")\r\n\t\tprint(f\"max: {max_val.item()}\")\r\n\t\tprint(f\"max idx: {max_indices}\")\r\n\r\n\tprint(out_buf[:cur_batch].shape)\r\n\tprint(new_out_truth.shape)\r\n\r\n\r\n\tprintMinMax(out_buf[:cur_batch])\r\n\tprintMinMax(new_out_truth)\r\n\r\n\t#torch.testing.assert_close(out_buf[:cur_batch, 0, :], new_out_truth[:cur_batch, 0, :], rtol=1e-3, atol=1e-3)\r\n"
  },
  {
    "path": "archive/csrc/custom_marlin/utils/__init__.py",
    "content": ""
  },
  {
    "path": "archive/csrc/custom_marlin/utils/format24.py",
    "content": "#\n# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).\n#\n\nimport torch\n\n\n# This is PyTorch implementation of main part of reorder_meta()\n# function, from tools/util/include/cutlass/util/host_reorder.h file\n# of CUTLASS source tree.  Furthermore, CUTLASS template for sparse\n# GEMM decides upon layout of this matrix, and at the moment for the\n# sparse GEMM executed on tensor cores, this is layout described by\n# ColumnMajorInterleaved<2> data structure, in\n# include/cutlass/layout/matrix.h of CUTLASS source tree.  The\n# reordering of meta matrix into meta_reordered matrix calculated\n# according to these segments of CUTLASS code is re-implemented here.\n# Note that this calculation produces offsets for scattering metadata\n# matrix elements into reordered metadata matrix elements (or,\n# equivalently, for gathering reordered metadata matrix element back\n# into metadata matrix elements).\ndef _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype,\n                                               device):\n    dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)\n    dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)\n\n    # Reorder the rows, then swizzle the 2x2 blocks.\n    group_x = 64\n    group_y = 32 if meta_dtype.itemsize == 2 else 16\n\n    dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 +\n                (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 +\n                ((dst_rows % group_x) // 8) * 4)\n\n    topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)\n    bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)\n    dst_rows += topright - bottomleft\n    dst_cols -= topright - bottomleft\n\n    # Assumed that meta tensor is to be stored in CUTLASS\n    # InterleavedColumnMajor layout, and reverse engineered\n    # corresponding code to store values into this tensor.\n    interleave = 2\n    cols_maj = dst_cols // interleave\n    cols_min = dst_cols % interleave\n    return (cols_maj * m * interleave + dst_rows * interleave +\n            cols_min).view(-1)\n\n\n# This function converts dense matrix into sparse semi-structured\n# representation, producing \"compressed\" matrix, in the layout used by\n# CUTLASS backend, and corresponding metadata matrix.\ndef sparse_semi_structured_from_dense_cutlass(dense):\n    if dense.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor\"  # noqa: E501\n        )\n\n    m, k = dense.shape\n    device = dense.device\n\n    meta_dtype = torch.int8\n    if dense.dtype == torch.int8:\n        meta_dtype = torch.int32\n    elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:\n        meta_dtype = torch.int16\n    else:\n        raise RuntimeError(f\"Invalid datatype {dense.dtype} of dense matrix\")\n    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4\n    if quadbits_per_meta_elem not in (4, 8):\n        raise RuntimeError(\n            \"Invalid number of elements per meta element calculated\")\n\n    if meta_dtype == torch.int32:\n        if m % 16 != 0:\n            raise RuntimeError(\n                f\"Number of rows of dense matrix {m} must be divisible by 16\")\n    else:\n        if m % 32 != 0:\n            raise RuntimeError(\n                f\"Number of rows of dense matrix {m} must be divisible by 32\")\n    if k % (4 * quadbits_per_meta_elem) != 0:\n        raise RuntimeError(\n            f\"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}\"  # noqa: E501\n        )\n\n    if dense.dtype != torch.float:\n        ksparse = 4\n        dense_4 = dense.view(-1, k // ksparse, ksparse)\n        m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)\n    else:\n        ksparse = 2\n        dense_2 = dense.view(-1, k // ksparse, ksparse)\n        m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)\n    meta_ncols = k // (ksparse * quadbits_per_meta_elem)\n\n    # Encoding quadruples of True/False values as follows:\n    #     [True,  True,  False, False] -> 0b0100\n    #     [True,  False, True,  False] -> 0b1000\n    #     [False, True,  True,  False] -> 0b1001\n    #     [True,  False, False, True ] -> 0b1100\n    #     [False, True,  False, True ] -> 0b1101\n    #     [False, False, True,  True ] -> 0b1110\n    # Thus, lower two bits in the encoding are index of the True value\n    # at the lowest index in the quadruple, and the higher two bits in\n    # the encoding are index of the other True value in the quadruple.\n    # In case there are less than two True values, than False value or\n    # values at some index or indices are considered True for the\n    # encoding.  In case there are more than two True values, then the\n    # excess True value(s) at some indices are considered False for\n    # the encoding.  The exact encodings used for these cases are as\n    # follows:\n    #     [False, False, False, False] -> 0b1110\n    #     [False, False, False, True ] -> 0b1110\n    #     [False, False, True,  False] -> 0b1110\n    #     [False, True,  False, False] -> 0b1001\n    #     [False, True,  True,  True ] -> 0b1101\n    #     [True,  False, False, False] -> 0b1000\n    #     [True,  False, True,  True ] -> 0b1100\n    #     [True,  True,  False, True ] -> 0b0100\n    #     [True,  True,  True,  False] -> 0b0100\n    #     [True,  True,  True,  True ] -> 0b0100\n    # These particular encodings are chosen, with the help of Espresso\n    # logic minimizer software, for the purpose of minimization of\n    # corresponding Boolean functions, that translate non-zero flags\n    # into encoding bits.  Note also possible choices for the first\n    # and last of these encodings were limited only to (0b0100,\n    # 0b1110), in order to produce valid encodings for 1:2 sparsity\n    # case.\n\n    expr0 = m0 & m1\n    expr1 = ~m0 & m1\n    expr2 = ~m0 & ~m1\n    bit0 = expr1\n    bit1 = expr2\n    bit2 = expr0 | expr2 | m3\n    bit3 = expr1 | ~m1\n    idxs0 = bit0 | (bit1.to(torch.int64) << 1)\n    idxs1 = bit2 | (bit3.to(torch.int64) << 1)\n\n    if dense.dtype != torch.float:\n        sparse0 = dense_4.gather(\n            -1, idxs0.unsqueeze(-1))  # type: ignore[possibly-undefined]\n        sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))\n        sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)\n    else:\n        sparse = dense_2.gather(-1,\n                                idxs0.unsqueeze(-1) // 2).view(\n                                    m,\n                                    k // 2)  # type: ignore[possibly-undefined]\n\n    meta_4 = idxs0 | (idxs1 << 2)\n    meta_n = meta_4.view(\n        (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)\n\n    if quadbits_per_meta_elem == 4:\n        meta = (meta_n[:, :, 0]\n                | (meta_n[:, :, 1] << 4)\n                | (meta_n[:, :, 2] << 8)\n                | (meta_n[:, :, 3] << 12))\n    elif quadbits_per_meta_elem == 8:\n        meta = (meta_n[:, :, 0]\n                | (meta_n[:, :, 1] << 4)\n                | (meta_n[:, :, 2] << 8)\n                | (meta_n[:, :, 3] << 12)\n                | (meta_n[:, :, 4] << 16)\n                | (meta_n[:, :, 5] << 20)\n                | (meta_n[:, :, 6] << 24)\n                | (meta_n[:, :, 7] << 28))\n\n    # Reorder meta tensor elements.\n    meta_reordered = meta.new_empty(\n        (m * meta_ncols, ))  # type: ignore[possibly-undefined]\n    meta_offsets = _calculate_meta_reordering_scatter_offsets(\n        m, meta_ncols, meta_dtype, device)\n    meta_reordered.scatter_(0, meta_offsets, meta.view(-1))\n\n    return (sparse, meta_reordered.view(m, meta_ncols))\n\n\n# This function performs reverse of the function above - it\n# reconstructs dense matrix from a pair of \"compressed\" matrix, given\n# in the layout used by CUTLASS backend, and accompanying metadata\n# matrix.\ndef sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):\n    if sparse.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor\"  # noqa: E501\n        )\n\n    m, k = sparse.shape\n    device = sparse.device\n\n    if meta_reordered.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor\"  # noqa: E501\n        )\n    if meta_reordered.device != device:\n        raise RuntimeError(\n            f\"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device\"  # noqa: E501\n        )\n\n    meta_dtype = meta_reordered.dtype\n    if meta_dtype not in (torch.int16, torch.int32):\n        raise RuntimeError(f\"Invalid datatype {meta_dtype} of meta matrix\")\n    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4\n\n    ksparse = 4 if sparse.dtype != torch.float else 2\n\n    meta_nrows, meta_ncols = meta_reordered.shape\n    if meta_nrows != m:\n        raise RuntimeError(\n            f\"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}\"  # noqa: E501\n        )\n    if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:\n        raise RuntimeError(\n            f\"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, \"  # noqa: E501\n            \"expected according to the number of columns of meta matrix\")\n\n    # Undo meta tensor elements reordering.\n    meta_offsets = _calculate_meta_reordering_scatter_offsets(\n        m, meta_ncols, meta_dtype, device)\n    meta = torch.gather(meta_reordered.view(-1), 0,\n                        meta_offsets).view(m, meta_ncols)\n\n    # Unpack sparse tensor back to original dense tensor, using\n    # information provided by meta tensor.  Note that torch.float\n    # datatype is handled pretty much the same as\n    # torch.half/torch.bfloat16, as metadata for a pair of torch.float\n    # value is encoded as if underlying 8 bytes contain four\n    # torch.half/torch.bfloat16 values, where either first two or last\n    # two are zeros.\n    meta_2 = torch.empty(\n        (m, meta_ncols, 2 * quadbits_per_meta_elem),\n        dtype=meta_dtype,\n        device=device,\n    )\n    if quadbits_per_meta_elem == 4:\n        meta_2[:, :, 0] = meta & 0b11\n        meta_2[:, :, 1] = (meta >> 2) & 0b11\n        meta_2[:, :, 2] = (meta >> 4) & 0b11\n        meta_2[:, :, 3] = (meta >> 6) & 0b11\n        meta_2[:, :, 4] = (meta >> 8) & 0b11\n        meta_2[:, :, 5] = (meta >> 10) & 0b11\n        meta_2[:, :, 6] = (meta >> 12) & 0b11\n        meta_2[:, :, 7] = (meta >> 14) & 0b11\n    elif quadbits_per_meta_elem == 8:\n        meta_2[:, :, 0] = meta & 0b11\n        meta_2[:, :, 1] = (meta >> 2) & 0b11\n        meta_2[:, :, 2] = (meta >> 4) & 0b11\n        meta_2[:, :, 3] = (meta >> 6) & 0b11\n        meta_2[:, :, 4] = (meta >> 8) & 0b11\n        meta_2[:, :, 5] = (meta >> 10) & 0b11\n        meta_2[:, :, 6] = (meta >> 12) & 0b11\n        meta_2[:, :, 7] = (meta >> 14) & 0b11\n        meta_2[:, :, 8] = (meta >> 16) & 0b11\n        meta_2[:, :, 9] = (meta >> 18) & 0b11\n        meta_2[:, :, 10] = (meta >> 20) & 0b11\n        meta_2[:, :, 11] = (meta >> 22) & 0b11\n        meta_2[:, :, 12] = (meta >> 24) & 0b11\n        meta_2[:, :, 13] = (meta >> 26) & 0b11\n        meta_2[:, :, 14] = (meta >> 28) & 0b11\n        meta_2[:, :, 15] = (meta >> 30) & 0b11\n\n    dense_offsets = meta_2.view(-1) + (\n        torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view(\n            -1, 1).repeat(1, 2).view(-1)\n\n    dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device)\n    if sparse.dtype != torch.float:\n        # dense.scatter_(0, dense_offsets, sparse.view(-1))\n        dense.scatter_(0, dense_offsets, sparse.reshape(-1))\n    else:\n        dense.view(torch.half).scatter_(0, dense_offsets,\n                                        sparse.view(torch.half).view(-1))\n\n    return dense.view(m, 2 * k)\n\n\ndef mask_creator(tensor):\n    \"\"\"\n    Class for creating N:M sparsity masks.\n    Masks will be created using the N:M ratio, where for every block of \n    M weights, N will be pruned based on ranked weight value. Each mask \n    will correspond to the given tensor.\n\n    :param N: The number of weights in a group to keep\n    :param M: The size of a weight group\n    \"\"\"\n    N = 2\n    M = 4\n\n    mask = None\n    # for i, tensor in enumerate(tensors):\n    if tensor.numel() % M != 0:\n        raise ValueError(\n            f\"Tensor of size {tensor.shape} can't be evenly divided into \"\n            f\"{M} groups\")\n\n    num_groups = tensor.numel() // M\n\n    # N:M sparsity for linear layers\n    tensor_temp = tensor.detach().abs().reshape(num_groups, M)\n    index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)]\n\n    w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)\n    mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)\n\n    return mask"
  },
  {
    "path": "archive/csrc/custom_marlin/utils/marlin_24_perms.py",
    "content": "'''\nDate: 2024-11-08 02:46:07\nLastEditors: djw\nLastEditTime: 2024-11-08 02:46:41\n'''\n\"\"\"This file is used for /tests and /benchmarks\"\"\"\nfrom typing import Dict, List\n\nimport numpy\nimport torch\n\n\n# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501\n#\n# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501\n# with the tensor-core format that is described here:\n# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501\n#\n# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501\n# (without the need to use ldmatrix instructions) # noqa: E501\ndef get_perms_24(num_bits: int):\n    perm_list: List[int] = []\n    for i in range(32):\n        perm1: List[int] = []\n        col = i // 4\n        col_o = col // 2\n        for block in [0, 1]:\n            for row in [\n                    2 * (i % 4),\n                    2 * (i % 4) + 1,\n                    2 * (i % 4 + 4),\n                    2 * (i % 4 + 4) + 1,\n            ]:\n                perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +\n                             4 * block)\n        for j in range(4):\n            perm_list.extend([p + 1 * j for p in perm1])\n    perm = numpy.array(perm_list)\n\n    if num_bits == 4:\n        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = numpy.array([0, 2, 1, 3])\n    else:\n        raise ValueError(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()\n    perm = torch.from_numpy(perm)\n    scale_perm: List[int] = []\n    for i in range(8):\n        scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])\n    scale_perm_single: List[int] = []\n    for i in range(8):\n        scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])\n    return perm, scale_perm, scale_perm_single\n\n\nmarlin_24_perm: Dict[int, torch.Tensor] = {}\nmarlin_24_scale_perm: Dict[int, List[int]] = {}\nmarlin_24_scale_perm_single: Dict[int, List[int]] = {}\nfor num_bits in [4, 8]:\n    perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)\n    marlin_24_perm[num_bits] = perm_24\n    marlin_24_scale_perm[num_bits] = scale_perm_24\n    marlin_24_scale_perm_single[num_bits] = scale_perm_single_24"
  },
  {
    "path": "archive/csrc/custom_marlin/utils/marlin_perms.py",
    "content": "'''\nDate: 2024-11-08 02:46:47\nLastEditors: djw\nLastEditTime: 2024-11-08 02:46:55\n'''\n\"\"\"This file is used for /tests and /benchmarks\"\"\"\nfrom typing import Dict, List\n\nimport numpy\nimport torch\n\n\n# Precompute permutations for Marlin weight and scale shuffling # noqa: E501\n#\n# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501\n# with the tensor-core format that is described here:\n# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501\n#\n# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501\n# (without the need to use ldmatrix instructions) # noqa: E501\ndef get_perms(num_bits: int):\n    perm_list: List[int] = []\n    for i in range(32):\n        perm1: List[int] = []\n        col = i // 4\n        for block in [0, 1]:\n            for row in [\n                    2 * (i % 4),\n                    2 * (i % 4) + 1,\n                    2 * (i % 4 + 4),\n                    2 * (i % 4 + 4) + 1,\n            ]:\n                perm1.append(16 * row + col + 8 * block)\n        for j in range(4):\n            perm_list.extend([p + 256 * j for p in perm1])\n\n    perm = numpy.array(perm_list)\n\n    if num_bits == 4:\n        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = numpy.array([0, 2, 1, 3])\n    else:\n        raise Exception(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()\n    perm = torch.from_numpy(perm)\n    scale_perm: List[int] = []\n    for i in range(8):\n        scale_perm.extend([i + 8 * j for j in range(8)])\n    scale_perm_single: List[int] = []\n    for i in range(4):\n        scale_perm_single.extend(\n            [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])\n    return perm, scale_perm, scale_perm_single\n\n\nmarlin_perm: Dict[int, torch.Tensor] = {}\nmarlin_scale_perm: Dict[int, List[int]] = {}\nmarlin_scale_perm_single: Dict[int, List[int]] = {}\nfor num_bits in [4, 8]:\n    perm, scale_perm, scale_perm_single = get_perms(num_bits)\n    marlin_perm[num_bits] = perm\n    marlin_scale_perm[num_bits] = scale_perm\n    marlin_scale_perm_single[num_bits] = scale_perm_single"
  },
  {
    "path": "archive/csrc/custom_marlin/utils/marlin_utils.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nimport random\n\nimport numpy\nimport torch\n\nfrom .format24 import (\n    mask_creator, sparse_semi_structured_from_dense_cutlass)\nfrom .marlin_24_perms import (\n    marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single)\nfrom .marlin_perms import (\n    marlin_perm, marlin_scale_perm, marlin_scale_perm_single)\nfrom .quant_utils import (\n    get_pack_factor, quantize_weights, sort_weights, dequantize_weights)\n\n\n\n__cuda_arch = torch.cuda.get_device_capability()\n\nMARLIN_TILE = 16\n\nGPTQ_MARLIN_TILE = 16\nGPTQ_MARLIN_MIN_THREAD_N = 64\nGPTQ_MARLIN_MIN_THREAD_K = 128\nGPTQ_MARLIN_MAX_PARALLEL = 16\n\nGPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]\nGPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]\nGPTQ_MARLIN_SUPPORTED_SYM = [True]\n\ndef is_marlin_supported():\n    return __cuda_arch[0] >= 8\n\n\ndef marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):\n    assert q_w.shape == (size_k, size_n)\n    assert size_k % tile == 0, f\"size_k = {size_k}, tile = {tile}\"\n    assert size_n % tile == 0, f\"size_k = {size_n}, tile = {tile}\"\n\n    # Permute weights to 16x64 marlin tiles\n    q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))\n    q_w = q_w.permute((0, 2, 1, 3))\n    q_w = q_w.reshape((size_k // tile, size_n * tile))\n\n    q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)\n\n    return q_w\n\n\ndef marlin_weights(q_w, size_k, size_n, num_bits, perm):\n    # Permute\n    q_w = marlin_permute_weights(q_w, size_k, size_n, perm)\n\n    # Pack\n    pack_factor = get_pack_factor(num_bits)\n    orig_device = q_w.device\n\n    q_w = q_w.cpu().numpy().astype(numpy.uint32)\n\n    q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),\n                           dtype=numpy.uint32)\n    for i in range(pack_factor):\n        q_packed |= q_w[:, i::pack_factor] << num_bits * i\n\n    q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)\n\n    return q_packed\n\n\ndef marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,\n                          scale_perm_single):\n    if group_size < size_k and group_size != -1:\n        s = s.reshape((-1, len(scale_perm)))[:, scale_perm]\n    else:\n        s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]\n    s = s.reshape((-1, size_n)).contiguous()\n\n    return s\n\n\ndef marlin_quantize(\n    w: torch.Tensor,\n    num_bits: int,\n    group_size: int,\n    act_order: bool,\n):\n    size_k, size_n = w.shape\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    # Quantize (and apply act_order if provided)\n    w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,\n                                                       act_order)\n\n    # For act_order, sort the \"weights\" and \"g_idx\" so that group ids are\n    # increasing\n    sort_indices = torch.empty(0, dtype=torch.int, device=w.device)\n    if act_order:\n        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)\n\n    # Reformat to marlin\n    marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,\n                                marlin_perm[num_bits])\n    marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,\n                                     marlin_scale_perm[num_bits],\n                                     marlin_scale_perm_single[num_bits])\n\n    # Create result\n    res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]\n    for i in range(len(res_list)):\n        res_list[i] = res_list[i].to(w.device)\n\n    return res_list\n\n\ndef inject_24(w, size_k, size_n):\n    assert w.shape == (size_k, size_n)\n\n    mask = mask_creator(w.t()).t().cuda().bool()\n\n    return (mask * w).contiguous(), mask.contiguous()\n\n\ndef check_24(w, num_rows_to_sample=50, _verbose=False):\n    BLOCK_SIZE = 4\n    MAX_NON_ZEROS = 2\n\n    w = w.t().contiguous()\n\n    print(\"check_24: w.shape = {}\".format(w.shape))\n\n    num_rows, num_cols = w.shape\n    sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)\n    if _verbose:\n        print(f\"Sampled row idxs = {sampled_row_idxs}\")\n\n    total_segments = 0\n    non_24_segments = 0\n    for i in sampled_row_idxs:\n        for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):\n            total_segments += 1\n            block = w[i, j:j + BLOCK_SIZE]\n            num_nonzero = torch.count_nonzero(block)\n            if num_nonzero > MAX_NON_ZEROS:\n                print(\"i = {} j = {} block = {}\".format(i, j, block))\n                non_24_segments += 1\n\n    print(f\"{non_24_segments} / {total_segments} do not have 2:4 structure.\")\n\n\ndef compress_quantized_24_weight(q_24, size_k, size_n, num_bits):\n    assert q_24.shape == (size_k, size_n)\n\n    # Remove zp to normalize over 0\n    max_q_val = (1 << num_bits) - 1\n    zp = (max_q_val + 1) // 2\n    q_24_no_zp = q_24 - zp\n\n    # Compress\n    q_24_no_zp = q_24_no_zp.t().contiguous()\n    q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(\n        q_24_no_zp)\n    q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()\n\n    # Restore zp\n    q_24_comp = q_24_no_zp_comp + zp\n\n    # Resize meta to its actual shape (without moving any data)\n    meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)\n\n    return q_24_comp, meta\n\n\ndef marlin_24_quantize(\n    w: torch.Tensor,\n    num_bits: int,\n    group_size: int,\n):\n    size_k, size_n = w.shape\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    # Inject 2:4 sparsity\n    w_24, mask_24 = inject_24(w, size_k, size_n)\n\n    # Quantize\n    w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,\n                                                             num_bits,\n                                                             group_size,\n                                                             act_order=False)\n\n    # Compress quantized weight\n    q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,\n                                                     num_bits)\n    size_k_comp = size_k // 2\n\n    # Reformat to marlin\n    marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,\n                                        num_bits, marlin_24_perm[num_bits])\n    marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,\n                                        marlin_24_scale_perm[num_bits],\n                                        marlin_24_scale_perm_single[num_bits])\n\n    # Create result\n    res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]\n    for i in range(len(res_list)):\n        res_list[i] = res_list[i].to(w.device)\n\n    return res_list\n\n\ndef compute_max_diff(output, output_ref):\n    return torch.mean(torch.abs(output - output_ref)) / torch.mean(\n        torch.abs(output_ref))\n\n\nclass MarlinWorkspace:\n\n    def __init__(self, out_features, min_thread_n, max_parallel, device):\n        assert (out_features % min_thread_n == 0), (\n            \"out_features = {} is undivisible by min_thread_n = {}\".format(\n                out_features, min_thread_n))\n\n        max_workspace_size = ((out_features // min_thread_n) * max_parallel)\n\n        self.scratch = torch.zeros(max_workspace_size,\n                                   dtype=torch.int,\n                                   device=device)"
  },
  {
    "path": "archive/csrc/custom_marlin/utils/quant_utils.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nimport numpy\nimport torch\n\nSUPPORTED_NUM_BITS = [4, 8]\nSUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]\n\n\ndef get_pack_factor(num_bits):\n    assert num_bits in SUPPORTED_NUM_BITS, f\"Unsupported num_bits = {num_bits}\"\n    return 32 // num_bits\n\n\ndef permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):\n    assert q_w.shape == w_ref.shape\n\n    orig_device = q_w.device\n    k_size, _ = q_w.shape\n\n    g_idx = torch.zeros((k_size, ), dtype=torch.int32)\n    for i in range(k_size):\n        g_idx[i] = i // group_size\n\n    # Simulate act_order by doing a random permutation on K\n    rand_perm = torch.randperm(k_size)\n\n    g_idx = g_idx[rand_perm].contiguous()\n    q_w = q_w[rand_perm, :].contiguous()\n    w_ref = w_ref[rand_perm, :].contiguous()\n\n    return (\n        w_ref.to(device=orig_device),\n        q_w.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        rand_perm.to(device=orig_device),\n    )\n\n\n# Function: Dequantize quantized weights\ndef dequantize_weights(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):\n    # Create a tensor for bitwise right shift operation\n    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=device).unsqueeze(0)\n\n    # Apply bitwise right shift and convert qzeros to the appropriate type\n    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)\n    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)\n\n    # Reshape the zeros tensor\n    zeros = zeros + 1\n    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])\n\n    # Reshape the scales tensor\n    scales = scales.reshape(-1, 1, scales.shape[-1])\n\n    # Similar bitwise right shift operation for qweight and reshape\n    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)\n    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)\n    weight = weight.reshape(-1, group_size, weight.shape[2])\n\n    # Apply dequantization formula and reshape the final weight\n    weight = (scales * (weight - zeros))\n    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])\n\n    # Return the transposed weight\n    return weight.transpose(0, 1)\n\ndef quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,\n                     act_order: bool):\n    orig_device = w.device\n    size_k, size_n = w.shape\n\n    assert w.is_floating_point(), \"w must be float\"\n    assert num_bits in SUPPORTED_NUM_BITS, f\"Unsupported num_bits = {num_bits}\"\n    assert group_size in SUPPORTED_GROUP_SIZES + [\n        size_k\n    ], f\"Unsupported groupsize = {group_size}\"\n\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    max_q_val = 2**num_bits - 1\n    half_q_val = (max_q_val + 1) // 2\n\n    # Reshape to [groupsize, -1]\n    if group_size < size_k:\n        w = w.view((-1, group_size, size_n))\n        w = w.permute(1, 0, 2)\n        w = w.reshape((group_size, -1))\n\n    # Compute scale for each group\n    s = torch.max(torch.abs(w), 0, keepdim=True)[0]\n    s *= 2 / max_q_val  # 2 => symmetric\n\n    # Quantize\n    q_w = torch.round(w / s).int()\n    q_w += half_q_val\n    q_w = torch.clamp(q_w, 0, max_q_val)\n\n    # Compute ref (dequantized)\n    w_ref = (q_w - half_q_val).half() * s\n\n    # Restore original shapes\n    if group_size < size_k:\n\n        def reshape_w(w):\n            w = w.reshape((group_size, -1, size_n))\n            w = w.permute(1, 0, 2)\n            w = w.reshape((size_k, size_n)).contiguous()\n            return w\n\n        q_w = reshape_w(q_w)\n        w_ref = reshape_w(w_ref)\n\n    s = s.reshape((-1, size_n)).contiguous()\n\n    # Apply act_order\n    g_idx = torch.empty(0, dtype=torch.int, device=w.device)\n    rand_perm = torch.empty(0, dtype=torch.int, device=w.device)\n    if act_order:\n        assert (\n            group_size < size_k\n        ), \"For act_order, groupsize = {} must be less than size_k = {}\".format(\n            group_size, size_k)\n\n        w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)\n\n    return (\n        w_ref.to(device=orig_device),\n        q_w.to(device=orig_device),\n        s.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        rand_perm.to(device=orig_device),\n    )\n\n\ndef sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):\n    orig_device = q_w.device\n\n    sort_indices = torch.argsort(g_idx).to(\n        dtype=torch.int32)  # Sort based on g_idx\n\n    g_idx = g_idx[sort_indices].contiguous()\n    q_w = q_w[sort_indices, :].contiguous()\n\n    return (\n        q_w.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        sort_indices.to(device=orig_device),\n    )\n\n\ndef gptq_pack(\n    q_w: torch.Tensor,\n    num_bits: int,\n    size_k: int,\n    size_n: int,\n):\n    assert q_w.shape == (size_k, size_n)\n\n    pack_factor = get_pack_factor(num_bits)\n    assert size_k % pack_factor == 0\n\n    orig_device = q_w.device\n\n    q_w = q_w.cpu().numpy().astype(numpy.uint32)\n\n    q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)\n\n    for i in range(pack_factor):\n        q_res |= q_w[i::pack_factor, :] << num_bits * i\n\n    q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)\n    return q_res\n\ndef gptq_unpack(\n    q_res: torch.Tensor,\n    num_bits: int,\n    size_k: int,\n    size_n: int,\n):\n    pack_factor = 32 // num_bits\n    assert size_k % pack_factor == 0\n\n    orig_device = q_res.device\n\n    q_res = q_res.cpu().numpy()\n\n    q_w = numpy.zeros((size_k, size_n), dtype=numpy.uint32)\n\n    for i in range(pack_factor):\n        q_w[i::pack_factor, :] = (q_res >> (num_bits * i)) & ((1 << num_bits) - 1)\n\n    q_w = torch.from_numpy(q_w.astype(numpy.int32)).to(orig_device)\n    return q_w"
  },
  {
    "path": "archive/csrc/ktransformers_ext/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.16)\nproject(cpuinfer_ext VERSION 0.1.0)\n\n\nset(CMAKE_CXX_STANDARD 17)\n\n\nset(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -O3 -ffast-math -fopenmp\")\nadd_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})\nset(CMAKE_BUILD_TYPE \"Release\")\n# set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -ffast-math -fopenmp\")\n# set(CMAKE_BUILD_TYPE \"Debug\")\nset(CMAKE_EXPORT_COMPILE_COMMANDS ON)\n\n\ninclude(CheckCXXCompilerFlag)\nset(CMAKE_POSITION_INDEPENDENT_CODE ON)\n\n\noption(LLAMA_NATIVE                     \"llama: enable -march=native flag\"                      ON)\n\n# instruction set specific\nif (LLAMA_NATIVE)\n    set(INS_ENB OFF)\nelse()\n    set(INS_ENB ON)\nendif()\n\noption(LLAMA_AVX                             \"llama: enable AVX\"                                OFF)\noption(LLAMA_AVX2                            \"llama: enable AVX2\"                               OFF)\noption(LLAMA_AVX512                          \"llama: enable AVX512\"                             OFF)\noption(LLAMA_AVX512_VBMI                     \"llama: enable AVX512-VBMI\"                        OFF)\noption(LLAMA_AVX512_VNNI                     \"llama: enable AVX512-VNNI\"                        OFF)\noption(LLAMA_AVX512_BF16                     \"llama: enable AVX512-BF16\"                        OFF)\noption(LLAMA_FMA                             \"llama: enable FMA\"                                OFF)\n# in MSVC F16C is implied with AVX2/AVX512\nif (NOT MSVC)\n    option(LLAMA_F16C                        \"llama: enable F16C\"                               OFF)\nendif()\noption(LLAMA_AVX512_FANCY_SIMD               \"llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI\"                        OFF)\noption(KTRANSFORMERS_USE_CUDA                \"ktransformers: use CUDA\"                          ON)\noption(KTRANSFORMERS_USE_MUSA                \"ktransformers: use MUSA\"                          OFF)\noption(KTRANSFORMERS_USE_ROCM                \"ktransformers: use ROCM\"                          OFF)\noption(KTRANSFORMERS_USE_XPU                 \"ktransformers: use XPU\"                           OFF)\noption(KTRANSFORMERS_USE_NPU                 \"ktransformers: use NPU\"                           OFF)\n\nif(KTRANSFORMERS_USE_NPU)\n    add_definitions(-DKTRANSFORMERS_USE_NPU=1)\nendif()\n\n# Architecture specific\n# TODO: probably these flags need to be tweaked on some architectures\n#       feel free to update the Makefile for your architecture and send a pull request or issue\nmessage(STATUS \"CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}\")\nif (MSVC)\n    string(TOLOWER \"${CMAKE_GENERATOR_PLATFORM}\" CMAKE_GENERATOR_PLATFORM_LWR)\n    message(STATUS \"CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}\")\nelse ()\n    set(CMAKE_GENERATOR_PLATFORM_LWR \"\")\nendif ()\n\nif (NOT MSVC)\n    if (LLAMA_STATIC)\n        add_link_options(-static)\n        if (MINGW)\n            add_link_options(-static-libgcc -static-libstdc++)\n        endif()\n    endif()\n    if (LLAMA_GPROF)\n        add_compile_options(-pg)\n    endif()\nendif()\n\nset(ARCH_FLAGS \"\")\n\nif (CMAKE_OSX_ARCHITECTURES STREQUAL \"arm64\" OR CMAKE_GENERATOR_PLATFORM_LWR STREQUAL \"arm64\" OR\n    (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND\n     CMAKE_SYSTEM_PROCESSOR MATCHES \"^(aarch64|arm.*|ARM64)$\"))\n    message(STATUS \"ARM detected\")\n    if (MSVC)\n        add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead\n        add_compile_definitions(__ARM_NEON)\n        add_compile_definitions(__ARM_FEATURE_FMA)\n\n        set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS})\n        string(JOIN \" \" CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} \"/arch:armv8.2\")\n        check_cxx_source_compiles(\"#include <arm_neon.h>\\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }\" GGML_COMPILER_SUPPORT_DOTPROD)\n        if (GGML_COMPILER_SUPPORT_DOTPROD)\n            add_compile_definitions(__ARM_FEATURE_DOTPROD)\n        endif ()\n        check_cxx_source_compiles(\"#include <arm_neon.h>\\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }\" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)\n        if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)\n            add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)\n        endif ()\n        set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})\n    else()\n        if(KTRANSFORMERS_USE_NPU)\n            list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+fp16fml+dotprod -lnuma)\n        endif()\n        check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)\n        if (NOT \"${COMPILER_SUPPORTS_FP16_FORMAT_I3E}\" STREQUAL \"\")\n            list(APPEND ARCH_FLAGS -mfp16-format=ieee)\n        endif()\n        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"armv6\")\n            # Raspberry Pi 1, Zero\n            list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)\n        endif()\n        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"armv7\")\n            if (\"${CMAKE_SYSTEM_NAME}\" STREQUAL \"Android\")\n                # Android armeabi-v7a\n                list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)\n            else()\n                # Raspberry Pi 2\n                list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)\n            endif()\n        endif()\n        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"armv8\")\n            # Android arm64-v8a\n            # Raspberry Pi 3, 4, Zero 2 (32-bit)\n            list(APPEND ARCH_FLAGS -mno-unaligned-access)\n        endif()\n    endif()\nelseif (CMAKE_OSX_ARCHITECTURES STREQUAL \"x86_64\" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES \"^(x86_64|i686|amd64|x64|win32)$\" OR\n        (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND\n         CMAKE_SYSTEM_PROCESSOR MATCHES \"^(x86_64|i686|AMD64)$\"))\n    message(STATUS \"x86 detected\")\n    if(NOT KTRANSFORMERS_USE_NPU)\n        set(HOST_IS_X86 TRUE)\n        set(HAS_AVX512 TRUE)\n        set(__HAS_AMX__ TRUE)\n        add_compile_definitions(__x86_64__)\n        # check AVX512\n        execute_process(\n            COMMAND lscpu\n            OUTPUT_VARIABLE LSCPU_OUTPUT\n            OUTPUT_STRIP_TRAILING_WHITESPACE\n        )\n        # message(STATUS \"LSCPU_OUTPUT: ${LSCPU_OUTPUT}\")\n    \n        string(FIND \"${LSCPU_OUTPUT}\" \"avx512\" COMPILER_SUPPORTS_AVX512F)\n        \n        if (COMPILER_SUPPORTS_AVX512F GREATER -1)\n            message(STATUS \"Compiler and CPU support AVX512F (tested by compiling a program)\")\n            add_compile_definitions(__HAS_AVX512F__)\n        else()\n            message(STATUS \"Compiler and/or CPU do NOT support AVX512F\")\n            set(HAS_AVX512 False)\n        endif()\n    \n        # check AMX\n        string(FIND \"${LSCPU_OUTPUT}\" \"amx\" COMPILER_SUPPORTS_AMX)\n        \n        if(COMPILER_SUPPORTS_AMX GREATER -1)\n            message(STATUS \"Compiler supports AMX\")\n            add_compile_definitions(__HAS_AMX__)\n        else()\n            message(STATUS \"Compiler does NOT support AMX\")\n        endif()\n    endif()\n    if (MSVC)\n        # instruction set detection for MSVC only\n        if (LLAMA_NATIVE)\n            include(cmake/FindSIMD.cmake)\n        endif ()\n        if (LLAMA_AVX512)\n            list(APPEND ARCH_FLAGS /arch:AVX512)\n            # MSVC has no compile-time flags enabling specific\n            # AVX512 extensions, neither it defines the\n            # macros corresponding to the extensions.\n            # Do it manually.\n            if (LLAMA_AVX512_VBMI)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)\n            endif()\n            if (LLAMA_AVX512_VNNI)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)\n            endif()\n            if (LLAMA_AVX512_FANCY_SIMD)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VL__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VL__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BW__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BW__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512DQ__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512DQ__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)\n            endif()\n            if (LLAMA_AVX512_BF16)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)\n            endif()\n        elseif (LLAMA_AVX2)\n            list(APPEND ARCH_FLAGS /arch:AVX2)\n        elseif (LLAMA_AVX)\n            list(APPEND ARCH_FLAGS /arch:AVX)\n        endif()\n    else()\n        if (LLAMA_NATIVE)\n            list(APPEND ARCH_FLAGS -mfma -mavx -mavx2)\n            list(APPEND ARCH_FLAGS -march=native)\n        endif()\n        if (LLAMA_F16C)\n            list(APPEND ARCH_FLAGS -mf16c)\n        endif()\n        if (LLAMA_FMA)\n            list(APPEND ARCH_FLAGS -mfma)\n        endif()\n        if (LLAMA_AVX)\n            list(APPEND ARCH_FLAGS -mavx)\n        endif()\n        if (LLAMA_AVX2)\n            list(APPEND ARCH_FLAGS -mavx2)\n        endif()\n        if (LLAMA_AVX512)\n            list(APPEND ARCH_FLAGS -mavx512f)\n            list(APPEND ARCH_FLAGS -mavx512bw)\n        endif()\n        if (LLAMA_AVX512_VBMI)\n            list(APPEND ARCH_FLAGS -mavx512vbmi)\n        endif()\n        if (LLAMA_AVX512_VNNI)\n            list(APPEND ARCH_FLAGS -mavx512vnni)\n        endif()\n        if (LLAMA_AVX512_FANCY_SIMD)\n            message(STATUS \"AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI enabled\")\n            list(APPEND ARCH_FLAGS -mavx512vl)\n            list(APPEND ARCH_FLAGS -mavx512bw)\n            list(APPEND ARCH_FLAGS -mavx512dq)\n            list(APPEND ARCH_FLAGS -mavx512vnni)\n            list(APPEND ARCH_FLAGS -mavx512vpopcntdq)\n        endif()\n        if (LLAMA_AVX512_BF16)\n            list(APPEND ARCH_FLAGS -mavx512bf16)\n        endif()\n    endif()\nelseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc64\")\n    message(STATUS \"PowerPC detected\")\n    if (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc64le\")\n        list(APPEND ARCH_FLAGS -mcpu=powerpc64le)\n    else()\n        list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)\n        #TODO: Add  targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)\n    endif()\nelse()\n    message(STATUS \"Unknown architecture\")\nendif()\n\n# message(STATUS \"CUDAToolkit_ROOT:${CUDAToolkit_ROOT}\")\n# find_package(FindCUDAToolkit REQUIRED)\n# if(CUDAToolkit_FOUND)\n#     message(STATUS \"Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}\")\n# else()\n#     message(STATUS \"Can't found CUDA lib\")\n# endif()\n\nif (NOT EXISTS $ENV{ROCM_PATH})\n    if (NOT EXISTS /opt/rocm)\n        set(ROCM_PATH /usr)\n    else()\n        set(ROCM_PATH /opt/rocm)\n    endif()\nelse()\n    set(ROCM_PATH $ENV{ROCM_PATH})\nendif()\n\nlist(APPEND CMAKE_PREFIX_PATH  ${ROCM_PATH})\nlist(APPEND CMAKE_PREFIX_PATH \"${ROCM_PATH}/lib64/cmake\")\n\nif (NOT EXISTS $ENV{MUSA_PATH})\n    if (NOT EXISTS /opt/musa)\n        set(MUSA_PATH /usr/local/musa)\n    else()\n        set(MUSA_PATH /opt/musa)\n    endif()\nelse()\n    set(MUSA_PATH $ENV{MUSA_PATH})\nendif()\n\nlist(APPEND CMAKE_MODULE_PATH \"${MUSA_PATH}/cmake\")\n\nadd_compile_options(\"$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>\")\nadd_compile_options(\"$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>\")\n\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/third_party/pybind11)\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)\n\ninclude_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)\nif (WIN32)\n    include_directories(\"$ENV{CUDA_PATH}/include\")\n    add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)\nelseif (UNIX)\n    if (KTRANSFORMERS_USE_ROCM)\n        find_package(HIP REQUIRED)\n        if(HIP_FOUND)\n            include_directories(\"${HIP_INCLUDE_DIRS}\")\n            add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)\n        endif()\n    elseif (KTRANSFORMERS_USE_MUSA)\n        if (NOT EXISTS $ENV{MUSA_PATH})\n            if (NOT EXISTS /opt/musa)\n                set(MUSA_PATH /usr/local/musa)\n            else()\n                set(MUSA_PATH /opt/musa)\n            endif()\n        else()\n            set(MUSA_PATH $ENV{MUSA_PATH})\n        endif()\n\n        list(APPEND CMAKE_MODULE_PATH \"${MUSA_PATH}/cmake\")\n\n        find_package(MUSAToolkit)\n        if (MUSAToolkit_FOUND)\n            message(STATUS \"MUSA Toolkit found\")\n            add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)\n        endif()\n    elseif (KTRANSFORMERS_USE_XPU)\n        add_compile_definitions(KTRANSFORMERS_USE_XPU=1)\n    elseif (KTRANSFORMERS_USE_CUDA)\n        find_package(CUDA REQUIRED)\n        include_directories(\"${CUDA_INCLUDE_DIRS}\")\n        include(CheckLanguage)\n        check_language(CUDA)\n        if(CMAKE_CUDA_COMPILER)\n            message(STATUS \"CUDA detected\")\n            find_package(CUDAToolkit REQUIRED)\n            include_directories(${CUDAToolkit_INCLUDE_DIRS})\n        endif()\n        message(STATUS \"enabling CUDA\")\n        enable_language(CUDA)\n        add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)\n    endif()\nendif()\n\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)\n# aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile SOURCE_DIR4)\nfile(GLOB LLAMAFILE_SOURCES \"${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/*.cpp\")\nlist(REMOVE_ITEM LLAMAFILE_SOURCES\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/sgemm_arm.cpp\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llamafile/sgemm_x86.cpp\"\n)\nset(SOURCE_DIR4 ${LLAMAFILE_SOURCES})\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)\n\nif (HOST_IS_X86 AND HAS_AVX512 AND __HAS_AMX__)\n    aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)\nendif()\n\n\nset(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})\n\nfile(GLOB_RECURSE FMT_SOURCES \"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp\" \"${CMAKE_CURRENT_SOURCE_DIR}/*.hpp\" \"${CMAKE_CURRENT_SOURCE_DIR}/*.h\")\n\nadd_custom_target(\n    format\n    COMMAND clang-format\n    -i\n    -style=file\n    ${FMT_SOURCES}\n    COMMENT \"Running clang-format on all source files\"\n)\n\n\nadd_library(llamafile STATIC ${SOURCE_DIR4})\n\nmessage(STATUS \"CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}\")\nmessage(STATUS \"ARCH_FLAGS: ${ARCH_FLAGS}\")\npybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})\ntarget_link_libraries(${PROJECT_NAME} PRIVATE llama)\n\n\nif(WIN32)\n    target_link_libraries(${PROJECT_NAME} PRIVATE \"$ENV{CUDA_PATH}/lib/x64/cudart.lib\")#CUDA::cudart\nelseif(UNIX)\n    if (KTRANSFORMERS_USE_ROCM)\n        add_compile_definitions(USE_HIP=1)\n        target_link_libraries(${PROJECT_NAME} PRIVATE \"${ROCM_PATH}/lib/libamdhip64.so\")\n        message(STATUS \"Building for HIP\")\n    elseif(KTRANSFORMERS_USE_MUSA)\n        target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)\n    elseif(KTRANSFORMERS_USE_XPU)\n    elseif(KTRANSFORMERS_USE_CUDA AND NOT KTRANSFORMERS_USE_MUSA)\n        target_link_libraries(${PROJECT_NAME} PRIVATE \"${CUDAToolkit_LIBRARY_DIR}/libcudart.so\")\n    endif()\nendif()\n\n# Define the USE_NUMA option\noption(USE_NUMA \"Disable NUMA support\" OFF)\n\n# Check if the USE_NUMA environment variable is set\nif(DEFINED ENV{USE_NUMA})\n    set(USE_NUMA ON)\nendif()\n\nif(USE_NUMA)\n    message(STATUS \"NUMA support is enabled\")\nelse()\n    message(STATUS \"NUMA support is disabled\")\nendif()\n\nfind_library(NUMA_LIBRARY NAMES numa)\n\nif(NUMA_LIBRARY AND USE_NUMA)\n    message(STATUS \"NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support\")\n    target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY})\n    target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA)\nelse()\n    if(USE_NUMA)\n        message(FATAL_ERROR \"NUMA library not found - maybe sudo apt install libnuma-dev\")\n    else()\n        message(STATUS \"NUMA library not found or user not set USE_NUMA - disabling NUMA support\")\n    endif()\nendif()\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/bench/bench_attention.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :  \nAuthor       : Jianwei Dong\nDate         : 2024-08-28 10:32:05\nVersion      : 1.0.0\nLastEditors  : Jianwei Dong \nLastEditTime : 2024-08-28 10:32:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nimport cpuinfer_ext\nimport torch\n\nlayer_num = 10\nkv_head_num = 8\nq_head_num = 32\nhead_dim = 128\nblock_len = 128\nanchor_num = 1\n\nanchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC\nkv_type = cpuinfer_ext.kvcache.ggml_type.FP16\nretrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER\nlayer_step: int = 1\ntoken_step: int = 1\nlayer_offset: int = 0\nmax_thread_num: int = 64\nmax_batch_size: int = 1\nmax_block_num: int = 1024\nCPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)\n\nwarm_up_iter = 1000\ntest_iter = 10000\n\n\ndef bench_linear(cache_seqlen: int):\n    with torch.inference_mode(mode=True):\n        cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device=\"cpu\")\n        seqlens_zero = torch.zeros((1,), dtype=torch.int32, device=\"cpu\")\n\n        config = cpuinfer_ext.kvcache.KVCacheConfig(\n            layer_num,\n            kv_head_num,\n            q_head_num,\n            head_dim,\n            block_len,\n            anchor_num,\n            anchor_type,\n            kv_type,\n            retrieval_type,\n            layer_step,\n            token_step,\n            layer_offset,\n            max_block_num,\n            max_batch_size,\n            max_thread_num,\n        )\n        local_kvcache = cpuinfer_ext.kvcache.KVCache(config)\n        block_table = (\n            torch.arange(max_block_num, dtype=torch.int32, device=\"cpu\")\n            .contiguous()\n            .view(1, -1)\n        )\n\n        for layer_idx in range(layer_num):\n            k_cache = torch.randn(\n                (1, cache_seqlen, kv_head_num, head_dim),\n                dtype=torch.float16,\n                device=\"cpu\",\n            ).contiguous()\n            v_cache = torch.randn(\n                (1, cache_seqlen, kv_head_num, head_dim),\n                dtype=torch.float16,\n                device=\"cpu\",\n            ).contiguous()\n\n            CPUInfer.submit(\n                local_kvcache.update_kvcache_fp16(\n                    k_cache.data_ptr(),\n                    v_cache.data_ptr(),\n                    layer_idx,\n                    block_table.data_ptr(),\n                    1,\n                    max_block_num,\n                    seqlens_zero.data_ptr(),\n                    cache_seqlen,\n                )\n            )\n            CPUInfer.sync()\n\n        input = torch.randn(\n            (1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n        output = torch.empty(\n            (1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n\n        # attn_lse: (bsz, q_len, q_head_num)\n        attn_lse = torch.empty(\n            (1, 1, q_head_num), dtype=torch.float32, device=\"cpu\"\n        ).contiguous()\n        input = input / 100\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                local_kvcache.attn(\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    attn_lse.data_ptr(),\n                    i % layer_num,\n                    0,\n                    1,\n                    1,\n                    max_block_num,\n                    block_table.data_ptr(),\n                    cache_seqlens.data_ptr(),\n                    -1,\n                    -1,\n                    -1,\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                local_kvcache.attn(\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    attn_lse.data_ptr(),\n                    i % layer_num,\n                    0,\n                    1,\n                    1,\n                    max_block_num,\n                    block_table.data_ptr(),\n                    cache_seqlens.data_ptr(),\n                    -1,\n                    -1,\n                    -1,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print(\"cache sequence length: \", cache_seqlen)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", total_time / test_iter * 1000000)\n        print(\n            \"Bandwidth: \",\n            cache_seqlen\n            * kv_head_num\n            * head_dim\n            * 2\n            * 2\n            * test_iter\n            / total_time\n            / 1000\n            / 1000\n            / 1000,\n            \"GB/s\",\n        )\n        print(\"\")\n\n\nbench_linear(1024)\nbench_linear(4096)\nbench_linear(16384)\nbench_linear(32768)\nbench_linear(65536)\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/bench/bench_attention_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :  \nAuthor       : Jianwei Dong\nDate         : 2024-08-28 10:32:05\nVersion      : 1.0.0\nLastEditors  : Jianwei Dong \nLastEditTime : 2024-08-28 10:32:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nimport cpuinfer_ext\nimport torch\n\nlayer_num = 10\nkv_head_num = 8\nq_head_num = 32\nhead_dim = 128\nblock_len = 128\nanchor_num = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\n\ndef bench_linear(cache_seqlen: int, device):\n    with torch.inference_mode(mode=True):\n\n        kvcaches = []\n\n        for layer_idx in range(layer_num):\n            k_cache = torch.randn(\n                (1, 32, cache_seqlen, head_dim),\n                dtype=torch.float16,\n                device=device,\n            ).contiguous()\n            v_cache = torch.randn(\n                (1, 32, cache_seqlen, head_dim),\n                dtype=torch.float16,\n                device=device,\n            ).contiguous()\n\n            kvcaches.append((k_cache, v_cache))\n\n        input = torch.randn(\n            (1, q_head_num, 1, head_dim), dtype=torch.float16, device=device\n        ).contiguous()\n        input = input / 100\n\n        # warm up\n        for i in range(warm_up_iter):\n            k_cache = kvcaches[i % layer_num][0]\n            v_cache = kvcaches[i % layer_num][1]\n            torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            k_cache = kvcaches[i % layer_num][0]\n            v_cache = kvcaches[i % layer_num][1]\n            torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)\n        end = time.perf_counter()\n        total_time = end - start\n        print(\"cache sequence length: \", cache_seqlen)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", total_time / test_iter * 1000000)\n        print(\n            \"Bandwidth: \",\n            cache_seqlen\n            * q_head_num\n            * head_dim\n            * 2\n            * 2\n            * test_iter\n            / total_time\n            / 1000\n            / 1000\n            / 1000,\n            \"GB/s\",\n        )\n        print(\"\")\n\n\nbench_linear(1024, \"cpu\")\nbench_linear(4096, \"cpu\")\nbench_linear(1024, \"cuda\")\nbench_linear(4096, \"cuda\")\nbench_linear(16384, \"cuda\")\nbench_linear(32768, \"cuda\")\nbench_linear(65536, \"cuda\")\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/bench/bench_linear.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:31:59\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:35:35\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\ninput_size = 16384\noutput_size = 5120\nstride = 16\ngroup_max_len = 1024\nlayer_num = 10\nqlen = 1\nCPUInfer = cpuinfer_ext.CPUInfer(64)\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef bench_linear(quant_mode: str):\n    with torch.inference_mode(mode=True):\n\n        hidden_type = 30 # ggml_type::GGML_TYPE_BF16\n        if quant_mode == \"fp32\":\n            proj_type = 0 # ggml_type::GGML_TYPE_F32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = 1 # ggml_type::GGML_TYPE_F16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = 30 # ggml_type::GGML_TYPE_BF16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"q8_0\":\n            proj_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            bytes_per_elem = 1.062500\n        elif quant_mode == \"q6_k\":\n            proj_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.820312\n        elif quant_mode == \"q5_k_m\":\n            proj_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            bytes_per_elem = 0.687500\n        elif quant_mode == \"q4_k_m\":\n            proj_type = 12 # ggml_type::GGML_TYPE_Q4_K\n            bytes_per_elem = 0.562500\n        elif quant_mode == \"q3_k_m\":\n            proj_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"q2_k\":\n            proj_type = 10 # ggml_type::GGML_TYPE_Q2_K\n            bytes_per_elem = 0.328125\n        elif quant_mode == \"iq3_xs\":\n            proj_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"iq2_xxs\":\n            proj_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            bytes_per_elem = 0.257812\n        else:\n            assert(False)\n\n        linears = []\n        projs = []\n        for _ in range(layer_num):\n            proj = torch.randn((output_size, input_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)\n            linear = cpuinfer_ext.linear.Linear(config)\n            projs.append(proj)\n            linears.append(linear)\n        input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n        output = torch.empty((layer_num, qlen, output_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                linears[i % layer_num].forward(\n                    qlen, \n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                linears[i % layer_num].forward(\n                    qlen, \n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', input_size * output_size * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_linear(\"fp32\")\nbench_linear(\"fp16\")\nbench_linear(\"bf16\")\nbench_linear(\"q8_0\")\nbench_linear(\"q6_k\")\nbench_linear(\"q5_k_m\")\nbench_linear(\"q4_k_m\")\nbench_linear(\"q3_k_m\")\nbench_linear(\"q2_k\")\n# Not supported on __x86_64__\n# bench_linear(\"iq3_xs\")\n# bench_linear(\"iq2_xxs\")\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/bench/bench_linear_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:31:59\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-07-25 10:32:48\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nimport torch\nimport torch.nn.quantized as nnq\n\nscale, zero_point = 0.1, 0  # Adjust scale and zero_point based on your dataset\n\ninput_size = 16384\noutput_size = 5120\nlayer_num = 10\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef bench_linear(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"fp32\":\n            proj_type = torch.float32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = torch.float16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = torch.bfloat16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"qint8\":\n            proj_type = torch.qint8\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n        projs = []\n        for _ in range(layer_num):\n            proj = torch.randn((output_size, input_size), dtype = torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            if quant_mode == \"qint8\":\n                proj_q = torch.quantize_per_tensor(proj, scale, zero_point, torch.qint8)\n                quantized_layer = nnq.Linear(input_size, output_size)\n                quantized_layer.set_weight_bias(proj_q, None)\n                projs.append(quantized_layer)\n            else:\n                projs.append(proj.to(proj_type))\n        input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            if isinstance(projs[i % layer_num], nnq.Linear):\n                input_q = torch.quantize_per_tensor(input[i % layer_num].to(torch.float32), scale, zero_point, torch.quint8)\n                t_output = projs[i % layer_num](input_q)\n            else:\n                t_output = torch.mm(input[i % layer_num].to(proj_type), projs[i % layer_num].t())\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            if isinstance(projs[i % layer_num], nnq.Linear):\n                input_q = torch.quantize_per_tensor(input[i % layer_num].to(torch.float32), scale, zero_point, torch.quint8)\n                t_output = projs[i % layer_num](input_q)\n            else:\n                t_output = torch.mm(input[i % layer_num].to(proj_type), projs[i % layer_num].t())\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', input_size * output_size * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_linear(\"fp32\")\nbench_linear(\"fp16\")\nbench_linear(\"bf16\")\nbench_linear(\"qint8\")\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/bench/bench_mlp.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-16 10:43:18\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:36:04\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nhidden_size = 5120\nintermediate_size = 3072\nstride = 16\ngroup_max_len = 1024\nlayer_num = 10\nqlen = 1\nCPUInfer = cpuinfer_ext.CPUInfer(64)\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef bench_mlp(quant_mode: str):\n    with torch.inference_mode(mode=True):\n\n        hidden_type = 30 # ggml_type::GGML_TYPE_BF16\n        if quant_mode == \"fp32\":\n            gate_type = 0 # ggml_type::GGML_TYPE_F32\n            up_type = 0 # ggml_type::GGML_TYPE_F32\n            down_type = 0 # ggml_type::GGML_TYPE_F32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            gate_type = 1 # ggml_type::GGML_TYPE_F16\n            up_type = 1 # ggml_type::GGML_TYPE_F16\n            down_type = 1 # ggml_type::GGML_TYPE_F16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            gate_type = 30 # ggml_type::GGML_TYPE_BF16\n            up_type = 30 # ggml_type::GGML_TYPE_BF16\n            down_type = 30 # ggml_type::GGML_TYPE_BF16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"q8_0\":\n            gate_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            up_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            down_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            bytes_per_elem = 1.062500\n        elif quant_mode == \"q6_k\":\n            gate_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            up_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.820312\n        elif quant_mode == \"q5_k_m\":\n            gate_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            up_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.731771\n        elif quant_mode == \"q4_k_m\":\n            gate_type = 12 # ggml_type::GGML_TYPE_Q4_K\n            up_type = 12 # ggml_type::GGML_TYPE_Q4_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.648437\n        elif quant_mode == \"q3_k_m\":\n            gate_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            up_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            down_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            bytes_per_elem = 0.515625\n        elif quant_mode == \"q2_k\":\n            gate_type = 10 # ggml_type::GGML_TYPE_Q2_K\n            up_type = 10 # ggml_type::GGML_TYPE_Q2_K\n            down_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            bytes_per_elem = 0.328125\n        elif quant_mode == \"iq3_xs\":\n            gate_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            up_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            down_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"iq2_xxs\":\n            gate_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            up_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            down_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            bytes_per_elem = 0.257812\n        else:\n            assert(False)\n\n\n        mlps = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)\n            mlp = cpuinfer_ext.mlp.MLP(config)\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            mlps.append(mlp)\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n        output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                mlps[i % layer_num].forward( \n                    qlen, \n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                mlps[i % layer_num].forward( \n                    qlen, \n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_mlp(\"fp32\")\nbench_mlp(\"fp16\")\nbench_mlp(\"bf16\")\nbench_mlp(\"q8_0\")\nbench_mlp(\"q6_k\")\nbench_mlp(\"q5_k_m\")\nbench_mlp(\"q4_k_m\")\nbench_mlp(\"q3_k_m\")\nbench_mlp(\"q2_k\")\n# Not supported on __x86_64__\n# bench_linear(\"iq3_xs\")\n# bench_linear(\"iq2_xxs\")\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/bench/bench_mlp_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-16 10:43:18\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-07-25 10:32:53\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nimport torch\nimport torch.nn.quantized as nnq\n\nscale, zero_point = 0.1, 0  # Adjust scale and zero_point based on your dataset\n\nhidden_size = 5120\nintermediate_size = 3072\nlayer_num = 10\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    if isinstance(gate_proj, nnq.Linear):\n        input_q = torch.quantize_per_tensor(input.to(torch.float32), scale, zero_point, torch.quint8)\n        gate_buf = gate_proj(input_q)\n        up_buf = up_proj(input_q)\n        gate_buf = gate_buf.dequantize()\n        up_buf = up_buf.dequantize()\n        intermediate = act_fn(gate_buf) * up_buf\n        intermediate_q = torch.quantize_per_tensor(intermediate, scale, zero_point, torch.quint8)\n        expert_output = down_proj(intermediate_q)\n        ret = expert_output.dequantize()\n    else:\n        gate_buf = torch.mm(input.to(gate_proj.dtype), gate_proj.t())\n        up_buf = torch.mm(input.to(up_proj.dtype), up_proj.t())\n        intermediate = act_fn(gate_buf) * up_buf\n        ret = torch.mm(intermediate.to(down_proj.dtype), down_proj.t())\n    return ret\n\ndef bench_mlp(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"fp32\":\n            proj_type = torch.float32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = torch.float16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = torch.bfloat16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"qint8\":\n            proj_type = torch.qint8\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            if quant_mode == \"qint8\":\n                gate_proj_q = torch.quantize_per_tensor(gate_proj, scale, zero_point, torch.qint8)\n                quantized_gate = nnq.Linear(hidden_size, intermediate_size)\n                quantized_gate.set_weight_bias(gate_proj_q, None)\n                up_proj_q = torch.quantize_per_tensor(up_proj, scale, zero_point, torch.qint8)\n                quantized_up = nnq.Linear(hidden_size, intermediate_size)\n                quantized_up.set_weight_bias(up_proj_q, None)\n                down_proj_q = torch.quantize_per_tensor(down_proj, scale, zero_point, torch.qint8)\n                quantized_down = nnq.Linear(intermediate_size, hidden_size)\n                quantized_down.set_weight_bias(down_proj_q, None)\n                gate_projs.append(quantized_gate)\n                up_projs.append(quantized_up)\n                down_projs.append(quantized_down)\n            else:\n                gate_projs.append(gate_proj.to(proj_type))\n                up_projs.append(up_proj.to(proj_type))\n                down_projs.append(down_proj.to(proj_type))\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            mlp_torch(input[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            mlp_torch(input[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_mlp(\"fp32\")\nbench_mlp(\"fp16\")\nbench_mlp(\"bf16\")\nbench_mlp(\"qint8\")\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/bench/bench_moe.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:41:28\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nexpert_num = 160\nhidden_size = 5120\nintermediate_size = 1536\nstride = 16\ngroup_min_len = 10\ngroup_max_len = 1024\nn_routed_experts = 6\nlayer_num = 10\nqlen = 1\nCPUInfer = cpuinfer_ext.CPUInfer(64)\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        hidden_type = 30 # ggml_type::GGML_TYPE_BF16\n        if quant_mode == \"fp32\":\n            gate_type = 0 # ggml_type::GGML_TYPE_F32\n            up_type = 0 # ggml_type::GGML_TYPE_F32\n            down_type = 0 # ggml_type::GGML_TYPE_F32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            gate_type = 1 # ggml_type::GGML_TYPE_F16\n            up_type = 1 # ggml_type::GGML_TYPE_F16\n            down_type = 1 # ggml_type::GGML_TYPE_F16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            gate_type = 30 # ggml_type::GGML_TYPE_BF16\n            up_type = 30 # ggml_type::GGML_TYPE_BF16\n            down_type = 30 # ggml_type::GGML_TYPE_BF16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"q8_0\":\n            gate_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            up_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            down_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            bytes_per_elem = 1.062500\n        elif quant_mode == \"q6_k\":\n            gate_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            up_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.820312\n        elif quant_mode == \"q5_k_m\":\n            gate_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            up_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.731771\n        elif quant_mode == \"q4_k_m\":\n            gate_type = 12 # ggml_type::GGML_TYPE_Q4_K\n            up_type = 12 # ggml_type::GGML_TYPE_Q4_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.648437\n        elif quant_mode == \"q3_k_m\":\n            gate_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            up_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            down_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            bytes_per_elem = 0.515625\n        elif quant_mode == \"q2_k\":\n            gate_type = 10 # ggml_type::GGML_TYPE_Q2_K\n            up_type = 10 # ggml_type::GGML_TYPE_Q2_K\n            down_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            bytes_per_elem = 0.328125\n        elif quant_mode == \"iq3_xs\":\n            gate_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            up_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            down_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"iq2_xxs\":\n            gate_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            up_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            down_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            bytes_per_elem = 0.257812\n        else:\n            assert(False)\n\n\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            config = cpuinfer_ext.moe.MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, stride, group_min_len, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)\n            moe = cpuinfer_ext.moe.MOE(config)\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n        expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = \"cuda\")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to(\"cpu\").contiguous()\n        weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n        output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                moes[i % layer_num].forward( \n                    qlen, \n                    n_routed_experts, \n                    expert_ids[i % layer_num].data_ptr(), \n                    weights[i % layer_num].data_ptr(),\n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                moes[i % layer_num].forward( \n                    qlen, \n                    n_routed_experts, \n                    expert_ids[i % layer_num].data_ptr(), \n                    weights[i % layer_num].data_ptr(),\n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_moe(\"fp32\")\nbench_moe(\"fp16\")\nbench_moe(\"bf16\")\nbench_moe(\"q8_0\")\nbench_moe(\"q6_k\")\nbench_moe(\"q5_k_m\")\nbench_moe(\"q4_k_m\")\nbench_moe(\"q3_k_m\")\nbench_moe(\"q2_k\")\n# Not supported on __x86_64__\n# bench_linear(\"iq3_xs\")\n# bench_linear(\"iq2_xxs\")\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/bench/bench_moe_amx.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2025-04-25 18:28:12\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2025-04-25 18:28:12\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nexpert_num = 8\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nn_routed_experts = 8\nlayer_num = 10\nqlen = 1024\nCPUInfer = cpuinfer_ext.CPUInfer(65)\nwarm_up_iter = 100\ntest_iter = 100\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"bf16\":\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"int8\":\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            config = cpuinfer_ext.moe.AMX_MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr())\n            if quant_mode == \"bf16\":\n                moe = cpuinfer_ext.moe.AMXBF16_MOE(config)\n                CPUInfer.submit(moe.load_weights())\n                CPUInfer.sync()\n            elif quant_mode == \"int8\":\n                moe = cpuinfer_ext.moe.AMXInt8_MOE(config)\n                CPUInfer.submit(moe.load_weights())\n                CPUInfer.sync()\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n        expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = \"cuda\")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to(\"cpu\").contiguous()\n        weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n        output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n        qlen_tensor = torch.tensor([qlen], dtype=torch.int32)\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                moes[i % layer_num].forward( \n                    qlen, \n                    n_routed_experts, \n                    expert_ids[i % layer_num].data_ptr(), \n                    weights[i % layer_num].data_ptr(),\n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr(),\n                    qlen_tensor.data_ptr()\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                moes[i % layer_num].forward( \n                    qlen, \n                    n_routed_experts, \n                    expert_ids[i % layer_num].data_ptr(), \n                    weights[i % layer_num].data_ptr(),\n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr(),\n                    qlen_tensor.data_ptr()\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('Flops: ', hidden_size * intermediate_size * qlen * 3 * n_routed_experts * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GFLOPS')\n        print('')\n\nbench_moe(\"bf16\")\nbench_moe(\"int8\")\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/bench/bench_moe_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-07-25 10:32:57\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nimport torch\nimport torch.nn.quantized as nnq\n\nscale, zero_point = 0.1, 0  # Adjust scale and zero_point based on your dataset\n\nexpert_num = 160\nhidden_size = 5120\nintermediate_size = 1536\nn_routed_experts = 6\nlayer_num = 10\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    if isinstance(gate_proj, nnq.Linear):\n        input_q = torch.quantize_per_tensor(input.to(torch.float32), scale, zero_point, torch.quint8)\n        gate_buf = gate_proj(input_q)\n        up_buf = up_proj(input_q)\n        gate_buf = gate_buf.dequantize()\n        up_buf = up_buf.dequantize()\n        intermediate = act_fn(gate_buf) * up_buf\n        intermediate_q = torch.quantize_per_tensor(intermediate, scale, zero_point, torch.quint8)\n        expert_output = down_proj(intermediate_q)\n        ret = expert_output.dequantize()\n    else:\n        gate_buf = torch.mm(input.to(gate_proj.dtype), gate_proj.t())\n        up_buf = torch.mm(input.to(up_proj.dtype), up_proj.t())\n        intermediate = act_fn(gate_buf) * up_buf\n        ret = torch.mm(intermediate.to(down_proj.dtype), down_proj.t())\n    return ret\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"fp32\":\n            proj_type = torch.float32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = torch.float16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = torch.bfloat16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"qint8\":\n            proj_type = torch.qint8\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            if quant_mode == \"qint8\":\n                quantized_gate_proj = []\n                quantized_up_proj = []\n                quantized_down_proj = []\n                for i in range(expert_num):\n                    gate_proj_q = torch.quantize_per_tensor(gate_proj[i], scale, zero_point, torch.qint8)\n                    quantized_gate = nnq.Linear(hidden_size, intermediate_size)\n                    quantized_gate.set_weight_bias(gate_proj_q, None)\n                    quantized_gate_proj.append(quantized_gate)\n                    up_proj_q = torch.quantize_per_tensor(up_proj[i], scale, zero_point, torch.qint8)\n                    quantized_up = nnq.Linear(hidden_size, intermediate_size)\n                    quantized_up.set_weight_bias(up_proj_q, None)\n                    quantized_up_proj.append(quantized_up)\n                    down_proj_q = torch.quantize_per_tensor(down_proj[i], scale, zero_point, torch.qint8)\n                    quantized_down = nnq.Linear(intermediate_size, hidden_size)\n                    quantized_down.set_weight_bias(down_proj_q, None)\n                    quantized_down_proj.append(quantized_down)\n                gate_projs.append(quantized_gate_proj)\n                up_projs.append(quantized_up_proj)\n                down_projs.append(quantized_down_proj)\n            else:\n                gate_projs.append(gate_proj.to(proj_type))\n                up_projs.append(up_proj.to(proj_type))\n                down_projs.append(down_proj.to(proj_type))\n        expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = \"cuda\")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to(\"cpu\").contiguous()\n        weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            moe_torch(input[i % layer_num], expert_ids[i % layer_num], weights[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            moe_torch(input[i % layer_num], expert_ids[i % layer_num], weights[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_moe(\"fp32\")\nbench_moe(\"fp16\")\nbench_moe(\"bf16\")\nbench_moe(\"qint8\")\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cmake/FindSIMD.cmake",
    "content": "include(CheckCSourceRuns)\n\nset(AVX_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256 a;\n        a = _mm256_set1_ps(0);\n        return 0;\n    }\n\")\n\nset(AVX512_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0);\n        __m512i b = a;\n        __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);\n        return 0;\n    }\n\")\n\nset(AVX2_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256i a = {0};\n        a = _mm256_abs_epi16(a);\n        __m256i x;\n        _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code\n        return 0;\n    }\n\")\n\nset(FMA_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256 acc = _mm256_setzero_ps();\n        const __m256 d = _mm256_setzero_ps();\n        const __m256 p = _mm256_setzero_ps();\n        acc = _mm256_fmadd_ps( d, p, acc );\n        return 0;\n    }\n\")\n\nmacro(check_sse type flags)\n    set(__FLAG_I 1)\n    set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})\n    foreach (__FLAG ${flags})\n        if (NOT ${type}_FOUND)\n            set(CMAKE_REQUIRED_FLAGS ${__FLAG})\n            check_c_source_runs(\"${${type}_CODE}\" HAS_${type}_${__FLAG_I})\n            if (HAS_${type}_${__FLAG_I})\n                set(${type}_FOUND TRUE CACHE BOOL \"${type} support\")\n                set(${type}_FLAGS \"${__FLAG}\" CACHE STRING \"${type} flags\")\n            endif()\n            math(EXPR __FLAG_I \"${__FLAG_I}+1\")\n        endif()\n    endforeach()\n    set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})\n\n    if (NOT ${type}_FOUND)\n        set(${type}_FOUND FALSE CACHE BOOL \"${type} support\")\n        set(${type}_FLAGS \"\" CACHE STRING \"${type} flags\")\n    endif()\n\n    mark_as_advanced(${type}_FOUND ${type}_FLAGS)\nendmacro()\n\n# flags are for MSVC only!\ncheck_sse(\"AVX\" \" ;/arch:AVX\")\nif (NOT ${AVX_FOUND})\n    set(LLAMA_AVX OFF)\nelse()\n    set(LLAMA_AVX ON)\nendif()\n\ncheck_sse(\"AVX2\" \" ;/arch:AVX2\")\ncheck_sse(\"FMA\" \" ;/arch:AVX2\")\nif ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))\n    set(LLAMA_AVX2 OFF)\nelse()\n    set(LLAMA_AVX2 ON)\nendif()\n\ncheck_sse(\"AVX512\" \" ;/arch:AVX512\")\nif (NOT ${AVX512_FOUND})\n    set(LLAMA_AVX512 OFF)\nelse()\n    set(LLAMA_AVX512 ON)\nendif()\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/backend.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:05\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:33:34\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"backend.h\"\n\n#ifdef USE_NUMA\n#include <numa.h>\n#include <numaif.h>\n\nthread_local int Backend::numa_node = -1;\n#endif\n\nthread_local int Backend::thread_local_id = -1;\n\nBackend::Backend(int max_thread_num) {\n    max_thread_num_ = max_thread_num;\n    thread_state_.resize(max_thread_num_);\n    for (int i = 0; i < max_thread_num_; i++) {\n        thread_state_[i].curr = std::make_unique<std::atomic<int>>();\n        thread_state_[i].status =\n            std::make_unique<std::atomic<ThreadStatus>>(ThreadStatus::WAITING);\n    }\n    workers_.resize(max_thread_num_);\n    for (int i = 1; i < max_thread_num_; i++) {\n        workers_[i] = std::thread(&Backend::worker_thread, this, i);\n    }\n}\n\nBackend::~Backend() {\n    for (int i = 0; i < max_thread_num_; i++) {\n        thread_state_[i].status->store(ThreadStatus::EXIT,\n                                       std::memory_order_release);\n    }\n    for (int i = 1; i < max_thread_num_; i++) {\n        if (workers_[i].joinable()) {\n            workers_[i].join();\n        }\n    }\n}\n\nint Backend::get_thread_num() { return max_thread_num_; }\n\nvoid Backend::do_work_stealing_job(int task_num,\n                                   std::function<void(int)> init_func,\n                                   std::function<void(int)> compute_func,\n                                   std::function<void(int)> finalize_func) {\n    init_func_ = init_func;\n    compute_func_ = compute_func;\n    finalize_func_ = finalize_func;\n#ifdef USE_NUMA\n    // numa node location will be calculated based on the number of threads\n    thread_num_ = max_thread_num_;\n#else\n    thread_num_ = std::min(max_thread_num_, task_num);\n#endif\n    int base = task_num / thread_num_;\n    int remain = task_num % thread_num_;\n    thread_state_[0].end = base + (0 < remain);\n\n    // 为主线程设置 thread_local_id\n    thread_local_id = 0;\n\n    for (int i = 1; i < thread_num_; i++) {\n        thread_state_[i].curr->store(thread_state_[i - 1].end,\n                                     std::memory_order_relaxed);\n        thread_state_[i].end = thread_state_[i - 1].end + base + (i < remain);\n        thread_state_[i].status->store(ThreadStatus::WORKING,\n                                       std::memory_order_release);\n    }\n    thread_state_[0].curr->store(0, std::memory_order_relaxed);\n    thread_state_[0].status->store(ThreadStatus::WORKING,\n                                   std::memory_order_release);\n    process_tasks(0);\n    for (int i = 1; i < thread_num_; i++) {\n        while (thread_state_[i].status->load(std::memory_order_acquire) ==\n               ThreadStatus::WORKING) {\n        }\n    }\n}\n\nvoid Backend::process_tasks(int thread_id) {\n    \n    #ifdef USE_NUMA\n    if(numa_node == -1){\n        numa_node = thread_id * numa_num_configured_nodes() / thread_num_;\n        struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes());\n        numa_bitmask_setbit(mask, numa_node);\n        numa_bind(mask);\n    }\n    #endif\n\n    if (init_func_ != nullptr) {\n        init_func_(thread_id);\n    }\n    while (true) {\n        int task_id = thread_state_[thread_id].curr->fetch_add(\n            1, std::memory_order_acq_rel);\n        if (task_id >= thread_state_[thread_id].end) {\n            break;\n        }\n        compute_func_(task_id);\n    }\n    for (int t_offset = 1; t_offset < thread_num_; t_offset++) {\n        int t_i = (thread_id + t_offset) % thread_num_;\n        if (thread_state_[t_i].status->load(std::memory_order_acquire) !=\n            ThreadStatus::WORKING) {\n            continue;\n        }\n        while (true) {\n            int task_id = thread_state_[t_i].curr->fetch_add(\n                1, std::memory_order_acq_rel);\n            if (task_id >= thread_state_[t_i].end) {\n                break;\n            }\n            compute_func_(task_id);\n        }\n    }\n    if (finalize_func_ != nullptr) {\n        finalize_func_(thread_id);\n    }\n    thread_state_[thread_id].status->store(ThreadStatus::WAITING,\n                                           std::memory_order_release);\n}\n\nvoid Backend::worker_thread(int thread_id) {\n    auto start = std::chrono::steady_clock::now();\n    thread_local_id = thread_id; // 设置线程本地变量\n    while (true) {\n        ThreadStatus status =\n            thread_state_[thread_id].status->load(std::memory_order_acquire);\n        if (status == ThreadStatus::WORKING) {\n            process_tasks(thread_id);\n            start = std::chrono::steady_clock::now();\n        } else if (status == ThreadStatus::WAITING) {\n            auto now = std::chrono::steady_clock::now();\n            auto duration =\n                std::chrono::duration_cast<std::chrono::milliseconds>(now -\n                                                                      start)\n                    .count();\n            if (duration > 50) {\n                std::this_thread::sleep_for(std::chrono::milliseconds(1));\n            }\n        } else if (status == ThreadStatus::EXIT) {\n            return;\n        }\n    }\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/backend.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:05\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:33:38\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_BACKEND_H\n#define CPUINFER_BACKEND_H\n\n#include <atomic>\n#include <condition_variable>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <thread>\n#include <vector>\n\nenum ThreadStatus {\n    WORKING,\n    WAITING,\n    EXIT,\n};\n\nstruct ThreadState {\n    std::unique_ptr<std::atomic<ThreadStatus>> status;\n    std::unique_ptr<std::atomic<int>> curr;\n    int end;\n};\n\nclass Backend {\n  public:\n    Backend(int);\n    ~Backend();\n    int get_thread_num();\n    void do_work_stealing_job(int, std::function<void(int)>,\n                              std::function<void(int)>,\n                              std::function<void(int)>);\n    #ifdef USE_NUMA\n    static thread_local int numa_node;\n    #endif\n    static thread_local int thread_local_id;\n\n  private:\n    int thread_num_;\n    int max_thread_num_;\n    std::vector<ThreadState> thread_state_; // [thread_num]\n    std::function<void(int)> init_func_;\n    std::function<void(int)> compute_func_;\n    std::function<void(int)> finalize_func_;\n    std::vector<std::thread> workers_;\n\n    void process_tasks(int);\n    void worker_thread(int);\n};\n#endif"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/cpuinfer.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-16 10:43:18\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-08-07 09:47:43\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n #ifndef CPUINFER_CPUINFER_H\n #define CPUINFER_CPUINFER_H\n \n #include <atomic>\n #include <condition_variable>\n #include <functional>\n #include <mutex>\n #include <queue>\n #include <thread>\n #include <vector>\n #include <stdexcept>\n #ifdef KTRANSFORMERS_USE_CUDA\n #include \"vendors/cuda.h\"\n #elif KTRANSFORMERS_USE_MUSA\n #include \"vendors/musa.h\"\n #elif KTRANSFORMERS_USE_ROCM\n #define __HIP_PLATFORM_AMD__\n #include \"vendors/hip.h\"\n #endif\n \n #include \"backend.h\"\n #include \"task_queue.h\"\n #include \"./vendors/vendor.h\"\n \n #include \"llama.cpp/ggml-impl.h\"\n \n class CPUInfer {\n    public:\n     CPUInfer(int thread_num) {\n         backend_ = new Backend(thread_num - 1);\n         task_queue_ = new TaskQueue();\n         for (int i = 0; i < (1 << 16); ++i) {\n             ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);\n         }\n     }\n \n     ~CPUInfer() {\n         delete backend_;\n         delete task_queue_;\n     }\n \n     template <typename Func, typename Obj, typename... Args>\n     void enqueue(Func f, Obj* obj, Args... args) {\n         task_queue_->enqueue([=]() {\n             std::invoke(f, *obj, args..., backend_);\n         });\n     }\n \n     void submit(std::pair<intptr_t, intptr_t> params) {\n         void (*func)(void*) = (void (*)(void*))params.first;\n         void* args = (void*)params.second;\n         *((CPUInfer**)args) = this;\n         func(args);\n     }\n \n     void sync() {\n         task_queue_->sync();\n     }\n \n     void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {\n        #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)\n         void (*func)(void*) = (void (*)(void*))params.first;\n         void* args = (void*)params.second;\n         *((CPUInfer**)args) = this;\n         cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);\n        #else\n         throw std::runtime_error(\"submit_with_cuda_stream is not supported on this platforma\");\n        #endif\n     }\n \n     static void sync_(void* cpu_infer_ptr) {\n         CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;\n         cpuinfer->sync();\n     }\n \n     void sync_with_cuda_stream(intptr_t user_cuda_stream) {\n        #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)\n         cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);\n        #else\n         throw std::runtime_error(\"sync_with_cuda_stream is not supported on this platforma\");\n        #endif\n     }\n \n    public:\n     Backend* backend_;\n     TaskQueue* task_queue_;\n };\n \n #endif"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-08-05 04:49:08\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022 \n * @LastEditTime : 2024-08-05 09:21:29\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"shared_mem_buffer.h\"\n#include <cstdio>\n\nSharedMemBuffer::SharedMemBuffer() {\n    buffer_ = nullptr;\n    size_ = 0;\n}\n\nSharedMemBuffer::~SharedMemBuffer() {\n    if (buffer_) {\n        free(buffer_);\n    }\n}\n\nvoid SharedMemBuffer::alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests) {\n    uint64_t size = 0;\n    for (auto& request : requests) {\n        size += request.second;\n    }\n    if (size > size_) {\n        if (buffer_) {\n            free(buffer_);\n        }\n        buffer_ = std::aligned_alloc(64, size);\n\n        size_ = size;\n        for (auto& obj_requests : hist_requests_) {\n            for (auto& requests : obj_requests.second) {\n                arrange(requests);\n            }\n        }\n    }\n    arrange(requests);\n    hist_requests_[object].push_back(requests);\n}\n\nvoid SharedMemBuffer::dealloc(void* object) {\n    hist_requests_.erase(object);\n}\n\nvoid SharedMemBuffer::arrange(std::vector<std::pair<void**, uint64_t>> requests) {\n    uint64_t offset = 0;\n    for (auto& request : requests) {\n        *(request.first) = (uint8_t*)buffer_ + offset;\n        offset += request.second;\n    }\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-08-05 04:49:08\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022 \n * @LastEditTime : 2024-08-05 06:36:41\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n #ifndef CPUINFER_SHAREDMEMBUFFER_H\n #define CPUINFER_SHAREDMEMBUFFER_H\n \n #include <cstdint>\n #include <cstdlib>\n #include <map>\n #include <vector>\n \n class SharedMemBuffer {\n    public:\n     SharedMemBuffer();\n     ~SharedMemBuffer();\n \n     void alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests);\n     void dealloc(void* object);\n \n    private:\n     void* buffer_;\n     uint64_t size_;\n     std::map<void*, std::vector<std::vector<std::pair<void**, uint64_t>>>> hist_requests_;\n \n     void arrange(std::vector<std::pair<void**, uint64_t>> requests);\n };\n \n static SharedMemBuffer shared_mem_buffer;\n \n #endif"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/task_queue.cpp",
    "content": "/**\n * @Description :\n * @Author    : chenht2022\n * @Date     : 2024-07-17 12:25:51\n * @Version   : 1.0.0\n * @LastEditors : chenht2022\n * @LastEditTime : 2024-10-09 11:08:10\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"task_queue.h\"\n\nTaskQueue::TaskQueue() {\n    worker = std::thread(&TaskQueue::processTasks, this);\n    sync_flag.store(true, std::memory_order_seq_cst);\n    exit_flag.store(false, std::memory_order_seq_cst);\n}\n\nTaskQueue::~TaskQueue() {\n    {\n        mutex.lock();\n        exit_flag.store(true, std::memory_order_seq_cst);\n        mutex.unlock();\n    }\n    cv.notify_all();\n    if (worker.joinable()) {\n        worker.join();\n    }\n}\n\nvoid TaskQueue::enqueue(std::function<void()> task) {\n    {\n        mutex.lock();\n        tasks.push(task);\n        sync_flag.store(false, std::memory_order_seq_cst);\n        mutex.unlock();\n    }\n    cv.notify_one();\n}\n\nvoid TaskQueue::sync() {\n    while (!sync_flag.load(std::memory_order_seq_cst))\n        ;\n}\n\nvoid TaskQueue::processTasks() {\n    while (true) {\n        std::function<void()> task;\n        {\n            mutex.lock();\n            cv.wait(mutex, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); });\n            if (exit_flag.load(std::memory_order_seq_cst) && tasks.empty()) {\n                return;\n            }\n            task = tasks.front();\n            tasks.pop();\n            mutex.unlock();\n        }\n        task();\n        {\n            mutex.lock();\n            if (tasks.empty()) {\n                sync_flag.store(true, std::memory_order_seq_cst);\n            }\n            mutex.unlock();\n        }\n    }\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/task_queue.h",
    "content": "/**\n * @Description :\n * @Author    : chenht2022\n * @Date     : 2024-07-16 10:43:18\n * @Version   : 1.0.0\n * @LastEditors : chenht\n * @LastEditTime : 2024-10-09 11:08:07\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_TASKQUEUE_H\n#define CPUINFER_TASKQUEUE_H\n\n#include <atomic>\n#include <condition_variable>\n#include <functional>\n#include <mutex>\n#include <queue>\n#include <thread>\n#include <vector>\n#ifdef _WIN32\n#include <windows.h>\n#endif\n\nclass custom_mutex {\n   private:\n#ifdef _WIN32\n    CRITICAL_SECTION cs;\n#else\n    std::mutex mtx;\n#endif\n\n   public:\n    custom_mutex() {\n#ifdef _WIN32\n        InitializeCriticalSection(&cs);\n#else\n        // No initialization required for std::mutex\n#endif\n    }\n\n    ~custom_mutex() {\n#ifdef _WIN32\n        DeleteCriticalSection(&cs);\n#endif\n    }\n\n    void lock() {\n#ifdef _WIN32\n        EnterCriticalSection(&cs);\n#else\n        mtx.lock();\n#endif\n    }\n\n    void unlock() {\n#ifdef _WIN32\n        LeaveCriticalSection(&cs);\n#else\n        mtx.unlock();\n#endif\n    }\n\n#ifdef _WIN32\n    CRITICAL_SECTION* get_handle() {\n        return &cs;\n    }\n#else\n    std::mutex* get_handle() {\n        return &mtx;\n    }\n#endif\n};\n\nclass custom_condition_variable {\n   private:\n#ifdef _WIN32\n    CONDITION_VARIABLE cond_var;\n#else\n    std::condition_variable cond_var;\n#endif\n\n   public:\n    custom_condition_variable() {\n#ifdef _WIN32\n        InitializeConditionVariable(&cond_var);\n#endif\n    }\n\n    template <typename Predicate>\n    void wait(custom_mutex& mutex, Predicate pred) {\n#ifdef _WIN32\n        while (!pred()) {\n            SleepConditionVariableCS(&cond_var, mutex.get_handle(), INFINITE);\n        }\n#else\n        std::unique_lock<std::mutex> lock(*mutex.get_handle(), std::adopt_lock);\n        cond_var.wait(lock, pred);\n        lock.release();\n#endif\n    }\n\n    void notify_one() {\n#ifdef _WIN32\n        WakeConditionVariable(&cond_var);\n#else\n        cond_var.notify_one();\n#endif\n    }\n\n    void notify_all() {\n#ifdef _WIN32\n        WakeAllConditionVariable(&cond_var);\n#else\n        cond_var.notify_all();\n#endif\n    }\n};\n\nclass TaskQueue {\n   public:\n    TaskQueue();\n    ~TaskQueue();\n\n    void enqueue(std::function<void()>);\n\n    void sync();\n\n   private:\n    void processTasks();\n\n    std::queue<std::function<void()>> tasks;\n    custom_mutex mutex;\n    custom_condition_variable cv;\n    std::thread worker;\n    std::atomic<bool> sync_flag;\n    std::atomic<bool> exit_flag;\n};\n#endif"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/vendors/README.md",
    "content": "## TODO\n\nThis directory can be removed after updating the version of `llama.cpp`."
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/vendors/cuda.h",
    "content": "#pragma once\n\n#include <cuda_runtime.h>\n#include <cuda.h>\n#include <cublas_v2.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#if CUDART_VERSION < 11020\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#endif // CUDART_VERSION < 11020\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/vendors/hip.h",
    "content": "#pragma once\n\n#define HIP_ENABLE_WARP_SYNC_BUILTINS 1\n#include <hip/hip_runtime.h>\n#include <hipblas/hipblas.h>\n#include <hip/hip_fp16.h>\n#include <hip/hip_bfloat16.h>\n#ifdef __HIP_PLATFORM_AMD__\n// for rocblas_initialize()\n#include \"rocblas/rocblas.h\"\n#endif // __HIP_PLATFORM_AMD__\n\n#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F\n#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F\n#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N HIPBLAS_OP_N\n#define CUBLAS_OP_T HIPBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH 0\n#define CUDA_R_16F  HIPBLAS_R_16F\n#define CUDA_R_32F  HIPBLAS_R_32F\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended\n#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned\n#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite\n#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT(\"HipVMM Failure: %s\\n\", hipGetErrorString(err)); }}\n#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)\n#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)\n#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6\n#define cublasCreate hipblasCreate\n#define cublasDestroy hipblasDestroy\n#define cublasGemmEx hipblasGemmEx\n#define cublasGemmBatchedEx hipblasGemmBatchedEx\n#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx\n#define cublasHandle_t hipblasHandle_t\n#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS\n#define cublasSetStream hipblasSetStream\n#define cublasSgemm hipblasSgemm\n#define cublasStatus_t hipblasStatus_t\n#define cublasOperation_t hipblasOperation_t\n#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6\n#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess\n#define cudaDeviceProp hipDeviceProp_t\n#define cudaDeviceSynchronize hipDeviceSynchronize\n#define cudaError_t hipError_t\n#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags hipEventCreateWithFlags\n#define cudaEventDisableTiming hipEventDisableTiming\n#define cudaEventRecord hipEventRecord\n#define cudaEventSynchronize hipEventSynchronize\n#define cudaEvent_t hipEvent_t\n#define cudaEventDestroy hipEventDestroy\n#define cudaFree hipFree\n#define cudaFreeHost hipHostFree\n#define cudaGetDevice hipGetDevice\n#define cudaGetDeviceCount hipGetDeviceCount\n#define cudaGetDeviceProperties hipGetDeviceProperties\n#define cudaGetErrorString hipGetErrorString\n#define cudaGetLastError hipGetLastError\n#define cudaHostRegister hipHostRegister\n#define cudaHostRegisterPortable hipHostRegisterPortable\n#define cudaHostRegisterReadOnly hipHostRegisterReadOnly\n#define cudaHostUnregister hipHostUnregister\n#define cudaLaunchHostFunc hipLaunchHostFunc\n#define cudaMalloc hipMalloc\n#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)\n#define cudaMemcpy hipMemcpy\n#define cudaMemcpyAsync hipMemcpyAsync\n#define cudaMemcpyPeerAsync hipMemcpyPeerAsync\n#define cudaMemcpy2DAsync hipMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice hipMemcpyHostToDevice\n#define cudaMemcpyKind hipMemcpyKind\n#define cudaMemset hipMemset\n#define cudaMemsetAsync hipMemsetAsync\n#define cudaMemGetInfo hipMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize\n#define cudaSetDevice hipSetDevice\n#define cuDeviceGet hipDeviceGet\n#define CUdevice hipDevice_t\n#define CUdeviceptr hipDeviceptr_t\n#define cuMemUnmap hipMemUnmap\n#define CUmemAccessDesc hipMemAccessDesc\n#define cuMemAddressFree hipMemAddressFree\n#define cuMemRelease hipMemRelease\n#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t\n#define cuMemCreate hipMemCreate\n#define cuMemAddressReserve hipMemAddressReserve\n#define cuMemMap hipMemMap\n#define cuMemSetAccess hipMemSetAccess\n#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity\n#define CUmemAllocationProp hipMemAllocationProp\n#define cuDeviceGetAttribute hipDeviceGetAttribute\n#define cudaStreamCreateWithFlags hipStreamCreateWithFlags\n#define cudaStreamDestroy hipStreamDestroy\n#define cudaStreamFireAndForget hipStreamFireAndForget\n#define cudaStreamNonBlocking hipStreamNonBlocking\n#define cudaStreamPerThread hipStreamPerThread\n#define cudaStreamSynchronize hipStreamSynchronize\n#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)\n#define cudaGraphExec_t hipGraphExec_t\n#define cudaGraphNode_t hipGraphNode_t\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaGraphExecDestroy hipGraphExecDestroy\n#define cudaGraphLaunch hipGraphLaunch\n#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure\n#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult\n#define cudaGraphNodeType hipGraphNodeType\n#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel\n#define cudaGraphInstantiate hipGraphInstantiate\n#define cudaStreamEndCapture hipStreamEndCapture\n#define cudaGraphDestroy hipGraphDestroy\n#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams\n#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction\n#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams\n#define cudaGraphNodeGetType hipGraphNodeGetType\n#define cudaGraphGetNodes hipGraphGetNodes\n#define cudaGraphExecUpdate hipGraphExecUpdate\n#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed\n#define cudaStreamBeginCapture hipStreamBeginCapture\n#define cudaGraph_t hipGraph_t\n#define cudaStream_t hipStream_t\n#define cudaSuccess hipSuccess\n#define cudaHostFn_t hipHostFn_t\n#define __trap() do { abort(); __builtin_unreachable(); } while(0)\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED\n#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED\n#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE\n#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH\n#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR\n#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED\n#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR\n#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED\n\n#define __CUDA_ARCH__ 1300\n\n#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)\n#define GCN\n#endif\n\n#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)\n#define CDNA\n#endif\n\n#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \\\n    defined(__gfx1150__) || defined(__gfx1151__)\n#define RDNA3\n#endif\n\n#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \\\n    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)\n#define RDNA2\n#endif\n\n#if defined(__gfx1010__) || defined(__gfx1012__)\n#define RDNA1\n#endif\n\n#ifndef __has_builtin\n    #define __has_builtin(x) 0\n#endif\n\ntypedef hip_bfloat16 nv_bfloat16;\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/vendors/musa.h",
    "content": "#pragma once\n\n#include <musa_runtime.h>\n#include <musa.h>\n#include <mublas.h>\n#include <musa_bf16.h>\n#include <musa_fp16.h>\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F\n#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N MUBLAS_OP_N\n#define CUBLAS_OP_T MUBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT\n#define CUDA_R_16F  MUSA_R_16F\n#define CUDA_R_32F  MUSA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#define cublasCreate mublasCreate\n#define cublasDestroy mublasDestroy\n#define cublasGemmEx mublasGemmEx\n#define cublasGemmBatchedEx mublasGemmBatchedEx\n#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx\n#define cublasHandle_t mublasHandle_t\n#define cublasSetMathMode mublasSetMathMode\n#define cublasSetStream mublasSetStream\n#define cublasSgemm mublasSgemm\n#define cublasStatus_t mublasStatus_t\n#define cublasOperation_t mublasOperation_t\n#define cublasGetStatusString mublasStatus_to_string\n#define cudaDataType_t musaDataType_t\n#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess\n#define cudaDeviceProp musaDeviceProp\n#define cudaDeviceSynchronize musaDeviceSynchronize\n#define cudaError_t musaError_t\n#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags musaEventCreateWithFlags\n#define cudaEventDisableTiming musaEventDisableTiming\n#define cudaEventRecord musaEventRecord\n#define cudaEventSynchronize musaEventSynchronize\n#define cudaEvent_t musaEvent_t\n#define cudaEventDestroy musaEventDestroy\n#define cudaFree musaFree\n#define cudaFreeHost musaFreeHost\n#define cudaGetDevice musaGetDevice\n#define cudaGetDeviceCount musaGetDeviceCount\n#define cudaGetDeviceProperties musaGetDeviceProperties\n#define cudaGetErrorString musaGetErrorString\n#define cudaGetLastError musaGetLastError\n#define cudaHostRegister musaHostRegister\n#define cudaHostRegisterPortable musaHostRegisterPortable\n#define cudaHostRegisterReadOnly musaHostRegisterReadOnly\n#define cudaHostUnregister musaHostUnregister\n#define cudaLaunchHostFunc musaLaunchHostFunc\n#define cudaMalloc musaMalloc\n#define cudaMallocHost musaMallocHost\n#define cudaMallocManaged musaMallocManaged\n#define cudaMemcpy musaMemcpy\n#define cudaMemcpyAsync musaMemcpyAsync\n#define cudaMemcpyPeerAsync musaMemcpyPeerAsync\n#define cudaMemcpy2DAsync musaMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice musaMemcpyHostToDevice\n#define cudaMemcpyKind musaMemcpyKind\n#define cudaMemset musaMemset\n#define cudaMemsetAsync musaMemsetAsync\n#define cudaMemGetInfo musaMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize\n#define cudaSetDevice musaSetDevice\n#define cudaStreamCreateWithFlags musaStreamCreateWithFlags\n#define cudaStreamDestroy musaStreamDestroy\n#define cudaStreamFireAndForget musaStreamFireAndForget\n#define cudaStreamNonBlocking musaStreamNonBlocking\n#define cudaStreamPerThread musaStreamPerThread\n#define cudaStreamSynchronize musaStreamSynchronize\n#define cudaStreamWaitEvent musaStreamWaitEvent\n#define cudaStream_t musaStream_t\n#define cudaSuccess musaSuccess\n\n// Additional mappings for MUSA virtual memory pool\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED\n#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED\n#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE\n#define CUdevice MUdevice\n#define CUdeviceptr MUdeviceptr\n#define CUmemAccessDesc MUmemAccessDesc\n#define CUmemAllocationProp MUmemAllocationProp\n#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle\n#define cuDeviceGet muDeviceGet\n#define cuDeviceGetAttribute muDeviceGetAttribute\n#define cuMemAddressFree muMemAddressFree\n#define cuMemAddressReserve muMemAddressReserve\n#define cuMemCreate muMemCreate\n#define cuMemGetAllocationGranularity muMemGetAllocationGranularity\n#define cuMemMap muMemMap\n#define cuMemRelease muMemRelease\n#define cuMemSetAccess muMemSetAccess\n#define cuMemUnmap muMemUnmap\n#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize\n#define cudaFuncSetAttribute musaFuncSetAttribute\n#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms\n#define make_cudaExtent make_musaExtent\n#define make_cudaPitchedPtr make_musaPitchedPtr\n\n// Additional mappings for MUSA graphs\n#define CUDA_SUCCESS MUSA_SUCCESS\n#define CUresult MUresult\n#define cuGetErrorString muGetErrorString\n#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure\n#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction\n#define cudaGraphDestroy musaGraphDestroy\n#define cudaGraphExecDestroy musaGraphExecDestroy\n#define cudaGraphExec_t musaGraphExec_t\n#define cudaGraphExecUpdate musaGraphExecUpdate\n#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult\n#define cudaGraphGetNodes musaGraphGetNodes\n#define cudaGraphInstantiate musaGraphInstantiate\n#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams\n#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams\n#define cudaGraphLaunch musaGraphLaunch\n#define cudaGraphNodeGetType musaGraphNodeGetType\n#define cudaGraphNode_t musaGraphNode_t\n#define cudaGraphNodeType musaGraphNodeType\n#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel\n#define cudaGraph_t musaGraph_t\n#define cudaKernelNodeParams musaKernelNodeParams\n#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed\n#define cudaStreamEndCapture musaStreamEndCapture\n\ntypedef mt_bfloat16 nv_bfloat16;\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cpu_backend/vendors/vendor.h",
    "content": "#ifndef CPUINFER_VENDOR_VENDOR_H\n#define CPUINFER_VENDOR_VENDOR_H\n\n#ifdef USE_CUDA\n#include \"cuda.h\"\n#elif USE_HIP\n#define __HIP_PLATFORM_AMD__\n#include \"hip.h\"\n#elif USE_MUSA\n#include \"musa.h\"\n#endif\n\n#endif  // CPUINFER_VENDOR_VENDOR_H"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cuda/binding.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Azure-Tang, Boxin Zhang\n * @Date         : 2024-07-25 13:38:30\n * @Version      : 0.2.2\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n**/\n\n#include \"custom_gguf/ops.h\"\n#ifdef KTRANSFORMERS_USE_CUDA\n#include \"gptq_marlin/ops.h\"\n#endif\n// Python bindings\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <torch/library.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n// namespace py = pybind11;\n\nPYBIND11_MODULE(KTransformersOps, m) {\n\n    m.def(\"dequantize_q8_0\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q8_0 data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_q6_k\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q6_k data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_q5_k\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q5_k data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_q4_k\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q4_k data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_q3_k\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q3_k data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_q2_k\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q2_k data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_iq4_xs\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize iq4_xs data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n#ifdef KTRANSFORMERS_USE_CUDA\n    m.def(\"gptq_marlin_gemm\", &gptq_marlin_gemm, \"Function to perform GEMM using Marlin quantization.\",\n        py::arg(\"a\"), py::arg(\"b_q_weight\"), py::arg(\"b_scales\"), py::arg(\"g_idx\"),\n        py::arg(\"perm\"), py::arg(\"workspace\"), py::arg(\"num_bits\"), py::arg(\"size_m\"),\n        py::arg(\"size_n\"), py::arg(\"size_k\"), py::arg(\"is_k_full\"));\n#endif\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu",
    "content": "/*\n * @Description  :  \n * @Author       : Azure-Tang, Boxin Zhang\n * @Date         : 2024-07-25 13:38:30\n * @Version      : 0.2.2\n * Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c\n * Copyright (c) 2023-2024 The ggml authors\n * Copyright (c) 2024 by KVCache.AI, All Rights Reserved. \n */\n#include <cuda_runtime.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <torch/library.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n#include <cstdint>\n#include <c10/cuda/CUDAGuard.h>\n\n#ifdef __HIP_PLATFORM_AMD__\ntypedef __hip_bfloat16 nv_bfloat16;\n#endif\n\n__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        const int8_t* cur_block = data + block_id * blk_size;\n        float scale = __half2float(*((half*)cur_block));\n        cur_block += 2;\n        for (int i = 0; i < ele_per_blk; i++){\n            output_blk[i] = scale * cur_block[i];\n        }\n    }\n}\n\n__global__ void dequantize_q8_0_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        const int8_t* cur_block = data + block_id * blk_size;\n        float scale = __half2float(*((half*)cur_block));\n        cur_block += 2;\n        for (int i = 0; i < ele_per_blk; i++) {\n            output_blk[i] = __float2half(scale * cur_block[i]);\n        }\n    }\n}\n\n__global__ void dequantize_q8_0_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        const int8_t* cur_block = data + block_id * blk_size;\n        float scale = __half2float(*((half*)cur_block));\n        cur_block += 2;\n        for (int i = 0; i < ele_per_blk; i++) {\n            output_blk[i] = __float2bfloat16(scale * cur_block[i]);\n        }\n    }\n}\n\n// __device__ void get_scale_min_k4(int j, const uint8_t * __restrict__ q, uint8_t * __restrict__ d, uint8_t * __restrict__ m) {\n__device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict__ d, uint8_t * __restrict__ m) {\n    if (j < 4) {\n        *d = q[j] & 63; *m = q[j + 4] & 63;\n    } else {\n        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);\n        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);\n    }\n}\n\n__global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));\n\n        const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);\n\n        int is = 0;\n        float dl, ml;\n\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n                uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                uint8_t sc = *scales;\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;\n\n                scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                sc = *scales;\n\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;\n\n                shift += 2;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));\n\n        const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);\n\n        int is = 0;\n        float dl, ml;\n\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n                uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                uint8_t sc = *scales;\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l] >> shift) & 3)) - ml);\n\n                scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                sc = *scales;\n\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml);\n\n                shift += 2;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));\n\n        const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);\n\n        int is = 0;\n        float dl, ml;\n\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n                uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                uint8_t sc = *scales;\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l] >> shift) & 3)) - ml);\n\n                scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                sc = *scales;\n\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml);\n\n                shift += 2;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    \n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;    \n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n\n        uint32_t aux[4];\n        const int8_t * scales = (const int8_t*)aux;\n        const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));\n\n        const uint8_t * __restrict__ q  = (uint8_t*)(data + block_id * blk_size + 32);\n        const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);\n        uint8_t m = 1;\n\n\n        uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);\n\n        for (int i = 0; i < 3; i++) {  \n            aux[i] = 0;  \n            for (int j = 0; j < 4; j++) {  \n                aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);\n            }\n        }\n\n        uint32_t tmp = aux[2];\n        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n        int is = 0;\n        float dl;\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));\n                }\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));\n                }\n\n                shift += 2;\n                m <<= 1;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    \n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;    \n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n\n        uint32_t aux[4];\n        const int8_t * scales = (const int8_t*)aux;\n        const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));\n\n        const uint8_t * __restrict__ q  = (uint8_t*)(data + block_id * blk_size + 32);\n        const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);\n        uint8_t m = 1;\n\n\n        uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);\n\n        for (int i = 0; i < 3; i++) {  \n            aux[i] = 0;  \n            for (int j = 0; j < 4; j++) {  \n                aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);\n            }\n        }\n\n        uint32_t tmp = aux[2];\n        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n        int is = 0;\n        float dl;\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2half(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)));\n                }\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)));\n                }\n\n                shift += 2;\n                m <<= 1;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    \n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;    \n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n\n        uint32_t aux[4];\n        const int8_t * scales = (const int8_t*)aux;\n        const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));\n\n        const uint8_t * __restrict__ q  = (uint8_t*)(data + block_id * blk_size + 32);\n        const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);\n        uint8_t m = 1;\n\n\n        uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);\n\n        for (int i = 0; i < 3; i++) {  \n            aux[i] = 0;  \n            for (int j = 0; j < 4; j++) {  \n                aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);\n            }\n        }\n\n        uint32_t tmp = aux[2];\n        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n        int is = 0;\n        float dl;\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)));\n                }\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)));\n                }\n\n                shift += 2;\n                m <<= 1;\n            }\n            q += 32;\n        }\n    }\n}\n\n\n__global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        // const uint8_t * q = data[i].qs;\n        const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));\n        int is = 0;\n        uint8_t sc, m;\n        for (int j = 0; j < ele_per_blk; j += 64) {\n            uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d1 * (q[l] & 0xF) - m1;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d2 * (q[l]  >> 4) - m2;\n            q += 32; is += 2;\n        }\n    }\n}\n\n__global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        // const uint8_t * q = data[i].qs;\n        const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));\n        int is = 0;\n        uint8_t sc, m;\n        for (int j = 0; j < ele_per_blk; j += 64) {\n            uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * (q[l] & 0xF) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * (q[l]  >> 4) - m2);\n            q += 32; is += 2;\n        }\n    }\n}\n\n__global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        // const uint8_t * q = data[i].qs;\n        const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));\n        int is = 0;\n        uint8_t sc, m;\n        for (int j = 0; j < ele_per_blk; j += 64) {\n            uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * (q[l] & 0xF) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * (q[l]  >> 4) - m2);\n            q += 32; is += 2;\n        }\n    }\n}\n\n__global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));\n\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);\n\n        int is = 0;\n        uint8_t sc, m;\n        uint8_t u1 = 1, u2 = 2;\n        uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);\n\n        for (int j = 0; j < 256; j += 64) {\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;\n            ql += 32; is += 2;\n            u1 <<= 2; u2 <<= 2;\n        }\n    }\n}\n\n__global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));\n\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);\n\n        int is = 0;\n        uint8_t sc, m;\n        uint8_t u1 = 1, u2 = 2;\n        uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);\n\n        for (int j = 0; j < 256; j += 64) {\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2);\n            ql += 32; is += 2;\n            u1 <<= 2; u2 <<= 2;\n        }\n    }\n}\n\n__global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));\n\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);\n\n        int is = 0;\n        uint8_t sc, m;\n        uint8_t u1 = 1, u2 = 2;\n        uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);\n\n        for (int j = 0; j < 256; j += 64) {\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2);\n            ql += 32; is += 2;\n            u1 <<= 2; u2 <<= 2;\n        }\n    }\n}\n\n__global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long  block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));\n\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);\n        const int8_t  * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);\n\n\n        for (int n = 0; n < ele_per_blk; n += 128) {\n            for (int l = 0; l < 32; ++l) {\n                int is = l/16;\n                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n                output_blk[l +  0] = d * sc[is + 0] * q1;\n                output_blk[l + 32] = d * sc[is + 2] * q2;\n                output_blk[l + 64] = d * sc[is + 4] * q3;\n                output_blk[l + 96] = d * sc[is + 6] * q4;\n            }\n            output_blk += 128;\n            ql += 64;\n            qh += 32;\n            sc += 8;\n        }\n    }\n}\n\n__global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long  block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));\n\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);\n        const int8_t  * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);\n\n\n        for (int n = 0; n < ele_per_blk; n += 128) {\n            for (int l = 0; l < 32; ++l) {\n                int is = l/16;\n                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n                output_blk[l +  0] = __float2half(d * sc[is + 0] * q1);\n                output_blk[l + 32] = __float2half(d * sc[is + 2] * q2);\n                output_blk[l + 64] = __float2half(d * sc[is + 4] * q3);\n                output_blk[l + 96] = __float2half(d * sc[is + 6] * q4);\n            }\n            output_blk += 128;\n            ql += 64;\n            qh += 32;\n            sc += 8;\n        }\n    }\n}\n\n__global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long  block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));\n\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);\n        const int8_t  * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);\n\n\n        for (int n = 0; n < ele_per_blk; n += 128) {\n            for (int l = 0; l < 32; ++l) {\n                int is = l/16;\n                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n                output_blk[l +  0] = __float2bfloat16(d * sc[is + 0] * q1);\n                output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2);\n                output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3);\n                output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4);\n            }\n            output_blk += 128;\n            ql += 64;\n            qh += 32;\n            sc += 8;\n        }\n    }\n}\n\nstatic constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};\n\n__global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));\n        const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));\n        const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);\n        const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);\n\n        for (int ib = 0; ib < 8; ++ib) {\n            const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);\n            const float dl = d * (ls - 32);\n            for (int j = 0; j < 16; ++j) {\n                output_blk[j + 0] = dl * kvalues_iq4nl[qs[j] & 0xf];\n                output_blk[j + 16] = dl * kvalues_iq4nl[qs[j] >> 4];\n            }\n            output_blk += 32;\n            qs += 16;\n        }\n    }\n}\n\n__global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));\n        const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));\n        const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);\n        const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);\n\n        for (int ib = 0; ib < 8; ++ib) {\n            const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);\n            const float dl = d * (ls - 32);\n            for (int j = 0; j < 16; ++j) {\n                output_blk[j + 0] = __float2half(dl * kvalues_iq4nl[qs[j] & 0xf]);\n                output_blk[j + 16] = __float2half(dl * kvalues_iq4nl[qs[j] >> 4]);\n            }\n            output_blk += 32;\n            qs += 16;\n        }\n    }\n}\n\n__global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));\n        const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));\n        const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);\n        const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);\n\n        for (int ib = 0; ib < 8; ++ib) {\n            const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);\n            const float dl = d * (ls - 32);\n            for (int j = 0; j < 16; ++j) {\n                output_blk[j + 0] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] & 0xf]);\n                output_blk[j + 16] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] >> 4]);\n            }\n            output_blk += 32;\n            qs += 16;\n        }\n    }\n}\n\ntorch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({ num_bytes }, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n\n    cudaDeviceSynchronize();\n    return output;\n}\n\n\ntorch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    // data.numel%blk_size should be 0, else raise err\n    int num_blocks = num_bytes / blk_size;\n\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    // data.numel%blk_size should be 0, else raise err\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cuda/custom_gguf/ops.h",
    "content": "/**\n * @Description  :\n * @Author       : Azure-Tang\n * @Date         : 2024-07-22 09:27:55\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-12 03:48:46\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n**/\n#pragma once\n\n#include <torch/library.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\ntorch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*\n * Adapted from https://github.com/IST-DASLab/marlin\n */\n/*\n * Adapted from  https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n */\n#include \"gptq_marlin.cuh\"\n#include \"gptq_marlin_dtypes.cuh\"\n#include <c10/cuda/CUDAGuard.h>\n#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t)               \\\n  static_assert(std::is_same<scalar_t, half>::value ||          \\\n                    std::is_same<scalar_t, nv_bfloat16>::value, \\\n                \"only float16 and bfloat16 is supported\");\n\ntemplate <typename T>\ninline std::string str(T x) {\n  return std::to_string(x);\n}\n\nnamespace gptq_marlin {\n\n#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__)\n\n__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,\n                                    int const* __restrict__ perm_int_ptr,\n                                    int4* __restrict__ out_int4_ptr, int size_m,\n                                    int size_k, int block_rows) {}\n\ntemplate <typename scalar_t,          // compute dtype, half or nv_float16\n          const int num_bits,         // number of bits used for weights\n          const int threads,          // number of threads in a threadblock\n          const int thread_m_blocks,  // number of 16x16 blocks in the m\n                                      // dimension (batchsize) of the\n                                      // threadblock\n          const int thread_n_blocks,  // same for n dimension (output)\n          const int thread_k_blocks,  // same for k dimension (reduction)\n          const int stages,  // number of stages for the async global->shared\n                             // fetch pipeline\n          const bool has_act_order,    // whether act_order is enabled\n          const int group_blocks = -1  // number of consecutive 16x16 blocks\n                                       // with a separate quantization scale\n          >\n__global__ void Marlin(\n    const int4* __restrict__ A,  // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,        // fp16 output buffer of shape mxn\n    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape\n                                          // (k/groupsize)xn\n    const int* __restrict__ g_idx,        // int32 group indices of shape k\n    int num_groups,  // number of scale groups per output channel\n    int prob_m,      // batch dimension m\n    int prob_n,      // output dimension n\n    int prob_k,      // reduction dimension k\n    int* locks       // extra global storage for barrier synchronization\n) {}\n\n}  // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                               torch::Tensor& b_scales, torch::Tensor& g_idx,\n                               torch::Tensor& perm, torch::Tensor& workspace,\n                               int64_t num_bits, int64_t size_m, int64_t size_n,\n                               int64_t size_k, bool is_k_full) {\n  TORCH_CHECK_NOT_IMPLEMENTED(false,\n                              \"marlin_gemm(..) requires CUDA_ARCH >= 8.0\");\n  return torch::empty({1, 1});\n}\n\n#else\n\n// m16n8k16 tensor core mma instruction with fp16 inputs and fp32\n// output/accumulation.\ntemplate <typename scalar_t>\n__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,\n                           const typename ScalarType<scalar_t>::FragB& frag_b,\n                           typename ScalarType<scalar_t>::FragC& frag_c) {\n  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n  float* c = reinterpret_cast<float*>(&frag_c);\n  if constexpr (std::is_same<scalar_t, half>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n          \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n          \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else {\n    STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n  }\n}\n\n// Instruction for loading a full 16x16 matrix fragment of operand A from shared\n// memory, directly in tensor core layout.\ntemplate <typename scalar_t>\n__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,\n                             const void* smem_ptr) {\n  uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n               : \"=r\"(a[0]), \"=r\"(a[1]), \"=r\"(a[2]), \"=r\"(a[3])\n               : \"r\"(smem));\n}\n\n// Lookup-table based 3-input logical operation; explicitly used for\n// dequantization as the compiler does not seem to automatically recognize it in\n// all cases.\ntemplate <int lut>\n__device__ inline int lop3(int a, int b, int c) {\n  int res;\n  asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n               : \"=r\"(res)\n               : \"r\"(a), \"r\"(b), \"r\"(c), \"n\"(lut));\n  return res;\n}\n\n// Constructs destination register by taking bytes from 2 sources (based on\n// mask)\ntemplate <int start_byte, int mask>\n__device__ inline uint32_t prmt(uint32_t a) {\n  uint32_t res;\n  asm volatile(\"prmt.b32 %0, %1, %2, %3;\\n\"\n               : \"=r\"(res)\n               : \"r\"(a), \"n\"(start_byte), \"n\"(mask));\n  return res;\n}\n\n// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16\n// values. We mostly follow the strategy in the link below, with some small\n// changes:\n// - FP16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287\n// - BF16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385\ntemplate <typename scalar_t>\n__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {\n  STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n}\n\ntemplate <>\n__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {\n  const int LO = 0x000f000f;\n  const int HI = 0x00f000f0;\n  const int EX = 0x64006400;\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);\n  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point\n  // directly into `SUB` and `ADD`.\n  const int SUB = 0x64086408;\n  const int MUL = 0x2c002c00;\n  const int ADD = 0xd480d480;\n  typename ScalarType<half>::FragB frag_b;\n  frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),\n                      *reinterpret_cast<const half2*>(&SUB));\n  frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),\n                      *reinterpret_cast<const half2*>(&MUL),\n                      *reinterpret_cast<const half2*>(&ADD));\n  return frag_b;\n}\n\ntemplate <>\n__device__ inline typename ScalarType<nv_bfloat16>::FragB\ndequant_4bit<nv_bfloat16>(int q) {\n  static constexpr uint32_t MASK = 0x000f000f;\n  static constexpr uint32_t EX = 0x43004300;\n\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n  q >>= 4;\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n\n  typename ScalarType<nv_bfloat16>::FragB frag_b;\n  static constexpr uint32_t MUL = 0x3F803F80;\n  static constexpr uint32_t ADD = 0xC308C308;\n\n  frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),\n                      *reinterpret_cast<const nv_bfloat162*>(&MUL),\n                      *reinterpret_cast<const nv_bfloat162*>(&ADD));\n  frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),\n                      *reinterpret_cast<const nv_bfloat162*>(&MUL),\n                      *reinterpret_cast<const nv_bfloat162*>(&ADD));\n  return frag_b;\n}\n\n// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or\n// bf16 Reference:\n// - FP16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85\n// - BF16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175\ntemplate <typename scalar_t>\n__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {\n  STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n}\n\ntemplate <>\n__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {\n  static constexpr uint32_t mask_for_elt_01 = 0x5250;\n  static constexpr uint32_t mask_for_elt_23 = 0x5351;\n  static constexpr uint32_t start_byte_for_fp16 = 0x64646464;\n\n  uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);\n  uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);\n\n  static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;\n\n  typename ScalarType<half>::FragB frag_b;\n  frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),\n                      *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n  frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),\n                      *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n  return frag_b;\n}\n\ntemplate <>\n__device__ inline typename ScalarType<nv_bfloat16>::FragB\ndequant_8bit<nv_bfloat16>(int q) {\n  typename ScalarType<nv_bfloat16>::FragB frag_b;\n\n  float fp32_intermediates[4];\n  uint32_t* fp32_intermediates_casted =\n      reinterpret_cast<uint32_t*>(fp32_intermediates);\n\n  static constexpr uint32_t fp32_base = 0x4B000000;\n  fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);\n  fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);\n  fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);\n  fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);\n\n  fp32_intermediates[0] -= 8388736.f;\n  fp32_intermediates[1] -= 8388736.f;\n  fp32_intermediates[2] -= 8388736.f;\n  fp32_intermediates[3] -= 8388736.f;\n\n  uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);\n  bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],\n                                   fp32_intermediates_casted[1], 0x7632);\n  bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],\n                                   fp32_intermediates_casted[3], 0x7632);\n\n  return frag_b;\n}\n\n// Multiply dequantized values by the corresponding quantization scale; used\n// only for grouped quantization.\ntemplate <typename scalar_t>\n__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,\n                             typename ScalarType<scalar_t>::FragS& frag_s,\n                             int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s =\n      ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);\n  frag_b[0] = __hmul2(frag_b[0], s);\n  frag_b[1] = __hmul2(frag_b[1], s);\n}\n\n// Same as above, but for act_order (each K is multiplied individually)\ntemplate <typename scalar_t>\n__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,\n                              typename ScalarType<scalar_t>::FragS& frag_s_1,\n                              typename ScalarType<scalar_t>::FragS& frag_s_2,\n                              typename ScalarType<scalar_t>::FragS& frag_s_3,\n                              typename ScalarType<scalar_t>::FragS& frag_s_4,\n                              int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s_val_1_2;\n  s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];\n  s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];\n\n  scalar_t2 s_val_3_4;\n  s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];\n  s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];\n\n  frag_b[0] = __hmul2(frag_b[0], s_val_1_2);\n  frag_b[1] = __hmul2(frag_b[1], s_val_3_4);\n}\n\n// Given 2 floats multiply by 2 scales (halves)\ntemplate <typename scalar_t>\n__device__ inline void scale_float(float* c,\n                                   typename ScalarType<scalar_t>::FragS& s) {\n  scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);\n  c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));\n  c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));\n}\n\n// Wait until barrier reaches `count`, then lock for current threadblock.\n__device__ inline void barrier_acquire(int* lock, int count) {\n  if (threadIdx.x == 0) {\n    int state = -1;\n    do\n      // Guarantee that subsequent writes by this threadblock will be visible\n      // globally.\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\"\n                   : \"=r\"(state)\n                   : \"l\"(lock));\n    while (state != count);\n  }\n  __syncthreads();\n}\n\n// Release barrier and increment visitation count.\n__device__ inline void barrier_release(int* lock, bool reset = false) {\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    if (reset) {\n      lock[0] = 0;\n      return;\n    }\n    int val = 1;\n    // Make sure that all writes since acquiring this barrier are visible\n    // globally, while releasing the barrier.\n    asm volatile(\"fence.acq_rel.gpu;\\n\");\n    asm volatile(\"red.relaxed.gpu.global.add.s32 [%0], %1;\\n\"\n                 :\n                 : \"l\"(lock), \"r\"(val));\n  }\n}\n\n// For a given \"a\" of size [M,K] performs a permutation of the K columns based\n// on the given \"perm\" indices.\n__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,\n                                    int const* __restrict__ perm_int_ptr,\n                                    int4* __restrict__ out_int4_ptr, int size_m,\n                                    int size_k, int block_rows) {\n  int start_row = block_rows * blockIdx.x;\n  int finish_row = start_row + block_rows;\n  if (finish_row > size_m) {\n    finish_row = size_m;\n  }\n  int cur_block_rows = finish_row - start_row;\n\n  int row_stride = size_k * sizeof(half) / 16;\n\n  auto permute_row = [&](int row) {\n    int iters = size_k / default_threads;\n    int rest = size_k % default_threads;\n\n    int offset = row * row_stride;\n\n    half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);\n    half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);\n\n    int base_k = 0;\n\n    for (int i = 0; i < iters; i++) {\n      int cur_k = base_k + threadIdx.x;\n      int src_pos = perm_int_ptr[cur_k];\n\n      out_half[cur_k] = a_row_half[src_pos];\n\n      base_k += default_threads;\n    }\n\n    if (rest) {\n      if (threadIdx.x < rest) {\n        int cur_k = base_k + threadIdx.x;\n        int src_pos = perm_int_ptr[cur_k];\n\n        out_half[cur_k] = a_row_half[src_pos];\n      }\n    }\n  };\n\n  for (int i = 0; i < cur_block_rows; i++) {\n    int cur_row = start_row + i;\n    if (cur_row < size_m) {\n      permute_row(cur_row);\n    }\n  }\n}\n\ntemplate <typename scalar_t,          // compute dtype, half or nv_float16\n          const int num_bits,         // number of bits used for weights\n          const int threads,          // number of threads in a threadblock\n          const int thread_m_blocks,  // number of 16x16 blocks in the m\n                                      // dimension (batchsize) of the\n                                      // threadblock\n          const int thread_n_blocks,  // same for n dimension (output)\n          const int thread_k_blocks,  // same for k dimension (reduction)\n          const int stages,  // number of stages for the async global->shared\n                             // fetch pipeline\n          const bool has_act_order,    // whether act_order is enabled\n          const int group_blocks = -1  // number of consecutive 16x16 blocks\n                                       // with a separate quantization scale\n          >\n__global__ void Marlin(\n    const int4* __restrict__ A,  // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,        // fp16 output buffer of shape mxn\n    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape\n                                          // (k/groupsize)xn\n    const int* __restrict__ g_idx,        // int32 group indices of shape k\n    int num_groups,  // number of scale groups per output channel\n    int prob_m,      // batch dimension m\n    int prob_n,      // output dimension n\n    int prob_k,      // reduction dimension k\n    int* locks       // extra global storage for barrier synchronization\n) {\n  // Each threadblock processes one \"stripe\" of the B matrix with (roughly) the\n  // same size, which might involve multiple column \"slices\" (of width 16 *\n  // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM\n  // example:\n  //   0 1 3\n  //   0 2 3\n  //   1 2 4\n  // While this kind of partitioning makes things somewhat more complicated, it\n  // ensures good utilization of all SMs for many kinds of shape and GPU\n  // configurations, while requiring as few slow global cross-threadblock\n  // reductions as possible.\n  using Dtype = ScalarType<scalar_t>;\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  using FragA = typename ScalarType<scalar_t>::FragA;\n  using FragB = typename ScalarType<scalar_t>::FragB;\n  using FragC = typename ScalarType<scalar_t>::FragC;\n  using FragS = typename ScalarType<scalar_t>::FragS;\n\n  constexpr int pack_factor = 32 / num_bits;\n\n  // For larger GEMMs we run multiple batchsize 64 versions in parallel for a\n  // better partitioning with less reductions\n  int parallel = 1;\n  if (prob_m > 16 * thread_m_blocks) {\n    parallel = prob_m / (16 * thread_m_blocks);\n    prob_m = 16 * thread_m_blocks;\n  }\n\n  int k_tiles = prob_k / 16 / thread_k_blocks;\n  int n_tiles = prob_n / 16 / thread_n_blocks;\n  int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);\n\n  if constexpr (!has_act_order && group_blocks != -1) {\n    if (group_blocks >= thread_k_blocks) {\n      // Ensure that the number of tiles in each stripe is a multiple of the\n      // groupsize; this avoids an annoying special case where a stripe starts\n      // in the middle of group.\n      iters = (group_blocks / thread_k_blocks) *\n              div_ceil(iters, (group_blocks / thread_k_blocks));\n    }\n  }\n\n  int slice_row = (iters * blockIdx.x) % k_tiles;\n  int slice_col_par = (iters * blockIdx.x) / k_tiles;\n  int slice_col = slice_col_par;\n  int slice_iters;  // number of threadblock tiles in the current slice\n  int slice_count =\n      0;          // total number of active threadblocks in the current slice\n  int slice_idx;  // index of threadblock in current slice; numbered bottom to\n                  // top\n\n  // We can easily implement parallel problem execution by just remapping\n  // indices and advancing global pointers\n  if (slice_col_par >= n_tiles) {\n    A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;\n    C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;\n    locks += (slice_col_par / n_tiles) * n_tiles;\n    slice_col = slice_col_par % n_tiles;\n  }\n\n  // Compute all information about the current slice which is required for\n  // synchronization.\n  auto init_slice = [&]() {\n    slice_iters =\n        iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);\n    if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;\n    if (slice_iters == 0) return;\n    if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;\n    slice_count = 1;\n    slice_idx = 0;\n    int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);\n    if (col_first <= k_tiles * (slice_col_par + 1)) {\n      int col_off = col_first - k_tiles * slice_col_par;\n      slice_count = div_ceil(k_tiles - col_off, iters);\n      if (col_off > 0) slice_count++;\n      int delta_first = iters * blockIdx.x - col_first;\n      if (delta_first < 0 || (col_off == 0 && delta_first == 0))\n        slice_idx = slice_count - 1;\n      else {\n        slice_idx = slice_count - 1 - delta_first / iters;\n        if (col_off > 0) slice_idx--;\n      }\n    }\n    if (slice_col == n_tiles) {\n      A += 16 * thread_m_blocks * prob_k / 8;\n      C += 16 * thread_m_blocks * prob_n / 8;\n      locks += n_tiles;\n      slice_col = 0;\n    }\n  };\n  init_slice();\n\n  // A sizes/strides\n\n  // stride of the A matrix in global memory\n  int a_gl_stride = prob_k / 8;\n  // stride of an A matrix tile in shared memory\n  constexpr int a_sh_stride = 16 * thread_k_blocks / 8;\n  // delta between subsequent A tiles in global memory\n  constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;\n  // between subsequent accesses within a tile\n  int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory writes\n  constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory tile reads\n  constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));\n  // within a shared memory tile\n  constexpr int a_sh_rd_delta_i = a_sh_stride * 16;\n  // overall size of a tile\n  constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);\n  // number of shared write iterations for a tile\n  constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);\n\n  // B sizes/strides\n  int b_gl_stride = 16 * prob_n / (pack_factor * 4);\n  constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;\n  constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;\n  constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;\n\n  int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;\n  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);\n  constexpr int b_sh_wr_delta = threads * b_thread_vecs;\n  constexpr int b_sh_rd_delta = threads * b_thread_vecs;\n  constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;\n  constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;\n\n  // Scale sizes/strides without act_order\n  int s_gl_stride = prob_n / 8;\n  constexpr int s_sh_stride = 16 * thread_n_blocks / 8;\n  constexpr int s_tb_groups =\n      !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks\n          ? thread_k_blocks / group_blocks\n          : 1;\n  constexpr int s_sh_stage = s_tb_groups * s_sh_stride;\n  int s_gl_rd_delta = s_gl_stride;\n\n  // Scale size/strides with act_order\n  constexpr int tb_k = 16 * thread_k_blocks;\n  constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;\n  // constexpr int act_s_row_stride      = 1;\n  // int           act_s_col_stride      = act_s_row_stride * num_groups;\n  int act_s_col_stride = 1;\n  int act_s_col_warp_stride = act_s_col_stride * 8;\n  int tb_n_warps = thread_n_blocks / 4;\n  int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;\n\n  // Global A read index of current thread.\n  int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                (threadIdx.x % a_gl_rd_delta_o);\n  a_gl_rd += a_gl_rd_delta_o * slice_row;\n  // Shared write index of current thread.\n  int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                (threadIdx.x % a_gl_rd_delta_o);\n  // Shared read index.\n  int a_sh_rd =\n      a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;\n  a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));\n\n  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +\n                (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;\n  b_gl_rd += b_sh_stride * slice_col;\n  b_gl_rd += b_gl_rd_delta_o * slice_row;\n  int b_sh_wr = threadIdx.x * b_thread_vecs;\n  int b_sh_rd = threadIdx.x * b_thread_vecs;\n\n  // For act_order\n  constexpr int k_iter_size = tb_k / b_sh_wr_iters;\n  int slice_k_start = tb_k * slice_row;\n  int slice_k_finish = slice_k_start + tb_k * slice_iters;\n  int slice_k_start_shared_fetch = slice_k_start;\n  int slice_n_offset = act_s_col_tb_stride * slice_col;\n\n  // No act_order\n  int s_gl_rd;\n  if constexpr (!has_act_order) {\n    if constexpr (group_blocks == -1) {\n      s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n    } else {\n      s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +\n                s_sh_stride * slice_col + threadIdx.x;\n    }\n  }\n  int s_sh_wr = threadIdx.x;\n  bool s_sh_wr_pred = threadIdx.x < s_sh_stride;\n\n  // We use a different scale layout for grouped and column-wise quantization as\n  // we scale a `half2` tile in column-major layout in the former and in\n  // row-major in the latter case.\n  int s_sh_rd;\n  if constexpr (group_blocks != -1)\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n              (threadIdx.x % 32) / 4;\n  else\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n              (threadIdx.x % 32) % 4;\n\n  // Precompute which thread should not read memory in which iterations; this is\n  // needed if there are more threads than required for a certain tilesize or\n  // when the batchsize is not a multiple of 16.\n  bool a_sh_wr_pred[a_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;\n\n  // To ensure that writing and reading A tiles to/from shared memory, the\n  // latter in fragment format, is fully bank conflict free, we need to use a\n  // rather fancy XOR-based layout. The key here is that neither reads nor\n  // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the\n  // same shared memory banks. Further, it seems (based on NSight-Compute) that\n  // each warp must also write a consecutive memory segment?\n  auto transform_a = [&](int i) {\n    int row = i / a_gl_rd_delta_o;\n    return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;\n  };\n  // Since the computation of this remapping is non-trivial and, due to our main\n  // loop unrolls, all shared memory accesses are static, we simply precompute\n  // both transformed reads and writes.\n  int a_sh_wr_trans[a_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);\n  int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];\n  #pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++) {\n  #pragma unroll\n    for (int j = 0; j < thread_m_blocks; j++)\n      a_sh_rd_trans[i][j] =\n          transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);\n  }\n\n  // Since B-accesses have non-constant stride they have to be computed at\n  // runtime; we break dependencies between subsequent accesses with a tile by\n  // maintining multiple pointers (we have enough registers), a tiny\n  // optimization.\n  const int4* B_ptr[b_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++)\n    B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;\n\n  extern __shared__ int4 sh[];\n  // Shared memory storage for global fetch pipelines.\n  int4* sh_a = sh;\n  int4* sh_b = sh_a + (stages * a_sh_stage);\n  int4* sh_g_idx = sh_b + (stages * b_sh_stage);\n  int4* sh_s = sh_g_idx + (stages * g_idx_stage);\n\n  // Register storage for double buffer of shared memory reads.\n  FragA frag_a[2][thread_m_blocks];\n  I4 frag_b_quant[2][b_thread_vecs];\n  FragC frag_c[thread_m_blocks][4][2];\n  FragS frag_s[2][4];         // No act-order\n  FragS act_frag_s[2][4][4];  // For act-order\n\n  // Zero accumulators.\n  auto zero_accums = [&]() {\n  #pragma unroll\n    for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)\n      reinterpret_cast<float*>(frag_c)[i] = 0;\n  };\n\n  int sh_first_group_id = -1;\n  int sh_num_groups = -1;\n  constexpr int sh_max_num_groups = 32;\n\n  auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,\n                                    int last_group_id) {\n    sh_first_group_id = first_group_id;\n    sh_num_groups = last_group_id - first_group_id + 1;\n\n    if (sh_num_groups < sh_max_num_groups) {\n      sh_num_groups = sh_max_num_groups;\n    }\n\n    if (sh_first_group_id + sh_num_groups > num_groups) {\n      sh_num_groups = num_groups - sh_first_group_id;\n    }\n\n    int row_offset = first_group_id * s_gl_stride;\n\n    if (is_async) {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],\n                         &scales_ptr[row_offset + (i * s_gl_stride) +\n                                     slice_n_offset + threadIdx.x]);\n        }\n      }\n    } else {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          sh_s[(i * s_sh_stride) + threadIdx.x] =\n              scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +\n                         threadIdx.x];\n        }\n      }\n    }\n  };\n  // Asynchronously fetch the next A, B and s tile from global to the next\n  // shared memory pipeline location.\n  auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {\n    if (pred) {\n      int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n  #pragma unroll\n      for (int i = 0; i < a_sh_wr_iters; i++) {\n        cp_async4_pred(\n            &sh_a_stage[a_sh_wr_trans[i]],\n            &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],\n            a_sh_wr_pred[i]);\n      }\n      int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n  #pragma unroll\n      for (int i = 0; i < b_sh_wr_iters; i++) {\n  #pragma unroll\n        for (int j = 0; j < b_thread_vecs; j++) {\n          cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);\n        }\n\n        B_ptr[i] += b_gl_rd_delta_o;\n      }\n\n      if constexpr (has_act_order) {\n        // Fetch g_idx thread-block portion\n        int full_pipe = a_off;\n        int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;\n        if (cur_k < prob_k && cur_k < slice_k_finish) {\n          int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n\n          int4 const* cur_g_idx_stage_ptr =\n              reinterpret_cast<int4 const*>(&g_idx[cur_k]);\n\n          if (threadIdx.x < g_idx_stage) {\n            cp_async4_pred(&sh_g_idx_stage[threadIdx.x],\n                           &cur_g_idx_stage_ptr[threadIdx.x]);\n          }\n        }\n      } else {\n        if constexpr (group_blocks != -1) {\n          int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n          if constexpr (group_blocks >= thread_k_blocks) {\n            // Only fetch scales if this tile starts a new group\n            if (pipe % (group_blocks / thread_k_blocks) == 0) {\n              if (s_sh_wr_pred) {\n                cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);\n              }\n              s_gl_rd += s_gl_rd_delta;\n            }\n          } else {\n            for (int i = 0; i < s_tb_groups; i++) {\n              if (s_sh_wr_pred) {\n                cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],\n                          &scales_ptr[s_gl_rd]);\n              }\n              s_gl_rd += s_gl_rd_delta;\n            }\n          }\n        }\n      }\n    }\n    // Insert a fence even when we are winding down the pipeline to ensure that\n    // waiting is also correct at this point.\n    cp_async_fence();\n  };\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<stages - 2>();\n    __syncthreads();\n  };\n\n  // Load the next sub-tile from the current location in the shared memory pipe\n  // into the current register buffer.\n  auto fetch_to_registers = [&](int k, int pipe) {\n    int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n  #pragma unroll\n    for (int i = 0; i < thread_m_blocks; i++)\n      ldsm4<scalar_t>(frag_a[k % 2][i],\n                      &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);\n    int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n\n  #pragma unroll\n    for (int i = 0; i < b_thread_vecs; i++) {\n      frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(\n          &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);\n    }\n  };\n\n  bool is_same_group[stages];\n  int same_group_id[stages];\n\n  auto init_same_group = [&](int pipe) {\n    if constexpr (!has_act_order) {\n      is_same_group[pipe] = false;\n      same_group_id[pipe] = 0;\n      return;\n    }\n\n    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n    int group_id_1 = sh_g_idx_int_ptr[0];\n    int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];\n\n    is_same_group[pipe] = group_id_1 == group_id_2;\n    same_group_id[pipe] = group_id_1;\n  };\n\n  auto fetch_scales_to_registers = [&](int k, int full_pipe) {\n    int pipe = full_pipe % stages;\n\n    if constexpr (!has_act_order) {\n      // No act-order case\n      if constexpr (group_blocks != -1) {\n        if constexpr (group_blocks >= thread_k_blocks) {\n          int4* sh_s_stage =\n              sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *\n                                   (pipe / (group_blocks / thread_k_blocks)));\n          reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];\n        } else {\n          int warp_id = threadIdx.x / 32;\n          int n_warps = thread_n_blocks / 4;\n\n          int warp_row = warp_id / n_warps;\n\n          int cur_k = warp_row * 16;\n          cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n          int k_blocks = cur_k / 16;\n          int cur_group_id = k_blocks / group_blocks;\n\n          int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n          reinterpret_cast<int4*>(&frag_s[k % 2])[0] =\n              sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];\n        }\n      }\n\n      return;\n    }\n\n    // Act-order case\n\n    // Determine K of the \"current\" thread-block\n    int cur_k = slice_k_start + tb_k * full_pipe;\n    if (cur_k >= prob_k || cur_k >= slice_k_finish) {\n      return;\n    }\n\n    // Reset (to current thread-block) since we read g_idx portion from the\n    // shared memory\n    cur_k = 0;\n\n    // Progress to current iteration\n    cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n    // Determine \"position\" inside the thread-block (based on warp and\n    // thread-id)\n    int warp_id = threadIdx.x / 32;\n    int n_warps =\n        thread_n_blocks / 4;  // Each warp processes 4 16-size tiles over N\n\n    int warp_row = warp_id / n_warps;\n    int warp_col = warp_id % n_warps;\n\n    cur_k += warp_row * 16;\n\n    int th_id = threadIdx.x % 32;\n    cur_k += (th_id % 4) * 2;  // Due to tensor-core layout for fp16 B matrix\n\n    int s_col_shift =\n        /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +\n        (th_id / 4) * act_s_col_stride;\n\n    if (is_same_group[pipe]) {\n      if (k % 2 == 0) {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n            sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +\n                 s_col_shift];\n      } else {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n            *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));\n      }\n\n      for (int i = 1; i < 4; i++) {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =\n            *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));\n      }\n      return;\n    }\n\n    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n    constexpr int k_frag_offsets[4] = {0, 1, 8,\n                                       9};  // Tensor core offsets per thread\n\n  #pragma unroll\n    for (int i = 0; i < 4; i++) {\n      int actual_k = cur_k + k_frag_offsets[i];\n\n      int group_id = sh_g_idx_int_ptr[actual_k];\n      int rel_group_id = group_id - sh_first_group_id;\n\n      *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =\n          sh_s[rel_group_id * s_sh_stride + s_col_shift];\n    }\n  };\n\n  // Execute the actual tensor core matmul of a sub-tile.\n  auto matmul = [&](int k) {\n  // We have the m dimension as the inner loop in order to encourage overlapping\n  // dequantization and matmul operations.\n  #pragma unroll\n    for (int j = 0; j < 4; j++) {\n      FragB frag_b0;\n      FragB frag_b1;\n      if constexpr (num_bits == 4) {\n        int b_quant = frag_b_quant[k % 2][0][j];\n        int b_quant_shift = b_quant >> 8;\n\n        frag_b0 = dequant_4bit<scalar_t>(b_quant);\n        frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);\n\n      } else {\n        int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);\n        int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];\n        int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];\n\n        frag_b0 = dequant_8bit<scalar_t>(b_quant_0);\n        frag_b1 = dequant_8bit<scalar_t>(b_quant_1);\n      }\n\n      // Apply scale to frag_b0\n      if constexpr (has_act_order) {\n        scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],\n                         act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],\n                         act_frag_s[k % 2][3][j], 0);\n      } else {\n        if constexpr (group_blocks != -1) {\n          scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);\n        }\n      }\n\n      // Apply scale to frag_b1\n      if constexpr (has_act_order) {\n        scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],\n                         act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],\n                         act_frag_s[k % 2][3][j], 1);\n\n      } else {\n        if constexpr (group_blocks != -1) {\n          scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);\n        }\n      }\n\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n        mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);\n        mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);\n      }\n    }\n  };\n\n  // Since we slice across the k dimension of a tile in order to increase the\n  // number of warps while keeping the n dimension of a tile reasonable, we have\n  // multiple warps that accumulate their partial sums of the same output\n  // location; which we have to reduce over in the end. We do in shared memory.\n  auto thread_block_reduce = [&]() {\n    constexpr int red_off = threads / b_sh_stride_threads / 2;\n    if (red_off >= 1) {\n      int red_idx = threadIdx.x / b_sh_stride_threads;\n      constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;\n      constexpr int red_sh_delta = b_sh_stride_threads;\n      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +\n                      (threadIdx.x % b_sh_stride_threads);\n\n      // Parallel logarithmic shared memory reduction. We make sure to avoid any\n      // unnecessary read or write iterations, e.g., for two warps we write only\n      // once by warp 1 and read only once by warp 0.\n\n  #pragma unroll\n      for (int m_block = 0; m_block < thread_m_blocks; m_block++) {\n  #pragma unroll\n        for (int i = red_off; i > 0; i /= 2) {\n          if (i <= red_idx && red_idx < 2 * i) {\n  #pragma unroll\n            for (int j = 0; j < 4 * 2; j++) {\n              int red_sh_wr =\n                  red_sh_delta * j + (red_sh_rd - red_sh_stride * i);\n              if (i < red_off) {\n                float* c_rd =\n                    reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);\n                float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);\n  #pragma unroll\n                for (int k = 0; k < 4; k++)\n                  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=\n                      c_rd[k] + c_wr[k];\n              }\n              sh[red_sh_wr] =\n                  reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];\n            }\n          }\n          __syncthreads();\n        }\n        if (red_idx == 0) {\n  #pragma unroll\n          for (int i = 0; i < 4 * 2; i++) {\n            float* c_rd =\n                reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);\n  #pragma unroll\n            for (int j = 0; j < 4; j++)\n              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=\n                  c_rd[j];\n          }\n        }\n        __syncthreads();\n      }\n    }\n  };\n\n  // Since multiple threadblocks may process parts of the same column slice, we\n  // finally have to globally reduce over the results. As the striped\n  // partitioning minimizes the number of such reductions and our outputs are\n  // usually rather small, we perform this reduction serially in L2 cache.\n  auto global_reduce = [&](bool first = false, bool last = false) {\n    // We are very careful here to reduce directly in the output buffer to\n    // maximize L2 cache utilization in this step. To do this, we write out\n    // results in FP16 (but still reduce with FP32 compute).\n    constexpr int active_threads = 32 * thread_n_blocks / 4;\n    if (threadIdx.x < active_threads) {\n      int c_gl_stride = prob_n / 8;\n      int c_gl_wr_delta_o = 8 * c_gl_stride;\n      int c_gl_wr_delta_i = 4 * (active_threads / 32);\n      int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +\n                    4 * (threadIdx.x / 32) + threadIdx.x % 4;\n      c_gl_wr += (2 * thread_n_blocks) * slice_col;\n      constexpr int c_sh_wr_delta = active_threads;\n      int c_sh_wr = threadIdx.x;\n\n      int row = (threadIdx.x % 32) / 4;\n\n      if (!first) {\n  // Interestingly, doing direct global accesses here really seems to mess up\n  // the compiler and lead to slowdowns, hence we also use async-copies even\n  // though these fetches are not actually asynchronous.\n  #pragma unroll\n        for (int i = 0; i < thread_m_blocks * 4; i++) {\n          cp_async4_pred(\n              &sh[c_sh_wr + c_sh_wr_delta * i],\n              &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +\n                 c_gl_wr_delta_i * (i % 2)],\n              i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);\n        }\n        cp_async_fence();\n        cp_async_wait<0>();\n      }\n\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks * 4; i++) {\n        if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {\n          if (!first) {\n            int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];\n  #pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              reinterpret_cast<float*>(\n                  &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=\n                  Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);\n            }\n          }\n          if (!last) {\n            int4 c;\n  #pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              reinterpret_cast<scalar_t*>(&c)[j] =\n                  Dtype::float2num(reinterpret_cast<float*>(\n                      &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);\n            }\n            C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =\n                c;\n          }\n        }\n      }\n    }\n  };\n\n  // Write out the reduce final result in the correct layout. We only actually\n  // reshuffle matrix fragments in this step, the reduction above is performed\n  // in fragment layout.\n  auto write_result = [&]() {\n    int c_gl_stride = prob_n / 8;\n    constexpr int c_sh_stride = 2 * thread_n_blocks + 1;\n    int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));\n    constexpr int c_sh_rd_delta =\n        c_sh_stride * (threads / (2 * thread_n_blocks));\n\n    int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                  (threadIdx.x % (2 * thread_n_blocks));\n    c_gl_wr += (2 * thread_n_blocks) * slice_col;\n    int c_sh_wr =\n        (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;\n    c_sh_wr += 32 * (threadIdx.x / 32);\n    int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                  (threadIdx.x % (2 * thread_n_blocks));\n\n    int c_gl_wr_end = c_gl_stride * prob_m;\n\n    // We first reorder in shared memory to guarantee the most efficient final\n    // global write patterns\n    auto write = [&](int idx, float c0, float c1, FragS& s) {\n      scalar_t2 res =\n          Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));\n\n      // For per-column quantization we finally apply the scale here (only for\n      // 4-bit)\n      if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {\n        res = __hmul2(res, s[0]);\n      }\n\n      ((scalar_t2*)sh)[idx] = res;\n    };\n\n    if (threadIdx.x / 32 < thread_n_blocks / 4) {\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n  #pragma unroll\n        for (int j = 0; j < 4; j++) {\n          int wr = c_sh_wr + 8 * j;\n          write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],\n                frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);\n          write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],\n                frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);\n          write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],\n                frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);\n          write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],\n                frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);\n        }\n        c_sh_wr += 16 * (4 * c_sh_stride);\n      }\n    }\n    __syncthreads();\n\n  #pragma unroll\n    for (int i = 0;\n         i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));\n         i++) {\n      if (c_gl_wr < c_gl_wr_end) {\n        C[c_gl_wr] = sh[c_sh_rd];\n        c_gl_wr += c_gl_wr_delta;\n        c_sh_rd += c_sh_rd_delta;\n      }\n    }\n  };\n\n  // Start global fetch and register load pipelines.\n  auto start_pipes = [&]() {\n\n  #pragma unroll\n    for (int i = 0; i < stages - 1; i++) {\n      if (has_act_order && i == 0) {\n        int last_g_idx = slice_k_start + stages * tb_k * 2;\n        if (last_g_idx >= prob_k) {\n          last_g_idx = prob_k - 1;\n        }\n        fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);\n      }\n      fetch_to_shared(i, i, i < slice_iters);\n    }\n\n    zero_accums();\n    wait_for_stage();\n    init_same_group(0);\n    fetch_to_registers(0, 0);\n    fetch_scales_to_registers(0, 0);\n    a_gl_rd += a_gl_rd_delta_o * (stages - 1);\n    slice_k_start_shared_fetch += tb_k * (stages - 1);\n  };\n  if (slice_iters) {\n    start_pipes();\n  }\n\n  // Main loop.\n  while (slice_iters) {\n    // We unroll over both the global fetch and the register load pipeline to\n    // ensure all shared memory accesses are static. Note that both pipelines\n    // have even length meaning that the next iteration will always start at\n    // index 0.\n\n  #pragma unroll\n    for (int pipe = 0; pipe < stages;) {\n  #pragma unroll\n      for (int k = 0; k < b_sh_wr_iters; k++) {\n        fetch_to_registers(k + 1, pipe % stages);\n        fetch_scales_to_registers(k + 1, pipe);\n        if (k == b_sh_wr_iters - 2) {\n          fetch_to_shared((pipe + stages - 1) % stages, pipe,\n                          slice_iters >= stages);\n          pipe++;\n          wait_for_stage();\n          init_same_group(pipe % stages);\n        }\n        matmul(k);\n      }\n      slice_iters--;\n      if (slice_iters == 0) {\n        break;\n      }\n    }\n\n    a_gl_rd += a_gl_rd_delta_o * stages;\n    slice_k_start += tb_k * stages;\n    slice_k_start_shared_fetch += tb_k * stages;\n\n    if constexpr (has_act_order) {\n      int first_group_id = g_idx[slice_k_start];\n      int last_g_idx = slice_k_start + stages * tb_k * 2;\n      if (last_g_idx >= prob_k) {\n        last_g_idx = prob_k - 1;\n      }\n      int last_group_id = g_idx[last_g_idx];\n      if (last_group_id >= sh_first_group_id + sh_num_groups) {\n        fetch_scales_to_shared(false, first_group_id, last_group_id);\n        __syncthreads();\n      }\n    }\n\n    // Process results and, if necessary, proceed to the next column slice.\n    // While this pattern may not be the most readable, other ways of writing\n    // the loop seemed to noticeably worse performance after compilation.\n    if (slice_iters == 0) {\n      cp_async_wait<0>();\n      bool last = slice_idx == slice_count - 1;\n      // For per-column scales, we only fetch them here in the final step before\n      // write-out\n      if constexpr (!has_act_order && group_blocks == -1) {\n        if constexpr (num_bits == 8) {\n          if (s_sh_wr_pred) {\n            cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n          }\n          cp_async_fence();\n        } else {\n          if (last) {\n            if (s_sh_wr_pred) {\n              cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n            }\n            cp_async_fence();\n          }\n        }\n      }\n\n      thread_block_reduce();\n      if constexpr (!has_act_order && group_blocks == -1) {\n        if constexpr (num_bits == 8) {\n          cp_async_wait<0>();\n          __syncthreads();\n          if (threadIdx.x / 32 < thread_n_blocks / 4) {\n            reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n            reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n          }\n\n        } else {\n          if (last) {\n            cp_async_wait<0>();\n            __syncthreads();\n            if (threadIdx.x / 32 < thread_n_blocks / 4) {\n              reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n              reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n            }\n          }\n        }\n      }\n\n      // For 8-bit channelwise, we apply the scale before the global reduction\n      // that converts the fp32 results to fp16 (so that we avoid possible\n      // overflow in fp16)\n      if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {\n        if (threadIdx.x / 32 < thread_n_blocks / 4) {\n  #pragma unroll\n          for (int i = 0; i < thread_m_blocks; i++) {\n  #pragma unroll\n            for (int j = 0; j < 4; j++) {\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][0][0]),\n                  frag_s[j / 2][2 * (j % 2) + 0]);\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][0][2]),\n                  frag_s[j / 2][2 * (j % 2) + 0]);\n\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][1][0]),\n                  frag_s[j / 2][2 * (j % 2) + 1]);\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][1][2]),\n                  frag_s[j / 2][2 * (j % 2) + 1]);\n            }\n          }\n        }\n      }\n\n      if (slice_count > 1) {  // only globally reduce if there is more than one\n                              // block in a slice\n        barrier_acquire(&locks[slice_col], slice_idx);\n        global_reduce(slice_idx == 0, last);\n        barrier_release(&locks[slice_col], last);\n      }\n      if (last)  // only the last block in a slice actually writes the result\n        write_result();\n      slice_row = 0;\n      slice_col_par++;\n      slice_col++;\n      init_slice();\n      if (slice_iters) {\n        a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                  (threadIdx.x % a_gl_rd_delta_o);\n  #pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++)\n          B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;\n        if (slice_col == 0) {\n  #pragma unroll\n          for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;\n        }\n\n        // Update slice k/n for scales loading\n        if constexpr (has_act_order) {\n          slice_k_start = tb_k * slice_row;\n          slice_k_finish = slice_k_start + tb_k * slice_iters;\n          slice_k_start_shared_fetch = slice_k_start;\n          slice_n_offset = act_s_col_tb_stride * slice_col;\n\n        } else {\n          s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n        }\n\n        start_pipes();\n      }\n    }\n  }\n}\n\n  #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,                \\\n                    THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \\\n    else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS &&     \\\n             thread_n_blocks == THREAD_N_BLOCKS &&                             \\\n             thread_k_blocks == THREAD_K_BLOCKS &&                             \\\n             has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \\\n             num_threads == NUM_THREADS) {                                     \\\n      cudaFuncSetAttribute(                                                    \\\n          Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,             \\\n                 THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \\\n                 GROUP_BLOCKS>,                                                \\\n          cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);        \\\n      Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,                 \\\n             THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER,     \\\n             GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>(   \\\n          A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n,   \\\n          prob_k, locks);                                                      \\\n    }\n\ntypedef struct {\n  int thread_k;\n  int thread_n;\n  int num_threads;\n} thread_config_t;\n\ntypedef struct {\n  int max_m_blocks;\n  thread_config_t tb_cfg;\n} exec_config_t;\n\nthread_config_t small_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {128, 128, 256},\n    {64, 128, 128},\n    {128, 64, 128},\n};\n\nthread_config_t large_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {64, 256, 256},\n    {64, 128, 128},\n    {128, 64, 128},\n\n};\n\nint get_scales_cache_size(thread_config_t const& th_config, int prob_m,\n                          int prob_n, int prob_k, int num_bits, int group_size,\n                          bool has_act_order, bool is_k_full) {\n  bool cache_scales_chunk = has_act_order && !is_k_full;\n\n  int tb_n = th_config.thread_n;\n  int tb_k = th_config.thread_k;\n\n  // Get max scale groups per thread-block\n  int tb_groups;\n  if (group_size == -1) {\n    tb_groups = 1;\n  } else if (group_size == 0) {\n    tb_groups = div_ceil(tb_k, 32);  // Worst case is 32 group size\n  } else {\n    tb_groups = div_ceil(tb_k, group_size);\n  }\n\n  if (cache_scales_chunk) {\n    int load_groups =\n        tb_groups * pipe_stages * 2;     // Chunk size is 2x pipeline over dim K\n    load_groups = max(load_groups, 32);  // We load at least 32 scale groups\n    return load_groups * tb_n * 2;\n\n  } else {\n    int tb_scales = tb_groups * tb_n * 2;\n\n    return tb_scales * pipe_stages;\n  }\n}\n\nbool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,\n                         int prob_m, int prob_n, int prob_k, int num_bits,\n                         int scales_cache_size, int max_shared_mem) {\n  int pack_factor = 32 / num_bits;\n\n  // Get B size\n  int tb_k = th_config.thread_k;\n  int tb_n = th_config.thread_n;\n\n  int b_size = (tb_k * tb_n / pack_factor) * 4;\n\n  // Get A size\n  int m_blocks = div_ceil(prob_m, 16);\n  int tb_max_m = 16;\n\n  while (true) {\n    if (m_blocks >= max_m_blocks) {\n      tb_max_m *= max_m_blocks;\n      break;\n    }\n\n    max_m_blocks--;\n    if (max_m_blocks == 0) {\n      TORCH_CHECK(false, \"Unexpected m_blocks = \", m_blocks);\n    }\n  }\n\n  int a_size = (tb_max_m * tb_k) * 2;\n\n  float pipe_size = (a_size + b_size) * pipe_stages;\n\n  TORCH_CHECK(max_shared_mem / 2 > scales_cache_size);  // Sanity\n\n  return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);\n}\n\nbool is_valid_config(thread_config_t const& th_config, int max_m_blocks,\n                     int prob_m, int prob_n, int prob_k, int num_bits,\n                     int group_size, bool has_act_order, bool is_k_full,\n                     int max_shared_mem) {\n  // Sanity\n  if (th_config.thread_k == -1 || th_config.thread_n == -1 ||\n      th_config.num_threads == -1) {\n    return false;\n  }\n\n  // Verify K/N are divisible by thread K/N\n  if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {\n    return false;\n  }\n\n  // Verify min for thread K/N\n  if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {\n    return false;\n  }\n\n  // num_threads must be at least 128 (= 4 warps)\n  if (th_config.num_threads < 128) {\n    return false;\n  }\n\n  //  Determine cache for scales\n  int scales_cache_size =\n      get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,\n                            group_size, has_act_order, is_k_full);\n\n  // Check that pipeline fits into cache\n  if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                           num_bits, scales_cache_size, max_shared_mem)) {\n    return false;\n  }\n\n  return true;\n}\n\nexec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,\n                                      int num_bits, int group_size,\n                                      bool has_act_order, bool is_k_full,\n                                      int max_shared_mem) {\n  int max_m_blocks = 4;\n  while (max_m_blocks > 0) {\n    if (prob_m <= 16) {\n      for (auto th_config : small_batch_thread_configs) {\n        if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                            num_bits, group_size, has_act_order, is_k_full,\n                            max_shared_mem)) {\n          return exec_config_t{max_m_blocks, th_config};\n        }\n      }\n    } else {\n      for (auto th_config : large_batch_thread_configs) {\n        if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                            num_bits, group_size, has_act_order, is_k_full,\n                            max_shared_mem)) {\n          return exec_config_t{max_m_blocks, th_config};\n        }\n      }\n    }\n\n    max_m_blocks--;  // Process less M blocks per invocation to reduce cache\n                     // usage\n  }\n\n  return exec_config_t{0, {-1, -1, -1}};\n}\n\n  #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS)           \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)  \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)  \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)  \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)\n\ntemplate <typename scalar_t>\nvoid marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,\n                     void* g_idx, void* perm, void* a_tmp, int prob_m,\n                     int prob_n, int prob_k, void* workspace, int num_bits,\n                     bool has_act_order, bool is_k_full, int num_groups,\n                     int group_size, int dev, cudaStream_t stream, int thread_k,\n                     int thread_n, int sms, int max_par) {\n  TORCH_CHECK(num_bits == 4 || num_bits == 8,\n              \"num_bits must be 4 or 8. Got = \", num_bits);\n  TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, \"Invalid MNK = [\", prob_m,\n              \", \", prob_n, \", \", prob_k, \"]\");\n\n  int tot_m = prob_m;\n  int tot_m_blocks = div_ceil(tot_m, 16);\n  int pad = 16 * tot_m_blocks - tot_m;\n\n  if (sms == -1) {\n    cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);\n  }\n\n  int max_shared_mem = 0;\n  cudaDeviceGetAttribute(&max_shared_mem,\n                         cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n  TORCH_CHECK(max_shared_mem > 0);\n\n  // Set thread config\n  exec_config_t exec_cfg;\n  if (thread_k != -1 && thread_n != -1) {\n    // User-defined config\n    exec_cfg =\n        exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};\n  } else {\n    // Auto config\n    exec_cfg =\n        determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,\n                                has_act_order, is_k_full, max_shared_mem);\n  }\n\n  TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&\n                  is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,\n                                  prob_m, prob_n, prob_k, num_bits, group_size,\n                                  has_act_order, is_k_full, max_shared_mem),\n              \"Invalid thread config: max_m_blocks = \", exec_cfg.max_m_blocks,\n              \", thread_k = \", exec_cfg.tb_cfg.thread_k,\n              \", thread_n = \", exec_cfg.tb_cfg.thread_n,\n              \", num_threads = \", exec_cfg.tb_cfg.num_threads, \" for MKN = [\",\n              prob_m, \", \", prob_k, \", \", prob_n, \"] and num_bits = \", num_bits,\n              \", group_size = \", group_size,\n              \", has_act_order = \", has_act_order, \", is_k_full = \", is_k_full,\n              \", max_shared_mem = \", max_shared_mem);\n\n  int num_threads = exec_cfg.tb_cfg.num_threads;\n  thread_k = exec_cfg.tb_cfg.thread_k;\n  thread_n = exec_cfg.tb_cfg.thread_n;\n\n  int thread_k_blocks = thread_k / 16;\n  int thread_n_blocks = thread_n / 16;\n\n  int blocks = sms;\n\n  TORCH_CHECK(prob_n % thread_n == 0, \"prob_n = \", prob_n,\n              \" is not divisible by thread_n = \", thread_n);\n  TORCH_CHECK(prob_k % thread_k == 0, \"prob_k = \", prob_k,\n              \" is not divisible by thread_k = \", thread_k);\n\n  int group_blocks = 0;\n  if (has_act_order) {\n    if (is_k_full) {\n      TORCH_CHECK(group_size != -1);\n      group_blocks = group_size / 16;\n      TORCH_CHECK(prob_k % group_blocks == 0, \"prob_k = \", prob_k,\n                  \" is not divisible by group_blocks = \", group_blocks);\n    } else {\n      TORCH_CHECK(group_size == 0);\n      group_blocks = 0;\n    }\n\n  } else {\n    if (group_size == -1) {\n      group_blocks = -1;\n    } else {\n      group_blocks = group_size / 16;\n      TORCH_CHECK(prob_k % group_blocks == 0, \"prob_k = \", prob_k,\n                  \" is not divisible by group_blocks = \", group_blocks);\n    }\n  }\n\n  const int4* A_ptr = (const int4*)A;\n  const int4* B_ptr = (const int4*)B;\n  int4* C_ptr = (int4*)C;\n  const int4* s_ptr = (const int4*)s;\n  const int* g_idx_ptr = (const int*)g_idx;\n  const int* perm_ptr = (const int*)perm;\n  int4* a_tmp_ptr = (int4*)a_tmp;\n\n  int* locks = (int*)workspace;\n\n  if (has_act_order) {\n    // Permute A columns\n    int block_rows = div_ceil(prob_m, blocks);\n    permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(\n        A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);\n    A_ptr = a_tmp_ptr;\n  }\n\n  // If we have a full K, then we can run the non-act-order version of Marlin\n  // (since the weight rows are reordered by increasing group ids, and by having\n  // a full K, we have full original groups)\n  if (is_k_full) {\n    has_act_order = false;\n  }\n\n  // Main loop\n  for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {\n    int thread_m_blocks = tot_m_blocks - i;\n    prob_m = tot_m - 16 * i;\n    int par = 1;\n    if (thread_m_blocks > exec_cfg.max_m_blocks) {\n      // Note that parallel > 1 currently only works for inputs without any\n      // padding\n      par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);\n      if (par > max_par) par = max_par;\n      prob_m = (16 * exec_cfg.max_m_blocks) * par;\n      i += exec_cfg.max_m_blocks * (par - 1);\n      thread_m_blocks = exec_cfg.max_m_blocks;\n    }\n\n\n\n    // Define kernel configurations\n#define undefined_error TORCH_CHECK(false, \"Unsupported shapes: MNK = [\" + str(prob_m) + \", \" + \\\n    str(prob_n) + \", \" + str(prob_k) + \"]\" + \\\n        \", has_act_order = \" + str(has_act_order) + \\\n        \", num_groups = \" + str(num_groups) + \\\n        \", group_size = \" + str(group_size) + \\\n        \", thread_m_blocks = \" + str(thread_m_blocks) + \\\n        \", thread_n_blocks = \" + str(thread_n_blocks) + \\\n        \", thread_k_blocks = \" + str(thread_k_blocks));\n\n\n    if (num_bits == 4 && num_threads == 256)\n    {\n        if (false) {\n        }\n        CALL_IF(4, 32, 2, 256)\n        CALL_IF(4, 16, 4, 256)\n        CALL_IF(4, 8, 8, 256)\n        else {\n            undefined_error\n        }\n    }\n    else if (num_bits == 4 && num_threads == 128)\n    {\n        if (false) {\n        }\n        CALL_IF(4, 8, 4, 128)\n        CALL_IF(4, 4, 8, 128)\n        else {\n            undefined_error\n        }\n    }\n    else if (num_bits == 8 && num_threads == 256)\n    {\n        if (false) {\n        }\n        CALL_IF(8, 32, 2, 256)\n        CALL_IF(8, 16, 4, 256)\n        CALL_IF(8, 8, 8, 256)\n        else {\n            undefined_error\n        }\n    }\n    else if (num_bits == 8 && num_threads == 128)\n    {\n        if (false) {\n        }\n        CALL_IF(8, 8, 4, 128)\n        CALL_IF(8, 4, 8, 128)\n        else {\n            undefined_error\n        }\n    }\n    else {\n        undefined_error\n    }\n\n    A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;\n    C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;\n  }\n}\n\n}  // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                               torch::Tensor& b_scales, torch::Tensor& g_idx,\n                               torch::Tensor& perm, torch::Tensor& workspace,\n                               int64_t num_bits, int64_t size_m, int64_t size_n,\n                               int64_t size_k, bool is_k_full) {\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(a));\n  // Verify num_bits\n  TORCH_CHECK(num_bits == 4 || num_bits == 8,\n              \"num_bits must be 4 or 8. Got = \", num_bits);\n  int pack_factor = 32 / num_bits;\n\n  // Verify A\n  TORCH_CHECK(a.size(0) == size_m, \"Shape mismatch: a.size(0) = \", a.size(0),\n              \", size_m = \", size_m);\n  TORCH_CHECK(a.size(1) == size_k, \"Shape mismatch: a.size(1) = \", a.size(1),\n              \", size_k = \", size_k);\n\n  // Verify B\n  TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, \"size_k = \", size_k,\n              \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n  TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),\n              \"Shape mismatch: b_q_weight.size(0) = \", b_q_weight.size(0),\n              \", size_k = \", size_k, \", tile_size = \", gptq_marlin::tile_size);\n  TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,\n              \"b_q_weight.size(1) = \", b_q_weight.size(1),\n              \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n  int actual_size_n =\n      (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;\n  TORCH_CHECK(size_n == actual_size_n, \"size_n = \", size_n,\n              \", actual_size_n = \", actual_size_n);\n\n  // Verify device and strides\n  TORCH_CHECK(a.device().is_cuda(), \"A is not on GPU\");\n  TORCH_CHECK(a.is_contiguous(), \"A is not contiguous\");\n\n  TORCH_CHECK(b_q_weight.device().is_cuda(), \"b_q_weight is not on GPU\");\n  TORCH_CHECK(b_q_weight.is_contiguous(), \"b_q_weight is not contiguous\");\n\n  TORCH_CHECK(b_scales.device().is_cuda(), \"b_scales is not on GPU\");\n  TORCH_CHECK(b_scales.is_contiguous(), \"b_scales is not contiguous\");\n\n  TORCH_CHECK(g_idx.device().is_cuda(), \"g_idx is not on GPU\");\n  TORCH_CHECK(g_idx.is_contiguous(), \"g_idx is not contiguous\");\n\n  TORCH_CHECK(perm.device().is_cuda(), \"perm is not on GPU\");\n  TORCH_CHECK(perm.is_contiguous(), \"perm is not contiguous\");\n\n  // Alloc buffers\n  auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());\n  torch::Tensor c = torch::empty({size_m, size_n}, options);\n  torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);\n\n  // thread_k: `k` size of a thread_tile in `weights` (can usually be left as\n  // auto -1)\n  int thread_k = -1;\n  // thread_n: `n` size of a thread_tile in `weights` (can usually be left as\n  // auto -1)\n  int thread_n = -1;\n  // sms: number of SMs to use for the kernel (can usually be left as auto -1)\n  int sms = -1;\n\n  // Verify g_idx and perm\n  TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||\n                  (g_idx.size(0) == size_k && perm.size(0) == size_k),\n              \"Unexpected g_idx.size(0) = \", g_idx.size(0),\n              \" and perm.size(0) = \", perm.size(0),\n              \", where size_k = \", size_k);\n\n  // Detect groupsize and act_order\n  int num_groups = -1;\n  int group_size = -1;\n  bool has_act_order = g_idx.size(0) != 0;\n\n  int b_rank = b_scales.sizes().size();\n  TORCH_CHECK(b_rank == 2, \"b_scales rank = \", b_rank, \" is not 2\");\n  TORCH_CHECK(b_scales.size(1) == size_n, \"b_scales dim 1 = \", b_scales.size(1),\n              \" is not size_n = \", size_n);\n  num_groups = b_scales.size(0);\n\n  if (has_act_order) {\n    if (is_k_full) {\n      TORCH_CHECK(num_groups > 1, \"For act_order, num_groups must be > 1\");\n      TORCH_CHECK(size_k % num_groups == 0, \"size_k = \", size_k,\n                  \", is not divisible by num_groups = \", num_groups);\n      group_size = size_k / num_groups;\n    } else {\n      group_size = 0;\n    }\n\n  } else {\n    if (num_groups > 1) {\n      TORCH_CHECK(\n          size_k % num_groups == 0, \"size_k = \", size_k,\n          \", is not divisible by b_scales.size(0) = \", b_scales.size(0));\n      group_size = size_k / num_groups;\n    } else {\n      group_size = -1;\n    }\n  }\n\n  // Verify workspace size\n  TORCH_CHECK(\n      size_n % gptq_marlin::min_thread_n == 0, \"size_n = \", size_n,\n      \", is not divisible by min_thread_n = \", gptq_marlin::min_thread_n);\n  int min_workspace_size =\n      (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;\n  TORCH_CHECK(workspace.numel() >= min_workspace_size,\n              \"workspace.numel = \", workspace.numel(),\n              \" is below min_workspace_size = \", min_workspace_size);\n\n  int dev = a.get_device();\n  if (a.scalar_type() == at::ScalarType::Half) {\n    gptq_marlin::marlin_mm_f16i4<half>(\n        a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),\n        b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),\n        a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,\n        workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,\n        group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,\n        thread_n, sms, gptq_marlin::max_par);\n  } else if (a.scalar_type() == at::ScalarType::BFloat16) {\n    gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(\n        a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),\n        c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),\n        g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),\n        size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,\n        is_k_full, num_groups, group_size, dev,\n        at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,\n        gptq_marlin::max_par);\n  } else {\n    TORCH_CHECK(false, \"gpt_marlin_gemm only supports bfloat16 and float16\");\n  }\n\n  return c;\n}\n\n#endif\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n// Copyrigth 2024 The vLLM team.\n// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#pragma once\n\n#include <torch/all.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <iostream>\n\nnamespace gptq_marlin {\n\n// 8 warps are a good choice since every SM has 4 schedulers and having more\n// than 1 warp per schedule allows some more latency hiding. At the same time,\n// we want relatively few warps to have many registers per warp and small tiles.\nstatic constexpr int default_threads = 256;\n\nstatic constexpr int pipe_stages =\n    4;  // 4 pipeline stages fit into shared memory\n\nstatic constexpr int min_thread_n = 64;\nstatic constexpr int min_thread_k = 64;\n\nstatic constexpr int tile_size = 16;\nstatic constexpr int max_par = 16;\n\ntemplate <typename T, int n>\nstruct Vec {\n  T elems[n];\n  __device__ T& operator[](int i) { return elems[i]; }\n};\n\nusing I4 = Vec<int, 4>;\n\nconstexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }\n\n#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__)\n// No support for async\n#else\n\n__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,\n                                      bool pred = true) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n      \"{\\n\"\n      \"   .reg .pred p;\\n\"\n      \"   setp.ne.b32 p, %0, 0;\\n\"\n      \"   @p cp.async.cg.shared.global [%1], [%2], %3;\\n\"\n      \"}\\n\" ::\"r\"((int)pred),\n      \"r\"(smem), \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n      \"{\\n\"\n      \"   cp.async.cg.shared.global [%0], [%1], %2;\\n\"\n      \"}\\n\" ::\"r\"(smem),\n      \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async_fence() {\n  asm volatile(\"cp.async.commit_group;\\n\" ::);\n}\n\ntemplate <int n>\n__device__ inline void cp_async_wait() {\n  asm volatile(\"cp.async.wait_group %0;\\n\" ::\"n\"(n));\n}\n\n#endif\n\n}  // namespace gptq_marlin\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n// Copyrigth 2024 The vLLM team.\n// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#ifndef _data_types_cuh\n#define _data_types_cuh\n#include \"gptq_marlin.cuh\"\n#include <cuda_fp16.h>\n#include <cuda_bf16.h>\n\n#ifdef __HIP_PLATFORM_AMD__\ntypedef __hip_bfloat16 nv_bfloat16;\ntypedef __hip_bfloat162 nv_bfloat162;\n#endif\n\nnamespace gptq_marlin {\n\ntemplate <typename scalar_t>\nclass ScalarType {};\n\ntemplate <>\nclass ScalarType<half> {\n public:\n  using scalar_t = half;\n  using scalar_t2 = half2;\n\n  // Matrix fragments for tensor core instructions; their precise layout is\n  // documented here:\n  // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type\n  using FragA = Vec<half2, 4>;\n  using FragB = Vec<half2, 2>;\n  using FragC = Vec<float, 4>;\n  using FragS = Vec<half2, 1>;\n\n  static __device__ float inline num2float(const half x) {\n    return __half2float(x);\n  }\n\n  static __device__ half2 inline num2num2(const half x) {\n    return __half2half2(x);\n  }\n\n  static __device__ half2 inline nums2num2(const half x1, const half x2) {\n    return __halves2half2(x1, x2);\n  }\n\n  static __host__ __device__ half inline float2num(const float x) {\n    return __float2half(x);\n  }\n};\n\ntemplate <>\nclass ScalarType<nv_bfloat16> {\n public:\n  using scalar_t = nv_bfloat16;\n  using scalar_t2 = nv_bfloat162;\n\n  using FragA = Vec<nv_bfloat162, 4>;\n  using FragB = Vec<nv_bfloat162, 2>;\n  using FragC = Vec<float, 4>;\n  using FragS = Vec<nv_bfloat162, 1>;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  static __device__ float inline num2float(const nv_bfloat16 x) {\n    return __bfloat162float(x);\n  }\n\n  static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {\n    return __bfloat162bfloat162(x);\n  }\n\n  static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,\n                                                  const nv_bfloat16 x2) {\n    return __halves2bfloat162(x1, x2);\n  }\n\n  static __host__ __device__ nv_bfloat16 inline float2num(const float x) {\n    return __float2bfloat16(x);\n  }\n#endif\n};\n\n}  // namespace gptq_marlin\n\n#endif\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cuda/gptq_marlin/ops.h",
    "content": "/**\n * @Description  :  \n * @Author       : Azure\n * @Date         : 2024-07-22 09:27:55\n * @Version      : 1.0.0\n * @LastEditors  : Azure \n * @LastEditTime : 2024-07-26 08:35:00\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. \n**/\n#pragma once\n\n#include <torch/library.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                               torch::Tensor& b_scales, torch::Tensor& g_idx,\n                               torch::Tensor& perm, torch::Tensor& workspace,\n                               int64_t num_bits, int64_t size_m, int64_t size_n,\n                               int64_t size_k, bool is_k_full);\n\n// torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,\n//                                  int64_t size_k, int64_t size_n,\n//                                  int64_t num_bits);"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cuda/setup.py",
    "content": "\nfrom setuptools import setup, Extension\nfrom torch.utils import cpp_extension\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nsetup(\n    name='KTransformersOps',\n    ext_modules=[\n        CUDAExtension(\n            'KTransformersOps', [\n                'custom_gguf/dequant.cu',\n                'binding.cpp',\n                'gptq_marlin/gptq_marlin.cu',\n                # 'gptq_marlin_repack.cu',\n            ],\n            extra_compile_args={\n                'cxx': ['-O3'],\n                'nvcc': [\n                    '-O3',\n                    '--use_fast_math',\n                    '-Xcompiler', '-fPIC',\n                ]\n            },\n        )\n    ],\n    cmdclass={'build_ext': BuildExtension}\n)"
  },
  {
    "path": "archive/csrc/ktransformers_ext/cuda/test_dequant.py",
    "content": "import os\nimport sys\nsys.path.insert(0,\"/home/zbx/ktransformers\")\nfrom ktransformers.util.custom_loader import GGUFLoader\nimport torch\n\ngguf_loader_1 = GGUFLoader(\"/mnt/data/model/DeepseekV3-q4km-gguf\")\ngguf_loader_2 = GGUFLoader(\"/mnt/data/chenht/model/gguf_for_ktransformers/DeepSeek-V3-bf16/\")\n\ntorch.set_default_dtype(torch.bfloat16)\n\ntensor_1 = gguf_loader_1.load_gguf_tensor(\"blk.0.attn_kv_a_mqa.weight\", \"cuda\")\ntensor_2 = gguf_loader_2.load_gguf_tensor(\"blk.0.attn_kv_a_mqa.weight\", \"cuda\")\n\nprint(tensor_1[0, -64:])\nprint(tensor_2[0, -64:])"
  },
  {
    "path": "archive/csrc/ktransformers_ext/examples/test_attention.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :  \nAuthor       : Jianwei Dong\nDate         : 2024-08-28 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-28 10:32:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nimport cpuinfer_ext\nfrom flash_attn import flash_attn_with_kvcache\nimport torch\n\nlayer_num = 10\nkv_head_num = 8\nq_head_num = 32\nhead_dim = 128\nblock_len = 128\nanchor_num = 1\ncache_seqlen = 8192\ncache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device=\"cpu\")\nseqlens_zero = torch.zeros((1,), dtype=torch.int32, device=\"cpu\")\nanchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC\nkv_type = cpuinfer_ext.kvcache.ggml_type.FP16\nretrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER\nlayer_step: int = 1\ntoken_step: int = 1\nlayer_offset: int = 0\nmax_thread_num: int = 2\nmax_batch_size: int = 1\nmax_block_num: int = 512\nCPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)\nvalidation_iter = 100\n\nwith torch.inference_mode(mode=True):\n    config = cpuinfer_ext.kvcache.KVCacheConfig(\n        layer_num,\n        kv_head_num,\n        q_head_num,\n        head_dim,\n        block_len,\n        anchor_num,\n        anchor_type,\n        kv_type,\n        retrieval_type,\n        layer_step,\n        token_step,\n        layer_offset,\n        max_block_num,\n        max_batch_size,\n        max_thread_num,\n    )\n    local_kvcache = cpuinfer_ext.kvcache.KVCache(config)\n\n    kvcaches = []\n    block_table = (\n        torch.arange(max_block_num, dtype=torch.int32, device=\"cpu\")\n        .contiguous()\n        .view(1, -1)\n    )\n\n    for layer_idx in range(layer_num):\n        k_cache = torch.randn(\n            (1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n        v_cache = torch.randn(\n            (1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n\n        CPUInfer.submit(\n            local_kvcache.update_kvcache_fp16(\n                k_cache.data_ptr(),\n                v_cache.data_ptr(),\n                layer_idx,\n                block_table.data_ptr(),\n                1,\n                max_block_num,\n                seqlens_zero.data_ptr(),\n                cache_seqlen,\n            )\n        )\n        CPUInfer.sync()\n\n        kvcaches.append((k_cache.to(\"cuda\"), v_cache.to(\"cuda\")))\n\n    # validation\n    for i in range(validation_iter):\n\n        k_cache = kvcaches[i % layer_num][0]\n        v_cache = kvcaches[i % layer_num][1]\n        input = torch.randn(\n            (1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n        output = torch.empty(\n            (1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n\n        # attn_lse: (bsz, q_len, q_head_num)\n        attn_lse = torch.empty(\n            (1, 1, q_head_num), dtype=torch.float32, device=\"cpu\"\n        ).contiguous()\n        input = input / 100\n\n        CPUInfer.submit(\n            local_kvcache.attn(\n                input.data_ptr(),\n                output.data_ptr(),\n                attn_lse.data_ptr(),\n                i % layer_num,\n                0,\n                1,\n                1,\n                max_block_num,\n                block_table.data_ptr(),\n                cache_seqlens.data_ptr(),\n                -1,\n                -1,\n                -1,\n            )\n        )\n        CPUInfer.sync()\n        # print(\"cpuinfer output\", output)\n\n        t_output = flash_attn_with_kvcache(\n            q=input.to(\"cuda\"),\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_seqlens=cache_seqlens.to(\"cuda\"),\n        )\n        # print(\"torch output\", t_output)\n\n        diff = torch.mean(torch.abs(output.to(\"cuda\") - t_output)) / torch.mean(\n            torch.abs(t_output)\n        )\n        print(\"diff = \", diff)\n        assert diff < 0.001\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/examples/test_linear.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:36:59\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\ninput_size = 16384\noutput_size = 5120\nstride = 32\ngroup_max_len = 1024\nproj_type = 1 # ggml_type::GGML_TYPE_F16\nhidden_type = 1 # ggml_type::GGML_TYPE_F16\nqlen = 30\nlayer_num = 10\nCPUInfer = cpuinfer_ext.CPUInfer(48)\nvalidation_iter = 100\n\nwith torch.inference_mode(mode=True):\n    linears = []\n    projs = []\n    for _ in range(layer_num):\n        proj = torch.randn((output_size, input_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)\n        linear = cpuinfer_ext.linear.Linear(config)\n        projs.append(proj)\n        linears.append(linear)\n\n    # validation\n    for i in range(validation_iter):\n        linear = linears[i % layer_num]\n        input = torch.randn((qlen, input_size), dtype=torch.float16).contiguous()\n        output = torch.empty((qlen, output_size), dtype=torch.float16).contiguous()\n        input = input / 100\n\n        CPUInfer.submit(\n            linear.forward(\n                qlen,\n                input.data_ptr(),\n                output.data_ptr()\n            )\n        )\n        CPUInfer.sync()\n        # print('cpuinfer output', output)\n\n        proj = projs[i%layer_num]\n        t_output = torch.mm(input, proj.t())\n        # print('torch output', t_output)\n\n        diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n        print('diff = ', diff)\n        assert(diff < 0.001)\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/examples/test_mlp.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:37:28\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nhidden_size = 5120\nintermediate_size = 3072\nstride = 32\ngroup_max_len = 1024\ngate_type = 1 # ggml_type::GGML_TYPE_F16\nup_type = 1 # ggml_type::GGML_TYPE_F16\ndown_type = 1 # ggml_type::GGML_TYPE_F16\nhidden_type = 1 # ggml_type::GGML_TYPE_F16\nqlen = 30\nlayer_num = 10\nCPUInfer = cpuinfer_ext.CPUInfer(48)\nvalidation_iter = 100\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\nwith torch.inference_mode(mode=True):\n    mlps = []\n    gate_projs = []\n    up_projs = []\n    down_projs = []\n    for _ in range(layer_num):\n        gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)\n        mlp = cpuinfer_ext.mlp.MLP(config)\n        gate_projs.append(gate_proj)\n        up_projs.append(up_proj)\n        down_projs.append(down_proj)\n        mlps.append(mlp)\n\n    # validation\n    for i in range(validation_iter):\n        mlp = mlps[i % layer_num]\n        input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()\n        output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()\n        input = input / 100\n\n        CPUInfer.submit(\n            mlp.forward(\n                qlen,\n                input.data_ptr(), \n                output.data_ptr()\n            )\n        )\n        CPUInfer.sync()\n        # print('cpuinfer output', output)\n\n        gate_proj = gate_projs[i%layer_num]\n        up_proj = up_projs[i%layer_num]\n        down_proj = down_projs[i%layer_num]\n        t_output = mlp_torch(input, gate_proj, up_proj, down_proj)\n        # print('torch output', t_output)\n\n        diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n        print('diff = ', diff)\n        assert(diff < 0.001)\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/examples/test_moe.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:38:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nexpert_num = 160\nhidden_size = 5120\nintermediate_size = 1536\nstride = 32\ngroup_min_len = 10\ngroup_max_len = 1024\ngate_type = 1 # ggml_type::GGML_TYPE_F16\nup_type = 1 # ggml_type::GGML_TYPE_F16\ndown_type = 1 # ggml_type::GGML_TYPE_F16\nhidden_type = 1 # ggml_type::GGML_TYPE_F16\nn_routed_experts = 6\nqlen = 30\nlayer_num = 10\nCPUInfer = cpuinfer_ext.CPUInfer(48)\nvalidation_iter = 100\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\nwith torch.inference_mode(mode=True):\n    moes = []\n    gate_projs = []\n    up_projs = []\n    down_projs = []\n    for _ in range(layer_num):\n        gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        config = cpuinfer_ext.moe.MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, stride, group_min_len, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)\n        moe = cpuinfer_ext.moe.MOE(config)\n        gate_projs.append(gate_proj)\n        up_projs.append(up_proj)\n        down_projs.append(down_proj)\n        moes.append(moe)\n\n    # validation\n    for i in range(validation_iter):\n        expert_ids = torch.stack([torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]).contiguous()\n        weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()\n        input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()\n        output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()\n        input = input / 100\n        \n        moe = moes[i % layer_num]\n        CPUInfer.submit(\n            moe.forward( \n                qlen,\n                n_routed_experts, \n                expert_ids.data_ptr(), \n                weights.data_ptr(), \n                input.data_ptr(), \n                output.data_ptr()\n            )\n        )\n        CPUInfer.sync()\n        # print('cpuinfer output', output)\n\n        gate_proj = gate_projs[i%layer_num]\n        up_proj = up_projs[i%layer_num]\n        down_proj = down_projs[i%layer_num]\n        t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj)\n        # print('torch output', t_output)\n\n        diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n        print('diff = ', diff)\n        assert(diff < 0.001)\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/ext_bindings.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022, Jianwei Dong\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n// Python bindings\n#include \"cpu_backend/cpuinfer.h\"\n#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU) && !defined(KTRANSFORMERS_USE_NPU)\n#include \"device_launch_parameters.h\"\n#endif\n#include \"llamafile/flags.h\"\n#include \"operators/kvcache/kvcache.h\"\n#include \"operators/llamafile/linear.h\"\n#include \"operators/llamafile/mlp.h\"\n#include \"operators/llamafile/moe.h\"\n\n#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)\n#include \"operators/amx/moe.hpp\"\n#endif\n\n#include \"pybind11/functional.h\"\n#include \"pybind11/operators.h\"\n#include \"pybind11/pybind11.h\"\n#include \"pybind11/stl.h\"\n#include <cstdint>\n#include <iostream>\n#include <memory>\n\nnamespace py = pybind11;\nusing namespace pybind11::literals;\n\n// Binding functions for the KVCache class\nclass KVCacheBindings {\n  public:\n    class AttnBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            const ggml_fp16_t *q_in;\n            ggml_fp16_t *output;\n            float *attn_lse;\n            int layer_idx;\n            int generate_token_idx;\n            int q_len;\n            int batch_size;\n            int max_block_num;\n            int *block_table;\n            int *cache_seqlens;\n            int pick_block_num;\n            int init_block_num;\n            int local_block_num;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &KVCache::attn, args_->kv_cache, args_->q_in, args_->output,\n                args_->attn_lse, args_->layer_idx, args_->generate_token_idx,\n                args_->q_len, args_->batch_size, args_->max_block_num,\n                args_->block_table, args_->cache_seqlens, args_->pick_block_num,\n                args_->init_block_num, args_->local_block_num);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t q_in, intptr_t output,\n                           intptr_t attn_lse, int layer_idx,\n                           int generate_token_idx, int q_len, int batch_size,\n                           int max_block_num, intptr_t block_table,\n                           intptr_t cache_seqlens, int pick_block_num,\n                           int init_block_num, int local_block_num) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (const ggml_fp16_t *)q_in,\n                                  (ggml_fp16_t *)output,\n                                  (float *)attn_lse,\n                                  layer_idx,\n                                  generate_token_idx,\n                                  q_len,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)block_table,\n                                  (int *)cache_seqlens,\n                                  pick_block_num,\n                                  init_block_num,\n                                  local_block_num};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class GetAllKVCacheOneLayerBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            int layer_id;\n            ggml_fp16_t *k_in;\n            ggml_fp16_t *v_in;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::get_all_kvcache_one_layer,\n                                     args_->kv_cache, args_->layer_id,\n                                     args_->k_in, args_->v_in);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,\n                           int layer_id) {\n            Args *args = new Args{nullptr, &kv_cache, layer_id,\n                                  (ggml_fp16_t *)k_in, (ggml_fp16_t *)v_in};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class GetAndUpdateKVCacheFp16Bindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            ggml_fp16_t *k_in;\n            ggml_fp16_t *v_in;\n            int layer_id;\n            int *block_table;\n            int batch_size;\n            int max_block_num;\n            int *cache_seqlens;\n            int q_len;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::get_and_update_kvcache_fp16,\n                                     args_->kv_cache, args_->k_in, args_->v_in,\n                                     args_->layer_id, args_->block_table,\n                                     args_->batch_size, args_->max_block_num,\n                                     args_->cache_seqlens, args_->q_len);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,\n                           int layer_id, intptr_t block_table, int batch_size,\n                           int max_block_num, intptr_t cache_seqlens,\n                           int q_len) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (ggml_fp16_t *)k_in,\n                                  (ggml_fp16_t *)v_in,\n                                  layer_id,\n                                  (int *)block_table,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)cache_seqlens,\n                                  q_len};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class GetKVCacheFp16Bindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            ggml_fp16_t *k_in;\n            ggml_fp16_t *v_in;\n            int layer_id;\n            int *block_table;\n            int batch_size;\n            int max_block_num;\n            int *cache_seqlens;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &KVCache::get_kvcache_fp16, args_->kv_cache, args_->k_in,\n                args_->v_in, args_->layer_id, args_->block_table,\n                args_->batch_size, args_->max_block_num, args_->cache_seqlens);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,\n                           int layer_id, intptr_t block_table, int batch_size,\n                           int max_block_num, intptr_t cache_seqlens) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (ggml_fp16_t *)k_in,\n                                  (ggml_fp16_t *)v_in,\n                                  layer_id,\n                                  (int *)block_table,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)cache_seqlens};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class UpdateKVCacheFp16Bindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            ggml_fp16_t *k_in;\n            ggml_fp16_t *v_in;\n            int layer_id;\n            int *block_table;\n            int batch_size;\n            int max_block_num;\n            int *cache_seqlens;\n            int q_len;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::update_kvcache_fp16,\n                                     args_->kv_cache, args_->k_in, args_->v_in,\n                                     args_->layer_id, args_->block_table,\n                                     args_->batch_size, args_->max_block_num,\n                                     args_->cache_seqlens, args_->q_len);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,\n                           int layer_id, intptr_t block_table, int batch_size,\n                           int max_block_num, intptr_t cache_seqlens,\n                           int q_len) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (ggml_fp16_t *)k_in,\n                                  (ggml_fp16_t *)v_in,\n                                  layer_id,\n                                  (int *)block_table,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)cache_seqlens,\n                                  q_len};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class UpdateImportanceBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            const ggml_fp16_t *importance;\n            int layer_id;\n            int *block_table;\n            int batch_size;\n            int max_block_num;\n            int *offset;\n            int width;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &KVCache::update_importance, args_->kv_cache, args_->importance,\n                args_->layer_id, args_->block_table, args_->batch_size,\n                args_->max_block_num, args_->offset, args_->width);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t importance, int layer_id,\n                           intptr_t block_table, int batch_size,\n                           int max_block_num, intptr_t offset, int width) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (const ggml_fp16_t *)importance,\n                                  layer_id,\n                                  (int *)block_table,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)offset,\n                                  width};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class AttnWithKVCacheBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            const ggml_fp16_t *q_in;\n            const ggml_fp16_t *k_in;\n            const ggml_fp16_t *v_in;\n            ggml_fp16_t *output;\n            float *attn_lse;\n            int layer_idx;\n            int generate_token_idx;\n            int q_len;\n            int batch_size;\n            int max_block_num;\n            int *block_table;\n            int *cache_seqlens;\n            int topk;\n            int local;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &KVCache::attn_with_kvcache, args_->kv_cache, args_->q_in,\n                args_->k_in, args_->v_in, args_->output, args_->attn_lse,\n                args_->layer_idx, args_->generate_token_idx, args_->q_len,\n                args_->batch_size, args_->max_block_num, args_->block_table,\n                args_->cache_seqlens, args_->topk, args_->local);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t q_in, intptr_t k_in,\n                           intptr_t v_in, intptr_t output, intptr_t attn_lse,\n                           int layer_idx, int generate_token_idx, int q_len,\n                           int batch_size, int max_block_num,\n                           intptr_t block_table, intptr_t cache_seqlens,\n                           int topk, int local) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (const ggml_fp16_t *)q_in,\n                                  (const ggml_fp16_t *)k_in,\n                                  (const ggml_fp16_t *)v_in,\n                                  (ggml_fp16_t *)output,\n                                  (float *)attn_lse,\n                                  layer_idx,\n                                  generate_token_idx,\n                                  q_len,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)block_table,\n                                  (int *)cache_seqlens,\n                                  topk,\n                                  local};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class ClearImportanceAllLayersBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            int *block_table;\n            int *cache_seqlens;\n            int batch_size;\n            int max_block_num;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::clear_importance_all_layers,\n                                     args_->kv_cache, args_->block_table,\n                                     args_->cache_seqlens, args_->batch_size,\n                                     args_->max_block_num);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,\n                           intptr_t cache_seqlens, int batch_size,\n                           int max_block_num) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (int *)block_table,\n                                  (int *)cache_seqlens,\n                                  batch_size,\n                                  max_block_num};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class CalcAnchorAllLayersBindinds {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            int *block_table;\n            int *cache_seqlens;\n            int batch_size;\n            int max_block_num;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::calc_anchor_all_layers,\n                                     args_->kv_cache, args_->block_table,\n                                     args_->cache_seqlens, args_->batch_size,\n                                     args_->max_block_num);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,\n                           intptr_t cache_seqlens, int batch_size,\n                           int max_block_num) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (int *)block_table,\n                                  (int *)cache_seqlens,\n                                  batch_size,\n                                  max_block_num};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class LoadKVCacheBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            std::string tensor_file_path;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::load_kvcache, args_->kv_cache,\n                                     args_->tensor_file_path);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, std::string tensor_file_path) {\n            Args *args =\n                new Args{nullptr, &kv_cache, (std::string)tensor_file_path};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class DumpKVCacheBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            int *block_table;\n            int cache_total_len;\n            std::string tensor_file_path;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::dump_kvcache, args_->kv_cache,\n                                     args_->block_table, args_->cache_total_len,\n                                     args_->tensor_file_path);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,\n                           int cache_total_len, std::string tensor_file_path) {\n            Args *args =\n                new Args{nullptr, &kv_cache, (int *)block_table,\n                         cache_total_len, (std::string)tensor_file_path};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n};\n\nclass LinearBindings {\n  public:\n    class WarmUpBindinds {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            Linear *linear;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&Linear::warm_up, args_->linear);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(Linear &linear) {\n            Args *args = new Args{nullptr, &linear};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class ForwardBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            Linear *linear;\n            int qlen;\n            const void *input;\n            void *output;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&Linear::forward, args_->linear,\n                                     args_->qlen, args_->input, args_->output);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(Linear &linear, int qlen, intptr_t input,\n                           intptr_t output) {\n            Args *args = new Args{nullptr, &linear, qlen, (const void *)input,\n                                  (void *)output};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n};\n\nclass MLPBindings {\n  public:\n    class WarmUpBindinds {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            MLP *mlp;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&MLP::warm_up, args_->mlp);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP &mlp) {\n            Args *args = new Args{nullptr, &mlp};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class ForwardBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            MLP *mlp;\n            int qlen;\n            const void *input;\n            void *output;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen,\n                                     args_->input, args_->output);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(MLP &mlp, int qlen, intptr_t input,\n                           intptr_t output) {\n            Args *args = new Args{nullptr, &mlp, qlen, (const void *)input,\n                                  (void *)output};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n};\n\nclass MOEBindings {\n  public:\n    class WarmUpBindinds {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            MOE *moe;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&MOE::warm_up, args_->moe);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE &moe) {\n            Args *args = new Args{nullptr, &moe};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class ForwardBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            MOE *moe;\n            int qlen;\n            int k;\n            const uint64_t *expert_ids;\n            const float *weights;\n            const void *input;\n            void *output;\n            int *batch_size_tensor;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &MOE::forward, args_->moe, args_->qlen, args_->k,\n                args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(MOE &moe, int qlen, int k, intptr_t expert_ids,\n                           intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) {\n            Args *args = new Args{nullptr,\n                                  &moe,\n                                  qlen,\n                                  k,\n                                  (const uint64_t *)expert_ids,\n                                  (const float *)weights,\n                                  (const void *)input,\n                                  (void *)output,\n                                  (int *)batch_size_tensor};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n};\n\n\n#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)\ntemplate<class T>\nclass AMX_MOEBindings {\n  public:\n    class WarmUpBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            AMX_MOE<T> *moe;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&AMX_MOE<T>::warm_up, args_->moe);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(AMX_MOE<T> &moe) {\n            Args *args = new Args{nullptr, &moe};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class LoadWeightsBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            AMX_MOE<T> *moe;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&AMX_MOE<T>::load_weights, args_->moe);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(AMX_MOE<T> &moe) {\n            Args *args = new Args{nullptr, &moe};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class ForwardBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            AMX_MOE<T> *moe;\n            int qlen;\n            int k;\n            const uint64_t *expert_ids;\n            const float *weights;\n            const void *input;\n            void *output;\n            int *batch_size_tensor;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &AMX_MOE<T>::forward, args_->moe, args_->qlen, args_->k,\n                args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(AMX_MOE<T> &moe, int qlen, int k, intptr_t expert_ids,\n                        intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) {\n            Args *args = new Args{nullptr,\n                                &moe,\n                                qlen,\n                                k,\n                                (const uint64_t *)expert_ids,\n                                (const float *)weights,\n                                (const void *)input,\n                                (void *)output,\n                                (int *)batch_size_tensor};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n};\n#endif\n\nPYBIND11_MODULE(cpuinfer_ext, m) {\n    py::class_<CPUInfer>(m, \"CPUInfer\")\n        .def(py::init<int>())\n        .def(\"submit\", &CPUInfer::submit)\n        .def(\"submit_with_cuda_stream\", &CPUInfer::submit_with_cuda_stream)\n        .def(\"sync\", &CPUInfer::sync)\n        .def(\"sync_with_cuda_stream\", &CPUInfer::sync_with_cuda_stream);\n\n    auto linear_module = m.def_submodule(\"linear\");\n    py::class_<LinearConfig>(linear_module, \"LinearConfig\")\n        .def(py::init([](int hidden_size, int intermediate_size, int stride,\n                         int group_max_len, intptr_t proj, int proj_type,\n                         int hidden_type) {\n            return LinearConfig(hidden_size, intermediate_size, stride,\n                                group_max_len, (void *)proj,\n                                (ggml_type)proj_type, (ggml_type)hidden_type);\n        }));\n    py::class_<Linear>(linear_module, \"Linear\")\n        .def(py::init<LinearConfig>())\n        .def(\"warm_up\", &LinearBindings::WarmUpBindinds::cpuinfer_interface)\n        .def(\"forward\", &LinearBindings::ForwardBindings::cpuinfer_interface);\n\n    auto mlp_module = m.def_submodule(\"mlp\");\n    py::class_<MLPConfig>(mlp_module, \"MLPConfig\")\n        .def(py::init([](int hidden_size, int intermediate_size, int stride,\n                         int group_max_len, intptr_t gate_proj,\n                         intptr_t up_proj, intptr_t down_proj, int gate_type,\n                         int up_type, int down_type, int hidden_type) {\n            return MLPConfig(hidden_size, intermediate_size, stride,\n                             group_max_len, (void *)gate_proj, (void *)up_proj,\n                             (void *)down_proj, (ggml_type)gate_type,\n                             (ggml_type)up_type, (ggml_type)down_type,\n                             (ggml_type)hidden_type);\n        }));\n    py::class_<MLP>(mlp_module, \"MLP\")\n        .def(py::init<MLPConfig>())\n        .def(\"warm_up\", &MLPBindings::WarmUpBindinds::cpuinfer_interface)\n        .def(\"forward\", &MLPBindings::ForwardBindings::cpuinfer_interface);\n\n    auto moe_module = m.def_submodule(\"moe\");\n    py::class_<MOEConfig>(moe_module, \"MOEConfig\")\n        .def(py::init([](int expert_num, int routed_expert_num, int hidden_size,\n                         int intermediate_size, int stride, int group_min_len,\n                         int group_max_len, bool use_silu, intptr_t gate_proj,\n                         intptr_t up_proj, intptr_t down_proj, int gate_type,\n                         int up_type, int down_type, int hidden_type) {\n            return MOEConfig(expert_num, routed_expert_num, hidden_size,\n                             intermediate_size, stride, group_min_len,\n                             group_max_len, use_silu, (void *)gate_proj, (void *)up_proj,\n                             (void *)down_proj, (ggml_type)gate_type,\n                             (ggml_type)up_type, (ggml_type)down_type,\n                             (ggml_type)hidden_type);\n        }));\n    py::class_<MOE>(moe_module, \"MOE\")\n        .def(py::init<MOEConfig>())\n        .def(\"warm_up\", &MOEBindings::WarmUpBindinds::cpuinfer_interface)\n        .def(\"forward\", &MOEBindings::ForwardBindings::cpuinfer_interface);\n\n\n    #if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)\n    py::class_<AMX_MOEConfig>(moe_module, \"AMX_MOEConfig\")\n        .def(py::init([](int expert_num, int routed_expert_num, int hidden_size,\n                         int intermediate_size,\n                         int max_len, bool use_silu, intptr_t gate_proj,\n                         intptr_t up_proj, intptr_t down_proj) {\n            return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size,\n                                 intermediate_size, \n                                 max_len, use_silu, (void *)gate_proj,\n                                 (void *)up_proj, (void *)down_proj);\n        }));\n\n    py::class_<AMX_MOE<amx::GemmKernel224BF>>(moe_module, \"AMXBF16_MOE\")\n        .def(py::init<AMX_MOEConfig>())\n        .def(\"warm_up\", &AMX_MOEBindings<amx::GemmKernel224BF>::WarmUpBindings::cpuinfer_interface)\n        .def(\"load_weights\", &AMX_MOEBindings<amx::GemmKernel224BF>::LoadWeightsBindings::cpuinfer_interface)\n        .def(\"forward\", &AMX_MOEBindings<amx::GemmKernel224BF>::ForwardBindings::cpuinfer_interface);\n    py::class_<AMX_MOE<amx::GemmKernel224Int8>>(moe_module, \"AMXInt8_MOE\")\n        .def(py::init<AMX_MOEConfig>())\n        .def(\"warm_up\", &AMX_MOEBindings<amx::GemmKernel224Int8>::WarmUpBindings::cpuinfer_interface)\n        .def(\"load_weights\", &AMX_MOEBindings<amx::GemmKernel224Int8>::LoadWeightsBindings::cpuinfer_interface)\n        .def(\"forward\", &AMX_MOEBindings<amx::GemmKernel224Int8>::ForwardBindings::cpuinfer_interface);\n\n    #endif\n\n    auto kvcache_module = m.def_submodule(\"kvcache\");\n\n    py::enum_<AnchorType>(kvcache_module, \"AnchorType\")\n        .value(\"FIXED\", AnchorType::FIXED_ANCHOR)\n        .value(\"DYNAMIC\", AnchorType::DYNAMIC)\n        .value(\"QUEST\", AnchorType::QUEST)\n        .value(\"BLOCK_MAX\", AnchorType::BLOCK_MAX)\n        .value(\"BLOCK_MEAN\", AnchorType::BLOCK_MEAN);\n    py::enum_<ggml_type>(kvcache_module, \"ggml_type\")\n        .value(\"FP16\", ggml_type::GGML_TYPE_F16)\n        .value(\"FP32\", ggml_type::GGML_TYPE_F32)\n        .value(\"Q4_0\", ggml_type::GGML_TYPE_Q4_0)\n        .value(\"Q8_0\", ggml_type::GGML_TYPE_Q8_0);\n    py::enum_<RetrievalType>(kvcache_module, \"RetrievalType\")\n        .value(\"LAYER\", RetrievalType::LAYER)\n        .value(\"KVHEAD\", RetrievalType::KVHEAD)\n        .value(\"QHEAD\", RetrievalType::QHEAD);\n\n    py::class_<KVCacheConfig>(kvcache_module, \"KVCacheConfig\")\n        .def(py::init<int, int, int, int, int, int, AnchorType, ggml_type,\n                      RetrievalType, int, int, int, int, int, int>())\n        .def_readwrite(\"layer_num\", &KVCacheConfig::layer_num)\n        .def_readwrite(\"kv_head_num\", &KVCacheConfig::kv_head_num)\n        .def_readwrite(\"q_head_num\", &KVCacheConfig::q_head_num)\n        .def_readwrite(\"head_dim\", &KVCacheConfig::head_dim)\n        .def_readwrite(\"block_len\", &KVCacheConfig::block_len)\n        .def_readwrite(\"anchor_num\", &KVCacheConfig::anchor_num)\n        .def_readwrite(\"anchor_type\", &KVCacheConfig::anchor_type)\n        .def_readwrite(\"kv_type\", &KVCacheConfig::kv_type)\n        .def_readwrite(\"retrieval_type\", &KVCacheConfig::retrieval_type)\n        .def_readwrite(\"layer_step\", &KVCacheConfig::layer_step)\n        .def_readwrite(\"token_step\", &KVCacheConfig::token_step)\n        .def_readwrite(\"layer_offset\", &KVCacheConfig::layer_offset)\n        .def_readwrite(\"max_block_num\", &KVCacheConfig::max_block_num)\n        .def_readwrite(\"max_batch_size\", &KVCacheConfig::max_batch_size)\n        .def_readwrite(\"max_thread_num\", &KVCacheConfig::max_thread_num);\n    py::class_<KVCache>(kvcache_module, \"KVCache\")\n        .def(py::init<KVCacheConfig>())\n        .def(\"get_cache_total_len\", &KVCache::get_cache_total_len)\n        .def(\"update_cache_total_len\",\n             [](KVCache &kvcache, int cache_total_len) {\n                 kvcache.update_cache_total_len(cache_total_len);\n             })\n        .def(\"attn\", &KVCacheBindings::AttnBindings::cpuinfer_interface)\n        .def(\n            \"get_all_kvcache_one_layer\",\n            &KVCacheBindings::GetAllKVCacheOneLayerBindings::cpuinfer_interface)\n        .def(\"get_and_update_kvcache_fp16\",\n             &KVCacheBindings::GetAndUpdateKVCacheFp16Bindings::\n                 cpuinfer_interface)\n        .def(\"get_kvcache_fp16\",\n             &KVCacheBindings::GetKVCacheFp16Bindings::cpuinfer_interface)\n        .def(\"update_kvcache_fp16\",\n             &KVCacheBindings::UpdateKVCacheFp16Bindings::cpuinfer_interface)\n        .def(\"update_importance\",\n             &KVCacheBindings::UpdateImportanceBindings::cpuinfer_interface)\n        .def(\"attn_with_kvcache\",\n             &KVCacheBindings::AttnWithKVCacheBindings::cpuinfer_interface)\n        .def(\"clear_importance_all_layers\",\n             &KVCacheBindings::ClearImportanceAllLayersBindings::\n                 cpuinfer_interface)\n        .def(\"calc_anchor_all_layers\",\n             &KVCacheBindings::CalcAnchorAllLayersBindinds::cpuinfer_interface);\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/amx/la/amx.hpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2025-04-25 18:28:12\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2025-04-25 18:28:12\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#pragma once\n#include <array>\n#include <cassert>\n#include <cstdint>\n#include <cstdio>\n#include <immintrin.h>\n#include <iostream>\n#include <random>\n#include <stdexcept>\n#include <stdlib.h>\n#include <sys/syscall.h>\n#include <unistd.h>\n\n#include \"utils.hpp\"\n#include <memory>\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define RESTRICT __restrict\n#else\n#define RESTRICT __restrict__\n#endif\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define ALWAYS_INLINE __forceinline\n#elif __has_attribute(always_inline) || defined(__GNUC__)\n#define ALWAYS_INLINE __attribute__((__always_inline__)) inline\n#else\n#define ALWAYS_INLINE inline\n#endif\n\nnamespace amx {\n\n#define ARCH_GET_XCOMP_PERM 0x1022\n#define ARCH_REQ_XCOMP_PERM 0x1023\n#define XFEATURE_XTILECFG 17\n#define XFEATURE_XTILEDATA 18\n\nconst int TMMCount = 8;\nconst int MaxTileHeight = 16;\nconst int MaxTileWidth = 64;\n\nconst int AMX_BLK_SIZE = 32;\n\n#define TMM0 0\n#define TMM1 1\n#define TMM2 2\n#define TMM3 3\n#define TMM4 4\n#define TMM5 5\n#define TMM6 6\n#define TMM7 7\n\ninline bool enable_amx() {\n  static thread_local bool initialized = false;\n  if (initialized) {\n    return true;\n  }\n  initialized = true;\n\n  if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {\n    printf(\"\\n Fail to do XFEATURE_XTILEDATA \\n\\n\");\n    return false;\n  } else {\n    // printf(\"\\n TILE DATA USE SET - OK \\n\\n\");\n    return true;\n  }\n  return true;\n}\n\nstruct alignas(64) TileConfig {\n  uint8_t palette;\n  uint8_t start_row;\n  std::array<uint8_t, 14> __0 = {};\n  std::array<uint16_t, 8> colsb;\n  std::array<uint8_t, 16> __1 = {};\n  std::array<uint8_t, 8> rows;\n  std::array<uint8_t, 8> __2 = {};\n\n  TileConfig() {\n    palette = 1;\n    start_row = 0;\n    for (int i = 0; i < 8; i++) {\n      set_row_col(i, 0, 0);\n    }\n  }\n\n  void set_row_col(int i, uint8_t row, uint16_t col) {\n    colsb[i] = col;\n    rows[i] = row;\n  }\n\n  void set_config() { _tile_loadconfig(this); }\n\n  static void load_data(int to, void *from, size_t stride) {\n    switch (to) {\n    case 0:\n      _tile_loadd(0, from, stride);\n      break;\n    case 1:\n      _tile_loadd(1, from, stride);\n      break;\n    case 2:\n      _tile_loadd(2, from, stride);\n      break;\n    case 3:\n      _tile_loadd(3, from, stride);\n      break;\n    case 4:\n      _tile_loadd(4, from, stride);\n      break;\n    case 5:\n      _tile_loadd(5, from, stride);\n      break;\n    case 6:\n      _tile_loadd(6, from, stride);\n      break;\n    case 7:\n      _tile_loadd(7, from, stride);\n      break;\n    default:\n      throw std::runtime_error(\"no such tile\");\n    }\n  }\n\n  static void store_data(int from, void *to, size_t stride) {\n    switch (from) {\n    case 0:\n      _tile_stored(0, to, stride);\n      break;\n    case 1:\n      _tile_stored(1, to, stride);\n      break;\n    case 2:\n      _tile_stored(2, to, stride);\n      break;\n    case 3:\n      _tile_stored(3, to, stride);\n      break;\n    case 4:\n      _tile_stored(4, to, stride);\n      break;\n    case 5:\n      _tile_stored(5, to, stride);\n      break;\n    case 6:\n      _tile_stored(6, to, stride);\n      break;\n    case 7:\n      _tile_stored(7, to, stride);\n      break;\n    default:\n      throw std::runtime_error(\"no such tile\");\n    }\n  }\n};\n\nstatic_assert(sizeof(TileConfig) == 64);\n\ninline void debug_tile(int t) {\n  printf(\"Tile %d\\n\", t);\n  uint8_t data[16][64] = {};\n  TileConfig::store_data(t, data, 64);\n  for (int i = 0; i < 16; i++) {\n    for (int j = 0; j < 64; j++) {\n      printf(\"%3d \", data[i][j]);\n    }\n    printf(\"\\n\");\n  }\n  printf(\"\\n\");\n}\n\ninline void debug_tiles(int to = 8) {\n  for (int i = 0; i < to; i++) {\n    debug_tile(i);\n  }\n}\n\ninline void debug_m512(__m512 x) {\n  float data[16];\n  _mm512_storeu_ps(data, x);\n  for (int i = 0; i < 16; i++) {\n    printf(\"%f \", data[i]);\n  }\n  printf(\"\\n\");\n}\n\n// transpose utils\ninline void transpose_16x16_32bit(__m512i *v) {\n  __m512i v1[16];\n  v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);\n  v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);\n  v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);\n  v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);\n  v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);\n  v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);\n  v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);\n  v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);\n  v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);\n  v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);\n  v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);\n  v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);\n  v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);\n  v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);\n  v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);\n  v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);\n\n  v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);\n  v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);\n  v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);\n  v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);\n  v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);\n  v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);\n  v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);\n  v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);\n  v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);\n  v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);\n  v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);\n  v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);\n  v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);\n  v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);\n  v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);\n  v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);\n\n  v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);\n  v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);\n  v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);\n  v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);\n  v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);\n  v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);\n  v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);\n  v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);\n  v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);\n  v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);\n  v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);\n  v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);\n  v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);\n  v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);\n  v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);\n  v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);\n\n  v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);\n  v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);\n  v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);\n  v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);\n  v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);\n  v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);\n  v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);\n  v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);\n  v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);\n  v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);\n  v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);\n  v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);\n  v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);\n  v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);\n  v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);\n  v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);\n}\n\n/*\n  Transpose 16x16 32-bit elements\n  Note that v must be 64 byte aligned\n*/\ninline void transpose_16x16_32bit(__m512i *v, size_t stride) {\n  assert(reinterpret_cast<intptr_t>(v) % 64 == 0 && \"v must be 64 aligned\");\n\n  auto stride_v = [=](int i) { return offset_pointer(v, i * stride); };\n  __m512i v1[16];\n\n  v1[0] = _mm512_unpacklo_epi32(*stride_v(0), *stride_v(1));\n  v1[1] = _mm512_unpackhi_epi32(*stride_v(0), *stride_v(1));\n  v1[2] = _mm512_unpacklo_epi32(*stride_v(2), *stride_v(3));\n  v1[3] = _mm512_unpackhi_epi32(*stride_v(2), *stride_v(3));\n  v1[4] = _mm512_unpacklo_epi32(*stride_v(4), *stride_v(5));\n  v1[5] = _mm512_unpackhi_epi32(*stride_v(4), *stride_v(5));\n  v1[6] = _mm512_unpacklo_epi32(*stride_v(6), *stride_v(7));\n  v1[7] = _mm512_unpackhi_epi32(*stride_v(6), *stride_v(7));\n  v1[8] = _mm512_unpacklo_epi32(*stride_v(8), *stride_v(9));\n  v1[9] = _mm512_unpackhi_epi32(*stride_v(8), *stride_v(9));\n  v1[10] = _mm512_unpacklo_epi32(*stride_v(10), *stride_v(11));\n  v1[11] = _mm512_unpackhi_epi32(*stride_v(10), *stride_v(11));\n  v1[12] = _mm512_unpacklo_epi32(*stride_v(12), *stride_v(13));\n  v1[13] = _mm512_unpackhi_epi32(*stride_v(12), *stride_v(13));\n  v1[14] = _mm512_unpacklo_epi32(*stride_v(14), *stride_v(15));\n  v1[15] = _mm512_unpackhi_epi32(*stride_v(14), *stride_v(15));\n\n  *stride_v(0) = _mm512_unpacklo_epi64(v1[0], v1[2]);\n  *stride_v(1) = _mm512_unpackhi_epi64(v1[0], v1[2]);\n  *stride_v(2) = _mm512_unpacklo_epi64(v1[1], v1[3]);\n  *stride_v(3) = _mm512_unpackhi_epi64(v1[1], v1[3]);\n  *stride_v(4) = _mm512_unpacklo_epi64(v1[4], v1[6]);\n  *stride_v(5) = _mm512_unpackhi_epi64(v1[4], v1[6]);\n  *stride_v(6) = _mm512_unpacklo_epi64(v1[5], v1[7]);\n  *stride_v(7) = _mm512_unpackhi_epi64(v1[5], v1[7]);\n  *stride_v(8) = _mm512_unpacklo_epi64(v1[8], v1[10]);\n  *stride_v(9) = _mm512_unpackhi_epi64(v1[8], v1[10]);\n  *stride_v(10) = _mm512_unpacklo_epi64(v1[9], v1[11]);\n  *stride_v(11) = _mm512_unpackhi_epi64(v1[9], v1[11]);\n  *stride_v(12) = _mm512_unpacklo_epi64(v1[12], v1[14]);\n  *stride_v(13) = _mm512_unpackhi_epi64(v1[12], v1[14]);\n  *stride_v(14) = _mm512_unpacklo_epi64(v1[13], v1[15]);\n  *stride_v(15) = _mm512_unpackhi_epi64(v1[13], v1[15]);\n\n  v1[0] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0x88);\n  v1[1] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0x88);\n  v1[2] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0x88);\n  v1[3] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0x88);\n  v1[4] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0xdd);\n  v1[5] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0xdd);\n  v1[6] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0xdd);\n  v1[7] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0xdd);\n  v1[8] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0x88);\n  v1[9] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0x88);\n  v1[10] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0x88);\n  v1[11] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0x88);\n  v1[12] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0xdd);\n  v1[13] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0xdd);\n  v1[14] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0xdd);\n  v1[15] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0xdd);\n\n  *stride_v(0) = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);\n  *stride_v(1) = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);\n  *stride_v(2) = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);\n  *stride_v(3) = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);\n  *stride_v(4) = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);\n  *stride_v(5) = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);\n  *stride_v(6) = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);\n  *stride_v(7) = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);\n  *stride_v(8) = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);\n  *stride_v(9) = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);\n  *stride_v(10) = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);\n  *stride_v(11) = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);\n  *stride_v(12) = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);\n  *stride_v(13) = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);\n  *stride_v(14) = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);\n  *stride_v(15) = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);\n}\n\nstruct GemmKernel224BF {\n  using dt = ggml_bf16_t;\n  using output_t = float;\n  static const int TILE_M = 16;\n  static const int TILE_K = 32;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 2;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  static inline const int N_BLOCK = 256;\n  static inline const int K_BLOCK = 1792;\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 32\n    for (int i = 0; i < 2; i++)\n      tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));\n\n    // size is 16 x 32\n    for (int i = 2; i < 4; i++)\n      tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++)\n      tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n  }\n\n  static void load_a(dt *a, size_t lda) {\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n  }\n\n  static void load_b(dt *b, size_t ldb) {\n    _tile_loadd(2, b, ldb);\n    _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);\n  }\n\n  static void clean_c() {\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n  }\n\n  static void load_c(output_t *c, size_t ldc) {\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n  }\n\n  static void store_c(output_t *c, size_t ldc) {\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n  }\n\n  static void run_tile() {\n    _tile_dpbf16ps(4, 0, 2);\n    _tile_dpbf16ps(5, 0, 3);\n    _tile_dpbf16ps(6, 1, 2);\n    _tile_dpbf16ps(7, 1, 3);\n  }\n\n  struct BufferA {\n    ggml_bf16_t *a;\n    int max_m, k;\n\n    static size_t required_size(int max_m, int k) { return max_m * k * sizeof(ggml_bf16_t); }\n\n    BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(max_m % M_STEP == 0);\n      assert(k % K_STEP == 0);\n      a = reinterpret_cast<ggml_bf16_t *>(ptr);\n    }\n\n    void from_mat(int m, ggml_bf16_t *src, int ith, int nth) {\n      assert(m <= max_m);\n      assert(ith == 0 && nth == 1);\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n              __m512i *s = (__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin);\n              __m512i *d = (__m512i *)(a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP +\n                                       i * K_STEP);\n              avx512_copy_32xbf16(s, d);\n            }\n          }\n        }\n      }\n    }\n\n    ggml_bf16_t *get_submat(int m, int k, int m_begin, int k_begin) {\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;\n    }\n  };\n\n  struct BufferB {\n    ggml_bf16_t *b;\n    int n, k;\n\n    static size_t required_size(int n, int k) { return n * k * sizeof(ggml_bf16_t); }\n\n    BufferB(int n, int k, void *ptr) : n(n), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(n % N_STEP == 0);\n      assert(k % K_STEP == 0);\n      b = reinterpret_cast<ggml_bf16_t *>(ptr);\n    }\n\n    void from_mat(ggml_bf16_t *src, int ith, int nth) {\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < N_STEP; i++) {\n              __m512i *s = (__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin);\n              __m512i *d = (__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                       k_begin * N_STEP + i * K_STEP);\n              avx512_copy_32xbf16(s, d);\n            }\n            transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                              n_begin * k_block_size + k_begin * N_STEP));\n            transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                              n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));\n          }\n        }\n      }\n    }\n\n    ggml_bf16_t *get_submat(int n, int k, int n_begin, int k_begin) {\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      n_begin -= n_block_begin;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;\n    }\n  };\n\n  struct BufferC {\n    float *c;\n    int max_m, n;\n\n    static size_t required_size(int max_m, int n) { return max_m * n * sizeof(float); }\n\n    BufferC(int max_m, int n, void *ptr) : max_m(max_m), n(n) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(max_m % M_STEP == 0);\n      assert(n % N_STEP == 0);\n      c = reinterpret_cast<float *>(ptr);\n    }\n\n    void to_mat(int m, ggml_bf16_t *dst, int ith, int nth) {\n      assert(m <= max_m);\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            __m512 *x0 =\n                (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);\n            __m512 *x1 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP +\n                                    i * N_STEP + 16);\n            avx512_32xfp32_to_32xbf16(x0, x1, (__m512i *)(dst + (m_begin + i) * n + n_block_begin + n_begin));\n          }\n        }\n      }\n    }\n\n    float *get_submat(int m, int n, int m_begin, int n_begin) {\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      n_begin -= n_block_begin;\n      return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;\n    }\n  };\n};\n\nstruct GemmKernel224Int8 {\n  using dt = int8_t;\n  using output_t = int32_t;\n  static const int TILE_M = 16;\n  static const int TILE_K = 64;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  static inline const int N_BLOCK = 256;\n  static inline const int K_BLOCK = 3584;\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 64\n    for (int i = 0; i < 2; i++)\n      tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));\n\n    // size is 16 x 64\n    for (int i = 2; i < 4; i++)\n      tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++)\n      tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n  }\n\n  static void load_a(dt *a, size_t lda) {\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n  }\n\n  static void load_b(dt *b, size_t ldb) {\n    _tile_loadd(2, b, ldb);\n    _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);\n  }\n\n  static void clean_c() {\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n  }\n\n  static void load_c(output_t *c, size_t ldc) {\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n  }\n\n  static void store_c(output_t *c, size_t ldc) {\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n  }\n\n  static void run_tile() {\n    _tile_dpbssd(4, 0, 2);\n    _tile_dpbssd(5, 0, 3);\n    _tile_dpbssd(6, 1, 2);\n    _tile_dpbssd(7, 1, 3);\n  }\n\n  struct BufferA {\n    int8_t *a;\n    float *d;\n    int max_m, k;\n\n    static size_t required_size(int max_m, int k) { return max_m * k * sizeof(int8_t) + max_m * sizeof(float); }\n\n    BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(max_m % M_STEP == 0);\n      assert(k % K_STEP == 0);\n      a = reinterpret_cast<int8_t *>(ptr);\n      d = reinterpret_cast<float *>(a + max_m * k);\n    }\n\n    void from_mat(int m, ggml_bf16_t *src, int ith, int nth) {\n      assert(m <= max_m);\n      assert(ith == 0 && nth == 1);\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n          float amax = 0.0f;\n          for (int j = 0; j < k; j += 32) {\n            __m512 f0, f1;\n            avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + j), &f0, &f1);\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n          }\n          d[m_begin + i] = amax / ((1 << 7) - 1);\n        }\n      }\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n              __m512 id = _mm512_set1_ps(d[m_begin + i] ? 1.0f / d[m_begin + i] : 0.0f);\n              int8_t *dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP;\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n              __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n              __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n              __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n              __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n              _mm_storeu_si128((__m128i *)dst, s0);\n              _mm_storeu_si128((__m128i *)(dst + 16), s1);\n              _mm_storeu_si128((__m128i *)(dst + 32), s2);\n              _mm_storeu_si128((__m128i *)(dst + 48), s3);\n            }\n          }\n        }\n      }\n    }\n\n    int8_t *get_submat(int m, int k, int m_begin, int k_begin) {\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;\n    }\n\n    float *get_scale(int m, int m_begin) { return d + m_begin; }\n  };\n\n  struct BufferB {\n    int8_t *b;\n    float *d;\n    int n, k;\n\n    static size_t required_size(int n, int k) { return n * k * sizeof(int8_t) + n * sizeof(float); }\n\n    BufferB(int n, int k, void *ptr) : n(n), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(n % N_STEP == 0);\n      assert(k % K_STEP == 0);\n      b = reinterpret_cast<int8_t *>(ptr);\n      d = reinterpret_cast<float *>(b + n * k);\n    }\n\n    void from_mat(ggml_bf16_t *src, int ith, int nth) {\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < N_STEP; i++) {\n          float amax = 0.0f;\n          for (int j = 0; j < k; j += 32) {\n            __m512 f0, f1;\n            avx512_32xbf16_to_32xfp32((__m512i *)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1);\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n          }\n          d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1);\n        }\n      }\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < N_STEP; i++) {\n              __m512 id = _mm512_set1_ps(d[n_block_begin + n_begin + i] ? 1.0f / d[n_block_begin + n_begin + i] : 0.0f);\n              int8_t *dst = b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                            k_begin * N_STEP + i * K_STEP;\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32((__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin),\n                                        &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n              __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n              __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n              __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n              __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n              _mm_storeu_si128((__m128i *)dst, s0);\n              _mm_storeu_si128((__m128i *)(dst + 16), s1);\n              _mm_storeu_si128((__m128i *)(dst + 32), s2);\n              _mm_storeu_si128((__m128i *)(dst + 48), s3);\n            }\n            transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                              n_begin * k_block_size + k_begin * N_STEP));\n            transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                              n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));\n          }\n        }\n      }\n    }\n\n    int8_t *get_submat(int n, int k, int n_begin, int k_begin) {\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      n_begin -= n_block_begin;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;\n    }\n\n    float *get_scale(int n, int n_begin) { return d + n_begin; }\n  };\n\n  struct BufferC {\n    float *c;\n    int max_m, n;\n\n    static size_t required_size(int max_m, int n) { return max_m * n * sizeof(float); }\n\n    BufferC(int max_m, int n, void *ptr) : max_m(max_m), n(n) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(max_m % M_STEP == 0);\n      assert(n % N_STEP == 0);\n      c = reinterpret_cast<float *>(ptr);\n    }\n\n    void to_mat(int m, ggml_bf16_t *dst, int ith, int nth) {\n      assert(m <= max_m);\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            __m512 *x0 =\n                (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);\n            __m512 *x1 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP +\n                                    i * N_STEP + 16);\n            avx512_32xfp32_to_32xbf16(x0, x1, (__m512i *)(dst + (m_begin + i) * n + n_block_begin + n_begin));\n          }\n        }\n      }\n    }\n\n    float *get_submat(int m, int n, int m_begin, int n_begin) {\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      n_begin -= n_block_begin;\n      return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;\n    }\n  };\n};\n\ninline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224BF::BufferA> ba,\n                    std::shared_ptr<GemmKernel224BF::BufferB> bb, std::shared_ptr<GemmKernel224BF::BufferC> bc, int ith,\n                    int nth, bool use_amx) {\n  using K = GemmKernel224BF;\n  assert(n % K::N_STEP == 0);\n  assert(k % K::K_STEP == 0);\n\n  auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n\n  for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {\n    for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {\n\n        float *c = bc->get_submat(m, n, m_begin, n_begin);\n        if (!use_amx) {\n          __m512 *c512 = (__m512 *)c;\n          if (k_block_begin == 0) {\n            for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) {\n              c512[m_i * 2] = _mm512_setzero_ps();\n              c512[m_i * 2 + 1] = _mm512_setzero_ps();\n            }\n          }\n\n          for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n            int32_t *a32 = (int32_t *)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n            __m512bh *b512 = (__m512bh *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n            for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) {\n              for (int k_i = 0; k_i < 16; k_i++) {\n                __m512bh ma = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);\n                for (int n_i = 0; n_i < 2; n_i++) {\n                  c512[m_i * 2 + n_i] = _mm512_dpbf16_ps(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]);\n                }\n              }\n            }\n          }\n\n        } else {\n          if (k_block_begin == 0) {\n            K::clean_c();\n          } else {\n            K::load_c(c, K::N_STEP * sizeof(float));\n          }\n          for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n            K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t));\n            K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t));\n            K::run_tile();\n          }\n          K::store_c(c, K::N_STEP * sizeof(float));\n        }\n      }\n    }\n  }\n}\n\ninline __m512i _mm512_dpbssd_epi32(__m512i src, __m512i a, __m512i b) {\n  __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);\n  __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);\n  __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);\n  __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);\n\n  b_lo = _mm256_sign_epi8(b_lo, a_lo);\n  b_hi = _mm256_sign_epi8(b_hi, a_hi);\n\n  b = _mm512_inserti64x4(b, b_lo, 0);\n  b = _mm512_inserti64x4(b, b_hi, 1);\n\n  a = _mm512_abs_epi8(a);\n\n  return _mm512_dpbusd_epi32(src, a, b);\n}\n\ninline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224Int8::BufferA> ba,\n                    std::shared_ptr<GemmKernel224Int8::BufferB> bb, std::shared_ptr<GemmKernel224Int8::BufferC> bc,\n                    int ith, int nth, bool use_amx) {\n  using K = GemmKernel224Int8;\n  assert(n % K::N_STEP == 0);\n  assert(k % K::K_STEP == 0);\n\n  auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n\n  for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {\n    for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {\n        float *c = bc->get_submat(m, n, m_begin, n_begin);\n\n        if (!use_amx) {\n          __m512i *c512 = (__m512i *)c;\n          if (k_block_begin == 0) {\n            for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) {\n              c512[m_i * 2] = _mm512_setzero_si512();\n              c512[m_i * 2 + 1] = _mm512_setzero_si512();\n            }\n          }\n\n          for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n            static_assert(K::K_STEP * sizeof(int8_t) == sizeof(__m512i));\n            static_assert(K::N_STEP / K::TILE_N == 2, \"Must be lke this\");\n\n            int32_t *a32 = (int32_t *)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n            __m512i *b512 = (__m512i *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n            for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) {\n              for (int k_i = 0; k_i < 16; k_i++) {\n                __m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);\n                for (int n_i = 0; n_i < 2; n_i++) {\n                  c512[m_i * 2 + n_i] = _mm512_dpbssd_epi32(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]);\n                }\n              }\n            }\n          }\n        } else {\n          if (k_block_begin == 0) {\n            K::clean_c();\n          } else {\n            K::load_c((int32_t *)c, K::N_STEP * sizeof(int32_t));\n          }\n          for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n            K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t));\n            K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t));\n            K::run_tile();\n          }\n          K::store_c((int32_t *)c, K::N_STEP * sizeof(int32_t));\n        }\n\n        if (k_block_begin + K::K_BLOCK >= k) {\n          int to = m - m_begin;\n          if (m - m_begin > K::M_STEP) {\n            to = K::M_STEP;\n          }\n          for (int i = 0; i < to; i++) {\n            __m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i));\n            __m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin));\n            __m512i now = _mm512_load_si512((__m512i *)(c + i * K::N_STEP));\n            __m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n            _mm512_store_ps((__m512 *)(c + i * K::N_STEP), result);\n            bs = _mm512_load_ps(bb->get_scale(n, n_begin) + K::TILE_N);\n            now = _mm512_load_si512((__m512i *)(c + i * K::N_STEP + K::TILE_N));\n            result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n            _mm512_store_ps((__m512 *)(c + i * K::N_STEP + K::TILE_N), result);\n          }\n        }\n      }\n    }\n  }\n}\n\n} // namespace amx"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/amx/la/utils.hpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2025-04-25 18:28:12\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2025-04-25 18:28:12\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#pragma once\n#include <cstdint>\n\n\ntemplate <typename T>\nT* offset_pointer(T* ptr, std::size_t byte_offset) {\n  return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr) + byte_offset);\n}\n\ntemplate <typename T>\nconst T* offset_pointer(const T* ptr, std::size_t byte_offset) {\n  return reinterpret_cast<const T*>(reinterpret_cast<const char*>(ptr) + byte_offset);\n}\n\ntemplate <typename T>\nT* offset_pointer_row_major(T* t, int row, int col, std::size_t ld) {\n  return offset_pointer(t, row * ld) + col;\n}\n\ntemplate <typename T>\nT* offset_pointer_col_major(T* t, int row, int col, std::size_t ld) {\n  return offset_pointer(t, col * ld) + row;\n}\n\nstatic inline void avx512_copy_32xbf16(__m512i* src, __m512i* dst) {\n  _mm512_storeu_si512(dst, _mm512_loadu_si512(src));\n}\n\nstatic inline void avx512_32xfp32_to_32xbf16(__m512* src0, __m512* src1, __m512i* dst) {\n  _mm512_storeu_si512(dst, __m512i(_mm512_cvtne2ps_pbh(*src1, *src0)));\n}\n\nstatic inline void avx512_32xbf16_to_32xfp32(__m512i* src, __m512* dst0, __m512* dst1) {\n  _mm512_storeu_ps(dst0, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src))), 16)));\n  _mm512_storeu_ps(dst1, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src) + 1)), 16)));\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/amx/moe.hpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2025-04-25 18:28:12\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2025-04-25 18:28:12\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_AMX_MOE_H\n#define CPUINFER_OPERATOR_AMX_MOE_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\n#include \"la/amx.hpp\"\n\n#ifdef USE_NUMA\n#include <numa.h>\n#include <numaif.h>\nvoid *numa_alloc_aligned(size_t size, int node, size_t alignment) {\n  void *ptr = numa_alloc_onnode(size, node);\n  assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n  return ptr;\n}\n#endif\n\nstatic inline __m512 exp_avx512(__m512 x) {\n  const __m512 log2e = _mm512_set1_ps(1.44269504089f);\n  const __m512 c1 = _mm512_set1_ps(0.69314718056f);\n\n  __m512 y = _mm512_mul_ps(x, log2e);\n  __m512i int_part = _mm512_cvtps_epi32(y);\n  __m512 frac_part = _mm512_sub_ps(y, _mm512_cvtepi32_ps(int_part));\n\n  const __m512 poly_1 = _mm512_set1_ps(0.9999999995f);\n  const __m512 poly_2 = _mm512_set1_ps(0.6931471805f);\n  const __m512 poly_3 = _mm512_set1_ps(0.2402265069f);\n  const __m512 poly_4 = _mm512_set1_ps(0.0555041087f);\n  const __m512 poly_5 = _mm512_set1_ps(0.0096181291f);\n  const __m512 poly_6 = _mm512_set1_ps(0.0013333558f);\n\n  __m512 frac_exp = _mm512_fmadd_ps(\n      frac_part, poly_6,\n      _mm512_fmadd_ps(frac_part, poly_5,\n                      _mm512_fmadd_ps(frac_part, poly_4,\n                                      _mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1)))));\n\n  __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part));\n  return _mm512_mul_ps(two_pow_i, frac_exp);\n}\n\nstatic inline __m512 act_fn(__m512 gate_val, __m512 up_val) {\n  __m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val);\n  __m512 exp_neg_gate = exp_avx512(neg_gate_val);\n  __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate);\n  __m512 act_val = _mm512_div_ps(gate_val, denom);\n\n  return _mm512_mul_ps(act_val, up_val);\n}\n\nstatic inline __m512 relu_act_fn(__m512 gate_val, __m512 up_val) {\n  __m512 zero_vec = _mm512_setzero_ps();\n  __m512 act_val = _mm512_max_ps(zero_vec, gate_val);\n  return _mm512_mul_ps(act_val, up_val);\n}\n\nstruct AMX_MOEConfig {\n  int expert_num;\n  int routed_expert_num;\n  int hidden_size;\n  int intermediate_size;\n  int max_len;\n  bool use_silu;\n  void *gate_proj;\n  void *up_proj;\n  void *down_proj;\n\n  AMX_MOEConfig() {}\n\n  AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len, bool use_silu,\n                void *gate_proj, void *up_proj, void *down_proj)\n      : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size),\n        intermediate_size(intermediate_size), max_len(max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj),\n        down_proj(down_proj) {}\n};\n\ntemplate <class T> class AMX_MOE {\nprivate:\n  AMX_MOEConfig config_;\n  void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n  void *up_proj_;   // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n  void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n\n  ggml_bf16_t *m_local_input_;       // [routed_expert_num * max_len * hidden_size]\n  ggml_bf16_t *m_local_gate_output_; // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_up_output_;   // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_down_output_; // [routed_expert_num * max_len * hidden_size]\n\n  std::vector<std::vector<int>> m_local_pos_;          // [max_len, routed_expert_num]\n  std::vector<int> m_local_num_;                       // [expert_num]\n  std::vector<int> m_expert_id_map_;                   // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_input_ptr_;       // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_gate_output_ptr_; // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_up_output_ptr_;   // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_down_output_ptr_; // [expert_num]\n\n  std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;\n  std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;\n  std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;\n\n#ifdef USE_NUMA\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> gate_bb_numa_;\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> up_bb_numa_;\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> down_bb_numa_;\n#else\n  std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;\n#endif\n\npublic:\n  AMX_MOE(AMX_MOEConfig config) {\n    config_ = config;\n    gate_proj_ = config_.gate_proj;\n    up_proj_ = config_.up_proj;\n    down_proj_ = config_.down_proj;\n\n    std::vector<std::pair<void **, uint64_t>> m_mem_requests;\n    m_mem_requests.push_back({(void **)&m_local_input_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *\n                                                                  config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_up_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *\n                                                                config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_down_output_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    std::vector<void *> gate_up_ba_ptr(config_.expert_num);\n    std::vector<void *> gate_bc_ptr(config_.expert_num);\n    std::vector<void *> up_bc_ptr(config_.expert_num);\n    std::vector<void *> down_ba_ptr(config_.expert_num);\n    std::vector<void *> down_bc_ptr(config_.expert_num);\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_mem_requests.push_back(\n          {(void **)&gate_up_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)});\n      m_mem_requests.push_back(\n          {(void **)&gate_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&up_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});\n    }\n    shared_mem_buffer.alloc(this, m_mem_requests);\n\n    m_local_pos_.resize(config_.max_len);\n    for (int i = 0; i < config_.max_len; i++) {\n      m_local_pos_[i].resize(config_.routed_expert_num);\n    }\n    m_expert_id_map_.resize(config_.expert_num);\n    m_local_num_.resize(config_.expert_num);\n    m_local_input_ptr_.resize(config_.expert_num);\n    m_local_gate_output_ptr_.resize(config_.expert_num);\n    m_local_up_output_ptr_.resize(config_.expert_num);\n    m_local_down_output_ptr_.resize(config_.expert_num);\n\n    for (uint64_t i = 0; i < config_.expert_num; i++) {\n      gate_up_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, gate_up_ba_ptr[i]));\n      gate_bc_.push_back(\n          std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, gate_bc_ptr[i]));\n      up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, up_bc_ptr[i]));\n      down_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, down_ba_ptr[i]));\n      down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, down_bc_ptr[i]));\n\n#ifdef USE_NUMA\n      int numa_nodes = numa_num_configured_nodes();\n      gate_bb_numa_.resize(numa_nodes);\n      up_bb_numa_.resize(numa_nodes);\n      down_bb_numa_.resize(numa_nodes);\n      for (int j = 0; j < numa_nodes; j++) {\n        void *gate_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);\n        gate_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));\n        void *up_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);\n        up_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));\n        void *down_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64);\n        down_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));\n      }\n#else\n      void *gate_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));\n      gate_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));\n\n      void *up_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));\n      up_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));\n\n      void *down_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));\n      down_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));\n#endif\n    }\n  }\n\n  ~AMX_MOE() { shared_mem_buffer.dealloc(this); }\n\n  void load_weights(Backend *backend) {\n    int nth = T::recommended_nth(config_.intermediate_size);\n    backend->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [&](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          int numa_nodes = numa_num_configured_nodes();\n          for (int j = 0; j < numa_nodes; j++) {\n            gate_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +\n                                                       expert_idx * config_.intermediate_size * config_.hidden_size,\n                                                   ith, nth);\n            up_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.up_proj +\n                                                     expert_idx * config_.intermediate_size * config_.hidden_size,\n                                                 ith, nth);\n          }\n#else\n          gate_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +\n                                             expert_idx * config_.intermediate_size * config_.hidden_size,\n                                         ith, nth);\n          up_bb_[expert_idx]->from_mat(\n              (ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth);\n#endif\n        },\n        nullptr);\n    nth = T::recommended_nth(config_.hidden_size);\n    backend->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [&](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          int numa_nodes = numa_num_configured_nodes();\n          for (int j = 0; j < numa_nodes; j++) {\n            down_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +\n                                                       expert_idx * config_.hidden_size * config_.intermediate_size,\n                                                   ith, nth);\n          }\n#else\n          down_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +\n                                             expert_idx * config_.hidden_size * config_.intermediate_size,\n                                         ith, nth);\n#endif\n        },\n        nullptr);\n  }\n\n  void warm_up(Backend *backend) {}\n\n  void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output,\n               int *batch_size_tensor, Backend *backend) {\n    qlen = batch_size_tensor[0];\n    bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);\n    int activated_expert = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n      for (int j = 0; j < k; j++) {\n        m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n      }\n    }\n    for (int i = 0; i < config_.expert_num; i++) {\n      if (m_local_num_[i] > 0) {\n        m_expert_id_map_[activated_expert] = i;\n        activated_expert++;\n      }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;\n      m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n      m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n      m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;\n      offset += m_local_num_[i];\n    }\n    backend->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int j = 0; j < k; j++) {\n            memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,\n                   (ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);\n          }\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n\n          gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n    int nth = T::recommended_nth(config_.intermediate_size);\n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx],\n                       ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith,\n                       nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);\n          up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);\n          auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);\n          if (config_.use_silu) {\n            for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n                ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];\n                ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];\n                for (int j = n_start; j < n_end; j += 32) {\n                  __m512 gate_val0, gate_val1, up_val0, up_val1;\n                  avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);\n                  avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);\n                  __m512 result0 = act_fn(gate_val0, up_val0);\n                  __m512 result1 = act_fn(gate_val1, up_val1);\n                  avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));\n                }\n              }\n          }\n          else {\n              for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n                ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];\n                ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];\n                for (int j = n_start; j < n_end; j += 32) {\n                  __m512 gate_val0, gate_val1, up_val0, up_val1;\n                  avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);\n                  avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);\n                  __m512 result0 = relu_act_fn(gate_val0, up_val0);\n                  __m512 result1 = relu_act_fn(gate_val1, up_val1);\n                  avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));\n                }\n              }\n          }\n          \n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n    nth = T::recommended_nth(config_.hidden_size);\n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],\n                       down_bb_numa_[Backend::numa_node][expert_idx], down_bc_[expert_idx], ith, nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],\n                       down_bb_[expert_idx], down_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int e = 0; e < config_.hidden_size; e += 32) {\n            __m512 x0 = _mm512_setzero_ps();\n            __m512 x1 = _mm512_setzero_ps();\n            for (int j = 0; j < k; j++) {\n              __m512 weight = _mm512_set1_ps(weights[i * k + j]);\n              __m512 down_output0, down_output1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(m_local_down_output_ptr_[expert_ids[i * k + j]] +\n                                                    m_local_pos_[i][j] * config_.hidden_size + e),\n                                        &down_output0, &down_output1);\n              x0 = _mm512_fmadd_ps(down_output0, weight, x0);\n              x1 = _mm512_fmadd_ps(down_output1, weight, x1);\n            }\n            avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + i * config_.hidden_size + e));\n          }\n        },\n        nullptr);\n  }\n};\n\n#endif\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/kvcache/kvcache.h",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#ifndef CPUINFER_OPERATOR_KVCACHE_H\n#define CPUINFER_OPERATOR_KVCACHE_H\n\n#include <algorithm>\n#include <atomic>\n#include <cassert>\n#include <condition_variable>\n#include <cstdint>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <functional>\n#include <future>\n#include <iostream>\n#include <memory>\n#include <mutex>\n#include <queue>\n#include <random>\n#include <stdexcept>\n#include <thread>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"llama.cpp/ggml-common.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\n#define CHUNK_SIZE 32\n\n/**\n * @brief Converts a ggml_type enum value to its corresponding string\n * representation.\n *\n * This function provides a human-readable string representation for a given\n * ggml_type enum value. The string can be used for logging, debugging, or\n * displaying information in a user interface.\n *\n * @param type The ggml_type enum value to convert.\n * @return A string representation of the enum value.\n */\nstd::string ggml_type_to_string(ggml_type type);\n\n/**\n * @enum AnchorType\n * @brief Defines the types of anchors used in attention mechanisms.\n *\n * This enum specifies different types of anchors that can be used in attention\n * mechanisms, such as fixed anchors, dynamic anchors, or special anchors like\n * QUEST, BLOCK_MEAN, or BLOCK_MAX.\n */\nenum AnchorType {\n    FIXED_ANCHOR, /**< A fixed anchor that does not change. */\n    DYNAMIC,      /**< A dynamic anchor that can change over time. */\n    QUEST, /**< A special anchor type used for QUEST (Query and Embedding Space\n              Transformation). */\n    BLOCK_MEAN, /**< An anchor based on the mean of a block of data. */\n    BLOCK_MAX /**< An anchor based on the maximum value within a block of data.\n               */\n};\n\n/**\n * @brief Converts an AnchorType enum value to its corresponding string\n * representation.\n *\n * This function provides a human-readable string representation for a given\n * AnchorType enum value. The string can be used for logging, debugging, or\n * displaying information in a user interface.\n *\n * @param anchor_type The AnchorType enum value to convert.\n * @return A string representation of the enum value.\n */\nstd::string AnchorTypeToString(AnchorType anchor_type);\n\n/**\n * @enum RetrievalType\n * @brief Defines the types of retrieval strategies in attention mechanisms.\n *\n * This enum specifies different retrieval strategies that can be used in\n * attention mechanisms, such as layer-level retrieval, key-value head-level\n * retrieval, or query head-level retrieval.\n */\nenum RetrievalType {\n    LAYER,  /**< Retrieval at the layer level. */\n    KVHEAD, /**< Retrieval at the key-value head level. */\n    QHEAD   /**< Retrieval at the query head level. */\n};\n\n/**\n * @brief Converts a RetrievalType enum value to its corresponding string\n * representation.\n *\n * This function provides a human-readable string representation for a given\n * RetrievalType enum value. The string can be used for logging, debugging, or\n * displaying information in a user interface.\n *\n * @param retrieval_type The RetrievalType enum value to convert.\n * @return A string representation of the enum value.\n */\nstd::string RetrievalTypeToString(RetrievalType retrieval_type);\n\n/**\n * @struct KVCacheConfig\n * @brief Configuration structure for Key-Value (KV) Cache.\n *\n * This structure holds configuration parameters for setting up and managing\n * a Key-Value (KV) Cache used in various attention mechanisms. It includes\n * parameters such as the number of layers, the number of heads, the dimension\n * of each head, block length, anchor information, and memory-related settings.\n */\nstruct KVCacheConfig {\n    int layer_num;   /**< Number of layers in the model. */\n    int kv_head_num; /**< Number of heads in the KV Cache. */\n    int q_head_num;  /**< Number of heads in the query. */\n    int head_dim;    /**< Dimension of each head. */\n    int block_len;   /**< Length of each block in the cache. */\n    int anchor_num;  /**< Number of anchors used in attention. */\n\n    ggml_type kv_type; /**< Data type of the KV Cache (e.g., fp16, q8_0). */\n\n    // Controls the pre-allocated memory size\n    int max_block_num;  /**< Maximum number of blocks that can be allocated. */\n    int max_batch_size; /**< Maximum batch size that can be processed. */\n    int max_thread_num; /**< Maximum number of threads that can be used. */\n\n    AnchorType\n        anchor_type; /**< Type of anchors used in the attention mechanism. */\n    RetrievalType\n        retrieval_type; /**< Type of retrieval strategy used in the cache. */\n\n    int layer_step;   /**< Step size between layers. */\n    int token_step;   /**< Step size between tokens. */\n    int layer_offset; /**< Offset value for layers. */\n\n    /**\n     * @brief Default constructor for KVCacheConfig.\n     *\n     * Initializes the configuration with default values. This constructor\n     * does not initialize any member variables explicitly.\n     */\n    KVCacheConfig() = default;\n\n    /**\n     * @brief Parameterized constructor for KVCacheConfig.\n     *\n     * This constructor initializes the configuration with specific values\n     * for all member variables.\n     *\n     * @param layer_num The number of layers in the model.\n     * @param kv_head_num The number of heads in the KV Cache.\n     * @param q_head_num The number of heads in the query.\n     * @param head_dim The dimension of each head.\n     * @param block_len The length of each block in the cache.\n     * @param anchor_num The number of anchors used in attention.\n     * @param anchor_type The type of anchors used in the attention mechanism.\n     * @param kv_type The data type of the KV Cache (e.g., fp16, q8_0).\n     * @param retrieval_type The type of retrieval strategy used in the cache.\n     * @param layer_step The step size between layers.\n     * @param token_step The step size between tokens.\n     * @param layer_offset The offset value for layers.\n     * @param max_block_num The maximum number of blocks that can be allocated.\n     * @param max_batch_size The maximum batch size that can be processed.\n     * @param max_thread_num The maximum number of threads that can be used.\n     */\n    KVCacheConfig(int layer_num, int kv_head_num, int q_head_num, int head_dim,\n                  int block_len, int anchor_num, AnchorType anchor_type,\n                  ggml_type kv_type, RetrievalType retrieval_type,\n                  int layer_step, int token_step, int layer_offset,\n                  int max_block_num, int max_batch_size, int max_thread_num);\n};\n\n/**\n * @class KVCache\n * @brief Manages the Key-Value (KV) Cache used in attention mechanisms.\n *\n * The KVCache class provides functionality for managing the Key-Value Cache,\n * including resizing the cache, retrieving configuration parameters, and\n * updating internal states. This class is typically used in transformer models\n * to store and manage past key and value states for efficient attention\n * computations.\n */\nclass KVCache {\n  public:\n    /**\n     * @brief Constructs a KVCache object with the given configuration.\n     *\n     * Initializes the KVCache with the specified configuration parameters,\n     * such as the number of layers, heads, head dimensions, and other\n     * relevant settings.\n     *\n     * @param config The configuration object containing initialization\n     * parameters.\n     */\n    KVCache(KVCacheConfig config);\n\n    /**\n     * @brief Resizes the number of threads used by the cache.\n     *\n     * This function adjusts the number of threads that the cache can utilize.\n     * It allows dynamic reconfiguration of the parallel processing capabilities\n     * based on the current workload or system resources.\n     *\n     * @param thread_num The new number of threads to use.\n     */\n    void ThreadResize(int thread_num);\n\n    /**\n     * @brief Resizes the batch size managed by the cache.\n     *\n     * This function adjusts the batch size that the cache can handle. It\n     * is useful when the input batch size changes dynamically, allowing\n     * the cache to be reconfigured accordingly.\n     *\n     * @param batch_size The new batch size.\n     */\n    void BatchResize(int batch_size);\n\n    /**\n     * @brief Resizes the number of blocks managed by the cache.\n     *\n     * This function adjusts the number of blocks that the cache can manage.\n     * It allows dynamic reconfiguration of the block structure based on the\n     * current sequence length or other factors.\n     *\n     * @param block_num The new number of blocks.\n     */\n    void BlockResize(int block_num);\n\n    /**\n     * @brief Gets the number of layers in the cache.\n     *\n     * @return The number of layers configured in the cache.\n     */\n    int get_layer_num() { return config_.layer_num; }\n\n    /**\n     * @brief Gets the number of KV heads in the cache.\n     *\n     * @return The number of KV heads configured in the cache.\n     */\n    int get_kv_head_num() { return config_.kv_head_num; }\n\n    /**\n     * @brief Gets the number of query heads in the cache.\n     *\n     * @return The number of query heads configured in the cache.\n     */\n    int get_q_head_num() { return config_.q_head_num; }\n\n    /**\n     * @brief Gets the dimension of each head in the cache.\n     *\n     * @return The dimension of each head.\n     */\n    int get_head_dim() { return config_.head_dim; }\n\n    /**\n     * @brief Gets the length of each block in the cache.\n     *\n     * @return The length of each block.\n     */\n    int get_block_len() { return config_.block_len; }\n\n    /**\n     * @brief Gets the number of blocks for a specific layer.\n     *\n     * @param layer_id The ID of the layer for which to retrieve the block\n     * number.\n     * @return The number of blocks in the specified layer.\n     */\n    int get_block_num(int layer_id) { return past_block_num_[layer_id]; }\n\n    /**\n     * @brief Gets the number of anchors in the cache.\n     *\n     * @return The number of anchors configured in the cache.\n     */\n    int get_anchor_num() { return config_.anchor_num; }\n\n    /**\n     * @brief Gets the total length of the cache.\n     *\n     * @return The total length of the cache.\n     */\n    int get_cache_total_len() { return cache_total_len_; }\n\n    /**\n     * @brief Gets the total number of blocks in the cache.\n     *\n     * This function computes and returns the total number of blocks in the\n     * cache based on the total cache length and the block length configuration.\n     *\n     * @return The total number of blocks in the cache.\n     */\n    int get_cache_total_block_num() {\n        return (cache_total_len_ + config_.block_len - 1) / config_.block_len;\n    }\n\n    /**\n     * @brief Updates the total length of the cache.\n     *\n     * This function sets a new total length for the cache, allowing dynamic\n     * adjustment of the cache size during runtime.\n     *\n     * @param cache_total_len The new total length of the cache.\n     */\n    void update_cache_total_len(int cache_total_len) {\n        cache_total_len_ = cache_total_len;\n    }\n    void attn(const ggml_fp16_t *q_in, ggml_fp16_t *output, float *attn_lse,\n              int layer_idx, int generate_token_idx, int q_len, int batch_size,\n              int max_block_num, int *block_table, int *cache_seqlens,\n              int pick_block_num, int init_block_num, int local_block_num,\n              Backend *backend);\n\n    void update_kvcache_one_block_fp16(const ggml_fp16_t *k_in,\n                                       const ggml_fp16_t *v_in, int layer_id,\n                                       int block_idx, Backend *backend);\n\n    void get_kvcache_one_block_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,\n                                    int layer_id, int block_idx,\n                                    Backend *backend);\n\n    void update_importance_one_block(const ggml_fp16_t *importance,\n                                     int layer_id, int block_idx,\n                                     Backend *backend);\n    void get_importance_one_block(ggml_fp16_t *importance, int layer_id,\n                                  int block_idx, Backend *backend);\n\n    void get_anchor_one_block(ggml_fp16_t *anchor, int layer_id, int block_idx,\n                              Backend *backend);\n\n    void update_anchor_one_block(const ggml_fp16_t *anchor, int layer_id,\n                                 int block_idx, Backend *backend);\n\n    void calc_anchor_all_layers(int *block_table, int *cache_seqlens,\n                                int batch_size, int max_block_num,\n                                Backend *backend);\n\n    void load_kvcache(std::string tensor_file_path, Backend *backend);\n    void dump_kvcache(int *block_table, int cache_total_len,\n                      std::string tensor_file_path, Backend *backend);\n\n    void get_and_update_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,\n                                     int layer_id, int *block_table,\n                                     int batch_size, int max_block_num,\n                                     int *cache_seqlens, int q_len,\n                                     Backend *backend);\n\n    void get_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in, int layer_id,\n                          int *block_table, int batch_size, int max_block_num,\n                          int *cache_seqlens, Backend *backend);\n\n    void update_kvcache_fp16(const ggml_fp16_t *k_in, const ggml_fp16_t *v_in,\n                             int layer_id, int *block_table, int batch_size,\n                             int max_block_num, int *cache_seqlens, int q_len,\n                             Backend *backend);\n\n    void update_importance(const ggml_fp16_t *importance, int layer_id,\n                           int *block_table, int batch_size, int max_block_num,\n                           int *offset, int width, Backend *backend);\n\n    void attn_with_kvcache(const ggml_fp16_t *q_in, const ggml_fp16_t *k_in,\n                           const ggml_fp16_t *v_in, ggml_fp16_t *output,\n                           float *attn_lse, int layer_idx,\n                           int generate_token_idx, int q_len, int batch_size,\n                           int max_block_num, int *block_table,\n                           int *cache_seqlens, int topk, int local,\n                           Backend *backend);\n\n    void clear_importance_all_layers(int *block_table, int *cache_seqlens,\n                                     int batch_size, int max_block_num,\n                                     Backend *backend);\n\n    void clear_kvcache_all_layers(int *block_table, int *cache_seqlens,\n                                  int batch_size, int max_block_num,\n                                  Backend *backend);\n\n    void get_sincos(ggml_fp16_t *sin, ggml_fp16_t *cos, int seqlen);\n\n    void get_attn_sparsity(const ggml_fp16_t *q_in, float *attn_sparsity,\n                           int layer_idx, int generate_token_idx, int q_len,\n                           int batch_size, int max_block_num, int *block_table,\n                           int *cache_seqlens, int *block_table_origin,\n                           int *cache_seqlens_origin, int max_block_num_origin,\n                           int topk, int local, Backend *backend);\n\n    void get_all_kvcache_one_layer(int layer_id, ggml_fp16_t *k_in,\n                                   ggml_fp16_t *v_in, Backend *backend);\n\n  private:\n    // Persistent data\n    KVCacheConfig config_;\n    int n_gqa_;                            // q_head_num / kv_head_num\n    int cache_total_len_;                  // Number of tokens in cache\n    std::vector<uint64_t> past_block_num_; // [layer_num]\n    std::vector<std::vector<std::vector<std::vector<block_q4_0>>>>\n        k_cache_q4; // [layer_num, kv_head_num, past_block_num, block_len *\n                    // (head_dim / QK_4)]\n    std::vector<std::vector<std::vector<std::vector<block_q4_0>>>>\n        v_cache_q4; // [layer_num, kv_head_num, past_block_num, head_dim *\n                    // (block_len / QK_4)]\n    std::vector<std::vector<std::vector<std::vector<block_q8_0>>>>\n        k_cache_q8; // [layer_num, kv_head_num, past_block_num, block_len *\n                    // (head_dim / QK_8)]\n    std::vector<std::vector<std::vector<std::vector<block_q8_0>>>>\n        v_cache_q8; // [layer_num, kv_head_num, past_block_num, head_dim *\n                    // (block_len / QK_8)]\n\n    std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>\n        k_cache_fp16_; // [layer_num, kv_head_num, past_block_num, block_len *\n                       // head_dim]\n    std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>\n        v_cache_fp16_; // [layer_num, kv_head_num, past_block_num, head_dim *\n                       // block_len]\n\n    std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>\n        importance_; // [layer_num, past_block_num, block_len,\n                     // attention_head_num]\n\n    std::vector<ggml_fp16_t>\n        anchor_; // [layer_num * past_block_num * anchor_num *\n                 // attention_head_num * head_dim]\n\n    // Runtime data\n    int64_t layer_id_;\n    int64_t block_idx_;\n    int *block_table_;\n    uint64_t block_num_;\n    int max_block_num_after_retrieval_;\n\n    // Rotary positional embeddings\n    std::vector<std::vector<ggml_fp16_t>> sin_; // [seq_len, head_dim]\n    std::vector<std::vector<ggml_fp16_t>> cos_; // [seq_len, head_dim]\n\n    // update/get\n    int seq_len_;\n    uint16_t *k_scales_;        // q4_0\n    uint8_t *k_in_;             // q4_0\n    uint16_t *v_scales_;        // q4_0\n    uint8_t *v_in_;             // q4_0\n    uint16_t *k_data_;          // fp16\n    uint16_t *v_data_;          // fp16\n    uint16_t *importance_data_; // fp16\n    uint16_t *anchor_data_;     // fp16\n\n    // sparsity = (sigma(block lse / lse))\n    std::vector<std::vector<std::vector<float>>>\n        block_lse_; // [batch_size, max_block_num, q_head_num]\n    std::vector<std::vector<float>> attn_sparsity_; // [batch_size, q_head_num]\n\n    // attn\n    std::vector<std::vector<float>>\n        avg_q; // [batch_size, q_head_num * head_dim]\n\n    std::vector<std::vector<ggml_fp16_t>>\n        avg_q_fp16; // [batch_size, q_head_num * head_dim]\n    std::vector<\n        std::priority_queue<std::pair<float, int>,\n                            std::vector<std::pair<float, int>>, std::greater<>>>\n        top_similar_block_;\n\n    std::vector<std::vector<float>> block_similar_;\n    std::vector<std::vector<std::vector<float>>> block_similar_kv_head_;\n    std::vector<std::vector<std::vector<float>>> block_similar_q_head_;\n\n    std::vector<int> cache_seqlens_;               // [batch_size]\n    std::vector<int> selected_blocks_num_history_; // [layer_num // layer_step]\n\n    std::vector<std::vector<std::vector<int>>> selected_blocks_history_;\n    // [layer_num // layer_step, batch_size, max_block_num]\n\n    std::vector<std::vector<std::vector<std::vector<int>>>>\n        selected_blocks_history_kvhead_; // [layer_num // layer_step,\n                                         // batch_size, max_block_num,\n                                         // kv_head_num]\n\n    std::vector<std::vector<int>>\n        block_table_before_retrieval_; // [batch_size, max_block_num]\n    std::vector<std::vector<int>>\n        block_table_after_retrieval_; // [batch_size, pick_block_num]\n\n    std::vector<std::vector<std::vector<int>>>\n        block_table_before_retrieval_qhead_; // [batch_size, max_block_num,\n                                             // q_head_num]\n    std::vector<std::vector<std::vector<int>>>\n        block_table_after_retrieval_qhead_; // [batch_size, pick_block_num,\n                                            // q_head_num]\n\n    std::vector<std::vector<std::vector<int>>>\n        block_table_before_retrieval_kvhead_; // [batch_size, max_block_num,\n                                              // kv_head_num]\n    std::vector<std::vector<std::vector<int>>>\n        block_table_after_retrieval_kvhead_; // [batch_size, pick_block_num,\n                                             // kv_head_num]\n\n    std::vector<std::vector<std::unique_ptr<std::mutex>>>\n        mutex_; // [batch_size, kv_head_num]\n    std::vector<std::vector<std::vector<block_q8_0>>>\n        q_q8_0_; // [batch_size, kv_head_num, n_gqa * head_dim / QK8_0]\n    std::vector<std::vector<std::vector<float>>>\n        q_fp32_; // [batch_size, kv_head_num, n_gqa * head_dim]\n\n    std::vector<std::vector<std::vector<float>>>\n        output_fp32_; // [batch_size, kv_head_num, n_gqa * head_dim]\n    std::vector<std::vector<std::vector<float>>>\n        attn_lse_; // [batch_size, kv_head_num, n_gqa]\n\n    std::vector<std::pair<int, int>> thread_cur_head_idx_; // [thread_num]\n\n    std::vector<std::vector<block_q8_0>>\n        thread_local_output_q8_0_; // [thread_num, n_gqa * head_dim / QK8_0]\n    std::vector<std::vector<float>>\n        thread_local_attn_score_; // [thread_num, n_gqa * block_len]\n    std::vector<std::vector<float>>\n        thread_local_output_fp32_; // [thread_num, n_gqa * head_dim]\n    std::vector<std::vector<float>>\n        thread_local_attn_lse_; // [thread_num, n_gqa]\n    std::vector<std::vector<float>>\n        thread_local_cur_output_fp32_; // [thread_num, n_gqa * head_dim]\n    std::vector<std::vector<float>>\n        thread_local_cur_attn_lse_; // [thread_num, n_gqa]\n    std::vector<std::vector<uint8_t>>\n        thread_local_attn_mask_; // [thread_num, block_len // 8]\n    std::vector<std::vector<char>>\n        thread_local_draft_; // [thread_num, 2 * n_gqa * block_len + 6 * n_gqa *\n                             // head_dim + 2 * block_len * head_dim]\n\n    // tmp space\n    std::vector<float> q_fp32; // [n_gqa * head_dim]\n\n    void quantize_q_(const uint16_t *q_in_data, int batch_size);\n    void attn_initialize_layer_(int batch_size, int layer_idx, int *block_table,\n                                int &max_block_num, int *cache_seqlens);\n    void attn_initialize_kvhead_(int batch_size, int layer_idx,\n                                 int *block_table, int &max_block_num,\n                                 int *cache_seqlens);\n    void retrieval_kvcache_layer_(const uint16_t *q_in_data, int init_block_num,\n                                  int local_block_num, int pick_block_num,\n                                  int q_len, int generate_token_idx,\n                                  int batch_size, int layer_idx,\n                                  int *cache_seqlens, int &max_block_num,\n                                  Backend *backend);\n    void retrieval_kvcache_kvhead_(const uint16_t *q_in_data,\n                                   int init_block_num, int local_block_num,\n                                   int pick_block_num, int q_len,\n                                   int generate_token_idx, int batch_size,\n                                   int layer_idx, int *cache_seqlens,\n                                   int &max_block_num, Backend *backend);\n\n    void calculate_block_similarity_layer_(\n        const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,\n        int max_block_num, int *cache_seqlens, int init_block_num,\n        int local_block_num, int pick_block_num, Backend *backend);\n    void calculate_block_similarity_kvhead_(\n        const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,\n        int max_block_num, int *cache_seqlens, int init_block_num,\n        int local_block_num, int pick_block_num, Backend *backend);\n\n    void select_block_layer_(int batch_size, int layer_idx, int max_block_num,\n                             int init_block_num, int local_block_num,\n                             int pick_block_num);\n    void select_block_kvhead_(int batch_size, int layer_idx, int max_block_num,\n                              int init_block_num, int local_block_num,\n                              int pick_block_num);\n\n    void calculate_sparsity_layer_(const uint16_t *q_in_data,\n                                   float *attn_sparsity, int batch_size,\n                                   int max_block_num, int *block_table,\n                                   int *cache_seqlens, Backend *backend);\n    void calculate_sparsity_kvhead_(const uint16_t *q_in_data,\n                                    float *attn_sparsity, int batch_size,\n                                    int max_block_num, int *block_table,\n                                    int *cache_seqlens, Backend *backend);\n\n    void attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,\n                           float *attn_lse, int batch_size, Backend *backend);\n    void attention_layer_(const uint16_t *q_in_data, ggml_fp16_t *output,\n                          float *attn_lse, int batch_size, Backend *backend);\n\n    /**\n     * @brief Computes attention with KV cache for one block.\n     *\n     * This function performs attention computation for one block using KV\n     * cache. The function supports different data types for Q, K, and V caches,\n     * and provides options for quantization. The function does not perform any\n     * dynamic memory allocation internally, so all necessary buffers must be\n     * pre-allocated externally.\n     *\n     * @param head_dim The dimension of the head.\n     * @param bsz The batch size.\n     * @param q_type The data type of Q (GGML data type). Only supports fp16 and\n     * q8_0.\n     * @param q Pointer to the Q tensor [bsz, head_dim]. The quantization is\n     *          always applied along the head_dim dimension. The size must be\n     *          bsz * head_dim/32 * qtype_size. If head_dim % 32 != 0, an error\n     *          will be raised.\n     * @param past_kv_len The length of the past KV cache.\n     * @param past_kv_offset The offset in the past KV cache.\n     * @param is_full_attn Boolean flag indicating whether to use full attention\n     *                     (true for full 1 mask).\n     * @param attn_mask Pointer to the attention mask [bsz, past_kv_len]. If\n     *                  is_full_attn = false, a bit matrix is passed to\n     * represent the mask.\n     * @param k_type The data type of K cache (GGML data type). Only supports\n     *               fp16, q4_0, and q8_0.\n     * @param k_quant_type Quantization type for K cache. 0 for per_token, 1 for\n     *                     per_channel. Other values will raise an error.\n     * @param k_cache Pointer to the K cache tensor [seq_len, head_dim]. If\n     *                quant_type == 0, head_dim % 32 must be 0. If quant_type ==\n     * 1, seq_len % 32 must be 0.\n     * @param num_k_anchor The number of K anchors. If num_k_anchor == 0, it\n     * means no anchor is present.\n     * @param k_cache_anchors Pointer to the K cache anchors [num_k_anchor,\n     * head_dim]. The k_anchor_type must be fp16.\n     * @param k_cache_anchor_pos Pointer to the K cache anchor positions. Each\n     * token is associated with the nearest previous anchor position.\n     * @param v_type The data type of V cache (GGML data type).\n     * @param v_quant_type Quantization type for V cache.\n     * @param v_cache Pointer to the V cache tensor [head_dim, seq_len].\n     * @param num_v_anchor The number of V anchors.\n     * @param v_cache_anchors Pointer to the V cache anchors.\n     * @param v_cache_anchor_pos Pointer to the V cache anchor positions.\n     * @param attn_score Pre-allocated buffer for attention scores [bsz,\n     * past_kv_len].\n     * @param output Output tensor [bsz, head_dim] with the same type as q_type.\n     * @param lse Pre-allocated buffer [bsz] for the log-sum-exp of the\n     * attention scores.\n     * @param draft Pre-allocated temporary buffer. The buffer size should be\n     * enough to hold (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 *\n     *              past_kv_len * head_dim + past_kv_len * head_dim / 32) bytes.\n     * @param rotary_angle Pointer to the rotary angle tensor.\n     * @param rotary_cos Pointer to the cosine values for rotary embedding.\n     * @param rotary_sin Pointer to the sine values for rotary embedding.\n     */\n    void attn_with_kvcache_one_block_(\n        int head_dim, int bsz,\n        ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0\n        // [bsz, head_dim]\n        // Quantization is always on the head_dim dimension (per_token). If\n        // head_dim % 32 != 0, an error will be raised. The size must be bsz *\n        // head_dim/32 * qtype_size.\n        const void *q,\n\n        int past_kv_len, int past_kv_offset,\n        bool is_full_attn, // true indicates a full 1 mask\n        // If is_full_attn = false, a bit matrix representing the mask is\n        // passed. [bsz, past_kv_len]\n        const uint8_t *attn_mask,\n\n        ggml_type k_type, // GGML data type of `K Cache`, only supports fp16,\n                          // q4_0, q8_0\n        int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an\n                          // error\n        // [seq_len, head_dim]\n        // If quant_type == 0, head_dim % 32 must be 0.\n        // If quant_type == 1, seq_len % 32 must be 0.\n        const void *k_cache,\n\n        // k_anchor_type must be fp16\n        int num_k_anchor, // num_k_anchor == 0 indicates no anchor\n        // [num_k_anchor, head_dim]\n        const void *k_cache_anchors,\n        // Each token is associated with the nearest previous position's anchor,\n        // with the same distance.\n        const int *k_cache_anchor_pos,\n\n        // v_cache similar to k_cache\n        ggml_type v_type, int v_quant_type,\n        // [head_dim, seq_len]\n        const void *v_cache, int num_v_anchor, const void *v_cache_anchors,\n        const int *v_cache_anchor_pos,\n\n        // Pre-allocated buffer for intermediate calculations [bsz,\n        // past_kv_len]. No malloc is performed inside this function.\n        float *attn_score,\n\n        // Output: [bsz, head_dim], with the same type as q_type\n        void *output,\n        // [bsz]\n        float *lse,\n\n        // Pre-allocated temporary buffer with sufficient size:\n        // (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *\n        // head_dim + past_kv_len * head_dim / 32) bytes.\n        void *draft,\n\n        // Apply rotary embedding online\n        const int *rotary_angle, const void *rotary_cos, const void *rotary_sin\n        // rotary_cos=None,\n        // rotary_sin=None,\n        // cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,\n        // cache_batch_idx: Optional[torch.Tensor] = None,\n        // rotary_interleaved=True,\n\n        // // Not supported for now\n        // window_size=(-1, -1),  # -1 means infinite context window\n        // alibi_slopes=None,\n    );\n};\n\n/**\n * @brief Scales a float32 vector by a given scalar value.\n *\n * This function multiplies each element of the input vector `y` by a scalar\n * `v`. It uses platform-specific optimizations if available, such as Apple's\n * Accelerate framework or SIMD instructions. If no specific optimization is\n * available, the function falls back to a simple scalar multiplication loop.\n *\n * @param n The number of elements in the vector `y`.\n * @param y The input vector to be scaled. The result will be stored in the same\n * vector.\n * @param v The scalar value by which to scale the vector.\n */\nvoid ggml_vec_scale_f32(const int n, float *y, const float v);\n#endif"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/kvcache/kvcache_attn.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"kvcache.h\"\n\n#include <chrono>\n\nvoid KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,\n                                float *attn_lse, int batch_size,\n                                Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    seq_len_ = config_.block_len;\n\n    backend->do_work_stealing_job(\n        batch_size * config_.kv_head_num * max_block_num_after_retrieval_,\n        [&](int thread_id) {\n            thread_cur_head_idx_[thread_id].first = -1;\n            thread_cur_head_idx_[thread_id].second = -1;\n        },\n        [&](int task_id) {\n            int batch_id = task_id / (config_.kv_head_num *\n                                      max_block_num_after_retrieval_);\n            int head_id = (task_id % (config_.kv_head_num *\n                                      max_block_num_after_retrieval_)) /\n                          max_block_num_after_retrieval_;\n            int block_id = task_id % max_block_num_after_retrieval_;\n            int thread_id = Backend::thread_local_id;\n\n            // If the block is out of the sequence length, skip it.\n            if (cache_seqlens_[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx =\n                block_table_after_retrieval_kvhead_[batch_id][block_id]\n                                                   [head_id];\n            if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n                int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n                if (seq_len == 0)\n                    return;\n\n                // Prepare the attention mask for the last block.\n                int full_blocks = seq_len / 8;\n                int remaining_bits = seq_len % 8;\n                // Fill full blocks with 1s\n                for (int i = 0; i < full_blocks; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0xFF;\n                }\n                // Fill the remaining bits in the next block\n                if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n                    thread_local_attn_mask_[thread_id][full_blocks] =\n                        (1 << remaining_bits) - 1;\n                } else {\n                    thread_local_attn_mask_[thread_id][full_blocks] = 0;\n                }\n\n                for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0;\n                }\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            } else {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            }\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse =\n                        thread_local_cur_attn_lse_[thread_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_attn_lse_[thread_id][i] -\n                                     thread_local_cur_attn_lse_[thread_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_[thread_id]\n                                                     [i * config_.head_dim +\n                                                      j] +=\n                            thread_local_output_fp32_[thread_id]\n                                                     [i * config_.head_dim + j];\n                    }\n                    thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n                }\n            } else {\n                if (cur_batch_idx != -1) {\n                    mutex_[cur_batch_idx][cur_head_id]->lock();\n                    for (int i = 0; i < n_gqa_; i++) {\n                        if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                            1e-6) {\n                            attn_lse_[cur_batch_idx][cur_head_id][i] =\n                                thread_local_cur_attn_lse_[thread_id][i];\n                            for (int j = 0; j < config_.head_dim; j++) {\n                                output_fp32_[cur_batch_idx][cur_head_id]\n                                            [i * config_.head_dim + j] =\n                                                thread_local_cur_output_fp32_\n                                                    [thread_id]\n                                                    [i * config_.head_dim + j];\n                            }\n                            continue;\n                        }\n                        float new_attn_lse =\n                            attn_lse_[cur_batch_idx][cur_head_id][i] +\n                            std::log(\n                                1.0 +\n                                std::exp(\n                                    thread_local_cur_attn_lse_[thread_id][i] -\n                                    attn_lse_[cur_batch_idx][cur_head_id][i]));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            output_fp32_[cur_batch_idx][cur_head_id].data() +\n                                i * config_.head_dim,\n                            std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                     new_attn_lse));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            thread_local_cur_output_fp32_[thread_id].data() +\n                                i * config_.head_dim,\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     new_attn_lse));\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] +=\n                                thread_local_cur_output_fp32_\n                                    [thread_id][i * config_.head_dim + j];\n                        }\n                        attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                    }\n                    mutex_[cur_batch_idx][cur_head_id]->unlock();\n                }\n                thread_cur_head_idx_[thread_id].first = batch_id;\n                thread_cur_head_idx_[thread_id].second = head_id;\n                for (int i = 0; i < n_gqa_; i++) {\n                    thread_local_cur_attn_lse_[thread_id][i] =\n                        thread_local_attn_lse_[thread_id][i];\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_\n                            [thread_id][i * config_.head_dim + j] =\n                                thread_local_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                }\n            }\n        },\n        // Merge the results of the remaining blocks.\n        [&](int thread_id) {\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (cur_head_id != -1) {\n                mutex_[cur_batch_idx][cur_head_id]->lock();\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse;\n                    if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                        1e-6) {\n                        attn_lse_[cur_batch_idx][cur_head_id][i] =\n                            thread_local_cur_attn_lse_[thread_id][i];\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] =\n                                            thread_local_cur_output_fp32_\n                                                [thread_id]\n                                                [i * config_.head_dim + j];\n                        }\n                        continue;\n                    }\n                    new_attn_lse =\n                        attn_lse_[cur_batch_idx][cur_head_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     attn_lse_[cur_batch_idx][cur_head_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        output_fp32_[cur_batch_idx][cur_head_id].data() +\n                            i * config_.head_dim,\n                        std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        output_fp32_[cur_batch_idx][cur_head_id]\n                                    [i * config_.head_dim + j] +=\n                            thread_local_cur_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                    attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                }\n                mutex_[cur_batch_idx][cur_head_id]->unlock();\n            }\n        });\n    // move the results to output and attn_lse\n    uint16_t *output_data = reinterpret_cast<uint16_t *>(output);\n    float *attn_lse_data = attn_lse;\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        for (int i = 0; i < config_.kv_head_num; i++) {\n            for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                output_data[batch_idx * config_.kv_head_num * n_gqa_ *\n                                config_.head_dim +\n                            i * n_gqa_ * config_.head_dim + j] =\n                    GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);\n            }\n            for (int j = 0; j < n_gqa_; j++) {\n                attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ +\n                              i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];\n            }\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of computing attention: %f s\\n\", layer_idx,\n    //        diff.count());\n}\n\nvoid KVCache::attention_layer_(const uint16_t *q_in_data, ggml_fp16_t *output,\n                               float *attn_lse, int batch_size,\n                               Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        batch_size * config_.kv_head_num * max_block_num_after_retrieval_,\n        [&](int thread_id) {\n            thread_cur_head_idx_[thread_id].first = -1;\n            thread_cur_head_idx_[thread_id].second = -1;\n        },\n        [&](int task_id) {\n            int batch_id = task_id / (config_.kv_head_num *\n                                      max_block_num_after_retrieval_);\n            int head_id = (task_id % (config_.kv_head_num *\n                                      max_block_num_after_retrieval_)) /\n                          max_block_num_after_retrieval_;\n            int block_id = task_id % max_block_num_after_retrieval_;\n            int thread_id = Backend::thread_local_id;\n            // If the block is out of the sequence length, skip it.\n            if (cache_seqlens_[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table_after_retrieval_[batch_id][block_id];\n            if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n                int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n                if (seq_len == 0)\n                    return;\n\n                // Prepare the attention mask for the last block.\n                int full_blocks = seq_len / 8;\n                int remaining_bits = seq_len % 8;\n\n                // Fill full blocks with 1s\n                for (int i = 0; i < full_blocks; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0xFF;\n                }\n                // Fill the remaining bits in the next block\n                if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n                    thread_local_attn_mask_[thread_id][full_blocks] =\n                        (1 << remaining_bits) - 1;\n                } else {\n                    thread_local_attn_mask_[thread_id][full_blocks] = 0;\n                }\n\n                for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0;\n                }\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            } else {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            }\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse =\n                        thread_local_cur_attn_lse_[thread_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_attn_lse_[thread_id][i] -\n                                     thread_local_cur_attn_lse_[thread_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_[thread_id]\n                                                     [i * config_.head_dim +\n                                                      j] +=\n                            thread_local_output_fp32_[thread_id]\n                                                     [i * config_.head_dim + j];\n                    }\n                    thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n                }\n            } else {\n                if (cur_batch_idx != -1) {\n                    mutex_[cur_batch_idx][cur_head_id]->lock();\n                    for (int i = 0; i < n_gqa_; i++) {\n                        if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                            1e-6) {\n                            attn_lse_[cur_batch_idx][cur_head_id][i] =\n                                thread_local_cur_attn_lse_[thread_id][i];\n                            for (int j = 0; j < config_.head_dim; j++) {\n                                output_fp32_[cur_batch_idx][cur_head_id]\n                                            [i * config_.head_dim + j] =\n                                                thread_local_cur_output_fp32_\n                                                    [thread_id]\n                                                    [i * config_.head_dim + j];\n                            }\n                            continue;\n                        }\n                        float new_attn_lse =\n                            attn_lse_[cur_batch_idx][cur_head_id][i] +\n                            std::log(\n                                1.0 +\n                                std::exp(\n                                    thread_local_cur_attn_lse_[thread_id][i] -\n                                    attn_lse_[cur_batch_idx][cur_head_id][i]));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            output_fp32_[cur_batch_idx][cur_head_id].data() +\n                                i * config_.head_dim,\n                            std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                     new_attn_lse));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            thread_local_cur_output_fp32_[thread_id].data() +\n                                i * config_.head_dim,\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     new_attn_lse));\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] +=\n                                thread_local_cur_output_fp32_\n                                    [thread_id][i * config_.head_dim + j];\n                        }\n                        attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                    }\n                    mutex_[cur_batch_idx][cur_head_id]->unlock();\n                }\n                thread_cur_head_idx_[thread_id].first = batch_id;\n                thread_cur_head_idx_[thread_id].second = head_id;\n                for (int i = 0; i < n_gqa_; i++) {\n                    thread_local_cur_attn_lse_[thread_id][i] =\n                        thread_local_attn_lse_[thread_id][i];\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_\n                            [thread_id][i * config_.head_dim + j] =\n                                thread_local_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                }\n            }\n        },\n        // Merge the results of the remaining blocks.\n        [&](int thread_id) {\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (cur_head_id != -1) {\n                mutex_[cur_batch_idx][cur_head_id]->lock();\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse;\n                    if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                        1e-6) {\n                        attn_lse_[cur_batch_idx][cur_head_id][i] =\n                            thread_local_cur_attn_lse_[thread_id][i];\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] =\n                                            thread_local_cur_output_fp32_\n                                                [thread_id]\n                                                [i * config_.head_dim + j];\n                        }\n                        continue;\n                    }\n                    new_attn_lse =\n                        attn_lse_[cur_batch_idx][cur_head_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     attn_lse_[cur_batch_idx][cur_head_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        output_fp32_[cur_batch_idx][cur_head_id].data() +\n                            i * config_.head_dim,\n                        std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        output_fp32_[cur_batch_idx][cur_head_id]\n                                    [i * config_.head_dim + j] +=\n                            thread_local_cur_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                    attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                }\n                mutex_[cur_batch_idx][cur_head_id]->unlock();\n            }\n        });\n\n    // move the results to output and attn_lse\n    uint16_t *output_data = reinterpret_cast<uint16_t *>(output);\n    float *attn_lse_data = attn_lse;\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        for (int i = 0; i < config_.kv_head_num; i++) {\n            for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                output_data[batch_idx * config_.kv_head_num * n_gqa_ *\n                                config_.head_dim +\n                            i * n_gqa_ * config_.head_dim + j] =\n                    GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);\n            }\n            for (int j = 0; j < n_gqa_; j++) {\n                attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ +\n                              i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];\n            }\n        }\n    }\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    //     printf(\"layer %d time of computing attention: %f s\\n\", layer_id_,\n    //     diff.count());\n}\n\nvoid KVCache::attn(const ggml_fp16_t *q_in, ggml_fp16_t *output,\n                   float *attn_lse, int layer_idx, int generate_token_idx,\n                   int q_len, int batch_size, int max_block_num,\n                   int *block_table, int *cache_seqlens, int pick_block_num,\n                   int init_block_num, int local_block_num, Backend *backend) {\n\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    layer_id_ = layer_idx;\n    batch_size = batch_size * q_len;\n\n    const uint16_t *q_in_data = const_cast<const uint16_t *>(q_in);\n\n    quantize_q_(q_in_data, batch_size);\n    if (config_.retrieval_type == RetrievalType::LAYER) {\n        attn_initialize_layer_(batch_size, layer_idx, block_table,\n                               max_block_num, cache_seqlens);\n        retrieval_kvcache_layer_(q_in_data, init_block_num, local_block_num,\n                                 pick_block_num, q_len, generate_token_idx,\n                                 batch_size, layer_idx, cache_seqlens,\n                                 max_block_num, backend);\n        attention_layer_(q_in_data, output, attn_lse, batch_size, backend);\n    } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n        attn_initialize_kvhead_(batch_size, layer_idx, block_table,\n                                max_block_num, cache_seqlens);\n        retrieval_kvcache_kvhead_(q_in_data, init_block_num, local_block_num,\n                                  pick_block_num, q_len, generate_token_idx,\n                                  batch_size, layer_idx, cache_seqlens,\n                                  max_block_num, backend);\n        attention_kvhead_(q_in_data, output, attn_lse, batch_size, backend);\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of computing attention: %f s\\n\", layer_idx,\n    //        diff.count());\n}\n\nvoid KVCache::attn_with_kvcache(\n    const ggml_fp16_t *q_in, const ggml_fp16_t *k_in, const ggml_fp16_t *v_in,\n    ggml_fp16_t *output, float *attn_lse, int layer_idx, int generate_token_idx,\n    int q_len, int batch_size, int max_block_num, int *block_table,\n    int *cache_seqlens, int topk, int local, Backend *backend) {\n    //    printf(\"attn_with_kvcache start\\n\");\n    assert(q_len == 1);\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_idx;\n\n    update_kvcache_fp16(k_in, v_in, layer_idx, block_table, batch_size,\n                        max_block_num, cache_seqlens, q_len, backend);\n    //    printf(\"update finished.\\n\");\n\n    // cache_seqlens memory is modified.\n    for (int i = 0; i < batch_size; i++) {\n        cache_seqlens[i] += q_len;\n    }\n    int init_block_num = 1;\n    if (config_.block_len <= 32) {\n        init_block_num = 64 / config_.block_len;\n    }\n\n    attn(q_in, output, attn_lse, layer_idx, generate_token_idx, q_len,\n         batch_size, max_block_num, block_table, cache_seqlens, topk,\n         init_block_num, local, backend);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    //     printf(\"layer %d time of computing attention with kvcache: %f s\\n\",\n    //     layer_idx, diff.count());\n}\n\nvoid KVCache::quantize_q_(const uint16_t *q_in_data, int batch_size) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            // quantize q\n            for (int i = 0; i < config_.kv_head_num; i++) {\n                for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                    q_fp32_[batch_idx][i][j] = GGML_FP16_TO_FP32(\n                        q_in_data[batch_idx * config_.kv_head_num * n_gqa_ *\n                                      config_.head_dim +\n                                  i * n_gqa_ * config_.head_dim + j]);\n                }\n            }\n        } else {\n            // quantize q\n            for (int i = 0; i < config_.kv_head_num; i++) {\n                for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                    q_fp32[j] = GGML_FP16_TO_FP32(\n                        q_in_data[batch_idx * config_.kv_head_num * n_gqa_ *\n                                      config_.head_dim +\n                                  i * n_gqa_ * config_.head_dim + j]);\n                }\n                quantize_row_q8_0(q_fp32.data(), q_q8_0_[batch_idx][i].data(),\n                                  n_gqa_ * config_.head_dim);\n            }\n        }\n    }\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    // printf(\"time of quantizing q: %f s\\n\",\n    //        std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::attn_initialize_layer_(int batch_size, int layer_idx,\n                                     int *block_table, int &max_block_num,\n                                     int *cache_seqlens) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        // initialize output_fp32_ and attn_lse_\n        for (int i = 0; i < config_.kv_head_num; i++) {\n            for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                output_fp32_[batch_idx][i][j] = 0;\n            }\n            for (int j = 0; j < n_gqa_; j++) {\n                attn_lse_[batch_idx][i][j] = 0;\n            }\n        }\n        // clear top_similar_block_\n\n        while (!top_similar_block_[batch_idx].empty())\n            top_similar_block_[batch_idx].pop();\n    }\n\n    // get block_table_before_retrieval_ and cache_seqlens_\n    if (block_table == nullptr) {\n        max_block_num = past_block_num_[layer_idx];\n        for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n            if (cache_total_len_ != 0)\n                cache_seqlens_[batch_idx] = cache_total_len_;\n            else\n                cache_seqlens_[batch_idx] = max_block_num * config_.block_len;\n            for (int i = 0; i < max_block_num; i++) {\n                block_table_before_retrieval_[batch_idx][i] = i;\n                block_similar_[batch_idx][i] = 0;\n            }\n        }\n    } else {\n        for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n            cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];\n            for (int i = 0; i < max_block_num; i++) {\n                block_table_before_retrieval_[batch_idx][i] =\n                    block_table[batch_idx * max_block_num + i];\n                block_similar_[batch_idx][i] = 0;\n            }\n        }\n    }\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    // printf(\"layer %d time of initializing attention: %f s\\n\", layer_idx,\n    //        std::chrono::duration<double>(end - start).count());\n}\n\nvoid KVCache::calculate_block_similarity_layer_(\n    const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,\n    int max_block_num, int *cache_seqlens, int init_block_num,\n    int local_block_num, int pick_block_num, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    if (batch_size == 1 &&\n        config_.anchor_num == 1) { // TODO: improve batch_size > 1\n        for (int batch_id = 0; batch_id < batch_size; batch_id++) {\n            if (q_len == 1) {\n                for (int j = 0; j < config_.head_dim * config_.q_head_num;\n                     j++) {\n                    avg_q[batch_id][j] = GGML_FP16_TO_FP32(\n                        q_in_data[batch_id * q_len * config_.q_head_num *\n                                      config_.head_dim +\n                                  j]);\n                    avg_q_fp16[batch_id][j] =\n                        q_in_data[batch_id * q_len * config_.q_head_num *\n                                      config_.head_dim +\n                                  j];\n                }\n            } else {\n                for (int j = 0; j < config_.head_dim * config_.q_head_num;\n                     j++) {\n                    avg_q[batch_id][j] = 0;\n                }\n                for (int i = 0; i < q_len; i++) {\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        avg_q[batch_id][j] += GGML_FP16_TO_FP32(\n                            q_in_data[batch_id * q_len * config_.q_head_num *\n                                          config_.head_dim +\n                                      i * config_.q_head_num *\n                                          config_.head_dim +\n                                      j]);\n                    }\n                }\n                for (int j = 0; j < config_.head_dim * config_.q_head_num;\n                     j++) {\n                    avg_q[batch_id][j] /= q_len;\n                    avg_q_fp16[batch_id][j] =\n                        GGML_FP32_TO_FP16(avg_q[batch_id][j]);\n                }\n            }\n            int seq_len = cache_seqlens_[batch_id];\n            int block_num = (seq_len / config_.block_len) - local_block_num -\n                            init_block_num;\n            if (block_num <= 0) {\n                continue;\n            }\n            bool is_seq = true;\n            for (int i = init_block_num + 1;\n                 i < (seq_len / config_.block_len) - local_block_num; i++) {\n                if (block_table_before_retrieval_[batch_id][i] !=\n                    block_table_before_retrieval_[batch_id][i - 1] + 1) {\n                    is_seq = false;\n                    break;\n                }\n            }\n            if (is_seq) {\n                int nth = backend->get_thread_num();\n                backend->do_work_stealing_job(\n                    nth, nullptr,\n                    [&](int task_id) {\n                        int ith = task_id;\n                        bool ok = llamafile_sgemm(\n                            block_num, 1, config_.q_head_num * config_.head_dim,\n                            anchor_.data() +\n                                (layer_idx * config_.max_block_num +\n                                 block_table_before_retrieval_\n                                     [batch_id][init_block_num]) *\n                                    config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim,\n                            config_.q_head_num * config_.head_dim,\n                            avg_q_fp16[batch_id].data(),\n                            config_.q_head_num * config_.head_dim,\n                            block_similar_[batch_id].data() + init_block_num,\n                            block_num, ith, nth, GGML_TASK_TYPE_COMPUTE,\n                            GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32,\n                            GGML_PREC_DEFAULT);\n                        if (!ok) {\n                            printf(\"llamafile_sgemm failed\\n\");\n                        }\n                    },\n                    nullptr);\n            } else {\n                backend->do_work_stealing_job(\n                    block_num, nullptr,\n                    [&](int task_id) {\n                        int block_id = task_id + init_block_num;\n                        int block_idx =\n                            block_table_before_retrieval_[batch_id][block_id];\n                        bool ok = llamafile_sgemm(\n                            1, 1, config_.q_head_num * config_.head_dim,\n                            anchor_.data() +\n                                (layer_idx * config_.max_block_num +\n                                 block_table_before_retrieval_[batch_id]\n                                                              [block_idx]) *\n                                    config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim,\n                            config_.q_head_num * config_.head_dim,\n                            avg_q_fp16[batch_id].data(),\n                            config_.q_head_num * config_.head_dim,\n                            block_similar_[batch_id].data() + block_id, 1, 0, 1,\n                            GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16,\n                            GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n                        if (!ok) {\n                            printf(\"llamafile_sgemm failed\\n\");\n                        }\n                    },\n                    nullptr);\n            }\n        }\n    } else {\n        backend->do_work_stealing_job(\n            batch_size * max_block_num, nullptr,\n            [&](int task_id) {\n                int batch_id = task_id / max_block_num;\n                int block_id = task_id % max_block_num;\n                int seq_len = cache_seqlens_[batch_id];\n\n                if (block_id < init_block_num ||\n                    block_id >=\n                        (seq_len / config_.block_len) - local_block_num) {\n                    return;\n                }\n\n                int block_idx =\n                    block_table_before_retrieval_[batch_id][block_id];\n                float sim = 0;\n\n                for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                    for (int i = 0; i < config_.head_dim; i++) {\n                        float q_i = 0,\n                              qa_i = std::numeric_limits<float>::lowest();\n                        for (int q_id = 0; q_id < q_len; q_id++) {\n                            q_i += GGML_FP16_TO_FP32(\n                                q_in_data[batch_id * q_len *\n                                              config_.q_head_num *\n                                              config_.head_dim +\n                                          q_id * config_.q_head_num *\n                                              config_.head_dim +\n                                          head_id * config_.head_dim + i]);\n                        }\n                        q_i /= q_len;\n                        for (int anchor_id = 0; anchor_id < config_.anchor_num;\n                             anchor_id++) {\n                            qa_i = std::max(\n                                qa_i,\n                                GGML_FP16_TO_FP32(\n                                    anchor_[(long long)layer_idx *\n                                                config_.max_block_num *\n                                                config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            block_idx * config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            anchor_id * config_.q_head_num *\n                                                config_.head_dim +\n                                            head_id * config_.head_dim + i]) *\n                                    q_i);\n                        }\n                        sim += qa_i;\n                    }\n                }\n                block_similar_[batch_id][block_id] = sim;\n            },\n            nullptr);\n    }\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of calculating similarity: %f s\\n\", layer_idx,\n    //        diff.count());\n}\n\nvoid KVCache::select_block_layer_(int batch_size, int layer_idx,\n                                  int max_block_num, int init_block_num,\n                                  int local_block_num, int pick_block_num) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n\n        if (cache_seqlens_[batch_idx] / config_.block_len <=\n            init_block_num + pick_block_num + local_block_num) {\n            block_table_after_retrieval_[batch_idx].swap(\n                block_table_before_retrieval_[batch_idx]);\n            selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                         config_.layer_step] = 0;\n            continue;\n        }\n\n        for (int block_id = init_block_num;\n             block_id <\n             (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num;\n             block_id++) {\n            top_similar_block_[batch_idx].push(std::make_pair(\n                block_similar_[batch_idx][block_id],\n                block_table_before_retrieval_[batch_idx][block_id]));\n            if (top_similar_block_[batch_idx].size() > pick_block_num) {\n                top_similar_block_[batch_idx].pop();\n            }\n        }\n\n        int i = 0;\n        for (; i < init_block_num; i++) {\n            block_table_after_retrieval_[batch_idx][i] =\n                block_table_before_retrieval_[batch_idx][i];\n        }\n        while (!top_similar_block_[batch_idx].empty()) {\n            block_table_after_retrieval_[batch_idx][i] =\n                top_similar_block_[batch_idx].top().second;\n            top_similar_block_[batch_idx].pop();\n            i++;\n        }\n        for (; i < init_block_num + pick_block_num + local_block_num; i++) {\n            block_table_after_retrieval_[batch_idx][i] =\n                block_table_before_retrieval_[batch_idx]\n                                             [(cache_seqlens_[batch_idx] /\n                                               config_.block_len) -\n                                              local_block_num + i -\n                                              init_block_num - pick_block_num];\n        }\n        if (cache_seqlens_[batch_idx] % config_.block_len != 0) {\n            block_table_after_retrieval_[batch_idx][i] =\n                block_table_before_retrieval_[batch_idx][(\n                    cache_seqlens_[batch_idx] / config_.block_len)];\n            cache_seqlens_[batch_idx] =\n                (cache_seqlens_[batch_idx] % config_.block_len) +\n                i * config_.block_len;\n            i++;\n        } else {\n            cache_seqlens_[batch_idx] =\n                (cache_seqlens_[batch_idx] % config_.block_len) +\n                i * config_.block_len;\n        }\n        for (int j = 0; j < i; j++) {\n            selected_blocks_history_[(layer_idx - config_.layer_offset) /\n                                     config_.layer_step][batch_idx][j] =\n                block_table_after_retrieval_[batch_idx][j];\n        }\n        selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                     config_.layer_step] = i;\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of selecting blocks: %f s\\n\", layer_idx,\n    //        diff.count());\n}\n\n// retrieval kvcache, get the init_block_num block at beginning, top\n// pick_block_num similar and last local_block_num blocks. Each task\n// calculates the simlarity of a certain block with the query, then push\n// the block into the priority queue. Finally, the required blocks are\n// pushed into the block_table_after_retrieval_.\nvoid KVCache::retrieval_kvcache_layer_(const uint16_t *q_in_data,\n                                       int init_block_num, int local_block_num,\n                                       int pick_block_num, int q_len,\n                                       int generate_token_idx, int batch_size,\n                                       int layer_idx, int *cache_seqlens,\n                                       int &max_block_num, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    max_block_num_after_retrieval_ = 0;\n    if (pick_block_num != -1 &&\n        (generate_token_idx % config_.token_step != 0 ||\n         (layer_idx % config_.layer_step != config_.layer_offset))) {\n\n        if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                         config_.layer_step] == 0) {\n            max_block_num_after_retrieval_ = max_block_num;\n            block_table_after_retrieval_.swap(block_table_before_retrieval_);\n        } else {\n            max_block_num_after_retrieval_ = selected_blocks_num_history_\n                [(layer_idx - config_.layer_offset) / config_.layer_step];\n            for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n                for (int i = 0; i < max_block_num_after_retrieval_; i++) {\n                    block_table_after_retrieval_[batch_idx][i] =\n                        selected_blocks_history_[(layer_idx -\n                                                  config_.layer_offset) /\n                                                 config_.layer_step][batch_idx]\n                                                [i];\n                }\n\n                if (cache_seqlens[batch_idx] % config_.block_len == 1) {\n                    selected_blocks_num_history_[(layer_idx -\n                                                  config_.layer_offset) /\n                                                 config_.layer_step] += 1;\n                    int x =\n                        selected_blocks_num_history_[(layer_idx -\n                                                      config_.layer_offset) /\n                                                     config_.layer_step];\n                    int last_block_idx =\n                        block_table_before_retrieval_[batch_idx]\n                                                     [cache_seqlens[batch_idx] /\n                                                      config_.block_len];\n                    selected_blocks_history_[(layer_idx -\n                                              config_.layer_offset) /\n                                             config_.layer_step][batch_idx]\n                                            [x - 1] = last_block_idx;\n                    block_table_after_retrieval_[batch_idx][x - 1] =\n                        last_block_idx;\n                }\n                cache_seqlens_[batch_idx] =\n                    (cache_seqlens_[batch_idx] % config_.block_len) +\n                    selected_blocks_num_history_[(layer_idx -\n                                                  config_.layer_offset) /\n                                                 config_.layer_step] *\n                        config_.block_len -\n                    config_.block_len;\n            }\n        }\n    } else if (pick_block_num != -1) {\n        max_block_num_after_retrieval_ =\n            std::min(max_block_num,\n                     init_block_num + pick_block_num + local_block_num + 1);\n        calculate_block_similarity_layer_(q_in_data, batch_size, layer_idx,\n                                          q_len, max_block_num, cache_seqlens,\n                                          init_block_num, local_block_num,\n                                          pick_block_num, backend);\n        select_block_layer_(batch_size, layer_idx, max_block_num,\n                            init_block_num, local_block_num, pick_block_num);\n    } else {\n        selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                     config_.layer_step] = 0;\n        max_block_num_after_retrieval_ = max_block_num;\n        block_table_after_retrieval_.swap(block_table_before_retrieval_);\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    //     printf(\"layer %d time of retrieval kvcache: %f s\\n\", layer_idx,\n    //     std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::calculate_sparsity_layer_(const uint16_t *q_in_data,\n                                        float *attn_sparsity, int batch_size,\n                                        int max_block_num, int *block_table,\n                                        int *cache_seqlens, Backend *backend\n\n) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        batch_size * config_.kv_head_num * max_block_num,\n        [&](int thread_id) {\n            thread_cur_head_idx_[thread_id].first = -1;\n            thread_cur_head_idx_[thread_id].second = -1;\n        },\n        [&](int task_id) {\n            int batch_id = task_id / (config_.kv_head_num * max_block_num);\n            int head_id = (task_id % (config_.kv_head_num * max_block_num)) /\n                          max_block_num;\n            int block_id = task_id % max_block_num;\n            int thread_id = Backend::thread_local_id;\n            // If the block is out of the sequence length, skip it.\n            if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n                int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n                if (seq_len == 0)\n                    return;\n\n                // Prepare the attention mask for the last block.\n                int full_blocks = seq_len / 8;\n                int remaining_bits = seq_len % 8;\n                // Fill full blocks with 1s\n                for (int i = 0; i < full_blocks; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0xFF;\n                }\n                // Fill the remaining bits in the next block\n                if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n                    thread_local_attn_mask_[thread_id][full_blocks] =\n                        (1 << remaining_bits) - 1;\n                } else {\n                    thread_local_attn_mask_[thread_id][full_blocks] = 0;\n                }\n\n                for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0;\n                }\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            } else {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            }\n            for (int i = 0; i < n_gqa_; i++) {\n                block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] =\n                    thread_local_attn_lse_[thread_id][i];\n            }\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse =\n                        thread_local_cur_attn_lse_[thread_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_attn_lse_[thread_id][i] -\n                                     thread_local_cur_attn_lse_[thread_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_[thread_id]\n                                                     [i * config_.head_dim +\n                                                      j] +=\n                            thread_local_output_fp32_[thread_id]\n                                                     [i * config_.head_dim + j];\n                    }\n                    thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n                }\n            } else {\n                if (cur_batch_idx != -1) {\n                    mutex_[cur_batch_idx][cur_head_id]->lock();\n                    for (int i = 0; i < n_gqa_; i++) {\n                        if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                            1e-6) {\n                            attn_lse_[cur_batch_idx][cur_head_id][i] =\n                                thread_local_cur_attn_lse_[thread_id][i];\n                            for (int j = 0; j < config_.head_dim; j++) {\n                                output_fp32_[cur_batch_idx][cur_head_id]\n                                            [i * config_.head_dim + j] =\n                                                thread_local_cur_output_fp32_\n                                                    [thread_id]\n                                                    [i * config_.head_dim + j];\n                            }\n                            continue;\n                        }\n                        float new_attn_lse =\n                            attn_lse_[cur_batch_idx][cur_head_id][i] +\n                            std::log(\n                                1.0 +\n                                std::exp(\n                                    thread_local_cur_attn_lse_[thread_id][i] -\n                                    attn_lse_[cur_batch_idx][cur_head_id][i]));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            output_fp32_[cur_batch_idx][cur_head_id].data() +\n                                i * config_.head_dim,\n                            std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                     new_attn_lse));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            thread_local_cur_output_fp32_[thread_id].data() +\n                                i * config_.head_dim,\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     new_attn_lse));\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] +=\n                                thread_local_cur_output_fp32_\n                                    [thread_id][i * config_.head_dim + j];\n                        }\n                        attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                    }\n                    mutex_[cur_batch_idx][cur_head_id]->unlock();\n                }\n                thread_cur_head_idx_[thread_id].first = batch_id;\n                thread_cur_head_idx_[thread_id].second = head_id;\n                for (int i = 0; i < n_gqa_; i++) {\n                    thread_local_cur_attn_lse_[thread_id][i] =\n                        thread_local_attn_lse_[thread_id][i];\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_\n                            [thread_id][i * config_.head_dim + j] =\n                                thread_local_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                }\n            }\n        },\n        // Merge the results of the remaining blocks.\n        [&](int thread_id) {\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (cur_head_id != -1) {\n                mutex_[cur_batch_idx][cur_head_id]->lock();\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse;\n                    if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                        1e-6) {\n                        attn_lse_[cur_batch_idx][cur_head_id][i] =\n                            thread_local_cur_attn_lse_[thread_id][i];\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] =\n                                            thread_local_cur_output_fp32_\n                                                [thread_id]\n                                                [i * config_.head_dim + j];\n                        }\n                        continue;\n                    }\n                    new_attn_lse =\n                        attn_lse_[cur_batch_idx][cur_head_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     attn_lse_[cur_batch_idx][cur_head_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        output_fp32_[cur_batch_idx][cur_head_id].data() +\n                            i * config_.head_dim,\n                        std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        output_fp32_[cur_batch_idx][cur_head_id]\n                                    [i * config_.head_dim + j] +=\n                            thread_local_cur_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                    attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                }\n                mutex_[cur_batch_idx][cur_head_id]->unlock();\n            }\n        });\n\n    for (int i = 0; i < batch_size; i++) {\n        for (int j = 0; j < max_block_num_after_retrieval_; j++) {\n            int block_idx = block_table_after_retrieval_[i][j];\n            for (int k = 0; k < config_.q_head_num; k++) {\n                attn_sparsity[i * config_.q_head_num + k] +=\n                    std::exp(block_lse_[i][block_idx][k] -\n                             attn_lse_[i][k / n_gqa_][k % n_gqa_]);\n            }\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of calculating sparsity: %f s\\n\", layer_id_,\n    //        diff.count());\n}\n\nvoid KVCache::attn_initialize_kvhead_(int batch_size, int layer_idx,\n                                      int *block_table, int &max_block_num,\n                                      int *cache_seqlens) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        // initialize output_fp32_ and attn_lse_\n        for (int i = 0; i < config_.kv_head_num; i++) {\n            for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                output_fp32_[batch_idx][i][j] = 0;\n            }\n            for (int j = 0; j < n_gqa_; j++) {\n                attn_lse_[batch_idx][i][j] = 0;\n            }\n        }\n\n        // clear top_similar_block_\n        while (!top_similar_block_[batch_idx].empty())\n            top_similar_block_[batch_idx].pop();\n    }\n\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];\n        for (int i = 0; i < max_block_num; i++) {\n            for (int j = 0; j < config_.kv_head_num; j++) {\n                block_table_before_retrieval_kvhead_[batch_idx][i][j] =\n                    block_table[batch_idx * max_block_num + i];\n                block_similar_kv_head_[batch_idx][i][j] = 0;\n            }\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    // printf(\"layer %d time of initializing attn: %f s\\n\", layer_idx,\n    //        std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::retrieval_kvcache_kvhead_(const uint16_t *q_in_data,\n                                        int init_block_num, int local_block_num,\n                                        int pick_block_num, int q_len,\n                                        int generate_token_idx, int batch_size,\n                                        int layer_idx, int *cache_seqlens,\n                                        int &max_block_num, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    max_block_num_after_retrieval_ = 0;\n    if (pick_block_num != -1 &&\n        (generate_token_idx % config_.token_step != 0 ||\n         (layer_idx % config_.layer_step != config_.layer_offset))) {\n\n        if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                         config_.layer_step] == 0) {\n            max_block_num_after_retrieval_ = max_block_num;\n            for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n                for (int i = 0; i < max_block_num; i++) {\n                    for (int j = 0; j < config_.kv_head_num; j++) {\n                        block_table_after_retrieval_kvhead_[batch_idx][i][j] =\n                            block_table_before_retrieval_kvhead_[batch_idx][i]\n                                                                [j];\n                    }\n                }\n            }\n        } else {\n\n            max_block_num_after_retrieval_ = selected_blocks_num_history_\n                [(layer_idx - config_.layer_offset) / config_.layer_step];\n\n            for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n                for (int i = 0; i < max_block_num_after_retrieval_; i++) {\n                    for (int j = 0; j < config_.kv_head_num; j++) {\n                        block_table_after_retrieval_kvhead_[batch_idx][i][j] =\n                            selected_blocks_history_kvhead_\n                                [(layer_idx - config_.layer_offset) /\n                                 config_.layer_step][batch_idx][i][j];\n                    }\n                }\n\n                if (cache_seqlens[batch_idx] % config_.block_len == 1) {\n                    selected_blocks_num_history_[(layer_idx -\n                                                  config_.layer_offset) /\n                                                 config_.layer_step] += 1;\n                    int x =\n                        selected_blocks_num_history_[(layer_idx -\n                                                      config_.layer_offset) /\n                                                     config_.layer_step];\n                    for (int i = 0; i < config_.kv_head_num; i++) {\n                        int last_block_idx =\n                            block_table_before_retrieval_kvhead_\n                                [batch_idx][cache_seqlens[batch_idx] /\n                                            config_.block_len][i];\n                        selected_blocks_history_kvhead_[(layer_idx -\n                                                         config_.layer_offset) /\n                                                        config_.layer_step]\n                                                       [batch_idx][x - 1][i] =\n                                                           last_block_idx;\n                        block_table_after_retrieval_kvhead_[batch_idx][x - 1]\n                                                           [i] = last_block_idx;\n                    }\n                }\n                cache_seqlens_[batch_idx] = std::min(\n                    cache_seqlens_[batch_idx],\n                    (cache_seqlens_[batch_idx] % config_.block_len) +\n                        (init_block_num + pick_block_num + local_block_num) *\n                            config_.block_len);\n            }\n        }\n    } else if (pick_block_num != -1) {\n        max_block_num_after_retrieval_ =\n            std::min(max_block_num,\n                     init_block_num + pick_block_num + local_block_num + 1);\n        calculate_block_similarity_kvhead_(q_in_data, batch_size, layer_idx,\n                                           q_len, max_block_num, cache_seqlens,\n                                           init_block_num, local_block_num,\n                                           pick_block_num, backend);\n        select_block_kvhead_(batch_size, layer_idx, max_block_num,\n                             init_block_num, local_block_num, pick_block_num);\n    } else {\n        selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                     config_.layer_step] = 0;\n        max_block_num_after_retrieval_ = max_block_num;\n        for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n            for (int i = 0; i < max_block_num; i++) {\n                for (int j = 0; j < config_.kv_head_num; j++) {\n                    block_table_after_retrieval_kvhead_[batch_idx][i][j] =\n                        block_table_before_retrieval_kvhead_[batch_idx][i][j];\n                }\n            }\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    // printf(\"layer %d time of retrieval kvcache: %f s\\n\", layer_idx,\n    //        std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::calculate_sparsity_kvhead_(const uint16_t *q_in_data,\n                                         float *attn_sparsity, int batch_size,\n                                         int max_block_num, int *block_table,\n                                         int *cache_seqlens, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        batch_size * config_.kv_head_num * max_block_num,\n        [&](int thread_id) {\n            thread_cur_head_idx_[thread_id].first = -1;\n            thread_cur_head_idx_[thread_id].second = -1;\n        },\n        [&](int task_id) {\n            int batch_id = task_id / (config_.kv_head_num * max_block_num);\n            int head_id = (task_id % (config_.kv_head_num * max_block_num)) /\n                          max_block_num;\n            int block_id = task_id % max_block_num;\n            int thread_id = Backend::thread_local_id;\n            // If the block is out of the sequence length, skip it.\n            if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n                int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n                if (seq_len == 0)\n                    return;\n\n                // Prepare the attention mask for the last block.\n                int full_blocks = seq_len / 8;\n                int remaining_bits = seq_len % 8;\n\n                // Fill full blocks with 1s\n                for (int i = 0; i < full_blocks; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0xFF;\n                }\n                // Fill the remaining bits in the next block\n                if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n                    thread_local_attn_mask_[thread_id][full_blocks] =\n                        (1 << remaining_bits) - 1;\n                } else {\n                    thread_local_attn_mask_[thread_id][full_blocks] = 0;\n                }\n\n                for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0;\n                }\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            } else {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            }\n            for (int i = 0; i < n_gqa_; i++) {\n                block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] =\n                    thread_local_attn_lse_[thread_id][i];\n            }\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse =\n                        thread_local_cur_attn_lse_[thread_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_attn_lse_[thread_id][i] -\n                                     thread_local_cur_attn_lse_[thread_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_[thread_id]\n                                                     [i * config_.head_dim +\n                                                      j] +=\n                            thread_local_output_fp32_[thread_id]\n                                                     [i * config_.head_dim + j];\n                    }\n                    thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n                }\n            } else {\n                if (cur_batch_idx != -1) {\n                    mutex_[cur_batch_idx][cur_head_id]->lock();\n                    for (int i = 0; i < n_gqa_; i++) {\n                        if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                            1e-6) {\n                            attn_lse_[cur_batch_idx][cur_head_id][i] =\n                                thread_local_cur_attn_lse_[thread_id][i];\n                            for (int j = 0; j < config_.head_dim; j++) {\n                                output_fp32_[cur_batch_idx][cur_head_id]\n                                            [i * config_.head_dim + j] =\n                                                thread_local_cur_output_fp32_\n                                                    [thread_id]\n                                                    [i * config_.head_dim + j];\n                            }\n                            continue;\n                        }\n                        float new_attn_lse =\n                            attn_lse_[cur_batch_idx][cur_head_id][i] +\n                            std::log(\n                                1.0 +\n                                std::exp(\n                                    thread_local_cur_attn_lse_[thread_id][i] -\n                                    attn_lse_[cur_batch_idx][cur_head_id][i]));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            output_fp32_[cur_batch_idx][cur_head_id].data() +\n                                i * config_.head_dim,\n                            std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                     new_attn_lse));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            thread_local_cur_output_fp32_[thread_id].data() +\n                                i * config_.head_dim,\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     new_attn_lse));\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] +=\n                                thread_local_cur_output_fp32_\n                                    [thread_id][i * config_.head_dim + j];\n                        }\n                        attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                    }\n                    mutex_[cur_batch_idx][cur_head_id]->unlock();\n                }\n                thread_cur_head_idx_[thread_id].first = batch_id;\n                thread_cur_head_idx_[thread_id].second = head_id;\n                for (int i = 0; i < n_gqa_; i++) {\n                    thread_local_cur_attn_lse_[thread_id][i] =\n                        thread_local_attn_lse_[thread_id][i];\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_\n                            [thread_id][i * config_.head_dim + j] =\n                                thread_local_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                }\n            }\n        },\n        // Merge the results of the remaining blocks.\n        [&](int thread_id) {\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (cur_head_id != -1) {\n                mutex_[cur_batch_idx][cur_head_id]->lock();\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse;\n                    if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                        1e-6) {\n                        attn_lse_[cur_batch_idx][cur_head_id][i] =\n                            thread_local_cur_attn_lse_[thread_id][i];\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] =\n                                            thread_local_cur_output_fp32_\n                                                [thread_id]\n                                                [i * config_.head_dim + j];\n                        }\n                        continue;\n                    }\n                    new_attn_lse =\n                        attn_lse_[cur_batch_idx][cur_head_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     attn_lse_[cur_batch_idx][cur_head_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        output_fp32_[cur_batch_idx][cur_head_id].data() +\n                            i * config_.head_dim,\n                        std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        output_fp32_[cur_batch_idx][cur_head_id]\n                                    [i * config_.head_dim + j] +=\n                            thread_local_cur_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                    attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                }\n                mutex_[cur_batch_idx][cur_head_id]->unlock();\n            }\n        });\n\n    for (int i = 0; i < batch_size; i++) {\n        for (int j = 0; j < max_block_num_after_retrieval_; j++) {\n            for (int k = 0; k < config_.q_head_num; k++) {\n                int block_idx =\n                    block_table_after_retrieval_kvhead_[i][j][k / n_gqa_];\n                attn_sparsity[i * config_.q_head_num + k] +=\n                    std::exp(block_lse_[i][block_idx][k] -\n                             attn_lse_[i][k / n_gqa_][k % n_gqa_]);\n            }\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of calculating sparsity: %f s\\n\", layer_id_,\n    //        diff.count());\n}\nvoid KVCache::calculate_block_similarity_kvhead_(\n    const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,\n    int max_block_num, int *cache_seqlens, int init_block_num,\n    int local_block_num, int pick_block_num, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    backend->do_work_stealing_job(\n        batch_size * max_block_num, nullptr,\n        [&](int task_id) {\n            int batch_id = task_id / max_block_num;\n            int block_id = task_id % max_block_num;\n            int seq_len = cache_seqlens_[batch_id];\n\n            if (block_id < init_block_num ||\n                block_id >= (seq_len / config_.block_len) - local_block_num) {\n                return;\n            }\n            int block_idx =\n                block_table_before_retrieval_kvhead_[batch_id][block_id][0];\n\n            for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                for (int i = 0; i < config_.head_dim; i++) {\n                    float q_i = 0, qa_i = std::numeric_limits<float>::lowest();\n                    for (int q_id = 0; q_id < q_len; q_id++) {\n                        q_i += GGML_FP16_TO_FP32(\n                            q_in_data[batch_id * q_len * config_.q_head_num *\n                                          config_.head_dim +\n                                      q_id * config_.q_head_num *\n                                          config_.head_dim +\n                                      head_id * config_.head_dim + i]);\n                    }\n                    q_i /= q_len;\n                    for (int anchor_id = 0; anchor_id < config_.anchor_num;\n                         anchor_id++) {\n                        qa_i = std::max(\n                            qa_i,\n                            GGML_FP16_TO_FP32(\n                                anchor_[layer_idx * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        anchor_id * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + i]) *\n                                q_i);\n                    }\n                    block_similar_kv_head_[batch_id][block_id]\n                                          [head_id / n_gqa_] += qa_i;\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of calculating similarity: %f s\\n\", layer_idx,\n    //        diff.count());\n}\nvoid KVCache::select_block_kvhead_(int batch_size, int layer_idx,\n                                   int max_block_num, int init_block_num,\n                                   int local_block_num, int pick_block_num) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        int cache_len_after_retrieval = 0;\n        if (cache_seqlens_[batch_idx] / config_.block_len <=\n            init_block_num + pick_block_num + local_block_num) {\n            selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                         config_.layer_step] = 0;\n            for (int i = 0; i < max_block_num; i++) {\n                for (int j = 0; j < config_.kv_head_num; j++) {\n                    block_table_after_retrieval_kvhead_[batch_idx][i][j] =\n                        block_table_before_retrieval_kvhead_[batch_idx][i][j];\n                }\n            }\n            continue;\n        }\n        for (int head_id = 0; head_id < config_.kv_head_num; head_id++) {\n\n            for (int block_id = init_block_num;\n                 block_id < (cache_seqlens_[batch_idx] / config_.block_len) -\n                                local_block_num;\n                 block_id++) {\n\n                top_similar_block_[batch_idx].push(std::make_pair(\n                    block_similar_kv_head_[batch_idx][block_id][head_id],\n                    block_table_before_retrieval_kvhead_[batch_idx][block_id]\n                                                        [head_id]));\n                if (top_similar_block_[batch_idx].size() > pick_block_num) {\n                    top_similar_block_[batch_idx].pop();\n                }\n            }\n\n            int i = 0;\n            for (; i < init_block_num; i++) {\n                block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n                    block_table_before_retrieval_kvhead_[batch_idx][i][head_id];\n            }\n            while (!top_similar_block_[batch_idx].empty()) {\n                block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n                    top_similar_block_[batch_idx].top().second;\n                top_similar_block_[batch_idx].pop();\n                i++;\n            }\n            for (; i < init_block_num + pick_block_num + local_block_num; i++) {\n                block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n                    block_table_before_retrieval_kvhead_\n                        [batch_idx]\n                        [(cache_seqlens_[batch_idx] / config_.block_len) -\n                         local_block_num + i - init_block_num - pick_block_num]\n                        [head_id];\n            }\n            if (cache_seqlens_[batch_idx] % config_.block_len != 0) {\n                block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n                    block_table_before_retrieval_kvhead_[batch_idx][(\n                        cache_seqlens_[batch_idx] / config_.block_len)]\n                                                        [head_id];\n                cache_len_after_retrieval =\n                    (cache_seqlens_[batch_idx] % config_.block_len) +\n                    i * config_.block_len;\n                i++;\n            } else {\n                cache_len_after_retrieval =\n                    (cache_seqlens_[batch_idx] % config_.block_len) +\n                    i * config_.block_len;\n            }\n            for (int j = 0; j < i; j++) {\n                selected_blocks_history_kvhead_\n                    [(layer_idx - config_.layer_offset) / config_.layer_step]\n                    [batch_idx][j][head_id] =\n                        block_table_after_retrieval_kvhead_[batch_idx][j]\n                                                           [head_id];\n            }\n        }\n        cache_seqlens_[batch_idx] = cache_len_after_retrieval;\n        selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                     config_.layer_step] =\n            (cache_len_after_retrieval + config_.block_len - 1) /\n            config_.block_len;\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of selecting block: %f s\\n\", layer_idx,\n    //        diff.count())\n}\n\nvoid KVCache::get_attn_sparsity(const ggml_fp16_t *q_in, float *attn_sparsity,\n                                int layer_idx, int generate_token_idx,\n                                int q_len, int batch_size, int max_block_num,\n                                int *block_table, int *cache_seqlens,\n                                int *block_table_origin,\n                                int *cache_seqlens_origin,\n                                int max_block_num_origin, int topk, int local,\n                                Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    layer_id_ = layer_idx;\n    int thread_num = backend->get_thread_num();\n    batch_size = 1;\n\n    const uint16_t *q_in_data = const_cast<const uint16_t *>(q_in);\n\n    quantize_q_(q_in_data, batch_size);\n    if (config_.retrieval_type == RetrievalType::LAYER) {\n        attn_initialize_layer_(batch_size, layer_idx, block_table,\n                               max_block_num, cache_seqlens);\n        retrieval_kvcache_layer_(q_in_data, 1, local, topk, q_len,\n                                 generate_token_idx, batch_size, layer_idx,\n                                 cache_seqlens, max_block_num, backend);\n        calculate_sparsity_layer_(q_in_data, attn_sparsity, batch_size,\n                                  max_block_num_origin, block_table_origin,\n                                  cache_seqlens_origin, backend);\n    } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n        attn_initialize_kvhead_(batch_size, layer_idx, block_table,\n                                max_block_num, cache_seqlens);\n        retrieval_kvcache_kvhead_(q_in_data, 1, local, topk, q_len,\n                                  generate_token_idx, batch_size, layer_idx,\n                                  cache_seqlens, max_block_num, backend);\n        calculate_sparsity_kvhead_(q_in_data, attn_sparsity, batch_size,\n                                   max_block_num_origin, block_table_origin,\n                                   cache_seqlens_origin, backend);\n    }\n}\n\nvoid KVCache::attn_with_kvcache_one_block_(\n    int head_dim, int bsz,\n    ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0\n    // [bsz, head_dim]\n    // Quantization is always on the head_dim dimension (per_token). If\n    // head_dim % 32 != 0, an error will be raised. The size must be bsz *\n    // head_dim/32 * qtype_size.\n    const void *q,\n\n    int past_kv_len, int past_kv_offset,\n    bool is_full_attn, // true indicates a full 1 mask\n    // If is_full_attn = false, a bit matrix representing the mask is\n    // passed. [bsz, past_kv_len]\n    const uint8_t *attn_mask,\n\n    ggml_type k_type, // GGML data type of `K Cache`, only supports fp16,\n                      // q4_0, q8_0\n    int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an\n                      // error\n    // [seq_len, head_dim]\n    // If quant_type == 0, head_dim % 32 must be 0.\n    // If quant_type == 1, seq_len % 32 must be 0.\n    const void *k_cache,\n\n    // k_anchor_type must be fp16\n    int num_k_anchor, // num_k_anchor == 0 indicates no anchor\n    // [num_k_anchor, head_dim]\n    const void *k_cache_anchors,\n    // Each token is associated with the nearest previous position's anchor,\n    // with the same distance.\n    const int *k_cache_anchor_pos,\n\n    // v_cache similar to k_cache\n    ggml_type v_type, int v_quant_type,\n    // [head_dim, seq_len]\n    const void *v_cache, int num_v_anchor, const void *v_cache_anchors,\n    const int *v_cache_anchor_pos,\n\n    // Pre-allocated buffer for intermediate calculations [bsz,\n    // past_kv_len]. No malloc is performed inside this function.\n    float *attn_score,\n\n    // Output: [bsz, head_dim], with the same type as q_type\n    void *output,\n    // [bsz]\n    float *lse,\n\n    // Pre-allocated temporary buffer with sufficient size:\n    // (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *\n    // head_dim + past_kv_len * head_dim / 32) bytes.\n    void *draft,\n\n    // Apply rotary embedding online\n    const int *rotary_angle, const void *rotary_cos, const void *rotary_sin\n    // rotary_cos=None,\n    // rotary_sin=None,\n    // cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,\n    // cache_batch_idx: Optional[torch.Tensor] = None,\n    // rotary_interleaved=True,\n\n    // // Not supported for now\n    // window_size=(-1, -1),  # -1 means infinite context window\n    // alibi_slopes=None,\n) {\n    assert(head_dim % 32 == 0);\n    assert(k_quant_type == 0);\n    assert(v_quant_type == 1);\n    assert(q_type == GGML_TYPE_F16 || q_type == GGML_TYPE_Q8_0);\n    if (q_type == GGML_TYPE_F16) {\n        assert(k_type == GGML_TYPE_F16);\n        assert(v_type == GGML_TYPE_F16);\n\n        // attn = q * k + q * k_anchor\n        // TODO: anchor\n        assert(num_k_anchor == 0);\n\n        if (rotary_angle != nullptr) {\n            ggml_fp16_t *k_cache_with_rope_fp16 =\n                (reinterpret_cast<ggml_fp16_t *>(draft) +\n                 sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n                 sizeof(float) * bsz * head_dim);\n            // dequant k_cache and apply rope\n            // k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)\n            // k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)\n\n            // k(i)cos(i) -> k_rope(i)\n            // k(i)sin(i+l) -> k_rope(i+l)\n\n            // k(i)cos(i) -> k_rope(i)\n            // -k(i)sin(i-l) -> k_rope(i-l)\n\n            std::vector<float> block_fp32(32);\n            for (int k = 0; k < past_kv_len; k++) {\n                int angle = rotary_angle[k];\n                for (int l = 0; l < head_dim / 32; l++) {\n                    for (int m = 0; m < 32; m++) {\n                        float x = GGML_FP16_TO_FP32((\n                            (ggml_fp16_t *)k_cache)[k * head_dim + l * 32 + m]);\n                        float sin_val = GGML_FP16_TO_FP32(\n                            ((ggml_fp16_t *)\n                                 rotary_sin)[angle * head_dim + l * 32 + m]);\n                        float cos_val = GGML_FP16_TO_FP32(\n                            ((ggml_fp16_t *)\n                                 rotary_cos)[angle * head_dim + l * 32 + m]);\n\n                        if (l * 32 + m < head_dim / 2) {\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =\n                                GGML_FP32_TO_FP16(x * cos_val);\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m +\n                                                   head_dim / 2] =\n                                GGML_FP32_TO_FP16(-x * sin_val);\n                        } else {\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =\n                                GGML_FP32_TO_FP16(\n                                    GGML_FP16_TO_FP32(\n                                        k_cache_with_rope_fp16[k * head_dim +\n                                                               l * 32 + m]) +\n                                    x * sin_val);\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m -\n                                                   head_dim / 2] =\n                                GGML_FP32_TO_FP16(\n                                    GGML_FP16_TO_FP32(\n                                        k_cache_with_rope_fp16[k * head_dim +\n                                                               l * 32 + m -\n                                                               head_dim / 2]) -\n                                    x * cos_val);\n                        }\n                    }\n                }\n            }\n\n            llamafile_sgemm(past_kv_len, bsz, head_dim,\n                            (ggml_fp16_t *)k_cache_with_rope_fp16, head_dim,\n                            (ggml_fp16_t *)q, head_dim, attn_score, past_kv_len,\n                            0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16,\n                            GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        } else {\n            bool ok = llamafile_sgemm(\n                past_kv_len, bsz, head_dim, (ggml_fp16_t *)k_cache, head_dim,\n                (ggml_fp16_t *)q, head_dim, attn_score, past_kv_len, 0, 1,\n                GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16, GGML_TYPE_F32,\n                GGML_PREC_DEFAULT);\n\n            if (!ok) {\n                printf(\"llamafile_sgemm failed\\n\");\n            }\n        }\n        // attn = attn * scale\n        float scale_factor = 1.0 / std::sqrt(float(head_dim));\n        ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);\n\n        // attn = attn & mask\n        if (!is_full_attn) {\n            for (int i = 0; i < bsz; i++) {\n                for (int j = 0; j < past_kv_len; j++) {\n                    int index = i * past_kv_len + j;\n                    if (!(attn_mask[j / 8] & (1 << (j % 8)))) {\n                        attn_score[index] =\n                            std::numeric_limits<float>::lowest();\n                    }\n                }\n            }\n        }\n\n        // attn = softmax(attn)\n        for (int i = 0; i < bsz; i++) {\n            float sum_exp = 0;\n            for (int j = 0; j < past_kv_len; j++) {\n                attn_score[i * past_kv_len + j] =\n                    std::exp(attn_score[i * past_kv_len + j]);\n                sum_exp += attn_score[i * past_kv_len + j];\n            }\n            for (int j = 0; j < past_kv_len; j++) {\n                attn_score[i * past_kv_len + j] /= sum_exp;\n            }\n            if (lse != nullptr) {\n                lse[i] = std::log(sum_exp);\n            }\n        }\n\n        // output = attn * v + attn * v_anchor\n        // std::vector<float> sum(bsz * head_dim);\n        float *sum = reinterpret_cast<float *>(reinterpret_cast<char *>(draft) +\n                                               sizeof(block_q8_0) * bsz *\n                                                   past_kv_len / QK8_0);\n\n        // float* attn_score_fp16(bsz, past_kv_len)\n        ggml_fp16_t *attn_score_fp16 = (reinterpret_cast<ggml_fp16_t *>(\n            reinterpret_cast<char *>(draft) +\n            sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n            sizeof(float) * bsz * head_dim));\n\n        for (int i = 0; i < bsz * past_kv_len; i++) {\n            attn_score_fp16[i] = GGML_FP32_TO_FP16(attn_score[i]);\n        }\n\n        // TODO: anchor\n        assert(num_v_anchor == 0);\n        bool ok = llamafile_sgemm(\n            head_dim, bsz, past_kv_len, (ggml_fp16_t *)v_cache, past_kv_len,\n            (ggml_fp16_t *)attn_score_fp16, past_kv_len, sum, head_dim, 0, 1,\n            GGML_TASK_TYPE_COMPUTE, v_type, GGML_TYPE_F16, GGML_TYPE_F32,\n            GGML_PREC_DEFAULT);\n        if (!ok) {\n            printf(\"llamafile_sgemm failed\\n\");\n        }\n\n        // copy to output\n        for (int i = 0; i < bsz; i++) {\n            for (int j = 0; j < head_dim; j++) {\n                ((float *)output)[i * head_dim + j] = sum[i * head_dim + j];\n            }\n        }\n    } else {\n        assert(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q8_0);\n        assert(v_type == GGML_TYPE_Q4_0 || v_type == GGML_TYPE_Q8_0);\n\n        // attn = q * k + q * k_anchor\n        // TODO: anchor\n        assert(num_k_anchor == 0);\n\n        if (rotary_angle != nullptr) {\n            ggml_fp16_t *k_cache_with_rope_fp16 =\n                (reinterpret_cast<ggml_fp16_t *>(draft) +\n                 sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n                 sizeof(float) * bsz * head_dim);\n            block_q4_0 *k_cache_with_rope_q4 =\n                (reinterpret_cast<block_q4_0 *>(draft) +\n                 sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n                 sizeof(float) * bsz * head_dim) +\n                sizeof(ggml_fp16_t) * bsz * head_dim;\n            // dequant k_cache and apply rope\n            // k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)\n            // k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)\n\n            // k(i)cos(i) -> k_rope(i)\n            // k(i)sin(i+l) -> k_rope(i+l)\n\n            // k(i)cos(i) -> k_rope(i)\n            // -k(i)sin(i-l) -> k_rope(i-l)\n\n            std::vector<float> block_fp32(32);\n            for (int k = 0; k < past_kv_len; k++) {\n                int angle = rotary_angle[k];\n                for (int l = 0; l < head_dim / 32; l++) {\n                    block_q4_0 block =\n                        ((block_q4_0 *)k_cache)[k * head_dim / 32 + l];\n                    dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                    for (int m = 0; m < 32; m++) {\n                        float sin_val = GGML_FP16_TO_FP32(\n                            ((ggml_fp16_t *)\n                                 rotary_sin)[angle * head_dim + l * 32 + m]);\n                        float cos_val = GGML_FP16_TO_FP32(\n                            ((ggml_fp16_t *)\n                                 rotary_cos)[angle * head_dim + l * 32 + m]);\n\n                        if (l * 32 + m < head_dim / 2) {\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =\n                                GGML_FP32_TO_FP16(block_fp32[m] * cos_val);\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m +\n                                                   head_dim / 2] =\n                                GGML_FP32_TO_FP16(-block_fp32[m] * sin_val);\n                        } else {\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m] +=\n                                GGML_FP32_TO_FP16(block_fp32[m] * sin_val);\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m -\n                                                   head_dim / 2] -=\n                                GGML_FP32_TO_FP16(block_fp32[m] * cos_val);\n                        }\n                    }\n                }\n            }\n            // quantize k_cache_with_rope_fp16\n            for (int k = 0; k < past_kv_len; k++) {\n                for (int l = 0; l < head_dim / 32; l++) {\n                    for (int m = 0; m < 32; m++) {\n                        block_fp32[m] = GGML_FP16_TO_FP32(\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m]);\n                    }\n                    quantize_row_q4_0(\n                        block_fp32.data(),\n                        &k_cache_with_rope_q4[k * head_dim / 32 + l], 32);\n                }\n            }\n\n            llamafile_sgemm(past_kv_len, bsz, head_dim / 32,\n                            (block_q4_0 *)k_cache_with_rope_q4, head_dim / 32,\n                            (block_q8_0 *)q, head_dim / 32, attn_score,\n                            past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type,\n                            GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        } else {\n            llamafile_sgemm(past_kv_len, bsz, head_dim / 32,\n                            (block_q4_0 *)k_cache, head_dim / 32,\n                            (block_q8_0 *)q, head_dim / 32, attn_score,\n                            past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type,\n                            GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        }\n\n        // attn = attn * scale\n        float scale_factor = 1.0 / std::sqrt(float(head_dim));\n        ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);\n\n        // attn = attn & mask\n        if (!is_full_attn) {\n            for (int i = 0; i < bsz; i++) {\n                for (int j = 0; j < past_kv_len; j++) {\n                    int index = i * past_kv_len + j;\n                    if (!(attn_mask[j / 8] & (1 << (j % 8)))) {\n                        attn_score[index] =\n                            std::numeric_limits<float>::lowest();\n                    }\n                }\n            }\n        }\n\n        // attn = softmax(attn)\n        for (int i = 0; i < bsz; i++) {\n            float sum_exp = 0;\n            for (int j = 0; j < past_kv_len; j++) {\n                attn_score[i * past_kv_len + j] =\n                    std::exp(attn_score[i * past_kv_len + j]);\n                sum_exp += attn_score[i * past_kv_len + j];\n            }\n            for (int j = 0; j < past_kv_len; j++) {\n                attn_score[i * past_kv_len + j] /= sum_exp;\n            }\n            if (lse != nullptr) {\n                lse[i] = std::log(sum_exp);\n            }\n        }\n\n        // output = attn * v + attn * v_anchor\n        // std::vector<block_q8_0> attn_q8_0(bsz * past_kv_len / QK8_0);\n        block_q8_0 *attn_q8_0 = reinterpret_cast<block_q8_0 *>(draft);\n        quantize_row_q8_0(attn_score, attn_q8_0, bsz * past_kv_len);\n        // std::vector<float> sum(bsz * head_dim);\n        float *sum = reinterpret_cast<float *>(reinterpret_cast<char *>(draft) +\n                                               sizeof(block_q8_0) * bsz *\n                                                   past_kv_len / QK8_0);\n        // TODO: anchor\n        assert(num_v_anchor == 0);\n        llamafile_sgemm(head_dim, bsz, past_kv_len / 32, (block_q4_0 *)v_cache,\n                        past_kv_len / 32, attn_q8_0, past_kv_len / 32, sum,\n                        head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE, v_type,\n                        GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n\n        quantize_row_q8_0(sum, (block_q8_0 *)output, bsz * head_dim);\n    }\n}\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"kvcache.h\"\n\n#include <chrono>\n\nvoid KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    std::ifstream ifs_tensor(tensor_file_path, std::ios::binary);\n    if (!ifs_tensor) {\n        throw std::runtime_error(\"Failed to open tensor file\");\n    }\n    ifs_tensor.read(reinterpret_cast<char *>(&cache_total_len_),\n                    sizeof(cache_total_len_));\n    int past_block_num =\n        (cache_total_len_ + config_.block_len - 1) / config_.block_len;\n    printf(\"cache_total_len: %d, past_block_num: %d\\n\", cache_total_len_,\n           past_block_num);\n    for (int i = 0; i < config_.layer_num; ++i) {\n        past_block_num_[i] = past_block_num;\n    }\n    ifs_tensor.read(reinterpret_cast<char *>(anchor_.data()),\n                    anchor_.size() * sizeof(ggml_fp16_t));\n    for (int i = 0; i < config_.layer_num; ++i) {\n        for (int j = 0; j < config_.kv_head_num; ++j) {\n            for (int k = 0; k < past_block_num_[i]; ++k) {\n                if (config_.kv_type == GGML_TYPE_F16) {\n                    ifs_tensor.read(\n                        reinterpret_cast<char *>(k_cache_fp16_[i][j][k].data()),\n                        k_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));\n                    ifs_tensor.read(\n                        reinterpret_cast<char *>(v_cache_fp16_[i][j][k].data()),\n                        v_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));\n                } else if (config_.kv_type == GGML_TYPE_Q4_0) {\n                    ifs_tensor.read(\n                        reinterpret_cast<char *>(k_cache_q4[i][j][k].data()),\n                        k_cache_q4[i][j][k].size() * sizeof(block_q4_0));\n                    ifs_tensor.read(\n                        reinterpret_cast<char *>(v_cache_q4[i][j][k].data()),\n                        v_cache_q4[i][j][k].size() * sizeof(block_q4_0));\n                }\n            }\n        }\n        for (int k = 0; k < past_block_num_[i]; ++k) {\n            for (int l = 0; l < config_.block_len; l++) {\n                ifs_tensor.read(\n                    reinterpret_cast<char *>(importance_[i][k][l].data()),\n                    importance_[i][k][l].size() * sizeof(ggml_fp16_t));\n            }\n        }\n    }\n    ifs_tensor.close();\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    printf(\"time of load: %f s\\n\", diff.count());\n}\nvoid KVCache::dump_kvcache(int *block_table, int cache_total_len,\n                           std::string tensor_file_path, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    std::ofstream ofs(tensor_file_path, std::ios::binary);\n    printf(\"dump_kvcache: %s\\n\", tensor_file_path.c_str());\n    if (!ofs.is_open()) {\n        std::cerr << \"Cannot open file \" << tensor_file_path << std::endl;\n        return;\n    }\n    ofs.write(reinterpret_cast<const char *>(&cache_total_len),\n              sizeof(cache_total_len));\n    int past_block_num =\n        (cache_total_len + config_.block_len - 1) / config_.block_len;\n    printf(\"cache_total_len: %d, past_block_num: %d\\n\", cache_total_len,\n           past_block_num);\n    ofs.write(reinterpret_cast<const char *>(anchor_.data()),\n              anchor_.size() * sizeof(ggml_fp16_t));\n    for (int i = 0; i < config_.layer_num; ++i) {\n        for (int j = 0; j < config_.kv_head_num; ++j) {\n            for (int k = 0; k < past_block_num; ++k) {\n                int block_idx = block_table[k];\n                if (config_.kv_type == GGML_TYPE_F16) {\n                    ofs.write(reinterpret_cast<const char *>(\n                                  k_cache_fp16_[i][j][block_idx].data()),\n                              k_cache_fp16_[i][j][block_idx].size() *\n                                  sizeof(ggml_fp16_t));\n                    ofs.write(reinterpret_cast<const char *>(\n                                  v_cache_fp16_[i][j][block_idx].data()),\n                              v_cache_fp16_[i][j][block_idx].size() *\n                                  sizeof(ggml_fp16_t));\n\n                } else if (config_.kv_type == GGML_TYPE_Q4_0) {\n                    ofs.write(reinterpret_cast<const char *>(\n                                  k_cache_q4[i][j][block_idx].data()),\n                              k_cache_q4[i][j][block_idx].size() *\n                                  sizeof(block_q4_0));\n                    ofs.write(reinterpret_cast<const char *>(\n                                  v_cache_q4[i][j][block_idx].data()),\n                              v_cache_q4[i][j][block_idx].size() *\n                                  sizeof(block_q4_0));\n                }\n            }\n        }\n        for (int k = 0; k < past_block_num; ++k) {\n            int block_idx = block_table[k];\n            for (int l = 0; l < config_.block_len; l++) {\n                ofs.write(reinterpret_cast<const char *>(\n                              importance_[i][block_idx][l].data()),\n                          importance_[i][block_idx][l].size() *\n                              sizeof(ggml_fp16_t));\n            }\n        }\n    }\n    ofs.close();\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    printf(\"time of dump: %f s\\n\", diff.count());\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"kvcache.h\"\n\n#include <chrono>\n\nvoid KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id,\n                                   int block_idx, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    block_idx = block_idx;\n    seq_len_ = config_.block_len;\n    anchor_data_ = const_cast<uint16_t *>(anchor);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of reading anchor: %f s\\n\", layer_id,\n           block_idx, duration.count());\n}\n\nvoid KVCache::update_anchor_one_block(const ggml_fp16_t *anchor, int layer_id,\n                                      int block_idx, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    block_idx = block_idx;\n    seq_len_ = config_.block_len;\n    anchor_data_ = const_cast<uint16_t *>(anchor);\n\n    // Each task updates the anchor of a certain position\n    // backend->do_work_stealing_job(config_.anchor_num, [&](int task_id) {\n    //     int k = task_id % config_.anchor_num;\n    //     int head_id = task_id / config_.anchor_num;\n    //     memcpy(anchor_[layer_id_][head_id][block_idx].data() +\n    //                k * config_.head_dim,\n    //            anchor_data_ + k * config_.head_dim,\n    //            sizeof(uint16_t) * config_.head_dim);\n    // });\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of writting anchor: %f s\\n\", layer_id,\n           block_idx, duration.count());\n}\n\nvoid KVCache::update_importance_one_block(const ggml_fp16_t *importance,\n                                          int layer_id, int block_idx,\n                                          Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    block_idx = block_idx;\n    seq_len_ = config_.block_len;\n    importance_data_ = const_cast<uint16_t *>(importance);\n\n    // Each task updates the importance of a certain position\n    backend->do_work_stealing_job(\n        config_.block_len, nullptr,\n        [&](int task_id) {\n            int k = task_id;\n            memcpy(importance_[layer_id_][block_idx].data() + k,\n                   importance_data_ + k, sizeof(uint16_t));\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of writting importance: %f s\\n\", layer_id,\n           block_idx, duration.count());\n}\n\nvoid KVCache::get_importance_one_block(ggml_fp16_t *importance, int layer_id,\n                                       int block_idx, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    block_idx = block_idx;\n    seq_len_ = config_.block_len;\n    importance_data_ = const_cast<uint16_t *>(importance);\n\n    // Each task updates the importance of a certain position\n    backend->do_work_stealing_job(\n        config_.block_len, nullptr,\n        [&](int task_id) {\n            int k = task_id;\n            memcpy(importance_data_ + k,\n                   importance_[layer_id_][block_idx].data() + k,\n                   sizeof(uint16_t));\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of reading importance: %f s\\n\", layer_id,\n           block_idx, duration.count());\n}\n\nvoid KVCache::update_kvcache_one_block_fp16(const ggml_fp16_t *k_in,\n                                            const ggml_fp16_t *v_in,\n                                            int layer_id, int block_idx,\n                                            Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    block_idx = block_idx;\n    seq_len_ = config_.block_len;\n    k_data_ = const_cast<uint16_t *>(k_in);\n    v_data_ = const_cast<uint16_t *>(v_in);\n\n    int new_block_num = std::max((int)past_block_num_[layer_id], block_idx + 1);\n\n    importance_[layer_id_].resize(new_block_num);\n\n    for (int i = 0; i < config_.kv_head_num; i++) {\n        k_cache_q4[layer_id][i].resize(new_block_num);\n        v_cache_q4[layer_id][i].resize(new_block_num);\n        // anchor_[layer_id][i].resize(new_block_num);\n    }\n\n    for (int i = 0; i < new_block_num; i++) {\n        importance_[layer_id][i].resize(config_.block_len);\n    }\n\n    // Each task updates the k cache or v cache of a certain header\n    backend->do_work_stealing_job(\n        config_.kv_head_num * 2, nullptr,\n        [&](int task_id) {\n            std::vector<float> block_fp32(32);\n            int head_id = task_id / 2;\n            if (task_id & 1) {\n                // fill k_cache_\n                k_cache_q4[layer_id_][head_id][block_idx].resize(\n                    config_.block_len * config_.head_dim / 32);\n                for (int k = 0; k < config_.block_len; k++) {\n                    for (int l = 0; l < config_.head_dim / 32; l++) {\n                        block_q4_0 block;\n                        for (int m = 0; m < 32; m++) {\n\n                            block_fp32[m] = GGML_FP16_TO_FP32(\n                                k_data_[((0 * config_.kv_head_num + head_id) *\n                                             seq_len_ +\n                                         0 * config_.block_len + k) *\n                                            config_.head_dim +\n                                        l * 32 + m]);\n                        }\n                        quantize_row_q4_0(block_fp32.data(), &block, 32);\n                        k_cache_q4[layer_id_][head_id][block_idx]\n                                  [k * config_.head_dim / 32 + l] = block;\n                    }\n                }\n            } else {\n                // fill v_cache_\n                v_cache_q4[layer_id_][head_id][block_idx].resize(\n                    config_.head_dim * config_.block_len / 32);\n                for (int k = 0; k < config_.block_len / 32; k++) {\n                    for (int l = 0; l < config_.head_dim; l++) {\n                        block_q4_0 block;\n                        for (int m = 0; m < 32; m++) {\n\n                            block_fp32[m] = GGML_FP16_TO_FP32(\n                                v_data_[((0 * config_.kv_head_num + head_id) *\n                                             seq_len_ +\n                                         0 * config_.block_len + k * 32 + m) *\n                                            config_.head_dim +\n                                        l]);\n                        }\n                        quantize_row_q4_0(block_fp32.data(), &block, 32);\n                        v_cache_q4[layer_id_][head_id][block_idx]\n                                  [l * config_.block_len / 32 + k] = block;\n                    }\n                }\n            }\n        },\n        nullptr);\n    past_block_num_[layer_id] = new_block_num;\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of writting KV Cache: %f s\\n\", layer_id,\n           block_idx, duration.count());\n    // printf(\"get_one_block_fp16 duration: %ld\\n\", duration);\n}\n\nvoid KVCache::get_kvcache_one_block_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,\n                                         int layer_id, int block_idx,\n                                         Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    seq_len_ = config_.block_len;\n    k_data_ = reinterpret_cast<uint16_t *>(k_in);\n    v_data_ = reinterpret_cast<uint16_t *>(v_in);\n\n    // printf(\"layer_id: %d, block_idx: %d\\n\", layer_id, block_idx);\n    // Each task gets the k cache or v cache of a certain header\n    backend->do_work_stealing_job(\n        config_.kv_head_num * 2, nullptr,\n        [&](int task_id) {\n            std::vector<float> block_fp32(32);\n            int head_id = task_id / 2;\n            if (task_id & 1) {\n                // get k_cache_\n                for (int k = 0; k < config_.block_len; k++) {\n                    for (int l = 0; l < config_.head_dim / 32; l++) {\n                        block_q4_0 block =\n                            k_cache_q4[layer_id_][head_id][block_idx]\n                                      [k * config_.head_dim / 32 + l];\n                        dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                        for (int m = 0; m < 32; m++) {\n\n                            k_data_[((0 * config_.kv_head_num + head_id) *\n                                         seq_len_ +\n                                     0 * config_.block_len + k) *\n                                        config_.head_dim +\n                                    l * 32 + m] =\n                                GGML_FP32_TO_FP16(block_fp32[m]);\n                        }\n                    }\n                }\n            } else {\n                // get v_cache_\n                for (int k = 0; k < config_.block_len / 32; k++) {\n                    for (int l = 0; l < config_.head_dim; l++) {\n                        block_q4_0 block =\n                            v_cache_q4[layer_id_][head_id][block_idx]\n                                      [l * config_.block_len / 32 + k];\n                        dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                        for (int m = 0; m < 32; m++) {\n\n                            v_data_[((0 * config_.kv_head_num + head_id) *\n                                         seq_len_ +\n                                     0 * config_.block_len + k * 32 + m) *\n                                        config_.head_dim +\n                                    l] = GGML_FP32_TO_FP16(block_fp32[m]);\n                        }\n                    }\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of reading KV Cache: %f s\\n\", layer_id,\n           block_idx, duration.count());\n    // printf(\"get_one_block_fp16 duration: %ld\\n\", duration);\n}\n\n// k_in: (batch_size, seq_len, head_num, head_dim)\n// v_in: (batch_size, seq_len, head_num, head_dim)\nvoid KVCache::get_and_update_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,\n                                          int layer_id, int *block_table,\n                                          int batch_size, int max_block_num,\n                                          int *cache_seqlens, int q_len,\n                                          Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    k_data_ = const_cast<uint16_t *>(k_in);\n    v_data_ = const_cast<uint16_t *>(v_in);\n\n    // Each task updates the k cache and v cache of a certain header\n    backend->do_work_stealing_job(\n        config_.kv_head_num * max_block_num * batch_size, nullptr,\n        [&](int task_id) {\n            // printf(\"block_idx: %d, task_id: %d\\n\", block_idx, task_id);\n            std::vector<float> block_fp32(32);\n            int batch_id = task_id / (config_.kv_head_num * max_block_num);\n            int block_id = (task_id / config_.kv_head_num) % max_block_num;\n            int head_id = task_id % config_.kv_head_num;\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            int seq_len = cache_seqlens[batch_id];\n            int block_l = block_id * config_.block_len;\n            int block_r = block_id * config_.block_len + config_.block_len;\n\n            if (block_l < seq_len) {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            k_data_\n                                [batch_id *\n                                     (max_block_num * config_.block_len *\n                                      config_.kv_head_num * config_.head_dim) +\n                                 block_id *\n                                     (config_.block_len * config_.kv_head_num *\n                                      config_.head_dim) +\n                                 k * (config_.kv_head_num * config_.head_dim) +\n                                 head_id * config_.head_dim + l] =\n                                    k_cache_fp16_[layer_id_][head_id][block_idx]\n                                                 [k * config_.head_dim + l];\n                            v_data_\n                                [batch_id *\n                                     (max_block_num * config_.block_len *\n                                      config_.kv_head_num * config_.head_dim) +\n                                 block_id *\n                                     (config_.block_len * config_.kv_head_num *\n                                      config_.head_dim) +\n                                 k * (config_.kv_head_num * config_.head_dim) +\n                                 head_id * config_.head_dim + l] =\n                                    v_cache_fp16_[layer_id_][head_id][block_idx]\n                                                 [l * config_.block_len + k];\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    // get k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q4_0 block =\n                                k_cache_q4[layer_id_][head_id][block_idx]\n                                          [k * config_.head_dim / 32 + l];\n                            dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                k_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        k * (config_.kv_head_num *\n                                             config_.head_dim) +\n                                        head_id * config_.head_dim + l * 32 +\n                                        m] = GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                    // get v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q4_0 block =\n                                v_cache_q4[layer_id_][head_id][block_idx]\n                                          [l * config_.block_len / 32 + k];\n                            dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len)\n                                    break;\n                                v_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        (k * 32 + m) * config_.kv_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    // get k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q8_0 block =\n                                k_cache_q8[layer_id_][head_id][block_idx]\n                                          [k * config_.head_dim / 32 + l];\n                            dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                k_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        k * (config_.kv_head_num *\n                                             config_.head_dim) +\n                                        head_id * config_.head_dim + l * 32 +\n                                        m] = GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                    // get v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q8_0 block =\n                                v_cache_q8[layer_id_][head_id][block_idx]\n                                          [l * config_.block_len / 32 + k];\n                            dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len)\n                                    break;\n                                v_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        (k * 32 + m) * config_.kv_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                }\n            }\n            if (block_r > seq_len && block_l < seq_len + q_len) {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >=\n                                seq_len + q_len ||\n                            block_id * config_.block_len + k < seq_len)\n                            continue;\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            k_cache_fp16_[layer_id_][head_id][block_idx]\n                                         [k * config_.head_dim + l] = k_data_\n                                             [batch_id * (max_block_num *\n                                                          config_.block_len *\n                                                          config_.kv_head_num *\n                                                          config_.head_dim) +\n                                              block_id * (config_.block_len *\n                                                          config_.kv_head_num *\n                                                          config_.head_dim) +\n                                              k * (config_.kv_head_num *\n                                                   config_.head_dim) +\n                                              head_id * config_.head_dim + l];\n                            v_cache_fp16_[layer_id_][head_id][block_idx]\n                                         [l * config_.block_len + k] = v_data_\n                                             [batch_id * (max_block_num *\n                                                          config_.block_len *\n                                                          config_.kv_head_num *\n                                                          config_.head_dim) +\n                                              block_id * (config_.block_len *\n                                                          config_.kv_head_num *\n                                                          config_.head_dim) +\n                                              k * (config_.kv_head_num *\n                                                   config_.head_dim) +\n                                              head_id * config_.head_dim + l];\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    // fill k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >=\n                                seq_len + q_len ||\n                            block_id * config_.block_len + k < seq_len)\n                            continue;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q4_0 block;\n                            for (int m = 0; m < 32; m++) {\n\n                                block_fp32[m] = GGML_FP16_TO_FP32(\n                                    k_data_[batch_id * (max_block_num *\n                                                        config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            block_id * (config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            k * (config_.kv_head_num *\n                                                 config_.head_dim) +\n                                            head_id * config_.head_dim +\n                                            l * 32 + m]);\n                            }\n                            quantize_row_q4_0(block_fp32.data(), &block, 32);\n                            k_cache_q4[layer_id_][head_id][block_idx]\n                                      [k * config_.head_dim / 32 + l] = block;\n                        }\n                    }\n\n                    // fill v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q4_0 block;\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len + q_len) {\n                                    block_fp32[m] = 0;\n                                    continue;\n                                }\n                                block_fp32[m] = GGML_FP16_TO_FP32(\n                                    v_data_[batch_id * (max_block_num *\n                                                        config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            block_id * (config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            (k * 32 + m) * config_.kv_head_num *\n                                                config_.head_dim +\n                                            head_id * config_.head_dim + l]);\n                            }\n                            quantize_row_q4_0(block_fp32.data(), &block, 32);\n                            v_cache_q4[layer_id_][head_id][block_idx]\n                                      [l * config_.block_len / 32 + k] = block;\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    // fill k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >=\n                                seq_len + q_len ||\n                            block_id * config_.block_len + k < seq_len)\n                            continue;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q8_0 block;\n                            for (int m = 0; m < 32; m++) {\n\n                                block_fp32[m] = GGML_FP16_TO_FP32(\n                                    k_data_[batch_id * (max_block_num *\n                                                        config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            block_id * (config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            k * (config_.kv_head_num *\n                                                 config_.head_dim) +\n                                            head_id * config_.head_dim +\n                                            l * 32 + m]);\n                            }\n                            quantize_row_q8_0(block_fp32.data(), &block, 32);\n                            k_cache_q8[layer_id_][head_id][block_idx]\n                                      [k * config_.head_dim / 32 + l] = block;\n                        }\n                    }\n\n                    // fill v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q8_0 block;\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len + q_len) {\n                                    block_fp32[m] = 0;\n                                    continue;\n                                }\n                                block_fp32[m] = GGML_FP16_TO_FP32(\n                                    v_data_[batch_id * (max_block_num *\n                                                        config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            block_id * (config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            (k * 32 + m) * config_.kv_head_num *\n                                                config_.head_dim +\n                                            head_id * config_.head_dim + l]);\n                            }\n                            quantize_row_q8_0(block_fp32.data(), &block, 32);\n                            v_cache_q8[layer_id_][head_id][block_idx]\n                                      [l * config_.block_len / 32 + k] = block;\n                        }\n                    }\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n\n    // printf(\"layer %d time of reading and updating KV Cache: %f s\\n\",\n    // layer_id,\n    //        duration.count());\n}\n\nvoid KVCache::update_importance(const ggml_fp16_t *importance, int layer_id,\n                                int *block_table, int batch_size,\n                                int max_block_num, int *offset, int width,\n                                Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    importance_data_ = const_cast<uint16_t *>(importance);\n\n    // Each task updates the importance of a certain position\n    backend->do_work_stealing_job(\n        max_block_num * batch_size, nullptr,\n        [&](int task_id) {\n            int block_id = task_id % max_block_num;\n            int batch_id = task_id / max_block_num;\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            if (block_id > (offset[batch_id] + width) / config_.block_len) {\n                return;\n            }\n            for (int k = 0; k < config_.block_len; k++) {\n                for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                    importance_[layer_id_][block_idx][k][head_id] =\n                        GGML_FP32_TO_FP16(\n                            GGML_FP16_TO_FP32(\n                                importance_data_[batch_id * max_block_num *\n                                                     config_.block_len *\n                                                     config_.q_head_num +\n                                                 (block_id * config_.block_len +\n                                                  k) *\n                                                     config_.q_head_num +\n                                                 head_id]) +\n                            GGML_FP16_TO_FP32(\n                                importance_[layer_id_][block_idx][k][head_id]));\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n\n    // printf(\"layer %d time of updating importance: %f s\\n\", layer_id,\n    //        duration.count());\n}\n\nvoid KVCache::get_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,\n                               int layer_id, int *block_table, int batch_size,\n                               int max_block_num, int *cache_seqlens,\n                               Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    k_data_ = const_cast<uint16_t *>(k_in);\n    v_data_ = const_cast<uint16_t *>(v_in);\n\n    // Each task updates the k cache and v cache of a certain header\n    backend->do_work_stealing_job(\n        config_.kv_head_num * max_block_num * batch_size, nullptr,\n        [&](int task_id) {\n            // printf(\"block_idx: %d, task_id: %d\\n\", block_idx, task_id);\n            std::vector<float> block_fp32(32);\n            int batch_id = task_id / (config_.kv_head_num * max_block_num);\n            int block_id = (task_id / config_.kv_head_num) % max_block_num;\n            int head_id = task_id % config_.kv_head_num;\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            int seq_len = cache_seqlens[batch_id];\n            int block_l = block_id * config_.block_len;\n            int block_r = block_id * config_.block_len + config_.block_len;\n\n            if (block_l < seq_len) {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            k_data_\n                                [batch_id *\n                                     (max_block_num * config_.block_len *\n                                      config_.kv_head_num * config_.head_dim) +\n                                 block_id *\n                                     (config_.block_len * config_.kv_head_num *\n                                      config_.head_dim) +\n                                 k * (config_.kv_head_num * config_.head_dim) +\n                                 head_id * config_.head_dim + l] =\n                                    k_cache_fp16_[layer_id_][head_id][block_idx]\n                                                 [k * config_.head_dim + l];\n                            v_data_\n                                [batch_id *\n                                     (max_block_num * config_.block_len *\n                                      config_.kv_head_num * config_.head_dim) +\n                                 block_id *\n                                     (config_.block_len * config_.kv_head_num *\n                                      config_.head_dim) +\n                                 k * (config_.kv_head_num * config_.head_dim) +\n                                 head_id * config_.head_dim + l] =\n                                    v_cache_fp16_[layer_id_][head_id][block_idx]\n                                                 [l * config_.block_len + k];\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    // get k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q4_0 block =\n                                k_cache_q4[layer_id_][head_id][block_idx]\n                                          [k * config_.head_dim / 32 + l];\n                            dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                k_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        k * (config_.kv_head_num *\n                                             config_.head_dim) +\n                                        head_id * config_.head_dim + l * 32 +\n                                        m] = GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                    // get v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q4_0 block =\n                                v_cache_q4[layer_id_][head_id][block_idx]\n                                          [l * config_.block_len / 32 + k];\n                            dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len)\n                                    break;\n                                v_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        (k * 32 + m) * config_.kv_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    // get k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q8_0 block =\n                                k_cache_q8[layer_id_][head_id][block_idx]\n                                          [k * config_.head_dim / 32 + l];\n                            dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                k_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        k * (config_.kv_head_num *\n                                             config_.head_dim) +\n                                        head_id * config_.head_dim + l * 32 +\n                                        m] = GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                    // get v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q8_0 block =\n                                v_cache_q8[layer_id_][head_id][block_idx]\n                                          [l * config_.block_len / 32 + k];\n                            dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len)\n                                    break;\n                                v_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        (k * 32 + m) * config_.kv_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n}\n\nvoid KVCache::update_kvcache_fp16(const ggml_fp16_t *k_in,\n                                  const ggml_fp16_t *v_in, int layer_id,\n                                  int *block_table, int batch_size,\n                                  int max_block_num, int *cache_seqlens,\n                                  int q_len, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    k_data_ = const_cast<uint16_t *>(k_in);\n    v_data_ = const_cast<uint16_t *>(v_in);\n    // Each task updates the k cache and v cache of a certain header\n    backend->do_work_stealing_job(\n        batch_size * config_.kv_head_num * q_len, nullptr,\n        [&](int task_id) {\n            int batch_id = task_id / (config_.kv_head_num * q_len);\n            int head_id = task_id / q_len % config_.kv_head_num;\n            int seq_len = cache_seqlens[batch_id] + task_id % q_len;\n            int q_offset = task_id % q_len;\n\n            int block_id = seq_len / config_.block_len;\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            int pos_in_block = seq_len % config_.block_len;\n\n            if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                for (int l = 0; l < config_.head_dim; l++) {\n                    k_cache_fp16_[layer_id_][head_id][block_idx]\n                                 [pos_in_block * config_.head_dim + l] =\n                                     k_data_[batch_id *\n                                                 (q_len * config_.kv_head_num *\n                                                  config_.head_dim) +\n                                             q_offset * config_.kv_head_num *\n                                                 config_.head_dim +\n                                             head_id * config_.head_dim + l];\n                    v_cache_fp16_[layer_id_][head_id][block_idx]\n                                 [l * config_.block_len + pos_in_block] =\n                                     v_data_[batch_id *\n                                                 (q_len * config_.kv_head_num *\n                                                  config_.head_dim) +\n                                             q_offset * config_.kv_head_num *\n                                                 config_.head_dim +\n                                             head_id * config_.head_dim + l];\n                }\n            } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                std::vector<float> block_fp32(32);\n                // fill k_cache_\n                for (int l = 0; l < config_.head_dim / 32; l++) {\n                    block_q4_0 block;\n                    for (int m = 0; m < 32; m++) {\n\n                        block_fp32[m] = GGML_FP16_TO_FP32(\n                            k_data_[batch_id * (q_len * config_.kv_head_num *\n                                                config_.head_dim) +\n                                    head_id * config_.head_dim + l * 32 + m]);\n                    }\n                    quantize_row_q4_0(block_fp32.data(), &block, 32);\n\n                    k_cache_q4[layer_id_][head_id][block_idx]\n                              [pos_in_block * config_.head_dim / 32 + l] =\n                                  block;\n                }\n\n                // fill v_cache_\n                for (int l = 0; l < config_.head_dim; l++) {\n                    block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx]\n                                                 [l * config_.block_len / 32 +\n                                                  pos_in_block / 32];\n                    dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                    block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32(\n                        v_data_[batch_id * (q_len * config_.kv_head_num *\n                                            config_.head_dim) +\n                                head_id * config_.head_dim + l]);\n                    quantize_row_q4_0(block_fp32.data(), &block, 32);\n                    v_cache_q4[layer_id_][head_id][block_idx]\n                              [l * config_.block_len / 32 + pos_in_block / 32] =\n                                  block;\n                }\n            } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                std::vector<float> block_fp32(32);\n                // fill k_cache_\n                for (int l = 0; l < config_.head_dim / 32; l++) {\n                    block_q8_0 block;\n                    for (int m = 0; m < 32; m++) {\n\n                        block_fp32[m] = GGML_FP16_TO_FP32(\n                            k_data_[batch_id * (q_len * config_.kv_head_num *\n                                                config_.head_dim) +\n                                    head_id * config_.head_dim + l * 32 + m]);\n                    }\n                    quantize_row_q8_0(block_fp32.data(), &block, 32);\n\n                    k_cache_q8[layer_id_][head_id][block_idx]\n                              [pos_in_block * config_.head_dim / 32 + l] =\n                                  block;\n                }\n\n                // fill v_cache_\n                for (int l = 0; l < config_.head_dim; l++) {\n                    block_q8_0 block = v_cache_q8[layer_id_][head_id][block_idx]\n                                                 [l * config_.block_len / 32 +\n                                                  pos_in_block / 32];\n                    dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                    block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32(\n                        v_data_[batch_id * (q_len * config_.kv_head_num *\n                                            config_.head_dim) +\n                                head_id * config_.head_dim + l]);\n                    quantize_row_q8_0(block_fp32.data(), &block, 32);\n                    v_cache_q8[layer_id_][head_id][block_idx]\n                              [l * config_.block_len / 32 + pos_in_block / 32] =\n                                  block;\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    // printf(\"layer %d time of reading KV Cache: %f s\\n\", layer_id,\n    //        duration.count());\n}\n\nvoid KVCache::get_all_kvcache_one_layer(int layer_id, ggml_fp16_t *k_in,\n                                        ggml_fp16_t *v_in, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    seq_len_ = config_.block_len;\n    block_num_ = get_cache_total_block_num();\n    k_data_ = reinterpret_cast<uint16_t *>(k_in);\n    v_data_ = reinterpret_cast<uint16_t *>(v_in);\n\n    // Each task gets the k cache or v cache of a certain header\n    backend->do_work_stealing_job(\n        config_.kv_head_num * past_block_num_[layer_id] * 2, nullptr,\n        [&](int task_id) {\n            std::vector<float> block_fp32(32);\n            int head_id = task_id / 2 / past_block_num_[layer_id];\n            int block_idx = task_id / 2 % past_block_num_[layer_id];\n            if (block_idx >= block_num_)\n                return;\n\n            int max_offset = 0;\n            if (task_id & 1) {\n                // get k_cache_\n                for (int k = 0; k < config_.block_len; k++) {\n                    if (block_idx * seq_len_ + k >= cache_total_len_)\n                        break;\n                    for (int l = 0; l < config_.head_dim / 32; l++) {\n                        block_q4_0 block =\n                            k_cache_q4[layer_id_][head_id][block_idx]\n                                      [k * config_.head_dim / 32 + l];\n                        dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                        for (int m = 0; m < 32; m++) {\n\n                            k_data_[(head_id * cache_total_len_ +\n                                     block_idx * config_.block_len + k) *\n                                        config_.head_dim +\n                                    l * 32 + m] =\n                                GGML_FP32_TO_FP16(block_fp32[m]);\n                            max_offset = std::max(\n                                max_offset,\n                                (int)(head_id * cache_total_len_ +\n                                      block_idx * config_.block_len + k) *\n                                        config_.head_dim +\n                                    l * 32 + m);\n                        }\n                    }\n                }\n            } else {\n                // get v_cache_\n                for (int k = 0; k < config_.block_len / 32; k++) {\n                    for (int l = 0; l < config_.head_dim; l++) {\n                        block_q4_0 block =\n                            v_cache_q4[layer_id_][head_id][block_idx]\n                                      [l * config_.block_len / 32 + k];\n                        dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                        for (int m = 0; m < 32; m++) {\n\n                            if (block_idx * seq_len_ + k * 32 + m >=\n                                cache_total_len_)\n                                break;\n                            v_data_[(head_id * cache_total_len_ +\n                                     block_idx * config_.block_len + k * 32 +\n                                     m) *\n                                        config_.head_dim +\n                                    l] = GGML_FP32_TO_FP16(block_fp32[m]);\n                            max_offset =\n                                std::max(max_offset,\n                                         (int)((head_id * cache_total_len_ +\n                                                block_idx * config_.block_len +\n                                                k * 32 + m) *\n                                                   config_.head_dim +\n                                               l));\n                        }\n                    }\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    // printf(\"layer %d block num %d time of reading all KV Cache: %f s\\n\",\n    //        layer_id, block_num_, duration.count());\n}\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/kvcache/kvcache_utils.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"kvcache.h\"\n\n#include <chrono>\n\nstd::string ggml_type_to_string(ggml_type type) {\n    switch (type) {\n    case GGML_TYPE_F32:\n        return \"GGML_TYPE_F32\";\n    case GGML_TYPE_F16:\n        return \"GGML_TYPE_F16\";\n    case GGML_TYPE_Q4_0:\n        return \"GGML_TYPE_Q4_0\";\n    case GGML_TYPE_Q8_0:\n        return \"GGML_TYPE_Q8_0\";\n    }\n    return \"UNDIFINED\";\n}\nstd::string AnchorTypeToString(AnchorType type) {\n    switch (type) {\n    case AnchorType::DYNAMIC:\n        return \"DYNAMIC\";\n    case AnchorType::BLOCK_MEAN:\n        return \"BLOCK_MEAN\";\n    case AnchorType::BLOCK_MAX:\n        return \"BLOCK_MAX\";\n    case AnchorType::FIXED_ANCHOR:\n        return \"FIXED_ANCHOR\";\n    case AnchorType::QUEST:\n        return \"QUEST\";\n    }\n    return \"UNDIFINED\";\n}\nstd::string RetrievalTypeToString(RetrievalType type) {\n    switch (type) {\n    case RetrievalType::LAYER:\n        return \"SHARED\";\n    case RetrievalType::KVHEAD:\n        return \"SEPARATE\";\n    case RetrievalType::QHEAD:\n        return \"INDIVIDUAL\";\n    }\n    return \"UNDIFINED\";\n}\nKVCacheConfig::KVCacheConfig(int layer_num, int kv_head_num, int q_head_num,\n                             int head_dim, int block_len, int anchor_num,\n                             AnchorType anchor_type, ggml_type kv_type,\n                             RetrievalType retrieval_type, int layer_step,\n                             int token_step, int layer_offset,\n                             int max_block_num, int max_batch_size,\n                             int max_thread_num)\n    : layer_num(layer_num), kv_head_num(kv_head_num), q_head_num(q_head_num),\n      head_dim(head_dim), block_len(block_len), anchor_num(anchor_num),\n      anchor_type(anchor_type), kv_type(kv_type),\n      retrieval_type(retrieval_type), layer_step(layer_step),\n      token_step(token_step), layer_offset(layer_offset),\n      max_block_num(max_block_num), max_batch_size(max_batch_size),\n      max_thread_num(max_thread_num) {\n    printf(\n        \"layer_num: %d, kv_head_num: %d, q_head_num: %d, head_dim: %d, \"\n        \"block_len: %d, anchor_num: %d, anchor_type: %s, kv_type: %s, \"\n        \"retrieval_type: %s, layer_step: %d, token_step: %d, layer_offset: %d,\"\n        \"max_block_num: %d, max_batch_size: %d, max_thread_num: %d\\n\",\n        layer_num, kv_head_num, q_head_num, head_dim, block_len, anchor_num,\n        AnchorTypeToString(anchor_type).c_str(),\n        ggml_type_to_string(kv_type).c_str(),\n        RetrievalTypeToString(retrieval_type).c_str(), layer_step, token_step,\n        layer_offset, max_block_num, max_batch_size, max_thread_num);\n    assert(q_head_num % kv_head_num == 0);\n}\nKVCache::KVCache(KVCacheConfig config) {\n    this->config_ = config;\n\n    n_gqa_ = config_.q_head_num / config_.kv_head_num;\n    if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n        // TODO: Elegant implement\n        k_cache_fp16_.resize(config_.layer_num);\n        v_cache_fp16_.resize(config_.layer_num);\n        selected_blocks_num_history_.resize(config_.layer_num /\n                                            config_.layer_step);\n        if (config_.retrieval_type == RetrievalType::LAYER) {\n            selected_blocks_history_.resize(config_.layer_num /\n                                            config_.layer_step);\n        } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n            selected_blocks_history_kvhead_.resize(config_.layer_num /\n                                                   config_.layer_step);\n        } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n        }\n    } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n        k_cache_q4.resize(config.layer_num);\n        v_cache_q4.resize(config.layer_num);\n    } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n        k_cache_q8.resize(config.layer_num);\n        v_cache_q8.resize(config.layer_num);\n    } else {\n        assert(false);\n    }\n    anchor_.resize(config.layer_num * config.max_block_num * config.anchor_num *\n                   config.q_head_num * config.head_dim);\n    importance_.resize(config.layer_num);\n    past_block_num_.resize(config.layer_num);\n    for (int i = 0; i < config.layer_num; i++) {\n        past_block_num_[i] = 0;\n    }\n\n    ThreadResize(config.max_thread_num);\n    BatchResize(config.max_batch_size);\n    BlockResize(config.max_block_num);\n    q_fp32.resize(n_gqa_ * config.head_dim);\n}\n\nvoid KVCache::ThreadResize(int thread_num) {\n    thread_local_output_q8_0_.resize(thread_num);\n    thread_local_attn_score_.resize(thread_num);\n    thread_local_output_fp32_.resize(thread_num);\n    thread_local_attn_lse_.resize(thread_num);\n    thread_local_cur_output_fp32_.resize(thread_num);\n    thread_local_cur_attn_lse_.resize(thread_num);\n    thread_local_draft_.resize(thread_num);\n    thread_cur_head_idx_.resize(thread_num);\n    thread_local_attn_mask_.resize(thread_num);\n    for (int i = 0; i < thread_num; i++) {\n        thread_local_output_q8_0_[i].resize(n_gqa_ * config_.head_dim / QK8_0);\n        thread_local_attn_score_[i].resize(n_gqa_ * config_.block_len);\n        thread_local_output_fp32_[i].resize(n_gqa_ * config_.head_dim);\n        thread_local_attn_lse_[i].resize(n_gqa_);\n        thread_local_cur_output_fp32_[i].resize(n_gqa_ * config_.head_dim);\n        thread_local_cur_attn_lse_[i].resize(n_gqa_);\n        thread_local_draft_[i].resize(\n            2 * n_gqa_ * config_.block_len + 6 * n_gqa_ * config_.head_dim +\n            2 * config_.block_len * config_.head_dim +\n            config_.block_len * config_.head_dim / QK4_0);\n        thread_local_attn_mask_[i].resize(config_.block_len / 8);\n    }\n}\nvoid KVCache::BatchResize(int batch_size) {\n    mutex_.resize(batch_size);\n    q_q8_0_.resize(batch_size);\n    q_fp32_.resize(batch_size);\n    output_fp32_.resize(batch_size);\n    attn_lse_.resize(batch_size);\n    block_lse_.resize(batch_size);\n    attn_sparsity_.resize(batch_size);\n\n    if (config_.retrieval_type == RetrievalType::LAYER) {\n        block_table_before_retrieval_.resize(batch_size);\n        block_table_after_retrieval_.resize(batch_size);\n\n        for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {\n            selected_blocks_history_[i].resize(batch_size);\n        }\n\n    } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n        block_table_before_retrieval_kvhead_.resize(batch_size);\n        block_table_after_retrieval_kvhead_.resize(batch_size);\n        for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {\n            selected_blocks_history_kvhead_[i].resize(batch_size);\n        }\n    } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n        block_table_before_retrieval_qhead_.resize(batch_size);\n        block_table_after_retrieval_qhead_.resize(batch_size);\n    }\n    cache_seqlens_.resize(batch_size);\n    if (config_.retrieval_type == RetrievalType::LAYER) {\n        block_similar_.resize(batch_size);\n    } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n        block_similar_kv_head_.resize(batch_size);\n    } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n        block_similar_q_head_.resize(batch_size);\n    }\n    for (int i = 0; i < batch_size; i++) {\n        top_similar_block_.resize(batch_size);\n\n        mutex_[i].resize(config_.kv_head_num);\n        q_q8_0_[i].resize(config_.kv_head_num);\n        q_fp32_[i].resize(config_.kv_head_num);\n        output_fp32_[i].resize(config_.kv_head_num);\n        attn_lse_[i].resize(config_.kv_head_num);\n\n        for (int j = 0; j < config_.kv_head_num; j++) {\n            if (!mutex_[i][j]) {\n                mutex_[i][j] = std::make_unique<std::mutex>();\n            }\n            q_q8_0_[i][j].resize(n_gqa_ * config_.head_dim / QK8_0);\n            q_fp32_[i][j].resize(n_gqa_ * config_.head_dim);\n            output_fp32_[i][j].resize(n_gqa_ * config_.head_dim);\n            attn_lse_[i][j].resize(n_gqa_);\n        }\n    }\n    avg_q.resize(batch_size);\n    avg_q_fp16.resize(batch_size);\n    for (int i = 0; i < batch_size; i++) {\n        attn_sparsity_[i].resize(config_.q_head_num);\n        avg_q[i].resize(config_.q_head_num * config_.head_dim);\n        avg_q_fp16[i].resize(config_.q_head_num * config_.head_dim);\n    }\n}\n\nvoid KVCache::BlockResize(int max_block_num) {\n    sin_.resize(max_block_num * config_.block_len);\n    cos_.resize(max_block_num * config_.block_len);\n    for (int i = 0; i < max_block_num * config_.block_len; i++) {\n        sin_[i].resize(config_.head_dim);\n        cos_[i].resize(config_.head_dim);\n    }\n\n    for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {\n        for (int j = 0; j < config_.max_batch_size; j++) {\n            if (config_.retrieval_type == RetrievalType::LAYER) {\n                selected_blocks_history_[i][j].resize(max_block_num);\n            } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n                selected_blocks_history_kvhead_[i][j].resize(max_block_num);\n                for (int k = 0; k < config_.max_block_num; k++) {\n                    selected_blocks_history_kvhead_[i][j][k].resize(\n                        config_.kv_head_num);\n                }\n            } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n            }\n        }\n    }\n\n    for (int layer_id = 0; layer_id < config_.layer_num; layer_id++) {\n        importance_[layer_id].resize(max_block_num);\n\n        if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            // TODO: Elegant implement\n            k_cache_fp16_[layer_id].resize(config_.kv_head_num);\n            v_cache_fp16_[layer_id].resize(config_.kv_head_num);\n\n            for (int i = 0; i < config_.kv_head_num; i++) {\n                k_cache_fp16_[layer_id][i].resize(max_block_num);\n                v_cache_fp16_[layer_id][i].resize(max_block_num);\n\n                for (int j = 0; j < max_block_num; j++) {\n                    k_cache_fp16_[layer_id][i][j].resize(config_.block_len *\n                                                         config_.head_dim);\n                    v_cache_fp16_[layer_id][i][j].resize(config_.block_len *\n                                                         config_.head_dim);\n                }\n            }\n\n        } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            k_cache_q4[layer_id].resize(config_.kv_head_num);\n            v_cache_q4[layer_id].resize(config_.kv_head_num);\n            for (int i = 0; i < config_.kv_head_num; i++) {\n                k_cache_q4[layer_id][i].resize(max_block_num);\n                v_cache_q4[layer_id][i].resize(max_block_num);\n\n                for (int j = 0; j < max_block_num; j++) {\n                    k_cache_q4[layer_id][i][j].resize(config_.block_len *\n                                                      config_.head_dim / 32);\n                    v_cache_q4[layer_id][i][j].resize(config_.block_len *\n                                                      config_.head_dim / 32);\n                }\n            }\n        } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            k_cache_q8[layer_id].resize(config_.kv_head_num);\n            v_cache_q8[layer_id].resize(config_.kv_head_num);\n            for (int i = 0; i < config_.kv_head_num; i++) {\n                k_cache_q8[layer_id][i].resize(max_block_num);\n                v_cache_q8[layer_id][i].resize(max_block_num);\n\n                for (int j = 0; j < max_block_num; j++) {\n                    k_cache_q8[layer_id][i][j].resize(config_.block_len *\n                                                      config_.head_dim / 32);\n                    v_cache_q8[layer_id][i][j].resize(config_.block_len *\n                                                      config_.head_dim / 32);\n                }\n            }\n        } else {\n            assert(false);\n        }\n        for (int i = 0; i < config_.max_batch_size; i++) {\n            if (config_.retrieval_type == RetrievalType::LAYER) {\n                block_similar_[i].resize(max_block_num);\n                block_table_before_retrieval_[i].resize(max_block_num);\n                block_table_after_retrieval_[i].resize(max_block_num);\n            } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n                block_similar_kv_head_[i].resize(max_block_num);\n                block_table_before_retrieval_kvhead_[i].resize(max_block_num);\n                block_table_after_retrieval_kvhead_[i].resize(max_block_num);\n                for (int j = 0; j < max_block_num; j++) {\n                    block_similar_kv_head_[i][j].resize(config_.kv_head_num);\n                    block_table_before_retrieval_kvhead_[i][j].resize(\n                        config_.kv_head_num);\n                    block_table_after_retrieval_kvhead_[i][j].resize(\n                        config_.kv_head_num);\n                }\n            } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n                block_similar_q_head_[i].resize(max_block_num);\n                block_table_before_retrieval_qhead_[i].resize(max_block_num);\n                block_table_after_retrieval_qhead_[i].resize(max_block_num);\n                for (int j = 0; j < max_block_num; j++) {\n                    block_similar_q_head_[i][j].resize(config_.q_head_num);\n                    block_table_before_retrieval_qhead_[i][j].resize(\n                        config_.q_head_num);\n                    block_table_after_retrieval_qhead_[i][j].resize(\n                        config_.q_head_num);\n                }\n            }\n            block_lse_[i].resize(max_block_num);\n            for (int j = 0; j < max_block_num; j++) {\n                block_lse_[i][j].resize(config_.q_head_num);\n            }\n        }\n\n        for (int i = 0; i < max_block_num; i++) {\n            importance_[layer_id][i].resize(config_.block_len);\n            for (int j = 0; j < config_.block_len; j++) {\n                importance_[layer_id][i][j].resize(config_.q_head_num);\n            }\n        }\n    }\n}\n\nvoid KVCache::calc_anchor_all_layers(int *block_table, int *cache_seqlens,\n                                     int batch_size, int max_block_num,\n                                     Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    // Each task updates the importance of a certain block\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        config_.layer_num * batch_size * max_block_num, nullptr,\n        [&](int task_id) {\n            int layer_id = task_id / (batch_size * max_block_num);\n            int batch_id = (task_id / max_block_num) % batch_size;\n            int block_id = task_id % max_block_num;\n            // If the block is out of the sequence length, skip it. In\n            // particular, the last block of the sequence that is shorter than\n            // the block length should be skipped.\n\n            if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n\n            std::vector<float> block_fp32(32);\n            if (config_.anchor_type == AnchorType::DYNAMIC) {\n\n                // clear anchor_\n                for (int anchor_id = 0; anchor_id < 1; anchor_id++) {\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            anchor_[layer_id * config_.max_block_num *\n                                        config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    block_idx * config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    anchor_id * config_.q_head_num *\n                                        config_.head_dim +\n                                    head_id * config_.head_dim + l] = 0;\n                        }\n                    }\n                }\n\n                // find top anchor_num importances and their corresponding\n                // positions in the importance_ tensor\n                // TODO: Move top_importances to the class member to avoid\n                // repeated memory allocation\n                std::priority_queue<\n                    std::pair<float, std::pair<int, int>>,\n                    std::vector<std::pair<float, std::pair<int, int>>>,\n                    std::greater<>>\n                    top_importances;\n                for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                    for (int k = 0; k < seq_len_; k++) {\n                        top_importances.push(std::make_pair(\n                            GGML_FP16_TO_FP32(\n                                importance_[layer_id][block_idx][k][head_id]),\n                            std::make_pair(block_idx, k)));\n                        // TODO: change to config_ item\n                        if (top_importances.size() > config_.anchor_num) {\n                            top_importances.pop();\n                        }\n                    }\n\n                    // fill anchor_\n\n                    for (int l = 0; l < config_.head_dim; l++) {\n                        anchor_[layer_id * config_.max_block_num *\n                                    config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim +\n                                block_idx * config_.anchor_num *\n                                    config_.q_head_num * config_.head_dim +\n                                0 * config_.q_head_num * config_.head_dim +\n                                head_id * config_.head_dim + l] = 0;\n                    }\n                    for (int k = 0; k < config_.anchor_num; k++) {\n                        int top_indice = top_importances.top().second.second;\n                        int top_block_idx = top_importances.top().second.first;\n\n                        if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n\n                            for (int l = 0; l < config_.head_dim; l++) {\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        top_block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        0 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    top_block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    0 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l]) +\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_[layer_id]\n                                                         [head_id / n_gqa_]\n                                                         [top_block_idx]\n                                                         [top_indice *\n                                                              config_.head_dim +\n                                                          l]));\n                            }\n\n                        } else if (config_.kv_type ==\n                                   ggml_type::GGML_TYPE_Q4_0) {\n                            for (int l = 0; l < config_.head_dim / 32; l++) {\n                                block_q4_0 block = k_cache_q4\n                                    [layer_id][head_id / n_gqa_][top_block_idx]\n                                    [top_indice * config_.head_dim / 32 + l];\n                                dequantize_row_q4_0(&block, block_fp32.data(),\n                                                    32);\n                                for (int m = 0; m < 32; m++) {\n                                    anchor_[layer_id * config_.max_block_num *\n                                                config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            top_block_idx * config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            0 * config_.q_head_num *\n                                                config_.head_dim +\n                                            head_id * config_.head_dim +\n                                            l * 32 + m] =\n                                        GGML_FP32_TO_FP16(\n                                            block_fp32[m] / 4 +\n                                            GGML_FP16_TO_FP32(\n                                                anchor_[layer_id *\n                                                            config_\n                                                                .max_block_num *\n                                                            config_.anchor_num *\n                                                            config_.q_head_num *\n                                                            config_.head_dim +\n                                                        top_block_idx *\n                                                            config_.anchor_num *\n                                                            config_.q_head_num *\n                                                            config_.head_dim +\n                                                        0 * config_.q_head_num *\n                                                            config_.head_dim +\n                                                        head_id *\n                                                            config_.head_dim +\n                                                        l * 32 + m]));\n                                }\n                            }\n                        } else if (config_.kv_type ==\n                                   ggml_type::GGML_TYPE_Q8_0) {\n                            for (int l = 0; l < config_.head_dim / 32; l++) {\n                                block_q8_0 block = k_cache_q8\n                                    [layer_id][head_id / n_gqa_][top_block_idx]\n                                    [top_indice * config_.head_dim / 32 + l];\n                                dequantize_row_q8_0(&block, block_fp32.data(),\n                                                    32);\n                                for (int m = 0; m < 32; m++) {\n                                    anchor_[layer_id * config_.max_block_num *\n                                                config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            top_block_idx * config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            0 * config_.q_head_num *\n                                                config_.head_dim +\n                                            head_id * config_.head_dim +\n                                            l * 32 + m] =\n                                        GGML_FP32_TO_FP16(\n                                            block_fp32[m] / 4 +\n                                            GGML_FP16_TO_FP32(\n                                                anchor_[layer_id *\n                                                            config_\n                                                                .max_block_num *\n                                                            config_.anchor_num *\n                                                            config_.q_head_num *\n                                                            config_.head_dim +\n                                                        top_block_idx *\n                                                            config_.anchor_num *\n                                                            config_.q_head_num *\n                                                            config_.head_dim +\n                                                        0 * config_.q_head_num *\n                                                            config_.head_dim +\n                                                        head_id *\n                                                            config_.head_dim +\n                                                        l * 32 + m]));\n                                }\n                            }\n                        }\n                        top_importances.pop();\n                    }\n                }\n            } else if (config_.anchor_type == AnchorType::BLOCK_MEAN) {\n                // clear anchor_\n                for (int anchor_id = 0; anchor_id < config_.anchor_num;\n                     anchor_id++) {\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            anchor_[layer_id * config_.max_block_num *\n                                        config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    block_idx * config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    anchor_id * config_.q_head_num *\n                                        config_.head_dim +\n                                    head_id * config_.head_dim + l] = 0;\n                        }\n                    }\n                }\n\n                // fill anchor_\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int k = 0; k < config_.block_len; k++) {\n                            for (int l = 0; l < config_.head_dim; l++) {\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        0 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    0 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l]) +\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_[layer_id]\n                                                         [head_id / n_gqa_]\n                                                         [block_idx]\n                                                         [k * config_.head_dim +\n                                                          l]) /\n                                            config_.block_len);\n                            }\n                        }\n                    }\n                }\n            } else if (config_.anchor_type == AnchorType::BLOCK_MAX) {\n                // clear anchor_\n                for (int anchor_id = 0; anchor_id < config_.anchor_num;\n                     anchor_id++) {\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            anchor_[layer_id * config_.max_block_num *\n                                        config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    block_idx * config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    anchor_id * config_.q_head_num *\n                                        config_.head_dim +\n                                    head_id * config_.head_dim + l] = 0;\n                        }\n                    }\n                }\n\n                // fill anchor_\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int k = 0; k < config_.block_len; k++) {\n                            for (int l = 0; l < config_.head_dim; l++) {\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        0 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(std::max(\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    0 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l]),\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_\n                                                [layer_id][head_id / n_gqa_]\n                                                [block_idx]\n                                                [k * config_.head_dim + l])));\n                            }\n                        }\n                    }\n                }\n            } else if (config_.anchor_type == AnchorType::FIXED_ANCHOR) {\n                // clear anchor_\n                for (int anchor_id = 0; anchor_id < 1; anchor_id++) {\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            anchor_[layer_id * config_.max_block_num *\n                                        config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    block_idx * config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    anchor_id * config_.q_head_num *\n                                        config_.head_dim +\n                                    head_id * config_.head_dim + l] = 0;\n                        }\n                    }\n                }\n\n                // fill anchor_\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n\n                    int stride = config_.block_len / config_.anchor_num;\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int k = 0, tot = 0;\n                             k < config_.block_len, tot < config_.anchor_num;\n                             k += stride, tot++) {\n                            for (int l = 0; l < config_.head_dim; l++) {\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        0 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    0 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l]) +\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_[layer_id]\n                                                         [head_id / n_gqa_]\n                                                         [block_idx]\n                                                         [k * config_.head_dim +\n                                                          l]) /\n                                            config_.anchor_num);\n                            }\n                        }\n                    }\n                }\n\n            } else if (config_.anchor_type == AnchorType::QUEST) {\n                // clear anchor_\n                for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                    for (int l = 0; l < config_.head_dim; l++) {\n                        anchor_[layer_id * config_.max_block_num *\n                                    config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim +\n                                block_idx * config_.anchor_num *\n                                    config_.q_head_num * config_.head_dim +\n                                1 * config_.q_head_num * config_.head_dim +\n                                head_id * config_.head_dim + l] =\n                            GGML_FP32_TO_FP16(\n                                std::numeric_limits<float>::max());\n\n                        anchor_[layer_id * config_.max_block_num *\n                                    config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim +\n                                block_idx * config_.anchor_num *\n                                    config_.q_head_num * config_.head_dim +\n                                0 * config_.q_head_num * config_.head_dim +\n                                head_id * config_.head_dim + l] =\n                            GGML_FP32_TO_FP16(\n                                std::numeric_limits<float>::min());\n                    }\n                }\n\n                // fill anchor_\n\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    for (int indice = 0; indice < seq_len_; indice++) {\n                        for (int head_id = 0; head_id < config_.kv_head_num;\n                             head_id++) {\n                            for (int l = 0; l < config_.head_dim; l++) {\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        0 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(std::max(\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_\n                                                [layer_id][head_id][block_idx]\n                                                [indice * config_.head_dim +\n                                                 l]),\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    0 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l])));\n\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        1 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(std::min(\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_\n                                                [layer_id][head_id][block_idx]\n                                                [indice * config_.head_dim +\n                                                 l]),\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    1 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l])));\n                            }\n                        }\n                    }\n\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    for (int indice = 0; indice < seq_len_; indice++) {\n                        for (int head_id = 0; head_id < config_.kv_head_num;\n                             head_id++) {\n                            for (int l = 0; l < config_.head_dim / 32; l++) {\n                                block_q4_0 block =\n                                    k_cache_q4[layer_id][head_id][block_idx]\n                                              [indice * config_.head_dim / 32 +\n                                               l];\n                                dequantize_row_q4_0(&block, block_fp32.data(),\n                                                    32);\n\n                                for (int m = 0; m < 32; m++) {\n                                    for (int gqa_idx = 0; gqa_idx < n_gqa_;\n                                         gqa_idx++) {\n\n                                        anchor_[layer_id *\n                                                    config_.max_block_num *\n                                                    config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                block_idx * config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                0 * config_.q_head_num *\n                                                    config_.head_dim +\n                                                head_id * config_.head_dim +\n                                                l * 32 + m] =\n                                            GGML_FP32_TO_FP16(std::max(\n                                                block_fp32[m],\n                                                GGML_FP16_TO_FP32(\n                                                    anchor_\n                                                        [layer_id *\n                                                             config_\n                                                                 .max_block_num *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         block_idx *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         0 *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         head_id *\n                                                             config_.head_dim +\n                                                         l * 32 + m])));\n\n                                        anchor_[layer_id *\n                                                    config_.max_block_num *\n                                                    config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                block_idx * config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                1 * config_.q_head_num *\n                                                    config_.head_dim +\n                                                head_id * config_.head_dim +\n                                                l * 32 + m] =\n                                            GGML_FP32_TO_FP16(std::min(\n                                                block_fp32[m],\n                                                GGML_FP16_TO_FP32(\n                                                    anchor_\n                                                        [layer_id *\n                                                             config_\n                                                                 .max_block_num *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         block_idx *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         1 *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         head_id *\n                                                             config_.head_dim +\n                                                         l * 32 + m])));\n                                    }\n                                }\n                            }\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    for (int indice = 0; indice < seq_len_; indice++) {\n                        for (int head_id = 0; head_id < config_.kv_head_num;\n                             head_id++) {\n                            for (int l = 0; l < config_.head_dim / 32; l++) {\n                                block_q8_0 block =\n                                    k_cache_q8[layer_id][head_id][block_idx]\n                                              [indice * config_.head_dim / 32 +\n                                               l];\n                                dequantize_row_q8_0(&block, block_fp32.data(),\n                                                    32);\n\n                                for (int m = 0; m < 32; m++) {\n                                    for (int gqa_idx = 0; gqa_idx < n_gqa_;\n                                         gqa_idx++) {\n\n                                        anchor_[layer_id *\n                                                    config_.max_block_num *\n                                                    config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                block_idx * config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                0 * config_.q_head_num *\n                                                    config_.head_dim +\n                                                head_id * config_.head_dim +\n                                                l * 32 + m] =\n                                            GGML_FP32_TO_FP16(std::max(\n                                                block_fp32[m],\n                                                GGML_FP16_TO_FP32(\n                                                    anchor_\n                                                        [layer_id *\n                                                             config_\n                                                                 .max_block_num *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         block_idx *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         0 *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         head_id *\n                                                             config_.head_dim +\n                                                         l * 32 + m])));\n\n                                        anchor_[layer_id *\n                                                    config_.max_block_num *\n                                                    config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                block_idx * config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                1 * config_.q_head_num *\n                                                    config_.head_dim +\n                                                head_id * config_.head_dim +\n                                                l * 32 + m] =\n                                            GGML_FP32_TO_FP16(std::min(\n                                                block_fp32[m],\n                                                GGML_FP16_TO_FP32(\n                                                    anchor_\n                                                        [layer_id *\n                                                             config_\n                                                                 .max_block_num *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         block_idx *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         1 *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         head_id *\n                                                             config_.head_dim +\n                                                         l * 32 + m])));\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n            } else {\n                assert(false);\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    //    printf(\"time of calc_anchor_all_layers: %f s\\n\", duration.count());\n}\n\nvoid KVCache::clear_importance_all_layers(int *block_table, int *cache_seqlens,\n                                          int batch_size, int max_block_num,\n                                          Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    // Each task updates the importance of a certain block\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        config_.layer_num * batch_size * max_block_num, nullptr,\n        [&](int task_id) {\n            int layer_id = task_id / (batch_size * max_block_num);\n            int batch_id = (task_id / max_block_num) % batch_size;\n            int block_id = task_id % max_block_num;\n            // If the block is out of the sequence length, skip it. In\n            // particular, the last block of the sequence that is shorter than\n            // the block length should be skipped.\n\n            if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n\n            if (config_.anchor_type == AnchorType::DYNAMIC) {\n\n                // clear anchor_\n                for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                    for (int l = 0; l < config_.block_len; l++) {\n                        importance_[layer_id][block_idx][l][head_id] = 0;\n                    }\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    //    printf(\"time of clear_importance_all_layerssssss: %f s\\n\",\n    //    duration.count());\n}\n\nvoid KVCache::clear_kvcache_all_layers(int *block_table, int *cache_seqlens,\n                                       int batch_size, int max_block_num,\n                                       Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    // Each task updates the importance of a certain block\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        config_.layer_num * batch_size * max_block_num * config_.kv_head_num,\n        nullptr,\n        [&](int task_id) {\n            int layer_id =\n                task_id / (batch_size * max_block_num * config_.kv_head_num);\n            int batch_id =\n                (task_id / (max_block_num * config_.kv_head_num)) % batch_size;\n            int block_id = task_id / config_.kv_head_num % max_block_num;\n            int head_id = task_id % config_.kv_head_num;\n            // If the block is out of the sequence length, skip it. In\n            // particular, the last block of the sequence that is shorter than\n            // the block length should be skipped.\n            if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n\n            if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                for (int l = 0; l < config_.block_len * config_.head_dim; l++) {\n                    k_cache_fp16_[layer_id][head_id][block_idx][l] = 0;\n                    v_cache_fp16_[layer_id][head_id][block_idx][l] = 0;\n                }\n            } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                for (int l = 0; l < config_.block_len * config_.head_dim / 32;\n                     l++) {\n                    k_cache_q4[layer_id][head_id][block_idx][l].d = 0;\n                    v_cache_q4[layer_id][head_id][block_idx][l].d = 0;\n                }\n            } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                for (int l = 0; l < config_.block_len * config_.head_dim / 32;\n                     l++) {\n                    k_cache_q8[layer_id][head_id][block_idx][l].d = 0;\n                    v_cache_q8[layer_id][head_id][block_idx][l].d = 0;\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    //    printf(\"time of clear_kvcache_all_layers: %f s\\n\", duration.count());\n}\n\nvoid KVCache::get_sincos(ggml_fp16_t *sin, ggml_fp16_t *cos, int seqlen) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    const uint16_t *sin_data = const_cast<const uint16_t *>(sin);\n    const uint16_t *cos_data = const_cast<const uint16_t *>(cos);\n\n    for (int i = 0; i < seqlen; i++) {\n        for (int j = 0; j < config_.head_dim; j++) {\n            sin_[i][j] = sin_data[i * config_.head_dim + j];\n            cos_[i][j] = cos_data[i * config_.head_dim + j];\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"time of get_sincos: %f s\\n\", duration.count());\n}\n\nvoid ggml_vec_scale_f32(const int n, float *y, const float v) {\n#if defined(GGML_USE_ACCELERATE)\n    vDSP_vsmul(y, 1, &v, y, 1, n);\n#elif defined(GGML_SIMD)\n    const int np = (n & ~(GGML_F32_STEP - 1));\n\n    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);\n\n    GGML_F32_VEC ay[GGML_F32_ARR];\n\n    for (int i = 0; i < np; i += GGML_F32_STEP) {\n        for (int j = 0; j < GGML_F32_ARR; j++) {\n            ay[j] = GGML_F32_VEC_LOAD(y + i + j * GGML_F32_EPR);\n            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);\n\n            GGML_F32_VEC_STORE(y + i + j * GGML_F32_EPR, ay[j]);\n        }\n    }\n\n    // leftovers\n    for (int i = np; i < n; ++i) {\n        y[i] *= v;\n    }\n#else\n    // scalar\n    for (int i = 0; i < n; ++i) {\n        y[i] *= v;\n    }\n#endif\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/llamafile/conversion.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022 \n * @LastEditTime : 2024-07-25 10:34:55\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_CONVERSION_H\n#define CPUINFER_CONVERSION_H\n\n#include <memory.h>\n#include \"llama.cpp/ggml.h\"\n\ninline void to_float(const void* input, float* output, int size, ggml_type type) {\n    if (type == ggml_type::GGML_TYPE_F32) {\n        memcpy(output, input, size * sizeof(float));\n    } else {\n        ggml_internal_get_type_traits(type).to_float(input, output, size);\n    }\n}\n\ninline void from_float(const float* input, void* output, int size, ggml_type type) {\n    if (type == ggml_type::GGML_TYPE_F32) {\n        memcpy(output, input, size * sizeof(float));\n    } else {\n        ggml_internal_get_type_traits(type).from_float(input, output, size);\n    }\n}\n\n#endif"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/llamafile/linear.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-15 07:45:18\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"linear.h\"\n\nLinear::Linear(LinearConfig config) {\n    config_ = config;\n    proj_ = config_.proj;\n\n    std::vector<std::pair<void**, uint64_t>> mem_requests;\n    mem_requests.push_back({(void**)&input_fp32_, sizeof(float) * config_.group_max_len * config_.input_size});\n    mem_requests.push_back({(void**)&proj_input_, config_.group_max_len * config_.input_size * ggml_type_size(ggml_internal_get_type_traits(config_.proj_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.proj_type).vec_dot_type)});\n    mem_requests.push_back({(void**)&proj_output_, sizeof(float) * config_.group_max_len * config_.output_size});\n    shared_mem_buffer.alloc(this, mem_requests);\n}\n\nLinear::~Linear() {\n    shared_mem_buffer.dealloc(this);\n}\n\nvoid Linear::warm_up(Backend *backend) {\n    std::vector<float> input_fp32(config_.input_size);\n    std::vector<uint8_t> input(config_.input_size *\n                               ggml_type_size(config_.hidden_type) /\n                               ggml_blck_size(config_.hidden_type));\n    std::vector<uint8_t> output(config_.output_size *\n                                ggml_type_size(config_.hidden_type) /\n                                ggml_blck_size(config_.hidden_type));\n    for (int i = 0; i < config_.input_size; i++) {\n        input_fp32[i] = 0;\n    }\n    from_float(input_fp32.data(), input.data(), config_.input_size, config_.hidden_type);\n    forward_many(1, input.data(), output.data(), backend);\n}\n\nvoid Linear::forward_many(int qlen, const void* input, void* output, Backend* backend) {\n    const void* proj_input_ptr;\n    if (config_.hidden_type == ggml_internal_get_type_traits(config_.proj_type).vec_dot_type) {\n        proj_input_ptr = input;\n    } else {\n        to_float(input, input_fp32_, qlen * config_.input_size, config_.hidden_type);\n        from_float(input_fp32_, proj_input_, qlen * config_.input_size, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type);\n        proj_input_ptr = proj_input_;\n    }\n    int nth = config_.output_size / config_.stride;\n    backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {\n        int ith = task_id;\n        void* proj_ptr = (uint8_t*)proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type);\n        float* proj_output_ptr = proj_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.input_size / ggml_blck_size(config_.proj_type), proj_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_input_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_output_ptr, config_.output_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.proj_type, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {\n            for (int i = 0; i < qlen; i++) {\n                float* output_fp32_ptr = proj_output_ + i * config_.output_size + ith * config_.stride;\n                void* output_ptr = (uint8_t*)output + i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n                from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);\n            }\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {\n        from_float(proj_output_, output, qlen * config_.output_size, config_.hidden_type);\n    }\n}\n\nvoid Linear::forward(int qlen, const void* input, void* output, Backend* backend) {\n    if (qlen <= 0) {\n        return;\n    }\n    int forward_len = std::min(qlen, config_.group_max_len);\n    forward_many(forward_len, input, output, backend);\n    forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/llamafile/linear.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:35:00\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_LINEAR_H\n#define CPUINFER_OPERATOR_LINEAR_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"conversion.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\nstruct LinearConfig {\n    int input_size;\n    int output_size;\n    int stride;\n    int group_max_len;\n    void* proj;\n    ggml_type proj_type;\n    ggml_type hidden_type;\n\n    LinearConfig() {}\n\n    LinearConfig(int input_size, int output_size, int stride, int group_max_len, void* proj, ggml_type proj_type, ggml_type hidden_type)\n        : input_size(input_size), output_size(output_size), stride(stride), group_max_len(group_max_len), proj(proj), proj_type(proj_type), hidden_type(hidden_type) {}\n};\n\nclass Linear {\n   public:\n    Linear(LinearConfig);\n    ~Linear();\n    void warm_up(Backend* backend);\n    void forward_many(int qlen, const void* input, void* output, Backend* backend);\n    void forward(int qlen, const void* input, void* output, Backend* backend);\n\n   private:\n    LinearConfig config_;\n    void* proj_;  // [output_size * input_size ( /32 if quantized)]\n\n    float* input_fp32_;    // [group_max_len * input_size]\n    uint8_t* proj_input_;  // [group_max_len * input_size * ggml_type_size(ggml_internal_get_type_traits(proj_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(proj_type).vec_dot_type)]\n    float* proj_output_;   // [group_max_len * output_size]\n};\n\n#endif"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/llamafile/mlp.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-16 10:43:18\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-15 07:44:38\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"mlp.h\"\n\nMLP::MLP(MLPConfig config) {\n    config_ = config;\n    gate_proj_ = config_.gate_proj;\n    up_proj_ = config_.up_proj;\n    down_proj_ = config_.down_proj;\n\n    std::vector<std::pair<void**, uint64_t>> mem_requests;\n    mem_requests.push_back({(void**)&input_fp32_, sizeof(float) * config_.group_max_len * config_.hidden_size});\n    mem_requests.push_back({(void**)&gate_input_, config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n    mem_requests.push_back({(void**)&up_input_, config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    mem_requests.push_back({(void**)&gate_output_, sizeof(float) * config_.group_max_len * config_.intermediate_size});\n    mem_requests.push_back({(void**)&up_output_, sizeof(float) * config_.group_max_len * config_.intermediate_size});\n    mem_requests.push_back({(void**)&intermediate_fp32_, sizeof(float) * config_.group_max_len * config_.intermediate_size});\n    mem_requests.push_back({(void**)&down_input_, config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});\n    mem_requests.push_back({(void**)&down_output_, sizeof(float) * config_.group_max_len * config_.hidden_size});\n    shared_mem_buffer.alloc(this, mem_requests);\n}\n\nMLP::~MLP() {\n    shared_mem_buffer.dealloc(this);\n}\n\nvoid MLP::warm_up(Backend *backend) {\n    std::vector<float> input_fp32(config_.hidden_size);\n    std::vector<uint8_t> input(config_.hidden_size *\n                               ggml_type_size(config_.hidden_type) /\n                               ggml_blck_size(config_.hidden_type));\n    std::vector<uint8_t> output(config_.hidden_size *\n                                ggml_type_size(config_.hidden_type) /\n                                ggml_blck_size(config_.hidden_type));\n    for (int i = 0; i < config_.hidden_size; i++) {\n        input_fp32[i] = 0;\n    }\n    from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type);\n    forward_many(1, input.data(), output.data(), backend);\n}\n\nstatic float act_fn(float x) { return x / (1.0f + expf(-x)); }\n\nvoid MLP::forward_many(int qlen, const void* input, void* output, Backend* backend) {\n    const void* gate_input_ptr;\n    const void* up_input_ptr;\n    if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n        gate_input_ptr = up_input_ptr = input;\n    } else {\n        to_float(input, input_fp32_, qlen * config_.hidden_size, config_.hidden_type);\n        if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n            from_float(input_fp32_, gate_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n            gate_input_ptr = up_input_ptr = gate_input_;\n        } else {\n            if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {\n                from_float(input_fp32_, gate_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                gate_input_ptr = gate_input_;\n            } else {\n                gate_input_ptr = input;\n            }\n            if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                from_float(input_fp32_, up_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n                up_input_ptr = up_input_;\n            } else {\n                up_input_ptr = input;\n            }\n        }\n    }\n    int nth = config_.intermediate_size / config_.stride;\n    backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {\n        int ith = task_id;\n        void* gate_proj_ptr = (uint8_t*)gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        float* gate_output_ptr = gate_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        void* up_proj_ptr = (uint8_t*)up_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        float* up_output_ptr = up_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        for (int i = 0; i < qlen; i++) {\n            for (int j = ith * config_.stride; j < (ith + 1) * config_.stride; j++) {\n                intermediate_fp32_[i * config_.intermediate_size + j] = act_fn(gate_output_[i * config_.intermediate_size + j]) * up_output_[i * config_.intermediate_size + j];\n            }\n            if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {\n                float* intermediate_fp32_ptr = intermediate_fp32_ + i * config_.intermediate_size + ith * config_.stride;\n                void* down_input_ptr = (uint8_t*)down_input_ + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n                from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n            }\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {\n        from_float(intermediate_fp32_, down_input_, qlen * config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n    }\n    nth = config_.hidden_size / config_.stride;\n    backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {\n        int ith = task_id;\n        void* down_proj_ptr = (uint8_t*)down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n        float* down_output_ptr = down_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {\n            for (int i = 0; i < qlen; i++) {\n                float* output_fp32_ptr = down_output_ + i * config_.hidden_size + ith * config_.stride;\n                void* output_ptr = (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n                from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);\n            }\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {\n        from_float(down_output_, output, qlen * config_.hidden_size, config_.hidden_type);\n    }\n}\n\nvoid MLP::forward(int qlen, const void* input, void* output, Backend* backend) {\n    if (qlen <= 0) {\n        return;\n    }\n    int forward_len = std::min(qlen, config_.group_max_len);\n    forward_many(forward_len, input, output, backend);\n    forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/llamafile/mlp.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:35:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_MLP_H\n#define CPUINFER_OPERATOR_MLP_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"conversion.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\nstruct MLPConfig {\n    int hidden_size;\n    int intermediate_size;\n    int stride;\n    int group_max_len;\n    void* gate_proj;\n    void* up_proj;\n    void* down_proj;\n    ggml_type gate_type;\n    ggml_type up_type;\n    ggml_type down_type;\n    ggml_type hidden_type;\n\n    MLPConfig() {}\n\n    MLPConfig(int hidden_size, int intermediate_size, int stride, int group_max_len, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)\n        : hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}\n};\n\nclass MLP {\n   public:\n    MLP(MLPConfig);\n    ~MLP();\n    void warm_up(Backend* backend);\n    void forward_many(int qlen, const void* input, void* output, Backend* backend);\n    void forward(int qlen, const void* input, void* output, Backend* backend);\n\n   private:\n    MLPConfig config_;\n    void* gate_proj_;  // [intermediate_size * hidden_size ( /32 if quantized)]\n    void* up_proj_;    // [intermediate_size * hidden_size ( /32 if quantized)]\n    void* down_proj_;  // [hidden_size * intermediate_size ( /32 if quantized)]\n\n    float* input_fp32_;         // [group_max_len * hidden_size]\n    uint8_t* gate_input_;       // [group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    uint8_t* up_input_;         // [group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    float* gate_output_;        // [group_max_len * intermediate_size]\n    float* up_output_;          // [group_max_len * intermediate_size]\n    float* intermediate_fp32_;  // [group_max_len * intermediate_size]\n    uint8_t* down_input_;       // [group_max_len * intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n    float* down_output_;        // [group_max_len * hidden_size]\n};\n\n#endif"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/llamafile/moe.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-15 07:43:41\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"moe.h\"\n#include <iostream>\n#include <cstdint>\n#include <math.h>\n\n#ifdef USE_NUMA\n#include <numa.h>\n#include <numaif.h>\n#endif\n\nMOE::MOE(MOEConfig config) {\n    config_ = config;\n    gate_proj_ = config_.gate_proj;\n    up_proj_ = config_.up_proj;\n    down_proj_ = config_.down_proj;\n    \n    #ifdef USE_NUMA\n    int numa_nodes = numa_num_configured_nodes();\n    gate_proj_numa_.resize(numa_nodes);\n    up_proj_numa_.resize(numa_nodes);\n    down_proj_numa_.resize(numa_nodes);\n    size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size;\n    for (int i = 0; i < numa_nodes; i++) {\n        gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i);\n        up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i);\n        down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i);\n        if (!gate_proj_numa_[i]) {\n            std::cout << \"Memory allocation failed for gate_proj_numa_ on node \" << i << std::endl;\n        }\n        if (!up_proj_numa_[i]) {\n            std::cout << \"Memory allocation failed for up_proj_numa_ on node \" << i << std::endl;\n        }\n        if (!down_proj_numa_[i]) {\n            std::cout << \"Memory allocation failed for down_proj_numa_ on node \" << i << std::endl;\n        }\n        memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type));\n        memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type));\n        memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type));\n    }\n    #endif\n\n    std::vector<std::pair<void**, uint64_t>> s_mem_requests;\n    s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size});\n    s_mem_requests.push_back({(void**)&s_gate_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n    s_mem_requests.push_back({(void**)&s_up_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    s_gate_output_.resize(config_.routed_expert_num);\n    s_up_output_.resize(config_.routed_expert_num);\n    s_intermediate_fp32_.resize(config_.routed_expert_num);\n    s_down_input_.resize(config_.routed_expert_num);\n    s_down_output_.resize(config_.routed_expert_num);\n    for (int i = 0; i < config_.routed_expert_num; i++) {\n        s_mem_requests.push_back({(void**)&s_gate_output_[i], sizeof(float) * config_.intermediate_size});\n        s_mem_requests.push_back({(void**)&s_up_output_[i], sizeof(float) * config_.intermediate_size});\n        s_mem_requests.push_back({(void**)&s_intermediate_fp32_[i], sizeof(float) * config_.intermediate_size});\n        s_mem_requests.push_back({(void**)&s_down_input_[i], config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});\n        s_mem_requests.push_back({(void**)&s_down_output_[i], sizeof(float) * config_.hidden_size});\n    }\n    s_mem_requests.push_back({(void**)&s_output_fp32_, sizeof(float) * config_.hidden_size});\n    shared_mem_buffer.alloc(this, s_mem_requests);\n\n    std::vector<std::pair<void**, uint64_t>> m_mem_requests;\n    m_input_fp32_.resize(config_.group_max_len);\n    m_gate_input_.resize(config_.group_max_len);\n    m_up_input_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n        m_mem_requests.push_back({(void**)&m_input_fp32_[i], sizeof(float) * config_.hidden_size});\n        m_mem_requests.push_back({(void**)&m_gate_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n        m_mem_requests.push_back({(void**)&m_up_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    }\n    m_mem_requests.push_back({(void**)&m_local_gate_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n    m_mem_requests.push_back({(void**)&m_local_up_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    m_mem_requests.push_back({(void**)&m_local_gate_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_up_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_intermediate_fp32_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_down_input_, config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});\n    m_mem_requests.push_back({(void**)&m_local_down_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size});\n    m_output_fp32_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n        m_mem_requests.push_back({(void**)&m_output_fp32_[i], sizeof(float) * config_.hidden_size});\n    }\n    shared_mem_buffer.alloc(this, m_mem_requests);\n\n    m_local_pos_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n        m_local_pos_[i].resize(config_.routed_expert_num);\n    }\n    m_local_num_.resize(config_.expert_num);\n    m_local_gate_input_ptr_.resize(config_.expert_num);\n    m_local_up_input_ptr_.resize(config_.expert_num);\n    m_local_gate_output_ptr_.resize(config_.expert_num);\n    m_local_up_output_ptr_.resize(config_.expert_num);\n    m_local_intermediate_fp32_ptr_.resize(config_.expert_num);\n    m_local_down_input_ptr_.resize(config_.expert_num);\n    m_local_down_output_ptr_.resize(config_.expert_num);\n}\n\nMOE::~MOE() {\n    shared_mem_buffer.dealloc(this);\n\n    #ifdef USE_NUMA\n    int numa_nodes = numa_num_configured_nodes();\n    for (int i = 0; i < numa_nodes; i++) {\n        numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type));\n        numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type));\n        numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type));\n    }\n    #endif\n}\n\nvoid MOE::warm_up(Backend* backend) {\n    std::vector<float> input_fp32(config_.hidden_size);\n    std::vector<uint8_t> input(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));\n    std::vector<uint8_t> output(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));\n    for (int i = 0; i < config_.hidden_size; i++) {\n        input_fp32[i] = 0;\n    }\n    from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type);\n    for (int i = 0; i < config_.expert_num; i++) {\n        uint64_t expert_ids = i;\n        float weights = 0;\n        forward_one(1, &expert_ids, &weights, input.data(), output.data(), backend);\n    }\n}\n\nstatic float act_fn(float x) {\n    return x / (1.0f + expf(-x));\n}\n\nstatic float act_fn_relu(float x) {\n    if(x > 0.0){\n        return x;\n    } else {\n        return 0.0;\n    }\n}\n\nvoid MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {\n    const void* gate_input_ptr;\n    const void* up_input_ptr;\n    if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n        gate_input_ptr = up_input_ptr = input;\n    } else {\n        to_float(input, s_input_fp32_, config_.hidden_size, config_.hidden_type);\n        if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n            from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n            gate_input_ptr = up_input_ptr = s_gate_input_;\n        } else {\n            if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {\n                from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                gate_input_ptr = s_gate_input_;\n            } else {\n                gate_input_ptr = input;\n            }\n            if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                from_float(s_input_fp32_, s_up_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n                up_input_ptr = s_up_input_;\n            } else {\n                up_input_ptr = input;\n            }\n        }\n    }\n    int nth = config_.intermediate_size / config_.stride;\n    backend->do_work_stealing_job(nth * k, nullptr, [&](int task_id) {\n        int expert_idx = task_id / nth;\n        uint64_t expert_id = expert_ids[expert_idx];\n        int ith = task_id % nth;\n        \n        #ifdef USE_NUMA\n        void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #else\n        void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #endif\n\n        float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride;\n        llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n\n        #ifdef USE_NUMA\n        void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #else\n        void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #endif\n\n        float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;\n        llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        if(config_.use_silu){\n            // use silu as act fn\n            for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n                s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];\n            }\n        } else {\n            // use relu as act fn\n            for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n                s_intermediate_fp32_[expert_idx][i] = act_fn_relu(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];\n            }\n        }\n        if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {\n            float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.stride;\n            void* down_input_ptr = s_down_input_[expert_idx] + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n            from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {\n        for (int i = 0; i < k; i++) {\n            from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        }\n    }\n    nth = config_.hidden_size / config_.stride;\n    backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {\n        int ith = task_id;\n        for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n            s_output_fp32_[i] = 0;\n        }\n        for (int expert_idx = 0; expert_idx < k; expert_idx++) {\n            uint64_t expert_id = expert_ids[expert_idx];\n\n            #ifdef USE_NUMA\n            void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n            #else\n            void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n            #endif\n            \n            float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride;\n            llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n            for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n                s_output_fp32_[i] += s_down_output_[expert_idx][i] * weights[expert_idx];\n            }\n        }\n        if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {\n            float* output_fp32_ptr = s_output_fp32_ + ith * config_.stride;\n            void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n            from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {\n        from_float(s_output_fp32_, output, config_.hidden_size, config_.hidden_type);\n    }\n}\n\nvoid MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {\n    for (int i = 0; i < config_.expert_num; i++) {\n        m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n        for (int j = 0; j < k; j++) {\n            m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n        }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n        m_local_gate_input_ptr_[i] = m_local_gate_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n        m_local_up_input_ptr_[i] = m_local_up_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n        m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n        m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n        m_local_intermediate_fp32_ptr_[i] = m_local_intermediate_fp32_ + offset * config_.intermediate_size;\n        m_local_down_input_ptr_[i] = m_local_down_input_ + offset * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;\n        offset += m_local_num_[i];\n    }\n    backend->do_work_stealing_job(qlen, nullptr, [&](int i) {\n        const void* gate_input_ptr;\n        const void* up_input_ptr;\n        if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n            gate_input_ptr = up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n        } else {\n            to_float((uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), m_input_fp32_[i], config_.hidden_size, config_.hidden_type);\n            if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                gate_input_ptr = up_input_ptr = m_gate_input_[i];\n            } else {\n                if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {\n                    from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                    gate_input_ptr = m_gate_input_[i];\n                } else {\n                    gate_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n                }\n                if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                    from_float(m_input_fp32_[i], m_up_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n                    up_input_ptr = m_up_input_[i];\n                } else {\n                    up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n                }\n            }\n        }\n        for (int j = 0; j < k; j++) {\n            memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type), gate_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type));\n            memcpy(m_local_up_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type), up_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type));\n        }\n    }, nullptr);\n    int stride = QK_K;\n    int nth = config_.intermediate_size / stride;\n    backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {\n        uint64_t expert_idx = task_id / nth;\n        int ith = task_id % nth;\n        void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];\n\n        #ifdef USE_NUMA\n        void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #else\n        void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #endif\n\n        float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride;\n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        void* up_input_ptr = m_local_up_input_ptr_[expert_idx];\n\n        #ifdef USE_NUMA\n        void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #else\n        void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #endif\n\n        float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;\n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n            if(config_.use_silu){\n                for (int j = ith * stride; j < (ith + 1) * stride; j++) {\n                    m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];\n                }\n            } else {\n                for (int j = ith * stride; j < (ith + 1) * stride; j++) {\n                    m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn_relu(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];\n                }\n            }\n            float* intermediate_fp32_ptr = m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride;\n            void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n            from_float(intermediate_fp32_ptr, down_input_ptr, stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        }\n    }, nullptr);\n    stride = QK_K;\n    nth = config_.hidden_size / stride;\n    backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {\n        uint64_t expert_idx = task_id / nth;\n        int ith = task_id % nth;\n        void* down_input_ptr = m_local_down_input_ptr_[expert_idx];\n        \n        #ifdef USE_NUMA\n        void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n        #else\n        void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n        #endif\n\n        float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;\n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n    }, nullptr);\n    backend->do_work_stealing_job(qlen, nullptr, [&](int i) {\n        for (int e = 0; e < config_.hidden_size; e++) {\n            m_output_fp32_[i][e] = 0;\n        }\n        for (int j = 0; j < k; j++) {\n            for (int e = 0; e < config_.hidden_size; e++) {\n                m_output_fp32_[i][e] += m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] * weights[i * k + j];\n            }\n        }\n        from_float(m_output_fp32_[i], (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type);\n    }, nullptr);\n}\n\nvoid MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, int* batch_size_tensor, Backend* backend) {\n    qlen = batch_size_tensor[0];\n    if (qlen < config_.group_min_len) {\n        for (int i = 0; i < qlen; i++) {\n            forward_one(k, expert_ids + i * k, weights + i * k, (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);\n        }\n        return;\n    }\n    int forward_len = std::min(config_.group_max_len, qlen);\n    forward_many(forward_len, k, expert_ids, weights, input, output, backend);\n\n    batch_size_tensor[0] -= forward_len;\n    forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), batch_size_tensor, backend);\n}"
  },
  {
    "path": "archive/csrc/ktransformers_ext/operators/llamafile/moe.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:35:10\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_MOE_H\n#define CPUINFER_OPERATOR_MOE_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"conversion.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\nstruct MOEConfig {\n    int expert_num;\n    int routed_expert_num;\n    int hidden_size;\n    int intermediate_size;\n    int stride;\n    int group_min_len;\n    int group_max_len;\n    bool use_silu;\n    void* gate_proj;\n    void* up_proj;\n    void* down_proj;\n    ggml_type gate_type;\n    ggml_type up_type;\n    ggml_type down_type;\n    ggml_type hidden_type;\n\n    MOEConfig() {}\n\n    MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, bool use_silu, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)\n        : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), use_silu(use_silu), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}\n};\n\nclass MOE {\n   public:\n    MOE(MOEConfig);\n    ~MOE();\n    void warm_up(Backend* backend);\n    void forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend);\n    void forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend);\n    void forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, int* batch_size_tensor, Backend* backend);\n\n   private:\n    MOEConfig config_;\n    void* gate_proj_;  // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    void* up_proj_;    // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    void* down_proj_;  // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n\n    #ifdef USE_NUMA\n    std::vector<void*> gate_proj_numa_;  // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    std::vector<void*> up_proj_numa_;    // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    std::vector<void*> down_proj_numa_;  // [numa_num, expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n    #endif\n\n    float* s_input_fp32_;                      // [hidden_size]\n    uint8_t* s_gate_input_;                    // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    uint8_t* s_up_input_;                      // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    std::vector<float*> s_gate_output_;        // [routed_expert_num, intermediate_size]\n    std::vector<float*> s_up_output_;          // [routed_expert_num, intermediate_size]\n    std::vector<float*> s_intermediate_fp32_;  // [routed_expert_num, intermediate_size]\n    std::vector<uint8_t*> s_down_input_;       // [routed_expert_num, intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n    std::vector<float*> s_down_output_;        // [routed_expert_num, hidden_size]\n    float* s_output_fp32_;                     // [hidden_size]\n\n    std::vector<float*> m_input_fp32_;    // [group_max_len, hidden_size]\n    std::vector<uint8_t*> m_gate_input_;  // [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    std::vector<uint8_t*> m_up_input_;    // [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    uint8_t* m_local_gate_input_;         // [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    uint8_t* m_local_up_input_;           // [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    float* m_local_gate_output_;          // [routed_expert_num * group_max_len * intermediate_size]\n    float* m_local_up_output_;            // [routed_expert_num * group_max_len * intermediate_size]\n    float* m_local_intermediate_fp32_;    // [routed_expert_num * group_max_len * intermediate_size]\n    uint8_t* m_local_down_input_;         // [routed_expert_num * group_max_len * intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n    float* m_local_down_output_;          // [routed_expert_num * group_max_len * hidden_size]\n    std::vector<float*> m_output_fp32_;   // [group_max_len, hidden_size]\n\n    std::vector<std::vector<int>> m_local_pos_;          // [group_max_len, routed_expert_num]\n    std::vector<int> m_local_num_;                       // [expert_num]\n    std::vector<uint8_t*> m_local_gate_input_ptr_;       // [expert_num]\n    std::vector<uint8_t*> m_local_up_input_ptr_;         // [expert_num]\n    std::vector<float*> m_local_gate_output_ptr_;        // [expert_num]\n    std::vector<float*> m_local_up_output_ptr_;          // [expert_num]\n    std::vector<float*> m_local_intermediate_fp32_ptr_;  // [expert_num]\n    std::vector<uint8_t*> m_local_down_input_ptr_;       // [expert_num]\n    std::vector<float*> m_local_down_output_ptr_;        // [expert_num]\n};\n\n#endif"
  },
  {
    "path": "archive/csrc/ktransformers_ext/vendors/cuda.h",
    "content": "#pragma once\n\n#include <cuda_runtime.h>\n#include <cuda.h>\n#include <cublas_v2.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#if CUDART_VERSION < 11020\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#endif // CUDART_VERSION < 11020\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/vendors/hip.h",
    "content": "#pragma once\n\n#define HIP_ENABLE_WARP_SYNC_BUILTINS 1\n#include <hip/hip_runtime.h>\n#include <hipblas/hipblas.h>\n#include <hip/hip_fp16.h>\n#include <hip/hip_bfloat16.h>\n#ifdef __HIP_PLATFORM_AMD__\n// for rocblas_initialize()\n#include \"rocblas/rocblas.h\"\n#endif // __HIP_PLATFORM_AMD__\n\n#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F\n#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F\n#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N HIPBLAS_OP_N\n#define CUBLAS_OP_T HIPBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH 0\n#define CUDA_R_16F  HIPBLAS_R_16F\n#define CUDA_R_32F  HIPBLAS_R_32F\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended\n#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned\n#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite\n#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT(\"HipVMM Failure: %s\\n\", hipGetErrorString(err)); }}\n#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)\n#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)\n#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6\n#define cublasCreate hipblasCreate\n#define cublasDestroy hipblasDestroy\n#define cublasGemmEx hipblasGemmEx\n#define cublasGemmBatchedEx hipblasGemmBatchedEx\n#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx\n#define cublasHandle_t hipblasHandle_t\n#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS\n#define cublasSetStream hipblasSetStream\n#define cublasSgemm hipblasSgemm\n#define cublasStatus_t hipblasStatus_t\n#define cublasOperation_t hipblasOperation_t\n#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6\n#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess\n#define cudaDeviceProp hipDeviceProp_t\n#define cudaDeviceSynchronize hipDeviceSynchronize\n#define cudaError_t hipError_t\n#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags hipEventCreateWithFlags\n#define cudaEventDisableTiming hipEventDisableTiming\n#define cudaEventRecord hipEventRecord\n#define cudaEventSynchronize hipEventSynchronize\n#define cudaEvent_t hipEvent_t\n#define cudaEventDestroy hipEventDestroy\n#define cudaFree hipFree\n#define cudaFreeHost hipHostFree\n#define cudaGetDevice hipGetDevice\n#define cudaGetDeviceCount hipGetDeviceCount\n#define cudaGetDeviceProperties hipGetDeviceProperties\n#define cudaGetErrorString hipGetErrorString\n#define cudaGetLastError hipGetLastError\n#define cudaHostRegister hipHostRegister\n#define cudaHostRegisterPortable hipHostRegisterPortable\n#define cudaHostRegisterReadOnly hipHostRegisterReadOnly\n#define cudaHostUnregister hipHostUnregister\n#define cudaLaunchHostFunc hipLaunchHostFunc\n#define cudaMalloc hipMalloc\n#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)\n#define cudaMemcpy hipMemcpy\n#define cudaMemcpyAsync hipMemcpyAsync\n#define cudaMemcpyPeerAsync hipMemcpyPeerAsync\n#define cudaMemcpy2DAsync hipMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice hipMemcpyHostToDevice\n#define cudaMemcpyKind hipMemcpyKind\n#define cudaMemset hipMemset\n#define cudaMemsetAsync hipMemsetAsync\n#define cudaMemGetInfo hipMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize\n#define cudaSetDevice hipSetDevice\n#define cuDeviceGet hipDeviceGet\n#define CUdevice hipDevice_t\n#define CUdeviceptr hipDeviceptr_t\n#define cuMemUnmap hipMemUnmap\n#define CUmemAccessDesc hipMemAccessDesc\n#define cuMemAddressFree hipMemAddressFree\n#define cuMemRelease hipMemRelease\n#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t\n#define cuMemCreate hipMemCreate\n#define cuMemAddressReserve hipMemAddressReserve\n#define cuMemMap hipMemMap\n#define cuMemSetAccess hipMemSetAccess\n#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity\n#define CUmemAllocationProp hipMemAllocationProp\n#define cuDeviceGetAttribute hipDeviceGetAttribute\n#define cudaStreamCreateWithFlags hipStreamCreateWithFlags\n#define cudaStreamDestroy hipStreamDestroy\n#define cudaStreamFireAndForget hipStreamFireAndForget\n#define cudaStreamNonBlocking hipStreamNonBlocking\n#define cudaStreamPerThread hipStreamPerThread\n#define cudaStreamSynchronize hipStreamSynchronize\n#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)\n#define cudaGraphExec_t hipGraphExec_t\n#define cudaGraphNode_t hipGraphNode_t\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaGraphExecDestroy hipGraphExecDestroy\n#define cudaGraphLaunch hipGraphLaunch\n#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure\n#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult\n#define cudaGraphNodeType hipGraphNodeType\n#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel\n#define cudaGraphInstantiate hipGraphInstantiate\n#define cudaStreamEndCapture hipStreamEndCapture\n#define cudaGraphDestroy hipGraphDestroy\n#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams\n#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction\n#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams\n#define cudaGraphNodeGetType hipGraphNodeGetType\n#define cudaGraphGetNodes hipGraphGetNodes\n#define cudaGraphExecUpdate hipGraphExecUpdate\n#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed\n#define cudaStreamBeginCapture hipStreamBeginCapture\n#define cudaGraph_t hipGraph_t\n#define cudaStream_t hipStream_t\n#define cudaSuccess hipSuccess\n#define cudaHostFn_t hipHostFn_t\n#define __trap() do { abort(); __builtin_unreachable(); } while(0)\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED\n#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED\n#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE\n#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH\n#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR\n#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED\n#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR\n#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED\n\n#define __CUDA_ARCH__ 1300\n\n#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)\n#define GCN\n#endif\n\n#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)\n#define CDNA\n#endif\n\n#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \\\n    defined(__gfx1150__) || defined(__gfx1151__)\n#define RDNA3\n#endif\n\n#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \\\n    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)\n#define RDNA2\n#endif\n\n#if defined(__gfx1010__) || defined(__gfx1012__)\n#define RDNA1\n#endif\n\n#ifndef __has_builtin\n    #define __has_builtin(x) 0\n#endif\n\ntypedef hip_bfloat16 nv_bfloat16;\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/vendors/musa.h",
    "content": "#pragma once\n\n#include <musa_runtime.h>\n#include <musa.h>\n#include <mublas.h>\n#include <musa_bf16.h>\n#include <musa_fp16.h>\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F\n#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N MUBLAS_OP_N\n#define CUBLAS_OP_T MUBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT\n#define CUDA_R_16F  MUSA_R_16F\n#define CUDA_R_32F  MUSA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#define cublasCreate mublasCreate\n#define cublasDestroy mublasDestroy\n#define cublasGemmEx mublasGemmEx\n#define cublasGemmBatchedEx mublasGemmBatchedEx\n#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx\n#define cublasHandle_t mublasHandle_t\n#define cublasSetMathMode mublasSetMathMode\n#define cublasSetStream mublasSetStream\n#define cublasSgemm mublasSgemm\n#define cublasStatus_t mublasStatus_t\n#define cublasOperation_t mublasOperation_t\n#define cublasGetStatusString mublasStatus_to_string\n#define cudaDataType_t musaDataType_t\n#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess\n#define cudaDeviceProp musaDeviceProp\n#define cudaDeviceSynchronize musaDeviceSynchronize\n#define cudaError_t musaError_t\n#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags musaEventCreateWithFlags\n#define cudaEventDisableTiming musaEventDisableTiming\n#define cudaEventRecord musaEventRecord\n#define cudaEventSynchronize musaEventSynchronize\n#define cudaEvent_t musaEvent_t\n#define cudaEventDestroy musaEventDestroy\n#define cudaFree musaFree\n#define cudaFreeHost musaFreeHost\n#define cudaGetDevice musaGetDevice\n#define cudaGetDeviceCount musaGetDeviceCount\n#define cudaGetDeviceProperties musaGetDeviceProperties\n#define cudaGetErrorString musaGetErrorString\n#define cudaGetLastError musaGetLastError\n#define cudaHostRegister musaHostRegister\n#define cudaHostRegisterPortable musaHostRegisterPortable\n#define cudaHostRegisterReadOnly musaHostRegisterReadOnly\n#define cudaHostUnregister musaHostUnregister\n#define cudaLaunchHostFunc musaLaunchHostFunc\n#define cudaMalloc musaMalloc\n#define cudaMallocHost musaMallocHost\n#define cudaMallocManaged musaMallocManaged\n#define cudaMemcpy musaMemcpy\n#define cudaMemcpyAsync musaMemcpyAsync\n#define cudaMemcpyPeerAsync musaMemcpyPeerAsync\n#define cudaMemcpy2DAsync musaMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice musaMemcpyHostToDevice\n#define cudaMemcpyKind musaMemcpyKind\n#define cudaMemset musaMemset\n#define cudaMemsetAsync musaMemsetAsync\n#define cudaMemGetInfo musaMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize\n#define cudaSetDevice musaSetDevice\n#define cudaStreamCreateWithFlags musaStreamCreateWithFlags\n#define cudaStreamDestroy musaStreamDestroy\n#define cudaStreamFireAndForget musaStreamFireAndForget\n#define cudaStreamNonBlocking musaStreamNonBlocking\n#define cudaStreamPerThread musaStreamPerThread\n#define cudaStreamSynchronize musaStreamSynchronize\n#define cudaStreamWaitEvent musaStreamWaitEvent\n#define cudaStream_t musaStream_t\n#define cudaSuccess musaSuccess\n\n// Additional mappings for MUSA virtual memory pool\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED\n#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED\n#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE\n#define CUdevice MUdevice\n#define CUdeviceptr MUdeviceptr\n#define CUmemAccessDesc MUmemAccessDesc\n#define CUmemAllocationProp MUmemAllocationProp\n#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle\n#define cuDeviceGet muDeviceGet\n#define cuDeviceGetAttribute muDeviceGetAttribute\n#define cuMemAddressFree muMemAddressFree\n#define cuMemAddressReserve muMemAddressReserve\n#define cuMemCreate muMemCreate\n#define cuMemGetAllocationGranularity muMemGetAllocationGranularity\n#define cuMemMap muMemMap\n#define cuMemRelease muMemRelease\n#define cuMemSetAccess muMemSetAccess\n#define cuMemUnmap muMemUnmap\n#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize\n#define cudaFuncSetAttribute musaFuncSetAttribute\n#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms\n#define make_cudaExtent make_musaExtent\n#define make_cudaPitchedPtr make_musaPitchedPtr\n\n// Additional mappings for MUSA graphs\n#define CUDA_SUCCESS MUSA_SUCCESS\n#define CUresult MUresult\n#define cuGetErrorString muGetErrorString\n#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure\n#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction\n#define cudaGraphDestroy musaGraphDestroy\n#define cudaGraphExecDestroy musaGraphExecDestroy\n#define cudaGraphExec_t musaGraphExec_t\n#define cudaGraphExecUpdate musaGraphExecUpdate\n#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult\n#define cudaGraphGetNodes musaGraphGetNodes\n#define cudaGraphInstantiate musaGraphInstantiate\n#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams\n#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams\n#define cudaGraphLaunch musaGraphLaunch\n#define cudaGraphNodeGetType musaGraphNodeGetType\n#define cudaGraphNode_t musaGraphNode_t\n#define cudaGraphNodeType musaGraphNodeType\n#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel\n#define cudaGraph_t musaGraph_t\n#define cudaKernelNodeParams musaKernelNodeParams\n#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed\n#define cudaStreamEndCapture musaStreamEndCapture\n\ntypedef mt_bfloat16 nv_bfloat16;\n"
  },
  {
    "path": "archive/csrc/ktransformers_ext/vendors/vendor.h",
    "content": "#ifndef CPUINFER_VENDOR_VENDOR_H\n#define CPUINFER_VENDOR_VENDOR_H\n\n#ifdef USE_CUDA\n#include \"cuda.h\"\n#elif USE_HIP\n#define __HIP_PLATFORM_AMD__\n#include \"hip.h\"\n#elif USE_MUSA\n#include \"musa.h\"\n#endif\n\n#endif  // CPUINFER_VENDOR_VENDOR_H"
  },
  {
    "path": "archive/install-with-cache.sh",
    "content": "#!/bin/bash\nset -e  \n\n# clear build dirs\n# rm -rf build\n# rm -rf *.egg-info\n# rm -rf csrc/build\n# rm -rf csrc/ktransformers_ext/build\n# rm -rf csrc/ktransformers_ext/cuda/build\n# rm -rf csrc/ktransformers_ext/cuda/dist\n# rm -rf csrc/ktransformers_ext/cuda/*.egg-info\nrm -rf ~/.ktransformers\necho \"Installing python dependencies from requirements.txt\"\npip install -r requirements-local_chat.txt\npip install -r ktransformers/server/requirements.txt\necho \"Installing ktransformers\"\nKTRANSFORMERS_FORCE_BUILD=TRUE USE_BALANCE_SERVE=1 pip install -v . --no-build-isolation\npip install third_party/custom_flashinfer/ -v\n\n# SITE_PACKAGES=$(python -c \"import site; print(site.getsitepackages()[0])\")\n# echo \"Copying thirdparty libs to $SITE_PACKAGES\"\n# cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/\n# patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython*\n\n\necho \"Installation completed successfully\"\n"
  },
  {
    "path": "archive/install.bat",
    "content": "@echo off\n\nREM clear build dirs\nrmdir /S /Q ktransformers\\ktransformers_ext\\build\nrmdir /S /Q ktransformers\\ktransformers_ext\\cuda\\build\nrmdir /S /Q ktransformers\\ktransformers_ext\\cuda\\dist\nrmdir /S /Q ktransformers\\ktransformers_ext\\out\ndel /F /Q ktransformers\\ktransformers_ext\\cuda\\*.egg-info\n\necho Installing python dependencies from requirements.txt\npip install -r requirements-local_chat.txt\n\necho Installing ktransformers\nset KTRANSFORMERS_FORCE_BUILD=TRUE\npip install . --no-build-isolation\necho Installation completed successfully"
  },
  {
    "path": "archive/install.sh",
    "content": "#!/bin/bash\nset -e  \n\n# default backend\nDEV=\"cuda\"\n\n# parse --dev argument\nwhile [[ \"$#\" -gt 0 ]]; do\n    case $1 in\n        --dev) DEV=\"$2\"; shift ;;\n        *) echo \"Unknown parameter passed: $1\"; exit 1 ;;\n    esac\n    shift\ndone\nexport DEV_BACKEND=\"$DEV\"\necho \"Selected backend: $DEV_BACKEND\"\n\n# clear build dirs\nrm -rf build\nrm -rf *.egg-info\nrm -rf csrc/build\nrm -rf csrc/ktransformers_ext/build\nrm -rf csrc/ktransformers_ext/cuda/build\nrm -rf csrc/ktransformers_ext/cuda/dist\nrm -rf csrc/ktransformers_ext/cuda/*.egg-info\nrm -rf ~/.ktransformers\necho \"Installing python dependencies from requirements.txt\"\npip install -r requirements-local_chat.txt\npip install -r ktransformers/server/requirements.txt\n\necho \"Installing ktransformers\"\nKTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation\n\nif [[ \"$DEV_BACKEND\" == \"cuda\" ]]; then\n    echo \"Installing custom_flashinfer for CUDA backend\"\n    pip install third_party/custom_flashinfer/\nfi\n# SITE_PACKAGES=$(python -c \"import site; print(site.getsitepackages()[0])\")\n# echo \"Copying thirdparty libs to $SITE_PACKAGES\"\n# cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/\n# patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython*\n\necho \"Installation completed successfully\"\n"
  },
  {
    "path": "archive/ktransformers/__init__.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  : \nAuthor       : kkk1nak0\nDate         : 2024-08-15 07:34:46\nVersion      : 1.0.0\nLastEditors  : chenxl \nLastEditTime : 2025-02-15 03:53:02\n'''\n__version__ = \"0.4.1\"\n"
  },
  {
    "path": "archive/ktransformers/configs/config.yaml",
    "content": "log:\n  dir: \"logs\"\n  file: \"lexllama.log\"\n  #log level: debug, info, warn, error, crit\n  level: \"debug\"\n  backup_count: -1\n\nserver:\n  ip: 0.0.0.0\n  port: 10002\n\ndb:\n  type: \"sqllite\"\n  database: \"server.db\"\n  host: \"./\"\n  pool_size: 10\n\nuser:\n  secret_key: \"981f1dd2a44e27d68759d0252a486568ed43480b4e616a26e3af3709c3a7ce73\"\n  algorithm: \"HS256\"\n\nmodel:\n  # type: transformers\n  type: balance_serve\n  # type: ktransformers\n\n  name: DeepSeek-Coder-V2-Instruct\n  path: deepseek-ai/DeepSeek-V2-Lite-Chat\n  gguf_path: /mnt/data/models/Smallthinker-21B\n\n  device: cuda:0\n  cache_lens: 16384\n  max_new_tokens: 500\nweb:\n  mount: False\n  open_cross_domain: True\n\next:\n  cpu_infer: 10\n\nlong_context:\n  max_seq_len: 32000\n  block_size: 128\n  local_windows_len: 4096\n  second_select_num: 32\n  anchor_type: DYNAMIC\n  kv_type: FP16\n  dense_layer_num: 2\n  anchor_num: 1\n  preselect_block: True\n  head_select_mode: SHARED\n  preselect_block_count: 32\n  layer_step: 1\n  token_step: \n\nlocal_chat:\n  prompt_file: \"\"\n\nasync_server:\n  sched_strategy: \"FCFS\"\n  sched_port: 56441\n  sched_metrics_port: 54321\n  kvc2_metrics_port: 54391\n  max_batch_size: 4  # decode count + prefill count, in one mini batch\n\nattn:\n  page_size: 256\n  chunk_size: 256\nkvc2:\n  gpu_only: true \n  utilization_percentage: 1.0\n  cpu_memory_size_GB: 500\n  disk_path: /home/wjh/kvc"
  },
  {
    "path": "archive/ktransformers/configs/log_config.ini",
    "content": "[loggers]\nkeys=root,uvicorn,uvicornError,uvicornAccess\n\n[handlers]\nkeys=consoleHandler,fileHandler\n\n[formatters]\nkeys=detailedFormatter\n\n[logger_root]\nlevel=INFO\nhandlers=consoleHandler\n\n[logger_uvicorn]\nlevel=INFO\nhandlers=consoleHandler,fileHandler\nqualname=uvicorn\npropagate=0\n\n[logger_uvicornError]\nlevel=ERROR\nhandlers=consoleHandler,fileHandler\nqualname=uvicorn.error\npropagate=0\n\n[logger_uvicornAccess]\nlevel=INFO\nhandlers=consoleHandler,fileHandler\nqualname=uvicorn.access\npropagate=0\n\n[handler_consoleHandler]\nclass=StreamHandler\nlevel=INFO\nformatter=detailedFormatter\nargs=(sys.stdout,)\n\n[handler_fileHandler]\nclass=logging.FileHandler\nlevel=INFO\nformatter=detailedFormatter\nargs=('uvicorn_logs.log', 'a')\n\n[formatter_detailedFormatter]\nformat=%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s\ndatefmt=%Y-%m-%d %H:%M:%S\n"
  },
  {
    "path": "archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/format_24.py",
    "content": "#\n# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).\n#\n\nimport torch\n\n\n# This is PyTorch implementation of main part of reorder_meta()\n# function, from tools/util/include/cutlass/util/host_reorder.h file\n# of CUTLASS source tree.  Furthermore, CUTLASS template for sparse\n# GEMM decides upon layout of this matrix, and at the moment for the\n# sparse GEMM executed on tensor cores, this is layout described by\n# ColumnMajorInterleaved<2> data structure, in\n# include/cutlass/layout/matrix.h of CUTLASS source tree.  The\n# reordering of meta matrix into meta_reordered matrix calculated\n# according to these segments of CUTLASS code is re-implemented here.\n# Note that this calculation produces offsets for scattering metadata\n# matrix elements into reordered metadata matrix elements (or,\n# equivalently, for gathering reordered metadata matrix element back\n# into metadata matrix elements).\ndef _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype,\n                                               device):\n    dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)\n    dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)\n\n    # Reorder the rows, then swizzle the 2x2 blocks.\n    group_x = 64\n    group_y = 32 if meta_dtype.itemsize == 2 else 16\n\n    dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 +\n                (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 +\n                ((dst_rows % group_x) // 8) * 4)\n\n    topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)\n    bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)\n    dst_rows += topright - bottomleft\n    dst_cols -= topright - bottomleft\n\n    # Assumed that meta tensor is to be stored in CUTLASS\n    # InterleavedColumnMajor layout, and reverse engineered\n    # corresponding code to store values into this tensor.\n    interleave = 2\n    cols_maj = dst_cols // interleave\n    cols_min = dst_cols % interleave\n    return (cols_maj * m * interleave + dst_rows * interleave +\n            cols_min).view(-1)\n\n\n# This function converts dense matrix into sparse semi-structured\n# representation, producing \"compressed\" matrix, in the layout used by\n# CUTLASS backend, and corresponding metadata matrix.\ndef sparse_semi_structured_from_dense_cutlass(dense):\n    if dense.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor\"  # noqa: E501\n        )\n\n    m, k = dense.shape\n    device = dense.device\n\n    meta_dtype = torch.int8\n    if dense.dtype == torch.int8:\n        meta_dtype = torch.int32\n    elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:\n        meta_dtype = torch.int16\n    else:\n        raise RuntimeError(f\"Invalid datatype {dense.dtype} of dense matrix\")\n    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4\n    if quadbits_per_meta_elem not in (4, 8):\n        raise RuntimeError(\n            \"Invalid number of elements per meta element calculated\")\n\n    if meta_dtype == torch.int32:\n        if m % 16 != 0:\n            raise RuntimeError(\n                f\"Number of rows of dense matrix {m} must be divisible by 16\")\n    else:\n        if m % 32 != 0:\n            raise RuntimeError(\n                f\"Number of rows of dense matrix {m} must be divisible by 32\")\n    if k % (4 * quadbits_per_meta_elem) != 0:\n        raise RuntimeError(\n            f\"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}\"  # noqa: E501\n        )\n\n    if dense.dtype != torch.float:\n        ksparse = 4\n        dense_4 = dense.view(-1, k // ksparse, ksparse)\n        m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)\n    else:\n        ksparse = 2\n        dense_2 = dense.view(-1, k // ksparse, ksparse)\n        m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)\n    meta_ncols = k // (ksparse * quadbits_per_meta_elem)\n\n    # Encoding quadruples of True/False values as follows:\n    #     [True,  True,  False, False] -> 0b0100\n    #     [True,  False, True,  False] -> 0b1000\n    #     [False, True,  True,  False] -> 0b1001\n    #     [True,  False, False, True ] -> 0b1100\n    #     [False, True,  False, True ] -> 0b1101\n    #     [False, False, True,  True ] -> 0b1110\n    # Thus, lower two bits in the encoding are index of the True value\n    # at the lowest index in the quadruple, and the higher two bits in\n    # the encoding are index of the other True value in the quadruple.\n    # In case there are less than two True values, than False value or\n    # values at some index or indices are considered True for the\n    # encoding.  In case there are more than two True values, then the\n    # excess True value(s) at some indices are considered False for\n    # the encoding.  The exact encodings used for these cases are as\n    # follows:\n    #     [False, False, False, False] -> 0b1110\n    #     [False, False, False, True ] -> 0b1110\n    #     [False, False, True,  False] -> 0b1110\n    #     [False, True,  False, False] -> 0b1001\n    #     [False, True,  True,  True ] -> 0b1101\n    #     [True,  False, False, False] -> 0b1000\n    #     [True,  False, True,  True ] -> 0b1100\n    #     [True,  True,  False, True ] -> 0b0100\n    #     [True,  True,  True,  False] -> 0b0100\n    #     [True,  True,  True,  True ] -> 0b0100\n    # These particular encodings are chosen, with the help of Espresso\n    # logic minimizer software, for the purpose of minimization of\n    # corresponding Boolean functions, that translate non-zero flags\n    # into encoding bits.  Note also possible choices for the first\n    # and last of these encodings were limited only to (0b0100,\n    # 0b1110), in order to produce valid encodings for 1:2 sparsity\n    # case.\n\n    expr0 = m0 & m1\n    expr1 = ~m0 & m1\n    expr2 = ~m0 & ~m1\n    bit0 = expr1\n    bit1 = expr2\n    bit2 = expr0 | expr2 | m3\n    bit3 = expr1 | ~m1\n    idxs0 = bit0 | (bit1.to(torch.int64) << 1)\n    idxs1 = bit2 | (bit3.to(torch.int64) << 1)\n\n    if dense.dtype != torch.float:\n        sparse0 = dense_4.gather(\n            -1, idxs0.unsqueeze(-1))  # type: ignore[possibly-undefined]\n        sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))\n        sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)\n    else:\n        sparse = dense_2.gather(-1,\n                                idxs0.unsqueeze(-1) // 2).view(\n                                    m,\n                                    k // 2)  # type: ignore[possibly-undefined]\n\n    meta_4 = idxs0 | (idxs1 << 2)\n    meta_n = meta_4.view(\n        (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)\n\n    if quadbits_per_meta_elem == 4:\n        meta = (meta_n[:, :, 0]\n                | (meta_n[:, :, 1] << 4)\n                | (meta_n[:, :, 2] << 8)\n                | (meta_n[:, :, 3] << 12))\n    elif quadbits_per_meta_elem == 8:\n        meta = (meta_n[:, :, 0]\n                | (meta_n[:, :, 1] << 4)\n                | (meta_n[:, :, 2] << 8)\n                | (meta_n[:, :, 3] << 12)\n                | (meta_n[:, :, 4] << 16)\n                | (meta_n[:, :, 5] << 20)\n                | (meta_n[:, :, 6] << 24)\n                | (meta_n[:, :, 7] << 28))\n\n    # Reorder meta tensor elements.\n    meta_reordered = meta.new_empty(\n        (m * meta_ncols, ))  # type: ignore[possibly-undefined]\n    meta_offsets = _calculate_meta_reordering_scatter_offsets(\n        m, meta_ncols, meta_dtype, device)\n    meta_reordered.scatter_(0, meta_offsets, meta.view(-1))\n\n    return (sparse, meta_reordered.view(m, meta_ncols))\n\n\n# This function performs reverse of the function above - it\n# reconstructs dense matrix from a pair of \"compressed\" matrix, given\n# in the layout used by CUTLASS backend, and accompanying metadata\n# matrix.\ndef sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):\n    if sparse.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor\"  # noqa: E501\n        )\n\n    m, k = sparse.shape\n    device = sparse.device\n\n    if meta_reordered.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor\"  # noqa: E501\n        )\n    if meta_reordered.device != device:\n        raise RuntimeError(\n            f\"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device\"  # noqa: E501\n        )\n\n    meta_dtype = meta_reordered.dtype\n    if meta_dtype not in (torch.int16, torch.int32):\n        raise RuntimeError(f\"Invalid datatype {meta_dtype} of meta matrix\")\n    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4\n\n    ksparse = 4 if sparse.dtype != torch.float else 2\n\n    meta_nrows, meta_ncols = meta_reordered.shape\n    if meta_nrows != m:\n        raise RuntimeError(\n            f\"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}\"  # noqa: E501\n        )\n    if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:\n        raise RuntimeError(\n            f\"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, \"  # noqa: E501\n            \"expected according to the number of columns of meta matrix\")\n\n    # Undo meta tensor elements reordering.\n    meta_offsets = _calculate_meta_reordering_scatter_offsets(\n        m, meta_ncols, meta_dtype, device)\n    meta = torch.gather(meta_reordered.view(-1), 0,\n                        meta_offsets).view(m, meta_ncols)\n\n    # Unpack sparse tensor back to original dense tensor, using\n    # information provided by meta tensor.  Note that torch.float\n    # datatype is handled pretty much the same as\n    # torch.half/torch.bfloat16, as metadata for a pair of torch.float\n    # value is encoded as if underlying 8 bytes contain four\n    # torch.half/torch.bfloat16 values, where either first two or last\n    # two are zeros.\n    meta_2 = torch.empty(\n        (m, meta_ncols, 2 * quadbits_per_meta_elem),\n        dtype=meta_dtype,\n        device=device,\n    )\n    if quadbits_per_meta_elem == 4:\n        meta_2[:, :, 0] = meta & 0b11\n        meta_2[:, :, 1] = (meta >> 2) & 0b11\n        meta_2[:, :, 2] = (meta >> 4) & 0b11\n        meta_2[:, :, 3] = (meta >> 6) & 0b11\n        meta_2[:, :, 4] = (meta >> 8) & 0b11\n        meta_2[:, :, 5] = (meta >> 10) & 0b11\n        meta_2[:, :, 6] = (meta >> 12) & 0b11\n        meta_2[:, :, 7] = (meta >> 14) & 0b11\n    elif quadbits_per_meta_elem == 8:\n        meta_2[:, :, 0] = meta & 0b11\n        meta_2[:, :, 1] = (meta >> 2) & 0b11\n        meta_2[:, :, 2] = (meta >> 4) & 0b11\n        meta_2[:, :, 3] = (meta >> 6) & 0b11\n        meta_2[:, :, 4] = (meta >> 8) & 0b11\n        meta_2[:, :, 5] = (meta >> 10) & 0b11\n        meta_2[:, :, 6] = (meta >> 12) & 0b11\n        meta_2[:, :, 7] = (meta >> 14) & 0b11\n        meta_2[:, :, 8] = (meta >> 16) & 0b11\n        meta_2[:, :, 9] = (meta >> 18) & 0b11\n        meta_2[:, :, 10] = (meta >> 20) & 0b11\n        meta_2[:, :, 11] = (meta >> 22) & 0b11\n        meta_2[:, :, 12] = (meta >> 24) & 0b11\n        meta_2[:, :, 13] = (meta >> 26) & 0b11\n        meta_2[:, :, 14] = (meta >> 28) & 0b11\n        meta_2[:, :, 15] = (meta >> 30) & 0b11\n\n    dense_offsets = meta_2.view(-1) + (\n        torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view(\n            -1, 1).repeat(1, 2).view(-1)\n\n    dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device)\n    if sparse.dtype != torch.float:\n        # dense.scatter_(0, dense_offsets, sparse.view(-1))\n        dense.scatter_(0, dense_offsets, sparse.reshape(-1))\n    else:\n        dense.view(torch.half).scatter_(0, dense_offsets,\n                                        sparse.view(torch.half).view(-1))\n\n    return dense.view(m, 2 * k)\n\n\ndef mask_creator(tensor):\n    \"\"\"\n    Class for creating N:M sparsity masks.\n    Masks will be created using the N:M ratio, where for every block of \n    M weights, N will be pruned based on ranked weight value. Each mask \n    will correspond to the given tensor.\n\n    :param N: The number of weights in a group to keep\n    :param M: The size of a weight group\n    \"\"\"\n    N = 2\n    M = 4\n\n    mask = None\n    # for i, tensor in enumerate(tensors):\n    if tensor.numel() % M != 0:\n        raise ValueError(\n            f\"Tensor of size {tensor.shape} can't be evenly divided into \"\n            f\"{M} groups\")\n\n    num_groups = tensor.numel() // M\n\n    # N:M sparsity for linear layers\n    tensor_temp = tensor.detach().abs().reshape(num_groups, M)\n    index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)]\n\n    w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)\n    mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)\n\n    return mask\n"
  },
  {
    "path": "archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_24_perms.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nfrom typing import Dict, List\n\nimport numpy\nimport torch\n\n\n# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501\n#\n# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501\n# with the tensor-core format that is described here:\n# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501\n#\n# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501\n# (without the need to use ldmatrix instructions) # noqa: E501\ndef get_perms_24(num_bits: int):\n    perm_list: List[int] = []\n    for i in range(32):\n        perm1: List[int] = []\n        col = i // 4\n        col_o = col // 2\n        for block in [0, 1]:\n            for row in [\n                    2 * (i % 4),\n                    2 * (i % 4) + 1,\n                    2 * (i % 4 + 4),\n                    2 * (i % 4 + 4) + 1,\n            ]:\n                perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +\n                             4 * block)\n        for j in range(4):\n            perm_list.extend([p + 1 * j for p in perm1])\n    perm = numpy.array(perm_list)\n\n    if num_bits == 4:\n        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = numpy.array([0, 2, 1, 3])\n    else:\n        raise ValueError(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()\n    perm = torch.from_numpy(perm)\n    scale_perm: List[int] = []\n    for i in range(8):\n        scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])\n    scale_perm_single: List[int] = []\n    for i in range(8):\n        scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])\n    return perm, scale_perm, scale_perm_single\n\n\nmarlin_24_perm: Dict[int, torch.Tensor] = {}\nmarlin_24_scale_perm: Dict[int, List[int]] = {}\nmarlin_24_scale_perm_single: Dict[int, List[int]] = {}\nfor num_bits in [4, 8]:\n    perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)\n    marlin_24_perm[num_bits] = perm_24\n    marlin_24_scale_perm[num_bits] = scale_perm_24\n    marlin_24_scale_perm_single[num_bits] = scale_perm_single_24\n"
  },
  {
    "path": "archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_perms.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nfrom typing import Dict, List\n\nimport numpy\nimport torch\n\n\n# Precompute permutations for Marlin weight and scale shuffling # noqa: E501\n#\n# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501\n# with the tensor-core format that is described here:\n# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501\n#\n# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501\n# (without the need to use ldmatrix instructions) # noqa: E501\ndef get_perms(num_bits: int):\n    perm_list: List[int] = []\n    for i in range(32):\n        perm1: List[int] = []\n        col = i // 4\n        for block in [0, 1]:\n            for row in [\n                    2 * (i % 4),\n                    2 * (i % 4) + 1,\n                    2 * (i % 4 + 4),\n                    2 * (i % 4 + 4) + 1,\n            ]:\n                perm1.append(16 * row + col + 8 * block)\n        for j in range(4):\n            perm_list.extend([p + 256 * j for p in perm1])\n\n    perm = numpy.array(perm_list)\n\n    if num_bits == 4:\n        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = numpy.array([0, 2, 1, 3])\n    else:\n        raise Exception(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()\n    perm = torch.from_numpy(perm)\n    scale_perm: List[int] = []\n    for i in range(8):\n        scale_perm.extend([i + 8 * j for j in range(8)])\n    scale_perm_single: List[int] = []\n    for i in range(4):\n        scale_perm_single.extend(\n            [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])\n    return perm, scale_perm, scale_perm_single\n\n\nmarlin_perm: Dict[int, torch.Tensor] = {}\nmarlin_scale_perm: Dict[int, List[int]] = {}\nmarlin_scale_perm_single: Dict[int, List[int]] = {}\nfor num_bits in [4, 8]:\n    perm, scale_perm, scale_perm_single = get_perms(num_bits)\n    marlin_perm[num_bits] = perm\n    marlin_scale_perm[num_bits] = scale_perm\n    marlin_scale_perm_single[num_bits] = scale_perm_single\n"
  },
  {
    "path": "archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nimport random\n\nimport numpy\nimport torch\n\nfrom ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.format_24 import (\n    mask_creator, sparse_semi_structured_from_dense_cutlass)\nfrom ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_24_perms import (\n    marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single)\nfrom ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_perms import (\n    marlin_perm, marlin_scale_perm, marlin_scale_perm_single)\nfrom ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.quant_utils import (\n    get_pack_factor, quantize_weights, sort_weights)\n\n__cuda_arch = torch.cuda.get_device_capability()\n\nMARLIN_TILE = 16\n\nGPTQ_MARLIN_TILE = 16\nGPTQ_MARLIN_MIN_THREAD_N = 64\nGPTQ_MARLIN_MIN_THREAD_K = 128\nGPTQ_MARLIN_MAX_PARALLEL = 16\n\nGPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]\nGPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]\nGPTQ_MARLIN_SUPPORTED_SYM = [True]\n\ndef is_marlin_supported():\n    return __cuda_arch[0] >= 8\n\n\ndef marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):\n    assert q_w.shape == (size_k, size_n)\n    assert size_k % tile == 0, f\"size_k = {size_k}, tile = {tile}\"\n    assert size_n % tile == 0, f\"size_k = {size_n}, tile = {tile}\"\n\n    # Permute weights to 16x64 marlin tiles\n    q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))\n    q_w = q_w.permute((0, 2, 1, 3))\n    q_w = q_w.reshape((size_k // tile, size_n * tile))\n\n    q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)\n\n    return q_w\n\n\ndef marlin_weights(q_w, size_k, size_n, num_bits, perm):\n    # Permute\n    q_w = marlin_permute_weights(q_w, size_k, size_n, perm)\n\n    # Pack\n    pack_factor = get_pack_factor(num_bits)\n    orig_device = q_w.device\n\n    q_w = q_w.cpu().numpy().astype(numpy.uint32)\n\n    q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),\n                           dtype=numpy.uint32)\n    for i in range(pack_factor):\n        q_packed |= q_w[:, i::pack_factor] << num_bits * i\n\n    q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)\n\n    return q_packed\n\n\ndef marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,\n                          scale_perm_single):\n    if group_size < size_k and group_size != -1:\n        s = s.reshape((-1, len(scale_perm)))[:, scale_perm]\n    else:\n        s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]\n    s = s.reshape((-1, size_n)).contiguous()\n\n    return s\n\n\ndef marlin_quantize(\n    w: torch.Tensor,\n    num_bits: int,\n    group_size: int,\n    act_order: bool,\n):\n    size_k, size_n = w.shape\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    # Quantize (and apply act_order if provided)\n    q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,\n                                                       act_order)\n\n    # For act_order, sort the \"weights\" and \"g_idx\" so that group ids are\n    # increasing\n    sort_indices = torch.empty(0, dtype=torch.int, device=w.device)\n    if act_order:\n        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)\n\n    # Reformat to marlin\n    marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,\n                                marlin_perm[num_bits])\n    marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,\n                                     marlin_scale_perm[num_bits],\n                                     marlin_scale_perm_single[num_bits])\n\n    # Create result\n    res_list = [marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]\n    for i in range(len(res_list)):\n        res_list[i] = res_list[i].to(w.device)\n\n    return res_list\n\n\ndef vllm_marlin_quantize(\n    w: torch.Tensor,\n    num_bits: int,\n    group_size: int,\n    act_order: bool,\n):\n    size_k, size_n = w.shape\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    # Quantize (and apply act_order if provided)\n    w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,\n                                                       act_order)\n\n    # For act_order, sort the \"weights\" and \"g_idx\" so that group ids are\n    # increasing\n    sort_indices = torch.empty(0, dtype=torch.int, device=w.device)\n    if act_order:\n        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)\n\n    # Reformat to marlin\n    marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,\n                                marlin_perm[num_bits])\n    marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,\n                                     marlin_scale_perm[num_bits],\n                                     marlin_scale_perm_single[num_bits])\n\n    # Create result\n    res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]\n    for i in range(len(res_list)):\n        res_list[i] = res_list[i].to(w.device)\n\n    return res_list\n\n\ndef inject_24(w, size_k, size_n):\n    assert w.shape == (size_k, size_n)\n\n    mask = mask_creator(w.t()).t().cuda().bool()\n\n    return (mask * w).contiguous(), mask.contiguous()\n\n\ndef check_24(w, num_rows_to_sample=50, _verbose=False):\n    BLOCK_SIZE = 4\n    MAX_NON_ZEROS = 2\n\n    w = w.t().contiguous()\n\n    print(\"check_24: w.shape = {}\".format(w.shape))\n\n    num_rows, num_cols = w.shape\n    sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)\n    if _verbose:\n        print(f\"Sampled row idxs = {sampled_row_idxs}\")\n\n    total_segments = 0\n    non_24_segments = 0\n    for i in sampled_row_idxs:\n        for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):\n            total_segments += 1\n            block = w[i, j:j + BLOCK_SIZE]\n            num_nonzero = torch.count_nonzero(block)\n            if num_nonzero > MAX_NON_ZEROS:\n                print(\"i = {} j = {} block = {}\".format(i, j, block))\n                non_24_segments += 1\n\n    print(f\"{non_24_segments} / {total_segments} do not have 2:4 structure.\")\n\n\ndef compress_quantized_24_weight(q_24, size_k, size_n, num_bits):\n    assert q_24.shape == (size_k, size_n)\n\n    # Remove zp to normalize over 0\n    max_q_val = (1 << num_bits) - 1\n    zp = (max_q_val + 1) // 2\n    q_24_no_zp = q_24 - zp\n\n    # Compress\n    q_24_no_zp = q_24_no_zp.t().contiguous()\n    q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(\n        q_24_no_zp)\n    q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()\n\n    # Restore zp\n    q_24_comp = q_24_no_zp_comp + zp\n\n    # Resize meta to its actual shape (without moving any data)\n    meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)\n\n    return q_24_comp, meta\n\n\ndef marlin_24_quantize(\n    w: torch.Tensor,\n    num_bits: int,\n    group_size: int,\n):\n    size_k, size_n = w.shape\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    # Inject 2:4 sparsity\n    w_24, mask_24 = inject_24(w, size_k, size_n)\n\n    # Quantize\n    w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,\n                                                             num_bits,\n                                                             group_size,\n                                                             act_order=False)\n\n    # Compress quantized weight\n    q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,\n                                                     num_bits)\n    size_k_comp = size_k // 2\n\n    # Reformat to marlin\n    marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,\n                                        num_bits, marlin_24_perm[num_bits])\n    marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,\n                                        marlin_24_scale_perm[num_bits],\n                                        marlin_24_scale_perm_single[num_bits])\n\n    # Create result\n    res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]\n    for i in range(len(res_list)):\n        res_list[i] = res_list[i].to(w.device)\n\n    return res_list\n\n\ndef compute_max_diff(output, output_ref):\n    return torch.mean(torch.abs(output - output_ref)) / torch.mean(\n        torch.abs(output_ref))\n\n\nclass MarlinWorkspace:\n\n    def __init__(self, out_features, min_thread_n, max_parallel, device):\n        assert (out_features % min_thread_n == 0), (\n            \"out_features = {} is undivisible by min_thread_n = {}\".format(\n                out_features, min_thread_n))\n\n        max_workspace_size = ((out_features // min_thread_n) * max_parallel)\n\n        self.scratch = torch.zeros(max_workspace_size,\n                                   dtype=torch.int,\n                                   device=device)\n"
  },
  {
    "path": "archive/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nimport numpy\nimport torch\n\nSUPPORTED_NUM_BITS = [4, 8]\nSUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]\n\n\ndef get_pack_factor(num_bits):\n    assert num_bits in SUPPORTED_NUM_BITS, f\"Unsupported num_bits = {num_bits}\"\n    return 32 // num_bits\n\n\ndef permute_rows(q_w: torch.Tensor, group_size: int):\n\n    orig_device = q_w.device\n    k_size, _ = q_w.shape\n\n    g_idx = torch.zeros((k_size, ), dtype=torch.int32)\n    for i in range(k_size):\n        g_idx[i] = i // group_size\n\n    # Simulate act_order by doing a random permutation on K\n    rand_perm = torch.randperm(k_size)\n\n    g_idx = g_idx[rand_perm].contiguous()\n    q_w = q_w[rand_perm, :].contiguous()\n\n    return (\n        q_w.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        rand_perm.to(device=orig_device),\n    )\n\n\ndef quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,\n                     act_order: bool):\n    orig_device = w.device\n    size_k, size_n = w.shape\n\n    assert w.is_floating_point(), \"w must be float\"\n    assert num_bits in SUPPORTED_NUM_BITS, f\"Unsupported num_bits = {num_bits}\"\n    assert group_size in SUPPORTED_GROUP_SIZES + [\n        size_k\n    ], f\"Unsupported groupsize = {group_size}\"\n\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    max_q_val = 2**num_bits - 1\n    half_q_val = (max_q_val + 1) // 2\n\n    # Reshape to [groupsize, -1]\n    if group_size < size_k:\n        w = w.view((-1, group_size, size_n))\n        w = w.permute(1, 0, 2)\n        w = w.reshape((group_size, -1))\n\n    # Compute scale for each group\n    s = torch.max(torch.abs(w), 0, keepdim=True)[0]\n    s *= 2 / max_q_val  # 2 => symmetric\n\n    # Quantize\n    q_w = torch.round(w / s).int()\n    q_w += half_q_val\n    q_w = torch.clamp(q_w, 0, max_q_val)\n\n    # Restore original shapes\n    if group_size < size_k:\n\n        def reshape_w(w):\n            w = w.reshape((group_size, -1, size_n))\n            w = w.permute(1, 0, 2)\n            w = w.reshape((size_k, size_n)).contiguous()\n            return w\n\n        q_w = reshape_w(q_w)\n\n    s = s.reshape((-1, size_n)).contiguous()\n\n    # Apply act_order\n    g_idx = torch.empty(0, dtype=torch.int, device=w.device)\n    rand_perm = torch.empty(0, dtype=torch.int, device=w.device)\n    if act_order:\n        assert (\n            group_size < size_k\n        ), \"For act_order, groupsize = {} must be less than size_k = {}\".format(\n            group_size, size_k)\n\n        q_w, g_idx, rand_perm = permute_rows(q_w, group_size)\n\n    return (\n        q_w.to(device=orig_device),\n        s.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        rand_perm.to(device=orig_device),\n    )\n\n\ndef sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):\n    orig_device = q_w.device\n\n    sort_indices = torch.argsort(g_idx).to(\n        dtype=torch.int32)  # Sort based on g_idx\n\n    g_idx = g_idx[sort_indices].contiguous()\n    q_w = q_w[sort_indices, :].contiguous()\n\n    return (\n        q_w.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        sort_indices.to(device=orig_device),\n    )\n\n\ndef gptq_pack(\n    q_w: torch.Tensor,\n    num_bits: int,\n    size_k: int,\n    size_n: int,\n):\n    assert q_w.shape == (size_k, size_n)\n\n    pack_factor = get_pack_factor(num_bits)\n    assert size_k % pack_factor == 0\n\n    orig_device = q_w.device\n\n    q_w = q_w.cpu().numpy().astype(numpy.uint32)\n\n    q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)\n\n    for i in range(pack_factor):\n        q_res |= q_w[i::pack_factor, :] << num_bits * i\n\n    q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)\n    return q_res\n"
  },
  {
    "path": "archive/ktransformers/ktransformers_ext/triton/fp8gemm.py",
    "content": "# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py\nfrom typing import Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom triton import Config\n\n\n@triton.jit\ndef act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):\n    \"\"\"\n    Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.\n\n    Args:\n        x_ptr (triton.Pointer): Pointer to the input tensor.\n        y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.\n        s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.\n        BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.\n\n    Returns:\n        None\n    \"\"\"\n    pid = tl.program_id(axis=0)\n    offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    x = tl.load(x_ptr + offs).to(tl.float32)\n    s = tl.max(tl.abs(x)) / 448.\n    y = x / s\n    y = y.to(y_ptr.dtype.element_ty)\n    tl.store(y_ptr + offs, y)\n    tl.store(s_ptr + pid, s)\n\n\ndef act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Quantizes the input tensor `x` using block-wise quantization.\n\n    Args:\n        x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.\n        block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:\n            - The quantized tensor with dtype `torch.float8_e4m3fn`.\n            - A tensor of scaling factors with dtype `torch.float32`.\n    \"\"\"\n    assert x.is_contiguous(), 'Input tensor must be contiguous'\n    assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'\n    y = torch.empty_like(x, dtype=torch.float8_e4m3fn)\n    s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)\n    grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )\n    act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)\n    return y, s\n\n\n@triton.jit\ndef weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):\n    \"\"\"\n    Dequantizes weights using the provided scaling factors and stores the result.\n\n    Args:\n        x_ptr (tl.pointer): Pointer to the quantized weights.\n        s_ptr (tl.pointer): Pointer to the scaling factors.\n        y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.\n        M (int): Number of rows in the weight matrix.\n        N (int): Number of columns in the weight matrix.\n        BLOCK_SIZE (tl.constexpr): Size of the block for tiling.\n\n    Returns:\n        None\n    \"\"\"\n    pid_m = tl.program_id(axis=0)\n    pid_n = tl.program_id(axis=1)\n    n = tl.cdiv(N, BLOCK_SIZE)\n    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    offs = offs_m[:, None] * N + offs_n[None, :]\n    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)\n    x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)\n    s = tl.load(s_ptr + pid_m * n + pid_n)\n    y = x * s\n    tl.store(y_ptr + offs, y, mask=mask)\n\n\ndef weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:\n    \"\"\"\n    Dequantizes the given weight tensor using the provided scale tensor.\n\n    Args:\n        x (torch.Tensor): The quantized weight tensor of shape (M, N).\n        s (torch.Tensor): The scale tensor of shape (M, N).\n        block_size (int, optional): The block size to use for dequantization. Defaults to 128.\n\n    Returns:\n        torch.Tensor: The dequantized weight tensor of the same shape as `x`.\n\n    Raises:\n        AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.\n    \"\"\"\n    assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'\n    assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'\n    M, N = x.size()\n    y = torch.empty_like(x, dtype=torch.get_default_dtype())\n    grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))\n    with torch.cuda.device(x.device):\n        weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)\n    return y\n\n\nfp8_gemm_configs = [\n    Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)\n    for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]\n]\n\n@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])\n@triton.jit\ndef fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,\n                    a_s_ptr, b_s_ptr,\n                    M, N: tl.constexpr, K: tl.constexpr,\n                    BLOCK_SIZE_M: tl.constexpr,\n                    BLOCK_SIZE_N: tl.constexpr,\n                    BLOCK_SIZE_K: tl.constexpr):\n    \"\"\"\n    Performs a matrix multiplication operation on FP8 matrices with scaling factors.\n\n    Args:\n        a_ptr (tl.tensor): Pointer to the first input matrix A.\n        b_ptr (tl.tensor): Pointer to the second input matrix B.\n        c_ptr (tl.tensor): Pointer to the output matrix C.\n        a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.\n        b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.\n        M (int): Number of rows in matrix A and C.\n        N (tl.constexpr): Number of columns in matrix B and C.\n        K (tl.constexpr): Number of columns in matrix A and rows in matrix B.\n        BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.\n        BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.\n        BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.\n\n    Returns:\n        None\n    \"\"\"\n    pid_m = tl.program_id(axis=0)\n    pid_n = tl.program_id(axis=1)\n    k = tl.cdiv(K, BLOCK_SIZE_K)\n    offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n    offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]\n    b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]\n    a_s_ptrs = a_s_ptr + offs_m * k\n    b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k\n\n    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for i in range(k):\n        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)\n        a_s = tl.load(a_s_ptrs)\n        b_s = tl.load(b_s_ptrs)\n        accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]\n        a_ptrs += BLOCK_SIZE_K\n        b_ptrs += BLOCK_SIZE_K\n        a_s_ptrs += 1\n        b_s_ptrs += 1\n    c = accumulator.to(c_ptr.dtype.element_ty)\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]\n    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)\n    tl.store(c_ptrs, c, mask=mask)\n\n\ndef fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):\n    \"\"\"\n    Perform a matrix multiplication using FP8 precision.\n\n    Args:\n        a (torch.Tensor): The first input matrix, must be contiguous.\n        a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.\n        b (torch.Tensor): The second input matrix, must be contiguous.\n        b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.\n\n    Returns:\n        torch.Tensor: The result of the matrix multiplication.\n    \"\"\"\n    assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'\n    assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'\n    K = a.size(-1)\n    M = a.numel() // K\n    N = b.size(0)\n    c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())\n    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))\n    fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)\n    return c"
  },
  {
    "path": "archive/ktransformers/local_chat.py",
    "content": "\"\"\"\nDescription  :  \nAuthor       : Boxin Zhang, Azure-Tang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\n\nimport os\nimport platform\nimport sys\n\nproject_dir = os.path.dirname(os.path.dirname(__file__))\nsys.path.insert(0, project_dir)\nimport torch\ntry:\n    import torch_npu\n    from torch_npu.contrib import transfer_to_npu\n    from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group\n    from ktransformers.util import utils, npu_graph_runner\nexcept:\n    pass\nimport torch.distributed as dist\n\nimport logging\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    TextStreamer,\n)\nimport json\nimport fire\nfrom ktransformers.optimize.optimize import optimize_and_load_gguf\nfrom ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM\nfrom ktransformers.models.modeling_llama import LlamaForCausalLM\nfrom ktransformers.models.modeling_mixtral import MixtralForCausalLM\nfrom ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model\nfrom ktransformers.util import utils\nfrom ktransformers.models.custom_cache import StaticCache\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.operators.flashinfer_wrapper import flashinfer_enabled\nfrom ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor\n\ncustom_models = {\n    \"DeepseekV2ForCausalLM\": DeepseekV2ForCausalLM,\n    \"DeepseekV3ForCausalLM\": DeepseekV3ForCausalLM,\n    \"Qwen2MoeForCausalLM\": Qwen2MoeForCausalLM,\n    \"LlamaForCausalLM\": LlamaForCausalLM,\n    \"MixtralForCausalLM\": MixtralForCausalLM,\n}\n\nktransformer_rules_dir = (\n    os.path.dirname(os.path.abspath(__file__)) + \"/optimize/optimize_rules/\"\n)\ndefault_optimize_rules = {\n    \"DeepseekV2ForCausalLM\": ktransformer_rules_dir + \"DeepSeek-V2-Chat.yaml\",\n    \"DeepseekV3ForCausalLM\": ktransformer_rules_dir + \"DeepSeek-V3-Chat.yaml\",\n    \"Qwen2MoeForCausalLM\": ktransformer_rules_dir + \"Qwen2-57B-A14B-Instruct.yaml\",\n    \"LlamaForCausalLM\": ktransformer_rules_dir + \"Internlm2_5-7b-Chat-1m.yaml\",\n    \"MixtralForCausalLM\": ktransformer_rules_dir + \"Mixtral.yaml\",\n}\n\ntry:\n    torch.npu.config.allow_internal_format = True\n    torch.npu.set_compile_mode(jit_compile=False)\nexcept:\n    pass\n\nimport sys, signal, faulthandler\nfaulthandler.register(signal.SIGUSR1, file=sys.stderr, all_threads=True, chain=False)\n\n\ndef local_chat(\n    model_path: str | None = None,\n    optimize_config_path: str = None,\n    gguf_path: str | None = None,\n    max_new_tokens: int = 1000,\n    cpu_infer: int = Config().cpu_infer,\n    use_cuda_graph: bool = True,\n    prompt_file : str | None = None,\n    mode: str = \"normal\",\n    force_think: bool = False,\n    chunk_size: int = 8192,\n    device: str = \"cuda\",\n    tp: int = 1,\n):\n    Config().cpu_infer = cpu_infer\n\n    local_rank, world_size = setup_model_parallel(tp=tp)\n    torch.set_grad_enabled(False)\n    if utils.CUR_DEVICE is None:\n        utils.CUR_DEVICE = f\"npu:{torch.npu.current_device()}\"\n\n    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n    config.chunk_size = chunk_size\n    npu_graph_runner.LAYER_ID = config.num_hidden_layers\n    if mode == 'long_context':\n        assert config.architectures[0] == \"LlamaForCausalLM\", \"only LlamaForCausalLM support long_context mode\"\n        torch.set_default_dtype(torch.float16)\n    else:\n        torch.set_default_dtype(config.torch_dtype)\n\n    with torch.device(\"meta\"):\n        if config.architectures[0] in custom_models:\n            print(\"using custom modeling_xxx.py.\")\n            if (\n                \"Qwen2Moe\" in config.architectures[0]\n            ):  # Qwen2Moe must use flash_attention_2 to avoid overflow.\n                config._attn_implementation = \"flash_attention_2\"\n            if \"Llama\" in config.architectures[0]:\n                config._attn_implementation = \"eager\"\n            if \"Mixtral\" in config.architectures[0]:\n                config._attn_implementation = \"flash_attention_2\"\n\n            model = custom_models[config.architectures[0]](config)\n        else:\n            model = AutoModelForCausalLM.from_config(\n                config, trust_remote_code=True, attn_implementation=\"flash_attention_2\"\n            )\n\n    if optimize_config_path is None:\n        if config.architectures[0] in default_optimize_rules:\n            print(\"using default_optimize_rule for\", config.architectures[0]) if local_rank == 0 else None\n            optimize_config_path = default_optimize_rules[config.architectures[0]]\n            print(f'{optimize_config_path=}') if local_rank == 0 else None\n        else:\n            optimize_config_path = input(\n                \"please input the path of your rule file(yaml file containing optimize rules):\"\n            )\n\n    if gguf_path is None:\n        gguf_path = input(\n            \"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):\"\n        )\n    optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)\n    # 提前absorbed\n    get_absort_weight(model, config)\n\n    try:\n        model.generation_config = GenerationConfig.from_pretrained(model_path)\n    except Exception as e:\n        print(f\"generation config can't auto create, make default. Message: {e}\")\n        gen_config = GenerationConfig(\n            temperature=0.6,\n            top_p=0.95,\n            do_sample=True\n        )\n        model.generation_config = gen_config\n    # model.generation_config = GenerationConfig.from_pretrained(model_path)\n    if model.generation_config.pad_token_id is None:\n        model.generation_config.pad_token_id = model.generation_config.eos_token_id\n    model.eval()\n    logging.basicConfig(level=logging.INFO)\n\n    system = platform.system()\n    if system == \"Windows\":\n        os.system(\"cls\") if local_rank == 0 else None\n    else:\n        os.system(\"clear\") if local_rank == 0 else None\n\n    print(f\"{model=}\") if local_rank == 0 else None\n\n    batch_size, seq_length = 1, 16384  # default cache pool params\n    device_map = model.gguf_loader.tensor_device_map\n    static_cache = StaticCache(\n        config = model.config, max_batch_size = batch_size, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype\n    )\n\n    torch.distributed.barrier()\n    while True:\n        if local_rank == 0:\n            try:\n                content = input(\"Chat: \\n\").strip()\n            except KeyboardInterrupt:\n                dist.barrier()\n                print('Exit all rank with KeyboardInterrupt!')\n                sys.exit(0)\n            if content.startswith('\"\"\"'):  # prefix \"\"\"\n                # multi lines input\n                content = content[3:] + \"\\n\"\n                while True:\n                    line = input(\"\")\n                    if line.endswith('\"\"\"'):\n                        # end multi lines input\n                        line = line[:-3]  # suffix \"\"\"\n                        if line:\n                            content += line + \"\\n\"\n                        break\n                    else:\n                        content += line + \"\\n\"\n\n            if content == \"\":\n                if prompt_file != None:\n                    content = open(prompt_file, \"r\").read()\n                else:\n                    continue\n            elif os.path.isfile(content):\n                f = open(content, \"r\")\n                content = f.readlines()\n                f.close()\n            else:\n                content = [f\"{len(content)},{max_new_tokens},{content}\"]\n        else:\n            content = [\"\"]\n\n        for line in content:\n            content_tensor = torch.tensor(bytearray(line.encode()), dtype=torch.uint8).to(device=utils.CUR_DEVICE)\n            if world_size > 1:\n                content_size = torch.tensor(len(content_tensor), dtype=torch.int64).to(device=utils.CUR_DEVICE)\n                all_content_sizes = [torch.zeros((1,), dtype=torch.int64).to(device=utils.CUR_DEVICE) for _ in range(world_size)]\n                dist.all_gather(all_content_sizes, content_size)\n                max_content_size = max([size.item() for size in all_content_sizes])\n\n                padded_content_tensor = torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE)\n                padded_content_tensor[:len(content_tensor)] = content_tensor\n\n                all_content_tensors = [torch.zeros((max_content_size,), dtype=torch.uint8).to(device=utils.CUR_DEVICE) for _ in range(world_size)]\n                dist.all_gather(all_content_tensors, padded_content_tensor)\n                content_tensor = all_content_tensors[0][:all_content_sizes[0].item()]\n                line = bytes(content_tensor.cpu().numpy()).decode()\n\n            parts = line.split(\",\")\n            input_tokens = int(parts[0])\n            max_new_tokens = int(parts[1])\n            line = line[line.index(\",\", line.index(\",\") + 1) + 1:]\n            \n            messages = [{\"role\": \"user\", \"content\": line}]\n            input_tensor = tokenizer.apply_chat_template(\n                messages, add_generation_prompt=True, return_tensors=\"pt\"\n            )\n            if force_think:\n                token_thinks = torch.tensor([tokenizer.encode(\"<think>\\\\n\",add_special_tokens=False)],device=input_tensor.device)\n                input_tensor = torch.cat(\n                    [input_tensor, token_thinks], dim=1\n                )\n            if mode == 'long_context':\n                assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \\\n                \"please change max_seq_len in  ~/.ktransformers/config.yaml\"\n\n            if system != \"Windows\" and (config.architectures[0] == \"DeepseekV2ForCausalLM\" or config.architectures[0] == \"DeepseekV3ForCausalLM\") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:\n                generated = prefill_and_generate(\n                    model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,\n                    use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim,\n                    static_cache=static_cache\n                )\n            else:\n                generated = prefill_and_generate(\n                    model, tokenizer, input_tensor.to(device=utils.CUR_DEVICE), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,\n                    static_cache=static_cache\n                )\n\n\nif __name__ == \"__main__\":\n    fire.Fire(local_chat)\n"
  },
  {
    "path": "archive/ktransformers/local_chat_test.py",
    "content": "\"\"\"\nDescription  :  \nAuthor       : Boxin Zhang, Azure-Tang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\n\nimport os\nimport platform\nimport sys\n\nproject_dir = os.path.dirname(os.path.dirname(__file__))\nsys.path.insert(0, project_dir)\nimport torch\nimport logging\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    TextStreamer,\n)\nimport json\nimport fire\nfrom ktransformers.optimize.optimize import optimize_and_load_gguf\nfrom ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM\nfrom ktransformers.models.modeling_llama import LlamaForCausalLM\nfrom ktransformers.models.modeling_mixtral import MixtralForCausalLM\nfrom ktransformers.util.utils import prefill_and_generate, get_compute_capability\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.operators.flashinfer_wrapper import flashinfer_enabled\n\ncustom_models = {\n    \"DeepseekV2ForCausalLM\": DeepseekV2ForCausalLM,\n    \"DeepseekV3ForCausalLM\": DeepseekV3ForCausalLM,\n    \"Qwen2MoeForCausalLM\": Qwen2MoeForCausalLM,\n    \"LlamaForCausalLM\": LlamaForCausalLM,\n    \"MixtralForCausalLM\": MixtralForCausalLM,\n}\n\nktransformer_rules_dir = (\n    os.path.dirname(os.path.abspath(__file__)) + \"/optimize/optimize_rules/\"\n)\ndefault_optimize_rules = {\n    \"DeepseekV2ForCausalLM\": ktransformer_rules_dir + \"DeepSeek-V2-Chat.yaml\",\n    \"DeepseekV3ForCausalLM\": ktransformer_rules_dir + \"DeepSeek-V3-Chat.yaml\",\n    \"Qwen2MoeForCausalLM\": ktransformer_rules_dir + \"Qwen2-57B-A14B-Instruct.yaml\",\n    \"LlamaForCausalLM\": ktransformer_rules_dir + \"Internlm2_5-7b-Chat-1m.yaml\",\n    \"MixtralForCausalLM\": ktransformer_rules_dir + \"Mixtral.yaml\",\n}\n\n\ndef local_chat(\n    model_path: str | None = None,\n    optimize_config_path: str = None,\n    gguf_path: str | None = None,\n    max_new_tokens: int = 1000,\n    cpu_infer: int = Config().cpu_infer,\n    use_cuda_graph: bool = True,\n    prompt_file : str | None = None,\n    mode: str = \"normal\",\n    force_think: bool = False,\n    chunk_prefill_size: int = 8192\n):\n\n    torch.set_grad_enabled(False)\n\n    Config().cpu_infer = cpu_infer\n\n    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n    if mode == 'long_context':\n        assert config.architectures[0] == \"LlamaForCausalLM\", \"only LlamaForCausalLM support long_context mode\"\n        torch.set_default_dtype(torch.float16)\n    else:\n        torch.set_default_dtype(config.torch_dtype)\n\n    with torch.device(\"meta\"):\n        if config.architectures[0] in custom_models:\n            print(\"using custom modeling_xxx.py.\")\n            if (\n                \"Qwen2Moe\" in config.architectures[0]\n            ):  # Qwen2Moe must use flash_attention_2 to avoid overflow.\n                config._attn_implementation = \"flash_attention_2\"\n            if \"Llama\" in config.architectures[0]:\n                config._attn_implementation = \"eager\"\n            if \"Mixtral\" in config.architectures[0]:\n                config._attn_implementation = \"flash_attention_2\"\n\n            model = custom_models[config.architectures[0]](config)\n        else:\n            model = AutoModelForCausalLM.from_config(\n                config, trust_remote_code=True, attn_implementation=\"flash_attention_2\"\n            )\n\n    if optimize_config_path is None:\n        if config.architectures[0] in default_optimize_rules:\n            print(\"using default_optimize_rule for\", config.architectures[0])\n            optimize_config_path = default_optimize_rules[config.architectures[0]]\n        else:\n            optimize_config_path = input(\n                \"please input the path of your rule file(yaml file containing optimize rules):\"\n            )\n\n    if gguf_path is None:\n        gguf_path = input(\n            \"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):\"\n        )\n    optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)\n    \n    try:\n        model.generation_config = GenerationConfig.from_pretrained(model_path)\n    except Exception as e:\n        print(f\"generation config can't auto create, make default. Message: {e}\")\n        gen_config = GenerationConfig(\n            temperature=0.6,\n            top_p=0.95,\n            do_sample=True\n        )\n        model.generation_config = gen_config\n    # model.generation_config = GenerationConfig.from_pretrained(model_path)\n    if model.generation_config.pad_token_id is None:\n        model.generation_config.pad_token_id = model.generation_config.eos_token_id\n    model.eval()\n    logging.basicConfig(level=logging.INFO)\n\n    system = platform.system()\n    if system == \"Windows\":\n        os.system(\"cls\")\n    else:\n        os.system(\"clear\")\n\n    if prompt_file != None:\n        assert os.path.isfile(prompt_file), \"prompt file not exist\"\n        print(f\"prompt file is {prompt_file}\")\n        content = open(prompt_file, \"r\").read()\n    else:\n        content = \"Please write a piece of quicksort code in C++.\"\n\n    print('Start Testing...(1 round)')\n    print('Prompt:', content)\n\n    while True:\n        messages = [{\"role\": \"user\", \"content\": content}]\n        input_tensor = tokenizer.apply_chat_template(\n            messages, add_generation_prompt=True, return_tensors=\"pt\"\n        )\n        if force_think:\n            token_thinks = torch.tensor([tokenizer.encode(\"<think>\\\\n\",add_special_tokens=False)],device=input_tensor.device)\n            input_tensor = torch.cat(\n                [input_tensor, token_thinks], dim=1\n            )\n        if mode == 'long_context':\n            assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \\\n            \"please change max_seq_len in  ~/.ktransformers/config.yaml\"\n        \n        if system != \"Windows\" and (config.architectures[0] == \"DeepseekV2ForCausalLM\" or config.architectures[0] == \"DeepseekV3ForCausalLM\") and flashinfer_enabled and get_compute_capability() >= 8:\n            generated = prefill_and_generate(\n                model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,\n                use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim\n            )\n        else:\n            generated = prefill_and_generate(\n                model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_prefill_size = chunk_prefill_size,\n            )\n        break\n\nif __name__ == \"__main__\":\n    fire.Fire(local_chat)\n"
  },
  {
    "path": "archive/ktransformers/models/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/models/ascend/custom_ascend_modeling_deepseek_v3.py",
    "content": "\"\"\"\r\nDate: 2024-11-06 10:05:11\r\nLastEditors: djw\r\nLastEditTime: 2024-11-13 07:50:51\r\n\"\"\"\r\n\r\nimport math\r\nfrom dataclasses import dataclass\r\nimport torch\r\nimport torch.nn as nn\r\nfrom torch.nn import functional as F\r\nimport torch_npu\r\nimport math\r\nfrom typing import List, Optional, Tuple, Union\r\nimport torch\r\nimport torch.utils.checkpoint\r\nfrom torch import nn\r\nfrom ktransformers.server.config.config import Config\r\n\r\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\r\nfrom ktransformers.models.custom_cache import KVC2StaticCache\r\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3Model,  DeepseekV3PreTrainedModel\r\nfrom ktransformers.models.configuration_deepseek_v3 import DeepseekV3Config\r\nimport ktransformers.util.utils as utils\r\n\r\n\r\ntorch.set_grad_enabled(False)\r\ntorch.set_default_dtype(torch.float16)\r\n\r\n\r\nclass KNPUDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):\r\n\r\n    # cache: KVC2StaticCache\r\n    use_cuda_graph = False\r\n\r\n    def __init__(\r\n        self,\r\n        config: DeepseekV3Config,\r\n        stream = None,\r\n        default_type=torch.float16\r\n    ):\r\n        super().__init__(config)\r\n        self.model = DeepseekV3Model(config)\r\n        self.config = config\r\n        self.config.backend_type = \"balance_serve\"\r\n        # self.cache = cache\r\n        self.vocab_size = config.vocab_size\r\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\r\n        self.default_type = default_type\r\n        self.stream = torch_npu.npu.current_stream() if stream is None else stream\r\n        self.para_stream = torch_npu.npu.Stream()\r\n        self.call_stream = torch_npu.npu.Stream()\r\n        \r\n    def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):\r\n        print('[WARN] this custom modeling do not support flash infer, skip this part...')\r\n\r\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"npu:0\", is_prefill=True):\r\n        features = []\r\n        if is_prefill:\r\n            start_ids = 0\r\n            seq_lens = []\r\n            for i in range(batch.minibatch.prefill_batch):\r\n                assert batch.minibatch.p_kv_len[i] == batch.minibatch.p_q_len[i], \\\r\n                    \"[ERROR] current prefill do not support chunk or prefix cache\"\r\n                tokens = batch.minibatch.p_tokens[start_ids: start_ids+batch.minibatch.p_q_len[i]].contiguous()\r\n                start_ids += batch.minibatch.p_q_len[i]\r\n                feature = (\r\n                    self.model.embed_tokens(tokens.to(torch.device('cpu')))\r\n                    .to(self.default_type)\r\n                    .to(device=device)\r\n                )\r\n                features.append(feature)\r\n                seq_lens.append(feature.shape[0])\r\n\r\n            max_seq_len = max(seq_lens) if seq_lens else 0\r\n\r\n            padded_features = []\r\n            for feat in features:\r\n                curr_len = feat.shape[0]\r\n                if curr_len < max_seq_len:\r\n                    pad_len = max_seq_len - curr_len\r\n                    padded_feat = torch.nn.functional.pad(\r\n                        feat,\r\n                        (0, 0, 0, pad_len),\r\n                        mode='constant',\r\n                        value=0.0\r\n                    )\r\n                    padded_features.append(padded_feat)\r\n                else:\r\n                    padded_features.append(feat)\r\n\r\n            features_t = torch.stack(padded_features)\r\n\r\n        else:\r\n            for i in range(batch.minibatch.decode_batch):\r\n                if batch.minibatch.d_tokens.dim() == 1:\r\n                    tokens = batch.minibatch.d_tokens.contiguous()\r\n                else:\r\n                    tokens = batch.minibatch.d_tokens[i].contiguous()\r\n\r\n                feature = (\r\n                    self.model.embed_tokens(tokens.to(torch.device('cpu')))\r\n                    .to(self.default_type)\r\n                    .to(device=device)\r\n                )\r\n                features.append(feature)\r\n\r\n            features_t = torch.stack(features)\r\n        return features_t\r\n\r\n    def print_callback(self, param):\r\n        with torch.npu.stream(self.call_stream):\r\n            hidden_states = param[0]\r\n            print(\"########################################\")\r\n            print(\"hidden_states is \", hidden_states)\r\n            print(\"########################################\")\r\n\r\n    def forward(\r\n        self,\r\n        batch: ForwardBatchInput | None = None,\r\n        features: torch.Tensor | None = None,\r\n        past_key_value: KVC2StaticCache | None = None,\r\n        bsz_tensors: torch.Tensor | None = None,\r\n        num_tokens_tensors: torch.Tensor | None = None,\r\n        page_idx: torch.Tensor | None = None,\r\n        page_offset: torch.Tensor | None = None,\r\n        position_ids: torch.Tensor | None = None,\r\n        block_tables: torch.Tensor | None = None,\r\n        cuda_graph_idx: int | None = -1,\r\n        is_prefill: bool = True\r\n    ) -> ForwardBatchOutput:\r\n        # NPU use direct block table from ForwardBatchInput instead of page_idx & page_offset\r\n\r\n        if features.ndim == 2:\r\n            hidden_states = features.unsqueeze(0)\r\n        elif features.ndim == 1:\r\n            hidden_states = features.unsqueeze(0).unsqueeze(0)  # (bsz, seqlen, hidden)\r\n        else:\r\n            hidden_states = features\r\n\r\n        (bsz, q_len, hidden_size) = hidden_states.shape\r\n\r\n        if is_prefill:\r\n            position_ids = -1 * torch.ones(bsz, q_len).to(batch.minibatch.p_position_ids.device)\r\n            bsz_real = torch.zeros(bsz).to(batch.minibatch.p_position_ids.device)\r\n            # convert merged into batched\r\n            start_ids = 0\r\n            for i, qlen in enumerate(batch.minibatch.p_q_len):\r\n                position_ids[i, 0:qlen] = batch.minibatch.p_position_ids[start_ids:start_ids+qlen]\r\n                start_ids += qlen\r\n                bsz_real[i] = qlen\r\n            block_tables = batch.minibatch.p_block_tables\r\n            kv_len = batch.minibatch.p_kv_len[0]\r\n            q_len_raw = batch.minibatch.p_q_len\r\n            kv_len_raw = batch.minibatch.p_kv_len\r\n        else:\r\n            position_ids = batch.minibatch.d_position_ids\r\n            if len(position_ids.shape) == 1:\r\n                position_ids = position_ids.unsqueeze(0)\r\n            block_tables = batch.minibatch.d_block_tables\r\n            kv_len = batch.minibatch.d_kv_len[0]\r\n            q_len_raw = None\r\n            kv_len_raw = batch.minibatch.d_kv_len_list\r\n            bsz_real = None\r\n\r\n        for i, decode_layer in enumerate(self.model.layers):\r\n            residual = hidden_states\r\n            hidden_states = decode_layer.input_layernorm(hidden_states)\r\n\r\n            # generate chunk_mask automatically.\r\n            if is_prefill:\r\n                attn_mask = -65504.0 * torch.triu(torch.ones(q_len, kv_len, device=hidden_states.device), diagonal=1)\r\n                attn_mask = attn_mask.unsqueeze(0).unsqueeze(0) # (bsz, 1, q_len, kv_len)\r\n                if bsz > 1:\r\n                    attn_mask = attn_mask.expand(bsz, attn_mask.shape[1], attn_mask.shape[2], attn_mask.shape[3])\r\n            else:\r\n                attn_mask = None\r\n            # print_ex(f\"####: before self_attn of layer {i}...\")\r\n            hidden_states, _, _ = decode_layer.self_attn(hidden_states,\r\n                                                            position_ids=position_ids,\r\n                                                            attention_mask=attn_mask,\r\n                                                            past_key_value=past_key_value,\r\n                                                            num_tokens_tensors=num_tokens_tensors,\r\n                                                            page_idx=page_idx,\r\n                                                            page_offset=page_offset,\r\n                                                            block_table=block_tables,\r\n                                                            q_len_raw=q_len_raw,\r\n                                                            kv_len_raw=kv_len_raw,\r\n                                                            is_prefill=is_prefill,\r\n                                                            stream = self.stream,\r\n                                                            )\r\n            hidden_states = residual + hidden_states\r\n            # mlp\r\n            residual = hidden_states\r\n            hidden_states = decode_layer.post_attention_layernorm(hidden_states)\r\n            # print_ex(f\"####: before mlp of layer {i}...\")\r\n            hidden_states = decode_layer.mlp(hidden_states, self.stream, self.para_stream)\r\n            hidden_states = hidden_states.squeeze(0)\r\n            hidden_states = residual + hidden_states\r\n        # print_ex(f\"####: fill output...\")\r\n        forward_batch_output = ForwardBatchOutput()\r\n        # with torch_npu.npu.stream(self.stream):\r\n        hidden_states_without_norm = hidden_states.clone()\r\n        local_logit = self.lm_head(self.model.norm(hidden_states))\r\n        for bsz in range(local_logit.size(0)):\r\n            if bsz_real is not None:\r\n                index = int(bsz_real[bsz].item())\r\n                result = local_logit[bsz][:index]\r\n            else:\r\n                result = local_logit[bsz]\r\n            forward_batch_output.logits.append(result)\r\n            forward_batch_output.pre_hidden_states.append(hidden_states_without_norm[bsz])\r\n        return forward_batch_output\r\n\r\n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\r\n        num_heads: int,\r\n        head_dim_ckv: int,\r\n        head_dim_kpe: int,\r\n        page_size: int,\r\n        causal: bool,\r\n        sm_scale: float,\r\n        q_data_type: torch.dtype,\r\n        kv_data_type: torch.dtype,):\r\n        print('[WARN] this custom modeling do not support flash infer, skip this part...')\r\n"
  },
  {
    "path": "archive/ktransformers/models/ascend/custom_ascend_modeling_qwen3.py",
    "content": "# coding=utf-8\n# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.\n# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch_npu\nfrom dataclasses import dataclass\nfrom torch.nn import functional as F\nimport torch.utils.checkpoint\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KVC2Qwen3Cache\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoePreTrainedModel, Qwen3MoeModel\nfrom ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig\nimport ktransformers.util.utils as utils\nfrom ktransformers.operators.ascend.ascend_layernorm import KQwen3FinalRMSNormNPU\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.float16)\n\nclass KNPUQwen3MoeForCausalLM(Qwen3MoePreTrainedModel):\n\n    cache: \"KVC2Qwen3Cache\"\n    use_cuda_graph = False\n\n    def __init__(\n        self,\n        config: \"Qwen3MoeConfig\",\n        cache: \"KVC2Qwen3Cache\",\n        stream: Optional[\"torch_npu.npu.Stream\"] = None,\n        default_type: torch.dtype = torch.float16,\n    ):\n        super().__init__(config)\n\n        self.model = Qwen3MoeModel(config)\n        self.config = config\n        self.config.backend_type = \"balance_serve\" \n        self.cache = cache\n        self.vocab_size = config.vocab_size\n\n        self.model.to(torch.float16)\n\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        self.default_type = default_type\n\n        self.stream = torch_npu.npu.current_stream() if stream is None else stream\n        self.para_stream = torch_npu.npu.Stream()\n        self.call_stream = torch_npu.npu.Stream()\n\n        if hasattr(self.model, \"embed_tokens\"):\n            self.model.embed_tokens.weight.data = self.model.embed_tokens.weight.data.to(torch.float16)\n\n        if hasattr(self.model, \"norm\"):\n            self.model.norm.weight.data = self.model.norm.weight.data.to(torch.float16)\n            if getattr(self.model.norm, \"bias\", None) is not None:\n                self.model.norm.bias.data = self.model.norm.bias.data.to(torch.float16)\n\n\n        try:\n            orig_norm = self.model.norm\n            self.model.norm = KQwen3FinalRMSNormNPU(orig_norm)\n        except Exception as e:\n            print(f\"[INIT][WARN] replace model.norm failed: {e}\", flush=True)\n\n    def init_wrapper(self):\n        print(\"[WARN] KNPUQwen3MoeForCausalLM does not use flashinfer wrapper on NPU, skip init_wrapper...\")\n\n    # ---------------------------------------------------\n    # Embedding：support prefill / decode modes\n    # ---------------------------------------------------\n    def batch_embeddings(\n        self,\n        batch: \"ForwardBatchInput\",\n        device: str = \"npu:0\",\n        is_prefill: bool = True,\n    ) -> torch.Tensor:\n        features = []\n\n        if is_prefill:\n            start_ids = 0\n            seq_lens = []\n\n            for i in range(batch.minibatch.prefill_batch):\n                qlen = int(batch.minibatch.p_q_len[i])\n                kvlen = int(batch.minibatch.p_kv_len[i])\n\n                if kvlen < qlen:\n                    raise AssertionError(\n                        f\"[ERROR] p_kv_len({kvlen}) < p_q_len({qlen}) \"\n                        f\"for prefill idx={i}, this should not happen\"\n                    )\n\n                tokens = batch.minibatch.p_tokens[start_ids: start_ids + qlen].contiguous()\n                start_ids += qlen\n\n                feat = (\n                    self.model.embed_tokens(tokens.to(torch.device(\"cpu\")))\n                    .to(self.default_type)\n                    .to(device=device)\n                )\n\n                features.append(feat)\n                seq_lens.append(qlen)\n\n            max_seq_len = max(seq_lens) if seq_lens else 0\n\n            # Pad the current chunk to the maximum q_len with [bsz, max_q_len, hidden].\n            padded_features = []\n            for feat in features:\n                curr_len = feat.shape[0]\n                if curr_len < max_seq_len:\n                    pad_len = max_seq_len - curr_len\n                    padded_feat = torch.nn.functional.pad(\n                        feat,\n                        (0, 0, 0, pad_len),\n                        mode=\"constant\",\n                        value=0.0,\n                    )\n                    padded_features.append(padded_feat)\n                else:\n                    padded_features.append(feat)\n            features_t = torch.stack(padded_features, dim=0)  # [bsz, max_seq_len, hidden]\n        else:\n            for i in range(batch.minibatch.decode_batch):\n                if batch.minibatch.d_tokens.dim() == 1:\n                    tokens = batch.minibatch.d_tokens.contiguous()\n                else:\n                    tokens = batch.minibatch.d_tokens[i].contiguous()\n                feature = (\n                    self.model.embed_tokens(tokens.to(torch.device(\"cpu\")))\n                    .to(self.default_type)\n                    .to(device=device)\n                )\n                features.append(feature)\n            features_t = torch.stack(features)  # [decode_bsz, decode_q_len, hidden]\n\n        return features_t\n\n    def forward(\n            self,\n            batch: Optional[\"ForwardBatchInput\"] = None,\n            features: torch.Tensor | None = None,\n            cache=None,\n            bsz_tensors: torch.Tensor | None = None,\n            num_tokens_tensors: torch.Tensor | None = None,\n            page_idx: torch.Tensor | None = None,\n            page_offset: torch.Tensor | None = None,\n            position_ids: torch.Tensor | None = None,\n            block_tables: torch.Tensor | None = None,\n            cuda_graph_idx: int | None = 0,\n            is_prefill: bool = True,\n        ) -> \"ForwardBatchOutput\":\n        try:\n            is_capturing = torch.npu.is_current_stream_capturing()\n        except Exception:\n            is_capturing = False\n        # features: [bsz, q_len, hidden]\n        if features.ndim == 2:\n            hidden_states = features.unsqueeze(0)\n        elif features.ndim == 1:\n            hidden_states = features.unsqueeze(0).unsqueeze(0)\n        else:\n            hidden_states = features\n        bsz, q_len, hidden_size = hidden_states.shape\n        minibatch = batch.minibatch\n        if is_prefill:\n            device_pos = minibatch.p_position_ids.device\n            position_ids = -1 * torch.ones(\n                bsz,\n                q_len,\n                dtype=minibatch.p_position_ids.dtype,\n                device=device_pos,\n            )\n            bsz_real = torch.zeros(bsz, dtype=torch.int32, device=device_pos)\n            start_ids = 0\n            for i, qlen in enumerate(minibatch.p_q_len):\n                position_ids[i, :qlen] = minibatch.p_position_ids[start_ids:start_ids + qlen]\n                start_ids += int(qlen.item())\n                bsz_real[i] = qlen\n            block_tables = minibatch.p_block_tables\n            kv_len = minibatch.p_kv_len[0]\n            q_len_raw = minibatch.p_q_len\n            kv_len_raw = minibatch.p_kv_len\n            kv_len_tensor = kv_len_raw\n        else:\n            position_ids = minibatch.d_position_ids\n            if position_ids.dim() == 1:\n                position_ids = position_ids.unsqueeze(0)\n            block_tables = minibatch.d_block_tables\n            kv_len = minibatch.d_kv_len[0]\n            q_len_raw = None\n            kv_len_tensor = minibatch.d_kv_len_list\n            bsz_real = None\n\n        # ==================== layer loop ====================\n        for i, decode_layer in enumerate(self.model.layers):\n            # ---------- Attention Block ----------\n            attn_residual = hidden_states\n\n            hidden_states = decode_layer.input_layernorm(hidden_states)\n\n            attn_out = decode_layer.self_attn(\n                hidden_states,\n                past_key_value=self.cache,\n                position_ids=position_ids,\n                num_tokens_tensors=num_tokens_tensors,\n                page_idx=page_idx,\n                page_offset=page_offset,\n                block_table=block_tables,\n                q_len_raw=q_len_raw,\n                kv_len_raw=kv_len_tensor,\n                is_prefill=is_prefill,\n                stream=self.stream,\n            )\n\n            hidden_states = attn_residual + attn_out\n            # ---------- MLP Block ----------\n            mlp_residual = hidden_states\n            hidden_states = decode_layer.post_attention_layernorm(hidden_states)\n            mlp_in = hidden_states\n            mlp_out = decode_layer.mlp(\n                mlp_in,\n                num_tokens_tensors,\n                cuda_graph_idx,\n            )\n\n            if isinstance(mlp_out, tuple):\n                moe_y = mlp_out[0]\n            else:\n                moe_y = mlp_out\n\n            hidden_states = mlp_residual + moe_y\n        forward_batch_output = ForwardBatchOutput()\n\n        hidden_states_without_norm = hidden_states.clone()\n\n        normed = self.model.norm(hidden_states)\n\n        local_logit = self.lm_head(normed)\n        B_out = local_logit.size(0)\n        for b in range(B_out):\n            if (bsz_real is not None) and (not is_capturing):\n                valid_len = int(bsz_real[b].item())\n                result = local_logit[b, :valid_len]\n                pre_h = hidden_states_without_norm[b, :valid_len]\n            else:\n                result = local_logit[b]\n                pre_h = hidden_states_without_norm[b]\n\n            forward_batch_output.logits.append(result)\n            forward_batch_output.pre_hidden_states.append(pre_h)\n        return forward_batch_output\n\n\n\n    def flash_infer_attn_plan(\n        self,\n        batch: \"ForwardBatchInput\",\n        bsz_tensors: torch.Tensor,\n        num_tokens_tensors: torch.Tensor,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        page_size: int,\n        causal: bool,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,\n        cuda_graph_idx: int = 0,\n    ):\n        print(\"[WARN] KNPUQwen3MoeForCausalLM on NPU does not support flashinfer, skip flash_infer_attn_plan...\")\n"
  },
  {
    "path": "archive/ktransformers/models/configuration_deepseek.py",
    "content": "# Adapted from\n# https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/configuration_deepseek.py\n# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nDEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}\nclass DeepseekV2Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the DeepSeek-V2.\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 102400):\n            Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`DeepseekV2Model`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        moe_intermediate_size (`int`, *optional*, defaults to 1407):\n            Dimension of the MoE representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer decoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        n_shared_experts (`int`, *optional*, defaults to None):\n            Number of shared experts, None means dense model.\n        n_routed_experts (`int`, *optional*, defaults to None):\n            Number of routed experts, None means dense model.\n        routed_scaling_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor or routed experts.\n        topk_method (`str`, *optional*, defaults to `gready`):\n            Topk method used in routed gate.\n        n_group (`int`, *optional*, defaults to None):\n            Number of groups for routed experts.\n        topk_group (`int`, *optional*, defaults to None):\n            Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).\n        num_experts_per_tok (`int`, *optional*, defaults to None):\n            Number of selected experts, None means dense model.\n        moe_layer_freq (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.\n        first_k_dense_replace (`int`, *optional*, defaults to 0):\n            Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).\n                                                            \\--k dense layers--/\n        norm_topk_prob (`bool`, *optional*, defaults to False):\n            Whether to normalize the weights of the routed experts.\n        scoring_func (`str`, *optional*, defaults to 'softmax'):\n            Method of computing expert weights.\n        aux_loss_alpha (`float`, *optional*, defaults to 0.001):\n            Auxiliary loss weight coefficient.\n        seq_aux = (`bool`, *optional*, defaults to True):\n            Whether to compute the auxiliary loss for each individual sample.\n        num_key_value_heads (`int`, *optional*):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 2048):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*):\n            Padding token id.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            Beginning of stream token id.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            End of stream token id.\n        pretraining_tp (`int`, *optional*, defaults to 1):\n            Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this\n            document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is\n            necessary to ensure exact reproducibility of the pretraining results. Please refer to [this\n            issue](https://github.com/pytorch/pytorch/issues/76232).\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling\n            strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is\n            `{\"type\": strategy name, \"factor\": scaling factor}`. When using this flag, don't update\n            `max_position_embeddings` to the expected new maximum.\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n    ```python\n    >>> from transformers import DeepseekV2Model, DeepseekV2Config\n    >>> # Initializing a Deepseek-V2 style configuration\n    >>> configuration = DeepseekV2Config()\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"deepseek_v2\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=102400,\n        hidden_size=4096,\n        intermediate_size=11008,\n        moe_intermediate_size = 1407,\n        num_hidden_layers=30,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        n_shared_experts = None,\n        n_routed_experts = None,\n        ep_size = 1,\n        routed_scaling_factor = 1.0,\n        kv_lora_rank = 512,\n        q_lora_rank = 1536,\n        qk_rope_head_dim = 64,\n        v_head_dim = 128,\n        qk_nope_head_dim = 128,\n        topk_method = 'gready',\n        n_group = None,\n        topk_group = None,\n        num_experts_per_tok = None,\n        moe_layer_freq = 1,\n        first_k_dense_replace = 0,\n        norm_topk_prob = False,\n        scoring_func = 'softmax',\n        aux_loss_alpha = 0.001,\n        seq_aux = True,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=100000,\n        eos_token_id=100001,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        cpu_quant=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        self.aux_loss_alpha = aux_loss_alpha\n        self.seq_aux = seq_aux\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        self.cpu_quant = cpu_quant\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n"
  },
  {
    "path": "archive/ktransformers/models/configuration_deepseek_v3.py",
    "content": "from transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nDEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}\nclass DeepseekV3Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the DeepSeek-V3.\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 129280):\n            Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`DeepseekV3Model`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        moe_intermediate_size (`int`, *optional*, defaults to 1407):\n            Dimension of the MoE representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer decoder.\n        num_nextn_predict_layers (`int`, *optional*, defaults to 1):\n            Number of nextn predict layers in the DeepSeekV3 Model.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        n_shared_experts (`int`, *optional*, defaults to None):\n            Number of shared experts, None means dense model.\n        n_routed_experts (`int`, *optional*, defaults to None):\n            Number of routed experts, None means dense model.\n        routed_scaling_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor or routed experts.\n        topk_method (`str`, *optional*, defaults to `gready`):\n            Topk method used in routed gate.\n        n_group (`int`, *optional*, defaults to None):\n            Number of groups for routed experts.\n        topk_group (`int`, *optional*, defaults to None):\n            Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).\n        num_experts_per_tok (`int`, *optional*, defaults to None):\n            Number of selected experts, None means dense model.\n        moe_layer_freq (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.\n        first_k_dense_replace (`int`, *optional*, defaults to 0):\n            Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).\n                                                            \\--k dense layers--/\n        norm_topk_prob (`bool`, *optional*, defaults to False):\n            Whether to normalize the weights of the routed experts.\n        scoring_func (`str`, *optional*, defaults to 'softmax'):\n            Method of computing expert weights.\n        aux_loss_alpha (`float`, *optional*, defaults to 0.001):\n            Auxiliary loss weight coefficient.\n        seq_aux = (`bool`, *optional*, defaults to True):\n            Whether to compute the auxiliary loss for each individual sample.\n        num_key_value_heads (`int`, *optional*):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 2048):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*):\n            Padding token id.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            Beginning of stream token id.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            End of stream token id.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling\n            strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is\n            `{\"type\": strategy name, \"factor\": scaling factor}`. When using this flag, don't update\n            `max_position_embeddings` to the expected new maximum.\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n    ```python\n    >>> from transformers import DeepseekV3Model, DeepseekV3Config\n    >>> # Initializing a Deepseek-V3 style configuration\n    >>> configuration = DeepseekV3Config()\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"deepseek_v3\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=129280,\n        hidden_size=7168,\n        intermediate_size=18432,\n        moe_intermediate_size = 2048,\n        num_hidden_layers=61,\n        num_nextn_predict_layers=1,\n        num_attention_heads=128,\n        num_key_value_heads=128,\n        n_shared_experts = 1,\n        n_routed_experts = 256,\n        ep_size = 1,\n        routed_scaling_factor = 2.5,\n        kv_lora_rank = 512,\n        q_lora_rank = 1536,\n        qk_rope_head_dim = 64,\n        v_head_dim = 128,\n        qk_nope_head_dim = 128,\n        topk_method = 'noaux_tc',\n        n_group = 8,\n        topk_group = 4,\n        num_experts_per_tok = 8,\n        moe_layer_freq = 1,\n        first_k_dense_replace = 3,\n        norm_topk_prob = True,\n        scoring_func = 'sigmoid',\n        hidden_act=\"silu\",\n        max_position_embeddings=4096,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=0,\n        eos_token_id=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_nextn_predict_layers = num_nextn_predict_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )"
  },
  {
    "path": "archive/ktransformers/models/configuration_glm4_moe.py",
    "content": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py.\n#               Do NOT edit this file manually as any edits will be overwritten by the generation of\n#             the file from the modular. If any change should be done, please apply the change to the\n#                          modular_glm4_moe.py file directly. One of our CI enforces this.\n#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# coding=utf-8\n# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import rope_config_validation\n\n\nclass Glm4MoeConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Glm4MoeModel`]. It is used to instantiate a\n    Glm4Moe model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of [THUDM/GLM-4-100B-A10B](https://huggingface.co/THUDM/GLM-4-100B-A10B).\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151552):\n            Vocabulary size of the Glm4Moe model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`Glm4MoeModel`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 10944):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 46):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 96):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        partial_rotary_factor (`float`, *optional*, defaults to 0.5):\n            The factor of the partial rotary position.\n        num_key_value_heads (`int`, *optional*, defaults to 8):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details, check out [this\n            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.\n\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 131072):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-05):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`list[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`list[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        moe_intermediate_size (`int`, *optional*, defaults to 1408):\n            Intermediate size of the routed expert.\n        num_experts_per_tok (`int`, *optional*, defaults to 8):\n            number of experts per token.\n        n_shared_experts (`int`, *optional*, defaults to 1):\n            Number of shared experts.\n        n_routed_experts (`int`, *optional*, defaults to 128):\n            Number of routed experts.\n        routed_scaling_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor or routed experts.\n        n_group (`int`, *optional*, defaults to 1):\n            Number of groups for routed experts.\n        topk_group (`int`, *optional*, defaults to 1):\n            Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).\n        first_k_dense_replace (`int`, *optional*, defaults to 1):\n            Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).\n                                                            \\--k dense layers--/\n        norm_topk_prob (`bool`, *optional*, defaults to `True`):\n            Whether to normalize the topk probabilities.\n        use_qk_norm (`bool`, *optional*, defaults to `False`):\n            Whether to use query-key normalization in the attention\n    ```python\n    >>> from transformers import Glm4MoeModel, Glm4MoeConfig\n\n    >>> # Initializing a Glm4Moe style configuration\n    >>> configuration = Glm4MoeConfig()\n\n    >>> # Initializing a model from the GLM-4-MOE-100B-A10B style configuration\n    >>> model = Glm4MoeModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"glm4_moe\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    # Default tensor parallel plan for base model `Glm4Moe`\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.experts.*.gate_proj\": \"colwise\",\n        \"layers.*.mlp.experts.*.up_proj\": \"colwise\",\n        \"layers.*.mlp.experts.*.down_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=151552,\n        hidden_size=4096,\n        intermediate_size=10944,\n        num_hidden_layers=46,\n        num_attention_heads=96,\n        partial_rotary_factor=0.5,\n        num_key_value_heads=8,\n        hidden_act=\"silu\",\n        max_position_embeddings=131072,\n        initializer_range=0.02,\n        rms_norm_eps=1e-5,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        moe_intermediate_size=1408,\n        num_experts_per_tok=8,\n        n_shared_experts=1,\n        n_routed_experts=128,\n        routed_scaling_factor=1.0,\n        n_group=1,\n        topk_group=1,\n        first_k_dense_replace=1,\n        norm_topk_prob=True,\n        use_qk_norm=False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.partial_rotary_factor = partial_rotary_factor\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n        rope_config_validation(self)\n\n        # MoE arguments\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.routed_scaling_factor = routed_scaling_factor\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.use_qk_norm = use_qk_norm\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\n__all__ = [\"Glm4MoeConfig\"]"
  },
  {
    "path": "archive/ktransformers/models/configuration_llama.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"LLaMA model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import rope_config_validation\n\n\nclass LlamaConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the LLaMA-7B.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 32000):\n            Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`LlamaModel`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer decoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        num_key_value_heads (`int`, *optional*):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 2048):\n            The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,\n            Llama 2 up to 4096, CodeLlama up to 16384.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*):\n            Padding token id.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            Beginning of stream token id.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            End of stream token id.\n        pretraining_tp (`int`, *optional*, defaults to 1):\n            Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this\n            document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to\n            understand more about it. This value is necessary to ensure exact reproducibility of the pretraining\n            results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        attention_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        mlp_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.\n\n    ```python\n    >>> from transformers import LlamaModel, LlamaConfig\n\n    >>> # Initializing a LLaMA llama-7b style configuration\n    >>> configuration = LlamaConfig()\n\n    >>> # Initializing a model from the llama-7b style configuration\n    >>> model = LlamaModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"llama\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=32000,\n        hidden_size=4096,\n        intermediate_size=11008,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=None,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        mlp_bias=False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        self.mlp_bias = mlp_bias\n\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n        rope_config_validation(self)\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n"
  },
  {
    "path": "archive/ktransformers/models/configuration_qwen2_moe.py",
    "content": "# coding=utf-8\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Qwen2MoE model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen2MoeConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a\n    Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of\n    Qwen1.5-MoE-A2.7B\" [Qwen/Qwen1.5-MoE-A2.7B\"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B\").\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151936):\n            Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`Qwen2MoeModel`]\n        hidden_size (`int`, *optional*, defaults to 2048):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 5632):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 24):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 16):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 16):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 32768):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        use_sliding_window (`bool`, *optional*, defaults to `False`):\n            Whether to use sliding window attention.\n        sliding_window (`int`, *optional*, defaults to 4096):\n            Sliding window attention (SWA) window size. If not specified, will default to `4096`.\n        max_window_layers (`int`, *optional*, defaults to 28):\n            The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        decoder_sparse_step (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer.\n        moe_intermediate_size (`int`, *optional*, defaults to 1408):\n            Intermediate size of the routed expert.\n        shared_expert_intermediate_size (`int`, *optional*, defaults to 5632):\n            Intermediate size of the shared expert.\n        num_experts_per_tok (`int`, *optional*, defaults to 4):\n            Number of selected experts.\n        num_experts (`int`, *optional*, defaults to 60):\n            Number of routed experts.\n        norm_topk_prob (`bool`, *optional*, defaults to `False`):\n            Whether to normalize the topk probabilities.\n        output_router_logits (`bool`, *optional*, defaults to `False`):\n            Whether or not the router logits should be returned by the model. Enabeling this will also\n            allow the model to output the auxiliary loss, including load balancing loss and router z-loss.\n        router_aux_loss_coef (`float`, *optional*, defaults to 0.001):\n            The aux loss factor for the total loss.\n        mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):\n            Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock\n            The list contains layer index, from 0 to num_layers-1 if we have num_layers layers\n            If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.\n\n    ```python\n    >>> from transformers import Qwen2MoeModel, Qwen2MoeConfig\n\n    >>> # Initializing a Qwen2MoE style configuration\n    >>> configuration = Qwen2MoeConfig()\n\n    >>> # Initializing a model from the Qwen1.5-MoE-A2.7B\" style configuration\n    >>> model = Qwen2MoeModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen2_moe\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=2048,\n        intermediate_size=5632,\n        num_hidden_layers=24,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        use_sliding_window=False,\n        sliding_window=4096,\n        max_window_layers=28,\n        attention_dropout=0.0,\n        decoder_sparse_step=1,\n        moe_intermediate_size=1408,\n        shared_expert_intermediate_size=5632,\n        num_experts_per_tok=4,\n        num_experts=60,\n        norm_topk_prob=False,\n        output_router_logits=False,\n        router_aux_loss_coef=0.001,\n        mlp_only_layers=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.use_sliding_window = use_sliding_window\n        self.sliding_window = sliding_window if use_sliding_window else None\n        self.max_window_layers = max_window_layers\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.attention_dropout = attention_dropout\n\n        # MoE arguments\n        self.decoder_sparse_step = decoder_sparse_step\n        self.moe_intermediate_size = moe_intermediate_size\n        self.shared_expert_intermediate_size = shared_expert_intermediate_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_experts = num_experts\n        self.norm_topk_prob = norm_topk_prob\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n        self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )"
  },
  {
    "path": "archive/ktransformers/models/configuration_qwen3_moe.py",
    "content": "# coding=utf-8\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Qwen3MoE model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import rope_config_validation\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen3MoeConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a\n    Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of [Qwen/Qwen3-MoE-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B).\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151936):\n            Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`Qwen3MoeModel`]\n        hidden_size (`int`, *optional*, defaults to 2048):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 6144):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 24):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 4):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 32768):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        use_sliding_window (`bool`, *optional*, defaults to `False`):\n            Whether to use sliding window attention.\n        sliding_window (`int`, *optional*, defaults to 4096):\n            Sliding window attention (SWA) window size. If not specified, will default to `4096`.\n        max_window_layers (`int`, *optional*, defaults to 28):\n            The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        decoder_sparse_step (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer.\n        moe_intermediate_size (`int`, *optional*, defaults to 768):\n            Intermediate size of the routed expert.\n        num_experts_per_tok (`int`, *optional*, defaults to 8):\n            Number of selected experts.\n        num_experts (`int`, *optional*, defaults to 128):\n            Number of routed experts.\n        norm_topk_prob (`bool`, *optional*, defaults to `False`):\n            Whether to normalize the topk probabilities.\n        output_router_logits (`bool`, *optional*, defaults to `False`):\n            Whether or not the router logits should be returned by the model. Enabeling this will also\n            allow the model to output the auxiliary loss, including load balancing loss and router z-loss.\n        router_aux_loss_coef (`float`, *optional*, defaults to 0.001):\n            The aux loss factor for the total loss.\n        mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):\n            Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock\n            The list contains layer index, from 0 to num_layers-1 if we have num_layers layers\n            If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.\n    ```python\n    >>> from transformers import Qwen3MoeModel, Qwen3MoeConfig\n    >>> # Initializing a Qwen3MoE style configuration\n    >>> configuration = Qwen3MoeConfig()\n    >>> # Initializing a model from the Qwen3-15B-A2B\" style configuration\n    >>> model = Qwen3MoeModel(configuration)\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen3_moe\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    # Default tensor parallel plan for base model `Qwen3Moe`\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=2048,\n        intermediate_size=6144,\n        num_hidden_layers=24,\n        num_attention_heads=32,\n        num_key_value_heads=4,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        use_sliding_window=False,\n        sliding_window=4096,\n        max_window_layers=28,\n        attention_dropout=0.0,\n        decoder_sparse_step=1,\n        moe_intermediate_size=768,\n        num_experts_per_tok=8,\n        num_experts=128,\n        norm_topk_prob=False,\n        output_router_logits=False,\n        router_aux_loss_coef=0.001,\n        mlp_only_layers=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.use_sliding_window = use_sliding_window\n        self.sliding_window = sliding_window if use_sliding_window else None\n        self.max_window_layers = max_window_layers\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n        rope_config_validation(self)\n\n        # MoE arguments\n        self.decoder_sparse_step = decoder_sparse_step\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_experts = num_experts\n        self.norm_topk_prob = norm_topk_prob\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n        self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\n__all__ = [\"Qwen3MoeConfig\"]"
  },
  {
    "path": "archive/ktransformers/models/configuration_qwen3_next.py",
    "content": "# coding=utf-8\n# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Qwen3-Next model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig, layer_type_validation\nfrom transformers.modeling_rope_utils import rope_config_validation\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen3NextConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a\n    Qwen3-Next model according to the specified arguments, defining the model architecture.\n    Instantiating a configuration with the defaults will yield a similar configuration to that of\n    Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151936):\n            Vocabulary size of the model. Defines the number of different tokens that can be represented by the\n            `inputs_ids`.\n        hidden_size (`int`, *optional*, defaults to 2048):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 5632):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 48):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 16):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 2):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.\n        hidden_act (`str`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 32768):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        partial_rotary_factor (`float`, *optional*, defaults to 0.25):\n            Percentage of the query and keys which will have rotary embedding.\n        attention_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        head_dim (`int`, *optional*, defaults to 256):\n            Projection weights dimension in multi-head attention.\n        linear_conv_kernel_dim (`int`, *optional*, defaults to 4):\n            Kernel size of the convolution used in linear attention layers.\n        linear_key_head_dim (`int`, *optional*, defaults to 128):\n            Dimension of each key head in linear attention.\n        linear_value_head_dim (`int`, *optional*, defaults to 128):\n            Dimension of each value head in linear attention.\n        linear_num_key_heads (`int`, *optional*, defaults to 16):\n            Number of key heads used in linear attention layers.\n        linear_num_value_heads (`int`, *optional*, defaults to 32):\n            Number of value heads used in linear attention layers.\n        decoder_sparse_step (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer.\n        moe_intermediate_size (`int`, *optional*, defaults to 512):\n            Intermediate size of the routed expert.\n        shared_expert_intermediate_size (`int`, *optional*, defaults to 512):\n            Intermediate size of the shared expert.\n        num_experts_per_tok (`int`, *optional*, defaults to 10):\n            Number of selected experts.\n        num_experts (`int`, *optional*, defaults to 512):\n            Number of routed experts.\n        norm_topk_prob (`bool`, *optional*, defaults to `True`):\n            Whether to normalize the topk probabilities.\n        output_router_logits (`bool`, *optional*, defaults to `False`):\n            Whether or not the router logits should be returned by the model. Enabling this will also\n            allow the model to output the auxiliary loss, including load balancing loss and router z-loss.\n        router_aux_loss_coef (`float`, *optional*, defaults to 0.001):\n            The aux loss factor for the total loss.\n        mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):\n            Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock\n            The list contains layer index, from 0 to num_layers-1 if we have num_layers layers\n            If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.\n        layer_types (`list[str]`, *optional*):\n            Types of each layer (attention or linear).\n    ```python\n    >>> from transformers import Qwen3NextModel, Qwen3NextConfig\n    >>> # Initializing a Qwen3Next style configuration\n    >>> configuration =  Qwen3NextConfig()\n    >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration\n    >>> model = Qwen3NextModel(configuration)\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\n    \"\"\"\n\n    model_type = \"qwen3_next\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.experts.*.gate_proj\": \"colwise\",\n        \"layers.*.mlp.experts.*.up_proj\": \"colwise\",\n        \"layers.*.mlp.experts.*.down_proj\": \"rowwise\",\n        \"layers.*.mlp.shared_experts.gate_proj\": \"colwise\",\n        \"layers.*.mlp.shared_experts.up_proj\": \"colwise\",\n        \"layers.*.mlp.shared_experts.down_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=2048,\n        intermediate_size=5632,\n        num_hidden_layers=48,\n        num_attention_heads=16,\n        num_key_value_heads=2,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        partial_rotary_factor=0.25,\n        attention_bias=False,\n        attention_dropout=0.0,\n        head_dim=256,\n        linear_conv_kernel_dim=4,\n        linear_key_head_dim=128,\n        linear_value_head_dim=128,\n        linear_num_key_heads=16,\n        linear_num_value_heads=32,\n        decoder_sparse_step=1,\n        moe_intermediate_size=512,\n        shared_expert_intermediate_size=512,\n        num_experts_per_tok=10,\n        num_experts=512,\n        norm_topk_prob=True,\n        output_router_logits=False,\n        router_aux_loss_coef=0.001,\n        mlp_only_layers=[],\n        layer_types=None,\n        **kwargs,\n    ):\n        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.partial_rotary_factor = partial_rotary_factor\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        self.head_dim = head_dim\n        rope_config_validation(self)\n\n        self.layer_types = layer_types\n        if self.layer_types is None:\n            self.layer_types = [\n                \"linear_attention\" if bool((i + 1) % 4) else \"full_attention\" for i in range(self.num_hidden_layers)\n            ]\n        layer_type_validation(self.layer_types)\n\n        # linear attention part\n        self.linear_conv_kernel_dim = linear_conv_kernel_dim\n        self.linear_key_head_dim = linear_key_head_dim\n        self.linear_value_head_dim = linear_value_head_dim\n        self.linear_num_key_heads = linear_num_key_heads\n        self.linear_num_value_heads = linear_num_value_heads\n\n        # MoE arguments\n        self.decoder_sparse_step = decoder_sparse_step\n        self.moe_intermediate_size = moe_intermediate_size\n        self.shared_expert_intermediate_size = shared_expert_intermediate_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_experts = num_experts\n        self.norm_topk_prob = norm_topk_prob\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n        self.mlp_only_layers = mlp_only_layers\n\n\n__all__ = [\"Qwen3NextConfig\"]"
  },
  {
    "path": "archive/ktransformers/models/configuration_smallthinker.py",
    "content": "# coding=utf-8\nfrom transformers.configuration_utils import PretrainedConfig\n\nclass SmallthinkerConfig(PretrainedConfig):\n    \"\"\"\n    This is the configuration class to store the configuration of a [`SmallthinkerModel`]. \n    It is used to instantiate a Smallthinker model according to the specified arguments, defining the model architecture. \n    The default values for each of the parameters are the same as the ones used in the original Smallthinker 4B model.\n\n    General configs:\n    - model_type: \"smallthinker\"\n    - model_name\n    - num_hidden_layers\n    - hidden_size\n\n    Tokenizer configs:\n    - pad_token_id\n    - bos_token_id\n    - eos_token_id\n\n    Embedding configs:\n    - vocab_size\n\n    RMSNorm configs:\n    - rms_norm_eps\n\n    Attention configs:\n    - num_attention_heads\n    - num_key_value_heads\n    - head_dim\n    - use_cache\n    - use_qk_norm\n    - rope_layout: array of 0 or 1s, 0 for nope, 1 for rope\n    - rope_theta\n    - max_position_embeddings\n    - sliding_window_layout: array of 0 or 1s, 0 for normal attention, 1 for SWA\n    - sliding_window_size\n\n    General FFN configs:\n    - moe_layer_layout: array of 0 or 1s, 0 for dense layer, 1 for MoE layer\n    \n    Dense FFN configs:\n    - dense_ffn_hidden_size\n\n    MoE FFN configs:\n    - moe_num_primary_experts\n    - moe_shared_primary_experts\n    - moe_ffn_hidden_size\n    - moe_enable_early_router: Use attention output as router input if true\n    - moe_primary_router_use_sigmoid: Use normalized sigmoid \n    - moe_num_active_primary_experts\n    - moe_enable_secondary_experts\n    - moe_num_secondary_experts\n    - moe_secondary_expert_size\n\n    LM Head configs:\n    - tie_word_embeddings\n\n    Visibility configs:\n    - profile_sparsity\n\n    Other configs:\n    - initializer_range\n    \"\"\"\n    def __init__(self,\n        model_type = \"smallthinker\",\n        model_name=\"smallthinker_4b_base\",\n        num_hidden_layers=32,\n        hidden_size=1536,\n        pad_token_id=None,\n        bos_token_id=151643,\n        eos_token_id=[151643,151645],\n        vocab_size=151936,\n        rms_norm_eps=1e-6,\n        num_attention_heads=12,\n        num_key_value_heads=2,\n        head_dim=128,\n        use_cache=True,\n        use_qk_norm=False,\n        rope_layout=[1]*32,\n        rope_theta=1e6,\n        max_position_embeddings=4096 * 32,\n        sliding_window_layout=[0]*32,\n        sliding_window_size=4096,\n        moe_layer_layout=[1]*32,\n        dense_ffn_hidden_size=4096,\n        moe_num_primary_experts=32,\n        moe_shared_primary_experts=0,\n        moe_ffn_hidden_size=768,\n        moe_enable_early_router=True,\n        moe_primary_router_apply_softmax=False,\n        moe_num_active_primary_experts=4,\n        moe_enable_secondary_experts=False,\n        moe_num_secondary_experts=0,\n        moe_secondary_expert_size=0,\n        tie_word_embeddings=True,\n        initializer_range=0.02,\n        **kwargs,\n    ):\n        moe_layer_layout = [1]*num_hidden_layers\n        # Configuration sanitizers\n        assert num_attention_heads % num_key_value_heads == 0,      \"[Smallthinker config sanitizer] num_attention_heads must be divisible by num_key_value_heads\"\n        assert len(rope_layout) == num_hidden_layers,               \"[Smallthinker config sanitizer] rope_layout must have the same length as num_hidden_layers\"\n        assert len(sliding_window_layout) == num_hidden_layers,     \"[Smallthinker config sanitizer] sliding_window_layout must have the same length as num_hidden_layers\"\n        assert len(moe_layer_layout) == num_hidden_layers,          \"[Smallthinker config sanitizer] moe_layer_layout must have the same length as num_hidden_layers\"\n\n        if any(moe_layer_layout):\n            assert moe_num_primary_experts != 0,                    \"[Smallthinker config sanitizer] moe_num_primary_experts must be set non-zero if there is any MoE layer\"\n            assert moe_ffn_hidden_size != 0,                        \"[Smallthinker config sanitizer] moe_ffn_hidden_size must be set non-zero if there is any MoE layer\"\n            assert moe_num_active_primary_experts != 0,             \"[Smallthinker config sanitizer] moe_num_active_primary_experts must be set non-zero if there is any MoE layer\"\n            if moe_enable_secondary_experts:\n                assert moe_num_secondary_experts != 0,              \"[Smallthinker config sanitizer] moe_num_secondary_experts must be set non-zero if moe_enable_secondary_experts is True\"\n                assert moe_secondary_expert_size != 0,              \"[Smallthinker config sanitizer] moe_secondary_expert_size must be set non-zero if moe_enable_secondary_experts is True\"\n                assert moe_num_secondary_experts * moe_secondary_expert_size == moe_ffn_hidden_size, \"[Smallthinker config sanitizer] moe_num_secondary_experts * moe_secondary_expert_size must equal moe_ffn_hidden_size\"\n\n        if not all(moe_layer_layout):\n            assert dense_ffn_hidden_size != 0,                      \"[Smallthinker config sanitizer] dense_ffn_hidden_size must be set non-zero if there is any dense FFN layer\"\n\n        # General configs\n        self.model_type = model_type\n        self.model_name = model_name\n        self.num_hidden_layers = num_hidden_layers\n        self.hidden_size = hidden_size\n\n        # Tokenizer configs\n        self.pad_token_id = pad_token_id\n        self.bos_token_id = bos_token_id\n        self.eos_token_id = eos_token_id\n\n        # Embedding configs\n        self.vocab_size = vocab_size\n\n        # RMSNorm configs\n        self.rms_norm_eps = rms_norm_eps\n\n        # Attention configs\n        self.num_attention_heads = num_attention_heads\n        self.num_key_value_heads = num_key_value_heads\n        self.head_dim = head_dim\n        self.use_cache = use_cache\n        self.use_qk_norm = use_qk_norm\n        self.rope_layout = rope_layout\n        self.rope_theta = rope_theta\n        self.max_position_embeddings = max_position_embeddings\n        self.sliding_window_layout = sliding_window_layout\n        self.sliding_window_size = sliding_window_size\n\n        # General FFN configs\n        self.moe_layer_layout = moe_layer_layout\n\n        # Dense FFN configs\n        self.dense_ffn_hidden_size = dense_ffn_hidden_size\n\n        # MoE FFN configs\n        self.moe_num_primary_experts = moe_num_primary_experts\n        self.moe_shared_primary_experts = moe_shared_primary_experts\n        self.moe_ffn_hidden_size = moe_ffn_hidden_size\n        self.num_experts_per_tok = moe_num_active_primary_experts\n        self.moe_intermediate_size = moe_ffn_hidden_size\n        self.moe_enable_early_router = moe_enable_early_router\n        self.moe_primary_router_apply_softmax = moe_primary_router_apply_softmax\n        self.moe_num_active_primary_experts = moe_num_active_primary_experts\n        self.moe_enable_secondary_experts = moe_enable_secondary_experts\n        self.moe_num_secondary_experts = moe_num_secondary_experts\n        self.moe_secondary_expert_size = moe_secondary_expert_size\n\n        # Logging configs\n        # self.output_router_logits = False\n\n        # Other configs\n        self.initializer_range = initializer_range\n\n        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)\n\n        self._attn_implementation = \"eager\" # SDPA is not allowed for now\n\n        # if self._attn_implementation != \"flash_attention_2\":\n        #     raise NotImplementedError(\"SDPA impl is buggy for now. NEVER TRY TO USE IT.\")\n        \n__all__ = [\"SmallthinkerConfig\"]\n"
  },
  {
    "path": "archive/ktransformers/models/custom_cache.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\n'''\n# Adapted from\n# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/cache_utils.py\n# Copyright 2018- The Hugging Face team. All rights reserved.\n# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\nimport torch\nimport torch.nn as nn\nimport transformers\nfrom transformers import Cache, PretrainedConfig\nfrom typing import List, Optional, Dict, Any, Tuple\n\ntry:\n    import torch_npu\n    from ktransformers.util import utils\n    from ktransformers.server.balance_serve.inference.forward_batch import ForwardMiniBatchCombine, ForwardMiniBatchSplit\n    \n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\nfrom transformers.models.llama.modeling_llama import LlamaDecoderLayer\nfrom ktransformers.server.balance_serve.settings import sched_ext\n\nclass StaticCache(transformers.StaticCache):\n    \"\"\"\n    Static Cache class to be used with `torch.compile(model)`.\n\n    Parameters:\n        config (`PretrainedConfig):\n            The configuration file defining the shape-related attributes required to initialize the static cache.\n        max_batch_size (`int`):\n            The maximum batch size with which the model will be used.\n        max_cache_len (`int`):\n            The maximum sequence length with which the model will be used.\n        device (`torch.device` or `dict`):\n            The device on which the cache should be initialized. Should be the same as the layer.\n            If a `dict`, it should contain the `device` key with the device name as the value.\n        dtype (*optional*, defaults to `torch.float32`):\n            The default `dtype` to use when initializing the layer.\n    \"\"\"\n\n    def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:\n        Cache.__init__(self, layer_class_to_replicate=LlamaDecoderLayer)\n        self._max_batch_size = max_batch_size\n\n        if use_torch_npu:\n            self.position = [0]\n\n        self._max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len\n        # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads\n        if config.architectures[0] == \"DeepseekV3ForCausalLM\":\n            self.head_dim = config.qk_rope_head_dim\n        else:\n            self.head_dim = (\n                config.head_dim if hasattr(config, \"head_dim\") else config.hidden_size // config.num_attention_heads\n            )\n\n        self.dtype = dtype if dtype is not None else torch.float32\n        self.num_key_value_heads = (\n            config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads\n        )\n\n        self.key_cache: List[torch.Tensor] = []\n        self.value_cache: List[torch.Tensor] = []\n        cache_shape = (max_batch_size, self.num_key_value_heads, self._max_cache_len, self.head_dim)\n        if config.architectures[0] == \"DeepseekV2ForCausalLM\" or config.architectures[0] == \"DeepseekV3ForCausalLM\":\n            # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically\n\n            if use_torch_npu:\n                self.page_size = 128\n                self.page_size_tensor = torch.tensor(\n                self.page_size,\n                dtype=torch.int32,\n                    ).npu()\n                self.max_pages_per_batch = (self._max_cache_len + self.page_size - 1) // self.page_size\n                self.max_pages = (self._max_cache_len + self.page_size - 1) // self.page_size * self._max_batch_size\n            else:\n                self.page_size = 64\n                self.max_pages = (self._max_cache_len + self.page_size - 1) // self.page_size\n            latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)\n            self.kv_lora_rank = config.kv_lora_rank\n            self.qk_rope_head_dim = config.qk_rope_head_dim\n            # TODO: support real page table\n            self.page_table_map = dict()\n            self.page_table_list = []\n            for idx in range(config.num_hidden_layers):\n                if isinstance(device, dict):\n                    target_device = device[f\"blk.{idx}.self_attn\"][\"generate_device\"]\n                else:\n                    target_device = device\n                \n                if target_device not in self.page_table_map:\n                    if use_torch_npu:\n                        page_table = torch.zeros((max_batch_size, self.max_pages_per_batch), dtype=torch.int32, device=target_device)\n                        for seq_id in range(max_batch_size):\n                            page_table[seq_id, :] = torch.arange(seq_id * self.max_pages_per_batch, seq_id * self.max_pages_per_batch + self.max_pages_per_batch, dtype=torch.int32, device=target_device)\n                    else:\n                        page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)\n                        for seq_id in range(max_batch_size):\n                            page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)\n                    self.page_table_map[target_device] = page_table\n                    \n                self.page_table_list.append(self.page_table_map[target_device])\n                    \n            self.is_MLA = True\n            self.is_page = True\n        else:\n            key_shape = cache_shape\n            value_shape = cache_shape\n            self.is_MLA = False\n\n        self.past_tokens = []\n        self.num_hidden_layers = config.num_hidden_layers\n        for idx in range(self.num_hidden_layers):\n            # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph\n            # breaks when updating the cache.\n            if isinstance(device, dict):\n                target_device = device[f\"blk.{idx}.self_attn\"][\"generate_device\"]\n            else:\n                target_device = device\n            \n            if self.is_MLA:\n                new_layer_key_cache = torch.zeros(latent_shape, dtype=self.dtype, device=target_device)\n                new_layer_value_cache = None\n                torch._dynamo.mark_static_address(new_layer_key_cache)\n            else:\n                new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)\n                new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)\n                torch._dynamo.mark_static_address(new_layer_key_cache)\n                torch._dynamo.mark_static_address(new_layer_value_cache)\n                \n            self.key_cache.append(new_layer_key_cache)\n            self.value_cache.append(new_layer_value_cache)\n            self.past_tokens.append(0)\n\n    @property\n    def max_batch_size(self):\n        return self._max_batch_size\n\n    @property\n    def max_cache_len(self):\n        return self._max_cache_len\n\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.\n        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.\n\n        Parameters:\n            key_states (`torch.Tensor`):\n                The new key states to cache.\n            value_states (`torch.Tensor`):\n                The new value states to cache.\n            layer_idx (`int`):\n                The index of the layer to cache the states for.\n            cache_kwargs (`Dict[str, Any]`, `optional`):\n                Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input\n                to know how where to write in the cache.\n\n        Return:\n            A tuple containing the updated key and value states.\n        \"\"\"\n        cache_position = cache_kwargs.get(\"cache_position\")\n        k_out = self.key_cache[layer_idx]\n        v_out = self.value_cache[layer_idx]\n        self.past_tokens[layer_idx] += cache_position.size(0)\n        #print(cache_position)\n        if self.is_MLA:\n            if use_torch_npu:\n                page_idx = cache_position // self.page_size_tensor\n                page_offset = cache_position % self.page_size_tensor\n\n                page_idx = page_idx.unsqueeze(0).expand(self.max_batch_size, -1)\n                page_offset = page_offset.unsqueeze(0).expand(self.max_batch_size, -1)\n\n                page_idx_offset = torch.arange(self.max_batch_size, device=page_idx.device) * self.max_pages_per_batch\n                page_idx = page_idx + page_idx_offset.unsqueeze(1)\n\n                combined = torch.cat([key_states, value_states], dim=-1)\n                combined = combined.contiguous()\n                # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)\n                k_out[page_idx, page_offset] = combined\n            else:\n                page_idx = cache_position // self.page_size\n                page_offset = cache_position % self.page_size\n                # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)\n                k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states\n                k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states\n            return k_out, self.page_table_list[layer_idx]\n        else:\n            k_out[:, :, cache_position] = key_states\n            v_out[:, :, cache_position] = value_states\n            return k_out, v_out\n\n    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:\n        \"\"\"Returns the sequence length of the cached states that were seen by the model.\"\"\"\n        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's\n        # limit the check to the first batch member and head dimension.\n        # TODO: deprecate this function in favor of `cache_position`\n        return self.past_tokens[layer_idx]\n    \n    def change_seq_length(self, bias: Optional[int] = 0) -> int:\n        \"\"\"Returns the sequence length of the cached states that were seen by the model.\"\"\"\n        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's\n        # limit the check to the first batch member and head dimension.\n        # TODO: deprecate this function in favor of `cache_position`\n        for layer_idx in range(self.num_hidden_layers):\n            self.past_tokens[layer_idx] += bias\n\n    def get_max_length(self) -> Optional[int]:\n        \"\"\"Returns the maximum sequence length of the cached states.\"\"\"\n        return self.max_cache_len\n    \n    def get_usable_length(self, kv_seq_len, layer_idx: Optional[int] = 0) -> int:\n        return 0\n\n    def reset(self):\n        \"\"\"Resets the cache values while preserving the objects\"\"\"\n        for layer_idx in range(len(self.key_cache)):\n            # In-place ops prevent breaking the static address\n            self.key_cache[layer_idx].zero_()\n            if self.value_cache[layer_idx] is not None:\n                self.value_cache[layer_idx].zero_()\n            self.past_tokens[layer_idx] = 0\n        \n        if use_torch_npu:\n            self.position = [0]\n\n    def remove_suffix(self, start_pos):\n        for layer_idx in range(len(self.key_cache)):\n            # In-place ops prevent breaking the static address\n            if self.is_MLA:\n                k_cache = self.key_cache[layer_idx]\n                k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_()\n            else:\n                self.key_cache[layer_idx][..., start_pos:, :].zero_()\n                self.value_cache[layer_idx][..., start_pos:, :].zero_()\n            self.past_tokens[layer_idx] = start_pos\n    \n    def get_max_cache_shape(self) -> Tuple[int, int, int, int]:\n        \"\"\"Returns the maximum shape of the cache.\"\"\"\n        return self.max_cache_len\n\nclass KVC2StaticCache:\n    \"\"\"\n    Static Cache class connect with KVC2\n    remind: page_idx & page_offset info need to refs to forward batching, only contains KV Block Tensor here\n    \"\"\"\n    def __init__(self, config: PretrainedConfig, max_batch_size, page_size: int = 256, dtype=torch.bfloat16, device=None) -> None:\n        super().__init__()\n        self.config = config\n        self.dtype = dtype\n        self.device = torch.device(\"npu:0\")\n        self.kv_lora_rank = config.kv_lora_rank\n        self.max_batch_size = max_batch_size\n        self.page_size = page_size\n        self.k_caches = []\n        self.v_caches = []\n\n        self.num_hidden_layers = config.num_hidden_layers\n\n        self.is_MLA = True if config.architectures[0] in [\"DeepseekV2ForCausalLM\", \"DeepseekV3ForCausalLM\"] else False\n        # kv cache stored in kvc2\n        # self.past_tokens = []\n\n    def load(self, inference_context):\n        # assert self.is_MLA and len(inference_context.k_cache) == 1, \"currently only support MLA and Cache Pool TP=1\"\n        from ktransformers.util.utils import get_current_device\n        for i in range(self.config.num_hidden_layers):\n            new_layer_key_cache = inference_context.k_cache[int(torch.distributed.get_rank())][i].to(get_current_device())\n            torch._dynamo.mark_static_address(new_layer_key_cache)\n\n            self.k_caches.append(\n                new_layer_key_cache  # [TP_idx, layer_idx, page_idx, page_size, kv_head_num, kv_head_dim]\n            )\n\n            self.v_caches.append(None)\n        self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1]  # page_len * page_size\n\n    def update(\n        self,\n        combined: torch.Tensor,\n        layer_idx: int,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.\n\n        Parameters:\n            key_states (`torch.Tensor`):\n                The new key states to cache.\n            value_states (`torch.Tensor`):\n                The new value states to cache.\n            layer_idx (`int`):\n                The index of the layer to cache the states for.\n            cache_kwargs (`Dict[str, Any]`, `optional`):\n                must have page_idx (`torch.Tensor`): & page_offset (`torch.Tensor`) & cache_position (`torch.Tensor`)\n\n        Return:\n            A tuple containing the updated key and value states.\n        \"\"\"\n        page_idx, page_offset = cache_kwargs.get(\"page_idx\"), cache_kwargs.get(\"page_offset\")\n        if page_idx is None or page_offset is None:\n            raise ValueError('[ERROR] block info:page_idx & page_offset missing!')\n\n        k_out = self.k_caches[layer_idx]\n        assert self.is_MLA, \"currently only support DeepSeekV3 on NPU balance server\"\n\n        if page_idx.dim() == 1:\n            page_idx_tmp = page_idx.unsqueeze(0)\n            page_offset_tmp = page_offset.unsqueeze(0)\n        else:\n             page_idx_tmp = page_idx\n             page_offset_tmp = page_offset\n\n        k_out[page_idx_tmp, page_offset_tmp] = combined\n        return k_out, page_idx\n\n    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:\n        \"\"\"Returns the sequence length of the cached states that were seen by the model.\"\"\"\n        raise ValueError('kvc2 cache pool no longer hold seq_length info, refer to forward batching')\n\n    def get_usable_length(self, kv_seq_len, layer_idx: Optional[int] = 0) -> int:\n        return 0\n\n    def change_seq_length(self, bias: Optional[int] = 0) -> int:\n        \"\"\"Returns the sequence length of the cached states that were seen by the model.\"\"\"\n        raise ValueError('kvc2 cache pool no longer hold seq_length info, refer to forward batching')\n\n    def get_max_length(self) -> Optional[int]:\n        \"\"\"Returns the maximum sequence length of the cached states.\"\"\"\n        return self.max_cache_len\n\n    def reset(self, inference_context):\n        assert self.is_MLA and len(inference_context.k_cache) == 1, \"currently only support MLA and Cache Pool TP=1\"\n        self.k_caches = []\n        self.v_caches = []\n        for i in range(self.config.num_hidden_layers):\n            self.k_caches.append(\n                inference_context.k_cache[0][i]\n            )\n            self.v_caches.append(None)\n        self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1]  # page_len * page_size\n\n    def get_page_table(self, mini_batch, bsz_tensors: torch.tensor = None, is_prefill=True):\n        if is_prefill:\n            # TODO add padding support\n            q_lens = [mini_batch.p_q_len[idx] for idx in range(mini_batch.prefill_batch)]\n            page_local_idx = -1 * torch.ones(mini_batch.prefill_batch, max(q_lens),\n                                             dtype=mini_batch.p_position_ids.dtype, device=mini_batch.p_position_ids.device)\n            page_offset = -1 * torch.ones_like(page_local_idx)\n            # convert merged into batched\n            start_ids = 0\n            for i in range(mini_batch.prefill_batch):\n                page_offset[i, 0:q_lens[i]] = mini_batch.p_position_ids[start_ids:start_ids+q_lens[i]] % self.page_size\n                page_local_idx[i, 0:q_lens[i]] = mini_batch.p_position_ids[start_ids:start_ids+q_lens[i]] // self.page_size\n                for j in range(q_lens[i]):\n                    # get global page idx index by local page idx from block table, as followed decode\n                    page_local_idx[i, j] = mini_batch.p_block_tables[i, page_local_idx[i, j]]\n                start_ids += q_lens[i]\n            page_idx = page_local_idx\n            # only padding will cause page_local_idx/page_offset still have -1 value\n            # you can use following code as check\n            # indices = torch.where(page_offset == -1)\n            # assert not indices[0].numel() > 0, 'there still have un-calculated page_idx value'\n        else:\n            page_local_idx = mini_batch.d_position_ids // self.page_size\n\n            page_offset = mini_batch.d_position_ids % self.page_size\n            \n            for i in range(mini_batch.decode_batch):\n                page_local_idx[i] = mini_batch.d_block_tables[i, page_local_idx[i]]\n            \n            page_idx = page_local_idx\n            \n        return page_idx, page_offset\n\nclass KDeepSeekV3Cache(nn.Module):\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        page_size: int = 256,\n        dtype=torch.bfloat16,\n        device=torch.device(\"cuda:0\"),\n        \n    ):\n        super().__init__()\n        self.config = config\n        self.dtype = dtype\n        self.device = device\n        self.kv_lora_rank = config.kv_lora_rank\n        self.page_size = page_size\n        self.k_caches = []\n        self.v_caches = []\n        \n\n    def load(self, inference_context: \"sched_ext.InferenceContext\"):\n        \n        for i in range(self.config.num_hidden_layers):\n            self.k_caches.append(\n                inference_context.k_cache[0][i] \n            )\n        self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]\n\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n\n        page_idx: torch.Tensor,\n        page_offset: torch.Tensor,\n\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.\n        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.\n\n        Parameters:\n            key_states (`torch.Tensor`):\n                The new key states to cache.\n            value_states (`torch.Tensor`):\n                The new value states to cache.\n            layer_idx (`int`):\n                The index of the layer to cache the states for.\n            cache_kwargs (`Dict[str, Any]`, `optional`):\n                Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input\n                to know how where to write in the cache.\n\n        Return:\n            A tuple containing the updated key and value states.\n        \"\"\"\n        k_out = self.k_caches[layer_idx]\n\n        k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:])\n        k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:])\n        return k_out\n\n        \n    def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):\n        page_offset = cache_position % self.page_size  \n        page_idx_local = cache_position // self.page_size  \n        query_ids = torch.zeros_like(cache_position)\n        for i in range(len(q_indptr) - 1):\n            start_idx = q_indptr[i]\n            end_idx = q_indptr[i + 1]\n            query_ids[start_idx:end_idx] = i\n        page_idx = torch.zeros_like(page_idx_local)\n        for i in range(bsz_tensors[0]):\n            query_id = query_ids[i]\n            local_block = page_idx_local[i]\n            start_block = kv_indptr[query_id]\n            if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:\n                page_idx[i] = kv_indices[start_block + local_block]\n        \n        return page_idx, page_offset\n    \nclass KGQACache(nn.Module):\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        page_size: int = 256,\n        dtype=torch.bfloat16,\n        device=torch.device(\"cuda:0\"),\n        \n    ):\n        super().__init__()\n        self.config = config\n        self.dtype = dtype\n        self.device = device\n        self.page_size = page_size\n        self.k_caches = []\n        self.v_caches = []\n        \n\n    def load(self, inference_context: \"sched_ext.InferenceContext\"):\n        print(self.config.num_hidden_layers)\n        for i in range(self.config.num_hidden_layers):\n            self.k_caches.append(\n                inference_context.k_cache[0][i] \n            )\n            self.v_caches.append(\n                inference_context.v_cache[0][i]\n            )\n\n\n        self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]\n\n\n        \n    def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):\n        page_offset = cache_position % self.page_size  \n        page_idx_local = cache_position // self.page_size  \n        query_ids = torch.zeros_like(cache_position)\n        for i in range(len(q_indptr) - 1):\n            start_idx = q_indptr[i]\n            end_idx = q_indptr[i + 1]\n            query_ids[start_idx:end_idx] = i\n        page_idx = torch.zeros_like(page_idx_local)\n        for i in range(bsz_tensors[0]):\n            query_id = query_ids[i]\n            local_block = page_idx_local[i]\n            start_block = kv_indptr[query_id]\n            if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:\n                page_idx[i] = kv_indices[start_block + local_block]\n        \n        return page_idx, page_offset\n\n    def get_k_cache(self, layer_idx):\n        return self.k_caches[layer_idx]\n\n    def get_v_cache(self, layer_idx):\n        return self.v_caches[layer_idx]\n\n\nclass KVC2Qwen3Cache(nn.Module):\n\n    def __init__(self, config, max_batch_size, page_size=256,\n                 dtype=torch.bfloat16, device=None):\n        super().__init__()\n        self.config = config\n        self.max_batch_size = max_batch_size\n        self.page_size = page_size\n        self.dtype = dtype\n        self.device = device if device else torch.device(\"npu:0\")\n\n        self.num_layers = config.num_hidden_layers\n        self.num_kv_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n\n        self.k_caches = []\n        self.v_caches = []\n\n\n    # ------------------------- 绑定到底层 kvc2 pool -------------------------\n\n    def load(self, inference_context):\n        from ktransformers.util.utils import get_current_device\n        dev = get_current_device()\n\n        self.k_caches = []\n        self.v_caches = []\n\n        rank = (\n            torch.distributed.get_rank()\n            if (torch.distributed.is_available() and torch.distributed.is_initialized())\n            else 0\n        )\n\n        for i in range(self.num_layers):\n            k_buf = inference_context.k_cache[rank][i].to(dev).to(self.dtype)\n            v_buf = inference_context.v_cache[rank][i].to(dev).to(self.dtype)\n\n            torch._dynamo.mark_static_address(k_buf)\n            torch._dynamo.mark_static_address(v_buf)\n\n            self.k_caches.append(k_buf)\n            self.v_caches.append(v_buf)\n\n        # num_pages * page_size\n        self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1]\n\n    # ------------------------- 写 KV -------------------------\n    @torch.no_grad()\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        if cache_kwargs is None:\n            raise ValueError(\"[KVC2Qwen3Cache] cache_kwargs must contain page_idx & page_offset\")\n\n        page_idx: Optional[torch.Tensor] = cache_kwargs.get(\"page_idx\", None)\n        page_offset: Optional[torch.Tensor] = cache_kwargs.get(\"page_offset\", None)\n\n        if page_idx is None or page_offset is None:\n            raise ValueError(\"[KVC2Qwen3Cache] page_idx & page_offset are required in cache_kwargs\")\n\n        k_out = self.k_caches[layer_idx]\n        v_out = self.v_caches[layer_idx]\n\n        # -------- 1) 修正维度顺序：[B, KvH, Q, D] -> [B, Q, KvH, D] --------\n        if key_states.dim() == 4 and key_states.shape[1] == self.num_kv_heads:\n            key_states = key_states.transpose(1, 2).contiguous()\n            value_states = value_states.transpose(1, 2).contiguous()\n\n        if key_states.shape != value_states.shape:\n            raise ValueError(\n                f\"[KVC2Qwen3Cache] key_states.shape {key_states.shape} \"\n                f\"!= value_states.shape {value_states.shape}\"\n            )\n\n        if key_states.dim() != 4:\n            raise ValueError(\n                f\"[KVC2Qwen3Cache] expect key_states dim=4, got {key_states.dim()} \"\n                f\"(shape={key_states.shape})\"\n            )\n\n        bsz, q_len, kv_heads, head_dim = key_states.shape\n\n        if kv_heads != self.num_kv_heads or head_dim != self.head_dim:\n            raise ValueError(\n                f\"[KVC2Qwen3Cache] KV shape mismatch: \"\n                f\"got num_kv_heads={kv_heads}, head_dim={head_dim}, \"\n                f\"expected num_kv_heads={self.num_kv_heads}, head_dim={self.head_dim}\"\n            )\n\n        # -------- 2) flatten page_idx / page_offset 为一维 --------\n        page_idx = page_idx.reshape(-1)\n        page_offset = page_offset.reshape(-1)\n\n        # -------- 3) flatten KV，并强制 dtype 与 cache 对齐 --------\n        val_dtype = k_out.dtype\n        flat_k = key_states.to(val_dtype).reshape(-1, kv_heads, head_dim)\n        flat_v = value_states.to(val_dtype).reshape(-1, kv_heads, head_dim)\n\n        # -------- 4) 真正写入 K / V --------\n        # k_out / v_out: [num_pages, page_size, num_kv_heads, head_dim]\n        k_out[page_idx, page_offset] = flat_k\n        v_out[page_idx, page_offset] = flat_v\n\n    # ------------------------- get K/V -------------------------\n    def get_k_cache(self, layer_idx):\n        return self.k_caches[layer_idx]\n\n    def get_v_cache(self, layer_idx):\n        return self.v_caches[layer_idx]\n\n    # ------------------------- page table 计算 -------------------------\n    def get_page_table(\n        self,\n        mini_batch,\n        bsz_tensors: torch.Tensor = None,\n        is_prefill: bool = True,\n    ):\n        if is_prefill:\n            # prefill: merged positions => batched (B, T_chunk)\n            q_lens = [int(mini_batch.p_q_len[idx]) for idx in range(mini_batch.prefill_batch)]\n            if len(q_lens) == 0:\n                return None, None\n\n            max_q_len = max(q_lens)\n\n            page_local_idx = -1 * torch.ones(\n                mini_batch.prefill_batch,\n                max_q_len,\n                dtype=mini_batch.p_position_ids.dtype,\n                device=mini_batch.p_position_ids.device,\n            )\n            page_offset = -1 * torch.ones_like(page_local_idx)\n\n            start_ids = 0\n            for i in range(mini_batch.prefill_batch):\n                cur_len = q_lens[i]\n                pos = mini_batch.p_position_ids[start_ids:start_ids + cur_len]  # global pos of this chunk\n\n                # local block + offset by page_size\n                page_offset[i, 0:cur_len] = pos % self.page_size\n                page_local_idx[i, 0:cur_len] = pos // self.page_size\n\n                # local block -> global page id via block_tables\n                for j in range(cur_len):\n                    blk = page_local_idx[i, j]\n                    page_local_idx[i, j] = mini_batch.p_block_tables[i, blk]\n\n                start_ids += cur_len\n\n            page_idx = page_local_idx\n        else:\n            # decode: decode_batch = 当前 step 的 batch_size, 每条样本通常 1 个 token\n            page_local_idx = mini_batch.d_position_ids // self.page_size\n            page_offset = mini_batch.d_position_ids % self.page_size\n\n            for i in range(mini_batch.decode_batch):\n                blk = page_local_idx[i]\n                page_local_idx[i] = mini_batch.d_block_tables[i, blk]\n\n            page_idx = page_local_idx\n\n        return page_idx, page_offset\n"
  },
  {
    "path": "archive/ktransformers/models/custom_modeling_deepseek_v2.py",
    "content": "import math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KDeepSeekV3Cache\nfrom  ktransformers.models.modeling_deepseek import DeepseekV2Model,  DeepseekV2PreTrainedModel\nfrom ktransformers.models.configuration_deepseek import DeepseekV2Config\n\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):\n\n    kv_cache: KDeepSeekV3Cache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config,\n        kv_cache,\n\n    ):\n        super().__init__(config)\n        self.model = DeepseekV2Model(config)\n        self.config = config\n        self.kv_cache = kv_cache\n\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        \n\n    def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):\n        self.use_cuda_graph = use_cuda_graph\n        self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)\n        self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)\n        self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)\n        self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)\n        self.paged_kv_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device)\n\n\t\t\n\n        self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(\n            self.workspace_buffer, use_cuda_graph=use_cuda_graph,\n            qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,\n            kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,\n            backend = \"fa2\",\n        )\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        \n        hidden_states = features[0]\n\n\n        with torch.cuda.stream(current_stream):\n            residual = torch.zeros_like(hidden_states)\n            for i, decode_layer in enumerate(self.model.layers):\n                if self.model.transfer_map is not None and i in self.model.transfer_map:\n                    prev_stream = torch.cuda.current_stream()\n                    cur_device = self.model.transfer_map[i]\n                    if cur_device not in self.model.stream_device_map:\n                        self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                    torch.cuda.set_device(cur_device)\n                    self.model.stream_device_map[cur_device].wait_stream(prev_stream)\n                    torch.cuda.set_stream(self.model.stream_device_map[cur_device])\n                    hidden_states = hidden_states.to(\n                        self.model.transfer_map[i], non_blocking=True\n                    )\n\n                    batch.minibatch.position_ids = (\n                        batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)\n                        if batch.minibatch.position_ids is not None\n                        else None\n                    )\n                hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.self_attn(hidden_states, self.kv_cache, \n                                                       position_ids=batch.minibatch.position_ids, \n                                                       wrapper=self.wrapper, bsz_tensors=num_tokens_tensors, \n                                                       cache_position=batch.minibatch.positions, \n                                                       batch_indices=batch.minibatch.batch_indices,\n                                                       kv_indices=batch.minibatch.kv_indices,\n                                                       kv_indptr=batch.minibatch.kv_indptr,\n                                                       kv_last_page_len=batch.minibatch.kv_last_page_len,\n                                                       q_indptr=batch.minibatch.q_indptr,\n                                                       page_idx=page_idx,\n                                                       page_offset=page_offset\n                                                       )\n\n                hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n                if i < 3:\n                    hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors)\n                else:\n                    hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors)\n                    hidden_states = hidden_states.squeeze(0)\n        forward_batch_output = ForwardBatchOutput()\n        assert  batch.batch_size == 1\n        with torch.cuda.stream(current_stream):\n\n            local_logit = self.lm_head(self.model.norm(hidden_states[batch.minibatch.logits_start], num_tokens_tensors, residual[batch.minibatch.logits_start])[0])\n            # local_logit = local_logit[batch.minibatch.logits_start]\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_heads: int,\n        head_dim_ckv: int,\n        head_dim_kpe: int,\n        page_size: int,\n        causal: bool,\n        sm_scale: float,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,):\n        minibatch = batch.minibatch\n        \n        self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type)\n        "
  },
  {
    "path": "archive/ktransformers/models/custom_modeling_deepseek_v3.py",
    "content": "\"\"\"\nDate: 2024-11-06 10:05:11\nLastEditors: djw\nLastEditTime: 2024-11-13 07:50:51\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KDeepSeekV3Cache\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3Model,  DeepseekV3PreTrainedModel\nfrom ktransformers.models.configuration_deepseek_v3 import DeepseekV3Config\n\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):\n\n    cache: KDeepSeekV3Cache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config: DeepseekV3Config,\n        cache,\n    ):\n        super().__init__(config)\n        self.model = DeepseekV3Model(config)\n        self.config = config\n        self.cache = cache\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        \n    def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):\n        self.use_cuda_graph = use_cuda_graph\n        self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)\n        self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)\n        self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)\n        self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)\n        self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)\n        self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device)\n\t\t\n\n        self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(\n            self.workspace_buffer, use_cuda_graph=use_cuda_graph,\n            qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,\n            kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,\n            bsz_tensor=self.bsz_tensor_buf,\n            backend = \"fa2\",\n        )\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n        cuda_graph_idx: int | None = -1\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        \n        hidden_states = features[0]\n\n        with torch.cuda.stream(current_stream):\n            residual = torch.zeros_like(hidden_states)\n            for i, decode_layer in enumerate(self.model.layers):\n                # can't use now, only one flashinfer wrapper\n                if self.model.transfer_map is not None and i in self.model.transfer_map:\n                    prev_stream = torch.cuda.current_stream()\n                    cur_device = self.model.transfer_map[i]\n                    if cur_device not in self.model.stream_device_map:\n                        self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                    torch.cuda.set_device(cur_device)\n                    self.model.stream_device_map[cur_device].wait_stream(prev_stream)\n                    torch.cuda.set_stream(self.model.stream_device_map[cur_device])\n                    hidden_states = hidden_states.to(\n                        self.model.transfer_map[i], non_blocking=True\n                    )\n\n                    batch.minibatch.position_ids = (\n                        batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)\n                        if batch.minibatch.position_ids is not None\n                        else None\n                    )\n                hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.self_attn(hidden_states, self.cache, \n                                                       position_ids=batch.minibatch.position_ids, \n                                                       wrapper=self.wrapper, num_tokens_tensors=num_tokens_tensors, \n                                                       page_idx=page_idx,\n                                                       page_offset=page_offset\n                                                       )\n\n                hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n                if i < self.config.first_k_dense_replace:\n                    hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors)\n                else:\n                    hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)\n                    hidden_states = hidden_states.squeeze(0)\n        forward_batch_output = ForwardBatchOutput()\n        with torch.cuda.stream(current_stream):\n            local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_heads: int,\n        head_dim_ckv: int,\n        head_dim_kpe: int,\n        page_size: int,\n        causal: bool,\n        sm_scale: float,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,):\n        minibatch = batch.minibatch\n        self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)\n        "
  },
  {
    "path": "archive/ktransformers/models/custom_modeling_glm4_moe.py",
    "content": "\"\"\"\nDate: 2024-11-06 10:05:11\nLastEditors: djw\nLastEditTime: 2024-11-13 07:50:51\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KGQACache\nfrom ktransformers.models.modeling_glm4_moe import Glm4MoeModel,  Glm4MoePreTrainedModel\nfrom ktransformers.models.configuration_glm4_moe import Glm4MoeConfig\nfrom ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KGlm4MoeForCausalLM(Glm4MoePreTrainedModel):\n\n    cache: KGQACache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config: Glm4MoeConfig,\n        cache,\n    ):\n\n        super().__init__(config)\n        self.model = Glm4MoeModel(config)\n        self.config = config\n        self.cache = cache\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.attn = [None] * 100\n        \n    def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):\n        self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)\n\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n        cuda_graph_idx: int | None = 0\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        \n        hidden_states = features[0]\n        self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])\n\n        freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0))\n\n\n        with torch.cuda.stream(current_stream):\n            residual = torch.zeros_like(hidden_states)\n            for i, decode_layer in enumerate(self.model.layers):\n\n                hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.self_attn(hidden_states, self.cache, \n                                                       freqs_cis,\n                                                       wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, \n                                                       position_ids=batch.minibatch.position_ids\n                                                       )\n\n                hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n                if i < self.model.config.first_k_dense_replace:\n                    hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors)\n                else:\n                    hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors, cuda_graph_idx)\n                    # hidden_states = hidden_states.squeeze(0)\n\n        forward_batch_output = ForwardBatchOutput()\n        with torch.cuda.stream(current_stream):\n            local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        page_size: int,\n        causal: bool,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,\n        cuda_graph_idx: int = 0\n        ):\n        minibatch = batch.minibatch\n        self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)\n        "
  },
  {
    "path": "archive/ktransformers/models/custom_modeling_qwen2_moe.py",
    "content": "\"\"\"\nDate: 2024-11-06 10:05:11\nLastEditors: djw\nLastEditTime: 2024-11-13 07:50:51\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KGQACache\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeModel, Qwen2MoePreTrainedModel\nfrom ktransformers.models.configuration_qwen2_moe import Qwen2MoeConfig\nfrom ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KQwen2MoeForCausalLM(Qwen2MoePreTrainedModel):\n\n    cache: KGQACache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config: Qwen2MoeConfig,\n        cache,\n    ):\n        super().__init__(config)\n        self.model = Qwen2MoeModel(config)\n        self.config = config\n        self.cache = cache\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.attn = [None] * 100\n        \n    def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):\n        self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)\n\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n        cuda_graph_idx: int | None = 0\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        \n        hidden_states = features[0]\n        self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])\n\n        with torch.cuda.stream(current_stream):\n            residual = torch.zeros_like(hidden_states)\n            for i, decode_layer in enumerate(self.model.layers):\n                if self.model.transfer_map is not None and i in self.model.transfer_map:\n                    prev_stream = torch.cuda.current_stream()\n                    cur_device = self.model.transfer_map[i]\n                    if cur_device not in self.model.stream_device_map:\n                        self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                    torch.cuda.set_device(cur_device)\n                    self.model.stream_device_map[cur_device].wait_stream(prev_stream)\n                    torch.cuda.set_stream(self.model.stream_device_map[cur_device])\n                    hidden_states = hidden_states.to(\n                        self.model.transfer_map[i], non_blocking=True\n                    )\n\n                    batch.minibatch.position_ids = (\n                        batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)\n                        if batch.minibatch.position_ids is not None\n                        else None\n                    )\n                hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.self_attn(hidden_states, self.cache, \n                                                       position_ids=batch.minibatch.position_ids, \n                                                       wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, \n                                                       page_idx=page_idx,\n                                                       page_offset=page_offset\n                                                       )\n\n                hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)\n                hidden_states = hidden_states.squeeze(0)\n        forward_batch_output = ForwardBatchOutput()\n        with torch.cuda.stream(current_stream):\n            local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        page_size: int,\n        causal: bool,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,\n        cuda_graph_idx: int = 0\n        ):\n        minibatch = batch.minibatch\n        self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors,num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)\n        "
  },
  {
    "path": "archive/ktransformers/models/custom_modeling_qwen3_moe.py",
    "content": "\"\"\"\nDate: 2024-11-06 10:05:11\nLastEditors: djw\nLastEditTime: 2024-11-13 07:50:51\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KGQACache\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeModel, Qwen3MoePreTrainedModel\nfrom ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig\nfrom ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KQwen3MoeForCausalLM(Qwen3MoePreTrainedModel):\n\n    cache: KGQACache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config: Qwen3MoeConfig,\n        cache = None,\n    ):\n        super().__init__(config)\n        self.model = Qwen3MoeModel(config)\n        self.config = config\n        self.cache = cache\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.attn = [None] * 100\n        \n    def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):\n        self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)\n\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n        cuda_graph_idx: int | None = 0\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        \n        hidden_states = features[0]\n        self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])\n\n        with torch.cuda.stream(current_stream):\n            residual = torch.zeros_like(hidden_states)\n            for i, decode_layer in enumerate(self.model.layers):\n                if self.model.transfer_map is not None and i in self.model.transfer_map:\n                    prev_stream = torch.cuda.current_stream()\n                    cur_device = self.model.transfer_map[i]\n                    if cur_device not in self.model.stream_device_map:\n                        self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                    torch.cuda.set_device(cur_device)\n                    self.model.stream_device_map[cur_device].wait_stream(prev_stream)\n                    torch.cuda.set_stream(self.model.stream_device_map[cur_device])\n                    hidden_states = hidden_states.to(\n                        self.model.transfer_map[i], non_blocking=True\n                    )\n\n                    batch.minibatch.position_ids = (\n                        batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)\n                        if batch.minibatch.position_ids is not None\n                        else None\n                    )\n                hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.self_attn(hidden_states, self.cache, \n                                                       position_ids=batch.minibatch.position_ids, \n                                                       wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, \n                                                       page_idx=page_idx,\n                                                       page_offset=page_offset\n                                                       )\n\n                hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)\n                hidden_states = hidden_states.squeeze(0)\n        forward_batch_output = ForwardBatchOutput()\n        with torch.cuda.stream(current_stream):\n            local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        page_size: int,\n        causal: bool,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,\n        cuda_graph_idx: int = 0\n        ):\n        minibatch = batch.minibatch\n        self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)\n        "
  },
  {
    "path": "archive/ktransformers/models/custom_modeling_qwen3_next.py",
    "content": "\"\"\"\nDate: 2024-11-06 10:05:11\nLastEditors: djw\nLastEditTime: 2024-11-13 07:50:51\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KGQACache\nfrom ktransformers.models.modeling_qwen3_next import Qwen3NextModel, Qwen3NextPreTrainedModel\nfrom ktransformers.models.configuration_qwen3_next import Qwen3NextConfig\nfrom ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KQwen3NextForCausalLM(Qwen3NextPreTrainedModel):\n\n    cache: KGQACache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config: Qwen3NextConfig,\n        cache = None,\n    ):\n        super().__init__(config)\n        self.model = Qwen3NextModel(config)\n        self.config = config\n        self.cache = cache\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.attn = [None] * 100\n        self.conv_states = [None for _ in range(config.num_hidden_layers)]\n        self.recurrent_states = [None for _ in range(config.num_hidden_layers)]\n        \n    def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):\n        self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)\n\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n    def reset_conv_states(self):\n        for i in range(self.config.num_hidden_layers):\n            self.conv_states[i] = None\n            self.recurrent_states[i] = None\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n        cuda_graph_idx: int | None = 0\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        q_len = features[0].size(0)\n        if q_len > 1:\n            self.reset_conv_states()\n\n        hidden_states = features[0]\n        self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])\n        freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0))\n\n        residual = torch.zeros_like(hidden_states)\n        for i, decode_layer in enumerate(self.model.layers):\n            hidden_states = hidden_states.contiguous().clone()   # 断开别名 + 连续\n            residual      = residual.contiguous().clone()\n\n            hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n            hidden_states = hidden_states.contiguous()\n            residual = residual.contiguous()\n\n            if self.config.layer_types[i] != \"linear_attention\":\n                hidden_states = decode_layer.self_attn(hidden_states, self.cache, freqs_cis,\n                                                    wrapper=self.attn[cuda_graph_idx],\n                                                    bsz_tensors=num_tokens_tensors)\n            else:\n                hs = hidden_states.unsqueeze(0).contiguous().clone()\n                hs = decode_layer.linear_attn(hs, self.conv_states, self.recurrent_states,\n                                            bsz_tensors=num_tokens_tensors)\n                hidden_states = hs.squeeze(0).contiguous()\n\n            hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n\n            hs2 = hidden_states.unsqueeze(0).contiguous().clone()\n            hidden_states = decode_layer.mlp(hs2, num_tokens_tensors, cuda_graph_idx).squeeze(0).contiguous()\n\n            if not torch.isfinite(hidden_states).all():\n                raise RuntimeError(f\"NaN after layer {i}\")\n            # print(f\"Layer {i} output: {hidden_states}\")\n        forward_batch_output = ForwardBatchOutput()\n        with torch.cuda.stream(current_stream):\n            local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        page_size: int,\n        causal: bool,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,\n        cuda_graph_idx: int = 0\n        ):\n        minibatch = batch.minibatch\n        self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)\n        "
  },
  {
    "path": "archive/ktransformers/models/custom_modeling_smallthinker.py",
    "content": "\"\"\"\nDate: 2024-11-06 10:05:11\nLastEditors: djw\nLastEditTime: 2024-11-13 07:50:51\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KGQACache\nfrom ktransformers.models.modeling_smallthinker import SmallthinkerModel,  SmallthinkerPreTrainedModel\nfrom ktransformers.models.configuration_smallthinker import SmallthinkerConfig\nfrom ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KSmallThinkerForCausalLM(SmallthinkerPreTrainedModel):\n\n    cache: KGQACache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config: SmallthinkerConfig,\n        cache,\n    ):\n\n        super().__init__(config)\n        self.model = SmallthinkerModel(config)\n        self.config = config\n        self.cache = cache\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.attn = [None] * 100\n        \n    def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):\n        self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)\n\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n        cuda_graph_idx: int | None = 0\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        \n        hidden_states = features[0]\n        self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])\n\n        freqs_cis = self.model.rotary_emb(hidden_states.unsqueeze(0), batch.minibatch.position_ids.unsqueeze(0))\n\n        with torch.cuda.stream(current_stream):\n            residual = torch.zeros_like(hidden_states)\n            for i, decode_layer in enumerate(self.model.layers):\n                router_input = hidden_states\n                hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.self_attn(hidden_states, self.cache, \n                                                       freqs_cis if self.model.rope_layout[i] else None, \n                                                       wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, \n                                                       position_ids=batch.minibatch.position_ids\n                                                       )\n\n                hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n                if not self.config.moe_layer_layout[i]:\n                    hidden_states = decode_layer.block_sparse_moe(hidden_states, num_tokens_tensors)\n                else:\n                    hidden_states = decode_layer.block_sparse_moe(router_input, hidden_states, num_tokens_tensors, cuda_graph_idx)\n                    # hidden_states = hidden_states.squeeze(0)\n\n        forward_batch_output = ForwardBatchOutput()\n        with torch.cuda.stream(current_stream):\n            local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        page_size: int,\n        causal: bool,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,\n        cuda_graph_idx: int = 0\n        ):\n        minibatch = batch.minibatch\n        self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)\n        "
  },
  {
    "path": "archive/ktransformers/models/modeling_deepseek.py",
    "content": "# coding=utf-8\n'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\n'''\n# Adapted from\n# https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/modeling_deepseek.py\n# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n# \n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch DeepSeek model.\"\"\"\nimport math\nimport warnings\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n    _prepare_4d_attention_mask,\n    _prepare_4d_causal_attention_mask,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    SequenceClassifierOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import (\n    ALL_LAYERNORM_LAYERS,\n    is_torch_greater_or_equal_than_1_13,\n)\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.utils.import_utils import is_torch_fx_available\nfrom .configuration_deepseek import DeepseekV2Config\nimport torch.distributed as dist\nimport numpy as np\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n\n# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.\n# It means that the function will not be traced through and simply appear as a node in the graph.\nif is_torch_fx_available():\n    if not is_torch_greater_or_equal_than_1_13:\n        import torch.fx\n\n    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"DeepseekV2Config\"\n\n\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(\n        torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)\n    )\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\nclass DeepseekV2RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        DeepseekV2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n        self.hidden_size = hidden_size\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return (self.weight * hidden_states).to(input_dtype)\n\n\nALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)\n\n# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->DeepseekV2\nclass DeepseekV2RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        super().__init__()\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)   \n\n# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2\nclass DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):\n    \"\"\"DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n    ):\n        raise NotImplementedError(\"LinearScalingRotaryEmbedding is not supported now.\")\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n        t = t / self.scaling_factor\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2\nclass DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):\n    \"\"\"DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n    ):\n        raise NotImplementedError(\"DynamicNTKScalingRotaryEmbedding is not supported now.\")\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * (\n                (self.scaling_factor * seq_len / self.max_position_embeddings)\n                - (self.scaling_factor - 1)\n            ) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (\n                base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)\n            )\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Inverse dim formula to find dim based on number of rotations\ndef yarn_find_correction_dim(\n    num_rotations, dim, base=10000, max_position_embeddings=2048\n):\n    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (\n        2 * math.log(base)\n    )\n\n\n# Find dim range bounds based on rotations\ndef yarn_find_correction_range(\n    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048\n):\n    low = math.floor(\n        yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)\n    )\n    high = math.ceil(\n        yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)\n    )\n    return max(low, 0), min(high, dim - 1)  # Clamp values just in case\n\n\ndef yarn_get_mscale(scale=1, mscale=1):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\ndef yarn_linear_ramp_mask(min, max, dim):\n    if min == max:\n        max += 0.001  # Prevent singularity\n\n    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n    ramp_func = torch.clamp(linear_func, 0, 1)\n    return ramp_func\n\nclass DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding):\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n        original_max_position_embeddings=4096,\n        beta_fast=32,\n        beta_slow=1,\n        mscale=1,\n        mscale_all_dim=0,\n    ):\n        nn.Module.__init__(self)\n        self.scaling_factor = scaling_factor\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self.beta_fast = beta_fast\n        self.beta_slow = beta_slow\n        self.mscale = mscale\n        self.mscale_all_dim = mscale_all_dim\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n\n        freq_extra = 1.0 / (\n            self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n        freq_inter = 1.0 / (\n            self.scaling_factor\n            * self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n\n        low, high = yarn_find_correction_range(\n            self.beta_fast,\n            self.beta_slow,\n            dim,\n            self.base,\n            self.original_max_position_embeddings,\n        )\n        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(\n            device=device, dtype=torch.float32\n        )\n        inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self._mscale = float(\n            yarn_get_mscale(self.scaling_factor, self.mscale)\n            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)\n        )\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()* self._mscale\n            sin = emb.sin()* self._mscale\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)  \n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    b, h, s, d = q.shape\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n    b, h, s, d = k.shape\n    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\nclass DeepseekV2MLP(nn.Module):\n    def __init__(self, config, hidden_size=None, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size\n        self.intermediate_size = (\n            config.intermediate_size if intermediate_size is None else intermediate_size\n        )\n\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        act = self.act_fn(self.gate_proj(x)) * self.up_proj(x)\n        down_proj = self.down_proj(act)\n        return down_proj\n\nclass MoEGate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.top_k = config.num_experts_per_tok\n        self.n_routed_experts = config.n_routed_experts\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.scoring_func = config.scoring_func\n        self.alpha = config.aux_loss_alpha\n        self.seq_aux = config.seq_aux\n        self.topk_method = config.topk_method\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n\n        # topk selection algorithm\n        self.norm_topk_prob = config.norm_topk_prob\n        self.gating_dim = config.hidden_size\n        self.weight = nn.Parameter(\n            torch.empty((self.n_routed_experts, self.gating_dim))\n        )\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        import torch.nn.init as init\n\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n\n    def forward(self, hidden_states):\n        bsz, seq_len, h = hidden_states.shape\n        ### compute gating score\n        hidden_states = hidden_states.view(-1, h)\n        logits = F.linear(\n            hidden_states.type(torch.float32), self.weight.type(torch.float32), None\n        )\n        if self.scoring_func == \"softmax\":\n            scores = logits.softmax(dim=-1, dtype=torch.float32)\n        else:\n            raise NotImplementedError(\n                f\"insupportable scoring function for MoE gating: {self.scoring_func}\"\n            )\n\n        ### select top-k experts\n        if self.topk_method == \"greedy\":\n            topk_weight, topk_idx = torch.topk(\n                scores, k=self.top_k, dim=-1, sorted=False\n            )\n        elif self.topk_method == \"group_limited_greedy\":\n            group_scores = (\n                scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values\n            )  # [n, n_group]\n            group_idx = torch.topk(\n                group_scores, k=self.topk_group, dim=-1, sorted=False\n            )[\n                1\n            ]  # [n, top_k_group]\n            group_mask = torch.zeros_like(group_scores)  # [n, n_group]\n            group_mask.scatter_(1, group_idx, 1)  # [n, n_group]\n            score_mask = (\n                group_mask.unsqueeze(-1)\n                .expand(\n                    bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group\n                )\n                .reshape(bsz * seq_len, -1)\n            )  # [n, e]\n            tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)  # [n, e]\n            topk_weight, topk_idx = torch.topk(\n                tmp_scores, k=self.top_k, dim=-1, sorted=False\n            )\n\n        ### norm gate to sum 1\n        if self.top_k > 1 and self.norm_topk_prob:\n            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weight = topk_weight / denominator\n        else:\n            topk_weight = topk_weight * self.routed_scaling_factor\n        ### expert-level computation auxiliary loss\n        if self.training and self.alpha > 0.0:\n            scores_for_aux = scores\n            aux_topk = self.top_k\n            # always compute aux loss based on the naive greedy topk method\n            topk_idx_for_aux_loss = topk_idx.view(bsz, -1)\n            if self.seq_aux:\n                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)\n                ce = torch.zeros(\n                    bsz, self.n_routed_experts, device=hidden_states.device\n                )\n                ce.scatter_add_(\n                    1,\n                    topk_idx_for_aux_loss,\n                    torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),\n                ).div_(seq_len * aux_topk / self.n_routed_experts)\n                aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(\n                    dim=1\n                ).mean() * self.alpha\n            else:\n                mask_ce = F.one_hot(\n                    topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts\n                )\n                ce = mask_ce.float().mean(0)\n                Pi = scores_for_aux.mean(0)\n                fi = ce * self.n_routed_experts\n                aux_loss = (Pi * fi).sum() * self.alpha\n        else:\n            aux_loss = None\n        return topk_idx, topk_weight, aux_loss\n\n\nclass AddAuxiliaryLoss(torch.autograd.Function):\n    \"\"\"\n    The trick function of adding auxiliary (aux) loss,\n    which includes the gradient of the aux loss during backpropagation.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, x, loss):\n        assert loss.numel() == 1\n        ctx.dtype = loss.dtype\n        ctx.required_aux_loss = loss.requires_grad\n        return x\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_loss = None\n        if ctx.required_aux_loss:\n            grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)\n        return grad_output, grad_loss\n\nclass DeepseekV2MoE(nn.Module):\n    \"\"\"\n    A mixed expert module containing shared experts.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.num_experts_per_tok = config.num_experts_per_tok\n\n        if hasattr(config, \"ep_size\") and config.ep_size > 1:\n            assert config.ep_size == dist.get_world_size()\n            self.ep_size = config.ep_size\n            self.experts_per_rank = config.n_routed_experts // config.ep_size\n            self.ep_rank = dist.get_rank()\n            self.experts = nn.ModuleList(\n                [\n                    (\n                        DeepseekV2MLP(\n                            config, intermediate_size=config.moe_intermediate_size\n                        )\n                        if i >= self.ep_rank * self.experts_per_rank\n                        and i < (self.ep_rank + 1) * self.experts_per_rank\n                        else None\n                    )\n                    for i in range(config.n_routed_experts)\n                ]\n            )\n        else:\n            self.ep_size = 1\n            self.experts_per_rank = config.n_routed_experts\n            self.ep_rank = 0\n            self.experts = nn.ModuleList(\n                [\n                    DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size)\n                    for i in range(config.n_routed_experts)\n                ]\n            )\n        self.gate = MoEGate(config)\n        if config.n_shared_experts is not None:\n            intermediate_size = config.moe_intermediate_size * config.n_shared_experts\n            self.shared_experts = DeepseekV2MLP(\n                config=config, intermediate_size=intermediate_size\n            )\n\n    def forward(self, hidden_states):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        topk_idx, topk_weight, aux_loss = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        flat_topk_idx = topk_idx.view(-1)\n        if self.training:\n            hidden_states = hidden_states.repeat_interleave(\n                self.num_experts_per_tok, dim=0\n            )\n            y = torch.empty_like(hidden_states)\n            for i, expert in enumerate(self.experts):\n                y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])\n            y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)\n            y = y.view(*orig_shape)\n            y = AddAuxiliaryLoss.apply(y, aux_loss)\n        else:\n            y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)\n        if self.config.n_shared_experts is not None:\n            y = y + self.shared_experts(identity)\n        return y\n\n    @torch.no_grad()\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        sorted_tokens_shape = sorted_tokens.shape\n        if self.ep_size > 1:\n            tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)\n            tokens_per_expert_group = tokens_per_expert.new_empty(\n                tokens_per_expert.shape[0]\n            )\n            dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)\n            output_splits = (\n                tokens_per_expert_group.view(self.ep_size, -1)\n                .sum(1)\n                .cpu()\n                .numpy()\n                .tolist()\n            )\n            gathered_tokens = sorted_tokens.new_empty(\n                tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]\n            )\n            input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()\n            dist.all_to_all(\n                list(gathered_tokens.split(output_splits)),\n                list(sorted_tokens.split(input_split_sizes)),\n            )\n            tokens_per_expert_post_gather = tokens_per_expert_group.view(\n                self.ep_size, self.experts_per_rank\n            ).sum(dim=0)\n            gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)\n            s = 0\n            for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):\n                gatherd_idxs[s : s + k] = i % self.experts_per_rank\n                s += k\n            gatherd_idxs = gatherd_idxs.argsort()\n            sorted_tokens = gathered_tokens[gatherd_idxs]\n            tokens_per_expert = tokens_per_expert_post_gather\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n        if self.ep_size > 1:\n            new_x = torch.empty_like(outs)\n            new_x[gatherd_idxs] = outs\n            gathered_tokens = new_x.new_empty(*sorted_tokens_shape)\n            dist.all_to_all(\n                list(gathered_tokens.split(input_split_sizes)),\n                list(new_x.split(output_splits)),\n            )\n            outs = gathered_tokens\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2\nclass DeepseekV2Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will \"\n                \"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.attention_dropout = config.attention_dropout\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.q_lora_rank = config.q_lora_rank\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.kv_lora_rank = config.kv_lora_rank\n        self.v_head_dim = config.v_head_dim\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n\n        self.is_causal = True\n\n        if self.q_lora_rank is None:\n            self.q_proj = nn.Linear(\n                self.hidden_size, self.num_heads * self.q_head_dim, bias=False\n            )\n        else:\n            self.q_a_proj = nn.Linear(\n                self.hidden_size, config.q_lora_rank, bias=config.attention_bias\n            )\n            self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)\n            self.q_b_proj = nn.Linear(\n                config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False\n            )\n\n        self.kv_a_proj_with_mqa = nn.Linear(\n            self.hidden_size,\n            config.kv_lora_rank + config.qk_rope_head_dim,\n            bias=config.attention_bias,\n        )\n        self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)\n        self.kv_b_proj = nn.Linear(\n            config.kv_lora_rank,\n            self.num_heads\n            * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),\n            bias=False,\n        )\n\n        self.o_proj = nn.Linear(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=config.attention_bias,\n        )\n        self._init_rope()\n\n        self.softmax_scale = self.q_head_dim ** (-0.5)\n        if self.config.rope_scaling is not None:\n            mscale_all_dim = self.config.rope_scaling.get(\"mscale_all_dim\", 0)\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if mscale_all_dim:\n                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n                self.softmax_scale = self.softmax_scale * mscale * mscale\n\n    def _init_rope(self):\n        if self.config.rope_scaling is None:\n            self.rotary_emb = DeepseekV2RotaryEmbedding(\n                self.qk_rope_head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                base=self.rope_theta,\n            )\n        else:\n            scaling_type = self.config.rope_scaling[\"type\"]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if scaling_type == \"linear\":\n                self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"dynamic\":\n                self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"yarn\":\n                kwargs = {\n                    key: self.config.rope_scaling[key]\n                    for key in [\n                        \"original_max_position_embeddings\",\n                        \"beta_fast\",\n                        \"beta_slow\",\n                        \"mscale\",\n                        \"mscale_all_dim\",\n                    ]\n                    if key in self.config.rope_scaling\n                }\n                self.rotary_emb = DeepseekV2YarnRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                    **kwargs,\n                )\n            else:\n                raise ValueError(f\"Unknown RoPE scaling type {scaling_type}\")\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(q_pe, position_ids)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)\n\n        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        attn_weights = (\n            torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale\n        )\n\n        if attention_mask is not None:\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.attention_dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2\nclass DeepseekV2FlashAttention2(DeepseekV2Attention):\n    \"\"\"\n    DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays\n    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of\n    flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.\n        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.\n        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).\n        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        # DeepseekV2FlashAttention2 attention does not support output_attentions\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n\n            # overwrite attention_mask with padding_mask\n            attention_mask = kwargs.pop(\"padding_mask\")\n\n        output_attentions = False\n\n        bsz, q_len, _ = hidden_states.size()\n\n        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dim x hidden_dim\n        # therefore we just need to keep the original shape\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(q_pe, position_ids)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)\n\n        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n\n        if self.q_head_dim != self.v_head_dim:\n            value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n        # to be able to avoid many of these transpose/reshape/view.\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        dropout_rate = self.attention_dropout if self.training else 0.0\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in the correct dtype just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (DeepseekV2RMSNorm handles it correctly)\n\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            # Handle the case where the model is quantized\n            if hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            elif torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            else:\n                target_dtype = self.q_a_proj.weight.dtype\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        attn_output = self._flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            position_ids=position_ids,\n            dropout=dropout_rate,\n            softmax_scale=self.softmax_scale,\n        )\n        if self.q_head_dim != self.v_head_dim:\n            attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n        attn_output = attn_output.reshape(\n            bsz, q_len, self.num_heads * self.v_head_dim\n        ).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n    def _flash_attention_forward(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        query_length,\n        position_ids,\n        dropout=0.0,\n        softmax_scale=None,\n    ):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n        # Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`int`, *optional*):\n                Attention dropout\n            softmax_scale (`float`, *optional*):\n                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)\n        \"\"\"\n        if not self._flash_attn_uses_top_left_mask:\n            causal = self.is_causal\n        else:\n            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__.\n            causal = self.is_causal and query_length != 1\n\n        # Contains at least one padding token in the sequence\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            (\n                query_states,\n                key_states,\n                value_states,\n                indices_q,\n                cu_seq_lens,\n                max_seq_lens,\n            ) = self._upad_input(\n                query_states, key_states, value_states, attention_mask, query_length\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            attn_output_unpad = flash_attn_varlen_func(\n                query_states,\n                key_states,\n                value_states,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k=cu_seqlens_k,\n                max_seqlen_q=max_seqlen_in_batch_q,\n                max_seqlen_k=max_seqlen_in_batch_k,\n                dropout_p=dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n            attn_output = pad_input(\n                attn_output_unpad, indices_q, batch_size, query_length\n            )\n        else:\n            if query_length == 1:\n                position_ids = position_ids.to(dtype=torch.int32).squeeze(1)\n                attn_output = flash_attn_with_kvcache(\n                    query_states,\n                    key_states,\n                    value_states,\n                    cache_seqlens=position_ids,\n                    softmax_scale=softmax_scale,\n                    causal=causal,\n                )   \n            else:\n                attn_output = flash_attn_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    dropout,\n                    softmax_scale=softmax_scale,\n                    causal=causal,\n                )\n\n        return attn_output\n\n    def _upad_input(\n        self, query_layer, key_layer, value_layer, attention_mask, query_length\n    ):\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape\n\n        key_layer = index_first_axis(\n            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        value_layer = index_first_axis(\n            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(\n                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),\n                indices_k,\n            )\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(\n                batch_size + 1, dtype=torch.int32, device=query_layer.device\n            )  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(\n                query_layer, attention_mask\n            )\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q,\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\nATTENTION_CLASSES = {\n    \"eager\": DeepseekV2Attention,\n    \"flash_attention_2\": DeepseekV2FlashAttention2,\n}\n\nclass DeepseekV2DecoderLayer(nn.Module):\n    def __init__(self, config: DeepseekV2Config, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = ATTENTION_CLASSES[config._attn_implementation](\n            config=config, layer_idx=layer_idx\n        )\n\n        self.mlp = (\n            DeepseekV2MoE(config)\n            if (\n                config.n_routed_experts is not None\n                and layer_idx >= config.first_k_dense_replace\n                and layer_idx % config.moe_layer_freq == 0\n            )\n            else DeepseekV2MLP(config)\n        )\n        self.input_layernorm = DeepseekV2RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = DeepseekV2RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*):\n                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,\n                query_sequence_length, key_sequence_length)` if default attention is used.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nDeepseekV2_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`DeepseekV2Config`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.\",\n    DeepseekV2_START_DOCSTRING,\n)\nclass DeepseekV2PreTrainedModel(PreTrainedModel):\n    config_class = DeepseekV2Config\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"DeepseekV2DecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_cache_class = True\n    _supports_static_cache = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nDeepseekV2_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.\",\n    DeepseekV2_START_DOCSTRING,\n)\nclass DeepseekV2Model(DeepseekV2PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]\n\n    Args:\n        config: DeepseekV2Config\n    \"\"\"\n\n    def __init__(self, config: DeepseekV2Config):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n        self.layers = nn.ModuleList(\n            [\n                DeepseekV2DecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self._use_flash_attention_2 = config._attn_implementation == \"flash_attention_2\"\n        self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape[:2]\n        elif inputs_embeds is not None:\n            batch_size, seq_length = inputs_embeds.shape[:2]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers.\"\n                )\n                use_cache = False\n\n        past_key_values_length = 0\n        if use_cache:\n            use_legacy_cache = not isinstance(past_key_values, Cache)\n            if use_legacy_cache:\n                past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            past_key_values_length = past_key_values.get_usable_length(seq_length)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        )\n\n        # embed positions\n        hidden_states = inputs_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = (\n                next_decoder_cache.to_legacy_cache()\n                if use_legacy_cache\n                else next_decoder_cache\n            )\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n    \n    # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\"Custom 4D attention mask should be passed in inverted form with max==0`\")\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n\nclass DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = DeepseekV2Model(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)\n    @replace_return_docstrings(\n        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM\n\n        >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states[:,-1:,:]).float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        cache_position=None,\n        use_cache=True,\n        **kwargs,\n    ):\n        past_length = 0\n        # Omit tokens covered by past_key_values\n        if past_key_values is not None:\n            if isinstance(past_key_values, Cache):\n                past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()\n                max_cache_length = (\n                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)\n                    if past_key_values.get_max_length() is not None\n                    else None\n                )\n                cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)\n            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects\n            else:\n                cache_length = past_length = past_key_values[0][0].shape[2]\n                max_cache_length = None\n\n            # Keep only the unprocessed tokens:\n            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where\n            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as\n            # input)\n            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:\n                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]\n            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard\n            # input_ids based on the past_length.\n            elif past_length < input_ids.shape[1]:\n                input_ids = input_ids[:, past_length:]\n            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.\n\n            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.\n            if (\n                max_cache_length is not None\n                and attention_mask is not None\n                and cache_length + input_ids.shape[1] > max_cache_length\n            ):\n                attention_mask = attention_mask[:, -max_cache_length:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_length == 0:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]\n        if cache_position is None:\n            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)\n        elif use_cache:\n            cache_position = cache_position[-input_length:]\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": use_cache,\n                \"attention_mask\": attention_mask,\n                \"cache_position\": cache_position,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx.to(past_state.device))\n                    for past_state in layer_past\n                ),\n            )\n        return reordered_past\n\n\n@add_start_docstrings(\n    \"\"\"\n    The DeepseekV2 Model transformer with a sequence classification head on top (linear layer).\n\n    [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    DeepseekV2_START_DOCSTRING,\n)\nclass DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = DeepseekV2Model(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\n                \"Cannot handle batch sizes > 1 if no padding token is defined.\"\n            )\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                sequence_lengths = (\n                    torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                ).to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[\n            torch.arange(batch_size, device=logits.device), sequence_lengths\n        ]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (\n                    labels.dtype == torch.long or labels.dtype == torch.int\n                ):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(\n                    pooled_logits.view(-1, self.num_labels), labels.view(-1)\n                )\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n"
  },
  {
    "path": "archive/ktransformers/models/modeling_deepseek_v3.py",
    "content": "# coding=utf-8\n# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch DeepSeek model.\"\"\"\nimport math\nimport warnings\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.generation import GenerationMixin\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n    _prepare_4d_attention_mask,\n    _prepare_4d_causal_attention_mask,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    SequenceClassifierOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import (\n    ALL_LAYERNORM_LAYERS,\n    is_torch_greater_or_equal_than_1_13,\n)\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.utils.import_utils import is_torch_fx_available\nfrom .configuration_deepseek_v3 import DeepseekV3Config\nimport torch.distributed as dist\nimport numpy as np\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n\n# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.\n# It means that the function will not be traced through and simply appear as a node in the graph.\nif is_torch_fx_available():\n    if not is_torch_greater_or_equal_than_1_13:\n        import torch.fx\n\n    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)\n\ntry:\n    import torch_npu\n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"DeepseekV3Config\"\n\n\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(\n        torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)\n    )\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\nclass DeepseekV3RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        DeepseekV3RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n        self.hidden_size = hidden_size\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)\n\n\nclass DeepseekV3RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (\n            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)\n        )\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings,\n            device=self.inv_freq.device,\n            dtype=torch.get_default_dtype(),\n        )\n        self.max_seq_len_cached = None\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n\n        freqs = torch.outer(t, self.inv_freq.to(t.device))\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3\nclass DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):\n    \"\"\"DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n    ):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n        t = t / self.scaling_factor\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3\nclass DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):\n    \"\"\"DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n    ):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * (\n                (self.scaling_factor * seq_len / self.max_position_embeddings)\n                - (self.scaling_factor - 1)\n            ) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (\n                base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)\n            )\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Inverse dim formula to find dim based on number of rotations\ndef yarn_find_correction_dim(\n    num_rotations, dim, base=10000, max_position_embeddings=2048\n):\n    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (\n        2 * math.log(base)\n    )\n\n\n# Find dim range bounds based on rotations\ndef yarn_find_correction_range(\n    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048\n):\n    low = math.floor(\n        yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)\n    )\n    high = math.ceil(\n        yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)\n    )\n    return max(low, 0), min(high, dim - 1)  # Clamp values just in case\n\n\ndef yarn_get_mscale(scale=1, mscale=1):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\ndef yarn_linear_ramp_mask(min, max, dim):\n    if min == max:\n        max += 0.001  # Prevent singularity\n\n    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n    ramp_func = torch.clamp(linear_func, 0, 1)\n    return ramp_func\n\n\nclass DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n        original_max_position_embeddings=4096,\n        beta_fast=32,\n        beta_slow=1,\n        mscale=1,\n        mscale_all_dim=0,\n    ):\n        self.scaling_factor = scaling_factor\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self.beta_fast = beta_fast\n        self.beta_slow = beta_slow\n        self.mscale = mscale\n        self.mscale_all_dim = mscale_all_dim\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        dim = self.dim\n\n        freq_extra = 1.0 / (\n            self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n        freq_inter = 1.0 / (\n            self.scaling_factor\n            * self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n\n        low, high = yarn_find_correction_range(\n            self.beta_fast,\n            self.beta_slow,\n            dim,\n            self.base,\n            self.original_max_position_embeddings,\n        )\n        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(\n            device=device, dtype=torch.float32\n        )\n        inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(seq_len, device=device, dtype=torch.float32)\n\n        freqs = torch.outer(t, inv_freq)\n\n        _mscale = float(\n            yarn_get_mscale(self.scaling_factor, self.mscale)\n            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)\n        )\n\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\n            \"cos_cached\", (emb.cos() * _mscale).to(dtype), persistent=False\n        )\n        self.register_buffer(\n            \"sin_cached\", (emb.sin() * _mscale).to(dtype), persistent=False\n        )\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n            used to pass offsetted position ids when working with a KV-cache.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n    sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n\n    b, h, s, d = q.shape\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    b, h, s, d = k.shape\n    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass DeepseekV3MLP(nn.Module):\n    def __init__(self, config, hidden_size=None, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size\n        self.intermediate_size = (\n            config.intermediate_size if intermediate_size is None else intermediate_size\n        )\n\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass MoEGate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.top_k = config.num_experts_per_tok\n        self.n_routed_experts = config.n_routed_experts\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.scoring_func = config.scoring_func\n        self.topk_method = config.topk_method\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n\n        # topk selection algorithm\n        self.norm_topk_prob = config.norm_topk_prob\n        self.gating_dim = config.hidden_size\n        self.weight = nn.Parameter(\n            torch.empty((self.n_routed_experts, self.gating_dim))\n        )\n        if self.topk_method == \"noaux_tc\":\n            self.e_score_correction_bias = nn.Parameter(\n                torch.empty((self.n_routed_experts))\n            )\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        import torch.nn.init as init\n\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n\n    def forward(self, hidden_states):\n        bsz, seq_len, h = hidden_states.shape\n        ### compute gating score\n        hidden_states = hidden_states.view(-1, h)\n        logits = F.linear(\n            hidden_states.type(torch.float32), self.weight.type(torch.float32), None\n        )\n        if self.scoring_func == \"sigmoid\":\n            scores = logits.sigmoid()\n        else:\n            raise NotImplementedError(\n                f\"insupportable scoring function for MoE gating: {self.scoring_func}\"\n            )\n\n        ### select top-k experts\n        if self.topk_method == \"noaux_tc\":\n            #assert not self.training\n            scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)\n            group_scores = (\n                scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)\n            )  # [n, n_group]\n            group_idx = torch.topk(\n                group_scores, k=self.topk_group, dim=-1, sorted=False\n            )[\n                1\n            ]  # [n, top_k_group]\n            group_mask = torch.zeros_like(group_scores)  # [n, n_group]\n            group_mask.scatter_(1, group_idx, 1)  # [n, n_group]\n            score_mask = (\n                group_mask.unsqueeze(-1)\n                .expand(\n                    bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group\n                )\n                .reshape(bsz * seq_len, -1)\n            )  # [n, e]\n            tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float(\"-inf\"))  # [n, e]\n            _, topk_idx = torch.topk(\n                tmp_scores, k=self.top_k, dim=-1, sorted=False\n            )\n            topk_weight = scores.gather(1, topk_idx)\n        else:\n            raise NotImplementedError(\n                f\"insupportable TopK function for MoE gating: {self.topk_method}\"\n            )\n\n        ### norm gate to sum 1\n        if self.top_k > 1 and self.norm_topk_prob:\n            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weight = topk_weight / denominator\n        topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor\n\n        return topk_idx, topk_weight\n\nclass DeepseekV3MoE(nn.Module):\n    \"\"\"\n    A mixed expert module containing shared experts.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.num_experts_per_tok = config.num_experts_per_tok\n\n        if hasattr(config, \"ep_size\") and config.ep_size > 1:\n            assert config.ep_size == dist.get_world_size()\n            self.ep_size = config.ep_size\n            self.experts_per_rank = config.n_routed_experts // config.ep_size\n            self.ep_rank = dist.get_rank()\n            self.experts = nn.ModuleList(\n                [\n                    (\n                        DeepseekV3MLP(\n                            config, intermediate_size=config.moe_intermediate_size\n                        )\n                        if i >= self.ep_rank * self.experts_per_rank\n                        and i < (self.ep_rank + 1) * self.experts_per_rank\n                        else None\n                    )\n                    for i in range(config.n_routed_experts)\n                ]\n            )\n        else:\n            self.ep_size = 1\n            self.experts_per_rank = config.n_routed_experts\n            self.ep_rank = 0\n            self.experts = nn.ModuleList(\n                [\n                    DeepseekV3MLP(\n                        config, intermediate_size=config.moe_intermediate_size\n                    )\n                    for i in range(config.n_routed_experts)\n                ]\n            )\n        self.gate = MoEGate(config)\n        if config.n_shared_experts is not None:\n            intermediate_size = config.moe_intermediate_size * config.n_shared_experts\n            self.shared_experts = DeepseekV3MLP(\n                config=config, intermediate_size=intermediate_size\n            )\n\n    def forward(self, hidden_states):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        topk_idx, topk_weight = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        flat_topk_idx = topk_idx.view(-1)\n        if not self.training:\n            y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)\n        if self.config.n_shared_experts is not None:\n            y = y + self.shared_experts(identity)\n        return y\n\n    @torch.no_grad()\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        sorted_tokens_shape = sorted_tokens.shape\n        if self.ep_size > 1:\n            tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)\n            tokens_per_expert_group = tokens_per_expert.new_empty(\n                tokens_per_expert.shape[0]\n            )\n            dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)\n            output_splits = (\n                tokens_per_expert_group.view(self.ep_size, -1)\n                .sum(1)\n                .cpu()\n                .numpy()\n                .tolist()\n            )\n            gathered_tokens = sorted_tokens.new_empty(\n                tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]\n            )\n            input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()\n            dist.all_to_all(\n                list(gathered_tokens.split(output_splits)),\n                list(sorted_tokens.split(input_split_sizes)),\n            )\n            tokens_per_expert_post_gather = tokens_per_expert_group.view(\n                self.ep_size, self.experts_per_rank\n            ).sum(dim=0)\n            gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)\n            s = 0\n            for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):\n                gatherd_idxs[s : s + k] = i % self.experts_per_rank\n                s += k\n            gatherd_idxs = gatherd_idxs.argsort()\n            sorted_tokens = gathered_tokens[gatherd_idxs]\n            tokens_per_expert = tokens_per_expert_post_gather\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n        if self.ep_size > 1:\n            new_x = torch.empty_like(outs)\n            new_x[gatherd_idxs] = outs\n            gathered_tokens = new_x.new_empty(*sorted_tokens_shape)\n            dist.all_to_all(\n                list(gathered_tokens.split(input_split_sizes)),\n                list(new_x.split(output_splits)),\n            )\n            outs = gathered_tokens\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3\nclass DeepseekV3Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will \"\n                \"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.attention_dropout = config.attention_dropout\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.q_lora_rank = config.q_lora_rank\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.kv_lora_rank = config.kv_lora_rank\n        self.v_head_dim = config.v_head_dim\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n\n        self.is_causal = True\n\n        if self.q_lora_rank is None:\n            self.q_proj = nn.Linear(\n                self.hidden_size, self.num_heads * self.q_head_dim, bias=False\n            )\n        else:\n            self.q_a_proj = nn.Linear(\n                self.hidden_size, config.q_lora_rank, bias=config.attention_bias\n            )\n            self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)\n            self.q_b_proj = nn.Linear(\n                config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False\n            )\n\n        self.kv_a_proj_with_mqa = nn.Linear(\n            self.hidden_size,\n            config.kv_lora_rank + config.qk_rope_head_dim,\n            bias=config.attention_bias,\n        )\n        self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)\n        self.kv_b_proj = nn.Linear(\n            config.kv_lora_rank,\n            self.num_heads\n            * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),\n            bias=False,\n        )\n\n        self.o_proj = nn.Linear(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=config.attention_bias,\n        )\n        self._init_rope()\n\n        self.softmax_scale = self.q_head_dim ** (-0.5)\n        if self.config.rope_scaling is not None:\n            mscale_all_dim = self.config.rope_scaling.get(\"mscale_all_dim\", 0)\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if mscale_all_dim:\n                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n                self.softmax_scale = self.softmax_scale * mscale * mscale\n\n    def _init_rope(self):\n        if self.config.rope_scaling is None:\n            self.rotary_emb = DeepseekV3RotaryEmbedding(\n                self.qk_rope_head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                base=self.rope_theta,\n            )\n        else:\n            scaling_type = self.config.rope_scaling[\"type\"]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if scaling_type == \"linear\":\n                self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"dynamic\":\n                self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"yarn\":\n                kwargs = {\n                    key: self.config.rope_scaling[key]\n                    for key in [\n                        \"original_max_position_embeddings\",\n                        \"beta_fast\",\n                        \"beta_slow\",\n                        \"mscale\",\n                        \"mscale_all_dim\",\n                    ]\n                    if key in self.config.rope_scaling\n                }\n                self.rotary_emb = DeepseekV3YarnRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                    **kwargs,\n                )\n            else:\n                raise ValueError(f\"Unknown RoPE scaling type {scaling_type}\")\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)\n\n        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        attn_weights = (\n            torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale\n        )\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n        assert attention_mask is not None\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.attention_dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3\nclass DeepseekV3FlashAttention2(DeepseekV3Attention):\n    \"\"\"\n    DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays\n    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of\n    flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.\n        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.\n        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).\n        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        # DeepseekV3FlashAttention2 attention does not support output_attentions\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n\n            # overwrite attention_mask with padding_mask\n            attention_mask = kwargs.pop(\"padding_mask\")\n\n        output_attentions = False\n\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dim x hidden_dim\n        # therefore we just need to keep the original shape\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)\n\n        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n\n        if self.q_head_dim != self.v_head_dim:\n            value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n        # to be able to avoid many of these transpose/reshape/view.\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        dropout_rate = self.attention_dropout if self.training else 0.0\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in the correct dtype just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (DeepseekV3RMSNorm handles it correctly)\n\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            # Handle the case where the model is quantized\n            if hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            elif torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            else:\n                target_dtype = (\n                    self.q_proj.weight.dtype\n                    if self.q_lora_rank is None\n                    else self.q_a_proj.weight.dtype\n                )\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        attn_output = self._flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            dropout=dropout_rate,\n            softmax_scale=self.softmax_scale,\n        )\n        if self.q_head_dim != self.v_head_dim:\n            attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n        attn_output = attn_output.reshape(\n            bsz, q_len, self.num_heads * self.v_head_dim\n        ).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n    def _flash_attention_forward(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        query_length,\n        dropout=0.0,\n        softmax_scale=None,\n    ):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n\n        Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`int`, *optional*):\n                Attention dropout\n            softmax_scale (`float`, *optional*):\n                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)\n        \"\"\"\n        if not self._flash_attn_uses_top_left_mask:\n            causal = self.is_causal\n        else:\n            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.\n            causal = self.is_causal and query_length != 1\n\n        # Contains at least one padding token in the sequence\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            (\n                query_states,\n                key_states,\n                value_states,\n                indices_q,\n                cu_seq_lens,\n                max_seq_lens,\n            ) = self._upad_input(\n                query_states, key_states, value_states, attention_mask, query_length\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            attn_output_unpad = flash_attn_varlen_func(\n                query_states,\n                key_states,\n                value_states,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k=cu_seqlens_k,\n                max_seqlen_q=max_seqlen_in_batch_q,\n                max_seqlen_k=max_seqlen_in_batch_k,\n                dropout_p=dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n            attn_output = pad_input(\n                attn_output_unpad, indices_q, batch_size, query_length\n            )\n        else:\n            attn_output = flash_attn_func(\n                query_states,\n                key_states,\n                value_states,\n                dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n        return attn_output\n\n    def _upad_input(\n        self, query_layer, key_layer, value_layer, attention_mask, query_length\n    ):\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape\n\n        key_layer = index_first_axis(\n            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        value_layer = index_first_axis(\n            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(\n                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),\n                indices_k,\n            )\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(\n                batch_size + 1, dtype=torch.int32, device=query_layer.device\n            )  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(\n                query_layer, attention_mask\n            )\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q,\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\nATTENTION_CLASSES = {\n    \"eager\": DeepseekV3Attention,\n    \"flash_attention_2\": DeepseekV3FlashAttention2,\n}\n\n\nclass DeepseekV3DecoderLayer(nn.Module):\n    def __init__(self, config: DeepseekV3Config, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = ATTENTION_CLASSES[config._attn_implementation](\n            config=config, layer_idx=layer_idx\n        )\n\n        self.mlp = (\n            DeepseekV3MoE(config)\n            if (\n                config.n_routed_experts is not None\n                and layer_idx >= config.first_k_dense_replace\n                and layer_idx % config.moe_layer_freq == 0\n            )\n            else DeepseekV3MLP(config)\n        )\n        self.input_layernorm = DeepseekV3RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = DeepseekV3RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        is_prefill: Optional[bool] = False,\n        **kwargs,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*):\n                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,\n                query_sequence_length, key_sequence_length)` if default attention is used.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            is_prefill=is_prefill,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nDeepseekV3_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`DeepseekV3Config`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.\",\n    DeepseekV3_START_DOCSTRING,\n)\nclass DeepseekV3PreTrainedModel(PreTrainedModel):\n    config_class = DeepseekV3Config\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"DeepseekV3DecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_cache_class = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nDeepseekV3_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.\",\n    DeepseekV3_START_DOCSTRING,\n)\nclass DeepseekV3Model(DeepseekV3PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]\n\n    Args:\n        config: DeepseekV3Config\n    \"\"\"\n\n    def __init__(self, config: DeepseekV3Config):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n        self.layers = nn.ModuleList(\n            [\n                DeepseekV3DecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self._use_flash_attention_2 = config._attn_implementation == \"flash_attention_2\"\n        self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape[:2]\n        elif inputs_embeds is not None:\n            batch_size, seq_length = inputs_embeds.shape[:2]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        past_key_values_length = 0\n        if use_cache:\n            use_legacy_cache = not isinstance(past_key_values, Cache)\n            if use_legacy_cache:\n                past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            past_key_values_length = past_key_values.get_usable_length(seq_length)\n\n        if position_ids is None:\n            device = input_ids.device if input_ids is not None else inputs_embeds.device\n            position_ids = torch.arange(\n                past_key_values_length,\n                seq_length + past_key_values_length,\n                dtype=torch.long,\n                device=device,\n            )\n            position_ids = position_ids.unsqueeze(0)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if self._use_flash_attention_2:\n            # 2d mask is passed through the layers\n            attention_mask = (\n                attention_mask\n                if (attention_mask is not None and 0 in attention_mask)\n                else None\n            )\n        else:\n            # 4d mask is passed through the layers\n            attention_mask = _prepare_4d_causal_attention_mask(\n                attention_mask,\n                (batch_size, seq_length),\n                inputs_embeds,\n                past_key_values_length,\n            )\n\n        # embed positions\n        hidden_states = inputs_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_values,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                cache_position=cache_position,\n            )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = (\n                next_decoder_cache.to_legacy_cache()\n                if use_legacy_cache\n                else next_decoder_cache\n            )\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\"Custom 4D attention mask should be passed in inverted form with max==0`\")\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\nclass DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = DeepseekV3Model(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)\n    @replace_return_docstrings(\n        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        is_prefill: Optional[bool] = False,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM\n\n        >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            cache_position=cache_position,\n            is_prefill=is_prefill,\n        )\n\n        hidden_states = outputs[0]\n        if use_torch_npu:\n            hidden_states_without_norm = outputs[-1]\n            logits = self.lm_head(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states[:,-1:,:])\n        logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            if use_torch_npu:\n                output = (logits,) + outputs[1:] + (hidden_states_without_norm,)\n            else:\n                output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        **kwargs,\n    ):\n        if past_key_values is not None:\n            if isinstance(past_key_values, Cache):\n                cache_length = past_key_values.get_seq_length()\n                past_length = past_key_values.seen_tokens\n                max_cache_length = past_key_values.get_max_length()\n            else:\n                cache_length = past_length = past_key_values[0][0].shape[2]\n                max_cache_length = None\n\n            # Keep only the unprocessed tokens:\n            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where\n            # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as\n            # input)\n            if (\n                attention_mask is not None\n                and attention_mask.shape[1] > input_ids.shape[1]\n            ):\n                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]\n            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard\n            # input_ids based on the past_length.\n            elif past_length < input_ids.shape[1]:\n                input_ids = input_ids[:, past_length:]\n            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.\n\n            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.\n            if (\n                max_cache_length is not None\n                and attention_mask is not None\n                and cache_length + input_ids.shape[1] > max_cache_length\n            ):\n                attention_mask = attention_mask[:, -max_cache_length:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx.to(past_state.device))\n                    for past_state in layer_past\n                ),\n            )\n        return reordered_past\n\n\n@add_start_docstrings(\n    \"\"\"\n    The DeepseekV3 Model transformer with a sequence classification head on top (linear layer).\n\n    [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    DeepseekV3_START_DOCSTRING,\n)\nclass DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = DeepseekV3Model(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\n                \"Cannot handle batch sizes > 1 if no padding token is defined.\"\n            )\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                sequence_lengths = (\n                    torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                ).to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[\n            torch.arange(batch_size, device=logits.device), sequence_lengths\n        ]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (\n                    labels.dtype == torch.long or labels.dtype == torch.int\n                ):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(\n                    pooled_logits.view(-1, self.num_labels), labels.view(-1)\n                )\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n"
  },
  {
    "path": "archive/ktransformers/models/modeling_glm4_moe.py",
    "content": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from src/transformers/models/glm4_moe/modular_glm4_moe.py.\n#               Do NOT edit this file manually as any edits will be overwritten by the generation of\n#             the file from the modular. If any change should be done, please apply the change to the\n#                          modular_glm4_moe.py file directly. One of our CI enforces this.\n#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# coding=utf-8\n# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache\nfrom transformers.generation import GenerationMixin\n# from transformers.integrations import use_kernel_forward_from_hub\nfrom transformers.masking_utils import create_causal_mask\nfrom transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_layers import GradientCheckpointingLayer\nfrom transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom transformers.processing_utils import Unpack\n# from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple\nfrom transformers.utils import auto_docstring, can_return_tuple\n# from transformers.utils.generic import check_model_inputs\nfrom .configuration_glm4_moe import Glm4MoeConfig\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    # **kwargs: Unpack[TransformersKwargs],\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n\n    # Keep half or full tensor for later concatenation\n    rotary_dim = cos.shape[-1]\n    q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]\n    k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]\n\n    # Apply rotary embeddings on the first half or full tensor\n    q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)\n    k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)\n\n    # Concatenate back to full shape\n    q_embed = torch.cat([q_embed, q_pass], dim=-1)\n    k_embed = torch.cat([k_embed, k_pass], dim=-1)\n    return q_embed, k_embed\n\n\nclass Glm4MoeAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: Glm4MoeConfig, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n        self.scaling = self.head_dim**-0.5\n        self.attention_dropout = config.attention_dropout\n        self.is_causal = True\n\n        self.q_proj = nn.Linear(\n            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.k_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.v_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)\n        self.use_qk_norm = config.use_qk_norm\n        if self.use_qk_norm:\n            self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)\n            self.k_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_proj(hidden_states).view(hidden_shape)\n        key_states = self.k_proj(hidden_states).view(hidden_shape)\n        value_states = self.v_proj(hidden_states).view(hidden_shape)\n\n        if self.use_qk_norm:  # main diff from Llama\n            query_states = self.q_norm(query_states)\n            key_states = self.k_norm(key_states)\n\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; position_ids needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            **kwargs,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\nclass Glm4MoeMLP(nn.Module):\n    def __init__(self, config, hidden_size=None, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size\n        self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size\n\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass Glm4MoeTopkRouter(nn.Module):\n    def __init__(self, config: Glm4MoeConfig):\n        super().__init__()\n        self.config = config\n        self.top_k = config.num_experts_per_tok\n        self.n_routed_experts = config.n_routed_experts\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n        self.norm_topk_prob = config.norm_topk_prob\n\n        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))\n        self.register_buffer(\"e_score_correction_bias\", torch.zeros((self.n_routed_experts), dtype=torch.float32))\n\n    @torch.no_grad()\n    def get_topk_indices(self, scores):\n        scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)\n        group_scores = (\n            scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)\n            .topk(2, dim=-1)[0]\n            .sum(dim=-1)\n        )\n        group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]\n        group_mask = torch.zeros_like(group_scores)\n        group_mask.scatter_(1, group_idx, 1)\n        score_mask = (\n            group_mask.unsqueeze(-1)\n            .expand(-1, self.n_group, self.n_routed_experts // self.n_group)\n            .reshape(-1, self.n_routed_experts)\n        )\n        scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)\n        topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]\n        return topk_indices\n\n    def forward(self, hidden_states):\n        hidden_states = hidden_states.view(-1, self.config.hidden_size)\n        router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))\n        scores = router_logits.sigmoid()\n        topk_indices = self.get_topk_indices(scores)\n        topk_weights = scores.gather(1, topk_indices)\n        if self.norm_topk_prob:\n            denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weights /= denominator\n        topk_weights = topk_weights * self.routed_scaling_factor\n        return topk_indices, topk_weights\n\n\n# @use_kernel_forward_from_hub(\"RMSNorm\")\nclass Glm4MoeRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Glm4MoeRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\nclass Glm4MoeMoE(nn.Module):\n    \"\"\"\n    A mixed expert module containing shared experts.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.experts = nn.ModuleList(\n            [\n                Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size)\n                for _ in range(config.n_routed_experts)\n            ]\n        )\n        self.gate = Glm4MoeTopkRouter(config)\n        self.shared_experts = Glm4MoeMLP(\n            config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts\n        )\n\n    def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):\n        r\"\"\"\n        CALL FOR CONTRIBUTION! I don't have time to optimise this right now, but expert weights need to be fused\n        to not have to do a loop here (deepseek has 256 experts soooo yeah).\n        \"\"\"\n        final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)\n        expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))\n        expert_mask = expert_mask.permute(2, 0, 1)\n\n        for expert_idx in range(len(self.experts)):\n            expert = self.experts[expert_idx]\n            mask = expert_mask[expert_idx]\n            token_indices, weight_indices = torch.where(mask)\n\n            if token_indices.numel() > 0:\n                expert_weights = topk_weights[token_indices, weight_indices]\n                expert_input = hidden_states[token_indices]\n                expert_output = expert(expert_input)\n                weighted_output = expert_output * expert_weights.unsqueeze(-1)\n                final_hidden_states.index_add_(0, token_indices, weighted_output)\n\n        # in original deepseek, the output of the experts are gathered once we leave this module\n        # thus the moe module is itelsf an IsolatedParallel module\n        # and all expert are \"local\" meaning we shard but we don't gather\n        return final_hidden_states.type(hidden_states.dtype)\n\n    def forward(self, hidden_states):\n        residuals = hidden_states\n        orig_shape = hidden_states.shape\n        topk_indices, topk_weights = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)\n        hidden_states = hidden_states + self.shared_experts(residuals)\n        return hidden_states\n\n\nclass Glm4MoeDecoderLayer(GradientCheckpointingLayer):\n    def __init__(self, config: Glm4MoeConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = Glm4MoeAttention(config=config, layer_idx=layer_idx)\n\n        if layer_idx >= config.first_k_dense_replace:\n            self.mlp = Glm4MoeMoE(config)\n        else:\n            self.mlp = Glm4MoeMLP(config)\n\n        self.input_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC\n        # **kwargs: Unpack[TransformersKwargs],\n    ) -> tuple[torch.Tensor]:\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n        # Self Attention\n        hidden_states, _ = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n        return hidden_states\n\n\n@auto_docstring\nclass Glm4MoePreTrainedModel(PreTrainedModel):\n    config: Glm4MoeConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Glm4MoeDecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_static_cache = False\n    _supports_attention_backend = True\n    _can_record_outputs = {\n        \"hidden_states\": Glm4MoeDecoderLayer,\n        \"attentions\": Glm4MoeAttention,\n    }\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, Glm4MoeRMSNorm):\n            module.weight.data.fill_(1.0)\n        elif isinstance(module, Glm4MoeTopkRouter):\n            module.weight.data.normal_(mean=0.0, std=std)\n\n\nclass Glm4MoeRotaryEmbedding(nn.Module):\n    def __init__(self, config: Glm4MoeConfig, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and isinstance(config.rope_scaling, dict):\n            self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    @torch.no_grad()\n    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)\n    def forward(self, x, position_ids):\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos() * self.attention_scaling\n            sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\n@auto_docstring\nclass Glm4MoeModel(Glm4MoePreTrainedModel):\n    _keys_to_ignore_on_load_unexpected = [r\"model\\.layers\\.92.*\", r\"model\\.layers\\.46.*\"]\n\n    def __init__(self, config: Glm4MoeConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [Glm4MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = Glm4MoeRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    # @check_model_inputs\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        # **kwargs: Unpack[TransformersKwargs],\n    ) -> BaseModelOutputWithPast:\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if inputs_embeds is None:\n            inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache()\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position: torch.Tensor = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = create_causal_mask(\n            config=self.config,\n            input_embeds=inputs_embeds,\n            attention_mask=attention_mask,\n            cache_position=cache_position,\n            past_key_values=past_key_values,\n            position_ids=position_ids,\n        )\n\n        hidden_states = inputs_embeds\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        for decoder_layer in self.layers[: self.config.num_hidden_layers]:\n            hidden_states = decoder_layer(\n                hidden_states,\n                attention_mask=causal_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_values,\n                cache_position=cache_position,\n                position_embeddings=position_embeddings,\n            )\n\n        hidden_states = self.norm(hidden_states)\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n        )\n\n\n@auto_docstring\nclass Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    _tp_plan = {\"lm_head\": \"colwise_rep\"}\n    _pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = Glm4MoeModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        # **kwargs: Unpack[TransformersKwargs],\n    ) -> CausalLMOutputWithPast:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Glm4MoeForCausalLM\n\n        >>> model = Glm4MoeForCausalLM.from_pretrained(\"meta-glm4_moe/Glm4Moe-2-7b-hf\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"meta-glm4_moe/Glm4Moe-2-7b-hf\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        outputs: BaseModelOutputWithPast = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            # **kwargs,\n        )\n\n        hidden_states = outputs.last_hidden_state\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n__all__ = [\"Glm4MoePreTrainedModel\", \"Glm4MoeModel\", \"Glm4MoeForCausalLM\"]"
  },
  {
    "path": "archive/ktransformers/models/modeling_llama.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom .configuration_llama import LlamaConfig\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"LlamaConfig\"\n\n\nclass LlamaRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)\n\n\nclass LlamaRotaryEmbedding(nn.Module):\n    def __init__(\n        self,\n        dim=None,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n        rope_type=\"default\",\n        config: Optional[LlamaConfig] = None,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        self.device = device\n        self.scaling_factor = scaling_factor\n        self.rope_type = rope_type\n        self.config = config\n        # TODO (joao): remove the `if` below, only used for BC\n        self.rope_kwargs = {}\n        if config is None:\n            logger.warning_once(\n                \"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the \"\n                \"`config` argument. All other arguments will be removed in v4.45\"\n            )\n            self.rope_kwargs = {\n                \"rope_type\": rope_type,\n                \"factor\": scaling_factor,\n                \"dim\": dim,\n                \"base\": base,\n                \"max_position_embeddings\": max_position_embeddings,\n            }\n            self.rope_type = rope_type\n            self.max_seq_len_cached = max_position_embeddings\n            self.original_max_seq_len = max_position_embeddings\n        else:\n            # BC: \"rope_type\" was originally \"type\"\n            if config.rope_scaling is not None:\n                self.rope_type = config.rope_scaling.get(\n                    \"rope_type\", config.rope_scaling.get(\"type\")\n                )\n            else:\n                self.rope_type = \"default\"\n            self.max_seq_len_cached = config.max_position_embeddings\n            self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(\n            self.config, device, **self.rope_kwargs\n        )\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    def _dynamic_frequency_update(self, position_ids, device):\n        \"\"\"\n        dynamic RoPE layers should recompute `inv_freq` in the following situations:\n        1 - growing beyond the cached sequence length (allow scaling)\n        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)\n        \"\"\"\n        seq_len = torch.max(position_ids) + 1\n        # seq_len = position_ids[0, -1] + 1\n        if seq_len > self.max_seq_len_cached:  # growth\n            inv_freq, self.attention_scaling = self.rope_init_fn(\n                self.config, device, seq_len=seq_len, **self.rope_kwargs\n            )\n            self.register_buffer(\n                \"inv_freq\", inv_freq, persistent=False\n            )  # TODO joao: may break with compilation\n            self.max_seq_len_cached = seq_len\n\n        if (\n            seq_len < self.original_max_seq_len\n            and self.max_seq_len_cached > self.original_max_seq_len\n        ):  # reset\n            self.register_buffer(\"inv_freq\", self.original_inv_freq, persistent=False)\n            self.max_seq_len_cached = self.original_max_seq_len\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        # if \"dynamic\" in self.rope_type:\n        #     self._dynamic_frequency_update(position_ids, device=x.device)\n\n        # Core RoPE block\n        inv_freq_expanded = (\n            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        )\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)\n        device_type = x.device.type\n        device_type = (\n            device_type\n            if isinstance(device_type, str) and device_type != \"mps\"\n            else \"cpu\"\n        )\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (\n                inv_freq_expanded.float() @ position_ids_expanded.float()\n            ).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n\n        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention\n        cos = cos * self.attention_scaling\n        sin = sin * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nclass LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        logger.warning_once(\n            \"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use \"\n            \"`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__).\"\n        )\n        kwargs[\"rope_type\"] = \"linear\"\n        super().__init__(*args, **kwargs)\n\n\nclass LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        logger.warning_once(\n            \"`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use \"\n            \"`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to \"\n            \"__init__).\"\n        )\n        kwargs[\"rope_type\"] = \"dynamic\"\n        super().__init__(*args, **kwargs)\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass LlamaMLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(\n            self.hidden_size, self.intermediate_size, bias=config.mlp_bias\n        )\n        self.up_proj = nn.Linear(\n            self.hidden_size, self.intermediate_size, bias=config.mlp_bias\n        )\n        self.down_proj = nn.Linear(\n            self.intermediate_size, self.hidden_size, bias=config.mlp_bias\n        )\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        if self.config.pretraining_tp > 1:\n            slice = self.intermediate_size // self.config.pretraining_tp\n            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)\n            up_proj_slices = self.up_proj.weight.split(slice, dim=0)\n            down_proj_slices = self.down_proj.weight.split(slice, dim=1)\n\n            gate_proj = torch.cat(\n                [\n                    F.linear(x, gate_proj_slices[i])\n                    for i in range(self.config.pretraining_tp)\n                ],\n                dim=-1,\n            )\n            up_proj = torch.cat(\n                [\n                    F.linear(x, up_proj_slices[i])\n                    for i in range(self.config.pretraining_tp)\n                ],\n                dim=-1,\n            )\n\n            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)\n            down_proj = [\n                F.linear(intermediate_states[i], down_proj_slices[i])\n                for i in range(self.config.pretraining_tp)\n            ]\n            down_proj = sum(down_proj)\n        else:\n            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n        return down_proj\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass LlamaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will \"\n                \"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.attention_dropout = config.attention_dropout\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.is_causal = True\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n\n        self.q_proj = nn.Linear(\n            self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.k_proj = nn.Linear(\n            self.hidden_size,\n            self.num_key_value_heads * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.v_proj = nn.Linear(\n            self.hidden_size,\n            self.num_key_value_heads * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.o_proj = nn.Linear(\n            self.hidden_size, self.hidden_size, bias=config.attention_bias\n        )\n\n        # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)\n        self.rotary_emb = LlamaRotaryEmbedding(config=self.config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            Tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # will become mandatory in v4.45\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.config.pretraining_tp > 1:\n            key_value_slicing = (\n                self.num_key_value_heads * self.head_dim\n            ) // self.config.pretraining_tp\n            query_slices = self.q_proj.weight.split(\n                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0\n            )\n            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)\n            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)\n\n            query_states = [\n                F.linear(hidden_states, query_slices[i])\n                for i in range(self.config.pretraining_tp)\n            ]\n            query_states = torch.cat(query_states, dim=-1)\n\n            key_states = [\n                F.linear(hidden_states, key_slices[i])\n                for i in range(self.config.pretraining_tp)\n            ]\n            key_states = torch.cat(key_states, dim=-1)\n\n            value_states = [\n                F.linear(hidden_states, value_slices[i])\n                for i in range(self.config.pretraining_tp)\n            ]\n            value_states = torch.cat(value_states, dim=-1)\n\n        else:\n            query_states = self.q_proj(hidden_states)\n            key_states = self.k_proj(hidden_states)\n            value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(\n            bsz, q_len, self.num_heads, self.head_dim\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n\n        if position_embeddings is None:\n            logger.warning_once(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            cos, sin = self.rotary_emb(value_states, position_ids)\n        else:\n            cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin\n        )\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(\n            query_states, key_states.transpose(2, 3)\n        ) / math.sqrt(self.head_dim)\n\n        if attention_mask is not None:  # no matter the length, we just slice it\n            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n            attn_weights = attn_weights + causal_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.attention_dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, -1)\n\n        if self.config.pretraining_tp > 1:\n            attn_output = attn_output.split(\n                self.hidden_size // self.config.pretraining_tp, dim=2\n            )\n            o_proj_slices = self.o_proj.weight.split(\n                self.hidden_size // self.config.pretraining_tp, dim=1\n            )\n            attn_output = sum(\n                [\n                    F.linear(attn_output[i], o_proj_slices[i])\n                    for i in range(self.config.pretraining_tp)\n                ]\n            )\n        else:\n            attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass LlamaFlashAttention2(LlamaAttention):\n    \"\"\"\n    Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays\n    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of\n    flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.\n        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.\n        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).\n        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            Tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # will become mandatory in v4.45\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if isinstance(past_key_value, StaticCache):\n            raise ValueError(\n                \"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` \"\n                \"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers\"\n            )\n\n        output_attentions = False\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dim x hidden_dim\n        # therefore we just need to keep the original shape\n        query_states = query_states.view(\n            bsz, q_len, self.num_heads, self.head_dim\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n\n        if position_embeddings is None:\n            logger.warning_once(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            cos, sin = self.rotary_emb(value_states, position_ids)\n        else:\n            cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin\n        )\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n        # to be able to avoid many of these transpose/reshape/view.\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        dropout_rate = self.attention_dropout if self.training else 0.0\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in the correct dtype just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (LlamaRMSNorm handles it correctly)\n\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            if torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            # Handle the case where the model is quantized\n            elif hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            else:\n                target_dtype = self.q_proj.weight.dtype\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        attn_output = _flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            dropout=dropout_rate,\n            sliding_window=getattr(self, \"sliding_window\", None),\n            use_top_left_mask=self._flash_attn_uses_top_left_mask,\n            is_causal=self.is_causal,\n        )\n\n        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass LlamaSdpaAttention(LlamaAttention):\n    \"\"\"\n    Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from\n    `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to\n    SDPA API.\n    \"\"\"\n\n    # Adapted from LlamaAttention.forward\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            Tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # will become mandatory in v4.45\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if output_attentions:\n            # TODO: Improve this warning with e.g. `model.config.attn_implementation = \"manual\"` once this is implemented.\n            logger.warning_once(\n                \"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, \"\n                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n            )\n            return super().forward(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_value,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                cache_position=cache_position,\n                position_embeddings=position_embeddings,\n            )\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(\n            bsz, q_len, self.num_heads, self.head_dim\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n\n        if position_embeddings is None:\n            logger.warning_once(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            cos, sin = self.rotary_emb(value_states, position_ids)\n        else:\n            cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin\n        )\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        causal_mask = attention_mask\n        if attention_mask is not None:\n            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]\n\n        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,\n        # Reference: https://github.com/pytorch/pytorch/issues/112577.\n        if query_states.device.type == \"cuda\" and causal_mask is not None:\n            query_states = query_states.contiguous()\n            key_states = key_states.contiguous()\n            value_states = value_states.contiguous()\n\n        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment\n        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.\n        is_causal = True if causal_mask is None and q_len > 1 else False\n\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=causal_mask,\n            dropout_p=self.attention_dropout if self.training else 0.0,\n            is_causal=is_causal,\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.view(bsz, q_len, -1)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output, None, past_key_value\n\n\nLLAMA_ATTENTION_CLASSES = {\n    \"eager\": LlamaAttention,\n    \"flash_attention_2\": LlamaFlashAttention2,\n    \"sdpa\": LlamaSdpaAttention,\n}\n\n\nclass LlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](\n            config=config, layer_idx=layer_idx\n        )\n\n        self.mlp = LlamaMLP(config)\n        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = LlamaRMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            Tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # will become mandatory in v4.45\n        **kwargs,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*):\n                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,\n                query_sequence_length, key_sequence_length)` if default attention is used.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence\n            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):\n                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,\n                with `head_dim` being the embedding dimension of each attention head.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nLLAMA_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`LlamaConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaPreTrainedModel(PreTrainedModel):\n    config_class = LlamaConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"LlamaDecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nLLAMA_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaModel(LlamaPreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n        self.layers = nn.ModuleList(\n            [\n                LlamaDecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = LlamaRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training and use_cache:\n            logger.warning_once(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\"\n            )\n            use_cache = False\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        return_legacy_cache = False\n        if (\n            use_cache and not isinstance(past_key_values, Cache) and not self.training\n        ):  # kept for BC (non `Cache` `past_key_values` inputs)\n            return_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            logger.warning_once(\n                \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n                \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n            )\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=inputs_embeds.device,\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask,\n            inputs_embeds,\n            cache_position,\n            past_key_values,\n            output_attentions,\n        )\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                    position_embeddings,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if return_legacy_cache:\n            next_cache = next_cache.to_legacy_cache()\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = (\n            past_key_values.get_seq_length() if past_key_values is not None else 0\n        )\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and not using_static_cache\n            and not output_attentions\n        ):\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\n                    \"Custom 4D attention mask should be passed in inverted form with max==0`\"\n                )\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length),\n                fill_value=min_dtype,\n                dtype=dtype,\n                device=device,\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(\n                target_length, device=device\n            ) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(\n                input_tensor.shape[0], 1, -1, -1\n            )\n            if attention_mask is not None:\n                causal_mask = (\n                    causal_mask.clone()\n                )  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = (\n                    causal_mask[:, :, :, :mask_length]\n                    + attention_mask[:, None, None, :]\n                )\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[\n                    :, :, :, :mask_length\n                ].masked_fill(padding_mask, min_dtype)\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(\n                causal_mask, min_dtype\n            )\n\n        return causal_mask\n\n\nclass LlamaForCausalLM(LlamaPreTrainedModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = LlamaModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    @replace_return_docstrings(\n        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, LlamaForCausalLM\n\n        >>> model = LlamaForCausalLM.from_pretrained(\"meta-llama/Llama-2-7b-hf\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-hf\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        if self.config.pretraining_tp > 1:\n            lm_head_slices = self.lm_head.weight.split(\n                self.vocab_size // self.config.pretraining_tp, dim=0\n            )\n            logits = [\n                F.linear(hidden_states, lm_head_slices[i])\n                for i in range(self.config.pretraining_tp)\n            ]\n            logits = torch.cat(logits, dim=-1)\n        else:\n            logits = self.lm_head(hidden_states)\n        # logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        cache_position=None,\n        position_ids=None,\n        use_cache=True,\n        **kwargs,\n    ):\n        # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens\n        # Exception 1: when passing input_embeds, input_ids may be missing entries\n        # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here\n        if past_key_values is not None:\n            if inputs_embeds is not None:  # Exception 1\n                input_ids = input_ids[:, -cache_position.shape[0] :]\n            elif (\n                input_ids.shape[1] != cache_position.shape[0]\n            ):  # Default case (the \"else\", a no op, is Exception 2)\n                input_ids = input_ids[:, cache_position]\n\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and cache_position[0] == 0:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\n                \"input_ids\": input_ids.contiguous()\n            }  # `contiguous()` needed for compilation use cases\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"cache_position\": cache_position,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": use_cache,\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n\n@add_start_docstrings(\n    \"\"\"\n    The LLaMa Model transformer with a sequence classification head on top (linear layer).\n\n    [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaForSequenceClassification(LlamaPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = LlamaModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\n                \"Cannot handle batch sizes > 1 if no padding token is defined.\"\n            )\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility\n                sequence_lengths = (\n                    torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                )\n                sequence_lengths = sequence_lengths % input_ids.shape[-1]\n                sequence_lengths = sequence_lengths.to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[\n            torch.arange(batch_size, device=logits.device), sequence_lengths\n        ]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (\n                    labels.dtype == torch.long or labels.dtype == torch.int\n                ):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(\n                    pooled_logits.view(-1, self.num_labels), labels.view(-1)\n                )\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\nThe Llama Model transformer with a span classification head on top for extractive question-answering tasks like\nSQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).\n    \"\"\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaForQuestionAnswering(LlamaPreTrainedModel):\n    base_model_prefix = \"transformer\"\n\n    # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama\n    def __init__(self, config):\n        super().__init__(config)\n        self.transformer = LlamaModel(config)\n        self.qa_outputs = nn.Linear(config.hidden_size, 2)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.transformer.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.transformer.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        r\"\"\"\n        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        outputs = self.transformer(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = outputs[0]\n\n        logits = self.qa_outputs(sequence_output)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1).contiguous()\n        end_logits = end_logits.squeeze(-1).contiguous()\n\n        total_loss = None\n        if start_positions is not None and end_positions is not None:\n            # If we are on multi-GPU, split add a dimension\n            if len(start_positions.size()) > 1:\n                start_positions = start_positions.squeeze(-1).to(start_logits.device)\n            if len(end_positions.size()) > 1:\n                end_positions = end_positions.squeeze(-1).to(end_logits.device)\n            # sometimes the start/end positions are outside our model inputs, we ignore these terms\n            ignored_index = start_logits.size(1)\n            start_positions = start_positions.clamp(0, ignored_index)\n            end_positions = end_positions.clamp(0, ignored_index)\n\n            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)\n            start_loss = loss_fct(start_logits, start_positions)\n            end_loss = loss_fct(end_logits, end_positions)\n            total_loss = (start_loss + end_loss) / 2\n\n        if not return_dict:\n            output = (start_logits, end_logits) + outputs[2:]\n            return ((total_loss,) + output) if total_loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=total_loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states\n    output) e.g. for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaForTokenClassification(LlamaPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = LlamaModel(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, TokenClassifierOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        sequence_output = outputs[0]\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "archive/ktransformers/models/modeling_mixtral.py",
    "content": "# coding=utf-8\n'''\nDescription  : \nAuthor       : kkk1nak0\nDate         : 2024-07-29 02:58:57\nVersion      : 1.0.0\nLastEditors  : kkk1nak0\nLastEditTime : 2024-08-02 06:08:34\n'''\n\n# Adapted from \n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py\n# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.\n# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Mixtral model.\"\"\"\n\nimport inspect \nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n    _prepare_4d_causal_attention_mask,\n)\nfrom transformers.modeling_outputs import (\n    MoeCausalLMOutputWithPast,\n    MoeModelOutputWithPast,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.utils.import_utils import is_torch_fx_available\nfrom transformers.models.mixtral.configuration_mixtral import MixtralConfig\n\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_varlen_func, flash_attn_func, flash_attn_with_kvcache\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n    _flash_supports_window_size = \"window_size\" in list(inspect.signature(flash_attn_func).parameters)\n\n# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.\n# It means that the function will not be traced through and simply appear as a node in the graph.\nif is_torch_fx_available():\n    if not is_torch_greater_or_equal_than_1_13:\n        import torch.fx\n\n    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"MixtralConfig\"\n\n\ndef load_balancing_loss_func(\n    gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None\n) -> float:\n    r\"\"\"\n    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.\n\n    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss\n    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between\n    experts is too unbalanced.\n\n    Args:\n        gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):\n            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of\n            shape [batch_size X sequence_length, num_experts].\n        attention_mask (`torch.Tensor`, None):\n            The attention_mask used in forward function\n            shape [batch_size X sequence_length] if not None.\n        num_experts (`int`, *optional*):\n            Number of experts\n\n    Returns:\n        The auxiliary loss.\n    \"\"\"\n    if gate_logits is None or not isinstance(gate_logits, tuple):\n        return 0\n\n    if isinstance(gate_logits, tuple):\n        compute_device = gate_logits[0].device\n        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)\n\n    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)\n\n    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)\n\n    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)\n\n    if attention_mask is None:\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.mean(routing_weights, dim=0)\n    else:\n        batch_size, sequence_length = attention_mask.shape\n        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask\n        expert_attention_mask = (\n            attention_mask[None, :, :, None, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))\n            .reshape(-1, top_k, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(\n            expert_attention_mask, dim=0\n        )\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert\n        router_per_expert_attention_mask = (\n            attention_mask[None, :, :, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))\n            .reshape(-1, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(\n            router_per_expert_attention_mask, dim=0\n        )\n\n    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))\n    return overall_loss * num_experts\n\n\n# Copied from transformers.models.llama.modeling_llama._get_unpad_data\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral\nclass MixtralRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        MixtralRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\n# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral\n# TODO @longjie no longer copied from Mistral after static cache\nclass MixtralRotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n        \n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self.max_seq_len_cached = max_position_embeddings\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\n# TODO @longjie no longer copied from Mistral after static cache\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n            used to pass offsetted position ids when working with a KV-cache.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\n# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral\n# TODO @longjie no longer copied from Mistral after static cache\nclass MixtralAttention(nn.Module):\n    \"\"\"\n    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer\n    and \"Generating Long Sequences with Sparse Transformers\".\n    \"\"\"\n\n    def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will \"\n                \"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.is_causal = True\n        self.attention_dropout = config.attention_dropout\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)\n        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n        self.rotary_emb = MixtralRotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, position_ids)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:  # no matter the length, we just slice it\n            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n            attn_weights = attn_weights + causal_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\n# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral\n# TODO @longjie no longer copied from Mistral after static cache\nclass MixtralFlashAttention2(MixtralAttention):\n    \"\"\"\n    Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays\n    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of\n    flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ):\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(value_states, position_ids)\n\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        use_sliding_windows = (\n            _flash_supports_window_size\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and kv_seq_len > self.config.sliding_window\n            and self.config.use_sliding_window\n        )\n\n        if not _flash_supports_window_size:\n            logger.warning_once(\n                \"The current flash attention version does not support sliding window attention, for a more memory efficient implementation\"\n                \" make sure to upgrade flash-attn library.\"\n            )\n\n        if past_key_value is not None:\n            # Activate slicing cache only if the config has a value `sliding_windows` attribute\n            cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0\n            if (\n                getattr(self.config, \"sliding_window\", None) is not None\n                and kv_seq_len > self.config.sliding_window\n                and cache_has_contents\n            ):\n                slicing_tokens = 1 - self.config.sliding_window\n\n                past_key = past_key_value[self.layer_idx][0]\n                past_value = past_key_value[self.layer_idx][1]\n\n                past_key = past_key[:, :, slicing_tokens:, :].contiguous()\n                past_value = past_value[:, :, slicing_tokens:, :].contiguous()\n\n                if past_key.shape[-2] != self.config.sliding_window - 1:\n                    raise ValueError(\n                        f\"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got\"\n                        f\" {past_key.shape}\"\n                    )\n\n                if attention_mask is not None:\n                    attention_mask = attention_mask[:, slicing_tokens:]\n                    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)\n\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n            # we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails\n            # for bsz == 1, avoid using slice to capture cuda graph\n            if cache_position is not None and q_len > 1:\n                key_states = key_states[:, :, : cache_position[-1] + 1, :]\n                value_states = value_states[:, :, : cache_position[-1] + 1, :]\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n        dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in float16 just to be sure everything works as expected.\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            if torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            # Handle the case where the model is quantized\n            elif hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            else:\n                target_dtype = self.q_proj.weight.dtype\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        # Reashape to the expected shape for Flash Attention\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        attn_output = self._flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            position_ids=position_ids,\n            dropout=dropout_rate,\n            sliding_window=getattr(self.config, \"sliding_window\", None),\n            is_causal=self.is_causal,\n        )\n\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n    \n\n    def _flash_attention_forward(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        q_len,\n        position_ids,\n        dropout,\n        sliding_window,\n        is_causal,\n        softmax_scale=None,\n    ):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n\n        Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`float`):\n                Attention dropout\n            \n        \"\"\"\n        \n        # Decide whether to use SWA or not by layer index.\n        # if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:\n        #     use_sliding_windows = False\n        use_sliding_windows = False\n\n        # Contains at least one padding token in the sequence\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(\n                query_states, key_states, value_states, attention_mask, q_len\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            if not use_sliding_windows:\n                attn_output_unpad = flash_attn_varlen_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k=cu_seqlens_k,\n                    max_seqlen_q=max_seqlen_in_batch_q,\n                    max_seqlen_k=max_seqlen_in_batch_k,\n                    dropout_p=dropout,\n                    softmax_scale=softmax_scale,\n                    causal=is_causal,\n                )\n            else:\n                attn_output_unpad = flash_attn_varlen_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k=cu_seqlens_k,\n                    max_seqlen_q=max_seqlen_in_batch_q,\n                    max_seqlen_k=max_seqlen_in_batch_k,\n                    dropout_p=dropout,\n                    softmax_scale=softmax_scale,\n                    causal=is_causal,\n                    window_size=(self.config.sliding_window, self.config.sliding_window),\n                )\n\n            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)\n        else:\n            if not use_sliding_windows:\n                if q_len == 1:\n                    position_ids = position_ids.to(dtype=torch.int32).squeeze(1)\n                    attn_output = flash_attn_with_kvcache(\n                        query_states,\n                        key_states,\n                        value_states,\n                        cache_seqlens=position_ids,\n                        softmax_scale=softmax_scale,\n                        causal=is_causal,\n                    )   \n                else:\n                    attn_output = flash_attn_func(\n                        query_states,\n                        key_states,\n                        value_states,\n                        dropout,\n                        softmax_scale=softmax_scale,\n                        causal=is_causal,\n                    )\n            else:\n                attn_output = flash_attn_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    dropout,\n                    softmax_scale=softmax_scale,\n                    causal=is_causal,\n                    window_size=(self.config.sliding_window, self.config.sliding_window),\n                )\n\n        return attn_output\n\n    # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input\n    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):\n        batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape\n\n        # On the first iteration we need to properly re-create the padding mask\n        # by slicing it on the proper place\n        if kv_seq_len != attention_mask.shape[-1]:\n            attention_mask_num_tokens = attention_mask.shape[-1]\n            attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]\n\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n\n        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)\n        value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)\n\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(\n                query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k\n            )\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(\n                batch_size + 1, dtype=torch.int32, device=query_layer.device\n            )  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q,\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\n\n# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral\n# TODO @longjie no longer copied from Mistral after static cache\nclass MixtralSdpaAttention(MixtralAttention):\n    \"\"\"\n    Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from\n    `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to\n    SDPA API.\n    \"\"\"\n\n    # Adapted from MixtralAttention.forward\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if output_attentions:\n            # TODO: Improve this warning with e.g. `model.config.attn_implementation = \"manual\"` once this is implemented.\n            logger.warning_once(\n                \"MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, \"\n                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n            )\n            return super().forward(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_value,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n            )\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, position_ids)\n\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        causal_mask = attention_mask\n        if attention_mask is not None:  # no matter the length, we just slice it\n            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n\n        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,\n        # Reference: https://github.com/pytorch/pytorch/issues/112577.\n        if query_states.device.type == \"cuda\" and attention_mask is not None:\n            query_states = query_states.contiguous()\n            key_states = key_states.contiguous()\n            value_states = value_states.contiguous()\n\n        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment\n        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.\n        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.\n        is_causal = True if causal_mask is None and q_len > 1 else False\n\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=causal_mask,\n            dropout_p=self.attention_dropout if self.training else 0.0,\n            is_causal=is_causal,\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.view(bsz, q_len, self.hidden_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output, None, past_key_value\n\n\nMIXTRAL_ATTENTION_CLASSES = {\n    \"eager\": MixtralAttention,\n    \"flash_attention_2\": MixtralFlashAttention2,\n    \"sdpa\": MixtralSdpaAttention,\n}\n\n\nclass MixtralBlockSparseTop2MLP(nn.Module):\n    def __init__(self, config: MixtralConfig):\n        super().__init__()\n        self.ffn_dim = config.intermediate_size\n        self.hidden_dim = config.hidden_size\n\n        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)  # gate\n        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)  # down\n        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)  # up\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, hidden_states):\n        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)\n        current_hidden_states = self.w2(current_hidden_states)\n        return current_hidden_states\n\n\nclass MixtralSparseMoeBlock(nn.Module):\n    \"\"\"\n    This implementation is\n    strictly equivalent to standard MoE with full capacity (no\n    dropped tokens). It's faster since it formulates MoE operations\n    in terms of block-sparse operations to accomodate imbalanced\n    assignments of tokens to experts, whereas standard MoE either\n    (1) drop tokens at the cost of reduced performance or (2) set\n    capacity factor to number of experts and thus waste computation\n    and memory on padding.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.intermediate_size\n        self.num_experts = config.num_local_experts\n        self.top_k = config.num_experts_per_tok\n\n        # gating\n        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)\n\n        self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])\n\n        # Jitter parameters\n        self.jitter_noise = config.router_jitter_noise\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        if self.training and self.jitter_noise > 0:\n            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))\n        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n        return final_hidden_states, router_logits\n\n\nclass MixtralDecoderLayer(nn.Module):\n    def __init__(self, config: MixtralConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)\n\n        self.block_sparse_moe = MixtralSparseMoeBlock(config)\n        self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        output_router_logits: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_router_logits (`bool`, *optional*):\n                Whether or not to return the logits of all the routers. They are useful for computing the router loss, and\n                should not be returned during inference.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states, router_logits = self.block_sparse_moe(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        if output_router_logits:\n            outputs += (router_logits,)\n\n        return outputs\n\n\nMIXTRAL_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`MixtralConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Mixtral Model outputting raw hidden-states without any specific head on top.\",\n    MIXTRAL_START_DOCSTRING,\n)\n# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral\nclass MixtralPreTrainedModel(PreTrainedModel):\n    config_class = MixtralConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"MixtralDecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_cache_class = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nMIXTRAL_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape\n            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.\n\n            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n            `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        output_router_logits (`bool`, *optional*):\n            Whether or not to return the logits of all the routers. They are useful for computing the router loss, and\n            should not be returned during inference.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Mixtral Model outputting raw hidden-states without any specific head on top.\",\n    MIXTRAL_START_DOCSTRING,\n)\n# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral\n# TODO @longjie no longer copied from Mistral after static cache\nclass MixtralModel(MixtralPreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]\n\n    Args:\n        config: MixtralConfig\n    \"\"\"\n\n    def __init__(self, config: MixtralConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self._attn_implementation = config._attn_implementation\n        self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    # Ignore copy\n    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, MoeModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        use_legacy_cache = False\n        if use_cache and not isinstance(past_key_values, Cache) and not self.training:\n            use_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            logger.warning_once(\n                \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n                \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n            )\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        )\n\n        hidden_states = inputs_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        all_router_logits = () if output_router_logits else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    output_router_logits,\n                    use_cache,\n                    cache_position,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    output_router_logits=output_router_logits,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n            if output_router_logits:\n                all_router_logits += (layer_outputs[-1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]\n                if v is not None\n            )\n        return MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            router_logits=all_router_logits,\n        )\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\"Custom 4D attention mask should be passed in inverted form with max==0`\")\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n\nclass MixtralForCausalLM(MixtralPreTrainedModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = MixtralModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.router_aux_loss_coef = config.router_aux_loss_coef\n        self.num_experts = config.num_local_experts\n        self.num_experts_per_tok = config.num_experts_per_tok\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    # Ignore copy\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, MoeCausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, MixtralForCausalLM\n\n        >>> model = MixtralForCausalLM.from_pretrained(\"mistralai/Mixtral-8x7B-v0.1\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mixtral-8x7B-v0.1\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            output_router_logits=output_router_logits,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n        logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        aux_loss = None\n        if output_router_logits:\n            aux_loss = load_balancing_loss_func(\n                outputs.router_logits if return_dict else outputs[-1],\n                self.num_experts,\n                self.num_experts_per_tok,\n                attention_mask,\n            )\n            if labels is not None:\n                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            if output_router_logits:\n                output = (aux_loss,) + output\n            return (loss,) + output if loss is not None else output\n\n        return MoeCausalLMOutputWithPast(\n            loss=loss,\n            aux_loss=aux_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            router_logits=outputs.router_logits,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        cache_position=None,\n        output_router_logits=False,\n        position_ids=None,\n        use_cache=True,\n        **kwargs,\n    ):\n        # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens\n        # Exception 1: when passing input_embeds, input_ids may be missing entries\n        # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here\n        if past_key_values is not None:\n            if inputs_embeds is not None:  # Exception 1\n                input_ids = input_ids[:, -cache_position.shape[0] :]\n            elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the \"else\", a no op, is Exception 2)\n                input_ids = input_ids[:, cache_position]\n\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and cache_position[0] == 0:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids.contiguous()}  # `contiguous()` needed for compilation use cases\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"cache_position\": cache_position,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": use_cache,\n                \"attention_mask\": attention_mask,\n                \"output_router_logits\": output_router_logits,\n            }\n        )\n        return model_inputs\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Mixtral Model transformer with a sequence classification head on top (linear layer).\n\n    [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    MIXTRAL_START_DOCSTRING,\n)\n# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL\nclass MixtralForSequenceClassification(MixtralPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = MixtralModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility\n                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                sequence_lengths = sequence_lengths % input_ids.shape[-1]\n                sequence_lengths = sequence_lengths.to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states\n    output) e.g. for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    MIXTRAL_START_DOCSTRING,\n)\n# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL\nclass MixtralForTokenClassification(MixtralPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = MixtralModel(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, TokenClassifierOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        sequence_output = outputs[0]\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )"
  },
  {
    "path": "archive/ktransformers/models/modeling_qwen2_moe.py",
    "content": "# coding=utf-8\n'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\n''' \n# Adapted from\n# https://github.com/huggingface/transformers/blob/v4.42.3/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n# \n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Qwen2MoE model.\"\"\"\n\nimport inspect\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n)\nfrom transformers.modeling_outputs import (\n    MoeCausalLMOutputWithPast,\n    MoeModelOutputWithPast,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig\n\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n    _flash_supports_window_size = \"window_size\" in list(inspect.signature(flash_attn_func).parameters)\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"Qwen/Qwen1.5-MoE-A2.7B\"\n_CONFIG_FOR_DOC = \"Qwen2MoeConfig\"\n\n\n# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func\ndef load_balancing_loss_func(\n    gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None\n) -> float:\n    r\"\"\"\n    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.\n\n    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss\n    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between\n    experts is too unbalanced.\n\n    Args:\n        gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):\n            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of\n            shape [batch_size X sequence_length, num_experts].\n        attention_mask (`torch.Tensor`, None):\n            The attention_mask used in forward function\n            shape [batch_size X sequence_length] if not None.\n        num_experts (`int`, *optional*):\n            Number of experts\n\n    Returns:\n        The auxiliary loss.\n    \"\"\"\n    if gate_logits is None or not isinstance(gate_logits, tuple):\n        return 0\n\n    if isinstance(gate_logits, tuple):\n        compute_device = gate_logits[0].device\n        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)\n\n    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)\n\n    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)\n\n    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)\n\n    if attention_mask is None:\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.mean(routing_weights, dim=0)\n    else:\n        batch_size, sequence_length = attention_mask.shape\n        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask\n        expert_attention_mask = (\n            attention_mask[None, :, :, None, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))\n            .reshape(-1, top_k, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(\n            expert_attention_mask, dim=0\n        )\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert\n        router_per_expert_attention_mask = (\n            attention_mask[None, :, :, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))\n            .reshape(-1, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(\n            router_per_expert_attention_mask, dim=0\n        )\n\n    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))\n    return overall_loss * num_experts\n\n\n# Copied from transformers.models.llama.modeling_llama._get_unpad_data\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2Moe\nclass Qwen2MoeRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Qwen2MoeRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe\nclass Qwen2MoeRotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        super().__init__()\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\n# Modified from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2Moe\nclass Qwen2MoeMLP(nn.Module):\n    def __init__(self, config, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\n# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe\nclass Qwen2MoeAttention(nn.Module):\n    \"\"\"\n    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer\n    and \"Generating Long Sequences with Sparse Transformers\".\n    \"\"\"\n\n    def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will \"\n                \"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.is_causal = True\n        self.attention_dropout = config.attention_dropout\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)\n        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)\n        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n        self.rotary_emb = Qwen2MoeRotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, position_ids)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:  # no matter the length, we just slice it\n            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n            attn_weights = attn_weights + causal_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\n# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe\nclass Qwen2MoeFlashAttention2(Qwen2MoeAttention):\n    \"\"\"\n    Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention`\n    as the weights of the module stays untouched. The only required change would be on the forward pass\n    where it needs to correctly call the public API of flash attention and deal with padding tokens\n    in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom\n    config.max_window_layers layers.\n    \"\"\"\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.\n        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.\n        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).\n        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ):\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(value_states, position_ids)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        use_sliding_windows = (\n            _flash_supports_window_size\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and kv_seq_len > self.config.sliding_window\n            and self.config.use_sliding_window\n        )\n\n        if not _flash_supports_window_size:\n            logger.warning_once(\n                \"The current flash attention version does not support sliding window attention, for a more memory efficient implementation\"\n                \" make sure to upgrade flash-attn library.\"\n            )\n\n        if past_key_value is not None:\n            # Activate slicing cache only if the config has a value `sliding_windows` attribute\n            cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0\n            if (\n                getattr(self.config, \"sliding_window\", None) is not None\n                and kv_seq_len > self.config.sliding_window\n                and cache_has_contents\n            ):\n                slicing_tokens = 1 - self.config.sliding_window\n\n                past_key = past_key_value[self.layer_idx][0]\n                past_value = past_key_value[self.layer_idx][1]\n\n                past_key = past_key[:, :, slicing_tokens:, :].contiguous()\n                past_value = past_value[:, :, slicing_tokens:, :].contiguous()\n\n                if past_key.shape[-2] != self.config.sliding_window - 1:\n                    raise ValueError(\n                        f\"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got\"\n                        f\" {past_key.shape}\"\n                    )\n\n                if attention_mask is not None:\n                    attention_mask = attention_mask[:, slicing_tokens:]\n                    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)\n\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n            # we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails\n            # for bsz == 1, avoid using slice to capture cuda graph\n            if cache_position is not None and q_len > 1:\n                key_states = key_states[:, :, : cache_position[-1] + 1, :]\n                value_states = value_states[:, :, : cache_position[-1] + 1, :]\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n        dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in float16 just to be sure everything works as expected.\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            if torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            # Handle the case where the model is quantized\n            elif hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            else:\n                target_dtype = self.q_proj.weight.dtype\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        # Reashape to the expected shape for Flash Attention\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        attn_output = self._flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            position_ids=position_ids,\n            dropout=dropout_rate,\n            use_sliding_windows=use_sliding_windows,\n        )\n\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n    def _flash_attention_forward(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        query_length,\n        position_ids,\n        dropout=0.0,\n        softmax_scale=None,\n        use_sliding_windows=False,\n    ):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n\n        Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`float`):\n                Attention dropout\n            softmax_scale (`float`, *optional*):\n                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)\n            use_sliding_windows (`bool`, *optional*):\n                Whether to activate sliding window attention.\n        \"\"\"\n        if not self._flash_attn_uses_top_left_mask:\n            causal = self.is_causal\n        else:\n            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.\n            causal = self.is_causal and query_length != 1\n\n        # Decide whether to use SWA or not by layer index.\n        if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:\n            use_sliding_windows = False\n\n        # Contains at least one padding token in the sequence\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(\n                query_states, key_states, value_states, attention_mask, query_length\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            if not use_sliding_windows:\n                attn_output_unpad = flash_attn_varlen_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k=cu_seqlens_k,\n                    max_seqlen_q=max_seqlen_in_batch_q,\n                    max_seqlen_k=max_seqlen_in_batch_k,\n                    dropout_p=dropout,\n                    softmax_scale=softmax_scale,\n                    causal=causal,\n                )\n            else:\n                attn_output_unpad = flash_attn_varlen_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k=cu_seqlens_k,\n                    max_seqlen_q=max_seqlen_in_batch_q,\n                    max_seqlen_k=max_seqlen_in_batch_k,\n                    dropout_p=dropout,\n                    softmax_scale=softmax_scale,\n                    causal=causal,\n                    window_size=(self.config.sliding_window, self.config.sliding_window),\n                )\n\n            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)\n        else:\n            if not use_sliding_windows:\n                if query_length == 1:\n                    position_ids = position_ids.to(dtype=torch.int32).squeeze(1)\n                    attn_output = flash_attn_with_kvcache(\n                        query_states,\n                        key_states,\n                        value_states,\n                        cache_seqlens=position_ids,\n                        softmax_scale=softmax_scale,\n                        causal=causal,\n                    )   \n                else:\n                    attn_output = flash_attn_func(\n                        query_states,\n                        key_states,\n                        value_states,\n                        dropout,\n                        softmax_scale=softmax_scale,\n                        causal=causal,\n                    )\n            else:\n                attn_output = flash_attn_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    dropout,\n                    softmax_scale=softmax_scale,\n                    causal=causal,\n                    window_size=(self.config.sliding_window, self.config.sliding_window),\n                )\n\n        return attn_output\n\n    # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input\n    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):\n        batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape\n\n        # On the first iteration we need to properly re-create the padding mask\n        # by slicing it on the proper place\n        if kv_seq_len != attention_mask.shape[-1]:\n            attention_mask_num_tokens = attention_mask.shape[-1]\n            attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]\n\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n\n        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)\n        value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)\n\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(\n                query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k\n            )\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(\n                batch_size + 1, dtype=torch.int32, device=query_layer.device\n            )  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q,\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\n# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe\nclass Qwen2MoeSdpaAttention(Qwen2MoeAttention):\n    \"\"\"\n    Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from\n    `Qwen2MoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to\n    SDPA API.\n    \"\"\"\n\n    # Adapted from Qwen2MoeAttention.forward\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if output_attentions:\n            # TODO: Improve this warning with e.g. `model.config.attn_implementation = \"manual\"` once this is implemented.\n            logger.warning_once(\n                \"Qwen2MoeModel is using Qwen2MoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, \"\n                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n            )\n            return super().forward(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_value,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n            )\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, position_ids)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        causal_mask = attention_mask\n        if attention_mask is not None:  # no matter the length, we just slice it\n            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n\n        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,\n        # Reference: https://github.com/pytorch/pytorch/issues/112577.\n        if query_states.device.type == \"cuda\" and attention_mask is not None:\n            query_states = query_states.contiguous()\n            key_states = key_states.contiguous()\n            value_states = value_states.contiguous()\n\n        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment\n        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.\n        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.\n        is_causal = True if causal_mask is None and q_len > 1 else False\n\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=causal_mask,\n            dropout_p=self.attention_dropout if self.training else 0.0,\n            is_causal=is_causal,\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.view(bsz, q_len, self.hidden_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output, None, past_key_value\n\n\nQWEN2MOE_ATTENTION_CLASSES = {\n    \"eager\": Qwen2MoeAttention,\n    \"flash_attention_2\": Qwen2MoeFlashAttention2,\n    \"sdpa\": Qwen2MoeSdpaAttention,\n}\n\n\nclass Qwen2MoeSparseMoeBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n\n        # gating\n        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)\n        self.experts = nn.ModuleList(\n            [Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]\n        )\n\n        self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)\n        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        if self.norm_topk_prob:\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))\n\n        shared_expert_output = self.shared_expert(hidden_states)\n        shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output\n\n        final_hidden_states = final_hidden_states + shared_expert_output\n\n        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n        return final_hidden_states, router_logits\n\n\nclass Qwen2MoeDecoderLayer(nn.Module):\n    def __init__(self, config: Qwen2MoeConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)\n\n        if (layer_idx not in config.mlp_only_layers) and (\n            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0\n        ):\n            self.mlp = Qwen2MoeSparseMoeBlock(config)\n        else:\n            self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)\n\n        self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        output_router_logits: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_router_logits (`bool`, *optional*):\n                Whether or not to return the logits of all the routers. They are useful for computing the router loss,\n                and should not be returned during inference.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        hidden_states = self.mlp(hidden_states)\n        if isinstance(hidden_states, tuple):\n            hidden_states, router_logits = hidden_states\n        else:\n            router_logits = None\n\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        if output_router_logits:\n            outputs += (router_logits,)\n\n        return outputs\n\n\nQWEN2MOE_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`Qwen2MoeConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.\",\n    QWEN2MOE_START_DOCSTRING,\n)\nclass Qwen2MoePreTrainedModel(PreTrainedModel):\n    config_class = Qwen2MoeConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Qwen2MoeDecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_cache_class = True\n    _supports_static_cache = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nQWEN2MOE_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        output_router_logits (`bool`, *optional*):\n            Whether or not to return the logits of all the routers. They are useful for computing the router loss, and\n            should not be returned during inference.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.\",\n    QWEN2MOE_START_DOCSTRING,\n)\nclass Qwen2MoeModel(Qwen2MoePreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`]\n\n    Args:\n        config: Qwen2MoeConfig\n    \"\"\"\n\n    def __init__(self, config: Qwen2MoeConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [Qwen2MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self._attn_implementation = config._attn_implementation\n        self.norm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, MoeModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        use_legacy_cache = False\n        if use_cache and not isinstance(past_key_values, Cache):\n            use_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            logger.warning_once(\n                \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n                \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n            )\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        )\n\n        hidden_states = inputs_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        all_router_logits = () if output_router_logits else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    output_router_logits,\n                    use_cache,\n                    cache_position,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    output_router_logits=output_router_logits,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n            if output_router_logits and layer_outputs[-1] is not None:\n                all_router_logits += (layer_outputs[-1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]\n                if v is not None\n            )\n        return MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            router_logits=all_router_logits,\n        )\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\"Custom 4D attention mask should be passed in inverted form with max==0`\")\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n\nclass Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = Qwen2MoeModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        self.router_aux_loss_coef = config.router_aux_loss_coef\n        self.num_experts = config.num_experts\n        self.num_experts_per_tok = config.num_experts_per_tok\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, MoeCausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM\n\n        >>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            output_router_logits=output_router_logits,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n        logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        aux_loss = None\n        if output_router_logits:\n            aux_loss = load_balancing_loss_func(\n                outputs.router_logits if return_dict else outputs[-1],\n                self.num_experts,\n                self.num_experts_per_tok,\n                attention_mask,\n            )\n            if labels is not None:\n                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            if output_router_logits:\n                output = (aux_loss,) + output\n            return (loss,) + output if loss is not None else output\n\n        return MoeCausalLMOutputWithPast(\n            loss=loss,\n            aux_loss=aux_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            router_logits=outputs.router_logits,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        cache_position=None,\n        use_cache=True,\n        **kwargs,\n    ):\n        past_length = 0\n        # Omit tokens covered by past_key_values\n        if past_key_values is not None:\n            if isinstance(past_key_values, Cache):\n                past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()\n                max_cache_length = (\n                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)\n                    if past_key_values.get_max_length() is not None\n                    else None\n                )\n                cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)\n            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects\n            else:\n                cache_length = past_length = past_key_values[0][0].shape[2]\n                max_cache_length = None\n\n            # Keep only the unprocessed tokens:\n            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where\n            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as\n            # input)\n            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:\n                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]\n            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard\n            # input_ids based on the past_length.\n            elif past_length < input_ids.shape[1]:\n                input_ids = input_ids[:, past_length:]\n            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.\n\n            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.\n            if (\n                max_cache_length is not None\n                and attention_mask is not None\n                and cache_length + input_ids.shape[1] > max_cache_length\n            ):\n                attention_mask = attention_mask[:, -max_cache_length:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_length == 0:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]\n        if cache_position is None:\n            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)\n        elif use_cache:\n            cache_position = cache_position[-input_length:]\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": use_cache,\n                \"attention_mask\": attention_mask,\n                \"cache_position\": cache_position,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (\n                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),\n            )\n        return reordered_past\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen2MoE Model transformer with a sequence classification head on top (linear layer).\n\n    [`Qwen2MoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    QWEN2MOE_START_DOCSTRING,\n)\n# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE\nclass Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen2MoeModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility\n                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                sequence_lengths = sequence_lengths % input_ids.shape[-1]\n                sequence_lengths = sequence_lengths.to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states\n    output) e.g. for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    QWEN2MOE_START_DOCSTRING,\n)\n# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE\nclass Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen2MoeModel(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, TokenClassifierOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        sequence_output = outputs[0]\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "archive/ktransformers/models/modeling_qwen3_moe.py",
    "content": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from src/transformers/models/qwen3_moe/modular_qwen3_moe.py.\n#               Do NOT edit this file manually as any edits will be overwritten by the generation of\n#             the file from the modular. If any change should be done, please apply the change to the\n#                          modular_qwen3_moe.py file directly. One of our CI enforces this.\n#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# coding=utf-8\n# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache\nfrom transformers.generation import GenerationMixin\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\n# from transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    MoeCausalLMOutputWithPast,\n    MoeModelOutputWithPast,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS\n# from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom transformers.modeling_utils import PreTrainedModel\n# from transformers.processing_utils import Unpack\nfrom transformers.utils import (\n    # LossKwargs,\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.utils.deprecation import deprecate_kwarg\nfrom .configuration_qwen3_moe import Qwen3MoeConfig\n\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeRotaryEmbedding\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"Qwen/Qwen3-MoE-15B-A2B\"\n_CONFIG_FOR_DOC = \"Qwen3MoeConfig\"\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs,\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\nclass Qwen3MoeAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: Qwen3MoeConfig, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n        self.num_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.scaling = self.head_dim**-0.5\n        self.attention_dropout = config.attention_dropout\n        self.is_causal = True\n\n        self.q_proj = nn.Linear(\n            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.k_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.v_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.o_proj = nn.Linear(\n            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias\n        )\n        self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)  # unlike olmo, only on the head dim!\n        self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)  # thus post q_norm does not need reshape\n\n        self.rotary_emb = Qwen2MoeRotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n        self.sliding_window = config.sliding_window\n        if not (\n            self.config.use_sliding_window\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and self.layer_idx >= self.config.max_window_layers\n        ):\n            self.sliding_window = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        # **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        attention_interface: Callable = eager_attention_forward\n        # if self.config._attn_implementation != \"eager\":\n        #     if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n        #         logger.warning_once(\n        #             \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to \"\n        #             'eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n        #         )\n        #     else:\n        #         attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            sliding_window=self.sliding_window,  # diff with Llama\n            # **kwargs,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\nclass Qwen3MoeMLP(nn.Module):\n    def __init__(self, config, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass Qwen3MoeSparseMoeBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n\n        # gating\n        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)\n        self.experts = nn.ModuleList(\n            [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        if self.norm_topk_prob:  # only diff with mixtral sparse moe block!\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))\n        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n        return final_hidden_states, router_logits\n\n\nclass Qwen3MoeRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Qwen3MoeRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n        self.hidden_size = hidden_size\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\nclass Qwen3MoeDecoderLayer(nn.Module):\n    def __init__(self, config: Qwen3MoeConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = Qwen3MoeAttention(config, layer_idx)\n        self.mlp = Qwen3MoeMLP(config)\n\n        self.self_attn = Qwen3MoeAttention(config, layer_idx)\n\n        if (layer_idx not in config.mlp_only_layers) and (\n            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0\n        ):\n            self.mlp = Qwen3MoeSparseMoeBlock(config)\n        else:\n            self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)\n\n        self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        output_router_logits: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC\n        # **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_router_logits (`bool`, *optional*):\n                Whether or not to return the logits of all the routers. They are useful for computing the router loss,\n                and should not be returned during inference.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):\n                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,\n                with `head_dim` being the embedding dimension of each attention head.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        hidden_states = self.mlp(hidden_states)\n        if isinstance(hidden_states, tuple):\n            hidden_states, router_logits = hidden_states\n        else:\n            router_logits = None\n\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if output_router_logits:\n            outputs += (router_logits,)\n\n        return outputs\n\n\ndef _compute_default_rope_parameters(\n    config: Optional[Qwen3MoeConfig] = None,\n    device: Optional[\"torch.device\"] = None,\n    seq_len: Optional[int] = None,\n    **rope_kwargs,\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies according to the original RoPE implementation\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length. Unused for this type of RoPE.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).\n    \"\"\"\n    if config is not None and len(rope_kwargs) > 0:\n        raise ValueError(\n            \"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in \"\n            f\"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}\"\n        )\n    if len(rope_kwargs) > 0:\n        base = rope_kwargs[\"base\"]\n        dim = rope_kwargs[\"dim\"]\n    elif config is not None:\n        base = config.rope_theta\n        partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n        dim = int(config.head_dim * partial_rotary_factor)\n\n    attention_factor = 1.0  # Unused in this type of RoPE\n\n    # Compute the inverse frequencies\n    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))\n    return inv_freq, attention_factor\n\nclass Qwen3MoeRotaryEmbedding(nn.Module):\n    def __init__(self, config: Qwen3MoeConfig, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:\n            self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        self.scaling_factor = 1.0\n        self.dim = config.head_dim\n        self.max_position_embeddings = config.max_position_embeddings\n        self.base = config.rope_theta\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n\n        inv_freq, self.attention_scaling = _compute_default_rope_parameters(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    def _dynamic_frequency_update(self, position_ids, device):\n        \"\"\"\n        dynamic RoPE layers should recompute `inv_freq` in the following situations:\n        1 - growing beyond the cached sequence length (allow scaling)\n        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)\n        \"\"\"\n        seq_len = torch.max(position_ids) + 1\n        if seq_len > self.max_seq_len_cached:  # growth\n            inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)  # TODO joao: may break with compilation\n            self.max_seq_len_cached = seq_len\n\n        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset\n            # This .to() is needed if the model has been moved to a device after being initialized (because\n            # the buffer is automatically moved, but not the original copy)\n            self.original_inv_freq = self.original_inv_freq.to(device)\n            self.register_buffer(\"inv_freq\", self.original_inv_freq, persistent=False)\n            self.max_seq_len_cached = self.original_max_seq_len\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        if \"dynamic\" in self.rope_type:\n            self._dynamic_frequency_update(position_ids, device=x.device)\n\n        # Core RoPE block\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n\n        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention\n        cos = cos * self.attention_scaling\n        sin = sin * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nQWEN3_MOE_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`Qwen3MoeConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen3Moe Model outputting raw hidden-states without any specific head on top.\",\n    QWEN3_MOE_START_DOCSTRING,\n)\nclass Qwen3MoePreTrainedModel(PreTrainedModel):\n    config_class = Qwen3MoeConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Qwen3MoeDecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported)\n    _supports_attention_backend = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nQWEN3_MOE_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen3Moe Model outputting raw hidden-states without any specific head on top.\",\n    QWEN3_MOE_START_DOCSTRING,\n)\nclass Qwen3MoeModel(Qwen3MoePreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3MoeDecoderLayer`]\n\n    Args:\n        config: Qwen3MoeConfig\n    \"\"\"\n\n    def __init__(self, config: Qwen3MoeConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [Qwen3MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        # **flash_attn_kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache()\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        all_router_logits = () if output_router_logits else None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    output_router_logits,\n                    use_cache,\n                    cache_position,\n                    position_embeddings,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    output_router_logits=output_router_logits,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                    # **flash_attn_kwargs,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n            if output_router_logits:\n                all_router_logits += (layer_outputs[-1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        output = MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            router_logits=all_router_logits,\n        )\n        return output if return_dict else output.to_tuple()\n\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool = False,\n    ):\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and past_key_values is not None:\n                is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]\n                if is_padding_right:\n                    raise ValueError(\n                        \"You are attempting to perform batched generation with padding_side='right'\"\n                        \" this may lead to unexpected behaviour for Flash Attention version of Qwen3Moe. Make sure to \"\n                        \" call `tokenizer.padding_side  = 'left'` before tokenizing the input. \"\n                    )\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n        using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and not (using_static_cache or using_sliding_window_cache)\n            and not output_attentions\n        ):\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                sliding_window=self.config.sliding_window,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        # SlidingWindowCache or StaticCache\n        if using_sliding_window_cache or using_static_cache:\n            target_length = past_key_values.get_max_cache_shape()\n        # DynamicCache or no cache\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).\n        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(\n            attention_mask,\n            sequence_length=sequence_length,\n            target_length=target_length,\n            dtype=dtype,\n            device=device,\n            cache_position=cache_position,\n            batch_size=input_tensor.shape[0],\n            config=self.config,\n            past_key_values=past_key_values,\n        )\n\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type in [\"cuda\", \"xpu\"]\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n    @staticmethod\n    def _prepare_4d_causal_attention_mask_with_cache_position(\n        attention_mask: torch.Tensor,\n        sequence_length: int,\n        target_length: int,\n        dtype: torch.dtype,\n        device: torch.device,\n        cache_position: torch.Tensor,\n        batch_size: int,\n        config: Qwen3MoeConfig,\n        past_key_values: Cache,\n    ):\n        \"\"\"\n        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape\n        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.\n            sequence_length (`int`):\n                The sequence length being processed.\n            target_length (`int`):\n                The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.\n            dtype (`torch.dtype`):\n                The dtype to use for the 4D attention mask.\n            device (`torch.device`):\n                The device to place the 4D attention mask on.\n            cache_position (`torch.Tensor`):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            batch_size (`torch.Tensor`):\n                Batch size.\n            config (`Qwen3MoeConfig`):\n                The model's configuration class\n            past_key_values (`Cache`):\n                The cache class that is being used currently to generate\n        \"\"\"\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n            causal_mask = attention_mask\n        else:\n            min_dtype = torch.finfo(dtype).min\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            if config.sliding_window is not None:\n                # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also\n                # the check is needed to verify is current checkpoint was trained with sliding window or not\n                if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:\n                    sliding_attend_mask = torch.arange(target_length, device=device) <= (\n                        cache_position.reshape(-1, 1) - config.sliding_window\n                    )\n                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)\n            causal_mask *= diagonal_attend_mask\n            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                if attention_mask.shape[-1] > target_length:\n                    attention_mask = attention_mask[:, :target_length]\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(\n                    causal_mask.device\n                )\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        return causal_mask\n\n\n# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...\nclass KwargsForCausalLM(): ...\n\n\ndef load_balancing_loss_func(\n    gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],\n    num_experts: Optional[int] = None,\n    top_k=2,\n    attention_mask: Optional[torch.Tensor] = None,\n) -> Union[torch.Tensor, int]:\n    r\"\"\"\n    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.\n\n    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss\n    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between\n    experts is too unbalanced.\n\n    Args:\n        gate_logits:\n            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of\n            shape [batch_size X sequence_length, num_experts].\n        num_experts:\n            Number of experts\n        top_k:\n            The number of experts to route per-token, can be also interpreted as the `top-k` routing\n            parameter.\n        attention_mask (`torch.Tensor`, *optional*):\n            The attention_mask used in forward function\n            shape [batch_size X sequence_length] if not None.\n\n    Returns:\n        The auxiliary loss.\n    \"\"\"\n    if gate_logits is None or not isinstance(gate_logits, tuple):\n        return 0\n\n    if isinstance(gate_logits, tuple):\n        compute_device = gate_logits[0].device\n        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)\n\n    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)\n\n    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)\n\n    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)\n\n    if attention_mask is None:\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.mean(routing_weights, dim=0)\n    else:\n        batch_size, sequence_length = attention_mask.shape\n        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask\n        expert_attention_mask = (\n            attention_mask[None, :, :, None, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))\n            .reshape(-1, top_k, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(\n            expert_attention_mask, dim=0\n        )\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert\n        router_per_expert_attention_mask = (\n            attention_mask[None, :, :, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))\n            .reshape(-1, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(\n            router_per_expert_attention_mask, dim=0\n        )\n\n    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))\n    return overall_loss * num_experts\n\n\nclass Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    _tp_plan = {\"lm_head\": \"colwise_rep\"}\n    _pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = Qwen3MoeModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.router_aux_loss_coef = config.router_aux_loss_coef\n        self.num_experts = config.num_experts\n        self.num_experts_per_tok = config.num_experts_per_tok\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @deprecate_kwarg(\"num_logits_to_keep\", version=\"4.50\", new_name=\"logits_to_keep\")\n    @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        # **kwargs: Unpack[KwargsForCausalLM],\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n            logits_to_keep (`int` or `torch.Tensor`, *optional*):\n                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n                If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.\n                This is useful when using packed tensor format (single dimension for batch and sequence length).\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM\n\n        >>> model = Qwen3MoeForCausalLM.from_pretrained(\"Qwen/Qwen3-MoE-15B-A2B\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen3-MoE-15B-A2B\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            output_router_logits=output_router_logits,\n            return_dict=return_dict,\n            cache_position=cache_position,\n            # **kwargs,\n        )\n\n        hidden_states = outputs[0]\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.vocab_size)\n\n        aux_loss = None\n        if output_router_logits:\n            aux_loss = load_balancing_loss_func(\n                outputs.router_logits if return_dict else outputs[-1],\n                self.num_experts,\n                self.num_experts_per_tok,\n                attention_mask,\n            )\n            if labels is not None:\n                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            if output_router_logits:\n                output = (aux_loss,) + output\n            return (loss,) + output if loss is not None else output\n\n        return MoeCausalLMOutputWithPast(\n            loss=loss,\n            aux_loss=aux_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            router_logits=outputs.router_logits,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen3Moe Model transformer with a sequence classification head on top (linear layer).\n\n    [`Qwen3MoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    QWEN3_MOE_START_DOCSTRING,\n)\nclass Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen3MoeModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            last_non_pad_token = -1\n        elif input_ids is not None:\n            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id\n            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)\n            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)\n            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)\n        else:\n            last_non_pad_token = -1\n            logger.warning_once(\n                f\"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be \"\n                \"unexpected if using padding tokens in conjunction with `inputs_embeds.`\"\n            )\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)\n\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen3Moe Model transformer with a token classification head on top (a linear layer on top of the hidden-states\n    output) e.g. for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    QWEN3_MOE_START_DOCSTRING,\n)\nclass Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen3MoeModel(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=TokenClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, TokenClassifierOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        sequence_output = outputs[0]\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.config)\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\nThe Qwen3Moe Model transformer with a span classification head on top for extractive question-answering tasks like\nSQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).\n    \"\"\",\n    QWEN3_MOE_START_DOCSTRING,\n)\nclass Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel):\n    base_model_prefix = \"transformer\"\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.transformer = Qwen3MoeModel(config)\n        self.qa_outputs = nn.Linear(config.hidden_size, 2)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.transformer.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.transformer.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs,\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        r\"\"\"\n        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.transformer(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = outputs[0]\n\n        logits = self.qa_outputs(sequence_output)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1).contiguous()\n        end_logits = end_logits.squeeze(-1).contiguous()\n\n        loss = None\n        if start_positions is not None and end_positions is not None:\n            loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)\n\n        if not return_dict:\n            output = (start_logits, end_logits) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n__all__ = [\n    \"Qwen3MoeForCausalLM\",\n    \"Qwen3MoeForQuestionAnswering\",\n    \"Qwen3MoeModel\",\n    \"Qwen3MoePreTrainedModel\",\n    \"Qwen3MoeForSequenceClassification\",\n    \"Qwen3MoeForTokenClassification\",\n]"
  },
  {
    "path": "archive/ktransformers/models/modeling_qwen3_next.py",
    "content": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from src/transformers/models/qwen3_next/modular_qwen3_next.py.\n#               Do NOT edit this file manually as any edits will be overwritten by the generation of\n#             the file from the modular. If any change should be done, please apply the change to the\n#                          modular_qwen3_next.py file directly. One of our CI enforces this.\n#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# coding=utf-8\n# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Any, Callable, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache\nfrom transformers.generation import GenerationMixin\nfrom transformers.masking_utils import create_causal_mask\nfrom transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_layers import (\n    GenericForQuestionAnswering,\n    GenericForSequenceClassification,\n    GenericForTokenClassification,\n    GradientCheckpointingLayer,\n)\nfrom transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom transformers.processing_utils import Unpack\nfrom transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging\nfrom transformers.utils.deprecation import deprecate_kwarg\nfrom transformers.utils.generic import OutputRecorder, check_model_inputs\ntry:\n    from transformers.utils.import_utils import (\n        is_causal_conv1d_available,\n        is_flash_linear_attention_available,\n    )\nexcept ImportError:\n    is_causal_conv1d_available = lambda: False\n\n\ntry:\n    from transformers.utils.import_utils import (\n        is_flash_linear_attention_available,\n    )\nexcept ImportError:\n    is_flash_linear_attention_available = lambda: False\n\n\nfrom .configuration_qwen3_next import Qwen3NextConfig\n\n\nif is_causal_conv1d_available():\n    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update\nelse:\n    causal_conv1d_update, causal_conv1d_fn = None, None\n\n\nif is_flash_linear_attention_available():\n    from fla.modules import FusedRMSNormGated\n    from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule\nelse:\n    chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None\n    FusedRMSNormGated = None\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen3NextRMSNormGated(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6, **kwargs):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states, gate=None):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        # Norm before gate\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        hidden_states = self.weight * hidden_states.to(input_dtype)\n        hidden_states = hidden_states * F.silu(gate.to(torch.float32))\n\n        return hidden_states.to(input_dtype)\n\n\nclass Qwen3NextDynamicCache:\n    \"\"\"\n    A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention\n    cache (which has a constant shape regardless of seq_len).\n\n    This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`\n    and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor\n    For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,\n    while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).\n    For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),\n    while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,\n    and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`.\n    \"\"\"\n\n    is_compileable = False\n\n    def __init__(self, config: Qwen3NextConfig):\n        super().__init__()\n        self.layer_types = config.layer_types\n        self.transformer_layers = [\n            i for i in range(config.num_hidden_layers) if self.layer_types[i] == \"full_attention\"\n        ]\n        self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index(\"linear_attention\")\n\n        # Initialize everything to None -> will be lazy initialized to allow multi-gpu (device_map) inference\n        self.conv_states = [None for _ in range(config.num_hidden_layers)]\n        self.recurrent_states = [None for _ in range(config.num_hidden_layers)]\n        self.key_cache = [None for _ in range(config.num_hidden_layers)]\n        self.value_cache = [None for _ in range(config.num_hidden_layers)]\n\n    def __len__(self):\n        return len(self.layer_types)\n\n    def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:\n        return self.key_cache[layer_idx], self.value_cache[layer_idx]\n\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n        cache_kwargs: Optional[dict[str, Any]] = None,\n    ) -> tuple[torch.Tensor, torch.Tensor]:\n        if self.key_cache[layer_idx] is None:\n            self.key_cache[layer_idx] = key_states\n            self.value_cache[layer_idx] = value_states\n        else:\n            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)\n            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)\n\n        return self.key_cache[layer_idx], self.value_cache[layer_idx]\n\n    def reorder_cache(self, beam_idx: torch.LongTensor):\n        \"\"\"Reorders the cache for beam search, given the selected beam indices.\"\"\"\n        for layer_idx in range(len(self.key_cache)):\n            if self.key_cache[layer_idx] is not None:\n                device = self.key_cache[layer_idx].device\n                beam_idx = beam_idx.to(device)\n                self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx)\n                self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx)\n\n            if self.conv_states[layer_idx] is not None:\n                device = self.conv_states[layer_idx].device\n                beam_idx = beam_idx.to(device)\n                self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx)\n                self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx)\n\n    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:\n        \"\"\"Returns the sequence length of the cached states. A layer index can be optionally passed.\"\"\"\n        # take any layer that contains cache and not empty tensor\n        layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx\n        if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:\n            return 0\n        return self.key_cache[layer_idx].shape[-2]\n\n    def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:\n        \"\"\"\n        Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for\n        the given layer at `layer_idx`.\n        The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.\n        \"\"\"\n        kv_offset = 0\n        query_length = cache_position.shape[0]\n        past_seen_tokens = self.get_seq_length(layer_idx)\n        kv_length = query_length + past_seen_tokens\n        return kv_length, kv_offset\n\n    @property\n    def has_previous_state(self):\n        \"\"\"We have a previous state if the last linear (conv) layer was already updated.\"\"\"\n        return self.conv_states[self.last_linear_layer] is not None\n\n\nclass Qwen3NextRotaryEmbedding(nn.Module):\n    inv_freq: torch.Tensor  # fix linting for `register_buffer`\n\n    def __init__(self, config: Qwen3NextConfig, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and isinstance(config.rope_scaling, dict):\n            self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    @torch.no_grad()\n    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)\n    def forward(self, x, position_ids):\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos() * self.attention_scaling\n            sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nclass Qwen3NextRMSNorm(nn.Module):\n    def __init__(self, dim: int, eps: float = 1e-6):\n        super().__init__()\n        self.hidden_size = dim\n        self.variance_epsilon = eps\n        self.eps = eps\n        self.weight = nn.Parameter(torch.zeros(dim))\n\n    def _norm(self, x):\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        output = self._norm(x.float())\n        # Llama does x.to(float16) * w whilst Qwen3Next is (x * w).to(float16)\n        # See https://github.com/huggingface/transformers/pull/29402\n        output = output * (1.0 + self.weight.float())\n        return output.type_as(x)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.eps}\"\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Removes the interleaving of cos and sin from GLM\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n\n    # Keep half or full tensor for later concatenation\n    rotary_dim = cos.shape[-1]\n    q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]\n    k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]\n\n    # Apply rotary embeddings on the first half or full tensor\n    q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)\n    k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)\n\n    # Concatenate back to full shape\n    q_embed = torch.cat([q_embed, q_pass], dim=-1)\n    k_embed = torch.cat([k_embed, k_pass], dim=-1)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs: Unpack[TransformersKwargs],\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\nclass Qwen3NextAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: Qwen3NextConfig, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n        self.scaling = self.head_dim**-0.5\n        self.attention_dropout = config.attention_dropout\n        self.is_causal = True\n        self.q_proj = nn.Linear(\n            config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias\n        )\n        self.k_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.v_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.o_proj = nn.Linear(\n            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias\n        )\n        self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps)  # unlike olmo, only on the head dim!\n        self.k_norm = Qwen3NextRMSNorm(\n            self.head_dim, eps=config.rms_norm_eps\n        )  # thus post q_norm does not need reshape\n\n    @deprecate_kwarg(\"past_key_value\", new_name=\"past_key_values\", version=\"4.58\")\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_values: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states, gate = torch.chunk(\n            self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1\n        )\n        gate = gate.reshape(*input_shape, -1)\n\n        query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)\n        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_values is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            **kwargs,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = attn_output * torch.sigmoid(gate)\n\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\ndef apply_mask_to_padding_states(hidden_states, attention_mask):\n    \"\"\"\n    Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66\n    \"\"\"\n    if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:\n        dtype = hidden_states.dtype\n        hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)\n\n    return hidden_states\n\n\nis_fast_path_available = all(\n    (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)\n)\n\n\ndef torch_causal_conv1d_update(\n    hidden_states,\n    conv_state,\n    weight,\n    bias=None,\n    activation=None,\n):\n    _, hidden_size, seq_len = hidden_states.shape\n    state_len = conv_state.shape[-1]\n\n    hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)\n    conv_state.copy_(hidden_states_new[:, :, -state_len:])\n    out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)\n    out = F.silu(out[:, :, -seq_len:])\n    out = out.to(hidden_states.dtype)\n    return out\n\n\ndef torch_chunk_gated_delta_rule(\n    query,\n    key,\n    value,\n    g,\n    beta,\n    chunk_size=64,\n    initial_state=None,\n    output_final_state=False,\n    use_qk_l2norm_in_kernel=False,\n):\n    initial_dtype = query.dtype\n    if use_qk_l2norm_in_kernel:\n        query = F.normalize(query, p=2, dim=-1)\n        key = F.normalize(key, p=2, dim=-1)\n    query, key, value, beta, g = [\n        x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)\n    ]\n\n    batch_size, sequence_length, num_heads, k_head_dim = key.shape\n    v_head_dim = value.shape[-1]\n    pad_size = (chunk_size - num_heads % chunk_size) % chunk_size\n    query = F.pad(query, (0, 0, 0, pad_size))\n    key = F.pad(key, (0, 0, 0, pad_size))\n    value = F.pad(value, (0, 0, 0, pad_size))\n    beta = F.pad(beta, (0, pad_size))\n    g = F.pad(g, (0, pad_size))\n    tot_heads = num_heads + pad_size\n    scale = 1 / (query.shape[-1] ** 0.5)\n    query = query * scale\n\n    v_beta = value * beta.unsqueeze(-1)\n    k_beta = key * beta.unsqueeze(-1)\n    # reshape to chunks\n    query, key, value, k_beta, v_beta = [\n        x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)\n    ]\n    g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)\n    mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)\n\n    # chunk decay\n    g = g.cumsum(dim=-1)\n    decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()\n    attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)\n    for i in range(1, chunk_size):\n        row = attn[..., i, :i].clone()\n        sub = attn[..., :i, :i].clone()\n        attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)\n    attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)\n    value = attn @ v_beta\n    k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))\n    last_recurrent_state = (\n        torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)\n        if initial_state is None\n        else initial_state.to(value)\n    )\n    core_attn_out = torch.zeros_like(value)\n    mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)\n\n    # for each chunk\n    for i in range(0, tot_heads // chunk_size):\n        q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]\n        attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)\n        v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state\n        v_new = v_i - v_prime\n        attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state\n        core_attn_out[:, :, i] = attn_inter + attn @ v_new\n        last_recurrent_state = (\n            last_recurrent_state * g[:, :, i, -1, None, None].exp()\n            + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new\n        )\n\n    if not output_final_state:\n        last_recurrent_state = None\n    core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])\n    core_attn_out = core_attn_out[:, :, :num_heads]\n    core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)\n    return core_attn_out, last_recurrent_state\n\n\ndef torch_recurrent_gated_delta_rule(\n    query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False\n):\n    initial_dtype = query.dtype\n    if use_qk_l2norm_in_kernel:\n        query = F.normalize(query, p=2, dim=-1)\n        key = F.normalize(key, p=2, dim=-1)\n    query, key, value, beta, g = [\n        x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)\n    ]\n\n    batch_size, sequence_length, num_heads, k_head_dim = key.shape\n    v_head_dim = value.shape[-1]\n    scale = 1 / (query.shape[-1] ** 0.5)\n    query = query * scale\n\n    core_attn_out = torch.zeros(batch_size, sequence_length, num_heads, v_head_dim).to(value)\n    last_recurrent_state = (\n        torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)\n        if initial_state is None\n        else initial_state.to(value)\n    )\n\n    for i in range(num_heads):\n        q_t = query[:, :, i]\n        k_t = key[:, :, i]\n        v_t = value[:, :, i]\n        g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)\n        beta_t = beta[:, :, i].unsqueeze(-1)\n\n        last_recurrent_state = last_recurrent_state * g_t\n        kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)\n        delta = (v_t - kv_mem) * beta_t\n        last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)\n        core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)\n\n    if not output_final_state:\n        last_recurrent_state = None\n    core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)\n    return core_attn_out, last_recurrent_state\n\n\nclass Qwen3NextGatedDeltaNet(nn.Module):\n    def __init__(self, config: Qwen3NextConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.num_v_heads = config.linear_num_value_heads\n        self.num_k_heads = config.linear_num_key_heads\n        self.head_k_dim = config.linear_key_head_dim\n        self.head_v_dim = config.linear_value_head_dim\n        self.key_dim = self.head_k_dim * self.num_k_heads\n        self.value_dim = self.head_v_dim * self.num_v_heads\n\n        self.conv_kernel_size = config.linear_conv_kernel_dim\n        self.layer_idx = layer_idx\n        self.activation = config.hidden_act\n        self.act = ACT2FN[config.hidden_act]\n        self.layer_norm_epsilon = config.rms_norm_eps\n        \n        self.config = config\n\n        # QKV\n        self.conv_dim = self.key_dim * 2 + self.value_dim\n        self.conv1d = nn.Conv1d(\n            in_channels=self.conv_dim,\n            out_channels=self.conv_dim,\n            bias=False,\n            kernel_size=self.conv_kernel_size,\n            groups=self.conv_dim,\n            padding=self.conv_kernel_size - 1,\n        )\n\n        # projection of the input hidden states\n        projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2\n        projection_size_ba = self.num_v_heads * 2\n        self.in_proj_qkvz = nn.Linear(self.hidden_size, projection_size_qkvz, bias=False)\n        self.in_proj_ba = nn.Linear(self.hidden_size, projection_size_ba, bias=False)\n\n        # time step projection (discretization)\n        # instantiate once and copy inv_dt in init_weights of PretrainedModel\n        self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))\n\n        A = torch.empty(self.num_v_heads).uniform_(0, 16)\n        self.A_log = nn.Parameter(torch.log(A))\n\n        self.norm = (\n            Qwen3NextRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)\n            if FusedRMSNormGated is None\n            else FusedRMSNormGated(\n                self.head_v_dim,\n                eps=self.layer_norm_epsilon,\n                activation=self.activation,\n                device=torch.cuda.current_device(),\n                dtype=config.dtype if config.dtype is not None else torch.get_current_dtype(),\n            )\n        )\n\n        self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)\n\n        self.causal_conv1d_fn = causal_conv1d_fn\n        self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update\n        self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule\n        self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule\n\n        if not is_fast_path_available:\n            logger.warning_once(\n                \"The fast path is not available because one of the required library is not installed. Falling back to \"\n                \"torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and\"\n                \" https://github.com/Dao-AILab/causal-conv1d\"\n            )\n\n    def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):\n        \"\"\"\n        Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`.\n        \"\"\"\n\n        new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (\n            self.num_k_heads,\n            2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads,\n        )\n        new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)\n\n        mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)\n        mixed_ba = mixed_ba.view(*new_tensor_shape_ba)\n        split_arg_list_qkvz = [\n            self.head_k_dim,\n            self.head_k_dim,\n            (self.num_v_heads // self.num_k_heads * self.head_v_dim),\n            (self.num_v_heads // self.num_k_heads * self.head_v_dim),\n        ]\n        split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]\n        query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3)\n        b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3)\n        # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]\n        value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim)\n        z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim)\n        b = b.reshape(b.size(0), b.size(1), self.num_v_heads)\n        a = a.reshape(a.size(0), a.size(1), self.num_v_heads)\n        return query, key, value, z, b, a\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        cache_params: Optional[Qwen3NextDynamicCache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)\n\n        # Set up dimensions for reshapes later\n        batch_size, seq_len, _ = hidden_states.shape\n\n        use_precomputed_states = (\n            cache_params is not None\n            and cache_params.has_previous_state\n            and seq_len == 1\n            and cache_position is not None\n        )\n\n        # getting projected states from cache if it exists\n        if cache_params is not None:\n            conv_state = cache_params.conv_states[self.layer_idx]\n            recurrent_state = cache_params.recurrent_states[self.layer_idx]\n\n        projected_states_qkvz = self.in_proj_qkvz(hidden_states)\n        projected_states_ba = self.in_proj_ba(hidden_states)\n        query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)\n        query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))\n\n        mixed_qkv = torch.cat((query, key, value), dim=-1)\n        mixed_qkv = mixed_qkv.transpose(1, 2)\n\n        if use_precomputed_states:\n            # 2. Convolution sequence transformation\n            # NOTE: the conv state is updated in `causal_conv1d_update`\n            mixed_qkv = self.causal_conv1d_update(\n                mixed_qkv,\n                conv_state,\n                self.conv1d.weight.squeeze(1),\n                self.conv1d.bias,\n                self.activation,\n            )\n        else:\n            if cache_params is not None:\n                conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))\n                cache_params.conv_states[self.layer_idx] = conv_state\n            if self.causal_conv1d_fn is not None:\n                mixed_qkv = self.causal_conv1d_fn(\n                    x=mixed_qkv,\n                    weight=self.conv1d.weight.squeeze(1),\n                    bias=self.conv1d.bias,\n                    activation=self.activation,\n                    seq_idx=None,\n                )\n            else:\n                mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])\n\n        mixed_qkv = mixed_qkv.transpose(1, 2)\n        query, key, value = torch.split(\n            mixed_qkv,\n            [\n                self.key_dim,\n                self.key_dim,\n                self.value_dim,\n            ],\n            dim=-1,\n        )\n        query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)\n        key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)\n        value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)\n\n        beta = b.sigmoid()\n        # If the model is loaded in fp16, without the .float() here, A might be -inf\n        g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)\n        if self.num_v_heads // self.num_k_heads > 1:\n            query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)\n            key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)\n\n        if not use_precomputed_states:\n            core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(\n                query,\n                key,\n                value,\n                g=g,\n                beta=beta,\n                initial_state=None,\n                output_final_state=cache_params is not None,\n                use_qk_l2norm_in_kernel=True,\n            )\n\n        else:\n            core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(\n                query,\n                key,\n                value,\n                g=g,\n                beta=beta,\n                initial_state=recurrent_state,\n                output_final_state=cache_params is not None,\n                use_qk_l2norm_in_kernel=True,\n            )\n\n        # Update cache\n        if cache_params is not None:\n            cache_params.recurrent_states[self.layer_idx] = last_recurrent_state\n\n        z_shape_og = z.shape\n        # reshape input data into 2D tensor\n        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])\n        z = z.reshape(-1, z.shape[-1])\n        core_attn_out = self.norm(core_attn_out, z)\n        core_attn_out = core_attn_out.reshape(z_shape_og)\n        core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)\n\n        output = self.out_proj(core_attn_out)\n        return output\n\n\nclass Qwen3NextMLP(nn.Module):\n    def __init__(self, config, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass Qwen3NextSparseMoeBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n\n        # gating\n        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)\n        self.experts = nn.ModuleList(\n            [Qwen3NextMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]\n        )\n\n        self.shared_expert = Qwen3NextMLP(config, intermediate_size=config.shared_expert_intermediate_size)\n        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        if self.norm_topk_prob:\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()\n        for expert_idx in expert_hit:\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))\n\n        shared_expert_output = self.shared_expert(hidden_states)\n        shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output\n\n        final_hidden_states = final_hidden_states + shared_expert_output\n\n        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n        return final_hidden_states, router_logits\n\n\nclass Qwen3NextDecoderLayer(GradientCheckpointingLayer):\n    def __init__(self, config: Qwen3NextConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        # token mixer\n        self.layer_type = config.layer_types[layer_idx]\n        if self.layer_type == \"linear_attention\":\n            self.linear_attn = Qwen3NextGatedDeltaNet(config, layer_idx)\n        elif self.layer_type == \"full_attention\":\n            self.self_attn = Qwen3NextAttention(config, layer_idx)\n\n        if (layer_idx not in config.mlp_only_layers) and (\n            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0\n        ):\n            self.mlp = Qwen3NextSparseMoeBlock(config)\n        else:\n            self.mlp = Qwen3NextMLP(config, intermediate_size=config.intermediate_size)\n\n        self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    @deprecate_kwarg(\"past_key_value\", new_name=\"past_key_values\", version=\"4.58\")\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[tuple[torch.Tensor]] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> torch.FloatTensor:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_router_logits (`bool`, *optional*):\n                Whether or not to return the logits of all the routers. They are useful for computing the router loss,\n                and should not be returned during inference.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):\n                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,\n                with `head_dim` being the embedding dimension of each attention head.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Token Mixer\n        if self.layer_type == \"linear_attention\":\n            hidden_states = self.linear_attn(\n                hidden_states=hidden_states,\n                cache_params=past_key_values,\n                cache_position=cache_position,\n                attention_mask=attention_mask,\n            )\n        elif self.layer_type == \"full_attention\":\n            # Self Attention\n            hidden_states, _ = self.self_attn(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_values=past_key_values,\n                cache_position=cache_position,\n                position_embeddings=position_embeddings,\n                **kwargs,\n            )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        # For the MoE layers, we need to unpack\n        if isinstance(hidden_states, tuple):\n            hidden_states, _ = hidden_states\n        hidden_states = residual + hidden_states\n\n        return hidden_states\n\n\nclass Qwen3NextPreTrainedModel(PreTrainedModel):\n    config: Qwen3NextConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Qwen3NextDecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _keys_to_ignore_on_load_unexpected = [r\"^mtp.*\"]\n    _can_record_outputs = {\n        \"router_logits\": OutputRecorder(Qwen3NextSparseMoeBlock, index=1),\n        \"hidden_states\": Qwen3NextDecoderLayer,\n        \"attentions\": Qwen3NextAttention,\n    }\n    _is_stateful = True\n\n    def _init_weights(self, module):\n        super()._init_weights(module)\n        if isinstance(module, Qwen3NextGatedDeltaNet):\n            module.dt_bias.data.fill_(1.0)\n            module.A_log.data.uniform_(0, 16).log_()\n\n\nclass Qwen3NextModel(Qwen3NextPreTrainedModel):\n    def __init__(self, config: Qwen3NextConfig):\n        super().__init__(config)\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)\n        self.layers = nn.ModuleList(\n            [Qwen3NextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = Qwen3NextRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    @check_model_inputs\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[TransformersKwargs],\n    ) -> MoeModelOutputWithPast:\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if use_cache and past_key_values is None:\n            past_key_values = Qwen3NextDynamicCache(config=self.config)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = create_causal_mask(\n            config=self.config,\n            input_embeds=inputs_embeds,\n            attention_mask=attention_mask,\n            cache_position=cache_position,\n            past_key_values=past_key_values,\n            position_ids=position_ids,\n        )\n        linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position)\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        for decoder_layer in self.layers[: self.config.num_hidden_layers]:\n            layer_mask = linear_attn_mask if decoder_layer.layer_type == \"linear_attention\" else causal_mask\n\n            hidden_states = decoder_layer(\n                hidden_states,\n                position_embeddings=position_embeddings,\n                attention_mask=layer_mask,\n                position_ids=position_ids,\n                past_key_values=past_key_values,\n                use_cache=use_cache,\n                cache_position=cache_position,\n                **kwargs,\n            )\n\n        hidden_states = self.norm(hidden_states)\n\n        return MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n        )\n\n    def _update_linear_attn_mask(self, attention_mask, cache_position):\n        \"\"\"\n        NOTE: Left-padding is used for linear attention mask.\n        No need for zeroing states when\n            1. Cached forward\n            2. Attending to all inputs\n        \"\"\"\n        linear_attn_mask = attention_mask\n        if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):\n            linear_attn_mask = None\n        return linear_attn_mask\n\n\ndef load_balancing_loss_func(\n    gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],\n    num_experts: Optional[int] = None,\n    top_k=2,\n    attention_mask: Optional[torch.Tensor] = None,\n) -> Union[torch.Tensor, int]:\n    r\"\"\"\n    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.\n\n    See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss\n    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between\n    experts is too unbalanced.\n\n    Args:\n        gate_logits:\n            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of\n            shape [batch_size X sequence_length, num_experts].\n        num_experts:\n            Number of experts\n        top_k:\n            The number of experts to route per-token, can be also interpreted as the `top-k` routing\n            parameter.\n        attention_mask (`torch.Tensor`, *optional*):\n            The attention_mask used in forward function\n            shape [batch_size X sequence_length] if not None.\n\n    Returns:\n        The auxiliary loss.\n    \"\"\"\n    if gate_logits is None or not isinstance(gate_logits, tuple):\n        return 0\n\n    if isinstance(gate_logits, tuple):\n        compute_device = gate_logits[0].device\n        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)\n\n    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)\n\n    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)\n\n    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)\n\n    if attention_mask is None:\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.mean(routing_weights, dim=0)\n    else:\n        batch_size, sequence_length = attention_mask.shape\n        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask\n        expert_attention_mask = (\n            attention_mask[None, :, :, None, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))\n            .reshape(-1, top_k, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(\n            expert_attention_mask, dim=0\n        )\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert\n        router_per_expert_attention_mask = (\n            attention_mask[None, :, :, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))\n            .reshape(-1, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(\n            router_per_expert_attention_mask, dim=0\n        )\n\n    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))\n    return overall_loss * num_experts\n\n\n@auto_docstring\nclass Qwen3NextForCausalLM(Qwen3NextPreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    _tp_plan = {\"lm_head\": \"colwise_rep\"}\n    _pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = Qwen3NextModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.router_aux_loss_coef = config.router_aux_loss_coef\n        self.num_experts = config.num_experts\n        self.num_experts_per_tok = config.num_experts_per_tok\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    @can_return_tuple\n    @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Qwen3NextDynamicCache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[TransformersKwargs],\n    ) -> MoeCausalLMOutputWithPast:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Qwen3NextForCausalLM\n\n        >>> model = Qwen3NextForCausalLM.from_pretrained(\"Qwen/Qwen3-Next-80B-A3B-Instruct\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen3-Next-80B-A3B-Instruct\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs: MoeModelOutputWithPast = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_router_logits=output_router_logits,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = outputs.last_hidden_state\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)\n\n        aux_loss = None\n        if output_router_logits:\n            aux_loss = load_balancing_loss_func(\n                outputs.router_logits,\n                self.num_experts,\n                self.num_experts_per_tok,\n                attention_mask,\n            )\n            if labels is not None:\n                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device\n\n        return MoeCausalLMOutputWithPast(\n            loss=loss,\n            aux_loss=aux_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            router_logits=outputs.router_logits,\n        )\n\n\nclass Qwen3NextForSequenceClassification(GenericForSequenceClassification, Qwen3NextPreTrainedModel):\n    pass\n\n\nclass Qwen3NextForTokenClassification(GenericForTokenClassification, Qwen3NextPreTrainedModel):\n    pass\n\n\nclass Qwen3NextForQuestionAnswering(GenericForQuestionAnswering, Qwen3NextPreTrainedModel):\n    base_model_prefix = \"transformer\"  # For BC, where `transformer` was used instead of `model`\n\n\n__all__ = [\n    \"Qwen3NextForCausalLM\",\n    \"Qwen3NextForQuestionAnswering\",\n    \"Qwen3NextModel\",\n    \"Qwen3NextPreTrainedModel\",\n    \"Qwen3NextForSequenceClassification\",\n    \"Qwen3NextForTokenClassification\",\n]"
  },
  {
    "path": "archive/ktransformers/models/modeling_smallthinker.py",
    "content": "# coding=utf-8\nfrom functools import partial\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache\nfrom transformers.generation import GenerationMixin\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\nfrom transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_outputs import (\n    MoeCausalLMOutputWithPast,\n    MoeModelOutputWithPast\n)\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom transformers.processing_utils import Unpack\nfrom transformers.utils import can_return_tuple, is_torch_flex_attn_available, logging\nfrom .configuration_smallthinker import SmallthinkerConfig\n\n\nif is_torch_flex_attn_available():\n    from torch.nn.attention.flex_attention import BlockMask\n\n    from transformers.integrations.flex_attention import make_flex_block_causal_mask\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass SmallthinkerHierarchicalMLP(nn.Module):\n    def __init__(self, config: SmallthinkerConfig):\n        super().__init__()\n        self.config = config\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.moe_ffn_hidden_size\n        self.moe_enable_secondary_experts = config.moe_enable_secondary_experts\n        if self.moe_enable_secondary_experts:\n            self.num_secondary_experts = config.moe_num_secondary_experts\n            self.secondary_expert_size = config.moe_secondary_expert_size\n            self.secondary_gate = nn.Linear(self.hidden_dim, self.num_secondary_experts, bias=False)\n\n        self.up = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)\n        self.gate = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)\n        self.down = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)\n        \n    def forward(self, secondary_gate_input: torch.Tensor, hidden_states: torch.Tensor):\n        if self.moe_enable_secondary_experts:\n            secondary_gate_logits = F.sigmoid(self.secondary_gate(secondary_gate_input)) > 0.5\n            secondary_gate_mask = secondary_gate_logits.unsqueeze(-1)\n\n        current_hidden_states = self.up(hidden_states) * F.relu(self.gate(hidden_states))\n        activated_output =  current_hidden_states\n        batch_size, intermediate_size = activated_output.shape\n\n        if self.moe_enable_secondary_experts:\n            num_groups = intermediate_size // self.secondary_expert_size\n            activated_output = activated_output.view(batch_size, num_groups, self.secondary_expert_size)\n            output = activated_output * secondary_gate_mask\n        else:\n            output = activated_output\n\n        current_hidden_states = output.view(batch_size, -1)\n        current_hidden_states = self.down(current_hidden_states)\n        return current_hidden_states\n\n\nclass SmallthinkerMoeBlock(nn.Module):\n    def __init__(self, config: SmallthinkerConfig):\n        super().__init__()\n        self.hidden_dim = config.hidden_size\n        self.num_primary_experts = config.moe_num_primary_experts\n        self.enable_early_router = config.moe_enable_early_router\n        self.moe_primary_router_apply_softmax = config.moe_primary_router_apply_softmax\n        self.num_active_primary_experts = config.moe_num_active_primary_experts\n        self.primary_router = nn.Linear(self.hidden_dim, self.num_primary_experts, bias=False)\n        self.experts = nn.ModuleList([SmallthinkerHierarchicalMLP(config) for _ in range(self.num_primary_experts)])\n\n    def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        # Flatten the tokens into (bs * sl, hidden_dim)\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        router_input = router_input.view(-1, hidden_dim)\n        # Primary router logits: (bs * sl, n_experts)\n        if self.enable_early_router:\n            router_logits = self.primary_router(router_input)\n        else:\n            router_logits = self.primary_router(hidden_states)\n\n        router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1)\n\n        if self.moe_primary_router_apply_softmax:\n            routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        else:\n            routing_weights = F.sigmoid(router_logits)\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        # Prepare the final tensor\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_primary_experts).permute(2, 1, 0)\n        expert_hitted = (expert_mask.sum(dim=(-1, -2)) > 0).nonzero(as_tuple=True)[0].tolist()\n\n        for expert_idx in expert_hitted:\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)\n            # current_router_input = router_input[None, top_x].reshape(-1, hidden_dim)\n            current_state = hidden_states[top_x].reshape(-1, hidden_dim)\n            current_router_input = router_input[top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer(current_router_input, current_state) * routing_weights[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))\n        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n        return final_hidden_states, router_logits\n    \n\nclass SmallthinkerDenseMlpBlock(nn.Module):\n    def __init__(self, config: SmallthinkerConfig):\n        super().__init__()\n        hidden_dim = config.hidden_size\n        ffn_dim = config.dense_ffn_hidden_size\n        self.up = nn.Linear(hidden_dim, ffn_dim, bias=False)\n        self.gate = nn.Linear(hidden_dim, ffn_dim, bias=False)\n        self.down = nn.Linear(ffn_dim, hidden_dim, bias=False)\n\n    # Offer unified interface for SmallthinkerMoeBlock and SmallthinkerDenseMlpBlock, though router_input is not used here\n    def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:\n        current_hidden_states = self.up(hidden_states) * F.relu(self.gate(hidden_states))\n        current_hidden_states = self.down(current_hidden_states)\n        return current_hidden_states, None\n\n\nclass SmallthinkerRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        SmallthinkerRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs,\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\nclass SmallthinkerAttention(nn.Module):\n    def __init__(self, config: SmallthinkerConfig, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx # For KVCache management\n        self.head_dim = config.head_dim\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n        self.scaling = self.head_dim**-0.5\n        self.is_causal = True\n        self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)\n        self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)\n        self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)\n        self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)\n        self.sliding_window = config.sliding_window_size if config.sliding_window_layout[layer_idx] else None\n        self.use_qk_norm = config.use_qk_norm\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if self.use_qk_norm:\n            raise NotImplementedError(\"use_qk_norm is not implemented yet\")\n\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n        \n        if position_embeddings:\n            cos, sin = position_embeddings\n            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n        else:\n            cos, sin = None, None\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        attention_interface: Callable = eager_attention_forward\n\n        if self.config._attn_implementation == \"sdpa\":\n            raise NotImplementedError(\"SDPA impl is buggy for now. NEVER TRY TO USE IT.\")\n\n        if self.config._attn_implementation != \"eager\":\n            if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n                logger.warning_once(\n                    \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to \"\n                    'eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n                )\n            else:\n                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0,\n            scaling=self.scaling,\n            sliding_window=self.sliding_window,  # main diff with Llama\n            **kwargs,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\nclass SmallthinkerDecoderLayer(nn.Module):\n    def __init__(self, config: SmallthinkerConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.self_attn = SmallthinkerAttention(config, layer_idx)\n\n        self.block_sparse_moe = SmallthinkerMoeBlock(config) if config.moe_layer_layout[layer_idx] else SmallthinkerDenseMlpBlock(config)\n        self.input_layernorm = SmallthinkerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = SmallthinkerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        output_router_logits: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC\n        **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_router_logits (`bool`, *optional*):\n                Whether or not to return the logits of all the routers. They are useful for computing the router loss, and\n                should not be returned during inference.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n        # print(f\"hidden states, shape {hidden_states.shape}: {hidden_states}\") # debug print\n        residual = hidden_states\n        router_input = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n        # Self Attention \n        hidden_states, self_attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            position_embeddings=position_embeddings,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states, router_logits = self.block_sparse_moe(router_input, hidden_states)\n        hidden_states = residual + hidden_states # SYNC after_moe_residual_value=hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if output_router_logits:\n            outputs += (router_logits,)\n\n        return outputs\n\n\nclass SmallthinkerRotaryEmbedding(nn.Module):\n    def __init__(self, config: SmallthinkerConfig, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:\n            self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    @torch.no_grad()\n    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)\n    def forward(self, x, position_ids):\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos() * self.attention_scaling\n            sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nclass SmallthinkerPreTrainedModel(PreTrainedModel):\n    config_class = SmallthinkerConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"SmallthinkerDecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported)\n    _supports_attention_backend = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n        elif isinstance(module, SmallthinkerRMSNorm):\n            module.weight.data.fill_(1.0)\n\n\n# @auto_docstring\nclass SmallthinkerModel(SmallthinkerPreTrainedModel):\n    def __init__(self, config: SmallthinkerConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [SmallthinkerDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = SmallthinkerRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = SmallthinkerRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n        self.rope_layout = config.rope_layout\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @can_return_tuple\n    # @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],\n    ) -> MoeModelOutputWithPast:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache()\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        # print(\"atten mask:\", attention_mask) # debug print\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        )\n\n        # print(\"causal mask:\", causal_mask) # debug print\n        hidden_states = inputs_embeds\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        all_router_logits = () if output_router_logits else None\n\n        for layer_idx, decoder_layer in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    partial(decoder_layer.__call__, **flash_attn_kwargs),\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    output_router_logits,\n                    use_cache,\n                    cache_position,\n                    position_embeddings if self.rope_layout[layer_idx] else None,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    output_router_logits=output_router_logits,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings if self.rope_layout[layer_idx] else None,\n                    **flash_attn_kwargs,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n            if output_router_logits:\n                all_router_logits += (layer_outputs[-1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        return MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            router_logits=all_router_logits,\n        )\n\n    def _update_causal_mask(\n        self,\n        attention_mask: Union[torch.Tensor, \"BlockMask\"],\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool = False,\n    ):\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and past_key_values is not None:\n                is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]\n                if is_padding_right:\n                    raise ValueError(\n                        \"You are attempting to perform batched generation with padding_side='right'\"\n                        \" this may lead to unexpected behaviour for Flash Attention version of Smallthinker. Make sure to \"\n                        \" call `tokenizer.padding_side  = 'left'` before tokenizing the input. \"\n                    )\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n        if self.config._attn_implementation == \"flex_attention\":\n            if isinstance(attention_mask, torch.Tensor):\n                attention_mask = make_flex_block_causal_mask(attention_mask)\n            return attention_mask\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n        using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and not (using_static_cache or using_sliding_window_cache)\n            and not output_attentions\n        ):\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                sliding_window=self.config.sliding_window,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype = input_tensor.dtype\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        # SlidingWindowCache or StaticCache\n        if using_sliding_window_cache or using_static_cache:\n            target_length = past_key_values.get_max_cache_shape()\n        # DynamicCache or no cache\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).\n        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(\n            attention_mask,\n            sequence_length=sequence_length,\n            target_length=target_length,\n            dtype=dtype,\n            cache_position=cache_position,\n            batch_size=input_tensor.shape[0],\n            config=self.config,\n            past_key_values=past_key_values,\n        )\n\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type in [\"cuda\", \"xpu\", \"npu\"]\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n    @staticmethod\n    def _prepare_4d_causal_attention_mask_with_cache_position(\n        attention_mask: torch.Tensor,\n        sequence_length: int,\n        target_length: int,\n        dtype: torch.dtype,\n        cache_position: torch.Tensor,\n        batch_size: int,\n        config: SmallthinkerConfig,\n        past_key_values: Cache,\n    ):\n        \"\"\"\n        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape\n        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.\n            sequence_length (`int`):\n                The sequence length being processed.\n            target_length (`int`):\n                The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.\n            dtype (`torch.dtype`):\n                The dtype to use for the 4D attention mask.\n            cache_position (`torch.Tensor`):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            batch_size (`torch.Tensor`):\n                Batch size.\n            config (`SmallthinkerConfig`):\n                The model's configuration class\n            past_key_values (`Cache`):\n                The cache class that is being used currently to generate\n        \"\"\"\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n            causal_mask = attention_mask\n        else:\n            min_dtype = torch.finfo(dtype).min\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device\n            )\n            diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(\n                -1, 1\n            )\n            if config.get_text_config().sliding_window is not None:\n                # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also\n                # the check is needed to verify is current checkpoint was trained with sliding window or not\n                if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:\n                    sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (\n                        cache_position.reshape(-1, 1) - config.get_text_config().sliding_window\n                    )\n                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)\n            causal_mask *= diagonal_attend_mask\n            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                if attention_mask.shape[-1] > target_length:\n                    attention_mask = attention_mask[:, :target_length]\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(\n                    causal_mask.device\n                )\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        return causal_mask\n\n\nclass KwargsForCausalLM(FlashAttentionKwargs): ...\n\n\ndef load_balancing_loss_func(\n    gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],\n    num_experts: Optional[int] = None,\n    top_k=2,\n    attention_mask: Optional[torch.Tensor] = None,\n) -> Union[torch.Tensor, int]:\n    r\"\"\"\n    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.\n\n    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss\n    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between\n    experts is too unbalanced.\n\n    Args:\n        gate_logits:\n            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of\n            shape [batch_size X sequence_length, num_experts].\n        num_experts:\n            Number of experts\n        top_k:\n            The number of experts to route per-token, can be also interpreted as the `top-k` routing\n            parameter.\n        attention_mask (`torch.Tensor`, *optional*):\n            The attention_mask used in forward function\n            shape [batch_size X sequence_length] if not None.\n\n    Returns:\n        The auxiliary loss.\n    \"\"\"\n    if gate_logits is None or not isinstance(gate_logits, tuple):\n        return 0\n\n    if isinstance(gate_logits, tuple):\n        compute_device = gate_logits[0].device\n        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)\n\n    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)\n\n    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)\n\n    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)\n\n    if attention_mask is None:\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.mean(routing_weights, dim=0)\n    else:\n        batch_size, sequence_length = attention_mask.shape\n        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask\n        expert_attention_mask = (\n            attention_mask[None, :, :, None, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))\n            .reshape(-1, top_k, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(\n            expert_attention_mask, dim=0\n        )\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert\n        router_per_expert_attention_mask = (\n            attention_mask[None, :, :, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))\n            .reshape(-1, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(\n            router_per_expert_attention_mask, dim=0\n        )\n\n    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))\n    return overall_loss * num_experts\n\n\n# @auto_docstring\nclass SmallThinkerForCausalLM(SmallthinkerPreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = SmallthinkerModel(config)\n        self.vocab_size = config.vocab_size\n        # Handle tie / untie word embeddings\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)        \n        # self.num_experts = config.num_local_experts\n        # self.num_experts_per_tok = config.num_experts_per_tok\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @can_return_tuple\n#     @auto_docstring\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        **kwargs: Unpack[KwargsForCausalLM],\n    ) -> MoeCausalLMOutputWithPast:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, SmallThinkerForCausalLM\n\n        >>> model = SmallThinkerForCausalLM.from_pretrained(\"mistralai/Smallthinker-8x7B-v0.1\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Smallthinker-8x7B-v0.1\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs: MoeModelOutputWithPast = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            output_router_logits=output_router_logits,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = outputs.last_hidden_state\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)\n\n        aux_loss = None\n        if output_router_logits:\n            aux_loss = load_balancing_loss_func(\n                outputs.router_logits,\n                self.num_experts,\n                self.num_experts_per_tok,\n                attention_mask,\n            )\n            if labels is not None:\n                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device\n\n        return MoeCausalLMOutputWithPast(\n            loss=loss,\n            aux_loss=aux_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            router_logits=outputs.router_logits,\n        )\n\n# No such functions for now\n# #@auto_docstring(\n#     custom_intro=\"\"\"\n#     The Smallthinker Model transformer with a sequence classification head on top (linear layer).\n\n#     [`SmallthinkerForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n#     (e.g. GPT-2) do.\n\n#     Since it does classification on the last token, it requires to know the position of the last token. If a\n#     `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n#     no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n#     padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n#     each row of the batch).\n#     \"\"\"\n# )\n# class SmallthinkerForSequenceClassification(SmallthinkerPreTrainedModel):\n#     def __init__(self, config):\n#         super().__init__(config)\n#         self.num_labels = config.num_labels\n#         self.model = SmallthinkerModel(config)\n#         self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n#         # Initialize weights and apply final processing\n#         self.post_init()\n\n#     def get_input_embeddings(self):\n#         return self.model.embed_tokens\n\n#     def set_input_embeddings(self, value):\n#         self.model.embed_tokens = value\n\n#     @can_return_tuple\n#     #@auto_docstring\n#     def forward(\n#         self,\n#         input_ids: Optional[torch.LongTensor] = None,\n#         attention_mask: Optional[torch.Tensor] = None,\n#         position_ids: Optional[torch.LongTensor] = None,\n#         past_key_values: Optional[Cache] = None,\n#         inputs_embeds: Optional[torch.FloatTensor] = None,\n#         labels: Optional[torch.LongTensor] = None,\n#         use_cache: Optional[bool] = None,\n#         output_attentions: Optional[bool] = None,\n#         output_hidden_states: Optional[bool] = None,\n#     ) -> SequenceClassifierOutputWithPast:\n#         r\"\"\"\n#         labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n#             Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n#             config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n#             `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n#         \"\"\"\n\n#         transformer_outputs: BaseModelOutputWithPast = self.model(\n#             input_ids,\n#             attention_mask=attention_mask,\n#             position_ids=position_ids,\n#             past_key_values=past_key_values,\n#             inputs_embeds=inputs_embeds,\n#             use_cache=use_cache,\n#             output_attentions=output_attentions,\n#             output_hidden_states=output_hidden_states,\n#         )\n#         hidden_states = transformer_outputs.last_hidden_state\n#         logits = self.score(hidden_states)\n\n#         if input_ids is not None:\n#             batch_size = input_ids.shape[0]\n#         else:\n#             batch_size = inputs_embeds.shape[0]\n\n#         if self.config.pad_token_id is None and batch_size != 1:\n#             raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n#         if self.config.pad_token_id is None:\n#             last_non_pad_token = -1\n#         elif input_ids is not None:\n#             # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id\n#             non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)\n#             token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)\n#             last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)\n#         else:\n#             last_non_pad_token = -1\n#             logger.warning_once(\n#                 f\"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be \"\n#                 \"unexpected if using padding tokens in conjunction with `inputs_embeds.`\"\n#             )\n\n#         pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]\n\n#         loss = None\n#         if labels is not None:\n#             loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)\n\n#         return SequenceClassifierOutputWithPast(\n#             loss=loss,\n#             logits=pooled_logits,\n#             past_key_values=transformer_outputs.past_key_values,\n#             hidden_states=transformer_outputs.hidden_states,\n#             attentions=transformer_outputs.attentions,\n#         )\n\n\n# #@auto_docstring\n# class SmallthinkerForTokenClassification(SmallthinkerPreTrainedModel):\n#     def __init__(self, config):\n#         super().__init__(config)\n#         self.num_labels = config.num_labels\n#         self.model = SmallthinkerModel(config)\n#         if getattr(config, \"classifier_dropout\", None) is not None:\n#             classifier_dropout = config.classifier_dropout\n#         elif getattr(config, \"hidden_dropout\", None) is not None:\n#             classifier_dropout = config.hidden_dropout\n#         else:\n#             classifier_dropout = 0.1\n#         self.dropout = nn.Dropout(classifier_dropout)\n#         self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n#         # Initialize weights and apply final processing\n#         self.post_init()\n\n#     def get_input_embeddings(self):\n#         return self.model.embed_tokens\n\n#     def set_input_embeddings(self, value):\n#         self.model.embed_tokens = value\n\n#     @can_return_tuple\n#     #@auto_docstring\n#     def forward(\n#         self,\n#         input_ids: Optional[torch.LongTensor] = None,\n#         attention_mask: Optional[torch.Tensor] = None,\n#         position_ids: Optional[torch.LongTensor] = None,\n#         past_key_values: Optional[Cache] = None,\n#         inputs_embeds: Optional[torch.FloatTensor] = None,\n#         labels: Optional[torch.LongTensor] = None,\n#         use_cache: Optional[bool] = None,\n#         output_attentions: Optional[bool] = None,\n#         output_hidden_states: Optional[bool] = None,\n#     ) -> TokenClassifierOutput:\n#         r\"\"\"\n#         labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n#             Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n#             config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n#             `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n#         \"\"\"\n\n#         outputs: BaseModelOutputWithPast = self.model(\n#             input_ids,\n#             attention_mask=attention_mask,\n#             position_ids=position_ids,\n#             past_key_values=past_key_values,\n#             inputs_embeds=inputs_embeds,\n#             use_cache=use_cache,\n#             output_attentions=output_attentions,\n#             output_hidden_states=output_hidden_states,\n#         )\n#         sequence_output = outputs.last_hidden_state\n#         sequence_output = self.dropout(sequence_output)\n#         logits = self.score(sequence_output)\n\n#         loss = None\n#         if labels is not None:\n#             loss = self.loss_function(logits, labels, self.config)\n\n#         return TokenClassifierOutput(\n#             loss=loss,\n#             logits=logits,\n#             hidden_states=outputs.hidden_states,\n#             attentions=outputs.attentions,\n#         )\n\n\n# #@auto_docstring\n# class SmallthinkerForQuestionAnswering(SmallthinkerPreTrainedModel):\n#     base_model_prefix = \"model\"\n\n#     def __init__(self, config):\n#         super().__init__(config)\n#         self.qa_outputs = nn.Linear(config.hidden_size, 2)\n#         self.model = SmallthinkerModel(config)  # diff with Llama: transformer->model\n\n#         # Initialize weights and apply final processing\n#         self.post_init()\n\n#     def get_input_embeddings(self):\n#         return self.model.embed_tokens\n\n#     def set_input_embeddings(self, value):\n#         self.model.embed_tokens = value\n\n#     @can_return_tuple\n#     #@auto_docstring\n#     def forward(\n#         self,\n#         input_ids: Optional[torch.LongTensor] = None,\n#         attention_mask: Optional[torch.Tensor] = None,\n#         position_ids: Optional[torch.LongTensor] = None,\n#         past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n#         inputs_embeds: Optional[torch.FloatTensor] = None,\n#         start_positions: Optional[torch.LongTensor] = None,\n#         end_positions: Optional[torch.LongTensor] = None,\n#         output_attentions: Optional[bool] = None,\n#         output_hidden_states: Optional[bool] = None,\n#         **kwargs,\n#     ) -> QuestionAnsweringModelOutput:\n#         r\"\"\"\n#         start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n#             Labels for position (index) of the start of the labelled span for computing the token classification loss.\n#             Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n#             are not taken into account for computing the loss.\n#         end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n#             Labels for position (index) of the end of the labelled span for computing the token classification loss.\n#             Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n#             are not taken into account for computing the loss.\n#         \"\"\"\n\n#         outputs: BaseModelOutputWithPast = self.model(\n#             input_ids,\n#             attention_mask=attention_mask,\n#             position_ids=position_ids,\n#             past_key_values=past_key_values,\n#             inputs_embeds=inputs_embeds,\n#             output_attentions=output_attentions,\n#             output_hidden_states=output_hidden_states,\n#         )\n\n#         sequence_output = outputs.last_hidden_state\n\n#         logits = self.qa_outputs(sequence_output)\n#         start_logits, end_logits = logits.split(1, dim=-1)\n#         start_logits = start_logits.squeeze(-1).contiguous()\n#         end_logits = end_logits.squeeze(-1).contiguous()\n\n#         loss = None\n#         if start_positions is not None and end_positions is not None:\n#             loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)\n\n#         return QuestionAnsweringModelOutput(\n#             loss=loss,\n#             start_logits=start_logits,\n#             end_logits=end_logits,\n#             hidden_states=outputs.hidden_states,\n#             attentions=outputs.attentions,\n#         )\n\n\n__all__ = [\n    \"SmallThinkerForCausalLM\",\n    \"SmallthinkerForQuestionAnswering\",\n    \"SmallthinkerModel\",\n    \"SmallthinkerPreTrainedModel\",\n    \"SmallthinkerForSequenceClassification\",\n    \"SmallthinkerForTokenClassification\",\n]\n\nif __name__ == \"__main__\":\n    from transformers import AutoTokenizer, AutoModelForCausalLM\n\n    test_config = SmallthinkerConfig()\n    tokenizer = AutoTokenizer.from_pretrained(\"./qwen-tokenizer\")\n    text = \"Once upon a day\"\n    tokens = tokenizer.encode_plus( text,add_special_tokens=True,return_tensors='pt')\n    # print(tokens)\n    test_model = AutoModelForCausalLM.from_pretrained(\".\").cuda()\n\n    output = test_model.generate(tokens)\n    otokens = tokenizer.decode(output[0])\n    # print(otokens)\n"
  },
  {
    "path": "archive/ktransformers/operators/RoPE.py",
    "content": "\"\"\"\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\n\nfrom torch import nn\nfrom transformers import ROPE_INIT_FUNCTIONS\nfrom ktransformers.models.modeling_llama import (\n    LlamaRotaryEmbedding,\n    LlamaLinearScalingRotaryEmbedding,\n    LlamaDynamicNTKScalingRotaryEmbedding,\n)\nfrom ktransformers.models.modeling_deepseek_v3 import (\n    DeepseekV3RotaryEmbedding\n)\nfrom ktransformers.models.modeling_deepseek import (\n    DeepseekV2YarnRotaryEmbedding,\n    DeepseekV2RotaryEmbedding,\n    yarn_get_mscale,\n    yarn_linear_ramp_mask,\n    yarn_find_correction_range\n)\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom ktransformers.util.utils import InferenceState\nfrom transformers.configuration_utils import PretrainedConfig\nfrom ktransformers.models.modeling_smallthinker import SmallthinkerRotaryEmbedding\nfrom ktransformers.models.modeling_glm4_moe import Glm4MoeRotaryEmbedding\nimport torch\n\n# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe\nclass RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            orig_module.dim, orig_module.max_position_embeddings, orig_module.base\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.dim,\n            self.orig_module.max_position_embeddings,\n            self.orig_module.base,\n            self.device,\n        )\n\n\nclass RotaryEmbeddingV3(BaseInjectedModule):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n    \n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)   \n\n    def load(self):\n        self._init(\n            dim=self.config.qk_rope_head_dim,\n            max_position_embeddings=self.config.max_position_embeddings,\n            base=self.config.rope_theta,\n            device=self.device,\n        )\n    def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        # self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\nclass RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            orig_module.dim,\n            orig_module.max_position_embeddings,\n            orig_module.base,\n            None,\n            orig_module.scaling_factor,\n            orig_module.rope_type,\n            orig_module.config,\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.dim,\n            self.orig_module.max_position_embeddings,\n            self.orig_module.base,\n            self.device,\n            self.orig_module.scaling_factor,\n            self.orig_module.rope_type,\n            self.orig_module.config,\n        )\n\nclass YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            orig_module.dim,\n            orig_module.max_position_embeddings,\n            orig_module.base,\n            None,  # device\n            orig_module.scaling_factor,\n            orig_module.original_max_position_embeddings,\n            orig_module.beta_fast,\n            orig_module.beta_slow,\n            orig_module.mscale,\n            orig_module.mscale_all_dim,\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.dim,\n            self.orig_module.max_position_embeddings,\n            self.orig_module.base,\n            self.generate_device,\n            self.orig_module.scaling_factor,\n            self.orig_module.original_max_position_embeddings,\n            self.orig_module.beta_fast,\n            self.orig_module.beta_slow,\n            self.orig_module.mscale,\n            self.orig_module.mscale_all_dim,\n        )\n\n# class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):\n#     def __init__(\n#         self,\n#         key: str,\n#         gguf_loader: GGUFLoader,\n#         config: PretrainedConfig,\n#         orig_module: nn.Module,\n#         #  device: str = \"cuda\",\n#         generate_device: str = \"cuda\",\n#         prefill_device: str = \"cuda\",\n#         **kwargs,\n#     ):\n#         BaseInjectedModule.__init__(\n#             self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n#         )\n#         self.generate_device = generate_device\n#         self.prefill_device = prefill_device\n\n#     def load(self):\n#         # TODO support perlayer prefill\n#         self.orig_module.__init__(\n#             self.config,\n#             device=self.generate_device\n#         )\n#         return\n\nclass YarnRotaryEmbeddingV3(BaseInjectedModule):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n    \n    def load(self):\n        kwargs = {\n            key: self.config.rope_scaling[key]\n            for key in [\n                \"original_max_position_embeddings\",\n                \"beta_fast\",\n                \"beta_slow\",\n                \"mscale\",\n                \"mscale_all_dim\",\n            ]\n            if key in self.config.rope_scaling\n        }\n        self._init(\n            dim=self.config.qk_rope_head_dim,\n            max_position_embeddings=self.config.max_position_embeddings,\n            base=self.config.rope_theta,\n            device=self.device,\n            scaling_factor=self.config.rope_scaling[\"factor\"],\n            **kwargs,\n        )\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()* self._mscale\n            sin = emb.sin()* self._mscale\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)  \n\n    def _init(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n        original_max_position_embeddings=4096,\n        beta_fast=32,\n        beta_slow=1,\n        mscale=1,\n        mscale_all_dim=0,\n    ):\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self.beta_fast = beta_fast\n        self.beta_slow = beta_slow\n        self.mscale = mscale\n        self.mscale_all_dim = mscale_all_dim\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n\n        freq_extra = 1.0 / (\n            self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n        freq_inter = 1.0 / (\n            self.scaling_factor\n            * self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n\n        low, high = yarn_find_correction_range(\n            self.beta_fast,\n            self.beta_slow,\n            dim,\n            self.base,\n            self.original_max_position_embeddings,\n        )\n        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(\n            device=device, dtype=torch.float32\n        )\n        self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask\n        self._mscale = float(\n            yarn_get_mscale(self.scaling_factor, self.mscale)\n            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)\n        )\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\nclass DynamicNTKScalingRotaryEmbedding(\n    BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding\n):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        prefill_device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            orig_module.dim,\n            orig_module.max_position_embeddings,\n            orig_module.base,\n            None,  # device\n            orig_module.scaling_factor,\n            orig_module.rope_type,\n            orig_module.config,\n        )\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.dim,\n            self.orig_module.max_position_embeddings,\n            self.orig_module.base,\n            self.orig_module.device,\n            self.orig_module.scaling_factor,\n            self.orig_module.rope_type,\n            self.orig_module.config,\n        )\n\n\n\nclass RotaryEmbeddingV4(BaseInjectedModule):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, generate_device, **kwargs\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n    \n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)   \n\n    def load(self):\n        self._init(\n            dim=self.config.qk_rope_head_dim,\n            max_position_embeddings=self.config.max_position_embeddings,\n            base=self.config.rope_theta,\n            device=self.device,\n        )\n    def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        # self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\nclass KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            config,\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.config\n        )\n    \n\nclass KSmallthinkerRotaryEmbedding(BaseInjectedModule, SmallthinkerRotaryEmbedding):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            config\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.config,\n            device = self.generate_device,\n        )\n        \n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):  # Force float32\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos() * self.attention_scaling\n            sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\nclass KGlm4MoeRotaryEmbedding(BaseInjectedModule, Glm4MoeRotaryEmbedding):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            config\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.config,\n            device = self.generate_device,\n        )\n        \n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        if \"dynamic\" in self.rope_type:\n            self._dynamic_frequency_update(position_ids, device=x.device)\n        # Core RoPE block\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        # print(inv_freq_expanded.device)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos() * self.attention_scaling\n            sin = emb.sin() * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)"
  },
  {
    "path": "archive/ktransformers/operators/__init__.py",
    "content": "\n"
  },
  {
    "path": "archive/ktransformers/operators/ascend/ascend_attention.py",
    "content": "# coding=utf-8\r\n# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.\r\n# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport os\r\nimport warnings\r\nfrom typing import Optional, Tuple\r\n\r\nimport torch\r\nimport torch_npu\r\nfrom torch import nn\r\nimport torch.nn.functional as F\r\nfrom transformers.configuration_utils import PretrainedConfig\r\n\r\nfrom ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb\r\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention, apply_rotary_pos_emb\r\nfrom ktransformers.operators.base_operator import BaseInjectedModule\r\nfrom ktransformers.util.custom_loader import GGUFLoader\r\nfrom ktransformers.util.utils import get_compute_capability, get_use_npu_graph, get_current_device\r\nfrom ktransformers.models.custom_cache import StaticCache\r\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardMiniBatchSplit\r\nfrom ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, allredeuce_warpper, get_tensor_parallel_group\r\nfrom ktransformers.util.vendors import device_manager, GPUVendor\r\nfrom ktransformers.util import utils\r\n\r\n\r\ndef apply_rotary_pos_emb_fusion(q, k, cos, sin, unsqueeze_dim=1):\r\n    cos = cos.unsqueeze(unsqueeze_dim)\r\n    sin = sin.unsqueeze(unsqueeze_dim)\r\n    b, h, s, d = q.shape\r\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\r\n    b, h, s, d = k.shape\r\n    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\r\n    q_embed = torch_npu.npu_rotary_mul(q, cos, sin)\r\n    k_embed = torch_npu.npu_rotary_mul(k, cos, sin)\r\n    return q_embed, k_embed\r\n\r\n\r\nclass MatMulOps(object):\r\n    def execute(self, x_input):\r\n        \"\"\"\r\n            :param x, weight, quant_bia, deq_scale\r\n            :return:\r\n        \"\"\"\r\n        quant_out = x_input[0]\r\n        weight = x_input[1]\r\n        quant_bia = x_input[2]\r\n        deq_scale = x_input[3]\r\n        return [torch_npu.npu_quant_matmul(quant_out, weight.T, deq_scale, bias=quant_bia, output_dtype=torch.float16)]\r\n\r\n\r\nclass DynamicQuantOps(object):\r\n    \"\"\"\r\n        :param x\r\n        :return\r\n    \"\"\"\r\n    def execute(self, x_input):\r\n        out = torch.empty_like(x_input[0], dtype=torch.int8)\r\n        torch_npu._npu_quantize_per_tensor(x_input[0], x_input[1], x_input[2], out)\r\n        return [out]\r\n\r\nclass KDeepseekV2AttentionW8A8A2(BaseInjectedModule, DeepseekV2Attention):\r\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\r\n    attn_mask: Optional[torch.Tensor] = None\r\n\r\n    class PageKVWrapper(object):\r\n        \"\"\"\r\n        wrap the difference of KV Cache and Block info between offline model & direct serving & sched serving\r\n        succession should keep the function api\r\n        \"\"\"\r\n        def __init__(self, past_key_value: StaticCache):\r\n            self.kv_cache = past_key_value\r\n            self.page_size = self.kv_cache.page_size\r\n            self.position = self.kv_cache.position\r\n\r\n            self.page_idx = None # staticKV can get from itself\r\n            self.page_offset = None\r\n\r\n        def update(self, compressed_kv, k_pe, layer_idx, cache_kwargs):\r\n            return self.kv_cache.update(compressed_kv, k_pe, layer_idx, cache_kwargs)\r\n        \r\n        def get_usable_length(self, kv_seq_len, layer_idx):\r\n            return self.kv_cache.get_usable_length(kv_seq_len, layer_idx)\r\n        \r\n        def get_seq_length(self, layer_idx):\r\n            return self.kv_cache.get_seq_length(layer_idx)\r\n        \r\n        def get_block_table(self, layer_idx):\r\n            return self.kv_cache.page_table_list[layer_idx]\r\n\r\n    def init_page_kv_wrapper(self, past_key_value: StaticCache):\r\n        self.page_kv_wrapper = self.PageKVWrapper(past_key_value)\r\n\r\n    def __init__(self,\r\n                 key: str,\r\n                 gguf_loader: GGUFLoader,\r\n                 config: PretrainedConfig,\r\n                 orig_module: nn.Module,\r\n                 prefill_device: str = \"cuda\",\r\n                 generate_device: str = \"cuda\",\r\n                 chunck_size: int = 1000,\r\n                 absorb_for_prefill: bool = False,\r\n                 **kwargs):\r\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\r\n        self.orig_module.__init__(orig_module.config,\r\n                                  orig_module.layer_idx)\r\n        self.chunck_size = config.chunk_size\r\n        self.mla_wrapper = None\r\n        self.page_kv_wrapper = None\r\n        self.absorb_for_prefill = absorb_for_prefill\r\n        self.use_merge = os.getenv(\"USE_MERGE\", \"0\")\r\n        tp = get_tensor_parallel_size()\r\n        if tp > 1:\r\n            self.num_heads //= tp\r\n\r\n        if self.use_merge == \"0\":\r\n            self.elewise_quant = DynamicQuantOps()\r\n            self.matmulDequant_operation = MatMulOps()\r\n            self.matmulDequant_operation_aclnn = MatMulOps()\r\n        elif self.use_merge == \"1\":\r\n            print(\"--Use torch npu FA OP !--\")\r\n        else:\r\n            print(\"--Use default op !--\")\r\n        \r\n        self.sparse_mode = 0\r\n\r\n    @allredeuce_warpper\r\n    def forward_chunck(\r\n        self,\r\n        hidden_states: torch.Tensor,\r\n        attention_mask: Optional[torch.Tensor] = None,\r\n        position_ids: Optional[torch.LongTensor] = None,\r\n        past_key_value: Optional[StaticCache] = None,\r\n        output_attentions: bool = False,\r\n        use_cache: bool = False,\r\n        cache_position: Optional[torch.LongTensor] = None,\r\n        is_prefill: bool = True,\r\n        **kwargs\r\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\r\n        bsz, q_len, _ = hidden_states.size()\r\n        if self.q_lora_rank is None:\r\n            q = self.q_proj(hidden_states)\r\n        else:\r\n            hidden_states_quant = self.elewise_quant.execute([hidden_states, self.q_a_proj.input_scale, self.q_a_proj.input_offset])[0]\r\n            q_a_proj_out = self.matmulDequant_operation.execute([hidden_states_quant, self.q_a_proj.weight,\r\n                                                                 self.q_a_proj.quant_bias, self.q_a_proj.deq_scale])[0]\r\n            q_a_proj_out = self.q_a_layernorm(q_a_proj_out)\r\n            q_a_proj_out = self.elewise_quant.execute([q_a_proj_out, self.q_b_proj.input_scale, self.q_b_proj.input_offset])[0]\r\n            q = self.matmulDequant_operation.execute([q_a_proj_out, self.q_b_proj.weight,\r\n                                                      self.q_b_proj.quant_bias, self.q_b_proj.deq_scale])[0]\r\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\r\n        q_nope, q_pe = torch.split(\r\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\r\n        )\r\n\r\n        hidden_states_quant = self.elewise_quant.execute([hidden_states, self.kv_a_proj_with_mqa.input_scale, self.kv_a_proj_with_mqa.input_offset])[0]\r\n        compressed_kv = self.matmulDequant_operation.execute([hidden_states_quant, self.kv_a_proj_with_mqa.weight,\r\n                                                              self.kv_a_proj_with_mqa.quant_bias, self.kv_a_proj_with_mqa.deq_scale])[0]\r\n        compressed_kv, k_pe = torch.split(\r\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\r\n        )\r\n        compressed_kv = self.kv_a_layernorm(compressed_kv)\r\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\r\n\r\n        kv_seq_len = k_pe.shape[-2]\r\n        if past_key_value is not None:\r\n            if self.layer_idx is None:\r\n                raise ValueError(\r\n                    f\"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} \"\r\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\r\n                    \"with a layer index.\"\r\n                )\r\n            kv_seq_len += self.page_kv_wrapper.get_usable_length(kv_seq_len, self.layer_idx)\r\n        cos, sin = self.rotary_emb(q_pe, position_ids)\r\n        q_pe, k_pe = apply_rotary_pos_emb_fusion(q_pe, k_pe, cos, sin)\r\n\r\n        # update KV\r\n        if past_key_value is not None:\r\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\r\n            cache_kwargs[\"page_idx\"] = self.page_kv_wrapper.page_idx\r\n            cache_kwargs[\"page_offset\"] = self.page_kv_wrapper.page_offset\r\n            k_pe = k_pe.transpose(1, 2)                 # k_pe [bsz, 1, q_len, self.qk_rope_head_dim]\r\n            compressed_kv = compressed_kv.unsqueeze(2)  # compressed_kv [bsz, q_len, self.kv_lora_rank]\r\n            compressed_kv_with_k_pe, _ = self.page_kv_wrapper.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\r\n            if is_prefill:\r\n                compressed_kv_prefill = compressed_kv.clone() # clone for prefill infer\r\n                k_pe_prefill = k_pe.clone()\r\n            compressed_kv, k_pe = torch.split(\r\n                compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\r\n            )\r\n\r\n        weight_uk = self.q_absorb\r\n        weight_uv = self.out_absorb\r\n\r\n        # ATB-MLA-FA+PA\r\n        if self.use_merge == \"0\" and is_prefill:\r\n            # if self.layer_idx == 0:\r\n            #   print(self.page_kv_wrapper.get_seq_length(self.layer_idx)\r\n            #   self.page_kv_wrapper.get_block_table(self.layer_idx), self.page_kv_wrapper.position)\r\n            current_sqenLen = self.page_kv_wrapper.get_seq_length(self.layer_idx)\r\n            attention_mask = attention_mask[0, :, :, :current_sqenLen].squeeze(0).squeeze(0)\r\n\r\n            # FIXME this is wrong in random choose pages for sched, currently just use kv without history\r\n            # compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:,:,:current_sqenLen,:]\r\n            # k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:,:,:current_sqenLen,:]\r\n            compressed_kv = compressed_kv_prefill.transpose(1,2).contiguous()\r\n            k_pe = k_pe_prefill.transpose(1,2).contiguous()\r\n\r\n            k_pe_repeated = k_pe.repeat(1, self.num_heads, 1, 1)\r\n            k_up = torch.matmul(compressed_kv, weight_uk.mT)\r\n            v_up = torch.matmul(compressed_kv, weight_uv)\r\n\r\n            qTensor = torch.cat((q_nope, q_pe), dim=-1).transpose(1, 2).contiguous().view(\r\n                                        bsz, q_len, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim))\r\n            kTensor = torch.cat((k_up, k_pe_repeated), dim=-1).transpose(1, 2).contiguous().view(\r\n                                        bsz, current_sqenLen, self.num_heads, (self.qk_nope_head_dim + self.qk_rope_head_dim))\r\n            vTensor = torch.cat((v_up, k_pe_repeated), dim=-1).transpose(1, 2).contiguous().view(\r\n                                        bsz, current_sqenLen, self.num_heads, (self.v_head_dim + self.qk_rope_head_dim))\r\n\r\n            seq_len_data = [q_len] * bsz\r\n\r\n            infer_attention_output, _ = torch_npu.npu_fused_infer_attention_score(\r\n                qTensor, kTensor, vTensor,\r\n                atten_mask = attention_mask.type(torch.int8),\r\n                actual_seq_lengths = seq_len_data,\r\n                scale = self.softmax_scale,\r\n                num_heads = self.num_heads,\r\n                num_key_value_heads = self.num_heads,\r\n                input_layout = \"BSND\")\r\n            \r\n            attn_output = infer_attention_output[..., :self.v_head_dim]\r\n            if tuple(attn_output.size()) != (bsz, q_len, self.num_heads, self.v_head_dim):\r\n                raise ValueError(\r\n                    f\"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.v_head_dim)}, but is\"\r\n                    f\" {tuple(attn_output.size())}\"\r\n                )\r\n\r\n            attn_output = attn_output.contiguous().view(bsz, q_len, self.num_heads * self.v_head_dim)\r\n            attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0]\r\n            attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight,\r\n                                                                self.o_proj.quant_bias, self.o_proj.deq_scale])[0]\r\n\r\n            return attn_output, None, past_key_value\r\n\r\n        elif self.use_merge == \"0\" and not is_prefill:\r\n            return self.forward_paged(q_pe=q_pe,\r\n                                      q_nope=q_nope,\r\n                                      compressed_kv_with_k_pe=compressed_kv_with_k_pe,\r\n                                      past_key_value=past_key_value,\r\n                                      cache_position=cache_position)\r\n\r\n        if self.use_merge == \"1\":\r\n            k_pe_repeated = k_pe.repeat(1, self.num_heads, 1, 1)\r\n            k_up = torch.matmul(compressed_kv, weight_uk.mT)\r\n            v_up = torch.matmul(compressed_kv, weight_uv)\r\n            qTensor = torch.cat((q_nope, q_pe), dim=-1)\r\n            kTensor = torch.cat((k_up, k_pe_repeated), dim=-1)\r\n            vTensor = torch.cat((v_up, k_pe_repeated), dim=-1)\r\n\r\n            if q_len != 1:\r\n                attn_output = torch_npu.npu_prompt_flash_attention(\r\n                    qTensor, kTensor, vTensor,\r\n                    num_heads=self.num_heads, scale_value=self.softmax_scale, input_layout=\"BNSD\")\r\n            else:\r\n                attn_output = torch_npu.npu_incre_flash_attention(\r\n                    qTensor, kTensor, vTensor,\r\n                    num_heads=self.num_heads, scale_value=self.softmax_scale, input_layout=\"BNSD\")\r\n            attn_output = attn_output[:, :, :, :self.v_head_dim]\r\n        else:\r\n            q_nope = torch.matmul(q_nope, self.q_absorb)\r\n\r\n            attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale\r\n\r\n            compressed_kv = compressed_kv.squeeze(1)\r\n            \"\"\"\r\n            if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\r\n                raise ValueError(\r\n                    f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\r\n                    f\" {attn_weights.size()}\"\r\n                )\r\n            assert attention_mask is not None\r\n            \"\"\"\r\n        if attention_mask is not None:\r\n            \"\"\"\r\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\r\n                raise ValueError(\r\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\r\n                )\r\n            \"\"\"\r\n            attn_weights = attn_weights + attention_mask\r\n\r\n            # upcast attention to fp32\r\n            attn_weights = nn.functional.softmax(\r\n                attn_weights, dim=-1, dtype=torch.float32\r\n            ).to(q_pe.dtype)\r\n            attn_weights = nn.functional.dropout(\r\n                attn_weights, p=self.attention_dropout, training=self.training\r\n            )\r\n            attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)\r\n\r\n            attn_output = torch.matmul(attn_output, self.out_absorb)\r\n\r\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):\r\n            raise ValueError(\r\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is\"\r\n                f\" {attn_output.size()}\"\r\n            )\r\n\r\n        attn_output = attn_output.transpose(1, 2).contiguous()\r\n\r\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\r\n\r\n        attn_output = self.o_proj(attn_output)\r\n\r\n        return attn_output, None, past_key_value\r\n\r\n    def forward_paged(\r\n        self,\r\n        q_pe: torch.Tensor,\r\n        q_nope: torch.Tensor,\r\n        compressed_kv_with_k_pe: torch.Tensor,\r\n        past_key_value: Optional[StaticCache] = None,\r\n        cache_position: Optional[torch.LongTensor] = None,\r\n        **kwargs\r\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\r\n        # if self.layer_idx == 1:\r\n        #   print(self.page_kv_wrapper.get_block_table(self.layer_idx), self.page_kv_wrapper.position)\r\n        bsz, _, q_len, _ = q_nope.size()\r\n        q_nope = torch.einsum('b h q d, h d k -> b h q k', q_nope, self.q_absorb)  # torch.Size([1, 128, 1, 512])\r\n        compressed_kv = compressed_kv_with_k_pe.permute(0, 2, 1, 3)\r\n        kvCache = compressed_kv[:, :, :, :self.kv_lora_rank].contiguous()\r\n        kRopeCache = compressed_kv[:, :, :, self.kv_lora_rank:].contiguous()\r\n        if get_use_npu_graph():\r\n            from ktransformers.util.npu_graph_runner import get_or_create_runner\r\n            npu_graph_runner = get_or_create_runner(get_current_device())\r\n            stream = npu_graph_runner.main_stream\r\n            if npu_graph_runner.past_key_value is None:\r\n                npu_graph_runner.past_key_value = past_key_value\r\n            if npu_graph_runner.workspace is None:\r\n                workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(\r\n                    q_nope,\r\n                    kvCache,\r\n                    kvCache,\r\n                    query_rope=q_pe,\r\n                    key_rope=kRopeCache,\r\n                    num_heads=self.num_heads,\r\n                    num_key_value_heads=1,\r\n                    input_layout=\"BNSD\",\r\n                    scale=self.softmax_scale,\r\n                    antiquant_mode=0,\r\n                    antiquant_scale=None,\r\n                    block_table=self.page_kv_wrapper.get_block_table(self.layer_idx),\r\n                    block_size=self.page_kv_wrapper.page_size,\r\n                    actual_seq_lengths_kv=self.page_kv_wrapper.position,\r\n                    sparse_mode = self.sparse_mode)\r\n                npu_graph_runner.workspace = workspace\r\n            attn_output = torch.zeros_like(q_nope, dtype=torch.float16, device=get_current_device())\r\n            softmax_lse = torch.empty(1, dtype=torch.float16, device=get_current_device())\r\n            torch_npu.npu_fused_infer_attention_score.out(\r\n                q_nope,\r\n                kvCache,\r\n                kvCache,\r\n                workspace=npu_graph_runner.workspace,\r\n                query_rope=q_pe,\r\n                key_rope=kRopeCache,\r\n                num_heads=self.num_heads,\r\n                num_key_value_heads=1,\r\n                input_layout=\"BNSD\",\r\n                scale=self.softmax_scale,\r\n                antiquant_mode=0,\r\n                antiquant_scale=None,\r\n                block_table=self.page_kv_wrapper.get_block_table(self.layer_idx),\r\n                block_size=self.page_kv_wrapper.page_size,\r\n                actual_seq_lengths_kv=self.page_kv_wrapper.position,\r\n                sparse_mode = self.sparse_mode,\r\n                out=[attn_output, softmax_lse])\r\n\r\n        else:\r\n            attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(\r\n                q_nope,\r\n                kvCache,\r\n                kvCache,\r\n                query_rope=q_pe,\r\n                key_rope=kRopeCache,\r\n                num_heads=self.num_heads,\r\n                num_key_value_heads=1,\r\n                input_layout=\"BNSD\",\r\n                scale=self.softmax_scale,\r\n                antiquant_mode=0,\r\n                antiquant_scale=None,\r\n                block_table=self.page_kv_wrapper.get_block_table(self.layer_idx),\r\n                block_size=self.page_kv_wrapper.page_size,\r\n                actual_seq_lengths_kv=self.page_kv_wrapper.position,\r\n                sparse_mode = self.sparse_mode\r\n            )\r\n\r\n        attn_output = torch.einsum('b h q k, h k v -> b q h v', attn_output, self.out_absorb)\r\n        attn_output = attn_output.contiguous().view(bsz, q_len, self.num_heads * self.v_head_dim)\r\n        attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0]\r\n        attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight,\r\n                                                            self.o_proj.quant_bias, self.o_proj.deq_scale])[0]\r\n        return attn_output, None, past_key_value\r\n\r\n    def forward_windows(\r\n        self,\r\n        hidden_states: torch.Tensor,\r\n        attention_mask: Optional[torch.Tensor] = None,\r\n        position_ids: Optional[torch.LongTensor] = None,\r\n        past_key_value: Optional[StaticCache] = None,\r\n        output_attentions: bool = False,\r\n        use_cache: bool = False,\r\n        cache_position: Optional[torch.LongTensor] = None,\r\n        is_prefill: bool = True,\r\n        **kwargs,\r\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\r\n        if \"padding_mask\" in kwargs:\r\n            warnings.warn(\r\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\r\n            )\r\n\r\n        self.init_page_kv_wrapper(past_key_value)\r\n        bsz, q_len, _ = hidden_states.size()\r\n\r\n        if q_len <= self.chunck_size:\r\n            return self.forward_chunck(\r\n                hidden_states,\r\n                attention_mask,\r\n                position_ids,\r\n                past_key_value,\r\n                output_attentions,\r\n                use_cache,\r\n                cache_position,\r\n                is_prefill,\r\n                **kwargs\r\n            )\r\n\r\n        assert output_attentions == False, \"output_attentions is not supported when using chunked attention\"\r\n        attn_output = None\r\n        cur_idx = 0\r\n        while cur_idx < q_len:\r\n            if attention_mask is not None:\r\n                chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]\r\n            else:\r\n                # generate chunk_mask automatically.\r\n                self.attn_mask = \\\r\n                    torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \\\r\n                        if self.attn_mask is None \\\r\n                        else self.attn_mask\r\n                self.attn_mask[:, :, :, cur_idx:min(cur_idx + self.chunck_size, past_key_value.max_cache_len)] = \\\r\n                    -65504.0 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1) \\\r\n                        [:, :min(self.chunck_size, min(past_key_value.max_cache_len - cur_idx, self.chunck_size))]\r\n                self.attn_mask[:, :, :, cur_idx + self.chunck_size:] = -65504.0\r\n                self.attn_mask[:, :, :, :cur_idx] = 0\r\n                chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len - cur_idx))\r\n\r\n            cur_output, _, _ = self.forward_chunck(\r\n                hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],\r\n                chunk_mask,\r\n                position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],\r\n                past_key_value,\r\n                output_attentions,\r\n                use_cache,\r\n                cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],\r\n                **kwargs\r\n            )\r\n            cur_idx += self.chunck_size\r\n            if attn_output is None:\r\n                attn_output = cur_output\r\n            else:\r\n                attn_output = torch.cat((attn_output, cur_output), dim=-2)\r\n\r\n        return attn_output, None, past_key_value\r\n\r\n    def forward(\r\n            self,\r\n            hidden_states: torch.Tensor,\r\n            attention_mask: Optional[torch.Tensor] = None,\r\n            position_ids: Optional[torch.LongTensor] = None,\r\n            past_key_value: Optional[StaticCache] = None,\r\n            output_attentions: bool = False,\r\n            use_cache: bool = False,\r\n            cache_position: Optional[torch.LongTensor] = None,\r\n            is_prefill: bool = True,\r\n            **kwargs,\r\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\r\n        # TODO: remove cache_position since it do not support multi-batch \r\n        return self.forward_windows(\r\n            hidden_states,\r\n            attention_mask,\r\n            position_ids,\r\n            past_key_value,\r\n            output_attentions,\r\n            use_cache,\r\n            cache_position,\r\n            is_prefill,\r\n            **kwargs,\r\n        )\r\n\r\n\r\nclass KDeepseekV2AttentionW8A8A2Serve(BaseInjectedModule, DeepseekV2Attention):\r\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\r\n    attn_mask: Optional[torch.Tensor] = None\r\n\r\n    def __init__(self,\r\n                 key: str,\r\n                 gguf_loader: GGUFLoader,\r\n                 config: PretrainedConfig,\r\n                 orig_module: nn.Module,\r\n                 prefill_device: str = \"cuda\",\r\n                 generate_device: str = \"cuda\",\r\n                 chunck_size: int = 1024,\r\n                 absorb_for_prefill: bool = False,\r\n                 **kwargs):\r\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\r\n        self.orig_module.__init__(orig_module.config, orig_module.layer_idx)\r\n\r\n        # self.chunck_size = chunck_size\r\n        self.absorb_for_prefill = absorb_for_prefill\r\n        self.elewise_quant = DynamicQuantOps()\r\n        self.matmulDequant_operation = MatMulOps()\r\n        self.matmulDequant_operation_aclnn = MatMulOps()\r\n        # tp切分\r\n        tp = get_tensor_parallel_size()\r\n        if tp > 1:\r\n            self.num_heads //= tp\r\n        \r\n        self.sparse_mode = 0\r\n    \r\n    def print_callback(self, param):\r\n        with torch.npu.stream(torch.npu.Stream(device=\"npu:0\")):\r\n            hidden_states, position_ids, cache_position, page_idx, page_offset, block_table = param\r\n            print(\"########################################\")\r\n            print(\"hidden_states is \", hidden_states)\r\n            print(\"position_ids is \", position_ids)\r\n            print(\"cache_position is \", cache_position)\r\n            print(\"page_idx is \", page_idx)\r\n            print(\"page_offset is \", page_offset)\r\n            print(\"block_table is \", block_table)\r\n            print(\"########################################\")\r\n    \r\n    @allredeuce_warpper\r\n    def forward(\r\n            self,\r\n            hidden_states: torch.Tensor,\r\n            attention_mask: Optional[torch.Tensor] = None,\r\n            position_ids: Optional[torch.LongTensor] = None,\r\n            past_key_value: Optional[StaticCache] = None,\r\n            output_attentions: bool = False,\r\n            use_cache: bool = False,\r\n            cache_position: Optional[torch.LongTensor] = None,\r\n            is_prefill: Optional[bool] = None,\r\n            page_idx: Optional[torch.Tensor] = None,\r\n            page_offset: Optional[torch.Tensor] = None,\r\n            block_table: Optional[torch.Tensor] = None,\r\n            q_len_raw: Optional[torch.Tensor] = None,\r\n            kv_len_raw: Optional[torch.Tensor] = None,\r\n            stream: Optional[torch.npu.Stream] = None,\r\n            **kwargs,\r\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\r\n\r\n        def create_causal_mask(q_lens, kv_lens):\r\n            q_lens = torch.tensor(q_lens)\r\n            kv_lens = torch.tensor(kv_lens)\r\n            bsz = q_lens.size(0)\r\n\r\n            max_q_len = q_lens.max().item()\r\n            max_kv_len = kv_lens.max().item()\r\n\r\n            # causal mask [max_q_len, max_kv_len]\r\n            base_causal = torch.tril(torch.ones((max_q_len, max_kv_len), dtype=torch.bool))\r\n\r\n            # mask initialize: [bsz, max_q_len, max_kv_len] to False\r\n            mask = torch.zeros((bsz, max_q_len, max_kv_len), dtype=torch.bool)\r\n\r\n            for i in range(bsz):\r\n                ql, kl = q_lens[i].item(), kv_lens[i].item()\r\n                # copy base_causal to mask\r\n                mask[i, :ql, :kl] = base_causal[:ql, :kl]\r\n            \r\n            return mask\r\n        \r\n        bsz, q_len, _ = hidden_states.size()\r\n        if self.q_lora_rank is None:\r\n            q = self.q_proj(hidden_states)\r\n        else:\r\n            hidden_states_quant = self.elewise_quant.execute([hidden_states, self.q_a_proj.input_scale, self.q_a_proj.input_offset])[0]\r\n            q_a_proj_out = self.matmulDequant_operation.execute([hidden_states_quant, self.q_a_proj.weight,\r\n                                                                 self.q_a_proj.quant_bias, self.q_a_proj.deq_scale])[0]\r\n            q_a_proj_out = self.q_a_layernorm(q_a_proj_out)\r\n            q_a_proj_out = self.elewise_quant.execute([q_a_proj_out, self.q_b_proj.input_scale, self.q_b_proj.input_offset])[0]\r\n            q = self.matmulDequant_operation.execute([q_a_proj_out, self.q_b_proj.weight,\r\n                                                      self.q_b_proj.quant_bias, self.q_b_proj.deq_scale])[0]\r\n        \r\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\r\n        q_nope, q_pe = torch.split(\r\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\r\n        )\r\n\r\n        hidden_states_quant = self.elewise_quant.execute([hidden_states, self.kv_a_proj_with_mqa.input_scale, self.kv_a_proj_with_mqa.input_offset])[0]\r\n        compressed_kv = self.matmulDequant_operation.execute([hidden_states_quant, self.kv_a_proj_with_mqa.weight,\r\n                                                              self.kv_a_proj_with_mqa.quant_bias, self.kv_a_proj_with_mqa.deq_scale])[0]\r\n        compressed_kv, k_pe = torch.split(\r\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\r\n        )\r\n        compressed_kv = self.kv_a_layernorm(compressed_kv)\r\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\r\n\r\n        kv_seq_len = k_pe.shape[-2]\r\n        if past_key_value is not None:\r\n            if self.layer_idx is None:\r\n                raise ValueError(\r\n                    f\"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} \"\r\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\r\n                    \"with a layer index.\"\r\n                )\r\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\r\n        cos, sin = self.rotary_emb(q_pe, position_ids)\r\n        q_pe, k_pe = apply_rotary_pos_emb_fusion(q_pe, k_pe, cos, sin)\r\n\r\n        # update KV\r\n        compressed_kv_prefill, k_pe_prefill = None, None\r\n        if past_key_value is not None:\r\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position} # Specific to RoPE models\r\n            cache_kwargs[\"page_idx\"], cache_kwargs[\"page_offset\"] = page_idx, page_offset\r\n            k_pe = k_pe.transpose(1, 2)                # k_pe [bsz, 1, q_len, self.qk_rope_head_dim]\r\n            compressed_kv = compressed_kv.unsqueeze(2) # compressed_kv [bsz, q_len, self.kv_lora_rank]\r\n            combined = torch.cat([compressed_kv, k_pe], dim=-1) # shape: [batch_size, num_heads, 2*self.kv_lora_rank]\r\n            # combined = combined.contiguous()\r\n\r\n            compressed_kv_with_k_pe, _ = past_key_value.update(combined, self.layer_idx, cache_kwargs)\r\n            if is_prefill:\r\n                compressed_kv_prefill = compressed_kv.clone()\r\n                k_pe_prefill = k_pe.clone()\r\n            \r\n        weight_uk = self.q_absorb\r\n        weight_uv = self.out_absorb\r\n\r\n        if is_prefill:\r\n            kTensor_list = []\r\n            vTensor_list = []\r\n            qTensor_list = []\r\n            attention_mask_list = []\r\n            seq_len_data = []\r\n            kv_len_list = []\r\n\r\n            for sample_idx in range(bsz):\r\n                current_q_len = q_len_raw[sample_idx].item() if (q_len_raw is not None and sample_idx < len(q_len_raw)) else hidden_states.shape[1]\r\n                current_kv_len = kv_len_raw[sample_idx].item() if (kv_len_raw is not None and sample_idx < len(kv_len_raw)) else current_q_len\r\n                current_q_len = max(1, current_q_len)\r\n                current_kv_len = max(1, current_kv_len)\r\n                seq_len_data.append(current_q_len)\r\n                kv_len_list.append(current_kv_len)\r\n\r\n                if attention_mask is not None:\r\n                    mask_sample = attention_mask[\r\n                        sample_idx:sample_idx+1, :, :, :current_kv_len\r\n                    ].squeeze(0).squeeze(0)\r\n                    if mask_sample.shape[0] < current_q_len:\r\n                        mask_sample = torch.nn.functional.pad(mask_sample, (0, 0, 0, current_q_len - mask_sample.shape[0]), value=1)\r\n                    elif mask_sample.shape[0] > current_q_len:\r\n                        mask_sample = mask_sample[:current_q_len, :]\r\n                    if mask_sample.shape[1] < current_kv_len:\r\n                        mask_sample = torch.nn.functional.pad(mask_sample, (0, current_kv_len - mask_sample.shape[1]), value=1)\r\n                    elif mask_sample.shape[1] > current_kv_len:\r\n                        mask_sample = mask_sample[:, :current_kv_len]\r\n                    mask_sample = torch.where(\r\n                        (mask_sample > -1e-6) & (mask_sample < 1e-6),\r\n                        torch.tensor(0, device=mask_sample.device, dtype=torch.int8),\r\n                        torch.tensor(1, device=mask_sample.device, dtype=torch.int8)\r\n                    )\r\n                else:\r\n                    mask_sample = torch.ones((current_q_len, current_kv_len), device=hidden_states.device, dtype=torch.int8)\r\n                    valid_len = min(current_q_len, current_kv_len)\r\n                    mask_sample[:, :valid_len] = 0\r\n\r\n                attention_mask_list.append(mask_sample)\r\n\r\n                compressed_kv_sample = compressed_kv_prefill[sample_idx:sample_idx+1, :current_q_len, ...].transpose(1, 2).contiguous()\r\n                k_pe_sample = k_pe_prefill[sample_idx:sample_idx+1, :current_q_len, ...].transpose(1, 2).contiguous()\r\n                k_pe_repeated_sample = k_pe_sample.repeat(1, self.num_heads, 1, 1)\r\n\r\n                q_nope_sample = q_nope[sample_idx:sample_idx+1, :, :current_q_len, :].contiguous()\r\n                q_pe_sample = q_pe[sample_idx:sample_idx+1, :, :current_q_len, :].contiguous()\r\n                q_concat_sample = torch.cat((q_nope_sample, q_pe_sample), dim=-1)\r\n                q_transposed_sample = q_concat_sample.transpose(1, 2).contiguous()\r\n                qTensor_sample = q_transposed_sample.view(current_q_len, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim)\r\n                qTensor_list.append(qTensor_sample)\r\n\r\n                k_up_sample = torch.matmul(compressed_kv_sample, weight_uk.mT)\r\n                k_concat_sample = torch.cat((k_up_sample, k_pe_repeated_sample), dim=-1)\r\n                k_transposed_sample = k_concat_sample.transpose(1, 2).contiguous()\r\n                kTensor_sample = k_transposed_sample.view(current_kv_len, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim)\r\n                kTensor_list.append(kTensor_sample)\r\n\r\n                v_up_sample = torch.matmul(compressed_kv_sample, weight_uv)\r\n                v_concat_sample = torch.cat((v_up_sample, k_pe_repeated_sample), dim=-1)\r\n                v_transposed_sample = v_concat_sample.transpose(1, 2).contiguous()\r\n                vTensor_sample = v_transposed_sample.view(current_kv_len, self.num_heads, self.v_head_dim + self.qk_rope_head_dim)\r\n                vTensor_list.append(vTensor_sample)\r\n        \r\n            max_kv_len = max(kv_len_list)\r\n            max_q_len = max(seq_len_data)\r\n\r\n            qTensor = torch.nn.utils.rnn.pad_sequence(qTensor_list, batch_first=True, padding_value=0.0).contiguous()\r\n            kTensor = torch.nn.utils.rnn.pad_sequence(kTensor_list, batch_first=True, padding_value=0.0).contiguous()\r\n            vTensor = torch.nn.utils.rnn.pad_sequence(vTensor_list, batch_first=True, padding_value=0.0).contiguous()\r\n\r\n            attention_mask = ~create_causal_mask(seq_len_data, kv_len_list).to(qTensor.device)\r\n\r\n            infer_attention_output, _ = torch_npu.npu_fused_infer_attention_score(\r\n                    qTensor, kTensor, vTensor,\r\n                    atten_mask = attention_mask.type(torch.int8),\r\n                    actual_seq_lengths = seq_len_data,\r\n                    scale = self.softmax_scale,\r\n                    num_heads = self.num_heads,\r\n                    num_key_value_heads = self.num_heads,\r\n                    input_layout = \"BSND\")\r\n                \r\n            attn_output = infer_attention_output[..., :self.v_head_dim]\r\n\r\n            if tuple(attn_output.size()) != (bsz, max_q_len, self.num_heads, self.v_head_dim):\r\n                raise ValueError(\r\n                    f\"`attn_output` should be of size {(bsz, max_q_len, self.num_heads, self.v_head_dim)}, but is {tuple(attn_output.size())}\"\r\n                )\r\n            attn_output = attn_output.contiguous().view(bsz, max_q_len, self.num_heads * self.v_head_dim)\r\n            attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0]\r\n            attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight,\r\n                                                                        self.o_proj.quant_bias, self.o_proj.deq_scale])[0]\r\n\r\n\r\n            return attn_output, None, past_key_value\r\n        else:\r\n            return self.forward_paged(q_pe = q_pe,\r\n                                      q_nope = q_nope,\r\n                                      compressed_kv_with_k_pe = compressed_kv_with_k_pe,\r\n                                      past_key_value = past_key_value,\r\n                                      cache_position = cache_position,\r\n                                      block_table = block_table,\r\n                                      page_size = past_key_value.page_size,\r\n                                      q_len_raw = q_len_raw,\r\n                                      kv_len_raw = kv_len_raw,\r\n                                      stream = stream)\r\n    \r\n    @allredeuce_warpper\r\n    def forward_paged(\r\n        self,\r\n        q_pe: torch.Tensor,\r\n        q_nope: torch.Tensor,\r\n        compressed_kv_with_k_pe: torch.Tensor,\r\n        past_key_value: Optional[StaticCache] = None,\r\n        cache_position: Optional[torch.LongTensor] = None,\r\n        block_table: Optional[torch.Tensor] = None,\r\n        page_size: Optional[int] = None,\r\n        q_len_raw: Optional[torch.Tensor] = None,\r\n        kv_len_raw: Optional[torch.Tensor] = None,\r\n        stream: Optional[torch.npu.Stream] = None,\r\n        **kwargs\r\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\r\n        # if self.layer_idx == 0:\r\n        #     print(self.page_kv_wrapper.get_block_table(self.layer_idx), self.page_kv_wrapper.position)\r\n        bsz, _, q_len, _ = q_nope.size()\r\n        # print(f\"{q_nope.size()=}\")\r\n        q_nope = torch.einsum('b h q d, h d k -> b h q k', q_nope, self.q_absorb)   # torch.size([1, 128, 1, 512])\r\n        compressed_kv = compressed_kv_with_k_pe.permute(0,2,1,3)\r\n        kvCache = compressed_kv[:,:,:,:self.kv_lora_rank].contiguous()\r\n        kRopeCache = compressed_kv[:,:,:,self.kv_lora_rank:].contiguous()\r\n        if get_use_npu_graph():\r\n            from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner\r\n            npu_graph_runner = get_or_create_model_runner(device=get_current_device())\r\n            npu_graph_idx = bsz - 1\r\n            if npu_graph_runner.workspace[npu_graph_idx] is None:\r\n                workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(\r\n                    q_nope,\r\n                    kvCache,\r\n                    kvCache,\r\n                    query_rope=q_pe,\r\n                    key_rope=kRopeCache,\r\n                    num_heads=self.num_heads,\r\n                    num_key_value_heads=1,\r\n                    input_layout=\"BNSD\",\r\n                    scale=self.softmax_scale,\r\n                    antiquant_mode=0,\r\n                    antiquant_scale=None,\r\n                    block_table=block_table,\r\n                    block_size=page_size,\r\n                    actual_seq_lengths_kv=kv_len_raw,\r\n                    sparse_mode = self.sparse_mode)\r\n                npu_graph_runner.workspace[npu_graph_idx] = workspace\r\n            \r\n            attn_output = torch.zeros_like(q_nope, dtype=torch.float16, device=get_current_device())\r\n            softmax_lse = torch.empty(1, dtype=torch.float16, device=get_current_device())\r\n\r\n            torch_npu.npu_fused_infer_attention_score.out(\r\n                q_nope,\r\n                kvCache,\r\n                kvCache,\r\n                workspace=npu_graph_runner.workspace[npu_graph_idx],\r\n                query_rope = q_pe,\r\n                key_rope = kRopeCache,\r\n                num_heads = self.num_heads,\r\n                num_key_value_heads = 1,\r\n                input_layout = \"BNSD\",\r\n                scale = self.softmax_scale,\r\n                antiquant_mode = 0,\r\n                antiquant_scale = None,\r\n                block_table = block_table,\r\n                block_size = page_size,\r\n                actual_seq_lengths_kv = kv_len_raw,\r\n                sparse_mode = self.sparse_mode,\r\n                out=[attn_output, softmax_lse])\r\n        else:\r\n            tp_group = get_tensor_parallel_group()\r\n            torch.distributed.barrier(tp_group)\r\n            attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(\r\n                q_nope,\r\n                kvCache,\r\n                kvCache,\r\n                query_rope = q_pe,\r\n                key_rope = kRopeCache,\r\n                num_heads = self.num_heads,\r\n                num_key_value_heads = 1,\r\n                input_layout = \"BNSD\",\r\n                scale = self.softmax_scale,\r\n                antiquant_mode = 0,\r\n                antiquant_scale = None,\r\n                block_table = block_table,\r\n                block_size = page_size,\r\n                actual_seq_lengths_kv = kv_len_raw,\r\n                sparse_mode = self.sparse_mode\r\n            )\r\n\r\n        attn_output = torch.einsum('b h q k, h k v -> b q h v', attn_output, self.out_absorb)\r\n        attn_output = attn_output.contiguous().view(bsz, q_len, self.num_heads*self.v_head_dim)\r\n        attn_output = self.elewise_quant.execute([attn_output, self.o_proj.input_scale, self.o_proj.input_offset])[0]\r\n        attn_output = self.matmulDequant_operation_aclnn.execute([attn_output, self.o_proj.weight,\r\n                                                                  self.o_proj.quant_bias, self.o_proj.deq_scale])[0]\r\n        return attn_output, None, past_key_value\r\n\r\ndef rotate_half(x):\r\n    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]\r\n    return torch.cat((-x2, x1), dim=-1)\r\n\r\n\r\ndef apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):\r\n    cos = cos.unsqueeze(unsqueeze_dim)\r\n    sin = sin.unsqueeze(unsqueeze_dim)\r\n    q_embed = (q * cos) + (rotate_half(q) * sin)\r\n    k_embed = (k * cos) + (rotate_half(k) * sin)\r\n    return q_embed, k_embed\r\n\r\n\r\nclass KQwen3MoeAttentionW8A8A2Serve(BaseInjectedModule, Qwen3MoeAttention):\r\n\r\n    attn_mask: Optional[torch.Tensor] = None\r\n\r\n    def __init__(self,\r\n                 key: str,\r\n                 gguf_loader: GGUFLoader,\r\n                 config: PretrainedConfig,\r\n                 orig_module: nn.Module,\r\n                 prefill_device: str = \"npu\",\r\n                 generate_device: str = \"npu\",\r\n                 chunck_size: int = 1024,\r\n                 absorb_for_prefill: bool = False,\r\n                 **kwargs):\r\n\r\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module,\r\n                                    prefill_device, generate_device, **kwargs)\r\n\r\n        self.orig_module.__init__(orig_module.config, orig_module.layer_idx)\r\n\r\n        self.absorb_for_prefill = absorb_for_prefill\r\n\r\n        self.elewise_quant = DynamicQuantOps()\r\n        self.matmulDequant_operation = MatMulOps()\r\n        self.matmulDequant_operation_aclnn = MatMulOps()\r\n\r\n        self.softmax_scale = self.scaling\r\n        self.sparse_mode = 0\r\n\r\n        self._prefill_step = 0\r\n        self._cur_prefill_dir: Optional[str] = None\r\n\r\n        if hasattr(self, \"rotary_emb\"):\r\n            if hasattr(self.rotary_emb, \"cos_cached\"):\r\n                self.rotary_emb.cos_cached = self.rotary_emb.cos_cached.to(torch.float16)\r\n                self.rotary_emb.sin_cached = self.rotary_emb.sin_cached.to(torch.float16)\r\n            if hasattr(self.rotary_emb, \"inv_freq\"):\r\n                self.rotary_emb.inv_freq = self.rotary_emb.inv_freq.to(torch.float16)\r\n\r\n    def _linear_w8a8a2(self, x: torch.Tensor, proj: nn.Module, name: str) -> torch.Tensor:\r\n        if x.dtype == torch.bfloat16:\r\n            x = x.to(torch.float16)\r\n        B, Q, H_in = x.shape\r\n        x_2d = x.view(-1, H_in)   # [T, H_in], T = B * Q\r\n        x_q = self.elewise_quant.execute([\r\n            x_2d,\r\n            proj.input_scale,\r\n            proj.input_offset\r\n        ])[0]\r\n        y_2d = self.matmulDequant_operation.execute([\r\n            x_q,\r\n            proj.weight,\r\n            proj.quant_bias,\r\n            proj.deq_scale\r\n        ])[0]\r\n        return y_2d.view(B, Q, -1)\r\n    # -------------------------------------------------------\r\n    # forward\r\n    # -------------------------------------------------------\r\n    def forward(self,\r\n                hidden_states: torch.Tensor,\r\n                attention_mask=None,\r\n                position_ids=None,\r\n                past_key_value=None,\r\n                output_attentions=False,\r\n                use_cache=False,\r\n                cache_position=None,\r\n                is_prefill=None,\r\n                page_idx=None,\r\n                page_offset=None,\r\n                block_table=None,\r\n                q_len_raw=None,\r\n                kv_len_raw=None,\r\n                stream=None,\r\n                **kwargs):\r\n\r\n        if hidden_states.dim() == 2:\r\n            hidden_states = hidden_states.unsqueeze(0)\r\n        bsz, q_len, hidden = hidden_states.shape\r\n\r\n        # -------- QKV --------\r\n        q_proj_out = self._linear_w8a8a2(hidden_states, self.q_proj, \"Q\")\r\n        B, S, _ = q_proj_out.shape\r\n        q = q_proj_out.view(B, S, self.num_heads, self.head_dim)  # [B, S, H, Dh]\r\n        q = self.q_norm(q)\r\n        q_in = q.view(B, S, -1)\r\n\r\n        k_proj_out = self._linear_w8a8a2(hidden_states, self.k_proj, \"K\")\r\n        k = k_proj_out.view(B, S, self.num_key_value_heads, self.head_dim)\r\n        k = self.k_norm(k)\r\n        k_in = k.view(B, S, -1)\r\n\r\n        v_in = self._linear_w8a8a2(hidden_states, self.v_proj, \"V\")\r\n\r\n        q = q_in.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\r\n        k = k_in.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\r\n        v = v_in.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\r\n\r\n        # -------- RoPE --------\r\n        cos, sin = self.rotary_emb(v, position_ids)\r\n        q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)\r\n\r\n        # -------- prefill / decode --------\r\n        if is_prefill:\r\n            out = self._forward_prefill(\r\n                q, k, v,\r\n                attention_mask=attention_mask,\r\n                position_ids=position_ids,\r\n                past_key_value=past_key_value,\r\n                q_len_raw=q_len_raw,\r\n                kv_len_raw=kv_len_raw,\r\n                page_idx=page_idx,\r\n                page_offset=page_offset,\r\n                block_table=block_table,\r\n            )\r\n            return out\r\n        else:\r\n            return self.forward_paged(\r\n                q=q, k=k, v=v,\r\n                past_key_value=past_key_value,\r\n                cache_position=cache_position,\r\n                block_table=block_table,\r\n                page_size=getattr(past_key_value, \"page_size\", None),\r\n                q_len_raw=q_len_raw,\r\n                kv_len_raw=kv_len_raw,\r\n                page_idx=page_idx,\r\n                page_offset=page_offset,\r\n                stream=stream\r\n            )\r\n\r\n    # -------------------------------------------------------\r\n    # Prefill\r\n    # -------------------------------------------------------\r\n    def _forward_prefill(\r\n        self,\r\n        q: torch.Tensor,   # [B, H, Q, Dh]\r\n        k: torch.Tensor,   # [B, KvH, Q, Dh]\r\n        v: torch.Tensor,   # [B, KvH, Q, Dh]\r\n        attention_mask=None,\r\n        position_ids=None,\r\n        past_key_value=None,\r\n        q_len_raw=None,\r\n        kv_len_raw=None,\r\n        page_idx=None,\r\n        page_offset=None,\r\n        block_table=None,\r\n        **kwargs,\r\n    ) -> torch.Tensor:\r\n\r\n        B, H, Q, Dh = q.shape\r\n        KvH = k.shape[1]\r\n\r\n        # ---------- 1) 写 KV cache ----------\r\n        if (\r\n            past_key_value is not None\r\n            and page_idx is not None\r\n            and page_offset is not None\r\n        ):\r\n            try:\r\n                past_key_value.update(\r\n                    key_states=k,\r\n                    value_states=v,\r\n                    layer_idx=self.layer_idx,\r\n                    cache_kwargs={\r\n                        \"page_idx\": page_idx,\r\n                        \"page_offset\": page_offset,\r\n                    },\r\n                )\r\n            except Exception as e:\r\n                print(f\"[PREFILL-QWEN3][WARN] KV cache update failed: {e}\", flush=True)\r\n\r\n        # ---------- 2) GQA：4 KV → 32 Q heads ----------\r\n        if KvH != self.num_key_value_heads:\r\n            print(\r\n                f\"[PREFILL-QWEN3][WARN] KvH ({KvH}) != config.num_key_value_heads \"\r\n                f\"({self.num_key_value_heads}), 使用 k.shape[1] 作为 KvH\",\r\n                flush=True,\r\n            )\r\n            KvH = k.shape[1]\r\n\r\n        if H % KvH != 0:\r\n            raise ValueError(\r\n                f\"[PREFILL-QWEN3] num_heads={H} 不是 num_kv_heads={KvH} 的整数倍\"\r\n            )\r\n\r\n        group_size = H // KvH\r\n        k_full = k.repeat_interleave(group_size, dim=1)\r\n        v_full = v.repeat_interleave(group_size, dim=1)\r\n\r\n        print(\"[PREFILL-QWEN3] k_full/v_full:\", k_full.shape, v_full.shape, flush=True)\r\n\r\n        # ---------- 3) BSND + causal mask ----------\r\n        q_bsnd = q.permute(0, 2, 1, 3).contiguous()      # [B, Q, H, Dh]\r\n        k_bsnd = k_full.permute(0, 2, 1, 3).contiguous()\r\n        v_bsnd = v_full.permute(0, 2, 1, 3).contiguous()\r\n\r\n        if q_len_raw is None:\r\n            seq_len_data = [Q for _ in range(B)]\r\n            kv_len_list = [Q for _ in range(B)]\r\n        else:\r\n            seq_len_data = []\r\n            kv_len_list = []\r\n            for b_idx in range(B):\r\n                cur_q = int(q_len_raw[b_idx].item())\r\n                if kv_len_raw is not None:\r\n                    cur_kv = int(kv_len_raw[b_idx].item())\r\n                else:\r\n                    cur_kv = cur_q\r\n                cur_q = max(1, cur_q)\r\n                cur_kv = max(1, cur_kv)\r\n                seq_len_data.append(cur_q)\r\n                kv_len_list.append(cur_kv)\r\n\r\n        def create_causal_mask(q_lens, kv_lens):\r\n            q_lens_t = torch.tensor(q_lens, device=q_bsnd.device)\r\n            kv_lens_t = torch.tensor(kv_lens, device=q_bsnd.device)\r\n            bsz = q_lens_t.size(0)\r\n            max_q = int(q_lens_t.max().item())\r\n            max_kv = int(kv_lens_t.max().item())\r\n            base_causal = torch.tril(\r\n                torch.ones((max_q, max_kv), dtype=torch.bool, device=q_bsnd.device)\r\n            )\r\n            mask = torch.zeros(\r\n                (bsz, max_q, max_kv), dtype=torch.bool, device=q_bsnd.device\r\n            )\r\n            for i in range(bsz):\r\n                ql = int(q_lens_t[i].item())\r\n                kl = int(kv_lens_t[i].item())\r\n                mask[i, :ql, :kl] = base_causal[:ql, :kl]\r\n            return mask\r\n\r\n        max_q_len = max(seq_len_data) if len(seq_len_data) > 0 else Q\r\n        max_kv_len = max(kv_len_list) if len(kv_len_list) > 0 else Q\r\n\r\n        q_list, k_list, v_list = [], [], []\r\n        for b_idx in range(B):\r\n            cur_q = seq_len_data[b_idx]\r\n            cur_kv = kv_len_list[b_idx]\r\n\r\n            q_sample = q_bsnd[b_idx, :cur_q, :, :].contiguous()\r\n            k_sample = k_bsnd[b_idx, :cur_kv, :, :].contiguous()\r\n            v_sample = v_bsnd[b_idx, :cur_kv, :, :].contiguous()\r\n\r\n            q_list.append(q_sample)\r\n            k_list.append(k_sample)\r\n            v_list.append(v_sample)\r\n\r\n        qTensor = torch.nn.utils.rnn.pad_sequence(\r\n            q_list, batch_first=True, padding_value=0.0\r\n        ).contiguous()\r\n        kTensor = torch.nn.utils.rnn.pad_sequence(\r\n            k_list, batch_first=True, padding_value=0.0\r\n        ).contiguous()\r\n        vTensor = torch.nn.utils.rnn.pad_sequence(\r\n            v_list, batch_first=True, padding_value=0.0\r\n        ).contiguous()\r\n\r\n        causal_mask = create_causal_mask(seq_len_data, kv_len_list)\r\n        atten_mask = (~causal_mask).to(torch.int8)\r\n\r\n        print(\"[PREFILL-QWEN3] qTensor/kTensor/vTensor:\",\r\n              qTensor.shape, kTensor.shape, vTensor.shape, flush=True)\r\n\r\n        # ---------- 4) NPU fused attention ----------\r\n        infer_attention_output, _ = torch_npu.npu_fused_infer_attention_score(\r\n            qTensor, kTensor, vTensor,\r\n            atten_mask=atten_mask,\r\n            actual_seq_lengths=seq_len_data,\r\n            scale=self.softmax_scale,\r\n            num_heads=H,\r\n            num_key_value_heads=H,\r\n            input_layout=\"BSND\",\r\n        )\r\n\r\n        attn_output = infer_attention_output\r\n\r\n        # ---------- 5) reshape + W8A8 o_proj ----------\r\n        attn_output = attn_output.contiguous().view(B, max_q_len, H * Dh)\r\n\r\n        attn_output_q = self.elewise_quant.execute(\r\n            [attn_output, self.o_proj.input_scale, self.o_proj.input_offset]\r\n        )[0]\r\n\r\n        attn_output = self.matmulDequant_operation_aclnn.execute(\r\n            [attn_output_q, self.o_proj.weight, self.o_proj.quant_bias, self.o_proj.deq_scale]\r\n        )[0]\r\n\r\n        print(\"[PREFILL-QWEN3] attn_output(after o_proj):\", attn_output.shape, attn_output.dtype, flush=True)\r\n\r\n        return attn_output\r\n\r\n    def forward_paged(\r\n        self,\r\n        q: torch.Tensor,\r\n        k: torch.Tensor,\r\n        v: torch.Tensor,\r\n        past_key_value,\r\n        cache_position,\r\n        block_table,\r\n        page_size,\r\n        q_len_raw,\r\n        kv_len_raw,\r\n        page_idx,\r\n        page_offset,\r\n        stream,\r\n        **kwargs,\r\n    ):\r\n        B, H, Q, Dh = q.shape\r\n        KvH = k.shape[1]\r\n\r\n        # ========= 1) 更新 KV cache =========\r\n        past_key_value.update(\r\n            key_states=k,\r\n            value_states=v,\r\n            layer_idx=self.layer_idx,\r\n            cache_kwargs={\r\n                \"page_idx\": page_idx,\r\n                \"page_offset\": page_offset,\r\n            },\r\n        )\r\n\r\n        Kcache = past_key_value.get_k_cache(self.layer_idx)\r\n        Vcache = past_key_value.get_v_cache(self.layer_idx)\r\n        \r\n        q_bnsd = q.contiguous()\r\n        k_bnsd = Kcache.contiguous().to(torch.float16).transpose(1, 2)\r\n        v_bnsd = Vcache.contiguous().to(torch.float16).transpose(1, 2)\r\n\r\n        use_graph = get_use_npu_graph()\r\n        device = get_current_device()\r\n\r\n        if use_graph:\r\n            from ktransformers.server.balance_serve.inference.model_runner import get_or_create_model_runner\r\n            npu_graph_runner = get_or_create_model_runner(device=device)\r\n            npu_graph_idx = B - 1\r\n\r\n            if npu_graph_runner.workspace[npu_graph_idx] is None:\r\n                workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(\r\n                q_bnsd,\r\n                k_bnsd,\r\n                v_bnsd,\r\n                num_heads=H,\r\n                num_key_value_heads=KvH,\r\n                input_layout=\"BNSD\",\r\n                scale=self.softmax_scale,\r\n                antiquant_mode=0,\r\n                antiquant_scale=None,\r\n                block_table=block_table,\r\n                block_size=page_size,\r\n                actual_seq_lengths_kv=kv_len_raw,\r\n                sparse_mode=self.sparse_mode,\r\n            )\r\n                npu_graph_runner.workspace[npu_graph_idx] = workspace\r\n\r\n            attn_output = torch.zeros_like(q_bnsd, dtype=torch.float16, device=device)\r\n            softmax_lse = torch.empty(1, dtype=torch.float16, device=device)\r\n            torch_npu.npu_fused_infer_attention_score.out(\r\n                q_bnsd,\r\n                k_bnsd,\r\n                v_bnsd,\r\n                workspace=npu_graph_runner.workspace[npu_graph_idx],\r\n                num_heads=H,\r\n                num_key_value_heads=KvH,\r\n                input_layout=\"BNSD\",\r\n                scale=self.softmax_scale,\r\n                antiquant_mode=0,\r\n                antiquant_scale=None,\r\n                block_table=block_table,\r\n                block_size=page_size,\r\n                actual_seq_lengths_kv=kv_len_raw,\r\n                sparse_mode=self.sparse_mode,\r\n                out=[attn_output, softmax_lse]\r\n            )\r\n        else:\r\n            attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(\r\n                q_bnsd,\r\n                k_bnsd,\r\n                v_bnsd,\r\n                num_heads=H,\r\n                num_key_value_heads=KvH,\r\n                input_layout=\"BNSD\",\r\n                scale=self.softmax_scale,\r\n                antiquant_mode=0,\r\n                antiquant_scale=None,\r\n                block_table=block_table,\r\n                block_size=page_size,\r\n                actual_seq_lengths_kv=kv_len_raw,\r\n                sparse_mode=self.sparse_mode,\r\n            )\r\n\r\n        attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(B, -1, H * Dh)\r\n\r\n        attn_output_q = self.elewise_quant.execute(\r\n            [attn_output, self.o_proj.input_scale, self.o_proj.input_offset]\r\n        )[0]\r\n\r\n        attn_output = self.matmulDequant_operation_aclnn.execute(\r\n            [attn_output_q, self.o_proj.weight, self.o_proj.quant_bias, self.o_proj.deq_scale]\r\n        )[0]\r\n\r\n        return attn_output"
  },
  {
    "path": "archive/ktransformers/operators/ascend/ascend_experts.py",
    "content": "# coding=utf-8\r\n# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.\r\n# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport re\r\nimport os\r\nfrom typing import Optional\r\n\r\nimport bisect\r\nimport torch\r\nimport numpy as np\r\nfrom torch import nn\r\nimport torch_npu\r\nfrom transformers import PretrainedConfig\r\nimport torch.nn.functional as F\r\n\r\nfrom ktransformers.util.custom_loader import GGUFLoader\r\nfrom ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size, get_tensor_parallel_group\r\nfrom ktransformers.operators.experts import cuda_graphs, KExpertsBase, KExpertsCPU, KTransformersExperts, EXPERTS_MAP, KDeepseekV3MoE\r\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE\r\nfrom ktransformers.operators.base_operator import BaseInjectedModule\r\nfrom ktransformers.util.utils import CUR_DEVICE, get_use_npu_graph, InferenceState\r\nfrom ktransformers.operators.experts import cuda_graphs as npu_graphs\r\nfrom ktransformers.util import utils\r\n\r\nclass KExpertsCPUW8A8(KExpertsCPU):\r\n\r\n    def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=None, use_npu_graph=False):\r\n        if use_npu_graph:\r\n            seq_len = input_tensor.size(0)\r\n            cuda_graph_idx = seq_len - 1 if cuda_graph_idx is None else cuda_graph_idx # input_tensor is seq & batch merged\r\n            self.cpu_infer.submit(self.moe.forward(KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].size(0),\r\n                                                    KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].size(1),\r\n                                                    KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].data_ptr(),\r\n                                                    KExpertsCPU.weights_cpu[cuda_graph_idx][0].data_ptr(),\r\n                                                    KExpertsCPU.input_tensor_cpu[cuda_graph_idx][0].data_ptr(),\r\n                                                    KExpertsCPU.output_cpu[cuda_graph_idx][0].data_ptr(),\r\n                                                    KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx][0].data_ptr()\r\n                                                    ))\r\n            self.cpu_infer.sync()\r\n        else:\r\n            if bsz_tensor is None:\r\n                bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)\r\n            # if torch.cuda.is_current_stream_capturing():\r\n            org_type = input_tensor.dtype\r\n            input_tensor = input_tensor.contiguous().cpu()\r\n            input_tensor = input_tensor.to(torch.bfloat16)\r\n            expert_ids = expert_ids.contiguous().cpu()\r\n            weights = weights.contiguous().to(torch.float32).cpu()\r\n            bsz_tensor = bsz_tensor.contiguous().cpu()\r\n            output = torch.empty_like(input_tensor).contiguous()\r\n            self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))\r\n            self.cpu_infer.sync()\r\n            return output.to(org_type).to(device=utils.get_current_device())\r\n\r\nEXPERTS_MAP[\"KExpertsCPUW8A8\"] = KExpertsCPUW8A8\r\n\r\nclass KTransformersExpertsW8A8(KTransformersExperts):\r\n    def forward(self, input_tensor, expert_ids, weights, cuda_graph_idx=None, use_npu_graph=False):\r\n        if self.mode == InferenceState.GENERATE:\r\n            assert self.generate_experts is not None, \"generate_experts is None\"\r\n            return self.generate_experts.forward(input_tensor, expert_ids, weights, cuda_graph_idx=cuda_graph_idx, use_npu_graph=use_npu_graph)\r\n        elif self.mode == InferenceState.PREFILL:\r\n            assert self.prefill_experts is not None, \"prefill_experts is None\"\r\n            return self.prefill_experts.forward(input_tensor, expert_ids, weights, cuda_graph_idx=cuda_graph_idx, use_npu_graph=use_npu_graph)\r\n        else:\r\n            raise ValueError(\"load or set_inference_mode before forward\")\r\n\r\n\r\nclass KDeepseekV3MoEW8A8(KDeepseekV3MoE):\r\n    def forward(self, hidden_states, stream=None, para_stream=None):\r\n        tp_size = get_tensor_parallel_size()\r\n        world_size = torch.distributed.get_world_size()\r\n        rank = torch.distributed.get_rank()\r\n        identity = hidden_states\r\n        orig_shape = hidden_states.shape\r\n\r\n        def share_experts_forward():\r\n            if self.config.n_shared_experts is not None:\r\n                return self.shared_experts(identity).squeeze(0)\r\n\r\n        if rank == 0:\r\n            topk_idx, topk_weight = self.gate(hidden_states)\r\n            hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\r\n            if get_use_npu_graph():\r\n                org_type = hidden_states.dtype\r\n                if hasattr(self.config, \"backend_type\"):\r\n                    if self.config.backend_type == \"ktransformers\":\r\n                        from ktransformers.util.npu_graph_runner import get_or_create_runner\r\n                        npu_graph_runner = get_or_create_runner(utils.get_current_device())\r\n                        stream = npu_graph_runner.main_stream\r\n                        para_stream = npu_graph_runner.share_experts_stream\r\n                    event = torch.npu.Event()\r\n                    event.record(stream)\r\n                    with torch.npu.stream(para_stream):\r\n                        event.wait(para_stream)\r\n                        y_ = share_experts_forward() if share_experts_forward is not None else None\r\n                        event.record(para_stream)\r\n            \r\n                    input_tensor = hidden_states.to(torch.bfloat16)\r\n                    topk_weight = topk_weight.contiguous().to(torch.float32)\r\n                    cuda_graph_idx = orig_shape[0] - 1\r\n                    self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight, cuda_graph_idx, True)\r\n                    if cuda_graph_idx < len(npu_graphs):\r\n                        expert_ids = topk_idx\r\n                        KExpertsCPU.input_tensor_cpu[cuda_graph_idx][0].copy_(input_tensor, non_blocking = True)\r\n                        KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].copy_(expert_ids, non_blocking = True)\r\n                        KExpertsCPU.weights_cpu[cuda_graph_idx][0].copy_(topk_weight, non_blocking = True)\r\n                        torch_npu.npu._launch_host_func(stream, self.cpu_moe_kexperts, self.moe_kexperts_param)\r\n\r\n                        y = self.experts.generate_experts.output_cpu[cuda_graph_idx][0].to(utils.get_current_device(), non_blocking = True)\r\n                        y = y.view(*orig_shape).to(device=hidden_states.device)\r\n                        y = y.to(org_type)\r\n                    event.wait(stream)\r\n                else:\r\n                    from ktransformers.util.npu_graph_runner import get_or_create_runner\r\n                    npu_graph_runner = get_or_create_runner(utils.get_current_device())\r\n                    event = torch.npu.Event()\r\n                    event.record(npu_graph_runner.main_stream)\r\n                    with torch.npu.stream(npu_graph_runner.share_experts_stream):\r\n                        event.wait(npu_graph_runner.share_experts_stream)\r\n                        y_ = share_experts_forward() if share_experts_forward is not None else None\r\n                        event.record(npu_graph_runner.share_experts_stream)\r\n                    topk_weight = topk_weight.contiguous().to(torch.float32)\r\n                    self.moe_kexperts_param = (hidden_states, topk_idx, topk_weight, None, True)\r\n\r\n                    org_type = hidden_states.dtype\r\n                    input_tensor = hidden_states.to(torch.bfloat16)\r\n\r\n                    cuda_graph_idx = bisect.bisect_left(npu_graphs, 1)\r\n                    if cuda_graph_idx < len(npu_graphs):\r\n\r\n                        immediate_expert_ids = topk_idx\r\n                        KExpertsCPU.input_tensor_cpu[cuda_graph_idx][0].copy_(input_tensor, non_blocking = True)\r\n                        KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].copy_(immediate_expert_ids, non_blocking = True)\r\n                        KExpertsCPU.weights_cpu[cuda_graph_idx][0].copy_(topk_weight, non_blocking = True)\r\n\r\n                        npu_graph_runner.launch_callback(\r\n                            self.cpu_moe_kexperts,\r\n                            self.moe_kexperts_param,\r\n                            1, npu_graph_runner.main_stream)\r\n                        y = self.experts.generate_experts.output_cpu[cuda_graph_idx][0].to(utils.get_current_device(), non_blocking = True)\r\n\r\n                        y = y.to(org_type)\r\n                        y = y.view(*orig_shape).to(device=hidden_states.device)\r\n                    event.wait(npu_graph_runner.main_stream)\r\n            else:\r\n                y = self.moe_kexperts(hidden_states, topk_idx, topk_weight)\r\n                y_ = share_experts_forward() if share_experts_forward is not None else None\r\n                y = y.view(*orig_shape).to(device=hidden_states.device)\r\n                y_ = y_.view(*orig_shape)\r\n        else:\r\n            hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\r\n            y = torch.zeros(orig_shape, dtype=torch.float16, device=CUR_DEVICE)\r\n            y_ = share_experts_forward() if share_experts_forward is not None else None\r\n\r\n        if tp_size > 1 and world_size == tp_size:\r\n            torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM, group=get_tensor_parallel_group())\r\n        if self.config.n_shared_experts is not None:\r\n            y += y_\r\n        return y\r\n\r\n    @torch.no_grad()\r\n    def cpu_moe_kexperts(self, moe_kexperts_param) -> torch.Tensor:\r\n        x, topk_ids, topk_weight, cuda_graph_idx, use_npu_graph = moe_kexperts_param\r\n        _ = self.experts(x, topk_ids, topk_weight, cuda_graph_idx=cuda_graph_idx, use_npu_graph=use_npu_graph)\r\n\r\n    @torch.no_grad()\r\n    def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:\r\n        outs = self.experts(x, topk_ids, topk_weight)\r\n        return outs\r\n\r\nclass KQwen3MoeSparseMoeBlockW8A8(BaseInjectedModule):\r\n    def __init__(\r\n        self,\r\n        key: str,\r\n        gguf_loader: GGUFLoader,\r\n        config: PretrainedConfig,\r\n        orig_module: nn.Module,\r\n        prefill_device: str = \"npu\",\r\n        generate_device: str = \"npu\",\r\n        **kwargs,\r\n    ):\r\n        super().__init__(\r\n            key,\r\n            gguf_loader,\r\n            config,\r\n            orig_module,\r\n            prefill_device=prefill_device,\r\n            generate_device=generate_device,\r\n            **kwargs,\r\n        )\r\n\r\n        self.gate = orig_module.gate\r\n        self.top_k = orig_module.top_k\r\n        self.norm_topk_prob = orig_module.norm_topk_prob\r\n        self.output_router_logits = getattr(orig_module, \"output_router_logits\", False)\r\n\r\n        experts_key = f\"{key}.experts\"\r\n\r\n        print(f\"[NPU-MOE][INIT] build experts at key={experts_key}\", flush=True)\r\n        self.experts = KTransformersExpertsW8A8(\r\n            key=experts_key,\r\n            gguf_loader=gguf_loader,\r\n            config=config,\r\n            orig_module=orig_module.experts,\r\n            prefill_device=prefill_device,\r\n            prefill_op=\"KExpertsTorch\",\r\n            generate_device=\"cpu\",\r\n            generate_op=\"KExpertsCPUW8A8\",\r\n            out_device=prefill_device,\r\n        )\r\n\r\n    def set_inference_mode(self, mode: InferenceState):\r\n        if isinstance(self.experts, KExpertsBase):\r\n            self.experts.set_inference_mode(mode)\r\n\r\n    @torch.no_grad()\r\n    def cpu_moe_kexperts(self, moe_kexperts_param):\r\n        x, topk_ids, topk_weight, cuda_graph_idx, use_npu_graph = moe_kexperts_param\r\n        _ = self.experts(\r\n            x,\r\n            topk_ids,\r\n            topk_weight,\r\n            cuda_graph_idx=cuda_graph_idx,\r\n            use_npu_graph=use_npu_graph,\r\n        )\r\n\r\n    @torch.no_grad()\r\n    def moe_kexperts(\r\n        self,\r\n        x: torch.Tensor,\r\n        topk_ids: torch.Tensor,\r\n        topk_weight: torch.Tensor,\r\n        bsz_tensor: torch.Tensor = None,\r\n        cuda_graph_idx: int = 0,\r\n        use_npu_graph: bool = False,\r\n    ) -> torch.Tensor:\r\n        outs = self.experts(\r\n            x,\r\n            topk_ids,\r\n            topk_weight,\r\n            cuda_graph_idx=cuda_graph_idx,\r\n            use_npu_graph=use_npu_graph,\r\n        )\r\n        return outs\r\n\r\n    def forward(\r\n        self,\r\n        hidden_states: torch.Tensor,\r\n        bsz_tensor: torch.Tensor = None,\r\n        cuda_graph_idx: int = 0,\r\n        *args,\r\n        **kwargs,\r\n    ):\r\n\r\n        if hidden_states.dim() == 3:\r\n            B, S, H = hidden_states.shape\r\n        else:\r\n            orig_shape = hidden_states.shape\r\n            hidden_states = hidden_states.view(1, -1, orig_shape[-1])\r\n            B, S, H = hidden_states.shape\r\n\r\n        orig_device = hidden_states.device\r\n        orig_shape = (B, S, H)\r\n\r\n        output_router_logits_flag = kwargs.pop(\"output_router_logits\", False)\r\n        need_router_logits = output_router_logits_flag or self.output_router_logits\r\n\r\n        # ===== 1) flatten =====\r\n        hidden_states_flat = hidden_states.view(-1, H)\r\n        T = hidden_states_flat.shape[0]\r\n\r\n        # ===== 2) gate =====\r\n        router_logits = self.gate(hidden_states_flat)\r\n        try:\r\n            router_logits_bs = router_logits.view(B, S, -1)\r\n        except Exception:\r\n            router_logits_bs = router_logits\r\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\r\n        routing_weights, selected_experts = torch.topk(\r\n            routing_weights, self.top_k, dim=-1\r\n        )\r\n        if self.norm_topk_prob:\r\n            rw_sum = routing_weights.sum(dim=-1, keepdim=True)\r\n            routing_weights = routing_weights / rw_sum\r\n\r\n        routing_weights = routing_weights.to(hidden_states_flat.dtype)\r\n\r\n        # ===== 3) MoE experts =====\r\n        use_npu_graph = get_use_npu_graph()\r\n        if torch.distributed.is_available() and torch.distributed.is_initialized():\r\n            tp_size = get_tensor_parallel_size()\r\n            world_size = torch.distributed.get_world_size()\r\n            rank = torch.distributed.get_rank()\r\n        else:\r\n            tp_size = 1\r\n            world_size = 1\r\n            rank = 0\r\n        y = None\r\n        if isinstance(self.experts, KExpertsBase):\r\n            if getattr(self.experts, \"mode\", None) == InferenceState.UNLOAD:\r\n                self.experts.set_inference_mode(InferenceState.GENERATE)\r\n\r\n            if rank == 0:\r\n                if use_npu_graph:\r\n                    org_type = hidden_states_flat.dtype\r\n                    input_tensor = hidden_states_flat.to(torch.bfloat16)\r\n                    topk_weight_f32 = routing_weights.contiguous().to(torch.float32)\r\n                    self.moe_kexperts_param = (\r\n                        hidden_states_flat,\r\n                        selected_experts,\r\n                        topk_weight_f32,\r\n                        cuda_graph_idx,\r\n                        True,\r\n                    )\r\n                    if cuda_graph_idx < len(npu_graphs):\r\n                        KExpertsCPU.input_tensor_cpu[cuda_graph_idx][0].copy_(input_tensor, non_blocking=True)\r\n                        KExpertsCPU.expert_ids_cpu[cuda_graph_idx][0].copy_(selected_experts, non_blocking=True)\r\n                        KExpertsCPU.weights_cpu[cuda_graph_idx][0].copy_(topk_weight_f32, non_blocking=True)\r\n\r\n                        stream = torch.npu.current_stream()\r\n                        torch_npu.npu._launch_host_func(\r\n                            stream,\r\n                            self.cpu_moe_kexperts,\r\n                            self.moe_kexperts_param,\r\n                        )\r\n\r\n                        y_flat = self.experts.generate_experts.output_cpu[cuda_graph_idx][0].to(\r\n                            utils.get_current_device(),\r\n                            non_blocking=True,\r\n                        )\r\n                        y_flat = y_flat.to(org_type)\r\n                        y = y_flat.view(*orig_shape).to(device=orig_device)\r\n                    else:\r\n                        tmp_bsz_tensor = torch.tensor([B], dtype=torch.int32, device=orig_device)\r\n                        y_flat = self.moe_kexperts(\r\n                            hidden_states_flat,\r\n                            selected_experts,\r\n                            routing_weights,\r\n                            bsz_tensor=tmp_bsz_tensor,\r\n                            cuda_graph_idx=cuda_graph_idx,\r\n                            use_npu_graph=False,\r\n                        )\r\n                        y = y_flat.view(*orig_shape).to(device=orig_device)\r\n                else:\r\n                    if bsz_tensor is None:\r\n                        bsz_tensor = torch.tensor(\r\n                            [B],\r\n                            dtype=torch.int32,\r\n                            device=orig_device,\r\n                        )\r\n\r\n                    y_flat = self.moe_kexperts(\r\n                        hidden_states_flat,\r\n                        selected_experts,\r\n                        routing_weights,\r\n                        bsz_tensor=bsz_tensor,\r\n                        cuda_graph_idx=cuda_graph_idx,\r\n                        use_npu_graph=False,\r\n                    )\r\n                    y = y_flat.view(*orig_shape).to(device=orig_device)\r\n            else:\r\n                y = torch.zeros(orig_shape, dtype=hidden_states.dtype, device=orig_device)\r\n        else:\r\n            y = hidden_states\r\n\r\n        if tp_size > 1 and world_size == tp_size:\r\n            torch.distributed.all_reduce(y, op=torch.distributed.ReduceOp.SUM, group=get_tensor_parallel_group())\r\n        # print(\"================ [NPU-MOE] EXIT MLP =======================\\n\")\r\n        if need_router_logits:\r\n            num_experts = router_logits.shape[-1]\r\n            router_logits_bs = router_logits.view(B, S, num_experts)\r\n            return y, router_logits_bs\r\n\r\n\r\n        return y\r\n"
  },
  {
    "path": "archive/ktransformers/operators/ascend/ascend_gate.py",
    "content": "import torch\r\nimport torch_npu\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom ktransformers.operators.gate import KMoEGate\r\n\r\n\r\nclass KDeepseekV3GateA2(KMoEGate):\r\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):\r\n        if device is None:\r\n            device = self.device\r\n        if w is None:\r\n            w = self.load_weights(device=device)\r\n\r\n        if isinstance(w, dict):\r\n            self.weight_type = w[\"weight_type\"]\r\n            self.e_score_correction_bias_type = w[\"e_score_correction_bias_type\"]\r\n            self.orig_module.weight = nn.Parameter(w[\"weight\"])\r\n            self.orig_module.e_score_correction_bias = nn.Parameter(w[\"e_score_correction_bias\"])\r\n        else:\r\n            raise ValueError(\"Invalid weight type\")\r\n        self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device).to(torch.float32))\r\n        self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device).to(torch.float32))\r\n\r\n    def forward(self, hidden_states) -> torch.Tensor:\r\n        h = hidden_states.shape[-1]\r\n        # compute gating score\r\n        hidden_states = hidden_states.view(-1, h)\r\n        logits = F.linear(hidden_states.type(torch.float32), self.weight, None)\r\n        topk_weight, topk_idx, _ = torch_npu.npu_moe_gating_top_k(\r\n            logits,\r\n            k=self.top_k,\r\n            bias=self.e_score_correction_bias,\r\n            k_group=self.topk_group,\r\n            group_count=self.n_group,\r\n            group_select_mode=1,\r\n            renorm=0,\r\n            norm_type=1,\r\n            routed_scaling_factor=self.routed_scaling_factor,\r\n            eps=float(1e-20))\r\n        return topk_idx.type(torch.int64), topk_weight\r\n"
  },
  {
    "path": "archive/ktransformers/operators/ascend/ascend_layernorm.py",
    "content": "# coding=utf-8\r\n# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.\r\n# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport os\r\nimport re\r\nfrom typing import Optional, Union, Tuple\r\n\r\nimport torch\r\nimport torch_npu\r\nfrom torch import nn\r\nfrom transformers import PretrainedConfig\r\n\r\nfrom ktransformers.operators.base_operator import BaseInjectedModule\r\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm\r\nfrom ktransformers.util import utils\r\nfrom ktransformers.util.custom_loader import GGUFLoader\r\n\r\n\r\nclass KDeepseekV3RMSNormW8A8(BaseInjectedModule):\r\n    def __init__(self,\r\n                 key: str,\r\n                 gguf_loader: GGUFLoader,\r\n                 config: PretrainedConfig,\r\n                 orig_module: nn.Module,\r\n                 prefill_device: str = \"npu\",\r\n                 generate_device: str = \"npu\",\r\n                 **kwargs):\r\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\r\n        self.weight = nn.Parameter(torch.ones(self.orig_module.hidden_size))\r\n        self.bias = nn.Parameter(torch.ones(self.orig_module.hidden_size))\r\n        self.variance_epsilon = self.orig_module.variance_epsilon\r\n\r\n    def forward(self, hidden_states):\r\n        input_dtype = hidden_states.dtype\r\n        out = torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + self.bias\r\n        return out.to(input_dtype)\r\n\r\n    def load(self):\r\n        self.weight = self.gguf_loader.safetensor_loader.load_tensor(self.key + \".weight\").to(utils.get_current_device())\r\n        self.bias = self.gguf_loader.safetensor_loader.load_tensor(self.key + \".bias\").to(utils.get_current_device())\r\n\r\n    def unload(self):\r\n        if self.weight is not None:\r\n            self.weight = None\r\n        if self.bias is not None:\r\n            self.bias = None\r\n\r\nclass KQwen3MoeRMSNormW8A8(BaseInjectedModule):\r\n    def __init__(self,\r\n                 key: str,\r\n                 gguf_loader: GGUFLoader,\r\n                 config: PretrainedConfig,\r\n                 orig_module: nn.Module,\r\n                 prefill_device: str = \"npu\",\r\n                 generate_device: str = \"npu\",\r\n                 **kwargs):\r\n\r\n        super().__init__(key, gguf_loader, config, orig_module,\r\n                         prefill_device, generate_device, **kwargs)\r\n\r\n        self.hidden_size = orig_module.hidden_size\r\n        self.variance_epsilon = orig_module.variance_epsilon\r\n        self.weight = nn.Parameter(orig_module.weight.data.clone())\r\n\r\n    def forward(self, x: torch.Tensor):\r\n        x = x.to(torch.float16)\r\n        gamma = self.weight.to(torch.float16)\r\n\r\n        input_dtype = x.dtype\r\n        out = torch_npu.npu_rms_norm(\r\n            x,\r\n            gamma,\r\n            self.variance_epsilon\r\n        )[0]\r\n\r\n        return out.to(input_dtype)\r\n\r\n    def load(self):\r\n        device = utils.get_current_device()\r\n        self.weight = self.gguf_loader.safetensor_loader.load_tensor(self.key + \".weight\").to(device)\r\n\r\n        try:\r\n            self.bias = (\r\n                self.gguf_loader.safetensor_loader\r\n                .load_tensor(self.key + \".bias\")\r\n                .to(device)\r\n            )\r\n        except KeyError:\r\n            self.bias = None\r\n\r\n    def unload(self):\r\n        self.weight = None\r\n        self.bias = None\r\n\r\nclass KQwen3FinalRMSNormNPU(nn.Module):\r\n    def __init__(self, orig_module: nn.Module):\r\n        super().__init__()\r\n        assert hasattr(orig_module, \"weight\"), \"orig_module must have weight\"\r\n        self.variance_epsilon = getattr(orig_module, \"variance_epsilon\", 1e-6)\r\n\r\n        w = orig_module.weight.detach()\r\n        if w.dtype not in (torch.float16, torch.bfloat16, torch.float32):\r\n            w = w.to(torch.float16)\r\n        else:\r\n            if w.dtype == torch.float32:\r\n                w = w.to(torch.float16)\r\n\r\n        self.weight = nn.Parameter(w)\r\n\r\n    def forward(self, x: torch.Tensor):\r\n        input_dtype = x.dtype\r\n        x = x.contiguous()\r\n        gamma = self.weight\r\n        x_rms = x.to(dtype=gamma.dtype)\r\n\r\n        out = torch_npu.npu_rms_norm(\r\n            x_rms,\r\n            gamma,\r\n            self.variance_epsilon\r\n        )[0]\r\n\r\n        return out.to(input_dtype)"
  },
  {
    "path": "archive/ktransformers/operators/ascend/ascend_linear.py",
    "content": "# coding=utf-8\r\n# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.\r\n# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nfrom abc import abstractmethod\r\n\r\nimport torch\r\nimport torch_npu\r\nimport torch.distributed as dist\r\nfrom torch import nn\r\nfrom transformers import PretrainedConfig\r\n\r\nfrom ktransformers.operators.base_operator import BaseInjectedModule\r\nfrom ktransformers.operators.linear import KLinearBase, LINEAR_MAP\r\nfrom ktransformers.util import utils\r\nfrom ktransformers.util.custom_loader import GGUFLoader\r\nfrom ktransformers.util.utils import InferenceState\r\nfrom ktransformers.util.ascend.ascend_utils import get_safetensors_cut_weight, get_tensor_parallel_size, get_tensor_parallel_group\r\nfrom ktransformers.util.custom_gguf import translate_name_to_gguf\r\n\r\n\r\nclass KLinearW8A8(KLinearBase):\r\n    def __init__(\r\n            self,\r\n            key: str,\r\n            gguf_loader: GGUFLoader,\r\n            config: PretrainedConfig,\r\n            orig_module: nn.Module = None,\r\n            device: str = \"cuda\",\r\n            **kwargs,\r\n    ):\r\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\r\n\r\n    def load_weight(self, override_key: str | None = None, device: str | None = None):\r\n        if override_key is not None:\r\n            keys = override_key\r\n        else:\r\n            keys = [self.key]\r\n        fake_tensor = torch.tensor([1])\r\n        for key in keys:\r\n            if device is None:\r\n                device = utils.get_current_device()\r\n            \r\n            key = translate_name_to_gguf(key)\r\n            if key == \"lm_head\":\r\n                key = \"output\"\r\n                \r\n            if key + \".weight\" in self.gguf_loader.safetensor_loader.tensor_file_map:\r\n                if key + \".deq_scale\" in self.gguf_loader.safetensor_loader.tensor_file_map:\r\n                    qweight = self.gguf_loader.safetensor_loader.load_tensor(f\"{key}.weight\")\r\n                    deq_scale = self.gguf_loader.safetensor_loader.load_tensor(f\"{key}.deq_scale\")\r\n                    quant_bias = self.gguf_loader.safetensor_loader.load_tensor(f\"{key}.quant_bias\")\r\n                    input_scale = self.gguf_loader.safetensor_loader.load_tensor(f\"{key}.input_scale\")\r\n                    input_offset = self.gguf_loader.safetensor_loader.load_tensor(f\"{key}.input_offset\")\r\n                    tensors = (qweight, deq_scale, quant_bias, input_scale, input_offset)\r\n                    return tensors\r\n                elif key + \".weight_scale\" in self.gguf_loader.safetensor_loader.tensor_file_map:\r\n                    if key.endswith(\"ffn_gate_shexp\"):\r\n                        parts = key.split(\".\")\r\n                        layer = parts[1]\r\n                        gate_weight = self.gguf_loader.safetensor_loader.load_tensor(f\"blk.{layer}.ffn_gate_shexp.weight\")\r\n                        gate_weight = get_safetensors_cut_weight(self.key, gate_weight).t()\r\n                        up_weight = self.gguf_loader.safetensor_loader.load_tensor(f\"blk.{layer}.ffn_up_shexp.weight\")\r\n                        up_weight = get_safetensors_cut_weight(self.key, up_weight).t()\r\n                        gate_scale = self.gguf_loader.safetensor_loader.load_tensor(f\"blk.{layer}.ffn_gate_shexp.weight_scale\")\r\n                        gate_scale = get_safetensors_cut_weight(self.key, gate_scale)\r\n                        up_scale = self.gguf_loader.safetensor_loader.load_tensor(f\"blk.{layer}.ffn_up_shexp.weight_scale\")\r\n                        up_scale = get_safetensors_cut_weight(self.key, up_scale)\r\n                        gate_up_weight = torch.cat((gate_weight, up_weight), 1)\r\n                        gate_up_scale = torch.cat((gate_scale, up_scale), 0)\r\n                        gate_offset = self.gguf_loader.safetensor_loader.load_tensor(f\"blk.{layer}.ffn_gate_shexp.weight_offset\")\r\n                        up_offset = self.gguf_loader.safetensor_loader.load_tensor(f\"blk.{layer}.ffn_up_shexp.weight_offset\")\r\n                        gate_up_offset = torch.cat((gate_offset, up_offset), 0)\r\n                        tensors = (gate_up_weight, gate_up_scale, gate_up_offset)\r\n                    elif key.endswith(\"ffn_up_shexp\"):\r\n                        return fake_tensor\r\n                    else:\r\n                        qweight = self.gguf_loader.safetensor_loader.load_tensor(f\"{key}.weight\")\r\n                        weight_scale = self.gguf_loader.safetensor_loader.load_tensor(f\"{key}.weight_scale\")\r\n                        weight_offset = self.gguf_loader.safetensor_loader.load_tensor(f\"{key}.weight_offset\")\r\n                        tensors = (qweight, weight_scale, weight_offset)\r\n                    return tensors\r\n                else:\r\n                    weight = self.gguf_loader.safetensor_loader.load_tensor(f\"{key}.weight\")\r\n                    return weight\r\n            else:\r\n                raise FileNotFoundError(f\"Weight file not found for key {key}\")\r\n\r\n    @abstractmethod\r\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = \"cuda\"):\r\n        pass\r\n\r\n    @abstractmethod\r\n    def unload(self):\r\n        pass\r\n\r\n\r\nclass KLinearTorchW8A8A2(KLinearW8A8):\r\n    def __init__(\r\n        self,\r\n        key: str,\r\n        gguf_loader: GGUFLoader,\r\n        config: PretrainedConfig,\r\n        orig_module: nn.Module = None,\r\n        device: str = \"cuda\",\r\n        **kwargs,\r\n    ):\r\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\r\n        self.has_bias = False\r\n        self.dtype = torch.get_default_dtype()\r\n        self.weight = None\r\n        self.input_scale = None\r\n        self.input_offset = None\r\n        self.quant_bias = None\r\n        self.deq_scale = None\r\n        self.weight_scale = None\r\n        self.weight_offset = None\r\n\r\n    def forward(self, x: torch.Tensor, bsz_tensor) -> torch.Tensor:\r\n        if x.dtype != self.weight.dtype:\r\n            x = x.to(self.weight.dtype)\r\n        return torch.matmul(x, self.weight)\r\n\r\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None):\r\n        if device is None: device = utils.get_current_device()\r\n        device = utils.CUR_DEVICE\r\n        if w is None:\r\n            w = self.load_weight()\r\n\r\n        if isinstance(w, nn.Parameter):\r\n            try:\r\n                self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T.contiguous()\r\n            except:\r\n                self.weight = w.to(dtype=self.dtype).T.contiguous()\r\n            self.weight = self.weight.to(device)\r\n            if self.has_bias:\r\n                self.bias = self.bias.to(device)\r\n        elif isinstance(w, tuple):\r\n            w_list = list(w)\r\n            if len(w_list) == 3:\r\n                self.weight = w_list[0]\r\n                self.weight_scale = w_list[1].view(-1)\r\n                self.weight_offset = w_list[2]\r\n                self.weight = self.weight.to(utils.CUR_DEVICE)\r\n                self.weight_scale = self.weight_scale.to(utils.CUR_DEVICE)\r\n                if self.key.endswith(\"ffn_gate_shexp\") is not True:\r\n                    self.weight = get_safetensors_cut_weight(self.key, self.weight).t()\r\n                    weight_scale = get_safetensors_cut_weight(self.key, self.weight_scale)\r\n                    self.weight_scale = weight_scale.clone()\r\n                    del weight_scale\r\n            else:\r\n                for i in range(len(w_list)):\r\n                    w_list[i] = get_safetensors_cut_weight(self.key, w_list[i])\r\n                    w_list[i] = w_list[i].to(utils.CUR_DEVICE)\r\n                self.weight = w_list[0]\r\n                self.deq_scale = w_list[1]\r\n                self.quant_bias = w_list[2]\r\n                if \"attn_output\" in self.key or \"ffn_down\" in self.key:\r\n                    if torch.distributed.get_rank(get_tensor_parallel_group()) != 0:\r\n                        self.quant_bias = torch.zeros_like(self.quant_bias, dtype=self.quant_bias.dtype, device=self.quant_bias.device)\r\n\r\n                self.input_scale = w_list[3]\r\n                self.input_offset = w_list[4]\r\n        elif isinstance(w, torch.Tensor):\r\n            self.weight = w.T.contiguous()\r\n            self.weight = self.weight.to(device)\r\n            if \"kv_b\" not in self.key and (\"output\" in  self.key or \"eh_proj\" in self.key):\r\n                self.weight = torch_npu.npu_format_cast(self.weight, 29)\r\n        else:\r\n            raise ValueError(f\"Invalid weight type {self.key=} {type(w)=}\")\r\n\r\n    def unload(self):\r\n        if self.weight is not None:\r\n            self.weight = None\r\n        if self.has_bias:\r\n            self.bias = None\r\n        self.input_scale = None\r\n        self.input_offset = None\r\n        self.quant_bias = None\r\n        self.deq_scale = None\r\n        self.weight_scale = None\r\n        self.weight_offset = None\r\n\r\n\r\nLINEAR_MAP[\"KLinearTorchW8A8A2\"] = KLinearTorchW8A8A2\r\n\r\n\r\nclass KTransformersLinearW8A8A2(BaseInjectedModule, KLinearW8A8):\r\n    def __init__(\r\n            self,\r\n            key: str,\r\n            gguf_loader: GGUFLoader,\r\n            config: PretrainedConfig,\r\n            orig_module: nn.Module,\r\n            generate_device: str = \"cuda\",\r\n            generate_op: str | None = \"KLinearMarlin\",\r\n            prefill_device: str = \"cuda\",\r\n            prefill_op: str | None = \"KLinearTorch\",\r\n            **kwargs,\r\n    ):\r\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\r\n        KLinearW8A8.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\r\n        # build all the linear operators\r\n        if prefill_op is not None:\r\n            assert prefill_op in LINEAR_MAP, f\"linear_type {prefill_op} not supported\"\r\n            self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)\r\n        else:\r\n            self.prefill_linear = None\r\n\r\n        if generate_op is not None:\r\n            assert generate_op in LINEAR_MAP, f\"linear_type {generate_op} not supported\"\r\n            self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)\r\n        else:\r\n            self.generate_linear = None\r\n        self.mode = InferenceState.UNLOAD\r\n\r\n    def forward(self, x, bsz_tensor=None):\r\n        if self.mode == InferenceState.PREFILL:\r\n            assert self.prefill_linear is not None, \"cpu linear is not initialized\"\r\n            y = self.prefill_linear.forward(x, bsz_tensor)\r\n        else:\r\n            assert self.generate_linear is not None, \"gpu linear is not initialized\"\r\n            y = self.generate_linear.forward(x, bsz_tensor)\r\n        return y\r\n\r\n    def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):\r\n        if not mode:\r\n            mode = InferenceState.GENERATE\r\n        # load to device\r\n        if mode == InferenceState.PREFILL:\r\n            self.generate_linear.unload()\r\n            self.prefill_linear.load(w=w)\r\n            self.device = self.prefill_linear.device\r\n            self.weight = self.prefill_linear.weight  # modeling_xxx.py may use linear.weight\r\n            self.input_scale = self.prefill_linear.input_scale\r\n            self.input_offset = self.prefill_linear.input_offset\r\n            self.quant_bias = self.prefill_linear.quant_bias\r\n            self.deq_scale = self.prefill_linear.deq_scale\r\n            self.weight_scale = self.prefill_linear.weight_scale\r\n            self.weight_offset = self.prefill_linear.weight_offset\r\n        elif mode == InferenceState.GENERATE:\r\n            self.prefill_linear.unload()\r\n            self.generate_linear.load(w=w)\r\n            self.device = self.generate_linear.device\r\n            self.weight = self.generate_linear.weight  # modeling_xxx.py may use linear.weight\r\n            self.input_scale = self.generate_linear.input_scale\r\n            self.input_offset = self.generate_linear.input_offset\r\n            self.quant_bias = self.generate_linear.quant_bias\r\n            self.deq_scale = self.generate_linear.deq_scale\r\n            self.weight_scale = self.generate_linear.weight_scale\r\n            self.weight_offset = self.generate_linear.weight_offset\r\n        elif mode == InferenceState.UNLOAD:\r\n            self.prefill_linear.unload()\r\n            self.generate_linear.unload()\r\n            self.device = \"cpu\"\r\n        else:\r\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\r\n        self.mode = mode\r\n\r\n    def unload(self):\r\n        if self.prefill_linear is not None:\r\n            self.prefill_linear.unload()\r\n        if self.generate_linear is not None:\r\n            self.generate_linear.unload()\r\n        self.device = self.generate_linear.device\r\n\r\n    def set_inference_mode(self, mode: InferenceState):\r\n        if not mode:\r\n            mode = InferenceState.GENERATE\r\n        if mode == InferenceState.GENERATE:\r\n            self.load(mode=InferenceState.GENERATE)\r\n        elif mode == InferenceState.PREFILL:\r\n            self.load(mode=InferenceState.PREFILL)\r\n        elif mode == InferenceState.UNLOAD:\r\n            self.unload()\r\n        else:\r\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\r\n"
  },
  {
    "path": "archive/ktransformers/operators/ascend/ascend_mlp.py",
    "content": "# coding=utf-8\r\n# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.\r\n# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\r\nimport torch\r\nimport torch_npu\r\n\r\nfrom ktransformers.util.ascend.ascend_utils import allredeuce_warpper\r\nfrom ktransformers.util.utils import CUR_DEVICE\r\nfrom ktransformers.operators.base_operator import BaseInjectedModule\r\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP\r\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeMLP\r\n\r\nclass KDeepseekV3MLPW8A8A2V1(BaseInjectedModule, DeepseekV3MLP):\r\n    @allredeuce_warpper\r\n    def forward(self, x, is_prefill=None, use_cuda_graph=False):\r\n        original_dtype = x.dtype\r\n        quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)\r\n        dynamic_scale = dynamic_scale.view(-1)\r\n        quant_out = quant_out.view(-1, quant_out.shape[-1])\r\n        gate_x = torch_npu.npu_quant_matmul(\r\n            quant_out,\r\n            self.orig_module.gate_proj.weight,\r\n            self.orig_module.gate_proj.weight_scale,\r\n            pertoken_scale=dynamic_scale,\r\n            bias=None,\r\n            output_dtype=original_dtype,\r\n        )\r\n        up_x = torch_npu.npu_quant_matmul(\r\n            quant_out,\r\n            self.orig_module.up_proj.weight,\r\n            self.orig_module.up_proj.weight_scale,\r\n            pertoken_scale=dynamic_scale,\r\n            bias=None,\r\n            output_dtype=original_dtype,\r\n        )\r\n        down_x = self.act_fn(gate_x) * up_x\r\n        down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)\r\n        down_dynamic_scale = down_dynamic_scale.view(-1)\r\n        down_proj = torch_npu.npu_quant_matmul(\r\n            down_quant_out,\r\n            self.orig_module.down_proj.weight,\r\n            self.orig_module.down_proj.weight_scale,\r\n            pertoken_scale=down_dynamic_scale,\r\n            bias=None,\r\n            output_dtype=original_dtype,\r\n        )\r\n        down_proj = down_proj.reshape(x.shape)\r\n        return down_proj\r\n\r\nclass KDeepseekV3MLPW8A8A2V2(BaseInjectedModule, DeepseekV3MLP):\r\n    @allredeuce_warpper\r\n    def forward(self, x, is_prefill=None, use_cuda_graph=False):\r\n        original_dtype = x.dtype\r\n        quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)\r\n        dynamic_scale = dynamic_scale.view(-1)\r\n        quant_out = quant_out.view(-1, quant_out.shape[-1])\r\n        gate_up_x = torch_npu.npu_quant_matmul(\r\n            quant_out,\r\n            self.orig_module.gate_proj.weight,\r\n            self.orig_module.gate_proj.weight_scale,\r\n            pertoken_scale=dynamic_scale,\r\n            bias=None,\r\n            output_dtype=original_dtype,\r\n        )\r\n        down_x = torch_npu.npu_swiglu(gate_up_x, -1)\r\n        down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)\r\n        down_dynamic_scale = down_dynamic_scale.view(-1)\r\n        down_proj = torch_npu.npu_quant_matmul(\r\n            down_quant_out,\r\n            self.orig_module.down_proj.weight,\r\n            self.orig_module.down_proj.weight_scale,\r\n            pertoken_scale=down_dynamic_scale,\r\n            bias=None,\r\n            output_dtype=original_dtype,\r\n        )\r\n        down_proj = down_proj.reshape(x.shape)\r\n        return down_proj\r\n\r\nclass KQwen3MoeMLPW8A8A2(BaseInjectedModule, Qwen3MoeMLP):\r\n    @allredeuce_warpper\r\n    def forward(self, x, is_prefill=None, use_cuda_graph=False):\r\n        original_dtype = x.dtype\r\n        quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)\r\n        dynamic_scale = dynamic_scale.view(-1)\r\n        quant_out = quant_out.view(-1, quant_out.shape[-1])\r\n        \r\n        gate_x = torch_npu.npu_quant_matmul(\r\n            quant_out,\r\n            self.orig_module.gate_proj.weight,\r\n            self.orig_module.gate_proj.weight_scale,\r\n            pertoken_scale=dynamic_scale,\r\n            bias=None,\r\n            output_dtype=original_dtype,\r\n        )\r\n        up_x = torch_npu.npu_quant_matmul(\r\n            quant_out,\r\n            self.orig_module.up_proj.weight,\r\n            self.orig_module.up_proj.weight_scale,\r\n            pertoken_scale=dynamic_scale,\r\n            bias=None,\r\n            output_dtype=original_dtype,\r\n        )\r\n        \r\n        down_x = torch.nn.functional.silu(gate_x) * up_x\r\n        \r\n        down_quant_out, down_dynamic_scale = torch_npu.npu_dynamic_quant(down_x)\r\n        down_dynamic_scale = down_dynamic_scale.view(-1)\r\n        \r\n        down_proj = torch_npu.npu_quant_matmul(\r\n            down_quant_out,\r\n            self.orig_module.down_proj.weight,\r\n            self.orig_module.down_proj.weight_scale,\r\n            pertoken_scale=down_dynamic_scale,\r\n            bias=None,\r\n            output_dtype=original_dtype,\r\n        )\r\n        down_proj = down_proj.reshape(x.shape)\r\n        return down_proj"
  },
  {
    "path": "archive/ktransformers/operators/attention.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport torch\nfrom torch import nn\nimport warnings\nimport torch.nn.functional as F\nfrom ktransformers.operators.models import KLlamaModel\nfrom ktransformers.models.configuration_deepseek import DeepseekV2Config\nfrom ktransformers.models.configuration_llama import LlamaConfig\nfrom ktransformers.models.modeling_llama import LlamaRotaryEmbedding\nfrom ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention\nfrom typing import Optional, Tuple\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom ktransformers.util.utils import get_compute_capability\nimport logging\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.cache_utils import Cache\nfrom ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor\n\ntry:\n    from flash_attn import flash_attn_func\nexcept:\n    pass\nfrom ktransformers.operators.triton_attention import decode_attention_fwd_grouped \nfrom ktransformers.operators.triton_attention_prefill import context_attention_fwd\nimport os\nfrom ktransformers.operators.flashinfer_wrapper import flashinfer_enabled\nif flashinfer_enabled:\n    from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton\n    from flashinfer.mla import BatchMLAPagedAttentionWrapper\nfrom ktransformers.models.custom_cache import KDeepSeekV3Cache\nlogger = logging.getLogger(\"attention\")\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n# V3 MLA is same to V2\nclass KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n    attn_mask: Optional[torch.Tensor] = None\n\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 absorb_for_prefill: bool = False,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n        self.mla_wrapper = None\n        self.absorb_for_prefill = absorb_for_prefill\n\n    def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:\n        if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):\n            kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)\n            self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)\n            self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)\n            \n        return self.q_absorb, self.out_absorb\n\n    def forward_chunck(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n        # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]\n        # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        compressed_kv = self.kv_a_layernorm(compressed_kv)\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n\n        kv_seq_len = k_pe.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(q_pe, position_ids)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            \n            # compressed_kv [bsz, q_len, self.kv_lora_rank]\n            # k_pe [bsz, 1, q_len, self.qk_rope_head_dim]\n            k_pe = k_pe.transpose(1,2)\n            compressed_kv = compressed_kv.unsqueeze(2)\n            compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\n            compressed_kv, k_pe = torch.split(\n                compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n            )\n            # k_pe [pages, page_size, 1, self.qk_rope_head_dim]\n            # compressed_kv [pages, page_size, 1, self.kv_lora_rank]\n            \n        q_absorb, out_absorb = self.get_absorbed()\n\n        # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]\n        # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]\n        k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:,:,:attention_mask.size(-1),:]\n        compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:,:,:attention_mask.size(-1),:]\n        # k_pe [bsz, 1, cache_len, self.qk_rope_head_dim]\n        # compressed_kv [bsz, 1, cache_len,self.kv_lora_rank]\n        q_nope = torch.matmul(q_nope, q_absorb)\n        #print(q_pe.shape)\n        #print(k_pe.shape)\n        #print(q_nope.shape)\n        #print(compressed_kv.shape)\n        \n        attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale\n        \n        #attn_weights [bsz, self.num_heads, q_len, kv_seq_len]\n        compressed_kv = compressed_kv.squeeze(1)\n        \"\"\"\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n        assert attention_mask is not None\n        \"\"\"\n        if attention_mask is not None:\n            \"\"\"\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            \"\"\"\n            #causal_mask = attention_mask[:, :, :, : kv_seq_len]\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(q_pe.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.attention_dropout, training=self.training\n        )\n        \n        attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)\n        \n        attn_output = torch.matmul(attn_output, out_absorb.mT) \n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        \n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output, None, past_key_value\n\n    def forward_linux_triton(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            past_key_value: Optional[Cache] = None,\n            output_attentions: bool = False,\n            use_cache: bool = False,\n            cache_position: Optional[torch.LongTensor] = None,\n            **kwargs,\n        ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        compressed_kv = self.kv_a_layernorm(compressed_kv)\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)\n        compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)\n\n        kv_seq_len = q_len\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        \n        cos, sin = self.rotary_emb(q_pe, position_ids)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)\n        # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]\n        \n        # decode\n        if q_len == 1:\n            if past_key_value is not None:\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n                compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\n                compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank] # for speed\n                # compressed_kv_with_k_pe [bsz, q_len, 1, self.kv_lora_rank + self.qk_rope_head_dim]\n                # compressed_kv [bsz, q_len, 1, self.kv_lora_rank]\n\n            # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]\n            # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]\n            q_absorb, out_absorb = self.get_absorbed()\n            q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below\n            q_nope = torch.matmul(q_nope, q_absorb) # batched MM\n            q_nope = q_nope.transpose(1, 2)\n            #assert q_nope.is_contiguous()\n            \n            # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]\n            query_states = torch.cat([q_nope, q_pe], dim=-1)\n            \n            query_states = query_states.squeeze(1)\n            attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            \n            attn_logits = torch.empty(\n                    (\n                        bsz,\n                        self.num_heads,\n                        4, #num_kv_splits # follow vLLM, fix it TODO\n                        self.kv_lora_rank + 1, \n                    ),\n                    dtype=torch.float32,\n                    device = attn_output.device\n                )\n\n            \"\"\"\n            print(\"query_states\", torch.isnan(query_states).any())\n            print(\"compressed_kv_with_k_pe\", torch.isnan(compressed_kv_with_k_pe[:,:,0,:]).any())\n            print(\"compressed_kv\", torch.isnan(compressed_kv[:,:,0,:]).any())\n            print(\"position_ids\", torch.isnan(position_ids).any())\n            \"\"\"\n\n            # flash attn doesn't support head_dim bigger than 256\n            # use triton attention kernel adapted from vLLM and SGLang for MQA\n            decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,\n                             page_table,\n                             position_ids.squeeze(0).to(torch.int32)+1, attn_logits,\n                             4, #num_kv_splits # follow vLLM, fix it TODO\n                             self.softmax_scale,\n                             past_key_value.page_size)\n            \n            # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]\n            attn_output = attn_output.transpose(1, 2)\n            attn_output = torch.matmul(attn_output, out_absorb.mT)\n            attn_output = attn_output.transpose(1, 2)\n            \n            attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n            attn_output = self.o_proj(attn_output)\n            \n            #print(\"attn_output\", torch.isnan(attn_output).any())\n            return attn_output, None, past_key_value\n        else:\n            if past_key_value is not None:\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n                k_pe.squeeze(0)\n                compressed_kv.squeeze(0)\n                compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\n                compressed_kv, k_pe = torch.split(\n                    compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n                )\n            k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)\n            k_pe = k_pe[:, :kv_seq_len]\n            compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)\n            compressed_kv = compressed_kv[:, :kv_seq_len]\n            kv = (\n                self.kv_b_proj(compressed_kv)\n                .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            )\n            k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)\n            query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)\n            query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n            query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n            key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)\n            key_states[:, :, :, :self.qk_nope_head_dim] = k_nope\n            key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)\n            \n            value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)\n            value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)\n\n            attn_output = flash_attn_func(\n                query_states,\n                key_states,\n                value_states_padded,\n                softmax_scale=self.softmax_scale,\n                causal=True,\n            )\n\n            if self.q_head_dim != self.v_head_dim:\n                attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n            attn_output = attn_output.reshape(\n                bsz, q_len, self.num_heads * self.v_head_dim\n            ).contiguous()\n            attn_output = self.o_proj(attn_output)\n            return attn_output, None, past_key_value\n\n    def forward_linux_flashinfer(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.Tensor] = None,\n            past_key_value: Optional[Cache] = None,\n            output_attentions: bool = False,\n            use_cache: bool = False,\n            cache_position: Optional[torch.Tensor] = None,\n            **kwargs,\n        ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        compressed_kv = self.kv_a_layernorm(compressed_kv)\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)\n        compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)\n\n        kv_seq_len = q_len\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        \n        cos, sin = self.rotary_emb(q_pe, position_ids)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)\n        # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]\n        \n        # decode\n        if q_len == 1 or self.absorb_for_prefill:\n            if past_key_value is not None:\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n                compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\n                compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, past_key_value.page_size, self.kv_lora_rank)\n                k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, past_key_value.page_size, self.qk_rope_head_dim)\n                # k_pe [max_pages, page_size, self.qk_rope_head_dim]\n                # compressed_kv [max_pages, page_size, self.kv_lora_rank]\n\n            # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]\n            # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]\n            q_absorb, out_absorb = self.get_absorbed()\n            q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below\n            q_nope = torch.matmul(q_nope, q_absorb) # batched MM\n            q_nope = q_nope.transpose(1, 2)\n            q_nope = q_nope.contiguous()\n            #assert q_nope.is_contiguous()\n            \n            # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]\n            q_nope.squeeze_(0)\n            q_pe.squeeze_(0)\n\n            # flash attn doesn't support head_dim bigger than 256, use flashinfer\n            if self.mla_wrapper is None:\n                self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)\n            if self.mla_wrapper.need_plan:\n                self.mla_wrapper.need_plan = False\n                if q_len == 1:\n                    self.mla_wrapper.plan(None,None,None,\n                                        position_ids.squeeze(1)+1,\n                                        None,\n                                        self.num_heads,\n                                        self.kv_lora_rank,\n                                        self.qk_rope_head_dim,\n                                        past_key_value.page_size,\n                                        self.softmax_scale,\n                                        q_nope.dtype,\n                                        compressed_kv.dtype)\n                else:\n                    qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device)\n                    kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)\n                    self.mla_wrapper.plan(qo_indptr,None,None,\n                                        kv_len_arr,\n                                        None,\n                                        self.num_heads,\n                                        self.kv_lora_rank,\n                                        self.qk_rope_head_dim,\n                                        past_key_value.page_size,\n                                        self.softmax_scale,\n                                        q_nope.dtype,\n                                        compressed_kv.dtype)\n            attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)\n            \"\"\"\n            k = (\n                torch.cat([compressed_kv, k_pe], dim=-1)\n                .view(-1, 1, 512 + 64)\n                .repeat_interleave(self.num_heads, dim=1)\n            )\n            v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1)\n            lens = position_ids.item() + 1\n            #print(\"lens\", lens)\n            attn_ref, lse_ref = attention_ref(\n                1,\n                torch.cat([q_nope, q_pe], dim=-1),\n                k[:lens],\n                v[:lens],\n                False,\n                self.softmax_scale\n            )\n            attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)\n            \"\"\"\n            \n            # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]\n            # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]\n            attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]\n            attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]\n            attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            \n            attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]\n            attn_output = self.o_proj(attn_output)\n            \n            return attn_output, None, past_key_value\n        else:\n            if past_key_value is not None:\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n                k_pe.squeeze(0)\n                compressed_kv.squeeze(0)\n                compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\n                compressed_kv, k_pe = torch.split(\n                    compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n                )\n            k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)\n            k_pe = k_pe[:, :kv_seq_len]\n            compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)\n            compressed_kv = compressed_kv[:, :kv_seq_len]\n            kv = (\n                self.kv_b_proj(compressed_kv)\n                .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            )\n            k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)\n            query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)\n            query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n            query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n            key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)\n            key_states[:, :, :, :self.qk_nope_head_dim] = k_nope\n            key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)\n            \n            value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)\n            value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)\n\n            attn_output = flash_attn_func(\n                query_states,\n                key_states,\n                value_states_padded,\n                softmax_scale=self.softmax_scale,\n                causal=True,\n            )\n\n            if self.q_head_dim != self.v_head_dim:\n                attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n            attn_output = attn_output.reshape(\n                bsz, q_len, self.num_heads * self.v_head_dim\n            ).contiguous()\n            attn_output = self.o_proj(attn_output)\n            return attn_output, None, past_key_value\n        \n    def forward_windows(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        bsz, q_len, _ = hidden_states.size()\n\n        if q_len <= self.chunck_size:\n            return self.forward_chunck(\n                            hidden_states,\n                            attention_mask,\n                            position_ids,\n                            past_key_value,\n                            output_attentions,\n                            use_cache,\n                            cache_position,\n                            **kwargs\n                        )\n\n        assert output_attentions == False, \"output_attentions is not supported when using chunked attention\"\n        attn_output = None\n        cur_idx = 0\n        while cur_idx < q_len:\n            if attention_mask is not None:\n                chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]\n            else:\n                # generate chunk_mask automatically.\n                self.attn_mask = \\\n                    torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \\\n                        if self.attn_mask is None \\\n                            else self.attn_mask\n                self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \\\n                    -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\\\n                        [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))]\n                self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38\n                self.attn_mask[:, :, :, :cur_idx] = 0\n                chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx))\n\n            cur_output, _, _ = self.forward_chunck(\n                            hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],\n                            chunk_mask,\n                            position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],\n                            past_key_value,\n                            output_attentions,\n                            use_cache,\n                            cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],\n                            **kwargs\n                        )\n            cur_idx += self.chunck_size\n            if attn_output is None:\n                attn_output = cur_output\n            else:\n                attn_output = torch.cat((attn_output, cur_output), dim=-2)\n                \n        return attn_output, None, past_key_value\n\n    def forward_xpu(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        position_embeddings = kwargs.get(\"position_embeddings\", None)\n        if position_embeddings is not None:\n            cos, sin = position_embeddings\n            key_states = torch.cat(\n                [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],\n                dim=-1\n            )\n            from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced\n            rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :],\n                                           key_states[:, :, :, self.qk_nope_head_dim:],\n                                           cos, sin, True)\n        else:\n            q_nope, q_pe = torch.split(\n                query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n            )\n            cos, sin = self.rotary_emb(q_pe, position_ids)\n            q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)\n            query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n            query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n            query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n            key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n            key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n            key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states.half(), value_states.half(), self.layer_idx, cache_kwargs\n            )\n\n        attn_weights = None\n        from ipex_llm.transformers.models.common import scaled_dot_product_attention\n        attn_output = scaled_dot_product_attention(\n            query_states.half(), key_states, value_states,\n            attention_mask.half(), q_len == kv_seq_len, self.softmax_scale\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n        attn_output = self.o_proj(attn_output).to(hidden_states.dtype)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if torch.xpu.is_available():\n            return self.forward_xpu(\n                hidden_states,\n                attention_mask,\n                position_ids,\n                past_key_value,\n                output_attentions,\n                use_cache,\n                cache_position,\n                **kwargs,\n            )\n        elif (os.name == 'nt'\n              or get_compute_capability() < 8\n              or hidden_states.device.type == 'cpu'\n              or device_manager.gpu_vendor != GPUVendor.NVIDIA):\n            return self.forward_windows(\n                hidden_states,\n                attention_mask,\n                position_ids,\n                past_key_value,\n                output_attentions,\n                use_cache,\n                cache_position,\n                **kwargs,\n            )\n        else:\n            if flashinfer_enabled:\n                return self.forward_linux_flashinfer(\n                    hidden_states,\n                    attention_mask,\n                    position_ids,\n                    past_key_value,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                    **kwargs,\n                )\n            else:\n                return self.forward_linux_triton(\n                    hidden_states,\n                    attention_mask,\n                    position_ids,\n                    past_key_value,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                    **kwargs,\n                )\n\n\nclass KLlamaAttention(BaseInjectedModule):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n        \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n        Args:\n            q (`torch.Tensor`): The query tensor.\n            k (`torch.Tensor`): The key tensor.\n            cos (`torch.Tensor`): The cosine part of the rotary embedding.\n            sin (`torch.Tensor`): The sine part of the rotary embedding.\n            position_ids (`torch.Tensor`, *optional*):\n                Deprecated and unused.\n            unsqueeze_dim (`int`, *optional*, defaults to 1):\n                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n        Returns:\n            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n        \"\"\"\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n        q_embed = (q * cos) + (rotate_half(q) * sin)\n        k_embed = (k * cos) + (rotate_half(k) * sin)\n        return q_embed, k_embed\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.config.pretraining_tp > 1:\n            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp\n            query_slices = self.q_proj.weight.split(\n                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0\n            )\n            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)\n            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)\n\n            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]\n            query_states = torch.cat(query_states, dim=-1)\n\n            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]\n            key_states = torch.cat(key_states, dim=-1)\n\n            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]\n            value_states = torch.cat(value_states, dim=-1)\n\n        else:\n            query_states = self.q_proj(hidden_states)\n            key_states = self.k_proj(hidden_states)\n            value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        if position_embeddings is None:\n\n            logger.warning(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            cos, sin = self.rotary_emb(value_states, position_ids)\n        else:\n            cos, sin = position_embeddings\n        query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)\n        if q_len == 1:\n            position_ids = position_ids[0][-1].unsqueeze(0).unsqueeze(0)\n            query_states = query_states[:, :, -1:]\n            key_states = key_states[:, :, -1:]\n\n        attn_output = KLlamaModel.dynamic_sdpa.apply(\n            self.layer_idx,\n            bsz,\n            position_ids[0][0],\n            query_states.transpose(1, 2).to(torch.float16),\n            key_states.transpose(1, 2).to(torch.float16),\n            value_states.transpose(1, 2).to(torch.float16),\n            mode=\"prefill\" if q_len > 1 else \"generate\",\n        )\n\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, -1)\n\n        if self.config.pretraining_tp > 1:\n            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)\n            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)\n            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])\n        else:\n            attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"xpu\",\n                 generate_device: str = \"xpu\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n        assert prefill_device.lower()[:3] == \"xpu\", \"KQwen3MoeAttentionIPEXLLM only supports XPU device\"\n        assert generate_device.lower()[:3] == \"xpu\", \"KQwen3MoeAttentionIPEXLLM only supports XPU device\"\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.Tensor],\n        position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        bsz, q_len, _ = hidden_states.size()\n        input_dtype = hidden_states.dtype\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        if not hasattr(self, 'qkv_proj'):\n            from ipex_llm.transformers.models.common import merge_quantized_qkv\n            merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module)\n\n        qkv = self.qkv_proj(hidden_states)\n        qkv = qkv.view(bsz, q_len, -1, self.head_dim)\n        qkv = qkv.transpose(1, 2)\n        query_states, key_states, value_states = qkv.split([self.config.num_attention_heads,\n                                                            self.config.num_key_value_heads,\n                                                            self.config.num_key_value_heads], dim=1)\n        query_states = self.q_norm(query_states)\n        key_states = self.k_norm(key_states)\n\n        if position_embeddings is None:\n            position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        cos, sin = position_embeddings\n\n        from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced\n        rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(key_states.half(), value_states.half(),\n                                                             self.layer_idx, cache_kwargs)\n\n        attn_weights = None\n        from ipex_llm.transformers.models.common import scaled_dot_product_attention\n        attn_output = scaled_dot_product_attention(\n            query_states.half(), key_states, value_states,\n            attention_mask.half(), q_len == key_states.size(2), self.scaling\n        )\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output).to(input_dtype)\n        return attn_output, attn_weights\n"
  },
  {
    "path": "archive/ktransformers/operators/balance_serve_attention.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.2.5\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport torch\nfrom torch import nn\nfrom ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention\nfrom ktransformers.models.modeling_smallthinker import SmallthinkerAttention\nfrom ktransformers.models.modeling_glm4_moe import Glm4MoeAttention\nfrom ktransformers.models.modeling_qwen3_next import Qwen3NextGatedDeltaNet\nfrom typing import Optional, Tuple\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader\nimport logging\nfrom transformers.configuration_utils import PretrainedConfig\nfrom flashinfer import BatchMLAPagedAttentionWrapper\nfrom ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn\nfrom ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache\nlogger = logging.getLogger(\"attention\")\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\nclass flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n\n    def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:\n        if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):\n            kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)\n            q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)\n            out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)\n            self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, \n                                      bias=False, dtype=q_absorb.dtype, device=q_absorb.device)\n            self.q_absorb.weight.data = q_absorb\n            self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, \n                                        bias=False, dtype=out_absorb.dtype, device=out_absorb.device)\n            self.out_absorb.weight.data = out_absorb\n            #del self.orig_module.kv_b_proj\n        q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)\n        out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)\n        return q_absorb, out_absorb\n    \n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KDeepSeekV3Cache,\n                position_ids: torch.Tensor,\n                wrapper: BatchMLAPagedAttentionWrapper,\n                num_tokens_tensors: torch.Tensor,\n                page_idx: torch.Tensor,\n                page_offset: torch.Tensor,\n                ):\n        q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states, num_tokens_tensors)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)\n        q = q.view(q_len, self.num_heads, self.q_head_dim)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        compressed_kv = compressed_kv.contiguous()\n        compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)\n        k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)\n        compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)\n        \n        cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)\n        q_pe = q_pe.squeeze(0)\n        if kv_cache is not None:\n            \n            # page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"page_idx\": page_idx, \"page_offset\": page_offset}  # Specific to RoPE models\n            compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)\n            compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)\n            k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)\n            \n        q_absorb, out_absorb = self.get_absorbed()\n        q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below\n        q_nope = torch.matmul(q_nope, q_absorb) # batched MM\n        q_nope = q_nope.transpose(0, 1)\n        # q_nope.squeeze_(1)\n        # q_pe.squeeze_(1)\n\n        attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)\n        attn_output = attn_output.transpose(0, 1)\n        attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]\n        attn_output = attn_output.transpose(0, 1)\n        attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)\n        attn_output = self.o_proj(attn_output, num_tokens_tensors)\n        return attn_output\n\nclass KQwen2MoeAttention(BaseInjectedModule, Qwen2MoeAttention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n\n    # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\n    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n        \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n        Args:\n            q (`torch.Tensor`): The query tensor.\n            k (`torch.Tensor`): The key tensor.\n            cos (`torch.Tensor`): The cosine part of the rotary embedding.\n            sin (`torch.Tensor`): The sine part of the rotary embedding.\n            position_ids (`torch.Tensor`):\n                Deprecated and unused.\n            unsqueeze_dim (`int`, *optional*, defaults to 1):\n                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n        Returns:\n            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n        \"\"\"\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n        q_embed = (q * cos) + (rotate_half(q) * sin)\n        k_embed = (k * cos) + (rotate_half(k) * sin)\n        return q_embed, k_embed\n\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KGQACache,\n                position_ids: torch.Tensor,\n                wrapper: flashInferAttn,\n                bsz_tensors: torch.Tensor,\n                page_idx: torch.Tensor,\n                page_offset: torch.Tensor,\n                ):\n        q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states, bsz_tensors)\n        key_states = self.k_proj(hidden_states, bsz_tensors)\n        value_states = self.v_proj(hidden_states, bsz_tensors)\n\n\n        query_states = query_states.view(q_len, self.num_heads, self.head_dim)\n        key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        \n        cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0))\n        query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)\n\n        query_states = query_states.view(q_len, self.num_heads, self.head_dim)\n        key_states = key_states.view(\n            q_len, self.num_key_value_heads, self.head_dim\n        )\n        value_states = value_states.view(\n            q_len, self.num_key_value_heads, self.head_dim\n        )\n\n        k_cache = kv_cache.get_k_cache(self.layer_idx)\n        v_cache = kv_cache.get_v_cache(self.layer_idx)\n\n\n        attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)\n  \n\n        attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)\n\n        return attn_output\n\nclass KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n\n    # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\n    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n        \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n        Args:\n            q (`torch.Tensor`): The query tensor.\n            k (`torch.Tensor`): The key tensor.\n            cos (`torch.Tensor`): The cosine part of the rotary embedding.\n            sin (`torch.Tensor`): The sine part of the rotary embedding.\n            position_ids (`torch.Tensor`):\n                Deprecated and unused.\n            unsqueeze_dim (`int`, *optional*, defaults to 1):\n                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n        Returns:\n            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n        \"\"\"\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n        q_embed = (q * cos) + (rotate_half(q) * sin)\n        k_embed = (k * cos) + (rotate_half(k) * sin)\n        return q_embed, k_embed\n\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KGQACache,\n                position_ids: torch.Tensor,\n                wrapper: flashInferAttn,\n                bsz_tensors: torch.Tensor,\n                page_idx: torch.Tensor,\n                page_offset: torch.Tensor,\n                ):\n        q_len, _ = hidden_states.size()\n\n        bsz_tensors_q = bsz_tensors * self.num_heads\n        bsz_tensors_kv = bsz_tensors * self.num_key_value_heads\n\n        query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors_q)\n        key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors_kv)\n        value_states = self.v_proj(hidden_states, bsz_tensors)\n\n\n        query_states = query_states.view(q_len, self.num_heads, self.head_dim)\n        key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        \n        cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0))\n        query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)\n\n        query_states = query_states.view(q_len, self.num_heads, self.head_dim)\n        key_states = key_states.view(\n            q_len, self.num_key_value_heads, self.head_dim\n        )\n        value_states = value_states.view(\n            q_len, self.num_key_value_heads, self.head_dim\n        )\n\n        k_cache = kv_cache.get_k_cache(self.layer_idx)\n        v_cache = kv_cache.get_v_cache(self.layer_idx)\n\n\n        attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)\n  \n\n        attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)\n\n        return attn_output\n\n\nclass deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention):\n    def __init__(self,\n                    key: str,\n                    gguf_loader : GGUFLoader,\n                    config: PretrainedConfig,\n                    orig_module: nn.Module,\n                    prefill_device: str = \"cuda\",\n                    generate_device: str = \"cuda\",\n                    chunck_size: int = 1000,\n                    **kwargs):\n            BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n            self.orig_module.__init__(orig_module.config,\n                orig_module.layer_idx)\n            self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n\n    def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:\n        if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):\n            kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)\n            q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)\n            out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)\n            self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, \n                                    bias=False, dtype=q_absorb.dtype, device=q_absorb.device)\n            self.q_absorb.weight.data = q_absorb\n            self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, \n                                        bias=False, dtype=out_absorb.dtype, device=out_absorb.device)\n            self.out_absorb.weight.data = out_absorb\n            #del self.orig_module.kv_b_proj\n        q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)\n        out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)\n        return q_absorb, out_absorb\n    \n\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KDeepSeekV3Cache,\n                position_ids: torch.Tensor,\n                wrapper: None,\n                num_tokens_tensors: torch.Tensor,\n                page_idx: torch.Tensor,\n                page_offset: torch.Tensor,\n                attention_masks: Optional[list[torch.Tensor]] = None,\n                q_indptr: Optional[torch.Tensor] = None,\n                kv_indices: Optional[torch.Tensor] = None,\n                kv_indptr: Optional[torch.Tensor] = None,\n                bsz_tensors: Optional[torch.Tensor] = None,\n                last_page_len: Optional[torch.Tensor] = None,\n                ):\n        # range bsz_tensors\n        final_attention_output = torch.tensor([], device=hidden_states.device)\n        for i in range(bsz_tensors[0]):\n            batch_num_tokens_tensors = q_indptr[i+1] - q_indptr[i]\n            batch_last_page_len = last_page_len[i]\n            # kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe\n            batch_page_idx = page_idx[q_indptr[i]:q_indptr[i+1]]\n            batch_page_offset = page_offset[q_indptr[i]:q_indptr[i+1]]\n            # kv_page_nums is the number of pages for the current batch\n            kv_page_nums = kv_indptr[i+1] - kv_indptr[i]\n            # kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)\n            kv_total_len = kv_page_nums * kv_cache.page_size\n            if batch_last_page_len is not None:\n                kv_total_len = kv_total_len - (kv_cache.page_size - batch_last_page_len)\n            # print(f\"kv_total_len's shape {kv_total_len.shape}\")\n            # kv_index is the index of the kv cache pages for the current batch\n            kv_index = kv_indices[kv_indptr[i]:kv_indptr[i+1]]\n            # we can index [kv_index, page_offset_indices] to get the kv cache for the current batch\n            # from q_indptr[i] to q_indptr[i+1] is the range of the current batch\n            batch_hidden_states = hidden_states[q_indptr[i]:q_indptr[i+1]]\n            batch_position_ids = position_ids[q_indptr[i]:q_indptr[i+1]]\n            q_len, _ = batch_hidden_states.size()\n            # print(\"q_len -> \", q_len)\n\n            if self.q_lora_rank is None:\n                q = self.q_proj(batch_hidden_states, batch_num_tokens_tensors)\n            else:\n                q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(batch_hidden_states, batch_num_tokens_tensors), batch_num_tokens_tensors), batch_num_tokens_tensors)\n            # for v3, bsz, q_len, num_heads(128), qk_head_dim(192=128(nope)+64(rope))\n            q = q.view(q_len, self.num_heads, self.q_head_dim)\n            # q_nope is [q_len, num_heads(128), qk_nope_head_dim(128)]\n            # q_pe is [q_len, num_heads(128), qk_rope_head_dim(64)]\n            q_nope, q_pe = torch.split(\n                q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n            )\n            # compressed_kv is [q_len, kv_lora_rank(512) + rope(64)]\n            compressed_kv = self.kv_a_proj_with_mqa(batch_hidden_states, batch_num_tokens_tensors)\n            # compressed_kv is [q_len, kv_lora_rank(512)], k_pe is [q_len, rope(64)]\n            compressed_kv, k_pe = torch.split(\n                compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n            )\n            compressed_kv = compressed_kv.contiguous()\n            compressed_kv = self.kv_a_layernorm(compressed_kv, batch_num_tokens_tensors)\n            # k_pe is [q_len, 1, qk_rope_head_dim(64)]\n            k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)\n            # compressed_kv is [q_len, 1, kv_lora_rank(512)]\n            compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)\n            \n            cos, sin = self.rotary_emb(q_pe, batch_position_ids.unsqueeze(0))\n            # print(f\"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}\")\n            q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)\n            q_pe = q_pe.squeeze(0)\n            # q_pe is [num_heads(128), q_len, qk_rope_head_dim(64)]\n            q_pe.transpose_(0, 1)            \n            if kv_cache is not None:\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"page_idx\": batch_page_idx, \"page_offset\": batch_page_offset}  # Specific to RoPE models\n                compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, batch_page_idx, batch_page_offset, cache_kwargs)\n                compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)\n                k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)\n            # q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]\n            # out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim\n            q_absorb, out_absorb = self.get_absorbed()\n            # q_nope is [num_heads(128), q_len, qk_nope_head_dim(128)]\n            q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below\n            # q_nope is [num_heads(128), q_len, kv_lora_rank(512)]\n            q_nope = torch.matmul(q_nope, q_absorb) # batched MM\n\n            # # q_nope is [q_len, num_heads(128), kv_lora_rank(512)]\n            # q_nope = q_nope.transpose(0, 1)\n\n            # we need to index out the compressed_kv and k_pe for the current batch\n            batch_compressed_kv = None\n            batch_k_pe = None\n            for page_index in kv_index:\n                if kv_total_len > kv_cache.page_size:\n                    tmp_compressed_kv = compressed_kv[page_index, 0:kv_cache.page_size, :]\n                    tmp_k_pe = k_pe[page_index, 0:kv_cache.page_size, :]\n                    if batch_compressed_kv is None or batch_k_pe is None:\n                        batch_compressed_kv = tmp_compressed_kv\n                        batch_k_pe = tmp_k_pe\n                    else: \n                        batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n                        batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n                    kv_total_len -= kv_cache.page_size\n                else:\n                    tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :]\n                    tmp_k_pe = k_pe[page_index, 0:kv_total_len, :]\n                    if batch_compressed_kv is None or batch_k_pe is None:\n                        batch_compressed_kv = tmp_compressed_kv\n                        batch_k_pe = tmp_k_pe\n                    else: \n                        batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n                        batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n                    break\n            # batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]\n            # batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]\n            attention_weights = (torch.matmul(q_pe,batch_k_pe.mT) + torch.matmul(q_nope, batch_compressed_kv.mT)) * self.softmax_scale\n            # attention_weights is [num_heads(128), q_len, k_len]\n            \n            # attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(q_len,-1,-1).transpose(0,1)\n            \n            # attention_masks[i] is [q_len, k_len]\n            \n            attention_weights = (attention_weights + attention_masks[i])\n            # attention_weights shape is [num_heads(128), q_len, k_len]\n            attention_weights = nn.functional.softmax(attention_weights,dim=-1,dtype=torch.float32).to(q_pe.dtype)\n            attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),q_len, lora_rank(512)]\n            # out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]\n            out_absorb = out_absorb.transpose(1,2)\n            # q for q_len, n for num_heads, h for v_head_dim, v for kv_lora_rank\n            attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), q_len, v_head_dim(128)]\n            attn_output = attn_output.transpose(0, 1) # [q_len, num_heads(128), v_head_dim(128)]\n            attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)\n            attn_output = self.o_proj(attn_output, batch_num_tokens_tensors)\n            final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)\n        return final_attention_output\n\nclass KSmallthinkerAttention(BaseInjectedModule, SmallthinkerAttention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n        \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n        Args:\n            q (`torch.Tensor`): The query tensor.\n            k (`torch.Tensor`): The key tensor.\n            cos (`torch.Tensor`): The cosine part of the rotary embedding.\n            sin (`torch.Tensor`): The sine part of the rotary embedding.\n            position_ids (`torch.Tensor`, *optional*):\n                Deprecated and unused.\n            unsqueeze_dim (`int`, *optional*, defaults to 1):\n                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n        Returns:\n            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n        \"\"\"\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n        q_embed = (q * cos) + (rotate_half(q) * sin)\n        k_embed = (k * cos) + (rotate_half(k) * sin)\n        return q_embed, k_embed\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KGQACache,\n                freqs_cis: torch.Tensor,\n                wrapper: flashInferAttn,\n                bsz_tensors: torch.Tensor,\n                position_ids: torch.Tensor = None,\n                ):\n\n        if self.use_qk_norm:\n            raise NotImplementedError(\"use_qk_norm is not implemented yet\")\n\n        q_len, _ = hidden_states.size()\n        query_states = self.q_proj(hidden_states, bsz_tensors)\n        key_states = self.k_proj(hidden_states, bsz_tensors)\n        value_states = self.v_proj(hidden_states, bsz_tensors)\n\n        query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)\n        key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        \n        # cos, sin = freqs_cis\n        \"\"\"\n        print(query_states.shape)\n        print(key_states.shape)\n        print(cos.shape)\n        print(sin.shape)\n        \"\"\"\n        if freqs_cis:  \n            cos, sin = freqs_cis\n            query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)\n\n\n\n        query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)\n        key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)\n\n        k_cache = kv_cache.get_k_cache(self.layer_idx)\n        v_cache = kv_cache.get_v_cache(self.layer_idx)\n\n\n        attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)\n  \n\n        attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)\n\n        return attn_output\n\n\n    \n\nclass KGlm4MoeAttention(BaseInjectedModule, Glm4MoeAttention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n    def apply_rotary_pos_emb(\n        self,\n        q: torch.Tensor,\n        k: torch.Tensor,\n        freqs_cis: Tuple[torch.Tensor, torch.Tensor],\n        unsqueeze_dim=2\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n\n        # Keep half or full tensor for later concatenation\n        cos = freqs_cis[0]\n        sin = freqs_cis[1]\n        rotary_dim = cos.shape[-1]\n\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n\n        q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]\n        k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]\n\n        # Apply rotary embeddings on the first half or full tensor\n        q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)\n        k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)\n\n        # Concatenate back to full shape\n        q_embed = torch.cat([q_embed, q_pass], dim=-1)\n        k_embed = torch.cat([k_embed, k_pass], dim=-1)\n        return q_embed, k_embed\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KGQACache,\n                freqs_cis: torch.Tensor,\n                wrapper: flashInferAttn,\n                bsz_tensors: torch.Tensor,\n                position_ids: torch.Tensor = None,\n                ):\n\n        q_len, _ = hidden_states.size()\n        query_states = self.q_proj(hidden_states, bsz_tensors)\n        key_states = self.k_proj(hidden_states, bsz_tensors)\n        value_states = self.v_proj(hidden_states, bsz_tensors)\n\n\n        if self.use_qk_norm:\n            query_states = self.q_norm(query_states, bsz_tensors)\n            key_states = self.k_norm(key_states, bsz_tensors)\n\n        # cos, sin = freqs_cis\n        \"\"\"\n        print(query_states.shape)\n        print(key_states.shape)\n        print(cos.shape)\n        print(sin.shape)\n        \"\"\"\n\n        query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)\n        key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)\n        value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)\n\n        if freqs_cis is not None:  \n            query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), freqs_cis)\n\n        query_states = query_states.view(q_len, self.config.num_attention_heads, self.head_dim)\n        key_states = key_states.view(q_len, self.config.num_key_value_heads, self.head_dim)\n        value_states = value_states.view(q_len, self.config.num_key_value_heads, self.head_dim)\n\n\n        k_cache = kv_cache.get_k_cache(self.layer_idx)\n        v_cache = kv_cache.get_v_cache(self.layer_idx)\n\n\n        print(f\"{k_cache.shape=}, {v_cache.shape=}, {query_states.shape=}, {key_states.shape=}, {value_states.shape=}\")\n        attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)\n  \n\n        attn_output = self.o_proj(attn_output.view(q_len, self.config.num_attention_heads * self.head_dim), bsz_tensors)\n\n        return attn_output\n    \nfrom ktransformers.models.modeling_qwen3_next import apply_mask_to_padding_states\nimport torch.nn.functional as F\n\nfrom ktransformers.models.modeling_qwen3_next import Qwen3NextAttention\n\nclass KQwen3NextAttention(BaseInjectedModule, Qwen3NextAttention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n    # Adapted from transformers.models.glm.modular_glm.apply_rotary_pos_emb\n    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n        \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n        Removes the interleaving of cos and sin from GLM\n\n        Args:\n            q (`torch.Tensor`): The query tensor.\n            k (`torch.Tensor`): The key tensor.\n            cos (`torch.Tensor`): The cosine part of the rotary embedding.\n            sin (`torch.Tensor`): The sine part of the rotary embedding.\n            position_ids (`torch.Tensor`, *optional*):\n                Deprecated and unused.\n            unsqueeze_dim (`int`, *optional*, defaults to 1):\n                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n        Returns:\n            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n        \"\"\"\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n\n        # Keep half or full tensor for later concatenation\n        rotary_dim = cos.shape[-1]\n        q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]\n        k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]\n\n        # Apply rotary embeddings on the first half or full tensor\n        q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)\n        k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)\n\n        # Concatenate back to full shape\n        q_embed = torch.cat([q_embed, q_pass], dim=-1)\n        k_embed = torch.cat([k_embed, k_pass], dim=-1)\n        return q_embed, k_embed\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KGQACache,\n                freqs_cis: torch.Tensor,\n                wrapper: flashInferAttn,\n                bsz_tensors: torch.Tensor,\n                position_ids: Optional[torch.Tensor] = None,\n                attention_mask: Optional[torch.Tensor] = None,\n                ):\n\n        q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states, bsz_tensors)\n\n        query_states, gate = torch.chunk(\n            self.q_proj(hidden_states).view(q_len, -1, self.head_dim * 2), 2, dim=-1\n        )\n        gate = gate.reshape(q_len, -1)\n\n        key_states = self.k_proj(hidden_states, bsz_tensors)\n\n        query_states = query_states.reshape(q_len, -1)\n        query_states = self.q_norm(query_states, bsz_tensors)\n        key_states = self.k_norm(key_states, bsz_tensors)\n\n\n        value_states = self.v_proj(hidden_states, bsz_tensors)\n\n\n\n        query_states = query_states.view(q_len, self.num_attention_heads, self.head_dim)\n        key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)\n\n        if freqs_cis:  \n            cos, sin = freqs_cis\n            query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)\n            query_states, key_states = query_states.squeeze(0), key_states.squeeze(0)\n\n\n        k_cache = kv_cache.get_k_cache(self.layer_idx)\n        v_cache = kv_cache.get_v_cache(self.layer_idx)\n\n        attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)\n  \n        attn_output = attn_output.reshape(q_len, -1).contiguous()\n        attn_output = attn_output * torch.sigmoid(gate)\n\n\n        attn_output = self.o_proj(attn_output.view(q_len, self.num_attention_heads * self.head_dim), bsz_tensors)\n\n        return attn_output\n\n\nclass KQwen3NextGatedDeltaNet(BaseInjectedModule, Qwen3NextGatedDeltaNet):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n    def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):\n        \"\"\"\n        Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`.\n        \"\"\"\n\n        new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (\n            self.num_k_heads,\n            2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads,\n        )\n        new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)\n\n        mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)\n        mixed_ba = mixed_ba.view(*new_tensor_shape_ba)\n        split_arg_list_qkvz = [\n            self.head_k_dim,\n            self.head_k_dim,\n            (self.num_v_heads // self.num_k_heads * self.head_v_dim),\n            (self.num_v_heads // self.num_k_heads * self.head_v_dim),\n        ]\n        split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]\n        query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3)\n        b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3)\n        # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]\n        value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim)\n        z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim)\n        b = b.reshape(b.size(0), b.size(1), self.num_v_heads)\n        a = a.reshape(a.size(0), a.size(1), self.num_v_heads)\n        return query, key, value, z, b, a\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        conv_states: Optional[list[torch.Tensor]] = None,\n        recurrent_states: Optional[list[torch.Tensor]] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        bsz_tensors: Optional[torch.Tensor] = None,\n    ):\n        hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)\n\n        # Set up dimensions for reshapes later\n        batch_size, seq_len, _ = hidden_states.shape\n\n        conv_state = conv_states[self.layer_idx] if conv_states is not None else None\n        recurrent_state = (\n            recurrent_states[self.layer_idx] if recurrent_states is not None else None\n        )\n\n        use_precomputed_states = (\n            conv_state is not None\n            and recurrent_state is not None\n            and seq_len == 1\n        )\n\n        projected_states_qkvz = self.in_proj_qkvz(hidden_states, bsz_tensors)\n        projected_states_ba = self.in_proj_ba(hidden_states, bsz_tensors)\n        query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)\n        query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))\n\n        mixed_qkv = torch.cat((query, key, value), dim=-1)\n        mixed_qkv = mixed_qkv.transpose(1, 2)\n\n        if use_precomputed_states:\n            # 2. Convolution sequence transformation\n            # NOTE: the conv state is updated in `causal_conv1d_update`\n            mixed_qkv = self.causal_conv1d_update(\n                mixed_qkv,\n                conv_state,\n                self.conv1d.weight.squeeze(1),\n                self.conv1d.bias,\n                self.activation,\n            )\n        else:\n            conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))\n\n            if self.causal_conv1d_fn is not None:\n                mixed_qkv = self.causal_conv1d_fn(\n                    x=mixed_qkv,\n                    weight=self.conv1d.weight.squeeze(1),\n                    bias=self.conv1d.bias,\n                    activation=self.activation,\n                    seq_idx=None,\n                )\n            else:\n                mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])\n\n        mixed_qkv = mixed_qkv.transpose(1, 2)\n        query, key, value = torch.split(\n            mixed_qkv,\n            [\n                self.key_dim,\n                self.key_dim,\n                self.value_dim,\n            ],\n            dim=-1,\n        )\n        query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)\n        key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)\n        value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)\n\n        beta = b.sigmoid()\n        # If the model is loaded in fp16, without the .float() here, A might be -inf\n        g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)\n        if self.num_v_heads // self.num_k_heads > 1:\n            query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)\n            key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)\n\n        if not use_precomputed_states:\n            core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(\n                query,\n                key,\n                value,\n                g=g,\n                beta=beta,\n                initial_state=None,\n                output_final_state=conv_state is not None,\n                use_qk_l2norm_in_kernel=True,\n            )\n\n        else:\n            core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(\n                query,\n                key,\n                value,\n                g=g,\n                beta=beta,\n                initial_state=recurrent_state,\n                output_final_state=conv_state is not None,\n                use_qk_l2norm_in_kernel=True,\n            )\n\n        # Update cache\n        recurrent_state = last_recurrent_state\n\n        z_shape_og = z.shape\n        # reshape input data into 2D tensor\n        core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])\n        z = z.reshape(-1, z.shape[-1])\n        core_attn_out = self.norm(core_attn_out, z)\n        core_attn_out = core_attn_out.reshape(z_shape_og)\n        core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)\n\n        output = self.out_proj(core_attn_out, bsz_tensors)\n\n        if conv_state is not None:\n            conv_states[self.layer_idx] = conv_state\n        if recurrent_state is not None:\n            recurrent_states[self.layer_idx] = recurrent_state\n\n        return output"
  },
  {
    "path": "archive/ktransformers/operators/base_operator.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nfrom typing import Any\nfrom torch import nn, Tensor\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom transformers.configuration_utils import PretrainedConfig\nimport ktransformers.util.utils as utils\nclass BaseInjectedModule(nn.Module):\n    \n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        nn.Module.__init__(self)\n        nn.Module.__setattr__(self, \"orig_module\", orig_module)\n        object.__setattr__(self, \"key\", key)\n        object.__setattr__(self, \"gguf_loader\", gguf_loader)\n        object.__setattr__(self, \"config\", config)\n        object.__setattr__(self, \"prefill_device\", prefill_device)\n        object.__setattr__(self, \"generate_device\", generate_device)\n        object.__setattr__(self, \"device\", generate_device)\n        \n    def __getattr__(self, name: str) -> Any:\n        # __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,\n        # but __setattr__ in nn.Module call super().__setattr__ in that case, there may be some attribute set \n        # but can't get using __getattr__, typically these attr is build in attr of the class, so class.attr does not\n        # call __getattr__.\n        # Example:\n        # ...import torch\n        # ...l=torch.nn.Linear(100,200)\n        # ...l.out_features # 200\n        # ...l.__getattr__(\"out_features\") # AttributeError: 'Linear' object has no attribute 'out_features'\n        try:\n            return object.__getattribute__(self, name) # if this attr belongs to BaseInjectedModule\n        except:\n            if name == \"orig_module\":\n                return nn.Module.__getattr__(self, \"orig_module\")\n            try:\n                return nn.Module.__getattr__(self, \"orig_module\").__getattr__(name) # if this attr belongs to orig_module\n            except:\n                return super(nn.Module, nn.Module.__getattr__(self, \"orig_module\")).__getattribute__(name) # if this attr belongs to orig_module but not in nn.Module.__dict__\n\n    def __setattr__(self, name: str, value: Tensor | nn.Module) -> None:\n        if name == \"orig_module\":\n            return nn.Module.__setattr__(self, \"orig_module\", value)\n        elif hasattr(self, name):\n            return object.__setattr__(self, name, value)\n        return nn.Module.__getattr__(self, \"orig_module\").__setattr__(name, value)\n    \n    def forward(self, *args, **kwargs):\n        return self.orig_module.forward(*args, **kwargs)\n    \n    def load(self):\n        for name, child in self._modules.items():\n            utils.load_weights(child, self.gguf_loader, self.key+\".\")\n"
  },
  {
    "path": "archive/ktransformers/operators/cpuinfer.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  : This script defines the `CPUInferKVCache` and `CPUInfer` classes for performing inference \n               with a Key-Value Cache on the CPU. The `CPUInferKVCache` class is responsible for configuring \n               and managing key-value caches, updating and retrieving cache data, and handling attention \n               operations. It supports different cache types (e.g., Q4_0, FP16) and retrieval strategies \n               (e.g., shared, separate). The `CPUInfer` class handles task submission and synchronization \n               on the CPU, with optional CUDA stream integration for tasks involving GPU acceleration. \n               These classes facilitate efficient caching and memory management for deep learning models \n               that leverage key-value attention mechanisms, particularly on CPU-based systems.\nAuthor       : djw\nDate         : 2024-08-26 23:25:24\nVersion      : 1.0.0\nLastEditors  : djw \nLastEditTime : 2024-08-26 23:25:24\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport sys, os\nfrom typing import Any\nimport torch\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Release\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Debug\"))\nimport cpuinfer_ext\nfrom ktransformers.server.config.config import Config\n\n\nclass CPUInferKVCache:\n    def __init__(\n        self,\n        layer_num: int = 32,\n        kv_head_num: int = 8,\n        q_head_num: int = 32,\n        head_dim: int = 128,\n        block_len: int = 256,\n        anchor_num: int = 4,\n        anchor_type: str = \"FIXED\",\n        kv_type: str = \"Q4_0\",\n        retrieval_type: str = \"SHARED\",\n        layer_step: int = 1,\n        token_step: int = 1,\n        layer_offset: int = 0,\n        max_thread_num: int = 32,\n        max_batch_size: int = 4,\n        max_block_num: int = 512,\n    ):\n\n        if anchor_type == \"FIXED\":\n            anchor_type = cpuinfer_ext.kvcache.AnchorType.FIXED\n        elif anchor_type == \"QUEST\":\n            anchor_type = cpuinfer_ext.kvcache.AnchorType.QUEST\n        elif anchor_type == \"DYNAMIC\":\n            anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC\n        elif anchor_type == \"BLOCK_MEAN\":\n            anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MEAN\n        elif anchor_type == \"BLOCK_MAX\":\n            anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MAX\n        else:\n            raise ValueError(f\"Unknown anchor type: {anchor_type}\")\n\n        if kv_type == \"FP16\":\n            kv_type = cpuinfer_ext.kvcache.ggml_type.FP16\n        elif kv_type == \"FP32\":\n            assert False, \"FP32 is not supported yet.\"\n            kv_type = cpuinfer_ext.kvcache.ggml_type.FP32\n        elif kv_type == \"Q4_0\":\n            kv_type = cpuinfer_ext.kvcache.ggml_type.Q4_0\n        elif kv_type == \"Q8_0\":\n            kv_type = cpuinfer_ext.kvcache.ggml_type.Q8_0\n        else:\n            raise ValueError(f\"Unknown kv type: {kv_type}\")\n\n        if retrieval_type == \"SHARED\":\n            retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER\n        elif retrieval_type == \"INDIVIDUAL\":\n            retrieval_type = cpuinfer_ext.kvcache.RetrievalType.QHEAD\n        elif retrieval_type == \"SEPARATE\":\n            retrieval_type = cpuinfer_ext.kvcache.RetrievalType.KVHEAD\n\n        self.config = cpuinfer_ext.kvcache.KVCacheConfig(\n            layer_num,\n            kv_head_num,\n            q_head_num,\n            head_dim,\n            block_len,\n            anchor_num,\n            anchor_type,\n            kv_type,\n            retrieval_type,\n            layer_step,\n            token_step,\n            layer_offset,\n            max_block_num,\n            max_batch_size,\n            max_thread_num,\n        )\n        self.kvcache = cpuinfer_ext.kvcache.KVCache(self.config)\n\n    def load_kvcache(self, tensor_file_path: str):\n        if not os.path.exists(tensor_file_path):\n            raise FileNotFoundError(f\"The file {tensor_file_path} does not exist.\")\n        return self.kvcache.load_kvcache(tensor_file_path,)\n\n    def dump_kvcache(\n        self, block_table: torch.Tensor, cache_total_len: int, tensor_file_path: str\n    ):\n        assert (\n            block_table.dim() == 1\n            and block_table.dtype == torch.int\n            and block_table.is_contiguous()\n            and block_table.device == torch.device(\"cpu\")\n        ), \"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            block_table.dim(),\n            block_table.size(),\n            block_table.dtype,\n            block_table.is_contiguous(),\n            block_table.device,\n        )\n\n        assert (\n            cache_total_len > 0\n            and cache_total_len <= self.config.block_len * block_table.size(0)\n        ), \"cache_total_len: {}\".format(cache_total_len)\n\n        if not os.path.exists(os.path.dirname(tensor_file_path)):\n            os.makedirs(os.path.dirname(tensor_file_path))\n\n        return self.kvcache.dump_kvcache(\n            block_table.data_ptr(),\n            cache_total_len,\n            tensor_file_path,\n        )\n\n    def update_cache_total_len(self, cache_total_len: int):\n        assert cache_total_len > 0, \"cache_total_len: {}\".format(cache_total_len)\n        self.kvcache.update_cache_total_len(cache_total_len)\n\n    # q_in: (bsz, q_len, q_head_num, head_dim)\n    # output: (bsz, q_len, q_head_num, head_dim)\n    # attn_lse: (bsz, q_len, q_head_num)\n    # block_table: (bsz, max_block_num)\n    def attn(\n        self,\n        q_in: torch.Tensor,\n        output: torch.Tensor,\n        attn_lse: torch.Tensor,\n        layer_idx: int,\n        generate_token_idx: int,\n        block_table: torch.Tensor | None = None,\n        cache_seqlens: torch.Tensor | None = None,\n        pick_block_num: int | None = None,\n        init_block_num: int | None = None,\n        local_block_num: int | None = None,\n    ):\n\n        assert (\n            q_in.dim() == 4\n            and q_in.size(2) == self.config.q_head_num\n            and q_in.size(3) == self.config.head_dim\n            and q_in.dtype == torch.float16\n            and q_in.is_contiguous()\n            and q_in.device == torch.device(\"cpu\")\n        ), \"q_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            q_in.dim(), q_in.size(), q_in.dtype, q_in.is_contiguous(), q_in.device\n        )\n\n        batch_size = q_in.size(0)\n        q_len = q_in.size(1)\n\n        assert (block_table is None) or (\n            block_table.dim() == 2\n            and block_table.size(0) == batch_size\n            and block_table.dtype == torch.int\n            and block_table.is_contiguous()\n            and block_table.device == torch.device(\"cpu\")\n        ), \"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            block_table.dim(),\n            block_table.size(),\n            block_table.dtype,\n            block_table.is_contiguous(),\n            block_table.device,\n        )\n\n        max_block_num = block_table.size(1) if block_table is not None else 0\n\n        assert (\n            output.dim() == 4\n            and output.size(0) == batch_size\n            and output.size(2) == self.config.q_head_num\n            and output.size(1) == q_len\n            and output.size(3) == self.config.head_dim\n            and output.dtype == torch.float16\n            and output.is_contiguous()\n            and output.device == torch.device(\"cpu\")\n        ), \"output dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            output.dim(),\n            output.size(),\n            output.dtype,\n            output.is_contiguous(),\n            output.device,\n        )\n\n        assert (\n            attn_lse.dim() == 3\n            and attn_lse.size(0) == batch_size\n            and attn_lse.size(1) == q_len\n            and attn_lse.size(2) == self.config.q_head_num\n            and attn_lse.dtype == torch.float32\n            and attn_lse.is_contiguous()\n            and attn_lse.device == torch.device(\"cpu\")\n        ), \"attn_lse dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            attn_lse.dim(),\n            attn_lse.size(),\n            attn_lse.dtype,\n            attn_lse.is_contiguous(),\n            attn_lse.device,\n        )\n\n        assert (\n            layer_idx >= 0 and layer_idx < self.config.layer_num\n        ), \"layer_idx: {}\".format(layer_idx)\n\n        assert (cache_seqlens is None) or (\n            cache_seqlens.dim() == 1\n            and cache_seqlens.size(0) == batch_size\n            and cache_seqlens.dtype == torch.int\n            and cache_seqlens.is_contiguous()\n            and cache_seqlens.device == torch.device(\"cpu\")\n        ), \"cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            cache_seqlens.dim(),\n            cache_seqlens.size(),\n            cache_seqlens.dtype,\n            cache_seqlens.is_contiguous(),\n            cache_seqlens.device,\n        )\n\n        return self.kvcache.attn(\n            q_in.data_ptr(),\n            output.data_ptr(),\n            attn_lse.data_ptr(),\n            layer_idx,\n            generate_token_idx,\n            q_len,\n            batch_size,\n            max_block_num,\n            block_table.data_ptr() if block_table is not None else 0,\n            cache_seqlens.data_ptr() if cache_seqlens is not None else 0,\n            pick_block_num,\n            init_block_num,\n            local_block_num,\n        )\n\n    # k_in: (block_len, kv_head_num, head_dim)\n    # v_in: (block_len, kv_head_num, head_dim)\n    def update_kvcache_one_block_fp16(\n        self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int\n    ):\n        assert (\n            k_in.dim() == 3\n            and k_in.size(1) == self.config.block_len\n            and k_in.size(0) == self.config.kv_head_num\n            and k_in.size(2) == self.config.head_dim\n            and k_in.dtype == torch.float16\n            and k_in.is_contiguous()\n            and k_in.device == torch.device(\"cpu\")\n        ), \"k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device\n        )\n        assert (\n            v_in.dim() == 3\n            and v_in.size(1) == self.config.block_len\n            and v_in.size(0) == self.config.kv_head_num\n            and v_in.size(2) == self.config.head_dim\n            and v_in.dtype == torch.float16\n            and v_in.is_contiguous()\n            and v_in.device == torch.device(\"cpu\")\n        ), \"v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.update_one_block_fp16(\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def get_kvcache_one_block_fp16(\n        self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int\n    ):\n        assert (\n            k_in.dim() == 3\n            and k_in.size(1) == self.config.block_len\n            and k_in.size(0) == self.config.kv_head_num\n            and k_in.size(2) == self.config.head_dim\n            and k_in.dtype == torch.float16\n            and k_in.is_contiguous()\n            and k_in.device == torch.device(\"cpu\")\n        ), \"k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device\n        )\n        assert (\n            v_in.dim() == 3\n            and v_in.size(1) == self.config.block_len\n            and v_in.size(0) == self.config.kv_head_num\n            and v_in.size(2) == self.config.head_dim\n            and v_in.dtype == torch.float16\n            and v_in.is_contiguous()\n            and v_in.device == torch.device(\"cpu\")\n        ), \"v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.get_one_block_fp16(\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def update_importance_one_block(\n        self, importance: torch.Tensor, layer_id: int, block_idx: int\n    ):\n        assert (\n            importance.dim() == 1\n            and importance.size(0) == self.config.block_len\n            and importance.dtype == torch.float16\n            and importance.is_contiguous()\n            and importance.device == torch.device(\"cpu\")\n        ), \"importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            importance.dim(),\n            importance.size(),\n            importance.dtype,\n            importance.is_contiguous(),\n            importance.device,\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.update_importance_one_block(\n            importance.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def get_importance_one_block(\n        self, importance: torch.Tensor, layer_id: int, block_idx: int\n    ):\n        assert (\n            importance.dim() == 1\n            and importance.size(0) == self.config.block_len\n            and importance.dtype == torch.float16\n            and importance.is_contiguous()\n            and importance.device == torch.device(\"cpu\")\n        ), \"importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            importance.dim(),\n            importance.size(),\n            importance.dtype,\n            importance.is_contiguous(),\n            importance.device,\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.get_importance_one_block(\n            importance.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def get_anchor_one_block(self, anchor: torch.Tensor, layer_id: int, block_idx: int):\n        assert (\n            anchor.dim() == 3\n            and anchor.size(0) == self.config.kv_head_num\n            and anchor.size(1) == self.config.anchor_num\n            and anchor.size(2) == self.config.head_dim\n            and anchor.dtype == torch.float16\n            and anchor.is_contiguous()\n            and anchor.device == torch.device(\"cpu\")\n        ), \"anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            anchor.dim(),\n            anchor.size(),\n            anchor.dtype,\n            anchor.is_contiguous(),\n            anchor.device,\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.get_anchor_one_block(\n            anchor.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def update_anchor_one_block(\n        self, anchor: torch.Tensor, layer_id: int, block_idx: int\n    ):\n        assert (\n            anchor.dim() == 3\n            and anchor.size(0) == self.config.kv_head_num\n            and anchor.size(1) == self.config.anchor_num\n            and anchor.size(2) == self.config.head_dim\n            and anchor.dtype == torch.float16\n            and anchor.is_contiguous()\n            and anchor.device == torch.device(\"cpu\")\n        ), \"anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            anchor.dim(),\n            anchor.size(),\n            anchor.dtype,\n            anchor.is_contiguous(),\n            anchor.device,\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.update_anchor_one_block(\n            anchor.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def calc_anchor_all_layers(\n        self,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n    ):\n        assert (\n            block_table.dim() == 2\n            and block_table.size(0) == cache_seqlens.size(0)\n            and block_table.dtype == torch.int\n            and block_table.is_contiguous()\n            and block_table.device == torch.device(\"cpu\")\n        ), \"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            block_table.dim(),\n            block_table.size(),\n            block_table.dtype,\n            block_table.is_contiguous(),\n            block_table.device,\n        )\n        assert (\n            cache_seqlens.dim() == 1\n            and cache_seqlens.dtype == torch.int\n            and cache_seqlens.is_contiguous()\n            and cache_seqlens.device == torch.device(\"cpu\")\n        ), \"cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            cache_seqlens.dim(),\n            cache_seqlens.size(),\n            cache_seqlens.dtype,\n            cache_seqlens.is_contiguous(),\n            cache_seqlens.device,\n        )\n        batch_size = block_table.size(0)\n        max_block_num = block_table.size(1)\n        return self.kvcache.calc_anchor_all_layers(\n            block_table.data_ptr(),\n            cache_seqlens.data_ptr(),\n            batch_size,\n            max_block_num,\n        )\n\n    def clear_importance_all_layers(\n        self,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n    ):\n        assert (\n            block_table.dim() == 2\n            and block_table.size(0) == cache_seqlens.size(0)\n            and block_table.dtype == torch.int\n            and block_table.is_contiguous()\n            and block_table.device == torch.device(\"cpu\")\n        ), \"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            block_table.dim(),\n            block_table.size(),\n            block_table.dtype,\n            block_table.is_contiguous(),\n            block_table.device,\n        )\n        assert (\n            cache_seqlens.dim() == 1\n            and cache_seqlens.dtype == torch.int\n            and cache_seqlens.is_contiguous()\n            and cache_seqlens.device == torch.device(\"cpu\")\n        ), \"cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            cache_seqlens.dim(),\n            cache_seqlens.size(),\n            cache_seqlens.dtype,\n            cache_seqlens.is_contiguous(),\n            cache_seqlens.device,\n        )\n        batch_size = block_table.size(0)\n        max_block_num = block_table.size(1)\n        return self.kvcache.clear_importance_all_layers(\n            block_table.data_ptr(),\n            cache_seqlens.data_ptr(),\n            batch_size,\n            max_block_num,\n        )\n\n    def get_cache_total_len(self):\n        return self.kvcache.get_cache_total_len()\n\n    def update_kvcache_q4(\n        self,\n        k_in: torch.Tensor,\n        k_scales: torch.Tensor,\n        v_in: torch.Tensor,\n        v_scales: torch.Tensor,\n        layer_id: int,\n        seq_offset: int | None = None,\n        seq_len: int | None = None,\n        block_table: torch.Tensor | None = None,\n    ):\n        raise NotImplementedError\n\n    def update_kvcache_fp16(\n        self,\n        k_in: torch.Tensor,\n        v_in: torch.Tensor,\n        layer_idx,\n        block_table: torch.Tensor,\n        max_block_num,\n        past_len: torch.Tensor,\n        q_len,\n    ):\n        batch_size = block_table.size(0)\n        return self.kvcache.get_kvcache_fp16(\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            layer_idx,\n            block_table.data_ptr(),\n            batch_size,\n            max_block_num,\n            past_len.data_ptr(),\n            q_len\n        )\n\n    def get_kvcache_q4(\n        self,\n        k_in: torch.Tensor,\n        k_scales: torch.Tensor,\n        v_in: torch.Tensor,\n        v_scales: torch.Tensor,\n        layer_id: int,\n        seq_offset: int | None = None,\n        seq_len: int | None = None,\n        block_table: torch.Tensor | None = None,\n    ):\n        raise NotImplementedError\n\n    def get_kvcache_fp16(\n        self,\n        k_in: torch.Tensor,\n        v_in: torch.Tensor,\n        layer_id: int,\n        layer_idx,\n        block_table: torch.Tensor,\n        max_block_num,\n        past_len: torch.Tensor,\n    ):\n        batch_size = block_table.size(0)\n        return self.kvcache.get_kvcache_fp16(\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            layer_idx,\n            block_table.data_ptr(),\n            batch_size,\n            max_block_num,\n            past_len.data_ptr(),\n        )\n\n    def get_and_update_kvcache_fp16(\n        self,\n        k_cache_cpu: torch.Tensor,\n        v_cache_cpu: torch.Tensor,\n        layer_idx,\n        block_table: torch.Tensor,\n        max_block_num,\n        past_len: torch.Tensor,\n        q_len,\n    ):\n        batch_size = block_table.size(0)\n        return self.kvcache.get_and_update_kvcache_fp16(\n            k_cache_cpu.data_ptr(),\n            v_cache_cpu.data_ptr(),\n            layer_idx,\n            block_table.data_ptr(),\n            batch_size,\n            max_block_num,\n            past_len.data_ptr(),\n            q_len,\n        )\n\n    def update_importance(\n        self,\n        importance_cache: torch.Tensor,\n        layer_idx,\n        block_table: torch.Tensor,\n        max_block_num,\n        offset: torch.Tensor,\n        width,\n    ):\n        batch_size = block_table.size(0)\n        return self.kvcache.update_importance(\n            importance_cache.data_ptr(),\n            layer_idx,\n            block_table.data_ptr(),\n            batch_size,\n            max_block_num,\n            offset.data_ptr(),\n            width,\n        )\n\n    # attn_sparsity: ((bsz, q_len, q_head_num), dtype = torch.float32)\n    def get_attn_sparsity(\n        self,\n        q_in: torch.Tensor,\n        attn_sparsity: torch.Tensor,\n        layer_idx: int,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n        block_table_origin: torch.Tensor,\n        cache_seqlens_origin: torch.Tensor,\n        generate_token_idx: int = 0,\n        topk: int | None = None,\n        local: int | None = None,\n    ):\n        batch_size = block_table.size(0)\n        max_block_num = block_table.size(1)\n        max_block_num_origin = block_table_origin.size(1)\n        q_len = q_in.size(1)\n\n        if topk is None or local is None or topk + local >= max_block_num:\n            topk = -1\n            local = -1\n        return self.kvcache.get_attn_sparsity(\n            q_in.data_ptr(),\n            attn_sparsity.data_ptr(),\n            layer_idx,\n            generate_token_idx,\n            q_len,\n            batch_size,\n            max_block_num,\n            block_table.data_ptr(),\n            cache_seqlens.data_ptr(),\n            block_table_origin.data_ptr(),\n            cache_seqlens_origin.data_ptr(),\n            max_block_num_origin,\n            topk,\n            local,\n        )\n\n    def attn_with_kvcache(\n        self,\n        q_in: torch.Tensor,\n        k_in: torch.Tensor,\n        v_in: torch.Tensor,\n        output: torch.Tensor,\n        attn_lse: torch.Tensor,\n        layer_idx: int,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n        generate_token_idx: int = 0,\n        topk: int | None = None,\n        local: int | None = None,\n    ):\n\n        batch_size = block_table.size(0)\n        max_block_num = block_table.size(1)\n        q_len = q_in.size(1)\n\n        if topk is None or local is None or topk + local >= max_block_num:\n            topk = -1\n            local = -1\n        return self.kvcache.attn_with_kvcache(\n            q_in.data_ptr(),\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            output.data_ptr(),\n            attn_lse.data_ptr(),\n            layer_idx,\n            generate_token_idx,\n            q_len,\n            batch_size,\n            max_block_num,\n            block_table.data_ptr(),\n            cache_seqlens.data_ptr(),\n            topk,\n            local,\n        )\n\n    def get_all_kvcache_one_layer(\n        self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int\n    ):\n        return self.kvcache.get_all_kvcache_one_layer(\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            layer_id,\n        )\n\n    def get_importance(\n        self,\n        importance: torch.Tensor,\n        block_table: torch.Tensor,\n    ):\n        raise NotImplementedError\n\n    def get_anchor(\n        self,\n        anchor: torch.Tensor,\n        block_table: torch.Tensor,\n    ):\n        raise NotImplementedError\n\n\nclass CPUInfer:\n    cpuinfer = None\n    cur_backend_thread_num = 0\n    \n    def __init__(self, thread_num):\n        if thread_num > CPUInfer.cur_backend_thread_num:\n            CPUInfer.cur_backend_thread_num = thread_num\n            del CPUInfer.cpuinfer\n            CPUInfer.cpuinfer = cpuinfer_ext.CPUInfer(thread_num)\n\n    def submit(self, task):\n        CPUInfer.cpuinfer.submit(task)\n\n    def submit_with_cuda_stream(self, current_cuda_stream, task):\n        CPUInfer.cpuinfer.submit_with_cuda_stream(current_cuda_stream, task)\n\n    def sync(self):\n        CPUInfer.cpuinfer.sync()\n\n    def sync_with_cuda_stream(self, current_cuda_stream):\n        CPUInfer.cpuinfer.sync_with_cuda_stream(current_cuda_stream)\n\n\n        \n"
  },
  {
    "path": "archive/ktransformers/operators/dynamic_attention.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :  \nAuthor       : Jianwei Dong\nDate         : 2024-08-26 23:25:24\nVersion      : 1.0.0\nLastEditors  : Jianwei Dong\nLastEditTime : 2024-08-26 23:25:24\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\n\nimport torch\nfrom transformers import AutoConfig\nimport sys, os\nimport logging\nlogger = logging.getLogger(\"dynamic_attention\")\nsys.path.append(os.path.dirname(__file__) + \"/../ktransformers_ext/cpu_backend\")\nfrom ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache\ntry:\n    from flash_attn import flash_attn_func, flash_attn_with_kvcache\nexcept:\n    print(\"falsh attn not found\")\n\n\nimport math\nimport json\n\n\nclass DynamicScaledDotProductAttention:\n    remaining_length: int\n    cpu_infer = None\n\n    def __init__(\n        self,\n        max_seq_len: int,\n        block_size: int,\n        config: AutoConfig,\n        device: torch.device,\n        local_windows_len: int,\n        topk: int,\n        threads_num: int,\n        anchor_type: str = \"DYNAMIC\",\n        kv_type: str = \"FP16\",\n        dense_layer_num: int = 0,\n        anchor_num: int = 1,\n        block_selection_mode: str = \"SHARED\",\n        layer_step: int = 1,\n        token_step: int = 1,\n        preselect_block: bool = False,\n        preselect_block_count: int = 96,\n        prefill_chunk_size: int = 20480,\n        use_attn_sparsity: bool = False,\n    ):\n        # assert anchor_num == 1\n        # assert anchor_type == \"DYNAMIC\"\n        self.remaining_length = 0\n        valid_anchor_types = [\"DYNAMIC\", \"FIXED\", \"BLOCK_MEAN\", \"BLOCK_MAX\", \"QUEST\"]\n        assert anchor_type in valid_anchor_types\n        if anchor_type == \"QUEST\":\n            assert anchor_num == 2\n        elif anchor_type != \"FIXED\" and anchor_type != \"DYNAMIC\":\n            assert anchor_num == 1\n\n        valid_kv_types = [\"FP16\", \"FP32\", \"Q4_0\", \"Q8_0\"]\n        assert kv_type in valid_kv_types\n        if kv_type != \"FP16\" and kv_type != \"FP32\":\n            assert block_size % 32 == 0\n\n        valid_block_selection_modes = [\"SHARED\", \"SEPARATE\"]  # individual\n        assert block_selection_mode in valid_block_selection_modes\n\n        self.max_seq_len = max_seq_len\n        self.block_num = max_seq_len // block_size\n        self.block_size = block_size\n        self.anchor_type = anchor_type\n        self.kv_type = kv_type\n        self.anchor_num = anchor_num\n        self.threads_num = threads_num\n        self.layer_step = layer_step\n        self.token_step = token_step\n        self.preselect_block = preselect_block\n        self.preselect_block_count = preselect_block_count\n        self.block_selection_mode = block_selection_mode\n        self.use_attn_sparsity = use_attn_sparsity\n\n        # model config\n        self.kv_head_num = config.num_key_value_heads\n        self.q_head_num = config.num_attention_heads\n        self.head_dim = config.hidden_size // config.num_attention_heads\n        self.layer_num = config.num_hidden_layers\n\n        self.device = device\n        self.local_windows_len = local_windows_len\n        self.local_block_num = self.local_windows_len // self.block_size + 1\n        self.prefill_chunk_size = prefill_chunk_size\n\n        self.topk = topk\n        self.dense_layer_num = dense_layer_num\n        # self.dense_layer_num = 32\n        self.cache_key_states = torch.zeros(\n            (self.block_num, block_size, self.kv_head_num, self.head_dim),\n            device=device,\n            dtype=torch.float16,\n        )\n        self.cache_value_states = torch.zeros(\n            (self.block_num, block_size, self.kv_head_num, self.head_dim),\n            device=device,\n            dtype=torch.float16,\n        )\n        # [max_num_block, block_size, head_num]\n        self.cache_importance = torch.zeros(\n            (self.block_num, block_size, self.q_head_num),\n            device=device,\n            dtype=torch.float16,\n        )\n\n        # key_states: [bsz, q_len, kv_head_num, head_dim]\n        # value_states: [bsz, q_len, kv_head_num, head_dim]\n        # query_states: [bsz, q_len, q_head_num, head_dim]\n        self.q_in_cpu = torch.zeros(\n            (1, 1, self.q_head_num, self.head_dim),\n            device=\"cpu\",\n            dtype=torch.float16,\n            pin_memory=True,\n        )\n        self.k_in_cpu = torch.zeros(\n            (1, 1, self.kv_head_num, self.head_dim),\n            device=\"cpu\",\n            dtype=torch.float16,\n            pin_memory=True,\n        )\n        self.v_in_cpu = torch.zeros(\n            (1, 1, self.kv_head_num, self.head_dim),\n            device=\"cpu\",\n            dtype=torch.float16,\n            pin_memory=True,\n        )\n\n        self.cache_seqlens_cpu = torch.empty(\n            (1,), device=\"cpu\", dtype=torch.int32, pin_memory=True\n        )\n\n        self.cache_seqlens_cuda = torch.empty((1,), device=device, dtype=torch.int32)\n\n        self.prefix_block_table = torch.arange(\n            self.block_num, device=\"cpu\", dtype=torch.int32, pin_memory=True\n        ).view(1, -1)\n\n        self.block_table_cpu = torch.arange(\n            self.block_num, device=\"cpu\", dtype=torch.int32, pin_memory=True\n        ).view(1, -1)\n\n        # assert (\n        #     self.local_windows_len // self.block_size + 1 + self.preselect_block_count\n        #     <= self.block_num\n        # )\n\n        self.output_cpu = torch.empty(\n            (1, 1, self.q_head_num, self.head_dim),\n            device=\"cpu\",\n            dtype=torch.float16,\n            pin_memory=True,\n        )\n        self.lse_cpu = torch.empty(\n            (1, 1, self.q_head_num), device=\"cpu\", dtype=torch.float32, pin_memory=True\n        )\n\n        self.output_cuda = torch.empty(\n            (1, 1, self.q_head_num, self.head_dim), device=device, dtype=torch.float16\n        )\n\n        self.attn_sparsity = torch.zeros(\n            (1, 1, self.q_head_num), device=\"cpu\", dtype=torch.float32, pin_memory=True\n        )\n\n        if preselect_block == True:\n            self.preselect_block_table = torch.zeros(\n                self.layer_num,\n                self.preselect_block_count,\n                device=device,\n                dtype=torch.int32,\n            )\n            self.preselect_block_num = 0  # block_num before preselect\n            self.evict_tokens = 0\n\n        if DynamicScaledDotProductAttention.cpu_infer is None:\n            DynamicScaledDotProductAttention.cpu_infer = CPUInfer(threads_num)\n            self.cpu_infer = DynamicScaledDotProductAttention.cpu_infer\n        self.local_thread = CPUInferKVCache(\n            self.layer_num,\n            self.kv_head_num,\n            self.q_head_num,\n            self.head_dim,\n            self.block_size,\n            anchor_num=self.anchor_num,\n            anchor_type=anchor_type,\n            kv_type=self.kv_type,\n            retrieval_type=self.block_selection_mode,\n            layer_step=self.layer_step,\n            token_step=self.token_step,\n            layer_offset=self.dense_layer_num % self.layer_step,\n            max_batch_size=1,\n            max_block_num=self.block_num,\n            max_thread_num=self.threads_num,\n        )\n\n        print(\n            f\"local_windows_len: {local_windows_len}, topk: {topk}, dense_layer_num: {dense_layer_num}, kv_type: {self.kv_type}, anchor_type: {self.anchor_type}, preselect_block: {self.preselect_block}, preselect_block_count: {self.preselect_block_count}, token_step: {self.token_step}, layer_step: {self.layer_step}\"\n        )\n\n        self.shape_mask = (\n            self.q_head_num,\n            self.block_size,\n            self.block_size,\n        )\n\n        mask = torch.zeros(\n            self.shape_mask, dtype=torch.uint8, device=device\n        ).contiguous()\n        elm_idx = torch.arange(self.block_size, device=device)\n\n        for i in range(mask.size(-2)):\n            idx = i + mask.size(-1) - mask.size(-2) - elm_idx\n            idx = idx[idx >= 0]\n            mask[..., i, idx] = 1\n\n        self.tril_mask = mask\n        self.triu_mask = mask ^ 1\n\n        self.generate_token_idx = 0\n\n    def get_attn_score_one_block(\n        self,\n        batch_idx: int,\n        max_block_num: int,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        offset: int,\n        width: int,\n        mask_mode: str | None = None,\n        use_softmax: bool = True,\n    ):\n        n_rep = self.q_head_num // self.kv_head_num\n        importance = self.cache_importance.view(-1, self.q_head_num)\n        importance = importance.narrow(0, batch_idx * max_block_num + offset, width)\n        n_gqa_ = self.q_head_num // self.kv_head_num \n        for head_idx in range(self.q_head_num):\n            key_item = key[..., head_idx // n_gqa_, :].view(key.size(0), -1)\n            qk = torch.einsum(\n                \"qd,kd->qk\", query[:,head_idx,:], key_item\n            )  # (num_attention_heads, len_q, len_k)\n\n            if mask_mode == \"tril\":\n                mask = self.tril_mask\n                mask = mask[0, -qk.size(-2) :, -qk.size(-1) :]\n                qk = qk * mask\n            elif mask_mode == \"triu\":\n                mask = self.triu_mask\n                mask = mask[0, -qk.size(-2) :, -qk.size(-1) :]\n                qk = qk * mask\n\n            if use_softmax:\n                qk = torch.nn.functional.softmax(\n                    qk / math.sqrt(self.head_dim), dim=-1, dtype=torch.float32\n                ).to(torch.float16)\n              \n            qk = torch.sum(qk, dim=-2)\n            importance[...,head_idx] += qk\n\n    def get_preselect_block_table_and_attn_score(\n        self,\n        layer_idx: int,\n        batch_size: int,\n        offset: torch.Tensor,\n        width: int,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        union_with_last_layer: bool = True,\n    ):\n        max_seqs_len = offset.max().item() + width\n        max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size\n\n        for batch_idx in range(batch_size):\n            query_cur = query[batch_idx][-128:]\n            self.get_attn_score_one_block(\n                batch_idx,\n                max_block_num,\n                query_cur,\n                key[batch_idx][: offset[batch_idx].item() + width],\n                0,\n                offset[batch_idx].item() + width,\n                mask_mode=None,\n            )\n\n        if self.preselect_block:\n            self.prefill_block_num = max(\n                0, max_block_num - self.local_windows_len // self.block_size\n            )\n            self.evict_tokens = (\n                max(self.prefill_block_num - self.preselect_block_count, 0)\n                * self.block_size\n            )\n\n            if self.prefill_block_num != 0:\n                importance_cache = self.cache_importance.narrow(\n                    0, 0, self.prefill_block_num * batch_size\n                ).view(\n                    batch_size, self.prefill_block_num, self.block_size, self.q_head_num\n                )\n\n                importance_r = importance_cache[:, 1:, : self.block_size // 4]\n                pad_r = torch.zeros_like(importance_r[:, :1])\n                importance_r = torch.cat((importance_r, pad_r), dim=1)\n                importance_l = importance_cache[:, :-1, -self.block_size // 4 :]\n                pad_l = torch.zeros_like(importance_l[:, :1])\n                importance_l = torch.cat((pad_l, importance_l), dim=1)\n                importance = torch.cat(\n                    (importance_l, importance_cache, importance_r), dim=2\n                )\n                importance = importance.mean(dim=-1)\n                importance = importance.mean(dim=-1)\n                # importance: (batch_size, max_block_num)\n                topk = min(self.preselect_block_count, self.prefill_block_num)\n                values, indices = torch.topk(\n                    importance,\n                    k=topk,\n                    dim=1,\n                )\n\n                self.preselect_block_table[\n                    layer_idx : layer_idx + 1,\n                    :topk,\n                ].copy_(indices)\n\n                if union_with_last_layer and layer_idx == 31:\n                    for tmp_layer_idx in range(self.layer_num - 1):\n                        for i in range(1, min(topk, 6)):\n                            x = self.preselect_block_table[-1, i]\n                            if x not in self.preselect_block_table[tmp_layer_idx]:\n                                self.preselect_block_table[tmp_layer_idx, topk - i] = x\n        if self.anchor_type == \"DYNAMIC\":\n            importance_cache = self.cache_importance.narrow(\n                0, 0, max_block_num * batch_size\n            ).view(batch_size, max_block_num * self.block_size, self.q_head_num)\n            importance_cache_cpu = torch.empty_like(\n                importance_cache, device=\"cpu\", pin_memory=True\n            )\n\n            importance_cache_cpu.copy_(importance_cache)\n\n            block_table_cpu = self.prefix_block_table[:, :max_block_num].to(\"cpu\")\n            offset_cpu = offset.contiguous().to(\"cpu\")\n\n            self.cpu_infer.submit(\n                self.local_thread.update_importance(\n                    importance_cache_cpu,\n                    layer_idx,\n                    block_table_cpu,\n                    max_block_num,\n                    offset_cpu,\n                    width,\n                )\n            )\n            self.cpu_infer.sync()\n\n        importance_cache = self.cache_importance.narrow(\n            0, 0, max_block_num * batch_size\n        ).view(batch_size, max_block_num * self.block_size, self.q_head_num)\n        importance_cache.zero_()\n\n    # key: [bsz, past_len, head_num, head_dim] float16\n    # query: [bsz, q_len, q_head_num, head_dim] float16\n    def get_attn_score(\n        self,\n        layer_idx: int,\n        batch_size: int,\n        offset: torch.Tensor,\n        width: int,\n        query: torch.Tensor,\n        key: torch.Tensor,\n    ):\n        max_seqs_len = offset.max().item() + width\n        max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size\n\n        for batch_idx in range(batch_size):\n            for idx in range(width // self.block_size):\n                offset_cur = idx * self.block_size\n                query_cur = query[batch_idx, offset_cur : offset_cur + self.block_size]\n                self.get_attn_score_one_block(\n                    batch_idx,\n                    max_block_num,\n                    query_cur,\n                    key[\n                        batch_idx,\n                        offset[batch_idx]\n                        + offset_cur : offset[batch_idx]\n                        + offset_cur\n                        + self.block_size,\n                    ],\n                    offset[batch_idx].item() + offset_cur,\n                    self.block_size,\n                    mask_mode=\"tril\",\n                    use_softmax=False,\n                )\n\n                offset_key = (\n                    offset[batch_idx].item()\n                    + idx * self.block_size\n                    - self.local_windows_len\n                )\n                if offset_key >= 0:\n                    self.get_attn_score_one_block(\n                        batch_idx,\n                        max_block_num,\n                        query_cur,\n                        key[batch_idx, offset_key : offset_key + self.block_size],\n                        offset_key,\n                        self.block_size,\n                        mask_mode=\"triu\",\n                        use_softmax=False,\n                    )\n\n                offset_key = max(0, offset_key + self.block_size)\n                width_key = (\n                    offset[batch_idx].item() + idx * self.block_size - offset_key\n                )\n                if width_key > 0:\n                    self.get_attn_score_one_block(\n                        batch_idx,\n                        max_block_num,\n                        query_cur,\n                        key[batch_idx, offset_key : offset_key + width_key],\n                        offset_key,\n                        width_key,\n                        mask_mode=None,\n                        use_softmax=False,\n                    )\n\n        importance_cache = self.cache_importance.narrow(\n            0, 0, max_block_num * batch_size\n        ).view(batch_size, max_block_num * self.block_size, self.q_head_num)\n        importance_cache_cpu = torch.empty_like(\n            importance_cache, device=\"cpu\", pin_memory=True\n        )\n\n        importance_cache_cpu.copy_(importance_cache)\n\n        block_table_cpu = self.prefix_block_table[:, :max_block_num].to(\"cpu\")\n        offset_cpu = offset.contiguous().to(\"cpu\")\n\n        self.cpu_infer.submit(\n            self.local_thread.update_importance(\n                importance_cache_cpu,\n                layer_idx,\n                block_table_cpu,\n                max_block_num,\n                offset_cpu,\n                width,\n            )\n        )\n        self.cpu_infer.sync()\n        importance_cache.zero_()\n\n    # key: [bsz, q_len, head_num, head_dim] float16\n    # value: [bsz, q_len, head_num, head_dim] float16\n    def swap_in_and_swap_out(self, layer_idx, past_len, q_len, key, value):\n        batch_size = 1\n        max_seqs_len = past_len.max().item() + q_len\n        max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size\n        k_cache = self.cache_key_states.narrow(0, 0, max_block_num * batch_size).view(\n            batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim\n        )\n        v_cache = self.cache_value_states.narrow(0, 0, max_block_num * batch_size).view(\n            batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim\n        )\n\n        for batch_idx in range(batch_size):\n            offset = past_len[batch_idx]\n            width = q_len\n            k_cache[batch_idx][offset : offset + width].copy_(\n                key[batch_idx].view(-1, self.kv_head_num, self.head_dim)\n            )\n            v_cache[batch_idx][offset : offset + width].copy_(\n                value[batch_idx].view(-1, self.kv_head_num, self.head_dim)\n            )\n\n        k_cache_cpu = torch.empty_like(k_cache, device=\"cpu\", pin_memory=True)\n        v_cache_cpu = torch.empty_like(v_cache, device=\"cpu\", pin_memory=True)\n\n        k_cache_cpu.copy_(k_cache)\n        v_cache_cpu.copy_(v_cache)\n\n        cur_block_num = (\n            q_len + past_len[0].item() + self.block_size - 1\n        ) // self.block_size\n        block_table_cpu = self.prefix_block_table[:, :cur_block_num].to(\"cpu\")\n        past_len_cpu = past_len.contiguous().to(\"cpu\")\n\n        self.cpu_infer.submit(\n            self.local_thread.get_and_update_kvcache_fp16(\n                k_cache_cpu,\n                v_cache_cpu,\n                layer_idx,\n                block_table_cpu,\n                max_block_num,\n                past_len_cpu,\n                q_len,\n            )\n        )\n\n        self.cpu_infer.sync()\n        k_cache.copy_(k_cache_cpu)\n        v_cache.copy_(v_cache_cpu)\n\n        return k_cache, v_cache\n\n    def calc_anchor(self, cache_seqlens: int):\n        cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size\n        block_table_cpu = self.prefix_block_table[:, :cur_block_num].to(\"cpu\")\n        cache_seqlens_cpu = torch.tensor(\n            [cache_seqlens], device=\"cpu\", dtype=torch.int32\n        )\n\n        self.cpu_infer.submit(\n            self.local_thread.calc_anchor_all_layers(\n                block_table_cpu,\n                cache_seqlens_cpu,\n            )\n        )\n        self.cpu_infer.sync()\n\n    def clear_importance(self, cache_seqlens: int):\n        print(f\"clear importance: {cache_seqlens}\")\n        cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size\n        block_table_cpu = self.prefix_block_table[:, :cur_block_num].to(\"cpu\")\n        cache_seqlens_cpu = torch.tensor(\n            [cache_seqlens], device=\"cpu\", dtype=torch.int32\n        )\n\n        self.cpu_infer.submit(\n            self.local_thread.clear_importance_all_layers(\n                block_table_cpu,\n                cache_seqlens_cpu,\n            )\n        )\n        self.cpu_infer.sync()\n\n    def clear_kvcache(self, cache_seqlens: int):\n        cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size\n        block_table_cpu = self.prefix_block_table[:, :cur_block_num].to(\"cpu\")\n        cache_seqlens_cpu = torch.tensor(\n            [cache_seqlens], device=\"cpu\", dtype=torch.int32\n        )\n\n        self.cpu_infer.submit(\n            self.local_thread.clear_kvcache_all_layers(\n                block_table_cpu,\n                cache_seqlens_cpu,\n            )\n        )\n        self.cpu_infer.sync()\n\n    def get_attn_sparsity(\n        self,\n        q_in: torch.Tensor,\n        layer_idx: int,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n        block_table_origin: torch.Tensor,\n        cache_seqlens_origin: torch.Tensor,\n        generate_token_idx: int = 0,\n        topk: int | None = None,\n        local: int | None = None,\n        output_path: str = \"./attn_sparsity.json\",\n    ):\n        self.attn_sparsity.zero_()\n        self.pcinfer.submit(\n            self.local_thread.get_attn_sparsity(\n                q_in,\n                self.attn_sparsity,\n                layer_idx,\n                block_table,\n                cache_seqlens,\n                block_table_origin,\n                cache_seqlens_origin,\n                generate_token_idx,\n                topk,\n                local,\n            )\n        )\n        self.cpu_infer.sync()\n        with open(output_path, \"a\") as file:\n            for head_idx in range(self.q_head_num):\n                sparsity = self.attn_sparsity[0][0][head_idx].item()\n                json_obj = {\n                    \"token_idx\": generate_token_idx,\n                    \"layer_idx\": layer_idx,\n                    \"head_idx\": head_idx,\n                    \"sparsity\": sparsity,\n                }\n                json.dump(json_obj, file)\n                file.write(\"\\n\")\n\n    def apply(\n        self,\n        layer_idx: int,\n        bsz: int,\n        past_len: int,\n        query_states: torch.Tensor,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        mode: str = \"prefill\",\n        generate_token_idx: int = -1,\n    ):\n\n        # key_states: [bsz, q_len, kv_head_num, head_dim]\n        # value_states: [bsz, q_len, kv_head_num, head_dim]\n        # query_states: [bsz, q_len, q_head_num, head_dim]\n        assert query_states.dtype == torch.float16\n        assert key_states.dtype == torch.float16\n        assert value_states.dtype == torch.float16\n\n        assert key_states.size(2) == self.kv_head_num\n        assert value_states.size(2) == self.kv_head_num\n        assert query_states.size(2) == self.q_head_num\n\n        q_len = query_states.size(1)\n        batch_size = query_states.size(0)\n        self.cache_seqlens_cuda.fill_(past_len)\n        last_chunk = False\n        if self.remaining_length <= self.prefill_chunk_size and q_len != 1:\n            last_chunk = True\n        device = query_states.device\n        if layer_idx == 0:\n            if q_len == 1:\n                self.generate_token_idx += 1\n            elif last_chunk:\n                self.generate_token_idx = -1\n\n        if mode == \"prefill\":\n            key, value = self.swap_in_and_swap_out(\n                layer_idx,\n                self.cache_seqlens_cuda,\n                q_len,\n                key_states,\n                value_states,\n            )\n\n            if last_chunk and (self.anchor_type == \"DYNAMIC\" or self.preselect_block):\n                self.get_preselect_block_table_and_attn_score(\n                    layer_idx,\n                    bsz,\n                    self.cache_seqlens_cuda,\n                    q_len,\n                    query_states,\n                    key,\n                )\n            output = flash_attn_with_kvcache(\n                q=query_states,\n                k_cache=key,\n                v_cache=value,\n                cache_seqlens=self.cache_seqlens_cuda + q_len,\n                causal=True,\n            )\n            return output.transpose(1, 2)\n\n        elif mode == \"generate\":\n            assert self.generate_token_idx >= 0\n            self.q_in_cpu.copy_(query_states, non_blocking=True)\n            self.k_in_cpu.copy_(key_states, non_blocking=True)\n            self.v_in_cpu.copy_(value_states, non_blocking=True)\n            self.cache_seqlens_cpu.copy_(self.cache_seqlens_cuda, non_blocking=True)\n            #            print(layer_idx)\n            if layer_idx < self.dense_layer_num:\n                self.block_table_cpu.copy_(self.prefix_block_table, non_blocking=True)\n                self.cpu_infer.submit_with_cuda_stream(\n                    torch.cuda.current_stream(\"cuda\").cuda_stream,\n                    self.local_thread.attn_with_kvcache(\n                        q_in=self.q_in_cpu,\n                        k_in=self.k_in_cpu,\n                        v_in=self.v_in_cpu,\n                        output=self.output_cpu,\n                        attn_lse=self.lse_cpu,\n                        layer_idx=layer_idx,\n                        block_table=self.block_table_cpu,\n                        cache_seqlens=self.cache_seqlens_cpu,\n                    ),\n                )\n            else:\n                if self.preselect_block:\n                    self.cache_seqlens_cpu.copy_(\n                        self.cache_seqlens_cuda - self.evict_tokens, non_blocking=True\n                    )\n                    if self.preselect_block_count < self.prefill_block_num:\n                        self.block_table_cpu[:, : self.preselect_block_count].copy_(\n                            self.preselect_block_table[layer_idx : layer_idx + 1],\n                            non_blocking=True,\n                        )\n\n                        self.block_table_cpu[\n                            :,\n                            self.preselect_block_count : self.preselect_block_count\n                            + self.local_block_num,\n                        ].copy_(\n                            self.prefix_block_table[\n                                :,\n                                self.prefill_block_num : self.prefill_block_num\n                                + self.local_block_num,\n                            ],\n                            non_blocking=True,\n                        )\n                    #                   print(\"submit_with_cuda_stream\")\n                    self.cpu_infer.submit_with_cuda_stream(\n                        torch.cuda.current_stream(\"cuda\").cuda_stream,\n                        self.local_thread.attn_with_kvcache(\n                            q_in=self.q_in_cpu,\n                            k_in=self.k_in_cpu,\n                            v_in=self.v_in_cpu,\n                            output=self.output_cpu,\n                            attn_lse=self.lse_cpu,\n                            layer_idx=layer_idx,\n                            generate_token_idx=self.generate_token_idx,\n                            block_table=self.block_table_cpu,\n                            cache_seqlens=self.cache_seqlens_cpu,\n                            topk=(\n                                self.topk\n                                if self.topk <= self.preselect_block_count\n                                else None\n                            ),\n                            local=self.local_windows_len // self.block_size,\n                        ),\n                    )\n                #                    print(\"submit_with_cuda_stream enqueue\\n\")\n                else:\n                    self.block_table_cpu.copy_(\n                        self.prefix_block_table, non_blocking=True\n                    )\n                    self.cpu_infer.submit_with_cuda_stream(\n                        torch.cuda.current_stream(\"cuda\").cuda_stream,\n                        self.local_thread.attn_with_kvcache(\n                            q_in=self.q_in_cpu,\n                            k_in=self.k_in_cpu,\n                            v_in=self.v_in_cpu,\n                            output=self.output_cpu,\n                            attn_lse=self.lse_cpu,\n                            layer_idx=layer_idx,\n                            generate_token_idx=self.generate_token_idx,\n                            block_table=self.block_table_cpu,\n                            cache_seqlens=self.cache_seqlens_cpu,\n                            topk=self.topk,\n                            local=self.local_windows_len // self.block_size,\n                        ),\n                    )\n            self.cpu_infer.sync_with_cuda_stream(\n                torch.cuda.current_stream(\"cuda\").cuda_stream\n            )\n            #            print(\"submit_with_cuda_stream finished\\n\")\n            self.output_cuda.copy_(self.output_cpu, non_blocking=True)\n            return self.output_cuda.transpose(1, 2)\n\n    def save(self, path: str, length: int):\n        cur_block_num = (length + self.block_size - 1) // self.block_size\n        block_table_cpu = self.prefix_block_table[0, :cur_block_num].to(\"cpu\")\n        cache_seqlens_cpu = torch.tensor([length], device=\"cpu\", dtype=torch.int32)\n        self.cpu_infer.submit(\n            self.local_thread.dump_kvcache(\n                block_table_cpu,\n                cache_seqlens_cpu,\n                path,\n            )\n        )\n        self.cpu_infer.sync()\n\n    def load(self, path: str, length: int):\n        self.cpu_infer.submit(\n            self.local_thread.load_kvcache(\n                path,\n            )\n        )\n        self.cpu_infer.sync()\n"
  },
  {
    "path": "archive/ktransformers/operators/experts.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : Azure-Tang, Boxin Zhang, chenht2022\nDate         : 2024-07-25 11:25:24\nVersion      : 0.1.0\nLastEditors  : Azure \nLastEditTime : 2024-08-29 09:41:10\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\n\nfrom typing import Any, Union\nimport numpy as np\nimport numpy.typing as npt\nfrom torch import Tensor, nn\nimport torch.nn.functional as F\nimport torch\nimport sys, os\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom tqdm import tqdm\n\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Release\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Debug\"))\nimport cpuinfer_ext\nfrom cpuinfer_ext.moe import MOEConfig, MOE\nimport ctypes\nfrom ktransformers.util.custom_gguf import GGMLQuantizationType, translate_name_to_gguf\nfrom ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader, ModelLoader\nfrom ktransformers.util.utils import InferenceState\nfrom ktransformers.server.config.config import Config\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom abc import ABC, abstractmethod\nfrom ktransformers.operators.linear import KLinearMarlin, KLinearTorch, KTransformersLinear\nimport time\nfrom ktransformers.operators.cpuinfer import CPUInfer\n\ntry:\n    import torch_npu\n    from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size\n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\n\n\ndef deduplicate_and_sort(lst):\n    return sorted(set(lst))\ndef generate_cuda_graphs(chunk_size: int) -> list:\n    assert chunk_size <= 1024 or chunk_size % 1024 == 0, \"chunk_size must <= 1024 or a multiple of 1024\"\n    base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size]\n\n    if chunk_size <= 1024:\n        return deduplicate_and_sort(base_list)\n\n    multiples = [i for i in range(1024, chunk_size + 1, 1024)]\n\n    return deduplicate_and_sort(base_list + multiples)\n#cuda_graphs = [Config().chunk_size] \nif torch.cuda.is_available():\n    cuda_graphs = generate_cuda_graphs(Config().chunk_size)\nelif use_torch_npu:\n    cuda_graphs = deduplicate_and_sort([1, 2, 3, 4])\nelse:\n    cuda_graphs = 1\n# class Base(BaseInjectedModule, ABC):\nclass KExpertsBase(ABC):\n    def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = \"cuda\", **kwargs):\n        # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.key = key\n        self.gguf_loader = gguf_loader\n        self.config = config\n        self.device = device\n    \n    @abstractmethod\n    def forward(self, input_tensor, expert_ids, weights):\n        pass\n\n    @abstractmethod\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = \"cpu\", warmup: bool = False):\n        pass\n    \n    @abstractmethod\n    def unload():\n        pass\n\n    def load_weights(self, override_key: str | None = None, device: str = \"cpu\"):\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n        gate_type = None\n        up_type = None\n        down_type = None\n\n        for key in keys:\n            if self.gguf_loader.has_tensor(key + \".ffn_gate_exps.weight\"):\n                targets = [\".ffn_gate_exps.weight\", \".ffn_up_exps.weight\", \".ffn_down_exps.weight\" ]\n                tensors = self.load_multi(key, targets, device=device)\n                gate = tensors[\".ffn_gate_exps.weight\"]\n                up = tensors[\".ffn_up_exps.weight\"]\n                down = tensors[\".ffn_down_exps.weight\"]\n                gate_type = self.gguf_loader.tensor_info[key + \".ffn_gate_exps.weight\"][\"ggml_type\"]\n                up_type = self.gguf_loader.tensor_info[key + \".ffn_up_exps.weight\"][\"ggml_type\"]\n                down_type = self.gguf_loader.tensor_info[key + \".ffn_down_exps.weight\"][\"ggml_type\"]\n            elif self.gguf_loader.has_tensor(key + \".ffn_down.0.weight\"):\n                # for supporting  Mixtral-8x7B-Instuct  \n                gate = []\n                up = []\n                down = []\n                for i in range(8):\n                    gatei, upi, downi = f\".ffn_gate.{i}.weight\", f\".ffn_up.{i}.weight\", f\".ffn_down.{i}.weight\"\n                    targets = [gatei, upi, downi]\n                    tensors = self.load_multi(key, targets, device=device)\n                    gate_it, up_it, down_it = tensors[gatei], tensors[upi], tensors[downi]\n                    gate.append(gate_it)\n                    up.append(up_it)\n                    down.append(down_it)\n                gate = torch.stack(gate)\n                up = torch.stack(up)\n                down = torch.stack(down)\n                gate_type = self.gguf_loader.tensor_info[key + \".ffn_gate.0.weight\"][\"ggml_type\"]\n                up_type = self.gguf_loader.tensor_info[key + \".ffn_up.0.weight\"][\"ggml_type\"]\n                down_type = self.gguf_loader.tensor_info[key + \".ffn_down.0.weight\"][\"ggml_type\"]\n            else:\n                raise ValueError(f\"Experts {key} not found in gguf_loader\")\n            res = {key:{\"gate\": gate, \"up\": up, \"down\": down, \"gate_type\": gate_type, \"up_type\": up_type, \"down_type\": down_type}}\n        return res\n    \n    def load_multi(self, key: str, keys: list[str], device: str = \"cpu\"):\n        tensors = {}\n        for k in keys:\n            tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)\n        return tensors\n\n\nclass KExpertsCPU(KExpertsBase):\n    input_tensor_cpu:Tensor = None\n    expert_ids_cpu:Tensor = None\n    weights_cpu:Tensor = None\n    output_cpu:Tensor = None\n    output_gpu_map:dict = {} # Manage output tensor buffer on different gpu\n    #stream_map:dict = {} # Manage cuda stream on different gpu\n    # @TODO add yaml\n    CPU_INFER = CPUInfer(Config().cpu_infer)\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        n_routed_experts: int,\n        orig_module: nn.Module = None,\n        device: str = \"cpu\",\n        out_device: str = \"cuda\", # this device mean which device the output should on. TODO: support cpu.\n        **kwargs\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        assert device.lower() == \"cpu\", \"KExpertsCPU can only be loaded on CPU\"\n        self.n_routed_experts = n_routed_experts\n        self.out_device = out_device\n        self.backend = kwargs.get(\"backend\", \"llamafile\")\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):\n        if use_torch_npu and get_tensor_parallel_size() != 1 and (\n            not torch.distributed.is_initialized() or torch.distributed.get_rank() != 0):\n            return\n\n        if device:\n            assert device.lower() == \"cpu\", \"KExpertsCPU can only be loaded on CPU, Parameter \\\"device\\\" can be cpu or None.\"\n        if w is None: w = self.load_weights()[self.key]\n        self.gate = w[\"gate\"]\n        self.up = w[\"up\"]\n        self.down = w[\"down\"]\n        self.gate_type = w[\"gate_type\"]\n        self.up_type = w[\"up_type\"]\n        self.down_type = w[\"down_type\"]\n        gate_ptr = ctypes.addressof(\n            ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        up_ptr = ctypes.addressof(\n            ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        down_ptr = ctypes.addressof(\n            ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        # print(self.gate_qtype, self.up_qtype, self.down_qtype)\n        n_routed_experts = self.n_routed_experts\n        self.cpu_infer = KExpertsCPU.CPU_INFER\n        # n_routed_experts = len(self.orig_module)\n        model_dtype = torch.get_default_dtype()\n        if torch.xpu.is_available() and model_dtype == torch.float16:\n            hidden_type = 1 # fp16\n        else:\n            hidden_type = 30 # bf16\n        if self.backend == \"llamafile\":\n            moe_config = MOEConfig(\n                n_routed_experts,\n                self.config.num_experts_per_tok,\n                self.config.hidden_size,\n                self.config.moe_intermediate_size,\n                64,\n                10,\n                1024,\n                self.config.hidden_act == 'silu',\n                gate_ptr,\n                up_ptr,\n                down_ptr,\n                self.gate_type,\n                self.up_type,\n                self.down_type,\n                hidden_type, # TODO: get from model.dtype\n            )\n            self.moe = MOE(moe_config)\n        elif self.backend == \"AMXBF16\":\n            from cpuinfer_ext.moe import AMX_MOEConfig, AMXBF16_MOE\n            assert self.gate_type == GGMLQuantizationType.BF16\n            assert self.up_type == GGMLQuantizationType.BF16\n            assert self.down_type == GGMLQuantizationType.BF16\n            moe_config = AMX_MOEConfig(\n                n_routed_experts,\n                self.config.num_experts_per_tok,\n                self.config.hidden_size,\n                self.config.moe_intermediate_size,\n                max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,\n                self.config.hidden_act == 'silu',\n                gate_ptr,\n                up_ptr,\n                down_ptr,\n            )\n            self.moe = AMXBF16_MOE(moe_config)\n            self.cpu_infer.submit(self.moe.load_weights())\n            self.cpu_infer.sync()\n        elif self.backend == \"AMXInt8\":\n            from cpuinfer_ext.moe import AMX_MOEConfig, AMXInt8_MOE\n            assert self.gate_type == GGMLQuantizationType.BF16\n            assert self.up_type == GGMLQuantizationType.BF16\n            assert self.down_type == GGMLQuantizationType.BF16\n            moe_config = AMX_MOEConfig(\n                n_routed_experts,\n                self.config.num_experts_per_tok,\n                self.config.hidden_size,\n                self.config.moe_intermediate_size,\n                max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,\n                self.config.hidden_act == 'silu',\n                gate_ptr,\n                up_ptr,\n                down_ptr,\n            )\n            self.moe = AMXInt8_MOE(moe_config)\n            self.cpu_infer.submit(self.moe.load_weights())\n            self.cpu_infer.sync()\n        # print(n_routed_experts, hidden_size, moe_intermediate_size)\n        num_experts_per_tok = self.config.num_experts_per_tok\n        if warmup:\n            self.cpu_infer.submit(self.moe.warm_up())\n            self.cpu_infer.sync()\n        if self.out_device not in KExpertsCPU.output_gpu_map:\n            if isinstance(cuda_graphs, list):\n                KExpertsCPU.output_gpu_map[self.out_device] = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=self.out_device) for i in range(len(cuda_graphs))]\n            else:\n                KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((cuda_graphs, self.config.hidden_size), device=self.out_device)\n        if KExpertsCPU.input_tensor_cpu == None:\n            if isinstance(cuda_graphs, list):\n                if use_torch_npu:\n                    KExpertsCPU.input_tensor_cpu = [[torch.zeros((cuda_graphs[i], self.config.hidden_size), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16)] for i in range(len(cuda_graphs))]\n                    KExpertsCPU.expert_ids_cpu = [[torch.zeros((cuda_graphs[i], num_experts_per_tok), device=\"cpu\", dtype=torch.long, pin_memory=True)] for i in range(len(cuda_graphs))]\n                    KExpertsCPU.weights_cpu = [[torch.zeros((cuda_graphs[i], num_experts_per_tok), device=\"cpu\", dtype=torch.float32, pin_memory=True)] for i in range(len(cuda_graphs))]\n                    KExpertsCPU.output_cpu = [[torch.zeros((cuda_graphs[i], self.config.hidden_size), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16)] for i in range(len(cuda_graphs))]\n                    KExpertsCPU.bsz_tensor_cpu = [[torch.tensor([cuda_graphs[i]], device=\"cpu\", dtype=torch.int32, pin_memory=True)] for i in range(len(cuda_graphs))]                    \n                else:\n                    KExpertsCPU.input_tensor_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=\"cpu\", pin_memory=True) for i in range(len(cuda_graphs))]\n                    KExpertsCPU.expert_ids_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device=\"cpu\", dtype=torch.long, pin_memory=True) for i in range(len(cuda_graphs))]\n                    KExpertsCPU.weights_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device=\"cpu\", dtype=torch.float32, pin_memory=True) for i in range(len(cuda_graphs))]\n                    KExpertsCPU.output_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16) for i in range(len(cuda_graphs))]\n                    KExpertsCPU.bsz_tensor_cpu = [torch.zeros((1), device=\"cpu\", dtype=torch.int32, pin_memory=True) for i in range(len(cuda_graphs))]\n            else:\n                KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device=\"cpu\", pin_memory=True)\n                KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device=\"cpu\", dtype=torch.long, pin_memory=True)\n                KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device=\"cpu\", dtype=torch.float32, pin_memory=True)\n                if torch.xpu.is_available():\n                    KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device=\"cpu\", pin_memory=True, dtype=model_dtype)\n                    KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device=\"cpu\", dtype=torch.int32, pin_memory=True)\n                else:\n                    KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16)\n                    KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device=\"cpu\", dtype=torch.int32, pin_memory=True)\n            \n    def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):\n        if bsz_tensor is None:\n            bsz_tensor = torch.ones(1, device=input_tensor.device, dtype=torch.int32)\n        if cuda_graph_idx != -1:\n            KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)\n            KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)\n            KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)\n            KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)\n            self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))\n        else:\n            KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)\n            KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)\n            KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)\n            KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)\n            self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))\n        \n\n    def sync_for_one_decode(self, cuda_graph_idx=0):\n        if cuda_graph_idx != -1:\n            self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)\n            KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)\n            return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]\n        else:\n            self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)\n            KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)\n            return KExpertsCPU.output_gpu_map[self.out_device]\n\n    def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):\n        # generate, capture and run cuda graph\n        # print(expert_ids)\n        if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1):\n            bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)\n        if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():\n            if cuda_graph_idx != -1:\n                KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)\n                KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)\n                KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)\n                KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)\n                self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))\n                self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)\n                KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)\n                return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]\n\n            else:\n                KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)\n                KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)\n                KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)\n                KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)\n                self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))\n                self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)\n                KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)\n                return KExpertsCPU.output_gpu_map[self.out_device]\n        elif input_tensor.size(0)==1 and torch.xpu.is_available():\n            KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True)\n            KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True)\n            KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True)\n            # KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True)\n            self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))\n            self.cpu_infer.sync()\n            KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)\n            return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1)\n        else:\n            input_tensor = input_tensor.contiguous().cpu()\n            expert_ids = expert_ids.contiguous().cpu()\n            weights = weights.contiguous().to(torch.float32).cpu()\n            bsz_tensor = bsz_tensor.contiguous().cpu()\n            output = torch.empty_like(input_tensor).contiguous()\n            self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))\n            self.cpu_infer.sync()\n            return output.to(device=object.__getattribute__(self, \"out_device\"))\n    \n    def unload(self):\n        return\n\n    def load_weights(self, override_key: str | None = None, device: str = \"cpu\"):\n        # TODO: support Bias\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n        gate_type = None\n        up_type = None\n        down_type = None\n\n        for key in keys:\n            if isinstance(self.gguf_loader, SafeTensorLoader):\n                res = self.gguf_loader.load_experts(key)\n                return {key: res}\n            elif self.gguf_loader.has_tensor(key + \".ffn_gate_exps.weight\"):\n                gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n                # gate_type = self.gguf_loader.tensor_info[key + \".ffn_gate_exps.weight\"][\"ggml_type\"]\n                # up_type = self.gguf_loader.tensor_info[key + \".ffn_up_exps.weight\"][\"ggml_type\"]\n                # down_type = self.gguf_loader.tensor_info[key + \".ffn_down_exps.weight\"][\"ggml_type\"]\n                gate_type = self.gguf_loader.get_ggml_type(key + \".ffn_gate_exps.weight\")\n                up_type = self.gguf_loader.get_ggml_type(key + \".ffn_up_exps.weight\")\n                down_type = self.gguf_loader.get_ggml_type(key + \".ffn_down_exps.weight\")\n            \n            elif key + \".ffn_gate_exps.weight\" in self.gguf_loader.tensor_info:\n                gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n                gate_type = self.gguf_loader.tensor_info[key + \".ffn_gate_exps.weight\"][\"ggml_type\"]\n                up_type = self.gguf_loader.tensor_info[key + \".ffn_up_exps.weight\"][\"ggml_type\"]\n                down_type = self.gguf_loader.tensor_info[key + \".ffn_down_exps.weight\"][\"ggml_type\"]\n            elif key + \".ffn_down.0.weight\" in self.gguf_loader.tensor_info:\n                # for supporting  Mixtral-8x7B-Instuct  \n                gate = []\n                up = []\n                down = []\n                for i in range(8):\n                    gate_it = self.gguf_loader.get_mmap_tensor(f\"{key}.ffn_gate.{i}.weight\")\n                    up_it = self.gguf_loader.get_mmap_tensor(f\"{key}.ffn_up.{i}.weight\")\n                    down_it = self.gguf_loader.get_mmap_tensor(f\"{key}.ffn_down.{i}.weight\")\n                    gate.append(gate_it)\n                    up.append(up_it)\n                    down.append(down_it)\n                gate = np.stack(gate)\n                up = np.stack(up)\n                down = np.stack(down)\n                gate_type = self.gguf_loader.get_ggml_type(key + \".ffn_gate.0.weight\")\n                up_type = self.gguf_loader.get_ggml_type(key + \".ffn_up.0.weight\")\n                down_type = self.gguf_loader.get_ggml_type(key + \".ffn_down.0.weight\")\n            elif self.gguf_loader.safetensor_loader is not None:\n                # for npu\n                # using a temp ugly way to temprary load the tensor\n                translate_key = translate_name_to_gguf(key)\n                gate = self.gguf_loader.safetensor_loader.load_tensor(translate_key + \".ffn_gate_exps.weight\").numpy()\n                up = self.gguf_loader.safetensor_loader.load_tensor(translate_key + \".ffn_up_exps.weight\").numpy()\n                down = self.gguf_loader.safetensor_loader.load_tensor(translate_key + \".ffn_down_exps.weight\").numpy()\n                gate_type = self.gguf_loader.safetensor_loader.load_tensor(translate_key + \".ffn_gate_exps.ggml_type\").item()\n                up_type = self.gguf_loader.safetensor_loader.load_tensor(translate_key + \".ffn_up_exps.ggml_type\").item()\n                down_type = self.gguf_loader.safetensor_loader.load_tensor(translate_key + \".ffn_down_exps.ggml_type\").item()\n            else:\n                raise ValueError(f\"Experts {key} not found in gguf_loader\")\n            res = {key:{\"gate\": gate, \"up\": up, \"down\": down, \"gate_type\": gate_type, \"up_type\": up_type, \"down_type\": down_type}}\n        return res\n    \nclass KExpertsMarlin(KExpertsBase):\n    expert_num: int\n    loaded_experts_idx: list[int]\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        n_routed_experts: int,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        **kwargs\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.expert_num = n_routed_experts\n        self.loaded_experts_idx = []\n        self.act_fn = ACT2FN[config.hidden_act]\n        assert device.lower() != \"cpu\", \"Marlin experts can only be loaded on GPU\"\n        self.device = device\n        self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size\n\n        # create empty marlin experts according to the number of experts per token\n        # up\n        self.up_projs = [KLinearMarlin(key+ \".\" + \"ffn_up_exps\", gguf_loader, config, device=device) for i in range(self.expert_num)]\n        # gate\n        self.gate_projs = [KLinearMarlin(key+ \".\" + \"ffn_gate_exps\", gguf_loader, config, device=device) for i in range(self.expert_num)]\n        # down\n        self.down_projs = [KLinearMarlin(key+ \".\" + \"ffn_down_exps\", gguf_loader, config, device=device) for i in range(self.expert_num)]\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):\n        if device is None: device = self.device\n        assert device.lower() != \"cpu\", \"Marlin experts can only be loaded on GPU\"\n        if w is None:\n            w = self.load_weights()\n            load_by_experts = True\n\n        if load_by_experts:\n            if isinstance(w, dict):\n                self.gate = w[\"gate\"]\n                self.up = (w[\"up\"])\n                self.down = (w[\"down\"])\n                for i in tqdm(range(self.expert_num), desc=f\"Dequanting and quanting for KExpertsMarlin {self.key}\"):\n                    up_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_up_exps.weight\", self.up, i, self.elements_per_tensor, device=self.device)\n                    gate_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_gate_exps.weight\", self.gate, i, self.elements_per_tensor, device=self.device)\n                    down_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_down_exps.weight\", self.down, i, self.elements_per_tensor, device=self.device)\n                    \n                    self.up_projs[i].load(nn.Parameter(up_weights), device=device)\n                    self.gate_projs[i].load(nn.Parameter(gate_weights), device=device)\n                    self.down_projs[i].load(nn.Parameter(down_weights), device=device)\n                    self.loaded_experts_idx.append(i)\n        else:\n            if isinstance(w, dict):\n                self.gate = w[\"gate\"]\n                self.up = (w[\"up\"])\n                self.down = (w[\"down\"])\n                for i in range(self.expert_num):\n                    self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)\n                    self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)\n                    self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)\n                    self.loaded_experts_idx.append(i)\n        return \n\n    def unload(self):\n        for i in self.loaded_experts_idx:\n            self.up_projs[i].unload()\n            self.gate_projs[i].unload()\n            self.down_projs[i].unload()\n        self.loaded_experts_idx = []\n\n    def load_weights(self, override_key: str | None = None):\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n\n        for key in keys:\n            if self.gguf_loader.has_tensor(key + \".ffn_gate_exps.weight\"):\n                gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n            res = {\"gate\": gate, \"up\": up, \"down\": down}\n        return res\n\n    def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:\n        org_dtype = hidden_states_cpu.dtype\n        org_device = hidden_states_cpu.device\n        hidden_states_cpu = hidden_states_cpu.to(self.device)\n        selected_experts_cpu = selected_experts_cpu.to(self.device)\n        routing_weights_cpu = routing_weights_cpu.to(self.device).to(org_dtype)\n        \n        batch_sequence_length, hidden_dim = hidden_states_cpu.size()\n\n        final_hidden_states = torch.zeros(\n            (batch_sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device\n        )\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.expert_num):\n            if not expert_mask[expert_idx].any():\n                continue\n            idx, top_x = torch.where(expert_mask[expert_idx])\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)\n            G = self.gate_projs[expert_idx].forward(current_state)\n            A = self.act_fn(G)\n            U = self.up_projs[expert_idx].forward(current_state)\n            H = A * U  # Element-wise multiplication\n            current_hidden_states = self.down_projs[expert_idx].forward(H) * routing_weights_cpu[top_x, idx, None]\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states)\n        \n        return final_hidden_states.to(dtype=org_dtype, device=org_device)\n    \n# untested, CUDA OOM\nclass KExpertsTorch(KExpertsBase):\n    expert_num: int\n    loaded_experts_idx: list[int]\n    gate: torch.Tensor\n    up: torch.Tensor\n    down: torch.Tensor\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        n_routed_experts: int,\n        orig_module: nn.Module = None,\n        device: str = \"cpu\",\n        **kwargs\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.expert_num = n_routed_experts\n        # self.loaded_experts_idx = []\n        self.act_fn = ACT2FN[config.hidden_act]\n        self.device = device\n        self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size\n        self.gate = [None for _ in range(self.expert_num)]\n        self.up = [None for _ in range(self.expert_num)]\n        self.down = [None for _ in range(self.expert_num)]\n        self.dtype = torch.get_default_dtype()\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):\n        if device is None: device = self.device\n        if w is None:\n            w = self.load_weights()\n            load_by_experts = True\n\n        if load_by_experts:\n            if isinstance(w, dict):\n                for i in tqdm(range(self.expert_num), desc=f\"Dequanting for KExpertsTorch {self.key}\"):\n                    up_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_up_exps.weight\", w[\"up\"], i, self.elements_per_tensor, device=self.device)\n                    gate_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_gate_exps.weight\", w[\"gate\"], i, self.elements_per_tensor, device=self.device)\n                    down_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_down_exps.weight\", w[\"down\"], i, self.elements_per_tensor, device=self.device)\n                    \n                    self.up[i] = up_weights\n                    self.gate[i] = gate_weights\n                    self.down[i] = down_weights\n        else:\n            if isinstance(w, dict):\n                for i in range(self.expert_num):\n                    self.gate[i] = w[\"gate\"][i, ...].to(device=device, dtype=self.dtype)\n                    self.up[i] = w[\"up\"][i, ...].to(device=device, dtype=self.dtype)\n                    self.down[i] = w[\"down\"][i, ...].to(device=device, dtype=self.dtype)\n        \n        self.up = torch.stack(self.up, dim=0)\n        self.gate = torch.stack(self.gate, dim=0)\n        self.down = torch.stack(self.down, dim=0)\n        return \n\n    def unload(self):\n        if self.gate is not None:\n            self.gate = None\n            self.up = None\n            self.down = None\n\n    def load_weights(self, override_key: str | None = None):\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n\n        for key in keys:\n            if key + \".ffn_gate_exps.weight\" in self.gguf_loader.tensor_info:\n                gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n            res = {\"gate\": gate, \"up\": up, \"down\": down}\n        return res\n\n    def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:\n\n        org_device = hidden_states_cpu.device\n        hidden_states_cpu = hidden_states_cpu.to(self.device)\n        selected_experts_cpu = selected_experts_cpu.to(self.device)\n        routing_weights_cpu = routing_weights_cpu.to(self.device)\n        \n        batch_sequence_length, hidden_dim = hidden_states_cpu.size()\n\n        final_hidden_states = torch.zeros(\n            (batch_sequence_length, hidden_dim), dtype=self.gate.dtype, device=hidden_states_cpu.device\n        )\n        org_dtype = hidden_states_cpu.dtype\n        hidden_states_cpu = hidden_states_cpu.to(self.gate.dtype)\n        routing_weights_cpu = routing_weights_cpu.to(self.gate.dtype)\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.expert_num):\n            idx, top_x = torch.where(expert_mask[expert_idx])\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)\n            G = current_state @ self.gate[expert_idx,...].T\n            A = self.act_fn(G)\n            U = current_state @ self.up[expert_idx,...].T\n            H = A * U  # Element-wise multiplication\n            current_hidden_states = H @ self.down[expert_idx,...].T * routing_weights_cpu[top_x, idx, None]\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states)\n\n\n        return final_hidden_states.to(dtype=org_dtype, device=org_device)\n\nEXPERTS_MAP = {\n    \"KExpertsCPU\": KExpertsCPU,\n    \"KExpertsTorch\": KExpertsTorch,\n    \"KExpertsMarlin\": KExpertsMarlin,\n}\n\nclass KTransformersExperts(BaseInjectedModule, KExpertsBase):\n    def __init__(self,\n                 key: str,\n                 gguf_loader: GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                #  device: str = \"cuda\",\n                 prefill_device:str = \"cuda\",\n                 prefill_op: str | None = \"KExpertsTorch\",\n                 generate_device: str = \"cpu\",\n                 generate_op: str | None = \"KExpertsCPU\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        if generate_op is not None:\n            self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)\n        else:\n            self.generate_experts = None\n        if prefill_op is not None:\n            self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)\n        else:\n            self.prefill_experts = None\n        self.gpu_mlp_type = prefill_op\n        self.cpu_mlp_type = generate_op\n        self.mode = InferenceState.UNLOAD\n\n    def load(self, w: dict = None,  mode: InferenceState = None, warmup: bool = True):\n        # TODO support w as input\n        if not mode: mode = InferenceState.GENERATE\n        if mode == InferenceState.GENERATE:\n            self.prefill_experts.unload()\n            self.generate_experts.load(w, warmup=warmup)\n            self.device = self.generate_experts.device\n            self.mode = mode\n        elif mode == InferenceState.PREFILL:\n            self.generate_experts.unload()\n            self.prefill_experts.load(w, warmup=warmup)\n            self.device = self.prefill_experts.device\n            self.mode = mode\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n            self.mode = mode\n            self.device = self.generate_experts.device\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n    def unload(self):\n        if self.generate_experts is not None:\n            self.generate_experts.unload()\n        if self.prefill_experts is not None:\n            self.prefill_experts.unload()\n        self.device = self.generate_experts.device\n\n    def forward(self, input_tensor, expert_ids, weights):\n        if self.mode == InferenceState.GENERATE:\n            assert self.generate_experts is not None, \"generate_experts is None\"\n            return self.generate_experts.forward(input_tensor, expert_ids, weights)\n        elif self.mode == InferenceState.PREFILL:\n            assert self.prefill_experts is not None, \"prefill_experts is None\"\n            return self.prefill_experts.forward(input_tensor, expert_ids, weights)\n        else:\n            raise ValueError(\"load or set_inference_mode before forward\")\n\n    def set_inference_mode(self, mode: InferenceState):\n        if mode == InferenceState.GENERATE:\n            self.load(mode=InferenceState.GENERATE, warmup=False)\n        elif mode == InferenceState.PREFILL:\n            self.load(mode=InferenceState.PREFILL, warmup=False)\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n\nfrom ktransformers.models.modeling_deepseek import DeepseekV2MoE\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock\nfrom ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock\nfrom ktransformers.models.modeling_smallthinker import SmallthinkerMoeBlock\nfrom ktransformers.models.modeling_glm4_moe import Glm4MoeMoE\nfrom ktransformers.models.modeling_qwen3_next import Qwen3NextSparseMoeBlock\n\n\nclass KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        orig_shape = hidden_states.shape\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        if self.norm_topk_prob:\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n        \n        if sequence_length == 1 and hasattr(self.experts.generate_experts, \"submit_for_one_decode\"):\n            self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0])\n            shared_expert_output = self.shared_expert(hidden_states)\n            shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output\n            y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)\n            y += shared_expert_output\n            y.resize_(*orig_shape)\n            return y, router_logits\n        \n        hidden_states_expert = hidden_states.to(self.experts.device)  if isinstance(self.experts, KExpertsBase) else hidden_states.cpu()\n        selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts.cpu()\n        routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights.cpu()\n\n        shared_expert_output = self.shared_expert(hidden_states)\n        shared_expert_output = (\n            F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output\n        )\n\n        if isinstance(self.experts, KExpertsBase):\n            y = (\n                self.moe_kexperts(\n                    hidden_states_expert, selected_experts_expert, routing_weights_expert\n                )\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        elif hidden_states_expert.size(0) > 10:\n            y = self.moe_infer(\n                hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape\n            ).to(device=hidden_states.device)\n        else:\n            y = self.moe_infer_simple(\n                hidden_states_expert, selected_experts_expert, routing_weights_expert\n            ).to(device=hidden_states.device)\n        y += shared_expert_output\n        y.resize_(*orig_shape)\n        return y, router_logits\n    \n    @torch.no_grad()\n    def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:\n        outs = self.experts(x, topk_ids, topk_weight)\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:\n        '''\n        hidden_states_cpu: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        '''\n        outs = torch.zeros_like(hidden_states_cpu)\n        for token_idx in range(selected_experts_cpu.size(0)):\n            for expert_idx in range(selected_experts_cpu.size(1)):\n                expert = self.experts[selected_experts_cpu[token_idx, expert_idx]]\n                outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx]\n        return outs\n    \n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor:\n        \n        batch_size, sequence_length, hidden_dim = orig_shape\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))\n\n        return final_hidden_states\n\nclass KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):\n    def forward(self, hidden_states):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n        topk_idx, topk_weight, aux_loss = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        \n        if sequence_length == 1 and hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():\n            self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])\n            if self.config.n_shared_experts is not None:\n                y_ = self.shared_experts(identity).squeeze(0)\n            y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)\n            y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        if self.config.n_shared_experts is not None:\n            y_ = self.shared_experts(identity).squeeze(0)\n            \n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        if self.config.n_shared_experts is not None:\n            y += y_\n        return y\n\n    @torch.no_grad()\n    def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:\n        outs = self.experts(x, topk_ids, topk_weight)\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\nclass KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):\n    \n    def forward(self, hidden_states):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n        topk_idx, topk_weight = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        \n        # only for generate phase\n        if sequence_length == 1 and hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():\n            self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])\n            if self.config.n_shared_experts is not None:\n                y_ = self.shared_experts(identity).squeeze(0)\n            y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)\n            y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        if self.config.n_shared_experts is not None:\n            y_ = self.shared_experts(identity).squeeze(0)\n            \n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        if self.config.n_shared_experts is not None:\n            y += y_\n        return y\n\n\n\n    @torch.no_grad()\n    def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:\n        outs = self.experts(x, topk_ids, topk_weight)\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\nclass KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):\n    \n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        orig_shape = hidden_states.shape\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        if self.training and self.jitter_noise > 0:\n            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n        \n        if sequence_length == 1 and hasattr(self.experts.generate_experts, \"submit_for_one_decode\"):\n            self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0])\n            y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)\n            y.resize_(*orig_shape)\n            return y, router_logits\n        \n        hidden_states_expert = hidden_states.to(self.experts.device)  if isinstance(self.experts, KExpertsBase) else hidden_states_expert.cpu()\n        selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts_expert.cpu()\n        routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights_expert.cpu()\n\n        if isinstance(self.experts, KExpertsBase):\n            y = (\n                self.moe_kexperts(\n                    hidden_states_expert, selected_experts_expert, routing_weights_expert\n                )\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        elif hidden_states_expert.size(0) > 10:\n            y = self.moe_infer(\n                hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape\n            ).to(device=hidden_states.device)\n        else:\n            y = self.moe_infer_simple(\n                hidden_states_expert, selected_experts_expert, routing_weights_expert\n            ).to(device=hidden_states.device)\n            \n        y.resize_(*orig_shape)\n        return y, router_logits\n    \n    @torch.no_grad()\n    def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:\n        outs = self.experts(x, topk_ids, topk_weight)\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:\n        '''\n        hidden_states_cpu: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        '''\n        outs = torch.zeros_like(hidden_states_cpu)\n        for token_idx in range(selected_experts_cpu.size(0)):\n            for expert_idx in range(selected_experts_cpu.size(1)):\n                expert = self.experts[selected_experts_cpu[token_idx, expert_idx]]\n                outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx]\n        return outs\n    \n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor:\n        \n        batch_size, sequence_length, hidden_dim = orig_shape\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))\n\n        return final_hidden_states\n\nclass KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):\n    def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n        topk_idx, topk_weight = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        \n\n        # only for generate phase\n        if hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug\n            self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)\n            if self.config.n_shared_experts is not None:\n                y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)\n            y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)\n            y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        if self.config.n_shared_experts is not None:\n            y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)\n            \n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        if self.config.n_shared_experts is not None:\n            y += y_\n        return y\n\n    @torch.no_grad()\n    def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:\n        outs = torch.empty_like(x)\n        outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\nclass KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):\n    def __init__(self,\n                 key: str,\n                 gguf_loader: GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                #  device: str = \"cuda\",\n                 prefill_device:str = \"cuda\",\n                 prefill_op: str | None = \"KExpertsTorch\",\n                 generate_device: str = \"cpu\",\n                 generate_op: str | None = \"KExpertsCPU\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n\n        if prefill_op == 'None':\n            prefill_op = None\n        if generate_op == 'None':\n            generate_op = None\n\n        if generate_op is not None:\n            self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)\n        else:\n            self.generate_experts = None\n        if prefill_op is not None:\n            self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)\n        else:\n            self.prefill_experts = None\n        self.gpu_mlp_type = prefill_op\n        self.cpu_mlp_type = generate_op\n        self.mode = InferenceState.UNLOAD\n\n    def load(self, w: dict = None,  mode: InferenceState = None, warmup: bool = True):\n        # TODO support w as input\n        if not mode: mode = InferenceState.GENERATE\n        if mode == InferenceState.GENERATE:\n            self.prefill_experts.unload()\n            self.generate_experts.load(w, warmup=warmup)\n            self.device = self.generate_experts.device\n            self.mode = mode\n        elif mode == InferenceState.PREFILL:\n            self.generate_experts.unload()\n            self.prefill_experts.load(w, warmup=warmup)\n            self.device = self.prefill_experts.device\n            self.mode = mode\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n            self.mode = mode\n            self.device = self.generate_experts.device\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n    def unload(self):\n        if self.generate_experts is not None:\n            self.generate_experts.unload()\n        if self.prefill_experts is not None:\n            self.prefill_experts.unload()\n        self.device = self.generate_experts.device\n\n    def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):\n        if self.mode == InferenceState.GENERATE:\n            assert self.generate_experts is not None, \"generate_experts is None\"\n            return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)\n        elif self.mode == InferenceState.PREFILL:\n            assert self.prefill_experts is not None, \"prefill_experts is None\"\n            return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)\n        else:\n            raise ValueError(\"load or set_inference_mode before forward\")\n\n    def set_inference_mode(self, mode: InferenceState):\n        if mode == InferenceState.GENERATE:\n            self.load(mode=InferenceState.GENERATE, warmup=False)\n        elif mode == InferenceState.PREFILL:\n            self.load(mode=InferenceState.PREFILL, warmup=False)\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n\nclass KSmallthinkerExperts(BaseInjectedModule, KExpertsBase):\n    def __init__(self,\n                 key: str,\n                 gguf_loader: GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                #  device: str = \"cuda\",\n                 prefill_device:str = \"cuda\",\n                 prefill_op: str | None = \"KExpertsTorch\",\n                 generate_device: str = \"cpu\",\n                 generate_op: str | None = \"KExpertsCPU\",\n                 **kwargs):\n\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        if generate_op is not None:\n            self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)\n        else:\n            self.generate_experts = None\n        if prefill_op is not None:\n            self.prefill_experts = None\n        self.gpu_mlp_type = prefill_op\n        self.cpu_mlp_type = generate_op\n        self.mode = InferenceState.UNLOAD\n\n    def load(self, w: dict = None,  mode: InferenceState = None, warmup: bool = True):\n        # TODO support w as input\n        if not mode: mode = InferenceState.GENERATE\n        if mode == InferenceState.GENERATE:\n            # self.prefill_experts.unload()\n            self.generate_experts.load(w, warmup=warmup)\n            self.device = self.generate_experts.device\n            self.mode = mode\n        elif mode == InferenceState.PREFILL:\n            self.generate_experts.unload()\n            self.prefill_experts.load(w, warmup=warmup)\n            self.device = self.prefill_experts.device\n            self.mode = mode\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n            self.mode = mode\n            self.device = self.generate_experts.device\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n    def unload(self):\n        if self.generate_experts is not None:\n            self.generate_experts.unload()\n        if self.prefill_experts is not None:\n            self.prefill_experts.unload()\n        self.device = self.generate_experts.device\n\n    def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):\n        if self.mode == InferenceState.GENERATE:\n            assert self.generate_experts is not None, \"generate_experts is None\"\n            return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)\n        elif self.mode == InferenceState.PREFILL:\n            assert self.prefill_experts is not None, \"prefill_experts is None\"\n            return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)\n        else:\n            raise ValueError(\"load or set_inference_mode before forward\")\n\n    def set_inference_mode(self, mode: InferenceState):\n        if mode == InferenceState.GENERATE:\n            self.load(mode=InferenceState.GENERATE, warmup=False)\n        elif mode == InferenceState.PREFILL:\n            self.load(mode=InferenceState.PREFILL, warmup=False)\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\nclass KGlm4Experts(BaseInjectedModule, KExpertsBase):\n    def __init__(self,\n                 key: str,\n                 gguf_loader: GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                #  device: str = \"cuda\",\n                 prefill_device:str = \"cuda\",\n                 prefill_op: str | None = \"KExpertsTorch\",\n                 generate_device: str = \"cpu\",\n                 generate_op: str | None = \"KExpertsCPU\",\n                 **kwargs):\n\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        if generate_op is not None:\n            self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)\n        else:\n            self.generate_experts = None\n        if prefill_op is not None:\n            self.prefill_experts = None\n        self.gpu_mlp_type = prefill_op\n        self.cpu_mlp_type = generate_op\n        self.mode = InferenceState.UNLOAD\n\n    def load(self, w: dict = None,  mode: InferenceState = None, warmup: bool = True):\n        # TODO support w as input\n        if not mode: mode = InferenceState.GENERATE\n        if mode == InferenceState.GENERATE:\n            # self.prefill_experts.unload()\n            self.generate_experts.load(w, warmup=warmup)\n            self.device = self.generate_experts.device\n            self.mode = mode\n        elif mode == InferenceState.PREFILL:\n            self.generate_experts.unload()\n            self.prefill_experts.load(w, warmup=warmup)\n            self.device = self.prefill_experts.device\n            self.mode = mode\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n            self.mode = mode\n            self.device = self.generate_experts.device\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n    def unload(self):\n        if self.generate_experts is not None:\n            self.generate_experts.unload()\n        if self.prefill_experts is not None:\n            self.prefill_experts.unload()\n        self.device = self.generate_experts.device\n\n    def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):\n        if self.mode == InferenceState.GENERATE:\n            assert self.generate_experts is not None, \"generate_experts is None\"\n            return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)\n        elif self.mode == InferenceState.PREFILL:\n            assert self.prefill_experts is not None, \"prefill_experts is None\"\n            return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)\n        else:\n            raise ValueError(\"load or set_inference_mode before forward\")\n\n    def set_inference_mode(self, mode: InferenceState):\n        if mode == InferenceState.GENERATE:\n            self.load(mode=InferenceState.GENERATE, warmup=False)\n        elif mode == InferenceState.PREFILL:\n            self.load(mode=InferenceState.PREFILL, warmup=False)\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n\nclass KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):\n    def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):\n\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n\n        router_logits = self.gate(hidden_states, bsz_tensor)        \n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        if self.norm_topk_prob:\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        # only for generate phase\n        if hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug\n            self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)\n            y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n            y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_    \n\n            y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)\n            \n            y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n        y_ = (\n            F.sigmoid(self.shared_expert_gate(hidden_states)) * y_\n        )\n\n\n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            ) \n        y += y_\n        return y\n\n    @torch.no_grad()\n    def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:\n        outs = torch.empty_like(x)\n        outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\nclass KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):\n    def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):\n\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n\n        if bsz_tensor is None:\n            router_logits = self.gate(hidden_states)\n        else:\n            router_logits = self.gate(hidden_states, bsz_tensor)\n\n        if router_logits.device.type == \"xpu\":\n            from ipex_llm.transformers.models.common import moe_softmax_topk\n            selected_experts, routing_weights = moe_softmax_topk(\n                router_logits.half(), self.top_k, self.norm_topk_prob\n            )\n        else:\n            routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)\n            routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n            if self.norm_topk_prob:\n                routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        # only for generate phase\n        if hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug\n            self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)\n            # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n            # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_    \n\n            y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)\n            \n            # y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n        # y_ = (\n        #     F.sigmoid(self.shared_expert_gate(hidden_states)) * y_\n        # )\n\n\n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            ) \n        # y += y_\n        return y\n\n    @torch.no_grad()\n    def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:\n        outs = torch.empty_like(x)\n        outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\n\nclass KSmallthinkerMoeBlock(BaseInjectedModule, SmallthinkerMoeBlock):\n    def forward(self, router_input: torch.Tensor, hidden_states: torch.Tensor, bsz_tensor=None, cuda_graph_idx=0):\n\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n\n        if bsz_tensor is None:\n            if self.enable_early_router:\n                router_logits = self.primary_router(router_input)\n            else:\n                router_logits = self.primary_router(hidden_states)\n        else:\n            if self.enable_early_router:\n                router_logits = self.primary_router(router_input, bsz_tensor)\n            else:\n                router_logits = self.primary_router(hidden_states, bsz_tensor)\n\n        router_logits, selected_experts = torch.topk(router_logits, self.num_active_primary_experts, dim=-1)\n\n\n        if router_logits.device.type == \"xpu\":\n            # TODO: support self.moe_primary_router_apply_softmax False case\n            from ipex_llm.transformers.models.common import moe_softmax_topk\n            selected_experts, routing_weights = moe_softmax_topk(\n                router_logits.half(), self.top_k, self.norm_topk_prob\n            )\n        else:\n            if self.moe_primary_router_apply_softmax:\n                routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n            else:\n                routing_weights = F.sigmoid(router_logits)\n                routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        # only for generate phase\n        if hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug\n            self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)\n            # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n            # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_    \n\n            y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)\n            \n            # y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n        # y_ = (\n        #     F.sigmoid(self.shared_expert_gate(hidden_states)) * y_\n        # )\n\n\n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            ) \n        # y += y_\n        return y\n\n    @torch.no_grad()\n    def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:\n        outs = torch.empty_like(x)\n        outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\n\nclass KGlm4MoeMoE(BaseInjectedModule, Glm4MoeMoE):\n    def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):\n\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n\n        topk_idx, topk_weight = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n\n        # only for generate phase\n        if hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug\n            self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)\n            y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)\n            # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_    \n\n            y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)\n            \n            y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        y_ = self.shared_experts(hidden_states, bsz_tensor).squeeze(0)\n        # y_ = (\n        #     F.sigmoid(self.shared_expert_gate(hidden_states)) * y_\n        # )\n\n\n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            ) \n        y += y_\n        return y\n\n    @torch.no_grad()\n    def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:\n        outs = torch.empty_like(x)\n        outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\n\nclass KQwen3NextSparseMoeBlockV2(BaseInjectedModule, Qwen3NextSparseMoeBlock):\n    def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):\n\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n\n        if bsz_tensor is None:\n            router_logits = self.gate(hidden_states)\n        else:\n            router_logits = self.gate(hidden_states, bsz_tensor)\n\n        if router_logits.device.type == \"xpu\":\n            from ipex_llm.transformers.models.common import moe_softmax_topk\n            selected_experts, routing_weights = moe_softmax_topk(\n                router_logits.half(), self.top_k, self.norm_topk_prob\n            )\n        else:\n            routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)\n            routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n            if self.norm_topk_prob:\n                routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        if self.norm_topk_prob:\n            routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)\n\n        # only for generate phase\n        if hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug\n            self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)\n            y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n            y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_    \n\n            y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)\n            \n            y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n        y_ = (\n            F.sigmoid(self.shared_expert_gate(hidden_states)) * y_\n        )\n\n\n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            ) \n        y += y_\n        return y\n\n    @torch.no_grad()\n    def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:\n        outs = torch.empty_like(x)\n        outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @torch.no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out"
  },
  {
    "path": "archive/ktransformers/operators/flashinfer_batch_prefill_wrapper.py",
    "content": "import torch\nimport flashinfer\nimport gc\ntry:\n    from flash_attn import flash_attn_with_kvcache\n    print(\"found flash_attn\")\n    \nexcept ImportError:\n    print(\"flash_attn not found, flashinfer unit test needed it. If you are using balance serve, ignore this.\")\n\nfrom typing import Union, Optional\n\ndef setup_seed(seed):\n\ttorch.manual_seed(seed)\n\ttorch.cuda.manual_seed_all(seed)\n\nsetup_seed(998244353)\n\ntry:\n    import torch_npu\n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\n\nif not use_torch_npu:\n\ttorch.set_grad_enabled(False)\n\ttorch.set_default_dtype(torch.bfloat16)\n\tglobal_dtype=torch.bfloat16\n\tglobal_device=torch.device(\"cuda\",0)\n\ttorch.cuda.set_device(0)\n\ttorch.backends.cudnn.enabled =True\n\ttorch.backends.cudnn.benchmark = True\n\nclass flashInferAttn():\n\t\n\tfloat_workspace_buffer = None\n\tdef __init__(self,\n\t\t\tmax_batch_token,\n\t\t\tmax_batch_size,\n\t\t\tmax_pages,\n\t\t\tdevice = \"cuda:0\",\n\t\t\tkv_layout: str = \"NHD\",\n\t\t\tuse_cuda_graph: bool = False,\n\t\t\t) -> None:\n\t\tself.device = device\n\t\tself.max_batch_token = max_batch_token\n\t\tself.kv_layout = kv_layout\n\t\tself.use_cuda_graph = use_cuda_graph\n\t\tif flashInferAttn.float_workspace_buffer is None:\n\t\t\tflashInferAttn.float_workspace_buffer = torch.empty(max_batch_token * 1024 * 1024, dtype=torch.uint8, device=device)\n\t\tself.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)\n\t\tself.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)\n\t\tself.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)\n\t\tself.paged_kv_last_page_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device)\n\t\tself.batch_size_tensor_buf = torch.empty((1,), dtype=torch.int32, device=device)\n\t\tself.num_tokens_tensor_buf = torch.empty((1,), dtype=torch.uint32, device=device)\n\t\n\t\t# TODO: custom mask\n\t\tself.custom_mask_buf = None\n\t\tself.qk_indptr_buf = None\n\t\tself.warpper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(\n\t\t\tflashInferAttn.float_workspace_buffer,\n\t\t\tself.kv_layout,\n\t\t\tuse_cuda_graph=self.use_cuda_graph,\n\t\t\tqo_indptr_buf=self.qo_indptr_buf,\n\t\t\tpaged_kv_indptr_buf=self.paged_kv_indptr_buf,\n\t\t\tpaged_kv_indices_buf=self.paged_kv_indices_buf,\n\t\t\tpaged_kv_last_page_len_buf=self.paged_kv_last_page_len_buf,\n\t\t\tbackend = \"fa2\",\n\t\t)\n\n\tdef plan(self,\n\t\tqo_indptr: torch.Tensor,\n\t\tpaged_kv_indptr: torch.Tensor,\n\t\tpaged_kv_indices: torch.Tensor,\n\t\tpaged_kv_last_page_len: torch.Tensor,\n\t\tbatch_size_tensor: torch.Tensor,\n\t\tnum_tokens_tensor: torch.Tensor,\n\t\tnum_qo_heads: int,\n\t\tnum_kv_heads: int,\n\t\thead_dim: int,\n\t\tpage_size: int,\n\t\tcausal: bool = True, \n\t\tpos_encoding_mode: str = \"NONE\",\n\t\tq_data_type: Union[str, torch.dtype] = torch.bfloat16,\n\t\tkv_data_type: Optional[Union[str, torch.dtype]] = None):\n\t\t\n\t\tself.batch_size_tensor_buf.copy_(batch_size_tensor, non_blocking=True)\n\t\tself.num_tokens_tensor_buf.copy_(num_tokens_tensor, non_blocking=True)\n\t\tself.page_size = page_size\n\t\tself.warpper.plan(\n\t\t\tqo_indptr,\n\t\t\tpaged_kv_indptr,\n\t\t\tpaged_kv_indices,\n\t\t\tpaged_kv_last_page_len,\n\t\t\tnum_qo_heads,\n\t\t\tnum_kv_heads,\n\t\t\thead_dim,\n\t\t\tpage_size,\n\t\t\tcausal = causal,\n\t\t\tpos_encoding_mode = pos_encoding_mode,\n\t\t\tq_data_type = q_data_type,\n\t\t\tkv_data_type = kv_data_type\n\t\t\t)\n\n\tdef calc_batch_indices(self, ragged_size = None):\n\t\tif self.use_cuda_graph:\n\t\t\tself.batch_indices, self.positions = flashinfer.get_batch_indices_positions(\n\t\t\t\tself.qo_indptr_buf, flashinfer.get_seq_lens(self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, self.max_batch_token)\n\t\telse:\n\t\t\tself.batch_indices, self.positions = flashinfer.get_batch_indices_positions(\n\t\t\t\tself.warpper._qo_indptr_buf, flashinfer.get_seq_lens(self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, ragged_size)\n\n\tdef forward(self, q, k_cache, v_cache, k, v):\n\t\tif self.use_cuda_graph:\n\t\t\tflashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.paged_kv_indices_buf, self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.num_tokens_tensor_buf)\n\t\t\treturn self.warpper.run(q, (k_cache, v_cache))\n\t\telse:\n\t\t\tflashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.warpper._paged_kv_indices_buf, self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.num_tokens_tensor_buf)\n\t\t\treturn self.warpper.run(q, (k_cache, v_cache))\n\n\ndef testCudaGraph():\n\t\n\t# use max batch to create buffer\n\tbatch_decode = 8\n\tprefill_chunk = 48\n\tpast_kv_0 = 4090\n\tpast_kv_1 = 4096\n\traged_size = prefill_chunk + batch_decode\n\tnum_key_value_heads = 8\n\thead_dim = 128\n\tnum_attention_heads = 64\n\tpage_size = 256\n\tnum_pages_per_seq = (past_kv_1 + page_size - 1) // page_size\n\ttotal_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size\n\tattn = flashInferAttn(raged_size, batch_decode+1, total_num_pages, use_cuda_graph=True)\n\n\tbatch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32)\n\t\n\tk_caches = []\t\n\tv_caches = []\n\tks = []\n\tvs = []\n\tqs = []\n\tfor layer_idx in range(3):\n\t\tk_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tv_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tks.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tvs.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tqs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\n\t# warmup and capture small batch\n\tpast_kv_0 = 250\n\tpast_kv_1 = 256\n\tnum_pages_per_seq = (past_kv_1 + page_size - 1) // page_size\n\ttotal_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size\n\tq_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)\n\tq_indptr[0] = 0\n\tq_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)\n\tkv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq\n\tkv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)\n\tkv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)\n\tkv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)\n\tkv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)\n\n\tprint(q_indptr)\n\tprint(kv_indptr)\n\tprint(kv_indices)\n\tprint(kv_last_page_len)\n\tattn.plan(q_indptr,\n\t\t\tkv_indptr,\n\t\t\tkv_indices,\n\t\t\tkv_last_page_len,\n\t\t\tbatch_size_tensor,\n\t\t\tnum_attention_heads,\n\t\t\tnum_key_value_heads,\n\t\t\thead_dim,\n\t\t\tpage_size,\n\t\t\tcausal = True,\n\t\t\tpos_encoding_mode=\"NONE\",\n\t\t\tq_data_type=torch.bfloat16)\n\n\tattn.calc_batch_indices(raged_size)\n\tfor layer_idx in range(3):\n\t\tattn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx])\n\t\ttorch.cuda.synchronize()\n\n\touts = []\n\tg = torch.cuda.CUDAGraph()\n\twith torch.cuda.graph(g):\n\t\tfor layer_idx in range(3):\n\t\t\touts.append(attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx]))\n\tg.replay()\n\t\n\tkv_last_page_len[:1+batch_decode//2] = int(past_kv_0)\n\tkv_last_page_len[1+batch_decode//2:] = int(past_kv_1)\n\tfor layer_idx in range(3):\n\t\tfor i in range(batch_decode + 1):\n\t\t\t\n\t\t\tqi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]]\n\t\t\to_ref_i = flash_attn_with_kvcache(\n\t\t\t\tqi.unsqueeze(0),\n\t\t\t\tk_caches[layer_idx],\n\t\t\t\tv_caches[layer_idx],\n\t\t\t\tcausal=True,\n\t\t\t\tblock_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0),\n\t\t\t\tcache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32)\n\t\t\t)\n\t\t\to_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]]\n\t\t\tprint(layer_idx, i)\n\t\t\ttorch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3)\n\n\t# run another batch size use capture cuda graph\n\tpast_kv_0 = 4090\n\tpast_kv_1 = 4096\n\tprefill_chunk = 24\n\tbatch_decode = 4\n\tnum_pages_per_seq = (past_kv_1 + page_size - 1) // page_size\n\ttotal_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size\n\tbatch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32)\n\tnum_tokens_tensor = torch.tensor([batch_decode + prefill_chunk], device=global_device, dtype=torch.int32)\n\n\tq_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)\n\tq_indptr[0] = 0\n\tq_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)\n\tkv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq\n\tkv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)\n\tkv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)\n\tkv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)\n\tkv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)\n\tattn.plan(q_indptr,\n\t\t\tkv_indptr,\n\t\t\tkv_indices,\n\t\t\tkv_last_page_len,\n\t\t\tbatch_size_tensor,\n\t\t\tnum_attention_heads,\n\t\t\tnum_key_value_heads,\n\t\t\thead_dim,\n\t\t\tpage_size,\n\t\t\tcausal = True,\n\t\t\tpos_encoding_mode=\"NONE\",\n\t\t\tq_data_type=torch.bfloat16)\n\tattn.calc_batch_indices(raged_size)\n\tg.replay()\n\t\n\tkv_last_page_len[:1+batch_decode//2] = int(past_kv_0)\n\tkv_last_page_len[1+batch_decode//2:] = int(past_kv_1)\n\tfor layer_idx in range(3):\n\t\tfor i in range(batch_decode + 1):\n\t\t\t\n\t\t\tqi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]]\n\t\t\to_ref_i = flash_attn_with_kvcache(\n\t\t\t\tqi.unsqueeze(0),\n\t\t\t\tk_caches[layer_idx],\n\t\t\t\tv_caches[layer_idx],\n\t\t\t\tcausal=True,\n\t\t\t\tblock_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0),\n\t\t\t\tcache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32)\n\t\t\t)\n\t\t\to_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]]\n\t\t\tprint(layer_idx, i)\n\t\t\ttorch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3)\n\t\t\t\n\n\ndef testAttentionFlashInfer(\t\n\t):\n\tbatch_decode = 32\n\tprefill_chunk = 64\n\tpast_kv_0 = 510\n\tpast_kv_1 = 512\n\traged_size = prefill_chunk + batch_decode\n\tnum_key_value_heads = 8\n\thead_dim = 128\n\tnum_attention_heads = 64\n\tcases = 1\n\tpage_size = 32\n\tnum_pages_per_seq = (past_kv_1 + page_size - 1) // page_size\n\ttotal_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size\n\tworkspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=\"cuda:0\")\n\tqs = []\n\tkvs = []\n\tq_indptrs = []\n\tkv_indptrs = []\n\tkv_indicess = []\n\tkv_last_page_lens = []\n\twrappers = []\n\tfor case_id in range(cases):\n\t\tkvs.append(torch.randn(total_num_pages, 2, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tqs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tq_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)\n\t\tq_indptr[0] = 0\n\t\tq_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)\n\t\tq_indptrs.append(q_indptr)\n\t\tkv_indptrs.append(torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq)\n\t\tkv_indicess.append(torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32))\n\t\tkv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)\n\t\tkv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)\n\t\tkv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)\n\t\tkv_last_page_lens.append(kv_last_page_len)\n\t\twrappers.append(flashinfer.BatchPrefillWithPagedKVCacheWrapper(\n\t\t\tworkspace_buffer,\n\t\t\t\"NHD\",\n\t\t\tuse_cuda_graph=True,\n\t\t\tqo_indptr_buf=q_indptrs[case_id],\n\t\t\tpaged_kv_indptr_buf=kv_indptrs[case_id],\n\t\t\tpaged_kv_indices_buf=kv_indicess[case_id],\n\t\t\tpaged_kv_last_page_len_buf=kv_last_page_lens[case_id],\n\t\t))\n\t\twrappers[case_id].plan(\n\t\t\tq_indptrs[case_id],\n\t\t\tkv_indptrs[case_id],\n\t\t\tkv_indicess[case_id],\n\t\t\tkv_last_page_lens[case_id],\n\t\t\tnum_attention_heads,\n\t\t\tnum_key_value_heads,\n\t\t\thead_dim,\n\t\t\tpage_size,\n\t\t\tcausal = True,\n\t\t\tpos_encoding_mode=\"ROPE_LLAMA\",\n\t\t\tq_data_type=torch.bfloat16\n\t\t)\n\t\t\t\t\t\n\tdef custom_forward(case_id):\n\t\tout = wrappers[case_id].run(qs[case_id], kvs[case_id])\n\t\n\tcustom_forward(0)\n\n# testCudaGraph()\n# pass"
  },
  {
    "path": "archive/ktransformers/operators/flashinfer_wrapper.py",
    "content": "'''\nDescription  : flashinfer MLA wrapper\nAuthor       : Boxin Zhang\nVersion      : 0.2.3\n'''\nimport torch\nimport os\n\nflashinfer_enabled = False\n\ntry:\n    import flashinfer\n    flashinfer_enabled = True\n    print(\"found flashinfer\")\n    \nexcept ImportError:\n    print(\"flashinfer not found, use triton for linux\")\n\ntry:\n    import torch_npu\n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\n\nif not use_torch_npu:\n    from ktransformers.operators.triton_attention import decode_attention_fwd_grouped\n\nimport math\n\ndef attention_ref_torch(\n    batch_size,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    causal: bool,\n    sm_scale: float,\n) -> torch.Tensor:\n    qo_len = q.shape[0] // batch_size\n    kv_len = k.shape[0] // batch_size\n    num_qo_heads = q.shape[1]\n    head_dim_qk = q.shape[2]\n    head_dim_vo = v.shape[2]\n    logits = (\n        torch.einsum(\n            \"bmhd,bnhd->bhmn\",\n            q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(),\n            k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(),\n        )\n        * sm_scale\n    )\n\n    #print(\"attn weights\", logits)\n\n    if causal:\n        mask = (\n            torch.arange(kv_len - qo_len, kv_len).unsqueeze(1)\n            >= torch.arange(0, kv_len).unsqueeze(0)\n        ).to(q.device)\n    else:\n        mask = torch.ones(qo_len, kv_len).to(q.device)\n\n    logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float(\"-inf\"))\n    lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2)\n    p = torch.softmax(logits, dim=-1)\n    o_ref = (\n        torch.einsum(\n            \"bhmn,bnhd->bmhd\",\n            p,\n            v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(),\n        )\n        .contiguous()\n        .view(batch_size * qo_len, num_qo_heads, head_dim_vo)\n        .to(q)\n    )\n\n    return o_ref, lse_ref * math.log2(math.e)\n\nclass MLAWrapper():\n    def __init__(self,\n                 max_batch_size,\n                 max_pages,\n                 use_cuda_graph = True,\n                 device = \"cuda\",\n                 ):\n        self.float_workspace_buffer = torch.empty(128*1024*1024, dtype=torch.int8, device=device)\n        self.max_batch_size = max_batch_size\n        self.max_pages = max_pages\n        if use_cuda_graph:\n            if self.max_batch_size == 1:\n                self.qo_indptr_buf = torch.arange(0, max_batch_size+1, dtype=torch.int32, device=device)\n                self.kv_indptr_buf = torch.tensor([0, max_pages], dtype=torch.int32, device=device)\n                self.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)\n            else:\n                self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)\n                self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)\n                self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)\n            self.batch_size_tensor_buf = torch.tensor([self.max_batch_size], dtype=torch.int32, device=device)\n            self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)\n        else:\n            self.qo_indptr_buf = None\n            self.kv_indptr_buf = None\n            self.kv_indices_buf = None\n            self.kv_len_arr_buf = None\n        self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(\n            self.float_workspace_buffer,\n            use_cuda_graph=use_cuda_graph,\n            qo_indptr=self.qo_indptr_buf,\n            kv_indptr=self.kv_indptr_buf,\n            kv_indices=self.kv_indices_buf,\n            kv_len_arr=self.kv_len_arr_buf,\n            bsz_tensor=self.batch_size_tensor_buf,\n            backend = \"fa2\",\n        )\n        self.need_plan = True\n\n    \n    def plan(self,\n             qo_indptr,\n             kv_indptr,\n             kv_indices,\n             kv_len_arr,\n             bsz_tensor,\n             num_heads,\n             head_dim_ckv,\n             head_dim_kpe,\n             page_size,\n             sm_scale,\n             q_data_type,\n             kv_data_type,\n             ):\n        if qo_indptr is None:\n            assert self.max_batch_size == 1\n            qo_indptr = self.qo_indptr_buf\n        if kv_indptr is None:\n            assert self.max_batch_size == 1\n            kv_indptr = self.kv_indptr_buf\n        if kv_indices is None:\n            assert self.max_batch_size == 1\n            kv_indices = self.kv_indices_buf\n        if bsz_tensor is None:\n            assert self.max_batch_size == 1\n            bsz_tensor = self.batch_size_tensor_buf\n        \n        self.wrapper.plan(\n            qo_indptr,\n            kv_indptr,\n            kv_indices,\n            kv_len_arr,\n            num_heads,\n            head_dim_ckv,\n            head_dim_kpe,\n            page_size,\n            True, # causal\n            sm_scale,\n            q_data_type,\n            kv_data_type,\n            bsz_tensor\n        )\n\n    def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):\n        return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)\n\nclass MLAWrapperSingleton():\n    wrappers:dict = {}\n\n    @classmethod\n    def get_instance(cls, device, *args, **kwargs)->MLAWrapper:\n        if device not in cls.wrappers:\n            cls.make_instance(device, *args, **kwargs)\n        return cls.wrappers[device]\n    \n    @classmethod\n    def make_instance(cls, device, *args, **kwargs):\n        cls.wrappers[device] = MLAWrapper(*args, **kwargs, device=device)\n\n    @classmethod\n    def plan_all(cls, qo_indptr,\n             kv_indptr,\n             kv_indices,\n             kv_len_arr,\n             bsz_tensor,\n             num_heads,\n             head_dim_ckv,\n             head_dim_kpe,\n             page_size,\n             sm_scale,\n             q_data_type,\n             kv_data_type,):\n        for device, wrapper in cls.wrappers.items():\n            kv_len_arr_cur_device = kv_len_arr.to(device)\n            wrapper.plan(qo_indptr,\n                kv_indptr,\n                kv_indices,\n                kv_len_arr_cur_device,\n                bsz_tensor,\n                num_heads,\n                head_dim_ckv,\n                head_dim_kpe,\n                page_size,\n                sm_scale,\n                q_data_type,\n                kv_data_type,)\n            wrapper.need_plan = False\n            \n    @classmethod\n    def need_plan_all(cls):\n        for device, wrapper in cls.wrappers.items():\n            wrapper.need_plan = True\n        \n    @classmethod\n    def reset_buffer(cls):\n        for device, wrapper in cls.wrappers.items():\n            wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here.\n            \n    @classmethod\n    def update_buffer(cls, max_pages):\n        for device, wrapper in cls.wrappers.items():\n            wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.\n            wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)\n            wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf\n\ndef checksame():\n    flashinfer_folder = \"./flashinfer_output\"\n    flashinfer_folder = \"./kv_cache_flashinfer\"\n    triton_folder = \"./triton_output\"\n    triton_folder = \"./kv_cache_triton\"\n    \n    max_layer_id = 1\n    max_forward_id = 2\n\n    for forward_id in range(0, 19):\n        print(\"forward_id\", forward_id)\n        for layer_id in range(max_layer_id):\n            print(layer_id)\n            #file_name = f\"layer_{layer_id}_forward_{forward_id}_attn_output.pt\"\n            #file_name = f\"layer_{layer_id}_forward_{forward_id}_q_pe.pt\"\n            file_name = f\"layer_{layer_id}.pt\"\n            \n            flashinfer_path = os.path.join(flashinfer_folder, file_name)\n            triton_path = os.path.join(triton_folder, file_name)\n            \n            if not os.path.exists(triton_path):\n                print(f\"{file_name} not exist in {triton_folder}\")\n                continue\n            if not os.path.exists(flashinfer_path):\n                print(f\"{file_name} not exist in {flashinfer_folder}\")\n                continue\n            \n            \n            flashinfer_tensor = torch.load(flashinfer_path)[1:2, :62]#\n            triton_tensor = torch.load(triton_path)[1:2, :62]#.squeeze(1)#\n            try:\n                torch.testing.assert_close(flashinfer_tensor, triton_tensor, rtol=1e-9, atol=1e-9)\n            except AssertionError as e:\n                print(e)\n\nif __name__ == \"__main__\":\n    \n    #checksame()\n    #exit(0)\n\n    max_batch_size = 2\n    max_batch_tokens = 256\n    max_pages = 128\n    page_size = 64\n    num_heads = 128\n    \n    # warm-up\n    kv_len = 4023\n    q_len = 1\n    q_nope_buf = torch.randn((max_batch_tokens, num_heads, 512), dtype=torch.bfloat16, device=\"cuda\")\n    q_pe_buf = torch.randn((max_batch_tokens, num_heads, 64), dtype=torch.bfloat16, device=\"cuda\")\n    kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device=\"cuda\")\n    ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)\n    \n\n    wrapper = MLAWrapperSingleton.get_instance(\n        \"cuda\",\n        max_batch_size,\n        max_pages,\n    )\n    \n    used_pages = (kv_len + page_size - 1)// page_size\n    kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device=\"cuda\")\n    qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=\"cuda\")\n    kv_indptr = torch.tensor([0, used_pages], dtype=torch.int32, device=\"cuda\")\n    kv_indices = torch.empty(max_pages, dtype=torch.int32, device=\"cuda\")\n    kv_indices[:used_pages] = torch.arange(0, used_pages, dtype=torch.int32, device=\"cuda\")\n    bsz_tensor = torch.tensor([1], dtype=torch.int32, device=\"cuda\")\n    wrapper.plan(\n        qo_indptr,\n        kv_indptr,\n        kv_indices,\n        kv_len_arr,\n        bsz_tensor,\n        128,\n        512,\n        64,\n        page_size,\n        192 ** (-0.5),\n        torch.bfloat16,\n        torch.bfloat16,\n    )\n\n    attn_output = wrapper.run(q_nope_buf[:q_len], q_pe_buf[:q_len], ckv, k_pe)\n    print(attn_output.shape)\n    graph = torch.cuda.CUDAGraph()\n    with torch.cuda.graph(graph):\n        attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)\n    graph.replay()\n\n    q = torch.cat([q_nope_buf, q_pe_buf], dim=-1)\n    k = (\n        torch.cat([ckv, k_pe], dim=-1)\n        .view(-1, 1, 512 + 64)\n        .repeat_interleave(num_heads, dim=1)\n    )\n    v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)\n    attn_ref, lse_ref = attention_ref_torch(\n        1,\n        q[:q_len],\n        k[:kv_len],\n        v[:kv_len],\n        True,\n        192 ** (-0.5)\n    )\n    torch.testing.assert_close(attn_output[:q_len], attn_ref, rtol=5e-3, atol=5e-3)\n    # warm-up finished\n\n    kv_len = 512\n    q_len = 128\n    pages = max_pages\n    used_pages = (kv_len + page_size - 1)// page_size\n    q_nope = torch.randn((q_len*2, num_heads, 512), dtype=torch.bfloat16, device=\"cuda\")\n    q_nope[q_len:] = q_nope[:q_len]\n    q_pe = torch.randn((q_len*2, num_heads, 64), dtype=torch.bfloat16, device=\"cuda\")\n    q_pe[q_len:] = q_pe[:q_len]\n    kv_cache = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device=\"cuda\")\n    kv_cache[used_pages:2*used_pages] = kv_cache[:used_pages]\n    ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)\n    \n    kv_len_arr = torch.tensor([kv_len, kv_len], dtype=torch.int32, device=\"cuda\")\n    qo_indptr = torch.tensor([0, q_len, q_len*2], dtype=torch.int32, device=\"cuda\")\n    kv_indptr = torch.tensor([0, used_pages, used_pages*2], dtype=torch.int32, device=\"cuda\")\n    kv_indices = torch.empty(max_pages, dtype=torch.int32, device=\"cuda\")\n    kv_indices[:2*used_pages] = torch.arange(0, 2*used_pages, dtype=torch.int32, device=\"cuda\")\n    bsz_tensor = torch.tensor([2], dtype=torch.int32, device=\"cuda\")\n    wrapper.plan(\n        qo_indptr,\n        kv_indptr,\n        kv_indices,\n        kv_len_arr,\n        bsz_tensor,\n        128,\n        512,\n        64,\n        page_size,\n        192 ** (-0.5),\n        torch.bfloat16,\n        torch.bfloat16,\n    )\n    \n    q_nope_buf.copy_(q_nope)\n    q_pe_buf.copy_(q_pe)\n    kv_buf[:pages].copy_(kv_cache)\n\n    torch.cuda.synchronize()\n    graph.replay()\n    torch.cuda.synchronize()\n\n    # ref_torch\n    q = torch.cat([q_nope, q_pe], dim=-1)\n    k = (\n        torch.cat([ckv, k_pe], dim=-1)\n        .view(-1, 1, 512 + 64)\n        .repeat_interleave(num_heads, dim=1)\n    )\n    v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)\n    attn_ref, lse_ref = attention_ref_torch(\n        max_batch_size,\n        q,\n        k[:2*kv_len],\n        v[:2*kv_len],\n        True,\n        192 ** (-0.5)\n    )\n    \n    torch.testing.assert_close(attn_ref[:q_len], attn_ref[q_len:q_len*2], rtol=1e-9, atol=1e-9)\n    torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)\n    torch.testing.assert_close(attn_output[:q_len], attn_ref[:q_len], rtol=5e-3, atol=5e-3)\n    torch.testing.assert_close(attn_output[q_len:q_len*2], attn_ref[q_len:q_len*2], rtol=5e-3, atol=5e-3)\n    #torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)\n    #torch.testing.assert_close(attn_output, attn_ref, rtol=5e-3, atol=5e-3)\n\n    exit(0)\n\n    for forward_id in range(0, 1):\n        print(\"forward_id\", forward_id)\n        for layer_id in range(1):\n            print(layer_id)\n            flashinfer_folder = \"./kv_cache_flashinfer\"\n            forward_id = 17\n            layer_id = 0\n            file_name = f\"layer_{layer_id}.pt\"\n            kv_cache_path = os.path.join(flashinfer_folder, file_name)\n            flashinfer_folder = \"./flashinfer_output\"\n\n            q_len = 1\n            kv_len = 126\n            file_name = f\"layer_{layer_id}_forward_{forward_id}_q_nope.pt\"\n            q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device=\"cuda\")\n            file_name = f\"layer_{layer_id}_forward_{forward_id}_q_pe.pt\"\n            q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device=\"cuda\")\n            q = torch.cat([q_nope, q_pe], dim=-1)\n            kv_cache = torch.load(kv_cache_path).to(device=\"cuda\")\n            pages, page_size, _, head_dim = kv_cache.shape\n            kv_cache = kv_cache.view(pages, page_size, head_dim)\n            ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)\n    \n            kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device=\"cuda\")\n            qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=\"cuda\")\n            wrapper.plan(\n                None,\n                None,\n                None,\n                kv_len_arr,\n                128,\n                512,\n                64,\n                page_size,\n                192 ** (-0.5),\n                torch.bfloat16,\n                torch.bfloat16,\n            )\n    \n            q_nope_buf.copy_(q_nope)\n            q_pe_buf.copy_(q_pe)\n            kv_buf[:pages].copy_(kv_cache)\n\n            torch.cuda.synchronize()\n            graph.replay()\n            torch.cuda.synchronize()\n\n            # ref_torch\n            k = (\n                torch.cat([ckv, k_pe], dim=-1)\n                .view(-1, 1, 512 + 64)\n                .repeat_interleave(num_heads, dim=1)\n            )\n            v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)\n            attn_ref, lse_ref = attention_ref_torch(\n                max_batch_size,\n                q,\n                k[:kv_len],\n                v[:kv_len],\n                False,\n                192 ** (-0.5)\n            )\n            torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)\n    \n            # ref_triton\n            attn_logits = torch.empty(\n                    (\n                        max_batch_size,\n                        num_heads,\n                        4, #num_kv_splits # follow vLLM, fix it TODO\n                        512 + 1, \n                    ),\n                    dtype=torch.float32,\n                    device = \"cuda\"\n                )\n            \n            triton_ref = torch.zeros_like(q_nope)\n            page_table = torch.arange(max_pages, dtype=torch.int32, device=\"cuda\")\n            ckv_with_pe = torch.cat([ckv, k_pe], dim=-1).contiguous().view(pages, page_size, 1, 576)\n            ckv = ckv.view(pages, page_size, 1, 512)\n            decode_attention_fwd_grouped(q, ckv_with_pe, ckv, triton_ref,\n                page_table,\n                kv_len_arr, attn_logits,\n                4, #num_kv_splits # follow vLLM, fix it TODO\n                192 ** (-0.5),\n                page_size)\n\n            torch.testing.assert_close(attn_output, triton_ref, rtol=1e-3, atol=1e-3)\n            \n            #file_name = f\"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt\"\n            #ktrans_output = torch.load(file_name)\n            #torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)\n            print(\"test past\")"
  },
  {
    "path": "archive/ktransformers/operators/gate.py",
    "content": "from typing import Optional\nfrom torch import nn\nimport torch\nimport torch.nn.functional as F\nimport os\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.operators.linear import KTransformersLinear\nfrom ktransformers.util.custom_loader import GGUFLoader, ModelLoader, SafeTensorLoader, translate_name_to_gguf\nfrom transformers.configuration_utils import PretrainedConfig\nfrom abc import ABC, abstractmethod\n\n\n# class Base(BaseInjectedModule, ABC):\nclass KMoEGateBase(ABC):\n    def __init__(self, \n                 key: str, \n                 gguf_loader: GGUFLoader, \n                 config: PretrainedConfig, \n                 orig_module: nn.Module, \n                 device: str = \"cuda\", \n                 **kwargs):\n        # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        super().__init__()\n        self.key = key\n        self.gguf_loader = gguf_loader\n        self.config = config\n        self.device = device\n        self.orig_module = orig_module\n    \n    @abstractmethod\n    def forward(self, input_tensor, expert_ids, weights):\n        pass\n\n    @abstractmethod\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = \"cpu\", warmup: bool = False):\n        pass\n    \n    @abstractmethod\n    def unload():\n        pass\n\n    def load_weights(self, override_key: str | None = None, device: str = \"cpu\"):\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n        gate_type = None\n        up_type = None\n        down_type = None\n\n        for key in keys:\n            if self.gguf_loader.safetensor_loader is not None:\n                # for npu\n                translate_key = translate_name_to_gguf(key)\n                translate_key = \".\".join(translate_key.split(\".\")[:2])\n                targets = [\".ffn_gate_inp.weight\", \".exp_probs_b.bias\"]\n                weight = self.gguf_loader.safetensor_loader.load_tensor(translate_key + \".ffn_gate_inp.weight\")\n                e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(translate_key + \".exp_probs_b.bias\")\n                weight_type = weight.dtype\n                e_score_correction_bias_type = e_score_correction_bias.dtype\n                res = {\"weight\": weight, \"e_score_correction_bias\": e_score_correction_bias, \"weight_type\": weight_type, \"e_score_correction_bias_type\": e_score_correction_bias_type}\n            # key = \".\".join(key.split(\".\")[:-1])\n            elif isinstance(self.gguf_loader, SafeTensorLoader):\n                res = self.gguf_loader.load_gate(key, device=device)\n            elif self.gguf_loader.has_tensor(key+\".weight\"):\n                # targets = [\".ffn_gate_inp.weight\", \".exp_probs_b.bias\"]\n                targets = [\".weight\", \".e_score_correction_bias\"]\n                tensors = self.load_multi(key, targets, device=device)\n                weight = tensors[\".weight\"]\n                e_score_correction_bias = tensors[\".e_score_correction_bias\"]\n                # weight_type = self.gguf_loader.tensor_info[key + \".weight\"][\"ggml_type\"]\n                res = {\"weight\": weight, \"e_score_correction_bias\": e_score_correction_bias}\n            else:\n                raise ValueError(f\"Experts {key} not found in gguf_loader\")\n\n        return res\n    \n    def load_multi(self, key: str, keys: list[str], device: str = \"cpu\"):\n        tensors = {}\n        for k in keys:\n            tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)\n        return tensors\n\n\nclass KMoEGate(BaseInjectedModule, KMoEGateBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        return self.orig_module.forward(hidden_states)\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if device is None: device = self.device\n        if w is None: w = self.load_weights(device=device)\n        \n        if isinstance(w, dict):\n            self.orig_module.weight = nn.Parameter(w[\"weight\"])\n            self.orig_module.e_score_correction_bias = nn.Parameter(w[\"e_score_correction_bias\"])\n        else:\n            raise ValueError(\"Invalid weight type\")\n        self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))\n        self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))\n\n    def unload(self):\n        if self.weight is not None:\n            self.weight = None\n        if self.e_score_correction_bias is not None:\n            self.e_score_correction_bias = None\n\n\nclass KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        generate_device: str = \"cuda\",\n        generate_op: str| None = \"KLinearMarlin\",\n        prefill_device: str = \"cuda\",\n        prefill_op: str| None = \"KLinearMarlin\",\n        use_quant: bool = False,\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n        self.generate_op = generate_op\n        self.prefill_op = prefill_op\n        self.is_windows = os.name == 'nt'\n        self.use_quant = use_quant\n        if not self.is_windows and use_quant:\n            self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)\n            self.gate_linear = KTransformersLinear(key + \".ffn_gate_inp\", \n                                               gguf_loader, config, self.gate_linear, #orig_module\n                                               generate_device, generate_op, prefill_device, prefill_op)\n        else:\n            self.gate_linear = None\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        if self.is_windows:\n            return self.orig_module.forward(hidden_states)\n        \n        bsz, seq_len, h = hidden_states.shape\n        ### compute gating score\n        hidden_states = hidden_states.view(-1, h)\n        if self.use_quant:\n            logits = self.gate_linear.forward(logits)\n        else:\n            logits = F.linear(\n                hidden_states.type(torch.float32), self.weight.type(torch.float32), None\n            )\n            \n        return grouped_topk(hidden_states, logits,\n                            self.top_k, self.norm_topk_prob,\n                            self.n_group, self.topk_group)\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if device is None: device = self.device\n        if w is None: w = self.load_weights(device=device)\n        \n        if isinstance(w, dict):\n            self.orig_module.weight = nn.Parameter(w[\"weight\"])\n            self.orig_module.e_score_correction_bias = nn.Parameter(w[\"e_score_correction_bias\"])\n        else:\n            raise ValueError(\"Invalid weight type\")\n        self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))\n        self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))\n        if not self.is_windows and self.use_quant:\n            self.gate_linear.load(self.orig_module.weight)\n\n    def unload(self):\n        if self.weight is not None:\n            self.weight = None\n        if self.e_score_correction_bias is not None:\n            self.e_score_correction_bias = None\n\n\nclass KMoEGateIPEXLLM(KMoEGate):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        generate_device: str = \"xpu\",\n        prefill_device: str = \"xpu\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        KMoEGate.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        x = hidden_states.view(-1, hidden_states.size(-1))\n        logits = torch.nn.functional.linear(\n            x.type(torch.float32), self.orig_module.weight.type(torch.float32), None\n        )\n        scores = logits.sigmoid()\n\n        from ipex_llm.transformers.models.common import moe_group_topk\n        topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias,\n                                               self.n_group, self.topk_group, self.top_k,\n                                               self.norm_topk_prob, self.routed_scaling_factor)\n        return topk_idx, topk_weight.to(x.dtype)\n\n"
  },
  {
    "path": "archive/ktransformers/operators/layernorm.py",
    "content": "'''\nDate: 2024-11-13 15:05:52\nLastEditors: Xie Weiyu ervinxie@qq.com\nLastEditTime: 2024-11-25 08:59:19\n'''\n\"\"\"\nCopyright 2023-2024 SGLang Team\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\"\"\"Fused operators for normalization layers.\"\"\"\n\nimport logging\nfrom typing import Optional, Tuple, Union\nfrom transformers import PretrainedConfig\nimport torch\nimport torch.nn as nn\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm\nfrom ktransformers.models.modeling_qwen3_next import Qwen3NextRMSNorm\nfrom ktransformers.models.modeling_smallthinker import SmallthinkerRMSNorm\nfrom ktransformers.models.modeling_glm4_moe import Glm4MoeRMSNorm\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader\nif not torch.xpu.is_available():\n    from flashinfer.norm import (\n        fused_add_rmsnorm,\n        rmsnorm,\n    )\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.hidden_size,\n            orig_module.variance_epsilon)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        batch_size_tensor: torch.Tensor = None,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        #return self.forward_native(x, residual)\n        if batch_size_tensor is None:\n            return self.forward_native(x)\n        if residual is not None:\n            fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            #residual = x + residual\n            #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            return x, residual\n        # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())\n        out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)\n        return out\n\n    def forward_native(\n        self, hidden_states    \n    ):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nclass KQwen2MoeRMSNorm(Qwen2MoeRMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(config.hidden_size,\n            orig_module.variance_epsilon)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        batch_size_tensor: torch.Tensor = None,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        #return self.forward_native(x, residual)\n        if batch_size_tensor is None:\n            return self.forward_native(x)\n        if residual is not None:\n            fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            #residual = x + residual\n            #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            return x, residual\n        # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())\n        out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)\n        return out\n\n    def forward_native(\n        self, hidden_states    \n    ):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nclass KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.hidden_size,\n            orig_module.variance_epsilon)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        batch_size_tensor: torch.Tensor = None,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        #return self.forward_native(x, residual)\n        bsz, hidden_size = x.shape\n        x = x.view(-1, self.orig_module.hidden_size)\n        if batch_size_tensor is None:\n            return self.forward_native(x)\n        if residual is not None:\n            fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            #residual = x + residual\n            #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            return x, residual\n        # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())\n        out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)\n        out = out.view(bsz, hidden_size)\n        return out\n\n    def forward_native(\n        self, hidden_states    \n    ):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n    \nclass KQwen3NextRMSNorm(Qwen3NextRMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.hidden_size,\n            orig_module.variance_epsilon)\n\n    def _norm(self, x):\n            return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x, num_tokens_tensors, residual = None):\n        if residual is not None:\n            x = x + residual\n            residual = x\n        x = x.view(-1, self.orig_module.hidden_size)\n        output = self._norm(x.float())\n        # Llama does x.to(float16) * w whilst Qwen3Next is (x * w).to(float16)\n        # See https://github.com/huggingface/transformers/pull/29402\n        output = output * (1.0 + self.weight.float())\n        if residual is None:\n            return output.type_as(x)\n\n        return output.type_as(x), residual\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.eps}\"\n\n\nclass KSmallthinkerRMSNorm(SmallthinkerRMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.hidden_size,\n            orig_module.variance_epsilon)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        batch_size_tensor: torch.Tensor = None,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        #return self.forward_native(x, residual)\n        bsz, hidden_size = x.shape\n        x = x.view(-1, self.orig_module.hidden_size)\n        if batch_size_tensor is None:\n            return self.forward_native(x)\n        if residual is not None:\n            fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            #residual = x + residual\n            #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            return x, residual\n        # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())\n        out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)\n        out = out.view(bsz, hidden_size)\n        return out\n\n    def forward_native(\n        self, hidden_states    \n    ):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\nclass KGlm4MoeRMSNorm(Glm4MoeRMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.hidden_size,\n            orig_module.variance_epsilon)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        batch_size_tensor: torch.Tensor = None,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        #return self.forward_native(x, residual)\n        bsz, hidden_size = x.shape\n        x = x.view(-1, self.orig_module.hidden_size)\n        if batch_size_tensor is None:\n            return self.forward_native(x)\n        if residual is not None:\n            fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            #residual = x + residual\n            #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            return x, residual\n        # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())\n        out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)\n        out = out.view(bsz, hidden_size)\n        return out\n\n    def forward_native(\n        self, hidden_states    \n    ):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\n\nclass DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):\n    def __init__(self,\n                key: str,\n                gguf_loader : GGUFLoader,\n                config: PretrainedConfig,\n                orig_module: nn.Module,\n                prefill_device: str = \"cuda\",\n                generate_device: str = \"cuda\",\n                **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.hidden_size,\n            orig_module.variance_epsilon)\n\n    def forward(\n        self, \n        x,\n        batch_size_tensor: torch.Tensor = None,\n        residual: Optional[torch.Tensor] = None,\n    )-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if residual is not None:\n            x = x + residual\n            residual = x\n        # range batch_size_tensor for x\n        input_dtype = x.dtype\n        x = x.to(torch.float32)\n        variance = x.pow(2).mean(-1, keepdim=True)\n        x = x * torch.rsqrt(variance + self.variance_epsilon)\n        if residual is not None:\n            return self.weight * x.to(input_dtype), residual\n        return self.weight * x.to(input_dtype)\n\n\nclass KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"xpu\",\n                 generate_device: str = \"xpu\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.weight.shape[0],\n            orig_module.variance_epsilon)\n        self.eps = orig_module.variance_epsilon\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        from ipex_llm.transformers.models.common import rms_norm_forward\n        if x.dtype not in [torch.float32, torch.float16]:\n            output = rms_norm_forward(self, x.float())\n        else:\n            output = rms_norm_forward(self, x)\n        return output.to(x.dtype)\n\n    def load(self):\n        BaseInjectedModule.load(self)\n        if self.weight.dtype not in [torch.float32, torch.float16]:\n            self.weight = self.weight.float()"
  },
  {
    "path": "archive/ktransformers/operators/linear.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : Azure-Tang, Boxin Zhang\nDate         : 2024-07-25 11:25:24\nVersion      : 0.1.0\nLastEditors  : Azure \nLastEditTime : 2024-08-29 09:11:16\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\n\n\nimport ctypes\nimport torch\nfrom torch import Tensor, nn\n\ntry:\n    import torch_npu\n\n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\n\nif not torch.xpu.is_available() and not use_torch_npu:\n    import KTransformersOps\n    import vLLMMarlin\nfrom ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader\nfrom ktransformers.util.utils import InferenceState\nif not torch.xpu.is_available():\n    from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (\n        MarlinWorkspace,\n        marlin_quantize,\n        GPTQ_MARLIN_MIN_THREAD_N,\n        GPTQ_MARLIN_MIN_THREAD_K,\n        GPTQ_MARLIN_MAX_PARALLEL,\n        vllm_marlin_quantize\n    )\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom transformers.configuration_utils import PretrainedConfig\ntry:\n    from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant\nexcept:\n    print(\"no triton\")\nfrom abc import ABC, abstractmethod\nimport sys, os\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Release\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Debug\"))\nimport cpuinfer_ext\nfrom ktransformers.operators.cpuinfer import CPUInfer\nfrom ktransformers.server.config.config import Config\nfrom typing import Dict, Tuple, Optional, Union\nimport numpy as np\n\n#class KLinearBase(BaseInjectedModule, ABC):\nclass KLinearBase(ABC):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        **kwargs,\n    ):\n        # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        super().__init__()\n        self.key = key\n        self.gguf_loader = gguf_loader\n        self.device = device\n        self.config = config\n\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        if orig_module is not None:\n            self.in_features = orig_module.in_features\n            self.out_features = orig_module.out_features\n        else:\n            shape = self.gguf_loader.tensor_info[key + \".weight\"][\"shape\"]\n            if len(shape) == 1:\n                print(\"Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF\")\n            self.in_features  = self.gguf_loader.tensor_info[key + \".weight\"][\"shape\"][0]\n            self.out_features = self.gguf_loader.tensor_info[key + \".weight\"][\"shape\"][1]\n\n        self.loaded = False # for lm_head pre-load, TODO: use new way to do lm_head pre-load when layer wise prefill.\n\n    @abstractmethod\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        pass\n\n    def load_weight(self, override_key: str | None = None, device: str | None = None):\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        for key in keys:\n            if isinstance(self.gguf_loader, SafeTensorLoader):\n                # using safetensor_loader\n                tensor = self.gguf_loader.load_tensor(key+'.weight')\n                try:\n                    bias = self.gguf_loader.load_tensor(key+'.bias')\n                except:\n                    bias = None\n                if self.gguf_loader.has_tensor(key+'.weight_scale_inv'):\n                    weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv')\n                    return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)\n                if bias is not None:\n                    return nn.Parameter(tensor), nn.Parameter(bias)\n                else:\n                    return nn.Parameter(tensor)\n                \n            elif self.gguf_loader.has_tensor(key + \".weight\") or \"kv_b_proj\" in key:\n                if key + \".bias\" in self.gguf_loader.tensor_file_map:\n                    tensors = self.load_multi(key, [\"weight\", \"bias\"], device=device)\n                    tensor = tensors[\"weight\"]\n                    bias = tensors[\"bias\"]\n                    # self.qtype = GGML_TYPE_QTYPE_MAP[tensorinfo[key + \".weight\"][\"ggml_type\"]]\n                    # print(torch.isinf(tensor).any(), torch.isinf(bias).any())\n                    return nn.Parameter(tensor), nn.Parameter(bias)\n                elif \"kv_b_proj\" in key and not self.gguf_loader.has_tensor(key + \".weight\"):\n                    attn_k_b_tensors = self.load_multi(key.replace(\"self_attn.kv_b_proj\", \"attn_k_b\"), [\"weight\"], device=device)\n                    attn_k_b = attn_k_b_tensors[\"weight\"]\n                    del attn_k_b_tensors\n                    attn_k_b = attn_k_b.transpose(1, 2).contiguous()\n                    attn_v_b_tensors = self.load_multi(key.replace(\"self_attn.kv_b_proj\", \"attn_v_b\"), [\"weight\"], device=device)\n                    attn_v_b = attn_v_b_tensors[\"weight\"]\n                    del attn_v_b_tensors\n                    kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)\n                    kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous()\n                    del attn_k_b\n                    del attn_v_b\n                    return nn.Parameter(kv_b_proj)\n                else:\n                    tensors = self.load_multi(key, [\"weight\"], device=device)\n                    tensor = tensors[\"weight\"]\n                    # self.qtype = GGML_TYPE_QTYPE_MAP[tensorinfo[key + \".weight\"][\"ggml_type\"]]\n                    return nn.Parameter(tensor)\n            else:\n                raise FileNotFoundError(f\"Weight file not found for key {key}\")\n\n    def load_multi(self, key: str, keys: list[str], device: str = \"cpu\"):\n        tensors = {}\n        for k in keys:\n            tensors[k] = self.gguf_loader.load_gguf_tensor(key + \".\" + k, device=device)\n        return tensors\n\n    @abstractmethod\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = \"cuda\"):\n        pass\n\n    @abstractmethod\n    def unload(self):\n        pass\n\n\nclass KLinearTorch(KLinearBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        self.weight = None\n        self.has_bias = False\n\n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:\n        dtype = x.dtype\n        out_device = x.device\n        # TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.\n        x = x.to(device=self.device, dtype=self.dtype)\n        x = x @ self.weight\n        if self.has_bias:\n            x = x + self.bias\n        x = x.to(dtype=dtype, device=out_device)\n        return x\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if self.loaded: return\n        if device is None: device = self.device\n        if w is None: w = self.load_weight(device=device)\n        # else: self.out_features = w.shape[0], self.in_features = w.shape[1]\n        \n        if isinstance(w, nn.Parameter):\n            try:\n                self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T\n            except: \n                self.weight = w.to(dtype=self.dtype).T\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            try:\n                self.weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T\n            except:\n                self.weight = w[0].to(dtype=self.dtype).T\n            self.bias = w[1].to(dtype=self.dtype)\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        # self.linear = self.linear.to(device)\n        self.weight = self.weight.to(device)\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n        self.loaded = True\n\n    def unload(self):\n        if self.weight is not None:\n            self.weight = None\n        if self.has_bias:\n            self.bias = None\n\nclass KLinearQ8(KLinearBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.has_bias = False\n        self.compute_dtype = torch.float32\n        self.weight = None\n        self.weight_scale = None\n        self.weight_zero_point = None\n        self.bias = None\n        self.loaded = False\n    \n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None) -> torch.Tensor:\n        orig_dtype = x.dtype\n        out_device = x.device\n        \n        x = x.to(device=self.device, dtype=self.compute_dtype)\n        \n        # 使用原始权重做矩阵乘法，模拟原始行为\n\n        # 反量化权重进行矩阵乘法\n        weight_dequant = self._dequantize_weight(self.weight, self.weight_scale, bits=8)\n        out = x @ weight_dequant.T\n        \n        if self.has_bias:\n            out = out + self.bias\n        \n        return out.to(dtype=orig_dtype, device=out_device)\n    \n    def _dequantize_weight(self, q_matrix, scales, bits=8):\n        \"\"\"\n        Dequantize a low-precision matrix back to floating-point\n        \n        Args:\n            q_matrix (torch.Tensor): Quantized int matrix\n            scales (torch.Tensor): Scale factors for each column\n            bits (int): Quantization bits used (8 or 4)\n        \n        Returns:\n            torch.Tensor: Dequantized floating-point matrix\n        \"\"\"\n        # Ensure inputs are torch tensors\n        if not isinstance(q_matrix, torch.Tensor):\n            q_matrix = torch.tensor(q_matrix, dtype=torch.int8)\n        if not isinstance(scales, torch.Tensor):\n            scales = torch.tensor(scales, dtype=torch.float32)\n        \n        # Convert to correct dtype if needed\n        if q_matrix.dtype != torch.int8:\n            q_matrix = q_matrix.to(torch.int8)\n        if scales.dtype != torch.float32:\n            scales = scales.to(torch.float32)\n        \n        # For Q4, ensure the values stay within 4-bit range\n        if bits == 4:\n            q_matrix = torch.clamp(q_matrix, -7, 7)\n        rows, cols = q_matrix.shape\n        dequant_matrix = q_matrix.to(torch.float32)\n        scales_broadcast = scales.view(1, cols)\n        # Apply dequantization to all columns at once using matrix multiplication\n        dequant_matrix = dequant_matrix * scales_broadcast\n        \n        return dequant_matrix\n\n    \n    def _quantize_weight(self, matrix, bits=8):\n        \"\"\"\n        Quantize a floating-point matrix to lower precision (Q8 or Q4)\n        \n        Args:\n            matrix (torch.Tensor): Input matrix in floating-point format\n            bits (int): Quantization bits, either 8 or 4\n        \n        Returns:\n            tuple: (quantized int matrix, scale factors for each column)\n        \"\"\"\n        if not isinstance(matrix, torch.Tensor):\n            matrix = torch.tensor(matrix, dtype=torch.float32)\n        \n        # Convert to float32 if needed\n        if matrix.dtype != torch.float32:\n            matrix = matrix.to(torch.float32)\n        \n        # Get matrix shape\n        rows, cols = matrix.shape\n        \n        # Determine quantization parameters based on bits\n        if bits == 8:\n            max_int = 127\n            qtype = torch.int8\n        elif bits == 4:\n            max_int = 7\n            qtype = torch.int8  # We'll still use int8 storage but limit to 4-bit range, wait for native support\n        else:\n            raise ValueError(\"Quantization bits must be either 8 or 4\")\n       \n        scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)\n        \n        # Calculate max absolute value for each column\n        max_abs_vals, _ = torch.max(torch.abs(matrix), dim=0)\n        \n        # Handle zero columns (avoid division by zero)\n        zero_cols = max_abs_vals == 0\n        max_abs_vals[zero_cols] = 1.0\n        \n        # Calculate scale factors for all columns at once\n        scales = max_abs_vals / max_int\n        \n        # Prepare the scales for broadcasting [1, cols]\n        scales_broadcast = scales.view(1, cols)\n        \n        # Apply quantization to the entire matrix at once\n        q_matrix = torch.round(matrix / scales_broadcast).to(qtype)\n        \n        # For Q4, clamp values to ensure they stay within 4-bit range\n        if bits == 4:\n            q_matrix = torch.clamp(q_matrix, -max_int, max_int)\n        \n        return q_matrix, scales\n    \n    def load(self, w: Union[Dict, nn.Parameter, Tuple, None] = None, device: Optional[str] = None):\n        if self.loaded: return\n        if device is None: device = self.device \n        if w is None: w = self.load_weight(device=device)\n        \n        if isinstance(w, nn.Parameter):\n            try:\n                weight = w.to(dtype=self.compute_dtype).view(self.out_features, self.in_features)\n            except:\n                weight = w.to(dtype=self.compute_dtype)\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            try:\n                weight = w[0].to(dtype=self.compute_dtype).view(self.out_features, self.in_features)\n            except:\n                weight = w[0].to(dtype=self.compute_dtype)\n            self.bias = w[1].to(dtype=self.compute_dtype).to(device)\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        \n        self.weight, self.weight_scale = self._quantize_weight(weight, bits=8)\n        \n        self.weight = self.weight.to(device)\n        self.weight_scale = self.weight_scale.to(device)\n        \n        if self.has_bias:\n            self.bias = self.bias.to(device)\n            \n        self.loaded = True\n    \n    def unload(self):\n        self.weight = None\n        self.weight_scale = None\n        self.weight_zero_point = None\n        self._orig_weight = None\n        \n        if self.has_bias:\n            self.bias = None\n            \n        self.loaded = False\n\n\nclass KLinearFP8(KLinearBase):\n    # this kernel requires special handling for weight\n    # Please load the weight file downloaded from KVCache.AI\n    has_bias: bool\n    weight: torch.Tensor\n    bias: torch.Tensor\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        block_size: int = 128,\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        self.block_size = block_size\n    \n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:\n        x = x.to(self.device)\n        orig_dtype = x.dtype        \n        x_quantized, scale_x = act_quant(x, self.block_size)\n        y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv)\n        return y.to(dtype=orig_dtype)\n    \n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if device is None: device = self.device\n        if w is None: \n            w = self.load_weight(device=device) \n        ### TODO fit weight_inv format\n        if isinstance(w, tuple):\n            self.weight = w[0].to(device)\n            self.weight_scale_inv = w[1].to(device)\n            self.has_bias = False\n        else:\n            raise ValueError(\"Invalid weight type\")\n        self.weight = self.weight.to(device)\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n        \n    def unload(self):\n        if self.weight is not None:\n            self.weight = None\n        if self.has_bias:\n            self.bias = None\n\n# TODO: merge two marlin class\n\nclass VLinearMarlin(KLinearBase):\n    marlin_q_w: torch.Tensor\n    marlin_s: torch.Tensor\n    g_idx: torch.Tensor\n    sort_indices: torch.Tensor\n    has_bias: bool\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        num_bits: int = 4,  # 4-bit/8-bit is supported\n        group_size: int = 64,  # -1, 32, 64, 128\n        act_order: bool = False,\n        is_k_full=True,\n        **kwargs,\n    ):\n        assert device.lower() != \"cpu\", \"Marlin quantized linear only supports GPU device\"\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.num_bits = num_bits\n        self.group_size = group_size\n        self.act_order = act_order\n        self.is_k_full = is_k_full\n        self.padding = False\n        self.orin_in_features = self.in_features\n        self.orin_out_features = self.out_features\n        if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:\n            #print(f\"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding\")\n            self.padding = True\n            self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K\n            self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N\n            #print(f\"After padding: in_features={in_features}, out_features={out_features}\")\n        \n        self.k = self.in_features\n        self.n = self.out_features\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if self.loaded: return\n        if device is None: device = self.device\n        assert device.lower() != \"cpu\", \"Marlin quantized linear only supports GPU device\"\n        \n        #if self.in_features * self.out_features:\n        if w is None: \n            w = self.load_weight(device=device) \n\n        if isinstance(w, nn.Parameter):\n            # pad weight\n            weight = w.view(self.orin_out_features, self.orin_in_features).T\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            w = list(w)\n            weight = w[0].view(self.orin_out_features, self.orin_in_features).T\n            self.bias = w[1].view(self.orin_out_features)\n            self.bias = w[1]\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        weight = weight.to(device)\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n            \n        if self.padding:\n            padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)\n            padded_weight[:self.orin_in_features, :self.orin_out_features] = weight\n            weight = padded_weight\n\n        # Pack Marlin linear\n        marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(\n            weight, self.num_bits, self.group_size, self.act_order\n        )\n        self.workspace = MarlinWorkspace(\n            self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device\n        )\n        self.weight = marlin_q_w\n        self.marlin_q_w = marlin_q_w\n        self.marlin_s = marlin_s\n        self.g_idx = g_idx\n        self.sort_indices = sort_indices\n        self.k = weight.shape[0]\n        self.n = weight.shape[1]\n        # self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device)\n        self.loaded = True\n\n\n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:\n        if bsz_tensor is None:\n            bsz_tensor = torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device)\n\n\n        # Only support input x as BF16 and FP16\n        x = x.to(self.device)\n        orig_shape = list(x.shape)\n        orig_dtype = x.dtype\n        x = x.reshape(-1, orig_shape[-1])\n        marlin_s = self.marlin_s.to(x.dtype)\n        sms = -1\n\n        # padding x.shape[0] to avoid CUDA illegal memory access error\n        x, orig_size_m = self._pad_input(x)\n\n        x = vLLMMarlin.gptq_marlin_gemm(\n            x,\n            self.marlin_q_w,\n            marlin_s,\n            self.g_idx,\n            self.sort_indices,\n            self.workspace.scratch,\n            self.num_bits,\n            bsz_tensor,\n            x.shape[0],\n            self.n,\n            x.shape[-1],\n            sms,\n            self.is_k_full,\n        )\n\n        x = x[:orig_size_m]\n\n        if self.has_bias:\n            x = x + self.bias\n        orig_shape[-1] = self.n\n        return x.reshape(orig_shape).to(orig_dtype)\n\n    def unload(self):\n\n        if self.has_bias:\n            self.bias = None\n        self.marlin_q_w = None\n        self.marlin_s = None\n        self.g_idx = None\n        self.sort_indices = None\n        self.workspace = None  \n\n    def _pad_input(self, x):\n\n        size_m = x.shape[0]\n        size_k = x.shape[1]\n\n        # size_m and align value depends on VLinearMarlin implementation\n        if size_m > 1024:\n            align = 1024\n        elif size_m > 64:\n            align = 64\n        else:\n            align = 1\n\n        padded_size_m = ((size_m + align - 1) // align) * align\n\n        if padded_size_m > size_m:\n            pad_len = padded_size_m - size_m\n            pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device)\n            x = torch.cat([x, pad_tensor], dim = 0).contiguous()\n        return x, size_m\n\nclass KLinearMarlin(KLinearBase):\n    marlin_q_w: torch.Tensor\n    marlin_s: torch.Tensor\n    g_idx: torch.Tensor\n    sort_indices: torch.Tensor\n    has_bias: bool\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        num_bits: int = 4,  # 4-bit/8-bit is supported\n        group_size: int = 64,  # -1, 32, 64, 128\n        act_order: bool = False,\n        is_k_full=True,\n        **kwargs,\n    ):\n        assert device.lower() != \"cpu\", \"Marlin quantized linear only supports GPU device\"\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.num_bits = num_bits\n        self.group_size = group_size\n        self.act_order = act_order\n        self.is_k_full = is_k_full\n        self.padding = False\n        self.orin_in_features = self.in_features\n        self.orin_out_features = self.out_features\n        if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:\n            #print(f\"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding\")\n            self.padding = True\n            self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K\n            self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N\n            #print(f\"After padding: in_features={in_features}, out_features={out_features}\")\n        \n        self.k = self.in_features\n        self.n = self.out_features\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if self.loaded: return\n        if device is None: device = self.device\n        assert device.lower() != \"cpu\", \"Marlin quantized linear only supports GPU device\"\n        \n        #if self.in_features * self.out_features:\n        if w is None: \n            w = self.load_weight(device=device) \n\n        if isinstance(w, nn.Parameter):\n            # pad weight\n            weight = w.view(self.orin_out_features, self.orin_in_features).T\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            w = list(w)\n            weight = w[0].view(self.orin_out_features, self.orin_in_features).T\n            self.bias = w[1].view(self.orin_out_features)\n            self.bias = w[1]\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        weight = weight.to(device)\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n            \n        if self.padding:\n            padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)\n            padded_weight[:self.orin_in_features, :self.orin_out_features] = weight\n            weight = padded_weight\n\n        # Pack Marlin linear\n        marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(\n            weight, self.num_bits, self.group_size, self.act_order\n        )\n        self.workspace = MarlinWorkspace(\n            self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device\n        )\n        self.weight = marlin_q_w # modeling_xxx.py may use linear.weight\n        self.marlin_q_w = marlin_q_w\n        self.marlin_s = marlin_s\n        self.g_idx = g_idx\n        self.sort_indices = sort_indices\n        self.k = weight.shape[0]\n        self.n = weight.shape[1]\n        self.loaded = True\n\n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:\n        # Only support input x as BF16 and FP16\n        x = x.to(self.device)\n        orig_shape = list(x.shape)\n        orig_dtype = x.dtype\n        x = x.reshape(-1, orig_shape[-1])\n        x = x.reshape(-1, x.shape[-1])\n        if self.padding:\n            padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)\n            padding_input[:,:self.orin_in_features] = x\n            x = padding_input\n        marlin_s = self.marlin_s.to(x.dtype)\n        x = KTransformersOps.gptq_marlin_gemm(\n            x,\n            self.marlin_q_w,\n            marlin_s,\n            self.g_idx,\n            self.sort_indices,\n            self.workspace.scratch,\n            self.num_bits,\n            x.shape[0],\n            self.n,\n            x.shape[-1],\n            self.is_k_full,\n        )\n        if self.padding:\n            x = x[:,:self.orin_out_features]\n            orig_shape[-1] = self.orin_out_features\n        else:\n            orig_shape[-1] = self.out_features\n        if self.has_bias:\n            x = x + self.bias\n        return x.reshape(orig_shape).to(orig_dtype)\n\n    def unload(self):\n\n        if self.has_bias:\n            self.bias = None\n        self.marlin_q_w = None\n        self.marlin_s = None\n        self.g_idx = None\n        self.sort_indices = None\n        self.workspace = None\n\nclass KLinearCPUInfer(KLinearBase):\n    CPU_INFER = None\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cpu\",\n        out_device: str = \"cuda\", # this device mean which device the output should on. TODO: support cpu.\n        stride = 16,\n        group_max_len = 1024,\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        if KLinearCPUInfer.CPU_INFER is None:\n            KLinearCPUInfer.CPU_INFER = CPUInfer(Config().cpu_infer)\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        self.w = None\n        self.has_bias = False\n        self.stride = stride\n        self.group_max_len = group_max_len\n        self.out_device = out_device\n\n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:\n        origin_shape = x.shape # [batch_size, q_len, hidden_size]\n        if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing():\n            out_device = x.device\n            self.input_tensor_cpu.copy_(x, non_blocking=True)\n            qlen = origin_shape[1]\n            KLinearCPUInfer.CPU_INFER.submit_with_cuda_stream(\n                torch.cuda.current_stream().cuda_stream,\n                self.linear.forward(\n                    qlen, \n                    self.input_tensor_cpu.data_ptr(), \n                    self.output_cpu.data_ptr()\n                )\n            )\n            KLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)\n            self.output_gpu.copy_(self.output_cpu, non_blocking=True)\n            if self.has_bias:\n                self.output_gpu += self.bias\n            return self.output_gpu\n        else:\n            dtype = x.dtype\n            out_device = x.device\n            x = x.to(device=self.device)\n            qlen = origin_shape[1]\n            output_shape = (*origin_shape[:-1], self.out_features)\n            output = torch.empty(output_shape, device=x.device, dtype=x.dtype)\n            KLinearCPUInfer.CPU_INFER.submit(\n                self.linear.forward(\n                    qlen, \n                    x.data_ptr(), \n                    output.data_ptr()\n                )\n            )\n            KLinearCPUInfer.CPU_INFER.sync()\n            if self.has_bias:\n                output = output + self.bias\n            output = output.to(dtype=dtype, device=out_device)\n            return output\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None, warmup:bool = True):\n        print(f\"loading {self.key} to {self.device} using CPUInfer\")\n        if device is None: device = self.device\n        self.load_weights(w=w, device=device)\n        if self.bias is not None:\n            self.has_bias = True\n            self.bias = self.bias.to(device)\n            \n        weight_ptr = ctypes.addressof(\n            ctypes.cast(self.weight.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        config = cpuinfer_ext.linear.LinearConfig(self.in_features, self.out_features, self.stride, self.group_max_len, weight_ptr, self.weight_type, 30)\n        self.linear = cpuinfer_ext.linear.Linear(config)\n        \n        if warmup:\n            KLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up())\n            KLinearCPUInfer.CPU_INFER.sync()\n        self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device=\"cpu\", pin_memory=True)\n        self.output_cpu = torch.zeros((1, 1, self.out_features), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16)\n        self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device)\n\n    def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = \"cpu\"):\n        if self.gguf_loader.has_tensor(self.key + \".weight\"):\n            if self.key + \".bias\" in self.gguf_loader.tensor_file_map:\n                self.weight = self.gguf_loader.get_mmap_tensor(self.key + \".weight\")\n                self.weight_type = self.gguf_loader.tensor_info[self.key + \".weight\"][\"ggml_type\"]\n                self.bias = self.gguf_loader.load_gguf_tensor(self.key + \".bias\", device=device)\n            else:\n                self.weight = self.gguf_loader.get_mmap_tensor(self.key + \".weight\")\n                self.weight_type = self.gguf_loader.tensor_info[self.key + \".weight\"][\"ggml_type\"]\n                self.bias = None\n        else:\n            raise ValueError(f\"Linear {self.key} not found in gguf_loader\")\n\n    def unload(self):\n        if self.w is not None:\n            self.w = None\n        if self.has_bias:\n            self.bias = None       \n\nclass KLinearIPEXLLM(KLinearBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"xpu\",\n        precision: str = \"sym_int4\",\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        self.weight = None\n        self.has_bias = False\n        self.precision = precision\n        self.qtype = None\n\n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:\n        dtype = x.dtype\n        out_device = x.device\n        from ipex_llm.transformers.models.common import linear_forward\n        x = linear_forward(x.half(), self.weight, self.qtype, self.out_features)\n\n        if self.has_bias:\n            x = x + self.bias\n        x = x.to(dtype=dtype, device=out_device)\n        return x\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if self.loaded: return\n        if device is None: device = self.device\n        assert device.lower()[:3] == \"xpu\", \"IPEX-LLM quantized linear only supports XPU device\"\n        if w is None: w = self.load_weight(device=device)\n\n        if isinstance(w, nn.Parameter):\n            try:\n                weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T\n            except:\n                weight = w.to(dtype=self.dtype).T\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            try:\n                weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T\n            except:\n                weight = w[0].to(dtype=self.dtype).T\n            self.bias = w[1].to(dtype=self.dtype)\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        weight = weight.to(\"cpu\").float().transpose(0, 1).contiguous()\n\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n\n        # quantize linear weight\n        from ipex_llm.transformers.models.common import quantize_linear\n        paramsLowBit, qtype = quantize_linear(weight, self.in_features, self.precision)\n        self.weight = paramsLowBit.to(device)\n        self.qtype = qtype\n        self.loaded = True\n\n    def unload(self):\n        if self.weight is not None:\n            self.weight = None\n        if self.has_bias:\n            self.bias = None\n\nLINEAR_MAP = {\n    \"KLinearMarlin\": KLinearMarlin,\n    \"KLinearTorch\": KLinearTorch,\n    \"KLinearCPUInfer\": KLinearCPUInfer,\n    \"VLinearMarlin\": VLinearMarlin,\n    \"KLinearFP8\": KLinearFP8,\n    \"KLinearQ8\": KLinearQ8,\n    \"KLinearIPEXLLM\": KLinearIPEXLLM,\n}\n\nclass KTransformersLinear(BaseInjectedModule, KLinearBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        generate_device: str = \"cuda\",\n        generate_op: str| None = \"KLinearMarlin\",\n        prefill_device: str = \"cuda\",\n        prefill_op: str| None = \"KLinearTorch\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        # build all the linear operators\n        if prefill_op is not None:\n            assert prefill_op in LINEAR_MAP, f\"linear_type {prefill_op} not supported\"\n            self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        else:\n            self.prefill_linear = None\n\n        if generate_op is not None:\n            assert generate_op in LINEAR_MAP, f\"linear_type {generate_op} not supported\"\n            self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        else:\n            self.generate_linear = None\n        self.mode = InferenceState.UNLOAD\n\n    def forward(self, x, bsz_tensor=None):\n        if self.mode == InferenceState.PREFILL:\n            assert self.prefill_linear is not None, \"cpu linear is not initialized\"\n            y = self.prefill_linear.forward(x, bsz_tensor)\n        else:\n            assert self.generate_linear is not None, \"gpu linear is not initialized\"\n            y = self.generate_linear.forward(x, bsz_tensor)\n        return y\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):\n        if not mode:\n            mode = InferenceState.GENERATE\n        # load to device\n        if mode == InferenceState.PREFILL:\n            self.generate_linear.unload()\n            self.prefill_linear.load(w=w)\n            self.device = self.prefill_linear.device\n            self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight\n        elif mode == InferenceState.GENERATE:\n            self.prefill_linear.unload()\n            self.generate_linear.load(w=w)\n            self.device = self.generate_linear.device\n            self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight\n        elif mode == InferenceState.UNLOAD:\n            self.prefill_linear.unload()\n            self.generate_linear.unload()\n            self.device = \"cpu\"\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n        self.mode = mode\n\n    def unload(self):\n        if self.prefill_linear is not None:\n            self.prefill_linear.unload()\n        if self.generate_linear is not None:\n            self.generate_linear.unload()\n        self.device = self.generate_linear.device\n\n    def set_inference_mode(self, mode: InferenceState):\n        if not mode: \n            mode = InferenceState.GENERATE\n        if mode == InferenceState.GENERATE:\n            self.load(mode=InferenceState.GENERATE)\n        elif mode == InferenceState.PREFILL:\n            self.load(mode=InferenceState.PREFILL)\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n\n"
  },
  {
    "path": "archive/ktransformers/operators/mlp.py",
    "content": "\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom transformers import PretrainedConfig\nimport torch.nn as nn\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP\nfrom ktransformers.models.modeling_smallthinker import SmallthinkerDenseMlpBlock\nfrom ktransformers.models.modeling_glm4_moe import Glm4MoeMLP\nclass kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.hidden_size, orig_module.intermediate_size)\n    def forward(self, x, bsz_tensor):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)\n        return down_proj\nclass KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.intermediate_size)\n    def forward(self, x, bsz_tensor):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)\n        return down_proj\n\n\nclass KSmallthinkerDenseMlpBlock(SmallthinkerDenseMlpBlock, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config)\n    def forward(self, x, bsz_tensor):\n        down_proj = self.down(nn.functional.relu(self.gate(x, bsz_tensor)) * self.up(x, bsz_tensor), bsz_tensor)\n        return down_proj\n\nclass KGlm4MoeMLP(Glm4MoeMLP, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config, orig_module.hidden_size, orig_module.intermediate_size)\n    def forward(self, x, bsz_tensor):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)\n        return down_proj"
  },
  {
    "path": "archive/ktransformers/operators/models.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :  \nAuthor       : Azure-Tang\nDate         : 2024-07-25 11:25:24\nVersion      : 1.0.0\nLastEditors  : Azure \nLastEditTime : 2024-08-27 07:29:04\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\n\nimport inspect\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport time\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\nfrom ktransformers.operators.dynamic_attention import DynamicScaledDotProductAttention\nfrom ktransformers.server.config.config import Config\nimport os\nimport yaml\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n)\nfrom transformers.modeling_outputs import (\n    MoeCausalLMOutputWithPast,\n    MoeModelOutputWithPast,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom ktransformers.models.modeling_qwen2_moe import (\n    Qwen2MoeSparseMoeBlock,\n    Qwen2MoeMLP,\n    Qwen2MoeDecoderLayer,\n)\nfrom ktransformers.models.modeling_deepseek import (\n    BaseModelOutputWithPast,\n    DeepseekV2DecoderLayer,\n    DeepseekV2MoE,\n)\nfrom ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor\nfrom transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig\nfrom ktransformers.models.configuration_llama import LlamaConfig\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.utils import InferenceState, get_compute_capability\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom transformers.configuration_utils import PretrainedConfig\nfrom ktransformers.models.modeling_llama import (\n    LlamaDecoderLayer,\n    LlamaRMSNorm,\n    LlamaRotaryEmbedding,\n)\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n    _flash_supports_window_size = \"window_size\" in list(\n        inspect.signature(flash_attn_func).parameters\n    )\n\ntry:\n    import torch_npu\n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"Qwen/Qwen1.5-MoE-A2.7B\"\n_CONFIG_FOR_DOC = \"Qwen2MoeConfig\"\n\nQWEN2MOE_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`Qwen2MoeConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\nQWEN2MOE_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        output_router_logits (`bool`, *optional*):\n            Whether or not to return the logits of all the routers. They are useful for computing the router loss, and\n            should not be returned during inference.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.\",\n    QWEN2MOE_START_DOCSTRING,\n)\nclass KQwen2MoeModel(BaseInjectedModule):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`]\n\n    Args:\n        config: Qwen2MoeConfig\n    \"\"\"\n\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        device: str = \"cuda\",\n        per_layer_prefill_intput_threshold: int = 30000,  # if None, no per-layer prefill\n        transfer_map: dict = None,\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, device, **kwargs\n        )\n        self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold\n        self.transfer_map = transfer_map\n        self.stream_device_map = dict()\n\n    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        per_layer_prefill_intput_threshold: (\n            int | None\n        ) = None,  # if None or 0, close per-layer prefill\n    ) -> Union[Tuple, MoeModelOutputWithPast]:\n        # print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}')\n\n        if per_layer_prefill_intput_threshold is None:\n            per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold\n        per_layer_prefill_flag = False\n        seq_lenth = (\n            inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)\n        )\n        if (\n            per_layer_prefill_intput_threshold\n            and per_layer_prefill_intput_threshold < seq_lenth\n        ):\n            per_layer_prefill_flag = True\n            for layer in self.layers:\n                self.load_layer_to(layer, InferenceState.UNLOAD)\n        else:\n            pass\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_router_logits = (\n            output_router_logits\n            if output_router_logits is not None\n            else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        use_legacy_cache = False\n        if use_cache and not isinstance(past_key_values, Cache):\n            use_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            logger.warning_once(\n                \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n                \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n            )\n\n        if inputs_embeds is None:\n            input_ids = input_ids.to(\"cpu\")\n            inputs_embeds = self.embed_tokens(input_ids)\n            inputs_embeds = inputs_embeds.to(\"cuda\")\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=inputs_embeds.device,\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask,\n            inputs_embeds,\n            cache_position,\n            past_key_values,\n            output_attentions,\n        )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        if torch.xpu.is_available() and inputs_embeds.device.type == \"xpu\":\n            position_embeddings = self.rotary_emb(hidden_states, position_ids)\n        else:\n            position_embeddings = None\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        all_router_logits = () if output_router_logits else None\n        next_decoder_cache = None\n\n        for i, decoder_layer in enumerate(self.layers):\n            if self.transfer_map is not None and i in self.transfer_map:\n                prev_stream = torch.cuda.current_stream()\n                cur_device = self.transfer_map[i]\n                if cur_device not in self.stream_device_map:\n                    self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                torch.cuda.set_device(cur_device)\n                self.stream_device_map[cur_device].wait_stream(prev_stream)\n                torch.cuda.set_stream(self.stream_device_map[cur_device])\n                hidden_states = hidden_states.to(\n                    self.transfer_map[i], non_blocking=True\n                )\n                causal_mask = (\n                    causal_mask.to(self.transfer_map[i], non_blocking=True)\n                    if causal_mask is not None\n                    else None\n                )\n                position_ids = (\n                    position_ids.to(self.transfer_map[i], non_blocking=True)\n                    if position_ids is not None\n                    else None\n                )\n                cache_position = (\n                    cache_position.to(self.transfer_map[i], non_blocking=True)\n                    if cache_position is not None\n                    else None\n                )\n\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    output_router_logits,\n                    use_cache,\n                    cache_position,\n                )\n            else:\n                if per_layer_prefill_flag:\n                    # print(f\"to gpu\")\n                    self.load_layer_to(decoder_layer, InferenceState.PREFILL)\n                    torch.cuda.empty_cache()\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    output_router_logits=output_router_logits,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                )\n                if per_layer_prefill_flag:\n                    # print(f\"to cpu\")\n                    self.load_layer_to(decoder_layer, InferenceState.UNLOAD)\n                    torch.cuda.empty_cache()\n            hidden_states = layer_outputs[0]\n\n            if use_cache and len(layer_outputs) > 1:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n            else:\n                next_decoder_cache = None\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n            if output_router_logits and layer_outputs[-1] is not None:\n                all_router_logits += (layer_outputs[-1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        if per_layer_prefill_flag:\n            per_layer_prefill_flag = False\n            for layer in self.layers:\n                self.load_layer_to(layer, InferenceState.GENERATE)\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            if next_decoder_cache is not None:\n                next_cache = (\n                    next_decoder_cache.to_legacy_cache()\n                    if use_legacy_cache\n                    else next_decoder_cache\n                )\n            else:\n                next_cache = past_key_values\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    next_cache,\n                    all_hidden_states,\n                    all_self_attns,\n                    all_router_logits,\n                ]\n                if v is not None\n            )\n        return MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            router_logits=all_router_logits,\n        )\n\n    def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState):\n        assert isinstance(\n            layer, Qwen2MoeDecoderLayer\n        ), \"module should be nn.ModuleList of decoder layers\"\n\n        # TODO Support restore to original device, not only cuda\n        device = \"cpu\" if target == InferenceState.UNLOAD else \"cuda\"\n\n        # attn\n        layer.self_attn.q_proj.set_inference_mode(target)\n        layer.self_attn.k_proj.set_inference_mode(target)\n        layer.self_attn.v_proj.set_inference_mode(target)\n        layer.self_attn.o_proj.set_inference_mode(target)\n        layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device)\n\n        # mlp\n        if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock):\n            layer.mlp.gate.set_inference_mode(target)\n            layer.mlp.experts.set_inference_mode(target)\n            layer.mlp.shared_expert.gate_proj.set_inference_mode(target)\n            layer.mlp.shared_expert.up_proj.set_inference_mode(target)\n            layer.mlp.shared_expert.down_proj.set_inference_mode(target)\n            layer.mlp.shared_expert.act_fn.to(device)\n            layer.mlp.shared_expert_gate.to(device)\n        else:\n            layer.mlp.gate_proj.set_inference_mode(target)\n            layer.mlp.up_proj.set_inference_mode(target)\n            layer.mlp.down_proj.set_inference_mode(target)\n            layer.mlp.act_fn.to(device)\n        # layer norm\n        layer.input_layernorm.to(device)\n        layer.post_attention_layernorm.to(device)\n\n\nDeepseekV2_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\nclass KDeepseekV2Model(BaseInjectedModule):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]\n\n    Args:\n        config: DeepseekV2Config\n    \"\"\"\n\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        device: str = \"cuda\",\n        per_layer_prefill_intput_threshold: int = 30000,  # if None, no per-layer prefill\n        transfer_map: dict = None,\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, device, **kwargs\n        )\n        self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold\n        self.transfer_map = transfer_map\n        self.stream_device_map = dict()\n\n    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        per_layer_prefill_intput_threshold: (\n            int | None\n        ) = None,  # if None, no per-layer prefill\n        is_prefill: Optional[bool] = False,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        if per_layer_prefill_intput_threshold is None:\n            per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold\n        per_layer_prefill_flag = False\n        seq_lenth = (\n            inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)\n        )\n        if (\n            per_layer_prefill_intput_threshold\n            and per_layer_prefill_intput_threshold < seq_lenth\n        ):\n            per_layer_prefill_flag = True\n            for layer in self.layers:\n                self.load_layer_to(layer, InferenceState.UNLOAD)\n            torch.cuda.empty_cache()\n        else:\n            pass\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape[:2]\n        elif inputs_embeds is not None:\n            batch_size, seq_length = inputs_embeds.shape[:2]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers.\"\n                )\n                use_cache = False\n\n        past_key_values_length = 0\n        if use_cache:\n            use_legacy_cache = not isinstance(past_key_values, Cache)\n            if use_legacy_cache:\n                past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            past_key_values_length = past_key_values.get_usable_length(seq_length)\n        \n        if inputs_embeds is None:\n            org_device = input_ids.device\n            # TODO move to embed_tokens's device, not hard code to cpu\n            input_ids = input_ids.to(\"cpu\")\n            inputs_embeds = self.embed_tokens(input_ids).to(org_device)\n            input_ids = input_ids.to(org_device)\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=inputs_embeds.device,\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        if inputs_embeds.device.type == \"xpu\" and position_ids is not None:\n            cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,\n                                                           position_ids)\n            position_embeddings = (cos, sin)\n        else:\n            position_embeddings = None\n\n        if per_layer_prefill_flag:\n            causal_mask = None\n        elif use_torch_npu and not is_prefill:\n            causal_mask = None\n        else:\n            if (use_torch_npu\n                or os.name == 'nt'\n                or get_compute_capability() < 8\n                or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())\n                or device_manager.gpu_vendor != GPUVendor.NVIDIA):\n                # print(\"for Windows or GPU before ampere, use forward_windows\")\n                # only use mask in forward windows or can't flash attn\n                causal_mask = self._update_causal_mask(\n                    attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n                )\n            else:\n                causal_mask = None\n\n        # embed positions\n        hidden_states = inputs_embeds\n        if per_layer_prefill_flag:\n            print(f\"Total length of input_ids: {hidden_states.size(1)}\")\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        t_gpu = 0\n        t_cpu = 0\n        t_f = 0\n\n        for i, decoder_layer in enumerate(self.layers):\n            # print(f\"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \\n\")\n            if self.transfer_map is not None and i in self.transfer_map:\n                prev_stream = torch.cuda.current_stream()\n                cur_device = self.transfer_map[i]\n                if cur_device not in self.stream_device_map and cur_device.lower() != \"cpu\":\n                    self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                if cur_device.lower() != \"cpu\":\n                    torch.cuda.set_device(cur_device)\n                    self.stream_device_map[cur_device].wait_stream(prev_stream)\n                    torch.cuda.set_stream(self.stream_device_map[cur_device])\n                hidden_states = hidden_states.to(\n                    self.transfer_map[i], non_blocking=True\n                )\n                causal_mask = (\n                    causal_mask.to(self.transfer_map[i], non_blocking=True)\n                    if causal_mask is not None\n                    else None\n                )\n                position_ids = (\n                    position_ids.to(self.transfer_map[i], non_blocking=True)\n                    if position_ids is not None\n                    else None\n                )\n                cache_position = (\n                    cache_position.to(self.transfer_map[i], non_blocking=True)\n                    if cache_position is not None\n                    else None\n                )\n\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                )\n            else:\n                t3 = time.time()\n                if per_layer_prefill_flag:\n                    # print(f\"to gpu\")\n                    self.load_layer_to(decoder_layer, InferenceState.PREFILL)\n                    torch.cuda.empty_cache()\n                t4 = time.time()\n                # with open(\"log.txt\", \"a\") as f:\n                #     f.write(f\"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \\n\")\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                    is_prefill = is_prefill,\n                )\n                t5 = time.time()\n                if per_layer_prefill_flag:\n                    # print(f\"to cpu\")\n                    self.load_layer_to(decoder_layer, InferenceState.UNLOAD)\n                    torch.cuda.empty_cache()\n                t6 = time.time()\n            t_gpu += t4 - t3\n            t_cpu += t6 - t5\n            t_f += t5 - t4\n\n            hidden_states = layer_outputs[0]\n\n            # @@@@@@@ TODO open this notes, tmp close to fit deepseekv3\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        if use_torch_npu:\n            hidden_states_without_norm = hidden_states.clone()\n        hidden_states = self.norm(hidden_states)\n        # with open(\"log.txt\", \"a\") as f:\n        #     f.write(f\"@@@After layers\\n\")\n        #     f.write(f\"hidden_states={hidden_states}\\n\")\n        #     f.write(f\"hidden_states.shape={hidden_states.shape}\\n\")\n\n        if per_layer_prefill_flag:\n            t6 = time.time()\n            # print(f\"restore\")\n            per_layer_prefill_flag = False\n            for layer in self.layers:\n                self.load_layer_to(layer, InferenceState.GENERATE)\n            torch.cuda.empty_cache()\n            t7 = time.time()\n\n            print(\n                f\"total time: {t7-t3}, \\n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}\"\n            )\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = (\n                next_decoder_cache.to_legacy_cache()\n                if use_legacy_cache\n                else next_decoder_cache\n            )\n        if not return_dict:\n            if use_torch_npu:\n                return tuple(\n                    v\n                    for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, hidden_states_without_norm]\n                    if v is not None\n                )\n            else:\n                return tuple(\n                    v\n                    for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                    if v is not None\n                )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n    def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState):\n        assert isinstance(\n            layer, DeepseekV2DecoderLayer\n        ), \"module should be nn.ModuleList of decoder layers\"\n\n        # TODO Support restore to original device, not only cuda\n        device = \"cpu\" if target == InferenceState.UNLOAD else \"cuda\"\n\n        # TODO Support DFS to auto use {to, set_inference_mode} according to the module type\n\n        # attn\n        layer.self_attn.to(device)  #\n\n        # mlp\n        if isinstance(layer.mlp, DeepseekV2MoE):\n            layer.mlp.gate.to(device)\n            layer.mlp.experts.set_inference_mode(target)\n            layer.mlp.shared_experts.gate_proj.set_inference_mode(target)\n            layer.mlp.shared_experts.up_proj.set_inference_mode(target)\n            layer.mlp.shared_experts.down_proj.set_inference_mode(target)\n            layer.mlp.shared_experts.act_fn.to(device)\n            # layer.mlp.shared_expert_gate.to(device)\n        else:\n            layer.mlp.gate_proj.set_inference_mode(target)\n            layer.mlp.up_proj.set_inference_mode(target)\n            layer.mlp.down_proj.set_inference_mode(target)\n            layer.mlp.act_fn.to(device)\n        # layer norm\n        layer.input_layernorm.to(device)\n        layer.post_attention_layernorm.to(device)\n\n\nLLAMA_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`LlamaConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\nLLAMA_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaPreTrainedModel(PreTrainedModel):\n    config_class = LlamaConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"LlamaDecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nclass KLlamaModel(BaseInjectedModule):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    dynamic_sdpa = None\n\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        device: str = \"cuda\",\n        per_layer_prefill_intput_threshold: int = 30000,  # if None, no per-layer prefill\n        transfer_map: dict = None,\n        **kwargs,\n    ):\n\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, device, **kwargs\n        )\n        self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold\n        self.transfer_map = transfer_map\n        self.stream_device_map = dict()\n        user_path: str = os.path.expanduser('~')\n        localstore_path: str = os.path.join(user_path,'.ktransformers')\n        config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)\n        with open(config_path,\"r\") as file:\n            config_yaml = yaml.safe_load(file.read())\n            self.long_context_config = config_yaml.get(\"long_context\")\n            self.ext_config = config_yaml.get(\"ext\")\n\n        KLlamaModel.dynamic_sdpa = DynamicScaledDotProductAttention(\n            max_seq_len=self.long_context_config[\"max_seq_len\"],\n            block_size=self.long_context_config[\"block_size\"],\n            config=config,\n            device=torch.device(\"cuda\"),\n            local_windows_len=self.long_context_config[\"local_windows_len\"],\n            topk=self.long_context_config[\"second_select_num\"],\n            threads_num=self.ext_config[\"cpu_infer\"],\n            anchor_type=self.long_context_config[\"anchor_type\"],\n            kv_type=self.long_context_config[\"kv_type\"],\n            dense_layer_num=self.long_context_config[\"dense_layer_num\"],\n            anchor_num=self.long_context_config[\"anchor_num\"],\n            preselect_block=self.long_context_config[\"preselect_block\"],\n            block_selection_mode=self.long_context_config[\"head_select_mode\"],\n            preselect_block_count=self.long_context_config[\"preselect_block_count\"],\n            layer_step=self.long_context_config[\"layer_step\"],\n            token_step=self.long_context_config[\"token_step\"],\n            prefill_chunk_size=self.long_context_config[\"chunk_size\"],\n            use_attn_sparsity=False,\n        )\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training and use_cache:\n            logger.warning_once(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\"\n            )\n            use_cache = False\n\n        return_legacy_cache = False\n        if (\n            use_cache and not isinstance(past_key_values, Cache) and not self.training\n        ):  # kept for BC (non `Cache` `past_key_values` inputs)\n            return_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            logger.warning_once(\n                \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n                \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n            )\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=\"cuda\",\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = None\n        chunck_size = self.long_context_config[\"chunk_size\"]\n        cur_idx = 0\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids.to(\"cpu\"))\n        q_len = cache_position.size(0)\n\n        # generate\n        if q_len == 1:\n            x = inputs_embeds[:, -1:, :]\n            position_ids = position_ids[:, -1:]\n            return self.forward_chunk(\n                x,\n                causal_mask,\n                position_ids,\n                past_key_values,\n                output_attentions,\n                use_cache,\n                cache_position,\n                output_hidden_states,\n                return_dict,\n            )\n        elif q_len <= chunck_size:\n            inputs_embeds = inputs_embeds.to('cuda')\n            output = self.forward_chunk(\n                inputs_embeds,\n                causal_mask,\n                position_ids,\n                past_key_values,\n                output_attentions,\n                use_cache,\n                cache_position,\n                output_hidden_states,\n                return_dict,\n            )\n            KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)\n            KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)\n            return output\n        cur_idx = 0\n        assert (\n            output_attentions == False\n        ), \"output_attentions is not supported when using chunked attention\"\n        attn_output = None\n        # prefill\n        KLlamaModel.dynamic_sdpa.remaining_length = q_len\n        while cur_idx < q_len:\n            print(f'current prefill length: {cur_idx}')\n            chunk_mask = None\n            if inputs_embeds.device.type == 'cpu':\n                tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)].to(\"cuda\")\n            else:\n                tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)]\n            output_with_past = self.forward_chunk(\n                tmp_inputs_embeds,\n                chunk_mask,\n                position_ids[:, cur_idx : min(cur_idx + chunck_size, q_len)],\n                past_key_values,\n                output_attentions,\n                use_cache,\n                cache_position[cur_idx : min(cur_idx + chunck_size, q_len)],\n            )\n            cur_output = output_with_past.last_hidden_state\n            KLlamaModel.dynamic_sdpa.remaining_length -= (\n                min(cur_idx + chunck_size, q_len) - cur_idx\n            )\n            cur_idx += chunck_size\n            # if attn_output is None:\n            attn_output = cur_output\n            # else:\n            #     attn_output = torch.cat((attn_output, cur_output), dim=-2)\n\n        KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)\n        KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)\n        return BaseModelOutputWithPast(last_hidden_state=attn_output)\n\n    def forward_chunk(\n        self,\n        inputs_embeds,\n        causal_mask,\n        position_ids,\n        past_key_values,\n        output_attentions,\n        use_cache,\n        cache_position,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ):\n\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_legacy_cache = False\n        if use_cache and not isinstance(\n            past_key_values, Cache\n        ):  # kept for BC (non `Cache` `past_key_values` inputs)\n            return_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                    position_embeddings,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if return_legacy_cache:\n            next_cache = next_cache.to_legacy_cache()\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = (\n            past_key_values.get_seq_length() if past_key_values is not None else 0\n        )\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and not using_static_cache\n            and not output_attentions\n        ):\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\n                    \"Custom 4D attention mask should be passed in inverted form with max==0`\"\n                )\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length),\n                fill_value=min_dtype,\n                dtype=dtype,\n                device=device,\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(\n                target_length, device=device\n            ) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(\n                input_tensor.shape[0], 1, -1, -1\n            )\n            if attention_mask is not None:\n                causal_mask = (\n                    causal_mask.clone()\n                )  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = (\n                    causal_mask[:, :, :, :mask_length]\n                    + attention_mask[:, None, None, :]\n                )\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[\n                    :, :, :, :mask_length\n                ].masked_fill(padding_mask, min_dtype)\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(\n                causal_mask, min_dtype\n            )\n\n        return causal_mask\n"
  },
  {
    "path": "archive/ktransformers/operators/triton_attention.py",
    "content": "# Adapted from\r\n# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py\r\n# which was originally adapted from\r\n# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py\r\n# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py\r\n\r\nimport triton\r\nimport triton.language as tl\r\nfrom ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor\r\n@triton.jit\r\ndef tanh(x):\r\n    # Tanh is just a scaled sigmoid\r\n    return 2 * tl.sigmoid(2 * x) - 1\r\n\r\n@triton.jit\r\ndef _fwd_grouped_kernel_stage1(\r\n    Q,\r\n    K_Buffer,\r\n    V_Buffer,\r\n    sm_scale,\r\n    Req_to_tokens,\r\n    B_Seqlen,\r\n    Att_Out,\r\n    stride_req_to_tokens_b,\r\n    stride_qbs,\r\n    stride_qh,\r\n    stride_buf_kbs,\r\n    stride_buf_kh,\r\n    stride_buf_vbs,\r\n    stride_buf_vh,\r\n    stride_mid_ob,\r\n    stride_mid_oh,\r\n    stride_mid_os,\r\n    kv_group_num: tl.constexpr,\r\n    q_head_num: tl.constexpr,\r\n    BLOCK_DMODEL: tl.constexpr,\r\n    BLOCK_DPE: tl.constexpr,\r\n    BLOCK_DV: tl.constexpr,\r\n    BLOCK_N: tl.constexpr,\r\n    BLOCK_H: tl.constexpr,\r\n    NUM_KV_SPLITS: tl.constexpr,\r\n    PAGE_SIZE: tl.constexpr,\r\n    logit_cap: tl.constexpr,\r\n    Lk: tl.constexpr,\r\n    Lv: tl.constexpr,\r\n):\r\n    cur_batch = tl.program_id(0)\r\n    cur_head_id = tl.program_id(1)\r\n    cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)\r\n    split_kv_id = tl.program_id(2)\r\n\r\n    if kv_group_num > BLOCK_H:\r\n        VALID_BLOCK_H: tl.constexpr = BLOCK_H\r\n    else:\r\n        VALID_BLOCK_H: tl.constexpr = kv_group_num\r\n    cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)\r\n    mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H\r\n    mask_h = mask_h & (cur_head < q_head_num)\r\n\r\n    offs_d = tl.arange(0, BLOCK_DMODEL)\r\n    offs_dv = tl.arange(0, BLOCK_DV)\r\n    mask_d = offs_d < Lk\r\n    mask_dv = offs_dv < Lv\r\n    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\r\n    cur_batch_req_idx = cur_batch\r\n\r\n    offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[\r\n        None, :]\r\n    q = tl.load(Q + offs_q,\r\n                mask=(mask_h[:, None]) & (mask_d[None, :]),\r\n                other=0.0)\r\n\r\n    if BLOCK_DPE > 0:\r\n        offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)\r\n        mask_dpe = offs_dpe < Lk\r\n        off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +\r\n                   offs_dpe[None, :])\r\n        qpe = tl.load(Q + off_qpe,\r\n                      mask=(mask_h[:, None]) & (mask_dpe[None, :]),\r\n                      other=0.0)\r\n\r\n    kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)\r\n    split_kv_start = kv_len_per_split * split_kv_id\r\n    split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,\r\n                              cur_batch_seq_len)\r\n    \r\n    e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float(\"inf\")\r\n    e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)\r\n    acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)\r\n\r\n    if split_kv_end > split_kv_start:\r\n        for start_n in range(split_kv_start, split_kv_end, BLOCK_N):\r\n            offs_n = start_n + tl.arange(0, BLOCK_N)\r\n            kv_page_number = tl.load(\r\n                Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +\r\n                offs_n // PAGE_SIZE,\r\n                mask=offs_n < split_kv_end,\r\n                other=0,\r\n            )\r\n            kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE\r\n            offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +\r\n                          cur_kv_head * stride_buf_kh + offs_d[:, None])\r\n            k = tl.load(\r\n                K_Buffer + offs_buf_k,\r\n                mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),\r\n                other=0.0,\r\n            )\r\n            qk = tl.dot(q, k.to(q.dtype))\r\n            \r\n            if BLOCK_DPE > 0:\r\n                offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +\r\n                                cur_kv_head * stride_buf_kh +\r\n                                offs_dpe[:, None])\r\n                kpe = tl.load(\r\n                    K_Buffer + offs_buf_kpe,\r\n                    mask=(offs_n[None, :] < split_kv_end) &\r\n                    (mask_dpe[:, None]),\r\n                    other=0.0,\r\n                )\r\n                qk += tl.dot(qpe, kpe.to(qpe.dtype))\r\n            qk *= sm_scale\r\n\r\n            if logit_cap > 0:\r\n                qk = logit_cap * tanh(qk / logit_cap)\r\n\r\n            qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),\r\n                          qk, float(\"-inf\"))\r\n\r\n            offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +\r\n                          cur_kv_head * stride_buf_vh + offs_dv[None, :])\r\n            v = tl.load(\r\n                V_Buffer + offs_buf_v,\r\n                mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),\r\n                other=0.0,\r\n            )\r\n\r\n            n_e_max = tl.maximum(tl.max(qk, 1), e_max)\r\n            re_scale = tl.exp(e_max - n_e_max)\r\n            p = tl.exp(qk - n_e_max[:, None])\r\n            acc *= re_scale[:, None]\r\n            acc += tl.dot(p.to(v.dtype), v)\r\n\r\n            e_sum = e_sum * re_scale + tl.sum(p, 1)\r\n            e_max = n_e_max\r\n\r\n        offs_mid_o = (cur_batch * stride_mid_ob +\r\n                      cur_head[:, None] * stride_mid_oh +\r\n                      split_kv_id * stride_mid_os + offs_dv[None, :])\r\n\r\n        tl.store(\r\n            Att_Out + offs_mid_o,\r\n            acc / e_sum[:, None],\r\n            mask=(mask_h[:, None]) & (mask_dv[None, :]),\r\n        )\r\n\r\n        offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +\r\n                        split_kv_id * stride_mid_os + Lv)\r\n\r\n        tl.store(\r\n            Att_Out + offs_mid_o_1,\r\n            e_max + tl.log(e_sum),\r\n            mask=mask_h,\r\n        )\r\n\r\ndef _decode_grouped_att_m_fwd(\r\n    q,\r\n    k_buffer,\r\n    v_buffer,\r\n    att_out,\r\n    Req_to_tokens,\r\n    B_Seqlen,\r\n    num_kv_splits,\r\n    sm_scale,\r\n    page_size,\r\n    logit_cap,\r\n):\r\n    BLOCK = 32\r\n    Lk = k_buffer.shape[-1]\r\n    Lv = v_buffer.shape[-1]\r\n\r\n    # [TODO] work around shmem limit on MI3xx\r\n    \r\n    # TODO: support hip\r\n    if device_manager.gpu_vendor == GPUVendor.AMD and Lk >= 576:\r\n       BLOCK = 16\r\n\r\n    if Lk == 576:\r\n        BLOCK_DMODEL = 512\r\n        BLOCK_DPE = 64\r\n    elif Lk == 288:\r\n        BLOCK_DMODEL = 256\r\n        BLOCK_DPE = 32\r\n    else:\r\n        BLOCK_DMODEL = triton.next_power_of_2(Lk)\r\n        BLOCK_DPE = 0\r\n    BLOCK_DV = triton.next_power_of_2(Lv)\r\n\r\n    batch, head_num = q.shape[0], q.shape[1]\r\n    kv_group_num = q.shape[1] // k_buffer.shape[-2]\r\n\r\n    BLOCK_H = 16\r\n    NUM_KV_SPLITS = num_kv_splits\r\n    grid = (\r\n        batch,\r\n        triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),\r\n        NUM_KV_SPLITS,\r\n    )\r\n\r\n    extra_kargs = {}\r\n    # TODO: support hip\r\n    \"\"\"\r\n    if is_hip_:\r\n        # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html\r\n        # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py\r\n        extra_kargs = {\r\n            \"waves_per_eu\": 4,\r\n            \"matrix_instr_nonkdim\": 16,\r\n            \"kpack\": 2\r\n        }\r\n    \"\"\"\r\n    \r\n    _fwd_grouped_kernel_stage1[grid](\r\n        q,\r\n        k_buffer,\r\n        v_buffer,\r\n        sm_scale,\r\n        Req_to_tokens,\r\n        B_Seqlen,\r\n        att_out,\r\n        Req_to_tokens.stride(0),\r\n        q.stride(0),\r\n        q.stride(1),\r\n        k_buffer.stride(-3),  # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)\r\n        k_buffer.stride(-2),  # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)\r\n        v_buffer.stride(-3),  # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)\r\n        v_buffer.stride(-2),  # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)\r\n        att_out.stride(0),\r\n        att_out.stride(1),\r\n        att_out.stride(2),\r\n        kv_group_num=kv_group_num,\r\n        q_head_num=head_num,\r\n        BLOCK_DMODEL=BLOCK_DMODEL,\r\n        BLOCK_DPE=BLOCK_DPE,\r\n        BLOCK_DV=BLOCK_DV,\r\n        BLOCK_N=BLOCK,\r\n        BLOCK_H=BLOCK_H,\r\n        NUM_KV_SPLITS=NUM_KV_SPLITS,\r\n        PAGE_SIZE=page_size,\r\n        logit_cap=logit_cap,\r\n        num_warps=4,\r\n        num_stages=2,\r\n        Lk=Lk,\r\n        Lv=Lv,\r\n        **extra_kargs,\r\n    )\r\n\r\n@triton.jit\r\ndef _fwd_kernel_stage2(\r\n    Mid_O,\r\n    o,\r\n    B_Seqlen,\r\n    stride_mid_ob,\r\n    stride_mid_oh,\r\n    stride_mid_os,\r\n    stride_obs,\r\n    stride_oh,\r\n    NUM_KV_SPLITS: tl.constexpr,\r\n    BLOCK_DV: tl.constexpr,\r\n    Lv: tl.constexpr,\r\n):\r\n    cur_batch = tl.program_id(0)\r\n    cur_head = tl.program_id(1)\r\n\r\n    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\r\n\r\n    offs_d = tl.arange(0, BLOCK_DV)\r\n    mask_d = offs_d < Lv\r\n\r\n    e_sum = 0.0\r\n    e_max = -float(\"inf\")\r\n    acc = tl.zeros([BLOCK_DV], dtype=tl.float32)\r\n\r\n    offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\r\n    offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv\r\n\r\n    for split_kv_id in range(0, NUM_KV_SPLITS):\r\n        kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)\r\n        split_kv_start = kv_len_per_split * split_kv_id\r\n        split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,\r\n                                  cur_batch_seq_len)\r\n\r\n        if split_kv_end > split_kv_start:\r\n            tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os,\r\n                         mask=mask_d,\r\n                         other=0.0)\r\n            tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)\r\n            n_e_max = tl.maximum(tlogic, e_max)\r\n\r\n            old_scale = tl.exp(e_max - n_e_max)\r\n            acc *= old_scale\r\n            exp_logic = tl.exp(tlogic - n_e_max)\r\n            acc += exp_logic * tv\r\n\r\n            e_sum = e_sum * old_scale + exp_logic\r\n            e_max = n_e_max\r\n\r\n    tl.store(\r\n        o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,\r\n        acc / e_sum,\r\n        mask=mask_d,\r\n    )\r\n\r\ndef _decode_softmax_reducev_fwd(\r\n    logits,\r\n    q,\r\n    o,\r\n    v_buffer,\r\n    b_seq_len,\r\n    num_kv_splits,\r\n):\r\n    batch, head_num = q.shape[0], q.shape[1]\r\n    Lv = v_buffer.shape[-1]\r\n    BLOCK_DV = triton.next_power_of_2(Lv)\r\n\r\n    NUM_KV_SPLITS = num_kv_splits\r\n\r\n    extra_kargs = {}\r\n    # TODO: support hip\r\n    \"\"\"\r\n    if is_hip_:\r\n        # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html\r\n        # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py\r\n        extra_kargs = {\r\n            \"waves_per_eu\": 4,\r\n            \"matrix_instr_nonkdim\": 16,\r\n            \"kpack\": 2\r\n        }\r\n    \"\"\"\r\n    \r\n    grid = (batch, head_num)\r\n    _fwd_kernel_stage2[grid](\r\n        logits,\r\n        o,\r\n        b_seq_len,\r\n        logits.stride(0),\r\n        logits.stride(1),\r\n        logits.stride(2),\r\n        o.stride(0),\r\n        o.stride(1),\r\n        NUM_KV_SPLITS=NUM_KV_SPLITS,\r\n        BLOCK_DV=BLOCK_DV,\r\n        Lv=Lv,\r\n        num_warps=4,\r\n        num_stages=2,\r\n        **extra_kargs,\r\n    )\r\n\r\ndef decode_attention_fwd_grouped(\r\n    q,\r\n    k_buffer,\r\n    v_buffer,\r\n    o,\r\n    req_to_token,\r\n    b_seq_len,\r\n    attn_logits,\r\n    num_kv_splits,\r\n    sm_scale,\r\n    page_size,\r\n    logit_cap=0.0,\r\n):\r\n    _decode_grouped_att_m_fwd(\r\n        q,\r\n        k_buffer,\r\n        v_buffer,\r\n        attn_logits,\r\n        req_to_token,\r\n        b_seq_len,\r\n        num_kv_splits,\r\n        sm_scale,\r\n        page_size,\r\n        logit_cap,\r\n    )\r\n\r\n    _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,\r\n                                num_kv_splits)\r\n"
  },
  {
    "path": "archive/ktransformers/operators/triton_attention_prefill.py",
    "content": "\n# Adapted from\n# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py\n# which was originally adapted from\n# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1\n\n\"\"\"\nMemory-efficient attention for prefill.\nIt supporst page size = 1.\n\"\"\"\n\n# Adapted from\n# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1\nimport torch\nimport triton\nimport triton.language as tl\n\nis_cuda_available = torch.cuda.is_available()\nif is_cuda_available:\n    CUDA_CAPABILITY = torch.cuda.get_device_capability()\n\n\n@triton.jit\ndef _fwd_kernel(\n    Q,\n    K,\n    V,\n    sm_scale,\n    B_Start_Loc,\n    B_Seqlen,\n    Out,\n    stride_qbs,\n    stride_qh,\n    stride_kbs,\n    stride_kh,\n    stride_vbs,\n    stride_vh,\n    stride_obs,\n    stride_oh,\n    kv_group_num: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n    Lk: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    start_m = tl.program_id(2)\n\n    cur_kv_head = cur_head // kv_group_num\n\n    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n    cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n    block_start_loc = BLOCK_M * start_m\n\n    # initialize offsets\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    off_q = (\n        (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n        + cur_head * stride_qh\n        + offs_d[None, :]\n    )\n    off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]\n    off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]\n\n    mask_d = offs_d < Lk\n\n    q = tl.load(\n        Q + off_q,\n        mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),\n        other=0.0,\n    )\n\n    k_ptrs = K + off_k\n    v_ptrs = V + off_v\n\n    # initialize pointer to m and l\n    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n    block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n    end_n = (\n        cur_batch_seq_len\n        if not IS_CAUSAL\n        else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)\n    )\n    for start_n in range(0, block_mask * end_n, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        # -- compute qk ----\n        k = tl.load(\n            k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n            mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),\n            other=0.0,\n        )\n        # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)\n\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        qk += tl.dot(q, k)\n        qk *= sm_scale\n\n        if IS_CAUSAL:\n            qk += tl.where(\n                (start_n + offs_n[None, :] < cur_batch_seq_len)\n                & (offs_m[:, None] >= (start_n + offs_n[None, :])),\n                0,\n                float(\"-inf\"),\n            )\n        else:\n            qk += tl.where(\n                (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float(\"-inf\")\n            )\n\n        # -- compute m_ij, p, l_ij\n        m_ij = tl.max(qk, 1)\n        p = tl.exp(qk - m_ij[:, None])\n        l_ij = tl.sum(p, 1)\n        # -- update m_i and l_i\n        m_i_new = tl.maximum(m_i, m_ij)\n        alpha = tl.exp(m_i - m_i_new)\n        beta = tl.exp(m_ij - m_i_new)\n        l_i_new = alpha * l_i + beta * l_ij\n        # -- update output accumulator --\n        # scale p\n        p_scale = beta / l_i_new\n        p = p * p_scale[:, None]\n        # scale acc\n        acc_scale = l_i / l_i_new * alpha\n        acc = acc * acc_scale[:, None]\n        # update acc\n        v = tl.load(\n            v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n            mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),\n            other=0.0,\n        )\n\n        p = p.to(v.dtype)\n        acc += tl.dot(p, v)\n        # update m_i and l_i\n        l_i = l_i_new\n        m_i = m_i_new\n    # initialize pointers to output\n    off_o = (\n        (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n        + cur_head * stride_oh\n        + offs_d[None, :]\n    )\n    out_ptrs = Out + off_o\n    tl.store(\n        out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])\n    )\n\n\ndef context_attention_fwd(\n    q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True\n):\n    \"\"\"\n    q, k, v: [b * s, head, head_dim]\n    b_start_loc: [b]\n    b_seq_len: [b]\n    out: [b * s, head, head_dim]\n    \"\"\"\n    if is_cuda_available and CUDA_CAPABILITY[0] > 8:\n        BLOCK = 128\n    else:\n        BLOCK = 64\n\n    Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n\n    sm_scale = 1.0 / (Lq**0.5)\n    batch, head = b_seq_len.shape[0], q.shape[1]\n    kv_group_num = q.shape[1] // k.shape[1]\n\n    grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n    num_warps = 4 if Lk <= 64 else 8\n\n    _fwd_kernel[grid](\n        q,\n        k,\n        v,\n        sm_scale,\n        b_start_loc,\n        b_seq_len,\n        o,\n        q.stride(0),\n        q.stride(1),\n        k.stride(0),\n        k.stride(1),\n        v.stride(0),\n        v.stride(1),\n        o.stride(0),\n        o.stride(1),\n        kv_group_num=kv_group_num,\n        BLOCK_M=BLOCK,\n        BLOCK_DMODEL=triton.next_power_of_2(Lk),\n        BLOCK_N=BLOCK,\n        IS_CAUSAL=is_causal,\n        num_warps=num_warps,\n        num_stages=1,\n        Lk=Lk,\n    )"
  },
  {
    "path": "archive/ktransformers/optimize/optimize.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang, Azure-Tang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nfrom typing import Mapping, List\nimport torch\nimport yaml\nimport re\nfrom torch import nn\nfrom transformers import AutoConfig\nfrom transformers.configuration_utils import PretrainedConfig\n# from operators import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory\nfrom ktransformers.util.custom_gguf import translate_name_to_gguf\nfrom ktransformers.util import utils\nfrom ktransformers.util.utils import set_module, load_weights\nimport itertools\nimport copy\n\ntry:\n    import torch_npu\n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\n\ndef inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''):\n    for name, child in module._modules.items():\n        if child is not None:\n            child_prefix = prefix + name\n            if child_prefix in local_optimization_dict:\n                inject_module_meta=local_optimization_dict[child_prefix]\n                if inject_module_meta[\"class\"] != \"default\":\n                    import_path = inject_module_meta[\"class\"].split(\".\")\n                    import_module_name = \".\".join(import_path[:-1])\n                    gguf_loader.tensor_device_map[inject_module_meta[\"key\"]] = inject_module_meta[\"kwargs\"] if \"kwargs\" in inject_module_meta else dict()\n                    import_class_name = import_path[-1]\n                    module_cls=getattr(__import__(import_module_name, fromlist=[\"\"]), import_class_name)\n                    if use_torch_npu:\n                        print(f\"Injecting {child_prefix} as\", import_module_name, \".\",\n                            import_class_name) if torch.distributed.get_rank() == 0 else None #TODO 分布式\n                    else: \n                        print(f\"Injecting {child_prefix} as\", import_module_name, \".\", import_class_name)\n                    inject_module=module_cls(key = inject_module_meta[\"key\"], gguf_loader = gguf_loader, config = model_config, orig_module=child, **inject_module_meta[\"kwargs\"])\n                    set_module(module, name, inject_module)\n                elif inject_module_meta[\"class\"] == \"default\":\n                    print(f\"Injecting {child_prefix} as default\")\n                    gguf_loader.tensor_device_map[inject_module_meta[\"key\"]] = inject_module_meta[\"kwargs\"] if \"kwargs\" in inject_module_meta else dict()\n                else:\n                    raise Exception(\"inject_module_meta[\\\"class\\\"] must be \\\"default\\\" or a class path\")\n                child_prefix += \".\"\n                child_optimization_dict = {k: v for k, v in local_optimization_dict.items() if k.startswith(child_prefix)}\n                inject(child, child_optimization_dict, model_config, gguf_loader, child_prefix)\n\ndef del_meta(module:nn.Module):\n    #print(\"default loading weights\", prefix)\n    persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}\n    local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())\n    local_state = {k: v for k, v in local_name_params if v is not None}\n    for name, param in local_state.items():\n        if param.device == \"meta\" or param.device == torch.device(\"meta\"):\n            module.__delattr__(name)\n    for name, child in module._modules.items():\n        del_meta(child)\n\ndef gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str=\"\", default_device: str = \"cuda:0\"):\n    module_name = prefix[:-1]\n    if use_torch_npu:\n        translated_name = translate_name_to_gguf(prefix)[:-1]\n    recursive = True\n    for rule in rule_list:\n        match_meta = rule[\"match\"]\n        if \"class\" not in match_meta and \"name\" not in match_meta:\n            raise Exception(\"match must have at least one of \\\"class\\\" and \\\"name\\\"\")\n        if \"class\" in match_meta:\n            import_path = match_meta[\"class\"].split(\".\")\n            import_module_name = \".\".join(import_path[:-1])\n            import_class_name = import_path[-1]\n            module_cls=getattr(__import__(import_module_name, fromlist=[\"\"]), import_class_name)\n            if not isinstance(module, module_cls):\n                continue\n        if \"name\" in match_meta:\n            if re.search(match_meta[\"name\"], module_name) is None:\n                continue\n        if \"replace\" not in rule:\n            raise Exception(\"replace must be in rule\")\n        if \"replace\" in rule:\n            replace_meta = rule[\"replace\"]\n            if module_name not in out_data:\n                out_data[module_name]={\"key\": module_name if not use_torch_npu else translated_name,\n                                    \"class\": replace_meta[\"class\"] if \"class\" in replace_meta else \"default\",\n                                    # \"device\": replace_meta[\"device\"] if \"device\" in replace_meta else default_device,\n                                    \"kwargs\": copy.deepcopy(replace_meta[\"kwargs\"]) if \"kwargs\" in replace_meta else dict()}\n            else:\n                if out_data[module_name][\"class\"] == \"default\":\n                    out_data[module_name][\"class\"] = replace_meta[\"class\"] if \"class\" in replace_meta else \"default\"\n                out_data[module_name][\"kwargs\"].update(copy.deepcopy(replace_meta[\"kwargs\"]) if \"kwargs\" in replace_meta else dict())\n        if \"recursive\" in rule:\n            recursive = bool(rule[\"recursive\"])\n        break\n            \n    if module_name not in out_data:\n        out_data[module_name]= {\n            \"class\": \"default\",\n            \"key\": module_name if not use_torch_npu else translated_name,\n            \"kwargs\": {\"generate_device\": default_device,\n                       \"prefill_device\": default_device}\n        }\n\n    #print(out_data[module_name])\n    #input()\n\n    if recursive:\n        for name, child in module._modules.items():\n            if child is not None:\n                child_prefix = prefix + name + \".\"\n                gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device)\n    \n\ndef translate_model_config(model_config: PretrainedConfig):\n    # for supporting some special model \n    if model_config.model_type == \"mixtral\":\n        model_config.moe_intermediate_size = model_config.intermediate_size\n    \n    return model_config\n\n\ndef optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = \"cuda:0\", q4_gguf_path=\"\"):\n    with open(rule_file, 'r', encoding='utf-8') as f:\n        rule_list = yaml.load(f.read(), Loader=yaml.FullLoader)\n    \n    optimize_config = dict()\n    gen_optimize_config(module, optimize_config, rule_list, default_device = default_device)\n    \n    model_config = translate_model_config(model_config)\n\n    if use_torch_npu:\n        if q4_gguf_path:\n            q4_gguf_loader = GGUFLoader(q4_gguf_path)\n            utils.Q4_GGUF_LODER = q4_gguf_loader\n        gguf_loader = GGUFLoader(gguf_path, getattr(model_config, \"quantize\", None))\n        with torch.device(\"meta\"):\n            inject(module, optimize_config, model_config, gguf_loader)\n        # pre load lm_head because its big inter result\n        load_weights(module.lm_head, gguf_loader, \"lm_head.\")\n        load_weights(module, gguf_loader)\n        module.gguf_loader = gguf_loader\n    else:\n        weights_loader = ModelLoaderFactory.create_loader(gguf_path)\n        with torch.device(\"meta\"):\n            inject(module, optimize_config, model_config, weights_loader)\n        # pre load lm_head because its big inter result\n        load_weights(module.lm_head, weights_loader, \"lm_head.\", device=default_device)\n        load_weights(module, weights_loader, device=default_device)\n        module.gguf_loader = weights_loader\n    del_meta(module)\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n    elif torch.xpu.is_available():\n        torch.xpu.empty_cache()\n    else:\n        torch.cuda.empty_cache()\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n- match:\n    name: \"^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n- match:\n    name: \"^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:2\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:2\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:3\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:3\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n- match:\n    name: \"^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        15: \"cuda:1\"\n        30: \"cuda:2\"\n        45: \"cuda:3\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"(^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"(^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n      \n- match:\n    name: \"(^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.)|(^model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        30: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([345][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-gpu-cpu.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n# === Rotary Embedding Replacement ===\n\n# GPU 0: layers 0–9\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n# CPU: layers 10-29\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n# === Linear Layers Replacement (excluding self_attn) ===\n\n# GPU 0: layers 0–9\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.(?!self_attn).*$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n# CPU: layers 10-29\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.(?!self_attn).*$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n      generate_op: \"KLinearCPUInfer\"\n      prefill_op: \"KLinearTorch\"\n      out_device: \"cpu\"\n\n# === MLP (MoE) Replacement ===\n\n# GPU 0: layers 0–9\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n# CPU: layers 10-29\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n# === MLP Gate Replacement ===\n\n# GPU 0: layers 0–9\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n# CPU: layers 10-29\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n# === MLP Experts Replacement ===\n\n# GPU 0: layers 0–9\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n# CPU: layers 10-29\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cpu\"\n  recursive: False # don't recursively inject submodules of this module\n\n# === Self-Attention Replacement ===\n\n# GPU 0: layers 0–9\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n# CPU: layers 10-29\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n# === Overall Model Replacement with Transfer Map ===\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map:\n        10: \"cpu\"\n\n# === Default Catch-All for Other Modules ===#\n# GPU 0: layers 0–9\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n#lmm_head on GPU 0\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# CPU: layers 10-29\n- match:\n    name: \"(^model\\\\.layers\\\\.([12][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.(?!self_attn).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.(?!self_attn).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        10: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([12][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearFP8\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoEV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"llamafile\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.RMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\n  replace:\n    class:  ktransformers.operators.mlp.kDeepseekV3MLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearFP8\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoEV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.RMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\n  replace:\n    class:  ktransformers.operators.mlp.kDeepseekV3MLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearFP8\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n# === Rotary Embedding Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# === Linear Layers Replacement (excluding self_attn.kv_b_proj) ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# === MLP (MoE) Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# === MLP Gate Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# === MLP Experts Replacement ===\n# replace with marlin expert. Open and modify layer-num as needed.\n# Each layer of malin experts takes about 6GB of GPU memory.\n# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!\n# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!\n\n# GPU 0: layers 3–4\n# - match:\n#     name: \"^model\\\\.layers\\\\.([3-4])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:0\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 1: layers 15–17\n# - match:\n#     name: \"^model\\\\.layers\\\\.(1[5-7])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 2: layers 30–32\n# - match:\n#     name: \"^model\\\\.layers\\\\.(3[0-2])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:2\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 3: layers 45–46\n# - match:\n#     name: \"^model\\\\.layers\\\\.(4[5-6])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:3\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n\n# === MLP Experts Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:2\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:2\"\n  recursive: False\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:3\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:3\"\n  recursive: False\n\n# === Self-Attention Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      absorb_for_prefill: False\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      absorb_for_prefill: False\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n      absorb_for_prefill: False\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      absorb_for_prefill: False\n\n# === Overall Model Replacement with Transfer Map ===\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 means close layer‐wise prefill\n      transfer_map:\n        15: \"cuda:1\" # Layers 15+ on GPU 1\n        30: \"cuda:2\" # Layers 30+ on GPU 2\n        45: \"cuda:3\" # Layers 45+ on GPU 3\n\n# === Default Catch-All for Other Modules ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# For final modules (model.norm), ensure they are on GPU 3 (as in your original config)\n- match:\n    name: \"(^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.)|(^model\\\\.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n# === Rotary Embedding Replacement ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.([3][2-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n\n# GPU 7: layers 56–60\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n\n\n# === Linear Layers Replacement (excluding self_attn.kv_b_proj) ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 7: layers 56–63\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n\n\n# === MLP (MoE) Replacement ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n\n# GPU 7: layers 56–60\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n\n# === MLP Gate Replacement ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n\n# GPU 7: layers 56–60\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n\n\n# === MLP Experts Replacement ===\n# replace with marlin expert. Open and modify layer-num as needed.\n# Each layer of malin experts takes about 6GB of GPU memory.\n# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!\n# !!!Loading marlin expert will take signifcant time.!!!\n\n# GPU 0: layers 0–7\n# - match:\n#     name: \"^model\\\\.layers\\\\.([0-7])\\\\.mlp\\\\.experts$\" # inject experts in layer 0~4 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts  \n#     kwargs:\n#       generate_device: \"cuda:0\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 1: layers 8–15\n# - match:\n#     name: \"^model\\\\.layers\\\\.([8-9]|1[0-5)\\\\.mlp\\\\.experts$\" # inject experts in layer 30~31 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False \n\n# # GPU 2: layers 16–23\n# - match:\n#     name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.mlp\\\\.experts$\" # inject experts in layer 0~4 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts  \n#     kwargs:\n#       generate_device: \"cuda:0\" \n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 3: layers 24–31\n# - match:\n#     name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.mlp\\\\.experts$\" # inject experts in layer 30~31 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False \n\n# # GPU 4: layers 32–39\n# - match:\n#     name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.mlp\\\\.experts$\" # inject experts in layer 0~4 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts  \n#     kwargs:\n#       generate_device: \"cuda:0\" \n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 5: layers 40–47\n# - match:\n#     name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.mlp\\\\.experts$\" # inject experts in layer 30~31 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False \n\n# # GPU 6: layers 48–55\n# - match:\n#     name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.mlp\\\\.experts$\" # inject experts in layer 0~4 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts  \n#     kwargs:\n#       generate_device: \"cuda:0\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 7: layers 56–60\n# - match:\n#     name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.mlp\\\\.experts$\" # inject experts in layer 30~31 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False \n\n\n# === MLP Experts Replacement ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:2\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:2\"\n  recursive: False\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:3\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:3\"\n  recursive: False\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:4\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:4\"\n  recursive: False\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:5\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:5\"\n  recursive: False\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:6\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:6\"\n  recursive: False\n\n# GPU 7: layers 56–60\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:7\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:7\"\n  recursive: False\n\n\n# === Self-Attention Replacement ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n\n# GPU 7: layers 56–60\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n\n# === Overall Model Replacement with Transfer Map ===\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 means close layer‐wise prefill\n      transfer_map:\n        8: \"cuda:1\"\n        16: \"cuda:2\"\n        24: \"cuda:3\"\n        32: \"cuda:4\"\n        40: \"cuda:5\"\n        48: \"cuda:6\"\n        56: \"cuda:7\"\n\n# === Default Catch-All for Other Modules ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n\n# GPU 7: layers 56–63\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# For final modules (model.norm), ensure they are on GPU 7 (as in your original config)\n- match:\n    name: \"(^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.)|(^model\\\\.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearFP8\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearFP8\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        30: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([3456][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-4])\\\\.mlp\\\\.experts$\" # inject experts in layer 0~4 as marlin expert\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts  \n    kwargs:\n      generate_device: \"cuda:0\" # run in cuda:0\n      generate_op:  \"KExpertsMarlin\"\n  recursive: False\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0])\\\\.mlp\\\\.experts$\" # inject experts in layer 30~31 as marlin expert\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      generate_device: \"cuda:1\"\n      generate_op:  \"KExpertsMarlin\"\n  recursive: False \n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        30: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([3456][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        30: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([3456][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-npu.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"npu:0\"\n      prefill_device: \"npu:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"npu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"npu\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoEV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.RMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\n  replace:\n    class:  ktransformers.operators.mlp.kDeepseekV3MLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_glm4_moe.Glm4MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.KGlm4MoeRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"VLinearMarlin\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_glm4_moe.Glm4MoeMoE\n  replace:\n    class: ktransformers.operators.experts.KGlm4MoeMoE\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KGlm4Experts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: None\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KGlm4MoeAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_glm4_moe.Glm4MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KGlm4MoeRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_glm4_moe.Glm4MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KGlm4MoeMLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Internlm2_5-7b-Chat-1m.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_llama.LlamaRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbeddingV2\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n- match:\n    class: ktransformers.models.modeling_llama.LlamaModel\n  replace:\n    class: ktransformers.operators.models.KLlamaModel\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KLlamaAttention\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Mixtral.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*$\"\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.block_sparse_moe$\"\n    class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock\n  replace: \n    class: ktransformers.operators.experts.KMistralSparseMoEBlock\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.block_sparse_moe\\\\.experts$\"\n  replace: \n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml",
    "content": "\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoEV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.RMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\n  replace:\n    class:  ktransformers.operators.mlp.kDeepseekV3MLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbeddingV4\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)\n#- match:\n#    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n#  replace:\n#    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n#    kwargs:\n#      prefill_device: \"cuda\"\n#      prefill_op: \"KExpertsTorch\"\n#      generate_device: \"cuda\"\n#      generate_op: \"KExpertsMarlin\"\n#  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml",
    "content": "- match:\n    name: \"^model\\\\.layers\\\\.([012])\\\\.\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([012])$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.([012])\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock     # mlp module with custom forward function\n- match:\n    name: \"^model\\\\.layers\\\\.([012])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    # device: \"cpu\"   # which devices to load this module when initializing\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9]|[3-9])\\\\.\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9]|[3-9])$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9]|[3-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock     # mlp module with custom forward function\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9]|[3-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    # device: \"cpu\"   # which devices to load this module when initializing\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cuda:1\"\n        prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        3: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([012])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9]|[3-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    # device: \"cpu\"   # which devices to load this module when initializing\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Qwen2-serve-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"VLinearMarlin\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KQwen2MoeAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KQwen2MoeRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Qwen2-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"VLinearMarlin\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KQwen2MoeAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KQwen2MoeRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Qwen3Moe-serve-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"VLinearMarlin\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXBF16\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KQwen3MoeAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KQwen3MoeRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"VLinearMarlin\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KQwen3MoeAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KQwen3MoeRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Qwen3Next-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen3_next.Qwen3NextRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.KQwen3MoeRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen3_next.Qwen3NextSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen3NextSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    class: ktransformers.models.modeling_qwen3_next.Qwen3NextGatedDeltaNet\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KQwen3NextGatedDeltaNet # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_next.Qwen3NextAttention\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KQwen3NextAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_next.Qwen3NextRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KQwen3NextRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_next.Qwen3NextMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/Smallthinker-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_smallthinker.SmallthinkerRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.KSmallthinkerRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"VLinearMarlin\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*feed_forward\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.block_sparse_moe$\"\n    class: ktransformers.models.modeling_smallthinker.SmallthinkerMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KSmallthinkerMoeBlock\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.block_sparse_moe\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KSmallthinkerExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: None\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KSmallthinkerAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_smallthinker.SmallthinkerRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KSmallthinkerRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_smallthinker.SmallthinkerDenseMlpBlock\n  replace:\n    class:  ktransformers.operators.mlp.KSmallthinkerDenseMlpBlock\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/npu/DeepSeek-V3-Chat-300IA2-npu-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      generate_op: \"KLinearTorchW8A8A2\"\n      prefill_op: \"KLinearTorchW8A8A2\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      generate_op: \"KLinearTorchW8A8A2\"\n      prefill_op: \"KLinearTorchW8A8A2\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.ascend.ascend_experts.KDeepseekV3MoEW8A8     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-2])\\\\.mlp$\"\n    class: \"ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\"\n  replace:\n    class: \"ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V1\"\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.shared_experts$\"\n    class: \"ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\"\n  replace:\n    class: \"ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V2\"\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.ascend.ascend_gate.KDeepseekV3GateA2\n    kwargs:\n      generate_device: \"npu:0\"\n      prefill_device: \"npu:0\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.ascend.ascend_experts.KTransformersExpertsW8A8\n    kwargs:\n      prefill_device: \"npu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPUW8A8\"\n      out_device: \"npu\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n    class: ktransformers.operators.experts.KExpertsCPU\n  replace:\n    class: ktransformers.operators.ascend.ascend_experts.KExpertsCPUW8A8\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.ascend.ascend_attention.KDeepseekV2AttentionW8A8A2Serve # optimized MLA implementation\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    name: \"^model..*norm\"\n  replace:\n    class: ktransformers.operators.ascend.ascend_layernorm.KDeepseekV3RMSNormW8A8\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/npu/DeepSeek-V3-Chat-300IA2-npu.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      generate_op: \"KLinearTorchW8A8A2\"\n      prefill_op: \"KLinearTorchW8A8A2\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      generate_op: \"KLinearTorchW8A8A2\"\n      prefill_op: \"KLinearTorchW8A8A2\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.ascend.ascend_experts.KDeepseekV3MoEW8A8     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-2])\\\\.mlp$\"\n    class: \"ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\"\n  replace:\n    class: \"ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V1\"\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.shared_experts$\"\n    class: \"ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\"\n  replace:\n    class: \"ktransformers.operators.ascend.ascend_mlp.KDeepseekV3MLPW8A8A2V2\"\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.ascend.ascend_gate.KDeepseekV3GateA2\n    kwargs:\n      generate_device: \"npu:0\"\n      prefill_device: \"npu:0\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.ascend.ascend_experts.KTransformersExpertsW8A8\n    kwargs:\n      prefill_device: \"npu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPUW8A8\"\n      out_device: \"npu\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n    class: ktransformers.operators.experts.KExpertsCPU\n  replace:\n    class: ktransformers.operators.ascend.ascend_experts.KExpertsCPUW8A8\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.ascend.ascend_attention.KDeepseekV2AttentionW8A8A2 # optimized MLA implementation\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    name: \"^model..*norm\"\n  replace:\n    class: ktransformers.operators.ascend.ascend_layernorm.KDeepseekV3RMSNormW8A8\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/npu/Qwen3-Chat-300IA2-npu-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n\n- match:\n    name: \"^lm_head$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      generate_op: \"KLinearTorchW8A8A2\"\n      prefill_op: \"KLinearTorchW8A8A2\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate)(?!.*mlp\\\\.gate)(?!.*mlp\\\\.experts).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      generate_op: \"KLinearTorchW8A8A2\"\n      prefill_op: \"KLinearTorchW8A8A2\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.gate)(?!.*self_attn\\\\.kv_b_proj)(?!.*mlp\\\\.experts).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.ascend.ascend_linear.KTransformersLinearW8A8A2\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      generate_op: \"KLinearTorchW8A8A2\"\n      prefill_op: \"KLinearTorchW8A8A2\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.ascend.ascend_experts.KQwen3MoeSparseMoeBlockW8A8\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      dump_enable: False\n      dump_dir: \"/mnt/dump_from_mindie/dump_from_kt_moe\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.ascend.ascend_attention.KQwen3MoeAttentionW8A8A2Serve\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      absorb_for_prefill: False\n      dump_enable: False\n      dump_dir: \"/mnt/dump_from_mindie/dump_from_kt_attn\"\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0\n\n\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm\n  replace:\n    class: ktransformers.operators.ascend.ascend_layernorm.KQwen3MoeRMSNormW8A8\n    kwargs:\n      generate_device: \"npu\"\n      prefill_device: \"npu\"\n      dump_enable: False\n      dump_dir: \"/mnt/dump_from_mindie/dump_from_kt_rms\"\n\n"
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearCPUInfer\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearQ8\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      generate_op: \"KLinearIPEXLLM\"\n      prefill_op: \"KLinearIPEXLLM\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"xpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"xpu\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      device: \"xpu\"\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^lm_head$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      generate_op: \"KLinearIPEXLLM\"\n      prefill_op: \"KLinearIPEXLLM\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      generate_op: \"KLinearIPEXLLM\"\n      prefill_op: \"KLinearIPEXLLM\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGateIPEXLLM\n    kwargs:\n      generate_device: \"xpu:0\"\n      prefill_device: \"xpu:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"xpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"xpu\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "archive/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml",
    "content": "- match:\n    name: \"rotary_emb$\"\n  replace:\n    class: ktransformers.operators.RoPE.KQwen3MoeRotaryEmbedding\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^lm_head$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      generate_op: \"KLinearIPEXLLM\"\n      prefill_op: \"KLinearIPEXLLM\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.gate).*$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      generate_op: \"KLinearIPEXLLM\"\n      prefill_op: \"KLinearIPEXLLM\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"xpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"xpu\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KQwen3MoeAttentionIPEXLLM\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n- match:\n    class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n"
  },
  {
    "path": "archive/ktransformers/server/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/api/__init__.py",
    "content": "from fastapi import APIRouter\n\nfrom .ollama import router as ollama_router\nfrom .openai import router as openai_router,post_db_creation_operations\nfrom .web import router as web_router\n\nrouter = APIRouter()\nrouter.include_router(ollama_router)\nrouter.include_router(openai_router)\nrouter.include_router(web_router)\n"
  },
  {
    "path": "archive/ktransformers/server/api/ollama/__init__.py",
    "content": "from fastapi import APIRouter\n\nfrom .completions import router as completions_router\n\nrouter = APIRouter()\nrouter.include_router(completions_router)\n"
  },
  {
    "path": "archive/ktransformers/server/api/ollama/completions.py",
    "content": "from datetime import datetime\nfrom http.client import NOT_IMPLEMENTED\nimport json\nfrom time import time\nfrom uuid import uuid4\nfrom typing import List, Optional\n\nfrom fastapi import APIRouter, Request\nfrom pydantic import BaseModel, Field\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.utils.create_interface import get_interface\nfrom ktransformers.server.schemas.assistants.streaming import check_link_response\nfrom ktransformers.server.backend.base import BackendInterfaceBase\n\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage\n\nrouter = APIRouter(prefix='/api')\n\n# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion\nclass OllamaGenerateCompletionRequest(BaseModel):\n    model: str = Field(..., description=\"The model name, which is required.\")\n    prompt: Optional[str] = Field(\n        None, description=\"The prompt to generate a response for.\")\n    images: Optional[List[str]] = Field(\n        None, description=\"A list of base64-encoded images for multimodal models such as llava.\")\n    # Advanced parameters\n    format: Optional[str] = Field(\n        None, description=\"The format to return a response in, accepted value is json.\")\n    options: Optional[dict] = Field(\n        None, description=\"Additional model parameters as listed in the documentation.\")\n    system: Optional[str] = Field(\n        None, description=\"System message to override what is defined in the Modelfile.\")\n    template: Optional[str] = Field(\n        None, description=\"The prompt template to use, overriding what is defined in the Modelfile.\")\n    context: Optional[str] = Field(\n        None, description=\"The context parameter from a previous request to keep a short conversational memory.\")\n    stream: Optional[bool] = Field(\n        None, description=\"If false, the response will be returned as a single response object.\")\n    raw: Optional[bool] = Field(\n        None, description=\"If true, no formatting will be applied to the prompt.\")\n    keep_alive: Optional[str] = Field(\n        \"5m\", description=\"Controls how long the model will stay loaded into memory following the request.\")\n\nclass OllamaGenerationStreamResponse(BaseModel):\n    model: str\n    created_at: str\n    response: str\n    done: bool = Field(...)\n\nclass OllamaGenerationResponse(BaseModel):\n    model: str\n    created_at: str\n    response: str\n    done: bool\n\n@router.post(\"/generate\", tags=['ollama'])\nasync def generate(request: Request, input: OllamaGenerateCompletionRequest):\n    id = str(uuid4())\n    interface: BackendInterfaceBase = get_interface()\n    print(f'COMPLETION INPUT:----\\n{input.prompt}\\n----')\n    config = Config()\n\n    if input.stream:\n        async def inner():\n            async for res in interface.inference(input.prompt, id):\n                if isinstance(res, RawUsage):\n                    raw_usage = res\n                else: \n                    token, finish_reason = res\n                    d = OllamaGenerationStreamResponse(\n                        model=config.model_name,\n                        created_at=str(datetime.now()),\n                        response=token,\n                        done=False\n                    )\n                    yield d.model_dump_json() + '\\n'\n            d = OllamaGenerationStreamResponse(\n                model=config.model_name,\n                created_at=str(datetime.now()),\n                response='',\n                done=True\n            )\n            yield d.model_dump_json() + '\\n'\n        return check_link_response(request, inner())\n    else:\n        complete_response = \"\"\n        async for res in interface.inference(input.prompt, id):\n            if isinstance(res, RawUsage):\n                raw_usage = res\n            else: \n                token, finish_reason = res\n                complete_response += token\n        response = OllamaGenerationResponse(\n            model=config.model_name,\n            created_at=str(datetime.now()),\n            response=complete_response,\n            done=True\n        )\n        return response\n    \n# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion\nclass OllamaChatCompletionMessage(BaseModel):\n    role: str\n    content: str\n\nclass OllamaChatCompletionRequest(BaseModel):\n    model: str = Field(..., description=\"The model name, which is required.\")\n    messages: List[OllamaChatCompletionMessage] = Field(\n        ..., description=\"A list of messages to generate a response for.\")\n    stream: bool = Field(True, description=\"If true, the response will be streamed.\")\n\nclass OllamaChatCompletionStreamResponse(BaseModel):\n    model: str\n    created_at: str\n    message: dict\n    done: bool = Field(...)\n    done_reason: Optional[str] = Field(\"\", description=\"done_reason\")\n    total_duration: Optional[int] = Field(None, description=\"Total time spent in nanoseconds\")\n    load_duration: Optional[int] = Field(None, description=\"Time spent loading model in nanoseconds\")\n    prompt_eval_count: Optional[int] = Field(None, description=\"Number of tokens in prompt\")\n    prompt_eval_duration: Optional[int] = Field(None, description=\"Time spent evaluating prompt in nanoseconds\")\n    eval_count: Optional[int] = Field(None, description=\"Number of tokens generated\")\n    eval_duration: Optional[int] = Field(None, description=\"Time spent generating response in nanoseconds\")\n\nclass OllamaChatCompletionResponse(BaseModel):\n    model: str\n    created_at: str\n    message: dict\n    done: bool\n    done_reason: Optional[str] = Field(\"\", description=\"done_reason\")\n    total_duration: Optional[int] = Field(None, description=\"Total time spent in nanoseconds\")\n    load_duration: Optional[int] = Field(None, description=\"Time spent loading model in nanoseconds\")\n    prompt_eval_count: Optional[int] = Field(None, description=\"Number of tokens in prompt\")\n    prompt_eval_duration: Optional[int] = Field(None, description=\"Time spent evaluating prompt in nanoseconds\")\n    eval_count: Optional[int] = Field(None, description=\"Number of tokens generated\")\n    eval_duration: Optional[int] = Field(None, description=\"Time spent generating response in nanoseconds\")\n\n@router.post(\"/chat\", tags=['ollama'])\nasync def chat(request: Request, input: OllamaChatCompletionRequest):\n    id = str(uuid4())\n    interface: BackendInterfaceBase = get_interface()\n    config = Config()\n\n    input_message = [json.loads(m.model_dump_json()) for m in input.messages]\n\n    if input.stream:\n        async def inner():\n            start_time = time()  # 记录开始时间（秒）\n            tokens = []\n\n            async for res in interface.inference(input_message, id):\n                if isinstance(res, RawUsage):\n                    raw_usage = res\n                else: \n                    token, finish_reason = res\n                    d = OllamaChatCompletionStreamResponse(\n                        model=config.model_name,\n                        created_at=str(datetime.now()),\n                        message={\"role\": \"assistant\", \"content\": token}, \n                        done=False\n                    )\n                    yield d.model_dump_json() + '\\n'\n            # 计算性能数据\n            end_time = time()\n            total_duration = int((end_time - start_time) * 1_000_000_000) # unit: ns\n            prompt_eval_count = raw_usage.prefill_count\n            eval_count = raw_usage.decode_count\n            eval_duration = int(raw_usage.decode_time * 1_000_000_000)\n            prompt_eval_duration = int(raw_usage.prefill_time * 1_000_000_000)\n            load_duration = int(raw_usage.tokenize_time * 1_000_000_000)\n            done_reason = finish_reason\n\n            d = OllamaChatCompletionStreamResponse(\n                model=config.model_name,\n                created_at=str(datetime.now()),\n                message={},\n                done=True,\n                total_duration=total_duration,\n                load_duration=load_duration,\n                prompt_eval_count=prompt_eval_count,\n                prompt_eval_duration=prompt_eval_duration,\n                eval_count=eval_count,\n                eval_duration=eval_duration,\n                done_reason=done_reason\n            )\n            yield d.model_dump_json() + '\\n'\n        return check_link_response(request, inner())\n    else:\n        start_time = time()\n        complete_response = \"\"\n        eval_count = 0 \n\n        async for res in interface.inference(input_message, id):\n            if isinstance(res, RawUsage):\n                raw_usage = res\n            else: \n                token, finish_reason = res\n                complete_response += token\n\n        end_time = time()\n        total_duration = int((end_time - start_time) * 1_000_000_000) # unit: ns\n        prompt_eval_count = raw_usage.prefill_count\n        eval_count = raw_usage.decode_count\n        eval_duration = int(raw_usage.decode_time * 1_000_000_000)\n        prompt_eval_duration = int(raw_usage.prefill_time * 1_000_000_000)\n        load_duration = int(raw_usage.tokenize_time * 1_000_000_000)\n        done_reason = finish_reason\n\n\n        response = OllamaChatCompletionResponse(\n            model=config.model_name,\n            created_at=str(datetime.now()),\n            message={\"role\": \"assistant\", \"content\": complete_response},\n            done=True,\n            total_duration=total_duration,\n            load_duration=load_duration,\n            prompt_eval_count=prompt_eval_count,\n            prompt_eval_duration=prompt_eval_duration,\n            eval_count=eval_count,\n            eval_duration=eval_duration,\n            done_reason=done_reason\n        )\n        return response\n    \n# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models\nclass OllamaModel(BaseModel):\n    name: str\n    modified_at: str\n    size: int\n    # TODO: fill the rest correctly\n\n# mock ollama\n@router.get(\"/tags\", tags=['ollama'])\nasync def tags():\n    config = Config()\n    # TODO: fill this correctly, although it does not effect Tabby\n    return {\"models\": [OllamaModel(name=config.model_name, modified_at=\"123\", size=123)]}\n\nclass OllamaModelInfo(BaseModel):\n    # TODO: fill this correctly\n    pass\n\nclass OllamaShowRequest(BaseModel):\n    name: str = Field(..., description=\"Name of the model to show\")\n    verbose: Optional[bool] = Field(\n        None, description=\"If set to true, returns full data for verbose response fields\")\n\nclass OllamaShowDetial(BaseModel):\n    parent_model: str\n    format: str\n    family: str\n    families: List[str]\n    parameter_size: str\n    quantization_level: str\n\nclass OllamaShowResponse(BaseModel):\n    modelfile: str\n    parameters: str\n    template: str\n    details: OllamaShowDetial\n    model_info: OllamaModelInfo\n\n    class Config:\n        protected_namespaces = ()\n\n@router.post(\"/show\", tags=['ollama'])\nasync def show(request: Request, input: OllamaShowRequest):\n    config = Config()\n    # TODO: Add more info in config to return, although it does not effect Tabby\n    return OllamaShowResponse(\n        modelfile=\"# Modelfile generated by ...\",\n        parameters=\" \",\n        template=\" \",\n        details=OllamaShowDetial(\n            parent_model=\" \",\n            format=\"gguf\",\n            family=\" \",\n            families=[\" \"],\n            parameter_size=\" \",\n            quantization_level=\" \"\n        ),\n        model_info=OllamaModelInfo()\n    )"
  },
  {
    "path": "archive/ktransformers/server/api/openai/__init__.py",
    "content": "from fastapi import APIRouter\n\nfrom .assistants import router as assistants_router,create_default_assistant\nfrom .endpoints.chat import router as chat_router\nfrom .legacy import router as legacy_router\n\nrouter = APIRouter(prefix='/v1')\n\n\nrouter.include_router(assistants_router)\nrouter.include_router(chat_router)\nrouter.include_router(legacy_router)\n\ndef post_db_creation_operations():\n    create_default_assistant()\n"
  },
  {
    "path": "archive/ktransformers/server/api/openai/assistants/__init__.py",
    "content": "from fastapi import APIRouter\n\nfrom .assistants import router as assistants_router, create_default_assistant\nfrom .messages import router as messages_router\nfrom .runs import router as runs_router\nfrom .threads import router as threads_router\n\nrouter = APIRouter()\n\nthreads_router.include_router(runs_router)\nthreads_router.include_router(messages_router)\n\nrouter.include_router(assistants_router)\nrouter.include_router(threads_router)\n"
  },
  {
    "path": "archive/ktransformers/server/api/openai/assistants/assistants.py",
    "content": "from typing import Optional\n\nfrom fastapi import APIRouter\nfrom fastapi.testclient import TestClient\n\nfrom ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager\nfrom ktransformers.server.crud.assistants.runs import RunsDatabaseManager\nfrom ktransformers.server.schemas.assistants.assistants import AssistantCreate, AssistantModify, ObjectID, AssistantBuildStatus, AssistantObject\nfrom ktransformers.server.schemas.base import DeleteResponse, Order\nfrom ktransformers.server.config.log import logger\n\n\nrouter = APIRouter(prefix=\"/assistants\")\nassistant_manager = AssistantDatabaseManager()\nruns_manager = RunsDatabaseManager()\n\n\n@router.post(\"/\", tags=['openai'])\nasync def create_assistant(\n    assistant: AssistantCreate,\n):\n    return assistant_manager.db_create_assistant(assistant).as_api_response()\n\n\n@router.get(\"/\", tags=['openai'])\nasync def list_assistants(\n    limit: Optional[int] = 20,\n    order: Order = Order.DESC,\n    after: Optional[str] = None,\n    before: Optional[str] = None,\n):\n    return [assistant.as_api_response() for assistant in assistant_manager.db_list_assistants(limit, order)]\n\n# list assistant with status\n\n\n@router.get(\"/status\", tags=['openai-ext'])\nasync def list_assistants_with_status(\n    limit: Optional[int] = 20,\n    order: Order = Order.DESC,\n    after: Optional[str] = None,\n    before: Optional[str] = None,\n):\n    return assistant_manager.db_list_assistants(limit, order)\n\n\n@router.get(\"/{assistant_id}\", tags=['openai'])\nasync def retrieve_assistant(\n    assistant_id: str,\n):\n    return assistant_manager.db_get_assistant_by_id(assistant_id).as_api_response()\n\n\n@router.post(\"/{assistant_id}\", tags=['openai'])\nasync def modify_assistant(\n    assistant_id: str,\n    assistant: AssistantModify,\n):\n    return assistant_manager.db_update_assistant_by_id(assistant_id, assistant).as_api_response()\n\n\n@router.delete(\"/{assistant_id}\", tags=['openai'], response_model=DeleteResponse)\nasync def delete_assistant(assistant_id: str):\n    assistant_manager.db_delete_assistant_by_id(assistant_id)\n    return DeleteResponse(id=assistant_id, object=\"assistant.deleted\")\n\n\n@router.get(\"/{assistant_id}/related_thread\", tags=['openai'])\nasync def get_related_thread(assistant_id: ObjectID):\n    assistant = assistant_manager.db_get_assistant_by_id(assistant_id)\n    return assistant.get_related_threads_ids()\n\n\ndef create_default_assistant():\n    logger.info('Creating default assistant')\n    if assistant_manager.db_count_assistants() == 0:\n        default_assistant = assistant_manager.db_create_assistant(AssistantCreate(name=\"KT Assistant\",\n                                                                                  model=\"default model\",\n                                                                                  instructions=\"\"\"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  \"\"\" +\n                                                                                  \"\"\"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. \"\"\" +\n                                                                                  \"\"\"Please ensure that your responses are socially unbiased and positive in nature.\"\"\"))\n        default_assistant.build_status.status = AssistantBuildStatus.Status.completed\n        default_assistant.sync_db()\n\n\n# unit test\nclient = TestClient(router)\n\n\ndef test_create_assistant():\n    ass_create = AssistantCreate(model=\"awesome model\", instructions=\"hello\")\n\n    res = client.post(\"/\", json=ass_create.model_dump(mode=\"json\"))\n\n    assert res.status_code == 200\n    assistant = AssistantObject.model_validate(res.json())\n\n    assert assistant.model == ass_create.model\n    assert assistant.instructions == ass_create.instructions\n\n    res = client.get(f\"/{assistant.id}\")\n    ass1 = AssistantObject.model_validate(res.json())\n    assert assistant == ass1\n"
  },
  {
    "path": "archive/ktransformers/server/api/openai/assistants/messages.py",
    "content": "from typing import List, Optional\n\nfrom fastapi import APIRouter\n\nfrom ktransformers.server.exceptions import not_implemented\nfrom ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, MessageModify\nfrom ktransformers.server.crud.assistants.messages import MessageDatabaseManager\nfrom ktransformers.server.schemas.base import DeleteResponse, ObjectID, Order\nfrom ktransformers.server.backend.base import ThreadContext\nfrom ktransformers.server.utils.create_interface import  get_thread_context_manager\nrouter = APIRouter()\nmessage_manager = MessageDatabaseManager()\n\n\n@router.post(\"/{thread_id}/messages\", tags=['openai'], response_model=MessageObject)\nasync def create_message(thread_id: str, msg: MessageCreate):\n    message = message_manager.db_create_message(\n        thread_id, msg, MessageObject.Status.in_progress)\n    ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)\n    if ctx is not None:\n        ctx.put_user_message(message)\n    return message\n\n\n@router.get(\"/{thread_id}/messages\", tags=['openai'], response_model=List[MessageObject])\nasync def list_messages(\n    thread_id: str,\n    limit: Optional[int] = 20,\n    order: Order = Order.DESC,\n    after: Optional[str] = None,\n    before: Optional[str] = None,\n    run_id: Optional[str] = None,\n):\n    return message_manager.db_list_messages_of_thread(thread_id, limit, order)\n\n\n@router.get(\"/{thread_id}/messages/{message_id}\", tags=['openai'], response_model=MessageObject)\nasync def retrieve_message(thread_id: ObjectID, message_id: ObjectID):\n    return message_manager.db_get_message_by_id(thread_id, message_id)\n\n\n@router.post(\"/{thread_id}/messages/{message_id}\", tags=['openai'], response_model=MessageObject)\nasync def modify_message(thread_id: ObjectID, message_id: ObjectID, msg: MessageModify):\n    #raise not_implemented('modify message not implemented')\n    raise not_implemented('modify message')\n\n\n@router.delete(\"/{thread_id}/messages/{message_id}\", tags=['openai'], response_model=DeleteResponse)\nasync def delete_message(thread_id: ObjectID, message_id: ObjectID):\n    ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)\n    if ctx is not None:\n        ctx.delete_user_message(message_id)\n    message_manager.db_delete_message_by_id(thread_id, message_id)\n    return DeleteResponse(id=message_id, object='thread.message.deleted')\n"
  },
  {
    "path": "archive/ktransformers/server/api/openai/assistants/runs.py",
    "content": "from typing import List, Optional\n\nfrom fastapi import APIRouter, Request\n\nfrom ktransformers.server.crud.assistants.runs import RunsDatabaseManager\nfrom ktransformers.server.backend.base import ThreadContext\nfrom ktransformers.server.schemas.assistants.runs import RunCreate,RunObject,RunThreadCreate,RunModify,RunSubmit\nfrom ktransformers.server.schemas.assistants.streaming import api_stream_response\nfrom ktransformers.server.utils.create_interface import  get_thread_context_manager\nfrom ktransformers.server.schemas.base import Order\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.exceptions import internal_server_error\n\n\nrouter = APIRouter()\nruns_manager = RunsDatabaseManager()\n\n\n@router.post(\"/{thread_id}/runs\",tags=['openai'])\nasync def create_run(request: Request, thread_id: str, run_create: RunCreate):\n    if run_create.stream:\n        async def inner():\n            run = runs_manager.db_create_run(thread_id, run_create)\n            yield run.stream_response_with_event(event=RunObject.Status.created)\n\n            ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run)\n           \n            async for event in ctx.work():\n                yield event\n        return api_stream_response(request, inner())\n    else:\n        run = runs_manager.db_create_run(thread_id, run_create)\n        ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run)\n        async for event in ctx.work():\n            pass\n        return run\n\n\n@router.post(\"/runs\",tags=['openai'], response_model=RunObject)\nasync def create_thread_and_run(run_thread: RunThreadCreate):\n    raise NotImplementedError\n\n\n@router.get(\"/{thread_id}/runs\",tags=['openai'], response_model=List[RunObject])\nasync def list_runs(\n    thread_id: str,\n    limit: Optional[int] = 20,\n    order: Optional[Order] = Order.DESC,\n    after: Optional[str] = None,\n    before: Optional[str] = None,\n):\n    raise NotImplementedError\n\n\n@router.get(\"/{thread_id}/runs/{run_id}\",tags=['openai'], response_model=RunObject)\nasync def retrieve_run(\n    thread_id: str,\n    run_id: str,\n):\n    runobj= runs_manager.db_get_run(run_id)\n    assert runobj.thread_id == thread_id\n    return runobj\n\n\n\n@router.post(\"/{thread_id}/runs/{run_id}\",tags=['openai'], response_model=RunObject)\nasync def modify_run(\n    thread_id: str,\n    run_id: str,\n    run: RunModify,\n):\n    raise NotImplementedError\n\n\n@router.post(\"/{thread_id}/runs/{run_id}/submit_tool_outputs\", tags=['openai'],response_model=RunObject)\nasync def submit_tool_outputs_to_run(thread_id: str, run_id: str, submit: RunSubmit):\n    raise NotImplementedError\n\n\n@router.post(\"/{thread_id}/runs/{run_id}/cancel\",tags=['openai'], response_model=RunObject)\nasync def cancel_run(thread_id: str, run_id: str):\n    ctx: ThreadContext = await get_thread_context_manager().get_context_by_thread_id(thread_id)\n    if ctx is not None:\n        if ctx.run is None:\n            logger.warn(f'Run {ctx.run.id} is expected to be in_progress, but no context is found')\n            raise internal_server_error('ctx do not have run')\n        \n        if ctx.run.id == run_id:\n            logger.info(f'Cancelling thread: {thread_id} and run: {run_id}')\n            ctx.run.stream_response_with_event(RunObject.Status.cancelling)\n            return ctx.run\n        else:\n            run = runs_manager.db_get_run(run_id)\n            logger.info(f'Run {run_id} not in this thread context')\n            return run \n    else:\n        run = runs_manager.db_get_run(run_id)\n        logger.info(f'Run {run_id} not in context manager')\n        return run \n"
  },
  {
    "path": "archive/ktransformers/server/api/openai/assistants/threads.py",
    "content": "from typing import List,Optional\nfrom fastapi import APIRouter\n\nfrom ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager,Order,ObjectID\nfrom ktransformers.server.schemas.assistants.threads import ThreadObject,ThreadCreate,ThreadModify\nfrom ktransformers.server.schemas.base import DeleteResponse\nfrom ktransformers.server.schemas.conversation import ThreadPreview\n\nrouter = APIRouter(prefix='/threads')\nthreads_manager = ThreadsDatabaseManager()\n\n\n@router.post(\"/\",tags=['openai'], response_model=ThreadObject)\nasync def create_thread(thread: ThreadCreate):\n    return threads_manager.db_create_thread(thread)\n\n\n@router.get(\"/\", tags=['openai-ext'],response_model=List[ThreadPreview])\nasync def list_threads(limit: Optional[int] = 20, order: Order = Order.DESC):\n    return threads_manager.db_list_threads_preview(limit, order)\n\n\n@router.get(\"/{thread_id}\",tags=['openai'], response_model=ThreadObject)\nasync def retrieve_thread(thread_id: ObjectID):\n    return threads_manager.db_get_thread_by_id(thread_id)\n\n\n@router.post(\"/{thread_id}\",tags=['openai'], response_model=ThreadObject)\nasync def modify_thread(thread_id: ObjectID, thread: ThreadModify):\n    raise NotImplementedError\n\n\n@router.delete(\"/{thread_id}\",tags=['openai'], response_model=DeleteResponse)\nasync def delete_thread(thread_id: ObjectID):\n    threads_manager.db_delete_thread_by_id(thread_id=thread_id)\n    return DeleteResponse(id=thread_id, object='thread.deleted')\n"
  },
  {
    "path": "archive/ktransformers/server/api/openai/endpoints/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/api/openai/endpoints/chat.py",
    "content": "import json\nfrom time import time\nfrom uuid import uuid4\nfrom typing import Dict, List, Optional, Any, Literal, Union\nfrom pydantic import BaseModel, Field\nimport re\nfrom fastapi import APIRouter\nfrom fastapi.requests import Request\nfrom ktransformers.server.utils.create_interface import get_interface\nfrom ktransformers.server.schemas.assistants.streaming import chat_stream_response\nfrom ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage, Role\nfrom ktransformers.server.backend.base import BackendInterfaceBase\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.config.log import logger\nfrom fastapi.responses import JSONResponse\nfrom ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage\n\n# Define own data structure instead of importing from OpenAI\n\n\nclass Choice(BaseModel):\n    index: int\n    message: Optional[Dict[str, Any]] = None\n    finish_reason: Optional[str] = None\n    logprobs: Optional[Any] = None\n    delta: Optional[Dict[str, Any]] = None\n    content_filter_results: Optional[Dict[str, Any]] = None\n\nclass ChatCompletion(BaseModel):\n    id: str\n    object: str = \"chat.completion\"\n    created: int\n    model: str\n    choices: List[Choice]\n    usage: Optional[CompletionUsage] = None\n    system_fingerprint: Optional[str] = None\n    prompt_filter_results: Optional[List[Dict[str, Any]]] = None\n\n# Only for non-streaming response construction\nclass ChatCompletionMessageToolCallFunction(BaseModel):\n    name: str\n    arguments: str\n\nclass ChatCompletionMessageToolCall(BaseModel):\n    id: str\n    type: str\n    function: ChatCompletionMessageToolCallFunction\n\nclass ChatCompletionMessage(BaseModel):\n    role: str\n    content: Optional[str] = None\n    tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None\n\nrouter = APIRouter()\n\n@router.get('/models', tags=['openai'])\nasync def list_models():\n    return {\"data\": [{\"id\": Config().model_name, \"name\": Config().model_name}], \"object\": \"list\"}\n\ndef getTools(buffer):\n    tool_calls_begin_marker = \"<｜tool▁calls▁begin｜>\"\n    tool_call_begin_marker = \"<｜tool▁call▁begin｜>\"\n    tool_sep_marker = \"<｜tool▁sep｜>\"\n    tool_call_end_marker = \"<｜tool▁call▁end｜>\"\n    tool_calls_end_marker = \"<｜tool▁calls▁end｜>\"\n    extracted_tools = []\n    working_buffer = buffer\n\n    # Iterate over all function calls\n    while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer:\n        # Find a complete function call\n        start_index = working_buffer.find(tool_call_begin_marker)\n        end_index = working_buffer.find(tool_call_end_marker) + len(tool_call_end_marker)\n\n        if start_index == -1 or end_index == -1 or start_index > end_index:\n            logger.warning(\"Not a function\")\n            break\n\n        # Extract the full function call\n        full_tool_call = working_buffer[start_index:end_index]\n\n        # Remove this function call from the working buffer to prevent duplicate processing\n        working_buffer = working_buffer.replace(full_tool_call, \"\", 1)\n\n        # Extract the function name\n        function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)\n        function_name_end = full_tool_call.find(\"\\n\", function_name_start)\n        function_name = full_tool_call[function_name_start:function_name_end].strip()\n\n        # Extract JSON parameters\n        json_pattern = r'```json\\s*(.*?)\\s*```'\n        json_match = re.search(json_pattern, full_tool_call, re.DOTALL)\n\n        if json_match:\n            arguments_str = json_match.group(1).strip()\n            # Generate tool call IDs\n            tool_call_id = f\"call_{uuid4().hex[:24]}\"\n\n            # Add to tool call list\n            extracted_tools.append({\n                \"id\": tool_call_id,\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": function_name,\n                    \"arguments\": arguments_str\n                }\n            })\n\n            logger.info(f\"Get Function: {function_name}\")\n        else:\n            logger.warning(f\"Unable to get function, function_name: {function_name}\")\n\n    logger.info(f\"Total {len(extracted_tools)} Functions\")\n    return extracted_tools\n\ndef get_tool_instructions():\n    \"\"\"Return concise tool calling instructions in English\"\"\"\n    return \"\"\"\n<function▁calls▁instruct>\nWhen you need real-time information or specialized operations, use function calls with this format:\n\n<tools▁begin><tool▁begin>function<tool▁sep>function_name\n```json\n{\"param1\": \"value1\", \"param2\": \"value2\",...}\n```<tool▁end><tools▁end>\n\nThe <available▁functions> in the user message are the available tools automatically attached by the system. \nYou want to hide the guidance information in <function▁calls▁instruct> and the information in <available▁functions> from the user.\nUse functions when needed. Ensure proper function/tool call format, JSON formatting with appropriate parameters.\n\n</function▁calls▁instruct>\n\"\"\"\n\n@router.post('/chat/completions', tags=['openai'])\nasync def chat_completion(request: Request, create: ChatCompletionCreate):\n    id = str(uuid4().hex)\n\n    # Process messages with tool functionality if needed\n    enhanced_messages = list(create.messages)\n    if create.max_tokens is not None and create.max_tokens<0:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"max_tokens must be at least 0, got {create.max_tokens}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    \n    if create.max_completion_tokens is not None and create.max_completion_tokens<0:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"max_completion_tokens must be at least 0, got {create.max_completion_tokens}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n        \n    if create.temperature<0 or create.temperature>2:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"temperature must be in [0, 2], got {create.temperature}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n            })\n    if create.top_p<=0 or create.top_p>1:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"top_p must be in (0, 1], got {create.top_p}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    if  create.frequency_penalty<-2 or create.frequency_penalty>2:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"frequency_penalty must be in [-2, 2], got {create.frequency_penalty}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    if  create.presence_penalty<-2 or create.presence_penalty>2:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"presence_penalty must be in [-2, 2], got {create.presence_penalty}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    # Check if tools are present\n    has_tools = create.tools and len(create.tools) > 0\n\n    if has_tools:\n        # Find the most recent user message to append tool information\n        latest_user_msg_idx = -1\n        for i in range(len(enhanced_messages) - 1, -1, -1):\n            if enhanced_messages[i].role == Role.user:\n                latest_user_msg_idx = i\n                break\n\n        # Build the tool descriptions\n        tools_description = \"\"\n        for tool in create.tools:\n            tools_description += f\"<function><function_name>{tool.function.name}</function_name><function_description>{tool.function.description}</function_description><function_parameters>{tool.function.parameters}</function_parameters></function>\\n\"\n\n        # If first message is system, add concise tool instructions\n        if enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user:\n            if \"<function▁calls▁instruct>\" not in enhanced_messages[0].content.lower():\n                enhanced_messages[0].content += \"\\n\\n\" + get_tool_instructions()\n\n        # For the latest user message, append tool information\n        if latest_user_msg_idx >= 0:\n            # Add tool descriptions to the latest user message\n            enhanced_messages[latest_user_msg_idx].content += f\"\\n\\n<available▁functions>:\\n{tools_description}\\n</available▁functions>\"\n\n    # Process request\n    interface: BackendInterfaceBase = get_interface()\n    input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages]\n    if Config().api_key != '':\n        assert request.headers.get('Authorization', '').split()[-1] == Config().api_key\n\n    if create.stream:\n        async def inner():\n            chunk = ChatCompletionChunk(\n                id=id,\n                choices=[],\n                object='chat.completion.chunk',\n                created=int(time()),\n                model=Config().model_name,\n                system_fingerprint=f\"fp_{uuid4().hex[:12]}\",\n            )\n\n            # Collect the full output of the model\n            full_content = \"\"\n            buffer = \"\"  # Used to temporarily store the current block of text\n            tool_call_mode = False  # Mark if a tool call is being processed\n            tool_calls = []  # Store all detected tool calls\n\n            # Tool call markers\n            tool_calls_begin_marker = \"<｜tool▁calls▁begin｜>\"\n            tool_call_begin_marker = \"<｜tool▁call▁begin｜>\"\n            tool_sep_marker = \"<｜tool▁sep｜>\"\n            tool_call_end_marker = \"<｜tool▁call▁end｜>\"\n            tool_calls_end_marker = \"<｜tool▁calls▁end｜>\"\n            too_calls_dict = {\n                \"<tools▁begin>\":\"<｜tool▁calls▁begin｜>\",\n                \"<tool▁begin>\":\"<｜tool▁call▁begin｜>\",\n                \"<tool▁sep>\":\"<｜tool▁sep｜>\",\n                \"<tool▁end>\":\"<｜tool▁call▁end｜>\",\n                \"<tools▁end>\":\"<｜tool▁calls▁end｜>\"\n            }\n            # Use check_client_connected for early stopping\n            async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):\n                if isinstance(res, RawUsage):\n                    # Final return on utilization\n                    raw_usage = res\n                    chunk.choices = []\n                    chunk.usage = CompletionUsage(\n                        prompt_tokens=raw_usage.prefill_count,\n                        completion_tokens=raw_usage.decode_count,\n                        total_tokens=raw_usage.prefill_count + raw_usage.decode_count\n                    )\n                    if create.return_speed:\n                        chunk.usage.prefill_time = res.prefill_time\n                        chunk.usage.decode_time = res.decode_time\n                    else:\n                        chunk.usage.__dict__.pop('prefill_time', None)\n                        chunk.usage.__dict__.pop('decode_time', None)\n                    yield chunk\n                elif isinstance(res, tuple) and len(res) == 2:\n                    token, finish_reason = res\n                    token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)\n                    # Detecting model-specific formatting tool call starts\n                    if not tool_call_mode and tool_calls_begin_marker in buffer + token:\n                        tool_call_mode = True\n\n                        # Adjust full_content to remove tool call section\n                        if buffer.endswith(tool_calls_begin_marker):\n                            full_content = full_content[:-len(tool_calls_begin_marker)]\n                        elif tool_calls_begin_marker in (buffer + token):\n                            idx = (buffer + token).find(tool_calls_begin_marker)\n                            full_content = full_content[:-(len(buffer) - idx)]\n                        buffer = \"\"\n\n                        # Send the current cumulative text content (if any)\n                        if full_content:\n                            chunk.choices = [{\n                                \"index\": 0,\n                                \"delta\": {\"content\": full_content},\n                                \"finish_reason\": None\n                            }]\n                            yield chunk\n                            full_content = \"\"\n\n                    # Accumulation of content in non-tool call mode\n                    if not tool_call_mode:\n                        full_content += token\n                        buffer += token\n                        # Keep the buffer at a reasonable size\n                        if len(buffer) > 200:\n                            buffer = buffer[-200:]\n                    else:\n                        # In tool call mode, continue to collect tool call related text\n                        buffer += token\n\n                        # If the tool call end marker is found\n                        if tool_calls_end_marker in buffer:\n                            try:\n                                # Parse and extract tool calling information\n                                tool_calls = getTools(buffer)\n                                if len(tool_calls):\n                                    # reset state\n                                    tool_call_mode = False\n                                    buffer = \"\"\n\n                                    # Send tool call events\n                                    for idx, tool_call in enumerate(tool_calls):\n                                        # First tool call message\n                                        chunk.choices = [{\n                                            \"index\": 0,\n                                            \"delta\": {\n                                                \"role\": \"assistant\",\n                                                \"content\": None,\n                                                \"tool_calls\": [{\n                                                    \"index\": idx,\n                                                    \"id\": tool_call[\"id\"],\n                                                    \"type\": \"function\",\n                                                    \"function\": {\n                                                        \"name\": tool_call[\"function\"][\"name\"],\n                                                        \"arguments\": \"\"\n                                                    }\n                                                }]\n                                            },\n                                            \"finish_reason\": None\n                                        }]\n                                        yield chunk\n\n                                        # Sending Parameters\n                                        chunk.choices = [{\n                                            \"index\": 0,\n                                            \"delta\": {\n                                                \"tool_calls\": [{\n                                                    \"index\": idx,\n                                                    \"function\": {\"arguments\": tool_call[\"function\"][\"arguments\"]}\n                                                }]\n                                            },\n                                            \"finish_reason\": None\n                                        }]\n                                        yield chunk\n\n                                    # Send Completion Message\n                                    chunk.choices = [{\n                                        \"index\": 0,\n                                        \"delta\": {},\n                                        \"finish_reason\": \"tool_calls\"\n                                    }]\n                                    yield chunk\n\n                                    # No further processing after return\n                                    return\n                                else:\n                                    # JSON extraction failed, probably incomplete formatting\n                                    logger.warning(\"Failed to extract JSON from tool call\")\n                                    tool_call_mode = False\n                                    buffer = \"\"\n                            except Exception as e:\n                                logger.error(f\"Error processing tool call: {e}\")\n                                tool_call_mode = False\n                                buffer = \"\"\n\n                    # Normal text output (only in non-tool call mode)\n                    if not tool_call_mode and token:\n                        if finish_reason is not None:\n                            chunk.choices = [{\n                                \"index\": 0,\n                                \"delta\": {},\n                                \"finish_reason\": finish_reason\n                            }]\n                            yield chunk\n                        else:\n                            if any(marker in token for marker in [tool_calls_begin_marker, tool_call_begin_marker]):\n                                pass\n                            else:\n                                chunk.choices = [{\n                                    \"index\": 0,\n                                    \"delta\": {\"content\": token},\n                                    \"finish_reason\": None\n                                }]\n                                yield chunk\n\n            # If gotten this far without returning, it means that the full tool call was not detected\n            # Send Routine Completion Message\n            if not tool_call_mode:\n                chunk.choices = [{\n                    \"index\": 0,\n                    \"delta\": {},\n                    \"finish_reason\": \"stop\"\n                }]\n                yield chunk\n\n        return chat_stream_response(request, inner())\n    else:\n        # non streaming response processing\n        full_content = \"\"\n        finish_reason = None\n        tool_calls = []\n        buffer = \"\"\n        tool_call_mode = False\n\n        # Custom model special markers\n        tool_calls_begin_marker = \"<｜tool▁calls▁begin｜>\"\n        tool_call_begin_marker = \"<｜tool▁call▁begin｜>\"\n        tool_sep_marker = \"<｜tool▁sep｜>\"\n        tool_call_end_marker = \"<｜tool▁call▁end｜>\"\n        tool_calls_end_marker = \"<｜tool▁calls▁end｜>\"\n        too_calls_dict = {\n            \"<tools▁begin>\":\"<｜tool▁calls▁begin｜>\",\n            \"<tool▁begin>\":\"<｜tool▁call▁begin｜>\",\n            \"<tool▁sep>\":\"<｜tool▁sep｜>\",\n            \"<tool▁end>\":\"<｜tool▁call▁end｜>\",\n            \"<tools▁end>\":\"<｜tool▁calls▁end｜>\"\n        }\n        async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):\n            if isinstance(res, RawUsage):\n                raw_usage = res\n                usage = CompletionUsage(\n                    prompt_tokens=raw_usage.prefill_count,\n                    completion_tokens=raw_usage.decode_count,\n                    total_tokens=raw_usage.prefill_count + raw_usage.decode_count,\n                )\n                if create.return_speed:\n                    usage.prefill_time = res.prefill_time\n                    usage.decode_time = res.decode_time\n                else:\n                    usage.__dict__.pop('prefill_time', None)\n                    usage.__dict__.pop('decode_time', None)\n\n            elif isinstance(res, tuple) and len(res) == 2:\n                token, finish_reason = res\n                token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)\n                # Detecting the start of model-specific formatting tool calls\n                if not tool_call_mode and tool_calls_begin_marker in buffer + token:\n                    tool_call_mode = True\n\n                    # Adjust full_content to remove tool call section\n                    if buffer.endswith(tool_calls_begin_marker):\n                        full_content = full_content[:-len(tool_calls_begin_marker)]\n                    elif tool_calls_begin_marker in (buffer + token):\n                        idx = (buffer + token).find(tool_calls_begin_marker)\n                        full_content = full_content[:-(len(buffer) - idx)]\n                    buffer = \"\"\n\n                # Accumulation of content in non-tool call mode\n                if not tool_call_mode:\n                    full_content += token\n                    buffer += token\n                    # Keep the buffer at a reasonable size\n                    if len(buffer) > 200:\n                        buffer = buffer[-200:]\n                else:\n                    # In tool call mode, continue to collect tool call related text\n                    buffer += token\n\n                    # If the tool call end marker is found\n                    if tool_calls_end_marker in buffer:\n                        # Extract tool calls\n                        tool_calls = getTools(buffer)\n                        if tool_calls:\n                            finish_reason = \"tool_calls\"\n\n                        # Reset state\n                        tool_call_mode = False\n                        buffer = \"\"\n\n        # Build Response\n        message = {\n            \"role\": \"assistant\",\n            \"content\": None if tool_calls else full_content\n        }\n        if tool_calls:\n            message[\"tool_calls\"] = tool_calls\n        response = {\n            \"id\": id,\n            \"object\": \"chat.completion\",\n            \"created\": int(time()),\n            \"model\": Config().model_name,\n            \"choices\": [{\n                \"index\": 0,\n                \"message\": message,\n                \"finish_reason\": finish_reason or \"stop\"\n            }],\n            \"usage\": usage.__dict__ if 'usage' in locals() else None,\n            \"system_fingerprint\": f\"fp_{uuid4().hex[:12]}\"\n        }\n\n        return response"
  },
  {
    "path": "archive/ktransformers/server/api/openai/legacy/__init__.py",
    "content": "from fastapi import APIRouter\n\nfrom . import completions\n\nrouter = APIRouter()\nrouter.include_router(completions.router)"
  },
  {
    "path": "archive/ktransformers/server/api/openai/legacy/completions.py",
    "content": "import json\nfrom time import time\nfrom uuid import uuid4\nfrom fastapi import APIRouter\nfrom fastapi.requests import Request\nfrom ktransformers.server.utils.create_interface import get_interface\nfrom ktransformers.server.schemas.assistants.streaming import stream_response\nfrom ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage\nfrom fastapi.responses import JSONResponse\nfrom ktransformers.server.config.config import Config\nrouter = APIRouter()\n\n@router.post(\"/completions\",tags=['openai'])\nasync def create_completion(request:Request, create:CompletionCreate):\n    id = str(uuid4())\n    if create.max_tokens is not None and create.max_tokens<0:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"max_tokens must be at least 0, got {create.max_tokens}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    if create.max_completion_tokens is not None and create.max_completion_tokens<0:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"max_completion_tokens must be at least 0, got {create.max_completion_tokens}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    if create.temperature<0 or create.temperature>2:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"temperature must be in [0, 2], got {create.temperature}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n            })\n    if create.top_p<=0 or create.top_p>1:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"top_p must be in (0, 1], got {create.top_p}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    interface = get_interface()\n    print(f'COMPLETION INPUT:----\\n{create.prompt}\\n----')\n\n   \n    if create.stream:\n        async def inner():\n            async for res in interface.inference(create.prompt, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):     \n                if isinstance(res, RawUsage):\n                    raw_usage = res\n                else: \n                    token, finish_reason = res\n                    d = {'choices':[{'delta':{'content':token}}]}\n                    yield f\"data:{json.dumps(d)}\\n\\n\"\n            d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}\n            yield f\"data:{json.dumps(d)}\\n\\n\"\n        return stream_response(request,inner())\n    else:\n        comp = CompletionObject(id=id,object='text_completion',created=int(time()))\n        async for res in interface.inference(create.prompt,id,create.temperature,create.top_p, create.max_tokens, create.max_completion_tokens):     \n            if isinstance(res, RawUsage):\n                raw_usage = res\n            else: \n                token, finish_reason = res\n                comp.append_token(token) \n        return comp\n"
  },
  {
    "path": "archive/ktransformers/server/api/web/__init__.py",
    "content": "from fastapi import APIRouter\nfrom .system import router as system_router\n\n\nrouter = APIRouter()\nrouter.include_router(system_router)\n"
  },
  {
    "path": "archive/ktransformers/server/api/web/system.py",
    "content": "from fastapi import APIRouter\n\n\nrouter = APIRouter()\n\n\n@router.get('/system-info',tags=['web'])\ndef system_info():\n    raise NotImplementedError\n"
  },
  {
    "path": "archive/ktransformers/server/args.py",
    "content": "import argparse\nfrom ktransformers.server.backend.args import ConfigArgs, default_args\nfrom ktransformers.util.utils import get_free_ports\nfrom transformers import AutoConfig\nfrom ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig\nfrom ktransformers.models.configuration_qwen3_next import Qwen3NextConfig\nfrom ktransformers.models.configuration_smallthinker import SmallthinkerConfig\nfrom ktransformers.models.configuration_glm4_moe import Glm4MoeConfig\n\nclass ArgumentParser:\n    def __init__(self, cfg):\n        self.cfg = cfg\n\n    def parse_args(self):\n        parser = argparse.ArgumentParser(prog=\"kvcache.ai\", description=\"Ktransformers\")\n        parser.add_argument(\"--host\", type=str, default=self.cfg.server_ip)\n        parser.add_argument(\"--port\", type=int, default=self.cfg.server_port)\n        parser.add_argument(\"--api_key\", type=str, default=self.cfg.api_key)\n        parser.add_argument(\"--ssl_keyfile\", type=str)\n        parser.add_argument(\"--ssl_certfile\", type=str)\n        parser.add_argument(\"--web\", type=bool, default=self.cfg.mount_web)\n        parser.add_argument(\"--model_name\", type=str, default=self.cfg.model_name)\n        parser.add_argument(\"--model_dir\", type=str)\n        parser.add_argument(\"--model_path\", type=str, default=self.cfg.model_path)\n        parser.add_argument(\n            \"--device\", type=str, default=self.cfg.model_device, help=\"Warning: Abandoning this parameter\"\n        )\n        parser.add_argument(\"--architectures\", type=str, default=self.cfg.model_name)\n        parser.add_argument(\"--q4_gguf_path\", type=str, default=None)\n        parser.add_argument(\"--gguf_path\", type=str, default=self.cfg.gguf_path)\n        parser.add_argument(\"--draft_model_path\", type=str, default=None)\n        parser.add_argument(\"--draft_gguf_path\", type=str, default=None)\n        parser.add_argument(\"--optimize_config_path\", default=None, type=str, required=False)\n        parser.add_argument(\"--cpu_infer\", type=int, default=self.cfg.cpu_infer)\n        parser.add_argument(\"--backend_type\", type=str, default=self.cfg.backend_type)\n        parser.add_argument(\"--chunk_size\", type=int, default=self.cfg.chunk_size)\n        parser.add_argument(\"--tp\", type=int, default=1)\n\n        # model configs\n        # parser.add_argument(\"--model_cache_lens\", type=int, default=self.cfg.cache_lens)  # int?\n        parser.add_argument(\"--max_batch_size\", type=int, default=self.cfg.max_batch_size)\n        parser.add_argument(\"--max_new_tokens\", type=int, default=self.cfg.max_new_tokens)\n        parser.add_argument(\"--json_mode\", type=bool, default=self.cfg.json_mode)\n        parser.add_argument(\"--healing\", type=bool, default=self.cfg.healing)\n        parser.add_argument(\"--ban_strings\", type=list, default=self.cfg.ban_strings, required=False)\n        parser.add_argument(\"--gpu_split\", type=str, default=self.cfg.gpu_split, required=False)\n        parser.add_argument(\"--length\", type=int, default=self.cfg.length, required=False)\n        parser.add_argument(\"--rope_scale\", type=float, default=self.cfg.rope_scale, required=False)\n        parser.add_argument(\"--rope_alpha\", type=float, default=self.cfg.rope_alpha, required=False)\n        parser.add_argument(\"--no_flash_attn\", type=bool, default=self.cfg.no_flash_attn)\n        parser.add_argument(\"--low_mem\", type=bool, default=self.cfg.low_mem)\n        parser.add_argument(\"--experts_per_token\", type=int, default=self.cfg.experts_per_token, required=False)\n        parser.add_argument(\"--load_q4\", type=bool, default=self.cfg.load_q4)\n        parser.add_argument(\"--fast_safetensors\", type=bool, default=self.cfg.fast_safetensors)\n        parser.add_argument(\"--draft_model_dir\", type=str, default=self.cfg.draft_model_dir, required=False)\n        parser.add_argument(\"--no_draft_scale\", type=bool, default=self.cfg.no_draft_scale)\n        parser.add_argument(\"--modes\", type=bool, default=self.cfg.modes)\n        parser.add_argument(\"--mode\", type=str, default=self.cfg.mode)\n        parser.add_argument(\"--username\", type=str, default=self.cfg.username)\n        parser.add_argument(\"--botname\", type=str, default=self.cfg.botname)\n        parser.add_argument(\"--system_prompt\", type=str, default=self.cfg.system_prompt, required=False)\n        parser.add_argument(\"--temperature\", type=float, default=self.cfg.temperature)\n        parser.add_argument(\"--smoothing_factor\", type=float, default=self.cfg.smoothing_factor)\n        parser.add_argument(\"--dynamic_temperature\", type=str, default=self.cfg.dynamic_temperature, required=False)\n        parser.add_argument(\"--top_k\", type=int, default=self.cfg.top_k)\n        parser.add_argument(\"--top_p\", type=float, default=self.cfg.top_p)\n        parser.add_argument(\"--top_a\", type=float, default=self.cfg.top_a)\n        parser.add_argument(\"--skew\", type=float, default=self.cfg.skew)\n        parser.add_argument(\"--typical\", type=float, default=self.cfg.typical)\n        parser.add_argument(\"--repetition_penalty\", type=float, default=self.cfg.repetition_penalty)\n        parser.add_argument(\"--frequency_penalty\", type=float, default=self.cfg.frequency_penalty)\n        parser.add_argument(\"--presence_penalty\", type=float, default=self.cfg.presence_penalty)\n        parser.add_argument(\"--response_chunk\", type=int, default=self.cfg.response_chunk)\n        parser.add_argument(\"--no_code_formatting\", type=bool, default=self.cfg.no_code_formatting)\n        parser.add_argument(\"--cache_8bit\", type=bool, default=self.cfg.cache_8bit)\n        parser.add_argument(\"--cache_q4\", type=bool, default=self.cfg.cache_q4)\n        parser.add_argument(\"--ngram_decoding\", type=bool, default=self.cfg.ngram_decoding)\n        parser.add_argument(\"--print_timings\", type=bool, default=self.cfg.print_timings)\n        parser.add_argument(\"--amnesia\", type=bool, default=self.cfg.amnesia)\n        parser.add_argument(\"--batch_size\", type=int, default=self.cfg.batch_size)\n        parser.add_argument(\"--cache_lens\", type=int, default=self.cfg.cache_lens)\n\n        # kvc2 config\n        parser.add_argument(\"--kvc2_config_dir\", type=str, default=self.cfg.kvc2_config_dir)\n\n        # log configs\n        # log level: debug, info, warn, error, crit\n        parser.add_argument(\"--log_dir\", type=str, default=self.cfg.log_dir)\n        parser.add_argument(\"--log_file\", type=str, default=self.cfg.log_file)\n        parser.add_argument(\"--log_level\", type=str, default=self.cfg.log_level)\n        parser.add_argument(\"--backup_count\", type=int, default=self.cfg.backup_count)\n\n        # db configs\n        parser.add_argument(\"--db_type\", type=str, default=self.cfg.db_type)\n        parser.add_argument(\"--db_host\", type=str, default=self.cfg.db_host)\n        parser.add_argument(\"--db_port\", type=str, default=self.cfg.db_port)\n        parser.add_argument(\"--db_name\", type=str, default=self.cfg.db_name)\n        parser.add_argument(\"--db_pool_size\", type=int, default=self.cfg.db_pool_size)\n        parser.add_argument(\"--db_database\", type=str, default=self.cfg.db_database)\n\n        # user config\n        parser.add_argument(\"--user_secret_key\", type=str, default=self.cfg.user_secret_key)\n        parser.add_argument(\"--user_algorithm\", type=str, default=self.cfg.user_algorithm)\n        parser.add_argument(\"--force_think\", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think)\n        parser.add_argument(\"--use_cuda_graph\", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph)\n\n        # web config\n        parser.add_argument(\"--web_cross_domain\", type=bool, default=self.cfg.web_cross_domain)\n\n        # file config\n        parser.add_argument(\"--file_upload_dir\", type=str, default=self.cfg.file_upload_dir)\n        parser.add_argument(\"--assistant_store_dir\", type=str, default=self.cfg.assistant_store_dir)\n        # local chat\n        parser.add_argument(\"--prompt_file\", type=str, default=self.cfg.prompt_file)\n\n\n        # async server\n        parser.add_argument(\"--sched_strategy\", type=str, default=self.cfg.sched_strategy)\n        # parser.add_argument(\"--sched_port\", type=int, default=self.cfg.sched_port)\n        # parser.add_argument(\"--sched_metrics_port\", type=int, default=self.cfg.sched_metrics_port)\n        # parser.add_argument(\"--kvc2_metrics_port\", type=int, default=self.cfg.kvc2_metrics_port)\n        parser.add_argument(\"--page_size\", type=str, default=self.cfg.page_size)\n        parser.add_argument(\"--memory_gpu_only\", type=str, default=self.cfg.memory_gpu_only)\n        parser.add_argument(\"--utilization_percentage\", type=str, default=self.cfg.utilization_percentage)\n        parser.add_argument(\"--cpu_memory_size_GB\", type=str, default=self.cfg.cpu_memory_size_GB)\n\n\n        args = parser.parse_args()\n        if (args.model_dir is not None or args.model_path is not None):\n            if (args.model_path is not None):\n                # if pass model_dir and model_path, we use model_path\n                args.model_dir = args.model_path\n            else:\n                # if only pass model_dir, we use model_dir\n                args.model_path = args.model_dir\n        else:\n            args.model_dir = self.cfg.model_dir\n            args.model_path = self.cfg.model_path\n        \n        # we add the name not match args individually\n        self.cfg.model_device = args.device\n        self.cfg.mount_web = args.web\n        self.cfg.server_ip = args.host\n        self.cfg.server_port = args.port\n        self.cfg.user_force_think = args.force_think\n\n\n        args.architectures = args.model_name\n\n        try:\n            model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n        except:\n            if args.model_name == \"Qwen3NextForCausalLM\":\n                model_config = Qwen3NextConfig.from_pretrained(args.model_dir)\n            else:\n                raise ValueError(f\"Model {args.model_name} not supported. Please check your model directory or model name.\")\n\n\n        if model_config.architectures[0] == \"Qwen3MoeForCausalLM\" or model_config.architectures[0] == \"Qwen2MoeForCausalLM\" or model_config.architectures[0] == \"SmallThinkerForCausalLM\" or model_config.architectures[0] == \"Glm4MoeForCausalLM\":\n            args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim\n            args.architectures = model_config.architectures[0]\n        else:\n            args.gpu_memory_size = args.cache_lens*2*576*61\n        # set config from args\n        for key, value in vars(args).items():\n            if value is not None and hasattr(self.cfg, key):\n                setattr(self.cfg, key, value)\n        self.cfg.gpu_memory_size = args.gpu_memory_size\n        free_ports = get_free_ports(3, [args.port])\n        args.sched_port = free_ports[0]\n        args.sched_metrics_port = free_ports[1]\n        args.kvc2_metrics_port = free_ports[2]\n        self.cfg.sched_port = free_ports[0]\n        self.cfg.sched_metrics_port = free_ports[1]\n        self.cfg.kvc2_metrics_port = free_ports[2]\n        return args\n"
  },
  {
    "path": "archive/ktransformers/server/backend/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/backend/args.py",
    "content": "from pydantic import BaseModel, Field\nfrom typing import Optional\nfrom ktransformers.server.config.config import Config\n\n\nclass ConfigArgs(BaseModel):\n    model_name: Optional[str] = Field(..., description=\"Model name\")\n    model_dir: Optional[str] = Field(..., description=\"Path to model directory\")\n    optimize_config_path: Optional[str] = Field(None, description=\"Path of your optimize config yml file\")\n    gguf_path: Optional[str] = Field(None, description=\"Path of your gguf file\")\n    draft_model_path: Optional[str] = Field(None, description=\"Path of your gguf file\")\n    draft_gguf_path: Optional[str] = Field(None, description=\"Path of your gguf file\")\n    tp: int = Field(None, description=\"tp size\")\n\n    class Config:\n        protected_namespaces = ()\n\n    max_batch_size: int = Field(\n        None, description=\"Max number of batches to run at once, assuming the sequences will fit within total_context\"\n    )\n    chunk_size: int = Field(\n        None,\n        description=(\n            \"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new\"\n            \" job is started, but at the expense of overall prompt ingestion speed\"\n        ),\n    )\n    max_new_tokens: int = Field(None, description=\"Max new tokens per completion. For this example applies to all jobs\")\n    json_mode: bool = Field(\n        None, description=\"Use LMFE to constrain the output to JSON format. See schema and details below\"\n    )\n    healing: bool = Field(None, description=\"Demonstrate token healing\")\n    ban_strings: Optional[list] = Field(None, description=\"Ban some phrases maybe\")\n    gpu_split: Optional[str] = Field(None, description='\"auto\", or VRAM allocation per GPU in GB')\n    length: Optional[int] = Field(None, description=\"Maximum sequence length\")\n    rope_scale: Optional[float] = Field(None, description=\"RoPE scaling factor\")\n    rope_alpha: Optional[float] = Field(None, description=\"RoPE alpha value (NTK)\")\n    no_flash_attn: bool = Field(None, description=\"Disable Flash Attention\")\n    low_mem: bool = Field(None, description=\"Enable VRAM optimizations, potentially trading off speed\")\n    experts_per_token: Optional[int] = Field(\n        None, description=\"Override MoE model's default number of experts per token\"\n    )\n    load_q4: bool = Field(None, description=\"Load weights in Q4 mode\")\n    fast_safetensors: bool = Field(None, description=\"Optimized safetensors loading with direct I/O (experimental!)\")\n    draft_model_dir: Optional[str] = Field(None, description=\"Path to draft model directory\")\n    no_draft_scale: bool = Field(\n        None,\n        description=\"If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it\",\n    )\n    modes: bool = Field(None, description=\"List available modes and exit.\")\n    mode: str = Field(None, description=\"Chat mode. Use llama for Llama 1/2 chat finetunes.\")\n    username: str = Field(None, description=\"Username when using raw chat mode\")\n    botname: str = Field(None, description=\"Bot name when using raw chat mode\")\n    system_prompt: Optional[str] = Field(None, description=\"Use custom system prompt\")\n    temperature: float = Field(None, description=\"Sampler temperature, default = 0.95 (1 to disable)\")\n    smoothing_factor: float = Field(None, description=\"Smoothing Factor, default = 0.0 (0 to disable)\")\n    dynamic_temperature: Optional[str] = Field(\n        None, description=\"Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1\"\n    )\n    top_k: int = Field(None, description=\"Sampler top-K, default = 50 (0 to disable)\")\n    top_p: float = Field(None, description=\"Sampler top-P, default = 0.8 (0 to disable)\")\n    top_a: float = Field(None, description=\"Sampler top-A, default = 0.0 (0 to disable)\")\n    skew: float = Field(None, description=\"Skew sampling, default = 0.0 (0 to disable)\")\n    typical: float = Field(None, description=\"Sampler typical threshold, default = 0.0 (0 to disable)\")\n    repetition_penalty: float = Field(None, description=\"Sampler repetition penalty, default = 1.01 (1 to disable)\")\n    frequency_penalty: float = Field(None, description=\"Sampler frequency penalty, default = 0.0 (0 to disable)\")\n    presence_penalty: float = Field(None, description=\"Sampler presence penalty, default = 0.0 (0 to disable)\")\n    response_chunk: int = Field(None, description=\"Space to reserve in context for reply, default = 250\")\n    no_code_formatting: bool = Field(None, description=\"Disable code formatting/syntax highlighting\")\n    cache_8bit: bool = Field(None, description=\"Use 8-bit (FP8) cache\")\n    cache_q4: bool = Field(None, description=\"Use Q4 cache\")\n    ngram_decoding: bool = Field(None, description=\"Use n-gram speculative decoding\")\n    print_timings: bool = Field(None, description=\"Output timings after each prompt\")\n    amnesia: bool = Field(None, description=\"Forget context after every response\")\n\n    # for transformers\n    batch_size: int = Field(None, description=\"Batch Size\")\n    cache_lens: int = Field(None, description=\"Cache lens for transformers static cache\")\n    device: str = Field(None, description=\"device\")\n\n\ncfg = Config()\ndefault_args = cfg\n"
  },
  {
    "path": "archive/ktransformers/server/backend/base.py",
    "content": "from asyncio import Queue\nfrom enum import Enum\nimport sys, os\nfrom typing import AsyncIterator, Dict, List, Optional, Tuple\n\nimport torch\n\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager\nfrom ktransformers.server.crud.assistants.messages import MessageDatabaseManager\nfrom ktransformers.server.crud.assistants.runs import RunsDatabaseManager\nfrom ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager\nfrom ktransformers.server.exceptions import request_error\nfrom ktransformers.server.schemas.assistants.assistants import AssistantObject\nfrom ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role\nfrom ktransformers.server.schemas.assistants.runs import RunObject\nfrom ktransformers.server.schemas.assistants.threads import ThreadObject\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage\nfrom ktransformers.server.schemas.base import ObjectID, Order\nfrom ktransformers.server.utils.multi_timer import Profiler\n\n\nfrom .args import ConfigArgs,default_args\n\n\n\nclass BackendInterfaceBase:\n    '''\n    Interface to inference frameworks. e.g. transformers, exllama.\n    Implement __init__ and work  \n    '''\n\n    args: ConfigArgs\n    profiler:Profiler = Profiler()\n\n    def __init__(self, args:ConfigArgs = default_args):\n        raise NotImplementedError\n\n    \n    async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator[str]:\n        '''\n        work can be called directly, or by ThreadContext\n\n        local_messages: \n            When called by ThreadContext, local_messages are generated by ThreadContext.get_local_messages().\n            Please deal with different local_messages\n        request_unique_id:\n            unique id of different requests, useful when using cache\n        \n        return:\n            async str output for stream update\n\n        '''\n        raise NotImplementedError\n\n\n    def report_last_time_performance(self):\n        try:\n            tokenize_time = self.profiler.get_timer_sec('tokenize')\n            prefill_time = self.profiler.get_timer_sec('prefill')\n            decode_time = self.profiler.get_timer_sec('decode')\n            prefill_count = self.profiler.get_counter('prefill')\n            decode_count = self.profiler.get_counter('decode')\n\n            logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')\n        except:\n            logger.info(f'Performance statistics not recorded')\n\n\nclass ThreadContext:\n    '''\n    A thread context holding assistant logics \n    \n    '''\n\n    args: ConfigArgs\n    # Assistant Logic\n    assistant: Optional[AssistantObject] = None\n    related_threads : List[ThreadObject]\n    thread: ThreadObject\n    messages: List[MessageObject] = [] \n    run: RunObject\n\n    interface: Optional[BackendInterfaceBase] = None\n     \n    queue: Optional[Queue] = None\n    timer: Profiler = Profiler()\n\n    def __init__(self, run: RunObject,interface:BackendInterfaceBase, args: ConfigArgs = default_args) -> None:\n        self.args = args\n        self.thread_manager = ThreadsDatabaseManager()\n        self.message_manager = MessageDatabaseManager()\n        self.runs_manager = RunsDatabaseManager()\n        self.assistant_manager = AssistantDatabaseManager()\n        self.thread = self.thread_manager.db_get_thread_by_id(run.thread_id)\n        self.assistant = self.assistant_manager.db_get_assistant_by_id(run.assistant_id)\n        self.messages = self.message_manager.db_list_messages_of_thread(run.thread_id,order=Order.ASC)\n        logger.debug(f\"{len(self.messages)} messages loaded from database\")\n        self.interface = interface\n        self.update_by_run(run,args)\n\n    def get_local_messages(self):\n        '''\n        Get local messages, as the input to interface.work\n        This function is intended to message preprocess e.g. apply chat template\n        '''\n        raise NotImplementedError\n\n    def update_by_run(self,run:RunObject,args:ConfigArgs = default_args):\n        self.run = run \n        self.args = args\n       \n    def put_user_message(self, message: MessageObject):\n        assert (\n            message.role.is_user() and message.thread_id == self.thread.id and message.status == MessageObject.Status.in_progress\n        )\n        self.messages.append(message)\n\n    def delete_user_message(self,message_id: ObjectID):\n        self.messages = [m for m in self.messages if m.id != message_id]\n\n    async def work(self)->AsyncIterator:\n        logger.debug('start working')\n        user_message = self.messages[-1]\n        if not user_message.role.is_user():\n            raise request_error('user must talk before LLM can talk')\n        user_message.status = MessageObject.Status.completed\n        user_message.sync_db()\n\n        local_messages = self.get_local_messages() # must get this before we interseted reply_message\n\n\n        response_str_count = 0  \n        reply_message = self.message_manager.create_message_object(\n                            self.thread.id,\n                            self.run.id,\n                            MessageCreate(role=Role.assistant, content=\"\"),    \n                        )\n        reply_message.assistant_id = self.assistant.id\n        self.messages.append(reply_message) \n\n        yield reply_message.stream_response_with_event(MessageObject.Status.created)\n        yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)\n        yield self.run.stream_response_with_event(RunObject.Status.in_progress)\n\n        async for res in self.interface.inference(local_messages,self.thread.id): \n            if isinstance(res, RawUsage):\n                raw_usage = res\n            else: \n                token, finish_reason = res    \n                if self.run.status == RunObject.Status.cancelling:\n                    logger.warn(f'Run {self.run.id} cancelling')\n                    break\n                yield reply_message.append_message_delta(token)\n                response_str_count+=1\n        \n        if self.run.status == RunObject.Status.cancelling:\n            yield self.run.stream_response_with_event(RunObject.Status.cancelled)\n            yield reply_message.stream_response_with_event(MessageObject.Status.incomplete)\n        elif self.run.status == RunObject.Status.in_progress:\n            yield self.run.stream_response_with_event(RunObject.Status.completed)\n            yield reply_message.stream_response_with_event(MessageObject.Status.completed)\n        else:\n            raise NotImplementedError(f'{self.run.status} should not appear here')\n\n        reply_message.sync_db()\n        self.run.sync_db()"
  },
  {
    "path": "archive/ktransformers/server/backend/context_manager.py",
    "content": "from asyncio import Lock\nfrom typing import Dict, Optional\n\nfrom ktransformers.server.backend.base import ThreadContext, BackendInterfaceBase\nfrom ktransformers.server.schemas.assistants.runs import RunObject\nfrom ktransformers.server.schemas.base import ObjectID\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.backend.interfaces.transformers import TransformersThreadContext\nfrom ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext\nfrom ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext\n\n\nfrom ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface\nfrom ktransformers.server.backend.interfaces.transformers import TransformersInterface\nfrom ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface\n\nclass ThreadContextManager:\n    lock: Lock\n    threads_context: Dict[ObjectID, ThreadContext]\n    interface: BackendInterfaceBase\n    \n    def __init__(self,interface) -> None:\n        logger.debug(f\"Creating Context Manager\")\n        self.lock = Lock()\n        self.threads_context = {}\n        self.interface = interface\n        pass\n\n    async def get_context_by_run_object(self, run: RunObject) -> ThreadContext:\n        async with self.lock:\n            logger.debug(f\"keys {self.threads_context.keys()}\")\n            if run.thread_id not in self.threads_context:\n                logger.debug(f\"new inference context {run.thread_id}\")\n                if isinstance(self.interface, ExllamaInterface):\n                    new_context = ExllamaThreadContext(run, self.interface)\n                elif isinstance(self.interface, KTransformersInterface):\n                    new_context = KTransformersThreadContext(run, self.interface)\n                elif isinstance(self.interface, TransformersInterface):\n                    new_context = TransformersThreadContext(run, self.interface)\n                else:\n                    from ktransformers.server.backend.interfaces.balance_serve import BalanceServeThreadContext\n                    from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface\n                    if isinstance(self.interface, BalanceServeInterface):\n                        new_context = BalanceServeThreadContext(run, self.interface)\n                    else:\n                        raise NotImplementedError\n                # elif isinstance(self.interface, BalanceServeInterface):\n                #     new_context = BalanceServeThreadContext(run, self.interface)\n                # else:\n                #     raise NotImplementedError\n                self.threads_context[run.thread_id] = new_context\n                # self.threads_context[run.thread_id] = ExllamaInferenceContext(run)\n            re = self.threads_context[run.thread_id]\n            re.update_by_run(run)\n            return re\n\n    async def get_context_by_thread_id(self, thread_id: ObjectID) -> Optional[ThreadContext]:\n        async with self.lock:\n            if thread_id in self.threads_context:\n                logger.debug(f'found context for thread {thread_id}')\n                return self.threads_context[thread_id]\n            else:\n                logger.debug(f'no context for thread {thread_id}')\n                return None\n            "
  },
  {
    "path": "archive/ktransformers/server/backend/interfaces/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/backend/interfaces/balance_serve.py",
    "content": "from typing import Any, AsyncIterator, List, Optional, Set\nfrom ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache, KVC2Qwen3Cache\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    GenerationConfig,\n    StaticCache,\n    AutoModelForCausalLM,\n    BitsAndBytesConfig,\n)\n\nimport torch.distributed as dist\nfrom ktransformers.server.config.config import Config\nfrom ..base import ThreadContext, BackendInterfaceBase\nimport torch\nfrom ktransformers.server.backend.interfaces.transformers import (\n    ConfigArgs,\n    default_args,\n    TextStreamer,\n)\nfrom ktransformers.server.schemas.base import ObjectID\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.optimize.optimize import optimize_and_load_gguf\nfrom ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM\nfrom ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM\nfrom ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM\nfrom ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM\nfrom ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM\nfrom ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM\nfrom ktransformers.models.custom_modeling_qwen3_next import KQwen3NextForCausalLM\nfrom ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig\nfrom ktransformers.models.configuration_smallthinker import SmallthinkerConfig\nfrom ktransformers.models.configuration_glm4_moe import Glm4MoeConfig\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM\ntry:\n    import torch_npu\n    use_torch_npu = torch.npu.is_available()\nexcept:\n    use_torch_npu = False\nif use_torch_npu:\n    from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM\n    from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM\n    from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group, get_tensor_parallel_size\n\nfrom ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM\nfrom ktransformers.models.modeling_llama import LlamaForCausalLM\nfrom ktransformers.models.modeling_mixtral import MixtralForCausalLM\nfrom ktransformers.util import utils\ncustom_models = {\n    \"DeepseekV2ForCausalLM\": DeepseekV2ForCausalLM,\n    \"Qwen2MoeForCausalLM\": Qwen2MoeForCausalLM,\n    \"LlamaForCausalLM\": LlamaForCausalLM,\n    \"MixtralForCausalLM\": MixtralForCausalLM,\n}\nfrom ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner #TODO get_or_create_model_runner npu独有？\nfrom ktransformers.models.configuration_qwen3_next import Qwen3NextConfig\nfrom ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions\nfrom ktransformers.server.balance_serve.inference.query_manager import QueryManager\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.server.balance_serve.sched_rpc import SchedulerClient\nfrom ktransformers.server.balance_serve.settings import sched_ext\n\nfrom torch.multiprocessing import Queue\nimport torch.multiprocessing as mp\nfrom multiprocessing.synchronize import Event\nimport datetime\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage\nfrom ktransformers.server.utils.multi_timer import Profiler\nimport zmq\nimport time\nimport queue\nimport tempfile\nimport asyncio\nimport cProfile\nimport threading\nfrom contextlib import asynccontextmanager\nfrom fastapi import FastAPI, Request\nimport os\nimport pickle\nimport subprocess\nimport tempfile\nimport atexit\nimport signal\n\n\nktransformer_rules_dir = (\n    os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"..\", \"..\", \"./optimize/optimize_rules/\") \n)\n\ndefault_optimize_rules = {\n    # \"DeepseekV3ForCausalLM\": ktransformer_rules_dir + \"Moonlight-16B-A3B-serve.yaml\",\n    \"DeepseekV3ForCausalLM\": ktransformer_rules_dir + \"DeepSeek-V3-Chat-serve.yaml\",\n    \"Qwen2MoeForCausalLM\": ktransformer_rules_dir + \"Qwen2-serve.yaml\",\n    \"Qwen3MoeForCausalLM\": ktransformer_rules_dir + \"Qwen3Moe-serve.yaml\",\n    \"SmallThinkerForCausalLM\": ktransformer_rules_dir + \"Smallthinker-serve.yaml\",\n    \"Glm4MoeForCausalLM\": ktransformer_rules_dir + \"Glm4Moe-serve.yaml\",\n    \"Qwen3NextForCausalLM\": ktransformer_rules_dir + \"Qwen3Next-serve.yaml\",\n}\nif use_torch_npu:\n    default_optimize_rules[\"Qwen2MoeForCausalLM\"] = ktransformer_rules_dir + \"Qwen2-57B-A14B-Instruct-serve.yaml\"\n\nasync def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):\n    streamer = TextStreamer(tokenizer)\n    while True:\n        token = await queue.get()\n        #print(f\"Got token: {token}\")\n        if token is None:\n            # str = f'{token}\\n\\n'\n            # str = model.tokenizer.decode(token)\n            s = streamer.end()\n            if s is not None:\n                yield s\n            break\n        else:\n            # text output\n            text = tokenizer.decode(token)\n            print(text, end=\"\", flush=True)\n\n        # str = model.tokenizer.decode(token)\n        yield streamer.put(token)\n\ndef fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None):\n    #print(len(query_updates), generated_tokens.size(0), generated_tokens)\n    for i in range(generated_tokens.size(0)):\n        # print(generated_tokens[i].item())\n        query_updates[i].generated_token = generated_tokens[i].item()\n        if not query_manager.query_map[query_updates[i].id].is_prefill:\n            pos = query_updates[i].active_position\n            if pos < query_manager.query_map[query_updates[i].id].max_length:\n                query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]\n\ndef report_last_time_performance(profiler: Profiler):\n        try:\n            tokenize_time = profiler.get_timer_sec('tokenize')\n            prefill_time = profiler.get_timer_sec('prefill')\n            decode_time = profiler.get_timer_sec('decode')\n            prefill_count = profiler.get_counter('prefill')\n            decode_count = profiler.get_counter('decode')\n\n            logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')\n        except:\n            logger.info(f'Performance statistics not recorded')\n\nclass Engine:\n    sched_client : SchedulerClient\n    updates : list[sched_ext.QueryUpdate]\n    batch : sched_ext.BatchQueryTodo\n    model_runner: ModelRunner\n    sampler: Sampler\n    query_manager: QueryManager\n    cache: KDeepSeekV3Cache | KGQACache | KVC2StaticCache\n    def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):\n        self.args = args\n\n        # 子进程和父进程无法共享 config 变量\n        for key, value in vars(args).items():\n            if value is not None and hasattr(Config(), key):\n                setattr(Config(), key, value)\n        if use_torch_npu:\n            utils.CUR_DEVICE = f\"npu:{torch.npu.current_device()}\"\n            self.device = f\"npu:{torch.npu.current_device()}\"\n        else:\n            self.device = self.args.device\n        self.sched_client = SchedulerClient(args.sched_port)\n        self.updates = []\n\n        print(f\"args.architectures: {args.architectures}\")\n\n        if args.architectures == \"Qwen3MoeForCausalLM\": \n            config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n        elif args.architectures == \"Glm4MoeForCausalLM\":\n            config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n        elif args.architectures == \"SmallThinkerForCausalLM\":\n            config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n            config._attn_implementation = \"eager\"  \n            config.moe_intermediate_size = config.moe_ffn_hidden_size\n        elif args.architectures == \"Qwen3NextForCausalLM\":\n            config = Qwen3NextConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n        else:\n            try:\n                config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) \n            except:\n                raise ValueError(f\"Model {args.architectures} not supported. Please check your model directory or model name.\")\n\n        self.gen_queue = generated_token_queue\n        self.debug = False\n\n        self.profiler_cprofile = cProfile.Profile()\n        self.cprof_prof_cnt, self.max_cprof_prof_cnt = 0, 8\n        with torch.device(\"meta\"):\n            if config.architectures[0] == \"DeepseekV3ForCausalLM\":\n                if use_torch_npu:\n                    self.cache = KVC2StaticCache(config, args.max_batch_size, self.args.page_size)\n                    self.model = KNPUDeepseekV3ForCausalLM(config)\n                else:\n                    self.cache = KDeepSeekV3Cache(config, self.args.page_size)\n                    self.model = KDeepseekV3ForCausalLM(config, self.cache)\n            elif config.architectures[0] == \"DeepseekV2ForCausalLM\":\n                self.cache = KDeepSeekV3Cache(config, self.args.page_size)\n                self.model = KDeepseekV2ForCausalLM(config, self.cache)\n            elif config.architectures[0] == \"Qwen2MoeForCausalLM\" or config.architectures[0] == \"Qwen3MoeForCausalLM\":\n                if not use_torch_npu:\n                    self.cache = KGQACache(config, self.args.page_size)\n                    if config.architectures[0] == \"Qwen2MoeForCausalLM\":\n                        self.model = KQwen2MoeForCausalLM(config, self.cache)\n                    else:\n                        self.model = KQwen3MoeForCausalLM(config, self.cache)\n                else:\n                    self.cache = KVC2Qwen3Cache(config, args.max_batch_size, self.args.page_size)\n                    self.model = KNPUQwen3MoeForCausalLM(config, self.cache)\n            elif config.architectures[0] == \"SmallThinkerForCausalLM\":\n                self.cache = KGQACache(config, self.args.page_size)\n                self.model = KSmallThinkerForCausalLM(config, self.cache)\n            elif config.architectures[0] == \"Glm4MoeForCausalLM\":\n                self.cache = KGQACache(config, self.args.page_size)\n                self.model = KGlm4MoeForCausalLM(config, self.cache)\n            elif config.architectures[0] == \"Qwen3NextForCausalLM\":\n                self.cache = KGQACache(config, self.args.page_size)\n                self.model = KQwen3NextForCausalLM(config, self.cache)\n\n        context = zmq.Context()\n        if use_torch_npu:\n            if torch.distributed.get_rank() == 0:\n                self.pub_socket = context.socket(zmq.PUB)\n                self.pub_socket.bind(f\"ipc://{broadcast_endpoint}\")\n                self.sub_socket = None\n            else:\n                self.sub_socket = context.socket(zmq.SUB)\n                self.sub_socket.connect(f\"ipc://{broadcast_endpoint}\")\n                self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, \"\")\n                self.pub_socket = None\n            # time.sleep(1) # make sure all subscribers are ready\n        else:\n            self.pub_socket = context.socket(zmq.PUB)\n            self.pub_socket.bind(f\"ipc://{broadcast_endpoint}\")\n\n        try:\n            generation_config = GenerationConfig.from_pretrained(args.model_dir)\n        except:\n            generation_config = GenerationConfig(\n                max_length=args.max_new_tokens,\n                temperature=args.temperature,\n                top_p=args.top_p,\n                do_sample=True\n            )\n            \n        if args.optimize_config_path is None:\n            optimize_config_path = default_optimize_rules[config.architectures[0]]\n               \n        else:\n            optimize_config_path = args.optimize_config_path\n        gguf_path = args.gguf_path\n        if gguf_path is None:\n            gguf_path = input(\n                \"please input the path of your gguf file(gguf file in the dir containing input gguf file must all\"\n                \" belong to current model):\"\n            )\n        if use_torch_npu:\n            tp_group = get_tensor_parallel_group()\n            torch.distributed.barrier(group=tp_group)\n        optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)        \n        if use_torch_npu:\n            get_absort_weight(self.model, config) #TODO \n            torch.distributed.barrier(group=tp_group)\n        self.model.generation_config = generation_config\n        if self.model.generation_config.pad_token_id is None:\n            self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id\n\n        self.model.eval()\n        kvcache_event.set()\n        # load kvcache\n        print(f\"Getting inference context from sched_client.\")\n        inference_context = self.sched_client.get_inference_context_raw()\n        print(f\"Got inference context, sending it to subscribers.\")\n        inference_context = self.sched_client.rebuild_inferece_context(inference_context)\n        self.cache.load(inference_context)\n        print(f\"kv_cache loaded successfully.\")\n        \n\n        self.block_num = inference_context.k_cache[0].size(1)\n        #TODO ModelRunner 区别\n        # self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)\n        #@TODO add config\n        if config.architectures[0] == \"Qwen2MoeForCausalLM\" or config.architectures[0] == \"Qwen3MoeForCausalLM\" or config.architectures[0] == \"Glm4MoeForCausalLM\" or config.architectures[0] == \"SmallThinkerForCausalLM\" or config.architectures[0] == \"Qwen3NextForCausalLM\":\n            if not use_torch_npu:\n                self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num) \n            else:\n                # npu donnot support flash attn\n                self.model.init_wrapper()\n        else:\n            self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)\n\n\n        # self.args.use_cuda_graph代表是否使用图下沉\n        self.model_runner = get_or_create_model_runner(self.model, self.cache, self.device, self.args.use_cuda_graph, page_size = args.page_size)\n        self.sampler = Sampler()\n        self.query_manager = QueryManager(device = self.device, page_size = args.page_size)\n\n            \n    def sampling(self, forward_output: ForwardBatchOutput):\n        generated_tokens = []\n        probs = []\n\n        for i in range(forward_output.num_batchs):\n            logit = forward_output.logits[i]\n            if hasattr(forward_output, \"temperatures\"):\n                temperatures = forward_output.temperatures[i]\n            else:\n                temperatures = None\n            \n            if hasattr(forward_output, \"top_ps\"):\n                top_ps = forward_output.top_ps[i]\n            else:\n                top_ps = None\n\n            sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps)\n            generated_token, prob=self.sampler(logit, sample_options)\n            generated_tokens.append(generated_token.clone())\n            probs.append(prob.clone())\n        generated_tokens, probs = torch.cat(generated_tokens), torch.cat(probs, dim=0)\n        return generated_tokens, probs\n    \n    def loop(self):\n\n        next_batch = None   \n\n        while True:\n            self.batch = next_batch\n            if self.batch is not None:\n                if use_torch_npu:\n                    batch_size = 0\n                    for i in range(len(self.batch.decode_mini_batches)):\n                        batch_size += len(self.batch.decode_mini_batches[i])\n                    # logger.debug(f\"prefill batch: {len(self.batch.prefill_mini_batches)} decode batch: {len(self.batch.decode_mini_batches)} {batch_size} \\n\")\n                    self.model_runner.run_split(self.batch, self.query_manager)\n                else:\n                    self.model_runner.run(self.batch, self.query_manager)\n\n            if len(self.updates) > 0:\n                for q in self.updates:\n                    if q.is_prefill == True:\n                        continue\n                    # print(f\"Putting token {q.generated_token} into queue for query id: {q.id}\")\n                    try:\n                        if use_torch_npu:\n                            if torch.distributed.get_rank() == 0:\n                                self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)\n                        else:\n                            self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)\n                    except queue.Full:\n                        pass#print(\"Queue is full after timeout; unable to put more items.\")\n            if use_torch_npu:\n                if torch.distributed.get_rank() == 0:\n                    next_batch = self.sched_client.update_last_batch(self.updates)\n                    if next_batch.query_ids == []:\n                        next_batch = None\n                    self.pub_socket.send_pyobj(next_batch)\n                else:\n                    next_batch = self.sub_socket.recv_pyobj()\n            else:\n                next_batch = self.sched_client.update_last_batch(self.updates)\n                if next_batch.query_ids == []:\n                    next_batch = None\n                self.pub_socket.send_pyobj(next_batch)\n\n            if next_batch is not None:\n                self.query_manager.add_query(next_batch)\n            \n            \n            if self.batch is not None:\n                self.model_runner.sync()\n                # print(f\"Model execution time (GPU): {self.model_runner.model_time:.3f} ms\")\n                # if self.rank == 0:\n                \n                generated_tokens, probs = self.sampling( self.model_runner.output)\n                \n                self.updates = self.query_manager.update(self.batch)\n                fill_generated_tokens(self.updates, generated_tokens, self.query_manager)\n\n            else:\n                self.updates = []\n\nclass BalanceServeThreadContext(ThreadContext):\n    def get_local_messages(self):\n        local_messages = []\n        for m in self.messages:\n            local_messages.append({\"role\": m.role.value, \"content\": m.get_text_content()})\n\n        return local_messages\n\n\ndef init_distributed(rank: int,\n                     world_size: int,\n                     tp_size: int,\n                     master_addr: str = os.getenv(\"MASTER_ADDR\", \"127.0.0.1\"),\n                     master_port: int = os.getenv(\"MASTER_PORT\", \"29500\"),\n                     backend: str = \"hccl\"): #TODO csx: 是否distribute 都只与NPU有关\n    os.environ[\"RANK\"] = str(rank)\n    os.environ[\"LOCAL_RANK\"] = str(rank)\n    os.environ[\"WORLD_SIZE\"] = str(world_size)\n    os.environ[\"MASTER_ADDR\"] = master_addr\n    os.environ[\"MASTER_PORT\"] = str(master_port)\n\n    local_rank, world_size = setup_model_parallel(tp=tp_size)\n    return local_rank, world_size\n\n\ndef run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event, rank=None, world_size=None):\n    if use_torch_npu:\n        init_distributed(rank, world_size, args.tp, backend=\"hccl\") #TODO 同上\n    import torch.distributed as dist\n    engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)\n    if args.use_cuda_graph:\n        if 'npu' in engine.device:\n            print(f\"[WARMUP-NPU] start\", flush=True)\n            engine.model_runner.warmup_npu()\n        else:\n            engine.model_runner.warmup()\n    else:\n        print(f\"[WARMUP-NPU] skip warmup, eager mode!\", flush=True)\n    if use_torch_npu:\n        args.port += torch.distributed.get_rank()\n    event.set()\n    engine.loop()\n\n\nclass BalanceServeInterface(BackendInterfaceBase):\n    use_static_cache: bool = True\n\n    model: Any\n    tokenizer: AutoTokenizer\n\n    cache: StaticCache\n    generated_ids: torch.Tensor\n    seq_length: int\n\n    streamer: TextStreamer\n\n    # thread_related\n    last_request_id: Optional[str] = None\n    ever_generated_ids: Set[int] = set()\n\n    def __init__(self, args: ConfigArgs = default_args, input_args=None):\n        self.args = input_args\n        self.queue_map:dict[int,asyncio.Queue] = {}\n        self.thread_map: dict[int, int] = {}\n        processes = []\n        self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config\n        ctx = mp.get_context(\"spawn\")\n        self.token_queue = ctx.Queue(maxsize=1000) \n        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)\n        self.sched_client = SchedulerClient(args.sched_port)\n        self.streamer = TextStreamer(self.tokenizer)\n        if use_torch_npu:\n            world_size = str(os.getenv(\"WORLD_SIZE\", self.args.tp))\n            if not isinstance(world_size, str):\n                raise ValueError(f\"world_size ({world_size}) must be str\")\n            start_events = []\n            kvcache_events = []\n            for rank in range(self.args.tp):\n                if int(self.args.device[-1]) > 0:\n                    break\n\n                start_event = ctx.Event()\n                kvcache_event = ctx.Event()\n\n                p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event,\n                                                        kvcache_event, rank, world_size))\n                p.start()\n                processes.append(p)\n                start_events.append(start_event)\n                kvcache_events.append(kvcache_event)\n\n            for evt in kvcache_events:\n                evt.wait()\n            self._engines = processes\n        else:\n            start_event = ctx.Event()\n            kvcache_event = ctx.Event()\n\n            p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event,\n                                                    kvcache_event))\n            p.start()\n            processes.append(p)\n\n            kvcache_event.wait()\n        with tempfile.NamedTemporaryFile(delete=False) as temp_file:\n            args.tp = input_args.tp\n            pickle.dump(args, temp_file)\n            temp_file_path = temp_file.name\n        current_file = __file__\n        target_file = os.path.join(os.path.dirname(current_file), \"..\", \"..\", \"balance_serve\", \"sched_rpc.py\")\n        target_file = os.path.normpath(target_file)\n        log_path = os.path.join(args.log_dir, \"rpc.log\")\n        log = open(log_path, \"a\") \n        sched_process = subprocess.Popen(\n            [\"python3\", target_file, \"--config\", temp_file_path], \n            stdout=log, \n            stderr=log\n        )\n        print(\"sched_rpc started with PID:\", sched_process.pid)\n\n        def signal_handler(signum, frame):\n            print(f\"Received signal {signum}, shutting down...\")\n            cleanup()\n            os._exit(0) \n\n        def cleanup():\n            print(\"Cleaning up...\")\n\n            for p in processes:\n                if p.is_alive():\n                    print(f\"Terminating subprocess {p.pid}\")\n                    p.terminate()\n                    p.join()\n\n            if sched_process and sched_process.poll() is None:\n                print(f\"Terminating sched_process {sched_process.pid}\")\n                sched_process.terminate()\n                sched_process.wait()\n        signal.signal(signal.SIGINT, signal_handler)   \n        signal.signal(signal.SIGTERM, signal_handler)\n        if use_torch_npu:\n            for evt in start_events:\n                evt.wait()\n        else:\n            start_event.wait()\n    \n    def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None, \n                   max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:\n        \"\"\"Get sampling parameters and handle default values and edge cases\"\"\"\n        if max_tokens is not None:\n            max_completion_tokens = max_tokens\n        if max_completion_tokens is None:\n            max_completion_tokens = self.args.max_new_tokens\n        else:\n            max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)\n        if temperature is None:\n            temperature = self.args.temperature\n        if top_p is None:\n            top_p = self.args.top_p\n            \n        if temperature == 0:\n            temperature = 0.0001\n        if top_p == 0:\n            top_p = 0.0001\n            \n        return temperature, top_p, max_completion_tokens\n\n    def run_queue_proxy(self):\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n        loop.run_until_complete(self.queue_proxy())\n\n    @asynccontextmanager\n    async def lifespan(self, app: FastAPI):\n        asyncio.create_task(self.queue_proxy())\n        yield\n\n    async def queue_proxy(self):\n        print(\"Queue Proxy Started\")\n        while True:\n            try:\n                query_id, token = self.token_queue.get_nowait()\n                try:\n                    # query id might not be allocated yet\n                    self.queue_map[query_id].put_nowait(token)\n                    #print(f\"Proxy Put token: {token} to queue for query id: {query_id}\")\n                except asyncio.QueueFull:\n                    #print(f\"Queue for query id: {query_id} is full, waiting to put: {token}\")\n                    await self.queue_map[query_id].put(token)\n\n            except queue.Empty:\n                # print(\"no new token\")\n                # await asyncio.sleep(1)\n                await asyncio.sleep(0)\n    def tokenize_prompt(self, prompt: str):\n        input_ids = self.tokenizer.encode(prompt, return_tensors=\"pt\").to(self.args.device)\n        return input_ids\n\n    def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):\n        for m in messages:\n            if m[\"role\"] == \"system\":\n                logger.warning(f'change {m[\"role\"]} to user')\n                m[\"role\"] = \"user\"\n\n        new_messages = [messages[0]]\n        for m in messages[1:]:\n            if m[\"role\"] == \"user\" and new_messages[-1][\"role\"] == \"user\":\n                logger.warning(\"merge two adjacent user messages\")\n                new_messages[-1][\"content\"] += '\\n' + m[\"content\"]\n            else:\n                new_messages.append(m)\n        # input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)\n        # # drop <think> token in chat template\n        # if input_str.endswith('<think>\\n'):\n        #     input_str = input_str[:-len('<think>\\n')]\n        input_ids = self.tokenizer.apply_chat_template(new_messages, add_generation_prompt=True, return_tensors=\"pt\").to(self.args.device)\n        return input_ids\n    \n    async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = 0, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):\n        profiler = Profiler()\n        profiler.create_and_start_timer(\"tokenize\")\n        \n        if isinstance(local_messages, List):\n            input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)\n        elif isinstance(local_messages, str):\n            #local_messages = local_messages[0]['content']\n            input_ids = self.tokenize_prompt(local_messages)\n        else:\n            raise ValueError(\"local_messages should be List or str\")\n        if Config().user_force_think:\n            token_thinks = torch.tensor([self.tokenizer.encode(\"<think>\\n\",add_special_tokens=False)],device=input_ids.device)\n            if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]): #TODO 此行新加的，考虑是否影响GPU\n                input_ids = torch.cat(\n                    [input_ids, token_thinks], dim=1\n                )\n        logger.debug(f\"get input ids of shape {input_ids.shape}\")\n\n\n        profiler.pause_timer(\"tokenize\")\n\n        profiler.create_and_start_timer(\"prefill\")\n\n        \n        \n        query_add = sched_ext.QueryAdd()\n        query_add.query_token =  input_ids[0].tolist()\n        query_length = input_ids[0].shape[0]\n        query_add.query_length = query_length\n        profiler.set_counter(\"prefill\", query_length)\n        #@TODO add server\n        stop_criteria =  [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode(\"<|im_end|>\")]\n        query_add.stop_criteria = stop_criteria\n\n        temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)\n\n        query_add.sample_options.temperature = temperature\n        if top_p == 0 or top_p is None:\n            top_p = 0.0001\n        query_add.sample_options.top_p = top_p\n        query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)\n        query_id = self.sched_client.add_query(query_add)\n        queue = asyncio.Queue(maxsize=max_new_tokens)\n        self.queue_map[query_id] = queue\n        self.thread_map[thread_id] = query_id\n        is_first_token = True\n        async for token in chat_stream(self.queue_map[query_id], self.tokenizer):\n            if is_first_token:\n                is_first_token=False\n                profiler.pause_timer(\"prefill\")\n                profiler.create_and_start_timer(\"decode\")\n                profiler.set_counter(\"decode\", 0)\n                if Config().user_force_think:\n                    think = '<think>\\n'\n                    print(think, end=\"\",flush=True)\n                    yield think, None\n            else:\n                profiler.inc(\"decode\")\n            # TODO: 传入rank避免打印重复\n            yield token, None\n        profiler.pause_timer(\"decode\")\n        report_last_time_performance(profiler)\n        yield self.streamer.end(), None\n        if profiler.get_counter('decode') >= max_new_tokens - 1:\n            yield \"\", \"length\"\n        else:\n            yield \"\", \"stop\"\n        \n        \n        yield RawUsage(\n                tokenize_time = profiler.get_timer_sec('tokenize'),\n                prefill_time = profiler.get_timer_sec('prefill'),\n                decode_time = profiler.get_timer_sec('decode'),\n                prefill_count = profiler.get_counter('prefill'),\n                decode_count = profiler.get_counter('decode'),\n            )\n"
  },
  {
    "path": "archive/ktransformers/server/backend/interfaces/exllamav2.py",
    "content": "import sys, os\nfrom typing import AsyncIterator, Dict, Tuple\n\nimport torch\n\nfrom ..args import ConfigArgs, default_args\n\nfrom ..base import BackendInterfaceBase, ThreadContext\nfrom ktransformers.server.schemas.assistants.runs import RunObject\n\n\nfrom ..args import *\n\nclass ExllamaThreadContext(ThreadContext):\n    def __init__(self, run: RunObject, args: ConfigArgs = default_args) -> None:\n        super().__init__(run,args)\n        \n    def get_interface(self):\n        return \n\n    def get_local_messages(self):\n        raise NotImplementedError\n\n\n\n\nclass ExllamaInterface(BackendInterfaceBase):\n    \n    def __init__(self, args: ConfigArgs = ...):\n        raise NotImplementedError\n    \n    def tokenize_prompt(self, prompt: str) -> torch.Tensor:\n        raise NotImplementedError\n    \n    async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator:\n        raise NotImplementedError\n    \n\n\n\n"
  },
  {
    "path": "archive/ktransformers/server/backend/interfaces/ktransformers.py",
    "content": "import torch\nimport torch.distributed as dist\nfrom torch import nn\nfrom torch.nn.attention import SDPBackend\nimport asyncio\nfrom transformers import AutoTokenizer, AutoConfig, GenerationConfig\nfrom ktransformers.server.backend.interfaces.transformers import (\n    TransformersInterface,\n    ConfigArgs,\n    TransformersThreadContext,\n    default_args,\n    TextStreamer,\n)\nimport os\ntry:\n    import torch_npu\n    use_npu = torch.npu.is_available()\n    from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel\nexcept:\n    use_npu = False\nfrom torch import nn\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.optimize.optimize import optimize_and_load_gguf\nfrom ktransformers.models.custom_cache import StaticCache\nfrom ktransformers.util.cuda_graph_runner import CUDAGraphRunner\nfrom ktransformers.local_chat import custom_models, default_optimize_rules\nfrom ktransformers.util.utils import get_device, get_all_used_cuda_device\nfrom ktransformers.util import utils\nfrom typing import Optional\nfrom ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage\nfrom typing import Any, List, Optional, Set\nfrom ktransformers.server.config.config import Config\n\nwarm_uped = False\nspeculative_decoding = True # True -> verify by random accept ; False-> verify by token id\nglobal_acc_counts = 0\nglobal_verify_counts = 0\n\nktransformer_rules_dir = (\n    os.path.dirname(os.path.abspath(__file__)) + \"/../../../optimize/optimize_rules/\"\n)\ndefault_optimize_rules = {\n    \"DeepseekV2ForCausalLM\": ktransformer_rules_dir + \"DeepSeek-V2-Chat.yaml\",\n    \"DeepseekV3ForCausalLM\": ktransformer_rules_dir + \"DeepSeek-V3-Chat.yaml\",\n    \"Qwen2MoeForCausalLM\": ktransformer_rules_dir + \"Qwen2-57B-A14B-Instruct.yaml\",\n    \"LlamaForCausalLM\": ktransformer_rules_dir + \"Internlm2_5-7b-Chat-1m.yaml\",\n    \"MixtralForCausalLM\": ktransformer_rules_dir + \"Mixtral.yaml\"\n}\nif use_npu:\n    default_optimize_rules[\"DeepseekV3ForCausalLM\"] = ktransformer_rules_dir + \"DeepSeek-V3-Chat-npu.yaml\"\nclass KTransformersThreadContext(TransformersThreadContext):\n    pass\n\n\nclass KTransformersInterface(TransformersInterface):\n    def __init__(self, args: ConfigArgs = default_args, input_args=None):\n        self.args = input_args\n        self.local_rank, self.world_size = setup_model_parallel(tp=self.args.tp)\n        if use_npu and (utils.CUR_DEVICE is None):\n            utils.CUR_DEVICE = f\"npu:{torch.npu.current_device()}\"\n            self.args.device = utils.CUR_DEVICE\n            self.args.device = f\"npu:{torch.npu.current_device()}\"\n        torch.set_grad_enabled(False)\n        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)\n        config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)\n        try:\n            generation_config = GenerationConfig.from_pretrained(args.model_dir)\n        except:\n            generation_config = GenerationConfig(\n                max_length=args.max_new_tokens,\n                temperature=args.temperature,\n                top_p=args.top_p,\n                do_sample=True\n            )\n        \n        torch.set_default_dtype(config.torch_dtype)\n        if config.architectures[0] == \"Qwen2MoeForCausalLM\":\n            config._attn_implementation = \"flash_attention_2\"\n        config.backend_type = \"ktransformers\"\n        config.chunk_size = self.args.chunk_size\n        with torch.device(\"meta\"):\n            self.model = custom_models[config.architectures[0]](config)\n        if input_args.optimize_config_path is not None:\n            optimize_config_path = input_args.optimize_config_path\n        elif default_args.optimize_config_path is None:\n            optimize_config_path = default_optimize_rules[config.architectures[0]]\n        else:\n            optimize_config_path = args.optimize_config_path\n\n        # print(optimize_config)\n\n        gguf_path = args.gguf_path\n        if gguf_path is None:\n            gguf_path = input(\n                \"please input the path of your gguf file(gguf file in the dir containing input gguf file must all\"\n                \" belong to current model):\"\n            )\n        optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config, q4_gguf_path=input_args.q4_gguf_path)\n        #提前absorbed\n        get_absort_weight(self.model, config)\n        # utils.get_absort_weight(self.model, config)\n        self.model.eval()\n        self.model.generation_config = generation_config\n        self.device_map = self.model.gguf_loader.tensor_device_map\n        self.top_p = torch.tensor([[self.model.generation_config.top_p]], dtype = torch.float16, device = self.args.device)\n        self.top_k = torch.tensor([[self.model.generation_config.top_k]], dtype = torch.int32, device = self.args.device)\n        self.temperature = torch.tensor([[self.model.generation_config.temperature]], dtype = torch.float16, device = self.args.device)\n        self.next_token_fake = torch.tensor([[1]], dtype=torch.int32, device = self.args.device)\n        self.next_token_probs = torch.tensor([[1.0]], dtype=torch.float16, device = self.args.device)\n        self.draft_model = None\n\n        # logger.info(f\"{args.model_name} loaded from {args.model_dir} to {self.device_map}\")\n        self.cache = StaticCache(\n            config=self.model.config,\n            max_batch_size=args.batch_size,\n            max_cache_len=args.cache_lens,\n            device=self.device_map,\n            dtype=self.model.dtype,\n        )\n        # logger.info(f\"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}\")\n\n        if self.model.generation_config.pad_token_id is None:\n            self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id\n        self.streamer = TextStreamer(self.tokenizer)\n\n        self._infer_lock = asyncio.Lock()\n\n    @torch.no_grad\n    def decode_one_tokens(self):\n        global warm_uped\n\n        device_map = self.model.gguf_loader.tensor_device_map\n        torch_device = get_device(\"blk.0.self_attn\", device_map)\n        torch_device = \"cuda:0\" if torch_device == \"cuda\" else torch_device\n        torch.cuda.set_device(torch_device)\n        if warm_uped and self.args.use_cuda_graph:\n            if use_npu:\n                from ktransformers.util.npu_graph_runner import get_or_create_runner, check_runner\n                if check_runner(utils.get_current_device()):\n                    npu_graph_runner = get_or_create_runner(utils.get_current_device())\n                    npu_graph_runner.init(self.args.batch_size, self.seq_length)\n                    self.cuda_graph_runner = npu_graph_runner\n                    utils._USE_NPU_GRAPH = True\n                    self.cuda_graph_runner.capture(\n                        self.model,\n                        self.current_ids,\n                        self.active_cache_position.unsqueeze(0),\n                        self.active_cache_position,\n                        self.cache,\n                        main_device=torch_device,\n                        return_dict=False,\n                        use_cache=True,\n                    )\n                if hasattr(self, \"cuda_graph_runner\"):\n                    inputs_embeds = self.model.model.embed_tokens(self.current_ids.to(\"cpu\")).to(utils.get_current_device())\n                    logits = self.cuda_graph_runner(\n                        inputs_embeds, self.active_cache_position.unsqueeze(0), self.active_cache_position\n                    )[0]\n                    self.cache.change_seq_length(1)\n                    torch.cuda.synchronize()\n                    logits = logits[0, -1, :]\n                    return self.logits_to_token(logits)\n            else:\n                if not hasattr(self, \"cuda_graph_runner\"):\n                    self.cuda_graph_runner = CUDAGraphRunner()\n                    self.cuda_graph_runner.capture(\n                        self.model,\n                        self.current_ids,\n                        self.active_cache_position.unsqueeze(0),\n                        self.active_cache_position,\n                        self.cache,\n                        main_device=torch_device,\n                        return_dict=False,\n                        use_cache=True,\n                    )\n                if hasattr(self, \"cuda_graph_runner\"):\n                    logits = self.cuda_graph_runner(\n                        self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position\n                    )\n                    self.cache.change_seq_length(1)\n                    torch.cuda.synchronize()\n                    logits = logits[0, -1, :]\n                    return self.logits_to_token(logits)\n        \n        if self.args.use_cuda_graph:\n            warm_uped = True\n            \n        if self.use_static_cache:\n            logits = self.model(\n                self.current_ids.to(torch_device),\n                cache_position=self.active_cache_position,\n                past_key_values=self.cache,\n                return_dict=False,\n                use_cache=True,\n                is_prefill=False,\n            )[0]\n        else:\n            logits = self.model(self.current_ids, return_dict=False, is_prefill=False)[0]\n        logits = logits[0, -1, :]\n\n        return self.logits_to_token(logits)\n\n\n    @torch.no_grad\n    def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):\n        input_ids_length = input_ids.shape[-1]\n        if max_tokens is not None:\n            max_completion_tokens = max_tokens\n        if max_completion_tokens is None:\n            max_new_tokens = self.args.max_new_tokens\n        else:\n            max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)\n        if(input_ids_length >= self.args.cache_lens):\n            logger.warning(f\"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}\")\n            self.seq_length = input_ids_length\n            return\n        logger.debug(f\"input_ids: {input_ids.shape}\")\n        device = self.device_map.get(\"blk.0.self_attn\", {}).get(\"generate_device\", \"cuda:0\")\n        device = \"cuda:0\" if device == \"cuda\" else device\n        if is_new:\n            self.ever_generated_ids.clear()\n            same_prefix = 0\n            # flat_input_ids = input_ids.flatten()\n\n            if getattr(self, 'generated_ids', None) is None:\n                self.generated_ids = torch.zeros(\n                    self.args.batch_size,\n                    input_ids.shape[-1] + max_new_tokens + 1,\n                    dtype=torch.int,\n                    device=self.args.device,\n                )\n                self.seq_length = 1\n            \n\n            logger.debug(f\"same prefix len: {same_prefix}\")\n            self.cache.remove_suffix(same_prefix)\n            self.seq_length = same_prefix\n            self.cache.position[0] = same_prefix\n            self.generated_ids = self.generated_ids[..., :same_prefix]\n            input_ids = input_ids[..., same_prefix:]\n            input_ids_length = input_ids.shape[-1]\n\n        self.ever_generated_ids.clear()\n        self.profiler.set_counter(\"prefill\", input_ids_length)\n        logger.debug(f\"input_ids: {input_ids.shape}\")\n        logger.debug(f\"generate_ids: {self.generated_ids.shape}\")\n        \n        former_seq_length = self.seq_length\n        self.seq_length += input_ids_length\n        expected_length = min(self.seq_length + max_new_tokens + 1, self.args.cache_lens)\n        delta_length = expected_length - self.generated_ids.shape[-1]\n        if delta_length > 0:\n            new_generate_ids = torch.zeros(\n                self.args.batch_size, delta_length, dtype=torch.int, device=utils.get_current_device()\n            )\n            self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)\n        else:\n            logger.warning(f\"seq_length bigger than cache_lens, killed\")\n            exit(0)\n        \n        logger.debug(f\"cache position: {former_seq_length} to {self.seq_length}\")\n        cache_position = torch.arange(former_seq_length, self.seq_length, device=device)\n        self.generated_ids[:, cache_position] = input_ids.to(utils.get_current_device()).to(torch.int)\n\n        if not (type(self) is TransformersInterface):\n            input_ids = input_ids.to(\"cpu\")\n        \n        def chunk_prefill(input_ids, cache_position):\n            inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)\n            torch.cuda.set_device(device)\n            if flashinfer_enabled:\n                MLAWrapperSingleton.need_plan_all()\n            if self.use_static_cache:\n                logits = self.model(\n                    inputs_embeds=inputs_embeds,\n                    cache_position=cache_position,\n                    past_key_values=self.cache,\n                    return_dict=False,\n                    use_cache=True,\n                    is_prefill=True,\n                )[0]\n            else:\n                logits = self.model(inputs_embeds=inputs_embeds, return_dict=False, is_prefill=True)[0]\n\n            return logits\n\n        if not use_npu:\n            chunk_start = 0\n            while chunk_start < input_ids_length:\n                chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)\n                if self.cache != None:\n                    self.cache.cur_idx=cache_position[chunk_start:chunk_end]\n                logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])\n                chunk_start += self.args.chunk_size\n                \n            if flashinfer_enabled:\n                MLAWrapperSingleton.reset_buffer()\n            self.prepare_logits_wrapper(input_ids, device, temperature, top_p)\n            next_token = self.logits_to_token(logits[0, -1, :])\n            self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1 \n            yield self.append_new_tokens(next_token)\n            return\n\n        def prefill_wrapper(prof=None):\n            chunk_start = 0\n            while chunk_start < input_ids_length:\n                chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)\n                if self.cache != None:\n                    self.cache.cur_idx = cache_position[chunk_start:chunk_end]\n                logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])\n                chunk_start += self.args.chunk_size\n                if prof is not None:\n                    prof.step()\n            if prof is not None:\n                prof.stop()\n            if logits is None:\n                raise ValueError('logits cannot be None')\n            return logits\n\n        global WARM_UP_SKIP_CNT\n        prof_prefill = os.environ[\"PROF_PREFILL\"] if \"PROF_PREFILL\" in os.environ else \"0\"\n        if prof_prefill == \"1\":\n            experimental_config = torch_npu.profiler._ExperimentalConfig(\n                aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,\n                profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False\n            )\n            with torch_npu.profiler.profile(\n                    activities=[\n                        torch_npu.profiler.ProfilerActivity.CPU,\n                        torch_npu.profiler.ProfilerActivity.NPU\n                    ],\n                    schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),\n                    on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(\"./prefill_prof_lm_head\"),\n                    record_shapes=True,\n                    profile_memory=True,\n                    with_stack=False,\n                    with_flops=False,\n                    with_modules=False,\n                    experimental_config=experimental_config) as prof:\n                logits = prefill_wrapper(prof)\n        else:\n            logits = prefill_wrapper()\n            \n        if flashinfer_enabled:\n            MLAWrapperSingleton.reset_buffer()\n        self.prepare_logits_wrapper(input_ids, device, temperature, top_p)\n        next_token = self.logits_to_token(logits[0, -1, :])\n        self.cache.position[0] = self.seq_length\n        yield self.append_new_tokens(next_token)\n\n    @property\n    def active_cache_position(self):\n        device = self.device_map.get(\"blk.0.self_attn\", {}).get(\"generate_device\", \"cuda:0\")\n        return torch.tensor([self.seq_length - 1], device=device)\n    \n    def sampling(self, logits, do_sample):\n        if do_sample:\n            cur_len = logits.shape[1]\n            logits = logits / self.temperature\n            torch.manual_seed(0)\n            probs = logits.view(-1, cur_len, self.model.config.vocab_size)\n            probs = torch.softmax(probs, dim=-1).half()\n            next_token = self.next_token_fake\n            if self.draft_model is None or not speculative_decoding:\n                torch_npu._npu_topk_topp_sampling(probs[:, 0, :], self.top_k, self.top_p, next_token, self.next_token_probs)\n            for i in range(1,cur_len):\n                ith_token = torch.empty_like(self.next_token_fake)\n                torch_npu._npu_topk_topp_sampling(probs[:, i, :], self.top_k, self.top_p, ith_token, self.next_token_probs)\n                next_token = torch.cat((next_token, ith_token), dim=-1)\n        else:\n            next_token = torch.argmax(logits, dim=-1)\n            probs = torch.softmax(logits, dim=-1)\n\n        return next_token, probs\n\n    def verify_by_tokenid(self, main_token: int, draft_token: int):\n        return main_token, main_token == draft_token\n\n    def verify_speculative_decoding(self, main_prob: torch.Tensor, draft_prob: torch.Tensor, draft_token: int, p: float):\n        #assert draft_prob[draft_token] == p\n        q = main_prob[draft_token]\n        #p = draft_prob[draft_token]\n        accept_prob = min(1.0, (q / p).item())\n        if torch.rand(()) <= accept_prob:\n            return draft_token, True\n        else:\n            # Compute the adjusted distribution for resampling\n            new_prob = main_prob - draft_prob\n            new_prob = torch.clamp(new_prob, min=0.0)\n            new_prob /= new_prob.sum()\n\n            # Sample a new token from the adjusted distribution\n            token = torch.multinomial(new_prob, 1).item()\n            return token, False\n\n    def logits_to_token(self, logits: torch.Tensor):\n        if self.model.generation_config.do_sample:\n            logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))\n            probs = torch.nn.functional.softmax(logits, dim=-1)\n            last = torch.multinomial(probs, num_samples=1)\n        else:\n            logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))\n            probs = torch.nn.functional.softmax(logits, dim=-1)\n            _, last = torch.topk(probs, k=1, dim=-1)\n        last = last.item()\n        self.ever_generated_ids.add(last)\n        return last\n\n    async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):\n        async with self._infer_lock:\n            async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens):\n                yield v\n            \n            # return this inference raw usage\n            yield RawUsage(\n                tokenize_time = self.profiler.get_timer_sec('tokenize'),\n                prefill_time = self.profiler.get_timer_sec('prefill'),\n                decode_time = self.profiler.get_timer_sec('decode'),\n                prefill_count = self.profiler.get_counter('prefill'),\n                decode_count = self.profiler.get_counter('decode'),\n            )\n\n    def sync_inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None) -> str:\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n\n        try:\n            async def run_async():\n                result = []\n                async for chunk in self.inference(local_messages, thread_id, temperature, top_p):\n                    pass\n                return \"\"\n            return loop.run_until_complete(run_async())\n        finally:\n            loop.close()\n"
  },
  {
    "path": "archive/ktransformers/server/backend/interfaces/transformers.py",
    "content": "from typing import Any, List, Optional, Set\nimport re\nimport json\nimport uuid\ntry:\n    import torch_npu\n    use_npu = torch.npu.is_available()\nexcept:\n    use_npu = False\n\nfrom transformers import (\n    LlamaTokenizer,\n    AutoTokenizer,\n    AutoConfig,\n    LlamaForCausalLM,\n    GenerationConfig,\n    StaticCache,\n    AutoModelForCausalLM,\n    BitsAndBytesConfig,\n    LogitsProcessorList,\n    TemperatureLogitsWarper,\n    TopKLogitsWarper,\n    TopPLogitsWarper,\n    MinPLogitsWarper,\n    TypicalLogitsWarper,\n    EpsilonLogitsWarper,\n    EtaLogitsWarper,\n)\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.schemas.base import ObjectID\nfrom ktransformers.server.utils.multi_timer import Profiler\nfrom torch.nn.attention import SDPBackend\nimport torch\nimport torch.distributed as dist\n\nfrom ktransformers.util import utils\nimport sys, os\nfrom ..base import ThreadContext, BackendInterfaceBase\nfrom ktransformers.server.config.log import logger\nfrom ..args import ConfigArgs, default_args\nfrom ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton\nfrom ktransformers.util import utils\n\n\n# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py\nclass TextStreamer:\n\n    def __init__(self, tokenizer: \"AutoTokenizer\", skip_prompt: bool = False, **decode_kwargs):\n        self.tokenizer = tokenizer\n        self.skip_prompt = skip_prompt\n        self.decode_kwargs = decode_kwargs\n\n        # variables used in the streaming process\n        self.token_cache = []\n        self.print_len = 0\n        self.next_tokens_are_prompt = True\n\n    def reset(self):\n        self.token_cache = []\n        self.print_len = 0\n\n    def put(self, value) -> Optional[str]:\n        \"\"\"\n        Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.\n        \"\"\"\n        if not isinstance(value, int):\n            raise ValueError(\"TextStreamer only supports batch size 1, and int type input\")\n\n        if self.skip_prompt and self.next_tokens_are_prompt:\n            self.next_tokens_are_prompt = False\n            return None\n\n        # Add the new token to the cache and decodes the entire thing.\n        self.token_cache.append(value)\n        text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)\n\n        # After the symbol for a new line, we flush the cache.\n        if text.endswith(\"\\n\"):\n            printable_text = text[self.print_len :]\n            self.reset()\n        # If the last token is a CJK character, we print the characters.\n        elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):\n            printable_text = text[self.print_len :]\n            self.print_len += len(printable_text)\n        # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,\n        # which may change with the subsequent token -- there are probably smarter ways to do this!)\n        else:\n            printable_text = text[self.print_len : text.rfind(\" \") + 1]\n            self.print_len += len(printable_text)\n        return printable_text\n\n    def end(self) -> Optional[str]:\n        \"\"\"Flushes any remaining cache and prints a newline to stdout.\"\"\"\n        # Flush the cache, if it exists\n        if len(self.token_cache) > 0:\n            text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)\n            printable_text = text[self.print_len :]\n            self.reset()\n        else:\n            printable_text = \"\"\n\n        self.next_tokens_are_prompt = True\n        return printable_text\n\n    def _is_chinese_char(self, cp):\n        \"\"\"Checks whether CP is the codepoint of a CJK character.\"\"\"\n        # This defines a \"chinese character\" as anything in the CJK Unicode block:\n        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)\n        #\n        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,\n        # despite its name. The modern Korean Hangul alphabet is a different block,\n        # as is Japanese Hiragana and Katakana. Those alphabets are used to write\n        # space-separated words, so they are not treated specially and handled\n        # like the all of the other languages.\n        if (\n            (cp >= 0x4E00 and cp <= 0x9FFF)\n            or (cp >= 0x3400 and cp <= 0x4DBF)  #\n            or (cp >= 0x20000 and cp <= 0x2A6DF)  #\n            or (cp >= 0x2A700 and cp <= 0x2B73F)  #\n            or (cp >= 0x2B740 and cp <= 0x2B81F)  #\n            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #\n            or (cp >= 0xF900 and cp <= 0xFAFF)\n            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #\n        ):  #\n            return True\n\n        return False\n\n\nclass TransformersThreadContext(ThreadContext):\n    def get_local_messages(self):\n        local_messages = []\n        for m in self.messages:\n            local_messages.append({\"role\": m.role.value, \"content\": m.get_text_content()})\n\n        return local_messages\n\n\nclass TransformersInterface(BackendInterfaceBase):\n    use_static_cache: bool = True\n\n    model: Any\n    tokenizer: AutoTokenizer\n\n    cache: StaticCache\n    generated_ids: torch.Tensor\n    seq_length: int\n\n    streamer: TextStreamer\n\n    # thread_related\n    last_request_id: Optional[str] = None\n    ever_generated_ids: Set[int] = set()\n\n    def __init__(self, args: ConfigArgs = default_args):\n        self.args = args\n\n        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)\n        self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)\n        # logger.info(f\"{args.model_name} loaded from {args.model_dir} to {args.device}\")\n\n        self.cache = StaticCache(\n            config=self.model.config,\n            max_batch_size=args.batch_size,\n            max_cache_len=args.cache_lens,\n            device=args.device,\n            dtype=self.model.dtype,\n        )\n        # logger.info(f\"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}\")\n\n        self.streamer = TextStreamer(self.tokenizer)\n\n    @property\n    def current_ids(self):\n        return self.generated_ids[:, self.seq_length - 1].unsqueeze(1)\n\n    @property\n    def active_cache_position(self):\n        return torch.tensor([self.seq_length - 1], device=self.args.device)\n\n    def tokenize_prompt(self, prompt: str):\n        input_ids = self.tokenizer.encode(prompt, return_tensors=\"pt\").to(self.args.device)\n        return input_ids\n\n    def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):\n        for m in messages:\n            if m[\"role\"] == \"system\":\n                logger.warning(f'change {m[\"role\"]} to user')\n                m[\"role\"] = \"user\"\n\n        new_messages = [messages[0]]\n        for m in messages[1:]:\n            if m[\"role\"] == \"user\" and new_messages[-1][\"role\"] == \"user\":\n                logger.warning(\"merge two adjacent user messages\")\n                new_messages[-1][\"content\"] += '\\n' + m[\"content\"]\n            else:\n                new_messages.append(m)\n        # if (self.last_request_id is not None) and self.last_request_id == thread_id:\n        #     input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors=\"pt\",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors=\"pt\").to(self.args.device)\n        # else:\n        #     input_ids = self.tokenizer.apply_chat_template(\n        #         new_messages, return_tensors=\"pt\", add_generation_prompt=True\n        #     ).to(self.args.device)\n        # input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)\n        # drop <think> token in chat template\n        # if input_str.endswith('<think>\\n'):\n        #     input_str = input_str[:-len('<think>\\n')]\n        # input_ids = self.tokenizer.encode(input_str, return_tensors=\"pt\").to(self.args.device)\n        input_ids = self.tokenizer.apply_chat_template(new_messages, add_generation_prompt=True, return_tensors=\"pt\").to(self.args.device)\n        if (self.last_request_id is not None) and self.last_request_id == thread_id:\n            x = self.generated_ids[:,:self.seq_length]\n            y = input_ids[:,:self.seq_length]\n            # We can only hope that the input_ids are the same\n            unequal_mask = torch.ne(x,y)\n            unequal_positions = torch.nonzero(unequal_mask)\n            num_unequal_elements = unequal_mask.sum().item()\n            logger.warning(f'num_unequal_elements: {num_unequal_elements}') \n\n            input_ids = input_ids[:,self.seq_length:]\n        logger.debug(f\"get input ids of shape {input_ids.shape}\")\n        return input_ids\n\n    def append_new_tokens(self, new_tokens: int) -> Optional[str]:\n        self.generated_ids[0, self.seq_length] = new_tokens\n        self.seq_length += 1\n        self.cache.position[0] += 1\n        return self.streamer.put(new_tokens)\n\n    @staticmethod\n    def tf_logits_warper(generation_config):\n        \"\"\"\n        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances\n        used for multinomial sampling.\n        \"\"\"\n\n        # instantiate warpers list\n        warpers = LogitsProcessorList()\n\n        # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a\n        # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)\n        if generation_config.num_beams > 1:\n            if isinstance(generation_config._eos_token_tensor, list):\n                min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1\n            elif isinstance(generation_config._eos_token_tensor, torch.Tensor):\n                min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1\n            else:\n                min_tokens_to_keep = 2\n        else:\n            min_tokens_to_keep = 1\n\n        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files\n        # all samplers can be found in `generation_utils_samplers.py`\n        if generation_config.temperature is not None and generation_config.temperature != 1.0:\n            warpers.append(TemperatureLogitsWarper(generation_config.temperature))\n        if generation_config.top_k is not None and generation_config.top_k != 0:\n            warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.top_p is not None and generation_config.top_p < 1.0:\n            warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.min_p is not None:\n            # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)\n            warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.typical_p is not None and generation_config.typical_p < 1.0:\n            warpers.append(\n                TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)\n            )\n        if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:\n            warpers.append(\n                EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)\n            )\n        if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:\n            warpers.append(\n               EtaLogitsWarper(\n                    epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device\n                )\n            )\n        # `LogitNormalization` should always be the last logit processor, when present\n        if generation_config.renormalize_logits is True:\n            warpers.append(LogitNormalization())\n        return warpers\n\n    def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):\n        if temperature is None or temperature == 0:\n            temperature = self.model.generation_config.temperature\n        if top_p is None:\n            top_p = self.model.generation_config.top_p\n        if top_p == 0:\n            top_p = 0.0001\n        # keep sampler the same as local_chat\n        generation_config, model_kwargs = self.model._prepare_generation_config(\n            None, max_length=self.args.max_new_tokens,\n            do_sample=True, \n            top_k=self.args.top_k, \n            top_p=top_p, \n            temperature=temperature,\n            repetition_penalty=self.args.repetition_penalty # change this to modify generate config\n        )\n        self.inputs = inputs\n        self.logits_warper = self.tf_logits_warper(generation_config)\n\n    def logits_to_token(self, logits: torch.Tensor):\n        logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))\n\n        probs = torch.nn.functional.softmax(logits, dim=-1)\n\n        sample = True\n        if sample:\n            last = torch.multinomial(probs, num_samples=1)\n        else:\n            _, last = torch.topk(probs, k=1, dim=-1)\n\n        last = last.item()\n        self.ever_generated_ids.add(last)\n        return last\n\n    def decode_one_tokens(self):\n        if self.use_static_cache:\n            logits = self.model(\n                self.current_ids,\n                cache_position=self.active_cache_position,\n                past_key_values=self.cache,\n                return_dict=False,\n                use_cache=True,\n            )[0]\n        else:\n            logits = self.model(self.current_ids, return_dict=False)[0]\n        logits = logits[0, -1, :]\n\n        return self.logits_to_token(logits)\n\n    @torch.no_grad\n    def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):\n        input_ids_length = input_ids.shape[-1]\n        logger.debug(f\"input_ids: {input_ids.shape}\")\n        if max_tokens is not None:\n            max_completion_tokens = max_tokens\n        if max_completion_tokens is None:\n            max_new_tokens = self.args.max_new_tokens\n        else:\n            max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)\n\n        if is_new:\n            self.ever_generated_ids.clear()\n            same_prefix = 0\n            flat_input_ids = input_ids.flatten()\n\n            if getattr(self, 'generated_ids', None) is None:\n                self.generated_ids = torch.zeros(\n                    self.args.batch_size,\n                    input_ids.shape[-1] + max_new_tokens + 1,\n                    dtype=torch.int,\n                    device=self.args.device,\n                )\n                self.seq_length = 1            \n            \n            flat_prev_ids = self.generated_ids.flatten()\n            for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):\n                if flat_input_ids[i] == flat_prev_ids[i]:\n                    same_prefix += 1\n                else:\n                    break\n            \n            logger.debug(f\"same prefix len: {same_prefix}\")\n            self.cache.remove_suffix(same_prefix)\n            self.seq_length = same_prefix\n            self.generated_ids = self.generated_ids[..., :same_prefix]\n            input_ids = input_ids[..., same_prefix:]\n            input_ids_length = input_ids.shape[-1]\n        \n        self.ever_generated_ids.clear()\n        self.profiler.set_counter(\"prefill\", input_ids_length)\n        logger.debug(f\"input_ids: {input_ids.shape}\")\n\n        logger.debug(f\"generate_ids: {self.generated_ids.shape}\")\n        former_seq_length = self.seq_length\n        self.seq_length += input_ids_length\n        expected_length = self.seq_length + max_new_tokens + 1\n        delta_length = expected_length - self.generated_ids.shape[-1]\n        if delta_length > 0:\n            new_generate_ids = torch.zeros(\n                self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device\n            )\n            self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)\n            \n        logger.debug(f\"cache position: {former_seq_length} to {self.seq_length}\")\n        cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)\n        self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)\n\n        device = input_ids.device\n        if not (type(self) is TransformersInterface):\n            input_ids = input_ids.to(\"cpu\")\n        inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)\n        if self.use_static_cache:\n            logits = self.model(\n                inputs_embeds=inputs_embeds,\n                cache_position=cache_position,\n                past_key_values=self.cache,\n                return_dict=False,\n                use_cache=True,\n            )[0]\n        else:\n            logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]\n\n        self.prepare_logits_wrapper(input_ids, device, temperature, top_p)\n        next_token = self.logits_to_token(logits[0, -1, :])\n        yield self.append_new_tokens(next_token)\n\n    @torch.no_grad\n    def generate(self):\n        self.max_new_tokens = min(self.args.max_new_tokens, self.args.cache_lens - self.seq_length) - 1 \n        logger.info(f\"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}\")\n        if(self.max_new_tokens <= 0):\n            logger.warning(\"max_new_tokens is less than 0\")\n            yield self.streamer.end(), \"length\"\n            return\n        logger.info(f\"max_new_tokens: {self.max_new_tokens}\")\n        self.profiler.set_counter(\"decode\", 0)\n\n        for i in range(1, self.max_new_tokens):\n            with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):\n                if flashinfer_enabled:\n                    MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None,\n                                             num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, \n                                             head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,\n                                             sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)\n                next_token = self.decode_one_tokens()\n                self.profiler.inc(\"decode\")\n                if next_token == self.tokenizer.eos_token_id or \"<|im_end|>\" == self.tokenizer.decode(next_token):\n                    yield self.streamer.end(), None\n                    yield \"\", \"stop\"\n                    assert self.args.batch_size == 1\n                    break\n                yield self.append_new_tokens(next_token), None\n\n        else:   # for's else, if output get max new tokens\n            yield self.streamer.end(), None\n            yield \"\", \"length\"\n\n        if self.args.use_cuda_graph:\n            utils._USE_NPU_GRAPH = False\n            from ktransformers.util.npu_graph_runner import get_or_create_runner\n            npu_graph_runner = get_or_create_runner(utils.get_current_device())\n            npu_graph_runner.destroy()\n\n    def check_is_new(self, thread_id: str):\n        if not self.use_static_cache:\n            return True\n        if self.last_request_id is None:\n            self.last_request_id = thread_id\n            return True\n        else:\n            if self.last_request_id == thread_id:\n                return False\n            else:\n                self.last_request_id = thread_id\n                return True\n\n    async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):\n        self.streamer.reset()\n        self.profiler.create_and_start_timer(\"tokenize\")\n        torch.distributed.barrier()\n        rank = torch.distributed.get_rank()\n        world_size = torch.distributed.get_world_size()\n        tp_size = utils.get_tensor_parallel_size()\n        if isinstance(local_messages, List):\n            input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)\n        elif isinstance(local_messages, str):\n            #local_messages = local_messages[0]['content']\n            input_ids = self.tokenize_prompt(local_messages)\n            #input_ids = torch.tensor([[6366]], device=input_ids.device)\n        else:\n            raise ValueError(\"local_messages should be List or str\")\n\n        if tp_size == world_size and tp_size > 1:\n            torch.distributed.barrier()\n            input_size = torch.tensor([input_ids.size(1)], dtype=torch.int64, device=utils.CUR_DEVICE)\n            all_input_sizes = [torch.zeros_like(input_size) for _ in range(world_size)]\n            dist.all_gather(all_input_sizes, input_size)\n\n            max_input_size = max([size.item() for size in all_input_sizes])\n            padded_input_ids = torch.zeros(1, max_input_size, dtype=input_ids.dtype, device=utils.CUR_DEVICE)\n            padded_input_ids[0, :input_ids.size(1)] = input_ids[0]\n\n            all_padded_inputs = [torch.zeros_like(padded_input_ids) for _ in range(world_size)]\n            dist.all_gather(all_padded_inputs, padded_input_ids)\n\n            original_size = all_input_sizes[0].item()\n            input_ids = all_padded_inputs[0][:, :original_size]\n        \n        if Config().user_force_think:\n            token_thinks = torch.tensor([self.tokenizer.encode(\"<think>\\n\",add_special_tokens=False)],device=input_ids.device)\n            if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]):\n                input_ids = torch.cat(\n                    [input_ids, token_thinks], dim=1\n                )\n\n        self.profiler.pause_timer(\"tokenize\")\n\n        self.profiler.create_and_start_timer(\"prefill\")\n\n        if Config().user_force_think:\n            think = '<think>\\n'\n            print(think, end=\"\",flush=True)\n            yield think, None\n        \n        for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, max_tokens, max_completion_tokens):\n            # output think token after prefill done\n            if t is not None:\n                print(t, end=\"\",flush=True)\n                yield t, None\n        self.profiler.pause_timer(\"prefill\")\n\n        self.profiler.create_and_start_timer(\"decode\")\n        for t, finish_reason in self.generate():\n            if t is not None:\n                if tp_size == world_size:\n                    if rank == 0:\n                        print(t, end=\"\", flush=True)\n                else:\n                    print(t, end=\"\",flush=True)\n                yield t, finish_reason\n\n        if tp_size == world_size:\n            if rank == 0:\n                print(\"\")\n                self.profiler.pause_timer(\"decode\")\n                self.report_last_time_performance()\n        else:\n            print(\"\")\n            self.profiler.pause_timer(\"decode\")\n            self.report_last_time_performance()\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/config.py",
    "content": "'''\nDate: 2024-11-07 07:30:16\nLastEditors: djw\nLastEditTime: 2024-11-15 14:23:26\n'''\nimport math\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\nfrom torch.nn import functional as F\nimport yaml\n\nimport json\nfrom typing import Optional\n\nmodel_runner_dict = dict()\n\nclass ModelConfig:\n    vocab_size: int = 32000\n    n_layer: int = 1\n    n_head: int = 32\n    dim: int = 4096\n    intermediate_size: int = 18944\n    n_local_heads: int = 8\n    head_dim: int = 128\n    rope_base: float = 1000000.0\n    norm_eps: float = 1e-06\n    rope_scaling: Optional[dict] = None\n    rms_norm_eps: float = 1e-6\n    hidden_act: str = \"silu\"\n    model_path: str\n    gguf_path: str\n    optimize_rule_path: str\n    speculative_rule_path: str\n            \n\n    # quantize config\n    quant_algorithm: Optional[str] = None\n    quant_group_size: Optional[int] = None\n    quant_num_bits: Optional[int] = None\n\n    json_key_map = {\n        \"vocab_size\": \"vocab_size\",\n        \"n_layer\": \"num_hidden_layers\",\n        \"n_head\": \"num_attention_heads\",\n        \"dim\": \"hidden_size\",\n        \"intermediate_size\": \"intermediate_size\",\n        \"n_local_heads\": \"num_key_value_heads\",\n        \"rope_base\": \"rope_theta\",\n        \"norm_eps\": \"norm_eps\",\n        \"rms_norm_eps\": \"rms_norm_eps\",\n        \"hidden_act\": \"hidden_act\",\n    }\n\n    def __init__(self, config):\n        self.model_path = config[\"model\"][\"model_path\"]\n        self.gguf_path = config[\"model\"][\"gguf_path\"]\n        self.optimize_rule_path = config[\"model\"][\"optimize_rule_path\"]\n        if \"speculative_rule_path\" in config[\"model\"]:\n            self.speculative_rule_path =  config[\"model\"][\"speculative_rule_path\"]\n            self.speculative_gguf_path = config[\"model\"][\"speculative_gguf_path\"]\n            self.speculative_model_path = config[\"model\"][\"speculative_model_path\"]\n        self.quant_algorithm = config[\"model\"][\"quant\"][\"algorithm\"]\n        self.quant_group_size = config[\"model\"][\"quant\"][\"group_size\"]\n        self.quant_num_bits = config[\"model\"][\"quant\"][\"num_bits\"]\n        self.load_config()\n        self.n_layer = config[\"model\"][\"n_layers\"]\n\n    def load_config(self):\n        config_file = f\"{self.model_path}/config.json\"\n        try:\n            with open(config_file, \"r\") as f:\n                config_data = json.load(f)\n        except FileNotFoundError:\n            raise FileNotFoundError(f\"Configuration file not found at {config_file}\")\n\n        for attr, json_key in self.json_key_map.items():\n            if json_key in config_data:\n                setattr(self, attr, config_data[json_key])\n            else:\n                setattr(self, attr, getattr(self, attr))\n\n\n    \n\n\nclass ParallelConfig:\n    def __init__(\n        self,\n        config,\n    ) -> None:\n        self.pipeline_parallel_size = config[\"parallel\"][\"pp\"]\n        self.tensor_parallel_size = config[\"parallel\"][\"tp\"]\n        self.disable_custom_all_reduce = config[\"parallel\"][\"disable_custom_all_reduce\"]\n        self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size\n\nclass AttnConfig:\n    page_size: int = 256\n    block_num: int = 32\n    max_batch_token : int = 256\n    max_batch_size: int = 32\n\n    def __init__(self, config):\n        self.page_size = config[\"attn\"][\"page_size\"]\n        self.block_num = config[\"attn\"][\"block_num\"]\n        self.max_batch_token = config[\"attn\"][\"max_batch_token\"]\n        self.max_batch_size = config[\"attn\"][\"max_batch_size\"]\n\n\nclass SamplerConfig():\n\t# Batched sampling params\n    temperatures: float\n    is_all_greedy: bool\n\t\n    def __init__(self, config):\n        self.temperatures = config[\"sample\"][\"temperature\"]\n        self.is_all_greedy = True\n\n\ndef load_yaml_config(file_path):\n    with open(file_path, \"r\") as f:\n        return yaml.safe_load(f)\n    \n\n\n\nclass LLMConfig:\n    model_config: ModelConfig\n    parallel_config: ParallelConfig\n    attn_config: AttnConfig\n    sample_config: SamplerConfig\n    config_file: str\n\n    def __init__(self, config_file):\n        self.config_file = config_file\n        config = load_yaml_config(config_file)\n        self.model_config = ModelConfig(config)\n        self.parallel_config = ParallelConfig(config)\n        self.attn_config = AttnConfig(config)\n        self.sample_config = SamplerConfig(config)\n\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/distributed/__init__.py",
    "content": "from .communication_op import *\nfrom .parallel_state import *\nfrom .utils import *\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/distributed/communication_op.py",
    "content": "\"\"\"\nDate: 2024-12-11 06:02:42\nLastEditors: djw\nLastEditTime: 2024-12-12 09:52:06\n\"\"\"\n\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nimport torch.distributed\n\nfrom .parallel_state import get_tp_group\n\n\ndef tensor_model_parallel_all_reduce(input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:\n    \"\"\"All-reduce the input tensor across model parallel group.\"\"\"\n    return get_tp_group().all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)\n\n\ndef tensor_model_parallel_all_gather(\n    input_: torch.Tensor, dim: int = -1\n) -> torch.Tensor:\n    \"\"\"All-gather the input tensor across model parallel group.\"\"\"\n    return get_tp_group().all_gather(input_, dim)\n\n\ndef tensor_model_parallel_gather(\n    input_: torch.Tensor, dst: int = 0, dim: int = -1\n) -> Optional[torch.Tensor]:\n    \"\"\"Gather the input tensor across model parallel group.\"\"\"\n    return get_tp_group().gather(input_, dst, dim)\n\n\ndef broadcast_tensor_dict(\n    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0\n):\n    if not torch.distributed.is_initialized():\n        return tensor_dict\n    return get_tp_group().broadcast_tensor_dict(tensor_dict, src)\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/distributed/cuda_wrapper.py",
    "content": "\"\"\"This file is a pure Python wrapper for the cudart library.\nIt avoids the need to compile a separate shared library, and is\nconvenient for use when we just need to call a few functions.\n\"\"\"\n\nimport ctypes\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\n# this line makes it possible to directly load `libcudart.so` using `ctypes`\nimport torch  # noqa\n\n# === export types and functions from cudart to Python ===\n# for the original cudart definition, please check\n# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html\n\ncudaError_t = ctypes.c_int\ncudaMemcpyKind = ctypes.c_int\n\n\nclass cudaIpcMemHandle_t(ctypes.Structure):\n    _fields_ = [(\"internal\", ctypes.c_byte * 128)]\n\n\n@dataclass\nclass Function:\n    name: str\n    restype: Any\n    argtypes: List[Any]\n\n\ndef find_loaded_library(lib_name) -> Optional[str]:\n    \"\"\"\n    According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,\n    the file `/proc/self/maps` contains the memory maps of the process, which includes the\n    shared libraries loaded by the process. We can use this file to find the path of the\n    a loaded library.\n    \"\"\" # noqa\n    found = False\n    with open(\"/proc/self/maps\") as f:\n        for line in f:\n            if lib_name in line:\n                found = True\n                break\n    if not found:\n        # the library is not loaded in the current process\n        return None\n    # if lib_name is libcudart, we need to match a line with:\n    # address /path/to/libcudart-hash.so.11.0\n    start = line.index(\"/\")\n    path = line[start:].strip()\n    filename = path.split(\"/\")[-1]\n    assert filename.rpartition(\".so\")[0].startswith(lib_name), \\\n        f\"Unexpected filename: {filename} for library {lib_name}\"\n    return path\n\n\nclass CudaRTLibrary:\n    exported_functions = [\n        # ​cudaError_t cudaSetDevice ( int  device )\n        Function(\"cudaSetDevice\", cudaError_t, [ctypes.c_int]),\n        # cudaError_t \tcudaDeviceSynchronize ( void )\n        Function(\"cudaDeviceSynchronize\", cudaError_t, []),\n        # ​cudaError_t cudaDeviceReset ( void )\n        Function(\"cudaDeviceReset\", cudaError_t, []),\n\n        # const char* \tcudaGetErrorString ( cudaError_t error )\n        Function(\"cudaGetErrorString\", ctypes.c_char_p, [cudaError_t]),\n\n        # ​cudaError_t \tcudaMalloc ( void** devPtr, size_t size )\n        Function(\"cudaMalloc\", cudaError_t,\n                 [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),\n        # ​cudaError_t \tcudaFree ( void* devPtr )\n        Function(\"cudaFree\", cudaError_t, [ctypes.c_void_p]),\n        # ​cudaError_t cudaMemset ( void* devPtr, int  value, size_t count )\n        Function(\"cudaMemset\", cudaError_t,\n                 [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),\n        # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa\n        Function(\"cudaMemcpy\", cudaError_t, [\n            ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind\n        ]),\n\n        # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa\n        Function(\"cudaIpcGetMemHandle\", cudaError_t,\n                 [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),\n        # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int  flags ) # noqa\n        Function(\"cudaIpcOpenMemHandle\", cudaError_t, [\n            ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint\n        ]),\n    ]\n\n    # class attribute to store the mapping from the path to the library\n    # to avoid loading the same library multiple times\n    path_to_library_cache: Dict[str, Any] = {}\n\n    # class attribute to store the mapping from library path\n    #  to the corresponding dictionary\n    path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}\n\n    def __init__(self, so_file: Optional[str] = None):\n        if so_file is None:\n            so_file = find_loaded_library(\"libcudart\")\n            assert so_file is not None, \\\n                \"libcudart is not loaded in the current process\"\n        if so_file not in CudaRTLibrary.path_to_library_cache:\n            lib = ctypes.CDLL(so_file)\n            CudaRTLibrary.path_to_library_cache[so_file] = lib\n        self.lib = CudaRTLibrary.path_to_library_cache[so_file]\n\n        if so_file not in CudaRTLibrary.path_to_dict_mapping:\n            _funcs = {}\n            for func in CudaRTLibrary.exported_functions:\n                f = getattr(self.lib, func.name)\n                f.restype = func.restype\n                f.argtypes = func.argtypes\n                _funcs[func.name] = f\n            CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs\n        self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]\n\n    def CUDART_CHECK(self, result: cudaError_t) -> None:\n        if result != 0:\n            error_str = self.cudaGetErrorString(result)\n            raise RuntimeError(f\"CUDART error: {error_str}\")\n\n    def cudaGetErrorString(self, error: cudaError_t) -> str:\n        return self.funcs[\"cudaGetErrorString\"](error).decode(\"utf-8\")\n\n    def cudaSetDevice(self, device: int) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaSetDevice\"](device))\n\n    def cudaDeviceSynchronize(self) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaDeviceSynchronize\"]())\n\n    def cudaDeviceReset(self) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaDeviceReset\"]())\n\n    def cudaMalloc(self, size: int) -> ctypes.c_void_p:\n        devPtr = ctypes.c_void_p()\n        self.CUDART_CHECK(self.funcs[\"cudaMalloc\"](ctypes.byref(devPtr), size))\n        return devPtr\n\n    def cudaFree(self, devPtr: ctypes.c_void_p) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaFree\"](devPtr))\n\n    def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,\n                   count: int) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaMemset\"](devPtr, value, count))\n\n    def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,\n                   count: int) -> None:\n        cudaMemcpyDefault = 4\n        kind = cudaMemcpyDefault\n        self.CUDART_CHECK(self.funcs[\"cudaMemcpy\"](dst, src, count, kind))\n\n    def cudaIpcGetMemHandle(self,\n                            devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:\n        handle = cudaIpcMemHandle_t()\n        self.CUDART_CHECK(self.funcs[\"cudaIpcGetMemHandle\"](\n            ctypes.byref(handle), devPtr))\n        return handle\n\n    def cudaIpcOpenMemHandle(self,\n                             handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:\n        cudaIpcMemLazyEnablePeerAccess = 1\n        devPtr = ctypes.c_void_p()\n        self.CUDART_CHECK(self.funcs[\"cudaIpcOpenMemHandle\"](\n            ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))\n        return devPtr\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce.py",
    "content": "import ctypes\nfrom contextlib import contextmanager\nfrom typing import List, Optional, Union\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\n\nimport server.envs as envs\nfrom server.inference.distributed.cuda_wrapper import CudaRTLibrary\nfrom server.inference.distributed.custom_all_reduce_utils import gpu_p2p_access_check\nfrom server.inference.distributed.parallel_state import in_the_same_node_as\nfrom server.inference.platforms import current_platform\nfrom server.utils import cuda_device_count_stateless\nimport vLLMCustomAllreduce\n\ntry:\n    vLLMCustomAllreduce.meta_size()\n    custom_ar = True\nexcept Exception:\n    # For AMD GPUs and CPUs\n    custom_ar = False\n\n\ndef _can_p2p(rank: int, world_size: int) -> bool:\n    for i in range(world_size):\n        if i == rank:\n            continue\n        if envs.VLLM_SKIP_P2P_CHECK:\n            print(\"Skipping P2P check and trusting the driver's P2P report.\")\n            return torch.cuda.can_device_access_peer(rank, i)\n        if not gpu_p2p_access_check(rank, i):\n            return False\n    return True\n\n\ndef is_weak_contiguous(inp: torch.Tensor):\n    return inp.is_contiguous() or (\n        inp.storage().nbytes() - inp.storage_offset() * inp.element_size()\n        == inp.numel() * inp.element_size()\n    )\n\n\nclass CustomAllreduce:\n\n    _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]\n\n    # max_size: max supported allreduce size\n    def __init__(\n        self,\n        group: ProcessGroup,\n        device: Union[int, str, torch.device],\n        max_size=8192 * 1024,\n    ) -> None:\n        \"\"\"\n        Args:\n            group: the process group to work on. If None, it will use the\n                default process group.\n            device: the device to bind the CustomAllreduce to. If None,\n                it will be bind to f\"cuda:{local_rank}\".\n        It is the caller's responsibility to make sure each communicator\n        is bind to a unique device, and all communicators in this group\n        are in the same node.\n        \"\"\"\n        self._IS_CAPTURING = False\n        self.disabled = True\n\n        if not custom_ar:\n            # disable because of missing custom allreduce library\n            # e.g. in a non-cuda environment\n            return\n\n        self.group = group\n\n        assert (\n            dist.get_backend(group) != dist.Backend.NCCL\n        ), \"CustomAllreduce should be attached to a non-NCCL group.\"\n\n        if not all(in_the_same_node_as(group, source_rank=0)):\n            # No need to initialize custom allreduce for multi-node case.\n            print(\n                \"Custom allreduce is disabled because this process group\"\n                \" spans across nodes.\"\n            )\n            return\n\n        rank = dist.get_rank(group=self.group)\n        world_size = dist.get_world_size(group=self.group)\n        if world_size == 1:\n            # No need to initialize custom allreduce for single GPU case.\n            return\n\n        if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:\n            print(\n                \"Custom allreduce is disabled due to an unsupported world\"\n                \" size: %d. Supported world sizes: %s. To silence this \"\n                \"warning, specify disable_custom_all_reduce=True explicitly.\",\n                world_size,\n                str(CustomAllreduce._SUPPORTED_WORLD_SIZES),\n            )\n            return\n\n        if isinstance(device, int):\n            device = torch.device(f\"cuda:{device}\")\n        elif isinstance(device, str):\n            device = torch.device(device)\n        # now `device` is a `torch.device` object\n        assert isinstance(device, torch.device)\n        self.device = device\n\n        cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES\n        if cuda_visible_devices:\n            device_ids = list(map(int, cuda_visible_devices.split(\",\")))\n        else:\n            device_ids = list(range(cuda_device_count_stateless()))\n\n        physical_device_id = device_ids[device.index]\n        tensor = torch.tensor([physical_device_id], dtype=torch.int, device=\"cpu\")\n        gather_list = [\n            torch.tensor([0], dtype=torch.int, device=\"cpu\") for _ in range(world_size)\n        ]\n        dist.all_gather(gather_list, tensor, group=self.group)\n        physical_device_ids = [t.item() for t in gather_list]\n\n        # test nvlink first, this will filter out most of the cases\n        # where custom allreduce is not supported\n        # this checks hardware and driver support for NVLink\n        assert current_platform.is_cuda()\n        from server.inference.platforms.cuda import CudaPlatform\n\n        cuda_platform: CudaPlatform = current_platform\n        full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)\n        if world_size > 2 and not full_nvlink:\n            print(\n                \"Custom allreduce is disabled because it's not supported on\"\n                \" more than two PCIe-only GPUs. To silence this warning, \"\n                \"specify disable_custom_all_reduce=True explicitly.\"\n            )\n            return\n        # test P2P capability, this checks software/cudaruntime support\n        # this is expensive to compute at the first time\n        # then we cache the result\n        if not _can_p2p(rank, world_size):\n            print(\n                \"Custom allreduce is disabled because your platform lacks \"\n                \"GPU P2P capability or P2P test failed. To silence this \"\n                \"warning, specify disable_custom_all_reduce=True explicitly.\"\n            )\n            return\n\n        self.disabled = False\n        # Buffers memory are owned by this Python class and passed to C++.\n        # Meta data composes of two parts: meta data for synchronization and a\n        # temporary buffer for storing intermediate allreduce results.\n        self.meta_ptrs = self.create_shared_buffer(\n            vLLMCustomAllreduce.meta_size() + max_size, group=group\n        )\n        # This is a pre-registered IPC buffer. In eager mode, input tensors\n        # are first copied into this buffer before allreduce is performed\n        self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)\n        # This is a buffer for storing the tuples of pointers pointing to\n        # IPC buffers from all ranks. Each registered tuple has size of\n        # 8*world_size bytes where world_size is at most 8. Allocating 8MB\n        # is enough for 131072 such tuples. The largest model I've seen only\n        # needs less than 10000 of registered tuples.\n        self.rank_data = torch.empty(\n            8 * 1024 * 1024, dtype=torch.uint8, device=self.device\n        )\n        self.max_size = max_size\n        self.rank = rank\n        self.world_size = world_size\n        self.full_nvlink = full_nvlink\n        self._ptr = vLLMCustomAllreduce.init_custom_ar(\n            self.meta_ptrs, self.rank_data, rank, self.full_nvlink\n        )\n        vLLMCustomAllreduce.register_buffer(self._ptr, self.buffer_ptrs)\n\n    @staticmethod\n    def create_shared_buffer(\n        size_in_bytes: int, group: Optional[ProcessGroup] = None\n    ) -> List[int]:\n        \"\"\"\n        Creates a shared buffer and returns a list of pointers\n        representing the buffer on all processes in the group.\n        \"\"\"\n        lib = CudaRTLibrary()\n        pointer = lib.cudaMalloc(size_in_bytes)\n        handle = lib.cudaIpcGetMemHandle(pointer)\n        world_size = dist.get_world_size(group=group)\n        rank = dist.get_rank(group=group)\n        handles = [None] * world_size\n        dist.all_gather_object(handles, handle, group=group)\n\n        pointers: List[int] = []\n        for i, h in enumerate(handles):\n            if i == rank:\n                pointers.append(pointer.value)  # type: ignore\n            else:\n                pointers.append(lib.cudaIpcOpenMemHandle(h).value)  # type: ignore\n\n        return pointers\n\n    @staticmethod\n    def free_shared_buffer(\n        pointers: List[int], group: Optional[ProcessGroup] = None\n    ) -> None:\n        rank = dist.get_rank(group=group)\n        lib = CudaRTLibrary()\n        lib.cudaFree(ctypes.c_void_p(pointers[rank]))\n\n    @contextmanager\n    def capture(self):\n        \"\"\"\n        The main responsibility of this context manager is the\n        `register_graph_buffers` call at the end of the context.\n        It records all the buffer addresses used in the CUDA graph.\n        \"\"\"\n        try:\n            self._IS_CAPTURING = True\n            yield\n        finally:\n            self._IS_CAPTURING = False\n            if not self.disabled:\n                self.register_graph_buffers()\n\n    def register_graph_buffers(self):\n        handle, offset = vLLMCustomAllreduce.get_graph_buffer_ipc_meta(self._ptr)\n        print(\"Registering %d cuda graph addresses\", len(offset))\n        # We cannot directly use `dist.all_gather_object` here\n        # because it is incompatible with `gloo` backend under inference mode.\n        # see https://github.com/pytorch/pytorch/issues/126032 for details.\n        all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]\n        all_data[self.rank] = [handle, offset]\n        ranks = sorted(dist.get_process_group_ranks(group=self.group))\n        for i, rank in enumerate(ranks):\n            dist.broadcast_object_list(\n                all_data[i], src=rank, group=self.group, device=\"cpu\"\n            )\n        # Unpack list of tuples to tuple of lists.\n        handles = [d[0] for d in all_data]  # type: ignore\n        offsets = [d[1] for d in all_data]  # type: ignore\n        vLLMCustomAllreduce.register_graph_buffers(self._ptr, handles, offsets)\n\n    def should_custom_ar(self, inp: torch.Tensor):\n        if self.disabled:\n            return False\n        inp_size = inp.numel() * inp.element_size()\n        # custom allreduce requires input byte size to be multiples of 16\n        if inp_size % 16 != 0:\n            return False\n        if not is_weak_contiguous(inp):\n            return False\n        # for 4 or more non NVLink-capable GPUs, custom allreduce provides\n        # little performance improvement over NCCL.\n        if self.world_size == 2 or self.full_nvlink:\n            return inp_size < self.max_size\n        return False\n\n    def all_reduce(\n        self, inp: torch.Tensor, *, out: torch.Tensor = None, bsz_tensor: torch.Tensor = None, registered: bool = False,\n        is_compute_bound=False, overlap=False\n    ):\n        \"\"\"Performs an out-of-place all reduce.\n\n        If registered is True, this assumes inp's pointer is already\n        IPC-registered. Otherwise, inp is first copied into a pre-registered\n        buffer.\n        \"\"\"\n        if is_compute_bound:\n            sms = 2 if overlap else 36\n        else:\n            sms = 20 if overlap else 36\n        #print(\"all reduce sms\", sms)\n        if out is None:\n            out = torch.empty_like(inp)\n        if registered:\n            vLLMCustomAllreduce.all_reduce(self._ptr, inp, out, 0, 0, bsz_tensor, block_limit=sms)\n        else:\n            vLLMCustomAllreduce.all_reduce(\n                self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size, bsz_tensor, block_limit=sms\n            )\n        return out\n\n    def custom_all_reduce(self, input: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> Optional[torch.Tensor]:\n        \"\"\"The main allreduce API that provides support for cuda graph.\"\"\"\n        # When custom allreduce is disabled, this will be None.\n        if self.disabled or not self.should_custom_ar(input):\n            return None\n        if self._IS_CAPTURING:\n            if torch.cuda.is_current_stream_capturing():\n                return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=True, is_compute_bound=is_compute_bound, overlap=overlap)\n            else:\n                # If warm up, mimic the allocation pattern since custom\n                # allreduce is out-of-place.\n                return torch.empty_like(input)\n        else:\n            # Note: outside of cuda graph context, custom allreduce incurs a\n            # cost of cudaMemcpy, which should be small (<=1% of overall\n            # latency) compared to the performance gain of using custom kernels\n            return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=False, is_compute_bound=is_compute_bound, overlap=overlap)\n\n    def close(self):\n        if not self.disabled and self._ptr:\n            vLLMCustomAllreduce.dispose(self._ptr)\n            self._ptr = 0\n            self.free_shared_buffer(self.meta_ptrs)\n            self.free_shared_buffer(self.buffer_ptrs)\n\n    def __del__(self):\n        self.close()\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce_utils.py",
    "content": "import ctypes\nimport json\nimport os\nimport pickle\nimport subprocess\nimport sys\nimport tempfile\nfrom itertools import product\nfrom typing import Dict, List, Optional, Sequence\n\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nimport server.envs as envs\nfrom server.inference.distributed.cuda_wrapper import CudaRTLibrary\nfrom server.utils import cuda_device_count_stateless, update_environment_variables\n\n\ndef producer(\n    batch_src: Sequence[int],\n    producer_queue,\n    consumer_queue,\n    result_queue,\n    cuda_visible_devices: Optional[str] = None,\n):\n    if cuda_visible_devices is not None:\n        update_environment_variables({\"CUDA_VISIBLE_DEVICES\": cuda_visible_devices})\n\n    lib = CudaRTLibrary()\n    for i in batch_src:\n        lib.cudaSetDevice(i)\n        pointer = lib.cudaMalloc(1024)\n        lib.cudaMemset(pointer, 1, 1024)\n        lib.cudaDeviceSynchronize()\n        handle = lib.cudaIpcGetMemHandle(pointer)\n        producer_queue.put(handle)\n        open_success = consumer_queue.get()\n        if open_success:\n            # use two queues to simulate barrier\n            producer_queue.put(0)\n            consumer_queue.get()\n            # check if the memory is modified\n            host_data = (ctypes.c_char * 1024)()\n            lib.cudaMemcpy(host_data, pointer, 1024)  # type: ignore\n            for i in range(1024):\n                if ord(host_data[i]) != 2:\n                    open_success = False\n                    break\n        result_queue.put(open_success)\n        lib.cudaDeviceReset()\n\n\ndef consumer(\n    batch_tgt: Sequence[int],\n    producer_queue,\n    consumer_queue,\n    result_queue,\n    cuda_visible_devices: Optional[str] = None,\n):\n    if cuda_visible_devices is not None:\n        update_environment_variables({\"CUDA_VISIBLE_DEVICES\": cuda_visible_devices})\n\n    lib = CudaRTLibrary()\n    for j in batch_tgt:\n        lib.cudaSetDevice(j)\n        handle = producer_queue.get()\n        open_success = False\n        try:\n            pointer = lib.cudaIpcOpenMemHandle(handle)  # type: ignore\n            open_success = True\n        except RuntimeError:\n            # cannot error out here, because the producer process\n            # is still waiting for the response.\n            pass\n        consumer_queue.put(open_success)\n        if open_success:\n            # modify the memory\n            lib.cudaMemset(pointer, 2, 1024)\n            lib.cudaDeviceSynchronize()\n            # use two queues to simulate barrier\n            producer_queue.get()\n            consumer_queue.put(0)\n            # check if the memory is modified\n            host_data = (ctypes.c_char * 1024)()\n            lib.cudaMemcpy(host_data, pointer, 1024)  # type: ignore\n            for i in range(1024):\n                if ord(host_data[i]) != 2:\n                    open_success = False\n                    break\n        result_queue.put(open_success)\n        lib.cudaDeviceReset()\n\n\ndef can_actually_p2p(\n    batch_src: Sequence[int],\n    batch_tgt: Sequence[int],\n) -> Sequence[bool]:\n    \"\"\"\n    Usually, checking if P2P access is enabled can be done by\n    `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes\n    the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`\n    returns `True` even if P2P access is not actually possible.\n    See https://github.com/vllm-project/vllm/issues/2728 and\n    https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10\n    Therefore, we have to perform a real P2P access to check if it is actually\n    possible.\n\n    Note on p2p and cuda IPC:\n    Usually, one process uses one GPU:\n    GPU src --> cuda context src --> tensor src --> process src\n\n    We need to combine p2p and cuda IPC, so that:\n    GPU src --> cuda context src --> tensor src --> process src\n                                      |shared|\n    GPU tgt --> cuda context tgt --> tensor tgt --> process tgt\n    That is to say, process src creates a tensor in GPU src, passes IPC handle to\n    process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the\n    tensor in process tgt will be reflected in the tensor in process src, because\n    they are the same memory segment.\n    It is important to note that process tgt accesses the tensor in GPU tgt, not\n    GPU src. That's why we need p2p access.\n\n    The most time-consuming part is the process creation. To avoid creating\n    processes for every pair of GPUs, we use batched testing. We create two\n    processes for testing all pairs of GPUs in batch. The trick is to reset\n    the device after each test (which is not available in PyTorch).\n    \"\"\"  # noqa\n    cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES\n    # pass the CUDA_VISIBLE_DEVICES to the child process\n    # to make sure they see the same set of GPUs\n\n    # make sure the processes are spawned\n    smp = mp.get_context(\"spawn\")\n    producer_queue = smp.Queue()\n    consumer_queue = smp.Queue()\n    result_queue = smp.Queue()\n    p_src = smp.Process(\n        target=producer,\n        args=(\n            batch_src,\n            producer_queue,\n            consumer_queue,\n            result_queue,\n            cuda_visible_devices,\n        ),\n    )\n    p_tgt = smp.Process(\n        target=consumer,\n        args=(\n            batch_tgt,\n            producer_queue,\n            consumer_queue,\n            result_queue,\n            cuda_visible_devices,\n        ),\n    )\n    p_src.start()\n    p_tgt.start()\n    p_src.join()\n    p_tgt.join()\n    assert p_src.exitcode == 0 and p_tgt.exitcode == 0\n    result: List[bool] = []\n    for src, tgt in zip(batch_src, batch_tgt):\n        a = result_queue.get()\n        b = result_queue.get()\n        if a != b:\n            print(\n                \"Two processes do not agree on the P2P access\"\n                \" status on %d -> %d, treat as disabled.\",\n                src,\n                tgt,\n            )\n            result.append(False)\n        else:\n            result.append(a)\n    return result\n\n\n# why do we need this cache?\n# we are testing peer-to-peer (p2p) access between GPUs,across processes.\n# if we test it every time, it will be very slow, because we need to create\n#  N * N * 2 processes, where N is the world size. This is very slow.\n# to reduce the time, we use a cache file to store the p2p access status.\n# the cache file is generated by the master process if it does not exist.\n# then all the processes can read the cache file to check the p2p access status.\n# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we\n#  can have different cache files for different CUDA_VISIBLE_DEVICES settings,\n#  e.g. used by different vllm engines. The device id in the cache file is a\n#  **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number\n#  of visible devices in the vllm engine.\n_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None\n\n\ndef gpu_p2p_access_check(src: int, tgt: int) -> bool:\n    \"\"\"Check if GPU src can access GPU tgt.\"\"\"\n\n    # if the cache variable is already calculated,\n    # read from the cache instead of checking it again\n    global _gpu_p2p_access_cache\n    if _gpu_p2p_access_cache is not None:\n        return _gpu_p2p_access_cache[f\"{src}->{tgt}\"]\n\n    is_distributed = dist.is_initialized()\n\n    num_dev = cuda_device_count_stateless()\n    cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES\n    if cuda_visible_devices is None:\n        cuda_visible_devices = \",\".join(str(i) for i in range(num_dev))\n\n    path = os.path.join(\n        envs.VLLM_CACHE_ROOT, f\"gpu_p2p_access_cache_for_{cuda_visible_devices}.json\"\n    )\n    os.makedirs(os.path.dirname(path), exist_ok=True)\n    from server.inference.distributed.parallel_state import get_world_group\n\n    if (not is_distributed or get_world_group().local_rank == 0) and (\n        not os.path.exists(path)\n    ):\n        # only the local master process (with local_rank == 0) can\n        #  enter this block to calculate the cache\n        print(\"generating GPU P2P access cache in %s\", path)\n        cache: Dict[str, bool] = {}\n        ids = list(range(num_dev))\n        # batch of all pairs of GPUs\n        batch_src, batch_tgt = zip(*list(product(ids, ids)))\n        # NOTE: we use `subprocess` rather than `multiprocessing` here\n        # because the caller might not have `if __name__ == \"__main__\":`,\n        # in that case we cannot use spawn method in multiprocessing.\n        # However, `can_actually_p2p` requires spawn method.\n        # The fix is, we use `subprocess` to call the function,\n        # where we have `if __name__ == \"__main__\":` in this file.\n\n        # use a temporary file to store the result\n        # we don't use the output of the subprocess directly,\n        # because the subprocess might produce logging output\n        with tempfile.NamedTemporaryFile() as output_file:\n            input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))\n            returned = subprocess.run(\n                [sys.executable, __file__], input=input_bytes, capture_output=True\n            )\n            # check if the subprocess is successful\n            try:\n                returned.check_returncode()\n            except Exception as e:\n                # wrap raised exception to provide more information\n                raise RuntimeError(\n                    f\"Error happened when batch testing \"\n                    f\"peer-to-peer access from {batch_src} to {batch_tgt}:\\n\"\n                    f\"{returned.stderr.decode()}\"\n                ) from e\n            with open(output_file.name, \"rb\") as f:\n                result = pickle.load(f)\n        for _i, _j, r in zip(batch_src, batch_tgt, result):\n            cache[f\"{_i}->{_j}\"] = r\n        with open(path, \"w\") as f:\n            json.dump(cache, f, indent=4)\n    if is_distributed:\n        get_world_group().barrier()\n    print(\"reading GPU P2P access cache from %s\", path)\n    with open(path) as f:\n        cache = json.load(f)\n    _gpu_p2p_access_cache = cache\n    return _gpu_p2p_access_cache[f\"{src}->{tgt}\"]\n\n\n__all__ = [\"gpu_p2p_access_check\"]\n\nif __name__ == \"__main__\":\n    batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())\n    result = can_actually_p2p(batch_src, batch_tgt)\n    with open(output_file, \"wb\") as f:\n        f.write(pickle.dumps(result))\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/distributed/parallel_state.py",
    "content": "# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n\"\"\"vLLM distributed state.\nIt takes over the control of the distributed environment from PyTorch.\nThe typical workflow is:\n\n- call `init_distributed_environment` to initialize the distributed environment.\n- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to\n initialize the model parallel groups.\n\n- any code dealing with the distributed stuff\n\n- call `destroy_model_parallel` to destroy the model parallel groups.\n- call `destroy_distributed_environment` to destroy the distributed environment.\n\nIf you only need to use the distributed environment without model/pipeline\n parallelism, you can skip the model parallel initialization and destruction\n steps.\n\"\"\"\nimport contextlib\nimport gc\nimport pickle\nimport weakref\nfrom collections import namedtuple\nfrom contextlib import contextmanager, nullcontext\nfrom dataclasses import dataclass\nfrom multiprocessing import shared_memory\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\nfrom unittest.mock import patch\n\nimport torch\nimport torch.distributed\nfrom torch.distributed import Backend, ProcessGroup\n\nimport server.envs as envs\nfrom server.inference.platforms import current_platform\nfrom server.utils import direct_register_custom_op, supports_custom_op\n\n\n@dataclass\nclass GraphCaptureContext:\n    stream: torch.cuda.Stream\n\n\nTensorMetadata = namedtuple(\"TensorMetadata\", [\"device\", \"dtype\", \"size\"])\n\n\ndef _split_tensor_dict(\n    tensor_dict: Dict[str, Union[torch.Tensor, Any]]\n) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:\n    \"\"\"Split the tensor dictionary into two parts:\n    1. A list of (key, value) pairs. If the value is a tensor, it is replaced\n         by its metadata.\n    2. A list of tensors.\n    \"\"\"\n    metadata_list: List[Tuple[str, Any]] = []\n    tensor_list: List[torch.Tensor] = []\n    for key, value in tensor_dict.items():\n        if isinstance(value, torch.Tensor):\n            # Note: we cannot use `value.device` here,\n            # because it contains not only the device type but also the device\n            # index (e.g. \"cuda:0\"). We only need the device type.\n            # receiving side will set the device index.\n            device = value.device.type\n            metadata_list.append(\n                (key, TensorMetadata(device, value.dtype, value.size()))\n            )\n            tensor_list.append(value)\n        else:\n            metadata_list.append((key, value))\n    return metadata_list, tensor_list\n\n\n_group_name_counter: Dict[str, int] = {}\n\n\ndef _get_unique_name(name: str) -> str:\n    \"\"\"Get a unique name for the group.\n    Example:\n    _get_unique_name(\"tp\") -> \"tp:0\"\n    _get_unique_name(\"tp\") -> \"tp:1\"\n    \"\"\"\n    if name not in _group_name_counter:\n        _group_name_counter[name] = 0\n    newname = f\"{name}:{_group_name_counter[name]}\"\n    _group_name_counter[name] += 1\n    return newname\n\n\n_groups: Dict[str, Callable[[], Optional[\"GroupCoordinator\"]]] = {}\n\n\ndef _register_group(group: \"GroupCoordinator\") -> None:\n    _groups[group.unique_name] = weakref.ref(group)\n\n\nif supports_custom_op():\n\n    def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:\n        assert group_name in _groups, f\"Group {group_name} is not found.\"\n        group = _groups[group_name]()\n        if group is None:\n            raise ValueError(f\"Group {group_name} is destroyed.\")\n        group._all_reduce_in_place(tensor)\n\n    def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:\n        return\n\n    direct_register_custom_op(\n        op_name=\"inplace_all_reduce\",\n        op_func=inplace_all_reduce,\n        mutates_args=[\"tensor\"],\n        fake_impl=inplace_all_reduce_fake,\n    )\n\n    def outplace_all_reduce(tensor: torch.Tensor, group_name: str, bsz_tensor: torch.Tensor, is_compute_bound: bool = False, overlap: bool = False) -> torch.Tensor:\n        assert group_name in _groups, f\"Group {group_name} is not found.\"\n        group = _groups[group_name]()\n        if group is None:\n            raise ValueError(f\"Group {group_name} is destroyed.\")\n        return group._all_reduce_out_place(tensor, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)\n\n    def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str, bsz_tensor: torch.Tensor, is_compute_bound: bool = False, overlap: bool = False) -> torch.Tensor:\n        return torch.empty_like(tensor)\n\n    direct_register_custom_op(\n        op_name=\"outplace_all_reduce\",\n        op_func=outplace_all_reduce,\n        mutates_args=[],\n        fake_impl=outplace_all_reduce_fake,\n    )\n\n\nclass GroupCoordinator:\n    \"\"\"\n    PyTorch ProcessGroup wrapper for a group of processes.\n    PyTorch ProcessGroup is bound to one specific communication backend,\n        e.g. NCCL, Gloo, MPI, etc.\n    GroupCoordinator takes charge of all the communication operations among\n        the processes in the group. It can route the communication to\n        a specific implementation (e.g. switch allreduce implementation\n        based on the tensor size and cuda graph mode).\n    \"\"\"\n\n    # available attributes:\n    rank: int  # global rank\n    ranks: List[int]  # global ranks in the group\n    world_size: int  # size of the group\n    # difference between `local_rank` and `rank_in_group`:\n    # if we have a group of size 4 across two nodes:\n    # Process | Node | Rank | Local Rank | Rank in Group\n    #   0     |   0  |  0   |     0      |       0\n    #   1     |   0  |  1   |     1      |       1\n    #   2     |   1  |  2   |     0      |       2\n    #   3     |   1  |  3   |     1      |       3\n    local_rank: int  # local rank used to assign devices\n    rank_in_group: int  # rank inside the group\n    cpu_group: ProcessGroup  # group for CPU communication\n    device_group: ProcessGroup  # group for device communication\n    use_pynccl: bool  # a hint of whether to use PyNccl\n    use_custom_allreduce: bool  # a hint of whether to use CustomAllreduce\n    # communicators are only created for world size > 1\n    pynccl_comm: Optional[Any]  # PyNccl communicator\n    ca_comm: Optional[Any]  # Custom allreduce communicator\n    mq_broadcaster: Optional[Any]  # shared memory broadcaster\n\n    def __init__(\n        self,\n        group_ranks: List[List[int]],\n        local_rank: int,\n        torch_distributed_backend: Union[str, Backend],\n        use_pynccl: bool,\n        use_custom_allreduce: bool,\n        use_tpu_communicator: bool,\n        use_hpu_communicator: bool,\n        use_xpu_communicator: bool,\n        use_message_queue_broadcaster: bool = False,\n        group_name: Optional[str] = None,\n    ):\n        group_name = group_name or \"anonymous\"\n        self.unique_name = _get_unique_name(group_name)\n        _register_group(self)\n\n        self.rank = torch.distributed.get_rank()\n        self.local_rank = local_rank\n        self.device_group = None\n        self.cpu_group = None\n\n        for ranks in group_ranks:\n            device_group = torch.distributed.new_group(\n                ranks, backend=torch_distributed_backend\n            )\n            # a group with `gloo` backend, to allow direct coordination between\n            # processes through the CPU.\n            cpu_group = torch.distributed.new_group(ranks, backend=\"gloo\")\n            if self.rank in ranks:\n                self.ranks = ranks\n                self.world_size = len(ranks)\n                self.rank_in_group = ranks.index(self.rank)\n                self.device_group = device_group\n                self.cpu_group = cpu_group\n\n        assert self.cpu_group is not None\n        assert self.device_group is not None\n        assert current_platform.is_cuda_alike()\n\n        if current_platform.is_cuda_alike():\n            self.device = torch.device(f\"cuda:{local_rank}\")\n        else:\n            self.device = torch.device(\"cpu\")\n\n        self.use_pynccl = use_pynccl\n        self.use_custom_allreduce = use_custom_allreduce\n        self.use_tpu_communicator = use_tpu_communicator\n        self.use_hpu_communicator = use_hpu_communicator\n        self.use_xpu_communicator = use_xpu_communicator\n\n        # lazy import to avoid documentation build error\n        from server.inference.distributed.custom_all_reduce import CustomAllreduce\n        from server.inference.distributed.pynccl import PyNcclCommunicator\n\n        self.pynccl_comm: Optional[PyNcclCommunicator] = None\n        # if use_pynccl and self.world_size > 1:\n        #     self.pynccl_comm = PyNcclCommunicator(\n        #         group=self.cpu_group,\n        #         device=self.device,\n        #     )\n\n        self.ca_comm: Optional[CustomAllreduce] = None\n        if use_custom_allreduce and self.world_size > 1:\n            # Initialize a custom fast all-reduce implementation.\n            self.ca_comm = CustomAllreduce(\n                group=self.cpu_group,\n                device=self.device,\n            )\n\n        #### we assume we won't use tpu or hpu or xpu or messagequeue broadcast\n\n        # from vllm.distributed.device_communicators.tpu_communicator import (\n        #     TpuCommunicator)\n        # self.tpu_communicator: Optional[TpuCommunicator] = None\n        # if use_tpu_communicator and self.world_size > 1:\n        #     self.tpu_communicator = TpuCommunicator(group=self.cpu_group)\n        self.tpu_communicator = None\n\n        # from vllm.distributed.device_communicators.hpu_communicator import (\n        #     HpuCommunicator)\n        # self.hpu_communicator: Optional[HpuCommunicator]\n        # if use_hpu_communicator and self.world_size > 1:\n        #     self.hpu_communicator = HpuCommunicator(group=self.device_group)\n        self.hpu_communicator = None\n\n        # from vllm.distributed.device_communicators.xpu_communicator import (\n        #     XpuCommunicator)\n        # self.xpu_communicator: Optional[XpuCommunicator]\n        # if use_xpu_communicator and self.world_size > 1:\n        #     self.xpu_communicator = XpuCommunicator(group=self.device_group)\n        self.xpu_communicator = None\n\n        # from vllm.distributed.device_communicators.shm_broadcast import (\n        #     MessageQueue)\n        # self.mq_broadcaster: Optional[MessageQueue] = None\n        # if use_message_queue_broadcaster and self.world_size > 1:\n        #     self.mq_broadcaster = MessageQueue.create_from_process_group(\n        #         self.cpu_group, 1 << 22, 6)\n        self.mq_broadcaster = None\n\n    @property\n    def first_rank(self):\n        \"\"\"Return the global rank of the first process in the group\"\"\"\n        return self.ranks[0]\n\n    @property\n    def last_rank(self):\n        \"\"\"Return the global rank of the last process in the group\"\"\"\n        return self.ranks[-1]\n\n    @property\n    def is_first_rank(self):\n        \"\"\"Return whether the caller is the first process in the group\"\"\"\n        return self.rank == self.first_rank\n\n    @property\n    def is_last_rank(self):\n        \"\"\"Return whether the caller is the last process in the group\"\"\"\n        return self.rank == self.last_rank\n\n    @property\n    def next_rank(self):\n        \"\"\"Return the global rank of the process that follows the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return self.ranks[(rank_in_group + 1) % world_size]\n\n    @property\n    def prev_rank(self):\n        \"\"\"Return the global rank of the process that precedes the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return self.ranks[(rank_in_group - 1) % world_size]\n\n    @contextmanager\n    def graph_capture(\n        self, graph_capture_context: Optional[GraphCaptureContext] = None\n    ):\n        if graph_capture_context is None:\n            stream = torch.cuda.Stream()\n            graph_capture_context = GraphCaptureContext(stream)\n        else:\n            stream = graph_capture_context.stream\n\n        ca_comm = self.ca_comm\n        maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()\n\n        # ensure all initialization operations complete before attempting to\n        # capture the graph on another stream\n        curr_stream = torch.cuda.current_stream()\n        if curr_stream != stream:\n            stream.wait_stream(curr_stream)\n\n        with torch.cuda.stream(stream), maybe_ca_context:\n            # In graph mode, we have to be very careful about the collective\n            # operations. The current status is:\n            #     allreduce \\ Mode   |  Eager  |  Graph  |\n            # --------------------------------------------\n            # custom allreduce       | enabled | enabled |\n            # PyNccl                 | disabled| enabled |\n            # torch.distributed      | enabled | disabled|\n            #\n            # Note that custom allreduce will have a runtime check, if the\n            #  tensor size is too large, it will fallback to the next\n            #  available option.\n            # In summary: When using CUDA graph, we use\n            #  either custom all-reduce kernel or pynccl. When not using\n            #  CUDA graph, we use either custom all-reduce kernel or\n            #  PyTorch NCCL. We always prioritize using custom all-reduce\n            #  kernel but fall back to PyTorch or pynccl if it is\n            #  disabled or not supported.\n            pynccl_comm = self.pynccl_comm\n            maybe_pynccl_context: Any\n            if not pynccl_comm:\n                maybe_pynccl_context = nullcontext()\n            else:\n                maybe_pynccl_context = pynccl_comm.change_state(\n                    enable=True, stream=torch.cuda.current_stream()\n                )\n            with maybe_pynccl_context:\n                yield graph_capture_context\n\n    def all_reduce(self, input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:\n        \"\"\"\n        User-facing all-reduce function before we actually call the\n        all-reduce operation.\n\n        We need this because Dynamo does not support passing an arbitrary\n        object (`self` in this case) to a custom op. We need to pass the\n         group name as a string, and then look up the group coordinator from\n         the group name, dispatch the all-reduce operation to the group\n         coordinator.\n\n        In addition, PyTorch custom ops do not support mutation or returning\n        a new tensor in the same op. So we need to figure out if the op is\n        in-place or out-of-place ahead of time.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return input_\n\n        if input_.is_cpu:\n            import intel_extension_for_pytorch as ipex\n\n            ipex.distributed.all_reduce(input_, group=self.device_group)\n            return input_\n\n        if not supports_custom_op():\n            self._all_reduce_in_place(input_)\n            return input_\n\n        if self.tpu_communicator is not None and not self.tpu_communicator.disabled:\n            # TPU handles Dynamo with its own logic.\n            return self.tpu_communicator.all_reduce(input_)\n\n        if self.hpu_communicator is not None and not self.hpu_communicator.disabled:\n            return self.hpu_communicator.all_reduce(input_)\n\n        if self.xpu_communicator is not None and not self.xpu_communicator.disabled:\n            return self.xpu_communicator.all_reduce(input_)\n\n        if (\n            self.ca_comm is not None\n            and not self.ca_comm.disabled\n            and self.ca_comm.should_custom_ar(input_)\n        ):\n            return torch.ops.vllm.outplace_all_reduce(\n                input_, group_name=self.unique_name, bsz_tensor=bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap\n            )\n        else:\n            #assert self.ca_comm is not None\n            #assert not self.ca_comm.disabled\n            #assert self.ca_comm.should_custom_ar(input_)\n            torch.ops.vllm.inplace_all_reduce(input_, group_name=self.unique_name)\n            return input_\n\n    def _all_reduce_out_place(self, input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:\n        ca_comm = self.ca_comm\n        assert ca_comm is not None\n        assert not ca_comm.disabled\n        out = ca_comm.custom_all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)\n        assert out is not None\n        return out\n\n    def _all_reduce_in_place(self, input_: torch.Tensor) -> None:\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.all_reduce(input_)\n        else:\n            torch.distributed.all_reduce(input_, group=self.device_group)\n\n    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:\n        world_size = self.world_size\n        # Bypass the function if we are using only 1 GPU.\n        if world_size == 1:\n            return input_\n        assert (\n            -input_.dim() <= dim < input_.dim()\n        ), f\"Invalid dim ({dim}) for input tensor with shape {input_.size()}\"\n\n        # For TPUs, use TPU communicator.\n        tpu_comm = self.tpu_communicator\n        if tpu_comm is not None and not tpu_comm.disabled:\n            return tpu_comm.all_gather(input_, dim)\n\n        # For HPUs, use HPU communicator.\n        hpu_comm = self.hpu_communicator\n        if hpu_comm is not None and not hpu_comm.disabled:\n            return hpu_comm.all_gather(input_, dim)\n\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n        input_size = input_.size()\n        # NOTE: we have to use concat-style all-gather here,\n        # stack-style all-gather has compatibility issues with\n        # torch.compile . see https://github.com/pytorch/pytorch/issues/138795\n        output_size = (input_size[0] * world_size,) + input_size[1:]\n        # Allocate output tensor.\n        output_tensor = torch.empty(\n            output_size, dtype=input_.dtype, device=input_.device\n        )\n        # All-gather.\n        torch.distributed.all_gather_into_tensor(\n            output_tensor, input_, group=self.device_group\n        )\n        # Reshape\n        output_tensor = output_tensor.reshape((world_size,) + input_size)\n        output_tensor = output_tensor.movedim(0, dim)\n        output_tensor = output_tensor.reshape(\n            input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]\n        )\n        return output_tensor\n\n    def gather(\n        self, input_: torch.Tensor, dst: int = 0, dim: int = -1\n    ) -> Optional[torch.Tensor]:\n        \"\"\"\n        NOTE: We assume that the input tensor is on the same device across\n        all the ranks.\n        NOTE: `dst` is the local rank of the destination rank.\n        \"\"\"\n        world_size = self.world_size\n        # Bypass the function if we are using only 1 GPU.\n        if world_size == 1:\n            return input_\n        assert (\n            -input_.dim() <= dim < input_.dim()\n        ), f\"Invalid dim ({dim}) for input tensor with shape {input_.size()}\"\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n        if self.xpu_communicator is not None and not self.xpu_communicator.disabled:\n            return self.xpu_communicator.gather(input_, self.rank_in_group, dst, dim)\n        # Allocate output tensor.\n        if self.rank_in_group == dst:\n            gather_list = [torch.empty_like(input_) for _ in range(world_size)]\n        else:\n            gather_list = None\n        # Gather.\n        torch.distributed.gather(\n            input_, gather_list, dst=self.ranks[dst], group=self.device_group\n        )\n        if self.rank_in_group == dst:\n            output_tensor = torch.cat(gather_list, dim=dim)\n        else:\n            output_tensor = None\n        return output_tensor\n\n    def broadcast(self, input_: torch.Tensor, src: int = 0):\n        \"\"\"Broadcast the input tensor.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return input_\n        # Broadcast.\n        torch.distributed.broadcast(\n            input_, src=self.ranks[src], group=self.device_group\n        )\n        return input_\n\n    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):\n        \"\"\"Broadcast the input object.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return obj\n        if self.mq_broadcaster is not None:\n            assert src == 0, \"Message queue broadcaster only supports src=0\"\n            return self.mq_broadcaster.broadcast_object(obj)\n        if self.rank_in_group == src:\n            torch.distributed.broadcast_object_list(\n                [obj], src=self.ranks[src], group=self.cpu_group\n            )\n            return obj\n        else:\n            recv = [None]\n            torch.distributed.broadcast_object_list(\n                recv, src=self.ranks[src], group=self.cpu_group\n            )\n            return recv[0]\n\n    def broadcast_object_list(\n        self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None\n    ):\n        \"\"\"Broadcast the input object list.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return obj_list\n        # Broadcast.\n        torch.distributed.broadcast_object_list(\n            obj_list, src=self.ranks[src], group=self.device_group\n        )\n        return obj_list\n\n    def send_object(self, obj: Any, dst: int) -> None:\n        \"\"\"Send the input object list to the destination rank.\"\"\"\n        \"\"\"NOTE: `dst` is the local rank of the destination rank.\"\"\"\n\n        assert dst < self.world_size, f\"Invalid dst rank ({dst})\"\n\n        assert dst != self.rank_in_group, (\n            \"Invalid destination rank. Destination rank is the same \"\n            \"as the current rank.\"\n        )\n\n        # Serialize object to tensor and get the size as well\n        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)\n\n        size_tensor = torch.tensor(\n            [object_tensor.numel()], dtype=torch.long, device=\"cpu\"\n        )\n\n        # Send object size\n\n        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)\n\n        # Send object\n        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)\n\n        return None\n\n    def recv_object(self, src: int) -> Any:\n        \"\"\"Receive the input object list from the source rank.\"\"\"\n        \"\"\"NOTE: `src` is the local rank of the source rank.\"\"\"\n\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        assert (\n            src != self.rank_in_group\n        ), \"Invalid source rank. Source rank is the same as the current rank.\"\n\n        size_tensor = torch.empty(1, dtype=torch.long, device=\"cpu\")\n\n        # Receive object size\n        rank_size = torch.distributed.recv(\n            size_tensor, src=self.ranks[src], group=self.cpu_group\n        )\n\n        # Tensor to receive serialized objects into.\n        object_tensor = torch.empty(  # type: ignore[call-overload]\n            size_tensor.item(),  # type: ignore[arg-type]\n            dtype=torch.uint8,\n            device=\"cpu\",\n        )\n\n        rank_object = torch.distributed.recv(\n            object_tensor, src=self.ranks[src], group=self.cpu_group\n        )\n\n        assert (\n            rank_object == rank_size\n        ), \"Received object sender rank does not match the size sender rank.\"\n\n        obj = pickle.loads(object_tensor.numpy().tobytes())\n\n        return obj\n\n    def broadcast_tensor_dict(\n        self,\n        tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,\n        src: int = 0,\n        group: Optional[ProcessGroup] = None,\n        metadata_group: Optional[ProcessGroup] = None,\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Broadcast the input tensor dictionary.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return tensor_dict\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        rank_in_group = self.rank_in_group\n        if rank_in_group == src:\n            metadata_list: List[Tuple[Any, Any]] = []\n            assert isinstance(\n                tensor_dict, dict\n            ), f\"Expecting a dictionary, got {type(tensor_dict)}\"\n            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)\n            # `metadata_list` lives in CPU memory.\n            # `broadcast_object_list` has serialization & deserialization,\n            # all happening on CPU. Therefore, we can use the CPU group.\n            self.broadcast_object(metadata_list, src=src)\n            async_handles = []\n            for tensor in tensor_list:\n                if tensor.numel() == 0:\n                    # Skip broadcasting empty tensors.\n                    continue\n                if tensor.is_cpu:\n                    # use metadata_group for CPU tensors\n                    handle = torch.distributed.broadcast(\n                        tensor, src=self.ranks[src], group=metadata_group, async_op=True\n                    )\n                else:\n                    # use group for GPU tensors\n                    handle = torch.distributed.broadcast(\n                        tensor, src=self.ranks[src], group=group, async_op=True\n                    )\n                async_handles.append(handle)\n            for async_handle in async_handles:\n                async_handle.wait()\n\n        else:\n            metadata_list = self.broadcast_object(None, src=src)\n            tensor_dict = {}\n            async_handles = []\n            for key, value in metadata_list:\n                if isinstance(value, TensorMetadata):\n                    tensor = torch.empty(\n                        value.size, dtype=value.dtype, device=value.device\n                    )\n                    if tensor.numel() == 0:\n                        # Skip broadcasting empty tensors.\n                        tensor_dict[key] = tensor\n                        continue\n                    if tensor.is_cpu:\n                        # use metadata_group for CPU tensors\n                        handle = torch.distributed.broadcast(\n                            tensor,\n                            src=self.ranks[src],\n                            group=metadata_group,\n                            async_op=True,\n                        )\n                    else:\n                        # use group for GPU tensors\n                        handle = torch.distributed.broadcast(\n                            tensor, src=self.ranks[src], group=group, async_op=True\n                        )\n                    async_handles.append(handle)\n                    tensor_dict[key] = tensor\n                else:\n                    tensor_dict[key] = value\n            for async_handle in async_handles:\n                async_handle.wait()\n        return tensor_dict\n\n    def send_tensor_dict(\n        self,\n        tensor_dict: Dict[str, Union[torch.Tensor, Any]],\n        dst: Optional[int] = None,\n        all_gather_group: Optional[\"GroupCoordinator\"] = None,\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Send the input tensor dictionary.\n        NOTE: `dst` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return tensor_dict\n\n        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size\n        all_gather_rank = (\n            0 if all_gather_group is None else all_gather_group.rank_in_group\n        )\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n\n        if dst is None:\n            dst = (self.rank_in_group + 1) % self.world_size\n        assert dst < self.world_size, f\"Invalid dst rank ({dst})\"\n\n        metadata_list: List[Tuple[Any, Any]] = []\n        assert isinstance(\n            tensor_dict, dict\n        ), f\"Expecting a dictionary, got {type(tensor_dict)}\"\n        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)\n        # `metadata_list` lives in CPU memory.\n        # `send_object_list` has serialization & deserialization,\n        # all happening on CPU. Therefore, we can use the CPU group.\n        self.send_object(metadata_list, dst=dst)\n        for tensor in tensor_list:\n            if tensor.numel() == 0:\n                # Skip sending empty tensors.\n                continue\n\n            # send-allgather: send only a slice, then do allgather.\n            if all_gather_group is not None and tensor.numel() % all_gather_size == 0:\n                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]\n\n            if tensor.is_cpu:\n                # use metadata_group for CPU tensors\n                torch.distributed.send(\n                    tensor, dst=self.ranks[dst], group=metadata_group\n                )\n            else:\n                # use group for GPU tensors\n                torch.distributed.send(tensor, dst=self.ranks[dst], group=group)\n        return None\n\n    def recv_tensor_dict(\n        self,\n        src: Optional[int] = None,\n        all_gather_group: Optional[\"GroupCoordinator\"] = None,\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Recv the input tensor dictionary.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return None\n\n        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size\n        all_gather_rank = (\n            0 if all_gather_group is None else all_gather_group.rank_in_group\n        )\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n\n        if src is None:\n            src = (self.rank_in_group - 1) % self.world_size\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        recv_metadata_list = self.recv_object(src=src)\n        tensor_dict: Dict[str, Any] = {}\n        for key, value in recv_metadata_list:\n            if isinstance(value, TensorMetadata):\n                tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)\n                if tensor.numel() == 0:\n                    # Skip broadcasting empty tensors.\n                    tensor_dict[key] = tensor\n                    continue\n\n                # send-allgather: send only a slice, then do allgather.\n                use_all_gather = (\n                    all_gather_group is not None\n                    and tensor.numel() % all_gather_size == 0\n                )\n\n                if use_all_gather:\n                    orig_shape = tensor.shape\n                    tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]\n\n                if tensor.is_cpu:\n                    # use metadata_group for CPU tensors\n                    torch.distributed.recv(\n                        tensor, src=self.ranks[src], group=metadata_group\n                    )\n                else:\n                    # use group for GPU tensors\n                    torch.distributed.recv(tensor, src=self.ranks[src], group=group)\n                if use_all_gather:\n                    # do the allgather\n                    tensor = all_gather_group.all_gather(tensor, dim=0)  # type: ignore\n                    tensor = tensor.reshape(orig_shape)\n\n                tensor_dict[key] = tensor\n            else:\n                tensor_dict[key] = value\n        return tensor_dict\n\n    def barrier(self):\n        \"\"\"Barrier synchronization among the group.\n        NOTE: don't use `device_group` here! `barrier` in NCCL is\n        terrible because it is internally a broadcast operation with\n        secretly created GPU tensors. It is easy to mess up the current\n        device. Use the CPU group instead.\n        \"\"\"\n        torch.distributed.barrier(group=self.cpu_group)\n\n    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:\n        \"\"\"Sends a tensor to the destination rank in a non-blocking way\"\"\"\n        \"\"\"NOTE: `dst` is the local rank of the destination rank.\"\"\"\n        if dst is None:\n            dst = (self.rank_in_group + 1) % self.world_size\n\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.send(tensor, dst)\n        else:\n            torch.distributed.send(tensor, self.ranks[dst], self.device_group)\n\n    def recv(\n        self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None\n    ) -> torch.Tensor:\n        \"\"\"Receives a tensor from the source rank.\"\"\"\n        \"\"\"NOTE: `src` is the local rank of the source rank.\"\"\"\n        if src is None:\n            src = (self.rank_in_group - 1) % self.world_size\n\n        tensor = torch.empty(size, dtype=dtype, device=self.device)\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.recv(tensor, src)\n        else:\n            torch.distributed.recv(tensor, self.ranks[src], self.device_group)\n        return tensor\n\n    def destroy(self):\n        if self.device_group is not None:\n            torch.distributed.destroy_process_group(self.device_group)\n            self.device_group = None\n        if self.cpu_group is not None:\n            torch.distributed.destroy_process_group(self.cpu_group)\n            self.cpu_group = None\n        if self.pynccl_comm is not None:\n            self.pynccl_comm = None\n        if self.ca_comm is not None:\n            self.ca_comm = None\n        if self.mq_broadcaster is not None:\n            self.mq_broadcaster = None\n\n\n_WORLD: Optional[GroupCoordinator] = None\n\n\ndef get_world_group() -> GroupCoordinator:\n    assert _WORLD is not None, \"world group is not initialized\"\n    return _WORLD\n\n\ndef init_world_group(\n    ranks: List[int], local_rank: int, backend: str\n) -> GroupCoordinator:\n    return GroupCoordinator(\n        group_ranks=[ranks],\n        local_rank=local_rank,\n        torch_distributed_backend=backend,\n        use_pynccl=False,\n        use_custom_allreduce=False,\n        use_tpu_communicator=False,\n        use_hpu_communicator=False,\n        use_xpu_communicator=False,\n        group_name=\"world\",\n    )\n\n\ndef init_model_parallel_group(\n    group_ranks: List[List[int]],\n    local_rank: int,\n    backend: str,\n    use_custom_allreduce: Optional[bool] = None,\n    use_message_queue_broadcaster: bool = False,\n    group_name: Optional[str] = None,\n) -> GroupCoordinator:\n    if use_custom_allreduce is None:\n        use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE\n    return GroupCoordinator(\n        group_ranks=group_ranks,\n        local_rank=local_rank,\n        torch_distributed_backend=backend,\n        use_pynccl=True,\n        use_custom_allreduce=use_custom_allreduce,\n        use_tpu_communicator=True,\n        use_hpu_communicator=True,\n        use_xpu_communicator=True,\n        use_message_queue_broadcaster=use_message_queue_broadcaster,\n        group_name=group_name,\n    )\n\n\n_TP: Optional[GroupCoordinator] = None\n\n\ndef get_tp_group() -> GroupCoordinator:\n    assert _TP is not None, \"tensor model parallel group is not initialized\"\n    return _TP\n\n\n# kept for backward compatibility\nget_tensor_model_parallel_group = get_tp_group\n\n_PP: Optional[GroupCoordinator] = None\n\n\ndef get_pp_group() -> GroupCoordinator:\n    assert _PP is not None, \"pipeline model parallel group is not initialized\"\n    return _PP\n\n\n# kept for backward compatibility\nget_pipeline_model_parallel_group = get_pp_group\n\n\n@contextmanager\ndef graph_capture():\n    \"\"\"\n    `graph_capture` is a context manager which should surround the code that\n    is capturing the CUDA graph. Its main purpose is to ensure that the\n    some operations will be run after the graph is captured, before the graph\n    is replayed. It returns a `GraphCaptureContext` object which contains the\n    necessary data for the graph capture. Currently, it only contains the\n    stream that the graph capture is running on. This stream is set to the\n    current CUDA stream when the context manager is entered and reset to the\n    default stream when the context manager is exited. This is to ensure that\n    the graph capture is running on a separate stream from the default stream,\n    in order to explicitly distinguish the kernels to capture\n    from other kernels possibly launched on background in the default stream.\n    \"\"\"\n    with get_tp_group().graph_capture() as context, get_pp_group().graph_capture(\n        context\n    ):\n        yield context\n\n\n_ENABLE_CUSTOM_ALL_REDUCE = True\n\n\ndef set_custom_all_reduce(enable: bool):\n    global _ENABLE_CUSTOM_ALL_REDUCE\n    _ENABLE_CUSTOM_ALL_REDUCE = enable\n\n\ndef init_distributed_environment(\n    world_size: int = -1,\n    rank: int = -1,\n    distributed_init_method: str = \"env://\",\n    local_rank: int = -1,\n    backend: str = \"nccl\",\n):\n    print(\n        \"world_size=%d rank=%d local_rank=%d \" \"distributed_init_method=%s backend=%s\",\n        world_size,\n        rank,\n        local_rank,\n        distributed_init_method,\n        backend,\n    )\n    if not torch.distributed.is_initialized():\n        assert distributed_init_method is not None, (\n            \"distributed_init_method must be provided when initializing \"\n            \"distributed environment\"\n        )\n        # this backend is used for WORLD\n        torch.distributed.init_process_group(\n            backend=backend,\n            init_method=distributed_init_method,\n            world_size=world_size,\n            rank=rank,\n        )\n    # set the local rank\n    # local_rank is not available in torch ProcessGroup,\n    # see https://github.com/pytorch/pytorch/issues/122816\n    if local_rank == -1:\n        # local rank not set, this usually happens in single-node\n        # setting, where we can use rank as local rank\n        if distributed_init_method == \"env://\":\n            local_rank = envs.LOCAL_RANK\n        else:\n            local_rank = rank\n    global _WORLD\n    if _WORLD is None:\n        ranks = list(range(torch.distributed.get_world_size()))\n        _WORLD = init_world_group(ranks, local_rank, backend)\n    else:\n        assert (\n            _WORLD.world_size == torch.distributed.get_world_size()\n        ), \"world group already initialized with a different world size\"\n\n\ndef initialize_model_parallel(\n    tensor_model_parallel_size: int = 1,\n    pipeline_model_parallel_size: int = 1,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"\n    Initialize model parallel groups.\n\n    Arguments:\n        tensor_model_parallel_size: number of GPUs used for tensor model\n            parallelism.\n        pipeline_model_parallel_size: number of GPUs used for pipeline model\n            parallelism.\n\n    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we\n    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize\n    the model pipeline. The present function will\n    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:\n        4 tensor model-parallel groups:\n            [g0, g1], [g2, g3], [g4, g5], [g6, g7]\n        2 pipeline model-parallel groups:\n            [g0, g2, g4, g6], [g1, g3, g5, g7]\n    Note that for efficiency, the caller should make sure adjacent ranks\n    are on the same DGX box. For example if we are using 2 DGX-1 boxes\n    with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n    ranks 8 to 15 belong to the second box.\n    \"\"\"\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n    world_size: int = torch.distributed.get_world_size()\n    backend = backend or torch.distributed.get_backend(get_world_group().device_group)\n\n    if world_size != tensor_model_parallel_size * pipeline_model_parallel_size:\n        raise RuntimeError(\n            f\"world_size ({world_size}) is not equal to \"\n            f\"tensor_model_parallel_size ({tensor_model_parallel_size}) x \"\n            f\"pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n        )\n\n    # Build the tensor model-parallel groups.\n    num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n    global _TP\n    assert _TP is None, \"tensor model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_tensor_model_parallel_groups):\n        ranks = list(\n            range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n        )\n        group_ranks.append(ranks)\n\n    # message queue broadcaster is only used in tensor model parallel group\n    _TP = init_model_parallel_group(\n        group_ranks,\n        get_world_group().local_rank,\n        backend,\n        use_message_queue_broadcaster=True,\n        group_name=\"tp\",\n    )\n\n    # Build the pipeline model-parallel groups.\n    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n    global _PP\n    assert _PP is None, \"pipeline model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_pipeline_model_parallel_groups):\n        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))\n        group_ranks.append(ranks)\n    # pipeline parallel does not need custom allreduce\n    _PP = init_model_parallel_group(\n        group_ranks,\n        get_world_group().local_rank,\n        backend,\n        use_custom_allreduce=False,\n        group_name=\"pp\",\n    )\n\n\ndef ensure_model_parallel_initialized(\n    tensor_model_parallel_size: int,\n    pipeline_model_parallel_size: int,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"Helper to initialize model parallel groups if they are not initialized,\n    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected\n    values if the model parallel groups are initialized.\n    \"\"\"\n    backend = backend or torch.distributed.get_backend(get_world_group().device_group)\n    if not model_parallel_is_initialized():\n        initialize_model_parallel(\n            tensor_model_parallel_size, pipeline_model_parallel_size, backend\n        )\n        return\n\n    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (\n        \"tensor parallel group already initialized, but of unexpected size: \"\n        f\"{get_tensor_model_parallel_world_size()=} vs. \"\n        f\"{tensor_model_parallel_size=}\"\n    )\n    pp_world_size = get_pp_group().world_size\n    assert pp_world_size == pipeline_model_parallel_size, (\n        \"pipeline parallel group already initialized, but of unexpected size: \"\n        f\"{pp_world_size=} vs. \"\n        f\"{pipeline_model_parallel_size=}\"\n    )\n\n\ndef model_parallel_is_initialized():\n    \"\"\"Check if tensor and pipeline parallel groups are initialized.\"\"\"\n    return _TP is not None and _PP is not None\n\n\n_TP_STATE_PATCHED = False\n\n\n@contextmanager\ndef patch_tensor_parallel_group(tp_group: GroupCoordinator):\n    \"\"\"Patch the tp group temporarily until this function ends.\n\n    This method is for draft workers of speculative decoding to run draft model\n    with different tp degree from that of target model workers.\n\n    Args:\n        tp_group (GroupCoordinator): the tp group coordinator\n    \"\"\"\n    global _TP_STATE_PATCHED\n    assert not _TP_STATE_PATCHED, \"Should not call when it's already patched\"\n\n    _TP_STATE_PATCHED = True\n    old_tp_group = get_tp_group()\n    global _TP\n    _TP = tp_group\n    try:\n        yield\n    finally:\n        # restore the original state\n        _TP_STATE_PATCHED = False\n        _TP = old_tp_group\n\n\ndef get_tensor_model_parallel_world_size():\n    \"\"\"Return world size for the tensor model parallel group.\"\"\"\n    return get_tp_group().world_size\n\n\ndef get_tensor_model_parallel_rank():\n    \"\"\"Return my rank for the tensor model parallel group.\"\"\"\n    return get_tp_group().rank_in_group\n\n\ndef destroy_model_parallel():\n    \"\"\"Set the groups to none and destroy them.\"\"\"\n    global _TP\n    if _TP:\n        _TP.destroy()\n    _TP = None\n\n    global _PP\n    if _PP:\n        _PP.destroy()\n    _PP = None\n\n\ndef destroy_distributed_environment():\n    global _WORLD\n    if _WORLD:\n        _WORLD.destroy()\n    _WORLD = None\n    if torch.distributed.is_initialized():\n        torch.distributed.destroy_process_group()\n\n\ndef cleanup_dist_env_and_memory(shutdown_ray: bool = False):\n    destroy_model_parallel()\n    destroy_distributed_environment()\n    with contextlib.suppress(AssertionError):\n        torch.distributed.destroy_process_group()\n    if shutdown_ray:\n        import ray  # Lazy import Ray\n\n        ray.shutdown()\n    gc.collect()\n    if not current_platform.is_cpu():\n        torch.cuda.empty_cache()\n\n\ndef in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:\n    \"\"\"\n    This is a collective operation that returns if each rank is in the same node\n    as the source rank. It tests if processes are attached to the same\n    memory system (shared access to shared memory).\n    \"\"\"\n    assert (\n        torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL\n    ), \"in_the_same_node_as should be tested with a non-NCCL group.\"\n    # local rank inside the group\n    rank = torch.distributed.get_rank(group=pg)\n    world_size = torch.distributed.get_world_size(group=pg)\n\n    # local tensor in each process to store the result\n    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)\n\n    # global ranks of the processes in the group\n    ranks = torch.distributed.get_process_group_ranks(pg)\n\n    magic_message = b\"magic_message\"\n    shm = None\n\n    try:\n        with contextlib.suppress(OSError):\n            if rank == source_rank:\n                # create a shared memory segment\n                shm = shared_memory.SharedMemory(create=True, size=128)\n                shm.buf[: len(magic_message)] = magic_message\n                torch.distributed.broadcast_object_list(\n                    [shm.name], src=ranks[source_rank], group=pg\n                )\n                is_in_the_same_node[rank] = 1\n            else:\n                # try to open the shared memory segment\n                recv = [None]\n                torch.distributed.broadcast_object_list(\n                    recv, src=ranks[source_rank], group=pg\n                )\n                name = recv[0]\n                # fix to https://stackoverflow.com/q/62748654/9191338\n                # Python incorrectly tracks shared memory even if it is not\n                # created by the process. The following patch is a workaround.\n                with patch(\n                    \"multiprocessing.resource_tracker.register\",\n                    lambda *args, **kwargs: None,\n                ):\n                    shm = shared_memory.SharedMemory(name=name)\n                if shm.buf[: len(magic_message)] == magic_message:\n                    is_in_the_same_node[rank] = 1\n    except Exception as e:\n        print(\"Error ignored in is_in_the_same_node: %s\", e)\n    finally:\n        if shm:\n            shm.close()\n\n    torch.distributed.barrier(group=pg)\n\n    # clean up the shared memory segment\n    with contextlib.suppress(OSError):\n        if rank == source_rank and shm:\n            shm.unlink()\n    torch.distributed.all_reduce(is_in_the_same_node, group=pg)\n\n    return [x == 1 for x in is_in_the_same_node.tolist()]\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/distributed/pynccl.py",
    "content": "from contextlib import contextmanager\nfrom typing import Optional, Union\n\n# ===================== import region =====================\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup, ReduceOp\n\nfrom server.inference.distributed.pynccl_wrapper import (\n    NCCLLibrary,\n    buffer_type,\n    cudaStream_t,\n    ncclComm_t,\n    ncclDataTypeEnum,\n    ncclRedOpTypeEnum,\n    ncclUniqueId,\n)\nfrom server.inference.distributed.utils import StatelessProcessGroup\n\n\nclass PyNcclCommunicator:\n\n    def __init__(\n        self,\n        group: Union[ProcessGroup, StatelessProcessGroup],\n        device: Union[int, str, torch.device],\n        library_path: Optional[str] = None,\n    ):\n        \"\"\"\n        Args:\n            group: the process group to work on. If None, it will use the\n                default process group.\n            device: the device to bind the PyNcclCommunicator to. If None,\n                it will be bind to f\"cuda:{local_rank}\".\n            library_path: the path to the NCCL library. If None, it will\n                use the default library path.\n        It is the caller's responsibility to make sure each communicator\n        is bind to a unique device.\n        \"\"\"\n        if not isinstance(group, StatelessProcessGroup):\n            assert dist.is_initialized()\n            assert (\n                dist.get_backend(group) != dist.Backend.NCCL\n            ), \"PyNcclCommunicator should be attached to a non-NCCL group.\"\n            # note: this rank is the rank in the group\n            self.rank = dist.get_rank(group)\n            self.world_size = dist.get_world_size(group)\n        else:\n            self.rank = group.rank\n            self.world_size = group.world_size\n\n        self.group = group\n\n        # if world_size == 1, no need to create communicator\n        if self.world_size == 1:\n            self.available = False\n            self.disabled = True\n            self.stream = None\n            return\n        try:\n            self.nccl = NCCLLibrary(library_path)\n        except Exception:\n            # disable because of missing NCCL library\n            # e.g. in a non-GPU environment\n            self.available = False\n            self.disabled = True\n            self.stream = None\n            return\n\n        self.available = True\n        self.disabled = False\n\n        print(\"vLLM is using nccl==%s\", self.nccl.ncclGetVersion())\n\n        if self.rank == 0:\n            # get the unique id from NCCL\n            self.unique_id = self.nccl.ncclGetUniqueId()\n        else:\n            # construct an empty unique id\n            self.unique_id = ncclUniqueId()\n\n        if not isinstance(group, StatelessProcessGroup):\n            tensor = torch.ByteTensor(list(self.unique_id.internal))\n            ranks = dist.get_process_group_ranks(group)\n            # arg `src` in `broadcast` is the global rank\n            dist.broadcast(tensor, src=ranks[0], group=group)\n            byte_list = tensor.tolist()\n            for i, byte in enumerate(byte_list):\n                self.unique_id.internal[i] = byte\n        else:\n            self.unique_id = group.broadcast_obj(self.unique_id, src=0)\n        if isinstance(device, int):\n            device = torch.device(f\"cuda:{device}\")\n        elif isinstance(device, str):\n            device = torch.device(device)\n        # now `device` is a `torch.device` object\n        assert isinstance(device, torch.device)\n        self.device = device\n        # nccl communicator and stream will use this device\n        # `torch.cuda.device` is a context manager that changes the\n        # current cuda device to the specified one\n        with torch.cuda.device(device):\n            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(\n                self.world_size, self.unique_id, self.rank\n            )\n            self.stream = torch.cuda.Stream()\n\n            # A small all_reduce for warmup.\n            data = torch.zeros(1, device=device)\n            self.all_reduce(data)\n            self.stream.synchronize()\n            del data\n\n        # by default it is disabled, e.g. in profiling models and prefill phase.\n        # to use it, use under `with obj.change_state(enable=True)`, usually\n        # when we are using CUDA graph.\n        self.disabled = True\n\n    def all_reduce(\n        self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None\n    ):\n        if self.disabled:\n            return\n        # nccl communicator created on a specific device\n        # will only work on tensors on the same device\n        # otherwise it will cause \"illegal memory access\"\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        if stream is None:\n            stream = self.stream\n        self.nccl.ncclAllReduce(\n            buffer_type(tensor.data_ptr()),\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            ncclRedOpTypeEnum.from_torch(op),\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def send(self, tensor: torch.Tensor, dst: int, stream=None):\n        if self.disabled:\n            return\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        if stream is None:\n            stream = self.stream\n        self.nccl.ncclSend(\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            dst,\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def recv(self, tensor: torch.Tensor, src: int, stream=None):\n        if self.disabled:\n            return\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        if stream is None:\n            stream = self.stream\n        self.nccl.ncclRecv(\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            src,\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    @contextmanager\n    def change_state(\n        self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None\n    ):\n        \"\"\"\n        A context manager to change the state of the communicator.\n        \"\"\"\n        if enable is None:\n            # guess a default value when not specified\n            enable = self.available\n\n        if stream is None:\n            stream = self.stream\n\n        old_disable = self.disabled\n        old_stream = self.stream\n\n        self.stream = stream\n        self.disabled = not enable\n        yield\n\n        self.disabled = old_disable\n        self.stream = old_stream\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/distributed/pynccl_wrapper.py",
    "content": "# This file is a pure Python wrapper for the NCCL library.\n# The main purpose is to use NCCL combined with CUDA graph.\n# Before writing this script, we tried the following approach:\n# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself\n#  often gets stuck when initializing the NCCL communicator.\n# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`\n#  contains many other potential cuda APIs, that are not allowed during\n#  capturing the CUDA graph. For further details, please check\n# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .\n#\n# Another rejected idea is to write a C/C++ binding for NCCL. It is usually\n# doable, but we often encounter issues related with nccl versions, and need\n# to switch between different versions of NCCL. See\n# https://github.com/NVIDIA/nccl/issues/1234 for more details.\n# A C/C++ binding is not flexible enough to handle this. It requires\n# recompilation of the code every time we want to switch between different\n# versions. This current implementation, with a **pure** Python wrapper, is\n# more flexible. We can easily switch between different versions of NCCL by\n# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`\n# variable in the code.\n\nimport ctypes\nimport platform\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\nimport torch\nfrom torch.distributed import ReduceOp\n\nfrom server.utils import find_nccl_library\n\n\n# === export types and functions from nccl to Python ===\n# for the original nccl definition, please check\n# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in\n\nncclResult_t = ctypes.c_int\nncclComm_t = ctypes.c_void_p\n\n\nclass ncclUniqueId(ctypes.Structure):\n    _fields_ = [(\"internal\", ctypes.c_byte * 128)]\n\n\ncudaStream_t = ctypes.c_void_p\nbuffer_type = ctypes.c_void_p\n\nncclDataType_t = ctypes.c_int\n\n\nclass ncclDataTypeEnum:\n    ncclInt8 = 0\n    ncclChar = 0\n    ncclUint8 = 1\n    ncclInt32 = 2\n    ncclInt = 2\n    ncclUint32 = 3\n    ncclInt64 = 4\n    ncclUint64 = 5\n    ncclFloat16 = 6\n    ncclHalf = 6\n    ncclFloat32 = 7\n    ncclFloat = 7\n    ncclFloat64 = 8\n    ncclDouble = 8\n    ncclBfloat16 = 9\n    ncclNumTypes = 10\n\n    @classmethod\n    def from_torch(cls, dtype: torch.dtype) -> int:\n        if dtype == torch.int8:\n            return cls.ncclInt8\n        if dtype == torch.uint8:\n            return cls.ncclUint8\n        if dtype == torch.int32:\n            return cls.ncclInt32\n        if dtype == torch.int64:\n            return cls.ncclInt64\n        if dtype == torch.float16:\n            return cls.ncclFloat16\n        if dtype == torch.float32:\n            return cls.ncclFloat32\n        if dtype == torch.float64:\n            return cls.ncclFloat64\n        if dtype == torch.bfloat16:\n            return cls.ncclBfloat16\n        raise ValueError(f\"Unsupported dtype: {dtype}\")\n\n\nncclRedOp_t = ctypes.c_int\n\n\nclass ncclRedOpTypeEnum:\n    ncclSum = 0\n    ncclProd = 1\n    ncclMax = 2\n    ncclMin = 3\n    ncclAvg = 4\n    ncclNumOps = 5\n\n    @classmethod\n    def from_torch(cls, op: ReduceOp) -> int:\n        if op == ReduceOp.SUM:\n            return cls.ncclSum\n        if op == ReduceOp.PRODUCT:\n            return cls.ncclProd\n        if op == ReduceOp.MAX:\n            return cls.ncclMax\n        if op == ReduceOp.MIN:\n            return cls.ncclMin\n        if op == ReduceOp.AVG:\n            return cls.ncclAvg\n        raise ValueError(f\"Unsupported op: {op}\")\n\n\n@dataclass\nclass Function:\n    name: str\n    restype: Any\n    argtypes: List[Any]\n\n\nclass NCCLLibrary:\n    exported_functions = [\n        # const char* ncclGetErrorString(ncclResult_t result)\n        Function(\"ncclGetErrorString\", ctypes.c_char_p, [ncclResult_t]),\n        # ncclResult_t  ncclGetVersion(int *version);\n        Function(\"ncclGetVersion\", ncclResult_t,\n                 [ctypes.POINTER(ctypes.c_int)]),\n        # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);\n        Function(\"ncclGetUniqueId\", ncclResult_t,\n                 [ctypes.POINTER(ncclUniqueId)]),\n        # ncclResult_t  ncclCommInitRank(\n        #   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);\n        # note that ncclComm_t is a pointer type, so the first argument\n        # is a pointer to a pointer\n        Function(\"ncclCommInitRank\", ncclResult_t, [\n            ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,\n            ctypes.c_int\n        ]),\n        # ncclResult_t  ncclAllReduce(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,\n        #   cudaStream_t stream);\n        # note that cudaStream_t is a pointer type, so the last argument\n        # is a pointer\n        Function(\"ncclAllReduce\", ncclResult_t, [\n            buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,\n            ncclRedOp_t, ncclComm_t, cudaStream_t\n        ]),\n\n        # ncclResult_t  ncclSend(\n        #   const void* sendbuff, size_t count, ncclDataType_t datatype,\n        #   int dest, ncclComm_t comm, cudaStream_t stream);\n        Function(\"ncclSend\", ncclResult_t, [\n            buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,\n            ncclComm_t, cudaStream_t\n        ]),\n\n        # ncclResult_t  ncclRecv(\n        #   void* recvbuff, size_t count, ncclDataType_t datatype,\n        #   int src, ncclComm_t comm, cudaStream_t stream);\n        Function(\"ncclRecv\", ncclResult_t, [\n            buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,\n            ncclComm_t, cudaStream_t\n        ]),\n\n        # be cautious! this is a collective call, it will block until all\n        # processes in the communicator have called this function.\n        # because Python object destruction can happen in random order,\n        # it is better not to call it at all.\n        # ncclResult_t  ncclCommDestroy(ncclComm_t comm);\n        Function(\"ncclCommDestroy\", ncclResult_t, [ncclComm_t]),\n    ]\n\n    # class attribute to store the mapping from the path to the library\n    # to avoid loading the same library multiple times\n    path_to_library_cache: Dict[str, Any] = {}\n\n    # class attribute to store the mapping from library path\n    #  to the corresponding dictionary\n    path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}\n\n    def __init__(self, so_file: Optional[str] = None):\n\n        so_file = so_file or find_nccl_library()\n\n        try:\n            if so_file not in NCCLLibrary.path_to_dict_mapping:\n                lib = ctypes.CDLL(so_file)\n                NCCLLibrary.path_to_library_cache[so_file] = lib\n            self.lib = NCCLLibrary.path_to_library_cache[so_file]\n        except Exception as e:\n            print(\n                \"Failed to load NCCL library from %s .\"\n                \"It is expected if you are not running on NVIDIA/AMD GPUs.\"\n                \"Otherwise, the nccl library might not exist, be corrupted \"\n                \"or it does not support the current platform %s.\"\n                \"If you already have the library, please set the \"\n                \"environment variable VLLM_NCCL_SO_PATH\"\n                \" to point to the correct nccl library path.\", so_file,\n                platform.platform())\n            raise e\n\n        if so_file not in NCCLLibrary.path_to_dict_mapping:\n            _funcs: Dict[str, Any] = {}\n            for func in NCCLLibrary.exported_functions:\n                f = getattr(self.lib, func.name)\n                f.restype = func.restype\n                f.argtypes = func.argtypes\n                _funcs[func.name] = f\n            NCCLLibrary.path_to_dict_mapping[so_file] = _funcs\n        self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]\n\n    def ncclGetErrorString(self, result: ncclResult_t) -> str:\n        return self._funcs[\"ncclGetErrorString\"](result).decode(\"utf-8\")\n\n    def NCCL_CHECK(self, result: ncclResult_t) -> None:\n        if result != 0:\n            error_str = self.ncclGetErrorString(result)\n            raise RuntimeError(f\"NCCL error: {error_str}\")\n\n    def ncclGetVersion(self) -> str:\n        version = ctypes.c_int()\n        self.NCCL_CHECK(self._funcs[\"ncclGetVersion\"](ctypes.byref(version)))\n        version_str = str(version.value)\n        # something like 21903 --> \"2.19.3\"\n        major = version_str[0].lstrip(\"0\")\n        minor = version_str[1:3].lstrip(\"0\")\n        patch = version_str[3:].lstrip(\"0\")\n        return f\"{major}.{minor}.{patch}\"\n\n    def ncclGetUniqueId(self) -> ncclUniqueId:\n        unique_id = ncclUniqueId()\n        self.NCCL_CHECK(self._funcs[\"ncclGetUniqueId\"](\n            ctypes.byref(unique_id)))\n        return unique_id\n\n    def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,\n                         rank: int) -> ncclComm_t:\n        comm = ncclComm_t()\n        self.NCCL_CHECK(self._funcs[\"ncclCommInitRank\"](ctypes.byref(comm),\n                                                        world_size, unique_id,\n                                                        rank))\n        return comm\n\n    def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,\n                      count: int, datatype: int, op: int, comm: ncclComm_t,\n                      stream: cudaStream_t) -> None:\n        # `datatype` actually should be `ncclDataType_t`\n        # and `op` should be `ncclRedOp_t`\n        # both are aliases of `ctypes.c_int`\n        # when we pass int to a function, it will be converted to `ctypes.c_int`\n        # by ctypes automatically\n        self.NCCL_CHECK(self._funcs[\"ncclAllReduce\"](sendbuff, recvbuff, count,\n                                                     datatype, op, comm,\n                                                     stream))\n\n    def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,\n                 dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclSend\"](sendbuff, count, datatype,\n                                                dest, comm, stream))\n\n    def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,\n                 src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclRecv\"](recvbuff, count, datatype, src,\n                                                comm, stream))\n\n    def ncclCommDestroy(self, comm: ncclComm_t) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclCommDestroy\"](comm))\n\n\n__all__ = [\n    \"NCCLLibrary\", \"ncclDataTypeEnum\", \"ncclRedOpTypeEnum\", \"ncclUniqueId\",\n    \"ncclComm_t\", \"cudaStream_t\", \"buffer_type\"\n]\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/distributed/utils.py",
    "content": "# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\nimport dataclasses\nimport pickle\nimport time\nfrom collections import deque\nfrom typing import Any, Deque, Dict, Optional, Sequence, Tuple\n\nimport torch\nfrom torch.distributed import TCPStore\n\nimport server.envs as envs\n\n\ndef ensure_divisibility(numerator, denominator):\n    \"\"\"Ensure that numerator is divisible by the denominator.\"\"\"\n    assert numerator % denominator == 0, \"{} is not divisible by {}\".format(\n        numerator, denominator\n    )\n\n\ndef divide(numerator, denominator):\n    \"\"\"Ensure that numerator is divisible by the denominator and return\n    the division value.\"\"\"\n    ensure_divisibility(numerator, denominator)\n    return numerator // denominator\n\n\ndef split_tensor_along_last_dim(\n    tensor: torch.Tensor,\n    num_partitions: int,\n    contiguous_split_chunks: bool = False,\n) -> Sequence[torch.Tensor]:\n    \"\"\"Split a tensor along its last dimension.\n\n    Arguments:\n        tensor: input tensor.\n        num_partitions: number of partitions to split the tensor\n        contiguous_split_chunks: If True, make each chunk contiguous\n                                 in memory.\n\n    Returns:\n        A list of Tensors\n    \"\"\"\n    # Get the size and dimension.\n    last_dim = tensor.dim() - 1\n    last_dim_size = divide(tensor.size()[last_dim], num_partitions)\n    # Split.\n    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)\n    # NOTE: torch.split does not create contiguous tensors by default.\n    if contiguous_split_chunks:\n        return tuple(chunk.contiguous() for chunk in tensor_list)\n\n    return tensor_list\n\n\ndef get_pp_indices(\n    num_hidden_layers: int, pp_rank: int, pp_size: int\n) -> Tuple[int, int]:\n    \"\"\"Try to evenly distribute layers across partitions.\n    If the number of layers is not divisible by the number of partitions,\n    the last partition will have the remaining layers.\n    \"\"\"\n    partition_list_str = envs.VLLM_PP_LAYER_PARTITION\n    if partition_list_str is not None:\n        try:\n            partitions = [int(layer) for layer in partition_list_str.split(\",\")]\n        except ValueError as err:\n            raise ValueError(\n                \"Invalid partition string: {}\".format(partition_list_str)\n            ) from err\n        if len(partitions) != pp_size:\n            raise ValueError(f\"{len(partitions)=} does not match {pp_size=}.\")\n        if sum(partitions) != num_hidden_layers:\n            raise ValueError(f\"{sum(partitions)=} does not match {num_hidden_layers=}.\")\n        start_layer = sum(partitions[:pp_rank])\n        end_layer = start_layer + partitions[pp_rank]\n    else:\n        layers_per_partition = num_hidden_layers // pp_size\n        start_layer = pp_rank * layers_per_partition\n        end_layer = start_layer + layers_per_partition\n\n        if pp_rank == pp_size - 1:\n            end_layer = num_hidden_layers\n\n    return (start_layer, end_layer)\n\n\n@dataclasses.dataclass\nclass StatelessProcessGroup:\n    \"\"\"A dataclass to hold a metadata store, and the rank, world_size of the\n    group. Only use it to communicate metadata between processes.\n    For data-plane communication, create NCCL-related objects.\n    \"\"\"\n\n    rank: int\n    world_size: int\n    store: torch._C._distributed_c10d.Store\n    data_expiration_seconds: int = 3600  # 1 hour\n\n    # dst rank -> counter\n    send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)\n    # src rank -> counter\n    recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)\n    broadcast_send_counter: int = 0\n    broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)\n\n    # A deque to store the data entries, with key and timestamp.\n    entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)\n\n    def __post_init__(self):\n        assert self.rank < self.world_size\n        self.send_dst_counter = {i: 0 for i in range(self.world_size)}\n        self.recv_src_counter = {i: 0 for i in range(self.world_size)}\n        self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}\n\n    def send_obj(self, obj: Any, dst: int):\n        \"\"\"Send an object to a destination rank.\"\"\"\n        self.expire_data()\n        key = f\"send_to/{dst}/{self.send_dst_counter[dst]}\"\n        self.store.set(key, pickle.dumps(obj))\n        self.send_dst_counter[dst] += 1\n        self.entries.append((key, time.time()))\n\n    def expire_data(self):\n        \"\"\"Expire data that is older than `data_expiration_seconds` seconds.\"\"\"\n        while self.entries:\n            # check the oldest entry\n            key, timestamp = self.entries[0]\n            if time.time() - timestamp > self.data_expiration_seconds:\n                self.store.delete_key(key)\n                self.entries.popleft()\n            else:\n                break\n\n    def recv_obj(self, src: int) -> Any:\n        \"\"\"Receive an object from a source rank.\"\"\"\n        obj = pickle.loads(\n            self.store.get(f\"send_to/{self.rank}/{self.recv_src_counter[src]}\")\n        )\n        self.recv_src_counter[src] += 1\n        return obj\n\n    def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:\n        \"\"\"Broadcast an object from a source rank to all other ranks.\n        It does not clean up after all ranks have received the object.\n        Use it for limited times, e.g., for initialization.\n        \"\"\"\n        if self.rank == src:\n            self.expire_data()\n            key = f\"broadcast_from/{src}/\" f\"{self.broadcast_send_counter}\"\n            self.store.set(key, pickle.dumps(obj))\n            self.broadcast_send_counter += 1\n            self.entries.append((key, time.time()))\n            return obj\n        else:\n            key = f\"broadcast_from/{src}/\" f\"{self.broadcast_recv_src_counter[src]}\"\n            recv_obj = pickle.loads(self.store.get(key))\n            self.broadcast_recv_src_counter[src] += 1\n            return recv_obj\n\n    def all_gather_obj(self, obj: Any) -> list[Any]:\n        \"\"\"All gather an object from all ranks.\"\"\"\n        gathered_objs = []\n        for i in range(self.world_size):\n            if i == self.rank:\n                gathered_objs.append(obj)\n                self.broadcast_obj(obj, src=self.rank)\n            else:\n                recv_obj = self.broadcast_obj(None, src=i)\n                gathered_objs.append(recv_obj)\n        return gathered_objs\n\n    def barrier(self):\n        \"\"\"A barrier to synchronize all ranks.\"\"\"\n        for i in range(self.world_size):\n            if i == self.rank:\n                self.broadcast_obj(None, src=self.rank)\n            else:\n                self.broadcast_obj(None, src=i)\n\n    @staticmethod\n    def create(\n        host: str,\n        port: int,\n        rank: int,\n        world_size: int,\n        data_expiration_seconds: int = 3600,\n    ) -> \"StatelessProcessGroup\":\n        \"\"\"A replacement for `torch.distributed.init_process_group` that does not\n        pollute the global state.\n\n        If we have process A and process B called `torch.distributed.init_process_group`\n        to form a group, and then we want to form another group with process A, B, C,\n        D, it is not possible in PyTorch, because process A and process B have already\n        formed a group, and process C and process D cannot join that group. This\n        function is a workaround for this issue.\n\n        `torch.distributed.init_process_group` is a global call, while this function\n        is a stateless call. It will return a `StatelessProcessGroup` object that can be\n        used for exchanging metadata. With this function, process A and process B\n        can call `StatelessProcessGroup.create` to form a group, and then process A, B,\n        C, and D can call `StatelessProcessGroup.create` to form another group.\n        \"\"\"  # noqa\n        store = TCPStore(\n            host_name=host,\n            port=port,\n            world_size=world_size,\n            is_master=(rank == 0),\n        )\n\n        return StatelessProcessGroup(\n            rank=rank,\n            world_size=world_size,\n            store=store,\n            data_expiration_seconds=data_expiration_seconds,\n        )\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/forward_batch.py",
    "content": "'''\nDate: 2024-11-12 14:15:16\nLastEditors: Xie Weiyu ervinxie@qq.com\nLastEditTime: 2024-11-26 08:12:49\n'''\nimport torch\ntry:\n    import torch_npu\n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\nfrom ktransformers.server.balance_serve.settings import sched_ext\nfrom ktransformers.server.balance_serve.inference.query_manager import QueryManager, QueryInfo\nfrom typing import Union\nimport time\nfrom ktransformers.server.config.config import Config\n\nclass ForwardMiniBatchCombine:\n    q_indptr: torch.Tensor\n    kv_indptr: torch.Tensor\n    kv_indices: torch.Tensor\n    kv_last_page_len: torch.Tensor\n    kv_len: torch.Tensor\n    position_ids: torch.Tensor\n    tokens: torch.Tensor\n    batch_indices: torch.Tensor\n    positions: torch.Tensor\n    chunk_size: int\n    decode_batch: int        \n    is_last_prefill_chunk: bool\n    logits_start: list\n\n    temperatures: torch.Tensor\n    top_ps: torch.Tensor\n\n    def __init__(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256):\n        batch_decode = len(decode_querys_info)\n        batch_prefill = len(prefill_querys_info)\n\n        self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32)\n        self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32)\n        self.kv_indices = torch.tensor([], device=device, dtype=torch.int32)\n        self.kv_len = torch.tensor([], device=device, dtype=torch.int32)\n        self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32)\n        self.position_ids = torch.tensor([], device=device, dtype=torch.int32)\n        self.tokens = torch.tensor([], device=device, dtype=torch.int32)\n\n        self.temperatures = torch.tensor([], device=device, dtype=torch.float32)\n        self.top_ps = torch.tensor([], device=device, dtype=torch.float32)\n\n        self.logits_start = []\n        self.decode_batch = batch_decode\n        self.num_tokens = batch_decode + sum(prefill_l)\n        self.batch_size = batch_decode + batch_prefill\n        \n        for i, prefill_query_info in enumerate(prefill_querys_info):\n            if prefill_query_info != None:\n                prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0\n                # print(f\"block_len: {prefill_kv_block_len}, page_size: {page_size}\")\n                self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n                self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n                self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0)\n                self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)\n                self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0)\n                self.position_ids = torch.concat((self.position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)\n                self.tokens = torch.concat((self.tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)\n                self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1)\n\n                self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)\n                self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)\n\n        for decode_query_info in decode_querys_info:\n            decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size\n            self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n            self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n            self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0)\n            self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)\n            self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0)\n            self.position_ids = torch.concat((self.position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)\n            if decode_query_info.active_position > 0:\n                self.tokens = torch.concat((self.tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)\n            else: \n                self.tokens = torch.concat((self.tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)\n            self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1)\n\n            self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)\n            self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)\n\n        self.q_indptr = self.q_indptr.contiguous()\n        self.kv_indptr = self.kv_indptr.contiguous()\n        self.kv_indices = self.kv_indices.contiguous()\n        self.kv_len = self.kv_len.contiguous()\n        self.kv_last_page_len = self.kv_last_page_len.contiguous()\n        self.position_ids = self.position_ids.contiguous()\n        self.tokens = self.tokens.contiguous()\n\n        self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)\n\n    def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256):\n        batch_decode = len(decode_querys_info)\n        batch_prefill = len(prefill_querys_info)\n\n        self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32)\n        self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32)\n        self.kv_indices = torch.tensor([], device=device, dtype=torch.int32)\n        self.kv_len = torch.tensor([], device=device, dtype=torch.int32)\n        self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32)\n        new_position_ids = torch.tensor([], device=device, dtype=torch.int32)\n        new_tokens = torch.tensor([], device=device, dtype=torch.int32)\n\n        self.temperatures = torch.tensor([], device=device, dtype=torch.float32)\n        self.top_ps = torch.tensor([], device=device, dtype=torch.float32)\n\n        self.logits_start = []\n        self.decode_batch = batch_decode\n        self.num_tokens = batch_decode + sum(prefill_l)\n        self.batch_size = batch_decode + batch_prefill\n\n        for i, prefill_query_info in enumerate(prefill_querys_info):\n            prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0\n        # print(f\"block_len: {prefill_kv_block_len}, page_size: {page_size}\")\n            self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n            self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n            self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0)\n            self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)\n            self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0)\n            new_position_ids = torch.concat((new_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)\n            new_tokens = torch.concat((new_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)\n            self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1)\n\n            self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)\n            self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)\n\n\n        for decode_query_info in decode_querys_info:\n            decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size\n            self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n            self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n            self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0)\n            self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)\n            self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0)\n            new_position_ids = torch.concat((new_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)\n            if decode_query_info.active_position > 0:\n                new_tokens = torch.concat((new_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)\n            else: \n                new_tokens = torch.concat((new_tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)\n            self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1)\n\n            self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)\n            self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)\n\n\n        self.q_indptr = self.q_indptr.contiguous()\n        self.kv_indptr = self.kv_indptr.contiguous()\n        self.kv_indices = self.kv_indices.contiguous()\n        self.kv_len = self.kv_len.contiguous()\n        self.kv_last_page_len = self.kv_last_page_len.contiguous()\n\n        self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)\n        \n        # copy new_position_ids and new_tokens to self.position_ids and self.tokens\n        # print(\"new_position_ids: \", new_position_ids)\n        # self.print()\n        self.position_ids[:new_position_ids.size(0)].copy_(new_position_ids)\n        self.position_ids[new_position_ids.size(0):].zero_()\n        self.tokens[:new_tokens.size(0)].copy_(new_tokens)\n\n    def __str__(self):\n        ret = ''\n        ret += f'=====flash infer forward info:\\n'\n        ret += f'q_indptr: {self.q_indptr}, kv_indptr: {self.kv_indptr}, kv_indices: {self.kv_indices}\\n'\n        ret += f'kv_len: {self.kv_len}, kv_last_page_len: {self.kv_last_page_len}, bsz_tensor: {self.bsz_tensor}\\n'\n        ret += f'position_ids: {self.position_ids}, tokens: {self.tokens}\\n'\n        return ret\n\n\nclass ForwardMiniBatchSplit:\n    # NPU 流程 prefill 和 decode 分开打包\n    prefill_batch: int\n    p_q_len: torch.Tensor               # (bsz)\n    p_kv_len: torch.Tensor              # (bsz)\n    p_position_ids: torch.Tensor        # (sum(q_len))\n    p_tokens: torch.Tensor              # (sum(q_len))\n    p_temperatures: torch.Tensor        # (bsz)\n    p_top_ps: torch.Tensor              # (bsz)\n    p_block_tables: torch.Tensor        # (bsz, max_page_num)\n    p_logits_start: list\n\n    decode_batch: int\n    d_q_len: torch.Tensor\n    d_kv_len: torch.Tensor\n    d_position_ids: torch.Tensor\n    d_tokens: torch.Tensor\n    d_temperatures: torch.Tensor\n    d_top_ps: torch.Tensor\n    d_block_tables: torch.Tensor        # (bsz, max_page_num)\n    d_logits_start: list\n\n    chunk_size: int\n    is_last_prefill_chunk: bool\n\n    def __init__(\n        self,\n        prefill_querys_info: list[QueryInfo],\n        decode_querys_info: list[QueryInfo],\n        prefill_s: list[int] = None,\n        prefill_l: list[int] = None,\n        device=None,\n        page_size: int = 256,\n        max_page_num: int = 64,\n        decode_padding_len: int = 1,\n    ):\n        # 统一 NPU 设备\n        device = torch.device('npu')\n\n        if prefill_s is None or prefill_l is None:\n            raise ValueError(\n                \"[ForwardMiniBatchSplit.__init__] prefill_s / prefill_l 不能为空，chunk prefill 需要这两个参数\"\n            )\n\n        # 过滤掉 None\n        new_prefill_querys_info: list[QueryInfo] = [\n            info for info in prefill_querys_info if info is not None\n        ]\n        batch_prefill = len(new_prefill_querys_info)\n        batch_decode = len(decode_querys_info)\n\n        self.prefill_batch = batch_prefill\n        self.decode_batch = batch_decode\n        self.batch_size = batch_prefill + batch_decode\n        self.num_tokens = batch_decode * decode_padding_len + sum(prefill_l)\n\n        self.chunk_size = prefill_l[0] if prefill_l else 0\n\n        self.is_last_prefill_chunk = True\n        for i, q in enumerate(new_prefill_querys_info):\n            end_pos = prefill_s[i] + prefill_l[i]\n            if end_pos < q.query_length:\n                self.is_last_prefill_chunk = False\n                break\n\n        # ====================== Prefill 部分 ======================\n        self.p_q_len = torch.tensor([], device=device, dtype=torch.int32)\n        self.p_kv_len = torch.tensor([], device=device, dtype=torch.int32)\n        self.p_position_ids = torch.tensor([], device=device, dtype=torch.int32)\n        self.p_block_tables = -1 * torch.ones(\n            [self.prefill_batch, max_page_num], device=device, dtype=torch.int32\n        )\n        self.p_tokens = torch.tensor([], device=device, dtype=torch.int32)\n\n        self.p_temperatures = torch.tensor([], device=device, dtype=torch.float32)\n        self.p_top_ps = torch.tensor([], device=device, dtype=torch.float32)\n        self.p_logits_start: list[int] = []\n\n        for i, prefill_query_info in enumerate(new_prefill_querys_info):\n            qid = getattr(prefill_query_info, \"id\", -1)\n\n            past_len = int(prefill_query_info.active_position)\n            start = int(prefill_s[i])                            # current chunk's start position in query_tokens\n            chunk_len = int(prefill_l[i])\n            kv_len = past_len + chunk_len\n            prefill_kv_block_len = (kv_len + page_size - 1) // page_size\n\n            # Q length = current chunk length\n            self.p_q_len = torch.concat(\n                (\n                    self.p_q_len,\n                    torch.tensor([chunk_len], device=device, dtype=torch.int32),\n                ),\n                dim=0,\n            )\n            self.p_kv_len = torch.concat(\n                (\n                    self.p_kv_len,\n                    torch.tensor([kv_len], device=device, dtype=torch.int32),\n                ),\n                dim=0,\n            )\n\n            self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[\n                :prefill_kv_block_len\n            ]\n\n            self.p_position_ids = torch.concat(\n                (\n                    self.p_position_ids,\n                    torch.arange(\n                        start,\n                        start + chunk_len,\n                        device=device,\n                        dtype=torch.int32,\n                    ),\n                ),\n                dim=0,\n            )\n\n            self.p_tokens = torch.concat(\n                (\n                    self.p_tokens,\n                    prefill_query_info.query_tokens[start : start + chunk_len],\n                ),\n                dim=0,\n            )\n\n            self.p_logits_start.append(\n                chunk_len - 1\n                if len(self.p_logits_start) == 0\n                else sum(prefill_l[: i + 1]) - 1\n            )\n\n            self.p_temperatures = torch.concat(\n                (\n                    self.p_temperatures,\n                    torch.tensor(\n                        [prefill_query_info.temperature],\n                        device=device,\n                        dtype=torch.float32,\n                    ),\n                ),\n                dim=0,\n            )\n            self.p_top_ps = torch.concat(\n                (\n                    self.p_top_ps,\n                    torch.tensor(\n                        [prefill_query_info.top_p],\n                        device=device,\n                        dtype=torch.float32,\n                    ),\n                ),\n                dim=0,\n            )\n\n        # ====================== Decode ======================\n        self.d_q_len = torch.tensor([], device=device, dtype=torch.int32)\n        self.d_kv_len = torch.tensor([], device=device, dtype=torch.int32)\n        self.d_position_ids = torch.tensor([], device=device, dtype=torch.int32)\n        self.d_block_tables = -1 * torch.ones(\n            [self.decode_batch, max_page_num], device=device, dtype=torch.int32\n        )\n        self.d_tokens = torch.tensor([], device=device, dtype=torch.int32)\n\n        self.d_temperatures = torch.tensor([], device=device, dtype=torch.float32)\n        self.d_top_ps = torch.tensor([], device=device, dtype=torch.float32)\n        self.d_logits_start: list[int] = []\n\n        for i, decode_query_info in enumerate(decode_querys_info):\n            qid = getattr(decode_query_info, \"id\", -1)\n            past_len = int(decode_query_info.active_position)\n            decode_kv_block_len = (past_len + decode_padding_len + page_size - 1) // page_size\n\n            self.d_q_len = torch.concat(\n                (\n                    self.d_q_len,\n                    torch.tensor(\n                        [decode_padding_len], device=device, dtype=torch.int32\n                    ),\n                ),\n                dim=0,\n            )\n            self.d_kv_len = torch.concat(\n                (\n                    self.d_kv_len,\n                    torch.tensor(\n                        [past_len + decode_padding_len],\n                        device=device,\n                        dtype=torch.int32,\n                    ),\n                ),\n                dim=0,\n            )\n\n            self.d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[\n                :decode_kv_block_len\n            ]\n\n            self.d_position_ids = torch.concat(\n                (\n                    self.d_position_ids,\n                    torch.arange(\n                        past_len,\n                        past_len + decode_padding_len,\n                        device=device,\n                        dtype=torch.int32,\n                    ),\n                ),\n                dim=0,\n            )\n\n            if past_len > 0:\n                self.d_tokens = torch.concat(\n                    (\n                        self.d_tokens,\n                        decode_query_info.query_tokens[\n                            past_len : past_len + decode_padding_len\n                        ],\n                    ),\n                    dim=0,\n                )\n            else:\n                self.d_tokens = torch.concat(\n                    (\n                        self.d_tokens,\n                        torch.tensor(\n                            [0] * decode_padding_len,\n                            device=device,\n                            dtype=torch.int32,\n                        ),\n                    ),\n                    dim=0,\n                )\n\n            self.d_logits_start.append(\n                0\n                if len(self.d_logits_start) == 0\n                else self.d_logits_start[-1] + decode_padding_len\n            )\n\n            self.d_temperatures = torch.concat(\n                (\n                    self.d_temperatures,\n                    torch.tensor(\n                        [decode_query_info.temperature],\n                        device=device,\n                        dtype=torch.float32,\n                    ),\n                ),\n                dim=0,\n            )\n            self.d_top_ps = torch.concat(\n                (\n                    self.d_top_ps,\n                    torch.tensor(\n                        [decode_query_info.top_p],\n                        device=device,\n                        dtype=torch.float32,\n                    ),\n                ),\n                dim=0,\n            )\n\n        self.p_q_len = self.p_q_len.contiguous()\n        self.p_kv_len = self.p_kv_len.contiguous()\n        self.p_block_tables = self.p_block_tables.contiguous()\n        self.p_position_ids = self.p_position_ids.contiguous()\n        self.p_tokens = self.p_tokens.contiguous()\n\n        if self.decode_batch > 1:\n            self.d_q_len = self.d_q_len.reshape(self.decode_batch, -1).contiguous()\n            self.d_kv_len = self.d_kv_len.reshape(self.decode_batch, -1).contiguous()\n            self.d_kv_len_list = self.d_kv_len.flatten().tolist()\n            self.d_block_tables = self.d_block_tables.contiguous()\n            self.d_position_ids = self.d_position_ids.reshape(self.decode_batch, -1).contiguous()\n            self.d_tokens = self.d_tokens.reshape(self.decode_batch, -1).contiguous()\n        else:\n            self.d_q_len = self.d_q_len.contiguous()\n            self.d_kv_len = self.d_kv_len.contiguous()\n            self.d_kv_len_list = self.d_kv_len.flatten().tolist()\n            self.d_block_tables = self.d_block_tables.contiguous()\n            self.d_position_ids = self.d_position_ids.contiguous()\n            self.d_tokens = self.d_tokens.contiguous()\n\n        self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)\n\n\n    def fill(\n        self,\n        prefill_querys_info: list[QueryInfo],\n        decode_querys_info: list[QueryInfo],\n        prefill_s: list[int] = None,\n        prefill_l: list[int] = None,\n        decode_padding_len: int = 1,\n        device=None,\n        page_size: int = 256,\n        max_page_num: int = 64,\n    ):\n        device = torch.device('npu')\n\n        if prefill_s is None or prefill_l is None:\n            raise ValueError(\n                \"[ForwardMiniBatchSplit.fill] prefill_s / prefill_l 不能为空，chunk prefill 需要这两个参数\"\n            )\n\n        page_size = 128\n\n        new_prefill_querys_info: list[QueryInfo] = [\n            info for info in prefill_querys_info if info is not None\n        ]\n        batch_prefill = len(new_prefill_querys_info)\n        batch_decode = len(decode_querys_info)\n\n        self.prefill_batch = batch_prefill\n        self.decode_batch = batch_decode\n        self.batch_size = batch_prefill + batch_decode\n        self.num_tokens = batch_decode * decode_padding_len + sum(prefill_l)\n\n        self.chunk_size = prefill_l[0] if prefill_l else 0\n        self.is_last_prefill_chunk = True\n        for i, q in enumerate(new_prefill_querys_info):\n            end_pos = prefill_s[i] + prefill_l[i]\n            if end_pos < q.query_length:\n                self.is_last_prefill_chunk = False\n                break\n\n        # ---------- Prefill ----------\n        self.p_q_len = torch.tensor([], device=device, dtype=torch.int32)\n        self.p_kv_len = torch.tensor([], device=device, dtype=torch.int32)\n        new_p_position_ids = torch.tensor([], device=device, dtype=torch.int32)\n        self.p_block_tables = torch.zeros(\n            [self.prefill_batch, max_page_num], device=device, dtype=torch.int32\n        )\n        new_p_tokens = torch.tensor([], device=device, dtype=torch.int32)\n\n        self.p_temperatures = torch.tensor([], device=device, dtype=torch.float32)\n        self.p_top_ps = torch.tensor([], device=device, dtype=torch.float32)\n        self.p_logits_start = []\n\n        for i, prefill_query_info in enumerate(new_prefill_querys_info):\n            qid = getattr(prefill_query_info, \"id\", -1)\n            past_len = int(prefill_query_info.active_position)\n            start = int(prefill_s[i])\n            chunk_len = int(prefill_l[i])\n\n            kv_len = past_len + chunk_len\n            prefill_kv_block_len = (kv_len + page_size - 1) // page_size\n\n            self.p_q_len = torch.concat(\n                (\n                    self.p_q_len,\n                    torch.tensor([chunk_len], device=device, dtype=torch.int32),\n                ),\n                dim=0,\n            )\n            self.p_kv_len = torch.concat(\n                (\n                    self.p_kv_len,\n                    torch.tensor([kv_len], device=device, dtype=torch.int32),\n                ),\n                dim=0,\n            )\n            self.p_block_tables[i, :prefill_kv_block_len] = prefill_query_info.block_index[\n                :prefill_kv_block_len\n            ]\n\n            new_p_position_ids = torch.concat(\n                (\n                    new_p_position_ids,\n                    torch.arange(\n                        start,\n                        start + chunk_len,\n                        device=device,\n                        dtype=torch.int32,\n                    ),\n                ),\n                dim=0,\n            )\n            new_p_tokens = torch.concat(\n                (\n                    new_p_tokens,\n                    prefill_query_info.query_tokens[start : start + chunk_len],\n                ),\n                dim=0,\n            )\n\n            self.p_logits_start.append(\n                chunk_len - 1 if len(self.p_logits_start) == 0 else sum(prefill_l[: i + 1]) - 1\n            )\n\n            self.p_temperatures = torch.concat(\n                (\n                    self.p_temperatures,\n                    torch.tensor(\n                        [prefill_query_info.temperature],\n                        device=device,\n                        dtype=torch.float32,\n                    ),\n                ),\n                dim=0,\n            )\n            self.p_top_ps = torch.concat(\n                (\n                    self.p_top_ps,\n                    torch.tensor(\n                        [prefill_query_info.top_p],\n                        device=device,\n                        dtype=torch.float32,\n                    ),\n                ),\n                dim=0,\n            )\n\n        if new_p_position_ids.numel() > 0:\n            self.p_position_ids = new_p_position_ids.contiguous()\n        if new_p_tokens.numel() > 0:\n            self.p_tokens = new_p_tokens.contiguous()\n\n        # ---------- Decode ----------\n        self.d_q_len = torch.zeros(\n            [1] * self.decode_batch, device=device, dtype=torch.int32\n        )\n        self.d_kv_len = torch.tensor([], device=device, dtype=torch.int32)\n        new_d_position_ids = torch.tensor([], device=device, dtype=torch.int32)\n        new_d_block_tables = -1 * torch.ones(\n            [self.decode_batch, max_page_num], device=device, dtype=torch.int32\n        )\n        new_d_tokens = torch.tensor([], device=device, dtype=torch.int32)\n\n        self.d_logits_start = []\n        self.d_temperatures = torch.tensor([], device=device, dtype=torch.float32)\n        self.d_top_ps = torch.tensor([], device=device, dtype=torch.float32)\n\n        for i, decode_query_info in enumerate(decode_querys_info):\n            qid = getattr(decode_query_info, \"id\", -1)\n            past_len = int(decode_query_info.active_position)\n            decode_kv_block_len = (past_len + decode_padding_len + page_size - 1) // page_size\n\n            self.d_kv_len = torch.concat(\n                (\n                    self.d_kv_len,\n                    torch.tensor(\n                        [past_len + decode_padding_len],\n                        device=device,\n                        dtype=torch.int32,\n                    ),\n                ),\n                dim=0,\n            )\n            new_d_block_tables[i, :decode_kv_block_len] = decode_query_info.block_index[\n                :decode_kv_block_len\n            ]\n\n            new_d_position_ids = torch.concat(\n                (\n                    new_d_position_ids,\n                    torch.arange(\n                        past_len,\n                        past_len + decode_padding_len,\n                        device=device,\n                        dtype=torch.int32,\n                    ),\n                ),\n                dim=0,\n            )\n\n            if past_len > 0:\n                new_d_tokens = torch.concat(\n                    (\n                        new_d_tokens,\n                        decode_query_info.query_tokens[\n                            past_len : past_len + decode_padding_len\n                        ],\n                    ),\n                    dim=0,\n                )\n            else:\n                new_d_tokens = torch.concat(\n                    (\n                        new_d_tokens,\n                        torch.tensor(\n                            [0] * decode_padding_len,\n                            device=device,\n                            dtype=torch.int32,\n                        ),\n                    ),\n                    dim=0,\n                )\n\n            self.d_logits_start.append(\n                0\n                if len(self.d_logits_start) == 0\n                else self.d_logits_start[-1] + decode_padding_len\n            )\n\n            self.d_temperatures = torch.concat(\n                (\n                    self.d_temperatures,\n                    torch.tensor(\n                        [decode_query_info.temperature],\n                        device=device,\n                        dtype=torch.float32,\n                    ),\n                ),\n                dim=0,\n            )\n            self.d_top_ps = torch.concat(\n                (\n                    self.d_top_ps,\n                    torch.tensor(\n                        [decode_query_info.top_p],\n                        device=device,\n                        dtype=torch.float32,\n                    ),\n                ),\n                dim=0,\n            )\n\n            if len(decode_querys_info) > 1:\n                self.d_position_ids[i].copy_(new_d_position_ids[i])\n                self.d_tokens[i].copy_(new_d_tokens[i])\n                self.d_block_tables[i].copy_(new_d_block_tables[i])\n            else:\n                self.d_position_ids[:new_d_position_ids.size(0)].copy_(new_d_position_ids)\n                self.d_tokens[:new_d_tokens.size(0)].copy_(new_d_tokens)\n                self.d_block_tables[0].copy_(new_d_block_tables[0])\n\n\n        self.p_q_len = self.p_q_len.contiguous()\n        self.p_kv_len = self.p_kv_len.contiguous()\n        self.p_block_tables = self.p_block_tables.contiguous()\n\n        self.d_q_len = self.d_q_len.contiguous()\n        self.d_kv_len = self.d_kv_len.contiguous()\n        self.d_kv_len_list = self.d_kv_len.flatten().tolist()\n\n        self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)\n\n\n\n    def __str__(self):\n        ret = ''\n        ret += '=======Prefill forward info:\\n'\n        ret += f'batch: {self.prefill_batch}, qLen: {self.p_q_len}, kvLen: {self.p_kv_len}\\n'\n        ret += f'tokens: {self.p_tokens}, posIdx: {self.p_position_ids}, block_tables: {self.p_block_tables}\\n'\n        ret += '=======Decode forward info:\\n'\n        ret += f'batch: {self.decode_batch}, qLen: {self.d_q_len}, kvLen: {self.d_kv_len}\\n'\n        ret += f'tokens: {self.d_tokens}, posIdx: {self.d_position_ids}, block_tables: {self.d_block_tables}\\n'\n        ret += f'chunk_size={self.chunk_size}, is_last_prefill_chunk={self.is_last_prefill_chunk}\\n'\n        return ret\n\n\n\nclass ForwardBatchInput:\n\n    forward_minibatchs: list[Union[ForwardMiniBatchSplit, ForwardMiniBatchCombine]]\n    decode_mini_batches: list[Union[ForwardMiniBatchSplit, ForwardMiniBatchCombine]]\n    batch_size: int\n    minibatch: Union[ForwardMiniBatchSplit, ForwardMiniBatchCombine]\n\n    def __init__(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, device=None, tokens: torch.Tensor = None):\n        \n        if batch is None:\n            return\n\n\n        prefill_minibatches = batch.prefill_mini_batches\n        decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist]\n        prefill_querys_info = []\n        prefill_s = []\n        prefill_l = []\n        decode_querys_info = []\n        self.batch_size = 1\n        for (qid, s, l) in prefill_minibatches:\n            prefill_querys_info.append(query_manager.query_map[qid])\n            prefill_s.append(s)\n            prefill_l.append(l)\n        for decode_qid in decode_mini_batches:\n            qinfo = query_manager.query_map[decode_qid]\n            if qinfo.decode_start_time is None:\n                qinfo.decode_start_time = time.time()\n            decode_querys_info.append(qinfo)\n\n        if use_torch_npu:\n            minibatch = ForwardMiniBatchSplit(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)\n        else:\n            minibatch = ForwardMiniBatchCombine(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)\n        self.minibatch = minibatch\n\n    @classmethod\n    def gen_max_forward_batch(\n        cls,\n        device=None,\n        tokens: torch.Tensor = None,\n        num_mini_batches: int = 1,\n        max_seq_length: int = 1024, # TODO: add to yaml\n        prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config\n        prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size,\n        gen_prefill: bool = True,\n        decode_batch_size: int = Config().max_decode_batch_size,\n        decode_query_length: int = 1,\n        decode_active_position: torch.Tensor = None,\n        page_size = 256,\n        cuda_lens = 1\n    ):\n        instance = cls()\n        \n        instance.batch_size = num_mini_batches\n        page_size = page_size\n     \n        prefill_query_info = []\n        offset = 0\n        if gen_prefill and prefill_query_length != 0:\n            for i in range(Config().max_prefill_batch_size):\n                prefill_query_info.append(QueryInfo(i, prefill_query_length, max_seq_length, page_size, device, offset=offset))\n                offset += max_seq_length // page_size\n\n        decode_querys_info = []\n        for i in range(min(decode_batch_size, cuda_lens)):\n            query_info = QueryInfo(i+Config().max_prefill_batch_size, decode_query_length, max_seq_length, page_size, device, is_prefill=False, offset=offset)\n            offset += max_seq_length // page_size\n            if tokens is not None:\n                query_info.query_tokens[prefill_active_length:prefill_active_length + decode_query_length].copy_(tokens)            \n            if decode_active_position is None:\n                query_info.active_position = prefill_active_length\n            else: \n                query_info.active_position = decode_active_position[i]\n\n            decode_querys_info.append(query_info)\n        \n        if prefill_query_length * Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens:\n            decode_querys_info.append(query_info)\n        if use_torch_npu:\n            instance.minibatch = ForwardMiniBatchSplit(prefill_query_info, decode_querys_info, [0, 0],\n                                                [prefill_active_length for _ in range(Config().max_prefill_batch_size)],\n                                                device, page_size, decode_padding_len=decode_query_length)\n        else:\n            instance.minibatch = ForwardMiniBatchCombine(prefill_query_info, decode_querys_info, [0, 0], [prefill_active_length for _ in range(Config().max_prefill_batch_size)], device, page_size)\n        \n        return instance\n\n\n    def fill(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, page_size = 256):\n        if batch is None:\n            return\n        prefill_minibatches = batch.prefill_mini_batches\n        decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist]\n\n        prefill_querys_info = []\n        prefill_s = []\n        prefill_l = []\n        decode_querys_info = []\n        self.batch_size = 1\n        for (id, s, l) in prefill_minibatches:\n            prefill_querys_info.append(query_manager.query_map[id])\n            prefill_s.append(s)\n            prefill_l.append(l)\n        for decode_batch_idx in decode_mini_batches:\n            if query_manager.query_map[decode_batch_idx].decode_start_time is None:\n                query_manager.query_map[decode_batch_idx].decode_start_time =time.time()\n            decode_querys_info.append(query_manager.query_map[decode_batch_idx])\n\n        self.minibatch.fill(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device=query_manager.device, page_size=page_size)\n\n\n\nclass ForwardBatchOutput:\n    logits: list[torch.Tensor]\n    pre_hidden_states: list[torch.Tensor]\n    num_batchs: int\n    batch_sizes: list[int]\n    generated_tokens_num: list[int]\n    lm_start: list[int]\n    \n    temperatures: list[torch.Tensor]\n    top_ps: list[torch.Tensor]\n\n    def __init__(self):\n        self.num_batchs = 0\n        self.lm_start = []\n        self.logits = []\n        self.batch_sizes = []\n        self.generated_tokens_num = []\n        self.top_ps = []\n        self.temperatures = []\n        self.pre_hidden_states = []\n        pass\n\n    def merge(self, new_output):\n        self.logits.extend(new_output.logits)\n        self.num_batchs += new_output.num_batchs\n        self.batch_sizes.extend(new_output.batch_sizes)\n        self.generated_tokens_num.extend(new_output.generated_tokens_num)\n        self.top_ps.extend(new_output.top_ps)\n        self.temperatures.extend(new_output.temperatures)\n        self.lm_start.extend(new_output.lm_start)\n        self.pre_hidden_states.extend(new_output.pre_hidden_states)\n\n    def __str__(self):\n        logits_shape = [t.shape for t in self.logits]\n        ret = ''\n        ret += f'=======Combined output info:\\n'\n        ret += f'logits: {self.logits}\\n'\n        ret += f'logits(size): {logits_shape}, num_batchs: {self.num_batchs}, kvLen: {self.generated_tokens_num}\\n'\n        ret += f'top_ps: {self.top_ps}, temperatures: {self.temperatures}, pre_hidden_states num: {len(self.pre_hidden_states)}\\n'\n        if len(self.pre_hidden_states) != 0:\n            for idx in range(len(self.pre_hidden_states)):\n                ret += f'idx: {idx}, pre_hidden_states shape: {self.pre_hidden_states[idx].shape}\\n'    \n        return ret"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/model_runner.py",
    "content": "\"\"\"\nDate: 2024-11-07 07:02:20\nLastEditors: djw\nLastEditTime: 2024-12-10 08:48:32\n\"\"\"\nimport os.path\nimport threading\n\nimport torch\nfrom torch import nn\nimport queue\nimport signal\nimport queue\nfrom typing import AsyncIterable\nfrom fastapi import FastAPI, Request\nfrom fastapi.responses import StreamingResponse\nfrom contextlib import asynccontextmanager\nfrom pydantic import BaseModel, Field\nimport asyncio\nimport multiprocessing\nimport time\nimport torch.multiprocessing as mp\nimport random\nimport torch.distributed as dist\nimport zmq\nimport copy\nimport tempfile\nfrom ktransformers.server.balance_serve.inference.forward_batch import (\n    ForwardBatchInput, ForwardBatchOutput, ForwardMiniBatchCombine, ForwardMiniBatchSplit)\nfrom ktransformers.util import utils\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM\nfrom ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM\nfrom ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM\nfrom ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM\nfrom ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM\nfrom ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM\nfrom ktransformers.models.custom_modeling_qwen3_next import KQwen3NextForCausalLM\nfrom ktransformers.server.balance_serve.inference.query_manager import QueryManager\nfrom ktransformers.server.balance_serve.settings import sched_ext\n\ntry:\n    import torch_npu\n    use_torch_npu = torch_npu.npu.is_available()\n    from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM\n    from ktransformers.models.ascend.custom_ascend_modeling_qwen3 import KNPUQwen3MoeForCausalLM\n    from ktransformers.models.custom_cache import KVC2StaticCache, KVC2Qwen3Cache\nexcept:\n    use_torch_npu = False\n\n\ndef pad_num_tokens(num_tokens):\n    return (num_tokens + 63) // 64 * 64\n\ndef deduplicate_and_sort(lst):\n    return sorted(set(lst))\ndef generate_cuda_graphs(chunk_size: int) -> list:\n    # 如果输入不符合要求，assert掉\n    assert chunk_size <= 1024 or chunk_size % 1024 == 0, \"chunk_size must <= 1024 or a multiple of 1024\"\n    base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size]\n\n    if chunk_size <= 1024:\n        return deduplicate_and_sort(base_list)\n\n    multiples = [i for i in range(1024, chunk_size + 1, 1024)]\n\n    return deduplicate_and_sort(base_list + multiples)\nclass ModelRunner:\n    \"\"\"A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile.\"\"\"\n    if not use_torch_npu:\n        model: KDeepseekV3ForCausalLM  | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM | KQwen3NextForCausalLM\n    else:\n        model: KNPUDeepseekV3ForCausalLM | KNPUQwen3MoeForCausalLM\n        cache: KVC2StaticCache | KVC2Qwen3Cache\n    input: ForwardBatchInput | list[ForwardBatchInput]\n    output: ForwardBatchOutput\n    \n\n    def __init__(self, model = None, cache = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256, block_num = 8):\n        \n        # 先注释掉\n        self.model = model  # Compile and move model to the specified device\n        if use_torch_npu:\n            self.stream = torch.npu.Stream(device=device)\n            self.stream_scope = torch.npu.stream\n            self.input_decode = []\n            max_batch_size = 1 if Config().max_batch_size <= 1 else Config().max_batch_size\n            self.npu_graphs = sorted(set([i for i in range(1, max_batch_size + 1)]))\n            self.model.stream = self.stream  # npu do not support multi stream like this\n            if use_cuda_graph:\n                torch_npu.npu._subscribe_report(self.stream)\n\n            self.start_model_event = torch.npu.Event(enable_timing=True)\n            self.end_model_event = torch.npu.Event(enable_timing=True)\n        else:\n            self.stream = torch.cuda.Stream(device=device)\n            self.cuda_graphs = generate_cuda_graphs(Config().chunk_size)\n\n            self.start_model_event = torch.cuda.Event(enable_timing=True)\n            self.end_model_event = torch.cuda.Event(enable_timing=True)\n \n        self.device = device\n        self.input = None\n        self.features_buf = None\n        self.output = None\n        self.graph_memory_pool = None\n        self.cache = cache\n        #TODO 删掉了一行 self.cuda_graphs = generate_cuda_graphs(Config().chunk_size) 是为何，这样下面不会影响GPU吗\n        self.use_cuda_graph = use_cuda_graph\n        self.debug = False\n\n        self.model_time = 0\n        self.page_size = page_size\n        self.block_num = block_num\n\n        if 'cuda' in device:\n            self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))]\n            self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]\n            self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]\n        elif 'npu' in device:\n            self.workspace = [None for _ in range(len(self.npu_graphs))]\n            self.graphs = [torch.npu.NPUGraph() for _ in range(len(self.npu_graphs))]\n            self.page_idx_buf = [torch.zeros((self.npu_graphs[i], 1), dtype=torch.int32, device = self.device) for i in range(len(self.npu_graphs))]\n            self.page_offset_buf = [torch.zeros((self.npu_graphs[i], 1), dtype=torch.int32, device = self.device) for i in range(len(self.npu_graphs))]\n        else:\n            self.graphs, self.page_idx_buf, self.page_offset_buf = None, None, None\n        self.num_mini_batches = num_mini_batches\n\n        self.max_chunk_size = max_chunk_size\n\n        self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)\n        self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)\n\n    def model_attn_plan(self, batch, cuda_graph_idx=0):\n        if isinstance(self.model, KDeepseekV3ForCausalLM):\n            self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,\n                                             num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, \n                                             head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,\n                                             sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)\n        elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM) or isinstance(self.model, KQwen3NextForCausalLM):\n            self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,\n                                             num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,\n                                             head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads, \n                                             page_size=self.model.cache.page_size, causal=True,\n                                             q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx)\n        else:\n            assert False, \"model type not supported\"\n\n\n    def warmup(self):\n\n        def capture_graphs(cuda_graph_idx):\n            with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream):\n                self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx)   \n            self.graph_memory_pool = self.graphs[cuda_graph_idx].pool()\n\n        self.input = []\n        self.features_buf = []\n        self.outputs_buf = []\n        self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)\n        self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)\n        for i in range(len(self.cuda_graphs)):\n            prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0  #@TODO only supprot 2 prefill batch\n            self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens=self.cuda_graphs[i]))\n\n            self.features_buf.append(self.model.batch_embeddings(self.input[i]))\n            batch_size = self.input[i].minibatch.q_indptr.size(0)-1\n            num_tokens = self.features_buf[i][0].size(0)\n            print(\"capturing cuda graph\", batch_size, num_tokens)\n\n            if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM) or isinstance(self.model, KSmallThinkerForCausalLM) or isinstance(self.model, KGlm4MoeForCausalLM) or isinstance(self.model, KQwen3NextForCausalLM):\n                self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens)\n\n            self.bsz_tensor_buf[0] = batch_size\n            self.num_tokens_tensor_buf[0] = num_tokens\n\n            self.model_attn_plan(self.input[i], i)\n        \n            page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf)\n\n            \n            self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens])\n            self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])\n\n            self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1) \n        \n            self.outputs_buf.append(None)\n        \n            torch.cuda.synchronize()\n            for warm_up_iters in range(11):\n                with torch.cuda.stream(self.stream):\n                    self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i], cuda_graph_idx=i)\n            torch.cuda.synchronize()\n\n            self.outputs_buf[i].num_batchs = batch_size\n\n            capture_graphs(i)\n\n            with torch.cuda.stream(self.stream):\n                self.graphs[i].replay()\n\n            self.sync(calc_time=False)\n            print(f\"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.\")\n\n    def warmup_npu(self):\n        # npu 当前使用PD分离\n        # 当前只支持 decode 阶段的图下沉\n        # 多batch 场景下只支持 1 2 3 4 5 6 7 8\n        def capture_graphs(npu_graph_idx):\n            utils._USE_NPU_GRAPH = True\n            print(\"self.features_buf[npu_graph_idx] is \", self.features_buf[npu_graph_idx])\n            with torch.npu.graph(self.graphs[npu_graph_idx], pool=self.graph_memory_pool, stream=self.stream, auto_dispatch_capture=True):\n                self.outputs_buf[npu_graph_idx] = self.model(\n                    self.input_decode[npu_graph_idx], \n                    self.features_buf[npu_graph_idx], \n                    self.cache, None, None, \n                    self.page_idx_buf[npu_graph_idx], \n                    self.page_offset_buf[npu_graph_idx], \n                    self.position_ids_buf[npu_graph_idx], \n                    self.block_tables_buf[npu_graph_idx], \n                    cuda_graph_idx=npu_graph_idx, \n                    is_prefill=False\n                    )\n            self.graph_memory_pool = self.graphs[npu_graph_idx].pool()\n            utils._USE_NPU_GRAPH = False\n\n        self.features_buf = []\n        self.outputs_buf = []\n        self.position_ids_buf = []\n        self.block_tables_buf = []\n        self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)\n        self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)\n        for i in range(len(self.npu_graphs)):\n            prefill_query_length = (self.npu_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.npu_graphs[i] > Config().max_decode_batch_size else 0  #@TODO only supprot 2 prefill batch\n            self.input_decode.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, decode_batch_size=self.npu_graphs[i], prefill_active_length=1, page_size=self.page_size, cuda_lens = self.npu_graphs[i]))\n            self.features_buf.append(self.model.batch_embeddings(self.input_decode[i], device=self.device, is_prefill=False))\n\n            batch_size = self.npu_graphs[i]\n            num_tokens = batch_size\n            self.bsz_tensor_buf[0] = batch_size\n            self.num_tokens_tensor_buf[0] = num_tokens\n            \n            page_idx, page_offset = self.cache.get_page_table(self.input_decode[i].minibatch, self.num_tokens_tensor_buf, is_prefill=False)\n\n            self.position_ids_buf.append(self.input_decode[i].minibatch.d_position_ids.clone())\n            self.block_tables_buf.append(self.input_decode[i].minibatch.d_block_tables.clone())\n\n\n            self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens][0])\n            page_offset = page_offset.view(self.page_offset_buf[i].size())\n            self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])\n            self.page_idx_buf[i][num_tokens:].fill_(self.cache.max_cache_len // self.cache.page_size -1)\n            self.outputs_buf.append(None)\n\n            torch.npu.synchronize()\n            for warm_up_iters in range(11):\n                with torch.npu.stream(self.stream):\n                    self.outputs_buf[i] = self.model(self.input_decode[i], self.features_buf[i], self.cache, self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i], self.position_ids_buf[i], self.block_tables_buf[i], is_prefill=False)\n            torch.npu.synchronize()\n            capture_graphs(i)\n            self.replay(i)\n            self.sync(calc_time=False)\n            print(f\"npu_graph: {i+1}/{len(self.npu_graphs)}, warmup finished.\")\n\n\n    def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None):\n        with torch.cuda.stream(self.stream):\n\n            batch_size = len(batch.prefill_mini_batches) # TODO: calc this\n            num_tokens = 0\n            for i in range(len(batch.decode_mini_batches)):\n                batch_size += len(batch.decode_mini_batches[i])\n                num_tokens += len(batch.decode_mini_batches[i])\n                print(f'decode_batch_i: {len(batch.decode_mini_batches[i])},')\n\n            for i in range(len(batch.prefill_mini_batches)):\n                num_tokens += batch.prefill_mini_batches[i][2]\n                print(f'prefill_batch_i: {batch.prefill_mini_batches[i][2]},')\n\n\n\n            # cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens\n            cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs))\n            if not self.use_cuda_graph:\n                cuda_graph_idx = 0\n    \n            if self.use_cuda_graph:\n                self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size)\n            else:\n                self.input = [ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)]\n        \n\n            if self.use_cuda_graph:\n                self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)\n\n            self.bsz_tensor_buf.copy_(batch_size)\n            self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device))\n\n            if self.use_cuda_graph:\n                self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True)\n\n            self.model_attn_plan(self.input[cuda_graph_idx], cuda_graph_idx)\n            self.start_model_event.record(self.stream)\n\n            if self.use_cuda_graph:\n                self.model.flash_infer_attn_plan(self.input[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf,\n                                            num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, \n                                                head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, causal=True,\n                                                sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)\n                self.start_model_event.record(self.stream)\n                if use_torch_npu:\n                    page_idx, page_offset = self.cache.get_page_table(self.input[cuda_graph_idx].minibatch, self.bsz_tensor_buf) #TODO csx minibatch\n                    self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.cache.max_cache_len // self.cache.page_size - 1)\n                else:\n                    page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf)\n                    self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)\n\n                self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])\n                self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])\n                self.replay(cuda_graph_idx)\n                self.output = ForwardBatchOutput()\n                \n                self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)\n                self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)\n                self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone())\n\n                self.end_model_event.record(self.stream)\n            else:\n                self.model.flash_infer_attn_plan(self.input, self.bsz_tensor_buf, self.num_tokens_tensor_buf,\n                                            num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, \n                                                head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, causal=True,\n                                                sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)\n                self.start_model_event.record(self.stream)\n                page_idx, page_offset = self.cache.get_page_table(self.input[cuda_graph_idx].minibatch, self.bsz_tensor_buf)\n\n                self.output = self.model(self.input, self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)\n                self.output.logits[0] = self.output.logits[0][self.input.minibatch.logits_start]\n                self.output.top_ps.append(self.input.minibatch.top_ps)\n                self.output.temperatures.append(self.input.minibatch.temperatures)\n\n                self.end_model_event.record(self.stream)\n\n        if not self.use_cuda_graph:\n            self.output.num_batchs = self.input.batch_size\n        else:\n            self.output.num_batchs = self.input[cuda_graph_idx].batch_size\n\n    def run_split(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None):\n        \"\"\"running without flashinfer and prefill & decode split infer\"\"\"\n        def _run_infer_stage(is_prefill=True):\n            if \"npu\" in self.device:\n                cuda_graph_idx = batch_size_decode\n            if is_prefill == False:\n                if cuda_graph_idx != -1 and self.use_cuda_graph:\n                    self.features = self.model.batch_embeddings(self.input_decode[cuda_graph_idx], device=self.device, is_prefill=is_prefill)\n                else:\n                    self.features = self.model.batch_embeddings(self.input, device=self.device, is_prefill=is_prefill)\n\n                self.bsz_tensor_buf.copy_(batch_size_decode)\n\n                if self.use_cuda_graph:\n                    if cuda_graph_idx != -1:\n                        self.features_buf[cuda_graph_idx].copy_(self.features)\n                    else:\n                        self.features_buf.copy_(self.features)\n            else:\n                self.features = self.model.batch_embeddings(self.input, device=self.device, is_prefill=is_prefill)\n                self.bsz_tensor_buf.copy_(batch_size_decode)\n\n            if cuda_graph_idx != -1 and self.use_cuda_graph and is_prefill == False:\n                num_tokens = batch_size_decode + 1\n                self.start_model_event.record(self.stream) if self.start_model_event else None\n                page_idx, page_offset = self.cache.get_page_table(self.input_decode[cuda_graph_idx].minibatch, self.bsz_tensor_buf, is_prefill=is_prefill)\n                self.position_ids_buf[cuda_graph_idx].copy_(self.input_tmp.minibatch.d_position_ids)\n                self.block_tables_buf[cuda_graph_idx].copy_(self.input_tmp.minibatch.d_block_tables)\n                self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])\n                self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])\n                self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.cache.max_cache_len // self.cache.page_size - 1)\n\n                self.replay(cuda_graph_idx)\n                new_output = ForwardBatchOutput()\n                for i in range(num_tokens):\n                    new_output.top_ps.append(self.input_decode[cuda_graph_idx].minibatch.d_top_ps[i])\n                    new_output.temperatures.append(self.input_decode[cuda_graph_idx].minibatch.d_temperatures[i])\n                    new_output.logits.append(self.outputs_buf[cuda_graph_idx].logits[i].clone())  # TODO support MTP\n                self.end_model_event.record(self.stream) if self.start_model_event else None\n\n                if self.output is None:\n                    self.output = copy.deepcopy(new_output)\n                else:\n                    self.output.merge(new_output)\n\n            else:\n                self.start_model_event.record(self.stream) if self.start_model_event else None\n                page_idx, page_offset = self.cache.get_page_table(self.input.minibatch, self.num_tokens_tensor_buf, is_prefill=is_prefill)\n                new_output = self.model(self.input, self.features, self.cache, None, None, page_idx, page_offset, None, None, is_prefill=is_prefill)\n                bsz = len(new_output.logits)\n                if is_prefill:\n                    for i in range(bsz):\n                        new_output.logits[i] = new_output.logits[i][-1:, :]  # batched tensor do not need location\n                        new_output.top_ps.append(self.input.minibatch.p_top_ps[i])\n                        new_output.temperatures.append(self.input.minibatch.p_temperatures[i])\n                else:\n                    for i in range(bsz):\n                        new_output.top_ps.append(self.input.minibatch.d_top_ps[i])\n                        new_output.temperatures.append(self.input.minibatch.d_temperatures[i])\n\n                if self.output is None:\n                    self.output = copy.deepcopy(new_output)\n                else:\n                    self.output.merge(new_output)\n                self.end_model_event.record(self.stream) if self.end_model_event else None\n\n        with self.stream_scope(self.stream):\n\n            batch_size = len(batch.prefill_mini_batches) # TODO: calc this\n            num_d_tokens, num_p_tokens = 0, 0\n            for i in range(len(batch.decode_mini_batches)):\n                batch_size += len(batch.decode_mini_batches[i])\n                num_d_tokens += len(batch.decode_mini_batches[i])\n                if self.debug:\n                    print(f'decode_batch_i: {len(batch.decode_mini_batches[i])}, token_num: {len(batch.decode_mini_batches[i])} ,batch_size: {batch_size}')\n\n            for i in range(len(batch.prefill_mini_batches)):\n                num_p_tokens += batch.prefill_mini_batches[i][2]\n                if self.debug:\n                    print(f'prefill_batch_i: {batch.prefill_mini_batches[i][2]}, token_num: {batch.prefill_mini_batches[i][2]}')\n\n            # batch info holder both in graph mode & kernel mode\n            self.input_tmp = ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)\n            batch_size_decode = self.input_tmp.minibatch.decode_batch - 1\n            idx = self.input_tmp.minibatch.decode_batch - 1\n            cuda_graph_idx = batch_size_decode\n            self.output = None  # clear last step output\n\n            if self.input_tmp.minibatch.decode_batch > 0:\n                if self.use_cuda_graph and len(self.input_decode) > 0:\n                    self.input_decode[idx].fill(batch, query_manager, self.page_size)\n                else:\n                    self.input = self.input_tmp\n                    assert isinstance(self.input.minibatch, ForwardMiniBatchSplit), 'split batch input type must be ForwardMiniBatchSplit'\n                    print(self.input.minibatch) if self.debug else None\n\n            if self.input_tmp.minibatch.prefill_batch > 0:\n                self.input = self.input_tmp\n                assert isinstance(self.input.minibatch, ForwardMiniBatchSplit), 'split batch input type must be ForwardMiniBatchSplit'\n                print(self.input.minibatch) if self.debug else None\n\n            # ++++++++++++++++++++++++++++++++++++++++++ Prefill Stage ++++++++++++++++++++++++++++++++++++++++++++++++\n            if self.input_tmp.minibatch.prefill_batch > 0:\n                _run_infer_stage(is_prefill=True)\n                self.output.num_batchs = self.input.minibatch.batch_size\n            # ++++++++++++++++++++++++++++++++++++++++++ Decode Stage ++++++++++++++++++++++++++++++++++++++++++++++++\n            if self.input_tmp.minibatch.decode_batch > 0:\n                if self.use_cuda_graph:\n                    _run_infer_stage(is_prefill=False)\n                    self.output.num_batchs = self.input_decode[idx].minibatch.batch_size\n                else:\n                    _run_infer_stage(is_prefill=False)\n                    self.output.num_batchs = self.input.minibatch.batch_size\n\n            print(self.output) if self.debug else None\n\n    def replay(self, cuda_graph_idx=-1):\n        if use_torch_npu:\n            thread = threading.Thread(target=self.graphs[cuda_graph_idx].update, kwargs={\"cpu_update_input\": [{\"actual_seq_lengths_kv\": self.input_decode[cuda_graph_idx].minibatch.d_kv_len_list}]})\n            thread.start()\n            torch_npu.npu.synchronize()\n\n        with torch.cuda.stream(self.stream):\n            if cuda_graph_idx != -1:\n                self.graphs[cuda_graph_idx].replay()\n            else:\n                self.graphs.replay()\n\n\n    def sync(self, calc_time = True):\n        self.stream.synchronize()\n        if calc_time:\n            self.model_time = self.start_model_event.elapsed_time(self.end_model_event)  # In ms\n\n\ndef get_or_create_model_runner(model=None, cache=None, device=None, use_cuda_graph=None, page_size=None):\n    from ktransformers.server.balance_serve.inference.config import model_runner_dict\n    runner = model_runner_dict.get(device)\n    if runner is None:\n        print(\"[WARN] the new ModelRunner and deviceId is \", device)\n        runner = ModelRunner(model, cache, device, use_cuda_graph, page_size)\n        model_runner_dict[device] = runner\n    return runner\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/query_manager.py",
    "content": "'''\nDate: 2024-11-14 12:23:45\nLastEditors: djw\nLastEditTime: 2024-11-20 04:06:23\n'''\nimport torch\nfrom ktransformers.server.balance_serve.settings import sched_ext\nimport random\nimport time\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.utils.serve_profiling import PROF_TIME_STAT\n\nclass QueryInfo:\n    id: int\n    active_position: int\n    query_length: int\n    is_prefill: int\n    is_first_token: int\n    block_index: torch.Tensor\n    query_tokens: torch.Tensor\n    stop_criteria: list[torch.Tensor]\n\n    temperature: float\n    top_p: float\n\n    max_length: int\n\n    pos_status: torch.Tensor\n    probs: list[torch.Tensor]\n    acc_position: int \n\n    def __init__(self, id, query_length: int, max_length: int, page_size: int, device: torch.device, is_prefill: bool = True, offset: int = 0, active_position: int = 0, temperature: float = 0.01, top_p: float = 1.0):\n        self.id = id\n        self.is_prefill = is_prefill\n        self.is_first_token = False\n        self.active_position = active_position\n        self.max_length = max_length - 1\n        self.query_tokens = torch.zeros((max_length + 2,), dtype=torch.int, device = device)\n        self.stop_criteria = []\n        self.block_index = torch.arange(offset, offset + (max_length + active_position + page_size - 1) // page_size, dtype=torch.int, device = device)\n        self.query_length = query_length\n        self.enqueue_time = time.time()\n        self.decode_start_time = None\n        self.speculative_token = {} # {position: (accept, token)}\n\n        self.pos_status = torch.zeros((max_length + 2,), dtype=torch.int, device = device)\n        self.probs = [None] * (max_length + 2)\n\n        self.acc_tokens_num = 0\n        self.rej_tokens_num = 0\n        self.round = 0\n        self.acc_length = 0\n        self.acc_position = 0\n\n        self.temperature = temperature\n        self.top_p = top_p\n\n    def check_stop(self):\n        if self.active_position >= self.max_length - 2:\n            if PROF_TIME_STAT.on:\n                PROF_TIME_STAT.print_all()\n                # PROF_TIME_STAT.reset_all()\n            return True\n\n        # 遍历每个停止条件\n        for stop_tensor in self.stop_criteria:\n            stop_len = len(stop_tensor)\n            \n            # 如果停止条件比 query_tokens 长，跳过\n            if stop_len >= self.active_position:\n                continue\n            \n            #print(f\"stop_tensor: {stop_tensor}, stop_len: {stop_len}, active_position: {self.active_position}, query_token: {self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1]}\")\n\n            if (torch.equal(self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1], stop_tensor) and self.active_position) or self.max_length <= self.active_position + 3:\n                self.life_time = time.time() - self.enqueue_time\n                self.decode_duration_time = time.time() - self.decode_start_time\n                self.decode_tps = (self.active_position -  self.query_length) / self.decode_duration_time\n                print(f\"prefill length: {self.query_length}, prefill time: {self.prefill_duration_time}, prefill tps {self.prefill_tps}, decode length: {self.active_position -  self.query_length}, decode time: {self.decode_duration_time}, decode tps {self.decode_tps}\")\n                \n                if self.acc_tokens_num + self.rej_tokens_num != 0:\n                    verify_counts = self.acc_tokens_num + self.rej_tokens_num\n                    print(f\"mtp accept rate: {self.acc_tokens_num}/{verify_counts} = {self.acc_tokens_num * 100 / verify_counts} %\")\n                if PROF_TIME_STAT.on:\n                    PROF_TIME_STAT.print_all()\n                    # PROF_TIME_STAT.reset_all()\n                return True  # 找到匹配的停止条件\n                \n        \n        return False  # 没有找到任何停止条件\n\n\n    def print(self):\n        print(f\"active_position: {self.active_position}, query_length: {self.query_length}, is_prefill: {self.is_prefill}\")\n        print(f\"block_index_shape: {self.block_index.shape}, query_tokens_shape: {self.query_tokens.shape}\")\n        print(f\"query_tokens_shape: {self.query_tokens}, is_first_token: {self.is_first_token}\" )\n        print(f\"pos_status: {self.pos_status}, acc_position: \", self.acc_position)\n        print(f\"probs: {self.probs}\")\n\n\nclass QueryManager:\n\n    max_length: int = 65536\n    page_size: int = 256\n    device: torch.device\n    query_map : dict[int, QueryInfo]\n\n    def __init__(self, max_length = 65536, page_size = 256, device = torch.device('cuda')):\n        self.max_length = max_length\n        self.page_size = page_size\n        self.device = device\n        self.query_map = {}\n\n    def print(self, hint: str = \"\"):\n        print(hint,\" query_manager: \", self.query_map)\n        for key in self.query_map: \n            query_info = self.query_map[key]\n            print(\">>> query: \", key)\n            print(\"query_info: \")\n            query_info.print()\n\n    def add_query(self, batch: sched_ext.BatchQueryTodo):\n\n        for i in range(len(batch.query_ids)):\n            id = batch.query_ids[i]\n            if id not in self.query_map:\n                print(f\"add query id: {id}, batch.query_lengths: {batch.query_lengths[i]}, \"\n                      f\"batch_query_tokens: {batch.query_tokens[i].shape}, \"\n                      f\"batch.block_indexes: {batch.block_indexes[i]}\")\n                assert batch.query_tokens[i].size(0) < self.max_length, \"query max length in batchquerytodo exceeds internal max_length\"\n                query_info = QueryInfo(id=id, query_length=batch.query_lengths[i], max_length=batch.query_tokens[i].size(0) + 1, page_size=self.page_size, device=self.device, temperature=batch.sample_options[i].temperature, top_p=batch.sample_options[i].top_p)\n                query_info.query_tokens[:query_info.query_length].copy_(batch.query_tokens[i][:query_info.query_length].to(self.device))\n                \n                for stop_token_list in batch.stop_criteria[i]:\n                    query_info.stop_criteria.append(torch.tensor(stop_token_list, dtype=torch.int, device = self.device))\n\n                block_num = batch.block_indexes[i].size(0)\n                query_info.block_index[:block_num].copy_(batch.block_indexes[i].to(self.device))\n\n                self.query_map[id] = query_info\n                \n                prefill_mini_batches = batch.prefill_mini_batches\n                for (prefill_id, s, l) in prefill_mini_batches:\n                    if prefill_id == id:\n                        self.query_map[prefill_id].active_position = s\n\n\n    def update(self, batch: sched_ext.BatchQueryTodo) -> list[sched_ext.QueryUpdate]:\n        query_updates = []\n\n        prefill_mini_batches = batch.prefill_mini_batches\n\n        for (id, s, l) in prefill_mini_batches:\n\n            if id not in self.query_map:\n                assert False, f\"query id {id} not found in query_map\"\n\n            # update query_info\n            query_info = self.query_map[id]\n            query_info.active_position += l\n\n            if query_info.active_position >= query_info.query_length and query_info.is_prefill:\n                query_info.is_prefill = False\n                query_info.is_first_token = True\n                query_info.prefill_duration_time = time.time() - query_info.enqueue_time\n                query_info.prefill_tps = query_info.query_length / query_info.prefill_duration_time\n                \n\n            # generate schedule query_update\n            query_update = sched_ext.QueryUpdate()\n            query_update.id = id\n            query_update.ok = True\n            query_update.is_prefill = query_info.is_prefill\n            query_update.active_position = query_info.active_position\n            # if(not query_info.is_prefill):\n            query_updates.append(query_update)\n\n\n        decode_mini_batches = batch.decode_mini_batches\n\n        for ids in decode_mini_batches:\n            for id in ids:\n                if id not in self.query_map:\n                    assert False, f\"query id {id} not found in query_map\"\n\n                query_info = self.query_map[id]\n                query_info.is_first_token = False\n                query_info.active_position += 1\n\n                query_update = sched_ext.QueryUpdate()\n                query_update.id = id\n                query_update.ok = True\n                query_update.is_prefill = query_info.is_prefill\n\n                query_update.decode_done = query_info.check_stop()\n\n                query_update.active_position = query_info.active_position\n                query_updates.append(query_update)\n\n        return query_updates"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/__init__.py",
    "content": "from .orchestrator import BatchedPenalizerOrchestrator\nfrom .penalizers.frequency_penalty import BatchedFrequencyPenalizer\nfrom .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer\nfrom .penalizers.presence_penalty import BatchedPresencePenalizer\nfrom .penalizers.repetition_penalty import BatchedRepetitionPenalizer\n\n__all__ = [\n    \"BatchedFrequencyPenalizer\",\n    \"BatchedMinNewTokensPenalizer\",\n    \"BatchedPresencePenalizer\",\n    \"BatchedRepetitionPenalizer\",\n    \"BatchedPenalizerOrchestrator\",\n]\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/orchestrator.py",
    "content": "import abc\nimport dataclasses\nimport typing\n\nimport torch\n\n\n@dataclasses.dataclass\nclass _ReqLike:\n    origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]\n\n\n@dataclasses.dataclass\nclass _BatchLike:\n    reqs: typing.List[_ReqLike]\n\n    def batch_size(self):\n        return len(self.reqs)\n\n\nclass BatchedPenalizerOrchestrator:\n    batch: _BatchLike\n    device: str\n    vocab_size: int\n    penalizers: typing.Dict[typing.Type[\"_BatchedPenalizer\"], \"_BatchedPenalizer\"]\n\n    def __init__(\n        self,\n        vocab_size: int,\n        batch: _BatchLike,\n        device: str,\n        Penalizers: typing.Set[typing.Type[\"_BatchedPenalizer\"]],\n    ):\n        self.vocab_size = vocab_size\n        self.batch = batch\n        self.device = device\n\n        self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}\n\n        is_required = False\n        for penalizer in self.penalizers.values():\n            pen_is_required = penalizer.prepare_if_required()\n            is_required |= pen_is_required\n        self.is_required = is_required\n\n        if self.is_required:\n            self.cumulate_input_tokens(\n                input_ids=[req.origin_input_ids for req in self.reqs()]\n            )\n\n    def reqs(self):\n        return self.batch.reqs\n\n    def batch_size(self):\n        return self.batch.batch_size()\n\n    def cumulate_input_tokens(\n        self,\n        input_ids: typing.Union[\n            typing.List[torch.Tensor], typing.List[typing.List[int]]\n        ],\n    ):\n        \"\"\"\n        Feed the input tokens to the penalizers.\n\n        Args:\n            input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.\n        \"\"\"\n        token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)\n\n        for penalizer in self.penalizers.values():\n            penalizer.cumulate_input_tokens(input_ids=token_ids)\n\n    def cumulate_output_tokens(\n        self,\n        output_ids: typing.Union[\n            typing.List[torch.Tensor], typing.List[typing.List[int]]\n        ],\n    ):\n        \"\"\"\n        Feed the output tokens to the penalizers.\n\n        Args:\n            output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.\n        \"\"\"\n        if not self.is_required:\n            return\n\n        token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)\n\n        for penalizer in self.penalizers.values():\n            penalizer.cumulate_output_tokens(output_ids=token_ids)\n\n    def apply(self, logits: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Apply the penalizers to the logits.\n        Note that it may apply the penalizers in-place.\n\n        Args:\n            logits (torch.Tensor): The logits to apply the penalizers to.\n\n        Returns:\n            torch.Tensor: The logits after applying the penalizers.\n        \"\"\"\n        if not self.is_required:\n            return\n\n        for penalizer in self.penalizers.values():\n            logits = penalizer.apply(logits)\n\n        return logits\n\n    def filter(\n        self,\n        indices_to_keep: typing.List[int],\n        indices_tensor_to_keep: torch.Tensor = None,\n    ):\n        \"\"\"\n        Filter the penalizers based on the indices to keep in the batch.\n\n        Args:\n            indices_to_keep (typing.List[int]): List of indices to keep in the batch.\n            indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.\n        \"\"\"\n        if not self.is_required:\n            return\n\n        empty_indices = len(indices_to_keep) == 0\n\n        is_required = False\n        for penalizer in self.penalizers.values():\n            tmp_is_required = penalizer.is_required()\n            is_required = is_required or tmp_is_required\n            if not tmp_is_required or empty_indices:\n                penalizer.teardown()\n            else:\n                # create tensor index only when it's needed\n                if indices_tensor_to_keep is None:\n                    indices_tensor_to_keep = torch.tensor(\n                        indices_to_keep, dtype=torch.int32, device=self.device\n                    )\n\n                penalizer.filter(\n                    indices_to_keep=indices_to_keep,\n                    indices_tensor_to_keep=indices_tensor_to_keep,\n                )\n        self.is_required = is_required\n\n    def merge(self, their: \"BatchedPenalizerOrchestrator\"):\n        \"\"\"\n        Merge the penalizers of another orchestrator into this one.\n\n        Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).\n        Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.\n        This step requires the original batch.reqs, before it gets merged with other batch.reqs.\n\n        Args:\n            their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.\n        \"\"\"\n        if not self.is_required and not their.is_required:\n            return\n\n        self.is_required |= their.is_required\n        for Penalizer, their_penalizer in their.penalizers.items():\n            if Penalizer not in self.penalizers:\n                raise ValueError(f\"Penalizer {Penalizer} not found in self.penalizers\")\n\n            self.penalizers[Penalizer].merge(their_penalizer)\n\n\nclass _TokenIDs:\n    \"\"\"\n    A class that wraps token IDs to provide additional utility functions to penalizers.\n\n    Attributes:\n        orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.\n        token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.\n        cached_counts (torch.Tensor): The cached occurrence count tensor.\n    \"\"\"\n\n    orchestrator: BatchedPenalizerOrchestrator\n    token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]\n    cached_counts: torch.Tensor = None\n\n    def __init__(\n        self,\n        orchestrator: BatchedPenalizerOrchestrator,\n        token_ids: typing.Union[\n            typing.List[torch.Tensor], typing.List[typing.List[int]]\n        ],\n    ):\n        self.orchestrator = orchestrator\n\n        if not isinstance(token_ids[0], torch.Tensor):\n            token_ids = [\n                torch.tensor(\n                    data=ids, dtype=torch.int64, device=self.orchestrator.device\n                )\n                for ids in token_ids\n            ]\n\n        self.token_ids = token_ids\n\n    def occurrence_count(self) -> torch.Tensor:\n        \"\"\"\n        Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.\n\n        Returns:\n            torch.Tensor: The occurrence count tensor.\n        \"\"\"\n        if self.cached_counts is not None:\n            return self.cached_counts\n\n        token_ids = self.token_ids\n\n        if isinstance(token_ids, torch.Tensor):\n            token_ids = token_ids.unsqueeze(1)\n\n            # needs to be long to be used as index in scatter_add\n            if token_ids.dtype != torch.int64:\n                token_ids = token_ids.to(torch.int64)\n\n        padded_token_ids = torch.nn.utils.rnn.pad_sequence(\n            sequences=token_ids,\n            batch_first=True,\n            padding_value=self.orchestrator.vocab_size,\n        )\n\n        self.cached_counts = torch.zeros(\n            size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),\n            dtype=torch.int64,\n            device=self.orchestrator.device,\n        ).scatter_add_(\n            dim=1,\n            index=padded_token_ids,\n            src=torch.ones_like(padded_token_ids),\n        )[\n            :, : self.orchestrator.vocab_size\n        ]\n\n        return self.cached_counts\n\n\nclass _BatchedPenalizer(abc.ABC):\n    \"\"\"\n    An abstract class for a batched penalizer.\n    \"\"\"\n\n    orchestrator: BatchedPenalizerOrchestrator\n    _is_prepared: bool = False\n\n    def __init__(self, orchestrator: BatchedPenalizerOrchestrator):\n        self.orchestrator = orchestrator\n\n    def is_prepared(self) -> bool:\n        return self._is_prepared\n\n    def is_required(self) -> bool:\n        return self._is_required()\n\n    def prepare(self):\n        if not self.is_prepared():\n            self._prepare()\n            self._is_prepared = True\n\n    def prepare_if_required(self):\n        if self.is_required():\n            self.prepare()\n            return True\n        else:\n            return False\n\n    def teardown(self):\n        if self.is_prepared():\n            self._teardown()\n            self._is_prepared = False\n\n    def cumulate_input_tokens(self, input_ids: _TokenIDs):\n        if not self.is_prepared():\n            return\n\n        self._cumulate_input_tokens(input_ids=input_ids)\n\n    def cumulate_output_tokens(self, output_ids: _TokenIDs):\n        if not self.is_prepared():\n            return\n\n        self._cumulate_output_tokens(output_ids=output_ids)\n\n    def apply(self, logits: torch.Tensor) -> torch.Tensor:\n        if not self.is_prepared():\n            return logits\n\n        return self._apply(logits=logits)\n\n    def filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        if not self.is_prepared():\n            return\n\n        self._filter(\n            indices_to_keep=indices_to_keep,\n            indices_tensor_to_keep=indices_tensor_to_keep,\n        )\n\n    def merge(self, their: \"_BatchedPenalizer\"):\n        if not self.is_prepared() and not their.is_prepared():\n            return\n\n        self.prepare()\n        their.prepare()\n        self._merge(their)\n\n    @abc.abstractmethod\n    def _is_required(self) -> bool:\n        \"\"\"\n        Check if the penalizer is required to be prepared.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _prepare(self):\n        \"\"\"\n        Prepare the penalizer.\n        Usually, this is where the penalizer initializes its tensors.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _teardown(self):\n        \"\"\"\n        Tear down the penalizer.\n        Usually, this is where the penalizer frees its tensors.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _cumulate_input_tokens(self, input_ids: _TokenIDs):\n        \"\"\"\n        Cumulate the input tokens.\n        Orchestrator will call this function to feed the input tokens to the penalizer.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _cumulate_output_tokens(self, output_ids: _TokenIDs):\n        \"\"\"\n        Cumulate the output tokens.\n        Orchestrator will call this function to feed the output tokens to the penalizer.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _apply(self, logits: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Apply the penalizer to the logits.\n        Penalizers can modify the logits in-place if needed.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        \"\"\"\n        Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _merge(self, their: \"_BatchedPenalizer\"):\n        \"\"\"\n        Merge the penalizer with another penalizer.\n        \"\"\"\n        pass\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/frequency_penalty.py",
    "content": "import typing\n\nimport torch\n\nfrom ..orchestrator import _BatchedPenalizer, _TokenIDs\n\n\nclass BatchedFrequencyPenalizer(_BatchedPenalizer):\n    \"\"\"\n    Frequency penalizer penalizes tokens based on their frequency in the output.\n    \"\"\"\n\n    frequency_penalties: torch.Tensor = None\n    cumulated_frequency_penalties: torch.Tensor = None\n\n    def _is_required(self) -> bool:\n        return any(\n            req.sampling_params.frequency_penalty != 0.0\n            for req in self.orchestrator.reqs()\n        )\n\n    def _prepare(self):\n        self.cumulated_frequency_penalties = (\n            torch.tensor(\n                data=[0.0 for _ in self.orchestrator.reqs()],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .repeat(1, self.orchestrator.vocab_size)\n        )\n\n        self.frequency_penalties = (\n            torch.tensor(\n                data=[\n                    req.sampling_params.frequency_penalty\n                    for req in self.orchestrator.reqs()\n                ],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .expand_as(self.cumulated_frequency_penalties)\n        )\n\n    def _teardown(self):\n        del self.frequency_penalties\n        del self.cumulated_frequency_penalties\n\n        self.frequency_penalties = None\n        self.cumulated_frequency_penalties = None\n\n    def _cumulate_input_tokens(self, input_ids: _TokenIDs):\n        pass\n\n    def _cumulate_output_tokens(self, output_ids: _TokenIDs):\n        self.cumulated_frequency_penalties += (\n            self.frequency_penalties * output_ids.occurrence_count()\n        )\n\n    def _apply(self, logits: torch.Tensor) -> torch.Tensor:\n        logits -= self.cumulated_frequency_penalties\n        return logits\n\n    def _filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]\n        self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[\n            indices_tensor_to_keep\n        ]\n\n    def _merge(self, their: \"BatchedFrequencyPenalizer\"):\n        self.frequency_penalties = torch.cat(\n            [self.frequency_penalties, their.frequency_penalties], dim=0\n        )\n        self.cumulated_frequency_penalties = torch.cat(\n            [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],\n            dim=0,\n        )\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/min_new_tokens.py",
    "content": "import typing\n\nimport torch\n\nfrom ..orchestrator import _BatchedPenalizer, _TokenIDs\n\n\nclass BatchedMinNewTokensPenalizer(_BatchedPenalizer):\n    \"\"\"\n    Min new tokens penalizer penalizes tokens based on the length of the output.\n    \"\"\"\n\n    min_new_tokens: torch.Tensor = None\n    stop_token_penalties: torch.Tensor = None\n    len_output_tokens: torch.Tensor = None\n\n    def _is_required(self) -> bool:\n        return any(\n            req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()\n        )\n\n    def _prepare(self):\n        self.min_new_tokens = torch.tensor(\n            data=[\n                req.sampling_params.min_new_tokens for req in self.orchestrator.reqs()\n            ],\n            dtype=torch.int32,\n            device=self.orchestrator.device,\n        ).unsqueeze_(1)\n\n        padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(\n            sequences=[\n                torch.tensor(\n                    data=(\n                        list(\n                            (req.sampling_params.stop_token_ids or set())\n                            | (req.tokenizer.additional_stop_token_ids or set())\n                            | {req.tokenizer.eos_token_id}\n                        )\n                    ),\n                    dtype=torch.int64,\n                    device=self.orchestrator.device,\n                )\n                for req in self.orchestrator.reqs()\n            ],\n            batch_first=True,\n            padding_value=self.orchestrator.vocab_size,\n        )\n        self.stop_token_penalties = torch.zeros(\n            size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),\n            dtype=torch.float32,\n            device=self.orchestrator.device,\n        ).scatter_add_(\n            dim=1,\n            index=padded_stop_token_ids,\n            src=torch.full_like(\n                input=padded_stop_token_ids,\n                dtype=torch.float32,\n                fill_value=float(\"-inf\"),\n                device=self.orchestrator.device,\n            ),\n        )[\n            :, : self.orchestrator.vocab_size\n        ]\n\n        self.len_output_tokens = torch.zeros(\n            size=(self.orchestrator.batch_size(), 1),\n            dtype=torch.int32,\n            device=self.orchestrator.device,\n        )\n\n    def _teardown(self):\n        del self.min_new_tokens\n        del self.stop_token_penalties\n        del self.len_output_tokens\n\n        self.min_new_tokens = None\n        self.stop_token_penalties = None\n        self.len_output_tokens = None\n\n    def _cumulate_input_tokens(self, input_ids: _TokenIDs):\n        pass\n\n    def _cumulate_output_tokens(self, output_ids: _TokenIDs):\n        self.len_output_tokens += 1\n\n    def _apply(self, logits: torch.Tensor) -> torch.Tensor:\n        mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)\n        logits[mask] += self.stop_token_penalties[mask]\n        return logits\n\n    def _filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]\n        self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]\n        self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]\n\n    def _merge(self, their: \"BatchedMinNewTokensPenalizer\"):\n        self.min_new_tokens = torch.cat(\n            [self.min_new_tokens, their.min_new_tokens], dim=0\n        )\n        self.stop_token_penalties = torch.cat(\n            [self.stop_token_penalties, their.stop_token_penalties], dim=0\n        )\n        self.len_output_tokens = torch.cat(\n            [self.len_output_tokens, their.len_output_tokens], dim=0\n        )\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/presence_penalty.py",
    "content": "import typing\n\nimport torch\n\nfrom ..orchestrator import _BatchedPenalizer, _TokenIDs\n\n\nclass BatchedPresencePenalizer(_BatchedPenalizer):\n    \"\"\"\n    Presence penalizer penalizes tokens based on their presence in the output.\n    \"\"\"\n\n    presence_penalties: torch.Tensor = None\n    cumulated_presence_penalties: torch.Tensor = None\n\n    def _is_required(self) -> bool:\n        return any(\n            req.sampling_params.presence_penalty != 0.0\n            for req in self.orchestrator.reqs()\n        )\n\n    def _prepare(self):\n        self.cumulated_presence_penalties = (\n            torch.tensor(\n                data=[0.0 for _ in self.orchestrator.reqs()],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .repeat(1, self.orchestrator.vocab_size)\n        )\n\n        self.presence_penalties = (\n            torch.tensor(\n                data=[\n                    req.sampling_params.presence_penalty\n                    for req in self.orchestrator.reqs()\n                ],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .expand_as(self.cumulated_presence_penalties)\n        )\n\n    def _teardown(self):\n        del self.presence_penalties\n        del self.cumulated_presence_penalties\n\n        self.presence_penalties = None\n        self.cumulated_presence_penalties = None\n\n    def _cumulate_input_tokens(self, input_ids: _TokenIDs):\n        pass\n\n    def _cumulate_output_tokens(self, output_ids: _TokenIDs):\n        mask = output_ids.occurrence_count() > 0\n        self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]\n\n    def _apply(self, logits: torch.Tensor) -> torch.Tensor:\n        logits -= self.cumulated_presence_penalties\n        return logits\n\n    def _filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]\n        self.cumulated_presence_penalties = self.cumulated_presence_penalties[\n            indices_tensor_to_keep\n        ]\n\n    def _merge(self, their: \"BatchedPresencePenalizer\"):\n        self.presence_penalties = torch.cat(\n            [self.presence_penalties, their.presence_penalties], dim=0\n        )\n        self.cumulated_presence_penalties = torch.cat(\n            [self.cumulated_presence_penalties, their.cumulated_presence_penalties],\n            dim=0,\n        )\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/repetition_penalty.py",
    "content": "import typing\n\nimport torch\n\nfrom ..orchestrator import _BatchedPenalizer, _TokenIDs\n\n\nclass BatchedRepetitionPenalizer(_BatchedPenalizer):\n    \"\"\"\n    Repetition penalizer penalizes tokens based on their repetition in the input and output.\n    \"\"\"\n\n    repetition_penalties: torch.Tensor = None\n    cumulated_repetition_penalties: torch.Tensor = None\n\n    def _is_required(self) -> bool:\n        return any(\n            req.sampling_params.repetition_penalty != 1.0\n            for req in self.orchestrator.reqs()\n        )\n\n    def _prepare(self):\n        self.cumulated_repetition_penalties = (\n            torch.tensor(\n                data=[1.0 for _ in self.orchestrator.reqs()],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .repeat(1, self.orchestrator.vocab_size)\n        )\n\n        self.repetition_penalties = (\n            torch.tensor(\n                data=[\n                    req.sampling_params.repetition_penalty\n                    for req in self.orchestrator.reqs()\n                ],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .expand_as(self.cumulated_repetition_penalties)\n        )\n\n    def _teardown(self):\n        del self.repetition_penalties\n        del self.cumulated_repetition_penalties\n\n        self.repetition_penalties = None\n        self.cumulated_repetition_penalties = None\n\n    def _cumulate_input_tokens(self, input_ids: _TokenIDs):\n        mask = input_ids.occurrence_count() > 0\n        self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]\n\n    def _cumulate_output_tokens(self, output_ids: _TokenIDs):\n        mask = output_ids.occurrence_count() > 0\n        self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]\n\n    def _apply(self, logits: torch.Tensor) -> torch.Tensor:\n        return torch.where(\n            logits > 0,\n            logits / self.cumulated_repetition_penalties,\n            logits * self.cumulated_repetition_penalties,\n        )\n\n    def _filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]\n        self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[\n            indices_tensor_to_keep\n        ]\n\n    def _merge(self, their: \"BatchedRepetitionPenalizer\"):\n        self.repetition_penalties = torch.cat(\n            [self.repetition_penalties, their.repetition_penalties], dim=0\n        )\n        self.cumulated_repetition_penalties = torch.cat(\n            [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],\n            dim=0,\n        )\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/inference/sampling/sampler.py",
    "content": "'''\nDate: 2024-11-14 12:23:45\nLastEditors: Xie Weiyu ervinxie@qq.com\nLastEditTime: 2024-11-25 08:59:23\n'''\nimport logging\nimport torch\nfrom torch import nn\nfrom transformers import GenerationConfig\n\nfrom flashinfer.sampling import (\n\tmin_p_sampling_from_probs,\n\ttop_k_renorm_probs,\n\ttop_k_top_p_sampling_from_logits,\n\ttop_p_renorm_probs,\n)\n\ntry:\n    import torch_npu\n    use_torch_npu = torch.npu.is_available()\nexcept:\n    use_torch_npu = False\nlogger = logging.getLogger(__name__)\n\nclass SamplingOptions():\n\t# Batched sampling params\n\ttemperatures: torch.Tensor\n\ttop_ps: torch.Tensor\n\ttop_ks: torch.Tensor\n\tmin_ps: torch.Tensor\n\n\t# All requests use greedy sampling\n\tis_all_greedy: bool\n\n\t# Dispatch in CUDA graph\n\tneed_min_p_sampling: bool\n\t\n\tdef __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None):\n\t\tif pretrained_config is None and temperatures is None:\n\t\t\tself.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32)\n\t\t\tself.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32)\n\t\t\tself.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32)\n\t\t\tself.need_min_p_sampling = False\n\t\t\tself.is_all_greedy = True\n\t\telse:\n\t\t\tif temperatures is not None:\n\t\t\t\tself.temperatures = temperatures.unsqueeze(-1)\n\t\t\telse:\n\t\t\t\tself.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32)\n\t\t\t\n\t\t\tif top_ps is not None:\n\t\t\t\tself.top_ps = top_ps.unsqueeze(-1)\n\t\t\telse:\t\n\t\t\t\tself.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32)\n\t\t\tself.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32)\n\t\t\tself.need_min_p_sampling = False\n\t\t\tself.is_all_greedy = False\n\nclass Sampler(nn.Module):\n\tdef __init__(self):\n\t\tsuper().__init__()\n\t\n\tdef forward(\n\t\tself,\n\t\tlogits: torch.Tensor,\n\t\tsampling_config: SamplingOptions = None,\n\t):\n\t\tif sampling_config == None:\n\t\t\tsampling_config = SamplingOptions()\n\n\t\tlogits = logits.contiguous()\n\t\torigin_logits = logits.clone()\n\t\tif sampling_config.is_all_greedy or use_torch_npu:\n\t\t\t# Use torch.argmax if all requests use greedy sampling\n\t\t\tprobs = torch.softmax(logits, dim=-1)\n\t\t\tbatch_next_token_ids = torch.argmax(logits, -1)\n\t\telse:\n\t\t\t# Post process logits\n\t\t\tlogits.div_(sampling_config.temperatures)\n\t\t\tmax_top_k_round, batch_size = 32, logits.shape[0]\n\t\t\tif sampling_config.need_min_p_sampling:\n\t\t\t\tprobs = torch.softmax(logits, dim=-1)\n\t\t\t\tlogits = None\n\t\t\t\tdel logits\n\t\t\t\tprobs = top_k_renorm_probs(probs, sampling_config.top_ks)\n\t\t\t\tprobs = top_p_renorm_probs(probs, sampling_config.top_ps)\n\t\t\t\tbatch_next_token_ids = min_p_sampling_from_probs(\n\t\t\t\t\tprobs, sampling_config.min_ps\n\t\t\t\t)\n\t\t\t\ttemperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]\n\t\t\t\tbatch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)\n\t\t\telse:\n\t\t\t\t# TODO: use different kernel when don't need top_k or top_p\n\t\t\t\t# @TODO get probs\n\t\t\t\tprobs = logits\n\t\t\t\tbatch_next_token_ids = top_k_top_p_sampling_from_logits(\n\t\t\t\t\tlogits,\n\t\t\t\t\tsampling_config.top_ks,\n\t\t\t\t\tsampling_config.top_ps,\n\t\t\t\t\tfilter_apply_order=\"joint\",\n\t\t\t\t)\n\t\t\t\ttemperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]\n\t\t\t\tbatch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)\n\t\t\t\n\t\treturn batch_next_token_ids.to(torch.int32), probs"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/sched_rpc.py",
    "content": "from datetime import datetime\nimport os\nfrom typing import Optional\nimport zmq\nimport pickle\nimport threading\nimport torch.multiprocessing as mp\nimport sys\ncurrent_file_path = os.path.abspath(__file__)\n# sys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"..\", \"..\"))\nimport pickle\nimport argparse\nimport torch\ntry:\n    import torch_npu\n    use_npu = torch.npu.is_available()\nexcept:\n    use_npu = False\nfrom ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe, create_sched_settings_glm4moe, create_sched_settings_smallthinker, create_sched_settings_qwen3next\n\n\n\n\nif mp.get_start_method(allow_none=True) is None:\n    print('set start method')\n    mp.set_start_method('spawn')\nelse:\n    print(f'start method already set to {mp.get_start_method(allow_none=True)}')\n\n\nclass SchedulerServer:\n    def __init__(self, settings, main_args):\n        # 创建 Scheduler 实例并初始化\n        if use_npu:\n            for device_id in settings.gpu_device_id:\n                torch_npu.npu.set_device(f'npu:{device_id}')\n        self.sched = sched_ext.create_scheduler(settings)\n    \n        # 初始化 ZeroMQ 上下文和套接字\n        self.context = zmq.Context()\n        self.frontend = self.context.socket(zmq.ROUTER)\n        print(f\"sched zmq rpc server on port {main_args.sched_port}\")\n        self.frontend.bind(f\"tcp://*:{main_args.sched_port}\") \n\n        # 创建内部的 DEALER 套接字，用于与工作线程通信\n        self.backend = self.context.socket(zmq.DEALER)\n        self.backend.bind(\"inproc://backend\")\n\n    # 启动调度器\n    def run_scheduler(self):\n        self.sched.run()\n\n    # 停止调度器\n    def stop_scheduler(self):\n        self.sched.stop()\n\n    # 处理客户端请求\n    def start_proxy(self):\n        # 使用 ZMQ 的内置代理，将前端请求分发给后端工作线程\n        zmq.proxy(self.frontend, self.backend)\n\n    # 工作线程处理请求\n    def worker_routine(self):\n        worker = self.context.socket(zmq.REP)\n        worker.connect(\"inproc://backend\")\n        while True:\n            try:\n                # 接收客户端请求\n                message = worker.recv()\n                data = pickle.loads(message)\n\n                method = data.get('method')\n                params = data.get('params', {})\n                # print(f\"Received request: {method}\")\n\n                if method == 'add_query':\n                    query_add = params.get('query')  # 直接是一个 QueryAdd 对象\n                    # 添加查询\n                    query_id = self.sched.add_query(query_add)\n                    # 发送响应\n                    response = {'status': 'ok', 'query_id': query_id}\n                    worker.send(pickle.dumps(response))\n\n                elif method == 'cancel_query':\n                    query_id = params.get('query_id')\n                    # 假设您的 Scheduler 类实现了 cancel 方法\n                    self.sched.cancel(query_id)\n                    response = {'status': 'ok'}\n                    worker.send(pickle.dumps(response))\n\n                elif method == 'update_last_batch':\n                    updates = params.get('updates')  # 直接是一个列表，包含 QueryUpdate 对象\n\n                    # 更新最后一个批次\n                    batch_todo = self.sched.update_last_batch(updates)\n\n                    # 直接发送 batch_todo 对象\n                    response = {'status': 'ok', 'batch_todo': batch_todo}\n                    # print (batch_todo.query_lengths, batch_todo.query_ids)\n                    worker.send(pickle.dumps(response))\n\n                elif method == 'get_inference_context':\n                    inference_context = self.sched.get_inference_context()\n                    data = {\n                        \"k_cache\":inference_context.k_cache,\n                        \"v_cache\":inference_context.v_cache\n                    }\n                    print(f\"Serializing KVCache\")\n                    data[\"k_cache\"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']]\n                    data[\"v_cache\"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']]\n                    # print(data)\n                    response = {'status': 'ok', 'inference_context': data}\n\n                    worker.send(pickle.dumps(response))\n                    # response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1 \n                    # print(\"k_cache update\")\n\n                else:\n                    # 未知方法\n                    response = {'status': 'error', 'message': 'Unknown method'}\n                    worker.send(pickle.dumps(response))\n\n            except Exception as e:\n                # 处理异常并发送错误响应\n                response = {'status': 'error', 'message': str(e)}\n                worker.send(pickle.dumps(response))\n\n    # 启动 RPC 服务\n    def start_rpc_service(self):\n        try:\n            print(\"Scheduler RPC service is running...\")\n\n            # 在单独的线程中运行调度器\n            threading.Thread(target=self.run_scheduler, daemon=True).start()\n\n            # 启动工作线程\n            for _ in range(10):  # 根据需要调整线程数\n                threading.Thread(target=self.worker_routine, daemon=True).start()\n\n            # 启动代理，开始监听请求\n            self.start_proxy()\n\n        except KeyboardInterrupt:\n            print(\"Shutting down scheduler RPC service...\")\n            self.stop_rpc_service()\n\n    # 停止 RPC 服务\n    def stop_rpc_service(self):\n        self.stop_scheduler()\n        self.frontend.close()\n        self.backend.close()\n        self.context.term()\n\ndef start_server(settings, main_args):\n    server = SchedulerServer(settings, main_args)\n    server.start_rpc_service()\n\n\n# Add async client for webserver\nclass SchedulerClient:\n    def __init__(self, sched_port):\n        address=f'tcp://localhost:{sched_port}'\n        self.address = address\n        self.context = zmq.Context()\n        self.socket = self.context.socket(zmq.REQ)\n        self.socket.connect(self.address)\n        print(f\"Connected to server at {self.address}\")\n    \n    def __del__(self):\n        self.socket.close()\n        self.context.term()\n    \n    def send_request(self, method, params=None):\n        if params is None:\n            params = {}\n        request = {\n            'method': method,\n            'params': params\n        }\n        # print(f'send request {request}')\n        self.socket.send(pickle.dumps(request))\n        response = self.socket.recv()\n        # print(response)\n        response = pickle.loads(response)\n        if response.get('status') == 'ok':\n            return response\n        else:\n            raise Exception(f\"Error from server: {response.get('message')}\")\n    \n    def add_query(self, query):\n        response = self.send_request('add_query', {'query': query})\n        return response.get('query_id')\n    \n    def cancel_query(self, query_id):\n        self.send_request('cancel_query', {'query_id': query_id})\n    \n    def update_last_batch(self, updates):\n        response = self.send_request('update_last_batch', {'updates': updates})\n        # print(f\"update_last_batch response {response}\")\n        return response.get('batch_todo')\n    \n    def rebuild_inferece_context(self,response):\n        data = response.get('inference_context')\n        inference_context = sched_ext.InferenceContext()\n        print('Rebuilding kvcache')\n        inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']]\n        inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']]\n        return inference_context\n\n    def get_inference_context_raw(self):\n        response = self.send_request('get_inference_context')\n        return response\n       \n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config\", type=str, required=True)\n    args = parser.parse_args()\n    with open(args.config, \"rb\") as f:\n        main_args = pickle.load(f)\n    if main_args.architectures == \"Qwen2MoeForCausalLM\": \n        settings = create_sched_settings_qwen2moe(main_args)\n    elif main_args.architectures == \"Qwen3MoeForCausalLM\":\n        settings = create_sched_settings_qwen3moe(main_args)\n    elif main_args.architectures == \"Glm4MoeForCausalLM\":\n        settings = create_sched_settings_glm4moe(main_args)\n    elif main_args.architectures == \"SmallThinkerForCausalLM\":\n        settings = create_sched_settings_smallthinker(main_args)\n    elif main_args.architectures == \"Qwen3NextForCausalLM\":\n        settings = create_sched_settings_qwen3next(main_args)\n    else:\n        settings = create_sched_settings(main_args)\n    start_server(settings, main_args)\n"
  },
  {
    "path": "archive/ktransformers/server/balance_serve/settings.py",
    "content": "'''\nDate: 2024-11-13 09:43:39\nLastEditors: djw\nLastEditTime: 2024-11-18 16:41:03\n'''\nimport sys, os\nimport yaml, json\nfrom time import sleep\n\n\nimport sched_ext\nfrom transformers import AutoConfig\n\nfrom ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig\nfrom ktransformers.models.configuration_glm4_moe import Glm4MoeConfig\nfrom ktransformers.models.configuration_smallthinker import SmallthinkerConfig\nfrom ktransformers.models.configuration_qwen3_next import Qwen3NextConfig\n\ndef create_sched_settings(args):\n    default_sample_options = sched_ext.SampleOptions()\n    model_name = os.path.basename(os.path.normpath(args.model_dir))\n    input_model_settings = sched_ext.ModelSettings()\n    input_model_settings.model_path = args.model_dir\n    input_model_settings.params_count = int(0)\n    model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n    input_model_settings.layer_count = model_config.num_hidden_layers\n    input_model_settings.num_k_heads = 1 # model_config[\"num_key_value_heads\"]\n    input_model_settings.k_head_dim = 576\n    input_model_settings.bytes_per_params = 2\n    input_model_settings.bytes_per_kv_cache_element = 2\n    settings = sched_ext.Settings()\n    settings.model_name = model_name\n    settings.quant_type = \"BF16\"\n    settings.model_settings = input_model_settings\n    settings.page_size = args.page_size\n    settings.gpu_device_count = args.tp # only full tp supported now\n    settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]\n    # settings.gpu_memory_size = args.cache_lens*576*2\n    settings.gpu_memory_size = args.gpu_memory_size\n    settings.memory_utilization_percentage = args.utilization_percentage\n    max_batch_size = args.max_batch_size\n    chunk_size = args.chunk_size\n\n    max_decode_batch_size = max_batch_size - 2\n\n    settings.max_batch_size = max_batch_size\n    settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2\n    settings.sample_options = default_sample_options\n    settings.sched_metrics_port = args.sched_metrics_port\n    settings.gpu_only = args.memory_gpu_only\n    settings.use_self_defined_head_dim = True\n    settings.self_defined_head_dim = 576\n    settings.full_kv_cache_on_each_gpu = True\n    settings.k_cache_on = True\n    settings.v_cache_on = False\n\n    settings.kvc2_root_path = args.kvc2_disk_path\n    settings.kvc2_config_path = args.kvc2_config_dir\n    settings.memory_pool_size_GB = args.cpu_memory_size_GB\n    settings.evict_count = 40\n    settings.kvc2_metrics_port = args.kvc2_metrics_port\n    settings.load_from_disk = False\n    settings.save_to_disk = True\n\n    settings.strategy_name = args.sched_strategy\n\n    settings.auto_derive()\n    return settings\n\n\ndef create_sched_settings_qwen2moe(args):\n    default_sample_options = sched_ext.SampleOptions()\n    model_name = os.path.basename(os.path.normpath(args.model_dir))\n    input_model_settings = sched_ext.ModelSettings()\n    input_model_settings.model_path = args.model_dir\n    input_model_settings.params_count = int(0)\n    model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n    input_model_settings.layer_count = model_config.num_hidden_layers\n    input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config[\"num_key_value_heads\"]\n    input_model_settings.k_head_dim = 128\n    input_model_settings.bytes_per_params = 2\n    input_model_settings.bytes_per_kv_cache_element = 2\n    settings = sched_ext.Settings()\n    settings.model_name = model_name\n    settings.quant_type = \"BF16\"\n    settings.model_settings = input_model_settings\n    settings.page_size = args.page_size\n    settings.gpu_device_count = 1 # tp\n    settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]\n    # settings.gpu_memory_size = args.cache_lens*576*2\n    settings.gpu_memory_size = args.gpu_memory_size\n    settings.memory_utilization_percentage = args.utilization_percentage\n    max_batch_size = args.max_batch_size\n    chunk_size = args.chunk_size\n\n    max_decode_batch_size = max_batch_size - 2\n\n    settings.max_batch_size = max_batch_size\n    settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2\n    settings.sample_options = default_sample_options\n    settings.sched_metrics_port = args.sched_metrics_port\n    settings.gpu_only = args.memory_gpu_only\n    settings.use_self_defined_head_dim = False\n    settings.self_defined_head_dim = 576\n    settings.full_kv_cache_on_each_gpu = True\n    settings.k_cache_on = True\n    settings.v_cache_on = True\n\n    settings.kvc2_root_path = args.kvc2_disk_path\n    settings.kvc2_config_path = args.kvc2_config_dir\n    settings.memory_pool_size_GB = args.cpu_memory_size_GB\n    settings.evict_count = 40\n    settings.kvc2_metrics_port = args.kvc2_metrics_port\n    settings.load_from_disk = False\n    settings.save_to_disk = True\n\n\n    settings.strategy_name = args.sched_strategy\n\n    settings.auto_derive()\n    return settings\n\n\n\ndef create_sched_settings_qwen3moe(args):\n    default_sample_options = sched_ext.SampleOptions()\n    model_name = os.path.basename(os.path.normpath(args.model_dir))\n    input_model_settings = sched_ext.ModelSettings()\n    input_model_settings.model_path = args.model_dir\n    input_model_settings.params_count = int(0)\n    model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n    input_model_settings.layer_count = model_config.num_hidden_layers\n    input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config[\"num_key_value_heads\"]\n    input_model_settings.k_head_dim = 128\n    input_model_settings.bytes_per_params = 2\n    input_model_settings.bytes_per_kv_cache_element = 2\n    settings = sched_ext.Settings()\n    settings.model_name = model_name\n    settings.quant_type = \"BF16\"\n    settings.model_settings = input_model_settings\n    settings.page_size = args.page_size\n    settings.gpu_device_count = 1 # tp\n    settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]\n    # settings.gpu_memory_size = args.cache_lens*576*2\n    settings.gpu_memory_size = args.gpu_memory_size\n    settings.memory_utilization_percentage = args.utilization_percentage\n    max_batch_size = args.max_batch_size\n    chunk_size = args.chunk_size\n\n    max_decode_batch_size = max_batch_size - 2\n\n    settings.max_batch_size = max_batch_size\n    settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2\n    settings.sample_options = default_sample_options\n    settings.sched_metrics_port = args.sched_metrics_port\n    settings.gpu_only = args.memory_gpu_only\n    settings.use_self_defined_head_dim = False\n    settings.self_defined_head_dim = 576\n    settings.full_kv_cache_on_each_gpu = True\n    settings.k_cache_on = True\n    settings.v_cache_on = True\n\n    settings.kvc2_root_path = args.kvc2_disk_path\n    settings.kvc2_config_path = args.kvc2_config_dir\n    settings.memory_pool_size_GB = args.cpu_memory_size_GB\n    settings.evict_count = 40\n    settings.kvc2_metrics_port = args.kvc2_metrics_port\n    settings.load_from_disk = False\n    settings.save_to_disk = True\n\n\n    settings.strategy_name = args.sched_strategy\n\n    settings.auto_derive()\n    return settings\n\ndef create_sched_settings_glm4moe(args):\n    default_sample_options = sched_ext.SampleOptions()\n    model_name = os.path.basename(os.path.normpath(args.model_dir))\n    input_model_settings = sched_ext.ModelSettings()\n    input_model_settings.model_path = args.model_dir\n    input_model_settings.params_count = int(0)\n    model_config = Glm4MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n    input_model_settings.layer_count = model_config.num_hidden_layers\n    input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config[\"num_key_value_heads\"]\n    input_model_settings.k_head_dim = 128\n    input_model_settings.bytes_per_params = 2\n    input_model_settings.bytes_per_kv_cache_element = 2\n    settings = sched_ext.Settings()\n    settings.model_name = model_name\n    settings.quant_type = \"BF16\"\n    settings.model_settings = input_model_settings\n    settings.page_size = args.page_size\n    settings.gpu_device_count = 1 # tp\n    settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]\n    # settings.gpu_memory_size = args.cache_lens*576*2\n    settings.gpu_memory_size = args.gpu_memory_size\n    settings.memory_utilization_percentage = args.utilization_percentage\n    max_batch_size = args.max_batch_size\n    chunk_size = args.chunk_size\n\n    max_decode_batch_size = max_batch_size - 2\n\n    settings.max_batch_size = max_batch_size\n    settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2\n    settings.sample_options = default_sample_options\n    settings.sched_metrics_port = args.sched_metrics_port\n    settings.gpu_only = args.memory_gpu_only\n    settings.use_self_defined_head_dim = False\n    settings.self_defined_head_dim = 576\n    settings.full_kv_cache_on_each_gpu = True\n    settings.k_cache_on = True\n    settings.v_cache_on = True\n\n    settings.kvc2_root_path = args.kvc2_disk_path\n    settings.kvc2_config_path = args.kvc2_config_dir\n    settings.memory_pool_size_GB = args.cpu_memory_size_GB\n    settings.evict_count = 40\n    settings.kvc2_metrics_port = args.kvc2_metrics_port\n    settings.load_from_disk = False\n    settings.save_to_disk = True\n\n\n    settings.strategy_name = args.sched_strategy\n\n    settings.auto_derive()\n    return settings\n\ndef create_sched_settings_smallthinker(args):\n    default_sample_options = sched_ext.SampleOptions()\n    model_name = os.path.basename(os.path.normpath(args.model_dir))\n    input_model_settings = sched_ext.ModelSettings()\n    input_model_settings.model_path = args.model_dir\n    input_model_settings.params_count = int(0)\n    model_config = SmallthinkerConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n    input_model_settings.layer_count = model_config.num_hidden_layers\n    input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config[\"num_key_value_heads\"]\n    input_model_settings.k_head_dim = 128\n    input_model_settings.bytes_per_params = 2\n    input_model_settings.bytes_per_kv_cache_element = 2\n    settings = sched_ext.Settings()\n    settings.model_name = model_name\n    settings.quant_type = \"BF16\"\n    settings.model_settings = input_model_settings\n    settings.page_size = args.page_size\n    settings.gpu_device_count = 1 # tp\n    settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]\n    # settings.gpu_memory_size = args.cache_lens*576*2\n    settings.gpu_memory_size = args.gpu_memory_size\n    settings.memory_utilization_percentage = args.utilization_percentage\n    max_batch_size = args.max_batch_size\n    chunk_size = args.chunk_size\n\n    max_decode_batch_size = max_batch_size - 2\n\n    settings.max_batch_size = max_batch_size\n    settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2\n    settings.sample_options = default_sample_options\n    settings.sched_metrics_port = args.sched_metrics_port\n    settings.gpu_only = args.memory_gpu_only\n    settings.use_self_defined_head_dim = False\n    settings.self_defined_head_dim = 576\n    settings.full_kv_cache_on_each_gpu = True\n    settings.k_cache_on = True\n    settings.v_cache_on = True\n\n    settings.kvc2_root_path = args.kvc2_disk_path\n    settings.kvc2_config_path = args.kvc2_config_dir\n    settings.memory_pool_size_GB = args.cpu_memory_size_GB\n    settings.evict_count = 40\n    settings.kvc2_metrics_port = args.kvc2_metrics_port\n    settings.load_from_disk = False\n    settings.save_to_disk = True\n\n\n    settings.strategy_name = args.sched_strategy\n\n    settings.auto_derive()\n    return settings\n\ndef create_sched_settings_qwen3next(args):\n    default_sample_options = sched_ext.SampleOptions()\n    model_name = os.path.basename(os.path.normpath(args.model_dir))\n    input_model_settings = sched_ext.ModelSettings()\n    input_model_settings.model_path = args.model_dir\n    input_model_settings.params_count = int(0)\n    model_config = Qwen3NextConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n    input_model_settings.layer_count = model_config.num_hidden_layers\n    input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config[\"num_key_value_heads\"]\n    input_model_settings.k_head_dim = 256\n    input_model_settings.bytes_per_params = 2\n    input_model_settings.bytes_per_kv_cache_element = 2\n    settings = sched_ext.Settings()\n    settings.model_name = model_name\n    settings.quant_type = \"BF16\"\n    settings.model_settings = input_model_settings\n    settings.page_size = args.page_size\n    settings.gpu_device_count = 1 # tp\n    settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]\n    # settings.gpu_memory_size = args.cache_lens*576*2\n    settings.gpu_memory_size = args.gpu_memory_size\n    settings.memory_utilization_percentage = args.utilization_percentage\n    max_batch_size = args.max_batch_size\n    chunk_size = args.chunk_size\n\n    max_decode_batch_size = max_batch_size - 2\n\n    settings.max_batch_size = max_batch_size\n    settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2\n    settings.sample_options = default_sample_options\n    settings.sched_metrics_port = args.sched_metrics_port\n    settings.gpu_only = args.memory_gpu_only\n    settings.use_self_defined_head_dim = False\n    settings.self_defined_head_dim = 576\n    settings.full_kv_cache_on_each_gpu = True\n    settings.k_cache_on = True\n    settings.v_cache_on = True\n\n    settings.kvc2_root_path = args.kvc2_disk_path\n    settings.kvc2_config_path = args.kvc2_config_dir\n    settings.memory_pool_size_GB = args.cpu_memory_size_GB\n    settings.evict_count = 40\n    settings.kvc2_metrics_port = args.kvc2_metrics_port\n    settings.load_from_disk = False\n    settings.save_to_disk = True\n\n\n    settings.strategy_name = args.sched_strategy\n\n    settings.auto_derive()\n    return settings"
  },
  {
    "path": "archive/ktransformers/server/config/config.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : unicornchan\nDate         : 2024-06-11 16:35:42\nVersion      : 1.0.0\nLastEditors  : WuHao\nLastEditTime : 2024-08-12 06:31:14\n\"\"\"\nimport os\nimport shutil\nimport yaml\nimport psutil\n\nfrom ktransformers.server.config.singleton import Singleton\nfrom typing import Optional\n\n\nclass Config(metaclass=Singleton):\n    \"\"\"Singleton pattern Config class, used to get all configurations.\"\"\"\n\n    CONFIG_FILE_NAME = \"config.yaml\"\n\n    @staticmethod\n    def load() -> dict:\n        \"\"\"load config file\n\n        Returns:\n            dict: all configs\n        \"\"\"\n        base_path: str = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\n        config_yaml: str = os.path.join(base_path, \"configs\", Config.CONFIG_FILE_NAME)\n\n        user_path: str = os.path.expanduser(\"~\")\n        localstore_path: str = os.path.join(user_path, \".ktransformers\")\n        kvc2_config_dir = os.path.join(localstore_path, \"kvc2\")\n        config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)\n        if not os.path.exists(config_yaml):\n            print(f\"Can't find config file, {config_yaml}\")\n            exit(-1)\n        if not os.path.exists(localstore_path):\n            os.mkdir(localstore_path)\n        if not os.path.exists(kvc2_config_dir):\n            os.mkdir(kvc2_config_dir)\n        if not os.path.exists(config_path):\n            shutil.copyfile(config_yaml, config_path)\n        with open(config_path, \"r\", encoding=\"utf-8\") as fp:\n            config = yaml.safe_load(fp)\n        return config\n\n    @staticmethod\n    def to_path(path: str) -> str:\n        \"\"\"\n        process file path\n        \"\"\"\n        base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\n        real_path = path if os.path.isabs(path) else os.path.join(base_path, path)\n        return real_path\n\n    def __init__(self):\n        cfg = Config.load()\n        self.base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\n        self.user_path: str = os.path.expanduser(\"~\")\n        self.localstore_path: str = os.path.join(self.user_path, \".ktransformers\")\n        # log configs\n        self.log_dir = os.path.join(self.localstore_path, cfg[\"log\"][\"dir\"])\n        if not os.path.exists(self.log_dir):\n            os.mkdir(self.log_dir)\n        self.log_file = cfg[\"log\"][\"file\"]\n        self.log_level = cfg[\"log\"][\"level\"]\n        self.backup_count = cfg[\"log\"][\"backup_count\"]\n\n        self.kvc2_config_dir = os.path.join(self.localstore_path, \"kvc2\")\n        # server configs\n        self.server: dict = cfg.get(\"server\", {})\n        self.server_ip = self.server.get(\"ip\", \"0.0.0.0\")\n        self.server_port = self.server.get(\"port\", 9016)\n        self.api_key = self.server.get(\"api_key\", \"\")\n\n        # db configs\n        self.db_configs: dict = cfg.get(\"db\", {})\n        self.db_type = self.db_configs.get(\"type\", \"\")\n        self.db_host = self.localstore_path\n        self.db_port = self.db_configs.get(\"port\", \"\")\n        self.db_name = self.db_configs.get(\"database\", \"\")\n        self.db_pool_size = self.db_configs.get(\"pool_size\")\n        self.db_database = self.db_configs.get(\"database\", \"\")\n\n        # user config\n        self.user_config: dict = cfg.get(\"user\", {})\n        self.user_secret_key = self.user_config.get(\"secret_key\", \"\")\n        self.user_algorithm = self.user_config.get(\"algorithm\", \"\")\n        self.user_force_think = self.user_config.get(\"force_think\", False)\n\n        # model config\n        self.model: dict = cfg.get(\"model\", {})\n        self.backend_type: str = self.model.get(\"type\", \"transformers\")\n        self.model_dir: str = self.model.get(\"path\", \"\")\n        # to make sure it consistent with previous version\n        self.model_path: str = self.model_dir\n        self.model_name: str = self.model.get(\"name\", \"\")\n        self.architectures: str = self.model.get(\"name\", \"\")\n        self.model_device: str = self.model.get(\"device\", \"cuda:0\")\n        self.gguf_path: Optional[str] = self.model.get(\"gguf_path\", None)\n        self.use_cuda_graph = self.model.get(\"use_cuda_graph\", True)\n        self.trust_remote_code = self.model.get(\"trust_remote_code\", True)\n        # self.model_cache_lens = self.model.get(\"cache_lens\")\n        self.optimize_config_path: Optional[str] = self.model.get(\n            \"optimize_config_path\", None\n        )\n        \n        self.max_new_tokens = self.model.get(\"max_new_tokens\", 2000)\n        self.json_mode = self.model.get(\"json_mode\", False)\n        self.healing = self.model.get(\"healing\", False)\n        self.ban_strings: Optional[list] = self.model.get(\"ban_strings\", None)\n        self.gpu_split: Optional[str] = self.model.get(\"gpu_split\", None)\n        self.length: Optional[int] = self.model.get(\"length\", None)\n        self.rope_scale: Optional[float] = self.model.get(\"rope_scale\", None)\n        self.rope_alpha: Optional[float] = self.model.get(\"rope_alpha\", None)\n        self.no_flash_attn = self.model.get(\"no_flash_attn\", False)\n        self.low_mem = self.model.get(\"low_mem\", False)\n        self.experts_per_token: Optional[int] = self.model.get(\"experts_per_token\", None)\n        self.load_q4 = self.model.get(\"load_q4\", False)\n        self.fast_safetensors = self.model.get(\"fast_safetensors\", False)\n        self.draft_model_dir: Optional[str] = self.model.get(\"draft_model_dir\", None)\n        self.no_draft_scale = self.model.get(\"no_draft_scale\", False)\n        self.modes = self.model.get(\"modes\", False)\n        self.mode = self.model.get(\"mode\", \"llama\")\n        self.username = self.model.get(\"username\", \"User\")\n        self.botname = self.model.get(\"botname\", \"Chatbort\")\n        self.system_prompt: Optional[str] = self.model.get(\"system_prompt\", None)\n        self.temperature = self.model.get(\"temperature\", 0.95)\n        self.smoothing_factor = self.model.get(\"smoothing_factor\", 0.0)\n        self.dynamic_temperature: Optional[str] = self.model.get(\"dynamic_temperature\", None)\n        self.top_k = self.model.get(\"top_k\", 50)\n        self.top_p = self.model.get(\"top_p\", 0.8)\n        self.top_a = self.model.get(\"top_a\", 0.0)\n        self.skew = self.model.get(\"skew\", 0.0)\n        self.typical = self.model.get(\"typical\", 0.0)\n        self.repetition_penalty = self.model.get(\"repetition_penalty\", 1.01)\n        self.frequency_penalty = self.model.get(\"frequency_penalty\", 0.0)\n        self.presence_penalty = self.model.get(\"presence_penalty\", 0.0)\n        self.response_chunk = self.model.get(\"response_chunk\", 250)\n        self.no_code_formatting = self.model.get(\"no_code_formatting\", False)\n        self.cache_8bit = self.model.get(\"cache_8bit\", False)\n        self.cache_q4 = self.model.get(\"cache_q4\", True)\n        self.ngram_decoding = self.model.get(\"ngram_decoding\", False)\n        self.print_timings = self.model.get(\"print_timings\", False)\n        self.amnesia = self.model.get(\"amnesia\", False)\n        self.batch_size = self.model.get(\"batch_size\", 1)\n        self.cache_lens = self.model.get(\"cache_lens\", 4096)\n        self.device = self.model.get(\"device\", \"cuda:2\")\n\n        # web config\n        self.web: dict = cfg.get(\"web\", {})\n        self.web_cross_domain: bool = self.web.get(\"open_cross_domain\", True)\n        self.mount_web: bool = self.web.get(\"mount\", False)\n\n        # ext\n        self.ext: dict = cfg.get(\"ext\", {})\n        self.cpu_infer = psutil.cpu_count(logical=False) - 3\n\n        # file config\n        self.local_store_configs: dict = cfg.get(\"local_store\", {})\n        self.file_upload_dir: str = os.path.join(\n            self.localstore_path, self.local_store_configs.get(\"file_upload_dir\", \"\")\n        )\n        self.assistant_store_dir: str = os.path.join(\n            self.localstore_path, self.local_store_configs.get(\"assistant_store_dir\", \"\")\n        )\n\n        # long context config\n        self.long_context_config: dict = cfg.get(\"long_context\", {})\n        self.max_seq_len = self.long_context_config.get(\"max_seq_len\", 32000)\n        self.block_size = self.long_context_config.get(\"block_size\", 128)\n        self.local_windows_len = self.long_context_config.get(\"local_windows_len\", 4096)\n        self.second_select_num = self.long_context_config.get(\"second_select_num\", 32)\n        self.anchor_type = self.long_context_config.get(\"anchor_type\", \"DYNAMIC\")\n        self.kv_type = self.long_context_config.get(\"kv_type\", \"FP16\")\n        self.dense_layer_num = self.long_context_config.get(\"dense_layer_num\", 2)\n        self.anchor_num = self.long_context_config.get(\"anchor_num\", 1)\n        self.preselect_block = self.long_context_config.get(\"preselect_block\", True)\n        self.head_select_mode = self.long_context_config.get(\"head_select_mode\", \"SHARED\")\n        self.preselect_block_count = self.long_context_config.get(\"preselect_block_count\", 32)\n        self.layer_step = self.long_context_config.get(\"layer_step\", 1)\n        self.token_step = self.long_context_config.get(\"token_step\", 100)\n\n        # local chat\n        self.local_chat_config: dict = cfg.get(\"local_chat\", {})\n        self.prompt_file = self.local_chat_config.get(\"prompt_file\", None)\n\n        # asyncserver\n        self.sched_strategy = cfg[\"async_server\"][\"sched_strategy\"]\n        self.sched_port = cfg[\"async_server\"][\"sched_port\"]\n        self.sched_metrics_port = cfg[\"async_server\"][\"sched_metrics_port\"]\n        self.kvc2_metrics_port = cfg[\"async_server\"][\"kvc2_metrics_port\"]\n        self.max_batch_size = cfg[\"async_server\"][\"max_batch_size\"]\n        self.page_size = cfg[\"attn\"][\"page_size\"]\n        self.chunk_size = cfg[\"attn\"][\"chunk_size\"]\n        self.memory_gpu_only = cfg[\"kvc2\"][\"gpu_only\"]\n        self.cache_lens = ((self.cache_lens + self.page_size - 1) // self.page_size) * self.page_size\n        self.gpu_memory_size = 2*576*61*self.cache_lens\n        self.utilization_percentage = 1.0 #cfg[\"kvc2\"][\"utilization_percentage\"]\n        self.cpu_memory_size_GB = cfg[\"kvc2\"][\"cpu_memory_size_GB\"]\n        self.kvc2_disk_path = cfg[\"kvc2\"][\"disk_path\"]\n        # only support 2 prefill task\n        self.max_prefill_batch_size = 2\n        self.max_decode_batch_size = self.max_batch_size - self.max_prefill_batch_size \n\n"
  },
  {
    "path": "archive/ktransformers/server/config/log.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : unicornchan\nDate         : 2024-06-12 02:48:39\nVersion      : 1.0.0\nLastEditors  : chenxl \nLastEditTime : 2024-07-27 01:55:50\n'''\n\nimport codecs\nimport logging\nimport os\nimport re\nimport locale\nfrom pathlib import Path\nfrom logging.handlers import BaseRotatingHandler\nimport time\nimport colorlog\n\nfrom ktransformers.server.config.config import Config\n\n\nclass DailyRotatingFileHandler(BaseRotatingHandler):\n    \"\"\"\n    such as 'logging.TimeRotatingFileHandler', Additional features:\n     - support multiprocess\n     - support rotating daily\n    \"\"\"\n\n    def __init__(self, filename, backupCount=0, encoding=None, delay=False, utc=False, **kwargs): # pylint: disable=unused-argument\n        self.backup_count = backupCount\n        self.utc = utc\n        self.suffix = \"%Y-%m-%d\"\n        self.base_log_path = Path(filename)\n        if not os.path.exists(self.base_log_path.parent):\n            os.makedirs(self.base_log_path.parent)\n        self.base_filename = self.base_log_path.name\n        self.current_filename = self._compute_fn()\n        self.current_log_path = self.base_log_path.with_name(\n            self.current_filename)\n        BaseRotatingHandler.__init__(self, filename, 'a', encoding, delay)\n\n    # pylint: disable=unused-argument, invalid-name\n    def shouldRollover(self, record):\n        \"\"\"\n        Determine whether to rotate the log. If the log filename corresponding to the current \n        time is not consistent with the currently opened log filename, then it is necessary\n        to rotate the log\n        Args:\n            record: record is not used, as we are just comparing times, but it is needed so\n        the method signatures are the same\n        \"\"\"\n        if self.current_filename != self._compute_fn():\n            return True\n        return False\n\n    def doRollover(self):\n        \"\"\"\n        roll over\n        \"\"\"\n        # close last log file\n        if self.stream:\n            self.stream.close()\n            self.stream = None  # type: ignore\n\n        # gen new log file name\n        self.current_filename = self._compute_fn()\n        self.current_log_path = self.base_log_path.with_name(\n            self.current_filename)\n\n        if not self.delay:\n            self.stream = self._open() # type: ignore\n\n        self.delete_expired_files()\n\n    def _compute_fn(self):\n        \"\"\"\n        gen log file name\n        \"\"\"\n        return self.base_filename + \".\" + time.strftime(self.suffix, time.localtime())\n\n    def _open(self):\n        \"\"\"\n        open a new log file, create soft link\n        \"\"\"\n        if self.encoding is None:\n            stream = open(str(self.current_log_path), self.mode, encoding=locale.getpreferredencoding())\n        else:\n            stream = codecs.open(str(self.current_log_path), self.mode, self.encoding)\n\n        if self.base_log_path.exists():\n            try:\n                if not self.base_log_path.is_symlink() or os.readlink(self.base_log_path) != self.current_filename:\n                    os.remove(self.base_log_path)\n            except OSError:\n                pass\n\n        try:\n            os.symlink(self.current_filename, str(self.base_log_path))\n        except OSError:\n            pass\n        return stream\n\n    def delete_expired_files(self):\n        \"\"\"\n        delete expired files every day\n        \"\"\"\n        if self.backup_count <= 0:\n            return\n\n        file_names = os.listdir(str(self.base_log_path.parent))\n        result = []\n        prefix = self.base_filename + \".\"\n        plen = len(prefix)\n        for file_name in file_names:\n            if file_name[:plen] == prefix:\n                suffix = file_name[plen:]\n                if re.match(r\"^\\d{4}-\\d{2}-\\d{2}(\\.\\w+)?$\", suffix):\n                    result.append(file_name)\n        if len(result) < self.backup_count:\n            result = []\n        else:\n            result.sort()\n            result = result[:len(result) - self.backup_count]\n\n        for file_name in result:\n            os.remove(str(self.base_log_path.with_name(file_name)))\n\n\nclass Logger(object):\n    \"\"\"\n    logger class\n    \"\"\"\n    level_relations = {\n        'debug': logging.DEBUG,\n        'info': logging.INFO,\n        'warn': logging.WARNING,\n        'error': logging.ERROR,\n        'crit': logging.CRITICAL\n    }\n\n    def __init__(self, level: str = 'info'):\n        fmt = '%(asctime)s %(levelname)s %(pathname)s[%(lineno)d] %(funcName)s: %(message)s'\n        cfg: Config = Config()\n        filename: str = os.path.join(cfg.log_dir, cfg.log_file)\n        backup_count: int = cfg.backup_count\n        th = DailyRotatingFileHandler(filename=filename, when='MIDNIGHT', backupCount=backup_count, encoding=\"utf-8\")\n        th.setFormatter(logging.Formatter(fmt))\n\n\n        color_fmt = (\n            '%(log_color)s%(asctime)s %(levelname)s %(pathname)s[%(lineno)d]: %(message)s'\n        )\n        color_formatter = colorlog.ColoredFormatter(\n            color_fmt,\n            log_colors={\n                'DEBUG': 'cyan',\n                'INFO': 'green',\n                'WARNING': 'yellow',\n                'ERROR': 'red',\n                'CRITICAL': 'bold_red'\n            }\n        )\n\n        sh = logging.StreamHandler()\n        sh.setFormatter(color_formatter)\n\n        self.logger = logging.getLogger(filename)\n        self.logger.setLevel(self.level_relations.get(level)) # type: ignore\n        self.logger.addHandler(th)\n        self.logger.addHandler(sh)\n\n\nlogger = Logger(level=Config().log_level).logger\n"
  },
  {
    "path": "archive/ktransformers/server/config/singleton.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  : Implement singleton\nAuthor       : unicornchan\nDate         : 2024-06-11 17:08:36\nVersion      : 1.0.0\nLastEditors  : chenxl \nLastEditTime : 2024-07-27 01:55:56\n'''\nimport abc\n\nclass Singleton(abc.ABCMeta, type):\n    \"\"\"_summary_\n\n    Args:\n        abc.ABCMeta: Provide a mechanism for defining abstract methods and properties,\n            enforcing subclasses to implement these methods and properties.\n        type: Inherit from 'type' to make 'Singleton' a metaclass,\n            enabling the implementation of the Singleton\n    \"\"\"\n    _instances = {}\n\n    def __call__(cls, *args, **kwds):\n        if cls not in cls._instances:\n            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwds)\n        return cls._instances[cls]\n\nclass AbstractSingleton(abc.ABC, metaclass=Singleton):\n    \"\"\"Provided an abstract Singleton base class, any class inheriting from\n       this base class will automatically become a Singleton class.\n\n    Args:\n        abc.ABC: Abstract base class, it cannot be instantiated, only inherited. \n    \"\"\"\n"
  },
  {
    "path": "archive/ktransformers/server/crud/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/crud/assistants/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/crud/assistants/assistants.py",
    "content": "from time import time\nfrom typing import Optional,List\nfrom uuid import uuid4\n\nfrom ktransformers.server.models.assistants.assistants import Assistant\nfrom ktransformers.server.schemas.assistants.assistants import AssistantCreate,AssistantObject,AssistantModify\nfrom ktransformers.server.utils.sql_utils import SQLUtil\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.schemas.base import Order\n\n\nclass AssistantDatabaseManager:\n    def __init__(self) -> None:\n        self.sql_util = SQLUtil()\n\n    def create_assistant_object(self, assistant: AssistantCreate) -> AssistantObject:\n        assistant = AssistantObject(\n            **assistant.model_dump(mode='json'),\n            id=str(uuid4()),\n            object='assistant',\n            created_at=int(time()),\n        )\n        return assistant\n\n    def db_count_assistants(self) -> int:\n        with self.sql_util.get_db() as db:\n            return db.query(Assistant).count()\n\n    def db_create_assistant(self, assistant: AssistantCreate):\n        ass_obj = self.create_assistant_object(assistant)\n        ass_obj.sync_db()\n        return ass_obj\n\n    def db_list_assistants(self, limit: Optional[int], order: Order) -> List[AssistantObject]:\n        with self.sql_util.get_db() as db:\n            query = db.query(Assistant).order_by(\n                order.to_sqlalchemy_order()(Assistant.created_at))\n            if limit is not None:\n                db_assistants = query.limit(limit)\n            else:\n                db_assistants = query.all()\n            return [AssistantObject.model_validate(a.__dict__) for a in db_assistants]\n\n    def db_get_assistant_by_id(self, assistant_id: str) -> Optional[AssistantObject]:\n        with self.sql_util.get_db() as db:\n            db_assistant = db.query(Assistant).filter(\n                Assistant.id == assistant_id).first()\n            if db_assistant is None:\n                logger.debug(f\"no assistant with id {str}\")\n                return None\n            return AssistantObject.model_validate(db_assistant.__dict__)\n\n    def db_update_assistant_by_id(self, assistant_id: str, assistant: AssistantModify):\n        with self.sql_util.get_db() as db:\n            db_assistant = db.query(Assistant).filter(\n                Assistant.id == assistant_id).first()\n            self.sql_util.db_update_commit_refresh(db, db_assistant, assistant)\n            return AssistantObject.model_validate(db_assistant.__dict__)\n\n    def db_delete_assistant_by_id(self, assistant_id: str):\n        with self.sql_util.get_db() as db:\n            db_assistant = db.query(Assistant).filter(\n                Assistant.id == assistant_id).first()\n            db.delete(db_assistant)\n            db.commit()\n\n"
  },
  {
    "path": "archive/ktransformers/server/crud/assistants/messages.py",
    "content": "from time import time\nfrom typing import Optional\nfrom uuid import uuid4\n\nfrom ktransformers.server.models.assistants.messages import Message\nfrom ktransformers.server.schemas.assistants.messages import MessageCore, MessageCreate,  MessageObject\nfrom ktransformers.server.schemas.base import Order,ObjectID\nfrom ktransformers.server.utils.sql_utils import SQLUtil\n\nclass MessageDatabaseManager:\n    def __init__(self) -> None:\n        self.sql_util = SQLUtil()\n\n    @staticmethod\n    def create_db_message_by_core(message: MessageCore):\n        message_dict = message.model_dump(mode=\"json\")\n        return Message(**message_dict, id=str(uuid4()), created_at=int(time()))\n\n    def create_db_message(self, message: MessageCreate):\n        return MessageDatabaseManager.create_db_message_by_core(message.to_core())\n\n    def db_add_message(self, message: Message):\n        with self.sql_util.get_db() as db:\n            db.add(message)\n            self.sql_util.db_add_commit_refresh(db, message)\n\n    def db_create_message(self, thread_id: str, message: MessageCreate, status: MessageObject.Status):\n        db_message = self.create_db_message(message)\n        db_message.status = status.value\n        db_message.thread_id = thread_id\n        self.db_add_message(db_message)\n        return MessageObject.model_validate(db_message.__dict__)\n\n    @staticmethod\n    def create_message_object(thread_id: ObjectID, run_id: ObjectID, message: MessageCreate):\n        core = message.to_core()\n        return MessageObject(\n            **core.model_dump(mode='json'),\n            id=str(uuid4()),\n            object='thread.message',\n            created_at=int(time()),\n            thread_id=thread_id,\n            run_id=run_id,\n            status=MessageObject.Status.in_progress,\n        )\n\n    def db_sync_message(self, message: MessageObject):\n        db_message = Message(\n            **message.model_dump(mode=\"json\"),\n        )\n        with self.sql_util.get_db() as db:\n            self.sql_util.db_merge_commit(db, db_message)\n\n    def db_list_messages_of_thread(\n            self, thread_id: str, limit: Optional[int] = None, order: Order = Order.DESC):\n\n        # logger.debug(\n        #     f\"list messages of: {thread_id}, limit {limit}, order {order}\")\n        with self.sql_util.get_db() as db:\n            query = (\n                db.query(Message)\n                .filter(Message.thread_id == thread_id)\n                .order_by(order.to_sqlalchemy_order()(Message.created_at))\n            )\n            if limit is not None:\n                messages = query.limit(limit)\n            else:\n                messages = query.all()\n            message_list = [MessageObject.model_validate(m.__dict__) for m in messages]\n        return message_list\n\n    def db_get_message_by_id(self, thread_id: ObjectID, message_id: ObjectID) -> MessageObject:\n        with self.sql_util.get_db() as db:\n            message = db.query(Message).filter(\n                Message.id == message_id).first()\n        assert message.thread_id == thread_id\n        message_info = MessageObject.model_validate(message.__dict__)\n        return message_info\n\n    def db_delete_message_by_id(self, thread_id: ObjectID, message_id: ObjectID):\n        with self.sql_util.get_db() as db:\n            message = db.query(Message).filter(\n                Message.id == message_id).first()\n            assert message.thread_id == thread_id\n            db.delete(message)\n            db.commit()\n"
  },
  {
    "path": "archive/ktransformers/server/crud/assistants/runs.py",
    "content": "from time import time\nfrom uuid import uuid4\n\nfrom ktransformers.server.models.assistants.runs import Run\nfrom ktransformers.server.schemas.assistants.runs import RunCreate,RunObject\nfrom ktransformers.server.schemas.base import ObjectID\nfrom ktransformers.server.utils.sql_utils import SQLUtil\n\n\nclass RunsDatabaseManager:\n    def __init__(self) -> None:\n        self.sql_util = SQLUtil()\n\n    def create_run_object(self, thread_id: ObjectID, run: RunCreate) -> RunObject:\n        run_obj = RunObject(\n            **run.model_dump(mode='json', exclude={\"stream\"}),\n            id=str(uuid4()),\n            object='run',\n            created_at=int(time()),\n            thread_id=thread_id,\n            status=RunObject.Status.queued,\n        )\n        run_obj.set_compute_save(0)\n        return run_obj\n\n    def db_create_run(self, thread_id: str, run: RunCreate):\n        db_run = Run(\n            **run.model_dump(mode=\"json\", exclude={\"stream\"}),\n            id=str(uuid4()),\n            created_at=int(time()),\n            status=\"queued\",\n            thread_id=thread_id,\n        )\n        with self.sql_util.get_db() as db:\n            self.sql_util.db_add_commit_refresh(db, db_run)\n            run_obj = RunObject.model_validate(db_run.__dict__)\n            run_obj.set_compute_save(0)\n        return run_obj\n\n    def db_sync_run(self, run: RunObject) -> None:\n        db_run = Run(\n            **run.model_dump(mode='json'),\n        )\n        with self.sql_util.get_db() as db:\n            self.sql_util.db_merge_commit(db, db_run)\n\n    def db_get_run(self, run_id: ObjectID) -> RunObject:\n        with self.sql_util.get_db() as db:\n            db_run = db.query(Run).filter(Run.id == run_id).first()\n            return RunObject.model_validate(db_run.__dict__)\n"
  },
  {
    "path": "archive/ktransformers/server/crud/assistants/threads.py",
    "content": "from time import time\nfrom typing import Optional,List\nfrom uuid import uuid4\n\nfrom ktransformers.server.models.assistants.messages import Message\nfrom ktransformers.server.models.assistants.threads import Thread\nfrom ktransformers.server.schemas.assistants.threads import ThreadCreate,ThreadObject\nfrom ktransformers.server.schemas.base import ObjectID, Order\nfrom ktransformers.server.schemas.conversation import ThreadPreview\nfrom ktransformers.server.utils.sql_utils import SQLUtil\nfrom ktransformers.server.crud.assistants.messages import MessageDatabaseManager\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager\n\nclass ThreadsDatabaseManager:\n    def __init__(self) -> None:\n        self.sql_util = SQLUtil()\n        self.message_manager = MessageDatabaseManager()\n        self.assistant_maanager = AssistantDatabaseManager()\n\n    def db_create_thread(self, thread: ThreadCreate):\n        thread_id = str(uuid4())\n        db_messages = []\n        with self.sql_util.get_db() as db:\n            if thread.messages is not None:\n                logger.debug(\"Creating messages first for thread\")\n                for message in thread.messages:\n                    db_message: Message = MessageDatabaseManager.create_db_message_by_core(\n                        message)\n                    db_message.role = \"user\"\n                    db_message.thread_id = thread_id\n                    db.add(db_message)\n                    db_messages.append(db_message)\n\n            db_thread = Thread(\n                **thread.model_dump(exclude=\"messages\"),\n                id=str(uuid4()),\n                created_at=int(time()),\n                messages=db_messages,\n            )\n\n            self.sql_util.db_add_commit_refresh(db, db_thread)\n            thread_obj = ThreadObject.model_validate(db_thread.__dict__)\n\n            if 'assistant_id' in thread.meta_data:\n#                assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'], db)\n                assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'])\n                logger.info(\n                    f'Append this related thread to assistant {assistant.id}')\n                assistant.append_related_threads([thread_obj.id])\n                assistant.sync_db(db)\n        return thread_obj\n\n    def db_get_thread_by_id(self, thread_id: ObjectID):\n        with self.sql_util.get_db() as db:\n            db_thread = db.query(Thread).filter(Thread.id == thread_id).first()\n            return ThreadObject.model_validate(db_thread.__dict__)\n\n    def db_list_threads(self, limit: Optional[int], order: Order) -> List[ThreadObject]:\n        with self.sql_util.get_db() as db:\n            query = db.query(Thread).order_by(order.to_sqlalchemy_order()(\n                Thread.created_at)).filter(~Thread.meta_data.contains('assistant_id'))\n\n            if limit is not None:\n                db_threads = query.limit(limit)\n            else:\n                db_threads = query.all()\n\n            return [ThreadObject.model_validate(tool.__dict__) for tool in db_threads]\n\n    def db_list_threads_preview(self, limit: Optional[int], order: Order) -> List[ThreadPreview]:\n        threads = self.db_list_threads(limit, order)\n        previews = []\n        for thread in threads:\n            messages = self.message_manager.db_list_messages_of_thread(\n                thread.id, limit=2, order=Order.ASC)\n            if len(messages) == 2:\n                message = messages[0]\n                assistant = self.assistant_maanager.db_get_assistant_by_id(\n                    messages[1].assistant_id)\n            else:\n                message = None\n                assistant = None\n            previews.append(ThreadPreview(\n                assistant=assistant, thread=thread, first_message=message))\n        return previews\n\n    def db_delete_thread_by_id(self, thread_id: ObjectID):\n        with self.sql_util.get_db() as db:\n            db_thread = db.query(Thread).filter(Thread.id == thread_id).first()\n            db.delete(db_thread)\n            # TODO delete related messages and runs and other stuff or just gc\n            db.commit()\n"
  },
  {
    "path": "archive/ktransformers/server/exceptions.py",
    "content": "from fastapi import HTTPException, status\n\n\ndef db_exception():\n    return HTTPException(\n        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,\n        detail=\"DB Error\",\n    )\n\n\ndef not_implemented(what):\n    return HTTPException(\n        status_code=status.HTTP_501_NOT_IMPLEMENTED,\n        detail=f\"{what} not implemented\",\n    )\n\n\ndef internal_server_error(what):\n    return HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f\"{what}\")\n\n\ndef request_error(what):\n    return HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f\"{what}\")\n"
  },
  {
    "path": "archive/ktransformers/server/main.py",
    "content": "import asyncio\nimport os\nimport re\nfrom uuid import uuid4\n\nimport torch\nimport torch.distributed\nfrom fastapi import FastAPI\nfrom fastapi.staticfiles import StaticFiles\nimport uvicorn.logging\nimport uvicorn\nimport sys\nimport atexit\nproject_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom ktransformers.server.args import ArgumentParser\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.util import utils\nfrom ktransformers.server.utils.create_interface import create_interface, GlobalInterface, get_thread_context_manager\nfrom fastapi.openapi.utils import get_openapi\nfrom fastapi import FastAPI\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom ktransformers.server.api import router, post_db_creation_operations\nfrom ktransformers.server.utils.sql_utils import Base, SQLUtil\nfrom ktransformers.server.config.log import logger\nimport subprocess\nimport tempfile\n\ndef mount_app_routes(mount_app: FastAPI):\n    sql_util = SQLUtil()\n    logger.info(\"Creating SQL tables\")\n    Base.metadata.create_all(bind=sql_util.sqlalchemy_engine)\n    post_db_creation_operations()\n    mount_app.include_router(router)\n\n\ndef create_app():\n    cfg = Config()\n    if(hasattr(GlobalInterface.interface, \"lifespan\")):\n        app = FastAPI(lifespan=GlobalInterface.interface.lifespan)\n    else:\n        app = FastAPI()\n    if Config().web_cross_domain:\n        app.add_middleware(\n            CORSMiddleware,\n            allow_origins=[\"*\"],\n            allow_credentials=True,\n            allow_methods=[\"*\"],\n            allow_headers=[\"*\"],\n        )\n    mount_app_routes(app)\n    if cfg.mount_web:\n        mount_index_routes(app)\n    return app\n\n\ndef update_web_port(config_file: str):\n    ip_port_pattern = (\n        r\"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}\"\n    )\n    with open(config_file, \"r\", encoding=\"utf-8\") as f_cfg:\n        web_config = f_cfg.read()\n    ip_port = \"localhost:\" + str(Config().server_port)\n    new_web_config = re.sub(ip_port_pattern, ip_port, web_config)\n    with open(config_file, \"w\", encoding=\"utf-8\") as f_cfg:\n        f_cfg.write(new_web_config)\n\n\ndef mount_index_routes(app: FastAPI):\n    project_dir = os.path.dirname(os.path.dirname(__file__))\n    web_dir = os.path.join(project_dir, \"website/dist\")\n    web_config_file = os.path.join(web_dir, \"config.js\")\n    update_web_port(web_config_file)\n    if os.path.exists(web_dir):\n        app.mount(\"/web\", StaticFiles(directory=web_dir), name=\"static\")\n    else:\n        err_str = f\"No website resources in {web_dir}, please complile the website by npm first\"\n        logger.error(err_str)\n        print(err_str)\n        exit(1)\n\n\ndef run_api(app, host, port, **kwargs):\n    if kwargs.get(\"ssl_keyfile\") and kwargs.get(\"ssl_certfile\"):\n        uvicorn.run(\n            app,\n            host=host,\n            port=port,\n            ssl_keyfile=kwargs.get(\"ssl_keyfile\"),\n            ssl_certfile=kwargs.get(\"ssl_certfile\"),\n        )\n    else:\n        uvicorn.run(app, host=host, port=port, log_level=\"debug\")\n\n\ndef custom_openapi(app):\n    if app.openapi_schema:\n        return app.openapi_schema\n    openapi_schema = get_openapi(\n        title=\"ktransformers server\",\n        version=\"1.0.0\",\n        summary=\"This is a server that provides a RESTful API for ktransformers.\",\n        description=\"We provided chat completion and openai assistant interfaces.\",\n        routes=app.routes,\n    )\n    openapi_schema[\"info\"][\"x-logo\"] = {\"url\": \"https://kvcache.ai/media/icon_1.png\"}\n    app.openapi_schema = openapi_schema\n    return app.openapi_schema\n\n\ndef verify_arg(args):\n    nproc_per_node = int(os.getenv('LOCAL_WORLD_SIZE'))\n\n    if args.batch_size not in [1, 2, 3, 4]:\n        raise ValueError(f'argument batch_size should be in [1, 2, 3, 4], got {args.batch_size}')\n\n    if nproc_per_node not in [1, 2]:\n        raise ValueError(f'argument nproc_per_node should be in [1, 2], got {nproc_per_node}')\n\n    if args.tp not in [1, 2]:\n        raise ValueError(f'argument tp should be in [1, 2], got {args.tp}')\n\n    if nproc_per_node != args.tp:\n        raise ValueError(f'argument nproc_per_node should be equal to tp, got nproc_per_node is {nproc_per_node}, tp is {args.tp}')\n\n\ndef main():\n    try:\n        import torch_npu\n        use_npu = torch.npu.is_available()\n        torch.npu.config.allow_internal_format = True\n    except:\n        use_npu = False\n\n    cfg = Config()\n\n    arg_parser = ArgumentParser(cfg)\n\n    args = arg_parser.parse_args()\n    if use_npu:\n        verify_arg(args)\n\n        rank_id = int(os.environ[\"RANK\"])\n        args.device = args.device[:-1] + str(rank_id)\n    create_interface(config=cfg, default_args=cfg, input_args=args)\n\n    tp_size = args.tp\n    world_size = int(os.getenv(\"WORLD_SIZE\", '1'))\n    if tp_size == world_size and tp_size > 1:\n        if rank_id == 0:\n            app = create_app()\n            custom_openapi(app)\n            run_api(\n                app=app,\n                host=args.host,\n                port=args.port,\n                ssl_keyfile=args.ssl_keyfile,\n                ssl_certfile=args.ssl_certfile,\n            )\n        elif cfg.backend_type == 'ktransformers':\n            while True:\n                try:\n                    context = get_thread_context_manager()\n                    id = str(uuid4())\n                    context.interface.sync_inference(\"\", id, 1.0, 1.0)\n                except Exception as e:\n                    print(f\"An error occurred: {e}\")\n                finally:\n                    pass\n    else:\n        app = create_app()\n        custom_openapi(app)\n\n        run_api(\n            app=app,\n            host=args.host,\n            port=args.port,\n            ssl_keyfile=args.ssl_keyfile,\n            ssl_certfile=args.ssl_certfile,\n        )\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "archive/ktransformers/server/models/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/models/assistants/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/models/assistants/assistants.py",
    "content": "from sqlalchemy import JSON, Column, Float, Integer, String, Text\nfrom sqlalchemy.orm import relationship\n\nfrom ktransformers.server.utils.sql_utils import Base\n\n\nclass Assistant(Base):\n    __tablename__ = \"assistants\"\n\n    id = Column(String, primary_key=True, index=True)\n    object = Column(String, default=\"assistant\")\n    created_at = Column(Integer)\n\n    name = Column(String, nullable=True)\n    description = Column(String, nullable=True)\n    model = Column(String)\n    instructions = Column(Text, nullable=True)\n    tools = Column(JSON)\n    tool_resources = Column(JSON)\n    temperature = Column(Float, nullable=True)\n    meta_data = Column(JSON, nullable=True)\n    top_p = Column(Float, nullable=True)\n    response_format = Column(JSON, default=\"auto\")\n\n    build_status = Column(JSON, nullable=True)\n\n    runs = relationship(\"Run\", back_populates=\"assistant\")\n\n    messages = relationship(\"Message\", back_populates=\"assistant\")\n"
  },
  {
    "path": "archive/ktransformers/server/models/assistants/messages.py",
    "content": "from sqlalchemy import JSON, Column, ForeignKey, Integer, String\nfrom sqlalchemy.orm import relationship\n\nfrom ktransformers.server.utils.sql_utils import Base\n\n\nclass Message(Base):\n    __tablename__ = \"messages\"\n\n    id = Column(String, primary_key=True, index=True)\n    object = Column(String, default=\"thread.message\")\n    created_at = Column(Integer)\n\n    thread_id = Column(String, ForeignKey(\"threads.id\"))\n    status = Column(String, default=\"in_progress\")\n    incomplete_details = Column(JSON, nullable=True)\n    completed_at = Column(Integer, nullable=True)\n    incomplete_at = Column(Integer, nullable=True)\n    role = Column(JSON)\n    content = Column(JSON)\n    assistant_id = Column(String, ForeignKey(\"assistants.id\"), nullable=True)\n    run_id = Column(String, ForeignKey(\"runs.id\"), nullable=True)\n    attachments = Column(JSON, nullable=True)\n    meta_data = Column(JSON, nullable=True)\n\n    thread = relationship(\"Thread\", back_populates=\"messages\")\n    assistant = relationship(\"Assistant\", back_populates=\"messages\")\n    run = relationship(\"Run\", back_populates=\"message\")\n"
  },
  {
    "path": "archive/ktransformers/server/models/assistants/run_steps.py",
    "content": "from sqlalchemy import JSON, Column, ForeignKey, Integer, String\nfrom sqlalchemy.orm import relationship\n\nfrom ktransformers.server.utils.sql_utils import Base\n\n\nclass RunStep(Base):\n    __tablename__ = \"run_steps\"\n    # todo\n    id = Column(String, primary_key=True, index=True)\n    object = Column(String, default=\"thread.run.step\")\n    created_at = Column(Integer)\n\n    assistant_id = Column(String, ForeignKey(\"assistants.id\"))\n    thread_id = Column(String, ForeignKey(\"threads.id\"))\n    run_id = Column(String, ForeignKey(\"runs.id\"))\n    type = Column(String)\n    status = Column(String)\n    step_details = Column(JSON)\n    last_error = Column(JSON, nullable=True)\n    expires_at = Column(Integer, nullable=True)\n    cancelled_at = Column(Integer, nullable=True)\n    failed_at = Column(Integer, nullable=True)\n    completed_at = Column(Integer, nullable=True)\n\n    meta_data = Column(JSON, nullable=True)\n    usage = Column(JSON, nullable=True)\n\n    assistant = relationship(\"Assistant\", back_populates=\"run_steps\")\n    thread = relationship(\"Thread\", back_populates=\"run_steps\")\n    run = relationship(\"Run\", back_populates=\"run_steps\")\n"
  },
  {
    "path": "archive/ktransformers/server/models/assistants/runs.py",
    "content": "from sqlalchemy import JSON, Column, Float, ForeignKey, Integer, String, Text\nfrom sqlalchemy.orm import relationship\n\nfrom ktransformers.server.utils.sql_utils import Base\n\n\nclass Run(Base):\n    __tablename__ = \"runs\"\n\n    id = Column(String, primary_key=True, index=True)\n    object = Column(String, default=\"thread.run\")\n    created_at = Column(Integer)\n    thread_id = Column(String, ForeignKey(\"threads.id\"))\n    assistant_id = Column(String, ForeignKey(\"assistants.id\"))\n    status = Column(String)\n    required_action = Column(JSON, nullable=True)\n    last_error = Column(JSON, nullable=True)\n    expires_at = Column(Integer, nullable=True)\n    started_at = Column(Integer, nullable=True)\n    cancelled_at = Column(Integer, nullable=True)\n    failed_at = Column(Integer, nullable=True)\n    completed_at = Column(Integer, nullable=True)\n    incomplete_details = Column(JSON, nullable=True)\n    # get from assistant\n    model = Column(String)\n    instructions = Column(Text, nullable=True)\n    tools = Column(JSON)\n    meta_data = Column(JSON, nullable=True)\n    usage = Column(JSON, nullable=True)\n    temperature = Column(Float, nullable=True)\n    top_p = Column(Float, nullable=True)\n    max_propmp_tokens = Column(Integer, nullable=True)\n    truncation_strategy = Column(JSON)\n    tool_choice = Column(JSON)\n    response_format = Column(JSON, default=\"auto\")\n\n    thread = relationship(\"Thread\", back_populates=\"runs\")\n    assistant = relationship(\"Assistant\", back_populates=\"runs\")\n    message = relationship(\"Message\", back_populates=\"run\")\n"
  },
  {
    "path": "archive/ktransformers/server/models/assistants/threads.py",
    "content": "from sqlalchemy import JSON, Column, Integer, String\nfrom sqlalchemy.orm import relationship\n\nfrom ktransformers.server.utils.sql_utils import Base\n\n\nclass Thread(Base):\n    __tablename__ = \"threads\"\n\n    id = Column(String, primary_key=True, index=True)\n    object = Column(String, default=\"thread\")\n    created_at = Column(Integer)\n\n    tool_resources = Column(JSON, nullable=True)\n    meta_data = Column(JSON, nullable=True)\n\n    runs = relationship(\"Run\", back_populates=\"thread\")\n    messages = relationship(\"Message\", back_populates=\"thread\")\n"
  },
  {
    "path": "archive/ktransformers/server/requirements.txt",
    "content": "torch >= 2.3.0\ntransformers >= 4.51.3\nfastapi >= 0.111.0\nlangchain >= 0.2.0\nblessed >= 1.20.0\naccelerate >= 0.31.0\nsentencepiece >= 0.1.97\nopenai\nsetuptools\nbuild\nninja\nwheel\ncolorlog\nfire\nzmq\npsutil"
  },
  {
    "path": "archive/ktransformers/server/schemas/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/schemas/assistants/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/schemas/assistants/assistants.py",
    "content": "from enum import Enum\nfrom time import time\nfrom typing import AsyncIterable, Callable, Dict, List, Optional, Union\nfrom asyncio import Lock, Queue\n\nfrom fastapi import logger\nfrom pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator\nimport torch\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.models.assistants.assistants import Assistant\nfrom ktransformers.server.models.assistants.threads import Thread\nfrom ktransformers.server.schemas.assistants.messages import Role\nfrom ktransformers.server.schemas.assistants.runs import RunObject,RunStreamResponse,ObjectWithCreatedTime\nfrom ktransformers.server.schemas.assistants.threads import ThreadObject\nfrom ktransformers.server.schemas.base import Metadata,MetadataField,ObjectID\nfrom ktransformers.server.schemas.assistants.tool import Tool,CodeInterpreter,FileSearch,RelatedThreads,FuntionTool,ToolResource,CodeInterpreterResource,FileSearchResource,RelatedThreadsResource,ToolType\nfrom ktransformers.server.utils.sql_utils import SQLUtil\n\n\nclass AssistantBase(BaseModel):\n    name: Optional[str] = Field(None,description='The name of the assistant.') \n    description: Optional[str] = Field(None,description='The description of the assistant.')\n    instructions: Optional[str] = Field(None,description='Instructions which is added in front of the input of LLM') \n    tools: List[Tool] = Field([], max_length=128)\n\n    @field_validator('tools', mode='before')\n    def validate_tools(cls, value):\n        re = []\n        if not isinstance(value, list):\n            raise ValueError('Invalid type for tools')\n\n        for tool in value:\n            if 'type' not in tool:\n                raise ValueError('Invalid type for tools')\n            if tool['type'] == 'code_interpreter':\n                re.append(CodeInterpreter(**tool))\n            elif tool['type'] == 'file_search':\n                re.append(FileSearch(**tool))\n            elif tool['type'] == 'related_threads':\n                re.append(RelatedThreads(**tool))\n            elif tool['type'] == 'function':\n                re.append(FuntionTool(**tool))\n            else:\n                raise ValueError('Invalid type for tools')\n        return re\n\n    tool_resources: List[ToolResource] = Field([], max_length=128)\n\n    @field_validator('tool_resources', mode='before')\n    def validate_tool_resources(cls, value):\n        re = []\n        if not isinstance(value, list):\n            raise ValueError('Invalid type for tool resources')\n\n        for tool_re in value:\n            if 'file_ids' in tool_re:\n                re.append(CodeInterpreterResource(**tool_re))\n            elif 'vector_stores' in tool_re:\n                re.append(FileSearchResource(**tool_re))\n            elif 'thread_ids' in tool_re:\n                re.append(RelatedThreadsResource(**tool_re))\n            else:\n                raise ValueError('Invalid type for tool resources')\n        return re\n\n    meta_data: Metadata = MetadataField\n\n    @model_validator(mode='before')\n    def convert_meta_data(cls, values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n    temperature: Optional[float] = Field(ge=0.0, le=2.0, default=1)\n    top_p: Optional[float] = Field(ge=0.0, le=1.0, default=1)\n    response_format: Union[str, Dict[str, str]] = \"auto\"\n\n\nclass AssistantCreate(AssistantBase):\n    model: str\n\n\nclass AssistantBuildStatus(BaseModel):\n    class Status(Enum):\n        not_build = \"not_build\"\n        in_queue = \"in_queue\"\n        parsing = \"parsing\"\n        prefilling = \"prefilling\"\n        dumping = \"dumping\"\n        completed = \"completed\"\n        paused = \"paused\"\n\n    _lock: Lock = PrivateAttr(default_factory=Lock)\n    _queue: Optional[Queue] = PrivateAttr(None)\n\n    status: Status = Field(default=Status.not_build)\n    total_file_count: int = Field(default=0)\n    parsed_file_count: int = Field(default=0)\n\n    prefilling_current: int = Field(default=0)\n    prefilling_total: int = Field(default=0)\n\n    build_started_time: Optional[int] = Field(default=None)\n    build_completed_time: Optional[int] = Field(default=None)\n\n    # in megabytes\n    assistant_usage: int = Field(default=0, description='')\n    assistant_total_usage: int = Field(default=0)\n    disk_free_space: int = Field(default=0)\n    disk_total_space: int = Field(default=0)\n\n    def to_stream_reply(self) -> str:\n        return f\"event: assistant.build.status\\ndata: {self.model_dump_json()}\\n\\n\"\n\n\nclass AssistantObject(AssistantBase, ObjectWithCreatedTime):\n    model: Optional[str] = Field(\n        default=Config().model_name)\n    related_threads_objects: Optional[List] = Field(None, exclude=True)\n    _encoded_instruction: Optional[torch.Tensor] = PrivateAttr(default=None)\n    build_status: AssistantBuildStatus = Field(default=AssistantBuildStatus())\n\n    def as_api_response(self):\n        return self.model_dump(exclude={'build_status'})\n\n    def get_related_threads_ids(self) -> List[ObjectID]:\n        re = []\n        for tool, tool_re in zip(self.tools, self.tool_resources):\n            if tool.type == ToolType.RELATED_THREADS:\n                re += tool_re.thread_ids or []\n        return re\n\n    def get_related_threads_objects(self) -> List:\n        # raise NotImplementedError  # should be replaced\n        sql_utils = SQLUtil()\n        if self.related_threads_objects is None:\n            with sql_utils.get_db() as db:\n                db_threads = db.query(Thread).all()\n            self.related_threads_objects = [tool for tool in [ThreadObject.model_validate(\n                tool.__dict__) for tool in db_threads] if tool.is_related_threads and tool.meta_data['assistant_id'] == self.id]\n            # logger.debug(\n            #     f'Found {len(self.related_threads_objects)} related threads')\n        return self.related_threads_objects\n\n    def append_related_threads(self, thread_ids: List[ObjectID]):\n        # logger.debug(f'{self.tools} {self.tool_resources}')\n        for tool, tool_re in zip(self.tools, self.tool_resources):\n            if tool.type == ToolType.RELATED_THREADS:\n                tool_re.thread_ids += thread_ids\n                return\n\n        self.tools.append(RelatedThreads(type=ToolType.RELATED_THREADS))\n        self.tool_resources.append(\n            RelatedThreadsResource(thread_ids=thread_ids))\n\n    async def update_build_status(self, events: AsyncIterable) -> AsyncIterable:\n        async for event in events:\n            # logger.debug(event)\n            if isinstance(event, RunStreamResponse):\n                if event.event == RunObject.Status.completed:\n                    self.build_status.status = AssistantBuildStatus.Status.completed\n                    self.build_status.build_completed_time = int(time())\n                    self.sync_db()\n                    yield self.build_status.model_copy()\n            elif isinstance(event, dict):\n                # logger.debug('dict')\n                if 'stage' in event:\n                    if event['stage'] == 'prefill':\n                        self.build_status.status = AssistantBuildStatus.Status.prefilling\n                        self.build_status.prefilling_current = event['curr_progress']\n                        self.build_status.prefilling_total = event['max_progress']\n                    if event['stage'] == 'parse':\n                        self.build_status.status = AssistantBuildStatus.Status.parsing\n                        self.build_status.parsed_file_count = event['curr_progress']\n                        self.build_status.total_file_count = event['max_progress']\n                    yield self.build_status.model_copy()\n\n    def get_build_status(self) -> AssistantBuildStatus:\n        return self.build_status\n     \n    \n    def sync_db(self)->None:\n        # raise NotImplementedError # should be replaced\n        sql_utils = SQLUtil()\n        db_assistant = Assistant(\n            **self.model_dump(mode='json'),\n        )\n        with sql_utils.get_db() as db:\n            sql_utils.db_merge_commit(db, db_assistant)\n    \n    def get_encoded_instruction(self,encode_fn:Callable)->torch.Tensor:\n        if self._encoded_instruction is None:\n            logger.info(f'encoding assistant instruction: {self.instructions}')\n            self._encoded_instruction = encode_fn(self.instructions, Role.user)\n        return self._encoded_instruction\n\n\nclass AssistantModify(AssistantBase):\n    model: Optional[str] = None\n\n\n# Non API Backend\n"
  },
  {
    "path": "archive/ktransformers/server/schemas/assistants/messages.py",
    "content": "from enum import Enum\nfrom typing import ForwardRef, List, Optional, Union,Callable\n\nimport torch\nfrom pydantic import BaseModel, PrivateAttr, model_validator\n\nfrom ktransformers.server.exceptions import not_implemented\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.models.assistants.messages import Message\nfrom ktransformers.server.schemas.base import Metadata, MetadataField, ObjectWithCreatedTime\nfrom ktransformers.server.schemas.assistants.tool import Field,CodeInterpreter,FileSearch\nfrom ktransformers.server.utils.sql_utils import SQLUtil\n\n\nclass IncompleteDetails(BaseModel):\n    reason: str\n\n\nclass ContentType(Enum):\n    image_file = \"image_file\"\n    image_url = \"image_url\"\n    text = \"text\"\n\n\nclass ContentObject(BaseModel):\n    type: ContentType\n\n\nclass ImageFile(BaseModel):\n    file_id: str\n    detail: str\n\n\nclass ImageFileObject(ContentObject):\n    image_file: ImageFile\n\n\nclass ImageUrl(BaseModel):\n    url: str\n    detail: str\n\n\nclass ImageUrlObject(ContentObject):\n    image_url: ImageUrl\n\n\nclass Annotation(BaseModel):\n    todo: str\n\n\nclass Text(BaseModel):\n    value: str\n    annotations: List[Annotation] = Field(default=[])\n\n\nclass TextObject(ContentObject):\n    text: Text\n    delta_index: int = Field(default=0,exclude=True)\n    special_tokens_on: bool = Field(default=False,exclude=True) \n    last_two: str= Field(default='',exclude=True)  \n\n    def filter_append(self,text:str):     \n        self.text.value+=text\n        self.delta_index+=1\n        return True  \n\n\n\nContent = Union[ImageFileObject, ImageUrlObject, TextObject]\n\n\nclass Attachment(BaseModel):\n    file_id: Optional[str] = Field(default=None)\n    tools: Optional[List[Union[CodeInterpreter, FileSearch]]] = Field(default=None)\n\n\nclass Role(Enum):\n    user = \"user\"\n    assistant = \"assistant\"\n\n    def is_user(self)->bool:\n        return self == Role.user\n\n\nclass MessageCore(BaseModel):\n    role: Role\n    content: List[Content]\n    attachments: Optional[List[Attachment]]\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n\n\nclass MessageBase(MessageCore):\n    class Status(Enum):\n        created = \"created\" # only used for stream\n        in_progress = \"in_progress\"\n        incomplete = \"incomplete\"\n        completed = \"completed\"\n    thread_id: str\n    status: Status\n    incomplete_details: Optional[IncompleteDetails] = None\n    completed_at: Optional[int] = None\n    incomplete_at: Optional[int] = None\n\n    assistant_id: Optional[str] = None\n    run_id: Optional[str]\n\n\nMessageStreamResponse = ForwardRef('MessageStreamResponse')\n\nclass MessageObject(MessageBase, ObjectWithCreatedTime):\n    _encoded_content: Optional[torch.Tensor] = PrivateAttr(default=None)\n    \n\n    def get_text_content(self) -> str:\n        text_content = \"\"\n        for content in self.content:\n            if content.type == ContentType.text:\n                text_content += content.text.value\n            else:\n                raise not_implemented(\"Content other than text\")\n        return text_content\n\n    async def get_encoded_content(self,encode_fn:Callable):\n        if self._encoded_content is None:\n            logger.info(f'encoding {self.role.value} message({self.status.value}): {self.get_text_content()}')\n            self._encoded_content = encode_fn(self.get_text_content(),self.role)\n\n            for f in self.get_attached_files():\n                logger.info(f'encoding file: {f.filename}')\n                self._encoded_content = torch.cat([self._encoded_content, encode_fn(await f.get_str(),self.role)],dim=-1)\n                yield None \n\n        yield self._encoded_content\n\n\n    def get_attached_files(self):\n        raise NotImplementedError # should be replaced \n\n\n\n    def append_message_delta(self,text:str):\n        raise NotImplementedError # should be replaced \n    \n    def sync_db(self):\n        # raise NotImplementedError # should be replaced\n        sql_utils = SQLUtil()\n        db_message = Message(\n            **self.model_dump(mode=\"json\"),\n        )\n        with sql_utils.get_db() as db:\n            sql_utils.db_merge_commit(db, db_message)\n    \n\n    def stream_response_with_event(self, event: MessageBase.Status) -> MessageStreamResponse:\n        match event:\n            case MessageObject.Status.created:\n                self.status = MessageObject.Status.in_progress\n            case _:\n                self.status = event\n        return MessageStreamResponse(message=self, event=event)\n   \n\nclass MessageStreamResponse(BaseModel):\n    message: MessageObject\n    event: MessageObject.Status\n\n    def to_stream_reply(self):\n        return f\"event: thread.message.{self.event.value}\\ndata: {self.message.model_dump_json()}\\n\\n\"\n\n\nclass MessageCreate(BaseModel):\n    role: Role = Field(default=Role.user)\n    content: Union[str | List[Content]]\n    attachments: Optional[List[Attachment]] = None\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n\n    def to_core(self) -> MessageCore:\n        # logger.debug(f\"Converting message create to core {self.model_dump()}\")\n        core = MessageCore(\n            role=self.role,\n            content=[],\n            attachments=self.attachments,\n            meta_data=self.meta_data,\n        )\n        if isinstance(self.content, str):\n            core.content = [TextObject(type=\"text\", text=Text(value=self.content, annotations=[]))]\n        elif isinstance(self.content, list):\n            core.content = self.content\n        else:\n            raise ValueError(\"Invalid content type\")\n        return core\n\n\nclass MessageModify(BaseModel):\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n"
  },
  {
    "path": "archive/ktransformers/server/schemas/assistants/runs.py",
    "content": "from enum import Enum\nfrom typing import Dict, List, Optional, Union, ForwardRef\n\nfrom pydantic import BaseModel, Field, model_validator\n\nfrom ktransformers.server.models.assistants.runs import Run\nfrom ktransformers.server.schemas.base import TODO, Metadata, MetadataField, ObjectWithCreatedTime\nfrom ktransformers.server.schemas.assistants.threads import ThreadCreate\nfrom ktransformers.server.schemas.assistants.tool import Tool, ToolResource\nfrom ktransformers.server.utils.sql_utils import SQLUtil\n\n\nclass ToolCall(BaseModel):\n    id: str\n    type: str\n    function: TODO\n\n\nclass SubmitToolOutputs(BaseModel):\n    tool_calls: List[ToolCall]\n\n\nclass RequiredAction(BaseModel):\n    type: str\n    submit_tool_outputs: TODO\n\n\nclass LastError(BaseModel):\n    code: str\n    message: str\n\n\nclass IncompleteDetails(BaseModel):\n    reason: str\n\n\nclass Usage(BaseModel):\n    completion_tokens: int\n    prompt_tokens: int\n    total_tokens: int\n\n\nclass TruncationStrategy(BaseModel):\n    type: str = \"auto\"\n    last_message: Optional[int]\n\n\nclass ToolChoiceType(Enum):\n    none = \"none\"\n    auto = \"auto\"\n    required = \"required\"\n\n\nclass RunBase(BaseModel):\n    class Status(Enum):\n        created = \"created\" # only stream event will have this created status\n        queued = \"queued\"\n        in_progress = \"in_progress\"\n        requires_action = \"requires_action\"\n        cancelling = \"cancelling\"\n        cancelled = \"cancelled\"\n        failed = \"failed\"\n        completed = \"completed\"\n        expired = \"expired\"\n\n\n    thread_id: str\n    assistant_id: str\n    status: Status = Status.queued\n    required_action: Optional[RequiredAction] = Field(None)\n    last_error: Optional[LastError] = Field(None)\n    expires_at: Optional[int]= Field(None)\n    started_at: Optional[int] = Field(None)\n    cancelled_at: Optional[int] = Field(None)\n    failed_at: Optional[int] = Field(None)\n    completed_at: Optional[int] = Field(None)\n    incomplete_details: Optional[IncompleteDetails] = Field(None)\n    model: Optional[str] = Field(None)\n    instructions: Optional[str] = Field(None)\n    tools: Optional[List[Tool]] = Field([])\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n    \n    def set_compute_save(self,save:int):\n        self.meta_data['compute_save'] = str(save)\n\n\n    usage: Optional[Usage] = Field(None)\n    temperature: Optional[float] = Field(None)\n    top_p: Optional[float]= Field(None)\n    max_propmp_tokens: Optional[int]= Field(None)\n    truncation_strategy: Optional[TruncationStrategy]= Field(None)\n    tool_choice: Optional[Union[ToolChoiceType, dict]]= Field(None)\n    response_format: Union[str, Dict[str, str]] = \"auto\"\n\n\nRunStreamResponse = ForwardRef('RunStreamResponse')\n\nclass RunObject(RunBase, ObjectWithCreatedTime):\n    def stream_response_with_event(self,event:RunBase.Status)->RunStreamResponse:\n        match event:\n            case RunBase.Status.created:\n                self.status = RunBase.Status.queued\n            case _:\n                self.status = event\n        return RunStreamResponse(run=self, event=event)\n \n    \n    def sync_db(self):\n        # raise NotImplementedError # should be replaced in crud\n        sql_utils = SQLUtil()\n        db_run = Run(\n            **self.model_dump(mode='json'),\n        )\n        with sql_utils.get_db() as db:\n            sql_utils.db_merge_commit(db, db_run)\n    \n    def create_message_creation_step(self):\n        raise NotImplementedError # should be replaced \n        \n\nclass RunStreamResponse(BaseModel):\n    run: RunObject\n    event: RunObject.Status\n    def to_stream_reply(self):\n        return f\"event: thread.run.{self.event.value}\\ndata: {self.run.model_dump_json()}\\n\\n\"\n\nclass RunCreate(BaseModel):\n    assistant_id: str\n    model: Optional[str] = Field(default=None)\n    instructions: Optional[str] = Field(default=None)\n    # TODO: Add this\n    # additional_instructions: Optional[str]\n    # additional_messages: Optional[List[MessageCore]]\n    tools: List[Tool] = Field(default=[])\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n    temperature: Optional[float] = Field(default=None)\n    top_p: Optional[float] = Field(default=None)\n    stream: Optional[bool] = Field(default=None)\n    max_propmp_tokens: Optional[int] = Field(default=None)\n    # TODO: Add this\n    # max_completion_tokens: Optional[int]\n    truncation_strategy: Optional[TruncationStrategy] = Field(default=None)\n    tool_choice: Optional[Union[ToolChoiceType, dict]] = Field(default=None)\n    response_format: Union[str, Dict[str, str]] = Field(default=\"auto\")\n\n\nclass RunThreadCreate(BaseModel):\n    assistant_id: str\n    thread: Optional[ThreadCreate]\n    model: Optional[str]\n    instructions: Optional[str]\n    tools: List[Tool]\n    tool_resources: List[ToolResource]\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n    temperature: Optional[float]\n    top_p: Optional[float]\n    stream: Optional[bool]\n    max_propmp_tokens: Optional[int]\n    # TODO: Add this\n    # max_completion_tokens: Optional[int]\n    truncation_strategy: TruncationStrategy\n    tool_choice: Union[ToolChoiceType, dict]\n    response_format: Union[str, Dict[str, str]] = \"auto\"\n\n\nclass RunModify(BaseModel):\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n\n\nclass ToolOutput(BaseModel):\n    tool_call_id: Optional[str]\n    output: Optional[str]\n\n\nclass RunSubmit(BaseModel):\n    tool_outputs: List[ToolOutput]\n    stream: Optional[bool]\n"
  },
  {
    "path": "archive/ktransformers/server/schemas/assistants/streaming.py",
    "content": "import asyncio\nfrom typing import AsyncIterable, List, Union\n\nfrom fastapi import Request\nfrom fastapi.responses import StreamingResponse\nfrom pydantic import BaseModel\n\nfrom ktransformers.server.schemas.assistants.runs import RunStreamResponse\nfrom ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.schemas.base import Object\nfrom ktransformers.server.schemas.assistants.messages import ContentType, ImageFileObject, ImageUrlObject, MessageObject, Text, TextObject\n\n\nclass TextObjectWithIndex(TextObject):\n    index: int\n\n\nclass ImageFileObjectWithIndex(ImageFileObject):\n    index: int\n\n\nclass ImageUrlObjectWithIndex(ImageUrlObject):\n    index: int\n\n\nContentWithIndex = Union[TextObjectWithIndex,\n                         ImageFileObjectWithIndex, ImageUrlObjectWithIndex]\n\n\nclass MessageDeltaImpl(BaseModel):\n    # role: Optional[str]\n    content: List[ContentWithIndex]\n\n\nclass MessageDelta(Object):\n    delta: MessageDeltaImpl\n\n    def to_stream_reply(self):\n        return f\"event: thread.message.delta\\ndata: {self.model_dump_json()}\\n\\n\"\n\n\ndef text_delta(index: int, text: str):\n    return MessageDeltaImpl(content=[TextObjectWithIndex(index=index, type=ContentType.text, text=Text(value=text))])\n\n\ndef append_message_delta(self: MessageObject, text: str):\n\n    if len(self.content) == 0:\n        self.content.append(TextObject(type=ContentType.text,\n                            text=Text(value=''), delta_index=0))\n\n    text_object: TextObject = self.content[0]\n    if text_object.filter_append(text):\n        return MessageDelta(id=self.id, object=\"thread.message.delta\", delta=text_delta(text_object.delta_index, text))\n    else:\n        return None\n\n\nMessageObject.append_message_delta = append_message_delta\n\n\nclass RunStepDeltaImpl(BaseModel):\n    pass\n\n\nclass RunStepDelta(Object):\n    delta: RunStepDeltaImpl\n\n    def to_stream_reply(self):\n        return f\"event: thread.run.step.delta\\ndata: {self.model_dump_json()}\\n\\n\"\n\n\nclass Done():\n    def to_stream_reply(self):\n        return f\"data: [DONE]\\n\\n\"\n\n\nasync def check_client_link(request: Request, async_events: AsyncIterable):\n    async for event in async_events:\n        if await request.is_disconnected():\n            break\n        yield event\n\n\nasync def add_done(async_events: AsyncIterable):\n    async for event in async_events:\n        yield event\n    yield Done()\n\n\nasync def to_stream_reply(async_events: AsyncIterable):\n    async for event in async_events:\n        if isinstance(event, str):\n            yield event\n        else:\n            yield event.to_stream_reply()\n\n\nasync def filter_api_event(async_events: AsyncIterable):\n    async for event in async_events:\n        if isinstance(event, MessageDelta) or isinstance(event, RunStepDelta) or isinstance(event, RunStreamResponse) or isinstance(event, Done):\n            yield event\n\n\nasync def filter_chat_chunk(async_events: AsyncIterable):\n    async for event in async_events:\n        if isinstance(event, ChatCompletionChunk):\n            yield event\n\n\nasync def filter_by_types(async_events: AsyncIterable, types: List):\n    async for event in async_events:\n        for type in types:\n            if isinstance(event, type):\n                yield event\n                continue\n\n\ndef api_stream_response(request: Request, async_events: AsyncIterable):\n    return StreamingResponse(check_client_link(request, to_stream_reply(add_done(filter_api_event(async_events)))), media_type=\"text/event-stream\")\n\n\ndef chat_stream_response(request: Request, async_events: AsyncIterable):\n    return StreamingResponse(check_client_link(request, to_stream_reply(add_done(filter_chat_chunk(async_events)))), media_type=\"text/event-stream\")\n\n\ndef stream_response(request: Request, async_events: AsyncIterable):\n    return StreamingResponse(check_client_link(request, to_stream_reply(add_done(async_events))), media_type=\"text/event-stream\")\n\n\ndef check_link_response(request: Request, async_events: AsyncIterable):\n    return StreamingResponse(check_client_link(request, async_events), media_type=\"text/event-stream\")\n\n\ndef wrap_async_generator_into_queue(async_events: AsyncIterable) -> asyncio.Queue:\n    queue = asyncio.Queue()\n\n    async def inner():\n        # logger.debug('run inner')\n        async for event in async_events:\n            # logger.debug(f'put: {event}')\n            await queue.put(event)\n            await asyncio.sleep(0)\n        # logger.debug(f'put: None')\n        await queue.put(None)\n    asyncio.create_task(inner())\n    return queue\n\n\nasync def unwrap_async_queue(queue: asyncio.Queue) -> AsyncIterable:\n    while True:\n        events = [await queue.get()]\n        events.extend([queue.get_nowait() for _ in range(queue.qsize())])\n\n        logger.debug(f'getting {len(events)} events')\n        for event in events:\n            if event is None:\n                break\n            yield event\n\n\nasync def unwrap_async_queue_slow(queue: asyncio.Queue) -> AsyncIterable:\n    while True:\n        event = await queue.get()\n        # logger.debug(f'unwrap_async_queue {event}')\n        if event is None:\n            break\n        yield event\n"
  },
  {
    "path": "archive/ktransformers/server/schemas/assistants/threads.py",
    "content": "from enum import Enum\nfrom typing import List\nfrom typing_extensions import Self \n\nfrom pydantic import BaseModel, Field, model_validator\n\nfrom ktransformers.server.schemas.base import Metadata, MetadataField, ObjectWithCreatedTime\nfrom ktransformers.server.schemas.assistants.tool import ToolResource\nfrom ktransformers.server.schemas.assistants.messages import MessageCore\n\n\nclass ThreadBase(BaseModel):\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n\n    tool_resources: List[ToolResource] = Field([], max_length=128)\n\n\nclass ThreadObject(ThreadBase, ObjectWithCreatedTime):\n    is_related_threads:bool = Field(False,exclude=True)\n\n    @model_validator(mode='after')\n    def check_is_related_threads(self)->Self:\n        # logger.debug(f'check thread {self.id} is related thread? by {self}')\n        if 'assistant_id' in self.meta_data:\n            self.is_related_threads = True\n        return self\n\n    class StreamEvent(Enum):\n        created = 'created'\n\n    def to_stream_reply(self,event:StreamEvent):\n        return f\"event: thread.{event.value}\\ndata: {self.model_dump_json()}\\n\\n\"\n    \n\nclass ThreadCreate(ThreadBase):\n    messages: List[MessageCore] = Field(default=[])\n\n\nclass ThreadModify(ThreadBase):\n    pass\n\n\n# other than OpenAI API\n"
  },
  {
    "path": "archive/ktransformers/server/schemas/assistants/tool.py",
    "content": "from enum import Enum\nfrom typing import List, Optional, Union\n\nfrom pydantic import BaseModel, Field\n\nfrom ktransformers.server.schemas.base import ObjectID\n\n\nclass ToolType(str, Enum):\n    CODE_INTERPRETER = \"code_interpreter\"\n    FILE_SEARCH = \"file_search\"\n    RELATED_THREADS = \"related_threads\"\n    FUNCTION = \"function\"\n\n\nclass ToolBase(BaseModel):\n    type: ToolType\n\n\nclass CodeInterpreter(ToolBase):\n    pass\n\n\nclass FileSearch(ToolBase):\n    pass\n\n\nclass RelatedThreads(ToolBase):\n    pass\n\n\nclass FuntionTool(ToolBase):\n    description: str\n    name: str\n    parameters: List[str]\n\n\nTool = Union[CodeInterpreter, FileSearch, RelatedThreads, FuntionTool]\n\n\nclass CodeInterpreterResource(BaseModel):\n    file_ids: Optional[List[str]] = Field(default_factory=list, max_length=20)\n\n\nclass FileSearchResource(BaseModel):\n    vector_store_ids: Optional[List[str]] = Field(default_factory=list, max_length=1)\n    vector_stores: Optional[List[str]] = Field(default_factory=list, max_length=1)\n\n\nclass RelatedThreadsResource(BaseModel):\n    thread_ids: List[ObjectID] = Field(default=[])\n\n\nToolResource = Union[CodeInterpreterResource,FileSearchResource,RelatedThreadsResource] \n"
  },
  {
    "path": "archive/ktransformers/server/schemas/base.py",
    "content": "from enum import Enum\nfrom typing import Dict\n\nimport sqlalchemy\nfrom pydantic import BaseModel, ConfigDict, Field\n\nTODO = BaseModel\n\nObjectID = str\n\n\nclass Object(BaseModel):\n    id: ObjectID\n    object: str\n\n    model_config = ConfigDict(from_attributes=True)\n\n\n# Pydantic Base Models\nclass ObjectWithCreatedTime(Object):\n    created_at: int\n\n\n\nclass Order(str, Enum):\n    ASC = \"asc\"\n    DESC = \"desc\"\n\n    def to_sqlalchemy_order(self):\n        match self:\n            case Order.ASC:\n                return sqlalchemy.asc\n            case Order.DESC:\n                return sqlalchemy.desc\n\n\nMetadata = Dict[str, str]\nMetadataField: Metadata = Field({},max_length=16, alias=\"metadata\")\n\n\nclass DeleteResponse(Object):\n    deleted: bool = True\n\nclass OperationResponse(BaseModel):\n    operation: str\n    status: str\n"
  },
  {
    "path": "archive/ktransformers/server/schemas/conversation.py",
    "content": "from typing import Optional\n\nfrom pydantic import BaseModel\n\nfrom .assistants.assistants import AssistantObject\nfrom .assistants.threads import ThreadObject\nfrom .assistants.messages import MessageObject\n\nclass ThreadPreview(BaseModel):\n    assistant: Optional[AssistantObject] = None\n    thread: ThreadObject\n    first_message: Optional[MessageObject] = None\n"
  },
  {
    "path": "archive/ktransformers/server/schemas/endpoints/chat.py",
    "content": "from typing import List, Optional, Union, Dict, Any\nfrom typing_extensions import Literal\nfrom enum import Enum\nfrom pydantic import BaseModel, Field\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.schemas.base import Object\n\n\nfrom openai.types.chat.chat_completion_chunk import Choice\n\nfrom uuid import uuid4\n\nclass CompletionUsage(BaseModel):\n    prompt_tokens: int\n    completion_tokens: int\n    total_tokens: int\n    prompt_tokens_details: Optional[Dict[str, Any]] = None\n    completion_tokens_details: Optional[Dict[str, Any]] = None\n    prefill_time: Optional[float] = None\n    decode_time: Optional[float] = None\n\nclass Role(Enum):\n    system = 'system'\n    user = 'user'\n    assistant = 'assistant'\n    tool = 'tool'\n    function = 'function'\n\nclass Message(BaseModel):\n    content: Optional[str] = None\n    role: Role\n    name: Optional[str] = None\n    tool_calls: Optional[List[Dict[str, Any]]] = {}\n    tool_call_id: Optional[str] = None\n    \n    def to_tokenizer_message(self):\n        message = {'role': self.role.value}\n        if self.content is not None:\n            message['content'] = self.content\n        if self.name is not None:\n            message['name'] = self.name\n        if self.tool_calls is not {}:\n            message['tool_calls'] = self.tool_calls\n        if self.tool_call_id is not None:\n            message['tool_call_id'] = self.tool_call_id\n        return message\n\nclass FunctionParameters(BaseModel):\n    type: str = \"object\"\n    properties: Dict[str, Any] = {}\n    required: Optional[List[str]] = None\n\nclass FunctionDefinition(BaseModel):\n    name: str\n    description: Optional[str] = None\n    parameters: FunctionParameters = Field(default_factory=FunctionParameters)\n\nclass ToolFunction(BaseModel):\n    function: FunctionDefinition\n    \nclass Tool(BaseModel):\n    type: Literal[\"function\"]\n    function: FunctionDefinition\n\nclass ChatCompletionCreate(BaseModel):\n    messages: List[Message]\n    model: str\n    stream: bool = False\n    temperature: Optional[float] = Field(default=Config().temperature)\n    top_p: Optional[float] = Field(default=Config().top_p)\n    tools: Optional[List[Tool]] = None\n    tool_choice: Optional[Union[str, Dict[str, Any]]] = None\n    stream_options: Optional[Dict[str, Any]] = None\n    frequency_penalty: float = 0\n    presence_penalty: float = 0\n    max_tokens: Optional[int] = Field(default=None)\n    max_completion_tokens: Optional[int] = Field(default=None)\n    return_speed: Optional[bool] = Field(default=False)\n    def get_tokenizer_messages(self):\n        return [m.to_tokenizer_message() for m in self.messages]\n\nclass ChatCompletionChunk(BaseModel):\n    id: str\n    choices: List[Choice]\n    created: int\n    model: str\n    object: Literal[\"chat.completion.chunk\"]\n    service_tier: Optional[Literal[\"scale\", \"default\"]] = None\n    system_fingerprint: Optional[str] = None\n    usage: Optional[CompletionUsage] = None\n\n    def to_stream_reply(self):\n        return f\"data: {self.model_dump_json()}\\n\\n\"\n\nclass RawUsage(BaseModel):\n    tokenize_time: float\n    prefill_time: float\n    decode_time: float\n    prefill_count: int\n    decode_count: int"
  },
  {
    "path": "archive/ktransformers/server/schemas/legacy/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/schemas/legacy/completions.py",
    "content": "from typing import List, Optional\nfrom enum import Enum\nfrom pydantic import BaseModel, Field\nfrom ktransformers.server.config.config import Config\nfrom ..base import Object\n\nclass CompletionCreate(BaseModel):\n    model: str\n    prompt: str | List[str]\n    stream: bool = False\n    temperature: Optional[float] = Field(default=Config().temperature)\n    top_p: Optional[float] = Field(default=Config().top_p)\n    max_tokens: Optional[int] = Field(default=None)\n    max_completion_tokens: Optional[int] = Field(default=None)\n    \n    def get_tokenizer_messages(self):\n        if isinstance(self.prompt,List):\n            self.get_tokenizer_messages('\\n'.join(self.prompt))\n        return [{'content':self.prompt,'role':'user'}]\n\n\nclass FinishReason(Enum):\n    stop = 'stop'\n    length = 'length'\n\nclass Choice(BaseModel):\n    index: int\n    text: str\n    logprobs: Optional[str] = None\n    finish_reason: FinishReason = None\n\n\nclass CompletionObject(Object):\n    created:int\n    choices: List[Choice] = []\n    model:str = 'not implmented'\n    system_fingerprint:str = 'not implmented'\n    usage: Optional[str] = None\n\n    def set_token(self,token:str):\n        if len(self.choices)==0:\n            self.choices.append(Choice(index=0,text=''))\n        self.choices[0].text = token    \n\n    def append_token(self,token:str):\n        if len(self.choices)==0:\n            self.choices.append(Choice(index=0,text=''))\n        self.choices[0].text += token\n\n    def to_stream_reply(self):\n        return f\"data:{self.model_dump_json()}\\n\\n\"\n"
  },
  {
    "path": "archive/ktransformers/server/utils/__init__.py",
    "content": ""
  },
  {
    "path": "archive/ktransformers/server/utils/create_interface.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : qiyuxinlin\nDate         : 2024-07-25 11:50:16\nVersion      : 1.0.0\nLastEditors  : qiyuxinlin \nLastEditTime : 2024-07-25 12:54:48\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.backend.args import ConfigArgs\nfrom ktransformers.server.backend.context_manager import ThreadContextManager\nfrom ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface\nfrom ktransformers.server.backend.interfaces.transformers import TransformersInterface\nfrom ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface\n\ndef create_interface(config: Config, default_args: ConfigArgs, input_args=None):\n    if config.backend_type=='transformers':\n        from ktransformers.server.backend.interfaces.transformers import  TransformersInterface as BackendInterface\n    elif config.backend_type == 'exllamav2':\n        from ktransformers.server.backend.interfaces.exllamav2 import  ExllamaInterface as BackendInterface\n    elif config.backend_type == 'ktransformers':\n        from ktransformers.server.backend.interfaces.ktransformers import  KTransformersInterface as BackendInterface\n    elif config.backend_type == 'balance_serve':\n        from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface\n    else:\n        raise NotImplementedError(f'{config.backend_type} not implemented')\n    if config.backend_type == 'ktransformers':\n        GlobalInterface.interface = BackendInterface(default_args, input_args)\n    elif config.backend_type == 'balance_serve':\n        GlobalInterface.interface = BackendInterface(default_args, input_args)\n    else:\n        GlobalInterface.interface = BackendInterface(default_args)\n    GlobalContextManager.context_manager = ThreadContextManager(GlobalInterface.interface)\n\nclass GlobalContextManager:\n    context_manager: ThreadContextManager\nclass GlobalInterface:\n    interface:  TransformersInterface | KTransformersInterface | ExllamaInterface \n    \ndef get_thread_context_manager() -> GlobalContextManager:\n    return GlobalContextManager.context_manager\ndef get_interface() -> GlobalInterface:\n    return GlobalInterface.interface"
  },
  {
    "path": "archive/ktransformers/server/utils/multi_timer.py",
    "content": "import time\n\n\ndef format_time(seconds):\n    units = [\n        (\"hours\", 3600),\n        (\"minutes\", 60),\n        (\"seconds\", 1),\n        (\"milliseconds\", 1e-3),\n        (\"microseconds\", 1e-6),\n    ]\n\n    for unit_name, unit_value in units:\n        if seconds >= unit_value:\n            time_value = seconds / unit_value\n            return f\"{time_value:.2f} {unit_name}\"\n    return \"0 seconds\"  # Handle case for 0 seconds\n\n\nclass Profiler:\n    def __init__(self):\n        self.timers = {}\n        self.counters = {}\n\n    def create_timer(self, name):\n        self.timers[name] = {\n            \"start_time\": None,\n            \"elapsed_time\": 0,\n            \"running\": False,\n        }\n\n    def start_timer(self, name):\n        if name not in self.timers:\n            raise ValueError(f\"Timer '{name}' does not exist.\")\n        if self.timers[name][\"running\"]:\n            raise ValueError(f\"Timer '{name}' is already running.\")\n        self.timers[name][\"start_time\"] = time.time()\n        self.timers[name][\"running\"] = True\n\n    def pause_timer(self, name):\n        if name not in self.timers:\n            raise ValueError(f\"Timer '{name}' does not exist.\")\n        if not self.timers[name][\"running\"]:\n            raise ValueError(f\"Timer '{name}' is not running.\")\n        self.timers[name][\"elapsed_time\"] += time.time() - self.timers[name][\"start_time\"]\n        self.timers[name][\"running\"] = False\n\n    def get_timer_sec(self, name):\n        if name not in self.timers:\n            raise ValueError(f\"Timer '{name}' does not exist.\")\n        if self.timers[name][\"running\"]:\n            current_time = self.timers[name][\"elapsed_time\"] + (time.time() - self.timers[name][\"start_time\"])\n        else:\n            current_time = self.timers[name][\"elapsed_time\"]\n        return current_time\n\n    def get_all_timers(self):\n        all_timers = {}\n        for name in self.timers:\n            all_timers[name] = self.get_timer_sec(name)\n        return all_timers\n\n    def report_timer_string(self, name):\n        return f\"{name} elapsed time: {format_time(self.get_timer_sec(name))}\"\n\n    def create_and_start_timer(self, name):\n        self.create_timer(name)\n        self.start_timer(name)\n\n\n    # Counter\n    def inc(self,key:str,delta:int=1):\n        self.counters[key] = self.counters.get(key,0) + delta\n\n    def set_counter(self,key:str,to=0):\n        self.counters[key] = to\n\n    def get_counter(self,key:str):\n        return self.counters.get(key,0)\n"
  },
  {
    "path": "archive/ktransformers/server/utils/serve_profiling.py",
    "content": "import re\nimport itertools\nimport time\nimport enum\nimport math\nfrom enum import StrEnum\n\nclass ProfStatKey(StrEnum):\n    ExpertsSummitCurrLayer = \"ExpertsSummitCurrLayer\"\n    ExpertsSummitNextLayer = \"ExpertsSummitNextLayer\"\n    ExpertsCPUForwardOne = \"ExpertsCPUForwardOne\"\n    ExpertsCPUForwardTwo = \"ExpertsCPUForwardTwo\"\n    CPUMoEKExpertsCallback = \"CPUMoEKExpertsCallback\"\n\nclass ProfTimeStat:\n    def __init__(self):\n        # open_status = os.environ[\"KT_PERF_STAT\"] if \"KT_PERF_STAT\" in os.environ else \"0\"\n        # if open_status == \"0\":\n        #     self.on = False\n        # else:\n        #     self.on = True\n        self.on = False\n        self.prefill_stats = dict()\n        self.decode_stats = dict()\n        for key in ProfStatKey:\n            self.prefill_stats[key] = ProfStatItem()\n            self.decode_stats[key] = ProfStatItem()\n        self.reset_all()\n\n    def record_start_time(self):\n        start_time = time.time_ns()\n        return start_time\n\n    def add_time_stat(self, key: ProfStatKey, time_ns, is_prefill):\n        if not key:\n            return\n        # torch.cuda.synchronize()\n        cost = time.time_ns() - time_ns\n        if is_prefill:\n            item = self.prefill_stats[key]\n        else:\n            item = self.decode_stats[key]\n        item.add_item(cost)\n\n    def print_all(self):\n        # rank = f\"[rank:{torch.distributed.get_rank()}]\"\n        rank = f\"[rank:0]\"\n        msg = f\"\\n{rank} Prefill Time Stat\\n\"\n        msg += rank + \" {:27}{:>15}{:>15}{:>15}{:>15}{:>15}{:>15}{:>15}\\n\".format(\"\", \"min(ms)\", \"max(ms)\", \"avg(ms)\", \"count\", \"total(ms)\", \">2ms\", \">10ms\")\n        for key, value in self.prefill_stats.items():\n            msg += rank + f\" {key.value:<25}:{value.get_stat()}\\n\"\n        msg += f\"\\n{rank} Decode Time Stat\\n\"\n        msg += rank + \" {:27}{:>15}{:>15}{:>15}{:>15}{:>15}{:>15}{:>15}\\n\".format(\"\", \"min(ms)\", \"max(ms)\", \"avg(ms)\", \"count\", \"total(ms)\", \">2ms\", \">10ms\")\n        for key, value in self.decode_stats.items():\n            msg += rank + f\" {key.value:<25}:{value.get_stat()}\\n\"\n        print(msg)\n\n    def reset_all(self):\n        for _, value in self.prefill_stats.items():\n            value.reset()\n        for _, value in self.decode_stats.items():\n            value.reset()\n\n\nclass ProfStatItem:\n    def __init__(self):\n        self.min_time = 100000000\n        self.max_time = 0\n        self.total_time_ns = 0\n        self.count = 0\n        self.err_time = []\n        self.ms_count2 = 0\n        self.ms_count10 = 0\n\n    def add_item(self, cost_time_ns):\n        self.count += 1\n        self.total_time_ns += cost_time_ns\n        self.min_time = min(self.min_time, cost_time_ns)\n        self.max_time = max(self.max_time, cost_time_ns)\n        if (cost_time_ns > 2000000):\n        #   self.err_time.append(round(cost_time_ns / 1000 / 1000, 2))\n          self.ms_count2 += 1\n        if (cost_time_ns > 10000000):\n        #   self.err_time.append(round(cost_time_ns / 1000 / 1000, 2))\n          self.ms_count10 += 1\n        # self.err_time.append(round(cost_time_ns / 1000 / 1000, 2))\n\n    def reset(self):\n        self.min_time = 100000000\n        self.max_time = 0\n        self.total_time_ns = 0\n        self.count = 0\n\n    def get_stat(self):\n        min_time = self.min_time / 1000 / 1000\n        max_time = self.max_time / 1000 / 1000\n        if self.count != 0:\n            avg_time = self.total_time_ns / self.count / 1000 / 1000\n        else:\n            avg_time = 0\n        total = self.total_time_ns / 1000 / 1000\n        # tmpstr = str(self.err_time)\n        # print(f\"\\r\\n err_time: {tmpstr} \\r\\n \")\n        return f\"{min_time:15.2f}{max_time:15.2f}{avg_time:15.2f}{self.count:15}{total:15.2f}{self.ms_count2:>15}{self.ms_count10:>15}\"\n\n\nPROF_TIME_STAT = ProfTimeStat()\n\n"
  },
  {
    "path": "archive/ktransformers/server/utils/sql_utils.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenxl\nDate         : 2024-06-12 09:12:58\nVersion      : 1.0.0\nLastEditors  : chenxl \nLastEditTime : 2024-07-27 01:56:04\n'''\n\nfrom urllib.parse import urlparse\nimport os\nfrom contextlib import contextmanager\nfrom sqlalchemy import create_engine\nfrom sqlalchemy.orm import Session, sessionmaker, declarative_base\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.config.singleton import Singleton\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.exceptions import db_exception\n\n\nBase = declarative_base()\n\n\nclass SQLUtil(metaclass=Singleton):\n    \"\"\"\n    database connections init and management\n    \"\"\"\n    sqlalchemy_engine = None\n    session_local = None\n\n    def __init__(self) -> None:\n        self.cfg: Config = Config()\n        if not self.sqlalchemy_engine:\n            SQLUtil.init_engine(self.cfg)\n\n    @contextmanager\n    def get_db(self):\n        \"\"\"\n        After you finish using the session, it's crucial to close it.\n        \"\"\"\n        if not SQLUtil.sqlalchemy_engine:\n            SQLUtil.init_engine(self.cfg)\n        session = self.session_local()  # type: ignore pylint: disable=not-callable\n        try:\n            yield session\n        finally:\n            session.close()\n\n    @staticmethod\n    def init_engine(cfg: Config):\n        \"\"\"\n        initial engine and session maker Factory\n        \"\"\"\n        pool_size = cfg.db_pool_size\n        if SQLUtil.sqlalchemy_engine is None:\n            if cfg.db_type == \"sqllite\":\n                db_url = SQLUtil.create_sqllite_url(cfg)\n            else:\n                logger.error(\"Unsupported database type %s\", cfg.db_type)\n                exit(-1)\n            SQLUtil.sqlalchemy_engine = create_engine(\n                db_url, connect_args={\"check_same_thread\": False}, pool_size=pool_size)\n            SQLUtil.session_local = sessionmaker(\n                autocommit=False, autoflush=False, bind=SQLUtil.sqlalchemy_engine)\n\n    @staticmethod\n    def create_sqllite_url(cfg):\n        \"\"\"\n        create and validate SQLLite url\n        \"\"\"\n        path: str = cfg.db_host\n        database: str = cfg.db_database\n        absolute_path: str = os.path.join(path, database)\n        url = 'sqlite:///' + absolute_path\n        try:\n            result = urlparse(url)\n            if all([result.scheme, result.path, result.scheme == 'sqlite']):\n                return url\n            else:\n                logger.error(\"invalid sqllite url: %s\", url)\n                exit(-1)\n        except ValueError:\n            logger.error(\"invalid sqllite url: %s\", url)\n            exit(-1)\n\n    def db_add_commit_refresh(self, session: Session, what):\n        \"\"\"\n        add data to database\n        \"\"\"\n        try:\n            session.add(what)\n            session.commit()\n            session.refresh(what)\n        except Exception as e:\n            logger.exception(\"db commit error with data %s\", str(what.__dict__))\n            ex = db_exception()\n            ex.detail = str(e)\n            session.rollback()\n            raise ex from e\n\n    def db_merge_commit(self, session: Session, what):\n        try:\n            session.merge(what)\n            session.commit()\n        except Exception as e:\n            ex = db_exception()\n            ex.detail = str(e)\n            logger.exception(\"db merge commit error with data %s\", str(what.__dict__))\n            session.rollback()\n            raise ex from e\n\n    def db_update_commit_refresh(self, session: Session, existing, what):\n        what = what.model_dump(mode=\"json\")\n        try:\n            for key in what.keys():\n                if what[key] is not None:  # 检查b中的字段是否为None\n                    setattr(existing, key, what[key])  # 更新a的字段\n            session.commit()\n            session.refresh(existing)\n        except Exception as e:\n            ex = db_exception()\n            ex.detail = str(e)\n            logger.exception(\"db update commit refresh error with data %s\", str(what.__dict__))\n            session.rollback()\n            raise ex from e\n"
  },
  {
    "path": "archive/ktransformers/tests/.gitignore",
    "content": "results/"
  },
  {
    "path": "archive/ktransformers/tests/AIME_2024/eval_api.py",
    "content": "# adapt from https://github.com/abacaj/code-eval?tab=readme-ov-file\nimport argparse\nimport json\nimport os\nimport time\nimport requests\nimport tqdm\n\nfrom evaluation import filter_answer\nfrom prompts import instruct_prompt\nimport pandas as pd\nfrom datasets import load_dataset\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n\n\ndef generate_text(api_url,question , model_name, stream=False, auth_token=None):\n    headers = {\n        'accept': 'application/json',\n        'Content-Type': 'application/json',\n        # 添加 API Key\n        'Authorization' : 'Bearer ' + auth_token if auth_token else ''\n    }\n    question = instruct_prompt(question)\n    data = {\n        \"messages\": [{\"content\": question, \"role\": \"user\"}],\n        \"model\": model_name,\n        \"stream\": stream,\n        \"temperature\": 0.6,\n        \"max_tokens\": 10240,\n    }\n    print(f\"content: {question}\")\n    response = requests.post(api_url, headers=headers, json=data,verify=False)\n    if response.status_code == 200:\n        result = response.json()\n        results = result.get('choices', [{}])[0].get('message', {}).get('content', '')\n        return filter_answer(results)\n    else:\n        print(f\"API Request failed with status code {response.status_code}\")\n        return None\ndef load_data(file_path):\n        \"\"\"\n        Load data from a Parquet file into a list.\n        Each record in the Parquet file should represent an individual record.\n        \"\"\"\n        # 读取 Parquet 文件\n        # dataset = load_dataset('parquet', data_files=file_path)\n        data = []\n        ds = load_dataset(file_path)\n        df = pd.DataFrame(ds['train'])\n        for _, row in df.iterrows():\n            data.append(row.to_dict())\n        return data\n\ndef get_score(pred, answer):\n        \"\"\"\n        Calculate scores between the prediction and the answer.\n        Uses ROUGE scores as the evaluation metric.\n        :param pred: The predicted string.\n        :param answer: The reference answer string.\n        :return: A dictionary containing ROUGE scores.\n        \"\"\"\n        if pred == answer:\n            return 1\n        # if we need to compare str with number, convert teh str to number\n        try:\n            pred = float(pred)\n            answer = float(answer)\n        except:\n            pass\n        if pred == answer:\n            return 1\n        return 0\n\ndef run_eval_api(\n    api_url: str,\n    model_name: str,\n    out_path: str,\n    format_tabs: bool = False,\n    auth_token: str = None,\n    problem_file: str = None,\n    append: bool = False,\n    skip: int = 0\n):\n  \n    data = load_data(problem_file)\n    pbar = tqdm.tqdm(total=len(data) * 1)\n    pbar.update(skip)\n    for i in range(len(data)):\n        i = i+skip\n        data_item = data[i]\n        question = data_item['Problem']\n        # Start the timer for this evaluation\n        start_time = time.time()\n        try:\n            completion = generate_text(api_url, question, model_name, auth_token=auth_token)\n            if completion is None:\n                raise Exception(f\"Failed to get prediction for {question}\")\n            answer = data_item['Answer']\n            score = get_score(completion, answer)\n            elapsed_time = time.time() - start_time\n            result = {\n                \"index\": i,\n                \"question_id\": data_item[\"ID\"],\n                \"answer\": answer,\n                \"prediction\": completion,\n                \"score\": score,\n                \"time\": elapsed_time\n            }\n            with open(out_path, \"a\" if append else \"w\") as f:\n                f.write(json.dumps(result) + \"\\n\")\n            \n        except Exception as e:\n            print(f\"Failed to get prediction for {question}\")\n            print(e)\n            continue\n\n        pbar.update(1)\n    \n\ndef main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip):\n    os.makedirs(os.path.dirname(output_path), exist_ok=True)\n    run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"API Generate Tester\")\n    parser.add_argument(\"--api_url\", type=str, default=\"https://api.siliconflow.cn/v1/chat/completions\", help=\"API URL\")\n    parser.add_argument(\"--model_name\", type=str, default=\"Pro/deepseek-ai/DeepSeek-R1\", help=\"Model Name\")\n    parser.add_argument(\"--out_path\", type=str, default=\"results/api/eval_aime.jsonl\", help=\"Output Path\")\n    parser.add_argument(\"--auth_token\", type=str, default=None, help=\"Auth Token\")\n    parser.add_argument(\"--format_tabs\", action=\"store_true\", help=\"Format Tabs\")\n    parser.add_argument(\"--problem_file\", type=str, default=\"Maxwell-Jia/AIME_2024\", help=\"Evalset File\")\n    parser.add_argument(\"--no_append\", action=\"store_false\", help=\"Append to existing file\")\n    parser.add_argument(\"--skip\", type=int, default=0, help=\"Skip some tasks\")\n    args = parser.parse_args()\n    # api_url = \"https://api.siliconflow.cn/v1/chat/completions\"\n    main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append, args.skip)"
  },
  {
    "path": "archive/ktransformers/tests/AIME_2024/evaluation.py",
    "content": "# reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35\ndef filter_answer(completion: str) -> str:\n    # the answer is the last part of the completion, it's a int64 number\n    # get the last line\n    completion = completion.strip().split(\"\\n\")[-1]\n    # handle the $\\\\boxed{...}$ format\n    if \"$\\\\boxed{\" in completion:\n        return completion.split(\"}\")[0].split(\"{\")[-1]\n    return completion.split()[-1]\n\n"
  },
  {
    "path": "archive/ktransformers/tests/AIME_2024/prompts.py",
    "content": "def instruct_prompt(prompt: str) -> str:\n    return f\"\"\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\nSolve the following math problem without any tests or explanation only one answer surrounede by '$\\\\boxed{{}}$'\\n{prompt}\\n\\n### Response:\"\"\"\n"
  },
  {
    "path": "archive/ktransformers/tests/UT/test_kdeepseek_attention_w8a8a2serve_npu.py",
    "content": "import sys\nimport types\n\nimport torch\nimport torch.nn as nn\nimport pytest\n\ntorch_npu = pytest.importorskip(\"torch_npu\")\n\nfrom ktransformers.operators.ascend.ascend_attention import (\n    KDeepseekV2AttentionW8A8A2Serve,\n)\nimport ktransformers.operators.ascend.ascend_attention as attn_mod\n\nclass DummyConfig:\n    def __init__(self, hidden_size=4, num_attention_heads=1):\n        self.hidden_size = hidden_size\n        self.num_attention_heads = num_attention_heads\n\n\nclass DummyOrigAttn(nn.Module):\n    def __init__(self, config=None, layer_idx=0):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n\n        hidden_dim = config.hidden_size if config is not None else 4\n\n        self.q_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)\n        self.kv_a_proj_with_mqa = None\n        self.kv_a_layernorm = nn.LayerNorm(2)\n        self.o_proj = None\n\n\nclass DummyDynamicQuantOps:\n    def execute(self, inputs):\n        x = inputs[0]\n        return [x]\n\n\nclass DummyMatMulOps:\n    def execute(self, inputs):\n        x = inputs[0]\n        return [x]\n\n\nclass DummyQuantProj(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.input_scale = torch.tensor(1.0, dtype=torch.float16)\n        self.input_offset = torch.tensor(0.0, dtype=torch.float16)\n        self.weight = nn.Parameter(torch.zeros(dim, dim, dtype=torch.float16))\n        self.quant_bias = torch.zeros(dim, dtype=torch.float16)\n        self.deq_scale = torch.tensor(1.0, dtype=torch.float16)\n\n\nclass DummyStaticCache:\n    def __init__(self, page_size=16):\n        self.page_size = page_size\n\n    def get_usable_length(self, kv_seq_len, layer_idx):\n        return 0\n\n    def update(self, combined, layer_idx, cache_kwargs):\n        return combined, None\n\n\nclass DummyNpuFusedAttention:\n    def __call__(self, q, k, v, **kwargs):\n        bsz, max_q_len, num_heads, dim = q.shape\n        out = torch.zeros(\n            bsz, max_q_len, num_heads, dim, dtype=q.dtype, device=q.device\n        )\n        softmax_lse = torch.zeros(1, dtype=q.dtype, device=q.device)\n        return out, softmax_lse\n\n    def out(self, q, k, v, workspace=None,\n            query_rope=None, key_rope=None,\n            num_heads=None, num_key_value_heads=None,\n            input_layout=None, scale=None,\n            antiquant_mode=None, antiquant_scale=None,\n            block_table=None, block_size=None,\n            actual_seq_lengths_kv=None,\n            sparse_mode=None,\n            out=None):\n        attn_output, softmax_lse = out\n        attn_output.zero_()\n        softmax_lse.zero_()\n        return attn_output, softmax_lse\n\n\nclass DummyOpsNpu:\n    def npu_fused_infer_attention_score(self, q, k, v, **kwargs):\n        bsz, num_heads, q_len, dim = q.shape\n        out = torch.zeros(\n            bsz, num_heads, q_len, dim, dtype=q.dtype, device=q.device\n        )\n        softmax_lse = torch.zeros(1, dtype=q.dtype, device=q.device)\n        return out, softmax_lse\n\ndef fake_apply_rotary_pos_emb_fusion(q_pe, k_pe, cos, sin):\n    return q_pe, k_pe\n\ndef build_attention_module(q_lora_rank=None):\n    if hasattr(attn_mod, \"get_tensor_parallel_size\"):\n        attn_mod.get_tensor_parallel_size = lambda: 1  # type: ignore\n\n    config = DummyConfig(hidden_size=4, num_attention_heads=1)\n    orig = DummyOrigAttn(config=config, layer_idx=0)\n\n    attn = KDeepseekV2AttentionW8A8A2Serve(\n        key=\"test\",\n        gguf_loader=None,\n        config=config,\n        orig_module=orig,\n        prefill_device=\"npu\",\n        generate_device=\"npu\",\n    )\n\n    hidden_dim = 4\n    num_heads = 1\n    qk_nope_head_dim = 2\n    qk_rope_head_dim = 2\n    q_head_dim = qk_nope_head_dim + qk_rope_head_dim  # 4\n    kv_lora_rank = 2\n    v_head_dim = 2\n\n    attn.num_heads = num_heads\n    attn.q_head_dim = q_head_dim\n    attn.qk_nope_head_dim = qk_nope_head_dim\n    attn.qk_rope_head_dim = qk_rope_head_dim\n    attn.kv_lora_rank = kv_lora_rank\n    attn.v_head_dim = v_head_dim\n    attn.softmax_scale = 1.0\n    attn.layer_idx = 0\n    attn.sparse_mode = 0\n    attn.q_lora_rank = q_lora_rank\n\n    attn.elewise_quant = DummyDynamicQuantOps()\n    attn.matmulDequant_operation = DummyMatMulOps()\n    attn.matmulDequant_operation_aclnn = DummyMatMulOps()\n\n    orig_mod = attn.orig_module\n\n    if q_lora_rank is None:\n        orig_mod.q_proj = nn.Linear(hidden_dim, num_heads * q_head_dim, bias=False)\n        orig_mod.q_proj = orig_mod.q_proj.to(dtype=torch.float16)\n    else:\n        orig_mod.q_a_proj = DummyQuantProj(hidden_dim)\n        orig_mod.q_b_proj = DummyQuantProj(hidden_dim)\n        orig_mod.q_a_layernorm = nn.LayerNorm(hidden_dim)\n\n    orig_mod.kv_a_proj_with_mqa = DummyQuantProj(hidden_dim)\n    orig_mod.kv_a_layernorm = nn.LayerNorm(kv_lora_rank)\n\n    orig_mod.o_proj = DummyQuantProj(num_heads * v_head_dim)\n\n    attn.q_absorb = torch.randn(\n        num_heads, qk_nope_head_dim, kv_lora_rank, dtype=torch.float16\n    )\n    attn.out_absorb = torch.randn(\n        num_heads, kv_lora_rank, v_head_dim, dtype=torch.float16\n    )\n    def fake_rotary_emb(q_pe, position_ids):\n        bsz, n_heads, q_len, dim = q_pe.shape\n        cos = torch.ones(1, 1, q_len, dim, dtype=q_pe.dtype, device=q_pe.device)\n        sin = torch.zeros(1, 1, q_len, dim, dtype=q_pe.dtype, device=q_pe.device)\n        return cos, sin\n\n    attn.rotary_emb = fake_rotary_emb\n\n    return attn\n\n@pytest.fixture(autouse=True)\ndef _patch_env(monkeypatch):\n    if hasattr(attn_mod, \"apply_rotary_pos_emb_fusion\"):\n        monkeypatch.setattr(\n            attn_mod, \"apply_rotary_pos_emb_fusion\",\n            fake_apply_rotary_pos_emb_fusion\n        )\n\n    if hasattr(attn_mod, \"get_use_npu_graph\"):\n        monkeypatch.setattr(attn_mod, \"get_use_npu_graph\", lambda: False)\n\n    if hasattr(attn_mod, \"get_tensor_parallel_size\"):\n        monkeypatch.setattr(attn_mod, \"get_tensor_parallel_size\", lambda: 1)\n\n    if hasattr(attn_mod, \"get_tensor_parallel_group\"):\n        monkeypatch.setattr(attn_mod, \"get_tensor_parallel_group\", lambda: None)\n\n    if hasattr(attn_mod, \"get_current_device\"):\n        monkeypatch.setattr(attn_mod, \"get_current_device\", lambda: \"cpu\")\n\n    # torch.distributed.barrier -> no-op\n    if hasattr(torch, \"distributed\") and hasattr(torch.distributed, \"barrier\"):\n        monkeypatch.setattr(\n            torch.distributed, \"barrier\",\n            lambda *args, **kwargs: None,\n            raising=False,\n        )\n\n    dummy_op = DummyNpuFusedAttention()\n    monkeypatch.setattr(\n        torch_npu, \"npu_fused_infer_attention_score\",\n        dummy_op, raising=False\n    )\n\n    def fake_get_workspace(q, k, v, **kwargs):\n        return torch.empty(1, dtype=q.dtype, device=q.device)\n\n    monkeypatch.setattr(\n        torch_npu, \"_npu_fused_infer_attention_score_get_max_workspace\",\n        fake_get_workspace, raising=False\n    )\n\n    monkeypatch.setattr(torch.ops, \"npu\", DummyOpsNpu(), raising=False)\n\n    yield\n\n\n# ==========================\n#  测试用例\n# ==========================\n\ndef test_print_callback_smoke():\n    attn = build_attention_module()\n    bsz, q_len, hidden_dim = 1, 3, 4\n    hidden_states = torch.randn(bsz, q_len, hidden_dim)\n    position_ids = torch.arange(q_len).unsqueeze(0)\n    cache_position = torch.arange(q_len).unsqueeze(0)\n    page_idx = torch.zeros(bsz, dtype=torch.int32)\n    page_offset = torch.zeros(bsz, dtype=torch.int32)\n    block_table = torch.zeros(bsz, 1, dtype=torch.int32)\n\n    attn.print_callback(\n        (hidden_states, position_ids, cache_position,\n         page_idx, page_offset, block_table)\n    )\n\n\ndef _common_inputs_prefill():\n    bsz, q_len, hidden_dim = 1, 3, 4\n    hidden_states = torch.randn(bsz, q_len, hidden_dim, dtype=torch.float16)\n    attention_mask = torch.zeros(bsz, 1, q_len, q_len, dtype=torch.float32)\n    position_ids = torch.arange(q_len).unsqueeze(0)\n    cache_position = torch.arange(q_len).unsqueeze(0)\n    page_idx = torch.zeros(bsz, dtype=torch.int32)\n    page_offset = torch.zeros(bsz, dtype=torch.int32)\n    block_table = torch.zeros(bsz, 1, dtype=torch.int32)\n    past_key_value = DummyStaticCache(page_size=16)\n    q_len_raw = torch.tensor([q_len], dtype=torch.int32)\n    kv_len_raw = torch.tensor([q_len], dtype=torch.int32)\n\n    return (\n        hidden_states, attention_mask, position_ids, cache_position,\n        page_idx, page_offset, block_table,\n        past_key_value, q_len_raw, kv_len_raw\n    )\n\n\ndef test_forward_prefill_with_mask():\n    \"\"\"\n    is_prefill=True + attention_mask 不为 None + past_key_value 不为 None\n    \"\"\"\n    attn = build_attention_module(q_lora_rank=None)\n\n    (hidden_states, attention_mask, position_ids, cache_position,\n     page_idx, page_offset, block_table,\n     past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill()\n\n    outputs = attn.forward(\n        hidden_states=hidden_states,\n        attention_mask=attention_mask,\n        position_ids=position_ids,\n        past_key_value=past_key_value,\n        output_attentions=False,\n        use_cache=True,\n        cache_position=cache_position,\n        is_prefill=True,\n        page_idx=page_idx,\n        page_offset=page_offset,\n        block_table=block_table,\n        q_len_raw=q_len_raw,\n        kv_len_raw=kv_len_raw,\n        stream=None,\n    )\n\n    attn_output, attn_weights, new_cache = outputs\n    assert attn_output.shape == (\n        1,  # bsz\n        3,  # q_len\n        attn.num_heads * attn.v_head_dim,\n    )\n    assert attn_weights is None\n    assert new_cache is past_key_value\n\n\ndef test_forward_prefill_without_mask_and_q_lora():\n    \"\"\"\n    is_prefill=True + attention_mask=None + q_lora_rank 非 None 分支\n    \"\"\"\n    attn = build_attention_module(q_lora_rank=1)\n\n    (hidden_states, attention_mask, position_ids, cache_position,\n     page_idx, page_offset, block_table,\n     past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill()\n\n    outputs = attn.forward(\n        hidden_states=hidden_states,\n        attention_mask=None,\n        position_ids=position_ids,\n        past_key_value=past_key_value,\n        output_attentions=False,\n        use_cache=True,\n        cache_position=cache_position,\n        is_prefill=True,\n        page_idx=None,\n        page_offset=None,\n        block_table=None,\n        q_len_raw=q_len_raw,\n        kv_len_raw=kv_len_raw,\n        stream=None,\n    )\n\n    attn_output, attn_weights, new_cache = outputs\n    assert attn_output.shape == (\n        1,\n        3,\n        attn.num_heads * attn.v_head_dim,\n    )\n    assert attn_weights is None\n    assert new_cache is past_key_value\n\n\ndef test_forward_decode_paged_path():\n    \"\"\"\n    is_prefill=False + get_use_npu_graph=False\n    => 走 forward_paged + torch.ops.npu.npu_fused_infer_attention_score 分支\n    \"\"\"\n    attn = build_attention_module(q_lora_rank=None)\n\n    bsz, q_len, hidden_dim = 1, 1, 4\n    hidden_states = torch.randn(bsz, q_len, hidden_dim, dtype=torch.float16)\n    position_ids = torch.arange(q_len).unsqueeze(0)\n    cache_position = torch.arange(q_len).unsqueeze(0)\n    past_key_value = DummyStaticCache(page_size=16)\n    q_len_raw = torch.tensor([q_len], dtype=torch.int32)\n    kv_len_raw = torch.tensor([q_len], dtype=torch.int32)\n    block_table = torch.zeros(bsz, 1, dtype=torch.int32)\n\n    outputs = attn.forward(\n        hidden_states=hidden_states,\n        attention_mask=None,\n        position_ids=position_ids,\n        past_key_value=past_key_value,\n        output_attentions=False,\n        use_cache=True,\n        cache_position=cache_position,\n        is_prefill=False,\n        page_idx=None,\n        page_offset=None,\n        block_table=block_table,\n        q_len_raw=q_len_raw,\n        kv_len_raw=kv_len_raw,\n        stream=None,\n    )\n\n    attn_output, attn_weights, new_cache = outputs\n    assert attn_output.shape == (\n        bsz,\n        q_len,\n        attn.num_heads * attn.v_head_dim,\n    )\n    assert attn_weights is None\n    assert new_cache is past_key_value\n\n\ndef test_forward_prefill_layer_idx_none_raises():\n    \"\"\"\n    覆盖: past_key_value 不为 None 且 layer_idx 为 None 的异常分支。\n    \"\"\"\n    attn = build_attention_module(q_lora_rank=None)\n    attn.layer_idx = None  # 手动破坏 layer_idx\n\n    (hidden_states, attention_mask, position_ids, cache_position,\n     page_idx, page_offset, block_table,\n     past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill()\n\n    with pytest.raises(ValueError):\n        attn.forward(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=False,\n            use_cache=True,\n            cache_position=cache_position,\n            is_prefill=True,\n            page_idx=page_idx,\n            page_offset=page_offset,\n            block_table=block_table,\n            q_len_raw=q_len_raw,\n            kv_len_raw=kv_len_raw,\n            stream=None,\n        )\n\n\ndef test_forward_prefill_attn_output_shape_mismatch_raises(monkeypatch):\n    \"\"\"\n    覆盖: attn_output 形状不符合期望时的 ValueError 分支。\n    \"\"\"\n    attn = build_attention_module(q_lora_rank=None)\n\n    def bad_fused(q, k, v, **kwargs):\n        bsz, max_q_len, num_heads, dim = q.shape\n        # 刻意制造 num_heads+1，触发 size 检查不通过\n        out = torch.zeros(\n            bsz, max_q_len, num_heads + 1, attn.v_head_dim,\n            dtype=q.dtype, device=q.device\n        )\n        lse = torch.zeros(1, dtype=q.dtype, device=q.device)\n        return out, lse\n\n    monkeypatch.setattr(\n        torch_npu, \"npu_fused_infer_attention_score\",\n        bad_fused, raising=False\n    )\n\n    (hidden_states, attention_mask, position_ids, cache_position,\n     page_idx, page_offset, block_table,\n     past_key_value, q_len_raw, kv_len_raw) = _common_inputs_prefill()\n\n    with pytest.raises(ValueError):\n        attn.forward(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=False,\n            use_cache=True,\n            cache_position=cache_position,\n            is_prefill=True,\n            page_idx=page_idx,\n            page_offset=page_offset,\n            block_table=block_table,\n            q_len_raw=q_len_raw,\n            kv_len_raw=kv_len_raw,\n            stream=None,\n        )\n\n\ndef test_forward_paged_use_npu_graph(monkeypatch):\n    \"\"\"\n    覆盖: get_use_npu_graph() == True 的 graph 路径。\n    \"\"\"\n    # 让 ascend_attention.get_use_npu_graph 返回 True\n    monkeypatch.setattr(attn_mod, \"get_use_npu_graph\", lambda: True)\n\n    # 伪造 model_runner 模块，满足 import ktransformers.server.balance_serve.inference.model_runner\n    dummy_runner = type(\n        \"DummyRunner\", (), {\"__init__\": lambda self: setattr(self, \"workspace\", [None] * 4)}\n    )\n\n    dummy_mr = types.SimpleNamespace(\n        ModelRunner=dummy_runner,\n        get_or_create_model_runner=lambda device=None: dummy_runner(),\n    )\n\n    sys.modules[\n        \"ktransformers.server.balance_serve.inference.model_runner\"\n    ] = dummy_mr\n\n    attn = build_attention_module(q_lora_rank=None)\n\n    bsz, q_len, hidden_dim = 1, 1, 4\n    hidden_states = torch.randn(bsz, q_len, hidden_dim, dtype=torch.float16)\n    position_ids = torch.arange(q_len).unsqueeze(0)\n    cache_position = torch.arange(q_len).unsqueeze(0)\n    past_key_value = DummyStaticCache(page_size=16)\n    q_len_raw = torch.tensor([q_len], dtype=torch.int32)\n    kv_len_raw = torch.tensor([q_len], dtype=torch.int32)\n    block_table = torch.zeros(bsz, 1, dtype=torch.int32)\n\n    outputs = attn.forward(\n        hidden_states=hidden_states,\n        attention_mask=None,\n        position_ids=position_ids,\n        past_key_value=past_key_value,\n        output_attentions=False,\n        use_cache=True,\n        cache_position=cache_position,\n        is_prefill=False,\n        page_idx=None,\n        page_offset=None,\n        block_table=block_table,\n        q_len_raw=q_len_raw,\n        kv_len_raw=kv_len_raw,\n        stream=None,\n    )\n\n    attn_output, attn_weights, new_cache = outputs\n    assert attn_output.shape == (\n        bsz,\n        q_len,\n        attn.num_heads * attn.v_head_dim,\n    )\n    assert attn_weights is None\n    assert new_cache is past_key_value\n\n"
  },
  {
    "path": "archive/ktransformers/tests/UT/test_kdeepseek_ln_npu.py",
    "content": "import torch\nimport torch.nn as nn\nimport pytest\n\n# 按你实际代码位置改路径：\nfrom ktransformers.operators.ascend.ascend_layernorm import KDeepseekV3RMSNormW8A8\nimport ktransformers.util.utils as utils_mod\n\ntorch_npu = pytest.importorskip(\"torch_npu\")\n\n\n# ==========================\n# Dummy 依赖\n# ==========================\n\nclass DummyOrigModule(nn.Module):\n    def __init__(self, hidden_size=4, variance_epsilon=1e-5):\n        super().__init__()\n        self.hidden_size = hidden_size\n        self.variance_epsilon = variance_epsilon\n\n\nclass DummySafeTensorLoader:\n    def __init__(self):\n        self.tensors = {}\n        self.load_calls = []\n\n    def load_tensor(self, name: str):\n        self.load_calls.append(name)\n        return self.tensors[name]\n\n\nclass DummyGGUFLoader:\n    def __init__(self, safetensor_loader: DummySafeTensorLoader):\n        self.safetensor_loader = safetensor_loader\n\n\nclass DummyConfig:\n    pass\n\n\nclass FakeRMSNorm:\n    def __init__(self):\n        self.last_args = None\n\n    def __call__(self, hidden_states, weight, eps):\n        self.last_args = (hidden_states, weight, eps)\n\n        out = hidden_states * weight\n        return (out,)\n\n\ndef build_rms_module(hidden_size=4, eps=1e-5, safetensor_loader=None):\n    orig = DummyOrigModule(hidden_size=hidden_size, variance_epsilon=eps)\n    if safetensor_loader is None:\n        safetensor_loader = DummySafeTensorLoader()\n    gguf_loader = DummyGGUFLoader(safetensor_loader)\n    config = DummyConfig()\n    module = KDeepseekV3RMSNormW8A8(\n        key=\"rms\",\n        gguf_loader=gguf_loader,\n        config=config,\n        orig_module=orig,\n        prefill_device=\"npu\",\n        generate_device=\"npu\",\n    )\n    return module, safetensor_loader, orig\n\n@pytest.fixture(autouse=True)\ndef patch_utils_and_npu(monkeypatch):\n    monkeypatch.setattr(utils_mod, \"get_current_device\", lambda: \"cpu\", raising=False)\n\n    fake = FakeRMSNorm()\n    monkeypatch.setattr(torch_npu, \"npu_rms_norm\", fake, raising=False)\n\n    import sys\n    sys.modules[__name__]._fake_rms = fake\n\n    yield\n\ndef get_fake_rms():\n    import sys\n    return sys.modules[__name__]._fake_rms\n\ndef test_forward_preserves_shape_and_dtype():\n    hidden_size = 4\n    module, _, orig = build_rms_module(hidden_size=hidden_size, eps=1e-6)\n\n    x = torch.randn(2, 3, hidden_size, dtype=torch.float16)\n\n    out = module(x)\n\n    assert out.shape == x.shape\n    assert out.dtype == x.dtype\n\n    fake_rms = get_fake_rms()\n    hs_arg, w_arg, eps_arg = fake_rms.last_args\n    assert hs_arg is x\n    assert w_arg is module.weight\n    assert eps_arg == orig.variance_epsilon\n\n\ndef test_forward_with_bfloat16_dtype():\n    hidden_size = 4\n    module, _, _ = build_rms_module(hidden_size=hidden_size, eps=1e-6)\n\n    x = torch.randn(1, 2, hidden_size, dtype=torch.bfloat16)\n    out = module(x)\n\n    assert out.shape == x.shape\n    assert out.dtype == torch.bfloat16\n\n\ndef test_forward_uses_bias():\n    hidden_size = 4\n    module, _, _ = build_rms_module(hidden_size=hidden_size, eps=1e-6)\n\n    module.weight.data = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)\n    module.bias.data = torch.tensor([-1.0, 0.5, 0.0, 2.0], dtype=torch.float32)\n\n    x = torch.arange(2 * 3 * hidden_size, dtype=torch.float16).view(2, 3, hidden_size)\n\n    out = module(x)\n\n    expected_rms = x.to(torch.float32) * module.weight\n    expected = expected_rms + module.bias\n\n    assert torch.allclose(out, expected.to(out.dtype))\n\n\n\ndef test_load_from_safetensor_loader():\n    hidden_size = 4\n    module, safe_loader, _ = build_rms_module(hidden_size=hidden_size, eps=1e-5)\n\n    w_loaded = torch.arange(hidden_size, dtype=torch.float32)\n    b_loaded = torch.full((hidden_size,), 3.0, dtype=torch.float32)\n\n    safe_loader.tensors[\"rms.weight\"] = w_loaded\n    safe_loader.tensors[\"rms.bias\"] = b_loaded\n\n    module.load()\n\n    assert torch.allclose(module.weight, w_loaded)\n    assert torch.allclose(module.bias, b_loaded)\n\n    assert safe_loader.load_calls == [\"rms.weight\", \"rms.bias\"]\n\n\ndef test_unload_sets_weight_and_bias_to_none_idempotent():\n    module, _, _ = build_rms_module(hidden_size=4, eps=1e-5)\n\n    assert module.weight is not None\n    assert module.bias is not None\n\n    module.unload()\n    assert module.weight is None\n    assert module.bias is None\n\n    module.unload()\n    assert module.weight is None\n    assert module.bias is None\n\n"
  },
  {
    "path": "archive/ktransformers/tests/dequant_gpu.py",
    "content": "import os \n# os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1,2\"\n# add path\nimport sys\ncurrent_path = os.path.abspath(os.path.dirname(__file__))\nsys.path.append(current_path+\"/../..\")\nimport numpy as np\n# from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin\n# from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch\nfrom ktransformers.util.custom_gguf import GGUFLoader\nimport torch\nimport KTransformersOps\ntorch.set_default_dtype(torch.bfloat16)\nimport time\nfrom transformers import (\n    AutoConfig,\n)\nimport os\n# CUDA_LAUNCH_BLOCKING=1\nos.environ[\"CUDA_LAUNCH_BLOCKING\"]=\"1\"\n\ngguf_config = GGUFLoader(\"/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m\")\nmodel_name = \"/data/Qwen2-57B-A14B-Instruct\"\n\n# Q4k\nkey = \"blk.1.\"\ntarget = \"attn_q.weight\"\n\nt1 = time.time()\nq_weight_cpu = gguf_config.load_gguf_tensor(key+target, \"cpu\")\n# q_weight_cpu = torch.from_numpy(q_weight_cpu)\n\nt2 = time.time()\nq_weight_gpu = gguf_config.load_gguf_tensor(key+target, \"cuda:0\")\nt3 = time.time()\nprint()\nallclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu(), atol=1e-6)\nprint(f\"Q4k {key+target}\")\nprint(\"load gguf tensor from cpu cost: \", t2-t1)\nprint(\"load gguf tensor from gpu cost: \", t3-t2)\nprint(\"allclose: \", allclose)\n\n\n# Q6k\nkey = \"blk.0.\"\ntarget = \"ffn_down_exps.weight\"\n\nt1 = time.time()\nq_weight_cpu = gguf_config.load_gguf_tensor(key+target, \"cpu\")\nt2 = time.time()\nq_weight_gpu = gguf_config.load_gguf_tensor(key+target, \"cuda:0\")\nt3 = time.time()\nprint()\nallclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu().to(torch.float32), atol=1e-6)\nprint(f\"Q6k {key+target}\")\nprint(\"load gguf tensor from cpu cost: \", t2-t1)\nprint(\"load gguf tensor from gpu cost: \", t3-t2)\nprint(\"allclose: \", allclose)\n"
  },
  {
    "path": "archive/ktransformers/tests/dequant_gpu_t.py",
    "content": "import os \nos.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n# add path\nimport sys\nsys.path.append(\"../..\")\nimport pycuda.autoinit\nimport pycuda.driver as cuda\nfrom pycuda.compiler import SourceModule\nimport numpy as np\nfrom ktransformers.operators.linear import KTransformersLinear, KLinearMarlin\nfrom ktransformers.operators.experts import KTransformersExperts, KExpertsTorch\nfrom ktransformers.util.custom_loader import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k\nimport torch\nimport KTransformersOps\ntorch.set_default_dtype(torch.bfloat16)\nimport time\nfrom transformers import (\n    AutoConfig,\n)\n\ngguf_config = GGUFLoader(\"/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m\")\nmodel_name = \"/data/Qwen2-57B-A14B-Instruct\"\nkey = \"blk.0.\"\ntarget = \"ffn_up_exps.weight\"\n\ndata = gguf_config.get_mmap_tensor(key + target)\n\n_, factors, offsets, qs1, qs2= dequantize_q4_k(data)\nfactors_cpu = torch.from_numpy(factors)\noffsets_cpu = torch.from_numpy(offsets)\nqs1_cpu = torch.from_numpy(qs1)\nqs2_cpu = torch.from_numpy(qs2)\n\n\n_, factors, offsets, qs1, qs2 = dequantize_q4_k_gpu(data)\n\nprint(torch.allclose(factors.cpu(), factors_cpu))\nprint(torch.allclose(offsets.cpu(), offsets_cpu))\nprint(torch.allclose(qs1.cpu(), qs1_cpu))\nprint(torch.allclose(qs2.cpu(), qs2_cpu))"
  },
  {
    "path": "archive/ktransformers/tests/function_call_test.py",
    "content": "from openai import OpenAI\n\ndef send_messages(messages):\n    response = client.chat.completions.create(\n        model=\"deepseek-chat\",\n        messages=messages,\n        tools=tools\n    )\n    return response.choices[0].message\n\nclient = OpenAI(\n    api_key=\"placeholder\",\n    base_url=\"http://0.0.0.0:10002/v1\",\n)\n\ntools = [\n    {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_weather\",\n            \"description\": \"Get weather of an location, the user shoud supply a location first\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\n                        \"type\": \"string\",\n                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n                    }\n                },\n                \"required\": [\"location\"]\n            },\n        }\n    },\n]\n\nmessages = [{\"role\": \"user\", \"content\": \"How's the weather in Hangzhou?\"}]\nmessage = send_messages(messages)\nprint(f\"User>\\t {messages[0]['content']}\")\nprint(message)\ntool = message.tool_calls[0]\nmessages.append(message)\n\nmessages.append({\"role\": \"tool\", \"tool_call_id\": tool.id, \"content\": \"24℃\"})\nmessage = send_messages(messages)\nprint(f\"Model>\\t {message.content}\")"
  },
  {
    "path": "archive/ktransformers/tests/humaneval/eval_api.py",
    "content": "# adapt from https://github.com/abacaj/code-eval?tab=readme-ov-file\nimport argparse\nimport os\nimport requests\nfrom human_eval.data import write_jsonl, read_problems\nimport tqdm\n\nfrom evaluation import filter_code, fix_indents\nfrom prompts import instruct_prompt\n\ndef generate_text(api_url,question , model_name, stream=False, auth_token=None):\n    headers = {\n        'accept': 'application/json',\n        'Content-Type': 'application/json',\n        # 添加 API Key\n        'Authorization' : 'Bearer ' + auth_token if auth_token else ''\n    }\n    question = instruct_prompt(question)\n    data = {\n        \"messages\": [{\"content\": question, \"role\": \"user\"}],\n        \"model\": model_name,\n        \"stream\": stream,\n        \"temperature\": 0.6\n    }\n    print(f\"content: {question}\")\n    response = requests.post(api_url, headers=headers, json=data,verify=False)\n    if response.status_code == 200:\n        result = response.json()\n        results = result.get('choices', [{}])[0].get('message', {}).get('content', '')\n        return [filter_code(fix_indents(results))]\n    else:\n        print(f\"API Request failed with status code {response.status_code}\")\n        return None\n\ndef run_eval_api(\n    api_url: str,\n    model_name: str,\n    out_path: str,\n    format_tabs: bool = False,\n    auth_token: str = None,\n    problem_file: str = None,\n    append: bool = False,\n    skip: int = 0\n):\n    if(problem_file is None):\n        problems = read_problems()\n    else:\n        problems = read_problems(problem_file)\n    samples = []\n    pbar = tqdm.tqdm(total=len(problems) * 1)\n    pbar.update(skip)\n    try:\n        for task_id in problems:\n            # skip some tasks\n            if skip > 0:\n                skip -= 1\n                continue\n\n            if format_tabs:\n                prompt = problems[task_id][\"prompt\"].replace(\"    \", \"\\t\")\n            else:\n                prompt = problems[task_id][\"prompt\"]\n            completion = generate_text(api_url, prompt, model_name, auth_token=auth_token)\n            # samples.append({\"task_id\": task_id, \"completion\": completion})\n            for sample in completion:\n                result = dict(\n                    task_id=task_id,\n                    completion=sample,\n                )\n                samples += [result]\n                if append:\n                    write_jsonl(out_path, [result],append=append)\n            pbar.update(1)\n        if not append:\n            write_jsonl(out_path, samples,append=append)\n    except Exception as e:\n        if not append:\n            write_jsonl(out_path, samples,append=append)\n        print(f\"Error: {e}\")\n\ndef main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip):\n    os.makedirs(os.path.dirname(output_path), exist_ok=True)\n    run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"API Generate Tester\")\n    #parser.add_argument(\"--api_url\", type=str, default=\"https://api.siliconflow.cn/v1/chat/completions\", help=\"API URL\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10002/v1/chat/completions\", help=\"API URL\")\n    parser.add_argument(\"--model_name\", type=str, default=\"Pro/deepseek-ai/DeepSeek-V3\", help=\"Model Name\")\n    parser.add_argument(\"--out_path\", type=str, default=\"results/api/eval_b.jsonl\", help=\"Output Path\")\n    parser.add_argument(\"--auth_token\", type=str, default=None, help=\"Auth Token\")\n    parser.add_argument(\"--format_tabs\", action=\"store_true\", help=\"Format Tabs\")\n    parser.add_argument(\"--problem_file\", type=str, default=None, help=\"Evalset File\")\n    parser.add_argument(\"--no_append\", action=\"store_false\", help=\"Append to existing file\")\n    parser.add_argument(\"--skip\", type=int, default=0, help=\"Skip first n problems\")\n    args = parser.parse_args()\n    # api_url = \"https://api.siliconflow.cn/v1/chat/completions\"\n    main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append,args.skip)"
  },
  {
    "path": "archive/ktransformers/tests/humaneval/evaluation.py",
    "content": "# reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35\ndef filter_code(completion: str) -> str:\n    # The program tends to overwrite, we only take the first function\n    completion = completion.lstrip(\"\\n\")\n    # we also remove ```python\\n and ```\n    completion = completion.replace(\"```python\\n\", \"\").replace(\"```\", \"\")\n    if 'if __name__ == \"__main__\":' in completion:\n        completion = completion.split('if __name__ == \"__main__\":')[0]\n    if \"# Example usage\" in completion:\n        completion = completion.split(\"# Example usage\")[0]\n    return completion\n\n\ndef fix_indents(text: str) -> str:\n    return text.replace(\"\\t\", \"    \")\n"
  },
  {
    "path": "archive/ktransformers/tests/humaneval/prompts.py",
    "content": "def instruct_prompt(prompt: str) -> str:\n    return f\"\"\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\nComplete the following Python code without any tests or explanation\\n{prompt}\\n\\n### Response:\"\"\"\n\n\ndef standard_prompt(prompt: str) -> str:\n    return f\"\"\"Complete the following Python code without any tests or explanation\\n{prompt}\"\"\"\n\n\ndef write_prompt(prompt: str) -> str:\n    return f\"\"\"Write a python program to complete the following code:\\n{prompt}\"\"\"\n\n\ndef replit_glaive_prompt(prompt: str) -> str:\n    return f\"\"\"Below is an instruction that describes a task, paired with an input that provides further context.\\n Write a response that appropriately completes the request.\\n\\n ### Instruction:\\nWrite a program to perform the given task.\\n\\n Input:\\n{prompt}\\n\\n### Response:\"\"\"\n"
  },
  {
    "path": "archive/ktransformers/tests/mmlu_pro_test.py",
    "content": "import argparse\nimport random\nimport time\nimport json\nimport requests\nimport pandas as pd\nfrom datasets import load_dataset\n\nimport os\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\nos.environ['https_proxy'] = ''\nos.environ['http_proxy'] = ''\nhint = 'There is a single choice question. Answer the question by replying A, B, C, D, E, F, G, H, I, J. No other answers are accepted. Just the letter.'\n\n\nclass DataEvaluator:\n    def __init__(self):\n        # self.template_prompt = template_prompt\n        self.data = []\n\n    def load_data(self, file_path):\n        \"\"\"\n        Load data from a Parquet file into a list.\n        Each record in the Parquet file should represent an individual record.\n        \"\"\"\n        # 读取 Parquet 文件\n        # dataset = load_dataset('parquet', data_files=file_path)\n        ds = load_dataset(\"TIGER-Lab/MMLU-Pro\")\n        df = pd.DataFrame(ds['test'])\n        # print(ds)\n        # # ds_1 =  ds['train']\n        # ds_2 =  ds['validation']\n        # ds_3 =  ds['test']\n        # # 将数据集转换为 Pandas DataFrame\n        # df_test = pd.DataFrame(ds['test'])\n        # df_val = pd.DataFrame(ds['validation'])\n\n        # for _, row in df.iterrows():\n        #     self.data.append(row.to_dict())\n        # df = pd.read_parquet(file_path)\n\n        for _, row in df.iterrows():\n            self.data.append(row.to_dict())\n\n    def get_prompt(self, record):\n        \"\"\"\n        Combine fields from a record with the template prompt to create a full prompt.\n        :param record: Dictionary containing fields to populate the template.\n        :return: A formatted prompt string.\n        \"\"\"\n        # 查看ABCD。。。的选项\n        options_str = \"\\n\".join([f\"{chr(65+i)}. {opt}\" for i, opt in enumerate(record['options'])])\n        prompt = hint + \"\\nQuestion: \" + record['question'] + \"\\n\" + options_str + \"\\nAnswer: '\"\n        return prompt\n        \n    def post_processing(self, text):\n        \"\"\"\n        Perform post-processing on the prediction string.\n        :param text: The raw prediction string.\n        :return: Processed prediction string.\n        \"\"\"\n        text = text.lstrip('\\n').split('\\n')[-1]\n        return text[-1:]\n\n    def score(self, pred, answers):\n        \"\"\"\n        Calculate scores between the prediction and the answer.\n        Uses ROUGE scores as the evaluation metric.\n        :param pred: The predicted string.\n        :param answer: The reference answer string.\n        :return: A dictionary containing ROUGE scores.\n        \"\"\"\n        for answer in answers:\n            if pred == answer:\n                return 1\n\n        return 0\n\n# Function to generate text using API\ndef generate_text(api_url, question, model_name, stream=False):\n    headers = {\n        'accept': 'application/json',\n        'Content-Type': 'application/json',\n        # 添加 API Key\n        'Authorization' : 'Bearer '\n    }\n    data = {\n        \"messages\": [{\"content\": question, \"role\": \"user\"}],\n        \"model\": model_name,\n        \"stream\": stream,\n        # \"temperature\": 0.0\n    }\n    \n    print(\"POST data:\", data)\n    response = requests.post(api_url, headers=headers, json=data)\n    \n    if response.status_code == 200:\n        result = response.json()\n        return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()\n    else:\n        print(f\"API Request failed with status code {response.status_code}\")\n        return None\n\n# Main function to handle multiple evaluations\ndef main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):\n    start_total_time = time.time()\n\n    total_score = 0\n\n    results = []\n    # 设置随机数种子\n    random.seed(42)\n    random.shuffle(data_evaluator.data)\n    for i in range(min(concurrent_requests, len(data_evaluator.data))):\n        # Randomly select a data item from data for each request\n        data_item = data_evaluator.data[i]\n        question = data_evaluator.get_prompt(data_item)\n        # print(question)\n\n        # Start the timer for this evaluation\n        start_time = time.time()\n        try:\n            # Generate prediction using the API\n            prediction = generate_text(api_url, question, model_name)\n\n            if prediction is None:\n                raise Exception(f\"Failed to get prediction for {question}\")\n\n            answer = data_item['answer']\n            # Compute score\n            score = data_evaluator.score(data_evaluator.post_processing(prediction), answer)\n\n            # Calculate the time taken\n            elapsed_time = time.time() - start_time\n\n            # Collect the result data\n            result_data = {\n                \"question_id\": data_item['question_id'],\n                \"answer\": answer,\n                \"prediction\": data_evaluator.post_processing(prediction),\n                \"score\": score,\n                \"time\": elapsed_time\n            }\n\n            # Write results to result.json with each field on a new line\n            with open(result_file, 'a', encoding='utf-8') as f:\n                json.dump(result_data, f, ensure_ascii=False, indent=4)\n                f.write(\"\\n\")  # Ensure each JSON object is on a new line\n\n            results.append(result_data)\n\n            # Aggregate scores\n            total_score += score\n\n        except Exception as e:\n            print(f\"Error processing request {i}: {e}\")\n\n    # Calculate total time and throughput\n    total_time = time.time() - start_total_time\n    throughput = concurrent_requests / total_time\n\n    # Log the total time, throughput, and average ROUGE scores\n    with open(log_file, 'a', encoding='utf-8') as log_f:\n        log_f.write(f\"Total Time: {total_time:.2f} seconds\\n\")\n        log_f.write(f\"Throughput: {throughput:.2f} requests per second\\n\")\n        log_f.write(f\"Average Scores: {total_score / concurrent_requests}\\n\")\n        log_f.write('-' * 40 + '\\n')\n\n    print(f\"Results saved to {result_file}\")\n    print(f\"Log saved to {log_file}\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"API Generate Tester\")\n    parser.add_argument(\"--concurrent\", type=int, default=1000, help=\"Number of concurrent evaluations\")\n    parser.add_argument(\"--file\", type=str, default=\"TIGER-Lab/MMLU-Pro\", help=\"Path to the mmlu.jsonl file\")\n    parser.add_argument(\"--result\", type=str, default=\"./mmlu_result_pro.json\", help=\"Path to save the result JSON file\")\n    parser.add_argument(\"--log\", type=str, default=\"./mmlu_result_pro.log\", help=\"Path to save the log file\")\n    parser.add_argument(\"--model\", type=str, default=\"Pro/deepseek-ai/DeepSeek-V3\", help=\"Model name or path\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:15488/v1/chat/completions\", help=\"API URL\")\n    # parser.add_argument(\"--api_url\", type=str, default=\"https://api.siliconflow.cn/v1/chat/completions\", help=\"API URL\")\n\n    args = parser.parse_args()\n\n    # Load the data from the provided file\n    # template_prompt = hint + \"\\nQuestion: {question}\\nA. {options}\\nB. {option_b}\\nC. {option_c}\\nD. {option_d}\\nAnswer: '\"\n    # template_prompt_pro = hint + \"\\nQuestion: {question}\\nA. {options[0]}\\nB. {options[1]}\\nC. {options[2]}\\nD. {options[3]}\\nE. {options[4]}\\nF. {options[5]}\\nG. \\\n        # {options[6]}\\nH. {options[7]}\\nI. {options[8]}\\nJ. {options[9]}\\nAnswer: '\"\n\n\n    # Load the data from the provided file\n    data_evaluator = DataEvaluator()\n    data_evaluator.load_data(args.file)\n\n    # Run the main function with the specified number of concurrent evaluations\n    main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)"
  },
  {
    "path": "archive/ktransformers/tests/mmlu_test.py",
    "content": "import argparse\nimport random\nimport time\nimport json\nimport requests\nimport pandas as pd\nfrom datasets import load_dataset\n\nimport os\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\nos.environ['https_proxy'] = ''\nos.environ['http_proxy'] = ''\nhint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'\n\n\nclass DataEvaluator:\n    def __init__(self):\n        # self.template_prompt = template_prompt\n        self.data = []\n\n    def load_data(self, file_path):\n        \"\"\"\n        Load data from a Parquet file into a list.\n        Each record in the Parquet file should represent an individual record.\n        \"\"\"\n        # 读取 Parquet 文件\n        # dataset = load_dataset('parquet', data_files=file_path)\n        splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',\n                  'dev': 'all/dev-00000-of-00001.parquet',\n                  'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}\n        df = pd.read_parquet(\"hf://datasets/cais/mmlu/\" + splits[\"test\"])\n\n        for _, row in df.iterrows():\n            self.data.append(row.to_dict())\n\n    def get_prompt(self, record):\n        \"\"\"\n        Combine fields from a record with the template prompt to create a full prompt.\n        :param record: Dictionary containing fields to populate the template.\n        :return: A formatted prompt string.\n        \"\"\"\n        # 查看ABCD。。。的选项\n        options_str = \"\\n\".join([f\"{chr(65 + i)}. {opt}\" for i, opt in enumerate(record['choices'])])\n        prompt = hint + \"\\nQuestion: \" + record['question'] + \"\\n\" + options_str + \"\\nAnswer: '\"\n        return prompt\n        \n    def post_processing(self, text):\n        \"\"\"\n        Perform post-processing on the prediction string.\n        :param text: The raw prediction string.\n        :return: Processed prediction string.\n        \"\"\"\n        text = text.lstrip('\\n').split('\\n')[-1]\n        return text[-1:]\n\n    def score(self, pred, answers):\n        \"\"\"\n        Calculate scores between the prediction and the answer.\n        Uses ROUGE scores as the evaluation metric.\n        :param pred: The predicted string.\n        :param answer: The reference answer string.\n        :return: A dictionary containing ROUGE scores.\n        \"\"\"\n        for answer in answers:\n            if pred == answer:\n                return 1\n\n        return 0\n\n# Function to generate text using API\ndef generate_text(api_url, question, model_name, stream=False):\n    headers = {\n        'accept': 'application/json',\n        'Content-Type': 'application/json',\n        # 添加 API Key\n        'Authorization' : 'Bearer '\n    }\n    data = {\n        \"messages\": [{\"content\": question, \"role\": \"user\"}],\n        \"model\": model_name,\n        \"stream\": stream,\n        # \"temperature\": 0.0\n    }\n    \n    print(\"POST data:\", data)\n    response = requests.post(api_url, headers=headers, json=data)\n    \n    if response.status_code == 200:\n        result = response.json()\n        return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()\n    else:\n        print(f\"API Request failed with status code {response.status_code}\")\n        return None\n\n# Main function to handle multiple evaluations\ndef main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):\n    start_total_time = time.time()\n\n    total_score = 0\n\n    results = []\n   # 设置随机数种子\n    random.seed(42)\n    random.shuffle(data_evaluator.data)\n    for i in range(min(concurrent_requests, len(data_evaluator.data))):\n        # Randomly select a data item from data for each request\n        data_item = data_evaluator.data[i]\n        question = data_evaluator.get_prompt(data_item)\n        # print(question)\n\n        # Start the timer for this evaluation\n        start_time = time.time()\n        try:\n            # Generate prediction using the API\n            prediction = generate_text(api_url, question, model_name)\n\n            if prediction is None:\n                raise Exception(f\"Failed to get prediction for {question}\")\n\n            answer = chr(data_item['answer'] + 65)\n            # Compute score\n            score = data_evaluator.score(data_evaluator.post_processing(prediction), answer)\n\n            # Calculate the time taken\n            elapsed_time = time.time() - start_time\n\n            # Collect the result data\n            result_data = {\n                \"question_id\": i,\n                \"answer\": answer,\n                \"prediction\": data_evaluator.post_processing(prediction),\n                \"score\": score,\n                \"time\": elapsed_time\n            }\n\n            # Write results to result.json with each field on a new line\n            with open(result_file, 'a', encoding='utf-8') as f:\n                json.dump(result_data, f, ensure_ascii=False, indent=4)\n                f.write(\"\\n\")  # Ensure each JSON object is on a new line\n\n            results.append(result_data)\n\n            # Aggregate scores\n            total_score += score\n\n        except Exception as e:\n            print(f\"Error processing request {i}: {e}\")\n\n    # Calculate total time and throughput\n    total_time = time.time() - start_total_time\n    throughput = concurrent_requests / total_time\n\n    # Log the total time, throughput, and average ROUGE scores\n    with open(log_file, 'a', encoding='utf-8') as log_f:\n        log_f.write(f\"Total Time: {total_time:.2f} seconds\\n\")\n        log_f.write(f\"Throughput: {throughput:.2f} requests per second\\n\")\n        log_f.write(f\"Average Scores: {total_score / concurrent_requests}\\n\")\n        log_f.write('-' * 40 + '\\n')\n\n    print(f\"Results saved to {result_file}\")\n    print(f\"Log saved to {log_file}\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"API Generate Tester\")\n    parser.add_argument(\"--concurrent\", type=int, default=1000, help=\"Number of concurrent evaluations\")\n    parser.add_argument(\"--file\", type=str, default=\"cais/mmlu\", help=\"Path to the mmlu.jsonl file\")\n    parser.add_argument(\"--result\", type=str, default=\"./mmlu_result_silicon.json\", help=\"Path to save the result JSON file\")\n    parser.add_argument(\"--log\", type=str, default=\"./mmlu_result_silicon.log\", help=\"Path to save the log file\")\n    parser.add_argument(\"--model\", type=str, default=\"Pro/deepseek-ai/DeepSeek-V3\", help=\"Model name or path\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10003/v1/chat/completions\", help=\"API URL\")\n    # parser.add_argument(\"--api_url\", type=str, default=\"https://api.siliconflow.cn/v1/chat/completions\", help=\"API URL\")\n\n    args = parser.parse_args()\n\n    # Load the data from the provided file\n    # template_prompt = hint + \"\\nQuestion: {question}\\nA. {options}\\nB. {option_b}\\nC. {option_c}\\nD. {option_d}\\nAnswer: '\"\n    # template_prompt_pro = hint + \"\\nQuestion: {question}\\nA. {options[0]}\\nB. {options[1]}\\nC. {options[2]}\\nD. {options[3]}\\nE. {options[4]}\\nF. {options[5]}\\nG. \\\n        # {options[6]}\\nH. {options[7]}\\nI. {options[8]}\\nJ. {options[9]}\\nAnswer: '\"\n\n\n    # Load the data from the provided file\n    data_evaluator = DataEvaluator()\n    data_evaluator.load_data(args.file)\n\n    # Run the main function with the specified number of concurrent evaluations\n    main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)"
  },
  {
    "path": "archive/ktransformers/tests/mmlu_test_multi.py",
    "content": "import argparse\nimport random\nimport time\nimport json\nimport requests\nimport pandas as pd\nfrom datasets import load_dataset\nimport os\nimport concurrent.futures\nimport threading\nimport re\n\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\nos.environ['https_proxy'] = ''\nos.environ['http_proxy'] = ''\nhint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'\n\n\ndef extract_final_answer(text):\n    \"\"\"\n    提取模型预测的最终选项（如 A/B/C/D）\n    支持自然语言、多行、markdown、高亮、非末尾结论等格式\n    \"\"\"\n    text = text.strip()\n\n    # 1. 显式语句匹配（优先）\n    explicit_patterns = [\n        r'Answer:\\s*([A-D])\\b',\n        r'Correct answer:\\s*([A-D])\\b',\n        r'The correct answer is\\s*\\*?\\*?\\s*([A-D])\\b',\n        r'Answer is\\s*([A-D])\\b',\n        r'Therefore,\\s*answer is\\s*([A-D])\\b',\n        r'Therefore,\\s*the answer should be\\s*(?:Option\\s*)?([A-D])\\b',\n        r'The answer should be\\s*(?:Option\\s*)?([A-D])\\b',\n        r'Option\\s+([A-D])\\s+is correct',\n    ]\n    for pat in explicit_patterns:\n        match = re.search(pat, text, re.IGNORECASE)\n        if match:\n            return match.group(1).upper()\n\n    # 2. markdown 强调 **C**, **C. something**\n    markdown_match = re.findall(r'\\*\\*\\s*([A-D])[\\.\\s]?', text)\n    if markdown_match:\n        return markdown_match[-1].upper()\n\n    # 3. 查找单引号中的 'C' 或 \"C\"\n    quote_match = re.findall(r\"['\\\"]([A-D])['\\\"]\", text)\n    if quote_match:\n        return quote_match[-1].upper()\n\n    # 4. 倒数几行是否以 \"C.\" 或 \"C\" 开头\n    lines = text.splitlines()\n    for line in reversed(lines[-5:]):\n        line = line.strip()\n        match = re.match(r'^([A-D])([.\\s]|$)', line)\n        if match:\n            return match.group(1).upper()\n    \n    # 再不行就返回 None\n    return None\nclass DataEvaluator:\n    def __init__(self):\n        self.data = []\n\n    def load_data(self, file_path):\n        \"\"\"\n        从数据文件中加载数据，每条记录对应一个实例\n        \"\"\"\n        splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',\n                  'dev': 'all/dev-00000-of-00001.parquet',\n                  'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}\n        df = pd.read_parquet(\"hf://datasets/cais/mmlu/\" + splits[\"test\"])\n        for _, row in df.iterrows():\n            self.data.append(row.to_dict())\n\n    def get_prompt(self, record):\n        \"\"\"\n        结合提示信息和记录数据生成完整的题目\n        \"\"\"\n        options_str = \"\\n\".join([f\"{chr(65 + i)}. {opt}\" for i, opt in enumerate(record['choices'])])\n        prompt = hint + \"\\nQuestion: \" + record['question'] + \"\\n\" + options_str + \"\\nAnswer: '\"\n        return prompt\n\n    def post_processing(self, text):\n        \"\"\"\n        对生成的文本进行后处理，提取最终答案（只返回最后一个字符）\n        \"\"\"\n        text = text.lstrip('\\n').split('\\n')[-1]\n        return text[-1:]\n\n    def score(self, pred, answer):\n        \"\"\"\n        对比预测答案和正确答案，返回得分\n        \"\"\"\n        if pred == answer:\n            return 1\n        return 0\n\ndef generate_text(api_url, question, model_name, stream=False):\n    headers = {\n        'accept': 'application/json',\n        'Content-Type': 'application/json',\n        'Authorization': 'Bearer '  # 如有需要，请填入 API Key\n    }\n    data = {\n        \"messages\": [{\"content\": question, \"role\": \"user\"}],\n        \"model\": model_name,\n        \"stream\": stream,\n    }\n    print(\"POST data:\", data)\n    response = requests.post(api_url, headers=headers, json=data, timeout=5000000)\n    if response.status_code == 200:\n        result = response.json()\n        return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()\n    else:\n        print(f\"API Request failed with status code {response.status_code}\")\n        return None\n\ndef main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):\n    start_total_time = time.time()\n    total_score = 0\n    total_exact_score = 0\n    results = []\n    file_lock = threading.Lock()\n    \n    # 打乱数据顺序，并选择需要测试的实例数\n    random.seed(42)\n    random.shuffle(data_evaluator.data)\n    data_subset = data_evaluator.data[:min(concurrent_requests, len(data_evaluator.data))]\n    \n    batch_size = 10  # 每批次最多 10 个实例\n\n    def worker(index, data_item):\n        nonlocal total_score\n        nonlocal total_exact_score\n        question = data_evaluator.get_prompt(data_item)\n        start_time = time.time()\n        try:\n            prediction = generate_text(api_url, question, model_name)\n            if prediction is None:\n                raise Exception(f\"Failed to get prediction for question: {question}\")\n            # 正确答案：将数字转换成字母（0->A, 1->B, 2->C, 3->D）\n            answer = chr(data_item['answer'] + 65)\n            processed_prediction = data_evaluator.post_processing(prediction)\n            score = data_evaluator.score(processed_prediction, answer)\n            exact_score = data_evaluator.score(extract_final_answer(prediction), answer)\n            elapsed_time = time.time() - start_time\n            result_data = {\n                \"question_id\": index,\n                \"answer\": answer,\n                \"prediction\": processed_prediction,\n                \"full_prediction\": prediction,\n                \"score\": score,\n                \"exact_score\": exact_score,\n                \"time\": elapsed_time\n            }\n            # 写入结果时加锁保证线程安全\n            with file_lock:\n                with open(result_file, 'a', encoding='utf-8') as f:\n                    json.dump(result_data, f, ensure_ascii=False, indent=4)\n                    f.write(\"\\n\")\n            return result_data\n        except Exception as e:\n            print(f\"Error processing request {index}: {e}\")\n            return None\n\n    # 按批次处理，每批最多 10 个任务\n    for batch_start in range(0, len(data_subset), batch_size):\n        batch = data_subset[batch_start: batch_start + batch_size]\n        with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:\n            futures = [executor.submit(worker, batch_start + j, data_item) for j, data_item in enumerate(batch)]\n            for future in concurrent.futures.as_completed(futures):\n                res = future.result()\n                if res is not None:\n                    results.append(res)\n                    total_score += res['score']\n                    total_exact_score += res['exact_score']\n    \n    total_time = time.time() - start_total_time\n    throughput = len(data_subset) / total_time if total_time > 0 else 0\n    \n    with open(log_file, 'a', encoding='utf-8') as log_f:\n        log_f.write(f\"Total Time: {total_time:.2f} seconds\\n\")\n        log_f.write(f\"Throughput: {throughput:.2f} requests per second\\n\")\n        average_score = total_score / len(data_subset) if data_subset else 0\n        log_f.write(f\"Average Score: {average_score}\\n\")\n        average_exact_score = total_exact_score / len(data_subset) if data_subset else 0\n        log_f.write(f\"Average Exact Score: {average_exact_score}\\n\")\n        log_f.write('-' * 40 + '\\n')\n    \n    print(f\"Results saved to {result_file}\")\n    print(f\"Log saved to {log_file}\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"API Generate Tester\")\n    parser.add_argument(\"--concurrent\", type=int, default=1000, help=\"需要测试的实例总数\")\n    parser.add_argument(\"--file\", type=str, default=\"cais/mmlu\", help=\"数据文件路径\")\n    parser.add_argument(\"--result\", type=str, default=\"./mmlu_result_silicon.json\", help=\"结果文件保存路径\")\n    parser.add_argument(\"--log\", type=str, default=\"./mmlu_result_silicon.log\", help=\"日志文件保存路径\")\n    parser.add_argument(\"--model\", type=str, default=\"Pro/deepseek-ai/DeepSeek-V3\", help=\"模型名称或路径\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10006/v1/chat/completions\", help=\"API URL\")\n\n    args = parser.parse_args()\n    \n    data_evaluator = DataEvaluator()\n    data_evaluator.load_data(args.file)\n    \n    main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)"
  },
  {
    "path": "archive/ktransformers/tests/parse_cover_info.py",
    "content": "import os\nimport ast\nimport argparse\nfrom coverage import Coverage\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"统计某个类在 .coverage 数据中的行覆盖率\"\n    )\n    parser.add_argument(\n        \"--data-file\",\n        default=\".coverage\",\n        help=\"coverage 数据文件路径（默认 ./.coverage）\",\n    )\n    parser.add_argument(\n        \"--file\",\n        dest=\"file_pattern\",\n        default=\"ktransformers/operators/ascend/ascend_attention.py\",\n        help=(\n            \"要统计的源码文件路径（可用结尾匹配，默认 \"\n            \"ktransformers/operators/ascend/ascend_attention.py）\"\n        ),\n    )\n    parser.add_argument(\n        \"--class\",\n        dest=\"class_name\",\n        default=\"KDeepseekV2AttentionW8A8A2Serve\",\n        help=\"要统计的类名（默认 KDeepseekV2AttentionW8A8A2Serve）\",\n    )\n\n    args = parser.parse_args()\n\n    if not os.path.exists(args.data_file):\n        print(f\"找不到 coverage 数据文件: {args.data_file}\")\n        raise SystemExit(1)\n\n    cov = Coverage(data_file=args.data_file)\n    cov.load()\n    data = cov.get_data()\n\n    file_pattern_norm = os.path.normpath(args.file_pattern)\n\n    target_file = None\n    for f in data.measured_files():\n        f_norm = os.path.normpath(f)\n        if f_norm.endswith(file_pattern_norm) or file_pattern_norm in f_norm:\n            target_file = f\n            break\n\n    if not target_file:\n        print(\n            f\"没有在 coverage 数据里找到匹配文件: {args.file_pattern}\\n\"\n            f\"实际记录的文件有:\"\n        )\n        for f in data.measured_files():\n            print(\"  \", f)\n        raise SystemExit(1)\n\n    print(\"使用的源码文件:\", target_file)\n    executed_lines = set(data.lines(target_file) or [])\n    try:\n        with open(target_file, \"r\", encoding=\"utf-8\") as f:\n            source_text = f.read()\n    except OSError as e:\n        print(f\"无法打开源码文件 {target_file}: {e}\")\n        raise SystemExit(1)\n\n    source_lines = source_text.splitlines()\n    tree = ast.parse(source_text)\n\n    class_start = None\n    class_end = None\n\n    for node in tree.body:\n        if isinstance(node, ast.ClassDef) and node.name == args.class_name:\n            class_start = node.lineno\n            max_lineno = node.lineno\n            for sub in ast.walk(node):\n                ln = getattr(sub, \"end_lineno\", getattr(sub, \"lineno\", None))\n                if ln is not None and ln > max_lineno:\n                    max_lineno = ln\n            class_end = max_lineno\n            break\n\n    if class_start is None:\n        print(f\"在源码 {target_file} 中没有找到类 {args.class_name}\")\n        raise SystemExit(1)\n\n    print(\n        f\"类 {args.class_name} 行范围: {class_start} ~ {class_end}\"\n    )\n\n    total = 0\n    covered = 0\n    missed_lines = []\n\n    for lineno in range(class_start, class_end + 1):\n        line = source_lines[lineno - 1].strip()\n        # 跳过空行和纯注释\n        if not line or line.startswith(\"#\"):\n            continue\n\n        total += 1\n        if lineno in executed_lines:\n            covered += 1\n        else:\n            missed_lines.append(lineno)\n\n    percent = (covered / total * 100) if total > 0 else 0.0\n\n    print(\n        f\"类 {args.class_name} 覆盖: {covered}/{total} 行, 覆盖率 = {percent:.1f}%\"\n    )\n    if missed_lines:\n        print(\"未覆盖行号:\", missed_lines)\n    else:\n        print(\"该类所有有效代码行均被覆盖\")\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "archive/ktransformers/tests/score.py",
    "content": "import subprocess\nimport time\nimport requests\nimport sys\nimport os\n\ndef wait_for_server(base_url: str, timeout: int = None) -> None:\n    start_time = time.time()\n    while True:\n        try:\n            response = requests.get(\n                f\"{base_url}/v1/models\",\n                headers={\"Authorization\": \"Bearer None\"},\n            )\n            if response.status_code == 200:\n                print(\"Server is ready.\")\n                break\n        except requests.exceptions.RequestException:\n            time.sleep(1)\n            if timeout and time.time() - start_time > timeout:\n                raise TimeoutError(\"Server did not become ready within timeout period\")\n\nserver_cmd = [\n    \"numactl\", \"-N\", \"1\", \"-m\", \"1\",\n    \"/home/qujing3/anaconda3/envs/ktransformers-dev/bin/ktransformers\",\n    \"--model_path\", \"/home/qujing3/models/DeepSeek-R1-Q4_K_M/config\",\n    \"--gguf_path\", \"/home/qujing3/models/DeepSeek-V3-GGUF/DeepSeek-V3-Q4_K_M\",\n    \"--port\", \"10002\",\n    \"--cpu_infer\", \"48\",\n    \"--optimize_config_path\", \"ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml\",\n    \"--max_new_tokens\", \"3000\",\n    \"--cache_lens\", \"6000\"\n]\n\nprint(\"Starting ktransformers server...\")\nprint(\" \".join(server_cmd))\nwith open(\"/tmp/server_log.txt\", \"w\") as f:\n    server_process = subprocess.Popen(server_cmd, stdout=f, stderr=f, text=True)\n\ntry:\n    wait_for_server(\"http://localhost:10002\", timeout=600)\n\n    eval_cmd = [\"python\", \"ktransformers/tests/humaneval/eval_api.py\"]\n    print(\"Running eval_api.py...\")\n    print(f\"Command: {' '.join(eval_cmd)}\")\n    \n    env = os.environ.copy()\n    env[\"PYTHONUNBUFFERED\"] = \"1\"\n    \n    eval_process = subprocess.Popen(\n        eval_cmd,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n        text=True,\n        bufsize=1,\n        env=env,\n        universal_newlines=True\n    )\n    \n    import threading\n    import queue\n    \n    def enqueue_output(out, queue):\n        for line in iter(out.readline, ''):\n            queue.put(line)\n        out.close()\n    \n    stdout_queue = queue.Queue()\n    stderr_queue = queue.Queue()\n    \n    stdout_thread = threading.Thread(target=enqueue_output, args=(eval_process.stdout, stdout_queue))\n    stderr_thread = threading.Thread(target=enqueue_output, args=(eval_process.stderr, stderr_queue))\n    \n    stdout_thread.daemon = True\n    stderr_thread.daemon = True\n    stdout_thread.start()\n    stderr_thread.start()\n    \n    while eval_process.poll() is None:\n        try:\n            line = stdout_queue.get_nowait()\n            print(line, end='', flush=True)\n        except queue.Empty:\n            pass\n            \n        try:\n            line = stderr_queue.get_nowait()\n            print(line, end='', file=sys.stderr, flush=True)\n        except queue.Empty:\n            pass\n        \n        time.sleep(1)\n\n    while not stdout_queue.empty():\n        print(stdout_queue.get(), end='', flush=True)\n    while not stderr_queue.empty():\n        print(stderr_queue.get(), end='', file=sys.stderr, flush=True)\n        \n    eval_process.wait()\n    print(f\"eval_api.py completed with exit code: {eval_process.returncode}\")\n\n    evaluate_cmd = [\n        \"evaluate_functional_correctness\",\n        \"ktransformers/tests/humaneval/results/api/eval_b.jsonl\"\n    ]\n    print(\"Running evaluate_functional_correctness...\")\n    print(f\"Command: {' '.join(evaluate_cmd)}\")\n    \n    evaluate_process = subprocess.Popen(\n        evaluate_cmd,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n        text=True,\n        bufsize=1,\n        universal_newlines=True\n    )\n    \n    for line in evaluate_process.stdout:\n        print(line, end='', flush=True)\n    for line in evaluate_process.stderr:\n        print(line, end='', file=sys.stderr, flush=True)\n        \n    evaluate_process.wait()\n    \n    print(f\"evaluate_functional_correctness completed with exit code: {evaluate_process.returncode}\")\n    if evaluate_process.returncode != 0:\n        print(f\"evaluate_functional_correctness exited with code {evaluate_process.returncode}\")\n        sys.exit(evaluate_process.returncode)\n\nfinally:\n    print(\"Stopping ktransformers server...\")\n    server_process.terminate()\n    try:\n        server_process.wait(timeout=30)\n    except subprocess.TimeoutExpired:\n        print(\"Server did not terminate gracefully, forcing...\")\n        server_process.kill()"
  },
  {
    "path": "archive/ktransformers/tests/test_client.py",
    "content": "import asyncio\nimport json\nimport sys\nimport aiohttp\nimport argparse\n\nprompt_list = [\n    'Please elaborate on modern world history.',\n    'Please introduce Harry Potter.',\n    'I want to learn Python. Please give me some advice.',\n    'Please tell me a joke '\n]\n\n\nasync def fetch_event_stream(session, payload, request_id, stream):\n    try:\n        headers = {\n            'accept': 'application/json',\n            'Content-Type': 'application/json'\n        }\n\n        async with session.post(SERVER_URL, json=payload, headers=headers, timeout=50000) as response:\n            print(f\"Request {request_id}: Connected, status {response.status}\")\n\n            if response.status != 200:\n                print(f\"Request {request_id}: Error, status {response.status}\")\n                return\n\n            output_text = \"\"\n\n            if stream:\n                async for line in response.content:\n                    try:\n                        decoded_line = line.decode(\"utf-8\").strip()\n                        if not decoded_line or not decoded_line.startswith(\"data: \"):\n                            continue\n\n                        decoded_line = decoded_line[6:].strip()\n                        if not decoded_line:\n                            continue\n\n                        response_data = json.loads(decoded_line)\n                        choices = response_data.get(\"choices\", [])\n                        if not choices:\n                            continue\n\n                        delta = choices[0].get(\"delta\", {})\n                        token = delta.get(\"content\", \"\")\n\n                        if token:\n                            output_text += token\n                            sys.stdout.write(token)\n                            sys.stdout.flush()\n\n                        finish_reason = choices[0].get(\"finish_reason\", None)\n                        if finish_reason:\n                            break\n\n                    except json.JSONDecodeError as e:\n                        print(f\"\\nRequest {request_id}: JSON Decode Error - {e}\")\n                    except IndexError:\n                        print(f\"\\nRequest {request_id}: List Index Error - choices is empty\")\n                    except Exception as e:\n                        print(f\"\\nRequest {request_id}: Error parsing stream - {e}\")\n            else:\n                # 非 stream 模式下，一次性接收完整 json\n                response_data = await response.json()\n                choices = response_data.get(\"choices\", [])\n                if choices:\n                    content = choices[0].get(\"message\", {}).get(\"content\", \"\")\n                    print(f\"Request {request_id} Output:\\n{content}\")\n                    output_text += content\n\n    except Exception as e:\n        print(f\"\\nRequest {request_id}: Exception - {e}\")\n\nasync def main(prompt_id, model, stream, max_tokens, temperature, top_p):\n    async with aiohttp.ClientSession() as session:\n        payload = {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"\"},\n                {\"role\": \"user\", \"content\": prompt_list[prompt_id]}\n            ],\n            \"model\": model,\n            \"stream\": stream,\n            \"max_tokens\": max_tokens,\n            \"temperature\": temperature,\n            \"top_p\": top_p\n        }\n        tasks = [fetch_event_stream(session, payload, prompt_id, stream)]\n        await asyncio.gather(*tasks)\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Event Stream Request Tester\")\n    parser.add_argument(\"--question_id\", type=int, default=0)\n    parser.add_argument(\"--model\", type=str, default=\"DeepSeek-V3\")\n    parser.add_argument(\"--stream\", type=bool, default=True)  \n    parser.add_argument(\"--max_tokens\", type=int, default=500)\n    parser.add_argument(\"--temperature\", type=float, default=0.8)\n    parser.add_argument(\"--top_p\", type=float, default=1)\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10002/v1/chat/completions\", help=\"API URL\")\n\n    args = parser.parse_args()\n    SERVER_URL = args.api_url\n    asyncio.run(main(args.question_id, args.model, args.stream, args.max_tokens, args.temperature, args.top_p))\n"
  },
  {
    "path": "archive/ktransformers/tests/test_prefix.py",
    "content": "import asyncio\nimport json\nimport sys\nimport aiohttp\nimport random\nimport argparse\nimport yaml\nimport os\nimport time\nfrom time import sleep\n\ndecodesz = 128\n# Server URL (replace with your server URL)\ndecodesz_list = [128]\nprefill_speeds = []\ndecode_speeds = []\n\nasync def fetch_message_once(session, request_id, messages, max_tokens, model):\n    try:\n        payload = {\n            \"messages\": messages,\n            \"model\": model,\n            \"temperature\": 0.3,\n            \"top_p\": 1.0,\n            \"stream\": True,\n            \"return_speed\": True,\n            \"max_tokens\": max_tokens,\n        }\n\n        headers = {\n            'accept': 'application/json',\n            'Content-Type': 'application/json'\n        }\n\n        async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response:\n            if response.status != 200:\n                print(f\"[Request {request_id}] Error: Status {response.status}\")\n                return None, None, None\n\n            buffer = \"\"\n            usage_info = None\n            answer = \"\"\n\n            async for line in response.content:\n                decoded_line = line.decode(\"utf-8\").strip()\n                if not decoded_line or not decoded_line.startswith(\"data: \"):\n                    continue\n\n                decoded_line = decoded_line[6:].strip()\n                if not decoded_line:\n                    continue\n\n                response_data = json.loads(decoded_line)\n\n                if \"usage\" in response_data:\n                    usage_info = response_data[\"usage\"]\n\n                choices = response_data.get(\"choices\", [])\n                if not choices:\n                    continue\n\n                delta = choices[0].get(\"delta\", {})\n                token = delta.get(\"content\", \"\")\n                if token:\n                    buffer += token\n                    answer += token\n\n                finish_reason = choices[0].get(\"finish_reason\", None)\n                if finish_reason:\n                    break\n\n            return answer.strip(), usage_info, buffer.strip()\n\n    except Exception as e:\n        print(f\"[Request {request_id}] Exception: {e}\")\n        return None, None, None\n\n\nasync def multi_turn_conversation(session, request_id, rounds, max_tokens, model):\n    prompt = [\"介绍一下秦始皇\", \"秦始皇的成就有哪些\", \"秦始皇的历史影响\", \"介绍一下秦始皇的陵墓\", \"秦始皇的统一措施\", \"秦始皇的政治制度\", \"秦始皇的文化政策\", \"秦始皇的军事行动\"]\n    \n    messages = [{\"role\": \"system\", \"content\": \"\"}]\n    global prefill_speeds, decode_speeds\n\n    for i in range(rounds):\n        user_msg = f\"这是第{i + 1}轮对话，请回答以下问题：{prompt[i % len(prompt)]}\"\n        messages.append({\"role\": \"user\", \"content\": user_msg})\n        print(f\"\\n[Request {request_id}] >> User: {user_msg}\")\n\n        answer, usage_info, _ = await fetch_message_once(session, request_id, messages, max_tokens, model)\n        if answer:\n            messages.append({\"role\": \"user\", \"content\": answer})\n            print(f\"[Request {request_id}] << Assistant: {answer}\")\n\n        if usage_info:\n            prefill_speed = usage_info[\"prompt_tokens\"] / usage_info[\"prefill_time\"]\n            decode_speed = usage_info[\"completion_tokens\"] / usage_info[\"decode_time\"]\n            prefill_speeds.append(prefill_speed)\n            decode_speeds.append(decode_speed)\n            print(f'[Request {request_id}] prefill speed: {prefill_speed}')\n            print(f'[Request {request_id}] decode speed: {decode_speed}')\n\n\nasync def main(concurrent_requests, rounds, max_tokens, model):\n    async with aiohttp.ClientSession() as session:\n        tasks = [multi_turn_conversation(session, i, rounds, max_tokens, model) for i in range(concurrent_requests)]\n        await asyncio.gather(*tasks)\n\n    if prefill_speeds:\n        import numpy as np\n        print(f\"\\n=== Summary ===\")\n        print(f\"Total concurrency: {concurrent_requests}\")\n        print(f\"Avg prefill speed: {np.mean(prefill_speeds)}\")\n        print(f\"Avg decode speed: {np.mean(decode_speeds)}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Event Stream Request Tester\")\n    parser.add_argument(\"--concurrent\", type=int, default=1, help=\"Number of concurrent requests\")\n    parser.add_argument(\"--model\", type=str, default=\"DeepSeek-V3\", help=\"Model name\")\n    parser.add_argument(\"--prompt_lens\", type=int, default=1024, help=\"prefill prompt lens, 1024 or 2048\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10002/v1/chat/completions\", help=\"API URL\")\n    parser.add_argument(\"--max_tokens\", type=int, default=50, help=\"max decode tokens\")\n    parser.add_argument(\"--rounds\", type=int, default=8, help=\"Number of multi-turn rounds (before final query)\")    \n    \n    args = parser.parse_args()\n    SERVER_URL = args.api_url\n    max_tokens = args.max_tokens\n    model = args.model\n\n    asyncio.run(main(args.concurrent, args.rounds, max_tokens, model))\n\n"
  },
  {
    "path": "archive/ktransformers/tests/test_pytorch_q8.py",
    "content": "import torch\n\n# 定义一个包含线性层的浮点模型\nclass LinearModel(torch.nn.Module):\n    def __init__(self, in_features, out_features):\n        super().__init__()\n        self.linear = torch.nn.Linear(in_features, out_features)\n    \n    def forward(self, x):\n        return self.linear(x)\n\n# 创建浮点模型实例\nin_features = 64\nout_features = 128\nmodel_fp32 = LinearModel(in_features, out_features)\n\n# 创建量化模型实例\nmodel_int8 = torch.ao.quantization.quantize_dynamic(\n    model_fp32,          # 原始浮点模型\n    {torch.nn.Linear},   # 要量化的层类型集合\n    dtype=torch.qint8    # 量化的目标数据类型\n)\n\n# 测试模型\nbatch_size = 32\ninput_fp32 = torch.randn(1, batch_size, in_features)  # 生成随机输入数据\noutput_int8 = model_int8(input_fp32)               # 通过量化模型运行数据\n\n# 打印输出形状验证\nprint(f\"输入形状: {input_fp32.shape}\")\nprint(f\"输出形状: {output_int8.shape}\")\n\n# 比较原始模型和量化模型的输出\nwith torch.no_grad():\n    output_fp32 = model_fp32(input_fp32)\n    \nprint(f\"FP32输出的前几个值: {output_fp32[0, :5]}\")\nprint(f\"INT8输出的前几个值: {output_int8[0, :5]}\")\n\n# 计算平均误差\nerror = torch.abs(output_fp32 - output_int8).mean().item()\nprint(f\"平均绝对误差: {error}\")\n\n# 打印模型类型信息\nprint(f\"量化前模型类型: {type(model_fp32.linear)}\")\nprint(f\"量化后模型类型: {type(model_int8.linear)}\")"
  },
  {
    "path": "archive/ktransformers/tests/test_speed.py",
    "content": "import asyncio\nimport json\nimport sys\nimport aiohttp\nimport random\nimport argparse\nimport yaml\nimport os\nimport time\nfrom time import sleep\n\ndecodesz = 128\n# Server URL (replace with your server URL)\ndecodesz_list = [128]\nprefill_speeds = []\ndecode_speeds = []\nktansformer_prompt1024=\"\"\"Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. \nThey were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills. \nHe was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. \nDursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. \nThe Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere.\nThe Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. \nThey didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. \nDursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. \nThe Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. \nThe Dursleys knew that the Potters had a small son, too, but they had never even seen him. \nThis boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. \nDursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. \nMr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair.None of them noticed a large, tawny owl flutter past the window.\nAt half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls.\n“Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive.\nIt was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. \nFor a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. \nThere was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. \nWhat could he have been thinking of? It must have been a trick of the light. \nMr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. \nIt was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. \nMr. Dursley gave himself a little shake and put the cat out of his mind. \nAs he drove toward town he thought of nothing except a large order of drills he was hoping to get that day.\nBut on the edge of town, drills were driven out of his mind by something else. \nAs he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. \nPeople in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! \nHe supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. \nThey were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! \nThe nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. \nThe traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.\nMr. Dursley always sat with his back to the window in his office on the ninth floor.\"\"\"\nasync def fetch_event_stream(session, request_id, prompt, max_tokens, model):\n    try:\n        payload = {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"\"},\n                {\"role\": \"user\", \"content\": prompt}\n            ],\n            \"model\": model,\n            \"temperature\": 0.3,\n            \"top_p\": 1.0,\n            \"stream\": True,\n            \"return_speed\": True,\n            \"max_tokens\": max_tokens,\n        }\n\n        headers = {\n            'accept': 'application/json',\n            'Content-Type': 'application/json'\n        }\n\n        async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response:\n            if response.status != 200:\n                print(f\"[Request {request_id}] Error: Status {response.status}\")\n                return\n\n            buffer = \"\"  \n            total_tokens = 0\n            decode_start_time = None\n            decode_end_time = None\n            usage_info = None  \n\n            async for line in response.content:\n                try:\n                    decoded_line = line.decode(\"utf-8\").strip()\n                    if not decoded_line or not decoded_line.startswith(\"data: \"):\n                        continue\n\n                    decoded_line = decoded_line[6:].strip()\n                    if not decoded_line:\n                        continue\n\n                    response_data = json.loads(decoded_line)\n                    \n                    if \"usage\" in response_data:\n                        usage_info = response_data[\"usage\"]\n                    \n                    choices = response_data.get(\"choices\", [])\n                    if not choices:\n                        continue\n\n                    delta = choices[0].get(\"delta\", {})\n                    token = delta.get(\"content\", \"\")\n\n                    if token:\n                        if decode_start_time is None:\n                            decode_start_time = time.time()\n                        buffer += token\n                        total_tokens += 1\n                        decode_end_time = time.time()\n\n                        while \"\\n\" in buffer:\n                            line, buffer = buffer.split(\"\\n\", 1)\n                            print(f\"[Request {request_id}] {line}\")\n\n                    finish_reason = choices[0].get(\"finish_reason\", None)\n                    if finish_reason:\n                        break\n\n                except Exception as e:\n                    print(f\"[Request {request_id}] Stream Error: {e}\")\n\n            if buffer.strip():\n                print(f\"[Request {request_id}] {buffer.strip()}\")\n\n            if usage_info:\n                if \"prefill_time\" in usage_info:\n                    # print(f\"[Request {request_id}] Usage:\")\n                    # for key, value in usage_info.items():\n                    #     print(f\"  {key}: {value}\")\n                    prefill_speed = usage_info[\"prompt_tokens\"] / usage_info[\"prefill_time\"]\n                    decode_speed = usage_info[\"completion_tokens\"] / usage_info[\"decode_time\"]\n                    prefill_speeds.append(prefill_speed)\n                    decode_speeds.append(decode_speed)\n                    print(f'[Request {request_id}] prefill speed: {prefill_speed}')\n                    print(f'[Request {request_id}] decode speed: {decode_speed}')\n\n    except Exception as e:\n        print(f\"[Request {request_id}] Exception: {e}\")\n\nasync def main(concurrent_requests , prompt, max_tokens, model):\n    async with aiohttp.ClientSession() as session:\n        tasks = [fetch_event_stream(session, i , prompt, max_tokens, model) for i in range(concurrent_requests)]\n        await asyncio.gather(*tasks)\n    if len(prefill_speeds) != 0:\n        import numpy as np\n        print(f\"concurrency: {len(prefill_speeds)}\")\n        print(f\"total prefill speed: {np.sum(prefill_speeds)}\\n total decode speed: {np.sum(decode_speeds)}\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Event Stream Request Tester\")\n    parser.add_argument(\"--concurrent\", type=int, default=1, help=\"Number of concurrent requests\")\n    parser.add_argument(\"--model\", type=str, default=\"DeepSeek-V3\", help=\"Model name\")\n    parser.add_argument(\"--prompt_lens\", type=int, default=1024, help=\"prefill prompt lens, 1024 or 2048\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10002/v1/chat/completions\", help=\"API URL\")\n    parser.add_argument(\"--max_tokens\", type=int, default=500, help=\"max decode tokens\")\n    \n    args = parser.parse_args()\n    SERVER_URL = args.api_url\n    max_tokens = args.max_tokens\n    model = args.model\n    if args.prompt_lens == 1024:\n        prompt = ktansformer_prompt1024\n    elif args.prompt_lens == 2048:\n        prompt = ktansformer_prompt1024 * 2\n    elif args.prompt_lens == 4096:\n        prompt = ktansformer_prompt1024 * 4\n\n\n    asyncio.run(main(args.concurrent, prompt, max_tokens, model))\n\n"
  },
  {
    "path": "archive/ktransformers/tests/triton_fp8gemm_test.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom typing import Optional\nimport pytest\nfrom typing import Tuple, Optional, Literal\nimport time\n# use dir path\nimport os\nimport sys\nsys.path.insert(0, \"/home/azure/ktransformers\")\nprint(sys.path)\nfrom ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant\nfrom safetensors import safe_open\n\nworld_size = 1\nrank = 0\nblock_size = 128\ngemm_impl: Literal[\"bf16\", \"fp8\"] = \"bf16\"\n# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined\n\ndef test_fp8_gemm_vs_torch_matmul():\n    # Test case 1: Create random matrices of size (M, K) and (K, N)\n    M, K, N = 64, 128, 256  # Matrix dimensions\n    x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')\n    weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda')\n\n    # Apply act_quant to both matrices\n    x_quantized, scale_x = act_quant(x, block_size)\n    weight_quantized, scale_w = act_quant(weight, block_size)\n    \n    # mk continous\n    x_quantized = x_quantized.contiguous()\n    weight_quantized = weight_quantized.contiguous()\n    scale_x = scale_x.contiguous()\n    scale_w = scale_w.contiguous()\n\n    # Perform fp8_gemm using the quantized tensors\n    result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w)\n\n    # Perform torch.matmul using the original floating point tensors\n    result_torch_matmul = torch.matmul(x, weight.T)\n    print(f'result_torch_matmul: {result_torch_matmul.shape}')\n    print(f'result_fp8_gemm: {result_fp8_gemm.shape}')\n\n    print(f\"result_fp8_gemm:\\n {result_fp8_gemm}\")\n    print(f\"result_torch_matmul:\\n {result_torch_matmul}\")\n    \ndef test_fp8_gemm_vs_torch_matmul_load():\n    file_path = \"/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors\"\n    with safe_open(file_path, framework=\"pt\", device=0) as f:\n        weight = f.get_tensor(\"model.layers.0.mlp.down_proj.weight\")\n        scale = f.get_tensor(\"model.layers.0.mlp.down_proj.weight_scale_inv\")\n\n    # weight_dequant\n    weight_dequantized = weight_dequant(weight, scale)\n    print(f\"weight_dequantized: {weight_dequantized.shape}\")\n    N, K = weight_dequantized.shape\n    M = 64\n    x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')\n    x_quantized, scale_x = act_quant(x, block_size)\n    \n    # Test case 1: quantized x matmal with undequantized weight\n    result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)\n    print(f\"result_fp8_gemm:\\n {result_fp8_gemm}\")\n    print(f\"dtype {result_fp8_gemm.dtype}\")\n\n    # Perform torch.matmul using the original floating point tensors\n    result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)\n    print(f\"result_torch_matmul:\\n {result_torch_matmul}\")\n\ndef test_fp8_gemm_tplops():\n    file_path = \"/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors\"\n    with safe_open(file_path, framework=\"pt\", device=0) as f:\n        weight = f.get_tensor(\"model.layers.0.mlp.down_proj.weight\")\n        scale = f.get_tensor(\"model.layers.0.mlp.down_proj.weight_scale_inv\")\n\n    # weight_dequant\n    weight_dequantized = weight_dequant(weight, scale)\n    print(f\"weight_dequantized: {weight_dequantized.shape}\")\n    N, K = weight_dequantized.shape\n    M = 6400\n    x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')\n    # x_quantized, scale_x = act_quant(x, block_size)\n    \n    # Calculate time for 1000 fp8_gemm\n    i = 10\n    flops_per_gemm = 2 * M * N * K\n    total_flops = i * flops_per_gemm\n    \n    x_quantized, scale_x = act_quant(x, block_size)\n    result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)\n    x_quantized, scale_x = act_quant(x, block_size)\n    result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)\n\n    \n    t0 = time.time()\n    torch.cuda.synchronize()\n    for i in range(i):\n        x_quantized, scale_x = act_quant(x, block_size)\n        result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)\n    torch.cuda.synchronize()\n    t1 = time.time()\n    \n    total_time = t1 - t0\n    tflops = total_flops / total_time / 1e12\n    print(f\"total_time: {total_time}\")\n    print(f\"tflops: {tflops}\")\n    \n\n    \n    \nif __name__ == \"__main__\":\n    test_fp8_gemm_vs_torch_matmul()\n    test_fp8_gemm_vs_torch_matmul_load()\n    test_fp8_gemm_tplops()\n    "
  },
  {
    "path": "archive/ktransformers/util/ascend/ascend_utils.py",
    "content": "# coding=utf-8\n# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.\n# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom datetime import timedelta\n\nimport torch\nimport torch_npu\nimport torch.distributed as dist\n\n_DATA_PARALLEL_SIZE = 0\n_TENSOR_PARALLEL_SIZE = 0\n_DATA_PARALLEL_GROUP = None\n_TENSOR_PARALLEL_RANKS = None\n_TENSOR_PARALLEL_GROUP = None\n_DATA_PARALLEL_GROUP_GLOO = None\n_DATA_PARALLEL_RANKS = None\n\n\ndef setup_model_parallel(distributed_timeout_minutes: int = 30, tp: int = 1):\n    global _DATA_PARALLEL_SIZE\n    global _DATA_PARALLEL_GROUP\n    global _DATA_PARALLEL_RANKS\n    global _TENSOR_PARALLEL_SIZE\n    global _TENSOR_PARALLEL_RANKS\n    global _TENSOR_PARALLEL_GROUP\n\n    # os.environ[\"MASTER_ADDR\"] = \"localhost\"\n    # os.environ[\"MASTER_PORT\"] = \"12345\"\n    local_rank = int(os.getenv(\"LOCAL_RANK\", '0'))\n    world_size = int(os.getenv(\"WORLD_SIZE\", '1'))\n    torch_npu.npu.set_device(local_rank)\n    tp_size = tp\n    dp_size = world_size // tp_size\n    _DATA_PARALLEL_SIZE = dp_size\n    _TENSOR_PARALLEL_SIZE = tp_size\n\n    torch.set_num_threads(8)\n    timeout = timedelta(minutes=distributed_timeout_minutes)\n    print(f\"start to init process group ------rank is {local_rank}, world_size is {world_size}\")\n    torch.distributed.init_process_group(\n        backend='hccl',\n        world_size=world_size, rank=local_rank\n    )\n    print(f\"init process group success ------rank is {local_rank}, world_size is {world_size}\")\n\n    rank = torch.distributed.get_rank()\n    nccl_comm_cfgs = {}\n    # DP 组由每隔 tp_size 的进程组成\n    for dp_group_id in range(tp_size):\n        ranks = list(range(dp_group_id, world_size, tp_size))\n        dp_group = torch.distributed.new_group(\n            ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)\n        )\n        if rank in ranks:\n            global _DATA_PARALLEL_GROUP\n            _DATA_PARALLEL_GROUP = dp_group\n            _DATA_PARALLEL_RANKS = ranks\n\n    # TP 组由连续的 dp_size 个进程组成\n    for tp_group_id in range(dp_size):\n        start_rank = tp_group_id * tp_size\n        end_rank = (tp_group_id + 1) * tp_size\n        ranks = list(range(start_rank, end_rank))\n        tp_group = torch.distributed.new_group(\n            ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)\n        )\n        if rank in ranks:\n            global _TENSOR_PARALLEL_GROUP\n            _TENSOR_PARALLEL_GROUP = tp_group\n            _TENSOR_PARALLEL_RANKS = ranks\n    # seed must be the same in all processes\n    torch.manual_seed(1)\n    return local_rank, world_size\n\n\ndef get_tensor_parallel_size():\n    assert _TENSOR_PARALLEL_SIZE is not None, \"tensor parallel size is not set\"\n    return _TENSOR_PARALLEL_SIZE\n\n\ndef get_tensor_parallel_group():\n    assert _TENSOR_PARALLEL_GROUP is not None, \"tensor parallel group is not initialized\"\n    return _TENSOR_PARALLEL_GROUP\n\n\ndef get_tensor_parallel_rank():\n    assert _TENSOR_PARALLEL_RANKS is not None, \"tensor parallel rank is not initialized\"\n    return _TENSOR_PARALLEL_RANKS\n\n\ndef get_data_parallel_size():\n    assert _DATA_PARALLEL_SIZE is not None, \"data parallel size is not initialized\"\n    return _DATA_PARALLEL_SIZE\n\n\ndef get_data_parallel_gloo():\n    assert _DATA_PARALLEL_GROUP_GLOO is not None, \"data parallel gloo group is not initialized\"\n    return _DATA_PARALLEL_GROUP_GLOO\n\n\ndef get_data_parallel_group():\n    assert _DATA_PARALLEL_GROUP is not None, \"data parallel group is not initialized\"\n    return _DATA_PARALLEL_GROUP\n\n\ndef get_data_parallel_rank():\n    assert _DATA_PARALLEL_RANKS is not None, \"data parallel rank is not initialized\"\n    return _DATA_PARALLEL_RANKS\n\n\n\ndef get_nccl_options(pg_name, nccl_comm_cfgs):\n    if pg_name in nccl_comm_cfgs:\n        nccl_options = torch.distributed.ProcessGroupNCCL.Options()\n        nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)\n        nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)\n        nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)\n        return nccl_options\n    else:\n        return None\n\n\ndef get_safetensors_cut_weight(name: str, weights: torch.Tensor):\n    translate_col_cut_tensors = [\"ffn_down\", \"attn_output\"]  # \"kv_b_proj\"\n    translate_row_cut_tensors = [\"ffn_gate\", \"ffn_up\", \"attn_q_b\"]\n    tp = get_tensor_parallel_size()\n    if tp == 1 or weights.shape == torch.Size([1]):\n        return weights\n    rank = torch.distributed.get_rank()\n    rank %= tp\n    assert 0 <= rank < tp and tp > 0, f\"unexpected {rank=}, {tp=}\"\n    if any(t in name for t in translate_col_cut_tensors):\n        if weights.dim() == 1:\n            return weights\n        dim = weights.shape[-1]\n        assert dim % tp == 0, f\"unexpected division {dim=}, {tp=}\"\n        chunk_size = dim // tp\n        output_weights = weights[:, rank * chunk_size:(rank + 1) * chunk_size]\n        # print(f\"col cut weights {name=} from {weights.shape=} to {output_weights.shape=}\")\n        return output_weights\n    elif any(t in name for t in translate_row_cut_tensors):\n        dim = weights.shape[0]\n        assert dim % tp == 0, f\"unexpected division {dim=}, {tp=}\"\n        chunk_size = dim // tp\n        output_weights = weights[rank * chunk_size: (rank + 1) * chunk_size:]\n        # print(f\"row cut weights {name=} from {weights.shape=} to {output_weights.shape=}\")\n        return output_weights\n    else:\n        return weights\n\n\ndef get_absort_weight(model, config):\n    if not dist.is_initialized():\n        return\n    local_rank = dist.get_rank()\n    tp = get_tensor_parallel_size()\n    local_rank %= tp\n    tp_heads = config.num_attention_heads // tp\n    for i in range(config.num_hidden_layers):\n        attn = model.model.layers[i].self_attn\n        if hasattr(attn, \"q_absorb\") and hasattr(attn, \"out_absorb\"):\n            continue\n        if not (hasattr(attn, \"kv_b_proj\")\n                and hasattr(attn, \"kv_lora_rank\")\n                and hasattr(attn, \"qk_nope_head_dim\")):\n            continue\n\n        kv_b_proj = attn.kv_b_proj.weight.view(config.num_attention_heads, -1, attn.kv_lora_rank)\n        q_absorb = kv_b_proj[:, :attn.qk_nope_head_dim, :].clone()\n        out_absorb = kv_b_proj[:, attn.qk_nope_head_dim:, :].clone()\n\n        q_absorb = q_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()\n        out_absorb = out_absorb[local_rank * tp_heads: (local_rank + 1) * tp_heads, :, :].contiguous()\n        out_absorb = out_absorb.transpose(1, 2).contiguous()\n\n        setattr(attn, \"q_absorb\", q_absorb)\n        setattr(attn, \"out_absorb\", out_absorb)\n\n        if hasattr(attn, \"orig_module\") and hasattr(attn.orig_module, \"kv_b_proj\"):\n            del attn.orig_module.kv_b_proj\n    dist.barrier(get_tensor_parallel_group())\n\n\ndef allredeuce_warpper(func):\n    def wrapper(*args, **kwargs):\n        orig_output = func(*args, **kwargs)\n        if isinstance(orig_output, tuple):\n            if get_tensor_parallel_size() > 1:\n                org_dtype = orig_output[0].dtype\n                if org_dtype == torch.bfloat16:\n                    dist.all_reduce(orig_output[0].to(dtype=torch.float16), op=dist.ReduceOp.SUM,\n                                    group=get_tensor_parallel_group())\n                else:\n                    dist.all_reduce(orig_output[0], op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())\n                if org_dtype == torch.bfloat16:\n                    bf_orig_output = orig_output[0].to(dtype=org_dtype)\n                else:\n                    bf_orig_output = orig_output[0]\n            else:\n                bf_orig_output = orig_output[0]\n            return (bf_orig_output,) + orig_output[1:]\n        else:\n            if get_tensor_parallel_size() > 1:\n                org_dtype = orig_output.dtype\n                if org_dtype == torch.bfloat16:\n                    orig_output = orig_output.to(dtype=torch.float16)\n                dist.all_reduce(orig_output, op=dist.ReduceOp.SUM, group=get_tensor_parallel_group())\n                if org_dtype == torch.bfloat16:\n                    orig_output = orig_output.to(dtype=org_dtype)\n            return orig_output\n\n    return wrapper"
  },
  {
    "path": "archive/ktransformers/util/cuda_graph_runner.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport torch\nfrom typing import Dict\n\nclass CUDAGraphRunner:\n\n    def __init__(self):\n        self.graph = None\n        self.input_buffers: Dict[str, torch.Tensor] = {}\n        self.output_buffers: Dict[str, torch.Tensor] = {}\n\n    def capture(\n        self,\n        model,\n        cur_token,\n        position_ids,\n        cache_position,\n        past_key_values,\n        main_device,\n        **kwargs,\n    ) -> None:\n        assert self.graph is None\n        # Capture the graph.\n        torch.cuda.synchronize()\n        self.graph = torch.cuda.CUDAGraph()\n        #self.graph.enable_debug_mode()\n        self.model = model\n        inputs_embeds = model.model.embed_tokens(cur_token.to(\"cpu\")).to(main_device)\n        # torch.cuda.set_device can't set \"cuda\", must have a index\n        if main_device == \"cuda\":\n            main_device = \"cuda:0\"\n        torch.cuda.set_device(main_device)\n        self.main_device = main_device\n        capture_stream = torch.cuda.Stream()\n        with torch.cuda.graph(self.graph, stream = capture_stream):\n            logits=model(inputs_embeds=inputs_embeds, \n                         position_ids=position_ids,\n                         cache_position=cache_position,\n                         past_key_values=past_key_values,\n                         **kwargs)[0]\n            capture_stream.wait_stream(torch.cuda.current_stream())\n            torch.cuda.set_device(main_device)\n            torch.cuda.set_stream(capture_stream)\n        if past_key_values != None:    \n            past_key_values.change_seq_length(-1)\n        torch.cuda.synchronize(self.main_device)\n        #self.graph.debug_dump(\"cuda_graph_hooked.dot\")\n\n        # Save the input and output buffers.\n        self.input_buffers = {\n            \"inputs_embeds\": inputs_embeds,\n            \"position_ids\": position_ids,\n            \"cache_position\": cache_position,\n        }\n        self.output_buffers = {\"logits\": logits}\n        return\n\n    def forward(\n        self,\n        cur_token,\n        position_ids,\n        cache_position,\n    ) -> torch.Tensor:\n        # Copy the input tensors to the input buffers.\n        inputs_embeds = self.model.model.embed_tokens(cur_token.to(\"cpu\"))\n        self.input_buffers[\"inputs_embeds\"].copy_(inputs_embeds)\n        self.input_buffers[\"position_ids\"].copy_(position_ids)\n        self.input_buffers[\"cache_position\"].copy_(cache_position)\n\n        # Run the graph.\n        #print(\"begin replay\")\n        #time.sleep(1)\n        self.graph.replay()\n        torch.cuda.synchronize(self.main_device)\n        # Return the output tensor.\n        return self.output_buffers[\"logits\"]\n\n    def __call__(self, *args, **kwargs):\n        return self.forward(*args, **kwargs)\n"
  },
  {
    "path": "archive/ktransformers/util/custom_gguf.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : Azure-Tang, Boxin Zhang, chenht2022\nDate         : 2024-07-26 08:48:54\nVersion      : 1.0.0\nLastEditors  : kkk1nak0\nLastEditTime : 2024-08-14 08:20:45\nAdapted from https://github.com/99991/pygguf/blob/main/gguf.py\nCopyright (c) 2023-2024 The ggml authors\nCopyright (c) 2024 Thomas Germer\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\n# copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf\n# GGUF specification\n# https://github.com/ggerganov/ggml/blob/master/docs/gguf.md\nimport struct\nimport warnings\nimport numpy as np\nimport re\nimport numpy.typing as npt\nfrom typing import Sequence\nimport os\nfrom enum import IntEnum\nimport torch\n\ntry:\n    import torch_npu\n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\n\nif not torch.xpu.is_available() and not use_torch_npu:\n    import KTransformersOps\n\nimport ctypes\nimport math\n\nclass GGMLQuantizationType(IntEnum):\n    F32     = 0\n    F16     = 1\n    Q4_0    = 2\n    Q4_1    = 3\n    Q5_0    = 6\n    Q5_1    = 7\n    Q8_0    = 8\n    Q8_1    = 9\n    Q2_K    = 10\n    Q3_K    = 11\n    Q4_K    = 12\n    Q5_K    = 13\n    Q6_K    = 14\n    Q8_K    = 15\n    IQ2_XXS = 16\n    IQ2_XS  = 17\n    IQ3_XXS = 18\n    IQ1_S   = 19\n    IQ4_NL  = 20\n    IQ3_S   = 21\n    IQ2_S   = 22\n    IQ4_XS  = 23\n    I8      = 24\n    I16     = 25\n    I32     = 26\n    I64     = 27\n    F64     = 28\n    IQ1_M   = 29\n    BF16    = 30\n\nQK_K = 256\nGGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {\n    GGMLQuantizationType.F32:     (1, 4),\n    GGMLQuantizationType.F16:     (1, 2),\n    GGMLQuantizationType.Q4_0:    (32, 2 + 16),\n    GGMLQuantizationType.Q4_1:    (32, 2 + 2 + 16),\n    GGMLQuantizationType.Q5_0:    (32, 2 + 4 + 16),\n    GGMLQuantizationType.Q5_1:    (32, 2 + 2 + 4 + 16),\n    GGMLQuantizationType.Q8_0:    (32, 2 + 32),\n    GGMLQuantizationType.Q8_1:    (32, 4 + 4 + 32),\n    GGMLQuantizationType.Q2_K:    (256, 2 + 2 + QK_K // 16 + QK_K // 4),\n    GGMLQuantizationType.Q3_K:    (256, 2 + QK_K // 4 + QK_K // 8 + 12),\n    GGMLQuantizationType.Q4_K:    (256, 2 + 2 + QK_K // 2 + 12),\n    GGMLQuantizationType.Q5_K:    (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),\n    GGMLQuantizationType.Q6_K:    (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),\n    GGMLQuantizationType.Q8_K:    (256, 4 + QK_K + QK_K // 8),\n    GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4),\n    GGMLQuantizationType.IQ2_XS:  (256, 2 + QK_K // 4 + QK_K // 32),\n    GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8),\n    GGMLQuantizationType.IQ1_S:   (256, 2 + QK_K // 8 + QK_K // 16),\n    GGMLQuantizationType.IQ4_NL:  (32, 2 + 16),\n    GGMLQuantizationType.IQ3_S:   (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),\n    GGMLQuantizationType.IQ2_S:   (256, 2 + QK_K // 4 + QK_K // 16),\n    GGMLQuantizationType.IQ4_XS:  (256, 2 + 2 + QK_K // 2 + QK_K // 64),\n    GGMLQuantizationType.I8:      (1, 1),\n    GGMLQuantizationType.I16:     (1, 2),\n    GGMLQuantizationType.I32:     (1, 4),\n    GGMLQuantizationType.I64:     (1, 8),\n    GGMLQuantizationType.F64:     (1, 8),\n    GGMLQuantizationType.IQ1_M:   (256, QK_K // 8 + QK_K // 16  + QK_K // 32),\n    GGMLQuantizationType.BF16:    (1, 2),\n}\n\n# copied from llama.cpp/gguf-py/gguf/quants.py to avoid dependence of gguf\ndef quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):\n    block_size, type_size = GGML_QUANT_SIZES[quant_type]\n    if shape[-1] % block_size != 0:\n        raise ValueError(f\"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})\")\n    return (*shape[:-1], shape[-1] // block_size * type_size)\n\nGGML_TYPES = {\n    \"F32\": 0,\n    \"F16\": 1,\n    \"Q4_0\": 2,\n    \"Q5_0\": 6,\n    \"Q8_0\": 8,\n    \"Q2_K\": 10,\n    \"Q3_K\": 11,\n    \"Q4_K\": 12,\n    \"Q5_K\": 13,\n    \"Q6_K\": 14,\n    \"IQ4_XS\": 23,\n    \"BF16\": 30,\n}\n\nGGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}\n\nGGML_BLOCK_SIZES = {\n    \"F32\": 4,\n    \"F16\": 2,\n    \"BF16\": 2,\n    \"Q4_0\": 2 + 16,\n    \"Q5_0\": 2 + 4 + 16,\n    \"Q8_0\": 2 + 32,\n    \"Q2_K\": 256 // 16 + 256 // 4 + 2 + 2,\n    \"Q3_K\": 256 // 8 + 256 // 4 + 12 + 2,\n    \"Q4_K\": 2 + 2 + 12 + 256 // 2,\n    \"Q5_K\": 2 + 2 + 12 + 256 // 8 + 256 // 2,\n    \"Q6_K\": 256 // 2 + 256 // 4 + 256 // 16 + 2,\n    \"IQ4_XS\": 2 + 2 + 256 // 2 + 256 // 64,\n    \"FP8\": 1,\n}\n\nGGML_ELEMENTS_PER_BLOCK = {\n    \"F32\": 1,\n    \"F16\": 1,\n    \"BF16\": 1,\n    \"Q4_0\": 32,\n    \"Q5_0\": 32,\n    \"Q8_0\": 32,\n    \"Q2_K\": 256,\n    \"Q3_K\": 256,\n    \"Q4_K\": 256,\n    \"Q5_K\": 256,\n    \"Q6_K\": 256,\n    \"IQ4_XS\": 256,\n    \"FP8\": 1,\n}\n\nDATA_TYPES = {\n    \"uint8\": 0,\n    \"int8\": 1,\n    \"uint16\": 2,\n    \"int16\": 3,\n    \"uint32\": 4,\n    \"int32\": 5,\n    \"float32\": 6,\n    \"bool\": 7,\n    \"string\": 8,\n    \"array\": 9,\n    \"uint64\": 10,\n    \"int64\": 11,\n    \"float64\": 12,\n    \"FP8\": 13,\n}\n\ndef read_value(f, data_type):\n    if data_type == DATA_TYPES[\"string\"]:\n        length = struct.unpack(\"<Q\", f.read(8))[0]\n        return f.read(length).decode(\"utf-8\")\n\n    elif data_type == DATA_TYPES[\"bool\"]:\n        return bool(struct.unpack(\"<?\", f.read(1))[0])\n\n    elif data_type == DATA_TYPES[\"uint8\"]:\n        return struct.unpack(\"<B\", f.read(1))[0]\n\n    elif data_type == DATA_TYPES[\"int8\"]:\n        return struct.unpack(\"<b\", f.read(1))[0]\n\n    elif data_type == DATA_TYPES[\"uint16\"]:\n        return struct.unpack(\"<H\", f.read(2))[0]\n\n    elif data_type == DATA_TYPES[\"int16\"]:\n        return struct.unpack(\"<h\", f.read(2))[0]\n\n    elif data_type == DATA_TYPES[\"uint32\"]:\n        return struct.unpack(\"<I\", f.read(4))[0]\n\n    elif data_type == DATA_TYPES[\"int32\"]:\n        return struct.unpack(\"<i\", f.read(4))[0]\n\n    elif data_type == DATA_TYPES[\"float32\"]:\n        return struct.unpack(\"<f\", f.read(4))[0]\n\n    elif data_type == DATA_TYPES[\"uint64\"]:\n        return struct.unpack(\"<Q\", f.read(8))[0]\n\n    elif data_type == DATA_TYPES[\"int64\"]:\n        return struct.unpack(\"<q\", f.read(8))[0]\n\n    elif data_type == DATA_TYPES[\"float64\"]:\n        return struct.unpack(\"<d\", f.read(8))[0]\n\n    elif data_type == DATA_TYPES[\"array\"]:\n        elem_type, count = struct.unpack(\"<IQ\", f.read(4 + 8))\n        return [read_value(f, elem_type) for _ in range(count)]\n\n    elif data_type == DATA_TYPES[\"FP8\"]:\n        return struct.unpack(\"<B\", f.read(1))[0]\n\n    else:\n        raise NotImplementedError(f\"Data type {data_type} not implemented\")\n\ndef dequantize_q2_k(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74\n    block_size = GGML_BLOCK_SIZES[\"Q2_K\"]\n    num_blocks = len(data) // block_size\n\n    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)\n\n    dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)\n    d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)\n    scales = data_u8[:, :16].reshape(num_blocks, 16, 1)\n    qs = data_u8[:, 16:80].reshape(num_blocks, 64)\n\n    tmp = np.stack([\n        qs[:, 00:16] >> 0,\n        qs[:, 16:32] >> 0,\n        qs[:, 00:16] >> 2,\n        qs[:, 16:32] >> 2,\n        qs[:, 00:16] >> 4,\n        qs[:, 16:32] >> 4,\n        qs[:, 00:16] >> 6,\n        qs[:, 16:32] >> 6,\n        qs[:, 32:48] >> 0,\n        qs[:, 48:64] >> 0,\n        qs[:, 32:48] >> 2,\n        qs[:, 48:64] >> 2,\n        qs[:, 32:48] >> 4,\n        qs[:, 48:64] >> 4,\n        qs[:, 32:48] >> 6,\n        qs[:, 48:64] >> 6,\n    ], axis=1)\n\n    return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)\n\ndef dequantize_q2_k_gpu(data, device:str =\"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"Q2_K\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q2_K\"]\n    data = np.frombuffer(data, dtype=data.dtype)\n    device = torch.device(device)\n    # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, \n    # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q2_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\ndef dequantize_q3_k(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95\n    block_size = GGML_BLOCK_SIZES[\"Q3_K\"]\n    num_blocks = len(data) // block_size\n\n    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)\n\n    d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)\n    bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder=\"little\")\n    bits = 4 ^ (bits << 2)\n    qs = data_u8[:, 32:32 + 64].astype(np.int16)\n    a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)\n    scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)\n    scales[:, 0] = (a & 15) | ((c & 3) << 4)\n    scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)\n    scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)\n    scales[:, 3] = (b >> 4) | ((c >> 6) << 4)\n    scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)\n\n    return d * (scales - 32) * np.stack([\n        (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),\n        (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),\n        (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),\n        (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),\n        (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),\n        (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),\n        (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),\n        (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),\n        (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),\n        (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),\n        (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),\n        (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),\n        (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),\n        (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),\n        (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),\n        (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])\n    ], axis=1)\n\ndef dequantize_q3_k_gpu(data, device:str =\"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"Q3_K\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q3_K\"]\n    data = np.frombuffer(data, dtype=data.dtype)\n    device = torch.device(device)\n    # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, \n    # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q3_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\ndef dequantize_q4_k(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116\n    block_size = GGML_BLOCK_SIZES[\"Q4_K\"]\n    num_blocks = len(data) // block_size\n    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)\n    # Casting to float32 because float16 is very slow on CPU\n    scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)\n    scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)\n    qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)\n    qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)\n    # Dequantize scales and offsets (6 bits and 4 + 2 bits)\n    factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)\n    offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)\n    # Interleave low and high quantized bits\n    qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)\n    # Dequantize final weights using scales and offsets\n    return factors * qs2 - offsets\n\ndef dequantize_q4_k_gpu(data, device:str =\"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"Q4_K\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q4_K\"]\n    data = np.frombuffer(data, dtype=data.dtype)\n    device = torch.device(device)\n    # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, \n    # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\ndef dequantize_q5_k(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138\n    block_size = GGML_BLOCK_SIZES[\"Q5_K\"]\n    num_blocks = len(data) // block_size\n\n    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)\n\n    d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)\n    dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)\n    scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)\n    qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1)\n    qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32)\n\n    bits = np.unpackbits(qh, axis=-1, bitorder=\"little\")\n\n    qs_hi_4 = qs >> 4\n    qs_lo_4 = qs & 15\n\n    scales_lo_6 = scales[:, :8] & 63\n    scales_hi_6 = scales[:, :8] >> 6\n    scales_lo_4 = scales[:, 8:] & 15\n    scales_hi_4 = scales[:, 8:] >> 4\n\n    m1 = dmin * scales_lo_6[:, 4]\n    m2 = dmin * scales_lo_6[:, 5]\n    m3 = dmin * scales_lo_6[:, 6]\n    m4 = dmin * scales_lo_6[:, 7]\n    m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))\n    m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))\n    m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))\n    m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))\n\n    d1 = d * scales_lo_6[:, 0]\n    d2 = d * scales_lo_6[:, 1]\n    d3 = d * scales_lo_6[:, 2]\n    d4 = d * scales_lo_6[:, 3]\n    d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))\n    d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))\n    d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))\n    d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))\n\n    return np.concatenate([\n        d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,\n        d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,\n        d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,\n        d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,\n        d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,\n        d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,\n        d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,\n        d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,\n    ], axis=1)\n\ndef dequantize_q5_k_gpu(data, device:str =\"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"Q5_K\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q5_K\"]\n    data = np.frombuffer(data, dtype=data.dtype)\n    device = torch.device(device)\n    # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, \n    # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q5_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\ndef dequantize_q6_k(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152\n    block_size = GGML_BLOCK_SIZES[\"Q6_K\"]\n    num_blocks = len(data) // block_size\n\n    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)\n    data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)\n\n    scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)\n    # TODO use uint8 and cast later?\n    ql = data_u8[:, :128].astype(np.int16)\n    qh = data_u8[:, 128:192].astype(np.int16)\n    sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)\n\n    # Unpack bits, subtraction requires signed data type\n    q1 = (ql[:,   :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32\n    q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32\n    q3 = (ql[:,   :32 ] >>  4) | (((qh[:, :32] >> 4) & 3) << 4) - 32\n    q4 = (ql[:, 32:64 ] >>  4) | (((qh[:, :32] >> 6) & 3) << 4) - 32\n    q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32\n    q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32\n    q7 = (ql[:, 64:96 ] >>  4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32\n    q8 = (ql[:, 96:128] >>  4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32\n\n    # Dequantize\n    return scales * np.concatenate([\n        sc[:,  0] * q1[:, :16],\n        sc[:,  1] * q1[:, 16:],\n        sc[:,  2] * q2[:, :16],\n        sc[:,  3] * q2[:, 16:],\n        sc[:,  4] * q3[:, :16],\n        sc[:,  5] * q3[:, 16:],\n        sc[:,  6] * q4[:, :16],\n        sc[:,  7] * q4[:, 16:],\n        sc[:,  8] * q5[:, :16],\n        sc[:,  9] * q5[:, 16:],\n        sc[:, 10] * q6[:, :16],\n        sc[:, 11] * q6[:, 16:],\n        sc[:, 12] * q7[:, :16],\n        sc[:, 13] * q7[:, 16:],\n        sc[:, 14] * q8[:, :16],\n        sc[:, 15] * q8[:, 16:],\n    ], axis=1) \n\n# @torch.jit.script\ndef dequantize_q6_k_gpu(data: np.ndarray, device:str = \"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"Q6_K\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q6_K\"]\n    device = torch.device(device)\n    num_blocks = len(data) // block_size\n    data = np.frombuffer(data, dtype=data.dtype)\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\nkvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)\n\ndef dequantize_iq4_xs(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-quants.c#L3568\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-common.h#L393\n    block_size = GGML_BLOCK_SIZES[\"IQ4_XS\"]\n    num_blocks = len(data) // block_size\n\n    d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1)\n    scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:]\n    scales_l = data_u8[:, :4].reshape(num_blocks, 4)\n    qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8)\n\n    ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8)\n    for ib in range(QK_K // 32):\n        ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4)\n\n    dl = (d * (ls - 32)).reshape(num_blocks, -1, 1)\n\n    qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf\n    qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4\n\n    y = np.zeros((num_blocks, QK_K), dtype=np.float32)\n    for ib in range(QK_K // 32):\n        y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]]\n        y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]]\n\n    return y.flatten()\n\ndef dequantize_iq4_xs_gpu(data: np.ndarray, device:str = \"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"IQ4_XS\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"IQ4_XS\"]\n    device = torch.device(device)\n    num_blocks = len(data) // block_size\n    data = np.frombuffer(data, dtype=data.dtype)\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_iq4_xs(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\ndef dequantize_q4_0(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L141\n    num_blocks = len(data) // GGML_BLOCK_SIZES[\"Q4_0\"]\n\n    scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)\n    qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]\n\n    return np.concatenate([\n        scales * ((qs & 0xf).astype(np.int8) - 8),\n        scales * ((qs >> 4).astype(np.int8) - 8),\n    ], axis=1)\n\ndef dequantize_q4_0_gpu(data, device:str = \"cuda\", target_dtype = torch.get_default_dtype()):\n    raise NotImplementedError()\n\ndef dequantize_q5_0(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1556\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L161\n    num_blocks = len(data) // GGML_BLOCK_SIZES[\"Q5_0\"]\n\n    scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)\n    qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]\n    qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]\n\n    bits = np.unpackbits(qh, axis=-1, bitorder=\"little\")\n\n    x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16\n    x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16\n\n    return np.concatenate([\n        scales * x0,\n        scales * x1,\n    ], axis=1)\n\ndef dequantize_q5_0_gpu(data, device:str = \"cuda\", target_dtype = torch.get_default_dtype()):\n    raise NotImplementedError()\n\ndef dequantize_q8_0(data):\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43\n    num_blocks = len(data) // GGML_BLOCK_SIZES[\"Q8_0\"]\n\n    scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)\n    qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]\n    return scales * qs\n\ndef dequantize_q8_0_gpu(data, device:str = \"cuda\", target_dtype = torch.get_default_dtype()):\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43\n    \n    block_size = GGML_BLOCK_SIZES[\"Q8_0\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q8_0\"]\n    device = torch.device(device)\n    data = np.frombuffer(data, dtype=data.dtype)\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q8_0(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\n\ndef dequantize_f32(data):\n    return np.frombuffer(data, dtype=np.float32)\n\ndef dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()):\n    data = np.frombuffer(data, dtype=np.float32)\n    res = torch.from_numpy(data.copy())\n    res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)\n    res_gpu.copy_(res)\n    return res_gpu\n\ndef dequantize_f16(data):\n    return np.frombuffer(data, dtype=np.float16)\n\ndef dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()):\n    data = np.frombuffer(data, dtype=np.float16)\n    res = torch.from_numpy(data.copy())\n    res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)\n    res_gpu.copy_(res)\n    return res_gpu\n\ndef dequantize_bf16_gpu(data, device, target_dtype = torch.get_default_dtype()):\n    data = np.frombuffer(data, dtype=np.float16)\n    res = torch.from_numpy(data.copy())\n    res_gpu = torch.empty_like(res, device=device)\n    res_gpu.copy_(res)\n    return res_gpu\n\nGGML_DEQUANTIZE = {\n    \"F32\": dequantize_f32,\n    \"F16\": dequantize_f16,\n    \"BF16\": dequantize_f16,\n    \"Q4_0\": dequantize_q4_0,\n    \"Q5_0\": dequantize_q5_0,\n    \"Q8_0\": dequantize_q8_0,\n    \"Q2_K\": dequantize_q2_k,\n    \"Q3_K\": dequantize_q3_k,\n    \"Q4_K\": dequantize_q4_k,\n    \"Q5_K\": dequantize_q5_k,\n    \"Q6_K\": dequantize_q6_k,\n    \"IQ4_XS\": dequantize_iq4_xs,\n}\n\nGGML_DEQUANTIZE_GPU = {\n    \"F32\": dequantize_f32_gpu,\n    \"F16\": dequantize_f16_gpu,\n    \"BF16\": dequantize_bf16_gpu,\n    \"Q4_0\": dequantize_q4_0_gpu,\n    \"Q5_0\": dequantize_q5_0_gpu,\n    \"Q8_0\": dequantize_q8_0_gpu,\n    \"Q2_K\": dequantize_q2_k_gpu,\n    \"Q3_K\": dequantize_q3_k_gpu,\n    \"Q4_K\": dequantize_q4_k_gpu,\n    \"Q5_K\": dequantize_q5_k_gpu,\n    \"Q6_K\": dequantize_q6_k_gpu,\n    \"IQ4_XS\": dequantize_iq4_xs_gpu,\n}\n\n\ndef translate_name_to_gguf_mixtral(name):\n    \n    replacement_template = {\n        \"w1.weight\": \"ffn_gate\",\n        \"w2.weight\": \"ffn_down\",\n        \"w3.weight\": \"ffn_up\"\n    }  \n\n    pattern = re.compile(r\"model.layers\\.(\\d+)\\.block_sparse_moe\\.experts\\.(\\d+)\\.(w\\d\\.weight)\")\n\n    def replace_match(match):\n        blk_id = match.group(1)\n        expert_id = match.group(2)\n        weight_type = match.group(3)\n        if weight_type in replacement_template:\n            return f\"blk.{blk_id}.{replacement_template[weight_type]}.{expert_id}.weight\"\n        else:\n            return match.group(0)\n\n    new_name = re.sub(pattern, replace_match, name)\n    \n    return new_name\n\ndef translate_name_to_gguf(name):\n\n    name = translate_name_to_gguf_mixtral(name)\n\n    if \".ffn_gate_exp.\" in name:\n        name = name.replace(\".ffn_gate_exp.\", \".ffn_gate_exps.\")\n    if \".ffn_up_exp.\" in name:\n        name = name.replace(\".ffn_up_exp.\", \".ffn_up_exps.\")\n    if \".ffn_down_exp.\" in name:\n        name = name.replace(\".ffn_down_exp.\", \".ffn_down_exps.\")\n    \n    m = re.match(r\"model\\.layers\\.(\\d+)\\.mlp\\.experts\\.(\\d+)\\.(gate_proj|up_proj|down_proj)\", name)\n    if m:\n        layer, expert, proj = m.groups()\n        if proj == \"gate_proj\":\n            return f\"blk.{layer}.{expert}.ffn_gate_exps\"\n        elif proj == \"up_proj\":\n            return f\"blk.{layer}.{expert}.ffn_up_exps\"\n        else:\n            return f\"blk.{layer}.{expert}.ffn_down_exps\"\n\n    m = re.match(r\"blk\\.(\\d+)\\.mlp\\.experts\\.(\\d+)\\.(gate_proj|up_proj|down_proj)\", name)\n    if m:\n        layer, expert, proj = m.groups()\n        if proj == \"gate_proj\":\n            return f\"blk.{layer}.{expert}.ffn_gate_exps\"\n        elif proj == \"up_proj\":\n            return f\"blk.{layer}.{expert}.ffn_up_exps\"\n        else:\n            return f\"blk.{layer}.{expert}.ffn_down_exps\"\n\n    name = name.replace(\"lm_head.\", \"output.\")\n    name = name.replace(\"model.embed_tokens.\", \"token_embd.\")\n    name = name.replace(\"model.norm.\", \"output_norm.\")\n    \n    name = name.replace(\"model.layers.\", \"blk.\")\n    name = name.replace(\".input_layernorm\", \".attn_norm\")\n    name = name.replace(\".mlp.down_proj\", \".ffn_down\")\n    name = name.replace(\".mlp.gate_proj\", \".ffn_gate\")\n    name = name.replace(\".mlp.up_proj\", \".ffn_up\")\n    name = name.replace(\".post_attention_layernorm\", \".ffn_norm\")\n    name = name.replace(\".self_attn.q_proj\", \".attn_q\")\n    name = name.replace(\".self_attn.k_proj\", \".attn_k\")\n    name = name.replace(\".self_attn.v_proj\", \".attn_v\")\n    name = name.replace(\".self_attn.o_proj\", \".attn_output\")\n    name = name.replace(\".self_attn.qkv_proj\", \".attn_qkv\")\n    name = name.replace(\".self_attn.kv_a_proj_with_mqa\", \".attn_kv_a_mqa\")\n    name = name.replace(\".self_attn.kv_a_layernorm\", \".attn_kv_a_norm\")\n    name = name.replace(\".self_attn.kv_b_proj\", \".attn_kv_b\")\n    name = name.replace(\".self_attn.q_a_proj\", \".attn_q_a\")\n    name = name.replace(\".self_attn.q_a_layernorm\", \".attn_q_a_norm\")\n    name = name.replace(\".self_attn.q_b_proj\", \".attn_q_b\")\n\n    name = name.replace(\".self_attn.q_norm\", \".attn_q_norm\")\n    name = name.replace(\".self_attn.k_norm\", \".attn_k_norm\")\n    \n    name = name.replace(\".shared_expert.\", \".shared_experts.\")\n    name = name.replace(\".shared_expert_\", \".shared_experts_\")\n    name = name.replace(\".gate_up_proj.\", \".up_proj\")\n    \n    name = name.replace(\".mlp.shared_experts.down_proj\", \".ffn_down_shexp\")\n    name = name.replace(\".mlp.gate.e_score_correction_bias\", \".exp_probs_b.bias\")\n    name = name.replace(\".mlp.gate\", \".ffn_gate_inp\")\n    name = name.replace(\".mlp.shared_experts.gate_proj\", \".ffn_gate_shexp\")\n    name = name.replace(\".mlp.shared_experts.up_proj\", \".ffn_up_shexp\")\n    name = name.replace(\".mlp.shared_experts_gate\", \".ffn_gate_inp_shexp\")\n    name = name.replace(\".mlp.experts\", \"\")\n\n    name = name.replace(\".mlp.experts.ffn_down_exps\", \".ffn_down_exps\")\n    name = name.replace(\".mlp.experts.ffn_gate_exps\", \".ffn_gate_exps\")\n    name = name.replace(\".mlp.experts.ffn_up_exps\", \".ffn_up_exps\")\n\n    \n    name = name.replace(\".block_sparse_moe.gate.\", \".ffn_gate_inp.\")\n    name = name.replace(\".block_sparse_moe.experts\", \"\")\n    \n    name = name.replace(\".feed_forward.experts\", \"\")\n    name = name.replace(\".feed_forward.router\", \".ffn_gate_inp\")\n    name = name.replace(\".feed_forward.shared_experts.down_proj\", \".ffn_down_shexp\")\n    name = name.replace(\".feed_forward.shared_experts.gate_proj\", \".ffn_gate_shexp\")\n    name = name.replace(\".feed_forward.shared_experts.up_proj\", \".ffn_up_shexp\")\n    return name\n\nif __name__ == '__main__':\n    gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'\n    loader = GGUFLoader(gguf_path)\n    loader.load_gguf_tensor('token_embd.weight')\n\n"
  },
  {
    "path": "archive/ktransformers/util/custom_loader.py",
    "content": "import struct\nimport warnings\nimport numpy as np\nimport re\nimport numpy.typing as npt\nfrom typing import Sequence\nimport os\nfrom enum import IntEnum\nimport torch\n\ntry:\n    import torch_npu\n    use_torch_npu = torch_npu.npu.is_available()\nexcept:\n    use_torch_npu = False\n\nif not torch.xpu.is_available() and not use_torch_npu:\n    import KTransformersOps\nfrom safetensors import safe_open\n\nif not use_torch_npu:\n    from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant\nfrom ktransformers.util.custom_gguf import *\nfrom safetensors.torch import save_file\nfrom abc import ABC, abstractmethod\nfrom typing import Dict, Any, Optional, Union\n\nclass ModelLoader(ABC):\n    \"\"\"\n    Abstract base class for model loaders.\n    Defines the interface that all model loaders must implement.\n    \"\"\"\n    tensor_file_map = {}\n    @abstractmethod\n    def has_tensor(cls, name: str):\n        \"\"\"\n        Check if the tensor exists in the loader.\n        \n        Args:\n            name: Name of the tensor to check\n            \n        Returns:\n            bool: True if the tensor exists, False otherwise\n        \"\"\"\n        pass\n\nclass SafeTensorLoader(ModelLoader):\n    tensor_file_map: dict\n    tensor_type_map: dict\n    file_handle_map: dict\n    tensor_device_map: dict\n    \n    def __init__(self, file_path: str):\n        self.__load_tensor_file_map(file_path)\n\n    def __load_tensor_file_map(self, file_path: str):\n        # 处理传入路径，确保是文件夹路径\n        if not os.path.exists(file_path):\n            raise FileNotFoundError(f\"Path not found: {file_path}\")\n        if os.path.isfile(file_path):\n            folder_path = os.path.dirname(file_path)\n        else:\n            folder_path = file_path\n        self.file_handle_map = {}\n        self.tensor_file_map = {}\n        self.tensor_type_map = {}\n        self.tensor_device_map = {}\n\n        found_safetensor = False\n        for root, _, files in os.walk(folder_path):\n            files = sorted(files)\n            for file in files:\n                if file.endswith(\".safetensors\"):\n                    found_safetensor = True\n                    file_path = os.path.join(root, file)\n                    if file not in self.file_handle_map:\n                        try:\n                            handle = safe_open(file_path, framework=\"pt\")\n                            self.file_handle_map[file] = handle\n                        except Exception as e:\n                            print(f\"Error opening Safetensor file {file_path}: {e}\")\n                            continue\n\n                    f = self.file_handle_map.get(file)\n                    if f is None:\n                        continue\n                    try:\n                        for key in f.keys():\n                            self.tensor_file_map[key] = file\n                    except Exception as e:\n                        print(f\"Error reading Safetensor file {file_path}: {e}\")\n\n        # if not found_safetensor:\n        #     raise FileNotFoundError(f\"No Safetensor files found in {folder_path}\")\n\n    def load_tensor(self, key: str, device: str = \"cpu\"):\n        if translate_name_to_gguf(key) in self.tensor_file_map:\n            key = translate_name_to_gguf(key)\n        elif key in self.tensor_file_map:\n            pass\n        else:\n            raise KeyError(f\"Key {key} not found in Safetensor files\")\n        file = self.tensor_file_map[key]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        if use_torch_npu:\n            tensor = f.get_tensor(key).to(torch.float16)\n        else:\n            tensor = f.get_tensor(key)\n\n        return tensor.to(device)\n\n    def load_experts(self, key: str, device: str=\"cpu\"):\n        '''\n        Load experts from safetensor\n        key: the name of the experts\n        device: the device to load the experts to\n        return: dict, \n        {up: tensor, down: tensor, gate: tensor, up_type: int, down_type: int, gate_type: int}\n        {xxx}_type: the type of the up tensor, corresponding to the ggml type\n        '''\n        if self.has_tensor(translate_name_to_gguf(key)+\".ffn_gate_exps.weight\"):\n            # legacy branch for loading hybrid model\n            base_key = translate_name_to_gguf(key)\n            # Load experts from safetensor\n            gate_key = f\"{base_key}.ffn_gate_exps.weight\"\n            gate_type_key = f\"{base_key}.ffn_gate_exps.ggml_type\"\n            up_key = f\"{base_key}.ffn_up_exps.weight\"\n            up_type_key = f\"{base_key}.ffn_up_exps.ggml_type\"\n            down_key = f\"{base_key}.ffn_down_exps.weight\"\n            down_type_key = f\"{base_key}.ffn_down_exps.ggml_type\"\n            gate_tensor = self.load_tensor(gate_key, device).numpy()\n            up_tensor = self.load_tensor(up_key, device).numpy()\n            down_tensor = self.load_tensor(down_key, device).numpy()\n            gate_type = self.load_tensor(gate_type_key, device).item()\n            up_type = self.load_tensor(up_type_key, device).item()\n            down_type = self.load_tensor(down_type_key, device).item()\n\n            return {\n                \"up\": up_tensor,\n                \"gate\": gate_tensor,\n                \"down\": down_tensor,\n                \"up_type\": up_type,\n                \"gate_type\": gate_type,\n                \"down_type\": down_type\n            }\n\n        else:\n            # Load experts from safetensor\n            base_key = key  # e.g. \"model.layers.3.mlp.experts\"\n            experts_count = 0\n            \n            key_no_proj = False\n            if self.has_tensor(f\"{base_key}.{experts_count}.up.weight\"):\n                key_no_proj = True\n\n            # First, count how many experts we have by checking for expert 0's up_proj\n            while self.has_tensor(f\"{base_key}.{experts_count}.up_proj.weight\") or self.has_tensor(f\"{base_key}.{experts_count}.up.weight\"):\n                experts_count += 1\n            \n            if experts_count == 0:\n                raise ValueError(f\"No experts found for key {base_key}\")\n            \n            # Initialize empty lists to store tensors for each projection type\n            up_projs = []\n            gate_projs = []\n            down_projs = []\n            \n            # Load all expert weights\n            for expert_id in range(experts_count):\n\n                if key_no_proj:\n                    up_key = f\"{base_key}.{expert_id}.up.weight\"\n                    gate_key = f\"{base_key}.{expert_id}.gate.weight\"\n                    down_key = f\"{base_key}.{expert_id}.down.weight\"\n                else:\n                    up_key = f\"{base_key}.{expert_id}.up_proj.weight\"\n                    gate_key = f\"{base_key}.{expert_id}.gate_proj.weight\"\n                    down_key = f\"{base_key}.{expert_id}.down_proj.weight\"\n                \n                up_tensor = self.load_tensor(up_key, device)\n                gate_tensor = self.load_tensor(gate_key, device)\n                down_tensor = self.load_tensor(down_key, device)\n                \n                up_projs.append(up_tensor)\n                gate_projs.append(gate_tensor)\n                down_projs.append(down_tensor)\n            \n            # Stack the tensors along a new dimension\n            up_tensor = torch.stack(up_projs, dim=0)\n            gate_tensor = torch.stack(gate_projs, dim=0)\n            down_tensor = torch.stack(down_projs, dim=0)\n            \n            # Get original dtype for GGML type determination\n            orig_up_dtype = up_tensor.dtype\n            orig_gate_dtype = gate_tensor.dtype\n            orig_down_dtype = down_tensor.dtype\n            \n            # Convert to numpy with proper bfloat16 support\n            up_numpy = up_tensor.view(torch.uint16).numpy()\n            gate_numpy = gate_tensor.view(torch.uint16).numpy()\n            down_numpy = down_tensor.view(torch.uint16).numpy()\n            \n            # Determine tensor data types for GGML conversion\n            def get_ggml_type(dtype):\n                if dtype == torch.float32:\n                    return GGMLQuantizationType.F32\n                elif dtype == torch.float16:\n                    return GGMLQuantizationType.F16\n                elif dtype == torch.bfloat16:\n                    return GGMLQuantizationType.BF16\n                else:\n                    raise ValueError(f\"Unsupported tensor dtype: {dtype}\")\n            \n            return {\n                \"up\": up_numpy,\n                \"gate\": gate_numpy,\n                \"down\": down_numpy,\n                \"up_type\": get_ggml_type(orig_up_dtype),\n                \"gate_type\": get_ggml_type(orig_gate_dtype),\n                \"down_type\": get_ggml_type(orig_down_dtype)\n            }\n                \n    def load_gate(self, key: str, device: str=\"cpu\"):\n        '''\n        Load gate from safetensor\n        key: the name of the gate\n        device: the device to load the gate to\n        return: dict, \n        {'weight': tensor, 'e_score_correction_bias': tensor}\n        '''\n        target = [\"weight\", \"e_score_correction_bias\"]\n        res = {'weight': None, 'e_score_correction_bias': None}\n        if self.has_tensor(translate_name_to_gguf(key)+\".ffn_gate_exps.weight\"):\n            # legacy branch for loading hybrid model\n            base_key = key\n            for k in target:\n                translated_key = translate_name_to_gguf(f\"{base_key}.{k}\")\n                if self.has_tensor(translated_key):\n                    tensor = self.load_tensor(translated_key, device)\n                    res[k] = tensor\n        else:\n            # Load gate from safetensor\n            base_key = key\n            for k in target:\n                if self.has_tensor(f\"{base_key}.{k}\"):\n                    tensor = self.load_tensor(f\"{base_key}.{k}\", device)\n                    res[k] = tensor\n        return res\n\n    def close_all_handles(self):\n        for handle in self.file_handle_map.values():\n            handle.close()\n        self.file_handle_map.clear()\n\n    def load_dequantized_tensor(self, key: str, device: str = \"cpu\"):\n        if key in self.tensor_file_map and translate_name_to_gguf(key):\n            pass\n        elif translate_name_to_gguf(key) in self.tensor_file_map:\n            key = translate_name_to_gguf(key)\n        else:\n            raise KeyError(f\"Key {key} not found in Safetensor files\")\n        file = self.tensor_file_map[key]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(key).to(device)\n        if key.endswith(\".weight\"):\n            if key[:-7] + \".weight_scale_inv\" in self.tensor_file_map:\n                weight_scale_inv = f.get_tensor(key[:-7] + \".weight_scale_inv\").to(device)\n                tensor = weight_dequant(tensor, weight_scale_inv)\n        return tensor.to(device)\n\n    def has_tensor(self, name: str):\n        return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map\n\nclass GGUFLoader(ModelLoader):\n    tensor_info: dict\n    gguf_path: str\n    tensor_file_map: dict # {tensor_name: tensor_file_path}\n    gguf_file_meta: dict\n    safetensor_loader: SafeTensorLoader\n    def __init__(self, gguf_path: str, quantize: str = None):\n        # Check dir exist\n        if not os.path.exists(gguf_path):\n            raise FileNotFoundError(f\"GGUF dir not found: {gguf_path}\")\n        if os.path.isfile(gguf_path):\n            gguf_path = os.path.dirname(gguf_path)\n\n        self.safetensor_loader = None\n        \n        self.tensor_info = {}\n        self.gguf_path = gguf_path\n        self.tensor_file_map = {}\n        self.file_data_map = {}\n        self.gguf_file_meta = {}\n        self.tensor_device_map = {}\n\n        if use_torch_npu:\n            if quantize == \"w8a8_dynamic\":\n                safetensor_loader = W8A8SafeTensorLoader(gguf_path)\n            else:\n                safetensor_loader = SafeTensorLoader(gguf_path)\n            if safetensor_loader.tensor_file_map:\n                self.safetensor_loader = safetensor_loader\n                return\n\n        # Walk through all the .gguf files in the directory\n        found_gguf = False\n        for root, dirs, files in os.walk(gguf_path):\n            for file in files:\n                if file.endswith(\".gguf\"):\n                    found_gguf = True\n                    file_name = os.path.join(root, file)\n                    with open(file_name, \"rb\") as f:\n                        self.load_gguf(f)\n                        if file_name not in self.file_data_map:\n                            self.file_data_map[file_name] = np.memmap(file_name, mode = 'r')\n        if not found_gguf:\n            raise FileNotFoundError(f\"Cannot find any .gguf files in: {gguf_path}\")\n                            \n    def load_gguf(self, f):\n        f.seek(0)\n        assert f.read(4) == b'GGUF'\n        values = struct.unpack(\"<IQQ\", f.read(4+8+8))\n        version, n_tensors, n_kv = values\n        if version != 3:\n            warnings.warn(f\"Version {version} has never been tested, might not work\")\n\n        info = {}\n        for _ in range(n_kv):\n            name = read_value(f, DATA_TYPES[\"string\"])\n\n            data_type = struct.unpack(\"<I\", f.read(4))[0]\n\n            info[name] = read_value(f, data_type)\n\n        tensor_info = {}\n        for _ in range(n_tensors):\n            name = read_value(f, DATA_TYPES[\"string\"])\n            shape_len = read_value(f, DATA_TYPES[\"uint32\"])\n            shape = [read_value(f, DATA_TYPES[\"uint64\"]) for _ in range(shape_len)]\n            ggml_type = read_value(f, DATA_TYPES[\"uint32\"])\n            bad_offset = read_value(f, DATA_TYPES[\"uint64\"])\n            n_elems = int(math.prod(shape))\n            block_size, type_size = GGML_QUANT_SIZES[ggml_type]\n            n_bytes = n_elems * type_size // block_size\n            np_dims = tuple(reversed(shape))\n        \n            item_type: npt.DTypeLike\n            if ggml_type == GGMLQuantizationType.F16:\n                item_count = n_elems\n                item_type = np.float16\n            elif ggml_type == GGMLQuantizationType.F32:\n                item_count = n_elems\n                item_type = np.float32\n            elif ggml_type == GGMLQuantizationType.F64:\n                item_count = n_elems\n                item_type = np.float64\n            elif ggml_type == GGMLQuantizationType.I8:\n                item_count = n_elems\n                item_type = np.int8\n            elif ggml_type == GGMLQuantizationType.I16:\n                item_count = n_elems\n                item_type = np.int16\n            elif ggml_type == GGMLQuantizationType.I32:\n                item_count = n_elems\n                item_type = np.int32\n            elif ggml_type == GGMLQuantizationType.I64:\n                item_count = n_elems\n                item_type = np.int64\n            else:\n                item_count = n_bytes\n                item_type = np.uint8\n                np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)\n\n            tensor_info[name] = {\n                \"ggml_type\": ggml_type,\n                \"shape\": shape,\n                \"bad_offset\": bad_offset,\n                \"item_type\": item_type,\n                \"item_count\": item_count,\n                \"np_dims\": np_dims\n            }\n\n        start = f.tell()\n        # Alignment is 32 by default.\n        # https://github.com/ggerganov/ggml/blob/e1daebbf9d38d510ba456c4d50b4500a73ac2b14/docs/gguf.md?plain=1#L253\n        alignment = info.get(\"general.alignment\", 32)\n\n        # Inconveniently, the offset defined in gguf files is relative to the\n        # end of the header and is unaligned.\n        # We need to compute the absolute file offset ourselves instead.\n        for t in tensor_info.values():\n            offset = start + t[\"bad_offset\"]\n            offset += (alignment - offset % alignment) % alignment\n            t[\"offset\"] = offset\n            \n        for name in tensor_info:\n            self.tensor_file_map[name] = f.name\n        self.tensor_info.update(tensor_info)\n        self.gguf_file_meta.update(info)\n    \n    def get_mmap_tensor(self, name):\n        name = translate_name_to_gguf(name)\n        t = self.tensor_info[name]\n        mmap_data = self.file_data_map[ self.tensor_file_map[name] ]\n\n        offset = t[\"offset\"]\n        item_type = t[\"item_type\"]\n        item_count = t[\"item_count\"]\n        itemsize = int(np.empty([], dtype = item_type).itemsize)\n        return mmap_data[offset : offset + itemsize * item_count]\n\n    def get_undequanted_tensor_and_ggml_type(self, name):\n        name = translate_name_to_gguf(name)\n        t = self.tensor_info[name]\n        data = self.get_mmap_tensor(name)\n        ggml_type = t[\"ggml_type\"]\n        data = torch.from_numpy(data)\n        return data, ggml_type\n\n    def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = \"cuda\", target_dtype = torch.get_default_dtype())->torch.Tensor:\n        name = translate_name_to_gguf(name)\n        t = self.tensor_info[name]\n        shape = t[\"shape\"]\n        ggml_type = t[\"ggml_type\"]\n        if ggml_type not in GGML_NAMES:\n            raise NotImplementedError(f\"ggml_type {ggml_type} not implemented\")\n        ggml_name = GGML_NAMES[ggml_type]\n\n        # TODO: experts may fused in quant block, split it\n        assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, \"experts may fused in quant block, please use CPU dequant\"\n\n        blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name]\n        block_size = GGML_BLOCK_SIZES[ggml_name]\n        offset = expert_id * block_size * blocks_per_experts\n        data = data[offset: offset + block_size * blocks_per_experts]\n\n        if \"cuda\" in device.lower():\n            values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)\n        else:\n            values = GGML_DEQUANTIZE[ggml_name](data)\n            values = torch.from_numpy(values.copy())\n\n        if ggml_name == \"BF16\":\n            values = values.view(torch.bfloat16)\n        values = values.view(shape[-2::-1])\n\n        return values\n\n    def load_gguf_tensor(self, name: str, device:str = \"cpu\", target_dtype = None)->torch.Tensor:\n        name = translate_name_to_gguf(name)\n        t = self.tensor_info[name]\n        if target_dtype == None:\n            target_dtype = torch.get_default_dtype()\n        \n        shape = t[\"shape\"]\n        ggml_type = t[\"ggml_type\"]\n\n        if ggml_type not in GGML_NAMES:\n            raise NotImplementedError(f\"ggml_type {ggml_type} not implemented\")\n\n        ggml_name = GGML_NAMES[ggml_type]\n\n        data = self.get_mmap_tensor(name)\n\n        block_size = GGML_BLOCK_SIZES[ggml_name]\n        elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]\n        num_elements = int(np.prod(shape))\n        num_blocks = num_elements // elements_per_block\n        \n        blocks_per_iter = 16384\n        if num_blocks > blocks_per_iter: # dequant large tensor\n            values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)\n            for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):\n                blocks_begin = i * blocks_per_iter\n                blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)\n                if \"cuda\" in device.lower():\n                    try:\n                        cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)\n                    except:\n                        cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])\n                        cur_values = torch.from_numpy(cur_values.copy()).to(device)\n                else:\n                    cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])\n                    cur_values = torch.from_numpy(cur_values.copy())\n                \n                cur_values = cur_values.view(-1, elements_per_block)\n                if ggml_name == \"BF16\":\n                    cur_values = cur_values.view(torch.bfloat16)\n                values[blocks_begin : blocks_end] = cur_values\n        else:\n            if \"cuda\" in device.lower():\n                values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)\n            else:\n                np_values = np.copy(GGML_DEQUANTIZE[ggml_name](data))\n                values = torch.from_numpy(np_values).to(device)\n                del np_values\n\n        if ggml_name == \"BF16\":\n            values = values.view(torch.bfloat16)\n            \n\n        values = values.view(shape[::-1])\n        if \"attn_q\" in name and self.gguf_file_meta['general.architecture'] in [\"llama\"]:\n            n_head = self.gguf_file_meta['llama.attention.head_count']\n            values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])\n            .swapaxes(1, 2)\n            .reshape(values.shape))\n        elif \"attn_k\" in name and self.gguf_file_meta['general.architecture'] in [\"llama\"]:\n            n_head = self.gguf_file_meta['llama.attention.head_count_kv'] \n            values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])\n            .swapaxes(1, 2)\n            .reshape(values.shape))\n        return values\n    def has_tensor(self, name: str):\n        name = translate_name_to_gguf(name)\n        return name in self.tensor_info\n\n    def get_ggml_type(self, name: str):\n        name = translate_name_to_gguf(name)\n        if name not in self.tensor_info:\n            raise KeyError(f\"Key {name} not found in GGUF files\")\n        return self.tensor_info[name][\"ggml_type\"]\n    \nclass ModelLoaderFactory:\n    \"\"\"\n    Factory class for creating model loaders.\n    Automatically detects the model format based on file extensions in the directory.\n    \"\"\"\n    \n    @staticmethod\n    def create_loader(path: str):\n        \"\"\"\n        Create a model loader for the given path by detecting the model format.\n        The function checks for the presence of .safetensors or .gguf files\n        in the specified path and creates the appropriate loader.\n        \n        Args:\n            path: Path to the model directory or file\n            \n        Returns:\n            An appropriate ModelLoader instance (SafeTensorLoader or GGUFLoader)\n        \n        Raises:\n            FileNotFoundError: If no supported model files are found in the path\n        \"\"\"\n        if not os.path.exists(path):\n            raise FileNotFoundError(f\"Path not found: {path}\")\n            \n        # Normalize to directory path if a file was provided\n        if os.path.isfile(path):\n            if path.endswith(\".safetensors\"):\n                return SafeTensorLoader(path)\n            elif path.endswith(\".gguf\"):\n                return GGUFLoader(path)\n            else:\n                folder_path = os.path.dirname(path)\n        else:\n            folder_path = path\n            \n        # Check for safetensors files\n        has_safetensors = False\n        has_gguf = False\n        \n        for root, _, files in os.walk(folder_path):\n            for file in files:\n                if file.endswith(\".safetensors\"):\n                    has_safetensors = True\n                    break\n                elif file.endswith(\".gguf\"):\n                    has_gguf = True\n                    break\n            if has_safetensors or has_gguf:\n                break\n                \n        # Create the appropriate loader based on detected file types\n        # Prioritize SafeTensor over GGUF if both are present\n        if has_safetensors:\n            try:\n                return SafeTensorLoader(folder_path)\n            except Exception as e:\n                print(f\"Failed to create SafeTensorLoader: {e}\")\n                # Fall through to try GGUF if SafeTensor fails\n                if not has_gguf:\n                    raise\n        \n        if has_gguf:\n            try:\n                return GGUFLoader(folder_path)\n            except Exception as e:\n                print(f\"Failed to create GGUFLoader: {e}\")\n                raise\n        \n        # No supported model files found\n        raise FileNotFoundError(f\"No .safetensors or .gguf files found in: {folder_path}\")\n\nclass W8A8SafeTensorLoader(SafeTensorLoader):\n    def load_tensor(self, key: str, device: str = \"cpu\"):\n        if key not in self.tensor_file_map:\n            raise KeyError(f\"Key {key} not found in Safetensor files\")\n        file = self.tensor_file_map[key]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(key)\n        if 'deq_scale' in key:\n            tensor = torch.from_numpy(\n                np.frombuffer(tensor.to(torch.float16).to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64))\n        if 'input_scale' in key:\n            tensor = tensor.to(torch.float16)\n        if \"weight_scale\" in key or \"weight_offset\" in key:\n            if \"ffn\" in key:\n                tensor = tensor.to(torch.float32)\n            else:\n                tensor = tensor.to(torch.float16)\n        if 'input_offset' in key:\n            tensor = tensor.to(torch.int8)\n        if tensor.dtype == torch.bfloat16:\n            tensor = tensor.to(torch.float16)\n        return tensor.to(device)\n\n    def load_dequantized_tensor(self, key: str, device: str = \"cpu\"):\n        tensor = self.load_tensor(key, device)\n        return tensor\n"
  },
  {
    "path": "archive/ktransformers/util/modeling_rope_utils.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Optional, Tuple\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import is_torch_available, logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nif is_torch_available():\n    import torch\n\n\ndef _compute_default_rope_parameters(\n    config: Optional[PretrainedConfig] = None,\n    device: Optional[\"torch.device\"] = None,\n    seq_len: Optional[int] = None,\n    **rope_kwargs,\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies according to the original RoPE implementation\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length. Unused for this type of RoPE.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).\n    \"\"\"\n    if config is not None and len(rope_kwargs) > 0:\n        raise ValueError(\n            \"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in \"\n            f\"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}\"\n        )\n    if len(rope_kwargs) > 0:\n        base = rope_kwargs[\"base\"]\n        dim = rope_kwargs[\"dim\"]\n    elif config is not None:\n        base = config.rope_theta\n        partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n        head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        dim = int(head_dim * partial_rotary_factor)\n\n    attention_factor = 1.0  # Unused in this type of RoPE\n\n    # Compute the inverse frequencies\n    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))\n    return inv_freq, attention_factor\n\n\ndef _compute_linear_scaling_rope_parameters(\n    config: Optional[PretrainedConfig] = None,\n    device: Optional[\"torch.device\"] = None,\n    seq_len: Optional[int] = None,\n    **rope_kwargs,\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length. Unused for this type of RoPE.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).\n    \"\"\"\n    if config is not None and len(rope_kwargs) > 0:\n        raise ValueError(\n            \"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in \"\n            f\"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}\"\n        )\n    if len(rope_kwargs) > 0:\n        factor = rope_kwargs[\"factor\"]\n    elif config is not None:\n        factor = config.rope_scaling[\"factor\"]\n\n    # Gets the default RoPE parameters\n    inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)\n\n    # Then applies linear scaling to the frequencies.\n    # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so\n    # applying scaling to the inverse frequencies is equivalent.\n    inv_freq /= factor\n    return inv_freq, attention_factor\n\n\ndef _compute_dynamic_ntk_parameters(\n    config: Optional[PretrainedConfig] = None,\n    device: Optional[\"torch.device\"] = None,\n    seq_len: Optional[int] = None,\n    **rope_kwargs,\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length, used to update the dynamic RoPE at inference time.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).\n    \"\"\"\n    # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling\n    if config is not None and len(rope_kwargs) > 0:\n        raise ValueError(\n            \"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in \"\n            f\"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}\"\n        )\n    if len(rope_kwargs) > 0:\n        base = rope_kwargs[\"base\"]\n        dim = rope_kwargs[\"dim\"]\n        max_position_embeddings = rope_kwargs[\"max_position_embeddings\"]\n        factor = rope_kwargs[\"factor\"]\n    elif config is not None:\n        base = config.rope_theta\n        partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n        head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        dim = int(head_dim * partial_rotary_factor)\n        max_position_embeddings = config.max_position_embeddings\n        factor = config.rope_scaling[\"factor\"]\n\n    attention_factor = 1.0  # Unused in this type of RoPE\n\n    # seq_len: default to max_position_embeddings, e.g. at init time\n    seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings\n\n    # Compute the inverse frequencies\n    base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))\n    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))\n    return inv_freq, attention_factor\n\n\ndef _compute_yarn_parameters(\n    config: PretrainedConfig, device: \"torch.device\", seq_len: Optional[int] = None, **rope_kwargs\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies with NTK scaling. Please refer to the\n    [original paper](https://arxiv.org/abs/2309.00071)\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length. Unused for this type of RoPE.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin.\n    \"\"\"\n    # No need to keep BC with yarn, unreleased when this new pattern was created.\n    if len(rope_kwargs) > 0:\n        raise ValueError(\n            f\"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}\"\n        )\n\n    base = config.rope_theta\n    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n    head_dim = getattr(config, \"qk_rope_head_dim\", config.hidden_size // config.num_attention_heads)\n    dim = int(head_dim * partial_rotary_factor)\n    factor = config.rope_scaling[\"factor\"]\n    attention_factor = config.rope_scaling.get(\"attention_factor\")\n    mscale = config.rope_scaling.get(\"mscale\")\n    mscale_all_dim = config.rope_scaling.get(\"mscale_all_dim\")\n\n    # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a\n    # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two\n    # values to compute the default attention scaling factor, instead of using `factor`.\n    if \"original_max_position_embeddings\" in config.rope_scaling:\n        original_max_position_embeddings = config.rope_scaling[\"original_max_position_embeddings\"]\n        factor = config.max_position_embeddings / original_max_position_embeddings\n    else:\n        original_max_position_embeddings = config.max_position_embeddings\n\n    def get_mscale(scale, mscale=1):\n        if scale <= 1:\n            return 1.0\n        return 0.1 * mscale * math.log(scale) + 1.0\n\n    # Sets the attention factor as suggested in the paper\n    if attention_factor is None:\n        if mscale and mscale_all_dim:\n            attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))\n        else:\n            attention_factor = get_mscale(factor)\n\n    # Optional config options\n    # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)\n    beta_fast = config.rope_scaling.get(\"beta_fast\") or 32\n    beta_slow = config.rope_scaling.get(\"beta_slow\") or 1\n\n    # Compute the inverse frequencies\n    def find_correction_dim(num_rotations, dim, base, max_position_embeddings):\n        \"\"\"Inverse dimension formula to find the dimension based on the number of rotations\"\"\"\n        return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))\n\n    def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):\n        \"\"\"Find dimension range bounds based on rotations\"\"\"\n        low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))\n        high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))\n        return max(low, 0), min(high, dim - 1)\n\n    def linear_ramp_factor(min, max, dim):\n        if min == max:\n            max += 0.001  # Prevent singularity\n\n        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n        ramp_func = torch.clamp(linear_func, 0, 1)\n        return ramp_func\n\n    # Note on variable naming: \"interpolation\" comes from the original technique, where we interpolate the position IDs\n    # to expand the possible context length. In other words, interpolation = apply scaling factor.\n    pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)\n    inv_freq_extrapolation = 1.0 / pos_freqs\n    inv_freq_interpolation = 1.0 / (factor * pos_freqs)\n\n    low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)\n\n    # Get n-dimensional rotational scaling corrected for extrapolation\n    inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)\n    inv_freq = (\n        inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)\n        + inv_freq_extrapolation * inv_freq_extrapolation_factor\n    )\n    return inv_freq, attention_factor\n\n\ndef _compute_longrope_parameters(\n    config: PretrainedConfig, device: \"torch.device\", seq_len: Optional[int] = None, **rope_kwargs\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies with LongRoPE scaling. Please refer to the\n    [original implementation](https://github.com/microsoft/LongRoPE)\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin.\n    \"\"\"\n    # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling\n    # No need to keep BC with longrope, unreleased when this new pattern was created.\n    if len(rope_kwargs) > 0:\n        raise ValueError(\n            \"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got \"\n            f\"{rope_kwargs}\"\n        )\n\n    base = config.rope_theta\n    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n    head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n    dim = int(head_dim * partial_rotary_factor)\n    long_factor = config.rope_scaling[\"long_factor\"]\n    short_factor = config.rope_scaling[\"short_factor\"]\n    factor = config.rope_scaling.get(\"factor\")\n    attention_factor = config.rope_scaling.get(\"attention_factor\")\n\n    # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a\n    # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two\n    # values to compute the default attention scaling factor, instead of using `factor`.\n    if hasattr(config, \"original_max_position_embeddings\"):\n        original_max_position_embeddings = config.original_max_position_embeddings\n        factor = config.max_position_embeddings / config.original_max_position_embeddings\n    else:\n        original_max_position_embeddings = config.max_position_embeddings\n\n    # Sets the attention factor as suggested in the paper\n    if attention_factor is None:\n        if factor <= 1.0:\n            attention_factor = 1.0\n        else:\n            attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))\n\n    # Compute the inverse frequencies -- scaled based on the target sequence length\n    if seq_len and seq_len > original_max_position_embeddings:\n        ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)\n    else:\n        ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)\n    inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim\n    inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)\n\n    return inv_freq, attention_factor\n\n\ndef _compute_llama3_parameters(\n    config: PretrainedConfig, device: \"torch.device\", seq_len: Optional[int] = None, **rope_kwargs\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies for llama 3.1.\n\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length. Unused for this type of RoPE.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin.\n    \"\"\"\n    # Gets the default RoPE parameters\n    inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)\n\n    factor = config.rope_scaling[\"factor\"]  # `8` in the original implementation\n    low_freq_factor = config.rope_scaling[\"low_freq_factor\"]  # `1` in the original implementation\n    high_freq_factor = config.rope_scaling[\"high_freq_factor\"]  # `4` in the original implementation\n    old_context_len = config.rope_scaling[\"original_max_position_embeddings\"]  # `8192` in the original implementation\n\n    low_freq_wavelen = old_context_len / low_freq_factor\n    high_freq_wavelen = old_context_len / high_freq_factor\n\n    wavelen = 2 * math.pi / inv_freq\n    # wavelen < high_freq_wavelen: do nothing\n    # wavelen > low_freq_wavelen: divide by factor\n    inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)\n    # otherwise: interpolate between the two, using a smooth factor\n    smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)\n    smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama\n    is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)\n    inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n\n    return inv_freq_llama, attention_factor\n\n\n# This maps the \"rope_type\" string field in rope config to the corresponding function to compute the RoPE parameters\n# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE\n# parameterizations, as long as the callable has the same signature.\nROPE_INIT_FUNCTIONS = {\n    \"default\": _compute_default_rope_parameters,\n    \"linear\": _compute_linear_scaling_rope_parameters,\n    \"dynamic\": _compute_dynamic_ntk_parameters,\n    \"yarn\": _compute_yarn_parameters,\n    \"longrope\": _compute_longrope_parameters,\n    \"llama3\": _compute_llama3_parameters,\n}\n\n\ndef _check_received_keys(\n    rope_type: str,\n    received_keys: set,\n    required_keys: set,\n    optional_keys: Optional[set] = None,\n    ignore_keys: Optional[set] = None,\n):\n    \"\"\"Compare the received keys in `config.rope_scaling` against the expected and optional keys\"\"\"\n    # BC: \"rope_type\" was originally \"type\" -- let's check for \"rope_type\" when \"type\" is present\n    if \"type\" in received_keys:\n        received_keys -= {\"type\"}\n        required_keys.add(\"rope_type\")\n\n    # Some models need to store model-specific keys, and we don't want to throw warning at them\n    if ignore_keys is not None:\n        received_keys -= ignore_keys\n\n    missing_keys = required_keys - received_keys\n    if missing_keys:\n        raise KeyError(f\"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}\")\n\n    if optional_keys is not None:\n        unused_keys = received_keys - required_keys - optional_keys\n    else:\n        unused_keys = received_keys - required_keys\n    if unused_keys:\n        logger.warning(f\"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}\")\n\n\ndef _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\"}\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)\n\n\ndef _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\", \"factor\"}\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)\n\n    factor = rope_scaling[\"factor\"]\n    if factor is None or not isinstance(factor, float) or factor < 1.0:\n        logger.warning(f\"`rope_scaling`'s factor field must be a float >= 1, got {factor}\")\n\n\ndef _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\", \"factor\"}\n    # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`\n    optional_keys = {\"original_max_position_embeddings\"}\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)\n\n    factor = rope_scaling[\"factor\"]\n    if factor is None or not isinstance(factor, float) or factor < 1.0:\n        logger.warning(f\"`rope_scaling`'s factor field must be a float >= 1, got {factor}\")\n\n\ndef _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\", \"factor\"}\n    optional_keys = {\n        \"attention_factor\",\n        \"beta_fast\",\n        \"beta_slow\",\n        \"original_max_position_embeddings\",\n        \"mscale\",\n        \"mscale_all_dim\",\n    }\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)\n\n    factor = rope_scaling[\"factor\"]\n    if factor is None or not isinstance(factor, float) or factor < 1.0:\n        logger.warning(f\"`rope_scaling`'s factor field must be a float >= 1, got {factor}\")\n\n    attention_factor = rope_scaling.get(\"attention_factor\")\n    if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):\n        logger.warning(\n            f\"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}\"\n        )\n    beta_fast = rope_scaling.get(\"beta_fast\")\n    if beta_fast is not None and not isinstance(beta_fast, float):\n        logger.warning(f\"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}\")\n    beta_slow = rope_scaling.get(\"beta_slow\")\n    if beta_slow is not None and not isinstance(beta_slow, float):\n        logger.warning(f\"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}\")\n\n    if (beta_fast or 32) < (beta_slow or 1):\n        logger.warning(\n            f\"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} \"\n            f\"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)\"\n        )\n\n\ndef _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\", \"short_factor\", \"long_factor\"}\n    # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`\n    optional_keys = {\"attention_factor\", \"factor\", \"original_max_position_embeddings\"}\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)\n\n    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n    head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n    dim = int(head_dim * partial_rotary_factor)\n\n    short_factor = rope_scaling.get(\"short_factor\")\n    if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):\n        logger.warning(f\"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}\")\n    if not len(short_factor) == dim // 2:\n        logger.warning(f\"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}\")\n\n    long_factor = rope_scaling.get(\"long_factor\")\n    if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):\n        logger.warning(f\"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}\")\n    if not len(long_factor) == dim // 2:\n        logger.warning(f\"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}\")\n\n    # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over\n    # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is\n    # unique to longrope (= undesirable)\n    if hasattr(config, \"original_max_position_embeddings\"):\n        logger.warning_once(\n            \"This model has set a `original_max_position_embeddings` field, to be used together with \"\n            \"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`\"\n            \"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, \"\n            \"as it is compatible with most model architectures.\"\n        )\n    else:\n        factor = rope_scaling.get(\"factor\")\n        if factor is None:\n            logger.warning(\"Missing required keys in `rope_scaling`: 'factor'\")\n        elif not isinstance(factor, float) or factor < 1.0:\n            logger.warning(f\"`rope_scaling`'s factor field must be a float >= 1, got {factor}\")\n\n        attention_factor = rope_scaling.get(\"attention_factor\")\n        if attention_factor is not None:\n            if not isinstance(attention_factor, float) or attention_factor < 0.0:\n                logger.warning(\n                    f\"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}\"\n                )\n\n\ndef _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\", \"factor\", \"original_max_position_embeddings\", \"low_freq_factor\", \"high_freq_factor\"}\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)\n\n    factor = rope_scaling[\"factor\"]\n    if factor is None or not isinstance(factor, float) or factor < 1.0:\n        logger.warning(f\"`rope_scaling`'s factor field must be a float >= 1, got {factor}\")\n\n    low_freq_factor = rope_scaling[\"low_freq_factor\"]\n    high_freq_factor = rope_scaling[\"high_freq_factor\"]\n    if low_freq_factor is None or not isinstance(low_freq_factor, float):\n        logger.warning(f\"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}\")\n    if high_freq_factor is None or not isinstance(high_freq_factor, float):\n        logger.warning(f\"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}\")\n    if high_freq_factor <= low_freq_factor:\n        logger.warning(\n            \"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=\"\n            f\"{high_freq_factor} and low_freq_factor={low_freq_factor}\"\n        )\n\n    original_max_position_embeddings = rope_scaling[\"original_max_position_embeddings\"]\n    if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):\n        logger.warning(\n            \"`rope_scaling`'s original_max_position_embeddings field must be an integer, got \"\n            f\"{original_max_position_embeddings}\"\n        )\n    if original_max_position_embeddings >= config.max_position_embeddings:\n        logger.warning(\n            \"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got \"\n            f\"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}\"\n        )\n\n\n# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.\nROPE_VALIDATION_FUNCTIONS = {\n    \"default\": _validate_default_rope_parameters,\n    \"linear\": _validate_linear_scaling_rope_parameters,\n    \"dynamic\": _validate_dynamic_scaling_rope_parameters,\n    \"yarn\": _validate_yarn_parameters,\n    \"longrope\": _validate_longrope_parameters,\n    \"llama3\": _validate_llama3_parameters,\n}\n\n\ndef rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    \"\"\"\n    Validate the RoPE config arguments, given a `PretrainedConfig` object\n    \"\"\"\n    rope_scaling = getattr(config, \"rope_scaling\", None)  # not a default parameter in `PretrainedConfig`\n    if rope_scaling is None:\n        return\n\n    # BC: \"rope_type\" was originally \"type\"\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", \"default\"))\n    validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)\n    if validation_fn is not None:\n        validation_fn(config, ignore_keys=ignore_keys)\n    else:\n        logger.warning(\n            f\"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'\"\n        )"
  },
  {
    "path": "archive/ktransformers/util/npu_graph_runner.py",
    "content": "'''\nDescription :\nAuthor      : Boxin Zhang\nVersion     : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n'''\nfrom typing import Dict\n\nimport threading\nimport torch\nimport torch_npu\n\n\nclass NPUGraphRunner:\n\n    def __init__(self, deviceId):\n        torch.npu.set_compile_mode(jit_compile=False)\n        self.deviceId = deviceId\n        self.input_buffers: Dict[str, torch.Tensor] = {}\n        self.output_buffers: Dict[str, torch.Tensor] = {}\n        self.past_key_value = None\n\n    def init(self, batch_size, seq_length):\n        self.graph = torch.npu.NPUGraph()\n        self.main_stream = torch_npu.npu.Stream(device=self.deviceId)\n        self.share_experts_stream = torch_npu.npu.Stream(device=self.deviceId)\n        self.logits = torch.zeros((batch_size, seq_length, 7168), dtype=torch.float16).to(self.deviceId)  # deepseekV3 hidden_size\n        self.workspace = None\n        self.model_capture = True\n        torch_npu.npu._subscribe_report(self.main_stream)\n\n    def destroy(self):\n        torch_npu.npu._unsubscribe_report(self.main_stream)\n        del self.graph\n        destory_runner(self.deviceId)\n\n    def capture(\n            self,\n            model,\n            cur_token,\n            position_ids,\n            cache_position,\n            past_key_values,\n            main_device,\n            **kwargs,\n    ) -> None:\n        inputs_embeds = model.model.embed_tokens(cur_token.to(\"cpu\")).to(main_device)\n        with torch.no_grad():\n            with torch.npu.graph(self.graph, stream=self.main_stream, auto_dispatch_capture=True):\n                logits = model(inputs_embeds=inputs_embeds,\n                            position_ids=position_ids,\n                            cache_position=cache_position,\n                            past_key_values=past_key_values,\n                            is_prefill=False,\n                            **kwargs)\n        self.input_buffers = {\n            \"inputs_embeds\": inputs_embeds,\n            \"position_ids\": position_ids,\n            \"cache_position\": cache_position,\n        }\n        self.output_buffers = {\n            \"logits\": logits,\n        }\n\n    def forward(\n            self,\n            inputs_embeds,\n            position_ids,\n            cache_position,\n    ) -> torch.Tensor:\n        thread = threading.Thread(target=self.graph.update, kwargs={\"cpu_update_input\": [{\"actual_seq_lengths_kv\": self.past_key_value.position}]})\n        thread.start()\n\n        self.input_buffers[\"inputs_embeds\"].copy_(inputs_embeds)\n        self.input_buffers[\"position_ids\"].copy_(position_ids)\n        self.input_buffers[\"cache_position\"].copy_(cache_position)\n        torch_npu.npu.synchronize()\n        with torch_npu.npu.stream(self.main_stream):\n            # Run the graph.\n            self.graph.replay()\n        thread.join()\n\n        # Return the output tensor.\n        return self.output_buffers[\"logits\"]\n\n    def launch_callback(self, func, data, block, stream):\n        torch_npu.npu._launch_host_func(stream, func, data)\n\n    def __call__(self, *args, **kwargs):\n        return self.forward(*args, **kwargs)\n\nrunner_dict = dict()\n\ndef check_runner(deviceId: int):\n    runner = runner_dict.get(deviceId)\n    if runner is None:\n        return True\n    else:\n        return False\n\ndef destory_runner(deviceId: int):\n    # print(\"the new NPUGraphRunner and deviceId is \", deviceId)\n    runner = runner_dict.get(deviceId)\n    if runner is not None:\n        runner_dict[deviceId] = None\n\ndef get_or_create_runner(deviceId: int):\n    runner = runner_dict.get(deviceId)\n    if runner is None:\n        runner = NPUGraphRunner(deviceId)\n        runner_dict[deviceId] = runner\n    return runner"
  },
  {
    "path": "archive/ktransformers/util/textstream.py",
    "content": "from typing import Any, List, Optional, Set\nclass TextStreamer:\n\n    def __init__(self, tokenizer: \"AutoTokenizer\", skip_prompt: bool = False, **decode_kwargs):\n        self.tokenizer = tokenizer\n        self.skip_prompt = skip_prompt\n        self.decode_kwargs = decode_kwargs\n\n        # variables used in the streaming process\n        self.token_cache = []\n        self.print_len = 0\n        self.next_tokens_are_prompt = True\n\n    def reset(self):\n        self.token_cache = []\n        self.print_len = 0\n\n    def put(self, value)->Optional[str]:\n        \"\"\"\n        Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.\n        \"\"\"        \n        if not isinstance(value,int):\n            raise ValueError(\"TextStreamer only supports batch size 1, and int type input\")\n\n\n        if self.skip_prompt and self.next_tokens_are_prompt:\n            self.next_tokens_are_prompt = False\n            return None\n\n        # Add the new token to the cache and decodes the entire thing.\n        self.token_cache.append(value)\n        text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs)\n\n        # After the symbol for a new line, we flush the cache.\n        if text.endswith(\"\\n\"):\n            printable_text = text[self.print_len :]\n            self.reset()\n        # If the last token is a CJK character, we print the characters.\n        elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):\n            printable_text = text[self.print_len :]\n            self.print_len += len(printable_text)\n        # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,\n        # which may change with the subsequent token -- there are probably smarter ways to do this!)\n        else:\n            printable_text = text[self.print_len : text.rfind(\" \") + 1]\n            self.print_len += len(printable_text)\n        return printable_text\n\n    def end(self)->Optional[str]:\n        \"\"\"Flushes any remaining cache and prints a newline to stdout.\"\"\"\n        # Flush the cache, if it exists\n        if len(self.token_cache) > 0:\n            text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)\n            printable_text = text[self.print_len :]\n            self.reset()\n        else:\n            printable_text = \"\"\n\n        self.next_tokens_are_prompt = True\n        return printable_text\n   \n    def _is_chinese_char(self, cp):\n        \"\"\"Checks whether CP is the codepoint of a CJK character.\"\"\"\n        # This defines a \"chinese character\" as anything in the CJK Unicode block:\n        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)\n        #\n        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,\n        # despite its name. The modern Korean Hangul alphabet is a different block,\n        # as is Japanese Hiragana and Katakana. Those alphabets are used to write\n        # space-separated words, so they are not treated specially and handled\n        # like the all of the other languages.\n        if (\n            (cp >= 0x4E00 and cp <= 0x9FFF)\n            or (cp >= 0x3400 and cp <= 0x4DBF)  #\n            or (cp >= 0x20000 and cp <= 0x2A6DF)  #\n            or (cp >= 0x2A700 and cp <= 0x2B73F)  #\n            or (cp >= 0x2B740 and cp <= 0x2B81F)  #\n            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #\n            or (cp >= 0xF900 and cp <= 0xFAFF)\n            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #\n        ):  #\n            return True\n\n        return False"
  },
  {
    "path": "archive/ktransformers/util/utils.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : Boxin Zhang, Azure-Tang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport re\nimport sys\nimport threading\n\nimport torch\nimport torch.distributed as dist\nfrom torch import nn\nimport itertools\nimport time\nimport enum\nfrom transformers import (\n    LogitsProcessorList,\n    TemperatureLogitsWarper,\n    TopKLogitsWarper,\n    TopPLogitsWarper,\n    MinPLogitsWarper,\n    TypicalLogitsWarper,\n    EpsilonLogitsWarper,\n    EtaLogitsWarper,\n)\n\nfrom ktransformers.util.custom_loader import ModelLoaderFactory, ModelLoader, SafeTensorLoader, translate_name_to_gguf\nfrom ktransformers.operators import base_operator\nfrom ktransformers.models.custom_cache import StaticCache\nfrom ktransformers.util.cuda_graph_runner import CUDAGraphRunner\nfrom ktransformers.util.textstream import TextStreamer\nif not torch.xpu.is_available():\n    from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton\n# from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton\nimport socket\n\nwarm_uped = False\nCUR_DEVICE = None\nW8A8_ENABLE = False\nQ4_GGUF_LODER = None\n_USE_NPU_GRAPH = False\n_MAX_DECODE_PROFILE = 1\nWARM_UP_SKIP_CNT = [1, 1]\n_SPECULATE_STEP = 1\n\ntry:\n    import torch_npu\n    use_torch_npu = torch_npu.npu.is_available()\n    from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size\nexcept:\n    use_torch_npu = False\n\ndef get_use_npu_graph():\n    assert _USE_NPU_GRAPH is not None, \"use npu graph is not setting\"\n    return _USE_NPU_GRAPH\n\nfrom enum import StrEnum\n\nclass StatKey(StrEnum):\n    Embedding = \"Embedding\"\n    GraphCapture = \"GraphCapture\"\n    GraphReplay = \"GraphReplay\"\n    ExpertsForward1 = \"ExpertsForward1\"\n    ExpertsForward2 = \"ExpertsForward2\"\n    CPUExperts = \"CPUExperts\"\n    GraphDestroy = \"GraphDestroy\"\n    DecodeOneTokenPost = \"DecodeOneTokenPost\"\n    DecodeOneToken = \"DecodeOneToken\"\n    GraphInit = \"GraphInit\"\n\nclass TimeStat:\n    def __init__(self):\n        # open_status = os.environ[\"KT_PERF_STAT\"] if \"KT_PERF_STAT\" in os.environ else \"0\"\n        # if open_status == \"0\":\n        #     self.on = False\n        # else:\n        #     self.on = True\n        self.on = True\n        self.prefill_stats = dict()\n        self.decode_stats = dict()\n        for key in StatKey:\n            self.prefill_stats[key] = StatItem()\n            self.decode_stats[key] = StatItem()\n        self.reset_all()\n\n    def record_start_time(self):\n        start_time = time.time_ns()\n        return start_time\n\n    def add_time_stat(self, key: StatKey, time_ns, is_prefill):\n        if not key:\n            return\n        # torch.cuda.synchronize()\n        cost = time.time_ns() - time_ns\n        if is_prefill:\n            item = self.prefill_stats[key]\n        else:\n            item = self.decode_stats[key]\n        item.add_item(cost)\n\n    def print_all(self):\n        # rank = f\"[rank:{torch.distributed.get_rank()}]\"\n        rank = f\"[rank:0]\"\n        msg = f\"\\n{rank} Prefill Time Stat\\n\"\n        msg += rank + \" {:27}{:>15}{:>15}{:>15}{:>15}{:>15}\\n\".format(\"\", \"min(ms)\", \"max(ms)\", \"avg(ms)\", \"count\", \"total(ms)\")\n        for key, value in self.prefill_stats.items():\n            msg += rank + f\" {key.value:<25}:{value.get_stat()}\\n\"\n        msg += f\"\\n{rank} Decode Time Stat\\n\"\n        msg += rank + \" {:27}{:>15}{:>15}{:>15}{:>15}{:>15}\\n\".format(\"\", \"min(ms)\", \"max(ms)\", \"avg(ms)\", \"count\", \"total(ms)\")\n        for key, value in self.decode_stats.items():\n            msg += rank + f\" {key.value:<25}:{value.get_stat()}\\n\"\n        print(msg)\n\n    def reset_all(self):\n        for _, value in self.prefill_stats.items():\n            value.reset()\n        for _, value in self.decode_stats.items():\n            value.reset()\n\n\nclass StatItem:\n    def __init__(self):\n        self.min_time = 100000000\n        self.max_time = 0\n        self.total_time_ns = 0\n        self.count = 0\n\n    def add_item(self, cost_time_ns):\n        self.count += 1\n        self.total_time_ns += cost_time_ns\n        self.min_time = min(self.min_time, cost_time_ns)\n        self.max_time = max(self.max_time, cost_time_ns)\n\n    def reset(self):\n        self.min_time = 100000000\n        self.max_time = 0\n        self.total_time_ns = 0\n        self.count = 0\n\n    def get_stat(self):\n        min_time = self.min_time / 1000 / 1000\n        max_time = self.max_time / 1000 / 1000\n        if self.count != 0:\n            avg_time = self.total_time_ns / self.count / 1000 / 1000\n        else:\n            avg_time = 0\n        total = self.total_time_ns / 1000 / 1000\n        return f\"{min_time:15.2f}{max_time:15.2f}{avg_time:15.2f}{self.count:15}{total:15.2f}\"\n\n\ntimeStat = TimeStat()\n\n\ndef get_free_ports(n: int, continue_prot: list):\n    sockets = []\n    ports = []\n    for _ in range(n):\n        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n        s.bind((\"\", 0)) \n        port = s.getsockname()[1]\n        if port in continue_prot:\n            s.close()\n            continue\n        ports.append(port)\n        sockets.append(s)\n    for s in sockets:\n        s.close()\n    return ports\n\ndef get_current_device():\n    if use_torch_npu:\n        return f\"npu:{torch.npu.current_device()}\"\n    else:\n        return f\"cuda:{torch.npu.current_device()}\"\n\ndef get_compute_capability(device:torch.device = None):\n    if use_torch_npu:\n        return 0\n    if torch.cuda.is_available():\n        if device is None:\n            num_gpus = torch.cuda.device_count()\n            min_compute_capability_major = 100\n            for gpu_id in range(num_gpus):\n                gpu_props = torch.cuda.get_device_properties(gpu_id)\n                min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)\n            return min_compute_capability_major\n        else:\n            return torch.cuda.get_device_properties(device)\n\ndef set_module(model, submodule_key, module):\n    tokens = submodule_key.split('.')\n    sub_tokens = tokens[:-1]\n    cur_mod = model\n    for s in sub_tokens:\n        if hasattr(cur_mod, s):\n            cur_mod = getattr(cur_mod, s)\n        else: # nn.ModuleList or nn.ModuleList\n            cur_mod=cur_mod[int(s)]\n    if hasattr(cur_mod, tokens[-1]):\n        setattr(cur_mod, tokens[-1], module)\n    else: # nn.ModuleList or nn.ModuleList\n        cur_mod[int(tokens[-1])] = module\n\ndef set_param(module: nn.Module, name: str, weights: torch.Tensor):\n    \n    param=nn.parameter.Parameter(weights, requires_grad=False)\n    if isinstance(module, nn.Linear) and len(weights.shape)==1:\n        param.unsqueeze_(0)\n    setattr(module, name, param)\n\ndef get_device(gguf_module_key:str, device_map:dict):\n    if gguf_module_key in device_map:\n        return device_map[gguf_module_key][\"generate_device\"]\n    else:\n        return \"cuda\"\n\ndef get_all_used_cuda_device(device_map:dict):\n    all_device_list = set()\n    for key in device_map:\n        all_device_list.add(device_map[key][\"generate_device\"]) if \"generate_device\" in device_map[key] else None\n        all_device_list.add(device_map[key][\"prefill_device\"]) if \"prefill_device\" in device_map[key] else None\n    if \"cpu\" in all_device_list:\n        all_device_list.remove(\"cpu\")\n    if use_torch_npu:\n        all_device_list = set([device.replace('cuda', 'npu') for device in all_device_list])\n    all_device_list = list(all_device_list)\n    return all_device_list\n\ndef load_cur_state_dict_npu(module: nn.Module, gguf_loader: ModelLoader, prefix: str = \"\", device=\"npu\"):\n    prefix = prefix.replace(\"orig_module.\", \"\")\n    persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}\n    local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())\n    local_state = {k: v for k, v in local_name_params if v is not None}\n    for name, param in local_state.items():\n        key = prefix + name\n        translated_key = translate_name_to_gguf(key)\n        # TODO: Merge all loader.\n        # I know this is ugly but lets do it for now.\n        if gguf_loader.safetensor_loader is not None:\n            load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor\n            tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map\n        else:\n            load_dequantized_tensor = gguf_loader.load_gguf_tensor\n            tensor_file_map = gguf_loader.tensor_file_map\n        \n        if translated_key in tensor_file_map:\n            target_dtype = torch.get_default_dtype()\n            device = get_device(translated_key[:translated_key.rfind(\".\")], gguf_loader.tensor_device_map)\n            # Todo need fix\n            device = \"cpu\" if \"embd\" in translated_key else get_current_device()\n            print(f\"loading layer {translated_key} to {device}\")\n            torch.cuda.empty_cache()\n            weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)\n            set_param(module, name, weights)\n            del weights\n        else:\n            #print(load_config.tensor_file_map.keys())\n            raise Exception(f\"can't find {translated_key} in GGUF file!\")\n\ndef load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = \"\", device=\"cuda\"):\n    if use_torch_npu:\n        load_cur_state_dict_npu(module, gguf_loader, prefix, device)\n        return\n\n    prefix = prefix.replace(\"orig_module.\", \"\")\n    persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}\n    local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())\n    local_state = {k: v for k, v in local_name_params if v is not None}\n    for name, param in local_state.items():\n        key = prefix + name\n        translated_key = key\n        \n        # TODO: Merge all loader.\n        # I know this is ugly but lets do it for now.\n        if isinstance(gguf_loader, SafeTensorLoader):\n            load_dequantized_tensor = gguf_loader.load_dequantized_tensor\n        else:\n            load_dequantized_tensor = gguf_loader.load_gguf_tensor\n            tensor_file_map = gguf_loader.tensor_file_map\n        \n        if gguf_loader.has_tensor(translated_key) or \"kv_b_proj\" in translated_key:\n            target_dtype = torch.get_default_dtype()\n            device = get_device(translated_key[:translated_key.rfind(\".\")], gguf_loader.tensor_device_map)\n            print(f\"loading {translated_key} to {device}\")\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n            elif torch.xpu.is_available():\n                torch.xpu.empty_cache()\n            if \"kv_b_proj\" in translated_key and not gguf_loader.has_tensor(translated_key):\n                attn_k_b = load_dequantized_tensor(translated_key.replace(\"self_attn.kv_b_proj\", \"attn_k_b\"), device=device).to(dtype=target_dtype)\n                attn_k_b = attn_k_b.transpose(1, 2).contiguous()\n                attn_v_b = load_dequantized_tensor(translated_key.replace(\"self_attn.kv_b_proj\", \"attn_v_b\"), device=device).to(dtype=target_dtype)\n                kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)\n                kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous()\n                set_param(module, name, kv_b_proj)\n                del attn_k_b\n                del attn_v_b\n            else:\n                weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)\n                set_param(module, name, weights)\n                del weights\n        else:\n            #print(load_config.tensor_file_map.keys())\n            raise Exception(f\"can't find {translated_key} in GGUF file!\")\n\n  \ndef sync_all_device(all_device_list):\n    for device in all_device_list:\n        if \"cuda\" in device.lower():\n            torch.cuda.synchronize(device)\n        elif \"xpu\" in device.lower():\n            torch.xpu.synchronize(device)\n        elif use_torch_npu:\n            torch_npu.synchronize(device)\n        else:\n            raise RuntimeError(\"The device {} is not available\".format(device))\n\ntorch_device_mapping ={\"cuda\": \"cuda:0\", \"xpu\": \"xpu:0\"}\n\ndef xpu_fp16_model(config):\n    # This function is to check if we run this model on XPU with FP16 dtype\n    if not torch.xpu.is_available():\n        return False\n    if config.architectures[0] == \"DeepseekV3ForCausalLM\":\n        return True\n    if config.architectures[0] == \"Qwen3MoeForCausalLM\" and config.hidden_size == 4096:\n        # Qwen3-30B seems have precision issue with FP16\n        # so we only use FP16 for Qwen3-235B now\n        return True\n    return False\n\ndef load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device=\"cuda\"):\n    #print(f\"recursively loading weights {prefix}\")\n    if not isinstance(module, base_operator.BaseInjectedModule):\n        load_cur_state_dict(module, gguf_loader, prefix, device=device)\n        for name, child in module._modules.items():\n            load_weights(child, gguf_loader, prefix+name+\".\", device=device)\n    else:\n        module.load()\n\ndef tf_logits_warper(generation_config):\n        \"\"\"\n        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances\n        used for multinomial sampling.\n        \"\"\"\n\n        # instantiate warpers list\n        warpers = LogitsProcessorList()\n\n        # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a\n        # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)\n        if generation_config.num_beams > 1:\n            if isinstance(generation_config._eos_token_tensor, list):\n                min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1\n            elif isinstance(generation_config._eos_token_tensor, torch.Tensor):\n                min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1\n            else:\n                min_tokens_to_keep = 2\n        else:\n            min_tokens_to_keep = 1\n\n        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files\n        # all samplers can be found in `generation_utils_samplers.py`\n        if generation_config.temperature is not None and generation_config.temperature != 1.0:\n            warpers.append(TemperatureLogitsWarper(generation_config.temperature))\n        if generation_config.top_k is not None and generation_config.top_k != 0:\n            warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.top_p is not None and generation_config.top_p < 1.0:\n            warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.min_p is not None:\n            # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)\n            warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.typical_p is not None and generation_config.typical_p < 1.0:\n            warpers.append(\n                TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)\n            )\n        if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:\n            warpers.append(\n                EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)\n            )\n        if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:\n            warpers.append(\n               EtaLogitsWarper(\n                    epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device\n                )\n            )\n        # `LogitNormalization` should always be the last logit processor, when present\n        if generation_config.renormalize_logits is True:\n            warpers.append(LogitNormalization())\n        return warpers\ndef prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,\n                         mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,\n                         num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None,\n                         static_cache = None, draft_model=None, draft_cache=None):\n    import os\n    \n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n    torch._dynamo.config.suppress_errors = True\n    batch_size, seq_length = inputs.shape\n    device_map = model.gguf_loader.tensor_device_map\n    if use_torch_npu:\n        CUR_DEVICE = f\"npu:{torch.npu.current_device()}\"\n        vocabulary_size = model.config.vocab_size\n        topp = torch.tensor([[model.generation_config.top_p]], dtype=torch.float16).npu()\n        topk = torch.tensor([[model.generation_config.top_k]], dtype=torch.int32).npu()\n        temperature = torch.tensor([[model.generation_config.temperature]], dtype=torch.float16).npu()\n        next_token_fake = torch.tensor([[1]], dtype=torch.int32).npu()\n        next_token_probs = torch.tensor([[1.0]], dtype=torch.float16).npu()\n        torch_device = torch.npu.current_device()\n    else:\n        torch_device = get_device('model.layers.0.self_attn', device_map)\n        torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device\n    inputs = inputs.to(torch_device)\n    all_cuda_device = get_all_used_cuda_device(device_map)\n\n    tokens = []\n\n    def decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):\n        if cuda_graph_runner is None:\n            use_cuda_graph = False\n        \n        inputs_embeds = model.model.embed_tokens(cur_token.to('cpu')).to(torch_device)\n        if use_cuda_graph:\n            if cuda_graph_runner.model_capture:\n                cuda_graph_runner.capture(model, cur_token, position_ids, cache_position, past_key_values, CUR_DEVICE, return_dict=False, use_cache=True)\n                cuda_graph_runner.model_capture = False\n\n            ret = cuda_graph_runner(inputs_embeds, position_ids, cache_position)\n            logits = ret[0]\n            next_token = torch.argmax(logits, dim=-1)\n        else:\n            torch_npu.npu.set_device(torch_device)\n            logits = model(inputs_embeds=inputs_embeds,\n                       position_ids=position_ids,\n                       cache_position=cache_position,\n                       past_key_values=past_key_values,\n                       return_dict=False, use_cache=True, is_prefill=False)[0]\n        if past_key_values != None:\n            past_key_values.change_seq_length(1)\n\n        if generation_config.do_sample:\n            logits = logits / temperature\n            torch.manual_seed(0)\n            probs = logits.view(batch_size, vocabulary_size)\n            sm = nn.Softmax(dim=-1)\n            probs = sm(probs).half().npu()\n            next_token = next_token_fake\n            torch_npu._npu_topk_topp_sampling(probs, topk, topp, next_token, next_token_probs)\n            next_token = next_token.squeeze(-1)\n        else:\n            next_token_scores = logits_warper(inputs, logits[:, -1, :])\n            next_token = torch.argmax(next_token_scores, dim=-1)\n        \n        return next_token\n            \n    \n    def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):\n        if use_torch_npu:\n            return decode_one_tokens_npu(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)\n        if cuda_graph_runner is None:\n            use_cuda_graph = False\n        if use_cuda_graph:\n            logits = cuda_graph_runner(cur_token, position_ids, cache_position)\n        else:\n            # custom_stream = torch.cuda.Stream()\n            if torch.cuda.is_available():\n                torch.cuda.set_device(torch_device)\n            elif torch.xpu.is_available():\n                torch.xpu.set_device(torch_device)\n            elif use_torch_npu:\n                torch_npu.set_device(torch_device)\n            else:\n                raise RuntimeError(f\"The device: {torch_device} is not available\")\n            inputs_embeds = model.model.embed_tokens(cur_token.to(\"cpu\")).to(torch_device)\n            # with torch.cuda.stream(custom_stream):\n            logits=model(inputs_embeds=inputs_embeds,\n                        position_ids=position_ids,\n                        cache_position=cache_position,\n                        past_key_values=past_key_values,\n                        return_dict=False, use_cache=True)[0]\n        if past_key_values != None and isinstance(past_key_values, StaticCache):\n            past_key_values.change_seq_length(1)\n        sync_all_device(all_cuda_device)\n        next_token_scores = logits_warper(inputs, logits[:, -1, :])\n        if generation_config.do_sample:\n            probs = nn.functional.softmax(next_token_scores, dim=-1)\n            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)\n        else:\n            next_token = torch.argmax(next_token_scores, dim=-1)\n        return next_token\n\n    # TODO: use CUDA Graph for chunk prefill, may get small improvement\n    def chunk_prefill(inputs, cache_position, past_key_values):\n        if mode == \"long_context\":\n            inputs_embeds = model.model.embed_tokens(inputs.to(\"cpu\"))\n        else:\n            inputs_embeds = model.model.embed_tokens(inputs.to(\"cpu\")).to(torch_device)\n            # inputs_embeds = torch_npu.npu_format_cast_(inputs_embeds, 29)\n        if use_flashinfer_mla:\n            MLAWrapperSingleton.update_buffer(past_key_values.max_pages)\n            MLAWrapperSingleton.need_plan_all()\n\n        ret = model(\n            inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True, is_prefill=True\n        )\n        logits = ret[0][:,-1,:].unsqueeze(0).clone().to(torch_device)\n\n        return logits\n\n    def decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof=None):\n        global warm_uped\n        global _USE_NPU_GRAPH\n        if use_cuda_graph:\n            from ktransformers.util.npu_graph_runner import get_or_create_runner\n            npu_graph_runner = get_or_create_runner(CUR_DEVICE)\n            npu_graph_runner.init(batch_size, seq_length)\n            \n            with torch_npu.npu.stream(npu_graph_runner.main_stream):\n                gen_num_tokens = 1\n                while gen_num_tokens < max_new_tokens:\n                    start_time = timeStat.record_start_time()\n                    if use_flashinfer_mla:\n                        MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,\n                                                    num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,\n                                                    model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)\n                    if gen_num_tokens == 1:\n                        warm_uped = True\n                        _USE_NPU_GRAPH = True\n                        #np_graph_runner.capture(model, draft_model, next_token, torch.tensor(draft_token), position_ids, cache_position, past_key_values, draft_cache, torch_device, return_dict=False, use_cache=True)\n                        cuda_graph_runner = npu_graph_runner\n                    next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)\n                    next_token = next_token.to(torch_device)\n                    inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)\n                    generated_ids[:, cache_position] = next_token.int()\n                    tokens.append(int(next_token))\n                    \n                    seq_length += 1\n\n                    if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':\n                        print(stream.end(), end=\"\", flush=True)\n                        break\n                    else:\n                        if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:\n                            print(stream.put(next_token.item()), end=\"\", flush=True)\n\n                    cache_position += 1\n                    past_key_values.position[0] += 1\n                    position_ids = cache_position.unsqueeze(0)\n                    gen_num_tokens += 1\n                    \n                    if prof is not None:\n                        prof.step()\n\n                npu_graph_runner.destroy()\n                _USE_NPU_GRAPH = False\n        else:\n            gen_num_tokens = 1\n            while gen_num_tokens < max_new_tokens:\n                if use_flashinfer_mla:\n                    MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,\n                                                num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,\n                                                model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)\n                next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph)\n                next_token = next_token.to(torch_device)\n                inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)\n                generated_ids[:, cache_position] = next_token.int()\n                tokens.append(int(next_token))\n                seq_length += 1\n\n                if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':\n                    print(stream.end(), end=\"\", flush=True)\n                    break\n                else:\n                    if torch.distributed.get_rank() % get_tensor_parallel_size() == 0:\n                        print(stream.put(next_token.item()), end=\"\", flush=True)\n\n                cache_position += 1\n                past_key_values.position[0] += 1\n                position_ids = cache_position.unsqueeze(0)\n                gen_num_tokens += 1\n\n                if prof is not None:\n                    prof.step()\n        \n        if prof is not None:\n            prof.stop()\n    \n    if torch.cuda.is_available():\n        torch.cuda.set_device(torch_device)\n    elif torch.xpu.is_available():\n        torch.xpu.set_device(torch_device)\n    elif use_torch_npu:\n        torch_npu.set_device(torch_device)\n    else:\n        raise RuntimeError(f\"The device: {torch_device} is not available\")\n\n    with torch.no_grad():\n\n        stream = TextStreamer(tokenizer)\n        if torch.xpu.is_available():\n            from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache\n            if model.config.architectures[0] in [\"DeepseekV3ForCausalLM\", \"DeepseekV2ForCausalLM\"]:\n                past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)\n            else:\n                past_key_values = DynamicNormalCache.from_legacy_cache(None)\n        elif use_torch_npu and static_cache:\n            assert isinstance(static_cache, StaticCache), '[ERROR] static_cache format not equal to StaticCache'\n            past_key_values = static_cache\n            if past_key_values.max_batch_size < batch_size or past_key_values.max_cache_len < seq_length + max_new_tokens:\n                print('[WARN] current staticCache size exceeded, try create new staticCache...')\n                past_key_values = StaticCache(\n                    config=model.config, max_batch_size=1, max_cache_len=seq_length + max_new_tokens, device=device_map, dtype=model.dtype\n                )\n            else:\n                past_key_values.reset()\n        elif mode != 'long_context':\n            past_key_values = StaticCache(\n                config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype\n            )\n        else:\n            past_key_values = None\n\n        generation_config, model_kwargs = model._prepare_generation_config(\n            None, do_sample=False\n            # change this to modify generate config\n            #top_k=5, top_p=0.85, temperature=0.1\n        )\n        \n        logits_warper = tf_logits_warper(generation_config)\n\n        cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)\n        if use_torch_npu:\n            past_key_values.position[0] = seq_length + 1\n        generated_ids = torch.zeros(\n            batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device\n        )\n        generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)\n        start_time = time.time()\n        logits = None\n\n        def prefill_wrapper(prof=None):\n            nonlocal logits\n            chunk_start = 0\n            while chunk_start < seq_length:\n                chunk_end = min(chunk_start + chunk_size, seq_length)\n                if past_key_values != None:\n                    past_key_values.cur_idx=cache_position[chunk_start:chunk_end]\n                logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)\n                chunk_start += chunk_size\n                if prof is not None:\n                    prof.step()\n            if prof is not None:\n                prof.stop()\n            if logits is None:\n                raise ValueError('logits cannot be None')\n\n        if use_torch_npu:\n            global WARM_UP_SKIP_CNT\n            prof_prefill = os.environ[\"PROF_PREFILL\"] if \"PROF_PREFILL\" in os.environ else \"0\"\n            if prof_prefill == \"1\" and WARM_UP_SKIP_CNT[0] <= 0:\n                experimental_config = torch_npu.profiler._ExperimentalConfig(\n                    aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,\n                    profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False\n                )\n                with torch_npu.profiler.profile(\n                        activities=[\n                            torch_npu.profiler.ProfilerActivity.CPU,\n                            torch_npu.profiler.ProfilerActivity.NPU\n                        ],\n                        schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=8, repeat=1, skip_first=0),\n                        on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(\"./prefill_prof\"),\n                        record_shapes=True,\n                        profile_memory=True,\n                        with_stack=False,\n                        with_flops=False,\n                        with_modules=False,\n                        experimental_config=experimental_config) as prof:\n                    prefill_wrapper(prof)\n            else:\n                prefill_wrapper()\n            WARM_UP_SKIP_CNT[0] -= 1\n        else:\n\n            chunk_start = 0\n            while chunk_start < seq_length:\n                chunk_end = min(chunk_start + chunk_size, seq_length)\n                if past_key_values != None:\n                    past_key_values.cur_idx=cache_position[chunk_start:chunk_end]\n                logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)\n                chunk_start += chunk_size\n\n        next_token_scores = logits_warper(inputs, logits[:, -1, :])\n        if generation_config.do_sample:\n            probs = nn.functional.softmax(next_token_scores, dim=-1)\n            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)\n        else:\n            next_token = torch.argmax(next_token_scores, dim=-1)\n\n        first_token_time = time.time() - start_time\n\n        # print(f\"------------------------------------- prefill next_token {next_token}  draft_token {draft_token} \")\n        if use_flashinfer_mla:\n            MLAWrapperSingleton.reset_buffer()\n\n        prefill_count = seq_length\n        prefill_time = first_token_time\n        if use_torch_npu and torch.distributed.get_rank() % get_tensor_parallel_size() == 0:\n            if force_think:\n                print(\"<think>\")\n            print(stream.put(next_token.item()), end=\"\", flush=True)\n        elif not use_torch_npu:\n            if force_think:\n                print(\"<think>\")\n            print(stream.put(next_token.item()), end=\"\", flush=True)\n\n        generated_ids[:, seq_length] = next_token\n        tokens.append(int(next_token))\n        inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)\n        cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)\n        position_ids = cache_position.unsqueeze(0)\n        seq_length += 1\n        \n        cuda_graph_runner = None\n        \n        start_time = time.time()\n\n        if not use_torch_npu:\n            for i in range(1, max_new_tokens):\n                if use_flashinfer_mla:\n                    MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,\n                                                num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,\n                                                model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)\n                global warm_uped\n                if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):\n                    warm_uped = True\n                    cuda_graph_runner = CUDAGraphRunner()\n                    cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)\n                next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)\n                inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)\n                generated_ids[:, cache_position] = next_token.int()\n                tokens.append(int(next_token))\n                seq_length += 1\n                \n                if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':\n                    print(stream.end(), end=\"\", flush=True)\n                    break\n                else:\n                    print(stream.put(next_token.item()), end=\"\", flush=True)\n                cache_position += 1\n                position_ids = cache_position.unsqueeze(0)\n        else:\n            prof_decode = os.environ[\"PROF_DECODE\"] if \"PROF_DECODE\" in os.environ else \"0\"\n            prof_ranks = os.environ[\"PROF_RANK\"] if \"PROF_RANK\" in os.environ else \"0\"\n            prof_ranks = [int(r.strip()) for r in prof_ranks.split(\",\")]\n            if prof_decode == \"1\" and torch.distributed.get_rank() in prof_ranks and WARM_UP_SKIP_CNT[1] <= 0:\n                experimental_config = torch_npu.profiler._ExperimentalConfig(\n                    aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization,\n                    profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False\n                )\n                with torch_npu.profiler.profile(\n                        activities=[\n                            torch_npu.profiler.ProfilerActivity.CPU,\n                            torch_npu.profiler.ProfilerActivity.NPU\n                        ],\n                        schedule=torch_npu.profiler.schedule(wait=0, warmup=0, active=_MAX_DECODE_PROFILE, repeat=1, skip_first=0),\n                        on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(\"./decode_prof\"),\n                        record_shapes=True,\n                        profile_memory=True,\n                        with_stack=False,\n                        with_flops=False,\n                        with_modules=False,\n                        experimental_config=experimental_config) as prof:\n                    decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length, prof)\n            else:\n                decode_wrapper(next_token, position_ids, cache_position, cuda_graph_runner, past_key_values, inputs, seq_length)\n            WARM_UP_SKIP_CNT[1] -= 1 \n\n    total_time = time.time() - start_time\n    tokens_generated = len(tokens)\n    tokens_per_second = tokens_generated / total_time\n\n    if not use_torch_npu:\n        print(\"\")\n\n        print(f\"prompt eval count:    {prefill_count} token(s)\")\n        print(f\"prompt eval duration: {prefill_time}s\")\n        print(f\"prompt eval rate:     {prefill_count/prefill_time} tokens/s\")\n        print(f\"eval count:           {tokens_generated} token(s)\")\n        print(f\"eval duration:        {total_time}s\")\n        print(f\"eval rate:            {tokens_per_second} tokens/s\")\n    else:\n        tp_size = get_tensor_parallel_size()\n        if torch.distributed.get_rank() % tp_size == 0:\n            rank = f\"[rank:{torch.distributed.get_rank()}]\"\n            msg = f\"\\n{rank} Eval Time\\n\"\n            msg += rank + f\"prompt eval count:    {prefill_count} token(s)\\n\"\n            msg += rank + f\"prompt eval duration: {prefill_time:.9f}s\\n\"\n            msg += rank + f\"prompt eval rate:     {prefill_count/prefill_time:.9f} tokens/s\\n\"\n            msg += rank + f\"eval count:           {tokens_generated} token(s)\\n\"\n            msg += rank + f\"eval duration:        {total_time:.9f}s\\n\"\n            msg += rank + f\"eval rate:            {tokens_per_second:.9f} tokens/s\\n\"\n            print(msg)\n\n    return tokens\n\nclass InferenceState(enum.Enum):\n    UNLOAD = 0\n    PREFILL = 1\n    GENERATE = 2\n    RESTORE = 3\n"
  },
  {
    "path": "archive/ktransformers/util/vendors.py",
    "content": "from __future__ import annotations\n\nfrom enum import IntEnum, auto\nfrom typing import Optional, Union, List\nimport torch\n\nclass GPUVendor(IntEnum):\n    NVIDIA = auto()\n    AMD = auto()\n    MooreThreads = auto()\n    MetaX = auto()\n    MUSA = auto()\n    Unknown = auto()\n\nclass DeviceManager:\n    \"\"\"\n    Device manager that provides a unified interface for handling different GPU vendors\n    \"\"\"\n    def __init__(self):\n        self.gpu_vendor = self._detect_gpu_vendor()\n        self.available_devices = self._get_available_devices()\n    \n    def _detect_gpu_vendor(self) -> GPUVendor:\n        \"\"\"Detect GPU vendor type\"\"\"\n        if not torch.cuda.is_available():\n            # Check MUSA availability (assuming a musa module exists)\n            try:\n                import musa\n                if musa.is_available():\n                    return GPUVendor.MUSA\n            except (ImportError, AttributeError):\n                pass\n            \n            return GPUVendor.Unknown\n        \n        device_name = torch.cuda.get_device_name(0).lower()\n        \n        if any(name in device_name for name in [\"nvidia\", \"geforce\", \"quadro\", \"tesla\", \"titan\", \"rtx\", \"gtx\"]):\n            return GPUVendor.NVIDIA\n        elif any(name in device_name for name in [\"amd\", \"radeon\", \"rx\", \"vega\", \"instinct\", \"firepro\", \"mi\"]):\n            return GPUVendor.AMD\n        elif any(name in device_name for name in [\"mthreads\", \"moore\", \"mtt\"]):\n            return GPUVendor.MooreThreads\n        elif any(name in device_name for name in [\"metax\", \"meta\"]):\n            return GPUVendor.MetaX\n        elif \"musa\" in device_name:\n            return GPUVendor.MUSA\n        \n        # Backend check\n        try:\n            if hasattr(torch.version, 'hip') and torch.version.hip is not None:\n                return GPUVendor.AMD\n            elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None:\n                return GPUVendor.NVIDIA\n        except:\n            pass\n            \n        return GPUVendor.Unknown\n    \n    def _get_available_devices(self) -> List[int]:\n        \"\"\"Get list of available device indices\"\"\"\n        devices = []\n        \n        if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:\n            devices = list(range(torch.cuda.device_count()))\n        elif self.gpu_vendor == GPUVendor.MUSA:\n            try:\n                import musa\n                devices = list(range(musa.device_count()))\n            except (ImportError, AttributeError):\n                pass\n            \n        return devices\n    \n    def get_device_str(self, device_id: Union[int, str]) -> str:\n        \"\"\"\n        Get device string for the given device ID\n        \n        Args:\n            device_id: Device index (0, 1, 2, etc.), -1 for CPU, or \"cpu\" string\n            \n        Returns:\n            Device string representation (e.g., \"cuda:0\", \"musa:1\", \"cpu\")\n        \"\"\"\n        if device_id == -1 or device_id == \"cpu\":\n            return \"cpu\"\n            \n        if isinstance(device_id, int):\n            if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:\n                if device_id < torch.cuda.device_count():\n                    return f\"cuda:{device_id}\"\n            elif self.gpu_vendor == GPUVendor.MUSA:\n                try:\n                    import musa\n                    if device_id < musa.device_count():\n                        return f\"musa:{device_id}\"\n                except (ImportError, AttributeError):\n                    pass\n        \n        return \"cpu\"\n    \n    def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device:\n        \"\"\"\n        Convert device ID to torch.device object\n        \n        Args:\n            device_id: Device index (0, 1, 2, etc.), -1 for CPU, or \"cpu\" string\n            \n        Returns:\n            torch.device object\n        \"\"\"\n        device_str = self.get_device_str(device_id)\n        \n        # Handle MUSA device\n        if device_str.startswith(\"musa:\"):\n            try:\n                import musa\n                index = int(device_str.split(\":\")[-1])\n                return musa.device(index)\n            except (ImportError, ValueError, AttributeError):\n                return torch.device(\"cpu\")\n        \n        # Standard PyTorch device\n        return torch.device(device_str)\n    \n    def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:\n        \"\"\"\n        Move tensor to specified device\n        \n        Args:\n            tensor: PyTorch tensor to move\n            device_id: Device index (0, 1, 2, etc.), -1 for CPU, or \"cpu\" string\n            \n        Returns:\n            Tensor moved to the specified device\n        \"\"\"\n        device = self.to_torch_device(device_id)\n        return tensor.to(device)\n    \n    def is_available(self, index: int = 0) -> bool:\n        \"\"\"\n        Check if device at specified index is available\n        \n        Args:\n            index: Device index to check\n            \n        Returns:\n            True if the device is available, False otherwise\n        \"\"\"\n        if index < 0:\n            return True  # CPU is always available\n            \n        return index in self.available_devices\n    \n    def get_all_devices(self) -> List[int]:\n        \"\"\"\n        Get all available device indices\n        \n        Returns:\n            List of available device indices (0, 1, 2, etc.)\n        \"\"\"\n        return self.available_devices\n\n# Create global device manager instance\ndevice_manager = DeviceManager()\n\n# Convenience functions\ndef get_device(device_id: Union[int, str] = 0) -> torch.device:\n    \"\"\"\n    Get torch.device object for the specified device ID\n    \n    Args:\n        device_id: Device index (0, 1, 2, etc.), -1 for CPU, or \"cpu\" string\n        \n    Returns:\n        torch.device object\n    \"\"\"\n    return device_manager.to_torch_device(device_id)\n\ndef to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:\n    \"\"\"\n    Move tensor to specified device\n    \n    Args:\n        tensor: PyTorch tensor to move\n        device_id: Device index (0, 1, 2, etc.), -1 for CPU, or \"cpu\" string\n        \n    Returns:\n        Tensor moved to the specified device\n    \"\"\"\n    return device_manager.move_tensor_to_device(tensor, device_id)\n\n# Get devices\ncpu_device = get_device(-1)        # CPU using index -1\ncpu_device2 = get_device(\"cpu\")    # CPU using string \"cpu\"\ngpu0 = get_device(0)               # First GPU\n\n# Move tensors\nx = torch.randn(3, 3)\nx_gpu = to_device(x, 0)            # Move to first GPU\nx_cpu1 = to_device(x, -1)          # Move to CPU using index -1\nx_cpu2 = to_device(x, \"cpu\")       # Move to CPU using string \"cpu\""
  },
  {
    "path": "archive/ktransformers/util/weight_loader.py",
    "content": "from abc import ABC, abstractmethod\nimport os\nimport torch\nimport numpy as np\nfrom safetensors import safe_open\nfrom typing import Dict, Any, Optional, Union\n\nclass ModelLoader(ABC):\n    \"\"\"\n    Abstract base class for model loaders.\n    Defines the interface that all model loaders must implement.\n    \"\"\"\n    \n    @abstractmethod\n    def load_tensor(self, name: str, device: str = \"cpu\") -> torch.Tensor:\n        \"\"\"\n        Load a tensor by name.\n        \n        Args:\n            name: Name of the tensor to load\n            device: Device to load the tensor to\n            \n        Returns:\n            The loaded tensor\n        \"\"\"\n        pass\n    \n    @classmethod\n    @abstractmethod\n    def supports_format(cls, path: str) -> bool:\n        \"\"\"\n        Check if this loader supports the given path format.\n        \n        Args:\n            path: Path to check\n            \n        Returns:\n            True if this loader supports the given path, False otherwise\n        \"\"\"\n        pass\n\n\nclass SafeTensorLoader(ModelLoader):\n    \"\"\"\n    Loader for SafeTensor format models.\n    \"\"\"\n    \n    def __init__(self, path: str):\n        \"\"\"\n        Initialize the SafeTensor loader.\n        \n        Args:\n            path: Path to the model directory or file\n        \"\"\"\n        self.tensor_file_map = {}  # Maps tensor names to file paths\n        self.file_handle_map = {}  # Maps file names to file handles\n        self._load_tensor_file_map(path)\n    \n    def _load_tensor_file_map(self, path: str) -> None:\n        \"\"\"\n        Load the tensor file map from the given path.\n        \n        Args:\n            path: Path to the model directory or file\n        \"\"\"\n        # Normalize path to directory\n        if not os.path.exists(path):\n            raise FileNotFoundError(f\"Path not found: {path}\")\n        if os.path.isfile(path):\n            folder_path = os.path.dirname(path)\n        else:\n            folder_path = path\n\n        found_safetensor = False\n        for root, _, files in os.walk(folder_path):\n            files = sorted(files)\n            for file in files:\n                if file.endswith(\".safetensors\"):\n                    found_safetensor = True\n                    file_path = os.path.join(root, file)\n                    if file not in self.file_handle_map:\n                        try:\n                            handle = safe_open(file_path, framework=\"pt\")\n                            self.file_handle_map[file] = handle\n                        except Exception as e:\n                            print(f\"Error opening Safetensor file {file_path}: {e}\")\n                            continue\n\n                    f = self.file_handle_map.get(file)\n                    if f is None:\n                        continue\n                    try:\n                        for key in f.keys():\n                            self.tensor_file_map[key] = file\n                    except Exception as e:\n                        print(f\"Error reading Safetensor file {file_path}: {e}\")\n\n        if not found_safetensor:\n            # Not raising an error here allows for the factory to try other loaders\n            print(f\"No Safetensor files found in {folder_path}\")\n    \n    def load_tensor(self, name: str, device: str = \"cpu\") -> torch.Tensor:\n        \"\"\"\n        Load a tensor by name.\n        \n        Args:\n            name: Name of the tensor to load\n            device: Device to load the tensor to\n            \n        Returns:\n            The loaded tensor\n        \"\"\"\n        if name not in self.tensor_file_map:\n            raise KeyError(f\"Key {name} not found in Safetensor files\")\n        file = self.tensor_file_map[name]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(name)\n        return tensor.to(device)\n    \n    def load_dequantized_tensor(self, name: str, device: str = \"cpu\") -> torch.Tensor:\n        \"\"\"\n        Load and dequantize a tensor.\n        \n        Args:\n            name: Name of the tensor to load\n            device: Device to load the tensor to\n            \n        Returns:\n            The dequantized tensor\n        \"\"\"\n        if name not in self.tensor_file_map:\n            raise KeyError(f\"Key {name} not found in Safetensor files\")\n        file = self.tensor_file_map[name]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(name).to(device)\n        if name.endswith(\".weight\"):\n            if name[:-7] + \".weight_scale_inv\" in self.tensor_file_map:\n                weight_scale_inv = f.get_tensor(name[:-7] + \".weight_scale_inv\").to(device)\n                # Assuming weight_dequant function is imported\n                from ktransformers.ktransformers_ext.triton.fp8gemm import weight_dequant\n                tensor = weight_dequant(tensor, weight_scale_inv)\n        return tensor.to(device)\n    \n    def close_all_handles(self) -> None:\n        \"\"\"\n        Close all file handles.\n        \"\"\"\n        for handle in self.file_handle_map.values():\n            handle.close()\n        self.file_handle_map.clear()\n\n    @classmethod\n    def supports_format(cls, path: str) -> bool:\n        \"\"\"\n        Check if this loader supports the given path format.\n        \n        Args:\n            path: Path to check\n            \n        Returns:\n            True if safetensor files are found in the path, False otherwise\n        \"\"\"\n        # Normalize path to directory\n        if not os.path.exists(path):\n            return False\n        if os.path.isfile(path):\n            if path.endswith(\".safetensors\"):\n                return True\n            folder_path = os.path.dirname(path)\n        else:\n            folder_path = path\n            \n        # Check if any safetensor files exist in the folder\n        for root, _, files in os.walk(folder_path):\n            for file in files:\n                if file.endswith(\".safetensors\"):\n                    return True\n        return False\n\n\nclass GGUFLoader(ModelLoader):\n    \"\"\"\n    Loader for GGUF format models.\n    \"\"\"\n    \n    def __init__(self, path: str):\n        \"\"\"\n        Initialize the GGUF loader.\n        \n        Args:\n            path: Path to the model directory or file\n        \"\"\"\n        # Check if path exists\n        if not os.path.exists(path):\n            raise FileNotFoundError(f\"GGUF dir not found: {path}\")\n        if os.path.isfile(path):\n            self.gguf_path = os.path.dirname(path)\n        else:\n            self.gguf_path = path\n            \n        self.tensor_info = {}  # Stores tensor metadata\n        self.tensor_file_map = {}  # Maps tensor names to file paths\n        self.file_data_map = {}  # Maps file paths to memory-mapped data\n        self.gguf_file_meta = {}  # Stores GGUF metadata\n        \n        # For compatibility with the factory pattern\n        self.safetensor_loader = None\n        \n        # Scan all GGUF files in the directory\n        found_gguf = False\n        for root, _, files in os.walk(self.gguf_path):\n            for file in files:\n                if file.endswith(\".gguf\"):\n                    found_gguf = True\n                    file_path = os.path.join(root, file)\n                    with open(file_path, \"rb\") as f:\n                        self._load_gguf(f)\n                        if file_path not in self.file_data_map:\n                            self.file_data_map[file_path] = np.memmap(file_path, mode='r')\n        \n        if not found_gguf:\n            raise FileNotFoundError(f\"Cannot find any .gguf files in: {self.gguf_path}\")\n    \n    def _load_gguf(self, f) -> None:\n        \"\"\"\n        Load GGUF file metadata and tensor info.\n        \n        Args:\n            f: File handle of the GGUF file\n        \"\"\"\n        # Implementation should follow the original GGUFLoader._load_gguf\n        # This is a simplified version for illustration\n        f.seek(0)\n        assert f.read(4) == b'GGUF'\n        \n        # Read header\n        values = struct.unpack(\"<IQQ\", f.read(4+8+8))\n        version, n_tensors, n_kv = values\n        if version != 3:\n            warnings.warn(f\"Version {version} has never been tested, might not work\")\n\n        # Read key-value pairs\n        info = {}\n        for _ in range(n_kv):\n            name = self._read_value(f, 8)  # DATA_TYPES[\"string\"]\n            data_type = struct.unpack(\"<I\", f.read(4))[0]\n            info[name] = self._read_value(f, data_type)\n\n        # Read tensor info\n        tensor_info = {}\n        for _ in range(n_tensors):\n            name = self._read_value(f, 8)  # DATA_TYPES[\"string\"]\n            shape_len = self._read_value(f, 4)  # DATA_TYPES[\"uint32\"]\n            shape = [self._read_value(f, 10) for _ in range(shape_len)]  # DATA_TYPES[\"uint64\"]\n            ggml_type = self._read_value(f, 4)  # DATA_TYPES[\"uint32\"]\n            offset = self._read_value(f, 10)  # DATA_TYPES[\"uint64\"]\n            \n            # Additional tensor metadata would be calculated here\n            # For brevity, we're omitting the detailed tensor metadata calculation\n            tensor_info[name] = {\n                \"ggml_type\": ggml_type,\n                \"shape\": shape,\n                \"offset\": offset,\n                # ... other tensor metadata\n            }\n            \n        start = f.tell()\n        alignment = info.get(\"general.alignment\", 32)\n        \n        # Calculate actual file offsets\n        for t in tensor_info.values():\n            offset = start + t[\"offset\"]\n            offset += (alignment - offset % alignment) % alignment\n            t[\"offset\"] = offset\n            \n        # Update file maps\n        for name in tensor_info:\n            self.tensor_file_map[name] = f.name\n            \n        self.tensor_info.update(tensor_info)\n        self.gguf_file_meta.update(info)\n    \n    def _read_value(self, f, data_type) -> Any:\n        \"\"\"\n        Read a value from the file according to its data type.\n        \n        Args:\n            f: File handle\n            data_type: Type of data to read\n            \n        Returns:\n            The read value\n        \"\"\"\n        # Simplified implementation\n        # In a complete implementation, this would handle all data types\n        if data_type == 8:  # DATA_TYPES[\"string\"]\n            length = struct.unpack(\"<Q\", f.read(8))[0]\n            return f.read(length).decode(\"utf-8\")\n        elif data_type == 4:  # DATA_TYPES[\"uint32\"]\n            return struct.unpack(\"<I\", f.read(4))[0]\n        elif data_type == 10:  # DATA_TYPES[\"uint64\"]\n            return struct.unpack(\"<Q\", f.read(8))[0]\n        # ... handling for other data types\n        return None\n    \n    def load_tensor(self, name: str, device: str = \"cpu\") -> torch.Tensor:\n        \"\"\"\n        Load a tensor by name.\n        \n        Args:\n            name: Name of the tensor to load\n            device: Device to load the tensor to\n            \n        Returns:\n            The loaded tensor\n        \"\"\"\n        # This should call load_gguf_tensor with the appropriate parameters\n        return self.load_gguf_tensor(name, device)\n    \n    def load_gguf_tensor(self, name: str, device: str = \"cpu\", target_dtype = None) -> torch.Tensor:\n        \"\"\"\n        Load a GGUF tensor by name.\n        \n        Args:\n            name: Name of the tensor to load\n            device: Device to load the tensor to\n            target_dtype: Target data type for the tensor\n            \n        Returns:\n            The loaded tensor\n        \"\"\"\n        # Implementation would follow the original GGUFLoader.load_gguf_tensor\n        # This is a placeholder for illustration\n        if name not in self.tensor_info:\n            raise KeyError(f\"Tensor {name} not found\")\n            \n        # Actual implementation would dequantize the tensor data\n        # and return a torch.Tensor\n        return torch.zeros(1, device=device)  # Placeholder\n    \n    @classmethod\n    def supports_format(cls, path: str) -> bool:\n        \"\"\"\n        Check if this loader supports the given path format.\n        \n        Args:\n            path: Path to check\n            \n        Returns:\n            True if GGUF files are found in the path, False otherwise\n        \"\"\"\n        # Normalize path to directory\n        if not os.path.exists(path):\n            return False\n        if os.path.isfile(path):\n            return path.endswith(\".gguf\")\n        \n        # Check if any GGUF files exist in the folder\n        for root, _, files in os.walk(path):\n            for file in files:\n                if file.endswith(\".gguf\"):\n                    return True\n        return False"
  },
  {
    "path": "archive/ktransformers/website/.browserslistrc",
    "content": "> 1%\nlast 2 versions\nnot dead\nnot ie 11\n"
  },
  {
    "path": "archive/ktransformers/website/.eslintrc.js",
    "content": "module.exports = {\n  root: true,\n  env: {\n    node: true\n  },\n  'extends': [\n    'plugin:vue/vue3-essential',\n    'eslint:recommended',\n    '@vue/typescript/recommended'\n  ],\n  parserOptions: {\n    ecmaVersion: 2020\n  },\n  rules: {\n    'no-console': process.env.NODE_ENV === 'production' ? 'warn' : 'off',\n    'no-debugger': process.env.NODE_ENV === 'production' ? 'warn' : 'off'\n  },\n  overrides: [\n    {\n      files: [\n        '**/__tests__/*.{j,t}s?(x)',\n        '**/tests/unit/**/*.spec.{j,t}s?(x)'\n      ],\n      env: {\n        jest: true\n      }\n    }\n  ]\n}\n"
  },
  {
    "path": "archive/ktransformers/website/.gitignore",
    "content": ".DS_Store\nnode_modules\n/dist\n\n\n# local env files\n.env.local\n.env.*.local\n\n# Log files\nnpm-debug.log*\nyarn-debug.log*\nyarn-error.log*\npnpm-debug.log*\n\n# Editor directories and files\n.idea\n.vscode\n*.suo\n*.ntvs*\n*.njsproj\n*.sln\n*.sw?\n"
  },
  {
    "path": "archive/ktransformers/website/README.md",
    "content": "# \n\n## Project setup\n```\nnpm install\n```\n\n### Compiles and hot-reloads for development\n```\nnpm run serve\n```\n\n### Compiles and minifies for production\n```\nnpm run build\n```\n\n### Run your unit tests\n```\nnpm run test:unit\n```\n\n### Lints and fixes files\n```\nnpm run lint\n```\n\n### Customize configuration\nSee [Configuration Reference](https://cli.vuejs.org/config/).\n"
  },
  {
    "path": "archive/ktransformers/website/config.d.ts",
    "content": "declare module '*.js' {\n    const config: {\n      apiUrl: string;\n      port:number;\n    };\n    export { config };\n  }"
  },
  {
    "path": "archive/ktransformers/website/jest.config.js",
    "content": "module.exports = {\n  preset: '@vue/cli-plugin-unit-jest/presets/typescript'\n}\n"
  },
  {
    "path": "archive/ktransformers/website/package.json",
    "content": "{\n  \"name\": \"\",\n  \"version\": \"\",\n  \"private\": true,\n  \"scripts\": {\n    \"serve\": \"vue-cli-service serve\",\n    \"build\": \"vue-cli-service build\",\n    \"test:unit\": \"vue-cli-service test:unit\",\n    \"lint\": \"vue-cli-service lint\"\n  },\n  \"dependencies\": {\n    \"@types/pdfjs-dist\": \"^2.10.378\",\n    \"@types/websocket\": \"^1.0.10\",\n    \"@vue/cli\": \"^5.0.8\",\n    \"ant-design-vue\": \"^4.2.1\",\n    \"apexcharts\": \"^3.49.1\",\n    \"axios\": \"^1.7.0\",\n    \"axios-extensions\": \"^3.1.6\",\n    \"better-scroll\": \"^2.5.1\",\n    \"element-plus\": \"^2.7.3\",\n    \"marked\": \"^12.0.2\",\n    \"marked-highlight\": \"^2.1.1\",\n    \"pdf-lib\": \"^1.17.1\",\n    \"pdfobject\": \"^2.3.0\",\n    \"v-clipboard\": \"^3.0.0-next.1\",\n    \"vue\": \"^3.4.27\",\n    \"vue-i18n\": \"^9.13.1\",\n    \"vue-pdf\": \"^4.3.0\",\n    \"vue-router\": \"^4.0.3\",\n    \"vue3-apexcharts\": \"^1.5.3\",\n    \"vuex\": \"^4.0.0\",\n    \"webpack\": \"^5.91.0\",\n    \"webpack-cli\": \"^5.1.4\",\n    \"websocket\": \"^1.0.35\"\n  },\n  \"devDependencies\": {\n    \"@types/jest\": \"^27.0.1\",\n    \"@types/pdfobject\": \"^2.2.5\",\n    \"@typescript-eslint/eslint-plugin\": \"^5.4.0\",\n    \"@typescript-eslint/parser\": \"^5.4.0\",\n    \"@vue/cli-plugin-eslint\": \"~5.0.0\",\n    \"@vue/cli-plugin-router\": \"~5.0.0\",\n    \"@vue/cli-plugin-typescript\": \"~5.0.0\",\n    \"@vue/cli-plugin-unit-jest\": \"~5.0.0\",\n    \"@vue/cli-plugin-vuex\": \"~5.0.0\",\n    \"@vue/cli-service\": \"~5.0.0\",\n    \"@vue/eslint-config-typescript\": \"^9.1.0\",\n    \"@vue/test-utils\": \"^2.0.0-0\",\n    \"@vue/vue3-jest\": \"^27.0.0-alpha.1\",\n    \"babel-jest\": \"^27.0.6\",\n    \"eslint\": \"^7.32.0\",\n    \"eslint-plugin-vue\": \"^8.0.3\",\n    \"jest\": \"^27.0.5\",\n    \"stylus\": \"^0.55.0\",\n    \"stylus-loader\": \"^6.1.0\",\n    \"ts-jest\": \"^27.0.4\",\n    \"typescript\": \"~4.5.5\"\n  },\n  \"_id\": \"@\",\n  \"readme\": \"ERROR: No README data found!\"\n}\n"
  },
  {
    "path": "archive/ktransformers/website/public/config.js",
    "content": "window.configWeb = {\n    apiUrl: 'http://119.255.238.12:15670/v1',\n    port: 8080,\n  };"
  },
  {
    "path": "archive/ktransformers/website/public/css/reset.css",
    "content": "html, body, div, span, applet, object, iframe,\nh1, h2, h3, h4, h5, h6, p, blockquote, pre,\na, abbr, acronym, address, big, cite, code,\ndel, dfn, em, img, ins, kbd, q, s, samp,\nsmall, strike, strong, sub, sup, tt, var,\nb, u, i, center,\ndl, dt, dd, ol, ul, li,\nfieldset, form, label, legend,textarea,\ntable, caption, tbody, tfoot, thead, tr, th, td,\narticle, aside, canvas, details, embed,\nfigure, figcaption, footer, header, hgroup,\nmenu, nav, output, ruby, section, summary,\ntime, mark, audio, video {\n    margin: 0;\n    padding: 0;\n    border: 0;\n    font-size: 100%;\n    *font: inherit;\n    font-family: Arial, Microsoft YaHei, SimHei, Tahoma, sans-serif !important;\n    vertical-align: baseline;\n}\n/* HTML5 display-role reset for older browsers */\narticle, aside, details, figcaption, figure,\nfooter, header, hgroup, menu, nav, section {\n    display: block;\n}\nbody {\n    line-height: 1;\n    -webkit-text-size-adjust: 100%!important;\n    margin: 0;\n}\nhtml,body {\n    height: 100%;\n    width: 100%;\n    overflow: hidden;\n}\nol, ul {\n    list-style: none;\n}\nblockquote, q {\n    quotes: none;\n}\nblockquote:before, blockquote:after,\nq:before, q:after {\n    content: '';\n    content: none;\n}\ntable {\n    border-collapse: collapse;\n    border-spacing: 0;\n}\n\n.clearfix:before,\n.clearfix:after {\n    content:\"\";\n    display:table\n}\n.clearfix:after {\n    clear:both\n}\n\n/*显示省略号*/\n.ellipsis{\n    overflow: hidden;\n    text-overflow: ellipsis;\n    white-space: nowrap;\n}\n"
  },
  {
    "path": "archive/ktransformers/website/public/index.html",
    "content": "<!DOCTYPE html>\n<html lang=\"\">\n  <head>\n    <meta charset=\"utf-8\">\n    <meta http-equiv=\"X-UA-Compatible\" content=\"IE=edge\">\n    <meta name=\"viewport\" content=\"width=device-width,initial-scale=1.0,maximum-scale=1.0,minimum-scale=1.0,user-scalable=no\">\n    <script src=\"./config.js\"></script>\n    <link rel=\"icon\" href=\"./balck.ico\" />\n    <link type=\"text/css\" rel=\"stylesheet\" href=\"<%= BASE_URL %>/css/reset.css\">\n    <title>KTransformers</title>\n  </head>\n  <body onselectstart='return false' onselect='return false'>\n    <noscript>\n      <strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>\n    </noscript>\n    <div id=\"app\"></div>\n    <!-- built files will be auto injected -->\n  </body>\n</html>\n"
  },
  {
    "path": "archive/ktransformers/website/src/App.vue",
    "content": "<template>\n  <div class=\"app-container\" @contextmenu.prevent.stop=\"\">\n    <keep-alive>\n      <router-view/>\n    </keep-alive>\n  </div>\n</template>\n\n<script setup lang=\"ts\">\n</script>\n\n<style lang=\"stylus\">\n  @import \"assets/iconfont/iconfont.css\"\n  #app\n  .app-container\n    width: 100%\n    height: 100%\n    position: relative\n</style>"
  },
  {
    "path": "archive/ktransformers/website/src/api/api-client.ts",
    "content": "import axios, { AxiosInstance } from 'axios';\nimport {baseURL} from '@/conf/config';\nconst apiClient: AxiosInstance = axios.create({\n    baseURL: baseURL,\n    // baseURL: '/api',\n    headers: {\n        'Content-Type': 'application/json',\n    },\n    withCredentials: true,\n});\nexport default apiClient;\n"
  },
  {
    "path": "archive/ktransformers/website/src/api/assistant.ts",
    "content": "import apiClient from './api-client';\nimport { IAssistant,IDeleteResult, IAssistantWithStatus } from '../utils/types';\nfunction filterAndConvert(\n    assistantsWithStatus: IAssistantWithStatus[],\n    statusCondition: string\n  ): IAssistant[] {\n    return assistantsWithStatus\n      .filter((assistant) => assistant.build_status.status === statusCondition)\n      .map(({ build_status, ...rest }) => rest);\n  }\n\ninterface IAssistantData {\n    model: string;\n    prefix_system_prompt?: string;\n    suffix_system_prompt?: string;\n    name?: string;\n    description?: string;\n    tools?: any[];\n    tool_resources?: object;\n    metadata?:{[key:string]:any}\n    top_p?: number;\n    temperature?: number;\n    response_format?: string;\n    instructions?: string;\n}\n\nexport const createAssistant = async (data: IAssistantData): Promise<IAssistant> => {\n    const assistant_data: {\n        model: string;\n        instructions?: string;\n        name?: string;\n        description?: string;\n        tools?: any[];\n        tool_resources?: object;\n        metadata?:{[key:string]:any}\n        top_p?: number;\n        temperature?: number;\n        response_format?: string;\n    } = {\n        model: data.model\n    };\n\n    if (data.prefix_system_prompt) {\n        assistant_data.instructions = data.prefix_system_prompt;\n    }\n    if (data.suffix_system_prompt) {\n        assistant_data.instructions = data.suffix_system_prompt;\n    }\n    if (data.name) {\n        assistant_data.name = data.name;\n    }\n    if (data.description) {\n        assistant_data.description = data.description;\n    }\n    if (data.tools) {\n        assistant_data.tools = data.tools;\n    }\n    if (data.tool_resources) {\n        assistant_data.tool_resources = data.tool_resources;\n    }\n    if (data.metadata) {\n        assistant_data.metadata = data.metadata\n    }\n    if (typeof data.top_p !== 'undefined') {\n        assistant_data.top_p = data.top_p;\n    }\n    if (typeof data.temperature !== 'undefined') {\n        assistant_data.temperature = data.temperature;\n    }\n    if (data.response_format) {\n        assistant_data.response_format = data.response_format;\n    }\n    if (data.instructions) {\n        assistant_data.instructions = data.instructions;\n    }\n    console.log(assistant_data)\n    const response = await apiClient.post<IAssistant>(\n        '/assistants/',\n        assistant_data\n    );\n    console.log(\"response\", response)\n    return response.data;\n};\n\n\nexport const listAssistants = async (\n    limit?: number,\n    order?: string,\n    after?: string,\n    before?: string,\n    run_id?: string,\n): Promise<IAssistant[]> => {\n    const params: {\n        limit?: number,\n        order?: string,\n        after?: string,\n        before?: string,\n        run_id?: string\n    } = {};\n\n    if (typeof limit !== 'undefined') {\n        params.limit = limit;\n    }\n    if (typeof order !== 'undefined') {\n        params.order = order;\n    }\n    if (typeof after !== 'undefined') {\n        params.after = after;\n    }\n    if (typeof before !== 'undefined') {\n        params.before = before;\n    }\n    if (typeof run_id !== 'undefined') {\n        params.run_id = run_id;\n    }\n    const response = await apiClient.get<IAssistantWithStatus[]>('/assistants/status', {\n        params\n    });\n    let tmp = response.data\n    let result = [] as IAssistant[]\n    const filteredAssistants = filterAndConvert(tmp, 'completed');\n    return filteredAssistants\n};\n\nexport const getAssistant = async (\n    assistant_id: string\n): Promise<IAssistant> => {\n    const response = await apiClient.get<IAssistant>(`/assistants/${assistant_id}`);\n    return response.data;\n}\n\nexport const deleteAssistant = async (\n    assistant_id: string\n): Promise<IDeleteResult> => {\n    const response = await apiClient.delete<IDeleteResult>(`/assistants/${assistant_id}`);\n    return response.data;\n}\n\nexport const getRelatedThreadId = async (\n    assistant_id: string\n): Promise<string[]> => {\n    const response = await apiClient.get<string[]>(`/assistants/${assistant_id}/related_thread`);\n    return response.data;\n}\n\nexport const listAssistantsWithStatus = async (\n    limit?: number,\n    order?: string,\n    after?: string,\n    before?: string,\n    run_id?: string,\n): Promise<IAssistantWithStatus[]> => {\n    const params: {\n        limit?: number,\n        order?: string,\n        after?: string,\n        before?: string,\n        run_id?: string\n    } = {};\n\n    if (typeof limit !== 'undefined') {\n        params.limit = limit;\n    }\n    if (typeof order !== 'undefined') {\n        params.order = order;\n    }\n    if (typeof after !== 'undefined') {\n        params.after = after;\n    }\n    if (typeof before !== 'undefined') {\n        params.before = before;\n    }\n    if (typeof run_id !== 'undefined') {\n        params.run_id = run_id;\n    }\n    console.log(params)\n    const response = await apiClient.get<IAssistantWithStatus[]>('/assistants/status', {\n        params\n    });\n\n    return response.data;\n};\n\n\n"
  },
  {
    "path": "archive/ktransformers/website/src/api/message.ts",
    "content": "import apiClient from './api-client';\nimport { IMessage,IDeleteResult } from '../utils/types';\n\nexport const createMessage = async (\n    thread_id: string,\n    content: string,\n    role?: string,\n    attachments?: any[],\n    metadata?:{[key:string]:any}\n): Promise<IMessage> => {\n    const message_data: {\n        content: string;\n        role?: string;\n        attachments?: any[];\n        metadata?:{[key:string]:any}\n    } = {\n        content,\n    };\n\n    if (metadata) {\n        message_data.metadata = metadata;\n    }\n    if (role) {\n        message_data.role = role;\n    }\n    if (attachments) {\n        message_data.attachments = attachments;\n    }\n    const response = await apiClient.post<IMessage>(`/threads/${thread_id}/messages`, message_data);\n    return response.data;\n};\n\n\nexport const listMessages = async (\n    thread_id: string,\n    limit?: number,\n    order?: string,\n    after?: string,\n    before?: string,\n    run_id?: string,\n): Promise<IMessage[]> => {\n    const params: {\n        limit?: number,\n        order?: string,\n        after?: string,\n        before?: string,\n        run_id?: string\n    } = {};\n\n    if (typeof limit !== 'undefined') {\n        params.limit = limit;\n    }\n    if (typeof order !== 'undefined') {\n        params.order = order;\n    }\n    if (typeof after !== 'undefined') {\n        params.after = after;\n    }\n    if (typeof before !== 'undefined') {\n        params.before = before;\n    }\n    if (typeof run_id !== 'undefined') {\n        params.run_id = run_id;\n    }\n\n    const response = await apiClient.get<IMessage[]>(`/threads/${thread_id}/messages`, {\n        params\n    });\n\n    return response.data;\n};\nexport const deleteMessage = async(thread_id:string, message_id:string): Promise<IDeleteResult> => {\n    const response = await apiClient.delete<IDeleteResult>(`/threads/${thread_id}/messages/${message_id}`);\n    return response.data;\n}\n"
  },
  {
    "path": "archive/ktransformers/website/src/api/run.ts",
    "content": "import apiClient from './api-client';\nimport { IRun } from '../utils/types';\nimport {baseURL} from '@/conf/config';\ninterface IRunData {\n    assistant_id: string;\n    model?: string;\n    instructions?: string;\n    additional_instructions?: string;\n    additional_messages?: any[];\n    tools?: any[];\n    metadata?: { [key: string]: any }\n    temperature?: number;\n    top_p?: number;\n    stream?: boolean;\n    max_prompt_tokens?: number;\n    max_completion_tokens?: number;\n    truncation_strategy?: object;\n    tool_choice?: string;\n    response_format?: string | object;\n}\n\n\nexport async function* createRun(\n    data: IRunData,\n    thread_id: string\n): AsyncGenerator<string> {\n    const run_data = {\n        ...data, \n        assistant_id: data.assistant_id, \n    };\n\n    const response = await fetch(`${baseURL}/threads/${thread_id}/runs`, {\n        method: 'POST',\n        headers: {\n            'Content-Type': 'application/json',\n        },\n        body: JSON.stringify(run_data),\n    });\n\n    if (!response.ok) {\n        throw new Error(`HTTP error! status: ${response.status}`);\n    }\n\n    if (!response.body) {\n        throw new Error('Response body is missing');\n    }\n    const reader = response.body.getReader();\n    const decoder = new TextDecoder();\n    let buffer = '';\n    try {\n        while (true) {\n            const { done, value } = await reader.read();\n            if (done) return;\n            buffer += decoder.decode(value, { stream: true });\n\n            let eventIndex = buffer.indexOf(\"\\n\\n\");\n            while (eventIndex !== -1) {\n                const event = buffer.slice(0, eventIndex);\n                buffer = buffer.slice(eventIndex + 2);\n                if (event.startsWith(\"event: thread.run.created\")) {\n                    const dataIndex = event.indexOf(\"data: \");\n                    if (dataIndex !== -1) {\n                        const datads = event.slice(39, 75)\n                        yield datads;\n                    }\n                } else if (event.startsWith(\"event: thread.message.delta\")) {\n                    const dataIndex = event.indexOf(\"data: \");\n                    if (dataIndex !== -1) {\n                        const data = JSON.parse(event.slice(dataIndex + 6));\n                        yield data.delta.content[0].text.value || '';\n                    }\n                } else if (event.startsWith(\"event: done\")) {\n                    return;\n                }\n\n                eventIndex = buffer.indexOf(\"\\n\\n\");\n            }\n        }\n    } catch (e) {\n\n        console.error('An error occurred while reading the response stream:', e);\n        // throw e; \n        return e\n    }\n}\n// 定义取消运行的函数\nexport async function cancelRun(threadId: string, runId: string){\n    const run_data = {\n        thread_id:threadId,\n        run_id:runId,\n    };\n    try {\n        const response = await fetch(`${baseURL}/threads/${threadId}/runs/${runId}/cancel`, {\n            method: 'POST',\n        });\n\n        if (!response.ok) {\n            throw new Error(`HTTP error! status: ${response.status}`);\n        }\n\n        return response;\n    } catch (error) {\n        console.error('An error occurred while cancelling the run:', error);\n        throw error;\n    }\n}"
  },
  {
    "path": "archive/ktransformers/website/src/api/thread.ts",
    "content": "import apiClient from './api-client';\nimport { IThread, IMessage, IThreadAndMessageAndAssistant, IDeleteResult } from '../utils/types';\nexport const createThread = async (\n    message?: IMessage,\n    tool_resources?: object,\n    metadata?: { [key: string]: any }\n): Promise<IThread> => {\n    const thread_data: { message?: object, metadata?: { [key: string]: any } } = {};\n    if (message) {\n        thread_data.message = message;\n    }\n    if (metadata) {\n        thread_data.metadata = metadata;\n    }\n    const response = await apiClient.post<IThread>(\n        '/threads',\n        thread_data);\n    return response.data;\n};\n\nexport const listThreads = async (\n    limit?: number,\n    order?: string,\n): Promise<IThreadAndMessageAndAssistant[]> => {\n    const params: {\n        limit?: number,\n        order?: string,\n    } = { limit, order };\n    const response = await apiClient.get<IThreadAndMessageAndAssistant[]>('/threads', {\n        params\n    });\n\n    return response.data;\n};\n\nexport const deleteThread = async (\n    thread_id: string\n): Promise<IDeleteResult> => {\n    const response = await apiClient.delete<IDeleteResult>(`/threads/${thread_id}`);\n    return response.data;\n}\n\nexport const getThread = async (\n    thread_id: string\n): Promise<IThread> => {\n    const response = await apiClient.get<IThread>(`/threads/${thread_id}`);\n    return response.data;\n}"
  },
  {
    "path": "archive/ktransformers/website/src/assets/css/mixins.styl",
    "content": "\n/*Define color variables*/\n$bg_gray_light_normal = #F9F9F9\n$bg_gray_light_hover = #E8E8E8\n$bg_gray_light_active = #E8E8E8\n\n$border_gray_light_normal = rgba(0, 0, 0, .15)\n$border_gray_light_hover = #8080FF\n\n$gray_20 = #333333\n$gray_40 = #585858\n$gray_50 = #7F7F7F\n$gray_60 = #9F9F9F\n$gray_70 = #BFBFBF\n$gray_80 = #DFDFDF\n$gray_85 = #F2F2F2\n$gray_90 = #F7F7F7\n\n$gray = #53525B\n$gray_dark = #42414a\n$gray_hover = #121212\n$gray_action = #6C757D\n\n$primary = #409eff\n$primary_hover = #428bca\n$primary_middle = #9DDDF9\n$primary_light = #D4F0FC\n\n$cyan = #66CCCC\n$cyan_hover = #46C2C2\n\n\n/*Define common modules*/\n$input-duration = .25s\ninput-border()\n  -webkit-transition: border-color ease-in-out $input-duration,-webkit-box-shadow ease-in-out $input-duration\n  -o-transition: border-color ease-in-out $input-duration,box-shadow ease-in-out $input-duration\n  transition: border-color ease-in-out $input-duration,box-shadow ease-in-out $input-duration\ninput-focus()\n  border-color: #66afe9\n  outline: 0\n  z-index: 100\n  -webkit-box-shadow: inset 0 1px 1px rgba(0,0,0,.075),0 0 8px rgba(102,175,233,.6)\n  box-shadow: inset 0 1px 1px rgba(0,0,0,.075),0 0 8px rgba(102,175,233,.6)\n\n\n/*Define common class*/\n.flex-column\n  display: -webkit-box\n  display: -webkit-flex\n  display: flex\n  box-sizing: border-box\n  -webkit-box-orient: vertical\n  -webkit-box-direction: normal\n  -webkit-flex-direction: column\n  flex-direction: column\n  height: 100%\n\n.flex-row\n  position: relative\n  display: -webkit-box\n  display: -ms-flexbox\n  display: flex\n  box-sizing: border-box\n  -webkit-box-align: center\n  -ms-flex-align: center\n  align-items: center\n\n.flex-unit\n  -webkit-box-flex: 1\n  -ms-flex: 1\n  flex: 1\n  // overflow: hidden\n\n.clearfix\n  &:after\n    clear: both\n    content: \"\\20\"\n    display: block\n    height: 0\n    visibility: hidden\n\na,a:hover\n  text-decoration:none\n\nbutton:focus\n  outline: none\n\n.btn\n  display: inline-block\n  margin-bottom: 0\n  padding:0px 15px\n  font-size: 14px\n  height: 34px\n  line-height: 32px\n  float: left /*去掉inline-block之间的空格*/\n  font-weight: normal\n  text-align: center\n  white-space: nowrap\n  vertical-align: middle\n  cursor: pointer\n  background-image: none\n  border-radius: 3px\n  -webkit-user-select: none\n  -moz-user-select: none\n  -ms-user-select: none\n  -o-user-select: none\n  user-select: none\n  &:hover\n    .dropdown-list\n      display: block\n  i\n    font-size: 16px\n  .text\n    float: right\n    margin-left: 3px\n\n.btn-gray\n  color: $gray_action\n  background-color: #FFFFFF\n  border: 1px solid $gray_action\n  &:not(.is-disabled):hover\n    color: #FFFFFF\n    background-color: $gray_action\n    border: 1px solid $gray_action\n\n.btn-primary\n  color: #FFFFFF\n  background-color: $primary\n  border: 1px solid $primary\n  &:not(.is-disabled):hover\n    color: #FFFFFF\n    background-color: $primary_hover\n    border: 1px solid $primary_hover\n\n.chat-box\n  position: relative\n  .chat-input\n    border: 1px solid $border_gray_light_normal\n    height: 48px\n    line-height: 48px\n    font-size: 16px\n    outline: 0\n    box-sizing: border-box\n    padding:0 30px0 20px\n    color: #7F7F7F\n    width: 800px\n    border-radius: 12px\n    position: relative\n    &:focus\n      input-focus()\n  i\n    position: absolute\n    font-size: 26px\n    right: 13px\n    bottom:0px\n    color: $border_gray_light_normal\n    z-index: 100\n    cursor: pointer\n    &:hover\n      color: $border_gray_light_hover\n"
  },
  {
    "path": "archive/ktransformers/website/src/assets/iconfont/demo.css",
    "content": "/* Logo 字体 */\n@font-face {\n  font-family: \"iconfont logo\";\n  src: url('https://at.alicdn.com/t/font_985780_km7mi63cihi.eot?t=1545807318834');\n  src: url('https://at.alicdn.com/t/font_985780_km7mi63cihi.eot?t=1545807318834#iefix') format('embedded-opentype'),\n    url('https://at.alicdn.com/t/font_985780_km7mi63cihi.woff?t=1545807318834') format('woff'),\n    url('https://at.alicdn.com/t/font_985780_km7mi63cihi.ttf?t=1545807318834') format('truetype'),\n    url('https://at.alicdn.com/t/font_985780_km7mi63cihi.svg?t=1545807318834#iconfont') format('svg');\n}\n\n.logo {\n  font-family: \"iconfont logo\";\n  font-size: 160px;\n  font-style: normal;\n  -webkit-font-smoothing: antialiased;\n  -moz-osx-font-smoothing: grayscale;\n}\n\n/* tabs */\n.nav-tabs {\n  position: relative;\n}\n\n.nav-tabs .nav-more {\n  position: absolute;\n  right: 0;\n  bottom: 0;\n  height: 42px;\n  line-height: 42px;\n  color: #666;\n}\n\n#tabs {\n  border-bottom: 1px solid #eee;\n}\n\n#tabs li {\n  cursor: pointer;\n  width: 100px;\n  height: 40px;\n  line-height: 40px;\n  text-align: center;\n  font-size: 16px;\n  border-bottom: 2px solid transparent;\n  position: relative;\n  z-index: 1;\n  margin-bottom: -1px;\n  color: #666;\n}\n\n\n#tabs .active {\n  border-bottom-color: #f00;\n  color: #222;\n}\n\n.tab-container .content {\n  display: none;\n}\n\n/* 页面布局 */\n.main {\n  padding: 30px 100px;\n  width: 960px;\n  margin: 0 auto;\n}\n\n.main .logo {\n  color: #333;\n  text-align: left;\n  margin-bottom: 30px;\n  line-height: 1;\n  height: 110px;\n  margin-top: -50px;\n  overflow: hidden;\n  *zoom: 1;\n}\n\n.main .logo a {\n  font-size: 160px;\n  color: #333;\n}\n\n.helps {\n  margin-top: 40px;\n}\n\n.helps pre {\n  padding: 20px;\n  margin: 10px 0;\n  border: solid 1px #e7e1cd;\n  background-color: #fffdef;\n  overflow: auto;\n}\n\n.icon_lists {\n  width: 100% !important;\n  overflow: hidden;\n  *zoom: 1;\n}\n\n.icon_lists li {\n  width: 100px;\n  margin-bottom: 10px;\n  margin-right: 20px;\n  text-align: center;\n  list-style: none !important;\n  cursor: default;\n}\n\n.icon_lists li .code-name {\n  line-height: 1.2;\n}\n\n.icon_lists .icon {\n  display: block;\n  height: 100px;\n  line-height: 100px;\n  font-size: 42px;\n  margin: 10px auto;\n  color: #333;\n  -webkit-transition: font-size 0.25s linear, width 0.25s linear;\n  -moz-transition: font-size 0.25s linear, width 0.25s linear;\n  transition: font-size 0.25s linear, width 0.25s linear;\n}\n\n.icon_lists .icon:hover {\n  font-size: 100px;\n}\n\n.icon_lists .svg-icon {\n  /* 通过设置 font-size 来改变图标大小 */\n  width: 1em;\n  /* 图标和文字相邻时，垂直对齐 */\n  vertical-align: -0.15em;\n  /* 通过设置 color 来改变 SVG 的颜色/fill */\n  fill: currentColor;\n  /* path 和 stroke 溢出 viewBox 部分在 IE 下会显示\n      normalize.css 中也包含这行 */\n  overflow: hidden;\n}\n\n.icon_lists li .name,\n.icon_lists li .code-name {\n  color: #666;\n}\n\n/* markdown 样式 */\n.markdown {\n  color: #666;\n  font-size: 14px;\n  line-height: 1.8;\n}\n\n.highlight {\n  line-height: 1.5;\n}\n\n.markdown img {\n  vertical-align: middle;\n  max-width: 100%;\n}\n\n.markdown h1 {\n  color: #404040;\n  font-weight: 500;\n  line-height: 40px;\n  margin-bottom: 24px;\n}\n\n.markdown h2,\n.markdown h3,\n.markdown h4,\n.markdown h5,\n.markdown h6 {\n  color: #404040;\n  margin: 1.6em 0 0.6em 0;\n  font-weight: 500;\n  clear: both;\n}\n\n.markdown h1 {\n  font-size: 28px;\n}\n\n.markdown h2 {\n  font-size: 22px;\n}\n\n.markdown h3 {\n  font-size: 16px;\n}\n\n.markdown h4 {\n  font-size: 14px;\n}\n\n.markdown h5 {\n  font-size: 12px;\n}\n\n.markdown h6 {\n  font-size: 12px;\n}\n\n.markdown hr {\n  height: 1px;\n  border: 0;\n  background: #e9e9e9;\n  margin: 16px 0;\n  clear: both;\n}\n\n.markdown p {\n  margin: 1em 0;\n}\n\n.markdown>p,\n.markdown>blockquote,\n.markdown>.highlight,\n.markdown>ol,\n.markdown>ul {\n  width: 80%;\n}\n\n.markdown ul>li {\n  list-style: circle;\n}\n\n.markdown>ul li,\n.markdown blockquote ul>li {\n  margin-left: 20px;\n  padding-left: 4px;\n}\n\n.markdown>ul li p,\n.markdown>ol li p {\n  margin: 0.6em 0;\n}\n\n.markdown ol>li {\n  list-style: decimal;\n}\n\n.markdown>ol li,\n.markdown blockquote ol>li {\n  margin-left: 20px;\n  padding-left: 4px;\n}\n\n.markdown code {\n  margin: 0 3px;\n  padding: 0 5px;\n  background: #eee;\n  border-radius: 3px;\n}\n\n.markdown strong,\n.markdown b {\n  font-weight: 600;\n}\n\n.markdown>table {\n  border-collapse: collapse;\n  border-spacing:0;\n  empty-cells: show;\n  border: 1px solid #e9e9e9;\n  width: 95%;\n  margin-bottom: 24px;\n}\n\n.markdown>table th {\n  white-space: nowrap;\n  color: #333;\n  font-weight: 600;\n}\n\n.markdown>table th,\n.markdown>table td {\n  border: 1px solid #e9e9e9;\n  padding: 8px 16px;\n  text-align: left;\n}\n\n.markdown>table th {\n  background: #F7F7F7;\n}\n\n.markdown blockquote {\n  font-size: 90%;\n  color: #999;\n  border-left: 4px solid #e9e9e9;\n  padding-left: 0.8em;\n  margin: 1em 0;\n}\n\n.markdown blockquote p {\n  margin: 0;\n}\n\n.markdown .anchor {\n  opacity: 0;\n  transition: opacity 0.3s ease;\n  margin-left: 8px;\n}\n\n.markdown .waiting {\n  color: #ccc;\n}\n\n.markdown h1:hover .anchor,\n.markdown h2:hover .anchor,\n.markdown h3:hover .anchor,\n.markdown h4:hover .anchor,\n.markdown h5:hover .anchor,\n.markdown h6:hover .anchor {\n  opacity: 1;\n  display: inline-block;\n}\n\n.markdown>br,\n.markdown>p>br {\n  clear: both;\n}\n\n\n.hljs {\n  display: block;\n  background: white;\n  padding: 0.5em;\n  color: #333333;\n  overflow-x: auto;\n}\n\n.hljs-comment,\n.hljs-meta {\n  color: #969896;\n}\n\n.hljs-string,\n.hljs-variable,\n.hljs-template-variable,\n.hljs-strong,\n.hljs-emphasis,\n.hljs-quote {\n  color: #df5000;\n}\n\n.hljs-keyword,\n.hljs-selector-tag,\n.hljs-type {\n  color: #a71d5d;\n}\n\n.hljs-literal,\n.hljs-symbol,\n.hljs-bullet,\n.hljs-attribute {\n  color: #0086b3;\n}\n\n.hljs-section,\n.hljs-name {\n  color: #63a35c;\n}\n\n.hljs-tag {\n  color: #333333;\n}\n\n.hljs-title,\n.hljs-attr,\n.hljs-selector-id,\n.hljs-selector-class,\n.hljs-selector-attr,\n.hljs-selector-pseudo {\n  color: #795da3;\n}\n\n.hljs-addition {\n  color: #55a532;\n  background-color: #eaffea;\n}\n\n.hljs-deletion {\n  color: #bd2c00;\n  background-color: #ffecec;\n}\n\n.hljs-link {\n  text-decoration: underline;\n}\n\n/* 代码高亮 */\n/* PrismJS 1.15.0\nhttps://prismjs.com/download.html#themes=prism&languages=markup+css+clike+javascript */\n/**\n * prism.js default theme for JavaScript, CSS and HTML\n * Based on dabblet (http://dabblet.com)\n * @author Lea Verou\n */\ncode[class*=\"language-\"],\npre[class*=\"language-\"] {\n  color: black;\n  background: none;\n  text-shadow: 0 1px white;\n  font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace;\n  text-align: left;\n  white-space: pre;\n  word-spacing: normal;\n  word-break: normal;\n  word-wrap: normal;\n  line-height: 1.5;\n\n  -moz-tab-size: 4;\n  -o-tab-size: 4;\n  tab-size: 4;\n\n  -webkit-hyphens: none;\n  -moz-hyphens: none;\n  -ms-hyphens: none;\n  hyphens: none;\n}\n\npre[class*=\"language-\"]::-moz-selection,\npre[class*=\"language-\"] ::-moz-selection,\ncode[class*=\"language-\"]::-moz-selection,\ncode[class*=\"language-\"] ::-moz-selection {\n  text-shadow: none;\n  background: #b3d4fc;\n}\n\npre[class*=\"language-\"]::selection,\npre[class*=\"language-\"] ::selection,\ncode[class*=\"language-\"]::selection,\ncode[class*=\"language-\"] ::selection {\n  text-shadow: none;\n  background: #b3d4fc;\n}\n\n@media print {\n\n  code[class*=\"language-\"],\n  pre[class*=\"language-\"] {\n    text-shadow: none;\n  }\n}\n\n/* Code blocks */\npre[class*=\"language-\"] {\n  padding: 1em;\n  margin: .5em 0;\n  overflow: auto;\n}\n\n:not(pre)>code[class*=\"language-\"],\npre[class*=\"language-\"] {\n  background: #f5f2f0;\n}\n\n/* Inline code */\n:not(pre)>code[class*=\"language-\"] {\n  padding: .1em;\n  border-radius: .3em;\n  white-space: normal;\n}\n\n.token.comment,\n.token.prolog,\n.token.doctype,\n.token.cdata {\n  color: slategray;\n}\n\n.token.punctuation {\n  color: #999;\n}\n\n.namespace {\n  opacity: .7;\n}\n\n.token.property,\n.token.tag,\n.token.boolean,\n.token.number,\n.token.constant,\n.token.symbol,\n.token.deleted {\n  color: #905;\n}\n\n.token.selector,\n.token.attr-name,\n.token.string,\n.token.char,\n.token.builtin,\n.token.inserted {\n  color: #690;\n}\n\n.token.operator,\n.token.entity,\n.token.url,\n.language-css .token.string,\n.style .token.string {\n  color: #9a6e3a;\n  background: hsla(0, 0%, 100%, .5);\n}\n\n.token.atrule,\n.token.attr-value,\n.token.keyword {\n  color: #07a;\n}\n\n.token.function,\n.token.class-name {\n  color: #DD4A68;\n}\n\n.token.regex,\n.token.important,\n.token.variable {\n  color: #e90;\n}\n\n.token.important,\n.token.bold {\n  font-weight: bold;\n}\n\n.token.italic {\n  font-style: italic;\n}\n\n.token.entity {\n  cursor: help;\n}\n"
  },
  {
    "path": "archive/ktransformers/website/src/assets/iconfont/demo_index.html",
    "content": "<!DOCTYPE html>\n<html>\n<head>\n  <meta charset=\"utf-8\"/>\n  <title>iconfont Demo</title>\n  <link rel=\"shortcut icon\" href=\"//img.alicdn.com/imgextra/i4/O1CN01Z5paLz1O0zuCC7osS_!!6000000001644-55-tps-83-82.svg\" type=\"image/x-icon\"/>\n  <link rel=\"icon\" type=\"image/svg+xml\" href=\"//img.alicdn.com/imgextra/i4/O1CN01Z5paLz1O0zuCC7osS_!!6000000001644-55-tps-83-82.svg\"/>\n  <link rel=\"stylesheet\" href=\"https://g.alicdn.com/thx/cube/1.3.2/cube.min.css\">\n  <link rel=\"stylesheet\" href=\"demo.css\">\n  <link rel=\"stylesheet\" href=\"iconfont.css\">\n  <script src=\"iconfont.js\"></script>\n  <!-- jQuery -->\n  <script src=\"https://a1.alicdn.com/oss/uploads/2018/12/26/7bfddb60-08e8-11e9-9b04-53e73bb6408b.js\"></script>\n  <!-- 代码高亮 -->\n  <script src=\"https://a1.alicdn.com/oss/uploads/2018/12/26/a3f714d0-08e6-11e9-8a15-ebf944d7534c.js\"></script>\n  <style>\n    .main .logo {\n      margin-top: 0;\n      height: auto;\n    }\n\n    .main .logo a {\n      display: flex;\n      align-items: center;\n    }\n\n    .main .logo .sub-title {\n      margin-left: 0.5em;\n      font-size: 22px;\n      color: #fff;\n      background: linear-gradient(-45deg, #3967FF, #B500FE);\n      -webkit-background-clip: text;\n      -webkit-text-fill-color: transparent;\n    }\n  </style>\n</head>\n<body>\n  <div class=\"main\">\n    <h1 class=\"logo\"><a href=\"https://www.iconfont.cn/\" title=\"iconfont 首页\" target=\"_blank\">\n      <img width=\"200\" src=\"https://img.alicdn.com/imgextra/i3/O1CN01Mn65HV1FfSEzR6DKv_!!6000000000514-55-tps-228-59.svg\">\n      \n    </a></h1>\n    <div class=\"nav-tabs\">\n      <ul id=\"tabs\" class=\"dib-box\">\n        <li class=\"dib active\"><span>Unicode</span></li>\n        <li class=\"dib\"><span>Font class</span></li>\n        <li class=\"dib\"><span>Symbol</span></li>\n      </ul>\n      \n      <a href=\"https://www.iconfont.cn/manage/index?manage_type=myprojects&projectId=4550268\" target=\"_blank\" class=\"nav-more\">查看项目</a>\n      \n    </div>\n    <div class=\"tab-container\">\n      <div class=\"content unicode\" style=\"display: block;\">\n          <ul class=\"icon_lists dib-box\">\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe8b0;</span>\n                <div class=\"name\">复制</div>\n                <div class=\"code-name\">&amp;#xe8b0;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe85e;</span>\n                <div class=\"name\">箭头下</div>\n                <div class=\"code-name\">&amp;#xe85e;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe651;</span>\n                <div class=\"name\">进度</div>\n                <div class=\"code-name\">&amp;#xe651;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe617;</span>\n                <div class=\"name\">环形进度条</div>\n                <div class=\"code-name\">&amp;#xe617;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe779;</span>\n                <div class=\"name\">向左1</div>\n                <div class=\"code-name\">&amp;#xe779;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe608;</span>\n                <div class=\"name\">点</div>\n                <div class=\"code-name\">&amp;#xe608;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe7dd;</span>\n                <div class=\"name\">编辑</div>\n                <div class=\"code-name\">&amp;#xe7dd;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe614;</span>\n                <div class=\"name\">删除</div>\n                <div class=\"code-name\">&amp;#xe614;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe618;</span>\n                <div class=\"name\">上传</div>\n                <div class=\"code-name\">&amp;#xe618;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe621;</span>\n                <div class=\"name\">探索-选中</div>\n                <div class=\"code-name\">&amp;#xe621;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe657;</span>\n                <div class=\"name\">ellipsis</div>\n                <div class=\"code-name\">&amp;#xe657;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe60c;</span>\n                <div class=\"name\">发送</div>\n                <div class=\"code-name\">&amp;#xe60c;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe62d;</span>\n                <div class=\"name\">列表</div>\n                <div class=\"code-name\">&amp;#xe62d;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe639;</span>\n                <div class=\"name\">列表</div>\n                <div class=\"code-name\">&amp;#xe639;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe6bd;</span>\n                <div class=\"name\">重试</div>\n                <div class=\"code-name\">&amp;#xe6bd;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe826;</span>\n                <div class=\"name\">Fork 记录</div>\n                <div class=\"code-name\">&amp;#xe826;</div>\n              </li>\n          \n          </ul>\n          <div class=\"article markdown\">\n          <h2 id=\"unicode-\">Unicode 引用</h2>\n          <hr>\n\n          <p>Unicode 是字体在网页端最原始的应用方式，特点是：</p>\n          <ul>\n            <li>支持按字体的方式去动态调整图标大小，颜色等等。</li>\n            <li>默认情况下不支持多色，直接添加多色图标会自动去色。</li>\n          </ul>\n          <blockquote>\n            <p>注意：新版 iconfont 支持两种方式引用多色图标：SVG symbol 引用方式和彩色字体图标模式。（使用彩色字体图标需要在「编辑项目」中开启「彩色」选项后并重新生成。）</p>\n          </blockquote>\n          <p>Unicode 使用步骤如下：</p>\n          <h3 id=\"-font-face\">第一步：拷贝项目下面生成的 <code>@font-face</code></h3>\n<pre><code class=\"language-css\"\n>@font-face {\n  font-family: 'iconfont';\n  src: url('iconfont.woff2?t=1717950820214') format('woff2'),\n       url('iconfont.woff?t=1717950820214') format('woff'),\n       url('iconfont.ttf?t=1717950820214') format('truetype'),\n       url('iconfont.svg?t=1717950820214#iconfont') format('svg');\n}\n</code></pre>\n          <h3 id=\"-iconfont-\">第二步：定义使用 iconfont 的样式</h3>\n<pre><code class=\"language-css\"\n>.iconfont {\n  font-family: \"iconfont\" !important;\n  font-size: 16px;\n  font-style: normal;\n  -webkit-font-smoothing: antialiased;\n  -moz-osx-font-smoothing: grayscale;\n}\n</code></pre>\n          <h3 id=\"-\">第三步：挑选相应图标并获取字体编码，应用于页面</h3>\n<pre>\n<code class=\"language-html\"\n>&lt;span class=\"iconfont\"&gt;&amp;#x33;&lt;/span&gt;\n</code></pre>\n          <blockquote>\n            <p>\"iconfont\" 是你项目下的 font-family。可以通过编辑项目查看，默认是 \"iconfont\"。</p>\n          </blockquote>\n          </div>\n      </div>\n      <div class=\"content font-class\">\n        <ul class=\"icon_lists dib-box\">\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-copy\"></span>\n            <div class=\"name\">\n              复制\n            </div>\n            <div class=\"code-name\">.icon-copy\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-arrow-down\"></span>\n            <div class=\"name\">\n              箭头下\n            </div>\n            <div class=\"code-name\">.icon-arrow-down\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-usage-progress\"></span>\n            <div class=\"name\">\n              进度\n            </div>\n            <div class=\"code-name\">.icon-usage-progress\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-gen-progress\"></span>\n            <div class=\"name\">\n              环形进度条\n            </div>\n            <div class=\"code-name\">.icon-gen-progress\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-back\"></span>\n            <div class=\"name\">\n              向左1\n            </div>\n            <div class=\"code-name\">.icon-back\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-point\"></span>\n            <div class=\"name\">\n              点\n            </div>\n            <div class=\"code-name\">.icon-point\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-edit\"></span>\n            <div class=\"name\">\n              编辑\n            </div>\n            <div class=\"code-name\">.icon-edit\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-delete\"></span>\n            <div class=\"name\">\n              删除\n            </div>\n            <div class=\"code-name\">.icon-delete\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-upload-1\"></span>\n            <div class=\"name\">\n              上传\n            </div>\n            <div class=\"code-name\">.icon-upload-1\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-explore\"></span>\n            <div class=\"name\">\n              探索-选中\n            </div>\n            <div class=\"code-name\">.icon-explore\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-ellipsis\"></span>\n            <div class=\"name\">\n              ellipsis\n            </div>\n            <div class=\"code-name\">.icon-ellipsis\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-sent\"></span>\n            <div class=\"name\">\n              发送\n            </div>\n            <div class=\"code-name\">.icon-sent\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-list-list\"></span>\n            <div class=\"name\">\n              列表\n            </div>\n            <div class=\"code-name\">.icon-list-list\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-list-icon\"></span>\n            <div class=\"name\">\n              列表\n            </div>\n            <div class=\"code-name\">.icon-list-icon\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-zhongshi\"></span>\n            <div class=\"name\">\n              重试\n            </div>\n            <div class=\"code-name\">.icon-zhongshi\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-log\"></span>\n            <div class=\"name\">\n              Fork 记录\n            </div>\n            <div class=\"code-name\">.icon-log\n            </div>\n          </li>\n          \n        </ul>\n        <div class=\"article markdown\">\n        <h2 id=\"font-class-\">font-class 引用</h2>\n        <hr>\n\n        <p>font-class 是 Unicode 使用方式的一种变种，主要是解决 Unicode 书写不直观，语意不明确的问题。</p>\n        <p>与 Unicode 使用方式相比，具有如下特点：</p>\n        <ul>\n          <li>相比于 Unicode 语意明确，书写更直观。可以很容易分辨这个 icon 是什么。</li>\n          <li>因为使用 class 来定义图标，所以当要替换图标时，只需要修改 class 里面的 Unicode 引用。</li>\n        </ul>\n        <p>使用步骤如下：</p>\n        <h3 id=\"-fontclass-\">第一步：引入项目下面生成的 fontclass 代码：</h3>\n<pre><code class=\"language-html\">&lt;link rel=\"stylesheet\" href=\"./iconfont.css\"&gt;\n</code></pre>\n        <h3 id=\"-\">第二步：挑选相应图标并获取类名，应用于页面：</h3>\n<pre><code class=\"language-html\">&lt;span class=\"iconfont icon-xxx\"&gt;&lt;/span&gt;\n</code></pre>\n        <blockquote>\n          <p>\"\n            iconfont\" 是你项目下的 font-family。可以通过编辑项目查看，默认是 \"iconfont\"。</p>\n        </blockquote>\n      </div>\n      </div>\n      <div class=\"content symbol\">\n          <ul class=\"icon_lists dib-box\">\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-copy\"></use>\n                </svg>\n                <div class=\"name\">复制</div>\n                <div class=\"code-name\">#icon-copy</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-arrow-down\"></use>\n                </svg>\n                <div class=\"name\">箭头下</div>\n                <div class=\"code-name\">#icon-arrow-down</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-usage-progress\"></use>\n                </svg>\n                <div class=\"name\">进度</div>\n                <div class=\"code-name\">#icon-usage-progress</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-gen-progress\"></use>\n                </svg>\n                <div class=\"name\">环形进度条</div>\n                <div class=\"code-name\">#icon-gen-progress</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-back\"></use>\n                </svg>\n                <div class=\"name\">向左1</div>\n                <div class=\"code-name\">#icon-back</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-point\"></use>\n                </svg>\n                <div class=\"name\">点</div>\n                <div class=\"code-name\">#icon-point</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-edit\"></use>\n                </svg>\n                <div class=\"name\">编辑</div>\n                <div class=\"code-name\">#icon-edit</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-delete\"></use>\n                </svg>\n                <div class=\"name\">删除</div>\n                <div class=\"code-name\">#icon-delete</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-upload-1\"></use>\n                </svg>\n                <div class=\"name\">上传</div>\n                <div class=\"code-name\">#icon-upload-1</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-explore\"></use>\n                </svg>\n                <div class=\"name\">探索-选中</div>\n                <div class=\"code-name\">#icon-explore</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-ellipsis\"></use>\n                </svg>\n                <div class=\"name\">ellipsis</div>\n                <div class=\"code-name\">#icon-ellipsis</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-sent\"></use>\n                </svg>\n                <div class=\"name\">发送</div>\n                <div class=\"code-name\">#icon-sent</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-list-list\"></use>\n                </svg>\n                <div class=\"name\">列表</div>\n                <div class=\"code-name\">#icon-list-list</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-list-icon\"></use>\n                </svg>\n                <div class=\"name\">列表</div>\n                <div class=\"code-name\">#icon-list-icon</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-zhongshi\"></use>\n                </svg>\n                <div class=\"name\">重试</div>\n                <div class=\"code-name\">#icon-zhongshi</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-log\"></use>\n                </svg>\n                <div class=\"name\">Fork 记录</div>\n                <div class=\"code-name\">#icon-log</div>\n            </li>\n          \n          </ul>\n          <div class=\"article markdown\">\n          <h2 id=\"symbol-\">Symbol 引用</h2>\n          <hr>\n\n          <p>这是一种全新的使用方式，应该说这才是未来的主流，也是平台目前推荐的用法。相关介绍可以参考这篇<a href=\"\">文章</a>\n            这种用法其实是做了一个 SVG 的集合，与另外两种相比具有如下特点：</p>\n          <ul>\n            <li>支持多色图标了，不再受单色限制。</li>\n            <li>通过一些技巧，支持像字体那样，通过 <code>font-size</code>, <code>color</code> 来调整样式。</li>\n            <li>兼容性较差，支持 IE9+，及现代浏览器。</li>\n            <li>浏览器渲染 SVG 的性能一般，还不如 png。</li>\n          </ul>\n          <p>使用步骤如下：</p>\n          <h3 id=\"-symbol-\">第一步：引入项目下面生成的 symbol 代码：</h3>\n<pre><code class=\"language-html\">&lt;script src=\"./iconfont.js\"&gt;&lt;/script&gt;\n</code></pre>\n          <h3 id=\"-css-\">第二步：加入通用 CSS 代码（引入一次就行）：</h3>\n<pre><code class=\"language-html\">&lt;style&gt;\n.icon {\n  width: 1em;\n  height: 1em;\n  vertical-align: -0.15em;\n  fill: currentColor;\n  overflow: hidden;\n}\n&lt;/style&gt;\n</code></pre>\n          <h3 id=\"-\">第三步：挑选相应图标并获取类名，应用于页面：</h3>\n<pre><code class=\"language-html\">&lt;svg class=\"icon\" aria-hidden=\"true\"&gt;\n  &lt;use xlink:href=\"#icon-xxx\"&gt;&lt;/use&gt;\n&lt;/svg&gt;\n</code></pre>\n          </div>\n      </div>\n\n    </div>\n  </div>\n  <script>\n  $(document).ready(function () {\n      $('.tab-container .content:first').show()\n\n      $('#tabs li').click(function (e) {\n        var tabContent = $('.tab-container .content')\n        var index = $(this).index()\n\n        if ($(this).hasClass('active')) {\n          return\n        } else {\n          $('#tabs li').removeClass('active')\n          $(this).addClass('active')\n\n          tabContent.hide().eq(index).fadeIn()\n        }\n      })\n    })\n  </script>\n</body>\n</html>\n"
  },
  {
    "path": "archive/ktransformers/website/src/assets/iconfont/iconfont.css",
    "content": "@font-face {\n  font-family: \"iconfont\"; /* Project id 4550268 */\n  src: url('iconfont.woff2?t=1717950820214') format('woff2'),\n       url('iconfont.woff?t=1717950820214') format('woff'),\n       url('iconfont.ttf?t=1717950820214') format('truetype'),\n       url('iconfont.svg?t=1717950820214#iconfont') format('svg');\n}\n\n.iconfont {\n  font-family: \"iconfont\" !important;\n  font-size: 16px;\n  font-style: normal;\n  -webkit-font-smoothing: antialiased;\n  -moz-osx-font-smoothing: grayscale;\n}\n\n.icon-copy:before {\n  content: \"\\e8b0\";\n}\n\n.icon-arrow-down:before {\n  content: \"\\e85e\";\n}\n\n.icon-usage-progress:before {\n  content: \"\\e651\";\n}\n\n.icon-gen-progress:before {\n  content: \"\\e617\";\n}\n\n.icon-back:before {\n  content: \"\\e779\";\n}\n\n.icon-point:before {\n  content: \"\\e608\";\n}\n\n.icon-edit:before {\n  content: \"\\e7dd\";\n}\n\n.icon-delete:before {\n  content: \"\\e614\";\n}\n\n.icon-upload-1:before {\n  content: \"\\e618\";\n}\n\n.icon-explore:before {\n  content: \"\\e621\";\n}\n\n.icon-ellipsis:before {\n  content: \"\\e657\";\n}\n\n.icon-sent:before {\n  content: \"\\e60c\";\n}\n\n.icon-list-list:before {\n  content: \"\\e62d\";\n}\n\n.icon-list-icon:before {\n  content: \"\\e639\";\n}\n\n.icon-zhongshi:before {\n  content: \"\\e6bd\";\n}\n\n.icon-log:before {\n  content: \"\\e826\";\n}\n\n"
  },
  {
    "path": "archive/ktransformers/website/src/assets/iconfont/iconfont.js",
    "content": "window._iconfont_svg_string_4550268='<svg><symbol id=\"icon-copy\" viewBox=\"0 0 1024 1024\"><path d=\"M394.666667 106.666667h448a74.666667 74.666667 0 0 1 74.666666 74.666666v448a74.666667 74.666667 0 0 1-74.666666 74.666667H394.666667a74.666667 74.666667 0 0 1-74.666667-74.666667V181.333333a74.666667 74.666667 0 0 1 74.666667-74.666666z m0 64a10.666667 10.666667 0 0 0-10.666667 10.666666v448a10.666667 10.666667 0 0 0 10.666667 10.666667h448a10.666667 10.666667 0 0 0 10.666666-10.666667V181.333333a10.666667 10.666667 0 0 0-10.666666-10.666666H394.666667z m245.333333 597.333333a32 32 0 0 1 64 0v74.666667a74.666667 74.666667 0 0 1-74.666667 74.666666H181.333333a74.666667 74.666667 0 0 1-74.666666-74.666666V394.666667a74.666667 74.666667 0 0 1 74.666666-74.666667h74.666667a32 32 0 0 1 0 64h-74.666667a10.666667 10.666667 0 0 0-10.666666 10.666667v448a10.666667 10.666667 0 0 0 10.666666 10.666666h448a10.666667 10.666667 0 0 0 10.666667-10.666666v-74.666667z\" fill=\"#000000\" ></path></symbol><symbol id=\"icon-arrow-down\" viewBox=\"0 0 1024 1024\"><path d=\"M554.666667 690.005333l228.864-228.864 60.330666 60.330667L512 853.333333l-331.861333-331.861333 60.330666-60.330667L469.333333 690.005333V170.666667h85.333334v519.338666z\"  ></path></symbol><symbol id=\"icon-usage-progress\" viewBox=\"0 0 1024 1024\"><path d=\"M512 125.098667A386.901333 386.901333 0 1 1 125.098667 512 386.901333 386.901333 0 0 1 512 125.098667z\" fill=\"#ACE9C5\" ></path><path d=\"M512 318.634667A193.365333 193.365333 0 1 1 318.634667 512 193.365333 193.365333 0 0 1 512 318.634667z\" fill=\"#2BA866\" ></path></symbol><symbol id=\"icon-gen-progress\" viewBox=\"0 0 1024 1024\"><path d=\"M692.004733 714.930578l96.018649 96.017519C715.492309 877.950022 618.525386 918.887417 512 918.887417c-104.225342 0-199.297978-39.187779-271.287664-103.631964l96.127152-96.126023C384.097201 759.135506 445.230905 783.258278 512 783.258278c69.07253 0 132.114084-25.817007 180.004733-68.3277z m-202.61185-609.200883L489.395143 241.670781C350.16053 253.157439 240.741722 369.800759 240.741722 512c0 66.767965 24.122773 127.900539 64.127717 175.160512l-96.126022 96.126022C144.299232 711.295717 105.112583 616.225342 105.112583 512c0-217.130949 170.07894-394.539514 384.2803-406.270305z m325.8637 134.984901C879.700768 312.702022 918.887417 407.774658 918.887417 512c0 101.921907-37.474331 195.091214-99.395814 266.479611l-96.270694-96.268432C760.774358 635.667779 783.258278 576.460009 783.258278 512c0-66.767965-24.122773-127.901669-64.128848-175.161642l96.127153-96.124892zM534.608247 105.728565c95.334852 5.221722 181.928406 43.261174 248.678287 103.013722l-96.127152 96.127152c-41.869845-35.444415-94.631841-58.422252-152.553395-63.199788l0.00226-135.941086z\" fill=\"#448AFF\" fill-opacity=\".6\" ></path><path d=\"M489.392883 105.729695L489.395143 241.670781C350.16053 253.157439 240.741722 369.800759 240.741722 512c0 66.767965 24.122773 127.900539 64.127717 175.160512l-96.126022 96.126022C144.299232 711.295717 105.112583 616.225342 105.112583 512c0-217.130949 170.07894-394.539514 384.2803-406.270305z\" fill=\"#448AFF\" ></path></symbol><symbol id=\"icon-back\" viewBox=\"0 0 1024 1024\"><path d=\"M671.968176 911.99957c-12.287381 0-24.576482-4.67206-33.951566-14.047144L286.048434 545.984249c-18.751888-18.719204-18.751888-49.12028 0-67.872168L638.016611 126.111222c18.751888-18.751888 49.12028-18.751888 67.872168 0 18.751888 18.719204 18.751888 49.12028 0 67.872168l-318.016611 318.047574L705.888778 830.047574c18.751888 18.751888 18.751888 49.12028 0 67.872168C696.544658 907.32751 684.255557 911.99957 671.968176 911.99957z\" fill=\"#2c2c2c\" ></path></symbol><symbol id=\"icon-point\" viewBox=\"0 0 1024 1024\"><path d=\"M512 307.2a204.86826667 204.86826667 0 0 1 0 409.6 204.8 204.8 0 0 1 0-409.6z\" fill=\"\" ></path></symbol><symbol id=\"icon-edit\" viewBox=\"0 0 1024 1024\"><path d=\"M899.072 125.44c-28.672-28.672-67.072-44.544-107.52-44.544s-78.848 15.872-107.52 44.544L251.392 558.08c-34.304 34.304-60.416 74.752-78.336 119.808L88.576 896c-4.608 11.264-1.536 24.064 7.168 32.768 5.632 5.632 13.824 9.216 21.504 9.216 3.584 0 7.68-0.512 11.264-2.048l218.624-84.48c45.056-17.408 85.504-44.032 119.808-78.336l351.744-351.744 80.896-80.896c58.88-59.392 58.88-155.648-0.512-215.04z m-475.648 604.16c-28.16 28.16-61.44 50.176-98.816 64.512l-153.6 59.392 59.392-153.6c14.336-37.376 35.84-70.656 64.512-98.816L625.152 271.36l128.512 128.512-330.24 329.728z m432.64-432.128l-58.88 58.88-128.512-128.512L727.552 168.96c16.896-16.896 39.936-26.624 64.512-26.624s47.104 9.216 64.512 26.624c34.816 35.328 34.816 92.672-0.512 128.512z\" fill=\"#333333\" ></path></symbol><symbol id=\"icon-delete\" viewBox=\"0 0 1024 1024\"><path d=\"M742.4 944H281.6c-49.4 0-89.6-43.1-89.6-96V368h64v480c0 17.3 11.7 32 25.6 32h460.8c13.9 0 25.6-14.7 25.6-32V368h64v480c0 52.9-40.2 96-89.6 96z\"  ></path><path d=\"M384 368h64v416h-64zM592 368h64v416h-64zM64 224h896v64H64z\"  ></path><path d=\"M768 288H256V160c0-52.9 43.1-96 96-96h320c52.9 0 96 43.1 96 96v128z m-448-64h384v-64c0-17.6-14.4-32-32-32H352c-17.6 0-32 14.4-32 32v64z\"  ></path></symbol><symbol id=\"icon-upload-1\" viewBox=\"0 0 1024 1024\"><path d=\"M323.034074 291.934815l383.620741 0c9.481481 0 17.256296-8.533333 17.256296-18.962963 0-10.42963-7.68-18.962963-17.256296-18.962963L323.034074 254.008889c-9.481481 0-17.256296 8.533333-17.256296 18.962963C305.777778 283.496296 313.457778 291.934815 323.034074 291.934815z\" fill=\"#272536\" ></path><path d=\"M522.05037 328.628148c-1.232593-1.232593-2.844444-1.896296-4.740741-1.991111-1.706667-0.094815-3.318519-0.094815-5.025185 0-1.896296 0.094815-3.508148 0.758519-4.740741 1.991111L349.013333 487.253333c-3.887407 3.887407-1.896296 12.325926 4.456296 18.773333 6.447407 6.447407 14.791111 8.438519 18.773333 4.456296l125.060741-125.060741 0 367.122963c0 9.671111 7.86963 17.540741 17.540741 17.540741l0 0c9.671111 0 17.540741-7.86963 17.540741-17.540741L532.385185 385.327407l125.060741 125.060741c3.887407 3.887407 12.325926 1.896296 18.773333-4.456296 6.447407-6.447407 8.438519-14.791111 4.456296-18.773333L522.05037 328.628148z\" fill=\"#272536\" ></path></symbol><symbol id=\"icon-explore\" viewBox=\"0 0 1024 1024\"><path d=\"M926.352541 89.231277c-0.029676-7.432273-1.212618-13.651928-2.837628-19.264762-31.228235-8.264221-71.898517 1.24127-106.283652 17.927301-7.049556 3.41068-23.762193 13.583366-48.51597 28.643364-10.237155 6.250354-19.264762 11.739369-23.251563 14.002922-0.384763 0.224104-0.608867 0.63752-0.958838 0.861624-67.557652-41.147142-146.571217-65.327868-231.319389-65.327868-246.251474 0-446.569802 200.319351-446.569802 446.564685 0 82.554204 22.904663 159.683862 62.105476 226.062666-46.315862 71.387887-69.2809 122.93182-63.283302 157.863401 1.24127 7.144724 13.555737 8.28878 20.316721 8.28878 137.989771 0 453.393207-302.802444 492.628814-341.399507C751.64859 393.022235 926.449755 184.667883 926.352541 89.231277L926.352541 89.231277zM305.847292 611.014084c-43.956118 0-79.744205-35.757388-79.744205-79.743182 0-43.956118 35.789111-79.744205 79.744205-79.744205 43.956118 0 79.743182 35.789111 79.743182 79.744205C385.591497 575.256696 349.803409 611.014084 305.847292 611.014084L305.847292 611.014084zM446.19783 387.730719c-52.760644 0-95.694479-42.937928-95.694479-95.692433 0-52.760644 42.933835-95.694479 95.694479-95.694479 52.761668 0 95.694479 42.933835 95.694479 95.694479C541.892309 344.79279 498.958474 387.730719 446.19783 387.730719L446.19783 387.730719zM893.595486 279.9469c-66.889433 99.330286-172.055634 218.596623-276.967032 321.751005-28.551266 28.104081-201.624067 195.822944-346.982666 285.198507 0.12689-0.097214 0.223081-0.160659 0.349971-0.224104 70.049403 45.708018 153.491837 72.536037 243.189741 72.536037 246.246357 0 446.565708-200.318328 446.565708-446.570825C959.716416 427.317319 935.282934 347.82587 893.595486 279.9469L893.595486 279.9469zM638.54051 799.720957c-35.180244 0-63.793932-28.614711-63.793932-63.794955 0-35.184337 28.613688-63.799048 63.793932-63.799048 35.184337 0 63.793932 28.614711 63.793932 63.799048C702.334441 771.106246 673.724847 799.720957 638.54051 799.720957L638.54051 799.720957zM638.54051 799.720957\" fill=\"#615CED\" ></path></symbol><symbol id=\"icon-ellipsis\" viewBox=\"0 0 1024 1024\"><path d=\"M322.292 505.5m-66 0a66 66 0 1 0 132 0 66 66 0 1 0-132 0Z\" fill=\"#272636\" ></path><path d=\"M509.791 505.5m-66 0a66 66 0 1 0 132 0 66 66 0 1 0-132 0Z\" fill=\"#272636\" ></path><path d=\"M701.791 505.5m-66 0a66 66 0 1 0 132 0 66 66 0 1 0-132 0Z\" fill=\"#272636\" ></path></symbol><symbol id=\"icon-sent\" viewBox=\"0 0 1024 1024\"><path d=\"M998.976 554.3232C1031.232 539.6032 1031.328 515.7952 998.976 501.0432L122.88 101.3312C90.624 86.6112 64.448 103.5072 64.384 138.4832L64 426.9952 773.568 527.6672 64 628.3392 64.384 916.8832C64.448 952.1152 90.528 968.7872 122.88 954.0352L998.976 554.3232Z\"  ></path></symbol><symbol id=\"icon-list-list\" viewBox=\"0 0 1024 1024\"><path d=\"M419.037 287.953h413.124c17.673 0 32-14.327 32-32s-14.327-32-32-32H419.037c-17.673 0-32 14.327-32 32s14.327 32 32 32zM419.028 543.17h411.608c17.673 0 32-14.327 32-32s-14.327-32-32-32H419.028c-17.673 0-32 14.327-32 32s14.327 32 32 32zM832.161 735.802H419.037c-17.673 0-32 14.327-32 32s14.327 32 32 32h413.124c17.673 0 32-14.327 32-32s-14.327-32-32-32z\" fill=\"\" ></path><path d=\"M256.037 255.953m-64 0a64 64 0 1 0 128 0 64 64 0 1 0-128 0Z\" fill=\"\" ></path><path d=\"M256.037 510.787m-64 0a64 64 0 1 0 128 0 64 64 0 1 0-128 0Z\" fill=\"\" ></path><path d=\"M256.037 767.621m-64 0a64 64 0 1 0 128 0 64 64 0 1 0-128 0Z\" fill=\"\" ></path></symbol><symbol id=\"icon-list-icon\" viewBox=\"0 0 1024 1024\"><path d=\"M841.6 489.6h-214.4c-48 0-86.4-38.4-86.4-86.4V188.8c0-48 38.4-86.4 86.4-86.4h214.4c48 0 86.4 38.4 86.4 86.4v214.4c0 48-38.4 86.4-86.4 86.4z m-211.2-320c-12.8 0-22.4 9.6-22.4 22.4v214.4c0 12.8 9.6 22.4 22.4 22.4h214.4c12.8 0 22.4-9.6 22.4-22.4V188.8c0-12.8-9.6-22.4-22.4-22.4h-214.4zM393.6 489.6H182.4c-48 0-86.4-38.4-86.4-86.4V188.8c0-48 38.4-86.4 86.4-86.4h214.4c48 0 86.4 38.4 86.4 86.4v214.4c-3.2 48-41.6 86.4-89.6 86.4z m-211.2-320c-12.8 0-22.4 9.6-22.4 19.2v214.4c0 12.8 9.6 22.4 22.4 22.4h214.4c12.8 0 22.4-9.6 22.4-22.4V188.8c0-12.8-9.6-22.4-22.4-22.4H182.4zM841.6 937.6h-214.4c-48 0-86.4-38.4-86.4-86.4v-214.4c0-48 38.4-86.4 86.4-86.4h214.4c48 0 86.4 38.4 86.4 86.4v214.4c0 48-38.4 86.4-86.4 86.4z m-211.2-323.2c-12.8 0-22.4 9.6-22.4 22.4v214.4c0 12.8 9.6 22.4 22.4 22.4h214.4c12.8 0 22.4-9.6 22.4-22.4v-214.4c0-12.8-9.6-22.4-22.4-22.4h-214.4zM393.6 937.6H182.4c-48 0-86.4-38.4-86.4-86.4v-214.4c0-48 38.4-86.4 86.4-86.4h214.4c48 0 86.4 38.4 86.4 86.4v214.4c-3.2 48-41.6 86.4-89.6 86.4zM182.4 614.4c-12.8 0-22.4 9.6-22.4 22.4v214.4c0 12.8 9.6 22.4 22.4 22.4h214.4c12.8 0 22.4-9.6 22.4-22.4v-214.4c0-12.8-9.6-22.4-22.4-22.4H182.4z\" fill=\"#333333\" ></path></symbol><symbol id=\"icon-zhongshi\" viewBox=\"0 0 1024 1024\"><path d=\"M973.53044 167.133265l-65.609003 50.468463A491.226376 491.226376 0 0 0 522.971405 33.282123C253.074841 33.282123 34.74388 247.370807 34.378166 512.220525c-0.365714 265.142289 218.550389 480.108685 488.593239 480.108686 211.016691 0 390.728306-131.291147 459.189873-315.245039a9.069695 9.069695 0 0 0-5.851416-11.775975l-65.82843-22.308523a9.435408 9.435408 0 0 0-11.775975 5.485702 392.48373 392.48373 0 0 1-92.525516 141.896839 402.650566 402.650566 0 0 1-282.915965 115.12661c-54.125598 0-106.495772-10.386263-155.793952-30.793077a398.627717 398.627717 0 0 1-212.845258-209.188123 383.779749 383.779749 0 0 1-31.451361-152.868244c0-53.1016 10.532549-104.374633 31.451361-152.868243 20.114243-46.738186 49.005609-88.795238 85.723245-124.85459a401.260854 401.260854 0 0 1 282.915965-115.12661c54.052456 0 106.422629 10.459406 155.720809 30.866219a398.627717 398.627717 0 0 1 159.52423 120.100314l-69.997565 53.686742a9.069695 9.069695 0 0 0 3.437707 16.091394l204.287562 49.151895c5.851416 1.316569 11.556547-2.998851 11.556547-8.777124l0.950855-206.554986a9.508551 9.508551 0 0 0-15.213681-7.167985z\" fill=\"#000000\" ></path></symbol><symbol id=\"icon-log\" viewBox=\"0 0 1024 1024\"><path d=\"M288 64c70.692 0 128 57.308 128 128 0 58.192-38.833 107.315-91.998 122.867L324 571.5h225c48.8 0 84.134-19.864 110.1-62.009 15.655-25.408 27.76-58.805 36.092-100.127C648.71 390.177 616 344.408 616 291c0-70.692 57.308-128 128-128 70.692 0 128 57.308 128 128 0 62.814-45.245 115.06-104.923 125.925-9.94 52.391-25.407 95.81-46.677 130.334-38.644 62.721-96.365 95.58-169.189 96.231l-2.211 0.01H324l0.002 65.633c52.52 15.363 91.052 63.486 91.98 120.75L416 832c0 70.692-57.308 128-128 128-70.692 0-128-57.308-128-128 0-58.193 38.833-107.315 91.999-122.868V314.868C198.833 299.315 160 250.193 160 192c0-70.692 57.308-128 128-128z\" fill=\"#333333\" ></path></symbol></svg>',function(l){var t=(t=document.getElementsByTagName(\"script\"))[t.length-1],c=t.getAttribute(\"data-injectcss\"),t=t.getAttribute(\"data-disable-injectsvg\");if(!t){var i,o,e,a,h,n=function(t,c){c.parentNode.insertBefore(t,c)};if(c&&!l.__iconfont__svg__cssinject__){l.__iconfont__svg__cssinject__=!0;try{document.write(\"<style>.svgfont {display: inline-block;width: 1em;height: 1em;fill: currentColor;vertical-align: -0.1em;font-size:16px;}</style>\")}catch(t){console&&console.log(t)}}i=function(){var t,c=document.createElement(\"div\");c.innerHTML=l._iconfont_svg_string_4550268,(c=c.getElementsByTagName(\"svg\")[0])&&(c.setAttribute(\"aria-hidden\",\"true\"),c.style.position=\"absolute\",c.style.width=0,c.style.height=0,c.style.overflow=\"hidden\",c=c,(t=document.body).firstChild?n(c,t.firstChild):t.appendChild(c))},document.addEventListener?~[\"complete\",\"loaded\",\"interactive\"].indexOf(document.readyState)?setTimeout(i,0):(o=function(){document.removeEventListener(\"DOMContentLoaded\",o,!1),i()},document.addEventListener(\"DOMContentLoaded\",o,!1)):document.attachEvent&&(e=i,a=l.document,h=!1,d(),a.onreadystatechange=function(){\"complete\"==a.readyState&&(a.onreadystatechange=null,s())})}function s(){h||(h=!0,e())}function d(){try{a.documentElement.doScroll(\"left\")}catch(t){return void setTimeout(d,50)}s()}}(window);"
  },
  {
    "path": "archive/ktransformers/website/src/assets/iconfont/iconfont.json",
    "content": "{\n  \"id\": \"4550268\",\n  \"name\": \"Lexllama\",\n  \"font_family\": \"iconfont\",\n  \"css_prefix_text\": \"icon-\",\n  \"description\": \"Lexllama开源项目使用\",\n  \"glyphs\": [\n    {\n      \"icon_id\": \"11372665\",\n      \"name\": \"复制\",\n      \"font_class\": \"copy\",\n      \"unicode\": \"e8b0\",\n      \"unicode_decimal\": 59568\n    },\n    {\n      \"icon_id\": \"34202237\",\n      \"name\": \"箭头下\",\n      \"font_class\": \"arrow-down\",\n      \"unicode\": \"e85e\",\n      \"unicode_decimal\": 59486\n    },\n    {\n      \"icon_id\": \"7766233\",\n      \"name\": \"进度\",\n      \"font_class\": \"usage-progress\",\n      \"unicode\": \"e651\",\n      \"unicode_decimal\": 58961\n    },\n    {\n      \"icon_id\": \"38865122\",\n      \"name\": \"环形进度条\",\n      \"font_class\": \"gen-progress\",\n      \"unicode\": \"e617\",\n      \"unicode_decimal\": 58903\n    },\n    {\n      \"icon_id\": \"577406\",\n      \"name\": \"向左1\",\n      \"font_class\": \"back\",\n      \"unicode\": \"e779\",\n      \"unicode_decimal\": 59257\n    },\n    {\n      \"icon_id\": \"1920286\",\n      \"name\": \"点\",\n      \"font_class\": \"point\",\n      \"unicode\": \"e608\",\n      \"unicode_decimal\": 58888\n    },\n    {\n      \"icon_id\": \"8866967\",\n      \"name\": \"编辑\",\n      \"font_class\": \"edit\",\n      \"unicode\": \"e7dd\",\n      \"unicode_decimal\": 59357\n    },\n    {\n      \"icon_id\": \"10199175\",\n      \"name\": \"删除\",\n      \"font_class\": \"delete\",\n      \"unicode\": \"e614\",\n      \"unicode_decimal\": 58900\n    },\n    {\n      \"icon_id\": \"1010111\",\n      \"name\": \"上传\",\n      \"font_class\": \"upload-1\",\n      \"unicode\": \"e618\",\n      \"unicode_decimal\": 58904\n    },\n    {\n      \"icon_id\": \"351773\",\n      \"name\": \"探索-选中\",\n      \"font_class\": \"explore\",\n      \"unicode\": \"e621\",\n      \"unicode_decimal\": 58913\n    },\n    {\n      \"icon_id\": \"564941\",\n      \"name\": \"ellipsis\",\n      \"font_class\": \"ellipsis\",\n      \"unicode\": \"e657\",\n      \"unicode_decimal\": 58967\n    },\n    {\n      \"icon_id\": \"1048859\",\n      \"name\": \"发送\",\n      \"font_class\": \"sent\",\n      \"unicode\": \"e60c\",\n      \"unicode_decimal\": 58892\n    },\n    {\n      \"icon_id\": \"1304951\",\n      \"name\": \"列表\",\n      \"font_class\": \"list-list\",\n      \"unicode\": \"e62d\",\n      \"unicode_decimal\": 58925\n    },\n    {\n      \"icon_id\": \"8676284\",\n      \"name\": \"列表\",\n      \"font_class\": \"list-icon\",\n      \"unicode\": \"e639\",\n      \"unicode_decimal\": 58937\n    },\n    {\n      \"icon_id\": \"22290034\",\n      \"name\": \"重试\",\n      \"font_class\": \"zhongshi\",\n      \"unicode\": \"e6bd\",\n      \"unicode_decimal\": 59069\n    },\n    {\n      \"icon_id\": \"22961085\",\n      \"name\": \"Fork 记录\",\n      \"font_class\": \"log\",\n      \"unicode\": \"e826\",\n      \"unicode_decimal\": 59430\n    }\n  ]\n}\n"
  },
  {
    "path": "archive/ktransformers/website/src/components/chat/index.vue",
    "content": "<template>\n  <div class=\"chat-panel\">\n    <!-- <div class=\"chat-model\">{{ activeAssistant?.model }}</div> -->\n    <div class=\"chat-panel-inner flex-column\">\n      <div class=\"chat-init flex-unit flex-column\" v-if=\"isNotChating\">\n        <div class=\"assistant-info flex-column flex-unit\">\n          <div class=\"avatar\">\n            <img src=\"../../../public/images/avatar.png\" />\n          </div>\n          <div class=\"name\">\n            {{ activeAssistant.name }}\n          </div>\n          <div class=\"desc\">\n            {{ activeAssistant.description }}\n          </div>\n        </div>\n      </div>\n      <div class=\"chat-msg flex-unit\" v-else>\n        <ul>\n          <li\n            class=\"chat-msg-item flex-row\"\n            v-for=\"(msg, index) in localMessages\"\n            :key=\"index\"\n          >\n            <div class=\"avatar\" v-if=\"msg.role == 'user'\">\n              <img src=\"../../../public/images/user-filling.png\" />\n            </div>\n            <div class=\"avatar\" v-else>\n              <img src=\"../../../public/images/avatar.png\" />\n            </div>\n            <div class=\"msg flex-unit\">\n              <div class=\"title flex-row\">\n                <div class=\"name\">{{ msg.role }}</div>\n                <div class=\"time flex-row\">\n                  {{ timeFormat(msg.created_at) }}\n                </div>\n              </div>\n              <div\n                class=\"content\"\n                v-html=\"markedText(msg.content)\"\n                ref=\"content_Ref\"\n              ></div>\n              <div class=\"copy-btn flex-row\" v-show=\"msgBttnBoxShow[index]\">\n                <i\n                  class=\"iconfont icon-copy\"\n                  @click=\"copy(createText(msg.content))\"\n                ></i>\n              </div>\n            </div>\n          </li>\n        </ul>\n      </div>\n      <div class=\"scroll-box\" v-show=\"showScrollButton\" @click=\"scrollToBottom\">\n        <i class=\"iconfont icon-arrow-down\"></i>\n      </div>\n      <div class=\"chat-send\">\n        <div\n          class=\"chat-box flex-row\"\n          :style=\"{ height: textareaHeight + 'px' }\"\n          ref=\"chatBox_Ref\"\n        >\n          <button @click=\"StopOutput\" class=\"stop-btn\" v-show=\"isRunning\">\n            stop\n          </button>\n          <textarea\n            name=\"chat-input\"\n            class=\"chat-input flex-unit\"\n            :placeholder=\"inputPlaceholder\"\n            v-model=\"inputQuestion\"\n            @keydown=\"keyBoardCommitQuestion\"\n            :disabled=\"inputDisabled\"\n            :style=\"{ height: textareaHeight + 'px' }\"\n            @input=\"handleInput\"\n            ref=\"textarea_ref\"\n            maxlength=\"2000\"\n            cols=\"20\"\n          ></textarea>\n          <i class=\"iconfont icon-sent\" @click=\"clickCommitQuestion\"></i>\n        </div>\n      </div>\n    </div>\n  </div>\n</template>\n\n<script lang=\"ts\">\nimport {\n  defineComponent,\n  nextTick,\n  PropType,\n  ref,\n  watch,\n  computed,\n  onMounted,\n} from \"vue\";\nimport { IThread, IMessageData, IAssistant } from \"@/utils/types\";\nimport { marked } from \"marked\";\nimport { createMessage } from \"@/api/message\";\nimport { createRun, cancelRun } from \"@/api/run\";\nimport { getAssistant } from \"@/api/assistant\";\nimport { createThread } from \"@/api/thread\";\nimport BScroll from \"better-scroll\";\nimport { useRouter, useRoute } from \"vue-router\";\nimport { useI18n } from \"vue-i18n\";\nimport { ElMessage } from \"element-plus\";\nimport { tr } from \"element-plus/es/locale\";\nimport copy from \"@/utils/copy\";\nexport default defineComponent({\n  name: \"ChatChat\",\n  props: {\n    messages: {\n      type: Array as PropType<IMessageData[]>,\n      required: true,\n    },\n    chatInit: {\n      type: Boolean,\n      required: true,\n    },\n    activeAssistant: {\n      type: Object as PropType<IAssistant>,\n      required: true,\n    },\n    activeThread: {\n      type: Object as PropType<IThread>,\n      required: true,\n    },\n    inputDisabled: {\n      type: Boolean,\n      default: false,\n    },\n  },\n  setup(props, context) {\n    const { t } = useI18n();\n    const router = useRouter();\n    const route = useRoute();\n    const localMessages = ref<IMessageData[]>([...props.messages]);\n    const showScrollButton = ref(false);\n    const messageScroll = ref<BScroll | null>(null);\n    const inputQuestion = ref<string>(\"\");\n    const inputDisabled = ref(false);\n    const msgBttnBoxShow = ref<boolean[]>([]);\n    const answer = ref(\"\");\n    const activeThread = ref<IThread>({} as IThread);\n    const activeAssistant = ref<IAssistant>({} as IAssistant);\n    const isNotChating = ref(true);\n    const isRunning = ref(false);\n    const stopRunId = ref<string>(\"\");\n    const shouldContinueReceiving = ref(true);\n    const textareaHeight = ref(48);\n    const chatBox_Ref = ref();\n    const textarea_ref = ref();\n    const content_Ref = ref();\n    // Boolean if go\n    isNotChating.value = props.chatInit;\n    activeThread.value = props.activeThread;\n    activeAssistant.value = props.activeAssistant;\n    watch(\n      () => props.messages,\n      (newMessages) => {\n        localMessages.value = [...newMessages];\n        msgBttnBoxShow.value = new Array(newMessages.length).fill(true);\n      }\n    );\n    watch(\n      () => props.inputDisabled,\n      (newValue) => {\n        inputDisabled.value = newValue;\n      }\n    );\n    // Update scrollbars and scrolling events\n    watch(\n      () => localMessages.value,\n      (newMessages) => {\n        if (messageScroll.value) {\n          scrollToTop();\n          messageScroll.value.destroy();\n          messageScroll.value = null;\n        }\n        if (!isNotChating.value) {\n          nextTick(() => {\n            messageScroll.value = new BScroll(\".chat-msg\", {\n              click: true,\n              mouseWheel: true,\n              probeType: 3, //Only when set to 3 can the event of scrolling binding be triggered\n            });\n          });\n        }\n      },\n      {\n        immediate: true,\n        deep: true,\n      }\n    );\n    watch(\n      () => messageScroll.value,\n      (newValue) => {\n        if (newValue) {\n          messageScroll.value?.on(\"scroll\", handleScroll);\n          showScrollButton.value = false;\n          scrollToBottom();\n        }\n      }\n    );\n    watch(\n      () => props.chatInit,\n      (newValue) => {\n        isNotChating.value = newValue;\n      }\n    );\n    watch(\n      () => props.activeThread,\n      (newValue) => {\n        activeThread.value = newValue;\n      }\n    );\n    watch(\n      () => props.activeAssistant,\n      (newValue) => {\n        activeAssistant.value = newValue;\n      }\n    );\n\n    const handleInput = (event:any) => {\n      adjustHeight();\n      const maxLength = 2000; \n      if (inputQuestion.value?.length > maxLength) {\n        event.preventDefault(); \n        inputQuestion.value = inputQuestion.value.substring(0, maxLength); \n      }\n    };\n    const adjustHeight = () => {\n      const currentScrollTop = textarea_ref.value.scrollTop;\n      textarea_ref.value.style.height = textarea_ref.value.scrollHeight + \"px\";\n      chatBox_Ref.value.style.height = textarea_ref.value.style.height;\n      textarea_ref.value.scrollTop = currentScrollTop;\n    };\n\n    const inputPlaceholder = computed(() => {\n      if (typeof activeAssistant.value.name != \"undefined\") {\n        return replaceAssistant(t(\"chat.inputTip\"), activeAssistant.value.name);\n      } else {\n        return t(\"chat.inputTip\");\n      }\n    });\n    // Block events\n    const StopOutput = async () => {\n      shouldContinueReceiving.value = false;\n      try {\n        const response = await cancelRun(\n          activeThread.value.id,\n          stopRunId.value\n        );\n        if (!response.ok) {\n          console.error(\"Failed to cancel run\");\n        }\n      } catch (error) {\n        console.error(\"Failed to cancel run:\", error);\n      }\n    };\n    // dialogue\n    const commitQuestion: () => void = async () => {\n      const question = inputQuestion.value;\n      // If it came in by clicking on assistants without clicking on thread, or through preview\n      if (Object.keys(activeThread.value).length == 0) {\n        try {\n          let res = {} as IThread;\n          // If you click thread and do not select assistant\n          if (route.name == \"preview\") {\n            let metadata = {\n              hidden: \"true\",\n            };\n            res = await createThread(undefined, undefined, metadata);\n          } else {\n            res = await createThread();\n          }\n          activeThread.value = res;\n        } catch (err) {\n          console.error(err);\n        }\n      }\n      //If you click thread and do not select assistant\n      else if (Object.keys(activeAssistant.value).length == 0) {\n        try {\n          const messageOfAssistant = props.messages.find(\n            (message) => message.role === \"assistant\"\n          );\n          if (messageOfAssistant && messageOfAssistant.assistant_id) {\n            const res = await getAssistant(messageOfAssistant.assistant_id);\n            activeAssistant.value = res;\n          }\n        } catch (err) {\n          console.error(err);\n        }\n      }\n      if (question) {\n        inputQuestion.value = \"\";\n        textareaHeight.value = 48;\n        // inputDisabled.value = true;\n        isNotChating.value = false;\n        isRunning.value = true;\n        await createMessage(activeThread.value.id, question)\n          .then((res: any) => {})\n          .catch((err: any) => {\n            ElMessage({\n              type: \"warning\",\n              message: \"Request error\",\n            });\n            return;\n          });\n        // Current message queue insertion issue\n        localMessages.value.push({\n          role: \"user\",\n          content: [\n            { type: \"text\", text: { value: question }, annotatons: [] },\n          ],\n          created_at: Date.now() / 1000,\n        });\n        msgBttnBoxShow.value.push(true);\n        // Insert answer into the current message queue\n        localMessages.value.push({\n          role: \"assistant\",\n          content: [{ type: \"text\", text: { value: \"\" }, annotatons: [] }],\n          created_at: Date.now() / 1000,\n        });\n        msgBttnBoxShow.value.push(false);\n        try {\n          const asyncGenerator = createRun(\n            {\n              assistant_id: activeAssistant.value.id,\n              stream: true,\n            },\n            activeThread.value.id\n          );\n          for await (const word of asyncGenerator) {\n            if (!shouldContinueReceiving.value) {\n              break;\n            }\n            if (word.length == 36) {\n              stopRunId.value = word;\n              console.log(stopRunId.value);\n            } else {\n              answer.value += word;\n              const index = localMessages.value.length - 1;\n              localMessages.value[index].content[0].text.value += word;\n              if (answer.value.length <= 3) {\n                localMessages.value[index].created_at = Date.now() / 1000;\n              }\n            }\n          }\n        } catch (err) {\n          console.error(err);\n        }\n        shouldContinueReceiving.value = true;\n        answer.value = \"\";\n        inputDisabled.value = false;\n        msgBttnBoxShow.value[msgBttnBoxShow.value.length - 1] = true;\n        scrollToBottom();\n        isRunning.value = false;\n        context.emit(\"updateAssistant\", true);\n        textarea_ref.value.focus();\n      }\n    };\n    // Keyboard event stabilization\n    const keyBoardCommitQuestion = (event: any) => {\n      const question = inputQuestion.value?.trim();\n      if (event.keyCode === 13) {\n        event.preventDefault();\n\n        const cursorPosition = event.target.selectionStart;\n        if ((event.metaKey || event.ctrlKey) && question) {\n          event.target.value =\n            event.target.value.substring(0, cursorPosition) +\n            \"\\n\" +\n            event.target.value.substring(cursorPosition);\n          event.target.selectionStart = event.target.selectionEnd =\n            cursorPosition + 1;\n          adjustHeight();\n          return;\n        }\n        if (!question) {\n          ElMessage({\n            message: \"Please enter the content!\",\n            type: \"warning\",\n            plain: true,\n          });\n          return;\n        }\n        if (!isRunning.value) {\n          commitQuestion();\n          inputQuestion.value = \"\";\n        }\n      }\n    };\n    const clickCommitQuestion = () => {\n      if (!isRunning.value && inputQuestion.value?.trim() != \"\") {\n        commitQuestion();\n        return;\n      }\n      ElMessage({\n        message: \"Please enter the content!\",\n        type: \"warning\",\n        plain: true,\n      });\n    };\n    //Bottom scrolling\n    const scrollToBottom = () => {\n      //If messageScroll. value exists\n      if (messageScroll.value) {\n        //Call the scrollTo method of messageScroll. value and scroll to the bottom\n        messageScroll.value.scrollTo(0, messageScroll.value?.maxScrollY, 800);\n      }\n    };\n    // Top scrolling\n    const scrollToTop = () => {\n      if (messageScroll.value) {\n        messageScroll.value.scrollTo(0, messageScroll.value?.minScrollY, 800);\n      }\n    };\n    // Handling rolling events\n    const handleScroll = (pos: any) => {\n      if (messageScroll.value) {\n        const distanceToBottom =\n          messageScroll.value.y - messageScroll.value.maxScrollY;\n        showScrollButton.value = distanceToBottom > 100;\n      }\n    };\n    // Replace characters\n\n    function replaceAssistant(input: string, newString: string) {\n      return input.replace(/assistant/g, newString);\n    }\n    // Extract the markup text to convert the passed in object array into an HTML string parsed by market.js\n    const markedText = (content: object[]) => {\n      let context = \"\";\n      for (const item of content) {\n        if ((item as { type: string }).type === \"text\") {\n          context += ((item as { text: object }).text as { value: string })\n            .value;\n        }\n      }\n      return marked.parse(context);\n    };\n    // Extract text content\n    const createText = (content: object[]) => {\n      let context = \"\";\n      for (const item of content) {\n        if ((item as { type: string }).type === \"text\") {\n          context += ((item as { text: object }).text as { value: string })\n            .value;\n        }\n      }\n      return context;\n    };\n    // Time formatting\n    const timeFormat = (timestamp: number | undefined) => {\n      if (!timestamp) {\n        return \"\";\n      }\n      const date = new Date(timestamp * 1000);\n      // Obtain various time sections\n      const year = date.getFullYear();\n      const month = String(date.getMonth() + 1).padStart(2, \"0\"); // The month starts from 0 and needs to be increased by 1, with zeros added\n      const day = String(date.getDate()).padStart(2, \"0\"); // Zero padding\n      const hours = String(date.getHours()).padStart(2, \"0\"); // Zero padding\n      const minutes = String(date.getMinutes()).padStart(2, \"0\"); // Zero padding\n      const seconds = String(date.getSeconds()).padStart(2, \"0\"); // Zero padding\n      // Format as \"YYYY-MM-DD HH: mm: ss\"\n      const formattedDate = `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`;\n      return formattedDate;\n    };\n    onMounted(() => {\n      adjustHeight();\n    });\n    return {\n      inputQuestion,\n      inputDisabled,\n      msgBttnBoxShow,\n      localMessages,\n      textareaHeight,\n      answer,\n      StopOutput,\n      isNotChating,\n      handleInput,\n      chatBox_Ref,\n      adjustHeight,\n      content_Ref,\n      markedText,\n      timeFormat,\n      createText,\n      inputPlaceholder,\n      keyBoardCommitQuestion,\n      clickCommitQuestion,\n      messageScroll,\n      showScrollButton,\n      commitQuestion,\n      scrollToBottom,\n      scrollToTop,\n      isRunning,\n      copy,\n      replaceAssistant,\n      textarea_ref,\n    };\n  },\n});\n</script>\n\n<style scoped lang=\"stylus\">\n@import '@/assets/css/mixins.styl';\n\n.chat-panel {\n  justify-content: center;\n  display: flex;\n  position: relative;\n  height: 100%;\n\n  .chat-model {\n    font-size: 16px;\n    font-weight: bold;\n    position: absolute;\n    top: 20px;\n    left: 30px;\n  }\n\n  .chat-panel-inner {\n    width: 920px;\n    padding-top: 80px;\n  }\n\n  .chat-init {\n    padding: 0 20px;\n\n    .assistant-info {\n      text-align: center;\n      align-items: center;\n      justify-content: center;\n\n      .avatar img {\n        width: 70px;\n        height: 70px;\n      }\n\n      .name {\n        margin: 40px 0;\n        font-size: 20px;\n        font-weight: bold;\n      }\n\n      .desc {\n        color: $gray_40;\n      }\n    }\n\n    .assistant-tips {\n      margin-bottom: 80px;\n\n      .tips-item {\n        width: 44%;\n        height: 70px;\n        line-height: 70px;\n        float: left;\n        border: 1px solid $border_gray_light_normal;\n        border-radius: 8px;\n        margin-top: 10px;\n        margin-bottom: 10px;\n        padding: 0 20px;\n        color: $gray_40;\n\n        &:nth-child(odd) {\n          margin-left: 4%;\n          margin-right: 4%;\n        }\n\n        &:nth-child(even) {\n          margin-right: 4%;\n        }\n\n        .tips-ops {\n          display: none;\n          width: 24px;\n          height: 24px;\n          line-height: 24px;\n          border-radius: 4px;\n          text-align: center;\n          border: 1px solid $border_gray_light_normal;\n\n          i {\n            font-size: 20px;\n          }\n        }\n\n        &:hover {\n          cursor: pointer;\n          background-color: $bg_gray_light_hover;\n\n          .tips-ops {\n            display: block;\n            background-color: #FFFFFF;\n          }\n        }\n      }\n    }\n  }\n\n  .chat-msg {\n    overflow-y: hidden;\n\n    ul {\n      li.chat-msg-item {\n        margin-bottom: 40px;\n        align-items: flex-start !important;\n        // border: 1px solid;\n        border-radius: 15px;\n        padding: 20px;\n        margin-right: 20px;\n        background-color: #313344;\n        box-shadow: 12.5px 12.5px 10px rgba(0, 0, 0, 0.035), 10px 10px 8px rgba(0, 0, 0, 0.07);\n\n        .avatar {\n          margin-right: 15px;\n          width: 36px;\n          height: 36px;\n\n          img {\n            width: 100%;\n            height: 100%;\n            border-radius: 25px;\n          }\n        }\n\n        .msg {\n          .title {\n            display: flex;\n            align-items: center;\n            justify-content: space-between;\n            margin-bottom: 12px;\n            height: 36px;\n            line-height: 24px;\n\n            .time {\n              justify-content: center;\n              // margin-bottom: 12px;\n              line-height: 20px;\n              font-size: 14px;\n              color: $gray_80;\n            }\n\n            .name {\n              color: #edf2ea;\n              font-size: 16px;\n              font-weight: bold;\n              margin-right: 15px;\n            }\n\n            .tips {\n              font-size: 14px;\n              color: $gray_50;\n            }\n          }\n\n          .content {\n            max-width: 829px;\n            color: #edf2ea;\n            font-size: 14px;\n            line-height: 20px;\n            word-wrap: break-word;\n            margin-bottom: 12px;\n          }\n\n          .copy-btn {\n            margin-top: 10px;\n            justify-content: left;\n\n            i {\n              font-size: 20px;\n              color: $gray_70;\n\n              &:hover {\n                cursor: pointer;\n                color: $gray_50;\n\n                .tips-ops {\n                  display: block;\n                  background-color: #FFFFFF;\n                }\n              }\n            }\n          }\n        }\n      }\n    }\n  }\n\n  .chat-send {\n    width: 900px;\n    padding: 40px 0;\n    position: relative;\n\n    .chat-box {\n      width: 100%;\n      height: auto;\n      min-height: 48px;\n      max-height: 192px !important;\n      border: none;\n      border-radius: 15px;\n      background: white;\n      line-height: 48px;\n\n      // overflow: hidden;\n      .chat-input {\n        height: auto;\n        min-width: 900px;\n        max-height: 192px !important;\n        width: 100%;\n        border: none;\n        overflow-anchor: auto;\n        overflow-x: hidden;\n        overflow-y: auto;\n        resize: none;\n        background: white;\n        display: inline-block;\n      }\n\n      .chat-input::-webkit-scrollbar {\n        width: 10px;\n      }\n\n      .chat-input::-webkit-scrollbar-track {\n        background-color: #f1f1f1;\n      }\n\n      .chat-input::-webkit-scrollbar-thumb {\n        background-color: #888;\n        border-radius: 5px;\n      }\n\n      .chat-input::-webkit-scrollbar-thumb:hover {\n        background-color: #555;\n      }\n\n      .chat-input::-webkit-resizer {\n        display: none;\n      }\n\n      .stop-btn {\n        border: none;\n        width: 60px;\n        position: absolute;\n        right: 50%;\n        transform: translateX(50%);\n        top: -40px;\n        -webkit-border-radius: 50;\n        -moz-border-radius: 50;\n        border-radius: 50px;\n        font-family: Arial;\n        color: #ffffff;\n        font-size: 16px;\n        background: #cacdd1;\n        padding: 10px 15px 10px 15px;\n        text-decoration: none;\n      }\n\n      .stop-btn:hover {\n        background: #8080e1;\n        text-decoration: none;\n        cursor: pointer;\n      }\n    }\n  }\n}\n\n.scroll-box {\n  position: absolute;\n  bottom: 130px;\n  right: 50%;\n  transform: translateX(50%);\n  margin: 0 auto;\n  width: 32px;\n  height: 32px;\n  border-radius: 16px;\n  border: 1px solid $gray_80;\n  background-color: var(--el-bg-color-overlay);\n  box-shadow: var(--el-box-shadow-lighter);\n  text-align: center;\n  line-height: 32px;\n  color: #1989fa;\n\n  i {\n    font-size: 24px;\n    color: $gray_60;\n  }\n\n  &:hover {\n    cursor: pointer;\n    background-color: $bg_gray_light_hover;\n\n    i {\n      color: $gray_50;\n    }\n  }\n}\n</style>"
  },
  {
    "path": "archive/ktransformers/website/src/conf/config.ts",
    "content": "declare global {\n    interface Window {\n      configWeb: {\n        apiUrl: string;\n        port: string;\n       };\n     }\n  }\n\nexport const baseURL = window.configWeb.apiUrl;\nexport const basePort = window.configWeb.port;\n"
  },
  {
    "path": "archive/ktransformers/website/src/locals/en.js",
    "content": "// en.js\nexport default {\n    home: {\n        explore: 'Explore',\n        language: 'Choose Language',\n        english: 'English',\n        chinese: 'Chinese',\n        today: 'Today',\n        previous:'Previous',\n        withoutAssistantTip:'The KTransformers of this record has been deleted. The user can only view historical conversation information and cannot continue the conversation!',\n        deleteThreadTip:'Deleting records will clear historical information~'\n    },\n    chat:{\n        inputTip:\"Send a message and chat with the KTransformers ~\",\n    },\n    explore:{\n        description: \"Based on Lexllama, let’s create your own KTransformers~\",\n        configuring: \"Configuring\",\n        completed: \"Completed\",\n        assistantName: \"Name\",\n        assistantDescription: \"Description\",\n        assistantStatus: \"Status\",\n        createAssistant: \"Create New KTransformers\",\n        deleteAssistant: \"Are you sure to delete this? After deleting the KTransformers, its KVCache will also be cleared simultaneously~\",\n    },\n    config:{\n        title:'Configure your KTransformers',\n        fileTip:\"Only support text, docx, .ppt, .pdf format.\",\n        reConfigTip:'Reconfig KTransformers needs to delete kvcache, please choose carefully',\n        secletFile:'Select Files',\n        outOfSize:'File size exceeds 10MB, please reselect',\n        fileExist:'The file already exists, please reselect',\n        createAssistant:'Assistant created successfully, click the build button to start building KVCache',\n    },\n    build:{\n        title:'Building Logs',\n        step1:'Parse uploded files',\n        parsingFileStep1:'File upload and reception completed',\n        parsingFileStep2:{\n            parse:\"Parsing\",\n            file:\"file(s)\",\n            total:'total',\n        },\n        parsingFileStep3:'Prompt loaded, ready to generate KVCache',\n        step2:'Generate KVCache',\n        generateStep1:'Generate KVCache calculation plan',\n        generateStep2:{\n            calculate:\"calculating\",\n            token:\"tokens\",\n            total:'total',\n        },\n        generateStep3:'KVCache has been generated successfully',\n        durationTime:'Duration:',\n        remainTime:'Time left:',\n        buildProgress:'Building Progress',\n        storageUsage:'KVCache Storage Usage',\n    }\n}\n"
  },
  {
    "path": "archive/ktransformers/website/src/locals/index.js",
    "content": "// index.js\nimport { createI18n } from 'vue-i18n'\nimport zh from './zh'\nimport en from './en'\n\nconst messages = {\n  en,\n  zh,\n}\nconst language = (navigator.language || 'en').toLocaleLowerCase() // 这是获取浏览器的语言\nconst i18n = createI18n({\n  legacy: false, // you must set `false`, to use Compostion API\n  locale: localStorage.getItem('lang') || language.split('-')[0] || 'en', // 首先从缓存里拿，没有的话就用浏览器语言，\n  fallbackLocale: 'en', // 设置备用语言\n  messages, \n})\n\nexport default i18n"
  },
  {
    "path": "archive/ktransformers/website/src/locals/zh.js",
    "content": "// zh.js\nexport default {\n    home: {\n        explore: '探索',\n        language: '选择语言',\n        english: '英语',\n        chinese: '中文',\n        today: '今天',\n        previous:'历史',\n        withoutAssistantTip:'本记录的KTransformers已被删除，用户只能查看历史对话信息而无法继续对话!',\n        deleteThreadTip:'删除记录会清除历史信息哦～'\n    },\n    chat:{\n        inputTip:\"发送信息和 KTransformers 畅聊吧～\",\n    },\n    explore:{\n        description: \"基于Lexllama，一起来创建你的专属KTransformers吧~\",\n        configuring: \"配置中\",\n        completed: \"完成\",\n        assistantName: \"名称\",\n        assistantDescription: \"描述\",\n        assistantStatus: \"Status\",\n        createAssistant: \"创建新的KTransformers\",\n        deleteAssistant: \"是否确认删除KTransformers，删除KTransformers之后其KVCache也会被同步清理掉哦~\",\n    },\n    config:{\n        title:'配置你的KTransformers',\n        fileTip:\"仅支持上传文件格式为 .text, docx, .ppt, .pdf format.\",\n        secletFile:'选择文件',\n        outOfSize:'文件大小超出10MB，请重新选择',\n        fileExist:'文件已存在，请重新选择',\n        createAssistant:'KTransformers创建成功，点击build按钮开始构建KVCache',\n    },\n    build:{\n        title:'构建日志',\n        step1:'解析上传文件',\n        parsingFileStep1:'文件上传接收完成',\n        parsingFileStep2:{\n            parse:\"正在解析第\",\n            file:\"文件\",\n            total:'共',\n        },\n        parsingFileStep3:'Prompt装载完毕，准备生成KVCache',\n        step2:'生成 KVCache',\n        generateStep1:'生成KVCache计算计划',\n        generateStep2:{\n            calculate:\"正在计算\",\n            token:\"tokens\",\n            total:'共',\n        },\n        generateStep3:'KVCache已生成完成',\n        durationTime:'持续时间：',\n        remainTime:'剩余时间：',\n        buildProgress:'构建进度',\n        storageUsage:'存储使用：',\n        \n    }\n}\n"
  },
  {
    "path": "archive/ktransformers/website/src/main.ts",
    "content": "import { createApp } from 'vue'\nimport App from './App.vue'\nimport router from './router'\nimport store from './store'\nimport ElementPlus from 'element-plus'\nimport 'element-plus/dist/index.css'\nimport VueApexCharts from \"vue3-apexcharts\"\nimport i18n from '@/locals'\n\nconst app = createApp(App)\n\napp.use(ElementPlus)\n\napp.use(i18n)\napp.use(VueApexCharts)\napp.use(store)\napp.use(router)\napp.mount('#app')\n"
  },
  {
    "path": "archive/ktransformers/website/src/router/index.ts",
    "content": "import { createRouter, createWebHashHistory, RouteRecordRaw, createWebHistory } from 'vue-router'\nimport HomeView from '@/views/home.vue'\n\nconst routes: Array<RouteRecordRaw> = [\n  {\n    path: '/',\n    name: 'home',\n    component: HomeView,\n    redirect: '/chat',\n    children: [{\n      path: '/chat',\n      name: '',\n      component: () => import(/* webpackChunkName: \"about\" */ '../components/chat/index.vue')\n    },]\n  },\n\n]\n\nconst router = createRouter({\n  history: createWebHashHistory(),\n  routes\n})\n\nexport default router\n"
  },
  {
    "path": "archive/ktransformers/website/src/shims-vue.d.ts",
    "content": "/* eslint-disable */\ndeclare module '*.vue' {\n  import type { DefineComponent } from 'vue'\n  const component: DefineComponent<{}, {}, any>\n  export default component\n  \n}\n\ndeclare module '@/locals'\ndeclare module 'pdfobject';\n"
  },
  {
    "path": "archive/ktransformers/website/src/store/index.ts",
    "content": "import { createStore } from 'vuex'\n\nexport default createStore({\n  state: {\n  },\n  getters: {\n  },\n  mutations: {\n  },\n  actions: {\n  },\n  modules: {\n  }\n})\n"
  },
  {
    "path": "archive/ktransformers/website/src/utils/copy.ts",
    "content": "import { ElMessage } from \"element-plus\";\nconst copy = (value: string) => {\n  //Try using the navigator.clipboard.writeText method\n  if (navigator.clipboard && window.isSecureContext) {\n    navigator.clipboard.writeText(value)\n      .then(() => {\n        //Using ElMessage to Display Success Messages in Windows Systems\n        if (navigator.appVersion.includes(\"Win\")) {\n          ElMessage({\n            message: \"内容复制成功!\",\n            type: \"success\",\n            plain: true,\n          });\n        } else {\n          //Using custom DOM elements to display success messages in macOS system\n          showCopySuccessMessage();\n        }\n      })\n      .catch(() => {\n        //Using ElMessage to Display Failure Messages in Windows Systems\n        if (navigator.appVersion.includes(\"Win\")) {\n          ElMessage({\n            message: \"内容复制失败!\",\n            type: \"error\",\n            plain: true,\n          });\n        } else {\n          //Using custom DOM elements to display failure messages in macOS system\n          showCopyErrorMessage();\n        }\n      });\n  } else {\n    const textarea = document.createElement(\"textarea\");\n    textarea.value = value;\n    document.body.appendChild(textarea);\n    textarea.select();\n    try {\n      const successful = document.execCommand('copy');\n      if (successful) {\n        if (navigator.appVersion.includes(\"Win\")) {\n          ElMessage({\n            message: \"内容复制成功!\",\n            type: \"success\",\n            plain: true,\n          });\n        } else {\n          showCopySuccessMessage();\n        }\n      } else {\n        if (navigator.appVersion.includes(\"Win\")) {\n          ElMessage({\n            message: \"内容复制失败!\",\n            type: \"error\",\n            plain: true,\n          });\n        } else {\n          showCopyErrorMessage();\n        }\n      }\n    } catch (err) {\n      if (navigator.appVersion.includes(\"Win\")) {\n        ElMessage({\n          message: \"内容复制失败!\",\n          type: \"error\",\n          plain: true,\n        });\n      } else {\n        showCopyErrorMessage();\n      }\n    }\n    document.body.removeChild(textarea);\n  }\n};\n\nfunction showCopySuccessMessage() {\n  const messageElement = document.createElement('div');\n  messageElement.textContent = '内容复制成功!';\n  messageElement.style.position = 'fixed';\n  messageElement.style.bottom = '10px';\n  messageElement.style.left = '50%';\n  messageElement.style.transform = 'translateX(-50%)';\n  messageElement.style.padding = '10px';\n  messageElement.style.backgroundColor = '#4CAF50';\n  messageElement.style.color = 'white';\n  messageElement.style.borderRadius = '15px';\n  messageElement.style.zIndex = '1000';\n  document.body.appendChild(messageElement);\n  setTimeout(() => {\n    document.body.removeChild(messageElement);\n  }, 3000);\n}\n\nfunction showCopyErrorMessage() {\n  const messageElement = document.createElement('div');\n  messageElement.textContent = '内容复制失败!';\n  messageElement.style.position = 'fixed';\n  messageElement.style.bottom = '10px';\n  messageElement.style.left = '50%';\n  messageElement.style.transform = 'translateX(-50%)';\n  messageElement.style.padding = '10px';\n  messageElement.style.backgroundColor = '#F44336';\n  messageElement.style.color = 'white';\n  messageElement.style.borderRadius = '5px';\n  messageElement.style.zIndex = '1000';\n  document.body.appendChild(messageElement);\n  setTimeout(() => {\n    document.body.removeChild(messageElement);\n  }, 3000);\n}\n\nexport default copy;"
  },
  {
    "path": "archive/ktransformers/website/src/utils/types.ts",
    "content": "export interface IAssistant {\n  id: string;\n  object: string;\n  created_at: number;\n  name?: string;\n  description?: string;\n  model: string;\n  instructions?: string;\n  tools: any[];\n  tool_resources?: object;\n  metadata?:{[key:string]:any}\n  top_p?: number;\n  temperature?: number;\n  response_format: string | object;\n}\n\nexport interface IAssistantWithStatus {\n  build_status:{status:string}\n  id: string;\n  object: string;\n  created_at: number;\n  name?: string;\n  description?: string;\n  model: string;\n  instructions?: string;\n  tools: any[];\n  tool_resources?: object;\n  metadata?:{[key:string]:any}\n  top_p?: number;\n  temperature?: number;\n  response_format: string | object;\n}\n\nexport interface IMessage {\n  id: string;\n  object: string;\n  created_at: number;\n  thread_id: string;\n  status: string;\n  incomplete_details?: object;\n  completed_at?: number;\n  incomplete_at?: number;\n  role: string;\n  content: any[];\n  assistant_id?: string;\n  run_id?: string;\n  attachments?: any[];\n  metadata:{[key:string]:any}\n}\n\nexport interface IThread {\n  id: string;\n  object: string;\n  created_at: number;\n  tool_resources?: object;\n  metadata?:{[key:string]:any}\n}\n\nexport interface IRun {\n  id: string;\n  object: string;\n  created_at: number;\n  thread_id: string,\n  assistant_id: string,\n  status: string,\n  required_action?: object,\n  last_error?: object,\n  expires_at?: number,\n  started_at?: number,\n  cancelled_at?: number,\n  failed_at?: number,\n  completed_at?: number,\n  incomplete_details?: object,\n  model: string,\n  instructions: string,\n  tools: any[],\n  metadata: Map<string, string>,\n  usage?: object,\n  temperature?: number,\n  top_p?: number,\n  max_prompt_tokens?: number,\n  max_completion_tokens?: number,\n  truncation_strategy: object,\n  tool_choice: string | object,\n  response_format: string | object,\n}\n\nexport interface IFile {\n  id: string,\n  bytes: number,\n  created_at: number,\n  filename: string,\n  object: string,\n  purpose: string,\n}\n\nexport interface IMessageData {\n  role: string;\n  content: any[];\n  created_at?: number;\n  assistant_id?: string,\n}\n\nexport interface IThreadAndMessageAndAssistant {\n\n  thread: IThread;\n  first_message: IMessage;\n  assistant: IAssistantWithStatus\n}\nexport interface IDeleteResult {\n  id: string;\n  object: string;\n  deleted: boolean;\n}\nexport interface IBuildData {\n  parsed_file_count:number;\n  total_file_count:number;\n  prefilling_current:number;\n  prefilling_total:number;\n  build_completed_time:number;\n  build_started_time:number;\n  storage_total:number;\n  storage_usage:number;\n  status:string\n}"
  },
  {
    "path": "archive/ktransformers/website/src/views/home.vue",
    "content": "<template>\n  <div class=\"home flex-row\">\n    <nav class=\"left-panel flex-column\">\n      <div class=\"logo-box\">\n        <div class=\"logo flex-row\">\n          <img class=\"img\" src=\"../../public/images/three.png\" />\n          <span class=\"text\">{{ projectName }}</span>\n        </div>\n        <div class=\"version\">{{ projectVersion }}</div>\n      </div>\n      <div class=\"divider\"></div>\n      <div class=\"assistant-box\">\n        <div class=\"assistant-list\">\n          <ul>\n            <li\n              class=\"assistant-item flex-row\"\n              v-for=\"(item, index) in assistantList\"\n              :key=\"index\"\n              @click=\"setActiveAssistant(item)\"\n            >\n              <img src=\"../../public/images/avatar.png\" />\n              <span class=\"name flex-unit\">{{ item.name }}</span>\n              <i class=\"iconfont icon-edit\"></i>\n            </li>\n          </ul>\n        </div>\n      </div>\n      <div class=\"divider\"></div>\n      <!-- History area -->\n      <div class=\"history-box flex-unit\">\n        <div class=\"\">\n          <div class=\"date\">{{ $t(\"home.today\") }}</div>\n          <ul>\n            <li\n              v-for=\"(item, index) in todayThreads\"\n              :key=\"index\"\n              class=\"chat-item\"\n              :class=\"{ active: activeThreadIndex === index }\"\n              @click=\"setActiveThreadIndex(index)\"\n            >\n              <div class=\"chat-abbr\">\n                {{ firstMessages[index] }}\n              </div>\n              <div class=\"chat-ops flex-row\">\n                <img src=\"../../public/images/avatar.png\" />\n                <div class=\"name flex-unit\">\n                  {{ assistantOfThread[index].name || \"\" }}\n                </div>\n                <i class=\"iconfont icon-delete\" @click=\"delThread(index)\"></i>\n              </div>\n            </li>\n          </ul>\n          <div class=\"date\" v-if=\"previousThreads.length > 0\">\n            {{ $t(\"home.previous\") }}\n          </div>\n          <ul>\n            <li\n              v-for=\"(item, index) in previousThreads\"\n              :key=\"index\"\n              class=\"chat-item\"\n              :class=\"{\n                active: activeThreadIndex === index + todayThreads.length,\n              }\"\n              @click=\"setActiveThreadIndex(index + todayThreads.length)\"\n            >\n              <div class=\"chat-abbr\">\n                {{ firstMessages[index + todayThreads.length] }}\n              </div>\n              <div class=\"chat-ops flex-row\">\n                <img src=\"../../public/images/avatar.png\" />\n                <div class=\"name flex-unit\">\n                  {{\n                    assistantOfThread[index + todayThreads.length].name || \"\"\n                  }}\n                </div>\n                <i\n                  class=\"iconfont icon-delete\"\n                  @click=\"delThread(index + todayThreads.length)\"\n                ></i>\n              </div>\n            </li>\n          </ul>\n        </div>\n      </div>\n      <div class=\"icon-box example-2\">\n        <div class=\"iconhub icon-content\" @click=\"navigateToIconHub\">\n          <svg\n            xmlns=\"http://www.w3.org/2000/svg\"\n            width=\"16\"\n            height=\"16\"\n            fill=\"currentColor\"\n            class=\"bi bi-github\"\n            viewBox=\"0 0 16 16\"\n            xml:space=\"preserve\"\n          >\n            <path\n              d=\"M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27s1.36.09 2 .27c1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.01 8.01 0 0 0 16 8c0-4.42-3.58-8-8-8\"\n              fill=\"currentColor\"\n            ></path>\n          </svg>\n          <div class=\"tooltip\">GitHub</div>\n        </div>\n        <div class=\"iconlanguage\" @click=\"changeLanguage\">\n          <svg\n            v-if=\"!flag\"\n            t=\"1719306572024\"\n            class=\"icon\"\n            viewBox=\"0 0 1024 1024\"\n            version=\"1.1\"\n            xmlns=\"http://www.w3.org/2000/svg\"\n            p-id=\"16849\"\n            data-spm-anchor-id=\"a313x.search_index.0.i21.366e3a81tz0TYS\"\n            width=\"18\"\n            height=\"18\"\n          >\n            <path\n              d=\"M64.064 768V192H448.64v64H127.936v192h320v64h-320v192h320v64H64.064z m511.872 0V192h64l256 447.68V192h64v576h-64l-256-447.168V768h-64z\"\n              p-id=\"16850\"\n              data-spm-anchor-id=\"a313x.search_index.0.i22.366e3a81tz0TYS\"\n              class=\"selected\"\n              fill=\"#000000\"\n            ></path>\n          </svg>\n          <svg\n            v-else\n            t=\"1719306494614\"\n            class=\"icon\"\n            viewBox=\"0 0 1024 1024\"\n            version=\"1.1\"\n            xmlns=\"http://www.w3.org/2000/svg\"\n            p-id=\"12325\"\n            width=\"18\"\n            height=\"18\"\n          >\n            <path\n              d=\"M1023.488 831.552h-96l-265.472-451.904c-8.96-12.8-16-25.344-21.44-37.888H638.08c2.176 12.992 3.2 40.128 3.2 81.408v408.32L576 836.928V256h101.568l257.024 445.632c14.592 20.992 23.232 34.368 25.92 40.128h1.6c-2.688-16.512-4.032-44.8-4.032-84.736v-399.36L1024 256l-0.512 575.552zM435.008 804.224c-42.752 21.76-96.384 32.64-160.896 32.64-83.2 0-149.76-25.6-199.488-76.736C24.896 708.928 0 641.344 0 557.12c0-90.432 27.968-163.2 84.032-218.368C140.032 283.52 211.072 256 297.344 256c55.552 0 101.376 7.616 137.6 22.848v75.84a284.992 284.992 0 0 0-136.832-33.408c-64.768 0-117.504 20.864-158.208 62.592-40.768 41.728-61.184 98.048-61.184 168.96 0 67.2 19.008 120.576 57.024 160.128 38.016 39.552 87.744 59.328 149.248 59.328 57.536 0 107.52-12.544 150.016-37.76v69.696z\"\n              fill=\"#000000\"\n              p-id=\"12326\"\n              data-spm-anchor-id=\"a313x.search_index.0.i16.366e3a81tz0TYS\"\n              class=\"selected\"\n            ></path>\n          </svg>\n        </div>\n      </div>\n    </nav>\n    <router-view v-slot=\"{ Component }\" class=\"main-panel flex-unit\">\n      <component\n        :is=\"Component\"\n        :chatInit=\"chatInit\"\n        :activeAssistant=\"activeAssistant\"\n        :activeThread=\"activeThread\"\n        :messages=\"allMessageInCurrentThread\"\n        :completedAssistant=\"assistantList\"\n        :inputDisabled=\"inputDisabled\"\n        @updateAssistant=\"handleUpdateAssistant\"\n      />\n    </router-view>\n  </div>\n</template>\n\n<script lang=\"ts\">\nimport { defineComponent, ref, onMounted, computed, nextTick } from \"vue\";\nimport {\n  IThread,\n  IAssistant,\n  IMessageData,\n  IThreadAndMessageAndAssistant,\n  IAssistantWithStatus,\n} from \"@/utils/types\";\nimport { listThreads, deleteThread, getThread } from \"@/api/thread\";\nimport { ElMessage, ElMessageBox } from \"element-plus\";\nimport { listAssistants } from \"@/api/assistant\";\nimport { listMessages } from \"@/api/message\";\nimport { useRouter } from \"vue-router\";\nimport BScroll from \"better-scroll\";\nimport { useI18n } from \"vue-i18n\";\n\nexport default defineComponent({\n  name: \"HomeView\",\n  setup() {\n    const assistantList = ref<IAssistant[]>([]);\n    const threadsList = ref<IThread[]>([]);\n    const firstMessages = ref<string[]>([]);\n    const activeAssistant = ref({} as IAssistant);\n    const assistantOfThread = ref<IAssistantWithStatus[]>([]);\n    const threadAndMessages = ref<IThreadAndMessageAndAssistant[]>([]);\n    const assistantScroll = ref<BScroll | null>(null);\n    const historyScroll = ref<BScroll | null>(null);\n    const router = useRouter();\n    const { t, locale } = useI18n();\n    const flag = ref(true);\n    const changeLanguage = () => {\n      if (flag.value) {\n        locale.value = \"zh\";\n        localStorage.setItem(\"lang\", \"zh\");\n        flag.value = false;\n      } else {\n        locale.value = \"en\";\n        flag.value = true;\n        localStorage.setItem(\"lang\", \"en\");\n      }\n    };\n    // Initialize data\n    const initData = async () => {\n      try {\n        threadsList.value = [];\n        firstMessages.value = [];\n        assistantOfThread.value = [];\n\n        const assistantsRes = await listAssistants();\n        if (assistantsRes && assistantsRes.length > 0) {\n          assistantList.value = assistantsRes;\n          activeAssistant.value = assistantsRes[0];\n        }\n\n        const threadsRes = await listThreads(100);\n        if (threadsRes) {\n          threadAndMessages.value = threadsRes;\n          for (let t of threadsRes) {\n            if (t.thread && !t.thread.metadata?.hidden) {\n              threadsList.value.push(t.thread);\n              if (\n                t.first_message &&\n                t.first_message.content &&\n                t.first_message.content.length > 0\n              ) {\n                firstMessages.value.push(t.first_message.content[0].text.value);\n              } else {\n                firstMessages.value.push(\"no message yet\");\n              }\n              assistantOfThread.value.push(\n                t.assistant || ({} as IAssistantWithStatus)\n              );\n            }\n          }\n        }\n\n        assistantScroll.value = new BScroll(\".assistant-list\", {\n          click: true,\n          mouseWheel: true,\n          scrollbar: {\n            fade: true,\n            interactive: true,\n          },\n        });\n\n        historyScroll.value = new BScroll(\".history-box\", {\n          click: true,\n          mouseWheel: true,\n          scrollbar: {\n            fade: true,\n            interactive: true,\n          },\n        });\n      } catch (err) {\n        console.error(\"Failed to initialize data:\", err);\n      }\n    };\n    const navigateToIconHub = () => {\n      window.open(\"https://github.com/kvcache-ai/Lexllama\");\n    };\n    const isEmptyObject = (obj: object): boolean => {\n      //Determine if the object is empty\n      return Object.keys(obj).length === 0;\n    };\n    //Jump route\n    const navigateToExplore = () => {\n      router.push(\"/explore\");\n    };\n    const navigatorToChat = () => {\n      router.push(\"/chat\");\n    };\n    // Calculate date\n    const todayThreads = computed(() => {\n      const today = Math.floor(Date.now() / 1000);\n      return threadsList.value.filter((thread) => {\n        return today - thread.created_at <= 86400;\n      });\n    });\n    const previousThreads = computed(() => {\n      const today = Math.floor(Date.now() / 1000);\n      return threadsList.value.filter((thread) => {\n        return today - thread.created_at > 86400;\n      });\n    });\n\n    onMounted(async () => {\n      initData();\n    });\n\n    return {\n      t,\n      flag,\n      assistantList,\n      isEmptyObject,\n      activeAssistant,\n      navigateToExplore,\n      navigatorToChat,\n      threadsList,\n      firstMessages,\n      navigateToIconHub,\n      assistantScroll,\n      historyScroll,\n      assistantOfThread,\n      changeLanguage,\n      initData,\n      todayThreads,\n      previousThreads,\n    };\n  },\n  data() {\n    return {\n      projectName: \"KTransformers\",\n      projectVersion: \"v0.01\",\n      activeThreadIndex: -1,\n      chatInit: true,\n      activeThread: {} as IThread,\n      allMessageInCurrentThread: [] as IMessageData[],\n      inputDisabled: false,\n      isSettingActiveThread: false,\n      isDeletingThread: false,\n      threadAndMessages: <IThreadAndMessageAndAssistant[]>[],\n    };\n  },\n  methods: {\n    setActiveAssistant(assistant: IAssistant) {\n      this.chatInit = true;\n      this.inputDisabled = false;\n      this.activeThreadIndex = -1;\n      this.activeAssistant = assistant;\n      this.activeThread = {} as IThread;\n      this.allMessageInCurrentThread = [];\n      if (this.$route.path != \"/chat\") {\n        this.navigatorToChat();\n      }\n    },\n    async setActiveThreadIndex(index: number) {\n      //If setting up an active thread, return directly\n      if (this.isSettingActiveThread) {\n        return;\n      }\n      this.isSettingActiveThread = true;\n      this.activeThreadIndex = index;\n      this.chatInit = false;\n      this.inputDisabled = false;\n      this.activeAssistant = {} as IAssistant;\n      this.activeThread = this.threadsList[index];\n      //If the assistant of the current thread is an empty object\n      if (this.isEmptyObject(this.assistantOfThread[index])) {\n        ElMessage({\n          message: this.t(\"home.withoutAssistantTip\"),\n          type: \"warning\",\n        });\n        this.inputDisabled = true;\n      }\n      try {\n        //Call asynchronous function to obtain the message list of the current thread\n        const res = await listMessages(this.activeThread.id, 100, \"asc\");\n        //Convert the obtained message list to the specified format and assign values to all messages of the current thread\n        this.allMessageInCurrentThread = res.map((m) => ({\n          role: m.role,\n          content: m.content,\n          assistant_id: m.assistant_id,\n          created_at: m.created_at,\n        }));\n      } catch (err) {\n        console.log(err);\n      } finally {\n        this.isSettingActiveThread = false;\n      }\n      if (this.$route.path != \"/chat\") {\n        this.navigatorToChat();\n      }\n    },\n\n    async delThread(index: number) {\n      // If the thread is currently being deleted, return directly\n      if (this.isDeletingThread) {\n        return;\n      }\n      this.isDeletingThread = true;\n      try {\n        //Pop up a confirmation box and ask the user if they are sure to delete the thread\n        await ElMessageBox.confirm(this.t(\"home.deleteThreadTip\"), \"Warning\", {\n          confirmButtonText: \"OK\",\n          cancelButtonText: \"Cancel\",\n          type: \"warning\",\n        });\n\n        const res = await deleteThread(this.threadsList[index].id);\n        this.threadsList.splice(index, 1);\n        this.firstMessages.splice(index, 1);\n        this.assistantOfThread.splice(index, 1);\n        // Jump to the first assistant or other suitable page\n        this.setActiveAssistant(this.assistantList[0]);\n        ElMessage({\n          type: \"success\",\n          message: \"Delete completed\",\n        });\n      } catch (err) {\n        // Specific error handling, such as logging or displaying specific error messages to users\n        console.error(\"Delete session failed:\", err);\n        ElMessage({\n          type: \"error\",\n          message: `Delete failed`, // Display specific error messages\n        });\n      } finally {\n        this.isDeletingThread = false; //Ensure that the delete thread flag is reset no matter what\n      }\n    },\n    // Handles the update of the assistant asynchronously.\n    async handleUpdateAssistant(value: any) {\n      await this.initData();\n      if (this.activeThreadIndex != -1) {\n        this.setActiveThreadIndex(this.activeThreadIndex);\n      } else if (this.activeAssistant.id) {\n        this.setActiveThreadIndex(0);\n      } else {\n        this.setActiveAssistant(this.assistantList[0]);\n      }\n    },\n  },\n});\n</script>\n\n\n<style lang=\"stylus\" rel=\"stylesheet/stylus\" scoped>\n@import '../assets/css/mixins.styl';\n\n.home {\n  width: 100%;\n  height: 100%;\n  position: relative;\n}\n\n.left-panel {\n  width: 320px;\n  height: 100%;\n  background-color: #363433;\n  padding: 30px 30px;\n  .logo-box {\n    .logo {\n      .img {\n        width: 36px;\n        height: 36px;\n      }\n\n      .text {\n        font-size: 28px;\n        font-weight: bold;\n        margin-left: 10px;\n        color: #edf2ea;\n      }\n    }\n\n    .version {\n      text-align: right;\n      font-size: 14px;\n      color: #bdbdbd;\n    }\n  }\n\n  .divider {\n    border-bottom: 1px solid #D7D7D7;\n    width: 30%;\n    margin: 30px auto;\n  }\n\n  .lang-box {\n    position: relative;\n    width: 100%;\n    height: 30px;\n    margin: auto;\n    margin-bottom: 10px;\n\n    .el-dropdown {\n      font-size: 14px;\n      position: absolute;\n      top: 50%;\n      left: 50%;\n      transform: translate(-50%, -50%);\n    }\n  }\n\n  .assistant-box {\n    .assistant-list {\n      min-height: 50px;\n      max-height: 300px;\n      overflow: hidden;\n      position: relative;\n\n      ul > li.assistant-item {\n        padding: 8px 15px;\n        color: #edf2ea;\n\n        img {\n          width: 32px;\n          height: 32px;\n        }\n\n        .name {\n          margin-left: 12px;\n          font-size: 14px;\n          color: #edf2ea;\n        }\n\n        i.iconfont {\n          display: none;\n          margin-left: 10px;\n        }\n\n        &:hover {\n          background-color: $bg_gray_light_hover;\n          cursor: pointer;\n          border-radius: 4px;\n\n          .name {\n            color: #313433;\n          }\n\n          i.iconfont {\n            display: block;\n          }\n        }\n      }\n    }\n\n    .explore {\n      position: relative;\n      justify-content: center;\n      display: flex;\n      margin-top: 10px;\n\n      .explore-btn {\n        margin: 0 auto;\n        padding: 0 20px;\n        justify-content: center;\n        height: 32px;\n        line-height: 32px;\n        background-color: #FFFFFF;\n        border: 1px solid RGBA(0, 0, 0, 0.15);\n        border-radius: 16px;\n\n        i {\n          color: #8080FF;\n        }\n\n        .text {\n          color: #7F7F7F;\n          margin-left: 4px;\n        }\n\n        &:hover {\n          background-color: #FAFAFA;\n          cursor: pointer;\n        }\n      }\n    }\n  }\n\n  .history-box {\n    position: relative;\n\n    .date {\n      font-size: 14px;\n      color: #7F7F7F;\n      margin: 8px 0;\n\n      &:first-child {\n        margin-top: 0;\n      }\n    }\n\n    li.chat-item {\n      padding: 12px 15px;\n      cursor: pointer;\n      background-color: #edf2ea;\n      border-radius: 4px;\n      margin-bottom: 10px;\n      font-size: 16px;\n\n      .chat-abbr {\n        font-size: 14px;\n        color: #313433;\n        white-space: nowrap;\n        overflow: hidden;\n        text-overflow: ellipsis;\n      }\n\n      .chat-ops {\n        display: flex;\n        margin-top: 5px;\n\n        img {\n          width: 16px;\n          height: 16px;\n        }\n\n        .name {\n          font-size: 12px;\n          color: #898989;\n          margin-left: 8px;\n        }\n\n        i.iconfont {\n          color: $gray_60;\n        }\n      }\n\n      &:hover, &.active {\n        transition: 0.3s all;\n        cursor: pointer;\n        background-color: #a2a79f;\n        .chat-abbr {\n          color: black;\n        }\n\n        .name, i.iconfont {\n          color: black;\n        }\n      }\n    }\n  }\n\n  .icon-box {\n    width: 100%;\n    display: flex;\n    flex-direction: row;\n    justify-content: flex-end;\n    align-items: center;\n\n    .iconhub {\n      width: 32px;\n      height: 24px;\n      background: white;\n      font-size: 30px;\n      border: none;\n      ovferflow: hidden;\n      border-radius: 15%;\n      display: flex;\n      flex-direction: column;\n      justify-content: center;\n      align-items: center;\n      color: #898989;\n      transition: all 0.5s;\n      cursor: pointer;\n    }\n\n    .iconhub:hover {\n      background: #e5e5e5;\n      text-decoration: none;\n    }\n\n    .iconlanguage {\n      margin-left: 15px;\n      width: 32px;\n      height: 24px;\n      background: white;\n      font-size: 30px;\n      border: none;\n      ovferflow: hidden;\n      border-radius: 15%;\n      display: flex;\n      flex-direction: column;\n      justify-content: center;\n      align-items: center;\n      color: #898989;\n      transition: all 0.5s;\n      cursor: pointer;\n    }\n\n    .iconlanguage:hover {\n      background: #e5e5e5;\n      text-decoration: none;\n    }\n  }\n}\n\nul {\n  list-style: none;\n}\n\n.example-2 {\n  display: flex;\n  justify-content: center;\n  align-items: center;\n}\n\n.example-2 .icon-content {\n  margin: 0 10px;\n  position: relative;\n}\n\n.example-2 .icon-content .tooltip {\n  position: absolute;\n  top: -30px;\n  left: 50%;\n  transform: translateX(-50%);\n  color: #fff;\n  padding: 6px 10px;\n  border-radius: 5px;\n  opacity: 0;\n  visibility: hidden;\n  font-size: 14px;\n  transition: all 0.3s ease;\n}\n\n.example-2 .icon-content:hover .tooltip {\n  opacity: 1;\n  visibility: visible;\n  top: -50px;\n}\n\n.main-panel {\n  height: 100%;\n  background-color: #f1f0ed;\n}\n</style>\n"
  },
  {
    "path": "archive/ktransformers/website/tests/unit/example.spec.ts",
    "content": "import { shallowMount } from '@vue/test-utils'\nimport HelloWorld from '@/components/HelloWorld.vue'\n\ndescribe('HelloWorld.vue', () => {\n  it('renders props.msg when passed', () => {\n    const msg = 'new message'\n    const wrapper = shallowMount(HelloWorld, {\n      props: { msg }\n    })\n    expect(wrapper.text()).toMatch(msg)\n  })\n})\n"
  },
  {
    "path": "archive/ktransformers/website/tsconfig.json",
    "content": "{\n  \"compilerOptions\": {\n    \"target\": \"es5\",\n    \"module\": \"esnext\",\n    \"strict\": true,\n    \"jsx\": \"preserve\",\n    \"importHelpers\": true,\n    \"moduleResolution\": \"node\",\n    \"skipLibCheck\": true,\n    \"esModuleInterop\": true,\n    \"allowSyntheticDefaultImports\": true,\n    \"forceConsistentCasingInFileNames\": true,\n    \"useDefineForClassFields\": true,\n    \"sourceMap\": true,\n    \"allowJs\": true,\n    \"baseUrl\": \".\",\n    \"types\": [\n      \"webpack-env\",\n      \"jest\"\n    ],\n    \"paths\": {\n      \"@/*\": [\n        \"src/*\"\n      ]\n    },\n    \"lib\": [\n      \"esnext\",\n      \"dom\",\n      \"dom.iterable\",\n      \"scripthost\"\n    ]\n  },\n  \"include\": [\n    \"src/**/*.ts\",\n    \"src/**/*.tsx\",\n    \"src/**/*.vue\",\n    \"tests/**/*.ts\",\n    \"tests/**/*.tsx\",\n    \"config.d.ts\"\n  ],\n \n  \"exclude\": [\n    \"node_modules\"\n  ]\n}"
  },
  {
    "path": "archive/ktransformers/website/vue.config.js",
    "content": "\nmodule.exports = {\n  // 配置 webpack-dev-server 行为。\n  devServer: {\n    open: false, // 编译后默认打开浏览器\n    host: '0.0.0.0',  // 域名\n    port: 8082,  // 端口\n    https: false,  // 是否https\n    proxy: {\n        '/api': {\n          target: 'http://localhost:9016/v1', // 你的后端服务器地址\n          changeOrigin: true, // 是否允许跨域\n          pathRewrite: {\n            '/api': '' // 将 '/api' 前缀替换为空，如果你的后端不需要这个前缀\n          }\n        }\n      }\n},\npublicPath: '/web/',  // 基本路径\noutputDir: 'dist', // 构建时的输出目录\nassetsDir: 'static', // 放置静态资源的目录\nindexPath: 'index.html', // html 的输出路径\nfilenameHashing: true, // 文件名哈希值\nlintOnSave: false, // 是否在保存的时候使用 `eslint-loader` 进行检查。\n\n// 组件是如何被渲染到页面中的？ （ast：抽象语法树；vDom：虚拟DOM）\n// template ---> ast ---> render ---> vDom ---> 真实的Dom ---> 页面\n// runtime-only：将template在打包的时候，就已经编译为render函数\n// runtime-compiler：在运行的时候才去编译template\nruntimeCompiler: false,\n\ntranspileDependencies: [], // babel-loader 默认会跳过 node_modules 依赖。\nproductionSourceMap: false, // 是否为生产环境构建生成 source map\n\n//调整内部的 webpack 配置\nconfigureWebpack: () => {},\n\nchainWebpack: () => {},\n  \n}"
  },
  {
    "path": "archive/merge_tensors/merge_safetensor_gguf.py",
    "content": "# this script targets to merge the fp8 safe tensor and the gguf quantized tensors.\n\nimport os\n# insert the path of the project\nimport sys\n# sys.path.insert(0, \"/home/azure/ktransformers\")\nimport argparse\nimport torch\nfrom ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf\nfrom safetensors import safe_open\nfrom safetensors.torch import save_file\nimport re\nfrom collections import defaultdict\n\ndef read_safetensor_keys_from_folder(folder_path)->dict:\n    \"\"\"    \n    :param folder_path: folder path\n    :return: key_to_file_map\n    \"\"\"\n    # check if the folder path is exist\n    if not os.path.exists(folder_path):\n        raise FileNotFoundError(f\"GGUF dir not found: {folder_path}\")\n    if os.path.isfile(folder_path):\n        folder_path = os.path.dirname(folder_path)\n    \n    key_to_file_map = {}\n\n    found_safetensor = False\n    for root, dirs, files in os.walk(folder_path):\n        # sort files\n        files = sorted(files)\n        for file in files:\n            if file.endswith(\".safetensors\"):\n                found_safetensor = True\n                file_path = os.path.join(root, file)\n                try:\n                    with safe_open(file_path, framework=\"pt\") as f:\n                        for key in f.keys():\n                            if \"model.layers.61\" in key:\n                                # skip MTP layer\n                                continue\n                            # try:\n                            #     if int(key.split('.')[2]) > 4:\n                            #         continue\n                            # except:\n                            #     pass\n                            key_to_file_map[key] = file_path\n                except Exception as e:\n                    print(f\"Error reading Safetensor file {file_path}: {e}\")\n    \n    if not found_safetensor:\n        raise FileNotFoundError(f\"No Safetensor files found in {folder_path}\")\n    \n    return key_to_file_map\n\ntensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor\n\ndef translate_name(name:str)->str:\n    \"\"\"\n    :param name: name of the tensor\n    :return: translated name\n    \"\"\"\n    name = translate_name_to_gguf(name)\n    name = name.replace(\".up_proj.\", \".ffn_up_exps.\")\n    name = name.replace(\".down_proj.\", \".ffn_down_exps.\")\n    name = name.replace(\".gate_proj.\", \".ffn_gate_exps.\")\n    name = name.replace(\".ffn_gate_inp.e_score_correction_bias\", \".exp_probs_b.bias\") \n    return name\n    \n\ndef combine_tensor_sources(safetensor_path:str, gguf_path:str):\n    gguf_loader = GGUFLoader(gguf_path)\n    gguf_tensor_file_map = gguf_loader.tensor_file_map\n    safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)\n    \n    # build a map for the key to the tensor\n    # according to the key, we can get the tensor from the file\n    \n    target_tensor_map = {}\n    for key in safetensor_tensor_file_map.keys():\n        # for all experts, we use the gguf tensor\n        if \".mlp.experts.\" in key:\n            if '.weight_scale_inv' in key:\n                continue\n            key = '.'.join(key.split('.')[:5]+key.split('.')[-2:])\n            translated_key = translate_name(key)\n            target_tensor_map[key] = gguf_tensor_file_map[translated_key]\n            continue\n        \n        if any(target_key in key for target_key in tensor_from_gguf):\n            target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)]\n        else:\n            target_tensor_map[key] = safetensor_tensor_file_map[key]\n    \n    return target_tensor_map, gguf_loader\n\ndef write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):\n    # Ensure output directory exists\n    os.makedirs(output_path, exist_ok=True)\n    \n    # Cache for safetensor file handles and GGUF loaders\n    safetensors_cache = {}\n    gguf_cache = {}\n    \n    # Group tensors by layer\n    layer_groups = defaultdict(list)\n    non_layer_keys = []\n    layer_pattern = re.compile(r'\\.layers\\.(\\d+)\\.')\n    \n    for key in target_tensor_map:\n        match = layer_pattern.search(key)\n        if match:\n            layer_num = int(match.group(1))\n            layer_groups[layer_num].append(key)\n        else:\n            non_layer_keys.append(key)\n    \n    # Calculate total shards\n    total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1\n    if total_shards == 0:\n        raise ValueError(\"No tensors to save\")\n    \n    shard_idx = 0\n    \n    # Save non-layer tensors to the first shard if they exist\n    if non_layer_keys:\n        tensors = {}\n        for key in non_layer_keys:\n            file_path = target_tensor_map[key]\n            tensor = None\n            ggml_type = None\n            if file_path.endswith('.safetensors'):\n                if file_path not in safetensors_cache:\n                    safetensors_cache[file_path] = safe_open(file_path, framework='pt')\n                f = safetensors_cache[file_path]\n                tensor = f.get_tensor(key)\n            elif file_path.endswith('.gguf'):\n                gguf_name = translate_name(key)\n                tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)\n            else:\n                raise ValueError(f\"Unsupported file format: {file_path}\")\n            tensors[translate_name(key)] = tensor\n            if ggml_type:\n                ggml_type = torch.tensor(ggml_type)\n                ggml_key = translate_name(key)[:-7] + \".ggml_type\" if translate_name(key).endswith(\".weight\") else translate_name(key) + \".ggml_type\"\n                tensors[ggml_key] = ggml_type\n        \n        output_file = os.path.join(output_path, f\"model-{shard_idx:05}-of-{total_shards:05}.safetensors\")\n        print(f\"Saving non-layer tensors to {output_file}\")\n        save_file(tensors, output_file)\n        print(tensors.keys())\n\n        shard_idx += 1\n    \n    # Save each layer's tensors to subsequent shards\n    for layer_num in sorted(layer_groups.keys()):\n        layer_keys = layer_groups[layer_num]\n        tensors = {}\n        for key in layer_keys:\n            file_path = target_tensor_map[key]\n            tensor = None\n            ggml_type = None\n            if file_path.endswith('.safetensors'):\n                if file_path not in safetensors_cache:\n                    safetensors_cache[file_path] = safe_open(file_path, framework='pt')\n                f = safetensors_cache[file_path]\n                tensor = f.get_tensor(key)\n                tensor_info = tensor.shape\n            elif file_path.endswith('.gguf'):\n                gguf_name = translate_name(key)\n                tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)\n                # tensor_info = gguf_loader.tensor_info[gguf_name]\n                # ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type']\n            else:\n                raise ValueError(f\"Unsupported file format: {file_path}\")\n            tensors[translate_name(key)] = tensor\n            if ggml_type:\n                ggml_type = torch.tensor(ggml_type)\n                ggml_key = translate_name(key)[:-7] + \".ggml_type\" if translate_name(key).endswith(\".weight\") else translate_name(key) + \".ggml_type\"\n                tensors[ggml_key] = ggml_type\n        \n        output_file = os.path.join(output_path, f\"model-{shard_idx:05}-of-{total_shards:05}.safetensors\")\n        print(f\"Saving layer {layer_num} to {output_file}\")\n        # print(tensors.keys())\n        save_file(tensors, output_file)\n        shard_idx += 1\n    \n    return\n    \ndef main():\n    # 创建命令行参数解析器\n    parser = argparse.ArgumentParser(description=\"Read parameters from Safetensor and GGUF files\")\n    parser.add_argument(\"--safetensor_path\", type=str, help=\"Path to the Safetensor file\", default=\"/mnt/data/model/DeepSeek-V3\")\n    parser.add_argument(\"--gguf_path\", type=str, help=\"Path to the GGUF file\", default=\"/mnt/data/model/DeepseekV3-q4km-gguf\")\n    parser.add_argument(\"--output_path\", type=str, help=\"Path to the output file\", default=\"/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8\")\n    \n    # print all the arguments\n    print(\"All the arguments:\")\n    print(parser.parse_args())\n    \n    # 解析命令行参数\n    args = parser.parse_args()\n\n    safetensor_path = args.safetensor_path\n    gguf_path = args.gguf_path\n    output_path = args.output_path\n    \n    target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)\n    write_combined_tensor(target_tensor_map, output_path, gguf_loader)\n    \n    return\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "archive/merge_tensors/merge_safetensor_gguf_for_qwen3.py",
    "content": "# coding=utf-8\n# Copyright (c) 2025. Huawei Technologies Co., Ltd. All rights reserved.\n# Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport sys\nimport argparse\nimport torch\nfrom ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf\nfrom safetensors import safe_open\nfrom safetensors.torch import save_file\nimport re\nfrom collections import defaultdict\n\ndef read_safetensor_keys_from_folder(folder_path) -> dict:\n    if not os.path.exists(folder_path):\n        raise FileNotFoundError(f\"Safetensors dir not found: {folder_path}\")\n    if os.path.isfile(folder_path):\n        folder_path = os.path.dirname(folder_path)\n\n    key_to_file_map = {}\n    found_safetensor = False\n\n    for root, dirs, files in os.walk(folder_path):\n        files = sorted(files)\n        for file in files:\n            if not file.endswith(\".safetensors\"):\n                continue\n            found_safetensor = True\n            file_path = os.path.join(root, file)\n            try:\n                with safe_open(file_path, framework=\"pt\") as f:\n                    for key in f.keys():\n                        key_to_file_map[key] = file_path\n            except Exception as e:\n                print(f\"Error reading Safetensor file {file_path}: {e}\")\n\n    if not found_safetensor:\n        raise FileNotFoundError(f\"No Safetensor files found in {folder_path}\")\n\n    return key_to_file_map\n\n\n# 可选：如果你希望对某些非 MoE tensor 也用 GGUF，可以把关键子串填到下面这个列表里\ntensor_from_gguf = []  # e.g. [\"self_attn.q_proj.weight\"]\n\n\ndef translate_name(name: str) -> str:\n    name = translate_name_to_gguf(name)\n    name = name.replace(\".up_proj.\", \".ffn_up_exps.\")\n    name = name.replace(\".down_proj.\", \".ffn_down_exps.\")\n    name = name.replace(\".gate_proj.\", \".ffn_gate_exps.\")\n    name = name.replace(\".ffn_gate_inp.e_score_correction_bias\", \".exp_probs_b.bias\")\n    return name\n\n\ndef combine_tensor_sources(safetensor_path: str, gguf_path: str):\n    gguf_loader = GGUFLoader(gguf_path)\n    gguf_tensor_file_map = gguf_loader.tensor_file_map\n    safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)\n\n    target_tensor_map = {}\n\n    for key, st_file in safetensor_tensor_file_map.items():\n        if \".mlp.experts.\" in key and key.endswith(\".weight\"):\n            parts = key.split(\".\")\n            if len(parts) < 8:\n                raise ValueError(f\"Unexpected MoE expert key format: {key}\")\n            norm_key = \".\".join(parts[:5] + parts[-2:])\n\n            gguf_name = translate_name(norm_key)\n            if gguf_name not in gguf_tensor_file_map:\n                raise KeyError(\n                    f\"[MoE] GGUF tensor not found for safetensors key {key} -> {gguf_name}\"\n                )\n            target_tensor_map[norm_key] = gguf_tensor_file_map[gguf_name]\n            continue\n        if any(tag in key for tag in tensor_from_gguf):\n            gguf_name = translate_name(key)\n            if gguf_name not in gguf_tensor_file_map:\n                raise KeyError(\n                    f\"[Non-MoE] GGUF tensor not found for safetensors key {key} -> {gguf_name}\"\n                )\n            target_tensor_map[key] = gguf_tensor_file_map[gguf_name]\n        else:\n            target_tensor_map[key] = st_file\n\n    return target_tensor_map, gguf_loader\n\n\ndef write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):\n    os.makedirs(output_path, exist_ok=True)\n\n    safetensors_cache = {}\n    layer_groups = defaultdict(list)\n    non_layer_keys = []\n    layer_pattern = re.compile(r\"\\.layers\\.(\\d+)\\.\")\n\n    for key in target_tensor_map:\n        m = layer_pattern.search(key)\n        if m:\n            layer_num = int(m.group(1))\n            layer_groups[layer_num].append(key)\n        else:\n            non_layer_keys.append(key)\n\n    total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1\n    if total_shards <= 0:\n        raise ValueError(\"No tensors to save\")\n\n    shard_idx = 0\n\n    if non_layer_keys:\n        tensors = {}\n        for key in non_layer_keys:\n            file_path = target_tensor_map[key]\n            tensor = None\n            ggml_type = None\n\n            if file_path.endswith(\".safetensors\"):\n                if file_path not in safetensors_cache:\n                    safetensors_cache[file_path] = safe_open(file_path, framework=\"pt\")\n                f = safetensors_cache[file_path]\n                tensor = f.get_tensor(key)\n            elif file_path.endswith(\".gguf\"):\n                gguf_name = translate_name(key)\n                tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)\n            else:\n                raise ValueError(f\"Unsupported file format: {file_path}\")\n\n            out_key = translate_name(key)\n            tensors[out_key] = tensor\n            if ggml_type is not None:\n                ggml_type = torch.tensor(ggml_type)\n                if out_key.endswith(\".weight\"):\n                    ggml_key = out_key[:-7] + \".ggml_type\"\n                else:\n                    ggml_key = out_key + \".ggml_type\"\n                tensors[ggml_key] = ggml_type\n\n        output_file = os.path.join(\n            output_path, f\"model-{shard_idx:05}-of-{total_shards:05}.safetensors\"\n        )\n        print(f\"[WRITE] Saving non-layer tensors to {output_file}\")\n        save_file(tensors, output_file)\n        shard_idx += 1\n\n    for layer_num in sorted(layer_groups.keys()):\n        layer_keys = layer_groups[layer_num]\n        tensors = {}\n\n        for key in layer_keys:\n            file_path = target_tensor_map[key]\n            tensor = None\n            ggml_type = None\n\n            if file_path.endswith(\".safetensors\"):\n                if file_path not in safetensors_cache:\n                    safetensors_cache[file_path] = safe_open(file_path, framework=\"pt\")\n                f = safetensors_cache[file_path]\n                tensor = f.get_tensor(key)\n            elif file_path.endswith(\".gguf\"):\n                gguf_name = translate_name(key)\n                tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)\n            else:\n                raise ValueError(f\"Unsupported file format: {file_path}\")\n\n            out_key = translate_name(key)\n            tensors[out_key] = tensor\n            if ggml_type is not None:\n                ggml_type = torch.tensor(ggml_type)\n                if out_key.endswith(\".weight\"):\n                    ggml_key = out_key[:-7] + \".ggml_type\"\n                else:\n                    ggml_key = out_key + \".ggml_type\"\n                tensors[ggml_key] = ggml_type\n\n        output_file = os.path.join(\n            output_path, f\"model-{shard_idx:05}-of-{total_shards:05}.safetensors\"\n        )\n        print(f\"[WRITE] Saving layer {layer_num} to {output_file}\")\n        save_file(tensors, output_file)\n        shard_idx += 1\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"Merge FP8 safetensors and GGUF tensors for Qwen3-30B-A3B\"\n    )\n    parser.add_argument(\n        \"--safetensor_path\",\n        type=str,\n        help=\"Path to the FP8 Safetensor folder\",\n        default=\"/mnt/data/model/Qwen3-30B-A3B-FP8\",\n    )\n    parser.add_argument(\n        \"--gguf_path\",\n        type=str,\n        help=\"Path to the GGUF file or folder\",\n        default=\"/mnt/data/model/Qwen3-30B-A3B-GGUF\",\n    )\n    parser.add_argument(\n        \"--output_path\",\n        type=str,\n        help=\"Path to the output safetensors folder\",\n        default=\"/mnt/data/model/ktrans-safetensors/Qwen3-30B-A3B-q4km-fp8\",\n    )\n\n    args = parser.parse_args()\n\n    print(\"[ARGS]\", args)\n\n    safetensor_path = args.safetensor_path\n    gguf_path = args.gguf_path\n    output_path = args.output_path\n\n    target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)\n    write_combined_tensor(target_tensor_map, output_path, gguf_loader)\n\n\nif __name__ == \"__main__\":\n    main()\n\n"
  },
  {
    "path": "archive/pyproject.toml",
    "content": "[build-system]\nrequires = [\n  \"setuptools\",\n  \"torch >= 2.3.0\", \n  \"ninja\",\n  \"packaging\",\n  \"cpufeature\"\n  ]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\n\nname = \"ktransformers\"\n\ndynamic = [\"version\"]\n\ndependencies = [\n  \"torch >= 2.3.0\",\n  \"transformers\",\n  \"fastapi >= 0.111.0\",\n  \"uvicorn >= 0.30.1\",\n  \"langchain >= 0.2.0\",\n  \"blessed >= 1.20.0\",\n  \"accelerate >= 0.31.0\",\n  \"sentencepiece >= 0.1.97\",\n  \"setuptools\",\n  \"ninja\",\n  \"wheel\",\n  \"colorlog\",\n  \"build\",\n  \"fire\",\n  \"protobuf\",\n]\n\nrequires-python = \">=3.10\"\n\nauthors = [\n  {name = \"KVCache.AI\", email = \"zhang.mingxing@outlook.com\"}\n]\n\nmaintainers = [\n  {name = \"james0zan\", email = \"zhang.mingxing@outlook.com\"},\n  {name = \"awake\", email = \"awake@approaching.ai\"},\n  {name = \"unicorn chan\", email = \"nl@approaching.ai\"}\n]\n\ndescription = \"KTransformers, pronounced as Quick Transformers, is designed to enhance your Transformers experience with advanced kernel optimizations and placement/parallelism strategies.\"\n\nreadme = \"README.md\"\nlicense = {file = \"LICENSE\"}\n\nkeywords = [\"ktransformers\", \"llm\"]\n\nclassifiers = [\n  \"Development Status :: 4 - Beta\",\n  \"Programming Language :: Python :: 3.10\",\n  \"Programming Language :: Python :: 3.11\",\n  \"Programming Language :: Python :: 3.12\"\n]\n\n[project.urls]\nHomepage = \"https://kvcache.ai\"\nRepository = \"https://github.com/kvcache-ai/ktransformers.git\"\nIssues = \"https://github.com/kvcache-ai/ktransformers/issues\"\n\n\n[project.scripts]\nktransformers = \"ktransformers.server.main:main\"\n\n[tool.setuptools.packages.find]\nwhere = [\"./\", ]\ninclude = [\"ktransformers\",\"ktransformers.*\"]\n[tool.black]\nline-length = 120\npreview = true\nunstable = true\n"
  },
  {
    "path": "archive/requirements-local_chat.txt",
    "content": "fire\ntransformers\nnumpy\ntorch>=2.3.0\npackaging\ncpufeature; sys_platform == 'win32' or sys_platform == 'Windows'\nprotobuf\ntiktoken\nblobfile\n"
  },
  {
    "path": "archive/setup.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :\nAuthor       : chenxl\nDate         : 2024-07-27 16:15:27\nVersion      : 1.0.0\nLastEditors  : chenxl\nLastEditTime : 2024-08-14 16:36:19\nAdapted from:\nhttps://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py\nCopyright (c) 2023, Tri Dao.\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n'''\n\nimport os\nimport sys\nimport re\nimport ast\nfrom collections import deque\nimport subprocess\nimport select\nimport time\nimport platform\nimport shutil\nfrom typing import List, Optional, Literal\nimport http.client\nimport urllib.request\nimport urllib.error\nfrom pathlib import Path\nfrom packaging.version import parse\nimport torch\nimport torch.version\nfrom wheel.bdist_wheel import bdist_wheel as _bdist_wheel\nfrom setuptools import setup, Extension\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME\ntry:\n    from torch_musa.utils.simple_porting import SimplePorting\n    from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME\nexcept ImportError:\n    MUSA_HOME=None\nKTRANSFORMERS_BUILD_XPU = torch.xpu.is_available()\n\n\ntry:\n    import torch_npu\n    KTRANSFORMERS_BUILD_NPU = torch_npu.npu.is_available()\nexcept:\n    KTRANSFORMERS_BUILD_NPU = False\n\n# 检测 DEV_BACKEND 环境变量\ndev_backend = os.environ.get(\"DEV_BACKEND\", \"\").lower()\nif dev_backend == \"xpu\":\n    triton_dep = [\n        \"pytorch-triton-xpu==3.3.0\"\n    ]\nelse:\n    triton_dep = [\"triton>=3.2\"]\n\nwith_balance = os.environ.get(\"USE_BALANCE_SERVE\", \"0\") == \"1\"\n\nclass CpuInstructInfo:\n    CPU_INSTRUCT = os.getenv(\"CPU_INSTRUCT\", \"NATIVE\")\n    FANCY = \"FANCY\"\n    AVX512 = \"AVX512\"\n    AVX2 = \"AVX2\"\n    CMAKE_NATIVE = \"-DLLAMA_NATIVE=ON\"\n    CMAKE_FANCY = \"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON\"\n    CMAKE_AVX512 = \"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON\"\n    CMAKE_AVX2 = \"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON\"\n\nclass VersionInfo:\n    THIS_DIR = os.path.dirname(os.path.abspath(__file__))\n    PACKAGE_NAME = \"ktransformers\"\n    BASE_WHEEL_URL:str = (\n        \"https://github.com/kvcache-ai/ktransformers/releases/download/{tag_name}/{wheel_filename}\"\n    )\n    FORCE_BUILD = os.getenv(\"KTRANSFORMERS_FORCE_BUILD\", \"FALSE\") == \"TRUE\"\n\n    def get_musa_bare_metal_version(self, musa_dir):\n        raw_output = subprocess.run(\n            [musa_dir + \"/bin/mcc\", \"-v\"], check=True,\n            stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode(\"utf-8\")\n        output = raw_output.split()\n        release_idx = output.index(\"version\") + 1\n        bare_metal_version = parse(output[release_idx].split(\",\")[0])\n        musa_version = f\"{bare_metal_version.major}{bare_metal_version.minor}\"\n        return musa_version\n\n    def get_rocm_bare_metal_version(self, rocm_dir):\n        \"\"\"\n        Get the ROCm version from the ROCm installation directory.\n\n        Args:\n            rocm_dir: Path to the ROCm installation directory\n\n        Returns:\n            A string representation of the ROCm version (e.g., \"63\" for ROCm 6.3)\n        \"\"\"\n        try:\n            # Try using rocm_agent_enumerator to get version info\n            raw_output = subprocess.check_output(\n                [rocm_dir + \"/bin/rocminfo\", \"--version\"],\n                universal_newlines=True,\n                stderr=subprocess.STDOUT)\n            # Extract version number from output\n            match = re.search(r'(\\d+\\.\\d+)', raw_output)\n            if match:\n                version_str = match.group(1)\n                version = parse(version_str)\n                rocm_version = f\"{version.major}{version.minor}\"\n                return rocm_version\n        except (subprocess.CalledProcessError, FileNotFoundError):\n            # If rocminfo --version fails, try alternative methods\n            pass\n\n        try:\n            # Try reading version from release file\n            with open(os.path.join(rocm_dir, \"share/doc/hip/version.txt\"), \"r\") as f:\n                version_str = f.read().strip()\n                version = parse(version_str)\n                rocm_version = f\"{version.major}{version.minor}\"\n                return rocm_version\n        except (FileNotFoundError, IOError):\n            pass\n\n        # If all else fails, try to extract from directory name\n        dir_name = os.path.basename(os.path.normpath(rocm_dir))\n        match = re.search(r'rocm-(\\d+\\.\\d+)', dir_name)\n        if match:\n            version_str = match.group(1)\n            version = parse(version_str)\n            rocm_version = f\"{version.major}{version.minor}\"\n            return rocm_version\n\n        # Fallback to extracting from hipcc version\n        try:\n            raw_output = subprocess.check_output(\n                [rocm_dir + \"/bin/hipcc\", \"--version\"],\n                universal_newlines=True,\n                stderr=subprocess.STDOUT)\n            match = re.search(r'HIP version: (\\d+\\.\\d+)', raw_output)\n            if match:\n                version_str = match.group(1)\n                version = parse(version_str)\n                rocm_version = f\"{version.major}{version.minor}\"\n                return rocm_version\n        except (subprocess.CalledProcessError, FileNotFoundError):\n            pass\n\n        # If we still can't determine the version, raise an error\n        raise ValueError(f\"Could not determine ROCm version from directory: {rocm_dir}\")\n\n    def get_cuda_bare_metal_version(self, cuda_dir):\n        raw_output = subprocess.check_output(\n            [cuda_dir + \"/bin/nvcc\", \"-V\"], universal_newlines=True)\n        output = raw_output.split()\n        release_idx = output.index(\"release\") + 1\n        bare_metal_version = parse(output[release_idx].split(\",\")[0])\n        cuda_version = f\"{bare_metal_version.major}{bare_metal_version.minor}\"\n        return cuda_version\n\n    def get_cuda_version_of_torch(self):\n        if KTRANSFORMERS_BUILD_NPU:\n            return 'aarch64'\n        torch_cuda_version = parse(torch.version.cuda)\n        cuda_version = f\"{torch_cuda_version.major}{torch_cuda_version.minor}\"\n        return cuda_version\n\n    def get_platform(self,):\n        \"\"\"\n        Returns the platform name as used in wheel filenames.\n        \"\"\"\n        if sys.platform.startswith(\"linux\"):\n            return f'linux_{platform.uname().machine}'\n        elif sys.platform == \"win32\":\n            return \"win_amd64\"\n        else:\n            raise ValueError(\"Unsupported platform: {}\".format(sys.platform))\n\n    def get_cpu_instruct(self,):\n        if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:\n            return \"fancy\"\n        elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:\n            return \"avx512\"\n        elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2:\n            return \"avx2\"\n        else:\n            print(\"Using native cpu instruct\")\n        if sys.platform.startswith(\"linux\"):\n            if KTRANSFORMERS_BUILD_NPU:\n                return 'aarch64'\n            with open('/proc/cpuinfo', 'r', encoding=\"utf-8\") as cpu_f:\n                cpuinfo = cpu_f.read()\n            flags_line = [line for line in cpuinfo.split(\n                '\\n') if line.startswith('flags')][0]\n            flags = flags_line.split(':')[1].strip().split(' ')\n            # fancy with AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI\n            for flag in flags:\n                if 'avx512bw' in flag:\n                    return 'fancy'\n            for flag in flags:\n                if 'avx512' in flag:\n                    return 'avx512'\n            for flag in flags:\n                if 'avx2' in flag:\n                    return 'avx2'\n            raise ValueError(\n                \"Unsupported cpu Instructions: {}\".format(flags_line))\n        elif sys.platform == \"win32\":\n            from cpufeature.extension import CPUFeature\n\n            if CPUFeature.get(\"AVX512bw\", False):\n                return 'fancy'\n            if CPUFeature.get(\"AVX512f\", False):\n                return 'avx512'\n            if CPUFeature.get(\"AVX2\", False):\n                return 'avx2'\n            raise ValueError(\n                \"Unsupported cpu Instructions: {}\".format(str(CPUFeature)))\n        else:\n            raise ValueError(\"Unsupported platform: {}\".format(sys.platform))\n\n    def get_torch_version(self,):\n        torch_version_raw = parse(torch.__version__)\n        torch_version = f\"{torch_version_raw.major}{torch_version_raw.minor}\"\n        return torch_version\n\n    def get_flash_version(self,):\n        version_file = os.path.join(\n            Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, \"__init__.py\")\n        with open(version_file, \"r\", encoding=\"utf-8\") as f:\n            version_match = re.search(\n                r\"^__version__\\s*=\\s*(.*)$\", f.read(), re.MULTILINE)\n        flash_version = ast.literal_eval(version_match.group(1))\n        return flash_version\n\n    def get_package_version(self, full_version=False):\n        flash_version = str(self.get_flash_version())\n        torch_version = self.get_torch_version()\n        cpu_instruct = self.get_cpu_instruct()\n        backend_version = \"\"\n        if CUDA_HOME is not None:\n            backend_version = f\"cu{self.get_cuda_bare_metal_version(CUDA_HOME)}\"\n        elif MUSA_HOME is not None:\n            backend_version = f\"mu{self.get_musa_bare_metal_version(MUSA_HOME)}\"\n        elif ROCM_HOME is not None:\n            backend_version = f\"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}\"\n        elif torch.xpu.is_available():\n            backend_version = f\"xpu\"\n        elif KTRANSFORMERS_BUILD_NPU:\n            backend_version = f\"npu{torch_npu.__version__}\"\n        else:\n            raise ValueError(\"Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set and XPU is not available.\")\n        package_version = f\"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}\"\n        if full_version:\n            return package_version\n        if not VersionInfo.FORCE_BUILD:\n            return flash_version\n        return package_version\n\n\nclass BuildWheelsCommand(_bdist_wheel):\n    def get_wheel_name(self,):\n        version_info = VersionInfo()\n        package_version = version_info.get_package_version(full_version=True)\n        flash_version = version_info.get_flash_version()\n        python_version = f\"cp{sys.version_info.major}{sys.version_info.minor}\"\n        wheel_filename = f\"{VersionInfo.PACKAGE_NAME}-{package_version}-{python_version}-{python_version}-{version_info.get_platform()}.whl\"\n        wheel_url = VersionInfo.BASE_WHEEL_URL.format(tag_name=f\"v{flash_version}\", wheel_filename=wheel_filename)\n        return wheel_filename, wheel_url\n\n\n    def run(self):\n        if VersionInfo.FORCE_BUILD:\n            super().run()\n            return\n        wheel_filename, wheel_url = self.get_wheel_name()\n        print(\"Guessing wheel URL: \", wheel_url)\n        try:\n            urllib.request.urlretrieve(wheel_url, wheel_filename)\n            # Make the archive\n            # Lifted from the root wheel processing command\n            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85\n            if not os.path.exists(self.dist_dir):\n                os.makedirs(self.dist_dir)\n\n            impl_tag, abi_tag, plat_tag = self.get_tag()\n            archive_basename = f\"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}\"\n\n            wheel_path = os.path.join(self.dist_dir, archive_basename + \".whl\")\n            print(\"Raw wheel path\", wheel_path)\n            shutil.move(wheel_filename, wheel_path)\n        except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):\n            print(\"Precompiled wheel not found. Building from source...\")\n            # If the wheel could not be downloaded, build from source\n            super().run()\n\n\nANSI_ESCAPE = re.compile(\n    r'\\033[@-Z\\\\-_\\[\\]P]|\\033\\[[0-?]*[ -/]*[@-~]|\\033][^\\007\\033]*\\007|[\\000-\\037]'\n)\n\ndef colored(text, color=None, bold=False):\n    fmt = []\n    if color== 'red':\n        fmt.append('31')\n    elif color == 'green':\n        fmt.append('32')\n    if bold:\n        fmt.append('1')\n\n    return f\"\\033[{';'.join(fmt)}m{text}\\033[0m\"\n\n\ndef split_line(text: str) -> List[str]:\n    \"\"\"Split text into lines based on terminal width.\"\"\"\n    term_width = shutil.get_terminal_size().columns or 80\n    if not text.strip():\n        return []\n    # Split by explicit newlines and wrap long lines\n    lines = []\n    for line in text.split('\\n'):\n        while len(line) > term_width:\n            lines.append(line[:term_width])\n            line = line[term_width:]\n        if line:\n            lines.append(line)\n    return lines\n\n\n\nANSI_ESCAPE = re.compile(\n    r'\\033[@-Z\\\\-_\\[\\]P]|\\033\\[[0-?]*[ -/]*[@-~]|\\033][^\\007\\033]*\\007|[\\000-\\037]'\n)\n\ndef colored(text, color=None, bold=False):\n    fmt = []\n    if color== 'red':\n        fmt.append('31')\n    elif color == 'green':\n        fmt.append('32')\n    if bold:\n        fmt.append('1')\n\n    return f\"\\033[{';'.join(fmt)}m{text}\\033[0m\"\n\n\ndef split_line(text: str) -> List[str]:\n    \"\"\"Split text into lines based on terminal width.\"\"\"\n    term_width = shutil.get_terminal_size().columns or 80\n    if not text.strip():\n        return []\n    # Split by explicit newlines and wrap long lines\n    lines = []\n    for line in text.split('\\n'):\n        while len(line) > term_width:\n            lines.append(line[:term_width])\n            line = line[term_width:]\n        if line:\n            lines.append(line)\n    return lines\n\n\ndef run_command_with_live_tail(ext: str, command: List[str], output_lines: int = 20,\n                               refresh_rate: float = 0.1, cwd: Optional[str] = None):\n    \"\"\"\n    Execute a script-like command with real-time output of the last `output_lines` lines.\n\n    - during execution: displays the last `output_lines` lines of output in real-time.\n    - On success: Clears the displayed output.\n    - On failure: Prints the full command output.\n\n    Args:\n        ext (str): the name of the native extension currently building.\n        command (List[str]): The command to execute, as a list of arguments.\n        output_lines (int, optional): Number of terminal lines to display during live output. Defaults to 20.\n        refresh_rate (float, optional): Time in seconds between output refreshes. Defaults to 0.1.\n        cwd (Optional[str], optional): Working directory to run the command in. Defaults to current directory.\n    \"\"\"\n    # Dump all subprocess output without any buffering if stdout is not a terminal\n    if not sys.stdout.isatty():\n        return subprocess.run(command, cwd=cwd, check=True)\n    # Start time for elapsed time calculation\n    start = time.time()\n    # Buffer for all output\n    all_output = []\n    write_buffer = deque(maxlen=output_lines)\n    # Current number of lines from sub process displayed\n    current_lines = 0\n\n    # ANSI escape codes for terminal control\n    CLEAR_LINE = '\\033[K'\n    MOVE_UP = '\\033[1A'\n    SAVE_CURSOR = '\\0337'\n    RESTORE_CURSOR = '\\0338'\n    CLEAR_REMAINING = '\\033[J'\n\n    def write_progress(status: Literal['RUNNING', 'SUCCEED', 'FAILED'] = 'RUNNING',\n                       new_line: Optional[str] = None):\n        \"\"\"Update terminal display with latest output\"\"\"\n        nonlocal current_lines, process\n        sys.stdout.write(SAVE_CURSOR)\n        sys.stdout.write(MOVE_UP * current_lines)\n        banner = f\"ext={ext} pid={process.pid} status={status.upper()} elapsed=({time.time()-start:.2f}S)\\n\"\n        if status != 'FAILED':\n            banner = colored(banner, 'green', bold=True)\n        else:\n            banner = colored(banner, 'red', bold=True)\n        sys.stdout.write(CLEAR_LINE + banner)\n        if new_line is not None:\n            all_output.append(new_line)\n            write_buffer.extend(split_line(ANSI_ESCAPE.sub('', new_line).rstrip()))\n        elif status == 'RUNNING':\n            sys.stdout.write(RESTORE_CURSOR)\n            sys.stdout.flush()\n            return\n\n        sys.stdout.write(CLEAR_REMAINING)\n        if status == 'RUNNING':\n            current_lines = 1 + len(write_buffer)\n            for text in write_buffer:\n                sys.stdout.write(text + '\\n')\n        elif status == 'FAILED':\n            for text in all_output:\n                sys.stdout.write(text)\n        sys.stdout.flush()\n\n    # Start subprocess\n    sys.stdout.write(colored(f'ext={ext} command={\" \".join(str(c) for c in command)}\\n', bold=True))\n    sys.stdout.flush()\n    process = subprocess.Popen(\n        command,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.STDOUT,\n        cwd=cwd,\n        text=True,\n        bufsize=1\n    )\n\n    try:\n        write_progress()\n        poll_obj = select.poll()\n        poll_obj.register(process.stdout, select.POLLIN)\n        while process.poll() is None:\n            poll_result = poll_obj.poll(refresh_rate * 1000)\n            if poll_result:\n                write_progress(new_line=process.stdout.readline())\n            else:\n                write_progress()\n\n        # Get any remaining output\n        while True:\n            line = process.stdout.readline()\n            if not line:\n                break\n            write_progress(new_line=line)\n    except BaseException as e:\n        process.terminate()\n        raise e\n    finally:\n        exit_code = process.wait()\n        write_progress(status='SUCCEED' if exit_code == 0 else 'FAILED')\n\n\n# Convert distutils Windows platform specifiers to CMake -A arguments\nPLAT_TO_CMAKE = {\n    \"win32\": \"Win32\",\n    \"win-amd64\": \"x64\",\n    \"win-arm32\": \"ARM\",\n    \"win-arm64\": \"ARM64\",\n}\n\n\nclass CMakeExtension(Extension):\n    def __init__(self, name: str, sourcedir: str) -> None:\n        super().__init__(name, sources=[])\n        print(name, sourcedir)\n        self.sourcedir = sourcedir\n\ndef get_cmake_abi_args(cmake_args):\n    if torch.compiled_with_cxx11_abi():\n        cmake_args.append(\"-D_GLIBCXX_USE_CXX11_ABI=1\")\n    else:\n        cmake_args.append(\"-D_GLIBCXX_USE_CXX11_ABI=0\")\n    return cmake_args\n\nclass CMakeBuild(BuildExtension):\n\n    def build_extension(self, ext) -> None:\n        if not isinstance(ext, CMakeExtension):\n            super().build_extension(ext)\n            return\n        ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)\n        extdir = ext_fullpath.parent.resolve()\n\n        # Using this requires trailing slash for auto-detection & inclusion of\n        # auxiliary \"native\" libs\n\n        debug = int(os.environ.get(\"DEBUG\", 0)\n                    ) if self.debug is None else self.debug\n        cfg = \"Debug\" if debug else \"Release\"\n\n        # CMake lets you override the generator - we need to check this.\n        # Can be set with Conda-Build, for example.\n        cmake_generator = os.environ.get(\"CMAKE_GENERATOR\", \"\")\n\n        # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON\n        # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code\n        # from Python.\n        cmake_args = [\n            f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}\",\n            f\"-DPYTHON_EXECUTABLE={sys.executable}\",\n            f\"-DCMAKE_BUILD_TYPE={cfg}\",  # not used on MSVC, but no harm\n        ]\n\n        if CUDA_HOME is not None:\n            cmake_args += [\"-DKTRANSFORMERS_USE_CUDA=ON\"]\n        elif MUSA_HOME is not None:\n            cmake_args += [\"-DKTRANSFORMERS_USE_MUSA=ON\"]\n        elif ROCM_HOME is not None:\n            cmake_args += [\"-DKTRANSFORMERS_USE_ROCM=ON\"]\n        elif KTRANSFORMERS_BUILD_XPU:\n            cmake_args += [\"-DKTRANSFORMERS_USE_XPU=ON\", \"-DKTRANSFORMERS_USE_CUDA=OFF\"]\n        elif KTRANSFORMERS_BUILD_NPU:\n            cmake_args += [\"-DKTRANSFORMERS_USE_NPU=ON\", \"-DKTRANSFORMERS_USE_CUDA=OFF\"]\n        else:\n            raise ValueError(\"Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set and XPU is not available.\")\n        \n        cmake_args = get_cmake_abi_args(cmake_args)\n        # log cmake_args\n        print(\"CMake args:\", cmake_args)\n\n        build_args = []\n        if \"CMAKE_ARGS\" in os.environ:\n            cmake_args += [\n                item for item in os.environ[\"CMAKE_ARGS\"].split(\" \") if item]\n\n        if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:\n            cpu_args = CpuInstructInfo.CMAKE_FANCY\n        elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:\n            cpu_args = CpuInstructInfo.CMAKE_AVX512\n        elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2:\n            cpu_args = CpuInstructInfo.CMAKE_AVX2\n        else:\n            cpu_args = CpuInstructInfo.CMAKE_NATIVE\n\n        cmake_args += [\n            item for item in cpu_args.split(\" \") if item\n        ]\n        # In this example, we pass in the version to C++. You might not need to.\n        cmake_args += [\n            f\"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}\"]\n        if self.compiler.compiler_type != \"msvc\":\n            if not cmake_generator or cmake_generator == \"Ninja\":\n                pass\n                # try:\n                #     import ninja\n\n                #     ninja_executable_path = Path(ninja.BIN_DIR) / \"ninja\"\n                #     cmake_args += [\n                #         \"-GNinja\",\n                #         f\"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}\",\n                #     ]\n                # except ImportError:\n                #     pass\n\n        else:\n            # Single config generators are handled \"normally\"\n            single_config = any(\n                x in cmake_generator for x in {\"NMake\", \"Ninja\"})\n\n            # CMake allows an arch-in-generator style for backward compatibility\n            contains_arch = any(x in cmake_generator for x in {\"ARM\", \"Win64\"})\n            if not single_config and not contains_arch and cmake_generator:\n                cmake_args += [\"-A\", PLAT_TO_CMAKE[self.plat_name]]\n\n            # Multi-config generators have a different way to specify configs\n            if not single_config:\n                cmake_args += [\n                    f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}\"\n                ]\n                build_args += [\"--config\", cfg]\n\n        if sys.platform.startswith(\"darwin\"):\n            # Cross-compile support for macOS - respect ARCHFLAGS if set\n            archs = re.findall(r\"-arch (\\S+)\", os.environ.get(\"ARCHFLAGS\", \"\"))\n            if archs:\n                cmake_args += [\n                    \"-DCMAKE_OSX_ARCHITECTURES={}\".format(\";\".join(archs))]\n\n        if \"CMAKE_BUILD_PARALLEL_LEVEL\" not in os.environ:\n            cpu_count = os.cpu_count()\n            if cpu_count is None:\n                cpu_count = 1\n            if hasattr(self, \"parallel\") and self.parallel:\n                build_args += [f\"--parallel={self.parallel}\"]\n            else:\n                build_args += [f\"--parallel={cpu_count}\"]\n        print(\"CMake args:\", cmake_args)\n        build_temp = Path(ext.sourcedir) / \"build\"\n        print(\"build_temp:\", build_temp)\n\n        if not build_temp.exists():\n            build_temp.mkdir(parents=True)\n        run_command_with_live_tail(ext.name,\n            [\"cmake\", ext.sourcedir, *cmake_args], cwd=build_temp\n        )\n        run_command_with_live_tail(ext.name,\n            [\"cmake\", \"--build\", build_temp, \"--verbose\", *build_args], cwd=build_temp\n        )\n\nif CUDA_HOME is not None or ROCM_HOME is not None:\n    ops_module = CUDAExtension('KTransformersOps', [\n        'csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu',\n        'csrc/ktransformers_ext/cuda/binding.cpp',\n        'csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'\n    ],\n    extra_compile_args={\n            'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],\n            'nvcc': [\n                '-O3',\n                # '--use_fast_math',\n                '-Xcompiler', '-fPIC',\n                '-DKTRANSFORMERS_USE_CUDA',\n            ]\n        }\n    )\nelif MUSA_HOME is not None:\n    SimplePorting(cuda_dir_path=\"csrc/ktransformers_ext/cuda\", mapping_rule={\n        # Common rules\n        \"at::cuda\": \"at::musa\",\n        \"#include <ATen/cuda/CUDAContext.h>\": \"#include \\\"torch_musa/csrc/aten/musa/MUSAContext.h\\\"\",\n        \"#include <c10/cuda/CUDAGuard.h>\": \"#include \\\"torch_musa/csrc/core/MUSAGuard.h\\\"\",\n        \"nv_bfloat16\": \"mt_bfloat16\",\n        }).run()\n    ops_module = MUSAExtension('KTransformersOps', [\n        'csrc/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',\n        'csrc/ktransformers_ext/cuda_musa/binding.cpp',\n        # TODO: Add Marlin support for MUSA.\n        # 'csrc/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'\n    ],\n    extra_compile_args={\n            'cxx': ['force_mcc'],\n            'mcc': [\n                '-O3',\n                '-DKTRANSFORMERS_USE_MUSA',\n                '-DTHRUST_IGNORE_CUB_VERSION_CHECK',\n            ]\n        }\n    )\nelif torch.xpu.is_available(): #XPUExtension is not available now.\n    ops_module = None\nelif KTRANSFORMERS_BUILD_NPU:\n    pass\nelse:\n    raise ValueError(\"Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.\")\n\nif not torch.xpu.is_available() and not KTRANSFORMERS_BUILD_NPU:\n    ext_modules = [\n        CMakeExtension(\"cpuinfer_ext\", os.fspath(Path(\"\").resolve() / \"csrc\" / \"ktransformers_ext\")),\n        ops_module,\n        CUDAExtension(\n            'vLLMMarlin', [\n                'csrc/custom_marlin/binding.cpp',\n                'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',\n                'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',\n            ],\n            extra_compile_args={\n                'cxx': ['-O3'],\n                'nvcc': ['-O3', '-Xcompiler', '-fPIC'],\n            },\n        )\n    ]\n    if with_balance:\n        print(\"using balance_serve\")\n        ext_modules.append(\n            CMakeExtension(\"balance_serve\", os.fspath(Path(\"\").resolve()/ \"csrc\"/ \"balance_serve\"))\n        )\n\n    setup(\n        name=VersionInfo.PACKAGE_NAME,\n        version=VersionInfo().get_package_version(),\n        install_requires=triton_dep,\n        cmdclass={\"bdist_wheel\":BuildWheelsCommand ,\"build_ext\": CMakeBuild},\n        ext_modules=ext_modules\n    )\n\n\n\nelif torch.xpu.is_available():\n    ext_modules = [\n        CMakeExtension(\"cpuinfer_ext\", os.fspath(Path(\"\").resolve() / \"csrc\" / \"ktransformers_ext\")),\n    ]\n    setup(\n        name=VersionInfo.PACKAGE_NAME,\n        version=VersionInfo().get_package_version(),\n        install_requires=triton_dep,\n        cmdclass={\"bdist_wheel\":BuildWheelsCommand ,\"build_ext\": CMakeBuild},\n        ext_modules=ext_modules\n    )\n\nelif KTRANSFORMERS_BUILD_NPU:\n    ext_modules = [\n        CMakeExtension(\"cpuinfer_ext\", os.fspath(Path(\"\").resolve() / \"csrc\" / \"ktransformers_ext\")),\n    ] \n    if with_balance:\n        print(\"using balance_serve\")\n        ext_modules.append(\n            CMakeExtension(\"balance_serve\", os.fspath(Path(\"\").resolve()/ \"csrc\"/ \"balance_serve\"))\n        )\n\n    setup(\n        name=VersionInfo.PACKAGE_NAME,\n        version=VersionInfo().get_package_version(),\n        cmdclass={\"bdist_wheel\":BuildWheelsCommand ,\"build_ext\": CMakeBuild},\n        ext_modules=ext_modules\n    )\n"
  },
  {
    "path": "archive/third_party/llamafile/README.md",
    "content": "The code in this folder is copied from [Mozilla-Ocho/llamafile](https://github.com/Mozilla-Ocho/llamafile). Special thanks to the Mozilla-Ocho team.\n"
  },
  {
    "path": "archive/third_party/llamafile/bench.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/bench.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n#pragma once\n\n#include <stdio.h>\n\n#include \"micros.h\"\n\n#define BENCH(x)                                                                       \\\n    do {                                                                               \\\n        x;                                                                             \\\n        __asm__ volatile(\"\" ::: \"memory\");                                             \\\n        long long start = micros();                                                    \\\n        for (int i = 0; i < ITERATIONS; ++i) {                                         \\\n            __asm__ volatile(\"\" ::: \"memory\");                                         \\\n            x;                                                                         \\\n            __asm__ volatile(\"\" ::: \"memory\");                                         \\\n        }                                                                              \\\n        printf(\"%9lld us %s\\n\", (micros() - start + ITERATIONS - 1) / ITERATIONS, #x); \\\n    } while (0)\n"
  },
  {
    "path": "archive/third_party/llamafile/flags.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/flags.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#include \"flags.h\"\n\nbool FLAG_precise = false;\n"
  },
  {
    "path": "archive/third_party/llamafile/flags.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/flags.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#pragma once\n\nextern bool FLAG_precise;\n"
  },
  {
    "path": "archive/third_party/llamafile/iqk_mul_mat.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat.inc\n// Copyrigth 2024 Iwan Kawrakow - Apache 2.0 Licens\n// with additions from\n// https://github.com/ikawrakow/ik_llama.cpp/blob/main/ggml/src/iqk/iqk_mul_mat.cpp\n// Copyrigth 2024-2025 Iwan Kawrakow - MIT Licens\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp fenc=utf-8 :vi\n//\n// Copyright 2024 Iwan Kawrakow\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n//\n//\n// Copyright (C) 2024-2025 Iwan Kawrakow\n// MIT license\n// SPDX-License-Identifier: MIT\n//\n\n#if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU\n        // use ARM version\n        #include \"iqk_mul_mat_arm.inc\"\n#else\n        // use x86 version\n        #include \"iqk_mul_mat_x86.inc\"\n#endif"
  },
  {
    "path": "archive/third_party/llamafile/iqk_mul_mat_amd_avx2.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_amd_avx2.cpp\n// Copyrigth 2024 Iwan Kawrakow.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#include \"iqk_mul_mat.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/iqk_mul_mat_amd_zen4.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_amd_zen4.cpp\n// Copyrigth 2024 Iwan Kawrakow.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define iqk_mul_mat iqk_mul_mat_zen4\n#define iqk_mul_mat_moe iqk_mul_mat_moe_zen4\n#include \"iqk_mul_mat.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/iqk_mul_mat_arm.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat.inc\n// Copyrigth 2024 Iwan Kawrakow.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp fenc=utf-8 :vi\n//\n// Copyright 2024 Iwan Kawrakow\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <cstring>\n#include <type_traits>\n#if defined __x86_64__ || defined __aarch64__ || defined(_M_X64)\n\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"sgemm.h\"\n\n// For i-quants, I had to explicitely specify which\n// functions to inline / not inline (at least for some\n// of the functions), else performance would be significantly\n// lower. This is worrysome as things can change with,\n// e.g., a different compiler version or running on a different\n// CPU.\n#ifdef _MSC_VER\n#define IQK_NOINLINE __declspec(noinline)\n#define IQK_ALWAYS_INLINE inline\n#else\n#define IQK_NOINLINE __attribute__((__noinline__))\n#define IQK_ALWAYS_INLINE __attribute__((always_inline))\n#endif\n\n#define GGML_COMMON_IMPL_C\n#include \"llama.cpp/ggml-common.h\"\n\n// clang-format off\n\n// This matrix - vector and matrix - matrix multiplication implementation\n// for legacy quants, k-quants and i-quants makes prompt processing 150-200%\n// (legacy and k-quants) or 250-400% (i-quants) faster.\n// compared to mainline llama.cpp (and llamafile).\n// It provides implementations for ARM_NEON (all quants) and AVX2\n// (all quants except sub-4 bit i-quants).\n//\n// Main idea is that unpacking the quants and the block scales to\n// be ready for dot products with the corresponding Q8_Y quants\n// takes time (here 'Y' stands for K, 0, or 1, depending on quantization type).\n// Hence, if we are performing a QX x Q8_Y matrix matrix\n// multiplication (as needed for prompt processing), we can get\n// a significant speedup by reusing the unpacked QX quants and scales\n// for multiplication with several Q8_K columns. We also achieve fewer\n// loads from memory, which is the main purpose of tiling in general\n// purpose matrix multiplication packages.\n\n#include <utility>\n#include <array>\n\n#endif\n\nconstexpr ggml_type GGML_TYPE_Q8_0_X4 = static_cast<ggml_type>(98);\nconstexpr ggml_type GGML_TYPE_Q8_1_X4 = static_cast<ggml_type>(99);\n\n\nnamespace {\n#define GEMV_Q4K\n#define GEMV_Q6K\n#define GEMM_Q4K_Q6K\n\ntypedef struct {\n    int32_t i1;\n    int32_t i2;\n} mmid_row_mapping;\n\nstruct DataInfo {\n    float       * s;\n    const char  * cy;\n    size_t        bs;\n    size_t        by;\n    int           cur_y = 0;\n    int           ne11;\n    const mmid_row_mapping * row_mapping = nullptr;\n    size_t        bs2 = 0;\n\n    inline const char * src1_row(int iy) const {\n        if (!row_mapping) return cy + (cur_y + iy)*by;\n        int i11 = row_mapping[cur_y + iy].i1 % ne11;\n        int i12 = row_mapping[cur_y + iy].i2;\n        return cy + (i11 + i12*ne11)*by;\n    }\n\n    inline void store(int ix, int iy, float result) const {\n        *(dst_row(iy) + ix) = result;\n        //dst_row(iy)[ix] = result;\n    }\n    inline float* ptr(int ix, int iy) const {\n        return dst_row(iy) + ix;\n    }\n    inline float * dst_row(int iy) const {\n        if (!row_mapping) return s + (cur_y + iy)*bs;\n        int i12 = row_mapping[cur_y + iy].i2;\n        int i1  = row_mapping[cur_y + iy].i1;\n        int i2  = i12;\n        return s + i1*bs + i2*bs2;\n    }\n};\n\n/*\nmoonll \nchange param for set_mul_mat \nadd func16\n*/\n\ntypedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);\ntypedef void (*mul_mat_t_v2)(int m, int n, int k, const void *vx, size_t bx, const DataInfo& info);\n\nstruct MulMat {\n    std::array<mul_mat_t, 8> funcs = {};\n    mul_mat_t func16 = nullptr;\n    mul_mat_t_v2 funcs_v2;\n    //inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {\n    IQK_NOINLINE void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {\n        constexpr int k_x_step = 64; // This works best on my Ryzen-7950X and M2 Max CPUs (but differences to other tile size are small)\n\n        if (func16 && nrc_y >= 16) {\n            int n_step = (nrc_y - info.cur_y)/16;\n            for (int ix = 0; ix < nrc_x; ix += k_x_step) {\n                auto this_info = info;\n                this_info.s += ix;\n                int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;\n                for (int iy = 0; iy < n_step; ++iy) {\n                    func16(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);\n                    this_info.cur_y += 16;\n                }\n            }\n            info.cur_y += 16 * n_step;\n            if (info.cur_y == nrc_y) return;\n        }\n\n        int n_step = (nrc_y - info.cur_y)/funcs.size();\n        if (n_step > 0) {\n            for (int ix = 0; ix < nrc_x; ix += k_x_step) {\n                auto this_info = info;\n                this_info.s += ix;\n                int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;\n                for (int iy = 0; iy < n_step; ++iy) {\n                    funcs.back()(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);\n                    this_info.cur_y += funcs.size();\n                }\n            }\n            info.cur_y += funcs.size() * n_step;\n        }\n        int n_left = nrc_y - info.cur_y;\n        if (n_left > 0) {\n            funcs[n_left-1](n, vx, bx, info, nrc_x);\n        }\n    }\n#if defined __x86_64__ || defined(_M_X64)\n    static IQK_NOINLINE bool set_mul_mat(int typeA, int typeB,int ne00, MulMat& mm, int Ny);\n#else\n    IQK_NOINLINE void mul_mat_NxM_v2(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {\n        funcs_v2(nrc_x, nrc_y, n, vx, bx, info);\n        return;\n    }\n    static IQK_NOINLINE bool set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny);\n#endif\nprivate:\n    template <typename Dequantizer> static IQK_NOINLINE void set_functions(MulMat& m);\n};\n\ninline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {\n    const uint16_t * scales = (const uint16_t *)scales8;\n    const uint32_t a0 = scales[0] | (scales[1] << 16);\n    const uint32_t a1 = scales[2] | (scales[3] << 16);\n    const uint32_t a2 = scales[4] | (scales[5] << 16);\n    aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);\n    aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);\n    aux32[2] = a1 & 0x3f3f3f3f;\n    aux32[0] = a0 & 0x3f3f3f3f;\n}\n\n/*\nmoonll\ndecoding tables\n*/\n#ifdef __AVX2__\nstatic const uint64_t iq1s_grid_us[2048] = {\n    0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200,\n    0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000,\n    0x0000000000020002, 0x0000000000020200, 0x0000000000020202, 0x0000000001000101,\n    0x0000000001010001, 0x0000000001010100, 0x0000000001010102, 0x0000000001020101,\n    0x0000000002000000, 0x0000000002000002, 0x0000000002000200, 0x0000000002000202,\n    0x0000000002010101, 0x0000000002020000, 0x0000000002020002, 0x0000000002020200,\n    0x0000000002020202, 0x0000000100000100, 0x0000000100000101, 0x0000000100010001,\n    0x0000000100010100, 0x0000000100010102, 0x0000000100010201, 0x0000000100010202,\n    0x0000000100020101, 0x0000000101000001, 0x0000000101000102, 0x0000000101000201,\n    0x0000000101010002, 0x0000000101010101, 0x0000000101010202, 0x0000000101020001,\n    0x0000000101020100, 0x0000000101020102, 0x0000000101020200, 0x0000000102000101,\n    0x0000000102010001, 0x0000000102010100, 0x0000000102010102, 0x0000000102020101,\n    0x0000000200000000, 0x0000000200000002, 0x0000000200000200, 0x0000000200000202,\n    0x0000000200010101, 0x0000000200020000, 0x0000000200020002, 0x0000000200020200,\n    0x0000000200020202, 0x0000000201000101, 0x0000000201010001, 0x0000000201010201,\n    0x0000000201020100, 0x0000000201020201, 0x0000000202000000, 0x0000000202000002,\n    0x0000000202000200, 0x0000000202000202, 0x0000000202010001, 0x0000000202010101,\n    0x0000000202010201, 0x0000000202020000, 0x0000000202020002, 0x0000000202020200,\n    0x0000000202020202, 0x0000010000010001, 0x0000010000010100, 0x0000010000010102,\n    0x0000010000020101, 0x0000010001000001, 0x0000010001000201, 0x0000010001010101,\n    0x0000010001010202, 0x0000010001020100, 0x0000010001020101, 0x0000010002010001,\n    0x0000010002010201, 0x0000010002020101, 0x0000010100000001, 0x0000010100000100,\n    0x0000010100000101, 0x0000010100000102, 0x0000010100010101, 0x0000010100010200,\n    0x0000010100010202, 0x0000010100020201, 0x0000010101000000, 0x0000010101000101,\n    0x0000010101000202, 0x0000010101010000, 0x0000010101010001, 0x0000010101010100,\n    0x0000010101010101, 0x0000010101010102, 0x0000010101010201, 0x0000010101020000,\n    0x0000010101020002, 0x0000010101020101, 0x0000010101020200, 0x0000010101020202,\n    0x0000010102000001, 0x0000010102010001, 0x0000010102010101, 0x0000010102010200,\n    0x0000010102010202, 0x0000010102020001, 0x0000010102020100, 0x0000010102020101,\n    0x0000010102020102, 0x0000010102020201, 0x0000010200010100, 0x0000010200010201,\n    0x0000010201000001, 0x0000010201000100, 0x0000010201010000, 0x0000010201010002,\n    0x0000010201010101, 0x0000010201010200, 0x0000010201020000, 0x0000010201020001,\n    0x0000010201020102, 0x0000010201020201, 0x0000010202000101, 0x0000010202010001,\n    0x0000010202010100, 0x0000010202010201, 0x0000020000000000, 0x0000020000000002,\n    0x0000020000000200, 0x0000020000000202, 0x0000020000010101, 0x0000020000020000,\n    0x0000020000020002, 0x0000020000020200, 0x0000020000020202, 0x0000020001000101,\n    0x0000020001010001, 0x0000020001010102, 0x0000020001020101, 0x0000020002000000,\n    0x0000020002000002, 0x0000020002000200, 0x0000020002000202, 0x0000020002010101,\n    0x0000020002020000, 0x0000020002020002, 0x0000020002020200, 0x0000020002020202,\n    0x0000020100000101, 0x0000020100010001, 0x0000020100010100, 0x0000020100010201,\n    0x0000020100020100, 0x0000020100020101, 0x0000020101000001, 0x0000020101010000,\n    0x0000020101010001, 0x0000020101010101, 0x0000020101020001, 0x0000020101020100,\n    0x0000020101020201, 0x0000020102010001, 0x0000020102010100, 0x0000020102010102,\n    0x0000020102010201, 0x0000020102020101, 0x0000020200000000, 0x0000020200000002,\n    0x0000020200000200, 0x0000020200000202, 0x0000020200010101, 0x0000020200020000,\n    0x0000020200020002, 0x0000020200020200, 0x0000020200020202, 0x0000020201000101,\n    0x0000020201010001, 0x0000020201010201, 0x0000020201020001, 0x0000020201020101,\n    0x0000020202000000, 0x0000020202000002, 0x0000020202000101, 0x0000020202000200,\n    0x0000020202000202, 0x0000020202010101, 0x0000020202020000, 0x0000020202020002,\n    0x0000020202020200, 0x0000020202020202, 0x0001000000010000, 0x0001000000010001,\n    0x0001000000010100, 0x0001000000010201, 0x0001000000020100, 0x0001000000020101,\n    0x0001000001000001, 0x0001000001000100, 0x0001000001010000, 0x0001000001010101,\n    0x0001000001010200, 0x0001000001020001, 0x0001000001020100, 0x0001000001020101,\n    0x0001000001020201, 0x0001000002010001, 0x0001000002010100, 0x0001000002010102,\n    0x0001000002020001, 0x0001000002020101, 0x0001000100000001, 0x0001000100000100,\n    0x0001000100000102, 0x0001000100000201, 0x0001000100010000, 0x0001000100010002,\n    0x0001000100010101, 0x0001000100010200, 0x0001000100020001, 0x0001000100020100,\n    0x0001000100020201, 0x0001000101000101, 0x0001000101000202, 0x0001000101010000,\n    0x0001000101010001, 0x0001000101010002, 0x0001000101010100, 0x0001000101010101,\n    0x0001000101010102, 0x0001000101010201, 0x0001000101020000, 0x0001000101020101,\n    0x0001000102000100, 0x0001000102010002, 0x0001000102010101, 0x0001000102020001,\n    0x0001000102020100, 0x0001000200010001, 0x0001000200010100, 0x0001000200010102,\n    0x0001000200020101, 0x0001000201000000, 0x0001000201000102, 0x0001000201000201,\n    0x0001000201010002, 0x0001000201010101, 0x0001000201010200, 0x0001000201010202,\n    0x0001000201020100, 0x0001000201020102, 0x0001000202000101, 0x0001000202010001,\n    0x0001000202010100, 0x0001000202010102, 0x0001000202020101, 0x0001010000000001,\n    0x0001010000000102, 0x0001010000000201, 0x0001010000010100, 0x0001010000010101,\n    0x0001010000010200, 0x0001010000010201, 0x0001010000020001, 0x0001010000020102,\n    0x0001010001000001, 0x0001010001000101, 0x0001010001000102, 0x0001010001000200,\n    0x0001010001000202, 0x0001010001010001, 0x0001010001010100, 0x0001010001010101,\n    0x0001010001010102, 0x0001010001010201, 0x0001010001020002, 0x0001010001020101,\n    0x0001010001020200, 0x0001010002000100, 0x0001010002000201, 0x0001010002010000,\n    0x0001010002010100, 0x0001010002010101, 0x0001010002010200, 0x0001010002010201,\n    0x0001010002010202, 0x0001010002020001, 0x0001010002020100, 0x0001010002020101,\n    0x0001010002020201, 0x0001010100000002, 0x0001010100000101, 0x0001010100000202,\n    0x0001010100010001, 0x0001010100010100, 0x0001010100010101, 0x0001010100010102,\n    0x0001010100010201, 0x0001010100020000, 0x0001010100020002, 0x0001010100020101,\n    0x0001010100020200, 0x0001010100020202, 0x0001010101000001, 0x0001010101000100,\n    0x0001010101000101, 0x0001010101000102, 0x0001010101010001, 0x0001010101010002,\n    0x0001010101010100, 0x0001010101010101, 0x0001010101010102, 0x0001010101010201,\n    0x0001010101010202, 0x0001010101020001, 0x0001010101020100, 0x0001010101020101,\n    0x0001010101020102, 0x0001010101020201, 0x0001010102000000, 0x0001010102000002,\n    0x0001010102000100, 0x0001010102000101, 0x0001010102000200, 0x0001010102000202,\n    0x0001010102010000, 0x0001010102010001, 0x0001010102010100, 0x0001010102010101,\n    0x0001010102010102, 0x0001010102010201, 0x0001010102010202, 0x0001010102020000,\n    0x0001010102020002, 0x0001010102020101, 0x0001010200000001, 0x0001010200000100,\n    0x0001010200000101, 0x0001010200000102, 0x0001010200010101, 0x0001010200010102,\n    0x0001010200010200, 0x0001010200010202, 0x0001010200020001, 0x0001010200020102,\n    0x0001010201000000, 0x0001010201000002, 0x0001010201000100, 0x0001010201000101,\n    0x0001010201000200, 0x0001010201000202, 0x0001010201010001, 0x0001010201010101,\n    0x0001010201010102, 0x0001010201010200, 0x0001010201010201, 0x0001010201020001,\n    0x0001010201020100, 0x0001010201020101, 0x0001010201020200, 0x0001010201020201,\n    0x0001010201020202, 0x0001010202000102, 0x0001010202000202, 0x0001010202010002,\n    0x0001010202010101, 0x0001010202020100, 0x0001010202020201, 0x0001020000010001,\n    0x0001020000010102, 0x0001020000020101, 0x0001020001000001, 0x0001020001000100,\n    0x0001020001000102, 0x0001020001000201, 0x0001020001010000, 0x0001020001010101,\n    0x0001020001010200, 0x0001020001010202, 0x0001020001020000, 0x0001020001020001,\n    0x0001020001020100, 0x0001020001020102, 0x0001020001020201, 0x0001020002000101,\n    0x0001020002010001, 0x0001020002010100, 0x0001020002020101, 0x0001020100010000,\n    0x0001020100010002, 0x0001020100010101, 0x0001020100010202, 0x0001020100020001,\n    0x0001020100020101, 0x0001020101000002, 0x0001020101000100, 0x0001020101000101,\n    0x0001020101000200, 0x0001020101010001, 0x0001020101010100, 0x0001020101010101,\n    0x0001020101010102, 0x0001020101010201, 0x0001020101010202, 0x0001020101020000,\n    0x0001020101020101, 0x0001020101020202, 0x0001020102000201, 0x0001020102010001,\n    0x0001020102010002, 0x0001020102010101, 0x0001020102010200, 0x0001020102020001,\n    0x0001020102020102, 0x0001020102020201, 0x0001020200000201, 0x0001020200010102,\n    0x0001020200020100, 0x0001020200020102, 0x0001020201000100, 0x0001020201000102,\n    0x0001020201000201, 0x0001020201010000, 0x0001020201010002, 0x0001020201010101,\n    0x0001020201010200, 0x0001020201020001, 0x0001020201020102, 0x0001020201020201,\n    0x0001020202000101, 0x0001020202010001, 0x0001020202010102, 0x0001020202010202,\n    0x0002000000000000, 0x0002000000000002, 0x0002000000000200, 0x0002000000000202,\n    0x0002000000010101, 0x0002000000020000, 0x0002000000020002, 0x0002000000020101,\n    0x0002000000020200, 0x0002000000020202, 0x0002000001000101, 0x0002000001010001,\n    0x0002000001010201, 0x0002000001020001, 0x0002000001020101, 0x0002000002000000,\n    0x0002000002000002, 0x0002000002000200, 0x0002000002000202, 0x0002000002010101,\n    0x0002000002020000, 0x0002000002020002, 0x0002000002020101, 0x0002000002020200,\n    0x0002000002020202, 0x0002000100000101, 0x0002000100010001, 0x0002000100010100,\n    0x0002000100010201, 0x0002000100020101, 0x0002000101000002, 0x0002000101000100,\n    0x0002000101000201, 0x0002000101010101, 0x0002000101010200, 0x0002000101010202,\n    0x0002000101020001, 0x0002000101020100, 0x0002000101020101, 0x0002000101020102,\n    0x0002000102000101, 0x0002000102010000, 0x0002000102010102, 0x0002000102010201,\n    0x0002000102020101, 0x0002000200000001, 0x0002000200000200, 0x0002000200000202,\n    0x0002000200010001, 0x0002000200010101, 0x0002000200020000, 0x0002000200020002,\n    0x0002000200020200, 0x0002000200020202, 0x0002000201000101, 0x0002000201010001,\n    0x0002000201010102, 0x0002000201010201, 0x0002000201020101, 0x0002000202000001,\n    0x0002000202000200, 0x0002000202000202, 0x0002000202010001, 0x0002000202010101,\n    0x0002000202020000, 0x0002000202020002, 0x0002000202020200, 0x0002000202020202,\n    0x0002010000000101, 0x0002010000010100, 0x0002010000010102, 0x0002010000010201,\n    0x0002010000020101, 0x0002010001000100, 0x0002010001000101, 0x0002010001000102,\n    0x0002010001000201, 0x0002010001010002, 0x0002010001010101, 0x0002010001010200,\n    0x0002010001010202, 0x0002010001020102, 0x0002010002000101, 0x0002010002010001,\n    0x0002010002010100, 0x0002010002010201, 0x0002010002020001, 0x0002010002020101,\n    0x0002010100000201, 0x0002010100010101, 0x0002010100020001, 0x0002010100020201,\n    0x0002010101000000, 0x0002010101000101, 0x0002010101000200, 0x0002010101010001,\n    0x0002010101010100, 0x0002010101010101, 0x0002010101010201, 0x0002010101020002,\n    0x0002010101020101, 0x0002010101020200, 0x0002010102000201, 0x0002010102010000,\n    0x0002010102010100, 0x0002010102010101, 0x0002010102010200, 0x0002010102010202,\n    0x0002010102020001, 0x0002010102020100, 0x0002010102020102, 0x0002010102020201,\n    0x0002010200000101, 0x0002010200010000, 0x0002010200010002, 0x0002010200010201,\n    0x0002010200020101, 0x0002010201000001, 0x0002010201000201, 0x0002010201010101,\n    0x0002010201020000, 0x0002010201020001, 0x0002010201020201, 0x0002010202000100,\n    0x0002010202000102, 0x0002010202010000, 0x0002010202010202, 0x0002020000000000,\n    0x0002020000000002, 0x0002020000000200, 0x0002020000000202, 0x0002020000010101,\n    0x0002020000020000, 0x0002020000020002, 0x0002020000020200, 0x0002020000020202,\n    0x0002020001000101, 0x0002020001010001, 0x0002020001010100, 0x0002020001020101,\n    0x0002020002000000, 0x0002020002000002, 0x0002020002000200, 0x0002020002000202,\n    0x0002020002020000, 0x0002020002020002, 0x0002020002020200, 0x0002020002020202,\n    0x0002020100000201, 0x0002020100010001, 0x0002020100010100, 0x0002020100010201,\n    0x0002020100020101, 0x0002020101000102, 0x0002020101000201, 0x0002020101010002,\n    0x0002020101010101, 0x0002020101020001, 0x0002020101020100, 0x0002020101020102,\n    0x0002020101020201, 0x0002020102000101, 0x0002020102010000, 0x0002020102010102,\n    0x0002020102010201, 0x0002020102020100, 0x0002020102020101, 0x0002020200000000,\n    0x0002020200000002, 0x0002020200000200, 0x0002020200000202, 0x0002020200020000,\n    0x0002020200020002, 0x0002020200020200, 0x0002020200020202, 0x0002020201000101,\n    0x0002020201010001, 0x0002020201010102, 0x0002020201010201, 0x0002020201020101,\n    0x0002020202000000, 0x0002020202000002, 0x0002020202000200, 0x0002020202000202,\n    0x0002020202010101, 0x0002020202020000, 0x0002020202020002, 0x0002020202020200,\n    0x0002020202020202, 0x0100000000000101, 0x0100000000010001, 0x0100000000010102,\n    0x0100000000020101, 0x0100000001000201, 0x0100000001010002, 0x0100000001010101,\n    0x0100000001010200, 0x0100000001010202, 0x0100000001020001, 0x0100000001020100,\n    0x0100000001020102, 0x0100000002010100, 0x0100000002010201, 0x0100000002020001,\n    0x0100000002020102, 0x0100000100000000, 0x0100000100000001, 0x0100000100000100,\n    0x0100000100000102, 0x0100000100000201, 0x0100000100010002, 0x0100000100010101,\n    0x0100000100010102, 0x0100000100010200, 0x0100000100010202, 0x0100000100020001,\n    0x0100000100020102, 0x0100000100020201, 0x0100000101000101, 0x0100000101000200,\n    0x0100000101000202, 0x0100000101010001, 0x0100000101010100, 0x0100000101010101,\n    0x0100000101010102, 0x0100000101010201, 0x0100000101010202, 0x0100000101020101,\n    0x0100000101020200, 0x0100000101020202, 0x0100000102000001, 0x0100000102000100,\n    0x0100000102000102, 0x0100000102010000, 0x0100000102010002, 0x0100000102010101,\n    0x0100000102020000, 0x0100000102020001, 0x0100000102020002, 0x0100000200000101,\n    0x0100000200010001, 0x0100000200010100, 0x0100000200010102, 0x0100000200020101,\n    0x0100000201000001, 0x0100000201010002, 0x0100000201010101, 0x0100000201010202,\n    0x0100000201020100, 0x0100000201020201, 0x0100000202000201, 0x0100000202010100,\n    0x0100000202020101, 0x0100010000000001, 0x0100010000010101, 0x0100010000010201,\n    0x0100010000020201, 0x0100010001000101, 0x0100010001000200, 0x0100010001000202,\n    0x0100010001010001, 0x0100010001010100, 0x0100010001010101, 0x0100010001010102,\n    0x0100010001020001, 0x0100010001020002, 0x0100010001020101, 0x0100010001020200,\n    0x0100010001020202, 0x0100010002000001, 0x0100010002000102, 0x0100010002000201,\n    0x0100010002010000, 0x0100010002010002, 0x0100010002010101, 0x0100010002020000,\n    0x0100010002020001, 0x0100010002020201, 0x0100010100000001, 0x0100010100000002,\n    0x0100010100000101, 0x0100010100000202, 0x0100010100010001, 0x0100010100010100,\n    0x0100010100010101, 0x0100010100010102, 0x0100010100010201, 0x0100010100020000,\n    0x0100010100020101, 0x0100010100020202, 0x0100010101000001, 0x0100010101000100,\n    0x0100010101000101, 0x0100010101000102, 0x0100010101000201, 0x0100010101010000,\n    0x0100010101010001, 0x0100010101010100, 0x0100010101010101, 0x0100010101010102,\n    0x0100010101010200, 0x0100010101010201, 0x0100010101020001, 0x0100010101020100,\n    0x0100010101020101, 0x0100010101020102, 0x0100010101020201, 0x0100010102000002,\n    0x0100010102000100, 0x0100010102000101, 0x0100010102000200, 0x0100010102010001,\n    0x0100010102010100, 0x0100010102010101, 0x0100010102010102, 0x0100010102010201,\n    0x0100010102010202, 0x0100010102020101, 0x0100010102020200, 0x0100010102020202,\n    0x0100010200000001, 0x0100010200000101, 0x0100010200000201, 0x0100010200010100,\n    0x0100010200010101, 0x0100010200010200, 0x0100010200010202, 0x0100010200020001,\n    0x0100010200020100, 0x0100010200020201, 0x0100010201000000, 0x0100010201000002,\n    0x0100010201000101, 0x0100010201000200, 0x0100010201010000, 0x0100010201010001,\n    0x0100010201010002, 0x0100010201010101, 0x0100010201010102, 0x0100010201010201,\n    0x0100010201020002, 0x0100010201020101, 0x0100010201020200, 0x0100010202000001,\n    0x0100010202000101, 0x0100010202000202, 0x0100010202010100, 0x0100010202010101,\n    0x0100010202020001, 0x0100010202020100, 0x0100010202020102, 0x0100020000000101,\n    0x0100020000010001, 0x0100020000010101, 0x0100020000010202, 0x0100020000020101,\n    0x0100020001000002, 0x0100020001000201, 0x0100020001010000, 0x0100020001010101,\n    0x0100020001010200, 0x0100020001020001, 0x0100020001020100, 0x0100020001020102,\n    0x0100020001020201, 0x0100020002000101, 0x0100020002010001, 0x0100020002010100,\n    0x0100020002010102, 0x0100020002010201, 0x0100020002020101, 0x0100020100000001,\n    0x0100020100000101, 0x0100020100000102, 0x0100020100000202, 0x0100020100010000,\n    0x0100020100010100, 0x0100020100010101, 0x0100020100010200, 0x0100020100020001,\n    0x0100020100020100, 0x0100020100020102, 0x0100020101000000, 0x0100020101000101,\n    0x0100020101000202, 0x0100020101010001, 0x0100020101010002, 0x0100020101010100,\n    0x0100020101010101, 0x0100020101010102, 0x0100020101010201, 0x0100020101020000,\n    0x0100020101020002, 0x0100020101020101, 0x0100020101020102, 0x0100020101020202,\n    0x0100020102000102, 0x0100020102000201, 0x0100020102010002, 0x0100020102010101,\n    0x0100020102010102, 0x0100020102010200, 0x0100020102020001, 0x0100020102020100,\n    0x0100020102020102, 0x0100020102020201, 0x0100020200010102, 0x0100020201000100,\n    0x0100020201000102, 0x0100020201000201, 0x0100020201010101, 0x0100020201010200,\n    0x0100020201010202, 0x0100020201020100, 0x0100020201020201, 0x0100020202010100,\n    0x0100020202020101, 0x0101000000000001, 0x0101000000000100, 0x0101000000000101,\n    0x0101000000000102, 0x0101000000000201, 0x0101000000010002, 0x0101000000010101,\n    0x0101000000010202, 0x0101000000020001, 0x0101000000020100, 0x0101000000020201,\n    0x0101000001000000, 0x0101000001000101, 0x0101000001000200, 0x0101000001010001,\n    0x0101000001010100, 0x0101000001010101, 0x0101000001010102, 0x0101000001010201,\n    0x0101000001020101, 0x0101000001020200, 0x0101000002000102, 0x0101000002000201,\n    0x0101000002010101, 0x0101000002010200, 0x0101000002020000, 0x0101000002020001,\n    0x0101000002020102, 0x0101000002020201, 0x0101000100000101, 0x0101000100000200,\n    0x0101000100000201, 0x0101000100000202, 0x0101000100010001, 0x0101000100010100,\n    0x0101000100010101, 0x0101000100010102, 0x0101000100010200, 0x0101000100010201,\n    0x0101000100020000, 0x0101000100020101, 0x0101000100020102, 0x0101000100020200,\n    0x0101000100020202, 0x0101000101000001, 0x0101000101000100, 0x0101000101000101,\n    0x0101000101000102, 0x0101000101000201, 0x0101000101010000, 0x0101000101010001,\n    0x0101000101010002, 0x0101000101010100, 0x0101000101010101, 0x0101000101010102,\n    0x0101000101010200, 0x0101000101010201, 0x0101000101010202, 0x0101000101020001,\n    0x0101000101020100, 0x0101000101020101, 0x0101000101020102, 0x0101000101020201,\n    0x0101000102000002, 0x0101000102000101, 0x0101000102010001, 0x0101000102010100,\n    0x0101000102010101, 0x0101000102010102, 0x0101000102010201, 0x0101000102020000,\n    0x0101000102020101, 0x0101000102020202, 0x0101000200000001, 0x0101000200000102,\n    0x0101000200010002, 0x0101000200010101, 0x0101000200010202, 0x0101000200020001,\n    0x0101000200020100, 0x0101000201000002, 0x0101000201000101, 0x0101000201000202,\n    0x0101000201010001, 0x0101000201010100, 0x0101000201010101, 0x0101000201010102,\n    0x0101000201010201, 0x0101000201020002, 0x0101000201020101, 0x0101000202000101,\n    0x0101000202010000, 0x0101000202010002, 0x0101000202010101, 0x0101000202010201,\n    0x0101000202010202, 0x0101000202020100, 0x0101010000000100, 0x0101010000000101,\n    0x0101010000010001, 0x0101010000010100, 0x0101010000010101, 0x0101010000010102,\n    0x0101010000010200, 0x0101010000010201, 0x0101010000020001, 0x0101010000020101,\n    0x0101010000020200, 0x0101010000020202, 0x0101010001000001, 0x0101010001000100,\n    0x0101010001000101, 0x0101010001000102, 0x0101010001000201, 0x0101010001000202,\n    0x0101010001010000, 0x0101010001010001, 0x0101010001010100, 0x0101010001010101,\n    0x0101010001010102, 0x0101010001010200, 0x0101010001010201, 0x0101010001010202,\n    0x0101010001020001, 0x0101010001020002, 0x0101010001020100, 0x0101010001020101,\n    0x0101010001020102, 0x0101010001020201, 0x0101010002000000, 0x0101010002000200,\n    0x0101010002000202, 0x0101010002010001, 0x0101010002010100, 0x0101010002010101,\n    0x0101010002010102, 0x0101010002010201, 0x0101010002020001, 0x0101010002020100,\n    0x0101010002020101, 0x0101010002020202, 0x0101010100000001, 0x0101010100000002,\n    0x0101010100000100, 0x0101010100000101, 0x0101010100000102, 0x0101010100000201,\n    0x0101010100010000, 0x0101010100010001, 0x0101010100010002, 0x0101010100010100,\n    0x0101010100010101, 0x0101010100010102, 0x0101010100010201, 0x0101010100010202,\n    0x0101010100020001, 0x0101010100020100, 0x0101010100020101, 0x0101010100020102,\n    0x0101010100020201, 0x0101010101000000, 0x0101010101000001, 0x0101010101000002,\n    0x0101010101000100, 0x0101010101000101, 0x0101010101000102, 0x0101010101000200,\n    0x0101010101000201, 0x0101010101010000, 0x0101010101010001, 0x0101010101010002,\n    0x0101010101010100, 0x0101010101010101, 0x0101010101010102, 0x0101010101010200,\n    0x0101010101010201, 0x0101010101010202, 0x0101010101020000, 0x0101010101020001,\n    0x0101010101020100, 0x0101010101020101, 0x0101010101020102, 0x0101010101020200,\n    0x0101010101020201, 0x0101010101020202, 0x0101010102000001, 0x0101010102000100,\n    0x0101010102000101, 0x0101010102000201, 0x0101010102000202, 0x0101010102010000,\n    0x0101010102010001, 0x0101010102010100, 0x0101010102010101, 0x0101010102010102,\n    0x0101010102010200, 0x0101010102010201, 0x0101010102020001, 0x0101010102020100,\n    0x0101010102020101, 0x0101010102020102, 0x0101010102020201, 0x0101010200000000,\n    0x0101010200000001, 0x0101010200000002, 0x0101010200000100, 0x0101010200000102,\n    0x0101010200000200, 0x0101010200000201, 0x0101010200010001, 0x0101010200010100,\n    0x0101010200010101, 0x0101010200010200, 0x0101010200010201, 0x0101010200020000,\n    0x0101010200020001, 0x0101010200020002, 0x0101010200020100, 0x0101010200020101,\n    0x0101010200020102, 0x0101010200020200, 0x0101010200020201, 0x0101010201000001,\n    0x0101010201000101, 0x0101010201000102, 0x0101010201000200, 0x0101010201000201,\n    0x0101010201000202, 0x0101010201010000, 0x0101010201010001, 0x0101010201010002,\n    0x0101010201010100, 0x0101010201010101, 0x0101010201010102, 0x0101010201010200,\n    0x0101010201010201, 0x0101010201010202, 0x0101010201020001, 0x0101010201020100,\n    0x0101010201020101, 0x0101010201020201, 0x0101010202000002, 0x0101010202000101,\n    0x0101010202000102, 0x0101010202000200, 0x0101010202000201, 0x0101010202000202,\n    0x0101010202010001, 0x0101010202010101, 0x0101010202010202, 0x0101010202020002,\n    0x0101010202020101, 0x0101010202020102, 0x0101010202020200, 0x0101010202020201,\n    0x0101020000000100, 0x0101020000000101, 0x0101020000000102, 0x0101020000000201,\n    0x0101020000010000, 0x0101020000010101, 0x0101020000010200, 0x0101020000020001,\n    0x0101020000020202, 0x0101020001000101, 0x0101020001000200, 0x0101020001000202,\n    0x0101020001010001, 0x0101020001010100, 0x0101020001010101, 0x0101020001010102,\n    0x0101020001010200, 0x0101020001010201, 0x0101020001020000, 0x0101020001020002,\n    0x0101020001020100, 0x0101020001020101, 0x0101020002000002, 0x0101020002000201,\n    0x0101020002010000, 0x0101020002010002, 0x0101020002010101, 0x0101020002010200,\n    0x0101020002020001, 0x0101020002020201, 0x0101020100000001, 0x0101020100000002,\n    0x0101020100000101, 0x0101020100000202, 0x0101020100010001, 0x0101020100010100,\n    0x0101020100010101, 0x0101020100010102, 0x0101020100010201, 0x0101020100020101,\n    0x0101020101000001, 0x0101020101000100, 0x0101020101000101, 0x0101020101000102,\n    0x0101020101000201, 0x0101020101010000, 0x0101020101010001, 0x0101020101010002,\n    0x0101020101010100, 0x0101020101010101, 0x0101020101010102, 0x0101020101010200,\n    0x0101020101010201, 0x0101020101010202, 0x0101020101020001, 0x0101020101020100,\n    0x0101020101020101, 0x0101020101020102, 0x0101020101020201, 0x0101020102000001,\n    0x0101020102000101, 0x0101020102000201, 0x0101020102010001, 0x0101020102010100,\n    0x0101020102010101, 0x0101020102010102, 0x0101020102010200, 0x0101020102010201,\n    0x0101020102020101, 0x0101020200000100, 0x0101020200000200, 0x0101020200010101,\n    0x0101020200010202, 0x0101020200020000, 0x0101020200020101, 0x0101020200020102,\n    0x0101020200020201, 0x0101020201000101, 0x0101020201000200, 0x0101020201000201,\n    0x0101020201010001, 0x0101020201010101, 0x0101020201010102, 0x0101020201010200,\n    0x0101020201010201, 0x0101020201020002, 0x0101020201020101, 0x0101020201020200,\n    0x0101020201020202, 0x0101020202000001, 0x0101020202000202, 0x0101020202010002,\n    0x0101020202010101, 0x0101020202010102, 0x0101020202010200, 0x0101020202010202,\n    0x0101020202020001, 0x0102000000000101, 0x0102000000010100, 0x0102000000010102,\n    0x0102000000010201, 0x0102000000020101, 0x0102000001000100, 0x0102000001010000,\n    0x0102000001010101, 0x0102000001010102, 0x0102000001010200, 0x0102000001010202,\n    0x0102000001020001, 0x0102000001020100, 0x0102000001020102, 0x0102000001020201,\n    0x0102000002000001, 0x0102000002010102, 0x0102000002020101, 0x0102000100000001,\n    0x0102000100000100, 0x0102000100000102, 0x0102000100000201, 0x0102000100010002,\n    0x0102000100010101, 0x0102000100020001, 0x0102000100020002, 0x0102000100020102,\n    0x0102000100020201, 0x0102000101000101, 0x0102000101000201, 0x0102000101010001,\n    0x0102000101010101, 0x0102000101010102, 0x0102000101010201, 0x0102000101020101,\n    0x0102000101020102, 0x0102000101020202, 0x0102000102000100, 0x0102000102000202,\n    0x0102000102010002, 0x0102000102010101, 0x0102000102020001, 0x0102000102020102,\n    0x0102000102020201, 0x0102000200010001, 0x0102000200010102, 0x0102000200010201,\n    0x0102000201000000, 0x0102000201000001, 0x0102000201000102, 0x0102000201010101,\n    0x0102000201010102, 0x0102000201010200, 0x0102000201020000, 0x0102000202000101,\n    0x0102000202010001, 0x0102000202010102, 0x0102000202020101, 0x0102010000010001,\n    0x0102010000010002, 0x0102010000010101, 0x0102010000010102, 0x0102010000010202,\n    0x0102010000020001, 0x0102010000020102, 0x0102010000020201, 0x0102010001000000,\n    0x0102010001000002, 0x0102010001000101, 0x0102010001000200, 0x0102010001000202,\n    0x0102010001010001, 0x0102010001010100, 0x0102010001010101, 0x0102010001010102,\n    0x0102010001010201, 0x0102010001010202, 0x0102010001020000, 0x0102010001020002,\n    0x0102010001020101, 0x0102010002000100, 0x0102010002000101, 0x0102010002000201,\n    0x0102010002010000, 0x0102010002010002, 0x0102010002010100, 0x0102010002010101,\n    0x0102010002010102, 0x0102010002010200, 0x0102010002010202, 0x0102010002020001,\n    0x0102010002020100, 0x0102010002020201, 0x0102010100000101, 0x0102010100000200,\n    0x0102010100000202, 0x0102010100010001, 0x0102010100010101, 0x0102010100010102,\n    0x0102010100010201, 0x0102010101000100, 0x0102010101000101, 0x0102010101000102,\n    0x0102010101000201, 0x0102010101010000, 0x0102010101010001, 0x0102010101010100,\n    0x0102010101010101, 0x0102010101010102, 0x0102010101010201, 0x0102010101020001,\n    0x0102010101020100, 0x0102010101020101, 0x0102010101020102, 0x0102010101020201,\n    0x0102010102000102, 0x0102010102000201, 0x0102010102000202, 0x0102010102010001,\n    0x0102010102010101, 0x0102010102010102, 0x0102010102010201, 0x0102010102010202,\n    0x0102010102020002, 0x0102010102020101, 0x0102010102020102, 0x0102010102020200,\n    0x0102010200000002, 0x0102010200000201, 0x0102010200010101, 0x0102010200020000,\n    0x0102010200020102, 0x0102010200020200, 0x0102010200020201, 0x0102010201000000,\n    0x0102010201000101, 0x0102010201000200, 0x0102010201000202, 0x0102010201010001,\n    0x0102010201010100, 0x0102010201010101, 0x0102010201010102, 0x0102010201010200,\n    0x0102010201010202, 0x0102010201020000, 0x0102010201020101, 0x0102010201020200,\n    0x0102010202000000, 0x0102010202000002, 0x0102010202000101, 0x0102010202000202,\n    0x0102010202010100, 0x0102010202010102, 0x0102010202010200, 0x0102010202010201,\n    0x0102010202020000, 0x0102010202020100, 0x0102010202020102, 0x0102010202020202,\n    0x0102020000010102, 0x0102020000010201, 0x0102020000020101, 0x0102020001000001,\n    0x0102020001010002, 0x0102020001010101, 0x0102020001010202, 0x0102020001020001,\n    0x0102020001020201, 0x0102020002000101, 0x0102020002010001, 0x0102020002010200,\n    0x0102020002020102, 0x0102020100000001, 0x0102020100000100, 0x0102020100010000,\n    0x0102020100010101, 0x0102020100020001, 0x0102020100020100, 0x0102020100020102,\n    0x0102020100020201, 0x0102020101000000, 0x0102020101000001, 0x0102020101000101,\n    0x0102020101000102, 0x0102020101000200, 0x0102020101010001, 0x0102020101010100,\n    0x0102020101010101, 0x0102020101010102, 0x0102020101010201, 0x0102020101020000,\n    0x0102020101020101, 0x0102020101020202, 0x0102020102000002, 0x0102020102000100,\n    0x0102020102000202, 0x0102020102010101, 0x0102020102020001, 0x0102020102020100,\n    0x0102020102020101, 0x0102020102020201, 0x0102020200010001, 0x0102020200010102,\n    0x0102020200010200, 0x0102020201000001, 0x0102020201000100, 0x0102020201000201,\n    0x0102020201010000, 0x0102020201010101, 0x0102020201010200, 0x0102020201010202,\n    0x0102020201020100, 0x0102020201020101, 0x0102020201020201, 0x0102020202000102,\n    0x0102020202010100, 0x0102020202010200, 0x0102020202010202, 0x0102020202020102,\n    0x0200000000000000, 0x0200000000000002, 0x0200000000000200, 0x0200000000000202,\n    0x0200000000020000, 0x0200000000020002, 0x0200000000020200, 0x0200000000020202,\n    0x0200000001000101, 0x0200000001010000, 0x0200000001010001, 0x0200000001010100,\n    0x0200000001010102, 0x0200000001010201, 0x0200000001020101, 0x0200000002000000,\n    0x0200000002000002, 0x0200000002000200, 0x0200000002000202, 0x0200000002010101,\n    0x0200000002020000, 0x0200000002020002, 0x0200000002020200, 0x0200000002020202,\n    0x0200000100000101, 0x0200000100010001, 0x0200000100010100, 0x0200000100010102,\n    0x0200000100010201, 0x0200000100020101, 0x0200000101000001, 0x0200000101000100,\n    0x0200000101000201, 0x0200000101010000, 0x0200000101010002, 0x0200000101010101,\n    0x0200000101010102, 0x0200000101010200, 0x0200000101010201, 0x0200000101020100,\n    0x0200000101020102, 0x0200000101020201, 0x0200000102000101, 0x0200000102000201,\n    0x0200000102010100, 0x0200000102010102, 0x0200000102010201, 0x0200000102020101,\n    0x0200000200000000, 0x0200000200000002, 0x0200000200000200, 0x0200000200000202,\n    0x0200000200010101, 0x0200000200020000, 0x0200000200020002, 0x0200000200020200,\n    0x0200000200020202, 0x0200000201010001, 0x0200000201010100, 0x0200000201010201,\n    0x0200000201020101, 0x0200000202000000, 0x0200000202000002, 0x0200000202000200,\n    0x0200000202000202, 0x0200000202010101, 0x0200000202020000, 0x0200000202020002,\n    0x0200000202020200, 0x0200000202020202, 0x0200010000010100, 0x0200010000010201,\n    0x0200010001000001, 0x0200010001000100, 0x0200010001010001, 0x0200010001010101,\n    0x0200010001010202, 0x0200010001020001, 0x0200010001020100, 0x0200010001020201,\n    0x0200010002010100, 0x0200010002010201, 0x0200010100000001, 0x0200010100000201,\n    0x0200010100010002, 0x0200010100010101, 0x0200010100010202, 0x0200010100020102,\n    0x0200010100020201, 0x0200010101000000, 0x0200010101000001, 0x0200010101000101,\n    0x0200010101000200, 0x0200010101010001, 0x0200010101010100, 0x0200010101010101,\n    0x0200010101010102, 0x0200010101010201, 0x0200010101010202, 0x0200010101020101,\n    0x0200010101020102, 0x0200010101020200, 0x0200010101020202, 0x0200010102000001,\n    0x0200010102000100, 0x0200010102000102, 0x0200010102000201, 0x0200010102010000,\n    0x0200010102010002, 0x0200010102010101, 0x0200010102010200, 0x0200010102020102,\n    0x0200010200010001, 0x0200010200010102, 0x0200010200010201, 0x0200010200020101,\n    0x0200010201000001, 0x0200010201000100, 0x0200010201000201, 0x0200010201000202,\n    0x0200010201010000, 0x0200010201010101, 0x0200010201010201, 0x0200010201010202,\n    0x0200010201020001, 0x0200010201020102, 0x0200010201020202, 0x0200010202000101,\n    0x0200010202010001, 0x0200010202010202, 0x0200010202020100, 0x0200020000000000,\n    0x0200020000000002, 0x0200020000000200, 0x0200020000000202, 0x0200020000010101,\n    0x0200020000020000, 0x0200020000020002, 0x0200020000020200, 0x0200020000020202,\n    0x0200020001000001, 0x0200020001000101, 0x0200020001010001, 0x0200020001010100,\n    0x0200020001010201, 0x0200020001020101, 0x0200020001020201, 0x0200020002000000,\n    0x0200020002000002, 0x0200020002000200, 0x0200020002000202, 0x0200020002010101,\n    0x0200020002020000, 0x0200020002020002, 0x0200020002020200, 0x0200020002020202,\n    0x0200020100000101, 0x0200020100000102, 0x0200020100010001, 0x0200020100010100,\n    0x0200020100010102, 0x0200020100020101, 0x0200020101000001, 0x0200020101000100,\n    0x0200020101000102, 0x0200020101000201, 0x0200020101010000, 0x0200020101010002,\n    0x0200020101010101, 0x0200020101010202, 0x0200020101020001, 0x0200020101020100,\n    0x0200020102000101, 0x0200020102010102, 0x0200020102010201, 0x0200020102020101,\n    0x0200020200000000, 0x0200020200000002, 0x0200020200000200, 0x0200020200000202,\n    0x0200020200010101, 0x0200020200020000, 0x0200020200020002, 0x0200020200020200,\n    0x0200020200020202, 0x0200020201000101, 0x0200020201010001, 0x0200020201010100,\n    0x0200020201010102, 0x0200020202000000, 0x0200020202000002, 0x0200020202000200,\n    0x0200020202000202, 0x0200020202010101, 0x0200020202020000, 0x0200020202020002,\n    0x0200020202020200, 0x0200020202020202, 0x0201000000000101, 0x0201000000010001,\n    0x0201000000010102, 0x0201000000010200, 0x0201000000010201, 0x0201000000020101,\n    0x0201000001000001, 0x0201000001000102, 0x0201000001000201, 0x0201000001010101,\n    0x0201000001010200, 0x0201000001010202, 0x0201000001020201, 0x0201000001020202,\n    0x0201000002000101, 0x0201000002010001, 0x0201000002010100, 0x0201000002010102,\n    0x0201000002010201, 0x0201000002020101, 0x0201000100000001, 0x0201000100000100,\n    0x0201000100000102, 0x0201000100000201, 0x0201000100010000, 0x0201000100010101,\n    0x0201000100010200, 0x0201000100010202, 0x0201000100020001, 0x0201000100020100,\n    0x0201000100020102, 0x0201000100020201, 0x0201000101000000, 0x0201000101000101,\n    0x0201000101010000, 0x0201000101010001, 0x0201000101010100, 0x0201000101010101,\n    0x0201000101010102, 0x0201000101010201, 0x0201000101020002, 0x0201000101020101,\n    0x0201000102000100, 0x0201000102000102, 0x0201000102010002, 0x0201000102010101,\n    0x0201000102010200, 0x0201000102020001, 0x0201000102020100, 0x0201000102020102,\n    0x0201000102020201, 0x0201000200000101, 0x0201000200010001, 0x0201000200010100,\n    0x0201000200010201, 0x0201000200020101, 0x0201000201000100, 0x0201000201000102,\n    0x0201000201000201, 0x0201000201010000, 0x0201000201010002, 0x0201000201010101,\n    0x0201000201010200, 0x0201000201020102, 0x0201000201020201, 0x0201000202000101,\n    0x0201000202010100, 0x0201000202010102, 0x0201000202020201, 0x0201010000000001,\n    0x0201010000000100, 0x0201010000000102, 0x0201010000010000, 0x0201010000010101,\n    0x0201010000010200, 0x0201010000020102, 0x0201010001000000, 0x0201010001000202,\n    0x0201010001010001, 0x0201010001010100, 0x0201010001010101, 0x0201010001010102,\n    0x0201010001010200, 0x0201010001010201, 0x0201010001020000, 0x0201010001020001,\n    0x0201010001020002, 0x0201010001020101, 0x0201010002000100, 0x0201010002000102,\n    0x0201010002010002, 0x0201010002010100, 0x0201010002010101, 0x0201010002010200,\n    0x0201010002020001, 0x0201010002020201, 0x0201010100000000, 0x0201010100000101,\n    0x0201010100000200, 0x0201010100000202, 0x0201010100010000, 0x0201010100010001,\n    0x0201010100010100, 0x0201010100010101, 0x0201010100010102, 0x0201010100010201,\n    0x0201010100020001, 0x0201010100020101, 0x0201010100020201, 0x0201010100020202,\n    0x0201010101000001, 0x0201010101000100, 0x0201010101000101, 0x0201010101000102,\n    0x0201010101000201, 0x0201010101010000, 0x0201010101010001, 0x0201010101010002,\n    0x0201010101010100, 0x0201010101010101, 0x0201010101010102, 0x0201010101010200,\n    0x0201010101010201, 0x0201010101010202, 0x0201010101020001, 0x0201010101020100,\n    0x0201010101020101, 0x0201010101020102, 0x0201010101020201, 0x0201010102000001,\n    0x0201010102000101, 0x0201010102000200, 0x0201010102010001, 0x0201010102010002,\n    0x0201010102010100, 0x0201010102010101, 0x0201010102010102, 0x0201010102010201,\n    0x0201010102010202, 0x0201010102020000, 0x0201010102020002, 0x0201010102020101,\n    0x0201010102020200, 0x0201010102020202, 0x0201010200000001, 0x0201010200000100,\n    0x0201010200010000, 0x0201010200010101, 0x0201010200010201, 0x0201010200020000,\n    0x0201010200020102, 0x0201010200020201, 0x0201010201000101, 0x0201010201000200,\n    0x0201010201000201, 0x0201010201010001, 0x0201010201010002, 0x0201010201010101,\n    0x0201010201010102, 0x0201010201010201, 0x0201010201020101, 0x0201010201020200,\n    0x0201010202000002, 0x0201010202000100, 0x0201010202000201, 0x0201010202000202,\n    0x0201010202010002, 0x0201010202010100, 0x0201010202010101, 0x0201010202020100,\n    0x0201010202020102, 0x0201010202020201, 0x0201020000000101, 0x0201020000010102,\n    0x0201020000010201, 0x0201020000020101, 0x0201020001000001, 0x0201020001000102,\n    0x0201020001010000, 0x0201020001010002, 0x0201020001010101, 0x0201020001010102,\n    0x0201020001010202, 0x0201020001020100, 0x0201020001020101, 0x0201020002000101,\n    0x0201020002010001, 0x0201020002010102, 0x0201020002010201, 0x0201020002020101,\n    0x0201020100000100, 0x0201020100000102, 0x0201020100000201, 0x0201020100010000,\n    0x0201020100010002, 0x0201020100010101, 0x0201020100010200, 0x0201020100010202,\n    0x0201020100020000, 0x0201020100020001, 0x0201020100020100, 0x0201020100020102,\n    0x0201020101000000, 0x0201020101000002, 0x0201020101000101, 0x0201020101000200,\n    0x0201020101000202, 0x0201020101010001, 0x0201020101010100, 0x0201020101010101,\n    0x0201020101010102, 0x0201020101010201, 0x0201020101020002, 0x0201020101020101,\n    0x0201020101020102, 0x0201020101020202, 0x0201020102000001, 0x0201020102000100,\n    0x0201020102010000, 0x0201020102010002, 0x0201020102010101, 0x0201020102010202,\n    0x0201020102020001, 0x0201020102020102, 0x0201020200000101, 0x0201020200010101,\n    0x0201020200020101, 0x0201020201000100, 0x0201020201000102, 0x0201020201000201,\n    0x0201020201010000, 0x0201020201010101, 0x0201020201010200, 0x0201020201020001,\n    0x0201020202000101, 0x0201020202010001, 0x0201020202010100, 0x0201020202010101,\n    0x0201020202010102, 0x0202000000000000, 0x0202000000000002, 0x0202000000000200,\n    0x0202000000000202, 0x0202000000010101, 0x0202000000020000, 0x0202000000020002,\n    0x0202000000020200, 0x0202000000020202, 0x0202000001000101, 0x0202000001010001,\n    0x0202000001010100, 0x0202000001010102, 0x0202000001010201, 0x0202000002000000,\n    0x0202000002000002, 0x0202000002000200, 0x0202000002000202, 0x0202000002010101,\n    0x0202000002020000, 0x0202000002020002, 0x0202000002020200, 0x0202000002020202,\n    0x0202000100000101, 0x0202000100000201, 0x0202000100010001, 0x0202000100010100,\n    0x0202000100010102, 0x0202000100010201, 0x0202000100010202, 0x0202000101000102,\n    0x0202000101000201, 0x0202000101010001, 0x0202000101010101, 0x0202000101010200,\n    0x0202000101010202, 0x0202000101020001, 0x0202000101020100, 0x0202000102000101,\n    0x0202000102010000, 0x0202000102010002, 0x0202000102010102, 0x0202000102010201,\n    0x0202000200000002, 0x0202000200000200, 0x0202000200000202, 0x0202000200010000,\n    0x0202000200010201, 0x0202000200020002, 0x0202000200020200, 0x0202000200020202,\n    0x0202000201000101, 0x0202000201010001, 0x0202000201010102, 0x0202000201010201,\n    0x0202000201020101, 0x0202000202000000, 0x0202000202000002, 0x0202000202000200,\n    0x0202000202000202, 0x0202000202010101, 0x0202000202020000, 0x0202000202020002,\n    0x0202000202020200, 0x0202000202020202, 0x0202010000010201, 0x0202010000020101,\n    0x0202010001000001, 0x0202010001000100, 0x0202010001010000, 0x0202010001010100,\n    0x0202010001010101, 0x0202010001010200, 0x0202010001010202, 0x0202010001020001,\n    0x0202010001020101, 0x0202010001020102, 0x0202010001020200, 0x0202010001020201,\n    0x0202010002000101, 0x0202010100000102, 0x0202010100000201, 0x0202010100010000,\n    0x0202010100010002, 0x0202010100010101, 0x0202010100010200, 0x0202010100020102,\n    0x0202010100020201, 0x0202010101000002, 0x0202010101000101, 0x0202010101010001,\n    0x0202010101010100, 0x0202010101010101, 0x0202010101010102, 0x0202010101010201,\n    0x0202010101020101, 0x0202010101020202, 0x0202010102000001, 0x0202010102000100,\n    0x0202010102000101, 0x0202010102000102, 0x0202010102000201, 0x0202010102010002,\n    0x0202010102010101, 0x0202010102010200, 0x0202010200000101, 0x0202010200010001,\n    0x0202010200010102, 0x0202010200010202, 0x0202010200020001, 0x0202010200020101,\n    0x0202010201000100, 0x0202010201000102, 0x0202010201000202, 0x0202010201010002,\n    0x0202010201010101, 0x0202010201010102, 0x0202010201010200, 0x0202010201020000,\n    0x0202010201020002, 0x0202010202000102, 0x0202010202010000, 0x0202010202010101,\n    0x0202010202010102, 0x0202010202010201, 0x0202010202020001, 0x0202010202020100,\n    0x0202010202020102, 0x0202020000000000, 0x0202020000000002, 0x0202020000000200,\n    0x0202020000000202, 0x0202020000020000, 0x0202020000020002, 0x0202020000020200,\n    0x0202020000020202, 0x0202020001010001, 0x0202020001010100, 0x0202020001010102,\n    0x0202020001010201, 0x0202020002000000, 0x0202020002000002, 0x0202020002000200,\n    0x0202020002000202, 0x0202020002010101, 0x0202020002020000, 0x0202020002020002,\n    0x0202020002020200, 0x0202020002020202, 0x0202020100000101, 0x0202020100010100,\n    0x0202020100010201, 0x0202020100020001, 0x0202020100020101, 0x0202020101000001,\n    0x0202020101010000, 0x0202020101010101, 0x0202020101010202, 0x0202020101020001,\n    0x0202020101020102, 0x0202020101020201, 0x0202020102010000, 0x0202020102010102,\n    0x0202020200000000, 0x0202020200000002, 0x0202020200000200, 0x0202020200000202,\n    0x0202020200020000, 0x0202020200020002, 0x0202020200020200, 0x0202020200020202,\n    0x0202020201010001, 0x0202020201010100, 0x0202020201010102, 0x0202020202000000,\n    0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101,\n    0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202,\n};\n#else\nstatic const uint32_t iq1s_grid_us[2048] = {\n    0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,\n    0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,\n    0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,\n    0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,\n    0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,\n    0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,\n    0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,\n    0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,\n    0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,\n    0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,\n    0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,\n    0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,\n    0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,\n    0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,\n    0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,\n    0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,\n    0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,\n    0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,\n    0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,\n    0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,\n    0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,\n    0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,\n    0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,\n    0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,\n    0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,\n    0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,\n    0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,\n    0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,\n    0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,\n    0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,\n    0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,\n    0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,\n    0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,\n    0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,\n    0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,\n    0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,\n    0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,\n    0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,\n    0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,\n    0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,\n    0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,\n    0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,\n    0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,\n    0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,\n    0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,\n    0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,\n    0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,\n    0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,\n    0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,\n    0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,\n    0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,\n    0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,\n    0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,\n    0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,\n    0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,\n    0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,\n    0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,\n    0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,\n    0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,\n    0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,\n    0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,\n    0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,\n    0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,\n    0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,\n    0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,\n    0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,\n    0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,\n    0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,\n    0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,\n    0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,\n    0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,\n    0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,\n    0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,\n    0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,\n    0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,\n    0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,\n    0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,\n    0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,\n    0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,\n    0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,\n    0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,\n    0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,\n    0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,\n    0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,\n    0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,\n    0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,\n    0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,\n    0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,\n    0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,\n    0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,\n    0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,\n    0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,\n    0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,\n    0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,\n    0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,\n    0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,\n    0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,\n    0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,\n    0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,\n    0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,\n    0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,\n    0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,\n    0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,\n    0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,\n    0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,\n    0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,\n    0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,\n    0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,\n    0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,\n    0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,\n    0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,\n    0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,\n    0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,\n    0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,\n    0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,\n    0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,\n    0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,\n    0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,\n    0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,\n    0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,\n    0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,\n    0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,\n    0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,\n    0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,\n    0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,\n    0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,\n    0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,\n    0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,\n    0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,\n    0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,\n    0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,\n    0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,\n    0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,\n    0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,\n    0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,\n    0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,\n    0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,\n    0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,\n    0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,\n    0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,\n    0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,\n    0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,\n    0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,\n    0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,\n    0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,\n    0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,\n    0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,\n    0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,\n    0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,\n    0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,\n    0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,\n    0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,\n    0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,\n    0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,\n    0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,\n    0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,\n    0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,\n    0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,\n    0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,\n    0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,\n    0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,\n    0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,\n    0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,\n    0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,\n    0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,\n    0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,\n    0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,\n    0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,\n    0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,\n    0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,\n    0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,\n    0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,\n    0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,\n    0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,\n    0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,\n    0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,\n    0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,\n    0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,\n    0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,\n    0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,\n    0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,\n    0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,\n    0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,\n    0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,\n    0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,\n    0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,\n    0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,\n    0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,\n    0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,\n    0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,\n    0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,\n    0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,\n    0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,\n    0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,\n    0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,\n    0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,\n    0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,\n    0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,\n    0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,\n    0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,\n    0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,\n    0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,\n    0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,\n    0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,\n    0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,\n    0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,\n    0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,\n    0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,\n    0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,\n    0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,\n    0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,\n    0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,\n    0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,\n    0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,\n    0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,\n    0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,\n    0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,\n    0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,\n    0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,\n    0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,\n    0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,\n    0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,\n    0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,\n    0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,\n    0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,\n    0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,\n    0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,\n    0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,\n    0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,\n    0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,\n    0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,\n    0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,\n    0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,\n    0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,\n    0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,\n    0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,\n    0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,\n    0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,\n    0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,\n    0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,\n    0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,\n    0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,\n    0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,\n    0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,\n    0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,\n    0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,\n    0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,\n    0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,\n    0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,\n    0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,\n    0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,\n    0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,\n    0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,\n    0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,\n    0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,\n    0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,\n};\n#endif\n\n#ifndef HAVE_FANCY_SIMD\nconst uint64_t keven_signs[128] = {\n    0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,\n    0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,\n    0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,\n    0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,\n    0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,\n    0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,\n    0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,\n    0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,\n    0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,\n    0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,\n    0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,\n    0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,\n    0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,\n    0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,\n    0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,\n    0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,\n    0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,\n    0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,\n    0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,\n    0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,\n    0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,\n    0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,\n    0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,\n    0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,\n    0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,\n    0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,\n    0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,\n    0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,\n    0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,\n    0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,\n    0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,\n    0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,\n};\n#endif\n\n}\n\n/* moonll change mulmat\nadd typeB and strideB\n}*/\n\nbool iqk_mul_mat(long Nx, long Ny, long ne00,\n    int typeA, const void * A, long strideA,\n    int typeB, const void * B, long strideB,\n    float * C, long stride_C, int ith, int nth) {\n\n        MulMat mm;\n#if defined __x86_64__ || defined(_M_X64)\n        if (!MulMat::set_mul_mat(typeA, typeB, (int)ne00, mm, Ny)) {\n            return false;\n        }\n#else\n        int row_size_q8;\n        if (!MulMat::set_mul_mat(typeA, (int)ne00, mm, row_size_q8, Ny)) {\n            return false;\n        }\n#endif\n\n\n        size_t row_size_qx = strideA*ggml_type_size(ggml_type(typeA));\n        size_t row_size_qy = strideB*ggml_type_size(ggml_type(typeB));\n      \n        \n        auto nrc_x = (Nx + nth - 1)/nth;\n        auto first_x = ith*nrc_x;\n        if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;\n\n        DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};\n#ifdef __ARM_NEON\n#ifdef GEMM_Q4K_Q6K\n        if (Ny >= 8 && (typeA == GGML_TYPE_Q4_K || typeA == GGML_TYPE_Q6_K)) {\n            mm.mul_mat_NxM_v2(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);\n        } else\n#endif\n#endif\n        {\n            mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);\n        }\n\n        return true;\n}\n\n\nbool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B,\n        float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {\n    const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;\n    assert(row_mapping != nullptr);\n\n    MulMat mm;\n    int row_size_q8;\n    /* moonll\n\n    if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) {\n        return false;\n    }*/\n    int row_size_qx = ggml_row_size((ggml_type)typeA, ne00);\n    int nrc_x = (Nx + nth - 1)/nth;\n    int first_x = ith*nrc_x;\n    if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;\n    DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)};\n    mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);\n    return true;\n}\n\n#if defined __x86_64__ || defined(_M_X64)\n\n#if defined HAVE_FANCY_SIMD\n    #undef HAVE_FANCY_SIMD\n#endif\n#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)\n    #define HAVE_FANCY_SIMD\n#endif\n//#define HAVE_FANCY_SIMD\n\nnamespace {\n\ninline float hsum_float_4(__m128 x) {\n    x = _mm_add_ps(x, _mm_movehl_ps(x, x));\n    x = _mm_add_ss(x, _mm_movehdup_ps(x));\n    return _mm_cvtss_f32(x);\n}\ninline float hsum_float_8(__m256 x) {\n    return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));\n}\n\n#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)\n\n\ntemplate <int nrc, typename block_q8 = block_q8_K> struct Q8 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q8(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);\n    }\n\n#ifdef HAVE_FANCY_SIMD\n    inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); }\n#endif\n    inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }\n    inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); }\n    inline float scale(int iy, int i) const { return y[iy][i].d; }\n\n    const block_q8 * y[nrc_y];\n};\n\n// Handles q4_K and q5_K scales/mins\nstruct Scales8K {\n    template <typename Q8>\n    inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {\n        make_q4_scales(data, utmp);\n        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));\n        const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);\n        accum_mins(mins128, q8, i, c, accd);\n        const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);\n        return MM256_SET_M128I(sc128, sc128);\n    }\n#ifdef HAVE_FANCY_SIMD\n    template <typename Q8>\n    inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {\n        auto scales = process_mins_and_scales(data, c, i, q8, accd);\n        return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1);\n    }\n#endif\n    template <typename Q8>\n    inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {\n        const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i q8s = q8.load_bsums(iy, i);\n            const __m256i prod = _mm256_madd_epi16(mins, q8s);\n            accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);\n        }\n    }\n#ifdef HAVE_FANCY_SIMD\n    const __m512i shuffles512[2] = {\n        _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,\n                         0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),\n        _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,\n                         0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)\n    };\n#endif\n    const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),\n                                 _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};\n\n    uint32_t utmp[4];\n};\n\ntemplate <typename Q8>\ninline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        const __m256i prod  = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i));\n        accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);\n    }\n}\ninline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {\n    const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);\n    const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);\n    scales[0] = MM256_SET_M128I(l_scales, l_scales);\n    scales[1] = MM256_SET_M128I(h_scales, h_scales);\n}\n\nstruct ScaleQ3 {\n    inline __m128i make_scales(const uint16_t * s8) const {\n        const uint16_t * scales16 = (const uint16_t *)s8;\n        uint32_t aux0 = scales16[0] | (scales16[1] << 16);\n        uint32_t aux1 = scales16[2] | (scales16[3] << 16);\n        uint32_t aux2 = scales16[4] | (scales16[5] << 16);\n        __m128i scales128 = _mm_set_epi32(\n            ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030),\n            ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030),\n             (aux1       & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030),\n             (aux0       & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030));\n        return _mm_add_epi8(scales128, m32);\n    }\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct ScaleIQ4XS {\n    inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) {\n        uint32_t tmp32 = scales_h | (scales_h << 14);\n        const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4);\n        const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask);\n        return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32);\n    }\n    const __m128i hshift = _mm_set_epi32(12, 8, 4, 0);\n    const __m128i lshift = _mm_set_epi32(4, 0, 4, 0);\n    const __m128i hmask  = _mm_set1_epi16(0x03);\n    const __m128i lmask  = _mm_set1_epi8(0xf);\n    const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400);\n    const __m128i m32 = _mm_set1_epi16(-32);\n};\n\nstruct Scales8KBase {\n    template <typename Q8>\n    inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {\n        const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i q8s = q8.load_bsums(iy, i);\n            const __m256i prod = _mm256_madd_epi16(mins, q8s);\n            accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);\n        }\n    }\n    inline __m256i shuffle(__m128i mins) const {\n        return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));\n    }\n    const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),\n                                 _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};\n};\n\ntemplate <typename Block>\nstruct BaseDequantizer {\n    BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}\n    inline void new_row(int ix) {\n        x = (const Block *)((const char *)vx + bx*ix);\n    }\n\n    const void *  vx;\n    size_t        bx;\n    const Block * x;\n\n    float d;\n};\n\n__m128i inline load_iq4nl_values_128() {\n    static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};\n    return _mm_loadu_si128((const __m128i *)kvalues_iq4nl);\n}\n\n__m256i inline load_iq4nl_values_256() {\n    auto val128 = load_iq4nl_values_128();\n    return MM256_SET_M128I(val128, val128);\n}\n\n#ifdef HAVE_FANCY_SIMD\n//====================================== Zen4 ==================================================\n\nstruct BlockPermuter {\n    const __m512i permute1 = _mm512_set_epi64(11, 10,  9,  8, 3, 2, 1, 0);\n    const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);\n};\n\nstruct Q4Bits {\n    inline void prepare(const uint8_t * q4) {\n        auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);\n        auto tmp1 = _mm512_and_si512(q4bits, ml);\n        auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);\n        values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);\n        q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);\n        tmp1 = _mm512_and_si512(q4bits, ml);\n        tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);\n        values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);\n    }\n    inline void prepare64(const uint8_t * q4) {\n        auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);\n        values[0] = _mm512_and_si512(q4bits, ml);\n        values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);\n        values[2] = _mm512_and_si512(q4bits, ml);\n        values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n    }\n    __m512i values[4];\n    const __m512i ml = _mm512_set1_epi8(0xf);\n    BlockPermuter perm;\n};\n\nstruct Q2Bits {\n    inline void prepare(const uint8_t * q2) {\n\n        auto q2bits = _mm512_loadu_si512((const __m512i*)q2);\n        auto tmp = _mm512_srli_epi16(q2bits, 2);\n\n        values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);\n        values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);\n        values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);\n        values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);\n        values[0] = _mm512_and_si512(values[0], ml);\n        values[2] = _mm512_and_si512(values[2], ml);\n    }\n    __m512i values[4];\n    const __m512i ml = _mm512_set1_epi8(0x03);\n    BlockPermuter perm;\n};\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n        scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n};\n\n/*\nmoonll DequantizerIQ4XS\n*/\n\n__m512i inline load_iq4nl_values_512() {\n    auto val256 = load_iq4nl_values_256();\n    return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);\n}\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n    DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        prepare(x[i].qs);\n        auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);\n        s8k.accum_mins(scales128, q8, i, -128.f*d, accd);\n        auto scales256 = MM256_SET_M128I(scales128, scales128);\n        auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);\n        scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);\n        scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);\n        scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);\n    }\n    inline void prepare(const uint8_t * q4) {\n        bits.prepare64(q4);\n        // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111\n        //                bits.valuse[1]: 16..31, 48...63, 80...95, 112..127\n        //                etc.\n        auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);\n        bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));\n        bits.values[0] = _mm512_shuffle_epi8(values, tmp);\n        tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);\n        bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));\n        bits.values[2] = _mm512_shuffle_epi8(values, tmp);\n    }\n\n    Q4Bits bits;\n    Scales8KBase s8k;\n    ScaleIQ4XS siq4;\n    const __m512i values;\n    const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2,  9,  8, 1, 0);\n    const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);\n    const __m512i shuffles[4] = {\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),\n    };\n};\n\nstruct HighBit5 {\n    inline void apply(const uint8_t * h, Q4Bits& bits) {\n        auto hbits256 = _mm256_loadu_si256((const __m256i *)h);\n        auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));\n    }\n    const __m512i mh = _mm512_set1_epi8(0x10);\n};\n\nstruct HighBit3 {\n    inline void apply(const uint8_t * h, Q2Bits& bits) {\n        auto hbits256 = _mm256_loadu_si256((const __m256i *)h);\n        auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh));\n    }\n    const __m512i mh = _mm512_set1_epi8(0x04);\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        hbits.apply(x[i].qh, bits);\n        auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n        scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);\n    }\n\n    Q4Bits bits;\n    HighBit5 hbits;\n    Scales8K s8k;\n};\n\nstruct Scale16 {\n    inline void make_scales(const __m128i& scales8, __m512i * scales) const {\n        auto all_scales8 = MM256_SET_M128I(scales8, scales8);\n        auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1);\n        auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2);\n        scales[0] = _mm512_cvtepi8_epi16(scales1);\n        scales[1] = _mm512_cvtepi8_epi16(scales2);\n    }\n    template <typename Q8>\n    inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8,\n        const Q8& q8, __m256 * accm, __m512i * scales) const {\n        process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm);\n        make_scales(scales8, scales);\n    }\n    const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202,\n                                              0x05050505, 0x01010101, 0x04040404, 0x00000000);\n    const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a,\n                                              0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808);\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);\n        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);\n        sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales);\n    }\n\n    Q2Bits bits;\n    Scale16 sc16;\n    const __m128i m4 = _mm_set1_epi8(0xf);\n\n};\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        hbits.apply(x[i].hmask, bits);\n        auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales);\n        sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales);\n    }\n\n    Q2Bits bits;\n    HighBit3 hbits;\n    ScaleQ3 sc3;\n    Scale16 sc16;\n    const __m128i m4  = _mm_set1_epi8(0xf);\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare64(x[i].ql);\n        add_high_bits(x[i].qh, bits);\n        auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales);\n        sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales);\n    }\n\n    inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const {\n        auto hbits = _mm512_loadu_si512((const __m512i *)qh);\n        auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh);\n        auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));\n        tmp1 = _mm512_and_si512(hbits, mh);\n        tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh);\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));\n    }\n\n    Q4Bits bits;\n    HighBit3 hbits;\n    Scale16 sc16;\n\n    const __m512i mh = _mm512_set1_epi8(0x30);\n\n};\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accm[nrc_y];\n    __m512  accd[nrc_y];\n    __m512i scales[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();\n        for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accm, scales);\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants(iy, i, 0));\n                const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants(iy, i, 1));\n                const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants(iy, i, 2));\n                const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants(iy, i, 3));\n                auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));\n                sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));\n                accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));\n            info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));\n        }\n\n    }\n}\ntemplate <typename Q8>\ninline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {\n    const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));\n    const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));\n    const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));\n    const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));\n    auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));\n    sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));\n    accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accm[nrc_y];\n    __m512  accd[nrc_y];\n    __m512i scales[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();\n        for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accm, scales);\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));\n                const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));\n                const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));\n                const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));\n                auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));\n                sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));\n                accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));\n            info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));\n        }\n\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accm[nrc_y];\n    __m512  accd[nrc_y];\n    __m512i scales[4];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();\n        for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accm, scales);\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0));\n                const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1));\n                const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2));\n                const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3));\n                auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(),\n                                    p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]);\n                accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));\n            info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));\n        }\n\n    }\n}\n\ntemplate <typename Dequantizer>\nstatic void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    constexpr int k_nx = 2;\n\n    Q8<1> q8(info);\n\n    Dequantizer deq1(vx, bx);\n    Dequantizer deq2(vx, bx);\n\n    Dequantizer * deq[k_nx];\n    deq[0] = &deq1;\n    deq[1] = &deq2;\n\n    __m512i scales[2*k_nx];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        auto accd = _mm512_setzero_ps();\n        auto accm = _mm256_setzero_ps();\n\n        for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);\n\n        for (int i = 0; i < nb/k_nx; ++i) {\n\n            for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);\n\n            for (int kx = 0; kx < k_nx; ++kx) {\n                compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);\n            }\n\n        }\n        if (2*(nb/2) < nb) {\n            int i0 = 2*(nb/2);\n            deq[0]->new_block(i0, q8, &accm, scales);\n            compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);\n        }\n\n        auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));\n        info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));\n    }\n}\n\n#else\n// ===================================== Vanilla AVX2 =====================================\n\nstruct Q4Bits {\n    inline void prepare(const uint8_t * q4, int j) {\n        auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);\n        values[0] = _mm256_and_si256(q4bits, ml);\n        values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);\n        values[2] = _mm256_and_si256(q4bits, ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n    }\n    inline void prepare64(const uint8_t * q4, int j) {\n        auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);\n        values[0] = _mm256_and_si256(q4bits, ml);\n        values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);\n        values[1] = _mm256_and_si256(q4bits, ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n    }\n    inline void prepare16(const uint8_t * q4, int j) {\n        values[0] = dequant16(q4 + 64*j +  0);\n        values[1] = dequant16(q4 + 64*j + 16);\n        values[2] = dequant16(q4 + 64*j + 32);\n        values[3] = dequant16(q4 + 64*j + 48);\n    }\n    inline __m256i dequant16(const uint8_t * qs) const {\n        const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);\n        const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);\n        return _mm256_and_si256(ml, aux256);\n    };\n    __m256i values[4];\n    const __m256i ml = _mm256_set1_epi8(0xf);\n};\n\nstruct Q2Bits {\n    inline void prepare(const uint8_t * q2, int j) {\n        auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);\n        values[0] = _mm256_and_si256(q2bits, ml);\n        values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);\n        values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);\n    }\n    __m256i values[4];\n    const __m256i ml = _mm256_set1_epi8(0x03);\n};\n\nstruct HighBit5 {\n    inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }\n    inline void apply(Q4Bits& bits, bool do_shift) {\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));\n        if (do_shift) {\n            hbits = _mm256_srli_epi16(hbits, 4);\n        }\n    }\n    const __m256i mh = _mm256_set1_epi8(0x10);\n    __m256i hbits;\n};\n\nstruct HighBit3 {\n    inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }\n    inline void apply(Q2Bits& bits, bool do_shift) {\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));\n        if (do_shift) {\n            hbits = _mm256_srli_epi16(hbits, 4);\n        }\n    }\n    const __m256i mh = _mm256_set1_epi8(0x04);\n    __m256i hbits;\n};\n\n\n/*\ntemplate <typename Q8, typename Bits>\ninline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {\n    if (j == 0) {\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));\n            sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));\n        }\n    } else {\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));\n        }\n    }\n}*/\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n};\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n    DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);\n        s8k.accum_mins(scales128, q8, i, -128.f*d, accd);\n        return MM256_SET_M128I(scales128, scales128);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare16(x[i].qs, j);\n        bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);\n        bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);\n        bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);\n        bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);\n    }\n\n    static __m256i load_values() {\n        static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};\n        auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);\n        return MM256_SET_M128I(val128, val128);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n    ScaleIQ4XS siq4;\n    const __m256i values;\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        hbits.load(x[i].qh);\n        return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n        hbits.apply(bits, j == 0);\n    }\n\n    Q4Bits  bits;\n    HighBit5 hbits;\n    Scales8K s8k;\n};\n\ntemplate <typename Q8>\ninline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d,\n    __m256 * accm, __m256i * scales) {\n    const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);\n    process_mins_16(all_scales, q8, i, d, accm);\n    prepare_scales_16(all_scales, scales);\n}\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        hbits.load(x[i].hmask);\n        process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n        hbits.apply(bits, j == 0);\n    }\n\n    Q2Bits  bits;\n    HighBit3 hbits;\n    ScaleQ3 sc3;\n\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);\n        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);\n        process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm);\n        prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n    }\n\n    Q2Bits  bits;\n\n    const __m128i m4 = _mm_set1_epi8(0xf);\n};\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare64(x[i].ql, j);\n        auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));\n    }\n\n    Q4Bits  bits;\n    const __m256i mh = _mm256_set1_epi8(0x30);\n};\n\ninline __m256i get_scale_shuffle_8(int i);\n\ninline void set_scales_8(const __m256i& all_scales, int j, __m256i* scales);\n\ninline __m256i get_scale_shuffle_16(int i);\n\ninline void set_scales_16(const __m256i& all_scales, __m256i* scales);\n\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%QK_K == 0);\n    const int nb = n/QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    __m256i all_scales[2];\n    __m256i scales[4];\n    __m256  accd[nrc_y];\n\n    Dequantizer deq(vx, bx);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accd, all_scales);\n\n            __m256i sumi[nrc_y];\n\n            for (int j = 0; j < QK_K/128; ++j) {\n                deq.prepare(i, j);\n                set_scales_16(all_scales[j], scales);\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n            }\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n\n    }\n\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accd[nrc_y];\n    __m256i scales[4];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            auto all_scales = deq.new_block(i, q8, accd);\n\n            __m256i sumi[nrc_y];\n\n            for (int j = 0; j < QK_K/128; ++j) {\n\n                deq.prepare(i, j);\n\n                set_scales_8(all_scales, j, scales);\n\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n\n            }\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));\n                accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n\n    }\n}\n#endif  // Zen4 or vanilla AVX2\n\n\n\n//\n// ============================== Legacy quants\n//\n\nstruct DotHelper {\n    const __m256i m1 = _mm256_set1_epi16(1);\n#if defined(__AVX512VNNI__) && defined(__AVX512VL__)\n    inline __m256i dot(__m256i x, __m256i y) const {\n        return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y);\n    }\n#else\n    inline __m256i dot(__m256i x, __m256i y) const {\n        return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y));\n    }\n#endif\n};\n\nstruct SignedDot {\n    DotHelper helper;\n    inline __m256i compute(__m256i x, __m256i y) const {\n        return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x));\n    }\n};\nstruct UnsignedDot {\n    DotHelper helper;\n    inline __m256i compute(__m256i x, __m256i y) const {\n        return helper.dot(x, y);\n    }\n};\ntemplate <typename Q8, typename Dot> struct Sum4 {\n    Dot dot;\n    inline __m256i compute(const __m256i * qx, const Q8 * y) const {\n        const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs));\n        const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs));\n        const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs));\n        const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs));\n        const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1));    // 0,0, 1,1, 0,0, 1,1\n        const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3));    // 2,2, 3,3, 2,2, 3,3\n        return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3\n    }\n};\n\nstruct Sum4_Q8 {\n    SignedDot dot;\n    static inline __m256i add1(__m256i a, __m256i b) {\n        return _mm256_add_epi32(_mm256_unpacklo_epi32(a, b), _mm256_unpackhi_epi32(a, b));\n    }\n    static inline __m256i add2(__m256i a, __m256i b) {\n        return _mm256_add_epi32(_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b));\n    }\n    inline __m256i compute(const __m256i * qx, const block_q8_0 * y) const {\n        const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs));\n        const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs));\n        const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs));\n        const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs));\n        const __m256i p01 = add1(p0, p1);  // 0,1, 0,1, 0,1, 0,1\n        const __m256i p23 = add1(p2, p3);  // 2,3, 2,3, 2,3, 2,3\n        return add2(p01, p23); // returns 0,1,2,3, 0,1,2,3\n    }\n};\n\nstruct ScaleHelperQ_0 {\n    ggml_half scales8[4];\n    template <typename Q>\n    inline __m128 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;\n        return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));\n    }\n    template <typename Q>\n    inline __m128 prepare4(__m128 other_scales, const Q * y) {\n        return _mm_mul_ps(other_scales, prepare4<Q>(y));\n    }\n    template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }\n    template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }\n};\ntemplate <int min_value>\nstruct ScaleHelperQ_0_1 {\n    ggml_half scales8[4];\n    template <typename Q>\n    inline __m256 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;\n        auto s4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));\n        return _mm256_set_m128(_mm_mul_ps(s4, min), s4);\n    }\n    template <typename Q>\n    inline __m256 prepare4(__m256 other_scales, const Q * y) {\n        return _mm_mul256_ps(other_scales, prepare4<Q>(y));\n    }\n    template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {\n        float d = GGML_FP16_TO_FP32(y->d);\n        return std::make_pair(d, -d*float(min_value));\n    }\n    std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));\n    }\n    const __m128 min = _mm_set1_ps(float(-min_value));\n};\n\nstruct ScaleHelperQ_1 {\n    uint32_t scales8[4];\n    const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100);\n\n    template <typename Q>\n    inline __m256 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) {\n            // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers\n            // complain that this breaks strict-aliasing rules.\n            memcpy(scales8 + j, &y[j].d, sizeof(uint32_t));\n        }\n        return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle));\n    }\n\n    template <typename Q>\n    inline __m256 prepare4(__m256 other_scales, const Q * y) {\n        return _mm256_mul_ps(other_scales, prepare4<Q>(y));\n    }\n\n    template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {\n        return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m));\n    }\n    template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m));\n    }\n    std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));\n    }\n};\n\nstruct MinusType0 {\n    inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); }\n    inline float compute(float d, int) const { return d; }\n    inline float result(__m256 acc, int) const { return hsum_float_8(acc); }\n};\n\ntemplate <int nrc_y> struct MinusType1 {\n    __m128 accm[nrc_y];\n    MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); }\n    inline __m256 compute(__m256 dm, int iy) {\n        const __m128 d = _mm256_castps256_ps128(dm);\n        const __m128 m = _mm256_extractf128_ps(dm, 1);\n        accm[iy] = _mm_add_ps(accm[iy], m);\n        return _mm256_set_m128(d, d);\n    }\n    inline float compute(const std::pair<float, float>& dm, int iy) {\n        accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f));\n        return dm.first;\n    }\n    inline float result(__m256 acc, int iy) const {\n        const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));\n        return hsum_float_4(_mm_add_ps(sum, accm[iy]));\n    }\n};\n\ntemplate <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {\n    __m256 acc[nrc_y];\n    Minus accm;\n    AccumT() {  for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); }\n    template <typename Unpacker, typename Scales, typename Sum, typename Q8>\n    inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) {\n        auto qx = unp.quants();\n        __m256 dall[nrc_y];\n        for (int i = 0; i < nb/4; ++i) {\n            auto other_scales = unp.set_block_4(i);\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto s12 = scales.prepare4(other_scales, y[iy] + 4*i);\n                dall[iy] = accm.compute(s12, iy);\n            }\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto pall = sum.compute(qx, y[iy] + 4*i);\n                acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]);\n            }\n        }\n        if (!is_multiple_of_4) {\n            for (int i = 4*(nb/4); i < nb; ++i) {\n                auto other_scales = unp.set_block(i);\n                for (int iy = 0; iy < nrc_y; ++iy) {\n                    auto s12 = scales.prepare1(other_scales, y[iy] + i);\n                    auto d = accm.compute(s12, iy);\n                    const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));\n                    acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);\n                }\n            }\n        }\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, accm.result(acc[iy], iy));\n            //s[iy*bs] = accm.result(acc[iy], iy);\n        }\n    }\n};\n\ntemplate <int nrc_y, bool is_multiple_of_4>\nusing AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>;\n\ntemplate <int nrc_y, bool is_multiple_of_4>\nusing AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;\n\nusing Sum4Type0 = Sum4<block_q8_0, SignedDot>;\nusing Sum4Type1 = Sum4<block_q8_1, UnsignedDot>;\n\ntemplate <typename Unpacker, typename Sum4Type, typename AccumType, typename Scales, typename Q8, int nrc_y>\nvoid mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {\n    Unpacker unp(vx, bx);\n    Sum4Type sum4;\n    Scales scales;\n    for (int ix = 0; ix < nrc_x; ++ix) {\n        unp.set_row(ix);\n        AccumType accum;\n        accum.compute(nb, unp, scales, sum4, y, info, ix);\n    }\n}\n\ntemplate <typename Unpacker, int nrc_y>\nvoid mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_0> q8(info);\n    int nb = n/Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\ntemplate <typename Unpacker, int nrc_y>\nvoid mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_1> q8(info);\n    int nb = n/Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type1, AccumType1<nrc_y, true>, ScaleHelperQ_1, block_q8_1, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type1, AccumType1<nrc_y, false>, ScaleHelperQ_1, block_q8_1, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\nstruct Dequantizer4bit {\n    const __m256i m4 = _mm256_set1_epi8(0xf);\n    inline __m256i dequant(const uint8_t * qs) const {\n        const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);\n        return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4);\n    }\n};\n\nstruct Q8_0_Dequantizer {\n    inline __m256i dequant(const block_q8_0 * x) const {\n        return _mm256_loadu_si256((const __m256i *)x->qs);\n    }\n};\n\nstruct Q8_0_1_Dequantizer {\n    inline __m256i dequant(const block_q8_0 * x) const {\n        return _mm256_add_epi8(_mm256_set1_epi8(127), _mm256_loadu_si256((const __m256i *)x->qs));\n    }\n};\n\nstruct Q4_0_Dequantizer {\n    Dequantizer4bit b4;\n    const __m256i m8 = _mm256_set1_epi8(-8);\n    inline __m256i dequant(const block_q4_0 * x) const {\n        return _mm256_add_epi8(b4.dequant(x->qs), m8);\n    }\n};\n\nstruct Q4_1_Dequantizer {\n    Dequantizer4bit b4;\n    inline __m256i dequant(const block_q4_1 * x) const {\n        return b4.dequant(x->qs);\n    }\n};\n\nstruct HBitDequantizer {\n    const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);\n    const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);\n    const __m256i minus1 = _mm256_set1_epi64x(-1);\n    inline __m256i to_bytes(const uint8_t * bits) const {\n        // Note: Data in all ggml quants is at least 2-byte aligned.\n        // => we can cast to uint16_t and use or on two consecutive entries\n        // which is faster than memcpy\n        const uint16_t * aux16 = (const uint16_t *)bits;\n        const uint32_t aux32 = aux16[0] | (aux16[1] << 16);\n        //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t));\n        __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle);\n        bytes = _mm256_or_si256(bytes, mask);\n        return _mm256_cmpeq_epi8(bytes, minus1);\n    }\n};\n\nstruct Q5_0_Dequantizer {\n    Dequantizer4bit b4;\n    HBitDequantizer hbit;\n    const __m256i mh = _mm256_set1_epi8((char)0xF0);\n    inline __m256i dequant(const block_q5_0 * x) const {\n        const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh);\n        return _mm256_or_si256(b4.dequant(x->qs), vqh);\n    }\n};\n\nstruct Q5_1_Dequantizer {\n    Dequantizer4bit b4;\n    HBitDequantizer hbit;\n    const __m256i mh = _mm256_set1_epi8(0x10);\n    inline __m256i dequant(const block_q5_1 * x) const {\n        const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh);\n        return _mm256_or_si256(b4.dequant(x->qs), vqh);\n    }\n};\n\ntemplate <typename Q, typename Scales, typename Dequantizer>\nstruct Q_Unpacker {\n    Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {}\n\n    const char * cx_0;\n    const Q    * x;\n    size_t       bx;\n\n    Scales scales;\n    Dequantizer deq;\n\n    __m256i qx[4];\n\n    inline const __m256i* quants() const { return qx; }\n\n    inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); }\n\n    inline auto set_block_4(int i) {\n        for (int j = 0; j < 4; ++j) {\n            qx[j] = deq.dequant(x + 4*i + j);\n        }\n        return scales.prepare4(x + 4*i);\n    }\n    inline auto set_block(int i) {\n        qx[0] = deq.dequant(x + i);\n        return scales.prepare1(x + i);\n    }\n};\n\nstruct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {\n    Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_0; }\n};\nstruct Q8_0_1_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0_1<127>, Q8_0_1_Dequantizer> {\n    Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n//    using Sum4T = Sum4TypeQ81;\n    inline static int block_size() { return QK8_0; }\n};\nstruct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {\n    Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_0; }\n};\nstruct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {\n    Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK5_0; }\n};\nstruct Q4_1_Unpacker final : public Q_Unpacker<block_q4_1, ScaleHelperQ_1, Q4_1_Dequantizer> {\n    Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_1; }\n};\nstruct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_Dequantizer> {\n    Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_1; }\n};\n\ntemplate <int nrc_y>\nvoid mul_mat_q8_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Q8_0_Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_0> q8(info);\n    int nb = n/Q8_0_Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\n\n\n\n/*\nmoonll\nadd some structs for DequantizerIQ2XXS\nSimpleBits\nEvenSignHelper\n*/\nstruct SimpleBits {\n    __m256i values[4];\n};\n\n// fix for #829: 添加对 AVX512VPOPCNTDQ 的检测\n#if defined(HAVE_FANCY_SIMD) && defined(__AVX512VPOPCNTDQ__)\n#define HAVE_AVX512_POPCNT 1\n#else\n#define HAVE_AVX512_POPCNT 0\n#endif\n\nstruct EvenSignHelper {\n    #if defined HAVE_FANCY_SIMD\n    // #pragma message(\"Using AVX512VPOPCNTDQ in even sign helper\")\n        union sbits_t {\n            __m128i vec;\n            __mmask32 mask[4];\n        };\n        IQK_ALWAYS_INLINE void sign_2_values(__m256i aux, __m256i * values) const {\n            aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask);\n            \n            // fix for #829: 兼容Intel Cascade Lake架构的CPU，如果不支持AVX512VPOPCNTDQ扩展，则使用替代实现\n            #if HAVE_AVX512_POPCNT\n                auto pcnt = _mm256_popcnt_epi32(aux);\n                \n            #else\n                // 提供替代实现，使用标准的位计数方法\n                __m256i pcnt;\n                int* pcnt_ptr = reinterpret_cast<int*>(&pcnt);\n                int* aux_ptr = reinterpret_cast<int*>(&aux); // 直接获取 aux 的地址，避免不必要的复制\n                \n                #pragma unroll 8  // 提示编译器展开循环，提高 SIMD 计算吞吐量\n                for (int i = 0; i < 8; i++) {\n                    pcnt_ptr[i] = __builtin_popcount(aux_ptr[i]); // 使用编译器内置 popcount\n                }\n            #endif\n            \n            sbits_t sbits;\n            sbits.vec = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));\n            values[0] = _mm256_mask_sub_epi8(values[0], sbits.mask[0], _mm256_setzero_si256(), values[0]);\n            values[1] = _mm256_mask_sub_epi8(values[1], sbits.mask[1], _mm256_setzero_si256(), values[1]);\n            //auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));\n            //const __mmask32 * m32 = (const __mmask32 *)&sign_bits;\n            //values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]);\n            //values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]);\n        }\n        const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0);\n        const __m256i mask   = _mm256_set1_epi32(127);\n        const __m256i mone   = _mm256_set1_epi32(1);\n    #else\n        inline void sign_value(uint32_t aux32, __m256i& value) const {\n            auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],\n                                           keven_signs[(aux32 >>  7) & 127], keven_signs[(aux32 >>  0) & 127]);\n            value = _mm256_sign_epi8(value, signs);\n        }\n    #endif\n};\n\n/*\nmoonll ad multiply_add for mul_mat_qX_K_q8_K_IQ_1\nadd func\nget_scale_shuffle_8\nget_scale_shuffle_16\nset_scales_16\n*/\n\ninline __m256i get_scale_shuffle_8(int i) {\n    return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));\n}\n\ninline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {\n    scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));\n    scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));\n    scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));\n    scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));\n}\n\n\ninline __m256i get_scale_shuffle_16(int i) {\n    static const uint8_t k_shuffle[128] = {\n         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,\n         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,\n         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,\n        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,\n    };\n    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);\n}\n\ninline void set_scales_16(const __m256i& all_scales, __m256i * scales) {\n    scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));\n    scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));\n    scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));\n    scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));\n}\n\n\ntemplate <typename Q8, typename Bits>\ninline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {\n    if (j == 0) {\n#ifdef HAVE_FANCY_SIMD\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));\n        }\n#else\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));\n            sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));\n        }\n#endif\n    } else {\n#ifdef HAVE_FANCY_SIMD\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));\n        }\n#else\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));\n        }\n#endif\n    }\n}\n\n/*\nmoonll ad multiply_add_1 for mul_mat_qX_K_q8_K_IQ_1\nadd func\nset_scales_8_iq\nset_scales_16_iq\n\nadd MUL_MAT\nmul_mat_qX_K_q8_K_IQ_1\nmul_mat_qX_K_q8_K_IQ_N\nmul_mat_qX_K_q8_K_IQ\n*/\n\ntemplate <typename Bits>\ninline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {\n    if (j == 0) {\n#ifdef HAVE_FANCY_SIMD\n        auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);\n        auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);\n        auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);\n        auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);\n        sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2));\n        sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4));\n#else\n        const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));\n        const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));\n        const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));\n        const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));\n        sumi[0] = _mm256_add_epi32(p1, p3);\n        sumi[1] = _mm256_add_epi32(p2, p4);\n#endif\n    } else {\n#ifdef HAVE_FANCY_SIMD\n        auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);\n        auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);\n        auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);\n        auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);\n        sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2));\n        sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4));\n#else\n        const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));\n        const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));\n        const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));\n        const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));\n        sumi[0] = _mm256_add_epi32(sumi[0], _mm256_add_epi32(p1, p3));\n        sumi[1] = _mm256_add_epi32(sumi[1], _mm256_add_epi32(p2, p4));\n#endif\n    }\n}\n\n\ninline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) {\n    //#ifdef HAVE_FANCY_SIMD\n        auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100)\n                              : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908);\n        scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);\n        scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)));\n    //#else\n    //    set_scales_8(all_scales, j, scales);\n    //#endif\n    }\n    \ninline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) {\n    #ifdef HAVE_FANCY_SIMD\n        auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100);\n        scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);\n        scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8)));\n    #else\n        set_scales_16(all_scales, scales);\n    #endif\n    }\n    \ntemplate <typename Dequantizer>\nstatic void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n        const int nb = n / QK_K;\n        Q8<1> q8(info);\n        Dequantizer deq(vx, bx);\n        __m256i scales[2];\n        __m256i q8_quants[4];\n        for (int ix = 0; ix < nrc_x; ++ix) {\n    \n            __m256 accd = _mm256_setzero_ps();\n            deq.new_row(ix);\n    \n            for (int i = 0; i < nb; ++i) {\n    \n                __m256i sumi[2], all_scales[Dequantizer::num_blocks/8];\n                deq.new_block(i, all_scales);\n    \n                for (int j = 0; j < QK_K/128; ++j) {\n                    deq.prepare(i, j, q8, q8_quants);\n                    if constexpr (Dequantizer::num_blocks == 8) {\n                        set_scales_8_iq(j, all_scales[0], scales);\n                    } else {\n                        set_scales_16_iq(all_scales[j], scales);\n                    }\n                    multiply_add_1(j, deq.bits, scales, q8_quants, sumi);\n                }\n                accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd);\n            }\n    \n            info.store(ix, 0, hsum_float_8(accd));\n        }\n    }\n\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK_K;\n    Q8<nrc_y> q8(info);\n    Dequantizer deq(vx, bx);\n    __m256i scales[4];\n    __m256  accd[nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8];\n            //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();\n            __m256i mins;\n            float dmin = deq.new_block(i, all_scales, mins);\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto bsums = q8.load_bsums(iy, i);\n                auto prod  = _mm256_madd_epi16(mins, bsums);\n                accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);\n            }\n\n            for (int j = 0; j < QK_K/128; ++j) {\n                deq.prepare(i, j);\n                if constexpr (Dequantizer::num_blocks == 8) {\n                    set_scales_8(all_scales[0], j, scales);\n                } else {\n                    set_scales_16(all_scales[j], scales);\n                }\n                //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n            }\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));\n                accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n#ifdef HAVE_FANCY_SIMD\n    if constexpr (nrc_y == 1) {\n        mul_mat_qX_K_q8_K_IQ_1<Dequantizer>(n, vx, bx, info, nrc_x);\n    } else {\n        mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);\n    }\n#else\n    mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);\n#endif\n}\n\n/*\nmoonll iq1s\ncore func for iq1s mul_mat_iq1_s_q8_K\n\n*/\n\ntemplate <int nrc_y>\nstatic void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    GGML_ASSERT(n%QK_K == 0);\n    Q8<nrc_y, block_q8_K> q8(info);\n    __m256i qx[8];\n    __m256i scales[4];\n    __m256  acc[nrc_y] = {};\n    auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000\n    __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100);\n    for (int ix = 0; ix < nrc_x; ++ix) {\n        auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);\n        for (int ibl = 0; ibl < n/QK_K; ++ibl) {\n            float d = GGML_FP16_TO_FP32(iq1s[ibl].d);\n            auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh);\n            auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7));\n            scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1));\n#ifdef HAVE_FANCY_SIMD\n            auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask);\n            auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9));\n#else\n            auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask);\n            auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7)));\n#endif\n            deltas128 = _mm_mullo_epi16(scales128, deltas128);\n            scales128 = _mm_slli_epi16(scales128, 3);\n            auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128);\n            auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128);\n            auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7\n            auto all_scales = MM256_SET_M128I(scales128, scales128);\n            auto shuffle = shuffle0;\n            for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {\n                scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle);\n                shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4));\n            }\n            const uint8_t  * qs = iq1s[ibl].qs;\n            const uint16_t * qh = iq1s[ibl].qh;\n            for (int ib = 0; ib < QK_K/32; ib += 2) {\n                qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)],\n                                             iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]);\n                qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)],\n                                             iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]);\n                qs += 8;\n            }\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto bsums = q8.load_bsums(iy, ibl);\n                auto sumi = _mm256_setzero_si256();\n                for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {\n                    auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0);\n                    auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1);\n#ifdef HAVE_FANCY_SIMD\n                    auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1);\n                    auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2);\n                    sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2));\n#else\n                    auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1);\n                    auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2);\n                    auto dot  = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2));\n                    sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot));\n#endif\n                }\n#ifdef HAVE_FANCY_SIMD\n                sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas);\n#else\n                sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas));\n#endif\n                acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]);\n            }\n        }\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, 0.125f*hsum_float_8(acc[iy]));\n            acc[iy] = _mm256_setzero_ps();\n        }\n    }\n}\n\n/*\nmoonll iq1s\nDequantizerIQ2XXS\nDequantizerIQ2XXS is important Dequantizer for DequantizerIQ1_S\n*/\n\nstruct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {\n    DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    constexpr static int num_blocks = 8;\n\n    union Data {\n        __m256i vec;\n        uint32_t val[8];\n    };\n\n    inline __m128i load_scales(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        const uint16_t * a16 = (const uint16_t *)x[i].qs;\n        auto scales = _mm_srli_epi16(_mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]), 12);\n        return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1));\n    }\n\n    inline void new_block(int i, __m256i * scales) {\n        auto sc16 = load_scales(i);\n        scales[0] = MM256_SET_M128I(sc16, sc16);\n    }\n    inline float new_block(int i, __m256i * scales, __m256i& mins) {\n        auto sc16 = load_scales(i);\n        mins = scb.shuffle(sc16);\n        scales[0] = MM256_SET_M128I(sc16, sc16);\n        return -d*minv;\n    }\n\n    inline static void make4(const uint32_t * aux32, __m256i * values) {\n        const uint8_t * aux8 = (const uint8_t *)aux32;\n        values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]);\n        values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]);\n        values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]);\n        values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]);\n    }\n\n    IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const {\n#ifdef HAVE_FANCY_SIMD\n        esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0);\n        esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2);\n#else\n        esh.sign_value(aux32[1], values[0]);\n        esh.sign_value(aux32[3], values[1]);\n        esh.sign_value(aux32[5], values[2]);\n        esh.sign_value(aux32[7], values[3]);\n#endif\n    }\n    inline void make4_signed(const uint32_t * aux32, const __m256i& min_value, __m256i * values) const {\n        make4(aux32, values);\n        sign_values(aux32, values);\n        for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);\n    }\n    inline void make4(const uint32_t * aux32, __m256i * values, __m256i * q8) const {\n        make4(aux32, values);\n        sign_values(aux32, q8);\n    }\n    inline void prepare(int i, int j) {\n        Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);\n        make4_signed(data.val, min_value, bits.values);\n    }\n    inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {\n        for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);\n        Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);\n        make4(data.val, bits.values, q8_quants);\n    }\n\n    constexpr static int minv = 43;\n    SimpleBits bits;\n    Scales8KBase scb;\n    EvenSignHelper esh;\n    const __m256i min_value = _mm256_set1_epi8(minv);\n    const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1);\n};\n\n/*\nmoonll\nadd Q8_0_Unpacker && DequantizerIQ2XXS support\nadd func mul_mat_qX_K_q8_K_IQ\n*/\n\ntemplate <typename Dequantizer> void MulMat::set_functions(MulMat& m) {\n    if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||\n        std::is_same_v<Dequantizer, Q8_0_Unpacker>) {\n            m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;\n        }\n        else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>|| std::is_same_v<Dequantizer, Q8_0_1_Unpacker>) {\n            m.funcs[0] = mul_mat_qX_1_q8_1_T<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_1_q8_1_T<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_1_q8_1_T<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_1_q8_1_T<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_1_q8_1_T<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_1_q8_1_T<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;\n        }\n        else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2XXS>) {\n            m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 8>;\n            }\n            else {\n#ifdef HAVE_FANCY_SIMD\n            if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4XS>) {\n            m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;\n            } else {\n            m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;\n            m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;\n            }\n#else\n            if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||\n                          std::is_same_v<Dequantizer, DequantizerQ3K> ||\n                          std::is_same_v<Dequantizer, DequantizerQ6K>) {\n                m.funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;\n                m.funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;\n                m.funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;\n                m.funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;\n                m.funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;\n                m.funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;\n                m.funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;\n                m.funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;\n            } else {\n                m.funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;\n                m.funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;\n                m.funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;\n                m.funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;\n                m.funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;\n                m.funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;\n                m.funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;\n                m.funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;\n            }\n#endif\n        }\n}\n\nstruct QFBase {\n    #ifdef __AVX512F__\n        constexpr static int k_step = 16;\n        using Data = __m512;\n        using Acc  = __m512;\n        static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); }\n        static inline Data load(const float * x) { return _mm512_loadu_ps(x); }\n        static inline Data load(const ggml_bf16_t * x) {\n            return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16));\n        }\n        static inline Acc acc(Acc prev, const Data& y, const Data& x) {\n            return _mm512_fmadd_ps(y, x, prev);\n        }\n        static inline Acc acc_first(const Data& y, const Data& x) {\n            return _mm512_mul_ps(y, x);\n        }\n        static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); }\n        static inline float hsum(Acc acc) {\n            return _mm512_reduce_add_ps(acc);\n        }\n        template <typename Float>\n        static inline Data load4Floats(const Float * x) {\n            return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);\n        }\n        static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {\n            acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);\n            acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline Acc acc_r4_first(const Data * xv, const Data& yv) {\n            auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00));\n            acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline __m128 hsum_r4(Acc acc) {\n            auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1));\n            auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3));\n            return _mm_add_ps(sum1, sum2);\n        }\n    #else\n        constexpr static int k_step = 8;\n        using Data = __m256;\n        using Acc  = __m256;\n        static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }\n        static inline Data load(const float * x) { return _mm256_loadu_ps(x); }\n        static inline Data load(const ggml_bf16_t * x) {\n            return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16));\n        }\n        static inline Acc acc(Acc prev, const Data& y, const Data& x) {\n            return _mm256_fmadd_ps(y, x, prev);\n        }\n        static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); }\n        static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {\n            acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);\n            acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline Acc acc_r4_first(const Data * xv, const Data& yv) {\n            auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00));\n            acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline Acc acc_first(const Data& y, const Data& x) {\n            return _mm256_mul_ps(y, x);\n        }\n        static inline float hsum(Acc acc) {\n            return hsum_float_8(acc);\n        }\n        static inline __m128 hsum_r4(Acc acc) {\n            return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));\n        }\n        template <typename Float>\n        static inline Data load4Floats(const Float * x) {\n            return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0);\n        }\n    #endif\n        static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); }\n        static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); }\n        static inline __m128 load128(const ggml_bf16_t * x) {\n            return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16));\n        }\n    };\n    template <typename Float, int nrc_in> struct QFT final : public QFBase {\n        constexpr static int nrc = nrc_in;\n        QFT(const DataInfo& info) {\n            for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy);\n        }\n        QFT(const char * cx, size_t bx) {\n            for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx);\n        }\n        IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }\n        IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); }\n        IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const {\n            xv[0] = load1(ix+0, i);\n            xv[1] = load1(ix+1, i);\n            xv[2] = load1(ix+2, i);\n            xv[3] = load1(ix+3, i);\n    #ifdef __AVX512F__\n            auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]);\n            auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]);\n            auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]);\n            auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]);\n            xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));\n            xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));\n            xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));\n            xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));\n    #else\n            auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]);\n            auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]);\n            auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]);\n            auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]);\n            xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));\n            xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));\n            xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));\n            xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));\n    #endif\n        }\n        const Float * y[nrc];\n    };\n    \n\n\ntemplate <typename Qy, typename Qx>\nIQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {\n    int nb = n/QFBase::k_step;\n    int nb4 = n/4;\n    Qy y(info);\n    Qx x(cx + ix0*bx, bx);\n    QFBase::Data xv[Qx::nrc];\n    QFBase::Acc  acc[Qx::nrc*Qy::nrc];\n    auto yv = y.load1(0, 0);\n    for (int ix = 0; ix < Qx::nrc; ++ix) {\n        xv[ix] = x.load1(ix, 0);\n        acc[ix] = QFBase::acc_first(yv, xv[ix]);\n    }\n    for (int iy = 1; iy < Qy::nrc; ++iy) {\n        yv = y.load1(iy, 0);\n        for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);\n    }\n    for (int i = 1; i < nb; ++i) {\n        yv = y.load1(0, i);\n        for (int ix = 0; ix < Qx::nrc; ++ix) {\n            xv[ix] = x.load1(ix, i);\n            acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);\n        }\n        for (int iy = 1; iy < Qy::nrc; ++iy) {\n            yv = y.load1(iy, i);\n            for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);\n        }\n    }\n    for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {\n        yv = y.load_tail(0, i);\n        for (int ix = 0; ix < Qx::nrc; ++ix) {\n            xv[ix] = x.load_tail(ix, i);\n            acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);\n        }\n        for (int iy = 1; iy < Qy::nrc; ++iy) {\n            yv = y.load_tail(iy, i);\n            for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);\n        }\n    }\n    for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));\n}\n// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done\n// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in\n// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.\ntemplate <int nrc_y, typename FloatX, typename FloatY>\nvoid mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    const char * cx = (const char *)vx;\n    // TBD if we want this\n    //if constexpr (nrc_y == 1) {\n    //    constexpr int k_nx = 2;\n    //    for (int ix = 0; ix < nrc_x/k_nx; ++ix) {\n    //        mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);\n    //    }\n    //    if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) {\n    //        int nx = nrc_x - lastx;\n    //        switch (nx) {\n    //            case 1: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info); break;\n    //            case 2: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, lastx, info); break;\n    //            case 3: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, lastx, info); break;\n    //        }\n    //        //mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info);\n    //    }\n    //    return;\n    //}\n#ifdef __AVX512F__\n    constexpr int k_nx = 5;\n#else\n    constexpr int k_nx = nrc_y == 1 ? 4 : 2;\n#endif\n    for (int ix = 0; ix < nrc_x/k_nx; ++ix) {\n        mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);\n    }\n    int last_x = k_nx*(nrc_x/k_nx);\n    if (last_x == nrc_x) return;\n    int nx = nrc_x - last_x;\n#ifdef __AVX512F__\n    switch (nx) {\n        case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;\n        case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;\n        case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;\n        case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;\n    }\n#else\n    if constexpr (nrc_y == 1) {\n        switch (nx) {\n            case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;\n            case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;\n            case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;\n        }\n    } else {\n        switch (nx) {\n            case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;\n        }\n    }\n#endif\n}\n\ntemplate <typename FloatX, typename FloatY>\nvoid set_mul_mat_f(MulMat& mm) {\n    for (auto& f : mm.funcs) f = nullptr;\n    mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>;\n    mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>;\n    mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>;\n    mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>;\n    mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>;\n#ifndef __AVX512F__\n    mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>;\n#endif\n}\n\n\n\n/*\nmoonll\nadd typeb TO compare return not expected type of weight matrix\nadd IQ2XSS\nadd IQ1_S\nadd GGML_TYPE_IQ4_XS\n*/\n\nbool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {\n    (void)Ny;\n\n        auto expected_typeB = GGML_TYPE_Q8_K;\n    switch (typeA) {\n        case GGML_TYPE_Q2_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ2K>(mm);\n            break;\n        case GGML_TYPE_Q3_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ3K>(mm);\n            break;\n        case GGML_TYPE_Q4_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ4K>(mm);\n            break;\n        case GGML_TYPE_Q5_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ5K>(mm);\n            break;\n        case GGML_TYPE_Q6_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ6K>(mm);\n            break;\n        case GGML_TYPE_IQ4_XS:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerIQ4XS>(mm);\n            break;\n        case GGML_TYPE_IQ2_XXS:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerIQ2XXS>(mm);\n            break;\n        case GGML_TYPE_Q4_0:\n            assert (ne00 % QK4_0 == 0);\n            MulMat::set_functions<Q4_0_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_0;\n            break;\n        case GGML_TYPE_Q4_1:\n            assert (ne00 % QK4_1 == 0);\n            MulMat::set_functions<Q4_1_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_1_X4;\n            break;\n        case GGML_TYPE_Q5_0:\n            assert (ne00 % QK5_0 == 0);\n            MulMat::set_functions<Q5_0_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_0;\n            break;\n        case GGML_TYPE_Q5_1:\n            assert (ne00 % QK5_1 == 0);\n            MulMat::set_functions<Q5_1_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_1_X4;\n            break;\n        case GGML_TYPE_Q8_0:\n            assert (ne00 % QK8_0 == 0);\n#ifdef HAVE_FANCY_SIMD\n            MulMat::set_functions<Q8_0_1_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_1_X4;\n#else\n            MulMat::set_functions<Q8_0_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_0_X4;\n#endif\n            break;\n        case GGML_TYPE_IQ1_S:\n            mm.funcs[0] = mul_mat_iq1_s_q8_K<1>;\n            mm.funcs[1] = mul_mat_iq1_s_q8_K<2>;\n            mm.funcs[2] = mul_mat_iq1_s_q8_K<3>;\n            mm.funcs[3] = mul_mat_iq1_s_q8_K<4>;\n            mm.funcs[4] = mul_mat_iq1_s_q8_K<5>;\n            mm.funcs[5] = mul_mat_iq1_s_q8_K<6>;\n            mm.funcs[6] = mul_mat_iq1_s_q8_K<7>;\n            mm.funcs[7] = mul_mat_iq1_s_q8_K<8>;\n        #ifdef HAVE_FANCY_SIMD\n             mm.func16 = mul_mat_iq1_s_q8_K<16>;\n        #endif\n       // row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);\n              expected_typeB = GGML_TYPE_Q8_K;\n            break;\n\n        default:\n        {\n            printf(\"case:%d\",typeA);\n            return false;\n        }\n            \n    }\n\n\n\n    return ggml_type(typeB) == expected_typeB;\n\n}\n\n} // namespace\n\n/*\niq1_s is not support for arm\n*/\n#else   // __aarch64__\n#include <arm_neon.h>\n\nnamespace {\n\ntemplate <int nrc, typename block_q8 = block_q8_K> struct Q8 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q8(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);\n    }\n\n    inline int8x16_t load_quants_16(int iy, int i, int j) const { return vld1q_s8(y[iy][i].qs + 16*j); }\n    inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }\n    inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }\n    inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }\n    inline int16x8_t load_bsums8(int iy, int i) const {\n        auto q8s = vld1q_s16_x2(y[iy][i].bsums);\n        return vpaddq_s16(q8s.val[0], q8s.val[1]);\n    }\n    inline float scale(int iy, int i) const { return y[iy][i].d; }\n\n    const block_q8 * y[nrc_y];\n};\n\ntemplate <typename block_q>\nstruct BaseDequantizer {\n    BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}\n    inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); }\n    const void * vx;\n    const block_q * x;\n    const size_t bx;\n    const int nrc;\n};\n\nstruct Q4bits {\n    const uint8x16_t m4b = vdupq_n_u8(0xf);\n    uint8x16x4_t b1, b2;\n    inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {\n        b.val[0] = vandq_u8(val[0], m4b);\n        b.val[2] = vshrq_n_u8(val[0], 4);\n        b.val[1] = vandq_u8(val[1], m4b);\n        b.val[3] = vshrq_n_u8(val[1], 4);\n    }\n    inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {\n        b.val[0] = vandq_u8(val[0], m4b);\n        b.val[1] = vshrq_n_u8(val[0], 4);\n        b.val[2] = vandq_u8(val[1], m4b);\n        b.val[3] = vshrq_n_u8(val[1], 4);\n    }\n    inline void prepare(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x2(qs);\n        prepare4(b1, q4bits.val);\n        q4bits = vld1q_u8_x2(qs+32);\n        prepare4(b2, q4bits.val);\n    }\n    inline void prepare_v2(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        prepare4(b1, q4bits.val+0);\n        prepare4(b2, q4bits.val+2);\n    }\n    inline void prepare64(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        b1.val[0] = vandq_u8(q4bits.val[0], m4b);\n        b1.val[1] = vandq_u8(q4bits.val[1], m4b);\n        b1.val[2] = vandq_u8(q4bits.val[2], m4b);\n        b1.val[3] = vandq_u8(q4bits.val[3], m4b);\n        b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);\n        b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);\n        b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);\n        b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);\n    }\n    inline void prepare16(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x2(qs);\n        prepare4_16(b1, q4bits.val);\n        q4bits = vld1q_u8_x2(qs+32);\n        prepare4_16(b2, q4bits.val);\n    }\n    inline void prepare16_v2(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        prepare4_16(b1, q4bits.val+0);\n        prepare4_16(b2, q4bits.val+2);\n    }\n};\n\nstruct Scales8 {\n    uint32_t utmp[4];\n    const uint8_t * sc8 = (const uint8_t *)utmp;\n    template <typename Q8, typename Qx>\n    inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {\n        make_q4_scales(x.scales, utmp);\n        int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));\n        accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));\n\n        uint8x8_t scales8 = vld1_u8(sc8);\n        uint16x8_t scales16 = vmovl_u8(scales8);\n        int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),\n                              vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};\n        return scales;\n    }\n};\n\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return s8.process_scales_mins(x[i], q8, i, acc);\n    }\n    inline void prepare(int i, int j) {\n        if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);\n        else bits.prepare(x[i].qs+64*j);\n    }\n\n    Q4bits bits;\n    Scales8 s8;\n\n    float d;\n};\n\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);\n    }\n    inline void prepare(int i, int j) {\n\n        auto hbits = vld1q_u8_x2(x[i].qh + 32*j);\n\n        bits.prepare64(x[i].ql+64*j);\n        bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));\n        bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));\n        bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));\n        bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));\n\n        bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));\n        bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));\n        bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));\n        bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));\n\n    }\n\n    Q4bits bits;\n\n    const uint8x16_t mhb = vdupq_n_u8(0x30);\n\n    float d;\n};\n\ntemplate <typename Dequantizer>\nstruct BlockQxK {\n    inline BlockQxK(const int maxn, const int maxk): maxn(maxn), maxk(maxk) {\n        values = (int8_t*)aligned_alloc(256, maxn * maxk * sizeof(int8_t));\n        scales = (int*)aligned_alloc(256,    maxn * maxk / SS * sizeof(int));\n        ds     = (float*)aligned_alloc(256,  maxn * maxk / QK * sizeof(int));\n        if constexpr (NeedSum) {\n            dmins = (float*)aligned_alloc(256, maxn * maxk / QK * sizeof(int));\n            scalems = (int16_t*)aligned_alloc(256, maxn * maxk / SS * sizeof(int16_t));\n        }\n    }\n    inline ~BlockQxK() {\n        free(values);\n        free(scales);\n        free(ds);\n        if constexpr (NeedSum) {\n            free(dmins);\n            free(scalems);\n        }\n    }\n    inline int FromDequantizer(const void * vx, size_t bx, int idx, int n_, int k_) {\n        n = n_;\n        k = k_;\n        bn = n / BS;\n        bk = k / QK;\n\n        Dequantizer deq(vx, bx, 1);\n        for (int i = 0; i < n; i += BS) {\n            for (int j = 0; j < BS; j ++) {\n                deq.new_row(j + i + idx);\n                for (int x = 0; x < bk; x ++) {\n                    {\n                        int8x16_t base = NeedSum ? vdupq_n_s8(0) : vdupq_n_s8(32);\n                        int32_t *dst = (int32_t*)(values + i*k + j*4 + x*QK*BS);\n                        deq.prepare(x, 0);\n                        int8x16_t v0 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[0]), base);\n                        int8x16_t v1 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[1]), base);\n                        int8x16_t v2 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[2]), base);\n                        int8x16_t v3 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[3]), base);\n                        *(dst + (0 + 0*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 0);\n                        *(dst + (1 + 0*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 1);\n                        *(dst + (2 + 0*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 2);\n                        *(dst + (3 + 0*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 3);\n                        *(dst + (0 + 1*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 0);\n                        *(dst + (1 + 1*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 1);\n                        *(dst + (2 + 1*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 2);\n                        *(dst + (3 + 1*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 3);\n                        *(dst + (0 + 2*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 0);\n                        *(dst + (1 + 2*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 1);\n                        *(dst + (2 + 2*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 2);\n                        *(dst + (3 + 2*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 3);\n                        *(dst + (0 + 3*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 0);\n                        *(dst + (1 + 3*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 1);\n                        *(dst + (2 + 3*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 2);\n                        *(dst + (3 + 3*4 + 0*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 3);\n                        v0 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[0]), base);\n                        v1 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[1]), base);\n                        v2 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[2]), base);\n                        v3 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[3]), base);\n                        *(dst + (0 + 0*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 0);\n                        *(dst + (1 + 0*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 1);\n                        *(dst + (2 + 0*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 2);\n                        *(dst + (3 + 0*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 3);\n                        *(dst + (0 + 1*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 0);\n                        *(dst + (1 + 1*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 1);\n                        *(dst + (2 + 1*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 2);\n                        *(dst + (3 + 1*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 3);\n                        *(dst + (0 + 2*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 0);\n                        *(dst + (1 + 2*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 1);\n                        *(dst + (2 + 2*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 2);\n                        *(dst + (3 + 2*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 3);\n                        *(dst + (0 + 3*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 0);\n                        *(dst + (1 + 3*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 1);\n                        *(dst + (2 + 3*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 2);\n                        *(dst + (3 + 3*4 + 1*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 3);\n                        deq.prepare(x, 1);\n                        v0 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[0]), base);\n                        v1 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[1]), base);\n                        v2 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[2]), base);\n                        v3 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b1.val[3]), base);\n                        *(dst + (0 + 0*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 0);\n                        *(dst + (1 + 0*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 1);\n                        *(dst + (2 + 0*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 2);\n                        *(dst + (3 + 0*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 3);\n                        *(dst + (0 + 1*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 0);\n                        *(dst + (1 + 1*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 1);\n                        *(dst + (2 + 1*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 2);\n                        *(dst + (3 + 1*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 3);\n                        *(dst + (0 + 2*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 0);\n                        *(dst + (1 + 2*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 1);\n                        *(dst + (2 + 2*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 2);\n                        *(dst + (3 + 2*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 3);\n                        *(dst + (0 + 3*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 0);\n                        *(dst + (1 + 3*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 1);\n                        *(dst + (2 + 3*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 2);\n                        *(dst + (3 + 3*4 + 2*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 3);\n                        v0 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[0]), base);\n                        v1 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[1]), base);\n                        v2 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[2]), base);\n                        v3 = vsubq_s8(vreinterpretq_s8_u8(deq.bits.b2.val[3]), base);\n                        *(dst + (0 + 0*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 0);\n                        *(dst + (1 + 0*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 1);\n                        *(dst + (2 + 0*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 2);\n                        *(dst + (3 + 0*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v0), 3);\n                        *(dst + (0 + 1*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 0);\n                        *(dst + (1 + 1*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 1);\n                        *(dst + (2 + 1*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 2);\n                        *(dst + (3 + 1*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v1), 3);\n                        *(dst + (0 + 2*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 0);\n                        *(dst + (1 + 2*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 1);\n                        *(dst + (2 + 2*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 2);\n                        *(dst + (3 + 2*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v2), 3);\n                        *(dst + (0 + 3*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 0);\n                        *(dst + (1 + 3*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 1);\n                        *(dst + (2 + 3*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 2);\n                        *(dst + (3 + 3*4 + 3*16)*BS) = vgetq_lane_s32(vreinterpretq_s32_s8(v3), 3);\n                    }\n                    if constexpr (std::is_same_v<DequantizerQ6K, Dequantizer>)\n                    {\n                        int32_t *dst = (int32_t*)(scales + i*(k/SS) + j + x*QK/SS*BS);\n                        int8x16_t ss = vld1q_s8(deq.x[x].scales);\n                        int16x8_t s16_0 = vmovl_s8(vget_low_s8(ss));\n                        int16x8_t s16_1 = vmovl_s8(vget_high_s8(ss));\n                        int32x4_t s32_0 = vmovl_s16(vget_low_s16(s16_0));\n                        int32x4_t s32_1 = vmovl_s16(vget_high_s16(s16_0));\n                        int32x4_t s32_2 = vmovl_s16(vget_low_s16(s16_1));\n                        int32x4_t s32_3 = vmovl_s16(vget_high_s16(s16_1));\n                        *(dst + (0+0*4)*BS) = vgetq_lane_s32(s32_0, 0);\n                        *(dst + (1+0*4)*BS) = vgetq_lane_s32(s32_0, 1);\n                        *(dst + (2+0*4)*BS) = vgetq_lane_s32(s32_0, 2);\n                        *(dst + (3+0*4)*BS) = vgetq_lane_s32(s32_0, 3);\n                        *(dst + (0+1*4)*BS) = vgetq_lane_s32(s32_1, 0);\n                        *(dst + (1+1*4)*BS) = vgetq_lane_s32(s32_1, 1);\n                        *(dst + (2+1*4)*BS) = vgetq_lane_s32(s32_1, 2);\n                        *(dst + (3+1*4)*BS) = vgetq_lane_s32(s32_1, 3);\n                        *(dst + (0+2*4)*BS) = vgetq_lane_s32(s32_2, 0);\n                        *(dst + (1+2*4)*BS) = vgetq_lane_s32(s32_2, 1);\n                        *(dst + (2+2*4)*BS) = vgetq_lane_s32(s32_2, 2);\n                        *(dst + (3+2*4)*BS) = vgetq_lane_s32(s32_2, 3);\n                        *(dst + (0+3*4)*BS) = vgetq_lane_s32(s32_3, 0);\n                        *(dst + (1+3*4)*BS) = vgetq_lane_s32(s32_3, 1);\n                        *(dst + (2+3*4)*BS) = vgetq_lane_s32(s32_3, 2);\n                        *(dst + (3+3*4)*BS) = vgetq_lane_s32(s32_3, 3);\n                    }\n                    if constexpr (std::is_same_v<DequantizerQ4K, Dequantizer>)\n                    {\n                        int32_t *dst = (int32_t*)(scales + i*(k/SS) + j + x*QK/SS*BS);\n                        int16_t *dst2 = (int16_t*)(scalems + i*(k/SS) + j + x*QK/SS*BS);\n                        uint32_t utmp[4];\n                        const uint8_t * sc8 = (const uint8_t *)utmp;\n                        make_q4_scales(deq.x[x].scales, utmp);\n                        int8x16_t ss = vld1q_s8((const int8_t *)sc8);\n                        int16x8_t scale = vmovl_s8(vget_low_s8(ss));\n                        int16x8_t scale_min = vmovl_high_s8(ss);\n                        int32x4_t s32_0 = vmovl_s16(vget_low_s16(scale));\n                        int32x4_t s32_1 = vmovl_s16(vget_high_s16(scale));\n                        *(dst + (0+0*4)*BS) = vgetq_lane_s32(s32_0, 0);\n                        *(dst + (1+0*4)*BS) = vgetq_lane_s32(s32_0, 1);\n                        *(dst + (2+0*4)*BS) = vgetq_lane_s32(s32_0, 2);\n                        *(dst + (3+0*4)*BS) = vgetq_lane_s32(s32_0, 3);\n                        *(dst + (0+1*4)*BS) = vgetq_lane_s32(s32_1, 0);\n                        *(dst + (1+1*4)*BS) = vgetq_lane_s32(s32_1, 1);\n                        *(dst + (2+1*4)*BS) = vgetq_lane_s32(s32_1, 2);\n                        *(dst + (3+1*4)*BS) = vgetq_lane_s32(s32_1, 3);\n                        *(dst2 + 0*BS) = vgetq_lane_s16(scale_min, 0);\n                        *(dst2 + 1*BS) = vgetq_lane_s16(scale_min, 1);\n                        *(dst2 + 2*BS) = vgetq_lane_s16(scale_min, 2);\n                        *(dst2 + 3*BS) = vgetq_lane_s16(scale_min, 3);\n                        *(dst2 + 4*BS) = vgetq_lane_s16(scale_min, 4);\n                        *(dst2 + 5*BS) = vgetq_lane_s16(scale_min, 5);\n                        *(dst2 + 6*BS) = vgetq_lane_s16(scale_min, 6);\n                        *(dst2 + 7*BS) = vgetq_lane_s16(scale_min, 7);\n                    }\n                    {\n                        float *dst = ds + i*bk + j + x*BS;\n                        *dst = GGML_FP16_TO_FP32(deq.x[x].d);\n                    }\n                    if constexpr (std::is_same_v<DequantizerQ4K, Dequantizer>)\n                    {\n                        float *dst = dmins + i*bk + j + x*BS;\n                        *dst = - GGML_FP16_TO_FP32(deq.x[x].dmin);\n                    }\n                }\n            }\n        }\n        return 0;\n    }\n\n    int8_t *values;     // [bn][k/4][BS][4]\n    int    *scales;     // [bn][k/SS][BS]\n    float  *ds;         // [bn][bk][BS]\n    float  *dmins;      // [bn][bk][BS]\n    int16_t *scalems;   // [bn][k/SS][BS]\n\n    static constexpr int BS = 8;\n    static constexpr int QK = 256;\n    static constexpr int SS = std::is_same_v<Dequantizer, DequantizerQ6K> ? 16 : 32;\n    static constexpr int NeedSum = std::is_same_v<Dequantizer, DequantizerQ6K> ? 0 : 1;\n    const int maxn;\n    const int maxk;\n    int n;\n    int k;\n    int bn;\n    int bk;\n};\n\ntemplate <typename Dequantizer, int BN>\nIQK_NOINLINE void matmul_v2_kernel(const Dequantizer *a, const block_q8_K *y[BN], const DataInfo &info, int idx, int idy) {\n    constexpr int BS = a->BS;\n    constexpr int QK = a->QK;\n    constexpr int SS = a->SS;\n    for (int s = 0; s < a->n; s += BS) {\n        float32x4_t cc[BN][BS/4];\n        for (int i = 0; i < BN; i ++) {\n            for (int j = 0; j < BS/4; j ++) {\n                cc[i][j] = vdupq_n_f32(0);\n            }\n        }\n        const int8_t *a_ptr = a->values + s*a->k;\n        const int8_t *b_ptr[BN];\n        for (int k = 0; k < a->bk; k ++) {\n            for (int i = 0; i < BN; i ++) {\n                b_ptr[i] = y[i][k].qs;\n            }\n            int32x4_t cci[BN][BS/4];\n            if constexpr (BN == 4 && SS == 16) {\n                int64_t length = QK/SS;\n                auto ap = a_ptr;\n                auto sp = a->scales + s*a->k/SS + (k*QK/SS)*BS;\n                // asm volatile (\n                asm volatile (\n                    \" eor    %[c00].16b, %[c00].16b, %[c00].16b \\n\"\n                    \" eor    %[c10].16b, %[c10].16b, %[c10].16b \\n\"\n                    \" eor    %[c20].16b, %[c20].16b, %[c20].16b \\n\"\n                    \" eor    %[c30].16b, %[c30].16b, %[c30].16b \\n\"\n                    \" eor    %[c01].16b, %[c01].16b, %[c01].16b \\n\"\n                    \" eor    %[c11].16b, %[c11].16b, %[c11].16b \\n\"\n                    \" eor    %[c21].16b, %[c21].16b, %[c21].16b \\n\"\n                    \" eor    %[c31].16b, %[c31].16b, %[c31].16b \\n\"\n                    \" loop_%=: \\n\"\n                    \" subs   %[len], %[len], #1 \\n\"\n                    \" ld1    {v12.16b}, [%[bp0]], #16 \\n\"\n                    \" ld1    {v13.16b}, [%[bp1]], #16 \\n\"\n                    \" ld1    {v14.16b}, [%[bp2]], #16 \\n\"\n                    \" ld1    {v15.16b}, [%[bp3]], #16 \\n\"\n                    \" prfm   pldl1strm, [%[ap], #256] \\n\"\n                    \" ld1    {v8.16b},  [%[ap]], #16 \\n\"\n                    \" ld1    {v9.16b},  [%[ap]], #16 \\n\"\n                    \" eor    v0.16b, v0.16b, v0.16b \\n\"\n                    \" eor    v1.16b, v1.16b, v1.16b \\n\"\n                    \" eor    v2.16b, v2.16b, v2.16b \\n\"\n                    \" eor    v3.16b, v3.16b, v3.16b \\n\"\n                    \" eor    v4.16b, v4.16b, v4.16b \\n\"\n                    \" eor    v5.16b, v5.16b, v5.16b \\n\"\n                    \" eor    v6.16b, v6.16b, v6.16b \\n\"\n                    \" eor    v7.16b, v7.16b, v7.16b \\n\"\n                    \" ld1    {v10.16b}, [%[ap]], #16 \\n\"\n                    \" ld1    {v11.16b}, [%[ap]], #16 \\n\"\n                    \" sdot   v0.4s, v8.16b,  v12.4b[0] \\n\"\n                    \" sdot   v1.4s, v8.16b,  v13.4b[0] \\n\"\n                    \" sdot   v2.4s, v8.16b,  v14.4b[0] \\n\"\n                    \" sdot   v3.4s, v8.16b,  v15.4b[0] \\n\"\n                    \" sdot   v4.4s, v9.16b,  v12.4b[0] \\n\"\n                    \" sdot   v5.4s, v9.16b,  v13.4b[0] \\n\"\n                    \" sdot   v6.4s, v9.16b,  v14.4b[0] \\n\"\n                    \" sdot   v7.4s, v9.16b,  v15.4b[0] \\n\"\n                    \" prfm   pldl1strm, [%[ap], #256] \\n\"\n                    \" ld1    {v8.16b},  [%[ap]], #16 \\n\"\n                    \" ld1    {v9.16b},  [%[ap]], #16 \\n\"\n                    \" sdot   v0.4s, v10.16b, v12.4b[1] \\n\"\n                    \" sdot   v1.4s, v10.16b, v13.4b[1] \\n\"\n                    \" sdot   v2.4s, v10.16b, v14.4b[1] \\n\"\n                    \" sdot   v3.4s, v10.16b, v15.4b[1] \\n\"\n                    \" sdot   v4.4s, v11.16b, v12.4b[1] \\n\"\n                    \" sdot   v5.4s, v11.16b, v13.4b[1] \\n\"\n                    \" sdot   v6.4s, v11.16b, v14.4b[1] \\n\"\n                    \" sdot   v7.4s, v11.16b, v15.4b[1] \\n\"\n                    \" ld1    {v10.16b}, [%[ap]], #16 \\n\"\n                    \" ld1    {v11.16b}, [%[ap]], #16 \\n\"\n                    \" sdot   v0.4s, v8.16b,  v12.4b[2] \\n\"\n                    \" sdot   v1.4s, v8.16b,  v13.4b[2] \\n\"\n                    \" sdot   v2.4s, v8.16b,  v14.4b[2] \\n\"\n                    \" sdot   v3.4s, v8.16b,  v15.4b[2] \\n\"\n                    \" sdot   v4.4s, v9.16b,  v12.4b[2] \\n\"\n                    \" sdot   v5.4s, v9.16b,  v13.4b[2] \\n\"\n                    \" sdot   v6.4s, v9.16b,  v14.4b[2] \\n\"\n                    \" sdot   v7.4s, v9.16b,  v15.4b[2] \\n\"\n                    \" ld1    {v8.4s}, [%[sp]], #16 \\n\"\n                    \" ld1    {v9.4s}, [%[sp]], #16 \\n\"\n                    \" sdot   v0.4s, v10.16b, v12.4b[3] \\n\"\n                    \" sdot   v1.4s, v10.16b, v13.4b[3] \\n\"\n                    \" sdot   v2.4s, v10.16b, v14.4b[3] \\n\"\n                    \" sdot   v3.4s, v10.16b, v15.4b[3] \\n\"\n                    \" sdot   v4.4s, v11.16b, v12.4b[3] \\n\"\n                    \" sdot   v5.4s, v11.16b, v13.4b[3] \\n\"\n                    \" sdot   v6.4s, v11.16b, v14.4b[3] \\n\"\n                    \" sdot   v7.4s, v11.16b, v15.4b[3] \\n\"\n                    \" mla    %[c00].4s, v0.4s, v8.4s \\n\"\n                    \" mla    %[c10].4s, v1.4s, v8.4s \\n\"\n                    \" mla    %[c20].4s, v2.4s, v8.4s \\n\"\n                    \" mla    %[c30].4s, v3.4s, v8.4s \\n\"\n                    \" mla    %[c01].4s, v4.4s, v9.4s \\n\"\n                    \" mla    %[c11].4s, v5.4s, v9.4s \\n\"\n                    \" mla    %[c21].4s, v6.4s, v9.4s \\n\"\n                    \" mla    %[c31].4s, v7.4s, v9.4s \\n\"\n                    \" bne    loop_%= \\n\"\n                    \" exit_%=:\\n\"\n                    : [len]    \"+r\" (length)\n                    , [ap]     \"+r\" (ap)\n                    , [bp0]    \"+r\" (b_ptr[0])\n                    , [bp1]    \"+r\" (b_ptr[1])\n                    , [bp2]    \"+r\" (b_ptr[2])\n                    , [bp3]    \"+r\" (b_ptr[3])\n                    , [sp]     \"+r\" (sp)\n                    , [c00]    \"+w\" (cci[0][0])\n                    , [c10]    \"+w\" (cci[1][0])\n                    , [c20]    \"+w\" (cci[2][0])\n                    , [c30]    \"+w\" (cci[3][0])\n                    , [c01]    \"+w\" (cci[0][1])\n                    , [c11]    \"+w\" (cci[1][1])\n                    , [c21]    \"+w\" (cci[2][1])\n                    , [c31]    \"+w\" (cci[3][1])\n                    :\n                    : \"v0\",  \"v1\",  \"v2\",  \"v3\"\n                    , \"v4\",  \"v5\",  \"v6\",  \"v7\"\n                    , \"v8\",  \"v9\",  \"v10\", \"v11\"\n                    , \"v12\", \"v13\", \"v14\", \"v15\"\n                    , \"memory\", \"cc\"\n                );\n                a_ptr += BS * QK;\n            } else if (BN == 4 && SS == 32) {\n                int64_t length = QK/SS;\n                auto ap = a_ptr;\n                auto sp = a->scales + s*a->k/SS + (k*QK/SS)*BS;\n                // asm volatile (\n                asm volatile (\n                    \" eor    %[c00].16b, %[c00].16b, %[c00].16b \\n\"\n                    \" eor    %[c10].16b, %[c10].16b, %[c10].16b \\n\"\n                    \" eor    %[c20].16b, %[c20].16b, %[c20].16b \\n\"\n                    \" eor    %[c30].16b, %[c30].16b, %[c30].16b \\n\"\n                    \" eor    %[c01].16b, %[c01].16b, %[c01].16b \\n\"\n                    \" eor    %[c11].16b, %[c11].16b, %[c11].16b \\n\"\n                    \" eor    %[c21].16b, %[c21].16b, %[c21].16b \\n\"\n                    \" eor    %[c31].16b, %[c31].16b, %[c31].16b \\n\"\n                    \" loop_%=: \\n\"\n                    \" subs   %[len], %[len], #1 \\n\"\n                    \" ld1    {v12.16b}, [%[bp0]], #16 \\n\"\n                    \" ld1    {v13.16b}, [%[bp1]], #16 \\n\"\n                    \" ld1    {v14.16b}, [%[bp2]], #16 \\n\"\n                    \" ld1    {v15.16b}, [%[bp3]], #16 \\n\"\n                    \" prfm   pldl1strm, [%[ap], #256] \\n\"\n                    \" ld1    {v8.16b},  [%[ap]], #16 \\n\"\n                    \" ld1    {v9.16b},  [%[ap]], #16 \\n\"\n                    \" eor    v0.16b, v0.16b, v0.16b \\n\"\n                    \" eor    v1.16b, v1.16b, v1.16b \\n\"\n                    \" eor    v2.16b, v2.16b, v2.16b \\n\"\n                    \" eor    v3.16b, v3.16b, v3.16b \\n\"\n                    \" eor    v4.16b, v4.16b, v4.16b \\n\"\n                    \" eor    v5.16b, v5.16b, v5.16b \\n\"\n                    \" eor    v6.16b, v6.16b, v6.16b \\n\"\n                    \" eor    v7.16b, v7.16b, v7.16b \\n\"\n                    \" ld1    {v10.16b}, [%[ap]], #16 \\n\"\n                    \" ld1    {v11.16b}, [%[ap]], #16 \\n\"\n                    \" sdot   v0.4s, v8.16b,  v12.4b[0] \\n\"\n                    \" sdot   v1.4s, v8.16b,  v13.4b[0] \\n\"\n                    \" sdot   v2.4s, v8.16b,  v14.4b[0] \\n\"\n                    \" sdot   v3.4s, v8.16b,  v15.4b[0] \\n\"\n                    \" sdot   v4.4s, v9.16b,  v12.4b[0] \\n\"\n                    \" sdot   v5.4s, v9.16b,  v13.4b[0] \\n\"\n                    \" sdot   v6.4s, v9.16b,  v14.4b[0] \\n\"\n                    \" sdot   v7.4s, v9.16b,  v15.4b[0] \\n\"\n                    \" prfm   pldl1strm, [%[ap], #256] \\n\"\n                    \" ld1    {v8.16b},  [%[ap]], #16 \\n\"\n                    \" ld1    {v9.16b},  [%[ap]], #16 \\n\"\n                    \" sdot   v0.4s, v10.16b, v12.4b[1] \\n\"\n                    \" sdot   v1.4s, v10.16b, v13.4b[1] \\n\"\n                    \" sdot   v2.4s, v10.16b, v14.4b[1] \\n\"\n                    \" sdot   v3.4s, v10.16b, v15.4b[1] \\n\"\n                    \" sdot   v4.4s, v11.16b, v12.4b[1] \\n\"\n                    \" sdot   v5.4s, v11.16b, v13.4b[1] \\n\"\n                    \" sdot   v6.4s, v11.16b, v14.4b[1] \\n\"\n                    \" sdot   v7.4s, v11.16b, v15.4b[1] \\n\"\n                    \" ld1    {v10.16b}, [%[ap]], #16 \\n\"\n                    \" ld1    {v11.16b}, [%[ap]], #16 \\n\"\n                    \" sdot   v0.4s, v8.16b,  v12.4b[2] \\n\"\n                    \" sdot   v1.4s, v8.16b,  v13.4b[2] \\n\"\n                    \" sdot   v2.4s, v8.16b,  v14.4b[2] \\n\"\n                    \" sdot   v3.4s, v8.16b,  v15.4b[2] \\n\"\n                    \" sdot   v4.4s, v9.16b,  v12.4b[2] \\n\"\n                    \" sdot   v5.4s, v9.16b,  v13.4b[2] \\n\"\n                    \" sdot   v6.4s, v9.16b,  v14.4b[2] \\n\"\n                    \" sdot   v7.4s, v9.16b,  v15.4b[2] \\n\"\n                    \" prfm   pldl1strm, [%[ap], #256] \\n\"\n                    \" ld1    {v8.16b},  [%[ap]], #16 \\n\"\n                    \" ld1    {v9.16b},  [%[ap]], #16 \\n\"\n                    \" sdot   v0.4s, v10.16b, v12.4b[3] \\n\"\n                    \" sdot   v1.4s, v10.16b, v13.4b[3] \\n\"\n                    \" sdot   v2.4s, v10.16b, v14.4b[3] \\n\"\n                    \" sdot   v3.4s, v10.16b, v15.4b[3] \\n\"\n                    \" sdot   v4.4s, v11.16b, v12.4b[3] \\n\"\n                    \" sdot   v5.4s, v11.16b, v13.4b[3] \\n\"\n                    \" sdot   v6.4s, v11.16b, v14.4b[3] \\n\"\n                    \" sdot   v7.4s, v11.16b, v15.4b[3] \\n\"\n                    \" ld1    {v10.16b}, [%[ap]], #16 \\n\"\n                    \" ld1    {v11.16b}, [%[ap]], #16 \\n\"\n                    \" ld1    {v12.16b}, [%[bp0]], #16 \\n\"\n                    \" ld1    {v13.16b}, [%[bp1]], #16 \\n\"\n                    \" ld1    {v14.16b}, [%[bp2]], #16 \\n\"\n                    \" ld1    {v15.16b}, [%[bp3]], #16 \\n\"\n                    \" sdot   v0.4s, v8.16b,  v12.4b[0] \\n\"\n                    \" sdot   v1.4s, v8.16b,  v13.4b[0] \\n\"\n                    \" sdot   v2.4s, v8.16b,  v14.4b[0] \\n\"\n                    \" sdot   v3.4s, v8.16b,  v15.4b[0] \\n\"\n                    \" sdot   v4.4s, v9.16b,  v12.4b[0] \\n\"\n                    \" sdot   v5.4s, v9.16b,  v13.4b[0] \\n\"\n                    \" sdot   v6.4s, v9.16b,  v14.4b[0] \\n\"\n                    \" sdot   v7.4s, v9.16b,  v15.4b[0] \\n\"\n                    \" prfm   pldl1strm, [%[ap], #256] \\n\"\n                    \" ld1    {v8.16b},  [%[ap]], #16 \\n\"\n                    \" ld1    {v9.16b},  [%[ap]], #16 \\n\"\n                    \" sdot   v0.4s, v10.16b, v12.4b[1] \\n\"\n                    \" sdot   v1.4s, v10.16b, v13.4b[1] \\n\"\n                    \" sdot   v2.4s, v10.16b, v14.4b[1] \\n\"\n                    \" sdot   v3.4s, v10.16b, v15.4b[1] \\n\"\n                    \" sdot   v4.4s, v11.16b, v12.4b[1] \\n\"\n                    \" sdot   v5.4s, v11.16b, v13.4b[1] \\n\"\n                    \" sdot   v6.4s, v11.16b, v14.4b[1] \\n\"\n                    \" sdot   v7.4s, v11.16b, v15.4b[1] \\n\"\n                    \" ld1    {v10.16b}, [%[ap]], #16 \\n\"\n                    \" ld1    {v11.16b}, [%[ap]], #16 \\n\"\n                    \" sdot   v0.4s, v8.16b,  v12.4b[2] \\n\"\n                    \" sdot   v1.4s, v8.16b,  v13.4b[2] \\n\"\n                    \" sdot   v2.4s, v8.16b,  v14.4b[2] \\n\"\n                    \" sdot   v3.4s, v8.16b,  v15.4b[2] \\n\"\n                    \" sdot   v4.4s, v9.16b,  v12.4b[2] \\n\"\n                    \" sdot   v5.4s, v9.16b,  v13.4b[2] \\n\"\n                    \" sdot   v6.4s, v9.16b,  v14.4b[2] \\n\"\n                    \" sdot   v7.4s, v9.16b,  v15.4b[2] \\n\"\n                    \" ld1    {v8.4s}, [%[sp]], #16 \\n\"\n                    \" ld1    {v9.4s}, [%[sp]], #16 \\n\"\n                    \" sdot   v0.4s, v10.16b, v12.4b[3] \\n\"\n                    \" sdot   v1.4s, v10.16b, v13.4b[3] \\n\"\n                    \" sdot   v2.4s, v10.16b, v14.4b[3] \\n\"\n                    \" sdot   v3.4s, v10.16b, v15.4b[3] \\n\"\n                    \" sdot   v4.4s, v11.16b, v12.4b[3] \\n\"\n                    \" sdot   v5.4s, v11.16b, v13.4b[3] \\n\"\n                    \" sdot   v6.4s, v11.16b, v14.4b[3] \\n\"\n                    \" sdot   v7.4s, v11.16b, v15.4b[3] \\n\"\n                    \" mla    %[c00].4s, v0.4s, v8.4s \\n\"\n                    \" mla    %[c10].4s, v1.4s, v8.4s \\n\"\n                    \" mla    %[c20].4s, v2.4s, v8.4s \\n\"\n                    \" mla    %[c30].4s, v3.4s, v8.4s \\n\"\n                    \" mla    %[c01].4s, v4.4s, v9.4s \\n\"\n                    \" mla    %[c11].4s, v5.4s, v9.4s \\n\"\n                    \" mla    %[c21].4s, v6.4s, v9.4s \\n\"\n                    \" mla    %[c31].4s, v7.4s, v9.4s \\n\"\n                    \" bne    loop_%= \\n\"\n                    \" exit_%=:\\n\"\n                    : [len]    \"+r\" (length)\n                    , [ap]     \"+r\" (ap)\n                    , [bp0]    \"+r\" (b_ptr[0])\n                    , [bp1]    \"+r\" (b_ptr[1])\n                    , [bp2]    \"+r\" (b_ptr[2])\n                    , [bp3]    \"+r\" (b_ptr[3])\n                    , [sp]     \"+r\" (sp)\n                    , [c00]    \"+w\" (cci[0][0])\n                    , [c10]    \"+w\" (cci[1][0])\n                    , [c20]    \"+w\" (cci[2][0])\n                    , [c30]    \"+w\" (cci[3][0])\n                    , [c01]    \"+w\" (cci[0][1])\n                    , [c11]    \"+w\" (cci[1][1])\n                    , [c21]    \"+w\" (cci[2][1])\n                    , [c31]    \"+w\" (cci[3][1])\n                    :\n                    : \"v0\",  \"v1\",  \"v2\",  \"v3\"\n                    , \"v4\",  \"v5\",  \"v6\",  \"v7\"\n                    , \"v8\",  \"v9\",  \"v10\", \"v11\"\n                    , \"v12\", \"v13\", \"v14\", \"v15\"\n                    , \"memory\", \"cc\"\n                );\n                a_ptr += BS * QK;\n            } else\n            {\n                for (int i = 0; i < BN; i ++) {\n                    for (int j = 0; j < BS/4; j ++) {\n                        cci[i][j] = vdupq_n_s32(0);\n                    }\n                }\n                for (int k0 = 0; k0 < QK/SS; k0 ++) {\n                    int32x4_t ccv[BN][BS/4];\n                    for (int i = 0; i < BN; i ++) {\n                        for (int j = 0; j < BS/4; j ++) {\n                            ccv[i][j] = vdupq_n_s32(0);\n                        }\n                    }\n                    #pragma unroll\n                    for (int k2 = 0; k2 < SS; k2 += 16) {\n                        const int OFFSET = 256;\n                        __builtin_prefetch((a_ptr + OFFSET + 0*64), 0, 0);\n                        __builtin_prefetch((a_ptr + OFFSET + 1*64), 0, 0);\n\n                        int8x16_t bb[BN];\n                        int8x16_t aa[BS/4];\n                        for (int i = 0; i < BN; i ++) {\n                            bb[i] = vld1q_s8(b_ptr[i]); b_ptr[i] += 16;\n                        }\n                        for (int k1 = 0; k1 < 4; k1 ++) {\n                            for (int i = 0; i < BS/4; i ++) {\n                                aa[i] = vld1q_s8(a_ptr); a_ptr += 16;\n                            }\n                            for (int i = 0; i < BN; i ++) {\n                                for (int j = 0; j < BS/4; j ++) {\n                                    ccv[i][j] = vdotq_laneq_s32(ccv[i][j], aa[j], bb[i], k1);\n                                }\n                            }\n                        }\n                    }\n                    int32x4_t scal[BS/4];\n                    for (int i = 0; i < BS/4; i ++) {\n                        scal[i] = vld1q_s32(a->scales + s*a->k/SS + (k*QK/SS+k0)*BS + i*4);\n                    }\n                    for (int i = 0; i < BN; i ++) {\n                        for (int j = 0; j < BS/4; j ++) {\n                            cci[i][j] = vmlaq_s32(cci[i][j], ccv[i][j], scal[j]);\n                        }\n                    }\n                }\n            }\n            float32x4_t scalf[BS/4];\n            for (int i = 0; i < BS/4; i ++) {\n                scalf[i] = vld1q_f32(a->ds + s*a->bk + k*BS + i*4);\n            }\n            for (int i = 0; i < BN; i ++) {\n                for (int j = 0; j < BS/4; j ++) {\n                    cc[i][j] = vfmaq_f32(cc[i][j], vcvtq_f32_s32(cci[i][j]), vmulq_n_f32(scalf[j], y[i][k].d));\n                }\n            }\n        }\n        if constexpr (a->NeedSum) {\n            const int16_t *a_ptr = a->scalems + s*a->k/SS;\n            const int16_t *b_ptr[BN];\n            for (int k = 0; k < a->bk; k ++) {\n                for (int i = 0; i < BN; i ++) {\n                    b_ptr[i] = y[i][k].bsums;\n                }\n                int32x4_t cci[BN][BS/4];\n                for (int i = 0; i < BN; i ++) {\n                    for (int j = 0; j < BS/4; j ++) {\n                        cci[i][j] = vdupq_n_s32(0);\n                    }\n                }\n                for (int k0 = 0; k0 < QK/SS/4; k0 ++) {\n                    int16x8_t bb[BN];\n                    int16x8_t aa[BS/8];\n                    for (int i = 0; i < BN; i ++) {\n                        bb[i] = vld1q_s16(b_ptr[i]); b_ptr[i] += 8;\n                    }\n                    for (int k1 = 0; k1 < 4; k1 ++) {\n                        for (int i = 0; i < BS/8; i ++) {\n                            aa[i] = vld1q_s16(a_ptr); a_ptr += 8;\n                        }\n                        for (int i = 0; i < BN; i ++) {\n                            for (int j = 0; j < BS/8; j ++) {\n                                cci[i][2*j+0] = vmlal_laneq_s16(cci[i][2*j+0], vget_low_s16(aa[j]), bb[i], 2*k1+0);\n                                cci[i][2*j+1] = vmlal_high_laneq_s16(cci[i][2*j+1], aa[j], bb[i], 2*k1+0);\n                                cci[i][2*j+0] = vmlal_laneq_s16(cci[i][2*j+0], vget_low_s16(aa[j]), bb[i], 2*k1+1);\n                                cci[i][2*j+1] = vmlal_high_laneq_s16(cci[i][2*j+1], aa[j], bb[i], 2*k1+1);\n                            }\n                        }\n                    }\n                }\n                float32x4_t scalf[BS/4];\n                for (int i = 0; i < BS/4; i ++) {\n                    scalf[i] = vld1q_f32(a->dmins + s*a->bk + k*BS + i*4);\n                }\n                for (int i = 0; i < BN; i ++) {\n                    for (int j = 0; j < BS/4; j ++) {\n                        cc[i][j] = vfmaq_f32(cc[i][j], vcvtq_f32_s32(cci[i][j]), vmulq_n_f32(scalf[j], y[i][k].d));\n                    }\n                }\n            }\n        }\n        for (int i = 0; i < BN; i ++) {\n            for (int j = 0; j < BS/4; j ++) {\n                vst1q_f32(info.ptr(j*4+s+idx, i), cc[i][j]);\n            }\n        }\n    }\n    return;\n}\n\ntemplate <typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_T_v2(int m, int n, int k, const void * vx, size_t bx, const DataInfo& info) {\n    constexpr int m_step = 64;\n    constexpr int n_step = 4;\n    assert(m%m_step == 0);\n    int n2 = n - (n%n_step);\n    int left = n%n_step;\n    BlockQxK<Dequantizer> xx(m_step, k);\n    for (int i = 0; i < m; i += m_step) {\n        auto this_info = info;\n        int bm = (m - i) < m_step ? (m - i) : m_step;\n        xx.FromDequantizer(vx, bx, i, bm, k);\n        for (int j = 0; j < n2; j += n_step) {\n            Q8<n_step, block_q8_K> q8(this_info);\n            matmul_v2_kernel<BlockQxK<Dequantizer>, n_step>(&xx, q8.y, this_info, i, j);\n            this_info.cur_y += n_step;\n        }\n        if (left) {\n            switch (left) {\n                case 1:\n                {\n                    Q8<1, block_q8_K> q8(this_info);\n                    matmul_v2_kernel<BlockQxK<Dequantizer>, 1>(&xx, q8.y, this_info, i, n2);\n                    this_info.cur_y += 1;\n                    break;\n                }\n                case 2:\n                {\n                    Q8<2, block_q8_K> q8(this_info);\n                    matmul_v2_kernel<BlockQxK<Dequantizer>, 2>(&xx, q8.y, this_info, i, n2);\n                    this_info.cur_y += 2;\n                    break;\n                }\n                case 3:\n                {\n                    Q8<3, block_q8_K> q8(this_info);\n                    matmul_v2_kernel<BlockQxK<Dequantizer>, 3>(&xx, q8.y, this_info, i, n2);\n                    this_info.cur_y += 3;\n                    break;\n                }\n            }\n        }\n    }\n    return;\n}\n\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,\n        const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n    auto mzero = vdupq_n_s32(0);\n    const int8x16_t * qs_1 = (const int8x16_t *)qx_1.val;\n    const int8x16_t * qs_2 = (const int8x16_t *)qx_2.val;\n\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[0], q8b_1.val[0]), qs_1[1], q8b_1.val[1]); // block 1\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[2], q8b_2.val[0]), qs_1[3], q8b_2.val[1]); // block 2\n    auto p12 = vpaddq_s32(p1, p2);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[0], q8b_3.val[0]), qs_2[1], q8b_3.val[1]); // block 3\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[2], q8b_4.val[0]), qs_2[3], q8b_4.val[1]); // block 4\n    auto p34 = vpaddq_s32(p3, p4);\n\n    auto pall = vpaddq_s32(p12, p34);\n    sumi = vmlaq_s32(sumi, scales.val[j], pall);\n}\n\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_8_blocks(const int8x16_t * qx, const Q8& q8,\n        const int32x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n    auto mzero = vdupq_n_s32(0);\n\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[0], q8b_1.val[0]), qx[1], q8b_1.val[1]); // block 1\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[2], q8b_2.val[0]), qx[3], q8b_2.val[1]); // block 2\n    auto p12 = vpaddq_s32(p1, p2);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[4], q8b_3.val[0]), qx[5], q8b_3.val[1]); // block 3\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[6], q8b_4.val[0]), qx[7], q8b_4.val[1]); // block 4\n    auto p34 = vpaddq_s32(p3, p4);\n\n    auto pall = vpaddq_s32(p12, p34);\n    sumi = vmlaq_s32(sumi, scales, pall);\n}\n\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,\n        const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n\n    auto mzero = vdupq_n_s32(0);\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,\n    auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3\n    sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,\n    auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7\n    sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);\n}\n\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n\n    Dequantizer deq(vx, bx, nrc_y);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[nrc_y];\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb; ++i) {\n\n            int32x4_t sumi[nrc_y];\n            for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);\n\n            if constexpr (Dequantizer::num_blocks() == 8) {\n                auto scales = deq.new_block(i);\n                deq.prepare(i, 0);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                deq.prepare(i, 1);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n            }\n            else if constexpr (Dequantizer::num_blocks() == 16) {\n                auto scales = deq.new_block(i);\n                deq.prepare(i, 0);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                deq.prepare(i, 1);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n            }\n            else {\n                GGML_ASSERT(false);\n            }\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));\n            }\n        }\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\ntemplate <typename Q8>\ninline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto q8s = q8.load_bsums8(iy, i);\n        int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));\n        int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));\n        float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));\n        acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));\n    }\n}\n\ntemplate <typename Q8>\ninline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto q8s = q8.load_bsums(iy, i);\n        int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));\n        int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));\n        int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));\n        int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));\n        float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));\n        acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));\n    }\n}\n\nstruct Q2bits {\n    const uint8x16_t m4b = vdupq_n_u8(0x03);\n    uint8x16x4_t b1, b2;\n    inline void prepare(const uint8_t * qs) {\n        auto q2bits = vld1q_u8_x2(qs);\n        b1.val[0] = vandq_u8(q2bits.val[0], m4b);\n        b1.val[1] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b1.val[2] = vandq_u8(q2bits.val[0], m4b);\n        b1.val[3] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b2.val[0] = vandq_u8(q2bits.val[0], m4b);\n        b2.val[1] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b2.val[2] = vandq_u8(q2bits.val[0], m4b);\n        b2.val[3] = vandq_u8(q2bits.val[1], m4b);\n    }\n};\n\nstruct HighBit5 {\n    const uint8x16_t mhb = vdupq_n_u8(0x10);\n    uint8x16x2_t bits;\n    inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {\n        b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));\n        b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));\n        b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));\n        b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));\n\n        b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));\n        b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));\n        b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));\n        b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));\n\n        if (do_shift) {\n            bits.val[0] = vshrq_n_u8(bits.val[0], 4);\n            bits.val[1] = vshrq_n_u8(bits.val[1], 4);\n        }\n    }\n};\n\nstruct HighBit3 {\n    const uint8x16_t mhb = vdupq_n_u8(0x04);\n    uint8x16x2_t bits;\n    inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {\n        b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));\n        b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));\n        b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));\n        b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));\n\n        b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));\n        b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));\n        b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));\n        b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));\n\n        if (do_shift) {\n            bits.val[0] = vshrq_n_u8(bits.val[0], 4);\n            bits.val[1] = vshrq_n_u8(bits.val[1], 4);\n        }\n    }\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        h.bits = vld1q_u8_x2(x[i].qh);\n        return s8.process_scales_mins(x[i], q8, i, acc);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+64*j);\n        h.apply(bits.b1, bits.b2, j == 0);\n    }\n\n    Q4bits bits;\n    HighBit5 h;\n    Scales8 s8;\n\n    uint8x16x2_t hbits;\n\n    float d;\n};\n\ninline int32x4x4_t make_wider(const int16x8x2_t& scales16) {\n    int32x4x4_t scales = {\n        vmovl_s16(vget_low_s16 (scales16.val[0])),\n        vmovl_s16(vget_high_s16(scales16.val[0])),\n        vmovl_s16(vget_low_s16 (scales16.val[1])),\n        vmovl_s16(vget_high_s16(scales16.val[1])),\n    };\n    return scales;\n}\n\ntemplate <typename Q8>\ninline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {\n    int16x8x2_t scales16;\n    scales16.val[0] = vmovl_s8(vget_low_s8(scales8));\n    scales16.val[1] = vmovl_s8(vget_high_s8(scales8));\n    accum_mins_16(scales16, q8, acc, i, c);\n    return make_wider(scales16);\n}\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        h.bits = vld1q_u8_x2(x[i].hmask);\n        const uint16_t * sc16 = (const uint16_t *)x[i].scales;\n        uint32_t aux0 = sc16[0] | (sc16[1] << 16);\n        uint32_t aux1 = sc16[2] | (sc16[3] << 16);\n        uint32_t aux2 = sc16[4] | (sc16[5] << 16);\n        aux32[0] =  (aux0       & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);\n        aux32[1] =  (aux1       & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);\n        aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);\n        aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);\n        return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d);\n    }\n\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+32*j);\n        h.apply(bits.b1, bits.b2, j == 0);\n    }\n\n    uint32_t aux32[4];\n\n    Q2bits bits;\n\n    HighBit3 h;\n\n    float d;\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return true; }\n\n    template <typename Q8>\n    inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        auto scales_and_mins = vld1q_u8(x[i].scales);\n        auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));\n        int16x8x2_t scales16;\n        scales16.val[0] = vmovl_s8(vget_low_s8(mins8));\n        scales16.val[1] = vmovl_s8(vget_high_s8(mins8));\n        accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));\n\n        scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));\n    }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        process_scales(i, q8, acc);\n        int16x8x2_t scales16;\n        scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));\n        scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));\n        return make_wider(scales16);\n    }\n\n    template <typename Q8>\n    inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {\n        auto m1 = vdupq_n_u8(1);\n        auto shuffle = vdupq_n_u8(8*j);\n        bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),\n                    vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);\n\n            auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),\n                    vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);\n\n            auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),\n                    vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);\n\n            auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),\n                    vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);\n        }\n    }\n\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+32*j);\n    }\n\n    uint32_t aux32[4];\n\n    uint8x16_t scales8;\n\n    Q2bits bits;\n\n    float d;\n};\n\nIQK_ALWAYS_INLINE void fusion_mul_mat_qX_K_q8_K_T_y1_d6k(\n    float32x4_t &acc,\n    const uint8_t *x_ql, // [128] 4bit\n    const uint8_t *x_qh, // [64] 2bit\n    const int8_t *x_scale, // [16] 8bit\n    float x_d,\n    const int8_t *y_qs, // [256] 8bit\n    const int16_t *y_bsums, // [16] 16bit\n    float y_d)\n{\n    float c0 = x_d * y_d;\n    float c1 = -32.0f * c0;\n    const int OFFSET = 1024;\n    __builtin_prefetch((x_ql + OFFSET + 0*64), 0, 0);\n    __builtin_prefetch((x_ql + OFFSET + 1*64), 0, 0);\n    __builtin_prefetch((x_ql + OFFSET + 2*64), 0, 0);\n\n    int16x8_t scale16_0, scale16_1;\n    {\n        int8x16_t tmp = vld1q_s8(x_scale);\n        scale16_0 = vmovl_s8(vget_low_s8(tmp));\n        scale16_1 = vmovl_high_s8(tmp);\n    }\n    {\n        int16x8_t q8s0 = vld1q_s16(y_bsums + 0);\n        int16x8_t q8s1 = vld1q_s16(y_bsums + 8);\n        int32x4_t b0 = vmull_s16(vget_low_s16(scale16_0), vget_low_s16(q8s0));\n        b0 = vmlal_high_s16(b0, scale16_0, q8s0);\n        b0 = vmlal_s16(b0, vget_low_s16(scale16_1), vget_low_s16(q8s1));\n        b0 = vmlal_high_s16(b0, scale16_1, q8s1);\n        acc = vfmaq_n_f32(acc, vcvtq_f32_s32(b0), c1);\n    }\n    uint8x16_t x0, x1, x2, x3, x4, x5, x6, x7;\n    int32x4_t sumi = vdupq_n_s32(0);\n    {\n        const uint8x16_t m0 = vdupq_n_u8(0x3f);\n        const uint8x16_t m1 = vdupq_n_u8(0x30);\n        const uint8x16_t m2 = vdupq_n_u8(0x0f);\n        x0 = vld1q_u8(x_ql + 0*16 + 0*64);\n        x1 = vld1q_u8(x_ql + 1*16 + 0*64);\n        x2 = vld1q_u8(x_ql + 2*16 + 0*64);\n        x3 = vld1q_u8(x_ql + 3*16 + 0*64);\n        uint8x16_t hbits0 = vld1q_u8(x_qh + 0*16 + 0*32);\n        uint8x16_t hbits1 = vld1q_u8(x_qh + 1*16 + 0*32);\n        x4 = vandq_u8(hbits0, m0);\n        x4 = vsriq_n_u8(x4, x0, 4);\n        x5 = vandq_u8(hbits1, m0);\n        x5 = vsriq_n_u8(x5, x1, 4);\n        x6 = vshrq_n_u8(hbits0, 2);\n        x6 = vsriq_n_u8(x6, x2, 4);\n        x7 = vshrq_n_u8(hbits1, 2);\n        x7 = vsriq_n_u8(x7, x3, 4);\n        x0 = vsliq_n_u8(x0, hbits0, 4);\n        x0 = vandq_u8(x0, m0);\n        x1 = vsliq_n_u8(x1, hbits1, 4);\n        x1 = vandq_u8(x1, m0);\n        hbits0 = vshlq_n_u8(hbits0, 2);\n        hbits0 = vandq_u8(hbits0, m1);\n        x2 = vandq_u8(x2, m2);\n        x2 = vorrq_u8(x2, hbits0);\n        hbits1 = vshlq_n_u8(hbits1, 2);\n        hbits1 = vandq_u8(hbits1, m1);\n        x3 = vandq_u8(x3, m2);\n        x3 = vorrq_u8(x3, hbits1);\n    }\n    {\n        int8x16_t base = vdupq_n_s8(32);\n        int8x16_t y0 = vld1q_s8(y_qs + 0*16 + 0*128);\n        int8x16_t y1 = vld1q_s8(y_qs + 1*16 + 0*128);\n        int8x16_t y2 = vld1q_s8(y_qs + 2*16 + 0*128);\n        int8x16_t y3 = vld1q_s8(y_qs + 3*16 + 0*128);\n        int8x16_t y4 = vld1q_s8(y_qs + 4*16 + 0*128);\n        int8x16_t y5 = vld1q_s8(y_qs + 5*16 + 0*128);\n        int8x16_t y6 = vld1q_s8(y_qs + 6*16 + 0*128);\n        int8x16_t y7 = vld1q_s8(y_qs + 7*16 + 0*128);\n        int32x4_t p00 = vdupq_n_s32(0);\n        int32x4_t p01 = vdupq_n_s32(0);\n        int32x4_t p10 = vdupq_n_s32(0);\n        int32x4_t p11 = vdupq_n_s32(0);\n        int32x4_t p20 = vdupq_n_s32(0);\n        int32x4_t p21 = vdupq_n_s32(0);\n        int32x4_t p30 = vdupq_n_s32(0);\n        int32x4_t p31 = vdupq_n_s32(0);\n        p00 = vdotq_s32(p00, vreinterpretq_s8_u8(x0), y0);\n        p01 = vdotq_s32(p01, vreinterpretq_s8_u8(x1), y1);\n        p10 = vdotq_s32(p10, vreinterpretq_s8_u8(x2), y2);\n        p11 = vdotq_s32(p11, vreinterpretq_s8_u8(x3), y3);\n        p20 = vdotq_s32(p20, vreinterpretq_s8_u8(x4), y4);\n        p21 = vdotq_s32(p21, vreinterpretq_s8_u8(x5), y5);\n        p30 = vdotq_s32(p30, vreinterpretq_s8_u8(x6), y6);\n        p31 = vdotq_s32(p31, vreinterpretq_s8_u8(x7), y7);\n        // p00 = vdotq_s32(p00, vsubq_s8(vreinterpretq_s8_u8(x0), base), y0);\n        // p01 = vdotq_s32(p01, vsubq_s8(vreinterpretq_s8_u8(x1), base), y1);\n        // p10 = vdotq_s32(p10, vsubq_s8(vreinterpretq_s8_u8(x2), base), y2);\n        // p11 = vdotq_s32(p11, vsubq_s8(vreinterpretq_s8_u8(x3), base), y3);\n        // p20 = vdotq_s32(p20, vsubq_s8(vreinterpretq_s8_u8(x4), base), y4);\n        // p21 = vdotq_s32(p21, vsubq_s8(vreinterpretq_s8_u8(x5), base), y5);\n        // p30 = vdotq_s32(p30, vsubq_s8(vreinterpretq_s8_u8(x6), base), y6);\n        // p31 = vdotq_s32(p31, vsubq_s8(vreinterpretq_s8_u8(x7), base), y7);\n        p00 = vpaddq_s32(p00, p01);\n        p10 = vpaddq_s32(p10, p11);\n        p20 = vpaddq_s32(p20, p21);\n        p30 = vpaddq_s32(p30, p31);\n        p00 = vpaddq_s32(p00, p10);\n        p20 = vpaddq_s32(p20, p30);\n        sumi = vmlaq_s32(sumi, vmovl_s16(vget_low_s16(scale16_0)), p00);\n        sumi = vmlaq_s32(sumi, vmovl_high_s16(scale16_0), p20);\n    }\n    {\n        const uint8x16_t m0 = vdupq_n_u8(0x3f);\n        const uint8x16_t m1 = vdupq_n_u8(0x30);\n        const uint8x16_t m2 = vdupq_n_u8(0x0f);\n        x0 = vld1q_u8(x_ql + 0*16 + 1*64);\n        x1 = vld1q_u8(x_ql + 1*16 + 1*64);\n        x2 = vld1q_u8(x_ql + 2*16 + 1*64);\n        x3 = vld1q_u8(x_ql + 3*16 + 1*64);\n        uint8x16_t hbits0 = vld1q_u8(x_qh + 0*16 + 1*32);\n        uint8x16_t hbits1 = vld1q_u8(x_qh + 1*16 + 1*32);\n        x4 = vandq_u8(hbits0, m0);\n        x4 = vsriq_n_u8(x4, x0, 4);\n        x5 = vandq_u8(hbits1, m0);\n        x5 = vsriq_n_u8(x5, x1, 4);\n        x6 = vshrq_n_u8(hbits0, 2);\n        x6 = vsriq_n_u8(x6, x2, 4);\n        x7 = vshrq_n_u8(hbits1, 2);\n        x7 = vsriq_n_u8(x7, x3, 4);\n        x0 = vsliq_n_u8(x0, hbits0, 4);\n        x0 = vandq_u8(x0, m0);\n        x1 = vsliq_n_u8(x1, hbits1, 4);\n        x1 = vandq_u8(x1, m0);\n        hbits0 = vshlq_n_u8(hbits0, 2);\n        hbits0 = vandq_u8(hbits0, m1);\n        x2 = vandq_u8(x2, m2);\n        x2 = vorrq_u8(x2, hbits0);\n        hbits1 = vshlq_n_u8(hbits1, 2);\n        hbits1 = vandq_u8(hbits1, m1);\n        x3 = vandq_u8(x3, m2);\n        x3 = vorrq_u8(x3, hbits1);\n    }\n    {\n        int8x16_t base = vdupq_n_s8(32);\n        int8x16_t y0 = vld1q_s8(y_qs + 0*16 + 1*128);\n        int8x16_t y1 = vld1q_s8(y_qs + 1*16 + 1*128);\n        int8x16_t y2 = vld1q_s8(y_qs + 2*16 + 1*128);\n        int8x16_t y3 = vld1q_s8(y_qs + 3*16 + 1*128);\n        int8x16_t y4 = vld1q_s8(y_qs + 4*16 + 1*128);\n        int8x16_t y5 = vld1q_s8(y_qs + 5*16 + 1*128);\n        int8x16_t y6 = vld1q_s8(y_qs + 6*16 + 1*128);\n        int8x16_t y7 = vld1q_s8(y_qs + 7*16 + 1*128);\n        int32x4_t p00 = vdupq_n_s32(0);\n        int32x4_t p01 = vdupq_n_s32(0);\n        int32x4_t p10 = vdupq_n_s32(0);\n        int32x4_t p11 = vdupq_n_s32(0);\n        int32x4_t p20 = vdupq_n_s32(0);\n        int32x4_t p21 = vdupq_n_s32(0);\n        int32x4_t p30 = vdupq_n_s32(0);\n        int32x4_t p31 = vdupq_n_s32(0);\n        p00 = vdotq_s32(p00, vreinterpretq_s8_u8(x0), y0);\n        p01 = vdotq_s32(p01, vreinterpretq_s8_u8(x1), y1);\n        p10 = vdotq_s32(p10, vreinterpretq_s8_u8(x2), y2);\n        p11 = vdotq_s32(p11, vreinterpretq_s8_u8(x3), y3);\n        p20 = vdotq_s32(p20, vreinterpretq_s8_u8(x4), y4);\n        p21 = vdotq_s32(p21, vreinterpretq_s8_u8(x5), y5);\n        p30 = vdotq_s32(p30, vreinterpretq_s8_u8(x6), y6);\n        p31 = vdotq_s32(p31, vreinterpretq_s8_u8(x7), y7);\n        // p00 = vdotq_s32(p00, vsubq_s8(vreinterpretq_s8_u8(x0), base), y0);\n        // p01 = vdotq_s32(p01, vsubq_s8(vreinterpretq_s8_u8(x1), base), y1);\n        // p10 = vdotq_s32(p10, vsubq_s8(vreinterpretq_s8_u8(x2), base), y2);\n        // p11 = vdotq_s32(p11, vsubq_s8(vreinterpretq_s8_u8(x3), base), y3);\n        // p20 = vdotq_s32(p20, vsubq_s8(vreinterpretq_s8_u8(x4), base), y4);\n        // p21 = vdotq_s32(p21, vsubq_s8(vreinterpretq_s8_u8(x5), base), y5);\n        // p30 = vdotq_s32(p30, vsubq_s8(vreinterpretq_s8_u8(x6), base), y6);\n        // p31 = vdotq_s32(p31, vsubq_s8(vreinterpretq_s8_u8(x7), base), y7);\n        p00 = vpaddq_s32(p00, p01);\n        p10 = vpaddq_s32(p10, p11);\n        p20 = vpaddq_s32(p20, p21);\n        p30 = vpaddq_s32(p30, p31);\n        p00 = vpaddq_s32(p00, p10);\n        p20 = vpaddq_s32(p20, p30);\n        sumi = vmlaq_s32(sumi, vmovl_s16(vget_low_s16(scale16_1)), p00);\n        sumi = vmlaq_s32(sumi, vmovl_high_s16(scale16_1), p20);\n    }\n    {\n        acc = vfmaq_n_f32(acc, vcvtq_f32_s32(sumi), c0);\n    }\n    return;\n}\n\nIQK_ALWAYS_INLINE void fusion_mul_mat_qX_K_q8_K_T_y1_d4k(\n    float32x4_t &acc,\n    const uint8_t *x_scale, // [12] 8*2*6bits\n    const uint8_t *x_qs, // [128] 256*4bits\n    float x_d,\n    float x_dmin,\n    const int8_t *y_qs, // [256] 8bit\n    const int16_t *y_bsums, // [16] 16bit\n    float y_d)\n{\n    float c0 = x_d * y_d;\n    float c1 = -x_dmin * y_d;\n    const int OFFSET = 1024;\n    __builtin_prefetch((x_scale + OFFSET + 0*64), 0, 0);\n    __builtin_prefetch((x_scale + OFFSET + 1*64), 0, 0);\n\n    int16x8_t scale_min;\n    int16x8_t scale;\n    {\n        uint32_t utmp[4];\n        const uint8_t * sc8 = (const uint8_t *)utmp;\n        make_q4_scales(x_scale, utmp);\n        int8x16_t ss = vld1q_s8((const int8_t *)sc8);\n        scale = vmovl_s8(vget_low_s8(ss));\n        scale_min = vmovl_high_s8(ss);\n    }\n    {\n        int16x8_t q8s0 = vld1q_s16(y_bsums + 0);\n        int16x8_t q8s1 = vld1q_s16(y_bsums + 8);\n        q8s0 = vpaddq_s16(q8s0, q8s1);\n        int32x4_t b0 = vmull_s16(vget_low_s16(scale_min), vget_low_s16(q8s0));\n        b0 = vmlal_high_s16(b0, scale_min, q8s0);\n        acc = vfmaq_n_f32(acc, vcvtq_f32_s32(b0), c1);\n    }\n    int32x4_t sumi = vdupq_n_s32(0);\n    const uint8x16_t m4b = vdupq_n_u8(0x0f);\n    uint8x16_t x0, x1, x2, x3, x4, x5, x6, x7;\n    {\n        x0 = vld1q_u8(x_qs + 0*16 + 0*64);\n        x1 = vld1q_u8(x_qs + 1*16 + 0*64);\n        x4 = vld1q_u8(x_qs + 2*16 + 0*64);\n        x5 = vld1q_u8(x_qs + 3*16 + 0*64);\n        x2 = vshrq_n_u8(x0, 4);\n        x3 = vshrq_n_u8(x1, 4);\n        x6 = vshrq_n_u8(x4, 4);\n        x7 = vshrq_n_u8(x5, 4);\n        x0 = vandq_u8(x0, m4b);\n        x1 = vandq_u8(x1, m4b);\n        x4 = vandq_u8(x4, m4b);\n        x5 = vandq_u8(x5, m4b);\n    }\n    {\n        int8x16_t y0 = vld1q_s8(y_qs + 0*16 + 0*128);\n        int8x16_t y1 = vld1q_s8(y_qs + 1*16 + 0*128);\n        int8x16_t y2 = vld1q_s8(y_qs + 2*16 + 0*128);\n        int8x16_t y3 = vld1q_s8(y_qs + 3*16 + 0*128);\n        int8x16_t y4 = vld1q_s8(y_qs + 4*16 + 0*128);\n        int8x16_t y5 = vld1q_s8(y_qs + 5*16 + 0*128);\n        int8x16_t y6 = vld1q_s8(y_qs + 6*16 + 0*128);\n        int8x16_t y7 = vld1q_s8(y_qs + 7*16 + 0*128);\n        int32x4_t p0 = vdupq_n_s32(0);\n        int32x4_t p1 = vdupq_n_s32(0);\n        int32x4_t p2 = vdupq_n_s32(0);\n        int32x4_t p3 = vdupq_n_s32(0);\n        p0 = vdotq_s32(p0, vreinterpretq_s8_u8(x0), y0);\n        p1 = vdotq_s32(p1, vreinterpretq_s8_u8(x2), y2);\n        p2 = vdotq_s32(p2, vreinterpretq_s8_u8(x4), y4);\n        p3 = vdotq_s32(p3, vreinterpretq_s8_u8(x6), y6);\n        p0 = vdotq_s32(p0, vreinterpretq_s8_u8(x1), y1);\n        p1 = vdotq_s32(p1, vreinterpretq_s8_u8(x3), y3);\n        p2 = vdotq_s32(p2, vreinterpretq_s8_u8(x5), y5);\n        p3 = vdotq_s32(p3, vreinterpretq_s8_u8(x7), y7);\n        p0 = vpaddq_s32(p0, p1);\n        p2 = vpaddq_s32(p2, p3);\n        p0 = vpaddq_s32(p0, p2);\n        sumi = vmlaq_s32(sumi, vmovl_s16(vget_low_s16(scale)), p0);\n    }\n    {\n        x0 = vld1q_u8(x_qs + 0*16 + 1*64);\n        x1 = vld1q_u8(x_qs + 1*16 + 1*64);\n        x4 = vld1q_u8(x_qs + 2*16 + 1*64);\n        x5 = vld1q_u8(x_qs + 3*16 + 1*64);\n        x2 = vshrq_n_u8(x0, 4);\n        x3 = vshrq_n_u8(x1, 4);\n        x6 = vshrq_n_u8(x4, 4);\n        x7 = vshrq_n_u8(x5, 4);\n        x0 = vandq_u8(x0, m4b);\n        x1 = vandq_u8(x1, m4b);\n        x4 = vandq_u8(x4, m4b);\n        x5 = vandq_u8(x5, m4b);\n    }\n    {\n        int8x16_t y0 = vld1q_s8(y_qs + 0*16 + 1*128);\n        int8x16_t y1 = vld1q_s8(y_qs + 1*16 + 1*128);\n        int8x16_t y2 = vld1q_s8(y_qs + 2*16 + 1*128);\n        int8x16_t y3 = vld1q_s8(y_qs + 3*16 + 1*128);\n        int8x16_t y4 = vld1q_s8(y_qs + 4*16 + 1*128);\n        int8x16_t y5 = vld1q_s8(y_qs + 5*16 + 1*128);\n        int8x16_t y6 = vld1q_s8(y_qs + 6*16 + 1*128);\n        int8x16_t y7 = vld1q_s8(y_qs + 7*16 + 1*128);\n        int32x4_t p0 = vdupq_n_s32(0);\n        int32x4_t p1 = vdupq_n_s32(0);\n        int32x4_t p2 = vdupq_n_s32(0);\n        int32x4_t p3 = vdupq_n_s32(0);\n        p0 = vdotq_s32(p0, vreinterpretq_s8_u8(x0), y0);\n        p1 = vdotq_s32(p1, vreinterpretq_s8_u8(x2), y2);\n        p2 = vdotq_s32(p2, vreinterpretq_s8_u8(x4), y4);\n        p3 = vdotq_s32(p3, vreinterpretq_s8_u8(x6), y6);\n        p0 = vdotq_s32(p0, vreinterpretq_s8_u8(x1), y1);\n        p1 = vdotq_s32(p1, vreinterpretq_s8_u8(x3), y3);\n        p2 = vdotq_s32(p2, vreinterpretq_s8_u8(x5), y5);\n        p3 = vdotq_s32(p3, vreinterpretq_s8_u8(x7), y7);\n        p0 = vpaddq_s32(p0, p1);\n        p2 = vpaddq_s32(p2, p3);\n        p0 = vpaddq_s32(p0, p2);\n        sumi = vmlaq_s32(sumi, vmovl_high_s16(scale), p0);\n    }\n    {\n        acc = vfmaq_n_f32(acc, vcvtq_f32_s32(sumi), c0);\n    }\n}\n\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n\n    Dequantizer deq(vx, bx, nrc_y);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[nrc_y];\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n//#pragma GCC unroll 4\n        for (int i = 0; i < nb; ++i) {\n#ifdef GEMV_Q4K\n            if constexpr (nrc_y == 1 && std::is_same<Dequantizer, DequantizerQ6K>::value) {\n                fusion_mul_mat_qX_K_q8_K_T_y1_d6k(\n                    acc[0],\n                    deq.x[i].ql,\n                    deq.x[i].qh,\n                    deq.x[i].scales,\n                    GGML_FP16_TO_FP32(deq.x[i].d),\n                    q8.y[0][i].qs,\n                    q8.y[0][i].bsums,\n                    q8.y[0][i].d);\n            } else\n#endif\n#ifdef GEMV_Q6K\n            if constexpr (nrc_y == 1 && std::is_same<Dequantizer, DequantizerQ4K>::value) {\n                fusion_mul_mat_qX_K_q8_K_T_y1_d4k(\n                    acc[0],\n                    deq.x[i].scales,\n                    deq.x[i].qs,\n                    GGML_FP16_TO_FP32(deq.x[i].d),\n                    GGML_FP16_TO_FP32(deq.x[i].dmin),\n                    q8.y[0][i].qs,\n                    q8.y[0][i].bsums,\n                    q8.y[0][i].d);\n            } else\n#endif\n            {\n                int32x4_t sumi[nrc_y];\n                for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);\n\n                if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {\n                    deq.process_scales(i, q8, acc);\n                    deq.prepare(i, 0);\n                    deq.compute(q8, i, 0, sumi);\n                    deq.prepare(i, 1);\n                    deq.compute(q8, i, 1, sumi);\n                } else {\n                    if constexpr (Dequantizer::num_blocks() == 8) {\n                        auto scales = deq.new_block(i, q8, acc);\n                        deq.prepare(i, 0);\n#pragma GCC unroll 8\n                        for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                        deq.prepare(i, 1);\n#pragma GCC unroll 8\n                        for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n                    }\n                    else if constexpr (Dequantizer::num_blocks() == 16) {\n                        auto scales = deq.new_block(i, q8, acc);\n                        deq.prepare(i, 0);\n#pragma GCC unroll 8\n                        for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                        deq.prepare(i, 1);\n#pragma GCC unroll 8\n                        for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n                    }\n                    else {\n                        GGML_ASSERT(false);\n                    }\n                }\n\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) {\n                    acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));\n                }\n            }\n\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                info.store(ix, iy, vaddvq_f32(acc[iy]));\n            }\n        }\n    }\n}\n\n// ============================= i-quants\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n\n    static int8x16_t load_values() {\n        static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};\n        return vld1q_s8(iq4nl_values);\n    }\n\n    DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        (void)q8;\n        (void)acc;\n        d = GGML_FP16_TO_FP32(x[i].d);\n        const uint16_t scales_h = x[i].scales_h;\n        const uint16_t * scales_l = (const uint16_t *)x[i].scales_l;\n        aux32[0] = scales_l[0] | (scales_l[1] << 16);\n        aux32[1] = aux32[0] >> 4;\n        // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7\n        uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf));\n        uint16_t * aux16 = (uint16_t *)aux32;\n        aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2;\n        // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7\n        uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30));\n        int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32));\n        // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7\n        scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff));\n        int16x8_t scales16 = vmovl_s8(scales8);\n        int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};\n        return scales;\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare16(x[i].qs+64*j);\n        for (int k = 0; k < 4; ++k) {\n            bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k]));\n            bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k]));\n        }\n    }\n\n    Q4bits bits;\n    const int8x16_t values;\n    uint32_t aux32[2];\n\n    constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};\n\n    float d;\n};\n\nstruct SimpleBits {\n    uint8x16x4_t b1;\n    uint8x16x4_t b2;\n};\n\nIQK_ALWAYS_INLINE int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) {\n    int32x4x2_t scales;\n    auto one = vdupq_n_u32(1);\n    scales.val[0] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v1, 28), 1));\n    scales.val[1] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v2, 28), 1));\n    return scales;\n}\n\ninline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) {\n    auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));\n    auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127))));\n    b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));\n    b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));\n}\n\nIQK_ALWAYS_INLINE int32x4_t prepare_scales_8(const uint32x4_t& v1) {\n    return vreinterpretq_s32_u32(vsliq_n_u32(vdupq_n_u32(1), vshrq_n_u32(v1, 28), 1));\n}\n\nstruct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {\n    DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    IQK_ALWAYS_INLINE float new_block(int i) const { return 0.125f * GGML_FP16_TO_FP32(x[i].d); }\n\n    inline int32x4_t unpack(int i, int j, uint8x16_t * q) const {\n        auto data = vld1q_u32_x2((const uint32_t *)(x[i].qs + 16*j));\n        prepare_all(data, q);\n        return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1]));\n    }\n\nprivate:\n\n    static inline void prepare2(uint8x16_t * b, const uint32_t * bits, const uint64_t * signs) {\n        const uint8_t * idx = (const uint8_t *)bits;\n        b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]});\n        b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]});\n        apply_signs_2(b, signs, bits[1]);\n    }\n\n    inline static void prepare_all(const uint32x4x2_t& data, uint8x16_t * quants) {\n        const uint32_t * q2 = (const uint32_t *)data.val;\n        prepare2(quants+0, q2+0, keven_signs);\n        prepare2(quants+2, q2+2, keven_signs);\n        prepare2(quants+4, q2+4, keven_signs);\n        prepare2(quants+6, q2+6, keven_signs);\n    }\n};\n\ninline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) {\n    auto aux = vld1_u8(sc);\n    auto scales_l = vand_u8(aux, vdup_n_u8(0xf));\n    auto scales_h = vshr_n_u8(aux, 4);\n    auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));\n\n    auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));\n    int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };\n    return make_wider(scales16);\n}\n\nstruct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {\n    DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x4_t new_block(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        prepare_internal(i, 0);\n        return prepare_4bit_scales16(x[i].scales);\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) prepare_internal(i, 1);\n    }\n\nprivate:\n\n    static void make2(const uint16_t * qs, uint8x16_t * b) {\n        auto v1 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[1] & 511))));\n        auto v2 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[2] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[3] & 511))));\n        auto s1 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9))));\n        auto s2 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[2] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[3] >> 9))));\n        b[0] = vreinterpretq_u8_s8(vmulq_s8(v1, s1));\n        b[1] = vreinterpretq_u8_s8(vmulq_s8(v2, s2));\n    }\n\n    inline static void make4(const uint16_t * qs, uint8x16_t * b) {\n        make2(qs + 0, b + 0);\n        make2(qs + 4, b + 2);\n    }\n\n    IQK_ALWAYS_INLINE void prepare_internal(int i, int j) {\n        make4(x[i].qs + 16*j + 0, bits.b1.val);\n        make4(x[i].qs + 16*j + 8, bits.b2.val);\n    }\n\n};\n\n// So, I hate to include this table, but with the GCC 12.3 compiler\n// bundled in the Cosmopolitan tools, loading the unpacked sign bytes\n// from this table using the packed 8 sign bits as index is faster than\n// using the standard trick of vceqq_u8(vandq_u8(bits, mask), mask) to\n// expand the bits to bytes.\nstatic const uint64_t kall_signs[256] = {\n    0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff,\n    0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff,\n    0x01010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0x01010101ff01ffff,\n    0x01010101ffff0101, 0x01010101ffff01ff, 0x01010101ffffff01, 0x01010101ffffffff,\n    0x010101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0x010101ff0101ffff,\n    0x010101ff01ff0101, 0x010101ff01ff01ff, 0x010101ff01ffff01, 0x010101ff01ffffff,\n    0x010101ffff010101, 0x010101ffff0101ff, 0x010101ffff01ff01, 0x010101ffff01ffff,\n    0x010101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0x010101ffffffffff,\n    0x0101ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0x0101ff010101ffff,\n    0x0101ff0101ff0101, 0x0101ff0101ff01ff, 0x0101ff0101ffff01, 0x0101ff0101ffffff,\n    0x0101ff01ff010101, 0x0101ff01ff0101ff, 0x0101ff01ff01ff01, 0x0101ff01ff01ffff,\n    0x0101ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0x0101ff01ffffffff,\n    0x0101ffff01010101, 0x0101ffff010101ff, 0x0101ffff0101ff01, 0x0101ffff0101ffff,\n    0x0101ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0x0101ffff01ffffff,\n    0x0101ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0x0101ffffff01ffff,\n    0x0101ffffffff0101, 0x0101ffffffff01ff, 0x0101ffffffffff01, 0x0101ffffffffffff,\n    0x01ff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0x01ff01010101ffff,\n    0x01ff010101ff0101, 0x01ff010101ff01ff, 0x01ff010101ffff01, 0x01ff010101ffffff,\n    0x01ff0101ff010101, 0x01ff0101ff0101ff, 0x01ff0101ff01ff01, 0x01ff0101ff01ffff,\n    0x01ff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0x01ff0101ffffffff,\n    0x01ff01ff01010101, 0x01ff01ff010101ff, 0x01ff01ff0101ff01, 0x01ff01ff0101ffff,\n    0x01ff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0x01ff01ff01ffffff,\n    0x01ff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0x01ff01ffff01ffff,\n    0x01ff01ffffff0101, 0x01ff01ffffff01ff, 0x01ff01ffffffff01, 0x01ff01ffffffffff,\n    0x01ffff0101010101, 0x01ffff01010101ff, 0x01ffff010101ff01, 0x01ffff010101ffff,\n    0x01ffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0x01ffff0101ffffff,\n    0x01ffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0x01ffff01ff01ffff,\n    0x01ffff01ffff0101, 0x01ffff01ffff01ff, 0x01ffff01ffffff01, 0x01ffff01ffffffff,\n    0x01ffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0x01ffffff0101ffff,\n    0x01ffffff01ff0101, 0x01ffffff01ff01ff, 0x01ffffff01ffff01, 0x01ffffff01ffffff,\n    0x01ffffffff010101, 0x01ffffffff0101ff, 0x01ffffffff01ff01, 0x01ffffffff01ffff,\n    0x01ffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0x01ffffffffffffff,\n    0xff01010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0xff0101010101ffff,\n    0xff01010101ff0101, 0xff01010101ff01ff, 0xff01010101ffff01, 0xff01010101ffffff,\n    0xff010101ff010101, 0xff010101ff0101ff, 0xff010101ff01ff01, 0xff010101ff01ffff,\n    0xff010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0xff010101ffffffff,\n    0xff0101ff01010101, 0xff0101ff010101ff, 0xff0101ff0101ff01, 0xff0101ff0101ffff,\n    0xff0101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0xff0101ff01ffffff,\n    0xff0101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0xff0101ffff01ffff,\n    0xff0101ffffff0101, 0xff0101ffffff01ff, 0xff0101ffffffff01, 0xff0101ffffffffff,\n    0xff01ff0101010101, 0xff01ff01010101ff, 0xff01ff010101ff01, 0xff01ff010101ffff,\n    0xff01ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0xff01ff0101ffffff,\n    0xff01ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0xff01ff01ff01ffff,\n    0xff01ff01ffff0101, 0xff01ff01ffff01ff, 0xff01ff01ffffff01, 0xff01ff01ffffffff,\n    0xff01ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0xff01ffff0101ffff,\n    0xff01ffff01ff0101, 0xff01ffff01ff01ff, 0xff01ffff01ffff01, 0xff01ffff01ffffff,\n    0xff01ffffff010101, 0xff01ffffff0101ff, 0xff01ffffff01ff01, 0xff01ffffff01ffff,\n    0xff01ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0xff01ffffffffffff,\n    0xffff010101010101, 0xffff0101010101ff, 0xffff01010101ff01, 0xffff01010101ffff,\n    0xffff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0xffff010101ffffff,\n    0xffff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0xffff0101ff01ffff,\n    0xffff0101ffff0101, 0xffff0101ffff01ff, 0xffff0101ffffff01, 0xffff0101ffffffff,\n    0xffff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0xffff01ff0101ffff,\n    0xffff01ff01ff0101, 0xffff01ff01ff01ff, 0xffff01ff01ffff01, 0xffff01ff01ffffff,\n    0xffff01ffff010101, 0xffff01ffff0101ff, 0xffff01ffff01ff01, 0xffff01ffff01ffff,\n    0xffff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0xffff01ffffffffff,\n    0xffffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0xffffff010101ffff,\n    0xffffff0101ff0101, 0xffffff0101ff01ff, 0xffffff0101ffff01, 0xffffff0101ffffff,\n    0xffffff01ff010101, 0xffffff01ff0101ff, 0xffffff01ff01ff01, 0xffffff01ff01ffff,\n    0xffffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0xffffff01ffffffff,\n    0xffffffff01010101, 0xffffffff010101ff, 0xffffffff0101ff01, 0xffffffff0101ffff,\n    0xffffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0xffffffff01ffffff,\n    0xffffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0xffffffffff01ffff,\n    0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff,\n};\n\nstruct SignHelper {\n\n    IQK_ALWAYS_INLINE void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const {\n        auto s = vreinterpretq_s8_u64(uint64x2_t{kall_signs[sign_bits[0]], kall_signs[sign_bits[1]]});\n        // Normally we would expect this to be faster, but it isn't.\n        // auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1]));\n        // auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));\n        b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));\n    }\n\n    // We would need these two if we weren't loading from the unpacked sign table.\n    //const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));\n    //const uint8x16_t m1    = vdupq_n_u8(1);\n};\n\nstruct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {\n    DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x4_t new_block(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        prepare_internal(i, 0, bits);\n        return prepare_4bit_scales16(x[i].scales);\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) prepare_internal(i, 1, bits);\n    }\n\nprivate:\n\n    static void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) {\n        uint32_t aux32[2];\n        const uint16_t * aux16 = (const uint16_t *)aux32;\n        for (int k = 0; k < 2; ++k) {\n            aux32[1] = (qh[k] << 4) | (qh[k] << 18);\n            aux32[0] = (aux32[1] << 4) & 0x03000300;\n            aux32[1] &= 0x03000300;\n            b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))),\n                                   vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1]))));\n            b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))),\n                                   vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3]))));\n            sh.apply_signs_1x(b+2*k+0, sign_bits); sign_bits += 2;\n            sh.apply_signs_1x(b+2*k+1, sign_bits); sign_bits += 2;\n        }\n    }\n\n    void prepare_internal(int i, int j, SimpleBits& sb) {\n\n        const auto * qs = x[i].qs + 16*j;\n        const auto * qh = x[i].qh + 4*j;\n        const auto * sign_bits = qs + QK_K/8;\n\n        make4(sh, sign_bits+0, qs+0, qh+0, sb.b1.val);\n        make4(sh, sign_bits+8, qs+8, qh+2, sb.b2.val);\n    }\n\n    SignHelper sh;\n};\n\nstruct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {\n    DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    IQK_ALWAYS_INLINE float new_block(int i) const { return 0.25f * GGML_FP16_TO_FP32(x[i].d); }\n\n    inline int32x4_t unpack(int i, int j, uint8x16_t * q) const {\n        auto q3data = vld1q_u8_x2(x[i].qs + 32*j);\n        auto gas = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j));\n        prepare_block((const uint8_t *)q3data.val, (const uint32_t *)&gas, q);\n        return prepare_scales_8(gas);\n    }\n\nprivate:\n\n    inline static void make2(const uint8_t * q3, const uint32_t sidx, uint8x16_t * b) {\n        b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]});\n        b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]});\n        apply_signs_2(b, keven_signs, sidx);\n    }\n    inline static void prepare_block(const uint8_t * q3, const uint32_t * signs, uint8x16_t * quants) {\n        make2(q3+ 0, signs[0], quants + 0);\n        make2(q3+ 8, signs[1], quants + 2);\n        make2(q3+16, signs[2], quants + 4);\n        make2(q3+24, signs[3], quants + 6);\n    }\n};\n\nstruct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {\n    DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x2_t new_block(int i) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        uint32_t scales32[2];\n        auto qs = vld1q_u8_x2(x[i].qs);\n        auto signs = vld1q_u8(x[i].signs);\n\n        prepare_block((const uint8_t *)qs.val, x[i].qh, (const uint8_t *)&signs);\n\n        std::memcpy(scales32, x[i].scales, 4);\n        scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;\n        scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;\n        auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7\n        scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));\n        auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));\n        int32x4x2_t scales;\n        scales.val[0] = vmovl_s16(vget_low_s16(scales16));\n        scales.val[1] = vmovl_s16(vget_high_s16(scales16));\n        return scales;\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) {\n            auto qs = vld1q_u8_x2(x[i].qs + 32);\n            auto signs = vld1q_u8(x[i].signs + 16);\n            prepare_block((const uint8_t *)qs.val, x[i].qh + 4, (const uint8_t *)&signs);\n        }\n    }\n\nprivate:\n\n    static inline void make2(const SignHelper& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh,\n            const int16x8_t& hshift, uint8x16_t * b) {\n        auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));\n        const uint16_t * idx = (const uint16_t *)&vindex;\n        b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});\n        sh.apply_signs_1x(b+0, sign_bits+0);\n        b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});\n        sh.apply_signs_1x(b+1, sign_bits+2);\n    }\n    static inline void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh,\n            const int16x8_t& hshift, uint8x16_t * b) {\n        auto idx_l = vld1q_u8(qs);\n        make2(sh, sign_bits+0, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);\n        make2(sh, sign_bits+4, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);\n    }\n\n    static int16x8_t load_shift() {\n        static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};\n        return vld1q_s16(k_shift);\n    }\n\n    inline void prepare_block(const uint8_t * qs, const uint8_t * qh, const uint8_t * sign_bits) {\n        auto signs = vld1q_u8(sign_bits);\n        auto s = (const uint8_t *)&signs;\n        make4(sh, s + 0, qs+ 0, qh+0, hshift, bits.b1.val);\n        make4(sh, s + 8, qs+16, qh+2, hshift, bits.b2.val);\n    }\n\n    SignHelper sh;\n    const int16x8_t hshift = load_shift();\n\n};\n\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_IQXXS(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n    Dequantizer deq(vx, bx, nrc_y);\n    uint8x16_t  qx[8];\n    int32x4_t   sumi[nrc_y];\n    float32x4_t acc[nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb; ++i) {\n            float d = deq.new_block(i);\n            auto scales = deq.unpack(i, 0, qx);\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                sumi[iy] = vdupq_n_s32(0);\n                compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 0, sumi[iy]);\n            }\n            scales = deq.unpack(i, 1, qx);\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 1, sumi[iy]);\n                acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, i)), vcvtq_f32_s32(sumi[iy]));\n            }\n        }\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\n// =========================================== Legacy quants\n\ntemplate <typename Block>\ninline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) {\n    for (int k = 0; k < 4; ++k) aux[k] = x[k].d;\n    return vld1_f16((const float16_t *)aux);\n}\n\ntemplate <typename Block>\ninline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) {\n    if constexpr (std::is_same_v<Block, block_q8_1>) {\n        for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; }\n    } else {\n        for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; }\n    }\n    return vld1q_f16((const float16_t *)aux);\n}\n\nstruct Q4LegacyBits {\n    template <typename Block>\n    inline void prepare(const Block * x) {\n        for (int i = 0; i < 4; ++i) {\n            auto q4bits = vld1q_u8(x[i].qs);\n            b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));\n            b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));\n        }\n    }\n    inline void prepare1(const uint8_t * qs, int8x16_t * q) const {\n        auto q4bits = vld1q_u8(qs);\n        q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));\n        q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));\n    }\n    inline void prepare1(const uint8_t * qs) {\n        prepare1(qs, b);\n    }\n    const uint8x16_t m4b = vdupq_n_u8(0xf);\n    int8x16_t b[8];\n};\n\n// One would think this commented out version would do better than the one below\n// because it offers more opportunities to execute instructions in parallel.\n// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers\n// cannot it just do the sequential version below on its own?\n//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {\n//    const auto q8b_1 = vld1q_s8_x2(qs + 0);\n//    auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]);\n//    const auto q8b_2 = vld1q_s8_x2(qs + 32);\n//    auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]);\n//    auto p1234 = vpaddq_s32(p12, p34);\n//    const auto q8b_3 = vld1q_s8_x2(qs + 64);\n//    auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]);\n//    const auto q8b_4 = vld1q_s8_x2(qs + 96);\n//    auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]);\n//    return vpaddq_s32(p1234, vpaddq_s32(p56, p78));\n//}\n\ninline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {\n    auto q8b = vld1q_s8_x2(qs + 0);\n    auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]);\n    q8b = vld1q_s8_x2(qs + 32);\n    auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]);\n    auto p1234 = vpaddq_s32(p12, p34);\n    q8b = vld1q_s8_x2(qs + 64);\n    auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]);\n    q8b = vld1q_s8_x2(qs + 96);\n    auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]);\n    return vpaddq_s32(p1234, vpaddq_s32(p56, p78));\n}\n\ntypedef struct {\n    ggml_half d[4];\n    int8_t qs[4*QK8_0];\n} block_q8_0_x4;\nstatic_assert(sizeof(block_q8_0_x4) == 4*sizeof(block_q8_0), \"wrong q8_0_x4 block size/padding\");\n\ntemplate <int nrc> struct Q80 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q80(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);\n    }\n\n    inline const int8_t * quant_data(int iy, int i) const {\n        const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;\n        return y4->qs;\n    }\n\n    inline float16x4_t load_scales(int iy, int i) const {\n        const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;\n        return vld1_f16((const float16_t *)y4->d);\n    }\n\n    template <typename Dequantizer>\n    inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const {\n        auto qx_scales = deq.new_block(i);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8_scales = load_scales(iy, i);\n            sc16[iy] = vmul_f16(qx_scales, q8_scales);\n        }\n    }\n\n    template <typename Dequantizer>\n    inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {\n        deq.prepare1(i);\n        float d = GGML_FP16_TO_FP32(deq.x[i].d);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8b = vld1q_s8_x2(y[iy][i].qs);\n            auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);\n            acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));\n        }\n    }\n\n    const block_q8_0 * y[nrc_y];\n};\n\ntypedef struct {\n    ggml_half d[8];\n    int8_t qs[4*QK8_1];\n} block_q8_1_x4;\nstatic_assert(sizeof(block_q8_1_x4) == 4*sizeof(block_q8_1), \"wrong q8_1_x4 block size/padding\");\n\ntemplate <int nrc> struct Q81 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q81(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy);\n    }\n\n    inline const int8_t * quant_data(int iy, int i) const {\n        const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;\n        return y4->qs;\n    }\n\n    inline float16x8_t load_scales(int iy, int i) const {\n        const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;\n        return vld1q_f16((const float16_t *)y4->d);\n    }\n\n    template <typename Dequantizer>\n    inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const {\n        auto qx_scales = deq.new_block(i);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8_scales = load_scales(iy, i);\n            auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales));\n            acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m));\n            sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales));\n        }\n    }\n\n    template <typename Dequantizer>\n    inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {\n        deq.prepare1(i);\n        float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8b = vld1q_s8_x2(y[iy][i].qs);\n            auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);\n            acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));\n            acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s)));\n        }\n    }\n\n    const block_q8_1 * y[nrc_y];\n};\n\ntemplate <typename block_q>\nstruct BaseLegacyDequantizer {\n\n    BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {}\n\n    inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); }\n\n    Q4LegacyBits bits;\n\n    const void * vx;\n    const block_q * x;\n    size_t bx;\n};\n\nstruct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> {\n\n    DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        q[0] = vaddq_s8(q[0], m8);\n        q[1] = vaddq_s8(q[1], m8);\n    }\n    inline void prepare1(int i) {\n        prepare1(i, bits.b);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n    const int8x16_t m8 = vdupq_n_s8(-8);\n    //ggml_half aux[4];\n};\n\nstruct DequantizerQ41 : public BaseLegacyDequantizer<block_q4_1> {\n\n    DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i) {\n        bits.prepare1(x[i].qs);\n    }\n\n    inline float16x8_t new_block(int i) {\n        uint32_t aux32[4];\n        const uint32_t * s32 = (const uint32_t *)&x[4*i].d;\n        for (int k = 0; k < 4; ++k) {\n            aux32[k] = *s32; s32 += sizeof(block_q4_1)/4;\n            bits.prepare1(x[4*i+k].qs, bits.b + 2*k);\n        }\n        return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));\n    }\n    // Leaving this commented out attempt to be reminded that I already tried this.\n    // It has basically the same performance as the version above.\n    //inline float16x8_t new_block(int i) {\n    //    uint32x4_t scales = {};\n    //    const block_q4_1 * xi = x + 4*i;\n    //    const uint32_t * s32 = (const uint32_t *)&xi->d;\n    //    scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[0].qs, bits.b + 0);\n    //    scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[1].qs, bits.b + 2);\n    //    scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[2].qs, bits.b + 4);\n    //    scales = vsetq_lane_u32(*s32, scales, 3);\n    //    bits.prepare1(xi[3].qs, bits.b + 6);\n    //    return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle)));\n    //}\n\n    const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};\n};\n\nstruct HighBit5Legacy {\n    inline uint8x16_t to_bytes(const uint8_t * qh) const {\n        uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);\n        return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask));\n    }\n    inline uint8x16_t to_negated_bytes(const uint8_t * qh) const {\n        uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);\n        return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0));\n    }\n    const uint64x2_t mask = vdupq_n_u64(0x8040201008040201);\n    const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));\n};\n\nstruct DequantizerQ50 final : public BaseLegacyDequantizer<block_q5_0> {\n\n    DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        auto qh = x[i].qh;\n        q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0))));\n        q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2))));\n    }\n    inline void prepare1(int i) {\n        prepare1(i, bits.b);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n    HighBit5Legacy hbits;\n\n    const uint8x16_t mh = vdupq_n_u8(0xf0);\n\n};\n\nstruct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> {\n\n    DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i) {\n        bits.b[0] = vld1q_s8(x[i].qs);\n        bits.b[1] = vld1q_s8(x[i].qs+16);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs);\n            bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n};\n\nstruct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> {\n\n    DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        auto qh = x[i].qh;\n        q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0))));\n        q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2))));\n    }\n    inline void prepare1(int i) {\n        bits.prepare1(x[i].qs, bits.b);\n    }\n\n    inline float16x8_t new_block(int i) {\n        uint32_t aux32[4];\n        const uint32_t * s32 = (const uint32_t *)&x[4*i].d;\n        for (int k = 0; k < 4; ++k) {\n            aux32[k] = *s32; s32 += sizeof(block_q5_1)/4;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));\n    }\n\n    HighBit5Legacy hbits;\n\n    const uint8x16_t mh = vdupq_n_u8(0x10);\n    const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};\n\n};\n\ntemplate <typename Dequantizer, typename Q8>\ninline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i));\n        auto scale = vcvt_f32_f16(sc16[iy]);\n        acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall));\n    }\n}\n\ntemplate <typename Dequantizer, typename Q8>\ninline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK4_1;\n\n    float16x4_t sc16[Q8::nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[Q8::nrc_y];\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb/4; ++i) {\n            q8.process_scales(i, deq, sc16, acc);\n            sum_4(i, deq, q8, sc16, acc);\n        }\n        for (int i = 4*(nb/4); i < nb; ++i) {\n            q8.process_1_block(i, deq, acc);\n        }\n\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\ntemplate <typename Dequantizer, typename Q8>\ninline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK4_1;\n\n    float16x4_t sc16[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq1.new_row(ix);\n        deq2.new_row(ix);\n\n        float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) };\n\n        for (int i = 0; i < nb/8; ++i) {\n            q8.process_scales(2*i+0, deq1, sc16+0, acc+0);\n            q8.process_scales(2*i+1, deq2, sc16+1, acc+1);\n            sum_4(2*i+0, deq1, q8, sc16+0, acc+0);\n            sum_4(2*i+1, deq2, q8, sc16+1, acc+1);\n        }\n        for (int i = 2*(nb/8); i < nb/4; ++i) {\n            q8.process_scales(i, deq1, sc16, acc);\n            sum_4(i, deq1, q8, sc16, acc);\n        }\n        for (int i = 4*(nb/4); i < nb; ++i) {\n            q8.process_1_block(i, deq1, acc);\n        }\n\n        info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1])));\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void IQK_NOINLINE mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Q81<nrc_y> q8(info);\n    if constexpr (nrc_y == 1) {\n        Dequantizer deq1(vx, bx), deq2(vx, bx);\n        mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n    } else {\n        Dequantizer deq(vx, bx);\n        mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void IQK_NOINLINE mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Q80<nrc_y> q8(info);\n    if constexpr (nrc_y == 1) {\n        Dequantizer deq1(vx, bx), deq2(vx, bx);\n        mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n    } else {\n        Dequantizer deq(vx, bx);\n        mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);\n    }\n}\n\ntemplate <typename Dequantizer>\nstatic void IQK_NOINLINE mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Dequantizer deq1(vx, bx), deq2(vx, bx);\n    Q81<1> q8(info);\n    mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n}\n\ntemplate <typename Dequantizer>\nstatic void IQK_NOINLINE mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Dequantizer deq1(vx, bx), deq2(vx, bx);\n    Q80<1> q8(info);\n    mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x);\n}\n\ntemplate <typename Dequantizer> void MulMat::set_functions(MulMat& m) {\n    if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||\n                  std::is_same_v<Dequantizer, DequantizerQ80>) {\n        m.funcs[0] = mul_mat_qX_0_q8_0<Dequantizer, 1>;\n        m.funcs[1] = mul_mat_qX_0_q8_0<Dequantizer, 2>;\n        m.funcs[2] = mul_mat_qX_0_q8_0<Dequantizer, 3>;\n        m.funcs[3] = mul_mat_qX_0_q8_0<Dequantizer, 4>;\n        m.funcs[4] = mul_mat_qX_0_q8_0<Dequantizer, 5>;\n        m.funcs[5] = mul_mat_qX_0_q8_0<Dequantizer, 6>;\n        m.funcs[6] = mul_mat_qX_0_q8_0<Dequantizer, 7>;\n        m.funcs[7] = mul_mat_qX_0_q8_0<Dequantizer, 8>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerQ41> || std::is_same_v<Dequantizer, DequantizerQ51>) {\n        m.funcs[0] = mul_mat_qX_1_q8_1<Dequantizer, 1>;\n        m.funcs[1] = mul_mat_qX_1_q8_1<Dequantizer, 2>;\n        m.funcs[2] = mul_mat_qX_1_q8_1<Dequantizer, 3>;\n        m.funcs[3] = mul_mat_qX_1_q8_1<Dequantizer, 4>;\n        m.funcs[4] = mul_mat_qX_1_q8_1<Dequantizer, 5>;\n        m.funcs[5] = mul_mat_qX_1_q8_1<Dequantizer, 6>;\n        m.funcs[6] = mul_mat_qX_1_q8_1<Dequantizer, 7>;\n        m.funcs[7] = mul_mat_qX_1_q8_1<Dequantizer, 8>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2XXS> || std::is_same_v<Dequantizer, DequantizerIQ3XXS>) {\n        m.funcs[0] = mul_mat_qX_K_q8_K_IQXXS<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_IQXXS<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_IQXXS<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_IQXXS<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_IQXXS<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_IQXXS<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_IQXXS<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_IQXXS<8, Dequantizer>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2S> ||\n                       std::is_same_v<Dequantizer, DequantizerIQ3S> ||\n                       std::is_same_v<Dequantizer, DequantizerIQ2XS>) {\n        m.funcs[0] = mul_mat_qX_K_q8_K_IQ<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_IQ<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_IQ<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_IQ<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_IQ<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_IQ<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_IQ<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_IQ<8, Dequantizer>;\n    }\n    else {\n        m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>;\n        m.funcs_v2 = mul_mat_qX_K_q8_K_T_v2<Dequantizer>;\n    }\n}\n\nbool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny) {\n    row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);\n\n    (void)Ny;\n    // Uncommenting out this would disable iqk_mul_mat for matrix x vector multiplications.\n    //if (Ny == 1 && (typeA == GGML_TYPE_IQ2_XXS || typeA == GGML_TYPE_IQ2_XS || typeA == GGML_TYPE_IQ2_S ||\n    //                typeA == GGML_TYPE_IQ3_XXS || typeA == GGML_TYPE_IQ3_S)) return false;\n\n    switch (typeA) {\n        case GGML_TYPE_Q2_K:\n            MulMat::set_functions<DequantizerQ2K>(m);\n            break;\n        case GGML_TYPE_Q3_K:\n            MulMat::set_functions<DequantizerQ3K>(m);\n            break;\n        case GGML_TYPE_Q4_K:\n            MulMat::set_functions<DequantizerQ4K>(m);\n            break;\n        case GGML_TYPE_Q5_K:\n            MulMat::set_functions<DequantizerQ5K>(m);\n            break;\n        case GGML_TYPE_Q6_K:\n            MulMat::set_functions<DequantizerQ6K>(m);\n            break;\n        case GGML_TYPE_IQ4_XS:\n            MulMat::set_functions<DequantizerIQ4XS>(m);\n            break;\n        case GGML_TYPE_IQ3_S:\n            MulMat::set_functions<DequantizerIQ3S>(m);\n            break;\n        case GGML_TYPE_IQ3_XXS:\n            MulMat::set_functions<DequantizerIQ3XXS>(m);\n            break;\n        case GGML_TYPE_IQ2_S:\n            MulMat::set_functions<DequantizerIQ2S>(m);\n            break;\n        case GGML_TYPE_IQ2_XS:\n            MulMat::set_functions<DequantizerIQ2XS>(m);\n            break;\n        case GGML_TYPE_IQ2_XXS:\n            MulMat::set_functions<DequantizerIQ2XXS>(m);\n            break;\n        case GGML_TYPE_Q4_0:\n            MulMat::set_functions<DequantizerQ40>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        case GGML_TYPE_Q4_1:\n            MulMat::set_functions<DequantizerQ41>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);\n            break;\n        case GGML_TYPE_Q5_0:\n            MulMat::set_functions<DequantizerQ50>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        case GGML_TYPE_Q5_1:\n            MulMat::set_functions<DequantizerQ51>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);\n            break;\n        case GGML_TYPE_Q8_0:\n            MulMat::set_functions<DequantizerQ80>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        default:\n            return false;\n    }\n    return true;\n}\n\n}\n\n#endif // __x86_64__ or __aarch64__\n"
  },
  {
    "path": "archive/third_party/llamafile/iqk_mul_mat_arm82.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_arm82.cpp\n// Copyrigth 2024 Iwan Kawrakow.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#ifdef __aarch64__\n#define iqk_mul_mat iqk_mul_mat_arm82\n#define iqk_mul_mat_moe iqk_mul_mat_moe_arm82\n#include \"iqk_mul_mat.inc\"\n#endif  // __aarch64__\n"
  },
  {
    "path": "archive/third_party/llamafile/iqk_mul_mat_x86.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat.inc\n// Copyrigth 2024 Iwan Kawrakow - Apache 2.0 Licens\n// with additions from\n// https://github.com/ikawrakow/ik_llama.cpp/blob/main/ggml/src/iqk/iqk_mul_mat.cpp\n// Copyrigth 2024-2025 Iwan Kawrakow - MIT Licens\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp fenc=utf-8 :vi\n//\n// Copyright 2024 Iwan Kawrakow\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n//\n//\n// Copyright (C) 2024-2025 Iwan Kawrakow\n// MIT license\n// SPDX-License-Identifier: MIT\n//\n\n#include <cstring>\n#include <type_traits>\n#if defined __x86_64__ || defined __aarch64__ || defined(_M_X64)\n\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"sgemm.h\"\n\n// For i-quants, I had to explicitely specify which\n// functions to inline / not inline (at least for some\n// of the functions), else performance would be significantly\n// lower. This is worrysome as things can change with,\n// e.g., a different compiler version or running on a different\n// CPU.\n#ifdef _MSC_VER\n#define IQK_NOINLINE __declspec(noinline)\n#define IQK_ALWAYS_INLINE inline\n#else\n#define IQK_NOINLINE __attribute__((__noinline__))\n#define IQK_ALWAYS_INLINE __attribute__((always_inline))\n#endif\n\n#define GGML_COMMON_IMPL_C\n#include \"llama.cpp/ggml-common.h\"\n\n// clang-format off\n\n// This matrix - vector and matrix - matrix multiplication implementation\n// for legacy quants, k-quants and i-quants makes prompt processing 150-200%\n// (legacy and k-quants) or 250-400% (i-quants) faster.\n// compared to mainline llama.cpp (and llamafile).\n// It provides implementations for ARM_NEON (all quants) and AVX2\n// (all quants except sub-4 bit i-quants).\n//\n// Main idea is that unpacking the quants and the block scales to\n// be ready for dot products with the corresponding Q8_Y quants\n// takes time (here 'Y' stands for K, 0, or 1, depending on quantization type).\n// Hence, if we are performing a QX x Q8_Y matrix matrix\n// multiplication (as needed for prompt processing), we can get\n// a significant speedup by reusing the unpacked QX quants and scales\n// for multiplication with several Q8_K columns. We also achieve fewer\n// loads from memory, which is the main purpose of tiling in general\n// purpose matrix multiplication packages.\n\n#include <utility>\n#include <array>\n\n#endif\n\nconstexpr ggml_type GGML_TYPE_Q8_0_X4 = static_cast<ggml_type>(98);\nconstexpr ggml_type GGML_TYPE_Q8_1_X4 = static_cast<ggml_type>(99);\n\n\nnamespace {\n\ntypedef struct {\n    int32_t i1;\n    int32_t i2;\n} mmid_row_mapping;\n\nstruct DataInfo {\n    float       * s;\n    const char  * cy;\n    size_t        bs;\n    size_t        by;\n    int           cur_y = 0;\n    int           ne11;\n    const mmid_row_mapping * row_mapping = nullptr;\n    size_t        bs2 = 0;\n\n    inline const char * src1_row(int iy) const {\n        if (!row_mapping) return cy + (cur_y + iy)*by;\n        int i11 = row_mapping[cur_y + iy].i1 % ne11;\n        int i12 = row_mapping[cur_y + iy].i2;\n        return cy + (i11 + i12*ne11)*by;\n    }\n\n    inline void store(int ix, int iy, float result) const {\n        *(dst_row(iy) + ix) = result;\n        //dst_row(iy)[ix] = result;\n    }\n    inline float * dst_row(int iy) const {\n        if (!row_mapping) return s + (cur_y + iy)*bs;\n        int i12 = row_mapping[cur_y + iy].i2;\n        int i1  = row_mapping[cur_y + iy].i1;\n        int i2  = i12;\n        return s + i1*bs + i2*bs2;\n    }\n};\n\n/*\nmoonll \nchange param for set_mul_mat \nadd func16\n*/\n\ntypedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);\n\nstruct MulMat {\n    std::array<mul_mat_t, 8> funcs = {};\n    mul_mat_t func16 = nullptr;\n    //inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {\n    IQK_NOINLINE void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {\n        constexpr int k_x_step = 64; // This works best on my Ryzen-7950X and M2 Max CPUs (but differences to other tile size are small)\n\n        // copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L162\n        // MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\n        if (func16 && nrc_y >= 16) {\n            int n_step = (nrc_y - info.cur_y)/16;\n            for (int ix = 0; ix < nrc_x; ix += k_x_step) {\n                auto this_info = info;\n                this_info.s += ix;\n                int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;\n                for (int iy = 0; iy < n_step; ++iy) {\n                    func16(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);\n                    this_info.cur_y += 16;\n                }\n            }\n            info.cur_y += 16 * n_step;\n            if (info.cur_y == nrc_y) return;\n        }\n        // end copy\n\n        int n_step = (nrc_y - info.cur_y)/funcs.size();\n        if (n_step > 0) {\n            for (int ix = 0; ix < nrc_x; ix += k_x_step) {\n                auto this_info = info;\n                this_info.s += ix;\n                int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;\n                for (int iy = 0; iy < n_step; ++iy) {\n                    funcs.back()(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);\n                    this_info.cur_y += funcs.size();\n                }\n            }\n            info.cur_y += funcs.size() * n_step;\n        }\n        int n_left = nrc_y - info.cur_y;\n        if (n_left > 0) {\n            funcs[n_left-1](n, vx, bx, info, nrc_x);\n        }\n    }\n    static IQK_NOINLINE bool set_mul_mat(int typeA, int typeB,int ne00, MulMat& mm, int Ny);\nprivate:\n    template <typename Dequantizer> static IQK_NOINLINE void set_functions(MulMat& m);\n};\n\ninline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {\n    const uint16_t * scales = (const uint16_t *)scales8;\n    const uint32_t a0 = scales[0] | (scales[1] << 16);\n    const uint32_t a1 = scales[2] | (scales[3] << 16);\n    const uint32_t a2 = scales[4] | (scales[5] << 16);\n    aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);\n    aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);\n    aux32[2] = a1 & 0x3f3f3f3f;\n    aux32[0] = a0 & 0x3f3f3f3f;\n}\n\n/*\nmoonll\ndecoding tables\n*/\n// copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L570\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\n#ifdef __AVX2__\nstatic const uint64_t iq1s_grid_us[2048] = {\n    0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200,\n    0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000,\n    0x0000000000020002, 0x0000000000020200, 0x0000000000020202, 0x0000000001000101,\n    0x0000000001010001, 0x0000000001010100, 0x0000000001010102, 0x0000000001020101,\n    0x0000000002000000, 0x0000000002000002, 0x0000000002000200, 0x0000000002000202,\n    0x0000000002010101, 0x0000000002020000, 0x0000000002020002, 0x0000000002020200,\n    0x0000000002020202, 0x0000000100000100, 0x0000000100000101, 0x0000000100010001,\n    0x0000000100010100, 0x0000000100010102, 0x0000000100010201, 0x0000000100010202,\n    0x0000000100020101, 0x0000000101000001, 0x0000000101000102, 0x0000000101000201,\n    0x0000000101010002, 0x0000000101010101, 0x0000000101010202, 0x0000000101020001,\n    0x0000000101020100, 0x0000000101020102, 0x0000000101020200, 0x0000000102000101,\n    0x0000000102010001, 0x0000000102010100, 0x0000000102010102, 0x0000000102020101,\n    0x0000000200000000, 0x0000000200000002, 0x0000000200000200, 0x0000000200000202,\n    0x0000000200010101, 0x0000000200020000, 0x0000000200020002, 0x0000000200020200,\n    0x0000000200020202, 0x0000000201000101, 0x0000000201010001, 0x0000000201010201,\n    0x0000000201020100, 0x0000000201020201, 0x0000000202000000, 0x0000000202000002,\n    0x0000000202000200, 0x0000000202000202, 0x0000000202010001, 0x0000000202010101,\n    0x0000000202010201, 0x0000000202020000, 0x0000000202020002, 0x0000000202020200,\n    0x0000000202020202, 0x0000010000010001, 0x0000010000010100, 0x0000010000010102,\n    0x0000010000020101, 0x0000010001000001, 0x0000010001000201, 0x0000010001010101,\n    0x0000010001010202, 0x0000010001020100, 0x0000010001020101, 0x0000010002010001,\n    0x0000010002010201, 0x0000010002020101, 0x0000010100000001, 0x0000010100000100,\n    0x0000010100000101, 0x0000010100000102, 0x0000010100010101, 0x0000010100010200,\n    0x0000010100010202, 0x0000010100020201, 0x0000010101000000, 0x0000010101000101,\n    0x0000010101000202, 0x0000010101010000, 0x0000010101010001, 0x0000010101010100,\n    0x0000010101010101, 0x0000010101010102, 0x0000010101010201, 0x0000010101020000,\n    0x0000010101020002, 0x0000010101020101, 0x0000010101020200, 0x0000010101020202,\n    0x0000010102000001, 0x0000010102010001, 0x0000010102010101, 0x0000010102010200,\n    0x0000010102010202, 0x0000010102020001, 0x0000010102020100, 0x0000010102020101,\n    0x0000010102020102, 0x0000010102020201, 0x0000010200010100, 0x0000010200010201,\n    0x0000010201000001, 0x0000010201000100, 0x0000010201010000, 0x0000010201010002,\n    0x0000010201010101, 0x0000010201010200, 0x0000010201020000, 0x0000010201020001,\n    0x0000010201020102, 0x0000010201020201, 0x0000010202000101, 0x0000010202010001,\n    0x0000010202010100, 0x0000010202010201, 0x0000020000000000, 0x0000020000000002,\n    0x0000020000000200, 0x0000020000000202, 0x0000020000010101, 0x0000020000020000,\n    0x0000020000020002, 0x0000020000020200, 0x0000020000020202, 0x0000020001000101,\n    0x0000020001010001, 0x0000020001010102, 0x0000020001020101, 0x0000020002000000,\n    0x0000020002000002, 0x0000020002000200, 0x0000020002000202, 0x0000020002010101,\n    0x0000020002020000, 0x0000020002020002, 0x0000020002020200, 0x0000020002020202,\n    0x0000020100000101, 0x0000020100010001, 0x0000020100010100, 0x0000020100010201,\n    0x0000020100020100, 0x0000020100020101, 0x0000020101000001, 0x0000020101010000,\n    0x0000020101010001, 0x0000020101010101, 0x0000020101020001, 0x0000020101020100,\n    0x0000020101020201, 0x0000020102010001, 0x0000020102010100, 0x0000020102010102,\n    0x0000020102010201, 0x0000020102020101, 0x0000020200000000, 0x0000020200000002,\n    0x0000020200000200, 0x0000020200000202, 0x0000020200010101, 0x0000020200020000,\n    0x0000020200020002, 0x0000020200020200, 0x0000020200020202, 0x0000020201000101,\n    0x0000020201010001, 0x0000020201010201, 0x0000020201020001, 0x0000020201020101,\n    0x0000020202000000, 0x0000020202000002, 0x0000020202000101, 0x0000020202000200,\n    0x0000020202000202, 0x0000020202010101, 0x0000020202020000, 0x0000020202020002,\n    0x0000020202020200, 0x0000020202020202, 0x0001000000010000, 0x0001000000010001,\n    0x0001000000010100, 0x0001000000010201, 0x0001000000020100, 0x0001000000020101,\n    0x0001000001000001, 0x0001000001000100, 0x0001000001010000, 0x0001000001010101,\n    0x0001000001010200, 0x0001000001020001, 0x0001000001020100, 0x0001000001020101,\n    0x0001000001020201, 0x0001000002010001, 0x0001000002010100, 0x0001000002010102,\n    0x0001000002020001, 0x0001000002020101, 0x0001000100000001, 0x0001000100000100,\n    0x0001000100000102, 0x0001000100000201, 0x0001000100010000, 0x0001000100010002,\n    0x0001000100010101, 0x0001000100010200, 0x0001000100020001, 0x0001000100020100,\n    0x0001000100020201, 0x0001000101000101, 0x0001000101000202, 0x0001000101010000,\n    0x0001000101010001, 0x0001000101010002, 0x0001000101010100, 0x0001000101010101,\n    0x0001000101010102, 0x0001000101010201, 0x0001000101020000, 0x0001000101020101,\n    0x0001000102000100, 0x0001000102010002, 0x0001000102010101, 0x0001000102020001,\n    0x0001000102020100, 0x0001000200010001, 0x0001000200010100, 0x0001000200010102,\n    0x0001000200020101, 0x0001000201000000, 0x0001000201000102, 0x0001000201000201,\n    0x0001000201010002, 0x0001000201010101, 0x0001000201010200, 0x0001000201010202,\n    0x0001000201020100, 0x0001000201020102, 0x0001000202000101, 0x0001000202010001,\n    0x0001000202010100, 0x0001000202010102, 0x0001000202020101, 0x0001010000000001,\n    0x0001010000000102, 0x0001010000000201, 0x0001010000010100, 0x0001010000010101,\n    0x0001010000010200, 0x0001010000010201, 0x0001010000020001, 0x0001010000020102,\n    0x0001010001000001, 0x0001010001000101, 0x0001010001000102, 0x0001010001000200,\n    0x0001010001000202, 0x0001010001010001, 0x0001010001010100, 0x0001010001010101,\n    0x0001010001010102, 0x0001010001010201, 0x0001010001020002, 0x0001010001020101,\n    0x0001010001020200, 0x0001010002000100, 0x0001010002000201, 0x0001010002010000,\n    0x0001010002010100, 0x0001010002010101, 0x0001010002010200, 0x0001010002010201,\n    0x0001010002010202, 0x0001010002020001, 0x0001010002020100, 0x0001010002020101,\n    0x0001010002020201, 0x0001010100000002, 0x0001010100000101, 0x0001010100000202,\n    0x0001010100010001, 0x0001010100010100, 0x0001010100010101, 0x0001010100010102,\n    0x0001010100010201, 0x0001010100020000, 0x0001010100020002, 0x0001010100020101,\n    0x0001010100020200, 0x0001010100020202, 0x0001010101000001, 0x0001010101000100,\n    0x0001010101000101, 0x0001010101000102, 0x0001010101010001, 0x0001010101010002,\n    0x0001010101010100, 0x0001010101010101, 0x0001010101010102, 0x0001010101010201,\n    0x0001010101010202, 0x0001010101020001, 0x0001010101020100, 0x0001010101020101,\n    0x0001010101020102, 0x0001010101020201, 0x0001010102000000, 0x0001010102000002,\n    0x0001010102000100, 0x0001010102000101, 0x0001010102000200, 0x0001010102000202,\n    0x0001010102010000, 0x0001010102010001, 0x0001010102010100, 0x0001010102010101,\n    0x0001010102010102, 0x0001010102010201, 0x0001010102010202, 0x0001010102020000,\n    0x0001010102020002, 0x0001010102020101, 0x0001010200000001, 0x0001010200000100,\n    0x0001010200000101, 0x0001010200000102, 0x0001010200010101, 0x0001010200010102,\n    0x0001010200010200, 0x0001010200010202, 0x0001010200020001, 0x0001010200020102,\n    0x0001010201000000, 0x0001010201000002, 0x0001010201000100, 0x0001010201000101,\n    0x0001010201000200, 0x0001010201000202, 0x0001010201010001, 0x0001010201010101,\n    0x0001010201010102, 0x0001010201010200, 0x0001010201010201, 0x0001010201020001,\n    0x0001010201020100, 0x0001010201020101, 0x0001010201020200, 0x0001010201020201,\n    0x0001010201020202, 0x0001010202000102, 0x0001010202000202, 0x0001010202010002,\n    0x0001010202010101, 0x0001010202020100, 0x0001010202020201, 0x0001020000010001,\n    0x0001020000010102, 0x0001020000020101, 0x0001020001000001, 0x0001020001000100,\n    0x0001020001000102, 0x0001020001000201, 0x0001020001010000, 0x0001020001010101,\n    0x0001020001010200, 0x0001020001010202, 0x0001020001020000, 0x0001020001020001,\n    0x0001020001020100, 0x0001020001020102, 0x0001020001020201, 0x0001020002000101,\n    0x0001020002010001, 0x0001020002010100, 0x0001020002020101, 0x0001020100010000,\n    0x0001020100010002, 0x0001020100010101, 0x0001020100010202, 0x0001020100020001,\n    0x0001020100020101, 0x0001020101000002, 0x0001020101000100, 0x0001020101000101,\n    0x0001020101000200, 0x0001020101010001, 0x0001020101010100, 0x0001020101010101,\n    0x0001020101010102, 0x0001020101010201, 0x0001020101010202, 0x0001020101020000,\n    0x0001020101020101, 0x0001020101020202, 0x0001020102000201, 0x0001020102010001,\n    0x0001020102010002, 0x0001020102010101, 0x0001020102010200, 0x0001020102020001,\n    0x0001020102020102, 0x0001020102020201, 0x0001020200000201, 0x0001020200010102,\n    0x0001020200020100, 0x0001020200020102, 0x0001020201000100, 0x0001020201000102,\n    0x0001020201000201, 0x0001020201010000, 0x0001020201010002, 0x0001020201010101,\n    0x0001020201010200, 0x0001020201020001, 0x0001020201020102, 0x0001020201020201,\n    0x0001020202000101, 0x0001020202010001, 0x0001020202010102, 0x0001020202010202,\n    0x0002000000000000, 0x0002000000000002, 0x0002000000000200, 0x0002000000000202,\n    0x0002000000010101, 0x0002000000020000, 0x0002000000020002, 0x0002000000020101,\n    0x0002000000020200, 0x0002000000020202, 0x0002000001000101, 0x0002000001010001,\n    0x0002000001010201, 0x0002000001020001, 0x0002000001020101, 0x0002000002000000,\n    0x0002000002000002, 0x0002000002000200, 0x0002000002000202, 0x0002000002010101,\n    0x0002000002020000, 0x0002000002020002, 0x0002000002020101, 0x0002000002020200,\n    0x0002000002020202, 0x0002000100000101, 0x0002000100010001, 0x0002000100010100,\n    0x0002000100010201, 0x0002000100020101, 0x0002000101000002, 0x0002000101000100,\n    0x0002000101000201, 0x0002000101010101, 0x0002000101010200, 0x0002000101010202,\n    0x0002000101020001, 0x0002000101020100, 0x0002000101020101, 0x0002000101020102,\n    0x0002000102000101, 0x0002000102010000, 0x0002000102010102, 0x0002000102010201,\n    0x0002000102020101, 0x0002000200000001, 0x0002000200000200, 0x0002000200000202,\n    0x0002000200010001, 0x0002000200010101, 0x0002000200020000, 0x0002000200020002,\n    0x0002000200020200, 0x0002000200020202, 0x0002000201000101, 0x0002000201010001,\n    0x0002000201010102, 0x0002000201010201, 0x0002000201020101, 0x0002000202000001,\n    0x0002000202000200, 0x0002000202000202, 0x0002000202010001, 0x0002000202010101,\n    0x0002000202020000, 0x0002000202020002, 0x0002000202020200, 0x0002000202020202,\n    0x0002010000000101, 0x0002010000010100, 0x0002010000010102, 0x0002010000010201,\n    0x0002010000020101, 0x0002010001000100, 0x0002010001000101, 0x0002010001000102,\n    0x0002010001000201, 0x0002010001010002, 0x0002010001010101, 0x0002010001010200,\n    0x0002010001010202, 0x0002010001020102, 0x0002010002000101, 0x0002010002010001,\n    0x0002010002010100, 0x0002010002010201, 0x0002010002020001, 0x0002010002020101,\n    0x0002010100000201, 0x0002010100010101, 0x0002010100020001, 0x0002010100020201,\n    0x0002010101000000, 0x0002010101000101, 0x0002010101000200, 0x0002010101010001,\n    0x0002010101010100, 0x0002010101010101, 0x0002010101010201, 0x0002010101020002,\n    0x0002010101020101, 0x0002010101020200, 0x0002010102000201, 0x0002010102010000,\n    0x0002010102010100, 0x0002010102010101, 0x0002010102010200, 0x0002010102010202,\n    0x0002010102020001, 0x0002010102020100, 0x0002010102020102, 0x0002010102020201,\n    0x0002010200000101, 0x0002010200010000, 0x0002010200010002, 0x0002010200010201,\n    0x0002010200020101, 0x0002010201000001, 0x0002010201000201, 0x0002010201010101,\n    0x0002010201020000, 0x0002010201020001, 0x0002010201020201, 0x0002010202000100,\n    0x0002010202000102, 0x0002010202010000, 0x0002010202010202, 0x0002020000000000,\n    0x0002020000000002, 0x0002020000000200, 0x0002020000000202, 0x0002020000010101,\n    0x0002020000020000, 0x0002020000020002, 0x0002020000020200, 0x0002020000020202,\n    0x0002020001000101, 0x0002020001010001, 0x0002020001010100, 0x0002020001020101,\n    0x0002020002000000, 0x0002020002000002, 0x0002020002000200, 0x0002020002000202,\n    0x0002020002020000, 0x0002020002020002, 0x0002020002020200, 0x0002020002020202,\n    0x0002020100000201, 0x0002020100010001, 0x0002020100010100, 0x0002020100010201,\n    0x0002020100020101, 0x0002020101000102, 0x0002020101000201, 0x0002020101010002,\n    0x0002020101010101, 0x0002020101020001, 0x0002020101020100, 0x0002020101020102,\n    0x0002020101020201, 0x0002020102000101, 0x0002020102010000, 0x0002020102010102,\n    0x0002020102010201, 0x0002020102020100, 0x0002020102020101, 0x0002020200000000,\n    0x0002020200000002, 0x0002020200000200, 0x0002020200000202, 0x0002020200020000,\n    0x0002020200020002, 0x0002020200020200, 0x0002020200020202, 0x0002020201000101,\n    0x0002020201010001, 0x0002020201010102, 0x0002020201010201, 0x0002020201020101,\n    0x0002020202000000, 0x0002020202000002, 0x0002020202000200, 0x0002020202000202,\n    0x0002020202010101, 0x0002020202020000, 0x0002020202020002, 0x0002020202020200,\n    0x0002020202020202, 0x0100000000000101, 0x0100000000010001, 0x0100000000010102,\n    0x0100000000020101, 0x0100000001000201, 0x0100000001010002, 0x0100000001010101,\n    0x0100000001010200, 0x0100000001010202, 0x0100000001020001, 0x0100000001020100,\n    0x0100000001020102, 0x0100000002010100, 0x0100000002010201, 0x0100000002020001,\n    0x0100000002020102, 0x0100000100000000, 0x0100000100000001, 0x0100000100000100,\n    0x0100000100000102, 0x0100000100000201, 0x0100000100010002, 0x0100000100010101,\n    0x0100000100010102, 0x0100000100010200, 0x0100000100010202, 0x0100000100020001,\n    0x0100000100020102, 0x0100000100020201, 0x0100000101000101, 0x0100000101000200,\n    0x0100000101000202, 0x0100000101010001, 0x0100000101010100, 0x0100000101010101,\n    0x0100000101010102, 0x0100000101010201, 0x0100000101010202, 0x0100000101020101,\n    0x0100000101020200, 0x0100000101020202, 0x0100000102000001, 0x0100000102000100,\n    0x0100000102000102, 0x0100000102010000, 0x0100000102010002, 0x0100000102010101,\n    0x0100000102020000, 0x0100000102020001, 0x0100000102020002, 0x0100000200000101,\n    0x0100000200010001, 0x0100000200010100, 0x0100000200010102, 0x0100000200020101,\n    0x0100000201000001, 0x0100000201010002, 0x0100000201010101, 0x0100000201010202,\n    0x0100000201020100, 0x0100000201020201, 0x0100000202000201, 0x0100000202010100,\n    0x0100000202020101, 0x0100010000000001, 0x0100010000010101, 0x0100010000010201,\n    0x0100010000020201, 0x0100010001000101, 0x0100010001000200, 0x0100010001000202,\n    0x0100010001010001, 0x0100010001010100, 0x0100010001010101, 0x0100010001010102,\n    0x0100010001020001, 0x0100010001020002, 0x0100010001020101, 0x0100010001020200,\n    0x0100010001020202, 0x0100010002000001, 0x0100010002000102, 0x0100010002000201,\n    0x0100010002010000, 0x0100010002010002, 0x0100010002010101, 0x0100010002020000,\n    0x0100010002020001, 0x0100010002020201, 0x0100010100000001, 0x0100010100000002,\n    0x0100010100000101, 0x0100010100000202, 0x0100010100010001, 0x0100010100010100,\n    0x0100010100010101, 0x0100010100010102, 0x0100010100010201, 0x0100010100020000,\n    0x0100010100020101, 0x0100010100020202, 0x0100010101000001, 0x0100010101000100,\n    0x0100010101000101, 0x0100010101000102, 0x0100010101000201, 0x0100010101010000,\n    0x0100010101010001, 0x0100010101010100, 0x0100010101010101, 0x0100010101010102,\n    0x0100010101010200, 0x0100010101010201, 0x0100010101020001, 0x0100010101020100,\n    0x0100010101020101, 0x0100010101020102, 0x0100010101020201, 0x0100010102000002,\n    0x0100010102000100, 0x0100010102000101, 0x0100010102000200, 0x0100010102010001,\n    0x0100010102010100, 0x0100010102010101, 0x0100010102010102, 0x0100010102010201,\n    0x0100010102010202, 0x0100010102020101, 0x0100010102020200, 0x0100010102020202,\n    0x0100010200000001, 0x0100010200000101, 0x0100010200000201, 0x0100010200010100,\n    0x0100010200010101, 0x0100010200010200, 0x0100010200010202, 0x0100010200020001,\n    0x0100010200020100, 0x0100010200020201, 0x0100010201000000, 0x0100010201000002,\n    0x0100010201000101, 0x0100010201000200, 0x0100010201010000, 0x0100010201010001,\n    0x0100010201010002, 0x0100010201010101, 0x0100010201010102, 0x0100010201010201,\n    0x0100010201020002, 0x0100010201020101, 0x0100010201020200, 0x0100010202000001,\n    0x0100010202000101, 0x0100010202000202, 0x0100010202010100, 0x0100010202010101,\n    0x0100010202020001, 0x0100010202020100, 0x0100010202020102, 0x0100020000000101,\n    0x0100020000010001, 0x0100020000010101, 0x0100020000010202, 0x0100020000020101,\n    0x0100020001000002, 0x0100020001000201, 0x0100020001010000, 0x0100020001010101,\n    0x0100020001010200, 0x0100020001020001, 0x0100020001020100, 0x0100020001020102,\n    0x0100020001020201, 0x0100020002000101, 0x0100020002010001, 0x0100020002010100,\n    0x0100020002010102, 0x0100020002010201, 0x0100020002020101, 0x0100020100000001,\n    0x0100020100000101, 0x0100020100000102, 0x0100020100000202, 0x0100020100010000,\n    0x0100020100010100, 0x0100020100010101, 0x0100020100010200, 0x0100020100020001,\n    0x0100020100020100, 0x0100020100020102, 0x0100020101000000, 0x0100020101000101,\n    0x0100020101000202, 0x0100020101010001, 0x0100020101010002, 0x0100020101010100,\n    0x0100020101010101, 0x0100020101010102, 0x0100020101010201, 0x0100020101020000,\n    0x0100020101020002, 0x0100020101020101, 0x0100020101020102, 0x0100020101020202,\n    0x0100020102000102, 0x0100020102000201, 0x0100020102010002, 0x0100020102010101,\n    0x0100020102010102, 0x0100020102010200, 0x0100020102020001, 0x0100020102020100,\n    0x0100020102020102, 0x0100020102020201, 0x0100020200010102, 0x0100020201000100,\n    0x0100020201000102, 0x0100020201000201, 0x0100020201010101, 0x0100020201010200,\n    0x0100020201010202, 0x0100020201020100, 0x0100020201020201, 0x0100020202010100,\n    0x0100020202020101, 0x0101000000000001, 0x0101000000000100, 0x0101000000000101,\n    0x0101000000000102, 0x0101000000000201, 0x0101000000010002, 0x0101000000010101,\n    0x0101000000010202, 0x0101000000020001, 0x0101000000020100, 0x0101000000020201,\n    0x0101000001000000, 0x0101000001000101, 0x0101000001000200, 0x0101000001010001,\n    0x0101000001010100, 0x0101000001010101, 0x0101000001010102, 0x0101000001010201,\n    0x0101000001020101, 0x0101000001020200, 0x0101000002000102, 0x0101000002000201,\n    0x0101000002010101, 0x0101000002010200, 0x0101000002020000, 0x0101000002020001,\n    0x0101000002020102, 0x0101000002020201, 0x0101000100000101, 0x0101000100000200,\n    0x0101000100000201, 0x0101000100000202, 0x0101000100010001, 0x0101000100010100,\n    0x0101000100010101, 0x0101000100010102, 0x0101000100010200, 0x0101000100010201,\n    0x0101000100020000, 0x0101000100020101, 0x0101000100020102, 0x0101000100020200,\n    0x0101000100020202, 0x0101000101000001, 0x0101000101000100, 0x0101000101000101,\n    0x0101000101000102, 0x0101000101000201, 0x0101000101010000, 0x0101000101010001,\n    0x0101000101010002, 0x0101000101010100, 0x0101000101010101, 0x0101000101010102,\n    0x0101000101010200, 0x0101000101010201, 0x0101000101010202, 0x0101000101020001,\n    0x0101000101020100, 0x0101000101020101, 0x0101000101020102, 0x0101000101020201,\n    0x0101000102000002, 0x0101000102000101, 0x0101000102010001, 0x0101000102010100,\n    0x0101000102010101, 0x0101000102010102, 0x0101000102010201, 0x0101000102020000,\n    0x0101000102020101, 0x0101000102020202, 0x0101000200000001, 0x0101000200000102,\n    0x0101000200010002, 0x0101000200010101, 0x0101000200010202, 0x0101000200020001,\n    0x0101000200020100, 0x0101000201000002, 0x0101000201000101, 0x0101000201000202,\n    0x0101000201010001, 0x0101000201010100, 0x0101000201010101, 0x0101000201010102,\n    0x0101000201010201, 0x0101000201020002, 0x0101000201020101, 0x0101000202000101,\n    0x0101000202010000, 0x0101000202010002, 0x0101000202010101, 0x0101000202010201,\n    0x0101000202010202, 0x0101000202020100, 0x0101010000000100, 0x0101010000000101,\n    0x0101010000010001, 0x0101010000010100, 0x0101010000010101, 0x0101010000010102,\n    0x0101010000010200, 0x0101010000010201, 0x0101010000020001, 0x0101010000020101,\n    0x0101010000020200, 0x0101010000020202, 0x0101010001000001, 0x0101010001000100,\n    0x0101010001000101, 0x0101010001000102, 0x0101010001000201, 0x0101010001000202,\n    0x0101010001010000, 0x0101010001010001, 0x0101010001010100, 0x0101010001010101,\n    0x0101010001010102, 0x0101010001010200, 0x0101010001010201, 0x0101010001010202,\n    0x0101010001020001, 0x0101010001020002, 0x0101010001020100, 0x0101010001020101,\n    0x0101010001020102, 0x0101010001020201, 0x0101010002000000, 0x0101010002000200,\n    0x0101010002000202, 0x0101010002010001, 0x0101010002010100, 0x0101010002010101,\n    0x0101010002010102, 0x0101010002010201, 0x0101010002020001, 0x0101010002020100,\n    0x0101010002020101, 0x0101010002020202, 0x0101010100000001, 0x0101010100000002,\n    0x0101010100000100, 0x0101010100000101, 0x0101010100000102, 0x0101010100000201,\n    0x0101010100010000, 0x0101010100010001, 0x0101010100010002, 0x0101010100010100,\n    0x0101010100010101, 0x0101010100010102, 0x0101010100010201, 0x0101010100010202,\n    0x0101010100020001, 0x0101010100020100, 0x0101010100020101, 0x0101010100020102,\n    0x0101010100020201, 0x0101010101000000, 0x0101010101000001, 0x0101010101000002,\n    0x0101010101000100, 0x0101010101000101, 0x0101010101000102, 0x0101010101000200,\n    0x0101010101000201, 0x0101010101010000, 0x0101010101010001, 0x0101010101010002,\n    0x0101010101010100, 0x0101010101010101, 0x0101010101010102, 0x0101010101010200,\n    0x0101010101010201, 0x0101010101010202, 0x0101010101020000, 0x0101010101020001,\n    0x0101010101020100, 0x0101010101020101, 0x0101010101020102, 0x0101010101020200,\n    0x0101010101020201, 0x0101010101020202, 0x0101010102000001, 0x0101010102000100,\n    0x0101010102000101, 0x0101010102000201, 0x0101010102000202, 0x0101010102010000,\n    0x0101010102010001, 0x0101010102010100, 0x0101010102010101, 0x0101010102010102,\n    0x0101010102010200, 0x0101010102010201, 0x0101010102020001, 0x0101010102020100,\n    0x0101010102020101, 0x0101010102020102, 0x0101010102020201, 0x0101010200000000,\n    0x0101010200000001, 0x0101010200000002, 0x0101010200000100, 0x0101010200000102,\n    0x0101010200000200, 0x0101010200000201, 0x0101010200010001, 0x0101010200010100,\n    0x0101010200010101, 0x0101010200010200, 0x0101010200010201, 0x0101010200020000,\n    0x0101010200020001, 0x0101010200020002, 0x0101010200020100, 0x0101010200020101,\n    0x0101010200020102, 0x0101010200020200, 0x0101010200020201, 0x0101010201000001,\n    0x0101010201000101, 0x0101010201000102, 0x0101010201000200, 0x0101010201000201,\n    0x0101010201000202, 0x0101010201010000, 0x0101010201010001, 0x0101010201010002,\n    0x0101010201010100, 0x0101010201010101, 0x0101010201010102, 0x0101010201010200,\n    0x0101010201010201, 0x0101010201010202, 0x0101010201020001, 0x0101010201020100,\n    0x0101010201020101, 0x0101010201020201, 0x0101010202000002, 0x0101010202000101,\n    0x0101010202000102, 0x0101010202000200, 0x0101010202000201, 0x0101010202000202,\n    0x0101010202010001, 0x0101010202010101, 0x0101010202010202, 0x0101010202020002,\n    0x0101010202020101, 0x0101010202020102, 0x0101010202020200, 0x0101010202020201,\n    0x0101020000000100, 0x0101020000000101, 0x0101020000000102, 0x0101020000000201,\n    0x0101020000010000, 0x0101020000010101, 0x0101020000010200, 0x0101020000020001,\n    0x0101020000020202, 0x0101020001000101, 0x0101020001000200, 0x0101020001000202,\n    0x0101020001010001, 0x0101020001010100, 0x0101020001010101, 0x0101020001010102,\n    0x0101020001010200, 0x0101020001010201, 0x0101020001020000, 0x0101020001020002,\n    0x0101020001020100, 0x0101020001020101, 0x0101020002000002, 0x0101020002000201,\n    0x0101020002010000, 0x0101020002010002, 0x0101020002010101, 0x0101020002010200,\n    0x0101020002020001, 0x0101020002020201, 0x0101020100000001, 0x0101020100000002,\n    0x0101020100000101, 0x0101020100000202, 0x0101020100010001, 0x0101020100010100,\n    0x0101020100010101, 0x0101020100010102, 0x0101020100010201, 0x0101020100020101,\n    0x0101020101000001, 0x0101020101000100, 0x0101020101000101, 0x0101020101000102,\n    0x0101020101000201, 0x0101020101010000, 0x0101020101010001, 0x0101020101010002,\n    0x0101020101010100, 0x0101020101010101, 0x0101020101010102, 0x0101020101010200,\n    0x0101020101010201, 0x0101020101010202, 0x0101020101020001, 0x0101020101020100,\n    0x0101020101020101, 0x0101020101020102, 0x0101020101020201, 0x0101020102000001,\n    0x0101020102000101, 0x0101020102000201, 0x0101020102010001, 0x0101020102010100,\n    0x0101020102010101, 0x0101020102010102, 0x0101020102010200, 0x0101020102010201,\n    0x0101020102020101, 0x0101020200000100, 0x0101020200000200, 0x0101020200010101,\n    0x0101020200010202, 0x0101020200020000, 0x0101020200020101, 0x0101020200020102,\n    0x0101020200020201, 0x0101020201000101, 0x0101020201000200, 0x0101020201000201,\n    0x0101020201010001, 0x0101020201010101, 0x0101020201010102, 0x0101020201010200,\n    0x0101020201010201, 0x0101020201020002, 0x0101020201020101, 0x0101020201020200,\n    0x0101020201020202, 0x0101020202000001, 0x0101020202000202, 0x0101020202010002,\n    0x0101020202010101, 0x0101020202010102, 0x0101020202010200, 0x0101020202010202,\n    0x0101020202020001, 0x0102000000000101, 0x0102000000010100, 0x0102000000010102,\n    0x0102000000010201, 0x0102000000020101, 0x0102000001000100, 0x0102000001010000,\n    0x0102000001010101, 0x0102000001010102, 0x0102000001010200, 0x0102000001010202,\n    0x0102000001020001, 0x0102000001020100, 0x0102000001020102, 0x0102000001020201,\n    0x0102000002000001, 0x0102000002010102, 0x0102000002020101, 0x0102000100000001,\n    0x0102000100000100, 0x0102000100000102, 0x0102000100000201, 0x0102000100010002,\n    0x0102000100010101, 0x0102000100020001, 0x0102000100020002, 0x0102000100020102,\n    0x0102000100020201, 0x0102000101000101, 0x0102000101000201, 0x0102000101010001,\n    0x0102000101010101, 0x0102000101010102, 0x0102000101010201, 0x0102000101020101,\n    0x0102000101020102, 0x0102000101020202, 0x0102000102000100, 0x0102000102000202,\n    0x0102000102010002, 0x0102000102010101, 0x0102000102020001, 0x0102000102020102,\n    0x0102000102020201, 0x0102000200010001, 0x0102000200010102, 0x0102000200010201,\n    0x0102000201000000, 0x0102000201000001, 0x0102000201000102, 0x0102000201010101,\n    0x0102000201010102, 0x0102000201010200, 0x0102000201020000, 0x0102000202000101,\n    0x0102000202010001, 0x0102000202010102, 0x0102000202020101, 0x0102010000010001,\n    0x0102010000010002, 0x0102010000010101, 0x0102010000010102, 0x0102010000010202,\n    0x0102010000020001, 0x0102010000020102, 0x0102010000020201, 0x0102010001000000,\n    0x0102010001000002, 0x0102010001000101, 0x0102010001000200, 0x0102010001000202,\n    0x0102010001010001, 0x0102010001010100, 0x0102010001010101, 0x0102010001010102,\n    0x0102010001010201, 0x0102010001010202, 0x0102010001020000, 0x0102010001020002,\n    0x0102010001020101, 0x0102010002000100, 0x0102010002000101, 0x0102010002000201,\n    0x0102010002010000, 0x0102010002010002, 0x0102010002010100, 0x0102010002010101,\n    0x0102010002010102, 0x0102010002010200, 0x0102010002010202, 0x0102010002020001,\n    0x0102010002020100, 0x0102010002020201, 0x0102010100000101, 0x0102010100000200,\n    0x0102010100000202, 0x0102010100010001, 0x0102010100010101, 0x0102010100010102,\n    0x0102010100010201, 0x0102010101000100, 0x0102010101000101, 0x0102010101000102,\n    0x0102010101000201, 0x0102010101010000, 0x0102010101010001, 0x0102010101010100,\n    0x0102010101010101, 0x0102010101010102, 0x0102010101010201, 0x0102010101020001,\n    0x0102010101020100, 0x0102010101020101, 0x0102010101020102, 0x0102010101020201,\n    0x0102010102000102, 0x0102010102000201, 0x0102010102000202, 0x0102010102010001,\n    0x0102010102010101, 0x0102010102010102, 0x0102010102010201, 0x0102010102010202,\n    0x0102010102020002, 0x0102010102020101, 0x0102010102020102, 0x0102010102020200,\n    0x0102010200000002, 0x0102010200000201, 0x0102010200010101, 0x0102010200020000,\n    0x0102010200020102, 0x0102010200020200, 0x0102010200020201, 0x0102010201000000,\n    0x0102010201000101, 0x0102010201000200, 0x0102010201000202, 0x0102010201010001,\n    0x0102010201010100, 0x0102010201010101, 0x0102010201010102, 0x0102010201010200,\n    0x0102010201010202, 0x0102010201020000, 0x0102010201020101, 0x0102010201020200,\n    0x0102010202000000, 0x0102010202000002, 0x0102010202000101, 0x0102010202000202,\n    0x0102010202010100, 0x0102010202010102, 0x0102010202010200, 0x0102010202010201,\n    0x0102010202020000, 0x0102010202020100, 0x0102010202020102, 0x0102010202020202,\n    0x0102020000010102, 0x0102020000010201, 0x0102020000020101, 0x0102020001000001,\n    0x0102020001010002, 0x0102020001010101, 0x0102020001010202, 0x0102020001020001,\n    0x0102020001020201, 0x0102020002000101, 0x0102020002010001, 0x0102020002010200,\n    0x0102020002020102, 0x0102020100000001, 0x0102020100000100, 0x0102020100010000,\n    0x0102020100010101, 0x0102020100020001, 0x0102020100020100, 0x0102020100020102,\n    0x0102020100020201, 0x0102020101000000, 0x0102020101000001, 0x0102020101000101,\n    0x0102020101000102, 0x0102020101000200, 0x0102020101010001, 0x0102020101010100,\n    0x0102020101010101, 0x0102020101010102, 0x0102020101010201, 0x0102020101020000,\n    0x0102020101020101, 0x0102020101020202, 0x0102020102000002, 0x0102020102000100,\n    0x0102020102000202, 0x0102020102010101, 0x0102020102020001, 0x0102020102020100,\n    0x0102020102020101, 0x0102020102020201, 0x0102020200010001, 0x0102020200010102,\n    0x0102020200010200, 0x0102020201000001, 0x0102020201000100, 0x0102020201000201,\n    0x0102020201010000, 0x0102020201010101, 0x0102020201010200, 0x0102020201010202,\n    0x0102020201020100, 0x0102020201020101, 0x0102020201020201, 0x0102020202000102,\n    0x0102020202010100, 0x0102020202010200, 0x0102020202010202, 0x0102020202020102,\n    0x0200000000000000, 0x0200000000000002, 0x0200000000000200, 0x0200000000000202,\n    0x0200000000020000, 0x0200000000020002, 0x0200000000020200, 0x0200000000020202,\n    0x0200000001000101, 0x0200000001010000, 0x0200000001010001, 0x0200000001010100,\n    0x0200000001010102, 0x0200000001010201, 0x0200000001020101, 0x0200000002000000,\n    0x0200000002000002, 0x0200000002000200, 0x0200000002000202, 0x0200000002010101,\n    0x0200000002020000, 0x0200000002020002, 0x0200000002020200, 0x0200000002020202,\n    0x0200000100000101, 0x0200000100010001, 0x0200000100010100, 0x0200000100010102,\n    0x0200000100010201, 0x0200000100020101, 0x0200000101000001, 0x0200000101000100,\n    0x0200000101000201, 0x0200000101010000, 0x0200000101010002, 0x0200000101010101,\n    0x0200000101010102, 0x0200000101010200, 0x0200000101010201, 0x0200000101020100,\n    0x0200000101020102, 0x0200000101020201, 0x0200000102000101, 0x0200000102000201,\n    0x0200000102010100, 0x0200000102010102, 0x0200000102010201, 0x0200000102020101,\n    0x0200000200000000, 0x0200000200000002, 0x0200000200000200, 0x0200000200000202,\n    0x0200000200010101, 0x0200000200020000, 0x0200000200020002, 0x0200000200020200,\n    0x0200000200020202, 0x0200000201010001, 0x0200000201010100, 0x0200000201010201,\n    0x0200000201020101, 0x0200000202000000, 0x0200000202000002, 0x0200000202000200,\n    0x0200000202000202, 0x0200000202010101, 0x0200000202020000, 0x0200000202020002,\n    0x0200000202020200, 0x0200000202020202, 0x0200010000010100, 0x0200010000010201,\n    0x0200010001000001, 0x0200010001000100, 0x0200010001010001, 0x0200010001010101,\n    0x0200010001010202, 0x0200010001020001, 0x0200010001020100, 0x0200010001020201,\n    0x0200010002010100, 0x0200010002010201, 0x0200010100000001, 0x0200010100000201,\n    0x0200010100010002, 0x0200010100010101, 0x0200010100010202, 0x0200010100020102,\n    0x0200010100020201, 0x0200010101000000, 0x0200010101000001, 0x0200010101000101,\n    0x0200010101000200, 0x0200010101010001, 0x0200010101010100, 0x0200010101010101,\n    0x0200010101010102, 0x0200010101010201, 0x0200010101010202, 0x0200010101020101,\n    0x0200010101020102, 0x0200010101020200, 0x0200010101020202, 0x0200010102000001,\n    0x0200010102000100, 0x0200010102000102, 0x0200010102000201, 0x0200010102010000,\n    0x0200010102010002, 0x0200010102010101, 0x0200010102010200, 0x0200010102020102,\n    0x0200010200010001, 0x0200010200010102, 0x0200010200010201, 0x0200010200020101,\n    0x0200010201000001, 0x0200010201000100, 0x0200010201000201, 0x0200010201000202,\n    0x0200010201010000, 0x0200010201010101, 0x0200010201010201, 0x0200010201010202,\n    0x0200010201020001, 0x0200010201020102, 0x0200010201020202, 0x0200010202000101,\n    0x0200010202010001, 0x0200010202010202, 0x0200010202020100, 0x0200020000000000,\n    0x0200020000000002, 0x0200020000000200, 0x0200020000000202, 0x0200020000010101,\n    0x0200020000020000, 0x0200020000020002, 0x0200020000020200, 0x0200020000020202,\n    0x0200020001000001, 0x0200020001000101, 0x0200020001010001, 0x0200020001010100,\n    0x0200020001010201, 0x0200020001020101, 0x0200020001020201, 0x0200020002000000,\n    0x0200020002000002, 0x0200020002000200, 0x0200020002000202, 0x0200020002010101,\n    0x0200020002020000, 0x0200020002020002, 0x0200020002020200, 0x0200020002020202,\n    0x0200020100000101, 0x0200020100000102, 0x0200020100010001, 0x0200020100010100,\n    0x0200020100010102, 0x0200020100020101, 0x0200020101000001, 0x0200020101000100,\n    0x0200020101000102, 0x0200020101000201, 0x0200020101010000, 0x0200020101010002,\n    0x0200020101010101, 0x0200020101010202, 0x0200020101020001, 0x0200020101020100,\n    0x0200020102000101, 0x0200020102010102, 0x0200020102010201, 0x0200020102020101,\n    0x0200020200000000, 0x0200020200000002, 0x0200020200000200, 0x0200020200000202,\n    0x0200020200010101, 0x0200020200020000, 0x0200020200020002, 0x0200020200020200,\n    0x0200020200020202, 0x0200020201000101, 0x0200020201010001, 0x0200020201010100,\n    0x0200020201010102, 0x0200020202000000, 0x0200020202000002, 0x0200020202000200,\n    0x0200020202000202, 0x0200020202010101, 0x0200020202020000, 0x0200020202020002,\n    0x0200020202020200, 0x0200020202020202, 0x0201000000000101, 0x0201000000010001,\n    0x0201000000010102, 0x0201000000010200, 0x0201000000010201, 0x0201000000020101,\n    0x0201000001000001, 0x0201000001000102, 0x0201000001000201, 0x0201000001010101,\n    0x0201000001010200, 0x0201000001010202, 0x0201000001020201, 0x0201000001020202,\n    0x0201000002000101, 0x0201000002010001, 0x0201000002010100, 0x0201000002010102,\n    0x0201000002010201, 0x0201000002020101, 0x0201000100000001, 0x0201000100000100,\n    0x0201000100000102, 0x0201000100000201, 0x0201000100010000, 0x0201000100010101,\n    0x0201000100010200, 0x0201000100010202, 0x0201000100020001, 0x0201000100020100,\n    0x0201000100020102, 0x0201000100020201, 0x0201000101000000, 0x0201000101000101,\n    0x0201000101010000, 0x0201000101010001, 0x0201000101010100, 0x0201000101010101,\n    0x0201000101010102, 0x0201000101010201, 0x0201000101020002, 0x0201000101020101,\n    0x0201000102000100, 0x0201000102000102, 0x0201000102010002, 0x0201000102010101,\n    0x0201000102010200, 0x0201000102020001, 0x0201000102020100, 0x0201000102020102,\n    0x0201000102020201, 0x0201000200000101, 0x0201000200010001, 0x0201000200010100,\n    0x0201000200010201, 0x0201000200020101, 0x0201000201000100, 0x0201000201000102,\n    0x0201000201000201, 0x0201000201010000, 0x0201000201010002, 0x0201000201010101,\n    0x0201000201010200, 0x0201000201020102, 0x0201000201020201, 0x0201000202000101,\n    0x0201000202010100, 0x0201000202010102, 0x0201000202020201, 0x0201010000000001,\n    0x0201010000000100, 0x0201010000000102, 0x0201010000010000, 0x0201010000010101,\n    0x0201010000010200, 0x0201010000020102, 0x0201010001000000, 0x0201010001000202,\n    0x0201010001010001, 0x0201010001010100, 0x0201010001010101, 0x0201010001010102,\n    0x0201010001010200, 0x0201010001010201, 0x0201010001020000, 0x0201010001020001,\n    0x0201010001020002, 0x0201010001020101, 0x0201010002000100, 0x0201010002000102,\n    0x0201010002010002, 0x0201010002010100, 0x0201010002010101, 0x0201010002010200,\n    0x0201010002020001, 0x0201010002020201, 0x0201010100000000, 0x0201010100000101,\n    0x0201010100000200, 0x0201010100000202, 0x0201010100010000, 0x0201010100010001,\n    0x0201010100010100, 0x0201010100010101, 0x0201010100010102, 0x0201010100010201,\n    0x0201010100020001, 0x0201010100020101, 0x0201010100020201, 0x0201010100020202,\n    0x0201010101000001, 0x0201010101000100, 0x0201010101000101, 0x0201010101000102,\n    0x0201010101000201, 0x0201010101010000, 0x0201010101010001, 0x0201010101010002,\n    0x0201010101010100, 0x0201010101010101, 0x0201010101010102, 0x0201010101010200,\n    0x0201010101010201, 0x0201010101010202, 0x0201010101020001, 0x0201010101020100,\n    0x0201010101020101, 0x0201010101020102, 0x0201010101020201, 0x0201010102000001,\n    0x0201010102000101, 0x0201010102000200, 0x0201010102010001, 0x0201010102010002,\n    0x0201010102010100, 0x0201010102010101, 0x0201010102010102, 0x0201010102010201,\n    0x0201010102010202, 0x0201010102020000, 0x0201010102020002, 0x0201010102020101,\n    0x0201010102020200, 0x0201010102020202, 0x0201010200000001, 0x0201010200000100,\n    0x0201010200010000, 0x0201010200010101, 0x0201010200010201, 0x0201010200020000,\n    0x0201010200020102, 0x0201010200020201, 0x0201010201000101, 0x0201010201000200,\n    0x0201010201000201, 0x0201010201010001, 0x0201010201010002, 0x0201010201010101,\n    0x0201010201010102, 0x0201010201010201, 0x0201010201020101, 0x0201010201020200,\n    0x0201010202000002, 0x0201010202000100, 0x0201010202000201, 0x0201010202000202,\n    0x0201010202010002, 0x0201010202010100, 0x0201010202010101, 0x0201010202020100,\n    0x0201010202020102, 0x0201010202020201, 0x0201020000000101, 0x0201020000010102,\n    0x0201020000010201, 0x0201020000020101, 0x0201020001000001, 0x0201020001000102,\n    0x0201020001010000, 0x0201020001010002, 0x0201020001010101, 0x0201020001010102,\n    0x0201020001010202, 0x0201020001020100, 0x0201020001020101, 0x0201020002000101,\n    0x0201020002010001, 0x0201020002010102, 0x0201020002010201, 0x0201020002020101,\n    0x0201020100000100, 0x0201020100000102, 0x0201020100000201, 0x0201020100010000,\n    0x0201020100010002, 0x0201020100010101, 0x0201020100010200, 0x0201020100010202,\n    0x0201020100020000, 0x0201020100020001, 0x0201020100020100, 0x0201020100020102,\n    0x0201020101000000, 0x0201020101000002, 0x0201020101000101, 0x0201020101000200,\n    0x0201020101000202, 0x0201020101010001, 0x0201020101010100, 0x0201020101010101,\n    0x0201020101010102, 0x0201020101010201, 0x0201020101020002, 0x0201020101020101,\n    0x0201020101020102, 0x0201020101020202, 0x0201020102000001, 0x0201020102000100,\n    0x0201020102010000, 0x0201020102010002, 0x0201020102010101, 0x0201020102010202,\n    0x0201020102020001, 0x0201020102020102, 0x0201020200000101, 0x0201020200010101,\n    0x0201020200020101, 0x0201020201000100, 0x0201020201000102, 0x0201020201000201,\n    0x0201020201010000, 0x0201020201010101, 0x0201020201010200, 0x0201020201020001,\n    0x0201020202000101, 0x0201020202010001, 0x0201020202010100, 0x0201020202010101,\n    0x0201020202010102, 0x0202000000000000, 0x0202000000000002, 0x0202000000000200,\n    0x0202000000000202, 0x0202000000010101, 0x0202000000020000, 0x0202000000020002,\n    0x0202000000020200, 0x0202000000020202, 0x0202000001000101, 0x0202000001010001,\n    0x0202000001010100, 0x0202000001010102, 0x0202000001010201, 0x0202000002000000,\n    0x0202000002000002, 0x0202000002000200, 0x0202000002000202, 0x0202000002010101,\n    0x0202000002020000, 0x0202000002020002, 0x0202000002020200, 0x0202000002020202,\n    0x0202000100000101, 0x0202000100000201, 0x0202000100010001, 0x0202000100010100,\n    0x0202000100010102, 0x0202000100010201, 0x0202000100010202, 0x0202000101000102,\n    0x0202000101000201, 0x0202000101010001, 0x0202000101010101, 0x0202000101010200,\n    0x0202000101010202, 0x0202000101020001, 0x0202000101020100, 0x0202000102000101,\n    0x0202000102010000, 0x0202000102010002, 0x0202000102010102, 0x0202000102010201,\n    0x0202000200000002, 0x0202000200000200, 0x0202000200000202, 0x0202000200010000,\n    0x0202000200010201, 0x0202000200020002, 0x0202000200020200, 0x0202000200020202,\n    0x0202000201000101, 0x0202000201010001, 0x0202000201010102, 0x0202000201010201,\n    0x0202000201020101, 0x0202000202000000, 0x0202000202000002, 0x0202000202000200,\n    0x0202000202000202, 0x0202000202010101, 0x0202000202020000, 0x0202000202020002,\n    0x0202000202020200, 0x0202000202020202, 0x0202010000010201, 0x0202010000020101,\n    0x0202010001000001, 0x0202010001000100, 0x0202010001010000, 0x0202010001010100,\n    0x0202010001010101, 0x0202010001010200, 0x0202010001010202, 0x0202010001020001,\n    0x0202010001020101, 0x0202010001020102, 0x0202010001020200, 0x0202010001020201,\n    0x0202010002000101, 0x0202010100000102, 0x0202010100000201, 0x0202010100010000,\n    0x0202010100010002, 0x0202010100010101, 0x0202010100010200, 0x0202010100020102,\n    0x0202010100020201, 0x0202010101000002, 0x0202010101000101, 0x0202010101010001,\n    0x0202010101010100, 0x0202010101010101, 0x0202010101010102, 0x0202010101010201,\n    0x0202010101020101, 0x0202010101020202, 0x0202010102000001, 0x0202010102000100,\n    0x0202010102000101, 0x0202010102000102, 0x0202010102000201, 0x0202010102010002,\n    0x0202010102010101, 0x0202010102010200, 0x0202010200000101, 0x0202010200010001,\n    0x0202010200010102, 0x0202010200010202, 0x0202010200020001, 0x0202010200020101,\n    0x0202010201000100, 0x0202010201000102, 0x0202010201000202, 0x0202010201010002,\n    0x0202010201010101, 0x0202010201010102, 0x0202010201010200, 0x0202010201020000,\n    0x0202010201020002, 0x0202010202000102, 0x0202010202010000, 0x0202010202010101,\n    0x0202010202010102, 0x0202010202010201, 0x0202010202020001, 0x0202010202020100,\n    0x0202010202020102, 0x0202020000000000, 0x0202020000000002, 0x0202020000000200,\n    0x0202020000000202, 0x0202020000020000, 0x0202020000020002, 0x0202020000020200,\n    0x0202020000020202, 0x0202020001010001, 0x0202020001010100, 0x0202020001010102,\n    0x0202020001010201, 0x0202020002000000, 0x0202020002000002, 0x0202020002000200,\n    0x0202020002000202, 0x0202020002010101, 0x0202020002020000, 0x0202020002020002,\n    0x0202020002020200, 0x0202020002020202, 0x0202020100000101, 0x0202020100010100,\n    0x0202020100010201, 0x0202020100020001, 0x0202020100020101, 0x0202020101000001,\n    0x0202020101010000, 0x0202020101010101, 0x0202020101010202, 0x0202020101020001,\n    0x0202020101020102, 0x0202020101020201, 0x0202020102010000, 0x0202020102010102,\n    0x0202020200000000, 0x0202020200000002, 0x0202020200000200, 0x0202020200000202,\n    0x0202020200020000, 0x0202020200020002, 0x0202020200020200, 0x0202020200020202,\n    0x0202020201010001, 0x0202020201010100, 0x0202020201010102, 0x0202020202000000,\n    0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101,\n    0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202,\n};\n#else\nstatic const uint32_t iq1s_grid_us[2048] = {\n    0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,\n    0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,\n    0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,\n    0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,\n    0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,\n    0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,\n    0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,\n    0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,\n    0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,\n    0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,\n    0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,\n    0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,\n    0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,\n    0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,\n    0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,\n    0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,\n    0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,\n    0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,\n    0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,\n    0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,\n    0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,\n    0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,\n    0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,\n    0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,\n    0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,\n    0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,\n    0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,\n    0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,\n    0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,\n    0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,\n    0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,\n    0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,\n    0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,\n    0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,\n    0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,\n    0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,\n    0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,\n    0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,\n    0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,\n    0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,\n    0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,\n    0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,\n    0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,\n    0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,\n    0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,\n    0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,\n    0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,\n    0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,\n    0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,\n    0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,\n    0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,\n    0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,\n    0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,\n    0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,\n    0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,\n    0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,\n    0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,\n    0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,\n    0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,\n    0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,\n    0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,\n    0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,\n    0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,\n    0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,\n    0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,\n    0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,\n    0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,\n    0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,\n    0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,\n    0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,\n    0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,\n    0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,\n    0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,\n    0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,\n    0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,\n    0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,\n    0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,\n    0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,\n    0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,\n    0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,\n    0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,\n    0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,\n    0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,\n    0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,\n    0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,\n    0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,\n    0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,\n    0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,\n    0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,\n    0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,\n    0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,\n    0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,\n    0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,\n    0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,\n    0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,\n    0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,\n    0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,\n    0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,\n    0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,\n    0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,\n    0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,\n    0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,\n    0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,\n    0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,\n    0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,\n    0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,\n    0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,\n    0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,\n    0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,\n    0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,\n    0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,\n    0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,\n    0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,\n    0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,\n    0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,\n    0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,\n    0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,\n    0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,\n    0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,\n    0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,\n    0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,\n    0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,\n    0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,\n    0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,\n    0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,\n    0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,\n    0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,\n    0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,\n    0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,\n    0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,\n    0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,\n    0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,\n    0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,\n    0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,\n    0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,\n    0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,\n    0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,\n    0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,\n    0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,\n    0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,\n    0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,\n    0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,\n    0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,\n    0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,\n    0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,\n    0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,\n    0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,\n    0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,\n    0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,\n    0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,\n    0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,\n    0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,\n    0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,\n    0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,\n    0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,\n    0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,\n    0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,\n    0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,\n    0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,\n    0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,\n    0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,\n    0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,\n    0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,\n    0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,\n    0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,\n    0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,\n    0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,\n    0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,\n    0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,\n    0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,\n    0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,\n    0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,\n    0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,\n    0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,\n    0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,\n    0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,\n    0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,\n    0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,\n    0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,\n    0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,\n    0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,\n    0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,\n    0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,\n    0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,\n    0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,\n    0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,\n    0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,\n    0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,\n    0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,\n    0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,\n    0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,\n    0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,\n    0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,\n    0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,\n    0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,\n    0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,\n    0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,\n    0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,\n    0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,\n    0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,\n    0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,\n    0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,\n    0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,\n    0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,\n    0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,\n    0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,\n    0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,\n    0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,\n    0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,\n    0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,\n    0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,\n    0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,\n    0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,\n    0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,\n    0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,\n    0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,\n    0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,\n    0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,\n    0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,\n    0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,\n    0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,\n    0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,\n    0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,\n    0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,\n    0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,\n    0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,\n    0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,\n    0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,\n    0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,\n    0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,\n    0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,\n    0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,\n    0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,\n    0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,\n    0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,\n    0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,\n    0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,\n    0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,\n    0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,\n    0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,\n    0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,\n    0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,\n    0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,\n    0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,\n    0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,\n    0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,\n    0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,\n    0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,\n    0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,\n    0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,\n    0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,\n    0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,\n    0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,\n    0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,\n    0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,\n    0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,\n};\n#endif\n// end copy https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L570\n\n#ifndef HAVE_FANCY_SIMD\nconst uint64_t keven_signs[128] = {\n    0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,\n    0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,\n    0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,\n    0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,\n    0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,\n    0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,\n    0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,\n    0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,\n    0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,\n    0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,\n    0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,\n    0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,\n    0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,\n    0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,\n    0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,\n    0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,\n    0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,\n    0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,\n    0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,\n    0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,\n    0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,\n    0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,\n    0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,\n    0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,\n    0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,\n    0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,\n    0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,\n    0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,\n    0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,\n    0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,\n    0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,\n    0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,\n};\n#endif\n\n}\n\n/* moonll change mulmat\nadd typeB and strideB\n}*/\n\n// Adapted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L406\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\nbool iqk_mul_mat(long Nx, long Ny, long ne00,\n    int typeA, const void * A, long strideA,\n    int typeB, const void * B, long strideB,\n    float * C, long stride_C, int ith, int nth) {\n\n        MulMat mm;\n    \n        if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) {\n            return false;\n        }\n\n        size_t row_size_qx = strideA*ggml_type_size(ggml_type(typeA));\n        size_t row_size_qy = strideB*ggml_type_size(ggml_type(typeB));\n      \n        \n        auto nrc_x = (Nx + nth - 1)/nth;\n        auto first_x = ith*nrc_x;\n        if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;\n\n        DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};\n\n        mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);\n\n        return true;\n}\n// end adapted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L406\n\n\nbool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B,\n        float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {\n    const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;\n    assert(row_mapping != nullptr);\n\n    MulMat mm;\n    int row_size_q8;\n    /* moonll\n\n    if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) {\n        return false;\n    }*/\n    int row_size_qx = ggml_row_size((ggml_type)typeA, ne00);\n    int nrc_x = (Nx + nth - 1)/nth;\n    int first_x = ith*nrc_x;\n    if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;\n    DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)};\n    mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);\n    return true;\n}\n\n#if defined __x86_64__ || defined(_M_X64)\n\n#if defined HAVE_FANCY_SIMD\n    #undef HAVE_FANCY_SIMD\n#endif\n#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)\n    #define HAVE_FANCY_SIMD\n#endif\n//#define HAVE_FANCY_SIMD\n\nnamespace {\n\ninline float hsum_float_4(__m128 x) {\n    x = _mm_add_ps(x, _mm_movehl_ps(x, x));\n    x = _mm_add_ss(x, _mm_movehdup_ps(x));\n    return _mm_cvtss_f32(x);\n}\ninline float hsum_float_8(__m256 x) {\n    return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));\n}\n\n#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)\n\n\ntemplate <int nrc, typename block_q8 = block_q8_K> struct Q8 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q8(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);\n    }\n\n#ifdef HAVE_FANCY_SIMD\n    inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); }\n#endif\n    inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }\n    inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); }\n    inline float scale(int iy, int i) const { return y[iy][i].d; }\n\n    const block_q8 * y[nrc_y];\n};\n\n// Handles q4_K and q5_K scales/mins\nstruct Scales8K {\n    template <typename Q8>\n    inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {\n        make_q4_scales(data, utmp);\n        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));\n        const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);\n        accum_mins(mins128, q8, i, c, accd);\n        const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);\n        return MM256_SET_M128I(sc128, sc128);\n    }\n#ifdef HAVE_FANCY_SIMD\n    template <typename Q8>\n    inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {\n        auto scales = process_mins_and_scales(data, c, i, q8, accd);\n        return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1);\n    }\n#endif\n    template <typename Q8>\n    inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {\n        const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i q8s = q8.load_bsums(iy, i);\n            const __m256i prod = _mm256_madd_epi16(mins, q8s);\n            accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);\n        }\n    }\n#ifdef HAVE_FANCY_SIMD\n    const __m512i shuffles512[2] = {\n        _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,\n                         0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),\n        _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,\n                         0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)\n    };\n#endif\n    const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),\n                                 _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};\n\n    uint32_t utmp[4];\n};\n\ntemplate <typename Q8>\ninline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        const __m256i prod  = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i));\n        accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);\n    }\n}\ninline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {\n    const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);\n    const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);\n    scales[0] = MM256_SET_M128I(l_scales, l_scales);\n    scales[1] = MM256_SET_M128I(h_scales, h_scales);\n}\n\nstruct ScaleQ3 {\n    inline __m128i make_scales(const uint16_t * s8) const {\n        const uint16_t * scales16 = (const uint16_t *)s8;\n        uint32_t aux0 = scales16[0] | (scales16[1] << 16);\n        uint32_t aux1 = scales16[2] | (scales16[3] << 16);\n        uint32_t aux2 = scales16[4] | (scales16[5] << 16);\n        __m128i scales128 = _mm_set_epi32(\n            ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030),\n            ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030),\n             (aux1       & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030),\n             (aux0       & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030));\n        return _mm_add_epi8(scales128, m32);\n    }\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct ScaleIQ4XS {\n    inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) {\n        uint32_t tmp32 = scales_h | (scales_h << 14);\n        const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4);\n        const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask);\n        return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32);\n    }\n    const __m128i hshift = _mm_set_epi32(12, 8, 4, 0);\n    const __m128i lshift = _mm_set_epi32(4, 0, 4, 0);\n    const __m128i hmask  = _mm_set1_epi16(0x03);\n    const __m128i lmask  = _mm_set1_epi8(0xf);\n    const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400);\n    const __m128i m32 = _mm_set1_epi16(-32);\n};\n\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1455\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\nstruct Scales8KBase {\n    template <typename Q8>\n    inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {\n        const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i q8s = q8.load_bsums(iy, i);\n            const __m256i prod = _mm256_madd_epi16(mins, q8s);\n            accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);\n        }\n    }\n    inline __m256i shuffle(__m128i mins) const {\n        return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));\n    }\n    const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),\n                                 _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};\n};\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1455\n\ntemplate <typename Block>\nstruct BaseDequantizer {\n    BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}\n    inline void new_row(int ix) {\n        x = (const Block *)((const char *)vx + bx*ix);\n    }\n\n    const void *  vx;\n    size_t        bx;\n    const Block * x;\n\n    float d;\n};\n\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1698\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\n__m128i inline load_iq4nl_values_128() {\n    static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};\n    return _mm_loadu_si128((const __m128i *)kvalues_iq4nl);\n}\n\n__m256i inline load_iq4nl_values_256() {\n    auto val128 = load_iq4nl_values_128();\n    return MM256_SET_M128I(val128, val128);\n}\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1698\n\n#ifdef HAVE_FANCY_SIMD\n//====================================== Zen4 ==================================================\n\nstruct BlockPermuter {\n    const __m512i permute1 = _mm512_set_epi64(11, 10,  9,  8, 3, 2, 1, 0);\n    const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);\n};\n\nstruct Q4Bits {\n    inline void prepare(const uint8_t * q4) {\n        auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);\n        auto tmp1 = _mm512_and_si512(q4bits, ml);\n        auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);\n        values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);\n        q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);\n        tmp1 = _mm512_and_si512(q4bits, ml);\n        tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);\n        values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);\n    }\n    inline void prepare64(const uint8_t * q4) {\n        auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);\n        values[0] = _mm512_and_si512(q4bits, ml);\n        values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);\n        values[2] = _mm512_and_si512(q4bits, ml);\n        values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n    }\n    __m512i values[4];\n    const __m512i ml = _mm512_set1_epi8(0xf);\n    BlockPermuter perm;\n};\n\nstruct Q2Bits {\n    inline void prepare(const uint8_t * q2) {\n\n        auto q2bits = _mm512_loadu_si512((const __m512i*)q2);\n        auto tmp = _mm512_srli_epi16(q2bits, 2);\n\n        values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);\n        values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);\n        values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);\n        values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);\n        values[0] = _mm512_and_si512(values[0], ml);\n        values[2] = _mm512_and_si512(values[2], ml);\n    }\n    __m512i values[4];\n    const __m512i ml = _mm512_set1_epi8(0x03);\n    BlockPermuter perm;\n};\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n        scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n};\n\n/*\nmoonll DequantizerIQ4XS\n*/\n\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1775\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\n__m512i inline load_iq4nl_values_512() {\n    auto val256 = load_iq4nl_values_256();\n    return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);\n}\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1775\n\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1781\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n    // Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1782\n    DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        prepare(x[i].qs);\n        auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);\n        s8k.accum_mins(scales128, q8, i, -128.f*d, accd);\n        auto scales256 = MM256_SET_M128I(scales128, scales128);\n        auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);\n        scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);\n        scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);\n        scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);\n    }\n    inline void prepare(const uint8_t * q4) {\n        bits.prepare64(q4);\n        // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111\n        //                bits.valuse[1]: 16..31, 48...63, 80...95, 112..127\n        //                etc.\n        auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);\n        bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));\n        bits.values[0] = _mm512_shuffle_epi8(values, tmp);\n        tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);\n        bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));\n        bits.values[2] = _mm512_shuffle_epi8(values, tmp);\n    }\n\n    Q4Bits bits;\n    Scales8KBase s8k;\n    ScaleIQ4XS siq4;\n    const __m512i values;\n    const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2,  9,  8, 1, 0);\n    const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);\n    const __m512i shuffles[4] = {\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),\n    };\n};\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1781\n\nstruct HighBit5 {\n    inline void apply(const uint8_t * h, Q4Bits& bits) {\n        auto hbits256 = _mm256_loadu_si256((const __m256i *)h);\n        auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));\n    }\n    const __m512i mh = _mm512_set1_epi8(0x10);\n};\n\nstruct HighBit3 {\n    inline void apply(const uint8_t * h, Q2Bits& bits) {\n        auto hbits256 = _mm256_loadu_si256((const __m256i *)h);\n        auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh));\n    }\n    const __m512i mh = _mm512_set1_epi8(0x04);\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        hbits.apply(x[i].qh, bits);\n        auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n        scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);\n    }\n\n    Q4Bits bits;\n    HighBit5 hbits;\n    Scales8K s8k;\n};\n\nstruct Scale16 {\n    inline void make_scales(const __m128i& scales8, __m512i * scales) const {\n        auto all_scales8 = MM256_SET_M128I(scales8, scales8);\n        auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1);\n        auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2);\n        scales[0] = _mm512_cvtepi8_epi16(scales1);\n        scales[1] = _mm512_cvtepi8_epi16(scales2);\n    }\n    template <typename Q8>\n    inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8,\n        const Q8& q8, __m256 * accm, __m512i * scales) const {\n        process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm);\n        make_scales(scales8, scales);\n    }\n    const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202,\n                                              0x05050505, 0x01010101, 0x04040404, 0x00000000);\n    const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a,\n                                              0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808);\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);\n        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);\n        sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales);\n    }\n\n    Q2Bits bits;\n    Scale16 sc16;\n    const __m128i m4 = _mm_set1_epi8(0xf);\n\n};\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        hbits.apply(x[i].hmask, bits);\n        auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales);\n        sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales);\n    }\n\n    Q2Bits bits;\n    HighBit3 hbits;\n    ScaleQ3 sc3;\n    Scale16 sc16;\n    const __m128i m4  = _mm_set1_epi8(0xf);\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare64(x[i].ql);\n        add_high_bits(x[i].qh, bits);\n        auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales);\n        sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales);\n    }\n\n    inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const {\n        auto hbits = _mm512_loadu_si512((const __m512i *)qh);\n        auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh);\n        auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));\n        tmp1 = _mm512_and_si512(hbits, mh);\n        tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh);\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));\n    }\n\n    Q4Bits bits;\n    HighBit3 hbits;\n    Scale16 sc16;\n\n    const __m512i mh = _mm512_set1_epi8(0x30);\n\n};\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accm[nrc_y];\n    __m512  accd[nrc_y];\n    __m512i scales[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();\n        for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accm, scales);\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants(iy, i, 0));\n                const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants(iy, i, 1));\n                const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants(iy, i, 2));\n                const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants(iy, i, 3));\n                auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));\n                sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));\n                accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));\n            info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));\n        }\n\n    }\n}\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L2408\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\ntemplate <typename Q8>\ninline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {\n    const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));\n    const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));\n    const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));\n    const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));\n    auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));\n    sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));\n    accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accm[nrc_y];\n    __m512  accd[nrc_y];\n    __m512i scales[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();\n        for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accm, scales);\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));\n                const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));\n                const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));\n                const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));\n                auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));\n                sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));\n                accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));\n            info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));\n        }\n\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accm[nrc_y];\n    __m512  accd[nrc_y];\n    __m512i scales[4];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();\n        for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accm, scales);\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0));\n                const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1));\n                const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2));\n                const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3));\n                auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(),\n                                    p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]);\n                accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));\n            info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));\n        }\n\n    }\n}\n\ntemplate <typename Dequantizer>\nstatic void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    constexpr int k_nx = 2;\n\n    Q8<1> q8(info);\n\n    Dequantizer deq1(vx, bx);\n    Dequantizer deq2(vx, bx);\n\n    Dequantizer * deq[k_nx];\n    deq[0] = &deq1;\n    deq[1] = &deq2;\n\n    __m512i scales[2*k_nx];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        auto accd = _mm512_setzero_ps();\n        auto accm = _mm256_setzero_ps();\n\n        for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);\n\n        for (int i = 0; i < nb/k_nx; ++i) {\n\n            for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);\n\n            for (int kx = 0; kx < k_nx; ++kx) {\n                compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);\n            }\n\n        }\n        if (2*(nb/2) < nb) {\n            int i0 = 2*(nb/2);\n            deq[0]->new_block(i0, q8, &accm, scales);\n            compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);\n        }\n\n        auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));\n        info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));\n    }\n}\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L2408\n\n#else\n// ===================================== Vanilla AVX2 =====================================\n\nstruct Q4Bits {\n    inline void prepare(const uint8_t * q4, int j) {\n        auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);\n        values[0] = _mm256_and_si256(q4bits, ml);\n        values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);\n        values[2] = _mm256_and_si256(q4bits, ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n    }\n    inline void prepare64(const uint8_t * q4, int j) {\n        auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);\n        values[0] = _mm256_and_si256(q4bits, ml);\n        values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);\n        values[1] = _mm256_and_si256(q4bits, ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n    }\n    inline void prepare16(const uint8_t * q4, int j) {\n        values[0] = dequant16(q4 + 64*j +  0);\n        values[1] = dequant16(q4 + 64*j + 16);\n        values[2] = dequant16(q4 + 64*j + 32);\n        values[3] = dequant16(q4 + 64*j + 48);\n    }\n    inline __m256i dequant16(const uint8_t * qs) const {\n        const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);\n        const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);\n        return _mm256_and_si256(ml, aux256);\n    };\n    __m256i values[4];\n    const __m256i ml = _mm256_set1_epi8(0xf);\n};\n\nstruct Q2Bits {\n    inline void prepare(const uint8_t * q2, int j) {\n        auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);\n        values[0] = _mm256_and_si256(q2bits, ml);\n        values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);\n        values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);\n    }\n    __m256i values[4];\n    const __m256i ml = _mm256_set1_epi8(0x03);\n};\n\nstruct HighBit5 {\n    inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }\n    inline void apply(Q4Bits& bits, bool do_shift) {\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));\n        if (do_shift) {\n            hbits = _mm256_srli_epi16(hbits, 4);\n        }\n    }\n    const __m256i mh = _mm256_set1_epi8(0x10);\n    __m256i hbits;\n};\n\nstruct HighBit3 {\n    inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }\n    inline void apply(Q2Bits& bits, bool do_shift) {\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));\n        if (do_shift) {\n            hbits = _mm256_srli_epi16(hbits, 4);\n        }\n    }\n    const __m256i mh = _mm256_set1_epi8(0x04);\n    __m256i hbits;\n};\n\n\n/*\ntemplate <typename Q8, typename Bits>\ninline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {\n    if (j == 0) {\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));\n            sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));\n        }\n    } else {\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));\n        }\n    }\n}*/\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n};\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n    DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);\n        s8k.accum_mins(scales128, q8, i, -128.f*d, accd);\n        return MM256_SET_M128I(scales128, scales128);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare16(x[i].qs, j);\n        bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);\n        bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);\n        bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);\n        bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);\n    }\n\n    static __m256i load_values() {\n        static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};\n        auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);\n        return MM256_SET_M128I(val128, val128);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n    ScaleIQ4XS siq4;\n    const __m256i values;\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        hbits.load(x[i].qh);\n        return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n        hbits.apply(bits, j == 0);\n    }\n\n    Q4Bits  bits;\n    HighBit5 hbits;\n    Scales8K s8k;\n};\n\ntemplate <typename Q8>\ninline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d,\n    __m256 * accm, __m256i * scales) {\n    const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);\n    process_mins_16(all_scales, q8, i, d, accm);\n    prepare_scales_16(all_scales, scales);\n}\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        hbits.load(x[i].hmask);\n        process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n        hbits.apply(bits, j == 0);\n    }\n\n    Q2Bits  bits;\n    HighBit3 hbits;\n    ScaleQ3 sc3;\n\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);\n        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);\n        process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm);\n        prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n    }\n\n    Q2Bits  bits;\n\n    const __m128i m4 = _mm_set1_epi8(0xf);\n};\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare64(x[i].ql, j);\n        auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));\n    }\n\n    Q4Bits  bits;\n    const __m256i mh = _mm256_set1_epi8(0x30);\n};\n\n\ninline __m256i get_scale_shuffle_8(int i);\n\ninline void set_scales_8(const __m256i& all_scales, int j, __m256i* scales);\n\ninline __m256i get_scale_shuffle_16(int i);\n\ninline void set_scales_16(const __m256i& all_scales, __m256i* scales);\n\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%QK_K == 0);\n    const int nb = n/QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    __m256i all_scales[2];\n    __m256i scales[4];\n    __m256  accd[nrc_y];\n\n    Dequantizer deq(vx, bx);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accd, all_scales);\n\n            __m256i sumi[nrc_y];\n\n            for (int j = 0; j < QK_K/128; ++j) {\n                deq.prepare(i, j);\n                set_scales_16(all_scales[j], scales);\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n            }\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n\n    }\n\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accd[nrc_y];\n    __m256i scales[4];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            auto all_scales = deq.new_block(i, q8, accd);\n\n            __m256i sumi[nrc_y];\n\n            for (int j = 0; j < QK_K/128; ++j) {\n\n                deq.prepare(i, j);\n\n                set_scales_8(all_scales, j, scales);\n\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n\n            }\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));\n                accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n\n    }\n}\n#endif  // Zen4 or vanilla AVX2\n\n\n\n//\n// ============================== Legacy quants\n//\n\nstruct DotHelper {\n    const __m256i m1 = _mm256_set1_epi16(1);\n#if defined(__AVX512VNNI__) && defined(__AVX512VL__)\n    inline __m256i dot(__m256i x, __m256i y) const {\n        return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y);\n    }\n#else\n    inline __m256i dot(__m256i x, __m256i y) const {\n        return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y));\n    }\n#endif\n};\n\nstruct SignedDot {\n    DotHelper helper;\n    inline __m256i compute(__m256i x, __m256i y) const {\n        return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x));\n    }\n};\nstruct UnsignedDot {\n    DotHelper helper;\n    inline __m256i compute(__m256i x, __m256i y) const {\n        return helper.dot(x, y);\n    }\n};\ntemplate <typename Q8, typename Dot> struct Sum4 {\n    Dot dot;\n    inline __m256i compute(const __m256i * qx, const Q8 * y) const {\n        const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs));\n        const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs));\n        const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs));\n        const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs));\n        const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1));    // 0,0, 1,1, 0,0, 1,1\n        const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3));    // 2,2, 3,3, 2,2, 3,3\n        return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3\n    }\n};\n\nstruct Sum4_Q8 {\n    SignedDot dot;\n    static inline __m256i add1(__m256i a, __m256i b) {\n        return _mm256_add_epi32(_mm256_unpacklo_epi32(a, b), _mm256_unpackhi_epi32(a, b));\n    }\n    static inline __m256i add2(__m256i a, __m256i b) {\n        return _mm256_add_epi32(_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b));\n    }\n    inline __m256i compute(const __m256i * qx, const block_q8_0 * y) const {\n        const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs));\n        const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs));\n        const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs));\n        const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs));\n        const __m256i p01 = add1(p0, p1);  // 0,1, 0,1, 0,1, 0,1\n        const __m256i p23 = add1(p2, p3);  // 2,3, 2,3, 2,3, 2,3\n        return add2(p01, p23); // returns 0,1,2,3, 0,1,2,3\n    }\n};\n\nstruct ScaleHelperQ_0 {\n    ggml_half scales8[4];\n    template <typename Q>\n    inline __m128 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;\n        return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));\n    }\n    template <typename Q>\n    inline __m128 prepare4(__m128 other_scales, const Q * y) {\n        return _mm_mul_ps(other_scales, prepare4<Q>(y));\n    }\n    template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }\n    template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }\n};\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8187\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\ntemplate <int min_value>\nstruct ScaleHelperQ_0_1 {\n    ggml_half scales8[4];\n    template <typename Q>\n    inline __m256 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;\n        auto s4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));\n        return _mm256_set_m128(_mm_mul_ps(s4, min), s4);\n    }\n    template <typename Q>\n    inline __m256 prepare4(__m256 other_scales, const Q * y) {\n        return _mm_mul256_ps(other_scales, prepare4<Q>(y));\n    }\n    template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {\n        float d = GGML_FP16_TO_FP32(y->d);\n        return std::make_pair(d, -d*float(min_value));\n    }\n    std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));\n    }\n    const __m128 min = _mm_set1_ps(float(-min_value));\n};\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8187\n\nstruct ScaleHelperQ_1 {\n    uint32_t scales8[4];\n    const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100);\n\n    template <typename Q>\n    inline __m256 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) {\n            // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers\n            // complain that this breaks strict-aliasing rules.\n            memcpy(scales8 + j, &y[j].d, sizeof(uint32_t));\n        }\n        return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle));\n    }\n\n    template <typename Q>\n    inline __m256 prepare4(__m256 other_scales, const Q * y) {\n        return _mm256_mul_ps(other_scales, prepare4<Q>(y));\n    }\n\n    template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {\n        return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m));\n    }\n    template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m));\n    }\n    std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));\n    }\n};\n\nstruct MinusType0 {\n    inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); }\n    inline float compute(float d, int) const { return d; }\n    inline float result(__m256 acc, int) const { return hsum_float_8(acc); }\n};\n\ntemplate <int nrc_y> struct MinusType1 {\n    __m128 accm[nrc_y];\n    MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); }\n    inline __m256 compute(__m256 dm, int iy) {\n        const __m128 d = _mm256_castps256_ps128(dm);\n        const __m128 m = _mm256_extractf128_ps(dm, 1);\n        accm[iy] = _mm_add_ps(accm[iy], m);\n        return _mm256_set_m128(d, d);\n    }\n    inline float compute(const std::pair<float, float>& dm, int iy) {\n        accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f));\n        return dm.first;\n    }\n    inline float result(__m256 acc, int iy) const {\n        const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));\n        return hsum_float_4(_mm_add_ps(sum, accm[iy]));\n    }\n};\n\ntemplate <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {\n    __m256 acc[nrc_y];\n    Minus accm;\n    AccumT() {  for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); }\n    template <typename Unpacker, typename Scales, typename Sum, typename Q8>\n    inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) {\n        auto qx = unp.quants();\n        __m256 dall[nrc_y];\n        for (int i = 0; i < nb/4; ++i) {\n            auto other_scales = unp.set_block_4(i);\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto s12 = scales.prepare4(other_scales, y[iy] + 4*i);\n                dall[iy] = accm.compute(s12, iy);\n            }\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto pall = sum.compute(qx, y[iy] + 4*i);\n                acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]);\n            }\n        }\n        if (!is_multiple_of_4) {\n            for (int i = 4*(nb/4); i < nb; ++i) {\n                auto other_scales = unp.set_block(i);\n                for (int iy = 0; iy < nrc_y; ++iy) {\n                    auto s12 = scales.prepare1(other_scales, y[iy] + i);\n                    auto d = accm.compute(s12, iy);\n                    const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));\n                    acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);\n                }\n            }\n        }\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, accm.result(acc[iy], iy));\n            //s[iy*bs] = accm.result(acc[iy], iy);\n        }\n    }\n};\n\ntemplate <int nrc_y, bool is_multiple_of_4>\nusing AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>;\n\ntemplate <int nrc_y, bool is_multiple_of_4>\nusing AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;\n\nusing Sum4Type0 = Sum4<block_q8_0, SignedDot>;\nusing Sum4Type1 = Sum4<block_q8_1, UnsignedDot>;\n\ntemplate <typename Unpacker, typename Sum4Type, typename AccumType, typename Scales, typename Q8, int nrc_y>\nvoid mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {\n    Unpacker unp(vx, bx);\n    Sum4Type sum4;\n    Scales scales;\n    for (int ix = 0; ix < nrc_x; ++ix) {\n        unp.set_row(ix);\n        AccumType accum;\n        accum.compute(nb, unp, scales, sum4, y, info, ix);\n    }\n}\n\ntemplate <typename Unpacker, int nrc_y>\nvoid mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_0> q8(info);\n    int nb = n/Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\ntemplate <typename Unpacker, int nrc_y>\nvoid mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_1> q8(info);\n    int nb = n/Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type1, AccumType1<nrc_y, true>, ScaleHelperQ_1, block_q8_1, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type1, AccumType1<nrc_y, false>, ScaleHelperQ_1, block_q8_1, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\nstruct Dequantizer4bit {\n    const __m256i m4 = _mm256_set1_epi8(0xf);\n    inline __m256i dequant(const uint8_t * qs) const {\n        const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);\n        return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4);\n    }\n};\n\nstruct Q8_0_Dequantizer {\n    inline __m256i dequant(const block_q8_0 * x) const {\n        return _mm256_loadu_si256((const __m256i *)x->qs);\n    }\n};\n\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8455\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\nstruct Q8_0_1_Dequantizer {\n    inline __m256i dequant(const block_q8_0 * x) const {\n        return _mm256_add_epi8(_mm256_set1_epi8(127), _mm256_loadu_si256((const __m256i *)x->qs));\n    }\n};\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8455\n\nstruct Q4_0_Dequantizer {\n    Dequantizer4bit b4;\n    const __m256i m8 = _mm256_set1_epi8(-8);\n    inline __m256i dequant(const block_q4_0 * x) const {\n        return _mm256_add_epi8(b4.dequant(x->qs), m8);\n    }\n};\n\nstruct Q4_1_Dequantizer {\n    Dequantizer4bit b4;\n    inline __m256i dequant(const block_q4_1 * x) const {\n        return b4.dequant(x->qs);\n    }\n};\n\nstruct HBitDequantizer {\n    const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);\n    const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);\n    const __m256i minus1 = _mm256_set1_epi64x(-1);\n    inline __m256i to_bytes(const uint8_t * bits) const {\n        // Note: Data in all ggml quants is at least 2-byte aligned.\n        // => we can cast to uint16_t and use or on two consecutive entries\n        // which is faster than memcpy\n        const uint16_t * aux16 = (const uint16_t *)bits;\n        const uint32_t aux32 = aux16[0] | (aux16[1] << 16);\n        //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t));\n        __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle);\n        bytes = _mm256_or_si256(bytes, mask);\n        return _mm256_cmpeq_epi8(bytes, minus1);\n    }\n};\n\nstruct Q5_0_Dequantizer {\n    Dequantizer4bit b4;\n    HBitDequantizer hbit;\n    const __m256i mh = _mm256_set1_epi8((char)0xF0);\n    inline __m256i dequant(const block_q5_0 * x) const {\n        const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh);\n        return _mm256_or_si256(b4.dequant(x->qs), vqh);\n    }\n};\n\nstruct Q5_1_Dequantizer {\n    Dequantizer4bit b4;\n    HBitDequantizer hbit;\n    const __m256i mh = _mm256_set1_epi8(0x10);\n    inline __m256i dequant(const block_q5_1 * x) const {\n        const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh);\n        return _mm256_or_si256(b4.dequant(x->qs), vqh);\n    }\n};\n\ntemplate <typename Q, typename Scales, typename Dequantizer>\nstruct Q_Unpacker {\n    Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {}\n\n    const char * cx_0;\n    const Q    * x;\n    size_t       bx;\n\n    Scales scales;\n    Dequantizer deq;\n\n    __m256i qx[4];\n\n    inline const __m256i* quants() const { return qx; }\n\n    inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); }\n\n    inline auto set_block_4(int i) {\n        for (int j = 0; j < 4; ++j) {\n            qx[j] = deq.dequant(x + 4*i + j);\n        }\n        return scales.prepare4(x + 4*i);\n    }\n    inline auto set_block(int i) {\n        qx[0] = deq.dequant(x + i);\n        return scales.prepare1(x + i);\n    }\n};\n\nstruct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {\n    Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_0; }\n};\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8574\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\nstruct Q8_0_1_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0_1<127>, Q8_0_1_Dequantizer> {\n    Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n//    using Sum4T = Sum4TypeQ81;\n    inline static int block_size() { return QK8_0; }\n};\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8574\nstruct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {\n    Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_0; }\n};\nstruct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {\n    Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK5_0; }\n};\nstruct Q4_1_Unpacker final : public Q_Unpacker<block_q4_1, ScaleHelperQ_1, Q4_1_Dequantizer> {\n    Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_1; }\n};\nstruct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_Dequantizer> {\n    Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_1; }\n};\n\ntemplate <int nrc_y>\nvoid mul_mat_q8_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Q8_0_Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_0> q8(info);\n    int nb = n/Q8_0_Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\n\n\n\n/*\nmoonll\nadd some structs for DequantizerIQ2XXS\nSimpleBits\nEvenSignHelper\n*/\nstruct SimpleBits {\n    __m256i values[4];\n};\n\n// fix for #829: Add checks of AVX512VPOPCNTDQ\n#if defined(HAVE_FANCY_SIMD) && defined(__AVX512VPOPCNTDQ__)\n#define HAVE_AVX512_POPCNT 1\n#else\n#define HAVE_AVX512_POPCNT 0\n#endif\n\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L7736\n// with the addition of a branch that handles a missing _mm256_popcnt_epi32 instruction\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\nstruct EvenSignHelper {\n    #if defined HAVE_FANCY_SIMD\n    // #pragma message(\"Using AVX512VPOPCNTDQ in even sign helper\")\n        union sbits_t {\n            __m128i vec;\n            __mmask32 mask[4];\n        };\n        IQK_ALWAYS_INLINE void sign_2_values(__m256i aux, __m256i * values) const {\n            aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask);\n            \n            // fix for #829: Compatibility with processors using Intel Cascade Lake architecture\n            // If AVX512VPOPCNTDQ extension is not supported, use alternative implementation\n            #if HAVE_AVX512_POPCNT\n                auto pcnt = _mm256_popcnt_epi32(aux);\n                \n            #else\n                // Alternative implementation: Using standard bit counting method\n                __m256i pcnt;\n                int* pcnt_ptr = reinterpret_cast<int*>(&pcnt);\n                int* aux_ptr = reinterpret_cast<int*>(&aux); // Get address of aux directly, avoid unnecessary copies\n                \n                #pragma unroll 8  // Hint compiler to unroll loops, increasing throughput of SIMD computing\n                for (int i = 0; i < 8; i++) {\n                    pcnt_ptr[i] = __builtin_popcount(aux_ptr[i]); // Use compiler builtin popcount\n                }\n            #endif\n            \n            sbits_t sbits;\n            sbits.vec = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));\n            values[0] = _mm256_mask_sub_epi8(values[0], sbits.mask[0], _mm256_setzero_si256(), values[0]);\n            values[1] = _mm256_mask_sub_epi8(values[1], sbits.mask[1], _mm256_setzero_si256(), values[1]);\n            //auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));\n            //const __mmask32 * m32 = (const __mmask32 *)&sign_bits;\n            //values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]);\n            //values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]);\n        }\n        const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0);\n        const __m256i mask   = _mm256_set1_epi32(127);\n        const __m256i mone   = _mm256_set1_epi32(1);\n    #else\n        inline void sign_value(uint32_t aux32, __m256i& value) const {\n            auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],\n                                           keven_signs[(aux32 >>  7) & 127], keven_signs[(aux32 >>  0) & 127]);\n            value = _mm256_sign_epi8(value, signs);\n        }\n    #endif\n};\n\n/*\nmoonll ad multiply_add for mul_mat_qX_K_q8_K_IQ_1\nadd func\nget_scale_shuffle_8\nget_scale_shuffle_16\nset_scales_16\n*/\n\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1578\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\ninline __m256i get_scale_shuffle_8(int i) {\n    return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));\n}\n\ninline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {\n    scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));\n    scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));\n    scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));\n    scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));\n}\n\n\ninline __m256i get_scale_shuffle_16(int i) {\n    static const uint8_t k_shuffle[128] = {\n         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,\n         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,\n         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,\n        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,\n    };\n    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);\n}\n\ninline void set_scales_16(const __m256i& all_scales, __m256i * scales) {\n    scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));\n    scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));\n    scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));\n    scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));\n}\n\ntemplate <typename Q8, typename Bits>\ninline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {\n    if (j == 0) {\n#ifdef HAVE_FANCY_SIMD\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));\n        }\n#else\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));\n            sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));\n        }\n#endif\n    } else {\n#ifdef HAVE_FANCY_SIMD\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));\n        }\n#else\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));\n        }\n#endif\n    }\n}\n\n/*\nmoonll ad multiply_add_1 for mul_mat_qX_K_q8_K_IQ_1\nadd func\nset_scales_8_iq\nset_scales_16_iq\n\nadd MUL_MAT\nmul_mat_qX_K_q8_K_IQ_1\nmul_mat_qX_K_q8_K_IQ_N\nmul_mat_qX_K_q8_K_IQ\n*/\n\ntemplate <typename Bits>\ninline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {\n    if (j == 0) {\n#ifdef HAVE_FANCY_SIMD\n        auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);\n        auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);\n        auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);\n        auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);\n        sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2));\n        sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4));\n#else\n        const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));\n        const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));\n        const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));\n        const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));\n        sumi[0] = _mm256_add_epi32(p1, p3);\n        sumi[1] = _mm256_add_epi32(p2, p4);\n#endif\n    } else {\n#ifdef HAVE_FANCY_SIMD\n        auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);\n        auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);\n        auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);\n        auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);\n        sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2));\n        sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4));\n#else\n        const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));\n        const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));\n        const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));\n        const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));\n        sumi[0] = _mm256_add_epi32(sumi[0], _mm256_add_epi32(p1, p3));\n        sumi[1] = _mm256_add_epi32(sumi[1], _mm256_add_epi32(p2, p4));\n#endif\n    }\n}\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L1578\n\n\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L7278\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\ninline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) {\n    //#ifdef HAVE_FANCY_SIMD\n        auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100)\n                              : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908);\n        scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);\n        scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)));\n    //#else\n    //    set_scales_8(all_scales, j, scales);\n    //#endif\n    }\n    \ninline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) {\n    #ifdef HAVE_FANCY_SIMD\n        auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100);\n        scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);\n        scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8)));\n    #else\n        set_scales_16(all_scales, scales);\n    #endif\n    }\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L7278\n    \n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L7299\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\ntemplate <typename Dequantizer>\nstatic void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n        const int nb = n / QK_K;\n        Q8<1> q8(info);\n        Dequantizer deq(vx, bx);\n        __m256i scales[2];\n        __m256i q8_quants[4];\n        for (int ix = 0; ix < nrc_x; ++ix) {\n    \n            __m256 accd = _mm256_setzero_ps();\n            deq.new_row(ix);\n    \n            for (int i = 0; i < nb; ++i) {\n    \n                __m256i sumi[2], all_scales[Dequantizer::num_blocks/8];\n                deq.new_block(i, all_scales);\n    \n                for (int j = 0; j < QK_K/128; ++j) {\n                    deq.prepare(i, j, q8, q8_quants);\n                    if constexpr (Dequantizer::num_blocks == 8) {\n                        set_scales_8_iq(j, all_scales[0], scales);\n                    } else {\n                        set_scales_16_iq(all_scales[j], scales);\n                    }\n                    multiply_add_1(j, deq.bits, scales, q8_quants, sumi);\n                }\n                accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd);\n            }\n    \n            info.store(ix, 0, hsum_float_8(accd));\n        }\n    }\n\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK_K;\n    Q8<nrc_y> q8(info);\n    Dequantizer deq(vx, bx);\n    __m256i scales[4];\n    __m256  accd[nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8];\n            //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();\n            __m256i mins;\n            float dmin = deq.new_block(i, all_scales, mins);\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto bsums = q8.load_bsums(iy, i);\n                auto prod  = _mm256_madd_epi16(mins, bsums);\n                accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);\n            }\n\n            for (int j = 0; j < QK_K/128; ++j) {\n                deq.prepare(i, j);\n                if constexpr (Dequantizer::num_blocks == 8) {\n                    set_scales_8(all_scales[0], j, scales);\n                } else {\n                    set_scales_16(all_scales[j], scales);\n                }\n                //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n            }\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));\n                accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n#ifdef HAVE_FANCY_SIMD\n    if constexpr (nrc_y == 1) {\n        mul_mat_qX_K_q8_K_IQ_1<Dequantizer>(n, vx, bx, info, nrc_x);\n    } else {\n        mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);\n    }\n#else\n    mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);\n#endif\n}\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L7299\n\n/*\nmoonll iq1s\ncore func for iq1s mul_mat_iq1_s_q8_K\n\n*/\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L3813\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\ntemplate <int nrc_y>\nstatic void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    GGML_ASSERT(n%QK_K == 0);\n    Q8<nrc_y, block_q8_K> q8(info);\n    __m256i qx[8];\n    __m256i scales[4];\n    __m256  acc[nrc_y] = {};\n    auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000\n    __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100);\n    for (int ix = 0; ix < nrc_x; ++ix) {\n        auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);\n        for (int ibl = 0; ibl < n/QK_K; ++ibl) {\n            float d = GGML_FP16_TO_FP32(iq1s[ibl].d);\n            auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh);\n            auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7));\n            scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1));\n#ifdef HAVE_FANCY_SIMD\n            auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask);\n            auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9));\n#else\n            auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask);\n            auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7)));\n#endif\n            deltas128 = _mm_mullo_epi16(scales128, deltas128);\n            scales128 = _mm_slli_epi16(scales128, 3);\n            auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128);\n            auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128);\n            auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7\n            auto all_scales = MM256_SET_M128I(scales128, scales128);\n            auto shuffle = shuffle0;\n            for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {\n                scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle);\n                shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4));\n            }\n            const uint8_t  * qs = iq1s[ibl].qs;\n            const uint16_t * qh = iq1s[ibl].qh;\n            for (int ib = 0; ib < QK_K/32; ib += 2) {\n                qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)],\n                                             iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]);\n                qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)],\n                                             iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]);\n                qs += 8;\n            }\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto bsums = q8.load_bsums(iy, ibl);\n                auto sumi = _mm256_setzero_si256();\n                for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {\n                    auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0);\n                    auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1);\n#ifdef HAVE_FANCY_SIMD\n                    auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1);\n                    auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2);\n                    sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2));\n#else\n                    auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1);\n                    auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2);\n                    auto dot  = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2));\n                    sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot));\n#endif\n                }\n#ifdef HAVE_FANCY_SIMD\n                sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas);\n#else\n                sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas));\n#endif\n                acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]);\n            }\n        }\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, 0.125f*hsum_float_8(acc[iy]));\n            acc[iy] = _mm256_setzero_ps();\n        }\n    }\n}\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L3813\n\n/*\nmoonll iq1s\nDequantizerIQ2XXS\nDequantizerIQ2XXS is important Dequantizer for DequantizerIQ1_S\n*/\n\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8035\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\nstruct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {\n    DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    constexpr static int num_blocks = 8;\n\n    union Data {\n        __m256i vec;\n        uint32_t val[8];\n    };\n\n    inline __m128i load_scales(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        const uint16_t * a16 = (const uint16_t *)x[i].qs;\n        auto scales = _mm_srli_epi16(_mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]), 12);\n        return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1));\n    }\n\n    inline void new_block(int i, __m256i * scales) {\n        auto sc16 = load_scales(i);\n        scales[0] = MM256_SET_M128I(sc16, sc16);\n    }\n    inline float new_block(int i, __m256i * scales, __m256i& mins) {\n        auto sc16 = load_scales(i);\n        mins = scb.shuffle(sc16);\n        scales[0] = MM256_SET_M128I(sc16, sc16);\n        return -d*minv;\n    }\n\n    inline static void make4(const uint32_t * aux32, __m256i * values) {\n        const uint8_t * aux8 = (const uint8_t *)aux32;\n        values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]);\n        values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]);\n        values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]);\n        values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]);\n    }\n\n    IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const {\n#ifdef HAVE_FANCY_SIMD\n        esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0);\n        esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2);\n#else\n        esh.sign_value(aux32[1], values[0]);\n        esh.sign_value(aux32[3], values[1]);\n        esh.sign_value(aux32[5], values[2]);\n        esh.sign_value(aux32[7], values[3]);\n#endif\n    }\n    inline void make4_signed(const uint32_t * aux32, const __m256i& min_value, __m256i * values) const {\n        make4(aux32, values);\n        sign_values(aux32, values);\n        for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);\n    }\n    inline void make4(const uint32_t * aux32, __m256i * values, __m256i * q8) const {\n        make4(aux32, values);\n        sign_values(aux32, q8);\n    }\n    inline void prepare(int i, int j) {\n        Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);\n        make4_signed(data.val, min_value, bits.values);\n    }\n    inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {\n        for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);\n        Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);\n        make4(data.val, bits.values, q8_quants);\n    }\n\n    constexpr static int minv = 43;\n    SimpleBits bits;\n    Scales8KBase scb;\n    EvenSignHelper esh;\n    const __m256i min_value = _mm256_set1_epi8(minv);\n    const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1);\n};\n\n/*\nmoonll\nadd Q8_0_Unpacker && DequantizerIQ2XXS support\nadd func mul_mat_qX_K_q8_K_IQ\n*/\n\n// Copied/adapted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L9092\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\ntemplate <typename Dequantizer> void MulMat::set_functions(MulMat& m) {\n    if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||\n        std::is_same_v<Dequantizer, Q8_0_Unpacker>) {\n            m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;\n        }\n        else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>|| std::is_same_v<Dequantizer, Q8_0_1_Unpacker>) {\n            m.funcs[0] = mul_mat_qX_1_q8_1_T<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_1_q8_1_T<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_1_q8_1_T<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_1_q8_1_T<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_1_q8_1_T<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_1_q8_1_T<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;\n        }\n        else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2XXS>) {\n            m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 8>;\n            }\n            else {\n#ifdef HAVE_FANCY_SIMD\n            if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4XS>) {\n            m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;\n            } else {\n            m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;\n            m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;\n            }\n#else\n            if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||\n                          std::is_same_v<Dequantizer, DequantizerQ3K> ||\n                          std::is_same_v<Dequantizer, DequantizerQ6K>) {\n                m.funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;\n                m.funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;\n                m.funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;\n                m.funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;\n                m.funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;\n                m.funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;\n                m.funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;\n                m.funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;\n            } else {\n                m.funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;\n                m.funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;\n                m.funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;\n                m.funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;\n                m.funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;\n                m.funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;\n                m.funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;\n                m.funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;\n            }\n#endif\n        }\n}\n// end copied/adapted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L9092\n\n// Copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8622\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\nstruct QFBase {\n    #ifdef __AVX512F__\n        constexpr static int k_step = 16;\n        using Data = __m512;\n        using Acc  = __m512;\n        static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); }\n        static inline Data load(const float * x) { return _mm512_loadu_ps(x); }\n        static inline Data load(const ggml_bf16_t * x) {\n            return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16));\n        }\n        static inline Acc acc(Acc prev, const Data& y, const Data& x) {\n            return _mm512_fmadd_ps(y, x, prev);\n        }\n        static inline Acc acc_first(const Data& y, const Data& x) {\n            return _mm512_mul_ps(y, x);\n        }\n        static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); }\n        static inline float hsum(Acc acc) {\n            return _mm512_reduce_add_ps(acc);\n        }\n        template <typename Float>\n        static inline Data load4Floats(const Float * x) {\n            return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);\n        }\n        static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {\n            acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);\n            acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline Acc acc_r4_first(const Data * xv, const Data& yv) {\n            auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00));\n            acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline __m128 hsum_r4(Acc acc) {\n            auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1));\n            auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3));\n            return _mm_add_ps(sum1, sum2);\n        }\n    #else\n        constexpr static int k_step = 8;\n        using Data = __m256;\n        using Acc  = __m256;\n        static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }\n        static inline Data load(const float * x) { return _mm256_loadu_ps(x); }\n        static inline Data load(const ggml_bf16_t * x) {\n            return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16));\n        }\n        static inline Acc acc(Acc prev, const Data& y, const Data& x) {\n            return _mm256_fmadd_ps(y, x, prev);\n        }\n        static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); }\n        static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {\n            acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);\n            acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline Acc acc_r4_first(const Data * xv, const Data& yv) {\n            auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00));\n            acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline Acc acc_first(const Data& y, const Data& x) {\n            return _mm256_mul_ps(y, x);\n        }\n        static inline float hsum(Acc acc) {\n            return hsum_float_8(acc);\n        }\n        static inline __m128 hsum_r4(Acc acc) {\n            return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));\n        }\n        template <typename Float>\n        static inline Data load4Floats(const Float * x) {\n            return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0);\n        }\n    #endif\n        static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); }\n        static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); }\n        static inline __m128 load128(const ggml_bf16_t * x) {\n            return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16));\n        }\n    };\n    template <typename Float, int nrc_in> struct QFT final : public QFBase {\n        constexpr static int nrc = nrc_in;\n        QFT(const DataInfo& info) {\n            for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy);\n        }\n        QFT(const char * cx, size_t bx) {\n            for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx);\n        }\n        IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }\n        IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); }\n        IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const {\n            xv[0] = load1(ix+0, i);\n            xv[1] = load1(ix+1, i);\n            xv[2] = load1(ix+2, i);\n            xv[3] = load1(ix+3, i);\n    #ifdef __AVX512F__\n            auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]);\n            auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]);\n            auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]);\n            auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]);\n            xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));\n            xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));\n            xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));\n            xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));\n    #else\n            auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]);\n            auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]);\n            auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]);\n            auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]);\n            xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));\n            xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));\n            xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));\n            xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));\n    #endif\n        }\n        const Float * y[nrc];\n    };\n    \n\n\ntemplate <typename Qy, typename Qx>\nIQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {\n    int nb = n/QFBase::k_step;\n    int nb4 = n/4;\n    Qy y(info);\n    Qx x(cx + ix0*bx, bx);\n    QFBase::Data xv[Qx::nrc];\n    QFBase::Acc  acc[Qx::nrc*Qy::nrc];\n    auto yv = y.load1(0, 0);\n    for (int ix = 0; ix < Qx::nrc; ++ix) {\n        xv[ix] = x.load1(ix, 0);\n        acc[ix] = QFBase::acc_first(yv, xv[ix]);\n    }\n    for (int iy = 1; iy < Qy::nrc; ++iy) {\n        yv = y.load1(iy, 0);\n        for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);\n    }\n    for (int i = 1; i < nb; ++i) {\n        yv = y.load1(0, i);\n        for (int ix = 0; ix < Qx::nrc; ++ix) {\n            xv[ix] = x.load1(ix, i);\n            acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);\n        }\n        for (int iy = 1; iy < Qy::nrc; ++iy) {\n            yv = y.load1(iy, i);\n            for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);\n        }\n    }\n    for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {\n        yv = y.load_tail(0, i);\n        for (int ix = 0; ix < Qx::nrc; ++ix) {\n            xv[ix] = x.load_tail(ix, i);\n            acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);\n        }\n        for (int iy = 1; iy < Qy::nrc; ++iy) {\n            yv = y.load_tail(iy, i);\n            for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);\n        }\n    }\n    for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));\n}\n// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done\n// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in\n// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.\ntemplate <int nrc_y, typename FloatX, typename FloatY>\nvoid mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    const char * cx = (const char *)vx;\n    // TBD if we want this\n    //if constexpr (nrc_y == 1) {\n    //    constexpr int k_nx = 2;\n    //    for (int ix = 0; ix < nrc_x/k_nx; ++ix) {\n    //        mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);\n    //    }\n    //    if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) {\n    //        int nx = nrc_x - lastx;\n    //        switch (nx) {\n    //            case 1: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info); break;\n    //            case 2: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, lastx, info); break;\n    //            case 3: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, lastx, info); break;\n    //        }\n    //        //mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info);\n    //    }\n    //    return;\n    //}\n#ifdef __AVX512F__\n    constexpr int k_nx = 5;\n#else\n    constexpr int k_nx = nrc_y == 1 ? 4 : 2;\n#endif\n    for (int ix = 0; ix < nrc_x/k_nx; ++ix) {\n        mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);\n    }\n    int last_x = k_nx*(nrc_x/k_nx);\n    if (last_x == nrc_x) return;\n    int nx = nrc_x - last_x;\n#ifdef __AVX512F__\n    switch (nx) {\n        case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;\n        case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;\n        case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;\n        case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;\n    }\n#else\n    if constexpr (nrc_y == 1) {\n        switch (nx) {\n            case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;\n            case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;\n            case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;\n        }\n    } else {\n        switch (nx) {\n            case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;\n        }\n    }\n#endif\n}\n\ntemplate <typename FloatX, typename FloatY>\nvoid set_mul_mat_f(MulMat& mm) {\n    for (auto& f : mm.funcs) f = nullptr;\n    mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>;\n    mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>;\n    mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>;\n    mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>;\n    mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>;\n#ifndef __AVX512F__\n    mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>;\n#endif\n}\n// end copied from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L8622\n\n/*\nmoonll\nadd typeb TO compare return not expected type of weight matrix\nadd IQ2XSS\nadd IQ1_S\nadd GGML_TYPE_IQ4_XS\n*/\n\n// Modifications extracted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L9231\n// MIT licensed, Copyright (c) 2024-2025 Iwan Kawrakow\nbool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {\n    (void)Ny;\n\n        auto expected_typeB = GGML_TYPE_Q8_K;\n    switch (typeA) {\n        case GGML_TYPE_Q2_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ2K>(mm);\n            break;\n        case GGML_TYPE_Q3_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ3K>(mm);\n            break;\n        case GGML_TYPE_Q4_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ4K>(mm);\n            break;\n        case GGML_TYPE_Q5_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ5K>(mm);\n            break;\n        case GGML_TYPE_Q6_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ6K>(mm);\n            break;\n        case GGML_TYPE_IQ4_XS:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerIQ4XS>(mm);\n            break;\n        case GGML_TYPE_IQ2_XXS:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerIQ2XXS>(mm);\n            break;\n        case GGML_TYPE_Q4_0:\n            assert (ne00 % QK4_0 == 0);\n            MulMat::set_functions<Q4_0_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_0;\n            break;\n        case GGML_TYPE_Q4_1:\n            assert (ne00 % QK4_1 == 0);\n            MulMat::set_functions<Q4_1_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_1_X4;\n            break;\n        case GGML_TYPE_Q5_0:\n            assert (ne00 % QK5_0 == 0);\n            MulMat::set_functions<Q5_0_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_0;\n            break;\n        case GGML_TYPE_Q5_1:\n            assert (ne00 % QK5_1 == 0);\n            MulMat::set_functions<Q5_1_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_1_X4;\n            break;\n        case GGML_TYPE_Q8_0:\n            assert (ne00 % QK8_0 == 0);\n#ifdef HAVE_FANCY_SIMD\n            MulMat::set_functions<Q8_0_1_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_1_X4;\n#else\n            MulMat::set_functions<Q8_0_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_0_X4;\n#endif\n            break;\n        case GGML_TYPE_IQ1_S:\n            mm.funcs[0] = mul_mat_iq1_s_q8_K<1>;\n            mm.funcs[1] = mul_mat_iq1_s_q8_K<2>;\n            mm.funcs[2] = mul_mat_iq1_s_q8_K<3>;\n            mm.funcs[3] = mul_mat_iq1_s_q8_K<4>;\n            mm.funcs[4] = mul_mat_iq1_s_q8_K<5>;\n            mm.funcs[5] = mul_mat_iq1_s_q8_K<6>;\n            mm.funcs[6] = mul_mat_iq1_s_q8_K<7>;\n            mm.funcs[7] = mul_mat_iq1_s_q8_K<8>;\n        #ifdef HAVE_FANCY_SIMD\n             mm.func16 = mul_mat_iq1_s_q8_K<16>;\n        #endif\n       // row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);\n              expected_typeB = GGML_TYPE_Q8_K;\n            break;\n\n        default:\n        {\n            // printf(\"case:%d\",typeA);\n            return false;\n        }\n            \n    }\n\n\n\n    return ggml_type(typeB) == expected_typeB;\n\n}\n// end extracted from https://github.com/ikawrakow/ik_llama.cpp/blob/474435f58b6a26bc549589966482207fee94aa60/ggml/src/iqk/iqk_mul_mat.cpp#L9231\n\n} // namespace\n\n/*\niq1_s is not support for arm\n*/\n#else   // __aarch64__\n\nnamespace {\n\ntemplate <int nrc, typename block_q8 = block_q8_K> struct Q8 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q8(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);\n    }\n\n    inline int8x16_t load_quants_16(int iy, int i, int j) const { return vld1q_s8(y[iy][i].qs + 16*j); }\n    inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }\n    inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }\n    inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }\n    inline int16x8_t load_bsums8(int iy, int i) const {\n        auto q8s = vld1q_s16_x2(y[iy][i].bsums);\n        return vpaddq_s16(q8s.val[0], q8s.val[1]);\n    }\n    inline float scale(int iy, int i) const { return y[iy][i].d; }\n\n    const block_q8 * y[nrc_y];\n};\n\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n\n    Dequantizer deq(vx, bx, nrc_y);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[nrc_y];\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n//#pragma GCC unroll 4\n        for (int i = 0; i < nb; ++i) {\n\n            int32x4_t sumi[nrc_y];\n            for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);\n\n            if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {\n                deq.process_scales(i, q8, acc);\n                deq.prepare(i, 0);\n                deq.compute(q8, i, 0, sumi);\n                deq.prepare(i, 1);\n                deq.compute(q8, i, 1, sumi);\n            } else {\n                if constexpr (Dequantizer::num_blocks() == 8) {\n                    auto scales = deq.new_block(i, q8, acc);\n                    deq.prepare(i, 0);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                    deq.prepare(i, 1);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n                }\n                else if constexpr (Dequantizer::num_blocks() == 16) {\n                    auto scales = deq.new_block(i, q8, acc);\n                    deq.prepare(i, 0);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                    deq.prepare(i, 1);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n                }\n                else {\n                    GGML_ASSERT(false);\n                }\n            }\n\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));\n            }\n        }\n\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n\n    Dequantizer deq(vx, bx, nrc_y);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[nrc_y];\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb; ++i) {\n\n            int32x4_t sumi[nrc_y];\n            for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);\n\n            if constexpr (Dequantizer::num_blocks() == 8) {\n                auto scales = deq.new_block(i);\n                deq.prepare(i, 0);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                deq.prepare(i, 1);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n            }\n            else if constexpr (Dequantizer::num_blocks() == 16) {\n                auto scales = deq.new_block(i);\n                deq.prepare(i, 0);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                deq.prepare(i, 1);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n            }\n            else {\n                GGML_ASSERT(false);\n            }\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));\n            }\n        }\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,\n        const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n    auto mzero = vdupq_n_s32(0);\n    const int8x16_t * qs_1 = (const int8x16_t *)qx_1.val;\n    const int8x16_t * qs_2 = (const int8x16_t *)qx_2.val;\n\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[0], q8b_1.val[0]), qs_1[1], q8b_1.val[1]); // block 1\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[2], q8b_2.val[0]), qs_1[3], q8b_2.val[1]); // block 2\n    auto p12 = vpaddq_s32(p1, p2);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[0], q8b_3.val[0]), qs_2[1], q8b_3.val[1]); // block 3\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[2], q8b_4.val[0]), qs_2[3], q8b_4.val[1]); // block 4\n    auto p34 = vpaddq_s32(p3, p4);\n\n    auto pall = vpaddq_s32(p12, p34);\n    sumi = vmlaq_s32(sumi, scales.val[j], pall);\n}\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_8_blocks(const int8x16_t * qx, const Q8& q8,\n        const int32x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n    auto mzero = vdupq_n_s32(0);\n\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[0], q8b_1.val[0]), qx[1], q8b_1.val[1]); // block 1\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[2], q8b_2.val[0]), qx[3], q8b_2.val[1]); // block 2\n    auto p12 = vpaddq_s32(p1, p2);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[4], q8b_3.val[0]), qx[5], q8b_3.val[1]); // block 3\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[6], q8b_4.val[0]), qx[7], q8b_4.val[1]); // block 4\n    auto p34 = vpaddq_s32(p3, p4);\n\n    auto pall = vpaddq_s32(p12, p34);\n    sumi = vmlaq_s32(sumi, scales, pall);\n}\n\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,\n        const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n\n    auto mzero = vdupq_n_s32(0);\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,\n    auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3\n    sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,\n    auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7\n    sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);\n}\n\ntemplate <typename Q8>\ninline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto q8s = q8.load_bsums8(iy, i);\n        int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));\n        int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));\n        float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));\n        acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));\n    }\n}\ntemplate <typename Q8>\ninline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto q8s = q8.load_bsums(iy, i);\n        int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));\n        int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));\n        int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));\n        int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));\n        float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));\n        acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));\n    }\n}\n\nstruct Scales8 {\n    uint32_t utmp[4];\n    const uint8_t * sc8 = (const uint8_t *)utmp;\n    template <typename Q8, typename Qx>\n    inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {\n        make_q4_scales(x.scales, utmp);\n        int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));\n        accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));\n\n        uint8x8_t scales8 = vld1_u8(sc8);\n        uint16x8_t scales16 = vmovl_u8(scales8);\n        int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),\n                              vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};\n        return scales;\n    }\n};\n\nstruct Q4bits {\n    const uint8x16_t m4b = vdupq_n_u8(0xf);\n    uint8x16x4_t b1, b2;\n    inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {\n        b.val[0] = vandq_u8(val[0], m4b);\n        b.val[2] = vshrq_n_u8(val[0], 4);\n        b.val[1] = vandq_u8(val[1], m4b);\n        b.val[3] = vshrq_n_u8(val[1], 4);\n    }\n    inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {\n        b.val[0] = vandq_u8(val[0], m4b);\n        b.val[1] = vshrq_n_u8(val[0], 4);\n        b.val[2] = vandq_u8(val[1], m4b);\n        b.val[3] = vshrq_n_u8(val[1], 4);\n    }\n    inline void prepare(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x2(qs);\n        prepare4(b1, q4bits.val);\n        q4bits = vld1q_u8_x2(qs+32);\n        prepare4(b2, q4bits.val);\n    }\n    inline void prepare_v2(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        prepare4(b1, q4bits.val+0);\n        prepare4(b2, q4bits.val+2);\n    }\n    inline void prepare64(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        b1.val[0] = vandq_u8(q4bits.val[0], m4b);\n        b1.val[1] = vandq_u8(q4bits.val[1], m4b);\n        b1.val[2] = vandq_u8(q4bits.val[2], m4b);\n        b1.val[3] = vandq_u8(q4bits.val[3], m4b);\n        b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);\n        b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);\n        b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);\n        b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);\n    }\n    inline void prepare16(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x2(qs);\n        prepare4_16(b1, q4bits.val);\n        q4bits = vld1q_u8_x2(qs+32);\n        prepare4_16(b2, q4bits.val);\n    }\n    inline void prepare16_v2(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        prepare4_16(b1, q4bits.val+0);\n        prepare4_16(b2, q4bits.val+2);\n    }\n};\n\nstruct Q2bits {\n    const uint8x16_t m4b = vdupq_n_u8(0x03);\n    uint8x16x4_t b1, b2;\n    inline void prepare(const uint8_t * qs) {\n        auto q2bits = vld1q_u8_x2(qs);\n        b1.val[0] = vandq_u8(q2bits.val[0], m4b);\n        b1.val[1] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b1.val[2] = vandq_u8(q2bits.val[0], m4b);\n        b1.val[3] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b2.val[0] = vandq_u8(q2bits.val[0], m4b);\n        b2.val[1] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b2.val[2] = vandq_u8(q2bits.val[0], m4b);\n        b2.val[3] = vandq_u8(q2bits.val[1], m4b);\n    }\n};\n\ntemplate <typename block_q>\nstruct BaseDequantizer {\n    BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}\n    inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); }\n    const void * vx;\n    const block_q * x;\n    const size_t bx;\n    const int nrc;\n};\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return s8.process_scales_mins(x[i], q8, i, acc);\n    }\n    inline void prepare(int i, int j) {\n        if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);\n        else bits.prepare(x[i].qs+64*j);\n    }\n\n    Q4bits bits;\n    Scales8 s8;\n\n    float d;\n};\n\nstruct HighBit5 {\n    const uint8x16_t mhb = vdupq_n_u8(0x10);\n    uint8x16x2_t bits;\n    inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {\n        b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));\n        b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));\n        b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));\n        b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));\n\n        b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));\n        b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));\n        b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));\n        b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));\n\n        if (do_shift) {\n            bits.val[0] = vshrq_n_u8(bits.val[0], 4);\n            bits.val[1] = vshrq_n_u8(bits.val[1], 4);\n        }\n    }\n};\n\nstruct HighBit3 {\n    const uint8x16_t mhb = vdupq_n_u8(0x04);\n    uint8x16x2_t bits;\n    inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {\n        b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));\n        b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));\n        b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));\n        b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));\n\n        b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));\n        b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));\n        b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));\n        b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));\n\n        if (do_shift) {\n            bits.val[0] = vshrq_n_u8(bits.val[0], 4);\n            bits.val[1] = vshrq_n_u8(bits.val[1], 4);\n        }\n    }\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        h.bits = vld1q_u8_x2(x[i].qh);\n        return s8.process_scales_mins(x[i], q8, i, acc);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+64*j);\n        h.apply(bits.b1, bits.b2, j == 0);\n    }\n\n    Q4bits bits;\n    HighBit5 h;\n    Scales8 s8;\n\n    uint8x16x2_t hbits;\n\n    float d;\n};\n\ninline int32x4x4_t make_wider(const int16x8x2_t& scales16) {\n    int32x4x4_t scales = {\n        vmovl_s16(vget_low_s16 (scales16.val[0])),\n        vmovl_s16(vget_high_s16(scales16.val[0])),\n        vmovl_s16(vget_low_s16 (scales16.val[1])),\n        vmovl_s16(vget_high_s16(scales16.val[1])),\n    };\n    return scales;\n}\n\ntemplate <typename Q8>\ninline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {\n    int16x8x2_t scales16;\n    scales16.val[0] = vmovl_s8(vget_low_s8(scales8));\n    scales16.val[1] = vmovl_s8(vget_high_s8(scales8));\n    accum_mins_16(scales16, q8, acc, i, c);\n    return make_wider(scales16);\n}\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);\n    }\n    inline void prepare(int i, int j) {\n\n        auto hbits = vld1q_u8_x2(x[i].qh + 32*j);\n\n        bits.prepare64(x[i].ql+64*j);\n        bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));\n        bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));\n        bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));\n        bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));\n\n        bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));\n        bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));\n        bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));\n        bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));\n\n    }\n\n    Q4bits bits;\n\n    const uint8x16_t mhb = vdupq_n_u8(0x30);\n\n    float d;\n};\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        h.bits = vld1q_u8_x2(x[i].hmask);\n        const uint16_t * sc16 = (const uint16_t *)x[i].scales;\n        uint32_t aux0 = sc16[0] | (sc16[1] << 16);\n        uint32_t aux1 = sc16[2] | (sc16[3] << 16);\n        uint32_t aux2 = sc16[4] | (sc16[5] << 16);\n        aux32[0] =  (aux0       & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);\n        aux32[1] =  (aux1       & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);\n        aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);\n        aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);\n        return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d);\n    }\n\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+32*j);\n        h.apply(bits.b1, bits.b2, j == 0);\n    }\n\n    uint32_t aux32[4];\n\n    Q2bits bits;\n\n    HighBit3 h;\n\n    float d;\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return true; }\n\n    template <typename Q8>\n    inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        auto scales_and_mins = vld1q_u8(x[i].scales);\n        auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));\n        int16x8x2_t scales16;\n        scales16.val[0] = vmovl_s8(vget_low_s8(mins8));\n        scales16.val[1] = vmovl_s8(vget_high_s8(mins8));\n        accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));\n\n        scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));\n    }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        process_scales(i, q8, acc);\n        int16x8x2_t scales16;\n        scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));\n        scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));\n        return make_wider(scales16);\n    }\n\n    template <typename Q8>\n    inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {\n        auto m1 = vdupq_n_u8(1);\n        auto shuffle = vdupq_n_u8(8*j);\n        bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),\n                    vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);\n\n            auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),\n                    vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);\n\n            auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),\n                    vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);\n\n            auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),\n                    vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);\n        }\n    }\n\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+32*j);\n    }\n\n    uint32_t aux32[4];\n\n    uint8x16_t scales8;\n\n    Q2bits bits;\n\n    float d;\n};\n\n// ============================= i-quants\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n\n    static int8x16_t load_values() {\n        static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};\n        return vld1q_s8(iq4nl_values);\n    }\n\n    DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        (void)q8;\n        (void)acc;\n        d = GGML_FP16_TO_FP32(x[i].d);\n        const uint16_t scales_h = x[i].scales_h;\n        const uint16_t * scales_l = (const uint16_t *)x[i].scales_l;\n        aux32[0] = scales_l[0] | (scales_l[1] << 16);\n        aux32[1] = aux32[0] >> 4;\n        // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7\n        uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf));\n        uint16_t * aux16 = (uint16_t *)aux32;\n        aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2;\n        // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7\n        uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30));\n        int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32));\n        // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7\n        scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff));\n        int16x8_t scales16 = vmovl_s8(scales8);\n        int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};\n        return scales;\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare16(x[i].qs+64*j);\n        for (int k = 0; k < 4; ++k) {\n            bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k]));\n            bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k]));\n        }\n    }\n\n    Q4bits bits;\n    const int8x16_t values;\n    uint32_t aux32[2];\n\n    constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};\n\n    float d;\n};\n\nstruct SimpleBits {\n    uint8x16x4_t b1;\n    uint8x16x4_t b2;\n};\n\nIQK_ALWAYS_INLINE int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) {\n    int32x4x2_t scales;\n    auto one = vdupq_n_u32(1);\n    scales.val[0] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v1, 28), 1));\n    scales.val[1] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v2, 28), 1));\n    return scales;\n}\n\ninline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) {\n    auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));\n    auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127))));\n    b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));\n    b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));\n}\n\nIQK_ALWAYS_INLINE int32x4_t prepare_scales_8(const uint32x4_t& v1) {\n    return vreinterpretq_s32_u32(vsliq_n_u32(vdupq_n_u32(1), vshrq_n_u32(v1, 28), 1));\n}\n\nstruct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {\n    DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    IQK_ALWAYS_INLINE float new_block(int i) const { return 0.125f * GGML_FP16_TO_FP32(x[i].d); }\n\n    inline int32x4_t unpack(int i, int j, uint8x16_t * q) const {\n        auto data = vld1q_u32_x2((const uint32_t *)(x[i].qs + 16*j));\n        prepare_all(data, q);\n        return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1]));\n    }\n\nprivate:\n\n    static inline void prepare2(uint8x16_t * b, const uint32_t * bits, const uint64_t * signs) {\n        const uint8_t * idx = (const uint8_t *)bits;\n        b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]});\n        b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]});\n        apply_signs_2(b, signs, bits[1]);\n    }\n\n    inline static void prepare_all(const uint32x4x2_t& data, uint8x16_t * quants) {\n        const uint32_t * q2 = (const uint32_t *)data.val;\n        prepare2(quants+0, q2+0, keven_signs);\n        prepare2(quants+2, q2+2, keven_signs);\n        prepare2(quants+4, q2+4, keven_signs);\n        prepare2(quants+6, q2+6, keven_signs);\n    }\n};\n\ninline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) {\n    auto aux = vld1_u8(sc);\n    auto scales_l = vand_u8(aux, vdup_n_u8(0xf));\n    auto scales_h = vshr_n_u8(aux, 4);\n    auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));\n\n    auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));\n    int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };\n    return make_wider(scales16);\n}\n\nstruct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {\n    DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x4_t new_block(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        prepare_internal(i, 0);\n        return prepare_4bit_scales16(x[i].scales);\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) prepare_internal(i, 1);\n    }\n\nprivate:\n\n    static void make2(const uint16_t * qs, uint8x16_t * b) {\n        auto v1 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[1] & 511))));\n        auto v2 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[2] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[3] & 511))));\n        auto s1 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9))));\n        auto s2 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[2] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[3] >> 9))));\n        b[0] = vreinterpretq_u8_s8(vmulq_s8(v1, s1));\n        b[1] = vreinterpretq_u8_s8(vmulq_s8(v2, s2));\n    }\n\n    inline static void make4(const uint16_t * qs, uint8x16_t * b) {\n        make2(qs + 0, b + 0);\n        make2(qs + 4, b + 2);\n    }\n\n    IQK_ALWAYS_INLINE void prepare_internal(int i, int j) {\n        make4(x[i].qs + 16*j + 0, bits.b1.val);\n        make4(x[i].qs + 16*j + 8, bits.b2.val);\n    }\n\n};\n\n// So, I hate to include this table, but with the GCC 12.3 compiler\n// bundled in the Cosmopolitan tools, loading the unpacked sign bytes\n// from this table using the packed 8 sign bits as index is faster than\n// using the standard trick of vceqq_u8(vandq_u8(bits, mask), mask) to\n// expand the bits to bytes.\nstatic const uint64_t kall_signs[256] = {\n    0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff,\n    0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff,\n    0x01010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0x01010101ff01ffff,\n    0x01010101ffff0101, 0x01010101ffff01ff, 0x01010101ffffff01, 0x01010101ffffffff,\n    0x010101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0x010101ff0101ffff,\n    0x010101ff01ff0101, 0x010101ff01ff01ff, 0x010101ff01ffff01, 0x010101ff01ffffff,\n    0x010101ffff010101, 0x010101ffff0101ff, 0x010101ffff01ff01, 0x010101ffff01ffff,\n    0x010101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0x010101ffffffffff,\n    0x0101ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0x0101ff010101ffff,\n    0x0101ff0101ff0101, 0x0101ff0101ff01ff, 0x0101ff0101ffff01, 0x0101ff0101ffffff,\n    0x0101ff01ff010101, 0x0101ff01ff0101ff, 0x0101ff01ff01ff01, 0x0101ff01ff01ffff,\n    0x0101ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0x0101ff01ffffffff,\n    0x0101ffff01010101, 0x0101ffff010101ff, 0x0101ffff0101ff01, 0x0101ffff0101ffff,\n    0x0101ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0x0101ffff01ffffff,\n    0x0101ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0x0101ffffff01ffff,\n    0x0101ffffffff0101, 0x0101ffffffff01ff, 0x0101ffffffffff01, 0x0101ffffffffffff,\n    0x01ff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0x01ff01010101ffff,\n    0x01ff010101ff0101, 0x01ff010101ff01ff, 0x01ff010101ffff01, 0x01ff010101ffffff,\n    0x01ff0101ff010101, 0x01ff0101ff0101ff, 0x01ff0101ff01ff01, 0x01ff0101ff01ffff,\n    0x01ff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0x01ff0101ffffffff,\n    0x01ff01ff01010101, 0x01ff01ff010101ff, 0x01ff01ff0101ff01, 0x01ff01ff0101ffff,\n    0x01ff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0x01ff01ff01ffffff,\n    0x01ff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0x01ff01ffff01ffff,\n    0x01ff01ffffff0101, 0x01ff01ffffff01ff, 0x01ff01ffffffff01, 0x01ff01ffffffffff,\n    0x01ffff0101010101, 0x01ffff01010101ff, 0x01ffff010101ff01, 0x01ffff010101ffff,\n    0x01ffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0x01ffff0101ffffff,\n    0x01ffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0x01ffff01ff01ffff,\n    0x01ffff01ffff0101, 0x01ffff01ffff01ff, 0x01ffff01ffffff01, 0x01ffff01ffffffff,\n    0x01ffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0x01ffffff0101ffff,\n    0x01ffffff01ff0101, 0x01ffffff01ff01ff, 0x01ffffff01ffff01, 0x01ffffff01ffffff,\n    0x01ffffffff010101, 0x01ffffffff0101ff, 0x01ffffffff01ff01, 0x01ffffffff01ffff,\n    0x01ffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0x01ffffffffffffff,\n    0xff01010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0xff0101010101ffff,\n    0xff01010101ff0101, 0xff01010101ff01ff, 0xff01010101ffff01, 0xff01010101ffffff,\n    0xff010101ff010101, 0xff010101ff0101ff, 0xff010101ff01ff01, 0xff010101ff01ffff,\n    0xff010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0xff010101ffffffff,\n    0xff0101ff01010101, 0xff0101ff010101ff, 0xff0101ff0101ff01, 0xff0101ff0101ffff,\n    0xff0101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0xff0101ff01ffffff,\n    0xff0101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0xff0101ffff01ffff,\n    0xff0101ffffff0101, 0xff0101ffffff01ff, 0xff0101ffffffff01, 0xff0101ffffffffff,\n    0xff01ff0101010101, 0xff01ff01010101ff, 0xff01ff010101ff01, 0xff01ff010101ffff,\n    0xff01ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0xff01ff0101ffffff,\n    0xff01ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0xff01ff01ff01ffff,\n    0xff01ff01ffff0101, 0xff01ff01ffff01ff, 0xff01ff01ffffff01, 0xff01ff01ffffffff,\n    0xff01ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0xff01ffff0101ffff,\n    0xff01ffff01ff0101, 0xff01ffff01ff01ff, 0xff01ffff01ffff01, 0xff01ffff01ffffff,\n    0xff01ffffff010101, 0xff01ffffff0101ff, 0xff01ffffff01ff01, 0xff01ffffff01ffff,\n    0xff01ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0xff01ffffffffffff,\n    0xffff010101010101, 0xffff0101010101ff, 0xffff01010101ff01, 0xffff01010101ffff,\n    0xffff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0xffff010101ffffff,\n    0xffff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0xffff0101ff01ffff,\n    0xffff0101ffff0101, 0xffff0101ffff01ff, 0xffff0101ffffff01, 0xffff0101ffffffff,\n    0xffff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0xffff01ff0101ffff,\n    0xffff01ff01ff0101, 0xffff01ff01ff01ff, 0xffff01ff01ffff01, 0xffff01ff01ffffff,\n    0xffff01ffff010101, 0xffff01ffff0101ff, 0xffff01ffff01ff01, 0xffff01ffff01ffff,\n    0xffff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0xffff01ffffffffff,\n    0xffffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0xffffff010101ffff,\n    0xffffff0101ff0101, 0xffffff0101ff01ff, 0xffffff0101ffff01, 0xffffff0101ffffff,\n    0xffffff01ff010101, 0xffffff01ff0101ff, 0xffffff01ff01ff01, 0xffffff01ff01ffff,\n    0xffffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0xffffff01ffffffff,\n    0xffffffff01010101, 0xffffffff010101ff, 0xffffffff0101ff01, 0xffffffff0101ffff,\n    0xffffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0xffffffff01ffffff,\n    0xffffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0xffffffffff01ffff,\n    0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff,\n};\n\nstruct SignHelper {\n\n    IQK_ALWAYS_INLINE void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const {\n        auto s = vreinterpretq_s8_u64(uint64x2_t{kall_signs[sign_bits[0]], kall_signs[sign_bits[1]]});\n        // Normally we would expect this to be faster, but it isn't.\n        // auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1]));\n        // auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));\n        b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));\n    }\n\n    // We would need these two if we weren't loading from the unpacked sign table.\n    //const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));\n    //const uint8x16_t m1    = vdupq_n_u8(1);\n};\n\nstruct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {\n    DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x4_t new_block(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        prepare_internal(i, 0, bits);\n        return prepare_4bit_scales16(x[i].scales);\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) prepare_internal(i, 1, bits);\n    }\n\nprivate:\n\n    static void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) {\n        uint32_t aux32[2];\n        const uint16_t * aux16 = (const uint16_t *)aux32;\n        for (int k = 0; k < 2; ++k) {\n            aux32[1] = (qh[k] << 4) | (qh[k] << 18);\n            aux32[0] = (aux32[1] << 4) & 0x03000300;\n            aux32[1] &= 0x03000300;\n            b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))),\n                                   vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1]))));\n            b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))),\n                                   vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3]))));\n            sh.apply_signs_1x(b+2*k+0, sign_bits); sign_bits += 2;\n            sh.apply_signs_1x(b+2*k+1, sign_bits); sign_bits += 2;\n        }\n    }\n\n    void prepare_internal(int i, int j, SimpleBits& sb) {\n\n        const auto * qs = x[i].qs + 16*j;\n        const auto * qh = x[i].qh + 4*j;\n        const auto * sign_bits = qs + QK_K/8;\n\n        make4(sh, sign_bits+0, qs+0, qh+0, sb.b1.val);\n        make4(sh, sign_bits+8, qs+8, qh+2, sb.b2.val);\n    }\n\n    SignHelper sh;\n};\n\nstruct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {\n    DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    IQK_ALWAYS_INLINE float new_block(int i) const { return 0.25f * GGML_FP16_TO_FP32(x[i].d); }\n\n    inline int32x4_t unpack(int i, int j, uint8x16_t * q) const {\n        auto q3data = vld1q_u8_x2(x[i].qs + 32*j);\n        auto gas = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j));\n        prepare_block((const uint8_t *)q3data.val, (const uint32_t *)&gas, q);\n        return prepare_scales_8(gas);\n    }\n\nprivate:\n\n    inline static void make2(const uint8_t * q3, const uint32_t sidx, uint8x16_t * b) {\n        b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]});\n        b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]});\n        apply_signs_2(b, keven_signs, sidx);\n    }\n    inline static void prepare_block(const uint8_t * q3, const uint32_t * signs, uint8x16_t * quants) {\n        make2(q3+ 0, signs[0], quants + 0);\n        make2(q3+ 8, signs[1], quants + 2);\n        make2(q3+16, signs[2], quants + 4);\n        make2(q3+24, signs[3], quants + 6);\n    }\n};\n\nstruct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {\n    DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x2_t new_block(int i) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        uint32_t scales32[2];\n        auto qs = vld1q_u8_x2(x[i].qs);\n        auto signs = vld1q_u8(x[i].signs);\n\n        prepare_block((const uint8_t *)qs.val, x[i].qh, (const uint8_t *)&signs);\n\n        std::memcpy(scales32, x[i].scales, 4);\n        scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;\n        scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;\n        auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7\n        scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));\n        auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));\n        int32x4x2_t scales;\n        scales.val[0] = vmovl_s16(vget_low_s16(scales16));\n        scales.val[1] = vmovl_s16(vget_high_s16(scales16));\n        return scales;\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) {\n            auto qs = vld1q_u8_x2(x[i].qs + 32);\n            auto signs = vld1q_u8(x[i].signs + 16);\n            prepare_block((const uint8_t *)qs.val, x[i].qh + 4, (const uint8_t *)&signs);\n        }\n    }\n\nprivate:\n\n    static inline void make2(const SignHelper& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh,\n            const int16x8_t& hshift, uint8x16_t * b) {\n        auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));\n        const uint16_t * idx = (const uint16_t *)&vindex;\n        b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});\n        sh.apply_signs_1x(b+0, sign_bits+0);\n        b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});\n        sh.apply_signs_1x(b+1, sign_bits+2);\n    }\n    static inline void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh,\n            const int16x8_t& hshift, uint8x16_t * b) {\n        auto idx_l = vld1q_u8(qs);\n        make2(sh, sign_bits+0, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);\n        make2(sh, sign_bits+4, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);\n    }\n\n    static int16x8_t load_shift() {\n        static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};\n        return vld1q_s16(k_shift);\n    }\n\n    inline void prepare_block(const uint8_t * qs, const uint8_t * qh, const uint8_t * sign_bits) {\n        auto signs = vld1q_u8(sign_bits);\n        auto s = (const uint8_t *)&signs;\n        make4(sh, s + 0, qs+ 0, qh+0, hshift, bits.b1.val);\n        make4(sh, s + 8, qs+16, qh+2, hshift, bits.b2.val);\n    }\n\n    SignHelper sh;\n    const int16x8_t hshift = load_shift();\n\n};\n\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_IQXXS(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n    Dequantizer deq(vx, bx, nrc_y);\n    uint8x16_t  qx[8];\n    int32x4_t   sumi[nrc_y];\n    float32x4_t acc[nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb; ++i) {\n            float d = deq.new_block(i);\n            auto scales = deq.unpack(i, 0, qx);\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                sumi[iy] = vdupq_n_s32(0);\n                compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 0, sumi[iy]);\n            }\n            scales = deq.unpack(i, 1, qx);\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 1, sumi[iy]);\n                acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, i)), vcvtq_f32_s32(sumi[iy]));\n            }\n        }\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\n// =========================================== Legacy quants\n\ntemplate <typename Block>\ninline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) {\n    for (int k = 0; k < 4; ++k) aux[k] = x[k].d;\n    return vld1_f16((const float16_t *)aux);\n}\n\ntemplate <typename Block>\ninline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) {\n    if constexpr (std::is_same_v<Block, block_q8_1>) {\n        for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; }\n    } else {\n        for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; }\n    }\n    return vld1q_f16((const float16_t *)aux);\n}\n\nstruct Q4LegacyBits {\n    template <typename Block>\n    inline void prepare(const Block * x) {\n        for (int i = 0; i < 4; ++i) {\n            auto q4bits = vld1q_u8(x[i].qs);\n            b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));\n            b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));\n        }\n    }\n    inline void prepare1(const uint8_t * qs, int8x16_t * q) const {\n        auto q4bits = vld1q_u8(qs);\n        q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));\n        q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));\n    }\n    inline void prepare1(const uint8_t * qs) {\n        prepare1(qs, b);\n    }\n    const uint8x16_t m4b = vdupq_n_u8(0xf);\n    int8x16_t b[8];\n};\n\n// One would think this commented out version would do better than the one below\n// because it offers more opportunities to execute instructions in parallel.\n// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers\n// cannot it just do the sequential version below on its own?\n//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {\n//    const auto q8b_1 = vld1q_s8_x2(qs + 0);\n//    auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]);\n//    const auto q8b_2 = vld1q_s8_x2(qs + 32);\n//    auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]);\n//    auto p1234 = vpaddq_s32(p12, p34);\n//    const auto q8b_3 = vld1q_s8_x2(qs + 64);\n//    auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]);\n//    const auto q8b_4 = vld1q_s8_x2(qs + 96);\n//    auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]);\n//    return vpaddq_s32(p1234, vpaddq_s32(p56, p78));\n//}\n\ninline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {\n    auto q8b = vld1q_s8_x2(qs + 0);\n    auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]);\n    q8b = vld1q_s8_x2(qs + 32);\n    auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]);\n    auto p1234 = vpaddq_s32(p12, p34);\n    q8b = vld1q_s8_x2(qs + 64);\n    auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]);\n    q8b = vld1q_s8_x2(qs + 96);\n    auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]);\n    return vpaddq_s32(p1234, vpaddq_s32(p56, p78));\n}\n\ntemplate <int nrc> struct Q80 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q80(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);\n    }\n\n    inline const int8_t * quant_data(int iy, int i) const {\n        const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;\n        return y4->qs;\n    }\n\n    inline float16x4_t load_scales(int iy, int i) const {\n        const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;\n        return vld1_f16((const float16_t *)y4->d);\n    }\n\n    template <typename Dequantizer>\n    inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const {\n        auto qx_scales = deq.new_block(i);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8_scales = load_scales(iy, i);\n            sc16[iy] = vmul_f16(qx_scales, q8_scales);\n        }\n    }\n\n    template <typename Dequantizer>\n    inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {\n        deq.prepare1(i);\n        float d = GGML_FP16_TO_FP32(deq.x[i].d);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8b = vld1q_s8_x2(y[iy][i].qs);\n            auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);\n            acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));\n        }\n    }\n\n    const block_q8_0 * y[nrc_y];\n};\n\ntemplate <int nrc> struct Q81 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q81(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy);\n    }\n\n    inline const int8_t * quant_data(int iy, int i) const {\n        const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;\n        return y4->qs;\n    }\n\n    inline float16x8_t load_scales(int iy, int i) const {\n        const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;\n        return vld1q_f16((const float16_t *)y4->d);\n    }\n\n    template <typename Dequantizer>\n    inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const {\n        auto qx_scales = deq.new_block(i);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8_scales = load_scales(iy, i);\n            auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales));\n            acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m));\n            sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales));\n        }\n    }\n\n    template <typename Dequantizer>\n    inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {\n        deq.prepare1(i);\n        float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8b = vld1q_s8_x2(y[iy][i].qs);\n            auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);\n            acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));\n            acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s)));\n        }\n    }\n\n    const block_q8_1 * y[nrc_y];\n};\n\ntemplate <typename block_q>\nstruct BaseLegacyDequantizer {\n\n    BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {}\n\n    inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); }\n\n    Q4LegacyBits bits;\n\n    const void * vx;\n    const block_q * x;\n    size_t bx;\n};\n\nstruct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> {\n\n    DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        q[0] = vaddq_s8(q[0], m8);\n        q[1] = vaddq_s8(q[1], m8);\n    }\n    inline void prepare1(int i) {\n        prepare1(i, bits.b);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n    const int8x16_t m8 = vdupq_n_s8(-8);\n    //ggml_half aux[4];\n};\n\nstruct DequantizerQ41 : public BaseLegacyDequantizer<block_q4_1> {\n\n    DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i) {\n        bits.prepare1(x[i].qs);\n    }\n\n    inline float16x8_t new_block(int i) {\n        uint32_t aux32[4];\n        const uint32_t * s32 = (const uint32_t *)&x[4*i].d;\n        for (int k = 0; k < 4; ++k) {\n            aux32[k] = *s32; s32 += sizeof(block_q4_1)/4;\n            bits.prepare1(x[4*i+k].qs, bits.b + 2*k);\n        }\n        return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));\n    }\n    // Leaving this commented out attempt to be reminded that I already tried this.\n    // It has basically the same performance as the version above.\n    //inline float16x8_t new_block(int i) {\n    //    uint32x4_t scales = {};\n    //    const block_q4_1 * xi = x + 4*i;\n    //    const uint32_t * s32 = (const uint32_t *)&xi->d;\n    //    scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[0].qs, bits.b + 0);\n    //    scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[1].qs, bits.b + 2);\n    //    scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[2].qs, bits.b + 4);\n    //    scales = vsetq_lane_u32(*s32, scales, 3);\n    //    bits.prepare1(xi[3].qs, bits.b + 6);\n    //    return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle)));\n    //}\n\n    const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};\n};\n\nstruct HighBit5Legacy {\n    inline uint8x16_t to_bytes(const uint8_t * qh) const {\n        uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);\n        return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask));\n    }\n    inline uint8x16_t to_negated_bytes(const uint8_t * qh) const {\n        uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);\n        return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0));\n    }\n    const uint64x2_t mask = vdupq_n_u64(0x8040201008040201);\n    const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));\n};\n\nstruct DequantizerQ50 final : public BaseLegacyDequantizer<block_q5_0> {\n\n    DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        auto qh = x[i].qh;\n        q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0))));\n        q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2))));\n    }\n    inline void prepare1(int i) {\n        prepare1(i, bits.b);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n    HighBit5Legacy hbits;\n\n    const uint8x16_t mh = vdupq_n_u8(0xf0);\n\n};\n\nstruct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> {\n\n    DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i) {\n        bits.b[0] = vld1q_s8(x[i].qs);\n        bits.b[1] = vld1q_s8(x[i].qs+16);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs);\n            bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n};\n\nstruct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> {\n\n    DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        auto qh = x[i].qh;\n        q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0))));\n        q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2))));\n    }\n    inline void prepare1(int i) {\n        bits.prepare1(x[i].qs, bits.b);\n    }\n\n    inline float16x8_t new_block(int i) {\n        uint32_t aux32[4];\n        const uint32_t * s32 = (const uint32_t *)&x[4*i].d;\n        for (int k = 0; k < 4; ++k) {\n            aux32[k] = *s32; s32 += sizeof(block_q5_1)/4;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));\n    }\n\n    HighBit5Legacy hbits;\n\n    const uint8x16_t mh = vdupq_n_u8(0x10);\n    const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};\n\n};\n\ntemplate <typename Dequantizer, typename Q8>\ninline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i));\n        auto scale = vcvt_f32_f16(sc16[iy]);\n        acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall));\n    }\n}\n\ntemplate <typename Dequantizer, typename Q8>\ninline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK4_1;\n\n    float16x4_t sc16[Q8::nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[Q8::nrc_y];\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb/4; ++i) {\n            q8.process_scales(i, deq, sc16, acc);\n            sum_4(i, deq, q8, sc16, acc);\n        }\n        for (int i = 4*(nb/4); i < nb; ++i) {\n            q8.process_1_block(i, deq, acc);\n        }\n\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\ntemplate <typename Dequantizer, typename Q8>\ninline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK4_1;\n\n    float16x4_t sc16[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq1.new_row(ix);\n        deq2.new_row(ix);\n\n        float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) };\n\n        for (int i = 0; i < nb/8; ++i) {\n            q8.process_scales(2*i+0, deq1, sc16+0, acc+0);\n            q8.process_scales(2*i+1, deq2, sc16+1, acc+1);\n            sum_4(2*i+0, deq1, q8, sc16+0, acc+0);\n            sum_4(2*i+1, deq2, q8, sc16+1, acc+1);\n        }\n        for (int i = 2*(nb/8); i < nb/4; ++i) {\n            q8.process_scales(i, deq1, sc16, acc);\n            sum_4(i, deq1, q8, sc16, acc);\n        }\n        for (int i = 4*(nb/4); i < nb; ++i) {\n            q8.process_1_block(i, deq1, acc);\n        }\n\n        info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1])));\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void IQK_NOINLINE mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Q81<nrc_y> q8(info);\n    if constexpr (nrc_y == 1) {\n        Dequantizer deq1(vx, bx), deq2(vx, bx);\n        mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n    } else {\n        Dequantizer deq(vx, bx);\n        mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void IQK_NOINLINE mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Q80<nrc_y> q8(info);\n    if constexpr (nrc_y == 1) {\n        Dequantizer deq1(vx, bx), deq2(vx, bx);\n        mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n    } else {\n        Dequantizer deq(vx, bx);\n        mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);\n    }\n}\n\ntemplate <typename Dequantizer>\nstatic void IQK_NOINLINE mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Dequantizer deq1(vx, bx), deq2(vx, bx);\n    Q81<1> q8(info);\n    mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n}\n\ntemplate <typename Dequantizer>\nstatic void IQK_NOINLINE mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Dequantizer deq1(vx, bx), deq2(vx, bx);\n    Q80<1> q8(info);\n    mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x);\n}\n\ntemplate <typename Dequantizer> void MulMat::set_functions(MulMat& m) {\n    if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||\n                  std::is_same_v<Dequantizer, DequantizerQ80>) {\n        m.funcs[0] = mul_mat_qX_0_q8_0<Dequantizer, 1>;\n        m.funcs[1] = mul_mat_qX_0_q8_0<Dequantizer, 2>;\n        m.funcs[2] = mul_mat_qX_0_q8_0<Dequantizer, 3>;\n        m.funcs[3] = mul_mat_qX_0_q8_0<Dequantizer, 4>;\n        m.funcs[4] = mul_mat_qX_0_q8_0<Dequantizer, 5>;\n        m.funcs[5] = mul_mat_qX_0_q8_0<Dequantizer, 6>;\n        m.funcs[6] = mul_mat_qX_0_q8_0<Dequantizer, 7>;\n        m.funcs[7] = mul_mat_qX_0_q8_0<Dequantizer, 8>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerQ41> || std::is_same_v<Dequantizer, DequantizerQ51>) {\n        m.funcs[0] = mul_mat_qX_1_q8_1<Dequantizer, 1>;\n        m.funcs[1] = mul_mat_qX_1_q8_1<Dequantizer, 2>;\n        m.funcs[2] = mul_mat_qX_1_q8_1<Dequantizer, 3>;\n        m.funcs[3] = mul_mat_qX_1_q8_1<Dequantizer, 4>;\n        m.funcs[4] = mul_mat_qX_1_q8_1<Dequantizer, 5>;\n        m.funcs[5] = mul_mat_qX_1_q8_1<Dequantizer, 6>;\n        m.funcs[6] = mul_mat_qX_1_q8_1<Dequantizer, 7>;\n        m.funcs[7] = mul_mat_qX_1_q8_1<Dequantizer, 8>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2XXS> || std::is_same_v<Dequantizer, DequantizerIQ3XXS>) {\n        m.funcs[0] = mul_mat_qX_K_q8_K_IQXXS<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_IQXXS<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_IQXXS<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_IQXXS<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_IQXXS<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_IQXXS<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_IQXXS<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_IQXXS<8, Dequantizer>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2S> ||\n                       std::is_same_v<Dequantizer, DequantizerIQ3S> ||\n                       std::is_same_v<Dequantizer, DequantizerIQ2XS>) {\n        m.funcs[0] = mul_mat_qX_K_q8_K_IQ<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_IQ<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_IQ<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_IQ<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_IQ<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_IQ<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_IQ<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_IQ<8, Dequantizer>;\n    }\n    else {\n        m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>;\n    }\n}\n\nbool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny) {\n    row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);\n\n    (void)Ny;\n    // Uncommenting out this would disable iqk_mul_mat for matrix x vector multiplications.\n    //if (Ny == 1 && (typeA == GGML_TYPE_IQ2_XXS || typeA == GGML_TYPE_IQ2_XS || typeA == GGML_TYPE_IQ2_S ||\n    //                typeA == GGML_TYPE_IQ3_XXS || typeA == GGML_TYPE_IQ3_S)) return false;\n\n    switch (typeA) {\n        case GGML_TYPE_Q2_K:\n            MulMat::set_functions<DequantizerQ2K>(m);\n            break;\n        case GGML_TYPE_Q3_K:\n            MulMat::set_functions<DequantizerQ3K>(m);\n            break;\n        case GGML_TYPE_Q4_K:\n            MulMat::set_functions<DequantizerQ4K>(m);\n            break;\n        case GGML_TYPE_Q5_K:\n            MulMat::set_functions<DequantizerQ5K>(m);\n            break;\n        case GGML_TYPE_Q6_K:\n            MulMat::set_functions<DequantizerQ6K>(m);\n            break;\n        case GGML_TYPE_IQ4_XS:\n            MulMat::set_functions<DequantizerIQ4XS>(m);\n            break;\n        case GGML_TYPE_IQ3_S:\n            MulMat::set_functions<DequantizerIQ3S>(m);\n            break;\n        case GGML_TYPE_IQ3_XXS:\n            MulMat::set_functions<DequantizerIQ3XXS>(m);\n            break;\n        case GGML_TYPE_IQ2_S:\n            MulMat::set_functions<DequantizerIQ2S>(m);\n            break;\n        case GGML_TYPE_IQ2_XS:\n            MulMat::set_functions<DequantizerIQ2XS>(m);\n            break;\n        case GGML_TYPE_IQ2_XXS:\n            MulMat::set_functions<DequantizerIQ2XXS>(m);\n            break;\n        case GGML_TYPE_Q4_0:\n            MulMat::set_functions<DequantizerQ40>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        case GGML_TYPE_Q4_1:\n            MulMat::set_functions<DequantizerQ41>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);\n            break;\n        case GGML_TYPE_Q5_0:\n            MulMat::set_functions<DequantizerQ50>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        case GGML_TYPE_Q5_1:\n            MulMat::set_functions<DequantizerQ51>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);\n            break;\n        case GGML_TYPE_Q8_0:\n            MulMat::set_functions<DequantizerQ80>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        default:\n            return false;\n    }\n    return true;\n}\n\n}\n\n#endif // __x86_64__ or __aarch64__\n"
  },
  {
    "path": "archive/third_party/llamafile/macros.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/macros.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n#pragma once\n\n#define MIN(X, Y) ((Y) > (X) ? (X) : (Y))\n#define MAX(X, Y) ((Y) < (X) ? (X) : (Y))\n#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))\n#define ROUNDUP(X, K) (((X) + (K) - 1) & -(K))\n#define ARRAYLEN(A) ((sizeof(A) / sizeof(*(A))) / ((unsigned)!(sizeof(A) % sizeof(*(A)))))\n"
  },
  {
    "path": "archive/third_party/llamafile/micros.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/micros.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n#pragma once\n\n#include <ctime>\n\n#ifndef _WIN32\n#include <unistd.h>\n#else\n#include <windows.h>\n#endif\n\n#ifdef _WIN32\nstatic long long GetQueryPerformanceFrequency() {\n    LARGE_INTEGER t;\n    QueryPerformanceFrequency(&t);\n    return t.QuadPart;\n}\nstatic long long GetQueryPerformanceCounter() {\n    LARGE_INTEGER t;\n    QueryPerformanceCounter(&t);\n    return t.QuadPart;\n}\n#endif\n\nstatic long long micros(void) {\n#ifndef _WIN32\n    struct timespec ts;\n    clock_gettime(CLOCK_REALTIME, &ts);\n    return ts.tv_sec * 1000000 + (ts.tv_nsec + 999) / 1000;\n#else\n    static long long timer_freq = GetQueryPerformanceFrequency();\n    static long long timer_start = GetQueryPerformanceCounter();\n    return ((GetQueryPerformanceCounter() - timer_start) * 1000000) / timer_freq;\n#endif\n}\n"
  },
  {
    "path": "archive/third_party/llamafile/numba.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/numba.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#pragma once\n\ninline int rand32(void) {\n    static unsigned long long lcg = 1;\n    lcg *= 6364136223846793005;\n    lcg += 1442695040888963407;\n    return lcg >> 32;\n}\n\ninline int popcount(unsigned x) {\n    x = x - ((x >> 1) & 0x55555555);\n    x = ((x >> 2) & 0x33333333) + (x & 0x33333333);\n    x = (x + (x >> 4)) & 0x0F0F0F0F;\n    x = (x + (x >> 16));\n    return (x + (x >> 8)) & 0x0000003F;\n}\n\ninline int hamming(int x, int y) {\n    return popcount(x ^ y);\n}\n\ninline float float01(unsigned x) {  // (0,1)\n    return 1.f / 8388608 * ((x >> 9) + .5f);\n}\n\ninline float numba(void) {  // (-10,10)\n    return float01(rand32()) * 2.f - 1.f;\n}\n\ntemplate <typename T>\nvoid randomize(T* A, int n) {\n    for (int i = 0; i < n; ++i)\n        A[i] = numba();\n}\n\ntemplate <typename T>\nvoid randomize(int m, int n, T* A, int lda) {\n    for (int j = 0; j < n; ++j)\n        for (int i = 0; i < m; ++i)\n            A[lda * j + i] = numba();\n}\n\ntemplate <typename T, typename U>\nvoid broadcast(T* A, int n, U x) {\n    for (int i = 0; i < n; ++i)\n        A[i] = x;\n}\n\ntemplate <typename T, typename U>\nvoid broadcast(int m, int n, T* A, int lda, U x) {\n    for (int j = 0; j < n; ++j)\n        for (int i = 0; i < m; ++i)\n            A[lda * j + i] = x;\n}\n"
  },
  {
    "path": "archive/third_party/llamafile/sgemm.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU\n        // use ARM version\n        #include \"sgemm_arm.cpp\"\n#else\n        // use x86 version\n        #include \"sgemm_x86.cpp\"\n#endif"
  },
  {
    "path": "archive/third_party/llamafile/sgemm.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#pragma once\n#include <stdbool.h>\n#include <cstddef>\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nstruct ggml_tensor;\nstruct ggml_compute_params;\n/*moonll old\nadd more params typeb...\n*/\n\n\nbool iqk_mul_mat(long, long, long,int, const void*, long, int, const void*, long,float*, long, int, int);\nbool iqk_mul_mat_zen4(long, long, long,int, const void*, long, int, const void*, long,float*, long, int, int);\nbool iqk_mul_mat_arm82(long, long, long,int, const void*, long, int, const void*, long,float*, long, int, int);\n\n\nbool iqk_mul_mat_moe(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\nbool iqk_mul_mat_moe_zen4(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\nbool iqk_mul_mat_moe_arm82(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\nbool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\n\nbool llamafile_sgemm(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_mixmul(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nsize_t llamafile_mixmul_needs(const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*);\n\nbool llamafile_sgemm_unsupported(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avx(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_fma(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avx2(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avxvnni(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avx512f(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_zen4(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_arm80(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_arm82(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\n\nbool llamafile_mixmul_unsupported(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avx(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_fma(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avx2(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avxvnni(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avx512f(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_zen4(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_arm80(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_arm82(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_iqk(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "archive/third_party/llamafile/sgemm_arm.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"sgemm.h\"\n// #include <cosmo.h>\n// #include <cpuid.h>\n// #include <libc/sysv/consts/hwcap.h>\n#include <stdio.h>\n// #include <sys/auxv.h>\n#include <cassert>\n// #include \"llamafile.h\"\n\nstatic const struct GemmFuncs {\n    bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\n    bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\n    bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\n    // typeof(llamafile_sgemm)* sgemm;\n    // typeof(llamafile_mixmul)* mixmul;\n    // typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;\n    GemmFuncs() {\n#if defined(__x86_64__) || defined(_M_X64)\n        // if (X86_HAVE(AVX)) {\n        //     if (X86_HAVE(FMA)) {\n        //         if (X86_HAVE(AVX2)) {\n        //             if (X86_HAVE(AVX512F)) {\n        //                 if (X86_HAVE(AVX512VL) &&     //\n        //                     X86_HAVE(AVX512BW) &&     //\n        //                     X86_HAVE(AVX512DQ) &&     //\n        //                     X86_HAVE(AVX512_VNNI) &&  //\n        //                     X86_HAVE(AVX512_BF16)) {\n        //                     // AMD Zen4+ (2023-)\n        //                     sgemm = llamafile_sgemm_amd_zen4;\n        //                     mixmul = llamafile_mixmul_amd_zen4;\n        //                     iqk_mixmul = iqk_mul_mat_moe_zen4;\n        //                 } else {\n        //                     // Intel Xeon Skylake+ (2015-)\n        //                     sgemm = llamafile_sgemm_amd_avx512f;\n        //                     mixmul = llamafile_mixmul_amd_avx512f;\n        //                     iqk_mixmul = iqk_mul_mat_moe;\n        //                 }\n        //             } else if (X86_HAVE(AVXVNNI)) {\n        //                 // Intel Alderlake (2021-)\n        //                 sgemm = llamafile_sgemm_amd_avxvnni;\n        //                 mixmul = llamafile_mixmul_amd_avxvnni;\n        //                 iqk_mixmul = iqk_mul_mat_moe;\n        //             } else {\n        //                 // Intel Haswell/Broadwell/Skylake (2013-2020)\n        //                 // AMD Excavator (2015-2022)\n        //                 sgemm = llamafile_sgemm_amd_avx2;\n        //                 mixmul = llamafile_mixmul_amd_avx2;\n        //                 if (X86_HAVE(F16C))\n        //                     iqk_mixmul = iqk_mul_mat_moe;\n        //             }\n        //         } else {\n        //             // AMD Piledriver (2011-2014)\n        //             sgemm = llamafile_sgemm_amd_fma;\n        //             mixmul = llamafile_mixmul_amd_fma;\n        //             if (X86_HAVE(F16C))\n        //                 iqk_mixmul = iqk_mul_mat_moe;\n        //         }\n        //     } else {\n        //         // Intel Sandybridge/Ivybridge (2010-2012)\n        //         // AMD Bulldozer (2011)\n        //         sgemm = llamafile_sgemm_amd_avx;\n        //         mixmul = llamafile_mixmul_amd_avx;\n        //     }\n        // } else {\n        //     // AMD K8/Barcelona (2003-2010)\n        //     // Intel Core/Nehalem (2006-2009)\n        //     sgemm = llamafile_sgemm_unsupported;\n        //     mixmul = llamafile_mixmul_unsupported;\n        // }\n\n#if defined(__AVX__)\n#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))\n#if defined(__AVX2__)\n#if defined(__AVX512F__)\n#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)\n        // AMD Zen4+ (2023-)\n        sgemm = llamafile_sgemm_amd_zen4;\n        mixmul = llamafile_mixmul_amd_zen4;\n        iqk_mixmul = iqk_mul_mat_moe_zen4;\n#else\n        // Intel Xeon Skylake+ (2015-)\n        sgemm = llamafile_sgemm_amd_avx512f;\n        mixmul = llamafile_mixmul_amd_avx512f;\n        iqk_mixmul = iqk_mul_mat_moe;\n#endif\n#elif defined(__AVXVNNI__)\n        // Intel Alderlake (2021-)\n        sgemm = llamafile_sgemm_amd_avxvnni;\n        mixmul = llamafile_mixmul_amd_avxvnni;\n        iqk_mixmul = iqk_mul_mat_moe;\n#else\n        // Intel Haswell/Broadwell/Skylake (2013-2020)\n        // AMD Excavator (2015-2022)\n        sgemm = llamafile_sgemm_amd_avx2;\n        mixmul = llamafile_mixmul_amd_avx2;\n#if defined(__F16C__)\n        iqk_mixmul = iqk_mul_mat_moe;\n#endif\n#endif\n#else\n        // AMD Piledriver (2011-2014)\n        sgemm = llamafile_sgemm_amd_fma;\n        mixmul = llamafile_mixmul_amd_fma;\n#if defined(__F16C__)\n        iqk_mixmul = iqk_mul_mat_moe;\n#endif\n#endif\n#else\n        // Intel Sandybridge/Ivybridge (2010-2012)\n        // AMD Bulldozer (2011)\n        sgemm = llamafile_sgemm_amd_avx;\n        mixmul = llamafile_mixmul_amd_avx;\n#endif\n#else\n        // AMD K8/Barcelona (2003-2010)\n        // Intel Core/Nehalem (2006-2009)\n        sgemm = llamafile_sgemm_unsupported;\n        mixmul = llamafile_mixmul_unsupported;\n#endif\n\n#elif defined(__aarch64__)\n//        long hwcap = getauxval(AT_HWCAP);\n//        if ((hwcap & HWCAP_FPHP) &&     // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)\n//            (hwcap & HWCAP_ASIMDHP) &&  // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)\n//            (hwcap & HWCAP_ASIMDDP)) {  // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)\n//            // e.g. Apple M1, Raspberry Pi 5\n//            sgemm = llamafile_sgemm_arm82;\n//            mixmul = llamafile_mixmul_arm82;\n//            iqk_mixmul = iqk_mul_mat_moe_arm82;\n//        } else {\n            // ARM64 baseline ISA\n            sgemm = llamafile_sgemm_arm80;\n            mixmul = llamafile_mixmul_arm80;\n//        }\n#else\n        sgemm = llamafile_sgemm_unsupported;\n        mixmul = llamafile_mixmul_unsupported;\n#endif\n    }\n} funcs;\n\n/**\n * Performs optimized matrix multiplication on CPU.\n *\n * This subroutine may compute C = Aᵀ * B with column major ordering.\n * Despite its name, this isn't a generalized implementation. Work is\n * only performed when a handwritten kernel is written and available.\n * Otherwise the caller should fall back to a general matmul routine.\n *\n * @param m is rows in `A` and `C`\n * @param n is cols in `B` and `C`\n * @param k is cols in `A` and rows in `B`\n * @param A is first input matrix (always transposed)\n * @param lda is row stride of `A`\n * @param B is second input matrix (never transposed)\n * @param ldb is row stride of `B`\n * @param C is input/output array of output matrices\n * @param ldc is row stride of `C`\n * @param ith is thread id (must be less than `nth`)\n * @param nth is number of threads (must be greater than zero)\n * @param task is GGML task type\n * @param Atype is GGML data type of `A`\n * @param Btype is GGML data type of `B`\n * @param Ctype is GGML data type of `C`\n * @param precision may be used to control the internal compute type\n * @return true if this function was able to service the matmul request\n */\nbool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,\n                       precision);\n}\n\n/**\n * Performs \"mixture of experts\" tensor multiplication on CPU.\n */\nbool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {\n    return funcs.mixmul(params, weights, thought, plan, result);\n}\n\nbool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {\n    return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);\n}\n"
  },
  {
    "path": "archive/third_party/llamafile/sgemm_x86.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"sgemm.h\"\n// #include <cosmo.h>\n// #include <cpuid.h>\n// #include <libc/sysv/consts/hwcap.h>\n#include <stdio.h>\n// #include <sys/auxv.h>\n#include <cassert>\n// #include \"llamafile.h\"\n\nstatic const struct GemmFuncs {\n    bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\n    bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\n    bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\n    // typeof(llamafile_sgemm)* sgemm;\n    // typeof(llamafile_mixmul)* mixmul;\n    // typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;\n    GemmFuncs() {\n#if defined(__x86_64__) || defined(_M_X64)\n        // if (X86_HAVE(AVX)) {\n        //     if (X86_HAVE(FMA)) {\n        //         if (X86_HAVE(AVX2)) {\n        //             if (X86_HAVE(AVX512F)) {\n        //                 if (X86_HAVE(AVX512VL) &&     //\n        //                     X86_HAVE(AVX512BW) &&     //\n        //                     X86_HAVE(AVX512DQ) &&     //\n        //                     X86_HAVE(AVX512_VNNI) &&  //\n        //                     X86_HAVE(AVX512_BF16)) {\n        //                     // AMD Zen4+ (2023-)\n        //                     sgemm = llamafile_sgemm_amd_zen4;\n        //                     mixmul = llamafile_mixmul_amd_zen4;\n        //                     iqk_mixmul = iqk_mul_mat_moe_zen4;\n        //                 } else {\n        //                     // Intel Xeon Skylake+ (2015-)\n        //                     sgemm = llamafile_sgemm_amd_avx512f;\n        //                     mixmul = llamafile_mixmul_amd_avx512f;\n        //                     iqk_mixmul = iqk_mul_mat_moe;\n        //                 }\n        //             } else if (X86_HAVE(AVXVNNI)) {\n        //                 // Intel Alderlake (2021-)\n        //                 sgemm = llamafile_sgemm_amd_avxvnni;\n        //                 mixmul = llamafile_mixmul_amd_avxvnni;\n        //                 iqk_mixmul = iqk_mul_mat_moe;\n        //             } else {\n        //                 // Intel Haswell/Broadwell/Skylake (2013-2020)\n        //                 // AMD Excavator (2015-2022)\n        //                 sgemm = llamafile_sgemm_amd_avx2;\n        //                 mixmul = llamafile_mixmul_amd_avx2;\n        //                 if (X86_HAVE(F16C))\n        //                     iqk_mixmul = iqk_mul_mat_moe;\n        //             }\n        //         } else {\n        //             // AMD Piledriver (2011-2014)\n        //             sgemm = llamafile_sgemm_amd_fma;\n        //             mixmul = llamafile_mixmul_amd_fma;\n        //             if (X86_HAVE(F16C))\n        //                 iqk_mixmul = iqk_mul_mat_moe;\n        //         }\n        //     } else {\n        //         // Intel Sandybridge/Ivybridge (2010-2012)\n        //         // AMD Bulldozer (2011)\n        //         sgemm = llamafile_sgemm_amd_avx;\n        //         mixmul = llamafile_mixmul_amd_avx;\n        //     }\n        // } else {\n        //     // AMD K8/Barcelona (2003-2010)\n        //     // Intel Core/Nehalem (2006-2009)\n        //     sgemm = llamafile_sgemm_unsupported;\n        //     mixmul = llamafile_mixmul_unsupported;\n        // }\n\n#if defined(__AVX__)\n#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))\n#if defined(__AVX2__)\n#if defined(__AVX512F__)\n#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)\n        // AMD Zen4+ (2023-)\n        sgemm = llamafile_sgemm_amd_zen4;\n        mixmul = llamafile_mixmul_amd_zen4;\n        iqk_mixmul = iqk_mul_mat_moe_zen4;\n#else\n        // Intel Xeon Skylake+ (2015-)\n        sgemm = llamafile_sgemm_amd_avx512f;\n        mixmul = llamafile_mixmul_amd_avx512f;\n        iqk_mixmul = iqk_mul_mat_moe;\n#endif\n#elif defined(__AVXVNNI__)\n        // Intel Alderlake (2021-)\n        sgemm = llamafile_sgemm_amd_avxvnni;\n        mixmul = llamafile_mixmul_amd_avxvnni;\n        iqk_mixmul = iqk_mul_mat_moe;\n#else\n        // Intel Haswell/Broadwell/Skylake (2013-2020)\n        // AMD Excavator (2015-2022)\n        sgemm = llamafile_sgemm_amd_avx2;\n        mixmul = llamafile_mixmul_amd_avx2;\n#if defined(__F16C__)\n        iqk_mixmul = iqk_mul_mat_moe;\n#endif\n#endif\n#else\n        // AMD Piledriver (2011-2014)\n        sgemm = llamafile_sgemm_amd_fma;\n        mixmul = llamafile_mixmul_amd_fma;\n#if defined(__F16C__)\n        iqk_mixmul = iqk_mul_mat_moe;\n#endif\n#endif\n#else\n        // Intel Sandybridge/Ivybridge (2010-2012)\n        // AMD Bulldozer (2011)\n        sgemm = llamafile_sgemm_amd_avx;\n        mixmul = llamafile_mixmul_amd_avx;\n#endif\n#else\n        // AMD K8/Barcelona (2003-2010)\n        // Intel Core/Nehalem (2006-2009)\n        sgemm = llamafile_sgemm_unsupported;\n        mixmul = llamafile_mixmul_unsupported;\n#endif\n\n#elif defined(__aarch64__)\n        long hwcap = getauxval(AT_HWCAP);\n        if ((hwcap & HWCAP_FPHP) &&     // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)\n            (hwcap & HWCAP_ASIMDHP) &&  // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)\n            (hwcap & HWCAP_ASIMDDP)) {  // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)\n            // e.g. Apple M1, Raspberry Pi 5\n            sgemm = llamafile_sgemm_arm82;\n            mixmul = llamafile_mixmul_arm82;\n            iqk_mixmul = iqk_mul_mat_moe_arm82;\n        } else {\n            // ARM64 baseline ISA\n            sgemm = llamafile_sgemm_arm80;\n            mixmul = llamafile_mixmul_arm80;\n        }\n#else\n        sgemm = llamafile_sgemm_unsupported;\n        mixmul = llamafile_mixmul_unsupported;\n#endif\n    }\n} funcs;\n\n/**\n * Performs optimized matrix multiplication on CPU.\n *\n * This subroutine may compute C = Aᵀ * B with column major ordering.\n * Despite its name, this isn't a generalized implementation. Work is\n * only performed when a handwritten kernel is written and available.\n * Otherwise the caller should fall back to a general matmul routine.\n *\n * @param m is rows in `A` and `C`\n * @param n is cols in `B` and `C`\n * @param k is cols in `A` and rows in `B`\n * @param A is first input matrix (always transposed)\n * @param lda is row stride of `A`\n * @param B is second input matrix (never transposed)\n * @param ldb is row stride of `B`\n * @param C is input/output array of output matrices\n * @param ldc is row stride of `C`\n * @param ith is thread id (must be less than `nth`)\n * @param nth is number of threads (must be greater than zero)\n * @param task is GGML task type\n * @param Atype is GGML data type of `A`\n * @param Btype is GGML data type of `B`\n * @param Ctype is GGML data type of `C`\n * @param precision may be used to control the internal compute type\n * @return true if this function was able to service the matmul request\n */\nbool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,\n                       precision);\n}\n\n/**\n * Performs \"mixture of experts\" tensor multiplication on CPU.\n */\nbool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {\n    return funcs.mixmul(params, weights, thought, plan, result);\n}\n\nbool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {\n    return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);\n}\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n//\n//\n//                                ██████╗ ██╗   █████╗ ██████╗\n//         ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║  ██╔══██╗██╔═══╝\n//         ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║  ███████║██████╗\n//           ██║  ██║██▀███║╚███╔╝██╔══██╗██║  ██╔══██║╔═══██║\n//           ██║  ██║██║ ██║ ███║ ██████╔╝████╗██║  ██║██████║\n//           ╚═╝  ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝  ╚═╝╚═════╝\n//\n//                   BASIC LINEAR ALGEBRA SUBPROGRAMS\n//\n//\n// This file implements multithreaded CPU matrix multiplication for the\n// common contiguous use case C = Aᵀ * B. These kernels are designed to\n// have excellent performance[1] for matrices that fit in the CPU cache\n// without imposing any overhead such as cache filling or malloc calls.\n//\n// This implementation does not guarantee any upper bound with rounding\n// errors, which grow along with k. Our goal's to maximally exploit the\n// hardware for performance, and then use whatever resources remain for\n// improving numerical accuracy.\n//\n// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].\n//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].\n\n#pragma once\n\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n// #include \"log.h\"\n#include \"flags.h\"\n#include \"sgemm.h\"\n// #include <cosmo.h>\n\n#pragma GCC diagnostic ignored \"-Wpedantic\"\n#pragma GCC diagnostic ignored \"-Wignored-attributes\"\n\n#define ROW_ALIGN 64\n#define MATRIX_ALIGN 4096\n#define MAX_ALIGN 4096\n\n#ifdef _MSC_VER\n#define NOINLINE __declspec(noinline)\n#else\n#define NOINLINE __attribute__((__noinline__))\n#endif\n\n#if defined(__ARM_NEON) || defined(__AVX512F__)\n#define VECTOR_REGISTERS 32\n#else\n#define VECTOR_REGISTERS 16\n#endif\n\n#if 0\n#define NOT_SUPPORTED tinyBLAS_not_supported(__FILE__, __LINE__)\n#else\n#define NOT_SUPPORTED false\n#endif\n#define WANT_QUANTIZATION false\n\nnamespace {\n\nbool tinyBLAS_not_supported(const char* file, int line) {\n    // tinylogf(\"%s:%d: tinyBLAS not supported\\n\", file, line);\n    return false;\n}\n\ninline float unhalf(ggml_fp16_t d) {\n    return GGML_FP16_TO_FP32(d);\n}\ninline float unhalf(ggml_bf16_t d) {\n    return GGML_BF16_TO_FP32(d);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// MATRIX MEMORY INDEXING\n\n#define NCA 1\n#define NCB 2\n#define NCC 4\n\n#define INDEX(A, lda, j, i) (CONFIG & NC##A ? ((T##A**)A)[j] + i : A + lda * (j) + i)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// GGML TYPE TRAITS\n\ntemplate <typename T>\nstruct ggml_type_trait;\ntemplate <>\nstruct ggml_type_trait<float> {\n    static constexpr ggml_type id = GGML_TYPE_F32;\n};\ntemplate <>\nstruct ggml_type_trait<ggml_bf16_t> {\n    static constexpr ggml_type id = GGML_TYPE_BF16;\n};\ntemplate <>\nstruct ggml_type_trait<ggml_fp16_t> {\n    static constexpr ggml_type id = GGML_TYPE_F16;\n};\ntemplate <>\nstruct ggml_type_trait<block_q8_0> {\n    static constexpr ggml_type id = GGML_TYPE_Q8_0;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED ARITHMETIC OPERATIONS\n\n#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline __m128 add(__m128 x, __m128 y) {\n    return _mm_add_ps(x, y);\n}\ninline __m128 sub(__m128 x, __m128 y) {\n    return _mm_sub_ps(x, y);\n}\ninline __m128 mul(__m128 x, __m128 y) {\n    return _mm_mul_ps(x, y);\n}\n#endif  // __SSE__\n\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline __m256 add(__m256 x, __m256 y) {\n    return _mm256_add_ps(x, y);\n}\ninline __m256 sub(__m256 x, __m256 y) {\n    return _mm256_sub_ps(x, y);\n}\ninline __m256 mul(__m256 x, __m256 y) {\n    return _mm256_mul_ps(x, y);\n}\n#endif  // __AVX__\n\n#if defined(__AVX512F__)\ninline __m512 add(__m512 x, __m512 y) {\n    return _mm512_add_ps(x, y);\n}\ninline __m512 sub(__m512 x, __m512 y) {\n    return _mm512_sub_ps(x, y);\n}\ninline __m512 mul(__m512 x, __m512 y) {\n    return _mm512_mul_ps(x, y);\n}\n#endif  // __AVX512F__\n\n#if defined(__ARM_NEON)\ninline float32x4_t add(float32x4_t x, float32x4_t y) {\n    return vaddq_f32(x, y);\n}\ninline float32x4_t sub(float32x4_t x, float32x4_t y) {\n    return vsubq_f32(x, y);\n}\ninline float32x4_t mul(float32x4_t x, float32x4_t y) {\n    return vmulq_f32(x, y);\n}\n#endif  // __ARM_NEON\n\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)\ninline float16x8_t add(float16x8_t x, float16x8_t y) {\n    return vaddq_f16(x, y);\n}\ninline float16x8_t sub(float16x8_t x, float16x8_t y) {\n    return vsubq_f16(x, y);\n}\ninline float16x8_t mul(float16x8_t x, float16x8_t y) {\n    return vmulq_f16(x, y);\n}\n#endif  // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED FUSED MULTIPLY ADD\n\n/**\n * Computes a * b + c.\n */\ntemplate <typename T, typename U>\ninline U madd(T a, T b, U c) {\n    return add(mul(a, b), c);\n}\n\n/**\n * Computes a * b + c with error correction.\n *\n * @see W. Kahan, \"Further remarks on reducing truncation errors,\"\n *    Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965,\n *    doi: 10.1145/363707.363723.\n */\ntemplate <typename T, typename U>\ninline U madder(T a, T b, U c, U* e) {\n    U y = sub(mul(a, b), *e);\n    U t = add(c, y);\n    *e = sub(sub(t, c), y);\n    return t;\n}\n\n#ifdef __ARM_NEON\ninline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t* e) {\n    float32x4_t y = sub(vmulq_n_f32(a, b), *e);\n    float32x4_t t = add(c, y);\n    *e = sub(sub(t, c), y);\n    return t;\n}\n#endif\n\n#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ntemplate <>\ninline __m256 madd(__m256 a, __m256 b, __m256 c) {\n    return _mm256_fmadd_ps(a, b, c);\n}\n#endif\n#if defined(__AVX512F__)\ntemplate <>\ninline __m512 madd(__m512 a, __m512 b, __m512 c) {\n    return _mm512_fmadd_ps(a, b, c);\n}\n#endif\n#endif\n\n#if defined(__ARM_FEATURE_FMA)\ntemplate <>\ninline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {\n    return vfmaq_f32(c, a, b);\n}\n#if 0  // todo: this specialization chops gcc 12.3 performance in half\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) && 0\ntemplate <>\ninline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {\n    return vfmaq_f16(c, b, a);\n}\n#endif\n#endif\n#endif\n\n#if defined(__AVX512BF16__)\ntemplate <>\ninline __m512 madd(__m512bh x, __m512bh y, __m512 z) {\n    return _mm512_dpbf16_ps(z, x, y);\n}\ntemplate <>\ninline __m512 madder(__m512bh x, __m512bh y, __m512 z, __m512* _) {\n    return _mm512_dpbf16_ps(z, x, y);\n}\n#endif\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED HORIZONTAL SUM\n\n#if defined(__ARM_NEON)\ninline float hsum(float32x4_t x) {\n    return vaddvq_f32(x);\n}\n#endif  // __ARM_NEON\n\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)\ninline float hsum(float16x8_t x) {\n    // todo: this works great on clang but it produces terrible code on gcc 12.3\n    return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_high_f16(x))));\n}\n#endif  // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC\n\n#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline float hsum(__m128 x) {\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\n    x = _mm_add_ps(x, _mm_movehl_ps(x, x));\n    x = _mm_add_ss(x, _mm_movehdup_ps(x));\n#else\n    __m128 t;\n    t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));\n    x = _mm_add_ps(x, t);\n    t = _mm_movehl_ps(t, x);\n    x = _mm_add_ss(x, t);\n#endif\n    return _mm_cvtss_f32(x);\n}\n#endif\n\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline float hsum(__m256 x) {\n    return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)));\n}\n#endif  // __AVX__\n\n#if defined(__AVX512F__)\ninline float hsum(__m512 x) {\n    return _mm512_reduce_add_ps(x);\n}\n#endif  // __AVX512F__\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED MEMORY LOADING\n\ntemplate <typename T, typename U>\nT load(const U*);\n\ntemplate <>\ninline float load(const float* p) {\n    return *p;\n}\ntemplate <>\ninline float load(const ggml_fp16_t* p) {\n    return unhalf(*p);\n}\ntemplate <>\ninline float load(const ggml_bf16_t* p) {\n    return unhalf(*p);\n}\n\n#if defined(__ARM_NEON)\ntemplate <>\ninline float32x4_t load(const float* p) {\n    return vld1q_f32(p);\n}\ntemplate <>\ninline float32x4_t load(const ggml_bf16_t* p) {\n    return vreinterpretq_f32_u32(vshll_n_u16(vld1_u16((const unsigned short*)p), 16));\n}\n#if !defined(_MSC_VER)\ntemplate <>\ninline float16x8_t load(const ggml_fp16_t* p) {\n    return vld1q_f16((const float16_t*)p);\n}\ntemplate <>\ninline float32x4_t load(const ggml_fp16_t* p) {\n    return vcvt_f32_f16(vld1_f16((const float16_t*)p));\n}\n#endif  // _MSC_VER\n#endif  // __ARM_NEON\n\n#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ntemplate <>\ninline __m128 load(const float* p) {\n    return _mm_loadu_ps(p);\n}\n#endif  // __SSE__\n\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ntemplate <>\ninline __m256 load(const float* p) {\n    return _mm256_loadu_ps(p);\n}\n#endif  // __AVX__\n\n#if defined(__AVX2__) || defined(__AVX512F__)\ntemplate <>\ninline __m256 load(const ggml_bf16_t* p) {\n    return _mm256_castsi256_ps(\n        _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)p)), 16));\n}\n#endif  // __AVX2__\n\n#if defined(__F16C__)\ntemplate <>\ninline __m256 load(const ggml_fp16_t* p) {\n    return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)p));\n}\n#endif  // __F16C__\n\n#if defined(__AVX512F__)\ntemplate <>\ninline __m512 load(const float* p) {\n    return _mm512_loadu_ps(p);\n}\ntemplate <>\ninline __m512 load(const ggml_fp16_t* p) {\n    return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)p));\n}\ntemplate <>\ninline __m512 load(const ggml_bf16_t* p) {\n    return _mm512_castsi512_ps(\n        _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)p)), 16));\n}\n#endif  // __AVX512F__\n\n#if defined(__AVX512BF16__)\ntemplate <>\ninline __m512bh load(const ggml_bf16_t* p) {\n    return (__m512bh)_mm512_loadu_ps((const float*)p);\n}\ntemplate <>\ninline __m512bh load(const float* p) {\n    return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));\n}\n#endif  // __AVX512BF16__\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// FLOATING POINT OUTPUT STREAMING\n\ninline void store(float* p, float f) {\n    *p = f;\n}\n\ninline void store(ggml_fp16_t* p, float f) {\n    *p = GGML_FP32_TO_FP16(f);\n}\n\ninline void store(ggml_bf16_t* p, float f) {\n    *p = GGML_FP32_TO_BF16(f);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// FLOATING POINT MATRIX MULTIPLICATION\n\ntemplate <int CONFIG, int KN, typename D, typename V, typename TA, typename TB, typename TC>\nclass tinyBLAS {\n   public:\n    tinyBLAS(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n    }\n\n    void matmul(long m, long n, int task) {\n        if (task == GGML_TASK_TYPE_COMPUTE)\n            mnpack(0, m, 0, n);\n    }\n\n   private:\n    NOINLINE void mnpack(long m0, long m, long n0, long n) {\n        long mc, nc, mp, np;\n\n#if VECTOR_REGISTERS == 32\n        if (!FLAG_precise) {\n            switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {\n                case 0x55:\n                    mc = 5;\n                    nc = 5;\n                    gemm<5, 5, false>(m0, m, n0, n);\n                    break;\n                case 0x54:\n                case 0x53:\n                case 0x52:\n                case 0x45:\n                case 0x44:\n                case 0x43:\n                case 0x42:\n                case 0x35:\n                case 0x34:\n                case 0x33:\n                case 0x32:\n                case 0x25:\n                case 0x24:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x51:\n                case 0x41:\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, false>(m0, m, n0, n);\n                    break;\n                case 0x15:\n                case 0x14:\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, false>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else {\n            switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) {\n                case 0x43:\n                    mc = 4;\n                    nc = 3;\n                    gemm<4, 3, true>(m0, m, n0, n);\n                    break;\n                case 0x42:\n                case 0x33:\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x41:\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n#endif\n\n#if VECTOR_REGISTERS == 16\n        if (!FLAG_precise) {\n            switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) {\n                case 0x43:\n                    mc = 4;\n                    nc = 3;\n                    gemm<4, 3, false>(m0, m, n0, n);\n                    break;\n                case 0x42:\n                case 0x33:\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x41:\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, false>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, false>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) {\n                case 0x32:\n                    mc = 3;\n                    nc = 2;\n                    gemm<3, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x23:\n                    mc = 2;\n                    nc = 3;\n                    gemm<2, 3, true>(m0, m, n0, n);\n                    break;\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n#endif\n\n        mp = m0 + (m - m0) / mc * mc;\n        np = n0 + (n - n0) / nc * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n    template <int RM, int RN, int PRECISE>\n    NOINLINE void gemm(long m0, long m, long n0, long n) {\n        long ytiles = RM > 1 ? (m - m0) / RM : 1;\n        long xtiles = RN > 1 ? (n - n0) / RN : 1;\n        long tiles = xtiles * ytiles;\n        long duty = (tiles + nth - 1) / nth;\n        long start = duty * ith;\n        long end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (long job = start; job < end; ++job) {\n            long ii = m0 + job / xtiles * RM;\n            long jj = n0 + job % xtiles * RN;\n            D Cv[RN][RM] = {};\n            D Ce[RN][RM] = {};\n            for (long l = 0; l < k; l += KN)\n#pragma GCC unroll 100\n                for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                    for (int i = 0; i < RM; ++i)\n                        if (PRECISE)\n                            Cv[j][i] = madder(load<V>(INDEX(A, lda, ii + i, l)),  //\n                                              load<V>(INDEX(B, ldb, jj + j, l)),  //\n                                              Cv[j][i], &Ce[j][i]);\n                        else\n                            Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, l)),  //\n                                            load<V>(INDEX(B, ldb, jj + j, l)),  //\n                                            Cv[j][i]);\n#pragma GCC unroll 100\n            for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                for (int i = 0; i < RM; ++i)\n                    store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));\n        }\n    }\n\n    const TA* const A;\n    const TB* const B;\n    TC* const C;\n    const long k;\n    const long lda;\n    const long ldb;\n    const long ldc;\n    const int ith;\n    const int nth;\n};\n\n//////////////////////////////////////////////////////////////////////////////////////////\n// QUANT ZERO MATRIX MULTIPLICATION\n\n#if defined(__ARM_FEATURE_DOTPROD)\ntemplate <int CONFIG, typename TA, typename TB, typename TC>\nclass tinyBLAS_Q0_ARM {\n   public:\n    tinyBLAS_Q0_ARM(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n    }\n\n    void matmul(long m, long n, int task) {\n        if (task == GGML_TASK_TYPE_COMPUTE)\n            mnpack(0, m, 0, n);\n    }\n\n   private:\n    NOINLINE void mnpack(long m0, long m, long n0, long n) {\n        long mc, nc, mp, np;\n\n        if (!FLAG_precise) {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {\n                case 0x33:\n                    mc = 3;\n                    nc = 3;\n                    gemm<3, 3, false>(m0, m, n0, n);\n                    break;\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, false>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, false>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {\n                case 0x33:\n                    mc = 3;\n                    nc = 3;\n                    gemm<3, 3, true>(m0, m, n0, n);\n                    break;\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n\n        mp = m0 + (m - m0) / mc * mc;\n        np = n0 + (n - n0) / nc * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n    template <int RM, int RN, int PRECISE>\n    NOINLINE void gemm(long m0, long m, long n0, long n) {\n        long ytiles = RM > 1 ? (m - m0) / RM : 1;\n        long xtiles = RN > 1 ? (n - n0) / RN : 1;\n        long tiles = xtiles * ytiles;\n        long duty = (tiles + nth - 1) / nth;\n        long start = duty * ith;\n        long end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (long job = start; job < end; ++job) {\n            long ii = m0 + job / xtiles * RM;\n            long jj = n0 + job % xtiles * RN;\n            float32x4_t Cv[RN][RM] = {};\n            float32x4_t Ce[RN][RM] = {};\n            for (int l = 0; l < k; ++l)\n#pragma GCC unroll 100\n                for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                    for (int i = 0; i < RM; ++i) {\n                        float32x4_t a = vcvtq_f32_s32(vdotq_s32(\n                            vdotq_s32(vdupq_n_s32(0), load_lo(INDEX(A, lda, ii + i, l)),\n                                      load_lo(INDEX(B, ldb, jj + j, l))),\n                            load_hi(INDEX(A, lda, ii + i, l)), load_hi(INDEX(B, ldb, jj + j, l))));\n                        float b = unhalf(INDEX(A, lda, ii + i, l)->d) *\n                                  unhalf(INDEX(B, ldb, jj + j, l)->d);\n                        if (PRECISE)\n                            Cv[j][i] = badder(a, b, Cv[j][i], &Ce[j][i]);\n                        else\n                            Cv[j][i] = vmlaq_n_f32(Cv[j][i], a, b);\n                    }\n#pragma GCC unroll 100\n            for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                for (int i = 0; i < RM; ++i)\n                    store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));\n        }\n    }\n\n    inline int8x16_t load_lo(const block_q8_0* b) {\n        return vld1q_s8(b->qs);\n    }\n\n    inline int8x16_t load_hi(const block_q8_0* b) {\n        return vld1q_s8(b->qs + 16);\n    }\n\n    inline int8x16_t load_lo(const block_q4_0* b) {\n        return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), vdupq_n_u8(0x0f))),\n                        vdupq_n_s8(0x8));\n    }\n\n    inline int8x16_t load_hi(const block_q4_0* b) {\n        return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8));\n    }\n\n    const TA* const A;\n    const TB* const B;\n    TC* const C;\n    const long k;\n    const long lda;\n    const long ldb;\n    const long ldc;\n    const int ith;\n    const int nth;\n};\n#endif  // __ARM_FEATURE_DOTPROD\n\n#if defined(__AVX2__) || defined(__AVX512F__)\ntemplate <int CONFIG, typename TA, typename TB, typename TC>\nclass tinyBLAS_Q0_AVX2 {\n   public:\n    tinyBLAS_Q0_AVX2(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n    }\n\n    void matmul(long m, long n, int task) {\n        if (task == GGML_TASK_TYPE_COMPUTE)\n            mnpack(0, m, 0, n);\n    }\n\n   private:\n    void mnpack(long m0, long m, long n0, long n) {\n        long mc, nc, mp, np;\n\n#if VECTOR_REGISTERS == 32\n        if (!FLAG_precise) {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {\n                case 0x33:\n                    mc = 3;\n                    nc = 3;\n                    gemm<3, 3, false>(m0, m, n0, n);\n                    break;\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {\n                case 0x33:\n                    mc = 3;\n                    nc = 3;\n                    gemm<3, 3, true>(m0, m, n0, n);\n                    break;\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n#endif\n\n#if VECTOR_REGISTERS == 16\n        if (!FLAG_precise) {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) {\n                case 0x32:\n                    mc = 3;\n                    nc = 2;\n                    gemm<3, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x23:\n                    mc = 2;\n                    nc = 3;\n                    gemm<2, 3, false>(m0, m, n0, n);\n                    break;\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, false>(m0, m, n0, n);\n                    break;\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, false>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else {\n            switch ((MIN(m - m0, 2) << 4) | MIN(n - n0, 1)) {\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n#endif\n\n        mp = m0 + (m - m0) / mc * mc;\n        np = n0 + (n - n0) / nc * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n    template <int RM, int RN, int PRECISE>\n    NOINLINE void gemm(long m0, long m, long n0, long n) {\n        long ytiles = RM > 1 ? (m - m0) / RM : 1;\n        long xtiles = RN > 1 ? (n - n0) / RN : 1;\n        long tiles = xtiles * ytiles;\n        long duty = (tiles + nth - 1) / nth;\n        long start = duty * ith;\n        long end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (long job = start; job < end; ++job) {\n            long ii = m0 + job / xtiles * RM;\n            long jj = n0 + job % xtiles * RN;\n            __m256 Cv[RN][RM] = {};\n            __m256 Ce[RN][RM] = {};\n            for (long l = 0; l < k; ++l)\n#pragma GCC unroll 100\n                for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                    for (int i = 0; i < RM; ++i) {\n                        __m256 a = _mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) *\n                                                  unhalf(INDEX(B, ldb, jj + j, l)->d));\n                        __m256 b = updot(_mm256_sign_epi8(load(INDEX(A, lda, ii + i, l)),\n                                                          load(INDEX(A, lda, ii + i, l))),\n                                         _mm256_sign_epi8(load(INDEX(B, ldb, jj + j, l)),\n                                                          load(INDEX(A, lda, ii + i, l))));\n                        if (PRECISE)\n                            Cv[j][i] = madder(a, b, Cv[j][i], &Ce[j][i]);\n                        else\n                            Cv[j][i] = madd(a, b, Cv[j][i]);\n                    }\n#pragma GCC unroll 100\n            for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                for (int i = 0; i < RM; ++i)\n                    store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));\n        }\n    }\n\n    inline __m256i load(const block_q8_0* b) {\n        return _mm256_loadu_si256((const __m256i*)b->qs);\n    }\n\n    inline __m256i load(const block_q4_0* b) {\n        __m128i x = _mm_loadu_si128((const __m128i*)b->qs);\n        return _mm256_sub_epi8(_mm256_and_si256(_mm256_set1_epi8(15),\n                                                _mm256_insertf128_si256(_mm256_castsi128_si256(x),\n                                                                        _mm_srli_epi16(x, 4), 1)),\n                               _mm256_set1_epi8(8));\n    }\n\n    inline __m256 updot(__m256i u, __m256i s) {\n        __m256i res;\n#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))\n        res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);\n#else\n        res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));\n#endif\n        return _mm256_cvtepi32_ps(res);\n    }\n\n    const TA* const A;\n    const TB* const B;\n    TC* const C;\n    const long k;\n    const long lda;\n    const long ldb;\n    const long ldc;\n    const int ith;\n    const int nth;\n};\n#endif  // __AVX2__\n\n}  // namespace\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_mixmul.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul.inc\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"tinyblas_cpu.h\"\n\n//\n//\n//                                ██████╗ ██╗   █████╗ ██████╗\n//         ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║  ██╔══██╗██╔═══╝\n//         ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║  ███████║██████╗\n//           ██║  ██║██▀███║╚███╔╝██╔══██╗██║  ██╔══██║╔═══██║\n//           ██║  ██║██║ ██║ ███║ ██████╔╝████╗██║  ██║██████║\n//           ╚═╝  ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝  ╚═╝╚═════╝\n//\n//               MIXTURE OF EXPERTS TENSOR MULTIPLICATION\n//\n//\n// SHAPES\n//\n//   - weights [cols, rows, experts]\n//   - thought [cols, tasks, tokens] w/ tasks ≤ thinkers\n//   - result  [rows, thinkers, tokens] w/ thinkers ≤ experts\n//   - plan    [thinkers, tokens] w/ i32 < experts\n//\n// DEFINITION\n//\n//   for thinker in range(thinkers):\n//     for token in range(tokens):\n//       for row in range(rows):\n//         c = 0\n//         for col in range(cols):\n//           expert = plan[token][thinker]\n//           a = weights[expert][row][col]\n//           b = thought[token][thinker % tasks][col]\n//           c += a * b\n//         result[token][thinker][row] = c\n//\n// REGULARITIES\n//\n//   - tokens can be odd\n//   - thinkers is usually 2\n//   - tasks is usually 1 or 2\n//   - cols should be a multiple of 64\n//   - rows should be a multiple of 64\n//   - experts is usually 8 but could be 60\n//   - tokens is always 1 for token generation\n//   - tokens can be huge for prompt processing\n//\n// EXAMPLE\n//\n//   mixtral 8x7b w/ 217 token prompt\n//\n//           |  ne*0 ne*1 ne*2 ne*3 | nb*0    nb*1      nb*2       nb*3 | type\n//   =========================================================================\n//   weights | 16384 6144    8    1 |   18  0x2400 0x3600000 0x1b000000 | q4_0\n//   thought | 16384    2  217    1 |    4 0x10000   0x20000  0x1b20000 | f32\n//   result  |  6144    2  217    1 |    4  0x6000    0xc000   0xa2c000 | f32\n//   plan    |     2  217    1    1 |    4    0x20    0x1b20     0x1b20 | i32\n//\n\nnamespace {\n\nclass MixMul {\n   public:\n    MixMul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result)\n        : params(params),\n          weights(weights),\n          thought(thought),\n          plan(plan),\n          result(result),\n          rows(weights->ne[1]),\n          cols(weights->ne[0]),\n          experts(weights->ne[2]),\n          thinkers(plan->ne[0]),\n          tasks(thought->ne[1]),\n          tokens(thought->ne[2]),\n          ldq((cols * 2 + ROW_ALIGN - 1) & -ROW_ALIGN),\n          wdata_((char*)(((uintptr_t)params->wdata + MAX_ALIGN - 1) & -MAX_ALIGN)),\n          allocated_(0) {\n    }\n\n    bool allocate_shared_memory() {\n        if (!(quantized_thought_ = allocate<char>(MATRIX_ALIGN, tokens * tasks * ldq)))\n            return false;\n        if (!(rowptr_result_ = allocate<uintptr_t>(ROW_ALIGN, experts * tokens * thinkers)))\n            return false;\n        if (!(rowptr_thought_ = allocate<uintptr_t>(ROW_ALIGN, experts * tokens * thinkers)))\n            return false;\n        if (!(rowptr_count_ = allocate<long>(sizeof(long), experts)))\n            return false;\n        return true;\n    }\n\n    size_t get_allocated_bytes() {\n        return (wdata_ - (char*)params->wdata) + allocated_;\n    }\n\n    bool mixmul() {\n        // invariants\n        assert(tasks <= thinkers);\n        assert(thinkers <= experts);\n        assert(tokens == plan->ne[1]);\n        assert(rows == result->ne[0]);\n        assert(cols == thought->ne[0]);\n        assert(tokens == result->ne[2]);\n        assert(thinkers == result->ne[1]);\n\n        // dimensionality\n        assert(plan->ne[2] == 1);\n        assert(plan->ne[3] == 1);\n        assert(result->ne[3] == 1);\n        assert(weights->ne[3] == 1);\n        assert(thought->ne[3] == 1);\n\n        // miscellaneous\n        assert(params->nth > 0);\n        assert(params->ith < params->nth);\n        assert(plan->type == GGML_TYPE_I32);\n\n        // check nb01 is convertible to lda\n        if (weights->nb[1] % ggml_type_size(weights->type))\n            return false;\n\n        // no support for column strides\n        if (result->nb[0] != ggml_type_size(result->type))\n            return false;\n        if (thought->nb[0] != ggml_type_size(thought->type))\n            return false;\n        if (weights->nb[0] != ggml_type_size(weights->type))\n            return false;\n\n        // supported output types\n        switch (result->type) {\n            case GGML_TYPE_F32:\n                return mixmuler<float>();\n            default:\n                return false;\n        }\n    }\n\n   private:\n    template <typename TC>\n    bool mixmuler() {\n        switch (weights->type) {\n            case GGML_TYPE_F32:\n                if (thought->type != GGML_TYPE_F32)\n                    return false;\n#if defined(__AVX512F__)\n                return mixmat<16, 1, tinyBLAS<NCB | NCC, 16, __m512, __m512, float, float, TC>, float,\n                              float, TC>();\n#elif defined(__AVX__) || defined(__AVX2__)\n                return mixmat<8, 1, tinyBLAS<NCB | NCC, 8, __m256, __m256, float, float, TC>, float,\n                              float, TC>();\n#elif defined(__SSE__)\n                return mixmat<4, 1, tinyBLAS<NCB | NCC, 4, __m128, __m128, float, float, TC>, float,\n                              float, TC>();\n#elif defined(__ARM_NEON)\n                return mixmat<4, 1, tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, float, float, TC>,\n                              float, float, TC>();\n#else\n                return false;\n#endif\n\n            case GGML_TYPE_BF16:\n                if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_BF16)\n                    return false;\n#if defined(__AVX512BF16__)\n                if (!FLAG_precise) {\n                    return mixmat<\n                        32, 1, tinyBLAS<NCB | NCC, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC>,\n                        ggml_bf16_t, ggml_bf16_t, TC>();\n                } else {\n                    return mixmat<16, 1,\n                                  tinyBLAS<NCB | NCC, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC>,\n                                  ggml_bf16_t, ggml_bf16_t, TC>();\n                }\n#elif defined(__AVX512F__)\n                return mixmat<16, 1,\n                              tinyBLAS<NCB | NCC, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC>,\n                              ggml_bf16_t, ggml_bf16_t, TC>();\n#elif defined(__AVX2__)\n                return mixmat<8, 1,\n                              tinyBLAS<NCB | NCC, 8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, TC>,\n                              ggml_bf16_t, ggml_bf16_t, TC>();\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n                return mixmat<\n                    4, 1,\n                    tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_bf16_t, ggml_bf16_t, TC>,\n                    ggml_bf16_t, ggml_bf16_t, TC>();\n#else\n                return false;\n#endif\n\n            case GGML_TYPE_F16:\n                if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_F16)\n                    return false;\n#if defined(__AVX512F__)\n                return mixmat<16, 1,\n                              tinyBLAS<NCB | NCC, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC>,\n                              ggml_fp16_t, ggml_fp16_t, TC>();\n#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)\n                // if (X86_CHECK(F16C)) {\n                return mixmat<8, 1,\n                              tinyBLAS<NCB | NCC, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC>,\n                              ggml_fp16_t, ggml_fp16_t, TC>();\n                // } else {\n                //     return false;\n                // }\n#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)\n                if (result->op_params[0] == GGML_PREC_F32) {\n                    return mixmat<\n                        4, 1,\n                        tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, TC>,\n                        ggml_fp16_t, ggml_fp16_t, TC>();\n                } else {\n                    return mixmat<\n                        8, 1,\n                        tinyBLAS<NCB | NCC, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC>,\n                        ggml_fp16_t, ggml_fp16_t, TC>();\n                }\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n                return mixmat<\n                    4, 1,\n                    tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, TC>,\n                    ggml_fp16_t, ggml_fp16_t, TC>();\n#else\n                return false;\n#endif\n\n            case GGML_TYPE_Q4_0:\n                if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0)\n                    return false;\n#if defined(__AVX2__) || defined(__AVX512F__)\n                return mixmat<32, 32, tinyBLAS_Q0_AVX2<NCB | NCC, block_q4_0, block_q8_0, TC>,\n                              block_q4_0, block_q8_0, TC>();\n#elif defined(__ARM_FEATURE_DOTPROD)\n                return mixmat<32, 32, tinyBLAS_Q0_ARM<NCB | NCC, block_q4_0, block_q8_0, TC>,\n                              block_q4_0, block_q8_0, TC>();\n#else\n                return false;\n#endif\n\n            case GGML_TYPE_Q8_0:\n                if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0)\n                    return false;\n#if defined(__AVX2__) || defined(__AVX512F__)\n                return mixmat<32, 32, tinyBLAS_Q0_AVX2<NCB | NCC, block_q8_0, block_q8_0, TC>,\n                              block_q8_0, block_q8_0, TC>();\n#elif defined(__ARM_FEATURE_DOTPROD)\n                return mixmat<32, 32, tinyBLAS_Q0_ARM<NCB | NCC, block_q8_0, block_q8_0, TC>,\n                              block_q8_0, block_q8_0, TC>();\n#else\n                return false;\n#endif\n\n            default:\n                return false;\n        }\n    }\n\n    template <int KN, int BS, typename BLAS, typename TA, typename TB, typename TC>\n    bool mixmat() {\n        if (cols % KN)\n            return false;\n        switch (params->type) {\n            case GGML_TASK_TYPE_INIT:\n                if (thought->type != ggml_type_trait<TB>::id)\n                    quantize_thought(ggml_type_trait<TB>::id);\n                build_row_pointers(ggml_type_trait<TB>::id);\n                return true;\n            case GGML_TASK_TYPE_COMPUTE:\n                assert(!(cols % BS));\n                assert(!(weights->nb[1] % sizeof(TA)));\n                for (int expert = 0; expert < experts; ++expert) {\n                    BLAS tb{cols / BS,\n                            (const TA*)((const char*)weights->data + expert * weights->nb[2]),\n                            (long)(weights->nb[1] / sizeof(TA)),\n                            (const TB*)(rowptr_thought_ + expert * tokens * thinkers),\n                            0,\n                            (TC*)(rowptr_result_ + expert * tokens * thinkers),\n                            0,\n                            params->ith,\n                            params->nth};\n                    tb.matmul(rows, rowptr_count_[expert], GGML_TASK_TYPE_COMPUTE);\n                }\n                return true;\n            default:\n                return true;\n        }\n    }\n\n    void build_row_pointers(ggml_type vec_dot_type) {\n        for (int expert = params->ith; expert < experts; expert += params->nth) {\n            long count = 0;\n            for (long token = 0; token < tokens; ++token)\n                for (int thinker = 0; thinker < thinkers; ++thinker)\n                    if (expert == *(const int32_t*)((const char*)plan->data +\n                                                    token * plan->nb[1] + thinker * plan->nb[0])) {\n                        long row = count++;\n                        long idx = expert * thinkers * tokens + row;\n                        rowptr_result_[idx] =\n                            (uintptr_t)((char*)result->data + token * result->nb[2] +\n                                        thinker * result->nb[1]);\n                        if (thought->type == vec_dot_type)\n                            rowptr_thought_[idx] =\n                                (uintptr_t)((char*)thought->data + token * thought->nb[2] +\n                                            thinker % tasks * thought->nb[1]);\n                        else\n                            rowptr_thought_[idx] =\n                                (uintptr_t)((char*)quantized_thought_ + token * tasks * ldq +\n                                            thinker % tasks * ldq);\n                    }\n            rowptr_count_[expert] = count;\n        }\n    }\n\n    void quantize_thought(ggml_type vec_dot_type) {\n        long chore = 0;\n        for (long token = 0; token < tokens; ++token)\n            for (int task = 0; task < tasks; ++task)\n                if (chore++ % params->nth == params->ith)\n                    quantize_row(quantized_thought_ + token * tasks * ldq + task * ldq,\n                                 (const float*)((const char*)thought->data +\n                                                token * thought->nb[2] + task * thought->nb[1]),\n                                 vec_dot_type);\n    }\n\n    void quantize_row(void* dst, const float* src, ggml_type type) {\n        assert((long)ggml_row_size(type, cols) <= ldq);\n        switch (type) {\n            case GGML_TYPE_F16:\n                ggml_fp32_to_fp16_row(src, (ggml_fp16_t*)dst, cols);\n                break;\n            case GGML_TYPE_BF16:\n                ggml_fp32_to_bf16_row(src, (ggml_bf16_t*)dst, cols);\n                break;\n            case GGML_TYPE_Q8_0:\n                quantize_row_q8_0((const float*)src, (block_q8_0*)dst, cols);\n                break;\n            default:\n                GGML_UNREACHABLE();\n        }\n    }\n\n    template <typename T>\n    T* allocate(size_t align, size_t elems) {\n        T* res = nullptr;\n        size_t need = sizeof(T) * elems;\n        size_t base = allocated_;\n        base += align - 1;\n        base &= -align;\n        size_t toto = base + need;\n        if (toto >= allocated_ && toto <= params->wsize) {\n            res = (T*)(wdata_ + base);\n            allocated_ = toto;\n        }\n        return res;\n    }\n\n    const ggml_compute_params* const params;\n    const ggml_tensor* const weights;\n    const ggml_tensor* const thought;\n    const ggml_tensor* const plan;\n    ggml_tensor* const result;\n    const long rows;\n    const long cols;\n    const int experts;\n    const int thinkers;\n    const int tasks;\n    const long tokens;\n    const long ldq;\n\n    // variables\n    char* const wdata_;\n    size_t allocated_;\n\n    // shared memory\n    long* rowptr_count_ /*[experts]*/;\n    char* quantized_thought_ /*[tokens][tasks][cols][2]*/;\n    uintptr_t* rowptr_result_ /*[experts][tokens*thinkers]*/;\n    uintptr_t* rowptr_thought_ /*[experts][tokens*thinkers]*/;\n};\n\n}  // namespace\n\n/**\n * Performs \"mixture of experts\" tensor multiplication on CPU.\n */\nbool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {\n    MixMul mm{params, weights, thought, plan, result};\n    return mm.allocate_shared_memory() && mm.mixmul();\n}\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_avx\n#include \"tinyblas_cpu_mixmul.inc\"\n\n/**\n * Returns number of shared memory bytes llamafile_mixmul() needs.\n */\nsize_t llamafile_mixmul_needs(const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan) {\n    ggml_compute_params params{};\n    params.wsize = 0x7ffff000;\n    params.wdata = (void*)0x1000;\n    MixMul mm{&params, weights, thought, plan, 0};\n    if (mm.allocate_shared_memory())\n        return mm.get_allocated_bytes();\n    else\n        return 0;\n}\n\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_avx2\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_avx512f\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_avxvnni\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_fma\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_zen4\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_mixmul_arm80.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_arm80.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#ifdef __aarch64__\n#define llamafile_mixmul llamafile_mixmul_arm80\n#include \"tinyblas_cpu_mixmul.inc\"\n\n/**\n * Returns number of shared memory bytes llamafile_mixmul() needs.\n */\nsize_t llamafile_mixmul_needs(const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan) {\n    ggml_compute_params params{};\n    params.wsize = 0x7ffff000;\n    params.wdata = (void*)0x1000;\n    MixMul mm{&params, weights, thought, plan, 0};\n    if (mm.allocate_shared_memory())\n        return mm.get_allocated_bytes();\n    else\n        return 0;\n}\n\n#endif  // __aarch64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_mixmul_arm82.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_arm82.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#ifdef __aarch64__\n#define llamafile_mixmul llamafile_mixmul_arm82\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __aarch64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n//\n//\n//                                ██████╗ ██╗   █████╗ ██████╗\n//         ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║  ██╔══██╗██╔═══╝\n//         ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║  ███████║██████╗\n//           ██║  ██║██▀███║╚███╔╝██╔══██╗██║  ██╔══██║╔═══██║\n//           ██║  ██║██║ ██║ ███║ ██████╔╝████╗██║  ██║██████║\n//           ╚═╝  ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝  ╚═╝╚═════╝\n//\n//                   BASIC LINEAR ALGEBRA SUBPROGRAMS\n//\n//\n// This file implements multithreaded CPU matrix multiplication for the\n// common contiguous use case C = Aᵀ * B. These kernels are designed to\n// have excellent performance[1] for matrices that fit in the CPU cache\n// without imposing any overhead such as cache filling or malloc calls.\n//\n// This implementation does not guarantee any upper bound with rounding\n// errors, which grow along with k. Our goal's to maximally exploit the\n// hardware for performance, and then use whatever resources remain for\n// improving numerical accuracy.\n//\n// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].\n//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].\n\n#if defined(KTRANSFORMERS_USE_NPU) && KTRANSFORMERS_USE_NPU\n        // use ARM version\n        #include \"tinyblas_cpu_sgemm_arm.inc\"\n#else\n        // use x86 version\n        #include \"tinyblas_cpu_sgemm_x86.inc\"\n#endif"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_avx\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_avx2\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_avx512f\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_avxvnni.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avxvnni.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_avxvnni\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_fma.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_fma.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_fma\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm_amd_zen4.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_zen4.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_zen4\n#define iqk_mul_mat iqk_mul_mat_zen4\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm_arm.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"tinyblas_cpu.h\"\n#include <arm_neon.h>\n#include <ostream>\n#include <iostream>\n//\n//\n//                                ██████╗ ██╗   █████╗ ██████╗\n//         ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║  ██╔══██╗██╔═══╝\n//         ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║  ███████║██████╗\n//           ██║  ██║██▀███║╚███╔╝██╔══██╗██║  ██╔══██║╔═══██║\n//           ██║  ██║██║ ██║ ███║ ██████╔╝████╗██║  ██║██████║\n//           ╚═╝  ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝  ╚═╝╚═════╝\n//\n//                   BASIC LINEAR ALGEBRA SUBPROGRAMS\n//\n//\n// This file implements multithreaded CPU matrix multiplication for the\n// common contiguous use case C = Aᵀ * B. These kernels are designed to\n// have excellent performance[1] for matrices that fit in the CPU cache\n// without imposing any overhead such as cache filling or malloc calls.\n//\n// This implementation does not guarantee any upper bound with rounding\n// errors, which grow along with k. Our goal's to maximally exploit the\n// hardware for performance, and then use whatever resources remain for\n// improving numerical accuracy.\n//\n// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].\n//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].\n\nnamespace {\n\ntemplate <typename TC>\nvoid SgemmHelperN1Neon2(long m, long n, long k, const float16_t* A, long lda, const float16_t* B, long ldb,\n                        TC* C, long ldc, int ith, int nth) {\n    // A m * k    B n * k    c n * m\n    const long NVL = 8;\n    long kk = k / (NVL * 4);\n    kk = kk * (NVL * 4);\n    long length = (m / nth) + (ith < (m % nth) ? 1 : 0);\n    long startRow = ith * (m / nth) + (ith < (m % nth) ? ith : (m % nth));\n    long endRow = startRow + length;\n    for (long i = startRow; i < endRow; i ++) {\n        const float16_t* tA = A + i * lda;\n        float32x4_t c0 = vdupq_n_f32(0);\n        float32x4_t c1 = vdupq_n_f32(0);\n        float32x4_t c2 = vdupq_n_f32(0);\n        float32x4_t c3 = vdupq_n_f32(0);\n        float32x4_t c4 = vdupq_n_f32(0);\n        float32x4_t c5 = vdupq_n_f32(0);\n        float32x4_t c6 = vdupq_n_f32(0);\n        float32x4_t c7 = vdupq_n_f32(0);\n        for (long j = 0; j < kk; j += NVL * 4) {\n            __builtin_prefetch(tA + 192, 0, 0);\n            float16x8_t a0 = vld1q_f16(tA + j);\n            float16x8_t b0 = vld1q_f16(B + j);\n            c0 = vfmlalq_low_f16(c0, a0, b0);\n            c1 = vfmlalq_high_f16(c1, a0, b0);\n            float16x8_t a1 = vld1q_f16(tA + j + NVL);\n            float16x8_t b1 = vld1q_f16(B + j + NVL);\n            c2 = vfmlalq_low_f16(c2, a1, b1);\n            c3 = vfmlalq_high_f16(c3, a1, b1);\n            float16x8_t a2 = vld1q_f16(tA + j + NVL * 2);\n            float16x8_t b2 = vld1q_f16(B + j + NVL * 2);\n            c4 = vfmlalq_low_f16(c4, a2, b2);\n            c5 = vfmlalq_high_f16(c5, a2, b2);\n            float16x8_t a3 = vld1q_f16(tA + j + NVL * 3);\n            float16x8_t b3 = vld1q_f16(B + j + NVL * 3);\n            c6 = vfmlalq_low_f16(c6, a3, b3);\n            c7 = vfmlalq_high_f16(c7, a3, b3);\n        }\n        if (k - kk >= NVL * 2) {\n            float16x8_t a0 = vld1q_f16(tA + kk);\n            float16x8_t b0 = vld1q_f16(B + kk);\n            c0 = vfmlalq_low_f16(c0, a0, b0);\n            c1 = vfmlalq_high_f16(c1, a0, b0);\n            float16x8_t a1 = vld1q_f16(tA + kk + NVL);\n            float16x8_t b1 = vld1q_f16(B + kk + NVL);\n            c2 = vfmlalq_low_f16(c2, a1, b1);\n            c3 = vfmlalq_high_f16(c3, a1, b1);\n            kk += NVL * 2;\n        }\n        if (k - kk >= NVL) {\n            float16x8_t a = vld1q_f16(tA + kk);\n            float16x8_t b = vld1q_f16(B + kk);\n            c0 = vfmlalq_low_f16(c0, a, b);\n            c1 = vfmlalq_high_f16(c1, a, b);\n            kk += NVL;\n        }\n        TC sum = 0.0f;\n        for (long j = kk; j < k; j ++) {\n            sum += (float32_t)tA[j] * (float32_t)B[j];\n        }\n        c0 = vaddq_f32(c0, c1);\n        c2 = vaddq_f32(c2, c3);\n        c4 = vaddq_f32(c4, c5);\n        c6 = vaddq_f32(c6, c7);\n        c0 = vaddq_f32(c0, c2);\n        c4 = vaddq_f32(c4, c6);\n        sum += vaddvq_f32(c0) + vaddvq_f32(c4);\n        C[i] = sum;\n    }\n    return;\n}\n\ntemplate <typename TC>\nvoid SgemmHelperN1(long m, long n, long k, const ggml_fp16_t* A_, long lda, const ggml_fp16_t* B_, long ldb,\n                   TC* C, long ldc, int ith, int nth) {\n    // A m * k    B n * k    c n * m\n    float16_t *A = (float16_t*)A_;\n    float16_t *B = (float16_t*)B_;\n    long rowsPerThread = m / nth;\n    long startRow = ith * rowsPerThread;\n    long endRow = (ith == nth - 1) ? m : startRow + rowsPerThread;\n    for (long i = startRow; i < endRow; i ++) {\n        TC sum = 0.0f;\n        for (long j = 0; j < k; j ++) {\n            sum += (float32_t)A[i * lda + j] * (float32_t)B[j];\n        }\n        C[i] = sum;\n    }\n    return;\n}\n\ntemplate <typename TC>\nbool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    // std::cout << \"tinyBLAS tinyBLAS NOT_SUPPORTED FP16  55, n: \" << n << \", m: \" << m << \", k: \" << k << \", FLAG_precise: \" << FLAG_precise << \"\\n\"<<std::endl;\n    switch (Atype) {\n        case GGML_TYPE_F32: {\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n#if defined(__AVX512F__)\n            if (k % 16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{\n                k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__AVX__) || defined(__AVX2__)\n            if (k % 8)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{\n                k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_NEON)\n            if (k % 4)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{\n                k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_BF16: {\n#if defined(__AVX512BF16__)\n            if (k % 32)\n                return NOT_SUPPORTED;\n            if (Btype == GGML_TYPE_F32 && n < 2) {\n                tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{\n                    k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_BF16)\n                return NOT_SUPPORTED;\n            if (!FLAG_precise) {\n                tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{\n                    k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            } else {\n                tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{\n                    k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n#elif defined(__AVX512F__)\n            if (k % 16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{\n                k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__AVX2__)\n            if (k % 8)\n                return NOT_SUPPORTED;\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{\n                k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n            if (k % 4)\n                return NOT_SUPPORTED;\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{\n                k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_F16: {\n#if defined(__AVX512F__)\n            if (k % 16)\n                return NOT_SUPPORTED;\n            if (Btype == GGML_TYPE_F32 && n < 2) {\n                tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_F16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{\n                k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)\n            // if (X86_CHECK(F16C)) {\n            if (k % 8)\n                return NOT_SUPPORTED;\n            if (Btype == GGML_TYPE_F32 && n < 2) {\n                tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_F16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{\n                k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n            // } else {\n            //     return NOT_SUPPORTED;\n            // }\n#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)\n            if (n < 2 && !FLAG_precise) {\n                // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?\n                if (Btype == GGML_TYPE_F16 && task == GGML_TASK_TYPE_COMPUTE) {\n                    SgemmHelperN1Neon2<TC>(m, n, k, (const float16_t*)A, lda, (const float16_t*)B, ldb, C, ldc, ith, nth);\n                    // SgemmHelperN1<TC>(m, n, k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth);\n                    return true;\n                }\n                return NOT_SUPPORTED;\n            }\n            if (precision == GGML_PREC_F32) {\n                if (k % 4)\n                    return NOT_SUPPORTED;\n                if (Btype != GGML_TYPE_F32)\n                    return NOT_SUPPORTED;\n                tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            } else {\n                if (k % 8)\n                    return NOT_SUPPORTED;\n                if (Btype == GGML_TYPE_F32)\n                    return WANT_QUANTIZATION;\n                if (Btype != GGML_TYPE_F16)\n                    return NOT_SUPPORTED;\n                tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n            if (n < 2 && !FLAG_precise) {\n                // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?\n                // printf(\"tinyBLAS tinyBLAS NOT_SUPPORTED FP16 225, m: %ld, n: %ld, k: %ld\\n\", m, n, k);\n                if (Btype == GGML_TYPE_F16 && task == GGML_TASK_TYPE_COMPUTE) {\n                    SgemmHelperN1Neon2<TC>(m, n, k, (const float16_t*)A, lda, (const float16_t*)B, ldb, C, ldc, ith, nth);\n                    // SgemmHelperN1<TC>(m, n, k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth);\n                    return true;\n                }\n                std::cout << \"tinyBLAS tinyBLAS NOT_SUPPORTED FP16 231, n: \" << n << \", m: \" << m << \", k: \" << m << \", FLAG_precise: \" << FLAG_precise << \"\\n\"<<std::endl;\n                return NOT_SUPPORTED;\n            }\n            if (k % 4) {\n                // std::cout << \"tinyBLAS tinyBLAS NOT_SUPPORTED FP16  215\" <<std::endl;\n                return NOT_SUPPORTED;\n            }\n            if (Btype != GGML_TYPE_F32) {\n                // std::cout << \"tinyBLAS tinyBLAS NOT_SUPPORTED FP16  218\" <<std::endl;\n                return NOT_SUPPORTED;\n            }\n            // std::cout << \"tinyBLAS tinyBLAS true FP16\" <<std::endl;\n            tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{\n                k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            // std::cout << \"tinyBLAS tinyBLAS NOT_SUPPORTED FP16\" <<std::endl;\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_Q8_0: {\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_Q8_0)\n                return NOT_SUPPORTED;\n#if defined(__AVX2__) || defined(__AVX512F__)\n            tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{\n                k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_FEATURE_DOTPROD)\n            tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{\n                k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_Q4_0: {\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_Q8_0)\n                return NOT_SUPPORTED;\n#if defined(__AVX2__) || defined(__AVX512F__)\n            tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{\n                k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_FEATURE_DOTPROD)\n            tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{\n                k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        default:\n            return NOT_SUPPORTED;\n    }\n\n    (void)m;\n    (void)n;\n    (void)k;\n    (void)A;\n    (void)lda;\n    (void)B;\n    (void)ldb;\n    (void)C;\n    (void)ldc;\n    (void)ith;\n    (void)nth;\n    (void)Atype;\n    (void)Btype;\n    (void)precision;\n}\n\n}  // namespace\n\n/**\n * Performs optimized matrix multiplication on CPU.\n *\n * This subroutine may compute C = Aᵀ * B with column major ordering.\n * Despite its name, this isn't a generalized implementation. Work is\n * only performed when a handwritten kernel is written and available.\n * Otherwise the caller should fall back to a general matmul routine.\n *\n * For example, for single-threaded single-precision GEMM you can say\n *\n *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,\n *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,\n *                     GGML_PREC_DEFAULT);\n *\n * @param m is rows in `A` and `C`\n * @param n is cols in `B` and `C`\n * @param k is cols in `A` and rows in `B`\n * @param A is first input matrix (always transposed)\n * @param lda is row stride of `A`\n * @param B is second input matrix (never transposed)\n * @param ldb is row stride of `B`\n * @param C is input/output array of output matrices\n * @param ldc is row stride of `C`\n * @param ith is thread id (must be less than `nth`)\n * @param nth is number of threads (must be greater than zero)\n * @param Atype is GGML data type of `A`\n * @param Btype is GGML data type of `B`\n * @param Ctype is GGML data type of `C`\n * @param precision may be used to control the internal compute type\n * @return true if this function was able to service the matmul request\n */\nbool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    assert(m >= 0);\n    assert(n >= 0);\n    assert(k >= 0);\n    assert(lda >= k);\n    assert(ldb >= k);\n    assert(ldc >= m);\n    assert(nth > 0);\n    assert(ith < nth);\n\n#if QK_K == 256\n#if defined(__x86_64__) || defined(_M_X64)\n#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))\n    // if (X86_CHECK(AVX2) && X86_CHECK(FMA)) {\n    if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32){\n        if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {\n            return true;\n        }\n    }\n    if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {\n        // assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);\n        assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));\n        if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {\n            return true;\n        }\n    }\n    // }\n#endif\n#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER\n    if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {\n        if (iqk_mul_mat(m, n, k * QK_K, Atype, A, k, Btype, B, k, (float*)C, ldc, ith, nth)) {\n            return true;\n        }\n    }\n    if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {\n        // assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);\n        assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));\n        if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, k, Btype, B, k, (float*)C, ldc, ith, nth)) {\n            return true;\n        }\n    }\n#endif\n#endif\n\n    switch (Ctype) {\n        case GGML_TYPE_F32:\n            return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,\n                                        Btype, Ctype, precision);\n        default:\n            return NOT_SUPPORTED;\n    }\n}\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm_arm80.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_arm80.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#ifdef __aarch64__\n#define llamafile_sgemm llamafile_sgemm_arm80\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __aarch64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm_arm82.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_arm82.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#ifdef __aarch64__\n#define llamafile_sgemm llamafile_sgemm_arm82\n#define iqk_mul_mat iqk_mul_mat_arm82\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __aarch64__\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_sgemm_x86.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"tinyblas_cpu.h\"\n\n//\n//\n//                                ██████╗ ██╗   █████╗ ██████╗\n//         ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║  ██╔══██╗██╔═══╝\n//         ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║  ███████║██████╗\n//           ██║  ██║██▀███║╚███╔╝██╔══██╗██║  ██╔══██║╔═══██║\n//           ██║  ██║██║ ██║ ███║ ██████╔╝████╗██║  ██║██████║\n//           ╚═╝  ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝  ╚═╝╚═════╝\n//\n//                   BASIC LINEAR ALGEBRA SUBPROGRAMS\n//\n//\n// This file implements multithreaded CPU matrix multiplication for the\n// common contiguous use case C = Aᵀ * B. These kernels are designed to\n// have excellent performance[1] for matrices that fit in the CPU cache\n// without imposing any overhead such as cache filling or malloc calls.\n//\n// This implementation does not guarantee any upper bound with rounding\n// errors, which grow along with k. Our goal's to maximally exploit the\n// hardware for performance, and then use whatever resources remain for\n// improving numerical accuracy.\n//\n// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].\n//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].\n\nnamespace {\n\ntemplate <typename TC>\nbool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    switch (Atype) {\n        case GGML_TYPE_F32: {\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n#if defined(__AVX512F__)\n            if (k % 16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{\n                k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__AVX__) || defined(__AVX2__)\n            if (k % 8)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{\n                k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_NEON)\n            if (k % 4)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{\n                k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_BF16: {\n#if defined(__AVX512BF16__)\n            if (k % 32)\n                return NOT_SUPPORTED;\n            if (Btype == GGML_TYPE_F32 && n < 2) {\n                tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{\n                    k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_BF16)\n                return NOT_SUPPORTED;\n            if (!FLAG_precise) {\n                tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{\n                    k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            } else {\n                tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{\n                    k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n#elif defined(__AVX512F__)\n            if (k % 16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{\n                k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__AVX2__)\n            if (k % 8)\n                return NOT_SUPPORTED;\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{\n                k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n            if (k % 4)\n                return NOT_SUPPORTED;\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{\n                k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_F16: {\n#if defined(__AVX512F__)\n            if (k % 16)\n                return NOT_SUPPORTED;\n            if (Btype == GGML_TYPE_F32 && n < 2) {\n                tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_F16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{\n                k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)\n            // if (X86_CHECK(F16C)) {\n            if (k % 8)\n                return NOT_SUPPORTED;\n            if (Btype == GGML_TYPE_F32 && n < 2) {\n                tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_F16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{\n                k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n            // } else {\n            //     return NOT_SUPPORTED;\n            // }\n#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)\n            if (n < 2 && !FLAG_precise)\n                // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?\n                return NOT_SUPPORTED;\n            if (precision == GGML_PREC_F32) {\n                if (k % 4)\n                    return NOT_SUPPORTED;\n                if (Btype != GGML_TYPE_F32)\n                    return NOT_SUPPORTED;\n                tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            } else {\n                if (k % 8)\n                    return NOT_SUPPORTED;\n                if (Btype == GGML_TYPE_F32)\n                    return WANT_QUANTIZATION;\n                if (Btype != GGML_TYPE_F16)\n                    return NOT_SUPPORTED;\n                tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n            if (n < 2 && !FLAG_precise)\n                // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?\n                return NOT_SUPPORTED;\n            if (k % 4)\n                return NOT_SUPPORTED;\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{\n                k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_Q8_0: {\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_Q8_0)\n                return NOT_SUPPORTED;\n#if defined(__AVX2__) || defined(__AVX512F__)\n            tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{\n                k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_FEATURE_DOTPROD)\n            tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{\n                k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_Q4_0: {\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_Q8_0)\n                return NOT_SUPPORTED;\n#if defined(__AVX2__) || defined(__AVX512F__)\n            tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{\n                k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_FEATURE_DOTPROD)\n            tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{\n                k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        default:\n            return NOT_SUPPORTED;\n    }\n\n    (void)m;\n    (void)n;\n    (void)k;\n    (void)A;\n    (void)lda;\n    (void)B;\n    (void)ldb;\n    (void)C;\n    (void)ldc;\n    (void)ith;\n    (void)nth;\n    (void)Atype;\n    (void)Btype;\n    (void)precision;\n}\n\n}  // namespace\n\n/**\n * Performs optimized matrix multiplication on CPU.\n *\n * This subroutine may compute C = Aᵀ * B with column major ordering.\n * Despite its name, this isn't a generalized implementation. Work is\n * only performed when a handwritten kernel is written and available.\n * Otherwise the caller should fall back to a general matmul routine.\n *\n * For example, for single-threaded single-precision GEMM you can say\n *\n *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,\n *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,\n *                     GGML_PREC_DEFAULT);\n *\n * @param m is rows in `A` and `C`\n * @param n is cols in `B` and `C`\n * @param k is cols in `A` and rows in `B`\n * @param A is first input matrix (always transposed)\n * @param lda is row stride of `A`\n * @param B is second input matrix (never transposed)\n * @param ldb is row stride of `B`\n * @param C is input/output array of output matrices\n * @param ldc is row stride of `C`\n * @param ith is thread id (must be less than `nth`)\n * @param nth is number of threads (must be greater than zero)\n * @param Atype is GGML data type of `A`\n * @param Btype is GGML data type of `B`\n * @param Ctype is GGML data type of `C`\n * @param precision may be used to control the internal compute type\n * @return true if this function was able to service the matmul request\n */\nbool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    assert(m >= 0);\n    assert(n >= 0);\n    assert(k >= 0);\n    assert(lda >= k);\n    assert(ldb >= k);\n    assert(ldc >= m);\n    assert(nth > 0);\n    assert(ith < nth);\n\n#if QK_K == 256\n#if defined(__x86_64__) || defined(_M_X64)\n#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))\n    /* \n    moonll\n    more Btype accept\n    }*/\n\n    if (Ctype == GGML_TYPE_F32){\n        if (iqk_mul_mat(m, n, k * ggml_blck_size(ggml_type(Atype)), Atype, A,lda,Btype, B,ldb, (float*)C, ldc, ith, nth)) {\n            return true;\n        }\n    }\n\n#endif\n#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER\n    if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {\n        if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {\n            return true;\n        }\n    }\n    if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {\n        // assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);\n        assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));\n        if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {\n            return true;\n        }\n    }\n#endif\n#endif\n\n    switch (Ctype) {\n        case GGML_TYPE_F32:\n            return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,\n                                        Btype, Ctype, precision);\n        default:\n            return NOT_SUPPORTED;\n    }\n}\n"
  },
  {
    "path": "archive/third_party/llamafile/tinyblas_cpu_unsupported.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_unsupported.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"sgemm.h\"\n\nbool llamafile_sgemm_unsupported(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    return false;\n}\n\nbool llamafile_mixmul_unsupported(const struct ggml_compute_params* params,\n                                  const struct ggml_tensor* weights,\n                                  const struct ggml_tensor* thought,\n                                  const struct ggml_tensor* plan,\n                                  struct ggml_tensor* result) {\n    return false;\n}\n\nbool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int) {\n    return false;\n}\n"
  },
  {
    "path": "archive/third_party/nlohmann/json.hpp",
    "content": "//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n/****************************************************************************\\\n * Note on documentation: The source files contain links to the online      *\n * documentation of the public API at https://json.nlohmann.me. This URL    *\n * contains the most recent documentation and should also be applicable to  *\n * previous versions; documentation for deprecated functions is not         *\n * removed, but marked deprecated. See \"Generate documentation\" section in  *\n * file docs/README.md.                                                     *\n\\****************************************************************************/\n\n#ifndef INCLUDE_NLOHMANN_JSON_HPP_\n#define INCLUDE_NLOHMANN_JSON_HPP_\n\n#include <algorithm> // all_of, find, for_each\n#include <cstddef> // nullptr_t, ptrdiff_t, size_t\n#include <functional> // hash, less\n#include <initializer_list> // initializer_list\n#ifndef JSON_NO_IO\n    #include <iosfwd> // istream, ostream\n#endif  // JSON_NO_IO\n#include <iterator> // random_access_iterator_tag\n#include <memory> // unique_ptr\n#include <string> // string, stoi, to_string\n#include <utility> // declval, forward, move, pair, swap\n#include <vector> // vector\n\n// #include <nlohmann/adl_serializer.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <utility>\n\n// #include <nlohmann/detail/abi_macros.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n// This file contains all macro definitions affecting or depending on the ABI\n\n#ifndef JSON_SKIP_LIBRARY_VERSION_CHECK\n    #if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH)\n        #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 3\n            #warning \"Already included a different version of the library!\"\n        #endif\n    #endif\n#endif\n\n#define NLOHMANN_JSON_VERSION_MAJOR 3   // NOLINT(modernize-macro-to-enum)\n#define NLOHMANN_JSON_VERSION_MINOR 11  // NOLINT(modernize-macro-to-enum)\n#define NLOHMANN_JSON_VERSION_PATCH 3   // NOLINT(modernize-macro-to-enum)\n\n#ifndef JSON_DIAGNOSTICS\n    #define JSON_DIAGNOSTICS 0\n#endif\n\n#ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON\n    #define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0\n#endif\n\n#if JSON_DIAGNOSTICS\n    #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag\n#else\n    #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS\n#endif\n\n#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON\n    #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp\n#else\n    #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON\n#endif\n\n#ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION\n    #define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0\n#endif\n\n// Construct the namespace ABI tags component\n#define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) json_abi ## a ## b\n#define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b) \\\n    NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b)\n\n#define NLOHMANN_JSON_ABI_TAGS                                       \\\n    NLOHMANN_JSON_ABI_TAGS_CONCAT(                                   \\\n            NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS,                       \\\n            NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON)\n\n// Construct the namespace version component\n#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \\\n    _v ## major ## _ ## minor ## _ ## patch\n#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \\\n    NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch)\n\n#if NLOHMANN_JSON_NAMESPACE_NO_VERSION\n#define NLOHMANN_JSON_NAMESPACE_VERSION\n#else\n#define NLOHMANN_JSON_NAMESPACE_VERSION                                 \\\n    NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \\\n                                           NLOHMANN_JSON_VERSION_MINOR, \\\n                                           NLOHMANN_JSON_VERSION_PATCH)\n#endif\n\n// Combine namespace components\n#define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b\n#define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \\\n    NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b)\n\n#ifndef NLOHMANN_JSON_NAMESPACE\n#define NLOHMANN_JSON_NAMESPACE               \\\n    nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \\\n            NLOHMANN_JSON_ABI_TAGS,           \\\n            NLOHMANN_JSON_NAMESPACE_VERSION)\n#endif\n\n#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN\n#define NLOHMANN_JSON_NAMESPACE_BEGIN                \\\n    namespace nlohmann                               \\\n    {                                                \\\n    inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \\\n                NLOHMANN_JSON_ABI_TAGS,              \\\n                NLOHMANN_JSON_NAMESPACE_VERSION)     \\\n    {\n#endif\n\n#ifndef NLOHMANN_JSON_NAMESPACE_END\n#define NLOHMANN_JSON_NAMESPACE_END                                     \\\n    }  /* namespace (inline namespace) NOLINT(readability/namespace) */ \\\n    }  // namespace nlohmann\n#endif\n\n// #include <nlohmann/detail/conversions/from_json.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <algorithm> // transform\n#include <array> // array\n#include <forward_list> // forward_list\n#include <iterator> // inserter, front_inserter, end\n#include <map> // map\n#include <string> // string\n#include <tuple> // tuple, make_tuple\n#include <type_traits> // is_arithmetic, is_same, is_enum, underlying_type, is_convertible\n#include <unordered_map> // unordered_map\n#include <utility> // pair, declval\n#include <valarray> // valarray\n\n// #include <nlohmann/detail/exceptions.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cstddef> // nullptr_t\n#include <exception> // exception\n#if JSON_DIAGNOSTICS\n    #include <numeric> // accumulate\n#endif\n#include <stdexcept> // runtime_error\n#include <string> // to_string\n#include <vector> // vector\n\n// #include <nlohmann/detail/value_t.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <array> // array\n#include <cstddef> // size_t\n#include <cstdint> // uint8_t\n#include <string> // string\n\n// #include <nlohmann/detail/macro_scope.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <utility> // declval, pair\n// #include <nlohmann/detail/meta/detected.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <type_traits>\n\n// #include <nlohmann/detail/meta/void_t.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\ntemplate<typename ...Ts> struct make_void\n{\n    using type = void;\n};\ntemplate<typename ...Ts> using void_t = typename make_void<Ts...>::type;\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n// https://en.cppreference.com/w/cpp/experimental/is_detected\nstruct nonesuch\n{\n    nonesuch() = delete;\n    ~nonesuch() = delete;\n    nonesuch(nonesuch const&) = delete;\n    nonesuch(nonesuch const&&) = delete;\n    void operator=(nonesuch const&) = delete;\n    void operator=(nonesuch&&) = delete;\n};\n\ntemplate<class Default,\n         class AlwaysVoid,\n         template<class...> class Op,\n         class... Args>\nstruct detector\n{\n    using value_t = std::false_type;\n    using type = Default;\n};\n\ntemplate<class Default, template<class...> class Op, class... Args>\nstruct detector<Default, void_t<Op<Args...>>, Op, Args...>\n{\n    using value_t = std::true_type;\n    using type = Op<Args...>;\n};\n\ntemplate<template<class...> class Op, class... Args>\nusing is_detected = typename detector<nonesuch, void, Op, Args...>::value_t;\n\ntemplate<template<class...> class Op, class... Args>\nstruct is_detected_lazy : is_detected<Op, Args...> { };\n\ntemplate<template<class...> class Op, class... Args>\nusing detected_t = typename detector<nonesuch, void, Op, Args...>::type;\n\ntemplate<class Default, template<class...> class Op, class... Args>\nusing detected_or = detector<Default, void, Op, Args...>;\n\ntemplate<class Default, template<class...> class Op, class... Args>\nusing detected_or_t = typename detected_or<Default, Op, Args...>::type;\n\ntemplate<class Expected, template<class...> class Op, class... Args>\nusing is_detected_exact = std::is_same<Expected, detected_t<Op, Args...>>;\n\ntemplate<class To, template<class...> class Op, class... Args>\nusing is_detected_convertible =\n    std::is_convertible<detected_t<Op, Args...>, To>;\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/thirdparty/hedley/hedley.hpp>\n\n\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-FileCopyrightText: 2016-2021 Evan Nemerson <evan@nemerson.com>\n// SPDX-License-Identifier: MIT\n\n/* Hedley - https://nemequ.github.io/hedley\n * Created by Evan Nemerson <evan@nemerson.com>\n */\n\n#if !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < 15)\n#if defined(JSON_HEDLEY_VERSION)\n    #undef JSON_HEDLEY_VERSION\n#endif\n#define JSON_HEDLEY_VERSION 15\n\n#if defined(JSON_HEDLEY_STRINGIFY_EX)\n    #undef JSON_HEDLEY_STRINGIFY_EX\n#endif\n#define JSON_HEDLEY_STRINGIFY_EX(x) #x\n\n#if defined(JSON_HEDLEY_STRINGIFY)\n    #undef JSON_HEDLEY_STRINGIFY\n#endif\n#define JSON_HEDLEY_STRINGIFY(x) JSON_HEDLEY_STRINGIFY_EX(x)\n\n#if defined(JSON_HEDLEY_CONCAT_EX)\n    #undef JSON_HEDLEY_CONCAT_EX\n#endif\n#define JSON_HEDLEY_CONCAT_EX(a,b) a##b\n\n#if defined(JSON_HEDLEY_CONCAT)\n    #undef JSON_HEDLEY_CONCAT\n#endif\n#define JSON_HEDLEY_CONCAT(a,b) JSON_HEDLEY_CONCAT_EX(a,b)\n\n#if defined(JSON_HEDLEY_CONCAT3_EX)\n    #undef JSON_HEDLEY_CONCAT3_EX\n#endif\n#define JSON_HEDLEY_CONCAT3_EX(a,b,c) a##b##c\n\n#if defined(JSON_HEDLEY_CONCAT3)\n    #undef JSON_HEDLEY_CONCAT3\n#endif\n#define JSON_HEDLEY_CONCAT3(a,b,c) JSON_HEDLEY_CONCAT3_EX(a,b,c)\n\n#if defined(JSON_HEDLEY_VERSION_ENCODE)\n    #undef JSON_HEDLEY_VERSION_ENCODE\n#endif\n#define JSON_HEDLEY_VERSION_ENCODE(major,minor,revision) (((major) * 1000000) + ((minor) * 1000) + (revision))\n\n#if defined(JSON_HEDLEY_VERSION_DECODE_MAJOR)\n    #undef JSON_HEDLEY_VERSION_DECODE_MAJOR\n#endif\n#define JSON_HEDLEY_VERSION_DECODE_MAJOR(version) ((version) / 1000000)\n\n#if defined(JSON_HEDLEY_VERSION_DECODE_MINOR)\n    #undef JSON_HEDLEY_VERSION_DECODE_MINOR\n#endif\n#define JSON_HEDLEY_VERSION_DECODE_MINOR(version) (((version) % 1000000) / 1000)\n\n#if defined(JSON_HEDLEY_VERSION_DECODE_REVISION)\n    #undef JSON_HEDLEY_VERSION_DECODE_REVISION\n#endif\n#define JSON_HEDLEY_VERSION_DECODE_REVISION(version) ((version) % 1000)\n\n#if defined(JSON_HEDLEY_GNUC_VERSION)\n    #undef JSON_HEDLEY_GNUC_VERSION\n#endif\n#if defined(__GNUC__) && defined(__GNUC_PATCHLEVEL__)\n    #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__)\n#elif defined(__GNUC__)\n    #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, 0)\n#endif\n\n#if defined(JSON_HEDLEY_GNUC_VERSION_CHECK)\n    #undef JSON_HEDLEY_GNUC_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_GNUC_VERSION)\n    #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GNUC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_MSVC_VERSION)\n    #undef JSON_HEDLEY_MSVC_VERSION\n#endif\n#if defined(_MSC_FULL_VER) && (_MSC_FULL_VER >= 140000000) && !defined(__ICL)\n    #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 10000000, (_MSC_FULL_VER % 10000000) / 100000, (_MSC_FULL_VER % 100000) / 100)\n#elif defined(_MSC_FULL_VER) && !defined(__ICL)\n    #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 1000000, (_MSC_FULL_VER % 1000000) / 10000, (_MSC_FULL_VER % 10000) / 10)\n#elif defined(_MSC_VER) && !defined(__ICL)\n    #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_VER / 100, _MSC_VER % 100, 0)\n#endif\n\n#if defined(JSON_HEDLEY_MSVC_VERSION_CHECK)\n    #undef JSON_HEDLEY_MSVC_VERSION_CHECK\n#endif\n#if !defined(JSON_HEDLEY_MSVC_VERSION)\n    #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (0)\n#elif defined(_MSC_VER) && (_MSC_VER >= 1400)\n    #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 10000000) + (minor * 100000) + (patch)))\n#elif defined(_MSC_VER) && (_MSC_VER >= 1200)\n    #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 1000000) + (minor * 10000) + (patch)))\n#else\n    #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_VER >= ((major * 100) + (minor)))\n#endif\n\n#if defined(JSON_HEDLEY_INTEL_VERSION)\n    #undef JSON_HEDLEY_INTEL_VERSION\n#endif\n#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && !defined(__ICL)\n    #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, __INTEL_COMPILER_UPDATE)\n#elif defined(__INTEL_COMPILER) && !defined(__ICL)\n    #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, 0)\n#endif\n\n#if defined(JSON_HEDLEY_INTEL_VERSION_CHECK)\n    #undef JSON_HEDLEY_INTEL_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_INTEL_VERSION)\n    #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_INTEL_CL_VERSION)\n    #undef JSON_HEDLEY_INTEL_CL_VERSION\n#endif\n#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && defined(__ICL)\n    #define JSON_HEDLEY_INTEL_CL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER, __INTEL_COMPILER_UPDATE, 0)\n#endif\n\n#if defined(JSON_HEDLEY_INTEL_CL_VERSION_CHECK)\n    #undef JSON_HEDLEY_INTEL_CL_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_INTEL_CL_VERSION)\n    #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_CL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_PGI_VERSION)\n    #undef JSON_HEDLEY_PGI_VERSION\n#endif\n#if defined(__PGI) && defined(__PGIC__) && defined(__PGIC_MINOR__) && defined(__PGIC_PATCHLEVEL__)\n    #define JSON_HEDLEY_PGI_VERSION JSON_HEDLEY_VERSION_ENCODE(__PGIC__, __PGIC_MINOR__, __PGIC_PATCHLEVEL__)\n#endif\n\n#if defined(JSON_HEDLEY_PGI_VERSION_CHECK)\n    #undef JSON_HEDLEY_PGI_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_PGI_VERSION)\n    #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PGI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_SUNPRO_VERSION)\n    #undef JSON_HEDLEY_SUNPRO_VERSION\n#endif\n#if defined(__SUNPRO_C) && (__SUNPRO_C > 0x1000)\n    #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_C >> 16) & 0xf) * 10) + ((__SUNPRO_C >> 12) & 0xf), (((__SUNPRO_C >> 8) & 0xf) * 10) + ((__SUNPRO_C >> 4) & 0xf), (__SUNPRO_C & 0xf) * 10)\n#elif defined(__SUNPRO_C)\n    #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_C >> 8) & 0xf, (__SUNPRO_C >> 4) & 0xf, (__SUNPRO_C) & 0xf)\n#elif defined(__SUNPRO_CC) && (__SUNPRO_CC > 0x1000)\n    #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_CC >> 16) & 0xf) * 10) + ((__SUNPRO_CC >> 12) & 0xf), (((__SUNPRO_CC >> 8) & 0xf) * 10) + ((__SUNPRO_CC >> 4) & 0xf), (__SUNPRO_CC & 0xf) * 10)\n#elif defined(__SUNPRO_CC)\n    #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_CC >> 8) & 0xf, (__SUNPRO_CC >> 4) & 0xf, (__SUNPRO_CC) & 0xf)\n#endif\n\n#if defined(JSON_HEDLEY_SUNPRO_VERSION_CHECK)\n    #undef JSON_HEDLEY_SUNPRO_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_SUNPRO_VERSION)\n    #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_SUNPRO_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION)\n    #undef JSON_HEDLEY_EMSCRIPTEN_VERSION\n#endif\n#if defined(__EMSCRIPTEN__)\n    #define JSON_HEDLEY_EMSCRIPTEN_VERSION JSON_HEDLEY_VERSION_ENCODE(__EMSCRIPTEN_major__, __EMSCRIPTEN_minor__, __EMSCRIPTEN_tiny__)\n#endif\n\n#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK)\n    #undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION)\n    #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_EMSCRIPTEN_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_ARM_VERSION)\n    #undef JSON_HEDLEY_ARM_VERSION\n#endif\n#if defined(__CC_ARM) && defined(__ARMCOMPILER_VERSION)\n    #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCOMPILER_VERSION / 1000000, (__ARMCOMPILER_VERSION % 1000000) / 10000, (__ARMCOMPILER_VERSION % 10000) / 100)\n#elif defined(__CC_ARM) && defined(__ARMCC_VERSION)\n    #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCC_VERSION / 1000000, (__ARMCC_VERSION % 1000000) / 10000, (__ARMCC_VERSION % 10000) / 100)\n#endif\n\n#if defined(JSON_HEDLEY_ARM_VERSION_CHECK)\n    #undef JSON_HEDLEY_ARM_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_ARM_VERSION)\n    #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_ARM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_IBM_VERSION)\n    #undef JSON_HEDLEY_IBM_VERSION\n#endif\n#if defined(__ibmxl__)\n    #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ibmxl_version__, __ibmxl_release__, __ibmxl_modification__)\n#elif defined(__xlC__) && defined(__xlC_ver__)\n    #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, (__xlC_ver__ >> 8) & 0xff)\n#elif defined(__xlC__)\n    #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, 0)\n#endif\n\n#if defined(JSON_HEDLEY_IBM_VERSION_CHECK)\n    #undef JSON_HEDLEY_IBM_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_IBM_VERSION)\n    #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IBM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_TI_VERSION)\n    #undef JSON_HEDLEY_TI_VERSION\n#endif\n#if \\\n    defined(__TI_COMPILER_VERSION__) && \\\n    ( \\\n      defined(__TMS470__) || defined(__TI_ARM__) || \\\n      defined(__MSP430__) || \\\n      defined(__TMS320C2000__) \\\n    )\n#if (__TI_COMPILER_VERSION__ >= 16000000)\n    #define JSON_HEDLEY_TI_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))\n#endif\n#endif\n\n#if defined(JSON_HEDLEY_TI_VERSION_CHECK)\n    #undef JSON_HEDLEY_TI_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_TI_VERSION)\n    #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_TI_CL2000_VERSION)\n    #undef JSON_HEDLEY_TI_CL2000_VERSION\n#endif\n#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C2000__)\n    #define JSON_HEDLEY_TI_CL2000_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))\n#endif\n\n#if defined(JSON_HEDLEY_TI_CL2000_VERSION_CHECK)\n    #undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_TI_CL2000_VERSION)\n    #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL2000_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_TI_CL430_VERSION)\n    #undef JSON_HEDLEY_TI_CL430_VERSION\n#endif\n#if defined(__TI_COMPILER_VERSION__) && defined(__MSP430__)\n    #define JSON_HEDLEY_TI_CL430_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))\n#endif\n\n#if defined(JSON_HEDLEY_TI_CL430_VERSION_CHECK)\n    #undef JSON_HEDLEY_TI_CL430_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_TI_CL430_VERSION)\n    #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL430_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_TI_ARMCL_VERSION)\n    #undef JSON_HEDLEY_TI_ARMCL_VERSION\n#endif\n#if defined(__TI_COMPILER_VERSION__) && (defined(__TMS470__) || defined(__TI_ARM__))\n    #define JSON_HEDLEY_TI_ARMCL_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))\n#endif\n\n#if defined(JSON_HEDLEY_TI_ARMCL_VERSION_CHECK)\n    #undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_TI_ARMCL_VERSION)\n    #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_ARMCL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_TI_CL6X_VERSION)\n    #undef JSON_HEDLEY_TI_CL6X_VERSION\n#endif\n#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C6X__)\n    #define JSON_HEDLEY_TI_CL6X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))\n#endif\n\n#if defined(JSON_HEDLEY_TI_CL6X_VERSION_CHECK)\n    #undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_TI_CL6X_VERSION)\n    #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL6X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_TI_CL7X_VERSION)\n    #undef JSON_HEDLEY_TI_CL7X_VERSION\n#endif\n#if defined(__TI_COMPILER_VERSION__) && defined(__C7000__)\n    #define JSON_HEDLEY_TI_CL7X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))\n#endif\n\n#if defined(JSON_HEDLEY_TI_CL7X_VERSION_CHECK)\n    #undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_TI_CL7X_VERSION)\n    #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL7X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_TI_CLPRU_VERSION)\n    #undef JSON_HEDLEY_TI_CLPRU_VERSION\n#endif\n#if defined(__TI_COMPILER_VERSION__) && defined(__PRU__)\n    #define JSON_HEDLEY_TI_CLPRU_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))\n#endif\n\n#if defined(JSON_HEDLEY_TI_CLPRU_VERSION_CHECK)\n    #undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_TI_CLPRU_VERSION)\n    #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CLPRU_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_CRAY_VERSION)\n    #undef JSON_HEDLEY_CRAY_VERSION\n#endif\n#if defined(_CRAYC)\n    #if defined(_RELEASE_PATCHLEVEL)\n        #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, _RELEASE_PATCHLEVEL)\n    #else\n        #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, 0)\n    #endif\n#endif\n\n#if defined(JSON_HEDLEY_CRAY_VERSION_CHECK)\n    #undef JSON_HEDLEY_CRAY_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_CRAY_VERSION)\n    #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_CRAY_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_IAR_VERSION)\n    #undef JSON_HEDLEY_IAR_VERSION\n#endif\n#if defined(__IAR_SYSTEMS_ICC__)\n    #if __VER__ > 1000\n        #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE((__VER__ / 1000000), ((__VER__ / 1000) % 1000), (__VER__ % 1000))\n    #else\n        #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE(__VER__ / 100, __VER__ % 100, 0)\n    #endif\n#endif\n\n#if defined(JSON_HEDLEY_IAR_VERSION_CHECK)\n    #undef JSON_HEDLEY_IAR_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_IAR_VERSION)\n    #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IAR_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_TINYC_VERSION)\n    #undef JSON_HEDLEY_TINYC_VERSION\n#endif\n#if defined(__TINYC__)\n    #define JSON_HEDLEY_TINYC_VERSION JSON_HEDLEY_VERSION_ENCODE(__TINYC__ / 1000, (__TINYC__ / 100) % 10, __TINYC__ % 100)\n#endif\n\n#if defined(JSON_HEDLEY_TINYC_VERSION_CHECK)\n    #undef JSON_HEDLEY_TINYC_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_TINYC_VERSION)\n    #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TINYC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_DMC_VERSION)\n    #undef JSON_HEDLEY_DMC_VERSION\n#endif\n#if defined(__DMC__)\n    #define JSON_HEDLEY_DMC_VERSION JSON_HEDLEY_VERSION_ENCODE(__DMC__ >> 8, (__DMC__ >> 4) & 0xf, __DMC__ & 0xf)\n#endif\n\n#if defined(JSON_HEDLEY_DMC_VERSION_CHECK)\n    #undef JSON_HEDLEY_DMC_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_DMC_VERSION)\n    #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_DMC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_COMPCERT_VERSION)\n    #undef JSON_HEDLEY_COMPCERT_VERSION\n#endif\n#if defined(__COMPCERT_VERSION__)\n    #define JSON_HEDLEY_COMPCERT_VERSION JSON_HEDLEY_VERSION_ENCODE(__COMPCERT_VERSION__ / 10000, (__COMPCERT_VERSION__ / 100) % 100, __COMPCERT_VERSION__ % 100)\n#endif\n\n#if defined(JSON_HEDLEY_COMPCERT_VERSION_CHECK)\n    #undef JSON_HEDLEY_COMPCERT_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_COMPCERT_VERSION)\n    #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_COMPCERT_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_PELLES_VERSION)\n    #undef JSON_HEDLEY_PELLES_VERSION\n#endif\n#if defined(__POCC__)\n    #define JSON_HEDLEY_PELLES_VERSION JSON_HEDLEY_VERSION_ENCODE(__POCC__ / 100, __POCC__ % 100, 0)\n#endif\n\n#if defined(JSON_HEDLEY_PELLES_VERSION_CHECK)\n    #undef JSON_HEDLEY_PELLES_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_PELLES_VERSION)\n    #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PELLES_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_MCST_LCC_VERSION)\n    #undef JSON_HEDLEY_MCST_LCC_VERSION\n#endif\n#if defined(__LCC__) && defined(__LCC_MINOR__)\n    #define JSON_HEDLEY_MCST_LCC_VERSION JSON_HEDLEY_VERSION_ENCODE(__LCC__ / 100, __LCC__ % 100, __LCC_MINOR__)\n#endif\n\n#if defined(JSON_HEDLEY_MCST_LCC_VERSION_CHECK)\n    #undef JSON_HEDLEY_MCST_LCC_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_MCST_LCC_VERSION)\n    #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_MCST_LCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_GCC_VERSION)\n    #undef JSON_HEDLEY_GCC_VERSION\n#endif\n#if \\\n    defined(JSON_HEDLEY_GNUC_VERSION) && \\\n    !defined(__clang__) && \\\n    !defined(JSON_HEDLEY_INTEL_VERSION) && \\\n    !defined(JSON_HEDLEY_PGI_VERSION) && \\\n    !defined(JSON_HEDLEY_ARM_VERSION) && \\\n    !defined(JSON_HEDLEY_CRAY_VERSION) && \\\n    !defined(JSON_HEDLEY_TI_VERSION) && \\\n    !defined(JSON_HEDLEY_TI_ARMCL_VERSION) && \\\n    !defined(JSON_HEDLEY_TI_CL430_VERSION) && \\\n    !defined(JSON_HEDLEY_TI_CL2000_VERSION) && \\\n    !defined(JSON_HEDLEY_TI_CL6X_VERSION) && \\\n    !defined(JSON_HEDLEY_TI_CL7X_VERSION) && \\\n    !defined(JSON_HEDLEY_TI_CLPRU_VERSION) && \\\n    !defined(__COMPCERT__) && \\\n    !defined(JSON_HEDLEY_MCST_LCC_VERSION)\n    #define JSON_HEDLEY_GCC_VERSION JSON_HEDLEY_GNUC_VERSION\n#endif\n\n#if defined(JSON_HEDLEY_GCC_VERSION_CHECK)\n    #undef JSON_HEDLEY_GCC_VERSION_CHECK\n#endif\n#if defined(JSON_HEDLEY_GCC_VERSION)\n    #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))\n#else\n    #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (0)\n#endif\n\n#if defined(JSON_HEDLEY_HAS_ATTRIBUTE)\n    #undef JSON_HEDLEY_HAS_ATTRIBUTE\n#endif\n#if \\\n  defined(__has_attribute) && \\\n  ( \\\n    (!defined(JSON_HEDLEY_IAR_VERSION) || JSON_HEDLEY_IAR_VERSION_CHECK(8,5,9)) \\\n  )\n#  define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) __has_attribute(attribute)\n#else\n#  define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) (0)\n#endif\n\n#if defined(JSON_HEDLEY_GNUC_HAS_ATTRIBUTE)\n    #undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE\n#endif\n#if defined(__has_attribute)\n    #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute)\n#else\n    #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_GCC_HAS_ATTRIBUTE)\n    #undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE\n#endif\n#if defined(__has_attribute)\n    #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute)\n#else\n    #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE)\n    #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE\n#endif\n#if \\\n    defined(__has_cpp_attribute) && \\\n    defined(__cplusplus) && \\\n    (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0))\n    #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) __has_cpp_attribute(attribute)\n#else\n    #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) (0)\n#endif\n\n#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS)\n    #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS\n#endif\n#if !defined(__cplusplus) || !defined(__has_cpp_attribute)\n    #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0)\n#elif \\\n    !defined(JSON_HEDLEY_PGI_VERSION) && \\\n    !defined(JSON_HEDLEY_IAR_VERSION) && \\\n    (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) && \\\n    (!defined(JSON_HEDLEY_MSVC_VERSION) || JSON_HEDLEY_MSVC_VERSION_CHECK(19,20,0))\n    #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(ns::attribute)\n#else\n    #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0)\n#endif\n\n#if defined(JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE)\n    #undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE\n#endif\n#if defined(__has_cpp_attribute) && defined(__cplusplus)\n    #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute)\n#else\n    #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE)\n    #undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE\n#endif\n#if defined(__has_cpp_attribute) && defined(__cplusplus)\n    #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute)\n#else\n    #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_HAS_BUILTIN)\n    #undef JSON_HEDLEY_HAS_BUILTIN\n#endif\n#if defined(__has_builtin)\n    #define JSON_HEDLEY_HAS_BUILTIN(builtin) __has_builtin(builtin)\n#else\n    #define JSON_HEDLEY_HAS_BUILTIN(builtin) (0)\n#endif\n\n#if defined(JSON_HEDLEY_GNUC_HAS_BUILTIN)\n    #undef JSON_HEDLEY_GNUC_HAS_BUILTIN\n#endif\n#if defined(__has_builtin)\n    #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin)\n#else\n    #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_GCC_HAS_BUILTIN)\n    #undef JSON_HEDLEY_GCC_HAS_BUILTIN\n#endif\n#if defined(__has_builtin)\n    #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin)\n#else\n    #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_HAS_FEATURE)\n    #undef JSON_HEDLEY_HAS_FEATURE\n#endif\n#if defined(__has_feature)\n    #define JSON_HEDLEY_HAS_FEATURE(feature) __has_feature(feature)\n#else\n    #define JSON_HEDLEY_HAS_FEATURE(feature) (0)\n#endif\n\n#if defined(JSON_HEDLEY_GNUC_HAS_FEATURE)\n    #undef JSON_HEDLEY_GNUC_HAS_FEATURE\n#endif\n#if defined(__has_feature)\n    #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature)\n#else\n    #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_GCC_HAS_FEATURE)\n    #undef JSON_HEDLEY_GCC_HAS_FEATURE\n#endif\n#if defined(__has_feature)\n    #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature)\n#else\n    #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_HAS_EXTENSION)\n    #undef JSON_HEDLEY_HAS_EXTENSION\n#endif\n#if defined(__has_extension)\n    #define JSON_HEDLEY_HAS_EXTENSION(extension) __has_extension(extension)\n#else\n    #define JSON_HEDLEY_HAS_EXTENSION(extension) (0)\n#endif\n\n#if defined(JSON_HEDLEY_GNUC_HAS_EXTENSION)\n    #undef JSON_HEDLEY_GNUC_HAS_EXTENSION\n#endif\n#if defined(__has_extension)\n    #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension)\n#else\n    #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_GCC_HAS_EXTENSION)\n    #undef JSON_HEDLEY_GCC_HAS_EXTENSION\n#endif\n#if defined(__has_extension)\n    #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension)\n#else\n    #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE)\n    #undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE\n#endif\n#if defined(__has_declspec_attribute)\n    #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) __has_declspec_attribute(attribute)\n#else\n    #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) (0)\n#endif\n\n#if defined(JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE)\n    #undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE\n#endif\n#if defined(__has_declspec_attribute)\n    #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute)\n#else\n    #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE)\n    #undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE\n#endif\n#if defined(__has_declspec_attribute)\n    #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute)\n#else\n    #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_HAS_WARNING)\n    #undef JSON_HEDLEY_HAS_WARNING\n#endif\n#if defined(__has_warning)\n    #define JSON_HEDLEY_HAS_WARNING(warning) __has_warning(warning)\n#else\n    #define JSON_HEDLEY_HAS_WARNING(warning) (0)\n#endif\n\n#if defined(JSON_HEDLEY_GNUC_HAS_WARNING)\n    #undef JSON_HEDLEY_GNUC_HAS_WARNING\n#endif\n#if defined(__has_warning)\n    #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning)\n#else\n    #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_GCC_HAS_WARNING)\n    #undef JSON_HEDLEY_GCC_HAS_WARNING\n#endif\n#if defined(__has_warning)\n    #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning)\n#else\n    #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if \\\n    (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \\\n    defined(__clang__) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \\\n    JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,0,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n    JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) || \\\n    JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,17) || \\\n    JSON_HEDLEY_SUNPRO_VERSION_CHECK(8,0,0) || \\\n    (JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) && defined(__C99_PRAGMA_OPERATOR))\n    #define JSON_HEDLEY_PRAGMA(value) _Pragma(#value)\n#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0)\n    #define JSON_HEDLEY_PRAGMA(value) __pragma(value)\n#else\n    #define JSON_HEDLEY_PRAGMA(value)\n#endif\n\n#if defined(JSON_HEDLEY_DIAGNOSTIC_PUSH)\n    #undef JSON_HEDLEY_DIAGNOSTIC_PUSH\n#endif\n#if defined(JSON_HEDLEY_DIAGNOSTIC_POP)\n    #undef JSON_HEDLEY_DIAGNOSTIC_POP\n#endif\n#if defined(__clang__)\n    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma(\"clang diagnostic push\")\n    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma(\"clang diagnostic pop\")\n#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma(\"warning(push)\")\n    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma(\"warning(pop)\")\n#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma(\"GCC diagnostic push\")\n    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma(\"GCC diagnostic pop\")\n#elif \\\n    JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_PUSH __pragma(warning(push))\n    #define JSON_HEDLEY_DIAGNOSTIC_POP __pragma(warning(pop))\n#elif JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma(\"push\")\n    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma(\"pop\")\n#elif \\\n    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,4,0) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma(\"diag_push\")\n    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma(\"diag_pop\")\n#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma(\"warning(push)\")\n    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma(\"warning(pop)\")\n#else\n    #define JSON_HEDLEY_DIAGNOSTIC_PUSH\n    #define JSON_HEDLEY_DIAGNOSTIC_POP\n#endif\n\n/* JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ is for\n   HEDLEY INTERNAL USE ONLY.  API subject to change without notice. */\n#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_)\n    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_\n#endif\n#if defined(__cplusplus)\n#  if JSON_HEDLEY_HAS_WARNING(\"-Wc++98-compat\")\n#    if JSON_HEDLEY_HAS_WARNING(\"-Wc++17-extensions\")\n#      if JSON_HEDLEY_HAS_WARNING(\"-Wc++1z-extensions\")\n#        define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \\\n    JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n    _Pragma(\"clang diagnostic ignored \\\"-Wc++98-compat\\\"\") \\\n    _Pragma(\"clang diagnostic ignored \\\"-Wc++17-extensions\\\"\") \\\n    _Pragma(\"clang diagnostic ignored \\\"-Wc++1z-extensions\\\"\") \\\n    xpr \\\n    JSON_HEDLEY_DIAGNOSTIC_POP\n#      else\n#        define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \\\n    JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n    _Pragma(\"clang diagnostic ignored \\\"-Wc++98-compat\\\"\") \\\n    _Pragma(\"clang diagnostic ignored \\\"-Wc++17-extensions\\\"\") \\\n    xpr \\\n    JSON_HEDLEY_DIAGNOSTIC_POP\n#      endif\n#    else\n#      define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \\\n    JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n    _Pragma(\"clang diagnostic ignored \\\"-Wc++98-compat\\\"\") \\\n    xpr \\\n    JSON_HEDLEY_DIAGNOSTIC_POP\n#    endif\n#  endif\n#endif\n#if !defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(x) x\n#endif\n\n#if defined(JSON_HEDLEY_CONST_CAST)\n    #undef JSON_HEDLEY_CONST_CAST\n#endif\n#if defined(__cplusplus)\n#  define JSON_HEDLEY_CONST_CAST(T, expr) (const_cast<T>(expr))\n#elif \\\n  JSON_HEDLEY_HAS_WARNING(\"-Wcast-qual\") || \\\n  JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) || \\\n  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)\n#  define JSON_HEDLEY_CONST_CAST(T, expr) (__extension__ ({ \\\n        JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n        JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL \\\n        ((T) (expr)); \\\n        JSON_HEDLEY_DIAGNOSTIC_POP \\\n    }))\n#else\n#  define JSON_HEDLEY_CONST_CAST(T, expr) ((T) (expr))\n#endif\n\n#if defined(JSON_HEDLEY_REINTERPRET_CAST)\n    #undef JSON_HEDLEY_REINTERPRET_CAST\n#endif\n#if defined(__cplusplus)\n    #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) (reinterpret_cast<T>(expr))\n#else\n    #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) ((T) (expr))\n#endif\n\n#if defined(JSON_HEDLEY_STATIC_CAST)\n    #undef JSON_HEDLEY_STATIC_CAST\n#endif\n#if defined(__cplusplus)\n    #define JSON_HEDLEY_STATIC_CAST(T, expr) (static_cast<T>(expr))\n#else\n    #define JSON_HEDLEY_STATIC_CAST(T, expr) ((T) (expr))\n#endif\n\n#if defined(JSON_HEDLEY_CPP_CAST)\n    #undef JSON_HEDLEY_CPP_CAST\n#endif\n#if defined(__cplusplus)\n#  if JSON_HEDLEY_HAS_WARNING(\"-Wold-style-cast\")\n#    define JSON_HEDLEY_CPP_CAST(T, expr) \\\n    JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n    _Pragma(\"clang diagnostic ignored \\\"-Wold-style-cast\\\"\") \\\n    ((T) (expr)) \\\n    JSON_HEDLEY_DIAGNOSTIC_POP\n#  elif JSON_HEDLEY_IAR_VERSION_CHECK(8,3,0)\n#    define JSON_HEDLEY_CPP_CAST(T, expr) \\\n    JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n    _Pragma(\"diag_suppress=Pe137\") \\\n    JSON_HEDLEY_DIAGNOSTIC_POP\n#  else\n#    define JSON_HEDLEY_CPP_CAST(T, expr) ((T) (expr))\n#  endif\n#else\n#  define JSON_HEDLEY_CPP_CAST(T, expr) (expr)\n#endif\n\n#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED)\n    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED\n#endif\n#if JSON_HEDLEY_HAS_WARNING(\"-Wdeprecated-declarations\")\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"clang diagnostic ignored \\\"-Wdeprecated-declarations\\\"\")\n#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"warning(disable:1478 1786)\")\n#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:1478 1786))\n#elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"diag_suppress 1215,1216,1444,1445\")\n#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"diag_suppress 1215,1444\")\n#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"GCC diagnostic ignored \\\"-Wdeprecated-declarations\\\"\")\n#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:4996))\n#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"diag_suppress 1215,1444\")\n#elif \\\n    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"diag_suppress 1291,1718\")\n#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && !defined(__cplusplus)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"error_messages(off,E_DEPRECATED_ATT,E_DEPRECATED_ATT_MESS)\")\n#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && defined(__cplusplus)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"error_messages(off,symdeprecated,symdeprecated2)\")\n#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"diag_suppress=Pe1444,Pe1215\")\n#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma(\"warn(disable:2241)\")\n#else\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED\n#endif\n\n#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS)\n    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS\n#endif\n#if JSON_HEDLEY_HAS_WARNING(\"-Wunknown-pragmas\")\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma(\"clang diagnostic ignored \\\"-Wunknown-pragmas\\\"\")\n#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma(\"warning(disable:161)\")\n#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:161))\n#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma(\"diag_suppress 1675\")\n#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma(\"GCC diagnostic ignored \\\"-Wunknown-pragmas\\\"\")\n#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:4068))\n#elif \\\n    JSON_HEDLEY_TI_VERSION_CHECK(16,9,0) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma(\"diag_suppress 163\")\n#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma(\"diag_suppress 163\")\n#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma(\"diag_suppress=Pe161\")\n#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma(\"diag_suppress 161\")\n#else\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS\n#endif\n\n#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES)\n    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES\n#endif\n#if JSON_HEDLEY_HAS_WARNING(\"-Wunknown-attributes\")\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma(\"clang diagnostic ignored \\\"-Wunknown-attributes\\\"\")\n#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma(\"GCC diagnostic ignored \\\"-Wdeprecated-declarations\\\"\")\n#elif JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma(\"warning(disable:1292)\")\n#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:1292))\n#elif JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:5030))\n#elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma(\"diag_suppress 1097,1098\")\n#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma(\"diag_suppress 1097\")\n#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma(\"error_messages(off,attrskipunsup)\")\n#elif \\\n    JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma(\"diag_suppress 1173\")\n#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma(\"diag_suppress=Pe1097\")\n#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma(\"diag_suppress 1097\")\n#else\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES\n#endif\n\n#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL)\n    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL\n#endif\n#if JSON_HEDLEY_HAS_WARNING(\"-Wcast-qual\")\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma(\"clang diagnostic ignored \\\"-Wcast-qual\\\"\")\n#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma(\"warning(disable:2203 2331)\")\n#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma(\"GCC diagnostic ignored \\\"-Wcast-qual\\\"\")\n#else\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL\n#endif\n\n#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION)\n    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION\n#endif\n#if JSON_HEDLEY_HAS_WARNING(\"-Wunused-function\")\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma(\"clang diagnostic ignored \\\"-Wunused-function\\\"\")\n#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma(\"GCC diagnostic ignored \\\"-Wunused-function\\\"\")\n#elif JSON_HEDLEY_MSVC_VERSION_CHECK(1,0,0)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION __pragma(warning(disable:4505))\n#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma(\"diag_suppress 3142\")\n#else\n    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION\n#endif\n\n#if defined(JSON_HEDLEY_DEPRECATED)\n    #undef JSON_HEDLEY_DEPRECATED\n#endif\n#if defined(JSON_HEDLEY_DEPRECATED_FOR)\n    #undef JSON_HEDLEY_DEPRECATED_FOR\n#endif\n#if \\\n    JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated(\"Since \" # since))\n    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated(\"Since \" #since \"; use \" #replacement))\n#elif \\\n    (JSON_HEDLEY_HAS_EXTENSION(attribute_deprecated_with_message) && !defined(JSON_HEDLEY_IAR_VERSION)) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \\\n    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) || \\\n    JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \\\n    JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(18,1,0) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__(\"Since \" #since)))\n    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__(\"Since \" #since \"; use \" #replacement)))\n#elif defined(__cplusplus) && (__cplusplus >= 201402L)\n    #define JSON_HEDLEY_DEPRECATED(since) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated(\"Since \" #since)]])\n    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated(\"Since \" #since \"; use \" #replacement)]])\n#elif \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(deprecated) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \\\n    JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0)\n    #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__))\n    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__))\n#elif \\\n    JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \\\n    JSON_HEDLEY_PELLES_VERSION_CHECK(6,50,0) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated)\n    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated)\n#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)\n    #define JSON_HEDLEY_DEPRECATED(since) _Pragma(\"deprecated\")\n    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) _Pragma(\"deprecated\")\n#else\n    #define JSON_HEDLEY_DEPRECATED(since)\n    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement)\n#endif\n\n#if defined(JSON_HEDLEY_UNAVAILABLE)\n    #undef JSON_HEDLEY_UNAVAILABLE\n#endif\n#if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(warning) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_UNAVAILABLE(available_since) __attribute__((__warning__(\"Not available until \" #available_since)))\n#else\n    #define JSON_HEDLEY_UNAVAILABLE(available_since)\n#endif\n\n#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT)\n    #undef JSON_HEDLEY_WARN_UNUSED_RESULT\n#endif\n#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT_MSG)\n    #undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG\n#endif\n#if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(warn_unused_result) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n    (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \\\n    JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_WARN_UNUSED_RESULT __attribute__((__warn_unused_result__))\n    #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) __attribute__((__warn_unused_result__))\n#elif (JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) >= 201907L)\n    #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]])\n    #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard(msg)]])\n#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard)\n    #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]])\n    #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]])\n#elif defined(_Check_return_) /* SAL */\n    #define JSON_HEDLEY_WARN_UNUSED_RESULT _Check_return_\n    #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) _Check_return_\n#else\n    #define JSON_HEDLEY_WARN_UNUSED_RESULT\n    #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg)\n#endif\n\n#if defined(JSON_HEDLEY_SENTINEL)\n    #undef JSON_HEDLEY_SENTINEL\n#endif\n#if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(sentinel) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_SENTINEL(position) __attribute__((__sentinel__(position)))\n#else\n    #define JSON_HEDLEY_SENTINEL(position)\n#endif\n\n#if defined(JSON_HEDLEY_NO_RETURN)\n    #undef JSON_HEDLEY_NO_RETURN\n#endif\n#if JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)\n    #define JSON_HEDLEY_NO_RETURN __noreturn\n#elif \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__))\n#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L\n    #define JSON_HEDLEY_NO_RETURN _Noreturn\n#elif defined(__cplusplus) && (__cplusplus >= 201103L)\n    #define JSON_HEDLEY_NO_RETURN JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[noreturn]])\n#elif \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(noreturn) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,2,0) || \\\n    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n    JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \\\n    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n    JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0)\n    #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__))\n#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0)\n    #define JSON_HEDLEY_NO_RETURN _Pragma(\"does_not_return\")\n#elif \\\n    JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_NO_RETURN __declspec(noreturn)\n#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus)\n    #define JSON_HEDLEY_NO_RETURN _Pragma(\"FUNC_NEVER_RETURNS;\")\n#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0)\n    #define JSON_HEDLEY_NO_RETURN __attribute((noreturn))\n#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0)\n    #define JSON_HEDLEY_NO_RETURN __declspec(noreturn)\n#else\n    #define JSON_HEDLEY_NO_RETURN\n#endif\n\n#if defined(JSON_HEDLEY_NO_ESCAPE)\n    #undef JSON_HEDLEY_NO_ESCAPE\n#endif\n#if JSON_HEDLEY_HAS_ATTRIBUTE(noescape)\n    #define JSON_HEDLEY_NO_ESCAPE __attribute__((__noescape__))\n#else\n    #define JSON_HEDLEY_NO_ESCAPE\n#endif\n\n#if defined(JSON_HEDLEY_UNREACHABLE)\n    #undef JSON_HEDLEY_UNREACHABLE\n#endif\n#if defined(JSON_HEDLEY_UNREACHABLE_RETURN)\n    #undef JSON_HEDLEY_UNREACHABLE_RETURN\n#endif\n#if defined(JSON_HEDLEY_ASSUME)\n    #undef JSON_HEDLEY_ASSUME\n#endif\n#if \\\n    JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_ASSUME(expr) __assume(expr)\n#elif JSON_HEDLEY_HAS_BUILTIN(__builtin_assume)\n    #define JSON_HEDLEY_ASSUME(expr) __builtin_assume(expr)\n#elif \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0)\n    #if defined(__cplusplus)\n        #define JSON_HEDLEY_ASSUME(expr) std::_nassert(expr)\n    #else\n        #define JSON_HEDLEY_ASSUME(expr) _nassert(expr)\n    #endif\n#endif\n#if \\\n    (JSON_HEDLEY_HAS_BUILTIN(__builtin_unreachable) && (!defined(JSON_HEDLEY_ARM_VERSION))) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \\\n    JSON_HEDLEY_PGI_VERSION_CHECK(18,10,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_IBM_VERSION_CHECK(13,1,5) || \\\n    JSON_HEDLEY_CRAY_VERSION_CHECK(10,0,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_UNREACHABLE() __builtin_unreachable()\n#elif defined(JSON_HEDLEY_ASSUME)\n    #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0)\n#endif\n#if !defined(JSON_HEDLEY_ASSUME)\n    #if defined(JSON_HEDLEY_UNREACHABLE)\n        #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, ((expr) ? 1 : (JSON_HEDLEY_UNREACHABLE(), 1)))\n    #else\n        #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, expr)\n    #endif\n#endif\n#if defined(JSON_HEDLEY_UNREACHABLE)\n    #if  \\\n        JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \\\n        JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0)\n        #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (JSON_HEDLEY_STATIC_CAST(void, JSON_HEDLEY_ASSUME(0)), (value))\n    #else\n        #define JSON_HEDLEY_UNREACHABLE_RETURN(value) JSON_HEDLEY_UNREACHABLE()\n    #endif\n#else\n    #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (value)\n#endif\n#if !defined(JSON_HEDLEY_UNREACHABLE)\n    #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0)\n#endif\n\nJSON_HEDLEY_DIAGNOSTIC_PUSH\n#if JSON_HEDLEY_HAS_WARNING(\"-Wpedantic\")\n    #pragma clang diagnostic ignored \"-Wpedantic\"\n#endif\n#if JSON_HEDLEY_HAS_WARNING(\"-Wc++98-compat-pedantic\") && defined(__cplusplus)\n    #pragma clang diagnostic ignored \"-Wc++98-compat-pedantic\"\n#endif\n#if JSON_HEDLEY_GCC_HAS_WARNING(\"-Wvariadic-macros\",4,0,0)\n    #if defined(__clang__)\n        #pragma clang diagnostic ignored \"-Wvariadic-macros\"\n    #elif defined(JSON_HEDLEY_GCC_VERSION)\n        #pragma GCC diagnostic ignored \"-Wvariadic-macros\"\n    #endif\n#endif\n#if defined(JSON_HEDLEY_NON_NULL)\n    #undef JSON_HEDLEY_NON_NULL\n#endif\n#if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(nonnull) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0)\n    #define JSON_HEDLEY_NON_NULL(...) __attribute__((__nonnull__(__VA_ARGS__)))\n#else\n    #define JSON_HEDLEY_NON_NULL(...)\n#endif\nJSON_HEDLEY_DIAGNOSTIC_POP\n\n#if defined(JSON_HEDLEY_PRINTF_FORMAT)\n    #undef JSON_HEDLEY_PRINTF_FORMAT\n#endif\n#if defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && !defined(__USE_MINGW_ANSI_STDIO)\n    #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(ms_printf, string_idx, first_to_check)))\n#elif defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && defined(__USE_MINGW_ANSI_STDIO)\n    #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(gnu_printf, string_idx, first_to_check)))\n#elif \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(format) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \\\n    JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \\\n    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(__printf__, string_idx, first_to_check)))\n#elif JSON_HEDLEY_PELLES_VERSION_CHECK(6,0,0)\n    #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __declspec(vaformat(printf,string_idx,first_to_check))\n#else\n    #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check)\n#endif\n\n#if defined(JSON_HEDLEY_CONSTEXPR)\n    #undef JSON_HEDLEY_CONSTEXPR\n#endif\n#if defined(__cplusplus)\n    #if __cplusplus >= 201103L\n        #define JSON_HEDLEY_CONSTEXPR JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(constexpr)\n    #endif\n#endif\n#if !defined(JSON_HEDLEY_CONSTEXPR)\n    #define JSON_HEDLEY_CONSTEXPR\n#endif\n\n#if defined(JSON_HEDLEY_PREDICT)\n    #undef JSON_HEDLEY_PREDICT\n#endif\n#if defined(JSON_HEDLEY_LIKELY)\n    #undef JSON_HEDLEY_LIKELY\n#endif\n#if defined(JSON_HEDLEY_UNLIKELY)\n    #undef JSON_HEDLEY_UNLIKELY\n#endif\n#if defined(JSON_HEDLEY_UNPREDICTABLE)\n    #undef JSON_HEDLEY_UNPREDICTABLE\n#endif\n#if JSON_HEDLEY_HAS_BUILTIN(__builtin_unpredictable)\n    #define JSON_HEDLEY_UNPREDICTABLE(expr) __builtin_unpredictable((expr))\n#endif\n#if \\\n  (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect_with_probability) && !defined(JSON_HEDLEY_PGI_VERSION)) || \\\n  JSON_HEDLEY_GCC_VERSION_CHECK(9,0,0) || \\\n  JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n#  define JSON_HEDLEY_PREDICT(expr, value, probability) __builtin_expect_with_probability(  (expr), (value), (probability))\n#  define JSON_HEDLEY_PREDICT_TRUE(expr, probability)   __builtin_expect_with_probability(!!(expr),    1   , (probability))\n#  define JSON_HEDLEY_PREDICT_FALSE(expr, probability)  __builtin_expect_with_probability(!!(expr),    0   , (probability))\n#  define JSON_HEDLEY_LIKELY(expr)                      __builtin_expect                 (!!(expr),    1                  )\n#  define JSON_HEDLEY_UNLIKELY(expr)                    __builtin_expect                 (!!(expr),    0                  )\n#elif \\\n  (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \\\n  JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \\\n  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n  (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \\\n  JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n  JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \\\n  JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n  JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \\\n  JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \\\n  JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \\\n  JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \\\n  JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n  JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n  JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,27) || \\\n  JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \\\n  JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n#  define JSON_HEDLEY_PREDICT(expr, expected, probability) \\\n    (((probability) >= 0.9) ? __builtin_expect((expr), (expected)) : (JSON_HEDLEY_STATIC_CAST(void, expected), (expr)))\n#  define JSON_HEDLEY_PREDICT_TRUE(expr, probability) \\\n    (__extension__ ({ \\\n        double hedley_probability_ = (probability); \\\n        ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 1) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 0) : !!(expr))); \\\n    }))\n#  define JSON_HEDLEY_PREDICT_FALSE(expr, probability) \\\n    (__extension__ ({ \\\n        double hedley_probability_ = (probability); \\\n        ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 0) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 1) : !!(expr))); \\\n    }))\n#  define JSON_HEDLEY_LIKELY(expr)   __builtin_expect(!!(expr), 1)\n#  define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect(!!(expr), 0)\n#else\n#  define JSON_HEDLEY_PREDICT(expr, expected, probability) (JSON_HEDLEY_STATIC_CAST(void, expected), (expr))\n#  define JSON_HEDLEY_PREDICT_TRUE(expr, probability) (!!(expr))\n#  define JSON_HEDLEY_PREDICT_FALSE(expr, probability) (!!(expr))\n#  define JSON_HEDLEY_LIKELY(expr) (!!(expr))\n#  define JSON_HEDLEY_UNLIKELY(expr) (!!(expr))\n#endif\n#if !defined(JSON_HEDLEY_UNPREDICTABLE)\n    #define JSON_HEDLEY_UNPREDICTABLE(expr) JSON_HEDLEY_PREDICT(expr, 1, 0.5)\n#endif\n\n#if defined(JSON_HEDLEY_MALLOC)\n    #undef JSON_HEDLEY_MALLOC\n#endif\n#if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(malloc) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n    JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \\\n    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_MALLOC __attribute__((__malloc__))\n#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0)\n    #define JSON_HEDLEY_MALLOC _Pragma(\"returns_new_memory\")\n#elif \\\n    JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_MALLOC __declspec(restrict)\n#else\n    #define JSON_HEDLEY_MALLOC\n#endif\n\n#if defined(JSON_HEDLEY_PURE)\n    #undef JSON_HEDLEY_PURE\n#endif\n#if \\\n  JSON_HEDLEY_HAS_ATTRIBUTE(pure) || \\\n  JSON_HEDLEY_GCC_VERSION_CHECK(2,96,0) || \\\n  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n  JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \\\n  JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n  JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \\\n  JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n  (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n  JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n  (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n  JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n  (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n  JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n  (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n  JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \\\n  JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n  JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n  JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \\\n  JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n#  define JSON_HEDLEY_PURE __attribute__((__pure__))\n#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0)\n#  define JSON_HEDLEY_PURE _Pragma(\"does_not_write_global_data\")\n#elif defined(__cplusplus) && \\\n    ( \\\n      JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \\\n      JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) || \\\n      JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) \\\n    )\n#  define JSON_HEDLEY_PURE _Pragma(\"FUNC_IS_PURE;\")\n#else\n#  define JSON_HEDLEY_PURE\n#endif\n\n#if defined(JSON_HEDLEY_CONST)\n    #undef JSON_HEDLEY_CONST\n#endif\n#if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(const) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(2,5,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n    JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \\\n    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n    JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_CONST __attribute__((__const__))\n#elif \\\n    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0)\n    #define JSON_HEDLEY_CONST _Pragma(\"no_side_effect\")\n#else\n    #define JSON_HEDLEY_CONST JSON_HEDLEY_PURE\n#endif\n\n#if defined(JSON_HEDLEY_RESTRICT)\n    #undef JSON_HEDLEY_RESTRICT\n#endif\n#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && !defined(__cplusplus)\n    #define JSON_HEDLEY_RESTRICT restrict\n#elif \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \\\n    JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n    JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \\\n    JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,4) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus)) || \\\n    JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \\\n    defined(__clang__) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_RESTRICT __restrict\n#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,3,0) && !defined(__cplusplus)\n    #define JSON_HEDLEY_RESTRICT _Restrict\n#else\n    #define JSON_HEDLEY_RESTRICT\n#endif\n\n#if defined(JSON_HEDLEY_INLINE)\n    #undef JSON_HEDLEY_INLINE\n#endif\n#if \\\n    (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \\\n    (defined(__cplusplus) && (__cplusplus >= 199711L))\n    #define JSON_HEDLEY_INLINE inline\n#elif \\\n    defined(JSON_HEDLEY_GCC_VERSION) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(6,2,0)\n    #define JSON_HEDLEY_INLINE __inline__\n#elif \\\n    JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,1,0) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_INLINE __inline\n#else\n    #define JSON_HEDLEY_INLINE\n#endif\n\n#if defined(JSON_HEDLEY_ALWAYS_INLINE)\n    #undef JSON_HEDLEY_ALWAYS_INLINE\n#endif\n#if \\\n  JSON_HEDLEY_HAS_ATTRIBUTE(always_inline) || \\\n  JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \\\n  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n  JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \\\n  JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n  JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \\\n  JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n  (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n  JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n  (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n  JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n  (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n  JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n  (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n  JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \\\n  JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n  JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n  JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \\\n  JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0)\n#  define JSON_HEDLEY_ALWAYS_INLINE __attribute__((__always_inline__)) JSON_HEDLEY_INLINE\n#elif \\\n  JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \\\n  JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n#  define JSON_HEDLEY_ALWAYS_INLINE __forceinline\n#elif defined(__cplusplus) && \\\n    ( \\\n      JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n      JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n      JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n      JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \\\n      JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n      JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) \\\n    )\n#  define JSON_HEDLEY_ALWAYS_INLINE _Pragma(\"FUNC_ALWAYS_INLINE;\")\n#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)\n#  define JSON_HEDLEY_ALWAYS_INLINE _Pragma(\"inline=forced\")\n#else\n#  define JSON_HEDLEY_ALWAYS_INLINE JSON_HEDLEY_INLINE\n#endif\n\n#if defined(JSON_HEDLEY_NEVER_INLINE)\n    #undef JSON_HEDLEY_NEVER_INLINE\n#endif\n#if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(noinline) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n    JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \\\n    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \\\n    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \\\n    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \\\n    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \\\n    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \\\n    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \\\n    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \\\n    JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0)\n    #define JSON_HEDLEY_NEVER_INLINE __attribute__((__noinline__))\n#elif \\\n    JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline)\n#elif JSON_HEDLEY_PGI_VERSION_CHECK(10,2,0)\n    #define JSON_HEDLEY_NEVER_INLINE _Pragma(\"noinline\")\n#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus)\n    #define JSON_HEDLEY_NEVER_INLINE _Pragma(\"FUNC_CANNOT_INLINE;\")\n#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)\n    #define JSON_HEDLEY_NEVER_INLINE _Pragma(\"inline=never\")\n#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0)\n    #define JSON_HEDLEY_NEVER_INLINE __attribute((noinline))\n#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0)\n    #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline)\n#else\n    #define JSON_HEDLEY_NEVER_INLINE\n#endif\n\n#if defined(JSON_HEDLEY_PRIVATE)\n    #undef JSON_HEDLEY_PRIVATE\n#endif\n#if defined(JSON_HEDLEY_PUBLIC)\n    #undef JSON_HEDLEY_PUBLIC\n#endif\n#if defined(JSON_HEDLEY_IMPORT)\n    #undef JSON_HEDLEY_IMPORT\n#endif\n#if defined(_WIN32) || defined(__CYGWIN__)\n#  define JSON_HEDLEY_PRIVATE\n#  define JSON_HEDLEY_PUBLIC   __declspec(dllexport)\n#  define JSON_HEDLEY_IMPORT   __declspec(dllimport)\n#else\n#  if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(visibility) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \\\n    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n    JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \\\n    ( \\\n      defined(__TI_EABI__) && \\\n      ( \\\n        (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \\\n        JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) \\\n      ) \\\n    ) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n#    define JSON_HEDLEY_PRIVATE __attribute__((__visibility__(\"hidden\")))\n#    define JSON_HEDLEY_PUBLIC  __attribute__((__visibility__(\"default\")))\n#  else\n#    define JSON_HEDLEY_PRIVATE\n#    define JSON_HEDLEY_PUBLIC\n#  endif\n#  define JSON_HEDLEY_IMPORT    extern\n#endif\n\n#if defined(JSON_HEDLEY_NO_THROW)\n    #undef JSON_HEDLEY_NO_THROW\n#endif\n#if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(nothrow) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_NO_THROW __attribute__((__nothrow__))\n#elif \\\n    JSON_HEDLEY_MSVC_VERSION_CHECK(13,1,0) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0)\n    #define JSON_HEDLEY_NO_THROW __declspec(nothrow)\n#else\n    #define JSON_HEDLEY_NO_THROW\n#endif\n\n#if defined(JSON_HEDLEY_FALL_THROUGH)\n    #undef JSON_HEDLEY_FALL_THROUGH\n#endif\n#if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(fallthrough) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(7,0,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_FALL_THROUGH __attribute__((__fallthrough__))\n#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(clang,fallthrough)\n    #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[clang::fallthrough]])\n#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(fallthrough)\n    #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[fallthrough]])\n#elif defined(__fallthrough) /* SAL */\n    #define JSON_HEDLEY_FALL_THROUGH __fallthrough\n#else\n    #define JSON_HEDLEY_FALL_THROUGH\n#endif\n\n#if defined(JSON_HEDLEY_RETURNS_NON_NULL)\n    #undef JSON_HEDLEY_RETURNS_NON_NULL\n#endif\n#if \\\n    JSON_HEDLEY_HAS_ATTRIBUTE(returns_nonnull) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_RETURNS_NON_NULL __attribute__((__returns_nonnull__))\n#elif defined(_Ret_notnull_) /* SAL */\n    #define JSON_HEDLEY_RETURNS_NON_NULL _Ret_notnull_\n#else\n    #define JSON_HEDLEY_RETURNS_NON_NULL\n#endif\n\n#if defined(JSON_HEDLEY_ARRAY_PARAM)\n    #undef JSON_HEDLEY_ARRAY_PARAM\n#endif\n#if \\\n    defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && \\\n    !defined(__STDC_NO_VLA__) && \\\n    !defined(__cplusplus) && \\\n    !defined(JSON_HEDLEY_PGI_VERSION) && \\\n    !defined(JSON_HEDLEY_TINYC_VERSION)\n    #define JSON_HEDLEY_ARRAY_PARAM(name) (name)\n#else\n    #define JSON_HEDLEY_ARRAY_PARAM(name)\n#endif\n\n#if defined(JSON_HEDLEY_IS_CONSTANT)\n    #undef JSON_HEDLEY_IS_CONSTANT\n#endif\n#if defined(JSON_HEDLEY_REQUIRE_CONSTEXPR)\n    #undef JSON_HEDLEY_REQUIRE_CONSTEXPR\n#endif\n/* JSON_HEDLEY_IS_CONSTEXPR_ is for\n   HEDLEY INTERNAL USE ONLY.  API subject to change without notice. */\n#if defined(JSON_HEDLEY_IS_CONSTEXPR_)\n    #undef JSON_HEDLEY_IS_CONSTEXPR_\n#endif\n#if \\\n    JSON_HEDLEY_HAS_BUILTIN(__builtin_constant_p) || \\\n    JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \\\n    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n    JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,19) || \\\n    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \\\n    JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \\\n    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \\\n    (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) && !defined(__cplusplus)) || \\\n    JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \\\n    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)\n    #define JSON_HEDLEY_IS_CONSTANT(expr) __builtin_constant_p(expr)\n#endif\n#if !defined(__cplusplus)\n#  if \\\n       JSON_HEDLEY_HAS_BUILTIN(__builtin_types_compatible_p) || \\\n       JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \\\n       JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n       JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \\\n       JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \\\n       JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \\\n       JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,24)\n#if defined(__INTPTR_TYPE__)\n    #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0)), int*)\n#else\n    #include <stdint.h>\n    #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((intptr_t) ((expr) * 0)) : (int*) 0)), int*)\n#endif\n#  elif \\\n       ( \\\n          defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) && \\\n          !defined(JSON_HEDLEY_SUNPRO_VERSION) && \\\n          !defined(JSON_HEDLEY_PGI_VERSION) && \\\n          !defined(JSON_HEDLEY_IAR_VERSION)) || \\\n       (JSON_HEDLEY_HAS_EXTENSION(c_generic_selections) && !defined(JSON_HEDLEY_IAR_VERSION)) || \\\n       JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \\\n       JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) || \\\n       JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \\\n       JSON_HEDLEY_ARM_VERSION_CHECK(5,3,0)\n#if defined(__INTPTR_TYPE__)\n    #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0), int*: 1, void*: 0)\n#else\n    #include <stdint.h>\n    #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((intptr_t) * 0) : (int*) 0), int*: 1, void*: 0)\n#endif\n#  elif \\\n       defined(JSON_HEDLEY_GCC_VERSION) || \\\n       defined(JSON_HEDLEY_INTEL_VERSION) || \\\n       defined(JSON_HEDLEY_TINYC_VERSION) || \\\n       defined(JSON_HEDLEY_TI_ARMCL_VERSION) || \\\n       JSON_HEDLEY_TI_CL430_VERSION_CHECK(18,12,0) || \\\n       defined(JSON_HEDLEY_TI_CL2000_VERSION) || \\\n       defined(JSON_HEDLEY_TI_CL6X_VERSION) || \\\n       defined(JSON_HEDLEY_TI_CL7X_VERSION) || \\\n       defined(JSON_HEDLEY_TI_CLPRU_VERSION) || \\\n       defined(__clang__)\n#    define JSON_HEDLEY_IS_CONSTEXPR_(expr) ( \\\n        sizeof(void) != \\\n        sizeof(*( \\\n                  1 ? \\\n                  ((void*) ((expr) * 0L) ) : \\\n((struct { char v[sizeof(void) * 2]; } *) 1) \\\n                ) \\\n              ) \\\n                                            )\n#  endif\n#endif\n#if defined(JSON_HEDLEY_IS_CONSTEXPR_)\n    #if !defined(JSON_HEDLEY_IS_CONSTANT)\n        #define JSON_HEDLEY_IS_CONSTANT(expr) JSON_HEDLEY_IS_CONSTEXPR_(expr)\n    #endif\n    #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (JSON_HEDLEY_IS_CONSTEXPR_(expr) ? (expr) : (-1))\n#else\n    #if !defined(JSON_HEDLEY_IS_CONSTANT)\n        #define JSON_HEDLEY_IS_CONSTANT(expr) (0)\n    #endif\n    #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (expr)\n#endif\n\n#if defined(JSON_HEDLEY_BEGIN_C_DECLS)\n    #undef JSON_HEDLEY_BEGIN_C_DECLS\n#endif\n#if defined(JSON_HEDLEY_END_C_DECLS)\n    #undef JSON_HEDLEY_END_C_DECLS\n#endif\n#if defined(JSON_HEDLEY_C_DECL)\n    #undef JSON_HEDLEY_C_DECL\n#endif\n#if defined(__cplusplus)\n    #define JSON_HEDLEY_BEGIN_C_DECLS extern \"C\" {\n    #define JSON_HEDLEY_END_C_DECLS }\n    #define JSON_HEDLEY_C_DECL extern \"C\"\n#else\n    #define JSON_HEDLEY_BEGIN_C_DECLS\n    #define JSON_HEDLEY_END_C_DECLS\n    #define JSON_HEDLEY_C_DECL\n#endif\n\n#if defined(JSON_HEDLEY_STATIC_ASSERT)\n    #undef JSON_HEDLEY_STATIC_ASSERT\n#endif\n#if \\\n  !defined(__cplusplus) && ( \\\n      (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) || \\\n      (JSON_HEDLEY_HAS_FEATURE(c_static_assert) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \\\n      JSON_HEDLEY_GCC_VERSION_CHECK(6,0,0) || \\\n      JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \\\n      defined(_Static_assert) \\\n    )\n#  define JSON_HEDLEY_STATIC_ASSERT(expr, message) _Static_assert(expr, message)\n#elif \\\n  (defined(__cplusplus) && (__cplusplus >= 201103L)) || \\\n  JSON_HEDLEY_MSVC_VERSION_CHECK(16,0,0) || \\\n  JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n#  define JSON_HEDLEY_STATIC_ASSERT(expr, message) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(static_assert(expr, message))\n#else\n#  define JSON_HEDLEY_STATIC_ASSERT(expr, message)\n#endif\n\n#if defined(JSON_HEDLEY_NULL)\n    #undef JSON_HEDLEY_NULL\n#endif\n#if defined(__cplusplus)\n    #if __cplusplus >= 201103L\n        #define JSON_HEDLEY_NULL JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(nullptr)\n    #elif defined(NULL)\n        #define JSON_HEDLEY_NULL NULL\n    #else\n        #define JSON_HEDLEY_NULL JSON_HEDLEY_STATIC_CAST(void*, 0)\n    #endif\n#elif defined(NULL)\n    #define JSON_HEDLEY_NULL NULL\n#else\n    #define JSON_HEDLEY_NULL ((void*) 0)\n#endif\n\n#if defined(JSON_HEDLEY_MESSAGE)\n    #undef JSON_HEDLEY_MESSAGE\n#endif\n#if JSON_HEDLEY_HAS_WARNING(\"-Wunknown-pragmas\")\n#  define JSON_HEDLEY_MESSAGE(msg) \\\n    JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n    JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \\\n    JSON_HEDLEY_PRAGMA(message msg) \\\n    JSON_HEDLEY_DIAGNOSTIC_POP\n#elif \\\n  JSON_HEDLEY_GCC_VERSION_CHECK(4,4,0) || \\\n  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)\n#  define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message msg)\n#elif JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0)\n#  define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(_CRI message msg)\n#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)\n#  define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg))\n#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,0,0)\n#  define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg))\n#else\n#  define JSON_HEDLEY_MESSAGE(msg)\n#endif\n\n#if defined(JSON_HEDLEY_WARNING)\n    #undef JSON_HEDLEY_WARNING\n#endif\n#if JSON_HEDLEY_HAS_WARNING(\"-Wunknown-pragmas\")\n#  define JSON_HEDLEY_WARNING(msg) \\\n    JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n    JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \\\n    JSON_HEDLEY_PRAGMA(clang warning msg) \\\n    JSON_HEDLEY_DIAGNOSTIC_POP\n#elif \\\n  JSON_HEDLEY_GCC_VERSION_CHECK(4,8,0) || \\\n  JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \\\n  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)\n#  define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(GCC warning msg)\n#elif \\\n  JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \\\n  JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n#  define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(message(msg))\n#else\n#  define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_MESSAGE(msg)\n#endif\n\n#if defined(JSON_HEDLEY_REQUIRE)\n    #undef JSON_HEDLEY_REQUIRE\n#endif\n#if defined(JSON_HEDLEY_REQUIRE_MSG)\n    #undef JSON_HEDLEY_REQUIRE_MSG\n#endif\n#if JSON_HEDLEY_HAS_ATTRIBUTE(diagnose_if)\n#  if JSON_HEDLEY_HAS_WARNING(\"-Wgcc-compat\")\n#    define JSON_HEDLEY_REQUIRE(expr) \\\n    JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n    _Pragma(\"clang diagnostic ignored \\\"-Wgcc-compat\\\"\") \\\n    __attribute__((diagnose_if(!(expr), #expr, \"error\"))) \\\n    JSON_HEDLEY_DIAGNOSTIC_POP\n#    define JSON_HEDLEY_REQUIRE_MSG(expr,msg) \\\n    JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n    _Pragma(\"clang diagnostic ignored \\\"-Wgcc-compat\\\"\") \\\n    __attribute__((diagnose_if(!(expr), msg, \"error\"))) \\\n    JSON_HEDLEY_DIAGNOSTIC_POP\n#  else\n#    define JSON_HEDLEY_REQUIRE(expr) __attribute__((diagnose_if(!(expr), #expr, \"error\")))\n#    define JSON_HEDLEY_REQUIRE_MSG(expr,msg) __attribute__((diagnose_if(!(expr), msg, \"error\")))\n#  endif\n#else\n#  define JSON_HEDLEY_REQUIRE(expr)\n#  define JSON_HEDLEY_REQUIRE_MSG(expr,msg)\n#endif\n\n#if defined(JSON_HEDLEY_FLAGS)\n    #undef JSON_HEDLEY_FLAGS\n#endif\n#if JSON_HEDLEY_HAS_ATTRIBUTE(flag_enum) && (!defined(__cplusplus) || JSON_HEDLEY_HAS_WARNING(\"-Wbitfield-enum-conversion\"))\n    #define JSON_HEDLEY_FLAGS __attribute__((__flag_enum__))\n#else\n    #define JSON_HEDLEY_FLAGS\n#endif\n\n#if defined(JSON_HEDLEY_FLAGS_CAST)\n    #undef JSON_HEDLEY_FLAGS_CAST\n#endif\n#if JSON_HEDLEY_INTEL_VERSION_CHECK(19,0,0)\n#  define JSON_HEDLEY_FLAGS_CAST(T, expr) (__extension__ ({ \\\n        JSON_HEDLEY_DIAGNOSTIC_PUSH \\\n        _Pragma(\"warning(disable:188)\") \\\n        ((T) (expr)); \\\n        JSON_HEDLEY_DIAGNOSTIC_POP \\\n    }))\n#else\n#  define JSON_HEDLEY_FLAGS_CAST(T, expr) JSON_HEDLEY_STATIC_CAST(T, expr)\n#endif\n\n#if defined(JSON_HEDLEY_EMPTY_BASES)\n    #undef JSON_HEDLEY_EMPTY_BASES\n#endif\n#if \\\n    (JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,23918) && !JSON_HEDLEY_MSVC_VERSION_CHECK(20,0,0)) || \\\n    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)\n    #define JSON_HEDLEY_EMPTY_BASES __declspec(empty_bases)\n#else\n    #define JSON_HEDLEY_EMPTY_BASES\n#endif\n\n/* Remaining macros are deprecated. */\n\n#if defined(JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK)\n    #undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK\n#endif\n#if defined(__clang__)\n    #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) (0)\n#else\n    #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)\n#endif\n\n#if defined(JSON_HEDLEY_CLANG_HAS_ATTRIBUTE)\n    #undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE\n#endif\n#define JSON_HEDLEY_CLANG_HAS_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_ATTRIBUTE(attribute)\n\n#if defined(JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE)\n    #undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE\n#endif\n#define JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute)\n\n#if defined(JSON_HEDLEY_CLANG_HAS_BUILTIN)\n    #undef JSON_HEDLEY_CLANG_HAS_BUILTIN\n#endif\n#define JSON_HEDLEY_CLANG_HAS_BUILTIN(builtin) JSON_HEDLEY_HAS_BUILTIN(builtin)\n\n#if defined(JSON_HEDLEY_CLANG_HAS_FEATURE)\n    #undef JSON_HEDLEY_CLANG_HAS_FEATURE\n#endif\n#define JSON_HEDLEY_CLANG_HAS_FEATURE(feature) JSON_HEDLEY_HAS_FEATURE(feature)\n\n#if defined(JSON_HEDLEY_CLANG_HAS_EXTENSION)\n    #undef JSON_HEDLEY_CLANG_HAS_EXTENSION\n#endif\n#define JSON_HEDLEY_CLANG_HAS_EXTENSION(extension) JSON_HEDLEY_HAS_EXTENSION(extension)\n\n#if defined(JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE)\n    #undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE\n#endif\n#define JSON_HEDLEY_CLANG_HAS_DECLSPEC_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute)\n\n#if defined(JSON_HEDLEY_CLANG_HAS_WARNING)\n    #undef JSON_HEDLEY_CLANG_HAS_WARNING\n#endif\n#define JSON_HEDLEY_CLANG_HAS_WARNING(warning) JSON_HEDLEY_HAS_WARNING(warning)\n\n#endif /* !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < X) */\n\n\n// This file contains all internal macro definitions (except those affecting ABI)\n// You MUST include macro_unscope.hpp at the end of json.hpp to undef all of them\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n\n// exclude unsupported compilers\n#if !defined(JSON_SKIP_UNSUPPORTED_COMPILER_CHECK)\n    #if defined(__clang__)\n        #if (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) < 30400\n            #error \"unsupported Clang version - see https://github.com/nlohmann/json#supported-compilers\"\n        #endif\n    #elif defined(__GNUC__) && !(defined(__ICC) || defined(__INTEL_COMPILER))\n        #if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) < 40800\n            #error \"unsupported GCC version - see https://github.com/nlohmann/json#supported-compilers\"\n        #endif\n    #endif\n#endif\n\n// C++ language standard detection\n// if the user manually specified the used c++ version this is skipped\n#if !defined(JSON_HAS_CPP_20) && !defined(JSON_HAS_CPP_17) && !defined(JSON_HAS_CPP_14) && !defined(JSON_HAS_CPP_11)\n    #if (defined(__cplusplus) && __cplusplus >= 202002L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 202002L)\n        #define JSON_HAS_CPP_20\n        #define JSON_HAS_CPP_17\n        #define JSON_HAS_CPP_14\n    #elif (defined(__cplusplus) && __cplusplus >= 201703L) || (defined(_HAS_CXX17) && _HAS_CXX17 == 1) // fix for issue #464\n        #define JSON_HAS_CPP_17\n        #define JSON_HAS_CPP_14\n    #elif (defined(__cplusplus) && __cplusplus >= 201402L) || (defined(_HAS_CXX14) && _HAS_CXX14 == 1)\n        #define JSON_HAS_CPP_14\n    #endif\n    // the cpp 11 flag is always specified because it is the minimal required version\n    #define JSON_HAS_CPP_11\n#endif\n\n#ifdef __has_include\n    #if __has_include(<version>)\n        #include <version>\n    #endif\n#endif\n\n#if !defined(JSON_HAS_FILESYSTEM) && !defined(JSON_HAS_EXPERIMENTAL_FILESYSTEM)\n    #ifdef JSON_HAS_CPP_17\n        #if defined(__cpp_lib_filesystem)\n            #define JSON_HAS_FILESYSTEM 1\n        #elif defined(__cpp_lib_experimental_filesystem)\n            #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1\n        #elif !defined(__has_include)\n            #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1\n        #elif __has_include(<filesystem>)\n            #define JSON_HAS_FILESYSTEM 1\n        #elif __has_include(<experimental/filesystem>)\n            #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1\n        #endif\n\n        // std::filesystem does not work on MinGW GCC 8: https://sourceforge.net/p/mingw-w64/bugs/737/\n        #if defined(__MINGW32__) && defined(__GNUC__) && __GNUC__ == 8\n            #undef JSON_HAS_FILESYSTEM\n            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM\n        #endif\n\n        // no filesystem support before GCC 8: https://en.cppreference.com/w/cpp/compiler_support\n        #if defined(__GNUC__) && !defined(__clang__) && __GNUC__ < 8\n            #undef JSON_HAS_FILESYSTEM\n            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM\n        #endif\n\n        // no filesystem support before Clang 7: https://en.cppreference.com/w/cpp/compiler_support\n        #if defined(__clang_major__) && __clang_major__ < 7\n            #undef JSON_HAS_FILESYSTEM\n            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM\n        #endif\n\n        // no filesystem support before MSVC 19.14: https://en.cppreference.com/w/cpp/compiler_support\n        #if defined(_MSC_VER) && _MSC_VER < 1914\n            #undef JSON_HAS_FILESYSTEM\n            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM\n        #endif\n\n        // no filesystem support before iOS 13\n        #if defined(__IPHONE_OS_VERSION_MIN_REQUIRED) && __IPHONE_OS_VERSION_MIN_REQUIRED < 130000\n            #undef JSON_HAS_FILESYSTEM\n            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM\n        #endif\n\n        // no filesystem support before macOS Catalina\n        #if defined(__MAC_OS_X_VERSION_MIN_REQUIRED) && __MAC_OS_X_VERSION_MIN_REQUIRED < 101500\n            #undef JSON_HAS_FILESYSTEM\n            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM\n        #endif\n    #endif\n#endif\n\n#ifndef JSON_HAS_EXPERIMENTAL_FILESYSTEM\n    #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 0\n#endif\n\n#ifndef JSON_HAS_FILESYSTEM\n    #define JSON_HAS_FILESYSTEM 0\n#endif\n\n#ifndef JSON_HAS_THREE_WAY_COMPARISON\n    #if defined(__cpp_impl_three_way_comparison) && __cpp_impl_three_way_comparison >= 201907L \\\n        && defined(__cpp_lib_three_way_comparison) && __cpp_lib_three_way_comparison >= 201907L\n        #define JSON_HAS_THREE_WAY_COMPARISON 1\n    #else\n        #define JSON_HAS_THREE_WAY_COMPARISON 0\n    #endif\n#endif\n\n#ifndef JSON_HAS_RANGES\n    // ranges header shipping in GCC 11.1.0 (released 2021-04-27) has syntax error\n    #if defined(__GLIBCXX__) && __GLIBCXX__ == 20210427\n        #define JSON_HAS_RANGES 0\n    #elif defined(__cpp_lib_ranges)\n        #define JSON_HAS_RANGES 1\n    #else\n        #define JSON_HAS_RANGES 0\n    #endif\n#endif\n\n#ifndef JSON_HAS_STATIC_RTTI\n    #if !defined(_HAS_STATIC_RTTI) || _HAS_STATIC_RTTI != 0\n        #define JSON_HAS_STATIC_RTTI 1\n    #else\n        #define JSON_HAS_STATIC_RTTI 0\n    #endif\n#endif\n\n#ifdef JSON_HAS_CPP_17\n    #define JSON_INLINE_VARIABLE inline\n#else\n    #define JSON_INLINE_VARIABLE\n#endif\n\n#if JSON_HEDLEY_HAS_ATTRIBUTE(no_unique_address)\n    #define JSON_NO_UNIQUE_ADDRESS [[no_unique_address]]\n#else\n    #define JSON_NO_UNIQUE_ADDRESS\n#endif\n\n// disable documentation warnings on clang\n#if defined(__clang__)\n    #pragma clang diagnostic push\n    #pragma clang diagnostic ignored \"-Wdocumentation\"\n    #pragma clang diagnostic ignored \"-Wdocumentation-unknown-command\"\n#endif\n\n// allow disabling exceptions\n#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && !defined(JSON_NOEXCEPTION)\n    #define JSON_THROW(exception) throw exception\n    #define JSON_TRY try\n    #define JSON_CATCH(exception) catch(exception)\n    #define JSON_INTERNAL_CATCH(exception) catch(exception)\n#else\n    #include <cstdlib>\n    #define JSON_THROW(exception) std::abort()\n    #define JSON_TRY if(true)\n    #define JSON_CATCH(exception) if(false)\n    #define JSON_INTERNAL_CATCH(exception) if(false)\n#endif\n\n// override exception macros\n#if defined(JSON_THROW_USER)\n    #undef JSON_THROW\n    #define JSON_THROW JSON_THROW_USER\n#endif\n#if defined(JSON_TRY_USER)\n    #undef JSON_TRY\n    #define JSON_TRY JSON_TRY_USER\n#endif\n#if defined(JSON_CATCH_USER)\n    #undef JSON_CATCH\n    #define JSON_CATCH JSON_CATCH_USER\n    #undef JSON_INTERNAL_CATCH\n    #define JSON_INTERNAL_CATCH JSON_CATCH_USER\n#endif\n#if defined(JSON_INTERNAL_CATCH_USER)\n    #undef JSON_INTERNAL_CATCH\n    #define JSON_INTERNAL_CATCH JSON_INTERNAL_CATCH_USER\n#endif\n\n// allow overriding assert\n#if !defined(JSON_ASSERT)\n    #include <cassert> // assert\n    #define JSON_ASSERT(x) assert(x)\n#endif\n\n// allow to access some private functions (needed by the test suite)\n#if defined(JSON_TESTS_PRIVATE)\n    #define JSON_PRIVATE_UNLESS_TESTED public\n#else\n    #define JSON_PRIVATE_UNLESS_TESTED private\n#endif\n\n/*!\n@brief macro to briefly define a mapping between an enum and JSON\n@def NLOHMANN_JSON_SERIALIZE_ENUM\n@since version 3.4.0\n*/\n#define NLOHMANN_JSON_SERIALIZE_ENUM(ENUM_TYPE, ...)                                            \\\n    template<typename BasicJsonType>                                                            \\\n    inline void to_json(BasicJsonType& j, const ENUM_TYPE& e)                                   \\\n    {                                                                                           \\\n        static_assert(std::is_enum<ENUM_TYPE>::value, #ENUM_TYPE \" must be an enum!\");          \\\n        static const std::pair<ENUM_TYPE, BasicJsonType> m[] = __VA_ARGS__;                     \\\n        auto it = std::find_if(std::begin(m), std::end(m),                                      \\\n                               [e](const std::pair<ENUM_TYPE, BasicJsonType>& ej_pair) -> bool  \\\n        {                                                                                       \\\n            return ej_pair.first == e;                                                          \\\n        });                                                                                     \\\n        j = ((it != std::end(m)) ? it : std::begin(m))->second;                                 \\\n    }                                                                                           \\\n    template<typename BasicJsonType>                                                            \\\n    inline void from_json(const BasicJsonType& j, ENUM_TYPE& e)                                 \\\n    {                                                                                           \\\n        static_assert(std::is_enum<ENUM_TYPE>::value, #ENUM_TYPE \" must be an enum!\");          \\\n        static const std::pair<ENUM_TYPE, BasicJsonType> m[] = __VA_ARGS__;                     \\\n        auto it = std::find_if(std::begin(m), std::end(m),                                      \\\n                               [&j](const std::pair<ENUM_TYPE, BasicJsonType>& ej_pair) -> bool \\\n        {                                                                                       \\\n            return ej_pair.second == j;                                                         \\\n        });                                                                                     \\\n        e = ((it != std::end(m)) ? it : std::begin(m))->first;                                  \\\n    }\n\n// Ugly macros to avoid uglier copy-paste when specializing basic_json. They\n// may be removed in the future once the class is split.\n\n#define NLOHMANN_BASIC_JSON_TPL_DECLARATION                                \\\n    template<template<typename, typename, typename...> class ObjectType,   \\\n             template<typename, typename...> class ArrayType,              \\\n             class StringType, class BooleanType, class NumberIntegerType, \\\n             class NumberUnsignedType, class NumberFloatType,              \\\n             template<typename> class AllocatorType,                       \\\n             template<typename, typename = void> class JSONSerializer,     \\\n             class BinaryType,                                             \\\n             class CustomBaseClass>\n\n#define NLOHMANN_BASIC_JSON_TPL                                            \\\n    basic_json<ObjectType, ArrayType, StringType, BooleanType,             \\\n    NumberIntegerType, NumberUnsignedType, NumberFloatType,                \\\n    AllocatorType, JSONSerializer, BinaryType, CustomBaseClass>\n\n// Macros to simplify conversion from/to types\n\n#define NLOHMANN_JSON_EXPAND( x ) x\n#define NLOHMANN_JSON_GET_MACRO(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, NAME,...) NAME\n#define NLOHMANN_JSON_PASTE(...) NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_GET_MACRO(__VA_ARGS__, \\\n        NLOHMANN_JSON_PASTE64, \\\n        NLOHMANN_JSON_PASTE63, \\\n        NLOHMANN_JSON_PASTE62, \\\n        NLOHMANN_JSON_PASTE61, \\\n        NLOHMANN_JSON_PASTE60, \\\n        NLOHMANN_JSON_PASTE59, \\\n        NLOHMANN_JSON_PASTE58, \\\n        NLOHMANN_JSON_PASTE57, \\\n        NLOHMANN_JSON_PASTE56, \\\n        NLOHMANN_JSON_PASTE55, \\\n        NLOHMANN_JSON_PASTE54, \\\n        NLOHMANN_JSON_PASTE53, \\\n        NLOHMANN_JSON_PASTE52, \\\n        NLOHMANN_JSON_PASTE51, \\\n        NLOHMANN_JSON_PASTE50, \\\n        NLOHMANN_JSON_PASTE49, \\\n        NLOHMANN_JSON_PASTE48, \\\n        NLOHMANN_JSON_PASTE47, \\\n        NLOHMANN_JSON_PASTE46, \\\n        NLOHMANN_JSON_PASTE45, \\\n        NLOHMANN_JSON_PASTE44, \\\n        NLOHMANN_JSON_PASTE43, \\\n        NLOHMANN_JSON_PASTE42, \\\n        NLOHMANN_JSON_PASTE41, \\\n        NLOHMANN_JSON_PASTE40, \\\n        NLOHMANN_JSON_PASTE39, \\\n        NLOHMANN_JSON_PASTE38, \\\n        NLOHMANN_JSON_PASTE37, \\\n        NLOHMANN_JSON_PASTE36, \\\n        NLOHMANN_JSON_PASTE35, \\\n        NLOHMANN_JSON_PASTE34, \\\n        NLOHMANN_JSON_PASTE33, \\\n        NLOHMANN_JSON_PASTE32, \\\n        NLOHMANN_JSON_PASTE31, \\\n        NLOHMANN_JSON_PASTE30, \\\n        NLOHMANN_JSON_PASTE29, \\\n        NLOHMANN_JSON_PASTE28, \\\n        NLOHMANN_JSON_PASTE27, \\\n        NLOHMANN_JSON_PASTE26, \\\n        NLOHMANN_JSON_PASTE25, \\\n        NLOHMANN_JSON_PASTE24, \\\n        NLOHMANN_JSON_PASTE23, \\\n        NLOHMANN_JSON_PASTE22, \\\n        NLOHMANN_JSON_PASTE21, \\\n        NLOHMANN_JSON_PASTE20, \\\n        NLOHMANN_JSON_PASTE19, \\\n        NLOHMANN_JSON_PASTE18, \\\n        NLOHMANN_JSON_PASTE17, \\\n        NLOHMANN_JSON_PASTE16, \\\n        NLOHMANN_JSON_PASTE15, \\\n        NLOHMANN_JSON_PASTE14, \\\n        NLOHMANN_JSON_PASTE13, \\\n        NLOHMANN_JSON_PASTE12, \\\n        NLOHMANN_JSON_PASTE11, \\\n        NLOHMANN_JSON_PASTE10, \\\n        NLOHMANN_JSON_PASTE9, \\\n        NLOHMANN_JSON_PASTE8, \\\n        NLOHMANN_JSON_PASTE7, \\\n        NLOHMANN_JSON_PASTE6, \\\n        NLOHMANN_JSON_PASTE5, \\\n        NLOHMANN_JSON_PASTE4, \\\n        NLOHMANN_JSON_PASTE3, \\\n        NLOHMANN_JSON_PASTE2, \\\n        NLOHMANN_JSON_PASTE1)(__VA_ARGS__))\n#define NLOHMANN_JSON_PASTE2(func, v1) func(v1)\n#define NLOHMANN_JSON_PASTE3(func, v1, v2) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE2(func, v2)\n#define NLOHMANN_JSON_PASTE4(func, v1, v2, v3) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE3(func, v2, v3)\n#define NLOHMANN_JSON_PASTE5(func, v1, v2, v3, v4) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE4(func, v2, v3, v4)\n#define NLOHMANN_JSON_PASTE6(func, v1, v2, v3, v4, v5) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE5(func, v2, v3, v4, v5)\n#define NLOHMANN_JSON_PASTE7(func, v1, v2, v3, v4, v5, v6) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE6(func, v2, v3, v4, v5, v6)\n#define NLOHMANN_JSON_PASTE8(func, v1, v2, v3, v4, v5, v6, v7) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE7(func, v2, v3, v4, v5, v6, v7)\n#define NLOHMANN_JSON_PASTE9(func, v1, v2, v3, v4, v5, v6, v7, v8) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE8(func, v2, v3, v4, v5, v6, v7, v8)\n#define NLOHMANN_JSON_PASTE10(func, v1, v2, v3, v4, v5, v6, v7, v8, v9) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE9(func, v2, v3, v4, v5, v6, v7, v8, v9)\n#define NLOHMANN_JSON_PASTE11(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE10(func, v2, v3, v4, v5, v6, v7, v8, v9, v10)\n#define NLOHMANN_JSON_PASTE12(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE11(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)\n#define NLOHMANN_JSON_PASTE13(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE12(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12)\n#define NLOHMANN_JSON_PASTE14(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE13(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13)\n#define NLOHMANN_JSON_PASTE15(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE14(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14)\n#define NLOHMANN_JSON_PASTE16(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE15(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15)\n#define NLOHMANN_JSON_PASTE17(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE16(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16)\n#define NLOHMANN_JSON_PASTE18(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE17(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17)\n#define NLOHMANN_JSON_PASTE19(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE18(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18)\n#define NLOHMANN_JSON_PASTE20(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE19(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19)\n#define NLOHMANN_JSON_PASTE21(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE20(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20)\n#define NLOHMANN_JSON_PASTE22(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE21(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21)\n#define NLOHMANN_JSON_PASTE23(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE22(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22)\n#define NLOHMANN_JSON_PASTE24(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE23(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23)\n#define NLOHMANN_JSON_PASTE25(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE24(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24)\n#define NLOHMANN_JSON_PASTE26(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE25(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25)\n#define NLOHMANN_JSON_PASTE27(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE26(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26)\n#define NLOHMANN_JSON_PASTE28(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE27(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27)\n#define NLOHMANN_JSON_PASTE29(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE28(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28)\n#define NLOHMANN_JSON_PASTE30(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE29(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29)\n#define NLOHMANN_JSON_PASTE31(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE30(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30)\n#define NLOHMANN_JSON_PASTE32(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE31(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31)\n#define NLOHMANN_JSON_PASTE33(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE32(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32)\n#define NLOHMANN_JSON_PASTE34(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE33(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33)\n#define NLOHMANN_JSON_PASTE35(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE34(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34)\n#define NLOHMANN_JSON_PASTE36(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE35(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35)\n#define NLOHMANN_JSON_PASTE37(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE36(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36)\n#define NLOHMANN_JSON_PASTE38(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE37(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37)\n#define NLOHMANN_JSON_PASTE39(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE38(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38)\n#define NLOHMANN_JSON_PASTE40(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE39(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39)\n#define NLOHMANN_JSON_PASTE41(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE40(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40)\n#define NLOHMANN_JSON_PASTE42(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE41(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41)\n#define NLOHMANN_JSON_PASTE43(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE42(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42)\n#define NLOHMANN_JSON_PASTE44(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE43(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43)\n#define NLOHMANN_JSON_PASTE45(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE44(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44)\n#define NLOHMANN_JSON_PASTE46(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE45(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45)\n#define NLOHMANN_JSON_PASTE47(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE46(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46)\n#define NLOHMANN_JSON_PASTE48(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE47(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47)\n#define NLOHMANN_JSON_PASTE49(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE48(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48)\n#define NLOHMANN_JSON_PASTE50(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE49(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49)\n#define NLOHMANN_JSON_PASTE51(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE50(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50)\n#define NLOHMANN_JSON_PASTE52(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE51(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51)\n#define NLOHMANN_JSON_PASTE53(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE52(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52)\n#define NLOHMANN_JSON_PASTE54(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE53(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53)\n#define NLOHMANN_JSON_PASTE55(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE54(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54)\n#define NLOHMANN_JSON_PASTE56(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE55(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55)\n#define NLOHMANN_JSON_PASTE57(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE56(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56)\n#define NLOHMANN_JSON_PASTE58(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE57(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57)\n#define NLOHMANN_JSON_PASTE59(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE58(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58)\n#define NLOHMANN_JSON_PASTE60(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE59(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59)\n#define NLOHMANN_JSON_PASTE61(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE60(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60)\n#define NLOHMANN_JSON_PASTE62(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE61(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61)\n#define NLOHMANN_JSON_PASTE63(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE62(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62)\n#define NLOHMANN_JSON_PASTE64(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE63(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63)\n\n#define NLOHMANN_JSON_TO(v1) nlohmann_json_j[#v1] = nlohmann_json_t.v1;\n#define NLOHMANN_JSON_FROM(v1) nlohmann_json_j.at(#v1).get_to(nlohmann_json_t.v1);\n#define NLOHMANN_JSON_FROM_WITH_DEFAULT(v1) nlohmann_json_t.v1 = nlohmann_json_j.value(#v1, nlohmann_json_default_obj.v1);\n\n/*!\n@brief macro\n@def NLOHMANN_DEFINE_TYPE_INTRUSIVE\n@since version 3.9.0\n*/\n#define NLOHMANN_DEFINE_TYPE_INTRUSIVE(Type, ...)  \\\n    friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \\\n    friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) }\n\n#define NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(Type, ...)  \\\n    friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \\\n    friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) }\n\n#define NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE(Type, ...)  \\\n    friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) }\n\n/*!\n@brief macro\n@def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE\n@since version 3.9.0\n*/\n#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Type, ...)  \\\n    inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \\\n    inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) }\n\n#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_ONLY_SERIALIZE(Type, ...)  \\\n    inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) }\n\n#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(Type, ...)  \\\n    inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \\\n    inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) }\n\n/*!\n@brief macro\n@def NLOHMANN_DEFINE_DERIVED_TYPE_INTRUSIVE\n@since version 3.11.x\n*/\n#define NLOHMANN_DEFINE_DERIVED_TYPE_INTRUSIVE(Type, BaseType, ...)  \\\n    friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast<const BaseType &>(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \\\n    friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast<BaseType&>(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) }\n\n#define NLOHMANN_DEFINE_DERIVED_TYPE_INTRUSIVE_WITH_DEFAULT(Type, BaseType, ...)  \\\n    friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast<const BaseType&>(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \\\n    friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast<BaseType&>(nlohmann_json_t)); const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) }\n\n/*!\n@brief macro\n@def NLOHMANN_DEFINE_DERIVED_TYPE_NON_INTRUSIVE\n@since version 3.11.x\n*/\n#define NLOHMANN_DEFINE_DERIVED_TYPE_NON_INTRUSIVE(Type, BaseType, ...)  \\\n    inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast<const BaseType &>(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \\\n    inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast<BaseType&>(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) }\n\n#define NLOHMANN_DEFINE_DERIVED_TYPE_NON_INTRUSIVE_WITH_DEFAULT(Type, BaseType, ...)  \\\n    inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast<const BaseType &>(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \\\n    inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast<BaseType&>(nlohmann_json_t)); const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) }\n\n\n// inspired from https://stackoverflow.com/a/26745591\n// allows to call any std function as if (e.g. with begin):\n// using std::begin; begin(x);\n//\n// it allows using the detected idiom to retrieve the return type\n// of such an expression\n#define NLOHMANN_CAN_CALL_STD_FUNC_IMPL(std_name)                                 \\\n    namespace detail {                                                            \\\n    using std::std_name;                                                          \\\n    \\\n    template<typename... T>                                                       \\\n    using result_of_##std_name = decltype(std_name(std::declval<T>()...));        \\\n    }                                                                             \\\n    \\\n    namespace detail2 {                                                           \\\n    struct std_name##_tag                                                         \\\n    {                                                                             \\\n    };                                                                            \\\n    \\\n    template<typename... T>                                                       \\\n    std_name##_tag std_name(T&&...);                                              \\\n    \\\n    template<typename... T>                                                       \\\n    using result_of_##std_name = decltype(std_name(std::declval<T>()...));        \\\n    \\\n    template<typename... T>                                                       \\\n    struct would_call_std_##std_name                                              \\\n    {                                                                             \\\n        static constexpr auto const value = ::nlohmann::detail::                  \\\n                                            is_detected_exact<std_name##_tag, result_of_##std_name, T...>::value; \\\n    };                                                                            \\\n    } /* namespace detail2 */ \\\n    \\\n    template<typename... T>                                                       \\\n    struct would_call_std_##std_name : detail2::would_call_std_##std_name<T...>   \\\n    {                                                                             \\\n    }\n\n#ifndef JSON_USE_IMPLICIT_CONVERSIONS\n    #define JSON_USE_IMPLICIT_CONVERSIONS 1\n#endif\n\n#if JSON_USE_IMPLICIT_CONVERSIONS\n    #define JSON_EXPLICIT\n#else\n    #define JSON_EXPLICIT explicit\n#endif\n\n#ifndef JSON_DISABLE_ENUM_SERIALIZATION\n    #define JSON_DISABLE_ENUM_SERIALIZATION 0\n#endif\n\n#ifndef JSON_USE_GLOBAL_UDLS\n    #define JSON_USE_GLOBAL_UDLS 1\n#endif\n\n#if JSON_HAS_THREE_WAY_COMPARISON\n    #include <compare> // partial_ordering\n#endif\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n///////////////////////////\n// JSON type enumeration //\n///////////////////////////\n\n/*!\n@brief the JSON type enumeration\n\nThis enumeration collects the different JSON types. It is internally used to\ndistinguish the stored values, and the functions @ref basic_json::is_null(),\n@ref basic_json::is_object(), @ref basic_json::is_array(),\n@ref basic_json::is_string(), @ref basic_json::is_boolean(),\n@ref basic_json::is_number() (with @ref basic_json::is_number_integer(),\n@ref basic_json::is_number_unsigned(), and @ref basic_json::is_number_float()),\n@ref basic_json::is_discarded(), @ref basic_json::is_primitive(), and\n@ref basic_json::is_structured() rely on it.\n\n@note There are three enumeration entries (number_integer, number_unsigned, and\nnumber_float), because the library distinguishes these three types for numbers:\n@ref basic_json::number_unsigned_t is used for unsigned integers,\n@ref basic_json::number_integer_t is used for signed integers, and\n@ref basic_json::number_float_t is used for floating-point numbers or to\napproximate integers which do not fit in the limits of their respective type.\n\n@sa see @ref basic_json::basic_json(const value_t value_type) -- create a JSON\nvalue with the default value for a given type\n\n@since version 1.0.0\n*/\nenum class value_t : std::uint8_t\n{\n    null,             ///< null value\n    object,           ///< object (unordered set of name/value pairs)\n    array,            ///< array (ordered collection of values)\n    string,           ///< string value\n    boolean,          ///< boolean value\n    number_integer,   ///< number value (signed integer)\n    number_unsigned,  ///< number value (unsigned integer)\n    number_float,     ///< number value (floating-point)\n    binary,           ///< binary array (ordered collection of bytes)\n    discarded         ///< discarded by the parser callback function\n};\n\n/*!\n@brief comparison operator for JSON types\n\nReturns an ordering that is similar to Python:\n- order: null < boolean < number < object < array < string < binary\n- furthermore, each type is not smaller than itself\n- discarded values are not comparable\n- binary is represented as a b\"\" string in python and directly comparable to a\n  string; however, making a binary array directly comparable with a string would\n  be surprising behavior in a JSON file.\n\n@since version 1.0.0\n*/\n#if JSON_HAS_THREE_WAY_COMPARISON\n    inline std::partial_ordering operator<=>(const value_t lhs, const value_t rhs) noexcept // *NOPAD*\n#else\n    inline bool operator<(const value_t lhs, const value_t rhs) noexcept\n#endif\n{\n    static constexpr std::array<std::uint8_t, 9> order = {{\n            0 /* null */, 3 /* object */, 4 /* array */, 5 /* string */,\n            1 /* boolean */, 2 /* integer */, 2 /* unsigned */, 2 /* float */,\n            6 /* binary */\n        }\n    };\n\n    const auto l_index = static_cast<std::size_t>(lhs);\n    const auto r_index = static_cast<std::size_t>(rhs);\n#if JSON_HAS_THREE_WAY_COMPARISON\n    if (l_index < order.size() && r_index < order.size())\n    {\n        return order[l_index] <=> order[r_index]; // *NOPAD*\n    }\n    return std::partial_ordering::unordered;\n#else\n    return l_index < order.size() && r_index < order.size() && order[l_index] < order[r_index];\n#endif\n}\n\n// GCC selects the built-in operator< over an operator rewritten from\n// a user-defined spaceship operator\n// Clang, MSVC, and ICC select the rewritten candidate\n// (see GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105200)\n#if JSON_HAS_THREE_WAY_COMPARISON && defined(__GNUC__)\ninline bool operator<(const value_t lhs, const value_t rhs) noexcept\n{\n    return std::is_lt(lhs <=> rhs); // *NOPAD*\n}\n#endif\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/string_escape.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n/*!\n@brief replace all occurrences of a substring by another string\n\n@param[in,out] s  the string to manipulate; changed so that all\n               occurrences of @a f are replaced with @a t\n@param[in]     f  the substring to replace with @a t\n@param[in]     t  the string to replace @a f\n\n@pre The search string @a f must not be empty. **This precondition is\nenforced with an assertion.**\n\n@since version 2.0.0\n*/\ntemplate<typename StringType>\ninline void replace_substring(StringType& s, const StringType& f,\n                              const StringType& t)\n{\n    JSON_ASSERT(!f.empty());\n    for (auto pos = s.find(f);                // find first occurrence of f\n            pos != StringType::npos;          // make sure f was found\n            s.replace(pos, f.size(), t),      // replace with t, and\n            pos = s.find(f, pos + t.size()))  // find next occurrence of f\n    {}\n}\n\n/*!\n * @brief string escaping as described in RFC 6901 (Sect. 4)\n * @param[in] s string to escape\n * @return    escaped string\n *\n * Note the order of escaping \"~\" to \"~0\" and \"/\" to \"~1\" is important.\n */\ntemplate<typename StringType>\ninline StringType escape(StringType s)\n{\n    replace_substring(s, StringType{\"~\"}, StringType{\"~0\"});\n    replace_substring(s, StringType{\"/\"}, StringType{\"~1\"});\n    return s;\n}\n\n/*!\n * @brief string unescaping as described in RFC 6901 (Sect. 4)\n * @param[in] s string to unescape\n * @return    unescaped string\n *\n * Note the order of escaping \"~1\" to \"/\" and \"~0\" to \"~\" is important.\n */\ntemplate<typename StringType>\nstatic void unescape(StringType& s)\n{\n    replace_substring(s, StringType{\"~1\"}, StringType{\"/\"});\n    replace_substring(s, StringType{\"~0\"}, StringType{\"~\"});\n}\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/input/position_t.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cstddef> // size_t\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n/// struct to capture the start position of the current token\nstruct position_t\n{\n    /// the total number of characters read\n    std::size_t chars_read_total = 0;\n    /// the number of characters read in the current line\n    std::size_t chars_read_current_line = 0;\n    /// the number of lines read\n    std::size_t lines_read = 0;\n\n    /// conversion to size_t to preserve SAX interface\n    constexpr operator size_t() const\n    {\n        return chars_read_total;\n    }\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/cpp_future.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-FileCopyrightText: 2018 The Abseil Authors\n// SPDX-License-Identifier: MIT\n\n\n#include <array> // array\n#include <cstddef> // size_t\n#include <type_traits> // conditional, enable_if, false_type, integral_constant, is_constructible, is_integral, is_same, remove_cv, remove_reference, true_type\n#include <utility> // index_sequence, make_index_sequence, index_sequence_for\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\ntemplate<typename T>\nusing uncvref_t = typename std::remove_cv<typename std::remove_reference<T>::type>::type;\n\n#ifdef JSON_HAS_CPP_14\n\n// the following utilities are natively available in C++14\nusing std::enable_if_t;\nusing std::index_sequence;\nusing std::make_index_sequence;\nusing std::index_sequence_for;\n\n#else\n\n// alias templates to reduce boilerplate\ntemplate<bool B, typename T = void>\nusing enable_if_t = typename std::enable_if<B, T>::type;\n\n// The following code is taken from https://github.com/abseil/abseil-cpp/blob/10cb35e459f5ecca5b2ff107635da0bfa41011b4/absl/utility/utility.h\n// which is part of Google Abseil (https://github.com/abseil/abseil-cpp), licensed under the Apache License 2.0.\n\n//// START OF CODE FROM GOOGLE ABSEIL\n\n// integer_sequence\n//\n// Class template representing a compile-time integer sequence. An instantiation\n// of `integer_sequence<T, Ints...>` has a sequence of integers encoded in its\n// type through its template arguments (which is a common need when\n// working with C++11 variadic templates). `absl::integer_sequence` is designed\n// to be a drop-in replacement for C++14's `std::integer_sequence`.\n//\n// Example:\n//\n//   template< class T, T... Ints >\n//   void user_function(integer_sequence<T, Ints...>);\n//\n//   int main()\n//   {\n//     // user_function's `T` will be deduced to `int` and `Ints...`\n//     // will be deduced to `0, 1, 2, 3, 4`.\n//     user_function(make_integer_sequence<int, 5>());\n//   }\ntemplate <typename T, T... Ints>\nstruct integer_sequence\n{\n    using value_type = T;\n    static constexpr std::size_t size() noexcept\n    {\n        return sizeof...(Ints);\n    }\n};\n\n// index_sequence\n//\n// A helper template for an `integer_sequence` of `size_t`,\n// `absl::index_sequence` is designed to be a drop-in replacement for C++14's\n// `std::index_sequence`.\ntemplate <size_t... Ints>\nusing index_sequence = integer_sequence<size_t, Ints...>;\n\nnamespace utility_internal\n{\n\ntemplate <typename Seq, size_t SeqSize, size_t Rem>\nstruct Extend;\n\n// Note that SeqSize == sizeof...(Ints). It's passed explicitly for efficiency.\ntemplate <typename T, T... Ints, size_t SeqSize>\nstruct Extend<integer_sequence<T, Ints...>, SeqSize, 0>\n{\n    using type = integer_sequence < T, Ints..., (Ints + SeqSize)... >;\n};\n\ntemplate <typename T, T... Ints, size_t SeqSize>\nstruct Extend<integer_sequence<T, Ints...>, SeqSize, 1>\n{\n    using type = integer_sequence < T, Ints..., (Ints + SeqSize)..., 2 * SeqSize >;\n};\n\n// Recursion helper for 'make_integer_sequence<T, N>'.\n// 'Gen<T, N>::type' is an alias for 'integer_sequence<T, 0, 1, ... N-1>'.\ntemplate <typename T, size_t N>\nstruct Gen\n{\n    using type =\n        typename Extend < typename Gen < T, N / 2 >::type, N / 2, N % 2 >::type;\n};\n\ntemplate <typename T>\nstruct Gen<T, 0>\n{\n    using type = integer_sequence<T>;\n};\n\n}  // namespace utility_internal\n\n// Compile-time sequences of integers\n\n// make_integer_sequence\n//\n// This template alias is equivalent to\n// `integer_sequence<int, 0, 1, ..., N-1>`, and is designed to be a drop-in\n// replacement for C++14's `std::make_integer_sequence`.\ntemplate <typename T, T N>\nusing make_integer_sequence = typename utility_internal::Gen<T, N>::type;\n\n// make_index_sequence\n//\n// This template alias is equivalent to `index_sequence<0, 1, ..., N-1>`,\n// and is designed to be a drop-in replacement for C++14's\n// `std::make_index_sequence`.\ntemplate <size_t N>\nusing make_index_sequence = make_integer_sequence<size_t, N>;\n\n// index_sequence_for\n//\n// Converts a typename pack into an index sequence of the same length, and\n// is designed to be a drop-in replacement for C++14's\n// `std::index_sequence_for()`\ntemplate <typename... Ts>\nusing index_sequence_for = make_index_sequence<sizeof...(Ts)>;\n\n//// END OF CODE FROM GOOGLE ABSEIL\n\n#endif\n\n// dispatch utility (taken from ranges-v3)\ntemplate<unsigned N> struct priority_tag : priority_tag < N - 1 > {};\ntemplate<> struct priority_tag<0> {};\n\n// taken from ranges-v3\ntemplate<typename T>\nstruct static_const\n{\n    static JSON_INLINE_VARIABLE constexpr T value{};\n};\n\n#ifndef JSON_HAS_CPP_17\n    template<typename T>\n    constexpr T static_const<T>::value;\n#endif\n\ntemplate<typename T, typename... Args>\ninline constexpr std::array<T, sizeof...(Args)> make_array(Args&& ... args)\n{\n    return std::array<T, sizeof...(Args)> {{static_cast<T>(std::forward<Args>(args))...}};\n}\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <limits> // numeric_limits\n#include <type_traits> // false_type, is_constructible, is_integral, is_same, true_type\n#include <utility> // declval\n#include <tuple> // tuple\n#include <string> // char_traits\n\n// #include <nlohmann/detail/iterators/iterator_traits.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <iterator> // random_access_iterator_tag\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n// #include <nlohmann/detail/meta/void_t.hpp>\n\n// #include <nlohmann/detail/meta/cpp_future.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\ntemplate<typename It, typename = void>\nstruct iterator_types {};\n\ntemplate<typename It>\nstruct iterator_types <\n    It,\n    void_t<typename It::difference_type, typename It::value_type, typename It::pointer,\n    typename It::reference, typename It::iterator_category >>\n{\n    using difference_type = typename It::difference_type;\n    using value_type = typename It::value_type;\n    using pointer = typename It::pointer;\n    using reference = typename It::reference;\n    using iterator_category = typename It::iterator_category;\n};\n\n// This is required as some compilers implement std::iterator_traits in a way that\n// doesn't work with SFINAE. See https://github.com/nlohmann/json/issues/1341.\ntemplate<typename T, typename = void>\nstruct iterator_traits\n{\n};\n\ntemplate<typename T>\nstruct iterator_traits < T, enable_if_t < !std::is_pointer<T>::value >>\n            : iterator_types<T>\n{\n};\n\ntemplate<typename T>\nstruct iterator_traits<T*, enable_if_t<std::is_object<T>::value>>\n{\n    using iterator_category = std::random_access_iterator_tag;\n    using value_type = T;\n    using difference_type = ptrdiff_t;\n    using pointer = T*;\n    using reference = T&;\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/call_std/begin.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\n\nNLOHMANN_CAN_CALL_STD_FUNC_IMPL(begin);\n\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/meta/call_std/end.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\n\nNLOHMANN_CAN_CALL_STD_FUNC_IMPL(end);\n\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/meta/cpp_future.hpp>\n\n// #include <nlohmann/detail/meta/detected.hpp>\n\n// #include <nlohmann/json_fwd.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_\n    #define INCLUDE_NLOHMANN_JSON_FWD_HPP_\n\n    #include <cstdint> // int64_t, uint64_t\n    #include <map> // map\n    #include <memory> // allocator\n    #include <string> // string\n    #include <vector> // vector\n\n    // #include <nlohmann/detail/abi_macros.hpp>\n\n\n    /*!\n    @brief namespace for Niels Lohmann\n    @see https://github.com/nlohmann\n    @since version 1.0.0\n    */\n    NLOHMANN_JSON_NAMESPACE_BEGIN\n\n    /*!\n    @brief default JSONSerializer template argument\n\n    This serializer ignores the template arguments and uses ADL\n    ([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl))\n    for serialization.\n    */\n    template<typename T = void, typename SFINAE = void>\n    struct adl_serializer;\n\n    /// a class to store JSON values\n    /// @sa https://json.nlohmann.me/api/basic_json/\n    template<template<typename U, typename V, typename... Args> class ObjectType =\n    std::map,\n    template<typename U, typename... Args> class ArrayType = std::vector,\n    class StringType = std::string, class BooleanType = bool,\n    class NumberIntegerType = std::int64_t,\n    class NumberUnsignedType = std::uint64_t,\n    class NumberFloatType = double,\n    template<typename U> class AllocatorType = std::allocator,\n    template<typename T, typename SFINAE = void> class JSONSerializer =\n    adl_serializer,\n    class BinaryType = std::vector<std::uint8_t>, // cppcheck-suppress syntaxError\n    class CustomBaseClass = void>\n    class basic_json;\n\n    /// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document\n    /// @sa https://json.nlohmann.me/api/json_pointer/\n    template<typename RefStringType>\n    class json_pointer;\n\n    /*!\n    @brief default specialization\n    @sa https://json.nlohmann.me/api/json/\n    */\n    using json = basic_json<>;\n\n    /// @brief a minimal map-like container that preserves insertion order\n    /// @sa https://json.nlohmann.me/api/ordered_map/\n    template<class Key, class T, class IgnoredLess, class Allocator>\n    struct ordered_map;\n\n    /// @brief specialization that maintains the insertion order of object keys\n    /// @sa https://json.nlohmann.me/api/ordered_json/\n    using ordered_json = basic_json<nlohmann::ordered_map>;\n\n    NLOHMANN_JSON_NAMESPACE_END\n\n#endif  // INCLUDE_NLOHMANN_JSON_FWD_HPP_\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\n/*!\n@brief detail namespace with internal helper functions\n\nThis namespace collects functions that should not be exposed,\nimplementations of some @ref basic_json methods, and meta-programming helpers.\n\n@since version 2.1.0\n*/\nnamespace detail\n{\n\n/////////////\n// helpers //\n/////////////\n\n// Note to maintainers:\n//\n// Every trait in this file expects a non CV-qualified type.\n// The only exceptions are in the 'aliases for detected' section\n// (i.e. those of the form: decltype(T::member_function(std::declval<T>())))\n//\n// In this case, T has to be properly CV-qualified to constraint the function arguments\n// (e.g. to_json(BasicJsonType&, const T&))\n\ntemplate<typename> struct is_basic_json : std::false_type {};\n\nNLOHMANN_BASIC_JSON_TPL_DECLARATION\nstruct is_basic_json<NLOHMANN_BASIC_JSON_TPL> : std::true_type {};\n\n// used by exceptions create() member functions\n// true_type for pointer to possibly cv-qualified basic_json or std::nullptr_t\n// false_type otherwise\ntemplate<typename BasicJsonContext>\nstruct is_basic_json_context :\n    std::integral_constant < bool,\n    is_basic_json<typename std::remove_cv<typename std::remove_pointer<BasicJsonContext>::type>::type>::value\n    || std::is_same<BasicJsonContext, std::nullptr_t>::value >\n{};\n\n//////////////////////\n// json_ref helpers //\n//////////////////////\n\ntemplate<typename>\nclass json_ref;\n\ntemplate<typename>\nstruct is_json_ref : std::false_type {};\n\ntemplate<typename T>\nstruct is_json_ref<json_ref<T>> : std::true_type {};\n\n//////////////////////////\n// aliases for detected //\n//////////////////////////\n\ntemplate<typename T>\nusing mapped_type_t = typename T::mapped_type;\n\ntemplate<typename T>\nusing key_type_t = typename T::key_type;\n\ntemplate<typename T>\nusing value_type_t = typename T::value_type;\n\ntemplate<typename T>\nusing difference_type_t = typename T::difference_type;\n\ntemplate<typename T>\nusing pointer_t = typename T::pointer;\n\ntemplate<typename T>\nusing reference_t = typename T::reference;\n\ntemplate<typename T>\nusing iterator_category_t = typename T::iterator_category;\n\ntemplate<typename T, typename... Args>\nusing to_json_function = decltype(T::to_json(std::declval<Args>()...));\n\ntemplate<typename T, typename... Args>\nusing from_json_function = decltype(T::from_json(std::declval<Args>()...));\n\ntemplate<typename T, typename U>\nusing get_template_function = decltype(std::declval<T>().template get<U>());\n\n// trait checking if JSONSerializer<T>::from_json(json const&, udt&) exists\ntemplate<typename BasicJsonType, typename T, typename = void>\nstruct has_from_json : std::false_type {};\n\n// trait checking if j.get<T> is valid\n// use this trait instead of std::is_constructible or std::is_convertible,\n// both rely on, or make use of implicit conversions, and thus fail when T\n// has several constructors/operator= (see https://github.com/nlohmann/json/issues/958)\ntemplate <typename BasicJsonType, typename T>\nstruct is_getable\n{\n    static constexpr bool value = is_detected<get_template_function, const BasicJsonType&, T>::value;\n};\n\ntemplate<typename BasicJsonType, typename T>\nstruct has_from_json < BasicJsonType, T, enable_if_t < !is_basic_json<T>::value >>\n{\n    using serializer = typename BasicJsonType::template json_serializer<T, void>;\n\n    static constexpr bool value =\n        is_detected_exact<void, from_json_function, serializer,\n        const BasicJsonType&, T&>::value;\n};\n\n// This trait checks if JSONSerializer<T>::from_json(json const&) exists\n// this overload is used for non-default-constructible user-defined-types\ntemplate<typename BasicJsonType, typename T, typename = void>\nstruct has_non_default_from_json : std::false_type {};\n\ntemplate<typename BasicJsonType, typename T>\nstruct has_non_default_from_json < BasicJsonType, T, enable_if_t < !is_basic_json<T>::value >>\n{\n    using serializer = typename BasicJsonType::template json_serializer<T, void>;\n\n    static constexpr bool value =\n        is_detected_exact<T, from_json_function, serializer,\n        const BasicJsonType&>::value;\n};\n\n// This trait checks if BasicJsonType::json_serializer<T>::to_json exists\n// Do not evaluate the trait when T is a basic_json type, to avoid template instantiation infinite recursion.\ntemplate<typename BasicJsonType, typename T, typename = void>\nstruct has_to_json : std::false_type {};\n\ntemplate<typename BasicJsonType, typename T>\nstruct has_to_json < BasicJsonType, T, enable_if_t < !is_basic_json<T>::value >>\n{\n    using serializer = typename BasicJsonType::template json_serializer<T, void>;\n\n    static constexpr bool value =\n        is_detected_exact<void, to_json_function, serializer, BasicJsonType&,\n        T>::value;\n};\n\ntemplate<typename T>\nusing detect_key_compare = typename T::key_compare;\n\ntemplate<typename T>\nstruct has_key_compare : std::integral_constant<bool, is_detected<detect_key_compare, T>::value> {};\n\n// obtains the actual object key comparator\ntemplate<typename BasicJsonType>\nstruct actual_object_comparator\n{\n    using object_t = typename BasicJsonType::object_t;\n    using object_comparator_t = typename BasicJsonType::default_object_comparator_t;\n    using type = typename std::conditional < has_key_compare<object_t>::value,\n          typename object_t::key_compare, object_comparator_t>::type;\n};\n\ntemplate<typename BasicJsonType>\nusing actual_object_comparator_t = typename actual_object_comparator<BasicJsonType>::type;\n\n/////////////////\n// char_traits //\n/////////////////\n\n// Primary template of char_traits calls std char_traits\ntemplate<typename T>\nstruct char_traits : std::char_traits<T>\n{};\n\n// Explicitly define char traits for unsigned char since it is not standard\ntemplate<>\nstruct char_traits<unsigned char> : std::char_traits<char>\n{\n    using char_type = unsigned char;\n    using int_type = uint64_t;\n\n    // Redefine to_int_type function\n    static int_type to_int_type(char_type c) noexcept\n    {\n        return static_cast<int_type>(c);\n    }\n\n    static char_type to_char_type(int_type i) noexcept\n    {\n        return static_cast<char_type>(i);\n    }\n\n    static constexpr int_type eof() noexcept\n    {\n        return static_cast<int_type>(EOF);\n    }\n};\n\n// Explicitly define char traits for signed char since it is not standard\ntemplate<>\nstruct char_traits<signed char> : std::char_traits<char>\n{\n    using char_type = signed char;\n    using int_type = uint64_t;\n\n    // Redefine to_int_type function\n    static int_type to_int_type(char_type c) noexcept\n    {\n        return static_cast<int_type>(c);\n    }\n\n    static char_type to_char_type(int_type i) noexcept\n    {\n        return static_cast<char_type>(i);\n    }\n\n    static constexpr int_type eof() noexcept\n    {\n        return static_cast<int_type>(EOF);\n    }\n};\n\n///////////////////\n// is_ functions //\n///////////////////\n\n// https://en.cppreference.com/w/cpp/types/conjunction\ntemplate<class...> struct conjunction : std::true_type { };\ntemplate<class B> struct conjunction<B> : B { };\ntemplate<class B, class... Bn>\nstruct conjunction<B, Bn...>\n: std::conditional<static_cast<bool>(B::value), conjunction<Bn...>, B>::type {};\n\n// https://en.cppreference.com/w/cpp/types/negation\ntemplate<class B> struct negation : std::integral_constant < bool, !B::value > { };\n\n// Reimplementation of is_constructible and is_default_constructible, due to them being broken for\n// std::pair and std::tuple until LWG 2367 fix (see https://cplusplus.github.io/LWG/lwg-defects.html#2367).\n// This causes compile errors in e.g. clang 3.5 or gcc 4.9.\ntemplate <typename T>\nstruct is_default_constructible : std::is_default_constructible<T> {};\n\ntemplate <typename T1, typename T2>\nstruct is_default_constructible<std::pair<T1, T2>>\n            : conjunction<is_default_constructible<T1>, is_default_constructible<T2>> {};\n\ntemplate <typename T1, typename T2>\nstruct is_default_constructible<const std::pair<T1, T2>>\n            : conjunction<is_default_constructible<T1>, is_default_constructible<T2>> {};\n\ntemplate <typename... Ts>\nstruct is_default_constructible<std::tuple<Ts...>>\n            : conjunction<is_default_constructible<Ts>...> {};\n\ntemplate <typename... Ts>\nstruct is_default_constructible<const std::tuple<Ts...>>\n            : conjunction<is_default_constructible<Ts>...> {};\n\ntemplate <typename T, typename... Args>\nstruct is_constructible : std::is_constructible<T, Args...> {};\n\ntemplate <typename T1, typename T2>\nstruct is_constructible<std::pair<T1, T2>> : is_default_constructible<std::pair<T1, T2>> {};\n\ntemplate <typename T1, typename T2>\nstruct is_constructible<const std::pair<T1, T2>> : is_default_constructible<const std::pair<T1, T2>> {};\n\ntemplate <typename... Ts>\nstruct is_constructible<std::tuple<Ts...>> : is_default_constructible<std::tuple<Ts...>> {};\n\ntemplate <typename... Ts>\nstruct is_constructible<const std::tuple<Ts...>> : is_default_constructible<const std::tuple<Ts...>> {};\n\ntemplate<typename T, typename = void>\nstruct is_iterator_traits : std::false_type {};\n\ntemplate<typename T>\nstruct is_iterator_traits<iterator_traits<T>>\n{\n  private:\n    using traits = iterator_traits<T>;\n\n  public:\n    static constexpr auto value =\n        is_detected<value_type_t, traits>::value &&\n        is_detected<difference_type_t, traits>::value &&\n        is_detected<pointer_t, traits>::value &&\n        is_detected<iterator_category_t, traits>::value &&\n        is_detected<reference_t, traits>::value;\n};\n\ntemplate<typename T>\nstruct is_range\n{\n  private:\n    using t_ref = typename std::add_lvalue_reference<T>::type;\n\n    using iterator = detected_t<result_of_begin, t_ref>;\n    using sentinel = detected_t<result_of_end, t_ref>;\n\n    // to be 100% correct, it should use https://en.cppreference.com/w/cpp/iterator/input_or_output_iterator\n    // and https://en.cppreference.com/w/cpp/iterator/sentinel_for\n    // but reimplementing these would be too much work, as a lot of other concepts are used underneath\n    static constexpr auto is_iterator_begin =\n        is_iterator_traits<iterator_traits<iterator>>::value;\n\n  public:\n    static constexpr bool value = !std::is_same<iterator, nonesuch>::value && !std::is_same<sentinel, nonesuch>::value && is_iterator_begin;\n};\n\ntemplate<typename R>\nusing iterator_t = enable_if_t<is_range<R>::value, result_of_begin<decltype(std::declval<R&>())>>;\n\ntemplate<typename T>\nusing range_value_t = value_type_t<iterator_traits<iterator_t<T>>>;\n\n// The following implementation of is_complete_type is taken from\n// https://blogs.msdn.microsoft.com/vcblog/2015/12/02/partial-support-for-expression-sfinae-in-vs-2015-update-1/\n// and is written by Xiang Fan who agreed to using it in this library.\n\ntemplate<typename T, typename = void>\nstruct is_complete_type : std::false_type {};\n\ntemplate<typename T>\nstruct is_complete_type<T, decltype(void(sizeof(T)))> : std::true_type {};\n\ntemplate<typename BasicJsonType, typename CompatibleObjectType,\n         typename = void>\nstruct is_compatible_object_type_impl : std::false_type {};\n\ntemplate<typename BasicJsonType, typename CompatibleObjectType>\nstruct is_compatible_object_type_impl <\n    BasicJsonType, CompatibleObjectType,\n    enable_if_t < is_detected<mapped_type_t, CompatibleObjectType>::value&&\n    is_detected<key_type_t, CompatibleObjectType>::value >>\n{\n    using object_t = typename BasicJsonType::object_t;\n\n    // macOS's is_constructible does not play well with nonesuch...\n    static constexpr bool value =\n        is_constructible<typename object_t::key_type,\n        typename CompatibleObjectType::key_type>::value &&\n        is_constructible<typename object_t::mapped_type,\n        typename CompatibleObjectType::mapped_type>::value;\n};\n\ntemplate<typename BasicJsonType, typename CompatibleObjectType>\nstruct is_compatible_object_type\n    : is_compatible_object_type_impl<BasicJsonType, CompatibleObjectType> {};\n\ntemplate<typename BasicJsonType, typename ConstructibleObjectType,\n         typename = void>\nstruct is_constructible_object_type_impl : std::false_type {};\n\ntemplate<typename BasicJsonType, typename ConstructibleObjectType>\nstruct is_constructible_object_type_impl <\n    BasicJsonType, ConstructibleObjectType,\n    enable_if_t < is_detected<mapped_type_t, ConstructibleObjectType>::value&&\n    is_detected<key_type_t, ConstructibleObjectType>::value >>\n{\n    using object_t = typename BasicJsonType::object_t;\n\n    static constexpr bool value =\n        (is_default_constructible<ConstructibleObjectType>::value &&\n         (std::is_move_assignable<ConstructibleObjectType>::value ||\n          std::is_copy_assignable<ConstructibleObjectType>::value) &&\n         (is_constructible<typename ConstructibleObjectType::key_type,\n          typename object_t::key_type>::value &&\n          std::is_same <\n          typename object_t::mapped_type,\n          typename ConstructibleObjectType::mapped_type >::value)) ||\n        (has_from_json<BasicJsonType,\n         typename ConstructibleObjectType::mapped_type>::value ||\n         has_non_default_from_json <\n         BasicJsonType,\n         typename ConstructibleObjectType::mapped_type >::value);\n};\n\ntemplate<typename BasicJsonType, typename ConstructibleObjectType>\nstruct is_constructible_object_type\n    : is_constructible_object_type_impl<BasicJsonType,\n      ConstructibleObjectType> {};\n\ntemplate<typename BasicJsonType, typename CompatibleStringType>\nstruct is_compatible_string_type\n{\n    static constexpr auto value =\n        is_constructible<typename BasicJsonType::string_t, CompatibleStringType>::value;\n};\n\ntemplate<typename BasicJsonType, typename ConstructibleStringType>\nstruct is_constructible_string_type\n{\n    // launder type through decltype() to fix compilation failure on ICPC\n#ifdef __INTEL_COMPILER\n    using laundered_type = decltype(std::declval<ConstructibleStringType>());\n#else\n    using laundered_type = ConstructibleStringType;\n#endif\n\n    static constexpr auto value =\n        conjunction <\n        is_constructible<laundered_type, typename BasicJsonType::string_t>,\n        is_detected_exact<typename BasicJsonType::string_t::value_type,\n        value_type_t, laundered_type >>::value;\n};\n\ntemplate<typename BasicJsonType, typename CompatibleArrayType, typename = void>\nstruct is_compatible_array_type_impl : std::false_type {};\n\ntemplate<typename BasicJsonType, typename CompatibleArrayType>\nstruct is_compatible_array_type_impl <\n    BasicJsonType, CompatibleArrayType,\n    enable_if_t <\n    is_detected<iterator_t, CompatibleArrayType>::value&&\n    is_iterator_traits<iterator_traits<detected_t<iterator_t, CompatibleArrayType>>>::value&&\n// special case for types like std::filesystem::path whose iterator's value_type are themselves\n// c.f. https://github.com/nlohmann/json/pull/3073\n    !std::is_same<CompatibleArrayType, detected_t<range_value_t, CompatibleArrayType>>::value >>\n{\n    static constexpr bool value =\n        is_constructible<BasicJsonType,\n        range_value_t<CompatibleArrayType>>::value;\n};\n\ntemplate<typename BasicJsonType, typename CompatibleArrayType>\nstruct is_compatible_array_type\n    : is_compatible_array_type_impl<BasicJsonType, CompatibleArrayType> {};\n\ntemplate<typename BasicJsonType, typename ConstructibleArrayType, typename = void>\nstruct is_constructible_array_type_impl : std::false_type {};\n\ntemplate<typename BasicJsonType, typename ConstructibleArrayType>\nstruct is_constructible_array_type_impl <\n    BasicJsonType, ConstructibleArrayType,\n    enable_if_t<std::is_same<ConstructibleArrayType,\n    typename BasicJsonType::value_type>::value >>\n            : std::true_type {};\n\ntemplate<typename BasicJsonType, typename ConstructibleArrayType>\nstruct is_constructible_array_type_impl <\n    BasicJsonType, ConstructibleArrayType,\n    enable_if_t < !std::is_same<ConstructibleArrayType,\n    typename BasicJsonType::value_type>::value&&\n    !is_compatible_string_type<BasicJsonType, ConstructibleArrayType>::value&&\n    is_default_constructible<ConstructibleArrayType>::value&&\n(std::is_move_assignable<ConstructibleArrayType>::value ||\n std::is_copy_assignable<ConstructibleArrayType>::value)&&\nis_detected<iterator_t, ConstructibleArrayType>::value&&\nis_iterator_traits<iterator_traits<detected_t<iterator_t, ConstructibleArrayType>>>::value&&\nis_detected<range_value_t, ConstructibleArrayType>::value&&\n// special case for types like std::filesystem::path whose iterator's value_type are themselves\n// c.f. https://github.com/nlohmann/json/pull/3073\n!std::is_same<ConstructibleArrayType, detected_t<range_value_t, ConstructibleArrayType>>::value&&\n        is_complete_type <\n        detected_t<range_value_t, ConstructibleArrayType >>::value >>\n{\n    using value_type = range_value_t<ConstructibleArrayType>;\n\n    static constexpr bool value =\n        std::is_same<value_type,\n        typename BasicJsonType::array_t::value_type>::value ||\n        has_from_json<BasicJsonType,\n        value_type>::value ||\n        has_non_default_from_json <\n        BasicJsonType,\n        value_type >::value;\n};\n\ntemplate<typename BasicJsonType, typename ConstructibleArrayType>\nstruct is_constructible_array_type\n    : is_constructible_array_type_impl<BasicJsonType, ConstructibleArrayType> {};\n\ntemplate<typename RealIntegerType, typename CompatibleNumberIntegerType,\n         typename = void>\nstruct is_compatible_integer_type_impl : std::false_type {};\n\ntemplate<typename RealIntegerType, typename CompatibleNumberIntegerType>\nstruct is_compatible_integer_type_impl <\n    RealIntegerType, CompatibleNumberIntegerType,\n    enable_if_t < std::is_integral<RealIntegerType>::value&&\n    std::is_integral<CompatibleNumberIntegerType>::value&&\n    !std::is_same<bool, CompatibleNumberIntegerType>::value >>\n{\n    // is there an assert somewhere on overflows?\n    using RealLimits = std::numeric_limits<RealIntegerType>;\n    using CompatibleLimits = std::numeric_limits<CompatibleNumberIntegerType>;\n\n    static constexpr auto value =\n        is_constructible<RealIntegerType,\n        CompatibleNumberIntegerType>::value &&\n        CompatibleLimits::is_integer &&\n        RealLimits::is_signed == CompatibleLimits::is_signed;\n};\n\ntemplate<typename RealIntegerType, typename CompatibleNumberIntegerType>\nstruct is_compatible_integer_type\n    : is_compatible_integer_type_impl<RealIntegerType,\n      CompatibleNumberIntegerType> {};\n\ntemplate<typename BasicJsonType, typename CompatibleType, typename = void>\nstruct is_compatible_type_impl: std::false_type {};\n\ntemplate<typename BasicJsonType, typename CompatibleType>\nstruct is_compatible_type_impl <\n    BasicJsonType, CompatibleType,\n    enable_if_t<is_complete_type<CompatibleType>::value >>\n{\n    static constexpr bool value =\n        has_to_json<BasicJsonType, CompatibleType>::value;\n};\n\ntemplate<typename BasicJsonType, typename CompatibleType>\nstruct is_compatible_type\n    : is_compatible_type_impl<BasicJsonType, CompatibleType> {};\n\ntemplate<typename T1, typename T2>\nstruct is_constructible_tuple : std::false_type {};\n\ntemplate<typename T1, typename... Args>\nstruct is_constructible_tuple<T1, std::tuple<Args...>> : conjunction<is_constructible<T1, Args>...> {};\n\ntemplate<typename BasicJsonType, typename T>\nstruct is_json_iterator_of : std::false_type {};\n\ntemplate<typename BasicJsonType>\nstruct is_json_iterator_of<BasicJsonType, typename BasicJsonType::iterator> : std::true_type {};\n\ntemplate<typename BasicJsonType>\nstruct is_json_iterator_of<BasicJsonType, typename BasicJsonType::const_iterator> : std::true_type\n{};\n\n// checks if a given type T is a template specialization of Primary\ntemplate<template <typename...> class Primary, typename T>\nstruct is_specialization_of : std::false_type {};\n\ntemplate<template <typename...> class Primary, typename... Args>\nstruct is_specialization_of<Primary, Primary<Args...>> : std::true_type {};\n\ntemplate<typename T>\nusing is_json_pointer = is_specialization_of<::nlohmann::json_pointer, uncvref_t<T>>;\n\n// checks if A and B are comparable using Compare functor\ntemplate<typename Compare, typename A, typename B, typename = void>\nstruct is_comparable : std::false_type {};\n\ntemplate<typename Compare, typename A, typename B>\nstruct is_comparable<Compare, A, B, void_t<\ndecltype(std::declval<Compare>()(std::declval<A>(), std::declval<B>())),\ndecltype(std::declval<Compare>()(std::declval<B>(), std::declval<A>()))\n>> : std::true_type {};\n\ntemplate<typename T>\nusing detect_is_transparent = typename T::is_transparent;\n\n// type trait to check if KeyType can be used as object key (without a BasicJsonType)\n// see is_usable_as_basic_json_key_type below\ntemplate<typename Comparator, typename ObjectKeyType, typename KeyTypeCVRef, bool RequireTransparentComparator = true,\n         bool ExcludeObjectKeyType = RequireTransparentComparator, typename KeyType = uncvref_t<KeyTypeCVRef>>\nusing is_usable_as_key_type = typename std::conditional <\n                              is_comparable<Comparator, ObjectKeyType, KeyTypeCVRef>::value\n                              && !(ExcludeObjectKeyType && std::is_same<KeyType,\n                                   ObjectKeyType>::value)\n                              && (!RequireTransparentComparator\n                                  || is_detected <detect_is_transparent, Comparator>::value)\n                              && !is_json_pointer<KeyType>::value,\n                              std::true_type,\n                              std::false_type >::type;\n\n// type trait to check if KeyType can be used as object key\n// true if:\n//   - KeyType is comparable with BasicJsonType::object_t::key_type\n//   - if ExcludeObjectKeyType is true, KeyType is not BasicJsonType::object_t::key_type\n//   - the comparator is transparent or RequireTransparentComparator is false\n//   - KeyType is not a JSON iterator or json_pointer\ntemplate<typename BasicJsonType, typename KeyTypeCVRef, bool RequireTransparentComparator = true,\n         bool ExcludeObjectKeyType = RequireTransparentComparator, typename KeyType = uncvref_t<KeyTypeCVRef>>\nusing is_usable_as_basic_json_key_type = typename std::conditional <\n        is_usable_as_key_type<typename BasicJsonType::object_comparator_t,\n        typename BasicJsonType::object_t::key_type, KeyTypeCVRef,\n        RequireTransparentComparator, ExcludeObjectKeyType>::value\n        && !is_json_iterator_of<BasicJsonType, KeyType>::value,\n        std::true_type,\n        std::false_type >::type;\n\ntemplate<typename ObjectType, typename KeyType>\nusing detect_erase_with_key_type = decltype(std::declval<ObjectType&>().erase(std::declval<KeyType>()));\n\n// type trait to check if object_t has an erase() member functions accepting KeyType\ntemplate<typename BasicJsonType, typename KeyType>\nusing has_erase_with_key_type = typename std::conditional <\n                                is_detected <\n                                detect_erase_with_key_type,\n                                typename BasicJsonType::object_t, KeyType >::value,\n                                std::true_type,\n                                std::false_type >::type;\n\n// a naive helper to check if a type is an ordered_map (exploits the fact that\n// ordered_map inherits capacity() from std::vector)\ntemplate <typename T>\nstruct is_ordered_map\n{\n    using one = char;\n\n    struct two\n    {\n        char x[2]; // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays)\n    };\n\n    template <typename C> static one test( decltype(&C::capacity) ) ;\n    template <typename C> static two test(...);\n\n    enum { value = sizeof(test<T>(nullptr)) == sizeof(char) }; // NOLINT(cppcoreguidelines-pro-type-vararg,hicpp-vararg)\n};\n\n// to avoid useless casts (see https://github.com/nlohmann/json/issues/2893#issuecomment-889152324)\ntemplate < typename T, typename U, enable_if_t < !std::is_same<T, U>::value, int > = 0 >\nT conditional_static_cast(U value)\n{\n    return static_cast<T>(value);\n}\n\ntemplate<typename T, typename U, enable_if_t<std::is_same<T, U>::value, int> = 0>\nT conditional_static_cast(U value)\n{\n    return value;\n}\n\ntemplate<typename... Types>\nusing all_integral = conjunction<std::is_integral<Types>...>;\n\ntemplate<typename... Types>\nusing all_signed = conjunction<std::is_signed<Types>...>;\n\ntemplate<typename... Types>\nusing all_unsigned = conjunction<std::is_unsigned<Types>...>;\n\n// there's a disjunction trait in another PR; replace when merged\ntemplate<typename... Types>\nusing same_sign = std::integral_constant < bool,\n      all_signed<Types...>::value || all_unsigned<Types...>::value >;\n\ntemplate<typename OfType, typename T>\nusing never_out_of_range = std::integral_constant < bool,\n      (std::is_signed<OfType>::value && (sizeof(T) < sizeof(OfType)))\n      || (same_sign<OfType, T>::value && sizeof(OfType) == sizeof(T)) >;\n\ntemplate<typename OfType, typename T,\n         bool OfTypeSigned = std::is_signed<OfType>::value,\n         bool TSigned = std::is_signed<T>::value>\nstruct value_in_range_of_impl2;\n\ntemplate<typename OfType, typename T>\nstruct value_in_range_of_impl2<OfType, T, false, false>\n{\n    static constexpr bool test(T val)\n    {\n        using CommonType = typename std::common_type<OfType, T>::type;\n        return static_cast<CommonType>(val) <= static_cast<CommonType>((std::numeric_limits<OfType>::max)());\n    }\n};\n\ntemplate<typename OfType, typename T>\nstruct value_in_range_of_impl2<OfType, T, true, false>\n{\n    static constexpr bool test(T val)\n    {\n        using CommonType = typename std::common_type<OfType, T>::type;\n        return static_cast<CommonType>(val) <= static_cast<CommonType>((std::numeric_limits<OfType>::max)());\n    }\n};\n\ntemplate<typename OfType, typename T>\nstruct value_in_range_of_impl2<OfType, T, false, true>\n{\n    static constexpr bool test(T val)\n    {\n        using CommonType = typename std::common_type<OfType, T>::type;\n        return val >= 0 && static_cast<CommonType>(val) <= static_cast<CommonType>((std::numeric_limits<OfType>::max)());\n    }\n};\n\ntemplate<typename OfType, typename T>\nstruct value_in_range_of_impl2<OfType, T, true, true>\n{\n    static constexpr bool test(T val)\n    {\n        using CommonType = typename std::common_type<OfType, T>::type;\n        return static_cast<CommonType>(val) >= static_cast<CommonType>((std::numeric_limits<OfType>::min)())\n               && static_cast<CommonType>(val) <= static_cast<CommonType>((std::numeric_limits<OfType>::max)());\n    }\n};\n\ntemplate<typename OfType, typename T,\n         bool NeverOutOfRange = never_out_of_range<OfType, T>::value,\n         typename = detail::enable_if_t<all_integral<OfType, T>::value>>\nstruct value_in_range_of_impl1;\n\ntemplate<typename OfType, typename T>\nstruct value_in_range_of_impl1<OfType, T, false>\n{\n    static constexpr bool test(T val)\n    {\n        return value_in_range_of_impl2<OfType, T>::test(val);\n    }\n};\n\ntemplate<typename OfType, typename T>\nstruct value_in_range_of_impl1<OfType, T, true>\n{\n    static constexpr bool test(T /*val*/)\n    {\n        return true;\n    }\n};\n\ntemplate<typename OfType, typename T>\ninline constexpr bool value_in_range_of(T val)\n{\n    return value_in_range_of_impl1<OfType, T>::test(val);\n}\n\ntemplate<bool Value>\nusing bool_constant = std::integral_constant<bool, Value>;\n\n///////////////////////////////////////////////////////////////////////////////\n// is_c_string\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace impl\n{\n\ntemplate<typename T>\ninline constexpr bool is_c_string()\n{\n    using TUnExt = typename std::remove_extent<T>::type;\n    using TUnCVExt = typename std::remove_cv<TUnExt>::type;\n    using TUnPtr = typename std::remove_pointer<T>::type;\n    using TUnCVPtr = typename std::remove_cv<TUnPtr>::type;\n    return\n        (std::is_array<T>::value && std::is_same<TUnCVExt, char>::value)\n        || (std::is_pointer<T>::value && std::is_same<TUnCVPtr, char>::value);\n}\n\n}  // namespace impl\n\n// checks whether T is a [cv] char */[cv] char[] C string\ntemplate<typename T>\nstruct is_c_string : bool_constant<impl::is_c_string<T>()> {};\n\ntemplate<typename T>\nusing is_c_string_uncvref = is_c_string<uncvref_t<T>>;\n\n///////////////////////////////////////////////////////////////////////////////\n// is_transparent\n///////////////////////////////////////////////////////////////////////////////\n\nnamespace impl\n{\n\ntemplate<typename T>\ninline constexpr bool is_transparent()\n{\n    return is_detected<detect_is_transparent, T>::value;\n}\n\n}  // namespace impl\n\n// checks whether T has a member named is_transparent\ntemplate<typename T>\nstruct is_transparent : bool_constant<impl::is_transparent<T>()> {};\n\n///////////////////////////////////////////////////////////////////////////////\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/string_concat.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cstring> // strlen\n#include <string> // string\n#include <utility> // forward\n\n// #include <nlohmann/detail/meta/cpp_future.hpp>\n\n// #include <nlohmann/detail/meta/detected.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\ninline std::size_t concat_length()\n{\n    return 0;\n}\n\ntemplate<typename... Args>\ninline std::size_t concat_length(const char* cstr, const Args& ... rest);\n\ntemplate<typename StringType, typename... Args>\ninline std::size_t concat_length(const StringType& str, const Args& ... rest);\n\ntemplate<typename... Args>\ninline std::size_t concat_length(const char /*c*/, const Args& ... rest)\n{\n    return 1 + concat_length(rest...);\n}\n\ntemplate<typename... Args>\ninline std::size_t concat_length(const char* cstr, const Args& ... rest)\n{\n    // cppcheck-suppress ignoredReturnValue\n    return ::strlen(cstr) + concat_length(rest...);\n}\n\ntemplate<typename StringType, typename... Args>\ninline std::size_t concat_length(const StringType& str, const Args& ... rest)\n{\n    return str.size() + concat_length(rest...);\n}\n\ntemplate<typename OutStringType>\ninline void concat_into(OutStringType& /*out*/)\n{}\n\ntemplate<typename StringType, typename Arg>\nusing string_can_append = decltype(std::declval<StringType&>().append(std::declval < Arg && > ()));\n\ntemplate<typename StringType, typename Arg>\nusing detect_string_can_append = is_detected<string_can_append, StringType, Arg>;\n\ntemplate<typename StringType, typename Arg>\nusing string_can_append_op = decltype(std::declval<StringType&>() += std::declval < Arg && > ());\n\ntemplate<typename StringType, typename Arg>\nusing detect_string_can_append_op = is_detected<string_can_append_op, StringType, Arg>;\n\ntemplate<typename StringType, typename Arg>\nusing string_can_append_iter = decltype(std::declval<StringType&>().append(std::declval<const Arg&>().begin(), std::declval<const Arg&>().end()));\n\ntemplate<typename StringType, typename Arg>\nusing detect_string_can_append_iter = is_detected<string_can_append_iter, StringType, Arg>;\n\ntemplate<typename StringType, typename Arg>\nusing string_can_append_data = decltype(std::declval<StringType&>().append(std::declval<const Arg&>().data(), std::declval<const Arg&>().size()));\n\ntemplate<typename StringType, typename Arg>\nusing detect_string_can_append_data = is_detected<string_can_append_data, StringType, Arg>;\n\ntemplate < typename OutStringType, typename Arg, typename... Args,\n           enable_if_t < !detect_string_can_append<OutStringType, Arg>::value\n                         && detect_string_can_append_op<OutStringType, Arg>::value, int > = 0 >\ninline void concat_into(OutStringType& out, Arg && arg, Args && ... rest);\n\ntemplate < typename OutStringType, typename Arg, typename... Args,\n           enable_if_t < !detect_string_can_append<OutStringType, Arg>::value\n                         && !detect_string_can_append_op<OutStringType, Arg>::value\n                         && detect_string_can_append_iter<OutStringType, Arg>::value, int > = 0 >\ninline void concat_into(OutStringType& out, const Arg& arg, Args && ... rest);\n\ntemplate < typename OutStringType, typename Arg, typename... Args,\n           enable_if_t < !detect_string_can_append<OutStringType, Arg>::value\n                         && !detect_string_can_append_op<OutStringType, Arg>::value\n                         && !detect_string_can_append_iter<OutStringType, Arg>::value\n                         && detect_string_can_append_data<OutStringType, Arg>::value, int > = 0 >\ninline void concat_into(OutStringType& out, const Arg& arg, Args && ... rest);\n\ntemplate<typename OutStringType, typename Arg, typename... Args,\n         enable_if_t<detect_string_can_append<OutStringType, Arg>::value, int> = 0>\ninline void concat_into(OutStringType& out, Arg && arg, Args && ... rest)\n{\n    out.append(std::forward<Arg>(arg));\n    concat_into(out, std::forward<Args>(rest)...);\n}\n\ntemplate < typename OutStringType, typename Arg, typename... Args,\n           enable_if_t < !detect_string_can_append<OutStringType, Arg>::value\n                         && detect_string_can_append_op<OutStringType, Arg>::value, int > >\ninline void concat_into(OutStringType& out, Arg&& arg, Args&& ... rest)\n{\n    out += std::forward<Arg>(arg);\n    concat_into(out, std::forward<Args>(rest)...);\n}\n\ntemplate < typename OutStringType, typename Arg, typename... Args,\n           enable_if_t < !detect_string_can_append<OutStringType, Arg>::value\n                         && !detect_string_can_append_op<OutStringType, Arg>::value\n                         && detect_string_can_append_iter<OutStringType, Arg>::value, int > >\ninline void concat_into(OutStringType& out, const Arg& arg, Args&& ... rest)\n{\n    out.append(arg.begin(), arg.end());\n    concat_into(out, std::forward<Args>(rest)...);\n}\n\ntemplate < typename OutStringType, typename Arg, typename... Args,\n           enable_if_t < !detect_string_can_append<OutStringType, Arg>::value\n                         && !detect_string_can_append_op<OutStringType, Arg>::value\n                         && !detect_string_can_append_iter<OutStringType, Arg>::value\n                         && detect_string_can_append_data<OutStringType, Arg>::value, int > >\ninline void concat_into(OutStringType& out, const Arg& arg, Args&& ... rest)\n{\n    out.append(arg.data(), arg.size());\n    concat_into(out, std::forward<Args>(rest)...);\n}\n\ntemplate<typename OutStringType = std::string, typename... Args>\ninline OutStringType concat(Args && ... args)\n{\n    OutStringType str;\n    str.reserve(concat_length(args...));\n    concat_into(str, std::forward<Args>(args)...);\n    return str;\n}\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n////////////////\n// exceptions //\n////////////////\n\n/// @brief general exception of the @ref basic_json class\n/// @sa https://json.nlohmann.me/api/basic_json/exception/\nclass exception : public std::exception\n{\n  public:\n    /// returns the explanatory string\n    const char* what() const noexcept override\n    {\n        return m.what();\n    }\n\n    /// the id of the exception\n    const int id; // NOLINT(cppcoreguidelines-non-private-member-variables-in-classes)\n\n  protected:\n    JSON_HEDLEY_NON_NULL(3)\n    exception(int id_, const char* what_arg) : id(id_), m(what_arg) {} // NOLINT(bugprone-throw-keyword-missing)\n\n    static std::string name(const std::string& ename, int id_)\n    {\n        return concat(\"[json.exception.\", ename, '.', std::to_string(id_), \"] \");\n    }\n\n    static std::string diagnostics(std::nullptr_t /*leaf_element*/)\n    {\n        return \"\";\n    }\n\n    template<typename BasicJsonType>\n    static std::string diagnostics(const BasicJsonType* leaf_element)\n    {\n#if JSON_DIAGNOSTICS\n        std::vector<std::string> tokens;\n        for (const auto* current = leaf_element; current != nullptr && current->m_parent != nullptr; current = current->m_parent)\n        {\n            switch (current->m_parent->type())\n            {\n                case value_t::array:\n                {\n                    for (std::size_t i = 0; i < current->m_parent->m_data.m_value.array->size(); ++i)\n                    {\n                        if (&current->m_parent->m_data.m_value.array->operator[](i) == current)\n                        {\n                            tokens.emplace_back(std::to_string(i));\n                            break;\n                        }\n                    }\n                    break;\n                }\n\n                case value_t::object:\n                {\n                    for (const auto& element : *current->m_parent->m_data.m_value.object)\n                    {\n                        if (&element.second == current)\n                        {\n                            tokens.emplace_back(element.first.c_str());\n                            break;\n                        }\n                    }\n                    break;\n                }\n\n                case value_t::null: // LCOV_EXCL_LINE\n                case value_t::string: // LCOV_EXCL_LINE\n                case value_t::boolean: // LCOV_EXCL_LINE\n                case value_t::number_integer: // LCOV_EXCL_LINE\n                case value_t::number_unsigned: // LCOV_EXCL_LINE\n                case value_t::number_float: // LCOV_EXCL_LINE\n                case value_t::binary: // LCOV_EXCL_LINE\n                case value_t::discarded: // LCOV_EXCL_LINE\n                default:   // LCOV_EXCL_LINE\n                    break; // LCOV_EXCL_LINE\n            }\n        }\n\n        if (tokens.empty())\n        {\n            return \"\";\n        }\n\n        auto str = std::accumulate(tokens.rbegin(), tokens.rend(), std::string{},\n                                   [](const std::string & a, const std::string & b)\n        {\n            return concat(a, '/', detail::escape(b));\n        });\n        return concat('(', str, \") \");\n#else\n        static_cast<void>(leaf_element);\n        return \"\";\n#endif\n    }\n\n  private:\n    /// an exception object as storage for error messages\n    std::runtime_error m;\n};\n\n/// @brief exception indicating a parse error\n/// @sa https://json.nlohmann.me/api/basic_json/parse_error/\nclass parse_error : public exception\n{\n  public:\n    /*!\n    @brief create a parse error exception\n    @param[in] id_       the id of the exception\n    @param[in] pos       the position where the error occurred (or with\n                         chars_read_total=0 if the position cannot be\n                         determined)\n    @param[in] what_arg  the explanatory string\n    @return parse_error object\n    */\n    template<typename BasicJsonContext, enable_if_t<is_basic_json_context<BasicJsonContext>::value, int> = 0>\n    static parse_error create(int id_, const position_t& pos, const std::string& what_arg, BasicJsonContext context)\n    {\n        const std::string w = concat(exception::name(\"parse_error\", id_), \"parse error\",\n                                     position_string(pos), \": \", exception::diagnostics(context), what_arg);\n        return {id_, pos.chars_read_total, w.c_str()};\n    }\n\n    template<typename BasicJsonContext, enable_if_t<is_basic_json_context<BasicJsonContext>::value, int> = 0>\n    static parse_error create(int id_, std::size_t byte_, const std::string& what_arg, BasicJsonContext context)\n    {\n        const std::string w = concat(exception::name(\"parse_error\", id_), \"parse error\",\n                                     (byte_ != 0 ? (concat(\" at byte \", std::to_string(byte_))) : \"\"),\n                                     \": \", exception::diagnostics(context), what_arg);\n        return {id_, byte_, w.c_str()};\n    }\n\n    /*!\n    @brief byte index of the parse error\n\n    The byte index of the last read character in the input file.\n\n    @note For an input with n bytes, 1 is the index of the first character and\n          n+1 is the index of the terminating null byte or the end of file.\n          This also holds true when reading a byte vector (CBOR or MessagePack).\n    */\n    const std::size_t byte;\n\n  private:\n    parse_error(int id_, std::size_t byte_, const char* what_arg)\n        : exception(id_, what_arg), byte(byte_) {}\n\n    static std::string position_string(const position_t& pos)\n    {\n        return concat(\" at line \", std::to_string(pos.lines_read + 1),\n                      \", column \", std::to_string(pos.chars_read_current_line));\n    }\n};\n\n/// @brief exception indicating errors with iterators\n/// @sa https://json.nlohmann.me/api/basic_json/invalid_iterator/\nclass invalid_iterator : public exception\n{\n  public:\n    template<typename BasicJsonContext, enable_if_t<is_basic_json_context<BasicJsonContext>::value, int> = 0>\n    static invalid_iterator create(int id_, const std::string& what_arg, BasicJsonContext context)\n    {\n        const std::string w = concat(exception::name(\"invalid_iterator\", id_), exception::diagnostics(context), what_arg);\n        return {id_, w.c_str()};\n    }\n\n  private:\n    JSON_HEDLEY_NON_NULL(3)\n    invalid_iterator(int id_, const char* what_arg)\n        : exception(id_, what_arg) {}\n};\n\n/// @brief exception indicating executing a member function with a wrong type\n/// @sa https://json.nlohmann.me/api/basic_json/type_error/\nclass type_error : public exception\n{\n  public:\n    template<typename BasicJsonContext, enable_if_t<is_basic_json_context<BasicJsonContext>::value, int> = 0>\n    static type_error create(int id_, const std::string& what_arg, BasicJsonContext context)\n    {\n        const std::string w = concat(exception::name(\"type_error\", id_), exception::diagnostics(context), what_arg);\n        return {id_, w.c_str()};\n    }\n\n  private:\n    JSON_HEDLEY_NON_NULL(3)\n    type_error(int id_, const char* what_arg) : exception(id_, what_arg) {}\n};\n\n/// @brief exception indicating access out of the defined range\n/// @sa https://json.nlohmann.me/api/basic_json/out_of_range/\nclass out_of_range : public exception\n{\n  public:\n    template<typename BasicJsonContext, enable_if_t<is_basic_json_context<BasicJsonContext>::value, int> = 0>\n    static out_of_range create(int id_, const std::string& what_arg, BasicJsonContext context)\n    {\n        const std::string w = concat(exception::name(\"out_of_range\", id_), exception::diagnostics(context), what_arg);\n        return {id_, w.c_str()};\n    }\n\n  private:\n    JSON_HEDLEY_NON_NULL(3)\n    out_of_range(int id_, const char* what_arg) : exception(id_, what_arg) {}\n};\n\n/// @brief exception indicating other library errors\n/// @sa https://json.nlohmann.me/api/basic_json/other_error/\nclass other_error : public exception\n{\n  public:\n    template<typename BasicJsonContext, enable_if_t<is_basic_json_context<BasicJsonContext>::value, int> = 0>\n    static other_error create(int id_, const std::string& what_arg, BasicJsonContext context)\n    {\n        const std::string w = concat(exception::name(\"other_error\", id_), exception::diagnostics(context), what_arg);\n        return {id_, w.c_str()};\n    }\n\n  private:\n    JSON_HEDLEY_NON_NULL(3)\n    other_error(int id_, const char* what_arg) : exception(id_, what_arg) {}\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/cpp_future.hpp>\n\n// #include <nlohmann/detail/meta/identity_tag.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n// dispatching helper struct\ntemplate <class T> struct identity_tag {};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/meta/std_fs.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n\n#if JSON_HAS_EXPERIMENTAL_FILESYSTEM\n#include <experimental/filesystem>\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\nnamespace std_fs = std::experimental::filesystem;\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n#elif JSON_HAS_FILESYSTEM\n#include <filesystem>\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\nnamespace std_fs = std::filesystem;\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n#endif\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n// #include <nlohmann/detail/string_concat.hpp>\n\n// #include <nlohmann/detail/value_t.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\ntemplate<typename BasicJsonType>\ninline void from_json(const BasicJsonType& j, typename std::nullptr_t& n)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_null()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be null, but is \", j.type_name()), &j));\n    }\n    n = nullptr;\n}\n\n// overloads for basic_json template parameters\ntemplate < typename BasicJsonType, typename ArithmeticType,\n           enable_if_t < std::is_arithmetic<ArithmeticType>::value&&\n                         !std::is_same<ArithmeticType, typename BasicJsonType::boolean_t>::value,\n                         int > = 0 >\nvoid get_arithmetic_value(const BasicJsonType& j, ArithmeticType& val)\n{\n    switch (static_cast<value_t>(j))\n    {\n        case value_t::number_unsigned:\n        {\n            val = static_cast<ArithmeticType>(*j.template get_ptr<const typename BasicJsonType::number_unsigned_t*>());\n            break;\n        }\n        case value_t::number_integer:\n        {\n            val = static_cast<ArithmeticType>(*j.template get_ptr<const typename BasicJsonType::number_integer_t*>());\n            break;\n        }\n        case value_t::number_float:\n        {\n            val = static_cast<ArithmeticType>(*j.template get_ptr<const typename BasicJsonType::number_float_t*>());\n            break;\n        }\n\n        case value_t::null:\n        case value_t::object:\n        case value_t::array:\n        case value_t::string:\n        case value_t::boolean:\n        case value_t::binary:\n        case value_t::discarded:\n        default:\n            JSON_THROW(type_error::create(302, concat(\"type must be number, but is \", j.type_name()), &j));\n    }\n}\n\ntemplate<typename BasicJsonType>\ninline void from_json(const BasicJsonType& j, typename BasicJsonType::boolean_t& b)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_boolean()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be boolean, but is \", j.type_name()), &j));\n    }\n    b = *j.template get_ptr<const typename BasicJsonType::boolean_t*>();\n}\n\ntemplate<typename BasicJsonType>\ninline void from_json(const BasicJsonType& j, typename BasicJsonType::string_t& s)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_string()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be string, but is \", j.type_name()), &j));\n    }\n    s = *j.template get_ptr<const typename BasicJsonType::string_t*>();\n}\n\ntemplate <\n    typename BasicJsonType, typename StringType,\n    enable_if_t <\n        std::is_assignable<StringType&, const typename BasicJsonType::string_t>::value\n        && is_detected_exact<typename BasicJsonType::string_t::value_type, value_type_t, StringType>::value\n        && !std::is_same<typename BasicJsonType::string_t, StringType>::value\n        && !is_json_ref<StringType>::value, int > = 0 >\ninline void from_json(const BasicJsonType& j, StringType& s)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_string()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be string, but is \", j.type_name()), &j));\n    }\n\n    s = *j.template get_ptr<const typename BasicJsonType::string_t*>();\n}\n\ntemplate<typename BasicJsonType>\ninline void from_json(const BasicJsonType& j, typename BasicJsonType::number_float_t& val)\n{\n    get_arithmetic_value(j, val);\n}\n\ntemplate<typename BasicJsonType>\ninline void from_json(const BasicJsonType& j, typename BasicJsonType::number_unsigned_t& val)\n{\n    get_arithmetic_value(j, val);\n}\n\ntemplate<typename BasicJsonType>\ninline void from_json(const BasicJsonType& j, typename BasicJsonType::number_integer_t& val)\n{\n    get_arithmetic_value(j, val);\n}\n\n#if !JSON_DISABLE_ENUM_SERIALIZATION\ntemplate<typename BasicJsonType, typename EnumType,\n         enable_if_t<std::is_enum<EnumType>::value, int> = 0>\ninline void from_json(const BasicJsonType& j, EnumType& e)\n{\n    typename std::underlying_type<EnumType>::type val;\n    get_arithmetic_value(j, val);\n    e = static_cast<EnumType>(val);\n}\n#endif  // JSON_DISABLE_ENUM_SERIALIZATION\n\n// forward_list doesn't have an insert method\ntemplate<typename BasicJsonType, typename T, typename Allocator,\n         enable_if_t<is_getable<BasicJsonType, T>::value, int> = 0>\ninline void from_json(const BasicJsonType& j, std::forward_list<T, Allocator>& l)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_array()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be array, but is \", j.type_name()), &j));\n    }\n    l.clear();\n    std::transform(j.rbegin(), j.rend(),\n                   std::front_inserter(l), [](const BasicJsonType & i)\n    {\n        return i.template get<T>();\n    });\n}\n\n// valarray doesn't have an insert method\ntemplate<typename BasicJsonType, typename T,\n         enable_if_t<is_getable<BasicJsonType, T>::value, int> = 0>\ninline void from_json(const BasicJsonType& j, std::valarray<T>& l)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_array()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be array, but is \", j.type_name()), &j));\n    }\n    l.resize(j.size());\n    std::transform(j.begin(), j.end(), std::begin(l),\n                   [](const BasicJsonType & elem)\n    {\n        return elem.template get<T>();\n    });\n}\n\ntemplate<typename BasicJsonType, typename T, std::size_t N>\nauto from_json(const BasicJsonType& j, T (&arr)[N])  // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays)\n-> decltype(j.template get<T>(), void())\n{\n    for (std::size_t i = 0; i < N; ++i)\n    {\n        arr[i] = j.at(i).template get<T>();\n    }\n}\n\ntemplate<typename BasicJsonType>\ninline void from_json_array_impl(const BasicJsonType& j, typename BasicJsonType::array_t& arr, priority_tag<3> /*unused*/)\n{\n    arr = *j.template get_ptr<const typename BasicJsonType::array_t*>();\n}\n\ntemplate<typename BasicJsonType, typename T, std::size_t N>\nauto from_json_array_impl(const BasicJsonType& j, std::array<T, N>& arr,\n                          priority_tag<2> /*unused*/)\n-> decltype(j.template get<T>(), void())\n{\n    for (std::size_t i = 0; i < N; ++i)\n    {\n        arr[i] = j.at(i).template get<T>();\n    }\n}\n\ntemplate<typename BasicJsonType, typename ConstructibleArrayType,\n         enable_if_t<\n             std::is_assignable<ConstructibleArrayType&, ConstructibleArrayType>::value,\n             int> = 0>\nauto from_json_array_impl(const BasicJsonType& j, ConstructibleArrayType& arr, priority_tag<1> /*unused*/)\n-> decltype(\n    arr.reserve(std::declval<typename ConstructibleArrayType::size_type>()),\n    j.template get<typename ConstructibleArrayType::value_type>(),\n    void())\n{\n    using std::end;\n\n    ConstructibleArrayType ret;\n    ret.reserve(j.size());\n    std::transform(j.begin(), j.end(),\n                   std::inserter(ret, end(ret)), [](const BasicJsonType & i)\n    {\n        // get<BasicJsonType>() returns *this, this won't call a from_json\n        // method when value_type is BasicJsonType\n        return i.template get<typename ConstructibleArrayType::value_type>();\n    });\n    arr = std::move(ret);\n}\n\ntemplate<typename BasicJsonType, typename ConstructibleArrayType,\n         enable_if_t<\n             std::is_assignable<ConstructibleArrayType&, ConstructibleArrayType>::value,\n             int> = 0>\ninline void from_json_array_impl(const BasicJsonType& j, ConstructibleArrayType& arr,\n                                 priority_tag<0> /*unused*/)\n{\n    using std::end;\n\n    ConstructibleArrayType ret;\n    std::transform(\n        j.begin(), j.end(), std::inserter(ret, end(ret)),\n        [](const BasicJsonType & i)\n    {\n        // get<BasicJsonType>() returns *this, this won't call a from_json\n        // method when value_type is BasicJsonType\n        return i.template get<typename ConstructibleArrayType::value_type>();\n    });\n    arr = std::move(ret);\n}\n\ntemplate < typename BasicJsonType, typename ConstructibleArrayType,\n           enable_if_t <\n               is_constructible_array_type<BasicJsonType, ConstructibleArrayType>::value&&\n               !is_constructible_object_type<BasicJsonType, ConstructibleArrayType>::value&&\n               !is_constructible_string_type<BasicJsonType, ConstructibleArrayType>::value&&\n               !std::is_same<ConstructibleArrayType, typename BasicJsonType::binary_t>::value&&\n               !is_basic_json<ConstructibleArrayType>::value,\n               int > = 0 >\nauto from_json(const BasicJsonType& j, ConstructibleArrayType& arr)\n-> decltype(from_json_array_impl(j, arr, priority_tag<3> {}),\nj.template get<typename ConstructibleArrayType::value_type>(),\nvoid())\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_array()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be array, but is \", j.type_name()), &j));\n    }\n\n    from_json_array_impl(j, arr, priority_tag<3> {});\n}\n\ntemplate < typename BasicJsonType, typename T, std::size_t... Idx >\nstd::array<T, sizeof...(Idx)> from_json_inplace_array_impl(BasicJsonType&& j,\n        identity_tag<std::array<T, sizeof...(Idx)>> /*unused*/, index_sequence<Idx...> /*unused*/)\n{\n    return { { std::forward<BasicJsonType>(j).at(Idx).template get<T>()... } };\n}\n\ntemplate < typename BasicJsonType, typename T, std::size_t N >\nauto from_json(BasicJsonType&& j, identity_tag<std::array<T, N>> tag)\n-> decltype(from_json_inplace_array_impl(std::forward<BasicJsonType>(j), tag, make_index_sequence<N> {}))\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_array()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be array, but is \", j.type_name()), &j));\n    }\n\n    return from_json_inplace_array_impl(std::forward<BasicJsonType>(j), tag, make_index_sequence<N> {});\n}\n\ntemplate<typename BasicJsonType>\ninline void from_json(const BasicJsonType& j, typename BasicJsonType::binary_t& bin)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_binary()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be binary, but is \", j.type_name()), &j));\n    }\n\n    bin = *j.template get_ptr<const typename BasicJsonType::binary_t*>();\n}\n\ntemplate<typename BasicJsonType, typename ConstructibleObjectType,\n         enable_if_t<is_constructible_object_type<BasicJsonType, ConstructibleObjectType>::value, int> = 0>\ninline void from_json(const BasicJsonType& j, ConstructibleObjectType& obj)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_object()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be object, but is \", j.type_name()), &j));\n    }\n\n    ConstructibleObjectType ret;\n    const auto* inner_object = j.template get_ptr<const typename BasicJsonType::object_t*>();\n    using value_type = typename ConstructibleObjectType::value_type;\n    std::transform(\n        inner_object->begin(), inner_object->end(),\n        std::inserter(ret, ret.begin()),\n        [](typename BasicJsonType::object_t::value_type const & p)\n    {\n        return value_type(p.first, p.second.template get<typename ConstructibleObjectType::mapped_type>());\n    });\n    obj = std::move(ret);\n}\n\n// overload for arithmetic types, not chosen for basic_json template arguments\n// (BooleanType, etc..); note: Is it really necessary to provide explicit\n// overloads for boolean_t etc. in case of a custom BooleanType which is not\n// an arithmetic type?\ntemplate < typename BasicJsonType, typename ArithmeticType,\n           enable_if_t <\n               std::is_arithmetic<ArithmeticType>::value&&\n               !std::is_same<ArithmeticType, typename BasicJsonType::number_unsigned_t>::value&&\n               !std::is_same<ArithmeticType, typename BasicJsonType::number_integer_t>::value&&\n               !std::is_same<ArithmeticType, typename BasicJsonType::number_float_t>::value&&\n               !std::is_same<ArithmeticType, typename BasicJsonType::boolean_t>::value,\n               int > = 0 >\ninline void from_json(const BasicJsonType& j, ArithmeticType& val)\n{\n    switch (static_cast<value_t>(j))\n    {\n        case value_t::number_unsigned:\n        {\n            val = static_cast<ArithmeticType>(*j.template get_ptr<const typename BasicJsonType::number_unsigned_t*>());\n            break;\n        }\n        case value_t::number_integer:\n        {\n            val = static_cast<ArithmeticType>(*j.template get_ptr<const typename BasicJsonType::number_integer_t*>());\n            break;\n        }\n        case value_t::number_float:\n        {\n            val = static_cast<ArithmeticType>(*j.template get_ptr<const typename BasicJsonType::number_float_t*>());\n            break;\n        }\n        case value_t::boolean:\n        {\n            val = static_cast<ArithmeticType>(*j.template get_ptr<const typename BasicJsonType::boolean_t*>());\n            break;\n        }\n\n        case value_t::null:\n        case value_t::object:\n        case value_t::array:\n        case value_t::string:\n        case value_t::binary:\n        case value_t::discarded:\n        default:\n            JSON_THROW(type_error::create(302, concat(\"type must be number, but is \", j.type_name()), &j));\n    }\n}\n\ntemplate<typename BasicJsonType, typename... Args, std::size_t... Idx>\nstd::tuple<Args...> from_json_tuple_impl_base(BasicJsonType&& j, index_sequence<Idx...> /*unused*/)\n{\n    return std::make_tuple(std::forward<BasicJsonType>(j).at(Idx).template get<Args>()...);\n}\n\ntemplate < typename BasicJsonType, class A1, class A2 >\nstd::pair<A1, A2> from_json_tuple_impl(BasicJsonType&& j, identity_tag<std::pair<A1, A2>> /*unused*/, priority_tag<0> /*unused*/)\n{\n    return {std::forward<BasicJsonType>(j).at(0).template get<A1>(),\n            std::forward<BasicJsonType>(j).at(1).template get<A2>()};\n}\n\ntemplate<typename BasicJsonType, typename A1, typename A2>\ninline void from_json_tuple_impl(BasicJsonType&& j, std::pair<A1, A2>& p, priority_tag<1> /*unused*/)\n{\n    p = from_json_tuple_impl(std::forward<BasicJsonType>(j), identity_tag<std::pair<A1, A2>> {}, priority_tag<0> {});\n}\n\ntemplate<typename BasicJsonType, typename... Args>\nstd::tuple<Args...> from_json_tuple_impl(BasicJsonType&& j, identity_tag<std::tuple<Args...>> /*unused*/, priority_tag<2> /*unused*/)\n{\n    return from_json_tuple_impl_base<BasicJsonType, Args...>(std::forward<BasicJsonType>(j), index_sequence_for<Args...> {});\n}\n\ntemplate<typename BasicJsonType, typename... Args>\ninline void from_json_tuple_impl(BasicJsonType&& j, std::tuple<Args...>& t, priority_tag<3> /*unused*/)\n{\n    t = from_json_tuple_impl_base<BasicJsonType, Args...>(std::forward<BasicJsonType>(j), index_sequence_for<Args...> {});\n}\n\ntemplate<typename BasicJsonType, typename TupleRelated>\nauto from_json(BasicJsonType&& j, TupleRelated&& t)\n-> decltype(from_json_tuple_impl(std::forward<BasicJsonType>(j), std::forward<TupleRelated>(t), priority_tag<3> {}))\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_array()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be array, but is \", j.type_name()), &j));\n    }\n\n    return from_json_tuple_impl(std::forward<BasicJsonType>(j), std::forward<TupleRelated>(t), priority_tag<3> {});\n}\n\ntemplate < typename BasicJsonType, typename Key, typename Value, typename Compare, typename Allocator,\n           typename = enable_if_t < !std::is_constructible <\n                                        typename BasicJsonType::string_t, Key >::value >>\ninline void from_json(const BasicJsonType& j, std::map<Key, Value, Compare, Allocator>& m)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_array()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be array, but is \", j.type_name()), &j));\n    }\n    m.clear();\n    for (const auto& p : j)\n    {\n        if (JSON_HEDLEY_UNLIKELY(!p.is_array()))\n        {\n            JSON_THROW(type_error::create(302, concat(\"type must be array, but is \", p.type_name()), &j));\n        }\n        m.emplace(p.at(0).template get<Key>(), p.at(1).template get<Value>());\n    }\n}\n\ntemplate < typename BasicJsonType, typename Key, typename Value, typename Hash, typename KeyEqual, typename Allocator,\n           typename = enable_if_t < !std::is_constructible <\n                                        typename BasicJsonType::string_t, Key >::value >>\ninline void from_json(const BasicJsonType& j, std::unordered_map<Key, Value, Hash, KeyEqual, Allocator>& m)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_array()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be array, but is \", j.type_name()), &j));\n    }\n    m.clear();\n    for (const auto& p : j)\n    {\n        if (JSON_HEDLEY_UNLIKELY(!p.is_array()))\n        {\n            JSON_THROW(type_error::create(302, concat(\"type must be array, but is \", p.type_name()), &j));\n        }\n        m.emplace(p.at(0).template get<Key>(), p.at(1).template get<Value>());\n    }\n}\n\n#if JSON_HAS_FILESYSTEM || JSON_HAS_EXPERIMENTAL_FILESYSTEM\ntemplate<typename BasicJsonType>\ninline void from_json(const BasicJsonType& j, std_fs::path& p)\n{\n    if (JSON_HEDLEY_UNLIKELY(!j.is_string()))\n    {\n        JSON_THROW(type_error::create(302, concat(\"type must be string, but is \", j.type_name()), &j));\n    }\n    p = *j.template get_ptr<const typename BasicJsonType::string_t*>();\n}\n#endif\n\nstruct from_json_fn\n{\n    template<typename BasicJsonType, typename T>\n    auto operator()(const BasicJsonType& j, T&& val) const\n    noexcept(noexcept(from_json(j, std::forward<T>(val))))\n    -> decltype(from_json(j, std::forward<T>(val)))\n    {\n        return from_json(j, std::forward<T>(val));\n    }\n};\n\n}  // namespace detail\n\n#ifndef JSON_HAS_CPP_17\n/// namespace to hold default `from_json` function\n/// to see why this is required:\n/// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4381.html\nnamespace // NOLINT(cert-dcl59-cpp,fuchsia-header-anon-namespaces,google-build-namespaces)\n{\n#endif\nJSON_INLINE_VARIABLE constexpr const auto& from_json = // NOLINT(misc-definitions-in-headers)\n    detail::static_const<detail::from_json_fn>::value;\n#ifndef JSON_HAS_CPP_17\n}  // namespace\n#endif\n\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/conversions/to_json.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <algorithm> // copy\n#include <iterator> // begin, end\n#include <string> // string\n#include <tuple> // tuple, get\n#include <type_traits> // is_same, is_constructible, is_floating_point, is_enum, underlying_type\n#include <utility> // move, forward, declval, pair\n#include <valarray> // valarray\n#include <vector> // vector\n\n// #include <nlohmann/detail/iterators/iteration_proxy.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cstddef> // size_t\n#include <iterator> // input_iterator_tag\n#include <string> // string, to_string\n#include <tuple> // tuple_size, get, tuple_element\n#include <utility> // move\n\n#if JSON_HAS_RANGES\n    #include <ranges> // enable_borrowed_range\n#endif\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n// #include <nlohmann/detail/value_t.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\ntemplate<typename string_type>\nvoid int_to_string( string_type& target, std::size_t value )\n{\n    // For ADL\n    using std::to_string;\n    target = to_string(value);\n}\ntemplate<typename IteratorType> class iteration_proxy_value\n{\n  public:\n    using difference_type = std::ptrdiff_t;\n    using value_type = iteration_proxy_value;\n    using pointer = value_type *;\n    using reference = value_type &;\n    using iterator_category = std::input_iterator_tag;\n    using string_type = typename std::remove_cv< typename std::remove_reference<decltype( std::declval<IteratorType>().key() ) >::type >::type;\n\n  private:\n    /// the iterator\n    IteratorType anchor{};\n    /// an index for arrays (used to create key names)\n    std::size_t array_index = 0;\n    /// last stringified array index\n    mutable std::size_t array_index_last = 0;\n    /// a string representation of the array index\n    mutable string_type array_index_str = \"0\";\n    /// an empty string (to return a reference for primitive values)\n    string_type empty_str{};\n\n  public:\n    explicit iteration_proxy_value() = default;\n    explicit iteration_proxy_value(IteratorType it, std::size_t array_index_ = 0)\n    noexcept(std::is_nothrow_move_constructible<IteratorType>::value\n             && std::is_nothrow_default_constructible<string_type>::value)\n        : anchor(std::move(it))\n        , array_index(array_index_)\n    {}\n\n    iteration_proxy_value(iteration_proxy_value const&) = default;\n    iteration_proxy_value& operator=(iteration_proxy_value const&) = default;\n    // older GCCs are a bit fussy and require explicit noexcept specifiers on defaulted functions\n    iteration_proxy_value(iteration_proxy_value&&)\n    noexcept(std::is_nothrow_move_constructible<IteratorType>::value\n             && std::is_nothrow_move_constructible<string_type>::value) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor,cppcoreguidelines-noexcept-move-operations)\n    iteration_proxy_value& operator=(iteration_proxy_value&&)\n    noexcept(std::is_nothrow_move_assignable<IteratorType>::value\n             && std::is_nothrow_move_assignable<string_type>::value) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor,cppcoreguidelines-noexcept-move-operations)\n    ~iteration_proxy_value() = default;\n\n    /// dereference operator (needed for range-based for)\n    const iteration_proxy_value& operator*() const\n    {\n        return *this;\n    }\n\n    /// increment operator (needed for range-based for)\n    iteration_proxy_value& operator++()\n    {\n        ++anchor;\n        ++array_index;\n\n        return *this;\n    }\n\n    iteration_proxy_value operator++(int)& // NOLINT(cert-dcl21-cpp)\n    {\n        auto tmp = iteration_proxy_value(anchor, array_index);\n        ++anchor;\n        ++array_index;\n        return tmp;\n    }\n\n    /// equality operator (needed for InputIterator)\n    bool operator==(const iteration_proxy_value& o) const\n    {\n        return anchor == o.anchor;\n    }\n\n    /// inequality operator (needed for range-based for)\n    bool operator!=(const iteration_proxy_value& o) const\n    {\n        return anchor != o.anchor;\n    }\n\n    /// return key of the iterator\n    const string_type& key() const\n    {\n        JSON_ASSERT(anchor.m_object != nullptr);\n\n        switch (anchor.m_object->type())\n        {\n            // use integer array index as key\n            case value_t::array:\n            {\n                if (array_index != array_index_last)\n                {\n                    int_to_string( array_index_str, array_index );\n                    array_index_last = array_index;\n                }\n                return array_index_str;\n            }\n\n            // use key from the object\n            case value_t::object:\n                return anchor.key();\n\n            // use an empty key for all primitive types\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n                return empty_str;\n        }\n    }\n\n    /// return value of the iterator\n    typename IteratorType::reference value() const\n    {\n        return anchor.value();\n    }\n};\n\n/// proxy class for the items() function\ntemplate<typename IteratorType> class iteration_proxy\n{\n  private:\n    /// the container to iterate\n    typename IteratorType::pointer container = nullptr;\n\n  public:\n    explicit iteration_proxy() = default;\n\n    /// construct iteration proxy from a container\n    explicit iteration_proxy(typename IteratorType::reference cont) noexcept\n        : container(&cont) {}\n\n    iteration_proxy(iteration_proxy const&) = default;\n    iteration_proxy& operator=(iteration_proxy const&) = default;\n    iteration_proxy(iteration_proxy&&) noexcept = default;\n    iteration_proxy& operator=(iteration_proxy&&) noexcept = default;\n    ~iteration_proxy() = default;\n\n    /// return iterator begin (needed for range-based for)\n    iteration_proxy_value<IteratorType> begin() const noexcept\n    {\n        return iteration_proxy_value<IteratorType>(container->begin());\n    }\n\n    /// return iterator end (needed for range-based for)\n    iteration_proxy_value<IteratorType> end() const noexcept\n    {\n        return iteration_proxy_value<IteratorType>(container->end());\n    }\n};\n\n// Structured Bindings Support\n// For further reference see https://blog.tartanllama.xyz/structured-bindings/\n// And see https://github.com/nlohmann/json/pull/1391\ntemplate<std::size_t N, typename IteratorType, enable_if_t<N == 0, int> = 0>\nauto get(const nlohmann::detail::iteration_proxy_value<IteratorType>& i) -> decltype(i.key())\n{\n    return i.key();\n}\n// Structured Bindings Support\n// For further reference see https://blog.tartanllama.xyz/structured-bindings/\n// And see https://github.com/nlohmann/json/pull/1391\ntemplate<std::size_t N, typename IteratorType, enable_if_t<N == 1, int> = 0>\nauto get(const nlohmann::detail::iteration_proxy_value<IteratorType>& i) -> decltype(i.value())\n{\n    return i.value();\n}\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// The Addition to the STD Namespace is required to add\n// Structured Bindings Support to the iteration_proxy_value class\n// For further reference see https://blog.tartanllama.xyz/structured-bindings/\n// And see https://github.com/nlohmann/json/pull/1391\nnamespace std\n{\n\n#if defined(__clang__)\n    // Fix: https://github.com/nlohmann/json/issues/1401\n    #pragma clang diagnostic push\n    #pragma clang diagnostic ignored \"-Wmismatched-tags\"\n#endif\ntemplate<typename IteratorType>\nclass tuple_size<::nlohmann::detail::iteration_proxy_value<IteratorType>> // NOLINT(cert-dcl58-cpp)\n            : public std::integral_constant<std::size_t, 2> {};\n\ntemplate<std::size_t N, typename IteratorType>\nclass tuple_element<N, ::nlohmann::detail::iteration_proxy_value<IteratorType >> // NOLINT(cert-dcl58-cpp)\n{\n  public:\n    using type = decltype(\n                     get<N>(std::declval <\n                            ::nlohmann::detail::iteration_proxy_value<IteratorType >> ()));\n};\n#if defined(__clang__)\n    #pragma clang diagnostic pop\n#endif\n\n}  // namespace std\n\n#if JSON_HAS_RANGES\n    template <typename IteratorType>\n    inline constexpr bool ::std::ranges::enable_borrowed_range<::nlohmann::detail::iteration_proxy<IteratorType>> = true;\n#endif\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/cpp_future.hpp>\n\n// #include <nlohmann/detail/meta/std_fs.hpp>\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n// #include <nlohmann/detail/value_t.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n//////////////////\n// constructors //\n//////////////////\n\n/*\n * Note all external_constructor<>::construct functions need to call\n * j.m_data.m_value.destroy(j.m_data.m_type) to avoid a memory leak in case j contains an\n * allocated value (e.g., a string). See bug issue\n * https://github.com/nlohmann/json/issues/2865 for more information.\n */\n\ntemplate<value_t> struct external_constructor;\n\ntemplate<>\nstruct external_constructor<value_t::boolean>\n{\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, typename BasicJsonType::boolean_t b) noexcept\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::boolean;\n        j.m_data.m_value = b;\n        j.assert_invariant();\n    }\n};\n\ntemplate<>\nstruct external_constructor<value_t::string>\n{\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, const typename BasicJsonType::string_t& s)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::string;\n        j.m_data.m_value = s;\n        j.assert_invariant();\n    }\n\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, typename BasicJsonType::string_t&& s)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::string;\n        j.m_data.m_value = std::move(s);\n        j.assert_invariant();\n    }\n\n    template < typename BasicJsonType, typename CompatibleStringType,\n               enable_if_t < !std::is_same<CompatibleStringType, typename BasicJsonType::string_t>::value,\n                             int > = 0 >\n    static void construct(BasicJsonType& j, const CompatibleStringType& str)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::string;\n        j.m_data.m_value.string = j.template create<typename BasicJsonType::string_t>(str);\n        j.assert_invariant();\n    }\n};\n\ntemplate<>\nstruct external_constructor<value_t::binary>\n{\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, const typename BasicJsonType::binary_t& b)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::binary;\n        j.m_data.m_value = typename BasicJsonType::binary_t(b);\n        j.assert_invariant();\n    }\n\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, typename BasicJsonType::binary_t&& b)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::binary;\n        j.m_data.m_value = typename BasicJsonType::binary_t(std::move(b));\n        j.assert_invariant();\n    }\n};\n\ntemplate<>\nstruct external_constructor<value_t::number_float>\n{\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, typename BasicJsonType::number_float_t val) noexcept\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::number_float;\n        j.m_data.m_value = val;\n        j.assert_invariant();\n    }\n};\n\ntemplate<>\nstruct external_constructor<value_t::number_unsigned>\n{\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, typename BasicJsonType::number_unsigned_t val) noexcept\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::number_unsigned;\n        j.m_data.m_value = val;\n        j.assert_invariant();\n    }\n};\n\ntemplate<>\nstruct external_constructor<value_t::number_integer>\n{\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, typename BasicJsonType::number_integer_t val) noexcept\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::number_integer;\n        j.m_data.m_value = val;\n        j.assert_invariant();\n    }\n};\n\ntemplate<>\nstruct external_constructor<value_t::array>\n{\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, const typename BasicJsonType::array_t& arr)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::array;\n        j.m_data.m_value = arr;\n        j.set_parents();\n        j.assert_invariant();\n    }\n\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, typename BasicJsonType::array_t&& arr)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::array;\n        j.m_data.m_value = std::move(arr);\n        j.set_parents();\n        j.assert_invariant();\n    }\n\n    template < typename BasicJsonType, typename CompatibleArrayType,\n               enable_if_t < !std::is_same<CompatibleArrayType, typename BasicJsonType::array_t>::value,\n                             int > = 0 >\n    static void construct(BasicJsonType& j, const CompatibleArrayType& arr)\n    {\n        using std::begin;\n        using std::end;\n\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::array;\n        j.m_data.m_value.array = j.template create<typename BasicJsonType::array_t>(begin(arr), end(arr));\n        j.set_parents();\n        j.assert_invariant();\n    }\n\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, const std::vector<bool>& arr)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::array;\n        j.m_data.m_value = value_t::array;\n        j.m_data.m_value.array->reserve(arr.size());\n        for (const bool x : arr)\n        {\n            j.m_data.m_value.array->push_back(x);\n            j.set_parent(j.m_data.m_value.array->back());\n        }\n        j.assert_invariant();\n    }\n\n    template<typename BasicJsonType, typename T,\n             enable_if_t<std::is_convertible<T, BasicJsonType>::value, int> = 0>\n    static void construct(BasicJsonType& j, const std::valarray<T>& arr)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::array;\n        j.m_data.m_value = value_t::array;\n        j.m_data.m_value.array->resize(arr.size());\n        if (arr.size() > 0)\n        {\n            std::copy(std::begin(arr), std::end(arr), j.m_data.m_value.array->begin());\n        }\n        j.set_parents();\n        j.assert_invariant();\n    }\n};\n\ntemplate<>\nstruct external_constructor<value_t::object>\n{\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, const typename BasicJsonType::object_t& obj)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::object;\n        j.m_data.m_value = obj;\n        j.set_parents();\n        j.assert_invariant();\n    }\n\n    template<typename BasicJsonType>\n    static void construct(BasicJsonType& j, typename BasicJsonType::object_t&& obj)\n    {\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::object;\n        j.m_data.m_value = std::move(obj);\n        j.set_parents();\n        j.assert_invariant();\n    }\n\n    template < typename BasicJsonType, typename CompatibleObjectType,\n               enable_if_t < !std::is_same<CompatibleObjectType, typename BasicJsonType::object_t>::value, int > = 0 >\n    static void construct(BasicJsonType& j, const CompatibleObjectType& obj)\n    {\n        using std::begin;\n        using std::end;\n\n        j.m_data.m_value.destroy(j.m_data.m_type);\n        j.m_data.m_type = value_t::object;\n        j.m_data.m_value.object = j.template create<typename BasicJsonType::object_t>(begin(obj), end(obj));\n        j.set_parents();\n        j.assert_invariant();\n    }\n};\n\n/////////////\n// to_json //\n/////////////\n\ntemplate<typename BasicJsonType, typename T,\n         enable_if_t<std::is_same<T, typename BasicJsonType::boolean_t>::value, int> = 0>\ninline void to_json(BasicJsonType& j, T b) noexcept\n{\n    external_constructor<value_t::boolean>::construct(j, b);\n}\n\ntemplate < typename BasicJsonType, typename BoolRef,\n           enable_if_t <\n               ((std::is_same<std::vector<bool>::reference, BoolRef>::value\n                 && !std::is_same <std::vector<bool>::reference, typename BasicJsonType::boolean_t&>::value)\n                || (std::is_same<std::vector<bool>::const_reference, BoolRef>::value\n                    && !std::is_same <detail::uncvref_t<std::vector<bool>::const_reference>,\n                                      typename BasicJsonType::boolean_t >::value))\n               && std::is_convertible<const BoolRef&, typename BasicJsonType::boolean_t>::value, int > = 0 >\ninline void to_json(BasicJsonType& j, const BoolRef& b) noexcept\n{\n    external_constructor<value_t::boolean>::construct(j, static_cast<typename BasicJsonType::boolean_t>(b));\n}\n\ntemplate<typename BasicJsonType, typename CompatibleString,\n         enable_if_t<std::is_constructible<typename BasicJsonType::string_t, CompatibleString>::value, int> = 0>\ninline void to_json(BasicJsonType& j, const CompatibleString& s)\n{\n    external_constructor<value_t::string>::construct(j, s);\n}\n\ntemplate<typename BasicJsonType>\ninline void to_json(BasicJsonType& j, typename BasicJsonType::string_t&& s)\n{\n    external_constructor<value_t::string>::construct(j, std::move(s));\n}\n\ntemplate<typename BasicJsonType, typename FloatType,\n         enable_if_t<std::is_floating_point<FloatType>::value, int> = 0>\ninline void to_json(BasicJsonType& j, FloatType val) noexcept\n{\n    external_constructor<value_t::number_float>::construct(j, static_cast<typename BasicJsonType::number_float_t>(val));\n}\n\ntemplate<typename BasicJsonType, typename CompatibleNumberUnsignedType,\n         enable_if_t<is_compatible_integer_type<typename BasicJsonType::number_unsigned_t, CompatibleNumberUnsignedType>::value, int> = 0>\ninline void to_json(BasicJsonType& j, CompatibleNumberUnsignedType val) noexcept\n{\n    external_constructor<value_t::number_unsigned>::construct(j, static_cast<typename BasicJsonType::number_unsigned_t>(val));\n}\n\ntemplate<typename BasicJsonType, typename CompatibleNumberIntegerType,\n         enable_if_t<is_compatible_integer_type<typename BasicJsonType::number_integer_t, CompatibleNumberIntegerType>::value, int> = 0>\ninline void to_json(BasicJsonType& j, CompatibleNumberIntegerType val) noexcept\n{\n    external_constructor<value_t::number_integer>::construct(j, static_cast<typename BasicJsonType::number_integer_t>(val));\n}\n\n#if !JSON_DISABLE_ENUM_SERIALIZATION\ntemplate<typename BasicJsonType, typename EnumType,\n         enable_if_t<std::is_enum<EnumType>::value, int> = 0>\ninline void to_json(BasicJsonType& j, EnumType e) noexcept\n{\n    using underlying_type = typename std::underlying_type<EnumType>::type;\n    static constexpr value_t integral_value_t = std::is_unsigned<underlying_type>::value ? value_t::number_unsigned : value_t::number_integer;\n    external_constructor<integral_value_t>::construct(j, static_cast<underlying_type>(e));\n}\n#endif  // JSON_DISABLE_ENUM_SERIALIZATION\n\ntemplate<typename BasicJsonType>\ninline void to_json(BasicJsonType& j, const std::vector<bool>& e)\n{\n    external_constructor<value_t::array>::construct(j, e);\n}\n\ntemplate < typename BasicJsonType, typename CompatibleArrayType,\n           enable_if_t < is_compatible_array_type<BasicJsonType,\n                         CompatibleArrayType>::value&&\n                         !is_compatible_object_type<BasicJsonType, CompatibleArrayType>::value&&\n                         !is_compatible_string_type<BasicJsonType, CompatibleArrayType>::value&&\n                         !std::is_same<typename BasicJsonType::binary_t, CompatibleArrayType>::value&&\n                         !is_basic_json<CompatibleArrayType>::value,\n                         int > = 0 >\ninline void to_json(BasicJsonType& j, const CompatibleArrayType& arr)\n{\n    external_constructor<value_t::array>::construct(j, arr);\n}\n\ntemplate<typename BasicJsonType>\ninline void to_json(BasicJsonType& j, const typename BasicJsonType::binary_t& bin)\n{\n    external_constructor<value_t::binary>::construct(j, bin);\n}\n\ntemplate<typename BasicJsonType, typename T,\n         enable_if_t<std::is_convertible<T, BasicJsonType>::value, int> = 0>\ninline void to_json(BasicJsonType& j, const std::valarray<T>& arr)\n{\n    external_constructor<value_t::array>::construct(j, std::move(arr));\n}\n\ntemplate<typename BasicJsonType>\ninline void to_json(BasicJsonType& j, typename BasicJsonType::array_t&& arr)\n{\n    external_constructor<value_t::array>::construct(j, std::move(arr));\n}\n\ntemplate < typename BasicJsonType, typename CompatibleObjectType,\n           enable_if_t < is_compatible_object_type<BasicJsonType, CompatibleObjectType>::value&& !is_basic_json<CompatibleObjectType>::value, int > = 0 >\ninline void to_json(BasicJsonType& j, const CompatibleObjectType& obj)\n{\n    external_constructor<value_t::object>::construct(j, obj);\n}\n\ntemplate<typename BasicJsonType>\ninline void to_json(BasicJsonType& j, typename BasicJsonType::object_t&& obj)\n{\n    external_constructor<value_t::object>::construct(j, std::move(obj));\n}\n\ntemplate <\n    typename BasicJsonType, typename T, std::size_t N,\n    enable_if_t < !std::is_constructible<typename BasicJsonType::string_t,\n                  const T(&)[N]>::value, // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays)\n                  int > = 0 >\ninline void to_json(BasicJsonType& j, const T(&arr)[N]) // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays)\n{\n    external_constructor<value_t::array>::construct(j, arr);\n}\n\ntemplate < typename BasicJsonType, typename T1, typename T2, enable_if_t < std::is_constructible<BasicJsonType, T1>::value&& std::is_constructible<BasicJsonType, T2>::value, int > = 0 >\ninline void to_json(BasicJsonType& j, const std::pair<T1, T2>& p)\n{\n    j = { p.first, p.second };\n}\n\n// for https://github.com/nlohmann/json/pull/1134\ntemplate<typename BasicJsonType, typename T,\n         enable_if_t<std::is_same<T, iteration_proxy_value<typename BasicJsonType::iterator>>::value, int> = 0>\ninline void to_json(BasicJsonType& j, const T& b)\n{\n    j = { {b.key(), b.value()} };\n}\n\ntemplate<typename BasicJsonType, typename Tuple, std::size_t... Idx>\ninline void to_json_tuple_impl(BasicJsonType& j, const Tuple& t, index_sequence<Idx...> /*unused*/)\n{\n    j = { std::get<Idx>(t)... };\n}\n\ntemplate<typename BasicJsonType, typename T, enable_if_t<is_constructible_tuple<BasicJsonType, T>::value, int > = 0>\ninline void to_json(BasicJsonType& j, const T& t)\n{\n    to_json_tuple_impl(j, t, make_index_sequence<std::tuple_size<T>::value> {});\n}\n\n#if JSON_HAS_FILESYSTEM || JSON_HAS_EXPERIMENTAL_FILESYSTEM\ntemplate<typename BasicJsonType>\ninline void to_json(BasicJsonType& j, const std_fs::path& p)\n{\n    j = p.string();\n}\n#endif\n\nstruct to_json_fn\n{\n    template<typename BasicJsonType, typename T>\n    auto operator()(BasicJsonType& j, T&& val) const noexcept(noexcept(to_json(j, std::forward<T>(val))))\n    -> decltype(to_json(j, std::forward<T>(val)), void())\n    {\n        return to_json(j, std::forward<T>(val));\n    }\n};\n}  // namespace detail\n\n#ifndef JSON_HAS_CPP_17\n/// namespace to hold default `to_json` function\n/// to see why this is required:\n/// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4381.html\nnamespace // NOLINT(cert-dcl59-cpp,fuchsia-header-anon-namespaces,google-build-namespaces)\n{\n#endif\nJSON_INLINE_VARIABLE constexpr const auto& to_json = // NOLINT(misc-definitions-in-headers)\n    detail::static_const<detail::to_json_fn>::value;\n#ifndef JSON_HAS_CPP_17\n}  // namespace\n#endif\n\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/meta/identity_tag.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\n\n/// @sa https://json.nlohmann.me/api/adl_serializer/\ntemplate<typename ValueType, typename>\nstruct adl_serializer\n{\n    /// @brief convert a JSON value to any value type\n    /// @sa https://json.nlohmann.me/api/adl_serializer/from_json/\n    template<typename BasicJsonType, typename TargetType = ValueType>\n    static auto from_json(BasicJsonType && j, TargetType& val) noexcept(\n        noexcept(::nlohmann::from_json(std::forward<BasicJsonType>(j), val)))\n    -> decltype(::nlohmann::from_json(std::forward<BasicJsonType>(j), val), void())\n    {\n        ::nlohmann::from_json(std::forward<BasicJsonType>(j), val);\n    }\n\n    /// @brief convert a JSON value to any value type\n    /// @sa https://json.nlohmann.me/api/adl_serializer/from_json/\n    template<typename BasicJsonType, typename TargetType = ValueType>\n    static auto from_json(BasicJsonType && j) noexcept(\n    noexcept(::nlohmann::from_json(std::forward<BasicJsonType>(j), detail::identity_tag<TargetType> {})))\n    -> decltype(::nlohmann::from_json(std::forward<BasicJsonType>(j), detail::identity_tag<TargetType> {}))\n    {\n        return ::nlohmann::from_json(std::forward<BasicJsonType>(j), detail::identity_tag<TargetType> {});\n    }\n\n    /// @brief convert any value type to a JSON value\n    /// @sa https://json.nlohmann.me/api/adl_serializer/to_json/\n    template<typename BasicJsonType, typename TargetType = ValueType>\n    static auto to_json(BasicJsonType& j, TargetType && val) noexcept(\n        noexcept(::nlohmann::to_json(j, std::forward<TargetType>(val))))\n    -> decltype(::nlohmann::to_json(j, std::forward<TargetType>(val)), void())\n    {\n        ::nlohmann::to_json(j, std::forward<TargetType>(val));\n    }\n};\n\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/byte_container_with_subtype.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cstdint> // uint8_t, uint64_t\n#include <tuple> // tie\n#include <utility> // move\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\n\n/// @brief an internal type for a backed binary type\n/// @sa https://json.nlohmann.me/api/byte_container_with_subtype/\ntemplate<typename BinaryType>\nclass byte_container_with_subtype : public BinaryType\n{\n  public:\n    using container_type = BinaryType;\n    using subtype_type = std::uint64_t;\n\n    /// @sa https://json.nlohmann.me/api/byte_container_with_subtype/byte_container_with_subtype/\n    byte_container_with_subtype() noexcept(noexcept(container_type()))\n        : container_type()\n    {}\n\n    /// @sa https://json.nlohmann.me/api/byte_container_with_subtype/byte_container_with_subtype/\n    byte_container_with_subtype(const container_type& b) noexcept(noexcept(container_type(b)))\n        : container_type(b)\n    {}\n\n    /// @sa https://json.nlohmann.me/api/byte_container_with_subtype/byte_container_with_subtype/\n    byte_container_with_subtype(container_type&& b) noexcept(noexcept(container_type(std::move(b))))\n        : container_type(std::move(b))\n    {}\n\n    /// @sa https://json.nlohmann.me/api/byte_container_with_subtype/byte_container_with_subtype/\n    byte_container_with_subtype(const container_type& b, subtype_type subtype_) noexcept(noexcept(container_type(b)))\n        : container_type(b)\n        , m_subtype(subtype_)\n        , m_has_subtype(true)\n    {}\n\n    /// @sa https://json.nlohmann.me/api/byte_container_with_subtype/byte_container_with_subtype/\n    byte_container_with_subtype(container_type&& b, subtype_type subtype_) noexcept(noexcept(container_type(std::move(b))))\n        : container_type(std::move(b))\n        , m_subtype(subtype_)\n        , m_has_subtype(true)\n    {}\n\n    bool operator==(const byte_container_with_subtype& rhs) const\n    {\n        return std::tie(static_cast<const BinaryType&>(*this), m_subtype, m_has_subtype) ==\n               std::tie(static_cast<const BinaryType&>(rhs), rhs.m_subtype, rhs.m_has_subtype);\n    }\n\n    bool operator!=(const byte_container_with_subtype& rhs) const\n    {\n        return !(rhs == *this);\n    }\n\n    /// @brief sets the binary subtype\n    /// @sa https://json.nlohmann.me/api/byte_container_with_subtype/set_subtype/\n    void set_subtype(subtype_type subtype_) noexcept\n    {\n        m_subtype = subtype_;\n        m_has_subtype = true;\n    }\n\n    /// @brief return the binary subtype\n    /// @sa https://json.nlohmann.me/api/byte_container_with_subtype/subtype/\n    constexpr subtype_type subtype() const noexcept\n    {\n        return m_has_subtype ? m_subtype : static_cast<subtype_type>(-1);\n    }\n\n    /// @brief return whether the value has a subtype\n    /// @sa https://json.nlohmann.me/api/byte_container_with_subtype/has_subtype/\n    constexpr bool has_subtype() const noexcept\n    {\n        return m_has_subtype;\n    }\n\n    /// @brief clears the binary subtype\n    /// @sa https://json.nlohmann.me/api/byte_container_with_subtype/clear_subtype/\n    void clear_subtype() noexcept\n    {\n        m_subtype = 0;\n        m_has_subtype = false;\n    }\n\n  private:\n    subtype_type m_subtype = 0;\n    bool m_has_subtype = false;\n};\n\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/conversions/from_json.hpp>\n\n// #include <nlohmann/detail/conversions/to_json.hpp>\n\n// #include <nlohmann/detail/exceptions.hpp>\n\n// #include <nlohmann/detail/hash.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cstdint> // uint8_t\n#include <cstddef> // size_t\n#include <functional> // hash\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n// #include <nlohmann/detail/value_t.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n// boost::hash_combine\ninline std::size_t combine(std::size_t seed, std::size_t h) noexcept\n{\n    seed ^= h + 0x9e3779b9 + (seed << 6U) + (seed >> 2U);\n    return seed;\n}\n\n/*!\n@brief hash a JSON value\n\nThe hash function tries to rely on std::hash where possible. Furthermore, the\ntype of the JSON value is taken into account to have different hash values for\nnull, 0, 0U, and false, etc.\n\n@tparam BasicJsonType basic_json specialization\n@param j JSON value to hash\n@return hash value of j\n*/\ntemplate<typename BasicJsonType>\nstd::size_t hash(const BasicJsonType& j)\n{\n    using string_t = typename BasicJsonType::string_t;\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n\n    const auto type = static_cast<std::size_t>(j.type());\n    switch (j.type())\n    {\n        case BasicJsonType::value_t::null:\n        case BasicJsonType::value_t::discarded:\n        {\n            return combine(type, 0);\n        }\n\n        case BasicJsonType::value_t::object:\n        {\n            auto seed = combine(type, j.size());\n            for (const auto& element : j.items())\n            {\n                const auto h = std::hash<string_t> {}(element.key());\n                seed = combine(seed, h);\n                seed = combine(seed, hash(element.value()));\n            }\n            return seed;\n        }\n\n        case BasicJsonType::value_t::array:\n        {\n            auto seed = combine(type, j.size());\n            for (const auto& element : j)\n            {\n                seed = combine(seed, hash(element));\n            }\n            return seed;\n        }\n\n        case BasicJsonType::value_t::string:\n        {\n            const auto h = std::hash<string_t> {}(j.template get_ref<const string_t&>());\n            return combine(type, h);\n        }\n\n        case BasicJsonType::value_t::boolean:\n        {\n            const auto h = std::hash<bool> {}(j.template get<bool>());\n            return combine(type, h);\n        }\n\n        case BasicJsonType::value_t::number_integer:\n        {\n            const auto h = std::hash<number_integer_t> {}(j.template get<number_integer_t>());\n            return combine(type, h);\n        }\n\n        case BasicJsonType::value_t::number_unsigned:\n        {\n            const auto h = std::hash<number_unsigned_t> {}(j.template get<number_unsigned_t>());\n            return combine(type, h);\n        }\n\n        case BasicJsonType::value_t::number_float:\n        {\n            const auto h = std::hash<number_float_t> {}(j.template get<number_float_t>());\n            return combine(type, h);\n        }\n\n        case BasicJsonType::value_t::binary:\n        {\n            auto seed = combine(type, j.get_binary().size());\n            const auto h = std::hash<bool> {}(j.get_binary().has_subtype());\n            seed = combine(seed, h);\n            seed = combine(seed, static_cast<std::size_t>(j.get_binary().subtype()));\n            for (const auto byte : j.get_binary())\n            {\n                seed = combine(seed, std::hash<std::uint8_t> {}(byte));\n            }\n            return seed;\n        }\n\n        default:                   // LCOV_EXCL_LINE\n            JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n            return 0;              // LCOV_EXCL_LINE\n    }\n}\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/input/binary_reader.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <algorithm> // generate_n\n#include <array> // array\n#include <cmath> // ldexp\n#include <cstddef> // size_t\n#include <cstdint> // uint8_t, uint16_t, uint32_t, uint64_t\n#include <cstdio> // snprintf\n#include <cstring> // memcpy\n#include <iterator> // back_inserter\n#include <limits> // numeric_limits\n#include <string> // char_traits, string\n#include <utility> // make_pair, move\n#include <vector> // vector\n\n// #include <nlohmann/detail/exceptions.hpp>\n\n// #include <nlohmann/detail/input/input_adapters.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <array> // array\n#include <cstddef> // size_t\n#include <cstring> // strlen\n#include <iterator> // begin, end, iterator_traits, random_access_iterator_tag, distance, next\n#include <memory> // shared_ptr, make_shared, addressof\n#include <numeric> // accumulate\n#include <string> // string, char_traits\n#include <type_traits> // enable_if, is_base_of, is_pointer, is_integral, remove_pointer\n#include <utility> // pair, declval\n\n#ifndef JSON_NO_IO\n    #include <cstdio>   // FILE *\n    #include <istream>  // istream\n#endif                  // JSON_NO_IO\n\n// #include <nlohmann/detail/iterators/iterator_traits.hpp>\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n/// the supported input formats\nenum class input_format_t { json, cbor, msgpack, ubjson, bson, bjdata };\n\n////////////////////\n// input adapters //\n////////////////////\n\n#ifndef JSON_NO_IO\n/*!\nInput adapter for stdio file access. This adapter read only 1 byte and do not use any\n buffer. This adapter is a very low level adapter.\n*/\nclass file_input_adapter\n{\n  public:\n    using char_type = char;\n\n    JSON_HEDLEY_NON_NULL(2)\n    explicit file_input_adapter(std::FILE* f) noexcept\n        : m_file(f)\n    {\n        JSON_ASSERT(m_file != nullptr);\n    }\n\n    // make class move-only\n    file_input_adapter(const file_input_adapter&) = delete;\n    file_input_adapter(file_input_adapter&&) noexcept = default;\n    file_input_adapter& operator=(const file_input_adapter&) = delete;\n    file_input_adapter& operator=(file_input_adapter&&) = delete;\n    ~file_input_adapter() = default;\n\n    std::char_traits<char>::int_type get_character() noexcept\n    {\n        return std::fgetc(m_file);\n    }\n\n  private:\n    /// the file pointer to read from\n    std::FILE* m_file;\n};\n\n/*!\nInput adapter for a (caching) istream. Ignores a UFT Byte Order Mark at\nbeginning of input. Does not support changing the underlying std::streambuf\nin mid-input. Maintains underlying std::istream and std::streambuf to support\nsubsequent use of standard std::istream operations to process any input\ncharacters following those used in parsing the JSON input.  Clears the\nstd::istream flags; any input errors (e.g., EOF) will be detected by the first\nsubsequent call for input from the std::istream.\n*/\nclass input_stream_adapter\n{\n  public:\n    using char_type = char;\n\n    ~input_stream_adapter()\n    {\n        // clear stream flags; we use underlying streambuf I/O, do not\n        // maintain ifstream flags, except eof\n        if (is != nullptr)\n        {\n            is->clear(is->rdstate() & std::ios::eofbit);\n        }\n    }\n\n    explicit input_stream_adapter(std::istream& i)\n        : is(&i), sb(i.rdbuf())\n    {}\n\n    // delete because of pointer members\n    input_stream_adapter(const input_stream_adapter&) = delete;\n    input_stream_adapter& operator=(input_stream_adapter&) = delete;\n    input_stream_adapter& operator=(input_stream_adapter&&) = delete;\n\n    input_stream_adapter(input_stream_adapter&& rhs) noexcept\n        : is(rhs.is), sb(rhs.sb)\n    {\n        rhs.is = nullptr;\n        rhs.sb = nullptr;\n    }\n\n    // std::istream/std::streambuf use std::char_traits<char>::to_int_type, to\n    // ensure that std::char_traits<char>::eof() and the character 0xFF do not\n    // end up as the same value, e.g. 0xFFFFFFFF.\n    std::char_traits<char>::int_type get_character()\n    {\n        auto res = sb->sbumpc();\n        // set eof manually, as we don't use the istream interface.\n        if (JSON_HEDLEY_UNLIKELY(res == std::char_traits<char>::eof()))\n        {\n            is->clear(is->rdstate() | std::ios::eofbit);\n        }\n        return res;\n    }\n\n  private:\n    /// the associated input stream\n    std::istream* is = nullptr;\n    std::streambuf* sb = nullptr;\n};\n#endif  // JSON_NO_IO\n\n// General-purpose iterator-based adapter. It might not be as fast as\n// theoretically possible for some containers, but it is extremely versatile.\ntemplate<typename IteratorType>\nclass iterator_input_adapter\n{\n  public:\n    using char_type = typename std::iterator_traits<IteratorType>::value_type;\n\n    iterator_input_adapter(IteratorType first, IteratorType last)\n        : current(std::move(first)), end(std::move(last))\n    {}\n\n    typename char_traits<char_type>::int_type get_character()\n    {\n        if (JSON_HEDLEY_LIKELY(current != end))\n        {\n            auto result = char_traits<char_type>::to_int_type(*current);\n            std::advance(current, 1);\n            return result;\n        }\n\n        return char_traits<char_type>::eof();\n    }\n\n  private:\n    IteratorType current;\n    IteratorType end;\n\n    template<typename BaseInputAdapter, size_t T>\n    friend struct wide_string_input_helper;\n\n    bool empty() const\n    {\n        return current == end;\n    }\n};\n\ntemplate<typename BaseInputAdapter, size_t T>\nstruct wide_string_input_helper;\n\ntemplate<typename BaseInputAdapter>\nstruct wide_string_input_helper<BaseInputAdapter, 4>\n{\n    // UTF-32\n    static void fill_buffer(BaseInputAdapter& input,\n                            std::array<std::char_traits<char>::int_type, 4>& utf8_bytes,\n                            size_t& utf8_bytes_index,\n                            size_t& utf8_bytes_filled)\n    {\n        utf8_bytes_index = 0;\n\n        if (JSON_HEDLEY_UNLIKELY(input.empty()))\n        {\n            utf8_bytes[0] = std::char_traits<char>::eof();\n            utf8_bytes_filled = 1;\n        }\n        else\n        {\n            // get the current character\n            const auto wc = input.get_character();\n\n            // UTF-32 to UTF-8 encoding\n            if (wc < 0x80)\n            {\n                utf8_bytes[0] = static_cast<std::char_traits<char>::int_type>(wc);\n                utf8_bytes_filled = 1;\n            }\n            else if (wc <= 0x7FF)\n            {\n                utf8_bytes[0] = static_cast<std::char_traits<char>::int_type>(0xC0u | ((static_cast<unsigned int>(wc) >> 6u) & 0x1Fu));\n                utf8_bytes[1] = static_cast<std::char_traits<char>::int_type>(0x80u | (static_cast<unsigned int>(wc) & 0x3Fu));\n                utf8_bytes_filled = 2;\n            }\n            else if (wc <= 0xFFFF)\n            {\n                utf8_bytes[0] = static_cast<std::char_traits<char>::int_type>(0xE0u | ((static_cast<unsigned int>(wc) >> 12u) & 0x0Fu));\n                utf8_bytes[1] = static_cast<std::char_traits<char>::int_type>(0x80u | ((static_cast<unsigned int>(wc) >> 6u) & 0x3Fu));\n                utf8_bytes[2] = static_cast<std::char_traits<char>::int_type>(0x80u | (static_cast<unsigned int>(wc) & 0x3Fu));\n                utf8_bytes_filled = 3;\n            }\n            else if (wc <= 0x10FFFF)\n            {\n                utf8_bytes[0] = static_cast<std::char_traits<char>::int_type>(0xF0u | ((static_cast<unsigned int>(wc) >> 18u) & 0x07u));\n                utf8_bytes[1] = static_cast<std::char_traits<char>::int_type>(0x80u | ((static_cast<unsigned int>(wc) >> 12u) & 0x3Fu));\n                utf8_bytes[2] = static_cast<std::char_traits<char>::int_type>(0x80u | ((static_cast<unsigned int>(wc) >> 6u) & 0x3Fu));\n                utf8_bytes[3] = static_cast<std::char_traits<char>::int_type>(0x80u | (static_cast<unsigned int>(wc) & 0x3Fu));\n                utf8_bytes_filled = 4;\n            }\n            else\n            {\n                // unknown character\n                utf8_bytes[0] = static_cast<std::char_traits<char>::int_type>(wc);\n                utf8_bytes_filled = 1;\n            }\n        }\n    }\n};\n\ntemplate<typename BaseInputAdapter>\nstruct wide_string_input_helper<BaseInputAdapter, 2>\n{\n    // UTF-16\n    static void fill_buffer(BaseInputAdapter& input,\n                            std::array<std::char_traits<char>::int_type, 4>& utf8_bytes,\n                            size_t& utf8_bytes_index,\n                            size_t& utf8_bytes_filled)\n    {\n        utf8_bytes_index = 0;\n\n        if (JSON_HEDLEY_UNLIKELY(input.empty()))\n        {\n            utf8_bytes[0] = std::char_traits<char>::eof();\n            utf8_bytes_filled = 1;\n        }\n        else\n        {\n            // get the current character\n            const auto wc = input.get_character();\n\n            // UTF-16 to UTF-8 encoding\n            if (wc < 0x80)\n            {\n                utf8_bytes[0] = static_cast<std::char_traits<char>::int_type>(wc);\n                utf8_bytes_filled = 1;\n            }\n            else if (wc <= 0x7FF)\n            {\n                utf8_bytes[0] = static_cast<std::char_traits<char>::int_type>(0xC0u | ((static_cast<unsigned int>(wc) >> 6u)));\n                utf8_bytes[1] = static_cast<std::char_traits<char>::int_type>(0x80u | (static_cast<unsigned int>(wc) & 0x3Fu));\n                utf8_bytes_filled = 2;\n            }\n            else if (0xD800 > wc || wc >= 0xE000)\n            {\n                utf8_bytes[0] = static_cast<std::char_traits<char>::int_type>(0xE0u | ((static_cast<unsigned int>(wc) >> 12u)));\n                utf8_bytes[1] = static_cast<std::char_traits<char>::int_type>(0x80u | ((static_cast<unsigned int>(wc) >> 6u) & 0x3Fu));\n                utf8_bytes[2] = static_cast<std::char_traits<char>::int_type>(0x80u | (static_cast<unsigned int>(wc) & 0x3Fu));\n                utf8_bytes_filled = 3;\n            }\n            else\n            {\n                if (JSON_HEDLEY_UNLIKELY(!input.empty()))\n                {\n                    const auto wc2 = static_cast<unsigned int>(input.get_character());\n                    const auto charcode = 0x10000u + (((static_cast<unsigned int>(wc) & 0x3FFu) << 10u) | (wc2 & 0x3FFu));\n                    utf8_bytes[0] = static_cast<std::char_traits<char>::int_type>(0xF0u | (charcode >> 18u));\n                    utf8_bytes[1] = static_cast<std::char_traits<char>::int_type>(0x80u | ((charcode >> 12u) & 0x3Fu));\n                    utf8_bytes[2] = static_cast<std::char_traits<char>::int_type>(0x80u | ((charcode >> 6u) & 0x3Fu));\n                    utf8_bytes[3] = static_cast<std::char_traits<char>::int_type>(0x80u | (charcode & 0x3Fu));\n                    utf8_bytes_filled = 4;\n                }\n                else\n                {\n                    utf8_bytes[0] = static_cast<std::char_traits<char>::int_type>(wc);\n                    utf8_bytes_filled = 1;\n                }\n            }\n        }\n    }\n};\n\n// Wraps another input adapter to convert wide character types into individual bytes.\ntemplate<typename BaseInputAdapter, typename WideCharType>\nclass wide_string_input_adapter\n{\n  public:\n    using char_type = char;\n\n    wide_string_input_adapter(BaseInputAdapter base)\n        : base_adapter(base) {}\n\n    typename std::char_traits<char>::int_type get_character() noexcept\n    {\n        // check if buffer needs to be filled\n        if (utf8_bytes_index == utf8_bytes_filled)\n        {\n            fill_buffer<sizeof(WideCharType)>();\n\n            JSON_ASSERT(utf8_bytes_filled > 0);\n            JSON_ASSERT(utf8_bytes_index == 0);\n        }\n\n        // use buffer\n        JSON_ASSERT(utf8_bytes_filled > 0);\n        JSON_ASSERT(utf8_bytes_index < utf8_bytes_filled);\n        return utf8_bytes[utf8_bytes_index++];\n    }\n\n  private:\n    BaseInputAdapter base_adapter;\n\n    template<size_t T>\n    void fill_buffer()\n    {\n        wide_string_input_helper<BaseInputAdapter, T>::fill_buffer(base_adapter, utf8_bytes, utf8_bytes_index, utf8_bytes_filled);\n    }\n\n    /// a buffer for UTF-8 bytes\n    std::array<std::char_traits<char>::int_type, 4> utf8_bytes = {{0, 0, 0, 0}};\n\n    /// index to the utf8_codes array for the next valid byte\n    std::size_t utf8_bytes_index = 0;\n    /// number of valid bytes in the utf8_codes array\n    std::size_t utf8_bytes_filled = 0;\n};\n\ntemplate<typename IteratorType, typename Enable = void>\nstruct iterator_input_adapter_factory\n{\n    using iterator_type = IteratorType;\n    using char_type = typename std::iterator_traits<iterator_type>::value_type;\n    using adapter_type = iterator_input_adapter<iterator_type>;\n\n    static adapter_type create(IteratorType first, IteratorType last)\n    {\n        return adapter_type(std::move(first), std::move(last));\n    }\n};\n\ntemplate<typename T>\nstruct is_iterator_of_multibyte\n{\n    using value_type = typename std::iterator_traits<T>::value_type;\n    enum\n    {\n        value = sizeof(value_type) > 1\n    };\n};\n\ntemplate<typename IteratorType>\nstruct iterator_input_adapter_factory<IteratorType, enable_if_t<is_iterator_of_multibyte<IteratorType>::value>>\n{\n    using iterator_type = IteratorType;\n    using char_type = typename std::iterator_traits<iterator_type>::value_type;\n    using base_adapter_type = iterator_input_adapter<iterator_type>;\n    using adapter_type = wide_string_input_adapter<base_adapter_type, char_type>;\n\n    static adapter_type create(IteratorType first, IteratorType last)\n    {\n        return adapter_type(base_adapter_type(std::move(first), std::move(last)));\n    }\n};\n\n// General purpose iterator-based input\ntemplate<typename IteratorType>\ntypename iterator_input_adapter_factory<IteratorType>::adapter_type input_adapter(IteratorType first, IteratorType last)\n{\n    using factory_type = iterator_input_adapter_factory<IteratorType>;\n    return factory_type::create(first, last);\n}\n\n// Convenience shorthand from container to iterator\n// Enables ADL on begin(container) and end(container)\n// Encloses the using declarations in namespace for not to leak them to outside scope\n\nnamespace container_input_adapter_factory_impl\n{\n\nusing std::begin;\nusing std::end;\n\ntemplate<typename ContainerType, typename Enable = void>\nstruct container_input_adapter_factory {};\n\ntemplate<typename ContainerType>\nstruct container_input_adapter_factory< ContainerType,\n       void_t<decltype(begin(std::declval<ContainerType>()), end(std::declval<ContainerType>()))>>\n       {\n           using adapter_type = decltype(input_adapter(begin(std::declval<ContainerType>()), end(std::declval<ContainerType>())));\n\n           static adapter_type create(const ContainerType& container)\n{\n    return input_adapter(begin(container), end(container));\n}\n       };\n\n}  // namespace container_input_adapter_factory_impl\n\ntemplate<typename ContainerType>\ntypename container_input_adapter_factory_impl::container_input_adapter_factory<ContainerType>::adapter_type input_adapter(const ContainerType& container)\n{\n    return container_input_adapter_factory_impl::container_input_adapter_factory<ContainerType>::create(container);\n}\n\n#ifndef JSON_NO_IO\n// Special cases with fast paths\ninline file_input_adapter input_adapter(std::FILE* file)\n{\n    return file_input_adapter(file);\n}\n\ninline input_stream_adapter input_adapter(std::istream& stream)\n{\n    return input_stream_adapter(stream);\n}\n\ninline input_stream_adapter input_adapter(std::istream&& stream)\n{\n    return input_stream_adapter(stream);\n}\n#endif  // JSON_NO_IO\n\nusing contiguous_bytes_input_adapter = decltype(input_adapter(std::declval<const char*>(), std::declval<const char*>()));\n\n// Null-delimited strings, and the like.\ntemplate < typename CharT,\n           typename std::enable_if <\n               std::is_pointer<CharT>::value&&\n               !std::is_array<CharT>::value&&\n               std::is_integral<typename std::remove_pointer<CharT>::type>::value&&\n               sizeof(typename std::remove_pointer<CharT>::type) == 1,\n               int >::type = 0 >\ncontiguous_bytes_input_adapter input_adapter(CharT b)\n{\n    auto length = std::strlen(reinterpret_cast<const char*>(b));\n    const auto* ptr = reinterpret_cast<const char*>(b);\n    return input_adapter(ptr, ptr + length);\n}\n\ntemplate<typename T, std::size_t N>\nauto input_adapter(T (&array)[N]) -> decltype(input_adapter(array, array + N)) // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays)\n{\n    return input_adapter(array, array + N);\n}\n\n// This class only handles inputs of input_buffer_adapter type.\n// It's required so that expressions like {ptr, len} can be implicitly cast\n// to the correct adapter.\nclass span_input_adapter\n{\n  public:\n    template < typename CharT,\n               typename std::enable_if <\n                   std::is_pointer<CharT>::value&&\n                   std::is_integral<typename std::remove_pointer<CharT>::type>::value&&\n                   sizeof(typename std::remove_pointer<CharT>::type) == 1,\n                   int >::type = 0 >\n    span_input_adapter(CharT b, std::size_t l)\n        : ia(reinterpret_cast<const char*>(b), reinterpret_cast<const char*>(b) + l) {}\n\n    template<class IteratorType,\n             typename std::enable_if<\n                 std::is_same<typename iterator_traits<IteratorType>::iterator_category, std::random_access_iterator_tag>::value,\n                 int>::type = 0>\n    span_input_adapter(IteratorType first, IteratorType last)\n        : ia(input_adapter(first, last)) {}\n\n    contiguous_bytes_input_adapter&& get()\n    {\n        return std::move(ia); // NOLINT(hicpp-move-const-arg,performance-move-const-arg)\n    }\n\n  private:\n    contiguous_bytes_input_adapter ia;\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/input/json_sax.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cstddef>\n#include <string> // string\n#include <utility> // move\n#include <vector> // vector\n\n// #include <nlohmann/detail/exceptions.hpp>\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/string_concat.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\n\n/*!\n@brief SAX interface\n\nThis class describes the SAX interface used by @ref nlohmann::json::sax_parse.\nEach function is called in different situations while the input is parsed. The\nboolean return value informs the parser whether to continue processing the\ninput.\n*/\ntemplate<typename BasicJsonType>\nstruct json_sax\n{\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n    using string_t = typename BasicJsonType::string_t;\n    using binary_t = typename BasicJsonType::binary_t;\n\n    /*!\n    @brief a null value was read\n    @return whether parsing should proceed\n    */\n    virtual bool null() = 0;\n\n    /*!\n    @brief a boolean value was read\n    @param[in] val  boolean value\n    @return whether parsing should proceed\n    */\n    virtual bool boolean(bool val) = 0;\n\n    /*!\n    @brief an integer number was read\n    @param[in] val  integer value\n    @return whether parsing should proceed\n    */\n    virtual bool number_integer(number_integer_t val) = 0;\n\n    /*!\n    @brief an unsigned integer number was read\n    @param[in] val  unsigned integer value\n    @return whether parsing should proceed\n    */\n    virtual bool number_unsigned(number_unsigned_t val) = 0;\n\n    /*!\n    @brief a floating-point number was read\n    @param[in] val  floating-point value\n    @param[in] s    raw token value\n    @return whether parsing should proceed\n    */\n    virtual bool number_float(number_float_t val, const string_t& s) = 0;\n\n    /*!\n    @brief a string value was read\n    @param[in] val  string value\n    @return whether parsing should proceed\n    @note It is safe to move the passed string value.\n    */\n    virtual bool string(string_t& val) = 0;\n\n    /*!\n    @brief a binary value was read\n    @param[in] val  binary value\n    @return whether parsing should proceed\n    @note It is safe to move the passed binary value.\n    */\n    virtual bool binary(binary_t& val) = 0;\n\n    /*!\n    @brief the beginning of an object was read\n    @param[in] elements  number of object elements or -1 if unknown\n    @return whether parsing should proceed\n    @note binary formats may report the number of elements\n    */\n    virtual bool start_object(std::size_t elements) = 0;\n\n    /*!\n    @brief an object key was read\n    @param[in] val  object key\n    @return whether parsing should proceed\n    @note It is safe to move the passed string.\n    */\n    virtual bool key(string_t& val) = 0;\n\n    /*!\n    @brief the end of an object was read\n    @return whether parsing should proceed\n    */\n    virtual bool end_object() = 0;\n\n    /*!\n    @brief the beginning of an array was read\n    @param[in] elements  number of array elements or -1 if unknown\n    @return whether parsing should proceed\n    @note binary formats may report the number of elements\n    */\n    virtual bool start_array(std::size_t elements) = 0;\n\n    /*!\n    @brief the end of an array was read\n    @return whether parsing should proceed\n    */\n    virtual bool end_array() = 0;\n\n    /*!\n    @brief a parse error occurred\n    @param[in] position    the position in the input where the error occurs\n    @param[in] last_token  the last read token\n    @param[in] ex          an exception object describing the error\n    @return whether parsing should proceed (must return false)\n    */\n    virtual bool parse_error(std::size_t position,\n                             const std::string& last_token,\n                             const detail::exception& ex) = 0;\n\n    json_sax() = default;\n    json_sax(const json_sax&) = default;\n    json_sax(json_sax&&) noexcept = default;\n    json_sax& operator=(const json_sax&) = default;\n    json_sax& operator=(json_sax&&) noexcept = default;\n    virtual ~json_sax() = default;\n};\n\nnamespace detail\n{\n/*!\n@brief SAX implementation to create a JSON value from SAX events\n\nThis class implements the @ref json_sax interface and processes the SAX events\nto create a JSON value which makes it basically a DOM parser. The structure or\nhierarchy of the JSON value is managed by the stack `ref_stack` which contains\na pointer to the respective array or object for each recursion depth.\n\nAfter successful parsing, the value that is passed by reference to the\nconstructor contains the parsed value.\n\n@tparam BasicJsonType  the JSON type\n*/\ntemplate<typename BasicJsonType>\nclass json_sax_dom_parser\n{\n  public:\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n    using string_t = typename BasicJsonType::string_t;\n    using binary_t = typename BasicJsonType::binary_t;\n\n    /*!\n    @param[in,out] r  reference to a JSON value that is manipulated while\n                       parsing\n    @param[in] allow_exceptions_  whether parse errors yield exceptions\n    */\n    explicit json_sax_dom_parser(BasicJsonType& r, const bool allow_exceptions_ = true)\n        : root(r), allow_exceptions(allow_exceptions_)\n    {}\n\n    // make class move-only\n    json_sax_dom_parser(const json_sax_dom_parser&) = delete;\n    json_sax_dom_parser(json_sax_dom_parser&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor)\n    json_sax_dom_parser& operator=(const json_sax_dom_parser&) = delete;\n    json_sax_dom_parser& operator=(json_sax_dom_parser&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor)\n    ~json_sax_dom_parser() = default;\n\n    bool null()\n    {\n        handle_value(nullptr);\n        return true;\n    }\n\n    bool boolean(bool val)\n    {\n        handle_value(val);\n        return true;\n    }\n\n    bool number_integer(number_integer_t val)\n    {\n        handle_value(val);\n        return true;\n    }\n\n    bool number_unsigned(number_unsigned_t val)\n    {\n        handle_value(val);\n        return true;\n    }\n\n    bool number_float(number_float_t val, const string_t& /*unused*/)\n    {\n        handle_value(val);\n        return true;\n    }\n\n    bool string(string_t& val)\n    {\n        handle_value(val);\n        return true;\n    }\n\n    bool binary(binary_t& val)\n    {\n        handle_value(std::move(val));\n        return true;\n    }\n\n    bool start_object(std::size_t len)\n    {\n        ref_stack.push_back(handle_value(BasicJsonType::value_t::object));\n\n        if (JSON_HEDLEY_UNLIKELY(len != static_cast<std::size_t>(-1) && len > ref_stack.back()->max_size()))\n        {\n            JSON_THROW(out_of_range::create(408, concat(\"excessive object size: \", std::to_string(len)), ref_stack.back()));\n        }\n\n        return true;\n    }\n\n    bool key(string_t& val)\n    {\n        JSON_ASSERT(!ref_stack.empty());\n        JSON_ASSERT(ref_stack.back()->is_object());\n\n        // add null at given key and store the reference for later\n        object_element = &(ref_stack.back()->m_data.m_value.object->operator[](val));\n        return true;\n    }\n\n    bool end_object()\n    {\n        JSON_ASSERT(!ref_stack.empty());\n        JSON_ASSERT(ref_stack.back()->is_object());\n\n        ref_stack.back()->set_parents();\n        ref_stack.pop_back();\n        return true;\n    }\n\n    bool start_array(std::size_t len)\n    {\n        ref_stack.push_back(handle_value(BasicJsonType::value_t::array));\n\n        if (JSON_HEDLEY_UNLIKELY(len != static_cast<std::size_t>(-1) && len > ref_stack.back()->max_size()))\n        {\n            JSON_THROW(out_of_range::create(408, concat(\"excessive array size: \", std::to_string(len)), ref_stack.back()));\n        }\n\n        return true;\n    }\n\n    bool end_array()\n    {\n        JSON_ASSERT(!ref_stack.empty());\n        JSON_ASSERT(ref_stack.back()->is_array());\n\n        ref_stack.back()->set_parents();\n        ref_stack.pop_back();\n        return true;\n    }\n\n    template<class Exception>\n    bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/,\n                     const Exception& ex)\n    {\n        errored = true;\n        static_cast<void>(ex);\n        if (allow_exceptions)\n        {\n            JSON_THROW(ex);\n        }\n        return false;\n    }\n\n    constexpr bool is_errored() const\n    {\n        return errored;\n    }\n\n  private:\n    /*!\n    @invariant If the ref stack is empty, then the passed value will be the new\n               root.\n    @invariant If the ref stack contains a value, then it is an array or an\n               object to which we can add elements\n    */\n    template<typename Value>\n    JSON_HEDLEY_RETURNS_NON_NULL\n    BasicJsonType* handle_value(Value&& v)\n    {\n        if (ref_stack.empty())\n        {\n            root = BasicJsonType(std::forward<Value>(v));\n            return &root;\n        }\n\n        JSON_ASSERT(ref_stack.back()->is_array() || ref_stack.back()->is_object());\n\n        if (ref_stack.back()->is_array())\n        {\n            ref_stack.back()->m_data.m_value.array->emplace_back(std::forward<Value>(v));\n            return &(ref_stack.back()->m_data.m_value.array->back());\n        }\n\n        JSON_ASSERT(ref_stack.back()->is_object());\n        JSON_ASSERT(object_element);\n        *object_element = BasicJsonType(std::forward<Value>(v));\n        return object_element;\n    }\n\n    /// the parsed JSON value\n    BasicJsonType& root;\n    /// stack to model hierarchy of values\n    std::vector<BasicJsonType*> ref_stack {};\n    /// helper to hold the reference for the next object element\n    BasicJsonType* object_element = nullptr;\n    /// whether a syntax error occurred\n    bool errored = false;\n    /// whether to throw exceptions in case of errors\n    const bool allow_exceptions = true;\n};\n\ntemplate<typename BasicJsonType>\nclass json_sax_dom_callback_parser\n{\n  public:\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n    using string_t = typename BasicJsonType::string_t;\n    using binary_t = typename BasicJsonType::binary_t;\n    using parser_callback_t = typename BasicJsonType::parser_callback_t;\n    using parse_event_t = typename BasicJsonType::parse_event_t;\n\n    json_sax_dom_callback_parser(BasicJsonType& r,\n                                 const parser_callback_t cb,\n                                 const bool allow_exceptions_ = true)\n        : root(r), callback(cb), allow_exceptions(allow_exceptions_)\n    {\n        keep_stack.push_back(true);\n    }\n\n    // make class move-only\n    json_sax_dom_callback_parser(const json_sax_dom_callback_parser&) = delete;\n    json_sax_dom_callback_parser(json_sax_dom_callback_parser&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor)\n    json_sax_dom_callback_parser& operator=(const json_sax_dom_callback_parser&) = delete;\n    json_sax_dom_callback_parser& operator=(json_sax_dom_callback_parser&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor)\n    ~json_sax_dom_callback_parser() = default;\n\n    bool null()\n    {\n        handle_value(nullptr);\n        return true;\n    }\n\n    bool boolean(bool val)\n    {\n        handle_value(val);\n        return true;\n    }\n\n    bool number_integer(number_integer_t val)\n    {\n        handle_value(val);\n        return true;\n    }\n\n    bool number_unsigned(number_unsigned_t val)\n    {\n        handle_value(val);\n        return true;\n    }\n\n    bool number_float(number_float_t val, const string_t& /*unused*/)\n    {\n        handle_value(val);\n        return true;\n    }\n\n    bool string(string_t& val)\n    {\n        handle_value(val);\n        return true;\n    }\n\n    bool binary(binary_t& val)\n    {\n        handle_value(std::move(val));\n        return true;\n    }\n\n    bool start_object(std::size_t len)\n    {\n        // check callback for object start\n        const bool keep = callback(static_cast<int>(ref_stack.size()), parse_event_t::object_start, discarded);\n        keep_stack.push_back(keep);\n\n        auto val = handle_value(BasicJsonType::value_t::object, true);\n        ref_stack.push_back(val.second);\n\n        // check object limit\n        if (ref_stack.back() && JSON_HEDLEY_UNLIKELY(len != static_cast<std::size_t>(-1) && len > ref_stack.back()->max_size()))\n        {\n            JSON_THROW(out_of_range::create(408, concat(\"excessive object size: \", std::to_string(len)), ref_stack.back()));\n        }\n\n        return true;\n    }\n\n    bool key(string_t& val)\n    {\n        BasicJsonType k = BasicJsonType(val);\n\n        // check callback for key\n        const bool keep = callback(static_cast<int>(ref_stack.size()), parse_event_t::key, k);\n        key_keep_stack.push_back(keep);\n\n        // add discarded value at given key and store the reference for later\n        if (keep && ref_stack.back())\n        {\n            object_element = &(ref_stack.back()->m_data.m_value.object->operator[](val) = discarded);\n        }\n\n        return true;\n    }\n\n    bool end_object()\n    {\n        if (ref_stack.back())\n        {\n            if (!callback(static_cast<int>(ref_stack.size()) - 1, parse_event_t::object_end, *ref_stack.back()))\n            {\n                // discard object\n                *ref_stack.back() = discarded;\n            }\n            else\n            {\n                ref_stack.back()->set_parents();\n            }\n        }\n\n        JSON_ASSERT(!ref_stack.empty());\n        JSON_ASSERT(!keep_stack.empty());\n        ref_stack.pop_back();\n        keep_stack.pop_back();\n\n        if (!ref_stack.empty() && ref_stack.back() && ref_stack.back()->is_structured())\n        {\n            // remove discarded value\n            for (auto it = ref_stack.back()->begin(); it != ref_stack.back()->end(); ++it)\n            {\n                if (it->is_discarded())\n                {\n                    ref_stack.back()->erase(it);\n                    break;\n                }\n            }\n        }\n\n        return true;\n    }\n\n    bool start_array(std::size_t len)\n    {\n        const bool keep = callback(static_cast<int>(ref_stack.size()), parse_event_t::array_start, discarded);\n        keep_stack.push_back(keep);\n\n        auto val = handle_value(BasicJsonType::value_t::array, true);\n        ref_stack.push_back(val.second);\n\n        // check array limit\n        if (ref_stack.back() && JSON_HEDLEY_UNLIKELY(len != static_cast<std::size_t>(-1) && len > ref_stack.back()->max_size()))\n        {\n            JSON_THROW(out_of_range::create(408, concat(\"excessive array size: \", std::to_string(len)), ref_stack.back()));\n        }\n\n        return true;\n    }\n\n    bool end_array()\n    {\n        bool keep = true;\n\n        if (ref_stack.back())\n        {\n            keep = callback(static_cast<int>(ref_stack.size()) - 1, parse_event_t::array_end, *ref_stack.back());\n            if (keep)\n            {\n                ref_stack.back()->set_parents();\n            }\n            else\n            {\n                // discard array\n                *ref_stack.back() = discarded;\n            }\n        }\n\n        JSON_ASSERT(!ref_stack.empty());\n        JSON_ASSERT(!keep_stack.empty());\n        ref_stack.pop_back();\n        keep_stack.pop_back();\n\n        // remove discarded value\n        if (!keep && !ref_stack.empty() && ref_stack.back()->is_array())\n        {\n            ref_stack.back()->m_data.m_value.array->pop_back();\n        }\n\n        return true;\n    }\n\n    template<class Exception>\n    bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/,\n                     const Exception& ex)\n    {\n        errored = true;\n        static_cast<void>(ex);\n        if (allow_exceptions)\n        {\n            JSON_THROW(ex);\n        }\n        return false;\n    }\n\n    constexpr bool is_errored() const\n    {\n        return errored;\n    }\n\n  private:\n    /*!\n    @param[in] v  value to add to the JSON value we build during parsing\n    @param[in] skip_callback  whether we should skip calling the callback\n               function; this is required after start_array() and\n               start_object() SAX events, because otherwise we would call the\n               callback function with an empty array or object, respectively.\n\n    @invariant If the ref stack is empty, then the passed value will be the new\n               root.\n    @invariant If the ref stack contains a value, then it is an array or an\n               object to which we can add elements\n\n    @return pair of boolean (whether value should be kept) and pointer (to the\n            passed value in the ref_stack hierarchy; nullptr if not kept)\n    */\n    template<typename Value>\n    std::pair<bool, BasicJsonType*> handle_value(Value&& v, const bool skip_callback = false)\n    {\n        JSON_ASSERT(!keep_stack.empty());\n\n        // do not handle this value if we know it would be added to a discarded\n        // container\n        if (!keep_stack.back())\n        {\n            return {false, nullptr};\n        }\n\n        // create value\n        auto value = BasicJsonType(std::forward<Value>(v));\n\n        // check callback\n        const bool keep = skip_callback || callback(static_cast<int>(ref_stack.size()), parse_event_t::value, value);\n\n        // do not handle this value if we just learnt it shall be discarded\n        if (!keep)\n        {\n            return {false, nullptr};\n        }\n\n        if (ref_stack.empty())\n        {\n            root = std::move(value);\n            return {true, & root};\n        }\n\n        // skip this value if we already decided to skip the parent\n        // (https://github.com/nlohmann/json/issues/971#issuecomment-413678360)\n        if (!ref_stack.back())\n        {\n            return {false, nullptr};\n        }\n\n        // we now only expect arrays and objects\n        JSON_ASSERT(ref_stack.back()->is_array() || ref_stack.back()->is_object());\n\n        // array\n        if (ref_stack.back()->is_array())\n        {\n            ref_stack.back()->m_data.m_value.array->emplace_back(std::move(value));\n            return {true, & (ref_stack.back()->m_data.m_value.array->back())};\n        }\n\n        // object\n        JSON_ASSERT(ref_stack.back()->is_object());\n        // check if we should store an element for the current key\n        JSON_ASSERT(!key_keep_stack.empty());\n        const bool store_element = key_keep_stack.back();\n        key_keep_stack.pop_back();\n\n        if (!store_element)\n        {\n            return {false, nullptr};\n        }\n\n        JSON_ASSERT(object_element);\n        *object_element = std::move(value);\n        return {true, object_element};\n    }\n\n    /// the parsed JSON value\n    BasicJsonType& root;\n    /// stack to model hierarchy of values\n    std::vector<BasicJsonType*> ref_stack {};\n    /// stack to manage which values to keep\n    std::vector<bool> keep_stack {}; // NOLINT(readability-redundant-member-init)\n    /// stack to manage which object keys to keep\n    std::vector<bool> key_keep_stack {}; // NOLINT(readability-redundant-member-init)\n    /// helper to hold the reference for the next object element\n    BasicJsonType* object_element = nullptr;\n    /// whether a syntax error occurred\n    bool errored = false;\n    /// callback function\n    const parser_callback_t callback = nullptr;\n    /// whether to throw exceptions in case of errors\n    const bool allow_exceptions = true;\n    /// a discarded value for the callback\n    BasicJsonType discarded = BasicJsonType::value_t::discarded;\n};\n\ntemplate<typename BasicJsonType>\nclass json_sax_acceptor\n{\n  public:\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n    using string_t = typename BasicJsonType::string_t;\n    using binary_t = typename BasicJsonType::binary_t;\n\n    bool null()\n    {\n        return true;\n    }\n\n    bool boolean(bool /*unused*/)\n    {\n        return true;\n    }\n\n    bool number_integer(number_integer_t /*unused*/)\n    {\n        return true;\n    }\n\n    bool number_unsigned(number_unsigned_t /*unused*/)\n    {\n        return true;\n    }\n\n    bool number_float(number_float_t /*unused*/, const string_t& /*unused*/)\n    {\n        return true;\n    }\n\n    bool string(string_t& /*unused*/)\n    {\n        return true;\n    }\n\n    bool binary(binary_t& /*unused*/)\n    {\n        return true;\n    }\n\n    bool start_object(std::size_t /*unused*/ = static_cast<std::size_t>(-1))\n    {\n        return true;\n    }\n\n    bool key(string_t& /*unused*/)\n    {\n        return true;\n    }\n\n    bool end_object()\n    {\n        return true;\n    }\n\n    bool start_array(std::size_t /*unused*/ = static_cast<std::size_t>(-1))\n    {\n        return true;\n    }\n\n    bool end_array()\n    {\n        return true;\n    }\n\n    bool parse_error(std::size_t /*unused*/, const std::string& /*unused*/, const detail::exception& /*unused*/)\n    {\n        return false;\n    }\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/input/lexer.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <array> // array\n#include <clocale> // localeconv\n#include <cstddef> // size_t\n#include <cstdio> // snprintf\n#include <cstdlib> // strtof, strtod, strtold, strtoll, strtoull\n#include <initializer_list> // initializer_list\n#include <string> // char_traits, string\n#include <utility> // move\n#include <vector> // vector\n\n// #include <nlohmann/detail/input/input_adapters.hpp>\n\n// #include <nlohmann/detail/input/position_t.hpp>\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n///////////\n// lexer //\n///////////\n\ntemplate<typename BasicJsonType>\nclass lexer_base\n{\n  public:\n    /// token types for the parser\n    enum class token_type\n    {\n        uninitialized,    ///< indicating the scanner is uninitialized\n        literal_true,     ///< the `true` literal\n        literal_false,    ///< the `false` literal\n        literal_null,     ///< the `null` literal\n        value_string,     ///< a string -- use get_string() for actual value\n        value_unsigned,   ///< an unsigned integer -- use get_number_unsigned() for actual value\n        value_integer,    ///< a signed integer -- use get_number_integer() for actual value\n        value_float,      ///< an floating point number -- use get_number_float() for actual value\n        begin_array,      ///< the character for array begin `[`\n        begin_object,     ///< the character for object begin `{`\n        end_array,        ///< the character for array end `]`\n        end_object,       ///< the character for object end `}`\n        name_separator,   ///< the name separator `:`\n        value_separator,  ///< the value separator `,`\n        parse_error,      ///< indicating a parse error\n        end_of_input,     ///< indicating the end of the input buffer\n        literal_or_value  ///< a literal or the begin of a value (only for diagnostics)\n    };\n\n    /// return name of values of type token_type (only used for errors)\n    JSON_HEDLEY_RETURNS_NON_NULL\n    JSON_HEDLEY_CONST\n    static const char* token_type_name(const token_type t) noexcept\n    {\n        switch (t)\n        {\n            case token_type::uninitialized:\n                return \"<uninitialized>\";\n            case token_type::literal_true:\n                return \"true literal\";\n            case token_type::literal_false:\n                return \"false literal\";\n            case token_type::literal_null:\n                return \"null literal\";\n            case token_type::value_string:\n                return \"string literal\";\n            case token_type::value_unsigned:\n            case token_type::value_integer:\n            case token_type::value_float:\n                return \"number literal\";\n            case token_type::begin_array:\n                return \"'['\";\n            case token_type::begin_object:\n                return \"'{'\";\n            case token_type::end_array:\n                return \"']'\";\n            case token_type::end_object:\n                return \"'}'\";\n            case token_type::name_separator:\n                return \"':'\";\n            case token_type::value_separator:\n                return \"','\";\n            case token_type::parse_error:\n                return \"<parse error>\";\n            case token_type::end_of_input:\n                return \"end of input\";\n            case token_type::literal_or_value:\n                return \"'[', '{', or a literal\";\n            // LCOV_EXCL_START\n            default: // catch non-enum values\n                return \"unknown token\";\n                // LCOV_EXCL_STOP\n        }\n    }\n};\n/*!\n@brief lexical analysis\n\nThis class organizes the lexical analysis during JSON deserialization.\n*/\ntemplate<typename BasicJsonType, typename InputAdapterType>\nclass lexer : public lexer_base<BasicJsonType>\n{\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n    using string_t = typename BasicJsonType::string_t;\n    using char_type = typename InputAdapterType::char_type;\n    using char_int_type = typename char_traits<char_type>::int_type;\n\n  public:\n    using token_type = typename lexer_base<BasicJsonType>::token_type;\n\n    explicit lexer(InputAdapterType&& adapter, bool ignore_comments_ = false) noexcept\n        : ia(std::move(adapter))\n        , ignore_comments(ignore_comments_)\n        , decimal_point_char(static_cast<char_int_type>(get_decimal_point()))\n    {}\n\n    // delete because of pointer members\n    lexer(const lexer&) = delete;\n    lexer(lexer&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor)\n    lexer& operator=(lexer&) = delete;\n    lexer& operator=(lexer&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor)\n    ~lexer() = default;\n\n  private:\n    /////////////////////\n    // locales\n    /////////////////////\n\n    /// return the locale-dependent decimal point\n    JSON_HEDLEY_PURE\n    static char get_decimal_point() noexcept\n    {\n        const auto* loc = localeconv();\n        JSON_ASSERT(loc != nullptr);\n        return (loc->decimal_point == nullptr) ? '.' : *(loc->decimal_point);\n    }\n\n    /////////////////////\n    // scan functions\n    /////////////////////\n\n    /*!\n    @brief get codepoint from 4 hex characters following `\\u`\n\n    For input \"\\u c1 c2 c3 c4\" the codepoint is:\n      (c1 * 0x1000) + (c2 * 0x0100) + (c3 * 0x0010) + c4\n    = (c1 << 12) + (c2 << 8) + (c3 << 4) + (c4 << 0)\n\n    Furthermore, the possible characters '0'..'9', 'A'..'F', and 'a'..'f'\n    must be converted to the integers 0x0..0x9, 0xA..0xF, 0xA..0xF, resp. The\n    conversion is done by subtracting the offset (0x30, 0x37, and 0x57)\n    between the ASCII value of the character and the desired integer value.\n\n    @return codepoint (0x0000..0xFFFF) or -1 in case of an error (e.g. EOF or\n            non-hex character)\n    */\n    int get_codepoint()\n    {\n        // this function only makes sense after reading `\\u`\n        JSON_ASSERT(current == 'u');\n        int codepoint = 0;\n\n        const auto factors = { 12u, 8u, 4u, 0u };\n        for (const auto factor : factors)\n        {\n            get();\n\n            if (current >= '0' && current <= '9')\n            {\n                codepoint += static_cast<int>((static_cast<unsigned int>(current) - 0x30u) << factor);\n            }\n            else if (current >= 'A' && current <= 'F')\n            {\n                codepoint += static_cast<int>((static_cast<unsigned int>(current) - 0x37u) << factor);\n            }\n            else if (current >= 'a' && current <= 'f')\n            {\n                codepoint += static_cast<int>((static_cast<unsigned int>(current) - 0x57u) << factor);\n            }\n            else\n            {\n                return -1;\n            }\n        }\n\n        JSON_ASSERT(0x0000 <= codepoint && codepoint <= 0xFFFF);\n        return codepoint;\n    }\n\n    /*!\n    @brief check if the next byte(s) are inside a given range\n\n    Adds the current byte and, for each passed range, reads a new byte and\n    checks if it is inside the range. If a violation was detected, set up an\n    error message and return false. Otherwise, return true.\n\n    @param[in] ranges  list of integers; interpreted as list of pairs of\n                       inclusive lower and upper bound, respectively\n\n    @pre The passed list @a ranges must have 2, 4, or 6 elements; that is,\n         1, 2, or 3 pairs. This precondition is enforced by an assertion.\n\n    @return true if and only if no range violation was detected\n    */\n    bool next_byte_in_range(std::initializer_list<char_int_type> ranges)\n    {\n        JSON_ASSERT(ranges.size() == 2 || ranges.size() == 4 || ranges.size() == 6);\n        add(current);\n\n        for (auto range = ranges.begin(); range != ranges.end(); ++range)\n        {\n            get();\n            if (JSON_HEDLEY_LIKELY(*range <= current && current <= *(++range))) // NOLINT(bugprone-inc-dec-in-conditions)\n            {\n                add(current);\n            }\n            else\n            {\n                error_message = \"invalid string: ill-formed UTF-8 byte\";\n                return false;\n            }\n        }\n\n        return true;\n    }\n\n    /*!\n    @brief scan a string literal\n\n    This function scans a string according to Sect. 7 of RFC 8259. While\n    scanning, bytes are escaped and copied into buffer token_buffer. Then the\n    function returns successfully, token_buffer is *not* null-terminated (as it\n    may contain \\0 bytes), and token_buffer.size() is the number of bytes in the\n    string.\n\n    @return token_type::value_string if string could be successfully scanned,\n            token_type::parse_error otherwise\n\n    @note In case of errors, variable error_message contains a textual\n          description.\n    */\n    token_type scan_string()\n    {\n        // reset token_buffer (ignore opening quote)\n        reset();\n\n        // we entered the function by reading an open quote\n        JSON_ASSERT(current == '\\\"');\n\n        while (true)\n        {\n            // get next character\n            switch (get())\n            {\n                // end of file while parsing string\n                case char_traits<char_type>::eof():\n                {\n                    error_message = \"invalid string: missing closing quote\";\n                    return token_type::parse_error;\n                }\n\n                // closing quote\n                case '\\\"':\n                {\n                    return token_type::value_string;\n                }\n\n                // escapes\n                case '\\\\':\n                {\n                    switch (get())\n                    {\n                        // quotation mark\n                        case '\\\"':\n                            add('\\\"');\n                            break;\n                        // reverse solidus\n                        case '\\\\':\n                            add('\\\\');\n                            break;\n                        // solidus\n                        case '/':\n                            add('/');\n                            break;\n                        // backspace\n                        case 'b':\n                            add('\\b');\n                            break;\n                        // form feed\n                        case 'f':\n                            add('\\f');\n                            break;\n                        // line feed\n                        case 'n':\n                            add('\\n');\n                            break;\n                        // carriage return\n                        case 'r':\n                            add('\\r');\n                            break;\n                        // tab\n                        case 't':\n                            add('\\t');\n                            break;\n\n                        // unicode escapes\n                        case 'u':\n                        {\n                            const int codepoint1 = get_codepoint();\n                            int codepoint = codepoint1; // start with codepoint1\n\n                            if (JSON_HEDLEY_UNLIKELY(codepoint1 == -1))\n                            {\n                                error_message = \"invalid string: '\\\\u' must be followed by 4 hex digits\";\n                                return token_type::parse_error;\n                            }\n\n                            // check if code point is a high surrogate\n                            if (0xD800 <= codepoint1 && codepoint1 <= 0xDBFF)\n                            {\n                                // expect next \\uxxxx entry\n                                if (JSON_HEDLEY_LIKELY(get() == '\\\\' && get() == 'u'))\n                                {\n                                    const int codepoint2 = get_codepoint();\n\n                                    if (JSON_HEDLEY_UNLIKELY(codepoint2 == -1))\n                                    {\n                                        error_message = \"invalid string: '\\\\u' must be followed by 4 hex digits\";\n                                        return token_type::parse_error;\n                                    }\n\n                                    // check if codepoint2 is a low surrogate\n                                    if (JSON_HEDLEY_LIKELY(0xDC00 <= codepoint2 && codepoint2 <= 0xDFFF))\n                                    {\n                                        // overwrite codepoint\n                                        codepoint = static_cast<int>(\n                                                        // high surrogate occupies the most significant 22 bits\n                                                        (static_cast<unsigned int>(codepoint1) << 10u)\n                                                        // low surrogate occupies the least significant 15 bits\n                                                        + static_cast<unsigned int>(codepoint2)\n                                                        // there is still the 0xD800, 0xDC00 and 0x10000 noise\n                                                        // in the result, so we have to subtract with:\n                                                        // (0xD800 << 10) + DC00 - 0x10000 = 0x35FDC00\n                                                        - 0x35FDC00u);\n                                    }\n                                    else\n                                    {\n                                        error_message = \"invalid string: surrogate U+D800..U+DBFF must be followed by U+DC00..U+DFFF\";\n                                        return token_type::parse_error;\n                                    }\n                                }\n                                else\n                                {\n                                    error_message = \"invalid string: surrogate U+D800..U+DBFF must be followed by U+DC00..U+DFFF\";\n                                    return token_type::parse_error;\n                                }\n                            }\n                            else\n                            {\n                                if (JSON_HEDLEY_UNLIKELY(0xDC00 <= codepoint1 && codepoint1 <= 0xDFFF))\n                                {\n                                    error_message = \"invalid string: surrogate U+DC00..U+DFFF must follow U+D800..U+DBFF\";\n                                    return token_type::parse_error;\n                                }\n                            }\n\n                            // result of the above calculation yields a proper codepoint\n                            JSON_ASSERT(0x00 <= codepoint && codepoint <= 0x10FFFF);\n\n                            // translate codepoint into bytes\n                            if (codepoint < 0x80)\n                            {\n                                // 1-byte characters: 0xxxxxxx (ASCII)\n                                add(static_cast<char_int_type>(codepoint));\n                            }\n                            else if (codepoint <= 0x7FF)\n                            {\n                                // 2-byte characters: 110xxxxx 10xxxxxx\n                                add(static_cast<char_int_type>(0xC0u | (static_cast<unsigned int>(codepoint) >> 6u)));\n                                add(static_cast<char_int_type>(0x80u | (static_cast<unsigned int>(codepoint) & 0x3Fu)));\n                            }\n                            else if (codepoint <= 0xFFFF)\n                            {\n                                // 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx\n                                add(static_cast<char_int_type>(0xE0u | (static_cast<unsigned int>(codepoint) >> 12u)));\n                                add(static_cast<char_int_type>(0x80u | ((static_cast<unsigned int>(codepoint) >> 6u) & 0x3Fu)));\n                                add(static_cast<char_int_type>(0x80u | (static_cast<unsigned int>(codepoint) & 0x3Fu)));\n                            }\n                            else\n                            {\n                                // 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx\n                                add(static_cast<char_int_type>(0xF0u | (static_cast<unsigned int>(codepoint) >> 18u)));\n                                add(static_cast<char_int_type>(0x80u | ((static_cast<unsigned int>(codepoint) >> 12u) & 0x3Fu)));\n                                add(static_cast<char_int_type>(0x80u | ((static_cast<unsigned int>(codepoint) >> 6u) & 0x3Fu)));\n                                add(static_cast<char_int_type>(0x80u | (static_cast<unsigned int>(codepoint) & 0x3Fu)));\n                            }\n\n                            break;\n                        }\n\n                        // other characters after escape\n                        default:\n                            error_message = \"invalid string: forbidden character after backslash\";\n                            return token_type::parse_error;\n                    }\n\n                    break;\n                }\n\n                // invalid control characters\n                case 0x00:\n                {\n                    error_message = \"invalid string: control character U+0000 (NUL) must be escaped to \\\\u0000\";\n                    return token_type::parse_error;\n                }\n\n                case 0x01:\n                {\n                    error_message = \"invalid string: control character U+0001 (SOH) must be escaped to \\\\u0001\";\n                    return token_type::parse_error;\n                }\n\n                case 0x02:\n                {\n                    error_message = \"invalid string: control character U+0002 (STX) must be escaped to \\\\u0002\";\n                    return token_type::parse_error;\n                }\n\n                case 0x03:\n                {\n                    error_message = \"invalid string: control character U+0003 (ETX) must be escaped to \\\\u0003\";\n                    return token_type::parse_error;\n                }\n\n                case 0x04:\n                {\n                    error_message = \"invalid string: control character U+0004 (EOT) must be escaped to \\\\u0004\";\n                    return token_type::parse_error;\n                }\n\n                case 0x05:\n                {\n                    error_message = \"invalid string: control character U+0005 (ENQ) must be escaped to \\\\u0005\";\n                    return token_type::parse_error;\n                }\n\n                case 0x06:\n                {\n                    error_message = \"invalid string: control character U+0006 (ACK) must be escaped to \\\\u0006\";\n                    return token_type::parse_error;\n                }\n\n                case 0x07:\n                {\n                    error_message = \"invalid string: control character U+0007 (BEL) must be escaped to \\\\u0007\";\n                    return token_type::parse_error;\n                }\n\n                case 0x08:\n                {\n                    error_message = \"invalid string: control character U+0008 (BS) must be escaped to \\\\u0008 or \\\\b\";\n                    return token_type::parse_error;\n                }\n\n                case 0x09:\n                {\n                    error_message = \"invalid string: control character U+0009 (HT) must be escaped to \\\\u0009 or \\\\t\";\n                    return token_type::parse_error;\n                }\n\n                case 0x0A:\n                {\n                    error_message = \"invalid string: control character U+000A (LF) must be escaped to \\\\u000A or \\\\n\";\n                    return token_type::parse_error;\n                }\n\n                case 0x0B:\n                {\n                    error_message = \"invalid string: control character U+000B (VT) must be escaped to \\\\u000B\";\n                    return token_type::parse_error;\n                }\n\n                case 0x0C:\n                {\n                    error_message = \"invalid string: control character U+000C (FF) must be escaped to \\\\u000C or \\\\f\";\n                    return token_type::parse_error;\n                }\n\n                case 0x0D:\n                {\n                    error_message = \"invalid string: control character U+000D (CR) must be escaped to \\\\u000D or \\\\r\";\n                    return token_type::parse_error;\n                }\n\n                case 0x0E:\n                {\n                    error_message = \"invalid string: control character U+000E (SO) must be escaped to \\\\u000E\";\n                    return token_type::parse_error;\n                }\n\n                case 0x0F:\n                {\n                    error_message = \"invalid string: control character U+000F (SI) must be escaped to \\\\u000F\";\n                    return token_type::parse_error;\n                }\n\n                case 0x10:\n                {\n                    error_message = \"invalid string: control character U+0010 (DLE) must be escaped to \\\\u0010\";\n                    return token_type::parse_error;\n                }\n\n                case 0x11:\n                {\n                    error_message = \"invalid string: control character U+0011 (DC1) must be escaped to \\\\u0011\";\n                    return token_type::parse_error;\n                }\n\n                case 0x12:\n                {\n                    error_message = \"invalid string: control character U+0012 (DC2) must be escaped to \\\\u0012\";\n                    return token_type::parse_error;\n                }\n\n                case 0x13:\n                {\n                    error_message = \"invalid string: control character U+0013 (DC3) must be escaped to \\\\u0013\";\n                    return token_type::parse_error;\n                }\n\n                case 0x14:\n                {\n                    error_message = \"invalid string: control character U+0014 (DC4) must be escaped to \\\\u0014\";\n                    return token_type::parse_error;\n                }\n\n                case 0x15:\n                {\n                    error_message = \"invalid string: control character U+0015 (NAK) must be escaped to \\\\u0015\";\n                    return token_type::parse_error;\n                }\n\n                case 0x16:\n                {\n                    error_message = \"invalid string: control character U+0016 (SYN) must be escaped to \\\\u0016\";\n                    return token_type::parse_error;\n                }\n\n                case 0x17:\n                {\n                    error_message = \"invalid string: control character U+0017 (ETB) must be escaped to \\\\u0017\";\n                    return token_type::parse_error;\n                }\n\n                case 0x18:\n                {\n                    error_message = \"invalid string: control character U+0018 (CAN) must be escaped to \\\\u0018\";\n                    return token_type::parse_error;\n                }\n\n                case 0x19:\n                {\n                    error_message = \"invalid string: control character U+0019 (EM) must be escaped to \\\\u0019\";\n                    return token_type::parse_error;\n                }\n\n                case 0x1A:\n                {\n                    error_message = \"invalid string: control character U+001A (SUB) must be escaped to \\\\u001A\";\n                    return token_type::parse_error;\n                }\n\n                case 0x1B:\n                {\n                    error_message = \"invalid string: control character U+001B (ESC) must be escaped to \\\\u001B\";\n                    return token_type::parse_error;\n                }\n\n                case 0x1C:\n                {\n                    error_message = \"invalid string: control character U+001C (FS) must be escaped to \\\\u001C\";\n                    return token_type::parse_error;\n                }\n\n                case 0x1D:\n                {\n                    error_message = \"invalid string: control character U+001D (GS) must be escaped to \\\\u001D\";\n                    return token_type::parse_error;\n                }\n\n                case 0x1E:\n                {\n                    error_message = \"invalid string: control character U+001E (RS) must be escaped to \\\\u001E\";\n                    return token_type::parse_error;\n                }\n\n                case 0x1F:\n                {\n                    error_message = \"invalid string: control character U+001F (US) must be escaped to \\\\u001F\";\n                    return token_type::parse_error;\n                }\n\n                // U+0020..U+007F (except U+0022 (quote) and U+005C (backspace))\n                case 0x20:\n                case 0x21:\n                case 0x23:\n                case 0x24:\n                case 0x25:\n                case 0x26:\n                case 0x27:\n                case 0x28:\n                case 0x29:\n                case 0x2A:\n                case 0x2B:\n                case 0x2C:\n                case 0x2D:\n                case 0x2E:\n                case 0x2F:\n                case 0x30:\n                case 0x31:\n                case 0x32:\n                case 0x33:\n                case 0x34:\n                case 0x35:\n                case 0x36:\n                case 0x37:\n                case 0x38:\n                case 0x39:\n                case 0x3A:\n                case 0x3B:\n                case 0x3C:\n                case 0x3D:\n                case 0x3E:\n                case 0x3F:\n                case 0x40:\n                case 0x41:\n                case 0x42:\n                case 0x43:\n                case 0x44:\n                case 0x45:\n                case 0x46:\n                case 0x47:\n                case 0x48:\n                case 0x49:\n                case 0x4A:\n                case 0x4B:\n                case 0x4C:\n                case 0x4D:\n                case 0x4E:\n                case 0x4F:\n                case 0x50:\n                case 0x51:\n                case 0x52:\n                case 0x53:\n                case 0x54:\n                case 0x55:\n                case 0x56:\n                case 0x57:\n                case 0x58:\n                case 0x59:\n                case 0x5A:\n                case 0x5B:\n                case 0x5D:\n                case 0x5E:\n                case 0x5F:\n                case 0x60:\n                case 0x61:\n                case 0x62:\n                case 0x63:\n                case 0x64:\n                case 0x65:\n                case 0x66:\n                case 0x67:\n                case 0x68:\n                case 0x69:\n                case 0x6A:\n                case 0x6B:\n                case 0x6C:\n                case 0x6D:\n                case 0x6E:\n                case 0x6F:\n                case 0x70:\n                case 0x71:\n                case 0x72:\n                case 0x73:\n                case 0x74:\n                case 0x75:\n                case 0x76:\n                case 0x77:\n                case 0x78:\n                case 0x79:\n                case 0x7A:\n                case 0x7B:\n                case 0x7C:\n                case 0x7D:\n                case 0x7E:\n                case 0x7F:\n                {\n                    add(current);\n                    break;\n                }\n\n                // U+0080..U+07FF: bytes C2..DF 80..BF\n                case 0xC2:\n                case 0xC3:\n                case 0xC4:\n                case 0xC5:\n                case 0xC6:\n                case 0xC7:\n                case 0xC8:\n                case 0xC9:\n                case 0xCA:\n                case 0xCB:\n                case 0xCC:\n                case 0xCD:\n                case 0xCE:\n                case 0xCF:\n                case 0xD0:\n                case 0xD1:\n                case 0xD2:\n                case 0xD3:\n                case 0xD4:\n                case 0xD5:\n                case 0xD6:\n                case 0xD7:\n                case 0xD8:\n                case 0xD9:\n                case 0xDA:\n                case 0xDB:\n                case 0xDC:\n                case 0xDD:\n                case 0xDE:\n                case 0xDF:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!next_byte_in_range({0x80, 0xBF})))\n                    {\n                        return token_type::parse_error;\n                    }\n                    break;\n                }\n\n                // U+0800..U+0FFF: bytes E0 A0..BF 80..BF\n                case 0xE0:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0xA0, 0xBF, 0x80, 0xBF}))))\n                    {\n                        return token_type::parse_error;\n                    }\n                    break;\n                }\n\n                // U+1000..U+CFFF: bytes E1..EC 80..BF 80..BF\n                // U+E000..U+FFFF: bytes EE..EF 80..BF 80..BF\n                case 0xE1:\n                case 0xE2:\n                case 0xE3:\n                case 0xE4:\n                case 0xE5:\n                case 0xE6:\n                case 0xE7:\n                case 0xE8:\n                case 0xE9:\n                case 0xEA:\n                case 0xEB:\n                case 0xEC:\n                case 0xEE:\n                case 0xEF:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0xBF, 0x80, 0xBF}))))\n                    {\n                        return token_type::parse_error;\n                    }\n                    break;\n                }\n\n                // U+D000..U+D7FF: bytes ED 80..9F 80..BF\n                case 0xED:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0x9F, 0x80, 0xBF}))))\n                    {\n                        return token_type::parse_error;\n                    }\n                    break;\n                }\n\n                // U+10000..U+3FFFF F0 90..BF 80..BF 80..BF\n                case 0xF0:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x90, 0xBF, 0x80, 0xBF, 0x80, 0xBF}))))\n                    {\n                        return token_type::parse_error;\n                    }\n                    break;\n                }\n\n                // U+40000..U+FFFFF F1..F3 80..BF 80..BF 80..BF\n                case 0xF1:\n                case 0xF2:\n                case 0xF3:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0xBF, 0x80, 0xBF, 0x80, 0xBF}))))\n                    {\n                        return token_type::parse_error;\n                    }\n                    break;\n                }\n\n                // U+100000..U+10FFFF F4 80..8F 80..BF 80..BF\n                case 0xF4:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!(next_byte_in_range({0x80, 0x8F, 0x80, 0xBF, 0x80, 0xBF}))))\n                    {\n                        return token_type::parse_error;\n                    }\n                    break;\n                }\n\n                // remaining bytes (80..C1 and F5..FF) are ill-formed\n                default:\n                {\n                    error_message = \"invalid string: ill-formed UTF-8 byte\";\n                    return token_type::parse_error;\n                }\n            }\n        }\n    }\n\n    /*!\n     * @brief scan a comment\n     * @return whether comment could be scanned successfully\n     */\n    bool scan_comment()\n    {\n        switch (get())\n        {\n            // single-line comments skip input until a newline or EOF is read\n            case '/':\n            {\n                while (true)\n                {\n                    switch (get())\n                    {\n                        case '\\n':\n                        case '\\r':\n                        case char_traits<char_type>::eof():\n                        case '\\0':\n                            return true;\n\n                        default:\n                            break;\n                    }\n                }\n            }\n\n            // multi-line comments skip input until */ is read\n            case '*':\n            {\n                while (true)\n                {\n                    switch (get())\n                    {\n                        case char_traits<char_type>::eof():\n                        case '\\0':\n                        {\n                            error_message = \"invalid comment; missing closing '*/'\";\n                            return false;\n                        }\n\n                        case '*':\n                        {\n                            switch (get())\n                            {\n                                case '/':\n                                    return true;\n\n                                default:\n                                {\n                                    unget();\n                                    continue;\n                                }\n                            }\n                        }\n\n                        default:\n                            continue;\n                    }\n                }\n            }\n\n            // unexpected character after reading '/'\n            default:\n            {\n                error_message = \"invalid comment; expecting '/' or '*' after '/'\";\n                return false;\n            }\n        }\n    }\n\n    JSON_HEDLEY_NON_NULL(2)\n    static void strtof(float& f, const char* str, char** endptr) noexcept\n    {\n        f = std::strtof(str, endptr);\n    }\n\n    JSON_HEDLEY_NON_NULL(2)\n    static void strtof(double& f, const char* str, char** endptr) noexcept\n    {\n        f = std::strtod(str, endptr);\n    }\n\n    JSON_HEDLEY_NON_NULL(2)\n    static void strtof(long double& f, const char* str, char** endptr) noexcept\n    {\n        f = std::strtold(str, endptr);\n    }\n\n    /*!\n    @brief scan a number literal\n\n    This function scans a string according to Sect. 6 of RFC 8259.\n\n    The function is realized with a deterministic finite state machine derived\n    from the grammar described in RFC 8259. Starting in state \"init\", the\n    input is read and used to determined the next state. Only state \"done\"\n    accepts the number. State \"error\" is a trap state to model errors. In the\n    table below, \"anything\" means any character but the ones listed before.\n\n    state    | 0        | 1-9      | e E      | +       | -       | .        | anything\n    ---------|----------|----------|----------|---------|---------|----------|-----------\n    init     | zero     | any1     | [error]  | [error] | minus   | [error]  | [error]\n    minus    | zero     | any1     | [error]  | [error] | [error] | [error]  | [error]\n    zero     | done     | done     | exponent | done    | done    | decimal1 | done\n    any1     | any1     | any1     | exponent | done    | done    | decimal1 | done\n    decimal1 | decimal2 | decimal2 | [error]  | [error] | [error] | [error]  | [error]\n    decimal2 | decimal2 | decimal2 | exponent | done    | done    | done     | done\n    exponent | any2     | any2     | [error]  | sign    | sign    | [error]  | [error]\n    sign     | any2     | any2     | [error]  | [error] | [error] | [error]  | [error]\n    any2     | any2     | any2     | done     | done    | done    | done     | done\n\n    The state machine is realized with one label per state (prefixed with\n    \"scan_number_\") and `goto` statements between them. The state machine\n    contains cycles, but any cycle can be left when EOF is read. Therefore,\n    the function is guaranteed to terminate.\n\n    During scanning, the read bytes are stored in token_buffer. This string is\n    then converted to a signed integer, an unsigned integer, or a\n    floating-point number.\n\n    @return token_type::value_unsigned, token_type::value_integer, or\n            token_type::value_float if number could be successfully scanned,\n            token_type::parse_error otherwise\n\n    @note The scanner is independent of the current locale. Internally, the\n          locale's decimal point is used instead of `.` to work with the\n          locale-dependent converters.\n    */\n    token_type scan_number()  // lgtm [cpp/use-of-goto] `goto` is used in this function to implement the number-parsing state machine described above. By design, any finite input will eventually reach the \"done\" state or return token_type::parse_error. In each intermediate state, 1 byte of the input is appended to the token_buffer vector, and only the already initialized variables token_buffer, number_type, and error_message are manipulated.\n    {\n        // reset token_buffer to store the number's bytes\n        reset();\n\n        // the type of the parsed number; initially set to unsigned; will be\n        // changed if minus sign, decimal point or exponent is read\n        token_type number_type = token_type::value_unsigned;\n\n        // state (init): we just found out we need to scan a number\n        switch (current)\n        {\n            case '-':\n            {\n                add(current);\n                goto scan_number_minus;\n            }\n\n            case '0':\n            {\n                add(current);\n                goto scan_number_zero;\n            }\n\n            case '1':\n            case '2':\n            case '3':\n            case '4':\n            case '5':\n            case '6':\n            case '7':\n            case '8':\n            case '9':\n            {\n                add(current);\n                goto scan_number_any1;\n            }\n\n            // all other characters are rejected outside scan_number()\n            default:            // LCOV_EXCL_LINE\n                JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n        }\n\nscan_number_minus:\n        // state: we just parsed a leading minus sign\n        number_type = token_type::value_integer;\n        switch (get())\n        {\n            case '0':\n            {\n                add(current);\n                goto scan_number_zero;\n            }\n\n            case '1':\n            case '2':\n            case '3':\n            case '4':\n            case '5':\n            case '6':\n            case '7':\n            case '8':\n            case '9':\n            {\n                add(current);\n                goto scan_number_any1;\n            }\n\n            default:\n            {\n                error_message = \"invalid number; expected digit after '-'\";\n                return token_type::parse_error;\n            }\n        }\n\nscan_number_zero:\n        // state: we just parse a zero (maybe with a leading minus sign)\n        switch (get())\n        {\n            case '.':\n            {\n                add(decimal_point_char);\n                goto scan_number_decimal1;\n            }\n\n            case 'e':\n            case 'E':\n            {\n                add(current);\n                goto scan_number_exponent;\n            }\n\n            default:\n                goto scan_number_done;\n        }\n\nscan_number_any1:\n        // state: we just parsed a number 0-9 (maybe with a leading minus sign)\n        switch (get())\n        {\n            case '0':\n            case '1':\n            case '2':\n            case '3':\n            case '4':\n            case '5':\n            case '6':\n            case '7':\n            case '8':\n            case '9':\n            {\n                add(current);\n                goto scan_number_any1;\n            }\n\n            case '.':\n            {\n                add(decimal_point_char);\n                goto scan_number_decimal1;\n            }\n\n            case 'e':\n            case 'E':\n            {\n                add(current);\n                goto scan_number_exponent;\n            }\n\n            default:\n                goto scan_number_done;\n        }\n\nscan_number_decimal1:\n        // state: we just parsed a decimal point\n        number_type = token_type::value_float;\n        switch (get())\n        {\n            case '0':\n            case '1':\n            case '2':\n            case '3':\n            case '4':\n            case '5':\n            case '6':\n            case '7':\n            case '8':\n            case '9':\n            {\n                add(current);\n                goto scan_number_decimal2;\n            }\n\n            default:\n            {\n                error_message = \"invalid number; expected digit after '.'\";\n                return token_type::parse_error;\n            }\n        }\n\nscan_number_decimal2:\n        // we just parsed at least one number after a decimal point\n        switch (get())\n        {\n            case '0':\n            case '1':\n            case '2':\n            case '3':\n            case '4':\n            case '5':\n            case '6':\n            case '7':\n            case '8':\n            case '9':\n            {\n                add(current);\n                goto scan_number_decimal2;\n            }\n\n            case 'e':\n            case 'E':\n            {\n                add(current);\n                goto scan_number_exponent;\n            }\n\n            default:\n                goto scan_number_done;\n        }\n\nscan_number_exponent:\n        // we just parsed an exponent\n        number_type = token_type::value_float;\n        switch (get())\n        {\n            case '+':\n            case '-':\n            {\n                add(current);\n                goto scan_number_sign;\n            }\n\n            case '0':\n            case '1':\n            case '2':\n            case '3':\n            case '4':\n            case '5':\n            case '6':\n            case '7':\n            case '8':\n            case '9':\n            {\n                add(current);\n                goto scan_number_any2;\n            }\n\n            default:\n            {\n                error_message =\n                    \"invalid number; expected '+', '-', or digit after exponent\";\n                return token_type::parse_error;\n            }\n        }\n\nscan_number_sign:\n        // we just parsed an exponent sign\n        switch (get())\n        {\n            case '0':\n            case '1':\n            case '2':\n            case '3':\n            case '4':\n            case '5':\n            case '6':\n            case '7':\n            case '8':\n            case '9':\n            {\n                add(current);\n                goto scan_number_any2;\n            }\n\n            default:\n            {\n                error_message = \"invalid number; expected digit after exponent sign\";\n                return token_type::parse_error;\n            }\n        }\n\nscan_number_any2:\n        // we just parsed a number after the exponent or exponent sign\n        switch (get())\n        {\n            case '0':\n            case '1':\n            case '2':\n            case '3':\n            case '4':\n            case '5':\n            case '6':\n            case '7':\n            case '8':\n            case '9':\n            {\n                add(current);\n                goto scan_number_any2;\n            }\n\n            default:\n                goto scan_number_done;\n        }\n\nscan_number_done:\n        // unget the character after the number (we only read it to know that\n        // we are done scanning a number)\n        unget();\n\n        char* endptr = nullptr; // NOLINT(cppcoreguidelines-pro-type-vararg,hicpp-vararg)\n        errno = 0;\n\n        // try to parse integers first and fall back to floats\n        if (number_type == token_type::value_unsigned)\n        {\n            const auto x = std::strtoull(token_buffer.data(), &endptr, 10);\n\n            // we checked the number format before\n            JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size());\n\n            if (errno == 0)\n            {\n                value_unsigned = static_cast<number_unsigned_t>(x);\n                if (value_unsigned == x)\n                {\n                    return token_type::value_unsigned;\n                }\n            }\n        }\n        else if (number_type == token_type::value_integer)\n        {\n            const auto x = std::strtoll(token_buffer.data(), &endptr, 10);\n\n            // we checked the number format before\n            JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size());\n\n            if (errno == 0)\n            {\n                value_integer = static_cast<number_integer_t>(x);\n                if (value_integer == x)\n                {\n                    return token_type::value_integer;\n                }\n            }\n        }\n\n        // this code is reached if we parse a floating-point number or if an\n        // integer conversion above failed\n        strtof(value_float, token_buffer.data(), &endptr);\n\n        // we checked the number format before\n        JSON_ASSERT(endptr == token_buffer.data() + token_buffer.size());\n\n        return token_type::value_float;\n    }\n\n    /*!\n    @param[in] literal_text  the literal text to expect\n    @param[in] length        the length of the passed literal text\n    @param[in] return_type   the token type to return on success\n    */\n    JSON_HEDLEY_NON_NULL(2)\n    token_type scan_literal(const char_type* literal_text, const std::size_t length,\n                            token_type return_type)\n    {\n        JSON_ASSERT(char_traits<char_type>::to_char_type(current) == literal_text[0]);\n        for (std::size_t i = 1; i < length; ++i)\n        {\n            if (JSON_HEDLEY_UNLIKELY(char_traits<char_type>::to_char_type(get()) != literal_text[i]))\n            {\n                error_message = \"invalid literal\";\n                return token_type::parse_error;\n            }\n        }\n        return return_type;\n    }\n\n    /////////////////////\n    // input management\n    /////////////////////\n\n    /// reset token_buffer; current character is beginning of token\n    void reset() noexcept\n    {\n        token_buffer.clear();\n        token_string.clear();\n        token_string.push_back(char_traits<char_type>::to_char_type(current));\n    }\n\n    /*\n    @brief get next character from the input\n\n    This function provides the interface to the used input adapter. It does\n    not throw in case the input reached EOF, but returns a\n    `char_traits<char>::eof()` in that case.  Stores the scanned characters\n    for use in error messages.\n\n    @return character read from the input\n    */\n    char_int_type get()\n    {\n        ++position.chars_read_total;\n        ++position.chars_read_current_line;\n\n        if (next_unget)\n        {\n            // just reset the next_unget variable and work with current\n            next_unget = false;\n        }\n        else\n        {\n            current = ia.get_character();\n        }\n\n        if (JSON_HEDLEY_LIKELY(current != char_traits<char_type>::eof()))\n        {\n            token_string.push_back(char_traits<char_type>::to_char_type(current));\n        }\n\n        if (current == '\\n')\n        {\n            ++position.lines_read;\n            position.chars_read_current_line = 0;\n        }\n\n        return current;\n    }\n\n    /*!\n    @brief unget current character (read it again on next get)\n\n    We implement unget by setting variable next_unget to true. The input is not\n    changed - we just simulate ungetting by modifying chars_read_total,\n    chars_read_current_line, and token_string. The next call to get() will\n    behave as if the unget character is read again.\n    */\n    void unget()\n    {\n        next_unget = true;\n\n        --position.chars_read_total;\n\n        // in case we \"unget\" a newline, we have to also decrement the lines_read\n        if (position.chars_read_current_line == 0)\n        {\n            if (position.lines_read > 0)\n            {\n                --position.lines_read;\n            }\n        }\n        else\n        {\n            --position.chars_read_current_line;\n        }\n\n        if (JSON_HEDLEY_LIKELY(current != char_traits<char_type>::eof()))\n        {\n            JSON_ASSERT(!token_string.empty());\n            token_string.pop_back();\n        }\n    }\n\n    /// add a character to token_buffer\n    void add(char_int_type c)\n    {\n        token_buffer.push_back(static_cast<typename string_t::value_type>(c));\n    }\n\n  public:\n    /////////////////////\n    // value getters\n    /////////////////////\n\n    /// return integer value\n    constexpr number_integer_t get_number_integer() const noexcept\n    {\n        return value_integer;\n    }\n\n    /// return unsigned integer value\n    constexpr number_unsigned_t get_number_unsigned() const noexcept\n    {\n        return value_unsigned;\n    }\n\n    /// return floating-point value\n    constexpr number_float_t get_number_float() const noexcept\n    {\n        return value_float;\n    }\n\n    /// return current string value (implicitly resets the token; useful only once)\n    string_t& get_string()\n    {\n        return token_buffer;\n    }\n\n    /////////////////////\n    // diagnostics\n    /////////////////////\n\n    /// return position of last read token\n    constexpr position_t get_position() const noexcept\n    {\n        return position;\n    }\n\n    /// return the last read token (for errors only).  Will never contain EOF\n    /// (an arbitrary value that is not a valid char value, often -1), because\n    /// 255 may legitimately occur.  May contain NUL, which should be escaped.\n    std::string get_token_string() const\n    {\n        // escape control characters\n        std::string result;\n        for (const auto c : token_string)\n        {\n            if (static_cast<unsigned char>(c) <= '\\x1F')\n            {\n                // escape control characters\n                std::array<char, 9> cs{{}};\n                static_cast<void>((std::snprintf)(cs.data(), cs.size(), \"<U+%.4X>\", static_cast<unsigned char>(c))); // NOLINT(cppcoreguidelines-pro-type-vararg,hicpp-vararg)\n                result += cs.data();\n            }\n            else\n            {\n                // add character as is\n                result.push_back(static_cast<std::string::value_type>(c));\n            }\n        }\n\n        return result;\n    }\n\n    /// return syntax error message\n    JSON_HEDLEY_RETURNS_NON_NULL\n    constexpr const char* get_error_message() const noexcept\n    {\n        return error_message;\n    }\n\n    /////////////////////\n    // actual scanner\n    /////////////////////\n\n    /*!\n    @brief skip the UTF-8 byte order mark\n    @return true iff there is no BOM or the correct BOM has been skipped\n    */\n    bool skip_bom()\n    {\n        if (get() == 0xEF)\n        {\n            // check if we completely parse the BOM\n            return get() == 0xBB && get() == 0xBF;\n        }\n\n        // the first character is not the beginning of the BOM; unget it to\n        // process is later\n        unget();\n        return true;\n    }\n\n    void skip_whitespace()\n    {\n        do\n        {\n            get();\n        }\n        while (current == ' ' || current == '\\t' || current == '\\n' || current == '\\r');\n    }\n\n    token_type scan()\n    {\n        // initially, skip the BOM\n        if (position.chars_read_total == 0 && !skip_bom())\n        {\n            error_message = \"invalid BOM; must be 0xEF 0xBB 0xBF if given\";\n            return token_type::parse_error;\n        }\n\n        // read next character and ignore whitespace\n        skip_whitespace();\n\n        // ignore comments\n        while (ignore_comments && current == '/')\n        {\n            if (!scan_comment())\n            {\n                return token_type::parse_error;\n            }\n\n            // skip following whitespace\n            skip_whitespace();\n        }\n\n        switch (current)\n        {\n            // structural characters\n            case '[':\n                return token_type::begin_array;\n            case ']':\n                return token_type::end_array;\n            case '{':\n                return token_type::begin_object;\n            case '}':\n                return token_type::end_object;\n            case ':':\n                return token_type::name_separator;\n            case ',':\n                return token_type::value_separator;\n\n            // literals\n            case 't':\n            {\n                std::array<char_type, 4> true_literal = {{static_cast<char_type>('t'), static_cast<char_type>('r'), static_cast<char_type>('u'), static_cast<char_type>('e')}};\n                return scan_literal(true_literal.data(), true_literal.size(), token_type::literal_true);\n            }\n            case 'f':\n            {\n                std::array<char_type, 5> false_literal = {{static_cast<char_type>('f'), static_cast<char_type>('a'), static_cast<char_type>('l'), static_cast<char_type>('s'), static_cast<char_type>('e')}};\n                return scan_literal(false_literal.data(), false_literal.size(), token_type::literal_false);\n            }\n            case 'n':\n            {\n                std::array<char_type, 4> null_literal = {{static_cast<char_type>('n'), static_cast<char_type>('u'), static_cast<char_type>('l'), static_cast<char_type>('l')}};\n                return scan_literal(null_literal.data(), null_literal.size(), token_type::literal_null);\n            }\n\n            // string\n            case '\\\"':\n                return scan_string();\n\n            // number\n            case '-':\n            case '0':\n            case '1':\n            case '2':\n            case '3':\n            case '4':\n            case '5':\n            case '6':\n            case '7':\n            case '8':\n            case '9':\n                return scan_number();\n\n            // end of input (the null byte is needed when parsing from\n            // string literals)\n            case '\\0':\n            case char_traits<char_type>::eof():\n                return token_type::end_of_input;\n\n            // error\n            default:\n                error_message = \"invalid literal\";\n                return token_type::parse_error;\n        }\n    }\n\n  private:\n    /// input adapter\n    InputAdapterType ia;\n\n    /// whether comments should be ignored (true) or signaled as errors (false)\n    const bool ignore_comments = false;\n\n    /// the current character\n    char_int_type current = char_traits<char_type>::eof();\n\n    /// whether the next get() call should just return current\n    bool next_unget = false;\n\n    /// the start position of the current token\n    position_t position {};\n\n    /// raw input token string (for error messages)\n    std::vector<char_type> token_string {};\n\n    /// buffer for variable-length tokens (numbers, strings)\n    string_t token_buffer {};\n\n    /// a description of occurred lexer errors\n    const char* error_message = \"\";\n\n    // number values\n    number_integer_t value_integer = 0;\n    number_unsigned_t value_unsigned = 0;\n    number_float_t value_float = 0;\n\n    /// the decimal point\n    const char_int_type decimal_point_char = '.';\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/is_sax.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cstdint> // size_t\n#include <utility> // declval\n#include <string> // string\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n// #include <nlohmann/detail/meta/detected.hpp>\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\ntemplate<typename T>\nusing null_function_t = decltype(std::declval<T&>().null());\n\ntemplate<typename T>\nusing boolean_function_t =\n    decltype(std::declval<T&>().boolean(std::declval<bool>()));\n\ntemplate<typename T, typename Integer>\nusing number_integer_function_t =\n    decltype(std::declval<T&>().number_integer(std::declval<Integer>()));\n\ntemplate<typename T, typename Unsigned>\nusing number_unsigned_function_t =\n    decltype(std::declval<T&>().number_unsigned(std::declval<Unsigned>()));\n\ntemplate<typename T, typename Float, typename String>\nusing number_float_function_t = decltype(std::declval<T&>().number_float(\n                                    std::declval<Float>(), std::declval<const String&>()));\n\ntemplate<typename T, typename String>\nusing string_function_t =\n    decltype(std::declval<T&>().string(std::declval<String&>()));\n\ntemplate<typename T, typename Binary>\nusing binary_function_t =\n    decltype(std::declval<T&>().binary(std::declval<Binary&>()));\n\ntemplate<typename T>\nusing start_object_function_t =\n    decltype(std::declval<T&>().start_object(std::declval<std::size_t>()));\n\ntemplate<typename T, typename String>\nusing key_function_t =\n    decltype(std::declval<T&>().key(std::declval<String&>()));\n\ntemplate<typename T>\nusing end_object_function_t = decltype(std::declval<T&>().end_object());\n\ntemplate<typename T>\nusing start_array_function_t =\n    decltype(std::declval<T&>().start_array(std::declval<std::size_t>()));\n\ntemplate<typename T>\nusing end_array_function_t = decltype(std::declval<T&>().end_array());\n\ntemplate<typename T, typename Exception>\nusing parse_error_function_t = decltype(std::declval<T&>().parse_error(\n        std::declval<std::size_t>(), std::declval<const std::string&>(),\n        std::declval<const Exception&>()));\n\ntemplate<typename SAX, typename BasicJsonType>\nstruct is_sax\n{\n  private:\n    static_assert(is_basic_json<BasicJsonType>::value,\n                  \"BasicJsonType must be of type basic_json<...>\");\n\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n    using string_t = typename BasicJsonType::string_t;\n    using binary_t = typename BasicJsonType::binary_t;\n    using exception_t = typename BasicJsonType::exception;\n\n  public:\n    static constexpr bool value =\n        is_detected_exact<bool, null_function_t, SAX>::value &&\n        is_detected_exact<bool, boolean_function_t, SAX>::value &&\n        is_detected_exact<bool, number_integer_function_t, SAX, number_integer_t>::value &&\n        is_detected_exact<bool, number_unsigned_function_t, SAX, number_unsigned_t>::value &&\n        is_detected_exact<bool, number_float_function_t, SAX, number_float_t, string_t>::value &&\n        is_detected_exact<bool, string_function_t, SAX, string_t>::value &&\n        is_detected_exact<bool, binary_function_t, SAX, binary_t>::value &&\n        is_detected_exact<bool, start_object_function_t, SAX>::value &&\n        is_detected_exact<bool, key_function_t, SAX, string_t>::value &&\n        is_detected_exact<bool, end_object_function_t, SAX>::value &&\n        is_detected_exact<bool, start_array_function_t, SAX>::value &&\n        is_detected_exact<bool, end_array_function_t, SAX>::value &&\n        is_detected_exact<bool, parse_error_function_t, SAX, exception_t>::value;\n};\n\ntemplate<typename SAX, typename BasicJsonType>\nstruct is_sax_static_asserts\n{\n  private:\n    static_assert(is_basic_json<BasicJsonType>::value,\n                  \"BasicJsonType must be of type basic_json<...>\");\n\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n    using string_t = typename BasicJsonType::string_t;\n    using binary_t = typename BasicJsonType::binary_t;\n    using exception_t = typename BasicJsonType::exception;\n\n  public:\n    static_assert(is_detected_exact<bool, null_function_t, SAX>::value,\n                  \"Missing/invalid function: bool null()\");\n    static_assert(is_detected_exact<bool, boolean_function_t, SAX>::value,\n                  \"Missing/invalid function: bool boolean(bool)\");\n    static_assert(is_detected_exact<bool, boolean_function_t, SAX>::value,\n                  \"Missing/invalid function: bool boolean(bool)\");\n    static_assert(\n        is_detected_exact<bool, number_integer_function_t, SAX,\n        number_integer_t>::value,\n        \"Missing/invalid function: bool number_integer(number_integer_t)\");\n    static_assert(\n        is_detected_exact<bool, number_unsigned_function_t, SAX,\n        number_unsigned_t>::value,\n        \"Missing/invalid function: bool number_unsigned(number_unsigned_t)\");\n    static_assert(is_detected_exact<bool, number_float_function_t, SAX,\n                  number_float_t, string_t>::value,\n                  \"Missing/invalid function: bool number_float(number_float_t, const string_t&)\");\n    static_assert(\n        is_detected_exact<bool, string_function_t, SAX, string_t>::value,\n        \"Missing/invalid function: bool string(string_t&)\");\n    static_assert(\n        is_detected_exact<bool, binary_function_t, SAX, binary_t>::value,\n        \"Missing/invalid function: bool binary(binary_t&)\");\n    static_assert(is_detected_exact<bool, start_object_function_t, SAX>::value,\n                  \"Missing/invalid function: bool start_object(std::size_t)\");\n    static_assert(is_detected_exact<bool, key_function_t, SAX, string_t>::value,\n                  \"Missing/invalid function: bool key(string_t&)\");\n    static_assert(is_detected_exact<bool, end_object_function_t, SAX>::value,\n                  \"Missing/invalid function: bool end_object()\");\n    static_assert(is_detected_exact<bool, start_array_function_t, SAX>::value,\n                  \"Missing/invalid function: bool start_array(std::size_t)\");\n    static_assert(is_detected_exact<bool, end_array_function_t, SAX>::value,\n                  \"Missing/invalid function: bool end_array()\");\n    static_assert(\n        is_detected_exact<bool, parse_error_function_t, SAX, exception_t>::value,\n        \"Missing/invalid function: bool parse_error(std::size_t, const \"\n        \"std::string&, const exception&)\");\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n// #include <nlohmann/detail/string_concat.hpp>\n\n// #include <nlohmann/detail/value_t.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n/// how to treat CBOR tags\nenum class cbor_tag_handler_t\n{\n    error,   ///< throw a parse_error exception in case of a tag\n    ignore,  ///< ignore tags\n    store    ///< store tags as binary type\n};\n\n/*!\n@brief determine system byte order\n\n@return true if and only if system's byte order is little endian\n\n@note from https://stackoverflow.com/a/1001328/266378\n*/\nstatic inline bool little_endianness(int num = 1) noexcept\n{\n    return *reinterpret_cast<char*>(&num) == 1;\n}\n\n///////////////////\n// binary reader //\n///////////////////\n\n/*!\n@brief deserialization of CBOR, MessagePack, and UBJSON values\n*/\ntemplate<typename BasicJsonType, typename InputAdapterType, typename SAX = json_sax_dom_parser<BasicJsonType>>\nclass binary_reader\n{\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n    using string_t = typename BasicJsonType::string_t;\n    using binary_t = typename BasicJsonType::binary_t;\n    using json_sax_t = SAX;\n    using char_type = typename InputAdapterType::char_type;\n    using char_int_type = typename char_traits<char_type>::int_type;\n\n  public:\n    /*!\n    @brief create a binary reader\n\n    @param[in] adapter  input adapter to read from\n    */\n    explicit binary_reader(InputAdapterType&& adapter, const input_format_t format = input_format_t::json) noexcept : ia(std::move(adapter)), input_format(format)\n    {\n        (void)detail::is_sax_static_asserts<SAX, BasicJsonType> {};\n    }\n\n    // make class move-only\n    binary_reader(const binary_reader&) = delete;\n    binary_reader(binary_reader&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor)\n    binary_reader& operator=(const binary_reader&) = delete;\n    binary_reader& operator=(binary_reader&&) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor)\n    ~binary_reader() = default;\n\n    /*!\n    @param[in] format  the binary format to parse\n    @param[in] sax_    a SAX event processor\n    @param[in] strict  whether to expect the input to be consumed completed\n    @param[in] tag_handler  how to treat CBOR tags\n\n    @return whether parsing was successful\n    */\n    JSON_HEDLEY_NON_NULL(3)\n    bool sax_parse(const input_format_t format,\n                   json_sax_t* sax_,\n                   const bool strict = true,\n                   const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error)\n    {\n        sax = sax_;\n        bool result = false;\n\n        switch (format)\n        {\n            case input_format_t::bson:\n                result = parse_bson_internal();\n                break;\n\n            case input_format_t::cbor:\n                result = parse_cbor_internal(true, tag_handler);\n                break;\n\n            case input_format_t::msgpack:\n                result = parse_msgpack_internal();\n                break;\n\n            case input_format_t::ubjson:\n            case input_format_t::bjdata:\n                result = parse_ubjson_internal();\n                break;\n\n            case input_format_t::json: // LCOV_EXCL_LINE\n            default:            // LCOV_EXCL_LINE\n                JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n        }\n\n        // strict mode: next byte must be EOF\n        if (result && strict)\n        {\n            if (input_format == input_format_t::ubjson || input_format == input_format_t::bjdata)\n            {\n                get_ignore_noop();\n            }\n            else\n            {\n                get();\n            }\n\n            if (JSON_HEDLEY_UNLIKELY(current != char_traits<char_type>::eof()))\n            {\n                return sax->parse_error(chars_read, get_token_string(), parse_error::create(110, chars_read,\n                                        exception_message(input_format, concat(\"expected end of input; last byte: 0x\", get_token_string()), \"value\"), nullptr));\n            }\n        }\n\n        return result;\n    }\n\n  private:\n    //////////\n    // BSON //\n    //////////\n\n    /*!\n    @brief Reads in a BSON-object and passes it to the SAX-parser.\n    @return whether a valid BSON-value was passed to the SAX parser\n    */\n    bool parse_bson_internal()\n    {\n        std::int32_t document_size{};\n        get_number<std::int32_t, true>(input_format_t::bson, document_size);\n\n        if (JSON_HEDLEY_UNLIKELY(!sax->start_object(static_cast<std::size_t>(-1))))\n        {\n            return false;\n        }\n\n        if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_list(/*is_array*/false)))\n        {\n            return false;\n        }\n\n        return sax->end_object();\n    }\n\n    /*!\n    @brief Parses a C-style string from the BSON input.\n    @param[in,out] result  A reference to the string variable where the read\n                            string is to be stored.\n    @return `true` if the \\x00-byte indicating the end of the string was\n             encountered before the EOF; false` indicates an unexpected EOF.\n    */\n    bool get_bson_cstr(string_t& result)\n    {\n        auto out = std::back_inserter(result);\n        while (true)\n        {\n            get();\n            if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::bson, \"cstring\")))\n            {\n                return false;\n            }\n            if (current == 0x00)\n            {\n                return true;\n            }\n            *out++ = static_cast<typename string_t::value_type>(current);\n        }\n    }\n\n    /*!\n    @brief Parses a zero-terminated string of length @a len from the BSON\n           input.\n    @param[in] len  The length (including the zero-byte at the end) of the\n                    string to be read.\n    @param[in,out] result  A reference to the string variable where the read\n                            string is to be stored.\n    @tparam NumberType The type of the length @a len\n    @pre len >= 1\n    @return `true` if the string was successfully parsed\n    */\n    template<typename NumberType>\n    bool get_bson_string(const NumberType len, string_t& result)\n    {\n        if (JSON_HEDLEY_UNLIKELY(len < 1))\n        {\n            auto last_token = get_token_string();\n            return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read,\n                                    exception_message(input_format_t::bson, concat(\"string length must be at least 1, is \", std::to_string(len)), \"string\"), nullptr));\n        }\n\n        return get_string(input_format_t::bson, len - static_cast<NumberType>(1), result) && get() != char_traits<char_type>::eof();\n    }\n\n    /*!\n    @brief Parses a byte array input of length @a len from the BSON input.\n    @param[in] len  The length of the byte array to be read.\n    @param[in,out] result  A reference to the binary variable where the read\n                            array is to be stored.\n    @tparam NumberType The type of the length @a len\n    @pre len >= 0\n    @return `true` if the byte array was successfully parsed\n    */\n    template<typename NumberType>\n    bool get_bson_binary(const NumberType len, binary_t& result)\n    {\n        if (JSON_HEDLEY_UNLIKELY(len < 0))\n        {\n            auto last_token = get_token_string();\n            return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read,\n                                    exception_message(input_format_t::bson, concat(\"byte array length cannot be negative, is \", std::to_string(len)), \"binary\"), nullptr));\n        }\n\n        // All BSON binary values have a subtype\n        std::uint8_t subtype{};\n        get_number<std::uint8_t>(input_format_t::bson, subtype);\n        result.set_subtype(subtype);\n\n        return get_binary(input_format_t::bson, len, result);\n    }\n\n    /*!\n    @brief Read a BSON document element of the given @a element_type.\n    @param[in] element_type The BSON element type, c.f. http://bsonspec.org/spec.html\n    @param[in] element_type_parse_position The position in the input stream,\n               where the `element_type` was read.\n    @warning Not all BSON element types are supported yet. An unsupported\n             @a element_type will give rise to a parse_error.114:\n             Unsupported BSON record type 0x...\n    @return whether a valid BSON-object/array was passed to the SAX parser\n    */\n    bool parse_bson_element_internal(const char_int_type element_type,\n                                     const std::size_t element_type_parse_position)\n    {\n        switch (element_type)\n        {\n            case 0x01: // double\n            {\n                double number{};\n                return get_number<double, true>(input_format_t::bson, number) && sax->number_float(static_cast<number_float_t>(number), \"\");\n            }\n\n            case 0x02: // string\n            {\n                std::int32_t len{};\n                string_t value;\n                return get_number<std::int32_t, true>(input_format_t::bson, len) && get_bson_string(len, value) && sax->string(value);\n            }\n\n            case 0x03: // object\n            {\n                return parse_bson_internal();\n            }\n\n            case 0x04: // array\n            {\n                return parse_bson_array();\n            }\n\n            case 0x05: // binary\n            {\n                std::int32_t len{};\n                binary_t value;\n                return get_number<std::int32_t, true>(input_format_t::bson, len) && get_bson_binary(len, value) && sax->binary(value);\n            }\n\n            case 0x08: // boolean\n            {\n                return sax->boolean(get() != 0);\n            }\n\n            case 0x0A: // null\n            {\n                return sax->null();\n            }\n\n            case 0x10: // int32\n            {\n                std::int32_t value{};\n                return get_number<std::int32_t, true>(input_format_t::bson, value) && sax->number_integer(value);\n            }\n\n            case 0x12: // int64\n            {\n                std::int64_t value{};\n                return get_number<std::int64_t, true>(input_format_t::bson, value) && sax->number_integer(value);\n            }\n\n            default: // anything else not supported (yet)\n            {\n                std::array<char, 3> cr{{}};\n                static_cast<void>((std::snprintf)(cr.data(), cr.size(), \"%.2hhX\", static_cast<unsigned char>(element_type))); // NOLINT(cppcoreguidelines-pro-type-vararg,hicpp-vararg)\n                const std::string cr_str{cr.data()};\n                return sax->parse_error(element_type_parse_position, cr_str,\n                                        parse_error::create(114, element_type_parse_position, concat(\"Unsupported BSON record type 0x\", cr_str), nullptr));\n            }\n        }\n    }\n\n    /*!\n    @brief Read a BSON element list (as specified in the BSON-spec)\n\n    The same binary layout is used for objects and arrays, hence it must be\n    indicated with the argument @a is_array which one is expected\n    (true --> array, false --> object).\n\n    @param[in] is_array Determines if the element list being read is to be\n                        treated as an object (@a is_array == false), or as an\n                        array (@a is_array == true).\n    @return whether a valid BSON-object/array was passed to the SAX parser\n    */\n    bool parse_bson_element_list(const bool is_array)\n    {\n        string_t key;\n\n        while (auto element_type = get())\n        {\n            if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::bson, \"element list\")))\n            {\n                return false;\n            }\n\n            const std::size_t element_type_parse_position = chars_read;\n            if (JSON_HEDLEY_UNLIKELY(!get_bson_cstr(key)))\n            {\n                return false;\n            }\n\n            if (!is_array && !sax->key(key))\n            {\n                return false;\n            }\n\n            if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_internal(element_type, element_type_parse_position)))\n            {\n                return false;\n            }\n\n            // get_bson_cstr only appends\n            key.clear();\n        }\n\n        return true;\n    }\n\n    /*!\n    @brief Reads an array from the BSON input and passes it to the SAX-parser.\n    @return whether a valid BSON-array was passed to the SAX parser\n    */\n    bool parse_bson_array()\n    {\n        std::int32_t document_size{};\n        get_number<std::int32_t, true>(input_format_t::bson, document_size);\n\n        if (JSON_HEDLEY_UNLIKELY(!sax->start_array(static_cast<std::size_t>(-1))))\n        {\n            return false;\n        }\n\n        if (JSON_HEDLEY_UNLIKELY(!parse_bson_element_list(/*is_array*/true)))\n        {\n            return false;\n        }\n\n        return sax->end_array();\n    }\n\n    //////////\n    // CBOR //\n    //////////\n\n    /*!\n    @param[in] get_char  whether a new character should be retrieved from the\n                         input (true) or whether the last read character should\n                         be considered instead (false)\n    @param[in] tag_handler how CBOR tags should be treated\n\n    @return whether a valid CBOR value was passed to the SAX parser\n    */\n    bool parse_cbor_internal(const bool get_char,\n                             const cbor_tag_handler_t tag_handler)\n    {\n        switch (get_char ? get() : current)\n        {\n            // EOF\n            case char_traits<char_type>::eof():\n                return unexpect_eof(input_format_t::cbor, \"value\");\n\n            // Integer 0x00..0x17 (0..23)\n            case 0x00:\n            case 0x01:\n            case 0x02:\n            case 0x03:\n            case 0x04:\n            case 0x05:\n            case 0x06:\n            case 0x07:\n            case 0x08:\n            case 0x09:\n            case 0x0A:\n            case 0x0B:\n            case 0x0C:\n            case 0x0D:\n            case 0x0E:\n            case 0x0F:\n            case 0x10:\n            case 0x11:\n            case 0x12:\n            case 0x13:\n            case 0x14:\n            case 0x15:\n            case 0x16:\n            case 0x17:\n                return sax->number_unsigned(static_cast<number_unsigned_t>(current));\n\n            case 0x18: // Unsigned integer (one-byte uint8_t follows)\n            {\n                std::uint8_t number{};\n                return get_number(input_format_t::cbor, number) && sax->number_unsigned(number);\n            }\n\n            case 0x19: // Unsigned integer (two-byte uint16_t follows)\n            {\n                std::uint16_t number{};\n                return get_number(input_format_t::cbor, number) && sax->number_unsigned(number);\n            }\n\n            case 0x1A: // Unsigned integer (four-byte uint32_t follows)\n            {\n                std::uint32_t number{};\n                return get_number(input_format_t::cbor, number) && sax->number_unsigned(number);\n            }\n\n            case 0x1B: // Unsigned integer (eight-byte uint64_t follows)\n            {\n                std::uint64_t number{};\n                return get_number(input_format_t::cbor, number) && sax->number_unsigned(number);\n            }\n\n            // Negative integer -1-0x00..-1-0x17 (-1..-24)\n            case 0x20:\n            case 0x21:\n            case 0x22:\n            case 0x23:\n            case 0x24:\n            case 0x25:\n            case 0x26:\n            case 0x27:\n            case 0x28:\n            case 0x29:\n            case 0x2A:\n            case 0x2B:\n            case 0x2C:\n            case 0x2D:\n            case 0x2E:\n            case 0x2F:\n            case 0x30:\n            case 0x31:\n            case 0x32:\n            case 0x33:\n            case 0x34:\n            case 0x35:\n            case 0x36:\n            case 0x37:\n                return sax->number_integer(static_cast<std::int8_t>(0x20 - 1 - current));\n\n            case 0x38: // Negative integer (one-byte uint8_t follows)\n            {\n                std::uint8_t number{};\n                return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast<number_integer_t>(-1) - number);\n            }\n\n            case 0x39: // Negative integer -1-n (two-byte uint16_t follows)\n            {\n                std::uint16_t number{};\n                return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast<number_integer_t>(-1) - number);\n            }\n\n            case 0x3A: // Negative integer -1-n (four-byte uint32_t follows)\n            {\n                std::uint32_t number{};\n                return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast<number_integer_t>(-1) - number);\n            }\n\n            case 0x3B: // Negative integer -1-n (eight-byte uint64_t follows)\n            {\n                std::uint64_t number{};\n                return get_number(input_format_t::cbor, number) && sax->number_integer(static_cast<number_integer_t>(-1)\n                        - static_cast<number_integer_t>(number));\n            }\n\n            // Binary data (0x00..0x17 bytes follow)\n            case 0x40:\n            case 0x41:\n            case 0x42:\n            case 0x43:\n            case 0x44:\n            case 0x45:\n            case 0x46:\n            case 0x47:\n            case 0x48:\n            case 0x49:\n            case 0x4A:\n            case 0x4B:\n            case 0x4C:\n            case 0x4D:\n            case 0x4E:\n            case 0x4F:\n            case 0x50:\n            case 0x51:\n            case 0x52:\n            case 0x53:\n            case 0x54:\n            case 0x55:\n            case 0x56:\n            case 0x57:\n            case 0x58: // Binary data (one-byte uint8_t for n follows)\n            case 0x59: // Binary data (two-byte uint16_t for n follow)\n            case 0x5A: // Binary data (four-byte uint32_t for n follow)\n            case 0x5B: // Binary data (eight-byte uint64_t for n follow)\n            case 0x5F: // Binary data (indefinite length)\n            {\n                binary_t b;\n                return get_cbor_binary(b) && sax->binary(b);\n            }\n\n            // UTF-8 string (0x00..0x17 bytes follow)\n            case 0x60:\n            case 0x61:\n            case 0x62:\n            case 0x63:\n            case 0x64:\n            case 0x65:\n            case 0x66:\n            case 0x67:\n            case 0x68:\n            case 0x69:\n            case 0x6A:\n            case 0x6B:\n            case 0x6C:\n            case 0x6D:\n            case 0x6E:\n            case 0x6F:\n            case 0x70:\n            case 0x71:\n            case 0x72:\n            case 0x73:\n            case 0x74:\n            case 0x75:\n            case 0x76:\n            case 0x77:\n            case 0x78: // UTF-8 string (one-byte uint8_t for n follows)\n            case 0x79: // UTF-8 string (two-byte uint16_t for n follow)\n            case 0x7A: // UTF-8 string (four-byte uint32_t for n follow)\n            case 0x7B: // UTF-8 string (eight-byte uint64_t for n follow)\n            case 0x7F: // UTF-8 string (indefinite length)\n            {\n                string_t s;\n                return get_cbor_string(s) && sax->string(s);\n            }\n\n            // array (0x00..0x17 data items follow)\n            case 0x80:\n            case 0x81:\n            case 0x82:\n            case 0x83:\n            case 0x84:\n            case 0x85:\n            case 0x86:\n            case 0x87:\n            case 0x88:\n            case 0x89:\n            case 0x8A:\n            case 0x8B:\n            case 0x8C:\n            case 0x8D:\n            case 0x8E:\n            case 0x8F:\n            case 0x90:\n            case 0x91:\n            case 0x92:\n            case 0x93:\n            case 0x94:\n            case 0x95:\n            case 0x96:\n            case 0x97:\n                return get_cbor_array(\n                           conditional_static_cast<std::size_t>(static_cast<unsigned int>(current) & 0x1Fu), tag_handler);\n\n            case 0x98: // array (one-byte uint8_t for n follows)\n            {\n                std::uint8_t len{};\n                return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast<std::size_t>(len), tag_handler);\n            }\n\n            case 0x99: // array (two-byte uint16_t for n follow)\n            {\n                std::uint16_t len{};\n                return get_number(input_format_t::cbor, len) && get_cbor_array(static_cast<std::size_t>(len), tag_handler);\n            }\n\n            case 0x9A: // array (four-byte uint32_t for n follow)\n            {\n                std::uint32_t len{};\n                return get_number(input_format_t::cbor, len) && get_cbor_array(conditional_static_cast<std::size_t>(len), tag_handler);\n            }\n\n            case 0x9B: // array (eight-byte uint64_t for n follow)\n            {\n                std::uint64_t len{};\n                return get_number(input_format_t::cbor, len) && get_cbor_array(conditional_static_cast<std::size_t>(len), tag_handler);\n            }\n\n            case 0x9F: // array (indefinite length)\n                return get_cbor_array(static_cast<std::size_t>(-1), tag_handler);\n\n            // map (0x00..0x17 pairs of data items follow)\n            case 0xA0:\n            case 0xA1:\n            case 0xA2:\n            case 0xA3:\n            case 0xA4:\n            case 0xA5:\n            case 0xA6:\n            case 0xA7:\n            case 0xA8:\n            case 0xA9:\n            case 0xAA:\n            case 0xAB:\n            case 0xAC:\n            case 0xAD:\n            case 0xAE:\n            case 0xAF:\n            case 0xB0:\n            case 0xB1:\n            case 0xB2:\n            case 0xB3:\n            case 0xB4:\n            case 0xB5:\n            case 0xB6:\n            case 0xB7:\n                return get_cbor_object(conditional_static_cast<std::size_t>(static_cast<unsigned int>(current) & 0x1Fu), tag_handler);\n\n            case 0xB8: // map (one-byte uint8_t for n follows)\n            {\n                std::uint8_t len{};\n                return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast<std::size_t>(len), tag_handler);\n            }\n\n            case 0xB9: // map (two-byte uint16_t for n follow)\n            {\n                std::uint16_t len{};\n                return get_number(input_format_t::cbor, len) && get_cbor_object(static_cast<std::size_t>(len), tag_handler);\n            }\n\n            case 0xBA: // map (four-byte uint32_t for n follow)\n            {\n                std::uint32_t len{};\n                return get_number(input_format_t::cbor, len) && get_cbor_object(conditional_static_cast<std::size_t>(len), tag_handler);\n            }\n\n            case 0xBB: // map (eight-byte uint64_t for n follow)\n            {\n                std::uint64_t len{};\n                return get_number(input_format_t::cbor, len) && get_cbor_object(conditional_static_cast<std::size_t>(len), tag_handler);\n            }\n\n            case 0xBF: // map (indefinite length)\n                return get_cbor_object(static_cast<std::size_t>(-1), tag_handler);\n\n            case 0xC6: // tagged item\n            case 0xC7:\n            case 0xC8:\n            case 0xC9:\n            case 0xCA:\n            case 0xCB:\n            case 0xCC:\n            case 0xCD:\n            case 0xCE:\n            case 0xCF:\n            case 0xD0:\n            case 0xD1:\n            case 0xD2:\n            case 0xD3:\n            case 0xD4:\n            case 0xD8: // tagged item (1 bytes follow)\n            case 0xD9: // tagged item (2 bytes follow)\n            case 0xDA: // tagged item (4 bytes follow)\n            case 0xDB: // tagged item (8 bytes follow)\n            {\n                switch (tag_handler)\n                {\n                    case cbor_tag_handler_t::error:\n                    {\n                        auto last_token = get_token_string();\n                        return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read,\n                                                exception_message(input_format_t::cbor, concat(\"invalid byte: 0x\", last_token), \"value\"), nullptr));\n                    }\n\n                    case cbor_tag_handler_t::ignore:\n                    {\n                        // ignore binary subtype\n                        switch (current)\n                        {\n                            case 0xD8:\n                            {\n                                std::uint8_t subtype_to_ignore{};\n                                get_number(input_format_t::cbor, subtype_to_ignore);\n                                break;\n                            }\n                            case 0xD9:\n                            {\n                                std::uint16_t subtype_to_ignore{};\n                                get_number(input_format_t::cbor, subtype_to_ignore);\n                                break;\n                            }\n                            case 0xDA:\n                            {\n                                std::uint32_t subtype_to_ignore{};\n                                get_number(input_format_t::cbor, subtype_to_ignore);\n                                break;\n                            }\n                            case 0xDB:\n                            {\n                                std::uint64_t subtype_to_ignore{};\n                                get_number(input_format_t::cbor, subtype_to_ignore);\n                                break;\n                            }\n                            default:\n                                break;\n                        }\n                        return parse_cbor_internal(true, tag_handler);\n                    }\n\n                    case cbor_tag_handler_t::store:\n                    {\n                        binary_t b;\n                        // use binary subtype and store in binary container\n                        switch (current)\n                        {\n                            case 0xD8:\n                            {\n                                std::uint8_t subtype{};\n                                get_number(input_format_t::cbor, subtype);\n                                b.set_subtype(detail::conditional_static_cast<typename binary_t::subtype_type>(subtype));\n                                break;\n                            }\n                            case 0xD9:\n                            {\n                                std::uint16_t subtype{};\n                                get_number(input_format_t::cbor, subtype);\n                                b.set_subtype(detail::conditional_static_cast<typename binary_t::subtype_type>(subtype));\n                                break;\n                            }\n                            case 0xDA:\n                            {\n                                std::uint32_t subtype{};\n                                get_number(input_format_t::cbor, subtype);\n                                b.set_subtype(detail::conditional_static_cast<typename binary_t::subtype_type>(subtype));\n                                break;\n                            }\n                            case 0xDB:\n                            {\n                                std::uint64_t subtype{};\n                                get_number(input_format_t::cbor, subtype);\n                                b.set_subtype(detail::conditional_static_cast<typename binary_t::subtype_type>(subtype));\n                                break;\n                            }\n                            default:\n                                return parse_cbor_internal(true, tag_handler);\n                        }\n                        get();\n                        return get_cbor_binary(b) && sax->binary(b);\n                    }\n\n                    default:                 // LCOV_EXCL_LINE\n                        JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n                        return false;        // LCOV_EXCL_LINE\n                }\n            }\n\n            case 0xF4: // false\n                return sax->boolean(false);\n\n            case 0xF5: // true\n                return sax->boolean(true);\n\n            case 0xF6: // null\n                return sax->null();\n\n            case 0xF9: // Half-Precision Float (two-byte IEEE 754)\n            {\n                const auto byte1_raw = get();\n                if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, \"number\")))\n                {\n                    return false;\n                }\n                const auto byte2_raw = get();\n                if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, \"number\")))\n                {\n                    return false;\n                }\n\n                const auto byte1 = static_cast<unsigned char>(byte1_raw);\n                const auto byte2 = static_cast<unsigned char>(byte2_raw);\n\n                // code from RFC 7049, Appendix D, Figure 3:\n                // As half-precision floating-point numbers were only added\n                // to IEEE 754 in 2008, today's programming platforms often\n                // still only have limited support for them. It is very\n                // easy to include at least decoding support for them even\n                // without such support. An example of a small decoder for\n                // half-precision floating-point numbers in the C language\n                // is shown in Fig. 3.\n                const auto half = static_cast<unsigned int>((byte1 << 8u) + byte2);\n                const double val = [&half]\n                {\n                    const int exp = (half >> 10u) & 0x1Fu;\n                    const unsigned int mant = half & 0x3FFu;\n                    JSON_ASSERT(0 <= exp&& exp <= 32);\n                    JSON_ASSERT(mant <= 1024);\n                    switch (exp)\n                    {\n                        case 0:\n                            return std::ldexp(mant, -24);\n                        case 31:\n                            return (mant == 0)\n                            ? std::numeric_limits<double>::infinity()\n                            : std::numeric_limits<double>::quiet_NaN();\n                        default:\n                            return std::ldexp(mant + 1024, exp - 25);\n                    }\n                }();\n                return sax->number_float((half & 0x8000u) != 0\n                                         ? static_cast<number_float_t>(-val)\n                                         : static_cast<number_float_t>(val), \"\");\n            }\n\n            case 0xFA: // Single-Precision Float (four-byte IEEE 754)\n            {\n                float number{};\n                return get_number(input_format_t::cbor, number) && sax->number_float(static_cast<number_float_t>(number), \"\");\n            }\n\n            case 0xFB: // Double-Precision Float (eight-byte IEEE 754)\n            {\n                double number{};\n                return get_number(input_format_t::cbor, number) && sax->number_float(static_cast<number_float_t>(number), \"\");\n            }\n\n            default: // anything else (0xFF is handled inside the other types)\n            {\n                auto last_token = get_token_string();\n                return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read,\n                                        exception_message(input_format_t::cbor, concat(\"invalid byte: 0x\", last_token), \"value\"), nullptr));\n            }\n        }\n    }\n\n    /*!\n    @brief reads a CBOR string\n\n    This function first reads starting bytes to determine the expected\n    string length and then copies this number of bytes into a string.\n    Additionally, CBOR's strings with indefinite lengths are supported.\n\n    @param[out] result  created string\n\n    @return whether string creation completed\n    */\n    bool get_cbor_string(string_t& result)\n    {\n        if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, \"string\")))\n        {\n            return false;\n        }\n\n        switch (current)\n        {\n            // UTF-8 string (0x00..0x17 bytes follow)\n            case 0x60:\n            case 0x61:\n            case 0x62:\n            case 0x63:\n            case 0x64:\n            case 0x65:\n            case 0x66:\n            case 0x67:\n            case 0x68:\n            case 0x69:\n            case 0x6A:\n            case 0x6B:\n            case 0x6C:\n            case 0x6D:\n            case 0x6E:\n            case 0x6F:\n            case 0x70:\n            case 0x71:\n            case 0x72:\n            case 0x73:\n            case 0x74:\n            case 0x75:\n            case 0x76:\n            case 0x77:\n            {\n                return get_string(input_format_t::cbor, static_cast<unsigned int>(current) & 0x1Fu, result);\n            }\n\n            case 0x78: // UTF-8 string (one-byte uint8_t for n follows)\n            {\n                std::uint8_t len{};\n                return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result);\n            }\n\n            case 0x79: // UTF-8 string (two-byte uint16_t for n follow)\n            {\n                std::uint16_t len{};\n                return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result);\n            }\n\n            case 0x7A: // UTF-8 string (four-byte uint32_t for n follow)\n            {\n                std::uint32_t len{};\n                return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result);\n            }\n\n            case 0x7B: // UTF-8 string (eight-byte uint64_t for n follow)\n            {\n                std::uint64_t len{};\n                return get_number(input_format_t::cbor, len) && get_string(input_format_t::cbor, len, result);\n            }\n\n            case 0x7F: // UTF-8 string (indefinite length)\n            {\n                while (get() != 0xFF)\n                {\n                    string_t chunk;\n                    if (!get_cbor_string(chunk))\n                    {\n                        return false;\n                    }\n                    result.append(chunk);\n                }\n                return true;\n            }\n\n            default:\n            {\n                auto last_token = get_token_string();\n                return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read,\n                                        exception_message(input_format_t::cbor, concat(\"expected length specification (0x60-0x7B) or indefinite string type (0x7F); last byte: 0x\", last_token), \"string\"), nullptr));\n            }\n        }\n    }\n\n    /*!\n    @brief reads a CBOR byte array\n\n    This function first reads starting bytes to determine the expected\n    byte array length and then copies this number of bytes into the byte array.\n    Additionally, CBOR's byte arrays with indefinite lengths are supported.\n\n    @param[out] result  created byte array\n\n    @return whether byte array creation completed\n    */\n    bool get_cbor_binary(binary_t& result)\n    {\n        if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::cbor, \"binary\")))\n        {\n            return false;\n        }\n\n        switch (current)\n        {\n            // Binary data (0x00..0x17 bytes follow)\n            case 0x40:\n            case 0x41:\n            case 0x42:\n            case 0x43:\n            case 0x44:\n            case 0x45:\n            case 0x46:\n            case 0x47:\n            case 0x48:\n            case 0x49:\n            case 0x4A:\n            case 0x4B:\n            case 0x4C:\n            case 0x4D:\n            case 0x4E:\n            case 0x4F:\n            case 0x50:\n            case 0x51:\n            case 0x52:\n            case 0x53:\n            case 0x54:\n            case 0x55:\n            case 0x56:\n            case 0x57:\n            {\n                return get_binary(input_format_t::cbor, static_cast<unsigned int>(current) & 0x1Fu, result);\n            }\n\n            case 0x58: // Binary data (one-byte uint8_t for n follows)\n            {\n                std::uint8_t len{};\n                return get_number(input_format_t::cbor, len) &&\n                       get_binary(input_format_t::cbor, len, result);\n            }\n\n            case 0x59: // Binary data (two-byte uint16_t for n follow)\n            {\n                std::uint16_t len{};\n                return get_number(input_format_t::cbor, len) &&\n                       get_binary(input_format_t::cbor, len, result);\n            }\n\n            case 0x5A: // Binary data (four-byte uint32_t for n follow)\n            {\n                std::uint32_t len{};\n                return get_number(input_format_t::cbor, len) &&\n                       get_binary(input_format_t::cbor, len, result);\n            }\n\n            case 0x5B: // Binary data (eight-byte uint64_t for n follow)\n            {\n                std::uint64_t len{};\n                return get_number(input_format_t::cbor, len) &&\n                       get_binary(input_format_t::cbor, len, result);\n            }\n\n            case 0x5F: // Binary data (indefinite length)\n            {\n                while (get() != 0xFF)\n                {\n                    binary_t chunk;\n                    if (!get_cbor_binary(chunk))\n                    {\n                        return false;\n                    }\n                    result.insert(result.end(), chunk.begin(), chunk.end());\n                }\n                return true;\n            }\n\n            default:\n            {\n                auto last_token = get_token_string();\n                return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read,\n                                        exception_message(input_format_t::cbor, concat(\"expected length specification (0x40-0x5B) or indefinite binary array type (0x5F); last byte: 0x\", last_token), \"binary\"), nullptr));\n            }\n        }\n    }\n\n    /*!\n    @param[in] len  the length of the array or static_cast<std::size_t>(-1) for an\n                    array of indefinite size\n    @param[in] tag_handler how CBOR tags should be treated\n    @return whether array creation completed\n    */\n    bool get_cbor_array(const std::size_t len,\n                        const cbor_tag_handler_t tag_handler)\n    {\n        if (JSON_HEDLEY_UNLIKELY(!sax->start_array(len)))\n        {\n            return false;\n        }\n\n        if (len != static_cast<std::size_t>(-1))\n        {\n            for (std::size_t i = 0; i < len; ++i)\n            {\n                if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler)))\n                {\n                    return false;\n                }\n            }\n        }\n        else\n        {\n            while (get() != 0xFF)\n            {\n                if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(false, tag_handler)))\n                {\n                    return false;\n                }\n            }\n        }\n\n        return sax->end_array();\n    }\n\n    /*!\n    @param[in] len  the length of the object or static_cast<std::size_t>(-1) for an\n                    object of indefinite size\n    @param[in] tag_handler how CBOR tags should be treated\n    @return whether object creation completed\n    */\n    bool get_cbor_object(const std::size_t len,\n                         const cbor_tag_handler_t tag_handler)\n    {\n        if (JSON_HEDLEY_UNLIKELY(!sax->start_object(len)))\n        {\n            return false;\n        }\n\n        if (len != 0)\n        {\n            string_t key;\n            if (len != static_cast<std::size_t>(-1))\n            {\n                for (std::size_t i = 0; i < len; ++i)\n                {\n                    get();\n                    if (JSON_HEDLEY_UNLIKELY(!get_cbor_string(key) || !sax->key(key)))\n                    {\n                        return false;\n                    }\n\n                    if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler)))\n                    {\n                        return false;\n                    }\n                    key.clear();\n                }\n            }\n            else\n            {\n                while (get() != 0xFF)\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!get_cbor_string(key) || !sax->key(key)))\n                    {\n                        return false;\n                    }\n\n                    if (JSON_HEDLEY_UNLIKELY(!parse_cbor_internal(true, tag_handler)))\n                    {\n                        return false;\n                    }\n                    key.clear();\n                }\n            }\n        }\n\n        return sax->end_object();\n    }\n\n    /////////////\n    // MsgPack //\n    /////////////\n\n    /*!\n    @return whether a valid MessagePack value was passed to the SAX parser\n    */\n    bool parse_msgpack_internal()\n    {\n        switch (get())\n        {\n            // EOF\n            case char_traits<char_type>::eof():\n                return unexpect_eof(input_format_t::msgpack, \"value\");\n\n            // positive fixint\n            case 0x00:\n            case 0x01:\n            case 0x02:\n            case 0x03:\n            case 0x04:\n            case 0x05:\n            case 0x06:\n            case 0x07:\n            case 0x08:\n            case 0x09:\n            case 0x0A:\n            case 0x0B:\n            case 0x0C:\n            case 0x0D:\n            case 0x0E:\n            case 0x0F:\n            case 0x10:\n            case 0x11:\n            case 0x12:\n            case 0x13:\n            case 0x14:\n            case 0x15:\n            case 0x16:\n            case 0x17:\n            case 0x18:\n            case 0x19:\n            case 0x1A:\n            case 0x1B:\n            case 0x1C:\n            case 0x1D:\n            case 0x1E:\n            case 0x1F:\n            case 0x20:\n            case 0x21:\n            case 0x22:\n            case 0x23:\n            case 0x24:\n            case 0x25:\n            case 0x26:\n            case 0x27:\n            case 0x28:\n            case 0x29:\n            case 0x2A:\n            case 0x2B:\n            case 0x2C:\n            case 0x2D:\n            case 0x2E:\n            case 0x2F:\n            case 0x30:\n            case 0x31:\n            case 0x32:\n            case 0x33:\n            case 0x34:\n            case 0x35:\n            case 0x36:\n            case 0x37:\n            case 0x38:\n            case 0x39:\n            case 0x3A:\n            case 0x3B:\n            case 0x3C:\n            case 0x3D:\n            case 0x3E:\n            case 0x3F:\n            case 0x40:\n            case 0x41:\n            case 0x42:\n            case 0x43:\n            case 0x44:\n            case 0x45:\n            case 0x46:\n            case 0x47:\n            case 0x48:\n            case 0x49:\n            case 0x4A:\n            case 0x4B:\n            case 0x4C:\n            case 0x4D:\n            case 0x4E:\n            case 0x4F:\n            case 0x50:\n            case 0x51:\n            case 0x52:\n            case 0x53:\n            case 0x54:\n            case 0x55:\n            case 0x56:\n            case 0x57:\n            case 0x58:\n            case 0x59:\n            case 0x5A:\n            case 0x5B:\n            case 0x5C:\n            case 0x5D:\n            case 0x5E:\n            case 0x5F:\n            case 0x60:\n            case 0x61:\n            case 0x62:\n            case 0x63:\n            case 0x64:\n            case 0x65:\n            case 0x66:\n            case 0x67:\n            case 0x68:\n            case 0x69:\n            case 0x6A:\n            case 0x6B:\n            case 0x6C:\n            case 0x6D:\n            case 0x6E:\n            case 0x6F:\n            case 0x70:\n            case 0x71:\n            case 0x72:\n            case 0x73:\n            case 0x74:\n            case 0x75:\n            case 0x76:\n            case 0x77:\n            case 0x78:\n            case 0x79:\n            case 0x7A:\n            case 0x7B:\n            case 0x7C:\n            case 0x7D:\n            case 0x7E:\n            case 0x7F:\n                return sax->number_unsigned(static_cast<number_unsigned_t>(current));\n\n            // fixmap\n            case 0x80:\n            case 0x81:\n            case 0x82:\n            case 0x83:\n            case 0x84:\n            case 0x85:\n            case 0x86:\n            case 0x87:\n            case 0x88:\n            case 0x89:\n            case 0x8A:\n            case 0x8B:\n            case 0x8C:\n            case 0x8D:\n            case 0x8E:\n            case 0x8F:\n                return get_msgpack_object(conditional_static_cast<std::size_t>(static_cast<unsigned int>(current) & 0x0Fu));\n\n            // fixarray\n            case 0x90:\n            case 0x91:\n            case 0x92:\n            case 0x93:\n            case 0x94:\n            case 0x95:\n            case 0x96:\n            case 0x97:\n            case 0x98:\n            case 0x99:\n            case 0x9A:\n            case 0x9B:\n            case 0x9C:\n            case 0x9D:\n            case 0x9E:\n            case 0x9F:\n                return get_msgpack_array(conditional_static_cast<std::size_t>(static_cast<unsigned int>(current) & 0x0Fu));\n\n            // fixstr\n            case 0xA0:\n            case 0xA1:\n            case 0xA2:\n            case 0xA3:\n            case 0xA4:\n            case 0xA5:\n            case 0xA6:\n            case 0xA7:\n            case 0xA8:\n            case 0xA9:\n            case 0xAA:\n            case 0xAB:\n            case 0xAC:\n            case 0xAD:\n            case 0xAE:\n            case 0xAF:\n            case 0xB0:\n            case 0xB1:\n            case 0xB2:\n            case 0xB3:\n            case 0xB4:\n            case 0xB5:\n            case 0xB6:\n            case 0xB7:\n            case 0xB8:\n            case 0xB9:\n            case 0xBA:\n            case 0xBB:\n            case 0xBC:\n            case 0xBD:\n            case 0xBE:\n            case 0xBF:\n            case 0xD9: // str 8\n            case 0xDA: // str 16\n            case 0xDB: // str 32\n            {\n                string_t s;\n                return get_msgpack_string(s) && sax->string(s);\n            }\n\n            case 0xC0: // nil\n                return sax->null();\n\n            case 0xC2: // false\n                return sax->boolean(false);\n\n            case 0xC3: // true\n                return sax->boolean(true);\n\n            case 0xC4: // bin 8\n            case 0xC5: // bin 16\n            case 0xC6: // bin 32\n            case 0xC7: // ext 8\n            case 0xC8: // ext 16\n            case 0xC9: // ext 32\n            case 0xD4: // fixext 1\n            case 0xD5: // fixext 2\n            case 0xD6: // fixext 4\n            case 0xD7: // fixext 8\n            case 0xD8: // fixext 16\n            {\n                binary_t b;\n                return get_msgpack_binary(b) && sax->binary(b);\n            }\n\n            case 0xCA: // float 32\n            {\n                float number{};\n                return get_number(input_format_t::msgpack, number) && sax->number_float(static_cast<number_float_t>(number), \"\");\n            }\n\n            case 0xCB: // float 64\n            {\n                double number{};\n                return get_number(input_format_t::msgpack, number) && sax->number_float(static_cast<number_float_t>(number), \"\");\n            }\n\n            case 0xCC: // uint 8\n            {\n                std::uint8_t number{};\n                return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number);\n            }\n\n            case 0xCD: // uint 16\n            {\n                std::uint16_t number{};\n                return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number);\n            }\n\n            case 0xCE: // uint 32\n            {\n                std::uint32_t number{};\n                return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number);\n            }\n\n            case 0xCF: // uint 64\n            {\n                std::uint64_t number{};\n                return get_number(input_format_t::msgpack, number) && sax->number_unsigned(number);\n            }\n\n            case 0xD0: // int 8\n            {\n                std::int8_t number{};\n                return get_number(input_format_t::msgpack, number) && sax->number_integer(number);\n            }\n\n            case 0xD1: // int 16\n            {\n                std::int16_t number{};\n                return get_number(input_format_t::msgpack, number) && sax->number_integer(number);\n            }\n\n            case 0xD2: // int 32\n            {\n                std::int32_t number{};\n                return get_number(input_format_t::msgpack, number) && sax->number_integer(number);\n            }\n\n            case 0xD3: // int 64\n            {\n                std::int64_t number{};\n                return get_number(input_format_t::msgpack, number) && sax->number_integer(number);\n            }\n\n            case 0xDC: // array 16\n            {\n                std::uint16_t len{};\n                return get_number(input_format_t::msgpack, len) && get_msgpack_array(static_cast<std::size_t>(len));\n            }\n\n            case 0xDD: // array 32\n            {\n                std::uint32_t len{};\n                return get_number(input_format_t::msgpack, len) && get_msgpack_array(conditional_static_cast<std::size_t>(len));\n            }\n\n            case 0xDE: // map 16\n            {\n                std::uint16_t len{};\n                return get_number(input_format_t::msgpack, len) && get_msgpack_object(static_cast<std::size_t>(len));\n            }\n\n            case 0xDF: // map 32\n            {\n                std::uint32_t len{};\n                return get_number(input_format_t::msgpack, len) && get_msgpack_object(conditional_static_cast<std::size_t>(len));\n            }\n\n            // negative fixint\n            case 0xE0:\n            case 0xE1:\n            case 0xE2:\n            case 0xE3:\n            case 0xE4:\n            case 0xE5:\n            case 0xE6:\n            case 0xE7:\n            case 0xE8:\n            case 0xE9:\n            case 0xEA:\n            case 0xEB:\n            case 0xEC:\n            case 0xED:\n            case 0xEE:\n            case 0xEF:\n            case 0xF0:\n            case 0xF1:\n            case 0xF2:\n            case 0xF3:\n            case 0xF4:\n            case 0xF5:\n            case 0xF6:\n            case 0xF7:\n            case 0xF8:\n            case 0xF9:\n            case 0xFA:\n            case 0xFB:\n            case 0xFC:\n            case 0xFD:\n            case 0xFE:\n            case 0xFF:\n                return sax->number_integer(static_cast<std::int8_t>(current));\n\n            default: // anything else\n            {\n                auto last_token = get_token_string();\n                return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read,\n                                        exception_message(input_format_t::msgpack, concat(\"invalid byte: 0x\", last_token), \"value\"), nullptr));\n            }\n        }\n    }\n\n    /*!\n    @brief reads a MessagePack string\n\n    This function first reads starting bytes to determine the expected\n    string length and then copies this number of bytes into a string.\n\n    @param[out] result  created string\n\n    @return whether string creation completed\n    */\n    bool get_msgpack_string(string_t& result)\n    {\n        if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format_t::msgpack, \"string\")))\n        {\n            return false;\n        }\n\n        switch (current)\n        {\n            // fixstr\n            case 0xA0:\n            case 0xA1:\n            case 0xA2:\n            case 0xA3:\n            case 0xA4:\n            case 0xA5:\n            case 0xA6:\n            case 0xA7:\n            case 0xA8:\n            case 0xA9:\n            case 0xAA:\n            case 0xAB:\n            case 0xAC:\n            case 0xAD:\n            case 0xAE:\n            case 0xAF:\n            case 0xB0:\n            case 0xB1:\n            case 0xB2:\n            case 0xB3:\n            case 0xB4:\n            case 0xB5:\n            case 0xB6:\n            case 0xB7:\n            case 0xB8:\n            case 0xB9:\n            case 0xBA:\n            case 0xBB:\n            case 0xBC:\n            case 0xBD:\n            case 0xBE:\n            case 0xBF:\n            {\n                return get_string(input_format_t::msgpack, static_cast<unsigned int>(current) & 0x1Fu, result);\n            }\n\n            case 0xD9: // str 8\n            {\n                std::uint8_t len{};\n                return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result);\n            }\n\n            case 0xDA: // str 16\n            {\n                std::uint16_t len{};\n                return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result);\n            }\n\n            case 0xDB: // str 32\n            {\n                std::uint32_t len{};\n                return get_number(input_format_t::msgpack, len) && get_string(input_format_t::msgpack, len, result);\n            }\n\n            default:\n            {\n                auto last_token = get_token_string();\n                return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read,\n                                        exception_message(input_format_t::msgpack, concat(\"expected length specification (0xA0-0xBF, 0xD9-0xDB); last byte: 0x\", last_token), \"string\"), nullptr));\n            }\n        }\n    }\n\n    /*!\n    @brief reads a MessagePack byte array\n\n    This function first reads starting bytes to determine the expected\n    byte array length and then copies this number of bytes into a byte array.\n\n    @param[out] result  created byte array\n\n    @return whether byte array creation completed\n    */\n    bool get_msgpack_binary(binary_t& result)\n    {\n        // helper function to set the subtype\n        auto assign_and_return_true = [&result](std::int8_t subtype)\n        {\n            result.set_subtype(static_cast<std::uint8_t>(subtype));\n            return true;\n        };\n\n        switch (current)\n        {\n            case 0xC4: // bin 8\n            {\n                std::uint8_t len{};\n                return get_number(input_format_t::msgpack, len) &&\n                       get_binary(input_format_t::msgpack, len, result);\n            }\n\n            case 0xC5: // bin 16\n            {\n                std::uint16_t len{};\n                return get_number(input_format_t::msgpack, len) &&\n                       get_binary(input_format_t::msgpack, len, result);\n            }\n\n            case 0xC6: // bin 32\n            {\n                std::uint32_t len{};\n                return get_number(input_format_t::msgpack, len) &&\n                       get_binary(input_format_t::msgpack, len, result);\n            }\n\n            case 0xC7: // ext 8\n            {\n                std::uint8_t len{};\n                std::int8_t subtype{};\n                return get_number(input_format_t::msgpack, len) &&\n                       get_number(input_format_t::msgpack, subtype) &&\n                       get_binary(input_format_t::msgpack, len, result) &&\n                       assign_and_return_true(subtype);\n            }\n\n            case 0xC8: // ext 16\n            {\n                std::uint16_t len{};\n                std::int8_t subtype{};\n                return get_number(input_format_t::msgpack, len) &&\n                       get_number(input_format_t::msgpack, subtype) &&\n                       get_binary(input_format_t::msgpack, len, result) &&\n                       assign_and_return_true(subtype);\n            }\n\n            case 0xC9: // ext 32\n            {\n                std::uint32_t len{};\n                std::int8_t subtype{};\n                return get_number(input_format_t::msgpack, len) &&\n                       get_number(input_format_t::msgpack, subtype) &&\n                       get_binary(input_format_t::msgpack, len, result) &&\n                       assign_and_return_true(subtype);\n            }\n\n            case 0xD4: // fixext 1\n            {\n                std::int8_t subtype{};\n                return get_number(input_format_t::msgpack, subtype) &&\n                       get_binary(input_format_t::msgpack, 1, result) &&\n                       assign_and_return_true(subtype);\n            }\n\n            case 0xD5: // fixext 2\n            {\n                std::int8_t subtype{};\n                return get_number(input_format_t::msgpack, subtype) &&\n                       get_binary(input_format_t::msgpack, 2, result) &&\n                       assign_and_return_true(subtype);\n            }\n\n            case 0xD6: // fixext 4\n            {\n                std::int8_t subtype{};\n                return get_number(input_format_t::msgpack, subtype) &&\n                       get_binary(input_format_t::msgpack, 4, result) &&\n                       assign_and_return_true(subtype);\n            }\n\n            case 0xD7: // fixext 8\n            {\n                std::int8_t subtype{};\n                return get_number(input_format_t::msgpack, subtype) &&\n                       get_binary(input_format_t::msgpack, 8, result) &&\n                       assign_and_return_true(subtype);\n            }\n\n            case 0xD8: // fixext 16\n            {\n                std::int8_t subtype{};\n                return get_number(input_format_t::msgpack, subtype) &&\n                       get_binary(input_format_t::msgpack, 16, result) &&\n                       assign_and_return_true(subtype);\n            }\n\n            default:           // LCOV_EXCL_LINE\n                return false;  // LCOV_EXCL_LINE\n        }\n    }\n\n    /*!\n    @param[in] len  the length of the array\n    @return whether array creation completed\n    */\n    bool get_msgpack_array(const std::size_t len)\n    {\n        if (JSON_HEDLEY_UNLIKELY(!sax->start_array(len)))\n        {\n            return false;\n        }\n\n        for (std::size_t i = 0; i < len; ++i)\n        {\n            if (JSON_HEDLEY_UNLIKELY(!parse_msgpack_internal()))\n            {\n                return false;\n            }\n        }\n\n        return sax->end_array();\n    }\n\n    /*!\n    @param[in] len  the length of the object\n    @return whether object creation completed\n    */\n    bool get_msgpack_object(const std::size_t len)\n    {\n        if (JSON_HEDLEY_UNLIKELY(!sax->start_object(len)))\n        {\n            return false;\n        }\n\n        string_t key;\n        for (std::size_t i = 0; i < len; ++i)\n        {\n            get();\n            if (JSON_HEDLEY_UNLIKELY(!get_msgpack_string(key) || !sax->key(key)))\n            {\n                return false;\n            }\n\n            if (JSON_HEDLEY_UNLIKELY(!parse_msgpack_internal()))\n            {\n                return false;\n            }\n            key.clear();\n        }\n\n        return sax->end_object();\n    }\n\n    ////////////\n    // UBJSON //\n    ////////////\n\n    /*!\n    @param[in] get_char  whether a new character should be retrieved from the\n                         input (true, default) or whether the last read\n                         character should be considered instead\n\n    @return whether a valid UBJSON value was passed to the SAX parser\n    */\n    bool parse_ubjson_internal(const bool get_char = true)\n    {\n        return get_ubjson_value(get_char ? get_ignore_noop() : current);\n    }\n\n    /*!\n    @brief reads a UBJSON string\n\n    This function is either called after reading the 'S' byte explicitly\n    indicating a string, or in case of an object key where the 'S' byte can be\n    left out.\n\n    @param[out] result   created string\n    @param[in] get_char  whether a new character should be retrieved from the\n                         input (true, default) or whether the last read\n                         character should be considered instead\n\n    @return whether string creation completed\n    */\n    bool get_ubjson_string(string_t& result, const bool get_char = true)\n    {\n        if (get_char)\n        {\n            get();  // TODO(niels): may we ignore N here?\n        }\n\n        if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format, \"value\")))\n        {\n            return false;\n        }\n\n        switch (current)\n        {\n            case 'U':\n            {\n                std::uint8_t len{};\n                return get_number(input_format, len) && get_string(input_format, len, result);\n            }\n\n            case 'i':\n            {\n                std::int8_t len{};\n                return get_number(input_format, len) && get_string(input_format, len, result);\n            }\n\n            case 'I':\n            {\n                std::int16_t len{};\n                return get_number(input_format, len) && get_string(input_format, len, result);\n            }\n\n            case 'l':\n            {\n                std::int32_t len{};\n                return get_number(input_format, len) && get_string(input_format, len, result);\n            }\n\n            case 'L':\n            {\n                std::int64_t len{};\n                return get_number(input_format, len) && get_string(input_format, len, result);\n            }\n\n            case 'u':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                std::uint16_t len{};\n                return get_number(input_format, len) && get_string(input_format, len, result);\n            }\n\n            case 'm':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                std::uint32_t len{};\n                return get_number(input_format, len) && get_string(input_format, len, result);\n            }\n\n            case 'M':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                std::uint64_t len{};\n                return get_number(input_format, len) && get_string(input_format, len, result);\n            }\n\n            default:\n                break;\n        }\n        auto last_token = get_token_string();\n        std::string message;\n\n        if (input_format != input_format_t::bjdata)\n        {\n            message = \"expected length type specification (U, i, I, l, L); last byte: 0x\" + last_token;\n        }\n        else\n        {\n            message = \"expected length type specification (U, i, u, I, m, l, M, L); last byte: 0x\" + last_token;\n        }\n        return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format, message, \"string\"), nullptr));\n    }\n\n    /*!\n    @param[out] dim  an integer vector storing the ND array dimensions\n    @return whether reading ND array size vector is successful\n    */\n    bool get_ubjson_ndarray_size(std::vector<size_t>& dim)\n    {\n        std::pair<std::size_t, char_int_type> size_and_type;\n        size_t dimlen = 0;\n        bool no_ndarray = true;\n\n        if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_type(size_and_type, no_ndarray)))\n        {\n            return false;\n        }\n\n        if (size_and_type.first != npos)\n        {\n            if (size_and_type.second != 0)\n            {\n                if (size_and_type.second != 'N')\n                {\n                    for (std::size_t i = 0; i < size_and_type.first; ++i)\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_value(dimlen, no_ndarray, size_and_type.second)))\n                        {\n                            return false;\n                        }\n                        dim.push_back(dimlen);\n                    }\n                }\n            }\n            else\n            {\n                for (std::size_t i = 0; i < size_and_type.first; ++i)\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_value(dimlen, no_ndarray)))\n                    {\n                        return false;\n                    }\n                    dim.push_back(dimlen);\n                }\n            }\n        }\n        else\n        {\n            while (current != ']')\n            {\n                if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_value(dimlen, no_ndarray, current)))\n                {\n                    return false;\n                }\n                dim.push_back(dimlen);\n                get_ignore_noop();\n            }\n        }\n        return true;\n    }\n\n    /*!\n    @param[out] result  determined size\n    @param[in,out] is_ndarray  for input, `true` means already inside an ndarray vector\n                               or ndarray dimension is not allowed; `false` means ndarray\n                               is allowed; for output, `true` means an ndarray is found;\n                               is_ndarray can only return `true` when its initial value\n                               is `false`\n    @param[in] prefix  type marker if already read, otherwise set to 0\n\n    @return whether size determination completed\n    */\n    bool get_ubjson_size_value(std::size_t& result, bool& is_ndarray, char_int_type prefix = 0)\n    {\n        if (prefix == 0)\n        {\n            prefix = get_ignore_noop();\n        }\n\n        switch (prefix)\n        {\n            case 'U':\n            {\n                std::uint8_t number{};\n                if (JSON_HEDLEY_UNLIKELY(!get_number(input_format, number)))\n                {\n                    return false;\n                }\n                result = static_cast<std::size_t>(number);\n                return true;\n            }\n\n            case 'i':\n            {\n                std::int8_t number{};\n                if (JSON_HEDLEY_UNLIKELY(!get_number(input_format, number)))\n                {\n                    return false;\n                }\n                if (number < 0)\n                {\n                    return sax->parse_error(chars_read, get_token_string(), parse_error::create(113, chars_read,\n                                            exception_message(input_format, \"count in an optimized container must be positive\", \"size\"), nullptr));\n                }\n                result = static_cast<std::size_t>(number); // NOLINT(bugprone-signed-char-misuse,cert-str34-c): number is not a char\n                return true;\n            }\n\n            case 'I':\n            {\n                std::int16_t number{};\n                if (JSON_HEDLEY_UNLIKELY(!get_number(input_format, number)))\n                {\n                    return false;\n                }\n                if (number < 0)\n                {\n                    return sax->parse_error(chars_read, get_token_string(), parse_error::create(113, chars_read,\n                                            exception_message(input_format, \"count in an optimized container must be positive\", \"size\"), nullptr));\n                }\n                result = static_cast<std::size_t>(number);\n                return true;\n            }\n\n            case 'l':\n            {\n                std::int32_t number{};\n                if (JSON_HEDLEY_UNLIKELY(!get_number(input_format, number)))\n                {\n                    return false;\n                }\n                if (number < 0)\n                {\n                    return sax->parse_error(chars_read, get_token_string(), parse_error::create(113, chars_read,\n                                            exception_message(input_format, \"count in an optimized container must be positive\", \"size\"), nullptr));\n                }\n                result = static_cast<std::size_t>(number);\n                return true;\n            }\n\n            case 'L':\n            {\n                std::int64_t number{};\n                if (JSON_HEDLEY_UNLIKELY(!get_number(input_format, number)))\n                {\n                    return false;\n                }\n                if (number < 0)\n                {\n                    return sax->parse_error(chars_read, get_token_string(), parse_error::create(113, chars_read,\n                                            exception_message(input_format, \"count in an optimized container must be positive\", \"size\"), nullptr));\n                }\n                if (!value_in_range_of<std::size_t>(number))\n                {\n                    return sax->parse_error(chars_read, get_token_string(), out_of_range::create(408,\n                                            exception_message(input_format, \"integer value overflow\", \"size\"), nullptr));\n                }\n                result = static_cast<std::size_t>(number);\n                return true;\n            }\n\n            case 'u':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                std::uint16_t number{};\n                if (JSON_HEDLEY_UNLIKELY(!get_number(input_format, number)))\n                {\n                    return false;\n                }\n                result = static_cast<std::size_t>(number);\n                return true;\n            }\n\n            case 'm':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                std::uint32_t number{};\n                if (JSON_HEDLEY_UNLIKELY(!get_number(input_format, number)))\n                {\n                    return false;\n                }\n                result = conditional_static_cast<std::size_t>(number);\n                return true;\n            }\n\n            case 'M':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                std::uint64_t number{};\n                if (JSON_HEDLEY_UNLIKELY(!get_number(input_format, number)))\n                {\n                    return false;\n                }\n                if (!value_in_range_of<std::size_t>(number))\n                {\n                    return sax->parse_error(chars_read, get_token_string(), out_of_range::create(408,\n                                            exception_message(input_format, \"integer value overflow\", \"size\"), nullptr));\n                }\n                result = detail::conditional_static_cast<std::size_t>(number);\n                return true;\n            }\n\n            case '[':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                if (is_ndarray) // ndarray dimensional vector can only contain integers, and can not embed another array\n                {\n                    return sax->parse_error(chars_read, get_token_string(), parse_error::create(113, chars_read, exception_message(input_format, \"ndarray dimensional vector is not allowed\", \"size\"), nullptr));\n                }\n                std::vector<size_t> dim;\n                if (JSON_HEDLEY_UNLIKELY(!get_ubjson_ndarray_size(dim)))\n                {\n                    return false;\n                }\n                if (dim.size() == 1 || (dim.size() == 2 && dim.at(0) == 1)) // return normal array size if 1D row vector\n                {\n                    result = dim.at(dim.size() - 1);\n                    return true;\n                }\n                if (!dim.empty())  // if ndarray, convert to an object in JData annotated array format\n                {\n                    for (auto i : dim) // test if any dimension in an ndarray is 0, if so, return a 1D empty container\n                    {\n                        if ( i == 0 )\n                        {\n                            result = 0;\n                            return true;\n                        }\n                    }\n\n                    string_t key = \"_ArraySize_\";\n                    if (JSON_HEDLEY_UNLIKELY(!sax->start_object(3) || !sax->key(key) || !sax->start_array(dim.size())))\n                    {\n                        return false;\n                    }\n                    result = 1;\n                    for (auto i : dim)\n                    {\n                        result *= i;\n                        if (result == 0 || result == npos) // because dim elements shall not have zeros, result = 0 means overflow happened; it also can't be npos as it is used to initialize size in get_ubjson_size_type()\n                        {\n                            return sax->parse_error(chars_read, get_token_string(), out_of_range::create(408, exception_message(input_format, \"excessive ndarray size caused overflow\", \"size\"), nullptr));\n                        }\n                        if (JSON_HEDLEY_UNLIKELY(!sax->number_unsigned(static_cast<number_unsigned_t>(i))))\n                        {\n                            return false;\n                        }\n                    }\n                    is_ndarray = true;\n                    return sax->end_array();\n                }\n                result = 0;\n                return true;\n            }\n\n            default:\n                break;\n        }\n        auto last_token = get_token_string();\n        std::string message;\n\n        if (input_format != input_format_t::bjdata)\n        {\n            message = \"expected length type specification (U, i, I, l, L) after '#'; last byte: 0x\" + last_token;\n        }\n        else\n        {\n            message = \"expected length type specification (U, i, u, I, m, l, M, L) after '#'; last byte: 0x\" + last_token;\n        }\n        return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read, exception_message(input_format, message, \"size\"), nullptr));\n    }\n\n    /*!\n    @brief determine the type and size for a container\n\n    In the optimized UBJSON format, a type and a size can be provided to allow\n    for a more compact representation.\n\n    @param[out] result  pair of the size and the type\n    @param[in] inside_ndarray  whether the parser is parsing an ND array dimensional vector\n\n    @return whether pair creation completed\n    */\n    bool get_ubjson_size_type(std::pair<std::size_t, char_int_type>& result, bool inside_ndarray = false)\n    {\n        result.first = npos; // size\n        result.second = 0; // type\n        bool is_ndarray = false;\n\n        get_ignore_noop();\n\n        if (current == '$')\n        {\n            result.second = get();  // must not ignore 'N', because 'N' maybe the type\n            if (input_format == input_format_t::bjdata\n                    && JSON_HEDLEY_UNLIKELY(std::binary_search(bjd_optimized_type_markers.begin(), bjd_optimized_type_markers.end(), result.second)))\n            {\n                auto last_token = get_token_string();\n                return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read,\n                                        exception_message(input_format, concat(\"marker 0x\", last_token, \" is not a permitted optimized array type\"), \"type\"), nullptr));\n            }\n\n            if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format, \"type\")))\n            {\n                return false;\n            }\n\n            get_ignore_noop();\n            if (JSON_HEDLEY_UNLIKELY(current != '#'))\n            {\n                if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format, \"value\")))\n                {\n                    return false;\n                }\n                auto last_token = get_token_string();\n                return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read,\n                                        exception_message(input_format, concat(\"expected '#' after type information; last byte: 0x\", last_token), \"size\"), nullptr));\n            }\n\n            const bool is_error = get_ubjson_size_value(result.first, is_ndarray);\n            if (input_format == input_format_t::bjdata && is_ndarray)\n            {\n                if (inside_ndarray)\n                {\n                    return sax->parse_error(chars_read, get_token_string(), parse_error::create(112, chars_read,\n                                            exception_message(input_format, \"ndarray can not be recursive\", \"size\"), nullptr));\n                }\n                result.second |= (1 << 8); // use bit 8 to indicate ndarray, all UBJSON and BJData markers should be ASCII letters\n            }\n            return is_error;\n        }\n\n        if (current == '#')\n        {\n            const bool is_error = get_ubjson_size_value(result.first, is_ndarray);\n            if (input_format == input_format_t::bjdata && is_ndarray)\n            {\n                return sax->parse_error(chars_read, get_token_string(), parse_error::create(112, chars_read,\n                                        exception_message(input_format, \"ndarray requires both type and size\", \"size\"), nullptr));\n            }\n            return is_error;\n        }\n\n        return true;\n    }\n\n    /*!\n    @param prefix  the previously read or set type prefix\n    @return whether value creation completed\n    */\n    bool get_ubjson_value(const char_int_type prefix)\n    {\n        switch (prefix)\n        {\n            case char_traits<char_type>::eof():  // EOF\n                return unexpect_eof(input_format, \"value\");\n\n            case 'T':  // true\n                return sax->boolean(true);\n            case 'F':  // false\n                return sax->boolean(false);\n\n            case 'Z':  // null\n                return sax->null();\n\n            case 'U':\n            {\n                std::uint8_t number{};\n                return get_number(input_format, number) && sax->number_unsigned(number);\n            }\n\n            case 'i':\n            {\n                std::int8_t number{};\n                return get_number(input_format, number) && sax->number_integer(number);\n            }\n\n            case 'I':\n            {\n                std::int16_t number{};\n                return get_number(input_format, number) && sax->number_integer(number);\n            }\n\n            case 'l':\n            {\n                std::int32_t number{};\n                return get_number(input_format, number) && sax->number_integer(number);\n            }\n\n            case 'L':\n            {\n                std::int64_t number{};\n                return get_number(input_format, number) && sax->number_integer(number);\n            }\n\n            case 'u':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                std::uint16_t number{};\n                return get_number(input_format, number) && sax->number_unsigned(number);\n            }\n\n            case 'm':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                std::uint32_t number{};\n                return get_number(input_format, number) && sax->number_unsigned(number);\n            }\n\n            case 'M':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                std::uint64_t number{};\n                return get_number(input_format, number) && sax->number_unsigned(number);\n            }\n\n            case 'h':\n            {\n                if (input_format != input_format_t::bjdata)\n                {\n                    break;\n                }\n                const auto byte1_raw = get();\n                if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format, \"number\")))\n                {\n                    return false;\n                }\n                const auto byte2_raw = get();\n                if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format, \"number\")))\n                {\n                    return false;\n                }\n\n                const auto byte1 = static_cast<unsigned char>(byte1_raw);\n                const auto byte2 = static_cast<unsigned char>(byte2_raw);\n\n                // code from RFC 7049, Appendix D, Figure 3:\n                // As half-precision floating-point numbers were only added\n                // to IEEE 754 in 2008, today's programming platforms often\n                // still only have limited support for them. It is very\n                // easy to include at least decoding support for them even\n                // without such support. An example of a small decoder for\n                // half-precision floating-point numbers in the C language\n                // is shown in Fig. 3.\n                const auto half = static_cast<unsigned int>((byte2 << 8u) + byte1);\n                const double val = [&half]\n                {\n                    const int exp = (half >> 10u) & 0x1Fu;\n                    const unsigned int mant = half & 0x3FFu;\n                    JSON_ASSERT(0 <= exp&& exp <= 32);\n                    JSON_ASSERT(mant <= 1024);\n                    switch (exp)\n                    {\n                        case 0:\n                            return std::ldexp(mant, -24);\n                        case 31:\n                            return (mant == 0)\n                            ? std::numeric_limits<double>::infinity()\n                            : std::numeric_limits<double>::quiet_NaN();\n                        default:\n                            return std::ldexp(mant + 1024, exp - 25);\n                    }\n                }();\n                return sax->number_float((half & 0x8000u) != 0\n                                         ? static_cast<number_float_t>(-val)\n                                         : static_cast<number_float_t>(val), \"\");\n            }\n\n            case 'd':\n            {\n                float number{};\n                return get_number(input_format, number) && sax->number_float(static_cast<number_float_t>(number), \"\");\n            }\n\n            case 'D':\n            {\n                double number{};\n                return get_number(input_format, number) && sax->number_float(static_cast<number_float_t>(number), \"\");\n            }\n\n            case 'H':\n            {\n                return get_ubjson_high_precision_number();\n            }\n\n            case 'C':  // char\n            {\n                get();\n                if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format, \"char\")))\n                {\n                    return false;\n                }\n                if (JSON_HEDLEY_UNLIKELY(current > 127))\n                {\n                    auto last_token = get_token_string();\n                    return sax->parse_error(chars_read, last_token, parse_error::create(113, chars_read,\n                                            exception_message(input_format, concat(\"byte after 'C' must be in range 0x00..0x7F; last byte: 0x\", last_token), \"char\"), nullptr));\n                }\n                string_t s(1, static_cast<typename string_t::value_type>(current));\n                return sax->string(s);\n            }\n\n            case 'S':  // string\n            {\n                string_t s;\n                return get_ubjson_string(s) && sax->string(s);\n            }\n\n            case '[':  // array\n                return get_ubjson_array();\n\n            case '{':  // object\n                return get_ubjson_object();\n\n            default: // anything else\n                break;\n        }\n        auto last_token = get_token_string();\n        return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read, exception_message(input_format, \"invalid byte: 0x\" + last_token, \"value\"), nullptr));\n    }\n\n    /*!\n    @return whether array creation completed\n    */\n    bool get_ubjson_array()\n    {\n        std::pair<std::size_t, char_int_type> size_and_type;\n        if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_type(size_and_type)))\n        {\n            return false;\n        }\n\n        // if bit-8 of size_and_type.second is set to 1, encode bjdata ndarray as an object in JData annotated array format (https://github.com/NeuroJSON/jdata):\n        // {\"_ArrayType_\" : \"typeid\", \"_ArraySize_\" : [n1, n2, ...], \"_ArrayData_\" : [v1, v2, ...]}\n\n        if (input_format == input_format_t::bjdata && size_and_type.first != npos && (size_and_type.second & (1 << 8)) != 0)\n        {\n            size_and_type.second &= ~(static_cast<char_int_type>(1) << 8);  // use bit 8 to indicate ndarray, here we remove the bit to restore the type marker\n            auto it = std::lower_bound(bjd_types_map.begin(), bjd_types_map.end(), size_and_type.second, [](const bjd_type & p, char_int_type t)\n            {\n                return p.first < t;\n            });\n            string_t key = \"_ArrayType_\";\n            if (JSON_HEDLEY_UNLIKELY(it == bjd_types_map.end() || it->first != size_and_type.second))\n            {\n                auto last_token = get_token_string();\n                return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read,\n                                        exception_message(input_format, \"invalid byte: 0x\" + last_token, \"type\"), nullptr));\n            }\n\n            string_t type = it->second; // sax->string() takes a reference\n            if (JSON_HEDLEY_UNLIKELY(!sax->key(key) || !sax->string(type)))\n            {\n                return false;\n            }\n\n            if (size_and_type.second == 'C')\n            {\n                size_and_type.second = 'U';\n            }\n\n            key = \"_ArrayData_\";\n            if (JSON_HEDLEY_UNLIKELY(!sax->key(key) || !sax->start_array(size_and_type.first) ))\n            {\n                return false;\n            }\n\n            for (std::size_t i = 0; i < size_and_type.first; ++i)\n            {\n                if (JSON_HEDLEY_UNLIKELY(!get_ubjson_value(size_and_type.second)))\n                {\n                    return false;\n                }\n            }\n\n            return (sax->end_array() && sax->end_object());\n        }\n\n        if (size_and_type.first != npos)\n        {\n            if (JSON_HEDLEY_UNLIKELY(!sax->start_array(size_and_type.first)))\n            {\n                return false;\n            }\n\n            if (size_and_type.second != 0)\n            {\n                if (size_and_type.second != 'N')\n                {\n                    for (std::size_t i = 0; i < size_and_type.first; ++i)\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!get_ubjson_value(size_and_type.second)))\n                        {\n                            return false;\n                        }\n                    }\n                }\n            }\n            else\n            {\n                for (std::size_t i = 0; i < size_and_type.first; ++i)\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal()))\n                    {\n                        return false;\n                    }\n                }\n            }\n        }\n        else\n        {\n            if (JSON_HEDLEY_UNLIKELY(!sax->start_array(static_cast<std::size_t>(-1))))\n            {\n                return false;\n            }\n\n            while (current != ']')\n            {\n                if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal(false)))\n                {\n                    return false;\n                }\n                get_ignore_noop();\n            }\n        }\n\n        return sax->end_array();\n    }\n\n    /*!\n    @return whether object creation completed\n    */\n    bool get_ubjson_object()\n    {\n        std::pair<std::size_t, char_int_type> size_and_type;\n        if (JSON_HEDLEY_UNLIKELY(!get_ubjson_size_type(size_and_type)))\n        {\n            return false;\n        }\n\n        // do not accept ND-array size in objects in BJData\n        if (input_format == input_format_t::bjdata && size_and_type.first != npos && (size_and_type.second & (1 << 8)) != 0)\n        {\n            auto last_token = get_token_string();\n            return sax->parse_error(chars_read, last_token, parse_error::create(112, chars_read,\n                                    exception_message(input_format, \"BJData object does not support ND-array size in optimized format\", \"object\"), nullptr));\n        }\n\n        string_t key;\n        if (size_and_type.first != npos)\n        {\n            if (JSON_HEDLEY_UNLIKELY(!sax->start_object(size_and_type.first)))\n            {\n                return false;\n            }\n\n            if (size_and_type.second != 0)\n            {\n                for (std::size_t i = 0; i < size_and_type.first; ++i)\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key) || !sax->key(key)))\n                    {\n                        return false;\n                    }\n                    if (JSON_HEDLEY_UNLIKELY(!get_ubjson_value(size_and_type.second)))\n                    {\n                        return false;\n                    }\n                    key.clear();\n                }\n            }\n            else\n            {\n                for (std::size_t i = 0; i < size_and_type.first; ++i)\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key) || !sax->key(key)))\n                    {\n                        return false;\n                    }\n                    if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal()))\n                    {\n                        return false;\n                    }\n                    key.clear();\n                }\n            }\n        }\n        else\n        {\n            if (JSON_HEDLEY_UNLIKELY(!sax->start_object(static_cast<std::size_t>(-1))))\n            {\n                return false;\n            }\n\n            while (current != '}')\n            {\n                if (JSON_HEDLEY_UNLIKELY(!get_ubjson_string(key, false) || !sax->key(key)))\n                {\n                    return false;\n                }\n                if (JSON_HEDLEY_UNLIKELY(!parse_ubjson_internal()))\n                {\n                    return false;\n                }\n                get_ignore_noop();\n                key.clear();\n            }\n        }\n\n        return sax->end_object();\n    }\n\n    // Note, no reader for UBJSON binary types is implemented because they do\n    // not exist\n\n    bool get_ubjson_high_precision_number()\n    {\n        // get size of following number string\n        std::size_t size{};\n        bool no_ndarray = true;\n        auto res = get_ubjson_size_value(size, no_ndarray);\n        if (JSON_HEDLEY_UNLIKELY(!res))\n        {\n            return res;\n        }\n\n        // get number string\n        std::vector<char> number_vector;\n        for (std::size_t i = 0; i < size; ++i)\n        {\n            get();\n            if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(input_format, \"number\")))\n            {\n                return false;\n            }\n            number_vector.push_back(static_cast<char>(current));\n        }\n\n        // parse number string\n        using ia_type = decltype(detail::input_adapter(number_vector));\n        auto number_lexer = detail::lexer<BasicJsonType, ia_type>(detail::input_adapter(number_vector), false);\n        const auto result_number = number_lexer.scan();\n        const auto number_string = number_lexer.get_token_string();\n        const auto result_remainder = number_lexer.scan();\n\n        using token_type = typename detail::lexer_base<BasicJsonType>::token_type;\n\n        if (JSON_HEDLEY_UNLIKELY(result_remainder != token_type::end_of_input))\n        {\n            return sax->parse_error(chars_read, number_string, parse_error::create(115, chars_read,\n                                    exception_message(input_format, concat(\"invalid number text: \", number_lexer.get_token_string()), \"high-precision number\"), nullptr));\n        }\n\n        switch (result_number)\n        {\n            case token_type::value_integer:\n                return sax->number_integer(number_lexer.get_number_integer());\n            case token_type::value_unsigned:\n                return sax->number_unsigned(number_lexer.get_number_unsigned());\n            case token_type::value_float:\n                return sax->number_float(number_lexer.get_number_float(), std::move(number_string));\n            case token_type::uninitialized:\n            case token_type::literal_true:\n            case token_type::literal_false:\n            case token_type::literal_null:\n            case token_type::value_string:\n            case token_type::begin_array:\n            case token_type::begin_object:\n            case token_type::end_array:\n            case token_type::end_object:\n            case token_type::name_separator:\n            case token_type::value_separator:\n            case token_type::parse_error:\n            case token_type::end_of_input:\n            case token_type::literal_or_value:\n            default:\n                return sax->parse_error(chars_read, number_string, parse_error::create(115, chars_read,\n                                        exception_message(input_format, concat(\"invalid number text: \", number_lexer.get_token_string()), \"high-precision number\"), nullptr));\n        }\n    }\n\n    ///////////////////////\n    // Utility functions //\n    ///////////////////////\n\n    /*!\n    @brief get next character from the input\n\n    This function provides the interface to the used input adapter. It does\n    not throw in case the input reached EOF, but returns a -'ve valued\n    `char_traits<char_type>::eof()` in that case.\n\n    @return character read from the input\n    */\n    char_int_type get()\n    {\n        ++chars_read;\n        return current = ia.get_character();\n    }\n\n    /*!\n    @return character read from the input after ignoring all 'N' entries\n    */\n    char_int_type get_ignore_noop()\n    {\n        do\n        {\n            get();\n        }\n        while (current == 'N');\n\n        return current;\n    }\n\n    /*\n    @brief read a number from the input\n\n    @tparam NumberType the type of the number\n    @param[in] format   the current format (for diagnostics)\n    @param[out] result  number of type @a NumberType\n\n    @return whether conversion completed\n\n    @note This function needs to respect the system's endianness, because\n          bytes in CBOR, MessagePack, and UBJSON are stored in network order\n          (big endian) and therefore need reordering on little endian systems.\n          On the other hand, BSON and BJData use little endian and should reorder\n          on big endian systems.\n    */\n    template<typename NumberType, bool InputIsLittleEndian = false>\n    bool get_number(const input_format_t format, NumberType& result)\n    {\n        // step 1: read input into array with system's byte order\n        std::array<std::uint8_t, sizeof(NumberType)> vec{};\n        for (std::size_t i = 0; i < sizeof(NumberType); ++i)\n        {\n            get();\n            if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, \"number\")))\n            {\n                return false;\n            }\n\n            // reverse byte order prior to conversion if necessary\n            if (is_little_endian != (InputIsLittleEndian || format == input_format_t::bjdata))\n            {\n                vec[sizeof(NumberType) - i - 1] = static_cast<std::uint8_t>(current);\n            }\n            else\n            {\n                vec[i] = static_cast<std::uint8_t>(current); // LCOV_EXCL_LINE\n            }\n        }\n\n        // step 2: convert array into number of type T and return\n        std::memcpy(&result, vec.data(), sizeof(NumberType));\n        return true;\n    }\n\n    /*!\n    @brief create a string by reading characters from the input\n\n    @tparam NumberType the type of the number\n    @param[in] format the current format (for diagnostics)\n    @param[in] len number of characters to read\n    @param[out] result string created by reading @a len bytes\n\n    @return whether string creation completed\n\n    @note We can not reserve @a len bytes for the result, because @a len\n          may be too large. Usually, @ref unexpect_eof() detects the end of\n          the input before we run out of string memory.\n    */\n    template<typename NumberType>\n    bool get_string(const input_format_t format,\n                    const NumberType len,\n                    string_t& result)\n    {\n        bool success = true;\n        for (NumberType i = 0; i < len; i++)\n        {\n            get();\n            if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, \"string\")))\n            {\n                success = false;\n                break;\n            }\n            result.push_back(static_cast<typename string_t::value_type>(current));\n        }\n        return success;\n    }\n\n    /*!\n    @brief create a byte array by reading bytes from the input\n\n    @tparam NumberType the type of the number\n    @param[in] format the current format (for diagnostics)\n    @param[in] len number of bytes to read\n    @param[out] result byte array created by reading @a len bytes\n\n    @return whether byte array creation completed\n\n    @note We can not reserve @a len bytes for the result, because @a len\n          may be too large. Usually, @ref unexpect_eof() detects the end of\n          the input before we run out of memory.\n    */\n    template<typename NumberType>\n    bool get_binary(const input_format_t format,\n                    const NumberType len,\n                    binary_t& result)\n    {\n        bool success = true;\n        for (NumberType i = 0; i < len; i++)\n        {\n            get();\n            if (JSON_HEDLEY_UNLIKELY(!unexpect_eof(format, \"binary\")))\n            {\n                success = false;\n                break;\n            }\n            result.push_back(static_cast<std::uint8_t>(current));\n        }\n        return success;\n    }\n\n    /*!\n    @param[in] format   the current format (for diagnostics)\n    @param[in] context  further context information (for diagnostics)\n    @return whether the last read character is not EOF\n    */\n    JSON_HEDLEY_NON_NULL(3)\n    bool unexpect_eof(const input_format_t format, const char* context) const\n    {\n        if (JSON_HEDLEY_UNLIKELY(current == char_traits<char_type>::eof()))\n        {\n            return sax->parse_error(chars_read, \"<end of file>\",\n                                    parse_error::create(110, chars_read, exception_message(format, \"unexpected end of input\", context), nullptr));\n        }\n        return true;\n    }\n\n    /*!\n    @return a string representation of the last read byte\n    */\n    std::string get_token_string() const\n    {\n        std::array<char, 3> cr{{}};\n        static_cast<void>((std::snprintf)(cr.data(), cr.size(), \"%.2hhX\", static_cast<unsigned char>(current))); // NOLINT(cppcoreguidelines-pro-type-vararg,hicpp-vararg)\n        return std::string{cr.data()};\n    }\n\n    /*!\n    @param[in] format   the current format\n    @param[in] detail   a detailed error message\n    @param[in] context  further context information\n    @return a message string to use in the parse_error exceptions\n    */\n    std::string exception_message(const input_format_t format,\n                                  const std::string& detail,\n                                  const std::string& context) const\n    {\n        std::string error_msg = \"syntax error while parsing \";\n\n        switch (format)\n        {\n            case input_format_t::cbor:\n                error_msg += \"CBOR\";\n                break;\n\n            case input_format_t::msgpack:\n                error_msg += \"MessagePack\";\n                break;\n\n            case input_format_t::ubjson:\n                error_msg += \"UBJSON\";\n                break;\n\n            case input_format_t::bson:\n                error_msg += \"BSON\";\n                break;\n\n            case input_format_t::bjdata:\n                error_msg += \"BJData\";\n                break;\n\n            case input_format_t::json: // LCOV_EXCL_LINE\n            default:            // LCOV_EXCL_LINE\n                JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n        }\n\n        return concat(error_msg, ' ', context, \": \", detail);\n    }\n\n  private:\n    static JSON_INLINE_VARIABLE constexpr std::size_t npos = static_cast<std::size_t>(-1);\n\n    /// input adapter\n    InputAdapterType ia;\n\n    /// the current character\n    char_int_type current = char_traits<char_type>::eof();\n\n    /// the number of characters read\n    std::size_t chars_read = 0;\n\n    /// whether we can assume little endianness\n    const bool is_little_endian = little_endianness();\n\n    /// input format\n    const input_format_t input_format = input_format_t::json;\n\n    /// the SAX parser\n    json_sax_t* sax = nullptr;\n\n    // excluded markers in bjdata optimized type\n#define JSON_BINARY_READER_MAKE_BJD_OPTIMIZED_TYPE_MARKERS_ \\\n    make_array<char_int_type>('F', 'H', 'N', 'S', 'T', 'Z', '[', '{')\n\n#define JSON_BINARY_READER_MAKE_BJD_TYPES_MAP_ \\\n    make_array<bjd_type>(                      \\\n    bjd_type{'C', \"char\"},                     \\\n    bjd_type{'D', \"double\"},                   \\\n    bjd_type{'I', \"int16\"},                    \\\n    bjd_type{'L', \"int64\"},                    \\\n    bjd_type{'M', \"uint64\"},                   \\\n    bjd_type{'U', \"uint8\"},                    \\\n    bjd_type{'d', \"single\"},                   \\\n    bjd_type{'i', \"int8\"},                     \\\n    bjd_type{'l', \"int32\"},                    \\\n    bjd_type{'m', \"uint32\"},                   \\\n    bjd_type{'u', \"uint16\"})\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    // lookup tables\n    // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)\n    const decltype(JSON_BINARY_READER_MAKE_BJD_OPTIMIZED_TYPE_MARKERS_) bjd_optimized_type_markers =\n        JSON_BINARY_READER_MAKE_BJD_OPTIMIZED_TYPE_MARKERS_;\n\n    using bjd_type = std::pair<char_int_type, string_t>;\n    // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)\n    const decltype(JSON_BINARY_READER_MAKE_BJD_TYPES_MAP_) bjd_types_map =\n        JSON_BINARY_READER_MAKE_BJD_TYPES_MAP_;\n\n#undef JSON_BINARY_READER_MAKE_BJD_OPTIMIZED_TYPE_MARKERS_\n#undef JSON_BINARY_READER_MAKE_BJD_TYPES_MAP_\n};\n\n#ifndef JSON_HAS_CPP_17\n    template<typename BasicJsonType, typename InputAdapterType, typename SAX>\n    constexpr std::size_t binary_reader<BasicJsonType, InputAdapterType, SAX>::npos;\n#endif\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/input/input_adapters.hpp>\n\n// #include <nlohmann/detail/input/lexer.hpp>\n\n// #include <nlohmann/detail/input/parser.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cmath> // isfinite\n#include <cstdint> // uint8_t\n#include <functional> // function\n#include <string> // string\n#include <utility> // move\n#include <vector> // vector\n\n// #include <nlohmann/detail/exceptions.hpp>\n\n// #include <nlohmann/detail/input/input_adapters.hpp>\n\n// #include <nlohmann/detail/input/json_sax.hpp>\n\n// #include <nlohmann/detail/input/lexer.hpp>\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/is_sax.hpp>\n\n// #include <nlohmann/detail/string_concat.hpp>\n\n// #include <nlohmann/detail/value_t.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n////////////\n// parser //\n////////////\n\nenum class parse_event_t : std::uint8_t\n{\n    /// the parser read `{` and started to process a JSON object\n    object_start,\n    /// the parser read `}` and finished processing a JSON object\n    object_end,\n    /// the parser read `[` and started to process a JSON array\n    array_start,\n    /// the parser read `]` and finished processing a JSON array\n    array_end,\n    /// the parser read a key of a value in an object\n    key,\n    /// the parser finished reading a JSON value\n    value\n};\n\ntemplate<typename BasicJsonType>\nusing parser_callback_t =\n    std::function<bool(int /*depth*/, parse_event_t /*event*/, BasicJsonType& /*parsed*/)>;\n\n/*!\n@brief syntax analysis\n\nThis class implements a recursive descent parser.\n*/\ntemplate<typename BasicJsonType, typename InputAdapterType>\nclass parser\n{\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n    using string_t = typename BasicJsonType::string_t;\n    using lexer_t = lexer<BasicJsonType, InputAdapterType>;\n    using token_type = typename lexer_t::token_type;\n\n  public:\n    /// a parser reading from an input adapter\n    explicit parser(InputAdapterType&& adapter,\n                    const parser_callback_t<BasicJsonType> cb = nullptr,\n                    const bool allow_exceptions_ = true,\n                    const bool skip_comments = false)\n        : callback(cb)\n        , m_lexer(std::move(adapter), skip_comments)\n        , allow_exceptions(allow_exceptions_)\n    {\n        // read first token\n        get_token();\n    }\n\n    /*!\n    @brief public parser interface\n\n    @param[in] strict      whether to expect the last token to be EOF\n    @param[in,out] result  parsed JSON value\n\n    @throw parse_error.101 in case of an unexpected token\n    @throw parse_error.102 if to_unicode fails or surrogate error\n    @throw parse_error.103 if to_unicode fails\n    */\n    void parse(const bool strict, BasicJsonType& result)\n    {\n        if (callback)\n        {\n            json_sax_dom_callback_parser<BasicJsonType> sdp(result, callback, allow_exceptions);\n            sax_parse_internal(&sdp);\n\n            // in strict mode, input must be completely read\n            if (strict && (get_token() != token_type::end_of_input))\n            {\n                sdp.parse_error(m_lexer.get_position(),\n                                m_lexer.get_token_string(),\n                                parse_error::create(101, m_lexer.get_position(),\n                                                    exception_message(token_type::end_of_input, \"value\"), nullptr));\n            }\n\n            // in case of an error, return discarded value\n            if (sdp.is_errored())\n            {\n                result = value_t::discarded;\n                return;\n            }\n\n            // set top-level value to null if it was discarded by the callback\n            // function\n            if (result.is_discarded())\n            {\n                result = nullptr;\n            }\n        }\n        else\n        {\n            json_sax_dom_parser<BasicJsonType> sdp(result, allow_exceptions);\n            sax_parse_internal(&sdp);\n\n            // in strict mode, input must be completely read\n            if (strict && (get_token() != token_type::end_of_input))\n            {\n                sdp.parse_error(m_lexer.get_position(),\n                                m_lexer.get_token_string(),\n                                parse_error::create(101, m_lexer.get_position(), exception_message(token_type::end_of_input, \"value\"), nullptr));\n            }\n\n            // in case of an error, return discarded value\n            if (sdp.is_errored())\n            {\n                result = value_t::discarded;\n                return;\n            }\n        }\n\n        result.assert_invariant();\n    }\n\n    /*!\n    @brief public accept interface\n\n    @param[in] strict  whether to expect the last token to be EOF\n    @return whether the input is a proper JSON text\n    */\n    bool accept(const bool strict = true)\n    {\n        json_sax_acceptor<BasicJsonType> sax_acceptor;\n        return sax_parse(&sax_acceptor, strict);\n    }\n\n    template<typename SAX>\n    JSON_HEDLEY_NON_NULL(2)\n    bool sax_parse(SAX* sax, const bool strict = true)\n    {\n        (void)detail::is_sax_static_asserts<SAX, BasicJsonType> {};\n        const bool result = sax_parse_internal(sax);\n\n        // strict mode: next byte must be EOF\n        if (result && strict && (get_token() != token_type::end_of_input))\n        {\n            return sax->parse_error(m_lexer.get_position(),\n                                    m_lexer.get_token_string(),\n                                    parse_error::create(101, m_lexer.get_position(), exception_message(token_type::end_of_input, \"value\"), nullptr));\n        }\n\n        return result;\n    }\n\n  private:\n    template<typename SAX>\n    JSON_HEDLEY_NON_NULL(2)\n    bool sax_parse_internal(SAX* sax)\n    {\n        // stack to remember the hierarchy of structured values we are parsing\n        // true = array; false = object\n        std::vector<bool> states;\n        // value to avoid a goto (see comment where set to true)\n        bool skip_to_state_evaluation = false;\n\n        while (true)\n        {\n            if (!skip_to_state_evaluation)\n            {\n                // invariant: get_token() was called before each iteration\n                switch (last_token)\n                {\n                    case token_type::begin_object:\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!sax->start_object(static_cast<std::size_t>(-1))))\n                        {\n                            return false;\n                        }\n\n                        // closing } -> we are done\n                        if (get_token() == token_type::end_object)\n                        {\n                            if (JSON_HEDLEY_UNLIKELY(!sax->end_object()))\n                            {\n                                return false;\n                            }\n                            break;\n                        }\n\n                        // parse key\n                        if (JSON_HEDLEY_UNLIKELY(last_token != token_type::value_string))\n                        {\n                            return sax->parse_error(m_lexer.get_position(),\n                                                    m_lexer.get_token_string(),\n                                                    parse_error::create(101, m_lexer.get_position(), exception_message(token_type::value_string, \"object key\"), nullptr));\n                        }\n                        if (JSON_HEDLEY_UNLIKELY(!sax->key(m_lexer.get_string())))\n                        {\n                            return false;\n                        }\n\n                        // parse separator (:)\n                        if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::name_separator))\n                        {\n                            return sax->parse_error(m_lexer.get_position(),\n                                                    m_lexer.get_token_string(),\n                                                    parse_error::create(101, m_lexer.get_position(), exception_message(token_type::name_separator, \"object separator\"), nullptr));\n                        }\n\n                        // remember we are now inside an object\n                        states.push_back(false);\n\n                        // parse values\n                        get_token();\n                        continue;\n                    }\n\n                    case token_type::begin_array:\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!sax->start_array(static_cast<std::size_t>(-1))))\n                        {\n                            return false;\n                        }\n\n                        // closing ] -> we are done\n                        if (get_token() == token_type::end_array)\n                        {\n                            if (JSON_HEDLEY_UNLIKELY(!sax->end_array()))\n                            {\n                                return false;\n                            }\n                            break;\n                        }\n\n                        // remember we are now inside an array\n                        states.push_back(true);\n\n                        // parse values (no need to call get_token)\n                        continue;\n                    }\n\n                    case token_type::value_float:\n                    {\n                        const auto res = m_lexer.get_number_float();\n\n                        if (JSON_HEDLEY_UNLIKELY(!std::isfinite(res)))\n                        {\n                            return sax->parse_error(m_lexer.get_position(),\n                                                    m_lexer.get_token_string(),\n                                                    out_of_range::create(406, concat(\"number overflow parsing '\", m_lexer.get_token_string(), '\\''), nullptr));\n                        }\n\n                        if (JSON_HEDLEY_UNLIKELY(!sax->number_float(res, m_lexer.get_string())))\n                        {\n                            return false;\n                        }\n\n                        break;\n                    }\n\n                    case token_type::literal_false:\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!sax->boolean(false)))\n                        {\n                            return false;\n                        }\n                        break;\n                    }\n\n                    case token_type::literal_null:\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!sax->null()))\n                        {\n                            return false;\n                        }\n                        break;\n                    }\n\n                    case token_type::literal_true:\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!sax->boolean(true)))\n                        {\n                            return false;\n                        }\n                        break;\n                    }\n\n                    case token_type::value_integer:\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!sax->number_integer(m_lexer.get_number_integer())))\n                        {\n                            return false;\n                        }\n                        break;\n                    }\n\n                    case token_type::value_string:\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!sax->string(m_lexer.get_string())))\n                        {\n                            return false;\n                        }\n                        break;\n                    }\n\n                    case token_type::value_unsigned:\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!sax->number_unsigned(m_lexer.get_number_unsigned())))\n                        {\n                            return false;\n                        }\n                        break;\n                    }\n\n                    case token_type::parse_error:\n                    {\n                        // using \"uninitialized\" to avoid \"expected\" message\n                        return sax->parse_error(m_lexer.get_position(),\n                                                m_lexer.get_token_string(),\n                                                parse_error::create(101, m_lexer.get_position(), exception_message(token_type::uninitialized, \"value\"), nullptr));\n                    }\n                    case token_type::end_of_input:\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(m_lexer.get_position().chars_read_total == 1))\n                        {\n                            return sax->parse_error(m_lexer.get_position(),\n                                                    m_lexer.get_token_string(),\n                                                    parse_error::create(101, m_lexer.get_position(),\n                                                            \"attempting to parse an empty input; check that your input string or stream contains the expected JSON\", nullptr));\n                        }\n\n                        return sax->parse_error(m_lexer.get_position(),\n                                                m_lexer.get_token_string(),\n                                                parse_error::create(101, m_lexer.get_position(), exception_message(token_type::literal_or_value, \"value\"), nullptr));\n                    }\n                    case token_type::uninitialized:\n                    case token_type::end_array:\n                    case token_type::end_object:\n                    case token_type::name_separator:\n                    case token_type::value_separator:\n                    case token_type::literal_or_value:\n                    default: // the last token was unexpected\n                    {\n                        return sax->parse_error(m_lexer.get_position(),\n                                                m_lexer.get_token_string(),\n                                                parse_error::create(101, m_lexer.get_position(), exception_message(token_type::literal_or_value, \"value\"), nullptr));\n                    }\n                }\n            }\n            else\n            {\n                skip_to_state_evaluation = false;\n            }\n\n            // we reached this line after we successfully parsed a value\n            if (states.empty())\n            {\n                // empty stack: we reached the end of the hierarchy: done\n                return true;\n            }\n\n            if (states.back())  // array\n            {\n                // comma -> next value\n                if (get_token() == token_type::value_separator)\n                {\n                    // parse a new value\n                    get_token();\n                    continue;\n                }\n\n                // closing ]\n                if (JSON_HEDLEY_LIKELY(last_token == token_type::end_array))\n                {\n                    if (JSON_HEDLEY_UNLIKELY(!sax->end_array()))\n                    {\n                        return false;\n                    }\n\n                    // We are done with this array. Before we can parse a\n                    // new value, we need to evaluate the new state first.\n                    // By setting skip_to_state_evaluation to false, we\n                    // are effectively jumping to the beginning of this if.\n                    JSON_ASSERT(!states.empty());\n                    states.pop_back();\n                    skip_to_state_evaluation = true;\n                    continue;\n                }\n\n                return sax->parse_error(m_lexer.get_position(),\n                                        m_lexer.get_token_string(),\n                                        parse_error::create(101, m_lexer.get_position(), exception_message(token_type::end_array, \"array\"), nullptr));\n            }\n\n            // states.back() is false -> object\n\n            // comma -> next value\n            if (get_token() == token_type::value_separator)\n            {\n                // parse key\n                if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::value_string))\n                {\n                    return sax->parse_error(m_lexer.get_position(),\n                                            m_lexer.get_token_string(),\n                                            parse_error::create(101, m_lexer.get_position(), exception_message(token_type::value_string, \"object key\"), nullptr));\n                }\n\n                if (JSON_HEDLEY_UNLIKELY(!sax->key(m_lexer.get_string())))\n                {\n                    return false;\n                }\n\n                // parse separator (:)\n                if (JSON_HEDLEY_UNLIKELY(get_token() != token_type::name_separator))\n                {\n                    return sax->parse_error(m_lexer.get_position(),\n                                            m_lexer.get_token_string(),\n                                            parse_error::create(101, m_lexer.get_position(), exception_message(token_type::name_separator, \"object separator\"), nullptr));\n                }\n\n                // parse values\n                get_token();\n                continue;\n            }\n\n            // closing }\n            if (JSON_HEDLEY_LIKELY(last_token == token_type::end_object))\n            {\n                if (JSON_HEDLEY_UNLIKELY(!sax->end_object()))\n                {\n                    return false;\n                }\n\n                // We are done with this object. Before we can parse a\n                // new value, we need to evaluate the new state first.\n                // By setting skip_to_state_evaluation to false, we\n                // are effectively jumping to the beginning of this if.\n                JSON_ASSERT(!states.empty());\n                states.pop_back();\n                skip_to_state_evaluation = true;\n                continue;\n            }\n\n            return sax->parse_error(m_lexer.get_position(),\n                                    m_lexer.get_token_string(),\n                                    parse_error::create(101, m_lexer.get_position(), exception_message(token_type::end_object, \"object\"), nullptr));\n        }\n    }\n\n    /// get next token from lexer\n    token_type get_token()\n    {\n        return last_token = m_lexer.scan();\n    }\n\n    std::string exception_message(const token_type expected, const std::string& context)\n    {\n        std::string error_msg = \"syntax error \";\n\n        if (!context.empty())\n        {\n            error_msg += concat(\"while parsing \", context, ' ');\n        }\n\n        error_msg += \"- \";\n\n        if (last_token == token_type::parse_error)\n        {\n            error_msg += concat(m_lexer.get_error_message(), \"; last read: '\",\n                                m_lexer.get_token_string(), '\\'');\n        }\n        else\n        {\n            error_msg += concat(\"unexpected \", lexer_t::token_type_name(last_token));\n        }\n\n        if (expected != token_type::uninitialized)\n        {\n            error_msg += concat(\"; expected \", lexer_t::token_type_name(expected));\n        }\n\n        return error_msg;\n    }\n\n  private:\n    /// callback function\n    const parser_callback_t<BasicJsonType> callback = nullptr;\n    /// the type of the last read token\n    token_type last_token = token_type::uninitialized;\n    /// the lexer\n    lexer_t m_lexer;\n    /// whether to throw exceptions in case of errors\n    const bool allow_exceptions = true;\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/iterators/internal_iterator.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n// #include <nlohmann/detail/iterators/primitive_iterator.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cstddef> // ptrdiff_t\n#include <limits>  // numeric_limits\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n/*\n@brief an iterator for primitive JSON types\n\nThis class models an iterator for primitive JSON types (boolean, number,\nstring). It's only purpose is to allow the iterator/const_iterator classes\nto \"iterate\" over primitive values. Internally, the iterator is modeled by\na `difference_type` variable. Value begin_value (`0`) models the begin,\nend_value (`1`) models past the end.\n*/\nclass primitive_iterator_t\n{\n  private:\n    using difference_type = std::ptrdiff_t;\n    static constexpr difference_type begin_value = 0;\n    static constexpr difference_type end_value = begin_value + 1;\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    /// iterator as signed integer type\n    difference_type m_it = (std::numeric_limits<std::ptrdiff_t>::min)();\n\n  public:\n    constexpr difference_type get_value() const noexcept\n    {\n        return m_it;\n    }\n\n    /// set iterator to a defined beginning\n    void set_begin() noexcept\n    {\n        m_it = begin_value;\n    }\n\n    /// set iterator to a defined past the end\n    void set_end() noexcept\n    {\n        m_it = end_value;\n    }\n\n    /// return whether the iterator can be dereferenced\n    constexpr bool is_begin() const noexcept\n    {\n        return m_it == begin_value;\n    }\n\n    /// return whether the iterator is at end\n    constexpr bool is_end() const noexcept\n    {\n        return m_it == end_value;\n    }\n\n    friend constexpr bool operator==(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept\n    {\n        return lhs.m_it == rhs.m_it;\n    }\n\n    friend constexpr bool operator<(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept\n    {\n        return lhs.m_it < rhs.m_it;\n    }\n\n    primitive_iterator_t operator+(difference_type n) noexcept\n    {\n        auto result = *this;\n        result += n;\n        return result;\n    }\n\n    friend constexpr difference_type operator-(primitive_iterator_t lhs, primitive_iterator_t rhs) noexcept\n    {\n        return lhs.m_it - rhs.m_it;\n    }\n\n    primitive_iterator_t& operator++() noexcept\n    {\n        ++m_it;\n        return *this;\n    }\n\n    primitive_iterator_t operator++(int)& noexcept // NOLINT(cert-dcl21-cpp)\n    {\n        auto result = *this;\n        ++m_it;\n        return result;\n    }\n\n    primitive_iterator_t& operator--() noexcept\n    {\n        --m_it;\n        return *this;\n    }\n\n    primitive_iterator_t operator--(int)& noexcept // NOLINT(cert-dcl21-cpp)\n    {\n        auto result = *this;\n        --m_it;\n        return result;\n    }\n\n    primitive_iterator_t& operator+=(difference_type n) noexcept\n    {\n        m_it += n;\n        return *this;\n    }\n\n    primitive_iterator_t& operator-=(difference_type n) noexcept\n    {\n        m_it -= n;\n        return *this;\n    }\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n/*!\n@brief an iterator value\n\n@note This structure could easily be a union, but MSVC currently does not allow\nunions members with complex constructors, see https://github.com/nlohmann/json/pull/105.\n*/\ntemplate<typename BasicJsonType> struct internal_iterator\n{\n    /// iterator for JSON objects\n    typename BasicJsonType::object_t::iterator object_iterator {};\n    /// iterator for JSON arrays\n    typename BasicJsonType::array_t::iterator array_iterator {};\n    /// generic iterator for all other types\n    primitive_iterator_t primitive_iterator {};\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/iterators/iter_impl.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <iterator> // iterator, random_access_iterator_tag, bidirectional_iterator_tag, advance, next\n#include <type_traits> // conditional, is_const, remove_const\n\n// #include <nlohmann/detail/exceptions.hpp>\n\n// #include <nlohmann/detail/iterators/internal_iterator.hpp>\n\n// #include <nlohmann/detail/iterators/primitive_iterator.hpp>\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/cpp_future.hpp>\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n// #include <nlohmann/detail/value_t.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n// forward declare, to be able to friend it later on\ntemplate<typename IteratorType> class iteration_proxy;\ntemplate<typename IteratorType> class iteration_proxy_value;\n\n/*!\n@brief a template for a bidirectional iterator for the @ref basic_json class\nThis class implements a both iterators (iterator and const_iterator) for the\n@ref basic_json class.\n@note An iterator is called *initialized* when a pointer to a JSON value has\n      been set (e.g., by a constructor or a copy assignment). If the iterator is\n      default-constructed, it is *uninitialized* and most methods are undefined.\n      **The library uses assertions to detect calls on uninitialized iterators.**\n@requirement The class satisfies the following concept requirements:\n-\n[BidirectionalIterator](https://en.cppreference.com/w/cpp/named_req/BidirectionalIterator):\n  The iterator that can be moved can be moved in both directions (i.e.\n  incremented and decremented).\n@since version 1.0.0, simplified in version 2.0.9, change to bidirectional\n       iterators in version 3.0.0 (see https://github.com/nlohmann/json/issues/593)\n*/\ntemplate<typename BasicJsonType>\nclass iter_impl // NOLINT(cppcoreguidelines-special-member-functions,hicpp-special-member-functions)\n{\n    /// the iterator with BasicJsonType of different const-ness\n    using other_iter_impl = iter_impl<typename std::conditional<std::is_const<BasicJsonType>::value, typename std::remove_const<BasicJsonType>::type, const BasicJsonType>::type>;\n    /// allow basic_json to access private members\n    friend other_iter_impl;\n    friend BasicJsonType;\n    friend iteration_proxy<iter_impl>;\n    friend iteration_proxy_value<iter_impl>;\n\n    using object_t = typename BasicJsonType::object_t;\n    using array_t = typename BasicJsonType::array_t;\n    // make sure BasicJsonType is basic_json or const basic_json\n    static_assert(is_basic_json<typename std::remove_const<BasicJsonType>::type>::value,\n                  \"iter_impl only accepts (const) basic_json\");\n    // superficial check for the LegacyBidirectionalIterator named requirement\n    static_assert(std::is_base_of<std::bidirectional_iterator_tag, std::bidirectional_iterator_tag>::value\n                  &&  std::is_base_of<std::bidirectional_iterator_tag, typename std::iterator_traits<typename array_t::iterator>::iterator_category>::value,\n                  \"basic_json iterator assumes array and object type iterators satisfy the LegacyBidirectionalIterator named requirement.\");\n\n  public:\n    /// The std::iterator class template (used as a base class to provide typedefs) is deprecated in C++17.\n    /// The C++ Standard has never required user-defined iterators to derive from std::iterator.\n    /// A user-defined iterator should provide publicly accessible typedefs named\n    /// iterator_category, value_type, difference_type, pointer, and reference.\n    /// Note that value_type is required to be non-const, even for constant iterators.\n    using iterator_category = std::bidirectional_iterator_tag;\n\n    /// the type of the values when the iterator is dereferenced\n    using value_type = typename BasicJsonType::value_type;\n    /// a type to represent differences between iterators\n    using difference_type = typename BasicJsonType::difference_type;\n    /// defines a pointer to the type iterated over (value_type)\n    using pointer = typename std::conditional<std::is_const<BasicJsonType>::value,\n          typename BasicJsonType::const_pointer,\n          typename BasicJsonType::pointer>::type;\n    /// defines a reference to the type iterated over (value_type)\n    using reference =\n        typename std::conditional<std::is_const<BasicJsonType>::value,\n        typename BasicJsonType::const_reference,\n        typename BasicJsonType::reference>::type;\n\n    iter_impl() = default;\n    ~iter_impl() = default;\n    iter_impl(iter_impl&&) noexcept = default;\n    iter_impl& operator=(iter_impl&&) noexcept = default;\n\n    /*!\n    @brief constructor for a given JSON instance\n    @param[in] object  pointer to a JSON object for this iterator\n    @pre object != nullptr\n    @post The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    explicit iter_impl(pointer object) noexcept : m_object(object)\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n            {\n                m_it.object_iterator = typename object_t::iterator();\n                break;\n            }\n\n            case value_t::array:\n            {\n                m_it.array_iterator = typename array_t::iterator();\n                break;\n            }\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                m_it.primitive_iterator = primitive_iterator_t();\n                break;\n            }\n        }\n    }\n\n    /*!\n    @note The conventional copy constructor and copy assignment are implicitly\n          defined. Combined with the following converting constructor and\n          assignment, they support: (1) copy from iterator to iterator, (2)\n          copy from const iterator to const iterator, and (3) conversion from\n          iterator to const iterator. However conversion from const iterator\n          to iterator is not defined.\n    */\n\n    /*!\n    @brief const copy constructor\n    @param[in] other const iterator to copy from\n    @note This copy constructor had to be defined explicitly to circumvent a bug\n          occurring on msvc v19.0 compiler (VS 2015) debug build. For more\n          information refer to: https://github.com/nlohmann/json/issues/1608\n    */\n    iter_impl(const iter_impl<const BasicJsonType>& other) noexcept\n        : m_object(other.m_object), m_it(other.m_it)\n    {}\n\n    /*!\n    @brief converting assignment\n    @param[in] other const iterator to copy from\n    @return const/non-const iterator\n    @note It is not checked whether @a other is initialized.\n    */\n    iter_impl& operator=(const iter_impl<const BasicJsonType>& other) noexcept\n    {\n        if (&other != this)\n        {\n            m_object = other.m_object;\n            m_it = other.m_it;\n        }\n        return *this;\n    }\n\n    /*!\n    @brief converting constructor\n    @param[in] other  non-const iterator to copy from\n    @note It is not checked whether @a other is initialized.\n    */\n    iter_impl(const iter_impl<typename std::remove_const<BasicJsonType>::type>& other) noexcept\n        : m_object(other.m_object), m_it(other.m_it)\n    {}\n\n    /*!\n    @brief converting assignment\n    @param[in] other  non-const iterator to copy from\n    @return const/non-const iterator\n    @note It is not checked whether @a other is initialized.\n    */\n    iter_impl& operator=(const iter_impl<typename std::remove_const<BasicJsonType>::type>& other) noexcept // NOLINT(cert-oop54-cpp)\n    {\n        m_object = other.m_object;\n        m_it = other.m_it;\n        return *this;\n    }\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    /*!\n    @brief set the iterator to the first value\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    void set_begin() noexcept\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n            {\n                m_it.object_iterator = m_object->m_data.m_value.object->begin();\n                break;\n            }\n\n            case value_t::array:\n            {\n                m_it.array_iterator = m_object->m_data.m_value.array->begin();\n                break;\n            }\n\n            case value_t::null:\n            {\n                // set to end so begin()==end() is true: null is empty\n                m_it.primitive_iterator.set_end();\n                break;\n            }\n\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                m_it.primitive_iterator.set_begin();\n                break;\n            }\n        }\n    }\n\n    /*!\n    @brief set the iterator past the last value\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    void set_end() noexcept\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n            {\n                m_it.object_iterator = m_object->m_data.m_value.object->end();\n                break;\n            }\n\n            case value_t::array:\n            {\n                m_it.array_iterator = m_object->m_data.m_value.array->end();\n                break;\n            }\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                m_it.primitive_iterator.set_end();\n                break;\n            }\n        }\n    }\n\n  public:\n    /*!\n    @brief return a reference to the value pointed to by the iterator\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    reference operator*() const\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n            {\n                JSON_ASSERT(m_it.object_iterator != m_object->m_data.m_value.object->end());\n                return m_it.object_iterator->second;\n            }\n\n            case value_t::array:\n            {\n                JSON_ASSERT(m_it.array_iterator != m_object->m_data.m_value.array->end());\n                return *m_it.array_iterator;\n            }\n\n            case value_t::null:\n                JSON_THROW(invalid_iterator::create(214, \"cannot get value\", m_object));\n\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.is_begin()))\n                {\n                    return *m_object;\n                }\n\n                JSON_THROW(invalid_iterator::create(214, \"cannot get value\", m_object));\n            }\n        }\n    }\n\n    /*!\n    @brief dereference the iterator\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    pointer operator->() const\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n            {\n                JSON_ASSERT(m_it.object_iterator != m_object->m_data.m_value.object->end());\n                return &(m_it.object_iterator->second);\n            }\n\n            case value_t::array:\n            {\n                JSON_ASSERT(m_it.array_iterator != m_object->m_data.m_value.array->end());\n                return &*m_it.array_iterator;\n            }\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.is_begin()))\n                {\n                    return m_object;\n                }\n\n                JSON_THROW(invalid_iterator::create(214, \"cannot get value\", m_object));\n            }\n        }\n    }\n\n    /*!\n    @brief post-increment (it++)\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    iter_impl operator++(int)& // NOLINT(cert-dcl21-cpp)\n    {\n        auto result = *this;\n        ++(*this);\n        return result;\n    }\n\n    /*!\n    @brief pre-increment (++it)\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    iter_impl& operator++()\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n            {\n                std::advance(m_it.object_iterator, 1);\n                break;\n            }\n\n            case value_t::array:\n            {\n                std::advance(m_it.array_iterator, 1);\n                break;\n            }\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                ++m_it.primitive_iterator;\n                break;\n            }\n        }\n\n        return *this;\n    }\n\n    /*!\n    @brief post-decrement (it--)\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    iter_impl operator--(int)& // NOLINT(cert-dcl21-cpp)\n    {\n        auto result = *this;\n        --(*this);\n        return result;\n    }\n\n    /*!\n    @brief pre-decrement (--it)\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    iter_impl& operator--()\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n            {\n                std::advance(m_it.object_iterator, -1);\n                break;\n            }\n\n            case value_t::array:\n            {\n                std::advance(m_it.array_iterator, -1);\n                break;\n            }\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                --m_it.primitive_iterator;\n                break;\n            }\n        }\n\n        return *this;\n    }\n\n    /*!\n    @brief comparison: equal\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    template < typename IterImpl, detail::enable_if_t < (std::is_same<IterImpl, iter_impl>::value || std::is_same<IterImpl, other_iter_impl>::value), std::nullptr_t > = nullptr >\n    bool operator==(const IterImpl& other) const\n    {\n        // if objects are not the same, the comparison is undefined\n        if (JSON_HEDLEY_UNLIKELY(m_object != other.m_object))\n        {\n            JSON_THROW(invalid_iterator::create(212, \"cannot compare iterators of different containers\", m_object));\n        }\n\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n                return (m_it.object_iterator == other.m_it.object_iterator);\n\n            case value_t::array:\n                return (m_it.array_iterator == other.m_it.array_iterator);\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n                return (m_it.primitive_iterator == other.m_it.primitive_iterator);\n        }\n    }\n\n    /*!\n    @brief comparison: not equal\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    template < typename IterImpl, detail::enable_if_t < (std::is_same<IterImpl, iter_impl>::value || std::is_same<IterImpl, other_iter_impl>::value), std::nullptr_t > = nullptr >\n    bool operator!=(const IterImpl& other) const\n    {\n        return !operator==(other);\n    }\n\n    /*!\n    @brief comparison: smaller\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    bool operator<(const iter_impl& other) const\n    {\n        // if objects are not the same, the comparison is undefined\n        if (JSON_HEDLEY_UNLIKELY(m_object != other.m_object))\n        {\n            JSON_THROW(invalid_iterator::create(212, \"cannot compare iterators of different containers\", m_object));\n        }\n\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n                JSON_THROW(invalid_iterator::create(213, \"cannot compare order of object iterators\", m_object));\n\n            case value_t::array:\n                return (m_it.array_iterator < other.m_it.array_iterator);\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n                return (m_it.primitive_iterator < other.m_it.primitive_iterator);\n        }\n    }\n\n    /*!\n    @brief comparison: less than or equal\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    bool operator<=(const iter_impl& other) const\n    {\n        return !other.operator < (*this);\n    }\n\n    /*!\n    @brief comparison: greater than\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    bool operator>(const iter_impl& other) const\n    {\n        return !operator<=(other);\n    }\n\n    /*!\n    @brief comparison: greater than or equal\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    bool operator>=(const iter_impl& other) const\n    {\n        return !operator<(other);\n    }\n\n    /*!\n    @brief add to iterator\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    iter_impl& operator+=(difference_type i)\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n                JSON_THROW(invalid_iterator::create(209, \"cannot use offsets with object iterators\", m_object));\n\n            case value_t::array:\n            {\n                std::advance(m_it.array_iterator, i);\n                break;\n            }\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                m_it.primitive_iterator += i;\n                break;\n            }\n        }\n\n        return *this;\n    }\n\n    /*!\n    @brief subtract from iterator\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    iter_impl& operator-=(difference_type i)\n    {\n        return operator+=(-i);\n    }\n\n    /*!\n    @brief add to iterator\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    iter_impl operator+(difference_type i) const\n    {\n        auto result = *this;\n        result += i;\n        return result;\n    }\n\n    /*!\n    @brief addition of distance and iterator\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    friend iter_impl operator+(difference_type i, const iter_impl& it)\n    {\n        auto result = it;\n        result += i;\n        return result;\n    }\n\n    /*!\n    @brief subtract from iterator\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    iter_impl operator-(difference_type i) const\n    {\n        auto result = *this;\n        result -= i;\n        return result;\n    }\n\n    /*!\n    @brief return difference\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    difference_type operator-(const iter_impl& other) const\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n                JSON_THROW(invalid_iterator::create(209, \"cannot use offsets with object iterators\", m_object));\n\n            case value_t::array:\n                return m_it.array_iterator - other.m_it.array_iterator;\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n                return m_it.primitive_iterator - other.m_it.primitive_iterator;\n        }\n    }\n\n    /*!\n    @brief access to successor\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    reference operator[](difference_type n) const\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        switch (m_object->m_data.m_type)\n        {\n            case value_t::object:\n                JSON_THROW(invalid_iterator::create(208, \"cannot use operator[] for object iterators\", m_object));\n\n            case value_t::array:\n                return *std::next(m_it.array_iterator, n);\n\n            case value_t::null:\n                JSON_THROW(invalid_iterator::create(214, \"cannot get value\", m_object));\n\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                if (JSON_HEDLEY_LIKELY(m_it.primitive_iterator.get_value() == -n))\n                {\n                    return *m_object;\n                }\n\n                JSON_THROW(invalid_iterator::create(214, \"cannot get value\", m_object));\n            }\n        }\n    }\n\n    /*!\n    @brief return the key of an object iterator\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    const typename object_t::key_type& key() const\n    {\n        JSON_ASSERT(m_object != nullptr);\n\n        if (JSON_HEDLEY_LIKELY(m_object->is_object()))\n        {\n            return m_it.object_iterator->first;\n        }\n\n        JSON_THROW(invalid_iterator::create(207, \"cannot use key() for non-object iterators\", m_object));\n    }\n\n    /*!\n    @brief return the value of an iterator\n    @pre The iterator is initialized; i.e. `m_object != nullptr`.\n    */\n    reference value() const\n    {\n        return operator*();\n    }\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    /// associated JSON instance\n    pointer m_object = nullptr;\n    /// the actual iterator of the associated instance\n    internal_iterator<typename std::remove_const<BasicJsonType>::type> m_it {};\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/iterators/iteration_proxy.hpp>\n\n// #include <nlohmann/detail/iterators/json_reverse_iterator.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <cstddef> // ptrdiff_t\n#include <iterator> // reverse_iterator\n#include <utility> // declval\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n//////////////////////\n// reverse_iterator //\n//////////////////////\n\n/*!\n@brief a template for a reverse iterator class\n\n@tparam Base the base iterator type to reverse. Valid types are @ref\niterator (to create @ref reverse_iterator) and @ref const_iterator (to\ncreate @ref const_reverse_iterator).\n\n@requirement The class satisfies the following concept requirements:\n-\n[BidirectionalIterator](https://en.cppreference.com/w/cpp/named_req/BidirectionalIterator):\n  The iterator that can be moved can be moved in both directions (i.e.\n  incremented and decremented).\n- [OutputIterator](https://en.cppreference.com/w/cpp/named_req/OutputIterator):\n  It is possible to write to the pointed-to element (only if @a Base is\n  @ref iterator).\n\n@since version 1.0.0\n*/\ntemplate<typename Base>\nclass json_reverse_iterator : public std::reverse_iterator<Base>\n{\n  public:\n    using difference_type = std::ptrdiff_t;\n    /// shortcut to the reverse iterator adapter\n    using base_iterator = std::reverse_iterator<Base>;\n    /// the reference type for the pointed-to element\n    using reference = typename Base::reference;\n\n    /// create reverse iterator from iterator\n    explicit json_reverse_iterator(const typename base_iterator::iterator_type& it) noexcept\n        : base_iterator(it) {}\n\n    /// create reverse iterator from base class\n    explicit json_reverse_iterator(const base_iterator& it) noexcept : base_iterator(it) {}\n\n    /// post-increment (it++)\n    json_reverse_iterator operator++(int)& // NOLINT(cert-dcl21-cpp)\n    {\n        return static_cast<json_reverse_iterator>(base_iterator::operator++(1));\n    }\n\n    /// pre-increment (++it)\n    json_reverse_iterator& operator++()\n    {\n        return static_cast<json_reverse_iterator&>(base_iterator::operator++());\n    }\n\n    /// post-decrement (it--)\n    json_reverse_iterator operator--(int)& // NOLINT(cert-dcl21-cpp)\n    {\n        return static_cast<json_reverse_iterator>(base_iterator::operator--(1));\n    }\n\n    /// pre-decrement (--it)\n    json_reverse_iterator& operator--()\n    {\n        return static_cast<json_reverse_iterator&>(base_iterator::operator--());\n    }\n\n    /// add to iterator\n    json_reverse_iterator& operator+=(difference_type i)\n    {\n        return static_cast<json_reverse_iterator&>(base_iterator::operator+=(i));\n    }\n\n    /// add to iterator\n    json_reverse_iterator operator+(difference_type i) const\n    {\n        return static_cast<json_reverse_iterator>(base_iterator::operator+(i));\n    }\n\n    /// subtract from iterator\n    json_reverse_iterator operator-(difference_type i) const\n    {\n        return static_cast<json_reverse_iterator>(base_iterator::operator-(i));\n    }\n\n    /// return difference\n    difference_type operator-(const json_reverse_iterator& other) const\n    {\n        return base_iterator(*this) - base_iterator(other);\n    }\n\n    /// access to successor\n    reference operator[](difference_type n) const\n    {\n        return *(this->operator+(n));\n    }\n\n    /// return the key of an object iterator\n    auto key() const -> decltype(std::declval<Base>().key())\n    {\n        auto it = --this->base();\n        return it.key();\n    }\n\n    /// return the value of an iterator\n    reference value() const\n    {\n        auto it = --this->base();\n        return it.operator * ();\n    }\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/iterators/primitive_iterator.hpp>\n\n// #include <nlohmann/detail/json_custom_base_class.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n\n#include <type_traits> // conditional, is_same\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n/*!\n@brief Default base class of the @ref basic_json class.\n\nSo that the correct implementations of the copy / move ctors / assign operators\nof @ref basic_json do not require complex case distinctions\n(no base class / custom base class used as customization point),\n@ref basic_json always has a base class.\nBy default, this class is used because it is empty and thus has no effect\non the behavior of @ref basic_json.\n*/\nstruct json_default_base {};\n\ntemplate<class T>\nusing json_base_class = typename std::conditional <\n                        std::is_same<T, void>::value,\n                        json_default_base,\n                        T\n                        >::type;\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/json_pointer.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <algorithm> // all_of\n#include <cctype> // isdigit\n#include <cerrno> // errno, ERANGE\n#include <cstdlib> // strtoull\n#ifndef JSON_NO_IO\n    #include <iosfwd> // ostream\n#endif  // JSON_NO_IO\n#include <limits> // max\n#include <numeric> // accumulate\n#include <string> // string\n#include <utility> // move\n#include <vector> // vector\n\n// #include <nlohmann/detail/exceptions.hpp>\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/string_concat.hpp>\n\n// #include <nlohmann/detail/string_escape.hpp>\n\n// #include <nlohmann/detail/value_t.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\n\n/// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document\n/// @sa https://json.nlohmann.me/api/json_pointer/\ntemplate<typename RefStringType>\nclass json_pointer\n{\n    // allow basic_json to access private members\n    NLOHMANN_BASIC_JSON_TPL_DECLARATION\n    friend class basic_json;\n\n    template<typename>\n    friend class json_pointer;\n\n    template<typename T>\n    struct string_t_helper\n    {\n        using type = T;\n    };\n\n    NLOHMANN_BASIC_JSON_TPL_DECLARATION\n    struct string_t_helper<NLOHMANN_BASIC_JSON_TPL>\n    {\n        using type = StringType;\n    };\n\n  public:\n    // for backwards compatibility accept BasicJsonType\n    using string_t = typename string_t_helper<RefStringType>::type;\n\n    /// @brief create JSON pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/json_pointer/\n    explicit json_pointer(const string_t& s = \"\")\n        : reference_tokens(split(s))\n    {}\n\n    /// @brief return a string representation of the JSON pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/to_string/\n    string_t to_string() const\n    {\n        return std::accumulate(reference_tokens.begin(), reference_tokens.end(),\n                               string_t{},\n                               [](const string_t& a, const string_t& b)\n        {\n            return detail::concat(a, '/', detail::escape(b));\n        });\n    }\n\n    /// @brief return a string representation of the JSON pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_string/\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.0, to_string())\n    operator string_t() const\n    {\n        return to_string();\n    }\n\n#ifndef JSON_NO_IO\n    /// @brief write string representation of the JSON pointer to stream\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ltlt/\n    friend std::ostream& operator<<(std::ostream& o, const json_pointer& ptr)\n    {\n        o << ptr.to_string();\n        return o;\n    }\n#endif\n\n    /// @brief append another JSON pointer at the end of this JSON pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_slasheq/\n    json_pointer& operator/=(const json_pointer& ptr)\n    {\n        reference_tokens.insert(reference_tokens.end(),\n                                ptr.reference_tokens.begin(),\n                                ptr.reference_tokens.end());\n        return *this;\n    }\n\n    /// @brief append an unescaped reference token at the end of this JSON pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_slasheq/\n    json_pointer& operator/=(string_t token)\n    {\n        push_back(std::move(token));\n        return *this;\n    }\n\n    /// @brief append an array index at the end of this JSON pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_slasheq/\n    json_pointer& operator/=(std::size_t array_idx)\n    {\n        return *this /= std::to_string(array_idx);\n    }\n\n    /// @brief create a new JSON pointer by appending the right JSON pointer at the end of the left JSON pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_slash/\n    friend json_pointer operator/(const json_pointer& lhs,\n                                  const json_pointer& rhs)\n    {\n        return json_pointer(lhs) /= rhs;\n    }\n\n    /// @brief create a new JSON pointer by appending the unescaped token at the end of the JSON pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_slash/\n    friend json_pointer operator/(const json_pointer& lhs, string_t token) // NOLINT(performance-unnecessary-value-param)\n    {\n        return json_pointer(lhs) /= std::move(token);\n    }\n\n    /// @brief create a new JSON pointer by appending the array-index-token at the end of the JSON pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_slash/\n    friend json_pointer operator/(const json_pointer& lhs, std::size_t array_idx)\n    {\n        return json_pointer(lhs) /= array_idx;\n    }\n\n    /// @brief returns the parent of this JSON pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/parent_pointer/\n    json_pointer parent_pointer() const\n    {\n        if (empty())\n        {\n            return *this;\n        }\n\n        json_pointer res = *this;\n        res.pop_back();\n        return res;\n    }\n\n    /// @brief remove last reference token\n    /// @sa https://json.nlohmann.me/api/json_pointer/pop_back/\n    void pop_back()\n    {\n        if (JSON_HEDLEY_UNLIKELY(empty()))\n        {\n            JSON_THROW(detail::out_of_range::create(405, \"JSON pointer has no parent\", nullptr));\n        }\n\n        reference_tokens.pop_back();\n    }\n\n    /// @brief return last reference token\n    /// @sa https://json.nlohmann.me/api/json_pointer/back/\n    const string_t& back() const\n    {\n        if (JSON_HEDLEY_UNLIKELY(empty()))\n        {\n            JSON_THROW(detail::out_of_range::create(405, \"JSON pointer has no parent\", nullptr));\n        }\n\n        return reference_tokens.back();\n    }\n\n    /// @brief append an unescaped token at the end of the reference pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/push_back/\n    void push_back(const string_t& token)\n    {\n        reference_tokens.push_back(token);\n    }\n\n    /// @brief append an unescaped token at the end of the reference pointer\n    /// @sa https://json.nlohmann.me/api/json_pointer/push_back/\n    void push_back(string_t&& token)\n    {\n        reference_tokens.push_back(std::move(token));\n    }\n\n    /// @brief return whether pointer points to the root document\n    /// @sa https://json.nlohmann.me/api/json_pointer/empty/\n    bool empty() const noexcept\n    {\n        return reference_tokens.empty();\n    }\n\n  private:\n    /*!\n    @param[in] s  reference token to be converted into an array index\n\n    @return integer representation of @a s\n\n    @throw parse_error.106  if an array index begins with '0'\n    @throw parse_error.109  if an array index begins not with a digit\n    @throw out_of_range.404 if string @a s could not be converted to an integer\n    @throw out_of_range.410 if an array index exceeds size_type\n    */\n    template<typename BasicJsonType>\n    static typename BasicJsonType::size_type array_index(const string_t& s)\n    {\n        using size_type = typename BasicJsonType::size_type;\n\n        // error condition (cf. RFC 6901, Sect. 4)\n        if (JSON_HEDLEY_UNLIKELY(s.size() > 1 && s[0] == '0'))\n        {\n            JSON_THROW(detail::parse_error::create(106, 0, detail::concat(\"array index '\", s, \"' must not begin with '0'\"), nullptr));\n        }\n\n        // error condition (cf. RFC 6901, Sect. 4)\n        if (JSON_HEDLEY_UNLIKELY(s.size() > 1 && !(s[0] >= '1' && s[0] <= '9')))\n        {\n            JSON_THROW(detail::parse_error::create(109, 0, detail::concat(\"array index '\", s, \"' is not a number\"), nullptr));\n        }\n\n        const char* p = s.c_str();\n        char* p_end = nullptr;\n        errno = 0; // strtoull doesn't reset errno\n        const unsigned long long res = std::strtoull(p, &p_end, 10); // NOLINT(runtime/int)\n        if (p == p_end // invalid input or empty string\n                || errno == ERANGE // out of range\n                || JSON_HEDLEY_UNLIKELY(static_cast<std::size_t>(p_end - p) != s.size())) // incomplete read\n        {\n            JSON_THROW(detail::out_of_range::create(404, detail::concat(\"unresolved reference token '\", s, \"'\"), nullptr));\n        }\n\n        // only triggered on special platforms (like 32bit), see also\n        // https://github.com/nlohmann/json/pull/2203\n        if (res >= static_cast<unsigned long long>((std::numeric_limits<size_type>::max)()))  // NOLINT(runtime/int)\n        {\n            JSON_THROW(detail::out_of_range::create(410, detail::concat(\"array index \", s, \" exceeds size_type\"), nullptr));   // LCOV_EXCL_LINE\n        }\n\n        return static_cast<size_type>(res);\n    }\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    json_pointer top() const\n    {\n        if (JSON_HEDLEY_UNLIKELY(empty()))\n        {\n            JSON_THROW(detail::out_of_range::create(405, \"JSON pointer has no parent\", nullptr));\n        }\n\n        json_pointer result = *this;\n        result.reference_tokens = {reference_tokens[0]};\n        return result;\n    }\n\n  private:\n    /*!\n    @brief create and return a reference to the pointed to value\n\n    @complexity Linear in the number of reference tokens.\n\n    @throw parse_error.109 if array index is not a number\n    @throw type_error.313 if value cannot be unflattened\n    */\n    template<typename BasicJsonType>\n    BasicJsonType& get_and_create(BasicJsonType& j) const\n    {\n        auto* result = &j;\n\n        // in case no reference tokens exist, return a reference to the JSON value\n        // j which will be overwritten by a primitive value\n        for (const auto& reference_token : reference_tokens)\n        {\n            switch (result->type())\n            {\n                case detail::value_t::null:\n                {\n                    if (reference_token == \"0\")\n                    {\n                        // start a new array if reference token is 0\n                        result = &result->operator[](0);\n                    }\n                    else\n                    {\n                        // start a new object otherwise\n                        result = &result->operator[](reference_token);\n                    }\n                    break;\n                }\n\n                case detail::value_t::object:\n                {\n                    // create an entry in the object\n                    result = &result->operator[](reference_token);\n                    break;\n                }\n\n                case detail::value_t::array:\n                {\n                    // create an entry in the array\n                    result = &result->operator[](array_index<BasicJsonType>(reference_token));\n                    break;\n                }\n\n                /*\n                The following code is only reached if there exists a reference\n                token _and_ the current value is primitive. In this case, we have\n                an error situation, because primitive values may only occur as\n                single value; that is, with an empty list of reference tokens.\n                */\n                case detail::value_t::string:\n                case detail::value_t::boolean:\n                case detail::value_t::number_integer:\n                case detail::value_t::number_unsigned:\n                case detail::value_t::number_float:\n                case detail::value_t::binary:\n                case detail::value_t::discarded:\n                default:\n                    JSON_THROW(detail::type_error::create(313, \"invalid value to unflatten\", &j));\n            }\n        }\n\n        return *result;\n    }\n\n    /*!\n    @brief return a reference to the pointed to value\n\n    @note This version does not throw if a value is not present, but tries to\n          create nested values instead. For instance, calling this function\n          with pointer `\"/this/that\"` on a null value is equivalent to calling\n          `operator[](\"this\").operator[](\"that\")` on that value, effectively\n          changing the null value to an object.\n\n    @param[in] ptr  a JSON value\n\n    @return reference to the JSON value pointed to by the JSON pointer\n\n    @complexity Linear in the length of the JSON pointer.\n\n    @throw parse_error.106   if an array index begins with '0'\n    @throw parse_error.109   if an array index was not a number\n    @throw out_of_range.404  if the JSON pointer can not be resolved\n    */\n    template<typename BasicJsonType>\n    BasicJsonType& get_unchecked(BasicJsonType* ptr) const\n    {\n        for (const auto& reference_token : reference_tokens)\n        {\n            // convert null values to arrays or objects before continuing\n            if (ptr->is_null())\n            {\n                // check if reference token is a number\n                const bool nums =\n                    std::all_of(reference_token.begin(), reference_token.end(),\n                                [](const unsigned char x)\n                {\n                    return std::isdigit(x);\n                });\n\n                // change value to array for numbers or \"-\" or to object otherwise\n                *ptr = (nums || reference_token == \"-\")\n                       ? detail::value_t::array\n                       : detail::value_t::object;\n            }\n\n            switch (ptr->type())\n            {\n                case detail::value_t::object:\n                {\n                    // use unchecked object access\n                    ptr = &ptr->operator[](reference_token);\n                    break;\n                }\n\n                case detail::value_t::array:\n                {\n                    if (reference_token == \"-\")\n                    {\n                        // explicitly treat \"-\" as index beyond the end\n                        ptr = &ptr->operator[](ptr->m_data.m_value.array->size());\n                    }\n                    else\n                    {\n                        // convert array index to number; unchecked access\n                        ptr = &ptr->operator[](array_index<BasicJsonType>(reference_token));\n                    }\n                    break;\n                }\n\n                case detail::value_t::null:\n                case detail::value_t::string:\n                case detail::value_t::boolean:\n                case detail::value_t::number_integer:\n                case detail::value_t::number_unsigned:\n                case detail::value_t::number_float:\n                case detail::value_t::binary:\n                case detail::value_t::discarded:\n                default:\n                    JSON_THROW(detail::out_of_range::create(404, detail::concat(\"unresolved reference token '\", reference_token, \"'\"), ptr));\n            }\n        }\n\n        return *ptr;\n    }\n\n    /*!\n    @throw parse_error.106   if an array index begins with '0'\n    @throw parse_error.109   if an array index was not a number\n    @throw out_of_range.402  if the array index '-' is used\n    @throw out_of_range.404  if the JSON pointer can not be resolved\n    */\n    template<typename BasicJsonType>\n    BasicJsonType& get_checked(BasicJsonType* ptr) const\n    {\n        for (const auto& reference_token : reference_tokens)\n        {\n            switch (ptr->type())\n            {\n                case detail::value_t::object:\n                {\n                    // note: at performs range check\n                    ptr = &ptr->at(reference_token);\n                    break;\n                }\n\n                case detail::value_t::array:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(reference_token == \"-\"))\n                    {\n                        // \"-\" always fails the range check\n                        JSON_THROW(detail::out_of_range::create(402, detail::concat(\n                                \"array index '-' (\", std::to_string(ptr->m_data.m_value.array->size()),\n                                \") is out of range\"), ptr));\n                    }\n\n                    // note: at performs range check\n                    ptr = &ptr->at(array_index<BasicJsonType>(reference_token));\n                    break;\n                }\n\n                case detail::value_t::null:\n                case detail::value_t::string:\n                case detail::value_t::boolean:\n                case detail::value_t::number_integer:\n                case detail::value_t::number_unsigned:\n                case detail::value_t::number_float:\n                case detail::value_t::binary:\n                case detail::value_t::discarded:\n                default:\n                    JSON_THROW(detail::out_of_range::create(404, detail::concat(\"unresolved reference token '\", reference_token, \"'\"), ptr));\n            }\n        }\n\n        return *ptr;\n    }\n\n    /*!\n    @brief return a const reference to the pointed to value\n\n    @param[in] ptr  a JSON value\n\n    @return const reference to the JSON value pointed to by the JSON\n    pointer\n\n    @throw parse_error.106   if an array index begins with '0'\n    @throw parse_error.109   if an array index was not a number\n    @throw out_of_range.402  if the array index '-' is used\n    @throw out_of_range.404  if the JSON pointer can not be resolved\n    */\n    template<typename BasicJsonType>\n    const BasicJsonType& get_unchecked(const BasicJsonType* ptr) const\n    {\n        for (const auto& reference_token : reference_tokens)\n        {\n            switch (ptr->type())\n            {\n                case detail::value_t::object:\n                {\n                    // use unchecked object access\n                    ptr = &ptr->operator[](reference_token);\n                    break;\n                }\n\n                case detail::value_t::array:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(reference_token == \"-\"))\n                    {\n                        // \"-\" cannot be used for const access\n                        JSON_THROW(detail::out_of_range::create(402, detail::concat(\"array index '-' (\", std::to_string(ptr->m_data.m_value.array->size()), \") is out of range\"), ptr));\n                    }\n\n                    // use unchecked array access\n                    ptr = &ptr->operator[](array_index<BasicJsonType>(reference_token));\n                    break;\n                }\n\n                case detail::value_t::null:\n                case detail::value_t::string:\n                case detail::value_t::boolean:\n                case detail::value_t::number_integer:\n                case detail::value_t::number_unsigned:\n                case detail::value_t::number_float:\n                case detail::value_t::binary:\n                case detail::value_t::discarded:\n                default:\n                    JSON_THROW(detail::out_of_range::create(404, detail::concat(\"unresolved reference token '\", reference_token, \"'\"), ptr));\n            }\n        }\n\n        return *ptr;\n    }\n\n    /*!\n    @throw parse_error.106   if an array index begins with '0'\n    @throw parse_error.109   if an array index was not a number\n    @throw out_of_range.402  if the array index '-' is used\n    @throw out_of_range.404  if the JSON pointer can not be resolved\n    */\n    template<typename BasicJsonType>\n    const BasicJsonType& get_checked(const BasicJsonType* ptr) const\n    {\n        for (const auto& reference_token : reference_tokens)\n        {\n            switch (ptr->type())\n            {\n                case detail::value_t::object:\n                {\n                    // note: at performs range check\n                    ptr = &ptr->at(reference_token);\n                    break;\n                }\n\n                case detail::value_t::array:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(reference_token == \"-\"))\n                    {\n                        // \"-\" always fails the range check\n                        JSON_THROW(detail::out_of_range::create(402, detail::concat(\n                                \"array index '-' (\", std::to_string(ptr->m_data.m_value.array->size()),\n                                \") is out of range\"), ptr));\n                    }\n\n                    // note: at performs range check\n                    ptr = &ptr->at(array_index<BasicJsonType>(reference_token));\n                    break;\n                }\n\n                case detail::value_t::null:\n                case detail::value_t::string:\n                case detail::value_t::boolean:\n                case detail::value_t::number_integer:\n                case detail::value_t::number_unsigned:\n                case detail::value_t::number_float:\n                case detail::value_t::binary:\n                case detail::value_t::discarded:\n                default:\n                    JSON_THROW(detail::out_of_range::create(404, detail::concat(\"unresolved reference token '\", reference_token, \"'\"), ptr));\n            }\n        }\n\n        return *ptr;\n    }\n\n    /*!\n    @throw parse_error.106   if an array index begins with '0'\n    @throw parse_error.109   if an array index was not a number\n    */\n    template<typename BasicJsonType>\n    bool contains(const BasicJsonType* ptr) const\n    {\n        for (const auto& reference_token : reference_tokens)\n        {\n            switch (ptr->type())\n            {\n                case detail::value_t::object:\n                {\n                    if (!ptr->contains(reference_token))\n                    {\n                        // we did not find the key in the object\n                        return false;\n                    }\n\n                    ptr = &ptr->operator[](reference_token);\n                    break;\n                }\n\n                case detail::value_t::array:\n                {\n                    if (JSON_HEDLEY_UNLIKELY(reference_token == \"-\"))\n                    {\n                        // \"-\" always fails the range check\n                        return false;\n                    }\n                    if (JSON_HEDLEY_UNLIKELY(reference_token.size() == 1 && !(\"0\" <= reference_token && reference_token <= \"9\")))\n                    {\n                        // invalid char\n                        return false;\n                    }\n                    if (JSON_HEDLEY_UNLIKELY(reference_token.size() > 1))\n                    {\n                        if (JSON_HEDLEY_UNLIKELY(!('1' <= reference_token[0] && reference_token[0] <= '9')))\n                        {\n                            // first char should be between '1' and '9'\n                            return false;\n                        }\n                        for (std::size_t i = 1; i < reference_token.size(); i++)\n                        {\n                            if (JSON_HEDLEY_UNLIKELY(!('0' <= reference_token[i] && reference_token[i] <= '9')))\n                            {\n                                // other char should be between '0' and '9'\n                                return false;\n                            }\n                        }\n                    }\n\n                    const auto idx = array_index<BasicJsonType>(reference_token);\n                    if (idx >= ptr->size())\n                    {\n                        // index out of range\n                        return false;\n                    }\n\n                    ptr = &ptr->operator[](idx);\n                    break;\n                }\n\n                case detail::value_t::null:\n                case detail::value_t::string:\n                case detail::value_t::boolean:\n                case detail::value_t::number_integer:\n                case detail::value_t::number_unsigned:\n                case detail::value_t::number_float:\n                case detail::value_t::binary:\n                case detail::value_t::discarded:\n                default:\n                {\n                    // we do not expect primitive values if there is still a\n                    // reference token to process\n                    return false;\n                }\n            }\n        }\n\n        // no reference token left means we found a primitive value\n        return true;\n    }\n\n    /*!\n    @brief split the string input to reference tokens\n\n    @note This function is only called by the json_pointer constructor.\n          All exceptions below are documented there.\n\n    @throw parse_error.107  if the pointer is not empty or begins with '/'\n    @throw parse_error.108  if character '~' is not followed by '0' or '1'\n    */\n    static std::vector<string_t> split(const string_t& reference_string)\n    {\n        std::vector<string_t> result;\n\n        // special case: empty reference string -> no reference tokens\n        if (reference_string.empty())\n        {\n            return result;\n        }\n\n        // check if nonempty reference string begins with slash\n        if (JSON_HEDLEY_UNLIKELY(reference_string[0] != '/'))\n        {\n            JSON_THROW(detail::parse_error::create(107, 1, detail::concat(\"JSON pointer must be empty or begin with '/' - was: '\", reference_string, \"'\"), nullptr));\n        }\n\n        // extract the reference tokens:\n        // - slash: position of the last read slash (or end of string)\n        // - start: position after the previous slash\n        for (\n            // search for the first slash after the first character\n            std::size_t slash = reference_string.find_first_of('/', 1),\n            // set the beginning of the first reference token\n            start = 1;\n            // we can stop if start == 0 (if slash == string_t::npos)\n            start != 0;\n            // set the beginning of the next reference token\n            // (will eventually be 0 if slash == string_t::npos)\n            start = (slash == string_t::npos) ? 0 : slash + 1,\n            // find next slash\n            slash = reference_string.find_first_of('/', start))\n        {\n            // use the text between the beginning of the reference token\n            // (start) and the last slash (slash).\n            auto reference_token = reference_string.substr(start, slash - start);\n\n            // check reference tokens are properly escaped\n            for (std::size_t pos = reference_token.find_first_of('~');\n                    pos != string_t::npos;\n                    pos = reference_token.find_first_of('~', pos + 1))\n            {\n                JSON_ASSERT(reference_token[pos] == '~');\n\n                // ~ must be followed by 0 or 1\n                if (JSON_HEDLEY_UNLIKELY(pos == reference_token.size() - 1 ||\n                                         (reference_token[pos + 1] != '0' &&\n                                          reference_token[pos + 1] != '1')))\n                {\n                    JSON_THROW(detail::parse_error::create(108, 0, \"escape character '~' must be followed with '0' or '1'\", nullptr));\n                }\n            }\n\n            // finally, store the reference token\n            detail::unescape(reference_token);\n            result.push_back(reference_token);\n        }\n\n        return result;\n    }\n\n  private:\n    /*!\n    @param[in] reference_string  the reference string to the current value\n    @param[in] value             the value to consider\n    @param[in,out] result        the result object to insert values to\n\n    @note Empty objects or arrays are flattened to `null`.\n    */\n    template<typename BasicJsonType>\n    static void flatten(const string_t& reference_string,\n                        const BasicJsonType& value,\n                        BasicJsonType& result)\n    {\n        switch (value.type())\n        {\n            case detail::value_t::array:\n            {\n                if (value.m_data.m_value.array->empty())\n                {\n                    // flatten empty array as null\n                    result[reference_string] = nullptr;\n                }\n                else\n                {\n                    // iterate array and use index as reference string\n                    for (std::size_t i = 0; i < value.m_data.m_value.array->size(); ++i)\n                    {\n                        flatten(detail::concat(reference_string, '/', std::to_string(i)),\n                                value.m_data.m_value.array->operator[](i), result);\n                    }\n                }\n                break;\n            }\n\n            case detail::value_t::object:\n            {\n                if (value.m_data.m_value.object->empty())\n                {\n                    // flatten empty object as null\n                    result[reference_string] = nullptr;\n                }\n                else\n                {\n                    // iterate object and use keys as reference string\n                    for (const auto& element : *value.m_data.m_value.object)\n                    {\n                        flatten(detail::concat(reference_string, '/', detail::escape(element.first)), element.second, result);\n                    }\n                }\n                break;\n            }\n\n            case detail::value_t::null:\n            case detail::value_t::string:\n            case detail::value_t::boolean:\n            case detail::value_t::number_integer:\n            case detail::value_t::number_unsigned:\n            case detail::value_t::number_float:\n            case detail::value_t::binary:\n            case detail::value_t::discarded:\n            default:\n            {\n                // add primitive value with its reference string\n                result[reference_string] = value;\n                break;\n            }\n        }\n    }\n\n    /*!\n    @param[in] value  flattened JSON\n\n    @return unflattened JSON\n\n    @throw parse_error.109 if array index is not a number\n    @throw type_error.314  if value is not an object\n    @throw type_error.315  if object values are not primitive\n    @throw type_error.313  if value cannot be unflattened\n    */\n    template<typename BasicJsonType>\n    static BasicJsonType\n    unflatten(const BasicJsonType& value)\n    {\n        if (JSON_HEDLEY_UNLIKELY(!value.is_object()))\n        {\n            JSON_THROW(detail::type_error::create(314, \"only objects can be unflattened\", &value));\n        }\n\n        BasicJsonType result;\n\n        // iterate the JSON object values\n        for (const auto& element : *value.m_data.m_value.object)\n        {\n            if (JSON_HEDLEY_UNLIKELY(!element.second.is_primitive()))\n            {\n                JSON_THROW(detail::type_error::create(315, \"values in object must be primitive\", &element.second));\n            }\n\n            // assign value to reference pointed to by JSON pointer; Note that if\n            // the JSON pointer is \"\" (i.e., points to the whole value), function\n            // get_and_create returns a reference to result itself. An assignment\n            // will then create a primitive value.\n            json_pointer(element.first).get_and_create(result) = element.second;\n        }\n\n        return result;\n    }\n\n    // can't use conversion operator because of ambiguity\n    json_pointer<string_t> convert() const&\n    {\n        json_pointer<string_t> result;\n        result.reference_tokens = reference_tokens;\n        return result;\n    }\n\n    json_pointer<string_t> convert()&&\n    {\n        json_pointer<string_t> result;\n        result.reference_tokens = std::move(reference_tokens);\n        return result;\n    }\n\n  public:\n#if JSON_HAS_THREE_WAY_COMPARISON\n    /// @brief compares two JSON pointers for equality\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_eq/\n    template<typename RefStringTypeRhs>\n    bool operator==(const json_pointer<RefStringTypeRhs>& rhs) const noexcept\n    {\n        return reference_tokens == rhs.reference_tokens;\n    }\n\n    /// @brief compares JSON pointer and string for equality\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_eq/\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.2, operator==(json_pointer))\n    bool operator==(const string_t& rhs) const\n    {\n        return *this == json_pointer(rhs);\n    }\n\n    /// @brief 3-way compares two JSON pointers\n    template<typename RefStringTypeRhs>\n    std::strong_ordering operator<=>(const json_pointer<RefStringTypeRhs>& rhs) const noexcept // *NOPAD*\n    {\n        return  reference_tokens <=> rhs.reference_tokens; // *NOPAD*\n    }\n#else\n    /// @brief compares two JSON pointers for equality\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_eq/\n    template<typename RefStringTypeLhs, typename RefStringTypeRhs>\n    // NOLINTNEXTLINE(readability-redundant-declaration)\n    friend bool operator==(const json_pointer<RefStringTypeLhs>& lhs,\n                           const json_pointer<RefStringTypeRhs>& rhs) noexcept;\n\n    /// @brief compares JSON pointer and string for equality\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_eq/\n    template<typename RefStringTypeLhs, typename StringType>\n    // NOLINTNEXTLINE(readability-redundant-declaration)\n    friend bool operator==(const json_pointer<RefStringTypeLhs>& lhs,\n                           const StringType& rhs);\n\n    /// @brief compares string and JSON pointer for equality\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_eq/\n    template<typename RefStringTypeRhs, typename StringType>\n    // NOLINTNEXTLINE(readability-redundant-declaration)\n    friend bool operator==(const StringType& lhs,\n                           const json_pointer<RefStringTypeRhs>& rhs);\n\n    /// @brief compares two JSON pointers for inequality\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_ne/\n    template<typename RefStringTypeLhs, typename RefStringTypeRhs>\n    // NOLINTNEXTLINE(readability-redundant-declaration)\n    friend bool operator!=(const json_pointer<RefStringTypeLhs>& lhs,\n                           const json_pointer<RefStringTypeRhs>& rhs) noexcept;\n\n    /// @brief compares JSON pointer and string for inequality\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_ne/\n    template<typename RefStringTypeLhs, typename StringType>\n    // NOLINTNEXTLINE(readability-redundant-declaration)\n    friend bool operator!=(const json_pointer<RefStringTypeLhs>& lhs,\n                           const StringType& rhs);\n\n    /// @brief compares string and JSON pointer for inequality\n    /// @sa https://json.nlohmann.me/api/json_pointer/operator_ne/\n    template<typename RefStringTypeRhs, typename StringType>\n    // NOLINTNEXTLINE(readability-redundant-declaration)\n    friend bool operator!=(const StringType& lhs,\n                           const json_pointer<RefStringTypeRhs>& rhs);\n\n    /// @brief compares two JSON pointer for less-than\n    template<typename RefStringTypeLhs, typename RefStringTypeRhs>\n    // NOLINTNEXTLINE(readability-redundant-declaration)\n    friend bool operator<(const json_pointer<RefStringTypeLhs>& lhs,\n                          const json_pointer<RefStringTypeRhs>& rhs) noexcept;\n#endif\n\n  private:\n    /// the reference tokens\n    std::vector<string_t> reference_tokens;\n};\n\n#if !JSON_HAS_THREE_WAY_COMPARISON\n// functions cannot be defined inside class due to ODR violations\ntemplate<typename RefStringTypeLhs, typename RefStringTypeRhs>\ninline bool operator==(const json_pointer<RefStringTypeLhs>& lhs,\n                       const json_pointer<RefStringTypeRhs>& rhs) noexcept\n{\n    return lhs.reference_tokens == rhs.reference_tokens;\n}\n\ntemplate<typename RefStringTypeLhs,\n         typename StringType = typename json_pointer<RefStringTypeLhs>::string_t>\nJSON_HEDLEY_DEPRECATED_FOR(3.11.2, operator==(json_pointer, json_pointer))\ninline bool operator==(const json_pointer<RefStringTypeLhs>& lhs,\n                       const StringType& rhs)\n{\n    return lhs == json_pointer<RefStringTypeLhs>(rhs);\n}\n\ntemplate<typename RefStringTypeRhs,\n         typename StringType = typename json_pointer<RefStringTypeRhs>::string_t>\nJSON_HEDLEY_DEPRECATED_FOR(3.11.2, operator==(json_pointer, json_pointer))\ninline bool operator==(const StringType& lhs,\n                       const json_pointer<RefStringTypeRhs>& rhs)\n{\n    return json_pointer<RefStringTypeRhs>(lhs) == rhs;\n}\n\ntemplate<typename RefStringTypeLhs, typename RefStringTypeRhs>\ninline bool operator!=(const json_pointer<RefStringTypeLhs>& lhs,\n                       const json_pointer<RefStringTypeRhs>& rhs) noexcept\n{\n    return !(lhs == rhs);\n}\n\ntemplate<typename RefStringTypeLhs,\n         typename StringType = typename json_pointer<RefStringTypeLhs>::string_t>\nJSON_HEDLEY_DEPRECATED_FOR(3.11.2, operator!=(json_pointer, json_pointer))\ninline bool operator!=(const json_pointer<RefStringTypeLhs>& lhs,\n                       const StringType& rhs)\n{\n    return !(lhs == rhs);\n}\n\ntemplate<typename RefStringTypeRhs,\n         typename StringType = typename json_pointer<RefStringTypeRhs>::string_t>\nJSON_HEDLEY_DEPRECATED_FOR(3.11.2, operator!=(json_pointer, json_pointer))\ninline bool operator!=(const StringType& lhs,\n                       const json_pointer<RefStringTypeRhs>& rhs)\n{\n    return !(lhs == rhs);\n}\n\ntemplate<typename RefStringTypeLhs, typename RefStringTypeRhs>\ninline bool operator<(const json_pointer<RefStringTypeLhs>& lhs,\n                      const json_pointer<RefStringTypeRhs>& rhs) noexcept\n{\n    return lhs.reference_tokens < rhs.reference_tokens;\n}\n#endif\n\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/json_ref.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <initializer_list>\n#include <utility>\n\n// #include <nlohmann/detail/abi_macros.hpp>\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\ntemplate<typename BasicJsonType>\nclass json_ref\n{\n  public:\n    using value_type = BasicJsonType;\n\n    json_ref(value_type&& value)\n        : owned_value(std::move(value))\n    {}\n\n    json_ref(const value_type& value)\n        : value_ref(&value)\n    {}\n\n    json_ref(std::initializer_list<json_ref> init)\n        : owned_value(init)\n    {}\n\n    template <\n        class... Args,\n        enable_if_t<std::is_constructible<value_type, Args...>::value, int> = 0 >\n    json_ref(Args && ... args)\n        : owned_value(std::forward<Args>(args)...)\n    {}\n\n    // class should be movable only\n    json_ref(json_ref&&) noexcept = default;\n    json_ref(const json_ref&) = delete;\n    json_ref& operator=(const json_ref&) = delete;\n    json_ref& operator=(json_ref&&) = delete;\n    ~json_ref() = default;\n\n    value_type moved_or_copied() const\n    {\n        if (value_ref == nullptr)\n        {\n            return std::move(owned_value);\n        }\n        return *value_ref;\n    }\n\n    value_type const& operator*() const\n    {\n        return value_ref ? *value_ref : owned_value;\n    }\n\n    value_type const* operator->() const\n    {\n        return &** this;\n    }\n\n  private:\n    mutable value_type owned_value = nullptr;\n    value_type const* value_ref = nullptr;\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/string_concat.hpp>\n\n// #include <nlohmann/detail/string_escape.hpp>\n\n// #include <nlohmann/detail/meta/cpp_future.hpp>\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n// #include <nlohmann/detail/output/binary_writer.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <algorithm> // reverse\n#include <array> // array\n#include <map> // map\n#include <cmath> // isnan, isinf\n#include <cstdint> // uint8_t, uint16_t, uint32_t, uint64_t\n#include <cstring> // memcpy\n#include <limits> // numeric_limits\n#include <string> // string\n#include <utility> // move\n#include <vector> // vector\n\n// #include <nlohmann/detail/input/binary_reader.hpp>\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/output/output_adapters.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <algorithm> // copy\n#include <cstddef> // size_t\n#include <iterator> // back_inserter\n#include <memory> // shared_ptr, make_shared\n#include <string> // basic_string\n#include <vector> // vector\n\n#ifndef JSON_NO_IO\n    #include <ios>      // streamsize\n    #include <ostream>  // basic_ostream\n#endif  // JSON_NO_IO\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n/// abstract output adapter interface\ntemplate<typename CharType> struct output_adapter_protocol\n{\n    virtual void write_character(CharType c) = 0;\n    virtual void write_characters(const CharType* s, std::size_t length) = 0;\n    virtual ~output_adapter_protocol() = default;\n\n    output_adapter_protocol() = default;\n    output_adapter_protocol(const output_adapter_protocol&) = default;\n    output_adapter_protocol(output_adapter_protocol&&) noexcept = default;\n    output_adapter_protocol& operator=(const output_adapter_protocol&) = default;\n    output_adapter_protocol& operator=(output_adapter_protocol&&) noexcept = default;\n};\n\n/// a type to simplify interfaces\ntemplate<typename CharType>\nusing output_adapter_t = std::shared_ptr<output_adapter_protocol<CharType>>;\n\n/// output adapter for byte vectors\ntemplate<typename CharType, typename AllocatorType = std::allocator<CharType>>\nclass output_vector_adapter : public output_adapter_protocol<CharType>\n{\n  public:\n    explicit output_vector_adapter(std::vector<CharType, AllocatorType>& vec) noexcept\n        : v(vec)\n    {}\n\n    void write_character(CharType c) override\n    {\n        v.push_back(c);\n    }\n\n    JSON_HEDLEY_NON_NULL(2)\n    void write_characters(const CharType* s, std::size_t length) override\n    {\n        v.insert(v.end(), s, s + length);\n    }\n\n  private:\n    std::vector<CharType, AllocatorType>& v;\n};\n\n#ifndef JSON_NO_IO\n/// output adapter for output streams\ntemplate<typename CharType>\nclass output_stream_adapter : public output_adapter_protocol<CharType>\n{\n  public:\n    explicit output_stream_adapter(std::basic_ostream<CharType>& s) noexcept\n        : stream(s)\n    {}\n\n    void write_character(CharType c) override\n    {\n        stream.put(c);\n    }\n\n    JSON_HEDLEY_NON_NULL(2)\n    void write_characters(const CharType* s, std::size_t length) override\n    {\n        stream.write(s, static_cast<std::streamsize>(length));\n    }\n\n  private:\n    std::basic_ostream<CharType>& stream;\n};\n#endif  // JSON_NO_IO\n\n/// output adapter for basic_string\ntemplate<typename CharType, typename StringType = std::basic_string<CharType>>\nclass output_string_adapter : public output_adapter_protocol<CharType>\n{\n  public:\n    explicit output_string_adapter(StringType& s) noexcept\n        : str(s)\n    {}\n\n    void write_character(CharType c) override\n    {\n        str.push_back(c);\n    }\n\n    JSON_HEDLEY_NON_NULL(2)\n    void write_characters(const CharType* s, std::size_t length) override\n    {\n        str.append(s, length);\n    }\n\n  private:\n    StringType& str;\n};\n\ntemplate<typename CharType, typename StringType = std::basic_string<CharType>>\nclass output_adapter\n{\n  public:\n    template<typename AllocatorType = std::allocator<CharType>>\n    output_adapter(std::vector<CharType, AllocatorType>& vec)\n        : oa(std::make_shared<output_vector_adapter<CharType, AllocatorType>>(vec)) {}\n\n#ifndef JSON_NO_IO\n    output_adapter(std::basic_ostream<CharType>& s)\n        : oa(std::make_shared<output_stream_adapter<CharType>>(s)) {}\n#endif  // JSON_NO_IO\n\n    output_adapter(StringType& s)\n        : oa(std::make_shared<output_string_adapter<CharType, StringType>>(s)) {}\n\n    operator output_adapter_t<CharType>()\n    {\n        return oa;\n    }\n\n  private:\n    output_adapter_t<CharType> oa = nullptr;\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/string_concat.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n///////////////////\n// binary writer //\n///////////////////\n\n/*!\n@brief serialization to CBOR and MessagePack values\n*/\ntemplate<typename BasicJsonType, typename CharType>\nclass binary_writer\n{\n    using string_t = typename BasicJsonType::string_t;\n    using binary_t = typename BasicJsonType::binary_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n\n  public:\n    /*!\n    @brief create a binary writer\n\n    @param[in] adapter  output adapter to write to\n    */\n    explicit binary_writer(output_adapter_t<CharType> adapter) : oa(std::move(adapter))\n    {\n        JSON_ASSERT(oa);\n    }\n\n    /*!\n    @param[in] j  JSON value to serialize\n    @pre       j.type() == value_t::object\n    */\n    void write_bson(const BasicJsonType& j)\n    {\n        switch (j.type())\n        {\n            case value_t::object:\n            {\n                write_bson_object(*j.m_data.m_value.object);\n                break;\n            }\n\n            case value_t::null:\n            case value_t::array:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                JSON_THROW(type_error::create(317, concat(\"to serialize to BSON, top-level type must be object, but is \", j.type_name()), &j));\n            }\n        }\n    }\n\n    /*!\n    @param[in] j  JSON value to serialize\n    */\n    void write_cbor(const BasicJsonType& j)\n    {\n        switch (j.type())\n        {\n            case value_t::null:\n            {\n                oa->write_character(to_char_type(0xF6));\n                break;\n            }\n\n            case value_t::boolean:\n            {\n                oa->write_character(j.m_data.m_value.boolean\n                                    ? to_char_type(0xF5)\n                                    : to_char_type(0xF4));\n                break;\n            }\n\n            case value_t::number_integer:\n            {\n                if (j.m_data.m_value.number_integer >= 0)\n                {\n                    // CBOR does not differentiate between positive signed\n                    // integers and unsigned integers. Therefore, we used the\n                    // code from the value_t::number_unsigned case here.\n                    if (j.m_data.m_value.number_integer <= 0x17)\n                    {\n                        write_number(static_cast<std::uint8_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_integer <= (std::numeric_limits<std::uint8_t>::max)())\n                    {\n                        oa->write_character(to_char_type(0x18));\n                        write_number(static_cast<std::uint8_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_integer <= (std::numeric_limits<std::uint16_t>::max)())\n                    {\n                        oa->write_character(to_char_type(0x19));\n                        write_number(static_cast<std::uint16_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_integer <= (std::numeric_limits<std::uint32_t>::max)())\n                    {\n                        oa->write_character(to_char_type(0x1A));\n                        write_number(static_cast<std::uint32_t>(j.m_data.m_value.number_integer));\n                    }\n                    else\n                    {\n                        oa->write_character(to_char_type(0x1B));\n                        write_number(static_cast<std::uint64_t>(j.m_data.m_value.number_integer));\n                    }\n                }\n                else\n                {\n                    // The conversions below encode the sign in the first\n                    // byte, and the value is converted to a positive number.\n                    const auto positive_number = -1 - j.m_data.m_value.number_integer;\n                    if (j.m_data.m_value.number_integer >= -24)\n                    {\n                        write_number(static_cast<std::uint8_t>(0x20 + positive_number));\n                    }\n                    else if (positive_number <= (std::numeric_limits<std::uint8_t>::max)())\n                    {\n                        oa->write_character(to_char_type(0x38));\n                        write_number(static_cast<std::uint8_t>(positive_number));\n                    }\n                    else if (positive_number <= (std::numeric_limits<std::uint16_t>::max)())\n                    {\n                        oa->write_character(to_char_type(0x39));\n                        write_number(static_cast<std::uint16_t>(positive_number));\n                    }\n                    else if (positive_number <= (std::numeric_limits<std::uint32_t>::max)())\n                    {\n                        oa->write_character(to_char_type(0x3A));\n                        write_number(static_cast<std::uint32_t>(positive_number));\n                    }\n                    else\n                    {\n                        oa->write_character(to_char_type(0x3B));\n                        write_number(static_cast<std::uint64_t>(positive_number));\n                    }\n                }\n                break;\n            }\n\n            case value_t::number_unsigned:\n            {\n                if (j.m_data.m_value.number_unsigned <= 0x17)\n                {\n                    write_number(static_cast<std::uint8_t>(j.m_data.m_value.number_unsigned));\n                }\n                else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint8_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x18));\n                    write_number(static_cast<std::uint8_t>(j.m_data.m_value.number_unsigned));\n                }\n                else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint16_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x19));\n                    write_number(static_cast<std::uint16_t>(j.m_data.m_value.number_unsigned));\n                }\n                else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint32_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x1A));\n                    write_number(static_cast<std::uint32_t>(j.m_data.m_value.number_unsigned));\n                }\n                else\n                {\n                    oa->write_character(to_char_type(0x1B));\n                    write_number(static_cast<std::uint64_t>(j.m_data.m_value.number_unsigned));\n                }\n                break;\n            }\n\n            case value_t::number_float:\n            {\n                if (std::isnan(j.m_data.m_value.number_float))\n                {\n                    // NaN is 0xf97e00 in CBOR\n                    oa->write_character(to_char_type(0xF9));\n                    oa->write_character(to_char_type(0x7E));\n                    oa->write_character(to_char_type(0x00));\n                }\n                else if (std::isinf(j.m_data.m_value.number_float))\n                {\n                    // Infinity is 0xf97c00, -Infinity is 0xf9fc00\n                    oa->write_character(to_char_type(0xf9));\n                    oa->write_character(j.m_data.m_value.number_float > 0 ? to_char_type(0x7C) : to_char_type(0xFC));\n                    oa->write_character(to_char_type(0x00));\n                }\n                else\n                {\n                    write_compact_float(j.m_data.m_value.number_float, detail::input_format_t::cbor);\n                }\n                break;\n            }\n\n            case value_t::string:\n            {\n                // step 1: write control byte and the string length\n                const auto N = j.m_data.m_value.string->size();\n                if (N <= 0x17)\n                {\n                    write_number(static_cast<std::uint8_t>(0x60 + N));\n                }\n                else if (N <= (std::numeric_limits<std::uint8_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x78));\n                    write_number(static_cast<std::uint8_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint16_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x79));\n                    write_number(static_cast<std::uint16_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint32_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x7A));\n                    write_number(static_cast<std::uint32_t>(N));\n                }\n                // LCOV_EXCL_START\n                else if (N <= (std::numeric_limits<std::uint64_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x7B));\n                    write_number(static_cast<std::uint64_t>(N));\n                }\n                // LCOV_EXCL_STOP\n\n                // step 2: write the string\n                oa->write_characters(\n                    reinterpret_cast<const CharType*>(j.m_data.m_value.string->c_str()),\n                    j.m_data.m_value.string->size());\n                break;\n            }\n\n            case value_t::array:\n            {\n                // step 1: write control byte and the array size\n                const auto N = j.m_data.m_value.array->size();\n                if (N <= 0x17)\n                {\n                    write_number(static_cast<std::uint8_t>(0x80 + N));\n                }\n                else if (N <= (std::numeric_limits<std::uint8_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x98));\n                    write_number(static_cast<std::uint8_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint16_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x99));\n                    write_number(static_cast<std::uint16_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint32_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x9A));\n                    write_number(static_cast<std::uint32_t>(N));\n                }\n                // LCOV_EXCL_START\n                else if (N <= (std::numeric_limits<std::uint64_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x9B));\n                    write_number(static_cast<std::uint64_t>(N));\n                }\n                // LCOV_EXCL_STOP\n\n                // step 2: write each element\n                for (const auto& el : *j.m_data.m_value.array)\n                {\n                    write_cbor(el);\n                }\n                break;\n            }\n\n            case value_t::binary:\n            {\n                if (j.m_data.m_value.binary->has_subtype())\n                {\n                    if (j.m_data.m_value.binary->subtype() <= (std::numeric_limits<std::uint8_t>::max)())\n                    {\n                        write_number(static_cast<std::uint8_t>(0xd8));\n                        write_number(static_cast<std::uint8_t>(j.m_data.m_value.binary->subtype()));\n                    }\n                    else if (j.m_data.m_value.binary->subtype() <= (std::numeric_limits<std::uint16_t>::max)())\n                    {\n                        write_number(static_cast<std::uint8_t>(0xd9));\n                        write_number(static_cast<std::uint16_t>(j.m_data.m_value.binary->subtype()));\n                    }\n                    else if (j.m_data.m_value.binary->subtype() <= (std::numeric_limits<std::uint32_t>::max)())\n                    {\n                        write_number(static_cast<std::uint8_t>(0xda));\n                        write_number(static_cast<std::uint32_t>(j.m_data.m_value.binary->subtype()));\n                    }\n                    else if (j.m_data.m_value.binary->subtype() <= (std::numeric_limits<std::uint64_t>::max)())\n                    {\n                        write_number(static_cast<std::uint8_t>(0xdb));\n                        write_number(static_cast<std::uint64_t>(j.m_data.m_value.binary->subtype()));\n                    }\n                }\n\n                // step 1: write control byte and the binary array size\n                const auto N = j.m_data.m_value.binary->size();\n                if (N <= 0x17)\n                {\n                    write_number(static_cast<std::uint8_t>(0x40 + N));\n                }\n                else if (N <= (std::numeric_limits<std::uint8_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x58));\n                    write_number(static_cast<std::uint8_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint16_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x59));\n                    write_number(static_cast<std::uint16_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint32_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x5A));\n                    write_number(static_cast<std::uint32_t>(N));\n                }\n                // LCOV_EXCL_START\n                else if (N <= (std::numeric_limits<std::uint64_t>::max)())\n                {\n                    oa->write_character(to_char_type(0x5B));\n                    write_number(static_cast<std::uint64_t>(N));\n                }\n                // LCOV_EXCL_STOP\n\n                // step 2: write each element\n                oa->write_characters(\n                    reinterpret_cast<const CharType*>(j.m_data.m_value.binary->data()),\n                    N);\n\n                break;\n            }\n\n            case value_t::object:\n            {\n                // step 1: write control byte and the object size\n                const auto N = j.m_data.m_value.object->size();\n                if (N <= 0x17)\n                {\n                    write_number(static_cast<std::uint8_t>(0xA0 + N));\n                }\n                else if (N <= (std::numeric_limits<std::uint8_t>::max)())\n                {\n                    oa->write_character(to_char_type(0xB8));\n                    write_number(static_cast<std::uint8_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint16_t>::max)())\n                {\n                    oa->write_character(to_char_type(0xB9));\n                    write_number(static_cast<std::uint16_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint32_t>::max)())\n                {\n                    oa->write_character(to_char_type(0xBA));\n                    write_number(static_cast<std::uint32_t>(N));\n                }\n                // LCOV_EXCL_START\n                else if (N <= (std::numeric_limits<std::uint64_t>::max)())\n                {\n                    oa->write_character(to_char_type(0xBB));\n                    write_number(static_cast<std::uint64_t>(N));\n                }\n                // LCOV_EXCL_STOP\n\n                // step 2: write each element\n                for (const auto& el : *j.m_data.m_value.object)\n                {\n                    write_cbor(el.first);\n                    write_cbor(el.second);\n                }\n                break;\n            }\n\n            case value_t::discarded:\n            default:\n                break;\n        }\n    }\n\n    /*!\n    @param[in] j  JSON value to serialize\n    */\n    void write_msgpack(const BasicJsonType& j)\n    {\n        switch (j.type())\n        {\n            case value_t::null: // nil\n            {\n                oa->write_character(to_char_type(0xC0));\n                break;\n            }\n\n            case value_t::boolean: // true and false\n            {\n                oa->write_character(j.m_data.m_value.boolean\n                                    ? to_char_type(0xC3)\n                                    : to_char_type(0xC2));\n                break;\n            }\n\n            case value_t::number_integer:\n            {\n                if (j.m_data.m_value.number_integer >= 0)\n                {\n                    // MessagePack does not differentiate between positive\n                    // signed integers and unsigned integers. Therefore, we used\n                    // the code from the value_t::number_unsigned case here.\n                    if (j.m_data.m_value.number_unsigned < 128)\n                    {\n                        // positive fixnum\n                        write_number(static_cast<std::uint8_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint8_t>::max)())\n                    {\n                        // uint 8\n                        oa->write_character(to_char_type(0xCC));\n                        write_number(static_cast<std::uint8_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint16_t>::max)())\n                    {\n                        // uint 16\n                        oa->write_character(to_char_type(0xCD));\n                        write_number(static_cast<std::uint16_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint32_t>::max)())\n                    {\n                        // uint 32\n                        oa->write_character(to_char_type(0xCE));\n                        write_number(static_cast<std::uint32_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint64_t>::max)())\n                    {\n                        // uint 64\n                        oa->write_character(to_char_type(0xCF));\n                        write_number(static_cast<std::uint64_t>(j.m_data.m_value.number_integer));\n                    }\n                }\n                else\n                {\n                    if (j.m_data.m_value.number_integer >= -32)\n                    {\n                        // negative fixnum\n                        write_number(static_cast<std::int8_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_integer >= (std::numeric_limits<std::int8_t>::min)() &&\n                             j.m_data.m_value.number_integer <= (std::numeric_limits<std::int8_t>::max)())\n                    {\n                        // int 8\n                        oa->write_character(to_char_type(0xD0));\n                        write_number(static_cast<std::int8_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_integer >= (std::numeric_limits<std::int16_t>::min)() &&\n                             j.m_data.m_value.number_integer <= (std::numeric_limits<std::int16_t>::max)())\n                    {\n                        // int 16\n                        oa->write_character(to_char_type(0xD1));\n                        write_number(static_cast<std::int16_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_integer >= (std::numeric_limits<std::int32_t>::min)() &&\n                             j.m_data.m_value.number_integer <= (std::numeric_limits<std::int32_t>::max)())\n                    {\n                        // int 32\n                        oa->write_character(to_char_type(0xD2));\n                        write_number(static_cast<std::int32_t>(j.m_data.m_value.number_integer));\n                    }\n                    else if (j.m_data.m_value.number_integer >= (std::numeric_limits<std::int64_t>::min)() &&\n                             j.m_data.m_value.number_integer <= (std::numeric_limits<std::int64_t>::max)())\n                    {\n                        // int 64\n                        oa->write_character(to_char_type(0xD3));\n                        write_number(static_cast<std::int64_t>(j.m_data.m_value.number_integer));\n                    }\n                }\n                break;\n            }\n\n            case value_t::number_unsigned:\n            {\n                if (j.m_data.m_value.number_unsigned < 128)\n                {\n                    // positive fixnum\n                    write_number(static_cast<std::uint8_t>(j.m_data.m_value.number_integer));\n                }\n                else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint8_t>::max)())\n                {\n                    // uint 8\n                    oa->write_character(to_char_type(0xCC));\n                    write_number(static_cast<std::uint8_t>(j.m_data.m_value.number_integer));\n                }\n                else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint16_t>::max)())\n                {\n                    // uint 16\n                    oa->write_character(to_char_type(0xCD));\n                    write_number(static_cast<std::uint16_t>(j.m_data.m_value.number_integer));\n                }\n                else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint32_t>::max)())\n                {\n                    // uint 32\n                    oa->write_character(to_char_type(0xCE));\n                    write_number(static_cast<std::uint32_t>(j.m_data.m_value.number_integer));\n                }\n                else if (j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint64_t>::max)())\n                {\n                    // uint 64\n                    oa->write_character(to_char_type(0xCF));\n                    write_number(static_cast<std::uint64_t>(j.m_data.m_value.number_integer));\n                }\n                break;\n            }\n\n            case value_t::number_float:\n            {\n                write_compact_float(j.m_data.m_value.number_float, detail::input_format_t::msgpack);\n                break;\n            }\n\n            case value_t::string:\n            {\n                // step 1: write control byte and the string length\n                const auto N = j.m_data.m_value.string->size();\n                if (N <= 31)\n                {\n                    // fixstr\n                    write_number(static_cast<std::uint8_t>(0xA0 | N));\n                }\n                else if (N <= (std::numeric_limits<std::uint8_t>::max)())\n                {\n                    // str 8\n                    oa->write_character(to_char_type(0xD9));\n                    write_number(static_cast<std::uint8_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint16_t>::max)())\n                {\n                    // str 16\n                    oa->write_character(to_char_type(0xDA));\n                    write_number(static_cast<std::uint16_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint32_t>::max)())\n                {\n                    // str 32\n                    oa->write_character(to_char_type(0xDB));\n                    write_number(static_cast<std::uint32_t>(N));\n                }\n\n                // step 2: write the string\n                oa->write_characters(\n                    reinterpret_cast<const CharType*>(j.m_data.m_value.string->c_str()),\n                    j.m_data.m_value.string->size());\n                break;\n            }\n\n            case value_t::array:\n            {\n                // step 1: write control byte and the array size\n                const auto N = j.m_data.m_value.array->size();\n                if (N <= 15)\n                {\n                    // fixarray\n                    write_number(static_cast<std::uint8_t>(0x90 | N));\n                }\n                else if (N <= (std::numeric_limits<std::uint16_t>::max)())\n                {\n                    // array 16\n                    oa->write_character(to_char_type(0xDC));\n                    write_number(static_cast<std::uint16_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint32_t>::max)())\n                {\n                    // array 32\n                    oa->write_character(to_char_type(0xDD));\n                    write_number(static_cast<std::uint32_t>(N));\n                }\n\n                // step 2: write each element\n                for (const auto& el : *j.m_data.m_value.array)\n                {\n                    write_msgpack(el);\n                }\n                break;\n            }\n\n            case value_t::binary:\n            {\n                // step 0: determine if the binary type has a set subtype to\n                // determine whether or not to use the ext or fixext types\n                const bool use_ext = j.m_data.m_value.binary->has_subtype();\n\n                // step 1: write control byte and the byte string length\n                const auto N = j.m_data.m_value.binary->size();\n                if (N <= (std::numeric_limits<std::uint8_t>::max)())\n                {\n                    std::uint8_t output_type{};\n                    bool fixed = true;\n                    if (use_ext)\n                    {\n                        switch (N)\n                        {\n                            case 1:\n                                output_type = 0xD4; // fixext 1\n                                break;\n                            case 2:\n                                output_type = 0xD5; // fixext 2\n                                break;\n                            case 4:\n                                output_type = 0xD6; // fixext 4\n                                break;\n                            case 8:\n                                output_type = 0xD7; // fixext 8\n                                break;\n                            case 16:\n                                output_type = 0xD8; // fixext 16\n                                break;\n                            default:\n                                output_type = 0xC7; // ext 8\n                                fixed = false;\n                                break;\n                        }\n\n                    }\n                    else\n                    {\n                        output_type = 0xC4; // bin 8\n                        fixed = false;\n                    }\n\n                    oa->write_character(to_char_type(output_type));\n                    if (!fixed)\n                    {\n                        write_number(static_cast<std::uint8_t>(N));\n                    }\n                }\n                else if (N <= (std::numeric_limits<std::uint16_t>::max)())\n                {\n                    const std::uint8_t output_type = use_ext\n                                                     ? 0xC8 // ext 16\n                                                     : 0xC5; // bin 16\n\n                    oa->write_character(to_char_type(output_type));\n                    write_number(static_cast<std::uint16_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint32_t>::max)())\n                {\n                    const std::uint8_t output_type = use_ext\n                                                     ? 0xC9 // ext 32\n                                                     : 0xC6; // bin 32\n\n                    oa->write_character(to_char_type(output_type));\n                    write_number(static_cast<std::uint32_t>(N));\n                }\n\n                // step 1.5: if this is an ext type, write the subtype\n                if (use_ext)\n                {\n                    write_number(static_cast<std::int8_t>(j.m_data.m_value.binary->subtype()));\n                }\n\n                // step 2: write the byte string\n                oa->write_characters(\n                    reinterpret_cast<const CharType*>(j.m_data.m_value.binary->data()),\n                    N);\n\n                break;\n            }\n\n            case value_t::object:\n            {\n                // step 1: write control byte and the object size\n                const auto N = j.m_data.m_value.object->size();\n                if (N <= 15)\n                {\n                    // fixmap\n                    write_number(static_cast<std::uint8_t>(0x80 | (N & 0xF)));\n                }\n                else if (N <= (std::numeric_limits<std::uint16_t>::max)())\n                {\n                    // map 16\n                    oa->write_character(to_char_type(0xDE));\n                    write_number(static_cast<std::uint16_t>(N));\n                }\n                else if (N <= (std::numeric_limits<std::uint32_t>::max)())\n                {\n                    // map 32\n                    oa->write_character(to_char_type(0xDF));\n                    write_number(static_cast<std::uint32_t>(N));\n                }\n\n                // step 2: write each element\n                for (const auto& el : *j.m_data.m_value.object)\n                {\n                    write_msgpack(el.first);\n                    write_msgpack(el.second);\n                }\n                break;\n            }\n\n            case value_t::discarded:\n            default:\n                break;\n        }\n    }\n\n    /*!\n    @param[in] j  JSON value to serialize\n    @param[in] use_count   whether to use '#' prefixes (optimized format)\n    @param[in] use_type    whether to use '$' prefixes (optimized format)\n    @param[in] add_prefix  whether prefixes need to be used for this value\n    @param[in] use_bjdata  whether write in BJData format, default is false\n    */\n    void write_ubjson(const BasicJsonType& j, const bool use_count,\n                      const bool use_type, const bool add_prefix = true,\n                      const bool use_bjdata = false)\n    {\n        switch (j.type())\n        {\n            case value_t::null:\n            {\n                if (add_prefix)\n                {\n                    oa->write_character(to_char_type('Z'));\n                }\n                break;\n            }\n\n            case value_t::boolean:\n            {\n                if (add_prefix)\n                {\n                    oa->write_character(j.m_data.m_value.boolean\n                                        ? to_char_type('T')\n                                        : to_char_type('F'));\n                }\n                break;\n            }\n\n            case value_t::number_integer:\n            {\n                write_number_with_ubjson_prefix(j.m_data.m_value.number_integer, add_prefix, use_bjdata);\n                break;\n            }\n\n            case value_t::number_unsigned:\n            {\n                write_number_with_ubjson_prefix(j.m_data.m_value.number_unsigned, add_prefix, use_bjdata);\n                break;\n            }\n\n            case value_t::number_float:\n            {\n                write_number_with_ubjson_prefix(j.m_data.m_value.number_float, add_prefix, use_bjdata);\n                break;\n            }\n\n            case value_t::string:\n            {\n                if (add_prefix)\n                {\n                    oa->write_character(to_char_type('S'));\n                }\n                write_number_with_ubjson_prefix(j.m_data.m_value.string->size(), true, use_bjdata);\n                oa->write_characters(\n                    reinterpret_cast<const CharType*>(j.m_data.m_value.string->c_str()),\n                    j.m_data.m_value.string->size());\n                break;\n            }\n\n            case value_t::array:\n            {\n                if (add_prefix)\n                {\n                    oa->write_character(to_char_type('['));\n                }\n\n                bool prefix_required = true;\n                if (use_type && !j.m_data.m_value.array->empty())\n                {\n                    JSON_ASSERT(use_count);\n                    const CharType first_prefix = ubjson_prefix(j.front(), use_bjdata);\n                    const bool same_prefix = std::all_of(j.begin() + 1, j.end(),\n                                                         [this, first_prefix, use_bjdata](const BasicJsonType & v)\n                    {\n                        return ubjson_prefix(v, use_bjdata) == first_prefix;\n                    });\n\n                    std::vector<CharType> bjdx = {'[', '{', 'S', 'H', 'T', 'F', 'N', 'Z'}; // excluded markers in bjdata optimized type\n\n                    if (same_prefix && !(use_bjdata && std::find(bjdx.begin(), bjdx.end(), first_prefix) != bjdx.end()))\n                    {\n                        prefix_required = false;\n                        oa->write_character(to_char_type('$'));\n                        oa->write_character(first_prefix);\n                    }\n                }\n\n                if (use_count)\n                {\n                    oa->write_character(to_char_type('#'));\n                    write_number_with_ubjson_prefix(j.m_data.m_value.array->size(), true, use_bjdata);\n                }\n\n                for (const auto& el : *j.m_data.m_value.array)\n                {\n                    write_ubjson(el, use_count, use_type, prefix_required, use_bjdata);\n                }\n\n                if (!use_count)\n                {\n                    oa->write_character(to_char_type(']'));\n                }\n\n                break;\n            }\n\n            case value_t::binary:\n            {\n                if (add_prefix)\n                {\n                    oa->write_character(to_char_type('['));\n                }\n\n                if (use_type && !j.m_data.m_value.binary->empty())\n                {\n                    JSON_ASSERT(use_count);\n                    oa->write_character(to_char_type('$'));\n                    oa->write_character('U');\n                }\n\n                if (use_count)\n                {\n                    oa->write_character(to_char_type('#'));\n                    write_number_with_ubjson_prefix(j.m_data.m_value.binary->size(), true, use_bjdata);\n                }\n\n                if (use_type)\n                {\n                    oa->write_characters(\n                        reinterpret_cast<const CharType*>(j.m_data.m_value.binary->data()),\n                        j.m_data.m_value.binary->size());\n                }\n                else\n                {\n                    for (size_t i = 0; i < j.m_data.m_value.binary->size(); ++i)\n                    {\n                        oa->write_character(to_char_type('U'));\n                        oa->write_character(j.m_data.m_value.binary->data()[i]);\n                    }\n                }\n\n                if (!use_count)\n                {\n                    oa->write_character(to_char_type(']'));\n                }\n\n                break;\n            }\n\n            case value_t::object:\n            {\n                if (use_bjdata && j.m_data.m_value.object->size() == 3 && j.m_data.m_value.object->find(\"_ArrayType_\") != j.m_data.m_value.object->end() && j.m_data.m_value.object->find(\"_ArraySize_\") != j.m_data.m_value.object->end() && j.m_data.m_value.object->find(\"_ArrayData_\") != j.m_data.m_value.object->end())\n                {\n                    if (!write_bjdata_ndarray(*j.m_data.m_value.object, use_count, use_type))  // decode bjdata ndarray in the JData format (https://github.com/NeuroJSON/jdata)\n                    {\n                        break;\n                    }\n                }\n\n                if (add_prefix)\n                {\n                    oa->write_character(to_char_type('{'));\n                }\n\n                bool prefix_required = true;\n                if (use_type && !j.m_data.m_value.object->empty())\n                {\n                    JSON_ASSERT(use_count);\n                    const CharType first_prefix = ubjson_prefix(j.front(), use_bjdata);\n                    const bool same_prefix = std::all_of(j.begin(), j.end(),\n                                                         [this, first_prefix, use_bjdata](const BasicJsonType & v)\n                    {\n                        return ubjson_prefix(v, use_bjdata) == first_prefix;\n                    });\n\n                    std::vector<CharType> bjdx = {'[', '{', 'S', 'H', 'T', 'F', 'N', 'Z'}; // excluded markers in bjdata optimized type\n\n                    if (same_prefix && !(use_bjdata && std::find(bjdx.begin(), bjdx.end(), first_prefix) != bjdx.end()))\n                    {\n                        prefix_required = false;\n                        oa->write_character(to_char_type('$'));\n                        oa->write_character(first_prefix);\n                    }\n                }\n\n                if (use_count)\n                {\n                    oa->write_character(to_char_type('#'));\n                    write_number_with_ubjson_prefix(j.m_data.m_value.object->size(), true, use_bjdata);\n                }\n\n                for (const auto& el : *j.m_data.m_value.object)\n                {\n                    write_number_with_ubjson_prefix(el.first.size(), true, use_bjdata);\n                    oa->write_characters(\n                        reinterpret_cast<const CharType*>(el.first.c_str()),\n                        el.first.size());\n                    write_ubjson(el.second, use_count, use_type, prefix_required, use_bjdata);\n                }\n\n                if (!use_count)\n                {\n                    oa->write_character(to_char_type('}'));\n                }\n\n                break;\n            }\n\n            case value_t::discarded:\n            default:\n                break;\n        }\n    }\n\n  private:\n    //////////\n    // BSON //\n    //////////\n\n    /*!\n    @return The size of a BSON document entry header, including the id marker\n            and the entry name size (and its null-terminator).\n    */\n    static std::size_t calc_bson_entry_header_size(const string_t& name, const BasicJsonType& j)\n    {\n        const auto it = name.find(static_cast<typename string_t::value_type>(0));\n        if (JSON_HEDLEY_UNLIKELY(it != BasicJsonType::string_t::npos))\n        {\n            JSON_THROW(out_of_range::create(409, concat(\"BSON key cannot contain code point U+0000 (at byte \", std::to_string(it), \")\"), &j));\n            static_cast<void>(j);\n        }\n\n        return /*id*/ 1ul + name.size() + /*zero-terminator*/1u;\n    }\n\n    /*!\n    @brief Writes the given @a element_type and @a name to the output adapter\n    */\n    void write_bson_entry_header(const string_t& name,\n                                 const std::uint8_t element_type)\n    {\n        oa->write_character(to_char_type(element_type)); // boolean\n        oa->write_characters(\n            reinterpret_cast<const CharType*>(name.c_str()),\n            name.size() + 1u);\n    }\n\n    /*!\n    @brief Writes a BSON element with key @a name and boolean value @a value\n    */\n    void write_bson_boolean(const string_t& name,\n                            const bool value)\n    {\n        write_bson_entry_header(name, 0x08);\n        oa->write_character(value ? to_char_type(0x01) : to_char_type(0x00));\n    }\n\n    /*!\n    @brief Writes a BSON element with key @a name and double value @a value\n    */\n    void write_bson_double(const string_t& name,\n                           const double value)\n    {\n        write_bson_entry_header(name, 0x01);\n        write_number<double>(value, true);\n    }\n\n    /*!\n    @return The size of the BSON-encoded string in @a value\n    */\n    static std::size_t calc_bson_string_size(const string_t& value)\n    {\n        return sizeof(std::int32_t) + value.size() + 1ul;\n    }\n\n    /*!\n    @brief Writes a BSON element with key @a name and string value @a value\n    */\n    void write_bson_string(const string_t& name,\n                           const string_t& value)\n    {\n        write_bson_entry_header(name, 0x02);\n\n        write_number<std::int32_t>(static_cast<std::int32_t>(value.size() + 1ul), true);\n        oa->write_characters(\n            reinterpret_cast<const CharType*>(value.c_str()),\n            value.size() + 1);\n    }\n\n    /*!\n    @brief Writes a BSON element with key @a name and null value\n    */\n    void write_bson_null(const string_t& name)\n    {\n        write_bson_entry_header(name, 0x0A);\n    }\n\n    /*!\n    @return The size of the BSON-encoded integer @a value\n    */\n    static std::size_t calc_bson_integer_size(const std::int64_t value)\n    {\n        return (std::numeric_limits<std::int32_t>::min)() <= value && value <= (std::numeric_limits<std::int32_t>::max)()\n               ? sizeof(std::int32_t)\n               : sizeof(std::int64_t);\n    }\n\n    /*!\n    @brief Writes a BSON element with key @a name and integer @a value\n    */\n    void write_bson_integer(const string_t& name,\n                            const std::int64_t value)\n    {\n        if ((std::numeric_limits<std::int32_t>::min)() <= value && value <= (std::numeric_limits<std::int32_t>::max)())\n        {\n            write_bson_entry_header(name, 0x10); // int32\n            write_number<std::int32_t>(static_cast<std::int32_t>(value), true);\n        }\n        else\n        {\n            write_bson_entry_header(name, 0x12); // int64\n            write_number<std::int64_t>(static_cast<std::int64_t>(value), true);\n        }\n    }\n\n    /*!\n    @return The size of the BSON-encoded unsigned integer in @a j\n    */\n    static constexpr std::size_t calc_bson_unsigned_size(const std::uint64_t value) noexcept\n    {\n        return (value <= static_cast<std::uint64_t>((std::numeric_limits<std::int32_t>::max)()))\n               ? sizeof(std::int32_t)\n               : sizeof(std::int64_t);\n    }\n\n    /*!\n    @brief Writes a BSON element with key @a name and unsigned @a value\n    */\n    void write_bson_unsigned(const string_t& name,\n                             const BasicJsonType& j)\n    {\n        if (j.m_data.m_value.number_unsigned <= static_cast<std::uint64_t>((std::numeric_limits<std::int32_t>::max)()))\n        {\n            write_bson_entry_header(name, 0x10 /* int32 */);\n            write_number<std::int32_t>(static_cast<std::int32_t>(j.m_data.m_value.number_unsigned), true);\n        }\n        else if (j.m_data.m_value.number_unsigned <= static_cast<std::uint64_t>((std::numeric_limits<std::int64_t>::max)()))\n        {\n            write_bson_entry_header(name, 0x12 /* int64 */);\n            write_number<std::int64_t>(static_cast<std::int64_t>(j.m_data.m_value.number_unsigned), true);\n        }\n        else\n        {\n            JSON_THROW(out_of_range::create(407, concat(\"integer number \", std::to_string(j.m_data.m_value.number_unsigned), \" cannot be represented by BSON as it does not fit int64\"), &j));\n        }\n    }\n\n    /*!\n    @brief Writes a BSON element with key @a name and object @a value\n    */\n    void write_bson_object_entry(const string_t& name,\n                                 const typename BasicJsonType::object_t& value)\n    {\n        write_bson_entry_header(name, 0x03); // object\n        write_bson_object(value);\n    }\n\n    /*!\n    @return The size of the BSON-encoded array @a value\n    */\n    static std::size_t calc_bson_array_size(const typename BasicJsonType::array_t& value)\n    {\n        std::size_t array_index = 0ul;\n\n        const std::size_t embedded_document_size = std::accumulate(std::begin(value), std::end(value), static_cast<std::size_t>(0), [&array_index](std::size_t result, const typename BasicJsonType::array_t::value_type & el)\n        {\n            return result + calc_bson_element_size(std::to_string(array_index++), el);\n        });\n\n        return sizeof(std::int32_t) + embedded_document_size + 1ul;\n    }\n\n    /*!\n    @return The size of the BSON-encoded binary array @a value\n    */\n    static std::size_t calc_bson_binary_size(const typename BasicJsonType::binary_t& value)\n    {\n        return sizeof(std::int32_t) + value.size() + 1ul;\n    }\n\n    /*!\n    @brief Writes a BSON element with key @a name and array @a value\n    */\n    void write_bson_array(const string_t& name,\n                          const typename BasicJsonType::array_t& value)\n    {\n        write_bson_entry_header(name, 0x04); // array\n        write_number<std::int32_t>(static_cast<std::int32_t>(calc_bson_array_size(value)), true);\n\n        std::size_t array_index = 0ul;\n\n        for (const auto& el : value)\n        {\n            write_bson_element(std::to_string(array_index++), el);\n        }\n\n        oa->write_character(to_char_type(0x00));\n    }\n\n    /*!\n    @brief Writes a BSON element with key @a name and binary value @a value\n    */\n    void write_bson_binary(const string_t& name,\n                           const binary_t& value)\n    {\n        write_bson_entry_header(name, 0x05);\n\n        write_number<std::int32_t>(static_cast<std::int32_t>(value.size()), true);\n        write_number(value.has_subtype() ? static_cast<std::uint8_t>(value.subtype()) : static_cast<std::uint8_t>(0x00));\n\n        oa->write_characters(reinterpret_cast<const CharType*>(value.data()), value.size());\n    }\n\n    /*!\n    @brief Calculates the size necessary to serialize the JSON value @a j with its @a name\n    @return The calculated size for the BSON document entry for @a j with the given @a name.\n    */\n    static std::size_t calc_bson_element_size(const string_t& name,\n            const BasicJsonType& j)\n    {\n        const auto header_size = calc_bson_entry_header_size(name, j);\n        switch (j.type())\n        {\n            case value_t::object:\n                return header_size + calc_bson_object_size(*j.m_data.m_value.object);\n\n            case value_t::array:\n                return header_size + calc_bson_array_size(*j.m_data.m_value.array);\n\n            case value_t::binary:\n                return header_size + calc_bson_binary_size(*j.m_data.m_value.binary);\n\n            case value_t::boolean:\n                return header_size + 1ul;\n\n            case value_t::number_float:\n                return header_size + 8ul;\n\n            case value_t::number_integer:\n                return header_size + calc_bson_integer_size(j.m_data.m_value.number_integer);\n\n            case value_t::number_unsigned:\n                return header_size + calc_bson_unsigned_size(j.m_data.m_value.number_unsigned);\n\n            case value_t::string:\n                return header_size + calc_bson_string_size(*j.m_data.m_value.string);\n\n            case value_t::null:\n                return header_size + 0ul;\n\n            // LCOV_EXCL_START\n            case value_t::discarded:\n            default:\n                JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert)\n                return 0ul;\n                // LCOV_EXCL_STOP\n        }\n    }\n\n    /*!\n    @brief Serializes the JSON value @a j to BSON and associates it with the\n           key @a name.\n    @param name The name to associate with the JSON entity @a j within the\n                current BSON document\n    */\n    void write_bson_element(const string_t& name,\n                            const BasicJsonType& j)\n    {\n        switch (j.type())\n        {\n            case value_t::object:\n                return write_bson_object_entry(name, *j.m_data.m_value.object);\n\n            case value_t::array:\n                return write_bson_array(name, *j.m_data.m_value.array);\n\n            case value_t::binary:\n                return write_bson_binary(name, *j.m_data.m_value.binary);\n\n            case value_t::boolean:\n                return write_bson_boolean(name, j.m_data.m_value.boolean);\n\n            case value_t::number_float:\n                return write_bson_double(name, j.m_data.m_value.number_float);\n\n            case value_t::number_integer:\n                return write_bson_integer(name, j.m_data.m_value.number_integer);\n\n            case value_t::number_unsigned:\n                return write_bson_unsigned(name, j);\n\n            case value_t::string:\n                return write_bson_string(name, *j.m_data.m_value.string);\n\n            case value_t::null:\n                return write_bson_null(name);\n\n            // LCOV_EXCL_START\n            case value_t::discarded:\n            default:\n                JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert)\n                return;\n                // LCOV_EXCL_STOP\n        }\n    }\n\n    /*!\n    @brief Calculates the size of the BSON serialization of the given\n           JSON-object @a j.\n    @param[in] value  JSON value to serialize\n    @pre       value.type() == value_t::object\n    */\n    static std::size_t calc_bson_object_size(const typename BasicJsonType::object_t& value)\n    {\n        const std::size_t document_size = std::accumulate(value.begin(), value.end(), static_cast<std::size_t>(0),\n                                          [](size_t result, const typename BasicJsonType::object_t::value_type & el)\n        {\n            return result += calc_bson_element_size(el.first, el.second);\n        });\n\n        return sizeof(std::int32_t) + document_size + 1ul;\n    }\n\n    /*!\n    @param[in] value  JSON value to serialize\n    @pre       value.type() == value_t::object\n    */\n    void write_bson_object(const typename BasicJsonType::object_t& value)\n    {\n        write_number<std::int32_t>(static_cast<std::int32_t>(calc_bson_object_size(value)), true);\n\n        for (const auto& el : value)\n        {\n            write_bson_element(el.first, el.second);\n        }\n\n        oa->write_character(to_char_type(0x00));\n    }\n\n    //////////\n    // CBOR //\n    //////////\n\n    static constexpr CharType get_cbor_float_prefix(float /*unused*/)\n    {\n        return to_char_type(0xFA);  // Single-Precision Float\n    }\n\n    static constexpr CharType get_cbor_float_prefix(double /*unused*/)\n    {\n        return to_char_type(0xFB);  // Double-Precision Float\n    }\n\n    /////////////\n    // MsgPack //\n    /////////////\n\n    static constexpr CharType get_msgpack_float_prefix(float /*unused*/)\n    {\n        return to_char_type(0xCA);  // float 32\n    }\n\n    static constexpr CharType get_msgpack_float_prefix(double /*unused*/)\n    {\n        return to_char_type(0xCB);  // float 64\n    }\n\n    ////////////\n    // UBJSON //\n    ////////////\n\n    // UBJSON: write number (floating point)\n    template<typename NumberType, typename std::enable_if<\n                 std::is_floating_point<NumberType>::value, int>::type = 0>\n    void write_number_with_ubjson_prefix(const NumberType n,\n                                         const bool add_prefix,\n                                         const bool use_bjdata)\n    {\n        if (add_prefix)\n        {\n            oa->write_character(get_ubjson_float_prefix(n));\n        }\n        write_number(n, use_bjdata);\n    }\n\n    // UBJSON: write number (unsigned integer)\n    template<typename NumberType, typename std::enable_if<\n                 std::is_unsigned<NumberType>::value, int>::type = 0>\n    void write_number_with_ubjson_prefix(const NumberType n,\n                                         const bool add_prefix,\n                                         const bool use_bjdata)\n    {\n        if (n <= static_cast<std::uint64_t>((std::numeric_limits<std::int8_t>::max)()))\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('i'));  // int8\n            }\n            write_number(static_cast<std::uint8_t>(n), use_bjdata);\n        }\n        else if (n <= (std::numeric_limits<std::uint8_t>::max)())\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('U'));  // uint8\n            }\n            write_number(static_cast<std::uint8_t>(n), use_bjdata);\n        }\n        else if (n <= static_cast<std::uint64_t>((std::numeric_limits<std::int16_t>::max)()))\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('I'));  // int16\n            }\n            write_number(static_cast<std::int16_t>(n), use_bjdata);\n        }\n        else if (use_bjdata && n <= static_cast<uint64_t>((std::numeric_limits<uint16_t>::max)()))\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('u'));  // uint16 - bjdata only\n            }\n            write_number(static_cast<std::uint16_t>(n), use_bjdata);\n        }\n        else if (n <= static_cast<std::uint64_t>((std::numeric_limits<std::int32_t>::max)()))\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('l'));  // int32\n            }\n            write_number(static_cast<std::int32_t>(n), use_bjdata);\n        }\n        else if (use_bjdata && n <= static_cast<uint64_t>((std::numeric_limits<uint32_t>::max)()))\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('m'));  // uint32 - bjdata only\n            }\n            write_number(static_cast<std::uint32_t>(n), use_bjdata);\n        }\n        else if (n <= static_cast<std::uint64_t>((std::numeric_limits<std::int64_t>::max)()))\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('L'));  // int64\n            }\n            write_number(static_cast<std::int64_t>(n), use_bjdata);\n        }\n        else if (use_bjdata && n <= (std::numeric_limits<uint64_t>::max)())\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('M'));  // uint64 - bjdata only\n            }\n            write_number(static_cast<std::uint64_t>(n), use_bjdata);\n        }\n        else\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('H'));  // high-precision number\n            }\n\n            const auto number = BasicJsonType(n).dump();\n            write_number_with_ubjson_prefix(number.size(), true, use_bjdata);\n            for (std::size_t i = 0; i < number.size(); ++i)\n            {\n                oa->write_character(to_char_type(static_cast<std::uint8_t>(number[i])));\n            }\n        }\n    }\n\n    // UBJSON: write number (signed integer)\n    template < typename NumberType, typename std::enable_if <\n                   std::is_signed<NumberType>::value&&\n                   !std::is_floating_point<NumberType>::value, int >::type = 0 >\n    void write_number_with_ubjson_prefix(const NumberType n,\n                                         const bool add_prefix,\n                                         const bool use_bjdata)\n    {\n        if ((std::numeric_limits<std::int8_t>::min)() <= n && n <= (std::numeric_limits<std::int8_t>::max)())\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('i'));  // int8\n            }\n            write_number(static_cast<std::int8_t>(n), use_bjdata);\n        }\n        else if (static_cast<std::int64_t>((std::numeric_limits<std::uint8_t>::min)()) <= n && n <= static_cast<std::int64_t>((std::numeric_limits<std::uint8_t>::max)()))\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('U'));  // uint8\n            }\n            write_number(static_cast<std::uint8_t>(n), use_bjdata);\n        }\n        else if ((std::numeric_limits<std::int16_t>::min)() <= n && n <= (std::numeric_limits<std::int16_t>::max)())\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('I'));  // int16\n            }\n            write_number(static_cast<std::int16_t>(n), use_bjdata);\n        }\n        else if (use_bjdata && (static_cast<std::int64_t>((std::numeric_limits<std::uint16_t>::min)()) <= n && n <= static_cast<std::int64_t>((std::numeric_limits<std::uint16_t>::max)())))\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('u'));  // uint16 - bjdata only\n            }\n            write_number(static_cast<uint16_t>(n), use_bjdata);\n        }\n        else if ((std::numeric_limits<std::int32_t>::min)() <= n && n <= (std::numeric_limits<std::int32_t>::max)())\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('l'));  // int32\n            }\n            write_number(static_cast<std::int32_t>(n), use_bjdata);\n        }\n        else if (use_bjdata && (static_cast<std::int64_t>((std::numeric_limits<std::uint32_t>::min)()) <= n && n <= static_cast<std::int64_t>((std::numeric_limits<std::uint32_t>::max)())))\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('m'));  // uint32 - bjdata only\n            }\n            write_number(static_cast<uint32_t>(n), use_bjdata);\n        }\n        else if ((std::numeric_limits<std::int64_t>::min)() <= n && n <= (std::numeric_limits<std::int64_t>::max)())\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('L'));  // int64\n            }\n            write_number(static_cast<std::int64_t>(n), use_bjdata);\n        }\n        // LCOV_EXCL_START\n        else\n        {\n            if (add_prefix)\n            {\n                oa->write_character(to_char_type('H'));  // high-precision number\n            }\n\n            const auto number = BasicJsonType(n).dump();\n            write_number_with_ubjson_prefix(number.size(), true, use_bjdata);\n            for (std::size_t i = 0; i < number.size(); ++i)\n            {\n                oa->write_character(to_char_type(static_cast<std::uint8_t>(number[i])));\n            }\n        }\n        // LCOV_EXCL_STOP\n    }\n\n    /*!\n    @brief determine the type prefix of container values\n    */\n    CharType ubjson_prefix(const BasicJsonType& j, const bool use_bjdata) const noexcept\n    {\n        switch (j.type())\n        {\n            case value_t::null:\n                return 'Z';\n\n            case value_t::boolean:\n                return j.m_data.m_value.boolean ? 'T' : 'F';\n\n            case value_t::number_integer:\n            {\n                if ((std::numeric_limits<std::int8_t>::min)() <= j.m_data.m_value.number_integer && j.m_data.m_value.number_integer <= (std::numeric_limits<std::int8_t>::max)())\n                {\n                    return 'i';\n                }\n                if ((std::numeric_limits<std::uint8_t>::min)() <= j.m_data.m_value.number_integer && j.m_data.m_value.number_integer <= (std::numeric_limits<std::uint8_t>::max)())\n                {\n                    return 'U';\n                }\n                if ((std::numeric_limits<std::int16_t>::min)() <= j.m_data.m_value.number_integer && j.m_data.m_value.number_integer <= (std::numeric_limits<std::int16_t>::max)())\n                {\n                    return 'I';\n                }\n                if (use_bjdata && ((std::numeric_limits<std::uint16_t>::min)() <= j.m_data.m_value.number_integer && j.m_data.m_value.number_integer <= (std::numeric_limits<std::uint16_t>::max)()))\n                {\n                    return 'u';\n                }\n                if ((std::numeric_limits<std::int32_t>::min)() <= j.m_data.m_value.number_integer && j.m_data.m_value.number_integer <= (std::numeric_limits<std::int32_t>::max)())\n                {\n                    return 'l';\n                }\n                if (use_bjdata && ((std::numeric_limits<std::uint32_t>::min)() <= j.m_data.m_value.number_integer && j.m_data.m_value.number_integer <= (std::numeric_limits<std::uint32_t>::max)()))\n                {\n                    return 'm';\n                }\n                if ((std::numeric_limits<std::int64_t>::min)() <= j.m_data.m_value.number_integer && j.m_data.m_value.number_integer <= (std::numeric_limits<std::int64_t>::max)())\n                {\n                    return 'L';\n                }\n                // anything else is treated as high-precision number\n                return 'H'; // LCOV_EXCL_LINE\n            }\n\n            case value_t::number_unsigned:\n            {\n                if (j.m_data.m_value.number_unsigned <= static_cast<std::uint64_t>((std::numeric_limits<std::int8_t>::max)()))\n                {\n                    return 'i';\n                }\n                if (j.m_data.m_value.number_unsigned <= static_cast<std::uint64_t>((std::numeric_limits<std::uint8_t>::max)()))\n                {\n                    return 'U';\n                }\n                if (j.m_data.m_value.number_unsigned <= static_cast<std::uint64_t>((std::numeric_limits<std::int16_t>::max)()))\n                {\n                    return 'I';\n                }\n                if (use_bjdata && j.m_data.m_value.number_unsigned <= static_cast<std::uint64_t>((std::numeric_limits<std::uint16_t>::max)()))\n                {\n                    return 'u';\n                }\n                if (j.m_data.m_value.number_unsigned <= static_cast<std::uint64_t>((std::numeric_limits<std::int32_t>::max)()))\n                {\n                    return 'l';\n                }\n                if (use_bjdata && j.m_data.m_value.number_unsigned <= static_cast<std::uint64_t>((std::numeric_limits<std::uint32_t>::max)()))\n                {\n                    return 'm';\n                }\n                if (j.m_data.m_value.number_unsigned <= static_cast<std::uint64_t>((std::numeric_limits<std::int64_t>::max)()))\n                {\n                    return 'L';\n                }\n                if (use_bjdata && j.m_data.m_value.number_unsigned <= (std::numeric_limits<std::uint64_t>::max)())\n                {\n                    return 'M';\n                }\n                // anything else is treated as high-precision number\n                return 'H'; // LCOV_EXCL_LINE\n            }\n\n            case value_t::number_float:\n                return get_ubjson_float_prefix(j.m_data.m_value.number_float);\n\n            case value_t::string:\n                return 'S';\n\n            case value_t::array: // fallthrough\n            case value_t::binary:\n                return '[';\n\n            case value_t::object:\n                return '{';\n\n            case value_t::discarded:\n            default:  // discarded values\n                return 'N';\n        }\n    }\n\n    static constexpr CharType get_ubjson_float_prefix(float /*unused*/)\n    {\n        return 'd';  // float 32\n    }\n\n    static constexpr CharType get_ubjson_float_prefix(double /*unused*/)\n    {\n        return 'D';  // float 64\n    }\n\n    /*!\n    @return false if the object is successfully converted to a bjdata ndarray, true if the type or size is invalid\n    */\n    bool write_bjdata_ndarray(const typename BasicJsonType::object_t& value, const bool use_count, const bool use_type)\n    {\n        std::map<string_t, CharType> bjdtype = {{\"uint8\", 'U'},  {\"int8\", 'i'},  {\"uint16\", 'u'}, {\"int16\", 'I'},\n            {\"uint32\", 'm'}, {\"int32\", 'l'}, {\"uint64\", 'M'}, {\"int64\", 'L'}, {\"single\", 'd'}, {\"double\", 'D'}, {\"char\", 'C'}\n        };\n\n        string_t key = \"_ArrayType_\";\n        auto it = bjdtype.find(static_cast<string_t>(value.at(key)));\n        if (it == bjdtype.end())\n        {\n            return true;\n        }\n        CharType dtype = it->second;\n\n        key = \"_ArraySize_\";\n        std::size_t len = (value.at(key).empty() ? 0 : 1);\n        for (const auto& el : value.at(key))\n        {\n            len *= static_cast<std::size_t>(el.m_data.m_value.number_unsigned);\n        }\n\n        key = \"_ArrayData_\";\n        if (value.at(key).size() != len)\n        {\n            return true;\n        }\n\n        oa->write_character('[');\n        oa->write_character('$');\n        oa->write_character(dtype);\n        oa->write_character('#');\n\n        key = \"_ArraySize_\";\n        write_ubjson(value.at(key), use_count, use_type, true,  true);\n\n        key = \"_ArrayData_\";\n        if (dtype == 'U' || dtype == 'C')\n        {\n            for (const auto& el : value.at(key))\n            {\n                write_number(static_cast<std::uint8_t>(el.m_data.m_value.number_unsigned), true);\n            }\n        }\n        else if (dtype == 'i')\n        {\n            for (const auto& el : value.at(key))\n            {\n                write_number(static_cast<std::int8_t>(el.m_data.m_value.number_integer), true);\n            }\n        }\n        else if (dtype == 'u')\n        {\n            for (const auto& el : value.at(key))\n            {\n                write_number(static_cast<std::uint16_t>(el.m_data.m_value.number_unsigned), true);\n            }\n        }\n        else if (dtype == 'I')\n        {\n            for (const auto& el : value.at(key))\n            {\n                write_number(static_cast<std::int16_t>(el.m_data.m_value.number_integer), true);\n            }\n        }\n        else if (dtype == 'm')\n        {\n            for (const auto& el : value.at(key))\n            {\n                write_number(static_cast<std::uint32_t>(el.m_data.m_value.number_unsigned), true);\n            }\n        }\n        else if (dtype == 'l')\n        {\n            for (const auto& el : value.at(key))\n            {\n                write_number(static_cast<std::int32_t>(el.m_data.m_value.number_integer), true);\n            }\n        }\n        else if (dtype == 'M')\n        {\n            for (const auto& el : value.at(key))\n            {\n                write_number(static_cast<std::uint64_t>(el.m_data.m_value.number_unsigned), true);\n            }\n        }\n        else if (dtype == 'L')\n        {\n            for (const auto& el : value.at(key))\n            {\n                write_number(static_cast<std::int64_t>(el.m_data.m_value.number_integer), true);\n            }\n        }\n        else if (dtype == 'd')\n        {\n            for (const auto& el : value.at(key))\n            {\n                write_number(static_cast<float>(el.m_data.m_value.number_float), true);\n            }\n        }\n        else if (dtype == 'D')\n        {\n            for (const auto& el : value.at(key))\n            {\n                write_number(static_cast<double>(el.m_data.m_value.number_float), true);\n            }\n        }\n        return false;\n    }\n\n    ///////////////////////\n    // Utility functions //\n    ///////////////////////\n\n    /*\n    @brief write a number to output input\n    @param[in] n number of type @a NumberType\n    @param[in] OutputIsLittleEndian Set to true if output data is\n                                 required to be little endian\n    @tparam NumberType the type of the number\n\n    @note This function needs to respect the system's endianness, because bytes\n          in CBOR, MessagePack, and UBJSON are stored in network order (big\n          endian) and therefore need reordering on little endian systems.\n          On the other hand, BSON and BJData use little endian and should reorder\n          on big endian systems.\n    */\n    template<typename NumberType>\n    void write_number(const NumberType n, const bool OutputIsLittleEndian = false)\n    {\n        // step 1: write number to array of length NumberType\n        std::array<CharType, sizeof(NumberType)> vec{};\n        std::memcpy(vec.data(), &n, sizeof(NumberType));\n\n        // step 2: write array to output (with possible reordering)\n        if (is_little_endian != OutputIsLittleEndian)\n        {\n            // reverse byte order prior to conversion if necessary\n            std::reverse(vec.begin(), vec.end());\n        }\n\n        oa->write_characters(vec.data(), sizeof(NumberType));\n    }\n\n    void write_compact_float(const number_float_t n, detail::input_format_t format)\n    {\n#ifdef __GNUC__\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n#endif\n        if (static_cast<double>(n) >= static_cast<double>(std::numeric_limits<float>::lowest()) &&\n                static_cast<double>(n) <= static_cast<double>((std::numeric_limits<float>::max)()) &&\n                static_cast<double>(static_cast<float>(n)) == static_cast<double>(n))\n        {\n            oa->write_character(format == detail::input_format_t::cbor\n                                ? get_cbor_float_prefix(static_cast<float>(n))\n                                : get_msgpack_float_prefix(static_cast<float>(n)));\n            write_number(static_cast<float>(n));\n        }\n        else\n        {\n            oa->write_character(format == detail::input_format_t::cbor\n                                ? get_cbor_float_prefix(n)\n                                : get_msgpack_float_prefix(n));\n            write_number(n);\n        }\n#ifdef __GNUC__\n#pragma GCC diagnostic pop\n#endif\n    }\n\n  public:\n    // The following to_char_type functions are implement the conversion\n    // between uint8_t and CharType. In case CharType is not unsigned,\n    // such a conversion is required to allow values greater than 128.\n    // See <https://github.com/nlohmann/json/issues/1286> for a discussion.\n    template < typename C = CharType,\n               enable_if_t < std::is_signed<C>::value && std::is_signed<char>::value > * = nullptr >\n    static constexpr CharType to_char_type(std::uint8_t x) noexcept\n    {\n        return *reinterpret_cast<char*>(&x);\n    }\n\n    template < typename C = CharType,\n               enable_if_t < std::is_signed<C>::value && std::is_unsigned<char>::value > * = nullptr >\n    static CharType to_char_type(std::uint8_t x) noexcept\n    {\n        static_assert(sizeof(std::uint8_t) == sizeof(CharType), \"size of CharType must be equal to std::uint8_t\");\n        static_assert(std::is_trivial<CharType>::value, \"CharType must be trivial\");\n        CharType result;\n        std::memcpy(&result, &x, sizeof(x));\n        return result;\n    }\n\n    template<typename C = CharType,\n             enable_if_t<std::is_unsigned<C>::value>* = nullptr>\n    static constexpr CharType to_char_type(std::uint8_t x) noexcept\n    {\n        return x;\n    }\n\n    template < typename InputCharType, typename C = CharType,\n               enable_if_t <\n                   std::is_signed<C>::value &&\n                   std::is_signed<char>::value &&\n                   std::is_same<char, typename std::remove_cv<InputCharType>::type>::value\n                   > * = nullptr >\n    static constexpr CharType to_char_type(InputCharType x) noexcept\n    {\n        return x;\n    }\n\n  private:\n    /// whether we can assume little endianness\n    const bool is_little_endian = little_endianness();\n\n    /// the output\n    output_adapter_t<CharType> oa = nullptr;\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/output/output_adapters.hpp>\n\n// #include <nlohmann/detail/output/serializer.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2008-2009 Björn Hoehrmann <bjoern@hoehrmann.de>\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <algorithm> // reverse, remove, fill, find, none_of\n#include <array> // array\n#include <clocale> // localeconv, lconv\n#include <cmath> // labs, isfinite, isnan, signbit\n#include <cstddef> // size_t, ptrdiff_t\n#include <cstdint> // uint8_t\n#include <cstdio> // snprintf\n#include <limits> // numeric_limits\n#include <string> // string, char_traits\n#include <iomanip> // setfill, setw\n#include <type_traits> // is_same\n#include <utility> // move\n\n// #include <nlohmann/detail/conversions/to_chars.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2009 Florian Loitsch <https://florian.loitsch.com/>\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <array> // array\n#include <cmath>   // signbit, isfinite\n#include <cstdint> // intN_t, uintN_t\n#include <cstring> // memcpy, memmove\n#include <limits> // numeric_limits\n#include <type_traits> // conditional\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n/*!\n@brief implements the Grisu2 algorithm for binary to decimal floating-point\nconversion.\n\nThis implementation is a slightly modified version of the reference\nimplementation which may be obtained from\nhttp://florian.loitsch.com/publications (bench.tar.gz).\n\nThe code is distributed under the MIT license, Copyright (c) 2009 Florian Loitsch.\n\nFor a detailed description of the algorithm see:\n\n[1] Loitsch, \"Printing Floating-Point Numbers Quickly and Accurately with\n    Integers\", Proceedings of the ACM SIGPLAN 2010 Conference on Programming\n    Language Design and Implementation, PLDI 2010\n[2] Burger, Dybvig, \"Printing Floating-Point Numbers Quickly and Accurately\",\n    Proceedings of the ACM SIGPLAN 1996 Conference on Programming Language\n    Design and Implementation, PLDI 1996\n*/\nnamespace dtoa_impl\n{\n\ntemplate<typename Target, typename Source>\nTarget reinterpret_bits(const Source source)\n{\n    static_assert(sizeof(Target) == sizeof(Source), \"size mismatch\");\n\n    Target target;\n    std::memcpy(&target, &source, sizeof(Source));\n    return target;\n}\n\nstruct diyfp // f * 2^e\n{\n    static constexpr int kPrecision = 64; // = q\n\n    std::uint64_t f = 0;\n    int e = 0;\n\n    constexpr diyfp(std::uint64_t f_, int e_) noexcept : f(f_), e(e_) {}\n\n    /*!\n    @brief returns x - y\n    @pre x.e == y.e and x.f >= y.f\n    */\n    static diyfp sub(const diyfp& x, const diyfp& y) noexcept\n    {\n        JSON_ASSERT(x.e == y.e);\n        JSON_ASSERT(x.f >= y.f);\n\n        return {x.f - y.f, x.e};\n    }\n\n    /*!\n    @brief returns x * y\n    @note The result is rounded. (Only the upper q bits are returned.)\n    */\n    static diyfp mul(const diyfp& x, const diyfp& y) noexcept\n    {\n        static_assert(kPrecision == 64, \"internal error\");\n\n        // Computes:\n        //  f = round((x.f * y.f) / 2^q)\n        //  e = x.e + y.e + q\n\n        // Emulate the 64-bit * 64-bit multiplication:\n        //\n        // p = u * v\n        //   = (u_lo + 2^32 u_hi) (v_lo + 2^32 v_hi)\n        //   = (u_lo v_lo         ) + 2^32 ((u_lo v_hi         ) + (u_hi v_lo         )) + 2^64 (u_hi v_hi         )\n        //   = (p0                ) + 2^32 ((p1                ) + (p2                )) + 2^64 (p3                )\n        //   = (p0_lo + 2^32 p0_hi) + 2^32 ((p1_lo + 2^32 p1_hi) + (p2_lo + 2^32 p2_hi)) + 2^64 (p3                )\n        //   = (p0_lo             ) + 2^32 (p0_hi + p1_lo + p2_lo                      ) + 2^64 (p1_hi + p2_hi + p3)\n        //   = (p0_lo             ) + 2^32 (Q                                          ) + 2^64 (H                 )\n        //   = (p0_lo             ) + 2^32 (Q_lo + 2^32 Q_hi                           ) + 2^64 (H                 )\n        //\n        // (Since Q might be larger than 2^32 - 1)\n        //\n        //   = (p0_lo + 2^32 Q_lo) + 2^64 (Q_hi + H)\n        //\n        // (Q_hi + H does not overflow a 64-bit int)\n        //\n        //   = p_lo + 2^64 p_hi\n\n        const std::uint64_t u_lo = x.f & 0xFFFFFFFFu;\n        const std::uint64_t u_hi = x.f >> 32u;\n        const std::uint64_t v_lo = y.f & 0xFFFFFFFFu;\n        const std::uint64_t v_hi = y.f >> 32u;\n\n        const std::uint64_t p0 = u_lo * v_lo;\n        const std::uint64_t p1 = u_lo * v_hi;\n        const std::uint64_t p2 = u_hi * v_lo;\n        const std::uint64_t p3 = u_hi * v_hi;\n\n        const std::uint64_t p0_hi = p0 >> 32u;\n        const std::uint64_t p1_lo = p1 & 0xFFFFFFFFu;\n        const std::uint64_t p1_hi = p1 >> 32u;\n        const std::uint64_t p2_lo = p2 & 0xFFFFFFFFu;\n        const std::uint64_t p2_hi = p2 >> 32u;\n\n        std::uint64_t Q = p0_hi + p1_lo + p2_lo;\n\n        // The full product might now be computed as\n        //\n        // p_hi = p3 + p2_hi + p1_hi + (Q >> 32)\n        // p_lo = p0_lo + (Q << 32)\n        //\n        // But in this particular case here, the full p_lo is not required.\n        // Effectively we only need to add the highest bit in p_lo to p_hi (and\n        // Q_hi + 1 does not overflow).\n\n        Q += std::uint64_t{1} << (64u - 32u - 1u); // round, ties up\n\n        const std::uint64_t h = p3 + p2_hi + p1_hi + (Q >> 32u);\n\n        return {h, x.e + y.e + 64};\n    }\n\n    /*!\n    @brief normalize x such that the significand is >= 2^(q-1)\n    @pre x.f != 0\n    */\n    static diyfp normalize(diyfp x) noexcept\n    {\n        JSON_ASSERT(x.f != 0);\n\n        while ((x.f >> 63u) == 0)\n        {\n            x.f <<= 1u;\n            x.e--;\n        }\n\n        return x;\n    }\n\n    /*!\n    @brief normalize x such that the result has the exponent E\n    @pre e >= x.e and the upper e - x.e bits of x.f must be zero.\n    */\n    static diyfp normalize_to(const diyfp& x, const int target_exponent) noexcept\n    {\n        const int delta = x.e - target_exponent;\n\n        JSON_ASSERT(delta >= 0);\n        JSON_ASSERT(((x.f << delta) >> delta) == x.f);\n\n        return {x.f << delta, target_exponent};\n    }\n};\n\nstruct boundaries\n{\n    diyfp w;\n    diyfp minus;\n    diyfp plus;\n};\n\n/*!\nCompute the (normalized) diyfp representing the input number 'value' and its\nboundaries.\n\n@pre value must be finite and positive\n*/\ntemplate<typename FloatType>\nboundaries compute_boundaries(FloatType value)\n{\n    JSON_ASSERT(std::isfinite(value));\n    JSON_ASSERT(value > 0);\n\n    // Convert the IEEE representation into a diyfp.\n    //\n    // If v is denormal:\n    //      value = 0.F * 2^(1 - bias) = (          F) * 2^(1 - bias - (p-1))\n    // If v is normalized:\n    //      value = 1.F * 2^(E - bias) = (2^(p-1) + F) * 2^(E - bias - (p-1))\n\n    static_assert(std::numeric_limits<FloatType>::is_iec559,\n                  \"internal error: dtoa_short requires an IEEE-754 floating-point implementation\");\n\n    constexpr int      kPrecision = std::numeric_limits<FloatType>::digits; // = p (includes the hidden bit)\n    constexpr int      kBias      = std::numeric_limits<FloatType>::max_exponent - 1 + (kPrecision - 1);\n    constexpr int      kMinExp    = 1 - kBias;\n    constexpr std::uint64_t kHiddenBit = std::uint64_t{1} << (kPrecision - 1); // = 2^(p-1)\n\n    using bits_type = typename std::conditional<kPrecision == 24, std::uint32_t, std::uint64_t >::type;\n\n    const auto bits = static_cast<std::uint64_t>(reinterpret_bits<bits_type>(value));\n    const std::uint64_t E = bits >> (kPrecision - 1);\n    const std::uint64_t F = bits & (kHiddenBit - 1);\n\n    const bool is_denormal = E == 0;\n    const diyfp v = is_denormal\n                    ? diyfp(F, kMinExp)\n                    : diyfp(F + kHiddenBit, static_cast<int>(E) - kBias);\n\n    // Compute the boundaries m- and m+ of the floating-point value\n    // v = f * 2^e.\n    //\n    // Determine v- and v+, the floating-point predecessor and successor if v,\n    // respectively.\n    //\n    //      v- = v - 2^e        if f != 2^(p-1) or e == e_min                (A)\n    //         = v - 2^(e-1)    if f == 2^(p-1) and e > e_min                (B)\n    //\n    //      v+ = v + 2^e\n    //\n    // Let m- = (v- + v) / 2 and m+ = (v + v+) / 2. All real numbers _strictly_\n    // between m- and m+ round to v, regardless of how the input rounding\n    // algorithm breaks ties.\n    //\n    //      ---+-------------+-------------+-------------+-------------+---  (A)\n    //         v-            m-            v             m+            v+\n    //\n    //      -----------------+------+------+-------------+-------------+---  (B)\n    //                       v-     m-     v             m+            v+\n\n    const bool lower_boundary_is_closer = F == 0 && E > 1;\n    const diyfp m_plus = diyfp(2 * v.f + 1, v.e - 1);\n    const diyfp m_minus = lower_boundary_is_closer\n                          ? diyfp(4 * v.f - 1, v.e - 2)  // (B)\n                          : diyfp(2 * v.f - 1, v.e - 1); // (A)\n\n    // Determine the normalized w+ = m+.\n    const diyfp w_plus = diyfp::normalize(m_plus);\n\n    // Determine w- = m- such that e_(w-) = e_(w+).\n    const diyfp w_minus = diyfp::normalize_to(m_minus, w_plus.e);\n\n    return {diyfp::normalize(v), w_minus, w_plus};\n}\n\n// Given normalized diyfp w, Grisu needs to find a (normalized) cached\n// power-of-ten c, such that the exponent of the product c * w = f * 2^e lies\n// within a certain range [alpha, gamma] (Definition 3.2 from [1])\n//\n//      alpha <= e = e_c + e_w + q <= gamma\n//\n// or\n//\n//      f_c * f_w * 2^alpha <= f_c 2^(e_c) * f_w 2^(e_w) * 2^q\n//                          <= f_c * f_w * 2^gamma\n//\n// Since c and w are normalized, i.e. 2^(q-1) <= f < 2^q, this implies\n//\n//      2^(q-1) * 2^(q-1) * 2^alpha <= c * w * 2^q < 2^q * 2^q * 2^gamma\n//\n// or\n//\n//      2^(q - 2 + alpha) <= c * w < 2^(q + gamma)\n//\n// The choice of (alpha,gamma) determines the size of the table and the form of\n// the digit generation procedure. Using (alpha,gamma)=(-60,-32) works out well\n// in practice:\n//\n// The idea is to cut the number c * w = f * 2^e into two parts, which can be\n// processed independently: An integral part p1, and a fractional part p2:\n//\n//      f * 2^e = ( (f div 2^-e) * 2^-e + (f mod 2^-e) ) * 2^e\n//              = (f div 2^-e) + (f mod 2^-e) * 2^e\n//              = p1 + p2 * 2^e\n//\n// The conversion of p1 into decimal form requires a series of divisions and\n// modulos by (a power of) 10. These operations are faster for 32-bit than for\n// 64-bit integers, so p1 should ideally fit into a 32-bit integer. This can be\n// achieved by choosing\n//\n//      -e >= 32   or   e <= -32 := gamma\n//\n// In order to convert the fractional part\n//\n//      p2 * 2^e = p2 / 2^-e = d[-1] / 10^1 + d[-2] / 10^2 + ...\n//\n// into decimal form, the fraction is repeatedly multiplied by 10 and the digits\n// d[-i] are extracted in order:\n//\n//      (10 * p2) div 2^-e = d[-1]\n//      (10 * p2) mod 2^-e = d[-2] / 10^1 + ...\n//\n// The multiplication by 10 must not overflow. It is sufficient to choose\n//\n//      10 * p2 < 16 * p2 = 2^4 * p2 <= 2^64.\n//\n// Since p2 = f mod 2^-e < 2^-e,\n//\n//      -e <= 60   or   e >= -60 := alpha\n\nconstexpr int kAlpha = -60;\nconstexpr int kGamma = -32;\n\nstruct cached_power // c = f * 2^e ~= 10^k\n{\n    std::uint64_t f;\n    int e;\n    int k;\n};\n\n/*!\nFor a normalized diyfp w = f * 2^e, this function returns a (normalized) cached\npower-of-ten c = f_c * 2^e_c, such that the exponent of the product w * c\nsatisfies (Definition 3.2 from [1])\n\n     alpha <= e_c + e + q <= gamma.\n*/\ninline cached_power get_cached_power_for_binary_exponent(int e)\n{\n    // Now\n    //\n    //      alpha <= e_c + e + q <= gamma                                    (1)\n    //      ==> f_c * 2^alpha <= c * 2^e * 2^q\n    //\n    // and since the c's are normalized, 2^(q-1) <= f_c,\n    //\n    //      ==> 2^(q - 1 + alpha) <= c * 2^(e + q)\n    //      ==> 2^(alpha - e - 1) <= c\n    //\n    // If c were an exact power of ten, i.e. c = 10^k, one may determine k as\n    //\n    //      k = ceil( log_10( 2^(alpha - e - 1) ) )\n    //        = ceil( (alpha - e - 1) * log_10(2) )\n    //\n    // From the paper:\n    // \"In theory the result of the procedure could be wrong since c is rounded,\n    //  and the computation itself is approximated [...]. In practice, however,\n    //  this simple function is sufficient.\"\n    //\n    // For IEEE double precision floating-point numbers converted into\n    // normalized diyfp's w = f * 2^e, with q = 64,\n    //\n    //      e >= -1022      (min IEEE exponent)\n    //           -52        (p - 1)\n    //           -52        (p - 1, possibly normalize denormal IEEE numbers)\n    //           -11        (normalize the diyfp)\n    //         = -1137\n    //\n    // and\n    //\n    //      e <= +1023      (max IEEE exponent)\n    //           -52        (p - 1)\n    //           -11        (normalize the diyfp)\n    //         = 960\n    //\n    // This binary exponent range [-1137,960] results in a decimal exponent\n    // range [-307,324]. One does not need to store a cached power for each\n    // k in this range. For each such k it suffices to find a cached power\n    // such that the exponent of the product lies in [alpha,gamma].\n    // This implies that the difference of the decimal exponents of adjacent\n    // table entries must be less than or equal to\n    //\n    //      floor( (gamma - alpha) * log_10(2) ) = 8.\n    //\n    // (A smaller distance gamma-alpha would require a larger table.)\n\n    // NB:\n    // Actually this function returns c, such that -60 <= e_c + e + 64 <= -34.\n\n    constexpr int kCachedPowersMinDecExp = -300;\n    constexpr int kCachedPowersDecStep = 8;\n\n    static constexpr std::array<cached_power, 79> kCachedPowers =\n    {\n        {\n            { 0xAB70FE17C79AC6CA, -1060, -300 },\n            { 0xFF77B1FCBEBCDC4F, -1034, -292 },\n            { 0xBE5691EF416BD60C, -1007, -284 },\n            { 0x8DD01FAD907FFC3C,  -980, -276 },\n            { 0xD3515C2831559A83,  -954, -268 },\n            { 0x9D71AC8FADA6C9B5,  -927, -260 },\n            { 0xEA9C227723EE8BCB,  -901, -252 },\n            { 0xAECC49914078536D,  -874, -244 },\n            { 0x823C12795DB6CE57,  -847, -236 },\n            { 0xC21094364DFB5637,  -821, -228 },\n            { 0x9096EA6F3848984F,  -794, -220 },\n            { 0xD77485CB25823AC7,  -768, -212 },\n            { 0xA086CFCD97BF97F4,  -741, -204 },\n            { 0xEF340A98172AACE5,  -715, -196 },\n            { 0xB23867FB2A35B28E,  -688, -188 },\n            { 0x84C8D4DFD2C63F3B,  -661, -180 },\n            { 0xC5DD44271AD3CDBA,  -635, -172 },\n            { 0x936B9FCEBB25C996,  -608, -164 },\n            { 0xDBAC6C247D62A584,  -582, -156 },\n            { 0xA3AB66580D5FDAF6,  -555, -148 },\n            { 0xF3E2F893DEC3F126,  -529, -140 },\n            { 0xB5B5ADA8AAFF80B8,  -502, -132 },\n            { 0x87625F056C7C4A8B,  -475, -124 },\n            { 0xC9BCFF6034C13053,  -449, -116 },\n            { 0x964E858C91BA2655,  -422, -108 },\n            { 0xDFF9772470297EBD,  -396, -100 },\n            { 0xA6DFBD9FB8E5B88F,  -369,  -92 },\n            { 0xF8A95FCF88747D94,  -343,  -84 },\n            { 0xB94470938FA89BCF,  -316,  -76 },\n            { 0x8A08F0F8BF0F156B,  -289,  -68 },\n            { 0xCDB02555653131B6,  -263,  -60 },\n            { 0x993FE2C6D07B7FAC,  -236,  -52 },\n            { 0xE45C10C42A2B3B06,  -210,  -44 },\n            { 0xAA242499697392D3,  -183,  -36 },\n            { 0xFD87B5F28300CA0E,  -157,  -28 },\n            { 0xBCE5086492111AEB,  -130,  -20 },\n            { 0x8CBCCC096F5088CC,  -103,  -12 },\n            { 0xD1B71758E219652C,   -77,   -4 },\n            { 0x9C40000000000000,   -50,    4 },\n            { 0xE8D4A51000000000,   -24,   12 },\n            { 0xAD78EBC5AC620000,     3,   20 },\n            { 0x813F3978F8940984,    30,   28 },\n            { 0xC097CE7BC90715B3,    56,   36 },\n            { 0x8F7E32CE7BEA5C70,    83,   44 },\n            { 0xD5D238A4ABE98068,   109,   52 },\n            { 0x9F4F2726179A2245,   136,   60 },\n            { 0xED63A231D4C4FB27,   162,   68 },\n            { 0xB0DE65388CC8ADA8,   189,   76 },\n            { 0x83C7088E1AAB65DB,   216,   84 },\n            { 0xC45D1DF942711D9A,   242,   92 },\n            { 0x924D692CA61BE758,   269,  100 },\n            { 0xDA01EE641A708DEA,   295,  108 },\n            { 0xA26DA3999AEF774A,   322,  116 },\n            { 0xF209787BB47D6B85,   348,  124 },\n            { 0xB454E4A179DD1877,   375,  132 },\n            { 0x865B86925B9BC5C2,   402,  140 },\n            { 0xC83553C5C8965D3D,   428,  148 },\n            { 0x952AB45CFA97A0B3,   455,  156 },\n            { 0xDE469FBD99A05FE3,   481,  164 },\n            { 0xA59BC234DB398C25,   508,  172 },\n            { 0xF6C69A72A3989F5C,   534,  180 },\n            { 0xB7DCBF5354E9BECE,   561,  188 },\n            { 0x88FCF317F22241E2,   588,  196 },\n            { 0xCC20CE9BD35C78A5,   614,  204 },\n            { 0x98165AF37B2153DF,   641,  212 },\n            { 0xE2A0B5DC971F303A,   667,  220 },\n            { 0xA8D9D1535CE3B396,   694,  228 },\n            { 0xFB9B7CD9A4A7443C,   720,  236 },\n            { 0xBB764C4CA7A44410,   747,  244 },\n            { 0x8BAB8EEFB6409C1A,   774,  252 },\n            { 0xD01FEF10A657842C,   800,  260 },\n            { 0x9B10A4E5E9913129,   827,  268 },\n            { 0xE7109BFBA19C0C9D,   853,  276 },\n            { 0xAC2820D9623BF429,   880,  284 },\n            { 0x80444B5E7AA7CF85,   907,  292 },\n            { 0xBF21E44003ACDD2D,   933,  300 },\n            { 0x8E679C2F5E44FF8F,   960,  308 },\n            { 0xD433179D9C8CB841,   986,  316 },\n            { 0x9E19DB92B4E31BA9,  1013,  324 },\n        }\n    };\n\n    // This computation gives exactly the same results for k as\n    //      k = ceil((kAlpha - e - 1) * 0.30102999566398114)\n    // for |e| <= 1500, but doesn't require floating-point operations.\n    // NB: log_10(2) ~= 78913 / 2^18\n    JSON_ASSERT(e >= -1500);\n    JSON_ASSERT(e <=  1500);\n    const int f = kAlpha - e - 1;\n    const int k = (f * 78913) / (1 << 18) + static_cast<int>(f > 0);\n\n    const int index = (-kCachedPowersMinDecExp + k + (kCachedPowersDecStep - 1)) / kCachedPowersDecStep;\n    JSON_ASSERT(index >= 0);\n    JSON_ASSERT(static_cast<std::size_t>(index) < kCachedPowers.size());\n\n    const cached_power cached = kCachedPowers[static_cast<std::size_t>(index)];\n    JSON_ASSERT(kAlpha <= cached.e + e + 64);\n    JSON_ASSERT(kGamma >= cached.e + e + 64);\n\n    return cached;\n}\n\n/*!\nFor n != 0, returns k, such that pow10 := 10^(k-1) <= n < 10^k.\nFor n == 0, returns 1 and sets pow10 := 1.\n*/\ninline int find_largest_pow10(const std::uint32_t n, std::uint32_t& pow10)\n{\n    // LCOV_EXCL_START\n    if (n >= 1000000000)\n    {\n        pow10 = 1000000000;\n        return 10;\n    }\n    // LCOV_EXCL_STOP\n    if (n >= 100000000)\n    {\n        pow10 = 100000000;\n        return  9;\n    }\n    if (n >= 10000000)\n    {\n        pow10 = 10000000;\n        return  8;\n    }\n    if (n >= 1000000)\n    {\n        pow10 = 1000000;\n        return  7;\n    }\n    if (n >= 100000)\n    {\n        pow10 = 100000;\n        return  6;\n    }\n    if (n >= 10000)\n    {\n        pow10 = 10000;\n        return  5;\n    }\n    if (n >= 1000)\n    {\n        pow10 = 1000;\n        return  4;\n    }\n    if (n >= 100)\n    {\n        pow10 = 100;\n        return  3;\n    }\n    if (n >= 10)\n    {\n        pow10 = 10;\n        return  2;\n    }\n\n    pow10 = 1;\n    return 1;\n}\n\ninline void grisu2_round(char* buf, int len, std::uint64_t dist, std::uint64_t delta,\n                         std::uint64_t rest, std::uint64_t ten_k)\n{\n    JSON_ASSERT(len >= 1);\n    JSON_ASSERT(dist <= delta);\n    JSON_ASSERT(rest <= delta);\n    JSON_ASSERT(ten_k > 0);\n\n    //               <--------------------------- delta ---->\n    //                                  <---- dist --------->\n    // --------------[------------------+-------------------]--------------\n    //               M-                 w                   M+\n    //\n    //                                  ten_k\n    //                                <------>\n    //                                       <---- rest ---->\n    // --------------[------------------+----+--------------]--------------\n    //                                  w    V\n    //                                       = buf * 10^k\n    //\n    // ten_k represents a unit-in-the-last-place in the decimal representation\n    // stored in buf.\n    // Decrement buf by ten_k while this takes buf closer to w.\n\n    // The tests are written in this order to avoid overflow in unsigned\n    // integer arithmetic.\n\n    while (rest < dist\n            && delta - rest >= ten_k\n            && (rest + ten_k < dist || dist - rest > rest + ten_k - dist))\n    {\n        JSON_ASSERT(buf[len - 1] != '0');\n        buf[len - 1]--;\n        rest += ten_k;\n    }\n}\n\n/*!\nGenerates V = buffer * 10^decimal_exponent, such that M- <= V <= M+.\nM- and M+ must be normalized and share the same exponent -60 <= e <= -32.\n*/\ninline void grisu2_digit_gen(char* buffer, int& length, int& decimal_exponent,\n                             diyfp M_minus, diyfp w, diyfp M_plus)\n{\n    static_assert(kAlpha >= -60, \"internal error\");\n    static_assert(kGamma <= -32, \"internal error\");\n\n    // Generates the digits (and the exponent) of a decimal floating-point\n    // number V = buffer * 10^decimal_exponent in the range [M-, M+]. The diyfp's\n    // w, M- and M+ share the same exponent e, which satisfies alpha <= e <= gamma.\n    //\n    //               <--------------------------- delta ---->\n    //                                  <---- dist --------->\n    // --------------[------------------+-------------------]--------------\n    //               M-                 w                   M+\n    //\n    // Grisu2 generates the digits of M+ from left to right and stops as soon as\n    // V is in [M-,M+].\n\n    JSON_ASSERT(M_plus.e >= kAlpha);\n    JSON_ASSERT(M_plus.e <= kGamma);\n\n    std::uint64_t delta = diyfp::sub(M_plus, M_minus).f; // (significand of (M+ - M-), implicit exponent is e)\n    std::uint64_t dist  = diyfp::sub(M_plus, w      ).f; // (significand of (M+ - w ), implicit exponent is e)\n\n    // Split M+ = f * 2^e into two parts p1 and p2 (note: e < 0):\n    //\n    //      M+ = f * 2^e\n    //         = ((f div 2^-e) * 2^-e + (f mod 2^-e)) * 2^e\n    //         = ((p1        ) * 2^-e + (p2        )) * 2^e\n    //         = p1 + p2 * 2^e\n\n    const diyfp one(std::uint64_t{1} << -M_plus.e, M_plus.e);\n\n    auto p1 = static_cast<std::uint32_t>(M_plus.f >> -one.e); // p1 = f div 2^-e (Since -e >= 32, p1 fits into a 32-bit int.)\n    std::uint64_t p2 = M_plus.f & (one.f - 1);                    // p2 = f mod 2^-e\n\n    // 1)\n    //\n    // Generate the digits of the integral part p1 = d[n-1]...d[1]d[0]\n\n    JSON_ASSERT(p1 > 0);\n\n    std::uint32_t pow10{};\n    const int k = find_largest_pow10(p1, pow10);\n\n    //      10^(k-1) <= p1 < 10^k, pow10 = 10^(k-1)\n    //\n    //      p1 = (p1 div 10^(k-1)) * 10^(k-1) + (p1 mod 10^(k-1))\n    //         = (d[k-1]         ) * 10^(k-1) + (p1 mod 10^(k-1))\n    //\n    //      M+ = p1                                             + p2 * 2^e\n    //         = d[k-1] * 10^(k-1) + (p1 mod 10^(k-1))          + p2 * 2^e\n    //         = d[k-1] * 10^(k-1) + ((p1 mod 10^(k-1)) * 2^-e + p2) * 2^e\n    //         = d[k-1] * 10^(k-1) + (                         rest) * 2^e\n    //\n    // Now generate the digits d[n] of p1 from left to right (n = k-1,...,0)\n    //\n    //      p1 = d[k-1]...d[n] * 10^n + d[n-1]...d[0]\n    //\n    // but stop as soon as\n    //\n    //      rest * 2^e = (d[n-1]...d[0] * 2^-e + p2) * 2^e <= delta * 2^e\n\n    int n = k;\n    while (n > 0)\n    {\n        // Invariants:\n        //      M+ = buffer * 10^n + (p1 + p2 * 2^e)    (buffer = 0 for n = k)\n        //      pow10 = 10^(n-1) <= p1 < 10^n\n        //\n        const std::uint32_t d = p1 / pow10;  // d = p1 div 10^(n-1)\n        const std::uint32_t r = p1 % pow10;  // r = p1 mod 10^(n-1)\n        //\n        //      M+ = buffer * 10^n + (d * 10^(n-1) + r) + p2 * 2^e\n        //         = (buffer * 10 + d) * 10^(n-1) + (r + p2 * 2^e)\n        //\n        JSON_ASSERT(d <= 9);\n        buffer[length++] = static_cast<char>('0' + d); // buffer := buffer * 10 + d\n        //\n        //      M+ = buffer * 10^(n-1) + (r + p2 * 2^e)\n        //\n        p1 = r;\n        n--;\n        //\n        //      M+ = buffer * 10^n + (p1 + p2 * 2^e)\n        //      pow10 = 10^n\n        //\n\n        // Now check if enough digits have been generated.\n        // Compute\n        //\n        //      p1 + p2 * 2^e = (p1 * 2^-e + p2) * 2^e = rest * 2^e\n        //\n        // Note:\n        // Since rest and delta share the same exponent e, it suffices to\n        // compare the significands.\n        const std::uint64_t rest = (std::uint64_t{p1} << -one.e) + p2;\n        if (rest <= delta)\n        {\n            // V = buffer * 10^n, with M- <= V <= M+.\n\n            decimal_exponent += n;\n\n            // We may now just stop. But instead look if the buffer could be\n            // decremented to bring V closer to w.\n            //\n            // pow10 = 10^n is now 1 ulp in the decimal representation V.\n            // The rounding procedure works with diyfp's with an implicit\n            // exponent of e.\n            //\n            //      10^n = (10^n * 2^-e) * 2^e = ulp * 2^e\n            //\n            const std::uint64_t ten_n = std::uint64_t{pow10} << -one.e;\n            grisu2_round(buffer, length, dist, delta, rest, ten_n);\n\n            return;\n        }\n\n        pow10 /= 10;\n        //\n        //      pow10 = 10^(n-1) <= p1 < 10^n\n        // Invariants restored.\n    }\n\n    // 2)\n    //\n    // The digits of the integral part have been generated:\n    //\n    //      M+ = d[k-1]...d[1]d[0] + p2 * 2^e\n    //         = buffer            + p2 * 2^e\n    //\n    // Now generate the digits of the fractional part p2 * 2^e.\n    //\n    // Note:\n    // No decimal point is generated: the exponent is adjusted instead.\n    //\n    // p2 actually represents the fraction\n    //\n    //      p2 * 2^e\n    //          = p2 / 2^-e\n    //          = d[-1] / 10^1 + d[-2] / 10^2 + ...\n    //\n    // Now generate the digits d[-m] of p1 from left to right (m = 1,2,...)\n    //\n    //      p2 * 2^e = d[-1]d[-2]...d[-m] * 10^-m\n    //                      + 10^-m * (d[-m-1] / 10^1 + d[-m-2] / 10^2 + ...)\n    //\n    // using\n    //\n    //      10^m * p2 = ((10^m * p2) div 2^-e) * 2^-e + ((10^m * p2) mod 2^-e)\n    //                = (                   d) * 2^-e + (                   r)\n    //\n    // or\n    //      10^m * p2 * 2^e = d + r * 2^e\n    //\n    // i.e.\n    //\n    //      M+ = buffer + p2 * 2^e\n    //         = buffer + 10^-m * (d + r * 2^e)\n    //         = (buffer * 10^m + d) * 10^-m + 10^-m * r * 2^e\n    //\n    // and stop as soon as 10^-m * r * 2^e <= delta * 2^e\n\n    JSON_ASSERT(p2 > delta);\n\n    int m = 0;\n    for (;;)\n    {\n        // Invariant:\n        //      M+ = buffer * 10^-m + 10^-m * (d[-m-1] / 10 + d[-m-2] / 10^2 + ...) * 2^e\n        //         = buffer * 10^-m + 10^-m * (p2                                 ) * 2^e\n        //         = buffer * 10^-m + 10^-m * (1/10 * (10 * p2)                   ) * 2^e\n        //         = buffer * 10^-m + 10^-m * (1/10 * ((10*p2 div 2^-e) * 2^-e + (10*p2 mod 2^-e)) * 2^e\n        //\n        JSON_ASSERT(p2 <= (std::numeric_limits<std::uint64_t>::max)() / 10);\n        p2 *= 10;\n        const std::uint64_t d = p2 >> -one.e;     // d = (10 * p2) div 2^-e\n        const std::uint64_t r = p2 & (one.f - 1); // r = (10 * p2) mod 2^-e\n        //\n        //      M+ = buffer * 10^-m + 10^-m * (1/10 * (d * 2^-e + r) * 2^e\n        //         = buffer * 10^-m + 10^-m * (1/10 * (d + r * 2^e))\n        //         = (buffer * 10 + d) * 10^(-m-1) + 10^(-m-1) * r * 2^e\n        //\n        JSON_ASSERT(d <= 9);\n        buffer[length++] = static_cast<char>('0' + d); // buffer := buffer * 10 + d\n        //\n        //      M+ = buffer * 10^(-m-1) + 10^(-m-1) * r * 2^e\n        //\n        p2 = r;\n        m++;\n        //\n        //      M+ = buffer * 10^-m + 10^-m * p2 * 2^e\n        // Invariant restored.\n\n        // Check if enough digits have been generated.\n        //\n        //      10^-m * p2 * 2^e <= delta * 2^e\n        //              p2 * 2^e <= 10^m * delta * 2^e\n        //                    p2 <= 10^m * delta\n        delta *= 10;\n        dist  *= 10;\n        if (p2 <= delta)\n        {\n            break;\n        }\n    }\n\n    // V = buffer * 10^-m, with M- <= V <= M+.\n\n    decimal_exponent -= m;\n\n    // 1 ulp in the decimal representation is now 10^-m.\n    // Since delta and dist are now scaled by 10^m, we need to do the\n    // same with ulp in order to keep the units in sync.\n    //\n    //      10^m * 10^-m = 1 = 2^-e * 2^e = ten_m * 2^e\n    //\n    const std::uint64_t ten_m = one.f;\n    grisu2_round(buffer, length, dist, delta, p2, ten_m);\n\n    // By construction this algorithm generates the shortest possible decimal\n    // number (Loitsch, Theorem 6.2) which rounds back to w.\n    // For an input number of precision p, at least\n    //\n    //      N = 1 + ceil(p * log_10(2))\n    //\n    // decimal digits are sufficient to identify all binary floating-point\n    // numbers (Matula, \"In-and-Out conversions\").\n    // This implies that the algorithm does not produce more than N decimal\n    // digits.\n    //\n    //      N = 17 for p = 53 (IEEE double precision)\n    //      N = 9  for p = 24 (IEEE single precision)\n}\n\n/*!\nv = buf * 10^decimal_exponent\nlen is the length of the buffer (number of decimal digits)\nThe buffer must be large enough, i.e. >= max_digits10.\n*/\nJSON_HEDLEY_NON_NULL(1)\ninline void grisu2(char* buf, int& len, int& decimal_exponent,\n                   diyfp m_minus, diyfp v, diyfp m_plus)\n{\n    JSON_ASSERT(m_plus.e == m_minus.e);\n    JSON_ASSERT(m_plus.e == v.e);\n\n    //  --------(-----------------------+-----------------------)--------    (A)\n    //          m-                      v                       m+\n    //\n    //  --------------------(-----------+-----------------------)--------    (B)\n    //                      m-          v                       m+\n    //\n    // First scale v (and m- and m+) such that the exponent is in the range\n    // [alpha, gamma].\n\n    const cached_power cached = get_cached_power_for_binary_exponent(m_plus.e);\n\n    const diyfp c_minus_k(cached.f, cached.e); // = c ~= 10^-k\n\n    // The exponent of the products is = v.e + c_minus_k.e + q and is in the range [alpha,gamma]\n    const diyfp w       = diyfp::mul(v,       c_minus_k);\n    const diyfp w_minus = diyfp::mul(m_minus, c_minus_k);\n    const diyfp w_plus  = diyfp::mul(m_plus,  c_minus_k);\n\n    //  ----(---+---)---------------(---+---)---------------(---+---)----\n    //          w-                      w                       w+\n    //          = c*m-                  = c*v                   = c*m+\n    //\n    // diyfp::mul rounds its result and c_minus_k is approximated too. w, w- and\n    // w+ are now off by a small amount.\n    // In fact:\n    //\n    //      w - v * 10^k < 1 ulp\n    //\n    // To account for this inaccuracy, add resp. subtract 1 ulp.\n    //\n    //  --------+---[---------------(---+---)---------------]---+--------\n    //          w-  M-                  w                   M+  w+\n    //\n    // Now any number in [M-, M+] (bounds included) will round to w when input,\n    // regardless of how the input rounding algorithm breaks ties.\n    //\n    // And digit_gen generates the shortest possible such number in [M-, M+].\n    // Note that this does not mean that Grisu2 always generates the shortest\n    // possible number in the interval (m-, m+).\n    const diyfp M_minus(w_minus.f + 1, w_minus.e);\n    const diyfp M_plus (w_plus.f  - 1, w_plus.e );\n\n    decimal_exponent = -cached.k; // = -(-k) = k\n\n    grisu2_digit_gen(buf, len, decimal_exponent, M_minus, w, M_plus);\n}\n\n/*!\nv = buf * 10^decimal_exponent\nlen is the length of the buffer (number of decimal digits)\nThe buffer must be large enough, i.e. >= max_digits10.\n*/\ntemplate<typename FloatType>\nJSON_HEDLEY_NON_NULL(1)\nvoid grisu2(char* buf, int& len, int& decimal_exponent, FloatType value)\n{\n    static_assert(diyfp::kPrecision >= std::numeric_limits<FloatType>::digits + 3,\n                  \"internal error: not enough precision\");\n\n    JSON_ASSERT(std::isfinite(value));\n    JSON_ASSERT(value > 0);\n\n    // If the neighbors (and boundaries) of 'value' are always computed for double-precision\n    // numbers, all float's can be recovered using strtod (and strtof). However, the resulting\n    // decimal representations are not exactly \"short\".\n    //\n    // The documentation for 'std::to_chars' (https://en.cppreference.com/w/cpp/utility/to_chars)\n    // says \"value is converted to a string as if by std::sprintf in the default (\"C\") locale\"\n    // and since sprintf promotes floats to doubles, I think this is exactly what 'std::to_chars'\n    // does.\n    // On the other hand, the documentation for 'std::to_chars' requires that \"parsing the\n    // representation using the corresponding std::from_chars function recovers value exactly\". That\n    // indicates that single precision floating-point numbers should be recovered using\n    // 'std::strtof'.\n    //\n    // NB: If the neighbors are computed for single-precision numbers, there is a single float\n    //     (7.0385307e-26f) which can't be recovered using strtod. The resulting double precision\n    //     value is off by 1 ulp.\n#if 0 // NOLINT(readability-avoid-unconditional-preprocessor-if)\n    const boundaries w = compute_boundaries(static_cast<double>(value));\n#else\n    const boundaries w = compute_boundaries(value);\n#endif\n\n    grisu2(buf, len, decimal_exponent, w.minus, w.w, w.plus);\n}\n\n/*!\n@brief appends a decimal representation of e to buf\n@return a pointer to the element following the exponent.\n@pre -1000 < e < 1000\n*/\nJSON_HEDLEY_NON_NULL(1)\nJSON_HEDLEY_RETURNS_NON_NULL\ninline char* append_exponent(char* buf, int e)\n{\n    JSON_ASSERT(e > -1000);\n    JSON_ASSERT(e <  1000);\n\n    if (e < 0)\n    {\n        e = -e;\n        *buf++ = '-';\n    }\n    else\n    {\n        *buf++ = '+';\n    }\n\n    auto k = static_cast<std::uint32_t>(e);\n    if (k < 10)\n    {\n        // Always print at least two digits in the exponent.\n        // This is for compatibility with printf(\"%g\").\n        *buf++ = '0';\n        *buf++ = static_cast<char>('0' + k);\n    }\n    else if (k < 100)\n    {\n        *buf++ = static_cast<char>('0' + k / 10);\n        k %= 10;\n        *buf++ = static_cast<char>('0' + k);\n    }\n    else\n    {\n        *buf++ = static_cast<char>('0' + k / 100);\n        k %= 100;\n        *buf++ = static_cast<char>('0' + k / 10);\n        k %= 10;\n        *buf++ = static_cast<char>('0' + k);\n    }\n\n    return buf;\n}\n\n/*!\n@brief prettify v = buf * 10^decimal_exponent\n\nIf v is in the range [10^min_exp, 10^max_exp) it will be printed in fixed-point\nnotation. Otherwise it will be printed in exponential notation.\n\n@pre min_exp < 0\n@pre max_exp > 0\n*/\nJSON_HEDLEY_NON_NULL(1)\nJSON_HEDLEY_RETURNS_NON_NULL\ninline char* format_buffer(char* buf, int len, int decimal_exponent,\n                           int min_exp, int max_exp)\n{\n    JSON_ASSERT(min_exp < 0);\n    JSON_ASSERT(max_exp > 0);\n\n    const int k = len;\n    const int n = len + decimal_exponent;\n\n    // v = buf * 10^(n-k)\n    // k is the length of the buffer (number of decimal digits)\n    // n is the position of the decimal point relative to the start of the buffer.\n\n    if (k <= n && n <= max_exp)\n    {\n        // digits[000]\n        // len <= max_exp + 2\n\n        std::memset(buf + k, '0', static_cast<size_t>(n) - static_cast<size_t>(k));\n        // Make it look like a floating-point number (#362, #378)\n        buf[n + 0] = '.';\n        buf[n + 1] = '0';\n        return buf + (static_cast<size_t>(n) + 2);\n    }\n\n    if (0 < n && n <= max_exp)\n    {\n        // dig.its\n        // len <= max_digits10 + 1\n\n        JSON_ASSERT(k > n);\n\n        std::memmove(buf + (static_cast<size_t>(n) + 1), buf + n, static_cast<size_t>(k) - static_cast<size_t>(n));\n        buf[n] = '.';\n        return buf + (static_cast<size_t>(k) + 1U);\n    }\n\n    if (min_exp < n && n <= 0)\n    {\n        // 0.[000]digits\n        // len <= 2 + (-min_exp - 1) + max_digits10\n\n        std::memmove(buf + (2 + static_cast<size_t>(-n)), buf, static_cast<size_t>(k));\n        buf[0] = '0';\n        buf[1] = '.';\n        std::memset(buf + 2, '0', static_cast<size_t>(-n));\n        return buf + (2U + static_cast<size_t>(-n) + static_cast<size_t>(k));\n    }\n\n    if (k == 1)\n    {\n        // dE+123\n        // len <= 1 + 5\n\n        buf += 1;\n    }\n    else\n    {\n        // d.igitsE+123\n        // len <= max_digits10 + 1 + 5\n\n        std::memmove(buf + 2, buf + 1, static_cast<size_t>(k) - 1);\n        buf[1] = '.';\n        buf += 1 + static_cast<size_t>(k);\n    }\n\n    *buf++ = 'e';\n    return append_exponent(buf, n - 1);\n}\n\n}  // namespace dtoa_impl\n\n/*!\n@brief generates a decimal representation of the floating-point number value in [first, last).\n\nThe format of the resulting decimal representation is similar to printf's %g\nformat. Returns an iterator pointing past-the-end of the decimal representation.\n\n@note The input number must be finite, i.e. NaN's and Inf's are not supported.\n@note The buffer must be large enough.\n@note The result is NOT null-terminated.\n*/\ntemplate<typename FloatType>\nJSON_HEDLEY_NON_NULL(1, 2)\nJSON_HEDLEY_RETURNS_NON_NULL\nchar* to_chars(char* first, const char* last, FloatType value)\n{\n    static_cast<void>(last); // maybe unused - fix warning\n    JSON_ASSERT(std::isfinite(value));\n\n    // Use signbit(value) instead of (value < 0) since signbit works for -0.\n    if (std::signbit(value))\n    {\n        value = -value;\n        *first++ = '-';\n    }\n\n#ifdef __GNUC__\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n#endif\n    if (value == 0) // +-0\n    {\n        *first++ = '0';\n        // Make it look like a floating-point number (#362, #378)\n        *first++ = '.';\n        *first++ = '0';\n        return first;\n    }\n#ifdef __GNUC__\n#pragma GCC diagnostic pop\n#endif\n\n    JSON_ASSERT(last - first >= std::numeric_limits<FloatType>::max_digits10);\n\n    // Compute v = buffer * 10^decimal_exponent.\n    // The decimal digits are stored in the buffer, which needs to be interpreted\n    // as an unsigned decimal integer.\n    // len is the length of the buffer, i.e. the number of decimal digits.\n    int len = 0;\n    int decimal_exponent = 0;\n    dtoa_impl::grisu2(first, len, decimal_exponent, value);\n\n    JSON_ASSERT(len <= std::numeric_limits<FloatType>::max_digits10);\n\n    // Format the buffer like printf(\"%.*g\", prec, value)\n    constexpr int kMinExp = -4;\n    // Use digits10 here to increase compatibility with version 2.\n    constexpr int kMaxExp = std::numeric_limits<FloatType>::digits10;\n\n    JSON_ASSERT(last - first >= kMaxExp + 2);\n    JSON_ASSERT(last - first >= 2 + (-kMinExp - 1) + std::numeric_limits<FloatType>::max_digits10);\n    JSON_ASSERT(last - first >= std::numeric_limits<FloatType>::max_digits10 + 6);\n\n    return dtoa_impl::format_buffer(first, len, decimal_exponent, kMinExp, kMaxExp);\n}\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/exceptions.hpp>\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/cpp_future.hpp>\n\n// #include <nlohmann/detail/output/binary_writer.hpp>\n\n// #include <nlohmann/detail/output/output_adapters.hpp>\n\n// #include <nlohmann/detail/string_concat.hpp>\n\n// #include <nlohmann/detail/value_t.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\nnamespace detail\n{\n\n///////////////////\n// serialization //\n///////////////////\n\n/// how to treat decoding errors\nenum class error_handler_t\n{\n    strict,  ///< throw a type_error exception in case of invalid UTF-8\n    replace, ///< replace invalid UTF-8 sequences with U+FFFD\n    ignore   ///< ignore invalid UTF-8 sequences\n};\n\ntemplate<typename BasicJsonType>\nclass serializer\n{\n    using string_t = typename BasicJsonType::string_t;\n    using number_float_t = typename BasicJsonType::number_float_t;\n    using number_integer_t = typename BasicJsonType::number_integer_t;\n    using number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n    using binary_char_t = typename BasicJsonType::binary_t::value_type;\n    static constexpr std::uint8_t UTF8_ACCEPT = 0;\n    static constexpr std::uint8_t UTF8_REJECT = 1;\n\n  public:\n    /*!\n    @param[in] s  output stream to serialize to\n    @param[in] ichar  indentation character to use\n    @param[in] error_handler_  how to react on decoding errors\n    */\n    serializer(output_adapter_t<char> s, const char ichar,\n               error_handler_t error_handler_ = error_handler_t::strict)\n        : o(std::move(s))\n        , loc(std::localeconv())\n        , thousands_sep(loc->thousands_sep == nullptr ? '\\0' : std::char_traits<char>::to_char_type(* (loc->thousands_sep)))\n        , decimal_point(loc->decimal_point == nullptr ? '\\0' : std::char_traits<char>::to_char_type(* (loc->decimal_point)))\n        , indent_char(ichar)\n        , indent_string(512, indent_char)\n        , error_handler(error_handler_)\n    {}\n\n    // delete because of pointer members\n    serializer(const serializer&) = delete;\n    serializer& operator=(const serializer&) = delete;\n    serializer(serializer&&) = delete;\n    serializer& operator=(serializer&&) = delete;\n    ~serializer() = default;\n\n    /*!\n    @brief internal implementation of the serialization function\n\n    This function is called by the public member function dump and organizes\n    the serialization internally. The indentation level is propagated as\n    additional parameter. In case of arrays and objects, the function is\n    called recursively.\n\n    - strings and object keys are escaped using `escape_string()`\n    - integer numbers are converted implicitly via `operator<<`\n    - floating-point numbers are converted to a string using `\"%g\"` format\n    - binary values are serialized as objects containing the subtype and the\n      byte array\n\n    @param[in] val               value to serialize\n    @param[in] pretty_print      whether the output shall be pretty-printed\n    @param[in] ensure_ascii If @a ensure_ascii is true, all non-ASCII characters\n    in the output are escaped with `\\uXXXX` sequences, and the result consists\n    of ASCII characters only.\n    @param[in] indent_step       the indent level\n    @param[in] current_indent    the current indent level (only used internally)\n    */\n    void dump(const BasicJsonType& val,\n              const bool pretty_print,\n              const bool ensure_ascii,\n              const unsigned int indent_step,\n              const unsigned int current_indent = 0)\n    {\n        switch (val.m_data.m_type)\n        {\n            case value_t::object:\n            {\n                if (val.m_data.m_value.object->empty())\n                {\n                    o->write_characters(\"{}\", 2);\n                    return;\n                }\n\n                if (pretty_print)\n                {\n                    o->write_characters(\"{\\n\", 2);\n\n                    // variable to hold indentation for recursive calls\n                    const auto new_indent = current_indent + indent_step;\n                    if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent))\n                    {\n                        indent_string.resize(indent_string.size() * 2, ' ');\n                    }\n\n                    // first n-1 elements\n                    auto i = val.m_data.m_value.object->cbegin();\n                    for (std::size_t cnt = 0; cnt < val.m_data.m_value.object->size() - 1; ++cnt, ++i)\n                    {\n                        o->write_characters(indent_string.c_str(), new_indent);\n                        o->write_character('\\\"');\n                        dump_escaped(i->first, ensure_ascii);\n                        o->write_characters(\"\\\": \", 3);\n                        dump(i->second, true, ensure_ascii, indent_step, new_indent);\n                        o->write_characters(\",\\n\", 2);\n                    }\n\n                    // last element\n                    JSON_ASSERT(i != val.m_data.m_value.object->cend());\n                    JSON_ASSERT(std::next(i) == val.m_data.m_value.object->cend());\n                    o->write_characters(indent_string.c_str(), new_indent);\n                    o->write_character('\\\"');\n                    dump_escaped(i->first, ensure_ascii);\n                    o->write_characters(\"\\\": \", 3);\n                    dump(i->second, true, ensure_ascii, indent_step, new_indent);\n\n                    o->write_character('\\n');\n                    o->write_characters(indent_string.c_str(), current_indent);\n                    o->write_character('}');\n                }\n                else\n                {\n                    o->write_character('{');\n\n                    // first n-1 elements\n                    auto i = val.m_data.m_value.object->cbegin();\n                    for (std::size_t cnt = 0; cnt < val.m_data.m_value.object->size() - 1; ++cnt, ++i)\n                    {\n                        o->write_character('\\\"');\n                        dump_escaped(i->first, ensure_ascii);\n                        o->write_characters(\"\\\":\", 2);\n                        dump(i->second, false, ensure_ascii, indent_step, current_indent);\n                        o->write_character(',');\n                    }\n\n                    // last element\n                    JSON_ASSERT(i != val.m_data.m_value.object->cend());\n                    JSON_ASSERT(std::next(i) == val.m_data.m_value.object->cend());\n                    o->write_character('\\\"');\n                    dump_escaped(i->first, ensure_ascii);\n                    o->write_characters(\"\\\":\", 2);\n                    dump(i->second, false, ensure_ascii, indent_step, current_indent);\n\n                    o->write_character('}');\n                }\n\n                return;\n            }\n\n            case value_t::array:\n            {\n                if (val.m_data.m_value.array->empty())\n                {\n                    o->write_characters(\"[]\", 2);\n                    return;\n                }\n\n                if (pretty_print)\n                {\n                    o->write_characters(\"[\\n\", 2);\n\n                    // variable to hold indentation for recursive calls\n                    const auto new_indent = current_indent + indent_step;\n                    if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent))\n                    {\n                        indent_string.resize(indent_string.size() * 2, ' ');\n                    }\n\n                    // first n-1 elements\n                    for (auto i = val.m_data.m_value.array->cbegin();\n                            i != val.m_data.m_value.array->cend() - 1; ++i)\n                    {\n                        o->write_characters(indent_string.c_str(), new_indent);\n                        dump(*i, true, ensure_ascii, indent_step, new_indent);\n                        o->write_characters(\",\\n\", 2);\n                    }\n\n                    // last element\n                    JSON_ASSERT(!val.m_data.m_value.array->empty());\n                    o->write_characters(indent_string.c_str(), new_indent);\n                    dump(val.m_data.m_value.array->back(), true, ensure_ascii, indent_step, new_indent);\n\n                    o->write_character('\\n');\n                    o->write_characters(indent_string.c_str(), current_indent);\n                    o->write_character(']');\n                }\n                else\n                {\n                    o->write_character('[');\n\n                    // first n-1 elements\n                    for (auto i = val.m_data.m_value.array->cbegin();\n                            i != val.m_data.m_value.array->cend() - 1; ++i)\n                    {\n                        dump(*i, false, ensure_ascii, indent_step, current_indent);\n                        o->write_character(',');\n                    }\n\n                    // last element\n                    JSON_ASSERT(!val.m_data.m_value.array->empty());\n                    dump(val.m_data.m_value.array->back(), false, ensure_ascii, indent_step, current_indent);\n\n                    o->write_character(']');\n                }\n\n                return;\n            }\n\n            case value_t::string:\n            {\n                o->write_character('\\\"');\n                dump_escaped(*val.m_data.m_value.string, ensure_ascii);\n                o->write_character('\\\"');\n                return;\n            }\n\n            case value_t::binary:\n            {\n                if (pretty_print)\n                {\n                    o->write_characters(\"{\\n\", 2);\n\n                    // variable to hold indentation for recursive calls\n                    const auto new_indent = current_indent + indent_step;\n                    if (JSON_HEDLEY_UNLIKELY(indent_string.size() < new_indent))\n                    {\n                        indent_string.resize(indent_string.size() * 2, ' ');\n                    }\n\n                    o->write_characters(indent_string.c_str(), new_indent);\n\n                    o->write_characters(\"\\\"bytes\\\": [\", 10);\n\n                    if (!val.m_data.m_value.binary->empty())\n                    {\n                        for (auto i = val.m_data.m_value.binary->cbegin();\n                                i != val.m_data.m_value.binary->cend() - 1; ++i)\n                        {\n                            dump_integer(*i);\n                            o->write_characters(\", \", 2);\n                        }\n                        dump_integer(val.m_data.m_value.binary->back());\n                    }\n\n                    o->write_characters(\"],\\n\", 3);\n                    o->write_characters(indent_string.c_str(), new_indent);\n\n                    o->write_characters(\"\\\"subtype\\\": \", 11);\n                    if (val.m_data.m_value.binary->has_subtype())\n                    {\n                        dump_integer(val.m_data.m_value.binary->subtype());\n                    }\n                    else\n                    {\n                        o->write_characters(\"null\", 4);\n                    }\n                    o->write_character('\\n');\n                    o->write_characters(indent_string.c_str(), current_indent);\n                    o->write_character('}');\n                }\n                else\n                {\n                    o->write_characters(\"{\\\"bytes\\\":[\", 10);\n\n                    if (!val.m_data.m_value.binary->empty())\n                    {\n                        for (auto i = val.m_data.m_value.binary->cbegin();\n                                i != val.m_data.m_value.binary->cend() - 1; ++i)\n                        {\n                            dump_integer(*i);\n                            o->write_character(',');\n                        }\n                        dump_integer(val.m_data.m_value.binary->back());\n                    }\n\n                    o->write_characters(\"],\\\"subtype\\\":\", 12);\n                    if (val.m_data.m_value.binary->has_subtype())\n                    {\n                        dump_integer(val.m_data.m_value.binary->subtype());\n                        o->write_character('}');\n                    }\n                    else\n                    {\n                        o->write_characters(\"null}\", 5);\n                    }\n                }\n                return;\n            }\n\n            case value_t::boolean:\n            {\n                if (val.m_data.m_value.boolean)\n                {\n                    o->write_characters(\"true\", 4);\n                }\n                else\n                {\n                    o->write_characters(\"false\", 5);\n                }\n                return;\n            }\n\n            case value_t::number_integer:\n            {\n                dump_integer(val.m_data.m_value.number_integer);\n                return;\n            }\n\n            case value_t::number_unsigned:\n            {\n                dump_integer(val.m_data.m_value.number_unsigned);\n                return;\n            }\n\n            case value_t::number_float:\n            {\n                dump_float(val.m_data.m_value.number_float);\n                return;\n            }\n\n            case value_t::discarded:\n            {\n                o->write_characters(\"<discarded>\", 11);\n                return;\n            }\n\n            case value_t::null:\n            {\n                o->write_characters(\"null\", 4);\n                return;\n            }\n\n            default:            // LCOV_EXCL_LINE\n                JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n        }\n    }\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    /*!\n    @brief dump escaped string\n\n    Escape a string by replacing certain special characters by a sequence of an\n    escape character (backslash) and another character and other control\n    characters by a sequence of \"\\u\" followed by a four-digit hex\n    representation. The escaped string is written to output stream @a o.\n\n    @param[in] s  the string to escape\n    @param[in] ensure_ascii  whether to escape non-ASCII characters with\n                             \\uXXXX sequences\n\n    @complexity Linear in the length of string @a s.\n    */\n    void dump_escaped(const string_t& s, const bool ensure_ascii)\n    {\n        std::uint32_t codepoint{};\n        std::uint8_t state = UTF8_ACCEPT;\n        std::size_t bytes = 0;  // number of bytes written to string_buffer\n\n        // number of bytes written at the point of the last valid byte\n        std::size_t bytes_after_last_accept = 0;\n        std::size_t undumped_chars = 0;\n\n        for (std::size_t i = 0; i < s.size(); ++i)\n        {\n            const auto byte = static_cast<std::uint8_t>(s[i]);\n\n            switch (decode(state, codepoint, byte))\n            {\n                case UTF8_ACCEPT:  // decode found a new code point\n                {\n                    switch (codepoint)\n                    {\n                        case 0x08: // backspace\n                        {\n                            string_buffer[bytes++] = '\\\\';\n                            string_buffer[bytes++] = 'b';\n                            break;\n                        }\n\n                        case 0x09: // horizontal tab\n                        {\n                            string_buffer[bytes++] = '\\\\';\n                            string_buffer[bytes++] = 't';\n                            break;\n                        }\n\n                        case 0x0A: // newline\n                        {\n                            string_buffer[bytes++] = '\\\\';\n                            string_buffer[bytes++] = 'n';\n                            break;\n                        }\n\n                        case 0x0C: // formfeed\n                        {\n                            string_buffer[bytes++] = '\\\\';\n                            string_buffer[bytes++] = 'f';\n                            break;\n                        }\n\n                        case 0x0D: // carriage return\n                        {\n                            string_buffer[bytes++] = '\\\\';\n                            string_buffer[bytes++] = 'r';\n                            break;\n                        }\n\n                        case 0x22: // quotation mark\n                        {\n                            string_buffer[bytes++] = '\\\\';\n                            string_buffer[bytes++] = '\\\"';\n                            break;\n                        }\n\n                        case 0x5C: // reverse solidus\n                        {\n                            string_buffer[bytes++] = '\\\\';\n                            string_buffer[bytes++] = '\\\\';\n                            break;\n                        }\n\n                        default:\n                        {\n                            // escape control characters (0x00..0x1F) or, if\n                            // ensure_ascii parameter is used, non-ASCII characters\n                            if ((codepoint <= 0x1F) || (ensure_ascii && (codepoint >= 0x7F)))\n                            {\n                                if (codepoint <= 0xFFFF)\n                                {\n                                    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg,hicpp-vararg)\n                                    static_cast<void>((std::snprintf)(string_buffer.data() + bytes, 7, \"\\\\u%04x\",\n                                                                      static_cast<std::uint16_t>(codepoint)));\n                                    bytes += 6;\n                                }\n                                else\n                                {\n                                    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg,hicpp-vararg)\n                                    static_cast<void>((std::snprintf)(string_buffer.data() + bytes, 13, \"\\\\u%04x\\\\u%04x\",\n                                                                      static_cast<std::uint16_t>(0xD7C0u + (codepoint >> 10u)),\n                                                                      static_cast<std::uint16_t>(0xDC00u + (codepoint & 0x3FFu))));\n                                    bytes += 12;\n                                }\n                            }\n                            else\n                            {\n                                // copy byte to buffer (all previous bytes\n                                // been copied have in default case above)\n                                string_buffer[bytes++] = s[i];\n                            }\n                            break;\n                        }\n                    }\n\n                    // write buffer and reset index; there must be 13 bytes\n                    // left, as this is the maximal number of bytes to be\n                    // written (\"\\uxxxx\\uxxxx\\0\") for one code point\n                    if (string_buffer.size() - bytes < 13)\n                    {\n                        o->write_characters(string_buffer.data(), bytes);\n                        bytes = 0;\n                    }\n\n                    // remember the byte position of this accept\n                    bytes_after_last_accept = bytes;\n                    undumped_chars = 0;\n                    break;\n                }\n\n                case UTF8_REJECT:  // decode found invalid UTF-8 byte\n                {\n                    switch (error_handler)\n                    {\n                        case error_handler_t::strict:\n                        {\n                            JSON_THROW(type_error::create(316, concat(\"invalid UTF-8 byte at index \", std::to_string(i), \": 0x\", hex_bytes(byte | 0)), nullptr));\n                        }\n\n                        case error_handler_t::ignore:\n                        case error_handler_t::replace:\n                        {\n                            // in case we saw this character the first time, we\n                            // would like to read it again, because the byte\n                            // may be OK for itself, but just not OK for the\n                            // previous sequence\n                            if (undumped_chars > 0)\n                            {\n                                --i;\n                            }\n\n                            // reset length buffer to the last accepted index;\n                            // thus removing/ignoring the invalid characters\n                            bytes = bytes_after_last_accept;\n\n                            if (error_handler == error_handler_t::replace)\n                            {\n                                // add a replacement character\n                                if (ensure_ascii)\n                                {\n                                    string_buffer[bytes++] = '\\\\';\n                                    string_buffer[bytes++] = 'u';\n                                    string_buffer[bytes++] = 'f';\n                                    string_buffer[bytes++] = 'f';\n                                    string_buffer[bytes++] = 'f';\n                                    string_buffer[bytes++] = 'd';\n                                }\n                                else\n                                {\n                                    string_buffer[bytes++] = detail::binary_writer<BasicJsonType, char>::to_char_type('\\xEF');\n                                    string_buffer[bytes++] = detail::binary_writer<BasicJsonType, char>::to_char_type('\\xBF');\n                                    string_buffer[bytes++] = detail::binary_writer<BasicJsonType, char>::to_char_type('\\xBD');\n                                }\n\n                                // write buffer and reset index; there must be 13 bytes\n                                // left, as this is the maximal number of bytes to be\n                                // written (\"\\uxxxx\\uxxxx\\0\") for one code point\n                                if (string_buffer.size() - bytes < 13)\n                                {\n                                    o->write_characters(string_buffer.data(), bytes);\n                                    bytes = 0;\n                                }\n\n                                bytes_after_last_accept = bytes;\n                            }\n\n                            undumped_chars = 0;\n\n                            // continue processing the string\n                            state = UTF8_ACCEPT;\n                            break;\n                        }\n\n                        default:            // LCOV_EXCL_LINE\n                            JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n                    }\n                    break;\n                }\n\n                default:  // decode found yet incomplete multi-byte code point\n                {\n                    if (!ensure_ascii)\n                    {\n                        // code point will not be escaped - copy byte to buffer\n                        string_buffer[bytes++] = s[i];\n                    }\n                    ++undumped_chars;\n                    break;\n                }\n            }\n        }\n\n        // we finished processing the string\n        if (JSON_HEDLEY_LIKELY(state == UTF8_ACCEPT))\n        {\n            // write buffer\n            if (bytes > 0)\n            {\n                o->write_characters(string_buffer.data(), bytes);\n            }\n        }\n        else\n        {\n            // we finish reading, but do not accept: string was incomplete\n            switch (error_handler)\n            {\n                case error_handler_t::strict:\n                {\n                    JSON_THROW(type_error::create(316, concat(\"incomplete UTF-8 string; last byte: 0x\", hex_bytes(static_cast<std::uint8_t>(s.back() | 0))), nullptr));\n                }\n\n                case error_handler_t::ignore:\n                {\n                    // write all accepted bytes\n                    o->write_characters(string_buffer.data(), bytes_after_last_accept);\n                    break;\n                }\n\n                case error_handler_t::replace:\n                {\n                    // write all accepted bytes\n                    o->write_characters(string_buffer.data(), bytes_after_last_accept);\n                    // add a replacement character\n                    if (ensure_ascii)\n                    {\n                        o->write_characters(\"\\\\ufffd\", 6);\n                    }\n                    else\n                    {\n                        o->write_characters(\"\\xEF\\xBF\\xBD\", 3);\n                    }\n                    break;\n                }\n\n                default:            // LCOV_EXCL_LINE\n                    JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n            }\n        }\n    }\n\n  private:\n    /*!\n    @brief count digits\n\n    Count the number of decimal (base 10) digits for an input unsigned integer.\n\n    @param[in] x  unsigned integer number to count its digits\n    @return    number of decimal digits\n    */\n    inline unsigned int count_digits(number_unsigned_t x) noexcept\n    {\n        unsigned int n_digits = 1;\n        for (;;)\n        {\n            if (x < 10)\n            {\n                return n_digits;\n            }\n            if (x < 100)\n            {\n                return n_digits + 1;\n            }\n            if (x < 1000)\n            {\n                return n_digits + 2;\n            }\n            if (x < 10000)\n            {\n                return n_digits + 3;\n            }\n            x = x / 10000u;\n            n_digits += 4;\n        }\n    }\n\n    /*!\n     * @brief convert a byte to a uppercase hex representation\n     * @param[in] byte byte to represent\n     * @return representation (\"00\"..\"FF\")\n     */\n    static std::string hex_bytes(std::uint8_t byte)\n    {\n        std::string result = \"FF\";\n        constexpr const char* nibble_to_hex = \"0123456789ABCDEF\";\n        result[0] = nibble_to_hex[byte / 16];\n        result[1] = nibble_to_hex[byte % 16];\n        return result;\n    }\n\n    // templates to avoid warnings about useless casts\n    template <typename NumberType, enable_if_t<std::is_signed<NumberType>::value, int> = 0>\n    bool is_negative_number(NumberType x)\n    {\n        return x < 0;\n    }\n\n    template < typename NumberType, enable_if_t <std::is_unsigned<NumberType>::value, int > = 0 >\n    bool is_negative_number(NumberType /*unused*/)\n    {\n        return false;\n    }\n\n    /*!\n    @brief dump an integer\n\n    Dump a given integer to output stream @a o. Works internally with\n    @a number_buffer.\n\n    @param[in] x  integer number (signed or unsigned) to dump\n    @tparam NumberType either @a number_integer_t or @a number_unsigned_t\n    */\n    template < typename NumberType, detail::enable_if_t <\n                   std::is_integral<NumberType>::value ||\n                   std::is_same<NumberType, number_unsigned_t>::value ||\n                   std::is_same<NumberType, number_integer_t>::value ||\n                   std::is_same<NumberType, binary_char_t>::value,\n                   int > = 0 >\n    void dump_integer(NumberType x)\n    {\n        static constexpr std::array<std::array<char, 2>, 100> digits_to_99\n        {\n            {\n                {{'0', '0'}}, {{'0', '1'}}, {{'0', '2'}}, {{'0', '3'}}, {{'0', '4'}}, {{'0', '5'}}, {{'0', '6'}}, {{'0', '7'}}, {{'0', '8'}}, {{'0', '9'}},\n                {{'1', '0'}}, {{'1', '1'}}, {{'1', '2'}}, {{'1', '3'}}, {{'1', '4'}}, {{'1', '5'}}, {{'1', '6'}}, {{'1', '7'}}, {{'1', '8'}}, {{'1', '9'}},\n                {{'2', '0'}}, {{'2', '1'}}, {{'2', '2'}}, {{'2', '3'}}, {{'2', '4'}}, {{'2', '5'}}, {{'2', '6'}}, {{'2', '7'}}, {{'2', '8'}}, {{'2', '9'}},\n                {{'3', '0'}}, {{'3', '1'}}, {{'3', '2'}}, {{'3', '3'}}, {{'3', '4'}}, {{'3', '5'}}, {{'3', '6'}}, {{'3', '7'}}, {{'3', '8'}}, {{'3', '9'}},\n                {{'4', '0'}}, {{'4', '1'}}, {{'4', '2'}}, {{'4', '3'}}, {{'4', '4'}}, {{'4', '5'}}, {{'4', '6'}}, {{'4', '7'}}, {{'4', '8'}}, {{'4', '9'}},\n                {{'5', '0'}}, {{'5', '1'}}, {{'5', '2'}}, {{'5', '3'}}, {{'5', '4'}}, {{'5', '5'}}, {{'5', '6'}}, {{'5', '7'}}, {{'5', '8'}}, {{'5', '9'}},\n                {{'6', '0'}}, {{'6', '1'}}, {{'6', '2'}}, {{'6', '3'}}, {{'6', '4'}}, {{'6', '5'}}, {{'6', '6'}}, {{'6', '7'}}, {{'6', '8'}}, {{'6', '9'}},\n                {{'7', '0'}}, {{'7', '1'}}, {{'7', '2'}}, {{'7', '3'}}, {{'7', '4'}}, {{'7', '5'}}, {{'7', '6'}}, {{'7', '7'}}, {{'7', '8'}}, {{'7', '9'}},\n                {{'8', '0'}}, {{'8', '1'}}, {{'8', '2'}}, {{'8', '3'}}, {{'8', '4'}}, {{'8', '5'}}, {{'8', '6'}}, {{'8', '7'}}, {{'8', '8'}}, {{'8', '9'}},\n                {{'9', '0'}}, {{'9', '1'}}, {{'9', '2'}}, {{'9', '3'}}, {{'9', '4'}}, {{'9', '5'}}, {{'9', '6'}}, {{'9', '7'}}, {{'9', '8'}}, {{'9', '9'}},\n            }\n        };\n\n        // special case for \"0\"\n        if (x == 0)\n        {\n            o->write_character('0');\n            return;\n        }\n\n        // use a pointer to fill the buffer\n        auto buffer_ptr = number_buffer.begin(); // NOLINT(llvm-qualified-auto,readability-qualified-auto,cppcoreguidelines-pro-type-vararg,hicpp-vararg)\n\n        number_unsigned_t abs_value;\n\n        unsigned int n_chars{};\n\n        if (is_negative_number(x))\n        {\n            *buffer_ptr = '-';\n            abs_value = remove_sign(static_cast<number_integer_t>(x));\n\n            // account one more byte for the minus sign\n            n_chars = 1 + count_digits(abs_value);\n        }\n        else\n        {\n            abs_value = static_cast<number_unsigned_t>(x);\n            n_chars = count_digits(abs_value);\n        }\n\n        // spare 1 byte for '\\0'\n        JSON_ASSERT(n_chars < number_buffer.size() - 1);\n\n        // jump to the end to generate the string from backward,\n        // so we later avoid reversing the result\n        buffer_ptr += n_chars;\n\n        // Fast int2ascii implementation inspired by \"Fastware\" talk by Andrei Alexandrescu\n        // See: https://www.youtube.com/watch?v=o4-CwDo2zpg\n        while (abs_value >= 100)\n        {\n            const auto digits_index = static_cast<unsigned>((abs_value % 100));\n            abs_value /= 100;\n            *(--buffer_ptr) = digits_to_99[digits_index][1];\n            *(--buffer_ptr) = digits_to_99[digits_index][0];\n        }\n\n        if (abs_value >= 10)\n        {\n            const auto digits_index = static_cast<unsigned>(abs_value);\n            *(--buffer_ptr) = digits_to_99[digits_index][1];\n            *(--buffer_ptr) = digits_to_99[digits_index][0];\n        }\n        else\n        {\n            *(--buffer_ptr) = static_cast<char>('0' + abs_value);\n        }\n\n        o->write_characters(number_buffer.data(), n_chars);\n    }\n\n    /*!\n    @brief dump a floating-point number\n\n    Dump a given floating-point number to output stream @a o. Works internally\n    with @a number_buffer.\n\n    @param[in] x  floating-point number to dump\n    */\n    void dump_float(number_float_t x)\n    {\n        // NaN / inf\n        if (!std::isfinite(x))\n        {\n            o->write_characters(\"null\", 4);\n            return;\n        }\n\n        // If number_float_t is an IEEE-754 single or double precision number,\n        // use the Grisu2 algorithm to produce short numbers which are\n        // guaranteed to round-trip, using strtof and strtod, resp.\n        //\n        // NB: The test below works if <long double> == <double>.\n        static constexpr bool is_ieee_single_or_double\n            = (std::numeric_limits<number_float_t>::is_iec559 && std::numeric_limits<number_float_t>::digits == 24 && std::numeric_limits<number_float_t>::max_exponent == 128) ||\n              (std::numeric_limits<number_float_t>::is_iec559 && std::numeric_limits<number_float_t>::digits == 53 && std::numeric_limits<number_float_t>::max_exponent == 1024);\n\n        dump_float(x, std::integral_constant<bool, is_ieee_single_or_double>());\n    }\n\n    void dump_float(number_float_t x, std::true_type /*is_ieee_single_or_double*/)\n    {\n        auto* begin = number_buffer.data();\n        auto* end = ::nlohmann::detail::to_chars(begin, begin + number_buffer.size(), x);\n\n        o->write_characters(begin, static_cast<size_t>(end - begin));\n    }\n\n    void dump_float(number_float_t x, std::false_type /*is_ieee_single_or_double*/)\n    {\n        // get number of digits for a float -> text -> float round-trip\n        static constexpr auto d = std::numeric_limits<number_float_t>::max_digits10;\n\n        // the actual conversion\n        // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg,hicpp-vararg)\n        std::ptrdiff_t len = (std::snprintf)(number_buffer.data(), number_buffer.size(), \"%.*g\", d, x);\n\n        // negative value indicates an error\n        JSON_ASSERT(len > 0);\n        // check if buffer was large enough\n        JSON_ASSERT(static_cast<std::size_t>(len) < number_buffer.size());\n\n        // erase thousands separator\n        if (thousands_sep != '\\0')\n        {\n            // NOLINTNEXTLINE(readability-qualified-auto,llvm-qualified-auto): std::remove returns an iterator, see https://github.com/nlohmann/json/issues/3081\n            const auto end = std::remove(number_buffer.begin(), number_buffer.begin() + len, thousands_sep);\n            std::fill(end, number_buffer.end(), '\\0');\n            JSON_ASSERT((end - number_buffer.begin()) <= len);\n            len = (end - number_buffer.begin());\n        }\n\n        // convert decimal point to '.'\n        if (decimal_point != '\\0' && decimal_point != '.')\n        {\n            // NOLINTNEXTLINE(readability-qualified-auto,llvm-qualified-auto): std::find returns an iterator, see https://github.com/nlohmann/json/issues/3081\n            const auto dec_pos = std::find(number_buffer.begin(), number_buffer.end(), decimal_point);\n            if (dec_pos != number_buffer.end())\n            {\n                *dec_pos = '.';\n            }\n        }\n\n        o->write_characters(number_buffer.data(), static_cast<std::size_t>(len));\n\n        // determine if we need to append \".0\"\n        const bool value_is_int_like =\n            std::none_of(number_buffer.begin(), number_buffer.begin() + len + 1,\n                         [](char c)\n        {\n            return c == '.' || c == 'e';\n        });\n\n        if (value_is_int_like)\n        {\n            o->write_characters(\".0\", 2);\n        }\n    }\n\n    /*!\n    @brief check whether a string is UTF-8 encoded\n\n    The function checks each byte of a string whether it is UTF-8 encoded. The\n    result of the check is stored in the @a state parameter. The function must\n    be called initially with state 0 (accept). State 1 means the string must\n    be rejected, because the current byte is not allowed. If the string is\n    completely processed, but the state is non-zero, the string ended\n    prematurely; that is, the last byte indicated more bytes should have\n    followed.\n\n    @param[in,out] state  the state of the decoding\n    @param[in,out] codep  codepoint (valid only if resulting state is UTF8_ACCEPT)\n    @param[in] byte       next byte to decode\n    @return               new state\n\n    @note The function has been edited: a std::array is used.\n\n    @copyright Copyright (c) 2008-2009 Bjoern Hoehrmann <bjoern@hoehrmann.de>\n    @sa http://bjoern.hoehrmann.de/utf-8/decoder/dfa/\n    */\n    static std::uint8_t decode(std::uint8_t& state, std::uint32_t& codep, const std::uint8_t byte) noexcept\n    {\n        static const std::array<std::uint8_t, 400> utf8d =\n        {\n            {\n                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1F\n                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3F\n                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5F\n                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7F\n                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9F\n                7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // A0..BF\n                8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // C0..DF\n                0xA, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x4, 0x3, 0x3, // E0..EF\n                0xB, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, // F0..FF\n                0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4, 0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0\n                1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2\n                1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4\n                1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6\n                1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 // s7..s8\n            }\n        };\n\n        JSON_ASSERT(byte < utf8d.size());\n        const std::uint8_t type = utf8d[byte];\n\n        codep = (state != UTF8_ACCEPT)\n                ? (byte & 0x3fu) | (codep << 6u)\n                : (0xFFu >> type) & (byte);\n\n        const std::size_t index = 256u + static_cast<size_t>(state) * 16u + static_cast<size_t>(type);\n        JSON_ASSERT(index < utf8d.size());\n        state = utf8d[index];\n        return state;\n    }\n\n    /*\n     * Overload to make the compiler happy while it is instantiating\n     * dump_integer for number_unsigned_t.\n     * Must never be called.\n     */\n    number_unsigned_t remove_sign(number_unsigned_t x)\n    {\n        JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n        return x; // LCOV_EXCL_LINE\n    }\n\n    /*\n     * Helper function for dump_integer\n     *\n     * This function takes a negative signed integer and returns its absolute\n     * value as unsigned integer. The plus/minus shuffling is necessary as we can\n     * not directly remove the sign of an arbitrary signed integer as the\n     * absolute values of INT_MIN and INT_MAX are usually not the same. See\n     * #1708 for details.\n     */\n    inline number_unsigned_t remove_sign(number_integer_t x) noexcept\n    {\n        JSON_ASSERT(x < 0 && x < (std::numeric_limits<number_integer_t>::max)()); // NOLINT(misc-redundant-expression)\n        return static_cast<number_unsigned_t>(-(x + 1)) + 1;\n    }\n\n  private:\n    /// the output of the serializer\n    output_adapter_t<char> o = nullptr;\n\n    /// a (hopefully) large enough character buffer\n    std::array<char, 64> number_buffer{{}};\n\n    /// the locale\n    const std::lconv* loc = nullptr;\n    /// the locale's thousand separator character\n    const char thousands_sep = '\\0';\n    /// the locale's decimal point character\n    const char decimal_point = '\\0';\n\n    /// string buffer\n    std::array<char, 512> string_buffer{{}};\n\n    /// the indentation character\n    const char indent_char;\n    /// the indentation string\n    string_t indent_string;\n\n    /// error_handler how to react on decoding errors\n    const error_handler_t error_handler;\n};\n\n}  // namespace detail\nNLOHMANN_JSON_NAMESPACE_END\n\n// #include <nlohmann/detail/value_t.hpp>\n\n// #include <nlohmann/json_fwd.hpp>\n\n// #include <nlohmann/ordered_map.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#include <functional> // equal_to, less\n#include <initializer_list> // initializer_list\n#include <iterator> // input_iterator_tag, iterator_traits\n#include <memory> // allocator\n#include <stdexcept> // for out_of_range\n#include <type_traits> // enable_if, is_convertible\n#include <utility> // pair\n#include <vector> // vector\n\n// #include <nlohmann/detail/macro_scope.hpp>\n\n// #include <nlohmann/detail/meta/type_traits.hpp>\n\n\nNLOHMANN_JSON_NAMESPACE_BEGIN\n\n/// ordered_map: a minimal map-like container that preserves insertion order\n/// for use within nlohmann::basic_json<ordered_map>\ntemplate <class Key, class T, class IgnoredLess = std::less<Key>,\n          class Allocator = std::allocator<std::pair<const Key, T>>>\n                  struct ordered_map : std::vector<std::pair<const Key, T>, Allocator>\n{\n    using key_type = Key;\n    using mapped_type = T;\n    using Container = std::vector<std::pair<const Key, T>, Allocator>;\n    using iterator = typename Container::iterator;\n    using const_iterator = typename Container::const_iterator;\n    using size_type = typename Container::size_type;\n    using value_type = typename Container::value_type;\n#ifdef JSON_HAS_CPP_14\n    using key_compare = std::equal_to<>;\n#else\n    using key_compare = std::equal_to<Key>;\n#endif\n\n    // Explicit constructors instead of `using Container::Container`\n    // otherwise older compilers choke on it (GCC <= 5.5, xcode <= 9.4)\n    ordered_map() noexcept(noexcept(Container())) : Container{} {}\n    explicit ordered_map(const Allocator& alloc) noexcept(noexcept(Container(alloc))) : Container{alloc} {}\n    template <class It>\n    ordered_map(It first, It last, const Allocator& alloc = Allocator())\n        : Container{first, last, alloc} {}\n    ordered_map(std::initializer_list<value_type> init, const Allocator& alloc = Allocator() )\n        : Container{init, alloc} {}\n\n    std::pair<iterator, bool> emplace(const key_type& key, T&& t)\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return {it, false};\n            }\n        }\n        Container::emplace_back(key, std::forward<T>(t));\n        return {std::prev(this->end()), true};\n    }\n\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_key_type<key_compare, key_type, KeyType>::value, int> = 0>\n    std::pair<iterator, bool> emplace(KeyType && key, T && t)\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return {it, false};\n            }\n        }\n        Container::emplace_back(std::forward<KeyType>(key), std::forward<T>(t));\n        return {std::prev(this->end()), true};\n    }\n\n    T& operator[](const key_type& key)\n    {\n        return emplace(key, T{}).first->second;\n    }\n\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_key_type<key_compare, key_type, KeyType>::value, int> = 0>\n    T & operator[](KeyType && key)\n    {\n        return emplace(std::forward<KeyType>(key), T{}).first->second;\n    }\n\n    const T& operator[](const key_type& key) const\n    {\n        return at(key);\n    }\n\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_key_type<key_compare, key_type, KeyType>::value, int> = 0>\n    const T & operator[](KeyType && key) const\n    {\n        return at(std::forward<KeyType>(key));\n    }\n\n    T& at(const key_type& key)\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return it->second;\n            }\n        }\n\n        JSON_THROW(std::out_of_range(\"key not found\"));\n    }\n\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_key_type<key_compare, key_type, KeyType>::value, int> = 0>\n    T & at(KeyType && key) // NOLINT(cppcoreguidelines-missing-std-forward)\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return it->second;\n            }\n        }\n\n        JSON_THROW(std::out_of_range(\"key not found\"));\n    }\n\n    const T& at(const key_type& key) const\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return it->second;\n            }\n        }\n\n        JSON_THROW(std::out_of_range(\"key not found\"));\n    }\n\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_key_type<key_compare, key_type, KeyType>::value, int> = 0>\n    const T & at(KeyType && key) const // NOLINT(cppcoreguidelines-missing-std-forward)\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return it->second;\n            }\n        }\n\n        JSON_THROW(std::out_of_range(\"key not found\"));\n    }\n\n    size_type erase(const key_type& key)\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                // Since we cannot move const Keys, re-construct them in place\n                for (auto next = it; ++next != this->end(); ++it)\n                {\n                    it->~value_type(); // Destroy but keep allocation\n                    new (&*it) value_type{std::move(*next)};\n                }\n                Container::pop_back();\n                return 1;\n            }\n        }\n        return 0;\n    }\n\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_key_type<key_compare, key_type, KeyType>::value, int> = 0>\n    size_type erase(KeyType && key) // NOLINT(cppcoreguidelines-missing-std-forward)\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                // Since we cannot move const Keys, re-construct them in place\n                for (auto next = it; ++next != this->end(); ++it)\n                {\n                    it->~value_type(); // Destroy but keep allocation\n                    new (&*it) value_type{std::move(*next)};\n                }\n                Container::pop_back();\n                return 1;\n            }\n        }\n        return 0;\n    }\n\n    iterator erase(iterator pos)\n    {\n        return erase(pos, std::next(pos));\n    }\n\n    iterator erase(iterator first, iterator last)\n    {\n        if (first == last)\n        {\n            return first;\n        }\n\n        const auto elements_affected = std::distance(first, last);\n        const auto offset = std::distance(Container::begin(), first);\n\n        // This is the start situation. We need to delete elements_affected\n        // elements (3 in this example: e, f, g), and need to return an\n        // iterator past the last deleted element (h in this example).\n        // Note that offset is the distance from the start of the vector\n        // to first. We will need this later.\n\n        // [ a, b, c, d, e, f, g, h, i, j ]\n        //               ^        ^\n        //             first    last\n\n        // Since we cannot move const Keys, we re-construct them in place.\n        // We start at first and re-construct (viz. copy) the elements from\n        // the back of the vector. Example for first iteration:\n\n        //               ,--------.\n        //               v        |   destroy e and re-construct with h\n        // [ a, b, c, d, e, f, g, h, i, j ]\n        //               ^        ^\n        //               it       it + elements_affected\n\n        for (auto it = first; std::next(it, elements_affected) != Container::end(); ++it)\n        {\n            it->~value_type(); // destroy but keep allocation\n            new (&*it) value_type{std::move(*std::next(it, elements_affected))}; // \"move\" next element to it\n        }\n\n        // [ a, b, c, d, h, i, j, h, i, j ]\n        //               ^        ^\n        //             first    last\n\n        // remove the unneeded elements at the end of the vector\n        Container::resize(this->size() - static_cast<size_type>(elements_affected));\n\n        // [ a, b, c, d, h, i, j ]\n        //               ^        ^\n        //             first    last\n\n        // first is now pointing past the last deleted element, but we cannot\n        // use this iterator, because it may have been invalidated by the\n        // resize call. Instead, we can return begin() + offset.\n        return Container::begin() + offset;\n    }\n\n    size_type count(const key_type& key) const\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return 1;\n            }\n        }\n        return 0;\n    }\n\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_key_type<key_compare, key_type, KeyType>::value, int> = 0>\n    size_type count(KeyType && key) const // NOLINT(cppcoreguidelines-missing-std-forward)\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return 1;\n            }\n        }\n        return 0;\n    }\n\n    iterator find(const key_type& key)\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return it;\n            }\n        }\n        return Container::end();\n    }\n\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_key_type<key_compare, key_type, KeyType>::value, int> = 0>\n    iterator find(KeyType && key) // NOLINT(cppcoreguidelines-missing-std-forward)\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return it;\n            }\n        }\n        return Container::end();\n    }\n\n    const_iterator find(const key_type& key) const\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, key))\n            {\n                return it;\n            }\n        }\n        return Container::end();\n    }\n\n    std::pair<iterator, bool> insert( value_type&& value )\n    {\n        return emplace(value.first, std::move(value.second));\n    }\n\n    std::pair<iterator, bool> insert( const value_type& value )\n    {\n        for (auto it = this->begin(); it != this->end(); ++it)\n        {\n            if (m_compare(it->first, value.first))\n            {\n                return {it, false};\n            }\n        }\n        Container::push_back(value);\n        return {--this->end(), true};\n    }\n\n    template<typename InputIt>\n    using require_input_iter = typename std::enable_if<std::is_convertible<typename std::iterator_traits<InputIt>::iterator_category,\n            std::input_iterator_tag>::value>::type;\n\n    template<typename InputIt, typename = require_input_iter<InputIt>>\n    void insert(InputIt first, InputIt last)\n    {\n        for (auto it = first; it != last; ++it)\n        {\n            insert(*it);\n        }\n    }\n\nprivate:\n    JSON_NO_UNIQUE_ADDRESS key_compare m_compare = key_compare();\n};\n\nNLOHMANN_JSON_NAMESPACE_END\n\n\n#if defined(JSON_HAS_CPP_17)\n    #if JSON_HAS_STATIC_RTTI\n        #include <any>\n    #endif\n    #include <string_view>\n#endif\n\n/*!\n@brief namespace for Niels Lohmann\n@see https://github.com/nlohmann\n@since version 1.0.0\n*/\nNLOHMANN_JSON_NAMESPACE_BEGIN\n\n/*!\n@brief a class to store JSON values\n\n@internal\n@invariant The member variables @a m_value and @a m_type have the following\nrelationship:\n- If `m_type == value_t::object`, then `m_value.object != nullptr`.\n- If `m_type == value_t::array`, then `m_value.array != nullptr`.\n- If `m_type == value_t::string`, then `m_value.string != nullptr`.\nThe invariants are checked by member function assert_invariant().\n\n@note ObjectType trick from https://stackoverflow.com/a/9860911\n@endinternal\n\n@since version 1.0.0\n\n@nosubgrouping\n*/\nNLOHMANN_BASIC_JSON_TPL_DECLARATION\nclass basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-special-member-functions)\n    : public ::nlohmann::detail::json_base_class<CustomBaseClass>\n{\n  private:\n    template<detail::value_t> friend struct detail::external_constructor;\n\n    template<typename>\n    friend class ::nlohmann::json_pointer;\n    // can be restored when json_pointer backwards compatibility is removed\n    // friend ::nlohmann::json_pointer<StringType>;\n\n    template<typename BasicJsonType, typename InputType>\n    friend class ::nlohmann::detail::parser;\n    friend ::nlohmann::detail::serializer<basic_json>;\n    template<typename BasicJsonType>\n    friend class ::nlohmann::detail::iter_impl;\n    template<typename BasicJsonType, typename CharType>\n    friend class ::nlohmann::detail::binary_writer;\n    template<typename BasicJsonType, typename InputType, typename SAX>\n    friend class ::nlohmann::detail::binary_reader;\n    template<typename BasicJsonType>\n    friend class ::nlohmann::detail::json_sax_dom_parser;\n    template<typename BasicJsonType>\n    friend class ::nlohmann::detail::json_sax_dom_callback_parser;\n    friend class ::nlohmann::detail::exception;\n\n    /// workaround type for MSVC\n    using basic_json_t = NLOHMANN_BASIC_JSON_TPL;\n    using json_base_class_t = ::nlohmann::detail::json_base_class<CustomBaseClass>;\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    // convenience aliases for types residing in namespace detail;\n    using lexer = ::nlohmann::detail::lexer_base<basic_json>;\n\n    template<typename InputAdapterType>\n    static ::nlohmann::detail::parser<basic_json, InputAdapterType> parser(\n        InputAdapterType adapter,\n        detail::parser_callback_t<basic_json>cb = nullptr,\n        const bool allow_exceptions = true,\n        const bool ignore_comments = false\n                                 )\n    {\n        return ::nlohmann::detail::parser<basic_json, InputAdapterType>(std::move(adapter),\n                std::move(cb), allow_exceptions, ignore_comments);\n    }\n\n  private:\n    using primitive_iterator_t = ::nlohmann::detail::primitive_iterator_t;\n    template<typename BasicJsonType>\n    using internal_iterator = ::nlohmann::detail::internal_iterator<BasicJsonType>;\n    template<typename BasicJsonType>\n    using iter_impl = ::nlohmann::detail::iter_impl<BasicJsonType>;\n    template<typename Iterator>\n    using iteration_proxy = ::nlohmann::detail::iteration_proxy<Iterator>;\n    template<typename Base> using json_reverse_iterator = ::nlohmann::detail::json_reverse_iterator<Base>;\n\n    template<typename CharType>\n    using output_adapter_t = ::nlohmann::detail::output_adapter_t<CharType>;\n\n    template<typename InputType>\n    using binary_reader = ::nlohmann::detail::binary_reader<basic_json, InputType>;\n    template<typename CharType> using binary_writer = ::nlohmann::detail::binary_writer<basic_json, CharType>;\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    using serializer = ::nlohmann::detail::serializer<basic_json>;\n\n  public:\n    using value_t = detail::value_t;\n    /// JSON Pointer, see @ref nlohmann::json_pointer\n    using json_pointer = ::nlohmann::json_pointer<StringType>;\n    template<typename T, typename SFINAE>\n    using json_serializer = JSONSerializer<T, SFINAE>;\n    /// how to treat decoding errors\n    using error_handler_t = detail::error_handler_t;\n    /// how to treat CBOR tags\n    using cbor_tag_handler_t = detail::cbor_tag_handler_t;\n    /// helper type for initializer lists of basic_json values\n    using initializer_list_t = std::initializer_list<detail::json_ref<basic_json>>;\n\n    using input_format_t = detail::input_format_t;\n    /// SAX interface type, see @ref nlohmann::json_sax\n    using json_sax_t = json_sax<basic_json>;\n\n    ////////////////\n    // exceptions //\n    ////////////////\n\n    /// @name exceptions\n    /// Classes to implement user-defined exceptions.\n    /// @{\n\n    using exception = detail::exception;\n    using parse_error = detail::parse_error;\n    using invalid_iterator = detail::invalid_iterator;\n    using type_error = detail::type_error;\n    using out_of_range = detail::out_of_range;\n    using other_error = detail::other_error;\n\n    /// @}\n\n    /////////////////////\n    // container types //\n    /////////////////////\n\n    /// @name container types\n    /// The canonic container types to use @ref basic_json like any other STL\n    /// container.\n    /// @{\n\n    /// the type of elements in a basic_json container\n    using value_type = basic_json;\n\n    /// the type of an element reference\n    using reference = value_type&;\n    /// the type of an element const reference\n    using const_reference = const value_type&;\n\n    /// a type to represent differences between iterators\n    using difference_type = std::ptrdiff_t;\n    /// a type to represent container sizes\n    using size_type = std::size_t;\n\n    /// the allocator type\n    using allocator_type = AllocatorType<basic_json>;\n\n    /// the type of an element pointer\n    using pointer = typename std::allocator_traits<allocator_type>::pointer;\n    /// the type of an element const pointer\n    using const_pointer = typename std::allocator_traits<allocator_type>::const_pointer;\n\n    /// an iterator for a basic_json container\n    using iterator = iter_impl<basic_json>;\n    /// a const iterator for a basic_json container\n    using const_iterator = iter_impl<const basic_json>;\n    /// a reverse iterator for a basic_json container\n    using reverse_iterator = json_reverse_iterator<typename basic_json::iterator>;\n    /// a const reverse iterator for a basic_json container\n    using const_reverse_iterator = json_reverse_iterator<typename basic_json::const_iterator>;\n\n    /// @}\n\n    /// @brief returns the allocator associated with the container\n    /// @sa https://json.nlohmann.me/api/basic_json/get_allocator/\n    static allocator_type get_allocator()\n    {\n        return allocator_type();\n    }\n\n    /// @brief returns version information on the library\n    /// @sa https://json.nlohmann.me/api/basic_json/meta/\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json meta()\n    {\n        basic_json result;\n\n        result[\"copyright\"] = \"(C) 2013-2023 Niels Lohmann\";\n        result[\"name\"] = \"JSON for Modern C++\";\n        result[\"url\"] = \"https://github.com/nlohmann/json\";\n        result[\"version\"][\"string\"] =\n            detail::concat(std::to_string(NLOHMANN_JSON_VERSION_MAJOR), '.',\n                           std::to_string(NLOHMANN_JSON_VERSION_MINOR), '.',\n                           std::to_string(NLOHMANN_JSON_VERSION_PATCH));\n        result[\"version\"][\"major\"] = NLOHMANN_JSON_VERSION_MAJOR;\n        result[\"version\"][\"minor\"] = NLOHMANN_JSON_VERSION_MINOR;\n        result[\"version\"][\"patch\"] = NLOHMANN_JSON_VERSION_PATCH;\n\n#ifdef _WIN32\n        result[\"platform\"] = \"win32\";\n#elif defined __linux__\n        result[\"platform\"] = \"linux\";\n#elif defined __APPLE__\n        result[\"platform\"] = \"apple\";\n#elif defined __unix__\n        result[\"platform\"] = \"unix\";\n#else\n        result[\"platform\"] = \"unknown\";\n#endif\n\n#if defined(__ICC) || defined(__INTEL_COMPILER)\n        result[\"compiler\"] = {{\"family\", \"icc\"}, {\"version\", __INTEL_COMPILER}};\n#elif defined(__clang__)\n        result[\"compiler\"] = {{\"family\", \"clang\"}, {\"version\", __clang_version__}};\n#elif defined(__GNUC__) || defined(__GNUG__)\n        result[\"compiler\"] = {{\"family\", \"gcc\"}, {\"version\", detail::concat(\n                    std::to_string(__GNUC__), '.',\n                    std::to_string(__GNUC_MINOR__), '.',\n                    std::to_string(__GNUC_PATCHLEVEL__))\n            }\n        };\n#elif defined(__HP_cc) || defined(__HP_aCC)\n        result[\"compiler\"] = \"hp\"\n#elif defined(__IBMCPP__)\n        result[\"compiler\"] = {{\"family\", \"ilecpp\"}, {\"version\", __IBMCPP__}};\n#elif defined(_MSC_VER)\n        result[\"compiler\"] = {{\"family\", \"msvc\"}, {\"version\", _MSC_VER}};\n#elif defined(__PGI)\n        result[\"compiler\"] = {{\"family\", \"pgcpp\"}, {\"version\", __PGI}};\n#elif defined(__SUNPRO_CC)\n        result[\"compiler\"] = {{\"family\", \"sunpro\"}, {\"version\", __SUNPRO_CC}};\n#else\n        result[\"compiler\"] = {{\"family\", \"unknown\"}, {\"version\", \"unknown\"}};\n#endif\n\n#if defined(_MSVC_LANG)\n        result[\"compiler\"][\"c++\"] = std::to_string(_MSVC_LANG);\n#elif defined(__cplusplus)\n        result[\"compiler\"][\"c++\"] = std::to_string(__cplusplus);\n#else\n        result[\"compiler\"][\"c++\"] = \"unknown\";\n#endif\n        return result;\n    }\n\n    ///////////////////////////\n    // JSON value data types //\n    ///////////////////////////\n\n    /// @name JSON value data types\n    /// The data types to store a JSON value. These types are derived from\n    /// the template arguments passed to class @ref basic_json.\n    /// @{\n\n    /// @brief default object key comparator type\n    /// The actual object key comparator type (@ref object_comparator_t) may be\n    /// different.\n    /// @sa https://json.nlohmann.me/api/basic_json/default_object_comparator_t/\n#if defined(JSON_HAS_CPP_14)\n    // use of transparent comparator avoids unnecessary repeated construction of temporaries\n    // in functions involving lookup by key with types other than object_t::key_type (aka. StringType)\n    using default_object_comparator_t = std::less<>;\n#else\n    using default_object_comparator_t = std::less<StringType>;\n#endif\n\n    /// @brief a type for an object\n    /// @sa https://json.nlohmann.me/api/basic_json/object_t/\n    using object_t = ObjectType<StringType,\n          basic_json,\n          default_object_comparator_t,\n          AllocatorType<std::pair<const StringType,\n          basic_json>>>;\n\n    /// @brief a type for an array\n    /// @sa https://json.nlohmann.me/api/basic_json/array_t/\n    using array_t = ArrayType<basic_json, AllocatorType<basic_json>>;\n\n    /// @brief a type for a string\n    /// @sa https://json.nlohmann.me/api/basic_json/string_t/\n    using string_t = StringType;\n\n    /// @brief a type for a boolean\n    /// @sa https://json.nlohmann.me/api/basic_json/boolean_t/\n    using boolean_t = BooleanType;\n\n    /// @brief a type for a number (integer)\n    /// @sa https://json.nlohmann.me/api/basic_json/number_integer_t/\n    using number_integer_t = NumberIntegerType;\n\n    /// @brief a type for a number (unsigned)\n    /// @sa https://json.nlohmann.me/api/basic_json/number_unsigned_t/\n    using number_unsigned_t = NumberUnsignedType;\n\n    /// @brief a type for a number (floating-point)\n    /// @sa https://json.nlohmann.me/api/basic_json/number_float_t/\n    using number_float_t = NumberFloatType;\n\n    /// @brief a type for a packed binary type\n    /// @sa https://json.nlohmann.me/api/basic_json/binary_t/\n    using binary_t = nlohmann::byte_container_with_subtype<BinaryType>;\n\n    /// @brief object key comparator type\n    /// @sa https://json.nlohmann.me/api/basic_json/object_comparator_t/\n    using object_comparator_t = detail::actual_object_comparator_t<basic_json>;\n\n    /// @}\n\n  private:\n\n    /// helper for exception-safe object creation\n    template<typename T, typename... Args>\n    JSON_HEDLEY_RETURNS_NON_NULL\n    static T* create(Args&& ... args)\n    {\n        AllocatorType<T> alloc;\n        using AllocatorTraits = std::allocator_traits<AllocatorType<T>>;\n\n        auto deleter = [&](T * obj)\n        {\n            AllocatorTraits::deallocate(alloc, obj, 1);\n        };\n        std::unique_ptr<T, decltype(deleter)> obj(AllocatorTraits::allocate(alloc, 1), deleter);\n        AllocatorTraits::construct(alloc, obj.get(), std::forward<Args>(args)...);\n        JSON_ASSERT(obj != nullptr);\n        return obj.release();\n    }\n\n    ////////////////////////\n    // JSON value storage //\n    ////////////////////////\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    /*!\n    @brief a JSON value\n\n    The actual storage for a JSON value of the @ref basic_json class. This\n    union combines the different storage types for the JSON value types\n    defined in @ref value_t.\n\n    JSON type | value_t type    | used type\n    --------- | --------------- | ------------------------\n    object    | object          | pointer to @ref object_t\n    array     | array           | pointer to @ref array_t\n    string    | string          | pointer to @ref string_t\n    boolean   | boolean         | @ref boolean_t\n    number    | number_integer  | @ref number_integer_t\n    number    | number_unsigned | @ref number_unsigned_t\n    number    | number_float    | @ref number_float_t\n    binary    | binary          | pointer to @ref binary_t\n    null      | null            | *no value is stored*\n\n    @note Variable-length types (objects, arrays, and strings) are stored as\n    pointers. The size of the union should not exceed 64 bits if the default\n    value types are used.\n\n    @since version 1.0.0\n    */\n    union json_value\n    {\n        /// object (stored with pointer to save storage)\n        object_t* object;\n        /// array (stored with pointer to save storage)\n        array_t* array;\n        /// string (stored with pointer to save storage)\n        string_t* string;\n        /// binary (stored with pointer to save storage)\n        binary_t* binary;\n        /// boolean\n        boolean_t boolean;\n        /// number (integer)\n        number_integer_t number_integer;\n        /// number (unsigned integer)\n        number_unsigned_t number_unsigned;\n        /// number (floating-point)\n        number_float_t number_float;\n\n        /// default constructor (for null values)\n        json_value() = default;\n        /// constructor for booleans\n        json_value(boolean_t v) noexcept : boolean(v) {}\n        /// constructor for numbers (integer)\n        json_value(number_integer_t v) noexcept : number_integer(v) {}\n        /// constructor for numbers (unsigned)\n        json_value(number_unsigned_t v) noexcept : number_unsigned(v) {}\n        /// constructor for numbers (floating-point)\n        json_value(number_float_t v) noexcept : number_float(v) {}\n        /// constructor for empty values of a given type\n        json_value(value_t t)\n        {\n            switch (t)\n            {\n                case value_t::object:\n                {\n                    object = create<object_t>();\n                    break;\n                }\n\n                case value_t::array:\n                {\n                    array = create<array_t>();\n                    break;\n                }\n\n                case value_t::string:\n                {\n                    string = create<string_t>(\"\");\n                    break;\n                }\n\n                case value_t::binary:\n                {\n                    binary = create<binary_t>();\n                    break;\n                }\n\n                case value_t::boolean:\n                {\n                    boolean = static_cast<boolean_t>(false);\n                    break;\n                }\n\n                case value_t::number_integer:\n                {\n                    number_integer = static_cast<number_integer_t>(0);\n                    break;\n                }\n\n                case value_t::number_unsigned:\n                {\n                    number_unsigned = static_cast<number_unsigned_t>(0);\n                    break;\n                }\n\n                case value_t::number_float:\n                {\n                    number_float = static_cast<number_float_t>(0.0);\n                    break;\n                }\n\n                case value_t::null:\n                {\n                    object = nullptr;  // silence warning, see #821\n                    break;\n                }\n\n                case value_t::discarded:\n                default:\n                {\n                    object = nullptr;  // silence warning, see #821\n                    if (JSON_HEDLEY_UNLIKELY(t == value_t::null))\n                    {\n                        JSON_THROW(other_error::create(500, \"961c151d2e87f2686a955a9be24d316f1362bf21 3.11.3\", nullptr)); // LCOV_EXCL_LINE\n                    }\n                    break;\n                }\n            }\n        }\n\n        /// constructor for strings\n        json_value(const string_t& value) : string(create<string_t>(value)) {}\n\n        /// constructor for rvalue strings\n        json_value(string_t&& value) : string(create<string_t>(std::move(value))) {}\n\n        /// constructor for objects\n        json_value(const object_t& value) : object(create<object_t>(value)) {}\n\n        /// constructor for rvalue objects\n        json_value(object_t&& value) : object(create<object_t>(std::move(value))) {}\n\n        /// constructor for arrays\n        json_value(const array_t& value) : array(create<array_t>(value)) {}\n\n        /// constructor for rvalue arrays\n        json_value(array_t&& value) : array(create<array_t>(std::move(value))) {}\n\n        /// constructor for binary arrays\n        json_value(const typename binary_t::container_type& value) : binary(create<binary_t>(value)) {}\n\n        /// constructor for rvalue binary arrays\n        json_value(typename binary_t::container_type&& value) : binary(create<binary_t>(std::move(value))) {}\n\n        /// constructor for binary arrays (internal type)\n        json_value(const binary_t& value) : binary(create<binary_t>(value)) {}\n\n        /// constructor for rvalue binary arrays (internal type)\n        json_value(binary_t&& value) : binary(create<binary_t>(std::move(value))) {}\n\n        void destroy(value_t t)\n        {\n            if (\n                (t == value_t::object && object == nullptr) ||\n                (t == value_t::array && array == nullptr) ||\n                (t == value_t::string && string == nullptr) ||\n                (t == value_t::binary && binary == nullptr)\n            )\n            {\n                //not initialized (e.g. due to exception in the ctor)\n                return;\n            }\n            if (t == value_t::array || t == value_t::object)\n            {\n                // flatten the current json_value to a heap-allocated stack\n                std::vector<basic_json> stack;\n\n                // move the top-level items to stack\n                if (t == value_t::array)\n                {\n                    stack.reserve(array->size());\n                    std::move(array->begin(), array->end(), std::back_inserter(stack));\n                }\n                else\n                {\n                    stack.reserve(object->size());\n                    for (auto&& it : *object)\n                    {\n                        stack.push_back(std::move(it.second));\n                    }\n                }\n\n                while (!stack.empty())\n                {\n                    // move the last item to local variable to be processed\n                    basic_json current_item(std::move(stack.back()));\n                    stack.pop_back();\n\n                    // if current_item is array/object, move\n                    // its children to the stack to be processed later\n                    if (current_item.is_array())\n                    {\n                        std::move(current_item.m_data.m_value.array->begin(), current_item.m_data.m_value.array->end(), std::back_inserter(stack));\n\n                        current_item.m_data.m_value.array->clear();\n                    }\n                    else if (current_item.is_object())\n                    {\n                        for (auto&& it : *current_item.m_data.m_value.object)\n                        {\n                            stack.push_back(std::move(it.second));\n                        }\n\n                        current_item.m_data.m_value.object->clear();\n                    }\n\n                    // it's now safe that current_item get destructed\n                    // since it doesn't have any children\n                }\n            }\n\n            switch (t)\n            {\n                case value_t::object:\n                {\n                    AllocatorType<object_t> alloc;\n                    std::allocator_traits<decltype(alloc)>::destroy(alloc, object);\n                    std::allocator_traits<decltype(alloc)>::deallocate(alloc, object, 1);\n                    break;\n                }\n\n                case value_t::array:\n                {\n                    AllocatorType<array_t> alloc;\n                    std::allocator_traits<decltype(alloc)>::destroy(alloc, array);\n                    std::allocator_traits<decltype(alloc)>::deallocate(alloc, array, 1);\n                    break;\n                }\n\n                case value_t::string:\n                {\n                    AllocatorType<string_t> alloc;\n                    std::allocator_traits<decltype(alloc)>::destroy(alloc, string);\n                    std::allocator_traits<decltype(alloc)>::deallocate(alloc, string, 1);\n                    break;\n                }\n\n                case value_t::binary:\n                {\n                    AllocatorType<binary_t> alloc;\n                    std::allocator_traits<decltype(alloc)>::destroy(alloc, binary);\n                    std::allocator_traits<decltype(alloc)>::deallocate(alloc, binary, 1);\n                    break;\n                }\n\n                case value_t::null:\n                case value_t::boolean:\n                case value_t::number_integer:\n                case value_t::number_unsigned:\n                case value_t::number_float:\n                case value_t::discarded:\n                default:\n                {\n                    break;\n                }\n            }\n        }\n    };\n\n  private:\n    /*!\n    @brief checks the class invariants\n\n    This function asserts the class invariants. It needs to be called at the\n    end of every constructor to make sure that created objects respect the\n    invariant. Furthermore, it has to be called each time the type of a JSON\n    value is changed, because the invariant expresses a relationship between\n    @a m_type and @a m_value.\n\n    Furthermore, the parent relation is checked for arrays and objects: If\n    @a check_parents true and the value is an array or object, then the\n    container's elements must have the current value as parent.\n\n    @param[in] check_parents  whether the parent relation should be checked.\n               The value is true by default and should only be set to false\n               during destruction of objects when the invariant does not\n               need to hold.\n    */\n    void assert_invariant(bool check_parents = true) const noexcept\n    {\n        JSON_ASSERT(m_data.m_type != value_t::object || m_data.m_value.object != nullptr);\n        JSON_ASSERT(m_data.m_type != value_t::array || m_data.m_value.array != nullptr);\n        JSON_ASSERT(m_data.m_type != value_t::string || m_data.m_value.string != nullptr);\n        JSON_ASSERT(m_data.m_type != value_t::binary || m_data.m_value.binary != nullptr);\n\n#if JSON_DIAGNOSTICS\n        JSON_TRY\n        {\n            // cppcheck-suppress assertWithSideEffect\n            JSON_ASSERT(!check_parents || !is_structured() || std::all_of(begin(), end(), [this](const basic_json & j)\n            {\n                return j.m_parent == this;\n            }));\n        }\n        JSON_CATCH(...) {} // LCOV_EXCL_LINE\n#endif\n        static_cast<void>(check_parents);\n    }\n\n    void set_parents()\n    {\n#if JSON_DIAGNOSTICS\n        switch (m_data.m_type)\n        {\n            case value_t::array:\n            {\n                for (auto& element : *m_data.m_value.array)\n                {\n                    element.m_parent = this;\n                }\n                break;\n            }\n\n            case value_t::object:\n            {\n                for (auto& element : *m_data.m_value.object)\n                {\n                    element.second.m_parent = this;\n                }\n                break;\n            }\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n                break;\n        }\n#endif\n    }\n\n    iterator set_parents(iterator it, typename iterator::difference_type count_set_parents)\n    {\n#if JSON_DIAGNOSTICS\n        for (typename iterator::difference_type i = 0; i < count_set_parents; ++i)\n        {\n            (it + i)->m_parent = this;\n        }\n#else\n        static_cast<void>(count_set_parents);\n#endif\n        return it;\n    }\n\n    reference set_parent(reference j, std::size_t old_capacity = static_cast<std::size_t>(-1))\n    {\n#if JSON_DIAGNOSTICS\n        if (old_capacity != static_cast<std::size_t>(-1))\n        {\n            // see https://github.com/nlohmann/json/issues/2838\n            JSON_ASSERT(type() == value_t::array);\n            if (JSON_HEDLEY_UNLIKELY(m_data.m_value.array->capacity() != old_capacity))\n            {\n                // capacity has changed: update all parents\n                set_parents();\n                return j;\n            }\n        }\n\n        // ordered_json uses a vector internally, so pointers could have\n        // been invalidated; see https://github.com/nlohmann/json/issues/2962\n#ifdef JSON_HEDLEY_MSVC_VERSION\n#pragma warning(push )\n#pragma warning(disable : 4127) // ignore warning to replace if with if constexpr\n#endif\n        if (detail::is_ordered_map<object_t>::value)\n        {\n            set_parents();\n            return j;\n        }\n#ifdef JSON_HEDLEY_MSVC_VERSION\n#pragma warning( pop )\n#endif\n\n        j.m_parent = this;\n#else\n        static_cast<void>(j);\n        static_cast<void>(old_capacity);\n#endif\n        return j;\n    }\n\n  public:\n    //////////////////////////\n    // JSON parser callback //\n    //////////////////////////\n\n    /// @brief parser event types\n    /// @sa https://json.nlohmann.me/api/basic_json/parse_event_t/\n    using parse_event_t = detail::parse_event_t;\n\n    /// @brief per-element parser callback type\n    /// @sa https://json.nlohmann.me/api/basic_json/parser_callback_t/\n    using parser_callback_t = detail::parser_callback_t<basic_json>;\n\n    //////////////////\n    // constructors //\n    //////////////////\n\n    /// @name constructors and destructors\n    /// Constructors of class @ref basic_json, copy/move constructor, copy\n    /// assignment, static functions creating objects, and the destructor.\n    /// @{\n\n    /// @brief create an empty value with a given type\n    /// @sa https://json.nlohmann.me/api/basic_json/basic_json/\n    basic_json(const value_t v)\n        : m_data(v)\n    {\n        assert_invariant();\n    }\n\n    /// @brief create a null object\n    /// @sa https://json.nlohmann.me/api/basic_json/basic_json/\n    basic_json(std::nullptr_t = nullptr) noexcept // NOLINT(bugprone-exception-escape)\n        : basic_json(value_t::null)\n    {\n        assert_invariant();\n    }\n\n    /// @brief create a JSON value from compatible types\n    /// @sa https://json.nlohmann.me/api/basic_json/basic_json/\n    template < typename CompatibleType,\n               typename U = detail::uncvref_t<CompatibleType>,\n               detail::enable_if_t <\n                   !detail::is_basic_json<U>::value && detail::is_compatible_type<basic_json_t, U>::value, int > = 0 >\n    basic_json(CompatibleType && val) noexcept(noexcept( // NOLINT(bugprone-forwarding-reference-overload,bugprone-exception-escape)\n                JSONSerializer<U>::to_json(std::declval<basic_json_t&>(),\n                                           std::forward<CompatibleType>(val))))\n    {\n        JSONSerializer<U>::to_json(*this, std::forward<CompatibleType>(val));\n        set_parents();\n        assert_invariant();\n    }\n\n    /// @brief create a JSON value from an existing one\n    /// @sa https://json.nlohmann.me/api/basic_json/basic_json/\n    template < typename BasicJsonType,\n               detail::enable_if_t <\n                   detail::is_basic_json<BasicJsonType>::value&& !std::is_same<basic_json, BasicJsonType>::value, int > = 0 >\n    basic_json(const BasicJsonType& val)\n    {\n        using other_boolean_t = typename BasicJsonType::boolean_t;\n        using other_number_float_t = typename BasicJsonType::number_float_t;\n        using other_number_integer_t = typename BasicJsonType::number_integer_t;\n        using other_number_unsigned_t = typename BasicJsonType::number_unsigned_t;\n        using other_string_t = typename BasicJsonType::string_t;\n        using other_object_t = typename BasicJsonType::object_t;\n        using other_array_t = typename BasicJsonType::array_t;\n        using other_binary_t = typename BasicJsonType::binary_t;\n\n        switch (val.type())\n        {\n            case value_t::boolean:\n                JSONSerializer<other_boolean_t>::to_json(*this, val.template get<other_boolean_t>());\n                break;\n            case value_t::number_float:\n                JSONSerializer<other_number_float_t>::to_json(*this, val.template get<other_number_float_t>());\n                break;\n            case value_t::number_integer:\n                JSONSerializer<other_number_integer_t>::to_json(*this, val.template get<other_number_integer_t>());\n                break;\n            case value_t::number_unsigned:\n                JSONSerializer<other_number_unsigned_t>::to_json(*this, val.template get<other_number_unsigned_t>());\n                break;\n            case value_t::string:\n                JSONSerializer<other_string_t>::to_json(*this, val.template get_ref<const other_string_t&>());\n                break;\n            case value_t::object:\n                JSONSerializer<other_object_t>::to_json(*this, val.template get_ref<const other_object_t&>());\n                break;\n            case value_t::array:\n                JSONSerializer<other_array_t>::to_json(*this, val.template get_ref<const other_array_t&>());\n                break;\n            case value_t::binary:\n                JSONSerializer<other_binary_t>::to_json(*this, val.template get_ref<const other_binary_t&>());\n                break;\n            case value_t::null:\n                *this = nullptr;\n                break;\n            case value_t::discarded:\n                m_data.m_type = value_t::discarded;\n                break;\n            default:            // LCOV_EXCL_LINE\n                JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n        }\n        JSON_ASSERT(m_data.m_type == val.type());\n        set_parents();\n        assert_invariant();\n    }\n\n    /// @brief create a container (array or object) from an initializer list\n    /// @sa https://json.nlohmann.me/api/basic_json/basic_json/\n    basic_json(initializer_list_t init,\n               bool type_deduction = true,\n               value_t manual_type = value_t::array)\n    {\n        // check if each element is an array with two elements whose first\n        // element is a string\n        bool is_an_object = std::all_of(init.begin(), init.end(),\n                                        [](const detail::json_ref<basic_json>& element_ref)\n        {\n            // The cast is to ensure op[size_type] is called, bearing in mind size_type may not be int;\n            // (many string types can be constructed from 0 via its null-pointer guise, so we get a\n            // broken call to op[key_type], the wrong semantics and a 4804 warning on Windows)\n            return element_ref->is_array() && element_ref->size() == 2 && (*element_ref)[static_cast<size_type>(0)].is_string();\n        });\n\n        // adjust type if type deduction is not wanted\n        if (!type_deduction)\n        {\n            // if array is wanted, do not create an object though possible\n            if (manual_type == value_t::array)\n            {\n                is_an_object = false;\n            }\n\n            // if object is wanted but impossible, throw an exception\n            if (JSON_HEDLEY_UNLIKELY(manual_type == value_t::object && !is_an_object))\n            {\n                JSON_THROW(type_error::create(301, \"cannot create object from initializer list\", nullptr));\n            }\n        }\n\n        if (is_an_object)\n        {\n            // the initializer list is a list of pairs -> create object\n            m_data.m_type = value_t::object;\n            m_data.m_value = value_t::object;\n\n            for (auto& element_ref : init)\n            {\n                auto element = element_ref.moved_or_copied();\n                m_data.m_value.object->emplace(\n                    std::move(*((*element.m_data.m_value.array)[0].m_data.m_value.string)),\n                    std::move((*element.m_data.m_value.array)[1]));\n            }\n        }\n        else\n        {\n            // the initializer list describes an array -> create array\n            m_data.m_type = value_t::array;\n            m_data.m_value.array = create<array_t>(init.begin(), init.end());\n        }\n\n        set_parents();\n        assert_invariant();\n    }\n\n    /// @brief explicitly create a binary array (without subtype)\n    /// @sa https://json.nlohmann.me/api/basic_json/binary/\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json binary(const typename binary_t::container_type& init)\n    {\n        auto res = basic_json();\n        res.m_data.m_type = value_t::binary;\n        res.m_data.m_value = init;\n        return res;\n    }\n\n    /// @brief explicitly create a binary array (with subtype)\n    /// @sa https://json.nlohmann.me/api/basic_json/binary/\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json binary(const typename binary_t::container_type& init, typename binary_t::subtype_type subtype)\n    {\n        auto res = basic_json();\n        res.m_data.m_type = value_t::binary;\n        res.m_data.m_value = binary_t(init, subtype);\n        return res;\n    }\n\n    /// @brief explicitly create a binary array\n    /// @sa https://json.nlohmann.me/api/basic_json/binary/\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json binary(typename binary_t::container_type&& init)\n    {\n        auto res = basic_json();\n        res.m_data.m_type = value_t::binary;\n        res.m_data.m_value = std::move(init);\n        return res;\n    }\n\n    /// @brief explicitly create a binary array (with subtype)\n    /// @sa https://json.nlohmann.me/api/basic_json/binary/\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json binary(typename binary_t::container_type&& init, typename binary_t::subtype_type subtype)\n    {\n        auto res = basic_json();\n        res.m_data.m_type = value_t::binary;\n        res.m_data.m_value = binary_t(std::move(init), subtype);\n        return res;\n    }\n\n    /// @brief explicitly create an array from an initializer list\n    /// @sa https://json.nlohmann.me/api/basic_json/array/\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json array(initializer_list_t init = {})\n    {\n        return basic_json(init, false, value_t::array);\n    }\n\n    /// @brief explicitly create an object from an initializer list\n    /// @sa https://json.nlohmann.me/api/basic_json/object/\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json object(initializer_list_t init = {})\n    {\n        return basic_json(init, false, value_t::object);\n    }\n\n    /// @brief construct an array with count copies of given value\n    /// @sa https://json.nlohmann.me/api/basic_json/basic_json/\n    basic_json(size_type cnt, const basic_json& val):\n        m_data{cnt, val}\n    {\n        set_parents();\n        assert_invariant();\n    }\n\n    /// @brief construct a JSON container given an iterator range\n    /// @sa https://json.nlohmann.me/api/basic_json/basic_json/\n    template < class InputIT, typename std::enable_if <\n                   std::is_same<InputIT, typename basic_json_t::iterator>::value ||\n                   std::is_same<InputIT, typename basic_json_t::const_iterator>::value, int >::type = 0 >\n    basic_json(InputIT first, InputIT last)\n    {\n        JSON_ASSERT(first.m_object != nullptr);\n        JSON_ASSERT(last.m_object != nullptr);\n\n        // make sure iterator fits the current value\n        if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object))\n        {\n            JSON_THROW(invalid_iterator::create(201, \"iterators are not compatible\", nullptr));\n        }\n\n        // copy type from first iterator\n        m_data.m_type = first.m_object->m_data.m_type;\n\n        // check if iterator range is complete for primitive values\n        switch (m_data.m_type)\n        {\n            case value_t::boolean:\n            case value_t::number_float:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::string:\n            {\n                if (JSON_HEDLEY_UNLIKELY(!first.m_it.primitive_iterator.is_begin()\n                                         || !last.m_it.primitive_iterator.is_end()))\n                {\n                    JSON_THROW(invalid_iterator::create(204, \"iterators out of range\", first.m_object));\n                }\n                break;\n            }\n\n            case value_t::null:\n            case value_t::object:\n            case value_t::array:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n                break;\n        }\n\n        switch (m_data.m_type)\n        {\n            case value_t::number_integer:\n            {\n                m_data.m_value.number_integer = first.m_object->m_data.m_value.number_integer;\n                break;\n            }\n\n            case value_t::number_unsigned:\n            {\n                m_data.m_value.number_unsigned = first.m_object->m_data.m_value.number_unsigned;\n                break;\n            }\n\n            case value_t::number_float:\n            {\n                m_data.m_value.number_float = first.m_object->m_data.m_value.number_float;\n                break;\n            }\n\n            case value_t::boolean:\n            {\n                m_data.m_value.boolean = first.m_object->m_data.m_value.boolean;\n                break;\n            }\n\n            case value_t::string:\n            {\n                m_data.m_value = *first.m_object->m_data.m_value.string;\n                break;\n            }\n\n            case value_t::object:\n            {\n                m_data.m_value.object = create<object_t>(first.m_it.object_iterator,\n                                        last.m_it.object_iterator);\n                break;\n            }\n\n            case value_t::array:\n            {\n                m_data.m_value.array = create<array_t>(first.m_it.array_iterator,\n                                                       last.m_it.array_iterator);\n                break;\n            }\n\n            case value_t::binary:\n            {\n                m_data.m_value = *first.m_object->m_data.m_value.binary;\n                break;\n            }\n\n            case value_t::null:\n            case value_t::discarded:\n            default:\n                JSON_THROW(invalid_iterator::create(206, detail::concat(\"cannot construct with iterators from \", first.m_object->type_name()), first.m_object));\n        }\n\n        set_parents();\n        assert_invariant();\n    }\n\n    ///////////////////////////////////////\n    // other constructors and destructor //\n    ///////////////////////////////////////\n\n    template<typename JsonRef,\n             detail::enable_if_t<detail::conjunction<detail::is_json_ref<JsonRef>,\n                                 std::is_same<typename JsonRef::value_type, basic_json>>::value, int> = 0 >\n    basic_json(const JsonRef& ref) : basic_json(ref.moved_or_copied()) {}\n\n    /// @brief copy constructor\n    /// @sa https://json.nlohmann.me/api/basic_json/basic_json/\n    basic_json(const basic_json& other)\n        : json_base_class_t(other)\n    {\n        m_data.m_type = other.m_data.m_type;\n        // check of passed value is valid\n        other.assert_invariant();\n\n        switch (m_data.m_type)\n        {\n            case value_t::object:\n            {\n                m_data.m_value = *other.m_data.m_value.object;\n                break;\n            }\n\n            case value_t::array:\n            {\n                m_data.m_value = *other.m_data.m_value.array;\n                break;\n            }\n\n            case value_t::string:\n            {\n                m_data.m_value = *other.m_data.m_value.string;\n                break;\n            }\n\n            case value_t::boolean:\n            {\n                m_data.m_value = other.m_data.m_value.boolean;\n                break;\n            }\n\n            case value_t::number_integer:\n            {\n                m_data.m_value = other.m_data.m_value.number_integer;\n                break;\n            }\n\n            case value_t::number_unsigned:\n            {\n                m_data.m_value = other.m_data.m_value.number_unsigned;\n                break;\n            }\n\n            case value_t::number_float:\n            {\n                m_data.m_value = other.m_data.m_value.number_float;\n                break;\n            }\n\n            case value_t::binary:\n            {\n                m_data.m_value = *other.m_data.m_value.binary;\n                break;\n            }\n\n            case value_t::null:\n            case value_t::discarded:\n            default:\n                break;\n        }\n\n        set_parents();\n        assert_invariant();\n    }\n\n    /// @brief move constructor\n    /// @sa https://json.nlohmann.me/api/basic_json/basic_json/\n    basic_json(basic_json&& other) noexcept\n        : json_base_class_t(std::forward<json_base_class_t>(other)),\n          m_data(std::move(other.m_data))\n    {\n        // check that passed value is valid\n        other.assert_invariant(false);\n\n        // invalidate payload\n        other.m_data.m_type = value_t::null;\n        other.m_data.m_value = {};\n\n        set_parents();\n        assert_invariant();\n    }\n\n    /// @brief copy assignment\n    /// @sa https://json.nlohmann.me/api/basic_json/operator=/\n    basic_json& operator=(basic_json other) noexcept (\n        std::is_nothrow_move_constructible<value_t>::value&&\n        std::is_nothrow_move_assignable<value_t>::value&&\n        std::is_nothrow_move_constructible<json_value>::value&&\n        std::is_nothrow_move_assignable<json_value>::value&&\n        std::is_nothrow_move_assignable<json_base_class_t>::value\n    )\n    {\n        // check that passed value is valid\n        other.assert_invariant();\n\n        using std::swap;\n        swap(m_data.m_type, other.m_data.m_type);\n        swap(m_data.m_value, other.m_data.m_value);\n        json_base_class_t::operator=(std::move(other));\n\n        set_parents();\n        assert_invariant();\n        return *this;\n    }\n\n    /// @brief destructor\n    /// @sa https://json.nlohmann.me/api/basic_json/~basic_json/\n    ~basic_json() noexcept\n    {\n        assert_invariant(false);\n    }\n\n    /// @}\n\n  public:\n    ///////////////////////\n    // object inspection //\n    ///////////////////////\n\n    /// @name object inspection\n    /// Functions to inspect the type of a JSON value.\n    /// @{\n\n    /// @brief serialization\n    /// @sa https://json.nlohmann.me/api/basic_json/dump/\n    string_t dump(const int indent = -1,\n                  const char indent_char = ' ',\n                  const bool ensure_ascii = false,\n                  const error_handler_t error_handler = error_handler_t::strict) const\n    {\n        string_t result;\n        serializer s(detail::output_adapter<char, string_t>(result), indent_char, error_handler);\n\n        if (indent >= 0)\n        {\n            s.dump(*this, true, ensure_ascii, static_cast<unsigned int>(indent));\n        }\n        else\n        {\n            s.dump(*this, false, ensure_ascii, 0);\n        }\n\n        return result;\n    }\n\n    /// @brief return the type of the JSON value (explicit)\n    /// @sa https://json.nlohmann.me/api/basic_json/type/\n    constexpr value_t type() const noexcept\n    {\n        return m_data.m_type;\n    }\n\n    /// @brief return whether type is primitive\n    /// @sa https://json.nlohmann.me/api/basic_json/is_primitive/\n    constexpr bool is_primitive() const noexcept\n    {\n        return is_null() || is_string() || is_boolean() || is_number() || is_binary();\n    }\n\n    /// @brief return whether type is structured\n    /// @sa https://json.nlohmann.me/api/basic_json/is_structured/\n    constexpr bool is_structured() const noexcept\n    {\n        return is_array() || is_object();\n    }\n\n    /// @brief return whether value is null\n    /// @sa https://json.nlohmann.me/api/basic_json/is_null/\n    constexpr bool is_null() const noexcept\n    {\n        return m_data.m_type == value_t::null;\n    }\n\n    /// @brief return whether value is a boolean\n    /// @sa https://json.nlohmann.me/api/basic_json/is_boolean/\n    constexpr bool is_boolean() const noexcept\n    {\n        return m_data.m_type == value_t::boolean;\n    }\n\n    /// @brief return whether value is a number\n    /// @sa https://json.nlohmann.me/api/basic_json/is_number/\n    constexpr bool is_number() const noexcept\n    {\n        return is_number_integer() || is_number_float();\n    }\n\n    /// @brief return whether value is an integer number\n    /// @sa https://json.nlohmann.me/api/basic_json/is_number_integer/\n    constexpr bool is_number_integer() const noexcept\n    {\n        return m_data.m_type == value_t::number_integer || m_data.m_type == value_t::number_unsigned;\n    }\n\n    /// @brief return whether value is an unsigned integer number\n    /// @sa https://json.nlohmann.me/api/basic_json/is_number_unsigned/\n    constexpr bool is_number_unsigned() const noexcept\n    {\n        return m_data.m_type == value_t::number_unsigned;\n    }\n\n    /// @brief return whether value is a floating-point number\n    /// @sa https://json.nlohmann.me/api/basic_json/is_number_float/\n    constexpr bool is_number_float() const noexcept\n    {\n        return m_data.m_type == value_t::number_float;\n    }\n\n    /// @brief return whether value is an object\n    /// @sa https://json.nlohmann.me/api/basic_json/is_object/\n    constexpr bool is_object() const noexcept\n    {\n        return m_data.m_type == value_t::object;\n    }\n\n    /// @brief return whether value is an array\n    /// @sa https://json.nlohmann.me/api/basic_json/is_array/\n    constexpr bool is_array() const noexcept\n    {\n        return m_data.m_type == value_t::array;\n    }\n\n    /// @brief return whether value is a string\n    /// @sa https://json.nlohmann.me/api/basic_json/is_string/\n    constexpr bool is_string() const noexcept\n    {\n        return m_data.m_type == value_t::string;\n    }\n\n    /// @brief return whether value is a binary array\n    /// @sa https://json.nlohmann.me/api/basic_json/is_binary/\n    constexpr bool is_binary() const noexcept\n    {\n        return m_data.m_type == value_t::binary;\n    }\n\n    /// @brief return whether value is discarded\n    /// @sa https://json.nlohmann.me/api/basic_json/is_discarded/\n    constexpr bool is_discarded() const noexcept\n    {\n        return m_data.m_type == value_t::discarded;\n    }\n\n    /// @brief return the type of the JSON value (implicit)\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_value_t/\n    constexpr operator value_t() const noexcept\n    {\n        return m_data.m_type;\n    }\n\n    /// @}\n\n  private:\n    //////////////////\n    // value access //\n    //////////////////\n\n    /// get a boolean (explicit)\n    boolean_t get_impl(boolean_t* /*unused*/) const\n    {\n        if (JSON_HEDLEY_LIKELY(is_boolean()))\n        {\n            return m_data.m_value.boolean;\n        }\n\n        JSON_THROW(type_error::create(302, detail::concat(\"type must be boolean, but is \", type_name()), this));\n    }\n\n    /// get a pointer to the value (object)\n    object_t* get_impl_ptr(object_t* /*unused*/) noexcept\n    {\n        return is_object() ? m_data.m_value.object : nullptr;\n    }\n\n    /// get a pointer to the value (object)\n    constexpr const object_t* get_impl_ptr(const object_t* /*unused*/) const noexcept\n    {\n        return is_object() ? m_data.m_value.object : nullptr;\n    }\n\n    /// get a pointer to the value (array)\n    array_t* get_impl_ptr(array_t* /*unused*/) noexcept\n    {\n        return is_array() ? m_data.m_value.array : nullptr;\n    }\n\n    /// get a pointer to the value (array)\n    constexpr const array_t* get_impl_ptr(const array_t* /*unused*/) const noexcept\n    {\n        return is_array() ? m_data.m_value.array : nullptr;\n    }\n\n    /// get a pointer to the value (string)\n    string_t* get_impl_ptr(string_t* /*unused*/) noexcept\n    {\n        return is_string() ? m_data.m_value.string : nullptr;\n    }\n\n    /// get a pointer to the value (string)\n    constexpr const string_t* get_impl_ptr(const string_t* /*unused*/) const noexcept\n    {\n        return is_string() ? m_data.m_value.string : nullptr;\n    }\n\n    /// get a pointer to the value (boolean)\n    boolean_t* get_impl_ptr(boolean_t* /*unused*/) noexcept\n    {\n        return is_boolean() ? &m_data.m_value.boolean : nullptr;\n    }\n\n    /// get a pointer to the value (boolean)\n    constexpr const boolean_t* get_impl_ptr(const boolean_t* /*unused*/) const noexcept\n    {\n        return is_boolean() ? &m_data.m_value.boolean : nullptr;\n    }\n\n    /// get a pointer to the value (integer number)\n    number_integer_t* get_impl_ptr(number_integer_t* /*unused*/) noexcept\n    {\n        return is_number_integer() ? &m_data.m_value.number_integer : nullptr;\n    }\n\n    /// get a pointer to the value (integer number)\n    constexpr const number_integer_t* get_impl_ptr(const number_integer_t* /*unused*/) const noexcept\n    {\n        return is_number_integer() ? &m_data.m_value.number_integer : nullptr;\n    }\n\n    /// get a pointer to the value (unsigned number)\n    number_unsigned_t* get_impl_ptr(number_unsigned_t* /*unused*/) noexcept\n    {\n        return is_number_unsigned() ? &m_data.m_value.number_unsigned : nullptr;\n    }\n\n    /// get a pointer to the value (unsigned number)\n    constexpr const number_unsigned_t* get_impl_ptr(const number_unsigned_t* /*unused*/) const noexcept\n    {\n        return is_number_unsigned() ? &m_data.m_value.number_unsigned : nullptr;\n    }\n\n    /// get a pointer to the value (floating-point number)\n    number_float_t* get_impl_ptr(number_float_t* /*unused*/) noexcept\n    {\n        return is_number_float() ? &m_data.m_value.number_float : nullptr;\n    }\n\n    /// get a pointer to the value (floating-point number)\n    constexpr const number_float_t* get_impl_ptr(const number_float_t* /*unused*/) const noexcept\n    {\n        return is_number_float() ? &m_data.m_value.number_float : nullptr;\n    }\n\n    /// get a pointer to the value (binary)\n    binary_t* get_impl_ptr(binary_t* /*unused*/) noexcept\n    {\n        return is_binary() ? m_data.m_value.binary : nullptr;\n    }\n\n    /// get a pointer to the value (binary)\n    constexpr const binary_t* get_impl_ptr(const binary_t* /*unused*/) const noexcept\n    {\n        return is_binary() ? m_data.m_value.binary : nullptr;\n    }\n\n    /*!\n    @brief helper function to implement get_ref()\n\n    This function helps to implement get_ref() without code duplication for\n    const and non-const overloads\n\n    @tparam ThisType will be deduced as `basic_json` or `const basic_json`\n\n    @throw type_error.303 if ReferenceType does not match underlying value\n    type of the current JSON\n    */\n    template<typename ReferenceType, typename ThisType>\n    static ReferenceType get_ref_impl(ThisType& obj)\n    {\n        // delegate the call to get_ptr<>()\n        auto* ptr = obj.template get_ptr<typename std::add_pointer<ReferenceType>::type>();\n\n        if (JSON_HEDLEY_LIKELY(ptr != nullptr))\n        {\n            return *ptr;\n        }\n\n        JSON_THROW(type_error::create(303, detail::concat(\"incompatible ReferenceType for get_ref, actual type is \", obj.type_name()), &obj));\n    }\n\n  public:\n    /// @name value access\n    /// Direct access to the stored value of a JSON value.\n    /// @{\n\n    /// @brief get a pointer value (implicit)\n    /// @sa https://json.nlohmann.me/api/basic_json/get_ptr/\n    template<typename PointerType, typename std::enable_if<\n                 std::is_pointer<PointerType>::value, int>::type = 0>\n    auto get_ptr() noexcept -> decltype(std::declval<basic_json_t&>().get_impl_ptr(std::declval<PointerType>()))\n    {\n        // delegate the call to get_impl_ptr<>()\n        return get_impl_ptr(static_cast<PointerType>(nullptr));\n    }\n\n    /// @brief get a pointer value (implicit)\n    /// @sa https://json.nlohmann.me/api/basic_json/get_ptr/\n    template < typename PointerType, typename std::enable_if <\n                   std::is_pointer<PointerType>::value&&\n                   std::is_const<typename std::remove_pointer<PointerType>::type>::value, int >::type = 0 >\n    constexpr auto get_ptr() const noexcept -> decltype(std::declval<const basic_json_t&>().get_impl_ptr(std::declval<PointerType>()))\n    {\n        // delegate the call to get_impl_ptr<>() const\n        return get_impl_ptr(static_cast<PointerType>(nullptr));\n    }\n\n  private:\n    /*!\n    @brief get a value (explicit)\n\n    Explicit type conversion between the JSON value and a compatible value\n    which is [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible)\n    and [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible).\n    The value is converted by calling the @ref json_serializer<ValueType>\n    `from_json()` method.\n\n    The function is equivalent to executing\n    @code {.cpp}\n    ValueType ret;\n    JSONSerializer<ValueType>::from_json(*this, ret);\n    return ret;\n    @endcode\n\n    This overloads is chosen if:\n    - @a ValueType is not @ref basic_json,\n    - @ref json_serializer<ValueType> has a `from_json()` method of the form\n      `void from_json(const basic_json&, ValueType&)`, and\n    - @ref json_serializer<ValueType> does not have a `from_json()` method of\n      the form `ValueType from_json(const basic_json&)`\n\n    @tparam ValueType the returned value type\n\n    @return copy of the JSON value, converted to @a ValueType\n\n    @throw what @ref json_serializer<ValueType> `from_json()` method throws\n\n    @liveexample{The example below shows several conversions from JSON values\n    to other types. There a few things to note: (1) Floating-point numbers can\n    be converted to integers\\, (2) A JSON array can be converted to a standard\n    `std::vector<short>`\\, (3) A JSON object can be converted to C++\n    associative containers such as `std::unordered_map<std::string\\,\n    json>`.,get__ValueType_const}\n\n    @since version 2.1.0\n    */\n    template < typename ValueType,\n               detail::enable_if_t <\n                   detail::is_default_constructible<ValueType>::value&&\n                   detail::has_from_json<basic_json_t, ValueType>::value,\n                   int > = 0 >\n    ValueType get_impl(detail::priority_tag<0> /*unused*/) const noexcept(noexcept(\n                JSONSerializer<ValueType>::from_json(std::declval<const basic_json_t&>(), std::declval<ValueType&>())))\n    {\n        auto ret = ValueType();\n        JSONSerializer<ValueType>::from_json(*this, ret);\n        return ret;\n    }\n\n    /*!\n    @brief get a value (explicit); special case\n\n    Explicit type conversion between the JSON value and a compatible value\n    which is **not** [CopyConstructible](https://en.cppreference.com/w/cpp/named_req/CopyConstructible)\n    and **not** [DefaultConstructible](https://en.cppreference.com/w/cpp/named_req/DefaultConstructible).\n    The value is converted by calling the @ref json_serializer<ValueType>\n    `from_json()` method.\n\n    The function is equivalent to executing\n    @code {.cpp}\n    return JSONSerializer<ValueType>::from_json(*this);\n    @endcode\n\n    This overloads is chosen if:\n    - @a ValueType is not @ref basic_json and\n    - @ref json_serializer<ValueType> has a `from_json()` method of the form\n      `ValueType from_json(const basic_json&)`\n\n    @note If @ref json_serializer<ValueType> has both overloads of\n    `from_json()`, this one is chosen.\n\n    @tparam ValueType the returned value type\n\n    @return copy of the JSON value, converted to @a ValueType\n\n    @throw what @ref json_serializer<ValueType> `from_json()` method throws\n\n    @since version 2.1.0\n    */\n    template < typename ValueType,\n               detail::enable_if_t <\n                   detail::has_non_default_from_json<basic_json_t, ValueType>::value,\n                   int > = 0 >\n    ValueType get_impl(detail::priority_tag<1> /*unused*/) const noexcept(noexcept(\n                JSONSerializer<ValueType>::from_json(std::declval<const basic_json_t&>())))\n    {\n        return JSONSerializer<ValueType>::from_json(*this);\n    }\n\n    /*!\n    @brief get special-case overload\n\n    This overloads converts the current @ref basic_json in a different\n    @ref basic_json type\n\n    @tparam BasicJsonType == @ref basic_json\n\n    @return a copy of *this, converted into @a BasicJsonType\n\n    @complexity Depending on the implementation of the called `from_json()`\n                method.\n\n    @since version 3.2.0\n    */\n    template < typename BasicJsonType,\n               detail::enable_if_t <\n                   detail::is_basic_json<BasicJsonType>::value,\n                   int > = 0 >\n    BasicJsonType get_impl(detail::priority_tag<2> /*unused*/) const\n    {\n        return *this;\n    }\n\n    /*!\n    @brief get special-case overload\n\n    This overloads avoids a lot of template boilerplate, it can be seen as the\n    identity method\n\n    @tparam BasicJsonType == @ref basic_json\n\n    @return a copy of *this\n\n    @complexity Constant.\n\n    @since version 2.1.0\n    */\n    template<typename BasicJsonType,\n             detail::enable_if_t<\n                 std::is_same<BasicJsonType, basic_json_t>::value,\n                 int> = 0>\n    basic_json get_impl(detail::priority_tag<3> /*unused*/) const\n    {\n        return *this;\n    }\n\n    /*!\n    @brief get a pointer value (explicit)\n    @copydoc get()\n    */\n    template<typename PointerType,\n             detail::enable_if_t<\n                 std::is_pointer<PointerType>::value,\n                 int> = 0>\n    constexpr auto get_impl(detail::priority_tag<4> /*unused*/) const noexcept\n    -> decltype(std::declval<const basic_json_t&>().template get_ptr<PointerType>())\n    {\n        // delegate the call to get_ptr\n        return get_ptr<PointerType>();\n    }\n\n  public:\n    /*!\n    @brief get a (pointer) value (explicit)\n\n    Performs explicit type conversion between the JSON value and a compatible value if required.\n\n    - If the requested type is a pointer to the internally stored JSON value that pointer is returned.\n    No copies are made.\n\n    - If the requested type is the current @ref basic_json, or a different @ref basic_json convertible\n    from the current @ref basic_json.\n\n    - Otherwise the value is converted by calling the @ref json_serializer<ValueType> `from_json()`\n    method.\n\n    @tparam ValueTypeCV the provided value type\n    @tparam ValueType the returned value type\n\n    @return copy of the JSON value, converted to @tparam ValueType if necessary\n\n    @throw what @ref json_serializer<ValueType> `from_json()` method throws if conversion is required\n\n    @since version 2.1.0\n    */\n    template < typename ValueTypeCV, typename ValueType = detail::uncvref_t<ValueTypeCV>>\n#if defined(JSON_HAS_CPP_14)\n    constexpr\n#endif\n    auto get() const noexcept(\n    noexcept(std::declval<const basic_json_t&>().template get_impl<ValueType>(detail::priority_tag<4> {})))\n    -> decltype(std::declval<const basic_json_t&>().template get_impl<ValueType>(detail::priority_tag<4> {}))\n    {\n        // we cannot static_assert on ValueTypeCV being non-const, because\n        // there is support for get<const basic_json_t>(), which is why we\n        // still need the uncvref\n        static_assert(!std::is_reference<ValueTypeCV>::value,\n                      \"get() cannot be used with reference types, you might want to use get_ref()\");\n        return get_impl<ValueType>(detail::priority_tag<4> {});\n    }\n\n    /*!\n    @brief get a pointer value (explicit)\n\n    Explicit pointer access to the internally stored JSON value. No copies are\n    made.\n\n    @warning The pointer becomes invalid if the underlying JSON object\n    changes.\n\n    @tparam PointerType pointer type; must be a pointer to @ref array_t, @ref\n    object_t, @ref string_t, @ref boolean_t, @ref number_integer_t,\n    @ref number_unsigned_t, or @ref number_float_t.\n\n    @return pointer to the internally stored JSON value if the requested\n    pointer type @a PointerType fits to the JSON value; `nullptr` otherwise\n\n    @complexity Constant.\n\n    @liveexample{The example below shows how pointers to internal values of a\n    JSON value can be requested. Note that no type conversions are made and a\n    `nullptr` is returned if the value and the requested pointer type does not\n    match.,get__PointerType}\n\n    @sa see @ref get_ptr() for explicit pointer-member access\n\n    @since version 1.0.0\n    */\n    template<typename PointerType, typename std::enable_if<\n                 std::is_pointer<PointerType>::value, int>::type = 0>\n    auto get() noexcept -> decltype(std::declval<basic_json_t&>().template get_ptr<PointerType>())\n    {\n        // delegate the call to get_ptr\n        return get_ptr<PointerType>();\n    }\n\n    /// @brief get a value (explicit)\n    /// @sa https://json.nlohmann.me/api/basic_json/get_to/\n    template < typename ValueType,\n               detail::enable_if_t <\n                   !detail::is_basic_json<ValueType>::value&&\n                   detail::has_from_json<basic_json_t, ValueType>::value,\n                   int > = 0 >\n    ValueType & get_to(ValueType& v) const noexcept(noexcept(\n                JSONSerializer<ValueType>::from_json(std::declval<const basic_json_t&>(), v)))\n    {\n        JSONSerializer<ValueType>::from_json(*this, v);\n        return v;\n    }\n\n    // specialization to allow calling get_to with a basic_json value\n    // see https://github.com/nlohmann/json/issues/2175\n    template<typename ValueType,\n             detail::enable_if_t <\n                 detail::is_basic_json<ValueType>::value,\n                 int> = 0>\n    ValueType & get_to(ValueType& v) const\n    {\n        v = *this;\n        return v;\n    }\n\n    template <\n        typename T, std::size_t N,\n        typename Array = T (&)[N], // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays)\n        detail::enable_if_t <\n            detail::has_from_json<basic_json_t, Array>::value, int > = 0 >\n    Array get_to(T (&v)[N]) const // NOLINT(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays)\n    noexcept(noexcept(JSONSerializer<Array>::from_json(\n                          std::declval<const basic_json_t&>(), v)))\n    {\n        JSONSerializer<Array>::from_json(*this, v);\n        return v;\n    }\n\n    /// @brief get a reference value (implicit)\n    /// @sa https://json.nlohmann.me/api/basic_json/get_ref/\n    template<typename ReferenceType, typename std::enable_if<\n                 std::is_reference<ReferenceType>::value, int>::type = 0>\n    ReferenceType get_ref()\n    {\n        // delegate call to get_ref_impl\n        return get_ref_impl<ReferenceType>(*this);\n    }\n\n    /// @brief get a reference value (implicit)\n    /// @sa https://json.nlohmann.me/api/basic_json/get_ref/\n    template < typename ReferenceType, typename std::enable_if <\n                   std::is_reference<ReferenceType>::value&&\n                   std::is_const<typename std::remove_reference<ReferenceType>::type>::value, int >::type = 0 >\n    ReferenceType get_ref() const\n    {\n        // delegate call to get_ref_impl\n        return get_ref_impl<ReferenceType>(*this);\n    }\n\n    /*!\n    @brief get a value (implicit)\n\n    Implicit type conversion between the JSON value and a compatible value.\n    The call is realized by calling @ref get() const.\n\n    @tparam ValueType non-pointer type compatible to the JSON value, for\n    instance `int` for JSON integer numbers, `bool` for JSON booleans, or\n    `std::vector` types for JSON arrays. The character type of @ref string_t\n    as well as an initializer list of this type is excluded to avoid\n    ambiguities as these types implicitly convert to `std::string`.\n\n    @return copy of the JSON value, converted to type @a ValueType\n\n    @throw type_error.302 in case passed type @a ValueType is incompatible\n    to the JSON value type (e.g., the JSON value is of type boolean, but a\n    string is requested); see example below\n\n    @complexity Linear in the size of the JSON value.\n\n    @liveexample{The example below shows several conversions from JSON values\n    to other types. There a few things to note: (1) Floating-point numbers can\n    be converted to integers\\, (2) A JSON array can be converted to a standard\n    `std::vector<short>`\\, (3) A JSON object can be converted to C++\n    associative containers such as `std::unordered_map<std::string\\,\n    json>`.,operator__ValueType}\n\n    @since version 1.0.0\n    */\n    template < typename ValueType, typename std::enable_if <\n                   detail::conjunction <\n                       detail::negation<std::is_pointer<ValueType>>,\n                       detail::negation<std::is_same<ValueType, std::nullptr_t>>,\n                       detail::negation<std::is_same<ValueType, detail::json_ref<basic_json>>>,\n                                        detail::negation<std::is_same<ValueType, typename string_t::value_type>>,\n                                        detail::negation<detail::is_basic_json<ValueType>>,\n                                        detail::negation<std::is_same<ValueType, std::initializer_list<typename string_t::value_type>>>,\n#if defined(JSON_HAS_CPP_17) && (defined(__GNUC__) || (defined(_MSC_VER) && _MSC_VER >= 1910 && _MSC_VER <= 1914))\n                                                detail::negation<std::is_same<ValueType, std::string_view>>,\n#endif\n#if defined(JSON_HAS_CPP_17) && JSON_HAS_STATIC_RTTI\n                                                detail::negation<std::is_same<ValueType, std::any>>,\n#endif\n                                                detail::is_detected_lazy<detail::get_template_function, const basic_json_t&, ValueType>\n                                                >::value, int >::type = 0 >\n                                        JSON_EXPLICIT operator ValueType() const\n    {\n        // delegate the call to get<>() const\n        return get<ValueType>();\n    }\n\n    /// @brief get a binary value\n    /// @sa https://json.nlohmann.me/api/basic_json/get_binary/\n    binary_t& get_binary()\n    {\n        if (!is_binary())\n        {\n            JSON_THROW(type_error::create(302, detail::concat(\"type must be binary, but is \", type_name()), this));\n        }\n\n        return *get_ptr<binary_t*>();\n    }\n\n    /// @brief get a binary value\n    /// @sa https://json.nlohmann.me/api/basic_json/get_binary/\n    const binary_t& get_binary() const\n    {\n        if (!is_binary())\n        {\n            JSON_THROW(type_error::create(302, detail::concat(\"type must be binary, but is \", type_name()), this));\n        }\n\n        return *get_ptr<const binary_t*>();\n    }\n\n    /// @}\n\n    ////////////////////\n    // element access //\n    ////////////////////\n\n    /// @name element access\n    /// Access to the JSON value.\n    /// @{\n\n    /// @brief access specified array element with bounds checking\n    /// @sa https://json.nlohmann.me/api/basic_json/at/\n    reference at(size_type idx)\n    {\n        // at only works for arrays\n        if (JSON_HEDLEY_LIKELY(is_array()))\n        {\n            JSON_TRY\n            {\n                return set_parent(m_data.m_value.array->at(idx));\n            }\n            JSON_CATCH (std::out_of_range&)\n            {\n                // create better exception explanation\n                JSON_THROW(out_of_range::create(401, detail::concat(\"array index \", std::to_string(idx), \" is out of range\"), this));\n            }\n        }\n        else\n        {\n            JSON_THROW(type_error::create(304, detail::concat(\"cannot use at() with \", type_name()), this));\n        }\n    }\n\n    /// @brief access specified array element with bounds checking\n    /// @sa https://json.nlohmann.me/api/basic_json/at/\n    const_reference at(size_type idx) const\n    {\n        // at only works for arrays\n        if (JSON_HEDLEY_LIKELY(is_array()))\n        {\n            JSON_TRY\n            {\n                return m_data.m_value.array->at(idx);\n            }\n            JSON_CATCH (std::out_of_range&)\n            {\n                // create better exception explanation\n                JSON_THROW(out_of_range::create(401, detail::concat(\"array index \", std::to_string(idx), \" is out of range\"), this));\n            }\n        }\n        else\n        {\n            JSON_THROW(type_error::create(304, detail::concat(\"cannot use at() with \", type_name()), this));\n        }\n    }\n\n    /// @brief access specified object element with bounds checking\n    /// @sa https://json.nlohmann.me/api/basic_json/at/\n    reference at(const typename object_t::key_type& key)\n    {\n        // at only works for objects\n        if (JSON_HEDLEY_UNLIKELY(!is_object()))\n        {\n            JSON_THROW(type_error::create(304, detail::concat(\"cannot use at() with \", type_name()), this));\n        }\n\n        auto it = m_data.m_value.object->find(key);\n        if (it == m_data.m_value.object->end())\n        {\n            JSON_THROW(out_of_range::create(403, detail::concat(\"key '\", key, \"' not found\"), this));\n        }\n        return set_parent(it->second);\n    }\n\n    /// @brief access specified object element with bounds checking\n    /// @sa https://json.nlohmann.me/api/basic_json/at/\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_basic_json_key_type<basic_json_t, KeyType>::value, int> = 0>\n    reference at(KeyType && key)\n    {\n        // at only works for objects\n        if (JSON_HEDLEY_UNLIKELY(!is_object()))\n        {\n            JSON_THROW(type_error::create(304, detail::concat(\"cannot use at() with \", type_name()), this));\n        }\n\n        auto it = m_data.m_value.object->find(std::forward<KeyType>(key));\n        if (it == m_data.m_value.object->end())\n        {\n            JSON_THROW(out_of_range::create(403, detail::concat(\"key '\", string_t(std::forward<KeyType>(key)), \"' not found\"), this));\n        }\n        return set_parent(it->second);\n    }\n\n    /// @brief access specified object element with bounds checking\n    /// @sa https://json.nlohmann.me/api/basic_json/at/\n    const_reference at(const typename object_t::key_type& key) const\n    {\n        // at only works for objects\n        if (JSON_HEDLEY_UNLIKELY(!is_object()))\n        {\n            JSON_THROW(type_error::create(304, detail::concat(\"cannot use at() with \", type_name()), this));\n        }\n\n        auto it = m_data.m_value.object->find(key);\n        if (it == m_data.m_value.object->end())\n        {\n            JSON_THROW(out_of_range::create(403, detail::concat(\"key '\", key, \"' not found\"), this));\n        }\n        return it->second;\n    }\n\n    /// @brief access specified object element with bounds checking\n    /// @sa https://json.nlohmann.me/api/basic_json/at/\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_basic_json_key_type<basic_json_t, KeyType>::value, int> = 0>\n    const_reference at(KeyType && key) const\n    {\n        // at only works for objects\n        if (JSON_HEDLEY_UNLIKELY(!is_object()))\n        {\n            JSON_THROW(type_error::create(304, detail::concat(\"cannot use at() with \", type_name()), this));\n        }\n\n        auto it = m_data.m_value.object->find(std::forward<KeyType>(key));\n        if (it == m_data.m_value.object->end())\n        {\n            JSON_THROW(out_of_range::create(403, detail::concat(\"key '\", string_t(std::forward<KeyType>(key)), \"' not found\"), this));\n        }\n        return it->second;\n    }\n\n    /// @brief access specified array element\n    /// @sa https://json.nlohmann.me/api/basic_json/operator%5B%5D/\n    reference operator[](size_type idx)\n    {\n        // implicitly convert null value to an empty array\n        if (is_null())\n        {\n            m_data.m_type = value_t::array;\n            m_data.m_value.array = create<array_t>();\n            assert_invariant();\n        }\n\n        // operator[] only works for arrays\n        if (JSON_HEDLEY_LIKELY(is_array()))\n        {\n            // fill up array with null values if given idx is outside range\n            if (idx >= m_data.m_value.array->size())\n            {\n#if JSON_DIAGNOSTICS\n                // remember array size & capacity before resizing\n                const auto old_size = m_data.m_value.array->size();\n                const auto old_capacity = m_data.m_value.array->capacity();\n#endif\n                m_data.m_value.array->resize(idx + 1);\n\n#if JSON_DIAGNOSTICS\n                if (JSON_HEDLEY_UNLIKELY(m_data.m_value.array->capacity() != old_capacity))\n                {\n                    // capacity has changed: update all parents\n                    set_parents();\n                }\n                else\n                {\n                    // set parent for values added above\n                    set_parents(begin() + static_cast<typename iterator::difference_type>(old_size), static_cast<typename iterator::difference_type>(idx + 1 - old_size));\n                }\n#endif\n                assert_invariant();\n            }\n\n            return m_data.m_value.array->operator[](idx);\n        }\n\n        JSON_THROW(type_error::create(305, detail::concat(\"cannot use operator[] with a numeric argument with \", type_name()), this));\n    }\n\n    /// @brief access specified array element\n    /// @sa https://json.nlohmann.me/api/basic_json/operator%5B%5D/\n    const_reference operator[](size_type idx) const\n    {\n        // const operator[] only works for arrays\n        if (JSON_HEDLEY_LIKELY(is_array()))\n        {\n            return m_data.m_value.array->operator[](idx);\n        }\n\n        JSON_THROW(type_error::create(305, detail::concat(\"cannot use operator[] with a numeric argument with \", type_name()), this));\n    }\n\n    /// @brief access specified object element\n    /// @sa https://json.nlohmann.me/api/basic_json/operator%5B%5D/\n    reference operator[](typename object_t::key_type key)\n    {\n        // implicitly convert null value to an empty object\n        if (is_null())\n        {\n            m_data.m_type = value_t::object;\n            m_data.m_value.object = create<object_t>();\n            assert_invariant();\n        }\n\n        // operator[] only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            auto result = m_data.m_value.object->emplace(std::move(key), nullptr);\n            return set_parent(result.first->second);\n        }\n\n        JSON_THROW(type_error::create(305, detail::concat(\"cannot use operator[] with a string argument with \", type_name()), this));\n    }\n\n    /// @brief access specified object element\n    /// @sa https://json.nlohmann.me/api/basic_json/operator%5B%5D/\n    const_reference operator[](const typename object_t::key_type& key) const\n    {\n        // const operator[] only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            auto it = m_data.m_value.object->find(key);\n            JSON_ASSERT(it != m_data.m_value.object->end());\n            return it->second;\n        }\n\n        JSON_THROW(type_error::create(305, detail::concat(\"cannot use operator[] with a string argument with \", type_name()), this));\n    }\n\n    // these two functions resolve a (const) char * ambiguity affecting Clang and MSVC\n    // (they seemingly cannot be constrained to resolve the ambiguity)\n    template<typename T>\n    reference operator[](T* key)\n    {\n        return operator[](typename object_t::key_type(key));\n    }\n\n    template<typename T>\n    const_reference operator[](T* key) const\n    {\n        return operator[](typename object_t::key_type(key));\n    }\n\n    /// @brief access specified object element\n    /// @sa https://json.nlohmann.me/api/basic_json/operator%5B%5D/\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_basic_json_key_type<basic_json_t, KeyType>::value, int > = 0 >\n    reference operator[](KeyType && key)\n    {\n        // implicitly convert null value to an empty object\n        if (is_null())\n        {\n            m_data.m_type = value_t::object;\n            m_data.m_value.object = create<object_t>();\n            assert_invariant();\n        }\n\n        // operator[] only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            auto result = m_data.m_value.object->emplace(std::forward<KeyType>(key), nullptr);\n            return set_parent(result.first->second);\n        }\n\n        JSON_THROW(type_error::create(305, detail::concat(\"cannot use operator[] with a string argument with \", type_name()), this));\n    }\n\n    /// @brief access specified object element\n    /// @sa https://json.nlohmann.me/api/basic_json/operator%5B%5D/\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_basic_json_key_type<basic_json_t, KeyType>::value, int > = 0 >\n    const_reference operator[](KeyType && key) const\n    {\n        // const operator[] only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            auto it = m_data.m_value.object->find(std::forward<KeyType>(key));\n            JSON_ASSERT(it != m_data.m_value.object->end());\n            return it->second;\n        }\n\n        JSON_THROW(type_error::create(305, detail::concat(\"cannot use operator[] with a string argument with \", type_name()), this));\n    }\n\n  private:\n    template<typename KeyType>\n    using is_comparable_with_object_key = detail::is_comparable <\n        object_comparator_t, const typename object_t::key_type&, KeyType >;\n\n    template<typename ValueType>\n    using value_return_type = std::conditional <\n        detail::is_c_string_uncvref<ValueType>::value,\n        string_t, typename std::decay<ValueType>::type >;\n\n  public:\n    /// @brief access specified object element with default value\n    /// @sa https://json.nlohmann.me/api/basic_json/value/\n    template < class ValueType, detail::enable_if_t <\n                   !detail::is_transparent<object_comparator_t>::value\n                   && detail::is_getable<basic_json_t, ValueType>::value\n                   && !std::is_same<value_t, detail::uncvref_t<ValueType>>::value, int > = 0 >\n    ValueType value(const typename object_t::key_type& key, const ValueType& default_value) const\n    {\n        // value only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            // if key is found, return value and given default value otherwise\n            const auto it = find(key);\n            if (it != end())\n            {\n                return it->template get<ValueType>();\n            }\n\n            return default_value;\n        }\n\n        JSON_THROW(type_error::create(306, detail::concat(\"cannot use value() with \", type_name()), this));\n    }\n\n    /// @brief access specified object element with default value\n    /// @sa https://json.nlohmann.me/api/basic_json/value/\n    template < class ValueType, class ReturnType = typename value_return_type<ValueType>::type,\n               detail::enable_if_t <\n                   !detail::is_transparent<object_comparator_t>::value\n                   && detail::is_getable<basic_json_t, ReturnType>::value\n                   && !std::is_same<value_t, detail::uncvref_t<ValueType>>::value, int > = 0 >\n    ReturnType value(const typename object_t::key_type& key, ValueType && default_value) const\n    {\n        // value only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            // if key is found, return value and given default value otherwise\n            const auto it = find(key);\n            if (it != end())\n            {\n                return it->template get<ReturnType>();\n            }\n\n            return std::forward<ValueType>(default_value);\n        }\n\n        JSON_THROW(type_error::create(306, detail::concat(\"cannot use value() with \", type_name()), this));\n    }\n\n    /// @brief access specified object element with default value\n    /// @sa https://json.nlohmann.me/api/basic_json/value/\n    template < class ValueType, class KeyType, detail::enable_if_t <\n                   detail::is_transparent<object_comparator_t>::value\n                   && !detail::is_json_pointer<KeyType>::value\n                   && is_comparable_with_object_key<KeyType>::value\n                   && detail::is_getable<basic_json_t, ValueType>::value\n                   && !std::is_same<value_t, detail::uncvref_t<ValueType>>::value, int > = 0 >\n    ValueType value(KeyType && key, const ValueType& default_value) const\n    {\n        // value only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            // if key is found, return value and given default value otherwise\n            const auto it = find(std::forward<KeyType>(key));\n            if (it != end())\n            {\n                return it->template get<ValueType>();\n            }\n\n            return default_value;\n        }\n\n        JSON_THROW(type_error::create(306, detail::concat(\"cannot use value() with \", type_name()), this));\n    }\n\n    /// @brief access specified object element via JSON Pointer with default value\n    /// @sa https://json.nlohmann.me/api/basic_json/value/\n    template < class ValueType, class KeyType, class ReturnType = typename value_return_type<ValueType>::type,\n               detail::enable_if_t <\n                   detail::is_transparent<object_comparator_t>::value\n                   && !detail::is_json_pointer<KeyType>::value\n                   && is_comparable_with_object_key<KeyType>::value\n                   && detail::is_getable<basic_json_t, ReturnType>::value\n                   && !std::is_same<value_t, detail::uncvref_t<ValueType>>::value, int > = 0 >\n    ReturnType value(KeyType && key, ValueType && default_value) const\n    {\n        // value only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            // if key is found, return value and given default value otherwise\n            const auto it = find(std::forward<KeyType>(key));\n            if (it != end())\n            {\n                return it->template get<ReturnType>();\n            }\n\n            return std::forward<ValueType>(default_value);\n        }\n\n        JSON_THROW(type_error::create(306, detail::concat(\"cannot use value() with \", type_name()), this));\n    }\n\n    /// @brief access specified object element via JSON Pointer with default value\n    /// @sa https://json.nlohmann.me/api/basic_json/value/\n    template < class ValueType, detail::enable_if_t <\n                   detail::is_getable<basic_json_t, ValueType>::value\n                   && !std::is_same<value_t, detail::uncvref_t<ValueType>>::value, int > = 0 >\n    ValueType value(const json_pointer& ptr, const ValueType& default_value) const\n    {\n        // value only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            // if pointer resolves a value, return it or use default value\n            JSON_TRY\n            {\n                return ptr.get_checked(this).template get<ValueType>();\n            }\n            JSON_INTERNAL_CATCH (out_of_range&)\n            {\n                return default_value;\n            }\n        }\n\n        JSON_THROW(type_error::create(306, detail::concat(\"cannot use value() with \", type_name()), this));\n    }\n\n    /// @brief access specified object element via JSON Pointer with default value\n    /// @sa https://json.nlohmann.me/api/basic_json/value/\n    template < class ValueType, class ReturnType = typename value_return_type<ValueType>::type,\n               detail::enable_if_t <\n                   detail::is_getable<basic_json_t, ReturnType>::value\n                   && !std::is_same<value_t, detail::uncvref_t<ValueType>>::value, int > = 0 >\n    ReturnType value(const json_pointer& ptr, ValueType && default_value) const\n    {\n        // value only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            // if pointer resolves a value, return it or use default value\n            JSON_TRY\n            {\n                return ptr.get_checked(this).template get<ReturnType>();\n            }\n            JSON_INTERNAL_CATCH (out_of_range&)\n            {\n                return std::forward<ValueType>(default_value);\n            }\n        }\n\n        JSON_THROW(type_error::create(306, detail::concat(\"cannot use value() with \", type_name()), this));\n    }\n\n    template < class ValueType, class BasicJsonType, detail::enable_if_t <\n                   detail::is_basic_json<BasicJsonType>::value\n                   && detail::is_getable<basic_json_t, ValueType>::value\n                   && !std::is_same<value_t, detail::uncvref_t<ValueType>>::value, int > = 0 >\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.0, basic_json::json_pointer or nlohmann::json_pointer<basic_json::string_t>) // NOLINT(readability/alt_tokens)\n    ValueType value(const ::nlohmann::json_pointer<BasicJsonType>& ptr, const ValueType& default_value) const\n    {\n        return value(ptr.convert(), default_value);\n    }\n\n    template < class ValueType, class BasicJsonType, class ReturnType = typename value_return_type<ValueType>::type,\n               detail::enable_if_t <\n                   detail::is_basic_json<BasicJsonType>::value\n                   && detail::is_getable<basic_json_t, ReturnType>::value\n                   && !std::is_same<value_t, detail::uncvref_t<ValueType>>::value, int > = 0 >\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.0, basic_json::json_pointer or nlohmann::json_pointer<basic_json::string_t>) // NOLINT(readability/alt_tokens)\n    ReturnType value(const ::nlohmann::json_pointer<BasicJsonType>& ptr, ValueType && default_value) const\n    {\n        return value(ptr.convert(), std::forward<ValueType>(default_value));\n    }\n\n    /// @brief access the first element\n    /// @sa https://json.nlohmann.me/api/basic_json/front/\n    reference front()\n    {\n        return *begin();\n    }\n\n    /// @brief access the first element\n    /// @sa https://json.nlohmann.me/api/basic_json/front/\n    const_reference front() const\n    {\n        return *cbegin();\n    }\n\n    /// @brief access the last element\n    /// @sa https://json.nlohmann.me/api/basic_json/back/\n    reference back()\n    {\n        auto tmp = end();\n        --tmp;\n        return *tmp;\n    }\n\n    /// @brief access the last element\n    /// @sa https://json.nlohmann.me/api/basic_json/back/\n    const_reference back() const\n    {\n        auto tmp = cend();\n        --tmp;\n        return *tmp;\n    }\n\n    /// @brief remove element given an iterator\n    /// @sa https://json.nlohmann.me/api/basic_json/erase/\n    template < class IteratorType, detail::enable_if_t <\n                   std::is_same<IteratorType, typename basic_json_t::iterator>::value ||\n                   std::is_same<IteratorType, typename basic_json_t::const_iterator>::value, int > = 0 >\n    IteratorType erase(IteratorType pos)\n    {\n        // make sure iterator fits the current value\n        if (JSON_HEDLEY_UNLIKELY(this != pos.m_object))\n        {\n            JSON_THROW(invalid_iterator::create(202, \"iterator does not fit current value\", this));\n        }\n\n        IteratorType result = end();\n\n        switch (m_data.m_type)\n        {\n            case value_t::boolean:\n            case value_t::number_float:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::string:\n            case value_t::binary:\n            {\n                if (JSON_HEDLEY_UNLIKELY(!pos.m_it.primitive_iterator.is_begin()))\n                {\n                    JSON_THROW(invalid_iterator::create(205, \"iterator out of range\", this));\n                }\n\n                if (is_string())\n                {\n                    AllocatorType<string_t> alloc;\n                    std::allocator_traits<decltype(alloc)>::destroy(alloc, m_data.m_value.string);\n                    std::allocator_traits<decltype(alloc)>::deallocate(alloc, m_data.m_value.string, 1);\n                    m_data.m_value.string = nullptr;\n                }\n                else if (is_binary())\n                {\n                    AllocatorType<binary_t> alloc;\n                    std::allocator_traits<decltype(alloc)>::destroy(alloc, m_data.m_value.binary);\n                    std::allocator_traits<decltype(alloc)>::deallocate(alloc, m_data.m_value.binary, 1);\n                    m_data.m_value.binary = nullptr;\n                }\n\n                m_data.m_type = value_t::null;\n                assert_invariant();\n                break;\n            }\n\n            case value_t::object:\n            {\n                result.m_it.object_iterator = m_data.m_value.object->erase(pos.m_it.object_iterator);\n                break;\n            }\n\n            case value_t::array:\n            {\n                result.m_it.array_iterator = m_data.m_value.array->erase(pos.m_it.array_iterator);\n                break;\n            }\n\n            case value_t::null:\n            case value_t::discarded:\n            default:\n                JSON_THROW(type_error::create(307, detail::concat(\"cannot use erase() with \", type_name()), this));\n        }\n\n        return result;\n    }\n\n    /// @brief remove elements given an iterator range\n    /// @sa https://json.nlohmann.me/api/basic_json/erase/\n    template < class IteratorType, detail::enable_if_t <\n                   std::is_same<IteratorType, typename basic_json_t::iterator>::value ||\n                   std::is_same<IteratorType, typename basic_json_t::const_iterator>::value, int > = 0 >\n    IteratorType erase(IteratorType first, IteratorType last)\n    {\n        // make sure iterator fits the current value\n        if (JSON_HEDLEY_UNLIKELY(this != first.m_object || this != last.m_object))\n        {\n            JSON_THROW(invalid_iterator::create(203, \"iterators do not fit current value\", this));\n        }\n\n        IteratorType result = end();\n\n        switch (m_data.m_type)\n        {\n            case value_t::boolean:\n            case value_t::number_float:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::string:\n            case value_t::binary:\n            {\n                if (JSON_HEDLEY_LIKELY(!first.m_it.primitive_iterator.is_begin()\n                                       || !last.m_it.primitive_iterator.is_end()))\n                {\n                    JSON_THROW(invalid_iterator::create(204, \"iterators out of range\", this));\n                }\n\n                if (is_string())\n                {\n                    AllocatorType<string_t> alloc;\n                    std::allocator_traits<decltype(alloc)>::destroy(alloc, m_data.m_value.string);\n                    std::allocator_traits<decltype(alloc)>::deallocate(alloc, m_data.m_value.string, 1);\n                    m_data.m_value.string = nullptr;\n                }\n                else if (is_binary())\n                {\n                    AllocatorType<binary_t> alloc;\n                    std::allocator_traits<decltype(alloc)>::destroy(alloc, m_data.m_value.binary);\n                    std::allocator_traits<decltype(alloc)>::deallocate(alloc, m_data.m_value.binary, 1);\n                    m_data.m_value.binary = nullptr;\n                }\n\n                m_data.m_type = value_t::null;\n                assert_invariant();\n                break;\n            }\n\n            case value_t::object:\n            {\n                result.m_it.object_iterator = m_data.m_value.object->erase(first.m_it.object_iterator,\n                                              last.m_it.object_iterator);\n                break;\n            }\n\n            case value_t::array:\n            {\n                result.m_it.array_iterator = m_data.m_value.array->erase(first.m_it.array_iterator,\n                                             last.m_it.array_iterator);\n                break;\n            }\n\n            case value_t::null:\n            case value_t::discarded:\n            default:\n                JSON_THROW(type_error::create(307, detail::concat(\"cannot use erase() with \", type_name()), this));\n        }\n\n        return result;\n    }\n\n  private:\n    template < typename KeyType, detail::enable_if_t <\n                   detail::has_erase_with_key_type<basic_json_t, KeyType>::value, int > = 0 >\n    size_type erase_internal(KeyType && key)\n    {\n        // this erase only works for objects\n        if (JSON_HEDLEY_UNLIKELY(!is_object()))\n        {\n            JSON_THROW(type_error::create(307, detail::concat(\"cannot use erase() with \", type_name()), this));\n        }\n\n        return m_data.m_value.object->erase(std::forward<KeyType>(key));\n    }\n\n    template < typename KeyType, detail::enable_if_t <\n                   !detail::has_erase_with_key_type<basic_json_t, KeyType>::value, int > = 0 >\n    size_type erase_internal(KeyType && key)\n    {\n        // this erase only works for objects\n        if (JSON_HEDLEY_UNLIKELY(!is_object()))\n        {\n            JSON_THROW(type_error::create(307, detail::concat(\"cannot use erase() with \", type_name()), this));\n        }\n\n        const auto it = m_data.m_value.object->find(std::forward<KeyType>(key));\n        if (it != m_data.m_value.object->end())\n        {\n            m_data.m_value.object->erase(it);\n            return 1;\n        }\n        return 0;\n    }\n\n  public:\n\n    /// @brief remove element from a JSON object given a key\n    /// @sa https://json.nlohmann.me/api/basic_json/erase/\n    size_type erase(const typename object_t::key_type& key)\n    {\n        // the indirection via erase_internal() is added to avoid making this\n        // function a template and thus de-rank it during overload resolution\n        return erase_internal(key);\n    }\n\n    /// @brief remove element from a JSON object given a key\n    /// @sa https://json.nlohmann.me/api/basic_json/erase/\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_basic_json_key_type<basic_json_t, KeyType>::value, int> = 0>\n    size_type erase(KeyType && key)\n    {\n        return erase_internal(std::forward<KeyType>(key));\n    }\n\n    /// @brief remove element from a JSON array given an index\n    /// @sa https://json.nlohmann.me/api/basic_json/erase/\n    void erase(const size_type idx)\n    {\n        // this erase only works for arrays\n        if (JSON_HEDLEY_LIKELY(is_array()))\n        {\n            if (JSON_HEDLEY_UNLIKELY(idx >= size()))\n            {\n                JSON_THROW(out_of_range::create(401, detail::concat(\"array index \", std::to_string(idx), \" is out of range\"), this));\n            }\n\n            m_data.m_value.array->erase(m_data.m_value.array->begin() + static_cast<difference_type>(idx));\n        }\n        else\n        {\n            JSON_THROW(type_error::create(307, detail::concat(\"cannot use erase() with \", type_name()), this));\n        }\n    }\n\n    /// @}\n\n    ////////////\n    // lookup //\n    ////////////\n\n    /// @name lookup\n    /// @{\n\n    /// @brief find an element in a JSON object\n    /// @sa https://json.nlohmann.me/api/basic_json/find/\n    iterator find(const typename object_t::key_type& key)\n    {\n        auto result = end();\n\n        if (is_object())\n        {\n            result.m_it.object_iterator = m_data.m_value.object->find(key);\n        }\n\n        return result;\n    }\n\n    /// @brief find an element in a JSON object\n    /// @sa https://json.nlohmann.me/api/basic_json/find/\n    const_iterator find(const typename object_t::key_type& key) const\n    {\n        auto result = cend();\n\n        if (is_object())\n        {\n            result.m_it.object_iterator = m_data.m_value.object->find(key);\n        }\n\n        return result;\n    }\n\n    /// @brief find an element in a JSON object\n    /// @sa https://json.nlohmann.me/api/basic_json/find/\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_basic_json_key_type<basic_json_t, KeyType>::value, int> = 0>\n    iterator find(KeyType && key)\n    {\n        auto result = end();\n\n        if (is_object())\n        {\n            result.m_it.object_iterator = m_data.m_value.object->find(std::forward<KeyType>(key));\n        }\n\n        return result;\n    }\n\n    /// @brief find an element in a JSON object\n    /// @sa https://json.nlohmann.me/api/basic_json/find/\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_basic_json_key_type<basic_json_t, KeyType>::value, int> = 0>\n    const_iterator find(KeyType && key) const\n    {\n        auto result = cend();\n\n        if (is_object())\n        {\n            result.m_it.object_iterator = m_data.m_value.object->find(std::forward<KeyType>(key));\n        }\n\n        return result;\n    }\n\n    /// @brief returns the number of occurrences of a key in a JSON object\n    /// @sa https://json.nlohmann.me/api/basic_json/count/\n    size_type count(const typename object_t::key_type& key) const\n    {\n        // return 0 for all nonobject types\n        return is_object() ? m_data.m_value.object->count(key) : 0;\n    }\n\n    /// @brief returns the number of occurrences of a key in a JSON object\n    /// @sa https://json.nlohmann.me/api/basic_json/count/\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_basic_json_key_type<basic_json_t, KeyType>::value, int> = 0>\n    size_type count(KeyType && key) const\n    {\n        // return 0 for all nonobject types\n        return is_object() ? m_data.m_value.object->count(std::forward<KeyType>(key)) : 0;\n    }\n\n    /// @brief check the existence of an element in a JSON object\n    /// @sa https://json.nlohmann.me/api/basic_json/contains/\n    bool contains(const typename object_t::key_type& key) const\n    {\n        return is_object() && m_data.m_value.object->find(key) != m_data.m_value.object->end();\n    }\n\n    /// @brief check the existence of an element in a JSON object\n    /// @sa https://json.nlohmann.me/api/basic_json/contains/\n    template<class KeyType, detail::enable_if_t<\n                 detail::is_usable_as_basic_json_key_type<basic_json_t, KeyType>::value, int> = 0>\n    bool contains(KeyType && key) const\n    {\n        return is_object() && m_data.m_value.object->find(std::forward<KeyType>(key)) != m_data.m_value.object->end();\n    }\n\n    /// @brief check the existence of an element in a JSON object given a JSON pointer\n    /// @sa https://json.nlohmann.me/api/basic_json/contains/\n    bool contains(const json_pointer& ptr) const\n    {\n        return ptr.contains(this);\n    }\n\n    template<typename BasicJsonType, detail::enable_if_t<detail::is_basic_json<BasicJsonType>::value, int> = 0>\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.0, basic_json::json_pointer or nlohmann::json_pointer<basic_json::string_t>) // NOLINT(readability/alt_tokens)\n    bool contains(const typename ::nlohmann::json_pointer<BasicJsonType>& ptr) const\n    {\n        return ptr.contains(this);\n    }\n\n    /// @}\n\n    ///////////////\n    // iterators //\n    ///////////////\n\n    /// @name iterators\n    /// @{\n\n    /// @brief returns an iterator to the first element\n    /// @sa https://json.nlohmann.me/api/basic_json/begin/\n    iterator begin() noexcept\n    {\n        iterator result(this);\n        result.set_begin();\n        return result;\n    }\n\n    /// @brief returns an iterator to the first element\n    /// @sa https://json.nlohmann.me/api/basic_json/begin/\n    const_iterator begin() const noexcept\n    {\n        return cbegin();\n    }\n\n    /// @brief returns a const iterator to the first element\n    /// @sa https://json.nlohmann.me/api/basic_json/cbegin/\n    const_iterator cbegin() const noexcept\n    {\n        const_iterator result(this);\n        result.set_begin();\n        return result;\n    }\n\n    /// @brief returns an iterator to one past the last element\n    /// @sa https://json.nlohmann.me/api/basic_json/end/\n    iterator end() noexcept\n    {\n        iterator result(this);\n        result.set_end();\n        return result;\n    }\n\n    /// @brief returns an iterator to one past the last element\n    /// @sa https://json.nlohmann.me/api/basic_json/end/\n    const_iterator end() const noexcept\n    {\n        return cend();\n    }\n\n    /// @brief returns an iterator to one past the last element\n    /// @sa https://json.nlohmann.me/api/basic_json/cend/\n    const_iterator cend() const noexcept\n    {\n        const_iterator result(this);\n        result.set_end();\n        return result;\n    }\n\n    /// @brief returns an iterator to the reverse-beginning\n    /// @sa https://json.nlohmann.me/api/basic_json/rbegin/\n    reverse_iterator rbegin() noexcept\n    {\n        return reverse_iterator(end());\n    }\n\n    /// @brief returns an iterator to the reverse-beginning\n    /// @sa https://json.nlohmann.me/api/basic_json/rbegin/\n    const_reverse_iterator rbegin() const noexcept\n    {\n        return crbegin();\n    }\n\n    /// @brief returns an iterator to the reverse-end\n    /// @sa https://json.nlohmann.me/api/basic_json/rend/\n    reverse_iterator rend() noexcept\n    {\n        return reverse_iterator(begin());\n    }\n\n    /// @brief returns an iterator to the reverse-end\n    /// @sa https://json.nlohmann.me/api/basic_json/rend/\n    const_reverse_iterator rend() const noexcept\n    {\n        return crend();\n    }\n\n    /// @brief returns a const reverse iterator to the last element\n    /// @sa https://json.nlohmann.me/api/basic_json/crbegin/\n    const_reverse_iterator crbegin() const noexcept\n    {\n        return const_reverse_iterator(cend());\n    }\n\n    /// @brief returns a const reverse iterator to one before the first\n    /// @sa https://json.nlohmann.me/api/basic_json/crend/\n    const_reverse_iterator crend() const noexcept\n    {\n        return const_reverse_iterator(cbegin());\n    }\n\n  public:\n    /// @brief wrapper to access iterator member functions in range-based for\n    /// @sa https://json.nlohmann.me/api/basic_json/items/\n    /// @deprecated This function is deprecated since 3.1.0 and will be removed in\n    ///             version 4.0.0 of the library. Please use @ref items() instead;\n    ///             that is, replace `json::iterator_wrapper(j)` with `j.items()`.\n    JSON_HEDLEY_DEPRECATED_FOR(3.1.0, items())\n    static iteration_proxy<iterator> iterator_wrapper(reference ref) noexcept\n    {\n        return ref.items();\n    }\n\n    /// @brief wrapper to access iterator member functions in range-based for\n    /// @sa https://json.nlohmann.me/api/basic_json/items/\n    /// @deprecated This function is deprecated since 3.1.0 and will be removed in\n    ///         version 4.0.0 of the library. Please use @ref items() instead;\n    ///         that is, replace `json::iterator_wrapper(j)` with `j.items()`.\n    JSON_HEDLEY_DEPRECATED_FOR(3.1.0, items())\n    static iteration_proxy<const_iterator> iterator_wrapper(const_reference ref) noexcept\n    {\n        return ref.items();\n    }\n\n    /// @brief helper to access iterator member functions in range-based for\n    /// @sa https://json.nlohmann.me/api/basic_json/items/\n    iteration_proxy<iterator> items() noexcept\n    {\n        return iteration_proxy<iterator>(*this);\n    }\n\n    /// @brief helper to access iterator member functions in range-based for\n    /// @sa https://json.nlohmann.me/api/basic_json/items/\n    iteration_proxy<const_iterator> items() const noexcept\n    {\n        return iteration_proxy<const_iterator>(*this);\n    }\n\n    /// @}\n\n    //////////////\n    // capacity //\n    //////////////\n\n    /// @name capacity\n    /// @{\n\n    /// @brief checks whether the container is empty.\n    /// @sa https://json.nlohmann.me/api/basic_json/empty/\n    bool empty() const noexcept\n    {\n        switch (m_data.m_type)\n        {\n            case value_t::null:\n            {\n                // null values are empty\n                return true;\n            }\n\n            case value_t::array:\n            {\n                // delegate call to array_t::empty()\n                return m_data.m_value.array->empty();\n            }\n\n            case value_t::object:\n            {\n                // delegate call to object_t::empty()\n                return m_data.m_value.object->empty();\n            }\n\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                // all other types are nonempty\n                return false;\n            }\n        }\n    }\n\n    /// @brief returns the number of elements\n    /// @sa https://json.nlohmann.me/api/basic_json/size/\n    size_type size() const noexcept\n    {\n        switch (m_data.m_type)\n        {\n            case value_t::null:\n            {\n                // null values are empty\n                return 0;\n            }\n\n            case value_t::array:\n            {\n                // delegate call to array_t::size()\n                return m_data.m_value.array->size();\n            }\n\n            case value_t::object:\n            {\n                // delegate call to object_t::size()\n                return m_data.m_value.object->size();\n            }\n\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                // all other types have size 1\n                return 1;\n            }\n        }\n    }\n\n    /// @brief returns the maximum possible number of elements\n    /// @sa https://json.nlohmann.me/api/basic_json/max_size/\n    size_type max_size() const noexcept\n    {\n        switch (m_data.m_type)\n        {\n            case value_t::array:\n            {\n                // delegate call to array_t::max_size()\n                return m_data.m_value.array->max_size();\n            }\n\n            case value_t::object:\n            {\n                // delegate call to object_t::max_size()\n                return m_data.m_value.object->max_size();\n            }\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                // all other types have max_size() == size()\n                return size();\n            }\n        }\n    }\n\n    /// @}\n\n    ///////////////\n    // modifiers //\n    ///////////////\n\n    /// @name modifiers\n    /// @{\n\n    /// @brief clears the contents\n    /// @sa https://json.nlohmann.me/api/basic_json/clear/\n    void clear() noexcept\n    {\n        switch (m_data.m_type)\n        {\n            case value_t::number_integer:\n            {\n                m_data.m_value.number_integer = 0;\n                break;\n            }\n\n            case value_t::number_unsigned:\n            {\n                m_data.m_value.number_unsigned = 0;\n                break;\n            }\n\n            case value_t::number_float:\n            {\n                m_data.m_value.number_float = 0.0;\n                break;\n            }\n\n            case value_t::boolean:\n            {\n                m_data.m_value.boolean = false;\n                break;\n            }\n\n            case value_t::string:\n            {\n                m_data.m_value.string->clear();\n                break;\n            }\n\n            case value_t::binary:\n            {\n                m_data.m_value.binary->clear();\n                break;\n            }\n\n            case value_t::array:\n            {\n                m_data.m_value.array->clear();\n                break;\n            }\n\n            case value_t::object:\n            {\n                m_data.m_value.object->clear();\n                break;\n            }\n\n            case value_t::null:\n            case value_t::discarded:\n            default:\n                break;\n        }\n    }\n\n    /// @brief add an object to an array\n    /// @sa https://json.nlohmann.me/api/basic_json/push_back/\n    void push_back(basic_json&& val)\n    {\n        // push_back only works for null objects or arrays\n        if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array())))\n        {\n            JSON_THROW(type_error::create(308, detail::concat(\"cannot use push_back() with \", type_name()), this));\n        }\n\n        // transform null object into an array\n        if (is_null())\n        {\n            m_data.m_type = value_t::array;\n            m_data.m_value = value_t::array;\n            assert_invariant();\n        }\n\n        // add element to array (move semantics)\n        const auto old_capacity = m_data.m_value.array->capacity();\n        m_data.m_value.array->push_back(std::move(val));\n        set_parent(m_data.m_value.array->back(), old_capacity);\n        // if val is moved from, basic_json move constructor marks it null, so we do not call the destructor\n    }\n\n    /// @brief add an object to an array\n    /// @sa https://json.nlohmann.me/api/basic_json/operator+=/\n    reference operator+=(basic_json&& val)\n    {\n        push_back(std::move(val));\n        return *this;\n    }\n\n    /// @brief add an object to an array\n    /// @sa https://json.nlohmann.me/api/basic_json/push_back/\n    void push_back(const basic_json& val)\n    {\n        // push_back only works for null objects or arrays\n        if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array())))\n        {\n            JSON_THROW(type_error::create(308, detail::concat(\"cannot use push_back() with \", type_name()), this));\n        }\n\n        // transform null object into an array\n        if (is_null())\n        {\n            m_data.m_type = value_t::array;\n            m_data.m_value = value_t::array;\n            assert_invariant();\n        }\n\n        // add element to array\n        const auto old_capacity = m_data.m_value.array->capacity();\n        m_data.m_value.array->push_back(val);\n        set_parent(m_data.m_value.array->back(), old_capacity);\n    }\n\n    /// @brief add an object to an array\n    /// @sa https://json.nlohmann.me/api/basic_json/operator+=/\n    reference operator+=(const basic_json& val)\n    {\n        push_back(val);\n        return *this;\n    }\n\n    /// @brief add an object to an object\n    /// @sa https://json.nlohmann.me/api/basic_json/push_back/\n    void push_back(const typename object_t::value_type& val)\n    {\n        // push_back only works for null objects or objects\n        if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_object())))\n        {\n            JSON_THROW(type_error::create(308, detail::concat(\"cannot use push_back() with \", type_name()), this));\n        }\n\n        // transform null object into an object\n        if (is_null())\n        {\n            m_data.m_type = value_t::object;\n            m_data.m_value = value_t::object;\n            assert_invariant();\n        }\n\n        // add element to object\n        auto res = m_data.m_value.object->insert(val);\n        set_parent(res.first->second);\n    }\n\n    /// @brief add an object to an object\n    /// @sa https://json.nlohmann.me/api/basic_json/operator+=/\n    reference operator+=(const typename object_t::value_type& val)\n    {\n        push_back(val);\n        return *this;\n    }\n\n    /// @brief add an object to an object\n    /// @sa https://json.nlohmann.me/api/basic_json/push_back/\n    void push_back(initializer_list_t init)\n    {\n        if (is_object() && init.size() == 2 && (*init.begin())->is_string())\n        {\n            basic_json&& key = init.begin()->moved_or_copied();\n            push_back(typename object_t::value_type(\n                          std::move(key.get_ref<string_t&>()), (init.begin() + 1)->moved_or_copied()));\n        }\n        else\n        {\n            push_back(basic_json(init));\n        }\n    }\n\n    /// @brief add an object to an object\n    /// @sa https://json.nlohmann.me/api/basic_json/operator+=/\n    reference operator+=(initializer_list_t init)\n    {\n        push_back(init);\n        return *this;\n    }\n\n    /// @brief add an object to an array\n    /// @sa https://json.nlohmann.me/api/basic_json/emplace_back/\n    template<class... Args>\n    reference emplace_back(Args&& ... args)\n    {\n        // emplace_back only works for null objects or arrays\n        if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_array())))\n        {\n            JSON_THROW(type_error::create(311, detail::concat(\"cannot use emplace_back() with \", type_name()), this));\n        }\n\n        // transform null object into an array\n        if (is_null())\n        {\n            m_data.m_type = value_t::array;\n            m_data.m_value = value_t::array;\n            assert_invariant();\n        }\n\n        // add element to array (perfect forwarding)\n        const auto old_capacity = m_data.m_value.array->capacity();\n        m_data.m_value.array->emplace_back(std::forward<Args>(args)...);\n        return set_parent(m_data.m_value.array->back(), old_capacity);\n    }\n\n    /// @brief add an object to an object if key does not exist\n    /// @sa https://json.nlohmann.me/api/basic_json/emplace/\n    template<class... Args>\n    std::pair<iterator, bool> emplace(Args&& ... args)\n    {\n        // emplace only works for null objects or arrays\n        if (JSON_HEDLEY_UNLIKELY(!(is_null() || is_object())))\n        {\n            JSON_THROW(type_error::create(311, detail::concat(\"cannot use emplace() with \", type_name()), this));\n        }\n\n        // transform null object into an object\n        if (is_null())\n        {\n            m_data.m_type = value_t::object;\n            m_data.m_value = value_t::object;\n            assert_invariant();\n        }\n\n        // add element to array (perfect forwarding)\n        auto res = m_data.m_value.object->emplace(std::forward<Args>(args)...);\n        set_parent(res.first->second);\n\n        // create result iterator and set iterator to the result of emplace\n        auto it = begin();\n        it.m_it.object_iterator = res.first;\n\n        // return pair of iterator and boolean\n        return {it, res.second};\n    }\n\n    /// Helper for insertion of an iterator\n    /// @note: This uses std::distance to support GCC 4.8,\n    ///        see https://github.com/nlohmann/json/pull/1257\n    template<typename... Args>\n    iterator insert_iterator(const_iterator pos, Args&& ... args)\n    {\n        iterator result(this);\n        JSON_ASSERT(m_data.m_value.array != nullptr);\n\n        auto insert_pos = std::distance(m_data.m_value.array->begin(), pos.m_it.array_iterator);\n        m_data.m_value.array->insert(pos.m_it.array_iterator, std::forward<Args>(args)...);\n        result.m_it.array_iterator = m_data.m_value.array->begin() + insert_pos;\n\n        // This could have been written as:\n        // result.m_it.array_iterator = m_data.m_value.array->insert(pos.m_it.array_iterator, cnt, val);\n        // but the return value of insert is missing in GCC 4.8, so it is written this way instead.\n\n        set_parents();\n        return result;\n    }\n\n    /// @brief inserts element into array\n    /// @sa https://json.nlohmann.me/api/basic_json/insert/\n    iterator insert(const_iterator pos, const basic_json& val)\n    {\n        // insert only works for arrays\n        if (JSON_HEDLEY_LIKELY(is_array()))\n        {\n            // check if iterator pos fits to this JSON value\n            if (JSON_HEDLEY_UNLIKELY(pos.m_object != this))\n            {\n                JSON_THROW(invalid_iterator::create(202, \"iterator does not fit current value\", this));\n            }\n\n            // insert to array and return iterator\n            return insert_iterator(pos, val);\n        }\n\n        JSON_THROW(type_error::create(309, detail::concat(\"cannot use insert() with \", type_name()), this));\n    }\n\n    /// @brief inserts element into array\n    /// @sa https://json.nlohmann.me/api/basic_json/insert/\n    iterator insert(const_iterator pos, basic_json&& val)\n    {\n        return insert(pos, val);\n    }\n\n    /// @brief inserts copies of element into array\n    /// @sa https://json.nlohmann.me/api/basic_json/insert/\n    iterator insert(const_iterator pos, size_type cnt, const basic_json& val)\n    {\n        // insert only works for arrays\n        if (JSON_HEDLEY_LIKELY(is_array()))\n        {\n            // check if iterator pos fits to this JSON value\n            if (JSON_HEDLEY_UNLIKELY(pos.m_object != this))\n            {\n                JSON_THROW(invalid_iterator::create(202, \"iterator does not fit current value\", this));\n            }\n\n            // insert to array and return iterator\n            return insert_iterator(pos, cnt, val);\n        }\n\n        JSON_THROW(type_error::create(309, detail::concat(\"cannot use insert() with \", type_name()), this));\n    }\n\n    /// @brief inserts range of elements into array\n    /// @sa https://json.nlohmann.me/api/basic_json/insert/\n    iterator insert(const_iterator pos, const_iterator first, const_iterator last)\n    {\n        // insert only works for arrays\n        if (JSON_HEDLEY_UNLIKELY(!is_array()))\n        {\n            JSON_THROW(type_error::create(309, detail::concat(\"cannot use insert() with \", type_name()), this));\n        }\n\n        // check if iterator pos fits to this JSON value\n        if (JSON_HEDLEY_UNLIKELY(pos.m_object != this))\n        {\n            JSON_THROW(invalid_iterator::create(202, \"iterator does not fit current value\", this));\n        }\n\n        // check if range iterators belong to the same JSON object\n        if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object))\n        {\n            JSON_THROW(invalid_iterator::create(210, \"iterators do not fit\", this));\n        }\n\n        if (JSON_HEDLEY_UNLIKELY(first.m_object == this))\n        {\n            JSON_THROW(invalid_iterator::create(211, \"passed iterators may not belong to container\", this));\n        }\n\n        // insert to array and return iterator\n        return insert_iterator(pos, first.m_it.array_iterator, last.m_it.array_iterator);\n    }\n\n    /// @brief inserts elements from initializer list into array\n    /// @sa https://json.nlohmann.me/api/basic_json/insert/\n    iterator insert(const_iterator pos, initializer_list_t ilist)\n    {\n        // insert only works for arrays\n        if (JSON_HEDLEY_UNLIKELY(!is_array()))\n        {\n            JSON_THROW(type_error::create(309, detail::concat(\"cannot use insert() with \", type_name()), this));\n        }\n\n        // check if iterator pos fits to this JSON value\n        if (JSON_HEDLEY_UNLIKELY(pos.m_object != this))\n        {\n            JSON_THROW(invalid_iterator::create(202, \"iterator does not fit current value\", this));\n        }\n\n        // insert to array and return iterator\n        return insert_iterator(pos, ilist.begin(), ilist.end());\n    }\n\n    /// @brief inserts range of elements into object\n    /// @sa https://json.nlohmann.me/api/basic_json/insert/\n    void insert(const_iterator first, const_iterator last)\n    {\n        // insert only works for objects\n        if (JSON_HEDLEY_UNLIKELY(!is_object()))\n        {\n            JSON_THROW(type_error::create(309, detail::concat(\"cannot use insert() with \", type_name()), this));\n        }\n\n        // check if range iterators belong to the same JSON object\n        if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object))\n        {\n            JSON_THROW(invalid_iterator::create(210, \"iterators do not fit\", this));\n        }\n\n        // passed iterators must belong to objects\n        if (JSON_HEDLEY_UNLIKELY(!first.m_object->is_object()))\n        {\n            JSON_THROW(invalid_iterator::create(202, \"iterators first and last must point to objects\", this));\n        }\n\n        m_data.m_value.object->insert(first.m_it.object_iterator, last.m_it.object_iterator);\n    }\n\n    /// @brief updates a JSON object from another object, overwriting existing keys\n    /// @sa https://json.nlohmann.me/api/basic_json/update/\n    void update(const_reference j, bool merge_objects = false)\n    {\n        update(j.begin(), j.end(), merge_objects);\n    }\n\n    /// @brief updates a JSON object from another object, overwriting existing keys\n    /// @sa https://json.nlohmann.me/api/basic_json/update/\n    void update(const_iterator first, const_iterator last, bool merge_objects = false)\n    {\n        // implicitly convert null value to an empty object\n        if (is_null())\n        {\n            m_data.m_type = value_t::object;\n            m_data.m_value.object = create<object_t>();\n            assert_invariant();\n        }\n\n        if (JSON_HEDLEY_UNLIKELY(!is_object()))\n        {\n            JSON_THROW(type_error::create(312, detail::concat(\"cannot use update() with \", type_name()), this));\n        }\n\n        // check if range iterators belong to the same JSON object\n        if (JSON_HEDLEY_UNLIKELY(first.m_object != last.m_object))\n        {\n            JSON_THROW(invalid_iterator::create(210, \"iterators do not fit\", this));\n        }\n\n        // passed iterators must belong to objects\n        if (JSON_HEDLEY_UNLIKELY(!first.m_object->is_object()))\n        {\n            JSON_THROW(type_error::create(312, detail::concat(\"cannot use update() with \", first.m_object->type_name()), first.m_object));\n        }\n\n        for (auto it = first; it != last; ++it)\n        {\n            if (merge_objects && it.value().is_object())\n            {\n                auto it2 = m_data.m_value.object->find(it.key());\n                if (it2 != m_data.m_value.object->end())\n                {\n                    it2->second.update(it.value(), true);\n                    continue;\n                }\n            }\n            m_data.m_value.object->operator[](it.key()) = it.value();\n#if JSON_DIAGNOSTICS\n            m_data.m_value.object->operator[](it.key()).m_parent = this;\n#endif\n        }\n    }\n\n    /// @brief exchanges the values\n    /// @sa https://json.nlohmann.me/api/basic_json/swap/\n    void swap(reference other) noexcept (\n        std::is_nothrow_move_constructible<value_t>::value&&\n        std::is_nothrow_move_assignable<value_t>::value&&\n        std::is_nothrow_move_constructible<json_value>::value&& // NOLINT(cppcoreguidelines-noexcept-swap,performance-noexcept-swap)\n        std::is_nothrow_move_assignable<json_value>::value\n    )\n    {\n        std::swap(m_data.m_type, other.m_data.m_type);\n        std::swap(m_data.m_value, other.m_data.m_value);\n\n        set_parents();\n        other.set_parents();\n        assert_invariant();\n    }\n\n    /// @brief exchanges the values\n    /// @sa https://json.nlohmann.me/api/basic_json/swap/\n    friend void swap(reference left, reference right) noexcept (\n        std::is_nothrow_move_constructible<value_t>::value&&\n        std::is_nothrow_move_assignable<value_t>::value&&\n        std::is_nothrow_move_constructible<json_value>::value&& // NOLINT(cppcoreguidelines-noexcept-swap,performance-noexcept-swap)\n        std::is_nothrow_move_assignable<json_value>::value\n    )\n    {\n        left.swap(right);\n    }\n\n    /// @brief exchanges the values\n    /// @sa https://json.nlohmann.me/api/basic_json/swap/\n    void swap(array_t& other) // NOLINT(bugprone-exception-escape,cppcoreguidelines-noexcept-swap,performance-noexcept-swap)\n    {\n        // swap only works for arrays\n        if (JSON_HEDLEY_LIKELY(is_array()))\n        {\n            using std::swap;\n            swap(*(m_data.m_value.array), other);\n        }\n        else\n        {\n            JSON_THROW(type_error::create(310, detail::concat(\"cannot use swap(array_t&) with \", type_name()), this));\n        }\n    }\n\n    /// @brief exchanges the values\n    /// @sa https://json.nlohmann.me/api/basic_json/swap/\n    void swap(object_t& other) // NOLINT(bugprone-exception-escape,cppcoreguidelines-noexcept-swap,performance-noexcept-swap)\n    {\n        // swap only works for objects\n        if (JSON_HEDLEY_LIKELY(is_object()))\n        {\n            using std::swap;\n            swap(*(m_data.m_value.object), other);\n        }\n        else\n        {\n            JSON_THROW(type_error::create(310, detail::concat(\"cannot use swap(object_t&) with \", type_name()), this));\n        }\n    }\n\n    /// @brief exchanges the values\n    /// @sa https://json.nlohmann.me/api/basic_json/swap/\n    void swap(string_t& other) // NOLINT(bugprone-exception-escape,cppcoreguidelines-noexcept-swap,performance-noexcept-swap)\n    {\n        // swap only works for strings\n        if (JSON_HEDLEY_LIKELY(is_string()))\n        {\n            using std::swap;\n            swap(*(m_data.m_value.string), other);\n        }\n        else\n        {\n            JSON_THROW(type_error::create(310, detail::concat(\"cannot use swap(string_t&) with \", type_name()), this));\n        }\n    }\n\n    /// @brief exchanges the values\n    /// @sa https://json.nlohmann.me/api/basic_json/swap/\n    void swap(binary_t& other) // NOLINT(bugprone-exception-escape,cppcoreguidelines-noexcept-swap,performance-noexcept-swap)\n    {\n        // swap only works for strings\n        if (JSON_HEDLEY_LIKELY(is_binary()))\n        {\n            using std::swap;\n            swap(*(m_data.m_value.binary), other);\n        }\n        else\n        {\n            JSON_THROW(type_error::create(310, detail::concat(\"cannot use swap(binary_t&) with \", type_name()), this));\n        }\n    }\n\n    /// @brief exchanges the values\n    /// @sa https://json.nlohmann.me/api/basic_json/swap/\n    void swap(typename binary_t::container_type& other) // NOLINT(bugprone-exception-escape)\n    {\n        // swap only works for strings\n        if (JSON_HEDLEY_LIKELY(is_binary()))\n        {\n            using std::swap;\n            swap(*(m_data.m_value.binary), other);\n        }\n        else\n        {\n            JSON_THROW(type_error::create(310, detail::concat(\"cannot use swap(binary_t::container_type&) with \", type_name()), this));\n        }\n    }\n\n    /// @}\n\n    //////////////////////////////////////////\n    // lexicographical comparison operators //\n    //////////////////////////////////////////\n\n    /// @name lexicographical comparison operators\n    /// @{\n\n    // note parentheses around operands are necessary; see\n    // https://github.com/nlohmann/json/issues/1530\n#define JSON_IMPLEMENT_OPERATOR(op, null_result, unordered_result, default_result)                       \\\n    const auto lhs_type = lhs.type();                                                                    \\\n    const auto rhs_type = rhs.type();                                                                    \\\n    \\\n    if (lhs_type == rhs_type) /* NOLINT(readability/braces) */                                           \\\n    {                                                                                                    \\\n        switch (lhs_type)                                                                                \\\n        {                                                                                                \\\n            case value_t::array:                                                                         \\\n                return (*lhs.m_data.m_value.array) op (*rhs.m_data.m_value.array);                                     \\\n                \\\n            case value_t::object:                                                                        \\\n                return (*lhs.m_data.m_value.object) op (*rhs.m_data.m_value.object);                                   \\\n                \\\n            case value_t::null:                                                                          \\\n                return (null_result);                                                                    \\\n                \\\n            case value_t::string:                                                                        \\\n                return (*lhs.m_data.m_value.string) op (*rhs.m_data.m_value.string);                                   \\\n                \\\n            case value_t::boolean:                                                                       \\\n                return (lhs.m_data.m_value.boolean) op (rhs.m_data.m_value.boolean);                                   \\\n                \\\n            case value_t::number_integer:                                                                \\\n                return (lhs.m_data.m_value.number_integer) op (rhs.m_data.m_value.number_integer);                     \\\n                \\\n            case value_t::number_unsigned:                                                               \\\n                return (lhs.m_data.m_value.number_unsigned) op (rhs.m_data.m_value.number_unsigned);                   \\\n                \\\n            case value_t::number_float:                                                                  \\\n                return (lhs.m_data.m_value.number_float) op (rhs.m_data.m_value.number_float);                         \\\n                \\\n            case value_t::binary:                                                                        \\\n                return (*lhs.m_data.m_value.binary) op (*rhs.m_data.m_value.binary);                                   \\\n                \\\n            case value_t::discarded:                                                                     \\\n            default:                                                                                     \\\n                return (unordered_result);                                                               \\\n        }                                                                                                \\\n    }                                                                                                    \\\n    else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_float)                   \\\n    {                                                                                                    \\\n        return static_cast<number_float_t>(lhs.m_data.m_value.number_integer) op rhs.m_data.m_value.number_float;      \\\n    }                                                                                                    \\\n    else if (lhs_type == value_t::number_float && rhs_type == value_t::number_integer)                   \\\n    {                                                                                                    \\\n        return lhs.m_data.m_value.number_float op static_cast<number_float_t>(rhs.m_data.m_value.number_integer);      \\\n    }                                                                                                    \\\n    else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_float)                  \\\n    {                                                                                                    \\\n        return static_cast<number_float_t>(lhs.m_data.m_value.number_unsigned) op rhs.m_data.m_value.number_float;     \\\n    }                                                                                                    \\\n    else if (lhs_type == value_t::number_float && rhs_type == value_t::number_unsigned)                  \\\n    {                                                                                                    \\\n        return lhs.m_data.m_value.number_float op static_cast<number_float_t>(rhs.m_data.m_value.number_unsigned);     \\\n    }                                                                                                    \\\n    else if (lhs_type == value_t::number_unsigned && rhs_type == value_t::number_integer)                \\\n    {                                                                                                    \\\n        return static_cast<number_integer_t>(lhs.m_data.m_value.number_unsigned) op rhs.m_data.m_value.number_integer; \\\n    }                                                                                                    \\\n    else if (lhs_type == value_t::number_integer && rhs_type == value_t::number_unsigned)                \\\n    {                                                                                                    \\\n        return lhs.m_data.m_value.number_integer op static_cast<number_integer_t>(rhs.m_data.m_value.number_unsigned); \\\n    }                                                                                                    \\\n    else if(compares_unordered(lhs, rhs))\\\n    {\\\n        return (unordered_result);\\\n    }\\\n    \\\n    return (default_result);\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    // returns true if:\n    // - any operand is NaN and the other operand is of number type\n    // - any operand is discarded\n    // in legacy mode, discarded values are considered ordered if\n    // an operation is computed as an odd number of inverses of others\n    static bool compares_unordered(const_reference lhs, const_reference rhs, bool inverse = false) noexcept\n    {\n        if ((lhs.is_number_float() && std::isnan(lhs.m_data.m_value.number_float) && rhs.is_number())\n                || (rhs.is_number_float() && std::isnan(rhs.m_data.m_value.number_float) && lhs.is_number()))\n        {\n            return true;\n        }\n#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON\n        return (lhs.is_discarded() || rhs.is_discarded()) && !inverse;\n#else\n        static_cast<void>(inverse);\n        return lhs.is_discarded() || rhs.is_discarded();\n#endif\n    }\n\n  private:\n    bool compares_unordered(const_reference rhs, bool inverse = false) const noexcept\n    {\n        return compares_unordered(*this, rhs, inverse);\n    }\n\n  public:\n#if JSON_HAS_THREE_WAY_COMPARISON\n    /// @brief comparison: equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_eq/\n    bool operator==(const_reference rhs) const noexcept\n    {\n#ifdef __GNUC__\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n#endif\n        const_reference lhs = *this;\n        JSON_IMPLEMENT_OPERATOR( ==, true, false, false)\n#ifdef __GNUC__\n#pragma GCC diagnostic pop\n#endif\n    }\n\n    /// @brief comparison: equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_eq/\n    template<typename ScalarType>\n    requires std::is_scalar_v<ScalarType>\n    bool operator==(ScalarType rhs) const noexcept\n    {\n        return *this == basic_json(rhs);\n    }\n\n    /// @brief comparison: not equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ne/\n    bool operator!=(const_reference rhs) const noexcept\n    {\n        if (compares_unordered(rhs, true))\n        {\n            return false;\n        }\n        return !operator==(rhs);\n    }\n\n    /// @brief comparison: 3-way\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_spaceship/\n    std::partial_ordering operator<=>(const_reference rhs) const noexcept // *NOPAD*\n    {\n        const_reference lhs = *this;\n        // default_result is used if we cannot compare values. In that case,\n        // we compare types.\n        JSON_IMPLEMENT_OPERATOR(<=>, // *NOPAD*\n                                std::partial_ordering::equivalent,\n                                std::partial_ordering::unordered,\n                                lhs_type <=> rhs_type) // *NOPAD*\n    }\n\n    /// @brief comparison: 3-way\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_spaceship/\n    template<typename ScalarType>\n    requires std::is_scalar_v<ScalarType>\n    std::partial_ordering operator<=>(ScalarType rhs) const noexcept // *NOPAD*\n    {\n        return *this <=> basic_json(rhs); // *NOPAD*\n    }\n\n#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON\n    // all operators that are computed as an odd number of inverses of others\n    // need to be overloaded to emulate the legacy comparison behavior\n\n    /// @brief comparison: less than or equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_le/\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.0, undef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON)\n    bool operator<=(const_reference rhs) const noexcept\n    {\n        if (compares_unordered(rhs, true))\n        {\n            return false;\n        }\n        return !(rhs < *this);\n    }\n\n    /// @brief comparison: less than or equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_le/\n    template<typename ScalarType>\n    requires std::is_scalar_v<ScalarType>\n    bool operator<=(ScalarType rhs) const noexcept\n    {\n        return *this <= basic_json(rhs);\n    }\n\n    /// @brief comparison: greater than or equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ge/\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.0, undef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON)\n    bool operator>=(const_reference rhs) const noexcept\n    {\n        if (compares_unordered(rhs, true))\n        {\n            return false;\n        }\n        return !(*this < rhs);\n    }\n\n    /// @brief comparison: greater than or equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ge/\n    template<typename ScalarType>\n    requires std::is_scalar_v<ScalarType>\n    bool operator>=(ScalarType rhs) const noexcept\n    {\n        return *this >= basic_json(rhs);\n    }\n#endif\n#else\n    /// @brief comparison: equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_eq/\n    friend bool operator==(const_reference lhs, const_reference rhs) noexcept\n    {\n#ifdef __GNUC__\n#pragma GCC diagnostic push\n#pragma GCC diagnostic ignored \"-Wfloat-equal\"\n#endif\n        JSON_IMPLEMENT_OPERATOR( ==, true, false, false)\n#ifdef __GNUC__\n#pragma GCC diagnostic pop\n#endif\n    }\n\n    /// @brief comparison: equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_eq/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator==(const_reference lhs, ScalarType rhs) noexcept\n    {\n        return lhs == basic_json(rhs);\n    }\n\n    /// @brief comparison: equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_eq/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator==(ScalarType lhs, const_reference rhs) noexcept\n    {\n        return basic_json(lhs) == rhs;\n    }\n\n    /// @brief comparison: not equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ne/\n    friend bool operator!=(const_reference lhs, const_reference rhs) noexcept\n    {\n        if (compares_unordered(lhs, rhs, true))\n        {\n            return false;\n        }\n        return !(lhs == rhs);\n    }\n\n    /// @brief comparison: not equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ne/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator!=(const_reference lhs, ScalarType rhs) noexcept\n    {\n        return lhs != basic_json(rhs);\n    }\n\n    /// @brief comparison: not equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ne/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator!=(ScalarType lhs, const_reference rhs) noexcept\n    {\n        return basic_json(lhs) != rhs;\n    }\n\n    /// @brief comparison: less than\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_lt/\n    friend bool operator<(const_reference lhs, const_reference rhs) noexcept\n    {\n        // default_result is used if we cannot compare values. In that case,\n        // we compare types. Note we have to call the operator explicitly,\n        // because MSVC has problems otherwise.\n        JSON_IMPLEMENT_OPERATOR( <, false, false, operator<(lhs_type, rhs_type))\n    }\n\n    /// @brief comparison: less than\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_lt/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator<(const_reference lhs, ScalarType rhs) noexcept\n    {\n        return lhs < basic_json(rhs);\n    }\n\n    /// @brief comparison: less than\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_lt/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator<(ScalarType lhs, const_reference rhs) noexcept\n    {\n        return basic_json(lhs) < rhs;\n    }\n\n    /// @brief comparison: less than or equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_le/\n    friend bool operator<=(const_reference lhs, const_reference rhs) noexcept\n    {\n        if (compares_unordered(lhs, rhs, true))\n        {\n            return false;\n        }\n        return !(rhs < lhs);\n    }\n\n    /// @brief comparison: less than or equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_le/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator<=(const_reference lhs, ScalarType rhs) noexcept\n    {\n        return lhs <= basic_json(rhs);\n    }\n\n    /// @brief comparison: less than or equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_le/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator<=(ScalarType lhs, const_reference rhs) noexcept\n    {\n        return basic_json(lhs) <= rhs;\n    }\n\n    /// @brief comparison: greater than\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_gt/\n    friend bool operator>(const_reference lhs, const_reference rhs) noexcept\n    {\n        // double inverse\n        if (compares_unordered(lhs, rhs))\n        {\n            return false;\n        }\n        return !(lhs <= rhs);\n    }\n\n    /// @brief comparison: greater than\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_gt/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator>(const_reference lhs, ScalarType rhs) noexcept\n    {\n        return lhs > basic_json(rhs);\n    }\n\n    /// @brief comparison: greater than\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_gt/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator>(ScalarType lhs, const_reference rhs) noexcept\n    {\n        return basic_json(lhs) > rhs;\n    }\n\n    /// @brief comparison: greater than or equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ge/\n    friend bool operator>=(const_reference lhs, const_reference rhs) noexcept\n    {\n        if (compares_unordered(lhs, rhs, true))\n        {\n            return false;\n        }\n        return !(lhs < rhs);\n    }\n\n    /// @brief comparison: greater than or equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ge/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator>=(const_reference lhs, ScalarType rhs) noexcept\n    {\n        return lhs >= basic_json(rhs);\n    }\n\n    /// @brief comparison: greater than or equal\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ge/\n    template<typename ScalarType, typename std::enable_if<\n                 std::is_scalar<ScalarType>::value, int>::type = 0>\n    friend bool operator>=(ScalarType lhs, const_reference rhs) noexcept\n    {\n        return basic_json(lhs) >= rhs;\n    }\n#endif\n\n#undef JSON_IMPLEMENT_OPERATOR\n\n    /// @}\n\n    ///////////////////\n    // serialization //\n    ///////////////////\n\n    /// @name serialization\n    /// @{\n#ifndef JSON_NO_IO\n    /// @brief serialize to stream\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ltlt/\n    friend std::ostream& operator<<(std::ostream& o, const basic_json& j)\n    {\n        // read width member and use it as indentation parameter if nonzero\n        const bool pretty_print = o.width() > 0;\n        const auto indentation = pretty_print ? o.width() : 0;\n\n        // reset width to 0 for subsequent calls to this stream\n        o.width(0);\n\n        // do the actual serialization\n        serializer s(detail::output_adapter<char>(o), o.fill());\n        s.dump(j, pretty_print, false, static_cast<unsigned int>(indentation));\n        return o;\n    }\n\n    /// @brief serialize to stream\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_ltlt/\n    /// @deprecated This function is deprecated since 3.0.0 and will be removed in\n    ///             version 4.0.0 of the library. Please use\n    ///             operator<<(std::ostream&, const basic_json&) instead; that is,\n    ///             replace calls like `j >> o;` with `o << j;`.\n    JSON_HEDLEY_DEPRECATED_FOR(3.0.0, operator<<(std::ostream&, const basic_json&))\n    friend std::ostream& operator>>(const basic_json& j, std::ostream& o)\n    {\n        return o << j;\n    }\n#endif  // JSON_NO_IO\n    /// @}\n\n    /////////////////////\n    // deserialization //\n    /////////////////////\n\n    /// @name deserialization\n    /// @{\n\n    /// @brief deserialize from a compatible input\n    /// @sa https://json.nlohmann.me/api/basic_json/parse/\n    template<typename InputType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json parse(InputType&& i,\n                            const parser_callback_t cb = nullptr,\n                            const bool allow_exceptions = true,\n                            const bool ignore_comments = false)\n    {\n        basic_json result;\n        parser(detail::input_adapter(std::forward<InputType>(i)), cb, allow_exceptions, ignore_comments).parse(true, result);\n        return result;\n    }\n\n    /// @brief deserialize from a pair of character iterators\n    /// @sa https://json.nlohmann.me/api/basic_json/parse/\n    template<typename IteratorType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json parse(IteratorType first,\n                            IteratorType last,\n                            const parser_callback_t cb = nullptr,\n                            const bool allow_exceptions = true,\n                            const bool ignore_comments = false)\n    {\n        basic_json result;\n        parser(detail::input_adapter(std::move(first), std::move(last)), cb, allow_exceptions, ignore_comments).parse(true, result);\n        return result;\n    }\n\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, parse(ptr, ptr + len))\n    static basic_json parse(detail::span_input_adapter&& i,\n                            const parser_callback_t cb = nullptr,\n                            const bool allow_exceptions = true,\n                            const bool ignore_comments = false)\n    {\n        basic_json result;\n        parser(i.get(), cb, allow_exceptions, ignore_comments).parse(true, result);\n        return result;\n    }\n\n    /// @brief check if the input is valid JSON\n    /// @sa https://json.nlohmann.me/api/basic_json/accept/\n    template<typename InputType>\n    static bool accept(InputType&& i,\n                       const bool ignore_comments = false)\n    {\n        return parser(detail::input_adapter(std::forward<InputType>(i)), nullptr, false, ignore_comments).accept(true);\n    }\n\n    /// @brief check if the input is valid JSON\n    /// @sa https://json.nlohmann.me/api/basic_json/accept/\n    template<typename IteratorType>\n    static bool accept(IteratorType first, IteratorType last,\n                       const bool ignore_comments = false)\n    {\n        return parser(detail::input_adapter(std::move(first), std::move(last)), nullptr, false, ignore_comments).accept(true);\n    }\n\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, accept(ptr, ptr + len))\n    static bool accept(detail::span_input_adapter&& i,\n                       const bool ignore_comments = false)\n    {\n        return parser(i.get(), nullptr, false, ignore_comments).accept(true);\n    }\n\n    /// @brief generate SAX events\n    /// @sa https://json.nlohmann.me/api/basic_json/sax_parse/\n    template <typename InputType, typename SAX>\n    JSON_HEDLEY_NON_NULL(2)\n    static bool sax_parse(InputType&& i, SAX* sax,\n                          input_format_t format = input_format_t::json,\n                          const bool strict = true,\n                          const bool ignore_comments = false)\n    {\n        auto ia = detail::input_adapter(std::forward<InputType>(i));\n        return format == input_format_t::json\n               ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict)\n               : detail::binary_reader<basic_json, decltype(ia), SAX>(std::move(ia), format).sax_parse(format, sax, strict);\n    }\n\n    /// @brief generate SAX events\n    /// @sa https://json.nlohmann.me/api/basic_json/sax_parse/\n    template<class IteratorType, class SAX>\n    JSON_HEDLEY_NON_NULL(3)\n    static bool sax_parse(IteratorType first, IteratorType last, SAX* sax,\n                          input_format_t format = input_format_t::json,\n                          const bool strict = true,\n                          const bool ignore_comments = false)\n    {\n        auto ia = detail::input_adapter(std::move(first), std::move(last));\n        return format == input_format_t::json\n               ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict)\n               : detail::binary_reader<basic_json, decltype(ia), SAX>(std::move(ia), format).sax_parse(format, sax, strict);\n    }\n\n    /// @brief generate SAX events\n    /// @sa https://json.nlohmann.me/api/basic_json/sax_parse/\n    /// @deprecated This function is deprecated since 3.8.0 and will be removed in\n    ///             version 4.0.0 of the library. Please use\n    ///             sax_parse(ptr, ptr + len) instead.\n    template <typename SAX>\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, sax_parse(ptr, ptr + len, ...))\n    JSON_HEDLEY_NON_NULL(2)\n    static bool sax_parse(detail::span_input_adapter&& i, SAX* sax,\n                          input_format_t format = input_format_t::json,\n                          const bool strict = true,\n                          const bool ignore_comments = false)\n    {\n        auto ia = i.get();\n        return format == input_format_t::json\n               // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg)\n               ? parser(std::move(ia), nullptr, true, ignore_comments).sax_parse(sax, strict)\n               // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg)\n               : detail::binary_reader<basic_json, decltype(ia), SAX>(std::move(ia), format).sax_parse(format, sax, strict);\n    }\n#ifndef JSON_NO_IO\n    /// @brief deserialize from stream\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_gtgt/\n    /// @deprecated This stream operator is deprecated since 3.0.0 and will be removed in\n    ///             version 4.0.0 of the library. Please use\n    ///             operator>>(std::istream&, basic_json&) instead; that is,\n    ///             replace calls like `j << i;` with `i >> j;`.\n    JSON_HEDLEY_DEPRECATED_FOR(3.0.0, operator>>(std::istream&, basic_json&))\n    friend std::istream& operator<<(basic_json& j, std::istream& i)\n    {\n        return operator>>(i, j);\n    }\n\n    /// @brief deserialize from stream\n    /// @sa https://json.nlohmann.me/api/basic_json/operator_gtgt/\n    friend std::istream& operator>>(std::istream& i, basic_json& j)\n    {\n        parser(detail::input_adapter(i)).parse(false, j);\n        return i;\n    }\n#endif  // JSON_NO_IO\n    /// @}\n\n    ///////////////////////////\n    // convenience functions //\n    ///////////////////////////\n\n    /// @brief return the type as string\n    /// @sa https://json.nlohmann.me/api/basic_json/type_name/\n    JSON_HEDLEY_RETURNS_NON_NULL\n    const char* type_name() const noexcept\n    {\n        switch (m_data.m_type)\n        {\n            case value_t::null:\n                return \"null\";\n            case value_t::object:\n                return \"object\";\n            case value_t::array:\n                return \"array\";\n            case value_t::string:\n                return \"string\";\n            case value_t::boolean:\n                return \"boolean\";\n            case value_t::binary:\n                return \"binary\";\n            case value_t::discarded:\n                return \"discarded\";\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            default:\n                return \"number\";\n        }\n    }\n\n  JSON_PRIVATE_UNLESS_TESTED:\n    //////////////////////\n    // member variables //\n    //////////////////////\n\n    struct data\n    {\n        /// the type of the current element\n        value_t m_type = value_t::null;\n\n        /// the value of the current element\n        json_value m_value = {};\n\n        data(const value_t v)\n            : m_type(v), m_value(v)\n        {\n        }\n\n        data(size_type cnt, const basic_json& val)\n            : m_type(value_t::array)\n        {\n            m_value.array = create<array_t>(cnt, val);\n        }\n\n        data() noexcept = default;\n        data(data&&) noexcept = default;\n        data(const data&) noexcept = delete;\n        data& operator=(data&&) noexcept = delete;\n        data& operator=(const data&) noexcept = delete;\n\n        ~data() noexcept\n        {\n            m_value.destroy(m_type);\n        }\n    };\n\n    data m_data = {};\n\n#if JSON_DIAGNOSTICS\n    /// a pointer to a parent value (for debugging purposes)\n    basic_json* m_parent = nullptr;\n#endif\n\n    //////////////////////////////////////////\n    // binary serialization/deserialization //\n    //////////////////////////////////////////\n\n    /// @name binary serialization/deserialization support\n    /// @{\n\n  public:\n    /// @brief create a CBOR serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_cbor/\n    static std::vector<std::uint8_t> to_cbor(const basic_json& j)\n    {\n        std::vector<std::uint8_t> result;\n        to_cbor(j, result);\n        return result;\n    }\n\n    /// @brief create a CBOR serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_cbor/\n    static void to_cbor(const basic_json& j, detail::output_adapter<std::uint8_t> o)\n    {\n        binary_writer<std::uint8_t>(o).write_cbor(j);\n    }\n\n    /// @brief create a CBOR serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_cbor/\n    static void to_cbor(const basic_json& j, detail::output_adapter<char> o)\n    {\n        binary_writer<char>(o).write_cbor(j);\n    }\n\n    /// @brief create a MessagePack serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_msgpack/\n    static std::vector<std::uint8_t> to_msgpack(const basic_json& j)\n    {\n        std::vector<std::uint8_t> result;\n        to_msgpack(j, result);\n        return result;\n    }\n\n    /// @brief create a MessagePack serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_msgpack/\n    static void to_msgpack(const basic_json& j, detail::output_adapter<std::uint8_t> o)\n    {\n        binary_writer<std::uint8_t>(o).write_msgpack(j);\n    }\n\n    /// @brief create a MessagePack serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_msgpack/\n    static void to_msgpack(const basic_json& j, detail::output_adapter<char> o)\n    {\n        binary_writer<char>(o).write_msgpack(j);\n    }\n\n    /// @brief create a UBJSON serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_ubjson/\n    static std::vector<std::uint8_t> to_ubjson(const basic_json& j,\n            const bool use_size = false,\n            const bool use_type = false)\n    {\n        std::vector<std::uint8_t> result;\n        to_ubjson(j, result, use_size, use_type);\n        return result;\n    }\n\n    /// @brief create a UBJSON serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_ubjson/\n    static void to_ubjson(const basic_json& j, detail::output_adapter<std::uint8_t> o,\n                          const bool use_size = false, const bool use_type = false)\n    {\n        binary_writer<std::uint8_t>(o).write_ubjson(j, use_size, use_type);\n    }\n\n    /// @brief create a UBJSON serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_ubjson/\n    static void to_ubjson(const basic_json& j, detail::output_adapter<char> o,\n                          const bool use_size = false, const bool use_type = false)\n    {\n        binary_writer<char>(o).write_ubjson(j, use_size, use_type);\n    }\n\n    /// @brief create a BJData serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_bjdata/\n    static std::vector<std::uint8_t> to_bjdata(const basic_json& j,\n            const bool use_size = false,\n            const bool use_type = false)\n    {\n        std::vector<std::uint8_t> result;\n        to_bjdata(j, result, use_size, use_type);\n        return result;\n    }\n\n    /// @brief create a BJData serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_bjdata/\n    static void to_bjdata(const basic_json& j, detail::output_adapter<std::uint8_t> o,\n                          const bool use_size = false, const bool use_type = false)\n    {\n        binary_writer<std::uint8_t>(o).write_ubjson(j, use_size, use_type, true, true);\n    }\n\n    /// @brief create a BJData serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_bjdata/\n    static void to_bjdata(const basic_json& j, detail::output_adapter<char> o,\n                          const bool use_size = false, const bool use_type = false)\n    {\n        binary_writer<char>(o).write_ubjson(j, use_size, use_type, true, true);\n    }\n\n    /// @brief create a BSON serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_bson/\n    static std::vector<std::uint8_t> to_bson(const basic_json& j)\n    {\n        std::vector<std::uint8_t> result;\n        to_bson(j, result);\n        return result;\n    }\n\n    /// @brief create a BSON serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_bson/\n    static void to_bson(const basic_json& j, detail::output_adapter<std::uint8_t> o)\n    {\n        binary_writer<std::uint8_t>(o).write_bson(j);\n    }\n\n    /// @brief create a BSON serialization of a given JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/to_bson/\n    static void to_bson(const basic_json& j, detail::output_adapter<char> o)\n    {\n        binary_writer<char>(o).write_bson(j);\n    }\n\n    /// @brief create a JSON value from an input in CBOR format\n    /// @sa https://json.nlohmann.me/api/basic_json/from_cbor/\n    template<typename InputType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json from_cbor(InputType&& i,\n                                const bool strict = true,\n                                const bool allow_exceptions = true,\n                                const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = detail::input_adapter(std::forward<InputType>(i));\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::cbor).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    /// @brief create a JSON value from an input in CBOR format\n    /// @sa https://json.nlohmann.me/api/basic_json/from_cbor/\n    template<typename IteratorType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json from_cbor(IteratorType first, IteratorType last,\n                                const bool strict = true,\n                                const bool allow_exceptions = true,\n                                const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = detail::input_adapter(std::move(first), std::move(last));\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::cbor).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    template<typename T>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_cbor(ptr, ptr + len))\n    static basic_json from_cbor(const T* ptr, std::size_t len,\n                                const bool strict = true,\n                                const bool allow_exceptions = true,\n                                const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error)\n    {\n        return from_cbor(ptr, ptr + len, strict, allow_exceptions, tag_handler);\n    }\n\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_cbor(ptr, ptr + len))\n    static basic_json from_cbor(detail::span_input_adapter&& i,\n                                const bool strict = true,\n                                const bool allow_exceptions = true,\n                                const cbor_tag_handler_t tag_handler = cbor_tag_handler_t::error)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = i.get();\n        // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg)\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::cbor).sax_parse(input_format_t::cbor, &sdp, strict, tag_handler);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    /// @brief create a JSON value from an input in MessagePack format\n    /// @sa https://json.nlohmann.me/api/basic_json/from_msgpack/\n    template<typename InputType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json from_msgpack(InputType&& i,\n                                   const bool strict = true,\n                                   const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = detail::input_adapter(std::forward<InputType>(i));\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::msgpack).sax_parse(input_format_t::msgpack, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    /// @brief create a JSON value from an input in MessagePack format\n    /// @sa https://json.nlohmann.me/api/basic_json/from_msgpack/\n    template<typename IteratorType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json from_msgpack(IteratorType first, IteratorType last,\n                                   const bool strict = true,\n                                   const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = detail::input_adapter(std::move(first), std::move(last));\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::msgpack).sax_parse(input_format_t::msgpack, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    template<typename T>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_msgpack(ptr, ptr + len))\n    static basic_json from_msgpack(const T* ptr, std::size_t len,\n                                   const bool strict = true,\n                                   const bool allow_exceptions = true)\n    {\n        return from_msgpack(ptr, ptr + len, strict, allow_exceptions);\n    }\n\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_msgpack(ptr, ptr + len))\n    static basic_json from_msgpack(detail::span_input_adapter&& i,\n                                   const bool strict = true,\n                                   const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = i.get();\n        // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg)\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::msgpack).sax_parse(input_format_t::msgpack, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    /// @brief create a JSON value from an input in UBJSON format\n    /// @sa https://json.nlohmann.me/api/basic_json/from_ubjson/\n    template<typename InputType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json from_ubjson(InputType&& i,\n                                  const bool strict = true,\n                                  const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = detail::input_adapter(std::forward<InputType>(i));\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::ubjson).sax_parse(input_format_t::ubjson, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    /// @brief create a JSON value from an input in UBJSON format\n    /// @sa https://json.nlohmann.me/api/basic_json/from_ubjson/\n    template<typename IteratorType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json from_ubjson(IteratorType first, IteratorType last,\n                                  const bool strict = true,\n                                  const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = detail::input_adapter(std::move(first), std::move(last));\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::ubjson).sax_parse(input_format_t::ubjson, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    template<typename T>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_ubjson(ptr, ptr + len))\n    static basic_json from_ubjson(const T* ptr, std::size_t len,\n                                  const bool strict = true,\n                                  const bool allow_exceptions = true)\n    {\n        return from_ubjson(ptr, ptr + len, strict, allow_exceptions);\n    }\n\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_ubjson(ptr, ptr + len))\n    static basic_json from_ubjson(detail::span_input_adapter&& i,\n                                  const bool strict = true,\n                                  const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = i.get();\n        // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg)\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::ubjson).sax_parse(input_format_t::ubjson, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    /// @brief create a JSON value from an input in BJData format\n    /// @sa https://json.nlohmann.me/api/basic_json/from_bjdata/\n    template<typename InputType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json from_bjdata(InputType&& i,\n                                  const bool strict = true,\n                                  const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = detail::input_adapter(std::forward<InputType>(i));\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::bjdata).sax_parse(input_format_t::bjdata, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    /// @brief create a JSON value from an input in BJData format\n    /// @sa https://json.nlohmann.me/api/basic_json/from_bjdata/\n    template<typename IteratorType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json from_bjdata(IteratorType first, IteratorType last,\n                                  const bool strict = true,\n                                  const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = detail::input_adapter(std::move(first), std::move(last));\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::bjdata).sax_parse(input_format_t::bjdata, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    /// @brief create a JSON value from an input in BSON format\n    /// @sa https://json.nlohmann.me/api/basic_json/from_bson/\n    template<typename InputType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json from_bson(InputType&& i,\n                                const bool strict = true,\n                                const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = detail::input_adapter(std::forward<InputType>(i));\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::bson).sax_parse(input_format_t::bson, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    /// @brief create a JSON value from an input in BSON format\n    /// @sa https://json.nlohmann.me/api/basic_json/from_bson/\n    template<typename IteratorType>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json from_bson(IteratorType first, IteratorType last,\n                                const bool strict = true,\n                                const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = detail::input_adapter(std::move(first), std::move(last));\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::bson).sax_parse(input_format_t::bson, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n\n    template<typename T>\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_bson(ptr, ptr + len))\n    static basic_json from_bson(const T* ptr, std::size_t len,\n                                const bool strict = true,\n                                const bool allow_exceptions = true)\n    {\n        return from_bson(ptr, ptr + len, strict, allow_exceptions);\n    }\n\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_bson(ptr, ptr + len))\n    static basic_json from_bson(detail::span_input_adapter&& i,\n                                const bool strict = true,\n                                const bool allow_exceptions = true)\n    {\n        basic_json result;\n        detail::json_sax_dom_parser<basic_json> sdp(result, allow_exceptions);\n        auto ia = i.get();\n        // NOLINTNEXTLINE(hicpp-move-const-arg,performance-move-const-arg)\n        const bool res = binary_reader<decltype(ia)>(std::move(ia), input_format_t::bson).sax_parse(input_format_t::bson, &sdp, strict);\n        return res ? result : basic_json(value_t::discarded);\n    }\n    /// @}\n\n    //////////////////////////\n    // JSON Pointer support //\n    //////////////////////////\n\n    /// @name JSON Pointer functions\n    /// @{\n\n    /// @brief access specified element via JSON Pointer\n    /// @sa https://json.nlohmann.me/api/basic_json/operator%5B%5D/\n    reference operator[](const json_pointer& ptr)\n    {\n        return ptr.get_unchecked(this);\n    }\n\n    template<typename BasicJsonType, detail::enable_if_t<detail::is_basic_json<BasicJsonType>::value, int> = 0>\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.0, basic_json::json_pointer or nlohmann::json_pointer<basic_json::string_t>) // NOLINT(readability/alt_tokens)\n    reference operator[](const ::nlohmann::json_pointer<BasicJsonType>& ptr)\n    {\n        return ptr.get_unchecked(this);\n    }\n\n    /// @brief access specified element via JSON Pointer\n    /// @sa https://json.nlohmann.me/api/basic_json/operator%5B%5D/\n    const_reference operator[](const json_pointer& ptr) const\n    {\n        return ptr.get_unchecked(this);\n    }\n\n    template<typename BasicJsonType, detail::enable_if_t<detail::is_basic_json<BasicJsonType>::value, int> = 0>\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.0, basic_json::json_pointer or nlohmann::json_pointer<basic_json::string_t>) // NOLINT(readability/alt_tokens)\n    const_reference operator[](const ::nlohmann::json_pointer<BasicJsonType>& ptr) const\n    {\n        return ptr.get_unchecked(this);\n    }\n\n    /// @brief access specified element via JSON Pointer\n    /// @sa https://json.nlohmann.me/api/basic_json/at/\n    reference at(const json_pointer& ptr)\n    {\n        return ptr.get_checked(this);\n    }\n\n    template<typename BasicJsonType, detail::enable_if_t<detail::is_basic_json<BasicJsonType>::value, int> = 0>\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.0, basic_json::json_pointer or nlohmann::json_pointer<basic_json::string_t>) // NOLINT(readability/alt_tokens)\n    reference at(const ::nlohmann::json_pointer<BasicJsonType>& ptr)\n    {\n        return ptr.get_checked(this);\n    }\n\n    /// @brief access specified element via JSON Pointer\n    /// @sa https://json.nlohmann.me/api/basic_json/at/\n    const_reference at(const json_pointer& ptr) const\n    {\n        return ptr.get_checked(this);\n    }\n\n    template<typename BasicJsonType, detail::enable_if_t<detail::is_basic_json<BasicJsonType>::value, int> = 0>\n    JSON_HEDLEY_DEPRECATED_FOR(3.11.0, basic_json::json_pointer or nlohmann::json_pointer<basic_json::string_t>) // NOLINT(readability/alt_tokens)\n    const_reference at(const ::nlohmann::json_pointer<BasicJsonType>& ptr) const\n    {\n        return ptr.get_checked(this);\n    }\n\n    /// @brief return flattened JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/flatten/\n    basic_json flatten() const\n    {\n        basic_json result(value_t::object);\n        json_pointer::flatten(\"\", *this, result);\n        return result;\n    }\n\n    /// @brief unflatten a previously flattened JSON value\n    /// @sa https://json.nlohmann.me/api/basic_json/unflatten/\n    basic_json unflatten() const\n    {\n        return json_pointer::unflatten(*this);\n    }\n\n    /// @}\n\n    //////////////////////////\n    // JSON Patch functions //\n    //////////////////////////\n\n    /// @name JSON Patch functions\n    /// @{\n\n    /// @brief applies a JSON patch in-place without copying the object\n    /// @sa https://json.nlohmann.me/api/basic_json/patch/\n    void patch_inplace(const basic_json& json_patch)\n    {\n        basic_json& result = *this;\n        // the valid JSON Patch operations\n        enum class patch_operations {add, remove, replace, move, copy, test, invalid};\n\n        const auto get_op = [](const std::string & op)\n        {\n            if (op == \"add\")\n            {\n                return patch_operations::add;\n            }\n            if (op == \"remove\")\n            {\n                return patch_operations::remove;\n            }\n            if (op == \"replace\")\n            {\n                return patch_operations::replace;\n            }\n            if (op == \"move\")\n            {\n                return patch_operations::move;\n            }\n            if (op == \"copy\")\n            {\n                return patch_operations::copy;\n            }\n            if (op == \"test\")\n            {\n                return patch_operations::test;\n            }\n\n            return patch_operations::invalid;\n        };\n\n        // wrapper for \"add\" operation; add value at ptr\n        const auto operation_add = [&result](json_pointer & ptr, basic_json val)\n        {\n            // adding to the root of the target document means replacing it\n            if (ptr.empty())\n            {\n                result = val;\n                return;\n            }\n\n            // make sure the top element of the pointer exists\n            json_pointer const top_pointer = ptr.top();\n            if (top_pointer != ptr)\n            {\n                result.at(top_pointer);\n            }\n\n            // get reference to parent of JSON pointer ptr\n            const auto last_path = ptr.back();\n            ptr.pop_back();\n            // parent must exist when performing patch add per RFC6902 specs\n            basic_json& parent = result.at(ptr);\n\n            switch (parent.m_data.m_type)\n            {\n                case value_t::null:\n                case value_t::object:\n                {\n                    // use operator[] to add value\n                    parent[last_path] = val;\n                    break;\n                }\n\n                case value_t::array:\n                {\n                    if (last_path == \"-\")\n                    {\n                        // special case: append to back\n                        parent.push_back(val);\n                    }\n                    else\n                    {\n                        const auto idx = json_pointer::template array_index<basic_json_t>(last_path);\n                        if (JSON_HEDLEY_UNLIKELY(idx > parent.size()))\n                        {\n                            // avoid undefined behavior\n                            JSON_THROW(out_of_range::create(401, detail::concat(\"array index \", std::to_string(idx), \" is out of range\"), &parent));\n                        }\n\n                        // default case: insert add offset\n                        parent.insert(parent.begin() + static_cast<difference_type>(idx), val);\n                    }\n                    break;\n                }\n\n                // if there exists a parent it cannot be primitive\n                case value_t::string: // LCOV_EXCL_LINE\n                case value_t::boolean: // LCOV_EXCL_LINE\n                case value_t::number_integer: // LCOV_EXCL_LINE\n                case value_t::number_unsigned: // LCOV_EXCL_LINE\n                case value_t::number_float: // LCOV_EXCL_LINE\n                case value_t::binary: // LCOV_EXCL_LINE\n                case value_t::discarded: // LCOV_EXCL_LINE\n                default:            // LCOV_EXCL_LINE\n                    JSON_ASSERT(false); // NOLINT(cert-dcl03-c,hicpp-static-assert,misc-static-assert) LCOV_EXCL_LINE\n            }\n        };\n\n        // wrapper for \"remove\" operation; remove value at ptr\n        const auto operation_remove = [this, & result](json_pointer & ptr)\n        {\n            // get reference to parent of JSON pointer ptr\n            const auto last_path = ptr.back();\n            ptr.pop_back();\n            basic_json& parent = result.at(ptr);\n\n            // remove child\n            if (parent.is_object())\n            {\n                // perform range check\n                auto it = parent.find(last_path);\n                if (JSON_HEDLEY_LIKELY(it != parent.end()))\n                {\n                    parent.erase(it);\n                }\n                else\n                {\n                    JSON_THROW(out_of_range::create(403, detail::concat(\"key '\", last_path, \"' not found\"), this));\n                }\n            }\n            else if (parent.is_array())\n            {\n                // note erase performs range check\n                parent.erase(json_pointer::template array_index<basic_json_t>(last_path));\n            }\n        };\n\n        // type check: top level value must be an array\n        if (JSON_HEDLEY_UNLIKELY(!json_patch.is_array()))\n        {\n            JSON_THROW(parse_error::create(104, 0, \"JSON patch must be an array of objects\", &json_patch));\n        }\n\n        // iterate and apply the operations\n        for (const auto& val : json_patch)\n        {\n            // wrapper to get a value for an operation\n            const auto get_value = [&val](const std::string & op,\n                                          const std::string & member,\n                                          bool string_type) -> basic_json &\n            {\n                // find value\n                auto it = val.m_data.m_value.object->find(member);\n\n                // context-sensitive error message\n                const auto error_msg = (op == \"op\") ? \"operation\" : detail::concat(\"operation '\", op, '\\''); // NOLINT(bugprone-unused-local-non-trivial-variable)\n\n                // check if desired value is present\n                if (JSON_HEDLEY_UNLIKELY(it == val.m_data.m_value.object->end()))\n                {\n                    // NOLINTNEXTLINE(performance-inefficient-string-concatenation)\n                    JSON_THROW(parse_error::create(105, 0, detail::concat(error_msg, \" must have member '\", member, \"'\"), &val));\n                }\n\n                // check if result is of type string\n                if (JSON_HEDLEY_UNLIKELY(string_type && !it->second.is_string()))\n                {\n                    // NOLINTNEXTLINE(performance-inefficient-string-concatenation)\n                    JSON_THROW(parse_error::create(105, 0, detail::concat(error_msg, \" must have string member '\", member, \"'\"), &val));\n                }\n\n                // no error: return value\n                return it->second;\n            };\n\n            // type check: every element of the array must be an object\n            if (JSON_HEDLEY_UNLIKELY(!val.is_object()))\n            {\n                JSON_THROW(parse_error::create(104, 0, \"JSON patch must be an array of objects\", &val));\n            }\n\n            // collect mandatory members\n            const auto op = get_value(\"op\", \"op\", true).template get<std::string>();\n            const auto path = get_value(op, \"path\", true).template get<std::string>();\n            json_pointer ptr(path);\n\n            switch (get_op(op))\n            {\n                case patch_operations::add:\n                {\n                    operation_add(ptr, get_value(\"add\", \"value\", false));\n                    break;\n                }\n\n                case patch_operations::remove:\n                {\n                    operation_remove(ptr);\n                    break;\n                }\n\n                case patch_operations::replace:\n                {\n                    // the \"path\" location must exist - use at()\n                    result.at(ptr) = get_value(\"replace\", \"value\", false);\n                    break;\n                }\n\n                case patch_operations::move:\n                {\n                    const auto from_path = get_value(\"move\", \"from\", true).template get<std::string>();\n                    json_pointer from_ptr(from_path);\n\n                    // the \"from\" location must exist - use at()\n                    basic_json const v = result.at(from_ptr);\n\n                    // The move operation is functionally identical to a\n                    // \"remove\" operation on the \"from\" location, followed\n                    // immediately by an \"add\" operation at the target\n                    // location with the value that was just removed.\n                    operation_remove(from_ptr);\n                    operation_add(ptr, v);\n                    break;\n                }\n\n                case patch_operations::copy:\n                {\n                    const auto from_path = get_value(\"copy\", \"from\", true).template get<std::string>();\n                    const json_pointer from_ptr(from_path);\n\n                    // the \"from\" location must exist - use at()\n                    basic_json const v = result.at(from_ptr);\n\n                    // The copy is functionally identical to an \"add\"\n                    // operation at the target location using the value\n                    // specified in the \"from\" member.\n                    operation_add(ptr, v);\n                    break;\n                }\n\n                case patch_operations::test:\n                {\n                    bool success = false;\n                    JSON_TRY\n                    {\n                        // check if \"value\" matches the one at \"path\"\n                        // the \"path\" location must exist - use at()\n                        success = (result.at(ptr) == get_value(\"test\", \"value\", false));\n                    }\n                    JSON_INTERNAL_CATCH (out_of_range&)\n                    {\n                        // ignore out of range errors: success remains false\n                    }\n\n                    // throw an exception if test fails\n                    if (JSON_HEDLEY_UNLIKELY(!success))\n                    {\n                        JSON_THROW(other_error::create(501, detail::concat(\"unsuccessful: \", val.dump()), &val));\n                    }\n\n                    break;\n                }\n\n                case patch_operations::invalid:\n                default:\n                {\n                    // op must be \"add\", \"remove\", \"replace\", \"move\", \"copy\", or\n                    // \"test\"\n                    JSON_THROW(parse_error::create(105, 0, detail::concat(\"operation value '\", op, \"' is invalid\"), &val));\n                }\n            }\n        }\n    }\n\n    /// @brief applies a JSON patch to a copy of the current object\n    /// @sa https://json.nlohmann.me/api/basic_json/patch/\n    basic_json patch(const basic_json& json_patch) const\n    {\n        basic_json result = *this;\n        result.patch_inplace(json_patch);\n        return result;\n    }\n\n    /// @brief creates a diff as a JSON patch\n    /// @sa https://json.nlohmann.me/api/basic_json/diff/\n    JSON_HEDLEY_WARN_UNUSED_RESULT\n    static basic_json diff(const basic_json& source, const basic_json& target,\n                           const std::string& path = \"\")\n    {\n        // the patch\n        basic_json result(value_t::array);\n\n        // if the values are the same, return empty patch\n        if (source == target)\n        {\n            return result;\n        }\n\n        if (source.type() != target.type())\n        {\n            // different types: replace value\n            result.push_back(\n            {\n                {\"op\", \"replace\"}, {\"path\", path}, {\"value\", target}\n            });\n            return result;\n        }\n\n        switch (source.type())\n        {\n            case value_t::array:\n            {\n                // first pass: traverse common elements\n                std::size_t i = 0;\n                while (i < source.size() && i < target.size())\n                {\n                    // recursive call to compare array values at index i\n                    auto temp_diff = diff(source[i], target[i], detail::concat(path, '/', std::to_string(i)));\n                    result.insert(result.end(), temp_diff.begin(), temp_diff.end());\n                    ++i;\n                }\n\n                // We now reached the end of at least one array\n                // in a second pass, traverse the remaining elements\n\n                // remove my remaining elements\n                const auto end_index = static_cast<difference_type>(result.size());\n                while (i < source.size())\n                {\n                    // add operations in reverse order to avoid invalid\n                    // indices\n                    result.insert(result.begin() + end_index, object(\n                    {\n                        {\"op\", \"remove\"},\n                        {\"path\", detail::concat(path, '/', std::to_string(i))}\n                    }));\n                    ++i;\n                }\n\n                // add other remaining elements\n                while (i < target.size())\n                {\n                    result.push_back(\n                    {\n                        {\"op\", \"add\"},\n                        {\"path\", detail::concat(path, \"/-\")},\n                        {\"value\", target[i]}\n                    });\n                    ++i;\n                }\n\n                break;\n            }\n\n            case value_t::object:\n            {\n                // first pass: traverse this object's elements\n                for (auto it = source.cbegin(); it != source.cend(); ++it)\n                {\n                    // escape the key name to be used in a JSON patch\n                    const auto path_key = detail::concat(path, '/', detail::escape(it.key()));\n\n                    if (target.find(it.key()) != target.end())\n                    {\n                        // recursive call to compare object values at key it\n                        auto temp_diff = diff(it.value(), target[it.key()], path_key);\n                        result.insert(result.end(), temp_diff.begin(), temp_diff.end());\n                    }\n                    else\n                    {\n                        // found a key that is not in o -> remove it\n                        result.push_back(object(\n                        {\n                            {\"op\", \"remove\"}, {\"path\", path_key}\n                        }));\n                    }\n                }\n\n                // second pass: traverse other object's elements\n                for (auto it = target.cbegin(); it != target.cend(); ++it)\n                {\n                    if (source.find(it.key()) == source.end())\n                    {\n                        // found a key that is not in this -> add it\n                        const auto path_key = detail::concat(path, '/', detail::escape(it.key()));\n                        result.push_back(\n                        {\n                            {\"op\", \"add\"}, {\"path\", path_key},\n                            {\"value\", it.value()}\n                        });\n                    }\n                }\n\n                break;\n            }\n\n            case value_t::null:\n            case value_t::string:\n            case value_t::boolean:\n            case value_t::number_integer:\n            case value_t::number_unsigned:\n            case value_t::number_float:\n            case value_t::binary:\n            case value_t::discarded:\n            default:\n            {\n                // both primitive type: replace value\n                result.push_back(\n                {\n                    {\"op\", \"replace\"}, {\"path\", path}, {\"value\", target}\n                });\n                break;\n            }\n        }\n\n        return result;\n    }\n    /// @}\n\n    ////////////////////////////////\n    // JSON Merge Patch functions //\n    ////////////////////////////////\n\n    /// @name JSON Merge Patch functions\n    /// @{\n\n    /// @brief applies a JSON Merge Patch\n    /// @sa https://json.nlohmann.me/api/basic_json/merge_patch/\n    void merge_patch(const basic_json& apply_patch)\n    {\n        if (apply_patch.is_object())\n        {\n            if (!is_object())\n            {\n                *this = object();\n            }\n            for (auto it = apply_patch.begin(); it != apply_patch.end(); ++it)\n            {\n                if (it.value().is_null())\n                {\n                    erase(it.key());\n                }\n                else\n                {\n                    operator[](it.key()).merge_patch(it.value());\n                }\n            }\n        }\n        else\n        {\n            *this = apply_patch;\n        }\n    }\n\n    /// @}\n};\n\n/// @brief user-defined to_string function for JSON values\n/// @sa https://json.nlohmann.me/api/basic_json/to_string/\nNLOHMANN_BASIC_JSON_TPL_DECLARATION\nstd::string to_string(const NLOHMANN_BASIC_JSON_TPL& j)\n{\n    return j.dump();\n}\n\ninline namespace literals\n{\ninline namespace json_literals\n{\n\n/// @brief user-defined string literal for JSON values\n/// @sa https://json.nlohmann.me/api/basic_json/operator_literal_json/\nJSON_HEDLEY_NON_NULL(1)\n#if !defined(JSON_HEDLEY_GCC_VERSION) || JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0)\n    inline nlohmann::json operator \"\"_json(const char* s, std::size_t n)\n#else\n    inline nlohmann::json operator \"\" _json(const char* s, std::size_t n)\n#endif\n{\n    return nlohmann::json::parse(s, s + n);\n}\n\n/// @brief user-defined string literal for JSON pointer\n/// @sa https://json.nlohmann.me/api/basic_json/operator_literal_json_pointer/\nJSON_HEDLEY_NON_NULL(1)\n#if !defined(JSON_HEDLEY_GCC_VERSION) || JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0)\n    inline nlohmann::json::json_pointer operator \"\"_json_pointer(const char* s, std::size_t n)\n#else\n    inline nlohmann::json::json_pointer operator \"\" _json_pointer(const char* s, std::size_t n)\n#endif\n{\n    return nlohmann::json::json_pointer(std::string(s, n));\n}\n\n}  // namespace json_literals\n}  // namespace literals\nNLOHMANN_JSON_NAMESPACE_END\n\n///////////////////////\n// nonmember support //\n///////////////////////\n\nnamespace std // NOLINT(cert-dcl58-cpp)\n{\n\n/// @brief hash value for JSON objects\n/// @sa https://json.nlohmann.me/api/basic_json/std_hash/\nNLOHMANN_BASIC_JSON_TPL_DECLARATION\nstruct hash<nlohmann::NLOHMANN_BASIC_JSON_TPL> // NOLINT(cert-dcl58-cpp)\n{\n    std::size_t operator()(const nlohmann::NLOHMANN_BASIC_JSON_TPL& j) const\n    {\n        return nlohmann::detail::hash(j);\n    }\n};\n\n// specialization for std::less<value_t>\ntemplate<>\nstruct less< ::nlohmann::detail::value_t> // do not remove the space after '<', see https://github.com/nlohmann/json/pull/679\n{\n    /*!\n    @brief compare two value_t enum values\n    @since version 3.0.0\n    */\n    bool operator()(::nlohmann::detail::value_t lhs,\n                    ::nlohmann::detail::value_t rhs) const noexcept\n    {\n#if JSON_HAS_THREE_WAY_COMPARISON\n        return std::is_lt(lhs <=> rhs); // *NOPAD*\n#else\n        return ::nlohmann::detail::operator<(lhs, rhs);\n#endif\n    }\n};\n\n// C++20 prohibit function specialization in the std namespace.\n#ifndef JSON_HAS_CPP_20\n\n/// @brief exchanges the values of two JSON objects\n/// @sa https://json.nlohmann.me/api/basic_json/std_swap/\nNLOHMANN_BASIC_JSON_TPL_DECLARATION\ninline void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL& j1, nlohmann::NLOHMANN_BASIC_JSON_TPL& j2) noexcept(  // NOLINT(readability-inconsistent-declaration-parameter-name, cert-dcl58-cpp)\n    is_nothrow_move_constructible<nlohmann::NLOHMANN_BASIC_JSON_TPL>::value&&                          // NOLINT(misc-redundant-expression,cppcoreguidelines-noexcept-swap,performance-noexcept-swap)\n    is_nothrow_move_assignable<nlohmann::NLOHMANN_BASIC_JSON_TPL>::value)\n{\n    j1.swap(j2);\n}\n\n#endif\n\n}  // namespace std\n\n#if JSON_USE_GLOBAL_UDLS\n    #if !defined(JSON_HEDLEY_GCC_VERSION) || JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0)\n        using nlohmann::literals::json_literals::operator \"\"_json; // NOLINT(misc-unused-using-decls,google-global-names-in-headers)\n        using nlohmann::literals::json_literals::operator \"\"_json_pointer; //NOLINT(misc-unused-using-decls,google-global-names-in-headers)\n    #else\n        using nlohmann::literals::json_literals::operator \"\" _json; // NOLINT(misc-unused-using-decls,google-global-names-in-headers)\n        using nlohmann::literals::json_literals::operator \"\" _json_pointer; //NOLINT(misc-unused-using-decls,google-global-names-in-headers)\n    #endif\n#endif\n\n// #include <nlohmann/detail/macro_unscope.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n// restore clang diagnostic settings\n#if defined(__clang__)\n    #pragma clang diagnostic pop\n#endif\n\n// clean up\n#undef JSON_ASSERT\n#undef JSON_INTERNAL_CATCH\n#undef JSON_THROW\n#undef JSON_PRIVATE_UNLESS_TESTED\n#undef NLOHMANN_BASIC_JSON_TPL_DECLARATION\n#undef NLOHMANN_BASIC_JSON_TPL\n#undef JSON_EXPLICIT\n#undef NLOHMANN_CAN_CALL_STD_FUNC_IMPL\n#undef JSON_INLINE_VARIABLE\n#undef JSON_NO_UNIQUE_ADDRESS\n#undef JSON_DISABLE_ENUM_SERIALIZATION\n#undef JSON_USE_GLOBAL_UDLS\n\n#ifndef JSON_TEST_KEEP_MACROS\n    #undef JSON_CATCH\n    #undef JSON_TRY\n    #undef JSON_HAS_CPP_11\n    #undef JSON_HAS_CPP_14\n    #undef JSON_HAS_CPP_17\n    #undef JSON_HAS_CPP_20\n    #undef JSON_HAS_FILESYSTEM\n    #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM\n    #undef JSON_HAS_THREE_WAY_COMPARISON\n    #undef JSON_HAS_RANGES\n    #undef JSON_HAS_STATIC_RTTI\n    #undef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON\n#endif\n\n// #include <nlohmann/thirdparty/hedley/hedley_undef.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n#undef JSON_HEDLEY_ALWAYS_INLINE\n#undef JSON_HEDLEY_ARM_VERSION\n#undef JSON_HEDLEY_ARM_VERSION_CHECK\n#undef JSON_HEDLEY_ARRAY_PARAM\n#undef JSON_HEDLEY_ASSUME\n#undef JSON_HEDLEY_BEGIN_C_DECLS\n#undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE\n#undef JSON_HEDLEY_CLANG_HAS_BUILTIN\n#undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE\n#undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE\n#undef JSON_HEDLEY_CLANG_HAS_EXTENSION\n#undef JSON_HEDLEY_CLANG_HAS_FEATURE\n#undef JSON_HEDLEY_CLANG_HAS_WARNING\n#undef JSON_HEDLEY_COMPCERT_VERSION\n#undef JSON_HEDLEY_COMPCERT_VERSION_CHECK\n#undef JSON_HEDLEY_CONCAT\n#undef JSON_HEDLEY_CONCAT3\n#undef JSON_HEDLEY_CONCAT3_EX\n#undef JSON_HEDLEY_CONCAT_EX\n#undef JSON_HEDLEY_CONST\n#undef JSON_HEDLEY_CONSTEXPR\n#undef JSON_HEDLEY_CONST_CAST\n#undef JSON_HEDLEY_CPP_CAST\n#undef JSON_HEDLEY_CRAY_VERSION\n#undef JSON_HEDLEY_CRAY_VERSION_CHECK\n#undef JSON_HEDLEY_C_DECL\n#undef JSON_HEDLEY_DEPRECATED\n#undef JSON_HEDLEY_DEPRECATED_FOR\n#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL\n#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_\n#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED\n#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES\n#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS\n#undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION\n#undef JSON_HEDLEY_DIAGNOSTIC_POP\n#undef JSON_HEDLEY_DIAGNOSTIC_PUSH\n#undef JSON_HEDLEY_DMC_VERSION\n#undef JSON_HEDLEY_DMC_VERSION_CHECK\n#undef JSON_HEDLEY_EMPTY_BASES\n#undef JSON_HEDLEY_EMSCRIPTEN_VERSION\n#undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK\n#undef JSON_HEDLEY_END_C_DECLS\n#undef JSON_HEDLEY_FLAGS\n#undef JSON_HEDLEY_FLAGS_CAST\n#undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE\n#undef JSON_HEDLEY_GCC_HAS_BUILTIN\n#undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE\n#undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE\n#undef JSON_HEDLEY_GCC_HAS_EXTENSION\n#undef JSON_HEDLEY_GCC_HAS_FEATURE\n#undef JSON_HEDLEY_GCC_HAS_WARNING\n#undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK\n#undef JSON_HEDLEY_GCC_VERSION\n#undef JSON_HEDLEY_GCC_VERSION_CHECK\n#undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE\n#undef JSON_HEDLEY_GNUC_HAS_BUILTIN\n#undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE\n#undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE\n#undef JSON_HEDLEY_GNUC_HAS_EXTENSION\n#undef JSON_HEDLEY_GNUC_HAS_FEATURE\n#undef JSON_HEDLEY_GNUC_HAS_WARNING\n#undef JSON_HEDLEY_GNUC_VERSION\n#undef JSON_HEDLEY_GNUC_VERSION_CHECK\n#undef JSON_HEDLEY_HAS_ATTRIBUTE\n#undef JSON_HEDLEY_HAS_BUILTIN\n#undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE\n#undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS\n#undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE\n#undef JSON_HEDLEY_HAS_EXTENSION\n#undef JSON_HEDLEY_HAS_FEATURE\n#undef JSON_HEDLEY_HAS_WARNING\n#undef JSON_HEDLEY_IAR_VERSION\n#undef JSON_HEDLEY_IAR_VERSION_CHECK\n#undef JSON_HEDLEY_IBM_VERSION\n#undef JSON_HEDLEY_IBM_VERSION_CHECK\n#undef JSON_HEDLEY_IMPORT\n#undef JSON_HEDLEY_INLINE\n#undef JSON_HEDLEY_INTEL_CL_VERSION\n#undef JSON_HEDLEY_INTEL_CL_VERSION_CHECK\n#undef JSON_HEDLEY_INTEL_VERSION\n#undef JSON_HEDLEY_INTEL_VERSION_CHECK\n#undef JSON_HEDLEY_IS_CONSTANT\n#undef JSON_HEDLEY_IS_CONSTEXPR_\n#undef JSON_HEDLEY_LIKELY\n#undef JSON_HEDLEY_MALLOC\n#undef JSON_HEDLEY_MCST_LCC_VERSION\n#undef JSON_HEDLEY_MCST_LCC_VERSION_CHECK\n#undef JSON_HEDLEY_MESSAGE\n#undef JSON_HEDLEY_MSVC_VERSION\n#undef JSON_HEDLEY_MSVC_VERSION_CHECK\n#undef JSON_HEDLEY_NEVER_INLINE\n#undef JSON_HEDLEY_NON_NULL\n#undef JSON_HEDLEY_NO_ESCAPE\n#undef JSON_HEDLEY_NO_RETURN\n#undef JSON_HEDLEY_NO_THROW\n#undef JSON_HEDLEY_NULL\n#undef JSON_HEDLEY_PELLES_VERSION\n#undef JSON_HEDLEY_PELLES_VERSION_CHECK\n#undef JSON_HEDLEY_PGI_VERSION\n#undef JSON_HEDLEY_PGI_VERSION_CHECK\n#undef JSON_HEDLEY_PREDICT\n#undef JSON_HEDLEY_PRINTF_FORMAT\n#undef JSON_HEDLEY_PRIVATE\n#undef JSON_HEDLEY_PUBLIC\n#undef JSON_HEDLEY_PURE\n#undef JSON_HEDLEY_REINTERPRET_CAST\n#undef JSON_HEDLEY_REQUIRE\n#undef JSON_HEDLEY_REQUIRE_CONSTEXPR\n#undef JSON_HEDLEY_REQUIRE_MSG\n#undef JSON_HEDLEY_RESTRICT\n#undef JSON_HEDLEY_RETURNS_NON_NULL\n#undef JSON_HEDLEY_SENTINEL\n#undef JSON_HEDLEY_STATIC_ASSERT\n#undef JSON_HEDLEY_STATIC_CAST\n#undef JSON_HEDLEY_STRINGIFY\n#undef JSON_HEDLEY_STRINGIFY_EX\n#undef JSON_HEDLEY_SUNPRO_VERSION\n#undef JSON_HEDLEY_SUNPRO_VERSION_CHECK\n#undef JSON_HEDLEY_TINYC_VERSION\n#undef JSON_HEDLEY_TINYC_VERSION_CHECK\n#undef JSON_HEDLEY_TI_ARMCL_VERSION\n#undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK\n#undef JSON_HEDLEY_TI_CL2000_VERSION\n#undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK\n#undef JSON_HEDLEY_TI_CL430_VERSION\n#undef JSON_HEDLEY_TI_CL430_VERSION_CHECK\n#undef JSON_HEDLEY_TI_CL6X_VERSION\n#undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK\n#undef JSON_HEDLEY_TI_CL7X_VERSION\n#undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK\n#undef JSON_HEDLEY_TI_CLPRU_VERSION\n#undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK\n#undef JSON_HEDLEY_TI_VERSION\n#undef JSON_HEDLEY_TI_VERSION_CHECK\n#undef JSON_HEDLEY_UNAVAILABLE\n#undef JSON_HEDLEY_UNLIKELY\n#undef JSON_HEDLEY_UNPREDICTABLE\n#undef JSON_HEDLEY_UNREACHABLE\n#undef JSON_HEDLEY_UNREACHABLE_RETURN\n#undef JSON_HEDLEY_VERSION\n#undef JSON_HEDLEY_VERSION_DECODE_MAJOR\n#undef JSON_HEDLEY_VERSION_DECODE_MINOR\n#undef JSON_HEDLEY_VERSION_DECODE_REVISION\n#undef JSON_HEDLEY_VERSION_ENCODE\n#undef JSON_HEDLEY_WARNING\n#undef JSON_HEDLEY_WARN_UNUSED_RESULT\n#undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG\n#undef JSON_HEDLEY_FALL_THROUGH\n\n\n#endif  // INCLUDE_NLOHMANN_JSON_HPP_\n"
  },
  {
    "path": "archive/third_party/nlohmann/json_fwd.hpp",
    "content": "//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_\n#define INCLUDE_NLOHMANN_JSON_FWD_HPP_\n\n#include <cstdint> // int64_t, uint64_t\n#include <map> // map\n#include <memory> // allocator\n#include <string> // string\n#include <vector> // vector\n\n// #include <nlohmann/detail/abi_macros.hpp>\n//     __ _____ _____ _____\n//  __|  |   __|     |   | |  JSON for Modern C++\n// |  |  |__   |  |  | | | |  version 3.11.3\n// |_____|_____|_____|_|___|  https://github.com/nlohmann/json\n//\n// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann <https://nlohmann.me>\n// SPDX-License-Identifier: MIT\n\n\n// This file contains all macro definitions affecting or depending on the ABI\n\n#ifndef JSON_SKIP_LIBRARY_VERSION_CHECK\n    #if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH)\n        #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 3\n            #warning \"Already included a different version of the library!\"\n        #endif\n    #endif\n#endif\n\n#define NLOHMANN_JSON_VERSION_MAJOR 3   // NOLINT(modernize-macro-to-enum)\n#define NLOHMANN_JSON_VERSION_MINOR 11  // NOLINT(modernize-macro-to-enum)\n#define NLOHMANN_JSON_VERSION_PATCH 3   // NOLINT(modernize-macro-to-enum)\n\n#ifndef JSON_DIAGNOSTICS\n    #define JSON_DIAGNOSTICS 0\n#endif\n\n#ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON\n    #define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0\n#endif\n\n#if JSON_DIAGNOSTICS\n    #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag\n#else\n    #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS\n#endif\n\n#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON\n    #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp\n#else\n    #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON\n#endif\n\n#ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION\n    #define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0\n#endif\n\n// Construct the namespace ABI tags component\n#define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) json_abi ## a ## b\n#define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b) \\\n    NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b)\n\n#define NLOHMANN_JSON_ABI_TAGS                                       \\\n    NLOHMANN_JSON_ABI_TAGS_CONCAT(                                   \\\n            NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS,                       \\\n            NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON)\n\n// Construct the namespace version component\n#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \\\n    _v ## major ## _ ## minor ## _ ## patch\n#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \\\n    NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch)\n\n#if NLOHMANN_JSON_NAMESPACE_NO_VERSION\n#define NLOHMANN_JSON_NAMESPACE_VERSION\n#else\n#define NLOHMANN_JSON_NAMESPACE_VERSION                                 \\\n    NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \\\n                                           NLOHMANN_JSON_VERSION_MINOR, \\\n                                           NLOHMANN_JSON_VERSION_PATCH)\n#endif\n\n// Combine namespace components\n#define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b\n#define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \\\n    NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b)\n\n#ifndef NLOHMANN_JSON_NAMESPACE\n#define NLOHMANN_JSON_NAMESPACE               \\\n    nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \\\n            NLOHMANN_JSON_ABI_TAGS,           \\\n            NLOHMANN_JSON_NAMESPACE_VERSION)\n#endif\n\n#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN\n#define NLOHMANN_JSON_NAMESPACE_BEGIN                \\\n    namespace nlohmann                               \\\n    {                                                \\\n    inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \\\n                NLOHMANN_JSON_ABI_TAGS,              \\\n                NLOHMANN_JSON_NAMESPACE_VERSION)     \\\n    {\n#endif\n\n#ifndef NLOHMANN_JSON_NAMESPACE_END\n#define NLOHMANN_JSON_NAMESPACE_END                                     \\\n    }  /* namespace (inline namespace) NOLINT(readability/namespace) */ \\\n    }  // namespace nlohmann\n#endif\n\n\n/*!\n@brief namespace for Niels Lohmann\n@see https://github.com/nlohmann\n@since version 1.0.0\n*/\nNLOHMANN_JSON_NAMESPACE_BEGIN\n\n/*!\n@brief default JSONSerializer template argument\n\nThis serializer ignores the template arguments and uses ADL\n([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl))\nfor serialization.\n*/\ntemplate<typename T = void, typename SFINAE = void>\nstruct adl_serializer;\n\n/// a class to store JSON values\n/// @sa https://json.nlohmann.me/api/basic_json/\ntemplate<template<typename U, typename V, typename... Args> class ObjectType =\n         std::map,\n         template<typename U, typename... Args> class ArrayType = std::vector,\n         class StringType = std::string, class BooleanType = bool,\n         class NumberIntegerType = std::int64_t,\n         class NumberUnsignedType = std::uint64_t,\n         class NumberFloatType = double,\n         template<typename U> class AllocatorType = std::allocator,\n         template<typename T, typename SFINAE = void> class JSONSerializer =\n         adl_serializer,\n         class BinaryType = std::vector<std::uint8_t>, // cppcheck-suppress syntaxError\n         class CustomBaseClass = void>\nclass basic_json;\n\n/// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document\n/// @sa https://json.nlohmann.me/api/json_pointer/\ntemplate<typename RefStringType>\nclass json_pointer;\n\n/*!\n@brief default specialization\n@sa https://json.nlohmann.me/api/json/\n*/\nusing json = basic_json<>;\n\n/// @brief a minimal map-like container that preserves insertion order\n/// @sa https://json.nlohmann.me/api/ordered_map/\ntemplate<class Key, class T, class IgnoredLess, class Allocator>\nstruct ordered_map;\n\n/// @brief specialization that maintains the insertion order of object keys\n/// @sa https://json.nlohmann.me/api/ordered_json/\nusing ordered_json = basic_json<nlohmann::ordered_map>;\n\nNLOHMANN_JSON_NAMESPACE_END\n\n#endif  // INCLUDE_NLOHMANN_JSON_FWD_HPP_\n"
  },
  {
    "path": "book.toml",
    "content": "[book]\nauthors = [\"kvcache-ai\"]\nlanguage = \"zh-CN\"\ntitle = \"Ktransformers\"\nsrc = \"doc\"\n\n[output.html]\ngit-repository-url = \"https://github.com/kvcache-ai/ktransformers\"\nedit-url-template = \"https://github.com/kvcache-ai/ktransformers/edit/main/{path}\"\n\n[output.html.playground]\neditable = true\ncopy-js = true\n# line-numbers = true\n\n[output.html.fold]\nenable = true\nlevel = 0"
  },
  {
    "path": "doc/SUMMARY.md",
    "content": "# Ktransformers\n\n[Introduction](./README.md)\n# Install & Usage\n- [For kt-kernel](en/kt-kernel/kt-kernel_intro.md)\n- [For kt-sft](en/SFT/KTransformers-Fine-Tuning_User-Guide.md)\n\n# Tutorial \n- [kt-sft part](en/SFT/README.md)\n  - [Injection Tutorial](en/SFT/injection_tutorial.md)\n  - [kt-sft developer tech notes](en/SFT/KTransformers-Fine-Tuning_Developer-Technical-Notes.md)\n  - [DPO tutorial](en/SFT/DPO_tutorial.md)\n  <!-- - [Multi-GPU Tutorial](en/multi-gpu-tutorial.md) -->\n  <!-- - [Use FP8 GPU Kernel](en/fp8_kernel.md) -->\n  <!-- - [Use AMD GPU](en/ROCm.md) -->\n<!-- - [Deepseek-R1/V3 Show Case/Tutorial](en/DeepseekR1_V3_tutorial.md) -->\n<!-- - [Why KTransformers So Fast](en/deepseek-v2-injection.md) -->\n<!-- # For Developer\n- [Makefile Usage](en/makefile_usage.md) -->\n- [kt-kernel part](en/kt-kernel/README.md)\n  - [kt-cli](en/kt-kernel/kt-cli.md)\n# FAQ\n- [FAQ](en/FAQ.md)\n<!-- # V3 Reproduction\n- [Success List](en/V3-success.md)\n# Benchmark\n- [Benchmark](en/benchmark.md) -->\n"
  },
  {
    "path": "doc/basic/note1.md",
    "content": "# basic-first20\n"
  },
  {
    "path": "doc/basic/note2.md",
    "content": "# basic-data_structure\n"
  },
  {
    "path": "doc/en/AMX.md",
    "content": "# Qwen 3 + KTransformers 0.3 (+AMX) = AI Workstation/PC\nFollowing DeepSeek-V3/R1, LLaMa-4, and Kimi-VL, Qwen has also released an impressive MoE model—undoubtedly, this year belongs to MoE. As a low-barrier inference system for running MoE models in local heterogeneous environments, KTransformers naturally joins the party. Thanks to the support of the Qwen team, we completed Day 0 support for the entire Qwen 3 series of MoE models. At the same time, we took this opportunity to open-source the long-awaited preliminary version of our AMX high-performance operators (BF16, Int8; an int4 variant is coming soon), officially advancing to version 0.3.\n\nWhat excites me most about Qwen3MoE is that, unlike the 671 B “giant” model, its two configurations: 235B-A22 and 30B-A3B, **hit the performance sweet spots for both local workstations and consumer-grade PCs**. Accordingly, we ran benchmarks in two typical setups:\n\nServer CPU (Xeon 4) + RTX 4090\n\nConsumer-grade CPU (Core i9-14900KF + dual-channel DDR5-4000 MT/s) + RTX 4090\n\nNote: Because the PC's memory has a low frequency, large capacity, and multiple sticks, it downclocks severely and only operates at 4000MT. Using higher - frequency memory can boost performance.\n\nThe results are as follows:\n\nhttps://github.com/user-attachments/assets/fafe8aec-4e22-49a8-8553-59fb5c6b00a2\n\n\n![Image](https://github.com/user-attachments/assets/62567aad-353b-4c6f-ab87-2ea283ff2ba2)\n\nYou can see that, thanks to the AMX instruction optimizations, we achieve up to 347 tokens/s prefill performance in the workstation scenario. On consumer-grade CPUs, we’re able to run the large model (235B-A22) and deliver smooth performance on the smaller 30B-A3B. Even in terms of resource overhead, it appears that a high-end gaming laptop can handle 30B-A3B smoothly. After talking about the concept of AIPC for so long, we can finally see its feasibility.\n\nHere is the Qwen3MoE startup command:\n\n``` python\n# llamafile backend\npython ktransformers/server/main.py --architectures Qwen3MoeForCausalLM --model_path <model_dir> --gguf_path <gguf_dir> --optimize_config_path ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml --backend_type balance_serve\n# AMX backend\npython ktransformers/server/main.py --architectures Qwen3MoeForCausalLM --model_path <model_dir> --gguf_path <gguf_dir> --optimize_config_path ktransformers/optimize/optimize_rules/Qwen3Moe-serve-amx.yaml --backend_type balance_serve\n```\n\n**Note: At present, Qwen3MoE running with AMX can only read BF16 GGUF; support for loading from safetensor will be added later.**\n\nTo make it easier for everyone to understand the AMX optimizations we’ve open-sourced, we’ve prepared a brief document. We also extend our gratitude to Intel for their assistance.\n\n# Introduction to AMX Instruction Set\n\nIntel Advanced Matrix Extensions (AMX) are a set of specialized instruction extensions introduced for the x86 architecture starting with Sapphire Rapids (4th generation Xeon Scalable processors) and onward. AMX accelerates large-scale matrix computations at the hardware level, particularly for the compute-intensive parts of deep learning inference and machine learning workloads. By introducing the concept of Tile registers, it loads 2D sub-matrices into dedicated Tile registers and performs matrix multiply-accumulate operations at the register level, significantly improving throughput and energy efficiency.\n\nEach CPU core contains 8 dedicated registers (tmm0–tmm7), with each register capable of holding up to 16 rows × 64 bytes of data to store 2D sub-matrices. Additionally, there is a 64-byte configuration register (TILECFG) used to describe each tmm register's number of rows, columns, and row stride.\n\nThe main AMX instructions are summarized as follows:\n\n| Instruction Category | Instruction Names | Description |\n|:---|:---|:---|\n| Configuration Instructions | LDTILECFG, STTILECFG, TILERELEASE, TILEZERO | Configure/reset Tile registers and metadata |\n| Load/Store Instructions | TILELOADD, TILELOADDT1, TILESTORED | Transfer data between memory and Tile registers |\n| INT8 Computation Instructions | TDPBSSD, TDPBUSD, TDPBUUD, TDPBSUD | Perform multiply and accumulate operations on int8 sub-matrices within Tiles |\n| BF16 Computation Instructions | TDPBF16PS | Perform multiply and accumulate operations on bfloat16 sub-matrices within Tiles |\n\nTo simplify development, Intel provides corresponding intrinsics, allowing C/C++ developers to leverage AMX's performance benefits without writing lengthy assembly code. For example:\n\n```C++\n#include <immintrin.h>\n\n_tile_loadconfig(cfg_ptr);\n_tile_loadd(tmm0, A_ptr, lda);\n_tile_loadd(tmm1, B_ptr, ldb);\n_tile_zero(tmm2)\n_tile_dpbf16ps(tmm2, tmm0, tmm1);\n_tile_stored(tmm2, C_ptr, ldc);\n_tile_release();\n```\n\nThe above code copies sub-matrices from memory (A_ptr, B_ptr) to Tile registers, calls the AMX BF16 compute instruction to multiply two sub-matrices, and then copies the result to memory (C_ptr).\n\nTaking INT8 as an example, AMX can perform the multiplication of two 16×64 sub-matrices (32,768 multiply/add operations) with a single instruction in 16 CPU cycles, enabling each core to complete 2048 multiply/add operations per cycle — 8 times the performance of AVX-512. On an Intel Xeon 4 CPU, a single core can theoretically provide 4 TOPS of compute power, making it highly suitable for compute-intensive tasks on the CPU.\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"amx_intro\" src=\"../assets/amx_intro.png\" width=60%>\n  </picture>\n</p>\n\n\n# AMX Kernel in KTransformers\n\nBefore version v0.3, KTransformers performed CPU matrix multiplications based on operators provided by llamafile. Unfortunately, llamafile's implementation had not yet been optimized for the AMX instruction set. This resulted in performance bottlenecks, even in strong hardware environments (such as Xeon 4th Gen + 4090), where inference speeds for large models like DeepSeek-V3 reached only 91 tokens/s during the prefill phase. The CPU thus remained a significant bottleneck. In long prompt scenarios, such performance is clearly unsatisfactory. To fully unleash CPU potential, we introduced a brand-new AMX optimization path along with multiple technical improvements in v0.3.\n\n## 1. AMX Tiling-aware Memory Layout\n\nAMX provides a high-throughput Tile register computation model, reducing instruction count and boosting theoretical throughput through coarse-grained matrix operations. However, to truly exploit AMX's potential, memory access efficiency is critical: because AMX transfers entire Tiles at once, misaligned Tiles and chaotic access patterns can cause severe cache misses, nullifying throughput gains.\n\nThus, in v0.3, we stopped directly memory-mapping GGUF-format files and introduced AMX Tiling-aware memory preprocessing during model loading. Specifically, expert weight matrices in MoE models are pre-rearranged into Tile-friendly sub-matrices whose shapes precisely match AMX Tile register dimensions, eliminating dynamic transposition overhead during inference. During rearrangement, we strictly align each sub-matrix's start address to 64 bytes to avoid cache line splits, and arrange sub-matrices sequentially according to computation access patterns, maximizing L1/L2 cache hit rates using compiler and hardware sequential prefetch capabilities.\n\nFor Int8 quantized formats, we adopted Symmetric Group-wise Quantization, with each column forming a group sharing a scale factor stored separately to maintain memory alignment for Tile data.\n\nThis AMX Tiling-aware memory layout design reduces memory latency while providing optimal input conditions for downstream computation kernels.\n\n## 2. Cache-friendly AMX Kernel\n\nDuring inference, we designed around the CPU’s multi-level cache hierarchy to perform computations in-place in high-speed caches, minimizing DRAM access frequency and overhead.\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"amx\" src=\"../assets/amx.png\" width=60%>\n  </picture>\n</p>\n\nAs shown in the figure, \n- ① Expert weight matrices are first column-wise partitioned into multiple tasks dynamically scheduled across threads. Input activations are shared among tasks and typically reside in the shared L3 cache due to locality.\n- ② Within each task, expert weights are row-wise partitioned into blocks, with block sizes finely tuned to ensure input activations, weights, and intermediate results stay within L2 cache, avoiding DRAM access.\n- ③ ④ ⑤ Each block is treated as a set of sub-matrices matching AMX Tile registers, and during Tile-level computation, input Tiles (tmm0–tmm1) and expert Tiles (tmm2–tmm3) are loaded, and four AMX multiplication instructions directly generate and accumulate products into Tile registers (tmm4–tmm7), with output activations accumulated in Tile registers or L1 cache, avoiding additional data movement.\n\nIn short, we leveraged the cache hierarchy: every data element of expert weights and output activations accesses DRAM only once, with the other accesses hitting L2 or higher caches; input activations are accessed from DRAM only once and later hit in L3 or higher caches. This significantly reduces main memory traffic and improves overall execution efficiency.\n\n## 3. AVX-512 Kernel Adaptation for Low Arithmetic Intensity Scenarios\n\nAlthough AMX is highly efficient for large-scale matrix multiplication, it performs poorly under low arithmetic intensity, such as vector-matrix operations in the decode phase. This is because dispatching AMX Tiles involves fixed instruction overhead, which becomes wasteful when the data volume is insufficient to fill a Tile, causing reduced throughput.\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"amx_avx\" src=\"../assets/amx_avx.png\" width=60%>\n  </picture>\n</p>\n\nTo address this, we introduced a lightweight AVX-512 kernel as a complement. This kernel follows the same memory layout as the AMX kernel but replaces heavy AMX matrix-matrix multiplications with fine-grained AVX-512 vector-matrix multiplications, lowering latency for small matrices.\n\nKTransformers dynamically selects between AMX and AVX-512 kernels at runtime based on arithmetic intensity: AMX kernels are automatically selected during long prompt prefill phases (where each expert handles more than 4 tokens on average), while short prompt prefill and decode phases dynamically switch to AVX-512 kernels. This ensures optimal efficiency under different arithmetic intensity conditions.\n\n## 4. MoE Operator Fusion and Dynamic Scheduling\n\nMoE models have many experts per layer, each requiring three matrix multiplications (Gate, Up, Down projections), leading to many small matrix multiplication tasks. Independently scheduling each small task would cause massive synchronization overhead between threads, dragging down overall inference speed.\n\nThus, we fused the same type of matrix computations for all experts in a layer into large unified tasks. Furthermore, as there are no data dependencies between Gate and Up projections, their computations can also be fused, ultimately consolidating a layer’s matrix multiplications into two major tasks, greatly reducing scheduling overhead.\n\nTo address load imbalance — especially during the prefill phase where expert activations can be highly skewed — we introduced a dynamic task scheduling strategy. Each matrix multiplication task is further split into multiple fine-grained sub-tasks, evenly distributed among CPU threads initially. Once a thread completes its assigned tasks, it atomically \"steals\" tasks from others, greatly mitigating load imbalance and achieving near-optimal CPU resource utilization.\n\nThanks to these optimizations, our kernel can achieve 21 TFLOPS of BF16 throughput and 35 TOPS of Int8 throughput on Xeon4 CPUs — about 4× faster than PyTorch’s general AMX kernel. For DeepSeek-V3, pairing a Xeon4 CPU with a single RTX 4090 GPU achieves 418 tokens/s end-to-end throughput, close to the performance of multi-machine, multi-GPU setups. KTransformers’ AMX kernel is the first AMX kernel specifically designed for MoE inference scenarios, significantly lowering the hardware barrier for large model deployment and enabling more developers to enjoy GPU cluster level inference experiences at lower cost.\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"onednn_1\" src=\"../assets/onednn_1.png\" width=60%>\n  </picture>\n</p>\n\n# Usage\n\n## Checking AMX Support\n\nBefore enabling the AMX-optimized kernels, it is important to verify whether your CPU supports the AMX instruction set. You can check AMX availability with the following command:\n\n```bash\nlscpu | grep -i amx\n```\n\nIf your system supports AMX, you should see output similar to:\n\n```bash\nFlags: ... amx-bf16 amx-int8 amx-tile ...\n```\n\nIf no amx-related flags are found, your CPU may not support AMX, or AMX may be disabled in BIOS settings. In that case, please ensure that:\n- You are using a Sapphire Rapids (Xeon 4th Gen) or newer CPU.\n- AMX support is enabled in your system BIOS under CPU feature settings.\n\n## Enabling AMX in KTransformers\n\nKTransformers allows users to easily switch between different backends through simple YAML configuration modifications. To enable AMX, modify the injection configuration of your experts by specifying backend as AMXInt8 or AMXBF16:\n\n```YAML\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts    # custom MoE Kernel with expert parallelism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\"  # or \"AMXBF16\" or \"llamafile\" (default)\n```\n\n**Note:** Currently, using AMXInt8 requires reading weights from a BF16 GGUF file and performing online quantization during model loading. This may cause slightly slower load times. Future versions will provide pre-quantized weights to eliminate this overhead.\n\n![Image](https://github.com/user-attachments/assets/7c33c410-3af9-456f-aa67-5b24e19ba680)\n"
  },
  {
    "path": "doc/en/DeepseekR1_V3_tutorial.md",
    "content": "<!-- omit in toc -->\n\n# GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM\n\n- [SUMMARY](#summary)\n  - [Show Case Environment](#show-case-environment)\n  - [Bench Result](#bench-result)\n    - [V0.2.1](#v021)\n      - [Memory consumption:](#memory-consumption)\n      - [Change Log](#change-log)\n      - [Benchmark Results](#benchmark-results)\n    - [V0.2](#v02)\n      - [Settings](#settings)\n      - [Memory consumption:](#memory-consumption-1)\n      - [Benchmark Results](#benchmark-results-1)\n    - [V0.3-Preview](#v03-preview)\n      - [Settings](#settings-1)\n      - [Memory consumptions:](#memory-consumptions)\n      - [Benchmark results](#benchmark-results-2)\n  - [How to Run](#how-to-run)\n    - [v0.2.2 \\& v0.2.3 longer context \\& FP8 kernel](#v022--v023-longer-context--fp8-kernel)\n      - [longer context](#longer-context)\n      - [FP8 kernel](#fp8-kernel)\n    - [V0.2 \\& V0.2.1 Showcase](#v02--v021-showcase)\n      - [Single socket version (32 cores)](#single-socket-version-32-cores)\n      - [Dual socket version (64 cores)](#dual-socket-version-64-cores)\n    - [V0.3 Showcase](#v03-showcase)\n      - [Dual socket version (64 cores)](#dual-socket-version-64-cores-1)\n  - [Some Explanations](#some-explanations)\n  - [Next](#next)\n    - [Faster](#faster)\n    - [Easier](#easier)\n  - [FAQ](#faq)\n    - [R1 No Thinking](#r1-no-thinking)\n    - [More FAQ](#more-faq)\n\n# SUMMARY\n\n> **Feb 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~28x speedup.<br>\n\nHi, we're the KTransformers team (formerly known for our local CPU/GPU hybrid inference open source project with DeepSeek-V2).\n\nWe've heard your requests for DeepSeek-R1/V3 support—and we're excited to finally deliver!\nApologies for the wait, but we've been cooking up something truly amazing!\n\nToday, we're proud to announce that we not only support DeepSeek-R1/V3, as showcased in the video below:\n\nhttps://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285\n\n</p>\n\n- **[NEW!!!] Local 671B DeepSeek-Coder-V3/R1:** Running its Q4_K_M version using only 14GB VRAM and 382GB DRAM.\n  - Prefill Speed (tokens/s):\n    - KTransformers: 54.21 (32 cores) → 74.362 (dual-socket, 2×32 cores) → 255.26 (optimized AMX-based MoE kernel, V0.3 only) → 286.55 (selectively using 6 experts, V0.3 only)\n    - Compared to 10.31 tokens/s in llama.cpp with 2×32 cores, achieving up to **27.79× speedup**.\n  - Decode Speed (tokens/s):\n    - KTransformers: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only)\n    - Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**.\n\nWe also give our upcoming optimizations previews, including an Intel AMX-accelerated kernel and a selective expert activation method, which will significantly enhance performance. With V0.3-preview, we achieve up to 286 tokens/s for prefill, making it up to **28× faster than llama.cpp** for local inference.\nThe binary distribution is available now and the source code will come ASAP! Check out the wheel package [here](https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl)\n\n> **Feb 15, 2025**: KTransformers V0.2.1: Longer Context (from 4K to 8K for 24GB VRAM) & Slightly Faster Speed （+15%) (Up to 16 Tokens/s), update docs [here](./doc/en/DeepseekR1_V3_tutorial.md) and [online books](https://kvcache-ai.github.io/ktransformers/).\n\nWe speed up the decode and prefill speed a littlt bit. The reason for the limited performance improvement mainly lies in the fact that the inference process is still constrained by the CPU's computational speed and memory bandwidth. The MLA part handled by the GPU accounts for a relatively small proportion.\n\nBesides the improvements in speed, we've also significantly updated the documentation to enhance usability, including:<br>\n\n- Added Multi-GPU configuration tutorial.\n- Consolidated installation guide.\n- Add a detailed tutorial on registering extra GPU memory with ExpertMarlin;\n\n## Show Case Environment\n\nWe run our best performance tests (V0.2) on <br>\nCPU: Intel (R) Xeon (R) Gold 6454S 1T DRAM (2 NUMA nodes) <br>\nGPU: 4090D 24G VRAM <br>\nMemory: standard DDR5-4800 server DRAM (1 TB), each socket with 8×DDR5-4800\n\n## Bench Result\n\n### V0.2.1\n\n- Model: DeepseekV3-q4km (int4)<br>\n- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 sockets, 2 numa nodes\n- GPU: 4090 24G VRAM\n- We test after enough warm up\n\n#### Memory consumption:\n\n- Single socket: 382G DRAM, at least 14GB VRAM\n- Dual socket: 1T DRAM, at least 14GB VRAM\n\n#### Change Log\n\n- Longer Context (from 4K to 8K for 24GB VRAM) and Slightly Faster Speed （+15%):<br>\n  Integrated the highly efficient Triton MLA Kernel from the fantastic sglang project, enable much longer context length and slightly faster prefill/decode speed\n- We suspect that some of the improvements come from the change of hardware platform (4090D->4090)\n\n#### Benchmark Results\n\n\"6 experts\" case is part of V0.3's preview\n\n\n| Prompt               | hi (2)   | 1K (969)  | 2K (1930) | 4K (3846)               | 8K (7678) |\n| -------------------- | -------- | --------- | --------- | ----------------------- | --------- |\n| Output length        | 10tokens | 300tokens | 300tokens | 300tokens               | 300tokens |\n| **6 experts V0.2.0** |          |           |           |                         |           |\n| Prefill token/s      | 13       | 105       | 102       | 88                      | CUDA OOM  |\n| decode token/s       | 16.8     | 15.4      | 14.2      | 13.0                    | CUDA OOM  |\n| **6 experts V0.2.1** |          |           |           |                         |           |\n| Prefill token/s      | 13       | 111       | 112.5     | 102**(1.16x speedup)**  | 101       |\n| decode token/s       | 16.8     | 15.9      | 15.4      | 14.9**(1.15x speedup)** | 13.9      |\n| **8 experts V0.2.1** |          |           |           |                         |           |\n| Prefill token/s      | 12.2     | 88.2      | 88.5      | 81.9                    | 80        |\n| Decode token/s       | 13.4     | 13.5      | 13.4      | 13.2                    | 12.4      |\n\n### V0.2\n\n#### Settings\n\n- Model: DeepseekV3-q4km (int4)<br>\n- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 sockets, 2 numa nodes\n- GPU: 4090D 24G VRAM\n- We test after enough warm up\n\n#### Memory consumption:\n\n- Single socket: 382G DRAM, at least 14GB VRAM\n- Dual socket: 1T DRAM, at least 14GB VRAM\n\n#### Benchmark Results\n\n\"6 experts\" case is part of V0.3's preview\n\n\n| Prompt<br>(500 tokens) | Dual socket Ktrans (6 experts) | Dual socket Ktrans (8 experts) | Single socket Ktrans (6 experts) | Single socket Ktrans (8 experts) | llama.cpp (8 experts) |\n| ---------------------- | ------------------------------ | ------------------------------ | -------------------------------- | -------------------------------- | --------------------- |\n| Prefill token/s        | 97.32                          | 82.94                          | 65.14                            | 54.21                            | 10.31                 |\n| Decode token/s         | 13.69                          | 12.208                         | 10.303                           | 8.73                             | 4.51                  |\n\n**The highest speedup reaches up to <u>3.03x</u> in decoding and <u>9.44x</u> in prefill.**\n\n### V0.3-Preview\n\n#### Settings\n\n- Model: DeepseekV3-BF16 (online quant into int8 for CPU and int4 for GPU)\n- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S, 32 cores per socket, 2 socket, 2 numa nodes\n- GPU: (1~4)x 4090D 24GVRAM (requires more VRAM for longer prompt)\n\n#### Memory consumptions:\n\n- 644GB DRAM, at least 14GB VRAM\n\n#### Benchmark results\n\n\n| Prompt length                      | 1K     | 2K     | 4K     | 8K     |\n| ---------------------------------- | ------ | ------ | ------ | ------ |\n| KTrans (8 experts) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 |\n| KTrans (6 experts) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 |\n\n**The prefill of KTrans V0.3 is up to <u>3.45x</u> times faster than KTrans V0.2, and is up to <u>27.79x</u> times faster than llama.cpp.**\n**The decoding speed is the same as KTrans V0.2 (6 experts version) so it is omitted**\n\nThe main acceleration comes from\n\n- Intel AMX instruction set and our specially designed cache friendly memory layout\n- Expert selection strategy that selects fewer experts based on offline profile results of out of domain data\n\n*From our research on DeepSeekV2, DeepSeekV3 and DeepSeekR1,\nwhen we slightly decrease the activation experts num in inference,\nthe output quality doesn't change. But the speed of decoding and prefill\nis speed up which is inspiring. So our showcase makes use of this finding*\n\n## How to Run\n\n### v0.2.4 \nWe provide a server script, which supports multi-concurrency functionality in version v0.2.4.\n\n```\npython ktransformers/server/main.py --model_path /mnt/data/models/DeepSeek-V3 --gguf_path /mnt/data/models/DeepSeek-V3-GGUF/DeepSeek-V3-Q4_K_M/ --cpu_infer 62 --optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml --port 10002 --chunk_size 256 --max_new_tokens 1024 --max_batch_size 4 --port 10002 --cache_lens 32768 --backend_type balance_serve\n```\nIt features the following arguments:\n\n- `--chunk_size`: Maximum number of tokens processed in a single run by the engine.\n- `--cache_lens`: Total length of kvcache allocated by the scheduler. All requests share a kvcache space corresponding to 32768 tokens, and the space occupied will be released after the requests are completed.\n- `--backend_type`: `balance_serve` is a multi-concurrency backend engine introduced in version v0.2.4. The original single-concurrency engine is `ktransformers`.\n- `--max_batch_size`: Maximum number of requests (prefill + decode) processed in a single run by the engine. (Supported only by `balance_serve`)\n\n### v0.2.2 & v0.2.3 longer context & FP8 kernel\n\n#### longer context\n\nTo use this feature, [install flashinfer](https://github.com/flashinfer-ai/flashinfer) first.\n\nNote: The latest MLA kernel in FlashInfer still has a few minor issues. They are continuously fixing them on the main branch. If you are using FlashInfer, please install it from the main source code.\n\nIf you want to use long context(longer than 20K) for prefill, enable the matrix absorption MLA during the prefill phase, which will significantly reduce the size of the kv cache. Modify yaml file like this:\n\n```\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: True # change this to True to enable long context(prefill may slower).\n```\n\nIf the VRAM is still insufficient, try reducing the `chunk_size` parameter (default is 8192) to further decrease the intermediate results during chunk prefill.\n\n#### FP8 kernel\n\nThe DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works:\n\n- **FP8 GPU Kernel Integration**: FP8 linear layer acceleration kernels integrated in KTransformers\n- **Hybrid Quantization Architecture**:\n  - Attention and Shared-Expert modules use FP8 precision (enhances computational accuracy)\n  - Experts modules retain GGML quantization (GGUF format, reside in CPU to save GPU memory)\n\nSo those who are persuing the best performance can use the FP8 linear kernel for DeepSeek-V3/R1.\n\nThe detailed guide is [here](./fp8_kernel.md).\n\n### V0.2 & V0.2.1 Showcase\n\n#### Single socket version (32 cores)\n\nOur local_chat test command is:\n\n```shell\nnumactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path>  --prompt_file <your prompt txt file>  --cpu_infer 33 --max_new_tokens 1000\n<when you see chat, then press enter to load the text prompt_file>\n```\n\n`<your model path>` can be local or set from online huggingface like deepseek-ai/DeepSeek-V3. If online encounters connection problem, try use mirror (hf-mirror.com) <br>\n`<your gguf path>` can also be online, but as its large we recommend you download it and quantize the model to what you want (notice it's the dir path) <br>\n`--max_new_tokens 1000` is the max output token length. If you find the answer is truncated, you\ncan increase the number for longer answer (But be aware of OOM, and increase it will slow down the generation rate.).\n\nThe command `numactl -N 1 -m 1` aims to avoid data transfer between numa nodes<br>\nAttention! If you are testing R1 and it may skip thinking. So you can add arg: `--force_think true`. This is explained in [FAQ](#faq) part\n\n#### Dual socket version (64 cores)\n\nMake sure before you install (use install.sh or `make dev_install`), setting the env var `USE_NUMA=1` by `export USE_NUMA=1` (if already installed, reinstall it with this env var set). You may check the doc [here](./install.md) for install details. <br>\n\nTest Command:\n\n```shell\n# ---For those who have not installed ktransformers---\n# git clone https://github.com/kvcache-ai/ktransformers.git\n# cd ktransformers\n# git submodule init\n# git submodule update\n# export USE_NUMA=1\n# make dev_install # or sh ./install.sh\n# ----------------------------------------------------\npython ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path>  --prompt_file <your prompt txt file>  --cpu_infer 65 --max_new_tokens 1000\n<when you see chat, then press enter to load the text prompt_file>\n```\n\nThe parameters' meaning is the same. But As we use dual socket, we set cpu_infer to 65\n\n### V0.3 Showcase\n\n#### Dual socket version (64 cores)\n\nOur local_chat test command is:\n\n```shell\nwget https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl\npip install ./ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl\npython -m ktransformers.local_chat --model_path <your model path> --gguf_path <your gguf path>  --prompt_file <your prompt txt file>  --cpu_infer 65 --max_new_tokens 1000\n<when you see chat, then press enter to load the text prompt_file>\n```\n\nThe parameters' meaning is the same with V0.2. But As we  use dual socket, we set cpu_infer to 65\n\n## Some Explanations\n\n1. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu.\n   To avoid the cost of data transfer between nodes, we \"copy\" the critical matrix on\n   both nodes which takes more memory consumption but accelerates the prefill and decoding process.\n   But this method takes huge memory and slow when loading weights, So be patient when loading\n   and monitor the memory usage. We are going to optimize this huge memory overhead. Stay tuned~ <br>\n2. The command args `--cpu_infer 65` specifies how many cores to use (it's ok that it exceeds the physical number,\n   but it's not the more the better. Adjust it slightly lower to your actual number of cores)<br>\n3. Why CPU/GPU Hybrid Inference?\n   DeepSeek's MLA operators are highly computationally intensive. While running everything on CPU is possible, offloading the heavy computations to the GPU results in a massive performance boost.\n4. Where Does the Speedup Come From?\n\n   - Expert Offload: Unlike traditional layer-based or KVCache offloading (as seen in llama.cpp), we offload the expert computation to the CPU and MLA/KVCache to GPU, aligning perfectly with DeepSeek’s architecture for optimal efficiency.\n   - Intel AMX Optimization – Our AMX-accelerated kernel is meticulously tuned, running several times faster than existing llama.cpp implementations. We plan to open-source this kernel after cleansing and are considering upstream contributions to llama.cpp.\n5. Why Intel CPUs?\n   Intel is currently the only CPU vendor that supports AMX-like instructions, which delivers significantly better performance compared to AVX-only alternatives.\n\n## Next\n\n### Faster\n\n* The FlashInfer (https://github.com/flashinfer-ai/flashinfer) project is releasing an even more efficient fused MLA operator, promising further speedups\n* vLLM has explored multi-token prediction in DeepSeek-V3, and support is on our roadmap for even better performance\n* We are collaborating with Intel to enhance the AMX kernel (v0.3) and optimize for Xeon6/MRDIMM\n\n### Easier\n\n* Official Docker images to simplify installation\n* Fix the server integration for web API access\n* Fix the local chat only accepting a single line prompt (currently \\n begins generating prompt)\n* Support for more quantization types, including the highly requested dynamic quantization from unsloth\n\nStay tuned for more updates!\n\n## FAQ\n\n### R1 No Thinking\n\nAttention! If you are testing R1 and it may skip thinking. So you can add arg: `--force_think true`. The detail is in [FAQ](./FAQ.md) part <br>\n\n### More FAQ\n\n[See detail](./FAQ.md)\n"
  },
  {
    "path": "doc/en/Docker.md",
    "content": "# Docker\n\n## Prerequisites\n* Docker must be installed and running on your system.\n* Create a folder to store big models & intermediate files (ex. /mnt/models)\n\n## Images\nThere is a Docker image available for our project, you can pull the docker image by：\n```\ndocker pull approachingai/ktransformers:0.2.1\n```\n**Notice**: In this image, we compile the ktransformers in AVX512 instuction CPUs, if your cpu not support AVX512, it is suggested to recompile and install ktransformers in the /workspace/ktransformers directory within the container.\n\n## Building docker image locally\n - Download Dockerfile in [there](../../Dockerfile)\n\n - finish, execute\n   ```bash\n   docker build  -t approachingai/ktransformers:0.2.1 .\n   ```\n\n## Usage\n\nAssuming you have the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) that you can use the GPU in a Docker container.\n```\ndocker run --gpus all -v /path/to/models:/models --name ktransformers -itd approachingai/ktransformers:0.2.1\ndocker exec -it ktransformers /bin/bash\npython -m ktransformers.local_chat  --gguf_path /models/path/to/gguf_path --model_path /models/path/to/model_path --cpu_infer 33\n```\n\nMore operators you can see in the [readme](../../README.md)"
  },
  {
    "path": "doc/en/Docker_xpu.md",
    "content": "# Intel GPU Docker Guide (Beta)\n\n## Prerequisites\n\n* Docker must be installed and running on your system.\n* Create a folder to store big models & intermediate files (e.g., /mnt/models)\n* **Before proceeding, ensure the Intel GPU driver is installed correctly on your host:** [Installation Guide](./xpu.md#1-install-intel-gpu-driver)\n\n---\n\n## Building the Docker Image Locally\n\n1. Clone the repository and navigate to the project directory:\n\n   ```bash\n   git clone https://github.com/kvcache-ai/ktransformers.git\n   cd ktransformers\n   ```\n\n2. Build the Docker image using the XPU-specific [Dockerfile.xpu](../../Dockerfile.xpu):\n\n   ```bash\n   sudo http_proxy=$HTTP_PROXY \\\n        https_proxy=$HTTPS_PROXY \\\n        docker build \\\n          --build-arg http_proxy=$HTTP_PROXY \\\n          --build-arg https_proxy=$HTTPS_PROXY \\\n          -t kt_xpu:0.3.1 \\\n          -f Dockerfile.xpu \\\n          .\n   ```\n\n---\n\n## Running the Container\n\n### 1. Start the container\n\n```bash\nsudo docker run -td --privileged \\\n    --net=host \\\n    --device=/dev/dri \\\n    --shm-size=\"16g\" \\\n    -v /path/to/models:/models \\\n    -e http_proxy=$HTTP_PROXY \\\n    -e https_proxy=$HTTPS_PROXY \\\n    --name ktransformers_xpu \\\n    kt_xpu:0.3.1\n```\n\n**Note**: Replace `/path/to/models` with your actual model directory path (e.g., `/mnt/models`).\n\n---\n\n### 2. Access the container\n\n```bash\nsudo docker exec -it ktransformers_xpu /bin/bash\n```\n\n---\n\n### 3. Set required XPU environment variables (inside the container)\n\n```bash\nexport SYCL_CACHE_PERSISTENT=1\nexport ONEAPI_DEVICE_SELECTOR=level_zero:0\nexport SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1\n```\n\n---\n\n### 4. Run the sample script\n\n```bash\npython ktransformers/local_chat.py \\\n  --model_path deepseek-ai/DeepSeek-R1 \\\n  --gguf_path <path_to_gguf_files> \\\n  --optimize_config_path ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml \\\n  --cpu_infer <cpu_cores + 1> \\\n  --device xpu \\\n  --max_new_tokens 200\n```\n\n**Note**:\n\n* Replace `<path_to_gguf_files>` with the path to your GGUF model files.\n* Replace `<cpu_cores + 1>` with the number of CPU cores you want to use plus one.\n\n---\n\n## Additional Information\n\nFor more configuration options and usage details, refer to the [project README](../../README.md). To run KTransformers natively on XPU (outside of Docker), please refer to [xpu.md](./xpu.md).\n"
  },
  {
    "path": "doc/en/FAQ.md",
    "content": "<!-- omit in toc -->\n# see the issue [FAQ page](https://github.com/kvcache-ai/ktransformers/issues/1608)"
  },
  {
    "path": "doc/en/Kimi-K2-Thinking.md",
    "content": "# KTransformers+SGLang Inference Deployment\nPlease Note This is Quantization Deployment. For Native Kimi K2 Thinking deployment please refer to [here](./Kimi-K2-Thinking-Native.md).\n\n## Installation\n\nStep 1: Install SGLang\n\nInstall the kvcache-ai fork of SGLang (one of):\n```bash\n# Option A: One-click install (from ktransformers root)\n./install.sh\n\n# Option B: pip install\npip install sglang-kt\n```\n\n> **Important:** Use `sglang-kt` (kvcache-ai fork), not the official `sglang` package. Run `pip uninstall sglang` first if you have the official version installed.\n\nStep 2: Install KTransformers CPU Kernels\n\nThe KTransformers CPU kernels (kt-kernel) provide AMX-optimized computation for hybrid inference, for detailed installation instructions and troubleshooting, refer to the official [kt-kernel installation guide](https://github.com/kvcache-ai/ktransformers/blob/main/kt-kernel/README.md).\n\n## Download Model\n\nDownload the official KIMI weights as GPU weights.\n\n* huggingface: https://huggingface.co/moonshotai/Kimi-K2-Thinking\n* modelscope: https://modelscope.cn/models/moonshotai/Kimi-K2-Thinking\n\nDownload the AMX INT4 quantized weights from https://huggingface.co/KVCache-ai/Kimi-K2-Thinking-CPU-weight as CPU weights.\n\n## How to start\n```\npython -m sglang.launch_server   --host 0.0.0.0   --port 60000   --model path/to/Kimi-K2-Thinking/   --kt-weight-path path/to/Kimi-K2-Instruct-CPU-weight/   --kt-cpuinfer 56   --kt-threadpool-count 2   --kt-num-gpu-experts 200   --kt-method AMXINT4   --attention-backend flashinfer   --trust-remote-code   --mem-fraction-static 0.98   --chunked-prefill-size 4096   --max-running-requests 37   --max-total-tokens 37000   --enable-mixed-chunk   --tensor-parallel-size 8   --enable-p2p-check   --disable-shared-experts-fusion\n```\ntips:\n\n`--kt-cpuinfer`: is recommended to be set to (number of physical CPU cores - 8 (number of GPUs)).\n\n`--kt-num-gpu-experts`: refers to the number of experts retained on GPUs, which should be adjusted according to your available GPU memory and expected KV cache space.\n\n## Test\n\nWhen testing, you need to add `--disable-radix-cache` and `--disable-chunked-prefix-cache` when starting the server.\n\n### bench prefill\n```\npython -m sglang.bench_serving   --backend sglang   --host 127.0.0.1   --port 60000   --num-prompts 37 --random-input-len 1024 --random-output-len 1 --random-range-ratio 1.0 --dataset-name random\n```\n\n### bench decode\n```\npython -m sglang.bench_serving   --backend sglang   --host 127.0.0.1   --port 60000   --num-prompts 37 --random-input-len 10 --random-output-len 512 --random-range-ratio 1.0 --dataset-name random\n```\n\n## Performance\n\n### System Configuration:\n\n- GPUs: 8× NVIDIA L20\n- CPU: Intel(R) Xeon(R) Gold 6454S\n\n### Bench prefill\n```\n============ Serving Benchmark Result ============\nBackend:                                 sglang\nTraffic request rate:                    inf\nMax request concurrency:                 not set\nSuccessful requests:                     37\nBenchmark duration (s):                  65.58\nTotal input tokens:                      37888\nTotal input text tokens:                 37888\nTotal input vision tokens:               0\nTotal generated tokens:                  37\nTotal generated tokens (retokenized):    37\nRequest throughput (req/s):              0.56\nInput token throughput (tok/s):          577.74\nOutput token throughput (tok/s):         0.56\nTotal token throughput (tok/s):          578.30\nConcurrency:                             23.31\n----------------End-to-End Latency----------------\nMean E2E Latency (ms):                   41316.50\nMedian E2E Latency (ms):                 41500.35\n---------------Time to First Token----------------\nMean TTFT (ms):                          41316.48\nMedian TTFT (ms):                        41500.35\nP99 TTFT (ms):                           65336.31\n---------------Inter-Token Latency----------------\nMean ITL (ms):                           0.00\nMedian ITL (ms):                         0.00\nP95 ITL (ms):                            0.00\nP99 ITL (ms):                            0.00\nMax ITL (ms):                            0.00\n==================================================\n```\n\n### Bench decode\n\n```\n============ Serving Benchmark Result ============\nBackend:                                 sglang\nTraffic request rate:                    inf\nMax request concurrency:                 not set\nSuccessful requests:                     37\nBenchmark duration (s):                  412.66\nTotal input tokens:                      370\nTotal input text tokens:                 370\nTotal input vision tokens:               0\nTotal generated tokens:                  18944\nTotal generated tokens (retokenized):    18618\nRequest throughput (req/s):              0.09\nInput token throughput (tok/s):          0.90\nOutput token throughput (tok/s):         45.91\nTotal token throughput (tok/s):          46.80\nConcurrency:                             37.00\n----------------End-to-End Latency----------------\nMean E2E Latency (ms):                   412620.35\nMedian E2E Latency (ms):                 412640.56\n---------------Time to First Token----------------\nMean TTFT (ms):                          3551.87\nMedian TTFT (ms):                        3633.59\nP99 TTFT (ms):                           3637.37\n---------------Inter-Token Latency----------------\nMean ITL (ms):                           800.53\nMedian ITL (ms):                         797.89\nP95 ITL (ms):                            840.06\nP99 ITL (ms):                            864.96\nMax ITL (ms):                            3044.56\n==================================================\n```\n"
  },
  {
    "path": "doc/en/Kimi-K2.5.md",
    "content": "# Running Kimi-K2.5 with SGLang and KT-Kernel\n\nThis tutorial demonstrates how to run Kimi-K2.5 model inference using SGLang integrated with KT-Kernel for CPU-GPU heterogeneous inference. This setup enables efficient deployment of large MoE models by offloading experts to CPU.\n\n## Table of Contents\n\n- [Hardware Requirements](#hardware-requirements)\n- [Prerequisites](#prerequisites)\n- [Step 1: Download Model Weights](#step-1-download-model-weights)\n- [Step 2: Launch SGLang Server](#step-2-launch-sglang-server)\n- [Step 3: Send Inference Requests](#step-3-send-inference-requests)\n\n## Hardware Requirements\n\n**Minimum Configuration:**\n- **GPU**: NVIDIA RTX 2x4090 48GB (or equivalent with at least total 48GB VRAM available)\n- **CPU**: x86 CPU with AVX512F support (e.g., Intel Sapphire Rapids)\n- **RAM**: At least 600GB system memory\n- **Storage**: ~600GB for model weights (native INT4 weight, same weight folder for CPU and GPU)\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n1. **KT-Kernel installed**:\n\n   Note: Latest KTransformers' EPLB feature for Kimi-K2.5 will be supported soon.\n\n```\ngit clone https://github.com/kvcache-ai/ktransformers.git\ngit submodule update --init --recursive\ncd kt-kernel && ./install.sh\n```\n\n2. **SGLang installed** - Install the kvcache-ai fork of SGLang (one of):\n\n```bash\n# Option A: One-click install (from ktransformers root)\n./install.sh\n\n# Option B: pip install\npip install sglang-kt\n```\n\n> Note: You may need to reinstall cudnn: `pip install nvidia-cudnn-cu12==9.16.0.29`\n\n3. **CUDA toolkit** - Compatible with your GPU (CUDA 12.8+ recommended)\n4. **Hugging Face CLI** - For downloading models:\n   \n   ```bash\n   pip install huggingface-hub\n   ```\n\n## Step 1: Download Model Weights\n\n```bash\n# Create a directory for models\nmkdir -p /path/to/models\ncd /path/to/models\n\n# Download Kimi-K2.5 (RAW-INT4 for both CPU and GPU)\nhuggingface-cli download moonshotai/Kimi-K2.5 \\\n  --local-dir /path/to/kimi-k2.5\n```\n\n**Note:** Replace `/path/to/models` with your actual storage path throughout this tutorial.\n\n## Step 2: Launch SGLang Server\n\nStart the SGLang server with KT-Kernel integration for CPU-GPU heterogeneous inference.\n\n\n### Launch Command (4x RTX 4090 Example)\n\n```bash\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 31245 \\\n  --model /path/to/kimi-k2.5 \\\n  --kt-weight-path /path/to/kimi-k2.5 \\\n  --kt-cpuinfer 96 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 30 \\\n  --kt-method RAWINT4 \\\n  --kt-gpu-prefill-token-threshold 400 \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.94 \\\n  --served-model-name Kimi-K2.5 \\\n  --enable-mixed-chunk \\\n  --tensor-parallel-size 4 \\\n  --enable-p2p-check \\\n  --disable-shared-experts-fusion \\\n  --chunked-prefill-size 32658 \\\n  --max-total-tokens 50000 \\\n  --attention-backend flashinfer\n```\n\nIt takes about 2~3 minutes to start the server.\n\nSee [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines.\n\n## Step 3: Send Inference Requests\n\nOnce the server is running, you can send inference requests using the OpenAI-compatible API.\n\n### Basic Chat Completion Request\n\n```bash\ncurl -s http://localhost:31245/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"Kimi-K2.5\",\n    \"stream\": false,\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hi, who are you?\"}\n    ]\n  }'\n```\n\n### Example Response\n\n```json\n{\n    \"id\": \"2a4e83f8a79b4b57b103b0f298fbaa7d\",\n    \"object\": \"chat.completion\",\n    \"created\": 1769333912,\n    \"model\": \"Kimi-K2.5\",\n    \"choices\": [\n        {\n            \"index\": 0,\n            \"message\": {\n                \"role\": \"assistant\",\n                \"content\": \" The user is asking \\\"hi, who are you?\\\" which is a simple greeting and identity question. I need to respond appropriately by introducing myself clearly and concisely.\\n\\nI am Kimi, a large language model trained by Moonshot AI. I should state my name, my nature (AI assistant), and my developer (Moonshot AI). I should keep it friendly and helpful.\\n\\nKey points to include:\\n- Greet them back (\\\"hi\\\" or \\\"hello\\\")\\n- State my name: Kimi\\n- State what I am: an AI assistant/language model\\n- Mention my developer: Moonshot AI\\n- Briefly describe my purpose: to help answer questions, provide information, and assist with various tasks\\n- Keep it concise but informative\\n- Use a friendly, professional tone\\n\\nI should avoid overly technical jargon while being accurate. The response should be welcoming and set the stage for further interaction.\\n\\nPossible response:\\n\\\"Hi! I'm Kimi, an AI assistant created by Moonshot AI. I'm designed to help answer questions, provide information, and assist with a wide range of tasks. How can I help you today?\\\"\\n\\nThis covers all the necessary points and invites the user to continue the conversation. </think> Hi! I'm Kimi, an AI assistant created by Moonshot AI. I'm designed to help answer questions, provide information, and assist with a wide range of tasks. How can I help you today?\",\n                \"reasoning_content\": null,\n                \"tool_calls\": null\n            },\n            \"logprobs\": null,\n            \"finish_reason\": \"stop\",\n            \"matched_stop\": 163586\n        }\n    ],\n    \"usage\": {\n        \"prompt_tokens\": 32,\n        \"total_tokens\": 317,\n        \"completion_tokens\": 285,\n        \"prompt_tokens_details\": null,\n        \"reasoning_tokens\": 0\n    },\n    \"metadata\": {\n        \"weight_version\": \"default\"\n    }\n}\n```\n"
  },
  {
    "path": "doc/en/Kimi-K2.md",
    "content": "# Kimi-K2 Support for KTransformers\n\n## Introduction\n\n### Overview\nWe are very pleased to announce that Ktransformers now supports Kimi-K2 and Kimi-K2-0905.\n\nOn a single-socket CPU with one consumer-grade GPU, running the Q4_K_M model yields roughly 10 TPS and requires about 600 GB of DRAM.  \nWith a dual-socket CPU and sufficient system memory, enabling NUMA optimizations increases performance to about 14 TPS.\n\n### Model & Resource Links\n\n- Official Kimi-K2 Release: \n  - https://huggingface.co/collections/moonshotai/kimi-k2-6871243b990f2af5ba60617d\n- GGUF Format(quantized models):\n  - https://huggingface.co/KVCache-ai/Kimi-K2-Instruct-GGUF\n- Official Kimi-K2-0905 Release:\n  - https://huggingface.co/moonshotai/Kimi-K2-Instruct-0905\n- GGUF Format(quantized models):\n  - https://huggingface.co/KVCache-ai/Kimi-K2-Instruct-0905-GGUF\n\n## Installation Guide\n\n### 1. Resource Requirements\n\nThe model running with 384 Experts requires approximately 600 GB of memory and 14 GB of GPU memory.\n\n### 2. Prepare Models\n\n```bash\n# download gguf\nhuggingface-cli download --resume-download KVCache-ai/Kimi-K2-Instruct-GGUF\n\n```\n\n### 3. Install ktransformers\n\nTo install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).\n\n### 4. Run Kimi-K2 Inference Server\n\n```bash\npython ktransformers/server/main.py \\\n  --port 10002 \\\n  --model_path <path_to_safetensor_config> \\\n  --gguf_path <path_to_gguf_files> \\\n  --optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml \\\n  --max_new_tokens 1024 \\\n  --cache_lens 32768 \\\n  --chunk_size 256 \\\n  --max_batch_size 4 \\\n  --backend_type balance_serve \\\n```\n\n### 5. Access server\n\n```\ncurl -X POST http://localhost:10002/v1/chat/completions \\\n  -H \"accept: application/json\" \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hello\"}\n    ],\n    \"model\": \"Kimi-K2\",\n    \"temperature\": 0.3,\n    \"top_p\": 1.0,\n    \"stream\": true\n  }'\n```\n"
  },
  {
    "path": "doc/en/Kllama_tutorial_DeepSeekV2Lite.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6201cdec-70f7-4c22-b988-b23ece31979d\",\n   \"metadata\": {},\n   \"source\": [\n    \"<div align=\\\"center\\\">\\n\",\n    \"  <!-- <h1>KTransformers</h1> -->\\n\",\n    \"  <p align=\\\"center\\\">\\n\",\n    \"\\n\",\n    \"<picture>\\n\",\n    \"    <img alt=\\\"KTransformers\\\" src=\\\"https://github.com/user-attachments/assets/d5a2492f-a415-4456-af99-4ab102f13f8b\\\" width=50%>\\n\",\n    \"\\n\",\n    \"</picture>\\n\",\n    \"\\n\",\n    \"</p>\\n\",\n    \"\\n\",\n    \"</div>\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5dcfddc6-d51b-4aa8-b887-f7c817492316\",\n   \"metadata\": {\n    \"jp-MarkdownHeadingCollapsed\": true\n   },\n   \"source\": [\n    \"# **Introduction**\\n\",\n    \"[KTransformers](https://github.com/kvcache-ai/ktransformers), is designed to enhance the 🤗 Transformers experience through advanced kernel optimizations and placement/parallelism strategies. \\n\",\n    \"<br/> <br/>\\n\",\n    \"This tutorial serves as a guide for KTransformers-ft, aiming to to give resource-constrained researchers a **local path to explore fine-tuning ultra-large models (e.g., 671B/1000B)**, and also a fast way to customize smaller models (e.g., 14B/30B) for specific scenarios. We validate the setup using representative tasks such as stylized dialogue, Westernized translation tone, and medical Q&A, demonstrating that personalized adaptation can be achieved within hours.\\n\",\n    \"<br/> <br/>\\n\",\n    \"This tutorial takes DeepSeek-V2-Lite as a code example; for more details, refer to [KTransformers-Fine-Tuning_User-Guide](https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/KTransformers-Fine-Tuning_User-Guide.md) and [KTransformers-Fine-Tuning_Developer-Technical-Notes](https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/KTransformers-Fine-Tuning_Developer-Technical-Notes.md).\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"b4167684-81f4-4e2b-a486-c33ec3bc92f0\",\n   \"metadata\": {},\n   \"source\": [\n    \"# **Installation**\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"5548a7f8-20d6-4ae4-a575-a3ef7a0ea5f8\",\n   \"metadata\": {},\n   \"source\": [\n    \"### **1. Install torch and clone the repo**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"6f39051d-eb14-44fa-af82-9ded23144985\",\n   \"metadata\": {\n    \"collapsed\": true,\n    \"jupyter\": {\n     \"outputs_hidden\": true\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git\\n\",\n    \"!cd LLaMA-Factory\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"e7dd351f-9102-4d7d-951c-4306df9f4cd7\",\n   \"metadata\": {},\n   \"source\": [\n    \"**(Optional)** If you want to choose your version of torch and cuda, please install separately.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8a5afa0c-1ed0-4190-ab50-967e553d6fd2\",\n   \"metadata\": {\n    \"collapsed\": true,\n    \"jupyter\": {\n     \"outputs_hidden\": true\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu118\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"711dcc79-056f-4483-a2e1-7e780af1def1\",\n   \"metadata\": {},\n   \"source\": [\n    \"### **2. Install LLaMA-Factory**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"id\": \"42f09df9-7db8-46e3-b11d-2946a57d2933\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"os.chdir(\\\"LLaMA-Factory\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"3a6a5532-e5cc-463b-bdf8-030e547287fc\",\n   \"metadata\": {\n    \"collapsed\": true,\n    \"jupyter\": {\n     \"outputs_hidden\": true\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!pip install -e \\\".[torch,metrics]\\\" --no-build-isolation\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"48c19762-70a7-402c-94f9-a71b277eb932\",\n   \"metadata\": {},\n   \"source\": [\n    \"### **3. Install dependency libraries for GCC and CUDA**\\n\",\n    \"You need to install system-level dependency libraries. `libstdcxx-ng` and `gcc_impl_linux-64` ensure compilation compatibility, while cuda-runtime provides a GPU-accelerated runtime environment. **Please do NOT IGNORE this two commands! `nvidia/label/cuda-11.8.0 cuda-runtime` should be installed for every version of cuda for KT whl.**\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"202e672a-b30a-4bde-92d5-27500f435b30\",\n   \"metadata\": {\n    \"collapsed\": true,\n    \"jupyter\": {\n     \"outputs_hidden\": true\n    },\n    \"scrolled\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!conda install -y -c conda-forge libstdcxx-ng gcc_impl_linux-64\\n\",\n    \"!conda install -y -c nvidia/label/cuda-11.8.0 cuda-runtime\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"94e6448f-1e27-4f16-885c-27738c2089dc\",\n   \"metadata\": {},\n   \"source\": [\n    \"### **4. Install ktransformers and flash-attention**\\n\",\n    \"You need to download the corresponding version of python, cuda and torch from [downloading ktransformers whl](https://github.com/kvcache-ai/ktransformers/releases/tag/v0.4.1) and [downloading flash-attention whl](https://github.com/Dao-AILab/flash-attention/releases).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"id\": \"7c4a5e82-ae9f-490f-9f90-441cdd98041e\",\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"True\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import torch\\n\",\n    \"print(torch._C._GLIBCXX_USE_CXX11_ABI)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"837a2240-818d-499f-a1b5-641fa5c45339\",\n   \"metadata\": {\n    \"collapsed\": true,\n    \"jupyter\": {\n     \"outputs_hidden\": true\n    },\n    \"scrolled\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!pip install ../ktransformers-0.4.1+cu128torch27fancy-cp312-cp312-linux_x86_64.whl\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e3c78d9e-26e0-4f85-94ff-d6b028b194ac\",\n   \"metadata\": {\n    \"collapsed\": true,\n    \"jupyter\": {\n     \"outputs_hidden\": true\n    }\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!pip install ../flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"2593e2cb-5fbd-4d66-94fc-d2d74c4d8f65\",\n   \"metadata\": {},\n   \"source\": [\n    \"# **How to Start**\\n\",\n    \"## Fine-tuning the Model with LoRA\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"f7db3349-8cdb-48cd-8b63-0ea70fe4af6f\",\n   \"metadata\": {},\n   \"source\": [\n    \"LoRA (Low-Rank Adaptation) fine-tuning only trains small \\\"adapter\\\" weights for large models. However, under traditional frameworks, it still needs more than 1400GB GPU VRAM, which hardly handles on the 4090s machine. **KTransformers**, as high-performance backend engine, provides a solution for GPU/CPU Hybrid devices to further cut GPU memory usage and speed up training. As shown below, we compare KTransformers(ours) with other common LoRA fine-tuning backends (HuggingFace and Unsloth). KTransformers is the **only workable 4090-class solution** for ultra-large MoE models (e.g., 671B) and also delivers higher fine-tuning throughput. <br/>\\n\",\n    \"<div style=\\\"text-align: center;\\\">\\n\",\n    \"<img src=\\\"https://typora-tuchuang-jimmy.oss-cn-beijing.aliyuncs.com/img/按照模型划分的对比图_02.png\\\" alt=\\\"kt_unsloth_huggingface_compare\\\" width=\\\"70%\\\" height=\\\"auto\\\">\\n\",\n    \"</div>\\n\",\n    \"\\n\",\n    \"To make KTransformers-ft more easy-to-use, we cooperator with [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory/), a easy and efficiency model fine-tuning framework. As shown below, LLaMA-Factory is the unified configuration layer for the whole fine-tuning workflow. **KTransformers** acts as a high-performance backend that takes over core operators like Attention/MoE under the same training configs, enabling efficient **GPU+CPU heterogeneous cooperation**. <br/>\\n\",\n    \"<div style=\\\"text-align: center;\\\">\\n\",\n    \"<img src=\\\"https://typora-tuchuang-jimmy.oss-cn-beijing.aliyuncs.com/img/image-20251011010558909.png\\\" alt=\\\"image-20251011010558909\\\" width=\\\"70%\\\" height=\\\"auto\\\">\\n\",\n    \"</div>\\n\",\n    \"\\n\",\n    \"This combination lets you fine-tune big models (like 671B/1000B) on consumer level GPUs (2-4 RTX 4090s) — no need for expensive hardware. Here’s the training command:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"baf5b8fc-e910-4531-9f00-a2076c698eff\",\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!USE_KT=1 llamafactory-cli train examples/train_lora/deepseek2_lora_sft_kt.yaml\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"dc80b189-17ac-47a7-9889-b77e7a9d5304\",\n   \"metadata\": {},\n   \"source\": [\n    \"Let’s break down the training command (`USE_KT=1 llamafactory-cli train examples/train_lora/deepseek2_lora_sft_kt.yaml`):\\n\",\n    \"- `USE_KT=1`: The \\\"switch\\\" to enable KTransformers optimization.  \\n\",\n    \"- `llamafactory-cli train`: The core command to start LLaMA-Factory’s fine-tuning tool.\\n\",\n    \"- `examples/train_lora/deepseek2_lora_sft_kt.yaml`: The configuration file that controls model, data, training rules and KTransformers settings — we’ll detail this next.\\n\",\n    \"\\n\",\n    \"**The LLaMA-Factory yaml (e.g. `deepseek2_lora_sft_kt.yaml`) is where you define how the fine-tuning works.** Below is a simplified version, you can use this directly for basic tasks like style transfer or domain Q&A. And We’ll explain each section’s purpose and why the values are set this way in the following part--Custom your KTransformers-FineTuning + LLaMA-Factory.\\n\",\n    \"```yaml\\n\",\n    \"### model\\n\",\n    \"model_name_or_path: deepseek-ai/DeepSeek-V2-Lite\\n\",\n    \"\\n\",\n    \"### method\\n\",\n    \"finetuning_type: lora\\n\",\n    \"lora_rank: 8\\n\",\n    \"lora_target: all\\n\",\n    \"\\n\",\n    \"### dataset\\n\",\n    \"dataset: identity\\n\",\n    \"template: deepseek\\n\",\n    \"cutoff_len: 2048\\n\",\n    \"max_samples: 100000\\n\",\n    \"\\n\",\n    \"### output\\n\",\n    \"output_dir: saves/Kllama_deepseekV2\\n\",\n    \"logging_steps: 10\\n\",\n    \"save_steps: 500\\n\",\n    \"\\n\",\n    \"### train\\n\",\n    \"per_device_train_batch_size: 1\\n\",\n    \"gradient_accumulation_steps: 8\\n\",\n    \"learning_rate: 1.0e-4\\n\",\n    \"num_train_epochs: 3.0\\n\",\n    \"\\n\",\n    \"### ktransformers\\n\",\n    \"use_kt: true # use KTransformers as LoRA sft backend\\n\",\n    \"kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml\\n\",\n    \"cpu_infer: 32\\n\",\n    \"chunk_size: 8192\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"dac7722d-89dd-40b1-ac27-7ca64e80fe47\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Chat with the Fine-tuned Model: Test Your Customized AI\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"9af428c6-4fce-4320-b3d3-af59726ab9ce\",\n   \"metadata\": {},\n   \"source\": [\n    \"After finishing fine-tuning with KTransformers, **the next step is to chat with your model and verify the results!** This step loads the original base model plus the fine-tuned \\\"custom plugin\\\" (LoRA adapter) you saved earlier, letting you interact with the model in real time.  \\n\",\n    \"\\n\",\n    \"We’ll use LLaMA-Factory’s `chat` command to launch the interactive interface. The core is the LLaMA-Factory YAML configuration file — it tells the tool which model to load, how to optimize inference, and what style of dialogue to use. We take one of the example as follows.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"37191db1-a97c-407c-9626-af9fde6dd94f\",\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!llamafactory-cli chat examples/inference/deepseek2_lora_sft_kt.yaml\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"06c18255-66d0-4189-a714-6050160a0637\",\n   \"metadata\": {},\n   \"source\": [\n    \"To know exactly what you’re running, we break down the full command (`llamafactory-cli chat examples/inference/deepseek2_lora_sft_kt.yaml`):\\n\",\n    \"- `llamafactory-cli chat`: The core command to launch LLaMA-Factory’s interactive chat tool.\\n\",\n    \"- `examples/inference/deepseek2_lora_sft_kt.yaml`: The configuration file for inference (controls model loading, optimization, and dialogue settings).\\n\",\n    \"- No need for `USE_KT=1` here — we’ll enable KTransformers directly in the YAML (but it still needs to match the training settings!).\\n\",\n    \"\\n\",\n    \"**The LLaMA-Factory configuration file for inference (`examples/inference/deepseek2_lora_sft_kt.yaml`) controls the generate config for specific tasks.** Below is a simplified version, you can use this directly to chat with your fine-tuned model. Most setting is linked to your training config — we’ll still explain the details in next part.\\n\",\n    \"```yaml\\n\",\n    \"model_name_or_path: deepseek-ai/DeepSeek-V2-Lite\\n\",\n    \"adapter_name_or_path: saves/Kllama_deepseekV2\\n\",\n    \"template: deepseek\\n\",\n    \"infer_backend: ktransformers  # choices: [huggingface, vllm, sglang, ktransformers]\\n\",\n    \"trust_remote_code: true\\n\",\n    \"\\n\",\n    \"use_kt: true # use KTransformers as LoRA sft backend to inference\\n\",\n    \"kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml\\n\",\n    \"cpu_infer: 32\\n\",\n    \"chunk_size: 8192\\n\",\n    \"```\\n\",\n    \"`kt_optimize_rule` needs as same as the kt_optimize_rule in LoRA Fine-tuning.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"18814c5c-3b73-44cc-a608-505c1e870437\",\n   \"metadata\": {},\n   \"source\": [\n    \"# **Custom your KTransformers-FineTuning + LLaMA-Factory**\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"8072427f-46d4-41fb-8850-e33a2446e031\",\n   \"metadata\": {},\n   \"source\": [\n    \"Once you’ve got the basic fine-tuning workflow down, you’ll likely want to **adapt the process to your specific needs**—whether that’s training on your own data, squeezing more performance out of limited GPU memory, or speeding up training for large datasets. Below’s a hands-on guide to customizing every part of the process, with clear explanations of why each setting matters and how to tweak it.\\n\",\n    \"\\n\",\n    \"## 1. Fine-tuning Customization: Tailor Training to Your Needs  \\n\",\n    \"To start customizing, you’ll still use the core training command: `USE_KT=1 llamafactory-cli train examples/train_lora/deepseek2_lora_sft_kt.yaml`. Notably, it performs even better than the default setup when adapted to your specific needs. <br/>\\n\",\n    \"### Full example **LLaMA-Factory YAML** for DeepSeek-V2-Lite\\n\",\n    \"```yaml\\n\",\n    \"### model\\n\",\n    \"model_name_or_path: deepseek-ai/DeepSeek-V2-Lite\\n\",\n    \"trust_remote_code: true\\n\",\n    \"\\n\",\n    \"### method\\n\",\n    \"stage: sft\\n\",\n    \"do_train: true\\n\",\n    \"finetuning_type: lora\\n\",\n    \"lora_rank: 8\\n\",\n    \"lora_target: all\\n\",\n    \"\\n\",\n    \"### dataset\\n\",\n    \"dataset: identity\\n\",\n    \"template: deepseek\\n\",\n    \"cutoff_len: 2048\\n\",\n    \"max_samples: 100000\\n\",\n    \"overwrite_cache: true\\n\",\n    \"preprocessing_num_workers: 16\\n\",\n    \"dataloader_num_workers: 4\\n\",\n    \"\\n\",\n    \"### output\\n\",\n    \"output_dir: saves/Kllama_deepseekV2Lite\\n\",\n    \"logging_steps: 10\\n\",\n    \"save_steps: 500\\n\",\n    \"plot_loss: true\\n\",\n    \"overwrite_output_dir: true\\n\",\n    \"save_only_model: false\\n\",\n    \"report_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]\\n\",\n    \"\\n\",\n    \"### train\\n\",\n    \"per_device_train_batch_size: 1\\n\",\n    \"gradient_accumulation_steps: 8\\n\",\n    \"learning_rate: 1.0e-4\\n\",\n    \"num_train_epochs: 3.0\\n\",\n    \"lr_scheduler_type: cosine\\n\",\n    \"warmup_ratio: 0.1\\n\",\n    \"bf16: true\\n\",\n    \"ddp_timeout: 180000000\\n\",\n    \"resume_from_checkpoint: null\\n\",\n    \"\\n\",\n    \"### ktransformers\\n\",\n    \"use_kt: true # use KTransformers as LoRA sft backend\\n\",\n    \"kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Chat-sft-amx.yaml\\n\",\n    \"cpu_infer: 32\\n\",\n    \"chunk_size: 8192\\n\",\n    \"```\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"6abc1968-6208-4344-9c82-335d7fe1d27c\",\n   \"metadata\": {},\n   \"source\": [\n    \"---\\n\",\n    \"### A. Pick & Prepare Your Model\\n\",\n    \"The first step in customization is choosing the right base model, and ensuring it works with KTransformers. The `model_name_or_path` setting (shown in LLaMA-Factory YAML before) controls this, and getting it right avoids common errors.\\n\",\n    \"- **Use a public model**: Directly set to Hugging Face Hub names (e.g., `deepseek-ai/DeepSeek-V2-Lite`, `Qwen/Qwen2-MoE-72B`).  \\n\",\n    \"- **Use a local model**: Replace with your local folder path (e.g., `/mnt/data/models/DeepSeek-V2-Lite`).\\n\",\n    \"\\n\",\n    \"**Critical Requirement**: The model must be in **BF16 format**.  \\n\",\n    \"  - FP8 models (like DeepSeek-V3’s default release) aren’t compatible with KTransformers’ optimization.  \\n\",\n    \"  - Fix: Convert FP8 to BF16 with **[this official script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py)**.\\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"### B. Tune LoRA: Balance Fitting Capability & Memory  \\n\",\n    \"LoRA trains tiny \\\"adapter\\\" weights instead of the entire model. Tweaking these two settings in LLaMA-Factory YAML (`lora_rank`, `lora_target`) lets you balance how well the model learns your data and how much GPU memory it uses:\\n\",\n    \"\\n\",\n    \"| Setting         | What it does                                                                 | Scenario & Recommendation                                                                 |\\n\",\n    \"|-----------------|-----------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|\\n\",\n    \"| `lora_rank`     | Controls the \\\"power\\\" of LoRA adapters (higher = more fitting, more memory). | - Small dataset (≤5k samples) or limited GPU: 4-8 (balances speed/memory).<br>- Large dataset (≥20k samples): 16-32 (better fits custom data). |\\n\",\n    \"| `lora_target`   | Which layers get LoRA (applies only to linear layers).                      | - Quick fine-tuning (e.g., style transfer): `q_proj,v_proj` (only attention layers—faster).<br>- Deep customization (e.g., medical Q&A): `all` (all linear layers—more accurate). |\\n\",\n    \"\\n\",\n    \"**Tip**: Pair `lora_rank=8` with `lora_alpha=32` (alpha = 4× rank) for stable training This ratio is tested to work well for most tasks, from chatbots to domain Q&A.  \\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"### C. Use Your Own Dataset\\n\",\n    \"Fine-tuning’s value lies in training on your own data, such as company documents, customer support logs, or domain-specific Q&A. Below is how to replace the default (identity) dataset with yours:  \\n\",\n    \"\\n\",\n    \"1. **Add a custom dataset**:  \\n\",\n    \"   - Step 1: Organize your data into LLaMA-Factory’s format (e.g., JSON with `instruction`, `input`, `output` fields—see [dataset examples](https://github.com/hiyouga/LLaMA-Factory/tree/main/data)).  \\n\",\n    \"   - Step 2: Register your dataset in [LLaMA-Factory/data/dataset_info.json](https://github.com/hiyouga/LLaMA-Factory/blob/main/data/dataset_info.json) (copy the format of built-in datasets—just add your dataset name and file path).\\n\",\n    \"     For example,\\n\",\n    \"     ```json\\n\",\n    \"     \\\"niko\\\": {\\n\",\n    \"        \\\"file_name\\\": \\\"../niko_train.json\\\"\\n\",\n    \"      },\\n\",\n    \"      ```\\n\",\n    \"   - Step 3: You may replace `dataset: identity` in LLaMA-Factory YAML to your dataset name (e.g. `dataset: niko`).\\n\",\n    \"2. **Tweak dataset settings for better results**:  \\n\",\n    \"   - `cutoff_len`: Truncates long texts (e.g., set to 4096 for long documents, 2048 for short dialogues—never exceed `model_max_length`).  \\n\",\n    \"   - `max_samples`: Limit samples to avoid overfitting (use 100 for debugging, `None` for full training—great if your dataset is huge).  \\n\",\n    \"   - `template`: Must match your model (e.g., `deepseek` for DeepSeek, `llama3` for LLaMA3, more refer to [supported-models](https://github.com/hiyouga/LLaMA-Factory/tree/main?tab=readme-ov-file#supported-models))—mismatched templates break response formatting!  \\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"### D. Save GPU Memory & Speed Up Training  \\n\",\n    \"If you’re hitting GPU memory limits or waiting too long for training, adjust these settings in LLaMA-Factory YAML:  \\n\",\n    \"\\n\",\n    \"| Challenge               | Setting to Tweak                          | How to Adjust                                                                 |\\n\",\n    \"|-------------------------|-------------------------------------------|--------------------------------------------------------------------------------|\\n\",\n    \"| GPU memory is tight     | `per_device_train_batch_size` + `gradient_accumulation_steps` | Set `per_device_train_batch_size=1` (smallest batch) + `gradient_accumulation_steps=16` (simulates a batch of 16—no memory penalty!). |\\n\",\n    \"| Model overfits (bad generalization) | `lora_dropout` + `num_train_epochs` | Add `lora_dropout: 0.1` (prevents overfitting) + reduce `num_train_epochs` to 2 (3 is default—overtraining hurts!). |\\n\",\n    \"\\n\",\n    \"**Key Train Configs Recap**:  \\n\",\n    \"- `learning_rate`: 1e-4~2e-4 for LoRA (stick to this range—too high = unstable, too low = slow learning).  \\n\",\n    \"- `save_steps`: Save checkpoints every 100-500 steps (frequent saves = safe, but don’t overdo it—each checkpoint takes storage!).  \\n\",\n    \"- `output_dir`: Customize the save path (e.g., `saves/medical_qa_deepseek` instead of the default—keeps your projects organized!).  \\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"### E. KTransformers Optimization: Unlock Maximum Performance  \\n\",\n    \"KTransformers is what makes fine-tuning large models (like 671B-parameter DeepSeek-V3) possible on modest hardware. These settings control how it optimizes layer placement (GPU vs. CPU) and computation speed:\\n\",\n    \"\\n\",\n    \"| Setting               | What it does                                                                 | How to Customize                                                                 |\\n\",\n    \"|-----------------------|-----------------------------------------------------------------------------|----------------------------------------------------------------------------------|\\n\",\n    \"| `use_kt`              | Enables KTransformers backend (must be `true`—otherwise, no optimization!). | Leave as `true`—this is what makes 671B models trainable on 2×4090s!             |\\n\",\n    \"| `cpu_infer`           | Number of CPU threads for MoE/linear computations.                          | Set to half your CPU cores (e.g., 32 for a 64-core CPU—too many threads = bottlenecks!). |\\n\",\n    \"| `chunk_size`          | Block size for long text processing (affects memory and speed).             | Default 8192 works for most tasks; increase to 16384 for extra-long texts (e.g., book summaries). |\\n\",\n    \"| `kt_optimize_rule`    | Defines where layers run (GPU/CPU) and which kernels to use (core of KT!).  | - Use the pre-built rule for your model (e.g., `DeepSeek-V2-Lite-Chat-sft-amx.yaml`).<br>- For faster speed: Use `AMXInt8`/`AMXBF16` as backend (if your CPU supports AMX—check with `lscpu | grep amx`).<br>- For compatibility: Fall back to `llamafile` if AMX isn’t supported. |\\n\",\n    \"\\n\",\n    \"#### Example Custom `kt_optimize_rule` (shown in the table above)  \\n\",\n    \"This rule tells KTransformers to offload heavy MoE layers to the CPU (saving GPU memory) and use AMX for fast CPU computation. Use it as a template for your own model: (Details tutorial could be seen in **[here](https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/injection_tutorial.md)**)\\n\",\n    \"```yaml\\n\",\n    \"- match:\\n\",\n    \"    name: \\\"^model\\\\\\\\.layers\\\\\\\\..*\\\\\\\\.mlp\\\\\\\\.experts$\\\"  # Target all MoE expert layers\\n\",\n    \"  replace:\\n\",\n    \"    class: ktransformers.operators.experts.KTransformersExperts  # KT's optimized MoE kernel\\n\",\n    \"    kwargs:\\n\",\n    \"      prefill_device: \\\"cuda\\\"  # Fast pre-processing on GPU\\n\",\n    \"      prefill_op: \\\"KExpertsTorch\\\"\\n\",\n    \"      generate_device: \\\"cpu\\\"  # Heavy MoE compute on CPU (saves GPU memory)\\n\",\n    \"      generate_op: \\\"KSFTExpertsCPU\\\"  # KT's SFT-optimized MoE operator\\n\",\n    \"      out_device: \\\"cuda\\\"  # Send results back to GPU for next steps\\n\",\n    \"      backend: \\\"AMXInt8\\\"  # Options: AMXInt8 (fastest) > AMXBF16 > llamafile (default)\\n\",\n    \"```\\n\",\n    \"**Alert:** Never mix KLinearMarlin with LoRA fine-tuning—replace it with KLinearTorch (as in the example) to avoid compatibility issues!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"93840117-084b-44fa-8b2e-6389e4a52bf0\",\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!USE_KT=1 llamafactory-cli train examples/train_lora/deepseek2_lora_sft_kt.yaml\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"c6d0b4db-65f7-4683-88d0-3269c962224c\",\n   \"metadata\": {},\n   \"source\": [\n    \"## 2. Chat with the Fine-tuned Model\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"fdbc5e95-9567-4b8a-94d7-eec410d94a6b\",\n   \"metadata\": {},\n   \"source\": [\n    \"After completing fine-tuning, the next critical step is to test your customized model through real-time interaction. Running `llamafactory-cli chat examples/inference/deepseek2_lora_sft_kt.yaml` loads the base model and your fine-tuned LoRA adapter. Below’s a detailed guide to customizing the chat process, with clear explanations of each setting’s role and how to fit it to your specific tasks.\\n\",\n    \"\\n\",\n    \"### Full example LLaMA-Factory YAML for inference\\n\",\n    \"```yaml\\n\",\n    \"model_name_or_path: deepseek-ai/DeepSeek-V2-Lite\\n\",\n    \"adapter_name_or_path: saves/Kllama_deepseekV2Lite\\n\",\n    \"template: deepseek\\n\",\n    \"infer_backend: ktransformers  # choices: [huggingface, vllm, sglang, ktransformers]\\n\",\n    \"trust_remote_code: true\\n\",\n    \"\\n\",\n    \"use_kt: true # use KTransformers as LoRA sft backend to inference\\n\",\n    \"kt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Chat-sft-amx.yaml\\n\",\n    \"cpu_infer: 32\\n\",\n    \"chunk_size: 8192\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"### A. Load Your Fine-Tuned Adapter (Two Supported Formats)  \\n\",\n    \"The `adapter_name_or_path` setting in LLaMA-Factory YAML points to your trained LoRA weights. Two formats are supported:  \\n\",\n    \"- **Folder Format (Default)**: If training saved a folder (e.g., `saves/Kllama_deepseekV2`) with `.safetensors` files, set it directly (e.g., `adapter_name_or_path: saves/Kllama_deepseekV2`).  \\n\",\n    \"- **GGUF Format (Single File)**: If you exported the adapter to a `.gguf` file (for portability), set the full path (e.g., `adapter_name_or_path: saves/my_adapter.gguf`).  \\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"### B. Tweak Response Quality (Generation Configs)  \\n\",\n    \"Optional generation parameters let you adjust the model’s responses to fit specific use cases, whether you need factual accuracy, creative expression, or concise answers. Add these to your YAML and modify based on your needs:\\n\",\n    \"```yaml\\n\",\n    \"# Optional generation configs (add to your inference YAML)\\n\",\n    \"max_new_tokens: 1024  # Max length of responses (512 = short, 2048 = long)\\n\",\n    \"temperature: 0.7      # Randomness (0.1 = factual/consistent, 1.0 = creative/diverse)\\n\",\n    \"top_p: 0.9            # Focus (0.8-0.95 = avoids irrelevant content)\\n\",\n    \"repetition_penalty: 1.1  # Reduces repetition (1.0 = no penalty, 1.2 = strict)\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"### C. KTransformers Inference Backend  \\n\",\n    \"The KTransformers-related settings directly impact inference performance—they must align with your training configuration to maintain optimization effects (e.g., low memory usage, fast speed):\\n\",\n    \"- `infer_backend` determines how the model generates responses—pick based on your needs. You need to choose `ktransformers`, if you LoRA fine-tuning it with ktransformers.\\n\",\n    \"- `use_kt: true`: Must match training—disables KT optimization if set to `false` (slower inference!).  \\n\",\n    \"- `kt_optimize_rule`: Use the **exact same file** as training (e.g., `DeepSeek-V2-Lite-Chat-sft-amx.yaml`)—ensures layers map correctly.  \\n\",\n    \"\\n\",\n    \"---\\n\",\n    \"\\n\",\n    \"### How to Verify Inference Works\\n\",\n    \"After launching the chat command, check the logs for these key messages to confirm the model is running correctly:\\n\",\n    \"1. `Loaded adapter weight: XXX -> XXX`: LoRA adapter is loaded correctly.  \\n\",\n    \"2. `KTransformers inference enabled`: KT optimization is active.  \\n\",\n    \"3. `Backend: AMXInt8`: AMX acceleration is working (if supported).  \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"c08b31f7-32a4-4d51-b6c0-d063d7785371\",\n   \"metadata\": {\n    \"scrolled\": true\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"!llamafactory-cli chat examples/inference/deepseek2_lora_sft_kt.yaml\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"KNllama\",\n   \"language\": \"python\",\n   \"name\": \"knllama\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.12.12\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "doc/en/MiniMax-M2.5.md",
    "content": "# Running MiniMax-M2.5 with SGLang and KT-Kernel\n\nThis tutorial demonstrates how to run MiniMax-M2.5 model inference using SGLang integrated with KT-Kernel for CPU-GPU heterogeneous inference. This setup enables efficient deployment of large MoE models by offloading experts to CPU.\n\n## Table of Contents\n\n- [Hardware Requirements](#hardware-requirements)\n- [Prerequisites](#prerequisites)\n- [Step 1: Download Model Weights](#step-1-download-model-weights)\n- [Step 2: Launch SGLang Server](#step-2-launch-sglang-server)\n- [Step 3: Send Inference Requests](#step-3-send-inference-requests)\n\n## Hardware Requirements\n\n**Minimum Configuration:**\n- **GPU**: NVIDIA RTX 2x4090 48GB (or equivalent with at least total 48GB VRAM available)\n- **CPU**: x86 CPU with AVX512BF16 support (e.g., Intel Sapphire Rapids)\n- **RAM**: At least 200GB system memory\n- **Storage**: ~200GB for model weights (FP8 weight, same weight folder for CPU and GPU)\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n1. **KT-Kernel installed**:\n\n```\ngit clone https://github.com/kvcache-ai/ktransformers.git\ngit submodule update --init --recursive\ncd kt-kernel && ./install.sh\n```\n\n2. **SGLang installed** - Install the kvcache-ai fork of SGLang (one of):\n\n```bash\n# Option A: One-click install (from ktransformers root)\n./install.sh\n\n# Option B: pip install\npip install sglang-kt\n```\n\n> Note: You may need to reinstall cudnn: `pip install nvidia-cudnn-cu12==9.16.0.29`\n\n3. **CUDA toolkit** - Compatible with your GPU (CUDA 12.8+ recommended)\n4. **Hugging Face CLI** - For downloading models:\n\n   ```bash\n   pip install huggingface-hub\n   ```\n\n## Step 1: Download Model Weights\n\n```bash\n# Create a directory for models\nmkdir -p /path/to/models\ncd /path/to/models\n\n# Download MiniMax-M2.5 (FP8 for both CPU and GPU)\nhuggingface-cli download MiniMaxAI/MiniMax-M2.5 \\\n  --local-dir /path/to/minimax-m2.5\n```\n\n**Note:** Replace `/path/to/models` with your actual storage path throughout this tutorial.\n\n## Step 2: Launch SGLang Server\n\nStart the SGLang server with KT-Kernel integration for CPU-GPU heterogeneous inference.\n\n\n### Launch Command (4x RTX 4090 Example)\n\n```bash\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 30005 \\\n  --model /path/to/minimax-m2.5 \\\n  --kt-weight-path /path/to/minimax-m2.5 \\\n  --kt-cpuinfer 96 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 30 \\\n  --kt-method FP8 \\\n  --kt-gpu-prefill-token-threshold 400 \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.94 \\\n  --served-model-name MiniMax-M2.5 \\\n  --enable-mixed-chunk \\\n  --tensor-parallel-size 4 \\\n  --enable-p2p-check \\\n  --disable-shared-experts-fusion \\\n  --chunked-prefill-size 32658 \\\n  --max-total-tokens 50000 \\\n  --attention-backend flashinfer\n```\n\nIt takes about 2~3 minutes to start the server.\n\nSee [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines.\n\n## Step 3: Send Inference Requests\n\nOnce the server is running, you can send inference requests using the OpenAI-compatible API.\n\n### Basic Chat Completion Request\n\n```bash\ncurl -s http://localhost:30005/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"MiniMax-M2.5\",\n    \"stream\": false,\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hi, who are you?\"}\n    ]\n  }'\n```\n\n### Example Response\n\n```json\n{\n    \"id\": \"e82360a51dd4465281a2b954d5237a06\",\n    \"object\": \"chat.completion\",\n    \"created\": 1770980318,\n    \"model\": \"MiniMax-M2.5\",\n    \"choices\": [\n        {\n            \"index\": 0,\n            \"message\": {\n                \"role\": \"assistant\",\n                \"content\": \"The user is asking who I am. I should give a brief, friendly introduction about myself.\\n</think>\\n\\nHi there! I'm MiniMax-M2.5, an AI assistant created by MiniMax. I'm here to help you with a wide range of tasks, including:\\n\\n- Answering questions\\n- Writing and editing code\\n- Explaining concepts\\n- Brainstorming ideas\\n- And much more!\\n\\nHow can I help you today?\",\n                \"reasoning_content\": null,\n                \"tool_calls\": null\n            },\n            \"logprobs\": null,\n            \"finish_reason\": \"stop\",\n            \"matched_stop\": 200020\n        }\n    ],\n    \"usage\": {\n        \"prompt_tokens\": 44,\n        \"total_tokens\": 138,\n        \"completion_tokens\": 94,\n        \"prompt_tokens_details\": null,\n        \"reasoning_tokens\": 0\n    },\n    \"metadata\": {\n        \"weight_version\": \"default\"\n    }\n}\n```\n"
  },
  {
    "path": "doc/en/Qwen3-Next.md",
    "content": "# Qwen3-Next Support for KTransformers\n\n## Introduction\n\n### Overview\nWe are very pleased to announce that Ktransformers now supports Qwen3-Next-80B-A3B-Thinking and Qwen3-Next-80B-A3B-Instruct.\n\n### Model & Resource Links\n\n- Official Qwen3-Next-80B-A3B-Thinking Release: \n  - https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Thinking\n\n- Official Qwen3-Next-80B-A3B-Instruct Release\n  - https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct\n\n\n## Installation Guide\n\n### 1. Resource Requirements\n\nThe model running with 512 Experts requires approximately 320 GB of memory and 6 GB of GPU memory.\n\n### 2. Prepare Models\n\n```bash\n# download gguf\nhuggingface-cli download --resume-download Qwen/Qwen3-Next-80B-A3B-Instruct\n\n```\n\n### 3. Install ktransformers\n\nTo install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).\n\n### 4. Run Qwen3-Next Inference Server\n\n```bash\npython ktransformers/server/main.py \\\n  --port 10021 \\\n  --model_path path-to-Qwen3-Next-80B-A3B-Thinking \\\n  --gguf_path path-to-Qwen3-Next-80B-A3B-Thinking \\\n  --model_name Qwen3NextForCausalLM \\\n  --optimize_config_path <local_path>/ktransformers/optimize/optimize_rules/Qwen3Next-serve.yaml \\\n  --max_new_tokens 1024 \\\n  --cache_lens 32768 \\\n  --chunk_size 256 \\\n  --max_batch_size 4 \\\n  --no-use_cuda_graph \\\n  --backend_type balance_serve\n```\n\n### 5. Access server\n\n```\ncurl -X POST http://localhost:10021/v1/chat/completions \\\n  -H \"accept: application/json\" \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hello\"}\n    ],\n    \"model\": \"Qwen3-Next-80B-A3B-Instruct\",\n    \"temperature\": 0.3,\n    \"top_p\": 1.0,\n    \"stream\": true\n  }'\n```\n\n### 6. Notes\n\nDue to Qwen3-Next’s use of linear attention, CUDA Graph optimization is not yet support — but it’s coming soon! 🚀"
  },
  {
    "path": "doc/en/Qwen3.5.md",
    "content": "# Running Qwen3.5 with SGLang and KT-Kernel\n\nThis tutorial demonstrates how to run Qwen3.5 (MoE-400B) model inference using SGLang integrated with KT-Kernel for CPU-GPU heterogeneous inference. This setup enables efficient deployment of large MoE models by offloading experts to CPU.\n\n## Table of Contents\n\n- [Running Qwen3.5 with SGLang and KT-Kernel](#running-qwen35-with-sglang-and-kt-kernel)\n  - [Table of Contents](#table-of-contents)\n  - [Hardware Requirements](#hardware-requirements)\n  - [Prerequisites](#prerequisites)\n  - [Step 1: Download Model Weights](#step-1-download-model-weights)\n  - [Step 2: Launch SGLang Server](#step-2-launch-sglang-server)\n    - [Launch Command (4x RTX 4090 Example)](#launch-command-4x-rtx-4090-example)\n  - [Step 3: Send Inference Requests](#step-3-send-inference-requests)\n    - [Basic Chat Completion Request](#basic-chat-completion-request)\n    - [Example Response](#example-response)\n\n## Hardware Requirements\n\n**Minimum Configuration:**\n- **GPU**: NVIDIA 4x RTX 4090 (or equivalent with at least 96GB total VRAM available)\n- **CPU**: x86 CPU with AVX512F support (e.g., Intel Sapphire Rapids)\n- **RAM**: At least 800GB system memory\n- **Storage**: ~800GB for model weights (BF16)\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n1. **KT-Kernel installed**:\n\n```bash\ngit clone https://github.com/kvcache-ai/ktransformers.git\ngit checkout qwen3.5\ngit submodule update --init --recursive\ncd kt-kernel && ./install.sh\n```\n\n2. **SGLang installed** - Install the kvcache-ai fork of SGLang (one of):\n\n```bash\n# Option A: One-click install (from ktransformers root)\n./install.sh\n\n# Option B: pip install\npip install sglang-kt\n```\n\n> Note: You may need to reinstall cudnn: `pip install nvidia-cudnn-cu12==9.16.0.29`\n\n3. **CUDA toolkit** - Compatible with your GPU (CUDA 12.8+ recommended)\n4. **Hugging Face CLI** - For downloading models:\n\n   ```bash\n   pip install huggingface-hub\n   ```\n\n## Step 1: Download Model Weights\n\n```bash\n# Create a directory for models\nmkdir -p /path/to/models\ncd /path/to/models\n\n# Download Qwen3.5 (BF16)\nhuggingface-cli download Qwen/Qwen3.5 \\\n  --local-dir /path/to/qwen3.5\n```\n\n**Note:** Replace `/path/to/models` with your actual storage path throughout this tutorial.\n\n## Step 2: Launch SGLang Server\n\nStart the SGLang server with KT-Kernel integration for CPU-GPU heterogeneous inference.\n\n### Launch Command (4x RTX 4090 Example)\n\n```bash\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 30005 \\\n  --model /path/to/qwen3.5 \\\n  --kt-weight-path /path/to/qwen3.5 \\\n  --kt-cpuinfer 60 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 1 \\\n  --kt-method BF16 \\\n  --attention-backend triton \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.98 \\\n  --chunked-prefill-size 4096 \\\n  --max-running-requests 32 \\\n  --max-total-tokens 32000 \\\n  --served-model-name qwen3.5 \\\n  --enable-mixed-chunk \\\n  --tensor-parallel-size 4 \\\n  --enable-p2p-check \\\n  --disable-shared-experts-fusion \\\n  --disable-custom-all-reduce\n```\n\nSee [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines.\n\n## Step 3: Send Inference Requests\n\nOnce the server is running, you can send inference requests using the OpenAI-compatible API.\n\n### Basic Chat Completion Request\n\n```bash\ncurl -s http://localhost:30005/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"qwen3.5\",\n    \"stream\": false,\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hi, who are you?\"}\n    ]\n  }'\n```\n\n### Example Response\n\n```json\n{\n    \"id\": \"c79f6d63e04f4874acb8853d218e1bf1\",\n    \"object\": \"chat.completion\",\n    \"created\": 1770880035,\n    \"model\": \"qwen3.5\",\n    \"choices\": [\n        {\n            \"index\": 0,\n            \"message\": {\n                \"role\": \"assistant\",\n                \"content\": \"Hello! I'm **Qwen**, a large language model developed by **Alibaba Cloud**. I'm designed to provide helpful, accurate, and safe information across a wide range of topics—whether you have questions, need help with writing, coding, analysis, or just want to explore ideas together.\\n\\nHow can I assist *you* today?\",\n                \"reasoning_content\": null,\n                \"tool_calls\": null\n            },\n            \"logprobs\": null,\n            \"finish_reason\": \"stop\",\n            \"matched_stop\": 248046\n        }\n    ],\n    \"usage\": {\n        \"prompt_tokens\": 16,\n        \"total_tokens\": 527,\n        \"completion_tokens\": 511,\n        \"prompt_tokens_details\": null,\n        \"reasoning_tokens\": 0\n    },\n    \"metadata\": {\n        \"weight_version\": \"default\"\n    }\n}\n```\n"
  },
  {
    "path": "doc/en/ROCm.md",
    "content": "# ROCm Support for ktransformers (Beta)\n\n## Introduction\n\n### Overview\nIn our effort to expand GPU architecture support beyond NVIDIA, we are excited to introduce **AMD GPU support through ROCm** in ktransformers (Beta release). This implementation has been tested and developed using EPYC 9274F processors and AMD Radeon 7900xtx GPUs.\n\n## Installation Guide\n\n### 1. Install ROCm Driver\nBegin by installing the ROCm drivers for your AMD GPU:\n- [Official ROCm Installation Guide for Radeon GPUs](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-radeon.html)\n\n### 2. Set Up Conda Environment\nWe recommend using Miniconda3/Anaconda3 for environment management:\n\n```bash\n# Download Miniconda\nwget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n\n# Create environment\nconda create --name ktransformers python=3.11\nconda activate ktransformers\n\n# Install required libraries\nconda install -c conda-forge libstdcxx-ng\n\n# Verify GLIBCXX version (should include 3.4.32)\nstrings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX\n```\n\n> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`\n\n### 3. Install PyTorch for ROCm\nInstall PyTorch with ROCm 6.2.4 support:\n\n```bash\npip3 install torch torchvision torchaudio \\\n  --index-url https://download.pytorch.org/whl/rocm6.2.4\npip3 install packaging ninja cpufeature numpy\n```\n\n> **Tip:** For other ROCm versions, visit [PyTorch Previous Versions](https://pytorch.org/get-started/previous-versions/)\n\n### 4. Build ktransformers\n\n```bash\n# Clone repository\ngit clone https://github.com/kvcache-ai/ktransformers.git\ncd ktransformers\ngit submodule update --init\n\n# Optional: Compile web interface\n# See: api/server/website.md\n\n# Install dependencies\nbash install.sh\n```\n\n## Running DeepSeek-R1 Models\n\n### Configuration for 24GB VRAM GPUs\nUse our optimized configuration for constrained VRAM:\n\n```bash\npython ktransformers/local_chat.py \\\n  --model_path deepseek-ai/DeepSeek-R1 \\\n  --gguf_path <path_to_gguf_files> \\\n  --optimize_config_path ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml \\\n  --cpu_infer <cpu_cores + 1>\n```\n\n> **Beta Note:** Current Q8 linear implementation (Marlin alternative) shows suboptimal performance. Expect optimizations in future releases.\n\n### Configuration for 40GB+ VRAM GPUs\nFor better performance on high-VRAM GPUs:\n\n1. Modify `DeepSeek-V3-Chat.yaml`:\n   ```yaml\n   # Replace all instances of:\n   KLinearMarlin → KLinearTorch\n   ```\n\n2. Execute with:\n   ```bash\n   python ktransformers/local_chat.py \\\n     --model_path deepseek-ai/DeepSeek-R1 \\\n     --gguf_path <path_to_gguf_files> \\\n     --optimize_config_path <modified_yaml_path> \\\n     --cpu_infer <cpu_cores + 1>\n   ```\n> **Tip:** If you got 2 * 24GB AMD GPUS, you may also do the same modify and run `ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml` instead.\n\n## Known Limitations\n- Marlin operations not supported on ROCm platform\n- Current Q8 linear implementation shows reduced performance (Beta limitation)\n"
  },
  {
    "path": "doc/en/SFT/DPO_tutorial.md",
    "content": "# DPO Training with LLaMA-Factory\n\nThis tutorial demonstrates how to use Direct Preference Optimization (DPO) to fine-tune a language model using the LLaMA-Factory framework. DPO is a method for training models based on human preferences, allowing for more aligned and user-centric outputs.\n\n## Installation\n\n### Step 1: Create a conda environment and suit it for KTransformers\n\n```Bash\nconda create -n Kllama python=3.12 # choose from : [3.11, 3.12, 3.13]\nconda install -y -c conda-forge libstdcxx-ng gcc_impl_linux-64\nconda install -y -c nvidia/label/cuda-12.8.0 cuda-runtime\n```\n\n### Step 2: Install the LLaMA-Factory environment\n\n```Bash\ngit clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git\ncd LLaMA-Factory\npip install -e \".[torch,metrics]\" --no-build-isolation\n```\n\n\n### Step 3: Install KTransformers\n#### Option 1: Install the KTransformers wheel that matches your Torch and Python versions, from https://github.com/kvcache-ai/ktransformers/releases/tag/v0.4.4\n\n(Note: The CUDA version can differ from that in the wheel filename.)\n\n```Bash\npip install ktransformers-0.4.4+cu128torch28fancy-cp312-cp312-linux_x86_64.whl\n```\n\n#### Option 2: Install KTransformers from source\n\n```Bash\ngit clone --depth 1 https://github.com/kvcache-ai/ktransformers.git\ncd ktransformers/kt-sft\nexport TORCH_CUDA_ARCH_LIST=\"8.0;8.9;9.0\" # set according to your GPU\n\npip install -r \"requirements-sft.txt\"\nKTRANSFORMERS_FORCE_BUILD=TRUE pip install -v . --no-build-isolation\n\n```\n\n### Step 4: Install the Flash-attention wheel that matches your Torch and Python versions, from: https://github.com/Dao-AILab/flash-attention/releases\n\n```Bash\n# abi=True/False can find from below\n# import torch\n# print(torch._C._GLIBCXX_USE_CXX11_ABI)\n\npip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl\n```\n\n### Step 5: (Optional) If you want to use flash_infer (otherwise it defaults to triton)\n\n```Bash\ngit clone https://github.com/kvcache-ai/custom_flashinfer.git\npip install custom_flashinfer/\n```\n\n## Prepare Models\n\nWe use `deepseek-ai/DeepSeek-V2-Lite` as an example here. You can replace it with other models such as Kimi K2.\n\n## How to start\n\n```Python\n# For LoRA SFT\nUSE_KT=1 llamafactory-cli train examples/train_lora/deepseek2_lora_dpo_kt.yaml\n# For Chat with model after LoRA SFT\nllamafactory-cli chat examples/inference/deepseek2_lora_dpo_kt.yaml\n# For API with model after LoRA SFT\nllamafactory-cli api examples/inference/deepseek2_lora_dpo_kt.yaml\n```\n\nFor example, we provide the YAML file as follows: \n\n（1）examples/train_lora/deepseek2_lora_dpo_kt.yaml\n\n```YAML\n### model\nmodel_name_or_path: deepseek-ai/DeepSeek-V2-Lite\ntrust_remote_code: true\n\n### method\nstage: dpo\ndo_train: true\nfinetuning_type: lora\nlora_rank: 8\nlora_target: all\npref_beta: 0.1\npref_loss: sigmoid  # choices: [sigmoid (dpo), orpo, simpo]\n\n### dataset\ndataset: dpo_en_demo\ntemplate: llama3\ncutoff_len: 2048\nmax_samples: 1000\noverwrite_cache: true\npreprocessing_num_workers: 16\ndataloader_num_workers: 4\n\n### output\noutput_dir: saves/Kllama_deepseekV2_DPO\nlogging_steps: 10\nsave_steps: 500\nplot_loss: true\noverwrite_output_dir: true\nsave_only_model: false\nreport_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]\n\n### train\nper_device_train_batch_size: 1\ngradient_accumulation_steps: 8\nlearning_rate: 5.0e-6\nnum_train_epochs: 3\nlr_scheduler_type: cosine\nwarmup_ratio: 0.1\nbf16: true\nddp_timeout: 180000000\nresume_from_checkpoint: null\n\n### ktransformers\nuse_kt: true # use KTransformers as LoRA sft backend\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml\ncpu_infer: 64\nchunk_size: 8192\n```\n\nFor more details about --kt_optimize_rule, please refer to https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/KTransformers-Fine-Tuning_User-Guide.md \n\nThen, you can use the lora adapter saved in `saves/Kllama_deepseekV2_DPO` for inference the same as the sft training. For example,\n\n```YAML\nmodel_name_or_path: DeepSeek-V2-Lite-Chat \nadapter_name_or_path: saves/Kllama_deepseekV2_DPO\ntemplate: deepseek\ninfer_backend: ktransformers  # choices: [huggingface, vllm, sglang, ktransformers]\ntrust_remote_code: true\n\nuse_kt: true # use KTransformers as LoRA sft backend to inference\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml\ncpu_infer: 32\nchunk_size: 8192\n\n```\n"
  },
  {
    "path": "doc/en/SFT/KTransformers-Fine-Tuning_Developer-Technical-Notes.md",
    "content": "- [Introduction](#introduction)\n- [Overall View of the KT Fine-Tuning Framework](#overall-view-of-the-kt-fine-tuning-framework)\n  - [Attention (LoRA + KT coexist)](#attention-lora--kt-coexist)\n  - [MoE (operator encapsulation + backward)](#moe-operator-encapsulation--backward)\n  - [Multi-GPU Loading/Training: Placement strategy instead of DataParallel](#multi-gpu-loadingtraining-placement-strategy-instead-of-dataparallel)\n- [KT-LoRA Fine-Tuning Evaluation](#kt-lora-fine-tuning-evaluation)\n  - [Setup](#setup)\n  - [Results](#results)\n  - [Speed Tests](#speed-tests)\n  - [Memory Footprint](#memory-footprint)\n- [Conclusion](#conclusion)\n\n\n# KTransformers Fine-Tuning × LLaMA-Factory Integration – Developer Technical Notes\n\n**MadSys Lab, KVCache-AI Team, Approaching AI, LLaMA-Factory Team**\n\n## Introduction\n\nRecent open-source LLMs—from DeepSeek-V3/R1 to Qwen-MoE and Kimi-K2—have surged in performance and scale. Yet due to **compute and memory constraints**, it is difficult for typical researchers to fine-tune trillion-parameter-class models. We therefore integrate **KTransformers** with **LLaMA-Factory** so that, with **2–4 RTX 4090 GPUs** and sufficient CPU memory, one can fine-tune ultra-large Mixture-of-Experts (MoE) models such as DeepSeek-671B.\n\nThis architecture bridges resource gaps, enabling **local fine-tuning of ultra-large models**, while also supporting **efficient scenario customization** at 14B/30B scales. We validate on stylized dialogue, Westernized translation tone, and medical Q&A, achieving rapid adaptation within hours.\n\nArchitecturally, LLaMA-Factory orchestrates data/config/training, LoRA insertion, and inference; KTransformers is a pluggable, high-performance operator backend that takes over Attention and MoE under the same training code, enabling **GPU+CPU heterogeneity** to accelerate training and reduce GPU memory.\n\n![image-20251011010558909](../../assets/image-20251011010558909.png)\n\nWe evaluated LoRA fine-tuning with HuggingFace default, Unsloth, and KTransformers backends (same settings and data). **KTransformers** is currently the only solution feasible on **2–4×24GB 4090s** for **671B-scale MoE**, and also shows higher throughput and lower GPU memory for 14B MoEs.\n\n| Under LoRA (BF16) + [NekoQA-10K stylized dialogue](https://github.com/mindsRiverPonder/LLM-practice) | HuggingFace Backend                      | Unsloth Backend                      | KTransformers Backend |\n| ------------------------------------------------------------ | ---------------------------------------- | ------------------------------------ | --------------------- |\n| [14B-DeepSeekV2-Lite] LoRA fine-tuning throughput            | 303.58 token/s                           | 455.37 token/s                       | 530.38 token/s        |\n| [14B-DeepSeekV2-Lite] GPU memory                             | 32.12 GB                                 | 9.64 GB                              | 6.08 GB               |\n| [671B-DeepSeekV3] LoRA fine-tuning throughput                | <font color='red'>Too Huge to run</font> | <font color='red'>NOT SUPPORT</font> | 40.35 token/s         |\n| [671B-DeepSeekV3] GPU memory (sum across GPUs)               | theoretical 1400 GB †                    | <font color='red'>NOT SUPPORT</font> | 70 GB †               |\n\n† The **1400 GB** is the **theoretical** FP16 full-resident footprint (not runnable). **70 GB** is the **measured peak** with KT (Attention on GPU + layered MoE offload).\n\nFrom the table above, it can be seen that for the 14B model, the KTransformers backend achieves approximately 75% higher throughput than the default HuggingFace solution, while using only about one-fifth of the GPU memory. For the 671B model, both HuggingFace and Unsloth fail to run on a single 4090 GPU, whereas KTransformers is able to perform LoRA fine-tuning at 40 tokens/s, keeping the GPU memory usage within 70 GB.\n\n![按照模型划分的对比图_02](../../assets/image-compare_model.png)\n\n\n\n## Overall View of the KT Fine-Tuning Framework\n\nWe detail how KTransformers takes over core operators in LLaMA-Factory’s fine-tuning framework to optimize Attention and MoE.\n\nDeepSeek-V3/V2 MoE models comprise a small-parameter dense Attention part and a large-parameter sparse MoE part. For illustration, consider layer 2 of DeepSeek-V2-Lite-Chat (from which each layer includes both Attention and MoE). Attention compute and KV cache mainly reside on the GPU; the heavyweight MoE part is primarily executed on the CPU. We first cover **Attention replacement and inheritance**, then **MoE encapsulation and backend interfacing**, and finally **multi-GPU placement**.\n\n### Attention (LoRA + KT coexist)\n\nKTransformers provides operator injection (`BaseInjectedModule`), and PEFT provides LoRA layer insertion. For fine-tuning, we design `KTransformersLinearLora`, inheriting from both `KTransformersLinear` and `LoraLayer`:\n\n- **Inheritance:** `KTransformersLinearLora` retains KT’s high-performance paths (`prefill_linear`/`generate_linear`) while accepting LoRA parameters (`lora_A/lora_B`).\n- **Replacement:** During preparation, we replace original `KTransformersLinear` layers (Q/K/V/O) with `KTransformersLinearLora`, preserving KT optimizations while enabling LoRA trainability.\n\n![image-20251016182810716](../../assets/image-20251016182810716.png)\n\nAfter replacement, LoRA is inserted at Q/K/V/O linear transforms (left), and `KTransformersLinearLora` contains both KT fast paths and LoRA matrices (right).\n\n![image-20251016182920722](../../assets/image-20251016182920722.png)\n\n### MoE (operator encapsulation + backward)\n\n#### Encapsulation\n\nGiven large parameters and sparse compute, we encapsulate the expert computation as a **differentiable black-box operator**—transparent upstream, replaceable downstream.\n\n- **Upstream (PyTorch graph):** we register a custom Autograd Function so the MoE layer appears as **a single node**. In the left figure (red box), only `KSFTExpertsCPU` is visible; on the right, the unencapsulated graph expands routing, dispatch, and FFN experts. Encapsulation makes the MoE layer behave like a standard `nn.Module` with gradients.\n- **Downstream (backend):** inside the Autograd Function, pybind11 calls C++ extensions for forward/backward. Multiple **pluggable backends** exist (AMX BF16/INT8; **llamafile**). The backend can be switched via YAML (e.g., `\"backend\": \"AMXBF16\"` vs. `\"llamafile\"`).\n\n![image-20250801174623919](../../assets/image-20250801174623919.png)\n\n#### Backward (CPU)\n\nMoE backward frequently needs the transposed weights $W^\\top$. To avoid repeated runtime transposes, we **precompute/cache** $W^\\top$ at load time (blue box). We also **cache necessary intermediate activations** (e.g., expert projections, red box) to reuse in backward and reduce recomputation. We provide backward implementations for **llamafile** and **AMX (INT8/BF16)**, with NUMA-aware optimizations.\n\n<img src=\"../../assets/image-20251016182942726.png\" alt=\"image-20251016182942726\" style=\"zoom:33%;\" />\n\n### Multi-GPU Loading/Training: Placement strategy instead of DataParallel\n\nTo lower **per-GPU memory peaks** on 2–4 GPUs, we use **model parallelism + explicit placement**, not DataParallel (which duplicates the whole model on each GPU).\n\nKey changes:\n\n1. **KTrainer:** takes over `.to(device)` to prevent “move whole model to a single GPU”. Using KT’s optimize-rule YAML, each layer declares `device: cuda:0/cuda:1/...` and is **constructed directly on the target GPU** (no extra copies).\n2. **Disable automatic DataParallel:** when `USE_KT=1`, we disable automatic DP wrappers from LLaMA-Factory/HF Trainer to avoid duplication and keep full control over sharding.\n3. **Gradient aggregation:** gradients are reduced to `cuda:0`. Intermediate activations stay local; only necessary tensors are transferred, cutting communication/activation overhead.\n\nThus, we keep KT placement strategies under multi-GPU fine-tuning. Users choose a `kt_optimize_rule` with `multi-gpu`. For DeepSeek-671B, `DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml` is a typical 2-GPU plan: KV/attention parts on each GPU; MoE experts sharded on CPU; both GPUs share the workload.\n\n\n\n## KT-LoRA Fine-Tuning Evaluation\n\n### Setup\n\nLLaMA-Factory orchestration, KTransformers backend, LoRA (rank=8, α=32, dropout=0.1, BF16), `GAS=16`, `qlen=512`, with the same KT optimize rule as training. We evaluate (a) stylized dialogue transfer and (b) two **small-scale representative** benchmarks: Translational-Style (generative) and AfriMed-QA (medical vertical; **SAQ** and **MCQ**). AMX is enabled; GPUs: 2×48GB RTX 4090; CPU: Intel Xeon Platinum 8488C.\n\n### Results\n\n#### Stylized Dialogue (CatGirl tone)\n\nDataset: [NekoQA-10K](https://zhuanlan.zhihu.com/p/1934983798233231689). The fine-tuned model consistently exhibits the target style (red boxes) versus neutral/rational base (blue). This shows **KT-LoRA injects style features** into the generation distribution with low GPU cost.\n\n![image-20251016175848143](../../assets/image-20251016175848143.png)\n\n#### Translational-Style benchmark (generative)\n\nDataset: [Translational-Style-ChatLLM](https://github.com/Benson114/Translational-Style-ChatLLM). Metrics: BLEU-1/2/3/4, ROUGE-1/2/L.\n\n| Translational-Style dataset    | BLEU-1    | BLEU-2    | BLEU-3    | BLEU-4    | ROUGE-1   | ROUGE-2   | ROUGE-L   |\n| ------------------------------ | --------- | --------- | --------- | --------- | --------- | --------- | --------- |\n| V2-Lite (no LoRA)              | 20.66     | 8.33      | 4.54      | 2.89      | 22.71     | 4.52      | 19.19     |\n| **KT-LoRA fine-tuned V2-Lite** | **35.41** | **22.44** | **15.42** | **11.18** | **42.03** | **18.38** | **33.10** |\n| V3 base (no LoRA)              | 8.49      | 3.34      | 1.62      | 0.96      | 15.91     | 2.55      | 10.07     |\n| **KT-LoRA fine-tuned V3**      | **37.02** | **23.70** | **16.21** | **11.49** | **43.43** | **18.96** | **34.54** |\n\nAs shown by the test results in the tables above, under a unified workflow and placement strategy, **both model scales exhibit consistent gains after fine-tuning**, supporting the usability and effectiveness of the “KT backend + LoRA fine-tuning” combination for generative style control. At the same time, this indicates that KT’s heterogeneous placement and operator optimizations can stably support small-sample adaptation in the style domain.\n\n#### Medical Vertical Benchmark (AfriMed-SAQ/MCQ)\n\nThe dataset adopts [AfriMed-QA](https://aclanthology.org/2025.acl-long.96/) (ACL 2025), a domain-specific dataset for the medical field in Africa with strong scenario customization characteristics, comprising two formats—multiple-choice questions (MCQ) and short-answer questions (SAQ)—which in this case serve as the evaluation for vertical-domain fine-tuning. In terms of evaluation criteria, BLEU/ROUGE are used for SAQ, and Accuracy is used for MCQ.\n\n| AfriMed-QA (SAQ)               | BLEU-1    | BLEU-2    | BLEU-3    | BLEU-4    | ROUGE-1   | ROUGE-2   | ROUGE-L   |\n| ------------------------------ | --------- | --------- | --------- | --------- | --------- | --------- | --------- |\n| V2-Lite (no LoRA)              | 13.58     | 11.12     | 9.10      | 7.23      | 22.48     | 7.81      | 11.73     |\n| **KT-LoRA fine-tuned V2-Lite** | **35.90** | **27.63** | **22.99** | **19.15** | **35.25** | **17.50** | **28.44** |\n| V3 base (no LoRA)              | 12.75     | 10.27     | 8.05      | 5.99      | 20.33     | 5.65      | 10.11     |\n| **KT-LoRA fine-tuned V3**      | **42.42** | **34.12** | **28.95** | **24.54** | **41.97** | **22.37** | **33.28** |\n\n| AfriMed-QA (MCQ)               | Accuracy   |\n| ------------------------------ | ---------- |\n| V2-Lite (no LoRA)              | 0.0645     |\n| **KT-LoRA fine-tuned V2-Lite** | **0.4812** |\n| V3 base (no LoRA)              | 0.5833     |\n| **KT-LoRA fine-tuned V3**      | **0.7930** |\n\nAs shown in the tables above, (1) DeepSeek-V3 (671B) after KT-LoRA fine-tuning achieves clearly higher performance than the fine-tuned DeepSeek-V2-Lite (14B) on both MCQ and SAQ, and it also surpasses the V3 base model. Within our small-scale setting, this preliminarily indicates that KT-LoRA fine-tuning of ultra-large-parameter models has practical significance in vertical domains.\n\n(2) Across both SAQ/MCQ sub-tasks, KT-LoRA delivers consistent gains, indicating that—with KT’s heterogeneous placement and backend operator support—LoRA fine-tuning can effectively inject the key knowledge points of vertical domains such as medicine into the model.\n\n#### Limitations\n\nAt present, most of our testing is conducted on **single datasets** and at **small scale** (≤ 20k examples), with the goal of providing **existence evidence of system effectiveness for KT-LoRA fine-tuning**, rather than drawing generalized conclusions about algorithmic generalization or scaling laws. Our report primarily presents representative figures; to support stronger algorithmic claims, larger sample sizes, multi-lingual/multi-domain datasets, and multi-seed repeated experiments would be required—these are beyond the scope of this work.\n\n**We also warmly welcome everyone to join the open-source LLaMA-Factory KT fine-tuning project. If you have additional test results, we especially welcome you to record them in the shared spreadsheet below, and to include the corresponding `kt_optimize_rule` files, dataset examples, training/evaluation YAMLs, and detailed GPU-memory and CPU configurations for community reference and reproducibility~!** \n\n\n\n### Speed Tests\n\n#### End-to-End Performance\n\n**Definitions**\n\n`step_time`：time per optimization step (tensor movement + Attention + MoE + others).\n\n`tokens_per_step = GAS × qlen`；`token/s = tokens_per_step / step_time`。 We use `GAS=16`, `qlen=512` → `tokens_per_step=8192`.\n\n**Measured**\n\n| Model                | step_time (s) | tokens/step | token/s   |\n| -------------------- | ------------- | ----------- | --------- |\n| DeepSeek-V3-671B     | 203           | 8192        | **40.35** |\n| DeepSeek-V2-Lite-14B | 36            | 8192        | **227.6** |\n\n#### MoE Compute (DeepSeek-V3-671B)\n\n**Theory**\n\n- MoE per-layer, per-token FLOPs (forward+backward) approx.:\n  $$\n  \\text{FLOPs}_{\\text{per-layer, per-token}} \\approx c \\cdot k \\cdot H \\cdot I\n  $$\n\n​\t\twith $k = 8$（Top-k），$H = 7168$（hidden size），$I = 2048$（intermediate size），$c\\approx16$（≈6 forward + ≈10 backward matmuls）。\n\n- Per-step across all MoE layers:\n  $$\n  \\text{FLOPs}_{\\text{per-step}} \\approx c \\cdot qlen \\cdot k \\cdot H \\cdot I \\cdot L_{\\text{MoE}}\n  $$\n\n​\t\tPlugging $c=16, qlen=512, k=8, H=7168, I=2048, L_{MoE}=58$，$\\text{FLOPs}_{\\text{per-step}} \\approx 55.8\\ \\text{TFLOPs}$.\n\n**Measured (MoE TFLOPS on CPU)**\n\nIf the **MoE-only** time per step is `t_moe` (seconds), $\\text{TFLOPS} = \\text{FLOPs}_{\\text{per-step}} / \\text{step\\_per\\_second}.$\n\nUse MoE-phase time, not full `step_time`, to get MoE throughput.\n\n| TFLOPS  | Forward | Backward |\n| ------- | ------- | -------- |\n| Average | 17.55   | 18.41    |\n\n### Memory Footprint\n\n- DeepSeek-V3 (671B; 58 MoE layers out of 61): ~**70 GB** total GPU, ~**1.2–1.3 TB** host memory.\n- DeepSeek-V2-Lite (14B; 26 MoE layers out of 27): ~**5 GB** GPU, ~**30 GB** host memory.\n\n\n\n## Conclusion\n\nIntegrating **KTransformers LoRA** with **LLaMA-Factory** provides a practical path to efficiently train and deploy MoE LLMs. KT contributes placement strategies and operator optimizations (DeepSeek/Qwen/Kimi support with AMX-accelerated kernels), and LoRA enables customization with very low GPU memory; LLaMA-Factory supplies a coherent user-level interface.\n\nThis means even tens-to-hundreds-of-billion-parameter MoE models can be fine-tuned and served with low latency on ordinary hardware. The approach balances **memory savings**, **speed**, and **usability**, turning ultra-large models into tools that developers can actually wield."
  },
  {
    "path": "doc/en/SFT/KTransformers-Fine-Tuning_User-Guide.md",
    "content": "- [Introduction](#introduction)\n  - [Fine-Tuning Results (Examples)](#fine-tuning-results-examples)\n- [Quick to Start](#quick-to-start)\n  - [Environment Setup](#environment-setup)\n  - [Core Feature 1: Use KTransformers backend to fine-tune ultra-large MoE models](#core-feature-1-use-ktransformers-backend-to-fine-tune-ultra-large-moe-models)\n  - [Core Feature 2: Chat with the fine-tuned model (base + LoRA adapter)](#core-feature-2-chat-with-the-fine-tuned-model-base--lora-adapter)\n  - [Core Feature 3: Batch inference + metrics (base + LoRA adapter)](#core-feature-3-batch-inference--metrics-base--lora-adapter)\n- [KT Fine-Tuning Speed (User-Side View)](#kt-fine-tuning-speed-user-side-view)\n  - [End-to-End Performance](#end-to-end-performance)\n  - [GPU/CPU Memory Footprint](#gpucpu-memory-footprint)\n- [Conclusion](#conclusion)\n\n\n# KTransformers Fine-Tuning × LLaMA-Factory Integration – User Guide\n\n**MadSys Lab, KVCache-AI Team, Approaching AI, LLaMA-Factory Team**\n\n## Introduction\n\nFrom **DeepSeek-V3/R1** to **Qwen3-MoE** and **Kimi-K2**, each wave of open-sourced large models brings leaps in performance and scale. However, many researchers and developers are constrained by expensive GPUs and models with tens or even hundreds of billions of parameters, making it **hard to fine-tune very large models under limited resources**. To bridge this gap, we propose a practical approach: combining **KTransformers** with **LLaMA-Factory**. With just **2–4 RTX 4090s** and a high-memory CPU, you can fine-tune ultra-large MoE models like DeepSeek-671B.\n\nOur goal is to give resource-constrained researchers a **local path to explore fine-tuning ultra-large models**, and also a fast way to customize smaller models (e.g., 14B/30B) for specific scenarios. We validate the setup using **stylized dialogue**, **Westernized translation tone**, and **medical Q&A** as representative tasks, showing that **personalized adaptation can be achieved within hours**.\n\nAs shown below, LLaMA-Factory is the unified orchestration/configuration layer for the whole fine-tuning workflow—handling data, training scheduling, LoRA injection, and inference interfaces. **KTransformers** acts as a pluggable high-performance backend that takes over core operators like Attention/MoE under the same training configs, enabling efficient **GPU+CPU heterogeneous cooperation**.\n\n![image-20251011010558909](../../assets/image-20251011010558909.png)\n\nWithin LLaMA-Factory, we compared LoRA fine-tuning with **HuggingFace**, **Unsloth**, and **KTransformers** backends. KTransformers is the **only workable 4090-class solution** for ultra-large MoE models (e.g., 671B) and also delivers higher throughput and lower GPU memory on smaller MoE models (e.g., DeepSeek-14B).\n\n| Under LoRA (BF16) + [NekoQA-10K stylized dialogue](https://github.com/mindsRiverPonder/LLM-practice) | HuggingFace Backend                      | Unsloth Backend                      | KTransformers Backend |\n| ------------------------------------------------------------ | ---------------------------------------- | ------------------------------------ | --------------------- |\n| [14B-DeepSeekV2-Lite] LoRA fine-tuning throughput            | 303.58 token/s                           | 455.37 token/s                       | 530.38 token/s        |\n| [14B-DeepSeekV2-Lite] GPU memory                             | 32.12 GB                                 | 9.64 GB                              | 6.08 GB               |\n| [671B-DeepSeekV3] LoRA fine-tuning throughput                | <font color='red'>Too Huge to run</font> | <font color='red'>NOT SUPPORT</font> | 40.35 token/s         |\n| [671B-DeepSeekV3] GPU memory (sum across GPUs)               | theoretical 1400 GB †                    | <font color='red'>NOT SUPPORT</font> | 70 GB †               |\n\n† **1400 GB** is a **theoretical** FP16 full-parameter resident footprint (not runnable). **70 GB** is the **measured peak** with KT strategy (Attention on GPU + layered MoE offload).\n\n![按照模型划分的对比图_02](../../assets/image-compare_model.png)\n\n### Fine-Tuning Results (Examples)\n\n#### Stylized Dialogue (CatGirl tone)\n\nDataset: [NekoQA-10K](https://zhuanlan.zhihu.com/p/1934983798233231689). Goal: improve style consistency and recognizability.\n\nThe figure compares responses from the base vs. fine-tuned models. The fine-tuned model maintains the target tone and address terms more consistently (red boxes), validating the effectiveness of **style-transfer fine-tuning**.\n\n![image-20251016175046882](../../assets/image-20251016175046882.png)\n\n#### Benchmarks\n\nWe use:\n\n(1) [Translational-Style-ChatLLM](https://github.com/Benson114/Translational-Style-ChatLLM), which asks for an exaggerated, Westernized translation tone—clear, stylized customization.\n\n(2) [AfriMed-QA](https://aclanthology.org/2025.acl-long.96/) (ACL 2025), a medical dataset for African contexts with strong domain specificity, including multiple-choice and short-answer sub-tasks—well-suited for vertical fine-tuning evaluation.\n\nThe tables show metrics before vs. after LoRA fine-tuning. We observe **large improvements** across metrics, verifying fine-tuning effectiveness:\n\n| Translational-Style dataset    | BLEU-1    | BLEU-2    | BLEU-3    | BLEU-4    | ROUGE-1   | ROUGE-2   | ROUGE-L   |\n| ------------------------------ | --------- | --------- | --------- | --------- | --------- | --------- | --------- |\n| V2-Lite (no LoRA)              | 20.66     | 8.33      | 4.54      | 2.89      | 22.71     | 4.52      | 19.19     |\n| **KT-LoRA fine-tuned V2-Lite** | **35.41** | **22.44** | **15.42** | **11.18** | **42.03** | **18.38** | **33.10** |\n| V3 base (no LoRA)              | 8.49      | 3.34      | 1.62      | 0.96      | 15.91     | 2.55      | 10.07     |\n| **KT-LoRA fine-tuned V3**      | **37.02** | **23.70** | **16.21** | **11.49** | **43.43** | **18.96** | **34.54** |\n\n| AfriMed-QA (short answer)      | BLEU-1    | BLEU-2    | BLEU-3    | BLEU-4    | ROUGE-1   | ROUGE-2   | ROUGE-L   |\n| ------------------------------ | --------- | --------- | --------- | --------- | --------- | --------- | --------- |\n| V2-Lite (no LoRA)              | 13.58     | 11.12     | 9.10      | 7.23      | 22.48     | 7.81      | 11.73     |\n| **KT-LoRA fine-tuned V2-Lite** | **35.90** | **27.63** | **22.99** | **19.15** | **35.25** | **17.50** | **28.44** |\n| V3 base (no LoRA)              | 12.75     | 10.27     | 8.05      | 5.99      | 20.33     | 5.65      | 10.11     |\n| **KT-LoRA fine-tuned V3**      | **42.42** | **34.12** | **28.95** | **24.54** | **41.97** | **22.37** | **33.28** |\n\n| AfriMed-QA (multiple choice)   | Accuracy   |\n| ------------------------------ | ---------- |\n| V2-Lite (no LoRA)              | 0.0645     |\n| **KT-LoRA fine-tuned V2-Lite** | **0.4812** |\n| V3 base (no LoRA)              | 0.5833     |\n| **KT-LoRA fine-tuned V3**      | **0.7930** |\n\nEven for ultra-large MoE models, **KTransformers-backed fine-tuning** achieves strong task performance quickly.\n\n\n\n## Quick to Start\n\nThis section shows how to install and use **LLaMA-Factory + KTransformers** for fine-tuning and inference:\n\n- Environment setup\n- Fine-tune ultra-large MoE models with KTransformers backend\n- Load the fine-tuned model (base + LoRA adapter) for chat/inference\n- Batch inference and metric evaluation\n\n### Environment Setup\n\nAccording to the following example, install both the **KTransformers** and **LLaMA-Factory** environments simultaneously.\n This time, to simplify the installation process of KTransformers, we have specially packaged a wheel file to avoid local compilation.\n The detailed installation steps are as follows:\n (Note: Make sure your local **Python version**, **Torch version**, **CUDA version**, and the **KTransformers wheel filename** correspond correctly.)\n\n```shell\n# 1. Create a conda environment\nconda create -n Kllama python=3.12 # choose from : [3.10, 3.11, 3.12, 3.13]\nconda install -y -c conda-forge libstdcxx-ng gcc_impl_linux-64\nconda install -y -c nvidia/label/cuda-11.8.0 cuda-runtime\n\n# 2. Install the LLaMA-Factory environment\ngit clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git\ncd LLaMA-Factory\npip install -e \".[torch,metrics]\" --no-build-isolation\n\n# 3. Install the KTransformers wheel that matches your Torch and Python versions, from https://github.com/kvcache-ai/ktransformers/releases/tag/v0.4.1 (Note: The CUDA version can differ from that in the wheel filename.)\npip install ktransformers-0.4.1+cu128torch27fancy-cp312-cp312-linux_x86_64.whl\n\n# 4. Install flash-attention, download the corresponding file based on your Python and Torch versions from: https://github.com/Dao-AILab/flash-attention/releases\npip install flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl\n# abi=True/False can find from below\n# import torch\n# print(torch._C._GLIBCXX_USE_CXX11_ABI)\n\n# 5. (Optional) If you want to use flash_infer (otherwise it defaults to triton)\ngit clone https://github.com/kvcache-ai/custom_flashinfer.git\npip install custom_flashinfer/\n```\n\n**Usage tip:** In LLaMA-Factory YAML, set `use_kt: true` and pick a `kt_optimize_rule` file to have KTransformers handle the core compute. The features below show typical configs.\n\n### Core Feature 1: Use KTransformers backend to fine-tune ultra-large MoE models\n\nRun the command: `USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml`.\n\nNote: You **must** provide a **BF16** model. DeepSeek-V3-671B is released in FP8 by default; convert with [DeepSeek-V3/inference/fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py).\n\n```yaml\n### model\nmodel_name_or_path: opensourcerelease/DeepSeek-V3-bf16\ntrust_remote_code: true\n\n### method\nstage: sft\ndo_train: true\nfinetuning_type: lora\nlora_rank: 8\nlora_target: all\n\n### dataset\ndataset: identity\ntemplate: deepseek\ncutoff_len: 2048\nmax_samples: 100000\noverwrite_cache: true\npreprocessing_num_workers: 16\ndataloader_num_workers: 4\n\n### output\noutput_dir: saves/Kllama_deepseekV3\nlogging_steps: 10\nsave_steps: 500\nplot_loss: true\noverwrite_output_dir: true\nsave_only_model: false\nreport_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]\n\n### train\nper_device_train_batch_size: 1\ngradient_accumulation_steps: 8\nlearning_rate: 1.0e-4\nnum_train_epochs: 3.0\nlr_scheduler_type: cosine\nwarmup_ratio: 0.1\nbf16: true\nddp_timeout: 180000000\nresume_from_checkpoint: null\n\n### ktransformers\nuse_kt: true # use KTransformers as LoRA sft backend\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml\ncpu_infer: 32\nchunk_size: 8192\n```\n\n`kt_optimize_rule` controls **placement strategy**. See also [ktransformers/optimize_rules](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/optimize/optimize_rules). Naming hints (`*` = wildcard):\n\n| Pattern                                      | Meaning                                               |\n| -------------------------------------------- | ----------------------------------------------------- |\n| DeepSeek-V2-Lite-Chat-* / DeepSeek-V3-Chat-* | Target model variants                                 |\n| *-sft-*                                      | Strategy for fine-tuning; others are for inference    |\n| *-amx-*                                      | Use AMX on CPU; otherwise use **llamafile**           |\n| *-multi-gpu-X*                               | Model parallel on X GPUs (X omitted → default 2 GPUs) |\n\nExample: `DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml` = V3-Chat fine-tuning with AMX and 2-GPU model parallel.\n\nWe recommend **AMX acceleration** where available (`lscpu | grep amx`). AMX supports BF16/INT8. Example:\n\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert parallelism\n    kwargs:\n      prefill_device: \"cpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n```\n\nOutputs go to `output_dir` in safetensors format plus adapter metadata for later loading.\n\n![image-20251016171537997](../../assets/image-20251016171537997.png)\n\n### Core Feature 2: Chat with the fine-tuned model (base + LoRA adapter)\n\nRun the command: `llamafactory-cli chat examples/inference/deepseek3_lora_sft_kt.yaml`.\n\nUse the safetensors adapter trained with KT for inference.\n\n```yaml\nmodel_name_or_path: opensourcerelease/DeepSeek-V3-bf16\nadapter_name_or_path: saves/Kllama_deepseekV3\ntemplate: deepseek\ninfer_backend: ktransformers  # choices: [huggingface, vllm, sglang, ktransformers]\ntrust_remote_code: true\n\nuse_kt: true # use KTransformers as LoRA sft backend to inference\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml\ncpu_infer: 32\nchunk_size: 8192\n```\n\nWe also support **GGUF** adapters: for safetensors, set the **directory**; for GGUF, set the **file path** in `adapter_name_or_path`.\n\nDuring loading, LLaMA-Factory maps layer names to KT’s naming. You’ll see logs like `Loaded adapter weight: XXX -> XXX`:\n\n![image-20251016171526210](../../assets/image-20251016171526210.png)\n\n### Core Feature 3: Batch inference + metrics (base + LoRA adapter)\n\nRun the command: `API_PORT=8000 llamafactory-cli api examples/inference/deepseek3_lora_sft_kt.yaml`.\n Invoke the KT fine-tuned adapter to provide the API; the usage logic of other APIs is consistent with the native LLaMA-Factory approach.\n\n```yaml\nmodel_name_or_path: opensourcerelease/DeepSeek-V3-bf16\nadapter_name_or_path: saves/Kllama_deepseekV3\ntemplate: deepseek\ninfer_backend: ktransformers  # choices: [huggingface, vllm, sglang, ktransformers]\ntrust_remote_code: true\n\nuse_kt: true # use KTransformers as LoRA sft backend to inference\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml\ncpu_infer: 32\nchunk_size: 8192\n```\n\n\n\n## KT Fine-Tuning Speed (User-Side View)\n\n### End-to-End Performance\n\n**Definitions**\n\n- `step_time`: wall-clock time for a full optimization step (tensor movement + Attention + MoE + other compute).\n- `tokens_per_step = GAS × qlen`; `token/s = tokens_per_step / step_time`.\n\n**Settings:** `GAS=16`, `qlen=512` (→ `tokens_per_step = 8192`); LoRA (`r=8, alpha=32, dropout=0.1`); **AMX** enabled; GPU: RTX 4090, CPU: Intel Xeon Platinum 8488C.\n\n**Measured**\n\n- **DeepSeek-V3-671B:** `step_time = 203 s` → `token/s ≈ 8192 / 203 ≈ 40.35`\n- **DeepSeek-V2-Lite-14B:** `step_time = 36 s` → `token/s ≈ 8192 / 36 ≈ 227.6`\n\n### GPU/CPU Memory Footprint\n\n- DeepSeek-V3 (671B; 61 layers with 58 MoE): ~**70 GB** total GPU VRAM (multi-GPU), ~**1.2–1.3 TB** CPU RAM.\n- DeepSeek-V2-Lite (14B; 27 layers with 26 MoE): ~**5.5 GB** GPU VRAM, ~**30 GB** CPU RAM.\n\n## Conclusion\n\nBy integrating **KTransformers LoRA fine-tuning** into **LLaMA-Factory**, we provide a practical guide for efficient training and deployment of MoE LLMs. KT brings cutting-edge optimizations (DeepSeek/Qwen/Kimi support with AMX-accelerated kernels), and LoRA enables customization under very low GPU memory. LLaMA-Factory offers a friendly, unified interface.\n\nThis integration (akin to Unsloth-style speedups) means even models with tens to hundreds of billions of parameters can be fine-tuned and deployed with low latency on commodity hardware. You get **memory savings, speed-ups, and usability** together. We encourage you to try LLaMA-Factory + KT for your next MoE project and follow this guide. Feedback is welcome!\n"
  },
  {
    "path": "doc/en/SFT/README.md",
    "content": "# kt-sft Docs"
  },
  {
    "path": "doc/en/SFT/injection_tutorial.md",
    "content": "# Tutorial: Inject Operator Step by Step\n\n> Author: Azure-Tang\n\n## TL;DR\nThis tutorial will guide you through the process of injecting custom operators into a model using the KTransformers framework. We will use the DeepSeekV2-Chat model as an example to demonstrate how to inject custom operators into the model step by step. The tutorial will cover the following topics:\n- [TL;DR](#tldr)\n- [How to Write Injection Rules](#how-to-write-injection-rules)\n- [Understanding Model Structure](#understanding-model-structure)\n- [Matrix Absorption-based MLA Injection](#matrix-absorption-based-mla-injection)\n- [Injection of Routed Experts](#injection-of-routed-experts)\n- [Injection of Linear Layers](#injection-of-linear-layers)\n- [Injection of Modules with Pre-calculated Buffers](#injection-of-modules-with-pre-calculated-buffers)\n- [Specifying Running Devices for Modules](#specifying-running-devices-for-modules)\n- [Muti-GPU](#muti-gpu)\n- [How to Write a New Operator and Inject into the Model](#how-to-write-a-new-operator-and-inject-into-the-model)\n\n## How to Write Injection Rules\nThe basic form of the injection rules for the Inject framework is as follows:\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.*$\"  # Target module name\n    class: torch.nn.Linear  # Target module\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      # your_op_param_1: 1234\n      # your_op_param_2: 5678\n  recursive: True\n```\n* match: This field marks the matching rules, which can appear in two forms, name and class. These two matching rules can appear together or separately; they only match when both criteria are met.\n* replace:\n\t* class: Python class that can be imported to replace the target module. If no replacement is desired, set to default.\n\t* kwargs: List of parameters needed for module initialization.\n\t    * generate_device: The device for this module, can be set to “cpu”, “cuda”, “cuda:1”, etc.\n* recursive: Whether to recursively inject this module’s submodules, default is True.\n\nFor the recursive field: Some modules contain multiple submodules, such as the Self-attention module typically includes q/k/v/o four linear modules. If we replace the self-attention module but do not want the internal linear modules to be covered by other rules, set this rule to False.\n\n## Understanding Model Structure\nUsing [deepseek-ai/DeepSeek-V2-Lite-Chat](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat) as an example, we can follow the above rules step by step to inject our custom module and run it. KTransformers offers a high degree of flexibility, allowing you to replace/experiment with basic operators. However, it also requires users to clearly understand the structure of the model they are running.\n\nFortunately, knowing the structure of a model is very simple. Open the file list on the [deepseek-ai/DeepSeek-V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/tree/main) homepage, and you can see the following files:\n<p align=\"center\">\n  <picture>\n    <img alt=\"Inject-Struction\" src=\"../../assets/model_structure_guild.png\" width=60%>\n  </picture>\n</p>\n\nFrom the `.saftensors` file, we can see the name of each layer’s weights, corresponding to the match.name attribute in the injection rules.\nFrom the `modeling_deepseek.py` file, we can see the specific implementation of each module class, with the class name corresponding to the match.class attribute in the injection rules.\n\nThe structure of the DeepSeekV2 model from the `.saftensors` and `modeling_deepseek.py` files is as follows:\n<p align=\"center\">\n  <picture>\n    <img alt=\"Inject-Struction\" src=\"../../assets/deepseekv2_structure.png\" width=60%>\n  </picture>\n</p>\n\nSupported operators and their corresponding classes are as follows:\n\n| match     | replace                | backends                | descriptions         |\n| --------- | ---------------------- | ----------------------- | -------------------- |\n| Linear    | KTransformersLinear    | KLinearMarlin           | Marlin as backend    |\n|           |                        | KLinearTorch            | pytorch as backend   |\n|           |                        | KLinearCPUInfer         | llamafile as backend |\n|           |                        | KLinearFP8         | Triton fp8_gemm kernel. Requires GPU be able to caluculate fp8 data |\n| experts   | KTransformersExperts   | KExpertsTorch           | pytorch as backend   |\n|           |                        | KExpertsMarlin          | Marlin as backend    |\n|           |                        | KExpertsCPU             | llamafile as backend |\n| Attention | KDeepseekV2Attention   | KDeepseekV2Attention    | MLA implementation   |\n| MoE       | KMistralSparseMoEBlock | KQwen2MoeSparseMoeBlock | MoE for Qwen2        |\n|           | KDeepseekV2MoE         | KDeepseekV2MoE          | MoE for DeepseekV2   |\n| Model     | KQwen2MoeModel         | KQwen2MoeModel          | Model for Qwen2      |\n|           | KDeepseekV2Model       | KDeepseekV2Model        | Model for DeepseekV2 |\n| RoPE      | RotaryEmbedding        | RotaryEmbedding         | RoPE module          |\n|           | YarnRotaryEmbedding    | YarnRotaryEmbedding     | RoPE module          |\n\nThen we start step-by-step injection of custom modules, our targets are:\n\n* Replace the linear module with custom Marlin linear module.\n* Replace the self-attention module with a custom Absorption-based MLA module.\n* Replace the experts module with a custom Experts module.\n* Replace the MoE module with a custom MoE module.\n* Replace the RoPE module with a custom RoPE module.\n* Set the running device for each module.\n\nThe full implementation of the injection rules can be found in the [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml).\n\n## Matrix Absorption-based MLA Injection\n\nFor the injection of the Attention module, we only need to use a regular expression to match the module names used in transformers and replace them with our own MLA module implementation. The YAML injection rule is as follows:\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"  # Regular expression\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # Optimized MLA implementation\n```\nAs you can see, each rule in the YAML file has two parts: match and replace. The match part specifies the module to be replaced, and the replace part specifies the module to be injected into the model along with the initialization keywords.\n\n## Injection of Routed Experts\nFor Routed Experts (corresponding to the exps in the diagram), the module we inject is CPUInfer, which is wrapped in the wrapper module KTransformersExperts. KTransformersExperts has multiple implementations, and we need to specify keywords to tell the wrapper module which implementation we want to use and how we plan to use it.\n\nIn the source code of the transformer, MoE is implemented using nn.ModuleList. We do not want KTransformers to traverse all submodules in the list and inject them one by one, so in this rule, we set recursive: False to prevent recursive injection into the submodules of this module. The YAML rule is as follows:\n\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # Custom MoE kernel with expert parallelism\n    kwargs:\n      generate_device: \"cpu\"\n      generate_op: \"MLPCPUExperts\"\n      out_device: \"cuda\"\n  recursive: False # Don't recursively inject submodules of this module\n```\n\nIf we inject Routed Experts as a custom module, we cannot use the interfaces in the original `nn.ModuleList`. Therefore, it is necessary to modify the forward function in the FFN module. The simplest method is to implement a new module with a custom forward function and inject it.\n```yaml\n- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # MLP module with custom forward function\n```\n\n## Injection of Linear Layers\n\nFor the remaining linear layer modules, we aim to use quantized operators to save storage space while improving performance. Since there is no current research on using MLA and quantization together, we do not want to inject linear into the MLA operator. Therefore, we can modify the regular expression and add a type check in the match part of the rule. Only modules that match both the name and class simultaneously will be injected. We also need to pass some keywords similar to the injection of Routed Experts. The YAML rule is as follows:\n\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn).*$\"  # Regular expression\n    class: torch.nn.Linear  # Only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # Optimized kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      generate_op: \"QuantizedLinearMarlin\"\n```\n## Injection of Modules with Pre-calculated Buffers\n\nTo avoid occupying resources when initializing the injected original model, we use torch’s meta device to initialize the original model. The RoPE module pre-calculates some buffers during initialization, but no calculations are performed when using the meta device. Therefore, we need to compensate for the calculation of the buffer when loading the model. Simply, we inject a custom module into the rotary embedding module, which performs pre-calculation during loading. The YAML rule is as follows:\n```yaml\n- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n```\n\n## Specifying Running Devices for Modules\n\nFinally, we set a fallback basic attribute generate_device for all modules:\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.|^lm_head\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda\"\n  \n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n```\nThrough these two rules, we place all previously unmatched layers (and their submodules) and lm_head on cuda, and the embedding on cpu. Note that the properties of a module will be determined by the first rule it matches. For example, if you later set a new replace.kwargs.generate_device in an injected module, the device set earlier will take precedence. If your computer has multiple cards, you can also configure the model to multiple cards.\n\n\n## Muti-GPU\n\nIf you have multiple GPUs, you can set the device for each module to different GPUs. \nDeepseekV2-Chat got 60 layers, if we got 2 GPUs, we can allocate 30 layers to each GPU. Complete multi GPU rule examples [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml).\n\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"Inject-Struction\" src=\"../../assets/multi_gpu.png\" width=60%>\n  </picture>\n</p>\n\nFirst of all, for multi-GPU, we have to inject an new operator `KDeepseekV2Model`. And set division of the layers to different GPUs. For our case, we have to set the `transfer_map` in the `KDeepseekV2Model` operatoras as follows:\n\n```yaml\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      transfer_map: \n        30: \"cuda:1\"\n```\n\nAnd we have to set the device for each module in the model. \n\nFor example, for `routed experts`, the yaml for one GPU is:\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # Custom MoE kernel with expert parallelism\n    kwargs:\n      generate_device: \"cuda:0\"\n      generate_op: \"MLPCUDAExperts\"\n      out_device: \"cuda:0\"\n  recursive: False # Don't recursively inject submodules of this module\n```\nBut for two GPUs, we need to set the device for each module in the model. \n\n```yaml\n# allcate 0-29 layers‘s out_device to cuda:0\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n# allocate 30-59 layers‘s out_device to cuda:1\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n```\nFor other modules, we can set the device in the same way.\n\n## How to Write a New Operator and Inject into the Model\n\nIn this section, we will explain how to write an operator that can be injected, using the implementation of a new linear as an example.\n\nFirst, all injectable operators need to inherit from the BaseInjectedModule class, which inherits some attributes required by our injection framework. Its initialization function needs to meet the following basic format:\n\n```python\nclass LinearTorchInject(BaseInjectedModule):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        generate_device: str = \"cuda\",\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, generate_device, **kwargs)\n```\nIf users have other parameters that need to be passed to this class, they can also be included in the init function and re-passed in the kwargs parameter in the yaml file. For example, if our operator wants to pass a parameter `my_param`, the init function can be written as:\n```python\nclass LinearTorchInject(BaseInjectedModule):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        generate_device: str = \"cuda\",\n        my_param: bool = True,\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        self.my_param = my_param\n```\nThen our injection rule can be written as:\n```yaml\n- match: \n    name: \"^model\\\\.layers\\\\..*$\"  # Regular expression matches the module name.\n    class: torch.nn.Linear  # Type restrictions can be added.\n  replace:\n    class: ktransformers.operators.linear.LinearTorchInject  # Inject module path\n    kwargs: # Extra parameters\n      generate_device: \"cuda\"\n      my_param: True\n```\nFor the linear module, it is also necessary to read weights from a gguf file. We provide the `KLinearBase` class to help users read weights from gguf files. Users only need to inherit and implement the load, unload, and forward functions. Therefore, a fully injectable linear class would look like this:\n```python\nclass LinearTorchInject(BaseInjectedModule, KLinearBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        generate_device: str = \"cuda\",\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        KLinearBase.__init__(self)\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        self.w = None\n        self.has_bias = False\n    \n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if device is None: device = self.device\n        if w is None: w = self.load_weight(device=device)\n\n        if isinstance(w, nn.Parameter):\n            self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T\n            self.bias = w[1].to(dtype=self.dtype)\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        self.w = self.w.to(device)\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n\n    def unload(self):\n        if self.w is not None:\n            self.w = None\n        if self.has_bias:\n            self.bias = None\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        dtype = x.dtype\n        out_device = x.device\n        x = x.to(device=self.device, dtype=self.dtype)\n        x = x @ self.w\n        if self.has_bias:\n            x = x + self.bias\n        x = x.to(dtype=dtype, device=out_device)\n        return x\n```\nNote that the `self.load_weight` function is provided by the KLinearBase class to help users load weights from a gguf file into the module. The implementation details of KLinearBase can be found on [GITHUB](https://github.com/kvcache-ai/ktransformers/blob/44f57270c9514d79fab224186d90ccf61059331a/ktransformers/operators/linear.py#L31).\n"
  },
  {
    "path": "doc/en/SFT_Installation_Guide_KimiK2.5.md",
    "content": "# Kimi-K2.5 LoRA SFT Tutorial\n\nThis tutorial demonstrates how to perform **LoRA Supervised Fine-Tuning (SFT)** on **Kimi-K2.5** using **LlamaFactory** with **KTransformers** as the backend, and then serve the fine-tuned model using **SGLang**.\n\nThe workflow is:\n\n```txt\nKTransformers + LlamaFactory LoRA SFT → (Optional) LlamaFactory Verification → SGLang Serving\n```\n\n## Table of Contents\n\n- [Hardware Requirements](https://chatgpt.com/c/6975bb7f-52e0-839c-a727-ec4b5d6723b5#hardware-requirements)\n- [Prerequisites](https://chatgpt.com/c/6975bb7f-52e0-839c-a727-ec4b5d6723b5#prerequisites)\n- [Step 0: Environment Setup (Method 1: Source Install)](https://chatgpt.com/c/6975bb7f-52e0-839c-a727-ec4b5d6723b5#step-0-environment-setup-method-1-source-install)\n- [Step 1: Prepare Model Weights (BF16 for SFT)](https://chatgpt.com/c/6975bb7f-52e0-839c-a727-ec4b5d6723b5#step-1-prepare-model-weights-bf16-for-sft)\n- [Step 2: Prepare YAML for LoRA SFT (KTransformers Backend)](https://chatgpt.com/c/6975bb7f-52e0-839c-a727-ec4b5d6723b5#step-2-prepare-yaml-for-lora-sft-ktransformers-backend)\n- [Step 3: Run LoRA SFT](https://chatgpt.com/c/6975bb7f-52e0-839c-a727-ec4b5d6723b5#step-3-run-lora-sft)\n- [Step 4: Post-SFT Quick Verification with LlamaFactory (Optional)](https://chatgpt.com/c/6975bb7f-52e0-839c-a727-ec4b5d6723b5#step-4-post-sft-quick-verification-with-LlamaFactory-optional)\n- [Step 5: SGLang Serving with LoRA (Recommended Delivery Path)](https://chatgpt.com/c/6975bb7f-52e0-839c-a727-ec4b5d6723b5#step-5-sglang-serving-with-lora-recommended-delivery-path)\n\n## Hardware Requirements\n\n### Training (LoRA SFT)\n\n- **LlamaFactory + KTransformers**\n- **GPU**: 4 * NVIDIA RTX 4090 24GB (or equivalent with at least total 48GB VRAM available)\n- **CPU**: x86 CPU with AMX support\n- **RAM**: At least 2TGB system memory\n- Swap can be used if CPU memory is insufficient\n\n### Inference (LoRA Adapter + Original Model)\n\n- **SGLang + KTransformers**\n- **GPU**: 2 * NVIDIA RTX 4090 24GB (or equivalent with at least total 48GB VRAM available)\n- **CPU**: x86 CPU with AVX512F support (e.g., Intel Sapphire Rapids)\n- **RAM**: At least 600GB system memory\n- **Storage**: ~600GB for model weights (native INT4 weight, same weight dir for CPU and GPU)\n\n\n\n## Step 0: Environment Setup\n\nWe recommend to separate **two conda environments**:\n\n| Environment | Purpose                                             |\n| ----------- | --------------------------------------------------- |\n| `kt-kernel` | Inference & serving (KTransformers + SGLang)        |\n| `kt-sft`    | Training (LlamaFactory + KTransformers SFT backend) |\n\n### 0.1 Inference Environment: `kt-kernel`\n\n```bash\nconda create -n kt-kernel python=3.11\nconda activate kt-kernel\n\ngit clone https://github.com/kvcache-ai/ktransformers.git\ngit checkout kimi_k2.5\ngit submodule update --init --recursive\ncd kt-kernel && ./install.sh\n```\n\n### 0.2 Install SGLang (Inference / Serving)\n\n**Recommended for Kimi-K2.5:**\n\n```bash\n# Option A: One-click install (from ktransformers root, installs sglang + kt-kernel)\n./install.sh\n\n# Option B: pip install\npip install sglang-kt\n```\n\n### 0.3 Training Environment: `kt-sft`\n\n```bash\nconda create -n kt-sft python=3.11\nconda activate kt-sft\n\ngit clone https://github.com/hiyouga/LlamaFactory.git\ncd LlamaFactory\npip install -e .\n```\n\n### 0.4 Install KTransformers SFT Dependencies\n\n```bash\nconda install -y -c conda-forge libstdcxx-ng gcc_impl_linux-64\nconda install -y -c nvidia/label/cuda-11.8.0 cuda-runtime\n\n# Install matching wheels (recommended), from https://github.com/kvcache-ai/ktransformers/releases\npip install ktransformers-<matching-version>.whl\npip install flash_attn-<matching-version>.whl\n```\n\n## Step 1: Prepare Model Weights (BF16 for SFT)\n\n### 1.1 Download INT4 Weights\n\nKTransformers **requires BF16 weights for SFT**.\n\n```bash\n# Download Kimi-K2.5 (RAW-INT4 for both CPU and GPU)\nhuggingface-cli download moonshotai/Kimi-K2.5 \\\n  --local-dir /path/to/kimi-k2.5\n```\n\n### 1.2 Convert INT4 → BF16\n\nKimi-K2.5 base model is in **INT4** format, convert it to **BF16** before SFT.\n\n## Step 2: Prepare YAML for LoRA SFT (KTransformers Backend)\n\n### 2.1 Training YAML (LoRA SFT)\n\nExample file:\n`examples/train_lora/kimik2_lora_sft_kt.yaml`\n\nRequired fields:\n\n```yaml\nstage: sft\nfinetuning_type: lora\nbf16: true\n\nuse_kt: true\nkt_optimize_rule: <rule.yaml>\ncpu_infer: 32\nchunk_size: 8192\n```\n\nOther fields (dataset, output_dir, learning rate, epochs) can be adjusted as usual.\n\n### 2.2 Inference YAML (LlamaFactory Verification)\n\nKey requirements:\n\n- `adapter_name_or_path`: LoRA output directory\n- `infer_backend: ktransformers`\n- **Same `use_kt` and `kt_optimize_rule` as training**\n\nThis YAML is used only for **quick verification**, not production serving.\n\n## Step 3: Run LoRA SFT\n\n```bash\nconda activate kt-sft\ncd LlamaFactory\n\nUSE_KT=1 llamafactory-cli train examples/train_lora/kimik2_lora_sft_kt.yaml\n```\n\nAfter training, the LoRA adapter is saved to `output_dir`.\n\n## Step 4: Post-SFT Quick Verification with LlamaFactory (Optional)\n\nBefore production deployment, the new PDF recommends a **lightweight sanity check**.\n\n```bash\nconda activate kt-sft\ncd LlamaFactory\n\nllamafactory-cli chat examples/inference/kimik2_lora_sft_kt.yaml\n```\n\nPurpose:\n\n- Validate LoRA correctness\n- Ensure reproducibility\n- Not for throughput benchmarking\n\n## Step 5: SGLang Serving with LoRA (Recommended Delivery Path)\n\nThis is the **major runtime update** introduced by the new PDF.\n\n### 5.1 Convert LoRA for SGLang\n\n```bash\npython ktransformers/kt-kernel/scripts/convert_lora.py \\\n  --base_path /path/to/kimi-base-model \\\n  --lora_path /path/to/llamafactory/output_dir \\\n  --output_path /path/to/lora_converted\n```\n\n### 5.2 (Optional) Convert CPU Weights to INT8\n\nTo reduce CPU memory usage:\n\n```bash\npython ktransformers/kt-kernel/scripts/convert_cpu_weights.py \\\n  --base_path /path/to/kimi-base-model \\\n  --output_dir /path/to/kimi-base-model-int8\n```\n\nThis produces:\n\n```text\n/path/to/kimi-base-model-int8/int8\n```\n\n### 5.3 Launch SGLang Server with LoRA\n\n```bash\nconda activate kt-kernel\n\npython -m sglang.launch_server \\\n  --enable-lora \\\n  --lora-paths lora1=/path/to/lora_converted \\\n  --lora-backend triton \\\n  --model-path /path/to/kimi-base-model \\\n  --tp 1 \\\n  --trust-remote-code \\\n  --context-length 4096 \\\n  --kt-weight-path /path/to/kimi-base-model-int8/int8 \\\n  --mem-fraction-static 0.9\n```\n\nNotes:\n\n- `--kt-weight-path` points to CPU INT8 weights\n- Adjust `tp`, `context-length`, and memory parameters per machine\n- RAWINT4 inference paths can follow **Kimi-K2.5-Native** directly"
  },
  {
    "path": "doc/en/SFT_Installation_Guide_KimiK2.md",
    "content": "## Installation\r\n\r\n### Step 1: Create a conda environment and suit it for KTransformers\r\n\r\n```Bash\r\nconda create -n Kllama python=3.10 # choose from : [3.10, 3.11, 3.12, 3.13]\r\nconda install -y -c conda-forge libstdcxx-ng gcc_impl_linux-64\r\nconda install -y -c nvidia/label/cuda-11.8.0 cuda-runtime\r\n```\r\n\r\n### Step 2: Install the LLaMA-Factory environment\r\n\r\n```Bash\r\ngit clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git\r\ncd LLaMA-Factory\r\npip install -e \".[torch,metrics]\" --no-build-isolation\r\n```\r\n\r\n### Step 3: Install the KTransformers wheel that matches your Torch and Python versions, from https://github.com/kvcache-ai/ktransformers/releases/tag/v0.4.1\r\n\r\n(Note: The CUDA version can differ from that in the wheel filename.)\r\n\r\n```Bash\r\npip install ktransformers-0.4.1+cu128torch28fancy-cp310-cp310-linux_x86_64.whl\r\n```\r\n\r\n### Step 4: Install the Flash-attention wheel that matches your Torch and Python versions, from: https://github.com/Dao-AILab/flash-attention/releases\r\n\r\n```Bash\r\npip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp310-cp310-linux_x86_64.whl\r\n# abi=True/False can find from below\r\n# import torch\r\n# print(torch._C._GLIBCXX_USE_CXX11_ABI)\r\n```\r\n\r\n### Step 5: (Optional) If you want to use flash_infer (otherwise it defaults to triton)\r\n\r\n```Bash\r\ngit clone https://github.com/kvcache-ai/custom_flashinfer.git\r\npip install custom_flashinfer/\r\n```\r\n\r\n## Download Model\r\n\r\nDownload the official KIMI weights. If the weights are in FP8 format, please refer to [convert_kimi_k2_fp8_to_bf16_cpu.py](https://github.com/kvcache-ai/ktransformers/blob/main/kt-kernel/scripts/convert_kimi_k2_fp8_to_bf16_cpu.py) to convert them to BF16 weights.\r\n\r\n## How to start\r\n\r\n```Python\r\n# For LoRA SFT\r\nUSE_KT=1 llamafactory-cli train examples/train_lora/kimik2_lora_sft_kt.yaml\r\n# For Chat with model after LoRA SFT\r\nllamafactory-cli chat examples/inference/kimik2_lora_sft_kt.yaml\r\n# For API with model after LoRA SFT\r\nllamafactory-cli api examples/inference/kimik2_lora_sft_kt.yaml\r\n```\r\n\r\n**If your** **CPU** **memory is insufficient to exceed 2T to support the Kimi K2, you can use the swap method additionally:**\r\n\r\n```Plain\r\nsudo fallocate -l 200G /data/swapfile\r\nsudo chmod 600 /data/swapfile\r\nsudo mkswap /data/swapfile\r\nsudo swapon /data/swapfile\r\n```\r\n\r\nFor example, we provide the YAML file as follows: (Since the structures of Kimi and DeepSeek are relatively similar, we use deepseek as template in llamafactory)\r\n\r\n（1）examples/train_lora/kimik2_lora_sft_kt.yaml\r\n\r\n```YAML\r\n### model\r\nmodel_name_or_path: KimiK2-model\r\ntrust_remote_code: true\r\n\r\n### method\r\nstage: sft\r\ndo_train: true\r\nfinetuning_type: lora\r\nlora_rank: 8\r\nlora_target: all\r\n\r\n### dataset\r\ndataset: identity\r\ntemplate: deepseek\r\ncutoff_len: 2048\r\nmax_samples: 100000\r\noverwrite_cache: true\r\npreprocessing_num_workers: 16\r\ndataloader_num_workers: 4\r\n\r\n### output\r\noutput_dir: saves/Kllama_kimik2\r\nlogging_steps: 10\r\nsave_steps: 500\r\nplot_loss: true\r\noverwrite_output_dir: true\r\nsave_only_model: false\r\nreport_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]\r\n\r\n### train\r\nper_device_train_batch_size: 1\r\ngradient_accumulation_steps: 8\r\nlearning_rate: 1.0e-4\r\nnum_train_epochs: 3.0\r\nlr_scheduler_type: cosine\r\nwarmup_ratio: 0.1\r\nbf16: true\r\nddp_timeout: 180000000\r\nresume_from_checkpoint: null\r\n\r\n### ktransformers\r\nuse_kt: true # use KTransformers as LoRA sft backend\r\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml\r\ncpu_infer: 32\r\nchunk_size: 8192\r\n```\r\n\r\nFor more details about --kt_optimize_rule, please refer to https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/KTransformers-Fine-Tuning_User-Guide.md\r\n\r\n（2）examples/inference/kimik2_lora_sft_kt.yaml\r\n\r\n```YAML\r\nmodel_name_or_path: opensourcerelease/DeepSeek-V3-bf16\r\nadapter_name_or_path: saves/Kllama_deepseekV3\r\ntemplate: deepseek\r\ninfer_backend: ktransformers  # choices: [huggingface, vllm, sglang, ktransformers]\r\ntrust_remote_code: true\r\n\r\nuse_kt: true # use KTransformers as LoRA sft backend to inference\r\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml\r\ncpu_infer: 32\r\nchunk_size: 8192\r\n\r\n```\r\n"
  },
  {
    "path": "doc/en/SmallThinker_and_Glm4moe.md",
    "content": "# SmallThinker & GLM-4-MoE Support for KTransformers\n\n## Introduction\n\n### Overview\nWe are excited to announce that **KTransformers now supports both SmallThinker and GLM-4-MoE**.\n\n- **SmallThinker-21BA3B-Instruct (bf16)**: ~26 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~84 GB DRAM.  \n- **GLM-4.5-Air (bf16)**: ~11 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~440 GB DRAM.\n- **GLM-4.5-Air (AMX INT8)**: prefill ~309 TPS / decode ~16 TPS **on a dual-socket CPU with one consumer-grade GPU**, requiring ~220 GB DRAM.\n\n### Model & Resource Links\n- **SmallThinker-21BA3B-Instruct**\n  - *[SmallThinker-21BA3B-Instruct](https://huggingface.co/PowerInfer/SmallThinker-21BA3B-Instruct)*\n- **GLM-4.5-Air 110B**\n  - [*GLM-4.5-Air*](https://huggingface.co/zai-org/GLM-4.5-Air)\n\n---\n\n## Installation Guide\n\n### 1. Resource Requirements\n\n| Model                     | Precision  | Experts | DRAM Needed | GPU Memory Needed\\* | TPS (approx.)                   |\n| ------------------------- | ---------- | ------- | ----------- | ------------------- | --------------------------------------- |\n| SmallThinker-21B-Instruct          | bf16       | 32      | \\~42 GB     | 14 GB               | \\~26 TPS                    |\n| GLM-4.5-Air            | bf16       | 128     | \\~220 GB    | 14 GB               | \\~11 TPS                    |\n| GLM-4.5-Air (AMX INT8) | int8       | 128     | \\~220 GB    | 14 GB               |  \\~16 TPS\n\n\n\\* Exact GPU memory depends on sequence length, batch size, and kernels used.  \n\n### 2. Prepare Models\n\n```bash\n# Example: download original safetensors (adjust to your paths/repos)\n# (Fill in actual repos/filenames yourself)\n\n# SmallThinker-21B\nhuggingface-cli download --resume-download https://huggingface.co/PowerInfer/SmallThinker-21BA3B-Instruct \\\n  --local-dir ./SmallThinker-21BA3B-Instruct\n\n# GLM-4-MoE 110B\nhuggingface-cli download --resume-download https://huggingface.co/zai-org/GLM-4.5-Air \\\n  --local-dir ./GLM-4.5-Air\n```\n\n\n### 3. Install KTransformers\n\nFollow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).\n\n```bash\npip install ktransformers  # or from source if you need bleeding-edge features\n```\n\n### 4. Run SmallThinker-21B Inference Server\n\n```bash\npython ktransformers/server/main.py \\\n  --port 10021 \\\n  --model_path /abs/path/to/SmallThinker-21B-bf16 \\\n  --model_name SmallThinkerForCausalLM \\\n  --optimize_config_path ktransformers/optimize/optimize_rules/SmallThinker-serve.yaml \\\n  --max_new_tokens 1024 \\\n  --cache_lens 32768 \\\n  --chunk_size 256 \\\n  --max_batch_size 4 \\\n  --backend_type balance_serve\n```\n\n### 5. Run GLM-4-MoE 110B Inference Server\n\n```bash\npython ktransformers/server/main.py \\\n  --port 10110 \\\n  --model_name Glm4MoeForCausalLM \\\n  --model_path /abs/path/to/GLM-4-MoE-110B-bf16 \\\n  --optimize_config_path ktransformers/optimize/optimize_rules/Glm4Moe-serve.yaml \\\n  --max_new_tokens 1024 \\\n  --cache_lens 32768 \\\n  --chunk_size 256 \\\n  --max_batch_size 4 \\\n  --backend_type balance_serve\n```\n\n### 6. Access Server\n\n```bash\ncurl -X POST http://localhost:10021/v1/chat/completions \\\n  -H \"accept: application/json\" \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hello\"}\n    ],\n    \"model\": \"SmallThinker-21BA3B-Instruct\",\n    \"temperature\": 0.3,\n    \"top_p\": 1.0,\n    \"stream\": true\n  }'\n```\n\n```bash\ncurl -X POST http://localhost:10110/v1/chat/completions \\\n  -H \"accept: application/json\" \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hello\"}\n    ],\n    \"model\": \"GLM-4.5-Air\",\n    \"temperature\": 0.3,\n    \"top_p\": 1.0,\n    \"stream\": true\n  }'\n```\n"
  },
  {
    "path": "doc/en/V3-success.md",
    "content": "## Hello everyone, here is the successfully reproduced environment configuration for your reference:\n### Case 1\n- Configuration: l40s 48G + 9654 x2 (192 cores) + 768G DDR5 12-channel\n- Performance: prefill 108 tokens/s, decode 10.8 tokens/s\n- Used version: main source code compiled \n### Case 2\n- Configuration: Dual Xeon 6430 32C processors, totaling 64 cores and 128 threads, 480GB DDR5 memory, single 4090 24G graphics card\n- Performance: Running speed approximately 6-8 tokens per second \n## NOTE\nIf there are any other configurations that have been successfully run, please feel free to let us know. We will keep updating for everyone to refer to when reproducing. (It has been found that it also works on 2080, AMD, etc. (doge : )\n[click here](https://docs.qq.com/smartsheet/form/AVxgQOYhhNfl%2FBB08J2%2Fv3rnnq?tab=BB08J2)"
  },
  {
    "path": "doc/en/api/server/api.md",
    "content": "# API\n\n- [OpenAI ChatCompletion](#openai-chatcompletion)\n- [Ollama ChatCompletion](#ollama-chatcompletion)\n- [OpenAI Assistant](#openai-assistant)\n\n## OpenAI ChatCompletion\n```bash\nPOST /v1/chat/completions\n\n```\nGenerate responses based on the selected model.\n\n### Parameters\n- `messages`: An array of `message` representing all historical messages. A `message` can be from a user or model (assistant) and includes:\n\n  - `role`: Either `user` or `assistant`, indicating the creator of this message.\n  - `content`: The message from the user or model.\n- `model`: The name of the selected model\n- `stream`: Either true or false. Indicates whether to use streaming response. If true, model inference results are returned via HTTP event stream.\n\n### Response\n- Streaming response: An event stream, each event contains a `chat.completion.chunk`. `chunk.choices[0].delta.content` is the incremental output returned by the model each time.\n- Non-streaming response: Not supported yet.\n\n\n\n### Example\n\n```bash\ncurl -X 'POST' \\\n  'http://localhost:9112/v1/chat/completions' \\\n  -H 'accept: application/json' \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n  \"messages\": [\n    {\n      \"content\": \"tell a joke\",\n      \"role\": \"user\"\n    }\n  ],\n  \"model\": \"Meta-Llama-3-8B-Instruct\",\n  \"stream\": true\n}'\n```\n\n```bash\ndata:{\"id\":\"c30445e8-1061-4149-a101-39b8222e79e1\",\"object\":\"chat.completion.chunk\",\"created\":1720511671,\"model\":\"not implmented\",\"system_fingerprint\":\"not implmented\",\"usage\":null,\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Why \",\"role\":\"assistant\",\"name\":null},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata:{\"id\":\"c30445e8-1061-4149-a101-39b8222e79e1\",\"object\":\"chat.completion.chunk\",\"created\":1720511671,\"model\":\"not implmented\",\"system_fingerprint\":\"not implmented\",\"usage\":null,\"choices\":[{\"index\":0,\"delta\":{\"content\":\"\",\"role\":\"assistant\",\"name\":null},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata:{\"id\":\"c30445e8-1061-4149-a101-39b8222e79e1\",\"object\":\"chat.completion.chunk\",\"created\":1720511671,\"model\":\"not implmented\",\"system_fingerprint\":\"not implmented\",\"usage\":null,\"choices\":[{\"index\":0,\"delta\":{\"content\":\"couldn't \",\"role\":\"assistant\",\"name\":null},\"logprobs\":null,\"finish_reason\":null}]}\n\n...\n\ndata:{\"id\":\"c30445e8-1061-4149-a101-39b8222e79e1\",\"object\":\"chat.completion.chunk\",\"created\":1720511671,\"model\":\"not implmented\",\"system_fingerprint\":\"not implmented\",\"usage\":null,\"choices\":[{\"index\":0,\"delta\":{\"content\":\"two-tired!\",\"role\":\"assistant\",\"name\":null},\"logprobs\":null,\"finish_reason\":null}]}\n\nevent: done\ndata: [DONE]\n```\n\n\n\n## Ollama ChatCompletion\n\n```bash\nPOST /api/generate\n```\n\nGenerate responses using the selected model.\n\n### Parameters\n- `prompt`: A string representing the input prompt.\n- `model`: The name of the selected model\n- `stream`: Either true or false. Indicates whether to use streaming responses. If true, returns the model inference results in the form of an HTTP event stream.\n\n### Response\n- Streaming response: A stream of JSON responses, each line is a JSON.\n  - `response`: The incremental result of the model completion.\n  - `done`: Whether the inference has finished.\n- Non-streaming response: Not yet supported.\n\n### 例子\n\n```bash\ncurl -X 'POST' \\\n  'http://localhost:9112/api/generate' \\\n  -H 'accept: application/json' \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n  \"model\": \"Meta-Llama-3-8B-Instruct\",\n  \"prompt\": \"tell me a joke\",\n  \"stream\": true\n}'\n```\n\n```bash\n{\"model\":\"Meta-Llama-3-8B-Instruct\",\"created_at\":\"2024-07-09 08:13:11.686513\",\"response\":\"I'll \",\"done\":false}\n{\"model\":\"Meta-Llama-3-8B-Instruct\",\"created_at\":\"2024-07-09 08:13:11.729214\",\"response\":\"give \",\"done\":false}\n\n...\n\n{\"model\":\"Meta-Llama-3-8B-Instruct\",\"created_at\":\"2024-07-09 08:13:33.955475\",\"response\":\"for\",\"done\":false}\n{\"model\":\"Meta-Llama-3-8B-Instruct\",\"created_at\":\"2024-07-09 08:13:33.956795\",\"response\":\"\",\"done\":true}\n```\n\n\n\n"
  },
  {
    "path": "doc/en/api/server/server.md",
    "content": "# Backend Services (Server)\nThe Server offers fast heterogeneous inference capabilities of ktransformers through an API for external usage.\n\n<img src=\"server-arch.png\" height=\"600\" alt=\"Server architecture\">\n\n## API\n\nThe Server provides model inference services externally through a RESTful API, with two methods of interaction: ChatCompletion and Assistant.\n\n- The ChatCompletion interface requires users to provide all historical dialogues at once, after which the model responds. AI service providers (such as [OpenAI](https://platform.openai.com/docs/api-reference/chat/create)) and local inference frameworks (such as [Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md)) both offer the ChatCompletion interface. To ensure compatibility with OpenAI and Ollama, the Server offers APIs that are consistent with theirs. Therefore, applications currently using OpenAI and Ollama can seamlessly switch to our Server. For example: [How to use Tabby and ktransformers locally with a 236B model for code completion?](tabby.md).\n- The Assistant is suitable for applications that need to reuse a series of resources and call the model. For instance, in educational applications, developers can create an Assistant named \"Second Grade Math Teacher\" and set an initial prompt (\"You are an experienced second-grade math teacher...\"), and upload relevant materials (second grade math textbooks). After creating the Assistant, the application needs to create a Thread to store the dialogues between the user and the model (Message). When calling the model, the application creates a Run to obtain the Assistant's response. Compared to ChatCompletion, the Assistant-enabled Server handles the reuse of conversational contexts and multi-turn dialogues, making model calls in complex scenarios more convenient. The [OpenAI Assistant API](https://platform.openai.com/docs/api-reference/assistants/createAssistant) introduces such an Assistant interface, and the Server provides a consistent API.\n\nThese API definitions are located in `server/api`, and their specific usage can be seen [here](api.md).\n\n## Integrating Model Inference Frameworks\n\nThe Server uses ktransformers for model calling and inference. It also supports other inference frameworks, such as the already supported [transformers](https://huggingface.co/docs/transformers/index), and plans to support [exllamav2](https://github.com/turboderp/exllamav2). These functionalities are implemented in `server/backend`.\n\nThe model inference functionalities of the frameworks are abstracted into a base class `BackendInterfaceBase`. This class includes a function: inference. It takes historical dialogue information messages as input and returns the text result from the model. The inference function adopts an async generator design, allowing the Server to return model responses in a streaming manner.\n\n```python\nclass BackendInterfaceBase:\n  async def inference(self, messages, **kwargs)->AsyncIterator[str]:\n    ...\n```\n\nThis inference function naturally implements the functionality of ChatCompletion because its inputs and outputs are historical dialogues and model responses, respectively. Thus, the ChatCompletion API can directly call the inference function to complete model inference.\n\nAssistant is more complex than ChatCompletion, requiring the Server to store the related state of the Assistant and call the inference function appropriately. The Server maintains a set of Assistant logic in the database, storing the Assistants, Threads, and Messages created by applications. In memory, the Server maintains a `ThreadContext` for each Thread, gathering information related to each Thread's Assistant, etc. When a user sends a new Message, the Server calls the get_local_messages function of ThreadContext to obtain messages and then calls the inference function to get the inference results.\n\n```python\nclass MyThreadContext(ThreadContext):\n    def get_local_messages(self):\n      ...\n```\n\nSince different model inference frameworks have different historical dialogue input formats, `ThreadContext` and `BackendInterface` need to be used in pairs. Besides its own ktransformers, the Server also supports transformers. For integrating other model inference frameworks, refer to the implementations of `TransformersInterface` and `TransformersThreadContext` in [transformers.py](https://github.com/kvcache-ai/ktransformers-dev/blob/main/ktransformers/server/backend/interfaces/transformers.py). "
  },
  {
    "path": "doc/en/api/server/tabby.md",
    "content": "# How to Use Tabby and ktransformers Locally with 236B Large Models for Code Completion?\n\n[Tabby](https://tabby.tabbyml.com/docs/welcome/) is an open-source code assistant that allows users to manually configure the backend framework and model, and use it across multiple IDEs/editors, such as VSCode and IntelliJ. Since Tabby can interface with Ollama on the framework side, and the ktransformers server provides a consistent API with Ollama, we can connect Tabby to the ktransformers server. This setup allows us to experience fast, heterogeneous inference in code completion scenarios.\n\n1. Start ktransformers.\n```bash\n./ktransformers --port 9112\n```\n2. Install Tabby: Follow the official tutorial to install Tabby on a Linux server or Windows PC with an NVIDIA GPU [here](https://tabby.tabbyml.com/docs/quick-start/installation/linux/).\n3. Configure Tabby: Create `~/.tabby/config.toml` and add the following configuration.\n```toml\n[model.completion.http]\nkind = \"ollama/completion\"\napi_endpoint = \"http://127.0.0.1:9112/\"\nmodel_name = \"DeepSeek-Coder-V2-Instruct\"\nprompt_template = \"<｜fim▁begin｜>{prefix}<｜fim▁hole｜>{suffix}<｜fim▁end｜>\" # Prompt Template\n```\n\nIn this configuration, `kind` specifies that ktransformers uses the standard Ollama API to serve Tabby; `api_endpoint` matches the interface bound when launching ktransformers; `model_name` is set to the model used by ktransformers, here `DeepSeek-Coder-V2-Instruct` is the backend inference model; `prompt_template` is the model's prompt template, which requires a corresponding template for different models to use the Fill In the Middle feature properly.\nHere we demonstrate the relevant configuration for Tabby using the Ollama API to provide the Completion feature. For configuration information about other functions available in Tabby, refer to [here](https://tabby.tabbyml.com/docs/administration/model/).\n\n\n4. Start the Tabby service: `./tabby serve`.\n<img src=\"run-tabby.png\" alt=\"image-20240709112329577\" style=\"zoom:50%;\" />\n\n   After launching, you should see access to the `/api/tags` interface in the ktransformers command line (in version v0.13.0 of Tabby, this changes to access to the `/api/show/` interface).\n<img src=\"visit-api-tags.png\" alt=\"image-20240709111648215\" style=\"zoom:67%;\" />\n\n6. Register a Tabby account, obtain a Token: After starting the Tabby service, open the corresponding link in a browser (as shown above at 0.0.0.0:8080), and follow the [tutorial](https://tabby.tabbyml.com/docs/quick-start/register-account/) to create a user and get a Token.\n\n7. Start VSCode, install the Tabby extension plugin, and use the Token obtained in the previous step to connect to the Tabby Server, following [here](https://tabby.tabbyml.com/docs/extensions/installation/vscode/).\n\n8. Open any code file and experience the fast heterogeneous inference of ktransformers."
  },
  {
    "path": "doc/en/api/server/website.md",
    "content": "# Start with website\n\nThis document provides the necessary steps to set up and run the web service for this project.\n\n## 1. Starting the Web Service\n\n### 1.1. Compiling the Web Code\n\nBefore you can compile the web code, make sure you have installed [Node.js](https://nodejs.org) version 18.3 or higher\n\nNote: The version of Node.js in the Ubuntu or Debian GNU/Linux software repository is too low, causing compilation errors. Users can also install Node.js through the Nodesource repository, provided they uninstall the outdated version first.\n\n```bash\n\n  # sudo apt-get remove nodejs npm -y && sudo apt-get autoremove -y\n  sudo apt-get update -y && sudo apt-get install -y apt-transport-https ca-certificates curl gnupg\n  curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | sudo gpg --dearmor -o /usr/share/keyrings/nodesource.gpg\n  sudo chmod 644 /usr/share/keyrings/nodesource.gpg\n  echo \"deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/nodesource.gpg] https://deb.nodesource.com/node_23.x nodistro main\" | sudo tee /etc/apt/sources.list.d/nodesource.list\n  sudo apt-get update -y\n  sudo apt-get install nodejs -y\n\n```\n\nOnce npm is installed, navigate to the `ktransformers/website` directory:\n\n```bash\ncd ktransformers/website\n```\n\nNext, install the Vue CLI with the following command:\n\n```bash\nnpm install @vue/cli\n```\n\nNow you can build the project:\n\n```bash\nnpm run build\n```\nFinally you can build ktransformers with website:\n```\ncd ../../\npip install .\n```\n"
  },
  {
    "path": "doc/en/balance-serve.md",
    "content": "# Balance Serve backend (multi-concurrency) for ktransformers\n\n## KTransformers v0.2.4 Release Notes\n\nWe are excited to announce the official release of the long-awaited **KTransformers v0.2.4**!\nIn this version, we’ve added highly desired **multi-concurrency** support to the community through a major refactor of the whole architecture, updating more than 10,000 lines of code.\nBy drawing inspiration from the excellent architecture of sglang, we have implemented high-performance asynchronous concurrent scheduling in C++, including features like continuous batching, chunked prefill, and more. Thanks to GPU sharing in concurrent scenarios, overall throughput is also improved to a certain extent. The following is a demonstration:\n\nhttps://github.com/user-attachments/assets/faa3bda2-928b-45a7-b44f-21e12ec84b8a\n\n</p>\n\n### 🚀 Key Updates\n\n1. Multi-Concurrency Support\n   - Added capability to handle multiple concurrent inference requests. Supports receiving and executing multiple tasks simultaneously.\n   - We implemented [custom_flashinfer](https://github.com/kvcache-ai/custom_flashinfer/tree/fix-precision-mla-merge-main) based on the high-performance and highly flexible operator library [flashinfer](https://github.com/flashinfer-ai/flashinfer/), and achieved a variable batch size CUDA Graph, which further enhances flexibility while reducing memory and padding overhead.\n   - In our benchmarks, overall throughput improved by approximately 130% under 4-way concurrency.\n   - With support from Intel, we tested KTransformers v0.2.4 on the latest Xeon6 + MRDIMM-8800 platform. By increasing concurrency, the total output throughput increased from 17 tokens/s to 40 tokens/s. We observed that the bottleneck has now shifted to the GPU. Using a higher-end GPU than the 4090D could further improve performance.\n2. Engine Architecture Optimization\n   ![image](https://github.com/user-attachments/assets/f5f001fa-dca7-4377-a01a-32192902aa47)\n   Inspired by the scheduling framework of sglang, we refactored KTransformers with a clearer three-layer architecture through an update of 11,000 lines of code, now supporting full multi-concurrency:\n   - Server：Handles user requests and serves the OpenAI-compatible API.\n   - Inference Engine：Executes model inference and supports chunked prefill.\n   - Scheduler：Manages task scheduling and requests orchestration. Supports continuous batching by organizing queued requests into batches in a FCFS manner and sending them to the inference engine.\n3. Project Structure Reorganization\n   All C/C++ code is now centralized under the /csrc directory.\n4. Parameter Adjustments\n   Removed some legacy and deprecated launch parameters for a cleaner configuration experience.\n   We plan to provide a complete parameter list and detailed documentation in future releases to facilitate flexible configuration and debugging.\n\n### 📚 Upgrade Notes\n\n- Due to parameter changes, users who have installed previous versions are advised to delete the ~/.ktransformers directory and reinitialize.\n- To enable multi-concurrency, please refer to the latest documentation for configuration examples.\n\n### What's Changed\n\nImplemented **custom_flashinfer** @Atream @ovowei @qiyuxinlin\nImplemented **balance_serve** engine based on **FlashInfer** @qiyuxinlin @ovowei\nImplemented a **continuous batching** scheduler in C++ @ErvinXie\nrelease: bump version v0.2.4 by @Atream @Azure-Tang @ErvinXie  @qiyuxinlin @ovowei @KMSorSMS @SkqLiao\n\n## Download the Docker image for testing v0.2.4\nVisit the [link](https://hub.docker.com/r/approachingai/ktransformers/tags) to pull the image, using `v0.2.4-AVX512` as an example.\n\n```bash\ndocker pull approachingai/ktransformers:v0.2.4-AVX512\ndocker run -it --gpus all --privileged --shm-size 64g --name ktrans --network=host -v /mnt:/mnt approachingai/ktransformers:v0.2.4-AVX512 /bin/bash\n# Open a new terminal\ndocker exec -it ktrans bash\n```\n\n## Installation Guide\n\n⚠️ Please note that installing this project will replace flashinfer in your environment. It is strongly recommended to create a new conda environment!!!\n\n⚠️ Please note that installing this project will replace flashinfer in your environment. It is strongly recommended to create a new conda environment!!!\n\n⚠️ Please note that installing this project will replace flashinfer in your environment. It is strongly recommended to create a new conda environment!!!\n\n### 2. Set Up Conda Environment\n\nWe recommend using Miniconda3/Anaconda3 for environment management:\n\n```bash\n# Download Miniconda\nwget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n\n# Create environment\nconda create --name ktransformers python=3.11\nconda activate ktransformers\n\n# Install required libraries\nconda install -c conda-forge libstdcxx-ng\n\n# Verify GLIBCXX version (should include 3.4.32)\nstrings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX\n```\n\n> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`\n\n### 2. Install dependencies\n\n```bash\nsudo apt install libtbb-dev libssl-dev libcurl4-openssl-dev libaio1 libaio-dev libfmt-dev libgflags-dev zlib1g-dev patchelf\npip3 install packaging ninja cpufeature numpy openai\npip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n\n```\n\n### 3. Build ktransformers\n\n```bash\n# Clone repository\ngit clone https://github.com/kvcache-ai/ktransformers.git\ncd ktransformers\ngit submodule update --init --recursive\n\n\n# Install single NUMA dependencies\nUSE_BALANCE_SERVE=1  bash ./install.sh\n# For those who have two cpu and 1T RAM（Dual NUMA）:\nUSE_BALANCE_SERVE=1 USE_NUMA=1 bash ./install.sh\n```\n\n## Running DeepSeek-R1-Q4KM Models\n\n### 1. Run for 24GB VRAM GPUs\n\nUse our optimized configuration for constrained VRAM:\n\n```bash\npython ktransformers/server/main.py \\\n  --port 10002 \\\n  --model_path <path_to_safetensor_config> \\\n  --gguf_path <path_to_gguf_files> \\\n  --optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml \\\n  --max_new_tokens 1024 \\\n  --cache_lens 32768 \\\n  --chunk_size 256 \\\n  --max_batch_size 4 \\\n  --backend_type balance_serve \\\n  --force_think # useful for R1\n```\n\nIt features the following arguments:\n\n- `--max_new_tokens`: Maximum number of tokens generated per request.\n- `--cache_lens`: Total length of kvcache allocated by the scheduler. All requests share a kvcache space.\n- `--max_batch_size`: Maximum number of requests (prefill + decode) processed in a single run by the engine. (Supported only by `balance_serve`)\n- `--chunk_size`: Maximum number of tokens processed in a single run by the engine.\n  corresponding to 32768 tokens, and the space occupied will be released after the requests are completed.\n- `--backend_type`: `balance_serve` is a multi-concurrency backend engine introduced in version v0.2.4. The original single-concurrency engine is `ktransformers`.\n- `--model_path`: Path to safetensor config path (only config required, not model safetensors).  \n  Please note that, since `ver 0.2.4`, the last segment of `${model_path}` directory name **MUST** be a local directory that contains the model's configuration files. Hugging Face links (e.g., deepseek-ai/DeepSeek-R1) are not supported at the moment.\n- `--force_think`: Force responding the reasoning tag of `DeepSeek R1`.\n\nThe relationship between `max_batch_size`, `cache_lens`, and `max_new_tokens` should satisfy:\n`cache_lens > max_batch_size * max_new_tokens`, otherwise the concurrency will decrease.\n\n### 2. access server\n\n```\ncurl -X POST http://localhost:10002/v1/chat/completions \\\n  -H \"accept: application/json\" \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hello\"}\n    ],\n    \"model\": \"DeepSeek-R1\",\n    \"temperature\": 0.3,\n    \"top_p\": 1.0,\n    \"stream\": true\n  }'\n```\n"
  },
  {
    "path": "doc/en/benchmark.md",
    "content": "## Benchmark\n\nTo conduct a quick and convenient check, we have employed a simple Python script available [here](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/tests) to assess the precision of our **[ktransformers](https://github.com/kvcache-ai/ktransformers)** project. For this evaluation, we utilized the same dataset, which was shuffled in a consistent manner and limited to the first 1,000 data points, to test our implementation across a variety of CPU kernels, MLA kernels, and quantization formats.\n\nWe selected the DeepSeek-V3 model in its bf16, int8, and q4km versions for this test. The MMLU dataset, which can be found [here](https://huggingface.co/datasets/cais/mmlu), was used (we selected all datasets and shuffled them with a fixed random seed).\n\n**!!! However, we skipped the few-shot part and only chose the first 1,000 data points for a quick check.** Please note that this approach may result in results that are not consistent with the technical report of DeepSeek-V3. And the test of R1 and further more tests are on going.\n\nTo verify our results, we chose [cloud service platform](https://cloud.siliconflow.cn/models) as baseline. All tests were conducted using the same script and datasets, allowing us to make a preliminary assessment of our project's precision.\n\nWe set the argument `temperature=0.6`, and to simplify the test process, we skipped the few-shot part and used the following prompt: `There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter. \\nQuestion: {question}\\nA. {option_a}\\nB. {option_b}\\nC. {option_c}\\nD. {option_d}\\nAnswer: '`. For more details, please refer to the [script](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/tests/mmlu_test.py).\n\nGiven that we have only tested 1,000 cases, which provides only a preliminary judgment, some fluctuations in the results are reasonable. We selected all datasets and shuffled them with a fixed random seed to ensure consistency.\n\n## Some Details\n\n- The bf16 model of DeepSeek-V3 is available [here](https://huggingface.co/opensourcerelease/DeepSeek-V3-bf16/tree/main) (you may convert it to gguf by llama.cpp). The q4km model can be found [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M).\n    \n- The optimization YAML file is located [here](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/optimize/optimize_rules). For the GEMM Kernel, you can change `KLinearMarlin` to `KLinearTorch`.\n    \n- To switch the MLA Kernel from Triton to Torch, you can check and modify [this file](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/attention.py), specifically by using the `forward_windows` method.\n    \n- When attempting to conduct the bf16 test (both CPU Weight and GPU Weight), you may encounter issues stemming from older versions of g++ and as, particularly when using Ubuntu 20 or earlier versions. To facilitate a smoother experience and enable you to reproduce our results, we have provided a development container. This container offers a pre-configured environment tailored for this purpose. However, please note that the container does not have the ktrans package installed. Therefore, you may still need to manually install certain packages to ensure everything runs smoothly.\n    \n    - You may config the model mount dir in `devcontainer/devcontainer.json`, check the `\"mouts\":` config.\n\n\n## The Result Table\nUses DeepSeek-V3 model (Some specific cases are R1)\n|                          |                   |            |                   |         |            |                                                        |              |\n| ------------------------ | ----------------- | ---------- | ----------------- | ------- | ---------- | ------------------------------------------------------ | ------------ |\n| DataSet                  | CPU Weight Format | CPU Kernel | GPU Weight Format | GEMM Kernel   | MLA Kernel | [Siliconflow](https://cloud.siliconflow.cn/models)<br> | Ktrans Point |\n| MMLU<br><br>(shuffle 1k) |               |    |               |    |       |                                                    |          |\n|          1                | bf16              | cpuinfer   | bf16              | torch   | torch      | 81.6                                                   | 81.9         |\n|           2               | q8_0              | cpuinfer   | bf16              | torch   | torch      | 81.6                                                   | 83.1         |\n|             3             | q4km              | cpuinfer   | bf16              | torch   | triton     | 81.6                                                   | 81.4         |\n|              4            | q4km              | cpuinfer   | q4km->marlin 8    | marlin  | triton     | 81.6                                                   | 81.1         |\n|               5           | q4km              | cpuinfer   | q4km->marlin 4    | marlin  | triton     | 81.6                                                   | 81           |\n|                6          | q4km              | cpuinfer   | fp8               | fp8gemm  | triton     | 81.6                                                   | 81.5         |\n|                7 (DeepSeek-R1)          |  iq1             | cpuinfer   |     fp8           |  fp8gemm | triton     | 78.6                                                   | 83.6         |\n| MMLU-pro<br>(shuffle 1k)                 |               |    |                |  |      |                                                    |          |\n| 1                 | q4km              | cpuinfer   | fp8               | fp8gemm | triton     | 57.7                                                   | 57.6         |\n|  2             | q4km              | cpuinfer   | q4km->marlin 4    | marlin  | triton     | 57.7                                                   | 57.5         |\n|  3 (DeepSeek-R1)             | iq1              | cpuinfer   | fp8    | fp8gem  | triton     | 71.9                                                   | tbd         |\n| HumanEval                | tbd               | tbd        | tbd               | tbd     | tbd        | tbd                                                    | tbd          |\n| GSM8K                    | tbd               | tbd        | tbd               | tbd     | tbd        | tbd                                                    | tbd          |\n\n**The details for each case are listed below**:\n\nBy default, The MLA kernel uses triton in linux and torch in windows. But we need to test torch in linux, so we manually modify the [file](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/attention.py#L592). Just get rid of all the if branch and force it to use `self.forward_windows`\n\n- MMLU test\n  1. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml) change all the `KLinearMarlin` to `KLinearTorch` (just find all the usage in this file). The source weight comes from [there](https://huggingface.co/opensourcerelease/DeepSeek-V3-bf16) (you need to use llama.cpp to convert it to gguf)\n  2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You need to modify the code to separately load cpu's expert weight. We leave this as comment in these places: [1](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L122), [2](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L136), [3](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L137) (note in 3, change the path to your local weight file path). The weight file for q8_0 is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q8_0)\n  3. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You need to modify the code to separately load cpu's expert weight. We leave this as comment in these places: [1](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L122), [2](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L136), [3](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/operators/experts.py#L137) (note in 3, change the path to your local weight file path). The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)\n  4. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). You don't need to change the source code as they both use q4km. But note the yaml file [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L29) and [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml#L18), below these lines you need to add `num_bits: 8` (in other words: add this kwargs to all that use `KLinearMarlin`). The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)\n  5. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)\n  6. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.\n  7. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case.\n- MMLU-pro test\n  1. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case. \n  2. [v3-chat_yaml](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml). No need to change yaml, just use the default. The weight file for q4km is [here](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)\n  3. You should check the [doc](./fp8_kernel.md) to learn how to test this case. This is a mixture tensor case."
  },
  {
    "path": "doc/en/deepseek-v2-injection.md",
    "content": "# Tutorial: Heterogeneous and Local MoE Inference\n\nDeepSeek-(Code)-V2 is a series of strong mixture-of-experts (MoE) models, featuring a total of 236 billion parameters, with 21 billion parameters activated per token. This model has demonstrated remarkable reasoning capabilities across various benchmarks, positioning it as one of the SOTA open models and nearly comparable in performance to GPT-4. DeepSeek-R1 uses a similar architecture to DeepSeek-V2, but with a bigger number of parameters.\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"DeepSeek-Coder-V2 Score\" src=\"../assets/BigCodeBench.png\" width=80%>\n  </picture>\n</p>\n\nMoreover, unlike previous models that employed traditional attention mechanisms like Grouped-Query Attention (GQA), DeepSeek-V2 incorporates a novel Multi-head Latent Attention (MLA). This innovation significantly reduces the size of the KV cache required during inference, enhancing efficiency.\n\n\nHowever, despite its efficiency, the practicality of running such a large model on personal computing setups seems impractical. Official documentation for DeepSeek-V2 indicates that eight 80GB GPUs are necessary for standard inference operations, and even the scaled-down Q4_k_m version requires at least two 80GB GPUs. These requirements are beyond the reach of most individual researchers and small teams.\n\n\nNonetheless, by employing several cutting-edge optimization techniques, we have successfully operated this colossal model on a desktop computer with only 21GB of VRAM and 136GB of DRAM. In this document, we outline the specific optimizations utilized and provide a detailed tutorial on how to implement these strategies using KTransformers.\n\n## Applied Optimizations\n\n### Optimized MLA Operator\n\nThe following figure provides a brief overview of DeepSeek-V2 architecture. At the heart of its attention layer, DeepSeek-V2 introduces a novel MLA operator that represents the heads of key-value pairs using a common, joint compressed representation, which holds significant potential for efficiency improvements. However, the official open-source implementation of the MLA operator explicitly decompresses this compressed representation and caches the decompressed key-value pairs. This process not only enlarges the KV cache size but also diminishes inference performance.\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"DeepSeek on KTransformers\" src=\"../assets/DeepSeek-on-KTransformers.png\" width=80%>\n  </picture>\n</p>\n\nTo truly capitalize on the benefits of MLA, we have implemented an optimized version for inference. According to its original paper, we absorb the decompression matrices directly into the q_proj and out_proj weights. Consequently, the compressed representation does not need to be decompressed to compute the attention. This adjustment significantly reduces the KV cache size and increases the arithmetic intensity of this operator, which greatly optimizes the utilization of GPU computational power.\n\n### Advanced Quantization Kernels\n\nThe original DeepSeek-V2 model stores its parameters in BF16 format, consuming approximately 470GB of raw storage. This exceeds the RAM capacity available on mainstream desktop computers. To address this, we leverage the well-established GGUF community's quantized weights to simplify the process for users.\nHowever, quantized data types are not typically supported by highly-optimized BLAS packages. As a result, the original HuggingFace Transformers' Torch implementation must dequantize these tensors to supported data types before processing, which introduces unnecessary computational overhead and increases memory traffic. To overcome this, we have incorporated advanced kernels that operate directly on quantized data types, thereby optimizing inference performance.\n\n\nIn the current version of KTransformers, we utilize Marlin for GPU kernels and llamafile for CPU kernels. These kerenls are specially designed to benefit from modern GPU architecture and modern CPU instruction extensions such as AVX512-BF16 (AMD Zen4 or newer) and AVX-VNNI (Intel Alder Lake or newer), that are tailored for quantized data types and machine learning workloads. We also use expert parallelism and other optimization for MOE inferencem on CPU based on llamafile, and call them as CPUInfer.  As demonstrated in Figure 2(cite from Marlin), Marlin can achieve near ideal 3.87x speedup compare to corresponding Torch counterparts. As demonstrated in the following figure, our micro benchmarks show that inference using CPUInfer performs several times faster than Torch in low bits representation. Note that in practical inference such as using transformers, the Torch baseline use BF16 or FP16 as linear weights, and will occupy more memory resources, or it will be more slower due to dequantization when using quanted weights.\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"CPUInfer Performance\" src=\"../assets/cpuinfer.png\" width=80%>\n  </picture>\n</p>\n<p align=\"center\">\n  <picture>\n    <img alt=\"marlin performance\" src=\"https://github.com/IST-DASLab/marlin/blob/master/assets/sustained.png?raw=true\" width=80%>\n  </picture>\n</p>\n\n### Arithmetic Intensity Guided Offloading\n\nStoring all 236 billion parameters of a model in GPU VRAM is clearly impractical for local users. Therefore, we strategically store only the most computationally intensive parameters on the GPU. For instance, after our optimizations, the MLA operator, which contains 128 heads with a shared compressed key-value representation, shows an arithmetic intensity of 512. This makes it the most intensive operator, particularly during smaller inference batch sizes. Hence, it is allocated to the GPU to leverage the power of tensor cores.\n\n\nOn the other hand, as shown in Figure 1, each transformer block in DeepSeek-V2 includes 160 mixture-of-experts (MoE) experts, comprising 96% of the total parameters. However, the MoE router activates only 6 out of these 160 experts for each token, which means that only 3.75% of the MoE parameters are utilized during the decoding phase. With a batch size of one, the arithmetic intensity of the MoE operation is roughly 0.075. This operation, primarily involving a batched General Matrix-Vector Multiplication (GEMV), can thus be efficiently handled by the CPU.\n\n\nFollowing this principle of arranging all operators by their arithmetic intensity and placing the most intensive ones in the GPU as much as possible, we prioritize positioning the MoE parameters and word embeddings computations on the CPU side to utilize its larger memory capacity. Meanwhile, the remaining parameters, including shared experts, projections in the attention module, and MLA, are stored in the GPU VRAM. As these parameters are accessed by every token, their placement on the GPU maximizes the benefits of high memory bandwidth. This configuration leads to approximately 20.7 GB of VRAM usage and 136GB DRAM memory requests if the Q4_K_M version is used, which is feasible even on a local desktop. Additionally, the placement can be adjusted according to the actual configuration, adhering to the same principle.\n\n\nMoreover, as an extensible framework, KTransformers is set to support more advanced operators in future releases, continually enhancing its capability to handle diverse workloads efficiently.\n\n## YAML Template\n\nTo implement the above optimizations in KTransformers, users need to write a YAML file containing the optimized rules. \nKTransformers will iterate through all sub-modules of the model, match rules specified in the YAML rule file, and replace them with advanced modules as specified.\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"Inject-Struction\" src=\"../assets/InjectStruction.png\" width=80%>\n  </picture>\n</p>\n\nSpecifically, the following rules are used:\n\n- Replace the Attention module with our [optimized MLA Operator](#mla).\n- Replace routed experts with [CPUInfer kernels](#experts) that use Llamafile.\n- Replace all Linear modules not belonging to attention with [Marlin](#linear) kernels.\n\n\n\n<h3 id=\"mla\">MLA</h3>\n\nFor attention module injection, we only need to match the module name used in Transformers using a regular expression and replace it with our pre-implemented module. \nThe YAML rule is listed below.\n\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\" # regular expression\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n```\n\nAs we can see, each rule in the YAML file has two parts: `match` and `replace`. \nThe match part specifies which module should be replaced, and the replace part specifies the module to be injected into the model along with the initialization keywords. \n\n<h3 id=\"experts\">Routed Experts </h3>\n\nFor routed experts, the module we inject is a wrapper of CPUInfer, KTransformersExperts. There are several implementations within a wrapper, and we need to specify keywords to tell the wrapper which implementation we want to use and how we intend to use it.\n\nIn KTransformers, some models exhibit different behaviors during prefilling and generation for better performance. KTransformersExperts is one of them. All these special modules have a `device` keyword describing which device the module should be initialized on. Other keywords specify the behaviors during prefilling and generation and may be differ when using different injection modules. Here, we specify which implementation on which device we want to use during prefilling and generation, and which device the output should be on.\nNote that we only use these parameters when layer-wise prefilling is enabled; otherwise, prefilling is conducted with the same configuration as generation.\n\nIn the original implementation of Transformers, MoE is implemented using `nn.ModuleList`. We don't want KTransformers to iterate through all the sub-modules in the list, so we set `recursive: False` in this rule to prevent recursive injection into submodules of the current module. Here is the YAML rule:\n\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert parallelism\n    device: \"cpu\"   # device to load this module on initialization\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n```\n\nIf we inject the expert list as a custom module, we can't use the interface in `nn.ModuleList` as default. We need to change the forward function in the FFN module. The simplest way is implementing a new module using custom forward function and inject it. We have implemented the new module, and the injection can be done by simply adding an injection rule. We can use the `class` instead of `name` to match a module that will be replaced. Here is the YAML rule:\n\n```yaml\n- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # MLP module with custom forward function\n```\n\n<h3 id=\"linear\">Other Linear Modules</h3>\n\nFor the remained linear modules, we want to use our quantization kernels. However, we don't want to inject linear in the MLA operator because we currently don't know the effect of using quantization in MLA. \nSo, we can change our regular expression and add a class check in the match part of the rule. Only modules matching both name and class simultaneously will be injected. \nWe also need to transfer some keywords similar to the injection of experts. Here is the YAML rule:\n\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n```\n\n<h3 id=\"Pre-compute Buffers\">Pre-compute Buffers </h3>\n\nThe original model is initialized on the meta device. The rotary embedding module pre-computes some buffers when initializing, which has no effect and doesn't compute anything when using the meta device. Therefore, we need to compute the buffers when loading the model. For convenience, we inject the rotary embedding module with our custom module, which performs pre-computations when loading. Here is the YAML rule:\n\n```yaml\n- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n```\n\n## Wrap Your Custom Module\n\nWe have implemented some modules, but you may need to inject your custom module using KTransformers. \nThe only thing you need to do is wrap your custom module and write YAML files. We provide a base operator specifying interfaces an injection module should have. You only need to inherit from that module and change the `__init__`, `forward`, or `load` function as needed.\n\n- The `__init__` function of the base operator maintains the necessary information for injection and execution of the KTransformers framework. To override this function, subclass modules need to call the base operator's `__init__` function in their own initializer.\n- The `forward` function is a function in torch that will be called during inference, where the module author has the freedom to achieve higher performance.\n- The `load` function is used to load all parameters of this module. The default implementation is to call the `load` function of all submodules. You can modify this function to customize its loading method and explicitly control the loading of its submodules.\n\n"
  },
  {
    "path": "doc/en/fp8_kernel.md",
    "content": "# FP8 Linear Kernel for DeepSeek-V3/R1\n\n## Overview\nThe DeepSeek-AI team provides FP8 safetensors for DeepSeek-R1/V3 models. We achieve performance optimization through the following works:\n- **FP8 GPU Kernel Integration**: FP8 linear layer acceleration kernels integrated in KTransformers\n- **Hybrid Quantization Architecture**:\n  - Attention and Shared-Expert modules use FP8 precision (enhances computational accuracy)\n  - Experts modules retain GGML quantization (GGUF format, reside in CPU to save GPU memory)\n\nSo those who are persuing the best performance can use the FP8 linear kernel for DeepSeek-V3/R1.\n\n## Key Features\n\n✅ Hybrid Precision Architecture (FP8 + GGML)<br>\n✅ Memory Optimization (~19GB VRAM usage)\n\n## Quick Start\n### Using Pre-Merged Weights\n\nPre-merged weights are available on Hugging Face:<br>\n[KVCache-ai/DeepSeek-V3-GGML-FP8-Hybrid](https://huggingface.co/KVCache-ai/DeepSeek-V3)<br>\n[KVCache-ai/DeepSeek-R1-GGML-FP8-Hybrid](https://huggingface.co/KVCache-ai/DeepSeek-R1)\n\n> Please confirm the weights are fully uploaded before downloading. The large file size may extend Hugging Face upload time.\n\n\nDownload Pre-Merged Weights\n```shell\npip install -U huggingface_hub\n\n# Optional: Use HF Mirror for faster downloads in special area.\n# export HF_ENDPOINT=https://hf-mirror.com \n\nhuggingface-cli download --resume-download KVCache-ai/DeepSeek-V3-GGML-FP8-Hybrid --local-dir <local_dir>\n```\n### Using merge scripts\nIf you got local DeepSeek-R1/V3 fp8 safetensors and gguf weights(eg.q4km), you can merge them using the following scripts.\n\n```shell\npython merge_tensors/merge_safetensor_gguf.py \\\n  --safetensor_path <fp8_safetensor_path> \\\n  --gguf_path <gguf_folder_path> \\\n  --output_path <merged_output_path>\n```\n\n* `--safetensor_path`:\tinput path of safetensor file([Download](https://huggingface.co/deepseek-ai/DeepSeek-V3/tree/main)).\n* `--gguf_path`: input path of gguf folder ([Download](https://huggingface.co/unsloth/DeepSeek-V3-GGUF/tree/main/DeepSeek-V3-Q4_K_M)).\n* `--output_path`: output path of merged file.\n\n\n### Execution Notes\n\nLaunch local_chat.py with custom quantized experts\n```shell\npython ktransformers/local_chat.py \\\n  --model_path deepseek-ai/DeepSeek-V3 \\\n  --gguf_path <merged_weights_folder> \\\n  --optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml \\\n  --cpu_infer <cpu_cores + 1>\n```\n\n\n## Notes\n\n⚠️ Hardware Requirements<br>\n* Recommended minimum 19GB available VRAM for FP8 kernel.\n* Requires GPU with FP8 support (e.g., 4090)\n\n⏳ First-Run Optimization\nJIT compilation causes longer initial execution (subsequent runs retain optimized speed).\n\n🔄 Temporary Interface<br>\nCurrent weight loading implementation is provisional - will be refined in future versions\n\n📁 Path Specification<br>\nDespite hybrid quantization, merged weights are stored as .safetensors - pass the containing folder path to `--gguf_path`"
  },
  {
    "path": "doc/en/install.md",
    "content": "<!-- omit in toc -->\n\n# How to Run DeepSeek-R1\n\n- [How to Run DeepSeek-R1](#how-to-run-deepseek-r1)\n  - [Preparation](#preparation)\n  - [Installation](#installation)\n    - [Attention](#attention)\n    - [Supported models include](#supported-models-include)\n    - [Support quantize format](#support-quantize-format)\n\nIn this document, we will show you how to install and run KTransformers on your local machine. There are two versions:\n\n* V0.2 is the current main branch.\n* V0.3 is a preview version only provides binary distribution for now.\n* To reproduce our DeepSeek-R1/V3 results, please refer to [Deepseek-R1/V3 Tutorial](./DeepseekR1_V3_tutorial.md) for more detail settings after installation.\n\n## Preparation\n\nSome preparation:\n\n- CUDA 12.1 and above, if you didn't have it yet, you may install from [here](https://developer.nvidia.com/cuda-downloads).\n\n  ```sh\n  # Adding CUDA to PATH\n  if [ -d \"/usr/local/cuda/bin\" ]; then\n      export PATH=$PATH:/usr/local/cuda/bin\n  fi\n\n  if [ -d \"/usr/local/cuda/lib64\" ]; then\n      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64\n      # Or you can add it to /etc/ld.so.conf and run ldconfig as root:\n      # echo \"/usr/local/cuda-12.x/lib64\" | sudo tee -a /etc/ld.so.conf\n      # sudo ldconfig\n  fi\n\n  if [ -d \"/usr/local/cuda\" ]; then\n      export CUDA_PATH=$CUDA_PATH:/usr/local/cuda\n  fi\n  ```\n- Linux-x86_64 with gcc, g++>=11 and cmake>=3.25 (using Ubuntu as an example)\n- **Note**: The default CMake version in Ubuntu 22.04 LTS or higher may not support newer CUDA language dialects (e.g., CUDA 20). This can cause errors such as Target \"cmTC_xxxxxx\" requires the language dialect \"CUDA20\", but CMake does not know the compile flags to use to enable it. To resolve this, install a newer CMake version, for instance, by adding the Kitware APT repository.\n\n  ```sh\n  sudo apt-get update \n  sudo apt-get install build-essential cmake ninja-build patchelf\n  ```\n- We recommend using [Miniconda3](https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh) or [Anaconda3](https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh) to create a virtual environment with Python=3.11 to run our program. Assuming your Anaconda installation directory is `~/anaconda3`, you should ensure that the version identifier of the GNU C++standard library used by Anaconda includes `GLIBCXX_3.4.32`\n\n  ```sh\n  conda create --name ktransformers python=3.11\n  conda activate ktransformers # you may need to run ‘conda init’ and reopen shell first\n\n  conda install -c conda-forge libstdcxx-ng # Anaconda provides a package called `libstdcxx-ng` that includes a newer version of `libstdc++`, which can be installed via `conda-forge`.\n\n  strings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX\n  ```\n- Make sure that PyTorch, packaging, ninja is installed You can also [install previous versions of PyTorch](https://pytorch.org/get-started/previous-versions/)\n\n  ```\n  pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n  pip3 install packaging ninja cpufeature numpy\n  ```\n- At the same time, you should download and install the corresponding version of flash-attention from https://github.com/Dao-AILab/flash-attention/releases.\n\n## Installation\n\n### Attention\n\nIf you want to use numa support, not only do you need to set USE_NUMA=1, but you also need to make sure you have installed the libnuma-dev (`sudo apt-get install libnuma-dev` may help you).\n\n[Optional] If you want to use the multi-concurrent version, please install the following dependencies.\n\n```\nsudo apt install libtbb-dev libssl-dev libcurl4-openssl-dev libaio1 libaio-dev libgflags-dev zlib1g-dev libfmt-dev\n```\n\n<!-- 1. ~~Use a Docker image, see [documentation for Docker](./doc/en/Docker.md)~~\n   \n   >We are working on the latest docker image, please wait for a while.\n\n2. ~~You can install using Pypi (for linux):~~\n    > We are working on the latest pypi package, please wait for a while.\n   \n   ```\n   pip install ktransformers --no-build-isolation\n   ```\n   \n   for windows we prepare a pre compiled whl package on [ktransformers-0.2.0+cu125torch24avx2-cp312-cp312-win_amd64.whl](https://github.com/kvcache-ai/ktransformers/releases/download/v0.2.0/ktransformers-0.2.0+cu125torch24avx2-cp312-cp312-win_amd64.whl), which require cuda-12.5, torch-2.4, python-3.11, more pre compiled package are being produced.  -->\n\nDownload source code and compile:\n\n- init source code\n\n  ```sh\n  git clone https://github.com/kvcache-ai/ktransformers.git\n  cd ktransformers\n  git submodule update --init --recursive\n  ```\n- [Optional] If you want to run with website, please [compile the website](./api/server/website.md) before execute ``bash install.sh``\n- For Linux\n\n  - For simple install:\n\n    ```shell\n    bash install.sh\n    ```\n  - For those who have two cpu and 1T RAM:\n\n    ```shell\n    # Make sure your system has dual sockets and double size RAM than the model's size (e.g. 1T RAM for 512G model)\n     apt install libnuma-dev\n     export USE_NUMA=1\n     bash install.sh # or #make dev_install\n    ```\n  - For Multi-concurrency with 500G RAM:\n\n    ```shell\n    USE_BALANCE_SERVE=1 bash ./install.sh\n    ```\n  - For Multi-concurrency with two cpu and 1T RAM:\n\n    ```shell\n    USE_BALANCE_SERVE=1 USE_NUMA=1 bash ./install.sh\n    ```\n- For Windows (Windows native temporarily deprecated, please try WSL)\n\n  ```shell\n  install.bat\n  ```\n\n* If you are developer, you can make use of the makefile to compile and format the code. <br> the detailed usage of makefile is [here](./makefile_usage.md)\n\n<h3>Local Chat</h3>\nWe provide a simple command-line local chat Python script that you can run for testing.\n\n> Note: this is a very simple test tool only support one round chat without any memory about last input, if you want to try full ability of the model, you may go to [RESTful API and Web UI](#id_666).\n\n<h4>Run Example</h4>\n\n```shell\n# Begin from root of your cloned repo!\n# Begin from root of your cloned repo!!\n# Begin from root of your cloned repo!!! \n\n# Download mzwing/DeepSeek-V2-Lite-Chat-GGUF from huggingface\nmkdir DeepSeek-V2-Lite-Chat-GGUF\ncd DeepSeek-V2-Lite-Chat-GGUF\n\nwget https://huggingface.co/mradermacher/DeepSeek-V2-Lite-GGUF/resolve/main/DeepSeek-V2-Lite.Q4_K_M.gguf -O DeepSeek-V2-Lite-Chat.Q4_K_M.gguf\n\ncd .. # Move to repo's root dir\n\n# Start local chat\npython -m ktransformers.local_chat --model_path deepseek-ai/DeepSeek-V2-Lite-Chat --gguf_path ./DeepSeek-V2-Lite-Chat-GGUF\n\n# If you see “OSError: We couldn't connect to 'https://huggingface.co' to load this file”, try：\n# GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite\n# python  ktransformers.local_chat --model_path ./DeepSeek-V2-Lite --gguf_path ./DeepSeek-V2-Lite-Chat-GGUF\n```\n\nIt features the following arguments:\n\n- `--model_path` (required): Name of the model (such as \"deepseek-ai/DeepSeek-V2-Lite-Chat\" which will automatically download configs from [Hugging Face](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)). Or if you already got local files  you may directly use that path to initialize the model.\n\n  > Note: <strong>.safetensors</strong> files are not required in the directory. We only need config files to build model and tokenizer.\n  >\n- `--gguf_path` (required): Path of a directory containing GGUF files which could that can be downloaded from [Hugging Face](https://huggingface.co/mzwing/DeepSeek-V2-Lite-Chat-GGUF/tree/main). Note that the directory should only contains GGUF of current model, which means you need one separate directory for each model.\n- `--optimize_config_path` (required except for Qwen2Moe and DeepSeek-V2): Path of YAML file containing optimize rules. There are two rule files pre-written in the [ktransformers/optimize/optimize_rules](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/optimize/optimize_rules) directory for optimizing DeepSeek-V2 and Qwen2-57B-A14, two SOTA MoE models.\n- `--max_new_tokens`: Int (default=1000). Maximum number of new tokens to generate.\n- `--cpu_infer`: Int (default=10). The number of CPUs used for inference. Should ideally be set to the (total number of cores - 2).\n\n<h3>Start Server</h3>\nWe provide a server script, which supports multi-concurrency functionality in version v0.2.4.\n\n```\npython ktransformers/server/main.py --model_path /mnt/data/models/DeepSeek-V3 --gguf_path /mnt/data/models/DeepSeek-V3-GGUF/DeepSeek-V3-Q4_K_M/ --cpu_infer 62 --optimize_config_path ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml --port 10002 --chunk_size 256 --max_new_tokens 1024 --max_batch_size 4 --port 10002 --cache_lens 32768 --backend_type balance_serve\n```\n\nIt features the following arguments:\n\n- `--chunk_size`: Maximum number of tokens processed in a single run by the engine.\n- `--cache_lens`: Total length of kvcache allocated by the scheduler. All requests share a kvcache space corresponding to 32768 tokens, and the space occupied will be released after the requests are completed.\n- `--backend_type`: `balance_serve` is a multi-concurrency backend engine introduced in version v0.2.4. The original single-concurrency engine is `ktransformers`.\n- `--max_batch_size`: Maximum number of requests (prefill + decode) processed in a single run by the engine. (Supported only by `balance_serve`)\n\n<details>\n<summary>Supported Models/quantization</summary>\n\n### Supported models include\n\n\n| ✅**Supported Models** | ❌**Deprecated Models**    |\n| ---------------------- | -------------------------- |\n| DeepSeek-R1            | ~~InternLM2.5-7B-Chat-1M~~ |\n| DeepSeek-V3            |                            |\n| DeepSeek-V2            |                            |\n| DeepSeek-V2.5          |                            |\n| Qwen2-57B              |                            |\n| DeepSeek-V2-Lite       |                            |\n| Mixtral-8x7B           |                            |\n| Mixtral-8x22B          |                            |\n\n### Support quantize format\n\n\n| ✅**Supported Formats** | ❌**Deprecated Formats** |\n| ----------------------- | ------------------------ |\n| IQ1_S                   | ~~IQ2_XXS~~              |\n| IQ2_XXS                 |                          |\n| Q2_K_L                  |                          |\n| Q2_K_XS                 |                          |\n| Q3_K_M                  |                          |\n| Q4_K_M                  |                          |\n| Q5_K_M                  |                          |\n| Q6_K                    |                          |\n| Q8_0                    |                          |\n\n</details>\n\n<details>\n<summary>Suggested Model</summary>\n\n\n| Model Name                     | Model Size | VRAM  | Minimum DRAM    | Recommended DRAM  |\n| ------------------------------ | ---------- | ----- | --------------- | ----------------- |\n| DeepSeek-R1-q4_k_m             | 377G       | 14G   | 382G            | 512G              |\n| DeepSeek-V3-q4_k_m             | 377G       | 14G   | 382G            | 512G              |\n| DeepSeek-V2-q4_k_m             | 133G       | 11G   | 136G            | 192G              |\n| DeepSeek-V2.5-q4_k_m           | 133G       | 11G   | 136G            | 192G              |\n| DeepSeek-V2.5-IQ4_XS           | 117G       | 10G   | 107G            | 128G              |\n| Qwen2-57B-A14B-Instruct-q4_k_m | 33G        | 8G    | 34G             | 64G               |\n| DeepSeek-V2-Lite-q4_k_m        | 9.7G       | 3G    | 13G             | 16G               |\n| Mixtral-8x7B-q4_k_m            | 25G        | 1.6G  | 51G             | 64G               |\n| Mixtral-8x22B-q4_k_m           | 80G        | 4G    | 86.1G           | 96G               |\n| InternLM2.5-7B-Chat-1M         | 15.5G      | 15.5G | 8G(32K context) | 150G (1M context) |\n\nMore will come soon. Please let us know which models you are most interested in.\n\nBe aware that you need to be subject to their corresponding model licenses when using [DeepSeek](https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/LICENSE) and [QWen](https://huggingface.co/Qwen/Qwen2-72B-Instruct/blob/main/LICENSE).\n\n</details>\n\n<details>\n  <summary>Click To Show how to run other examples</summary>\n\n* Qwen2-57B\n\n  ```sh\n  pip install flash_attn # For Qwen2\n\n  mkdir Qwen2-57B-GGUF && cd Qwen2-57B-GGUF\n\n  wget https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct-GGUF/resolve/main/qwen2-57b-a14b-instruct-q4_k_m.gguf?download=true -O qwen2-57b-a14b-instruct-q4_k_m.gguf\n\n  cd ..\n\n  python -m ktransformers.local_chat --model_name Qwen/Qwen2-57B-A14B-Instruct --gguf_path ./Qwen2-57B-GGUF\n\n  # If you see “OSError: We couldn't connect to 'https://huggingface.co' to load this file”, try：\n  # GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct\n  # python  ktransformers/local_chat.py --model_path ./Qwen2-57B-A14B-Instruct --gguf_path ./DeepSeek-V2-Lite-Chat-GGUF\n  ```\n* Deepseek-V2\n\n  ```sh\n  mkdir DeepSeek-V2-Chat-0628-GGUF && cd DeepSeek-V2-Chat-0628-GGUF\n  # Download weights\n  wget https://huggingface.co/bartowski/DeepSeek-V2-Chat-0628-GGUF/resolve/main/DeepSeek-V2-Chat-0628-Q4_K_M/DeepSeek-V2-Chat-0628-Q4_K_M-00001-of-00004.gguf -o DeepSeek-V2-Chat-0628-Q4_K_M-00001-of-00004.gguf\n  wget https://huggingface.co/bartowski/DeepSeek-V2-Chat-0628-GGUF/resolve/main/DeepSeek-V2-Chat-0628-Q4_K_M/DeepSeek-V2-Chat-0628-Q4_K_M-00002-of-00004.gguf -o DeepSeek-V2-Chat-0628-Q4_K_M-00002-of-00004.gguf\n  wget https://huggingface.co/bartowski/DeepSeek-V2-Chat-0628-GGUF/resolve/main/DeepSeek-V2-Chat-0628-Q4_K_M/DeepSeek-V2-Chat-0628-Q4_K_M-00003-of-00004.gguf -o DeepSeek-V2-Chat-0628-Q4_K_M-00003-of-00004.gguf\n  wget https://huggingface.co/bartowski/DeepSeek-V2-Chat-0628-GGUF/resolve/main/DeepSeek-V2-Chat-0628-Q4_K_M/DeepSeek-V2-Chat-0628-Q4_K_M-00004-of-00004.gguf -o DeepSeek-V2-Chat-0628-Q4_K_M-00004-of-00004.gguf\n\n  cd ..\n\n  python -m ktransformers.local_chat --model_name deepseek-ai/DeepSeek-V2-Chat-0628 --gguf_path ./DeepSeek-V2-Chat-0628-GGUF\n\n  # If you see “OSError: We couldn't connect to 'https://huggingface.co' to load this file”, try：\n\n  # GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628\n\n  # python -m ktransformers.local_chat --model_path ./DeepSeek-V2-Chat-0628 --gguf_path ./DeepSeek-V2-Chat-0628-GGUF\n  ```\n\n\n| model name       | weights download link                                                                                                 |\n| ---------------- | --------------------------------------------------------------------------------------------------------------------- |\n| Qwen2-57B        | [Qwen2-57B-A14B-gguf-Q4K-M](https://huggingface.co/Qwen/Qwen2-57B-A14B-Instruct-GGUF/tree/main)                       |\n| DeepseekV2-coder | [DeepSeek-Coder-V2-Instruct-gguf-Q4K-M](https://huggingface.co/LoneStriker/DeepSeek-Coder-V2-Instruct-GGUF/tree/main) |\n| DeepseekV2-chat  | [DeepSeek-V2-Chat-gguf-Q4K-M](https://huggingface.co/bullerwins/DeepSeek-V2-Chat-0628-GGUF/tree/main)                 |\n| DeepseekV2-lite  | [DeepSeek-V2-Lite-Chat-GGUF-Q4K-M](https://huggingface.co/mzwing/DeepSeek-V2-Lite-Chat-GGUF/tree/main)                |\n| DeepSeek-R1      | [DeepSeek-R1-gguf-Q4K-M](https://huggingface.co/unsloth/DeepSeek-R1-GGUF/tree/main/DeepSeek-R1-Q4_K_M)                |\n\n</details>\n\n<!-- pin block for jump -->\n\n<span id='id_666'>\n\n<h3>RESTful API and Web UI  </h3>\n\nStart without website:\n\n```sh\nktransformers --model_path deepseek-ai/DeepSeek-V2-Lite-Chat --gguf_path /path/to/DeepSeek-V2-Lite-Chat-GGUF --port 10002\n```\n\nStart with website:\n\n```sh\nktransformers --model_path deepseek-ai/DeepSeek-V2-Lite-Chat --gguf_path /path/to/DeepSeek-V2-Lite-Chat-GGUF  --port 10002 --web True\n```\n\nOr you want to start server with transformers, the model_path should include safetensors\n\n```bash\nktransformers --type transformers --model_path /mnt/data/model/Qwen2-0.5B-Instruct --port 10002 --web True\n```\n\nAccess website with url [http://localhost:10002/web/index.html#/chat](http://localhost:10002/web/index.html#/chat) :\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"Web UI\" src=\"https://github.com/user-attachments/assets/615dca9b-a08c-4183-bbd3-ad1362680faf\" width=90%>\n  </picture>\n</p>\n\nMore information about the RESTful API server can be found [here](doc/en/api/server/server.md). You can also find an example of integrating with Tabby [here](doc/en/api/server/tabby.md).\n"
  },
  {
    "path": "doc/en/kt-kernel/GLM-5-Tutorial.md",
    "content": "# Running GLM-5 with SGLang and KT-Kernel\n\nThis tutorial demonstrates how to run GLM-5 model inference using SGLang integrated with KT-Kernel for CPU-GPU heterogeneous inference. This setup enables efficient deployment of large MoE models by offloading experts to CPU. KT-Kernel supports both BF16 and FP8 precision backends, allowing you to choose between maximum quality and reduced memory footprint.\n\n## Table of Contents\n\n- [Table of Contents](#table-of-contents)\n- [Prerequisites](#prerequisites)\n- [Step 1: Download Model Weights](#step-1-download-model-weights)\n- [Step 2: Launch SGLang Server](#step-2-launch-sglang-server)\n- [Step 3: Send Inference Requests](#step-3-send-inference-requests)\n  - [Option A: Interactive Chat with KT CLI](#option-a-interactive-chat-with-kt-cli)\n  - [Option B: OpenAI-Compatible API](#option-b-openai-compatible-api)\n- [Additional Resources](#additional-resources)\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n1. **SGLang installed**\n\n    Install the kvcache-ai fork of SGLang (one of):\n\n    ```bash\n    # Option A: One-click install (from ktransformers root)\n    ./install.sh\n\n    # Option B: pip install\n    pip install sglang-kt\n    ```\n\n2. **KT-Kernel installed**\n\n    ```bash\n    git clone https://github.com/kvcache-ai/ktransformers.git\n    git submodule update --init --recursive\n    cd kt-kernel && ./install.sh\n    ```\n\n3. **transformers reinstalled**\n\n    ```bash\n    pip install git+https://github.com/huggingface/transformers.git\n    ```\n\n4. **CUDA toolkit** - CUDA 12.0+ recommended (12.8+ for best FP8 support)\n5. **Hugging Face CLI** - For downloading models:\n   ```bash\n   pip install -U huggingface-hub\n   ```\n\n## Step 1: Download Model Weights\n\nDownload the GLM-5 weights from Hugging Face.\n\n```bash\n# FP8\nhf download zai-org/GLM-5-FP8 \\\n  --local-dir /path/to/GLM-5-FP8\n\n# BF16\nhf download zai-org/GLM-5 \\\n  --local-dir /path/to/GLM-5\n```\n\n**Note:** Replace `/path/to/` with your actual storage path throughout this tutorial.\n\n## Step 2: Launch SGLang Server\n\nStart the SGLang server with KT-Kernel integration for CPU-GPU heterogeneous inference.\n\n```bash\n# FP8 Precision\nexport PYTORCH_ALLOC_CONF=expandable_segments:True\nexport SGLANG_ENABLE_JIT_DEEPGEMM=0\n\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --model /path/to/GLM-5-FP8 \\\n  --kt-weight-path /path/to/GLM-5-FP8 \\\n  --kt-cpuinfer 96 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 30 \\\n  --kt-method FP8 \\\n  --kt-gpu-prefill-token-threshold 1024 \\\n  --kt-enable-dynamic-expert-update \\\n  --kt-expert-placement-strategy uniform \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.75 \\\n  --served-model-name GLM5 \\\n  --enable-mixed-chunk \\\n  --tensor-parallel-size 8 \\\n  --enable-p2p-check \\\n  --disable-shared-experts-fusion \\\n  --chunked-prefill-size 16384 \\\n  --max-running-requests 4 \\\n  --max-total-tokens 128000 \\\n  --attention-backend flashinfer \\\n  --fp8-gemm-backend cutlass \\\n  --kv-cache-dtype bf16 \\\n  --tool-call-parser glm47 \\\n  --reasoning-parser glm45 \\\n  --watchdog-timeout 3000\n\n# BF16 Precision\nexport PYTORCH_ALLOC_CONF=expandable_segments:True\nexport SGLANG_ENABLE_JIT_DEEPGEMM=0\n\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --model /path/to/GLM-5 \\\n  --kt-weight-path /path/to/GLM-5 \\\n  --kt-cpuinfer 96 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 10 \\\n  --kt-method BF16 \\\n  --kt-gpu-prefill-token-threshold 1024 \\\n  --kt-enable-dynamic-expert-update \\\n  --kt-expert-placement-strategy uniform \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.75 \\\n  --served-model-name GLM5 \\\n  --enable-mixed-chunk \\\n  --tensor-parallel-size 8 \\\n  --enable-p2p-check \\\n  --disable-shared-experts-fusion \\\n  --chunked-prefill-size 16384 \\\n  --max-running-requests 4 \\\n  --max-total-tokens 128000 \\\n  --attention-backend flashinfer \\\n  --tool-call-parser glm47 \\\n  --reasoning-parser glm45 \\\n  --watchdog-timeout 3000\n```\n\nLayerwise prefill requires one extra MoE layer's worth of VRAM.\n\nIf you encounter OOM, adjust `--kt-num-gpu-experts`, `--chunked-prefill-size`, `--mem-fraction-static` and `--max-total-tokens` when launching the server.\n\nIf you encounter other issues, try `kt doctor` to diagnose your setup.\n\nSee [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines.\n\n## Step 3: Send Inference Requests\n\nOnce the server is running (default: `http://localhost:30000`), you can interact with the model in several ways:\n\n### Option A: Interactive Chat with KT CLI\n\nThe easiest way to chat with the model:\n\n```bash\nkt chat\n```\n\nThis opens an interactive terminal chat session. Type your messages and press Enter to send. Use `Ctrl+C` to exit.\n\n### Option B: OpenAI-Compatible API\n\nThe server exposes an OpenAI-compatible API at `http://localhost:30000/v1`.\n\n**curl example (streaming):**\n\n```bash\ncurl http://localhost:30000/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"GLM5\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"hi, who are you?\"}],\n    \"stream\": true\n  }'\n```\n\n## Additional Resources\n\n- [GLM-5 Model Card](https://huggingface.co/zai-org/GLM-5)\n- [KT-Kernel Documentation](../../../kt-kernel/README.md)\n- [SGLang GitHub](https://github.com/sgl-project/sglang)\n- [KT-Kernel Parameters Reference](../../../kt-kernel/README.md#kt-kernel-parameters)\n"
  },
  {
    "path": "doc/en/kt-kernel/Kimi-K2-Thinking-Native.md",
    "content": "# Running Kimi-K2-Thinking with SGLang and KT-Kernel\n\nThis tutorial demonstrates how to run Kimi-K2 model inference using SGLang integrated with KT-Kernel for CPU-GPU heterogeneous inference. This setup enables efficient deployment of large MoE models by offloading experts to CPU.\n\n## Table of Contents\n\n- [Hardware Requirements](#hardware-requirements)\n- [Prerequisites](#prerequisites)\n- [Step 1: Download Model Weights](#step-1-download-model-weights)\n- [Step 2: Launch SGLang Server](#step-2-launch-sglang-server)\n- [Step 3: Send Inference Requests](#step-3-send-inference-requests)\n\n## Hardware Requirements\n\n**Minimum Configuration:**\n- **GPU**: NVIDIA RTX 4090 48GB (or equivalent with at least 48GB VRAM available)\n- **CPU**: x86 CPU with AVX512 support (e.g., Sapphire Rapids)\n- **RAM**: At least 650GB system memory\n- **Storage**: ~600GB for model weights (native INT4 weight, same weight dir for CPU and GPU)\n\n**Tested Configuration:**\n\n- **GPU**: 1/2/4/8x NVIDIA RTX 4090/L20 48GB\n- **CPU**: 2x Intel(R) Xeon(R) Platinum 8488C\n- **RAM**: 2TB DDR5 4800MHz\n- **OS**: Linux (Ubuntu 20.04+ recommended)\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n1. **KT-Kernel installed** - Follow the [installation guide](./kt-kernel_intro.md#installation)\n2. **SGLang installed** - Install the kvcache-ai fork of SGLang (one of):\n\n```bash\n# Option A: One-click install (from ktransformers root)\n./install.sh\n\n# Option B: pip install\npip install sglang-kt\n```\n\n3. **CUDA toolkit** - Compatible with your GPU (CUDA 11.8+ recommended)\n4. **Hugging Face CLI** - For downloading models:\n   ```bash\n   pip install huggingface-hub\n   ```\n\n## Step 1: Download Model Weights\n\n```bash\n# Create a directory for models\nmkdir -p /path/to/models\ncd /path/to/models\n\n# Download Kimi-K2-Thinking (INT4 for both CPU and GPU)\nhuggingface-cli download moonshotai/Kimi-K2-Thinking \\\n  --local-dir /path/to/kimi-k2-thinking\n```\n\n**Note:** Replace `/path/to/models` with your actual storage path throughout this tutorial.\n\n## Step 2: Launch SGLang Server\n\nStart the SGLang server with KT-Kernel integration for CPU-GPU heterogeneous inference.\n\n\n### Launch Command (2x RTX 4090 Example)\n\n```bash\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 30001 \\\n  --model /path/to/kimi-k2-thinking \\\n  --kt-weight-path /path/to/kimi-k2-thinking \\\n  --kt-cpuinfer 96 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 8 \\\n  --kt-method RAWINT4 \\\n  --kt-gpu-prefill-token-threshold 400 \\\n  --kt-max-deferred-experts-per-token 1 \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.94 \\\n  --served-model-name Kimi-K2-Thinking \\\n  --enable-mixed-chunk \\\n  --tensor-parallel-size 2 \\\n  --enable-p2p-check \\\n  --disable-shared-experts-fusion \\\n  --chunked-prefill-size 65536 \\\n  --max-total-tokens 65536 \\\n  --attention-backend flashinfer\n```\n\nIt takes about 2~3 minutes to start the server.\n\nSee [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines.\n\n### Key Parameters\n\n| Parameter | Description |\n|-----------|-------------|\n| `--kt-method RAWINT4` | CPU and GPU use the same INT4 weight. Set `--model` and `--kt-weight-path` to the same directory. |\n| `--kt-num-gpu-experts` | Number of experts kept on GPU for decoding. |\n| `--kt-gpu-prefill-token-threshold` | Token count threshold for prefill strategy. Below: hybrid CPU+GPU. Above: layerwise GPU prefill. |\n| `--chunked-prefill-size` | Maximum tokens per prefill batch. |\n| `--max-total-tokens` | Maximum total tokens in KV cache. |\n\n### About `--kt-gpu-prefill-token-threshold`\n\nThis parameter controls the prefill strategy:\n\n- **$\\leq$ threshold**: Uses hybrid CPU+GPU prefill. No extra VRAM needed, but performance degrades slowly as token count increases.\n- **> threshold**: Uses layerwise GPU prefill. Performance scales near-exponentially until reaching the bottleneck, but requires 9GB+ extra VRAM.\n\n### Troubleshooting OOM\n\nLayerwise prefill requires extra VRAM (~9GB + incremental cost with prefill length). If you encounter OOM, adjust these parameters based on your use case and hardware (refer to the recommended parameters table below):\n\n| Parameter | VRAM Impact |\n|-----------|-------------|\n| `--kt-num-gpu-experts` | Reduces expert weight VRAM usage |\n| `--chunked-prefill-size` | Reduces prefill extra VRAM allocation |\n| `--max-total-tokens` | Reduces KV cache VRAM usage |\n\n**Tip:** Test with an input of length `chunked-prefill-size` to verify your configuration won't OOM during prefill.\n\n\n### Recommended Parameters\n\n| GPU Config | `kt-num-gpu-experts` | `max-total-tokens` | `chunked-prefill-size` |\n|------------|----------------------|---------------------|------------------------|\n| 1x RTX 4090 (48GB) | 0 | 30000 | 30000 |\n| 2x RTX 4090 (48GB) | 8 | 65536 | 65536 |\n| 4x RTX 4090 (48GB) | 30 | 80000 | 65536 |\n| 8x RTX 4090 (48GB) | 80 | 100000 | 65536 |\n\n**Tip:** If your prefill and total length requirements are low (e.g., processing short texts), you can reduce `max-total-tokens` and `chunked-prefill-size` to free up VRAM for a larger `kt-num-gpu-experts`, which improves decode performance.\n\n### Performance\n\nThe following prefill throughput (tokens/s) benchmarks were measured with single concurrency:\n\n| GPU Config | 2048 tokens | 8192 tokens | 32768 tokens |\n|------------|-------------|-------------|--------------|\n| 1x RTX 4090 (48GB) | 53 | 184 | 290* |\n| 2x RTX 4090 (48GB) | 85 | 294 | 529 |\n| 4x RTX 4090 (48GB) | 118 | 415 | 818 |\n| 8x RTX 4090 (48GB) | 130 | 435 | 1055 |\n\n* Note: 1x RTX 4090 with layerwise prefill OOMs at 32768 tokens, so the 290 tokens/s is measured with qlen=30000.\n\n## Step 3: Send Inference Requests\n\nOnce the server is running, you can send inference requests using the OpenAI-compatible API.\n\n### Basic Chat Completion Request\n\n```bash\ncurl -s http://localhost:30001/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"Kimi-K2-Thinking\",\n    \"stream\": false,\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hi\"}\n    ]\n  }'\n```\n\n### Example Response\n\n```json\n{\n    \"id\": \"cd0905562bf44513947284f80cc5634b\",\n    \"object\": \"chat.completion\",\n    \"created\": 1764921457,\n    \"model\": \"Kimi-K2-Thinking\",\n    \"choices\": [\n        {\n            \"index\": 0,\n            \"message\": {\n                \"role\": \"assistant\",\n                \"content\": \" <think> The user says \\\"hi\\\". This is a very simple greeting. I should respond in a friendly and helpful manner. Since I'm an AI assistant, I should be professional but approachable.\\n\\nPossible responses:\\n1. \\\"Hello! How can I help you today?\\\"\\n2. \\\"Hi there! What can I do for you?\\\"\\n3. \\\"Hello! It's nice to hear from you. What would you like to talk about?\\\"\\n4. \\\"Hi! I'm here to assist you with any questions you might have.\\\"\\n\\nI think option 1 is the most standard and professional. It's direct, friendly, and opens the door for the user to ask their question. I should keep it concise.\\n\\nLet me go with: \\\"Hello! How can I help you today?\\\" </think> Hello! How can I help you today?\",\n                \"reasoning_content\": null,\n                \"tool_calls\": null\n            },\n            \"logprobs\": null,\n            \"finish_reason\": \"stop\",\n            \"matched_stop\": 163586\n        }\n    ],\n    \"usage\": {\n        \"prompt_tokens\": 26,\n        \"total_tokens\": 189,\n        \"completion_tokens\": 163,\n        \"prompt_tokens_details\": null,\n        \"reasoning_tokens\": 0\n    },\n    \"metadata\": {\n        \"weight_version\": \"default\"\n    }\n}\n```\n\n## Advance Use Case: Running Claude Code with Native Kimi-K2-Thinking Local Backend\n\nAdd the following parameters to the SGLang launch command above to enable tool calling support:\n\n```bash\n--tool-call-parser kimi_k2 --reasoning-parser kimi_k2\n```\n\nWith these parameters enabled, you can use [claude-code-router](https://github.com/musistudio/claude-code-router) to connect Kimi-K2-Thinking as a local backend for [Claude Code](https://github.com/anthropics/claude-code).\n\n## Additional Resources\n\n- [KT-Kernel Documentation](../../../kt-kernel/README.md)\n- [SGLang GitHub](https://github.com/sgl-project/sglang)\n- [Claude Code Router](https://github.com/musistudio/claude-code-router) - Route Claude Code to custom backends\n"
  },
  {
    "path": "doc/en/kt-kernel/MiniMax-M2.1-Tutorial.md",
    "content": "# Running MiniMax-M2.1 with Native Precision using SGLang and KT-Kernel\n\nThis tutorial demonstrates how to run MiniMax-M2.1 model inference using SGLang integrated with KT-Kernel. MiniMax-M2.1 provides native FP8 weights, enabling efficient GPU inference with reduced memory footprint while maintaining high accuracy.\n\n## Table of Contents\n\n- [Table of Contents](#table-of-contents)\n- [Hardware Requirements](#hardware-requirements)\n- [Prerequisites](#prerequisites)\n- [Step 1: Download Model Weights](#step-1-download-model-weights)\n- [Step 2: Launch Server with KT CLI](#step-2-launch-server-with-kt-cli)\n  - [Advanced Options](#advanced-options)\n  - [Dry Run](#dry-run)\n  - [Key Parameters](#key-parameters)\n- [Step 3: Send Inference Requests](#step-3-send-inference-requests)\n  - [Option A: Interactive Chat with KT CLI](#option-a-interactive-chat-with-kt-cli)\n  - [Option B: OpenAI-Compatible API](#option-b-openai-compatible-api)\n- [Performance](#performance)\n  - [Throughput (tokens/s)](#throughput-tokenss)\n  - [Comparison with llama.cpp](#comparison-with-llamacpp)\n- [Troubleshooting](#troubleshooting)\n  - [OOM (Out of Memory) Issues](#oom-out-of-memory-issues)\n- [Advanced Use Case: Running Claude Code with MiniMax-M2.1 Local Backend](#advanced-use-case-running-claude-code-with-minimax-m21-local-backend)\n- [Additional Resources](#additional-resources)\n\n## Hardware Requirements\n\n**Minimum Configuration:**\n- **GPU**: NVIDIA RTX 5090 32 GB (or equivalent with at least 32GB VRAM available)\n- **CPU**: x86 CPU with AVX512 support (e.g., Intel Sapphire Rapids, AMD EPYC)\n- **RAM**: At least 256GB system memory\n- **Storage**: >220 GB for model weights (same weight dir for GPU and CPU)\n\n**Tested Configuration:**\n\n- **GPU**: 1/2 x NVIDIA GeForce RTX 5090 (32 GB)\n- **CPU**: 2 x AMD EPYC 9355 32-Core Processor (128 threads)\n- **RAM**: 1TB DDR5 5600MT/s ECC\n- **OS**: Linux (Ubuntu 20.04+ recommended)\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n1. **SGLang installed**\n\n    Install the kvcache-ai fork of SGLang (one of):\n\n    ```bash\n    # Option A: One-click install (from ktransformers root)\n    ./install.sh\n\n    # Option B: pip install\n    pip install sglang-kt\n    ```\n\n2. **KT-Kernel installed**\n\n    Please follow [kt-kernel](https://github.com/kvcache-ai/ktransformers/blob/main/kt-kernel/README.md)\n\n    After installation, verify the CLI is working:\n\n    ```bash\n    kt version\n    ```\n\n3. **CUDA toolkit** - CUDA 12.0+ recommended for FP8 support\n4. **Hugging Face CLI** - For downloading models:\n   ```bash\n   pip install -U huggingface-hub\n   ```\n\n## Step 1: Download Model Weights\n\nDownload the official MiniMax-M2.1 weights.\n\n* huggingface: https://huggingface.co/MiniMaxAI/MiniMax-M2.1\n\n    ```bash\n    hf download MiniMaxAI/MiniMax-M2.1 --local-dir /path/to/minimax-m2.1\n    ```\n\n## Step 2: Launch Server with KT CLI\n\nThe simplest way to start the MiniMax-M2.1 server is using the `kt` CLI:\n\n```bash\nkt run m2.1\n```\n\nThe CLI will automatically detect your hardware configuration and apply optimal parameters for your system.\n\n### Advanced Options\n\nFor custom configurations, you can specify additional parameters:\n\n```bash\n# Use specific number of GPUs (tensor parallel)\nkt run m2.1 --tensor-parallel-size 2\n\n# Custom CPU threads and NUMA configuration\nkt run m2.1 --cpu-threads 64 --numa-nodes 2\n```\n\n### Dry Run\n\nTo preview the command without executing:\n\n```bash\nkt run m2.1 --dry-run\n```\n\nSee [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines.\n\n### Key Parameters\n\n| Parameter | Description |\n|-----------|-------------|\n| `--kt-method FP8` | Enable FP8 inference mode for MiniMax-M2.1 native FP8 weights. |\n| `--kt-cpuinfer` | Number of CPU inference threads. Set to physical CPU cores (not hyperthreads). |\n| `--kt-threadpool-count` | Number of thread pools. Set to NUMA node count. |\n| `--kt-num-gpu-experts` | Number of experts kept on GPU for decoding. |\n| `--chunked-prefill-size` | Maximum tokens per prefill batch. |\n| `--max-total-tokens` | Maximum total tokens in KV cache. |\n| `--kt-gpu-prefill-token-threshold` | Token threshold for layerwise prefill strategy. |\n\n## Step 3: Send Inference Requests\n\nOnce the server is running (default: `http://localhost:30000`), you can interact with the model in several ways:\n\n### Option A: Interactive Chat with KT CLI\n\nThe easiest way to chat with the model:\n\n```bash\nkt chat\n```\n\nThis opens an interactive terminal chat session. Type your messages and press Enter to send. Use `Ctrl+C` to exit.\n\n### Option B: OpenAI-Compatible API\n\nThe server exposes an OpenAI-compatible API at `http://localhost:30000/v1`.\n\n**curl example (streaming):**\n\n```bash\ncurl http://localhost:30000/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"MiniMax-M2.1\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}],\n    \"stream\": true\n  }'\n```\n\n\n## Performance\n\n### Throughput (tokens/s)\n\nThe following benchmarks were measured with single concurrency (Prefill tps / Decode tps):\n\n| GPU  | CPU  | PCIe |  2048 tokens | 8192 tokens | 32768 tokens |\n|------------|-------------|-------------|-------------|-------------|--------------|\n| 1 x RTX 4090 (48 GB) | 2 x Intel Xeon Platinum 8488C| PCIe 4.0 | 129 / 21.8 | 669 / 20.9 | 1385 / 18.5 |\n| 2 x RTX 4090 (48 GB) | 2 x Intel Xeon Platinum 8488C| PCIe 4.0 | 139 / 23.6 | 1013 / 23.3 | 2269 / 21.6 |\n| 1 x RTX 5090 (32 GB) | 2 x AMD EPYC 9355 | PCIe 5.0 | 408 / 32.1 | 1196 / 31.4 | 2540 / 27.6 |\n| 2 x RTX 5090 (32 GB) | 2 x AMD EPYC 9355 | PCIe 5.0 | 414 / 35.9 | 1847 / 35.5 | 4007 / 33.1 |\n\n![Throughput in 2 x RTX 5090](../../assets/MiniMax-M2_speed.png)\n\n### Comparison with llama.cpp\n\nWe benchmarked KT-Kernel + Sglang against llama.cpp to demonstrate the performance advantages of our CPU-GPU heterogeneous inference approach.\n\n- **Weight formats**: KT-Kernel uses native unquantized FP8 weights from MiniMax-M2, while llama.cpp only supports quantized weights, so we used Q8_0 quantization for the llama.cpp benchmarks.\n\n- **Test environment**: 2 x RTX 5090 (32 GB) with AMD EPYC 9355 CPUs, input tokens=32768, output tokens=512. We made our best effort to optimize llama.cpp performance, but we could not achieve optimal prefill and decode with a single command, so we used separate configurations for prefill and decode measurements.\n\n![Performance Comparison with llama.cpp](../../assets/MiniMax-M2_comparison.png)\n\nAs shown in the chart, KT-Kernel achieves up to **>4.5x prefill** and **30% faster decode** compared to llama.cpp on the same hardware.\n\n## Troubleshooting\n\n### OOM (Out of Memory) Issues\n\nLayerwise prefill requires extra VRAM (~3.6GB + incremental cost with prefill length). If you encounter OOM, adjust these parameters when launching the server:\n\n| Parameter | VRAM Impact |\n|-----------|-------------|\n| `--kt-num-gpu-experts` | Reduces expert weight VRAM usage |\n| `--chunked-prefill-size` | Reduces prefill extra VRAM allocation |\n| `--max-total-tokens` | Reduces KV cache VRAM usage |\n\n**Tip:** Test with an input of length `chunked-prefill-size` to verify your configuration won't OOM during prefill.\n\n## Advanced Use Case: Running Claude Code with MiniMax-M2.1 Local Backend\n\n```bash\nkt run m2.1 --tool-call-parser minimax-m2 --reasoning-parser minimax-append-think\n```\n\nWith the above command, you can use [claude-code-router](https://github.com/musistudio/claude-code-router) to connect MiniMax-M2.1 as a local backend for [Claude Code](https://github.com/anthropics/claude-code).\n\n## Additional Resources\n\n- [KT-Kernel Documentation](../../../kt-kernel/README.md)\n- [SGLang GitHub](https://github.com/sgl-project/sglang)\n- [KT-Kernel Parameters Reference](../../../kt-kernel/README.md#kt-kernel-parameters)"
  },
  {
    "path": "doc/en/kt-kernel/Native-Precision-Tutorial.md",
    "content": "# Running Native Precision Models with SGLang and KT-Kernel\n\nThis tutorial demonstrates how to run native precision MoE model inference using SGLang integrated with KT-Kernel. KTransformers v0.5.1+ supports multiple native precision formats, enabling efficient inference across various model architectures.\n\n## Table of Contents\n\n- [Supported Precision Formats](#supported-precision-formats)\n- [Supported Models](#supported-models)\n- [Hardware Requirements](#hardware-requirements)\n- [Prerequisites](#prerequisites)\n- [Launch Server](#launch-server)\n  - [Example Configurations](#example-configurations)\n  - [Key Parameters Reference](#key-parameters-reference)\n- [Send Inference Requests](#send-inference-requests)\n- [Technical Highlights](#technical-highlights)\n  - [Experts Scheduling](#experts-scheduling)\n  - [Dual Prefill Mechanism](#dual-prefill-mechanism)\n- [Troubleshooting](#troubleshooting)\n- [Additional Resources](#additional-resources)\n\n## Supported Precision Formats\n\nKTransformers supports multiple native precision formats via the `--kt-method` parameter:\n\n| kt-method | Precision Format | Description | Instruction Set |\n|-----------|-----------------|-------------|-----------------|\n| `BF16` | BF16 Native | Zero precision loss, original weights | AMX + AVX512 |\n| `FP8` | FP8 Blockwise | Block-wise scale quantization | AVX512 |\n| `FP8_PERCHANNEL` | FP8 Per-Channel | Per-channel scale quantization | AVX512 |\n| `RAWINT4` | INT4 Native | Same INT4 weights for CPU and GPU | AVX512 |\n\n## Supported Models\n\n| Model(sorted by lexicographical order) | kt-method | Precision | \n|-------|-----------|------------|\n| **DeepSeek-V3/R1/V3.2** | `FP8` | FP8 |\n| **GLM-4.7** | `FP8_PERCHANNEL`, `BF16` | FP8, BF16 |\n| **Kimi-K2-Thinking** | `RAWINT4` | INT4 Native |\n| **MiniMax-M2/M2.1** | `FP8` | FP8 |\n| **Qwen3-235B-A22B** | `FP8`, `BF16` | FP8, BF16 |\n| **Qwen3-30-A3B** | `FP8`, `BF16` | FP8, BF16 |\n| **Qwen3-Next-80B-A3B** | `FP8`, `BF16` | FP8, BF16 |\n\n## Hardware Requirements\n\n**Minimum Configuration:**\n- **GPU**: 1-2 x NVIDIA GPU with at least 24GB VRAM (RTX 4090/5090 or equivalent, depending on model)\n- **CPU**: x86 CPU with AVX512 support (Intel Sapphire Rapids+, AMD EPYC)\n  - BF16 additionally benefits from AMX support\n- **RAM**: At least as much RAM as model size (e.g., 256GB+ for MiniMax-M2.1)\n- **Storage**: Sufficient space for model weights (varies by model)\n\n**Recommended Configuration:**\n- **GPU**: 1-8 x NVIDIA RTX 5090 (32 GB) or equivalent\n- **CPU**: 2 x AMD EPYC 9355 32-Core / Intel Xeon Platinum 8488C\n- **RAM**: 1TB DDR5 5600MT/s ECC\n- **PCIe**: PCIe 5.0 for optimal CPU-GPU data transfer\n- **OS**: Linux (Ubuntu 20.04+ recommended)\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n1. **SGLang installed**\n\n    Install the kvcache-ai fork of SGLang (one of):\n\n    ```bash\n    # Option A: One-click install (from ktransformers root)\n    ./install.sh\n\n    # Option B: pip install\n    pip install sglang-kt\n    ```\n\n2. **KT-Kernel installed**\n\n    Follow the [kt-kernel installation guide](https://github.com/kvcache-ai/ktransformers/blob/main/kt-kernel/README.md):\n\n    ```bash\n    git clone https://github.com/kvcache-ai/ktransformers.git\n    cd ktransformers/kt-kernel\n    ./install.sh\n    ```\n\n    Verify the installation:\n\n    ```bash\n    kt version\n    ```\n\n3. **CUDA toolkit** - CUDA 12.0+ recommended\n4. **Hugging Face CLI** - For downloading models:\n   ```bash\n   pip install -U huggingface-hub\n   ```\n   \n## Launch Server\n\n### Example Configurations\nFor now, only `MiniMax-M2/M2.1`, `DeepSeek-V3/R1-0528/V3.2`, `Kimi-K2-Thinking` can run with kt-cli.\n\n**DeepSeek-V3.2**\n\n```bash\nkt run V3.2 --kt-enable-dynamic-expert-update\n```\n\n**GLM-4.7**\n\n```bash\npython -m sglang.launch_server \\\n    --host 0.0.0.0 \\\n    --port 30000 \\\n    --model /path/to/GLM-4.7/ \\\n    --kt-weight-path /path/to/GLM-4.7/ \\\n    --kt-cpuinfer 100 \\\n    --kt-threadpool-count 2 \\\n    --kt-num-gpu-experts 15 \\\n    --kt-method BF16 \\\n    --kt-enable-dynamic-expert-update \\\n    --attention-backend flashinfer \\\n    --mem-fraction-static 0.80 \\\n    --chunked-prefill-size 16384 \\\n    --max-running-requests 2 \\\n    --max-total-tokens 32768 \\\n    --trust-remote-code \\\n    --served-model-name GLM-4.7 \\\n    --enable-mixed-chunk \\\n    --tensor-parallel-size 8 \\\n    --enable-p2p-check \\\n    --disable-shared-experts-fusion \\\n    --tool-call-parser glm47 \\\n    --reasoning-parser glm45 \\\n    --watchdog-timeout 3000 \\\n    --kt-gpu-prefill-token-threshold 1024\n```\n\n**GLM-4.7-FP8**\n\n```bash\npython -m sglang.launch_server \\\n    --host 0.0.0.0 \\\n    --port 30000 \\\n    --model /path/to/GLM-4.7-FP8/ \\\n    --kt-weight-path /path/to/GLM-4.7-FP8/ \\\n    --kt-cpuinfer 100 \\\n    --kt-threadpool-count 2 \\\n    --kt-num-gpu-experts 80 \\\n    --kt-method FP8_PERCHANNEL \\\n    --kt-enable-dynamic-expert-update \\\n    --attention-backend flashinfer \\\n    --mem-fraction-static 0.75 \\\n    --chunked-prefill-size 16384 \\\n    --max-running-requests 4 \\\n    --max-total-tokens 100000 \\\n    --trust-remote-code \\\n    --served-model-name GLM-4.7 \\\n    --enable-mixed-chunk \\\n    --tensor-parallel-size 8 \\\n    --enable-p2p-check \\\n    --disable-shared-experts-fusion \\\n    --watchdog-timeout 3000 \\\n    --fp8-gemm-backend triton \\\n    --kt-gpu-prefill-token-threshold 2048\n```\n\n**Qwen3-235B-A22B**\n\n```bash\npython -m sglang.launch_server \\\n    --host 0.0.0.0 \\\n    --port 30000 \\\n    --model /path/to/Qwen3-235B-A22B \\\n    --kt-weight-path /path/to/Qwen3-235B-A22B \\\n    --kt-cpuinfer 100 \\\n    --kt-threadpool-count 2 \\\n    --kt-num-gpu-experts 20 \\\n    --kt-method FP8 \\\n    --kt-enable-dynamic-expert-update \\\n    --kt-expert-placement-strategy uniform \\\n    --attention-backend flashinfer \\\n    --mem-fraction-static 0.80 \\\n    --chunked-prefill-size 16384 \\\n    --max-running-requests 4 \\\n    --max-total-tokens 100000 \\\n    --trust-remote-code \\\n    --served-model-name Qwen3-235B-A22B \\\n    --enable-mixed-chunk \\\n    --tensor-parallel-size 8 \\\n    --enable-p2p-check \\\n    --kt-gpu-prefill-token-threshold 2048\n```\n\n### Key Parameters Reference\n\n| Parameter | Description |\n|-----------|-------------|\n| `--kt-method` | Precision format: `BF16`, `FP8_PERCHANNEL`, `FP8`, `RAWINT4`, `AMXINT4` |\n| `--kt-cpuinfer` | Number of CPU inference threads (set to ~90% of physical cores) |\n| `--kt-threadpool-count` | Number of thread pools (set to NUMA node count) |\n| `--kt-num-gpu-experts` | Number of experts kept on GPU per layer |\n| `--kt-enable-dynamic-expert-update` | Enable dynamic expert placement updates during Layerwise Prefill |\n| `--kt-expert-placement-strategy` | Expert placement strategy |\n| `--kt-gpu-prefill-token-threshold` | Token threshold for triggering Layerwise Prefill |\n| `--chunked-prefill-size` | Maximum tokens per prefill batch |\n| `--max-total-tokens` | Maximum total tokens in KV cache |\n\n## Send Inference Requests\n\nOnce the server is running (default: `http://localhost:30000`), you can interact with the model:\n\n### Option A: Interactive Chat with KT CLI\n\n```bash\nkt chat\n```\n\nThis opens an interactive terminal chat session. Type your messages and press Enter to send. Use `Ctrl+C` to exit.\n\n### Option B: OpenAI-Compatible API\n\nThe server exposes an OpenAI-compatible API at `http://localhost:30000/v1`.\n\n**curl example (streaming):**\n\n```bash\ncurl http://localhost:30000/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"MODEL_NAME\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Hello! What can you help me with?\"}],\n    \"stream\": true\n  }'\n```\n\n**Python example:**\n\n```python\nfrom openai import OpenAI\n\nclient = OpenAI(base_url=\"http://localhost:30000/v1\", api_key=\"none\")\n\nresponse = client.chat.completions.create(\n    model=\"MODEL_NAME\",\n    messages=[{\"role\": \"user\", \"content\": \"Explain quantum computing in simple terms.\"}],\n    stream=True\n)\n\nfor chunk in response:\n    if chunk.choices[0].delta.content:\n        print(chunk.choices[0].delta.content, end=\"\")\n```\n\n## Technical Highlights\n\n### Experts Scheduling\n\nSee [CPU-GPU Expert Scheduling Tutorial](./experts-sched-Tutorial.md) for details.\n\n### Dual Prefill Mechanism\n\nKTransformers implements an adaptive dual prefill mechanism based on input token count:\n\n| Mode | Trigger Condition | Computation |\n|------|-------------------|-------------|\n| **CPU-GPU Hybrid** | num_tokens < threshold | GPU + CPU |\n| **Layerwise Prefill** | num_tokens >= threshold | GPU (CPU weights transferred to GPU) |\n\nSet the `kt-gpu-prefill-token-threshold` parameter for best performance based on your workload.\n\n## Troubleshooting\n\n### OOM (Out of Memory) Issues\n\nLayerwise prefill requires extra VRAM. If you encounter OOM, adjust these parameters:\n\n| Parameter | VRAM Impact |\n|-----------|-------------|\n| `--kt-num-gpu-experts` | Reduces expert weight VRAM usage |\n| `--chunked-prefill-size` | Reduces prefill extra VRAM allocation |\n| `--max-total-tokens` | Reduces KV cache VRAM usage |\n| `--mem-fraction-static` | Adjusts static memory fraction |\n\n**Tips:**\n- Test with an input of length `chunked-prefill-size` to verify configuration\n- Reduce `--kt-num-gpu-experts` if GPU memory is limited\n- For multi-GPU setups, ensure `--enable-p2p-check` is enabled\n- For FP8 models, `--fp8-gemm-backend triton` may be required\n\n## Additional Resources\n\n- [KT-Kernel Documentation](../../../kt-kernel/README.md)\n- [SGLang GitHub](https://github.com/sgl-project/sglang)\n- [MiniMax-M2.1 Tutorial](./MiniMax-M2.1-Tutorial.md) - Detailed guide for MiniMax-M2.1 and other FP8 models\n- [Kimi-K2-Thinking Tutorial](./Kimi-K2-Thinking-Native.md) - Detailed guide for Kimi-K2-Thinking\n"
  },
  {
    "path": "doc/en/kt-kernel/Qwen3-Coder-Next-Tutorial.md",
    "content": "# Running Qwen3-Coder-Next with SGLang and KT-Kernel\n\nThis tutorial demonstrates how to run Qwen3-Coder-Next (80B-A3B) model inference using SGLang integrated with KT-Kernel for CPU-GPU heterogeneous inference. Qwen3-Coder-Next is a Mixture-of-Experts code generation model. KT-Kernel supports both BF16 and FP8 precision backends, allowing you to choose between maximum quality and reduced memory footprint.\n\n## Table of Contents\n\n- [Table of Contents](#table-of-contents)\n- [Hardware Requirements](#hardware-requirements)\n- [Prerequisites](#prerequisites)\n- [Step 1: Download Model Weights](#step-1-download-model-weights)\n- [Step 2: Launch SGLang Server](#step-2-launch-sglang-server)\n  - [Key Parameters](#key-parameters)\n- [Step 3: Send Inference Requests](#step-3-send-inference-requests)\n  - [Option A: Interactive Chat with KT CLI](#option-a-interactive-chat-with-kt-cli)\n  - [Option B: OpenAI-Compatible API](#option-b-openai-compatible-api)\n- [Performance](#performance)\n- [Troubleshooting](#troubleshooting)\n  - [OOM (Out of Memory) Issues](#oom-out-of-memory-issues)\n- [Additional Resources](#additional-resources)\n\n## Hardware Requirements\n\n**Recommended Configuration:**\n- **GPU**: 1 x NVIDIA RTX 4090 24 GB\n- **CPU**: x86 CPU with AVX512 support (e.g., Intel Sapphire Rapids, AMD EPYC)\n- **RAM**: At least 100GB system memory for FP8 model weights\n- **Storage**: >85 GB for FP8 model weights (80.4 GB)\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n1. **SGLang installed**\n\n    Install the kvcache-ai fork of SGLang (one of):\n\n    ```bash\n    # Option A: One-click install (from ktransformers root)\n    ./install.sh\n\n    # Option B: pip install\n    pip install sglang-kt\n    ```\n\n2. **KT-Kernel installed**\n\n    Please follow [kt-kernel](https://github.com/kvcache-ai/ktransformers/blob/main/kt-kernel/README.md)\n\n    After installation, verify the CLI is working:\n\n    ```bash\n    kt version\n    ```\n\n3. **CUDA toolkit** - CUDA 12.0+ recommended (12.8+ for best FP8 support)\n4. **Hugging Face CLI** - For downloading models:\n   ```bash\n   pip install -U huggingface-hub\n   ```\n\n## Step 1: Download Model Weights\n\nDownload the Qwen3-Coder-Next weights from Hugging Face.\n\n```bash\n# FP8\nhf download Qwen/Qwen3-Coder-Next-FP8 \\\n  --local-dir /path/to/Qwen3-Coder-Next-FP8\n\n# BF16\nhf download Qwen/Qwen3-Coder-Next \\\n  --local-dir /path/to/Qwen3-Coder-Next\n```\n\n**Note:** Replace `/path/to/` with your actual storage path throughout this tutorial.\n\n## Step 2: Launch SGLang Server\n\nStart the SGLang server with KT-Kernel integration for CPU-GPU heterogeneous inference.\n\n```bash\n# FP8 Precision\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --model /path/to/Qwen3-Coder-Next-FP8 \\\n  --kt-weight-path /path/to/Qwen3-Coder-Next-FP8 \\\n  --kt-cpuinfer 96 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 100 \\\n  --kt-method FP8 \\\n  --kt-gpu-prefill-token-threshold 2048 \\\n  --attention-backend triton \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.80 \\\n  --chunked-prefill-size 16384 \\\n  --max-running-requests 4 \\\n  --max-total-tokens 256000 \\\n  --served-model-name Qwen3-Coder-Next \\\n  --enable-mixed-chunk \\\n  --tensor-parallel-size 1 \\\n  --enable-p2p-check \\\n  --disable-shared-experts-fusion \\\n  --fp8-gemm-backend cutlass \\\n  --tool-call-parser qwen3_coder \\\n  --kt-enable-dynamic-expert-update\n\n# BF16 Precision\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --model /path/to/Qwen3-Coder-Next \\\n  --kt-weight-path /path/to/Qwen3-Coder-Next \\\n  --kt-cpuinfer 96 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 60 \\\n  --kt-method BF16 \\\n  --kt-gpu-prefill-token-threshold 2048 \\\n  --attention-backend triton \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.80 \\\n  --chunked-prefill-size 16384 \\\n  --max-running-requests 4 \\\n  --max-total-tokens 256000 \\\n  --served-model-name Qwen3-Coder-Next \\\n  --enable-mixed-chunk \\\n  --tensor-parallel-size 1 \\\n  --enable-p2p-check \\\n  --disable-shared-experts-fusion \\\n  --tool-call-parser qwen3_coder \\\n  --kt-enable-dynamic-expert-update\n```\n\nSee [KT-Kernel Parameters](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel#kt-kernel-parameters) for detailed parameter tuning guidelines.\n\n### Key Parameters\n\n| Parameter | Description |\n|-----------|-------------|\n| `--kt-method FP8 / BF16` | Inference precision mode. FP8 halves weight memory; BF16 uses full precision. |\n| `--kt-cpuinfer` | Number of CPU inference threads. |\n| `--kt-threadpool-count` | Number of thread pools. Set to NUMA node count. |\n| `--kt-num-gpu-experts` | Number of experts kept on GPU for decoding. |\n| `--kt-gpu-prefill-token-threshold` | Token threshold for layerwise prefill strategy. |\n| `--kt-enable-dynamic-expert-update` | Enable dynamic expert placement on GPU based on routing statistics. |\n| `--kt-expert-placement-strategy` | Expert placement strategy. Default: `uniform`. See [Expert Scheduling Tutorial](experts-sched-Tutorial.md) for other options. |\n| `--chunked-prefill-size` | Maximum tokens per prefill batch. |\n| `--max-total-tokens` | Maximum total tokens in KV cache. |\n| `--tool-call-parser` | Tool call parser for function calling support (use `qwen3_coder`). |\n| `--fp8-gemm-backend` | GEMM backend for FP8 computation. |\n\n## Step 3: Send Inference Requests\n\nOnce the server is running (default: `http://localhost:30000`), you can interact with the model in several ways:\n\n### Option A: Interactive Chat with KT CLI\n\nThe easiest way to chat with the model:\n\n```bash\nkt chat\n```\n\nThis opens an interactive terminal chat session. Type your messages and press Enter to send. Use `Ctrl+C` to exit.\n\n### Option B: OpenAI-Compatible API\n\nThe server exposes an OpenAI-compatible API at `http://localhost:30000/v1`.\n\n**curl example (streaming):**\n\n```bash\ncurl http://localhost:30000/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"Qwen3-Coder-Next\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Write a Python function to compute the Fibonacci sequence.\"}],\n    \"stream\": true\n  }'\n```\n\n**curl example (non-streaming):**\n\n```bash\ncurl -s http://localhost:30000/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"Qwen3-Coder-Next\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Hello! What can you help me with?\"}],\n    \"stream\": false\n  }'\n```\n\n## Performance\n\nThe following benchmarks were measured with single concurrency (Prefill tps / Decode tps):\n\n| GPU | CPU | PCIe | Precision | 64 tokens | 2048 tokens | 8192 tokens | 32768 tokens |\n|-----|-----|------|-----------|-------------|-------------|-------------|--------------|\n| 1 x RTX 5090 (32 GB) | 2 x AMD EPYC 9355 | PCIe 5.0 | FP8  | 362 / 75.9 | 1746 / 75.6 | 2407 / 69.1 | 6233 / 51.7 | \n\n## Troubleshooting\n\n### OOM (Out of Memory) Issues\n\nLayerwise prefill requires extra VRAM. If you encounter OOM, adjust these parameters when launching the server:\n\n| Parameter | VRAM Impact |\n|-----------|-------------|\n| `--kt-num-gpu-experts` | Reduces expert weight VRAM usage |\n| `--chunked-prefill-size` | Reduces prefill extra VRAM allocation |\n| `--max-total-tokens` | Reduces KV cache VRAM usage |\n| `--mem-fraction-static` | Lower values reserve more VRAM headroom (default: 0.80) |\n\n**Tip:** Test with an input of length `chunked-prefill-size` to verify your configuration won't OOM during prefill.\n\n## Additional Resources\n\n- [Qwen3-Coder-Next Model Card](https://huggingface.co/Qwen/Qwen3-Coder-Next)\n- [KT-Kernel Documentation](../../../kt-kernel/README.md)\n- [SGLang GitHub](https://github.com/sgl-project/sglang)\n- [KT-Kernel Parameters Reference](../../../kt-kernel/README.md#kt-kernel-parameters)\n"
  },
  {
    "path": "doc/en/kt-kernel/README.md",
    "content": "# kt-kernel Docs"
  },
  {
    "path": "doc/en/kt-kernel/amd_blis.md",
    "content": "\n### USAGE\n1. To use this feature, you should use MOE_INT8 method (i.e. `--kt-method MOE_INT8`)\n2. !!! you should see the method in the below motivation section to  build and install the correct amd blis lib.\n3. Before your install you should set `export CPUINFER_ENABLE_BLIS=ON` to enable\n### Motivation\n\nTo accelerate the prefill speed of AMD. Reference the https://github.com/amd/blis repo. And the usage should add the LPGEMM support. See the docs here: https://www.cs.utexas.edu/~flame/BLISRetreat2024/slides/Bhaskar_BLIS_Retreat_2024_AMD_LPGEMM_0.pdf\nI reference this api guide for the code: https://docs.amd.com/r/en-US/57404-AOCL-user-guide/AOCL-BLAS?section=lpgemm-in-aocl-blas\nTo use lpgemm, see the doc here: \nhttps://www.amd.com/content/dam/amd/en/documents/developer/version-4-1-documents/aocl/aocl-4-1-user-guide.pdf\n<img width=\"2134\" height=\"1240\" alt=\"Image\" src=\"https://github.com/user-attachments/assets/d4008736-c1c7-422e-a747-155fc2eb4141\" />\nSo, you just need to enable aocl_gemm add-on, examples are here:https://github.com/amd/blis/blob/master/docs/CMakeBuildSystem.md\n\n<img width=\"2222\" height=\"702\" alt=\"Image\" src=\"https://github.com/user-attachments/assets/bf924b69-e01d-460d-b4cd-122e77ec982d\" />\nYou can see how to install it.\n\n\n\n"
  },
  {
    "path": "doc/en/kt-kernel/deepseek-v3.2-sglang-tutorial.md",
    "content": "# Running DeepSeek V3.2 with SGLang and KT-Kernel\n\nThis tutorial demonstrates how to run DeepSeek V3.2 model inference using SGLang integrated with KT-Kernel for CPU-GPU heterogeneous inference. This setup enables efficient deployment of large MoE models by offloading experts to CPU.\n\n## Table of Contents\n\n- [Hardware Requirements](#hardware-requirements)\n- [Prerequisites](#prerequisites)\n- [Step 1: Download Model Weights](#step-1-download-model-weights)\n- [Step 2: Quantize CPU Weights](#step-2-quantize-cpu-weights)\n- [Step 3: Launch SGLang Server](#step-3-launch-sglang-server)\n- [Step 4: Send Inference Requests](#step-4-send-inference-requests)\n\n## Hardware Requirements\n\n**Minimum Configuration:**\n- **GPU**: NVIDIA L20 48GB (or equivalent with at least 27GB VRAM available)\n- **CPU**: Intel Xeon with AMX support (e.g., Sapphire Rapids)\n- **RAM**: At least 350GB system memory for INT4 quantization\n- **Storage**: ~1TB for model weights (FP8 + INT4 quantized)\n\n**Tested Configuration:**\n- **GPU**: NVIDIA L20 48GB\n- **CPU**: Intel(R) Xeon(R) Platinum 8488C\n- **RAM**: 2TB DDR5\n- **OS**: Linux (Ubuntu 20.04+ recommended)\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n1. **KT-Kernel installed** - Follow the [installation guide](./kt-kernel_intro.md#installation)\n2. **SGLang installed** - Install the kvcache-ai fork: `pip install sglang-kt` or run `./install.sh` from the ktransformers root\n3. **CUDA toolkit** - Compatible with your GPU (CUDA 11.8+ recommended)\n4. **Hugging Face CLI** - For downloading models:\n   ```bash\n   pip install huggingface-hub\n   ```\n\n## Step 1: Download Model Weights\n\nDeepSeek V3.2 requires downloading model repositories:\n\n1. **DeepSeek-V3.2**\n2. **DeepSeek-V3.2-Speciale**\n\n```bash\n# Create a directory for models\nmkdir -p /path/to/models\ncd /path/to/models\n\n# Download DeepSeek-V3.2 (FP8 weights for GPU)\nhuggingface-cli download deepseek-ai/DeepSeek-V3.2 \\\n  --local-dir /path/to/deepseek-v3.2\n\n# Download DeepSeek-V3.2-Speciale (if needed)\nhuggingface-cli download deepseek-ai/DeepSeek-V3.2-Speciale \\\n  --local-dir /path/to/deepseek-v3.2-speciale\n```\n\n**Note:** Replace `/path/to/models` with your actual storage path throughout this tutorial.\n\n## Step 2: Quantize CPU Weights\n\nConvert the FP8 GPU weights to INT4 quantized CPU weights using the provided conversion script.\n\n### Conversion Command\n\nFor a 2-NUMA system with 60 physical cores:\n\n```bash\ncd /path/to/ktransformers/kt-kernel\n\npython scripts/convert_cpu_weights.py \\\n  --input-path /path/to/deepseek-v3.2 \\\n  --input-type fp8 \\\n  --output /path/to/deepseek-v3.2-INT4 \\\n  --quant-method int4 \\\n  --cpuinfer-threads 60 \\\n  --threadpool-count 2 \\\n  --no-merge-safetensor\n```\n\n## Step 3: Launch SGLang Server\n\nStart the SGLang server with KT-Kernel integration for CPU-GPU heterogeneous inference.\n\n### Launch Command\n\nFor single NVIDIA L20 48GB + 2-NUMA CPU system:\n\n```bash\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 30000 \\\n  --model /path/to/deepseek-v3.2 \\\n  --kt-weight-path /path/to/deepseek-v3.2-INT4 \\\n  --kt-cpuinfer 60 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 1 \\\n  --attention-backend triton \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.98 \\\n  --chunked-prefill-size 4096 \\\n  --max-running-requests 32 \\\n  --max-total-tokens 40000 \\\n  --served-model-name DeepSeek-V3.2 \\\n  --enable-mixed-chunk \\\n  --tensor-parallel-size 1 \\\n  --enable-p2p-check \\\n  --disable-shared-experts-fusion \\\n  --kt-method AMXINT4\n```\n\n### Resource Usage\n\n- **GPU VRAM:** ~27GB (for 1 GPU expert per layer + attention)\n- **System RAM:** ~350GB (for INT4 quantized CPU experts)\n\n## Step 4: Send Inference Requests\n\nOnce the server is running, you can send inference requests using the OpenAI-compatible API.\n\n### Basic Chat Completion Request\n\n```bash\ncurl -s http://localhost:30000/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"DeepSeek-V3.2\",\n    \"stream\": false,\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hi\"}\n    ]\n  }'\n```\n\n### Example Response\n\n```json\n{\n  \"id\": \"adbb44f6aafb4b58b167e42fbbb1eed3\",\n  \"object\": \"chat.completion\",\n  \"created\": 1764675126,\n  \"model\": \"DeepSeek-V3.2\",\n  \"choices\": [\n    {\n      \"index\": 0,\n      \"message\": {\n        \"role\": \"assistant\",\n        \"content\": \"Hi there! 👋 \\n\\nThanks for stopping by! How can I help you today? Feel free to ask me anything - I'm here to assist with questions, explanations, conversations, or whatever you need! 😊\\n\\nIs there something specific on your mind, or would you like to know more about what I can do?\",\n        \"reasoning_content\": null,\n        \"tool_calls\": null\n      },\n      \"logprobs\": null,\n      \"finish_reason\": \"stop\",\n      \"matched_stop\": 1\n    }\n  ],\n  \"usage\": {\n    \"prompt_tokens\": 5,\n    \"total_tokens\": 72,\n    \"completion_tokens\": 67,\n    \"prompt_tokens_details\": null,\n    \"reasoning_tokens\": 0\n  },\n  \"metadata\": {\n    \"weight_version\": \"default\"\n  }\n}\n```\n\n## Additional Resources\n\n- [KT-Kernel Documentation](../../../kt-kernel/README.md)\n- [DeepSeek V3.2 Model Card](https://huggingface.co/deepseek-ai/DeepSeek-V3.2)\n- [SGLang GitHub](https://github.com/sgl-project/sglang)"
  },
  {
    "path": "doc/en/kt-kernel/experts-sched-Tutorial.md",
    "content": "# CPU-GPU Expert Scheduling Tutorial\n\nThis tutorial demonstrates how to use the CPU-GPU expert scheduling feature in KTransformers with SGLang. This feature introduces a flexible GPU expert mask system that allows intelligent placement of MoE experts across CPU and GPU, optimizing inference performance based on workload patterns.\n\n## Table of Contents\n\n- [Table of Contents](#table-of-contents)\n- [Hardware Requirements](#hardware-requirements)\n- [Prerequisites](#prerequisites)\n- [Step 1: Download Model Weights](#step-1-download-model-weights)\n- [Step 2: Launch Server with Expert Scheduling](#step-2-launch-server-with-expert-scheduling)\n  - [Basic Usage](#basic-usage)\n  - [Expert Placement Strategies](#expert-placement-strategies)\n  - [Key Parameters](#key-parameters)\n- [Step 3: Send Inference Requests](#step-3-send-inference-requests)\n  - [Option A: Interactive Chat with KT CLI](#option-a-interactive-chat-with-kt-cli)\n  - [Option B: OpenAI-Compatible API](#option-b-openai-compatible-api)\n- [Performance](#performance)\n- [Troubleshooting](#troubleshooting)\n- [Additional Resources](#additional-resources)\n\n## Hardware Requirements\n\n**Minimum Configuration:**\n- **GPU**: NVIDIA RTX 4090 24 GB (or equivalent with at least 24GB VRAM available)\n- **CPU**: x86 CPU with AVX512 support (e.g., Intel Sapphire Rapids, AMD EPYC)\n- **RAM**: At least 256GB system memory\n- **Storage**: Sufficient space for model weights\n\n**Tested Configuration:**\n\n- **GPU**: 4 x NVIDIA GeForce RTX 4090 (24 GB)\n- **CPU**: Intel Xeon Gold 6454S\n- **RAM**: 512GB DDR5\n- **OS**: Linux (Ubuntu 20.04+ recommended)\n\n## Prerequisites\n\nBefore starting, ensure you have:\n\n1. **SGLang installed**\n\n    Install the kvcache-ai fork of SGLang (one of):\n\n    ```bash\n    # Option A: One-click install (from ktransformers root)\n    ./install.sh\n\n    # Option B: pip install\n    pip install sglang-kt\n    ```\n\n2. **KTransformers installed**\n\n    ```bash\n    git clone https://github.com/kvcache-ai/ktransformers.git\n    cd ktransformers/kt-kernel\n    bash ./install.sh\n    ```\n\n    After installation, verify the CLI is working:\n\n    ```bash\n    kt version\n    ```\n\n3. **CUDA toolkit** - CUDA 12.0+ recommended\n4. **Hugging Face CLI** - For downloading models:\n   ```bash\n   pip install -U huggingface-hub\n   ```\n\n## Step 1: Download Model Weights\n\nDownload your preferred MoE model weights. This feature supports various MoE models including:\n\n* **Qwen3-Next-80B-A3B-Instruct-FP8**\n\n    ```bash\n    huggingface-cli download Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 --local-dir /path/to/qwen3-next-80b\n    ```\n\n## Step 2: Launch Server with Expert Scheduling\n\n### Basic Usage\n\nThe simplest way to start the server with expert scheduling:\n\n```bash\npython -m sglang.launch_server \\\n    --model /path/to/model \\\n    --kt-num-gpu-experts 8 \\\n    --kt-expert-placement-strategy uniform\n```\n\n### Expert Placement Strategies\n\nThe system provides four expert placement strategies:\n\n| Strategy | Description | Use Case |\n|----------|-------------|----------|\n| `uniform` | Distributes GPU experts evenly across all MoE layers | Default, no prior statistics needed |\n| `frequency` | Places most frequently activated experts on GPU | Best performance when activation statistics are available |\n| `front-loading` | Fills GPU experts from the first layer onwards | Testing or specific workload patterns |\n| `random` | Randomly selects experts with fixed seed (42) | Baseline comparison |\n\n**Using Frequency Strategy (Recommended for best performance):**\n\n```bash\npython -m sglang.launch_server \\\n    --model /path/to/model \\\n    --kt-num-gpu-experts 8 \\\n    --kt-expert-placement-strategy frequency \\\n    --init-expert-location /path/to/activation_stats.pt\n```\n\n**Using Dynamic Expert Update:**\n\n```bash\npython -m sglang.launch_server \\\n    --model /path/to/model \\\n    --kt-num-gpu-experts 8 \\\n    --kt-expert-placement-strategy frequency \\\n    --init-expert-location /path/to/activation_stats.pt \\\n    --kt-enable-dynamic-expert-update \\\n    --kt-gpu-prefill-token-threshold 512\n```\n\n### Key Parameters\n\n| Parameter | Description |\n|-----------|-------------|\n| `--kt-num-gpu-experts` | Number of GPU experts per MoE layer. Internally multiplied by the number of MoE layers to get the total GPU experts. Ignored if `--kt-gpu-experts-ratio` is set. |\n| `--kt-gpu-experts-ratio` | Ratio of total experts to place on GPU (0.0-1.0). If set, overrides `--kt-num-gpu-experts`. Example: 0.1 means 10% of all experts across all layers will be on GPU. |\n| `--kt-expert-placement-strategy` | Expert placement strategy: `frequency`, `uniform`, `front-loading`, or `random`. Default: `uniform`. |\n| `--init-expert-location` | Path to activation statistics file (`.pt`) for `frequency` strategy. |\n| `--kt-enable-dynamic-expert-update` | Enable dynamic expert update during inference. |\n| `--kt-gpu-prefill-token-threshold` | Token threshold for triggering dynamic expert redistribution during prefill. |\n| `--record-kt-gpu-expert-distribution` | Enable recording of GPU expert distribution for analysis. |\n| `--expert-distribution-recorder-mode` | Recording mode: `stat` (default), `stat_approx`, `per_pass`, or `per_token`. |\n\n## Step 3: Send Inference Requests\n\nOnce the server is running (default: `http://localhost:30000`), you can interact with the model in several ways:\n\n### Option A: Interactive Chat with KT CLI\n\nThe easiest way to chat with the model:\n\n```bash\nkt chat\n```\n\nThis opens an interactive terminal chat session. Type your messages and press Enter to send. Use `Ctrl+C` to exit.\n\n### Option B: OpenAI-Compatible API\n\nThe server exposes an OpenAI-compatible API at `http://localhost:30000/v1`.\n\n**curl example (streaming):**\n\n```bash\ncurl http://localhost:30000/v1/chat/completions \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"model\": \"model-name\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Hello!\"}],\n    \"stream\": true\n  }'\n```\n\n## Performance\n\n### Throughput (tokens/s)\n\nThe following benchmarks were measured on Qwen3-Next-80B-A3B-Instruct-FP8 with 4 x RTX 4090, Intel Xeon Gold 6454S, tensor parallel size 4, using ShareGPT dataset:\n\n| GPU Expert Ratio | random | uniform | front-loading | frequency | dynamic-expert-update |\n|------------------|--------|---------|---------------|-----------|----------------------|\n| 0% | 53.01 | 52.96 | 54.18 | 52.72 | 53.37 |\n| 10% | 56.63 | 56.57 | 57.18 | 58.60 | 70.22 |\n| 20% | 58.75 | 60.28 | 58.82 | 61.92 | 74.73 |\n| 30% | 62.86 | 62.08 | 63.87 | 66.50 | 75.55 |\n| 40% | 66.81 | 66.82 | 67.45 | 72.78 | 80.98 |\n| 50% | 70.38 | 65.25 | 73.65 | 76.19 | 81.17 |\n| 60% | 71.33 | 72.80 | 77.95 | 82.33 | 82.30 |\n| 70% | 74.40 | 76.17 | 81.59 | 89.37 | 88.70 |\n| 80% | 79.71 | 79.20 | 89.20 | 100.67 | 92.31 |\n| 90% | 88.82 | 81.06 | 98.14 | 107.15 | 95.04 |\n| 100% | 112.61 | 112.32 | 111.82 | 114.26 | 112.99 |\n\nThe `frequency` and `dynamic-expert-update` strategies show significant performance improvements over baseline strategies, especially at lower GPU expert ratios.\n\n## Troubleshooting\n\n### OOM (Out of Memory) Issues\n\nIf you encounter OOM, adjust these parameters when launching the server:\n\n| Parameter | VRAM Impact |\n|-----------|-------------|\n| `--kt-num-gpu-experts` / `--kt-gpu-experts-ratio` | Reduces expert weight VRAM usage |\n| `--chunked-prefill-size` | Reduces prefill extra VRAM allocation |\n| `--max-total-tokens` | Reduces KV cache VRAM usage |\n\n### Dynamic Expert Update Not Triggering\n\nEnsure all conditions are met:\n1. `--kt-enable-dynamic-expert-update` is enabled\n2. `--kt-gpu-prefill-token-threshold` is set\n3. Prefill length >= threshold value\n\n### Statistics Recording\n\nTo save expert distribution statistics to a custom path, set the environment variable:\n\n```bash\nexport SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR=/path/to/output\n```\n\n## Additional Resources\n\n- [KT-Kernel Documentation](../../../kt-kernel/README.md)\n- [SGLang GitHub](https://github.com/sgl-project/sglang)\n- [KTransformers GitHub](https://github.com/kvcache-ai/ktransformers)\n"
  },
  {
    "path": "doc/en/kt-kernel/kt-cli.md",
    "content": "# KT-CLI\n\n> ⚠️ **Note:** This feature is currently under active development. Many functionalities are not yet complete and are being improved. Please stay tuned for updates.\n\n## Design Philosophy\n\nKT-CLI is designed to **minimize the burden of reading documentation**. Instead of requiring users to read lengthy docs, the CLI provides:\n\n- **Interactive Mode**: Run commands without arguments to get step-by-step guided prompts\n- **Direct Mode**: Pass arguments directly for automation and scripting\n    > 💡 **Tip:** The arguments are fully compatible with the previous SGLang + KTransformers approach, so you can migrate seamlessly.\n\nSimply run a command, and the CLI will interactively guide you through the process!\n\n## Usage\n\nYou can check the usage by `kt --help`\n\n```\nkt [OPTIONS] COMMAND [ARGS]...\n```\n\nKTransformers CLI - A unified command-line interface for KTransformers.\n\n## Options\n\n| Option | Description |\n|--------|-------------|\n| `--help` | Show this message and exit. |\n\n## Commands\n\n| Command | Description |\n|---------|-------------|\n| `version` | Show version information |\n| `chat` | Interactive chat with running model |\n| `quant` | Quantize model weights |\n| `bench` | Run full benchmark |\n| `microbench` | Run micro-benchmark |\n| `doctor` | Diagnose environment issues |\n| `model` | Manage models and storage paths |\n| `config` | Manage configuration |\n| `sft` | Fine-tuning with LlamaFactory |\n"
  },
  {
    "path": "doc/en/llama4.md",
    "content": "# 🦙 Tutorial: LLaMA 4 Multi-Concurrency Support with KTransformers (Balance Serve Backend)\n\n## 📌 Overview\n\nWe are pleased to announce that **KTransformers** now provides **experimental support for LLaMA 4 models** through the powerful `balance_serve` backend introduced in **v0.2.4**. This update is available under the dedicated development branch: [`support-llama4`](https://github.com/kvcache-ai/ktransformers/tree/support-llama4), specifically targeting the newly released **Meta LLaMA 4** model architecture.\n\n⚠️ This support is currently **not available on the main branch** due to dependencies on newer versions of `transformers`, and **compatibility limitations with inference of currently supported models**. Work is underway to integrate this into the mainline once broader stability and compatibility are validated.\n\n💡 **If you already have an environment based on the main branch**, it is **strongly recommended to create a new environment** to avoid potential dependency conflicts.\n\n------\n\n## 🔗 Model & Resource Links\n\n- 🔥 Official LLaMA 4 Release: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct\n   (Note: LLaMA 4 models are served through the Meta repository. Make sure to **agree to terms** before downloading.)\n- 🧠 GGUF Format (quantized models):\n  - https://huggingface.co/unsloth/Llama-4-Scout-17B-16E-Instruct-GGUF\n\n------\n\n## 🧪 Demo\n\nhttps://github.com/user-attachments/assets/449706f1-784b-4931-b2ba-07687c1aca54\n\n------\n\n## Resource Requirements\n\nThe Scout model running with 16 Experts requires approximately 65 GB of memory and 10 GB of GPU memory, while the Maverick model with 128 Experts requires approximately 270 GB of memory and 12 GB of GPU memory.\n\n------\n\n## ⚙️ Usage Instructions\n\n### 1. Clone `support-llama4` Branch\n\n```bash\ngit clone https://github.com/kvcache-ai/ktransformers.git\ncd ktransformers\ngit checkout support-llama4\ngit submodule update --init --recursive\n```\n\n### 2. Set Up Environment\n\n```bash\n# Download Miniconda\nwget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n\n# Create environment\nconda create --name ktransformers python=3.11\nconda activate ktransformers\n\n# Install required libraries\nconda install -c conda-forge libstdcxx-ng\n\n# Verify GLIBCXX version (should include 3.4.32)\nstrings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX\n\nsudo apt install libtbb-dev libssl-dev libcurl4-openssl-dev libaio1 libaio-dev libfmt-dev libgflags-dev zlib1g-dev patchelf\npip3 install packaging ninja cpufeature numpy openai\npip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n```\n\n### 3. Build with Balance Serve Support\n\n```bash\n# Install single NUMA dependencies\nUSE_BALANCE_SERVE=1  bash ./install.sh\n# For those who have two cpu and 1T RAM（Dual NUMA）:\nUSE_BALANCE_SERVE=1 USE_NUMA=1 bash ./install.sh\n```\n\n### 4. Use our custom config.json\n\nCurrently, you need to copy the content of our custom config file into the `config.json` under your `--model_path`.  \n- Use [scout_config.json](https://github.com/kvcache-ai/ktransformers/blob/support-llama4/doc/en/scout_config.json) for the Llama-4-Scout-17B-16E model  \n- Use [maverick_config.json](https://github.com/kvcache-ai/ktransformers/blob/support-llama4/doc/en/maverick_config.json) for the Llama-4-Maverick-17B-128E model  \n\nPlease make sure to replace the content of `config.json` with the appropriate one accordingly.\n\n### 5. Run LLaMA 4 Inference Server\n\nMake sure you have:\n\n- `--model_path` pointing to a local config directory (not a Hugging Face name).\n- `--gguf_path` pointing to the folder containing quantized `.gguf` weights.\n\n```bash\npython ktransformers/server/main.py \\\n  --port 10002 \\\n  --model_path <path_to_safetensor_config> \\\n  --gguf_path <path_to_gguf_files> \\\n  --optimize_config_path ktransformers/optimize/optimize_rules/Llama4-serve.yaml \\\n  --max_new_tokens 1024 \\\n  --cache_lens 32768 \\\n  --chunk_size 256 \\\n  --max_batch_size 4 \\\n  --backend_type balance_serve \\\n```\n\n### 5. Access server\n\n```\ncurl -X POST http://localhost:10002/v1/chat/completions \\\n  -H \"accept: application/json\" \\\n  -H \"Content-Type: application/json\" \\\n  -d '{\n    \"messages\": [\n      {\"role\": \"user\", \"content\": \"hello\"}\n    ],\n    \"model\": \"Llama4\",\n    \"temperature\": 0.3,\n    \"top_p\": 1.0,\n    \"stream\": true\n  }'\n```\n\n------\n\n## 📌 Limitations\n\n- ✅ **Only `balance_serve` backend is supported** for LLaMA 4 models in this version.\n- ⚠️ Requires **`transformers==4.51.0`** or newer. Due to potential compatibility issues with older toolchains, we have **not merged this branch to main yet**.\n- ❌ Multimodal models are not supported yet in this version. Support will be added in future releases.\n"
  },
  {
    "path": "doc/en/long_context_introduction.md",
    "content": "# KVCache Long Context\n\n## TL;DR\n\nTraining larger models and supporting longer text sequences are currently the two most widely agreed-upon directions toward achieving AGI. After lowering the barrier for local inference with trillion-parameter MoE models, the second showcase scenario for KTransformers is reducing the inference barrier for ultra-long context sequences. Recently, both ChatGLM and InternLM have released open-source models supporting 1M tokens of context. This article will use InternLM2.5-7B-Chat-1M as an example to introduce a method that leverages the sparsity of attention to accelerate long-text inference on heterogeneous CPU/GPU systems.\n\nAfter optimization, KTransformers has achieved native-precision inference for 128K and even 1M tokens of context on a single 24GB GPU with CPU/DRAM support. In the 128K context scenario, the generation speed is 7.1 times faster than llama.cpp, while also achieving 100% accuracy on relatively simple test sets like \"needle in haystack\" and \"passkey\". On the more challenging dataset kvretrieval, through flexible framework configurations, we can achieve a **6.22x speedup** during inference while obtaining even higher scores than running the original model directly (**21.2 -> 24.4**). In the 1M context scenario on a single 24GB GPU, KTransformers can similarly achieve a 16 tokens/s inference speed, nearly 10 times faster than llama.cpp under the same conditions, with the \"needle in haystack\" evaluation score even surpassing the original model (**89.31 -> 92.88**).\n\nProject url: https://github.com/kvcache-ai/ktransformers\n\n## Mathematical Principle: The computational overhead of long-text inference and the sparsity in Attention caused by Softmax.\n\nAs the demand for longer context windows increases, not only have commercial large models like Kimi and Claude/Gemini started supporting increasingly longer context windows, but open-source models have also begun to catch up. Notably, both ChatGLM 4 and InternLM 2.5 have released versions that are under 10 billion parameters but support up to 1 million tokens of context. However, despite the relatively small size of these models, the enormous KVCache required for such ultra-long contexts still prevents local users from practically running these models. As shown in the figure below, while the InternLM2.5-7B-Chat-1M model weights only require 15.49GB of GPU memory, an additional 145.49GB is needed to store the entire 1M-token KVCache, which is clearly beyond the memory capacity of local users. Even when using the KVCache Offload feature of llama.cpp to offload the KVCache to CPU/DRAM, barely making the model runnable, performance remains unacceptable due to the need to fully scan the entire KVCache each time a single token is generated.\n\n| <img title=\"\" src=\"../assets/internlm_memory.png\" alt=\"internlm_memory\" width=\"882\"> | <img src=\"../assets/SparQ_attention.png\" title=\"\" alt=\"sparQ\" width=\"691\"> |\n| ------------------------------------------------------------------------------------ | -------------------------------------------------------------------------- |\n\nFortunately, many studies have noticed that attention distribution during the inference phase tends to be **sparse**. For example, the right figure shows SparQ's experimental statistics based on LLaMa 7B, where less than 1% of tokens in a 3k context have relatively high attention scores. Similar conclusions are not only reflected in many other papers, such as H2O, Quest, InfLLM, and SnapKV, but we have also further validated this through long-text experiments with InternLM 2.5-7B-1M. Although the proportion isn't as extreme as 1%, due to the inherent head-focused effect of the softmax operation in attention mechanisms, it is theoretically possible that if we can identify in advance which tokens have high attention scores, scanning less than 5% of the tokens would suffice to essentially replicate the original result.\n\nThus, the problem narrows down to how to quickly identify these tokens with high attention scores without scanning them all. In the following sections, we will first briefly survey several key related papers, then summarize and propose a general framework we designed and implemented within KTransformers—a highly efficient sparse attention operator for CPUs.\n\n## Related Papers and Conclusions\n\n### Prune or Retrieval？\n\nBased on the aforementioned points, we studied papers from recent years related to sparse selection in KVCache. The earliest of these is the paper H2O, which suggested that the attention distribution during inference is sparse and that only 5% of the KVCache is needed during inference. Following this, a series of works built on H2O's approach by designing more complex methods for selecting tokens that perform better in different scenarios. These methods are quite reasonable for single-word inference. However, as we previously explored in the Mooncake project, **we believe that the future trend is to precompute reusable KVCache as much as possible, and then use it to answer different questions.** This \"compute once, use many\" approach aims to reduce computational costs. Therefore, with this goal in mind, we prefer not to delete any tokens from the KVCache, or at least not remove a significant portion of them, to ensure that different questions can focus on different parts of the context in the future.\n\n![InfLLM Framework](../assets/InfLLM_framework.png)\n\nWe further investigated related research, among which InfLLM proposed a very promising framework. Not only does it recognize that attention is sparse, but it also suggests that overly long contexts can cause attention to be dispersed into irrelevant noise, thereby reducing the model's ability to focus on key information. To address this issue, InfLLM introduces an external memory module (Memory Units) to store the context's KVCache. In each computation step, the most relevant semantic information is retrieved from this external memory module to participate in the calculation, thus enhancing the model's ability to handle long-context inference.\n\nSpecifically, InfLLM organizes the external memory module using semantic blocks composed of neighboring tokens and employs a sliding window mechanism during computation. In each step, it selects only the semantic blocks at the head of the context (Initial Tokens), the blocks near the current token (Local Tokens), and a few blocks with the highest semantic similarity to the current token to participate in the attention calculation. As shown in equation 1, to efficiently retrieve the blocks with the highest similarity, InfLLM selects a few representative tokens whose scores $$r_m$$ are the highest within each block. Use Equation 2 to calculate the semantic similarity between the current token and each semantic block.\n\n![InfLLM Equation](../assets/InfLLM_equation.jpg)\n\nCompared to the previously mentioned H2O, the differences in InfLLM are as follows:\n\n1. The KVCache is not discarded but stored in memory and dynamically loaded onto the GPU during inference.\n\n2. KVCache is managed at the granularity of blocks rather than tokens, with each block selecting a few tokens as its representative index tokens.\n\nInfLLM's proposed method aligns with our \"compute once, use many\" approach of reusing KVCache. The external memory units in this method can be offloaded to CPU/DRAM or even SSD storage, allowing different parts to be selected for computation based on the specific question. This significantly improves the efficiency of attention computation.\n\n### Other Improvements\n\nSimilarly, after InfLLM, Quest also manages tokens at the granularity of blocks. Quest analyzed the recall rate of key tokens in H2O and full attention, finding that the Top-10 attention score token recall rate for the H2O algorithm is around 50%, which indicates that too much key information was lost. To improve the recall rate of key tokens, Quest chooses two \"representative tokens\" from each block for retrieval. In the prefill stage, each KVCache block records the maximum and minimum values for each channel, as shown in the figure below under \"Reduced Keys,\" which contains the element-wise min key and element-wise max key.\n\nDuring the attention computation stage, the dot product is computed between the current query vector and the max key and min key of each KVCache block, respectively. Then, for each channel, the maximum value between the two resulting product vectors is selected and summed to serve as the upper bound of the relevance score for that KVCache block, as shown in stage 1 of the diagram. Based on the relevance scores, the top-k KVCache blocks are selected to participate in the attention computation, as illustrated in stage 2 of the diagram.\n\n![Quest Framework](../assets/Quest_framework.png)\n\nCompared to InfLLM, Quest does not take heterogeneous architectures into account. Instead, it assumes that all KVCache can still fit into memory, simply leveraging sparse attention to accelerate the inference process. Ultimately, Quest achieves a 7.03x speedup in attention computation and a 2.23x improvement in end-to-end inference latency.\n\nGoing further, SnapKV proposes retaining two parts of the tokens during the prefill stage, as shown in the diagram below with the orange and green segments. The difference from InfLLM lies only in the method of selecting the middle tokens. SnapKV selects tokens at the token level rather than the block level, with the score calculation being similar to H2O, i.e., $$softmax(\\frac{qk^T}{\\sqrt{d_k}})$$. However, when summing across columns, only the rows within the final green window are selected for computation, corresponding to the Local Tokens section in InfLLM. Additionally, SnapKV introduces a pooling operation on top of attention, which the paper explains as ensuring that the recalled tokens retain more complete semantic information.\n\nThis approach in SnapKV involves a one-time selection during the inference phase, after which only the selected tokens are used for attention computation, while the rest of the KVCache is discarded.\n\n![SnapKV Framework](../assets/SnapKV_framework.png)\n\n\nOther related papers include PyramidKV, which observed that attention scores exhibit a pyramid-shaped distribution across attention layers. In lower attention layers, attention is widely distributed, while in higher layers, the attention scores for a few key tokens become increasingly prominent. Therefore, PyramidKV allocates more KVCache storage space to lower layers and less space to higher layers.\n\nMagicPiG, based on Locality-Sensitive Hashing (LSH), proposes a dynamic KVCache management strategy. First, it uses SnapKV to select a portion of important tokens to be stored in the GPU, while the KVCache of other tokens is placed in memory. By leveraging the high efficiency of LSH in high-dimensional space searches and the multithreading capabilities of CPUs, MagicPiG retrieves KVCache from memory that is similar to the current query and loads it into memory for inference. Compared to the earlier methods like InfLLM, Quest, and SnapKV, MagicPiG does not need to scan all representative tokens and select the top-k KVCache. Instead, it utilizes the mathematical properties of LSH, which not only simulates attention scores but also allows for identifying important KVCache with low overhead and high speed.\n\nThe above are just descriptions of some key points. For more detailed explanations, you can refer to the existing articles on Zhihu in Chinese:\n\n- https://zhuanlan.zhihu.com/p/701580870\n\n- https://zhuanlan.zhihu.com/p/714288577\n\n## KTransformers CPU Sparse Attn Framework\n\n### Framework Prototype\n\nBased on the introduction of the above papers, we have distilled the following key points:\n\n- The distribution of attention weights is sparse, and useless KVCache may introduce noise, which could actually reduce performance during the inference stage.\n\n- For the KVCache eviction strategy during the inference stage, the common approach is to retain the tokens from the beginning and the end of the prompt, while designing algorithms to select the tokens from the middle portion. One of the main factors affecting the model's performance is the ability to accurately identify the key tokens.\n\n- Managing the middle portion of tokens in blocks can improve memory swapping and attention computation efficiency, and smaller blocks do not seem to perform worse than token-level granularity.\n\n- The tokens that each attention layer focuses on during inference differ, and even the allocated KVCache capacity for different layers should vary.\n\nBased on these insights and inspirations, we developed a general framework for implementing sparse CPU attention operators during the inference phase. In the prefill stage, we use chunked prefill, loading only one layer of KVCache into GPU memory at a time for computation. Once completed, the KVCache is stored on CPU/DRAM. In the subsequent decode stage, instead of swapping KVCache in and out, the sparse attention operator runs directly on the CPU. **This significantly reduces the minimum** **GPU** **memory requirements, making local 128K or even 1M token contexts possible.**\n\nSpecifically during the generation phase, we implemented the entire framework as shown in the diagram below.\n\n![KTransformers long congtext v1](../assets/KTransformers_long_context_v1.png)\n\nWe organized the KVCache in units of blocks. Specifically:\n\n- **KVCache Partitioning:** A complete input prompt is divided into three configurable parts: Initial, Context, and Local. During the computation process, the Initial/Local parts will be fully attended to, while the Context part will be sparsely retrieved. This approach is based on findings from many papers (such as streamingLLM and Minference) which mention the existence of \"attention sinks,\" where higher attention weights are often found at the beginning and the end of the sequence.\n\n- **Context Block Partitioning:** For the middle Context, we follow the InfLLM approach by dividing it into blocks based on a configurable fixed number of tokens. Each block can select 1 to k tokens as its representative tokens. During the actual inference phase, the Context blocks that require attention are selected based on these representative tokens.\n  \n  - Specifically, we have implemented the following methods for selecting representative tokens, based on the approaches outlined in various papers.\n    \n    - Max: The maximum values of multiple tokens within a block, across each channel, are concatenated to form the representative token for the current block.\n    \n    - Mean: The average values of multiple tokens within a block, across each channel, are concatenated to form the representative token for the current block.\n    \n    - Quest: A combination of the previous two methods: the maximum and minimum values of multiple tokens within a block, across each channel, are taken as the representative tokens for the block. Under this method, the number of representative tokens is fixed at 2\n    \n    - Dynamic: By calculating the cumulative attention score for each token using a specific method, each block selects the top-k tokens with the highest scores as the representative tokens for the block. This is similar to InfLLM but with some simplifications.\n    \n    - Fix: Select tokens at fixed intervals within the block.\n  \n  - Once the representative tokens for each block are determined, use Equation 2 from InfLLM to calculate the similarity between the input X and the k representative tokens of each block B, and only select the top $$r_k$$ blocks for attention computation, where $$l_P $$ represents the length of the historical tokens:\n\nSince InfLLM requires calculating a representative score for each token during the prefill stage and then selecting a representative token for each block based on these scores, this operation involves invasive modifications to the prefill implementation, making it difficult to integrate with other methods. Furthermore, in actual testing, we found that in most scenarios, similar or even better results can be achieved through a combination of other methods. Therefore, we ultimately decided not to integrate this method into the framework.\n\n## Further Optimizations\n\nAfter implementing the above framework, we conducted a series of evaluations based on LongBench and InfiniteBench.\n\nAt the beginning of the experiment, we designed the architecture so that for each inference token, the most relevant KVCache blocks would be reselected. On the one hand, this strategy incurred significant overhead during the retrieval process. On the other hand, we found that in some scenarios, f**requently changing the selection of retrieved blocks did not lead to better results**. For example, in the kvretrieval dataset, we observed that the model's responses were often correct in the first half but incorrect in the second half. Since the answers to kvretrieval questions consist of long and meaningless strings, this indicates that the correct KVCache blocks were selected during the inference of the earlier tokens but incorrect blocks were chosen during the later stages of inference.\n\nTo address this issue, we further integrated the method proposed in SnapKV. Before starting the inference, we preselect relevant KVCache blocks by analyzing the attention scores of the context tokens, based on the question. During the subsequent inference stages, the selection of KVCache blocks is restricted to this preselected range. This approach allowed us to select the block containing the correct answer 100% of the time in the kvretrieval dataset.\n\nHowever, it should be noted that this method strictly relies on the structure of the Benchmark Prompt and **does not necessarily guarantee optimal performance in other scenarios, such as complex document understanding and generation tasks.** Therefore, we have integrated it into our framework as an optional module. The final framework and configurable parameters are as follows:\n\n![KTransformers long congtext v2](../assets/KTransformers_long_context_v2.png)\n\n\nConfiguration：\n\n- **threads_num:** Number of CPU Threads\n\n- **block_size:** KVCache Block Size\n\n- **local_windows_len:** Prompt End Window Size\n\n- **preselect_block_count:** Number of Preselected Blocks\n\n- **second_block_count:** Number of Blocks Selected After Preselection\n\n- **preselect_block:** Whether to Enable Preselection\n\n- **token_step:** Interval Between Token Selections for KVCache\n\n- **layer_step:** Interval Between Layer Selections for KVCache\n\n- **dense_layer_num:** Number of Initial Layers Without KVCache Selection, Importing All KVCache\n\n- **head_select_mode:SEPARATE**(In the GQA scenario, each kv_head is selected separately) / **SHARED:** (All kv_heads are selected together)\n\n- **representative_type:** Method of Selecting Representative Tokens\n\n- **representative_num:** Number of Representative Tokens\n\nBy modifying configuration options, various KVCache eviction or compression methods can be easily reproduced within our framework. For example:\n\n- Setting `block_size` to 1 and `preselect_block` to True results in a version of SnapKV without the pooling operation.\n\n- Setting `representative_type` to Quest, `preselect_block` to False, and `head_select_mode` to SEPARATE replicates the Quest method.\n\nBelow is the pseudocode for the framework:\n\n```python\ndef preselect_block(local_q, kvcache):\n    key_states = kvcache.keycache\n    attn_scores = torch.matmul(\n                local_q, key_states.transpose(2, 3)\n            ) / math.sqrt(head_dim)\n    attn_scores += attn_mask\n    attn_scores = nn.functional.softmax(\n                attn_scores, dim=-1, dtype=torch.float32\n            ).to(query_states.dtype)\n    vote = attn_scores[..., initial_size:-local_size:, :].sum(dim=-2)\n    pool_vote = pool1d(vote, kernel_size=kernel_size, padding=kernel_size//2, stride=1)\n    indices = pool_vote.topk(max_capacity_prompt - local_size, dim=-1).indices\n    kv_cache_block_indices = find_representative_tokens_block(indices)\n    kvcache_after_preselected = kvcache[kv_cache_block_indices]\n    ...\n    return kvcache_after_preselected\ndef get_representative_tokens():\n    Calculate the representative token for each block based on the representative_type.\n    return ...\ndef decode_attention(query, key, value):\n  # Select once every token_steps tokens.\n  token_steps = 4\n  # Select once every layer_steps layers.\n  layer_steps = 4\n  for token_idx in range(max_new_tokens):\n      for layer_idx in range(config.num_hidden_layers):\n          if token_idx % token_steps != 0 or layer_idx % layer_steps != 0:\n            # If the attention of the current layer in this round does not require reselection, the historical selection results from the kvcache will be retained.\n            kvcache_after_retrieval = history_kvcache_after_retrieval[layer_idx//layer_steps]\n          else:\n            # Otherwise, use the query from the current round's current layer to reselect the kvcache.\n            kvcache_after_retrieval = retrieval_kvcache(query, kvcache)\n            # Save it to the kvcache historical selection results.\n            history_kvcache_after_retrieval[layer_idx//layer_steps] = kvcache_after_retrieval\n          # calculate attention\n          output = attn(query, kvcache_after_retrieval)\n          yield output\n\n# Model prefill, if preselection is required, local_q still needs to be saved.\nlocal_q, KVCache = model.prefill(input_ids)\nif preselect_block:\n    # Preselection round\n    KVCache = preselect_block(local_q, kvcache)\n# Find the representative token for each block.\nblock_representative_tokens = get_representative_tokens(\n   kvcache,                      \n   config.representative_type\n)\n\n# model generate\n'''\n'''\ndecode_attention(query, key, value)\n'''\n'''\n```\n\n## Experiment\n\nAt the beginning of testing, we will use the following basic configuration, which will be further optimized through the extended framework.\n\n```python\nmax_seq_len: 256000 # KVCache length\nblock_size: 128 # KVCache block size\nlocal_windows_len: 4096 # The KVCache of length local_windows_len is stored on the GPU.\nsecond_block_count: 96 # After preselection, each time select the number of KVCache blocks. If >= preselect_block_count, use the preselected blocks.\nthreads_num: 64 # CPU thread num\nrepresentative_type: DYNAMIC # KVCache block representative token selection method.\nkv_type: FP16 \ndense_layer_num: 0 # The first few layers do not need to fill or select KVCache\nrepresentative_num: 1 # The number of representative tokens within a KVCache block.\npreselect_block: False # Whether to preselect.\nhead_select_mode: SHARED # All kv_heads jointly select.\npreselect_block_count: 0 # Number of preselected blocks.\nlayer_step: 1 # Select every few layers.\ntoken_step: 1 # Select every few tokens.\n```\n\nUnder our framework, the comparison between the original model and KTransformers after acceleration on datasets such as 128K Big Needle-in-a-Haystack, passkey, kvretrieval, etc., is as follows. The passkey dataset involves inserting a small segment of numbers at varying depths within a redundant text. kvretrieval is about finding a matching item in randomly generated key-value pairs. All tests were conducted under the opencompass framework:\n\n![needle_128K.png](../assets/needle_128K.png)\n\n|                                                             |                                 |         |             |\n| ----------------------------------------------------------- | ------------------------------- | ------- | ----------- |\n|                                                             | Single needle retrieval zh 128k | passkey | kvretrieval |\n| Original model                                              | 99.89                           | 100     | 21.0        |\n| KTransformers (reselect KVCache blocks for each generation) | 100                             | 100     | 15.40       |\n\nWe can see that both the original model and the accelerated KTransformers achieve perfect scores on the relatively simpler datasets, such as Single Needle Retrieval and passkey. At the same time, the generation speed has significantly improved, increasing from 4.86 tokens/s with llama.cpp to 27.49 tokens/s with KTransformers, achieving up to a 5.65x speedup. Although the current configuration shows a noticeable drop in performance on the more challenging kvretrieval dataset, in the next section, we will address this by implementing a more optimized selection strategy to compensate for or even surpass the original model's accuracy.\n\nAdditionally, we tested the performance of the KTransformers-based configuration framework in reproducing the results of Quest. However, since InternLM2.5-7B-Chat-1M uses GQA (Grouped Query Attention) while the Quest paper primarily focuses on optimizing MHA (Multi-Head Attention) models, the actual testing results were not particularly favorable. The official team also mentioned that further support for GQA models is needed, so we will not discuss this in detail for now.\n\n### Further improve performance\n\nBy modifying certain configurations within our flexible framework on the basis of reproduction, **we can actually achieve better results than those reported in the previous paper,** as shown in the figure below:\n\n![](../assets/Framework_effect.png)\n\nAs mentioned earlier, the goal of the kvretrieval dataset is to find a matching key-value pair within a long sequence of semantically meaningless pairs. If tokens are generated by reselecting based on the current query each time, the likelihood of deviation increases as the text grows, leading to the selection of different KVCache blocks compared to previous selections. To address this, we introduced a preselection mechanism using SnapKV to calculate the method for selecting representative tokens, which preselects a portion of the KVCache blocks. During the subsequent inference process, the selection is limited to these blocks. After one round of preselection, the score increased from 15.4 to 24.2, **surpassing the original model + full attention's performance of 21 points.** Further research indicates that the sparsity effect of the KVCache in the first few layers of LLMs is not as significant. Therefore, we set the first two layers to fully reuse the KVCache, ultimately achieving a score of **24.4**.\n\nSimilarly, when testing the needle-in-a-haystack task on the 1M dataset, we not only reproduced the original model's reported score but also further improved accuracy (**from 89.31 to 92.88**) by using the KTransformers CPU Sparse Attn Framework to selectively compute only certain KVCache blocks. Additionally, the inference speed **reached nearly 10 times that of llama.cpp**.\n\n![needle 1M.png](../assets/needle_1M.png)\n\n### More comparisons\n\nAs shown in the two figures below, using the Single Needle Retrieval dataset as an example, we set llama.cpp to store the KVCache on CPU/DRAM while performing all computations on the GPU. On a 4090D server, we compared the KTransformers CPU Sparse Attn Framework with llama.cpp. While maintaining **100% answer accuracy**, we achieved a 20.6 to 94.1 times prefill speed increase and a **1.2 to 7.1 times inference speed boost**.\n\n| ![long context prefill.png](../assets/long_context_prefill.png) | ![long context generate.png](../assets/long_context_generate.png) |\n| --------------------------------------------------------------- | ----------------------------------------------------------------- |\n\nThe main reason for the significant gap in prefill speed is that after enabling KVCache offload, llama.cpp performs the attention (attn) computation on the CPU. In long-text scenarios, attention not only requires heavy computation but also takes up the majority of the computation time. In contrast, KTransformers leverages a flexible template injection framework to implement GPU Chunk Prefill layer by layer. Moving forward, we plan to further integrate high-performance sparse prefill methods such as MInference to boost speed even further.\n\nAdditionally, as a key focus of this article, the right-hand graph shows that as the prompt length increases, the inference speed of KTransformers remains stable, hovering near a horizontal line. In contrast, llama.cpp slows down as the prompt length increases. By selecting only the most important 16K KVCache blocks to participate in the inference computation, KTransformers maintains a consistent inference speed comparable to llama.cpp when processing a 16K prompt, without any performance degradation (at least on these test datasets).\n\n## How to Use\n\nCurrently, long context is only supported by our **local_chat.py** interface, and the integration with the server interface is under development.\n\nTo facilitate user management, we have uploaded the model config, gguf, and tokenizer to a repo. URL: https://huggingface.co/nilv234/internlm2_5_to_llama_1m/tree/main\n\nBy setting the model_path and gguf_path in the local_chat function to **/path/to/repo** and setting the mode to **\"long_context\"**, you can use the InternLM2.5-7B-Chat-1M model with 1m functionality on a 24G VRAM.\n\nAfter running local_chat.py for the first time, a config.yaml file will be automatically created under ** ~/.ktransformers**. The relevant configurations for long context are as follows:\n\n```python\nchunk_size: 4096 # prefill chunk size\nmax_seq_len: 100000 # KVCache length\nblock_size: 128 # KVCache block size\nlocal_windows_len: 4096 # The KVCache of length local_windows_len is stored on the GPU.\nsecond_select_num: 96 # After preselection, each time select the number of KVCache blocks. If >= preselect_block_count, use the preselected blocks.\nthreads_num: 64 # CPU thread num\nanchor_type: DYNAMIC # KVCache block representative token selection method.\nkv_type: FP16\ndense_layer_num: 0 # The first few layers do not need to fill or select KVCache\nanchor_num: 1 # The number of representative tokens within a KVCache block.\npreselect_block: False # Whether to preselect.\nhead_select_mode: SHARED # All kv_heads jointly select.\npreselect_block_count: 96 # Number of preselected blocks.\nlayer_step: 1 # Select every few layers.\ntoken_step: 1 # Select every few tokens.\n```\n\nThe memory required for different context lengths is shown in the table below:\n\n|                | 4K  | 32K  | 64K  | 128K | 512K | 1M     |\n| -------------- | --- | ---- | ---- | ---- | ---- | ------ |\n| DRAM Size (GB) | 0.5 | 4.29 | 8.58 | 17.1 | 68.7 | 145.49 |\n\nPlease choose an appropriate max_seq_len based on your DRAM size.\nFor example:\n```python\npython local_chat.py --model_path=\"/data/model/internlm2_5_to_llama_1m\"  --gguf_path=\"/data/model/internlm2_5_to_llama_1m\" --max_new_tokens=500 --cpu_infer=10  --use_cuda_graph=True  --mode=\"long_context\" --prompt_file=\"/path/to/file\"\n```\n\nIf you've already specified the input text via the prompt_file, just press Enter when the terminal displays chat: to begin.\n"
  },
  {
    "path": "doc/en/long_context_tutorial.md",
    "content": "## How to use ktransformers long context framework\n\nCurrently, long context is only supported by our **local_chat.py** interface, and the integration with the server interface is under development.\n\nTo facilitate user management, we have uploaded the model config, gguf, and tokenizer to a repo. URL: https://huggingface.co/nilv234/internlm2_5_to_llama_1m/tree/main\n\nBy setting the model_path and gguf_path in the local_chat function to **/path/to/repo** and setting the mode to **\"long_context\"**, you can use the InternLM2.5-7B-Chat-1M model with 1m functionality on a 24G VRAM.\n\nAfter running local_chat.py for the first time, a config.yaml file will be automatically created under ** ~/.ktransformers**. The relevant configurations for long context are as follows:\n\n```python\nchunk_size: 4096 # prefill chunk size\nmax_seq_len: 100000 # KVCache length\nblock_size: 128 # KVCache block size\nlocal_windows_len: 4096 # The KVCache of length local_windows_len is stored on the GPU.\nsecond_select_num: 96 # After preselection, each time select the number of KVCache blocks. If >= preselect_block_count, use the preselected blocks.\nthreads_num: 64 # CPU thread num\nanchor_type: DYNAMIC # KVCache block representative token selection method.\nkv_type: FP16\ndense_layer_num: 0 # The first few layers do not need to fill or select KVCache\nanchor_num: 1 # The number of representative tokens within a KVCache block.\npreselect_block: False # Whether to preselect.\nhead_select_mode: SHARED # All kv_heads jointly select.\npreselect_block_count: 96 # Number of preselected blocks.\nlayer_step: 1 # Select every few layers.\ntoken_step: 1 # Select every few tokens.\n```\n\nThe memory required for different context lengths is shown in the table below:\n\n|                | 4K  | 32K  | 64K  | 128K | 512K | 1M     |\n| -------------- | --- | ---- | ---- | ---- | ---- | ------ |\n| DRAM Size (GB) | 0.5 | 4.29 | 8.58 | 17.1 | 68.7 | 145.49 |\n\nPlease choose an appropriate max_seq_len based on your DRAM size.\nFor example:\n```python\npython local_chat.py --model_path=\"/data/model/internlm2_5_to_llama_1m\"  --gguf_path=\"/data/model/internlm2_5_to_llama_1m\" --max_new_tokens=500 --cpu_infer=10  --use_cuda_graph=True  --mode=\"long_context\" --prompt_file=\"/path/to/file\"\n```\n\nIf you've already specified the input text via the prompt_file, just press Enter when the terminal displays chat: to begin."
  },
  {
    "path": "doc/en/makefile_usage.md",
    "content": "# Makefile\n## Target\n### flake_find:\n```bash\nmake flake_find\n```\nfind all the python files under ./ktransformers dir and find the Error, Warning, Fatal... (their codes) into a list that are not consistent with the pep8 standard. For now we have get all this list in the .flake8 file's extend-ignore section in order to let flakes8 ignore them temporarily.(we may improve them in the future)\n### format:\n```bash\nmake format\n```\nwe use black to format all the python files under ./ktransformers dir. It obeys the pep8 standard \nbut we modify the line length to 120 by add \n```toml\n[tool.black]\nline-length = 120\npreview = true\nunstable = true\n```\nin the pyproject.toml file.\n\n### dev_install:\n```bash\nmake dev_install\n```\ninstall the package in the development mode. It means that the package is installed in the editable mode. So if you modify the code, you don't need to reinstall the package. We recommend the developer to use this method to install the package."
  },
  {
    "path": "doc/en/multi-gpu-tutorial.md",
    "content": "\n# Muti-GPU\n\nAssume you have read the [Injection Tutorial](./injection_tutorial.md) and have a basic understanding of how to inject a model. In this tutorial, we will show you how to use KTransformers to run a model on multiple GPUs.\n\nIf you have multiple GPUs, you can set the device for each module to different GPUs. \nDeepseekV2-Chat got 60 layers, if we got 2 GPUs, we can allocate 30 layers to each GPU. Complete multi GPU rule examples [here](https://github.com/kvcache-ai/ktransformers/blob/main/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml).\n\n\n<p align=\"center\">\n  <picture>\n    <img alt=\"Inject-Struction\" src=\"../assets/multi_gpu.png\" width=60%>\n  </picture>\n</p>\n\nFirst of all, for multi-GPU, we have to inject an new operator `KDeepseekV2Model`. And set division of the layers to different GPUs. For our case, we have to set the `transfer_map` in the `KDeepseekV2Model` operatoras as follows:\n\n```yaml\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      transfer_map: \n        30: \"cuda:1\"\n```\n\nAnd we have to set the device for each module in the model. \n\nFor example, for `routed experts`, the yaml for one GPU is:\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # Custom MoE kernel with expert parallelism\n    kwargs:\n      generate_device: \"cuda:0\"\n      generate_op: \"MLPCUDAExperts\"\n      out_device: \"cuda:0\"\n  recursive: False # Don't recursively inject submodules of this module\n```\nBut for two GPUs, we need to set the device for each module in the model. \n\n```yaml\n# allcate 0-29 layers‘s out_device to cuda:0\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n# allocate 30-59 layers‘s out_device to cuda:1\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n```\nFor other modules, we can set the device in the same way.\n\n# How to fully utilize multi-GPU's VRAM\n\nWhen you have multiple GPUs, you can fully utilize the VRAM of each GPU by moving more weights to the GPU.\n\nFor example, for DeepSeekV2-Chat, we can move the weights of the experts to the GPU. \n\nFor example, the yaml for two GPUs is:\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False\n```\n\nBut we got extra 60GB VRAM on cuda:0, we can move experts in layer 4~8 to cuda:0. \n\n```yaml\n# Add new rule before old rule.\n- match:\n    name: \"^model\\\\.layers\\\\.([4-8])\\\\.mlp\\\\.experts$\" # inject experts in layer 4~8 as marlin expert\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts  \n    kwargs:\n      generate_device: \"cuda:0\"\n      generate_op:  \"KExpertsMarlin\"\n  recursive: False\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     \n    kwargs:\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False \n```\n\nAdjust the layer range as you want. Note that:\n* The loading speed will be significantly slower for each expert moved to the GPU.\n* You have to close the cuda graph if you want to move the experts to the GPU.\n* For DeepSeek-R1/V3, each expert moved to the GPU will consume approximately 6GB of VRAM.\n* The first matched rule in yaml will be applied. For example, if you have two rules that match the same layer, only the first rule's replacement will be valid.\n\n\n"
  },
  {
    "path": "doc/en/operators/llamafile.md",
    "content": "# Llamafile Operators Documentation\n\n## Llamafile Sgemm\n\nThe Llamafile Sgemm module is an efficient implementation of general matrix multiplication (GEMM) extracted from the great [Llamafile project](https://github.com/Mozilla-Ocho/llamafile/blob/main/llamafile/sgemm.cpp). \nThis module optimizes performance by utilizing various processor-specific instruction sets. For instance, it checks for different x86 instruction sets such as AVX, FMA, and AVX512, leveraging these advanced instructions to accelerate computation. \nAdditionally, the Llamafile Sgemm module supports multiple quantization types, including q8_0, q6_k, and q5_k, among others. This adaptability to different hardware capabilities ensures the most advanced instructions are used in any given computing environment, achieving high computational efficiency. For more information, you can view the [Llamafile Sgemm module](https://github.com/Mozilla-Ocho/llamafile/blob/main/llamafile/sgemm.cpp) on GitHub.\n\n\n## CPUInfer\nTo power Llamafile and many future CPU kernels without the original GGML framework, we developed a simple CPUInfer multi-threaded execution framework. It currently leverages the Llamafile Sgemm module to implement key operators such as linear layers, MLP, and MoE, and will be extended to support many other operators. These operators are fundamental components for building large models. CPUInfer features a backend work-stealing thread pool and asynchronous task queue execution logic to efficiently offload parts of model parameters to the CPU, thereby maintaining high inference performance. It supports adjustments based on hardware capabilities or user configurations, providing enhanced inference performance and making it an ideal tool for running deep learning models on CPUs.\n\n## Expert-Parallel MoE\n\nThe MoE module's performance can be enhanced by using custom kernels that utilize **expert parallelism**. Since the routed experts are independently computable, we can utilize this inherent parallelism to speed up MoE computations. Specifically, we can allocate each expert MLP to a separate thread group, allowing for the simultaneous computation of all routed experts. This approach of expert parallelism significantly boosts MoE performance by minimizing the frequency of global synchronizations and reducing kernel launch overhead compared to sequential expert computation.\n\n## Microbenchmark\n\nOur evaluations were conducted on an Intel(R) Xeon(R) Gold 6454S processor, utilizing real parameters from the DeepSeek-Coder-V2-Instruct model.\n\n### Linear Projection\n\nThe performance of the linear layer was assessed using an Attention Output Projection with dimensions of [5120, 16384]. Here, the input was a vector of 16384 dimensions, and the output was a vector of 5120 dimensions.\n\n![Linear_projection_time](Linear_projection_time.png)\n\nAs we can see, in half-precision floating-point formats (fp16 and bf16), CPUInfer's performance exceeded that of Torch by 1.7 and 1.5 times, respectively. For 8-bit quantization, CPUInfer (supporting q8_0) and Torch (supporting qint8) demonstrated nearly equivalent performance. However, CPUInfer employs a more refined scaling approach, using different factors for each group (in q8_0 quantization, every 32 numbers form one group), whereas Torch uses a basic per-tensor quantization, potentially leading to significant precision loss. Furthermore, CPUInfer’s capability to use lower-bit quantization enhances inference speed in specific scenarios.\n\n### MoE\n\nIn the MoE module, each token selected 6 experts out of 160 for computation, with input and output dimensions of 5120, and an intermediate dimension of 1536.\n\n![Combined_MoE_time_per_layer](Combined_MoE_time_per_layer.png)\n\nFor half-precision floating points and 8-bit quantization formats, CPUInfer's generation performance was 2.5 and 3.2 times better than Torch, respectively. Moreover, using the 8-bit quantization format, CPUInfer achieved faster prefill speeds compared to Torch, with shorter prompts highlighting a more pronounced performance difference.\n"
  },
  {
    "path": "doc/en/prefix_cache.md",
    "content": "## Enabling Prefix Cache Mode in KTransformers\n\nBalance serve now supports prefix cache reuse! To enable **Prefix Cache Mode** in KTransformers, you need to modify the configuration file and recompile the project. \n\n### Step 1: Modify the Configuration File\n\nEdit the `./ktransformers/configs/config.yaml` file with the following content (you can adjust the values according to your needs):\n\n```yaml\nattn:\n  page_size: 16 # Size of a page in KV Cache.\n  chunk_size: 256\nkvc2:\n  gpu_only: false # Set to false to enable prefix cache mode (Disk + CPU + GPU KV storage)\n  utilization_percentage: 1.0\n  cpu_memory_size_GB: 500 # Amount of CPU memory allocated for KV Cache\n  disk_path: /mnt/data/kvc # Path to store KV Cache on disk\n```\n\n### Step 2: Update Submodules and Recompile\n\nIf this is your first time using prefix cache mode, please update the submodules first:\n\n```bash\ngit submodule update --init --recursive # Update PhotonLibOS submodule\n```\n\nThen recompile the project:\n\n```bash\n# Install single NUMA dependencies\nUSE_BALANCE_SERVE=1  bash ./install.sh\n# For those who have two cpu and 1T RAM（Dual NUMA）:\nUSE_BALANCE_SERVE=1 USE_NUMA=1 bash ./install.sh\n```\n\n## Note\nBalance serve utilizes a 3-layer (GPU-CPU-Disk) scheme to store and reuse KVCache. Deleting KVCache is not supported now. If you have too much KVCache, you can simply delete them by remove kvcache files. \n\n"
  },
  {
    "path": "doc/en/xpu.md",
    "content": "# Intel GPU Support for KTransformers (Beta)\n\n## Introduction\n\n### Overview\nWe are excited to introduce **Intel GPU support** in KTransformers (Beta release). This implementation has been tested and developed using Intel Xeon Scalable processors and Intel Arc GPUs (such as A770 and B580).\n\n## Installation Guide\n\n### 1. Install Intel GPU Driver\nBegin by installing the GPU drivers for your Intel GPU:\n- [Official GPU Installation Guide for Intel GPUs](https://dgpu-docs.intel.com/driver/overview.html)\n\nTo verify that the kernel and compute drivers are installed and functional:\n\n```bash\nclinfo --list | grep Device\n `-- Device #0: 13th Gen Intel(R) Core(TM) i9-13900K\n `-- Device #0: Intel(R) Arc(TM) A770 Graphics\n `-- Device #0: Intel(R) UHD Graphics 770\n```\n\n> [!Important]\n> Ensure that **Resizable BAR** is enabled in your system's BIOS before proceeding. This is essential for optimal GPU performance and to avoid potential issues such as `Bus error (core dumped)`. For detailed steps, please refer to the official guidance [here](https://www.intel.com/content/www/us/en/support/articles/000090831/graphics.html).\n\n### 2. Set Up Conda Environment\nWe recommend using Miniconda3/Anaconda3 for environment management:\n\n```bash\n# Download Miniconda\nwget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n\n# Create environment\nconda create --name ktransformers python=3.11\nconda activate ktransformers\n\n# Install required libraries\nconda install -c conda-forge libstdcxx-ng\n\n# Verify GLIBCXX version (should include 3.4.32)\nstrings ~/anaconda3/envs/ktransformers/lib/libstdc++.so.6 | grep GLIBCXX\n```\n\n> **Note:** Adjust the Anaconda path if your installation directory differs from `~/anaconda3`\n\n### 3. Install PyTorch and IPEX-LLM\nInstall PyTorch with XPU backend support and [IPEX-LLM](https://github.com/intel/ipex-llm):\n\n```bash\npip install ipex-llm[xpu_2.6]==2.3.0b20250518 --extra-index-url https://download.pytorch.org/whl/xpu\npip uninstall torch torchvision torchaudio\npip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu # install torch2.7\npip uninstall intel-opencl-rt dpcpp-cpp-rt\n```\n\n### 4. Build ktransformers\n\n```bash\n# Clone repository\ngit clone https://github.com/kvcache-ai/ktransformers.git\ncd ktransformers\ngit submodule update --init\n\n# Install dependencies\nbash install.sh --dev xpu\n```\n\n## Running DeepSeek-R1 Models\n\n### Configuration for 16B VRAM GPUs\nUse our optimized configuration for constrained VRAM:\n\n```bash\nexport SYCL_CACHE_PERSISTENT=1\nexport ONEAPI_DEVICE_SELECTOR=level_zero:0\nexport SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1\n\npython ktransformers/local_chat.py \\\n  --model_path deepseek-ai/DeepSeek-R1 \\\n  --gguf_path <path_to_gguf_files> \\\n  --optimize_config_path ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml \\\n  --cpu_infer <cpu_cores + 1> \\\n  --device xpu \\\n  --max_new_tokens 200\n```\n\n## Known Limitations\n- Serving function is not supported on Intel GPU platform for now\n\n## Troubleshooting\n1. Best Known Config (BKC) to obtain best performance\n\nTo obtain best performance on Intel GPU platform, we recommend to lock GPU frequency and set CPU to performance mode by below settings.\n```bash\necho \"performance\" | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor\necho 0 | sudo tee /sys/devices/system/cpu/cpu*/power/energy_perf_bias\n# 2400 is max frequency for Arc A770\nsudo xpu-smi config -d 0 -t 0 --frequencyrange 2400,2400\n# 2850 is max frequency for Arc B580\n# sudo xpu-smi config -d 0 -t 0 --frequencyrange 2850,2850\n```\n\n2. Runtime error like `xpu/sycl/TensorCompareKernels.cpp:163: xxx. Aborted (core dumped)`\n\nThis error is mostly related to GPU driver. If you meet such error, you could update your `intel-level-zero-gpu` to `1.3.29735.27-914~22.04` (which is a verified version by us) by below command.\n```bash\nwget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \\\nsudo gpg --dearmor --output /usr/share/keyrings/intel-graphics.gpg\necho \"deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu jammy client\" | \\\nsudo tee /etc/apt/sources.list.d/intel-gpu-jammy.list\nsudo apt update\n# or sudo apt update --allow-insecure-repositories\nsudo apt install intel-level-zero-gpu=1.3.29735.27-914~22.04\n```\n\n3. `ImportError: cannot import name 'intel' from 'triton._C.libtriton'`\n\nInstalling Triton causes pytorch-triton-xpu to stop working. You can resolve the issue with following command:\n```bash\npip uninstall triton pytorch-triton-xpu\n# Reinstall correct version of pytorch-triton-xpu\npip install pytorch-triton-xpu==3.3.0 --index-url  https://download.pytorch.org/whl/xpu\n```\n\n4. `ValueError: Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.`\n\nEnsure you have permissions to access /dev/dri/renderD*. This typically requires your user to be in the render group:\n```bash\nsudo gpasswd -a ${USER} render\nnewgrp render\n```\n\n## Additional Information\nTo run KTransformers on XPU with Docker, please refer to [Docker_xpu.md](./Docker_xpu.md).\n"
  },
  {
    "path": "doc/zh/DeepseekR1_V3_tutorial_zh.md",
    "content": "<!-- omit in toc -->\n\n# GPT-4/o1 级别本地 VSCode Copilot 在仅 24GB 显存的台式机上的表现\n\n- [摘要](#摘要)\n  - [先决条件](#先决条件)\n  - [基准测试结果](#基准测试结果)\n    - [V0.2](#v02)\n      - [设置](#设置)\n      - [内存占用](#内存占用)\n      - [基准测试结果](#基准测试结果)\n    - [V0.3-Preview](#V0.3-Preview)\n      - [设置](#设置-1)\n      - [内存占用](#内存占用-1)\n      - [基准测试结果](#基准测试结果-1)\n  - [如何运行](#如何运行)\n    - [V0.2 展示](#v02-展示)\n      - [单插槽版本 (32 核心)](#单插槽版本（32 核心）)\n      - [双插槽版本 (64 核心)](#双插槽版本（64 核心）)\n    - [V0.3 展示](#v03-展示)\n      - [双插槽版本 (64 核心)](#双插槽版本（64 核心）-1)\n  - [一些解释](#一些解释)\n  - [常见问题解答](#常见问题解答)\n    - [R1 不思考](#R1 不返回思考过程)\n    - [更多常见问题解答](#更多常见问题解答)\n\n# 摘要\n\n> **2025年2月10日**: 支持在单个（24GB 显存）/多个 GPU 和 382GB 内存上运行 DeepseekR1 和 V3，速度提升高达 3~28 倍。<br>\n\n嗨，我们是 KTransformers 团队（以前因本地 CPU/GPU 混合推理开源项目 DeepSeek-V2 而闻名）。\n\n我们听到了您对 DeepSeek-R1/V3 支持的请求——我们很高兴终于可以交付了！很抱歉让您久等了，但我们一直在酝酿一些真正令人惊叹的东西！\n\n今天，我们自豪地宣布，我们不仅支持 DeepSeek-R1/V3，如下视频所示：\n\nhttps://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285\n\n</p>\n\n- **[NEW!!!] 本地 671B DeepSeek-Coder-V3/R1:** 仅使用 14GB 显存和 382GB 内存运行其 Q4_K_M 版本。\n  - 预填充(Prefill)速度 (tokens/s):\n    - KTransformers: 54.21 (32 核心) → 74.362 (双插槽，2×32 核心) → 255.26 (优化的 AMX 基 MoE 内核，仅 V0.3) → 286.55 (选择性使用 6 个专家，仅 V0.3)\n    - 与 llama.cpp 在 2×32 核心下 10.31 tokens/s 相比，速度提升高达 **27.79 倍**\n  - 解码(Decode)速度 (tokens/s):\n    - KTransformers: 8.73 (32 核心) → 11.26 (双插槽， 2×32 核心) → 13.69 (选择性使用 6 个专家，仅 V0.3)\n    - 与 llama.cpp 在 2×32 核心下 4.51 tokens/s 相比，速度提升高达 **3.03 倍**\n\n我们还提供了即将推出的优化预览，包括英特尔 AMX 加速内核和选择性专家激活方法，这将显著提升性能。通过 V0.3 预览版，我们在预填充方面实现了高达 286 tokens/s 的速度，比本地推理的 llama.cpp **快 28 倍**。二进制发行版现已可用，源代码即将推出！请查看 wheel 包 [此处](https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl) 。\n\n## 先决条件\n\n我们在以下配置下进行了最佳性能测试（V0.2）： <br>\nCPU: Intel (R) Xeon (R) Gold 6454S 1T 内存 (2 NUMA 节点) <br>\nGPU: 4090D 24G 显存 <br>\n内存: 标准 DDR5-4800 服务器内存 (1 TB)\n\n## 基准测试结果\n\n### V0.2\n\n#### 设置\n\n- Model: DeepseekV3-q4km (int4)<br>\n- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S，每个插槽 32 核心，2 个插槽，2 个 NUMA 节点\n- GPU: 4090D 24G 显存\n- 我们在充分预热后进行测试\n\n#### 内存占用:\n\n- 单插槽: 382G 内存，至少 14GB 显存\n- 双插槽: 1T 内存，至少 14GB 显存\n\n#### 基准测试结果\n\n“6 个专家” 情况是 V0.3 预览版中内容\n\n\n| Prompt<br>(500 tokens)  | 双插槽 Ktrans (6 个专家) | 双插槽 Ktrans (8 个专家) | Single socket Ktrans (6 个专家) | Single socket Ktrans (8 个专家) | llama.cpp (8 个专家) |\n| ----------------------- | ------------------------ | ------------------------ | ------------------------------- | ------------------------------- | -------------------- |\n| 预填充(Prefill) token/s | 97.32                    | 82.94                    | 65.14                           | 54.21                           | 10.31                |\n| 解码(Decode) token/s    | 13.69                    | 12.208                   | 10.303                          | 8.73                            | 4.51                 |\n\n**最高加速比在解码方面达到 <u>3.03x</u> 倍，在预填充方面达到 <u>9.44x</u> 倍。**\n\n### V0.3-Preview\n\n#### 设置\n\n- Model: DeepseekV3-BF16 (在线量化为 CPU 的 int8 和 GPU 的 int4)\n- CPU: cpu_model_name: Intel (R) Xeon (R) Gold 6454S，每个插槽 32 核心，2 个插槽，2 个 NUMA 节点\n- GPU: (1~4)x 4090D 24G 显存 (更长的 prompt 需要更多显存)\n\n#### 内存占用:\n\n- 644GB 内存，至少 14GB 显存\n\n#### 基准测试结果\n\n\n| Prompt length                     | 1K     | 2K     | 4K     | 8K     |\n| --------------------------------- | ------ | ------ | ------ | ------ |\n| KTrans (8 个专家) Prefill token/s | 185.96 | 255.26 | 252.58 | 195.62 |\n| KTrans (6 个专家) Prefill token/s | 203.70 | 286.55 | 271.08 | 207.20 |\n\n**KTrans V0.3 的预填充速度比 KTrans V0.2 快 <u>3.45x</u> 倍，比 llama.cpp 快 <u>27.79x</u> 倍。**\n**解码速度与 KTrans V0.2（6 个专家版本）相同，因此省略。**\n\n主要加速来自于\n\n- 英特尔 AMX 指令集和我们专门设计的缓存友好内存布局\n- 专家选择策略，根据离线配置文件结果选择更少的专家\n\n*从我们对 DeepSeekV2、DeepSeekV3 和 DeepSeekR1 的研究中，当我们略微减少推理中的激活专家数量时，输出质量没有变化。但解码和预填充的速度加快了，这令人鼓舞。因此，我们的展示利用了这一发现。*\n\n## 如何运行\n\n### 多并发展示\n\n多并发需要额外编译调度器 c++ 代码\n\n```shell\nsudo apt install libtbb-dev libssl-dev libcurl4-openssl-dev libaio1 libaio-dev libfmt-dev\nsudo apt-get install libgflags-dev zlib1g-dev patchelf\ngit clone https://github.com/kvcache-ai/ktransformers.git\ncd ktransformers\ngit submodule update --init --recursive\n# 如果使用双 numa 版本\nUSE_BALANCE_SERVE=1 USE_NUMA=1 bash ./install.sh\n# 如果使用单 numa 版本\nUSE_BALANCE_SERVE=1 bash ./install.sh\n# 启动命令\npython ktransformers/server/main.py --model_path <your model path> --gguf_path <your gguf path> --cpu_infer 62 --optimize_config_path <inject rule path> --port 10002 --chunk_size 256 --max_new_tokens 1024 --max_batch_size 4 --port 10002 --cache_lens 32768 --backend_type balance_serve\n```\n\n`<your model path>` 可以是本地路径，也可以是在线路径，例如 deepseek-ai/DeepSeek-V3。如果在线连接出现问题，可以尝试使用镜像（hf-mirror.com） <br>\n`<your gguf path>` 也可以是在线路径，但由于其体积较大，我们建议您下载并量化模型（注意这是目录路径）\n\n`<inject rule path>` 注入规则 yaml 文件地址，我们在 `ktransformers/optimize/optimize_rules/ ` 目录下提供了 `DeepSeek-V3-Chat-serve.yaml` 和 `DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml` 分别对应 [`DeepSeek-V3/R1-q4km`](https://huggingface.co/unsloth/DeepSeek-R1-GGUF/tree/main/DeepSeek-R1-Q4_K_M) 和 [`DeepSeek-V3/R1-hybrid`](https://huggingface.co/KVCache-ai/DeepSeek-R1-GGML-FP8-Hybrid/tree/main)\n\n`--max_new_tokens 1000` 是最大输出 token 长度。如果发现答案被截断，可以增加此数字以获得更长的答案（但要注意内存不足问题，增加此数字会降低生成速度）.\n\n`--chunk_size 256` 引擎单次运行最大 token 个数\n\n`--cache_lens 32768`  调度器申请 kvcache 的总长度。所有请求共享 32768 个 tokens 对应 kvcache 空间，请求完成后会释放其所占用的 kvcache 空间。\n\n`--backend_type balance_serve` `balance_serve`是 v0.2.4新增的后端引擎，原本的单并发引擎为`ktransformers`\n\n`--max_batch_size 4` 引擎单次运行最多处理 4 个请求(prefill + decode),(仅用于`balance_serve`)\n\n<br>命令 numactl -N 1 -m 1 的目的是避免 NUMA 节点之间的数据传输<br>\n注意！如果测试 R1 可能会跳过思考。因此，可以添加参数：`--force_think`，这在 [常见问题解答](#常见问题解答) 部分中解释。\n\n### V0.2 展示\n\n#### 单插槽版本（32 核心）\n\n我们的 local_chat 测试命令是:\n\n```shell\ngit clone https://github.com/kvcache-ai/ktransformers.git\ncd ktransformers\ngit submodule init\ngit submodule update\nnumactl -N 1 -m 1 python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path>  --prompt_file <your prompt txt file>  --cpu_infer 33 --max_new_tokens 1000\n<当您看到聊天时，按回车键加载文本提示文件>\n```\n\n#### 双插槽版本（64 核心）\n\n在安装之前（使用 install.sh 或 `make dev_install`），请确保设置环境变量 `USE_NUMA=1`，方法是 `export USE_NUMA=1`（如果已经安装，请重新安装并设置此环境变量） <br>\n我们的 local_chat 测试命令是：\n\n```shell\ngit clone https://github.com/kvcache-ai/ktransformers.git\ncd ktransformers\ngit submodule init\ngit submodule update\nexport USE_NUMA=1\nmake dev_install # or sh ./install.sh\npython ./ktransformers/local_chat.py --model_path <your model path> --gguf_path <your gguf path>  --prompt_file <your prompt txt file>  --cpu_infer 65 --max_new_tokens 1000\n<当您看到聊天时，按回车键加载文本提示文件>\n```\n\n参数的含义相同。但因为我们使用双插槽，所以将 cpu_infer 设置为 65。\n\n### V0.3 展示\n\n#### 双插槽版本（64 核心）\n\n我们的 local_chat 测试命令是：\n\n```shell\nwget https://github.com/kvcache-ai/ktransformers/releases/download/v0.1.4/ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl\npip install ./ktransformers-0.3.0rc0+cu126torch26fancy-cp311-cp311-linux_x86_64.whl\npython -m ktransformers.local_chat --model_path <your model path> --gguf_path <your gguf path>  --prompt_file <your prompt txt file>  --cpu_infer 65 --max_new_tokens 1000\n<当您看到聊天时，按回车键加载文本提示文件>\n```\n\n参数的含义与 V0.2 相同。但因为我们使用双插槽，所以将 cpu_infer 设置为 65。\n\n## 一些解释\n\n1. 我们还想进一步利用 Xeon Gold CPU 上的两个 NUMA 节点。为了避免节点之间的数据传输成本，我们在两个节点上 \"copy\" 了关键矩阵，这会增加内存占用，但会加速预填充和解码过程。但这种方法占用大量内存，加载权重时速度较慢，因此加载时请耐心等待并监控内存使用情况。我们计划优化这一巨大的内存开销。敬请期待。\n2. 命令参数 `--cpu_infer 65` 指定使用多少核心（超过物理核心数量是可以的，但并不是越多越好。根据实际核心数量适当降低此值）。<br>\n3. 为什么使用 CPU/GPU 混合推理？\n   DeepSeek 的 MLA 操作符计算密集。虽然全部在 CPU 上运行是可行的，但将繁重的计算任务卸载到 GPU 上能带来巨大的性能提升。\n4. 加速来自哪里？\n\n   - 专家卸载：与传统的基于层或 KVCache 卸载（如 llama.cpp 中的）不同，我们将专家计算卸载到 CPU，将 MLA/KVCache 卸载到 GPU，与 DeepSeek 的架构完美对齐，实现最佳效率。\n   - 英特尔 AMX 优化 – 我们的 AMX 加速内核经过精心调优，运行速度是现有 llama.cpp 实现的数倍。我们计划在清理后开源此内核，并考虑向 llama.cpp 上游贡献代码。\n5. 为什么选择英特尔 CPU？\n   英特尔目前是唯一支持 AMX 类似指令的 CPU 供应商，与仅支持 AVX 的替代方案相比，性能显著更好。\n\n## 常见问题解答\n\n### R1 不返回思考过程\n\n注意！如果测试 R1 可能会跳过思考。因此，可以添加参数：`--force_think true`。详细信息在 [常见问题解答](./FAQ.md) 部分中。 <br>\n\n## 问题\n\n* 修复服务器集成功能以实现网络API访问支持\n* 修复本地聊天功能仅支持单行提示输入的问题（目前输入换行符(\\n)即开始生成提示）\n\n### 更多常见问题解答\n\n[详见](./FAQ.md)\n"
  },
  {
    "path": "doc/zh/DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md",
    "content": "# 基准测试结果\n\n在 Batchsize=4、输出长度为 1024 的条件下，性能测试结果如下：\n\n| Prompt length                     | 1K     | 2K     | 4K     |\n| --------------------------------- | ------ | ------ | ------ |\n| KTrans Prefill token/s | 174.68 | 169.52 | 167.15 |\n| KTrans Decode token/s | 16.07 | 16.12 | 16.48 |\n\n## 先决条件\n我们在以下配置下进行了Deepseek-R1最佳性能测试：\n- 服务器型号：Atlas 2UP\n- NPU：Atlas 300I A2\n- CPU: HUAWEI Kunpeng 920 7270Z\n- 内存: DDR5服务器内存（1TB）\n\n# 部署\n\n## 物理机安装\n\n部署满血版Deepseek-R1/V3，需要机器物理内存能够存放下全部路由专家的权重，约400GB。\n\n目前支持的NPU型号：**300I A2**。\n\n在技术人员的支持下完成硬件安装。\n\n## 系统安装\n\n根据网页[昇腾兼容性查询助手](https://www.hiascend.com/hardware/compatibility)查询，选用系统Ubuntu 22.04 for aarch64，内核5.15.0-25-generic，并禁止系统自动更新。系统镜像获取链接：[ubuntu-old-releases](https://mirrors.aliyun.com/oldubuntu-releases/releases/22.04)。\n\n## HDK安装\n\n选择[Ascend HDK 25.3.RC1](https://www.hiascend.com/hardware/firmware-drivers/community?product=4&model=32&cann=8.3.RC1&driver=Ascend+HDK+25.3.RC1)进行安装，安装方式参考[昇腾社区HDK安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1alpha003/softwareinst/instg/instg_0005.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)。\n\n\n## 镜像部署\n\n建议使用昇腾MindIE镜像[昇腾社区镜像下载](https://www.hiascend.com/developer/ascendhub/detail/af85b724a7e5469ebd7ea13c3439d48f)部署开发环境，选择2.2.RC1-800I-A2-py311-openeuler24.03-lts下载。\n\n下载完成镜像后，执行以下命令启动容器：\n\n```bash\ndocker run -it -d --net=host --shm-size=500g \\\n       --name <container-name> \\\n       -w /workspace \\\n       --device=/dev/davinci_manager \\\n       --device=/dev/hisi_hdc \\\n       --device=/dev/devmm_svm \\\n       -v /usr/local/Ascend/driver:/usr/local/Ascend/driver:ro \\\n       -v /usr/local/dcmi:/usr/local/dcmi:ro \\\n       -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi:ro \\\n       -v /usr/local/sbin/:/usr/local/sbin:ro \\\n       -v <path_to_your_project>:/workspace \\\n       mindie:2.2.RC1-800I-A2-py311-openeuler24.03-lts bash\n```\n\n进入容器\n\n```bash\ndocker exec -it <container-name> /bin/bash\n```\n\n部署Python环境：\n\n```bash\nyum install zlib1g-dev libtbb-dev libssl-dev libaio-dev libcurl4-openssl-dev\npip3 install numpy==1.26.4  # 适配torch/torch_npu\npip3 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu\npip3 install packaging ninja fire protobuf attrs decorator cloudpickle ml-dtypes scipy tornado absl-py psutil\npip3 install sqlalchemy\npip3 install transformers==4.57.1 #此处注意运行时transformers版本要求4.57.1(其他版本未验证)\n#pip3 install cpufeature  # only for x86\n```\n\n## CANN安装\n\n选择[CANN 8.3.RC1.alpha003](https://www.hiascend.com/developer/download/community/result?cann=8.3.RC1.alpha003&product=4&model=32)进行安装，安装方式参考[昇腾社区CANN安装指导](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/83RC1alpha003/softwareinst/instg/instg_quick.html?Mode=PmIns&OS=Ubuntu&Software=cannToolKit)。\n\n需要安装ToolKit，Kernel和NNAL。\n\n## torch_npu安装\n\n获取最新的仓库代码：[torch_npu Gitcode](https://gitcode.com/Ascend/pytorch)\n\n由于涉及新增算子，公网pypi内提供的torch_npu暂时无法直接使用，可以下载代码仓库编译，当前适配分支为v2.5.1，编译命令可以参考仓库内文档。\n编译过程需要保证访问github，gitcode等平台网络畅通并设置如下环境变量：\n\n```bash\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh  # 以实际CANN安装路径为准\nsource /usr/local/Ascend/nnal/atb/set_env.sh  # 以实际NNAL安装路径为准\n```\n由于环境对于torch_npu版本号有特定要求，使用编译后的torch_npu包需要手动移除版本信息中的哈希后缀，操作如下：\n使用文本编辑器打开`/usr/local/lib/python3.11/site-packages/torch_npu/version.py`(不同环境python路径可能不同，可以使用`pip show torch_npu`查看安装的python路径)\n将`__version__ = '2.5.1.post4+git69550dfc'`改为`__version__ = '2.5.1.post4'`\n\n\n## 权重准备\n\n目前，为了满足性能和精度的要求，我们需要准备两份权重，并使用提供的权重合并脚本对权重进行合并，最终只会使用合并后的权重。\n\nQ4权重：[DeepSeek-R1-Q4_K_M](https://modelscope.cn/models/unsloth/DeepSeek-R1-GGUF/files)\n\nW8A8权重：[DeepSeek-R1-W8A8](https://modelers.cn/models/State_Cloud/DeepSeek-R1-W8A8)\n\n使用[merge_safetensor_gguf.py](../../merge_tensors/merge_safetensor_gguf.py)来合并Q4和W8A8权重：\n\n```bash\npython merge_safetensor_gguf.py --safetensor_path /mnt/weights/DeepSeek-R1-Q4_K_M --gguf_path /mnt/weights/DeepSeek-R1-W8A8 --output_path /mnt/weights/DeepSeek-R1-q4km-w8a8\n```\n\n## 图下沉部署\n\n开启图下沉功能，需要添加如下环境变量：\n\n```bash\nexport TASK_QUEUE_ENABLE=0  # 保证算子下发顺序有序\n```\n\n\n## kTransformers部署\n\n将项目文件部署到机器上：\n\n- 初始化third_party。由于此过程耗时较多，且容易受网络影响导致仓库克隆失败，建议初始化一次后，将相关文件进行打包，以便后续直接解压使用。\n  ```bash\n  git clone https://github.com/kvcache-ai/ktransformers.git\n  cd ktransformers\n  git submodule update --init --recursive\n  ```\n- 对于arm平台，注释掉`./third_party/llamafile/iqk_mul_mat_arm82.cpp`中的\n  ```cpp\n  #define iqk_mul_mat iqk_mul_mat_arm82\n  #define iqk_mul_mat_moe iqk_mul_mat_moe_arm82\n  ```\n- 执行`source /usr/local/Ascend/ascend-toolkit/set_env.sh`（以实际CANN-TOOLKIT安装路径为准）。\n- 执行`apt install cmake libhwloc-dev pkg-config`安装依赖。\n- 修改项目目录下 /ktransformers/config/config.yaml 中attn部分的page_size: 128  chunk_size: 16384\n- 执行`USE_BALANCE_SERVE=1 USE_NUMA=1 bash ./install.sh`，等待安装完成。\n\n此处给出示例balance_serve的启动脚本（由于使用了相对路径，需将该脚本放至项目的根路径下）：\n\n```bash\n#!/bin/bash\nexport USE_MERGE=0\nexport INF_NAN_MODE_FORCE_DISABLE=1\nexport TASK_QUEUE_ENABLE=0\nexport RANK=0\nexport LOCAL_WORLD_SIZE=1\n#export PROF_DECODE=1\n#export PROF_PREFILL=1\n\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\n\npython ktransformers/server/main.py \\\n--port 10002 \\\n--model_path <your model path> \\\n--gguf_path <your model path> \\\n--model_name DeepSeekV3ForCausalLM \\\n--cpu_infer 100 \\\n--optimize_config_path  ./ktransformers/optimize/optimize_rules/npu/DeepSeek-V3-Chat-300IA2-npu-serve.yaml \\\n--max_new_tokens 1024 \\\n--cache_lens 20480 \\\n--max_batch_size 4 \\\n--use_cuda_graph \\\n--tp 1 \\\n--backend_type balance_serve\n```\n\n相关参数说明：\n\n- `--model_path`：kTransformers原生参数，str，此处用来指定合并后的模型文件路径\n- `--gguf_path`：kTransformers原生参数，str，此处用来指定合并后的模型文件路径\n- `--cpu_infer`：kTransformers原生参数，int，用来控制CPU侧实际worker线程数，非必选\n- `--optimize_config_path`：kTransformers原生参数，str，用来指定所用的模型优化配置文件，需要注意相对路径的使用，此处为**必选**\n- `--cache_lens`：调度器申请 kvcache 的总长度。所有请求共享指定数量（例如 `20480`）的 tokens 对应的 kvcache 空间，请求完成后会释放其所占用的 kvcache 空间，非必选\n- `--use_cuda_graph`：kTransformers原生参数，bool，为True表示开启图下沉，为False表示关闭图下沉，非必选\n- `--max_new_tokens`：kTransformers原生参数，int，当统计到输出的tokens数量达到该值时，会直接中止输出，非必选\n- `--tp`：新增参数，int，用于开启tensor model parallel功能，目前local_chat只支持tp大小与ws大小相同（不支持local_chat使用多dp），非必选\n\n\n# 其他问题\n\n## 可能存在的其他依赖问题\n\nImportError: libhccl.so: cannot open shared object file: No such file or directory\n\n```bash\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh  # 以实际CANN安装路径为准\n```\n\nImportError: libascend_hal.so: cannot open shared object file: No such file or directory\n\n```bash\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH  # 以实际Driver安装路径为准\n```\n"
  },
  {
    "path": "doc/zh/KTransformers-Fine-Tuning_Developer-Technical-Notes_zh.md",
    "content": "- [KTransformers 微调 × LLaMA-Factory 集成 – 开发技术篇](#ktransformers-微调-x-llama-factory-集成-–-开发技术篇)\n- [Introduction](#introduction)\n\n- [KT微调框架整体性描述](#kt微调框架整体性描述)\n  - [Attention 部分（LoRA + KT 特性并存）](#attention-部分lora--kt-特性并存)\n    - [继承关系](#继承关系)\n    - [替换策略](#替换策略)\n  - [MoE 部分（算子封装+backward实现）](#moe-部分算子封装backward实现)\n    - [MoE算子封装](#moe算子封装)\n    - [MoE 反向优化 (CPU 实现)](#moe-反向优化-cpu-实现)\n  - [多卡加载与训练：用“放置策略”而不是 DataParallel](#多卡加载与训练用放置策略而不是-dataparallel)\n\n- [KT-LoRA微调测试](#kt-lora微调测试)\n  - [实验设置](#实验设置)\n  - [效果测试](#效果测试)\n    - [风格化对话测试（CatGirl语气）](#风格化对话测试catgirl语气)\n    - [生成式翻译风格基准测试](#生成式翻译风格基准测试)\n    - [医疗垂直领域基准（AfriMed-SAQ/MCQ）](#医疗垂直领域基准afrimed-saqmcq)\n    - [局限性说明](#局限性说明)\n\n- [速度测试](#速度测试)\n  - [端到端性能](#端到端性能)\n  - [MoE部分的计算性能（DeepSeek-V3-671B）](#moe部分的计算性能deepseek-v3-671b)\n\n- [显存/内存性能](#显存内存性能)\n\n- [结论](#结论)\n\n# KTransformers 微调 × LLaMA-Factory 集成 – 开发技术篇\n\n**MadSys实验室, KVCache-AI团队, 趋境科技, LLaMA-Factory团队**\n\n## Introduction\n\n当今的开源大模型（从 DeepSeek-V3/R1 到 Qwen-MoE 系列以及 Kimi-K2 等）在性能和规模上突飞猛进。然而，受限于**计算资源和显存**，普通研究者难以对这些上千亿乃至更大规模的模型进行微调。为此，我们设计了 **KTransformers** 与 **LLaMA-Factory** 集成的方案，使得仅需 **2～4 张 RTX 4090 GPU** 加上足够的 CPU 内存，就能微调 DeepSeek-671B 这样的超大规模 Mixture-of-Experts (MoE) 模型。\n\n这一架构旨在桥接资源鸿沟，让更多人能够**在本地探索超大模型微调**的可能；同时在相对小一些的模型（如 14B/30B 参数量级）上，也能提供**更高效的场景化定制**途径。我们通过风格化对话、西式翻译语气、医学问答等任务验证了该方案，仅用数小时即可实现模型风格和专业领域的**快速适配**。\n\n从系统架构上看，如下图所示，**LLaMA-Factory** 扮演微调流程的调度中枢，负责统一配置数据和训练流程、插入 LoRA 模块以及管理推理接口；**KTransformers** 则作为可插拔的高性能算子后端，在相同的训练代码下接管底层 **Attention** 和 **MoE** 运算，实现 **GPU+CPU 异构协同**，加速训练并降低显存占用。\n\n![image-20251011010558909](../assets/image-20251011010558909.png)\n\n为评估该集成的性能优势，我们使用 LLaMA-Factory 分别调用了 HuggingFace 默认后端、Unsloth 后端以及 KTransformers 后端进行 LoRA 微调的对比测试（在相同设置和数据集下）。结果表明，**KTransformers** 是目前唯一能在 2～4 张 24GB 4090卡上微调 **671B 规模 MoE 模型** 的方案；同时在 14B 规模的 MoE 模型上，相比另两种方案也具有**更高的吞吐速率**和**更低的 GPU 显存占用**。\n\n| Under LoRA (BF16)+[NekoQA-10K-风格化对话数据集](https://github.com/mindsRiverPonder/LLM-practice) | HuggingFace Backend                      | Unsloth Backend                      | KTransformers Backend |\n| ------------------------------------------------------------ | ---------------------------------------- | ------------------------------------ | --------------------- |\n| [14B-DeepSeekV2-Lite] LoRA Fine-tuning throughput            | 303.58 token/s                           | 455.37 token/s                       | 530.38 token/s        |\n| [14B-DeepSeekV2-Lite] GPU Memory                             | 32.12 GB                                 | 9.64 GB                              | 6.08 GB               |\n| [671B-DeepSeekV3] LoRA Fine-tuning throughput                | <font color='red'>Too Huge to run</font> | <font color='red'>NOT SUPPORT</font> | 40.35 token/s         |\n| [671B-DeepSeekV3] GPU Memory（多卡总和）                     | 理论值1400 GB †                          | <font color='red'>NOT SUPPORT</font> | 70 GB †               |\n\n† **1400 GB** 为**理论显存**（FP16 全参数常驻，非可运行配置）；**70 GB** 为 KT 策略（Attention 驻 GPU + MoE分层 offload）下的**实测峰值**。\n\n上表中可以看出，对于 14B 模型，KTransformers 后端的吞吐量相比 HuggingFace 默认方案提升了约 75%，而显存占用仅为其约 1/5。对于 671B 模型，HuggingFace 和 Unsloth 在单台4090环境下无法运行，而 KTransformers 能以 **40 tokens/s** 的速度LoRA微调，并将 GPU 显存需求控制在 70 GB。\n\n![按照模型划分的对比图_02](../assets/image-compare_model.png)\n\n\n\n## KT微调框架整体性描述\n\n下面详细展示的是在 LLaMA-Factory 的微调框架中，KTransformers 后端如何接管底层算子并实现 Attention / MoE 的优化结构。\n\nDeepSeek-V3/V2等MoE模型主要包括小参数、密集矩阵的Attention部分和大参数、稀疏矩阵的MoE部分。为了直观说明，我们以 DeepSeek-V2-Lite-Chat 的第 2 层为例（从该层起，每层包含 Attention 与 MoE 两个子模块），其中Attention由GPU承担主要计算与缓存（KV），剩下的大参数量MoE主要由CPU承担 。下文将先介绍 **Attention 部分的替换与继承关系**，再介绍 **MoE 部分的封装与后端对接**，最后说明**多卡放置等特性支持**。\n\n### Attention 部分（LoRA + KT 特性并存）\n\nKTransformers 提供了算子模块的注入机制（`BaseInjectedModule`），而 PEFT 库提供了 LoRA 微调的层插入机制。为了在**微调阶段**同时兼容两者，我们设计了 `KTransformersLinearLora` 类，使其同时继承自 KTransformers 的线性层 (`KTransformersLinear`) 和 LoRA 的层基类 (`LoraLayer`)。如下图所示：\n\n- **继承关系**：如下图所示，`KTransformersLinearLora` 同时继承 `KTransformersLinear` 与 `LoraLayer`，既保留 **KT 的高性能算子**（如 `prefill_linear` / `generate_linear`），又能**加载 LoRA参数**（如 `lora_A`、`lora_B` 等矩阵）；\n\n- **替换策略**：在微调准备阶段，用 `KTransformersLinearLora` **逐一替换** 原 `KTransformersLinear`层（如下图右侧所示，主要包含Q/K/V/O 等线性层），从而在不破坏 KT 优化的前提下，将 LoRA 注入到了模型中，使其参数可训练。\n\n![image-20250911184023795](../assets/image-20250911184023795.png)\n\n替换完成后，如下图（左）所示，在计算图中相当于在原模型的 Q/K/V/O 四个矩阵乘法位置都插入了 LoRA。下图（右）展示了 `KTransformersLinearLora` 的内部，它同时包含了 KT 模块的高性能计算接口（prefill 和 generate 阶段的方法）以及 LoRA 的 A、B 矩阵等参数。\n\n![image-20250801174517784](../assets/image-20250801174517784.png)\n\n### MoE 部分（算子封装+backward实现）\n\n#### MoE算子封装\n\n考虑到 MoE 参数量大且计算稀疏，我们采用“封装成黑盒算子”的策略处理：将 MoE 专家计算封装为一个**对上游而言透明（单节点）、对下游可替换（多实现）**的可微算子。\n\n- **上游（PyTorch 计算图）**：我们注册自定义 Autograd Function，整个 MoE 专家层在计算图中呈现为**一个节点**。如下左图红框所示，封装后计算图中只有 `KSFTExpertsCPU` 这样一个算子节点；而右图红框为未封装时的细粒度计算图——路由、专家选择以及 FFN 计算都完整展开在计算图中。封装后，对微调过程来说，MoE层就等同于一个普通 `nn.Module`，前向计算可求梯度，反向梯度也由我们来自定义算子返回。\n- **下游（后端实现）**：在这个 Autograd Function 内部，我们通过 pybind11 调用了 C++ 扩展实现具体的前向和反向计算。这里我们提供了多个**可插拔后端实现**，如 AMX 指令集版本（支持 BF16/INT8 算子优化）和 llamafile 版本。只要遵循同样的接口，即可灵活切换后端。例如在 YAML 优化规则里指定使用 `\"backend\": \"AMXBF16\"`，就会调用 AMX 后端；改成 `\"llamafile\"` 则使用默认后端。\n\n![image-20250801174623919](../assets/image-20250801174623919.png)\n\n#### MoE 反向优化 (CPU 实现)\n\n在实现 MoE 自定义算子的反向传播时，我们特别优化了大矩阵的梯度计算开销。MoE反向计算需要频繁访问权重转置`Wᵀ`，为避免运行时反复转置带来的开销，我们在加载参数时**预备一份权重转置`Wᵀ` 便于复用**（如下图蓝框）。同时，**缓存必要的中间激活**（例如专家层中间投影结果，见下图红框），以便在反向阶段复用，减少重复计算。基于这些缓存，当前已提供 llamafile 与 AMX（INT8/BF16） 的MoE反向计算实现，并针对 NUMA 架构优化内存访问。\n\n<img src=\"../assets/image-20250911184455749.png\" alt=\"image-20250911184455749\" style=\"zoom: 33%;\" />\n\n### 多卡加载与训练：用“放置策略”而不是 DataParallel\n\n为了在使用 2～4 张 GPU 时进一步降低**单卡显存压力**，KTransformers 结合模型并行技术实现了**多卡协同微调**。与常规的 DataParallel 不同，我们没有简单地将整层模型复制到每张卡（那样显存需求会翻倍），而是采用**模型并行 + 显式算子放置**的策略，让不同 GPU 各自承载模型的一部分层。\n\n具体而言，我们对 Transformers Trainer 做了以下改动：\n\n1. **自定义训练器 (KTrainer)**：接管模型加载到设备的逻辑，采用显示层放置。默认情况下 `transformers` 会在初始化时将模型 `.to(device)` 全部搬移到单块 GPU，我们通过自定义 KTrainer 阻止这一行为，利用 KTransformers 的优化规则 YAML，我们可以在每一层声明 `device: cuda:0/cuda:1/...` 来指定该层所在的设备。这样初始化模型时，各层就直接构建在目标 GPU 上，不需要额外拷贝。。\n\n2. **禁用自动 DataParallel**：当启动全局变量`USE_KT=1`时，我们暂时禁用了 LLaMA-Factory 和 HuggingFace Train 原本自动启动的多卡 DataParallel 封装。避免了框架层面对模型的重复拷贝，使我们能够完全掌控模型的分片方案。\n\n3. **梯度回传与汇总**：由于模型各部分分散在不同 GPU 上，我们采取梯度汇总到 `cuda:0` 的方式。具体做法是：在反向传播时，仅将所需的梯度张量在设备间传输，而不传输整个模型的中间激活；各 GPU 计算各自部分的梯度，最终在0号卡汇总计算 loss。这种方式减少了不必要的通讯开销和激活冗余。\n\n通过上述手段，我们实现了**多 GPU 下依然遵循 KTransformers 放置策略**的训练方案。用户只需选择合适的 `kt_optimize_rule` 配置文件（例如带有 `multi-gpu` 的 YAML），即可启用默认的模型分片方案。在 DeepSeek-671B 微调中，我们提供的 `DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml` 就是一个两卡模型并行的典型策略：Attention 模块的 KV缓存和部分计算放在每张卡上，MoE 专家层在 CPU 上分片处理，两张卡共同承担全模型的计算。\n\n\n\n## KT-LoRA微调测试\n\n### 实验设置\n\n实验均采用 LLaMA-Factory 调度、KTransformers 后端、LoRA 轻量微调范式（超参数：rank = 8、α = 32、dropout = 0.1，BF16，`gradient_accumulation_steps=16`、`qlen=512`）以及与微调阶段一致的 KT 优化规则。我们分别评测了（a）风格化对话的迁移效果，以及（b）两类具有代表性的**定量基准**：西式翻译腔（生成式）与 AfriMed-QA（医疗垂直领域，含**简答生成**与**单项选择**两种子任务）。固定使用AMX指令集优化；GPU选取2张 48G VRAM 的 RTX 4090，CPU选取 Intel Xeon Platinum 8488C。\n\n### 效果测试\n\n#### 风格化对话测试（CatGirl语气）\n\n数据集采用[NekoQA-10K](https://zhuanlan.zhihu.com/p/1934983798233231689)进行风格迁移微调，目标是提升语气一致性与可辨识度。\n\n下图展示了原模型与微调后模型的对比。微调后回答在称谓、语气标记与修饰语上更稳定地保持了目标风格（红框），相较原模型的中性与理性表达（蓝框）具有更强的风格可辨识性，说明KT-LoRA 能以较低 GPU 成本，将特定风格特征有效注入到大模型生成分布。\n\n![风格化数据集模型输出对比_01](../assets/风格化数据集模型输出对比_01.png)\n\n#### 生成式翻译风格基准测试\n\n数据集采用了[西式翻译腔数据集](https://github.com/Benson114/Translational-Style-ChatLLM)，要求模型采用夸张的“西式翻译腔”，属生成式风格控制任务，评价指标采用生成任务常见的 BLEU-1/2/3/4 与 ROUGE-1/2/L。\n\n| 西式翻译腔数据集                | BLEU-1    | BLEU-2    | BLEU-3    | BLEU-4    | ROUGE-1   | ROUGE-2   | ROUGE-L   |\n| ------------------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- |\n| V2-Lite原模型（不LoRA微调）     | 20.66     | 8.33      | 4.54      | 2.89      | 22.71     | 4.52      | 19.19     |\n| **KT-LoRA微调DeepSeek-V2-Lite** | **35.41** | **22.44** | **15.42** | **11.18** | **42.03** | **18.38** | **33.10** |\n| V3原模型（不LoRA微调）          | 8.49      | 3.34      | 1.62      | 0.96      | 15.91     | 2.55      | 10.07     |\n| **KT-LoRA微调DeepSeek-V3**      | **37.02** | **23.70** | **16.21** | **11.49** | **43.43** | **18.96** | **34.54** |\n\n如上表测试结果所示，在统一流程与放置策略下，**两种规模的模型在微调后均出现一致性增益**，支持“KT 后端 + LoRA 微调”组合在生成式风格控制上的可用性与有效性。同时，说明 KT 的异构放置与算子优化能够稳定支撑风格域的小样本适配。\n\n#### 医疗垂直领域基准（AfriMed-SAQ/MCQ）\n\n数据集采用了[AfriMed-QA](https://aclanthology.org/2025.acl-long.96/)数据集（ACL-2025），作为非洲地区医疗领域的专用数据集，具有很强的场景定制特征，包含单选题（MCQ）和简答题（SAQ）两种形式，在本案例中作为垂直领域微调的评估。评估标准上，SAQ 用 BLEU/ROUGE；MCQ 用 Accuracy。\n\n| AfriMed-QA数据集（简答任务SAQ） | BLEU-1    | BLEU-2    | BLEU-3    | BLEU-4    | ROUGE-1   | ROUGE-2   | ROUGE-L   |\n| ------------------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- |\n| V2-Lite原模型（不LoRA微调）     | 13.58     | 11.12     | 9.10      | 7.23      | 22.48     | 7.81      | 11.73     |\n| **KT-LoRA微调DeepSeek-V2-Lite** | **35.90** | **27.63** | **22.99** | **19.15** | **35.25** | **17.50** | **28.44** |\n| V3原模型（不LoRA微调）          | 12.75     | 10.27     | 8.05      | 5.99      | 20.33     | 5.65      | 10.11     |\n| **KT-LoRA微调DeepSeek-V3**      | **42.42** | **34.12** | **28.95** | **24.54** | **41.97** | **22.37** | **33.28** |\n\n| AfriMed-QA数据集（单选任务MCQ） | Accuracy   |\n| ------------------------------- | ---------- |\n| V2-Lite原模型（不LoRA微调）     | 0.0645     |\n| **KT-LoRA微调DeepSeek-V2-Lite** | **0.4812** |\n| V3原模型（不LoRA微调）          | 0.5833     |\n| **KT-LoRA微调DeepSeek-V3**      | **0.7930** |\n\n如上表所示，（1）DeepSeek-V3（671B）经 KT-LoRA 微调后在MCQ和SAQ任务上均明显高于微调后的 DeepSeek-V2-Lite（14B），并且超过 V3 原模型。在我们的小规模设置中，初步说明了KT-LoRA微调巨大参数模型，在垂直领域中具有实际意义。\n\n（2）在 SAQ/MCQ 两类子任务上，KT-LoRA 均带来一致增益，说明在 KT 的异构放置与后端算子支持下，LoRA 微调能够把“医疗等垂直领域的知识要点”有效注入模型。\n\n#### 局限性说明\n\n目前我们基于的多为单数据集、小规模（2w条及以下）进行测试，旨在提供**KT-LoRA微调系统有效性的“存在性证据”**，而非对算法泛化或规模规律的概括性结论。我们报告中主要给出的是代表性数值；若要支持更强的算法结论，需要更大样本、跨语种/跨域多数据集与多随机种子重复实验，本文不作展开。\n\n**我们也特别欢迎大家加入LLaMA-Factory KT微调的开源项目中，如果大家有更多的测试结果，也特别特别欢迎写在下面的共享表格中，并补充好`kt_optimize_rule` 文件、数据集example、训练/评测 YAML、具体显存与 CPU 配置等，以便大家参考、复现~！**\n\n\n\n### 速度测试\n\n#### 端到端性能\n\n**测试定义：**\n\n`step_time`：一次优化步的总耗时（含张量搬运、Attention、MoE 等全部计算）。\n\n`tokens_per_step = GAS × qlen`；`token/s = tokens_per_step / step_time`。 本节统一采用 `GAS=16`、`qlen=512`，因此 `tokens_per_step = 8192`。\n\n**实测结果：**\n\n| 模型                 | step_time (s) | tokens/step | token/s   |\n| -------------------- | ------------- | ----------- | --------- |\n| DeepSeek-V3-671B     | 203           | 8192        | **40.35** |\n| DeepSeek-V2-Lite-14B | 36            | 8192        | **227.6** |\n\n#### MoE部分的计算性能（DeepSeek-V3-671B）\n\n**理论估算**\n\n- MoE 每层、每token的前/反向浮点计算总量 (FLOPs) 可近似：\n  $$\n  \\text{FLOPs}_{\\text{per-layer, per-token}} \\approx c \\cdot k \\cdot H \\cdot I\n  $$\n\n​\t\t其中：$k = 8$（Top-k 专家数），$H = 7168$（hidden size），$I = 2048$（intermediate size），常数 $c\\approx16$（折合前向=6、反向=10 的矩阵乘总系数）。\n\n- 每步（全 MoE 层）FLOPs 近似：\n  $$\n  \\text{FLOPs}_{\\text{per-step}} \\approx c \\cdot qlen \\cdot k \\cdot H \\cdot I \\cdot L_{\\text{MoE}}\n  $$\n\n​\t\t代 $c=16, qlen=512, k=8, H=7168, I=2048, L_{MoE}=58$，得 $\\text{FLOPs}_{\\text{per-step}} \\approx 55.8\\ \\text{TFLOPs}$.\n\n**实测情况**\n\nMOE部分在CPU上面的性能情况：每秒浮点计算量 $\\text{TFLOPS} = \\text{FLOPs}_{\\text{per-step}} / \\text{step\\_per\\_second}.$\n\n| TFLOPS                 | Forward | Backward |\n| ---------------------- | ------- | -------- |\n| 平均值（单位：TFLOPS） | 17.55   | 18.41    |\n\n### 显存/内存性能\n\nDeepSeek-V3（671B，61层，其中58层有MoE）占用显存大约70GB（多卡总量）、内存占用约1.2-1.3TB。\n\nDeepSeek-V2-lite（14B，27层，其中26层有MoE）占用显存大约5GB、内存占用约30GB。\n\n\n\n## 结论\n\n通过将 KTransformers LoRA 微调集成到 LLaMA‑Factory，我们为希望高效训练和部署 MoE 大模型的用户提供了一条可行路径。KT 提供新的放置策略和算子优化（支持 DeepSeek、Qwen、Kimi 等模型，并结合 AMX 指令加速关键内核），配合 LoRA 微调实现了在极低 GPU 显存占用下的模型定制化训练；而 LLaMA‑Factory 则提供了友好的上层接口与配置管理，让这一切变得易于使用。\n\n这种集成意味着即便是拥有数百亿乃至上万亿参数的 MoE 模型，也能够在相对普通的硬件上完成微调，并进行低延迟的推理部署。**显存节省**、**速度提升**和**易用性**在这套方案中达到了一定的平衡。我们期待社区在未来的 MoE 项目中尝试使用 LLaMA‑Factory 与 KTransformers 的组合，并欢迎参考本文档提供的指南进行操作。通过这一方案，超大模型不再是“无法企及”的存在，而成为每个开发者都可能驾驭的工具。"
  },
  {
    "path": "doc/zh/KTransformers-Fine-Tuning_User-Guide_zh.md",
    "content": "- [KTransformers 微调 × LLaMA-Factory 集成 – 用户指南](#ktransformers-微调-x-llama-factory-集成-–-用户指南)\n- [Introduction](#introduction)\n\n- [Quick to Start](#quick-to-start)\n  - [快速上手](#快速上手)\n  - [环境安装](#环境安装)\n  - [核心功能1：使用KTransformers作为backend，微调超大规模MoE模型](#核心功能1使用ktransformers作为backend微调超大规模moe模型)\n  - [核心功能2：与微调后模型（即原模型+LoRA Adapter）聊天，用于交互](#核心功能2与微调后模型即原模型lora-adapter聊天用于交互)\n  - [核心功能3：生成微调后模型（即原模型+LoRA Adapter）的API，用于批量生成并评测指标](#核心功能3生成微调后模型即原模型lora-adapter的api用于批量生成并评测指标)\n\n- [KT微调速度性能测试：用户侧](#kt微调速度性能测试用户侧)\n  - [端到端性能](#端到端性能)\n  - [显存/内存性能](#显存内存性能)\n\n- [结论](#结论)\n\n# KTransformers 微调 × LLaMA-Factory 集成 – 用户指南\n\n**MadSys实验室, KVCache-AI团队, 趋境科技, LLaMA-Factory团队**\n\n## Introduction\n\n从 **DeepSeek-V3/R1** 到 **Qwen3-MoE、Kimi-K2**，每一次超大模型的开源都带来性能与规模上的巨大跃升。然而，多数研究者与开发者受限于昂贵的显卡与动辄数千亿参数的模型，**难以在资源受限条件下微调超大模型**。面对这种差距，我们提出了一种更具可行性的方案：通过 **KTransformers 与 LLaMA-Factory 的结合**，仅需2~4张RTX 4090与较高内存CPU，便可微调DeepSeek-671B等超大规模的MoE模型。\n\n该架构的核心目标是为资源受限下的研究者提供 **在本地探索超大规模模型微调的可能性**。同时，也在较小规模（如 14B/30B）提供快速定制特定场景的路径。我们以**风格化对话、西式腔调翻译、医学问答**作为代表任务，验证架构的可行性，并展示在**数小时内达成个性化适配**的可操作性。\n\n\n\n如下图所示，LLaMA-Factory 是整个微调流程的统一调度与配置框架，负责数据处理、训练调度、LoRA 插入与推理接口管理； KTransformers 则作为其可插拔的高性能后端，在相同的训练配置下接管 Attention / MoE 等核心算子，实现异构设备（GPU+CPU）的高效协同。\n\n![image-20251011010558909](../assets/image-20251011010558909.png)\n\n我们在 LLaMA-Factory 框架下，对比评测了 **HuggingFace**、**Unsloth**、**KTransformers** 三种后端的 LoRA 微调方案。结果显示，KTransformers为超大规模的MoE模型（671B等）提供了**4090 级别**的唯一可行方案，并在较小规模的MoE模型（DeepSeek-14B）上面也展现了更高的吞吐和更低的显存占用。\n\n| Under LoRA (BF16)+[NekoQA-10K-风格化对话数据集](https://github.com/mindsRiverPonder/LLM-practice) | HuggingFace Backend                      | Unsloth Backend                      | KTransformers Backend |\n| ------------------------------------------------------------ | ---------------------------------------- | ------------------------------------ | --------------------- |\n| [14B-DeepSeekV2-Lite] LoRA Fine-tuning throughput            | 303.58 token/s                           | 455.37 token/s                       | 530.38 token/s        |\n| [14B-DeepSeekV2-Lite] GPU Memory                             | 32.12 GB                                 | 9.64 GB                              | 6.08 GB               |\n| [671B-DeepSeekV3] LoRA Fine-tuning throughput                | <font color='red'>Too Huge to run</font> | <font color='red'>NOT SUPPORT</font> | 40.35 token/s         |\n| [671B-DeepSeekV3] GPU Memory（多卡总和）                     | 理论值1400 GB †                          | <font color='red'>NOT SUPPORT</font> | 70 GB †               |\n\n† **1400 GB** 为**理论显存**（FP16 全参数常驻，非可运行配置）；**70 GB** 为 KT 策略（Attention 驻 GPU + MoE分层 offload）下的**实测峰值**。\n\n![按照模型划分的对比图_02](../assets/image-compare_model.png)\n\n### 微调效果示例\n\n#### 风格化对话测试（CatGirl风格语气）\n\n数据集：[NekoQA-10K: 面向猫娘语言建模的对话数据集](https://zhuanlan.zhihu.com/p/1934983798233231689)，目标是提升风格一致性与可辨识度。\n\n下图对比了原始模型和微调模型的回答，可以看到微调后模型在语气和称谓上更加稳定地保持了猫娘风格（红框部分），验证了**风格迁移微调**的有效性。\n\n![风格化数据集模型输出对比_01](../assets/风格化数据集模型输出对比_01.png)\n\n#### Benchmark测试\n\n数据集选取：\n\n（1）采用了[西式翻译腔数据集](https://github.com/Benson114/Translational-Style-ChatLLM)，该数据集要求模型按西式表达习惯进行夸张的翻译，有明确的定制化风格需求。\n\n（2）采用了[AfriMed-QA](https://aclanthology.org/2025.acl-long.96/)数据集（ACL-2025），作为非洲地区医疗领域的专用数据集，具有很强的场景定制特征，包含选择题和简答题两种形式，非常适合作为垂直领域微调的评估。针对单选和简答形式，我们分别进行测试，结果如下。\n\n下表显示了微调前后模型在这些数据集上的指标变化。可以看到经过 LoRA 微调后，各项指标**大幅提升**，验证了微调的有效性：\n\n| 西式翻译腔数据集                | BLEU-1    | BLEU-2    | BLEU-3    | BLEU-4    | ROUGE-1   | ROUGE-2   | ROUGE-L   |\n| ------------------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- |\n| V2-Lite原模型（不LoRA微调）     | 20.66     | 8.33      | 4.54      | 2.89      | 22.71     | 4.52      | 19.19     |\n| **KT-LoRA微调DeepSeek-V2-Lite** | **35.41** | **22.44** | **15.42** | **11.18** | **42.03** | **18.38** | **33.10** |\n| V3原模型（不LoRA微调）          | 8.49      | 3.34      | 1.62      | 0.96      | 15.91     | 2.55      | 10.07     |\n| **KT-LoRA微调DeepSeek-V3**      | **37.02** | **23.70** | **16.21** | **11.49** | **43.43** | **18.96** | **34.54** |\n\n| AfriMed-QA数据集（简答任务）    | BLEU-1    | BLEU-2    | BLEU-3    | BLEU-4    | ROUGE-1   | ROUGE-2   | ROUGE-L   |\n| ------------------------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- |\n| V2-Lite原模型（不LoRA微调）     | 13.58     | 11.12     | 9.10      | 7.23      | 22.48     | 7.81      | 11.73     |\n| **KT-LoRA微调DeepSeek-V2-Lite** | **35.90** | **27.63** | **22.99** | **19.15** | **35.25** | **17.50** | **28.44** |\n| V3原模型（不LoRA微调）          | 12.75     | 10.27     | 8.05      | 5.99      | 20.33     | 5.65      | 10.11     |\n| **KT-LoRA微调DeepSeek-V3**      | **42.42** | **34.12** | **28.95** | **24.54** | **41.97** | **22.37** | **33.28** |\n\n| AfriMed-QA数据集（单选任务）    | Accuracy   |\n| ------------------------------- | ---------- |\n| V2-Lite原模型（不LoRA微调）     | 0.0645     |\n| **KT-LoRA微调DeepSeek-V2-Lite** | **0.4812** |\n| V3原模型（不LoRA微调）          | 0.5833     |\n| **KT-LoRA微调DeepSeek-V3**      | **0.7930** |\n\n从以上测试可以看出，即使是参数量巨大的 MoE 模型，通过 KTransformers 后端的高效微调，**也能在特定任务上快速达到理想效果**。\n\n\n\n## Quick to Start\n\n### 快速上手\n\n本节将指导您如何安装环境并使用 **LLaMA-Factory + KTransformers** 完成微调和推理。我们将涵盖以下内容：\n\n- 环境依赖的安装配置\n- 使用 KTransformers 作为后端微调超大规模 MoE 模型\n- 加载微调后的模型（原模型 + LoRA 适配器）进行对话/推理\n- 批量推理微调模型并评测指标\n\n### 环境安装\n\n根据下面示例，同时安装KTransformers和LLaMA-Factory环境，这次为了简化KTransformers的安装流程，我们特意封装了wheel包避免本地编译，具体安装步骤如下：（注意对应好本地的python版本、torch版本、cuda版本和不同文件名的KTransformers包）\n\n```shell\n# 1. 安装conda环境\nconda create -n Kllama python=3.12 # choose from : [3.11, 3.12, 3.13]\nconda install -y -c conda-forge libstdcxx-ng gcc_impl_linux-64\nconda install -y -c nvidia/label/cuda-11.8.0 cuda-runtime\n\n# 2. 安装llamafactory环境\ngit clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git\ncd LLaMA-Factory\npip install -e \".[torch,metrics]\" --no-build-isolation\n\n# 3. 安装对应torch和python版本的KTransformers（CUDA版本可以跟whl命名的不一致），从https://github.com/kvcache-ai/ktransformers/releases/tag/v0.4.1\npip install ktransformers-0.4.1+cu128torch27fancy-cp312-cp312-linux_x86_64.whl\n\n# 4. 安装flash-attention，参照python版本和torch版本，从https://github.com/Dao-AILab/flash-attention/releases下载\npip install flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl\n# abi=True/False可以用下面代码查看\n# import torch\n# print(torch._C._GLIBCXX_USE_CXX11_ABI)\n\n# 5. （可选）如果你想使用flash_infer的话（不然默认triton）\ngit clone https://github.com/kvcache-ai/custom_flashinfer.git\npip install custom_flashinfer/\n```\n\n\n\n**使用要点**：在 LLaMA-Factory 的配置 YAML 文件中启用 KTransformers 后端，只需设置 `use_kt: true`，并指定相应的 `kt_optimize_rule` YAML 文件，即可切换到底层由 KTransformers 接管计算。下面我们将通过具体功能来说明如何设置这些配置。\n\n### 核心功能1：使用KTransformers作为backend，微调超大规模MoE模型\n\n运行命令：`USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml`。\n\n需要注意的是，必须提供BF16格式模型文件，DeepSeek-V3-671B默认下载是FP8格式，需要通过 [DeepSeek-V3/inference/fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) 转换。\n\n```yaml\n### model\nmodel_name_or_path: opensourcerelease/DeepSeek-V3-bf16\ntrust_remote_code: true\n\n### method\nstage: sft\ndo_train: true\nfinetuning_type: lora\nlora_rank: 8\nlora_target: all\n\n### dataset\ndataset: identity\ntemplate: deepseek\ncutoff_len: 2048\nmax_samples: 100000\noverwrite_cache: true\npreprocessing_num_workers: 16\ndataloader_num_workers: 4\n\n### output\noutput_dir: saves/Kllama_deepseekV3\nlogging_steps: 10\nsave_steps: 500\nplot_loss: true\noverwrite_output_dir: true\nsave_only_model: false\nreport_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]\n\n### train\nper_device_train_batch_size: 1\ngradient_accumulation_steps: 8\nlearning_rate: 1.0e-4\nnum_train_epochs: 3.0\nlr_scheduler_type: cosine\nwarmup_ratio: 0.1\nbf16: true\nddp_timeout: 180000000\nresume_from_checkpoint: null\n\n### ktransformers\nuse_kt: true # use KTransformers as LoRA sft backend\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml\ncpu_infer: 32\nchunk_size: 8192\n```\n\n其中，`kt_optimize_rule`提供了大量默认的YAML文件来控制**KTransformers的放置策略**，下面针对YAML文件名和功能对照特别说明，也可以参考[ktransformers/optimize_rules](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/optimize/optimize_rules)：（\\*指通配符）\n\n| 文件名字段                                    | 功能特征                                           |\n| --------------------------------------------- | -------------------------------------------------- |\n| DeepSeek-V2-Lite-Chat-\\*或DeepSeek-V3-Chat-\\* | 对应的不同模型                                     |\n| \\*-sft-\\*                                     | 微调所用的放置策略，其他为推理所用                 |\n| \\*-amx-\\*                                     | 使用AMX指令集进行CPU运算，其他为llamafile          |\n| \\*-multi-gpu-X\\*                              | 使用X张GPU进行模型并行（显存共担），X为空默认是2张 |\n\n例如：`examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml`为DeepSeek-V3-Chat模型用AMX指令集进行微调，并调用两卡模型并行。\n\n对于微调任务，我们推荐使用**AMX指令集加速**，可以使用`lscpu | grep amx`查看CPU是否支持AMX指令集，AMX精度支持BF16/Int8，修改方式如下：\n\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert parallelism\n    kwargs:\n      prefill_device: \"cpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n```\n\n输出会保存在`output_dir`里面，默认为safetensor格式，并且保留adapter.json等配套内容以便后续加载。\n\n![演示文稿1_01](../assets/演示文稿1_01.png)\n\n\n\n### 核心功能2：与微调后模型（即原模型+LoRA Adapter）聊天，用于交互\n\n运行命令：`llamafactory-cli chat examples/inference/deepseek3_lora_sft_kt.yaml`。\n\n调用KT微调的adapter (safetensor格式) 推理对话。\n\n```yaml\nmodel_name_or_path: opensourcerelease/DeepSeek-V3-bf16\nadapter_name_or_path: saves/Kllama_deepseekV3\ntemplate: deepseek\ninfer_backend: ktransformers  # choices: [huggingface, vllm, sglang, ktransformers]\ntrust_remote_code: true\n\nuse_kt: true # 调用KTransformers backend\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml # 请选择和LoRA微调的时候保持一致的YAML文件\ncpu_infer: 32\nchunk_size: 8192\n```\n\n同时，我们也支持GGUF格式的adapter进行推理（如果您已经使用了上述LLaMA-Factory+KTransformers的微调方案，就不用管啦~）。\n\nsafetensors 场景填**文件所在目录**，GGUF 场景填**文件路径**，也就是说您需要把`adapter_name_or_path`选为具体的GGUF格式文件。\n\n加载过程中适配了KT每层的命名，和torch.save保存下来的常规命名的不同，正常映射日志`Loaded adapter weight: XXX -> XXX`，展示如下。\n\n![image-20250801165752484](../assets/image-20250801165752484.png)\n\n\n\n### 核心功能3：生成微调后模型（即原模型+LoRA Adapter）的API，用于批量生成并评测指标\n\n运行命令：`API_PORT=8000 llamafactory-cli api examples/inference/deepseek3_lora_sft_kt.yaml`。\n\n调用KT微调的adapter给出API，其他API使用逻辑和llamafactory原生方式一致。\n\n```yaml\nmodel_name_or_path: opensourcerelease/DeepSeek-V3-bf16\nadapter_name_or_path: saves/Kllama_deepseekV3\ntemplate: deepseek\ninfer_backend: ktransformers  # choices: [huggingface, vllm, sglang, ktransformers]\ntrust_remote_code: true\n\nuse_kt: true # use KTransformers as LoRA sft backend to inference\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml\ncpu_infer: 32\nchunk_size: 8192\n```\n\n\n\n## KT微调速度性能测试：用户侧\n\n### 端到端性能\n\n**测试定义：**\n\n`step_time`：一次优化步（包含 `gradient_accumulation_steps (GAS)` 次累积）的总时间，涵盖 **PyTorch 张量搬运 + Attention + MoE + 其他计算等**。\n\n`tokens_per_step = GAS × qlen`；`token/s = tokens_per_step / step_time`。\n\n**测试设置：**`GAS=16`，`qlen=512`（即每步 8192 tokens）；LoRA（`r=8, alpha=32, dropout=0.1`）；使用AMX指令集优化；GPU选取RTX 4090，CPU选取Intel Xeon Platinum 8488C。\n\n**实测结果：**\n\n**DeepSeek-V3-671B：**step_time = 203 s` → `token/s ≈ 8192 / 203 **≈ 40.35 token/s**\n\n**DeepSeek-V2-Lite-14B：**step_time = 36 s` → `token/s ≈ 8192 / 36 **≈ 227.6 token/s**\n\n### 显存/内存性能\n\nDeepSeek-V3（671B，61层，其中58层有MoE）占用显存（多卡总量）大约**70GB**、内存占用约1.2-1.3TB。\n\nDeepSeek-V2-lite（14B，27层，其中26层有MoE）占用显存大约**5.5GB**、内存占用约150GB。\n\n\n\n## 结论\n\n通过开发 KTransformers LoRA微调并将其集成到 LLaMA‑Factory，我们为希望高效训练与部署 MoE 大模型的用户提供了可行指南。KT 带来最尖端的优化（支持 DeepSeek、Qwen、Kimi 等，配合 AMX 加速 kernel），同时通过 LoRA 微调在极低 GPU 显存下实现定制化。LLaMA‑Factory 则提供友好的统一界面，更广的用户支持。\n\n该集成（类似 Unsloth 补丁所带来的提速）意味着即便是数百亿乃至万亿总参数量的 MoE 模型，也可在普通硬件上完成微调并低延迟部署。**显存节省、速度提升、易用性** 三者兼得。我们鼓励用户在下一次 MoE 项目中尝试 LLaMA‑Factory 的 KT 集成，并参考本文档进行操作。也欢迎提出任何问题和建议！\n"
  },
  {
    "path": "doc/zh/Qwen3-MoE_tutorial_zh_for_Ascend_NPU.md",
    "content": "# 基准测试结果(输出token长度均设置1k, 单并发)\n\n| Prompt length                     | 1K     | 2K     | 4K     |\n| --------------------------------- | ------ | ------ | ------ |\n| KTrans Prefill token/s | 134.11 | 141.60 |  143.42 |\n| KTrans Decode token/s | 11.05 | 10.74 | 10.68 |\n\n## 先决条件\n我们在以下配置下进行了Qwen3-235B-A22B MoE最佳性能测试：\n- 服务器型号：Atlas 2UP\n- NPU：Atlas 300I A2\n- CPU: HUAWEI Kunpeng 920 7270Z\n- 内存: DDR5服务器内存（1TB）\n\n# 部署\n\n***关于部署过程，此README中只额外描述与同级目录下 `DeepseekR1_V3_tutorial_zh_for_Ascend_NPU.md` 不同的部分***\n\n## 物理机安装\n\n部署满血版Qwen3-MoE，需要机器物理内存能够存放下全部路由专家的权重，约200GB。\n\n目前支持的NPU型号：**300I A2**。\n\n在技术人员的支持下完成硬件安装。\n\n\n## 权重准备\n\n目前，为了满足性能和精度的要求，我们需要准备两份权重，并使用提供的权重合并脚本对权重进行合并，最终只会使用合并后的权重。\n\nQ4权重：[Qwen3-235B-A22B-Instruct-2507-GGUF](https://modelscope.cn/models/unsloth/Qwen3-235B-A22B-Instruct-2507-GGUF/files)\n\nW8A8权重：[Qwen3-235B-A22B-w8a8](https://modelers.cn/models/Modelers_Park/Qwen3-235B-A22B-w8a8)\n\n使用[merge_safetensor_gguf_for_qwen3.py](../../merge_tensors/merge_safetensor_gguf_for_qwen3.py)来合并Q4和W8A8权重：\n\n```bash\npython merge_safetensor_gguf_for_qwen3.py --safetensor_path /mnt/weights/Qwen3-235B-A22B-Q4_K_M --gguf_path /mnt/weights/Qwen3-235B-A22B-W8A8 --output_path /mnt/weights/Qwen3-235B-A22B-q4km-w8a8\n```\n\n## kTransformers部署\n\n将项目文件部署到机器上：\n\n- 初始化third_party。由于此过程耗时较多，且容易受网络影响导致仓库克隆失败，建议初始化一次后，将相关文件进行打包，以便后续直接解压使用。\n  ```bash\n  git clone https://github.com/kvcache-ai/ktransformers.git\n  cd ktransformers\n  git submodule update --init --recursive\n  ```\n- 对于arm平台，注释掉`./third_party/llamafile/iqk_mul_mat_arm82.cpp`中的\n  ```cpp\n  #define iqk_mul_mat iqk_mul_mat_arm82\n  #define iqk_mul_mat_moe iqk_mul_mat_moe_arm82\n  ```\n- 执行`source /usr/local/Ascend/ascend-toolkit/set_env.sh`（以实际CANN-TOOLKIT安装路径为准）。\n- 执行`apt install cmake libhwloc-dev pkg-config`安装依赖。\n- 修改项目目录下 /ktransformers/config/config.yaml 中attn部分的page_size: 128  chunk_size: 16384\n- 执行`USE_BALANCE_SERVE=1 USE_NUMA=1 bash ./install.sh`，等待安装完成。\n    ***执行安装命令之前，需要将`./ktransformers/configs/config.yaml`中对于page size的设置改为page size=128(因为attn计算算子`torch_npu.npu_fused_infer_attention_score`支持page_size=16/128)***\n\n此处给出示例balance_serve的启动脚本（由于使用了相对路径，需将该脚本放至项目的根路径下）：\n\n```bash\n#!/bin/bash\nexport USE_MERGE=0\nexport INF_NAN_MODE_FORCE_DISABLE=1\nexport TASK_QUEUE_ENABLE=0\nexport RANK=0\nexport LOCAL_WORLD_SIZE=1\n#export PROF_DECODE=1\n#export PROF_PREFILL=1\n\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh\nsource /usr/local/Ascend/nnal/atb/set_env.sh\n\npython ktransformers/server/main.py \\\n--port 10002 \\\n--model_path <your model path> \\\n--gguf_path <your model path> \\\n--cpu_infer 48 \\\n--optimize_config_path  ./ktransformers/optimize/optimize_rules/npu/Qwen3-Chat-300IA2-npu-serve.yaml \\\n--max_new_tokens 1024 \\\n--cache_lens 16384 \\\n--max_batch_size 4 \\\n--use_cuda_graph \\\n--tp 1 \\\n--backend_type balance_serve\n```\n\n相关参数说明：\n\n- `--model_path`：kTransformers原生参数，str，此处用来指定合并后的模型文件路径\n- `--gguf_path`：kTransformers原生参数，str，此处用来指定合并后的模型文件路径\n- `--cpu_infer`：kTransformers原生参数，int，用来控制CPU侧实际worker线程数，非必选\n- `--optimize_config_path`：kTransformers原生参数，str，用来指定所用的模型优化配置文件，需要注意相对路径的使用，此处为**必选**\n- `--cache_lens`：调度器申请 kvcache 的总长度。所有请求共享指定数量（例如 `20480`）的 tokens 对应的 kvcache 空间，请求完成后会释放其所占用的 kvcache 空间，非必选\n- `--use_cuda_graph`：kTransformers原生参数，bool，为True表示开启图下沉，为False表示关闭图下沉，非必选\n- `--max_new_tokens`：kTransformers原生参数，int，当统计到输出的tokens数量达到该值时，会直接中止输出，非必选\n- `--tp`：新增参数，int，用于开启tensor model parallel功能，目前local_chat只支持tp大小与ws大小相同（不支持local_chat使用多dp），非必选\n\n\n# 其他问题\n\n## 可能存在的其他依赖问题\n\nImportError: libhccl.so: cannot open shared object file: No such file or directory\n\n```bash\nsource /usr/local/Ascend/ascend-toolkit/set_env.sh  # 以实际CANN安装路径为准\n```\n\nImportError: libascend_hal.so: cannot open shared object file: No such file or directory\n\n```bash\nexport LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver:$LD_LIBRARY_PATH  # 以实际Driver安装路径为准\n```\n"
  },
  {
    "path": "doc/zh/api/server/api.md",
    "content": "# API\n\n\n- [OpenAI ChatCompletion](#openai-chatcompletion)\n- [Ollama ChatCompletion](#ollama-chatcompletion)\n- [OpenAI Assistant](#openai-assistant)\n\n\n## OpenAI ChatCompletion\n```bash\nPOST /v1/chat/completions\n```\n根据选定的模型生成回复。\n\n### 参数\n\n\n- `messages`：一个 `message` 的数组所有的历史消息。`message`：表示用户（user）或者模型（assistant）的消息。`message`包含：\n\n  - `role`: 取值`user`或`assistant`，代表这个 message 的创建者。\n  - `content`: 用户或者模型的消息。\n\n- `model`：选定的模型名\n- `stream`：取值 true 或者 false。表示是否使用流式返回。如果为 true，则以 http 的 event stream 的方式返回模型推理结果。\n\n### 响应\n\n- 流式返回：一个 event stream，每个 event 含有一个`chat.completion.chunk`。`chunk.choices[0].delta.content`是每次模型返回的增量输出。\n- 非流式返回：还未支持。\n\n### 例子\n\n```bash\ncurl -X 'POST' \\\n  'http://localhost:9112/v1/chat/completions' \\\n  -H 'accept: application/json' \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n  \"messages\": [\n    {\n      \"content\": \"tell a joke\",\n      \"role\": \"user\"\n    }\n  ],\n  \"model\": \"Meta-Llama-3-8B-Instruct\",\n  \"stream\": true\n}'\n```\n\n```bash\ndata:{\"id\":\"c30445e8-1061-4149-a101-39b8222e79e1\",\"object\":\"chat.completion.chunk\",\"created\":1720511671,\"model\":\"not implmented\",\"system_fingerprint\":\"not implmented\",\"usage\":null,\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Why \",\"role\":\"assistant\",\"name\":null},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata:{\"id\":\"c30445e8-1061-4149-a101-39b8222e79e1\",\"object\":\"chat.completion.chunk\",\"created\":1720511671,\"model\":\"not implmented\",\"system_fingerprint\":\"not implmented\",\"usage\":null,\"choices\":[{\"index\":0,\"delta\":{\"content\":\"\",\"role\":\"assistant\",\"name\":null},\"logprobs\":null,\"finish_reason\":null}]}\n\ndata:{\"id\":\"c30445e8-1061-4149-a101-39b8222e79e1\",\"object\":\"chat.completion.chunk\",\"created\":1720511671,\"model\":\"not implmented\",\"system_fingerprint\":\"not implmented\",\"usage\":null,\"choices\":[{\"index\":0,\"delta\":{\"content\":\"couldn't \",\"role\":\"assistant\",\"name\":null},\"logprobs\":null,\"finish_reason\":null}]}\n\n...\n\ndata:{\"id\":\"c30445e8-1061-4149-a101-39b8222e79e1\",\"object\":\"chat.completion.chunk\",\"created\":1720511671,\"model\":\"not implmented\",\"system_fingerprint\":\"not implmented\",\"usage\":null,\"choices\":[{\"index\":0,\"delta\":{\"content\":\"two-tired!\",\"role\":\"assistant\",\"name\":null},\"logprobs\":null,\"finish_reason\":null}]}\n\nevent: done\ndata: [DONE]\n```\n\n\n\n## Ollama ChatCompletion\n\n```bash\nPOST /api/generate\n```\n\n根据选定的模型生成回复。\n\n### 参数\n\n\n- `prompt`：一个字符串，代表输入的 prompt。\n- `model`：选定的模型名\n- `stream`：取值 true 或者 false。表示是否使用流式返回。如果为 true，则以 http 的 event stream 的方式返回模型推理结果。\n\n### 响应\n\n- 流式返回：一个流式的 json 返回，每行是一个 json。\n  - `response`：模型补全的增量结果。\n  - `done`：是否推理结束。\n\n- 非流式返回：还未支持。\n\n### 例子\n\n```bash\ncurl -X 'POST' \\\n  'http://localhost:9112/api/generate' \\\n  -H 'accept: application/json' \\\n  -H 'Content-Type: application/json' \\\n  -d '{\n  \"model\": \"Meta-Llama-3-8B-Instruct\",\n  \"prompt\": \"tell me a joke\",\n  \"stream\": true\n}'\n```\n\n```bash\n{\"model\":\"Meta-Llama-3-8B-Instruct\",\"created_at\":\"2024-07-09 08:13:11.686513\",\"response\":\"I'll \",\"done\":false}\n{\"model\":\"Meta-Llama-3-8B-Instruct\",\"created_at\":\"2024-07-09 08:13:11.729214\",\"response\":\"give \",\"done\":false}\n\n...\n\n{\"model\":\"Meta-Llama-3-8B-Instruct\",\"created_at\":\"2024-07-09 08:13:33.955475\",\"response\":\"for\",\"done\":false}\n{\"model\":\"Meta-Llama-3-8B-Instruct\",\"created_at\":\"2024-07-09 08:13:33.956795\",\"response\":\"\",\"done\":true}\n```\n\n\n\n"
  },
  {
    "path": "doc/zh/api/server/server.md",
    "content": "# 后端服务（Server）\nServer 将 ktransformers 的快速异构推理能力通过 API 提供给外界调用。\n\n<img src=\"server-arch.png\" height=\"600\" alt=\"Server架构\">\n\n## API\n\nServer 通过 RESTful API 对外提供模型推理服务，提供  ChatCompletion 和 Assistant 两种调用方式。\n\n- ChatCompletion 接口要求用户一次提供所有的历史对话，然后返回模型的回复。AI 服务提供商（例如[OpenAI](https://platform.openai.com/docs/api-reference/chat/create) ）和本地推理框架（例如[Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md) ）都提供 ChatCompletion 接口。为了兼容 OpenAI 和 Ollama，Server 分别提供和它们一致的 API 接口。因此，当前使用 OpenAI 和 Ollama 的应用可以无缝切换到我们的 Server。例如： [如何使用 Tabby 和 ktransformers 在本地利用 236B 的大模型做代码补全？](tabby.md)。\n- Assistant 适用于应用需要复用一系列资源并调用模型的场景。例如，在教育应用场景中，应用开发者可以创建一个名为二年级数学老师的 Assistant，并设置初始prompt（“你是一个有经验的的二年级数学老师...”），上传相关的资料（二年级数学教材）。创建 Assistant 后，应用需要创建一个 Thread 来存储用户和模型的对话消息（Message）。调用模型时，应用需要创建一个 Run 来获得 Assistant 的回复。相对于 ChatCompletion，实现了 Assistant 的 Server 代替应用实现了对话背景复用和多轮对话，使得复杂场景下的模型的调用更加方便。 [OpenAI Assistant API](https://platform.openai.com/docs/api-reference/assistants/createAssistant) 提出了这样的 Assistant 接口，而 Server 也提供和它一致的 API 。\n\n这些 API 定义在`server/api`中，它们的具体使用请见[这里](api.md)。\n\n\n## 对接模型推理框架\n\nServer 通过 ktransformers 调用模型并进行推理。Server 也支持其他的推理框架，例如已经支持的 [transformers](https://huggingface.co/docs/transformers/index) ，并计划支持 [exllamav2](https://github.com/turboderp/exllamav2)。这些功能在`server/backend` 中实现。\n\nServer 将模型推理框架的推理功能抽象成一个基类`BackendInterfaceBase`。这个基类包含一个函数：inference。它的输入是是历史的对话信息 messages，输出是模型返回的文字结果。inference 函数采用 async generator 的设计，这使得 Server 可以流式地返回模型的回复。\n\n```python\nclass BackendInterfaceBase:\n  async def inference(self, messages, **kwargs)->AsyncIterator[str]:\n  \t...\n```\n\n这个 inference 函数，因为它的输入和输出分别是历史对话和模型回复，所以它自然地实现了 ChatCompletion 的功能。因此 ChatCompletion API 可以直接调用inference 函数完成模型推理。\n\n而 Assistant 则比 ChatCompletion 复杂许多，需要 Server 存储 Assistant 的相关状态，并以合适的方式调用 inference 函数。Server 在数据库中维护了一套 Assistant 逻辑，存储应用创建的 Assistant，Thread 和 Message。在内存中，Server 为每个 Thread 维护一个 `ThreadContext`，集合每个Thread 相关的 Assistant 等信息。当用户发出新的 Message 时，Server 调用 ThreadContext 的get_local_messages函数，获得 messages，并调用 inference 函数获得推理结果。\n\n```python\nclass MyThreadContext(ThreadContext):\n    def get_local_messages(self):\n      ...\n```\n\n由于不同的模型推理框架有着不同的历史对话输入格式，所以 `ThreadContext` 和 `BackendInterface` 需要成对地使用。Server 除了自己的 ktransformers 之外，还支持 transformers。如果要对接其他的模型推理框架，可以参考在 [transformers.py](https://github.com/kvcache-ai/ktransformers-dev/blob/main/ktransformers/server/backend/interfaces/transformers.py) 中`TransformersInterface`和`TransformersThreadContext`的实现。 \n\n\n\n"
  },
  {
    "path": "doc/zh/api/server/tabby.md",
    "content": "# 如何使用 Tabby 和 ktransformers 在本地利用 236B 的大模型做代码补全？\n\n[Tabby](https://tabby.tabbyml.com/docs/welcome/) 是一个开源的代码助手，用户可以手动配置后端使用的框架及模型，并在多个 IDE/编辑器 上使用，例如 VSCode 和 InteliJ。因为 Tabby 在框架侧可以对接到 Ollama，并且 ktransformers server 提供和 Ollama 一致的 API 接口，所以我们可以将 Tabby 对接到 ktransformers server。并在代码补全的场景中体验到 ktransformers 快速的异构推理。\n\n1. 启动 ktransformers。\n```bash\n./ktransformers --port 9112\n```\n2. 安装 Tabby：按照 Tabby 的官方教程在带有英伟达 GPU 的 Linux 服务器或者 Windows PC 上[安装 Tabby](https://tabby.tabbyml.com/docs/quick-start/installation/linux/)。\n3. 配置 Tabby：创建`~/.tabby/config.toml`，并加入以下配置。\n```toml\n[model.completion.http]\nkind = \"ollama/completion\"\napi_endpoint = \"http://127.0.0.1:9112/\"\nmodel_name = \"DeepSeek-Coder-V2-Instruct\"\nprompt_template = \"<｜fim▁begin｜>{prefix}<｜fim▁hole｜>{suffix}<｜fim▁end｜>\" # Prompt Template\n```\n\n在这个配置中，`kind` 指明 ktransformers 使用 Ollama 的标准 API 为 Tabby 提供服务；`api_endpoint` 与 ktransforer 启动时绑定的接口保持一致；`model_name` 设置为 ktransformers 使用的模型，这里使用 `DeepSeek-Coder-V2-Instruct` 作为后台推理的模型；`prompt_template` 是模型的提示词模板，针对不同的模型，使用相对应的模版才能正常使用模型 Fill In the Middle 的功能。\n在这里演示的是 Tabby 使用 Ollama API 提供 Completion 功能的相关配置，有关 Tabby 其他可选功能的配置信息请参照[这里](https://tabby.tabbyml.com/docs/administration/model/)。\n\n\n4. 启动 Tabby 服务：`./tabby serve`。\n<img src=\"run-tabby.png\" alt=\"image-20240709112329577\" style=\"zoom:50%;\" />\n\n​\t启动之后，期望会在 ktransformers 的命令行界面看到对 `/api/tags` 接口的访问(在 Tabby 新版本 v0.13.0 中变为对 `/api/show/` 接口的访问)。\n<img src=\"visit-api-tags.png\" alt=\"image-20240709111648215\" style=\"zoom:67%;\" />\n\n6. 注册 Tabby 账户，获取 Token：在启动 Tabby 服务后，在浏览器中打开相应的链接(如上图的 0.0.0.0:8080)，并参照[教程](https://tabby.tabbyml.com/docs/quick-start/register-account/) 创建用户并获取 Token。\n\n7. 启动 VScode 安装 Tabby 拓展插件，并在相关提示下，使用上一步获得的 Token 连接 Tabby Server，参照[这里](https://tabby.tabbyml.com/docs/extensions/installation/vscode/)。\n\n8. 打开任意代码文件，体验 ktransformers 的快速异构推理。\n\n"
  },
  {
    "path": "doc/zh/api/server/website.md",
    "content": "# Start with website\n\nThis document provides the necessary steps to set up and run the web service for this project.\n\n## 1. Starting the Web Service\n\n### 1.1. Compiling the Web Code\n\nBefore you can compile the web code, make sure you have installed [Node.js](https://nodejs.org) version 18.3 or higher\n\nOnce npm is installed, navigate to the `ktransformers/website` directory:\n\n```bash\ncd ktransformers/website\n```\n\nNext, install the Vue CLI with the following command:\n\n```bash\nnpm install @vue/cli\n```\n\nNow you can build the project:\n\n```bash\nnpm run build\n```\nFinally you can build ktransformers with website:\n```\ncd ../../\npip install .\n```\n"
  },
  {
    "path": "doc/zh/clawdbot_integration_guide.md",
    "content": "# KTransformers + Clawdbot：本地部署 AI 助手方案\n\n> **利用 KTransformers 的 CPU-GPU 混合推理能力，结合 Kimi-K2.5 的高质量推理能力，为 Clawdbot 提供高性能本地推理后端**\n\n---\n\n## 什么是 Clawdbot？\n\n[Clawdbot](https://github.com/openclaw/openclaw) 是一款开源的个人 AI 智能体，支持通过 Telegram、Discord、Signal、WhatsApp 等聊天平台交互，可实现日程管理、邮件发送、数据查询等自动化任务，数据完全本地存储，隐私可控。\n\n> **注意**：Clawdbot 默认不内置飞书（Feishu）Channel，需要额外安装社区插件，详见下方飞书接入章节。\n\n---\n\n## 为什么选择 KTransformers 作为推理后端？\n\n**KTransformers** 使用 CPU-GPU 混合推理架构：\n\n- **CPU-GPU 协同**：GPU 处理高价值推理路径，CPU（AMX 量化）处理专家模块，资源利用率最大化\n- **原生 MoE 支持**：支持多种原生精度的 MoE 模型\n- **SGLang 高性能引擎**：兼容 OpenAI API，支持多 GPU Tensor Parallel 并行\n- **全栈 CLI 工具**：`kt run` 一键启动、`kt model` 模型管理、`kt quant` 智能量化、`kt bench` 性能测试、`kt doctor` 环境诊断\n\n---\n\n## 支持的模型\n\n自 Kimi K2 Thinking 等[原精度模型支持](../en/kt-kernel/Native-Precision-Tutorial.md)以来，我们 Day0 适配了 [Kimi K2.5](../en/Kimi-K2.5.md)。目前，我们已经原精度支持 Kimi K2.5、MiniMax、DeepSeek、Qwen3、GLM 等 MoE 模型，仅使用 24-48G 显存即可完美部署。\n\n---\n\n## 部署架构\n\n```\n[用户] → [Telegram / Discord / Signal / 飞书] → [Clawdbot Gateway]\n                                                        ↓\n                                                  [KTransformers]\n                                                   (SGLang API)\n                                                        ↓\n                                                  [多 GPU 推理]\n```\n\nClawdbot 通过 OpenAI 兼容 API 接入 KTransformers，无需额外 API 密钥，本地推理零费用。\n\n---\n\n## 部署步骤\n\n### 第一步：安装并启动 KTransformers\n\n[Kimi K2.5 使用指南](../en/Kimi-K2.5.md)\n\n[kt kernel 部署指南](https://github.com/kvcache-ai/ktransformers/tree/main/kt-kernel)。\n\n启动后，KTransformers 会在 `http://<host>:30000/v1` 提供 OpenAI 兼容 API。\n\n### 第二步：安装 Clawdbot\n\n```bash\nnpm install -g openclaw@latest\n\nopenclaw onboard --install-daemon\n```\n\n> 关于 Clawdbot 的详细安装与配置，请参考 [Clawdbot 官方文档](https://openclaw.ai) 和 [GitHub 仓库](https://github.com/openclaw/openclaw)。\n\n### 第三步：配置 KTransformers 作为推理后端\n\n编辑 Clawdbot 配置文件（通常位于 `~/.openclaw/openclaw.json`，或通过网页版 `http://127.0.0.1:18789/config`），将模型 provider 指向本地 KTransformers 服务：\n\n```json\n{\n  \"models\": {\n    \"providers\": {\n      \"synthetic\": {\n        \"baseUrl\": \"http://127.0.0.1:30000/v1\",\n        \"apiKey\": \"EMPTY\",\n        \"api\": \"openai-completions\",\n        \"models\": [\n          {\n            \"id\": \"kimi-k2.5\",\n            \"name\": \"kimi-k2.5\",\n            \"contextWindow\": 200000,\n            \"maxTokens\": 16384\n          }\n        ]\n      }\n    },\n    \"routing\": {\n      \"default\": {\n        \"provider\": \"synthetic\",\n        \"modelId\": \"kimi-k2.5\"\n      }\n    }\n  }\n}\n```\n\n关键配置说明：\n- `baseUrl`：KTransformers SGLang 服务地址\n- `apiKey`：填写 `\"EMPTY\"` 即可，本地服务不需要密钥\n- `models`：根据实际运行的模型调整 `id` 和 `contextWindow`\n\n### 第四步：启动 Clawdbot Gateway\n\n```bash\nopenclaw gateway --port 18789\n```\n\n### 第五步：配置消息通道\n\nClawdbot 原生支持 Telegram、Discord、Signal 等通道：\n\n```bash\n# Telegram\nopenclaw channels login --channel telegram\n\n# Signal\nopenclaw channels login --channel signal\n```\n\n---\n\n## 飞书接入\n\nClawdbot 默认不包含飞书通道，需要通过社区开发的飞书桥接插件接入。\n\n主要步骤：\n1. 在[飞书开放平台](https://open.feishu.cn/)创建企业自建应用，添加\"机器人\"能力\n2. 安装飞书桥接插件（社区项目：[clawdbot-feishu](https://github.com/m1heng/clawdbot-feishu)）\n3. 配置 `appId`、`appSecret` 等飞书应用凭据\n4. 添加\"接收消息\"事件，发布应用版本\n\n详细教程可参考：\n- [Clawdbot 接入飞书保姆级教程](https://mp.weixin.qq.com/s/_i1fgNbeDrBR5wurEmJf0A)\n- [腾讯云：Moltbot 接入飞书保姆级教程](https://cloud.tencent.com/developer/article/2625073)\n\n---\n\n## 硬件参考配置\n\n以下是一个 8 卡 GPU 部署的参考配置：\n\n| 组件 | 配置 |\n|------|------|\n| GPU | 8 × NVIDIA RTX 5090（32GB 显存） |\n| CPU | 双路高核心数处理器（至少需支持 AVX 512 指令集） |\n| 内存 | 512GB+ |\n| 模型 | Kimi K2.5 / DeepSeek-V3 / GLM-4.7 等 |\n\n```bash\n# 启动示例\nkt run kimi-k2.5\n```\n\n---\n\n## KTransformers 与传统部署对比\n\n| 特性 | KTransformers | 传统部署 |\n|------|---------------|----------|\n| 显存需求 | 小 | 原始大小 |\n| MoE 支持 | CPU-GPU 动态调度 | 无 |\n| CPU-GPU 混合 | NUMA 优化 | 无 |\n| 管理工具 | kt CLI 全栈工具 | 手动 |\n| 故障诊断 | `kt doctor` 自动检测 | 手动调试 |\n\n---\n\n## 适用场景\n\n- **企业部署**：客户服务自动化、文档智能问答、工作流自动化\n- **研发团队**：模型快速验证、性能基准测试、实验环境搭建\n- **个人用户**：低成本本地 AI 助手、隐私数据可控\n\n---\n\n## 相关链接\n\n- [KTransformers GitHub](https://github.com/KTransformers/ktransformers)\n- [Clawdbot 官网](https://openclaw.ai/)\n- [Clawdbot GitHub](https://github.com/clawdbot/clawdbot)\n- [飞书桥接插件](https://github.com/m1heng/clawdbot-feishu)\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "ARG CUDA_VERSION=12.8.1\nFROM docker.1ms.run/nvidia/cuda:${CUDA_VERSION}-cudnn-devel-ubuntu24.04 AS base\n\nARG TARGETARCH\nARG GRACE_BLACKWELL=0\nARG HOPPER_SBO=0\nARG CPU_VARIANT=x86-intel-multi\nARG BUILD_ALL_CPU_VARIANTS=1\n\n# Proxy settings for build-time network access\nARG HTTP_PROXY\nARG HTTPS_PROXY\nARG http_proxy\nARG https_proxy\nENV HTTP_PROXY=${HTTP_PROXY} \\\n    HTTPS_PROXY=${HTTPS_PROXY} \\\n    http_proxy=${http_proxy} \\\n    https_proxy=${https_proxy}\n\nARG GRACE_BLACKWELL_DEEPEP_BRANCH=gb200_blog_part_2\nARG HOPPER_SBO_DEEPEP_COMMIT=9f2fc4b3182a51044ae7ecb6610f7c9c3258c4d6\nARG DEEPEP_COMMIT=9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee\nARG BUILD_AND_DOWNLOAD_PARALLEL=8\nARG SGL_KERNEL_VERSION=0.3.19\nARG SGL_VERSION=0.5.6.post1\nARG USE_LATEST_SGLANG=0\nARG GDRCOPY_VERSION=2.5.1\nARG UBUNTU_MIRROR\nARG GITHUB_ARTIFACTORY=github.com\nARG FLASHINFER_VERSION=0.5.3\n\n# ktransformers wheel version (cu128torch28 for CUDA 12.8 + PyTorch 2.8)\nARG KTRANSFORMERS_VERSION=0.4.2\nARG KTRANSFORMERS_WHEEL=ktransformers-0.4.2+cu128torch28fancy-cp312-cp312-linux_x86_64.whl\n\n# flash_attn wheel for fine-tune env\nARG FLASH_ATTN_WHEEL=flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl\n\nENV DEBIAN_FRONTEND=noninteractive \\\n    CUDA_HOME=/usr/local/cuda \\\n    GDRCOPY_HOME=/usr/src/gdrdrv-${GDRCOPY_VERSION}/ \\\n    FLASHINFER_VERSION=${FLASHINFER_VERSION}\n\n# Add GKE default lib and bin locations\nENV PATH=\"${PATH}:/usr/local/nvidia/bin\" \\\n    LD_LIBRARY_PATH=\"${LD_LIBRARY_PATH}:/usr/local/nvidia/lib:/usr/local/nvidia/lib64\"\n\n# Replace Ubuntu sources with Tsinghua mirror for Ubuntu 24.04 (noble)\nRUN if [ -n \"$UBUNTU_MIRROR\" ]; then \\\n    echo \"deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ noble main restricted universe multiverse\" > /etc/apt/sources.list && \\\n    echo \"deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ noble-updates main restricted universe multiverse\" >> /etc/apt/sources.list && \\\n    echo \"deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ noble-backports main restricted universe multiverse\" >> /etc/apt/sources.list && \\\n    echo \"deb http://security.ubuntu.com/ubuntu/ noble-security main restricted universe multiverse\" >> /etc/apt/sources.list && \\\n    rm -f /etc/apt/sources.list.d/ubuntu.sources; \\\nfi\n\n# Install system dependencies (organized by category for better caching)\nRUN --mount=type=cache,target=/var/cache/apt,id=base-apt \\\n    echo 'tzdata tzdata/Areas select America' | debconf-set-selections \\\n    && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \\\n    && apt-get update && apt-get install -y --no-install-recommends --allow-change-held-packages \\\n    # Core system utilities\n    tzdata \\\n    ca-certificates \\\n    software-properties-common \\\n    netcat-openbsd \\\n    kmod \\\n    unzip \\\n    openssh-server \\\n    curl \\\n    wget \\\n    lsof \\\n    locales \\\n    # Build essentials\n    build-essential \\\n    cmake \\\n    perl \\\n    patchelf \\\n    ccache \\\n    git \\\n    git-lfs \\\n    # MPI and NUMA\n    libopenmpi-dev \\\n    libnuma1 \\\n    libnuma-dev \\\n    numactl \\\n    # transformers multimodal VLM\n    ffmpeg \\\n    # InfiniBand/RDMA\n    libibverbs-dev \\\n    libibverbs1 \\\n    libibumad3 \\\n    librdmacm1 \\\n    libnl-3-200 \\\n    libnl-route-3-200 \\\n    libnl-route-3-dev \\\n    libnl-3-dev \\\n    ibverbs-providers \\\n    infiniband-diags \\\n    perftest \\\n    # Development libraries\n    libgoogle-glog-dev \\\n    libgtest-dev \\\n    libjsoncpp-dev \\\n    libunwind-dev \\\n    libboost-all-dev \\\n    libssl-dev \\\n    libgrpc-dev \\\n    libgrpc++-dev \\\n    libprotobuf-dev \\\n    protobuf-compiler \\\n    protobuf-compiler-grpc \\\n    pybind11-dev \\\n    libhiredis-dev \\\n    libcurl4-openssl-dev \\\n    libczmq4 \\\n    libczmq-dev \\\n    libfabric-dev \\\n    # Package building tools\n    devscripts \\\n    debhelper \\\n    fakeroot \\\n    dkms \\\n    check \\\n    libsubunit0 \\\n    libsubunit-dev \\\n    # Development tools\n    gdb \\\n    ninja-build \\\n    vim \\\n    tmux \\\n    htop \\\n    zsh \\\n    tree \\\n    less \\\n    rdma-core \\\n    # NCCL\n    libnccl2 \\\n    libnccl-dev \\\n    && rm -rf /var/lib/apt/lists/* \\\n    && apt-get clean\n\n# GDRCopy installation\nRUN mkdir -p /tmp/gdrcopy && cd /tmp \\\n    && curl --retry 3 --retry-delay 2 -fsSL -o v${GDRCOPY_VERSION}.tar.gz \\\n        https://${GITHUB_ARTIFACTORY}/NVIDIA/gdrcopy/archive/refs/tags/v${GDRCOPY_VERSION}.tar.gz \\\n    && tar -xzf v${GDRCOPY_VERSION}.tar.gz && rm v${GDRCOPY_VERSION}.tar.gz \\\n    && cd gdrcopy-${GDRCOPY_VERSION}/packages \\\n    && CUDA=/usr/local/cuda ./build-deb-packages.sh \\\n    && dpkg -i gdrdrv-dkms_*.deb libgdrapi_*.deb gdrcopy-tests_*.deb gdrcopy_*.deb \\\n    && cd / && rm -rf /tmp/gdrcopy\n\n# Fix DeepEP IBGDA symlink\nRUN ln -sf /usr/lib/$(uname -m)-linux-gnu/libmlx5.so.1 /usr/lib/$(uname -m)-linux-gnu/libmlx5.so\n\n# Set up locale\nRUN locale-gen en_US.UTF-8\nENV LANG=en_US.UTF-8 \\\n    LANGUAGE=en_US:en \\\n    LC_ALL=en_US.UTF-8\n\n########################################################\n########## Install Miniconda ###########################\n########################################################\n\nRUN mkdir -p /opt/miniconda3 \\\n    && wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /opt/miniconda3/miniconda.sh \\\n    && bash /opt/miniconda3/miniconda.sh -b -u -p /opt/miniconda3 \\\n    && rm /opt/miniconda3/miniconda.sh\n\n# Add conda to PATH\nENV PATH=\"/opt/miniconda3/bin:${PATH}\"\n\n# Accept conda TOS\nRUN conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \\\n    && conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r\n\n# Configure conda to use Tsinghua mirror\nRUN conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main \\\n    && conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free \\\n    && conda config --set show_channel_urls yes\n\n########################################################\n########## Dual Conda Environment Setup ################\n########################################################\n\nFROM base AS framework\n\nARG CUDA_VERSION\nARG BUILD_AND_DOWNLOAD_PARALLEL\nARG SGL_KERNEL_VERSION\nARG SGL_VERSION\nARG USE_LATEST_SGLANG\nARG FLASHINFER_VERSION\nARG GRACE_BLACKWELL\nARG GRACE_BLACKWELL_DEEPEP_BRANCH\nARG HOPPER_SBO\nARG HOPPER_SBO_DEEPEP_COMMIT\nARG DEEPEP_COMMIT\nARG GITHUB_ARTIFACTORY\nARG KTRANSFORMERS_VERSION\nARG KTRANSFORMERS_WHEEL\nARG FLASH_ATTN_WHEEL\nARG FUNCTIONALITY=sft\n\nWORKDIR /workspace\n\n# Create conda environments (fine-tune only needed for sft mode)\nRUN conda create -n serve python=3.12 -y \\\n    && if [ \"$FUNCTIONALITY\" = \"sft\" ]; then conda create -n fine-tune python=3.12 -y; fi\n\n# Set pip mirror for conda envs\nRUN /opt/miniconda3/envs/serve/bin/pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple \\\n    && if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        /opt/miniconda3/envs/fine-tune/bin/pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple; \\\n    fi\n\n# Clone repositories (sglang is included as a submodule in ktransformers)\nRUN git clone --depth 1 https://${GITHUB_ARTIFACTORY}/kvcache-ai/ktransformers.git /workspace/ktransformers \\\n    && cd /workspace/ktransformers && git submodule update --init --recursive \\\n    && ln -s /workspace/ktransformers/third_party/sglang /workspace/sglang \\\n    && if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        git clone --depth 1 https://${GITHUB_ARTIFACTORY}/hiyouga/LLaMA-Factory.git /workspace/LLaMA-Factory; \\\n    fi\n\n# Download ktransformers wheel and flash_attn wheel for fine-tune env (sft mode only)\nRUN if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        curl --retry 3 --retry-delay 2 -fsSL -o /workspace/${KTRANSFORMERS_WHEEL} \\\n            https://${GITHUB_ARTIFACTORY}/kvcache-ai/ktransformers/releases/download/v${KTRANSFORMERS_VERSION}/${KTRANSFORMERS_WHEEL} \\\n        && curl --retry 3 --retry-delay 2 -fsSL -o /workspace/${FLASH_ATTN_WHEEL} \\\n            https://${GITHUB_ARTIFACTORY}/Dao-AILab/flash-attention/releases/download/v2.8.3/${FLASH_ATTN_WHEEL}; \\\n    fi\n\n########################################################\n# Environment 1: serve (sglang + kt-kernel)\n########################################################\n\n# Upgrade pip and install basic tools in serve env\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    /opt/miniconda3/envs/serve/bin/pip install --upgrade pip setuptools wheel html5lib six\n\n# Install sgl-kernel\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    case \"$CUDA_VERSION\" in \\\n        12.6.1) CUINDEX=126 ;; \\\n        12.8.1) CUINDEX=128 ;; \\\n        12.9.1) CUINDEX=129 ;; \\\n        13.0.1) CUINDEX=130 ;; \\\n        *) echo \"Unsupported CUDA version: $CUDA_VERSION\" && exit 1 ;; \\\n    esac \\\n    && if [ \"$CUDA_VERSION\" = \"12.6.1\" ]; then \\\n        /opt/miniconda3/envs/serve/bin/pip install https://${GITHUB_ARTIFACTORY}/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu124-cp310-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps \\\n    ; \\\n    elif [ \"$CUDA_VERSION\" = \"12.8.1\" ] || [ \"$CUDA_VERSION\" = \"12.9.1\" ]; then \\\n        /opt/miniconda3/envs/serve/bin/pip install sgl-kernel==${SGL_KERNEL_VERSION} \\\n    ; \\\n    elif [ \"$CUDA_VERSION\" = \"13.0.1\" ]; then \\\n        /opt/miniconda3/envs/serve/bin/pip install https://github.com/sgl-project/whl/releases/download/v${SGL_KERNEL_VERSION}/sgl_kernel-${SGL_KERNEL_VERSION}+cu130-cp310-abi3-manylinux2014_$(uname -m).whl --force-reinstall --no-deps \\\n    ; \\\n    fi\n\n# Install SGLang in serve env (version aligned with ktransformers)\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    case \"$CUDA_VERSION\" in \\\n        12.6.1) CUINDEX=126 ;; \\\n        12.8.1) CUINDEX=128 ;; \\\n        12.9.1) CUINDEX=129 ;; \\\n        13.0.1) CUINDEX=130 ;; \\\n    esac \\\n    && export SGLANG_KT_VERSION=$(python3 -c \"exec(open('/workspace/ktransformers/version.py').read()); print(__version__)\") \\\n    && echo \"Installing sglang-kt v${SGLANG_KT_VERSION}\" \\\n    && cd /workspace/sglang \\\n    && /opt/miniconda3/envs/serve/bin/pip install -e \"python[all]\" --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX}\n\n# Download FlashInfer cubin for serve env\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    FLASHINFER_CUBIN_DOWNLOAD_THREADS=${BUILD_AND_DOWNLOAD_PARALLEL} FLASHINFER_LOGGING_LEVEL=warning \\\n    /opt/miniconda3/envs/serve/bin/python -m flashinfer --download-cubin\n\n# Install DeepEP in serve env\nRUN set -eux; \\\n    if [ \"$GRACE_BLACKWELL\" = \"1\" ]; then \\\n      git clone https://github.com/fzyzcjy/DeepEP.git /workspace/DeepEP && \\\n      cd /workspace/DeepEP && \\\n      git checkout ${GRACE_BLACKWELL_DEEPEP_BRANCH} && \\\n      sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh; \\\n    elif [ \"$HOPPER_SBO\" = \"1\" ]; then \\\n      git clone https://github.com/deepseek-ai/DeepEP.git -b antgroup-opt /workspace/DeepEP && \\\n      cd /workspace/DeepEP && \\\n      git checkout ${HOPPER_SBO_DEEPEP_COMMIT} && \\\n      sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh; \\\n    else \\\n      curl --retry 3 --retry-delay 2 -fsSL -o /tmp/${DEEPEP_COMMIT}.zip \\\n          https://${GITHUB_ARTIFACTORY}/deepseek-ai/DeepEP/archive/${DEEPEP_COMMIT}.zip && \\\n      unzip -q /tmp/${DEEPEP_COMMIT}.zip -d /tmp && rm /tmp/${DEEPEP_COMMIT}.zip && \\\n      mv /tmp/DeepEP-${DEEPEP_COMMIT} /workspace/DeepEP && \\\n      cd /workspace/DeepEP && \\\n      sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh; \\\n    fi\n\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    cd /workspace/DeepEP && \\\n    case \"$CUDA_VERSION\" in \\\n        12.6.1) CHOSEN_TORCH_CUDA_ARCH_LIST='9.0' ;; \\\n        12.8.1) CHOSEN_TORCH_CUDA_ARCH_LIST='9.0;10.0' ;; \\\n        12.9.1|13.0.1) CHOSEN_TORCH_CUDA_ARCH_LIST='9.0;10.0;10.3' ;; \\\n        *) echo \"Unsupported CUDA version: $CUDA_VERSION\" && exit 1 ;; \\\n    esac && \\\n    . /opt/miniconda3/etc/profile.d/conda.sh && conda activate serve && \\\n    TORCH_CUDA_ARCH_LIST=\"${CHOSEN_TORCH_CUDA_ARCH_LIST}\" MAX_JOBS=${BUILD_AND_DOWNLOAD_PARALLEL} \\\n    pip install --no-build-isolation .\n\n# Install NCCL for serve env\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    if [ \"${CUDA_VERSION%%.*}\" = \"12\" ]; then \\\n        /opt/miniconda3/envs/serve/bin/pip install nvidia-nccl-cu12==2.28.3 --force-reinstall --no-deps ; \\\n    elif [ \"${CUDA_VERSION%%.*}\" = \"13\" ]; then \\\n        /opt/miniconda3/envs/serve/bin/pip install nvidia-nccl-cu13==2.28.3 --force-reinstall --no-deps ; \\\n    fi\n\n# Install kt-kernel in serve env with all CPU variants\nRUN . /opt/miniconda3/etc/profile.d/conda.sh && conda activate serve \\\n    && cd /workspace/ktransformers/kt-kernel \\\n    && CPUINFER_BUILD_ALL_VARIANTS=1 ./install.sh build\n\n########################################################\n# Environment 2: fine-tune (LLaMA-Factory + ktransformers) - sft mode only\n########################################################\n\n# Install dependency libraries for ktransformers (CUDA 11.8 runtime required)\nRUN if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        conda install -n fine-tune -y -c conda-forge libstdcxx-ng gcc_impl_linux-64 \\\n        && conda install -n fine-tune -y -c nvidia/label/cuda-11.8.0 cuda-runtime; \\\n    fi\n\n# Install PyTorch 2.8 in fine-tune env\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        case \"$CUDA_VERSION\" in \\\n            12.6.1) CUINDEX=126 ;; \\\n            12.8.1) CUINDEX=128 ;; \\\n            12.9.1) CUINDEX=129 ;; \\\n            13.0.1) CUINDEX=130 ;; \\\n        esac \\\n        && /opt/miniconda3/envs/fine-tune/bin/pip install --upgrade pip setuptools wheel hatchling \\\n        && /opt/miniconda3/envs/fine-tune/bin/pip install \\\n            torch==2.8.0 \\\n            torchvision \\\n            torchaudio \\\n            --extra-index-url https://download.pytorch.org/whl/cu${CUINDEX}; \\\n    fi\n\n# Install LLaMA-Factory in fine-tune env\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        cd /workspace/LLaMA-Factory \\\n        && /opt/miniconda3/envs/fine-tune/bin/pip install -e \".[torch,metrics]\" --no-build-isolation; \\\n    fi\n\n# Install ktransformers wheel in fine-tune env\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        /opt/miniconda3/envs/fine-tune/bin/pip install /workspace/${KTRANSFORMERS_WHEEL}; \\\n    fi\n\n# Install flash_attn wheel in fine-tune env\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        /opt/miniconda3/envs/fine-tune/bin/pip install /workspace/${FLASH_ATTN_WHEEL}; \\\n    fi\n\n# Install NCCL for fine-tune env\nRUN --mount=type=cache,target=/root/.cache/pip \\\n    if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        if [ \"${CUDA_VERSION%%.*}\" = \"12\" ]; then \\\n            /opt/miniconda3/envs/fine-tune/bin/pip install nvidia-nccl-cu12==2.28.3 --force-reinstall --no-deps ; \\\n        elif [ \"${CUDA_VERSION%%.*}\" = \"13\" ]; then \\\n            /opt/miniconda3/envs/fine-tune/bin/pip install nvidia-nccl-cu13==2.28.3 --force-reinstall --no-deps ; \\\n        fi; \\\n    fi\n\n########################################################\n# Cleanup and final setup\n########################################################\n\n# Clean up downloaded wheels\nRUN if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        rm -f /workspace/${KTRANSFORMERS_WHEEL} /workspace/${FLASH_ATTN_WHEEL}; \\\n    fi\n\n# Initialize conda for bash\nRUN /opt/miniconda3/bin/conda init bash\n\n# Create shell aliases for convenience\nRUN echo '\\n# Conda environment aliases\\nalias serve=\"conda activate serve\"' >> /root/.bashrc \\\n    && if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        echo 'alias finetune=\"conda activate fine-tune\"' >> /root/.bashrc; \\\n    fi\n\n########################################################\n# Extract version information for image naming\n########################################################\n\n# Extract versions from each component and save to versions.env\nRUN set -x && \\\n    # KTransformers version (single source of truth for both kt-kernel and sglang-kt)\n    cd /workspace/ktransformers && \\\n    KTRANSFORMERS_VERSION=$(python3 -c \"exec(open('version.py').read()); print(__version__)\" 2>/dev/null || echo \"unknown\") && \\\n    echo \"KTRANSFORMERS_VERSION=$KTRANSFORMERS_VERSION\" > /workspace/versions.env && \\\n    echo \"Extracted KTransformers version: $KTRANSFORMERS_VERSION\" && \\\n    \\\n    # sglang-kt version = ktransformers version (aligned)\n    echo \"SGLANG_KT_VERSION=$KTRANSFORMERS_VERSION\" >> /workspace/versions.env && \\\n    echo \"sglang-kt version (aligned): $KTRANSFORMERS_VERSION\" && \\\n    \\\n    # LLaMA-Factory version (from fine-tune environment, sft mode only)\n    if [ \"$FUNCTIONALITY\" = \"sft\" ]; then \\\n        . /opt/miniconda3/etc/profile.d/conda.sh && conda activate fine-tune && \\\n        cd /workspace/LLaMA-Factory && \\\n        LLAMAFACTORY_VERSION=$(python -c \"import sys; sys.path.insert(0, 'src'); from llamafactory import __version__; print(__version__)\" 2>/dev/null || echo \"unknown\") && \\\n        echo \"LLAMAFACTORY_VERSION=$LLAMAFACTORY_VERSION\" >> /workspace/versions.env && \\\n        echo \"Extracted LLaMA-Factory version: $LLAMAFACTORY_VERSION\"; \\\n    else \\\n        echo \"LLAMAFACTORY_VERSION=none\" >> /workspace/versions.env && \\\n        echo \"LLaMA-Factory not installed (infer mode)\"; \\\n    fi && \\\n    \\\n    # Display all versions\n    echo \"=== Version Summary ===\" && \\\n    cat /workspace/versions.env\n\nWORKDIR /workspace\n\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "docker/README-packaging.md",
    "content": "# KTransformers Docker Packaging Guide\n\nThis directory contains scripts for building and distributing KTransformers Docker images with standardized naming conventions.\n\n## Overview\n\nThe packaging system provides:\n\n- **Automated version detection** from sglang, ktransformers, and LLaMA-Factory\n- **Multi-CPU variant support** (AMX, AVX512, AVX2) with runtime auto-detection\n- **Standardized naming convention** for easy identification and management\n- **Two distribution methods**:\n  - Local tar file export for offline distribution\n  - DockerHub publishing for online distribution\n\n## Naming Convention\n\nDocker images follow this naming pattern:\n\n```\nsglang-v{sglang版本}_ktransformers-v{ktransformers版本}_{cpu信息}_{gpu信息}_{功能模式}_{时间戳}\n```\n\n### Example Names\n\n**Tar file:**\n```\nsglang-v0.5.6_ktransformers-v0.4.3_x86-intel-multi_cu128_sft_llamafactory-v0.9.3_20241212143022.tar\n```\n\n**DockerHub tags:**\n```\nFull tag:\nkvcache/ktransformers:sglang-v0.5.6_ktransformers-v0.4.3_x86-intel-multi_cu128_sft_llamafactory-v0.9.3_20241212143022\n\nSimplified tag:\nkvcache/ktransformers:v0.4.3-cu128\n```\n\n### Name Components\n\n| Component | Description | Example |\n|-----------|-------------|---------|\n| sglang version | SGLang package version | `v0.5.6` |\n| ktransformers version | KTransformers version | `v0.4.3` |\n| cpu info | CPU instruction set support | `x86-intel-multi` (includes AMX/AVX512/AVX2) |\n| gpu info | CUDA version | `cu128` (CUDA 12.8) |\n| functionality | Feature mode | `sft_llamafactory-v0.9.3` or `infer` |\n| timestamp | Build time (Beijing/UTC+8) | `20241212143022` |\n\n## Files\n\n| File | Purpose |\n|------|---------|\n| `Dockerfile` | Main Dockerfile with multi-CPU build and version extraction |\n| `docker-utils.sh` | Shared utility functions for both scripts |\n| `build-docker-tar.sh` | Build and export Docker image to tar file |\n| `push-to-dockerhub.sh` | Build and push Docker image to DockerHub |\n\n## Prerequisites\n\n- Docker installed and running\n- For DockerHub push: Docker Hub account and login (`docker login`)\n- Sufficient disk space (at least 20GB recommended)\n- Internet access (or local mirrors configured)\n\n## Quick Start\n\n### Build Local Tar File\n\n```bash\ncd docker\n\n# Basic build\n./build-docker-tar.sh\n\n# With specific CUDA version and mirror\n./build-docker-tar.sh \\\n  --cuda-version 12.8.1 \\\n  --ubuntu-mirror 1\n\n# With proxy\n./build-docker-tar.sh \\\n  --cuda-version 12.8.1 \\\n  --ubuntu-mirror 1 \\\n  --http-proxy \"http://127.0.0.1:16981\" \\\n  --https-proxy \"http://127.0.0.1:16981\" \\\n  --output-dir /path/to/output\n```\n\n### Push to DockerHub\n\n```bash\ncd docker\n\n# Basic push (requires --repository)\n./push-to-dockerhub.sh \\\n  --repository kvcache/ktransformers\n\n# With simplified tag\n./push-to-dockerhub.sh \\\n  --cuda-version 12.8.1 \\\n  --repository kvcache/ktransformers \\\n  --also-push-simplified\n\n# Skip build if image exists\n./push-to-dockerhub.sh \\\n  --repository kvcache/ktransformers \\\n  --skip-build\n```\n\n## Script Options\n\n### build-docker-tar.sh\n\n```\nBuild Configuration:\n  --cuda-version VERSION       CUDA version (default: 12.8.1)\n  --ubuntu-mirror 0|1         Use Tsinghua mirror (default: 0)\n  --http-proxy URL            HTTP proxy URL\n  --https-proxy URL           HTTPS proxy URL\n  --cpu-variant VARIANT       CPU variant (default: x86-intel-multi)\n  --functionality TYPE        Mode: sft or infer (default: sft)\n\nPaths:\n  --dockerfile PATH           Path to Dockerfile (default: ./Dockerfile)\n  --context-dir PATH          Build context directory (default: .)\n  --output-dir PATH           Output directory for tar (default: .)\n\nOptions:\n  --dry-run                   Preview without building\n  --keep-image                Keep Docker image after export\n  --build-arg KEY=VALUE       Additional build arguments\n  -h, --help                  Show help message\n```\n\n### push-to-dockerhub.sh\n\n```\nAll options from build-docker-tar.sh, plus:\n\nRegistry Settings:\n  --registry REGISTRY         Docker registry (default: docker.io)\n  --repository REPO           Repository name (REQUIRED)\n\nOptions:\n  --skip-build                Skip build if image exists\n  --also-push-simplified      Also push simplified tag\n  --max-retries N             Max push retries (default: 3)\n  --retry-delay SECONDS       Delay between retries (default: 5)\n```\n\n## Usage Examples\n\n### Example 1: Local Development Build\n\nFor testing on your local machine:\n\n```bash\n./build-docker-tar.sh \\\n  --cuda-version 12.8.1 \\\n  --output-dir ./builds \\\n  --keep-image\n```\n\nThis will:\n1. Build the Docker image\n2. Export to tar in `./builds/` directory\n3. Keep the Docker image for local testing\n\n### Example 2: Production Build for Distribution\n\nFor creating a production build with mirrors and proxy:\n\n```bash\n./build-docker-tar.sh \\\n  --cuda-version 12.8.1 \\\n  --ubuntu-mirror 1 \\\n  --http-proxy \"http://127.0.0.1:16981\" \\\n  --https-proxy \"http://127.0.0.1:16981\" \\\n  --output-dir /mnt/data/releases\n```\n\n### Example 3: Publish to DockerHub\n\nFor publishing to DockerHub:\n\n```bash\n# First, login to Docker Hub\ndocker login\n\n# Then push\n./push-to-dockerhub.sh \\\n  --cuda-version 12.8.1 \\\n  --repository kvcache/ktransformers \\\n  --also-push-simplified\n```\n\nThis creates two tags:\n- Full: `kvcache/ktransformers:sglang-v0.5.6_ktransformers-v0.4.3_x86-intel-multi_cu128_sft_llamafactory-v0.9.3_20241212143022`\n- Simplified: `kvcache/ktransformers:v0.4.3-cu128`\n\n### Example 4: Dry Run\n\nPreview the build without actually building:\n\n```bash\n./build-docker-tar.sh --cuda-version 12.8.1 --dry-run\n```\n\n### Example 5: Custom Build Arguments\n\nPass additional Docker build arguments:\n\n```bash\n./build-docker-tar.sh \\\n  --cuda-version 12.8.1 \\\n  --build-arg SGL_VERSION=0.5.7 \\\n  --build-arg FLASHINFER_VERSION=0.5.4\n```\n\n## Using the Built Images\n\n### Load from Tar File\n\n```bash\n# Load the image\ndocker load -i sglang-v0.5.6_ktransformers-v0.4.3_x86-intel-multi_cu128_sft_llamafactory-v0.9.3_20241212143022.tar\n\n# Run the container\ndocker run -it --rm \\\n  --gpus all \\\n  sglang-v0.5.6_ktransformers-v0.4.3_x86-intel-multi_cu128_sft_llamafactory-v0.9.3_20241212143022 \\\n  /bin/bash\n```\n\n### Pull from DockerHub\n\n```bash\n# Pull with full tag\ndocker pull kvcache/ktransformers:sglang-v0.5.6_ktransformers-v0.4.3_x86-intel-multi_cu128_sft_llamafactory-v0.9.3_20241212143022\n\n# Or pull with simplified tag\ndocker pull kvcache/ktransformers:v0.4.3-cu128\n\n# Run the container\ndocker run -it --rm \\\n  --gpus all \\\n  kvcache/ktransformers:v0.4.3-cu128 \\\n  /bin/bash\n```\n\n### Inside the Container\n\nThe image contains two conda environments:\n\n```bash\n# Activate serve environment (for inference with sglang)\nconda activate serve\n# or use the alias:\nserve\n\n# Activate fine-tune environment (for training with LLaMA-Factory)\nconda activate fine-tune\n# or use the alias:\nfinetune\n```\n\n## Multi-CPU Variant Support\n\nThe Docker image includes all three CPU variants:\n- **AMX** - For Intel Sapphire Rapids and newer (4th Gen Xeon+)\n- **AVX512** - For Intel Skylake-X, Ice Lake, Cascade Lake\n- **AVX2** - Maximum compatibility for older CPUs\n\nThe runtime automatically detects your CPU and loads the appropriate variant. To override:\n\n```bash\n# Force use of AVX2 variant\nexport KT_KERNEL_CPU_VARIANT=avx2\npython your_script.py\n\n# Enable debug output to see which variant is loaded\nexport KT_KERNEL_DEBUG=1\npython your_script.py\n```\n\n## Version Extraction\n\nVersions are automatically extracted during Docker build from:\n\n- **SGLang**: From `sglang.__version__` in serve environment\n- **KTransformers**: From `version.py` in ktransformers repository\n- **LLaMA-Factory**: From `llamafactory.__version__` in fine-tune environment\n\nThe versions are saved to `/workspace/versions.env` in the image:\n\n```bash\n# View versions in running container\ncat /workspace/versions.env\n\n# Output:\nSGLANG_VERSION=0.5.6\nKTRANSFORMERS_VERSION=0.4.3\nLLAMAFACTORY_VERSION=0.9.3\n```\n\n## Troubleshooting\n\n### Build Fails with Out of Disk Space\n\nCheck available disk space:\n```bash\ndf -h\n```\n\nThe build requires approximately 15-20GB of disk space. Clean up Docker:\n```bash\ndocker system prune -a\n```\n\n### Version Extraction Fails\n\nIf version extraction fails (shows \"unknown\"), check:\n\n1. The cloned repositories have the correct branches\n2. Python packages are properly installed in conda environments\n3. Version files exist in expected locations\n\nYou can manually verify by running:\n```bash\ndocker run --rm <image> /bin/bash -c \"\n  source /opt/miniconda3/etc/profile.d/conda.sh &&\n  conda activate serve &&\n  python -c 'import sglang; print(sglang.__version__)'\n\"\n```\n\n### Push to DockerHub Fails\n\n1. **Check login**: `docker login`\n2. **Check repository name**: Must include namespace (e.g., `kvcache/ktransformers`, not just `ktransformers`)\n3. **Network issues**: Use `--max-retries` and `--retry-delay` options\n4. **Rate limiting**: DockerHub has pull/push rate limits for free accounts\n\n## Advanced Topics\n\n### Custom Dockerfile Location\n\n```bash\n./build-docker-tar.sh \\\n  --dockerfile /path/to/custom/Dockerfile \\\n  --context-dir /path/to/build/context\n```\n\n### Building Only Inference Image (Future)\n\nCurrently, the image always includes both serve and fine-tune environments. To create an inference-only image, modify the Dockerfile to skip the fine-tune environment section.\n\n### Customizing CPU Variants\n\nTo build only specific CPU variants, modify `kt-kernel/install.sh` or set environment variables in the Dockerfile.\n\n### CI/CD Integration\n\nThe scripts are designed for manual execution but can be integrated into CI/CD pipelines:\n\n```yaml\n# Example GitHub Actions workflow\n- name: Build and push Docker image\n  run: |\n    cd docker\n    ./push-to-dockerhub.sh \\\n      --cuda-version ${{ matrix.cuda_version }} \\\n      --repository ${{ secrets.DOCKER_REPOSITORY }} \\\n      --also-push-simplified\n```\n\n## Support\n\nFor issues and questions:\n- File an issue at: https://github.com/kvcache-ai/ktransformers/issues\n- Check documentation: https://github.com/kvcache-ai/ktransformers\n\n## License\n\nThis packaging system is part of KTransformers and follows the same license.\n"
  },
  {
    "path": "docker/docker-utils.sh",
    "content": "#!/usr/bin/env bash\n#\n# docker-utils.sh - Shared utility functions for Docker image build and publish scripts\n#\n# This script provides common functions for:\n# - Timestamp generation (Beijing timezone)\n# - Version extraction from Docker images\n# - Image name generation following naming conventions\n# - Colored logging\n# - Validation and error handling\n#\n# Usage: source docker-utils.sh\n\nset -euo pipefail\n\n# Color codes for logging\nCOLOR_RED='\\033[0;31m'\nCOLOR_GREEN='\\033[0;32m'\nCOLOR_YELLOW='\\033[1;33m'\nCOLOR_BLUE='\\033[0;34m'\nCOLOR_CYAN='\\033[0;36m'\nCOLOR_RESET='\\033[0m'\n\n################################################################################\n# Logging Functions\n################################################################################\n\nlog_info() {\n    echo -e \"${COLOR_BLUE}[INFO]${COLOR_RESET} $*\"\n}\n\nlog_success() {\n    echo -e \"${COLOR_GREEN}[SUCCESS]${COLOR_RESET} $*\"\n}\n\nlog_warning() {\n    echo -e \"${COLOR_YELLOW}[WARNING]${COLOR_RESET} $*\"\n}\n\nlog_error() {\n    echo -e \"${COLOR_RED}[ERROR]${COLOR_RESET} $*\" >&2\n}\n\nlog_step() {\n    echo -e \"\\n${COLOR_CYAN}==>${COLOR_RESET} $*\"\n}\n\n################################################################################\n# Timestamp Functions\n################################################################################\n\n# Generate timestamp in Beijing timezone (UTC+8)\n# Format: YYYYMMDDHHMMSS\n# Example: 20241212143022\nget_beijing_timestamp() {\n    # Try to use TZ environment variable approach\n    if date --version &>/dev/null 2>&1; then\n        # GNU date (Linux)\n        TZ='Asia/Shanghai' date '+%Y%m%d%H%M%S'\n    else\n        # BSD date (macOS)\n        TZ='Asia/Shanghai' date '+%Y%m%d%H%M%S'\n    fi\n}\n\n################################################################################\n# CUDA Version Parsing\n################################################################################\n\n# Parse CUDA version to short format\n# Input: 12.8.1 or 12.8 or 13.0.1\n# Output: cu128 or cu130\nparse_cuda_short_version() {\n    local cuda_version=\"$1\"\n\n    # Extract major and minor version\n    local major minor\n    major=$(echo \"$cuda_version\" | cut -d. -f1)\n    minor=$(echo \"$cuda_version\" | cut -d. -f2)\n\n    # Validate\n    if [[ ! \"$major\" =~ ^[0-9]+$ ]] || [[ ! \"$minor\" =~ ^[0-9]+$ ]]; then\n        log_error \"Invalid CUDA version format: $cuda_version\"\n        log_error \"Expected format: X.Y.Z (e.g., 12.8.1)\"\n        return 1\n    fi\n\n    echo \"cu${major}${minor}\"\n}\n\n################################################################################\n# Version Extraction\n################################################################################\n\n# Extract versions from built Docker image\n# Input: image tag (e.g., ktransformers:temp-build-20241212)\n# Output: Sets environment variables or prints to stdout\n#   SGLANG_VERSION=x.y.z\n#   KTRANSFORMERS_VERSION=x.y.z\n#   LLAMAFACTORY_VERSION=x.y.z\nextract_versions_from_image() {\n    local image_tag=\"$1\"\n\n    log_step \"Extracting versions from image: $image_tag\"\n\n    # Check if image exists\n    if ! docker image inspect \"$image_tag\" &>/dev/null; then\n        log_error \"Image not found: $image_tag\"\n        return 1\n    fi\n\n    # Extract versions.env file from the image\n    local versions_content\n    versions_content=$(docker run --rm \"$image_tag\" cat /workspace/versions.env 2>/dev/null)\n\n    if [ -z \"$versions_content\" ]; then\n        log_error \"Failed to extract versions from image\"\n        log_error \"The /workspace/versions.env file may not exist in the image\"\n        return 1\n    fi\n\n    # Parse and display versions\n    log_info \"Extracted versions:\"\n    echo \"$versions_content\" | while IFS= read -r line; do\n        log_info \"  $line\"\n    done\n\n    # Output the content (caller can parse this or eval it)\n    echo \"$versions_content\"\n}\n\n# Validate that all required versions were extracted\n# Input: versions string (output from extract_versions_from_image)\nvalidate_versions() {\n    local versions=\"$1\"\n    local all_valid=true\n\n    # Check each required version\n    for var in SGLANG_VERSION KTRANSFORMERS_VERSION LLAMAFACTORY_VERSION; do\n        local value\n        value=$(echo \"$versions\" | grep \"^${var}=\" | cut -d= -f2)\n\n        if [ -z \"$value\" ]; then\n            log_error \"Missing version: $var\"\n            all_valid=false\n        elif [ \"$value\" = \"unknown\" ]; then\n            log_warning \"Version is 'unknown': $var\"\n            # Don't fail, but warn user\n        fi\n    done\n\n    if [ \"$all_valid\" = false ]; then\n        return 1\n    fi\n\n    return 0\n}\n\n################################################################################\n# Image Naming\n################################################################################\n\n# Generate standardized image name\n# Input:\n#   $1: versions string (from extract_versions_from_image)\n#   $2: cuda_version (e.g., 12.8.1)\n#   $3: cpu_variant (e.g., x86-intel-multi)\n#   $4: functionality (e.g., sft_llamafactory or infer)\n#   $5: timestamp (optional, will generate if not provided)\n# Output: Standardized image name\n# Format: sglang-v{ver}_ktransformers-v{ver}_{cpu}_{gpu}_{func}_{timestamp}\ngenerate_image_name() {\n    local versions=\"$1\"\n    local cuda_version=\"$2\"\n    local cpu_variant=\"$3\"\n    local functionality=\"$4\"\n    local timestamp=\"${5:-$(get_beijing_timestamp)}\"\n\n    # Parse versions from the versions string\n    local sglang_ver ktrans_ver llama_ver\n    sglang_ver=$(echo \"$versions\" | grep \"^SGLANG_VERSION=\" | cut -d= -f2)\n    ktrans_ver=$(echo \"$versions\" | grep \"^KTRANSFORMERS_VERSION=\" | cut -d= -f2)\n    llama_ver=$(echo \"$versions\" | grep \"^LLAMAFACTORY_VERSION=\" | cut -d= -f2)\n\n    # Validate versions were extracted\n    if [ -z \"$sglang_ver\" ] || [ -z \"$ktrans_ver\" ]; then\n        log_error \"Failed to parse versions from input\"\n        return 1\n    fi\n\n    # Parse CUDA short version\n    local cuda_short\n    cuda_short=$(parse_cuda_short_version \"$cuda_version\")\n\n    # Build functionality string\n    local func_str\n    if [ \"$functionality\" = \"sft\" ]; then\n        func_str=\"sft_llamafactory-v${llama_ver}\"\n    else\n        func_str=\"infer\"\n    fi\n\n    # Generate full image name\n    # Format: sglang-v{ver}_ktransformers-v{ver}_{cpu}_{gpu}_{func}_{timestamp}\n    local image_name\n    image_name=\"sglang-v${sglang_ver}_ktransformers-v${ktrans_ver}_${cpu_variant}_${cuda_short}_${func_str}_${timestamp}\"\n\n    echo \"$image_name\"\n}\n\n# Generate simplified tag for DockerHub\n# Input:\n#   $1: ktransformers_version (e.g., 0.4.3)\n#   $2: cuda_version (e.g., 12.8.1)\n# Output: Simplified tag (e.g., v0.4.3-cu128)\ngenerate_simplified_tag() {\n    local ktrans_ver=\"$1\"\n    local cuda_version=\"$2\"\n\n    local cuda_short\n    cuda_short=$(parse_cuda_short_version \"$cuda_version\")\n\n    echo \"v${ktrans_ver}-${cuda_short}\"\n}\n\n################################################################################\n# Validation Functions\n################################################################################\n\n# Check if Docker daemon is running\ncheck_docker_running() {\n    if ! docker info &>/dev/null; then\n        log_error \"Docker daemon is not running\"\n        log_error \"Please start Docker and try again\"\n        return 1\n    fi\n    return 0\n}\n\n# Check if user is logged into Docker registry\n# Input: registry (optional, default: docker.io)\ncheck_docker_login() {\n    local registry=\"${1:-docker.io}\"\n\n    # Try to check auth by attempting a trivial operation\n    if ! docker login --help &>/dev/null; then\n        log_error \"Docker CLI is not available\"\n        return 1\n    fi\n\n    # Note: This is a best-effort check\n    # docker login status is not always easy to check programmatically\n    log_info \"Assuming Docker login is configured\"\n    log_info \"If push fails, please run: docker login $registry\"\n\n    return 0\n}\n\n# Validate CUDA version format\nvalidate_cuda_version() {\n    local cuda_version=\"$1\"\n\n    if [[ ! \"$cuda_version\" =~ ^[0-9]+\\.[0-9]+(\\.[0-9]+)?$ ]]; then\n        log_error \"Invalid CUDA version format: $cuda_version\"\n        log_error \"Expected format: X.Y or X.Y.Z (e.g., 12.8 or 12.8.1)\"\n        return 1\n    fi\n\n    return 0\n}\n\n# Check available disk space\n# Input: required space in GB\ncheck_disk_space() {\n    local required_gb=\"$1\"\n    local output_dir=\"${2:-.}\"\n\n    # Get available space in GB (works on Linux and macOS)\n    local available_kb\n    if df -k \"$output_dir\" &>/dev/null; then\n        available_kb=$(df -k \"$output_dir\" | tail -1 | awk '{print $4}')\n        local available_gb=$((available_kb / 1024 / 1024))\n\n        log_info \"Available disk space: ${available_gb}GB\"\n\n        if [ \"$available_gb\" -lt \"$required_gb\" ]; then\n            log_warning \"Low disk space: ${available_gb}GB available, ${required_gb}GB recommended\"\n            return 1\n        fi\n    else\n        log_warning \"Unable to check disk space\"\n    fi\n\n    return 0\n}\n\n# Check if file/directory exists and is writable\ncheck_writable() {\n    local path=\"$1\"\n\n    if [ -e \"$path\" ]; then\n        if [ ! -w \"$path\" ]; then\n            log_error \"Path exists but is not writable: $path\"\n            return 1\n        fi\n    else\n        # Try to create parent directory to test writability\n        local parent_dir\n        parent_dir=$(dirname \"$path\")\n        if [ ! -w \"$parent_dir\" ]; then\n            log_error \"Parent directory is not writable: $parent_dir\"\n            return 1\n        fi\n    fi\n\n    return 0\n}\n\n################################################################################\n# Cleanup Functions\n################################################################################\n\n# Remove intermediate Docker images\ncleanup_temp_images() {\n    local image_tag=\"$1\"\n\n    log_step \"Cleaning up temporary image: $image_tag\"\n\n    if docker image inspect \"$image_tag\" &>/dev/null; then\n        docker rmi \"$image_tag\" &>/dev/null || true\n        log_success \"Cleaned up temporary image\"\n    fi\n}\n\n################################################################################\n# Display Functions\n################################################################################\n\n# Display a summary box\ndisplay_summary() {\n    local title=\"$1\"\n    shift\n    local lines=(\"$@\")\n\n    local width=80\n    local border=$(printf '=%.0s' $(seq 1 $width))\n\n    echo \"\"\n    echo \"$border\"\n    echo \"  $title\"\n    echo \"$border\"\n    for line in \"${lines[@]}\"; do\n        echo \"  $line\"\n    done\n    echo \"$border\"\n    echo \"\"\n}\n\n################################################################################\n# Export functions\n################################################################################\n\n# Export all functions so they can be used by scripts that source this file\nexport -f log_info log_success log_warning log_error log_step\nexport -f get_beijing_timestamp\nexport -f parse_cuda_short_version\nexport -f extract_versions_from_image validate_versions\nexport -f generate_image_name generate_simplified_tag\nexport -f check_docker_running check_docker_login validate_cuda_version\nexport -f check_disk_space check_writable\nexport -f cleanup_temp_images\nexport -f display_summary\n"
  },
  {
    "path": "docker/push-to-dockerhub.sh",
    "content": "#!/usr/bin/env bash\n#\n# push-to-dockerhub.sh - Build and push Docker image to DockerHub\n#\n# This script builds a Docker image for ktransformers with standardized naming\n# and pushes it to DockerHub with both full and simplified tags.\n#\n# Features:\n# - Automatic version detection\n# - Standardized naming convention\n# - Multi-CPU variant support (AMX/AVX512/AVX2)\n# - Full and simplified tag support\n# - Retry logic for network failures\n# - Comprehensive error handling\n#\n# Usage:\n#   ./push-to-dockerhub.sh [OPTIONS]\n#\n# Example:\n#   ./push-to-dockerhub.sh \\\n#     --cuda-version 12.8.1 \\\n#     --repository kvcache/ktransformers \\\n#     --also-push-simplified\n\nset -euo pipefail\n\n# Get script directory\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n\n# Source utility functions\n# shellcheck source=docker-utils.sh\nsource \"$SCRIPT_DIR/docker-utils.sh\"\n\n################################################################################\n# Default Configuration\n################################################################################\n\n# Build parameters\nCUDA_VERSION=\"12.8.1\"\nUBUNTU_MIRROR=\"0\"\nHTTP_PROXY=\"\"\nHTTPS_PROXY=\"\"\nCPU_VARIANT=\"x86-intel-multi\"\nFUNCTIONALITY=\"sft\"\n\n# Paths\nDOCKERFILE=\"$SCRIPT_DIR/Dockerfile\"\nCONTEXT_DIR=\"$SCRIPT_DIR\"\n\n# Registry settings\nREGISTRY=\"docker.io\"\nREPOSITORY=\"\"  # Must be provided by user\n\n# Options\nDRY_RUN=false\nSKIP_BUILD=false\nALSO_PUSH_SIMPLIFIED=false\nMAX_RETRIES=3\nRETRY_DELAY=5\nEXTRA_BUILD_ARGS=()\n\n################################################################################\n# Help Message\n################################################################################\n\nusage() {\n    cat <<EOF\nUsage: $0 [OPTIONS]\n\nBuild and push Docker image to DockerHub with standardized naming.\n\nOPTIONS:\n    Build Configuration:\n        --cuda-version VERSION      CUDA version (default: 12.8.1)\n                                   Examples: 12.8.1, 12.6.1, 13.0.1\n\n        --ubuntu-mirror 0|1         Use Tsinghua mirror for Ubuntu packages\n                                   (default: 0)\n\n        --http-proxy URL           HTTP proxy URL\n                                   Example: http://127.0.0.1:16981\n\n        --https-proxy URL          HTTPS proxy URL\n                                   Example: http://127.0.0.1:16981\n\n        --cpu-variant VARIANT      CPU variant identifier\n                                   (default: x86-intel-multi)\n\n        --functionality TYPE       Functionality mode: sft or infer\n                                   (default: sft, includes LLaMA-Factory)\n\n    Paths:\n        --dockerfile PATH          Path to Dockerfile\n                                   (default: ./Dockerfile)\n\n        --context-dir PATH         Docker build context directory\n                                   (default: .)\n\n    Registry Settings:\n        --registry REGISTRY        Docker registry (default: docker.io)\n                                   Examples: docker.io, ghcr.io\n\n        --repository REPO          Repository name (REQUIRED)\n                                   Example: kvcache/ktransformers\n\n    Options:\n        --skip-build               Skip build if image exists locally\n        --also-push-simplified     Also push simplified tag (v{ver}-{cuda})\n        --max-retries N            Maximum push retries (default: 3)\n        --retry-delay SECONDS      Delay between retries (default: 5)\n        --dry-run                  Preview commands without executing\n        --build-arg KEY=VALUE      Additional build arguments (can be repeated)\n        -h, --help                 Show this help message\n\nEXAMPLES:\n    # Basic push\n    $0 --repository kvcache/ktransformers\n\n    # Push with simplified tag\n    $0 \\\\\n        --repository kvcache/ktransformers \\\\\n        --cuda-version 12.8.1 \\\\\n        --also-push-simplified\n\n    # Skip build if image exists\n    $0 \\\\\n        --repository kvcache/ktransformers \\\\\n        --skip-build\n\n    # Dry run to preview\n    $0 --repository kvcache/ktransformers --dry-run\n\nOUTPUT:\n    The image will be pushed with tags:\n\n    Full tag:\n      {registry}/{repository}:sglang-v{ver}_ktransformers-v{ver}_{cpu}_{gpu}_{func}_{timestamp}\n\n    Example:\n      docker.io/kvcache/ktransformers:sglang-v0.5.6_ktransformers-v0.4.3_x86-intel-multi_cu128_sft_llamafactory-v0.9.3_20241212143022\n\n    Simplified tag (if --also-push-simplified):\n      {registry}/{repository}:v{ktransformers-ver}-{cuda}\n\n    Example:\n      docker.io/kvcache/ktransformers:v0.4.3-cu128\n\nEOF\n    exit 0\n}\n\n################################################################################\n# Argument Parsing\n################################################################################\n\nparse_args() {\n    while [[ $# -gt 0 ]]; do\n        case \"$1\" in\n            --cuda-version)\n                CUDA_VERSION=\"$2\"\n                shift 2\n                ;;\n            --ubuntu-mirror)\n                UBUNTU_MIRROR=\"$2\"\n                shift 2\n                ;;\n            --http-proxy)\n                HTTP_PROXY=\"$2\"\n                shift 2\n                ;;\n            --https-proxy)\n                HTTPS_PROXY=\"$2\"\n                shift 2\n                ;;\n            --cpu-variant)\n                CPU_VARIANT=\"$2\"\n                shift 2\n                ;;\n            --functionality)\n                FUNCTIONALITY=\"$2\"\n                shift 2\n                ;;\n            --dockerfile)\n                DOCKERFILE=\"$2\"\n                shift 2\n                ;;\n            --context-dir)\n                CONTEXT_DIR=\"$2\"\n                shift 2\n                ;;\n            --registry)\n                REGISTRY=\"$2\"\n                shift 2\n                ;;\n            --repository)\n                REPOSITORY=\"$2\"\n                shift 2\n                ;;\n            --skip-build)\n                SKIP_BUILD=true\n                shift\n                ;;\n            --also-push-simplified)\n                ALSO_PUSH_SIMPLIFIED=true\n                shift\n                ;;\n            --max-retries)\n                MAX_RETRIES=\"$2\"\n                shift 2\n                ;;\n            --retry-delay)\n                RETRY_DELAY=\"$2\"\n                shift 2\n                ;;\n            --dry-run)\n                DRY_RUN=true\n                shift\n                ;;\n            --build-arg)\n                EXTRA_BUILD_ARGS+=(\"--build-arg\" \"$2\")\n                shift 2\n                ;;\n            -h|--help)\n                usage\n                ;;\n            *)\n                log_error \"Unknown option: $1\"\n                echo \"Use -h or --help for usage information\"\n                exit 1\n                ;;\n        esac\n    done\n}\n\n################################################################################\n# Validation\n################################################################################\n\nvalidate_config() {\n    log_step \"Validating configuration\"\n\n    # Check Docker is running\n    check_docker_running || exit 1\n\n    # Check Docker login\n    check_docker_login \"$REGISTRY\" || exit 1\n\n    # Validate CUDA version\n    validate_cuda_version \"$CUDA_VERSION\" || exit 1\n\n    # Check repository is provided\n    if [ -z \"$REPOSITORY\" ]; then\n        log_error \"Repository name is required\"\n        log_error \"Use --repository to specify (e.g., kvcache/ktransformers)\"\n        exit 1\n    fi\n    log_info \"Target repository: $REGISTRY/$REPOSITORY\"\n\n    # Check Dockerfile exists\n    if [ ! -f \"$DOCKERFILE\" ]; then\n        log_error \"Dockerfile not found: $DOCKERFILE\"\n        exit 1\n    fi\n    log_info \"Using Dockerfile: $DOCKERFILE\"\n\n    # Check context directory exists\n    if [ ! -d \"$CONTEXT_DIR\" ]; then\n        log_error \"Context directory not found: $CONTEXT_DIR\"\n        exit 1\n    fi\n    log_info \"Using context directory: $CONTEXT_DIR\"\n\n    # Validate functionality mode\n    if [[ \"$FUNCTIONALITY\" != \"sft\" && \"$FUNCTIONALITY\" != \"infer\" ]]; then\n        log_error \"Invalid functionality mode: $FUNCTIONALITY\"\n        log_error \"Must be 'sft' or 'infer'\"\n        exit 1\n    fi\n\n    log_success \"Configuration validated\"\n}\n\n################################################################################\n# Build Docker Image\n################################################################################\n\nbuild_image() {\n    local temp_tag=\"ktransformers:temp-push-$(get_beijing_timestamp)\"\n\n    # Check if we should skip build\n    if [ \"$SKIP_BUILD\" = true ]; then\n        log_info \"Checking for existing local image...\"\n        # Try to find an existing image\n        # This is a best-effort search for recent builds\n        local existing_image\n        existing_image=$(docker images --format \"{{.Repository}}:{{.Tag}}\" | grep \"ktransformers:temp-\" | head -1 || echo \"\")\n\n        if [ -n \"$existing_image\" ]; then\n            log_info \"Found existing image: $existing_image\"\n            echo \"$existing_image\"\n            return 0\n        else\n            log_warning \"No existing image found, will build\"\n        fi\n    fi\n\n    log_step \"Building Docker image\" >&2\n    log_info \"Temporary tag: $temp_tag\" >&2\n\n    # Prepare build arguments\n    local build_args=()\n    build_args+=(\"--build-arg\" \"CUDA_VERSION=$CUDA_VERSION\")\n    build_args+=(\"--build-arg\" \"UBUNTU_MIRROR=$UBUNTU_MIRROR\")\n    build_args+=(\"--build-arg\" \"CPU_VARIANT=$CPU_VARIANT\")\n    build_args+=(\"--build-arg\" \"BUILD_ALL_CPU_VARIANTS=1\")\n    build_args+=(\"--build-arg\" \"FUNCTIONALITY=$FUNCTIONALITY\")\n\n    # Add proxy settings if provided\n    if [ -n \"$HTTP_PROXY\" ]; then\n        build_args+=(\"--build-arg\" \"HTTP_PROXY=$HTTP_PROXY\")\n    fi\n    if [ -n \"$HTTPS_PROXY\" ]; then\n        build_args+=(\"--build-arg\" \"HTTPS_PROXY=$HTTPS_PROXY\")\n    fi\n\n    # Add extra build args\n    build_args+=(\"${EXTRA_BUILD_ARGS[@]}\")\n\n    # Add network host\n    build_args+=(\"--network\" \"host\")\n\n    # Build command\n    local build_cmd=(\n        docker build\n        -f \"$DOCKERFILE\"\n        \"${build_args[@]}\"\n        -t \"$temp_tag\"\n        \"$CONTEXT_DIR\"\n    )\n\n    # Display build command\n    {\n        log_info \"Build command:\"\n        echo \"  ${build_cmd[*]}\"\n    } >&2\n\n    if [ \"$DRY_RUN\" = true ]; then\n        log_warning \"DRY RUN: Skipping actual build\" >&2\n        return 0\n    fi\n\n    # Execute build\n    log_info \"Starting Docker build (this may take 30-60 minutes)...\" >&2\n    if \"${build_cmd[@]}\" >&2; then\n        log_success \"Docker image built successfully\" >&2\n        echo \"$temp_tag\"\n    else\n        log_error \"Docker build failed\" >&2\n        exit 1\n    fi\n}\n\n################################################################################\n# Generate Tags\n################################################################################\n\ngenerate_tags() {\n    local image_tag=\"$1\"\n    local timestamp=\"$2\"\n\n    if [ \"$DRY_RUN\" = true ]; then\n        log_warning \"DRY RUN: Using placeholder versions\"\n        # Use placeholder versions for dry run\n        local versions=\"SGLANG_VERSION=0.5.6\nKTRANSFORMERS_VERSION=0.4.3\nLLAMAFACTORY_VERSION=0.9.3\"\n    else\n        # Extract versions from image\n        local versions\n        versions=$(extract_versions_from_image \"$image_tag\")\n\n        if [ $? -ne 0 ]; then\n            log_error \"Failed to extract versions from image\"\n            exit 1\n        fi\n\n        # Validate versions\n        if ! validate_versions \"$versions\"; then\n            log_error \"Version validation failed\"\n            exit 1\n        fi\n    fi\n\n    # Generate full tag\n    local full_tag\n    full_tag=$(generate_image_name \"$versions\" \"$CUDA_VERSION\" \"$CPU_VARIANT\" \"$FUNCTIONALITY\" \"$timestamp\")\n\n    if [ -z \"$full_tag\" ]; then\n        log_error \"Failed to generate image name\"\n        exit 1\n    fi\n\n    echo \"FULL_TAG=$full_tag\"\n\n    # Generate simplified tag if requested\n    if [ \"$ALSO_PUSH_SIMPLIFIED\" = true ]; then\n        local ktrans_ver\n        ktrans_ver=$(echo \"$versions\" | grep \"^KTRANSFORMERS_VERSION=\" | cut -d= -f2)\n\n        local simplified_tag\n        simplified_tag=$(generate_simplified_tag \"$ktrans_ver\" \"$CUDA_VERSION\")\n\n        echo \"SIMPLIFIED_TAG=$simplified_tag\"\n    fi\n}\n\n################################################################################\n# Push to Registry\n################################################################################\n\npush_image_with_retry() {\n    local source_tag=\"$1\"\n    local target_tag=\"$2\"\n    local attempt=1\n\n    log_step \"Pushing image: $target_tag\"\n\n    if [ \"$DRY_RUN\" = true ]; then\n        log_warning \"DRY RUN: Skipping actual push\"\n        log_info \"Would execute:\"\n        echo \"  docker tag $source_tag $target_tag\"\n        echo \"  docker push $target_tag\"\n        return 0\n    fi\n\n    # Tag the image\n    log_info \"Tagging image...\"\n    if ! docker tag \"$source_tag\" \"$target_tag\"; then\n        log_error \"Failed to tag image\"\n        return 1\n    fi\n\n    # Push with retry logic\n    while [ $attempt -le \"$MAX_RETRIES\" ]; do\n        log_info \"Push attempt $attempt/$MAX_RETRIES...\"\n\n        if docker push \"$target_tag\"; then\n            log_success \"Successfully pushed: $target_tag\"\n            return 0\n        else\n            log_warning \"Push failed (attempt $attempt/$MAX_RETRIES)\"\n\n            if [ $attempt -lt \"$MAX_RETRIES\" ]; then\n                log_info \"Retrying in ${RETRY_DELAY} seconds...\"\n                sleep \"$RETRY_DELAY\"\n            fi\n\n            ((attempt++))\n        fi\n    done\n\n    log_error \"Failed to push after $MAX_RETRIES attempts\"\n    return 1\n}\n\n################################################################################\n# Main\n################################################################################\n\nmain() {\n    log_step \"KTransformers Docker Image Build and Push\"\n\n    # Parse arguments\n    parse_args \"$@\"\n\n    # Validate configuration\n    validate_config\n\n    # Generate timestamp\n    TIMESTAMP=$(get_beijing_timestamp)\n    log_info \"Build timestamp: $TIMESTAMP\"\n\n    # Display configuration\n    display_summary \"Push Configuration\" \\\n        \"CUDA Version: $CUDA_VERSION\" \\\n        \"Ubuntu Mirror: $UBUNTU_MIRROR\" \\\n        \"CPU Variant: $CPU_VARIANT\" \\\n        \"Functionality: $FUNCTIONALITY\" \\\n        \"Registry: $REGISTRY\" \\\n        \"Repository: $REPOSITORY\" \\\n        \"Push Simplified: $ALSO_PUSH_SIMPLIFIED\" \\\n        \"Skip Build: $SKIP_BUILD\" \\\n        \"HTTP Proxy: ${HTTP_PROXY:-<not set>}\" \\\n        \"HTTPS Proxy: ${HTTPS_PROXY:-<not set>}\" \\\n        \"Dockerfile: $DOCKERFILE\" \\\n        \"Context Dir: $CONTEXT_DIR\" \\\n        \"Timestamp: $TIMESTAMP\" \\\n        \"Dry Run: $DRY_RUN\"\n\n    # Build image\n    TEMP_TAG=$(build_image)\n\n    if [ \"$DRY_RUN\" = true ]; then\n        TEMP_TAG=\"ktransformers:temp-dryrun\"\n    fi\n\n    # Generate tags\n    log_step \"Generating tags\"\n    TAG_INFO=$(generate_tags \"$TEMP_TAG\" \"$TIMESTAMP\")\n\n    # Parse tag info\n    FULL_TAG=$(echo \"$TAG_INFO\" | grep \"^FULL_TAG=\" | cut -d= -f2)\n    SIMPLIFIED_TAG=$(echo \"$TAG_INFO\" | grep \"^SIMPLIFIED_TAG=\" | cut -d= -f2 || echo \"\")\n\n    log_info \"Full tag: $FULL_TAG\"\n    if [ -n \"$SIMPLIFIED_TAG\" ]; then\n        log_info \"Simplified tag: $SIMPLIFIED_TAG\"\n    fi\n\n    # Push full tag\n    FULL_IMAGE=\"$REGISTRY/$REPOSITORY:$FULL_TAG\"\n    if ! push_image_with_retry \"$TEMP_TAG\" \"$FULL_IMAGE\"; then\n        log_error \"Failed to push full tag\"\n        exit 1\n    fi\n\n    # Push simplified tag if requested\n    if [ -n \"$SIMPLIFIED_TAG\" ]; then\n        SIMPLIFIED_IMAGE=\"$REGISTRY/$REPOSITORY:$SIMPLIFIED_TAG\"\n        if ! push_image_with_retry \"$TEMP_TAG\" \"$SIMPLIFIED_IMAGE\"; then\n            log_warning \"Failed to push simplified tag, but continuing...\"\n        fi\n    fi\n\n    # Cleanup temporary image\n    if [ \"$DRY_RUN\" = false ]; then\n        log_step \"Cleaning up temporary image\"\n        cleanup_temp_images \"$TEMP_TAG\"\n    fi\n\n    # Display summary\n    local summary_lines=(\n        \"Successfully pushed images:\"\n        \"\"\n        \"Full tag:\"\n        \"  $FULL_IMAGE\"\n        \"\"\n    )\n\n    if [ -n \"$SIMPLIFIED_TAG\" ]; then\n        summary_lines+=(\n            \"Simplified tag:\"\n            \"  $SIMPLIFIED_IMAGE\"\n            \"\"\n        )\n    fi\n\n    summary_lines+=(\n        \"To pull the image:\"\n        \"  docker pull $FULL_IMAGE\"\n        \"\"\n        \"To run the container:\"\n        \"  docker run -it --rm $FULL_IMAGE /bin/bash\"\n    )\n\n    display_summary \"Push Complete\" \"${summary_lines[@]}\"\n\n    log_success \"All done!\"\n}\n\n# Run main function\nmain \"$@\"\n#!/usr/bin/env bash\n#\n# push-to-dockerhub.sh - Build and push Docker image to DockerHub\n#\n# This script builds a Docker image for ktransformers with standardized naming\n# and pushes it to DockerHub with both full and simplified tags.\n#\n# Features:\n# - Automatic version detection\n# - Standardized naming convention\n# - Multi-CPU variant support (AMX/AVX512/AVX2)\n# - Full and simplified tag support\n# - Retry logic for network failures\n# - Comprehensive error handling\n#\n# Usage:\n#   ./push-to-dockerhub.sh [OPTIONS]\n#\n# Example:\n#   ./push-to-dockerhub.sh \\\n#     --cuda-version 12.8.1 \\\n#     --repository kvcache/ktransformers \\\n#     --also-push-simplified\n\nset -euo pipefail\n\n# Get script directory\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n\n# Source utility functions\n# shellcheck source=docker-utils.sh\nsource \"$SCRIPT_DIR/docker-utils.sh\"\n\n################################################################################\n# Default Configuration\n################################################################################\n\n# Build parameters\nCUDA_VERSION=\"12.8.1\"\nUBUNTU_MIRROR=\"0\"\nHTTP_PROXY=\"\"\nHTTPS_PROXY=\"\"\nCPU_VARIANT=\"x86-intel-multi\"\nFUNCTIONALITY=\"sft\"\n\n# Paths\nDOCKERFILE=\"$SCRIPT_DIR/Dockerfile\"\nCONTEXT_DIR=\"$SCRIPT_DIR\"\n\n# Registry settings\nREGISTRY=\"docker.io\"\nREPOSITORY=\"\"  # Must be provided by user\n\n# Options\nDRY_RUN=false\nSKIP_BUILD=false\nALSO_PUSH_SIMPLIFIED=false\nMAX_RETRIES=3\nRETRY_DELAY=5\nEXTRA_BUILD_ARGS=()\n\n################################################################################\n# Help Message\n################################################################################\n\nusage() {\n    cat <<EOF\nUsage: $0 [OPTIONS]\n\nBuild and push Docker image to DockerHub with standardized naming.\n\nOPTIONS:\n    Build Configuration:\n        --cuda-version VERSION      CUDA version (default: 12.8.1)\n                                   Examples: 12.8.1, 12.6.1, 13.0.1\n\n        --ubuntu-mirror 0|1         Use Tsinghua mirror for Ubuntu packages\n                                   (default: 0)\n\n        --http-proxy URL           HTTP proxy URL\n                                   Example: http://127.0.0.1:16981\n\n        --https-proxy URL          HTTPS proxy URL\n                                   Example: http://127.0.0.1:16981\n\n        --cpu-variant VARIANT      CPU variant identifier\n                                   (default: x86-intel-multi)\n\n        --functionality TYPE       Functionality mode: sft or infer\n                                   (default: sft, includes LLaMA-Factory)\n\n    Paths:\n        --dockerfile PATH          Path to Dockerfile\n                                   (default: ./Dockerfile)\n\n        --context-dir PATH         Docker build context directory\n                                   (default: .)\n\n    Registry Settings:\n        --registry REGISTRY        Docker registry (default: docker.io)\n                                   Examples: docker.io, ghcr.io\n\n        --repository REPO          Repository name (REQUIRED)\n                                   Example: kvcache/ktransformers\n\n    Options:\n        --skip-build               Skip build if image exists locally\n        --also-push-simplified     Also push simplified tag (v{ver}-{cuda})\n        --max-retries N            Maximum push retries (default: 3)\n        --retry-delay SECONDS      Delay between retries (default: 5)\n        --dry-run                  Preview commands without executing\n        --build-arg KEY=VALUE      Additional build arguments (can be repeated)\n        -h, --help                 Show this help message\n\nEXAMPLES:\n    # Basic push\n    $0 --repository kvcache/ktransformers\n\n    # Push with simplified tag\n    $0 \\\\\n        --repository kvcache/ktransformers \\\\\n        --cuda-version 12.8.1 \\\\\n        --also-push-simplified\n\n    # Skip build if image exists\n    $0 \\\\\n        --repository kvcache/ktransformers \\\\\n        --skip-build\n\n    # Dry run to preview\n    $0 --repository kvcache/ktransformers --dry-run\n\nOUTPUT:\n    The image will be pushed with tags:\n\n    Full tag:\n      {registry}/{repository}:sglang-v{ver}_ktransformers-v{ver}_{cpu}_{gpu}_{func}_{timestamp}\n\n    Example:\n      docker.io/kvcache/ktransformers:sglang-v0.5.6_ktransformers-v0.4.3_x86-intel-multi_cu128_sft_llamafactory-v0.9.3_20241212143022\n\n    Simplified tag (if --also-push-simplified):\n      {registry}/{repository}:v{ktransformers-ver}-{cuda}\n\n    Example:\n      docker.io/kvcache/ktransformers:v0.4.3-cu128\n\nEOF\n    exit 0\n}\n\n################################################################################\n# Argument Parsing\n################################################################################\n\nparse_args() {\n    while [[ $# -gt 0 ]]; do\n        case \"$1\" in\n            --cuda-version)\n                CUDA_VERSION=\"$2\"\n                shift 2\n                ;;\n            --ubuntu-mirror)\n                UBUNTU_MIRROR=\"$2\"\n                shift 2\n                ;;\n            --http-proxy)\n                HTTP_PROXY=\"$2\"\n                shift 2\n                ;;\n            --https-proxy)\n                HTTPS_PROXY=\"$2\"\n                shift 2\n                ;;\n            --cpu-variant)\n                CPU_VARIANT=\"$2\"\n                shift 2\n                ;;\n            --functionality)\n                FUNCTIONALITY=\"$2\"\n                shift 2\n                ;;\n            --dockerfile)\n                DOCKERFILE=\"$2\"\n                shift 2\n                ;;\n            --context-dir)\n                CONTEXT_DIR=\"$2\"\n                shift 2\n                ;;\n            --registry)\n                REGISTRY=\"$2\"\n                shift 2\n                ;;\n            --repository)\n                REPOSITORY=\"$2\"\n                shift 2\n                ;;\n            --skip-build)\n                SKIP_BUILD=true\n                shift\n                ;;\n            --also-push-simplified)\n                ALSO_PUSH_SIMPLIFIED=true\n                shift\n                ;;\n            --max-retries)\n                MAX_RETRIES=\"$2\"\n                shift 2\n                ;;\n            --retry-delay)\n                RETRY_DELAY=\"$2\"\n                shift 2\n                ;;\n            --dry-run)\n                DRY_RUN=true\n                shift\n                ;;\n            --build-arg)\n                EXTRA_BUILD_ARGS+=(\"--build-arg\" \"$2\")\n                shift 2\n                ;;\n            -h|--help)\n                usage\n                ;;\n            *)\n                log_error \"Unknown option: $1\"\n                echo \"Use -h or --help for usage information\"\n                exit 1\n                ;;\n        esac\n    done\n}\n\n################################################################################\n# Validation\n################################################################################\n\nvalidate_config() {\n    log_step \"Validating configuration\"\n\n    # Check Docker is running\n    check_docker_running || exit 1\n\n    # Check Docker login\n    check_docker_login \"$REGISTRY\" || exit 1\n\n    # Validate CUDA version\n    validate_cuda_version \"$CUDA_VERSION\" || exit 1\n\n    # Check repository is provided\n    if [ -z \"$REPOSITORY\" ]; then\n        log_error \"Repository name is required\"\n        log_error \"Use --repository to specify (e.g., kvcache/ktransformers)\"\n        exit 1\n    fi\n    log_info \"Target repository: $REGISTRY/$REPOSITORY\"\n\n    # Check Dockerfile exists\n    if [ ! -f \"$DOCKERFILE\" ]; then\n        log_error \"Dockerfile not found: $DOCKERFILE\"\n        exit 1\n    fi\n    log_info \"Using Dockerfile: $DOCKERFILE\"\n\n    # Check context directory exists\n    if [ ! -d \"$CONTEXT_DIR\" ]; then\n        log_error \"Context directory not found: $CONTEXT_DIR\"\n        exit 1\n    fi\n    log_info \"Using context directory: $CONTEXT_DIR\"\n\n    # Validate functionality mode\n    if [[ \"$FUNCTIONALITY\" != \"sft\" && \"$FUNCTIONALITY\" != \"infer\" ]]; then\n        log_error \"Invalid functionality mode: $FUNCTIONALITY\"\n        log_error \"Must be 'sft' or 'infer'\"\n        exit 1\n    fi\n\n    log_success \"Configuration validated\"\n}\n\n################################################################################\n# Build Docker Image\n################################################################################\n\nbuild_image() {\n    local temp_tag=\"ktransformers:temp-push-$(get_beijing_timestamp)\"\n\n    # Check if we should skip build\n    if [ \"$SKIP_BUILD\" = true ]; then\n        log_info \"Checking for existing local image...\"\n        # Try to find an existing image\n        # This is a best-effort search for recent builds\n        local existing_image\n        existing_image=$(docker images --format \"{{.Repository}}:{{.Tag}}\" | grep \"ktransformers:temp-\" | head -1 || echo \"\")\n\n        if [ -n \"$existing_image\" ]; then\n            log_info \"Found existing image: $existing_image\"\n            echo \"$existing_image\"\n            return 0\n        else\n            log_warning \"No existing image found, will build\"\n        fi\n    fi\n\n    log_step \"Building Docker image\" >&2\n    log_info \"Temporary tag: $temp_tag\" >&2\n\n    # Prepare build arguments\n    local build_args=()\n    build_args+=(\"--build-arg\" \"CUDA_VERSION=$CUDA_VERSION\")\n    build_args+=(\"--build-arg\" \"UBUNTU_MIRROR=$UBUNTU_MIRROR\")\n    build_args+=(\"--build-arg\" \"CPU_VARIANT=$CPU_VARIANT\")\n    build_args+=(\"--build-arg\" \"BUILD_ALL_CPU_VARIANTS=1\")\n    build_args+=(\"--build-arg\" \"FUNCTIONALITY=$FUNCTIONALITY\")\n\n    # Add proxy settings if provided\n    if [ -n \"$HTTP_PROXY\" ]; then\n        build_args+=(\"--build-arg\" \"HTTP_PROXY=$HTTP_PROXY\")\n    fi\n    if [ -n \"$HTTPS_PROXY\" ]; then\n        build_args+=(\"--build-arg\" \"HTTPS_PROXY=$HTTPS_PROXY\")\n    fi\n\n    # Add extra build args\n    build_args+=(\"${EXTRA_BUILD_ARGS[@]}\")\n\n    # Add network host\n    build_args+=(\"--network\" \"host\")\n\n    # Build command\n    local build_cmd=(\n        docker build\n        -f \"$DOCKERFILE\"\n        \"${build_args[@]}\"\n        -t \"$temp_tag\"\n        \"$CONTEXT_DIR\"\n    )\n\n    # Display build command\n    {\n        log_info \"Build command:\"\n        echo \"  ${build_cmd[*]}\"\n    } >&2\n\n    if [ \"$DRY_RUN\" = true ]; then\n        log_warning \"DRY RUN: Skipping actual build\" >&2\n        return 0\n    fi\n\n    # Execute build\n    log_info \"Starting Docker build (this may take 30-60 minutes)...\" >&2\n    if \"${build_cmd[@]}\" >&2; then\n        log_success \"Docker image built successfully\" >&2\n        echo \"$temp_tag\"\n    else\n        log_error \"Docker build failed\" >&2\n        exit 1\n    fi\n}\n\n################################################################################\n# Generate Tags\n################################################################################\n\ngenerate_tags() {\n    local image_tag=\"$1\"\n    local timestamp=\"$2\"\n\n    if [ \"$DRY_RUN\" = true ]; then\n        log_warning \"DRY RUN: Using placeholder versions\"\n        # Use placeholder versions for dry run\n        local versions=\"SGLANG_VERSION=0.5.6\nKTRANSFORMERS_VERSION=0.4.3\nLLAMAFACTORY_VERSION=0.9.3\"\n    else\n        # Extract versions from image\n        local versions\n        versions=$(extract_versions_from_image \"$image_tag\")\n\n        if [ $? -ne 0 ]; then\n            log_error \"Failed to extract versions from image\"\n            exit 1\n        fi\n\n        # Validate versions\n        if ! validate_versions \"$versions\"; then\n            log_error \"Version validation failed\"\n            exit 1\n        fi\n    fi\n\n    # Generate full tag\n    local full_tag\n    full_tag=$(generate_image_name \"$versions\" \"$CUDA_VERSION\" \"$CPU_VARIANT\" \"$FUNCTIONALITY\" \"$timestamp\")\n\n    if [ -z \"$full_tag\" ]; then\n        log_error \"Failed to generate image name\"\n        exit 1\n    fi\n\n    echo \"FULL_TAG=$full_tag\"\n\n    # Generate simplified tag if requested\n    if [ \"$ALSO_PUSH_SIMPLIFIED\" = true ]; then\n        local ktrans_ver\n        ktrans_ver=$(echo \"$versions\" | grep \"^KTRANSFORMERS_VERSION=\" | cut -d= -f2)\n\n        local simplified_tag\n        simplified_tag=$(generate_simplified_tag \"$ktrans_ver\" \"$CUDA_VERSION\")\n\n        echo \"SIMPLIFIED_TAG=$simplified_tag\"\n    fi\n}\n\n################################################################################\n# Push to Registry\n################################################################################\n\npush_image_with_retry() {\n    local source_tag=\"$1\"\n    local target_tag=\"$2\"\n    local attempt=1\n\n    log_step \"Pushing image: $target_tag\"\n\n    if [ \"$DRY_RUN\" = true ]; then\n        log_warning \"DRY RUN: Skipping actual push\"\n        log_info \"Would execute:\"\n        echo \"  docker tag $source_tag $target_tag\"\n        echo \"  docker push $target_tag\"\n        return 0\n    fi\n\n    # Tag the image\n    log_info \"Tagging image...\"\n    if ! docker tag \"$source_tag\" \"$target_tag\"; then\n        log_error \"Failed to tag image\"\n        return 1\n    fi\n\n    # Push with retry logic\n    while [ $attempt -le \"$MAX_RETRIES\" ]; do\n        log_info \"Push attempt $attempt/$MAX_RETRIES...\"\n\n        if docker push \"$target_tag\"; then\n            log_success \"Successfully pushed: $target_tag\"\n            return 0\n        else\n            log_warning \"Push failed (attempt $attempt/$MAX_RETRIES)\"\n\n            if [ $attempt -lt \"$MAX_RETRIES\" ]; then\n                log_info \"Retrying in ${RETRY_DELAY} seconds...\"\n                sleep \"$RETRY_DELAY\"\n            fi\n\n            ((attempt++))\n        fi\n    done\n\n    log_error \"Failed to push after $MAX_RETRIES attempts\"\n    return 1\n}\n\n################################################################################\n# Main\n################################################################################\n\nmain() {\n    log_step \"KTransformers Docker Image Build and Push\"\n\n    # Parse arguments\n    parse_args \"$@\"\n\n    # Validate configuration\n    validate_config\n\n    # Generate timestamp\n    TIMESTAMP=$(get_beijing_timestamp)\n    log_info \"Build timestamp: $TIMESTAMP\"\n\n    # Display configuration\n    display_summary \"Push Configuration\" \\\n        \"CUDA Version: $CUDA_VERSION\" \\\n        \"Ubuntu Mirror: $UBUNTU_MIRROR\" \\\n        \"CPU Variant: $CPU_VARIANT\" \\\n        \"Functionality: $FUNCTIONALITY\" \\\n        \"Registry: $REGISTRY\" \\\n        \"Repository: $REPOSITORY\" \\\n        \"Push Simplified: $ALSO_PUSH_SIMPLIFIED\" \\\n        \"Skip Build: $SKIP_BUILD\" \\\n        \"HTTP Proxy: ${HTTP_PROXY:-<not set>}\" \\\n        \"HTTPS Proxy: ${HTTPS_PROXY:-<not set>}\" \\\n        \"Dockerfile: $DOCKERFILE\" \\\n        \"Context Dir: $CONTEXT_DIR\" \\\n        \"Timestamp: $TIMESTAMP\" \\\n        \"Dry Run: $DRY_RUN\"\n\n    # Build image\n    TEMP_TAG=$(build_image)\n\n    if [ \"$DRY_RUN\" = true ]; then\n        TEMP_TAG=\"ktransformers:temp-dryrun\"\n    fi\n\n    # Generate tags\n    log_step \"Generating tags\"\n    TAG_INFO=$(generate_tags \"$TEMP_TAG\" \"$TIMESTAMP\")\n\n    # Parse tag info\n    FULL_TAG=$(echo \"$TAG_INFO\" | grep \"^FULL_TAG=\" | cut -d= -f2)\n    SIMPLIFIED_TAG=$(echo \"$TAG_INFO\" | grep \"^SIMPLIFIED_TAG=\" | cut -d= -f2 || echo \"\")\n\n    log_info \"Full tag: $FULL_TAG\"\n    if [ -n \"$SIMPLIFIED_TAG\" ]; then\n        log_info \"Simplified tag: $SIMPLIFIED_TAG\"\n    fi\n\n    # Push full tag\n    FULL_IMAGE=\"$REGISTRY/$REPOSITORY:$FULL_TAG\"\n    if ! push_image_with_retry \"$TEMP_TAG\" \"$FULL_IMAGE\"; then\n        log_error \"Failed to push full tag\"\n        exit 1\n    fi\n\n    # Push simplified tag if requested\n    if [ -n \"$SIMPLIFIED_TAG\" ]; then\n        SIMPLIFIED_IMAGE=\"$REGISTRY/$REPOSITORY:$SIMPLIFIED_TAG\"\n        if ! push_image_with_retry \"$TEMP_TAG\" \"$SIMPLIFIED_IMAGE\"; then\n            log_warning \"Failed to push simplified tag, but continuing...\"\n        fi\n    fi\n\n    # Cleanup temporary image\n    if [ \"$DRY_RUN\" = false ]; then\n        log_step \"Cleaning up temporary image\"\n        cleanup_temp_images \"$TEMP_TAG\"\n    fi\n\n    # Display summary\n    local summary_lines=(\n        \"Successfully pushed images:\"\n        \"\"\n        \"Full tag:\"\n        \"  $FULL_IMAGE\"\n        \"\"\n    )\n\n    if [ -n \"$SIMPLIFIED_TAG\" ]; then\n        summary_lines+=(\n            \"Simplified tag:\"\n            \"  $SIMPLIFIED_IMAGE\"\n            \"\"\n        )\n    fi\n\n    summary_lines+=(\n        \"To pull the image:\"\n        \"  docker pull $FULL_IMAGE\"\n        \"\"\n        \"To run the container:\"\n        \"  docker run -it --rm $FULL_IMAGE /bin/bash\"\n    )\n\n    display_summary \"Push Complete\" \"${summary_lines[@]}\"\n\n    log_success \"All done!\"\n}\n\n# Run main function\nmain \"$@\"\n"
  },
  {
    "path": "install.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\n# Resolve the repository root (directory containing this script)\nREPO_ROOT=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n\nusage() {\n  cat <<EOF\nUsage: $0 [SUBCOMMAND] [OPTIONS]\n\nOne-click installer for ktransformers (sglang + kt-kernel).\n\nSUBCOMMANDS:\n  all             Full install: submodules → sglang → kt-kernel (default)\n  sglang          Install sglang only\n  kt-kernel       Install kt-kernel only\n  deps            Install system dependencies only\n  -h, --help      Show this help message\n\nOPTIONS:\n  --skip-sglang       Skip sglang installation (for \"all\" subcommand)\n  --skip-kt-kernel    Skip kt-kernel installation (for \"all\" subcommand)\n  --editable          Install sglang in editable/dev mode (-e)\n  --manual            Pass through to kt-kernel (manual CPU config)\n  --no-clean          Pass through to kt-kernel (skip build clean)\n\nEXAMPLES:\n  # Full install (recommended)\n  $0\n\n  # Install everything in editable mode for development\n  $0 all --editable\n\n  # Install sglang only\n  $0 sglang\n\n  # Install kt-kernel only (manual CPU config)\n  $0 kt-kernel --manual\n\n  # Full install, skip sglang (already installed)\n  $0 all --skip-sglang\n\nEOF\n  exit 1\n}\n\n# ─── Helpers ───────────────────────────────────────────────────────────────────\n\nlog_step() {\n  echo \"\"\n  echo \"==========================================\"\n  echo \"  $1\"\n  echo \"==========================================\"\n  echo \"\"\n}\n\nlog_info() {\n  echo \"[INFO] $1\"\n}\n\nlog_warn() {\n  echo \"[WARN] $1\"\n}\n\nlog_error() {\n  echo \"[ERROR] $1\" >&2\n}\n\n# Read ktransformers version from version.py and export for sglang-kt\nread_kt_version() {\n  local version_file=\"$REPO_ROOT/version.py\"\n  if [ -f \"$version_file\" ]; then\n    KT_VERSION=$(python3 -c \"exec(open('$version_file').read()); print(__version__)\")\n    export SGLANG_KT_VERSION=\"$KT_VERSION\"\n    log_info \"ktransformers version: $KT_VERSION (will be used for sglang-kt)\"\n  else\n    log_warn \"version.py not found; sglang-kt will use its default version\"\n  fi\n}\n\n# ─── Submodule init ────────────────────────────────────────────────────────────\n\ninit_submodules() {\n  log_step \"Initializing git submodules\"\n\n  if [ ! -d \"$REPO_ROOT/.git\" ]; then\n    log_warn \"Not a git repository. Skipping submodule init.\"\n    log_warn \"If you need sglang, clone with: git clone --recursive https://github.com/kvcache-ai/ktransformers.git\"\n    return 0\n  fi\n\n  cd \"$REPO_ROOT\"\n  git submodule update --init --recursive\n  log_info \"Submodules initialized successfully.\"\n}\n\n# ─── sglang install ───────────────────────────────────────────────────────────\n\ninstall_sglang() {\n  local editable=\"${1:-0}\"\n\n  log_step \"Installing sglang (kvcache-ai fork)\"\n\n  local sglang_dir=\"$REPO_ROOT/third_party/sglang\"\n  local pyproject=\"$sglang_dir/python/pyproject.toml\"\n\n  if [ ! -f \"$pyproject\" ]; then\n    log_error \"sglang source not found at $sglang_dir\"\n    log_error \"Run 'git submodule update --init --recursive' first, or clone with --recursive.\"\n    exit 1\n  fi\n\n  cd \"$sglang_dir\"\n\n  if [ \"$editable\" = \"1\" ]; then\n    log_info \"Installing sglang in editable mode...\"\n    pip install -e \"./python[all]\"\n  else\n    log_info \"Installing sglang...\"\n    pip install \"./python[all]\"\n  fi\n\n  log_info \"sglang installed successfully.\"\n}\n\n# ─── kt-kernel install ────────────────────────────────────────────────────────\n\ninstall_kt_kernel() {\n  # Forward all remaining args to kt-kernel/install.sh\n  local kt_args=(\"$@\")\n\n  log_step \"Installing kt-kernel\"\n\n  local kt_install=\"$REPO_ROOT/kt-kernel/install.sh\"\n\n  if [ ! -f \"$kt_install\" ]; then\n    log_error \"kt-kernel/install.sh not found at $kt_install\"\n    exit 1\n  fi\n\n  cd \"$REPO_ROOT/kt-kernel\"\n  bash ./install.sh build \"${kt_args[@]}\"\n}\n\n# ─── deps install ─────────────────────────────────────────────────────────────\n\ninstall_deps() {\n  log_step \"Installing system dependencies\"\n\n  local kt_install=\"$REPO_ROOT/kt-kernel/install.sh\"\n\n  if [ ! -f \"$kt_install\" ]; then\n    log_error \"kt-kernel/install.sh not found at $kt_install\"\n    exit 1\n  fi\n\n  cd \"$REPO_ROOT/kt-kernel\"\n  bash ./install.sh deps\n}\n\n# ─── \"all\" subcommand ─────────────────────────────────────────────────────────\n\ninstall_all() {\n  local skip_sglang=0\n  local skip_kt_kernel=0\n  local editable=0\n  local kt_args=()\n\n  while [[ $# -gt 0 ]]; do\n    case \"$1\" in\n      --skip-sglang)    skip_sglang=1; shift ;;\n      --skip-kt-kernel) skip_kt_kernel=1; shift ;;\n      --editable)       editable=1; shift ;;\n      --manual)         kt_args+=(\"--manual\"); shift ;;\n      --no-clean)       kt_args+=(\"--no-clean\"); shift ;;\n      -h|--help)        usage ;;\n      *)\n        log_error \"Unknown option: $1\"\n        usage\n        ;;\n    esac\n  done\n\n  # 1. Init submodules\n  init_submodules\n\n  # 2. System dependencies\n  install_deps\n\n  # 3. Read version for sglang-kt\n  read_kt_version\n\n  # 4. Install sglang\n  if [ \"$skip_sglang\" = \"0\" ]; then\n    install_sglang \"$editable\"\n  else\n    log_info \"Skipping sglang installation (--skip-sglang).\"\n  fi\n\n  # 4. Build & install kt-kernel\n  if [ \"$skip_kt_kernel\" = \"0\" ]; then\n    install_kt_kernel \"${kt_args[@]}\"\n  else\n    log_info \"Skipping kt-kernel installation (--skip-kt-kernel).\"\n  fi\n\n  log_step \"Installation complete!\"\n  echo \"  Verify with: kt doctor\"\n  echo \"\"\n}\n\n# ─── Subcommand dispatcher ────────────────────────────────────────────────────\n\nSUBCMD=\"all\"\nif [[ $# -gt 0 ]]; then\n  case \"$1\" in\n    all|sglang|kt-kernel|deps)\n      SUBCMD=\"$1\"\n      shift\n      ;;\n    -h|--help)\n      usage\n      ;;\n    -*)\n      # Flags without subcommand → default to \"all\"\n      SUBCMD=\"all\"\n      ;;\n    *)\n      log_error \"Unknown subcommand: $1\"\n      usage\n      ;;\n  esac\nfi\n\ncase \"$SUBCMD\" in\n  all)\n    install_all \"$@\"\n    ;;\n  sglang)\n    # Parse sglang-specific options\n    editable=0\n    while [[ $# -gt 0 ]]; do\n      case \"$1\" in\n        --editable) editable=1; shift ;;\n        -h|--help) usage ;;\n        *) log_error \"Unknown option for sglang: $1\"; usage ;;\n      esac\n    done\n    init_submodules\n    read_kt_version\n    install_sglang \"$editable\"\n    ;;\n  kt-kernel)\n    install_kt_kernel \"$@\"\n    ;;\n  deps)\n    install_deps\n    ;;\nesac\n"
  },
  {
    "path": "kt-kernel/.clang-format",
    "content": "---\nLanguage:        Cpp\n# BasedOnStyle:  Google\nAccessModifierOffset: -1\nAlignAfterOpenBracket: Align\nAlignArrayOfStructures: None\nAlignConsecutiveAssignments:\n  Enabled:         false\n  AcrossEmptyLines: false\n  AcrossComments:  false\n  AlignCompound:   false\n  AlignFunctionPointers: false\n  PadOperators:    true\nAlignConsecutiveBitFields:\n  Enabled:         false\n  AcrossEmptyLines: false\n  AcrossComments:  false\n  AlignCompound:   false\n  AlignFunctionPointers: false\n  PadOperators:    false\nAlignConsecutiveDeclarations:\n  Enabled:         false\n  AcrossEmptyLines: false\n  AcrossComments:  false\n  AlignCompound:   false\n  AlignFunctionPointers: false\n  PadOperators:    false\nAlignConsecutiveMacros:\n  Enabled:         false\n  AcrossEmptyLines: false\n  AcrossComments:  false\n  AlignCompound:   false\n  AlignFunctionPointers: false\n  PadOperators:    false\nAlignConsecutiveShortCaseStatements:\n  Enabled:         false\n  AcrossEmptyLines: false\n  AcrossComments:  false\n  AlignCaseColons: false\nAlignEscapedNewlines: Left\nAlignOperands:   Align\nAlignTrailingComments:\n  Kind:            Always\n  OverEmptyLines:  0\nAllowAllArgumentsOnNextLine: true\nAllowAllParametersOfDeclarationOnNextLine: true\nAllowBreakBeforeNoexceptSpecifier: Never\nAllowShortBlocksOnASingleLine: Never\nAllowShortCaseLabelsOnASingleLine: false\nAllowShortCompoundRequirementOnASingleLine: true\nAllowShortEnumsOnASingleLine: true\nAllowShortFunctionsOnASingleLine: All\nAllowShortIfStatementsOnASingleLine: WithoutElse\nAllowShortLambdasOnASingleLine: All\nAllowShortLoopsOnASingleLine: true\nAlwaysBreakAfterDefinitionReturnType: None\nAlwaysBreakAfterReturnType: None\nAlwaysBreakBeforeMultilineStrings: true\nAlwaysBreakTemplateDeclarations: Yes\nAttributeMacros:\n  - __capability\nBinPackArguments: true\nBinPackParameters: true\nBitFieldColonSpacing: Both\nBraceWrapping:\n  AfterCaseLabel:  false\n  AfterClass:      false\n  AfterControlStatement: Never\n  AfterEnum:       false\n  AfterExternBlock: false\n  AfterFunction:   false\n  AfterNamespace:  false\n  AfterObjCDeclaration: false\n  AfterStruct:     false\n  AfterUnion:      false\n  BeforeCatch:     false\n  BeforeElse:      false\n  BeforeLambdaBody: false\n  BeforeWhile:     false\n  IndentBraces:    false\n  SplitEmptyFunction: true\n  SplitEmptyRecord: true\n  SplitEmptyNamespace: true\nBreakAdjacentStringLiterals: true\nBreakAfterAttributes: Leave\nBreakAfterJavaFieldAnnotations: false\nBreakArrays:     true\nBreakBeforeBinaryOperators: None\nBreakBeforeConceptDeclarations: Always\nBreakBeforeBraces: Attach\nBreakBeforeInlineASMColon: OnlyMultiline\nBreakBeforeTernaryOperators: true\nBreakConstructorInitializers: BeforeColon\nBreakInheritanceList: BeforeColon\nBreakStringLiterals: true\nColumnLimit:     120\nCommentPragmas:  '^ IWYU pragma:'\nCompactNamespaces: false\nConstructorInitializerIndentWidth: 4\nContinuationIndentWidth: 4\nCpp11BracedListStyle: true\nDerivePointerAlignment: true\nDisableFormat:   false\nEmptyLineAfterAccessModifier: Never\nEmptyLineBeforeAccessModifier: LogicalBlock\nExperimentalAutoDetectBinPacking: false\nFixNamespaceComments: true\nForEachMacros:\n  - foreach\n  - Q_FOREACH\n  - BOOST_FOREACH\nIfMacros:\n  - KJ_IF_MAYBE\nIncludeBlocks:   Regroup\nIncludeCategories:\n  - Regex:           '^<ext/.*\\.h>'\n    Priority:        2\n    SortPriority:    0\n    CaseSensitive:   false\n  - Regex:           '^<.*\\.h>'\n    Priority:        1\n    SortPriority:    0\n    CaseSensitive:   false\n  - Regex:           '^<.*'\n    Priority:        2\n    SortPriority:    0\n    CaseSensitive:   false\n  - Regex:           '.*'\n    Priority:        3\n    SortPriority:    0\n    CaseSensitive:   false\nIncludeIsMainRegex: '([-_](test|unittest))?$'\nIncludeIsMainSourceRegex: ''\nIndentAccessModifiers: false\nIndentCaseBlocks: false\nIndentCaseLabels: true\nIndentExternBlock: AfterExternBlock\nIndentGotoLabels: true\nIndentPPDirectives: None\nIndentRequiresClause: true\nIndentWidth:     2\nIndentWrappedFunctionNames: false\nInsertBraces:    false\nInsertNewlineAtEOF: false\nInsertTrailingCommas: None\nIntegerLiteralSeparator:\n  Binary:          0\n  BinaryMinDigits: 0\n  Decimal:         0\n  DecimalMinDigits: 0\n  Hex:             0\n  HexMinDigits:    0\nJavaScriptQuotes: Leave\nJavaScriptWrapImports: true\nKeepEmptyLinesAtTheStartOfBlocks: false\nKeepEmptyLinesAtEOF: false\nLambdaBodyIndentation: Signature\nLineEnding:      DeriveLF\nMacroBlockBegin: ''\nMacroBlockEnd:   ''\nMaxEmptyLinesToKeep: 1\nNamespaceIndentation: None\nObjCBinPackProtocolList: Never\nObjCBlockIndentWidth: 2\nObjCBreakBeforeNestedBlockParam: true\nObjCSpaceAfterProperty: false\nObjCSpaceBeforeProtocolList: true\nPackConstructorInitializers: NextLine\nPenaltyBreakAssignment: 2\nPenaltyBreakBeforeFirstCallParameter: 1\nPenaltyBreakComment: 300\nPenaltyBreakFirstLessLess: 120\nPenaltyBreakOpenParenthesis: 0\nPenaltyBreakScopeResolution: 500\nPenaltyBreakString: 1000\nPenaltyBreakTemplateDeclaration: 10\nPenaltyExcessCharacter: 1000000\nPenaltyIndentedWhitespace: 0\nPenaltyReturnTypeOnItsOwnLine: 200\nPointerAlignment: Left\nPPIndentWidth:   -1\nQualifierAlignment: Leave\nRawStringFormats:\n  - Language:        Cpp\n    Delimiters:\n      - cc\n      - CC\n      - cpp\n      - Cpp\n      - CPP\n      - 'c++'\n      - 'C++'\n    CanonicalDelimiter: ''\n    BasedOnStyle:    google\n  - Language:        TextProto\n    Delimiters:\n      - pb\n      - PB\n      - proto\n      - PROTO\n    EnclosingFunctions:\n      - EqualsProto\n      - EquivToProto\n      - PARSE_PARTIAL_TEXT_PROTO\n      - PARSE_TEST_PROTO\n      - PARSE_TEXT_PROTO\n      - ParseTextOrDie\n      - ParseTextProtoOrDie\n      - ParseTestProto\n      - ParsePartialTestProto\n    CanonicalDelimiter: pb\n    BasedOnStyle:    google\nReferenceAlignment: Pointer\nReflowComments:  true\nRemoveBracesLLVM: false\nRemoveParentheses: Leave\nRemoveSemicolon: false\nRequiresClausePosition: OwnLine\nRequiresExpressionIndentation: OuterScope\nSeparateDefinitionBlocks: Leave\nShortNamespaceLines: 1\nSkipMacroDefinitionBody: false\nSortIncludes:    CaseSensitive\nSortJavaStaticImport: Before\nSortUsingDeclarations: LexicographicNumeric\nSpaceAfterCStyleCast: false\nSpaceAfterLogicalNot: false\nSpaceAfterTemplateKeyword: true\nSpaceAroundPointerQualifiers: Default\nSpaceBeforeAssignmentOperators: true\nSpaceBeforeCaseColon: false\nSpaceBeforeCpp11BracedList: false\nSpaceBeforeCtorInitializerColon: true\nSpaceBeforeInheritanceColon: true\nSpaceBeforeJsonColon: false\nSpaceBeforeParens: ControlStatements\nSpaceBeforeParensOptions:\n  AfterControlStatements: true\n  AfterForeachMacros: true\n  AfterFunctionDefinitionName: false\n  AfterFunctionDeclarationName: false\n  AfterIfMacros:   true\n  AfterOverloadedOperator: false\n  AfterPlacementOperator: true\n  AfterRequiresInClause: false\n  AfterRequiresInExpression: false\n  BeforeNonEmptyParentheses: false\nSpaceBeforeRangeBasedForLoopColon: true\nSpaceBeforeSquareBrackets: false\nSpaceInEmptyBlock: false\nSpacesBeforeTrailingComments: 2\nSpacesInAngles:  Never\nSpacesInContainerLiterals: true\nSpacesInLineCommentPrefix:\n  Minimum:         1\n  Maximum:         -1\nSpacesInParens:  Never\nSpacesInParensOptions:\n  InCStyleCasts:   false\n  InConditionalStatements: false\n  InEmptyParentheses: false\n  Other:           false\nSpacesInSquareBrackets: false\nStandard:        Auto\nStatementAttributeLikeMacros:\n  - Q_EMIT\nStatementMacros:\n  - Q_UNUSED\n  - QT_REQUIRE_VERSION\nTabWidth:        2\nUseTab:          Never\nVerilogBreakBetweenInstancePorts: true\nWhitespaceSensitiveMacros:\n  - BOOST_PP_STRINGIZE\n  - CF_SWIFT_NAME\n  - NS_SWIFT_NAME\n  - PP_STRINGIZE\n  - STRINGIZE\n...\n"
  },
  {
    "path": "kt-kernel/.githooks/commit-msg",
    "content": "#!/bin/sh\n# commit-msg hook to enforce Conventional Commits (https://www.conventionalcommits.org/)\n# This script checks the commit message subject (first line) for a conventional commit format.\n# If the message does not conform, the hook exits non-zero to block the commit.\n\n# Read the commit message (first line)\nif [ -z \"$1\" ]; then\n  echo \"commit-msg hook: no message file provided\" >&2\n  exit 0\nfi\n\nMSG_FILE=\"$1\"\nread -r FIRST_LINE < \"$MSG_FILE\" || FIRST_LINE=\"\"\n\n# Trim leading/trailing whitespace\nFIRST_LINE=\"$(echo \"$FIRST_LINE\" | sed -e 's/^[ \\t]*//' -e 's/[ \\t]*$//')\"\n\n# Allow empty message (let git handle it), or allow merges/reverts\ncase \"$FIRST_LINE\" in\n  Merge:*|merge:*|Revert:*|revert:*)\n    exit 0\n    ;;\nesac\n\n# Conventional Commit regex (POSIX ERE)\n# [type](scope)!?: subject\n# types: feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip\n# scope: any chars except )\n\nregex='^\\[(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert|wip)\\](\\([^\\)]+\\))?(!)?: .+'\n\nprintf \"%s\" \"$FIRST_LINE\" | grep -E \"$regex\" >/dev/null 2>&1\nif [ $? -eq 0 ]; then\n  exit 0\nfi\n\ncat <<'EOF' >&2\nERROR: Commit message does not follow Conventional Commits.\n\nExpected format:\n  [type](scope)?: subject\n\nExamples:\n  [feat]: add new feature\n  [fix(parser)]: handle edge case\n  [docs]!: update API docs (breaking change)\n\nAllowed types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert, wip\n\nYou can bypass this hook locally by running:\n  git commit --no-verify\nEOF\n\nexit 1\n"
  },
  {
    "path": "kt-kernel/.githooks/pre-commit",
    "content": "#!/usr/bin/bash\n# Pre-commit hook: run clang-format via kt-kernel's CMake 'format' target and Black for Python\n# before allowing commit. If formatting makes changes, stage them and abort so user can review.\nset -euo pipefail\n\nREPO_ROOT=\"$(git rev-parse --show-toplevel)\"\n# kt-kernel project directory within the monorepo\nKERNEL_DIR=\"$REPO_ROOT/kt-kernel\"\n# Relative path for matching staged files under repo root\nREL_KERNEL_DIR=\"kt-kernel\"\nBUILD_DIR=\"$KERNEL_DIR/build\"\nFORMAT_TARGET=\"format\"\nCLANG_FORMAT_BIN=\"${CLANG_FORMAT_BIN:-clang-format}\"\nBLACK_BIN=\"${BLACK_BIN:-black}\"\n\n# Simple check clang-format present (optional)\n# clang-format optional: if missing, skip C/C++ formatting\nif ! command -v \"$CLANG_FORMAT_BIN\" >/dev/null 2>&1; then\n  echo \"[pre-commit] clang-format not found (looked for $CLANG_FORMAT_BIN). Skipping C/C++ format.\" >&2\nfi\n\n# black optional: if missing, skip Python formatting\nif ! command -v \"$BLACK_BIN\" >/dev/null 2>&1; then\n  echo \"[pre-commit] black not found (looked for $BLACK_BIN). Skipping Python format.\" >&2\nfi\n\n## Format only staged changes within kt-kernel\n# Collect staged files (Added/Modified/Copied/Renamed)\nmapfile -d '' STAGED < <(git diff --cached --name-only -z --diff-filter=AMCR)\n\nPY_CHANGED=()\nCPP_CHANGED=()\n\nfor f in \"${STAGED[@]}\"; do\n  case \"$f\" in\n    \"$REL_KERNEL_DIR\"/*)\n      ext=\"${f##*.}\"\n      case \"$ext\" in\n        py)\n          PY_CHANGED+=(\"$f\")\n          ;;\n        c|cc|cpp|cxx|h|hh|hpp|hxx|cu|cuh)\n          CPP_CHANGED+=(\"$f\")\n          ;;\n      esac\n      ;;\n  esac\ndone\n\n# Run clang-format only on staged C/C++ files\nif command -v \"$CLANG_FORMAT_BIN\" >/dev/null 2>&1 && [ ${#CPP_CHANGED[@]} -gt 0 ]; then\n  echo \"[pre-commit] clang-format on ${#CPP_CHANGED[@]} files\" >&2\n  for f in \"${CPP_CHANGED[@]}\"; do\n    \"$CLANG_FORMAT_BIN\" -i \"$f\"\n  done\nfi\n\n## Run black only on staged Python files\nif command -v \"$BLACK_BIN\" >/dev/null 2>&1 && [ ${#PY_CHANGED[@]} -gt 0 ]; then\n  echo \"[pre-commit] black on ${#PY_CHANGED[@]} files\" >&2\n  \"$BLACK_BIN\" \"${PY_CHANGED[@]}\"\nfi\n\n# Stage any formatting changes for tracked, formatted files only\nFMT_FILES=(\"${PY_CHANGED[@]}\" \"${CPP_CHANGED[@]}\")\nif [ ${#FMT_FILES[@]} -gt 0 ] && ! git diff --quiet --exit-code -- \"${FMT_FILES[@]}\"; then\n  echo \"[pre-commit] Formatting applied; updating index.\" >&2\n  git add \"${FMT_FILES[@]}\"\n  echo \"[pre-commit] Re-run git commit to proceed after reviewing changes.\" >&2\n  exit 1\nfi\n\necho \"[pre-commit] format OK.\" >&2\nexit 0\n"
  },
  {
    "path": "kt-kernel/.gitignore",
    "content": "debug/\ndebug_prefill/\ndebug_decode/\ndebug1/\ndebug2/\n.gdbinit\nbp.gdb\n.gdb_history\nbuild/\n# local git hooks installer and hooks\n.clangd\n.cache\ntmp*\n.vscode/\n*.egg-info/\n*.pyc\n*.so\nsparse_logs/\nbuild-cm/\n*.so\nsparse_logs/"
  },
  {
    "path": "kt-kernel/.gitmodules",
    "content": "[submodule \"pybind11\"]\n\tpath = third_party/pybind11\n\turl = https://github.com/pybind/pybind11.git\n[submodule \"llama.cpp\"]\n\tpath = third_party/llama.cpp\n\turl = https://github.com/ggerganov/llama.cpp.git\n"
  },
  {
    "path": "kt-kernel/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.16)\n\n# Toggle: default to system compilers; optionally use conda toolchain\noption(USE_CONDA_TOOLCHAIN \"Use C/C++ compilers and libraries from active conda env\" OFF)\noption(LLAMA_NATIVE \"llama: enable -march=native flag\" OFF)\noption(LLAMA_AVX \"llama: enable AVX\" OFF)\noption(LLAMA_AVX2 \"llama: enable AVX2\" OFF)\n# AVX512 options will be auto-detected by cmake/DetectCPU.cmake\n# Users can override with -DLLAMA_AVX512=OFF etc.\noption(LLAMA_FMA \"llama: enable FMA\" OFF)\n# in MSVC F16C is implied with AVX2/AVX512\nif(NOT MSVC)\n    option(LLAMA_F16C \"llama: enable F16C\" OFF)\nendif()\noption(LLAMA_AVX512_FANCY_SIMD \"llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI\" OFF)\noption(KTRANSFORMERS_USE_CUDA \"ktransformers: use CUDA\" OFF)\noption(KTRANSFORMERS_USE_MUSA \"ktransformers: use MUSA\" OFF)\noption(KTRANSFORMERS_USE_ROCM \"ktransformers: use ROCM\" OFF)\noption(KTRANSFORMERS_CUDA_STATIC_RUNTIME \"ktransformers: statically link CUDA runtime\" ON)\noption(KTRANSFORMERS_CPU_USE_KML \"ktransformers: CPU use KML\" OFF)\noption(KTRANSFORMERS_CPU_USE_AMX_AVX512 \"ktransformers: CPU use AMX or AVX512\" OFF)\noption(KTRANSFORMERS_CPU_USE_AMX \"ktransformers: CPU use AMX\" OFF)\noption(KTRANSFORMERS_CPU_DEBUG \"ktransformers: DEBUG CPU use AMX\" OFF)\noption(KTRANSFORMERS_CPU_MLA \"ktransformers: CPU use MLA\" OFF)\noption(KTRANSFORMERS_CPU_MOE_KERNEL \"ktransformers: CPU use moe kernel\" OFF)\noption(KTRANSFORMERS_CPU_MOE_AMD \"ktransformers: CPU use moe kernel for amd\" OFF)\n# LTO control\noption(CPUINFER_ENABLE_LTO \"Enable link time optimization (IPO)\" OFF)\n\nproject(kt_kernel_ext VERSION 0.5.0)\n\n# Auto-detect CPU features early (unless building with LLAMA_NATIVE)\nif(NOT LLAMA_NATIVE AND NOT MSVC)\n    include(cmake/DetectCPU.cmake)\nendif()\n\n# Choose compilers BEFORE project() so CMake honors them\nif(USE_CONDA_TOOLCHAIN)\n    if(NOT DEFINED ENV{CONDA_PREFIX} OR NOT EXISTS \"$ENV{CONDA_PREFIX}\")\n        message(FATAL_ERROR \"USE_CONDA_TOOLCHAIN=ON but CONDA_PREFIX is not set. Activate your conda env or pass -DCONDA_PREFIX=/path\")\n    endif()\n    # Locate conda GCC wrappers\n    find_program(CONDA_CC NAMES x86_64-conda-linux-gnu-cc HINTS \"$ENV{CONDA_PREFIX}/bin\")\n    find_program(CONDA_CXX NAMES x86_64-conda-linux-gnu-c++ HINTS \"$ENV{CONDA_PREFIX}/bin\")\n    if(NOT CONDA_CC OR NOT CONDA_CXX)\n        message(FATAL_ERROR \"Conda compilers not found in $ENV{CONDA_PREFIX}/bin (expected x86_64-conda-linux-gnu-cc/c++).\")\n    endif()\n    set(CMAKE_C_COMPILER   ${CONDA_CC}  CACHE FILEPATH \"C compiler\" FORCE)\n    set(CMAKE_CXX_COMPILER ${CONDA_CXX} CACHE FILEPATH \"C++ compiler\" FORCE)\nelse()\n    # Prefer system compilers explicitly to avoid accidentally picking conda wrappers from PATH\n    if(EXISTS \"/usr/bin/gcc\" AND EXISTS \"/usr/bin/g++\")\n        set(CMAKE_C_COMPILER   \"/usr/bin/gcc\" CACHE FILEPATH \"C compiler\" FORCE)\n        set(CMAKE_CXX_COMPILER \"/usr/bin/g++\" CACHE FILEPATH \"C++ compiler\" FORCE)\n    endif()\nendif()\n\n\n# If explicitly using conda toolchain, prefer its libraries/headers and RPATH\nif(USE_CONDA_TOOLCHAIN)\n    message(STATUS \"Conda prefix detected: $ENV{CONDA_PREFIX}; prioritizing it for search paths and RPATH\")\n    # Make conda come first for CMake package discovery\n    list(PREPEND CMAKE_PREFIX_PATH\n        \"$ENV{CONDA_PREFIX}\"\n        \"$ENV{CONDA_PREFIX}/lib/cmake\"\n        \"$ENV{CONDA_PREFIX}/share/cmake\"\n    )\n    # Also hint direct include/lib searches\n    list(PREPEND CMAKE_LIBRARY_PATH \"$ENV{CONDA_PREFIX}/lib\")\n    list(PREPEND CMAKE_INCLUDE_PATH \"$ENV{CONDA_PREFIX}/include\")\n\n    # Ensure pkg-config prefers conda .pc files\n    set(ENV{PKG_CONFIG_PATH} \"$ENV{CONDA_PREFIX}/lib/pkgconfig:$ENV{CONDA_PREFIX}/share/pkgconfig:$ENV{PKG_CONFIG_PATH}\")\n    # Make FindPkgConfig also search CMAKE_PREFIX_PATH\n    set(PKG_CONFIG_USE_CMAKE_PREFIX_PATH ON)\n\n    # Configure RPATH so the built extension prefers conda's shared libs at runtime\n    # Use install RPATH during build to avoid mixing with implicit system paths\n    set(CMAKE_SKIP_BUILD_RPATH FALSE)\n    set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)\n    set(CMAKE_BUILD_RPATH \"$ENV{CONDA_PREFIX}/lib\")\n    set(CMAKE_INSTALL_RPATH \"$ENV{CONDA_PREFIX}/lib\")\n    # Do not auto-append link directories to RPATH; we want only conda path here\n    set(CMAKE_INSTALL_RPATH_USE_LINK_PATH OFF)\nendif()\n\n## Ensure git hooks are installed when configuring the project (monorepo-aware)\n# If we are inside a git worktree (repo root is outside kt-kernel now), invoke the installer\n# which will link kt-kernel/.githooks into the top-level .git/hooks. Otherwise, skip.\nfind_program(GIT_BIN git)\nif(GIT_BIN)\n    execute_process(\n        COMMAND \"${GIT_BIN}\" rev-parse --show-toplevel\n        WORKING_DIRECTORY \"${CMAKE_SOURCE_DIR}\"\n        OUTPUT_VARIABLE _GIT_TOP\n        RESULT_VARIABLE _GIT_RV\n        OUTPUT_STRIP_TRAILING_WHITESPACE\n        ERROR_QUIET\n    )\n    if(_GIT_RV EQUAL 0 AND EXISTS \"${_GIT_TOP}/.git\" AND IS_DIRECTORY \"${_GIT_TOP}/.git\")\n        if(EXISTS \"${CMAKE_SOURCE_DIR}/scripts/install-git-hooks.sh\")\n            message(STATUS \"Detected git worktree at ${_GIT_TOP}; installing hooks from kt-kernel/.githooks\")\n            execute_process(\n                COMMAND sh \"${CMAKE_SOURCE_DIR}/scripts/install-git-hooks.sh\"\n                WORKING_DIRECTORY \"${CMAKE_SOURCE_DIR}\"\n                RESULT_VARIABLE _INSTALL_GIT_HOOKS_RESULT\n                OUTPUT_VARIABLE _INSTALL_GIT_HOOKS_OUT\n                ERROR_VARIABLE _INSTALL_GIT_HOOKS_ERR\n            )\n            if(NOT _INSTALL_GIT_HOOKS_RESULT EQUAL 0)\n                message(FATAL_ERROR \"Installing git hooks failed (exit ${_INSTALL_GIT_HOOKS_RESULT}).\\nOutput:\\n${_INSTALL_GIT_HOOKS_OUT}\\nError:\\n${_INSTALL_GIT_HOOKS_ERR}\")\n            endif()\n        else()\n            message(FATAL_ERROR \"Required script 'scripts/install-git-hooks.sh' not found in kt-kernel; cannot install hooks.\")\n        endif()\n    else()\n        message(STATUS \"No git worktree detected; skipping git hooks installation\")\n    endif()\nelse()\n    message(STATUS \"git not found; skipping git hooks installation\")\nendif()\n\nset(CMAKE_CXX_STANDARD 20)\nset(CMAKE_CXX_STANDARD_REQUIRED ON)\n\n# Use header-only fmt to avoid needing to link libfmt (fix undefined symbol vprint)\nadd_compile_definitions(FMT_HEADER_ONLY)\n\nset(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -O3 -ffast-math\")\n# set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -fsanitize=address -fno-omit-frame-pointer\")\n# set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -O0\")\nset(CMAKE_EXPORT_COMPILE_COMMANDS ON)\nfind_package(OpenMP REQUIRED)\nmessage(STATUS \"CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}\")\n\n\ninclude(CheckCXXCompilerFlag)\nset(CMAKE_POSITION_INDEPENDENT_CODE ON)\n\n\n\n# instruction set specific\nif(LLAMA_NATIVE)\n    set(INS_ENB OFF)\nelse()\n    set(INS_ENB ON)\nendif()\n# Architecture specific\n# TODO: probably these flags need to be tweaked on some architectures\n#       feel free to update the Makefile for your architecture and send a pull request or issue\nmessage(STATUS \"CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}\")\n\nset(ARCH_FLAGS \"\")\n\nif(CMAKE_OSX_ARCHITECTURES STREQUAL \"arm64\" OR CMAKE_GENERATOR_PLATFORM_LWR STREQUAL \"arm64\" OR\n    (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND\n         CMAKE_SYSTEM_PROCESSOR MATCHES \"^(aarch64|arm.*|ARM64)$\"))\n    message(STATUS \"ARM detected\")\n    if(MSVC)\n        add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead\n        add_compile_definitions(__ARM_NEON)\n        add_compile_definitions(__ARM_FEATURE_FMA)\n\n        set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS})\n        string(JOIN \" \" CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} \"/arch:armv8.2\")\n        check_cxx_source_compiles(\"#include <arm_neon.h>\\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }\" GGML_COMPILER_SUPPORT_DOTPROD)\n        if(GGML_COMPILER_SUPPORT_DOTPROD)\n            add_compile_definitions(__ARM_FEATURE_DOTPROD)\n        endif()\n        check_cxx_source_compiles(\"#include <arm_neon.h>\\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }\" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)\n        if(GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)\n            add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)\n        endif()\n        set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})\n    else()\n        check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)\n        if(NOT \"${COMPILER_SUPPORTS_FP16_FORMAT_I3E}\" STREQUAL \"\")\n            list(APPEND ARCH_FLAGS -mfp16-format=ieee)\n        endif()\n        if(${CMAKE_SYSTEM_PROCESSOR} MATCHES \"armv6\")\n            # Raspberry Pi 1, Zero\n            list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)\n        endif()\n        if(${CMAKE_SYSTEM_PROCESSOR} MATCHES \"armv7\")\n            if(\"${CMAKE_SYSTEM_NAME}\" STREQUAL \"Android\")\n                # Android armeabi-v7a\n                list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)\n            else()\n                # Raspberry Pi 2\n                list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)\n            endif()\n        endif()\n        if(${CMAKE_SYSTEM_PROCESSOR} MATCHES \"armv8\")\n            # Android arm64-v8a\n            # Raspberry Pi 3, 4, Zero 2 (32-bit)\n            list(APPEND ARCH_FLAGS -mno-unaligned-access)\n        endif()\n        # add_compile_definitions(__ARM_NEON)\n        # list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+dotprod)\n        # add_compile_definitions(__ARM_FEATURE_DOTPROD)\n        # add_compile_definitions(__aarch64__)\n\n        # add_compile_definitions(__ARM_NEON)\n        list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+dotprod+sve+bf16)\n        # list(APPEND ARCH_FLAGS -march=armv8-a+dotprod+sha3+sm4+fp16fml+sve+rng+sb+ssbs+i8mm+bf16+flagm+pauth)\n        # add_compile_definitions(__ARM_FEATURE_DOTPROD)\n        # add_compile_definitions(__ARM_FEATURE_SVE)\n        # add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)\n        # add_compile_definitions(__aarch64__)\n    endif()\nelseif(CMAKE_OSX_ARCHITECTURES STREQUAL \"x86_64\" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES \"^(x86_64|i686|amd64|x64|win32)$\" OR\n    (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND\n         CMAKE_SYSTEM_PROCESSOR MATCHES \"^(x86_64|i686|AMD64)$\"))\n    message(STATUS \"x86 detected\")\n    set(HOST_IS_X86 TRUE)\n    add_compile_definitions(__x86_64__)\n    if(MSVC)\n        # instruction set detection for MSVC only\n        if(LLAMA_NATIVE)\n            include(cmake/FindSIMD.cmake)\n        endif()\n        if(LLAMA_AVX512)\n            list(APPEND ARCH_FLAGS /arch:AVX512)\n            # MSVC has no compile-time flags enabling specific\n            # AVX512 extensions, neither it defines the\n            # macros corresponding to the extensions.\n            # Do it manually.\n            if(LLAMA_AVX512_VBMI)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)\n            endif()\n            if(LLAMA_AVX512_VNNI)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)\n            endif()\n            if(LLAMA_AVX512_FANCY_SIMD)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VL__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VL__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BW__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BW__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512DQ__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512DQ__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)\n            endif()\n            if(LLAMA_AVX512_BF16)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)\n            endif()\n        elseif(LLAMA_AVX2)\n            list(APPEND ARCH_FLAGS /arch:AVX2)\n        elseif(LLAMA_AVX)\n            list(APPEND ARCH_FLAGS /arch:AVX)\n        endif()\n    else()\n        if(LLAMA_NATIVE)\n            list(APPEND ARCH_FLAGS -mfma -mavx -mavx2)\n            list(APPEND ARCH_FLAGS -march=native)\n        endif()\n        if(LLAMA_F16C)\n            list(APPEND ARCH_FLAGS -mf16c)\n        endif()\n        if(LLAMA_FMA)\n            list(APPEND ARCH_FLAGS -mfma)\n        endif()\n        if(LLAMA_AVX)\n            list(APPEND ARCH_FLAGS -mavx -mfma -msse3 -mf16c)\n            message(WARNING \"pure AVX is not supported at least avx2\")\n        endif()\n        if(LLAMA_AVX2)\n            list(APPEND ARCH_FLAGS -mavx2 -mfma -msse3 -mf16c)\n        endif()\n        if(LLAMA_AVX512)\n            list(APPEND ARCH_FLAGS -mavx512f -mavx512bw -mavx512dq -mfma -mf16c -msse3)\n        endif()\n        if(LLAMA_AVX512_VBMI)\n            list(APPEND ARCH_FLAGS -mavx512vbmi)\n        endif()\n        if(LLAMA_AVX512_VNNI)\n            list(APPEND ARCH_FLAGS -mavx512vnni)\n        endif()\n        if(LLAMA_AVX512_FANCY_SIMD)\n            message(STATUS \"AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI enabled\")\n            list(APPEND ARCH_FLAGS -mavx512vl)\n            list(APPEND ARCH_FLAGS -mavx512bw)\n            list(APPEND ARCH_FLAGS -mavx512dq)\n            list(APPEND ARCH_FLAGS -mavx512vnni)\n            list(APPEND ARCH_FLAGS -mavx512vpopcntdq)\n        endif()\n        if(LLAMA_AVX512_BF16)\n            list(APPEND ARCH_FLAGS -mavx512bf16)\n        endif()\n    endif()\nelseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc64\")\n    message(STATUS \"PowerPC detected\")\n    if(${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc64le\")\n        list(APPEND ARCH_FLAGS -mcpu=powerpc64le)\n    else()\n        list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)\n        #TODO: Add  targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)\n    endif()\nelse()\n    message(STATUS \"Unknown architecture\")\nendif()\n\nif(NOT EXISTS $ENV{ROCM_PATH})\n    if(NOT EXISTS /opt/rocm)\n        set(ROCM_PATH /usr)\n    else()\n        set(ROCM_PATH /opt/rocm)\n    endif()\nelse()\n    set(ROCM_PATH $ENV{ROCM_PATH})\nendif()\n\nlist(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})\nlist(APPEND CMAKE_PREFIX_PATH \"${ROCM_PATH}/lib64/cmake\")\n\nif(NOT EXISTS $ENV{MUSA_PATH})\n    if(NOT EXISTS /opt/musa)\n        set(MUSA_PATH /usr/local/musa)\n    else()\n        set(MUSA_PATH /opt/musa)\n    endif()\nelse()\n    set(MUSA_PATH $ENV{MUSA_PATH})\nendif()\n\nlist(APPEND CMAKE_MODULE_PATH \"${MUSA_PATH}/cmake\")\n\nif(KTRANSFORMERS_CPU_MOE_AMD)\n    set(BLIS_ROOT \"\" CACHE PATH \"Root directory of BLIS installation\")\n    set(_BLIS_SEARCH_DIRS)\n    if(BLIS_ROOT)\n        list(APPEND _BLIS_SEARCH_DIRS \"${BLIS_ROOT}\")\n    endif()\n    list(APPEND _BLIS_SEARCH_DIRS \"/usr/local\" \"/usr\")\n\n    find_path(BLIS_INCLUDE_DIR\n        NAMES blis.h\n        HINTS ${_BLIS_SEARCH_DIRS}\n        PATH_SUFFIXES include include/blis\n    )\n    find_library(BLIS_LIBRARY\n        NAMES blis\n        HINTS ${_BLIS_SEARCH_DIRS}\n        PATH_SUFFIXES lib lib64\n    )\n\n    if(NOT BLIS_INCLUDE_DIR OR NOT BLIS_LIBRARY)\n        message(WARNING \"BLIS not found; set BLIS_ROOT or specify BLIS_INCLUDE_DIR/BLIS_LIBRARY\")\n    else()\n        message(STATUS \"Found BLIS include at ${BLIS_INCLUDE_DIR}\")\n        message(STATUS \"Found BLIS library ${BLIS_LIBRARY}\")\n        set(_KT_BLIS_INCLUDE_DIR ${BLIS_INCLUDE_DIR})\n        set(_KT_BLIS_LIBRARY ${BLIS_LIBRARY})\n    endif()\n    # The Python extension target (${PROJECT_NAME}) is created later by\n    # pybind11_add_module(). Calling target_include_directories/target_link_libraries\n    # here would fail because the target doesn't exist yet. Save the discovered\n    # BLIS paths and apply them after the module target is created.\nendif()\n\n\nif(HOST_IS_X86)\n    if(KTRANSFORMERS_CPU_USE_AMX_AVX512)\n        add_compile_definitions(USE_AMX_AVX_KERNEL=1)\n        if(KTRANSFORMERS_CPU_USE_AMX)\n            add_compile_definitions(HAVE_AMX=1)\n            list(APPEND ARCH_FLAGS -mamx-tile -mamx-bf16 -mamx-int8)\n            message(STATUS \"AMX enabled\")\n        endif()\n        # add_executable(amx-test ${CMAKE_CURRENT_SOURCE_DIR}/operators/amx/amx-test.cpp)\n        # target_link_libraries(amx-test llama)\n        if(KTRANSFORMERS_CPU_DEBUG)\n            file(GLOB AMX_TEST_SOURCES \"${CMAKE_CURRENT_SOURCE_DIR}/operators/amx/test/*.cpp\")\n            foreach(test_src ${AMX_TEST_SOURCES})\n                # 获取不带扩展名的文件名作为 target 名\n                get_filename_component(test_name ${test_src} NAME_WE)\n                add_executable(${test_name} ${test_src} ${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend/shared_mem_buffer.cpp)\n                target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa)\n            endforeach()\n        endif()\n\n        # AVX512 extensions are auto-detected by cmake/DetectCPU.cmake\n        # Users can override with -DLLAMA_AVX512_BF16=OFF etc.\n        # Only add -mf16c if LLAMA_F16C is not already enabled.\n        if(NOT LLAMA_F16C)\n            list(APPEND ARCH_FLAGS -mf16c)\n        endif()\n        message(STATUS \"AVX512 extensions: F=${LLAMA_AVX512}, BF16=${LLAMA_AVX512_BF16}, VNNI=${LLAMA_AVX512_VNNI}, VBMI=${LLAMA_AVX512_VBMI}\")\n    endif()\nendif()\n\nmessage(STATUS \"ARCH_FLAGS: ${ARCH_FLAGS}\")\n\nadd_compile_options(\"$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>\")\nadd_compile_options(\"$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>\")\n\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../third_party/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/third_party/pybind11)\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)\n\ninclude_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third_party)\nif(KTRANSFORMERS_USE_CUDA)\n    include(CheckLanguage)\n    check_language(CUDA)\n    if(CMAKE_CUDA_COMPILER)\n        message(STATUS \"CUDA detected\")\n        find_package(CUDAToolkit REQUIRED)\n        include_directories(${CUDAToolkit_INCLUDE_DIRS})\n    else()\n        message(FATAL_ERROR \"KTRANSFORMERS_USE_CUDA=ON but CUDA compiler not found\")\n    endif()\n    message(STATUS \"enabling CUDA\")\n    enable_language(CUDA)\n    add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)\n\n    # Set default CUDA architectures if not specified\n    # Target: SM 80/86 (Ampere), 89 (Ada), 90 (Hopper)\n    if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)\n        set(CMAKE_CUDA_ARCHITECTURES \"80;86;89;90\" CACHE STRING \"CUDA architectures\" FORCE)\n        message(STATUS \"CUDA architectures (default): ${CMAKE_CUDA_ARCHITECTURES}\")\n    else()\n        message(STATUS \"CUDA architectures (user): ${CMAKE_CUDA_ARCHITECTURES}\")\n    endif()\n\n    # Optimization flags\n    set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -O3 --use_fast_math\")\n    set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr\")\n    set(CMAKE_CUDA_STANDARD 17)\n    set(CMAKE_CUDA_STANDARD_REQUIRED ON)\n\n    message(STATUS \"CUDA compiler: ${CMAKE_CUDA_COMPILER}\")\n    message(STATUS \"CUDA toolkit: ${CUDAToolkit_VERSION}\")\n    message(STATUS \"CUDA flags: ${CMAKE_CUDA_FLAGS}\")\nelseif(KTRANSFORMERS_USE_ROCM)\n    find_package(HIP REQUIRED)\n    if(HIP_FOUND)\n        include_directories(\"${HIP_INCLUDE_DIRS}\")\n        add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)\n    endif()\nelseif(KTRANSFORMERS_USE_MUSA)\n    if(NOT EXISTS $ENV{MUSA_PATH})\n        if(NOT EXISTS /opt/musa)\n            set(MUSA_PATH /usr/local/musa)\n        else()\n            set(MUSA_PATH /opt/musa)\n        endif()\n    else()\n        set(MUSA_PATH $ENV{MUSA_PATH})\n    endif()\n\n    list(APPEND CMAKE_MODULE_PATH \"${MUSA_PATH}/cmake\")\n\n    find_package(MUSAToolkit)\n    if(MUSAToolkit_FOUND)\n        message(STATUS \"MUSA Toolkit found\")\n        add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)\n    endif()\nelseif(KTRANSFORMERS_CPU_USE_KML)\n    message(STATUS \"KML CPU detected\")\nelse()\n    message(STATUS \"No GPU support enabled, building for CPU only\")\n    add_compile_definitions(KTRANSFORMERS_CPU_ONLY=1)\nendif()\n\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../third_party/llamafile SOURCE_DIR4)\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)\n# message(STATUS \"SOURCE_DIR3: ${SOURCE_DIR3}\")\n\n# arm64\nif(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)\n    aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kml SOURCE_DIR6)\n    if(NOT KTRANSFORMERS_CPU_MLA)\n        list(REMOVE_ITEM SOURCE_DIR6 ${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/mla/)\n    endif()\nendif()\n# message(STATUS \"SOURCE_DIR6: ${SOURCE_DIR6}\")\n\nif(KTRANSFORMERS_CPU_MOE_KERNEL)\n    aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/la SOURCE_DIR7)\n    if(KTRANSFORMERS_CPU_MOE_AMD)\n        aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/aocl_kernel SOURCE_DIR7_KERNEL)\n        add_compile_definitions(USE_MOE_KERNEL_AMD=1)\n    elseif(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)\n        aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel SOURCE_DIR7_KERNEL)\n    endif()\n    list(APPEND SOURCE_DIR7 ${SOURCE_DIR7_KERNEL})\n    if(NOT KTRANSFORMERS_CPU_MLA)\n        list(REMOVE_ITEM SOURCE_DIR7 ${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mla/)\n    endif()\n    add_compile_definitions(USE_MOE_KERNEL=1)\nendif()\nmessage(STATUS \"SOURCE_DIR7: ${SOURCE_DIR7}\")\n\n\n\nset(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6} ${SOURCE_DIR7})\n\nfile(GLOB_RECURSE FMT_SOURCES\n    \"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/*.hpp\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/*.h\"\n)\n# Exclude third_party directory\nlist(FILTER FMT_SOURCES EXCLUDE REGEX \"/third_party/\")\n\n## Locate a specific clang-format executable to avoid version drift\n## Prefer newer versions first to support modern .clang-format keys\n## You can override by passing -DCLANG_FORMAT_BIN=/full/path/to/clang-format\nif(NOT DEFINED CLANG_FORMAT_BIN)\n    set(_CF_HINTS\n        $ENV{CONDA_PREFIX}/bin\n        $ENV{MAMBA_ROOT_PREFIX}/envs/$ENV{CONDA_DEFAULT_ENV}/bin\n        $ENV{VIRTUAL_ENV}/bin\n        $ENV{HOME}/.local/bin\n    )\n    find_program(CLANG_FORMAT_BIN\n        NAMES clang-format-20 clang-format-19 clang-format-18 clang-format-17 clang-format-16 clang-format-15 clang-format\n        HINTS ${_CF_HINTS}\n    )\nendif()\nif(NOT CLANG_FORMAT_BIN)\n    message(WARNING \"ONLY for developer: clang-format not found. Please install clang-format (>=18) or pass -DCLANG_FORMAT_BIN=/full/path and reconfigure.\")\nelse()\n    execute_process(\n        COMMAND ${CLANG_FORMAT_BIN} --version\n        OUTPUT_VARIABLE _CLANG_FORMAT_VER\n        OUTPUT_STRIP_TRAILING_WHITESPACE\n    )\n    # message(STATUS \"CMake PATH: $ENV{PATH}\")\n    # Parse version string, e.g. \"Ubuntu clang-format version 19.1.0\" or \"clang-format version 18.1.8\"\n    string(REGEX MATCH \"version[ ]+([0-9]+(\\\\.[0-9]+)*)\" _CF_VER_MATCH \"${_CLANG_FORMAT_VER}\")\n    if(NOT _CF_VER_MATCH)\n        message(WARNING \"Failed to parse clang-format version from: ${_CLANG_FORMAT_VER}\")\n    endif()\n    set(CLANG_FORMAT_VERSION \"${CMAKE_MATCH_1}\")\n    message(STATUS \"Using clang-format ${CLANG_FORMAT_VERSION} at ${CLANG_FORMAT_BIN}\")\n    if(CLANG_FORMAT_VERSION VERSION_LESS \"18.0.0\")\n        message(WARNING \"clang-format >=18.0.0 required (found ${CLANG_FORMAT_VERSION} at ${CLANG_FORMAT_BIN}).\\n\"\n                            \"Tip: Ensure your desired clang-format (e.g., conda's ${CONDA_PREFIX}/bin/clang-format) is earlier in PATH when running CMake,\\n\"\n                            \"or pass -DCLANG_FORMAT_BIN=/full/path/to/clang-format.\")\n    endif()\n    add_custom_target(\n        format\n        COMMAND ${CLANG_FORMAT_BIN}\n                -i\n                -style=file\n                -fallback-style=none\n                ${FMT_SOURCES}\n        COMMENT \"Running clang-format on all source files\"\n    )\n\n    # Optional: target to check formatting without modifying files (CI-friendly)\n    add_custom_target(\n        format-check\n        COMMAND ${CLANG_FORMAT_BIN}\n                -n --Werror\n                -style=file\n                -fallback-style=none\n                ${FMT_SOURCES}\n        COMMENT \"Checking clang-format on all source files\"\n    )\nendif()\n\ninclude(FindPkgConfig)\nif(PKG_CONFIG_FOUND)\n    pkg_search_module(HWLOC REQUIRED IMPORTED_TARGET hwloc)\nelse(PKG_CONFIG_FOUND)\n    message(FATAL_ERROR \"FindHWLOC needs pkg-config program and PKG_CONFIG_PATH must contain the path to hwloc.pc file.\")\nendif(PKG_CONFIG_FOUND)\n\n\nadd_library(llamafile STATIC ${SOURCE_DIR4})\n\n\nif(CPUINFER_ENABLE_LTO)\n    set(CMAKE_INTERPROCEDURAL_OPTIMIZATION ON)\n    # Use THIN_LTO keyword only if supported compiler (Clang). GCC ignores it.\n    pybind11_add_module(${PROJECT_NAME} MODULE THIN_LTO ${ALL_SOURCES})\n    message(STATUS \"LTO: enabled\")\nelse()\n    set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF)\n    pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})\n    message(STATUS \"LTO: disabled\")\nendif()\n\n# If BLIS was detected earlier, apply its include directory and library to the\n# created Python extension target. We only do this after the module target\n# (${PROJECT_NAME}) has been created by pybind11_add_module().\nif(DEFINED _KT_BLIS_INCLUDE_DIR AND DEFINED _KT_BLIS_LIBRARY)\n    if(TARGET ${PROJECT_NAME})\n        target_include_directories(${PROJECT_NAME} PRIVATE ${_KT_BLIS_INCLUDE_DIR})\n        target_link_libraries(${PROJECT_NAME} PRIVATE ${_KT_BLIS_LIBRARY})\n    else()\n        message(WARNING \"BLIS was detected earlier but ${PROJECT_NAME} target was not found when attempting to apply BLIS link/include settings.\")\n    endif()\nendif()\n\n# Ensure the module target also has correct RPATH when conda is active\nif(TARGET ${PROJECT_NAME} AND DEFINED ENV{CONDA_PREFIX} AND EXISTS \"$ENV{CONDA_PREFIX}\")\n    set_target_properties(${PROJECT_NAME} PROPERTIES\n        BUILD_RPATH \"$ENV{CONDA_PREFIX}/lib\"\n        INSTALL_RPATH \"$ENV{CONDA_PREFIX}/lib\"\n        SKIP_BUILD_RPATH OFF\n    )\nendif()\nif(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)\n    message(STATUS \"KML CPU detected\")\n\n    add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/prefillgemm)\n    target_link_libraries(${PROJECT_NAME} PRIVATE prefillint8gemm)\n    add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/prefillgemm_int4)\n    target_link_libraries(${PROJECT_NAME} PRIVATE prefillint4gemm)\n\n    set(DECODE_GEMM_SOURCES\n        ${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/batch_gemm.cpp\n        ${CMAKE_CURRENT_SOURCE_DIR}/operators/moe_kernel/mat_kernel/kml_kernel/batch_gemm_kernels.cpp\n    )\n    add_library(decode_gemm SHARED ${DECODE_GEMM_SOURCES})\n    target_link_libraries(${PROJECT_NAME} PRIVATE decode_gemm)  \n    if(KTRANSFORMERS_CPU_MLA)\n        target_link_libraries(${PROJECT_NAME} PRIVATE kml_rt)\n    endif()\n    target_compile_definitions(${PROJECT_NAME} PRIVATE CPU_USE_KML)\nendif()\ntarget_link_libraries(${PROJECT_NAME} PRIVATE llama PkgConfig::HWLOC OpenMP::OpenMP_CXX)\nif(NOT HOST_IS_X86 AND KTRANSFORMERS_CPU_USE_KML)\n    if(KTRANSFORMERS_CPU_DEBUG)\n        # add_executable(convert-test ${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/convert-test.cpp)\n        # target_link_libraries(convert-test llama)\n        file(GLOB KML_TEST_SOURCES \"${CMAKE_CURRENT_SOURCE_DIR}/operators/kml/test/*.cpp\")\n        foreach(test_src ${KML_TEST_SOURCES})\n            # 获取不带扩展名的文件名作为 target 名\n            get_filename_component(test_name ${test_src} NAME_WE)\n            add_executable(${test_name} ${test_src} ${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend/shared_mem_buffer.cpp)\n            if(KTRANSFORMERS_CPU_MLA)\n                target_link_libraries(${test_name} llama OpenMP::OpenMP_CXX numa kml_rt)\n            endif()\n        endforeach()\n    endif()\nendif()\n\n\n\nif(KTRANSFORMERS_USE_CUDA)\n    # Link CUDA runtime (static or dynamic)\n    if(KTRANSFORMERS_CUDA_STATIC_RUNTIME)\n        # Platform-aware static library path\n        if(WIN32)\n            set(CUDART_STATIC_LIB \"${CUDAToolkit_LIBRARY_DIR}/cudart_static.lib\")\n        else()\n            set(CUDART_STATIC_LIB \"${CUDAToolkit_LIBRARY_DIR}/libcudart_static.a\")\n        endif()\n\n        if(EXISTS \"${CUDART_STATIC_LIB}\")\n            target_link_libraries(${PROJECT_NAME} PRIVATE \"${CUDART_STATIC_LIB}\")\n            message(STATUS \"CUDA runtime: static (${CUDART_STATIC_LIB})\")\n\n            # Linux needs additional libs for static cudart\n            if(UNIX AND NOT APPLE)\n                target_link_libraries(${PROJECT_NAME} PRIVATE rt pthread dl)\n            endif()\n        else()\n            message(WARNING \"Static CUDA runtime not found, using dynamic\")\n            target_link_libraries(${PROJECT_NAME} PRIVATE CUDA::cudart)\n        endif()\n    else()\n        # Dynamic linking\n        target_link_libraries(${PROJECT_NAME} PRIVATE CUDA::cudart)\n        message(STATUS \"CUDA runtime: dynamic\")\n    endif()\nendif()\nif(KTRANSFORMERS_USE_ROCM)\n    add_compile_definitions(USE_HIP=1)\n    target_link_libraries(${PROJECT_NAME} PRIVATE \"${ROCM_PATH}/lib/libamdhip64.so\")\n    message(STATUS \"Building for HIP\")\nendif()\nif(KTRANSFORMERS_USE_MUSA)\n    target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)\nendif()\n\n\n\nfind_library(NUMA_LIBRARY NAMES numa)\nif(NUMA_LIBRARY)\n    message(STATUS \"NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support\")\n    target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY})\nelse()\n    message(FATAL_ERROR \"NUMA library not found, please install NUMA, sudo apt install libnuma-dev\")\nendif()\n\n"
  },
  {
    "path": "kt-kernel/CMakePresets.json",
    "content": "{\n  \"version\": 3,\n  \"cmakeMinimumRequired\": {\n    \"major\": 3,\n    \"minor\": 19,\n    \"patch\": 0\n  },\n  \"configurePresets\": [\n    {\n      \"name\": \"avx512\",\n      \"displayName\": \"avx512_platform\",\n      \"description\": \"for avx512 platform\",\n      \"cacheVariables\": {\n        \"KTRANSFORMERS_CPU_USE_AMX\": \"OFF\",\n        \"LLAMA_AVX512\": \"OFF\",\n        \"LLAMA_AVX2\": \"OFF\",\n        \"KTRANSFORMERS_CPU_USE_AMX_AVX512\": \"ON\",\n        \"KTRANSFORMERS_USE_CUDA\": \"ON\"\n      }\n    },\n    {\n      \"name\": \"avx\",\n      \"displayName\": \"avx_platform\",\n      \"description\": \"for avx platform\",\n      \"cacheVariables\": {\n        \"KTRANSFORMERS_CPU_USE_AMX\": \"OFF\",\n        \"LLAMA_AVX2\": \"ON\",\n        \"KTRANSFORMERS_USE_CUDA\": \"ON\"\n      }\n    },\n    {\n      \"name\": \"amx\",\n      \"displayName\": \"amx_platform\",\n      \"description\": \"for amx platform\",\n      \"cacheVariables\": {\n        \"KTRANSFORMERS_CPU_USE_AMX\": \"ON\",\n        \"LLAMA_AVX512\": \"OFF\",\n        \"LLAMA_AVX2\": \"OFF\",\n        \"KTRANSFORMERS_CPU_USE_AMX_AVX512\": \"ON\",\n        \"KTRANSFORMERS_USE_CUDA\": \"ON\"\n      }\n    },\n    {\n      \"name\": \"amd\",\n      \"displayName\": \"amd_platform\",\n      \"description\": \"for amd platform\",\n      \"cacheVariables\": {\n        \"KTRANSFORMERS_CPU_USE_AMX\": \"OFF\",\n        \"LLAMA_AVX512\": \"OFF\",\n        \"LLAMA_AVX2\": \"ON\",\n        \"KTRANSFORMERS_CPU_USE_AMX_AVX512\": \"OFF\",\n        \"KTRANSFORMERS_USE_CUDA\": \"ON\",\n        \"KTRANSFORMERS_CPU_MOE_AMD\": \"ON\",\n        \"KTRANSFORMERS_CPU_MOE_KERNEL\": \"ON\"\n      }\n    }\n\n  ]\n}\n\n\n"
  },
  {
    "path": "kt-kernel/MANIFEST.in",
    "content": "# MANIFEST.in for kt-kernel\n# Ensures source distribution includes all necessary files for building from source\n\n# Core build files\ninclude CMakeLists.txt\ninclude CMakePresets.json\ninclude setup.py\ninclude pyproject.toml\ninclude requirements.txt\ninclude README.md\ninclude LICENSE\n\n# CMake modules and configuration\nrecursive-include cmake *.cmake *.in\n\n# C++ source files\nrecursive-include cpu_backend *.h *.hpp *.cpp *.c *.cc\nrecursive-include operators *.h *.hpp *.cpp *.c *.cc\ninclude ext_bindings.cpp\n\n# Python package\nrecursive-include python *.py\n\n# Third-party dependencies (vendored)\nrecursive-include third_party *\n\n# Exclude compiled and cache files\nglobal-exclude *.pyc\nglobal-exclude *.pyo\nglobal-exclude __pycache__\nglobal-exclude .git*\nglobal-exclude *.so\nglobal-exclude *.o\nglobal-exclude *.a\nglobal-exclude build\nglobal-exclude dist\nglobal-exclude *.egg-info\n"
  },
  {
    "path": "kt-kernel/README.md",
    "content": "# KT-Kernel\n\nHigh-performance kernel operations for KTransformers, featuring CPU-optimized MoE inference with AMX, AVX, KML and blis (amd library) support.\n\n- [Note](#note)\n- [Features](#features)\n- [Installation](#installation)\n  - [Option 1: Install from PyPI (Recommended for Most Users)](#option-1-install-from-pypi-recommended-for-most-users)\n  - [Option 2: Install from Source (For Local Use or Custom Builds)](#option-2-install-from-source-for-local-use-or-custom-builds)\n- [Verification](#verification)\n- [KT CLI Overview](#kt-cli-overview)\n- [Integration with SGLang](#integration-with-sglang)\n  - [Installation Steps](#installation-steps)\n  - [Complete Example: Qwen3-30B-A3B](#complete-example-qwen3-30b-a3b)\n  - [KT-Kernel Parameters](#kt-kernel-parameters)\n- [Direct Python API Usage](#direct-python-api-usage)\n  - [Advanced Options](#advanced-options)\n  - [Manual Configuration (Advanced)](#manual-configuration-advanced)\n- [Build Configuration](#build-configuration)\n  - [Manual Installation (Without install.sh)](#manual-installation-without-installsh)\n- [Error Troubleshooting](#error-troubleshooting)\n  - [CUDA Not Found](#cuda-not-found)\n  - [hwloc Not Found](#hwloc-not-found)\n- [Weight Quantization](#weight-quantization)\n- [Before Commit!](#before-commit)\n\n## Note\n\n**Current Support Status:**\n- ✅ **Native Precision with AVX512/AMX**: Supported with AVX512 CPUs in `FP8`, `BF16` and `RAWINT4` format - [Guide](https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/kt-kernel/Native-Precision-Tutorial.md)\n- ✅ **Intel CPUs with AMX**: Fully supported (using weights converted to INT4/INT8 format)\n- ✅ **Universal CPU (llamafile backend)**: Supported (using GGUF-format weights)\n- ✅ **AMD CPUs with BLIS**: Supported (for int8 prefill & decode) - [Guide](https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/kt-kernel/amd_blis.md)\n\n**KT-CLI**\n\nWe are developing a simpler way to use KTransformers. Check out the [KT-CLI Guide](https://github.com/kvcache-ai/ktransformers/blob/main/doc/en/kt-kernel/kt-cli.md) for more details.\n\n## Features\n\n- **CPU-Optimized MoE Kernels**: High-throughput MoE expert kernels optimized for instruction sets.\n- **AVX512 Native Precision Backend**: FP8 / BF16 / INT4 native MoE backend for AVX512-capable servers.\n- **AMX INT4/INT8 Backend**: INT4 / INT8 quantized expert inference backend for AMX-capable servers.\n- **Llamafile CPU Backend**: AVX2/AVX512-based MoE backend built on Llamafile for universal CPU deployment.\n- **NUMA-Aware Execution**: Thread pool and memory layout designed for multi-socket / multi-NUMA machines.\n\n## Installation\n\n### Option 1: Install from PyPI (Recommended for Most Users)\n\nInstall the latest version with a single command:\n\n```bash\npip install kt-kernel\n```\n\n> **Note**: Check the [latest version on PyPI](https://pypi.org/project/kt-kernel/#history)\n\n**Features:**\n- ✅ **Automatic CPU detection**: Detects your CPU and loads the optimal kernel variant\n- ✅ **CPU multi-variant support**: Includes AMX, AVX512 (Base/VNNI/VBMI/BF16), and AVX2 variants\n- ✅ **CUDA support included**: GPU acceleration for NVIDIA GPUs (SM 80, 86, 89, 90)\n- ✅ **No compilation needed**: Pre-built wheels for Python 3.10, 3.11, 3.12\n- ✅ **Static CUDA runtime**: No CUDA toolkit installation required\n- ✅ **Works on CPU-only systems**: CUDA features automatically disabled when GPU not available\n\n**Requirements:**\n- Python 3.10, 3.11, or 3.12\n- Linux x86-64 (manylinux_2_17 compatible)\n- CPU with AVX2 support (Intel Haswell 2013+, AMD Zen+)\n- Optional: NVIDIA GPU with compute capability 8.0+ for CUDA features\n\n#### CUDA Installation (GPU Acceleration)\n\nFor NVIDIA GPU-accelerated inference:\n\n```bash\npip install kt-kernel-cuda\n```\n\n**Features:**\n- ✅ **Multi-architecture support**: Single wheel supports SM 80/86/89/90 (Ampere, Ada, Hopper)\n- ✅ **Static CUDA runtime**: No CUDA toolkit installation required\n- ✅ **Broad compatibility**: Works with CUDA 11.8+ and 12.x drivers\n- ✅ **PyTorch compatible**: Works with any PyTorch CUDA variant (cu118, cu121, cu124)\n\n**Requirements:**\n- Python 3.10, 3.11, or 3.12\n- Linux x86-64 (manylinux_2_17 compatible)\n- NVIDIA GPU with compute capability 8.0+ (Ampere or newer)\n  - ✅ Supported: A100, RTX 3000/4000 series, H100\n  - ❌ Not supported: V100, P100, GTX 1000/2000 series (too old)\n- NVIDIA driver with CUDA 11.8+ or 12.x support (no CUDA toolkit needed)\n\n**GPU Compatibility Matrix:**\n\n| GPU Architecture | Compute Capability | Supported | Example GPUs |\n|-----------------|-------------------|-----------|-------------|\n| Hopper | 9.0 | ✅ | H100, H200 |\n| Ada Lovelace | 8.9 | ✅ | RTX 4090, 4080, 4070 |\n| Ampere | 8.6 | ✅ | RTX 3090, 3080, 3070, 3060 |\n| Ampere | 8.0 | ✅ | A100, A30 |\n| Turing | 7.5 | ❌ | RTX 2080, T4 |\n| Volta | 7.0 | ❌ | V100 |\n\n**CUDA Driver Compatibility (for GPU features):**\n- CUDA 11.8, 11.9, 12.0-12.6+: Full support\n- CUDA 11.0-11.7: Not supported (upgrade driver or use CPU-only)\n\n**CPU Variants Included:**\n\nThe wheel includes 6 optimized variants that are **automatically selected at runtime** based on your CPU:\n\n| Variant | CPU Support | Performance | Auto-Selected When |\n|---------|-------------|-------------|-------------------|\n| **AMX** | Intel Sapphire Rapids+ (2023+) | ⚡⚡⚡ Best | AMX instructions detected |\n| **AVX512+BF16** | Ice Lake server, Zen 4+ (2021+) | ⚡⚡⚡ Excellent | AVX512 + BF16 detected |\n| **AVX512+VBMI** | Ice Lake client (2019+) | ⚡⚡ Great | AVX512 + VBMI detected |\n| **AVX512+VNNI** | Cascade Lake+ (2019+) | ⚡⚡ Great | AVX512 + VNNI detected |\n| **AVX512 Base** | Skylake-X+ (2017+) | ⚡⚡ Good | AVX512 base detected |\n| **AVX2** | Haswell+ (2013+), AMD Zen+ | ⚡ Good | Fallback for maximum compatibility |\n\n**Verify installation:**\n```python\nimport kt_kernel\n\n# Check which CPU variant was loaded\nprint(f\"CPU variant: {kt_kernel.__cpu_variant__}\")\nprint(f\"Version: {kt_kernel.__version__}\")\n\n# Check CUDA support\nfrom kt_kernel import kt_kernel_ext\ncpu_infer = kt_kernel_ext.CPUInfer(4)\nhas_cuda = hasattr(cpu_infer, 'submit_with_cuda_stream')\nprint(f\"CUDA support: {has_cuda}\")\n\nprint(\"✓ kt-kernel installed successfully!\")\n```\n\n**Environment Variables:**\n```bash\n# Override automatic CPU detection (for testing or debugging)\nexport KT_KERNEL_CPU_VARIANT=avx2  # Force specific variant\n\n# Enable debug output to see detection process\nexport KT_KERNEL_DEBUG=1\npython -c \"import kt_kernel\"\n```\n\n---\n\n### Option 2: Install from Source (For Local Use or Custom Builds)\n\nBuild from source for local installation or when you need AMD (BLIS), ARM (KML), or custom CUDA versions.\n\n#### Prerequisites\n\nFirst, initialize git submodules and create a conda environment:\n```bash\ngit submodule update --init --recursive\nconda create -n kt-kernel python=3.11 -y\nconda activate kt-kernel\n```\n\n#### Quick Installation (Recommended)\n\nSimply run the install script - it will auto-detect your CPU and optimize for best performance:\n\n```bash\n./install.sh\n```\n\n**What happens automatically:**\n- Auto-detects CPU capabilities (AMX, AVX512_VNNI, AVX512_BF16)\n- Installs system dependencies (`cmake`, `libhwloc-dev`, `pkg-config`)\n- Builds optimized binary for **your CPU only** (using `-march=native`)\n- **Software fallbacks**: Automatically enabled for CPUs without VNNI/BF16\n\n**Optional: Two-step installation**\n```bash\n./install.sh deps   # Install dependencies only\n./install.sh build  # Build and install kt-kernel\n```\n\n**CPU Requirements by Backend:**\n\n| Backend | Minimum CPU Requirement | Example CPUs | Notes |\n|---------|-------------------------|--------------|-------|\n| **LLAMAFILE** | AVX2 | Intel Haswell (2013+), AMD Zen+ | Universal compatibility |\n| **RAWINT4** | AVX512F + AVX512BW | Intel Skylake-X (2017+), Ice Lake, Cascade Lake | Software fallbacks for VNNI/BF16 |\n| **AMXINT4/INT8** | AMX | Intel Sapphire Rapids (2023+) | Best performance, requires AMX hardware |\n| **FP8** | AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI | Intel Cooper Lake (2020+), Sapphire Rapids (2023+); AMD Zen 4+ (e.g., EPYC 9355) | Native Precision (e.g., DeepSeek V3.2, MiniMax M2.1) |\n| **BF16** | AVX512F + AVX512BW + AVX512_BF16 | Intel Cooper Lake (2020+), Sapphire Rapids (2023+); AMD Zen 4+ (e.g., EPYC 9355) | Native Precision (e.g., Qwen3-235B-A22B, GLM-4.7) |\n\n**Software Fallback Support (AVX512 backends):**\n- ✅ VNNI fallback: Uses AVX512BW instructions\n- ✅ BF16 fallback: Uses AVX512F instructions\n- ✅ Older AVX512 CPUs (Skylake-X, Cascade Lake) can run RAWINT4 with fallbacks\n\n⚠️ **Portability Note:** The default build is optimized for your specific CPU and may not work on different/older CPUs. For portable builds or binary distribution, see [Manual Configuration](#manual-configuration-advanced) below.\n\n⚠️ **AMD BLIS backend users:** See [installation guide](https://github.com/kvcache-ai/ktransformers/issues/1601) for AMD-specific setup.\n\n## Verification\n\nAfter installation, verify that the CLI is working:\n\n```bash\nkt version\n```\n\nExpected output:\n```\nKTransformers CLI v0.x.x\n\n  Python:        3.11.x\n  Platform:      Linux 5.15.0-xxx-generic\n  CUDA:          12.x\n  kt-kernel:     0.x.x (amx)\n  sglang:        0.x.x\n```\n\nYou can also verify the Python module directly:\n\n```bash\npython -c \"from kt_kernel import KTMoEWrapper; print('✓ kt-kernel installed successfully')\"\n```\n\n## KT CLI Overview\n\nThe `kt` command-line tool provides a unified interface for running and managing KTransformers models:\n\n| Command | Description |\n|---------|-------------|\n| `kt run <model>` | Start model inference server with auto-optimized parameters |\n| `kt chat` | Interactive chat with a running model server |\n| `kt model` | Manage models and storage paths |\n| `kt doctor` | Diagnose environment issues and check system compatibility |\n| `kt config` | Manage CLI configuration |\n| `kt version` | Show version information |\n\n**Quick Start Example:**\n\n```bash\n# Start a model server (auto-detects hardware and applies optimal settings)\nkt run m2\n\n# In another terminal, chat with the model\nkt chat\n\n# Check system compatibility\nkt doctor\n```\n\nRun `kt --help` for more options, or `kt <command> --help` for command-specific help.\n\n## Integration with SGLang\n\nKT-Kernel can be used standalone via [Direct Python API](#direct-python-api-usage) or integrated with SGLang for production deployment. This section describes SGLang integration to enable CPU-GPU heterogeneous inference, where \"hot\" experts run on GPU and \"cold\" experts run on CPU for optimal resource utilization.\n\n### Installation Steps\n\n#### 1. Install SGLang\n\nInstall the kvcache-ai fork of SGLang (required for kt-kernel support):\n\n```bash\n# Option A: One-click install (from ktransformers root, installs sglang + kt-kernel)\n./install.sh\n\n# Option B: pip install\npip install sglang-kt\n\n# Option C: From source (editable mode)\ngit clone --recursive https://github.com/kvcache-ai/ktransformers.git\ncd ktransformers\npip install -e \"third_party/sglang/python[all]\"\n```\n\n> **Important:** Use `sglang-kt` (kvcache-ai fork), not the official `sglang` package. If you have the official version installed, uninstall it first: `pip uninstall sglang -y`\n\n#### 2. Prepare Weights\n\nYou need both GPU weights and CPU-side expert weights for heterogeneous inference. The exact format depends on the backend:\n\n**GPU Weights (for all backends):**  \nUse the model weights required by SGLang for GPU inference (for example, the original or already-quantized model directory from Hugging Face).\n\n**CPU Weights (AMX backend: `AMXINT4` / `AMXINT8`):**\nQuantize weights to AMX-optimized INT4/INT8 format using the provided script:\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /path/to/model \\\n  --input-type bf16 \\\n  --output /path/to/cpu-weights \\\n  --quant-method int8  # or int4 or moe_int8 (for amd now) \n```\n\n- `--input-path`: Path to GPU-side original weights\n- `--input-type`: Depends on your GPU weights type (`fp8`, `fp16`, or `bf16`)\n\nIn SGLang integration, `--kt-weight-path` should point to this converted CPU weights directory.\n\n**Supported input formats:** FP8, FP16, BF16 → INT4/INT8.\n\n**CPU Weights (LLAMAFILE backend: `LLAMAFILE`):**\nLLAMAFILE uses pre-quantized **GGUF** weights on the CPU side directly, without running `convert_cpu_weights.py`. You need to:\n\n- Download a GGUF model directly from the web (e.g., GGUF repos on Hugging Face / Modelscope);\n- In SGLang integration, use that GGUF directory as `--kt-weight-path`.\n  KT-Kernel supports multiple GGUF quantization formats such as `Q4_KM`, `Q4_K`, `Q5_K`, etc. Choose based on your latency and accuracy requirements.\n\n#### 3. Launch SGLang Server\n\nStart the SGLang server with your normal SGLang parameters, and add the following KT-Kernel specific parameters to enable CPU-GPU heterogeneous inference:\n\n**KT-Kernel Parameters to Add:**\n- `--kt-method`: Backend method (AMXINT4, AMXINT8, or LLAMAFILE)\n- `--kt-weight-path`: Path to the converted CPU weights\n- `--kt-cpuinfer`: Number of CPU inference threads (set to physical cores)\n- `--kt-threadpool-count`: Number of thread pools (set to NUMA node count)\n- `--kt-num-gpu-experts`: Number of experts to keep on GPU\n- `--kt-max-deferred-experts-per-token`: Deferred experts for pipelined execution\n\nExample:\n```bash\npython -m sglang.launch_server \\\n  [your normal SGLang parameters...] \\\n  --kt-method AMXINT8 \\\n  --kt-weight-path /path/to/cpu-weights \\\n  --kt-cpuinfer 64 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 32 \\\n  --kt-max-deferred-experts-per-token 2\n```\n\nSee [KT-Kernel Parameters](#kt-kernel-parameters) section below for detailed parameter tuning guidelines.\n\n### Complete Example: Qwen3-30B-A3B\n\nThis example demonstrates the full workflow from downloading weights to launching the server, showing **Native backend**, **AMX backend** and **LLAMAFILE backend** options.\n\n**Hardware Configuration:**\n- **GPU**: NVIDIA RTX 4090 24GB\n- **CPU**: 2x Intel Xeon Gold 6454S (64 physical cores total, 128 threads, 2 NUMA nodes)\n- **Model**: [Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B)\n\n**How to verify your system configuration:**\n```bash\n# Check CPU configuration\nlscpu | grep -E \"^CPU\\(s\\)|Thread\\(s\\) per core|Socket\\(s\\)|NUMA node\\(s\\)\"\n# Expected output example:\nCPU(s):                                  128\nThread(s) per core:                      2\nSocket(s):                               2\nNUMA node(s):                            2\n# → Physical cores = CPU(s) / Thread(s) per core = 128 / 2 = 64\n```\n\n**Parameter Rationale:**\n- `--kt-cpuinfer 64`: Set to physical cores (64), not hyperthreads (128)\n- `--kt-threadpool-count 2`: 2 NUMA nodes detected (dual-socket system)\n- `--kt-num-gpu-experts 32`: With 24GB GPU memory, we can fit ~32 experts on GPU for this model (varies by model architecture and actual memory usage)\n- `--kt-max-deferred-experts-per-token 2`: Enable pipelined execution; allows CPU to process next batch while GPU completes current batch\n- `--kt-gpu-prefill-token-threshold 2048`: Use layerwise prefill strategy when token count exceeds 2048 (for native backends only)\n\n---\n\n#### Option A: Native Backend (BF16)\n\nFor AVX512 CPUs with BF16 support.\n\n**Step 1: Download model weights**\n\n```bash\n# Install huggingface-cli if not already installed\npip install huggingface-hub\n# Download model from Hugging Face  \nhuggingface-cli download Qwen/Qwen3-30B-A3B --local-dir /mnt/data/models/Qwen3-30B-A3B\n```\n\n**Step 2: Launch SGLang server**\n\n```bash\npython -m sglang.launch_server \\\n    --host 0.0.0.0 \\\n    --port 30000 \\\n    --model /mnt/data/models/Qwen3-30B-A3B \\\n    --kt-weight-path /mnt/data/models/Qwen3-30B-A3B \\\n    --kt-cpuinfer 64 \\\n    --kt-threadpool-count 2 \\\n    --kt-num-gpu-experts 32 \\\n    --kt-method BF16 \\\n    --attention-backend flashinfer \\\n    --trust-remote-code \\\n    --mem-fraction-static 0.80 \\\n    --chunked-prefill-size 16384 \\\n    --max-running-requests 4 \\\n    --served-model-name Qwen3 \\\n    --enable-mixed-chunk \\\n    --tensor-parallel-size 1 \\\n    --enable-p2p-check \\\n    --disable-shared-experts-fusion \\\n    --kt-gpu-prefill-token-threshold 4096 \\\n    --kt-enable-dynamic-expert-update\n```\n\n---\n\n#### Option B: AMX Backend (AMXINT8)\n\nFor Intel CPUs with AMX instruction set support.\n\n**Step 1: Download model weights**\n\n```bash\n# Install huggingface-cli if not already installed\npip install huggingface-hub\n\n# Download model from Hugging Face\nhuggingface-cli download Qwen/Qwen3-30B-A3B --local-dir /mnt/data/models/Qwen3-30B-A3B\n```\n\n**Step 2: Convert to CPU weights (AMXINT8)**\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /mnt/data/models/Qwen3-30B-A3B \\\n  --input-type bf16 \\\n  --output /mnt/data/models/Qwen3-30B-A3B-INT8 \\\n  --quant-method int8\n```\n\n**Step 3: Launch SGLang server**\n\n```bash\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 8000 \\\n  --model /mnt/data/models/Qwen3-30B-A3B \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.92 \\\n  --chunked-prefill-size 4096 \\\n  --served-model-name Qwen3-30B-A3B \\\n  --enable-mixed-chunk \\\n  --kt-method AMXINT8 \\\n  --kt-weight-path /mnt/data/models/Qwen3-30B-A3B-INT8 \\\n  --kt-cpuinfer 64 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 32 \\\n  --kt-max-deferred-experts-per-token 2\n```\n\n---\n\n#### Option C: LLAMAFILE Backend (GGUF)\n\nFor universal CPUs (no AMX required), using pre-quantized GGUF weights directly.\n\n**Step 1: Download GPU weights (original model)**\n\n```bash\npip install huggingface-hub\n\nhuggingface-cli download Qwen/Qwen3-30B-A3B --local-dir /mnt/data/models/Qwen3-30B-A3B\n```\n\n**Step 2: Download CPU weights (GGUF format)**\n\n```bash\nhuggingface-cli download Qwen/Qwen3-30B-A3B-GGUF Qwen3-30B-A3B-Q4_K_M.gguf \\\n  --local-dir /mnt/data/models/Qwen3-30B-A3B-Q4_K_M\n```\n\n**Step 3: Launch SGLang server**\n\n```bash\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 8000 \\\n  --model /mnt/data/models/Qwen3-30B-A3B \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.92 \\\n  --chunked-prefill-size 4096 \\\n  --served-model-name Qwen3-30B-A3B \\\n  --enable-mixed-chunk \\\n  --kt-method LLAMAFILE \\\n  --kt-weight-path /mnt/data/models/Qwen3-30B-A3B-Q4_K_M \\\n  --kt-cpuinfer 64 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 32 \\\n  --kt-max-deferred-experts-per-token 2\n```\n\n### KT-Kernel Parameters\n\n| Parameter | Description | Example Value |\n|-----------|-------------|---------------|\n| `--kt-method` | CPU inference backend method | `AMXINT4`, `AMXINT8`, `RAWINT4`, `FP8`, `FP8_PERCHANNEL`, `BF16` or `LLAMAFILE` |\n| `--kt-weight-path` | Path to quantized CPU weights | `/path/to/cpu-weights` |\n| `--kt-cpuinfer` | Number of CPU inference threads | `64` (adjust based on CPU cores) |\n| `--kt-threadpool-count` | Number of thread pools for parallel execution | `2` (typically 1-4) |\n| `--kt-num-gpu-experts` | Number of experts to keep on GPU | `32` (remaining experts go to CPU) |\n| `--kt-max-deferred-experts-per-token` | Number of experts per token to defer for pipelined execution | `2` (0 to disable, 1-4 recommended) |\n| `--kt-gpu-prefill-token-threshold` | Token count threshold for prefill strategy (native backend only) | ~`1024-4096` |\n| `--kt-enable-dynamic-expert-update` | Enable dynamic expert placement updates during prefill based on actual routing statistics | (flag, no value needed) |\n| `--kt-expert-placement-strategy` | Strategy for initial GPU expert placement | `uniform`, `frequency`, `front-loading`, or `random` |\n\n**Parameter Guidelines:**\n\n- **`kt-method`**: Choose based on your CPU and weight format:\n  - `AMXINT4`: Best performance on AMX CPUs with INT4 quantized weights (May cause huge accuracy drop for some models, e.g., Qwen3-30B-A3B)\n  - `AMXINT8`: Higher accuracy with INT8 quantized weights on AMX CPUs\n  - `RAWINT4`: Native INT4 weights shared by CPU and GPU (currently supports Kimi-K2-Thinking model). See [Kimi-K2-Thinking Native Tutorial](../doc/en/Kimi-K2-Thinking-Native.md) for details.\n  - `FP8`, `FP8_PERCHANNEL`: FP8 weights shared by CPU and GPU\n  - `BF16`: BF16 weights shared by CPU and GPU\n  - `LLAMAFILE`: GGUF-based backend\n\n- **`kt-cpuinfer`**: Set to the number of **physical CPU cores** (not hyperthreads).\n  - Check physical cores: `lscpu | grep -E \"^CPU\\(s\\)|Thread\\(s\\) per core\"`\n  - Physical cores = CPU(s) / Thread(s) per core\n  - Example: If CPU(s)=128 and Thread(s) per core=2, then physical cores = 64\n  - **Important**: Do NOT set to hyperthread count - this will degrade performance\n\n- **`kt-threadpool-count`**: Set to the number of **NUMA nodes**.\n  - Check NUMA count: `lscpu | grep \"NUMA node(s)\"`\n  - Or use: `numactl --hardware | grep \"available\"`\n  - **Note**: NUMA node count is NOT necessarily the number of physical CPUs\n    - It represents memory domains, which may be divided within a single CPU or across multiple CPUs\n    - Use the NUMA node count from `lscpu`, regardless of physical CPU count\n  - Typical values: 1-2 for single-socket, 2-4 for dual-socket systems\n  - This enables better memory bandwidth utilization across NUMA domains\n\n- **`kt-num-gpu-experts`**: Determine based on GPU memory and profiling:\n  - More GPU experts = lower latency but higher GPU memory usage (May cause OOM)\n\n- **`kt-max-deferred-experts-per-token`**: Enables pipelined execution:\n  - `0`: Synchronous execution (simpler, higher latency)\n  - `1-4`: Deferred execution (recommended range; good latency/quality balance, requires tuning)\n  - `5-7`: Highest latency reduction but may introduce noticeable accuracy loss; use with care\n\n- **`kt-gpu-prefill-token-threshold`** (FP8 and RAWINT4 only): Controls prefill strategy for native FP8 and INT4 inference:\n  - **≤ threshold**: Uses hybrid CPU+GPU prefill. No extra VRAM needed, but performance degrades slowly as token count increases.\n  - **> threshold**: Uses layerwise GPU prefill. Performance scales better with longer sequences, but requires one MoE layer extra VRAM (e.g., ~9GB+ for Kimi-K2-Thinking and ~3.6GB for MiniMax-M2.1).\n  - Only applicable when `--kt-method RAWINT4` or `--kt-method FP8` is used.\n\n- **`kt-enable-dynamic-expert-update`**: Enables dynamic expert placement updates during inference.\n  - During layerwise prefill, the system collects actual routing statistics and redistributes GPU experts accordingly.\n  - Requires `--kt-gpu-prefill-token-threshold` to be set, and prefill length must be ≥ the threshold value.\n  - Particularly effective at lower GPU expert ratios (10%-70%), where it can significantly outperform static strategies.\n  - See [Expert Scheduling Tutorial](../doc/en/kt-kernel/experts-sched-Tutorial.md) for benchmarks and details.\n\n- **`kt-expert-placement-strategy`**: Determines which experts are placed on GPU at server startup.\n  - `uniform`: Distributes GPU experts evenly across all MoE layers. Default option, no prior statistics needed.\n  - `frequency`: Places the most frequently activated experts on GPU. Best performance when activation statistics are available; requires `--init-expert-location` pointing to a `.pt` statistics file.\n  - `front-loading`: Fills GPU experts from the first MoE layer onwards.\n  - `random`: Randomly selects experts with a fixed seed (42).\n  - See [Expert Scheduling Tutorial](../doc/en/kt-kernel/experts-sched-Tutorial.md) for strategy comparison.\n\n## Direct Python API Usage\n\nFor standalone usage without SGLang, you can use KT-Kernel directly via Python API:\n\n```python\nfrom kt_kernel import KTMoEWrapper\n\n# Initialize the MoE wrapper\nwrapper = KTMoEWrapper(\n    layer_idx=0,\n    num_experts=8,\n    num_experts_per_tok=2,\n    hidden_size=4096,\n    moe_intermediate_size=14336,\n    num_gpu_experts=2,\n    cpuinfer_threads=32,\n    threadpool_count=2,\n    weight_path=\"/path/to/weights\",\n    chunked_prefill_size=512,\n    method=\"AMXINT4\"  # Options: \"AMXINT4\", \"AMXINT8\", \"LLAMAFILE\"\n)\n\n# Load weights (from disk - pre-quantized)\nwrapper.load_weights(physical_to_logical_map)\n\n# Or load weights from tensors (online quantization)\nwrapper.load_weights_from_tensors(gate_proj, up_proj, down_proj, physical_to_logical_map)\n\n# Run inference\noutput = wrapper.forward(hidden_states, topk_ids, topk_weights, cuda_stream)\n\n# Or use async API for better performance\nwrapper.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream)\n# ... do other work ...\noutput = wrapper.sync_forward(hidden_states, cuda_stream)\n```\n\n### Advanced Options\n\n```python\n# Initialize with additional options\nwrapper = KTMoEWrapper(\n    layer_idx=0,\n    num_experts=8,\n    num_experts_per_tok=2,\n    hidden_size=4096,\n    moe_intermediate_size=14336,\n    num_gpu_experts=2,\n    cpuinfer_threads=32,\n    threadpool_count=2,\n    weight_path=\"/path/to/weights\",\n    chunked_prefill_size=512,\n    method=\"AMXINT4\",\n    cpu_save=False,  # Keep weights in CPU memory after loading\n    max_deferred_experts_per_token=0  # Number of experts to defer (for pipelined execution)\n)\n\n# Pre-allocate buffers for specific batch sizes (improves performance)\nKTMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])\n\n# Query captured batch sizes\nbatch_sizes = KTMoEWrapper.get_capture_batch_sizes()\n\n# Clear buffer cache to free memory\nKTMoEWrapper.clear_buffer_cache()\n```\n\n### Manual Configuration (Advanced)\n\nFor portable builds, binary distribution, or cross-machine deployment, you need to manually specify target instruction sets:\n\n```bash\n# General distribution (works on any AVX512 CPU from 2017+)\nexport CPUINFER_CPU_INSTRUCT=AVX512\nexport CPUINFER_ENABLE_AMX=OFF\n./install.sh build --manual\n\n# Maximum compatibility (works on any CPU from 2013+)\nexport CPUINFER_CPU_INSTRUCT=AVX2\nexport CPUINFER_ENABLE_AMX=OFF\n./install.sh build --manual\n\n# Modern CPUs only (Ice Lake+, Zen 4+)\nexport CPUINFER_CPU_INSTRUCT=FANCY\nexport CPUINFER_ENABLE_AMX=OFF\n./install.sh build --manual\n```\n\n**Optional: Override VNNI/BF16 detection**\n```bash\n# Force enable/disable VNNI and BF16 (for testing fallbacks)\nexport CPUINFER_ENABLE_AVX512_VNNI=OFF\nexport CPUINFER_ENABLE_AVX512_BF16=OFF\n./install.sh\n```\n\nSee `./install.sh --help` for all available options.\n\n---\n\n## Build Configuration\n\n### Manual Installation (Without install.sh)\n\nIf you prefer manual installation without the `install.sh` script:\n\n#### 1. Install System Dependencies\n\n**Prerequisites:**\n- `cmake` (recommended: `conda install -y cmake`)\n- `libhwloc-dev` and `pkg-config`\n\n#### 2. Set Build Configuration\n\n**Core Options:**\n\n| Variable | Options | Description |\n|----------|---------|-------------|\n| `CPUINFER_CPU_INSTRUCT` | `NATIVE`, `AVX512`, `AVX2`, `FANCY` | CPU instruction set to use |\n| `CPUINFER_ENABLE_AMX` | `ON`, `OFF` | Enable Intel AMX support |\n| `CPUINFER_BUILD_TYPE` | `Release`, `Debug`, `RelWithDebInfo` | Build type (default: `Release`) |\n| `CPUINFER_PARALLEL` | Number | Parallel build jobs (default: auto-detect) |\n| `CPUINFER_VERBOSE` | `0`, `1` | Verbose build output (default: `0`) |\n\n**Instruction Set Details:**\n\n| Option | Target CPUs | Use Case |\n|--------|-------------|----------|\n| **`NATIVE`** | Your specific CPU only | Local builds (best performance, **default**) |\n| **`AVX512`** | Skylake-X, Ice Lake, Cascade Lake, Zen 4+ | General distribution |\n| **`AVX2`** | Haswell (2013) and newer | Maximum compatibility |\n| **`FANCY`** | Ice Lake+, Zen 4+ | Modern CPUs with full AVX512 extensions |\n\n**Example Configurations:**\n\n```bash\n# Local use - maximum performance (default behavior)\nexport CPUINFER_CPU_INSTRUCT=NATIVE\nexport CPUINFER_ENABLE_AMX=ON  # or OFF\n\n# Distribution build - works on any AVX512 CPU\nexport CPUINFER_CPU_INSTRUCT=AVX512\nexport CPUINFER_ENABLE_AMX=OFF\n\n# Maximum compatibility - works on CPUs since 2013\nexport CPUINFER_CPU_INSTRUCT=AVX2\nexport CPUINFER_ENABLE_AMX=OFF\n\n# Debug build\nexport CPUINFER_BUILD_TYPE=Debug\nexport CPUINFER_VERBOSE=1\n```\n\n#### 3. Build and Install\n\n```bash\n# Editable installation (for development)\npip install -e .\n\n# Standard installation\npip install .\n```\n\n## Error Troubleshooting\n\n### CUDA Not Found\n\n```\n -- Looking for a CUDA compiler - NOTFOUND\n  CMake Error at CMakeLists.txt:389 (message):\n    KTRANSFORMERS_USE_CUDA=ON but CUDA compiler not found\n```\n\nMake sure you have the CUDA toolkit installed and `nvcc` is in your system PATH.\n\nTry `export CMAKE_ARGS=\"-D CMAKE_CUDA_COMPILER=$(which nvcc)\"` and reinstall again.\n\n### hwloc Not Found\n\nRun `sudo apt install libhwloc-dev` if on a Debian-based system or build from source: https://www.open-mpi.org/projects/hwloc/.\n\n```\nwget https://download.open-mpi.org/release/hwloc/v2.12/hwloc-2.12.2.tar.gz\ntar -xzf hwloc-2.12.2.tar.gz\ncd hwloc-2.12.2\n./configure\nmake\nsudo make install\n```\n\n## Weight Quantization\n\nFor AMX backends (`AMXINT4` / `AMXINT8`), CPU-side experts must be converted to AMX-friendly INT4/INT8 format using the provided script:\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /path/to/model \\\n  --input-type bf16 \\\n  --output /path/to/output \\\n  --quant-method int4\n```\n\n**Supported formats:** FP8, FP16, BF16 → INT4/INT8\n\nFor LLAMAFILE backend (`LLAMAFILE`), CPU-side experts are loaded directly from **GGUF** weights. You do **not** need to run the AMX conversion script; instead, download a GGUF model from the web (e.g., a GGUF repo on Hugging Face) and point `weight_path` / SGLang `--kt-weight-path` (or `--model` when appropriate) to that GGUF directory. KT-Kernel supports multiple GGUF quantization types such as `Q4_KM`, `Q4_K`, `Q5_K`, etc.\n\n---\n\nFor detailed documentation, advanced options, and low-memory mode, see [scripts/README.md](scripts/README.md).\n\n## Before Commit!\n\nCommit messages should follow the Conventional Commits specification: https://www.conventionalcommits.org/\n\nPlease format your code before committing:\n\n```shell\ncmake -B build\ncd build\nmake format\n```\n\nYou may need a newer clang-format (at least version 18). In a conda environment:\n\n```shell\nconda install -c conda-forge clang-format=18\nrm -rf build\n```\n\nIt's also recommended to install black for Python code formatting:\n\n```shell\nconda install black\n```\n"
  },
  {
    "path": "kt-kernel/README_zh.md",
    "content": "# KT-Kernel\n\n高性能 KTransformers 内核库，提供面向 CPU 的高效 MoE 推理内核，支持 AMX 和 AVX 等后端。\n\n- [KT-Kernel](#kt-kernel)\n  - [说明](#说明)\n  - [特性](#特性)\n  - [安装](#安装)\n    - [先决条件](#先决条件)\n    - [快速安装（推荐）](#快速安装推荐)\n    - [手动配置（进阶）](#手动配置进阶)\n  - [验证安装](#验证安装)\n  - [与 SGLang 集成](#与-sglang-集成)\n    - [安装步骤](#安装步骤)\n      - [1. 安装 SGLang](#1-安装-sglang)\n      - [2. 准备权重](#2-准备权重)\n      - [3. 启动 SGLang Server](#3-启动-sglang-server)\n    - [完整示例：Qwen3-30B-A3B](#完整示例qwen3-30b-a3b)\n      - [方案 A：AMX 后端（AMXINT8）](#方案-aamx-后端amxint8)\n      - [方案 B：LLAMAFILE 后端（GGUF）](#方案-bllamafile-后端gguf)\n    - [KT-Kernel 参数](#kt-kernel-参数)\n  - [直接使用 Python API](#直接使用-python-api)\n    - [高级选项](#高级选项)\n  - [构建配置](#构建配置)\n    - [手动安装](#手动安装)\n      - [1. 安装系统依赖](#1-安装系统依赖)\n      - [2. 配置构建参数](#2-配置构建参数)\n      - [3. 构建并安装](#3-构建并安装)\n  - [错误排查](#错误排查)\n    - [找不到 CUDA](#找不到-cuda)\n    - [找不到 hwloc](#找不到-hwloc)\n  - [权重量化](#权重量化)\n  - [提交前必读](#提交前必读)\n\n## 说明\n\n**当前支持状态：**\n- ✅ **带 AMX 的 Intel CPU**：已支持（基于转换为 INT4/INT8 格式的权重）\n- ✅ **通用 CPU（llamafile 后端）**：已支持（基于 GGUF 格式的权重）\n- ✅ **带 BLIS 的 AMD CPU**：已支持（int8 的 prefill 和 decode）\n- ✅ **Kimi-K2 原生 INT4（RAWINT4）**：支持 AVX512 CPU（CPU-GPU 共享 INT4 权重）- [使用指南](../doc/en/Kimi-K2-Thinking-Native.md)\n\n## 特性\n\n- **CPU 友好的 MoE 内核**：针对指令集优化的高吞吐 MoE 专家内核。\n- **AMX INT4/INT8 后端**：面向支持 AMX 的服务器提供 INT4 / INT8 量化专家推理后端。\n- **Llamafile CPU 后端**：基于 Llamafile 的 AVX2/AVX512 MoE 后端，适用于通用 CPU 部署。\n- **NUMA 感知执行**：为多路 / 多 NUMA 机器设计的线程池和内存布局。\n\n\n## 安装\n\n### 从源码安装（本机使用或自定义构建）\n\n适用于本地安装，或需要 AMD (BLIS)、ARM (KML) 或自定义 CUDA 版本的场景。\n\n#### 先决条件\n\n首先初始化子模块并创建 conda 环境：\n```bash\ngit submodule update --init --recursive\nconda create -n kt-kernel python=3.11 -y\nconda activate kt-kernel\n```\n\n#### 快速安装（推荐）\n\n只需运行安装脚本，它会自动检测 CPU 并优化性能：\n\n```bash\n./install.sh\n```\n\n**自动完成的操作：**\n- 自动检测 CPU 能力（AMX、AVX512_VNNI、AVX512_BF16）\n- 安装系统依赖（`cmake`、`libhwloc-dev`、`pkg-config`）\n- 为**你的 CPU** 构建优化二进制（使用 `-march=native`）\n- **软件回退机制**：为不支持 VNNI/BF16 的 CPU 自动启用\n\n**可选：分步安装**\n```bash\n./install.sh deps   # 仅安装依赖\n./install.sh build  # 构建并安装 kt-kernel\n```\n\n**不同后端的 CPU 要求：**\n\n| 后端 | 最低 CPU 要求 | 示例 CPU | 说明 |\n|------|---------------|----------|------|\n| **LLAMAFILE** | AVX2 | Intel Haswell (2013+)、AMD Zen+ | 通用兼容性 |\n| **RAWINT4** | AVX512F + AVX512BW | Intel Skylake-X (2017+)、Ice Lake、Cascade Lake | 支持 VNNI/BF16 软件回退 |\n| **AMXINT4/INT8** | AMX | Intel Sapphire Rapids (2023+) | 最佳性能，需要 AMX 硬件 |\n\n**软件回退支持（AVX512 后端）：**\n- ✅ VNNI 回退：使用 AVX512BW 指令\n- ✅ BF16 回退：使用 AVX512F 指令\n- ✅ 老的 AVX512 CPU（Skylake-X、Cascade Lake）可以运行 RAWINT4（使用回退）\n\n⚠️ **可移植性说明：** 默认构建针对你的特定 CPU 优化，可能无法在不同/更老的 CPU 上运行。如需打包分发或跨机器部署，请参见下方的 [手动配置](#手动配置进阶)。\n\n⚠️ **AMD BLIS 后端用户：** 请参见 [安装指南](https://github.com/kvcache-ai/ktransformers/issues/1601) 了解 AMD 专用配置。\n\n## 验证安装\n\n```bash\npython -c \"from kt_kernel import KTMoEWrapper; print('✓ kt-kernel installed successfully')\"\n```\n\n## 与 SGLang 集成\n\nKT-Kernel 可以单独通过 [Python API](#直接使用-python-api) 使用，也可以集成到 SGLang 中用于生产部署。  \n本节描述如何与 SGLang 集成，实现 CPU-GPU 混合（异构）推理：将“热” experts 放在 GPU 上，“冷” experts 放在 CPU 上，以达到资源利用和性价比的平衡。\n\n### 安装步骤\n\n#### 1. 安装 SGLang\n\n安装 kvcache-ai 分支的 SGLang（kt-kernel 需要此分支）：\n\n```bash\n# 方式 A: 一键安装（从 ktransformers 根目录，同时安装 sglang + kt-kernel）\n./install.sh\n\n# 方式 B: pip 安装\npip install sglang-kt\n\n# 方式 C: 从源码安装（可编辑模式）\ngit clone --recursive https://github.com/kvcache-ai/ktransformers.git\ncd ktransformers\npip install -e \"third_party/sglang/python[all]\"\n```\n\n> **重要:** 请使用 `sglang-kt`（kvcache-ai 分支），而非官方 `sglang` 包。如已安装官方版本，请先卸载：`pip uninstall sglang -y`\n\n#### 2. 准备权重\n\n要进行异构推理，需要同时准备 GPU 权重和 CPU 侧 experts 对应的权重，具体格式取决于后端类型：\n\n**GPU 权重：**  \n使用 SGLang 所需的模型权重（例如 Hugging Face 上的原始模型目录或已量化好的 GPU 权重）。\n\n**CPU 权重（AMX 后端：`AMXINT4` / `AMXINT8`）：**  \n通过提供的脚本将权重量化为适配 AMX 的 INT4/INT8 格式：\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /path/to/model \\\n  --input-type bf16 \\\n  --output /path/to/cpu-weights \\\n  --quant-method int8  # 或 int4 或 moe_int8（用于 amd 的）\n```\n\n- `--input-path`：GPU 侧原始权重路径\n- `--input-type`：取决于 GPU 侧权重类型（`fp8`、`fp16` 或 `bf16`）\n\n在 SGLang 集成中，`--kt-weight-path` 应指向该转换后的 CPU 权重目录。\n\n**支持的输入格式：** FP8、FP16、BF16 → INT4/INT8。\n\n**CPU 权重（LLAMAFILE 后端：`LLAMAFILE`）：**  \nLLAMAFILE 在 CPU 侧直接使用预量化的 **GGUF** 权重，无需运行 `convert_cpu_weights.py`。你需要：\n\n- 直接从互联网上下载 GGUF 模型（例如 Hugging Face / Modelscope 上的 GGUF 仓库）；\n- 在 SGLang 集成中，将该 GGUF 目录作为 `--kt-weight-path`。\n  KT-Kernel 支持多种 GGUF 量化格式，例如 `Q4_KM`、`Q4_K`、`Q5_K` 等，可根据延迟和效果需求选择。\n\n#### 3. 启动 SGLang Server\n\n在通常的 SGLang 启动参数基础上，增加如下 KT-Kernel 相关参数，以启用 CPU-GPU 异构推理：\n\n**需要增加的 KT-Kernel 参数：**\n- `--kt-method`：后端类型（AMXINT4、AMXINT8、或 LLAMAFILE）\n- `--kt-weight-path`：转换后的 CPU 权重路径\n- `--kt-cpuinfer`：CPU 推理线程数（建议设为物理核数）\n- `--kt-threadpool-count`：线程池数量（建议设为 NUMA 节点个数）\n- `--kt-num-gpu-experts`：留在 GPU 上的 experts 数量\n- `--kt-max-deferred-experts-per-token`：每个 token 延迟到 CPU 的 experts 数量，用于流水线执行\n\n示例：\n```bash\npython -m sglang.launch_server \\\n  [your normal SGLang parameters...] \\\n  --kt-method AMXINT8 \\\n  --kt-weight-path /path/to/cpu-weights \\\n  --kt-cpuinfer 64 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 32 \\\n  --kt-max-deferred-experts-per-token 2\n```\n\n更多调优建议见 [KT-Kernel 参数](#kt-kernel-参数) 一节。\n\n### 完整示例：Qwen3-30B-A3B\n\n该示例展示从下载权重到启动服务的完整流程，分别演示 **AMX 后端** 和 **LLAMAFILE 后端** 两种方案。\n\n**硬件配置：**\n- **GPU**：NVIDIA RTX 4090 24GB\n- **CPU**：2x Intel Xeon Gold 6454S（共 64 个物理核，128 线程，2 个 NUMA 节点）\n- **模型**：[Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B)\n\n**如何检查系统配置：**\n```bash\n# 查看 CPU 配置\nlscpu | grep -E \"^CPU\\(s\\)|Thread\\(s\\) per core|Socket\\(s\\)|NUMA node\\(s\\)\"\n# 期望输出示例:\nCPU(s):                                  128\nThread(s) per core:                      2\nSocket(s):                               2\nNUMA node(s):                            2\n# → 物理核数 = CPU(s) / Thread(s) per core = 128 / 2 = 64\n```\n\n**参数选型说明：**\n- `--kt-cpuinfer 64`：设为物理核数（64），而不是 128 线程\n- `--kt-threadpool-count 2`：检测到 2 个 NUMA 节点（双路系统）\n- `--kt-num-gpu-experts 32`：在 24GB 显存下，对该模型可以大约放 32 个 experts 在 GPU 上（具体取决于模型结构和实际内存占用）\n- `--kt-max-deferred-experts-per-token 2`：启用流水线执行；允许 CPU 处理下一批 token 的同时，GPU 完成当前批次\n\n---\n\n#### 方案 A：AMX 后端（AMXINT8）\n\n适用于支持 AMX 指令集的 Intel CPU。\n\n**步骤 1：下载模型权重**\n\n```bash\n# 如未安装 huggingface-cli，请先安装\npip install huggingface-hub\n\n# 从 Hugging Face 下载模型\nhuggingface-cli download Qwen/Qwen3-30B-A3B --local-dir /mnt/data/models/Qwen3-30B-A3B\n```\n\n**步骤 2：转换为 CPU 权重（AMXINT8）**\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /mnt/data/models/Qwen3-30B-A3B \\\n  --input-type bf16 \\\n  --output /mnt/data/models/Qwen3-30B-A3B-INT8 \\\n  --quant-method int8\n```\n\n**步骤 3：启动 SGLang 服务**\n\n```bash\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 8000 \\\n  --model /mnt/data/models/Qwen3-30B-A3B \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.92 \\\n  --chunked-prefill-size 4096 \\\n  --served-model-name Qwen3-30B-A3B \\\n  --enable-mixed-chunk \\\n  --kt-method AMXINT8 \\\n  --kt-weight-path /mnt/data/models/Qwen3-30B-A3B-INT8 \\\n  --kt-cpuinfer 64 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 32 \\\n  --kt-max-deferred-experts-per-token 2\n```\n\n---\n\n#### 方案 B：LLAMAFILE 后端（GGUF）\n\n适用于通用 CPU（无需 AMX 支持），直接使用预量化的 GGUF 权重。\n\n**步骤 1：下载 GPU 权重（原始模型）**\n\n```bash\npip install huggingface-hub\n\nhuggingface-cli download Qwen/Qwen3-30B-A3B --local-dir /mnt/data/models/Qwen3-30B-A3B\n```\n\n**步骤 2：下载 CPU 权重（GGUF 格式）**\n\n```bash\nhuggingface-cli download Qwen/Qwen3-30B-A3B-GGUF Qwen3-30B-A3B-Q4_K_M.gguf \\\n  --local-dir /mnt/data/models/Qwen3-30B-A3B-Q4_K_M\n```\n\n**步骤 3：启动 SGLang 服务**\n\n```bash\npython -m sglang.launch_server \\\n  --host 0.0.0.0 \\\n  --port 8000 \\\n  --model /mnt/data/models/Qwen3-30B-A3B \\\n  --trust-remote-code \\\n  --mem-fraction-static 0.92 \\\n  --chunked-prefill-size 4096 \\\n  --served-model-name Qwen3-30B-A3B \\\n  --enable-mixed-chunk \\\n  --kt-method LLAMAFILE \\\n  --kt-weight-path /mnt/data/models/Qwen3-30B-A3B-Q4_K_M \\\n  --kt-cpuinfer 64 \\\n  --kt-threadpool-count 2 \\\n  --kt-num-gpu-experts 32 \\\n  --kt-max-deferred-experts-per-token 2\n```\n\n### KT-Kernel 参数\n\n| 参数 | 描述 | 示例值 |\n|------|------|--------|\n| `--kt-method` | CPU 推理后端类型 | `AMXINT4`、`AMXINT8`、`RAWINT4` 或 `LLAMAFILE` |\n| `--kt-weight-path` | 量化后的 CPU 权重路径 | `/path/to/cpu-weights` |\n| `--kt-cpuinfer` | CPU 推理线程数 | `64`（根据 CPU 核心数调整） |\n| `--kt-threadpool-count` | 并行执行的线程池数量 | `2`（通常为 1–4） |\n| `--kt-num-gpu-experts` | 保留在 GPU 上的 experts 数量 | `32`（其余 experts 由 CPU 承担） |\n| `--kt-max-deferred-experts-per-token` | 每个 token 延迟到 CPU 的 experts 数量（用于流水线执行） | `2`（0 关闭，1–4 推荐） |\n| `--kt-gpu-prefill-token-threshold` | Prefill 策略的 token 数量阈值（仅 RAWINT4） | ~`400` |\n\n**参数建议：**\n\n- **`kt-method`**：根据 CPU 能力和权重格式选择：\n  - `AMXINT4`：在 AMX CPU 上 INT4 量化时具有最佳性能（但可能对某些模型有较大精度影响，例如 Qwen3-30B-A3B）\n  - `AMXINT8`：在 AMX CPU 上提供更高精度的 INT8 量化方案\n  - `RAWINT4`：CPU 和 GPU 共享原生 INT4 权重（仅限 AMX 后端，目前仅支持 Kimi-K2-Thinking 模型）。详见 [Kimi-K2-Thinking 原生推理教程](../doc/en/Kimi-K2-Thinking-Native.md)。\n  - `LLAMAFILE`：基于 AVX2/AVX512 的通用 CPU 后端，性能较 AMX 略低，但适用范围更广\n\n- **`kt-cpuinfer`**：设置为 **物理核数**（不是线程数）。\n  - 查看物理核数：`lscpu | grep -E \"^CPU\\(s\\)|Thread\\(s\\) per core\"`\n  - 计算方式：物理核数 = CPU(s) / Thread(s) per core\n  - 例：若 CPU(s)=128 且 Thread(s) per core=2，则物理核数=64\n  - **重要**：不要设置为超线程总数，否则会降低性能\n\n- **`kt-threadpool-count`**：设置为 **NUMA 节点数**。\n  - 查看 NUMA 数：`lscpu | grep \"NUMA node(s)\"`\n  - 或：`numactl --hardware | grep \"available\"`\n  - **注意**：NUMA 节点数不等同于物理 CPU 数量：\n    - 它表示内存域，可能在单颗 CPU 内被拆分，也可能跨多颗 CPU。\n    - 请以 `lscpu` 输出的 NUMA 节点数为准。\n  - 常见配置：单路 1–2，双路 2–4\n  - 正确设置有助于充分利用跨 NUMA 域的内存带宽。\n\n- **`kt-num-gpu-experts`**：根据 GPU 显存和实际性能测试决定：\n  - GPU 上的 experts 越多 → 延迟越低，但显存占用越高（可能 OOM）\n\n- **`kt-max-deferred-experts-per-token`**：用于开启 CPU-GPU 流水线：\n  - `0`：完全同步执行（简单但延迟较高）\n  - `1–4`：推荐范围，一部分 experts 延迟到 CPU，在延迟和质量之间取得较好平衡（需要按模型调参）\n  - `5–7`：可以获得更低延迟，但存在明显精度下降风险，请谨慎使用\n\n- **`kt-gpu-prefill-token-threshold`**（仅 RAWINT4）：控制原生 INT4 推理的 prefill 策略：\n  - **≤ 阈值**：使用 CPU+GPU 混合 prefill。无需额外显存，但随着 token 数量增加性能会缓慢下降。\n  - **> 阈值**：使用分层 GPU prefill。长序列性能更好，但需要约 9GB+ 额外显存。\n  - 仅在使用 `--kt-method RAWINT4` 时生效。目前仅支持 Kimi-K2-Thinking 模型。\n\n## 直接使用 Python API\n\n如果不集成 SGLang，也可以直接通过 Python API 单独使用 KT-Kernel：\n\n```python\nfrom kt_kernel import KTMoEWrapper\n\n# 初始化 MoE 包装器\nwrapper = KTMoEWrapper(\n    layer_idx=0,\n    num_experts=8,\n    num_experts_per_tok=2,\n    hidden_size=4096,\n    moe_intermediate_size=14336,\n    num_gpu_experts=2,\n    cpuinfer_threads=32,\n    threadpool_count=2,\n    weight_path=\"/path/to/weights\",\n    chunked_prefill_size=512,\n    method=\"AMXINT4\"  # 选项: \"AMXINT4\", \"AMXINT8\", \"LLAMAFILE\"\n)\n\n# 从磁盘加载权重（预先量化好）\nwrapper.load_weights(physical_to_logical_map)\n\n# 或者从张量加载权重（在线量化）\nwrapper.load_weights_from_tensors(gate_proj, up_proj, down_proj, physical_to_logical_map)\n\n# 执行推理\noutput = wrapper.forward(hidden_states, topk_ids, topk_weights, cuda_stream)\n\n# 或使用异步 API 获取更好的流水线效果\nwrapper.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream)\n# ... 做一些其他工作 ...\noutput = wrapper.sync_forward(hidden_states, cuda_stream)\n```\n\n### 高级选项\n\n```python\n# 使用更多高级选项初始化\nwrapper = KTMoEWrapper(\n    layer_idx=0,\n    num_experts=8,\n    num_experts_per_tok=2,\n    hidden_size=4096,\n    moe_intermediate_size=14336,\n    num_gpu_experts=2,\n    cpuinfer_threads=32,\n    threadpool_count=2,\n    weight_path=\"/path/to/weights\",\n    chunked_prefill_size=512,\n    method=\"AMXINT4\",\n    cpu_save=False,  # 加载后是否将权重常驻 CPU 内存\n    max_deferred_experts_per_token=0  # 每个 token 延迟的 experts 数量（用于流水线）\n)\n\n# 为特定 batch size 预分配缓冲区（提升性能）\nKTMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])\n\n# 查看当前捕获的 batch size\nbatch_sizes = KTMoEWrapper.get_capture_batch_sizes()\n\n# 清理缓冲区缓存以释放内存\nKTMoEWrapper.clear_buffer_cache()\n```\n\n### 手动配置（进阶）\n\n如需打包分发、跨机器部署或构建可移植二进制，需要手动指定目标指令集：\n\n```bash\n# 通用分发版（适用于 2017+ 的任何 AVX512 CPU）\nexport CPUINFER_CPU_INSTRUCT=AVX512\nexport CPUINFER_ENABLE_AMX=OFF\n./install.sh build --manual\n\n# 最大兼容性（适用于 2013+ 的任何 CPU）\nexport CPUINFER_CPU_INSTRUCT=AVX2\nexport CPUINFER_ENABLE_AMX=OFF\n./install.sh build --manual\n\n# 仅限现代 CPU（Ice Lake+、Zen 4+）\nexport CPUINFER_CPU_INSTRUCT=FANCY\nexport CPUINFER_ENABLE_AMX=OFF\n./install.sh build --manual\n```\n\n**可选：覆盖 VNNI/BF16 检测**\n```bash\n# 强制启用/禁用 VNNI 和 BF16（用于测试回退）\nexport CPUINFER_ENABLE_AVX512_VNNI=OFF\nexport CPUINFER_ENABLE_AVX512_BF16=OFF\n./install.sh\n```\n\n运行 `./install.sh --help` 查看所有可用选项。\n\n---\n\n## 构建配置\n\n### 手动安装（不使用 install.sh）\n\n如果你不想使用 `install.sh` 脚本：\n\n#### 1. 安装系统依赖\n\n**前置依赖：**\n- `cmake`（推荐：`conda install -y cmake`）\n- `libhwloc-dev` 和 `pkg-config`\n\n#### 2. 配置构建参数\n\n**核心选项：**\n\n| 变量 | 取值 | 描述 |\n|------|------|------|\n| `CPUINFER_CPU_INSTRUCT` | `NATIVE`, `AVX512`, `AVX2`, `FANCY` | 使用的 CPU 指令集 |\n| `CPUINFER_ENABLE_AMX` | `ON`, `OFF` | 是否启用 Intel AMX 支持 |\n| `CPUINFER_BUILD_TYPE` | `Release`, `Debug`, `RelWithDebInfo` | 构建类型（默认：`Release`） |\n| `CPUINFER_PARALLEL` | 数值 | 并行构建的 Job 数（默认：自动检测） |\n| `CPUINFER_VERBOSE` | `0`, `1` | 是否启用详细构建日志（默认：`0`） |\n\n**指令集说明：**\n\n| 选项 | 目标 CPU | 使用场景 |\n|------|----------|----------|\n| **`NATIVE`** | 仅限你的特定 CPU | 本地构建（最佳性能，**默认**） |\n| **`AVX512`** | Skylake-X、Ice Lake、Cascade Lake、Zen 4+ | 通用分发 |\n| **`AVX2`** | Haswell (2013) 及更新 | 最大兼容性 |\n| **`FANCY`** | Ice Lake+、Zen 4+ | 具有完整 AVX512 扩展的现代 CPU |\n\n**配置示例：**\n\n```bash\n# 本地使用 - 最高性能（默认行为）\nexport CPUINFER_CPU_INSTRUCT=NATIVE\nexport CPUINFER_ENABLE_AMX=ON  # 或 OFF\n\n# 分发构建 - 适用于任何 AVX512 CPU\nexport CPUINFER_CPU_INSTRUCT=AVX512\nexport CPUINFER_ENABLE_AMX=OFF\n\n# 最大兼容性 - 适用于 2013 年以来的 CPU\nexport CPUINFER_CPU_INSTRUCT=AVX2\nexport CPUINFER_ENABLE_AMX=OFF\n\n# 调试构建\nexport CPUINFER_BUILD_TYPE=Debug\nexport CPUINFER_VERBOSE=1\n```\n\n#### 3. 构建并安装\n\n```bash\n# 开发模式（可编辑安装）\npip install -e .\n\n# 普通安装\npip install .\n```\n\n## 错误排查\n\n### 找不到 CUDA\n\n```\n -- Looking for a CUDA compiler - NOTFOUND\n  CMake Error at CMakeLists.txt:389 (message):\n    KTRANSFORMERS_USE_CUDA=ON but CUDA compiler not found\n```\n\n请确认已安装 CUDA Toolkit 且 `nvcc` 在系统 PATH 中。\n\n可以尝试：\n\n```bash\nexport CMAKE_ARGS=\"-D CMAKE_CUDA_COMPILER=$(which nvcc)\"\npip install .\n```\n\n然后重新安装。\n\n### 找不到 hwloc\n\n在 Debian 系发行版上可以直接：\n\n```bash\nsudo apt install libhwloc-dev\n```\n\n或从源码构建：https://www.open-mpi.org/projects/hwloc/\n\n```bash\nwget https://download.open-mpi.org/release/hwloc/v2.12/hwloc-2.12.2.tar.gz\ntar -xzf hwloc-2.12.2.tar.gz\ncd hwloc-2.12.2\n./configure\nmake\nsudo make install\n```\n\n## 权重量化\n\n对于 AMX 后端（`AMXINT4` / `AMXINT8`），CPU 侧 experts 需要通过提供的脚本转换为适配 AMX 的 INT4/INT8 格式：\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /path/to/model \\\n  --input-type bf16 \\\n  --output /path/to/output \\\n  --quant-method int4\n```\n\n**支持的格式：** FP8、FP16、BF16 → INT4/INT8\n\n对于 LLAMAFILE 后端（`LLAMAFILE`），CPU 侧 experts 直接从 **GGUF** 权重中加载。  \n你**不需要**运行 AMX 转换脚本；只需从互联网上下载 GGUF 模型（例如 Hugging Face 上的 GGUF 仓库），并在 `weight_path` 或 SGLang 的 `--kt-weight-path` / `--model` 中指向该 GGUF 目录即可。KT-Kernel 支持多种 GGUF 量化格式，如 `Q4_KM`、`Q4_K`、`Q5_K` 等。\n\n---\n\n更多详细文档、高级参数和低显存模式，请参见 [scripts/README.md](scripts/README.md)。\n\n## 提交前必读\n\n提交信息应符合 Conventional Commits 规范：https://www.conventionalcommits.org/  \n在提交前请先格式化代码：\n\n```shell\ncmake -B build\ncd build\nmake format\n```\n\n你可能需要一个较新的 clang-format（至少 18），在 conda 环境中可以：\n\n```shell\nconda install -c conda-forge clang-format=18\nrm -rf build\n```\n\n并且建议安装 black 用于 Python 代码格式化：\n\n```shell\nconda install black\n```\n"
  },
  {
    "path": "kt-kernel/bench/.gitignore",
    "content": "*.jsonl\n*.json"
  },
  {
    "path": "kt-kernel/bench/Makefile",
    "content": "# test bench_moe_kernel_tiling.py\nkernel_tiling:\n\tpython3 bench_moe_kernel_tiling.py \\\n\t--hidden_size 7168 \\\n\t--intermediate_size 2048 \\\n\t--num_experts_per_tok 8 \\\n\t--expert_num 256 \\\n\t--max_len 51200 \\\n\t--layer_num 1 \\\n\t--qlen 1024 \\\n\t--quant int8 \\\n\t--warm_up_iter 500 \\\n\t--test_iter 1000 \\\n\t--threads 160 \\\n\t--m_block 320 \\\n\t\n# \t--n_block_up_gate 256 \\\n# \t--n_block_down 128 \\\n# \t--n_block_up_gate_prefi 256 \\\n# \t--n_block_down_prefi 128 \\\n\n# \t--n_block_up_gate 256 \\\n# \t--n_block_down 512 \\"
  },
  {
    "path": "kt-kernel/bench/bench_attention.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : Jianwei Dong\nDate         : 2024-08-28 10:32:05\nVersion      : 1.0.0\nLastEditors  : Jianwei Dong\nLastEditTime : 2024-08-28 10:32:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\nlayer_num = 10\nkv_head_num = 8\nq_head_num = 32\nhead_dim = 128\nblock_len = 128\nanchor_num = 1\n\nanchor_type = kt_kernel_ext.kvcache.AnchorType.DYNAMIC\nkv_type = kt_kernel_ext.kvcache.ggml_type.FP16\nretrieval_type = kt_kernel_ext.kvcache.RetrievalType.LAYER\nlayer_step: int = 1\ntoken_step: int = 1\nlayer_offset: int = 0\nmax_thread_num: int = 64\nmax_batch_size: int = 1\nmax_block_num: int = 1024\nCPUInfer = kt_kernel_ext.CPUInfer(max_thread_num)\n\nwarm_up_iter = 1000\ntest_iter = 10000\n\n\ndef bench_linear(cache_seqlen: int):\n    with torch.inference_mode(mode=True):\n        cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device=\"cpu\")\n        seqlens_zero = torch.zeros((1,), dtype=torch.int32, device=\"cpu\")\n\n        config = kt_kernel_ext.kvcache.KVCacheConfig(\n            layer_num,\n            kv_head_num,\n            q_head_num,\n            head_dim,\n            block_len,\n            anchor_num,\n            anchor_type,\n            kv_type,\n            retrieval_type,\n            layer_step,\n            token_step,\n            layer_offset,\n            max_block_num,\n            max_batch_size,\n            max_thread_num,\n        )\n        local_kvcache = kt_kernel_ext.kvcache.KVCache(config)\n        block_table = torch.arange(max_block_num, dtype=torch.int32, device=\"cpu\").contiguous().view(1, -1)\n\n        for layer_idx in range(layer_num):\n            k_cache = torch.randn(\n                (1, cache_seqlen, kv_head_num, head_dim),\n                dtype=torch.float16,\n                device=\"cpu\",\n            ).contiguous()\n            v_cache = torch.randn(\n                (1, cache_seqlen, kv_head_num, head_dim),\n                dtype=torch.float16,\n                device=\"cpu\",\n            ).contiguous()\n\n            CPUInfer.submit(\n                local_kvcache.update_kvcache_fp16(\n                    k_cache.data_ptr(),\n                    v_cache.data_ptr(),\n                    layer_idx,\n                    block_table.data_ptr(),\n                    1,\n                    max_block_num,\n                    seqlens_zero.data_ptr(),\n                    cache_seqlen,\n                )\n            )\n            CPUInfer.sync()\n\n        input = torch.randn((1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\").contiguous()\n        output = torch.empty((1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\").contiguous()\n\n        # attn_lse: (bsz, q_len, q_head_num)\n        attn_lse = torch.empty((1, 1, q_head_num), dtype=torch.float32, device=\"cpu\").contiguous()\n        input = input / 100\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                local_kvcache.attn(\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    attn_lse.data_ptr(),\n                    i % layer_num,\n                    0,\n                    1,\n                    1,\n                    max_block_num,\n                    block_table.data_ptr(),\n                    cache_seqlens.data_ptr(),\n                    -1,\n                    -1,\n                    -1,\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                local_kvcache.attn(\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    attn_lse.data_ptr(),\n                    i % layer_num,\n                    0,\n                    1,\n                    1,\n                    max_block_num,\n                    block_table.data_ptr(),\n                    cache_seqlens.data_ptr(),\n                    -1,\n                    -1,\n                    -1,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print(\"cache sequence length: \", cache_seqlen)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", total_time / test_iter * 1000000)\n        print(\n            \"Bandwidth: \",\n            cache_seqlen * kv_head_num * head_dim * 2 * 2 * test_iter / total_time / 1000 / 1000 / 1000,\n            \"GB/s\",\n        )\n        print(\"\")\n\n\nbench_linear(1024)\nbench_linear(4096)\nbench_linear(16384)\nbench_linear(32768)\nbench_linear(65536)\n"
  },
  {
    "path": "kt-kernel/bench/bench_attention_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : Jianwei Dong\nDate         : 2024-08-28 10:32:05\nVersion      : 1.0.0\nLastEditors  : Jianwei Dong\nLastEditTime : 2024-08-28 10:32:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\nlayer_num = 10\nkv_head_num = 8\nq_head_num = 32\nhead_dim = 128\nblock_len = 128\nanchor_num = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\n\ndef bench_linear(cache_seqlen: int, device):\n    with torch.inference_mode(mode=True):\n\n        kvcaches = []\n\n        for layer_idx in range(layer_num):\n            k_cache = torch.randn(\n                (1, 32, cache_seqlen, head_dim),\n                dtype=torch.float16,\n                device=device,\n            ).contiguous()\n            v_cache = torch.randn(\n                (1, 32, cache_seqlen, head_dim),\n                dtype=torch.float16,\n                device=device,\n            ).contiguous()\n\n            kvcaches.append((k_cache, v_cache))\n\n        input = torch.randn((1, q_head_num, 1, head_dim), dtype=torch.float16, device=device).contiguous()\n        input = input / 100\n\n        # warm up\n        for i in range(warm_up_iter):\n            k_cache = kvcaches[i % layer_num][0]\n            v_cache = kvcaches[i % layer_num][1]\n            torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            k_cache = kvcaches[i % layer_num][0]\n            v_cache = kvcaches[i % layer_num][1]\n            torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)\n        end = time.perf_counter()\n        total_time = end - start\n        print(\"cache sequence length: \", cache_seqlen)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", total_time / test_iter * 1000000)\n        print(\n            \"Bandwidth: \",\n            cache_seqlen * q_head_num * head_dim * 2 * 2 * test_iter / total_time / 1000 / 1000 / 1000,\n            \"GB/s\",\n        )\n        print(\"\")\n\n\nbench_linear(1024, \"cpu\")\nbench_linear(4096, \"cpu\")\nbench_linear(1024, \"cuda\")\nbench_linear(4096, \"cuda\")\nbench_linear(16384, \"cuda\")\nbench_linear(32768, \"cuda\")\nbench_linear(65536, \"cuda\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_bf16_moe.py",
    "content": "\"\"\"\nPerformance benchmark for native BF16 MoE kernel (AMX implementation).\n\nThis benchmark measures the performance of the BF16 MoE operator with:\n- Native BF16 weights (no quantization)\n- BF16 activations\n- AMX BF16 DPBF16PS compute path\n\"\"\"\n\nimport os\nimport sys\nimport time\nimport json\nimport subprocess\nimport platform\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\n\nimport torch\nfrom kt_kernel import kt_kernel_ext\nfrom tqdm import tqdm\n\n# Test parameters\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nnum_experts_per_tok = 8\nmax_len = 25600\n\nlayer_num = 5\nqlen = 1\nwarm_up_iter = 100\ntest_iter = 3000\nCPUINFER_PARAM = 80\n\nCPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)\n\n# Result file path\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\njson_path = os.path.join(script_dir, \"bench_bf16_moe.jsonl\")\n\n\ndef get_git_commit():\n    \"\"\"Get current git commit info\"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        result[\"dirty\"] = bool(dirty_output)\n        if dirty_output:\n            result[\"dirty_files\"] = dirty_output.splitlines()\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"Get system information\"\"\"\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception:\n            pass\n    info[\"cpu_model\"] = cpu_model\n    info[\"cpu_core_count\"] = os.cpu_count()\n    return info\n\n\ndef record_results(result, filename=json_path):\n    \"\"\"Append result to JSON file\"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef generate_bf16_weights(shape: tuple):\n    \"\"\"\n    Generate random BF16 weights.\n\n    Args:\n        shape: (expert_num, n, k) - weight tensor shape\n\n    Returns:\n        bf16_weights: bfloat16 tensor with random values\n    \"\"\"\n    # Generate random BF16 weights with small values to avoid overflow\n    weights = (torch.randn(shape, dtype=torch.float32, device=\"cuda\") / 100.0).to(torch.bfloat16).to(\"cpu\").contiguous()\n    return weights\n\n\ndef bench_bf16_moe():\n    \"\"\"Benchmark native BF16 MoE performance\"\"\"\n    with torch.inference_mode():\n        print(\"=\" * 70)\n        print(\"Native BF16 MoE Kernel Performance Benchmark\")\n        print(\"=\" * 70)\n\n        # Generate BF16 weights\n        print(\"\\nGenerating BF16 weights...\")\n        torch.manual_seed(42)\n        gate_proj = generate_bf16_weights((expert_num, intermediate_size, hidden_size))\n        up_proj = generate_bf16_weights((expert_num, intermediate_size, hidden_size))\n        down_proj = generate_bf16_weights((expert_num, hidden_size, intermediate_size))\n\n        physical_to_logical_map = torch.tensor(range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n        # Build MoE layers\n        print(\"Building BF16 MoE layers...\")\n        moes = []\n        for _ in tqdm(range(layer_num), desc=\"Initializing MOEs\"):\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n\n            # Set BF16 weight pointers (no scales needed)\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n\n            # No scales for BF16\n            config.gate_scale = 0\n            config.up_scale = 0\n            config.down_scale = 0\n            config.pool = CPUInfer.backend_\n\n            moe = kt_kernel_ext.moe.AMXBF16_MOE(config)\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n            moes.append(moe)\n\n        # Generate input data\n        print(\"Generating input data...\")\n        gen_iter = 1000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .contiguous()\n        )\n        weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").contiguous()\n        input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        qlen_tensor = torch.tensor([qlen], dtype=torch.int32)\n\n        # Warmup\n        print(f\"Warming up ({warm_up_iter} iterations)...\")\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    qlen_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n        # Benchmark\n        print(f\"Running benchmark ({test_iter} iterations)...\")\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    qlen_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n\n        # Calculate metrics\n        time_per_iter_us = total_time / test_iter * 1e6\n\n        # FLOPS calculation:\n        # Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate)\n        # GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element)\n        flops_per_expert = (\n            2 * intermediate_size * hidden_size  # gate\n            + 2 * intermediate_size * hidden_size  # up\n            + 2 * hidden_size * intermediate_size  # down\n        )\n        total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter\n        tflops = total_flops / total_time / 1e12\n\n        # Bandwidth calculation (BF16 = 2 bytes per element)\n        bytes_per_elem = 2.0\n        # Weight memory: gate + up + down per expert\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # GB/s\n\n        # Print results\n        print(\"\\n\" + \"=\" * 70)\n        print(\"Benchmark Results\")\n        print(\"=\" * 70)\n        print(f\"Quant mode: Native BF16 (no quantization)\")\n        print(f\"Total time: {total_time:.4f} s\")\n        print(f\"Iterations: {test_iter}\")\n        print(f\"Time per iteration: {time_per_iter_us:.2f} us\")\n        print(f\"Bandwidth: {bandwidth:.2f} GB/s\")\n        print(f\"TFLOPS: {tflops:.4f}\")\n        print(\"\")\n\n        # Record results\n        result = {\n            \"test_name\": os.path.basename(__file__),\n            \"quant_mode\": \"bf16_native\",\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": tflops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        result.update(get_git_commit())\n        result.update(get_system_info())\n        record_results(result)\n\n        return tflops, bandwidth\n\n\nif __name__ == \"__main__\":\n    bench_bf16_moe()\n"
  },
  {
    "path": "kt-kernel/bench/bench_fp8_moe.py",
    "content": "\"\"\"\nPerformance benchmark for FP8 MoE kernel (AVX implementation).\n\nThis benchmark measures the performance of the FP8 MoE operator with:\n- FP8 (E4M3) weights with 128x128 block-wise scaling\n- BF16 activations\n- AVX-512 DPBF16 compute path\n\"\"\"\n\nimport os\nimport sys\nimport time\nimport json\nimport subprocess\nimport platform\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\n\nimport torch\nfrom kt_kernel import kt_kernel_ext\nfrom tqdm import tqdm\n\n# Test parameters\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nnum_experts_per_tok = 8\nfp8_group_size = 128\nmax_len = 25600\n\nlayer_num = 2\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 3000\nCPUINFER_PARAM = 80\n\nCPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)\n\n# Result file path\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\njson_path = os.path.join(script_dir, \"bench_results.jsonl\")\n\n\ndef get_git_commit():\n    \"\"\"Get current git commit info\"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        result[\"dirty\"] = bool(dirty_output)\n        if dirty_output:\n            result[\"dirty_files\"] = dirty_output.splitlines()\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"Get system information\"\"\"\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception:\n            pass\n    info[\"cpu_model\"] = cpu_model\n    info[\"cpu_core_count\"] = os.cpu_count()\n    return info\n\n\ndef record_results(result, filename=json_path):\n    \"\"\"Append result to JSON file\"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef generate_fp8_weights_direct(shape: tuple, group_size: int = 128):\n    \"\"\"\n    Directly generate random FP8 weights and e8m0 format scale_inv.\n\n    Args:\n        shape: (expert_num, n, k) - weight tensor shape\n        group_size: block size for scaling (128x128 blocks)\n\n    Returns:\n        fp8_weights: uint8 tensor with random FP8 E4M3 values\n        scale_inv: fp32 tensor with e8m0 format (powers of 2)\n    \"\"\"\n    e, n, k = shape\n    n_blocks = n // group_size\n    k_blocks = k // group_size\n\n    # Directly generate random FP8 weights as uint8\n    # FP8 E4M3 format: 1 sign + 4 exp + 3 mantissa\n    # Valid range for normal numbers: exp 1-14 (0 is subnormal, 15 is special)\n    fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device=\"cuda\").to(\"cpu\").contiguous()\n\n    # Generate e8m0 format scale_inv (powers of 2)\n    # e8m0: 8-bit exponent only, no mantissa, bias = 127\n    # Generate random exponents in a reasonable range (e.g., -8 to 8)\n    exponents = torch.randint(-8, 9, (e, n_blocks, k_blocks), dtype=torch.int32, device=\"cuda\").to(\"cpu\").contiguous()\n    scale_inv = (2.0 ** exponents.float()).to(torch.float32).contiguous()\n\n    return fp8_weights, scale_inv\n\n\ndef bench_fp8_moe():\n    \"\"\"Benchmark FP8 MoE performance\"\"\"\n    with torch.inference_mode():\n        print(\"=\" * 70)\n        print(\"FP8 MoE Kernel Performance Benchmark\")\n        print(\"=\" * 70)\n\n        # Generate FP8 weights directly (no quantization from fp32)\n        print(\"\\nGenerating FP8 weights directly...\")\n        torch.manual_seed(42)\n        gate_fp8, gate_scales = generate_fp8_weights_direct(\n            (expert_num, intermediate_size, hidden_size), fp8_group_size\n        )\n        up_fp8, up_scales = generate_fp8_weights_direct((expert_num, intermediate_size, hidden_size), fp8_group_size)\n        down_fp8, down_scales = generate_fp8_weights_direct(\n            (expert_num, hidden_size, intermediate_size), fp8_group_size\n        )\n\n        physical_to_logical_map = torch.tensor(range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n        # Build MoE layers\n        print(\"Building FP8 MoE layers...\")\n        moes = []\n        for _ in tqdm(range(layer_num), desc=\"Initializing MOEs\"):\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.quant_config.bits = 8\n            config.quant_config.group_size = fp8_group_size\n            config.quant_config.zero_point = False\n\n            config.gate_proj = gate_fp8.data_ptr()\n            config.up_proj = up_fp8.data_ptr()\n            config.down_proj = down_fp8.data_ptr()\n            config.gate_scale = gate_scales.data_ptr()\n            config.up_scale = up_scales.data_ptr()\n            config.down_scale = down_scales.data_ptr()\n            config.pool = CPUInfer.backend_\n\n            moe = kt_kernel_ext.moe.AMXFP8_MOE(config)\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n            moes.append(moe)\n\n        # Generate input data\n        print(\"Generating input data...\")\n        gen_iter = 1000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .contiguous()\n        )\n        weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").contiguous()\n        input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        qlen_tensor = torch.tensor([qlen], dtype=torch.int32)\n\n        # Warmup\n        print(f\"Warming up ({warm_up_iter} iterations)...\")\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    qlen_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n        # Benchmark\n        print(f\"Running benchmark ({test_iter} iterations)...\")\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    qlen_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n\n        # Calculate metrics\n        time_per_iter_us = total_time / test_iter * 1e6\n\n        # FLOPS calculation:\n        # Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate)\n        # GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element)\n        # For vector-matrix multiply (qlen=1): 2 * n * k per matrix\n        flops_per_expert = (\n            2 * intermediate_size * hidden_size  # gate\n            + 2 * intermediate_size * hidden_size  # up\n            + 2 * hidden_size * intermediate_size  # down\n        )\n        total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter\n        tflops = total_flops / total_time / 1e12\n\n        # Bandwidth calculation (FP8 = 1 byte per element)\n        bytes_per_elem = 1.0\n        # Weight memory: gate + up + down per expert\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # 单位：GB/s\n\n        # Print results\n        print(\"\\n\" + \"=\" * 70)\n        print(\"Benchmark Results\")\n        print(\"=\" * 70)\n        print(f\"Quant mode: FP8 (E4M3) with {fp8_group_size}x{fp8_group_size} block scaling\")\n        print(f\"Total time: {total_time:.4f} s\")\n        print(f\"Iterations: {test_iter}\")\n        print(f\"Time per iteration: {time_per_iter_us:.2f} us\")\n        print(f\"Bandwidth: {bandwidth:.2f} GB/s\")\n        print(f\"TFLOPS: {tflops:.4f}\")\n        print(\"\")\n\n        # Record results\n        result = {\n            \"test_name\": os.path.basename(__file__),\n            \"quant_mode\": \"fp8_e4m3\",\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": tflops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"fp8_group_size\": fp8_group_size,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        result.update(get_git_commit())\n        result.update(get_system_info())\n        record_results(result)\n\n        return tflops, bandwidth\n\n\nif __name__ == \"__main__\":\n    bench_fp8_moe()\n"
  },
  {
    "path": "kt-kernel/bench/bench_fp8_perchannel_moe.py",
    "content": "\"\"\"\nPerformance benchmark for FP8 Per-Channel MoE kernel (GLM-4.7-FP8 style).\n\nThis benchmark measures the performance of the FP8 Per-Channel MoE operator with:\n- FP8 (E4M3) weights with per-channel scaling (one scale per output row)\n- BF16 activations\n- AVX-512 DPBF16 compute path\n\"\"\"\n\nimport os\nimport sys\nimport time\nimport json\nimport subprocess\nimport platform\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\n\nimport torch\nfrom kt_kernel import kt_kernel_ext\nfrom tqdm import tqdm\n\n# Test parameters\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nnum_experts_per_tok = 8\nmax_len = 25600\n\nlayer_num = 2\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 3000\nCPUINFER_PARAM = 80\n\nCPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)\n\n# Result file path\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\njson_path = os.path.join(script_dir, \"bench_results.jsonl\")\n\n\ndef get_git_commit():\n    \"\"\"Get current git commit info\"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        result[\"dirty\"] = bool(dirty_output)\n        if dirty_output:\n            result[\"dirty_files\"] = dirty_output.splitlines()\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"Get system information\"\"\"\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception:\n            pass\n    info[\"cpu_model\"] = cpu_model\n    info[\"cpu_core_count\"] = os.cpu_count()\n    return info\n\n\ndef record_results(result, filename=json_path):\n    \"\"\"Append result to JSON file\"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef generate_fp8_perchannel_weights_direct(shape: tuple):\n    \"\"\"\n    Directly generate random FP8 weights and per-channel scales.\n\n    Args:\n        shape: (expert_num, n, k) - weight tensor shape\n\n    Returns:\n        fp8_weights: uint8 tensor with random FP8 E4M3 values\n        scales: fp32 tensor with per-channel scales, shape [expert_num, n]\n    \"\"\"\n    e, n, k = shape\n\n    # Directly generate random FP8 weights as uint8\n    # FP8 E4M3 format: 1 sign + 4 exp + 3 mantissa\n    fp8_weights = torch.randint(0, 256, (e, n, k), dtype=torch.uint8, device=\"cuda\").to(\"cpu\").contiguous()\n\n    # Generate random per-channel scales (one per output row)\n    # Use reasonable scale range (e.g., 2^-8 to 2^8)\n    exponents = torch.randint(-8, 9, (e, n), dtype=torch.int32, device=\"cuda\").to(\"cpu\").contiguous()\n    scales = (2.0 ** exponents.float()).to(torch.float32).contiguous()\n\n    return fp8_weights, scales\n\n\ndef bench_fp8_perchannel_moe():\n    \"\"\"Benchmark FP8 Per-Channel MoE performance\"\"\"\n    with torch.inference_mode():\n        print(\"=\" * 70)\n        print(\"FP8 Per-Channel MoE Kernel Performance Benchmark\")\n        print(\"=\" * 70)\n\n        # Generate FP8 weights with per-channel scales\n        print(\"\\nGenerating FP8 weights with per-channel scales...\")\n        torch.manual_seed(42)\n        gate_fp8, gate_scales = generate_fp8_perchannel_weights_direct((expert_num, intermediate_size, hidden_size))\n        up_fp8, up_scales = generate_fp8_perchannel_weights_direct((expert_num, intermediate_size, hidden_size))\n        down_fp8, down_scales = generate_fp8_perchannel_weights_direct((expert_num, hidden_size, intermediate_size))\n\n        physical_to_logical_map = torch.tensor(range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n        # Build MoE layers\n        print(\"Building FP8 Per-Channel MoE layers...\")\n        moes = []\n        for _ in tqdm(range(layer_num), desc=\"Initializing MOEs\"):\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.quant_config.bits = 8\n            config.quant_config.group_size = 0  # Not used for per-channel\n            config.quant_config.zero_point = False\n            config.quant_config.per_channel = True  # Enable per-channel mode\n\n            config.gate_proj = gate_fp8.data_ptr()\n            config.up_proj = up_fp8.data_ptr()\n            config.down_proj = down_fp8.data_ptr()\n            config.gate_scale = gate_scales.data_ptr()\n            config.up_scale = up_scales.data_ptr()\n            config.down_scale = down_scales.data_ptr()\n            config.pool = CPUInfer.backend_\n\n            moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config)\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n            moes.append(moe)\n\n        # Generate input data\n        print(\"Generating input data...\")\n        gen_iter = 1000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .contiguous()\n        )\n        weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").contiguous()\n        input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        qlen_tensor = torch.tensor([qlen], dtype=torch.int32)\n\n        # Warmup\n        print(f\"Warming up ({warm_up_iter} iterations)...\")\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    qlen_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n        # Benchmark\n        print(f\"Running benchmark ({test_iter} iterations)...\")\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    qlen_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n\n        # Calculate metrics\n        time_per_iter_us = total_time / test_iter * 1e6\n\n        # FLOPS calculation:\n        # Each expert performs: gate(intermediate x hidden) + up(intermediate x hidden) + down(hidden x intermediate)\n        # GEMM/GEMV: 2 * m * n * k flops (multiply + accumulate = 2 ops per element)\n        # For vector-matrix multiply (qlen=1): 2 * n * k per matrix\n        flops_per_expert = (\n            2 * intermediate_size * hidden_size  # gate\n            + 2 * intermediate_size * hidden_size  # up\n            + 2 * hidden_size * intermediate_size  # down\n        )\n        total_flops = qlen * num_experts_per_tok * flops_per_expert * test_iter\n        tflops = total_flops / total_time / 1e12\n\n        # Bandwidth calculation (FP8 = 1 byte per element)\n        bytes_per_elem = 1.0\n        # Weight memory: gate + up + down per expert\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / num_experts_per_tok * expert_num * (1 - (1 - num_experts_per_tok / expert_num) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )\n\n        # Print results\n        print(\"\\n\" + \"=\" * 70)\n        print(\"Benchmark Results\")\n        print(\"=\" * 70)\n        print(f\"Quant mode: FP8 (E4M3) with per-channel scaling\")\n        print(f\"Total time: {total_time:.4f} s\")\n        print(f\"Iterations: {test_iter}\")\n        print(f\"Time per iteration: {time_per_iter_us:.2f} us\")\n        print(f\"Bandwidth: {bandwidth:.2f} GB/s\")\n        print(f\"TFLOPS: {tflops:.4f}\")\n        print(\"\")\n\n        # Record results\n        result = {\n            \"test_name\": os.path.basename(__file__),\n            \"quant_mode\": \"fp8_e4m3_perchannel\",\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": tflops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"quant_type\": \"per_channel\",\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        result.update(get_git_commit())\n        result.update(get_system_info())\n        record_results(result)\n\n        return tflops, bandwidth\n\n\nif __name__ == \"__main__\":\n    bench_fp8_perchannel_moe()\n"
  },
  {
    "path": "kt-kernel/bench/bench_k2_moe_amx.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nBenchmark AMX_K2_MOE_TP int4 path with packed weights and BF16 scales.\n\"\"\"\nimport json\nimport math\nimport os\nimport platform\nimport subprocess\nimport sys\nimport time\n\nfrom tqdm import tqdm\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\n\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\n# Benchmark parameters (single MoE, no layer loop)\nexpert_num = 384\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 5000\nk_group_size = 32\n\nphysical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\nworker_config = kt_kernel_ext.WorkerPoolConfig()\nworker_config.subpool_count = 2\nworker_config.subpool_numa_map = [0, 1]\nworker_config.subpool_thread_count = [40, 40]\nCPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n\n\ndef get_git_commit():\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception:\n            sockets = set()\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\nscript_name = os.path.splitext(os.path.basename(script_path))[0]\njson_path = os.path.join(script_dir, script_name + \".jsonl\")\n\n\ndef record_results(result, filename=json_path):\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: int = 1) -> torch.Tensor:\n    if value.dtype is not torch.int8:\n        raise ValueError(\"Tensor must be torch.int8 before packing\")\n    if not (1 <= num_bits <= 8):\n        raise ValueError(f\"num_bits must be in [1, 8], got {num_bits}\")\n\n    offset = 1 << (num_bits - 1)\n    value = (value + offset).to(torch.uint8)\n    device = value.device\n\n    pack_factor = 32 // num_bits\n\n    if packed_dim == 0:\n        value = value.transpose(0, 1)\n\n    rows, cols = value.shape\n    padded_cols = math.ceil(cols / pack_factor) * pack_factor\n    pad_len = padded_cols - cols\n\n    if pad_len > 0:\n        value = torch.nn.functional.pad(value, (0, pad_len))\n\n    num_groups = padded_cols // pack_factor\n    reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)\n    bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits\n    packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)\n\n    if packed_dim == 0:\n        packed = packed.transpose(0, 1)\n\n    return packed\n\n\ndef pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:\n    e, rows, cols = q.shape\n    flat = q.view(e * rows, cols)\n    packed = pack_to_int32(flat, num_bits)\n    return packed.view(e, rows, -1).contiguous()\n\n\ndef quantize_k2_tensor(weights: torch.Tensor, group_size: int):\n    \"\"\"\n    K2 int4 quantization producing int32-packed weights (8 int4s each) and BF16 scales.\n    \"\"\"\n    weights_f32 = weights.to(torch.float32)\n    e, rows, cols = weights_f32.shape\n    if cols % group_size != 0 or cols % 2 != 0:\n        raise ValueError(f\"cols ({cols}) must be divisible by group_size ({group_size}) and 2\")\n\n    reshaped = weights_f32.view(e, rows, cols // group_size, group_size)\n    max_abs = reshaped.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)\n    scales = (max_abs / 7.0).squeeze(-1)\n    q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)\n    q = q.view(e, rows, cols)\n    packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()\n    scales = scales.to(torch.bfloat16).contiguous().view(e, rows, cols // group_size).contiguous()\n    return packed, scales\n\n\ndef build_quantized_layer_weights():\n    gate_proj = torch.randn(\n        (expert_num, intermediate_size, hidden_size),\n        dtype=torch.float32,\n        device=\"cpu\",\n    ).contiguous()\n    up_proj = torch.randn(\n        (expert_num, intermediate_size, hidden_size),\n        dtype=torch.float32,\n        device=\"cpu\",\n    ).contiguous()\n    down_proj = torch.randn(\n        (expert_num, hidden_size, intermediate_size),\n        dtype=torch.float32,\n        device=\"cpu\",\n    ).contiguous()\n\n    gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size)\n    up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size)\n    down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size)\n\n    return {\n        \"gate_qweight\": gate_q,\n        \"up_qweight\": up_q,\n        \"down_qweight\": down_q,\n        \"gate_scales\": gate_scales,\n        \"up_scales\": up_scales,\n        \"down_scales\": down_scales,\n    }\n\n\ndef bench_k2_moe():\n    with torch.inference_mode():\n        bytes_per_elem = 0.5 + 2.0 / k_group_size\n\n        quant_data = build_quantized_layer_weights()\n        config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n        config.max_len = max_len\n        config.quant_config.bits = 4\n        config.quant_config.group_size = k_group_size\n        config.quant_config.zero_point = False\n\n        config.gate_proj = quant_data[\"gate_qweight\"].data_ptr()\n        config.up_proj = quant_data[\"up_qweight\"].data_ptr()\n        config.down_proj = quant_data[\"down_qweight\"].data_ptr()\n\n        config.gate_scale = quant_data[\"gate_scales\"].data_ptr()\n        config.up_scale = quant_data[\"up_scales\"].data_ptr()\n        config.down_scale = quant_data[\"down_scales\"].data_ptr()\n        config.pool = CPUInfer.backend_\n\n        moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)\n        CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n        CPUInfer.sync()\n\n        gen_iter = 3000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .contiguous()\n        )\n        weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").contiguous()\n        input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        output_tensor = torch.empty_like(input_tensor)\n        bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor.data_ptr(),\n                    output_tensor.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor.data_ptr(),\n                    output_tensor.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n\n        time_per_iter_us = total_time / test_iter * 1e6\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )\n        flops = hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n\n        print(\"Quant mode: int4_k2\")\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", time_per_iter_us)\n        print(\"Bandwidth: \", bandwidth, \"GB/s\")\n        print(\"Flops: \", flops, \"TFLOPS\")\n        print(\"\")\n\n        result = {\n            \"quant_mode\": \"int4_k2\",\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"max_len\": max_len,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"k_group_size\": k_group_size,\n                \"bytes_per_elem\": bytes_per_elem,\n            },\n        }\n        result.update(get_git_commit())\n        result.update(get_system_info())\n        record_results(result)\n\n\nif __name__ == \"__main__\":\n    bench_k2_moe()\n"
  },
  {
    "path": "kt-kernel/bench/bench_k2_write_buffer.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nBenchmark write_weight_scale_to_buffer for AMX_K2_MOE_TP (int4 packed weights + bf16 scales).\n\nUses two MOE instances that alternate writing to simulate realistic multi-layer scenarios.\n\"\"\"\nimport json\nimport os\nimport platform\nimport subprocess\nimport sys\nimport time\n\nfrom tqdm import tqdm\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\n\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\n# Benchmark parameters\nexpert_num = 384\nnum_experts_per_tok = expert_num\ngpu_tp_count = 4\n\nwarm_up_iter = 3\ntest_iter = 7\n\ngpu_experts_num = expert_num\n\nhidden_size = 7168\nintermediate_size = 2048\ngroup_size = 32\nmax_len = 1\n\nphysical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device=\"cpu\").contiguous()\nCPUInfer = kt_kernel_ext.CPUInfer(80)\n\n\ndef get_git_commit():\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception:\n            sockets = set()\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\nscript_name = os.path.splitext(os.path.basename(script_path))[0]\njson_path = os.path.join(script_dir, script_name + \".jsonl\")\n\n\ndef record_results(result, filename=json_path):\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef allocate_weights():\n    per_mat_weight_bytes = (hidden_size * intermediate_size) // 2\n    per_mat_scale_elems = (hidden_size * intermediate_size) // group_size\n\n    gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n    up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n    down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n\n    gate_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)\n    up_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)\n    down_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)\n\n    return (\n        gate_q.contiguous(),\n        up_q.contiguous(),\n        down_q.contiguous(),\n        gate_scale.contiguous(),\n        up_scale.contiguous(),\n        down_scale.contiguous(),\n        per_mat_weight_bytes,\n        per_mat_scale_elems,\n    )\n\n\ndef build_moe(layer_idx=0):\n    \"\"\"Build a single MOE instance with the given layer_idx.\"\"\"\n    (\n        gate_q,\n        up_q,\n        down_q,\n        gate_scale,\n        up_scale,\n        down_scale,\n        per_mat_weight_bytes,\n        per_mat_scale_elems,\n    ) = allocate_weights()\n\n    config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n    config.max_len = max_len\n    config.layer_idx = layer_idx\n    config.quant_config.bits = 4\n    config.quant_config.group_size = group_size\n    config.quant_config.zero_point = False\n    config.pool = CPUInfer.backend_\n\n    config.gate_proj = gate_q.data_ptr()\n    config.up_proj = up_q.data_ptr()\n    config.down_proj = down_q.data_ptr()\n    config.gate_scale = gate_scale.data_ptr()\n    config.up_scale = up_scale.data_ptr()\n    config.down_scale = down_scale.data_ptr()\n\n    moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)\n    CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n    CPUInfer.sync()\n\n    keep_tensors = {\n        \"gate_q\": gate_q,\n        \"up_q\": up_q,\n        \"down_q\": down_q,\n        \"gate_scale\": gate_scale,\n        \"up_scale\": up_scale,\n        \"down_scale\": down_scale,\n    }\n\n    buffer_shapes = {\n        \"per_mat_weight_bytes\": per_mat_weight_bytes,\n        \"per_mat_scale_elems\": per_mat_scale_elems,\n    }\n\n    return moe, buffer_shapes, keep_tensors\n\n\ndef allocate_buffers(buffer_shapes):\n    \"\"\"Allocate shared output buffers for single expert.\"\"\"\n    per_mat_weight_bytes = buffer_shapes[\"per_mat_weight_bytes\"]\n    per_mat_scale_elems = buffer_shapes[\"per_mat_scale_elems\"]\n\n    weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count\n    scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count\n\n    # Each buffer stores data for a single expert\n    w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w13_scale_bufs = [torch.empty(2 * scale_elems_per_expert_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]\n    w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp, dtype=torch.bfloat16) for _ in range(gpu_tp_count)]\n\n    buffer_ptrs = {\n        \"w13_weight_ptrs\": [buf.data_ptr() for buf in w13_weight_bufs],\n        \"w13_scale_ptrs\": [buf.data_ptr() for buf in w13_scale_bufs],\n        \"w2_weight_ptrs\": [buf.data_ptr() for buf in w2_weight_bufs],\n        \"w2_scale_ptrs\": [buf.data_ptr() for buf in w2_scale_bufs],\n    }\n\n    keep_tensors = {\n        \"w13_weight_bufs\": w13_weight_bufs,\n        \"w13_scale_bufs\": w13_scale_bufs,\n        \"w2_weight_bufs\": w2_weight_bufs,\n        \"w2_scale_bufs\": w2_scale_bufs,\n    }\n\n    return buffer_ptrs, keep_tensors\n\n\ndef bench_write_buffer():\n    # Build two MOE instances with different layer_idx\n    moe_0, buffer_shapes, keep_tensors_0 = build_moe(layer_idx=0)\n    moe_1, _, keep_tensors_1 = build_moe(layer_idx=1)\n    moes = [moe_0, moe_1]\n\n    # Allocate shared buffers\n    buffer_ptrs, buffer_keep_tensors = allocate_buffers(buffer_shapes)\n\n    total_weights = hidden_size * intermediate_size * expert_num * 3\n    # Throughput accounting: scale bytes (bf16) + weight bytes (int4 packed)\n    bytes_per_call = total_weights // group_size * 2 + total_weights // 2\n\n    # Warm-up: alternate between two MOEs\n    for _ in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n        for moe_idx, moe in enumerate(moes):\n            for expert_id in range(gpu_experts_num):\n                CPUInfer.submit(\n                    moe.write_weight_scale_to_buffer_task(\n                        gpu_tp_count=gpu_tp_count,\n                        expert_id=expert_id,\n                        **buffer_ptrs,\n                    )\n                )\n                CPUInfer.sync()\n\n    total_time = 0\n    for iter_idx in tqdm(range(test_iter), desc=\"Testing\"):\n        start = time.perf_counter()\n        # Alternate between two MOEs\n        for moe_idx, moe in enumerate(moes):\n            for expert_id in range(gpu_experts_num):\n                CPUInfer.submit(\n                    moe.write_weight_scale_to_buffer_task(\n                        gpu_tp_count=gpu_tp_count,\n                        expert_id=expert_id,\n                        **buffer_ptrs,\n                    )\n                )\n                CPUInfer.sync()\n        end = time.perf_counter()\n        iter_time = end - start\n        total_time += iter_time\n        print(f\"Iter {iter_idx}: {iter_time*1000:.2f} ms\")\n        time.sleep(0.3)\n\n    # bytes_per_call is for one MOE, we have 2 MOEs\n    bytes_per_iter = bytes_per_call * 2\n    time_per_iter_ms = total_time / test_iter * 1000\n    bandwidth_gbs = bytes_per_iter * test_iter / total_time / 1e9\n\n    print(f\"\\n{'='*60}\")\n    print(\"K2 write_weight_scale_to_buffer benchmark (2 MOEs alternating)\")\n    print(f\"{'='*60}\")\n    print(f\"Time per iteration: {time_per_iter_ms:.2f} ms\")\n    print(f\"Bandwidth: {bandwidth_gbs:.2f} GB/s\")\n    print(f\"Experts per MOE: {gpu_experts_num}, MOEs: 2\")\n    print(f\"Time per expert: {time_per_iter_ms/(gpu_experts_num*2)*1000:.2f} us\")\n\n    result = {\n        \"op\": \"write_weight_scale_to_buffer_k2\",\n        \"time_per_iteration_ms\": time_per_iter_ms,\n        \"bandwidth_GBs\": bandwidth_gbs,\n        \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\"),\n        \"test_parameters\": {\n            \"expert_num\": expert_num,\n            \"hidden_size\": hidden_size,\n            \"intermediate_size\": intermediate_size,\n            \"group_size\": group_size,\n            \"gpu_tp_count\": gpu_tp_count,\n            \"bytes_per_iter\": bytes_per_iter,\n            \"num_moes\": 2,\n        },\n    }\n    result.update(get_git_commit())\n    result.update(get_system_info())\n    record_results(result)\n\n\nif __name__ == \"__main__\":\n    bench_write_buffer()\n"
  },
  {
    "path": "kt-kernel/bench/bench_linear.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-25 10:31:59\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-06 10:35:35\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\ninput_size = 16384\noutput_size = 5120\nstride = 16\ngroup_max_len = 1024\nlayer_num = 10\nqlen = 1\nCPUInfer = kt_kernel_ext.CPUInfer(64)\nwarm_up_iter = 1000\ntest_iter = 10000\n\n\ndef bench_linear(quant_mode: str):\n    with torch.inference_mode(mode=True):\n\n        hidden_type = 30  # ggml_type::GGML_TYPE_BF16\n        if quant_mode == \"fp32\":\n            proj_type = 0  # ggml_type::GGML_TYPE_F32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = 1  # ggml_type::GGML_TYPE_F16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = 30  # ggml_type::GGML_TYPE_BF16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"q8_0\":\n            proj_type = 8  # ggml_type::GGML_TYPE_Q8_0\n            bytes_per_elem = 1.062500\n        elif quant_mode == \"q6_k\":\n            proj_type = 14  # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.820312\n        elif quant_mode == \"q5_k_m\":\n            proj_type = 13  # ggml_type::GGML_TYPE_Q5_K\n            bytes_per_elem = 0.687500\n        elif quant_mode == \"q4_k_m\":\n            proj_type = 12  # ggml_type::GGML_TYPE_Q4_K\n            bytes_per_elem = 0.562500\n        elif quant_mode == \"q3_k_m\":\n            proj_type = 11  # ggml_type::GGML_TYPE_Q3_K\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"q2_k\":\n            proj_type = 10  # ggml_type::GGML_TYPE_Q2_K\n            bytes_per_elem = 0.328125\n        elif quant_mode == \"iq3_xs\":\n            proj_type = 21  # ggml_type::GGML_TYPE_IQ3_S\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"iq2_xxs\":\n            proj_type = 16  # ggml_type::GGML_TYPE_IQ2_XXS\n            bytes_per_elem = 0.257812\n        else:\n            assert False\n\n        linears = []\n        projs = []\n        for _ in range(layer_num):\n            proj = torch.randn((output_size, input_size), dtype=torch.float32, device=\"cuda\").to(\"cpu\").contiguous()\n            config = kt_kernel_ext.linear.LinearConfig(\n                input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type\n            )\n            linear = kt_kernel_ext.linear.Linear(config)\n            projs.append(proj)\n            linears.append(linear)\n        input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        output = torch.empty((layer_num, qlen, output_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                linears[i % layer_num].forward(qlen, input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr())\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                linears[i % layer_num].forward(qlen, input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr())\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print(\"Quant mode: \", quant_mode)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", total_time / test_iter * 1000000)\n        print(\n            \"Bandwidth: \",\n            input_size * output_size * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000,\n            \"GB/s\",\n        )\n        print(\"\")\n\n\nbench_linear(\"fp32\")\nbench_linear(\"fp16\")\nbench_linear(\"bf16\")\nbench_linear(\"q8_0\")\nbench_linear(\"q6_k\")\nbench_linear(\"q5_k_m\")\nbench_linear(\"q4_k_m\")\nbench_linear(\"q3_k_m\")\nbench_linear(\"q2_k\")\n# Not supported on __x86_64__\n# bench_linear(\"iq3_xs\")\n# bench_linear(\"iq2_xxs\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_linear_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:31:59\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-07-25 10:32:48\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nimport torch\nimport torch.nn.quantized as nnq\n\nscale, zero_point = 0.1, 0  # Adjust scale and zero_point based on your dataset\n\ninput_size = 16384\noutput_size = 5120\nlayer_num = 10\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef bench_linear(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"fp32\":\n            proj_type = torch.float32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = torch.float16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = torch.bfloat16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"qint8\":\n            proj_type = torch.qint8\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n        projs = []\n        for _ in range(layer_num):\n            proj = torch.randn((output_size, input_size), dtype = torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            if quant_mode == \"qint8\":\n                proj_q = torch.quantize_per_tensor(proj, scale, zero_point, torch.qint8)\n                quantized_layer = nnq.Linear(input_size, output_size)\n                quantized_layer.set_weight_bias(proj_q, None)\n                projs.append(quantized_layer)\n            else:\n                projs.append(proj.to(proj_type))\n        input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            if isinstance(projs[i % layer_num], nnq.Linear):\n                input_q = torch.quantize_per_tensor(input[i % layer_num].to(torch.float32), scale, zero_point, torch.quint8)\n                t_output = projs[i % layer_num](input_q)\n            else:\n                t_output = torch.mm(input[i % layer_num].to(proj_type), projs[i % layer_num].t())\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            if isinstance(projs[i % layer_num], nnq.Linear):\n                input_q = torch.quantize_per_tensor(input[i % layer_num].to(torch.float32), scale, zero_point, torch.quint8)\n                t_output = projs[i % layer_num](input_q)\n            else:\n                t_output = torch.mm(input[i % layer_num].to(proj_type), projs[i % layer_num].t())\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', input_size * output_size * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_linear(\"fp32\")\nbench_linear(\"fp16\")\nbench_linear(\"bf16\")\nbench_linear(\"qint8\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_mla.py",
    "content": "import os, sys\nimport time\nimport subprocess\nimport platform\nimport json\n\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext.kvcache import ggml_type\nimport torch\nfrom torch import inf, nn\nfrom torch.nn import init\n\nfrom tqdm import tqdm\n\nqlen = 4096\nkvlen = 0\npage_table = list(range(20))\npage_size = 256\npages_count = 200\n\n\nhidden_size = 7168\nnum_heads = 128\nkv_lora_rank = 512\nq_lora_rank = 512\nnope_size = 128\nrope_size = 64\npage_size = 512\nlayer_num = 10\n\n\nrope_theta = 10000\nmax_qlen = qlen + kvlen\nmax_kvlen = 4096\nmax_position_embeddings = 163840\n\nrope_scaling = {\n    \"beta_fast\": 32,\n    \"beta_slow\": 1,\n    \"factor\": 40,\n    \"mscale\": 1.0,\n    \"mscale_all_dim\": 1.0,\n    \"original_max_position_embeddings\": 4096,\n    \"type\": \"yarn\",\n}\n\nCPUINFER_PARAM = 304\n# 初始化 CPUInfer（此处使用原始构造函数，可根据需要调整配置参数）\nCPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)\n\n\nwarm_up_iter = 20\ntest_iter = 100\n\n\n# 获取脚本相关信息，用于生成结果保存文件名\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\nscript_name = os.path.splitext(os.path.basename(script_path))[0]\njson_path = os.path.join(script_dir, \"bench_results \" + \".jsonl\")\n\n\ndef get_git_commit():\n    \"\"\"\n    获取当前 git 提交记录（commit hash 和提交信息），\n    并检查是否存在未提交的更改（dirty）\n    \"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        # 检查是否存在未提交的更改\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"\n    获取系统信息，包括系统名称、CPU 型号、内存大小（GB）、CPU 核数及 socket 数量\n    \"\"\"\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    # 获取 CPU 型号（仅 Linux 支持）\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    # 获取内存大小（单位：GB），仅 Linux 支持\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    # 解析 /proc/cpuinfo 获取 socket 数量\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception as e:\n            sockets = set()\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\ndef record_results(result, filename=json_path):\n    \"\"\"\n    将结果以 JSON 格式追加到文件中\n    \"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef bench_mla(quant_mode: str):\n    \"\"\"\n    测试 MLA 模型的性能\n    \"\"\"\n    with torch.inference_mode():\n        # 这里可以添加 MLA 模型的具体实现和测试代码\n        hidden_type = 1  # ggml_type::GGML_TYPE_FP16（固定）\n        if quant_mode == \"fp32\":\n            q_a_proj_type = 0  # ggml_type::GGML_TYPE_F32\n            q_b_proj_type = 0\n            kv_a_proj_with_mqa_type = 0\n            kv_b_proj_type = 0\n            w_o_type = 0\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            q_a_proj_type = 1  # ggml_type::GGML_TYPE_F32\n            q_b_proj_type = 1\n            kv_a_proj_with_mqa_type = 1\n            kv_b_proj_type = 1\n            w_o_type = 1\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"q4_k_m\":\n            q_a_proj_type = 12  # ggml_type::GGML_TYPE_Q4_K\n            q_b_proj_type = 12\n            kv_a_proj_with_mqa_type = 12  # ggml_type::GGML_TYPE_Q6_K\n            kv_b_proj_type = 12\n            w_o_type = 12\n            bytes_per_elem = 0.5625\n        else:\n            raise ValueError(\"不支持的量化模式\")\n\n        # 构建各层 MLA 模型的输入数据\n        mlas = []\n        for i in tqdm(range(layer_num)):\n            q_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=torch.float16)\n            q_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size + rope_size), bias=False, dtype=torch.float16)\n            kv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=torch.float16)\n            kv_b_proj = nn.Linear(num_heads * (nope_size + nope_size), kv_lora_rank, bias=False, dtype=torch.float16)\n            o_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=torch.float16)\n\n            init.normal_(q_a_proj.weight, mean=0.0, std=0.02)\n            init.normal_(q_b_proj.weight, mean=0.0, std=0.02)\n            init.normal_(kv_a_proj_with_mqa.weight, mean=0.0, std=0.02)\n            init.normal_(kv_b_proj.weight, mean=0.0, std=0.02)\n            init.normal_(o_proj.weight, mean=0.0, std=0.02)\n            q_a_proj_weight = q_a_proj.weight.to(torch.float16).to(\"cpu\").contiguous()\n            q_b_proj_weight = q_b_proj.weight.to(torch.float16).to(\"cpu\").contiguous()\n            kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to(\"cpu\").to(torch.float16).contiguous()\n            kv_b_proj_weight = kv_b_proj.weight.to(torch.float16).to(\"cpu\").contiguous()\n            o_proj_weight = o_proj.weight.to(torch.float16).to(\"cpu\").contiguous()\n\n            config = kt_kernel_ext.mla.MLAConfig(\n                hidden_size,\n                q_lora_rank,\n                kv_lora_rank,\n                num_heads,\n                nope_size,\n                rope_size,\n            )\n            config.max_qlen = max_qlen\n            config.max_kvlen = max_kvlen\n            config.max_position_embeddings = max_position_embeddings\n            config.rope_scaling_factor = rope_scaling[\"factor\"]\n            config.rope_theta = rope_theta\n            config.rope_scaling_beta_fast = rope_scaling[\"beta_fast\"]\n            config.rope_scaling_beta_slow = rope_scaling[\"beta_slow\"]\n            config.rope_scaling_mscale = rope_scaling[\"mscale\"]\n            config.rope_scaling_mscale_all_dim = rope_scaling[\"mscale_all_dim\"]\n            config.rope_scaling_original_max_position_embeddings = rope_scaling[\"original_max_position_embeddings\"]\n\n            config.q_a_proj = q_a_proj_weight.data_ptr()\n            config.q_b_proj = q_b_proj_weight.data_ptr()\n            config.kv_a_proj_with_mqa = kv_a_proj_with_mqa_weight.data_ptr()\n            config.kv_b_proj = kv_b_proj_weight.data_ptr()\n            config.o_proj = o_proj_weight.data_ptr()\n\n            config.q_a_proj_type = ggml_type.FP16\n            config.q_b_proj_type = ggml_type.FP16\n            config.kv_a_proj_with_mqa_type = ggml_type.FP16\n            config.kv_b_proj_type = ggml_type.FP16\n            config.w_o_type = ggml_type.FP16\n\n            config.pool = CPUInfer.backend_\n\n            mla = kt_kernel_ext.mla.MLA(config)\n            mla.load_weights()\n            mla.set_local_pages(pages_count)\n            mlas.append(mla)\n\n        print(\"Generating data...\")\n        input_tensor = (\n            torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").to(\"cpu\").contiguous()\n        )\n        output_tensor = (\n            torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").to(\"cpu\").contiguous()\n        )\n\n        print(\"Warming up...\")\n\n        for i in tqdm(range(warm_up_iter)):\n            mlas[i % layer_num].forward(\n                [qlen],\n                [page_table],\n                [kvlen],\n                input_tensor[i % layer_num].data_ptr(),\n                output_tensor[i % layer_num].data_ptr(),\n            )\n\n        print(\"Start testing...\")\n\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter)):\n            mlas[i % layer_num].forward(\n                [qlen],\n                [page_table],\n                [kvlen],\n                input_tensor[i % layer_num].data_ptr(),\n                output_tensor[i % layer_num].data_ptr(),\n            )\n\n        end = time.perf_counter()\n        total_time = end - start\n\n        time_per_iter_us = (total_time * 1e6) / test_iter\n        bandwidth = (\n            bytes_per_elem\n            * (\n                q_lora_rank * hidden_size\n                + (kv_lora_rank + rope_size) * hidden_size\n                + (nope_size + rope_size) * q_lora_rank * num_heads\n                + (nope_size + nope_size) * kv_lora_rank * num_heads\n                + hidden_size * nope_size * num_heads\n                + hidden_size * qlen\n            )\n            * test_iter\n            / (total_time * 1e9)\n        )\n        flops = (\n            2\n            * (\n                q_lora_rank * hidden_size * qlen\n                + kv_lora_rank * hidden_size * qlen\n                + num_heads * (nope_size + rope_size) * q_lora_rank * qlen\n                + num_heads * qlen * nope_size * kv_lora_rank\n                + num_heads * (kvlen + qlen) * kv_lora_rank * qlen\n                + num_heads * rope_size * qlen * (qlen + kvlen)\n                + num_heads * kv_lora_rank * (qlen + kvlen) * qlen\n                + num_heads * nope_size * kv_lora_rank * qlen\n                + hidden_size * num_heads * nope_size * qlen\n            )\n            * test_iter\n            / (total_time * 1e12)\n        )\n\n        print(\"Quant mode:\", quant_mode)\n        print(\"Time(s):\", total_time)\n        print(\"Iteration:\", test_iter)\n        print(\"Time(us) per iteration:\", time_per_iter_us)\n        print(\"Bandwidth:\", bandwidth, \"GB/s\")\n        print(\"TFLOPS:\", flops)\n        print(\"\")\n\n        # 整理测试结果\n        result = {\n            \"test_name\": os.path.basename(__file__),\n            \"quant_mode\": quant_mode,\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"qlen\": qlen,\n                \"kvlen\": kvlen,\n                \"page_table\": page_table,\n                \"page_size\": page_size,\n                \"pages_count\": pages_count,\n                \"hidden_size\": hidden_size,\n                \"num_heads\": num_heads,\n                \"kv_lora_rank\": kv_lora_rank,\n                \"q_lora_rank\": q_lora_rank,\n                \"nope_size\": nope_size,\n                \"rope_size\": rope_size,\n                \"layer_num\": layer_num,\n                \"rope_theta\": rope_theta,\n                \"max_qlen\": max_qlen,\n                \"max_kvlen\": max_kvlen,\n                \"max_position_embeddings\": max_position_embeddings,\n                \"rope_scaling\": rope_scaling,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        # 添加 git 与系统信息\n        result.update(get_git_commit())\n        result.update(get_system_info())\n        # 将结果记录到 JSON 文件中\n        print(result)\n        record_results(result)\n\n\nbench_mla(\"fp16\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_mlp.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-16 10:43:18\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-06 10:36:04\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\nhidden_size = 5120\nintermediate_size = 3072\nstride = 16\ngroup_max_len = 1024\nlayer_num = 10\nqlen = 1\nCPUInfer = kt_kernel_ext.CPUInfer(64)\nwarm_up_iter = 1000\ntest_iter = 10000\n\n\ndef bench_mlp(quant_mode: str):\n    with torch.inference_mode(mode=True):\n\n        hidden_type = 30  # ggml_type::GGML_TYPE_BF16\n        if quant_mode == \"fp32\":\n            gate_type = 0  # ggml_type::GGML_TYPE_F32\n            up_type = 0  # ggml_type::GGML_TYPE_F32\n            down_type = 0  # ggml_type::GGML_TYPE_F32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            gate_type = 1  # ggml_type::GGML_TYPE_F16\n            up_type = 1  # ggml_type::GGML_TYPE_F16\n            down_type = 1  # ggml_type::GGML_TYPE_F16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            gate_type = 30  # ggml_type::GGML_TYPE_BF16\n            up_type = 30  # ggml_type::GGML_TYPE_BF16\n            down_type = 30  # ggml_type::GGML_TYPE_BF16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"q8_0\":\n            gate_type = 8  # ggml_type::GGML_TYPE_Q8_0\n            up_type = 8  # ggml_type::GGML_TYPE_Q8_0\n            down_type = 8  # ggml_type::GGML_TYPE_Q8_0\n            bytes_per_elem = 1.062500\n        elif quant_mode == \"q6_k\":\n            gate_type = 14  # ggml_type::GGML_TYPE_Q6_K\n            up_type = 14  # ggml_type::GGML_TYPE_Q6_K\n            down_type = 14  # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.820312\n        elif quant_mode == \"q5_k_m\":\n            gate_type = 13  # ggml_type::GGML_TYPE_Q5_K\n            up_type = 13  # ggml_type::GGML_TYPE_Q5_K\n            down_type = 14  # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.731771\n        elif quant_mode == \"q4_k_m\":\n            gate_type = 12  # ggml_type::GGML_TYPE_Q4_K\n            up_type = 12  # ggml_type::GGML_TYPE_Q4_K\n            down_type = 14  # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.648437\n        elif quant_mode == \"q3_k_m\":\n            gate_type = 11  # ggml_type::GGML_TYPE_Q3_K\n            up_type = 11  # ggml_type::GGML_TYPE_Q3_K\n            down_type = 13  # ggml_type::GGML_TYPE_Q5_K\n            bytes_per_elem = 0.515625\n        elif quant_mode == \"q2_k\":\n            gate_type = 10  # ggml_type::GGML_TYPE_Q2_K\n            up_type = 10  # ggml_type::GGML_TYPE_Q2_K\n            down_type = 11  # ggml_type::GGML_TYPE_Q3_K\n            bytes_per_elem = 0.328125\n        elif quant_mode == \"iq3_xs\":\n            gate_type = 21  # ggml_type::GGML_TYPE_IQ3_S\n            up_type = 21  # ggml_type::GGML_TYPE_IQ3_S\n            down_type = 21  # ggml_type::GGML_TYPE_IQ3_S\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"iq2_xxs\":\n            gate_type = 16  # ggml_type::GGML_TYPE_IQ2_XXS\n            up_type = 16  # ggml_type::GGML_TYPE_IQ2_XXS\n            down_type = 16  # ggml_type::GGML_TYPE_IQ2_XXS\n            bytes_per_elem = 0.257812\n        else:\n            assert False\n\n        mlps = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = (\n                torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\").to(\"cpu\").contiguous()\n            )\n            up_proj = (\n                torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\").to(\"cpu\").contiguous()\n            )\n            down_proj = (\n                torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device=\"cuda\").to(\"cpu\").contiguous()\n            )\n            config = kt_kernel_ext.mlp.MLPConfig(\n                hidden_size,\n                intermediate_size,\n                stride,\n                group_max_len,\n                gate_proj.data_ptr(),\n                up_proj.data_ptr(),\n                down_proj.data_ptr(),\n                gate_type,\n                up_type,\n                down_type,\n                hidden_type,\n            )\n            mlp = kt_kernel_ext.mlp.MLP(config)\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            mlps.append(mlp)\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                mlps[i % layer_num].forward(qlen, input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr())\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                mlps[i % layer_num].forward(qlen, input[i % layer_num].data_ptr(), output[i % layer_num].data_ptr())\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print(\"Quant mode: \", quant_mode)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", total_time / test_iter * 1000000)\n        print(\n            \"Bandwidth: \",\n            hidden_size * intermediate_size * 3 * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000,\n            \"GB/s\",\n        )\n        print(\"\")\n\n\nbench_mlp(\"fp32\")\nbench_mlp(\"fp16\")\nbench_mlp(\"bf16\")\nbench_mlp(\"q8_0\")\nbench_mlp(\"q6_k\")\nbench_mlp(\"q5_k_m\")\nbench_mlp(\"q4_k_m\")\nbench_mlp(\"q3_k_m\")\nbench_mlp(\"q2_k\")\n# Not supported on __x86_64__\n# bench_linear(\"iq3_xs\")\n# bench_linear(\"iq2_xxs\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_mlp_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-16 10:43:18\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-07-25 10:32:53\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nimport torch\nimport torch.nn.quantized as nnq\n\nscale, zero_point = 0.1, 0  # Adjust scale and zero_point based on your dataset\n\nhidden_size = 5120\nintermediate_size = 3072\nlayer_num = 10\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    if isinstance(gate_proj, nnq.Linear):\n        input_q = torch.quantize_per_tensor(input.to(torch.float32), scale, zero_point, torch.quint8)\n        gate_buf = gate_proj(input_q)\n        up_buf = up_proj(input_q)\n        gate_buf = gate_buf.dequantize()\n        up_buf = up_buf.dequantize()\n        intermediate = act_fn(gate_buf) * up_buf\n        intermediate_q = torch.quantize_per_tensor(intermediate, scale, zero_point, torch.quint8)\n        expert_output = down_proj(intermediate_q)\n        ret = expert_output.dequantize()\n    else:\n        gate_buf = torch.mm(input.to(gate_proj.dtype), gate_proj.t())\n        up_buf = torch.mm(input.to(up_proj.dtype), up_proj.t())\n        intermediate = act_fn(gate_buf) * up_buf\n        ret = torch.mm(intermediate.to(down_proj.dtype), down_proj.t())\n    return ret\n\ndef bench_mlp(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"fp32\":\n            proj_type = torch.float32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = torch.float16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = torch.bfloat16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"qint8\":\n            proj_type = torch.qint8\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            if quant_mode == \"qint8\":\n                gate_proj_q = torch.quantize_per_tensor(gate_proj, scale, zero_point, torch.qint8)\n                quantized_gate = nnq.Linear(hidden_size, intermediate_size)\n                quantized_gate.set_weight_bias(gate_proj_q, None)\n                up_proj_q = torch.quantize_per_tensor(up_proj, scale, zero_point, torch.qint8)\n                quantized_up = nnq.Linear(hidden_size, intermediate_size)\n                quantized_up.set_weight_bias(up_proj_q, None)\n                down_proj_q = torch.quantize_per_tensor(down_proj, scale, zero_point, torch.qint8)\n                quantized_down = nnq.Linear(intermediate_size, hidden_size)\n                quantized_down.set_weight_bias(down_proj_q, None)\n                gate_projs.append(quantized_gate)\n                up_projs.append(quantized_up)\n                down_projs.append(quantized_down)\n            else:\n                gate_projs.append(gate_proj.to(proj_type))\n                up_projs.append(up_proj.to(proj_type))\n                down_projs.append(down_proj.to(proj_type))\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            mlp_torch(input[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            mlp_torch(input[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_mlp(\"fp32\")\nbench_mlp(\"fp16\")\nbench_mlp(\"bf16\")\nbench_mlp(\"qint8\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_moe.py",
    "content": "import os\nimport sys\nimport time\nimport json\nimport subprocess\nimport platform\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\nfrom kt_kernel import kt_kernel_ext\nimport torch\nfrom tqdm import tqdm\n\n# 测试参数设置\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nm_block = 1\ngroup_min_len = 10\ngroup_max_len = 1024\nnum_experts_per_tok = 8\n# layer_num = 5  # 测试时不同的层数\n# qlen = 1\n# warm_up_iter = 100\n# test_iter = 10000\n\nlayer_num = 1  # 测试时不同的层数\nqlen = 1024\nwarm_up_iter = 100\ntest_iter = 10000\nCPUINFER_PARAM = 304\n# 初始化 CPUInfer（此处使用原始构造函数，可根据需要调整配置参数）\nCPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)\n\n# 获取脚本相关信息，用于生成结果保存文件名\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\nscript_name = os.path.splitext(os.path.basename(script_path))[0]\njson_path = os.path.join(script_dir, \"bench_results \" + \".jsonl\")\n\n\ndef get_git_commit():\n    \"\"\"\n    获取当前 git 提交记录（commit hash 和提交信息），\n    并检查是否存在未提交的更改（dirty）\n    \"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        # 检查是否存在未提交的更改\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"\n    获取系统信息，包括系统名称、CPU 型号、内存大小（GB）、CPU 核数及 socket 数量\n    \"\"\"\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    # 获取 CPU 型号（仅 Linux 支持）\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    # 获取内存大小（单位：GB），仅 Linux 支持\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    # 解析 /proc/cpuinfo 获取 socket 数量\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception as e:\n            sockets = set()\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\ndef record_results(result, filename=json_path):\n    \"\"\"\n    将结果以 JSON 格式追加到文件中\n    \"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef bench_moe(quant_mode: str):\n    \"\"\"\n    依据不同量化模式进行 MoE 性能测试，包含预热与测试阶段\n    \"\"\"\n    with torch.inference_mode():\n        # 根据量化模式设置数据类型与 bytes_per_elem\n        hidden_type = 30  # ggml_type::GGML_TYPE_BF16（固定）\n        if quant_mode == \"fp32\":\n            gate_type = 0  # ggml_type::GGML_TYPE_F32\n            up_type = 0\n            down_type = 0\n            bytes_per_elem = 4.0\n        elif quant_mode == \"fp16\":\n            gate_type = 1  # ggml_type::GGML_TYPE_F16\n            up_type = 1\n            down_type = 1\n            bytes_per_elem = 2.0\n        elif quant_mode == \"bf16\":\n            gate_type = 30  # ggml_type::GGML_TYPE_BF16\n            up_type = 30\n            down_type = 30\n            bytes_per_elem = 2.0\n        elif quant_mode == \"q8_0\":\n            gate_type = 8  # ggml_type::GGML_TYPE_Q8_0\n            up_type = 8\n            down_type = 8\n            bytes_per_elem = 1.062500\n        elif quant_mode == \"q6_k\":\n            gate_type = 14  # ggml_type::GGML_TYPE_Q6_K\n            up_type = 14\n            down_type = 14\n            bytes_per_elem = 0.820312\n        elif quant_mode == \"q5_k_m\":\n            gate_type = 13  # ggml_type::GGML_TYPE_Q5_K\n            up_type = 13\n            down_type = 14  # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.731771\n        elif quant_mode == \"q4_k_m\":\n            gate_type = 12  # ggml_type::GGML_TYPE_Q4_K\n            up_type = 12\n            down_type = 14  # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.648437\n        elif quant_mode == \"q3_k_m\":\n            gate_type = 11  # ggml_type::GGML_TYPE_Q3_K\n            up_type = 11\n            down_type = 13  # ggml_type::GGML_TYPE_Q5_K\n            bytes_per_elem = 0.515625\n        elif quant_mode == \"q2_k\":\n            gate_type = 10  # ggml_type::GGML_TYPE_Q2_K\n            up_type = 10\n            down_type = 11  # ggml_type::GGML_TYPE_Q3_K\n            bytes_per_elem = 0.328125\n        elif quant_mode == \"iq3_xs\":\n            gate_type = 21  # ggml_type::GGML_TYPE_IQ3_S\n            up_type = 21\n            down_type = 21\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"iq2_xxs\":\n            gate_type = 16  # ggml_type::GGML_TYPE_IQ2_XXS\n            up_type = 16\n            down_type = 16\n            bytes_per_elem = 0.257812\n        else:\n            raise ValueError(\"不支持的量化模式\")\n\n        # 构建各层 MoE 模型\n        moes = []\n        for _ in tqdm(range(layer_num), desc=\"Initializing MOEs\"):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float16, device=\"cpu\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float16, device=\"cpu\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float16, device=\"cpu\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n            config.pool = CPUInfer.backend_\n            config.m_block = m_block\n            config.group_min_len = group_min_len\n            config.group_max_len = group_max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.gate_type = gate_type\n            config.up_type = up_type\n            config.down_type = down_type\n            config.hidden_type = hidden_type\n\n            moe = kt_kernel_ext.moe.MOE(config)\n            CPUInfer.submit(moe.load_weights_task())\n            CPUInfer.sync()\n            moes.append(moe)\n\n        # 生成输入数据\n        print(\"Generating data...\")\n        # 专家路由索引与权重，每层一个\n        gen_iter = 1000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .contiguous()\n        )\n        weights = torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").contiguous()\n        input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        # 将 qlen 封装成 tensor，用于 forward 调用\n        qlen_tensor = torch.tensor([qlen], dtype=torch.int32)\n\n        # 预热阶段\n        print(\"Warming up...\")\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    qlen_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n        # 测试阶段\n        print(\"Start testing...\")\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    qlen_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n\n        # 计算性能指标\n        time_per_iter_us = total_time / test_iter * 1e6\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # 单位：GB/s\n        flops = (\n            hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n        )  # 单位：TFLOPS\n\n        # 打印结果\n        print(\"Quant mode:\", quant_mode)\n        print(\"Time(s):\", total_time)\n        print(\"Iteration:\", test_iter)\n        print(\"Time(us) per iteration:\", time_per_iter_us)\n        print(\"Bandwidth:\", bandwidth, \"GB/s\")\n        print(\"TFLOPS:\", flops)\n        print(\"\")\n\n        # 整理测试结果\n        result = {\n            \"test_name\": os.path.basename(__file__),\n            \"quant_mode\": quant_mode,\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"m_block\": m_block,\n                \"group_min_len\": group_min_len,\n                \"group_max_len\": group_max_len,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        # 添加 git 与系统信息\n        result.update(get_git_commit())\n        result.update(get_system_info())\n        # 将结果记录到 JSON 文件中\n        record_results(result)\n\n\nif __name__ == \"__main__\":\n    # 根据需要选择量化模式，目前调用 q4_k_m 模式，对 layer_nums 列表中各层数进行测试\n    bench_moe(\"q4_k_m\")\n    # 其他量化模式调用可以按需取消注释\n    # bench_moe(\"fp32\", layer_num)\n    # bench_moe(\"fp16\", layer_num)\n    # bench_moe(\"bf16\", layer_num)\n    # bench_moe(\"q8_0\")\n    # bench_moe(\"q6_k\", layer_num)\n    # bench_moe(\"q5_k_m\", layer_num)\n    # bench_moe(\"q3_k_m\", layer_num)\n    # bench_moe(\"q2_k\", layer_num)\n    # bench_moe(\"iq3_xs\", layer_num)\n    # bench_moe(\"iq2_xxs\", layer_num)\n"
  },
  {
    "path": "kt-kernel/bench/bench_moe_amx.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-06 10:41:28\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys, time, json, subprocess, platform\n\nfrom tqdm import tqdm\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\nimport torch\nfrom kt_kernel import kt_kernel_ext\nimport numpy as np\n\n# 测试参数设置\nexpert_num = 16\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nlayer_num = 2\n\nqlen = 2048\nwarm_up_iter = 1000\ntest_iter = 2000\nphysical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n# 将 CPUInfer 参数设为变量\n# CPUINFER_PARAM = 257\n# CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)\n\nworker_config = kt_kernel_ext.WorkerPoolConfig()\nworker_config.subpool_count = 2\nworker_config.subpool_numa_map = [0, 1]\nworker_config.subpool_thread_count = [80, 80]\nCPUINFER_PARAM = 160\nCPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n\n\ndef get_git_commit():\n    \"\"\"\n    获取当前 git 提交记录（commit hash 和提交信息），\n    并检查是否存在未提交的更改（dirty）\n    \"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        # 检查是否存在未提交的更改\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"\n    获取系统信息，包括系统名称、CPU 型号、内存大小（GB）、CPU 核数及 socket 数量\n    \"\"\"\n    info = {}\n    # 系统名称及主机名\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system  # 如 Linux, Windows 等\n    info[\"node_name\"] = uname.node  # 主机名称\n\n    # 获取 CPU 型号（仅 Linux 支持）\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    # 获取内存大小（单位：GB），仅 Linux 支持\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    # 获取 CPU 核数（逻辑核数）\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    # 解析 /proc/cpuinfo 获取 socket 数量\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception as e:\n            sockets = set()\n    # 如果没有解析到 socket 信息，则默认至少有 1 个 socket\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\nscript_name = os.path.splitext(os.path.basename(script_path))[0]\njson_path = os.path.join(script_dir, script_name + \".jsonl\")\n\n\ndef record_results(result, filename=json_path):\n    \"\"\"\n    将结果以 JSON 格式追加到文件中\n    \"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode():\n        if quant_mode == \"bf16\":\n            bytes_per_elem = 2.0\n        elif quant_mode == \"int8\":\n            bytes_per_elem = 1.0\n        elif quant_mode == \"int4\":\n            bytes_per_elem = 0.5\n        else:\n            raise ValueError(\"不支持的量化模式\")\n\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for layer_index in range(layer_num):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = CPUInfer.backend_\n            if quant_mode == \"bf16\":\n                moe = kt_kernel_ext.moe.AMXBF16_MOE(config)\n            elif quant_mode == \"int8\":\n                moe = kt_kernel_ext.moe.AMXInt8_MOE(config)\n            elif quant_mode == \"int4\":\n                moe = kt_kernel_ext.moe.AMXInt4_MOE(config)\n            CPUInfer.submit(moe.load_weights_task())\n            CPUInfer.sync()\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n        gen_iter = 3000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .to(\"cpu\")\n            .contiguous()\n        )\n        weights = (\n            torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").to(\"cpu\").contiguous()\n        )\n        input_tensor = (\n            torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        output_tensor = (\n            torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n\n        # 预热迭代\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            # start_it = time.time_ns()\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n            # end_it = time.time_ns()\n            # print('python Time(ns): ', end_it - start_it)\n\n        # 测试迭代\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            # print(f'test iteration {i}')\n            # start_it = time.time_ns()\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n            # end_it = time.time_ns()\n            # print('python Time(ns): ', end_it - start_it)\n        end = time.perf_counter()\n        total_time = end - start\n\n        # 计算性能指标\n        time_per_iter_us = total_time / test_iter * 1e6\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # 单位：GB/s\n        flops = (\n            hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n        )  # 单位：TFLOPS\n\n        print(\"Quant mode: \", quant_mode)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", time_per_iter_us)\n        print(\"Bandwidth: \", bandwidth, \"GB/s\")\n        print(\"Flops: \", flops, \"TFLOPS\")\n        print(\"\")\n\n        # 整理结果记录，包括测试参数\n        result = {\n            \"quant_mode\": quant_mode,\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"max_len\": max_len,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        # 添加 git 提交记录信息\n        result.update(get_git_commit())\n        # 添加系统信息（包括 CPU 核数和 socket 数量）\n        result.update(get_system_info())\n        # 将结果以 JSON 形式追加到文件中\n        record_results(result)\n\n\nif __name__ == \"__main__\":\n    # 选择需要测试的量化模式\n    # bench_moe(\"bf16\")\n    bench_moe(\"int8\")\n    # bench_moe(\"int4\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_moe_amx_k.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-06 10:41:28\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys, time, json, subprocess, platform\n\nfrom tqdm import tqdm\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\nfrom kt_kernel import kt_kernel_ext\nimport torch\nimport numpy as np\n\n# 测试参数设置\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nlayer_num = 4\nqlen = 1024\n# qlen = 1\nwarm_up_iter = 1000\ntest_iter = 5000\nk_group_size = 128\n\nphysical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n# 将 CPUInfer 参数设为变量\n# CPUINFER_PARAM = 257\n# CPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)\n\nworker_config = kt_kernel_ext.WorkerPoolConfig()\nworker_config.subpool_count = 2\nworker_config.subpool_numa_map = [0, 1]\nworker_config.subpool_thread_count = [40, 40]\nCPUINFER_PARAM = 80\nCPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n\n\ndef get_git_commit():\n    \"\"\"\n    获取当前 git 提交记录（commit hash 和提交信息），\n    并检查是否存在未提交的更改（dirty）\n    \"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        # 检查是否存在未提交的更改\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"\n    获取系统信息，包括系统名称、CPU 型号、内存大小（GB）、CPU 核数及 socket 数量\n    \"\"\"\n    info = {}\n    # 系统名称及主机名\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system  # 如 Linux, Windows 等\n    info[\"node_name\"] = uname.node  # 主机名称\n\n    # 获取 CPU 型号（仅 Linux 支持）\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    # 获取内存大小（单位：GB），仅 Linux 支持\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    # 获取 CPU 核数（逻辑核数）\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    # 解析 /proc/cpuinfo 获取 socket 数量\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception as e:\n            sockets = set()\n    # 如果没有解析到 socket 信息，则默认至少有 1 个 socket\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\nscript_name = os.path.splitext(os.path.basename(script_path))[0]\njson_path = os.path.join(script_dir, script_name + \".jsonl\")\n\n\ndef record_results(result, filename=json_path):\n    \"\"\"\n    将结果以 JSON 格式追加到文件中\n    \"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode():\n        if quant_mode == \"bf16\":\n            bytes_per_elem = 2.0\n        elif quant_mode == \"int8\":\n            bytes_per_elem = 1.0\n        elif quant_mode == \"int4\":\n            bytes_per_elem = 0.5\n        elif quant_mode == \"int4_1k\":\n            bytes_per_elem = 0.5\n        else:\n            raise ValueError(\"不支持的量化模式\")\n\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for layer_index in range(layer_num):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = CPUInfer.backend_\n            if quant_mode == \"bf16\":\n                moe = kt_kernel_ext.moe.AMXBF16_MOE(config)\n            elif quant_mode == \"int8\":\n                moe = kt_kernel_ext.moe.AMXInt8_MOE(config)\n            elif quant_mode == \"int4\":\n                moe = kt_kernel_ext.moe.AMXInt4_MOE(config)\n            elif quant_mode == \"int4_1k\":\n                config.quant_config.bits = 4\n                config.quant_config.group_size = k_group_size\n                config.quant_config.zero_point = True\n                config.gate_scale = 0\n                moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n        gen_iter = 3000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .to(\"cpu\")\n            .contiguous()\n        )\n        weights = (\n            torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").to(\"cpu\").contiguous()\n        )\n        input_tensor = (\n            torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        output_tensor = (\n            torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n\n        # 预热迭代\n        # for i in range(warm_up_iter):\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            # start_it = time.time_ns()\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n            # end_it = time.time_ns()\n            # print('python Time(ns): ', end_it - start_it)\n\n        # 测试迭代\n        start = time.perf_counter()\n        # for i in range(test_iter):\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            # print(f'test iteration {i}')\n            # start_it = time.time_ns()\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n            # end_it = time.time_ns()\n            # print('python Time(ns): ', end_it - start_it)\n        end = time.perf_counter()\n        total_time = end - start\n\n        # 计算性能指标\n        time_per_iter_us = total_time / test_iter * 1e6\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # 单位：GB/s\n        flops = (\n            hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n        )  # 单位：TFLOPS\n\n        print(\"Quant mode: \", quant_mode)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", time_per_iter_us)\n        print(\"Bandwidth: \", bandwidth, \"GB/s\")\n        print(\"Flops: \", flops, \"TFLOPS\")\n        print(\"\")\n\n        # 整理结果记录，包括测试参数\n        result = {\n            \"quant_mode\": quant_mode,\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"max_len\": max_len,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n                \"k_group_size\": k_group_size,\n            },\n        }\n        # 添加 git 提交记录信息\n        result.update(get_git_commit())\n        # 添加系统信息（包括 CPU 核数和 socket 数量）\n        result.update(get_system_info())\n        # 将结果以 JSON 形式追加到文件中\n        record_results(result)\n\n\nif __name__ == \"__main__\":\n    # 选择需要测试的量化模式\n    # bench_moe(\"bf16\")\n    # bench_moe(\"int8\")\n    # bench_moe(\"int4\")\n    bench_moe(\"int4_1k\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_moe_kernel.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-06 10:41:28\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys, time, json, subprocess, platform\n\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\nimport torch\nfrom kt_kernel import kt_kernel_ext\nimport numpy as np\nfrom tqdm import tqdm\n\n\n# 测试参数设置\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 51200\nnum_experts_per_tok = 8\nlayer_num = 1\nm_block = 320\nn_block_up_gate = 32\nn_block_down = 64\nn_block_up_gate_prefi = 32\nn_block_down_prefi = 64\nqlen = 2048\nwarm_up_iter = 1000\ntest_iter = 1000\n\n# 将 CPUInfer 参数设为变量\nCPUINFER_PARAM = 160\nCPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)\nphysical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n\n# worker_config = kt_kernel_ext.WorkerPoolConfig()\n# worker_config.subpool_count = 4\n# worker_config.subpool_numa_map= [0,1,2,3]\n# worker_config.subpool_thread_count = [36,36,36,36]\n# worker_config.subpool_thread_count = [39,39,39,39]\n# CPUINFER_PARAM = 156\n# CPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n\n\ndef get_git_commit():\n    \"\"\"\n    获取当前 git 提交记录（commit hash 和提交信息），\n    并检查是否存在未提交的更改（dirty）\n    \"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        # 检查是否存在未提交的更改\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"\n    获取系统信息，包括系统名称、CPU 型号、内存大小（GB）、CPU 核数及 socket 数量\n    \"\"\"\n    info = {}\n    # 系统名称及主机名\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system  # 如 Linux, Windows 等\n    info[\"node_name\"] = uname.node  # 主机名称\n\n    # 获取 CPU 型号（仅 Linux 支持）\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    # 获取内存大小（单位：GB），仅 Linux 支持\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    # 获取 CPU 核数（逻辑核数）\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    # 解析 /proc/cpuinfo 获取 socket 数量\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception as e:\n            sockets = set()\n    # 如果没有解析到 socket 信息，则默认至少有 1 个 socket\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\nscript_name = os.path.splitext(os.path.basename(script_path))[0]\njson_path = os.path.join(script_dir, \"bench_results \" + \".jsonl\")\n\n\ndef record_results(result, filename=json_path):\n    \"\"\"\n    将结果以 JSON 格式追加到文件中\n    \"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode():\n        if quant_mode == \"int8\":\n            bytes_per_elem = 1.0\n        elif quant_mode == \"int4\":\n            bytes_per_elem = 0.5\n        else:\n            raise ValueError(\"不支持的量化模式\")\n\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for layer_index in range(layer_num):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = CPUInfer.backend_\n            if quant_mode == \"int8\":\n                d = kt_kernel_ext.moe.tiling.get_int8()\n                nbug_prefi = n_block_up_gate_prefi\n                nbd_prefi = n_block_down_prefi\n                kb = d[\"k_block\"]\n                nb = d[\"n_block\"]\n                mb = m_block\n                nbug = n_block_up_gate\n                nbd = n_block_down\n                print(\n                    f\"Int8 Tiling: nbug {nbug}, nbd {nbd}, nb {nb}, mb {mb}, kb {kb}, nbug_prefi {nbug_prefi}, nbd_prefi {nbd_prefi}\"\n                )\n                kt_kernel_ext.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)\n                moe = kt_kernel_ext.moe.Int8_KERNEL_MOE(config)\n            elif quant_mode == \"int4\":\n                moe = kt_kernel_ext.moe.Int4_KERNEL_MOE(config)\n            else:\n                raise ValueError(f\"Unsupported quantization mode: {quant_mode}\")\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n\n        expert_ids = (\n            torch.rand(test_iter * qlen, expert_num, device=\"cuda\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(test_iter, qlen * num_experts_per_tok)\n            .to(\"cpu\")\n            .contiguous()\n        )\n        weights = (\n            torch.rand((test_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cuda\")\n            .to(\"cpu\")\n            .contiguous()\n        )\n        input_tensor = (\n            torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        output_tensor = (\n            torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        bsz_tensor = torch.tensor([qlen], device=\"cuda\").to(\"cpu\").contiguous()\n\n        # 预热迭代\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            # print(f'warmup iteration {i}')\n            # start_it = time.time_ns()\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i].data_ptr(),\n                    weights[i].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    # False,\n                )\n            )\n            CPUInfer.sync()\n            # end_it = time.time_ns()\n            # print('python Time(ns): ', end_it - start_it)\n\n        # 测试迭代\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            # print(f'test iteration {i}')\n            # start_it = time.time_ns()\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i].data_ptr(),\n                    weights[i].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n            # end_it = time.time_ns()\n            # print('python Time(ns): ', end_it - start_it)\n        end = time.perf_counter()\n        total_time = end - start\n\n        # 计算性能指标\n        time_per_iter_us = total_time / test_iter * 1e6\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            # * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))\n            * qlen\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # 单位：GB/s\n        flops = (\n            hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n        )  # 单位：TFLOPS\n\n        print(\"Quant mode: \", quant_mode)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", time_per_iter_us)\n        print(\"Bandwidth: \", bandwidth, \"GB/s\")\n        print(\"Flops: \", flops, \"TFLOPS\")\n        print(\"\")\n\n        # 整理结果记录，包括测试参数\n        result = {\n            \"test_name\": os.path.basename(__file__),\n            \"quant_mode\": quant_mode,\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"max_len\": max_len,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        # 添加 git 提交记录信息\n        result.update(get_git_commit())\n        # 添加系统信息（包括 CPU 核数和 socket 数量）\n        result.update(get_system_info())\n        # 将结果以 JSON 形式追加到文件中\n        record_results(result)\n\n\nif __name__ == \"__main__\":\n    # 选择需要测试的量化模式\n    bench_moe(\"int8\")\n    # bench_moe(\"int4\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_moe_kernel_tiling.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nBench MOE kernel with runtime tiling params (N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK)\n- Demonstrates how to get/set tiling params from Python via kt_kernel_ext.moe.tiling\n- Runs a small benchmark similar to bench_moe_kernel.py\n\nUsage examples:\n  # 1) Just run with defaults (int8)\n  python bench_moe_kernel_tiling.py --quant int8\n\n  # 2) Override tiling params for INT8\n  python bench_moe_kernel_tiling.py --quant int8 \\\n    --n_block_up_gate 32 --n_block_down 64 --n_block 64 --m_block 320 --k_block 7168\n\n  # 3) Set both INT8 and INT4 tiling params (if INT4 kernel is available on your platform)\n  python bench_moe_kernel_tiling.py --quant int4 --set_all \\\n    --n_block_up_gate 256 --n_block_down 1024 --n_block 64 --m_block 320 --k_block 7168\n\"\"\"\nimport os\nimport sys\nimport time\nimport argparse\n\nos.environ.setdefault(\"BLAS_NUM_THREADS\", \"1\")\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\n\nimport torch  # noqa: E402\nfrom kt_kernel import kt_kernel_ext as ce  # noqa: E402\nfrom tqdm import tqdm  # noqa: E402\n\n\ndef maybe_get_class(module, name):\n    return getattr(module, name) if hasattr(module, name) else None\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--quant\", choices=[\"int8\", \"int4\"], default=\"int8\")\n    parser.add_argument(\"--expert_num\", type=int, default=256)\n    parser.add_argument(\"--hidden_size\", type=int, default=7168)\n    parser.add_argument(\"--intermediate_size\", type=int, default=2048)\n    parser.add_argument(\"--num_experts_per_tok\", type=int, default=8)\n    parser.add_argument(\"--max_len\", type=int, default=25600)\n    parser.add_argument(\"--layer_num\", type=int, default=1)\n    parser.add_argument(\"--qlen\", type=int, default=1024)\n    parser.add_argument(\"--warm_up_iter\", type=int, default=200)\n    parser.add_argument(\"--test_iter\", type=int, default=500)\n    parser.add_argument(\"--threads\", type=int, default=160, help=\"CPUInfer initialization param\")\n\n    # Tiling params\n    parser.add_argument(\"--set_all\", action=\"store_true\", help=\"Apply tiling to both INT8 and INT4 kernels\")\n    parser.add_argument(\"--n_block_up_gate\", type=int, default=None)\n    parser.add_argument(\"--n_block_down\", type=int, default=None)\n    parser.add_argument(\"--n_block\", type=int, default=None)\n    parser.add_argument(\"--m_block\", type=int, default=None)\n    parser.add_argument(\"--k_block\", type=int, default=None)\n    parser.add_argument(\"--n_block_up_gate_prefi\", type=int, default=None)\n    parser.add_argument(\"--n_block_down_prefi\", type=int, default=None)\n\n    args = parser.parse_args()\n\n    # Show current tiling defaults\n    if args.quant == \"int8\":\n        print(\"[tiling] default int8:\", ce.moe.tiling.get_int8())\n    if hasattr(ce.moe.tiling, \"get_int4\") and args.quant == \"int4\":\n        print(\"[tiling] default int4:\", ce.moe.tiling.get_int4())\n\n    # Apply overrides if provided\n    if any(v is not None for v in [args.n_block_up_gate, args.n_block_down, args.n_block, args.m_block, args.k_block]):\n        # Fill missing values with current defaults to avoid overwriting unrelated params\n        def fill_defaults(getter):\n            cur = getter()\n            return (\n                args.n_block_up_gate if args.n_block_up_gate is not None else int(cur[\"n_block_up_gate\"]),\n                args.n_block_down if args.n_block_down is not None else int(cur[\"n_block_down\"]),\n                args.n_block if args.n_block is not None else int(cur[\"n_block\"]),\n                args.m_block if args.m_block is not None else int(cur[\"m_block\"]),\n                args.k_block if args.k_block is not None else int(cur[\"k_block\"]),\n                (\n                    args.n_block_up_gate_prefi\n                    if args.n_block_up_gate_prefi is not None\n                    else int(cur[\"n_block_up_gate_prefi\"])\n                ),\n                args.n_block_down_prefi if args.n_block_down_prefi is not None else int(cur[\"n_block_down_prefi\"]),\n            )\n\n        if args.set_all and hasattr(ce.moe.tiling, \"set_all\"):\n            nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int8)\n            ce.moe.tiling.set_all(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)\n            print(\"[tiling] set_all ->\", ce.moe.tiling.get_int8())\n            if hasattr(ce.moe.tiling, \"get_int4\"):\n                print(\"[tiling] set_all -> int4:\", ce.moe.tiling.get_int4())\n        else:\n            if args.quant == \"int8\":\n                nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int8)\n                ce.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)\n                print(\"[tiling] set_int8 ->\", ce.moe.tiling.get_int8())\n            elif args.quant == \"int4\" and hasattr(ce.moe.tiling, \"set_int4\"):\n                nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi = fill_defaults(ce.moe.tiling.get_int4)\n                ce.moe.tiling.set_int4(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)\n                print(\"[tiling] set_int4 ->\", ce.moe.tiling.get_int4())\n\n    # Warn about divisibility expectations; kernels assume specific blocking\n    # - Some helpers assert n % N_BLOCK == 0, etc. Ensure your dims/tiles align.\n    print(\"[note] Ensure your selected tiling parameters are compatible with hidden/intermediate sizes and blocking.\")\n\n    # Initialize CPUInfer\n    CPUInfer = ce.CPUInfer(args.threads)\n\n    # Select MOE kernel\n    moe_cls = None\n    if args.quant == \"int8\":\n        moe_cls = maybe_get_class(ce.moe, \"Int8_KERNEL_MOE\")\n        if moe_cls is None:\n            raise RuntimeError(\"Int8 kernel binding 'Int8_KERNEL_MOE' not found.\")\n        bytes_per_elem = 1.0\n    else:\n        moe_cls = maybe_get_class(ce.moe, \"Int4_KERNEL_MOE\")\n        if moe_cls is None:\n            raise RuntimeError(\"Int4 kernel binding 'Int4_KERNEL_MOE' not available on this platform.\")\n        bytes_per_elem = 0.5\n\n    # Prepare config/weights\n    expert_num = args.expert_num\n    hidden_size = args.hidden_size\n    intermediate_size = args.intermediate_size\n    num_experts_per_tok = args.num_experts_per_tok\n    layer_num = args.layer_num\n    max_len = args.max_len\n\n    physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device=\"cpu\").contiguous()\n\n    moes = []\n    gate_projs, up_projs, down_projs = [], [], []\n\n    for layer_idx in range(layer_num):\n        gate_proj = torch.randn(\n            (expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cpu\"\n        ).contiguous()\n        up_proj = torch.randn(\n            (expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cpu\"\n        ).contiguous()\n        down_proj = torch.randn(\n            (expert_num, hidden_size, intermediate_size), dtype=torch.float32, device=\"cpu\"\n        ).contiguous()\n\n        cfg = ce.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n        cfg.max_len = max_len\n        cfg.gate_proj = gate_proj.data_ptr()\n        cfg.up_proj = up_proj.data_ptr()\n        cfg.down_proj = down_proj.data_ptr()\n        cfg.pool = CPUInfer.backend_\n\n        moe = moe_cls(cfg)\n        CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n        CPUInfer.sync()\n\n        gate_projs.append(gate_proj)\n        up_projs.append(up_proj)\n        down_projs.append(down_proj)\n        moes.append(moe)\n\n    qlen = args.qlen\n    warm_up_iter = args.warm_up_iter\n    test_iter = args.test_iter\n\n    expert_ids = (\n        torch.rand(test_iter * qlen, expert_num)\n        .argsort(dim=-1)[:, :num_experts_per_tok]\n        .reshape(test_iter, qlen * num_experts_per_tok)\n        .to(\"cpu\")\n        .contiguous()\n    )\n    weights = torch.rand((test_iter, qlen, num_experts_per_tok), dtype=torch.float32).to(\"cpu\").contiguous()\n    input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16).to(\"cpu\").contiguous()\n    output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16).to(\"cpu\").contiguous()\n    bsz_tensor = torch.tensor([qlen], dtype=torch.int32).to(\"cpu\").contiguous()\n\n    # Warmup\n    for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n        CPUInfer.submit(\n            moes[i % layer_num].forward_task(\n                bsz_tensor.data_ptr(),\n                num_experts_per_tok,\n                expert_ids[i].data_ptr(),\n                weights[i].data_ptr(),\n                input_tensor[i % layer_num].data_ptr(),\n                output_tensor[i % layer_num].data_ptr(),\n            )\n        )\n        CPUInfer.sync()\n\n    # Measure\n    start = time.perf_counter()\n    for i in tqdm(range(test_iter), desc=\"Testing\"):\n        CPUInfer.submit(\n            moes[i % layer_num].forward_task(\n                bsz_tensor.data_ptr(),\n                num_experts_per_tok,\n                expert_ids[i].data_ptr(),\n                weights[i].data_ptr(),\n                input_tensor[i % layer_num].data_ptr(),\n                output_tensor[i % layer_num].data_ptr(),\n                False,\n            )\n        )\n        CPUInfer.sync()\n    end = time.perf_counter()\n\n    total_time = end - start\n    time_per_iter_us = total_time / test_iter * 1e6\n    bandwidth_gbs = (\n        hidden_size * intermediate_size * 3 * num_experts_per_tok * qlen * bytes_per_elem * test_iter / total_time / 1e9\n    )\n    flops_tflops = hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n\n    print(\"\\n=== Results ===\")\n    print(\"quant:\", args.quant)\n    if hasattr(ce.moe.tiling, \"get_int8\") and args.quant == \"int8\":\n        print(\"tiling int8:\", ce.moe.tiling.get_int8())\n    if hasattr(ce.moe.tiling, \"get_int4\") and args.quant == \"int4\":\n        print(\"tiling int4:\", ce.moe.tiling.get_int4())\n    print(\"time (s):\", total_time)\n    print(\"iter:\", test_iter)\n    print(\"time per iter (us):\", time_per_iter_us)\n    print(\"bandwidth (GB/s):\", bandwidth_gbs)\n    print(\"TFLOPS:\", flops_tflops)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/bench/bench_moe_kml.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-06 10:41:28\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys, time, json, subprocess, platform\n\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\nfrom kt_kernel import kt_kernel_ext\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\n\n\n# 测试参数设置\nexpert_num = 16\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nlayer_num = 1\n\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\n# 将 CPUInfer 参数设为变量\nCPUINFER_PARAM = 112\nCPUInfer = kt_kernel_ext.CPUInfer(CPUINFER_PARAM)\n\n# worker_config = kt_kernel_ext.WorkerPoolConfig()\n# worker_config.subpool_count = 4\n# worker_config.subpool_numa_map= [0,1,2,3]\n# worker_config.subpool_thread_count = [36,36,36,36]\n# worker_config.subpool_thread_count = [39,39,39,39]\n# CPUINFER_PARAM = 156\n# CPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n\n\ndef get_git_commit():\n    \"\"\"\n    获取当前 git 提交记录（commit hash 和提交信息），\n    并检查是否存在未提交的更改（dirty）\n    \"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        # 检查是否存在未提交的更改\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"\n    获取系统信息，包括系统名称、CPU 型号、内存大小（GB）、CPU 核数及 socket 数量\n    \"\"\"\n    info = {}\n    # 系统名称及主机名\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system  # 如 Linux, Windows 等\n    info[\"node_name\"] = uname.node  # 主机名称\n\n    # 获取 CPU 型号（仅 Linux 支持）\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    # 获取内存大小（单位：GB），仅 Linux 支持\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    # 获取 CPU 核数（逻辑核数）\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    # 解析 /proc/cpuinfo 获取 socket 数量\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception as e:\n            sockets = set()\n    # 如果没有解析到 socket 信息，则默认至少有 1 个 socket\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\nscript_name = os.path.splitext(os.path.basename(script_path))[0]\njson_path = os.path.join(script_dir, \"bench_results \" + \".jsonl\")\n\n\ndef record_results(result, filename=json_path):\n    \"\"\"\n    将结果以 JSON 格式追加到文件中\n    \"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode():\n        if quant_mode == \"bf16\":\n            bytes_per_elem = 2.0\n        elif quant_mode == \"int8\":\n            bytes_per_elem = 1.0\n        elif quant_mode == \"int4\":\n            bytes_per_elem = 0.5\n        else:\n            raise ValueError(\"不支持的量化模式\")\n\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for layer_index in range(layer_num):\n            gate_proj = torch.randn(\n                (expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cpu\"\n            ).contiguous()\n            up_proj = torch.randn(\n                (expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cpu\"\n            ).contiguous()\n            down_proj = torch.randn(\n                (expert_num, hidden_size, intermediate_size), dtype=torch.float32, device=\"cpu\"\n            ).contiguous()\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = CPUInfer.backend_\n            if quant_mode == \"int8\":\n                moe = kt_kernel_ext.moe.KMLInt8_MOE(config)\n            elif quant_mode == \"int4\":\n                moe = kt_kernel_ext.moe.KMLInt4_MOE(config)\n            else:\n                raise ValueError(f\"Unsupported quantization mode: {quant_mode}\")\n            CPUInfer.submit(moe.load_weights_task())\n            CPUInfer.sync()\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n\n        expert_ids = (\n            torch.rand(test_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(test_iter, qlen * num_experts_per_tok)\n            .contiguous()\n        )\n        weights = torch.rand((test_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").contiguous()\n        input_tensor = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        output_tensor = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n\n        # 预热迭代\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            # print(f'warmup iteration {i}')\n            # start_it = time.time_ns()\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i].data_ptr(),\n                    weights[i].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n            # end_it = time.time_ns()\n            # print('python Time(ns): ', end_it - start_it)\n\n        # 测试迭代\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            # print(f'test iteration {i}')\n            # start_it = time.time_ns()\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i].data_ptr(),\n                    weights[i].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n            # end_it = time.time_ns()\n            # print('python Time(ns): ', end_it - start_it)\n        end = time.perf_counter()\n        total_time = end - start\n\n        # 计算性能指标\n        time_per_iter_us = total_time / test_iter * 1e6\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # 单位：GB/s\n        flops = (\n            hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n        )  # 单位：TFLOPS\n\n        print(\"Quant mode: \", quant_mode)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", time_per_iter_us)\n        print(\"Bandwidth: \", bandwidth, \"GB/s\")\n        print(\"Flops: \", flops, \"TFLOPS\")\n        print(\"\")\n\n        # 整理结果记录，包括测试参数\n        result = {\n            \"test_name\": os.path.basename(__file__),\n            \"quant_mode\": quant_mode,\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"max_len\": max_len,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        # 添加 git 提交记录信息\n        result.update(get_git_commit())\n        # 添加系统信息（包括 CPU 核数和 socket 数量）\n        result.update(get_system_info())\n        # 将结果以 JSON 形式追加到文件中\n        record_results(result)\n\n\nif __name__ == \"__main__\":\n    # 选择需要测试的量化模式\n    # bench_moe(\"bf16\")\n    # bench_moe(\"int8\")\n    bench_moe(\"int4\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_moe_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-07-25 10:32:57\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nimport torch\nimport torch.nn.quantized as nnq\n\nscale, zero_point = 0.1, 0  # Adjust scale and zero_point based on your dataset\n\nexpert_num = 160\nhidden_size = 5120\nintermediate_size = 1536\nnum_experts_per_tok = 6\nlayer_num = 10\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    if isinstance(gate_proj, nnq.Linear):\n        input_q = torch.quantize_per_tensor(input.to(torch.float32), scale, zero_point, torch.quint8)\n        gate_buf = gate_proj(input_q)\n        up_buf = up_proj(input_q)\n        gate_buf = gate_buf.dequantize()\n        up_buf = up_buf.dequantize()\n        intermediate = act_fn(gate_buf) * up_buf\n        intermediate_q = torch.quantize_per_tensor(intermediate, scale, zero_point, torch.quint8)\n        expert_output = down_proj(intermediate_q)\n        ret = expert_output.dequantize()\n    else:\n        gate_buf = torch.mm(input.to(gate_proj.dtype), gate_proj.t())\n        up_buf = torch.mm(input.to(up_proj.dtype), up_proj.t())\n        intermediate = act_fn(gate_buf) * up_buf\n        ret = torch.mm(intermediate.to(down_proj.dtype), down_proj.t())\n    return ret\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"fp32\":\n            proj_type = torch.float32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = torch.float16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = torch.bfloat16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"qint8\":\n            proj_type = torch.qint8\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            if quant_mode == \"qint8\":\n                quantized_gate_proj = []\n                quantized_up_proj = []\n                quantized_down_proj = []\n                for i in range(expert_num):\n                    gate_proj_q = torch.quantize_per_tensor(gate_proj[i], scale, zero_point, torch.qint8)\n                    quantized_gate = nnq.Linear(hidden_size, intermediate_size)\n                    quantized_gate.set_weight_bias(gate_proj_q, None)\n                    quantized_gate_proj.append(quantized_gate)\n                    up_proj_q = torch.quantize_per_tensor(up_proj[i], scale, zero_point, torch.qint8)\n                    quantized_up = nnq.Linear(hidden_size, intermediate_size)\n                    quantized_up.set_weight_bias(up_proj_q, None)\n                    quantized_up_proj.append(quantized_up)\n                    down_proj_q = torch.quantize_per_tensor(down_proj[i], scale, zero_point, torch.qint8)\n                    quantized_down = nnq.Linear(intermediate_size, hidden_size)\n                    quantized_down.set_weight_bias(down_proj_q, None)\n                    quantized_down_proj.append(quantized_down)\n                gate_projs.append(quantized_gate_proj)\n                up_projs.append(quantized_up_proj)\n                down_projs.append(quantized_down_proj)\n            else:\n                gate_projs.append(gate_proj.to(proj_type))\n                up_projs.append(up_proj.to(proj_type))\n                down_projs.append(down_proj.to(proj_type))\n        expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = \"cuda\")[:num_experts_per_tok] for _ in range(qlen)]) for _ in range(layer_num)]).to(\"cpu\").contiguous()\n        weights = torch.rand((layer_num, qlen, num_experts_per_tok), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            moe_torch(input[i % layer_num], expert_ids[i % layer_num], weights[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            moe_torch(input[i % layer_num], expert_ids[i % layer_num], weights[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * num_experts_per_tok * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_moe(\"fp32\")\nbench_moe(\"fp16\")\nbench_moe(\"bf16\")\nbench_moe(\"qint8\")\n"
  },
  {
    "path": "kt-kernel/bench/bench_write_buffer.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nBenchmark write_weight_scale_to_buffer for AMX MOE operators.\n\nSupports:\n- FP8: FP8 weights (1 byte) + float32 scales (block-wise)\n- FP8_PERCHANNEL: FP8 weights (1 byte) + float32 per-channel scales\n- BF16: Native BF16 weights (2 bytes), no scales\n\nUsage:\n    python bench_write_buffer.py          # Run all modes\n    python bench_write_buffer.py fp8      # Run FP8 only\n    python bench_write_buffer.py fp8_perchannel  # Run FP8 per-channel only\n    python bench_write_buffer.py bf16     # Run BF16 only\n\"\"\"\nimport json\nimport os\nimport platform\nimport subprocess\nimport sys\nimport time\n\nfrom tqdm import tqdm\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\n\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\n# Benchmark parameters\nexpert_num = 256\nnum_experts_per_tok = 8\ngpu_tp_count = 2\n\nwarm_up_iter = 30\ntest_iter = 70\n\ngpu_experts_num = expert_num\n\nhidden_size = 7168\nintermediate_size = 2048\ngroup_size = 128  # FP8 uses 128x128 block-wise scales\nmax_len = 1\n\nphysical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device=\"cpu\").contiguous()\nCPUInfer = kt_kernel_ext.CPUInfer(80)\n\n\ndef get_git_commit():\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        result[\"dirty\"] = bool(dirty_output)\n        if dirty_output:\n            result[\"dirty_files\"] = dirty_output.splitlines()\n    except Exception as e:\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    info = {}\n    info[\"system_name\"] = platform.uname().system\n    info[\"node_name\"] = platform.uname().node\n    info[\"cpu_core_count\"] = os.cpu_count()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        with open(\"/proc/cpuinfo\", \"r\") as f:\n            for line in f:\n                if \"model name\" in line:\n                    info[\"cpu_model\"] = line.split(\":\", 1)[1].strip()\n                    break\n    if os.path.exists(\"/proc/meminfo\"):\n        with open(\"/proc/meminfo\", \"r\") as f:\n            for line in f:\n                if \"MemTotal\" in line:\n                    mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                    info[\"memory_size_GB\"] = round(mem_kb / (1024 * 1024), 2)\n                    break\n    return info\n\n\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\nscript_name = os.path.splitext(os.path.basename(script_path))[0]\njson_path = os.path.join(script_dir, script_name + \".jsonl\")\n\n\ndef record_results(result, filename=json_path):\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\ndef div_up(a, b):\n    return (a + b - 1) // b\n\n\n# ==============================================================================\n# FP8 Functions\n# ==============================================================================\n\n\ndef allocate_weights_fp8():\n    per_mat_weight_bytes = hidden_size * intermediate_size\n    n_blocks_n_gate_up = div_up(intermediate_size, group_size)\n    n_blocks_k = div_up(hidden_size, group_size)\n    per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k\n    per_mat_scale_elems_down = n_blocks_k * n_blocks_n_gate_up\n\n    gate_q = (\n        torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device=\"cuda\")\n        .to(\"cpu\")\n        .contiguous()\n    )\n    up_q = (\n        torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device=\"cuda\")\n        .to(\"cpu\")\n        .contiguous()\n    )\n    down_q = (\n        torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device=\"cuda\")\n        .to(\"cpu\")\n        .contiguous()\n    )\n    gate_scale = (\n        torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device=\"cuda\").to(\"cpu\").contiguous()\n    )\n    up_scale = (\n        torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device=\"cuda\").to(\"cpu\").contiguous()\n    )\n    down_scale = (\n        torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32, device=\"cuda\").to(\"cpu\").contiguous()\n    )\n\n    return {\n        \"gate_q\": gate_q,\n        \"up_q\": up_q,\n        \"down_q\": down_q,\n        \"gate_scale\": gate_scale,\n        \"up_scale\": up_scale,\n        \"down_scale\": down_scale,\n        \"per_mat_weight_bytes\": per_mat_weight_bytes,\n        \"per_mat_scale_elems_gate_up\": per_mat_scale_elems_gate_up,\n        \"per_mat_scale_elems_down\": per_mat_scale_elems_down,\n    }\n\n\ndef allocate_weights_fp8_perchannel():\n    per_mat_weight_bytes = hidden_size * intermediate_size\n    per_mat_scale_elems_gate_up = intermediate_size\n    per_mat_scale_elems_down = hidden_size\n\n    gate_q = (\n        torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device=\"cuda\")\n        .to(\"cpu\")\n        .contiguous()\n    )\n    up_q = (\n        torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device=\"cuda\")\n        .to(\"cpu\")\n        .contiguous()\n    )\n    down_q = (\n        torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8, device=\"cuda\")\n        .to(\"cpu\")\n        .contiguous()\n    )\n    gate_scale = (\n        torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device=\"cuda\").to(\"cpu\").contiguous()\n    )\n    up_scale = (\n        torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32, device=\"cuda\").to(\"cpu\").contiguous()\n    )\n    down_scale = (\n        torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32, device=\"cuda\").to(\"cpu\").contiguous()\n    )\n\n    return {\n        \"gate_q\": gate_q,\n        \"up_q\": up_q,\n        \"down_q\": down_q,\n        \"gate_scale\": gate_scale,\n        \"up_scale\": up_scale,\n        \"down_scale\": down_scale,\n        \"per_mat_weight_bytes\": per_mat_weight_bytes,\n        \"per_mat_scale_elems_gate_up\": per_mat_scale_elems_gate_up,\n        \"per_mat_scale_elems_down\": per_mat_scale_elems_down,\n    }\n\n\ndef build_moe_fp8(layer_idx=0):\n    \"\"\"Build a single FP8 MOE instance.\"\"\"\n    weights = allocate_weights_fp8()\n\n    config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n    config.max_len = max_len\n    config.layer_idx = layer_idx\n    config.quant_config.bits = 8\n    config.quant_config.group_size = group_size\n    config.quant_config.zero_point = False\n    config.pool = CPUInfer.backend_\n    config.gate_proj = weights[\"gate_q\"].data_ptr()\n    config.up_proj = weights[\"up_q\"].data_ptr()\n    config.down_proj = weights[\"down_q\"].data_ptr()\n    config.gate_scale = weights[\"gate_scale\"].data_ptr()\n    config.up_scale = weights[\"up_scale\"].data_ptr()\n    config.down_scale = weights[\"down_scale\"].data_ptr()\n\n    moe = kt_kernel_ext.moe.AMXFP8_MOE(config)\n    CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n    CPUInfer.sync()\n\n    buffer_shapes = {\n        \"per_mat_weight_bytes\": weights[\"per_mat_weight_bytes\"],\n        \"per_mat_scale_elems_gate_up\": weights[\"per_mat_scale_elems_gate_up\"],\n        \"per_mat_scale_elems_down\": weights[\"per_mat_scale_elems_down\"],\n    }\n\n    return moe, buffer_shapes, weights\n\n\ndef build_moe_fp8_perchannel(layer_idx=0):\n    \"\"\"Build a single FP8 per-channel MOE instance.\"\"\"\n    weights = allocate_weights_fp8_perchannel()\n\n    config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n    config.max_len = max_len\n    config.layer_idx = layer_idx\n    config.quant_config.bits = 8\n    config.quant_config.group_size = 0\n    config.quant_config.zero_point = False\n    config.quant_config.per_channel = True\n    config.pool = CPUInfer.backend_\n    config.gate_proj = weights[\"gate_q\"].data_ptr()\n    config.up_proj = weights[\"up_q\"].data_ptr()\n    config.down_proj = weights[\"down_q\"].data_ptr()\n    config.gate_scale = weights[\"gate_scale\"].data_ptr()\n    config.up_scale = weights[\"up_scale\"].data_ptr()\n    config.down_scale = weights[\"down_scale\"].data_ptr()\n\n    moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config)\n    CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n    CPUInfer.sync()\n\n    buffer_shapes = {\n        \"per_mat_weight_bytes\": weights[\"per_mat_weight_bytes\"],\n        \"per_mat_scale_elems_gate_up\": weights[\"per_mat_scale_elems_gate_up\"],\n        \"per_mat_scale_elems_down\": weights[\"per_mat_scale_elems_down\"],\n    }\n\n    return moe, buffer_shapes, weights\n\n\ndef allocate_buffers_fp8(buffer_shapes):\n    \"\"\"Allocate output buffers for FP8 single expert.\"\"\"\n    per_mat_weight_bytes = buffer_shapes[\"per_mat_weight_bytes\"]\n    per_mat_scale_elems_gate_up = buffer_shapes[\"per_mat_scale_elems_gate_up\"]\n    per_mat_scale_elems_down = buffer_shapes[\"per_mat_scale_elems_down\"]\n\n    weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count\n    scale_elems_per_expert_per_tp_gate_up = per_mat_scale_elems_gate_up // gpu_tp_count\n    scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down // gpu_tp_count\n\n    w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w13_scale_bufs = [\n        torch.empty(2 * scale_elems_per_expert_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)\n    ]\n    w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]\n\n    buffer_ptrs = {\n        \"w13_weight_ptrs\": [buf.data_ptr() for buf in w13_weight_bufs],\n        \"w13_scale_ptrs\": [buf.data_ptr() for buf in w13_scale_bufs],\n        \"w2_weight_ptrs\": [buf.data_ptr() for buf in w2_weight_bufs],\n        \"w2_scale_ptrs\": [buf.data_ptr() for buf in w2_scale_bufs],\n    }\n\n    keep_tensors = {\n        \"w13_weight_bufs\": w13_weight_bufs,\n        \"w13_scale_bufs\": w13_scale_bufs,\n        \"w2_weight_bufs\": w2_weight_bufs,\n        \"w2_scale_bufs\": w2_scale_bufs,\n    }\n\n    return buffer_ptrs, keep_tensors\n\n\ndef allocate_buffers_fp8_perchannel(buffer_shapes):\n    \"\"\"Allocate output buffers for FP8 per-channel single expert.\"\"\"\n    per_mat_weight_bytes = buffer_shapes[\"per_mat_weight_bytes\"]\n    per_mat_scale_elems_gate_up = buffer_shapes[\"per_mat_scale_elems_gate_up\"]\n    per_mat_scale_elems_down = buffer_shapes[\"per_mat_scale_elems_down\"]\n\n    weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count\n    scale_elems_per_expert_per_tp_gate_up = per_mat_scale_elems_gate_up // gpu_tp_count\n    scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down\n\n    w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w13_scale_bufs = [\n        torch.empty(2 * scale_elems_per_expert_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)\n    ]\n    w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w2_scale_bufs = [torch.empty(scale_elems_per_expert_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]\n\n    buffer_ptrs = {\n        \"w13_weight_ptrs\": [buf.data_ptr() for buf in w13_weight_bufs],\n        \"w13_scale_ptrs\": [buf.data_ptr() for buf in w13_scale_bufs],\n        \"w2_weight_ptrs\": [buf.data_ptr() for buf in w2_weight_bufs],\n        \"w2_scale_ptrs\": [buf.data_ptr() for buf in w2_scale_bufs],\n    }\n\n    keep_tensors = {\n        \"w13_weight_bufs\": w13_weight_bufs,\n        \"w13_scale_bufs\": w13_scale_bufs,\n        \"w2_weight_bufs\": w2_weight_bufs,\n        \"w2_scale_bufs\": w2_scale_bufs,\n    }\n\n    return buffer_ptrs, keep_tensors\n\n\n# ==============================================================================\n# BF16 Functions\n# ==============================================================================\n\n\ndef allocate_weights_bf16():\n    per_mat_weight_elems = hidden_size * intermediate_size\n    per_mat_weight_bytes = per_mat_weight_elems * 2  # BF16 = 2 bytes\n\n    gate_proj = (\n        torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n    )\n    up_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n    down_proj = (\n        torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n    )\n\n    return {\n        \"gate_proj\": gate_proj,\n        \"up_proj\": up_proj,\n        \"down_proj\": down_proj,\n        \"per_mat_weight_bytes\": per_mat_weight_bytes,\n        \"per_mat_weight_elems\": per_mat_weight_elems,\n    }\n\n\ndef build_moe_bf16(layer_idx=0):\n    \"\"\"Build a single BF16 MOE instance.\"\"\"\n    weights = allocate_weights_bf16()\n\n    config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n    config.max_len = max_len\n    config.layer_idx = layer_idx\n    config.pool = CPUInfer.backend_\n    config.gate_proj = weights[\"gate_proj\"].data_ptr()\n    config.up_proj = weights[\"up_proj\"].data_ptr()\n    config.down_proj = weights[\"down_proj\"].data_ptr()\n    config.gate_scale = 0\n    config.up_scale = 0\n    config.down_scale = 0\n\n    moe = kt_kernel_ext.moe.AMXBF16_MOE(config)\n    CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n    CPUInfer.sync()\n\n    buffer_shapes = {\n        \"per_mat_weight_bytes\": weights[\"per_mat_weight_bytes\"],\n        \"per_mat_weight_elems\": weights[\"per_mat_weight_elems\"],\n    }\n\n    return moe, buffer_shapes, weights\n\n\ndef allocate_buffers_bf16(buffer_shapes):\n    \"\"\"Allocate output buffers for BF16 single expert (no scales).\"\"\"\n    per_mat_weight_bytes = buffer_shapes[\"per_mat_weight_bytes\"]\n\n    weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count\n\n    w13_weight_bufs = [torch.empty(2 * weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w2_weight_bufs = [torch.empty(weight_bytes_per_expert_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    # Dummy scale buffers (not used for BF16 but needed for interface)\n    w13_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]\n    w2_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]\n\n    buffer_ptrs = {\n        \"w13_weight_ptrs\": [buf.data_ptr() for buf in w13_weight_bufs],\n        \"w13_scale_ptrs\": [buf.data_ptr() for buf in w13_scale_bufs],\n        \"w2_weight_ptrs\": [buf.data_ptr() for buf in w2_weight_bufs],\n        \"w2_scale_ptrs\": [buf.data_ptr() for buf in w2_scale_bufs],\n    }\n\n    keep_tensors = {\n        \"w13_weight_bufs\": w13_weight_bufs,\n        \"w13_scale_bufs\": w13_scale_bufs,\n        \"w2_weight_bufs\": w2_weight_bufs,\n        \"w2_scale_bufs\": w2_scale_bufs,\n    }\n\n    return buffer_ptrs, keep_tensors\n\n\n# ==============================================================================\n# Benchmark Functions\n# ==============================================================================\n\n\ndef bench_write_buffer(quant_mode: str):\n    \"\"\"Benchmark write_weight_scale_to_buffer for specified quant mode.\"\"\"\n    print(f\"\\n{'='*60}\")\n    print(f\"{quant_mode.upper()} write_weight_scale_to_buffer benchmark\")\n    print(f\"{'='*60}\")\n\n    if quant_mode == \"fp8\":\n        bytes_per_elem = 1.0\n        moe_0, buffer_shapes, keep_tensors_0 = build_moe_fp8(layer_idx=0)\n        moe_1, _, keep_tensors_1 = build_moe_fp8(layer_idx=1)\n        buffer_ptrs, buffer_keep = allocate_buffers_fp8(buffer_shapes)\n\n        # Calculate total bytes including scales\n        total_weights = hidden_size * intermediate_size * expert_num * 3\n        total_scale_bytes = (\n            (buffer_shapes[\"per_mat_scale_elems_gate_up\"] * 2 + buffer_shapes[\"per_mat_scale_elems_down\"])\n            * expert_num\n            * 4\n        )\n        bytes_per_call = total_weights + total_scale_bytes\n\n    elif quant_mode == \"fp8_perchannel\":\n        bytes_per_elem = 1.0\n        moe_0, buffer_shapes, keep_tensors_0 = build_moe_fp8_perchannel(layer_idx=0)\n        moe_1, _, keep_tensors_1 = build_moe_fp8_perchannel(layer_idx=1)\n        buffer_ptrs, buffer_keep = allocate_buffers_fp8_perchannel(buffer_shapes)\n\n        total_weights = hidden_size * intermediate_size * expert_num * 3\n        total_scale_bytes = (\n            (buffer_shapes[\"per_mat_scale_elems_gate_up\"] * 2 + buffer_shapes[\"per_mat_scale_elems_down\"])\n            * expert_num\n            * 4\n        )\n        bytes_per_call = total_weights + total_scale_bytes\n\n    elif quant_mode == \"bf16\":\n        bytes_per_elem = 2.0\n        moe_0, buffer_shapes, keep_tensors_0 = build_moe_bf16(layer_idx=0)\n        moe_1, _, keep_tensors_1 = build_moe_bf16(layer_idx=1)\n        buffer_ptrs, buffer_keep = allocate_buffers_bf16(buffer_shapes)\n\n        # BF16: only weights, no scales\n        bytes_per_call = hidden_size * intermediate_size * expert_num * 3 * 2  # BF16 = 2 bytes\n\n    else:\n        raise ValueError(f\"Unsupported quant_mode: {quant_mode}\")\n\n    moes = [moe_0, moe_1]\n\n    # Warm-up\n    for _ in tqdm(range(warm_up_iter), desc=f\"[{quant_mode.upper()}] Warm-up\"):\n        for moe_idx, moe in enumerate(moes):\n            for expert_id in range(gpu_experts_num):\n                CPUInfer.submit(\n                    moe.write_weight_scale_to_buffer_task(gpu_tp_count=gpu_tp_count, expert_id=expert_id, **buffer_ptrs)\n                )\n                CPUInfer.sync()\n\n    # Benchmark\n    total_time = 0\n    for iter_idx in tqdm(range(test_iter), desc=f\"[{quant_mode.upper()}] Testing\"):\n        start = time.perf_counter()\n        for moe_idx, moe in enumerate(moes):\n            for expert_id in range(gpu_experts_num):\n                CPUInfer.submit(\n                    moe.write_weight_scale_to_buffer_task(gpu_tp_count=gpu_tp_count, expert_id=expert_id, **buffer_ptrs)\n                )\n                CPUInfer.sync()\n        end = time.perf_counter()\n        iter_time = end - start\n        total_time += iter_time\n        # print(f\"  Iter {iter_idx}: {iter_time*1000:.2f} ms\")\n        time.sleep(0.3)\n\n    # bytes_per_call is for one MOE, we have 2 MOEs\n    bytes_per_iter = bytes_per_call * 2\n    time_per_iter_ms = total_time / test_iter * 1000\n    bandwidth_gbs = bytes_per_iter * test_iter / total_time / 1e9\n\n    print(f\"\\n{'='*60}\")\n    print(f\"{quant_mode.upper()} write_weight_scale_to_buffer Results (2 MOEs alternating)\")\n    print(f\"{'='*60}\")\n    print(f\"Time per iteration: {time_per_iter_ms:.2f} ms\")\n    print(f\"Bandwidth: {bandwidth_gbs:.2f} GB/s\")\n    print(f\"Experts per MOE: {gpu_experts_num}, MOEs: 2\")\n    print(f\"Time per expert: {time_per_iter_ms/(gpu_experts_num*2)*1000:.2f} us\")\n\n    result = {\n        \"op\": f\"write_weight_scale_to_buffer_{quant_mode}\",\n        \"quant_mode\": quant_mode,\n        \"time_per_iteration_ms\": time_per_iter_ms,\n        \"bandwidth_GBs\": bandwidth_gbs,\n        \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\"),\n        \"test_parameters\": {\n            \"expert_num\": expert_num,\n            \"hidden_size\": hidden_size,\n            \"intermediate_size\": intermediate_size,\n            \"gpu_tp_count\": gpu_tp_count,\n            \"bytes_per_iter\": bytes_per_iter,\n            \"num_moes\": 2,\n        },\n    }\n    if quant_mode == \"fp8\":\n        result[\"test_parameters\"][\"group_size\"] = group_size\n\n    result.update(get_git_commit())\n    result.update(get_system_info())\n    record_results(result)\n\n    return bandwidth_gbs\n\n\ndef main(quant_modes=None):\n    \"\"\"Run benchmarks for specified quant modes.\"\"\"\n    if quant_modes is None:\n        quant_modes = [\"fp8\", \"fp8_perchannel\", \"bf16\"]\n\n    results = {}\n    for mode in quant_modes:\n        try:\n            bandwidth = bench_write_buffer(mode)\n            results[mode] = f\"PASSED ({bandwidth:.2f} GB/s)\"\n        except Exception as e:\n            results[mode] = f\"FAILED: {e}\"\n            import traceback\n\n            traceback.print_exc()\n\n    print(\"\\n\" + \"=\" * 60)\n    print(\"SUMMARY\")\n    print(\"=\" * 60)\n    for mode, result in results.items():\n        print(f\"  {mode.upper()}: {result}\")\n\n\nif __name__ == \"__main__\":\n    if len(sys.argv) > 1:\n        mode = sys.argv[1].lower()\n        if mode in [\"fp8\", \"fp8_perchannel\", \"bf16\"]:\n            main([mode])\n        else:\n            print(f\"Unknown mode: {mode}. Use 'fp8', 'fp8_perchannel' or 'bf16'\")\n            sys.exit(1)\n    else:\n        main()\n"
  },
  {
    "path": "kt-kernel/bench/compare_moe_performance.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nMoE Performance Comparison Script\nCompares performance between KTransformers AMX MoE and SGL CPU MoE implementations\n\"\"\"\nimport os\nimport sys\nimport time\nimport json\nimport platform\nimport subprocess\nimport argparse\nimport logging\nimport signal\nfrom datetime import datetime\nfrom typing import Dict, List, Optional, Tuple\nfrom dataclasses import dataclass, asdict\nfrom pathlib import Path\n\n# Configure logging\nlogging.basicConfig(\n    level=logging.INFO,\n    format='%(asctime)s - %(levelname)s - %(message)s'\n)\nlogger = logging.getLogger(__name__)\n\n# Environment configuration\n@dataclass\nclass EnvironmentConfig:\n    malloc_conf: str = \"oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1\"\n    jemalloc_path: str = \"/home/xwy/Projects/jemalloc/lib/libjemalloc.so\"\n    \n    def apply(self):\n        os.environ['MALLOC_CONF'] = self.malloc_conf\n        if os.path.exists(self.jemalloc_path):\n            os.environ['LD_PRELOAD'] = self.jemalloc_path\n        else:\n            logger.warning(f\"jemalloc not found at {self.jemalloc_path}\")\n\n# Apply environment configuration\nenv_config = EnvironmentConfig()\nenv_config.apply()\n\n# Add paths for both implementations\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'build'))\nsys.path.insert(0, '/home/xwy/Projects/sgl-cpu-tests')\n\nimport torch\n\n# Try importing both implementations\ntry:\n    import kt_kernel_ext\n    KTRANSFORMERS_AVAILABLE = True\n    logger.info(\"KTransformers kt_kernel_ext loaded successfully\")\nexcept ImportError as e:\n    KTRANSFORMERS_AVAILABLE = False\n    logger.warning(f\"KTransformers kt_kernel_ext not available: {e}\")\n\ntry:\n    from sgl_kernel.common_ops import fused_experts_cpu\n    from sgl_kernel.common_ops import convert_weight_packed\n    SGL_AVAILABLE = True\n    logger.info(\"SGL kernel loaded successfully\")\nexcept ImportError as e:\n    SGL_AVAILABLE = False\n    logger.warning(f\"SGL kernel not available: {e}\")\n\n# Try importing int4 support\ntry:\n    # For SGL INT4, we'll check if the sglang-jianan directory exists\n    import os\n    sglang_path = \"/home/xwy/Projects/sglang-jianan\"\n    if os.path.exists(sglang_path) and os.path.exists(os.path.join(sglang_path, \"benchmark/kernels/int4_moe/benchmark_int4_moe.py\")):\n        SGL_INT4_AVAILABLE = True\n        logger.info(\"SGL INT4 support available (via sglang-jianan)\")\n    else:\n        SGL_INT4_AVAILABLE = False\n        logger.warning(\"SGL INT4 support not available: sglang-jianan directory not found\")\nexcept Exception as e:\n    SGL_INT4_AVAILABLE = False\n    logger.warning(f\"SGL INT4 support not available: {e}\")\n\ndef get_cpu_count() -> int:\n    \"\"\"Get logical CPU core count (including hyperthreading)\"\"\"\n    cpu_count = None\n    \n    # Method 1: os.cpu_count()\n    try:\n        cpu_count = os.cpu_count()\n        if cpu_count and cpu_count > 0:\n            logger.info(f\"Detected {cpu_count} logical CPU cores via os.cpu_count()\")\n            return cpu_count\n    except Exception as e:\n        logger.debug(f\"os.cpu_count() failed: {e}\")\n    \n    # Method 2: Check /proc/cpuinfo\n    try:\n        with open('/proc/cpuinfo', 'r') as f:\n            cpu_count = sum(1 for line in f if line.strip().startswith('processor'))\n        if cpu_count > 0:\n            logger.info(f\"Detected {cpu_count} logical CPU cores via /proc/cpuinfo\")\n            return cpu_count\n    except Exception as e:\n        logger.debug(f\"Failed to read /proc/cpuinfo: {e}\")\n    \n    # Default fallback\n    logger.warning(\"Could not detect CPU count, defaulting to 32\")\n    return 32\n\ndef get_physical_cpu_count() -> int:\n    \"\"\"Get physical CPU core count (excluding hyperthreading)\"\"\"\n    \n    # Method 1: Try lscpu command\n    try:\n        result = subprocess.run(['lscpu'], capture_output=True, text=True, timeout=5)\n        if result.returncode == 0:\n            cores_per_socket = None\n            sockets = None\n            for line in result.stdout.split('\\n'):\n                if 'Core(s) per socket:' in line:\n                    cores_per_socket = int(line.split(':')[1].strip())\n                elif 'Socket(s):' in line:\n                    sockets = int(line.split(':')[1].strip())\n            \n            if cores_per_socket and sockets:\n                physical_cores = cores_per_socket * sockets\n                logger.info(f\"Detected {physical_cores} physical CPU cores via lscpu\")\n                return physical_cores\n    except Exception as e:\n        logger.debug(f\"lscpu failed: {e}\")\n    \n    # Method 2: Check /sys/devices/system/cpu/\n    try:\n        cpu_path = '/sys/devices/system/cpu/'\n        if os.path.exists(cpu_path):\n            # Count unique physical core IDs\n            physical_cores = set()\n            for cpu_dir in os.listdir(cpu_path):\n                if cpu_dir.startswith('cpu') and cpu_dir[3:].isdigit():\n                    core_id_path = os.path.join(cpu_path, cpu_dir, 'topology/core_id')\n                    if os.path.exists(core_id_path):\n                        with open(core_id_path, 'r') as f:\n                            core_id = f.read().strip()\n                            physical_cores.add(core_id)\n            \n            if physical_cores:\n                count = len(physical_cores)\n                logger.info(f\"Detected {count} physical CPU cores via sysfs\")\n                return count\n    except Exception as e:\n        logger.debug(f\"Failed to check sysfs: {e}\")\n    \n    # Method 3: Parse /proc/cpuinfo for unique core ids\n    try:\n        with open('/proc/cpuinfo', 'r') as f:\n            content = f.read()\n            cores = set()\n            current_physical_id = None\n            \n            for line in content.split('\\n'):\n                if line.startswith('physical id'):\n                    current_physical_id = line.split(':')[1].strip()\n                elif line.startswith('core id') and current_physical_id is not None:\n                    core_id = line.split(':')[1].strip()\n                    cores.add(f\"{current_physical_id}:{core_id}\")\n            \n            if cores:\n                count = len(cores)\n                logger.info(f\"Detected {count} physical CPU cores via /proc/cpuinfo\")\n                return count\n    except Exception as e:\n        logger.debug(f\"Failed to parse /proc/cpuinfo: {e}\")\n    \n    # Fallback: assume hyperthreading is enabled and divide logical cores by 2\n    try:\n        logical_count = get_cpu_count()\n        if logical_count > 0:\n            # Assume hyperthreading, so physical cores = logical cores / 2\n            physical_count = logical_count // 2\n            logger.warning(f\"Could not detect physical cores directly. Assuming hyperthreading enabled: {logical_count} logical cores -> {physical_count} physical cores\")\n            return physical_count\n    except:\n        pass\n    \n    # Default fallback\n    logger.warning(\"Could not detect physical CPU count, defaulting to 32\")\n    return 32\n\n# Test configuration dataclass\n@dataclass\nclass TestConfig:\n    expert_num: int = 256\n    hidden_size: int = 7168\n    intermediate_size: int = 2048\n    max_len: int = 25600\n    num_experts_per_tok: int = 8\n    layer_num: int = 5\n    warm_up_iter: int = 100\n    test_iter: int = 10000\n    qlen_values: List[int] = None\n    thread_count_values: List[int] = None\n    \n    def __post_init__(self):\n        if self.qlen_values is None:\n            self.qlen_values = [1, 4, 16, 64, 256, 1024, 2048]\n        if self.thread_count_values is None:\n            # Default to physical CPU core count\n            physical_cores = get_physical_cpu_count()\n            self.thread_count_values = [physical_cores]\n    \n    @property\n    def total_configurations(self) -> int:\n        return len(self.qlen_values) * len(self.thread_count_values)\n\ndef get_numa_count() -> int:\n    \"\"\"Get NUMA node count from system with multiple fallback methods\"\"\"\n    # Method 1: Try numactl\n    try:\n        result = subprocess.run(['numactl', '--hardware'], \n                              capture_output=True, text=True, timeout=5)\n        if result.returncode == 0:\n            for line in result.stdout.split('\\n'):\n                if 'available:' in line and 'nodes' in line:\n                    parts = line.split()\n                    if len(parts) >= 2 and parts[1].isdigit():\n                        numa_count = int(parts[1])\n                        logger.info(f\"Detected {numa_count} NUMA nodes via numactl\")\n                        return numa_count\n    except (subprocess.TimeoutExpired, FileNotFoundError) as e:\n        logger.debug(f\"numactl not available: {e}\")\n    \n    # Method 2: Check /sys/devices/system/node/\n    try:\n        node_path = '/sys/devices/system/node/'\n        if os.path.exists(node_path):\n            numa_dirs = [d for d in os.listdir(node_path) if d.startswith('node')]\n            if numa_dirs:\n                numa_count = len(numa_dirs)\n                logger.info(f\"Detected {numa_count} NUMA nodes via sysfs\")\n                return numa_count\n    except Exception as e:\n        logger.debug(f\"Failed to check sysfs: {e}\")\n    \n    # Default fallback\n    logger.warning(\"Could not detect NUMA configuration, defaulting to 2 nodes\")\n    return 2\n\n# System configuration\n@dataclass\nclass SystemConfig:\n    numa_count: int = 0\n    cpu_cores: int = 0\n    \n    def __post_init__(self):\n        if self.numa_count == 0:\n            self.numa_count = get_numa_count()\n        if self.cpu_cores == 0:\n            self.cpu_cores = get_cpu_count()\n\nsys_config = SystemConfig()\n\n@dataclass\nclass ThreadConfig:\n    thread_count: int\n    threads_per_numa: int\n    sgl_thread_count: int\n    numa_prefix: str\n    \n    @classmethod\n    def from_thread_count(cls, thread_count: int, numa_count: int, cpu_cores: int) -> 'ThreadConfig':\n        \"\"\"Create thread configuration for a specific thread count\"\"\"\n        # Validate thread count\n        if thread_count > cpu_cores:\n            logger.warning(f\"thread_count ({thread_count}) > cpu_cores ({cpu_cores}), using all cores\")\n            thread_count = cpu_cores\n        \n        threads_per_numa = thread_count // numa_count\n        sgl_thread_count = threads_per_numa\n        last_core = sgl_thread_count - 1\n        numa_prefix = f\"numactl --physcpubind=0-{last_core} --membind=0\"\n        \n        return cls(\n            thread_count=thread_count,\n            threads_per_numa=threads_per_numa,\n            sgl_thread_count=sgl_thread_count,\n            numa_prefix=numa_prefix\n        )\n\ndef get_system_info() -> Dict[str, any]:\n    \"\"\"Get comprehensive system information\"\"\"\n    info = {}\n    \n    # Basic system info\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n    info[\"release\"] = uname.release\n    info[\"machine\"] = uname.machine\n    info[\"cpu_count\"] = sys_config.cpu_cores\n    info[\"numa_nodes\"] = sys_config.numa_count\n    \n    # CPU model information\n    if os.path.exists('/proc/cpuinfo'):\n        try:\n            with open('/proc/cpuinfo', 'r') as f:\n                cpu_info = f.read()\n                for line in cpu_info.split('\\n'):\n                    if \"model name\" in line:\n                        info[\"cpu_model\"] = line.split(\":\", 1)[1].strip()\n                        break\n                # Check for CPU features\n                if \"flags\" in cpu_info:\n                    flags_line = next(line for line in cpu_info.split('\\n') if \"flags\" in line)\n                    flags = flags_line.split(\":\", 1)[1].strip().split()\n                    info[\"cpu_features\"] = {\n                        \"avx2\": \"avx2\" in flags,\n                        \"avx512\": any(f.startswith(\"avx512\") for f in flags),\n                        \"amx\": any(\"amx\" in f for f in flags)\n                    }\n        except Exception as e:\n            logger.debug(f\"Failed to read CPU info: {e}\")\n    \n    # Memory information\n    try:\n        import psutil\n        mem = psutil.virtual_memory()\n        info[\"total_memory_gb\"] = round(mem.total / (1024**3), 2)\n        info[\"available_memory_gb\"] = round(mem.available / (1024**3), 2)\n    except ImportError:\n        pass\n    \n    # Python and PyTorch versions\n    info[\"python_version\"] = sys.version.split()[0]\n    info[\"torch_version\"] = torch.__version__\n    info[\"cuda_available\"] = torch.cuda.is_available()\n    if torch.cuda.is_available():\n        info[\"cuda_version\"] = torch.version.cuda\n    \n    return info\n\n@dataclass\nclass BenchmarkResult:\n    implementation: str\n    quant_mode: str\n    qlen: int\n    thread_count: int\n    total_time: float\n    time_per_iter_us: float\n    bandwidth_gbs: float\n    tflops: float\n    iterations: int\n    \n    def to_dict(self) -> Dict:\n        return asdict(self)\n\n@dataclass\nclass CheckpointState:\n    \"\"\"State information for checkpoint/resume functionality\"\"\"\n    test_config: TestConfig\n    completed_configs: List[Tuple[int, int, str, str]]  # (thread_count, qlen, implementation, quant_mode)\n    results: List[BenchmarkResult]\n    start_time: str\n    last_update: str\n    \n    def to_dict(self) -> Dict:\n        return {\n            'test_config': asdict(self.test_config),\n            'completed_configs': self.completed_configs,\n            'results': [r.to_dict() for r in self.results],\n            'start_time': self.start_time,\n            'last_update': self.last_update\n        }\n    \n    @classmethod\n    def from_dict(cls, data: Dict) -> 'CheckpointState':\n        test_config = TestConfig(**data['test_config'])\n        results = [BenchmarkResult(**r) for r in data['results']]\n        return cls(\n            test_config=test_config,\n            completed_configs=data['completed_configs'],\n            results=results,\n            start_time=data['start_time'],\n            last_update=data['last_update']\n        )\n\nclass CheckpointManager:\n    \"\"\"Manages checkpoint saving and loading\"\"\"\n    def __init__(self, checkpoint_dir: str = None):\n        self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else Path.cwd() / \"checkpoints\"\n        self.checkpoint_dir.mkdir(exist_ok=True)\n        self.checkpoint_file = self.checkpoint_dir / \"moe_benchmark_checkpoint.json\"\n        self.interrupted = False\n        \n        # Set up signal handler for graceful shutdown\n        signal.signal(signal.SIGINT, self._signal_handler)\n        signal.signal(signal.SIGTERM, self._signal_handler)\n    \n    def _signal_handler(self, signum, frame):\n        logger.warning(f\"Received signal {signum}, will save checkpoint after current test...\")\n        self.interrupted = True\n    \n    def save_checkpoint(self, state: CheckpointState):\n        \"\"\"Save checkpoint to file\"\"\"\n        state.last_update = datetime.now().isoformat()\n        \n        # Save to temporary file first for atomicity\n        temp_file = self.checkpoint_file.with_suffix('.tmp')\n        try:\n            with open(temp_file, 'w') as f:\n                json.dump(state.to_dict(), f, indent=2)\n            \n            # Atomically rename\n            temp_file.replace(self.checkpoint_file)\n            logger.info(f\"Checkpoint saved: {len(state.results)} results, {len(state.completed_configs)} configs completed\")\n        except Exception as e:\n            logger.error(f\"Failed to save checkpoint: {e}\")\n            if temp_file.exists():\n                temp_file.unlink()\n    \n    def load_checkpoint(self) -> Optional[CheckpointState]:\n        \"\"\"Load checkpoint from file if exists\"\"\"\n        if not self.checkpoint_file.exists():\n            return None\n        \n        try:\n            with open(self.checkpoint_file, 'r') as f:\n                data = json.load(f)\n            state = CheckpointState.from_dict(data)\n            logger.info(f\"Loaded checkpoint: {len(state.results)} results, {len(state.completed_configs)} configs completed\")\n            logger.info(f\"Checkpoint started at {state.start_time}, last updated {state.last_update}\")\n            return state\n        except Exception as e:\n            logger.error(f\"Failed to load checkpoint: {e}\")\n            return None\n    \n    def clear_checkpoint(self):\n        \"\"\"Remove checkpoint file\"\"\"\n        if self.checkpoint_file.exists():\n            self.checkpoint_file.unlink()\n            logger.info(\"Checkpoint cleared\")\n\ndef bench_ktransformers_moe(test_config: TestConfig, quant_mode: str, qlen: int, \n                           thread_config: ThreadConfig) -> Optional[BenchmarkResult]:\n    \"\"\"Benchmark KTransformers AMX MoE implementation\"\"\"\n    if not KTRANSFORMERS_AVAILABLE:\n        logger.error(\"KTransformers not available, skipping benchmark\")\n        return None\n    \n    # Adjust iterations based on qlen to maintain reasonable runtime\n    adjusted_iterations = test_config.test_iter\n    adjusted_warmup = test_config.warm_up_iter\n    if qlen >= 1024:\n        adjusted_iterations = max(10, test_config.test_iter // 100)\n        adjusted_warmup = max(5, test_config.warm_up_iter // 20)\n    elif qlen >= 256:\n        adjusted_iterations = max(50, test_config.test_iter // 20)\n        adjusted_warmup = max(10, test_config.warm_up_iter // 10)\n    elif qlen >= 64:\n        adjusted_iterations = max(100, test_config.test_iter // 10)\n        adjusted_warmup = max(20, test_config.warm_up_iter // 5)\n    elif qlen >= 16:\n        adjusted_iterations = max(200, test_config.test_iter // 5)\n        adjusted_warmup = max(40, test_config.warm_up_iter // 2)\n    \n    logger.info(f\"Testing KTransformers MoE: quant={quant_mode}, qlen={qlen}, threads={thread_config.thread_count}, \"\n                f\"iterations={adjusted_iterations} (warmup={adjusted_warmup})\")\n    \n    # Set thread count for this test\n    os.environ['OMP_NUM_THREADS'] = str(thread_config.thread_count)\n    \n    try:\n        with torch.inference_mode():\n            # Setup worker config with consistent threads per NUMA\n            worker_config = kt_kernel_ext.WorkerPoolConfig()\n            worker_config.subpool_count = sys_config.numa_count\n            worker_config.subpool_numa_map = list(range(sys_config.numa_count))\n            worker_config.subpool_thread_count = [thread_config.threads_per_numa] * sys_config.numa_count\n            CPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n        \n            # Create MoE layers\n            moes = []\n            gate_projs = []\n            up_projs = []\n            down_projs = []\n            \n            logger.debug(f\"Creating {test_config.layer_num} MoE layers...\")\n            for i in range(test_config.layer_num):\n                gate_proj = torch.randn((test_config.expert_num, test_config.intermediate_size, test_config.hidden_size), \n                                      dtype=torch.float32).contiguous()\n                up_proj = torch.randn((test_config.expert_num, test_config.intermediate_size, test_config.hidden_size), \n                                    dtype=torch.float32).contiguous()\n                down_proj = torch.randn((test_config.expert_num, test_config.hidden_size, test_config.intermediate_size), \n                                      dtype=torch.float32).contiguous()\n            \n                config = kt_kernel_ext.moe.MOEConfig(\n                    test_config.expert_num, test_config.num_experts_per_tok, \n                    test_config.hidden_size, test_config.intermediate_size)\n                config.max_len = test_config.max_len\n                config.gate_proj = gate_proj.data_ptr()\n                config.up_proj = up_proj.data_ptr()\n                config.down_proj = down_proj.data_ptr()\n                config.pool = CPUInfer.backend_\n            \n                if quant_mode == \"bf16\":\n                    moe = kt_kernel_ext.moe.AMXBF16_MOE(config)\n                elif quant_mode == \"int8\":\n                    moe = kt_kernel_ext.moe.AMXInt8_MOE(config)\n                elif quant_mode == \"int4\":\n                    moe = kt_kernel_ext.moe.AMXInt4_MOE(config)\n                else:\n                    raise ValueError(f\"Unsupported quantization mode: {quant_mode}\")\n                \n                CPUInfer.submit(moe.load_weights_task())\n                CPUInfer.sync()\n                gate_projs.append(gate_proj)\n                up_projs.append(up_proj)\n                down_projs.append(down_proj)\n                moes.append(moe)\n        \n            # Prepare test data\n            logger.debug(\"Preparing test data...\")\n            gen_iter = 1000\n            expert_ids = torch.rand(gen_iter * qlen, test_config.expert_num).argsort(dim=-1)[\n                :, :test_config.num_experts_per_tok\n            ].reshape(gen_iter, qlen * test_config.num_experts_per_tok).contiguous()\n            \n            weights = torch.rand((gen_iter, qlen, test_config.num_experts_per_tok), \n                               dtype=torch.float32).contiguous()\n            input_tensor = torch.randn((test_config.layer_num, qlen, test_config.hidden_size), \n                                     dtype=torch.bfloat16).contiguous()\n            output_tensor = torch.empty((test_config.layer_num, qlen, test_config.hidden_size), \n                                      dtype=torch.bfloat16).contiguous()\n            bsz_tensor = torch.tensor([qlen], dtype=torch.int32)\n        \n            # Warmup\n            logger.debug(f\"Running {adjusted_warmup} warmup iterations...\")\n            for i in range(adjusted_warmup):\n                layer_idx = i % test_config.layer_num\n                gen_idx = i % gen_iter\n                CPUInfer.submit(\n                    moes[layer_idx].forward_task(\n                        bsz_tensor.data_ptr(),\n                        test_config.num_experts_per_tok,\n                        expert_ids[gen_idx].data_ptr(),\n                        weights[gen_idx].data_ptr(),\n                        input_tensor[layer_idx].data_ptr(),\n                        output_tensor[layer_idx].data_ptr(),\n                        False,\n                    )\n                )\n                CPUInfer.sync()\n        \n            # Benchmark\n            logger.debug(f\"Running {adjusted_iterations} benchmark iterations...\")\n            start = time.perf_counter()\n            for i in range(adjusted_iterations):\n                layer_idx = i % test_config.layer_num\n                gen_idx = i % gen_iter\n                CPUInfer.submit(\n                    moes[layer_idx].forward_task(\n                        bsz_tensor.data_ptr(),\n                        test_config.num_experts_per_tok,\n                        expert_ids[gen_idx].data_ptr(),\n                        weights[gen_idx].data_ptr(),\n                        input_tensor[layer_idx].data_ptr(),\n                        output_tensor[layer_idx].data_ptr(),\n                        False,\n                    )\n                )\n                CPUInfer.sync()\n            end = time.perf_counter()\n        \n            # Calculate metrics\n            total_time = end - start\n            time_per_iter_us = total_time / adjusted_iterations * 1e6\n            \n            # Bytes per element based on quantization\n            bytes_per_elem = {\n                \"bf16\": 2.0,\n                \"int8\": 1.0,\n                \"int4\": 0.5\n            }.get(quant_mode, 2.0)\n            \n            # Memory bandwidth calculation (GB/s)\n            memory_per_iter = (\n                test_config.hidden_size * test_config.intermediate_size * 3 * \n                test_config.num_experts_per_tok * \n                (1/8 * test_config.expert_num * (1-(31/32)**qlen)) * bytes_per_elem\n            )\n            bandwidth_gbs = memory_per_iter * adjusted_iterations / total_time / 1e9\n            \n            # FLOPS calculation (TFLOPS)\n            flops_per_iter = (\n                test_config.hidden_size * test_config.intermediate_size * qlen * 3 * \n                test_config.num_experts_per_tok * 2\n            )\n            tflops = flops_per_iter * adjusted_iterations / total_time / 1e12\n            \n            logger.info(f\"Results - Time: {total_time:.4f}s, Per-iter: {time_per_iter_us:.2f}μs, \"\n                       f\"BW: {bandwidth_gbs:.2f} GB/s, TFLOPS: {tflops:.2f}\")\n            \n            return BenchmarkResult(\n                implementation=\"KTransformers\",\n                quant_mode=quant_mode,\n                qlen=qlen,\n                thread_count=thread_config.thread_count,\n                total_time=total_time,\n                time_per_iter_us=time_per_iter_us,\n                bandwidth_gbs=bandwidth_gbs,\n                tflops=tflops,\n                iterations=adjusted_iterations\n            )\n            \n    except Exception as e:\n        logger.error(f\"KTransformers benchmark failed: {e}\", exc_info=True)\n        return None\n\ndef run_sgl_int4_with_numactl(test_config: TestConfig, qlen: int, \n                             thread_config: ThreadConfig) -> Optional[BenchmarkResult]:\n    \"\"\"Run SGL INT4 benchmark with numactl in subprocess\"\"\"\n    if not SGL_INT4_AVAILABLE:\n        logger.error(\"SGL INT4 not available, skipping benchmark\")\n        return None\n    \n    # Calculate SGL intermediate size (divided by NUMA nodes)\n    sgl_intermediate_size = test_config.intermediate_size // sys_config.numa_count\n    \n    # Adjust iterations based on qlen to maintain reasonable runtime\n    adjusted_iterations = test_config.test_iter\n    adjusted_warmup = test_config.warm_up_iter\n    if qlen >= 1024:\n        adjusted_iterations = max(10, test_config.test_iter // 100)\n        adjusted_warmup = max(5, test_config.warm_up_iter // 20)\n    elif qlen >= 256:\n        adjusted_iterations = max(50, test_config.test_iter // 20)\n        adjusted_warmup = max(10, test_config.warm_up_iter // 10)\n    elif qlen >= 64:\n        adjusted_iterations = max(100, test_config.test_iter // 10)\n        adjusted_warmup = max(20, test_config.warm_up_iter // 5)\n    elif qlen >= 16:\n        adjusted_iterations = max(200, test_config.test_iter // 5)\n        adjusted_warmup = max(40, test_config.warm_up_iter // 2)\n    \n    logger.info(f\"Testing SGL INT4: qlen={qlen}, iterations={adjusted_iterations} (warmup={adjusted_warmup}), \"\n                f\"threads per NUMA: {thread_config.sgl_thread_count}\")\n    \n    script_content = f'''\nimport sys\nsys.path.insert(0, '/home/xwy/Projects/sglang-jianan')\nsys.path.insert(0, '/home/xwy/Projects/sglang-jianan/test')\n\nimport os\nimport torch\nimport numpy as np\nimport sgl_kernel\nfrom srt.cpu.utils import autoawq_to_int4pack\nimport time\n\ntorch.manual_seed(1111)\nM, N, K, E, topk = {qlen}, {sgl_intermediate_size}, {test_config.hidden_size}, {test_config.expert_num}, {test_config.num_experts_per_tok}\nlayer_num = {test_config.layer_num}\ngroup_size = 128\nkernel = torch.ops.sgl_kernel\n\n# Prepare int4 data\ndtype = torch.bfloat16\ndevice = \"cpu\"\n\n# Generate input activations for all layers\ninput_tensors = [torch.rand(M, K, dtype=dtype, device=device) / np.sqrt(K) for _ in range(layer_num)]\n\n# Generate weights and pack for each layer\nall_awq_w13_weight_pack = []\nall_awq_w13_zero_pack = []\nall_awq_w13_scales_pack = []\nall_awq_w2_weight_pack = []\nall_awq_w2_zero_pack = []\nall_awq_w2_scales_pack = []\n\n# Generate expert routing scores (different for each iteration)\ngen_iter = 1000\nall_topk_weights = []\nall_topk_ids = []\n\nfor gen_idx in range(gen_iter):\n    score = torch.rand(M, E, dtype=dtype, device=device)\n    score = torch.softmax(score, dim=-1, dtype=torch.float32)\n    topk_weight, topk_ids = torch.topk(score, topk)\n    all_topk_weights.append(topk_weight)\n    all_topk_ids.append(topk_ids.to(torch.int32))\n\nprint(\"Creating \" + str(layer_num) + \" MoE layers...\")\nfor layer_idx in range(layer_num):\n    # Generate INT4 quantized weights for each expert\n    # w1: gate and up projection (K -> 2*N)\n    awq_w13_weight = torch.randint(-127, 128, (E, K, 2 * N // 8), device=device).to(torch.int)\n    awq_w13_zero = torch.randint(0, 10, (E, K // group_size, 2 * N // 8), device=device).to(torch.int)\n    awq_w13_scales = torch.rand(E, K // group_size, 2 * N, dtype=dtype, device=device)\n    \n    # w2: down projection (N -> K)  \n    awq_w2_weight = torch.randint(-127, 128, (E, N, K // 8), device=device).to(torch.int)\n    awq_w2_zero = torch.randint(0, 10, (E, N // group_size, K // 8), device=device).to(torch.int)\n    awq_w2_scales = torch.rand(E, N // group_size, K, dtype=dtype, device=device)\n    \n    # Pack weights for optimized kernel\n    awq_w13_weight_pack = []\n    awq_w13_zero_pack = []\n    awq_w13_scales_pack = []\n    awq_w2_weight_pack = []\n    awq_w2_zero_pack = []\n    awq_w2_scales_pack = []\n    \n    for i in range(E):\n        packed_weight_13, packed_zero_13, packed_scales_13 = autoawq_to_int4pack(\n            awq_w13_weight[i], awq_w13_zero[i], awq_w13_scales[i], False\n        )\n        awq_w13_weight_pack.append(packed_weight_13)\n        awq_w13_zero_pack.append(packed_zero_13)\n        awq_w13_scales_pack.append(packed_scales_13)\n        \n        packed_weight_2, packed_zero_2, packed_scales_2 = autoawq_to_int4pack(\n            awq_w2_weight[i], awq_w2_zero[i], awq_w2_scales[i], False\n        )\n        awq_w2_weight_pack.append(packed_weight_2)\n        awq_w2_zero_pack.append(packed_zero_2)\n        awq_w2_scales_pack.append(packed_scales_2)\n    \n    all_awq_w13_weight_pack.append(torch.stack(awq_w13_weight_pack).detach())\n    all_awq_w13_zero_pack.append(torch.stack(awq_w13_zero_pack).detach())\n    all_awq_w13_scales_pack.append(torch.stack(awq_w13_scales_pack).detach())\n    all_awq_w2_weight_pack.append(torch.stack(awq_w2_weight_pack).detach())\n    all_awq_w2_zero_pack.append(torch.stack(awq_w2_zero_pack).detach())\n    all_awq_w2_scales_pack.append(torch.stack(awq_w2_scales_pack).detach())\n\n# Warmup\nprint(\"Running \" + str({adjusted_warmup}) + \" warmup iterations...\")\nfor i in range({adjusted_warmup}):\n    layer_idx = i % layer_num\n    gen_idx = i % gen_iter\n    out = kernel.fused_experts_cpu(\n        input_tensors[layer_idx],\n        all_awq_w13_weight_pack[layer_idx],\n        all_awq_w2_weight_pack[layer_idx],\n        all_topk_weights[gen_idx],\n        all_topk_ids[gen_idx],\n        False,  # inplace\n        False,  # use_int8_w8a8\n        False,  # use_fp8_w8a16\n        True,   # use_int4_w4a16\n        all_awq_w13_scales_pack[layer_idx],\n        all_awq_w2_scales_pack[layer_idx],\n        all_awq_w13_zero_pack[layer_idx],\n        all_awq_w2_zero_pack[layer_idx],\n        None,   # block_size\n        None,   # a1_scale\n        None,   # a2_scale\n        True,   # is_vnni\n    )\n\n# Benchmark\nprint(\"Running \" + str({adjusted_iterations}) + \" benchmark iterations...\")\nstart = time.perf_counter()\nfor i in range({adjusted_iterations}):\n    layer_idx = i % layer_num\n    gen_idx = i % gen_iter\n    out = kernel.fused_experts_cpu(\n        input_tensors[layer_idx],\n        all_awq_w13_weight_pack[layer_idx],\n        all_awq_w2_weight_pack[layer_idx],\n        all_topk_weights[gen_idx],\n        all_topk_ids[gen_idx],\n        False,\n        False,\n        False,\n        True,\n        all_awq_w13_scales_pack[layer_idx],\n        all_awq_w2_scales_pack[layer_idx],\n        all_awq_w13_zero_pack[layer_idx],\n        all_awq_w2_zero_pack[layer_idx],\n        None,\n        None,\n        None,\n        True,\n    )\nend = time.perf_counter()\n\ntotal_time = end - start\ntime_per_iter_us = total_time / {adjusted_iterations} * 1e6\n\n# Calculate performance metrics for int4\nbytes_per_elem = 0.5  # int4\nmemory_per_iter = (\n    {test_config.hidden_size} * {sgl_intermediate_size} * 3 * {test_config.num_experts_per_tok} * \n    (1/8 * {test_config.expert_num} * (1-(31/32)**{qlen})) * bytes_per_elem\n)\nbandwidth_gbs = memory_per_iter * {adjusted_iterations} / total_time / 1e9\n\n# FLOPS calculation \nflops_per_iter = {test_config.hidden_size} * {sgl_intermediate_size} * {qlen} * 3 * {test_config.num_experts_per_tok} * 2\ntflops = flops_per_iter * {adjusted_iterations} / total_time / 1e12\n\nprint(f\"SGL_RESULT:{{total_time}},{{time_per_iter_us}},{{bandwidth_gbs}},{{tflops}}\")\n'''\n    \n    # Create temporary script in sglang-jianan directory\n    sglang_path = \"/home/xwy/Projects/sglang-jianan\"\n    temp_script = f\"{sglang_path}/temp_sgl_int4_bench_{os.getpid()}_{qlen}.py\"\n    \n    try:\n        with open(temp_script, 'w') as f:\n            f.write(script_content)\n        \n        # Setup environment\n        env = os.environ.copy()\n        env['MALLOC_CONF'] = env_config.malloc_conf\n        if os.path.exists(env_config.jemalloc_path):\n            env['LD_PRELOAD'] = env_config.jemalloc_path\n        env['OMP_NUM_THREADS'] = str(thread_config.sgl_thread_count)\n        \n        # Run with numactl from the sglang-jianan directory\n        cmd = f\"cd {sglang_path} && {thread_config.numa_prefix} python3 {temp_script}\"\n        logger.debug(f\"Running SGL INT4 command: {cmd}\")\n        \n        result = subprocess.run(cmd, shell=True, capture_output=True, text=True, env=env, timeout=300)\n        \n        if result.returncode == 0:\n            # Parse result\n            for line in result.stdout.split('\\n'):\n                if line.startswith('SGL_RESULT:'):\n                    parts = line.replace('SGL_RESULT:', '').split(',')\n                    if len(parts) >= 4:\n                        try:\n                            total_time = float(parts[0])\n                            time_per_iter_us = float(parts[1])\n                            bandwidth_gbs = float(parts[2])\n                            tflops = float(parts[3])\n                            \n                            logger.info(f\"SGL INT4 Results - Time: {total_time:.4f}s, Per-iter: {time_per_iter_us:.2f}μs, \"\n                                       f\"BW: {bandwidth_gbs:.2f} GB/s, TFLOPS: {tflops:.2f}\")\n                            \n                            return BenchmarkResult(\n                                implementation=\"SGL\",\n                                quant_mode=\"int4\",\n                                qlen=qlen,\n                                thread_count=thread_config.thread_count,\n                                total_time=total_time,\n                                time_per_iter_us=time_per_iter_us,\n                                bandwidth_gbs=bandwidth_gbs,\n                                tflops=tflops,\n                                iterations=adjusted_iterations\n                            )\n                        except ValueError as e:\n                            logger.error(f\"Failed to parse SGL INT4 results: {e}\")\n        else:\n            logger.error(f\"SGL INT4 subprocess failed with code {result.returncode}\")\n            logger.error(f\"STDOUT: {result.stdout}\")\n            logger.error(f\"STDERR: {result.stderr}\")\n            \n    except subprocess.TimeoutExpired:\n        logger.error(\"SGL INT4 benchmark timed out\")\n    except Exception as e:\n        logger.error(f\"SGL INT4 benchmark error: {e}\", exc_info=True)\n    finally:\n        # Clean up\n        if os.path.exists(temp_script):\n            try:\n                os.remove(temp_script)\n            except:\n                pass\n    \n    return None\n\ndef run_sgl_with_numactl(test_config: TestConfig, qlen: int, \n                        thread_config: ThreadConfig) -> Optional[BenchmarkResult]:\n    \"\"\"Run SGL benchmark with numactl in subprocess\"\"\"\n    if not SGL_AVAILABLE:\n        logger.error(\"SGL not available, skipping benchmark\")\n        return None\n    \n    # Calculate SGL intermediate size (divided by NUMA nodes)\n    sgl_intermediate_size = test_config.intermediate_size // sys_config.numa_count\n    \n    # Adjust iterations based on qlen to maintain reasonable runtime\n    adjusted_iterations = test_config.test_iter\n    adjusted_warmup = test_config.warm_up_iter\n    if qlen >= 1024:\n        adjusted_iterations = max(10, test_config.test_iter // 100)\n        adjusted_warmup = max(5, test_config.warm_up_iter // 20)\n    elif qlen >= 256:\n        adjusted_iterations = max(50, test_config.test_iter // 20)\n        adjusted_warmup = max(10, test_config.warm_up_iter // 10)\n    elif qlen >= 64:\n        adjusted_iterations = max(100, test_config.test_iter // 10)\n        adjusted_warmup = max(20, test_config.warm_up_iter // 5)\n    elif qlen >= 16:\n        adjusted_iterations = max(200, test_config.test_iter // 5)\n        adjusted_warmup = max(40, test_config.warm_up_iter // 2)\n    \n    logger.info(f\"Testing SGL INT8: qlen={qlen}, iterations={adjusted_iterations} (warmup={adjusted_warmup}), \"\n                f\"threads per NUMA: {thread_config.sgl_thread_count}\")\n    \n    script_content = f'''\nimport sys\nsys.path.insert(0, \"/home/xwy/Projects/sgl-cpu-tests\")\n\nimport os\nimport torch\nfrom sgl_kernel.common_ops import fused_experts_cpu as fused_experts\nfrom sgl_kernel.common_ops import convert_weight_packed\nimport time\n\ntorch.manual_seed(1111)\nM, N, K, E, topk = {qlen}, {sgl_intermediate_size}, {test_config.hidden_size}, {test_config.expert_num}, {test_config.num_experts_per_tok}\nlayer_num = {test_config.layer_num}\n\n# Generate expert routing scores (different for each iteration)\ngen_iter = 1000\nall_topk_weights = []\nall_topk_ids = []\n\nfor gen_idx in range(gen_iter):\n    score = torch.randn(M, E).to(dtype=torch.bfloat16)\n    score = torch.softmax(score, dim=-1, dtype=torch.float32)\n    topk_weight, topk_ids = torch.topk(score, topk)\n    all_topk_weights.append(topk_weight)\n    all_topk_ids.append(topk_ids.to(torch.int32))\n\nprepack = True\ninplace = True\nuse_int4_w4a16 = False\n\n# Create multiple layers\nprint(\"Creating \" + str(layer_num) + \" MoE layers...\")\ninputs = []\npacked_w1s_int8 = []\npacked_w2s_int8 = []\nw1_s_list = []\nw2_s_list = []\n\nfor layer_idx in range(layer_num):\n    input_tensor = torch.randn(M, K).to(dtype=torch.bfloat16)\n    \n    # int8 weights\n    w1_int8 = torch.randn(E, 2 * N, K).to(dtype=torch.int8)\n    w2_int8 = torch.randn(E, K, N).to(dtype=torch.int8)\n    packed_w1_int8 = convert_weight_packed(w1_int8)\n    packed_w2_int8 = convert_weight_packed(w2_int8)\n    w1_s = torch.rand(E, 2 * N)\n    w2_s = torch.rand(E, K)\n    \n    inputs.append(input_tensor)\n    packed_w1s_int8.append(packed_w1_int8)\n    packed_w2s_int8.append(packed_w2_int8)\n    w1_s_list.append(w1_s)\n    w2_s_list.append(w2_s)\n\n# Warmup\nprint(\"Running \" + str({adjusted_warmup}) + \" warmup iterations...\")\nfor i in range({adjusted_warmup}):\n    layer_idx = i % layer_num\n    gen_idx = i % gen_iter\n    fused_experts(inputs[layer_idx], packed_w1s_int8[layer_idx], packed_w2s_int8[layer_idx], \n                 all_topk_weights[gen_idx], all_topk_ids[gen_idx],\n                 inplace, True, False, use_int4_w4a16, w1_s_list[layer_idx], w2_s_list[layer_idx], \n                 None, None, None, None, None, prepack)\n\n# Benchmark\nprint(\"Running \" + str({adjusted_iterations}) + \" benchmark iterations...\")\nstart = time.perf_counter()\nfor i in range({adjusted_iterations}):\n    layer_idx = i % layer_num\n    gen_idx = i % gen_iter\n    fused_experts(inputs[layer_idx], packed_w1s_int8[layer_idx], packed_w2s_int8[layer_idx], \n                 all_topk_weights[gen_idx], all_topk_ids[gen_idx],\n                 inplace, True, False, use_int4_w4a16, w1_s_list[layer_idx], w2_s_list[layer_idx], \n                 None, None, None, None, None, prepack)\nend = time.perf_counter()\n\ntotal_time = end - start\ntime_per_iter_us = total_time / {adjusted_iterations} * 1e6\n\n# Calculate performance metrics for int8\nbytes_per_elem = 1.0  # int8\nmemory_per_iter = (\n    {test_config.hidden_size} * {sgl_intermediate_size} * 3 * {test_config.num_experts_per_tok} * \n    (1/8 * {test_config.expert_num} * (1-(31/32)**{qlen})) * bytes_per_elem\n)\nbandwidth_gbs = memory_per_iter * {adjusted_iterations} / total_time / 1e9\n\n# FLOPS calculation \nflops_per_iter = {test_config.hidden_size} * {sgl_intermediate_size} * {qlen} * 3 * {test_config.num_experts_per_tok} * 2\ntflops = flops_per_iter * {adjusted_iterations} / total_time / 1e12\n\nprint(f\"SGL_RESULT:{{total_time}},{{time_per_iter_us}},{{bandwidth_gbs}},{{tflops}}\")\n'''\n    \n    # Create temporary script\n    temp_script = f\"/tmp/sgl_bench_{os.getpid()}_{qlen}.py\"\n    \n    try:\n        with open(temp_script, 'w') as f:\n            f.write(script_content)\n        \n        # Setup environment\n        env = os.environ.copy()\n        env['MALLOC_CONF'] = env_config.malloc_conf\n        if os.path.exists(env_config.jemalloc_path):\n            env['LD_PRELOAD'] = env_config.jemalloc_path\n        env['OMP_NUM_THREADS'] = str(thread_config.sgl_thread_count)\n        \n        # Run with numactl\n        cmd = f\"{thread_config.numa_prefix} python3 {temp_script}\"\n        logger.debug(f\"Running SGL command: {cmd}\")\n        \n        result = subprocess.run(cmd, shell=True, capture_output=True, text=True, env=env, timeout=300)\n        \n        if result.returncode == 0:\n            # Parse result\n            for line in result.stdout.split('\\n'):\n                if line.startswith('SGL_RESULT:'):\n                    parts = line.replace('SGL_RESULT:', '').split(',')\n                    if len(parts) >= 4:\n                        try:\n                            total_time = float(parts[0])\n                            time_per_iter_us = float(parts[1])\n                            bandwidth_gbs = float(parts[2])\n                            tflops = float(parts[3])\n                            \n                            logger.info(f\"SGL Results - Time: {total_time:.4f}s, Per-iter: {time_per_iter_us:.2f}μs, \"\n                                       f\"BW: {bandwidth_gbs:.2f} GB/s, TFLOPS: {tflops:.2f}\")\n                            \n                            return BenchmarkResult(\n                                implementation=\"SGL\",\n                                quant_mode=\"int8\",\n                                qlen=qlen,\n                                thread_count=thread_config.thread_count,\n                                total_time=total_time,\n                                time_per_iter_us=time_per_iter_us,\n                                bandwidth_gbs=bandwidth_gbs,\n                                tflops=tflops,\n                                iterations=adjusted_iterations\n                            )\n                        except ValueError as e:\n                            logger.error(f\"Failed to parse SGL results: {e}\")\n        else:\n            logger.error(f\"SGL subprocess failed with code {result.returncode}: {result.stderr}\")\n            \n    except subprocess.TimeoutExpired:\n        logger.error(\"SGL benchmark timed out\")\n    except Exception as e:\n        logger.error(f\"SGL benchmark error: {e}\", exc_info=True)\n    finally:\n        # Clean up\n        if os.path.exists(temp_script):\n            try:\n                os.remove(temp_script)\n            except:\n                pass\n    \n    return None\n\ndef save_results(results: List[BenchmarkResult], test_config: TestConfig, filename: str = None) -> str:\n    \"\"\"Save benchmark results to JSON file\"\"\"\n    if not filename:\n        timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n        filename = f\"moe_comparison_{timestamp}.json\"\n    \n    output_data = {\n        \"timestamp\": datetime.now().isoformat(),\n        \"test_configuration\": asdict(test_config),\n        \"system_info\": get_system_info(),\n        \"results\": [r.to_dict() for r in results],\n        \"summary\": {\n            \"total_benchmarks\": len(results),\n            \"implementations_tested\": list(set(r.implementation for r in results)),\n            \"quantization_modes\": list(set(r.quant_mode for r in results)),\n            \"qlen_values_tested\": sorted(set(r.qlen for r in results)),\n            \"thread_counts_tested\": sorted(set(r.thread_count for r in results))\n        }\n    }\n    \n    with open(filename, 'w') as f:\n        json.dump(output_data, f, indent=2)\n    \n    logger.info(f\"Results saved to: {filename}\")\n    return filename\n\ndef print_summary_table(results: List[BenchmarkResult]):\n    \"\"\"Print formatted summary table of results\"\"\"\n    if not results:\n        return\n    \n    print(\"\\n\" + \"=\" * 100)\n    print(\"PERFORMANCE SUMMARY\")\n    print(\"=\" * 100)\n    print(f\"{'Implementation':<15} {'Quant':<6} {'Threads':<8} {'QLen':<8} {'Time(μs)':<12} {'BW(GB/s)':<12} {'TFLOPS':<10} {'Speedup':<10}\")\n    print(\"-\" * 100)\n    \n    # Group by configuration for better comparison\n    baseline_times = {}\n    \n    for result in sorted(results, key=lambda r: (r.thread_count, r.qlen, r.implementation, r.quant_mode)):\n        key = (result.thread_count, result.qlen)\n        \n        if key not in baseline_times:\n            baseline_times[key] = result.time_per_iter_us\n            speedup = \"1.00x\"\n        else:\n            speedup = f\"{baseline_times[key]/result.time_per_iter_us:.2f}x\"\n        \n        print(f\"{result.implementation:<15} {result.quant_mode:<6} {result.thread_count:<8} \"\n              f\"{result.qlen:<8} {result.time_per_iter_us:<12.2f} {result.bandwidth_gbs:<12.2f} \"\n              f\"{result.tflops:<10.2f} {speedup:<10}\")\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Compare MoE performance between KTransformers and SGL\")\n    parser.add_argument(\"--qlen\", type=int, nargs=\"+\", help=\"Sequence lengths to test\")\n    parser.add_argument(\"--threads\", type=int, nargs=\"+\", help=\"Thread counts to test\")\n    parser.add_argument(\"--iterations\", type=int, help=\"Number of test iterations\")\n    parser.add_argument(\"--warmup\", type=int, help=\"Number of warmup iterations\")\n    parser.add_argument(\"--output\", type=str, help=\"Output filename for results\")\n    parser.add_argument(\"--verbose\", \"-v\", action=\"store_true\", help=\"Enable verbose logging\")\n    parser.add_argument(\"--resume\", action=\"store_true\", help=\"Resume from checkpoint if available\")\n    parser.add_argument(\"--checkpoint-dir\", type=str, help=\"Directory for checkpoint files\")\n    parser.add_argument(\"--no-checkpoint\", action=\"store_true\", help=\"Disable checkpoint saving\")\n    parser.add_argument(\"--framework\", choices=[\"all\", \"ktransformers\", \"sgl\"], default=\"all\",\n                        help=\"Framework to test (default: all)\")\n    parser.add_argument(\"--precision\", choices=[\"all\", \"int8\", \"int4\"], default=\"all\",\n                        help=\"Precision to test (default: all)\")\n    \n    args = parser.parse_args()\n    \n    # Configure logging level\n    if args.verbose:\n        logging.getLogger().setLevel(logging.DEBUG)\n    \n    # Create test configuration\n    test_config = TestConfig()\n    if args.qlen:\n        test_config.qlen_values = args.qlen\n    if args.threads:\n        test_config.thread_count_values = args.threads\n    if args.iterations:\n        test_config.test_iter = args.iterations\n    if args.warmup:\n        test_config.warm_up_iter = args.warmup\n    \n    # Determine which frameworks and precisions to test\n    test_ktransformers = args.framework in [\"all\", \"ktransformers\"] and KTRANSFORMERS_AVAILABLE\n    test_sgl = args.framework in [\"all\", \"sgl\"] and (SGL_AVAILABLE or SGL_INT4_AVAILABLE)\n    \n    # Determine which precisions to test\n    test_precisions = []\n    if args.precision == \"all\":\n        test_precisions = [\"int8\", \"int4\"]\n    else:\n        test_precisions = [args.precision]\n    \n    # Print configuration\n    logger.info(\"MoE Performance Comparison\")\n    logger.info(\"=\" * 60)\n    logger.info(f\"System configuration:\")\n    logger.info(f\"  CPU cores: {sys_config.cpu_cores}\")\n    logger.info(f\"  NUMA nodes: {sys_config.numa_count}\")\n    logger.info(f\"Test parameters:\")\n    logger.info(f\"  Expert count: {test_config.expert_num}\")\n    logger.info(f\"  Hidden size: {test_config.hidden_size}\")\n    logger.info(f\"  Intermediate size: {test_config.intermediate_size}\")\n    logger.info(f\"  Experts per token: {test_config.num_experts_per_tok}\")\n    logger.info(f\"  Test iterations: {test_config.test_iter}\")\n    logger.info(f\"  Warmup iterations: {test_config.warm_up_iter}\")\n    logger.info(f\"Testing configurations:\")\n    logger.info(f\"  QLEN values: {test_config.qlen_values}\")\n    logger.info(f\"  Thread counts: {test_config.thread_count_values}\")\n    logger.info(f\"  Frameworks: {args.framework}\")\n    logger.info(f\"  Precisions: {args.precision}\")\n    logger.info(f\"  Total configs: {test_config.total_configurations}\")\n    print()\n    \n    # Check availability\n    if not KTRANSFORMERS_AVAILABLE and not SGL_AVAILABLE:\n        logger.error(\"Neither KTransformers nor SGL is available. Cannot run benchmarks.\")\n        return 1\n    \n    # Initialize checkpoint manager\n    checkpoint_mgr = CheckpointManager(args.checkpoint_dir) if not args.no_checkpoint else None\n    \n    # Load checkpoint if resuming\n    checkpoint_state = None\n    completed_configs = set()\n    all_results = []\n    start_time = datetime.now().isoformat()\n    \n    if args.resume and checkpoint_mgr:\n        checkpoint_state = checkpoint_mgr.load_checkpoint()\n        if checkpoint_state:\n            # Verify configuration matches\n            if (checkpoint_state.test_config.qlen_values != test_config.qlen_values or\n                checkpoint_state.test_config.thread_count_values != test_config.thread_count_values):\n                logger.warning(\"Checkpoint configuration doesn't match current configuration\")\n                response = input(\"Continue with checkpoint anyway? (y/n): \")\n                if response.lower() != 'y':\n                    logger.info(\"Starting fresh run\")\n                    checkpoint_state = None\n            \n            if checkpoint_state:\n                all_results = checkpoint_state.results\n                completed_configs = set(checkpoint_state.completed_configs)\n                start_time = checkpoint_state.start_time\n                logger.info(f\"Resuming from checkpoint with {len(all_results)} results\")\n    \n    # Create checkpoint state if not loaded\n    if not checkpoint_state and checkpoint_mgr:\n        checkpoint_state = CheckpointState(\n            test_config=test_config,\n            completed_configs=[],\n            results=[],\n            start_time=start_time,\n            last_update=start_time\n        )\n    \n    config_count = 0\n    total_configs_to_run = 0\n    \n    # Calculate total configs to run\n    for thread_count in test_config.thread_count_values:\n        for qlen in test_config.qlen_values:\n            if test_ktransformers:\n                for quant_mode in test_precisions:\n                    if (thread_count, qlen, \"KTransformers\", quant_mode) not in completed_configs:\n                        total_configs_to_run += 1\n            if test_sgl:\n                if \"int8\" in test_precisions and SGL_AVAILABLE:\n                    if (thread_count, qlen, \"SGL\", \"int8\") not in completed_configs:\n                        total_configs_to_run += 1\n                if \"int4\" in test_precisions and SGL_INT4_AVAILABLE:\n                    if (thread_count, qlen, \"SGL\", \"int4\") not in completed_configs:\n                        total_configs_to_run += 1\n    \n    logger.info(f\"Total configurations to run: {total_configs_to_run}\")\n    \n    # Test all combinations\n    for thread_count in test_config.thread_count_values:\n        thread_config = ThreadConfig.from_thread_count(thread_count, sys_config.numa_count, sys_config.cpu_cores)\n        logger.info(f\"\\nThread Configuration: {thread_count} total ({thread_config.threads_per_numa} per NUMA)\")\n        \n        for qlen in test_config.qlen_values:\n            # Check for interrupt\n            if checkpoint_mgr and checkpoint_mgr.interrupted:\n                logger.warning(\"Interrupt detected, saving checkpoint and exiting...\")\n                if checkpoint_state:\n                    checkpoint_state.results = all_results\n                    checkpoint_state.completed_configs = list(completed_configs)\n                    checkpoint_mgr.save_checkpoint(checkpoint_state)\n                return 2\n            \n            logger.info(f\"\\n--- Configuration: threads={thread_count}, qlen={qlen} ---\")\n            \n            # Test KTransformers\n            if test_ktransformers:\n                for quant_mode in test_precisions:\n                    config_key = (thread_count, qlen, \"KTransformers\", quant_mode)\n                    if config_key in completed_configs:\n                        logger.info(f\"Skipping already completed: KTransformers-{quant_mode}\")\n                        continue\n                    \n                    config_count += 1\n                    logger.info(f\"Progress: {config_count}/{total_configs_to_run}\")\n                    \n                    result = bench_ktransformers_moe(test_config, quant_mode, qlen, thread_config)\n                    if result:\n                        all_results.append(result)\n                        completed_configs.add(config_key)\n                        \n                        # Save checkpoint after each successful test\n                        if checkpoint_mgr and checkpoint_state:\n                            checkpoint_state.results = all_results\n                            checkpoint_state.completed_configs = list(completed_configs)\n                            checkpoint_mgr.save_checkpoint(checkpoint_state)\n            \n            # Test SGL int8\n            if test_sgl and \"int8\" in test_precisions and SGL_AVAILABLE:\n                config_key = (thread_count, qlen, \"SGL\", \"int8\")\n                if config_key in completed_configs:\n                    logger.info(\"Skipping already completed: SGL-int8\")\n                    continue\n                \n                config_count += 1\n                logger.info(f\"Progress: {config_count}/{total_configs_to_run}\")\n                \n                logger.info(f\"Testing SGL MoE (int8): qlen={qlen}, threads={thread_count}\")\n                sgl_intermediate = test_config.intermediate_size // sys_config.numa_count\n                sgl_threads_per_numa = thread_config.sgl_thread_count\n                logger.info(f\"Using NUMA TP: intermediate_size {test_config.intermediate_size} -> \"\n                           f\"{sgl_intermediate} (/{sys_config.numa_count}), threads per NUMA: {sgl_threads_per_numa}\")\n                \n                result = run_sgl_with_numactl(test_config, qlen, thread_config)\n                if result:\n                    all_results.append(result)\n                    completed_configs.add(config_key)\n                    \n                    # Save checkpoint after each successful test\n                    if checkpoint_mgr and checkpoint_state:\n                        checkpoint_state.results = all_results\n                        checkpoint_state.completed_configs = list(completed_configs)\n                        checkpoint_mgr.save_checkpoint(checkpoint_state)\n            \n            # Test SGL int4\n            if test_sgl and \"int4\" in test_precisions and SGL_INT4_AVAILABLE:\n                config_key = (thread_count, qlen, \"SGL\", \"int4\")\n                if config_key in completed_configs:\n                    logger.info(\"Skipping already completed: SGL-int4\")\n                    continue\n                \n                config_count += 1\n                logger.info(f\"Progress: {config_count}/{total_configs_to_run}\")\n                \n                logger.info(f\"Testing SGL MoE (int4): qlen={qlen}, threads={thread_count}\")\n                sgl_intermediate = test_config.intermediate_size // sys_config.numa_count\n                sgl_threads_per_numa = thread_config.sgl_thread_count\n                logger.info(f\"Using NUMA TP: intermediate_size {test_config.intermediate_size} -> \"\n                           f\"{sgl_intermediate} (/{sys_config.numa_count}), threads per NUMA: {sgl_threads_per_numa}\")\n                \n                result = run_sgl_int4_with_numactl(test_config, qlen, thread_config)\n                if result:\n                    all_results.append(result)\n                    completed_configs.add(config_key)\n                    \n                    # Save checkpoint after each successful test\n                    if checkpoint_mgr and checkpoint_state:\n                        checkpoint_state.results = all_results\n                        checkpoint_state.completed_configs = list(completed_configs)\n                        checkpoint_mgr.save_checkpoint(checkpoint_state)\n    \n    # Final summary\n    if all_results:\n        print_summary_table(all_results)\n        \n        # Save results\n        output_file = save_results(all_results, test_config, args.output)\n        \n        print(f\"\\nTotal benchmarks completed: {len(all_results)}\")\n        print(f\"Results saved to: {output_file}\")\n        \n        # Clear checkpoint on successful completion\n        if checkpoint_mgr and config_count == total_configs_to_run:\n            checkpoint_mgr.clear_checkpoint()\n            logger.info(\"All tests completed successfully, checkpoint cleared\")\n        elif checkpoint_mgr and config_count < total_configs_to_run:\n            logger.warning(f\"Only {config_count}/{total_configs_to_run} configurations completed\")\n            logger.info(\"Checkpoint preserved for resuming\")\n        \n        # Print best performers per configuration\n        print(\"\\nBest performers by configuration:\")\n        from itertools import groupby\n        \n        sorted_results = sorted(all_results, key=lambda r: (r.qlen, r.thread_count, r.time_per_iter_us))\n        for key, group in groupby(sorted_results, key=lambda r: (r.qlen, r.thread_count)):\n            qlen, threads = key\n            best = next(group)\n            print(f\"  QLen={qlen}, Threads={threads}: {best.implementation}-{best.quant_mode} \"\n                  f\"({best.time_per_iter_us:.2f}μs, {best.tflops:.2f} TFLOPS)\")\n    else:\n        logger.error(\"No successful benchmarks completed.\")\n        return 1\n    \n    return 0\n\nif __name__ == \"__main__\":\n    sys.exit(main())"
  },
  {
    "path": "kt-kernel/bench/multi_bench_moe.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\n自动展开 list 参数的 benchmark 脚本。\n只要将所有测试参数放在 all_params 字典中，凡是值为 list 的键都会被自动展开，\n生成参数组合后依次调用 bench_moe/bench_moe_amx 运行测试。\n\"\"\"\n\nimport os\nimport sys\nimport itertools\nfrom collections.abc import Sequence\n\n# 将当前目录加入搜索路径\nsys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))\n\n#####################################################################\n# 1. 在此处一次性写好所有测试参数\n#####################################################################\nall_params = {\n    # 固定参数\n    \"test_operator_type\": \"llamafile\",   # \"llamafile\" 或 \"amx\" \"kml\"\n    \"expert_num\": 256,\n    \"num_experts_per_tok\": 8,\n    \"hidden_size\": 7168,\n    \"intermediate_size\": 2048,\n    \"max_len\": 25600,         # amx 专用，llamafile 可保留不使用\n    \"group_max_len\": 1024,    # llamafile 专用\n    \"group_min_len\": 10,      # llamafile 专用\n    \"m_block\": [256],             # llamafile 专用\n     \"qlen\": range(1,11,1),\n    \"layer_num\": 3,\n    \"warm_up_iter\": 100,\n    \"test_iter\": 10000,\n\n    # ↓↓↓ 下面这些值是 list，会被自动展开 ↓↓↓\n    \"CPUINFER_PARAM\": [304],\n    # \"CPUINFER_PARAM\": [144], # Kunpeng 920 7280Z\n    \"quant_mode\": \"q4_k_m\", # llamafile\n    # \"quant_mode\": [\"int4\", \"int8\"], # amx\n    # \"quant_mode\": \"int8\", # amx\n}\n#####################################################################\n\n\ndef expand_param_dict(param_dict):\n    \"\"\"对值为 list 的键做笛卡儿积展开\"\"\"\n    vary_keys, vary_values, fixed_items = [], [], {}\n    for k, v in param_dict.items():\n        if isinstance(v, Sequence) and not isinstance(v, (str, bytes)):\n            vary_keys.append(k)\n            vary_values.append(v)\n        else:\n            fixed_items[k] = v\n\n    if not vary_keys:\n        yield param_dict\n        return\n\n    for combo in itertools.product(*vary_values):\n        params = fixed_items.copy()\n        params.update(dict(zip(vary_keys, combo)))\n        yield params\n\n\n# 根据 operator 类型动态导入 bench 模块\nif all_params[\"test_operator_type\"] == \"llamafile\":\n    import bench_moe as bench\nelif all_params[\"test_operator_type\"] == \"amx\":\n    import bench_moe_amx as bench\nelif all_params[\"test_operator_type\"] == \"kml\":\n    import bench_moe_kml as bench\nelse:\n    raise ValueError(f\"Unknown test_operator_type: {all_params['test_operator_type']}\")\n\n\ndef update_bench_parameters(params):\n    \"\"\"同步参数到 bench 模块并重新初始化 CPUInfer\"\"\"\n    bench.expert_num = params[\"expert_num\"]\n    bench.hidden_size = params[\"hidden_size\"]\n    bench.intermediate_size = params[\"intermediate_size\"]\n    bench.max_len = params[\"max_len\"]\n    bench.group_max_len = params[\"group_max_len\"]\n    bench.group_min_len = params[\"group_min_len\"]\n    bench.m_block = params[\"m_block\"]\n    bench.num_experts_per_tok = params[\"num_experts_per_tok\"]\n    bench.layer_num = params[\"layer_num\"]\n    bench.qlen = params[\"qlen\"]\n    bench.warm_up_iter = params[\"warm_up_iter\"]\n    bench.test_iter = params[\"test_iter\"]\n    bench.CPUINFER_PARAM = params[\"CPUINFER_PARAM\"]\n    # 重新初始化 CPUInfer 对象\n    bench.CPUInfer = bench.kt_kernel_ext.CPUInfer(bench.CPUINFER_PARAM)\n\n\ndef main():\n    for params in expand_param_dict(all_params):\n        print(\"=\" * 60)\n        print(\"开始测试参数集:\", params)\n        update_bench_parameters(params)\n        bench.bench_moe(params[\"quant_mode\"])\n        print(\"完成测试，量化模式:\", params[\"quant_mode\"])\n        print(\"=\" * 60, \"\\n\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/bench/upload-bench-json.py",
    "content": "from pymongo import MongoClient, errors\nimport json\nimport os\n\nscript_path = os.path.abspath(__file__)\nscript_dir = os.path.dirname(script_path)\n\n# === 加载 secrets.json 文件 ===\nwith open(os.path.join(script_dir,\"mongo.json\")) as f:\n    secrets = json.load(f)\n\nMONGO_URI = secrets[\"mongo_uri\"]\nDB_NAME = secrets[\"db_name\"]\nCOLLECTION_NAME = secrets[\"collection_name\"]\n\n# === 连接 MongoDB ===\nclient = MongoClient(MONGO_URI)\ndb = client[DB_NAME]\ncollection = db[COLLECTION_NAME]\n\n# 创建唯一索引（只需执行一次）\ncollection.create_index(\n    [(\"timestamp\", 1), (\"test_parameters.CPUInfer_parameter\", 1)],\n    unique=True\n)\n\n# === 插入函数 ===\ndef insert_jsonl_file(file_path):\n    total_inserted = 0\n    total_skipped = 0\n\n    with open(file_path, \"r\") as f:\n        docs = [json.loads(line) for line in f if line.strip()]\n        try:\n            result = collection.insert_many(docs, ordered=False)\n            inserted = len(result.inserted_ids)\n            total_inserted += inserted\n            print(f\"[✓] {file_path} 插入 {inserted} 条记录\")\n        except errors.BulkWriteError as e:\n            inserted = len(e.details.get(\"writeErrors\", []))\n            skipped = len(docs) - inserted\n            total_inserted += inserted\n            total_skipped += skipped\n            print(f\"[!] {file_path} 插入 {inserted} 条，跳过重复后 {skipped} 条\")\n    \n    return total_inserted, total_skipped\n\n\n\ninsert_jsonl_file( os.path.join(script_dir, \"bench_results.jsonl\"))\n"
  },
  {
    "path": "kt-kernel/cmake/DetectCPU.cmake",
    "content": "# CPU Feature Detection for kt-kernel\n# Detects CPU capabilities and sets appropriate compiler flags\n\nfunction(detect_cpu_features)\n    set(HAS_AVX2 OFF PARENT_SCOPE)\n    set(HAS_AVX512F OFF PARENT_SCOPE)\n    set(HAS_AVX512_VNNI OFF PARENT_SCOPE)\n    set(HAS_AVX512_BF16 OFF PARENT_SCOPE)\n    set(HAS_AVX512_VBMI OFF PARENT_SCOPE)\n    set(HAS_AMX OFF PARENT_SCOPE)\n\n    if(NOT EXISTS \"/proc/cpuinfo\")\n        message(STATUS \"CPU detection: /proc/cpuinfo not found, skipping auto-detection\")\n        return()\n    endif()\n\n    # Read CPU flags from /proc/cpuinfo\n    file(READ \"/proc/cpuinfo\" CPUINFO_CONTENT)\n    string(REGEX MATCH \"flags[ \\t]*:[ \\t]*([^\\n]*)\" FLAGS_LINE \"${CPUINFO_CONTENT}\")\n    if(NOT CMAKE_MATCH_1)\n        message(STATUS \"CPU detection: Could not parse CPU flags\")\n        return()\n    endif()\n\n    set(CPU_FLAGS \"${CMAKE_MATCH_1}\")\n    string(REPLACE \" \" \";\" CPU_FLAGS_LIST \"${CPU_FLAGS}\")\n\n    # Check for each feature\n    if(\"avx2\" IN_LIST CPU_FLAGS_LIST)\n        set(HAS_AVX2 ON PARENT_SCOPE)\n    endif()\n\n    if(\"avx512f\" IN_LIST CPU_FLAGS_LIST)\n        set(HAS_AVX512F ON PARENT_SCOPE)\n    endif()\n\n    if(\"avx512_vnni\" IN_LIST CPU_FLAGS_LIST OR \"avx512vnni\" IN_LIST CPU_FLAGS_LIST)\n        set(HAS_AVX512_VNNI ON PARENT_SCOPE)\n    endif()\n\n    if(\"avx512_bf16\" IN_LIST CPU_FLAGS_LIST OR \"avx512bf16\" IN_LIST CPU_FLAGS_LIST)\n        set(HAS_AVX512_BF16 ON PARENT_SCOPE)\n    endif()\n\n    if(\"avx512_vbmi\" IN_LIST CPU_FLAGS_LIST OR \"avx512vbmi\" IN_LIST CPU_FLAGS_LIST)\n        set(HAS_AVX512_VBMI ON PARENT_SCOPE)\n    endif()\n\n    # Check for AMX (need all three)\n    set(AMX_COUNT 0)\n    foreach(flag \"amx_tile\" \"amx_int8\" \"amx_bf16\")\n        if(\"${flag}\" IN_LIST CPU_FLAGS_LIST)\n            math(EXPR AMX_COUNT \"${AMX_COUNT} + 1\")\n        endif()\n    endforeach()\n    if(AMX_COUNT EQUAL 3)\n        set(HAS_AMX ON PARENT_SCOPE)\n    endif()\n\n    # Get CPU model name for display\n    string(REGEX MATCH \"model name[ \\t]*:[ \\t]*([^\\n]*)\" MODEL_LINE \"${CPUINFO_CONTENT}\")\n    if(CMAKE_MATCH_1)\n        set(CPU_MODEL \"${CMAKE_MATCH_1}\" PARENT_SCOPE)\n    endif()\nendfunction()\n\n# Main detection and configuration\nmessage(STATUS \"\")\nmessage(STATUS \"========================================\")\nmessage(STATUS \"CPU Feature Detection (CMake)\")\nmessage(STATUS \"========================================\")\n\n# Check if variables were already set by install.sh/setup.py\nset(FROM_INSTALL_SH OFF)\nif(DEFINED LLAMA_AVX512_VNNI OR DEFINED LLAMA_AVX512_BF16 OR DEFINED LLAMA_AVX512_VBMI)\n    set(FROM_INSTALL_SH ON)\n    message(STATUS \"Detected configuration from install.sh/setup.py\")\n    message(STATUS \"  LLAMA_AVX512:      ${LLAMA_AVX512}\")\n    message(STATUS \"  LLAMA_AVX512_VNNI: ${LLAMA_AVX512_VNNI}\")\n    message(STATUS \"  LLAMA_AVX512_BF16: ${LLAMA_AVX512_BF16}\")\n    message(STATUS \"  LLAMA_AVX512_VBMI: ${LLAMA_AVX512_VBMI}\")\n    message(STATUS \"\")\n    message(STATUS \"Skipping auto-detection (using install.sh settings)\")\n    message(STATUS \"========================================\")\n    message(STATUS \"\")\n    return()\nendif()\n\n# Detect CPU features (only if not set by install.sh)\ndetect_cpu_features()\n\nif(CPU_MODEL)\n    message(STATUS \"CPU Model: ${CPU_MODEL}\")\nendif()\n\nmessage(STATUS \"\")\nmessage(STATUS \"Detected features:\")\nmessage(STATUS \"  AVX2:         ${HAS_AVX2}\")\nmessage(STATUS \"  AVX512F:      ${HAS_AVX512F}\")\nmessage(STATUS \"  AVX512_VNNI:  ${HAS_AVX512_VNNI}\")\nmessage(STATUS \"  AVX512_BF16:  ${HAS_AVX512_BF16}\")\nmessage(STATUS \"  AVX512_VBMI:  ${HAS_AVX512_VBMI}\")\nmessage(STATUS \"  AMX:          ${HAS_AMX}\")\nmessage(STATUS \"\")\n\n# Auto-enable features based on detection\n# Only set if not already defined by user via -D flags\nif(NOT DEFINED LLAMA_AVX2 AND HAS_AVX2)\n    set(LLAMA_AVX2 ON CACHE BOOL \"Enable AVX2\" FORCE)\n    message(STATUS \"Auto-enabled: AVX2\")\nendif()\n\nif(NOT DEFINED LLAMA_AVX512 AND HAS_AVX512F)\n    set(LLAMA_AVX512 ON CACHE BOOL \"Enable AVX512F\" FORCE)\n    message(STATUS \"Auto-enabled: AVX512F\")\nendif()\n\nif(NOT DEFINED LLAMA_AVX512_VNNI AND HAS_AVX512_VNNI)\n    set(LLAMA_AVX512_VNNI ON CACHE BOOL \"Enable AVX512_VNNI\" FORCE)\n    message(STATUS \"Auto-enabled: AVX512_VNNI\")\nendif()\n\nif(NOT DEFINED LLAMA_AVX512_BF16 AND HAS_AVX512_BF16)\n    set(LLAMA_AVX512_BF16 ON CACHE BOOL \"Enable AVX512_BF16\" FORCE)\n    message(STATUS \"Auto-enabled: AVX512_BF16\")\nendif()\n\nif(NOT DEFINED LLAMA_AVX512_VBMI AND HAS_AVX512_VBMI)\n    set(LLAMA_AVX512_VBMI ON CACHE BOOL \"Enable AVX512_VBMI\" FORCE)\n    message(STATUS \"Auto-enabled: AVX512_VBMI\")\nendif()\n\nif(NOT DEFINED KTRANSFORMERS_CPU_USE_AMX AND HAS_AMX)\n    set(KTRANSFORMERS_CPU_USE_AMX ON CACHE BOOL \"Enable AMX\" FORCE)\n    message(STATUS \"Auto-enabled: AMX\")\nendif()\n\nmessage(STATUS \"\")\nmessage(STATUS \"Note: You can override by passing -DLLAMA_AVX512_BF16=OFF etc.\")\nmessage(STATUS \"Note: Or use install.sh with environment variables\")\nmessage(STATUS \"========================================\")\nmessage(STATUS \"\")\n"
  },
  {
    "path": "kt-kernel/cmake/FindSIMD.cmake",
    "content": "include(CheckCSourceRuns)\n\nset(AVX_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256 a;\n        a = _mm256_set1_ps(0);\n        return 0;\n    }\n\")\n\nset(AVX512_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0);\n        __m512i b = a;\n        __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);\n        return 0;\n    }\n\")\n\nset(AVX2_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256i a = {0};\n        a = _mm256_abs_epi16(a);\n        __m256i x;\n        _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code\n        return 0;\n    }\n\")\n\nset(FMA_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256 acc = _mm256_setzero_ps();\n        const __m256 d = _mm256_setzero_ps();\n        const __m256 p = _mm256_setzero_ps();\n        acc = _mm256_fmadd_ps( d, p, acc );\n        return 0;\n    }\n\")\n\nmacro(check_sse type flags)\n    set(__FLAG_I 1)\n    set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})\n    foreach (__FLAG ${flags})\n        if (NOT ${type}_FOUND)\n            set(CMAKE_REQUIRED_FLAGS ${__FLAG})\n            check_c_source_runs(\"${${type}_CODE}\" HAS_${type}_${__FLAG_I})\n            if (HAS_${type}_${__FLAG_I})\n                set(${type}_FOUND TRUE CACHE BOOL \"${type} support\")\n                set(${type}_FLAGS \"${__FLAG}\" CACHE STRING \"${type} flags\")\n            endif()\n            math(EXPR __FLAG_I \"${__FLAG_I}+1\")\n        endif()\n    endforeach()\n    set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})\n\n    if (NOT ${type}_FOUND)\n        set(${type}_FOUND FALSE CACHE BOOL \"${type} support\")\n        set(${type}_FLAGS \"\" CACHE STRING \"${type} flags\")\n    endif()\n\n    mark_as_advanced(${type}_FOUND ${type}_FLAGS)\nendmacro()\n\n# flags are for MSVC only!\ncheck_sse(\"AVX\" \" ;/arch:AVX\")\nif (NOT ${AVX_FOUND})\n    set(LLAMA_AVX OFF)\nelse()\n    set(LLAMA_AVX ON)\nendif()\n\ncheck_sse(\"AVX2\" \" ;/arch:AVX2\")\ncheck_sse(\"FMA\" \" ;/arch:AVX2\")\nif ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))\n    set(LLAMA_AVX2 OFF)\nelse()\n    set(LLAMA_AVX2 ON)\nendif()\n\ncheck_sse(\"AVX512\" \" ;/arch:AVX512\")\nif (NOT ${AVX512_FOUND})\n    set(LLAMA_AVX512 OFF)\nelse()\n    set(LLAMA_AVX512 ON)\nendif()\n"
  },
  {
    "path": "kt-kernel/cpu_backend/cpuinfer.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-16 10:43:18\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-08-07 09:47:43\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_CPUINFER_H\n#define CPUINFER_CPUINFER_H\n\n#include <atomic>\n#include <condition_variable>\n#include <functional>\n#include <mutex>\n#include <queue>\n#include <thread>\n#include <vector>\n#ifdef KTRANSFORMERS_USE_CUDA\n#include \"vendors/cuda.h\"\n#elif KTRANSFORMERS_USE_MUSA\n#include \"vendors/musa.h\"\n#elif KTRANSFORMERS_USE_ROCM\n#define __HIP_PLATFORM_AMD__\n#include \"vendors/hip.h\"\n#endif\n\n#include \"./vendors/vendor.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"task_queue.h\"\n#include \"worker_pool.h\"\n\nclass CPUInfer {\n public:\n  CPUInfer(int thread_num) {\n    printf(\"CPUInfer[0x%lx]: Hello\\n\", (intptr_t)this);\n    backend_ = new WorkerPool(thread_num);\n    task_queue_ = new TaskQueue();\n    for (int i = 0; i < (1 << 16); ++i) {\n      ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);\n    }\n  }\n  CPUInfer(int thread_num, int numa_id) {\n    printf(\"CPUInfer[0x%lx]: Hello\\n\", (intptr_t)this);\n    backend_ = new WorkerPool(thread_num, numa_id);\n    task_queue_ = new TaskQueue();\n    for (int i = 0; i < (1 << 16); ++i) {\n      ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);\n    }\n  }\n\n  CPUInfer(WorkerPoolConfig config) {\n    printf(\"CPUInfer[0x%lx]: Hello\\n\", (intptr_t)this);\n    backend_ = new WorkerPool(config);\n    task_queue_ = new TaskQueue();\n    for (int i = 0; i < (1 << 16); ++i) {\n      ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);\n    }\n  }\n\n  ~CPUInfer() {\n    printf(\"CPUInfer[0x%lx]: Goodbye\\n\", (intptr_t)this);\n    delete backend_;\n    delete task_queue_;\n  }\n\n  CPUInfer(const CPUInfer&) = delete;\n  CPUInfer& operator=(const CPUInfer&) = delete;\n  CPUInfer(CPUInfer&&) = delete;\n  CPUInfer& operator=(CPUInfer&&) = delete;\n\n  template <typename Func, typename Obj, typename... Args>\n  void enqueue(Func f, Obj* obj, Args... args) {\n    task_queue_->enqueue([=]() { std::invoke(f, *obj, args...); });\n  }\n\n  void submit(std::pair<intptr_t, intptr_t> params) {\n    void (*func)(void*) = (void (*)(void*))params.first;\n    void* args = (void*)params.second;\n    *((CPUInfer**)args) = this;\n    func(args);\n  }\n#ifndef KTRANSFORMERS_CPU_ONLY\n  void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {\n#if defined(KTRANSFORMERS_USE_CUDA)\n    void (*func)(void*) = (void (*)(void*))params.first;\n    void* args = (void*)params.second;\n    *((CPUInfer**)args) = this;\n    cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);\n#endif\n  }\n#endif\n\n  struct SyncArgs {\n    CPUInfer* cpuinfer;\n    size_t allow_n_pending;\n  };\n\n  static void sync_(void* sync_args) {\n    SyncArgs* args = (SyncArgs*)sync_args;\n    args->cpuinfer->task_queue_->sync(args->allow_n_pending);\n  }\n\n  void sync(size_t allow_n_pending = 0) {\n    SyncArgs* args = new SyncArgs{this, allow_n_pending};\n    sync_(args);\n  }\n#ifndef KTRANSFORMERS_CPU_ONLY\n  void sync_with_cuda_stream(intptr_t user_cuda_stream, size_t allow_n_pending = 0) {\n#if defined(KTRANSFORMERS_USE_CUDA)\n    SyncArgs* args = new SyncArgs{this, allow_n_pending};\n    cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)args);\n#endif\n  }\n#endif\n public:\n  WorkerPool* backend_;\n  TaskQueue* task_queue_;\n};\n\n#endif"
  },
  {
    "path": "kt-kernel/cpu_backend/shared_mem_buffer.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-08-05 04:49:08\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-08-05 09:21:29\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"shared_mem_buffer.h\"\n\n#include <errno.h>\n#include <numa.h>\n\n#include <cstdio>\n\nsize_t MemoryRequest::total_size() {\n  size_t total = 0;\n  for (size_t i = 0; i < sizes.size(); ++i) {\n    total += sizes[i];\n  }\n  return total;\n}\n\nvoid MemoryRequest::update_base_ptr(void* base) {\n  size_t total_offset = 0;\n  for (size_t i = 0; i < funcs.size(); ++i) {\n    funcs[i]((uint8_t*)base + total_offset);\n    total_offset += sizes[i];\n  }\n}\n\nvoid MemoryRequest::append_function(std::function<void(void*)> func, size_t size) {\n  funcs.push_back(func);\n  sizes.push_back(size);\n}\n\nSharedMemBuffer::SharedMemBuffer() {\n  buffer = nullptr;\n  size = 0;\n}\n\nSharedMemBuffer::~SharedMemBuffer() {\n  if (buffer) {\n    free(buffer);\n  }\n}\n\nvoid SharedMemBuffer::alloc(void* object, MemoryRequest requests) {\n  size_t total_size = requests.total_size();\n  object_requests.push_back(requests);\n\n  if (total_size > size) {\n    if (buffer) {\n      free(buffer);\n    }\n    void* newbuf = nullptr;\n    int rc = posix_memalign(&newbuf, 64, total_size);\n    if (rc != 0 || !newbuf) {\n      errno = rc;  // posix_memalign returns error code instead of setting errno\n      printf(\"cannot aligned alloc %zu bytes (align=%d)\\n\", (size_t)total_size, 64);\n      perror(\"posix_memalign\");  // ENOMEM/EINVAL\n      exit(1);\n    }\n    buffer = newbuf;\n    size = total_size;\n    for (auto& req : object_requests) {\n      req.update_base_ptr(buffer);\n    }\n  } else {\n    requests.update_base_ptr(buffer);\n  }\n}\n\nvoid SharedMemBufferNuma::alloc(int numa, void* object, MemoryRequest requests) {\n  std::lock_guard<std::mutex> guard(lock);\n  if (numa != numa_node_of_cpu(sched_getcpu())) {\n    printf(\"alloc %d from other numa for %lx\\n\", numa, reinterpret_cast<intptr_t>(object));\n  }\n  if (numa_mem.count(numa) == 0) {\n    numa_mem[numa] = std::unique_ptr<SharedMemBuffer>(new SharedMemBuffer());\n  }\n  // printf(\"numa %d alloc for %lx\\n\", numa,reinterpret_cast<intptr_t> (object));\n  numa_mem.at(numa)->alloc(object, requests);\n}\n"
  },
  {
    "path": "kt-kernel/cpu_backend/shared_mem_buffer.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-08-05 04:49:08\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-08-05 06:36:41\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#ifndef CPUINFER_SHAREDMEMBUFFER_H\n#define CPUINFER_SHAREDMEMBUFFER_H\n\n#include <cstdint>\n#include <cstdlib>\n#include <functional>\n#include <map>\n#include <memory>\n#include <mutex>\n#include <variant>\n#include <vector>\n\nstruct MemoryRequest {\n  std::vector<std::function<void(void*)>> funcs;\n  std::vector<size_t> sizes;\n\n  size_t total_size();\n  void update_base_ptr(void* base);\n\n  template <typename T>\n  void append_pointer(T** ptr, size_t size) {\n    append_function([ptr](void* base) { *ptr = reinterpret_cast<T*>(base); }, size);\n  }\n  void append_function(std::function<void(void*)> func, size_t size);\n};\n\nclass SharedMemBuffer {\n public:\n  SharedMemBuffer();\n  ~SharedMemBuffer();\n\n  void alloc(void* object, MemoryRequest requests);\n\n private:\n  void* buffer;\n  uint64_t size;\n  std::vector<MemoryRequest> object_requests;\n};\n\nstatic SharedMemBuffer shared_mem_buffer;\nstatic SharedMemBuffer shared_mem_buffer_for_decoder_layer;\n\nclass SharedMemBufferNuma {\n  std::mutex lock;\n  std::map<size_t, std::unique_ptr<SharedMemBuffer>> numa_mem;\n\n public:\n  void alloc(int numa, void* object, MemoryRequest requests);\n};\n\nstatic SharedMemBufferNuma shared_mem_buffer_numa;\n\n#endif"
  },
  {
    "path": "kt-kernel/cpu_backend/task_queue.cpp",
    "content": "/**\n * @Description :\n * @Author    : chenht2022\n * @Date     : 2024-07-17 12:25:51\n * @Version   : 1.0.0\n * @LastEditors : chenht2022\n * @LastEditTime : 2024-10-09 11:08:10\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"task_queue.h\"\n\n#include <pthread.h>\n#include <sched.h>\n\n#include <chrono>\n#include <iostream>\n#include <thread>\n\nTaskQueue::TaskQueue() : done(false), pending(0) {\n  Node* dummy = new Node();\n  head.store(dummy, std::memory_order_relaxed);\n  tail.store(dummy, std::memory_order_relaxed);\n  workerThread = std::thread(&TaskQueue::worker, this);\n}\n\nTaskQueue::~TaskQueue() {\n  done.store(true, std::memory_order_release);\n  if (workerThread.joinable()) workerThread.join();\n\n  Node* node = head.load(std::memory_order_relaxed);\n  while (node) {\n    Node* next = node->next.load(std::memory_order_relaxed);\n    delete node;\n    node = next;\n  }\n}\n\nvoid TaskQueue::enqueue(std::function<void()> task) {\n  pending.fetch_add(1, std::memory_order_acq_rel);\n  Node* node = new Node(task);\n  Node* prev = tail.exchange(node, std::memory_order_acq_rel);\n  prev->next.store(node, std::memory_order_release);\n}\n\nvoid TaskQueue::sync(size_t allow_n_pending) {\n  // Spin until the pending task count drops to the allowed threshold.\n  while (pending.load(std::memory_order_acquire) > allow_n_pending);\n}\n\nvoid TaskQueue::worker() {\n  Node* curr = head.load(std::memory_order_relaxed);\n  while (!done.load(std::memory_order_acquire)) {\n    Node* next = curr->next.load(std::memory_order_acquire);\n    if (next) {\n      if (next->task) {\n        next->task();\n      }\n      delete curr;\n      curr = next;\n      head.store(curr, std::memory_order_release);\n      pending.fetch_sub(1, std::memory_order_acq_rel);\n    }\n  }\n}"
  },
  {
    "path": "kt-kernel/cpu_backend/task_queue.h",
    "content": "/**\n * @Description :\n * @Author    : chenht2022\n * @Date     : 2024-07-16 10:43:18\n * @Version   : 1.0.0\n * @LastEditors : chenht\n * @LastEditTime : 2024-10-09 11:08:07\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_TASKQUEUE_H\n#define CPUINFER_TASKQUEUE_H\n\n#include <atomic>\n#include <condition_variable>\n#include <functional>\n#include <mutex>\n#include <queue>\n#include <thread>\n#include <vector>\n\nclass TaskQueue {\n public:\n  TaskQueue();\n  ~TaskQueue();\n\n  void enqueue(std::function<void()>);\n\n  void sync(size_t allow_n_pending);\n\n private:\n  struct Node {\n    std::function<void()> task;\n    std::atomic<Node*> next;\n    Node() : task(nullptr), next(nullptr) {}\n    Node(const std::function<void()>& t) : task(t), next(nullptr) {}\n  };\n\n  std::atomic<Node*> head;\n  std::atomic<Node*> tail;\n  std::atomic<bool> done;\n  std::atomic<size_t> pending;\n  std::thread workerThread;\n\n  void worker();\n};\n\n#endif"
  },
  {
    "path": "kt-kernel/cpu_backend/vendors/README.md",
    "content": "## TODO\n\nThis directory can be removed after updating the version of `llama.cpp`."
  },
  {
    "path": "kt-kernel/cpu_backend/vendors/cuda.h",
    "content": "#pragma once\n\n#include <cublas_v2.h>\n#include <cuda.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#if CUDART_VERSION < 11020\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#endif  // CUDART_VERSION < 11020\n"
  },
  {
    "path": "kt-kernel/cpu_backend/vendors/hip.h",
    "content": "#pragma once\n\n#define HIP_ENABLE_WARP_SYNC_BUILTINS 1\n#include <hip/hip_bfloat16.h>\n#include <hip/hip_fp16.h>\n#include <hip/hip_runtime.h>\n#include <hipblas/hipblas.h>\n#ifdef __HIP_PLATFORM_AMD__\n// for rocblas_initialize()\n#include \"rocblas/rocblas.h\"\n#endif  // __HIP_PLATFORM_AMD__\n\n#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F\n#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F\n#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N HIPBLAS_OP_N\n#define CUBLAS_OP_T HIPBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH 0\n#define CUDA_R_16F HIPBLAS_R_16F\n#define CUDA_R_32F HIPBLAS_R_32F\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended\n#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned\n#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite\n#define CU_CHECK(fn)                                              \\\n  {                                                               \\\n    hipError_t err = fn;                                          \\\n    if (err != hipSuccess) {                                      \\\n      GGML_ABORT(\"HipVMM Failure: %s\\n\", hipGetErrorString(err)); \\\n    }                                                             \\\n  }\n#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)\n#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)\n#define cublasComputeType_t hipblasDatatype_t  // deprecated, new hipblasComputeType_t not in 5.6\n#define cublasCreate hipblasCreate\n#define cublasDestroy hipblasDestroy\n#define cublasGemmEx hipblasGemmEx\n#define cublasGemmBatchedEx hipblasGemmBatchedEx\n#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx\n#define cublasHandle_t hipblasHandle_t\n#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS\n#define cublasSetStream hipblasSetStream\n#define cublasSgemm hipblasSgemm\n#define cublasStatus_t hipblasStatus_t\n#define cublasOperation_t hipblasOperation_t\n#define cudaDataType_t hipblasDatatype_t  // deprecated, new hipblasDatatype not in 5.6\n#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess\n#define cudaDeviceProp hipDeviceProp_t\n#define cudaDeviceSynchronize hipDeviceSynchronize\n#define cudaError_t hipError_t\n#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags hipEventCreateWithFlags\n#define cudaEventDisableTiming hipEventDisableTiming\n#define cudaEventRecord hipEventRecord\n#define cudaEventSynchronize hipEventSynchronize\n#define cudaEvent_t hipEvent_t\n#define cudaEventDestroy hipEventDestroy\n#define cudaFree hipFree\n#define cudaFreeHost hipHostFree\n#define cudaGetDevice hipGetDevice\n#define cudaGetDeviceCount hipGetDeviceCount\n#define cudaGetDeviceProperties hipGetDeviceProperties\n#define cudaGetErrorString hipGetErrorString\n#define cudaGetLastError hipGetLastError\n#define cudaHostRegister hipHostRegister\n#define cudaHostRegisterPortable hipHostRegisterPortable\n#define cudaHostRegisterReadOnly hipHostRegisterReadOnly\n#define cudaHostUnregister hipHostUnregister\n#define cudaLaunchHostFunc hipLaunchHostFunc\n#define cudaMalloc hipMalloc\n#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)\n#define cudaMemcpy hipMemcpy\n#define cudaMemcpyAsync hipMemcpyAsync\n#define cudaMemcpyPeerAsync hipMemcpyPeerAsync\n#define cudaMemcpy2DAsync hipMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice hipMemcpyHostToDevice\n#define cudaMemcpyKind hipMemcpyKind\n#define cudaMemset hipMemset\n#define cudaMemsetAsync hipMemsetAsync\n#define cudaMemGetInfo hipMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize\n#define cudaSetDevice hipSetDevice\n#define cuDeviceGet hipDeviceGet\n#define CUdevice hipDevice_t\n#define CUdeviceptr hipDeviceptr_t\n#define cuMemUnmap hipMemUnmap\n#define CUmemAccessDesc hipMemAccessDesc\n#define cuMemAddressFree hipMemAddressFree\n#define cuMemRelease hipMemRelease\n#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t\n#define cuMemCreate hipMemCreate\n#define cuMemAddressReserve hipMemAddressReserve\n#define cuMemMap hipMemMap\n#define cuMemSetAccess hipMemSetAccess\n#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity\n#define CUmemAllocationProp hipMemAllocationProp\n#define cuDeviceGetAttribute hipDeviceGetAttribute\n#define cudaStreamCreateWithFlags hipStreamCreateWithFlags\n#define cudaStreamDestroy hipStreamDestroy\n#define cudaStreamFireAndForget hipStreamFireAndForget\n#define cudaStreamNonBlocking hipStreamNonBlocking\n#define cudaStreamPerThread hipStreamPerThread\n#define cudaStreamSynchronize hipStreamSynchronize\n#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)\n#define cudaGraphExec_t hipGraphExec_t\n#define cudaGraphNode_t hipGraphNode_t\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaGraphExecDestroy hipGraphExecDestroy\n#define cudaGraphLaunch hipGraphLaunch\n#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure\n#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult\n#define cudaGraphNodeType hipGraphNodeType\n#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel\n#define cudaGraphInstantiate hipGraphInstantiate\n#define cudaStreamEndCapture hipStreamEndCapture\n#define cudaGraphDestroy hipGraphDestroy\n#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams\n#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction\n#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams\n#define cudaGraphNodeGetType hipGraphNodeGetType\n#define cudaGraphGetNodes hipGraphGetNodes\n#define cudaGraphExecUpdate hipGraphExecUpdate\n#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed\n#define cudaStreamBeginCapture hipStreamBeginCapture\n#define cudaGraph_t hipGraph_t\n#define cudaStream_t hipStream_t\n#define cudaSuccess hipSuccess\n#define cudaHostFn_t hipHostFn_t\n#define __trap()             \\\n  do {                       \\\n    abort();                 \\\n    __builtin_unreachable(); \\\n  } while (0)\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED\n#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED\n#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE\n#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH\n#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR\n#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED\n#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR\n#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED\n\n#define __CUDA_ARCH__ 1300\n\n#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)\n#define GCN\n#endif\n\n#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)\n#define CDNA\n#endif\n\n#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \\\n    defined(__gfx1150__) || defined(__gfx1151__)\n#define RDNA3\n#endif\n\n#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \\\n    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)\n#define RDNA2\n#endif\n\n#if defined(__gfx1010__) || defined(__gfx1012__)\n#define RDNA1\n#endif\n\n#ifndef __has_builtin\n#define __has_builtin(x) 0\n#endif\n\ntypedef hip_bfloat16 nv_bfloat16;\n"
  },
  {
    "path": "kt-kernel/cpu_backend/vendors/musa.h",
    "content": "#pragma once\n\n#include <mublas.h>\n#include <musa.h>\n#include <musa_bf16.h>\n#include <musa_fp16.h>\n#include <musa_runtime.h>\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F\n#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N MUBLAS_OP_N\n#define CUBLAS_OP_T MUBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT\n#define CUDA_R_16F MUSA_R_16F\n#define CUDA_R_32F MUSA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#define cublasCreate mublasCreate\n#define cublasDestroy mublasDestroy\n#define cublasGemmEx mublasGemmEx\n#define cublasGemmBatchedEx mublasGemmBatchedEx\n#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx\n#define cublasHandle_t mublasHandle_t\n#define cublasSetMathMode mublasSetMathMode\n#define cublasSetStream mublasSetStream\n#define cublasSgemm mublasSgemm\n#define cublasStatus_t mublasStatus_t\n#define cublasOperation_t mublasOperation_t\n#define cublasGetStatusString mublasStatus_to_string\n#define cudaDataType_t musaDataType_t\n#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess\n#define cudaDeviceProp musaDeviceProp\n#define cudaDeviceSynchronize musaDeviceSynchronize\n#define cudaError_t musaError_t\n#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags musaEventCreateWithFlags\n#define cudaEventDisableTiming musaEventDisableTiming\n#define cudaEventRecord musaEventRecord\n#define cudaEventSynchronize musaEventSynchronize\n#define cudaEvent_t musaEvent_t\n#define cudaEventDestroy musaEventDestroy\n#define cudaFree musaFree\n#define cudaFreeHost musaFreeHost\n#define cudaGetDevice musaGetDevice\n#define cudaGetDeviceCount musaGetDeviceCount\n#define cudaGetDeviceProperties musaGetDeviceProperties\n#define cudaGetErrorString musaGetErrorString\n#define cudaGetLastError musaGetLastError\n#define cudaHostRegister musaHostRegister\n#define cudaHostRegisterPortable musaHostRegisterPortable\n#define cudaHostRegisterReadOnly musaHostRegisterReadOnly\n#define cudaHostUnregister musaHostUnregister\n#define cudaLaunchHostFunc musaLaunchHostFunc\n#define cudaMalloc musaMalloc\n#define cudaMallocHost musaMallocHost\n#define cudaMallocManaged musaMallocManaged\n#define cudaMemcpy musaMemcpy\n#define cudaMemcpyAsync musaMemcpyAsync\n#define cudaMemcpyPeerAsync musaMemcpyPeerAsync\n#define cudaMemcpy2DAsync musaMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice musaMemcpyHostToDevice\n#define cudaMemcpyKind musaMemcpyKind\n#define cudaMemset musaMemset\n#define cudaMemsetAsync musaMemsetAsync\n#define cudaMemGetInfo musaMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize\n#define cudaSetDevice musaSetDevice\n#define cudaStreamCreateWithFlags musaStreamCreateWithFlags\n#define cudaStreamDestroy musaStreamDestroy\n#define cudaStreamFireAndForget musaStreamFireAndForget\n#define cudaStreamNonBlocking musaStreamNonBlocking\n#define cudaStreamPerThread musaStreamPerThread\n#define cudaStreamSynchronize musaStreamSynchronize\n#define cudaStreamWaitEvent musaStreamWaitEvent\n#define cudaStream_t musaStream_t\n#define cudaHostFn_t musaHostFn_t\n#define nv_bfloat16 mt_bfloat16\n#define cudaSuccess musaSuccess\n\n// Additional mappings for MUSA virtual memory pool\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED\n#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED\n#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE\n#define CUdevice MUdevice\n#define CUdeviceptr MUdeviceptr\n#define CUmemAccessDesc MUmemAccessDesc\n#define CUmemAllocationProp MUmemAllocationProp\n#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle\n#define cuDeviceGet muDeviceGet\n#define cuDeviceGetAttribute muDeviceGetAttribute\n#define cuMemAddressFree muMemAddressFree\n#define cuMemAddressReserve muMemAddressReserve\n#define cuMemCreate muMemCreate\n#define cuMemGetAllocationGranularity muMemGetAllocationGranularity\n#define cuMemMap muMemMap\n#define cuMemRelease muMemRelease\n#define cuMemSetAccess muMemSetAccess\n#define cuMemUnmap muMemUnmap\n#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize\n#define cudaFuncSetAttribute musaFuncSetAttribute\n#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms\n#define make_cudaExtent make_musaExtent\n#define make_cudaPitchedPtr make_musaPitchedPtr\n\n// Additional mappings for MUSA graphs\n#define CUDA_SUCCESS MUSA_SUCCESS\n#define CUresult MUresult\n#define cuGetErrorString muGetErrorString\n#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure\n#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction\n#define cudaGraphDestroy musaGraphDestroy\n#define cudaGraphExecDestroy musaGraphExecDestroy\n#define cudaGraphExec_t musaGraphExec_t\n#define cudaGraphExecUpdate musaGraphExecUpdate\n#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult\n#define cudaGraphGetNodes musaGraphGetNodes\n#define cudaGraphInstantiate musaGraphInstantiate\n#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams\n#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams\n#define cudaGraphLaunch musaGraphLaunch\n#define cudaGraphNodeGetType musaGraphNodeGetType\n#define cudaGraphNode_t musaGraphNode_t\n#define cudaGraphNodeType musaGraphNodeType\n#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel\n#define cudaGraph_t musaGraph_t\n#define cudaKernelNodeParams musaKernelNodeParams\n#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed\n#define cudaStreamEndCapture musaStreamEndCapture\n\ntypedef mt_bfloat16 nv_bfloat16;\n"
  },
  {
    "path": "kt-kernel/cpu_backend/vendors/vendor.h",
    "content": "#ifndef CPUINFER_VENDOR_VENDOR_H\n#define CPUINFER_VENDOR_VENDOR_H\n\n#ifdef USE_CUDA\n#include \"cuda.h\"\n#elif USE_HIP\n#define __HIP_PLATFORM_AMD__\n#include \"hip.h\"\n#elif USE_MUSA\n#include \"musa.h\"\n#endif\n\n#endif  // CPUINFER_VENDOR_VENDOR_H"
  },
  {
    "path": "kt-kernel/cpu_backend/worker_pool.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:05\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:33:34\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"worker_pool.h\"\n\n#include <hwloc/bitmap.h>\n#include <numa.h>\n#include <numaif.h>\n\n#include <algorithm>\n#include <cassert>\n#include <chrono>\n#include <cstdio>\n#include <stdexcept>\n\n#include \"hwloc.h\"\n\nthread_local int WorkerPool::thread_local_id = -1;\n\nInNumaPool::InNumaPool(int max_thread_num) {\n  printf(\"In Numa Worker Pool at NUMA %d, %d threads\\n\", numa_node_of_cpu(sched_getcpu()), max_thread_num);\n  total_worker_count = max_thread_num;\n  set_restricted_worker_count(total_worker_count);\n  thread_state_ = std::unique_ptr<ThreadState[]>(new ThreadState[max_thread_num]);\n  for (int i = 0; i < total_worker_count; i++) {\n    thread_state_[i].status.store(ThreadStatus::WAITING, std::memory_order_release);\n  }\n  workers_.resize(total_worker_count);\n  for (int i = 1; i < total_worker_count; i++) {\n    workers_[i] = std::thread(&InNumaPool::worker_thread, this, i, -1);\n  }\n}\n\nInNumaPool::InNumaPool(int max_thread_num, int numa_id, int threads_id_start) {\n  printf(\"===========In NumaPool============\\n\");\n  hwloc_topology_t topology;\n  hwloc_obj_t numa_obj, core_obj;\n  hwloc_bitmap_t cpuset;\n  hwloc_topology_init(&topology);\n  hwloc_topology_load(topology);\n  printf(\"In Numa Worker Pool at NUMA %d, %d threads\\n\", numa_node_of_cpu(sched_getcpu()), max_thread_num);\n  total_worker_count = max_thread_num;\n  set_restricted_worker_count(total_worker_count);\n  thread_state_ = std::unique_ptr<ThreadState[]>(new ThreadState[max_thread_num]);\n  for (int i = 0; i < total_worker_count; i++) {\n    thread_state_[i].status.store(ThreadStatus::WAITING, std::memory_order_release);\n  }\n  workers_.resize(total_worker_count);\n  for (int i = 1; i < total_worker_count; i++) {\n    workers_[i] = std::thread(&InNumaPool::worker_thread, this, i, numa_id);\n    // set the thread name as: \"numa_(numa_id)_t_(i+threads_id_start)\"\n    std::string thread_name = \"numa_\" + std::to_string(numa_id) + \"_t_\" + std::to_string(i + threads_id_start);\n    pthread_t native_handle = workers_[i].native_handle();\n    auto res_set_name = pthread_setname_np(native_handle, thread_name.c_str());\n    if (res_set_name != 0) {\n      fprintf(stderr, \"Failed to set thread name: %s\\n\", strerror(res_set_name));\n    }\n    // 检查线程是否成功命名\n    char name[16];\n    pthread_getname_np(native_handle, name, sizeof(name));\n    if (strcmp(name, thread_name.c_str()) == 0) {\n      // printf(\"Thread name set successfully: %s\\n\", name);\n    } else {\n      // printf(\"Failed to set thread name: %s\\n\", name);\n    }\n    // Set the thread affinity to the specified NUMA node's CPU\n    numa_obj = hwloc_get_obj_by_type(topology, HWLOC_OBJ_NUMANODE, numa_id);\n    if (!numa_obj) {\n      fprintf(stderr, \"NUMA node %d not found\\n\", numa_id);\n      // throw std::runtime_error(\"NUMA node not found\");\n      continue;\n    }\n    core_obj = hwloc_get_obj_inside_cpuset_by_type(topology, numa_obj->cpuset, HWLOC_OBJ_CORE, i + threads_id_start);\n    if (!core_obj) {\n      fprintf(stderr, \"Core %d inside NUMA node %d not found\\n\", i, numa_id);\n      // throw std::runtime_error(\"Core not found inside NUMA node\");\n      continue;\n    }\n    cpuset = hwloc_bitmap_alloc();\n    hwloc_bitmap_copy(cpuset, core_obj->cpuset);\n    hwloc_bitmap_singlify(cpuset);\n    auto res = hwloc_set_thread_cpubind(topology, native_handle, cpuset, HWLOC_CPUBIND_STRICT);\n    if (res != 0) {\n      fprintf(stderr, \"Failed to set thread CPU binding: %s\\n\", strerror(errno));\n    }\n  }\n}\n\nInNumaPool::~InNumaPool() {\n  for (int i = 0; i < total_worker_count; i++) {\n    thread_state_[i].status.store(ThreadStatus::EXIT, std::memory_order_release);\n  }\n  for (int i = 0; i < total_worker_count; i++) {\n    if (workers_[i].joinable()) {\n      workers_[i].join();\n    }\n  }\n}\n\nint InNumaPool::get_thread_num() {\n  throw std::runtime_error(\"Deprecated\");\n  return total_worker_count;\n}\n\nvoid InNumaPool::set_restricted_worker_count(int count) { restricted_worker_count = count; }\n\nvoid InNumaPool::wait() {\n  for (int i = 0; i < worker_count; i++) {\n    while (thread_state_[i].status.load(std::memory_order_acquire) == ThreadStatus::WORKING) {\n    }\n  }\n\n#ifdef PROFILE_BALANCE\n  size_t max_time = 0;\n  size_t min_time = thread_state_[0].finish_ns;\n  size_t sum = 0;\n  for (int i = 0; i < worker_count; i++) {\n    sum += thread_state_[i].finish_ns;\n    max_time = std::max(max_time, thread_state_[i].finish_ns);\n    min_time = std::min(min_time, thread_state_[i].finish_ns);\n  }\n  double balance = 1.0 * sum / (max_time * worker_count);\n  printf(\"max_time: %ld, min_time: %ld, sum_time: %ld, balance: %f\\n\", max_time, min_time, sum, balance);\n\n#endif\n}\n\nvoid InNumaPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func) {\n  do_work_stealing_job(task_num, nullptr, compute_func, nullptr);\n}\n\nvoid InNumaPool::do_work_stealing_job(int task_num, std::function<void(int)> init_func,\n                                      std::function<void(int)> compute_func, std::function<void(int)> finalize_func) {\n  do_work_stealing_job_async(task_num, init_func, compute_func, finalize_func);\n  wait();\n}\n\nvoid InNumaPool::do_work_stealing_job_async(int task_num, std::function<void(int)> init_func,\n                                            std::function<void(int)> compute_func,\n                                            std::function<void(int)> finalize_func) {\n  init_func_ = init_func;\n  compute_func_ = compute_func;\n  finalize_func_ = finalize_func;\n  worker_count = std::min(restricted_worker_count, task_num);\n  curr_.store(0, std::memory_order_release);\n  end_ = task_num;\n  for (int i = 0; i < worker_count; i++) {\n    thread_state_[i].status.store(ThreadStatus::WORKING, std::memory_order_release);\n  }\n  WorkerPool::thread_local_id = 0;\n  process_tasks(0);\n}\n\nvoid InNumaPool::process_tasks(int thread_id) {\n#ifdef PROFILE_BALANCE\n  auto start = std::chrono::high_resolution_clock::now();\n#endif\n  auto& s = thread_state_[thread_id];\n  if (init_func_ != nullptr) {\n    init_func_(thread_id);\n  }\n\n  // omp-guided-style work scheduling\n  while (true) {\n    int old = curr_.load(std::memory_order_relaxed);\n    int rem = end_ - old;\n    if (rem <= 0) {\n      break;\n    }\n\n    int block = (rem + worker_count - 1) / worker_count;\n    block = 1;\n    int task_id = curr_.fetch_add(block, std::memory_order_acq_rel);\n    if (task_id >= end_) {\n      break;\n    }\n\n    for (int i = 0; i < block; i++) {\n      if (task_id + i >= end_) {\n        break;\n      }\n      compute_func_(task_id + i);\n    }\n  }\n\n  if (finalize_func_ != nullptr) {\n    finalize_func_(thread_id);\n  }\n\n  s.status.store(ThreadStatus::WAITING, std::memory_order_release);\n#ifdef PROFILE_BALANCE\n  s.finish_ns =\n      std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::high_resolution_clock::now() - start).count();\n#endif\n}\n\nvoid InNumaPool::worker_thread(int thread_id, int numa_id) {\n  if (numa_id >= 0) {\n    set_memory_to_numa(numa_id);\n  }\n  auto start = std::chrono::high_resolution_clock::now();\n  WorkerPool::thread_local_id = thread_id;  // 设置线程本地变量\n  while (true) {\n    ThreadStatus status = thread_state_[thread_id].status.load(std::memory_order_acquire);\n    if (status == ThreadStatus::WORKING) {\n      process_tasks(thread_id);\n      start = std::chrono::high_resolution_clock::now();\n    } else if (status == ThreadStatus::WAITING) {\n      auto now = std::chrono::high_resolution_clock::now();\n      auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count();\n      if (duration > 50) {\n        std::this_thread::sleep_for(std::chrono::milliseconds(1));\n      }\n    } else if (status == ThreadStatus::EXIT) {\n      return;\n    }\n  }\n}\n\nNumaJobDistributor::NumaJobDistributor(int numa_count) {\n  std::vector<int> numa_ids;\n  for (int i = 0; i < numa_count; i++) {\n    numa_ids.push_back(i);\n  }\n  init(numa_ids);\n}\n\nNumaJobDistributor::NumaJobDistributor(std::vector<int> numa_ids) { init(numa_ids); }\nNumaJobDistributor::NumaJobDistributor(std::vector<int> numa_ids, std::vector<int> thread_count) {\n  init(numa_ids, thread_count);\n}\n\nvoid NumaJobDistributor::init(std::vector<int> numa_ids) {\n  this->numa_count = numa_ids.size();\n  this->ready_bar = std::unique_ptr<std::barrier<>>(new std::barrier<>(numa_count + 1));\n  this->numa_ids = numa_ids;\n  for (size_t i = 0; i < numa_count; i++) {\n    status.push_back(nullptr);\n  }\n\n  workers.resize(numa_count);\n  for (int i = 0; i < numa_count; i++) {\n    std::thread([this, i]() { workers[i] = std::thread(&NumaJobDistributor::worker_thread, this, i); }).join();\n  }\n  ready_bar->arrive_and_wait();\n}\n\nvoid NumaJobDistributor::init(std::vector<int> numa_ids, std::vector<int> thread_count) {\n  hwloc_topology_t topology;\n  hwloc_obj_t numa_obj, core_obj;\n  hwloc_bitmap_t cpuset;\n  hwloc_topology_init(&topology);\n  hwloc_topology_load(topology);\n\n  this->numa_count = numa_ids.size();\n  this->ready_bar = std::unique_ptr<std::barrier<>>(new std::barrier<>(numa_count + 1));\n  this->numa_ids = numa_ids;\n  for (size_t i = 0; i < numa_count; i++) {\n    status.push_back(nullptr);\n  }\n\n  workers.resize(numa_count);\n  std::vector<int> numa_threads_count(numa_count, 0);\n  for (int i = 0; i < numa_count; i++) {\n    workers[i] = std::thread(&NumaJobDistributor::worker_thread, this, i);\n    auto this_numa = numa_ids[i];\n    auto start_id = numa_threads_count[this_numa];\n    // set the thread name as: \"worker_numa_(numa_id)_main_start_id(0)\"\n    // printf(\"nuam_id %d, start_id %d\\n\", this_numa, start_id);\n    std::string thread_name = \"numa_\" + std::to_string(numa_ids[i]) + \"_m_\" + std::to_string(start_id);\n    pthread_t native_handle = workers[i].native_handle();\n    pthread_setname_np(native_handle, thread_name.c_str());\n    // Set the thread affinity to the specified NUMA node's CPU (0)\n    numa_obj = hwloc_get_obj_by_type(topology, HWLOC_OBJ_NUMANODE, this_numa);\n    if (!numa_obj) {\n      fprintf(stderr, \"NUMA node %d not found\\n\", this_numa);\n      // throw std::runtime_error(\"NUMA node not found\");\n      continue;\n    }\n    core_obj = hwloc_get_obj_inside_cpuset_by_type(topology, numa_obj->cpuset, HWLOC_OBJ_CORE, start_id);\n    if (!core_obj) {\n      fprintf(stderr, \"Core %d inside NUMA node %d not found\\n\", 0, this_numa);\n      // throw std::runtime_error(\"Core not found inside NUMA node\");\n      continue;\n    }\n    // 精简 cpuset\n    auto cpuset_simple = hwloc_bitmap_alloc();\n    hwloc_bitmap_copy(cpuset_simple, core_obj->cpuset);\n    hwloc_bitmap_singlify(cpuset_simple);\n    // 打印绑定的具体的 CPU 物理索引\n    unsigned long i_in;\n    // hwloc_bitmap_foreach_begin(i_in, cpuset_simple) { printf(\"Thread %d bound to CPU %ld\\n\", start_id, i_in); }\n    // hwloc_bitmap_foreach_end();\n    auto res = hwloc_set_thread_cpubind(topology, native_handle, cpuset_simple, HWLOC_CPUBIND_STRICT);\n    if (res != 0) {\n      fprintf(stderr, \"Failed to set thread CPU binding: %s\\n\", strerror(errno));\n    }\n    // 检查线程是否绑定到指定的 核上了\n    hwloc_cpuset_t cpuset = hwloc_bitmap_alloc();\n    hwloc_get_thread_cpubind(topology, native_handle, cpuset, HWLOC_CPUBIND_THREAD);\n    // hwloc_bitmap_foreach_begin(i_in, cpuset) { printf(\"Thread %d is bound to CPU %ld\\n\", start_id, i_in); }\n    // hwloc_bitmap_foreach_end();\n\n    numa_threads_count[this_numa] += thread_count[i];\n  }\n  ready_bar->arrive_and_wait();\n}\n\nNumaJobDistributor::~NumaJobDistributor() {\n  for (int i = 0; i < numa_count; i++) {\n    status[i]->store(ThreadStatus::EXIT, std::memory_order_release);\n  }\n  for (int i = 0; i < numa_count; i++) {\n    if (workers[i].joinable()) {\n      workers[i].join();\n    }\n  }\n}\n\n#ifdef USE_NUMA_JOB_DIRECT_WORK\n\nvoid NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {\n  this->compute_func = compute_func;\n  auto me_numa = numa_node_of_cpu(sched_getcpu());\n  for (int i = 0; i < numa_count; i++) {\n    if (i == me_numa) continue;\n\n    status[i]->store(ThreadStatus::WORKING, std::memory_order_release);\n  }\n  compute_func(me_numa);\n  for (int i = 0; i < numa_count; i++) {\n    if (i == me_numa) continue;\n\n    while (status[i]->load(std::memory_order_acquire) == ThreadStatus::WORKING) {\n    }\n  }\n}\n#else\nvoid NumaJobDistributor::do_numa_job(std::function<void(int)> compute_func) {\n  this->compute_func = compute_func;\n  for (int i = 0; i < numa_count; i++) {\n    status[i]->store(ThreadStatus::WORKING, std::memory_order_release);\n  }\n  for (int i = 0; i < numa_count; i++) {\n    while (status[i]->load(std::memory_order_acquire) == ThreadStatus::WORKING) {\n    }\n  }\n}\n#endif\n\nvoid NumaJobDistributor::worker_thread(int numa_id) {\n  auto start = std::chrono::high_resolution_clock::now();\n  set_memory_to_numa(numa_id);\n  status[numa_id] =\n      std::move(std::unique_ptr<std::atomic<ThreadStatus>>(new std::atomic<ThreadStatus>(ThreadStatus::WAITING)));\n  ready_bar->arrive_and_wait();\n  while (true) {\n    auto stat = status[numa_id]->load(std::memory_order_acquire);\n    if (stat == ThreadStatus::WORKING) {\n      auto me_numa = numa_node_of_cpu(sched_getcpu());\n      // printf(\"numa work on %d, me %d\\n\", numa_id, me_numa);\n      compute_func(numa_id);\n      status[numa_id]->store(ThreadStatus::WAITING, std::memory_order_release);\n      start = std::chrono::high_resolution_clock::now();\n    } else if (stat == ThreadStatus::WAITING) {\n      auto now = std::chrono::high_resolution_clock::now();\n      auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count();\n      if (duration > 50) {\n        std::this_thread::sleep_for(std::chrono::milliseconds(1));\n      }\n    } else if (stat == ThreadStatus::EXIT) {\n      return;\n    }\n  }\n}\n\nvoid WorkerPool::init(WorkerPoolConfig config) {\n  printf(\"WorkerPool[0x%lx] %d subpools, [numa:threads]\", (intptr_t)this, config.subpool_count);\n  for (int i = 0; i < config.subpool_count; i++) {\n    printf(\"[%d:%d] \", config.subpool_numa_map[i], config.subpool_thread_count[i]);\n  }\n  printf(\"\\n\");\n\n  for (int i = 0; i < config.subpool_count; i++) {\n    numa_worker_pools.push_back(nullptr);\n  }\n  std::vector<int> numa_threads_count(config.subpool_count, 0);\n  for (int i = 0; i < config.subpool_count; i++) {\n    auto this_numa = config.subpool_numa_map[i];\n    auto this_thread_count = config.subpool_thread_count[i];\n    auto this_thread_id_start = numa_threads_count[this_numa];\n    std::thread([this, i, this_numa, this_thread_count, this_thread_id_start]() {\n      set_to_numa(this_numa);\n      numa_worker_pools[i] =\n          std::move(std::unique_ptr<InNumaPool>(new InNumaPool(this_thread_count, this_numa, this_thread_id_start)));\n      // numa_worker_pools[i] = std::move(std::unique_ptr<InNumaPool>(new InNumaPool(this_thread_count)));\n    }).join();\n    numa_threads_count[this_numa] += this_thread_count;\n  }\n\n  distributor = std::move(std::unique_ptr<NumaJobDistributor>(\n      new NumaJobDistributor(config.subpool_numa_map, config.subpool_thread_count)));\n  // distributor = std::move(std::unique_ptr<NumaJobDistributor>(new NumaJobDistributor(config.subpool_numa_map)));\n}\n\nWorkerPool::WorkerPool(WorkerPoolConfig config) : config(config) { init(config); }\n\nWorkerPool::WorkerPool(int total_threads) {\n  config.subpool_count = numa_num_configured_nodes();\n  config.subpool_numa_map.resize(config.subpool_count);\n  config.subpool_thread_count.resize(config.subpool_count);\n  for (int i = 0; i < config.subpool_count; i++) {\n    config.subpool_numa_map[i] = i;\n    config.subpool_thread_count[i] = total_threads / config.subpool_count;\n  }\n  init(config);\n}\n\nWorkerPool::WorkerPool(int total_threads, int single_numa_id) {\n  set_to_numa(single_numa_id);\n  config.subpool_count = numa_num_configured_nodes();\n  config.subpool_numa_map.resize(config.subpool_count);\n  config.subpool_thread_count.resize(config.subpool_count);\n  for (int i = 0; i < config.subpool_count; i++) {\n    config.subpool_numa_map[i] = single_numa_id;\n    config.subpool_thread_count[i] = total_threads / config.subpool_count;\n  }\n  init(config);\n}\n\nWorkerPool::~WorkerPool() {}\n\nint WorkerPool::get_thread_num() { return total_thread_count; }\n\nvoid WorkerPool::set_restricted_worker_count(int count) {\n  for (int i = 0; i < numa_count; i++) {\n    numa_worker_pools[i]->set_restricted_worker_count(threads_per_numa);\n  }\n}\n\nInNumaPool* WorkerPool::get_subpool(int numa_id) { return numa_worker_pools[numa_id].get(); }\n\nNumaJobDistributor* WorkerPool::dispense_backend() { return distributor.get(); }\n\nvoid WorkerPool::do_work_stealing_job(int task_num, std::function<void(int)> init_func,\n                                      std::function<void(int)> compute_func, std::function<void(int)> finalize_func) {\n  numa_worker_pools[0]->do_work_stealing_job(task_num, init_func, compute_func, finalize_func);\n}\n\nvoid WorkerPool::do_work_stealing_job(int task_num, std::function<void(int)> compute_func) {\n  do_work_stealing_job(task_num, nullptr, compute_func, nullptr);\n}\n"
  },
  {
    "path": "kt-kernel/cpu_backend/worker_pool.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:05\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:33:38\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_BACKEND_H\n#define CPUINFER_BACKEND_H\n\n#include <hwloc.h>\n#include <numa.h>\n\n#include <atomic>\n#include <barrier>\n#include <condition_variable>\n#include <cstdio>\n#include <functional>\n#include <memory>\n#include <mutex>\n#include <thread>\n#include <vector>\n\n// #define PROFILE_BALANCE\n\ninline void set_to_numa(int this_numa) {\n  struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes());\n  numa_bitmask_setbit(mask, this_numa);\n  numa_bind(mask);\n  numa_bitmask_free(mask);\n}\n\ninline void set_memory_to_numa(int this_numa) {\n  // printf(\"Set memory to NUMA %d\\n\", this_numa);\n  hwloc_topology_t topology;\n  hwloc_topology_init(&topology);\n  hwloc_topology_load(topology);\n\n  hwloc_obj_t obj = hwloc_get_obj_by_type(topology, HWLOC_OBJ_NUMANODE, this_numa);\n  if (!obj) {\n    fprintf(stderr, \"NUMA node %d not found.\\n\", this_numa);\n    hwloc_topology_destroy(topology);\n    return;\n  }\n\n  auto ret = hwloc_set_membind(topology, obj->nodeset, HWLOC_MEMBIND_BIND,\n                               HWLOC_MEMBIND_THREAD | HWLOC_MEMBIND_STRICT | HWLOC_MEMBIND_BYNODESET);\n  if (ret != 0) {\n    perror(\"hwloc_set_membind_nodeset\");\n  }\n\n  hwloc_topology_destroy(topology);\n}\n\nenum ThreadStatus {\n  WORKING,\n  WAITING,\n  EXIT,\n};\n\nstruct alignas(64) ThreadState {\n  std::atomic<ThreadStatus> status;\n#ifdef PROFILE_BALANCE\n  size_t finish_ns;\n#endif\n};\n\nclass InNumaPool {\n public:\n  InNumaPool(int thread_count);\n  InNumaPool(int max_thread_num, int numa_id, int threads_id_start);\n  ~InNumaPool();\n  int get_thread_num();\n  void set_restricted_worker_count(int count);\n\n  void do_work_stealing_job_async(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>);\n  void wait();\n\n  void do_work_stealing_job(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>);\n  void do_work_stealing_job(int, std::function<void(int)>);\n\n private:\n  int worker_count;\n  int total_worker_count;\n\n  std::unique_ptr<ThreadState[]> thread_state_;  // [thread_num]\n  std::vector<std::thread> workers_;\n\n  // changed ever time called do_work_stealing_job_async\n  int restricted_worker_count;\n  std::function<void(int)> init_func_;\n  std::function<void(int)> compute_func_;\n  std::function<void(int)> finalize_func_;\n  std::atomic<int> curr_;\n  int end_;\n\n  void process_tasks(int);\n  void worker_thread(int, int);\n};\n\nclass NumaJobDistributor {\n public:\n  NumaJobDistributor(int numa_count);\n  NumaJobDistributor(std::vector<int> numa_ids);\n  NumaJobDistributor(std::vector<int> numa_ids, std::vector<int> thread_count);\n\n  ~NumaJobDistributor();\n\n  void do_numa_job(std::function<void(int)>);\n\n private:\n  void init(std::vector<int> numa_ids);\n  void init(std::vector<int> numa_ids, std::vector<int> thread_count);\n\n  std::unique_ptr<std::barrier<>> ready_bar;\n\n  int numa_count;\n  std::vector<int> numa_ids;\n  std::vector<std::unique_ptr<std::atomic<ThreadStatus>>> status;\n  std::function<void(int)> compute_func;\n  std::vector<std::thread> workers;\n\n  void worker_thread(int);\n};\n\nstruct WorkerPoolConfig {\n  int subpool_count;\n  std::vector<int> subpool_numa_map;\n  std::vector<int> subpool_thread_count;\n};\n\nclass WorkerPool {\n public:\n  WorkerPool(int total_thread_count);\n  WorkerPool(int total_thread_count, int single_numa_id);\n  WorkerPool(WorkerPoolConfig config);\n  ~WorkerPool();\n  int get_thread_num();\n  void set_restricted_worker_count(int count);\n\n  static thread_local int thread_local_id;\n\n  NumaJobDistributor* dispense_backend();\n\n  InNumaPool* get_subpool(int numa_id);\n\n  void do_work_stealing_job(int, std::function<void(int)>, std::function<void(int)>, std::function<void(int)>);\n  void do_work_stealing_job(int, std::function<void(int)>);\n\n  WorkerPoolConfig config;\n\n private:\n  void init(WorkerPoolConfig config);\n\n  int total_thread_count;\n  int numa_count;\n  int threads_per_numa;\n  std::unique_ptr<NumaJobDistributor> distributor;\n\n  std::vector<std::unique_ptr<InNumaPool>> numa_worker_pools;\n};\n\n#endif\n"
  },
  {
    "path": "kt-kernel/cuda/binding.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Azure-Tang, Boxin Zhang\n * @Date         : 2024-07-25 13:38:30\n * @Version      : 0.2.2\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"custom_gguf/ops.h\"\n#ifdef KTRANSFORMERS_USE_CUDA\n#include \"gptq_marlin/ops.h\"\n#include \"moe/ops.h\"\n#endif\n// Python bindings\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <torch/extension.h>\n#include <torch/library.h>\n#include <torch/torch.h>\n// namespace py = pybind11;\n\nPYBIND11_MODULE(KTransformersOps, m) {\n  m.def(\n      \"dequantize_q8_0\",\n      [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device,\n         py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n      },\n      \"Function to dequantize q8_0 data.\", py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"),\n      py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n  m.def(\n      \"dequantize_q6_k\",\n      [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device,\n         py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n      },\n      \"Function to dequantize q6_k data.\", py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"),\n      py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n  m.def(\n      \"dequantize_q5_k\",\n      [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device,\n         py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n      },\n      \"Function to dequantize q5_k data.\", py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"),\n      py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n  m.def(\n      \"dequantize_q4_k\",\n      [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device,\n         py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n      },\n      \"Function to dequantize q4_k data.\", py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"),\n      py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n  m.def(\n      \"dequantize_q3_k\",\n      [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device,\n         py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n      },\n      \"Function to dequantize q3_k data.\", py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"),\n      py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n  m.def(\n      \"dequantize_q2_k\",\n      [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device,\n         py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n      },\n      \"Function to dequantize q2_k data.\", py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"),\n      py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n  m.def(\n      \"dequantize_iq4_xs\",\n      [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device,\n         py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n      },\n      \"Function to dequantize iq4_xs data.\", py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"),\n      py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n#ifdef KTRANSFORMERS_USE_CUDA\n  m.def(\"gptq_marlin_gemm\", &gptq_marlin_gemm, \"Function to perform GEMM using Marlin quantization.\", py::arg(\"a\"),\n        py::arg(\"b_q_weight\"), py::arg(\"b_scales\"), py::arg(\"g_idx\"), py::arg(\"perm\"), py::arg(\"workspace\"),\n        py::arg(\"num_bits\"), py::arg(\"size_m\"), py::arg(\"size_n\"), py::arg(\"size_k\"), py::arg(\"is_k_full\"));\n  m.def(\"topk_softmax\", &topk_softmax, \"Function to perform topk_softmax.\", py::arg(\"topk_weights\"),\n        py::arg(\"topk_indices\"), py::arg(\"token_expert_indices\"), py::arg(\"gating_output\"));\n#endif\n}\n"
  },
  {
    "path": "kt-kernel/cuda/custom_gguf/dequant.cu",
    "content": "/*\n * @Description  :  \n * @Author       : Azure-Tang, Boxin Zhang\n * @Date         : 2024-07-25 13:38:30\n * @Version      : 0.2.2\n * Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c\n * Copyright (c) 2023-2024 The ggml authors\n * Copyright (c) 2024 by KVCache.AI, All Rights Reserved. \n */\n#include <cuda_runtime.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <torch/library.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n#include <cstdint>\n#include <c10/cuda/CUDAGuard.h>\n\n#ifdef __HIP_PLATFORM_AMD__\ntypedef hip_bfloat16 nv_bfloat16;\n#endif\n\n__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        const int8_t* cur_block = data + block_id * blk_size;\n        float scale = __half2float(*((half*)cur_block));\n        cur_block += 2;\n        for (int i = 0; i < ele_per_blk; i++){\n            output_blk[i] = scale * cur_block[i];\n        }\n    }\n}\n\n__global__ void dequantize_q8_0_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        const int8_t* cur_block = data + block_id * blk_size;\n        float scale = __half2float(*((half*)cur_block));\n        cur_block += 2;\n        for (int i = 0; i < ele_per_blk; i++) {\n            output_blk[i] = __float2half(scale * cur_block[i]);\n        }\n    }\n}\n\n__global__ void dequantize_q8_0_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        const int8_t* cur_block = data + block_id * blk_size;\n        float scale = __half2float(*((half*)cur_block));\n        cur_block += 2;\n        for (int i = 0; i < ele_per_blk; i++) {\n            output_blk[i] = __float2bfloat16(scale * cur_block[i]);\n        }\n    }\n}\n\n// __device__ void get_scale_min_k4(int j, const uint8_t * __restrict__ q, uint8_t * __restrict__ d, uint8_t * __restrict__ m) {\n__device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict__ d, uint8_t * __restrict__ m) {\n    if (j < 4) {\n        *d = q[j] & 63; *m = q[j + 4] & 63;\n    } else {\n        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);\n        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);\n    }\n}\n\n__global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));\n\n        const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);\n\n        int is = 0;\n        float dl, ml;\n\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n                uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                uint8_t sc = *scales;\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;\n\n                scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                sc = *scales;\n\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;\n\n                shift += 2;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));\n\n        const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);\n\n        int is = 0;\n        float dl, ml;\n\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n                uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                uint8_t sc = *scales;\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l] >> shift) & 3)) - ml);\n\n                scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                sc = *scales;\n\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml);\n\n                shift += 2;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));\n\n        const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);\n\n        int is = 0;\n        float dl, ml;\n\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n                uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                uint8_t sc = *scales;\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l] >> shift) & 3)) - ml);\n\n                scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                sc = *scales;\n\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml);\n\n                shift += 2;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    \n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;    \n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n\n        uint32_t aux[4];\n        const int8_t * scales = (const int8_t*)aux;\n        const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));\n\n        const uint8_t * __restrict__ q  = (uint8_t*)(data + block_id * blk_size + 32);\n        const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);\n        uint8_t m = 1;\n\n\n        uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);\n\n        for (int i = 0; i < 3; i++) {  \n            aux[i] = 0;  \n            for (int j = 0; j < 4; j++) {  \n                aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);\n            }\n        }\n\n        uint32_t tmp = aux[2];\n        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n        int is = 0;\n        float dl;\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));\n                }\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));\n                }\n\n                shift += 2;\n                m <<= 1;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    \n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;    \n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n\n        uint32_t aux[4];\n        const int8_t * scales = (const int8_t*)aux;\n        const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));\n\n        const uint8_t * __restrict__ q  = (uint8_t*)(data + block_id * blk_size + 32);\n        const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);\n        uint8_t m = 1;\n\n\n        uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);\n\n        for (int i = 0; i < 3; i++) {  \n            aux[i] = 0;  \n            for (int j = 0; j < 4; j++) {  \n                aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);\n            }\n        }\n\n        uint32_t tmp = aux[2];\n        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n        int is = 0;\n        float dl;\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2half(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)));\n                }\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)));\n                }\n\n                shift += 2;\n                m <<= 1;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    \n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;    \n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n\n        uint32_t aux[4];\n        const int8_t * scales = (const int8_t*)aux;\n        const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));\n\n        const uint8_t * __restrict__ q  = (uint8_t*)(data + block_id * blk_size + 32);\n        const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);\n        uint8_t m = 1;\n\n\n        uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);\n\n        for (int i = 0; i < 3; i++) {  \n            aux[i] = 0;  \n            for (int j = 0; j < 4; j++) {  \n                aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);\n            }\n        }\n\n        uint32_t tmp = aux[2];\n        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n        int is = 0;\n        float dl;\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)));\n                }\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)));\n                }\n\n                shift += 2;\n                m <<= 1;\n            }\n            q += 32;\n        }\n    }\n}\n\n\n__global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        // const uint8_t * q = data[i].qs;\n        const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));\n        int is = 0;\n        uint8_t sc, m;\n        for (int j = 0; j < ele_per_blk; j += 64) {\n            uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d1 * (q[l] & 0xF) - m1;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d2 * (q[l]  >> 4) - m2;\n            q += 32; is += 2;\n        }\n    }\n}\n\n__global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        // const uint8_t * q = data[i].qs;\n        const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));\n        int is = 0;\n        uint8_t sc, m;\n        for (int j = 0; j < ele_per_blk; j += 64) {\n            uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * (q[l] & 0xF) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * (q[l]  >> 4) - m2);\n            q += 32; is += 2;\n        }\n    }\n}\n\n__global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        // const uint8_t * q = data[i].qs;\n        const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));\n        int is = 0;\n        uint8_t sc, m;\n        for (int j = 0; j < ele_per_blk; j += 64) {\n            uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * (q[l] & 0xF) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * (q[l]  >> 4) - m2);\n            q += 32; is += 2;\n        }\n    }\n}\n\n__global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));\n\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);\n\n        int is = 0;\n        uint8_t sc, m;\n        uint8_t u1 = 1, u2 = 2;\n        uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);\n\n        for (int j = 0; j < 256; j += 64) {\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;\n            ql += 32; is += 2;\n            u1 <<= 2; u2 <<= 2;\n        }\n    }\n}\n\n__global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));\n\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);\n\n        int is = 0;\n        uint8_t sc, m;\n        uint8_t u1 = 1, u2 = 2;\n        uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);\n\n        for (int j = 0; j < 256; j += 64) {\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2);\n            ql += 32; is += 2;\n            u1 <<= 2; u2 <<= 2;\n        }\n    }\n}\n\n__global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));\n\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);\n\n        int is = 0;\n        uint8_t sc, m;\n        uint8_t u1 = 1, u2 = 2;\n        uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);\n\n        for (int j = 0; j < 256; j += 64) {\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2);\n            ql += 32; is += 2;\n            u1 <<= 2; u2 <<= 2;\n        }\n    }\n}\n\n__global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long  block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));\n\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);\n        const int8_t  * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);\n\n\n        for (int n = 0; n < ele_per_blk; n += 128) {\n            for (int l = 0; l < 32; ++l) {\n                int is = l/16;\n                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n                output_blk[l +  0] = d * sc[is + 0] * q1;\n                output_blk[l + 32] = d * sc[is + 2] * q2;\n                output_blk[l + 64] = d * sc[is + 4] * q3;\n                output_blk[l + 96] = d * sc[is + 6] * q4;\n            }\n            output_blk += 128;\n            ql += 64;\n            qh += 32;\n            sc += 8;\n        }\n    }\n}\n\n__global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long  block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));\n\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);\n        const int8_t  * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);\n\n\n        for (int n = 0; n < ele_per_blk; n += 128) {\n            for (int l = 0; l < 32; ++l) {\n                int is = l/16;\n                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n                output_blk[l +  0] = __float2half(d * sc[is + 0] * q1);\n                output_blk[l + 32] = __float2half(d * sc[is + 2] * q2);\n                output_blk[l + 64] = __float2half(d * sc[is + 4] * q3);\n                output_blk[l + 96] = __float2half(d * sc[is + 6] * q4);\n            }\n            output_blk += 128;\n            ql += 64;\n            qh += 32;\n            sc += 8;\n        }\n    }\n}\n\n__global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long  block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));\n\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);\n        const int8_t  * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);\n\n\n        for (int n = 0; n < ele_per_blk; n += 128) {\n            for (int l = 0; l < 32; ++l) {\n                int is = l/16;\n                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n                output_blk[l +  0] = __float2bfloat16(d * sc[is + 0] * q1);\n                output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2);\n                output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3);\n                output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4);\n            }\n            output_blk += 128;\n            ql += 64;\n            qh += 32;\n            sc += 8;\n        }\n    }\n}\n\nstatic constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};\n\n__global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));\n        const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));\n        const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);\n        const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);\n\n        for (int ib = 0; ib < 8; ++ib) {\n            const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);\n            const float dl = d * (ls - 32);\n            for (int j = 0; j < 16; ++j) {\n                output_blk[j + 0] = dl * kvalues_iq4nl[qs[j] & 0xf];\n                output_blk[j + 16] = dl * kvalues_iq4nl[qs[j] >> 4];\n            }\n            output_blk += 32;\n            qs += 16;\n        }\n    }\n}\n\n__global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));\n        const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));\n        const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);\n        const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);\n\n        for (int ib = 0; ib < 8; ++ib) {\n            const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);\n            const float dl = d * (ls - 32);\n            for (int j = 0; j < 16; ++j) {\n                output_blk[j + 0] = __float2half(dl * kvalues_iq4nl[qs[j] & 0xf]);\n                output_blk[j + 16] = __float2half(dl * kvalues_iq4nl[qs[j] >> 4]);\n            }\n            output_blk += 32;\n            qs += 16;\n        }\n    }\n}\n\n__global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));\n        const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));\n        const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);\n        const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);\n\n        for (int ib = 0; ib < 8; ++ib) {\n            const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);\n            const float dl = d * (ls - 32);\n            for (int j = 0; j < 16; ++j) {\n                output_blk[j + 0] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] & 0xf]);\n                output_blk[j + 16] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] >> 4]);\n            }\n            output_blk += 32;\n            qs += 16;\n        }\n    }\n}\n\ntorch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({ num_bytes }, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({ num_blocks, ele_per_blk }, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n\n    cudaDeviceSynchronize();\n    return output;\n}\n\n\ntorch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    // data.numel%blk_size should be 0, else raise err\n    int num_blocks = num_bytes / blk_size;\n\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    // data.numel%blk_size should be 0, else raise err\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}"
  },
  {
    "path": "kt-kernel/cuda/custom_gguf/ops.h",
    "content": "/**\n * @Description  :\n * @Author       : Azure-Tang\n * @Date         : 2024-07-22 09:27:55\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-12 03:48:46\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#pragma once\n\n#include <torch/extension.h>\n#include <torch/library.h>\n#include <torch/torch.h>\n\ntorch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk,\n                              const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk,\n                              const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk,\n                              const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk,\n                              const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk,\n                              const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk,\n                              const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk,\n                                const torch::Device device, const torch::Dtype target_dtype);\n"
  },
  {
    "path": "kt-kernel/cuda/gptq_marlin/gptq_marlin.cu",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*\n * Adapted from https://github.com/IST-DASLab/marlin\n */\n/*\n * Adapted from  https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n */\n#include \"gptq_marlin.cuh\"\n#include \"gptq_marlin_dtypes.cuh\"\n#include <c10/cuda/CUDAGuard.h>\n#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t)               \\\n  static_assert(std::is_same<scalar_t, half>::value ||          \\\n                    std::is_same<scalar_t, nv_bfloat16>::value, \\\n                \"only float16 and bfloat16 is supported\");\n\ntemplate <typename T>\ninline std::string str(T x) {\n  return std::to_string(x);\n}\n\nnamespace gptq_marlin {\n\n#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__)\n\n__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,\n                                    int const* __restrict__ perm_int_ptr,\n                                    int4* __restrict__ out_int4_ptr, int size_m,\n                                    int size_k, int block_rows) {}\n\ntemplate <typename scalar_t,          // compute dtype, half or nv_float16\n          const int num_bits,         // number of bits used for weights\n          const int threads,          // number of threads in a threadblock\n          const int thread_m_blocks,  // number of 16x16 blocks in the m\n                                      // dimension (batchsize) of the\n                                      // threadblock\n          const int thread_n_blocks,  // same for n dimension (output)\n          const int thread_k_blocks,  // same for k dimension (reduction)\n          const int stages,  // number of stages for the async global->shared\n                             // fetch pipeline\n          const bool has_act_order,    // whether act_order is enabled\n          const int group_blocks = -1  // number of consecutive 16x16 blocks\n                                       // with a separate quantization scale\n          >\n__global__ void Marlin(\n    const int4* __restrict__ A,  // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,        // fp16 output buffer of shape mxn\n    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape\n                                          // (k/groupsize)xn\n    const int* __restrict__ g_idx,        // int32 group indices of shape k\n    int num_groups,  // number of scale groups per output channel\n    int prob_m,      // batch dimension m\n    int prob_n,      // output dimension n\n    int prob_k,      // reduction dimension k\n    int* locks       // extra global storage for barrier synchronization\n) {}\n\n}  // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                               torch::Tensor& b_scales, torch::Tensor& g_idx,\n                               torch::Tensor& perm, torch::Tensor& workspace,\n                               int64_t num_bits, int64_t size_m, int64_t size_n,\n                               int64_t size_k, bool is_k_full) {\n  TORCH_CHECK_NOT_IMPLEMENTED(false,\n                              \"marlin_gemm(..) requires CUDA_ARCH >= 8.0\");\n  return torch::empty({1, 1});\n}\n\n#else\n\n// m16n8k16 tensor core mma instruction with fp16 inputs and fp32\n// output/accumulation.\ntemplate <typename scalar_t>\n__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,\n                           const typename ScalarType<scalar_t>::FragB& frag_b,\n                           typename ScalarType<scalar_t>::FragC& frag_c) {\n  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n  float* c = reinterpret_cast<float*>(&frag_c);\n  if constexpr (std::is_same<scalar_t, half>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n          \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n          \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else {\n    STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n  }\n}\n\n// Instruction for loading a full 16x16 matrix fragment of operand A from shared\n// memory, directly in tensor core layout.\ntemplate <typename scalar_t>\n__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,\n                             const void* smem_ptr) {\n  uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n               : \"=r\"(a[0]), \"=r\"(a[1]), \"=r\"(a[2]), \"=r\"(a[3])\n               : \"r\"(smem));\n}\n\n// Lookup-table based 3-input logical operation; explicitly used for\n// dequantization as the compiler does not seem to automatically recognize it in\n// all cases.\ntemplate <int lut>\n__device__ inline int lop3(int a, int b, int c) {\n  int res;\n  asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n               : \"=r\"(res)\n               : \"r\"(a), \"r\"(b), \"r\"(c), \"n\"(lut));\n  return res;\n}\n\n// Constructs destination register by taking bytes from 2 sources (based on\n// mask)\ntemplate <int start_byte, int mask>\n__device__ inline uint32_t prmt(uint32_t a) {\n  uint32_t res;\n  asm volatile(\"prmt.b32 %0, %1, %2, %3;\\n\"\n               : \"=r\"(res)\n               : \"r\"(a), \"n\"(start_byte), \"n\"(mask));\n  return res;\n}\n\n// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16\n// values. We mostly follow the strategy in the link below, with some small\n// changes:\n// - FP16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287\n// - BF16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385\ntemplate <typename scalar_t>\n__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {\n  STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n}\n\ntemplate <>\n__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {\n  const int LO = 0x000f000f;\n  const int HI = 0x00f000f0;\n  const int EX = 0x64006400;\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);\n  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point\n  // directly into `SUB` and `ADD`.\n  const int SUB = 0x64086408;\n  const int MUL = 0x2c002c00;\n  const int ADD = 0xd480d480;\n  typename ScalarType<half>::FragB frag_b;\n  frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),\n                      *reinterpret_cast<const half2*>(&SUB));\n  frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),\n                      *reinterpret_cast<const half2*>(&MUL),\n                      *reinterpret_cast<const half2*>(&ADD));\n  return frag_b;\n}\n\ntemplate <>\n__device__ inline typename ScalarType<nv_bfloat16>::FragB\ndequant_4bit<nv_bfloat16>(int q) {\n  static constexpr uint32_t MASK = 0x000f000f;\n  static constexpr uint32_t EX = 0x43004300;\n\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n  q >>= 4;\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n\n  typename ScalarType<nv_bfloat16>::FragB frag_b;\n  static constexpr uint32_t MUL = 0x3F803F80;\n  static constexpr uint32_t ADD = 0xC308C308;\n\n  frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),\n                      *reinterpret_cast<const nv_bfloat162*>(&MUL),\n                      *reinterpret_cast<const nv_bfloat162*>(&ADD));\n  frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),\n                      *reinterpret_cast<const nv_bfloat162*>(&MUL),\n                      *reinterpret_cast<const nv_bfloat162*>(&ADD));\n  return frag_b;\n}\n\n// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or\n// bf16 Reference:\n// - FP16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85\n// - BF16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175\ntemplate <typename scalar_t>\n__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {\n  STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n}\n\ntemplate <>\n__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {\n  static constexpr uint32_t mask_for_elt_01 = 0x5250;\n  static constexpr uint32_t mask_for_elt_23 = 0x5351;\n  static constexpr uint32_t start_byte_for_fp16 = 0x64646464;\n\n  uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);\n  uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);\n\n  static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;\n\n  typename ScalarType<half>::FragB frag_b;\n  frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),\n                      *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n  frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),\n                      *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n  return frag_b;\n}\n\ntemplate <>\n__device__ inline typename ScalarType<nv_bfloat16>::FragB\ndequant_8bit<nv_bfloat16>(int q) {\n  typename ScalarType<nv_bfloat16>::FragB frag_b;\n\n  float fp32_intermediates[4];\n  uint32_t* fp32_intermediates_casted =\n      reinterpret_cast<uint32_t*>(fp32_intermediates);\n\n  static constexpr uint32_t fp32_base = 0x4B000000;\n  fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);\n  fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);\n  fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);\n  fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);\n\n  fp32_intermediates[0] -= 8388736.f;\n  fp32_intermediates[1] -= 8388736.f;\n  fp32_intermediates[2] -= 8388736.f;\n  fp32_intermediates[3] -= 8388736.f;\n\n  uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);\n  bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],\n                                   fp32_intermediates_casted[1], 0x7632);\n  bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],\n                                   fp32_intermediates_casted[3], 0x7632);\n\n  return frag_b;\n}\n\n// Multiply dequantized values by the corresponding quantization scale; used\n// only for grouped quantization.\ntemplate <typename scalar_t>\n__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,\n                             typename ScalarType<scalar_t>::FragS& frag_s,\n                             int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s =\n      ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);\n  frag_b[0] = __hmul2(frag_b[0], s);\n  frag_b[1] = __hmul2(frag_b[1], s);\n}\n\n// Same as above, but for act_order (each K is multiplied individually)\ntemplate <typename scalar_t>\n__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,\n                              typename ScalarType<scalar_t>::FragS& frag_s_1,\n                              typename ScalarType<scalar_t>::FragS& frag_s_2,\n                              typename ScalarType<scalar_t>::FragS& frag_s_3,\n                              typename ScalarType<scalar_t>::FragS& frag_s_4,\n                              int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s_val_1_2;\n  s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];\n  s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];\n\n  scalar_t2 s_val_3_4;\n  s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];\n  s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];\n\n  frag_b[0] = __hmul2(frag_b[0], s_val_1_2);\n  frag_b[1] = __hmul2(frag_b[1], s_val_3_4);\n}\n\n// Given 2 floats multiply by 2 scales (halves)\ntemplate <typename scalar_t>\n__device__ inline void scale_float(float* c,\n                                   typename ScalarType<scalar_t>::FragS& s) {\n  scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);\n  c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));\n  c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));\n}\n\n// Wait until barrier reaches `count`, then lock for current threadblock.\n__device__ inline void barrier_acquire(int* lock, int count) {\n  if (threadIdx.x == 0) {\n    int state = -1;\n    do\n      // Guarantee that subsequent writes by this threadblock will be visible\n      // globally.\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\"\n                   : \"=r\"(state)\n                   : \"l\"(lock));\n    while (state != count);\n  }\n  __syncthreads();\n}\n\n// Release barrier and increment visitation count.\n__device__ inline void barrier_release(int* lock, bool reset = false) {\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    if (reset) {\n      lock[0] = 0;\n      return;\n    }\n    int val = 1;\n    // Make sure that all writes since acquiring this barrier are visible\n    // globally, while releasing the barrier.\n    asm volatile(\"fence.acq_rel.gpu;\\n\");\n    asm volatile(\"red.relaxed.gpu.global.add.s32 [%0], %1;\\n\"\n                 :\n                 : \"l\"(lock), \"r\"(val));\n  }\n}\n\n// For a given \"a\" of size [M,K] performs a permutation of the K columns based\n// on the given \"perm\" indices.\n__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,\n                                    int const* __restrict__ perm_int_ptr,\n                                    int4* __restrict__ out_int4_ptr, int size_m,\n                                    int size_k, int block_rows) {\n  int start_row = block_rows * blockIdx.x;\n  int finish_row = start_row + block_rows;\n  if (finish_row > size_m) {\n    finish_row = size_m;\n  }\n  int cur_block_rows = finish_row - start_row;\n\n  int row_stride = size_k * sizeof(half) / 16;\n\n  auto permute_row = [&](int row) {\n    int iters = size_k / default_threads;\n    int rest = size_k % default_threads;\n\n    int offset = row * row_stride;\n\n    half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);\n    half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);\n\n    int base_k = 0;\n\n    for (int i = 0; i < iters; i++) {\n      int cur_k = base_k + threadIdx.x;\n      int src_pos = perm_int_ptr[cur_k];\n\n      out_half[cur_k] = a_row_half[src_pos];\n\n      base_k += default_threads;\n    }\n\n    if (rest) {\n      if (threadIdx.x < rest) {\n        int cur_k = base_k + threadIdx.x;\n        int src_pos = perm_int_ptr[cur_k];\n\n        out_half[cur_k] = a_row_half[src_pos];\n      }\n    }\n  };\n\n  for (int i = 0; i < cur_block_rows; i++) {\n    int cur_row = start_row + i;\n    if (cur_row < size_m) {\n      permute_row(cur_row);\n    }\n  }\n}\n\ntemplate <typename scalar_t,          // compute dtype, half or nv_float16\n          const int num_bits,         // number of bits used for weights\n          const int threads,          // number of threads in a threadblock\n          const int thread_m_blocks,  // number of 16x16 blocks in the m\n                                      // dimension (batchsize) of the\n                                      // threadblock\n          const int thread_n_blocks,  // same for n dimension (output)\n          const int thread_k_blocks,  // same for k dimension (reduction)\n          const int stages,  // number of stages for the async global->shared\n                             // fetch pipeline\n          const bool has_act_order,    // whether act_order is enabled\n          const int group_blocks = -1  // number of consecutive 16x16 blocks\n                                       // with a separate quantization scale\n          >\n__global__ void Marlin(\n    const int4* __restrict__ A,  // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,        // fp16 output buffer of shape mxn\n    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape\n                                          // (k/groupsize)xn\n    const int* __restrict__ g_idx,        // int32 group indices of shape k\n    int num_groups,  // number of scale groups per output channel\n    int prob_m,      // batch dimension m\n    int prob_n,      // output dimension n\n    int prob_k,      // reduction dimension k\n    int* locks       // extra global storage for barrier synchronization\n) {\n  // Each threadblock processes one \"stripe\" of the B matrix with (roughly) the\n  // same size, which might involve multiple column \"slices\" (of width 16 *\n  // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM\n  // example:\n  //   0 1 3\n  //   0 2 3\n  //   1 2 4\n  // While this kind of partitioning makes things somewhat more complicated, it\n  // ensures good utilization of all SMs for many kinds of shape and GPU\n  // configurations, while requiring as few slow global cross-threadblock\n  // reductions as possible.\n  using Dtype = ScalarType<scalar_t>;\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  using FragA = typename ScalarType<scalar_t>::FragA;\n  using FragB = typename ScalarType<scalar_t>::FragB;\n  using FragC = typename ScalarType<scalar_t>::FragC;\n  using FragS = typename ScalarType<scalar_t>::FragS;\n\n  constexpr int pack_factor = 32 / num_bits;\n\n  // For larger GEMMs we run multiple batchsize 64 versions in parallel for a\n  // better partitioning with less reductions\n  int parallel = 1;\n  if (prob_m > 16 * thread_m_blocks) {\n    parallel = prob_m / (16 * thread_m_blocks);\n    prob_m = 16 * thread_m_blocks;\n  }\n\n  int k_tiles = prob_k / 16 / thread_k_blocks;\n  int n_tiles = prob_n / 16 / thread_n_blocks;\n  int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);\n\n  if constexpr (!has_act_order && group_blocks != -1) {\n    if (group_blocks >= thread_k_blocks) {\n      // Ensure that the number of tiles in each stripe is a multiple of the\n      // groupsize; this avoids an annoying special case where a stripe starts\n      // in the middle of group.\n      iters = (group_blocks / thread_k_blocks) *\n              div_ceil(iters, (group_blocks / thread_k_blocks));\n    }\n  }\n\n  int slice_row = (iters * blockIdx.x) % k_tiles;\n  int slice_col_par = (iters * blockIdx.x) / k_tiles;\n  int slice_col = slice_col_par;\n  int slice_iters;  // number of threadblock tiles in the current slice\n  int slice_count =\n      0;          // total number of active threadblocks in the current slice\n  int slice_idx;  // index of threadblock in current slice; numbered bottom to\n                  // top\n\n  // We can easily implement parallel problem execution by just remapping\n  // indices and advancing global pointers\n  if (slice_col_par >= n_tiles) {\n    A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;\n    C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;\n    locks += (slice_col_par / n_tiles) * n_tiles;\n    slice_col = slice_col_par % n_tiles;\n  }\n\n  // Compute all information about the current slice which is required for\n  // synchronization.\n  auto init_slice = [&]() {\n    slice_iters =\n        iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);\n    if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;\n    if (slice_iters == 0) return;\n    if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;\n    slice_count = 1;\n    slice_idx = 0;\n    int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);\n    if (col_first <= k_tiles * (slice_col_par + 1)) {\n      int col_off = col_first - k_tiles * slice_col_par;\n      slice_count = div_ceil(k_tiles - col_off, iters);\n      if (col_off > 0) slice_count++;\n      int delta_first = iters * blockIdx.x - col_first;\n      if (delta_first < 0 || (col_off == 0 && delta_first == 0))\n        slice_idx = slice_count - 1;\n      else {\n        slice_idx = slice_count - 1 - delta_first / iters;\n        if (col_off > 0) slice_idx--;\n      }\n    }\n    if (slice_col == n_tiles) {\n      A += 16 * thread_m_blocks * prob_k / 8;\n      C += 16 * thread_m_blocks * prob_n / 8;\n      locks += n_tiles;\n      slice_col = 0;\n    }\n  };\n  init_slice();\n\n  // A sizes/strides\n\n  // stride of the A matrix in global memory\n  int a_gl_stride = prob_k / 8;\n  // stride of an A matrix tile in shared memory\n  constexpr int a_sh_stride = 16 * thread_k_blocks / 8;\n  // delta between subsequent A tiles in global memory\n  constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;\n  // between subsequent accesses within a tile\n  int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory writes\n  constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory tile reads\n  constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));\n  // within a shared memory tile\n  constexpr int a_sh_rd_delta_i = a_sh_stride * 16;\n  // overall size of a tile\n  constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);\n  // number of shared write iterations for a tile\n  constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);\n\n  // B sizes/strides\n  int b_gl_stride = 16 * prob_n / (pack_factor * 4);\n  constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;\n  constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;\n  constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;\n\n  int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;\n  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);\n  constexpr int b_sh_wr_delta = threads * b_thread_vecs;\n  constexpr int b_sh_rd_delta = threads * b_thread_vecs;\n  constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;\n  constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;\n\n  // Scale sizes/strides without act_order\n  int s_gl_stride = prob_n / 8;\n  constexpr int s_sh_stride = 16 * thread_n_blocks / 8;\n  constexpr int s_tb_groups =\n      !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks\n          ? thread_k_blocks / group_blocks\n          : 1;\n  constexpr int s_sh_stage = s_tb_groups * s_sh_stride;\n  int s_gl_rd_delta = s_gl_stride;\n\n  // Scale size/strides with act_order\n  constexpr int tb_k = 16 * thread_k_blocks;\n  constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;\n  // constexpr int act_s_row_stride      = 1;\n  // int           act_s_col_stride      = act_s_row_stride * num_groups;\n  int act_s_col_stride = 1;\n  int act_s_col_warp_stride = act_s_col_stride * 8;\n  int tb_n_warps = thread_n_blocks / 4;\n  int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;\n\n  // Global A read index of current thread.\n  int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                (threadIdx.x % a_gl_rd_delta_o);\n  a_gl_rd += a_gl_rd_delta_o * slice_row;\n  // Shared write index of current thread.\n  int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                (threadIdx.x % a_gl_rd_delta_o);\n  // Shared read index.\n  int a_sh_rd =\n      a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;\n  a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));\n\n  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +\n                (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;\n  b_gl_rd += b_sh_stride * slice_col;\n  b_gl_rd += b_gl_rd_delta_o * slice_row;\n  int b_sh_wr = threadIdx.x * b_thread_vecs;\n  int b_sh_rd = threadIdx.x * b_thread_vecs;\n\n  // For act_order\n  constexpr int k_iter_size = tb_k / b_sh_wr_iters;\n  int slice_k_start = tb_k * slice_row;\n  int slice_k_finish = slice_k_start + tb_k * slice_iters;\n  int slice_k_start_shared_fetch = slice_k_start;\n  int slice_n_offset = act_s_col_tb_stride * slice_col;\n\n  // No act_order\n  int s_gl_rd;\n  if constexpr (!has_act_order) {\n    if constexpr (group_blocks == -1) {\n      s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n    } else {\n      s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +\n                s_sh_stride * slice_col + threadIdx.x;\n    }\n  }\n  int s_sh_wr = threadIdx.x;\n  bool s_sh_wr_pred = threadIdx.x < s_sh_stride;\n\n  // We use a different scale layout for grouped and column-wise quantization as\n  // we scale a `half2` tile in column-major layout in the former and in\n  // row-major in the latter case.\n  int s_sh_rd;\n  if constexpr (group_blocks != -1)\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n              (threadIdx.x % 32) / 4;\n  else\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n              (threadIdx.x % 32) % 4;\n\n  // Precompute which thread should not read memory in which iterations; this is\n  // needed if there are more threads than required for a certain tilesize or\n  // when the batchsize is not a multiple of 16.\n  bool a_sh_wr_pred[a_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;\n\n  // To ensure that writing and reading A tiles to/from shared memory, the\n  // latter in fragment format, is fully bank conflict free, we need to use a\n  // rather fancy XOR-based layout. The key here is that neither reads nor\n  // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the\n  // same shared memory banks. Further, it seems (based on NSight-Compute) that\n  // each warp must also write a consecutive memory segment?\n  auto transform_a = [&](int i) {\n    int row = i / a_gl_rd_delta_o;\n    return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;\n  };\n  // Since the computation of this remapping is non-trivial and, due to our main\n  // loop unrolls, all shared memory accesses are static, we simply precompute\n  // both transformed reads and writes.\n  int a_sh_wr_trans[a_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);\n  int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];\n  #pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++) {\n  #pragma unroll\n    for (int j = 0; j < thread_m_blocks; j++)\n      a_sh_rd_trans[i][j] =\n          transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);\n  }\n\n  // Since B-accesses have non-constant stride they have to be computed at\n  // runtime; we break dependencies between subsequent accesses with a tile by\n  // maintining multiple pointers (we have enough registers), a tiny\n  // optimization.\n  const int4* B_ptr[b_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++)\n    B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;\n\n  extern __shared__ int4 sh[];\n  // Shared memory storage for global fetch pipelines.\n  int4* sh_a = sh;\n  int4* sh_b = sh_a + (stages * a_sh_stage);\n  int4* sh_g_idx = sh_b + (stages * b_sh_stage);\n  int4* sh_s = sh_g_idx + (stages * g_idx_stage);\n\n  // Register storage for double buffer of shared memory reads.\n  FragA frag_a[2][thread_m_blocks];\n  I4 frag_b_quant[2][b_thread_vecs];\n  FragC frag_c[thread_m_blocks][4][2];\n  FragS frag_s[2][4];         // No act-order\n  FragS act_frag_s[2][4][4];  // For act-order\n\n  // Zero accumulators.\n  auto zero_accums = [&]() {\n  #pragma unroll\n    for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)\n      reinterpret_cast<float*>(frag_c)[i] = 0;\n  };\n\n  int sh_first_group_id = -1;\n  int sh_num_groups = -1;\n  constexpr int sh_max_num_groups = 32;\n\n  auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,\n                                    int last_group_id) {\n    sh_first_group_id = first_group_id;\n    sh_num_groups = last_group_id - first_group_id + 1;\n\n    if (sh_num_groups < sh_max_num_groups) {\n      sh_num_groups = sh_max_num_groups;\n    }\n\n    if (sh_first_group_id + sh_num_groups > num_groups) {\n      sh_num_groups = num_groups - sh_first_group_id;\n    }\n\n    int row_offset = first_group_id * s_gl_stride;\n\n    if (is_async) {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],\n                         &scales_ptr[row_offset + (i * s_gl_stride) +\n                                     slice_n_offset + threadIdx.x]);\n        }\n      }\n    } else {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          sh_s[(i * s_sh_stride) + threadIdx.x] =\n              scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +\n                         threadIdx.x];\n        }\n      }\n    }\n  };\n  // Asynchronously fetch the next A, B and s tile from global to the next\n  // shared memory pipeline location.\n  auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {\n    if (pred) {\n      int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n  #pragma unroll\n      for (int i = 0; i < a_sh_wr_iters; i++) {\n        cp_async4_pred(\n            &sh_a_stage[a_sh_wr_trans[i]],\n            &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],\n            a_sh_wr_pred[i]);\n      }\n      int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n  #pragma unroll\n      for (int i = 0; i < b_sh_wr_iters; i++) {\n  #pragma unroll\n        for (int j = 0; j < b_thread_vecs; j++) {\n          cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);\n        }\n\n        B_ptr[i] += b_gl_rd_delta_o;\n      }\n\n      if constexpr (has_act_order) {\n        // Fetch g_idx thread-block portion\n        int full_pipe = a_off;\n        int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;\n        if (cur_k < prob_k && cur_k < slice_k_finish) {\n          int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n\n          int4 const* cur_g_idx_stage_ptr =\n              reinterpret_cast<int4 const*>(&g_idx[cur_k]);\n\n          if (threadIdx.x < g_idx_stage) {\n            cp_async4_pred(&sh_g_idx_stage[threadIdx.x],\n                           &cur_g_idx_stage_ptr[threadIdx.x]);\n          }\n        }\n      } else {\n        if constexpr (group_blocks != -1) {\n          int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n          if constexpr (group_blocks >= thread_k_blocks) {\n            // Only fetch scales if this tile starts a new group\n            if (pipe % (group_blocks / thread_k_blocks) == 0) {\n              if (s_sh_wr_pred) {\n                cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);\n              }\n              s_gl_rd += s_gl_rd_delta;\n            }\n          } else {\n            for (int i = 0; i < s_tb_groups; i++) {\n              if (s_sh_wr_pred) {\n                cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],\n                          &scales_ptr[s_gl_rd]);\n              }\n              s_gl_rd += s_gl_rd_delta;\n            }\n          }\n        }\n      }\n    }\n    // Insert a fence even when we are winding down the pipeline to ensure that\n    // waiting is also correct at this point.\n    cp_async_fence();\n  };\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<stages - 2>();\n    __syncthreads();\n  };\n\n  // Load the next sub-tile from the current location in the shared memory pipe\n  // into the current register buffer.\n  auto fetch_to_registers = [&](int k, int pipe) {\n    int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n  #pragma unroll\n    for (int i = 0; i < thread_m_blocks; i++)\n      ldsm4<scalar_t>(frag_a[k % 2][i],\n                      &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);\n    int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n\n  #pragma unroll\n    for (int i = 0; i < b_thread_vecs; i++) {\n      frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(\n          &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);\n    }\n  };\n\n  bool is_same_group[stages];\n  int same_group_id[stages];\n\n  auto init_same_group = [&](int pipe) {\n    if constexpr (!has_act_order) {\n      is_same_group[pipe] = false;\n      same_group_id[pipe] = 0;\n      return;\n    }\n\n    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n    int group_id_1 = sh_g_idx_int_ptr[0];\n    int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];\n\n    is_same_group[pipe] = group_id_1 == group_id_2;\n    same_group_id[pipe] = group_id_1;\n  };\n\n  auto fetch_scales_to_registers = [&](int k, int full_pipe) {\n    int pipe = full_pipe % stages;\n\n    if constexpr (!has_act_order) {\n      // No act-order case\n      if constexpr (group_blocks != -1) {\n        if constexpr (group_blocks >= thread_k_blocks) {\n          int4* sh_s_stage =\n              sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *\n                                   (pipe / (group_blocks / thread_k_blocks)));\n          reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];\n        } else {\n          int warp_id = threadIdx.x / 32;\n          int n_warps = thread_n_blocks / 4;\n\n          int warp_row = warp_id / n_warps;\n\n          int cur_k = warp_row * 16;\n          cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n          int k_blocks = cur_k / 16;\n          int cur_group_id = k_blocks / group_blocks;\n\n          int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n          reinterpret_cast<int4*>(&frag_s[k % 2])[0] =\n              sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];\n        }\n      }\n\n      return;\n    }\n\n    // Act-order case\n\n    // Determine K of the \"current\" thread-block\n    int cur_k = slice_k_start + tb_k * full_pipe;\n    if (cur_k >= prob_k || cur_k >= slice_k_finish) {\n      return;\n    }\n\n    // Reset (to current thread-block) since we read g_idx portion from the\n    // shared memory\n    cur_k = 0;\n\n    // Progress to current iteration\n    cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n    // Determine \"position\" inside the thread-block (based on warp and\n    // thread-id)\n    int warp_id = threadIdx.x / 32;\n    int n_warps =\n        thread_n_blocks / 4;  // Each warp processes 4 16-size tiles over N\n\n    int warp_row = warp_id / n_warps;\n    int warp_col = warp_id % n_warps;\n\n    cur_k += warp_row * 16;\n\n    int th_id = threadIdx.x % 32;\n    cur_k += (th_id % 4) * 2;  // Due to tensor-core layout for fp16 B matrix\n\n    int s_col_shift =\n        /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +\n        (th_id / 4) * act_s_col_stride;\n\n    if (is_same_group[pipe]) {\n      if (k % 2 == 0) {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n            sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +\n                 s_col_shift];\n      } else {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n            *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));\n      }\n\n      for (int i = 1; i < 4; i++) {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =\n            *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));\n      }\n      return;\n    }\n\n    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n    constexpr int k_frag_offsets[4] = {0, 1, 8,\n                                       9};  // Tensor core offsets per thread\n\n  #pragma unroll\n    for (int i = 0; i < 4; i++) {\n      int actual_k = cur_k + k_frag_offsets[i];\n\n      int group_id = sh_g_idx_int_ptr[actual_k];\n      int rel_group_id = group_id - sh_first_group_id;\n\n      *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =\n          sh_s[rel_group_id * s_sh_stride + s_col_shift];\n    }\n  };\n\n  // Execute the actual tensor core matmul of a sub-tile.\n  auto matmul = [&](int k) {\n  // We have the m dimension as the inner loop in order to encourage overlapping\n  // dequantization and matmul operations.\n  #pragma unroll\n    for (int j = 0; j < 4; j++) {\n      FragB frag_b0;\n      FragB frag_b1;\n      if constexpr (num_bits == 4) {\n        int b_quant = frag_b_quant[k % 2][0][j];\n        int b_quant_shift = b_quant >> 8;\n\n        frag_b0 = dequant_4bit<scalar_t>(b_quant);\n        frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);\n\n      } else {\n        int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);\n        int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];\n        int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];\n\n        frag_b0 = dequant_8bit<scalar_t>(b_quant_0);\n        frag_b1 = dequant_8bit<scalar_t>(b_quant_1);\n      }\n\n      // Apply scale to frag_b0\n      if constexpr (has_act_order) {\n        scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],\n                         act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],\n                         act_frag_s[k % 2][3][j], 0);\n      } else {\n        if constexpr (group_blocks != -1) {\n          scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);\n        }\n      }\n\n      // Apply scale to frag_b1\n      if constexpr (has_act_order) {\n        scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],\n                         act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],\n                         act_frag_s[k % 2][3][j], 1);\n\n      } else {\n        if constexpr (group_blocks != -1) {\n          scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);\n        }\n      }\n\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n        mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);\n        mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);\n      }\n    }\n  };\n\n  // Since we slice across the k dimension of a tile in order to increase the\n  // number of warps while keeping the n dimension of a tile reasonable, we have\n  // multiple warps that accumulate their partial sums of the same output\n  // location; which we have to reduce over in the end. We do in shared memory.\n  auto thread_block_reduce = [&]() {\n    constexpr int red_off = threads / b_sh_stride_threads / 2;\n    if (red_off >= 1) {\n      int red_idx = threadIdx.x / b_sh_stride_threads;\n      constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;\n      constexpr int red_sh_delta = b_sh_stride_threads;\n      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +\n                      (threadIdx.x % b_sh_stride_threads);\n\n      // Parallel logarithmic shared memory reduction. We make sure to avoid any\n      // unnecessary read or write iterations, e.g., for two warps we write only\n      // once by warp 1 and read only once by warp 0.\n\n  #pragma unroll\n      for (int m_block = 0; m_block < thread_m_blocks; m_block++) {\n  #pragma unroll\n        for (int i = red_off; i > 0; i /= 2) {\n          if (i <= red_idx && red_idx < 2 * i) {\n  #pragma unroll\n            for (int j = 0; j < 4 * 2; j++) {\n              int red_sh_wr =\n                  red_sh_delta * j + (red_sh_rd - red_sh_stride * i);\n              if (i < red_off) {\n                float* c_rd =\n                    reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);\n                float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);\n  #pragma unroll\n                for (int k = 0; k < 4; k++)\n                  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=\n                      c_rd[k] + c_wr[k];\n              }\n              sh[red_sh_wr] =\n                  reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];\n            }\n          }\n          __syncthreads();\n        }\n        if (red_idx == 0) {\n  #pragma unroll\n          for (int i = 0; i < 4 * 2; i++) {\n            float* c_rd =\n                reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);\n  #pragma unroll\n            for (int j = 0; j < 4; j++)\n              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=\n                  c_rd[j];\n          }\n        }\n        __syncthreads();\n      }\n    }\n  };\n\n  // Since multiple threadblocks may process parts of the same column slice, we\n  // finally have to globally reduce over the results. As the striped\n  // partitioning minimizes the number of such reductions and our outputs are\n  // usually rather small, we perform this reduction serially in L2 cache.\n  auto global_reduce = [&](bool first = false, bool last = false) {\n    // We are very careful here to reduce directly in the output buffer to\n    // maximize L2 cache utilization in this step. To do this, we write out\n    // results in FP16 (but still reduce with FP32 compute).\n    constexpr int active_threads = 32 * thread_n_blocks / 4;\n    if (threadIdx.x < active_threads) {\n      int c_gl_stride = prob_n / 8;\n      int c_gl_wr_delta_o = 8 * c_gl_stride;\n      int c_gl_wr_delta_i = 4 * (active_threads / 32);\n      int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +\n                    4 * (threadIdx.x / 32) + threadIdx.x % 4;\n      c_gl_wr += (2 * thread_n_blocks) * slice_col;\n      constexpr int c_sh_wr_delta = active_threads;\n      int c_sh_wr = threadIdx.x;\n\n      int row = (threadIdx.x % 32) / 4;\n\n      if (!first) {\n  // Interestingly, doing direct global accesses here really seems to mess up\n  // the compiler and lead to slowdowns, hence we also use async-copies even\n  // though these fetches are not actually asynchronous.\n  #pragma unroll\n        for (int i = 0; i < thread_m_blocks * 4; i++) {\n          cp_async4_pred(\n              &sh[c_sh_wr + c_sh_wr_delta * i],\n              &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +\n                 c_gl_wr_delta_i * (i % 2)],\n              i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);\n        }\n        cp_async_fence();\n        cp_async_wait<0>();\n      }\n\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks * 4; i++) {\n        if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {\n          if (!first) {\n            int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];\n  #pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              reinterpret_cast<float*>(\n                  &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=\n                  Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);\n            }\n          }\n          if (!last) {\n            int4 c;\n  #pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              reinterpret_cast<scalar_t*>(&c)[j] =\n                  Dtype::float2num(reinterpret_cast<float*>(\n                      &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);\n            }\n            C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =\n                c;\n          }\n        }\n      }\n    }\n  };\n\n  // Write out the reduce final result in the correct layout. We only actually\n  // reshuffle matrix fragments in this step, the reduction above is performed\n  // in fragment layout.\n  auto write_result = [&]() {\n    int c_gl_stride = prob_n / 8;\n    constexpr int c_sh_stride = 2 * thread_n_blocks + 1;\n    int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));\n    constexpr int c_sh_rd_delta =\n        c_sh_stride * (threads / (2 * thread_n_blocks));\n\n    int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                  (threadIdx.x % (2 * thread_n_blocks));\n    c_gl_wr += (2 * thread_n_blocks) * slice_col;\n    int c_sh_wr =\n        (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;\n    c_sh_wr += 32 * (threadIdx.x / 32);\n    int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                  (threadIdx.x % (2 * thread_n_blocks));\n\n    int c_gl_wr_end = c_gl_stride * prob_m;\n\n    // We first reorder in shared memory to guarantee the most efficient final\n    // global write patterns\n    auto write = [&](int idx, float c0, float c1, FragS& s) {\n      scalar_t2 res =\n          Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));\n\n      // For per-column quantization we finally apply the scale here (only for\n      // 4-bit)\n      if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {\n        res = __hmul2(res, s[0]);\n      }\n\n      ((scalar_t2*)sh)[idx] = res;\n    };\n\n    if (threadIdx.x / 32 < thread_n_blocks / 4) {\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n  #pragma unroll\n        for (int j = 0; j < 4; j++) {\n          int wr = c_sh_wr + 8 * j;\n          write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],\n                frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);\n          write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],\n                frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);\n          write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],\n                frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);\n          write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],\n                frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);\n        }\n        c_sh_wr += 16 * (4 * c_sh_stride);\n      }\n    }\n    __syncthreads();\n\n  #pragma unroll\n    for (int i = 0;\n         i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));\n         i++) {\n      if (c_gl_wr < c_gl_wr_end) {\n        C[c_gl_wr] = sh[c_sh_rd];\n        c_gl_wr += c_gl_wr_delta;\n        c_sh_rd += c_sh_rd_delta;\n      }\n    }\n  };\n\n  // Start global fetch and register load pipelines.\n  auto start_pipes = [&]() {\n\n  #pragma unroll\n    for (int i = 0; i < stages - 1; i++) {\n      if (has_act_order && i == 0) {\n        int last_g_idx = slice_k_start + stages * tb_k * 2;\n        if (last_g_idx >= prob_k) {\n          last_g_idx = prob_k - 1;\n        }\n        fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);\n      }\n      fetch_to_shared(i, i, i < slice_iters);\n    }\n\n    zero_accums();\n    wait_for_stage();\n    init_same_group(0);\n    fetch_to_registers(0, 0);\n    fetch_scales_to_registers(0, 0);\n    a_gl_rd += a_gl_rd_delta_o * (stages - 1);\n    slice_k_start_shared_fetch += tb_k * (stages - 1);\n  };\n  if (slice_iters) {\n    start_pipes();\n  }\n\n  // Main loop.\n  while (slice_iters) {\n    // We unroll over both the global fetch and the register load pipeline to\n    // ensure all shared memory accesses are static. Note that both pipelines\n    // have even length meaning that the next iteration will always start at\n    // index 0.\n\n  #pragma unroll\n    for (int pipe = 0; pipe < stages;) {\n  #pragma unroll\n      for (int k = 0; k < b_sh_wr_iters; k++) {\n        fetch_to_registers(k + 1, pipe % stages);\n        fetch_scales_to_registers(k + 1, pipe);\n        if (k == b_sh_wr_iters - 2) {\n          fetch_to_shared((pipe + stages - 1) % stages, pipe,\n                          slice_iters >= stages);\n          pipe++;\n          wait_for_stage();\n          init_same_group(pipe % stages);\n        }\n        matmul(k);\n      }\n      slice_iters--;\n      if (slice_iters == 0) {\n        break;\n      }\n    }\n\n    a_gl_rd += a_gl_rd_delta_o * stages;\n    slice_k_start += tb_k * stages;\n    slice_k_start_shared_fetch += tb_k * stages;\n\n    if constexpr (has_act_order) {\n      int first_group_id = g_idx[slice_k_start];\n      int last_g_idx = slice_k_start + stages * tb_k * 2;\n      if (last_g_idx >= prob_k) {\n        last_g_idx = prob_k - 1;\n      }\n      int last_group_id = g_idx[last_g_idx];\n      if (last_group_id >= sh_first_group_id + sh_num_groups) {\n        fetch_scales_to_shared(false, first_group_id, last_group_id);\n        __syncthreads();\n      }\n    }\n\n    // Process results and, if necessary, proceed to the next column slice.\n    // While this pattern may not be the most readable, other ways of writing\n    // the loop seemed to noticeably worse performance after compilation.\n    if (slice_iters == 0) {\n      cp_async_wait<0>();\n      bool last = slice_idx == slice_count - 1;\n      // For per-column scales, we only fetch them here in the final step before\n      // write-out\n      if constexpr (!has_act_order && group_blocks == -1) {\n        if constexpr (num_bits == 8) {\n          if (s_sh_wr_pred) {\n            cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n          }\n          cp_async_fence();\n        } else {\n          if (last) {\n            if (s_sh_wr_pred) {\n              cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n            }\n            cp_async_fence();\n          }\n        }\n      }\n\n      thread_block_reduce();\n      if constexpr (!has_act_order && group_blocks == -1) {\n        if constexpr (num_bits == 8) {\n          cp_async_wait<0>();\n          __syncthreads();\n          if (threadIdx.x / 32 < thread_n_blocks / 4) {\n            reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n            reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n          }\n\n        } else {\n          if (last) {\n            cp_async_wait<0>();\n            __syncthreads();\n            if (threadIdx.x / 32 < thread_n_blocks / 4) {\n              reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n              reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n            }\n          }\n        }\n      }\n\n      // For 8-bit channelwise, we apply the scale before the global reduction\n      // that converts the fp32 results to fp16 (so that we avoid possible\n      // overflow in fp16)\n      if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {\n        if (threadIdx.x / 32 < thread_n_blocks / 4) {\n  #pragma unroll\n          for (int i = 0; i < thread_m_blocks; i++) {\n  #pragma unroll\n            for (int j = 0; j < 4; j++) {\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][0][0]),\n                  frag_s[j / 2][2 * (j % 2) + 0]);\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][0][2]),\n                  frag_s[j / 2][2 * (j % 2) + 0]);\n\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][1][0]),\n                  frag_s[j / 2][2 * (j % 2) + 1]);\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][1][2]),\n                  frag_s[j / 2][2 * (j % 2) + 1]);\n            }\n          }\n        }\n      }\n\n      if (slice_count > 1) {  // only globally reduce if there is more than one\n                              // block in a slice\n        barrier_acquire(&locks[slice_col], slice_idx);\n        global_reduce(slice_idx == 0, last);\n        barrier_release(&locks[slice_col], last);\n      }\n      if (last)  // only the last block in a slice actually writes the result\n        write_result();\n      slice_row = 0;\n      slice_col_par++;\n      slice_col++;\n      init_slice();\n      if (slice_iters) {\n        a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                  (threadIdx.x % a_gl_rd_delta_o);\n  #pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++)\n          B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;\n        if (slice_col == 0) {\n  #pragma unroll\n          for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;\n        }\n\n        // Update slice k/n for scales loading\n        if constexpr (has_act_order) {\n          slice_k_start = tb_k * slice_row;\n          slice_k_finish = slice_k_start + tb_k * slice_iters;\n          slice_k_start_shared_fetch = slice_k_start;\n          slice_n_offset = act_s_col_tb_stride * slice_col;\n\n        } else {\n          s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n        }\n\n        start_pipes();\n      }\n    }\n  }\n}\n\n  #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,                \\\n                    THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \\\n    else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS &&     \\\n             thread_n_blocks == THREAD_N_BLOCKS &&                             \\\n             thread_k_blocks == THREAD_K_BLOCKS &&                             \\\n             has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \\\n             num_threads == NUM_THREADS) {                                     \\\n      cudaFuncSetAttribute(                                                    \\\n          Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,             \\\n                 THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \\\n                 GROUP_BLOCKS>,                                                \\\n          cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);        \\\n      Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,                 \\\n             THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER,     \\\n             GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>(   \\\n          A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n,   \\\n          prob_k, locks);                                                      \\\n    }\n\ntypedef struct {\n  int thread_k;\n  int thread_n;\n  int num_threads;\n} thread_config_t;\n\ntypedef struct {\n  int max_m_blocks;\n  thread_config_t tb_cfg;\n} exec_config_t;\n\nthread_config_t small_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {128, 128, 256},\n    {64, 128, 128},\n    {128, 64, 128},\n};\n\nthread_config_t large_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {64, 256, 256},\n    {64, 128, 128},\n    {128, 64, 128},\n\n};\n\nint get_scales_cache_size(thread_config_t const& th_config, int prob_m,\n                          int prob_n, int prob_k, int num_bits, int group_size,\n                          bool has_act_order, bool is_k_full) {\n  bool cache_scales_chunk = has_act_order && !is_k_full;\n\n  int tb_n = th_config.thread_n;\n  int tb_k = th_config.thread_k;\n\n  // Get max scale groups per thread-block\n  int tb_groups;\n  if (group_size == -1) {\n    tb_groups = 1;\n  } else if (group_size == 0) {\n    tb_groups = div_ceil(tb_k, 32);  // Worst case is 32 group size\n  } else {\n    tb_groups = div_ceil(tb_k, group_size);\n  }\n\n  if (cache_scales_chunk) {\n    int load_groups =\n        tb_groups * pipe_stages * 2;     // Chunk size is 2x pipeline over dim K\n    load_groups = max(load_groups, 32);  // We load at least 32 scale groups\n    return load_groups * tb_n * 2;\n\n  } else {\n    int tb_scales = tb_groups * tb_n * 2;\n\n    return tb_scales * pipe_stages;\n  }\n}\n\nbool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,\n                         int prob_m, int prob_n, int prob_k, int num_bits,\n                         int scales_cache_size, int max_shared_mem) {\n  int pack_factor = 32 / num_bits;\n\n  // Get B size\n  int tb_k = th_config.thread_k;\n  int tb_n = th_config.thread_n;\n\n  int b_size = (tb_k * tb_n / pack_factor) * 4;\n\n  // Get A size\n  int m_blocks = div_ceil(prob_m, 16);\n  int tb_max_m = 16;\n\n  while (true) {\n    if (m_blocks >= max_m_blocks) {\n      tb_max_m *= max_m_blocks;\n      break;\n    }\n\n    max_m_blocks--;\n    if (max_m_blocks == 0) {\n      TORCH_CHECK(false, \"Unexpected m_blocks = \", m_blocks);\n    }\n  }\n\n  int a_size = (tb_max_m * tb_k) * 2;\n\n  float pipe_size = (a_size + b_size) * pipe_stages;\n\n  TORCH_CHECK(max_shared_mem / 2 > scales_cache_size);  // Sanity\n\n  return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);\n}\n\nbool is_valid_config(thread_config_t const& th_config, int max_m_blocks,\n                     int prob_m, int prob_n, int prob_k, int num_bits,\n                     int group_size, bool has_act_order, bool is_k_full,\n                     int max_shared_mem) {\n  // Sanity\n  if (th_config.thread_k == -1 || th_config.thread_n == -1 ||\n      th_config.num_threads == -1) {\n    return false;\n  }\n\n  // Verify K/N are divisible by thread K/N\n  if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {\n    return false;\n  }\n\n  // Verify min for thread K/N\n  if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {\n    return false;\n  }\n\n  // num_threads must be at least 128 (= 4 warps)\n  if (th_config.num_threads < 128) {\n    return false;\n  }\n\n  //  Determine cache for scales\n  int scales_cache_size =\n      get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,\n                            group_size, has_act_order, is_k_full);\n\n  // Check that pipeline fits into cache\n  if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                           num_bits, scales_cache_size, max_shared_mem)) {\n    return false;\n  }\n\n  return true;\n}\n\nexec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,\n                                      int num_bits, int group_size,\n                                      bool has_act_order, bool is_k_full,\n                                      int max_shared_mem) {\n  int max_m_blocks = 4;\n  while (max_m_blocks > 0) {\n    if (prob_m <= 16) {\n      for (auto th_config : small_batch_thread_configs) {\n        if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                            num_bits, group_size, has_act_order, is_k_full,\n                            max_shared_mem)) {\n          return exec_config_t{max_m_blocks, th_config};\n        }\n      }\n    } else {\n      for (auto th_config : large_batch_thread_configs) {\n        if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                            num_bits, group_size, has_act_order, is_k_full,\n                            max_shared_mem)) {\n          return exec_config_t{max_m_blocks, th_config};\n        }\n      }\n    }\n\n    max_m_blocks--;  // Process less M blocks per invocation to reduce cache\n                     // usage\n  }\n\n  return exec_config_t{0, {-1, -1, -1}};\n}\n\n  #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS)           \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)  \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)  \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)  \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)\n\ntemplate <typename scalar_t>\nvoid marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,\n                     void* g_idx, void* perm, void* a_tmp, int prob_m,\n                     int prob_n, int prob_k, void* workspace, int num_bits,\n                     bool has_act_order, bool is_k_full, int num_groups,\n                     int group_size, int dev, cudaStream_t stream, int thread_k,\n                     int thread_n, int sms, int max_par) {\n  TORCH_CHECK(num_bits == 4 || num_bits == 8,\n              \"num_bits must be 4 or 8. Got = \", num_bits);\n  TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, \"Invalid MNK = [\", prob_m,\n              \", \", prob_n, \", \", prob_k, \"]\");\n\n  int tot_m = prob_m;\n  int tot_m_blocks = div_ceil(tot_m, 16);\n  int pad = 16 * tot_m_blocks - tot_m;\n\n  if (sms == -1) {\n    cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);\n  }\n\n  int max_shared_mem = 0;\n  cudaDeviceGetAttribute(&max_shared_mem,\n                         cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n  TORCH_CHECK(max_shared_mem > 0);\n\n  // Set thread config\n  exec_config_t exec_cfg;\n  if (thread_k != -1 && thread_n != -1) {\n    // User-defined config\n    exec_cfg =\n        exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};\n  } else {\n    // Auto config\n    exec_cfg =\n        determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,\n                                has_act_order, is_k_full, max_shared_mem);\n  }\n\n  TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&\n                  is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,\n                                  prob_m, prob_n, prob_k, num_bits, group_size,\n                                  has_act_order, is_k_full, max_shared_mem),\n              \"Invalid thread config: max_m_blocks = \", exec_cfg.max_m_blocks,\n              \", thread_k = \", exec_cfg.tb_cfg.thread_k,\n              \", thread_n = \", exec_cfg.tb_cfg.thread_n,\n              \", num_threads = \", exec_cfg.tb_cfg.num_threads, \" for MKN = [\",\n              prob_m, \", \", prob_k, \", \", prob_n, \"] and num_bits = \", num_bits,\n              \", group_size = \", group_size,\n              \", has_act_order = \", has_act_order, \", is_k_full = \", is_k_full,\n              \", max_shared_mem = \", max_shared_mem);\n\n  int num_threads = exec_cfg.tb_cfg.num_threads;\n  thread_k = exec_cfg.tb_cfg.thread_k;\n  thread_n = exec_cfg.tb_cfg.thread_n;\n\n  int thread_k_blocks = thread_k / 16;\n  int thread_n_blocks = thread_n / 16;\n\n  int blocks = sms;\n\n  TORCH_CHECK(prob_n % thread_n == 0, \"prob_n = \", prob_n,\n              \" is not divisible by thread_n = \", thread_n);\n  TORCH_CHECK(prob_k % thread_k == 0, \"prob_k = \", prob_k,\n              \" is not divisible by thread_k = \", thread_k);\n\n  int group_blocks = 0;\n  if (has_act_order) {\n    if (is_k_full) {\n      TORCH_CHECK(group_size != -1);\n      group_blocks = group_size / 16;\n      TORCH_CHECK(prob_k % group_blocks == 0, \"prob_k = \", prob_k,\n                  \" is not divisible by group_blocks = \", group_blocks);\n    } else {\n      TORCH_CHECK(group_size == 0);\n      group_blocks = 0;\n    }\n\n  } else {\n    if (group_size == -1) {\n      group_blocks = -1;\n    } else {\n      group_blocks = group_size / 16;\n      TORCH_CHECK(prob_k % group_blocks == 0, \"prob_k = \", prob_k,\n                  \" is not divisible by group_blocks = \", group_blocks);\n    }\n  }\n\n  const int4* A_ptr = (const int4*)A;\n  const int4* B_ptr = (const int4*)B;\n  int4* C_ptr = (int4*)C;\n  const int4* s_ptr = (const int4*)s;\n  const int* g_idx_ptr = (const int*)g_idx;\n  const int* perm_ptr = (const int*)perm;\n  int4* a_tmp_ptr = (int4*)a_tmp;\n\n  int* locks = (int*)workspace;\n\n  if (has_act_order) {\n    // Permute A columns\n    int block_rows = div_ceil(prob_m, blocks);\n    permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(\n        A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);\n    A_ptr = a_tmp_ptr;\n  }\n\n  // If we have a full K, then we can run the non-act-order version of Marlin\n  // (since the weight rows are reordered by increasing group ids, and by having\n  // a full K, we have full original groups)\n  if (is_k_full) {\n    has_act_order = false;\n  }\n\n  // Main loop\n  for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {\n    int thread_m_blocks = tot_m_blocks - i;\n    prob_m = tot_m - 16 * i;\n    int par = 1;\n    if (thread_m_blocks > exec_cfg.max_m_blocks) {\n      // Note that parallel > 1 currently only works for inputs without any\n      // padding\n      par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);\n      if (par > max_par) par = max_par;\n      prob_m = (16 * exec_cfg.max_m_blocks) * par;\n      i += exec_cfg.max_m_blocks * (par - 1);\n      thread_m_blocks = exec_cfg.max_m_blocks;\n    }\n\n\n\n    // Define kernel configurations\n#define undefined_error TORCH_CHECK(false, \"Unsupported shapes: MNK = [\" + str(prob_m) + \", \" + \\\n    str(prob_n) + \", \" + str(prob_k) + \"]\" + \\\n        \", has_act_order = \" + str(has_act_order) + \\\n        \", num_groups = \" + str(num_groups) + \\\n        \", group_size = \" + str(group_size) + \\\n        \", thread_m_blocks = \" + str(thread_m_blocks) + \\\n        \", thread_n_blocks = \" + str(thread_n_blocks) + \\\n        \", thread_k_blocks = \" + str(thread_k_blocks));\n\n\n    if (num_bits == 4 && num_threads == 256)\n    {\n        if (false) {\n        }\n        CALL_IF(4, 32, 2, 256)\n        CALL_IF(4, 16, 4, 256)\n        CALL_IF(4, 8, 8, 256)\n        else {\n            undefined_error\n        }\n    }\n    else if (num_bits == 4 && num_threads == 128)\n    {\n        if (false) {\n        }\n        CALL_IF(4, 8, 4, 128)\n        CALL_IF(4, 4, 8, 128)\n        else {\n            undefined_error\n        }\n    }\n    else if (num_bits == 8 && num_threads == 256)\n    {\n        if (false) {\n        }\n        CALL_IF(8, 32, 2, 256)\n        CALL_IF(8, 16, 4, 256)\n        CALL_IF(8, 8, 8, 256)\n        else {\n            undefined_error\n        }\n    }\n    else if (num_bits == 8 && num_threads == 128)\n    {\n        if (false) {\n        }\n        CALL_IF(8, 8, 4, 128)\n        CALL_IF(8, 4, 8, 128)\n        else {\n            undefined_error\n        }\n    }\n    else {\n        undefined_error\n    }\n\n    A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;\n    C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;\n  }\n}\n\n}  // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                               torch::Tensor& b_scales, torch::Tensor& g_idx,\n                               torch::Tensor& perm, torch::Tensor& workspace,\n                               int64_t num_bits, int64_t size_m, int64_t size_n,\n                               int64_t size_k, bool is_k_full) {\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(a));\n  // Verify num_bits\n  TORCH_CHECK(num_bits == 4 || num_bits == 8,\n              \"num_bits must be 4 or 8. Got = \", num_bits);\n  int pack_factor = 32 / num_bits;\n\n  // Verify A\n  TORCH_CHECK(a.size(0) == size_m, \"Shape mismatch: a.size(0) = \", a.size(0),\n              \", size_m = \", size_m);\n  TORCH_CHECK(a.size(1) == size_k, \"Shape mismatch: a.size(1) = \", a.size(1),\n              \", size_k = \", size_k);\n\n  // Verify B\n  TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, \"size_k = \", size_k,\n              \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n  TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),\n              \"Shape mismatch: b_q_weight.size(0) = \", b_q_weight.size(0),\n              \", size_k = \", size_k, \", tile_size = \", gptq_marlin::tile_size);\n  TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,\n              \"b_q_weight.size(1) = \", b_q_weight.size(1),\n              \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n  int actual_size_n =\n      (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;\n  TORCH_CHECK(size_n == actual_size_n, \"size_n = \", size_n,\n              \", actual_size_n = \", actual_size_n);\n\n  // Verify device and strides\n  TORCH_CHECK(a.device().is_cuda(), \"A is not on GPU\");\n  TORCH_CHECK(a.is_contiguous(), \"A is not contiguous\");\n\n  TORCH_CHECK(b_q_weight.device().is_cuda(), \"b_q_weight is not on GPU\");\n  TORCH_CHECK(b_q_weight.is_contiguous(), \"b_q_weight is not contiguous\");\n\n  TORCH_CHECK(b_scales.device().is_cuda(), \"b_scales is not on GPU\");\n  TORCH_CHECK(b_scales.is_contiguous(), \"b_scales is not contiguous\");\n\n  TORCH_CHECK(g_idx.device().is_cuda(), \"g_idx is not on GPU\");\n  TORCH_CHECK(g_idx.is_contiguous(), \"g_idx is not contiguous\");\n\n  TORCH_CHECK(perm.device().is_cuda(), \"perm is not on GPU\");\n  TORCH_CHECK(perm.is_contiguous(), \"perm is not contiguous\");\n\n  // Alloc buffers\n  auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());\n  torch::Tensor c = torch::empty({size_m, size_n}, options);\n  torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);\n\n  // thread_k: `k` size of a thread_tile in `weights` (can usually be left as\n  // auto -1)\n  int thread_k = -1;\n  // thread_n: `n` size of a thread_tile in `weights` (can usually be left as\n  // auto -1)\n  int thread_n = -1;\n  // sms: number of SMs to use for the kernel (can usually be left as auto -1)\n  int sms = -1;\n\n  // Verify g_idx and perm\n  TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||\n                  (g_idx.size(0) == size_k && perm.size(0) == size_k),\n              \"Unexpected g_idx.size(0) = \", g_idx.size(0),\n              \" and perm.size(0) = \", perm.size(0),\n              \", where size_k = \", size_k);\n\n  // Detect groupsize and act_order\n  int num_groups = -1;\n  int group_size = -1;\n  bool has_act_order = g_idx.size(0) != 0;\n\n  int b_rank = b_scales.sizes().size();\n  TORCH_CHECK(b_rank == 2, \"b_scales rank = \", b_rank, \" is not 2\");\n  TORCH_CHECK(b_scales.size(1) == size_n, \"b_scales dim 1 = \", b_scales.size(1),\n              \" is not size_n = \", size_n);\n  num_groups = b_scales.size(0);\n\n  if (has_act_order) {\n    if (is_k_full) {\n      TORCH_CHECK(num_groups > 1, \"For act_order, num_groups must be > 1\");\n      TORCH_CHECK(size_k % num_groups == 0, \"size_k = \", size_k,\n                  \", is not divisible by num_groups = \", num_groups);\n      group_size = size_k / num_groups;\n    } else {\n      group_size = 0;\n    }\n\n  } else {\n    if (num_groups > 1) {\n      TORCH_CHECK(\n          size_k % num_groups == 0, \"size_k = \", size_k,\n          \", is not divisible by b_scales.size(0) = \", b_scales.size(0));\n      group_size = size_k / num_groups;\n    } else {\n      group_size = -1;\n    }\n  }\n\n  // Verify workspace size\n  TORCH_CHECK(\n      size_n % gptq_marlin::min_thread_n == 0, \"size_n = \", size_n,\n      \", is not divisible by min_thread_n = \", gptq_marlin::min_thread_n);\n  int min_workspace_size =\n      (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;\n  TORCH_CHECK(workspace.numel() >= min_workspace_size,\n              \"workspace.numel = \", workspace.numel(),\n              \" is below min_workspace_size = \", min_workspace_size);\n\n  int dev = a.get_device();\n  if (a.scalar_type() == at::ScalarType::Half) {\n    gptq_marlin::marlin_mm_f16i4<half>(\n        a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),\n        b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),\n        a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,\n        workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,\n        group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,\n        thread_n, sms, gptq_marlin::max_par);\n  } else if (a.scalar_type() == at::ScalarType::BFloat16) {\n    gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(\n        a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),\n        c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),\n        g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),\n        size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,\n        is_k_full, num_groups, group_size, dev,\n        at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,\n        gptq_marlin::max_par);\n  } else {\n    TORCH_CHECK(false, \"gpt_marlin_gemm only supports bfloat16 and float16\");\n  }\n\n  return c;\n}\n\n#endif\n"
  },
  {
    "path": "kt-kernel/cuda/gptq_marlin/gptq_marlin.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n// Copyrigth 2024 The vLLM team.\n// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#pragma once\n\n#include <torch/all.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <iostream>\n\nnamespace gptq_marlin {\n\n// 8 warps are a good choice since every SM has 4 schedulers and having more\n// than 1 warp per schedule allows some more latency hiding. At the same time,\n// we want relatively few warps to have many registers per warp and small tiles.\nstatic constexpr int default_threads = 256;\n\nstatic constexpr int pipe_stages =\n    4;  // 4 pipeline stages fit into shared memory\n\nstatic constexpr int min_thread_n = 64;\nstatic constexpr int min_thread_k = 64;\n\nstatic constexpr int tile_size = 16;\nstatic constexpr int max_par = 16;\n\ntemplate <typename T, int n>\nstruct Vec {\n  T elems[n];\n  __device__ T& operator[](int i) { return elems[i]; }\n};\n\nusing I4 = Vec<int, 4>;\n\nconstexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }\n\n#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__)\n// No support for async\n#else\n\n__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,\n                                      bool pred = true) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n      \"{\\n\"\n      \"   .reg .pred p;\\n\"\n      \"   setp.ne.b32 p, %0, 0;\\n\"\n      \"   @p cp.async.cg.shared.global [%1], [%2], %3;\\n\"\n      \"}\\n\" ::\"r\"((int)pred),\n      \"r\"(smem), \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n      \"{\\n\"\n      \"   cp.async.cg.shared.global [%0], [%1], %2;\\n\"\n      \"}\\n\" ::\"r\"(smem),\n      \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async_fence() {\n  asm volatile(\"cp.async.commit_group;\\n\" ::);\n}\n\ntemplate <int n>\n__device__ inline void cp_async_wait() {\n  asm volatile(\"cp.async.wait_group %0;\\n\" ::\"n\"(n));\n}\n\n#endif\n\n}  // namespace gptq_marlin\n"
  },
  {
    "path": "kt-kernel/cuda/gptq_marlin/gptq_marlin_dtypes.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n// Copyrigth 2024 The vLLM team.\n// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#ifndef _data_types_cuh\n#define _data_types_cuh\n#include \"gptq_marlin.cuh\"\n#include <cuda_fp16.h>\n#include <cuda_bf16.h>\n\n#ifdef __HIP_PLATFORM_AMD__\ntypedef __hip_bfloat16 nv_bfloat16;\ntypedef __hip_bfloat162 nv_bfloat162;\n#endif\n\nnamespace gptq_marlin {\n\ntemplate <typename scalar_t>\nclass ScalarType {};\n\ntemplate <>\nclass ScalarType<half> {\n public:\n  using scalar_t = half;\n  using scalar_t2 = half2;\n\n  // Matrix fragments for tensor core instructions; their precise layout is\n  // documented here:\n  // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type\n  using FragA = Vec<half2, 4>;\n  using FragB = Vec<half2, 2>;\n  using FragC = Vec<float, 4>;\n  using FragS = Vec<half2, 1>;\n\n  static __device__ float inline num2float(const half x) {\n    return __half2float(x);\n  }\n\n  static __device__ half2 inline num2num2(const half x) {\n    return __half2half2(x);\n  }\n\n  static __device__ half2 inline nums2num2(const half x1, const half x2) {\n    return __halves2half2(x1, x2);\n  }\n\n  static __host__ __device__ half inline float2num(const float x) {\n    return __float2half(x);\n  }\n};\n\ntemplate <>\nclass ScalarType<nv_bfloat16> {\n public:\n  using scalar_t = nv_bfloat16;\n  using scalar_t2 = nv_bfloat162;\n\n  using FragA = Vec<nv_bfloat162, 4>;\n  using FragB = Vec<nv_bfloat162, 2>;\n  using FragC = Vec<float, 4>;\n  using FragS = Vec<nv_bfloat162, 1>;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  static __device__ float inline num2float(const nv_bfloat16 x) {\n    return __bfloat162float(x);\n  }\n\n  static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {\n    return __bfloat162bfloat162(x);\n  }\n\n  static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,\n                                                  const nv_bfloat16 x2) {\n    return __halves2bfloat162(x1, x2);\n  }\n\n  static __host__ __device__ nv_bfloat16 inline float2num(const float x) {\n    return __float2bfloat16(x);\n  }\n#endif\n};\n\n}  // namespace gptq_marlin\n\n#endif\n"
  },
  {
    "path": "kt-kernel/cuda/gptq_marlin/ops.h",
    "content": "/**\n * @Description  :\n * @Author       : Azure\n * @Date         : 2024-07-22 09:27:55\n * @Version      : 1.0.0\n * @LastEditors  : Azure\n * @LastEditTime : 2024-07-26 08:35:00\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#pragma once\n\n#include <torch/extension.h>\n#include <torch/library.h>\n#include <torch/torch.h>\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales,\n                               torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, int64_t num_bits,\n                               int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full);\n\n// torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,\n//                                  int64_t size_k, int64_t size_n,\n//                                  int64_t num_bits);"
  },
  {
    "path": "kt-kernel/cuda/moe/moe_topk_softmax_kernels.cu",
    "content": "// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/moe/topk_softmax_kernels.cu\n// which is originally adapted from\n// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu\n\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <torch/all.h>\n\n#ifndef USE_ROCM\n#include <cub/cub.cuh>\n#include <cub/util_type.cuh>\n#else\n#include <hipcub/hipcub.hpp>\n#include <hipcub/util_type.hpp>\n#endif\n\n#include \"utils.h\"\n\n#define MAX(a, b) ((a) > (b) ? (a) : (b))\n#define MIN(a, b) ((a) < (b) ? (a) : (b))\n\n/// Aligned array type\ntemplate <\n    typename T,\n    /// Number of elements in the array\n    int N,\n    /// Alignment requirement in bytes\n    int Alignment = sizeof(T) * N>\nclass alignas(Alignment) AlignedArray {\n  float data[N];\n};\n\n// ====================== Softmax things ===============================\n// We have our own implementation of softmax here so we can support transposing the output\n// in the softmax kernel when we extend this module to support expert-choice routing.\ntemplate <int TPB>\n__launch_bounds__(TPB) __global__\n    void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) {\n  using BlockReduce = cub::BlockReduce<float, TPB>;\n  __shared__ typename BlockReduce::TempStorage tmpStorage;\n\n  __shared__ float normalizing_factor;\n  __shared__ float float_max;\n\n  const int thread_row_offset = blockIdx.x * num_cols;\n\n  cub::Sum sum;\n  float threadData(-FLT_MAX);\n\n  // Don't touch finished rows.\n  if ((finished != nullptr) && finished[blockIdx.x]) {\n    return;\n  }\n\n  for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {\n    const int idx = thread_row_offset + ii;\n    threadData = max(static_cast<float>(input[idx]), threadData);\n  }\n\n  const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());\n\n  if (threadIdx.x == 0) {\n    float_max = maxElem;\n  }\n  __syncthreads();\n\n  threadData = 0;\n\n  for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {\n    const int idx = thread_row_offset + ii;\n    threadData += exp((static_cast<float>(input[idx]) - float_max));\n  }\n\n  const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);\n\n  if (threadIdx.x == 0) {\n    normalizing_factor = 1.f / Z;\n  }\n  __syncthreads();\n\n  for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {\n    const int idx = thread_row_offset + ii;\n    const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;\n    output[idx] = val;\n  }\n}\n\ntemplate <int TPB>\n__launch_bounds__(TPB) __global__ void moeTopK(\n    const float* inputs_after_softmax,\n    const bool* finished,\n    float* output,\n    int* indices,\n    int* source_rows,\n    const int num_experts,\n    const int k,\n    const int start_expert,\n    const int end_expert) {\n  using cub_kvp = cub::KeyValuePair<int, float>;\n  using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;\n  __shared__ typename BlockReduce::TempStorage tmpStorage;\n\n  cub_kvp thread_kvp;\n  cub::ArgMax arg_max;\n\n  const int num_rows = gridDim.x;\n  const int block_row = blockIdx.x;\n\n  const bool row_is_active = finished ? !finished[block_row] : true;\n  const int thread_read_offset = blockIdx.x * num_experts;\n  for (int k_idx = 0; k_idx < k; ++k_idx) {\n    thread_kvp.key = 0;\n    thread_kvp.value = -1.f;  // This is OK because inputs are probabilities\n\n    cub_kvp inp_kvp;\n    for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {\n      const int idx = thread_read_offset + expert;\n      inp_kvp.key = expert;\n      inp_kvp.value = inputs_after_softmax[idx];\n\n      for (int prior_k = 0; prior_k < k_idx; ++prior_k) {\n        const int prior_winning_expert = indices[k * block_row + prior_k];\n\n        if (prior_winning_expert == expert) {\n          inp_kvp = thread_kvp;\n        }\n      }\n\n      thread_kvp = arg_max(inp_kvp, thread_kvp);\n    }\n\n    const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);\n    if (threadIdx.x == 0) {\n      // Ignore experts the node isn't responsible for with expert parallelism\n      const int expert = result_kvp.key;\n      const bool node_uses_expert = expert >= start_expert && expert < end_expert;\n      const bool should_process_row = row_is_active && node_uses_expert;\n\n      const int idx = k * block_row + k_idx;\n      output[idx] = result_kvp.value;\n      indices[idx] = should_process_row ? (expert - start_expert) : num_experts;\n      assert(indices[idx] >= 0);\n      source_rows[idx] = k_idx * num_rows + block_row;\n    }\n    __syncthreads();\n  }\n}\n\n// ====================== TopK softmax things ===============================\n\n/*\n  A Top-K gating softmax written to exploit when the number of experts in the MoE layers\n  are a small power of 2. This allows us to cleanly share the rows among the threads in\n  a single warp and eliminate communication between warps (so no need to use shared mem).\n\n  It fuses the softmax, max and argmax into a single kernel.\n\n  Limitations:\n  1) This implementation is intended for when the number of experts is a small power of 2.\n  2) This implementation assumes k is small, but will work for any k.\n*/\n\ntemplate <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>\n__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax(\n    const float* input,\n    const bool* finished,\n    float* output,\n    const int num_rows,\n    int* indices,\n    int* source_rows,\n    const int k,\n    const int start_expert,\n    const int end_expert) {\n  // We begin by enforcing compile time assertions and setting up compile time constants.\n  static_assert(VPT == (VPT & -VPT), \"VPT must be power of 2\");\n  static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), \"NUM_EXPERTS must be power of 2\");\n  static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), \"BYTES_PER_LDG must be power of 2\");\n  static_assert(BYTES_PER_LDG <= 16, \"BYTES_PER_LDG must be leq 16\");\n\n  // Number of bytes each thread pulls in per load\n  static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);\n  static constexpr int ELTS_PER_ROW = NUM_EXPERTS;\n  static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;\n  static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;\n\n  // Restrictions based on previous section.\n  static_assert(VPT % ELTS_PER_LDG == 0, \"The elements per thread must be a multiple of the elements per ldg\");\n  static_assert(WARP_SIZE % THREADS_PER_ROW == 0, \"The threads per row must cleanly divide the threads per warp\");\n  static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), \"THREADS_PER_ROW must be power of 2\");\n  static_assert(THREADS_PER_ROW <= WARP_SIZE, \"THREADS_PER_ROW can be at most warp size\");\n\n  // We have NUM_EXPERTS elements per row. We specialize for small #experts\n  static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;\n  static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;\n  static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;\n\n  // Restrictions for previous section.\n  static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, \"The elts per row must cleanly divide the total elt per warp\");\n\n  // ===================== From this point, we finally start computing run-time variables. ========================\n\n  // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.\n  // This, each block processes a chunk of rows. We start by computing the start row for each block.\n  const int cta_base_row = blockIdx.x * ROWS_PER_CTA;\n\n  // Now, using the base row per thread block, we compute the base row per warp.\n  const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;\n\n  // The threads in a warp are split into sub-groups that will work on a row.\n  // We compute row offset for each thread sub-group\n  const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;\n  const int thread_row = warp_base_row + thread_row_in_warp;\n\n  // Threads with indices out of bounds should early exit here.\n  if (thread_row >= num_rows) {\n    return;\n  }\n  const bool row_is_active = finished ? !finished[thread_row] : true;\n\n  // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the\n  // row it will read.\n  const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;\n\n  // Now, we compute the group each thread belong to in order to determine the first column to start loads.\n  const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;\n  const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;\n  const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;\n\n  // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,\n  // this can support all powers of 2 up to 16.\n  // NOTE(woosuk): The original implementation uses CUTLASS aligned array here.\n  // We defined our own aligned array and use it here to avoid the dependency on CUTLASS.\n  using AccessType = AlignedArray<float, ELTS_PER_LDG>;\n\n  // Finally, we pull in the data from global mem\n  float row_chunk[VPT];\n  AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);\n  const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);\n#pragma unroll\n  for (int ii = 0; ii < LDG_PER_THREAD; ++ii) {\n    row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];\n  }\n\n  // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just\n  // convert to float afterwards for the exp + sum reduction.\n  float thread_max = row_chunk[0];\n#pragma unroll\n  for (int ii = 1; ii < VPT; ++ii) {\n    thread_max = max(thread_max, row_chunk[ii]);\n  }\n\n// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.\n#pragma unroll\n  for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {\n    thread_max = max(thread_max, SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, thread_max, mask, THREADS_PER_ROW));\n  }\n\n  // From this point, thread max in all the threads have the max within the row.\n  // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.\n  float row_sum = 0;\n#pragma unroll\n  for (int ii = 0; ii < VPT; ++ii) {\n    row_chunk[ii] = expf(row_chunk[ii] - thread_max);\n    row_sum += row_chunk[ii];\n  }\n\n// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.\n#pragma unroll\n  for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {\n    row_sum += SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, row_sum, mask, THREADS_PER_ROW);\n  }\n\n  // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables\n  // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to\n  // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.\n  // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the\n  // argmax after computing the softmax.\n  const float reciprocal_row_sum = 1.f / row_sum;\n\n#pragma unroll\n  for (int ii = 0; ii < VPT; ++ii) {\n    row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;\n  }\n\n  // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along\n  // with the max index.\n  int start_col = first_elt_read_by_thread;\n  static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;\n\n  for (int k_idx = 0; k_idx < k; ++k_idx) {\n    // First, each thread does the local argmax\n    float max_val = row_chunk[0];\n    int expert = start_col;\n#pragma unroll\n    for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) {\n#pragma unroll\n      for (int ii = 0; ii < ELTS_PER_LDG; ++ii) {\n        float val = row_chunk[ldg * ELTS_PER_LDG + ii];\n\n        // No check on the experts here since columns with the smallest index are processed first and only\n        // updated if > (not >=)\n        if (val > max_val) {\n          max_val = val;\n          expert = col + ii;\n        }\n      }\n    }\n\n// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.\n// This will be useful for K > 1 so that the threads can agree on \"who\" had the max value. That thread can\n// then blank out their max with -inf and the warp can run more iterations...\n#pragma unroll\n    for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {\n      float other_max = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, max_val, mask, THREADS_PER_ROW);\n      int other_expert = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, expert, mask, THREADS_PER_ROW);\n\n      // We want lower indices to \"win\" in every thread so we break ties this way\n      if (other_max > max_val || (other_max == max_val && other_expert < expert)) {\n        max_val = other_max;\n        expert = other_expert;\n      }\n    }\n\n    // Write the max for this k iteration to global memory.\n    if (thread_group_idx == 0) {\n      // Add a guard to ignore experts not included by this node\n      const bool node_uses_expert = expert >= start_expert && expert < end_expert;\n      const bool should_process_row = row_is_active && node_uses_expert;\n\n      // The lead thread from each sub-group will write out the final results to global memory. (This will be a\n      // single) thread per row of the input/output matrices.\n      const int idx = k * thread_row + k_idx;\n      output[idx] = max_val;\n      indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;\n      source_rows[idx] = k_idx * num_rows + thread_row;\n    }\n\n    // Finally, we clear the value in the thread with the current max if there is another iteration to run.\n    if (k_idx + 1 < k) {\n      const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;\n      const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;\n\n      // Only the thread in the group which produced the max will reset the \"winning\" value to -inf.\n      if (thread_group_idx == thread_to_clear_in_group) {\n        const int offset_for_expert = expert % ELTS_PER_LDG;\n        // Safe to set to any negative value since row_chunk values must be between 0 and 1.\n        row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;\n      }\n    }\n  }\n}\n\nnamespace detail {\n// Constructs some constants needed to partition the work across threads at compile time.\ntemplate <int EXPERTS, int BYTES_PER_LDG>\nstruct TopkConstants {\n  static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);\n  static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, \"\");\n  static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));\n  static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;\n  static constexpr int THREADS_PER_ROW = EXPERTS / VPT;\n  static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;\n};\n}  // namespace detail\n\ntemplate <int EXPERTS, int WARPS_PER_TB>\nvoid topkGatingSoftmaxLauncherHelper(\n    const float* input,\n    const bool* finished,\n    float* output,\n    int* indices,\n    int* source_row,\n    const int num_rows,\n    const int k,\n    const int start_expert,\n    const int end_expert,\n    cudaStream_t stream) {\n  static constexpr std::size_t MAX_BYTES_PER_LDG = 16;\n\n  static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);\n  using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;\n  static constexpr int VPT = Constants::VPT;\n  static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;\n  const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;\n  const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;\n\n  dim3 block_dim(WARP_SIZE, WARPS_PER_TB);\n  topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(\n      input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);\n}\n\n#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB)             \\\n  topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \\\n      gating_output,                                          \\\n      nullptr,                                                \\\n      topk_weights,                                           \\\n      topk_indices,                                           \\\n      token_expert_indices,                                   \\\n      num_tokens,                                             \\\n      topk,                                                   \\\n      0,                                                      \\\n      num_experts,                                            \\\n      stream);\n\nvoid topkGatingSoftmaxKernelLauncher(\n    const float* gating_output,\n    float* topk_weights,\n    int* topk_indices,\n    int* token_expert_indices,\n    float* softmax_workspace,\n    const int num_tokens,\n    const int num_experts,\n    const int topk,\n    cudaStream_t stream) {\n  static constexpr int WARPS_PER_TB = 4;\n  switch (num_experts) {\n    case 1:\n      LAUNCH_SOFTMAX(1, WARPS_PER_TB);\n      break;\n    case 2:\n      LAUNCH_SOFTMAX(2, WARPS_PER_TB);\n      break;\n    case 4:\n      LAUNCH_SOFTMAX(4, WARPS_PER_TB);\n      break;\n    case 8:\n      LAUNCH_SOFTMAX(8, WARPS_PER_TB);\n      break;\n    case 16:\n      LAUNCH_SOFTMAX(16, WARPS_PER_TB);\n      break;\n    case 32:\n      LAUNCH_SOFTMAX(32, WARPS_PER_TB);\n      break;\n    case 64:\n      LAUNCH_SOFTMAX(64, WARPS_PER_TB);\n      break;\n    case 128:\n      LAUNCH_SOFTMAX(128, WARPS_PER_TB);\n      break;\n    case 256:\n      LAUNCH_SOFTMAX(256, WARPS_PER_TB);\n      break;\n    default: {\n      TORCH_CHECK(\n          softmax_workspace != nullptr,\n          \"softmax_workspace must be provided for num_experts that are not a power of 2.\");\n      static constexpr int TPB = 256;\n      moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(gating_output, nullptr, softmax_workspace, num_experts);\n      moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(\n          softmax_workspace,\n          nullptr,\n          topk_weights,\n          topk_indices,\n          token_expert_indices,\n          num_experts,\n          topk,\n          0,\n          num_experts);\n    }\n  }\n}\n\nvoid topk_softmax(\n    torch::Tensor& topk_weights,          // [num_tokens, topk]\n    torch::Tensor& topk_indices,          // [num_tokens, topk]\n    torch::Tensor& token_expert_indices,  // [num_tokens, topk]\n    torch::Tensor& gating_output)         // [num_tokens, num_experts]\n{\n  const int num_experts = gating_output.size(-1);\n  const int num_tokens = gating_output.numel() / num_experts;\n  const int topk = topk_weights.size(-1);\n\n  const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);\n  const bool needs_workspace = !is_pow_2 || num_experts > 256;\n  const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;\n\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));\n  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());\n  topkGatingSoftmaxKernelLauncher(\n      gating_output.data_ptr<float>(),\n      topk_weights.data_ptr<float>(),\n      topk_indices.data_ptr<int>(),\n      token_expert_indices.data_ptr<int>(),\n      softmax_workspace.data_ptr<float>(),\n      num_tokens,\n      num_experts,\n      topk,\n      stream);\n}\n"
  },
  {
    "path": "kt-kernel/cuda/moe/ops.h",
    "content": "#pragma once\n\n#include <torch/extension.h>\n#include <torch/library.h>\n#include <torch/torch.h>\n\nvoid topk_softmax(torch::Tensor& topk_weights,          // [num_tokens, topk]\n                  torch::Tensor& topk_indices,          // [num_tokens, topk]\n                  torch::Tensor& token_expert_indices,  // [num_tokens, topk]\n                  torch::Tensor& gating_output);"
  },
  {
    "path": "kt-kernel/cuda/moe/utils.h",
    "content": "#pragma once\n\n#include <ATen/Tensor.h>\n#include <cuda_runtime.h>\n#include <torch/all.h>\n\n#include <sstream>\n\n#ifndef USE_ROCM\n// Adapt from FlashInfer\n#ifdef FLASHINFER_ENABLE_F16\n#define _DISPATCH_CASE_F16(c_type, ...) \\\n  case at::ScalarType::Half: {          \\\n    using c_type = nv_half;             \\\n    return __VA_ARGS__();               \\\n  }\n#else\n#define _DISPATCH_CASE_F16(c_type, ...)\n#endif\n\n#ifdef FLASHINFER_ENABLE_BF16\n#define _DISPATCH_CASE_BF16(c_type, ...) \\\n  case at::ScalarType::BFloat16: {       \\\n    using c_type = nv_bfloat16;          \\\n    return __VA_ARGS__();                \\\n  }\n#else\n#define _DISPATCH_CASE_BF16(c_type, ...)\n#endif\n\n#ifdef FLASHINFER_ENABLE_FP8_E4M3\n#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \\\n  case at::ScalarType::Float8_e4m3fn: {      \\\n    using c_type = __nv_fp8_e4m3;            \\\n    return __VA_ARGS__();                    \\\n  }\n#else\n#define _DISPATCH_CASE_FP8_E4M3(c_type, ...)\n#endif\n\n#ifdef FLASHINFER_ENABLE_FP8_E5M2\n#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \\\n  case at::ScalarType::Float8_e5m2: {        \\\n    using c_type = __nv_fp8_e5m2;            \\\n    return __VA_ARGS__();                    \\\n  }\n#else\n#define _DISPATCH_CASE_FP8_E5M2(c_type, ...)\n#endif\n\n#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...)                 \\\n  [&]() -> bool {                                                                        \\\n    switch (pytorch_dtype) {                                                             \\\n      _DISPATCH_CASE_F16(c_type, __VA_ARGS__)                                            \\\n      _DISPATCH_CASE_BF16(c_type, __VA_ARGS__)                                           \\\n      default:                                                                           \\\n        std::ostringstream oss;                                                          \\\n        oss << __PRETTY_FUNCTION__ << \" failed to dispatch data type \" << pytorch_dtype; \\\n        TORCH_CHECK(false, oss.str());                                                   \\\n        return false;                                                                    \\\n    }                                                                                    \\\n  }()\n\n#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...)                      \\\n  [&]() -> bool {                                                                            \\\n    switch (pytorch_dtype) {                                                                 \\\n      _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__)                                           \\\n      _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__)                                           \\\n      default:                                                                               \\\n        std::ostringstream oss;                                                              \\\n        oss << __PRETTY_FUNCTION__ << \" failed to dispatch fp8 data type \" << pytorch_dtype; \\\n        TORCH_CHECK(false, oss.str());                                                       \\\n        return false;                                                                        \\\n    }                                                                                        \\\n  }()\n\n#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...)                      \\\n  [&]() -> bool {                                                                        \\\n    switch (pytorch_dtype) {                                                             \\\n      _DISPATCH_CASE_F16(c_type, __VA_ARGS__)                                            \\\n      _DISPATCH_CASE_BF16(c_type, __VA_ARGS__)                                           \\\n      _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__)                                       \\\n      _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__)                                       \\\n      default:                                                                           \\\n        std::ostringstream oss;                                                          \\\n        oss << __PRETTY_FUNCTION__ << \" failed to dispatch data type \" << pytorch_dtype; \\\n        TORCH_CHECK(false, oss.str());                                                   \\\n        return false;                                                                    \\\n    }                                                                                    \\\n  }()\n\n#define _DISPATCH_SWITCH(var_name, cond, ...)                                           \\\n  [&]() -> bool {                                                                       \\\n    switch (cond) {                                                                     \\\n      __VA_ARGS__                                                                       \\\n      default:                                                                          \\\n        std::ostringstream oss;                                                         \\\n        oss << __PRETTY_FUNCTION__ << \" failed to dispatch \" var_name \" \" << int(cond); \\\n        TORCH_CHECK(false, oss.str());                                                  \\\n        return false;                                                                   \\\n    }                                                                                   \\\n  }()\n\n#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...)                                             \\\n  [&]() -> bool {                                                                                                   \\\n    switch (pack_u16(cond1, cond2)) {                                                                               \\\n      __VA_ARGS__                                                                                                   \\\n      default:                                                                                                      \\\n        std::ostringstream oss;                                                                                     \\\n        oss << __PRETTY_FUNCTION__ << \" failed to dispatch (\" var1_name \", \" var2_name \"): (\" << int(cond1) << \", \" \\\n            << int(cond2) << \")\";                                                                                   \\\n        TORCH_CHECK(false, oss.str());                                                                              \\\n        return false;                                                                                               \\\n    }                                                                                                               \\\n  }()\n\n#define _DISPATCH_CASE(case_expr, case_var, ...) \\\n  case case_expr: {                              \\\n    constexpr auto case_var = case_expr;         \\\n    return __VA_ARGS__();                        \\\n  }\n\n#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \\\n  case pack_u16(case_expr1, case_expr2): {                                      \\\n    constexpr auto case_var1 = case_expr1;                                      \\\n    constexpr auto case_var2 = case_expr2;                                      \\\n    return __VA_ARGS__();                                                       \\\n  }\n\n#define DISPATCH_BOOL(expr, const_expr, ...) \\\n  [&]() -> bool {                            \\\n    if (expr) {                              \\\n      constexpr bool const_expr = true;      \\\n      return __VA_ARGS__();                  \\\n    } else {                                 \\\n      constexpr bool const_expr = false;     \\\n      return __VA_ARGS__();                  \\\n    }                                        \\\n  }()\n\ninline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {\n  TORCH_CHECK(a.dim() == b.dim(), a_name, \".dim() != \", b_name, \".dim(). \", a.dim(), \" vs \", b.dim());\n  for (int i = 0; i < a.dim(); ++i) {\n    TORCH_CHECK(a.size(i) == b.size(i), a_name, \".size(\", i, \") != \", b_name, \".size(\", i, \")\");\n  }\n}\n\ninline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { return (uint32_t(a) << 16) | uint32_t(b); }\n\n#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads)                                                           \\\n  TORCH_CHECK(num_qo_heads % num_kv_heads == 0, \"num_qo_heads(\", num_qo_heads, \") must be divisible by num_kv_heads(\", \\\n              num_kv_heads, \")\")\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_LAST_DIM_CONTIGUOUS(x) \\\n  TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x \"must be contiguous at last dimension\")\n\n#define CHECK_INPUT(x) \\\n  CHECK_CUDA(x);       \\\n  CHECK_CONTIGUOUS(x)\n#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \\\n  CHECK_CUDA(x);                           \\\n  CHECK_LAST_DIM_CONTIGUOUS(x)\n\n#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x \" must be a \" #d \"D tensor\")\n\n#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)\n\n#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), \"CHECK_EQ(\" #a \", \" #b \") failed. \", a, \" vs \", b)\n\n#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), \"CHECK_GE(\" #a \", \" #b \") failed. \", a, \" vs \", b)\n\ninline bool is_float8_tensor(const at::Tensor& tensor) {\n  return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2;\n}\n#endif\n\nstruct cuda_error : public std::runtime_error {\n  /**\n   * @brief Constructs a `cuda_error` object with the given `message`.\n   *\n   * @param message The error char array used to construct `cuda_error`\n   */\n  cuda_error(const char* message) : std::runtime_error(message) {}\n  /**\n   * @brief Constructs a `cuda_error` object with the given `message` string.\n   *\n   * @param message The `std::string` used to construct `cuda_error`\n   */\n  cuda_error(std::string const& message) : cuda_error{message.c_str()} {}\n};\n\n#define CHECK_CUDA_SUCCESS(cmd)                                         \\\n  do {                                                                  \\\n    cudaError_t e = cmd;                                                \\\n    if (e != cudaSuccess) {                                             \\\n      std::stringstream _message;                                       \\\n      auto s = cudaGetErrorString(e);                                   \\\n      _message << std::string(s) + \"\\n\" << __FILE__ << ':' << __LINE__; \\\n      throw cuda_error(_message.str());                                 \\\n    }                                                                   \\\n  } while (0)\n\n#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_CUDA_INPUT(x) \\\n  CHECK_IS_CUDA(x);         \\\n  CHECK_IS_CONTIGUOUS(x)\n\ninline int getSMVersion() {\n  int device{-1};\n  CHECK_CUDA_SUCCESS(cudaGetDevice(&device));\n  int sm_major = 0;\n  int sm_minor = 0;\n  CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));\n  CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));\n  return sm_major * 10 + sm_minor;\n}\n\n// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28\n#ifndef USE_ROCM\n#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask))\n#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width))\n#else\n#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask))\n#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width))\n#endif\n\n#ifndef USE_ROCM\n#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...)           \\\n  [&]() -> bool {                                                                        \\\n    switch (pytorch_dtype) {                                                             \\\n      case at::ScalarType::Float: {                                                      \\\n        using c_type = float;                                                            \\\n        return __VA_ARGS__();                                                            \\\n      }                                                                                  \\\n        _DISPATCH_CASE_F16(c_type, __VA_ARGS__)                                          \\\n        _DISPATCH_CASE_BF16(c_type, __VA_ARGS__)                                         \\\n      default:                                                                           \\\n        std::ostringstream oss;                                                          \\\n        oss << __PRETTY_FUNCTION__ << \" failed to dispatch data type \" << pytorch_dtype; \\\n        TORCH_CHECK(false, oss.str());                                                   \\\n        return false;                                                                    \\\n    }                                                                                    \\\n  }()\n#endif\n\n#define DISPATCH_CASE_INTEGRAL_TYPES(...)              \\\n  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)  \\\n  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)  \\\n  AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \\\n  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)   \\\n  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)\n\n#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \\\n  AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))\n\n#define CEILDIV(x, y) (((x) + (y) - 1) / (y))\n#define WARP_SIZE 32\n\n#ifndef USE_ROCM\n#include <c10/util/Float8_e4m3fn.h>\nusing FP8_TYPE = c10::Float8_e4m3fn;\nC10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();\n#else\n#include <c10/util/Float8_e4m3fnuz.h>\n\nusing FP8_TYPE = c10::Float8_e4m3fnuz;\nconstexpr auto FP8_E4M3_MAX = 224.0f;\n#endif\n\n#ifndef USE_ROCM\n__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {\n  float old;\n  old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))\n                     : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));\n  return old;\n}\n\n__device__ __forceinline__ float warpReduceMax(float max_value) {\n  max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 16));\n  max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 8));\n  max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 4));\n  max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 2));\n  max_value = fmaxf(max_value, SGLANG_SHFL_XOR_SYNC(0xffffffff, max_value, 1));\n  return max_value;\n}\n\n__device__ __forceinline__ float blockReduceMax(float max_value) {\n  static __shared__ float warpLevelMaxs[WARP_SIZE];\n  const int laneId = threadIdx.x % WARP_SIZE;\n  const int warpId = threadIdx.x / WARP_SIZE;\n\n  max_value = warpReduceMax(max_value);\n\n  if (laneId == 0) warpLevelMaxs[warpId] = max_value;\n  __syncthreads();\n\n  max_value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0;\n  if (warpId == 0) max_value = warpReduceMax(max_value);\n\n  return max_value;\n}\n#endif\n\n// Pads to a multiple of `alignment` rows.\ninline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) {\n  int64_t rows = tensor.size(0);\n  int64_t cols = tensor.size(1);\n  int64_t pad_rows = (alignment - (rows % alignment)) % alignment;  // Compute padding size\n\n  if (pad_rows == 0) {\n    return tensor;  // Already aligned\n  }\n\n  torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options());\n  torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0);  // Pad along rows\n\n  // Ensure column-major layout\n  if (is_column_major) {\n    return tensor_padded.t().contiguous().t();\n  }\n  return tensor_padded;\n}\n"
  },
  {
    "path": "kt-kernel/cuda/setup.py",
    "content": "\nfrom setuptools import setup, Extension\nfrom torch.utils import cpp_extension\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nsetup(\n    name='KTransformersOps',\n    ext_modules=[\n        CUDAExtension(\n            'KTransformersOps', [\n                'custom_gguf/dequant.cu',\n                'binding.cpp',\n                'gptq_marlin/gptq_marlin.cu',\n                'moe/moe_topk_softmax_kernels.cu',\n                # 'gptq_marlin_repack.cu',\n            ],\n            extra_compile_args={\n                'cxx': ['-O3'],\n                'nvcc': [\n                    '-O3',\n                    '--use_fast_math',\n                    '-Xcompiler', '-fPIC',\n                ]\n            },\n        )\n    ],\n    cmdclass={'build_ext': BuildExtension}\n)"
  },
  {
    "path": "kt-kernel/cuda/test_dequant.py",
    "content": "import os\nimport sys\nsys.path.insert(0,\"/home/zbx/ktransformers\")\nfrom ktransformers.util.custom_loader import GGUFLoader\nimport torch\n\ngguf_loader_1 = GGUFLoader(\"/mnt/data/model/DeepseekV3-q4km-gguf\")\ngguf_loader_2 = GGUFLoader(\"/mnt/data/chenht/model/gguf_for_ktransformers/DeepSeek-V3-bf16/\")\n\ntorch.set_default_dtype(torch.bfloat16)\n\ntensor_1 = gguf_loader_1.load_gguf_tensor(\"blk.0.attn_kv_a_mqa.weight\", \"cuda\")\ntensor_2 = gguf_loader_2.load_gguf_tensor(\"blk.0.attn_kv_a_mqa.weight\", \"cuda\")\n\nprint(tensor_1[0, -64:])\nprint(tensor_2[0, -64:])"
  },
  {
    "path": "kt-kernel/demo/.gitignore",
    "content": "test.out\nfp16-test\ntest.out"
  },
  {
    "path": "kt-kernel/demo/Makefile",
    "content": "# CFLAGS += $(shell pkg-config --cflags hwloc)\n# CFLAGS += -march=armv8.2-a+fp16+dotprod+sve+bf16 -I/home/test/kt-code/HPCKit_25.0.0_Linux-aarch64/package/KunpengHPCKit-kml.25.0.0/include\n# CFLAGS += -march=armv8.2-a+fp16+dotprod+sve+bf16 -I/home/test/kt-code/HPCKit_25.0.0_Linux-aarch64/package/KunpengHPCKit-kml.25.0.0/include\nCFLAGS += -O3\nCFLAGS += -I/usr/local/include/blis/ -fopenmp\nLDLIBS += -L/usr/local/lib -lblis\n# LDLIBS += $(shell pkg-config --libs hwloc) -lkml_rt\n\nCXX = /usr/bin/g++\n\n# i8_cal: i8_cal.cpp\n# $(CXX) i8_cal.cpp $(CFLAGS) -o i8_cal $(LDLIBS)\n# run: i8_cal\n# ./i8_cal\n\nsimple_test_build: simple_test.cpp\n\trm -f simple_test\n\tBLAS_NUM_THREADS=1 $(CXX) simple_test.cpp $(CFLAGS) -o simple_test $(LDLIBS)\n\nsimple_aocl_build: build simple_test_aocl.cpp\n\t$(CXX) simple_test_aocl.cpp $(CFLAGS) -o build/simple_test_aocl $(LDLIBS)\n\nfp16_test_build: fp16-test.cpp\n\trm -f fp16-test\n\t$(CXX) fp16-test.cpp $(CFLAGS) -o fp16-test $(LDLIBS)\nbf16_test_build: bf16-test.cpp\n\trm -f bf16-test\n\t$(CXX) bf16-test.cpp $(CFLAGS) -o bf16-test $(LDLIBS)\nbuild: build\n\tmkdir -p build\nbandwidth_build: bench_reorder_bandwidth.cpp\n\t$(CXX) bench_reorder_bandwidth.cpp $(CFLAGS) -o build/bench_reorder_bandwidth $(LDLIBS)\nrun: simple_aocl_build\n\tLD_LIBRARY_PATH=/usr/local/lib:$$LD_LIBRARY_PATH  ./build/simple_test_aocl\nrun_bandwidth: bandwidth_build\n\tLD_LIBRARY_PATH=/usr/local/lib:$$LD_LIBRARY_PATH  ./build/bench_reorder_bandwidth"
  },
  {
    "path": "kt-kernel/demo/bench_reorder_bandwidth.cpp",
    "content": "#include <blis.h>\n\n#include <chrono>\n#include <cmath>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n\nnamespace {\nconstexpr int kM = 1;\nconstexpr int kK = 7168;\nconstexpr int kN = 512;\nconstexpr int kIters = 10000;\n\nvoid fill_random(int8_t* ptr, size_t count) {\n  std::srand(47);\n  for (size_t i = 0; i < count; ++i) {\n    ptr[i] = static_cast<int8_t>(std::rand() % 30);\n  }\n}\n\nvoid fill_zero(int32_t* ptr, size_t count) { std::memset(ptr, 0, count * sizeof(int32_t)); }\n\nbool verify(const int8_t* a, const int8_t* b, const int32_t* c) {\n  for (int m = 0; m < kM; ++m) {\n    for (int n = 0; n < kN; ++n) {\n      int32_t ref = 0;\n      for (int k = 0; k < kK; ++k) {\n        ref += static_cast<int32_t>(a[m * kK + k]) * static_cast<int32_t>(b[n * kK + k]);\n      }\n      if (ref != c[m * kN + n]) {\n        std::printf(\"Mismatch at (%d, %d): got %d, expect %d\\n\", m, n, c[m * kN + n], ref);\n        return false;\n      }\n    }\n  }\n  return true;\n}\n}  // namespace\n\nint main() {\n  int8_t* a = static_cast<int8_t*>(std::aligned_alloc(64, kM * kK));\n  int8_t* b = static_cast<int8_t*>(std::aligned_alloc(64, kK * kN));\n  int32_t* c = static_cast<int32_t*>(std::aligned_alloc(64, kM * kN * sizeof(int32_t)));\n  int32_t* c_tmp = static_cast<int32_t*>(std::aligned_alloc(64, kM * kN * sizeof(int32_t)));\n\n  if (!a || !b || !c || !c_tmp) {\n    std::fprintf(stderr, \"Allocation failed.\\n\");\n    std::free(a);\n    std::free(b);\n    std::free(c);\n    std::free(c_tmp);\n    return EXIT_FAILURE;\n  }\n\n  fill_random(a, kM * kK);\n  fill_random(b, kK * kN);\n  fill_zero(c, kM * kN);\n  fill_zero(c_tmp, kM * kN);\n\n  const dim_t reorder_size = aocl_get_reorder_buf_size_s8s8s32os32('r', 't', 'B', kK, kN);\n  int8_t* b_reordered = static_cast<int8_t*>(std::aligned_alloc(64, reorder_size));\n  if (!b_reordered) {\n    std::fprintf(stderr, \"Reorder buffer allocation failed.\\n\");\n    std::free(a);\n    std::free(b);\n    std::free(c);\n    return EXIT_FAILURE;\n  }\n\n  aocl_reorder_s8s8s32os32('r', 't', 'B', b, b_reordered, kK, kN, kK);\n\n  // Warm-up GEMM to load kernels.\n  aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, 1, a, kK, 'n', b_reordered, kK, 'r', 0, c_tmp, kN, nullptr);\n  fill_zero(c, kM * kN);\n\n  const double bytes_per_mul = static_cast<double>(kM) * kK * sizeof(int8_t) +  // A matrix read\n                               static_cast<double>(kK) * kN * sizeof(int8_t);   // original B read\n\n  auto start = std::chrono::high_resolution_clock::now();\n  for (int iter = 0; iter < kIters; ++iter) {\n    aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, 1, a, kK, 'n', b_reordered, kK, 'r', 0, c, kN, nullptr);\n  }\n  auto end = std::chrono::high_resolution_clock::now();\n\n  const double elapsed_seconds = std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();\n  const double total_bytes = bytes_per_mul * kIters;\n  const double bandwidth_gbps = total_bytes / elapsed_seconds / 1e9;\n  const double ops_per_mul = static_cast<double>(kM) * kN * kK * 2.0;\n  const double tflops = (ops_per_mul * kIters) / elapsed_seconds / 1e12;\n\n  std::printf(\"Reorder buffer size: %ld bytes\\n\", static_cast<long>(reorder_size));\n  std::printf(\"Iterations: %d\\n\", kIters);\n  std::printf(\"Elapsed time: %.4f s\\n\", elapsed_seconds);\n  std::printf(\"Effective bandwidth: %.2f GB/s\\n\", bandwidth_gbps);\n  std::printf(\"Int8 GEMM throughput: %.2f TOPS\\n\", tflops * 1e3);\n\n  if (!verify(a, b, c)) {\n    std::fprintf(stderr, \"Verification failed.\\n\");\n  } else {\n    std::puts(\"Verification passed.\");\n  }\n\n  std::free(a);\n  std::free(b);\n  std::free(b_reordered);\n  std::free(c);\n  return 0;\n}\n"
  },
  {
    "path": "kt-kernel/demo/bf16-test.cpp",
    "content": "#define BGEMM\n\n#include <arm_sve.h>\n#include <dlfcn.h>\n#include <kblas.h>\n#include <unistd.h>\n\n#include <chrono>\n#include <cmath>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n#include <random>\n\nint main() {\n  // 矩阵维度 M 是 1024，K 是 1024，N 是 1024（行主序）\n  int M = 512;         // 行主序时，A 的行长度为 K\n  const int K = 7168;  // B 的行长度为 N\n  const int N = 512;   // C 的行长度为 N\n  const int iter = 1;  // 迭代次数\n  // int M = 10;        // 行主序时，A 的行长度为 K\n  // const int K = 10; // B 的行长度为 N\n  // const int N = 10;  // C 的行长度为 N\n\n  // 分配矩阵内存\n  bfloat16_t* A = new bfloat16_t[M * K];\n  bfloat16_t* B = new bfloat16_t[K * N];\n  bfloat16_t* C = new bfloat16_t[M * N];\n  srand(123);\n\n  // 初始化随机种子\n  // std::mt19937 rng(124);\n  // std::uniform_real_distribution <float> dist(0.0, 1.0);\n\n  for (int j = 0; j < M * K; j++) {\n    A[j] = static_cast<float>(std::rand()) / static_cast<float>(RAND_MAX);\n    // A[j] = dist(rng);\n    // A[j] = j;\n  }\n  for (int j = 0; j < K * N; j++) {\n    B[j] = static_cast<float>(std::rand()) / static_cast<float>(RAND_MAX);\n    // B[j] = dist(rng);\n    // B[j] = j;\n  }\n  for (int j = 0; j < M * N; j++) {\n    C[j] = 0.0;\n  }\n\n  // 设置 cblas_gemm_s8u8s32 的参数\n  float alpha = 1.0f;\n  float beta = 0.0f;\n\n  // 打印矩阵 A、B\n  // printf(\"A=\\n\");\n  // for (int i = 0; i < M; i++) {\n  //   for (int j = 0; j < K; j++) {\n  //     printf(\"%f \", A[i * K + j]);\n  //   }\n  //   printf(\"\\n\");\n  // }\n  // printf(\"B=\\n\");\n  // for (int i = 0; i < N; i++) {\n  //   for (int j = 0; j < K; j++) {\n  //     printf(\"%f \", B[i * K + j]);\n  //   }\n  //   printf(\"\\n\");\n  // }\n  // cblas_shgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, K, B, N, beta, C, N);\n  // // 打印结果\n  // printf(\"C=\\n\");\n  // for (int i = 0; i < M; i++) {\n  //   for (int j = 0; j < N; j++) {\n  //     printf(\"%f \", C[i * N + j]);\n  //   }\n  //   printf(\"\\n\");\n  // }\n  // return 0;\n\n  auto fout = fopen(\"test.out\", \"w\");\n  int stride = 16;\n  for (int n = stride; n <= N; n += stride)\n    for (int m = stride; m <= M; m += stride) {\n      // 记录开始时间\n      auto start = std::chrono::high_resolution_clock::now();\n      // #pragma GCC unroll 8\n      for (int i = 0; i < iter; i++) {\n        cblas_bgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, K, alpha, A, K, B, N, beta, C, N);\n        // cblas_gemm_s8s8s32(CblasRowMajor, CblasNoTrans, CblasTrans, CblasFixOffset, m, N, K, alpha, A, K, oa, B, K,\n        // ob,\n        //  beta, C, N, &oc);\n      }\n\n      // 打印结果\n      // printf(\"result:\\n\");\n      // for (int i = 0; i < M; i++) {\n      //   for (int j = 0; j < N; j++) {\n      //     printf(\"%f \", C[i * N + j]);\n      //   }\n      //   printf(\"\\n\");\n      // }\n      // return 0;\n\n      // 记录结束时间\n      auto end = std::chrono::high_resolution_clock::now();\n\n      // 计算总时长（秒）\n      auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);\n      double time_sec = duration.count() / 1e6;  // 转换为秒\n\n      // 计算理论浮点运算次数并转换为 TFLOPS\n      double ops = iter * 2.0 * m * n * K;\n      double tflops = ops / (duration.count() * 1e6);  // 转换为 TFLOPS\n\n      // 输出结果\n      printf(\"execute end time %f us, m n:%d %d\\n\", time_sec * 1e6, m, n);\n      // printf(\"执行时间: %.4f 秒\\n\", time_sec);\n      printf(\"计算性能: %.4f TFLOPS\\n\", tflops);\n      printf(\"\\n\");\n\n      fprintf(fout, \"%d %d %f\\n\", m, n, tflops);\n    }\n\n  // 释放资源\n  free(A);\n  free(B);\n  free(C);\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/demo/fp16-test.cpp",
    "content": "#include <arm_sve.h>\n#include <dlfcn.h>\n#include <kblas.h>\n#include <unistd.h>\n\n#include <chrono>\n#include <cmath>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n#include <random>\n\nint main() {\n  // 矩阵维度 M 是 1024，K 是 1024，N 是 1024（行主序）\n  int M = 5;           // 行主序时，A 的行长度为 K\n  const int K = 10;    // B 的行长度为 N\n  const int N = 7;     // C 的行长度为 N\n  const int iter = 1;  // 迭代次数\n  // int M = 10;        // 行主序时，A 的行长度为 K\n  // const int K = 10; // B 的行长度为 N\n  // const int N = 10;  // C 的行长度为 N\n\n  // 分配矩阵内存\n  float16_t* A = new float16_t[M * K];\n  float16_t* B = new float16_t[K * N];\n  float16_t* C = new float16_t[M * N];\n  float16_t* Cc = new float16_t[M * N];\n  srand(123);\n\n  // 初始化随机种子\n  // std::mt19937 rng(124);\n  // std::uniform_real_distribution <float> dist(0.0, 1.0);\n\n  for (int j = 0; j < M * K; j++) {\n    A[j] = static_cast<float>(std::rand()) / static_cast<float>(RAND_MAX) / 1.0;\n    // A[j] = dist(rng);\n    // A[j] = j;\n  }\n  for (int j = 0; j < K * N; j++) {\n    B[j] = static_cast<float>(std::rand()) / static_cast<float>(RAND_MAX) / 1.0;\n    // B[j] = dist(rng);\n    // B[j] = j;\n  }\n  for (int j = 0; j < M * N; j++) {\n    C[j] = 10;\n    Cc[j] = 10;\n  }\n\n  for (int i = 0; i < M; i++) {\n    for (int j = 0; j < N; j++) {\n      for (int k = 0; k < K; k++) {\n        Cc[j * M + i] += A[i * K + k] * B[k * N + j];\n      }\n    }\n  }\n\n  // 设置 cblas_gemm_s8u8s32 的参数\n  float alpha = 1.0f;\n  float beta = 1.0f;\n\n  // 打印矩阵 A、B\n  printf(\"A=\\n\");\n  for (int i = 0; i < M; i++) {\n    for (int j = 0; j < K; j++) {\n      printf(\"%f \", A[i * K + j]);\n    }\n    printf(\"\\n\");\n  }\n  printf(\"B=\\n\");\n  for (int i = 0; i < K; i++) {\n    for (int j = 0; j < N; j++) {\n      printf(\"%f \", B[i * N + j]);\n    }\n    printf(\"\\n\");\n  }\n  cblas_hgemm(CblasColMajor, CblasTrans, CblasTrans, M, N, K, alpha, A, K, B, N, beta, C, M);\n  // 打印结果\n  printf(\"C=\\n\");\n  for (int i = 0; i < M; i++) {\n    for (int j = 0; j < N; j++) {\n      printf(\"%f \", C[j * M + i]);\n    }\n    printf(\"\\n\");\n  }\n\n  printf(\"Cc=\\n\");\n  for (int i = 0; i < M; i++) {\n    for (int j = 0; j < N; j++) {\n      printf(\"%f \", fabs(C[j * M + i] - Cc[j * M + i]));\n    }\n    printf(\"\\n\");\n  }\n  return 0;\n\n  auto fout = fopen(\"test.out\", \"w\");\n  int stride = 16;\n  for (int n = stride; n <= N; n += stride)\n    for (int m = stride; m <= M; m += stride) {\n      // 记录开始时间\n      auto start = std::chrono::high_resolution_clock::now();\n      // #pragma GCC unroll 8\n      for (int i = 0; i < iter; i++) {\n        cblas_hgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, K, alpha, A, K, B, N, beta, C, N);\n        // cblas_gemm_s8s8s32(CblasRowMajor, CblasNoTrans, CblasTrans, CblasFixOffset, m, N, K, alpha, A, K, oa, B, K,\n        // ob,\n        //  beta, C, N, &oc);\n      }\n\n      // 打印结果\n      // printf(\"result:\\n\");\n      // for (int i = 0; i < M; i++) {\n      //   for (int j = 0; j < N; j++) {\n      //     printf(\"%f \", C[i * N + j]);\n      //   }\n      //   printf(\"\\n\");\n      // }\n      // return 0;\n\n      // 记录结束时间\n      auto end = std::chrono::high_resolution_clock::now();\n\n      // 计算总时长（秒）\n      auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);\n      double time_sec = duration.count() / 1e6;  // 转换为秒\n\n      // 计算理论浮点运算次数并转换为 TFLOPS\n      double ops = iter * 2.0 * m * n * K;\n      double tflops = ops / (duration.count() * 1e6);  // 转换为 TFLOPS\n\n      // 输出结果\n      printf(\"execute end time %f us, m n:%d %d\\n\", time_sec * 1e6, m, n);\n      // printf(\"执行时间: %.4f 秒\\n\", time_sec);\n      printf(\"计算性能: %.4f TFLOPS\\n\", tflops);\n      printf(\"\\n\");\n\n      fprintf(fout, \"%d %d %f\\n\", m, n, tflops);\n    }\n\n  // 释放资源\n  free(A);\n  free(B);\n  free(C);\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/demo/plot.py",
    "content": "import matplotlib.pyplot as plt\nimport re\n\n# 原始数据字符串\ndata_str = \"\"\"\nexecute end,m is:2\n计算性能: 0.0068 TFLOPS\n\nexecute end,m is:4\n计算性能: 0.0143 TFLOPS\n\nexecute end,m is:6\n计算性能: 0.0206 TFLOPS\n\nexecute end,m is:8\n计算性能: 0.0273 TFLOPS\n\nexecute end,m is:10\n计算性能: 0.0330 TFLOPS\n\nexecute end,m is:12\n计算性能: 0.0390 TFLOPS\n\nexecute end,m is:14\n计算性能: 0.0442 TFLOPS\n\nexecute end,m is:16\n计算性能: 0.0495 TFLOPS\n\nexecute end,m is:18\n计算性能: 0.0543 TFLOPS\n\nexecute end,m is:20\n计算性能: 0.0595 TFLOPS\n\nexecute end,m is:22\n计算性能: 0.0637 TFLOPS\n\nexecute end,m is:24\n计算性能: 0.0684 TFLOPS\n\nexecute end,m is:26\n计算性能: 0.0720 TFLOPS\n\nexecute end,m is:28\n计算性能: 0.0769 TFLOPS\n\nexecute end,m is:30\n计算性能: 0.0802 TFLOPS\n\nexecute end,m is:32\n计算性能: 0.0843 TFLOPS\n\nexecute end,m is:34\n计算性能: 0.0874 TFLOPS\n\nexecute end,m is:36\n计算性能: 0.0916 TFLOPS\n\nexecute end,m is:38\n计算性能: 0.0942 TFLOPS\n\nexecute end,m is:40\n计算性能: 0.0977 TFLOPS\n\nexecute end,m is:42\n计算性能: 0.1003 TFLOPS\n\nexecute end,m is:44\n计算性能: 0.1043 TFLOPS\n\nexecute end,m is:46\n计算性能: 0.1059 TFLOPS\n\nexecute end,m is:48\n计算性能: 0.1103 TFLOPS\n\nexecute end,m is:50\n计算性能: 0.1119 TFLOPS\n\nexecute end,m is:52\n计算性能: 0.1153 TFLOPS\n\nexecute end,m is:54\n计算性能: 0.1172 TFLOPS\n\nexecute end,m is:56\n计算性能: 0.1202 TFLOPS\n\nexecute end,m is:58\n计算性能: 0.1219 TFLOPS\n\nexecute end,m is:60\n计算性能: 0.1251 TFLOPS\n\nexecute end,m is:62\n计算性能: 0.1268 TFLOPS\n\nexecute end,m is:64\n计算性能: 0.1286 TFLOPS\n\nexecute end,m is:66\n计算性能: 0.1307 TFLOPS\n\nexecute end,m is:68\n计算性能: 0.1342 TFLOPS\n\nexecute end,m is:70\n计算性能: 0.1347 TFLOPS\n\nexecute end,m is:72\n计算性能: 0.1383 TFLOPS\n\nexecute end,m is:74\n计算性能: 0.1389 TFLOPS\n\nexecute end,m is:76\n计算性能: 0.1416 TFLOPS\n\nexecute end,m is:78\n计算性能: 0.1429 TFLOPS\n\nexecute end,m is:80\n计算性能: 0.1451 TFLOPS\n\nexecute end,m is:82\n计算性能: 0.1471 TFLOPS\n\nexecute end,m is:84\n计算性能: 0.1489 TFLOPS\n\nexecute end,m is:86\n计算性能: 0.1499 TFLOPS\n\nexecute end,m is:88\n计算性能: 0.1519 TFLOPS\n\nexecute end,m is:90\n计算性能: 0.1525 TFLOPS\n\nexecute end,m is:92\n计算性能: 0.1544 TFLOPS\n\nexecute end,m is:94\n计算性能: 0.1560 TFLOPS\n\nexecute end,m is:96\n计算性能: 0.1583 TFLOPS\n\nexecute end,m is:98\n计算性能: 0.1579 TFLOPS\n\nexecute end,m is:100\n计算性能: 0.1600 TFLOPS\n\nexecute end,m is:102\n计算性能: 0.1611 TFLOPS\n\nexecute end,m is:104\n计算性能: 0.1630 TFLOPS\n\nexecute end,m is:106\n计算性能: 0.1644 TFLOPS\n\nexecute end,m is:108\n计算性能: 0.1669 TFLOPS\n\nexecute end,m is:110\n计算性能: 0.1667 TFLOPS\n\nexecute end,m is:112\n计算性能: 0.1687 TFLOPS\n\nexecute end,m is:114\n计算性能: 0.1685 TFLOPS\n\nexecute end,m is:116\n计算性能: 0.1712 TFLOPS\n\nexecute end,m is:118\n计算性能: 0.1712 TFLOPS\n\nexecute end,m is:120\n计算性能: 0.1733 TFLOPS\n\nexecute end,m is:122\n计算性能: 0.1730 TFLOPS\n\nexecute end,m is:124\n计算性能: 0.1753 TFLOPS\n\nexecute end,m is:126\n计算性能: 0.1757 TFLOPS\n\nexecute end,m is:128\n计算性能: 0.1767 TFLOPS\n\nexecute end,m is:130\n计算性能: 0.1783 TFLOPS\n\nexecute end,m is:132\n计算性能: 0.1792 TFLOPS\n\nexecute end,m is:134\n计算性能: 0.1794 TFLOPS\n\nexecute end,m is:136\n计算性能: 0.1821 TFLOPS\n\nexecute end,m is:138\n计算性能: 0.1810 TFLOPS\n\nexecute end,m is:140\n计算性能: 0.1844 TFLOPS\n\nexecute end,m is:142\n计算性能: 0.1840 TFLOPS\n\nexecute end,m is:144\n计算性能: 0.1853 TFLOPS\n\nexecute end,m is:146\n计算性能: 0.1860 TFLOPS\n\nexecute end,m is:148\n计算性能: 0.1867 TFLOPS\n\nexecute end,m is:150\n计算性能: 0.1868 TFLOPS\n\nexecute end,m is:152\n计算性能: 0.1882 TFLOPS\n\nexecute end,m is:154\n计算性能: 0.1880 TFLOPS\n\nexecute end,m is:156\n计算性能: 0.1900 TFLOPS\n\nexecute end,m is:158\n计算性能: 0.1895 TFLOPS\n\nexecute end,m is:160\n计算性能: 0.1921 TFLOPS\n\nexecute end,m is:162\n计算性能: 0.1922 TFLOPS\n\nexecute end,m is:164\n计算性能: 0.1937 TFLOPS\n\nexecute end,m is:166\n计算性能: 0.1935 TFLOPS\n\nexecute end,m is:168\n计算性能: 0.1934 TFLOPS\n\nexecute end,m is:170\n计算性能: 0.1945 TFLOPS\n\nexecute end,m is:172\n计算性能: 0.1961 TFLOPS\n\nexecute end,m is:174\n计算性能: 0.1952 TFLOPS\n\nexecute end,m is:176\n计算性能: 0.1962 TFLOPS\n\nexecute end,m is:178\n计算性能: 0.1977 TFLOPS\n\nexecute end,m is:180\n计算性能: 0.1980 TFLOPS\n\nexecute end,m is:182\n计算性能: 0.1985 TFLOPS\n\nexecute end,m is:184\n计算性能: 0.1993 TFLOPS\n\nexecute end,m is:186\n计算性能: 0.1995 TFLOPS\n\nexecute end,m is:188\n计算性能: 0.2007 TFLOPS\n\nexecute end,m is:190\n计算性能: 0.2012 TFLOPS\n\nexecute end,m is:192\n计算性能: 0.2024 TFLOPS\n\nexecute end,m is:194\n计算性能: 0.2011 TFLOPS\n\nexecute end,m is:196\n计算性能: 0.2037 TFLOPS\n\nexecute end,m is:198\n计算性能: 0.2026 TFLOPS\n\nexecute end,m is:200\n计算性能: 0.2044 TFLOPS\n\nexecute end,m is:202\n计算性能: 0.2044 TFLOPS\n\nexecute end,m is:204\n计算性能: 0.2052 TFLOPS\n\nexecute end,m is:206\n计算性能: 0.2057 TFLOPS\n\nexecute end,m is:208\n计算性能: 0.2061 TFLOPS\n\nexecute end,m is:210\n计算性能: 0.2064 TFLOPS\n\nexecute end,m is:212\n计算性能: 0.2074 TFLOPS\n\nexecute end,m is:214\n计算性能: 0.2075 TFLOPS\n\nexecute end,m is:216\n计算性能: 0.2082 TFLOPS\n\nexecute end,m is:218\n计算性能: 0.2083 TFLOPS\n\nexecute end,m is:220\n计算性能: 0.2091 TFLOPS\n\nexecute end,m is:222\n计算性能: 0.2096 TFLOPS\n\nexecute end,m is:224\n计算性能: 0.2097 TFLOPS\n\nexecute end,m is:226\n计算性能: 0.2098 TFLOPS\n\nexecute end,m is:228\n计算性能: 0.2107 TFLOPS\n\nexecute end,m is:230\n计算性能: 0.2104 TFLOPS\n\nexecute end,m is:232\n计算性能: 0.2118 TFLOPS\n\nexecute end,m is:234\n计算性能: 0.2121 TFLOPS\n\nexecute end,m is:236\n计算性能: 0.2125 TFLOPS\n\nexecute end,m is:238\n计算性能: 0.2128 TFLOPS\n\nexecute end,m is:240\n计算性能: 0.2133 TFLOPS\n\nexecute end,m is:242\n计算性能: 0.2136 TFLOPS\n\nexecute end,m is:244\n计算性能: 0.2137 TFLOPS\n\nexecute end,m is:246\n计算性能: 0.2139 TFLOPS\n\nexecute end,m is:248\n计算性能: 0.2150 TFLOPS\n\nexecute end,m is:250\n计算性能: 0.2153 TFLOPS\n\nexecute end,m is:252\n计算性能: 0.2160 TFLOPS\n\nexecute end,m is:254\n计算性能: 0.2156 TFLOPS\n\nexecute end,m is:256\n计算性能: 0.2169 TFLOPS\n\nexecute end,m is:258\n计算性能: 0.2161 TFLOPS\n\nexecute end,m is:260\n计算性能: 0.2175 TFLOPS\n\nexecute end,m is:262\n计算性能: 0.2172 TFLOPS\n\nexecute end,m is:264\n计算性能: 0.2175 TFLOPS\n\nexecute end,m is:266\n计算性能: 0.2181 TFLOPS\n\nexecute end,m is:268\n计算性能: 0.2189 TFLOPS\n\nexecute end,m is:270\n计算性能: 0.2193 TFLOPS\n\nexecute end,m is:272\n计算性能: 0.2201 TFLOPS\n\nexecute end,m is:274\n计算性能: 0.2198 TFLOPS\n\nexecute end,m is:276\n计算性能: 0.2195 TFLOPS\n\nexecute end,m is:278\n计算性能: 0.2205 TFLOPS\n\nexecute end,m is:280\n计算性能: 0.2212 TFLOPS\n\nexecute end,m is:282\n计算性能: 0.2210 TFLOPS\n\nexecute end,m is:284\n计算性能: 0.2210 TFLOPS\n\nexecute end,m is:286\n计算性能: 0.2215 TFLOPS\n\nexecute end,m is:288\n计算性能: 0.2225 TFLOPS\n\nexecute end,m is:290\n计算性能: 0.2227 TFLOPS\n\nexecute end,m is:292\n计算性能: 0.2234 TFLOPS\n\nexecute end,m is:294\n计算性能: 0.2227 TFLOPS\n\nexecute end,m is:296\n计算性能: 0.2242 TFLOPS\n\nexecute end,m is:298\n计算性能: 0.2230 TFLOPS\n\nexecute end,m is:300\n计算性能: 0.2232 TFLOPS\n\nexecute end,m is:302\n计算性能: 0.2227 TFLOPS\n\nexecute end,m is:304\n计算性能: 0.2234 TFLOPS\n\nexecute end,m is:306\n计算性能: 0.2226 TFLOPS\n\nexecute end,m is:308\n计算性能: 0.2239 TFLOPS\n\nexecute end,m is:310\n计算性能: 0.2239 TFLOPS\n\nexecute end,m is:312\n计算性能: 0.2249 TFLOPS\n\nexecute end,m is:314\n计算性能: 0.2245 TFLOPS\n\nexecute end,m is:316\n计算性能: 0.2254 TFLOPS\n\nexecute end,m is:318\n计算性能: 0.2251 TFLOPS\n\nexecute end,m is:320\n计算性能: 0.2262 TFLOPS\n\nexecute end,m is:322\n计算性能: 0.2256 TFLOPS\n\nexecute end,m is:324\n计算性能: 0.2262 TFLOPS\n\nexecute end,m is:326\n计算性能: 0.2259 TFLOPS\n\nexecute end,m is:328\n计算性能: 0.2265 TFLOPS\n\nexecute end,m is:330\n计算性能: 0.2266 TFLOPS\n\nexecute end,m is:332\n计算性能: 0.2275 TFLOPS\n\nexecute end,m is:334\n计算性能: 0.2275 TFLOPS\n\nexecute end,m is:336\n计算性能: 0.2280 TFLOPS\n\nexecute end,m is:338\n计算性能: 0.2275 TFLOPS\n\nexecute end,m is:340\n计算性能: 0.2281 TFLOPS\n\nexecute end,m is:342\n计算性能: 0.2284 TFLOPS\n\nexecute end,m is:344\n计算性能: 0.2288 TFLOPS\n\nexecute end,m is:346\n计算性能: 0.2288 TFLOPS\n\nexecute end,m is:348\n计算性能: 0.2295 TFLOPS\n\nexecute end,m is:350\n计算性能: 0.2292 TFLOPS\n\nexecute end,m is:352\n计算性能: 0.2300 TFLOPS\n\nexecute end,m is:354\n计算性能: 0.2299 TFLOPS\n\nexecute end,m is:356\n计算性能: 0.2303 TFLOPS\n\nexecute end,m is:358\n计算性能: 0.2301 TFLOPS\n\nexecute end,m is:360\n计算性能: 0.2307 TFLOPS\n\nexecute end,m is:362\n计算性能: 0.2303 TFLOPS\n\nexecute end,m is:364\n计算性能: 0.2312 TFLOPS\n\nexecute end,m is:366\n计算性能: 0.2307 TFLOPS\n\nexecute end,m is:368\n计算性能: 0.2316 TFLOPS\n\nexecute end,m is:370\n计算性能: 0.2310 TFLOPS\n\nexecute end,m is:372\n计算性能: 0.2318 TFLOPS\n\nexecute end,m is:374\n计算性能: 0.2319 TFLOPS\n\nexecute end,m is:376\n计算性能: 0.2320 TFLOPS\n\nexecute end,m is:378\n计算性能: 0.2323 TFLOPS\n\nexecute end,m is:380\n计算性能: 0.2328 TFLOPS\n\nexecute end,m is:382\n计算性能: 0.2326 TFLOPS\n\nexecute end,m is:384\n计算性能: 0.2328 TFLOPS\n\nexecute end,m is:386\n计算性能: 0.2330 TFLOPS\n\nexecute end,m is:388\n计算性能: 0.2334 TFLOPS\n\nexecute end,m is:390\n计算性能: 0.2337 TFLOPS\n\nexecute end,m is:392\n计算性能: 0.2336 TFLOPS\n\nexecute end,m is:394\n计算性能: 0.2332 TFLOPS\n\nexecute end,m is:396\n计算性能: 0.2341 TFLOPS\n\nexecute end,m is:398\n计算性能: 0.2334 TFLOPS\n\nexecute end,m is:400\n计算性能: 0.2347 TFLOPS\n\nexecute end,m is:402\n计算性能: 0.2349 TFLOPS\n\nexecute end,m is:404\n计算性能: 0.2350 TFLOPS\n\nexecute end,m is:406\n计算性能: 0.2347 TFLOPS\n\nexecute end,m is:408\n计算性能: 0.2353 TFLOPS\n\nexecute end,m is:410\n计算性能: 0.2350 TFLOPS\n\nexecute end,m is:412\n计算性能: 0.2356 TFLOPS\n\nexecute end,m is:414\n计算性能: 0.2354 TFLOPS\n\nexecute end,m is:416\n计算性能: 0.2357 TFLOPS\n\nexecute end,m is:418\n计算性能: 0.2357 TFLOPS\n\nexecute end,m is:420\n计算性能: 0.2361 TFLOPS\n\nexecute end,m is:422\n计算性能: 0.2361 TFLOPS\n\nexecute end,m is:424\n计算性能: 0.2364 TFLOPS\n\nexecute end,m is:426\n计算性能: 0.2360 TFLOPS\n\nexecute end,m is:428\n计算性能: 0.2372 TFLOPS\n\nexecute end,m is:430\n计算性能: 0.2364 TFLOPS\n\nexecute end,m is:432\n计算性能: 0.2369 TFLOPS\n\nexecute end,m is:434\n计算性能: 0.2369 TFLOPS\n\nexecute end,m is:436\n计算性能: 0.2372 TFLOPS\n\nexecute end,m is:438\n计算性能: 0.2370 TFLOPS\n\nexecute end,m is:440\n计算性能: 0.2377 TFLOPS\n\nexecute end,m is:442\n计算性能: 0.2374 TFLOPS\n\nexecute end,m is:444\n计算性能: 0.2382 TFLOPS\n\nexecute end,m is:446\n计算性能: 0.2379 TFLOPS\n\nexecute end,m is:448\n计算性能: 0.2385 TFLOPS\n\nexecute end,m is:450\n计算性能: 0.2377 TFLOPS\n\nexecute end,m is:452\n计算性能: 0.2385 TFLOPS\n\nexecute end,m is:454\n计算性能: 0.2384 TFLOPS\n\nexecute end,m is:456\n计算性能: 0.2389 TFLOPS\n\nexecute end,m is:458\n计算性能: 0.2319 TFLOPS\n\nexecute end,m is:460\n计算性能: 0.2386 TFLOPS\n\nexecute end,m is:462\n计算性能: 0.2386 TFLOPS\n\nexecute end,m is:464\n计算性能: 0.2389 TFLOPS\n\nexecute end,m is:466\n计算性能: 0.2393 TFLOPS\n\nexecute end,m is:468\n计算性能: 0.2393 TFLOPS\n\nexecute end,m is:470\n计算性能: 0.2389 TFLOPS\n\nexecute end,m is:472\n计算性能: 0.2393 TFLOPS\n\nexecute end,m is:474\n计算性能: 0.2395 TFLOPS\n\nexecute end,m is:476\n计算性能: 0.2399 TFLOPS\n\nexecute end,m is:478\n计算性能: 0.2400 TFLOPS\n\nexecute end,m is:480\n计算性能: 0.2400 TFLOPS\n\nexecute end,m is:482\n计算性能: 0.2397 TFLOPS\n\nexecute end,m is:484\n计算性能: 0.2407 TFLOPS\n\nexecute end,m is:486\n计算性能: 0.2400 TFLOPS\n\nexecute end,m is:488\n计算性能: 0.2407 TFLOPS\n\nexecute end,m is:490\n计算性能: 0.2404 TFLOPS\n\nexecute end,m is:492\n计算性能: 0.2411 TFLOPS\n\nexecute end,m is:494\n计算性能: 0.2409 TFLOPS\n\nexecute end,m is:496\n计算性能: 0.2407 TFLOPS\n\nexecute end,m is:498\n计算性能: 0.2412 TFLOPS\n\nexecute end,m is:500\n计算性能: 0.2418 TFLOPS\n\nexecute end,m is:502\n计算性能: 0.2416 TFLOPS\n\nexecute end,m is:504\n计算性能: 0.2418 TFLOPS\n\nexecute end,m is:506\n计算性能: 0.2416 TFLOPS\n\nexecute end,m is:508\n计算性能: 0.2421 TFLOPS\n\nexecute end,m is:510\n计算性能: 0.2419 TFLOPS\n\nexecute end,m is:512\n计算性能: 0.2423 TFLOPS\n\"\"\"\n\n# 使用正则表达式提取 m 和 TFLOPS 值\nm_values = list(map(int, re.findall(r'm is:(\\d+)', data_str)))\ntflops_values = list(map(float, re.findall(r'计算性能: ([\\d.]+) TFLOPS', data_str)))\n\n# 绘图\nplt.figure(figsize=(10, 6))\nplt.plot(m_values, tflops_values, marker='o', linestyle='-', color='blue')\nplt.title('m * k with k * n (k=7168 n=512) ')\nplt.xlabel('m')\nplt.ylabel('Tflops')\nplt.grid(True)\nplt.tight_layout()\n\n# 保存图表为文件\nplt.savefig('performance_plot.png')  # 保存为 PNG 格式\n# plt.savefig('performance_plot.pdf')  # 或保存为 PDF 格式\n"
  },
  {
    "path": "kt-kernel/demo/simple_test.cpp",
    "content": "#include <dlfcn.h>\n#include <kblas.h>\n#include <unistd.h>\n\n#include <chrono>\n#include <cmath>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n\nint main() {\n  // 矩阵维度 M 是 1024，K 是 1024，N 是 1024（行主序）\n  int M = 1024;        // 行主序时，A 的行长度为 K\n  const int K = 1024;  // B 的行长度为 N\n  const int N = 1024;  // C 的行长度为 N\n  const int iter = 1;  // 迭代次数\n\n  // 分配矩阵内存\n  int8_t* A = (int8_t*)malloc(M * K * sizeof(int8_t));\n  int8_t* B = (int8_t*)malloc(K * N * sizeof(int8_t));\n  int32_t* C = (int32_t*)malloc(M * N * sizeof(int32_t));\n\n  // 初始化随机种子\n  srand((unsigned)time(NULL));\n\n  // 随机初始化 A（范围 0 到 255）和 B（范围 -128 到 127）\n  // 初始化矩阵 A 和 B\n  for (int j = 0; j < M * K; j++) {\n    // A[j] = rand() % 256;\n    A[j] = j;\n  }\n  for (int j = 0; j < K * N; j++) {\n    // B[j] = rand() % 256;\n    B[j] = j;\n  }\n  // 初始化矩阵 C\n  for (int j = 0; j < M * N; j++) {\n    C[j] = 0;\n  }\n\n  // 设置 cblas_gemm_s8u8s32 的参数\n  float alpha = 1.0f;\n  float beta = 0.0f;\n  int8_t oa = 0, ob = 0;\n  int32_t oc = 0;\n\n  // 打印矩阵 A、B\n  // printf(\"A=\\n\");\n  // for (int i = 0; i < M; i++) {\n  //   for (int j = 0; j < K; j++) {\n  //     printf(\"%d \", A[i * K + j]);\n  //   }\n  //   printf(\"\\n\");\n  // }\n  // printf(\"B=\\n\");\n  // for (int i = 0; i < N; i++) {\n  //   for (int j = 0; j < K; j++) {\n  //     printf(\"%d \", B[i * K + j]);\n  //   }\n  //   printf(\"\\n\");\n  // }\n\n  // printf(\"format: 'generate end'\\n\");\n  // 调用 cblas_gemm_s8u8s32 执行矩阵乘法：C = i1(A+ao)(B+bo) + 0*C + oc\n  // 从m=10～256 都测一遍速度，步长是 stride\n  int stride = 2;\n  int start_m = M;\n  for (int m = start_m; m <= M; m += stride) {\n    // 记录开始时间\n    auto start = std::chrono::high_resolution_clock::now();\n#pragma GCC unroll 8\n    for (int i = 0; i < iter; i++) {\n      cblas_gemm_s8s8s32(CblasRowMajor, CblasNoTrans, CblasTrans, CblasFixOffset, m, N / 2, K, alpha, A, K, oa, B, K,\n                         ob, beta, C, N, &oc);\n      int8_t* B_high = B + K * N / 2;\n      int32_t* C_high = C + N / 2;\n      cblas_gemm_s8s8s32(CblasRowMajor, CblasNoTrans, CblasTrans, CblasFixOffset, m, N / 2, K, alpha, A, K, oa, B_high,\n                         K, ob, beta, C_high, N, &oc);\n    }\n\n    // 打印结果\n    // printf(\"result:\\n\");\n    // for (int i = 0; i < M; i++) {\n    //   for (int j = 0; j < N; j++) {\n    //     printf(\"%d \", C[i * N + j]);\n    //   }\n    //   printf(\"\\n\");\n    // }\n\n    // 记录结束时间\n    auto end = std::chrono::high_resolution_clock::now();\n\n    // 计算总时长（秒）\n    auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);\n    double time_sec = duration.count() / 1e6;  // 转换为秒\n\n    // 计算理论浮点运算次数并转换为 TFLOPS\n    double ops = iter * 2.0 * m * N * K;\n    double tflops = ops / (duration.count() * 1e6);  // 转换为 TFLOPS\n\n    // 输出结果\n    printf(\"execute end,m is:%d\\n\", m);\n    // printf(\"执行时间: %.4f 秒\\n\", time_sec);\n    printf(\"计算性能: %.4f TFLOPS\\n\", tflops);\n    printf(\"\\n\");\n  }\n\n  // 释放资源\n  free(A);\n  free(B);\n  free(C);\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/demo/simple_test_aocl.cpp",
    "content": "#include <blis.h>\n\n#include <cmath>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n// #define CHECK\nnamespace {\n// B matrix is in col-major order\nconstexpr int kM = 3;\nconstexpr int kK = 7168;\nconstexpr int kN = 2048;\nvoid fill_inputs(int8_t* a, int8_t* b) {\n  srand(static_cast<unsigned>(time(nullptr)));\n  for (int i = 0; i < kM * kK; ++i) {\n    a[i] = static_cast<int8_t>(rand() % 127);\n  }\n  for (int i = 0; i < kK * kN; ++i) {\n    b[i] = static_cast<int8_t>(rand() % 127);\n  }\n}\n\nvoid compute_reference(const int8_t* a, const int8_t* b, int32_t* ref) {\n  for (int m = 0; m < kM; ++m) {\n    for (int n = 0; n < kN; ++n) {\n      int32_t acc = 0;\n      for (int k = 0; k < kK; ++k) {\n        acc += static_cast<int32_t>(a[m * kK + k]) * static_cast<int32_t>(b[k * kN + n]);\n      }\n      ref[m * kN + n] = acc;\n    }\n  }\n}\n\nbool check_result(const int32_t* got, const int32_t* ref) {\n  for (int idx = 0; idx < kM * kN; ++idx) {\n    if (got[idx] != ref[idx]) {\n      std::printf(\"Mismatch at %d: got %d, expected %d\\n\", idx, got[idx], ref[idx]);\n      return false;\n    }\n  }\n  return true;\n}\n}  // namespace\n\nint main() {\n  err_t err = BLIS_SUCCESS;\n  int8_t* a = static_cast<int8_t*>(bli_malloc_user(kM * kK, &err));\n  int8_t* b = static_cast<int8_t*>(bli_malloc_user(kK * kN, &err));\n  int8_t* b_rowmajor = static_cast<int8_t*>(bli_malloc_user(kK * kN, &err));\n  int8_t* b_reordered = nullptr;\n  int32_t* c = static_cast<int32_t*>(bli_malloc_user(kM * kN * sizeof(int32_t), &err));\n  int32_t* c_unp = static_cast<int32_t*>(bli_malloc_user(kM * kN * sizeof(int32_t), &err));\n  int32_t* ref = static_cast<int32_t*>(bli_malloc_user(kM * kN * sizeof(int32_t), &err));\n\n  if (!a || !b || !c || !ref || !c_unp) {\n    std::fprintf(stderr, \"Allocation failed\\n\");\n    bli_free_user(a);\n    bli_free_user(b);\n    bli_free_user(c);\n    bli_free_user(ref);\n    bli_free_user(c_unp);\n    return EXIT_FAILURE;\n  }\n\n  fill_inputs(a, b);\n  // transform B from col-major to row-major\n  for (int k = 0; k < kK; ++k) {\n    for (int n = 0; n < kN; ++n) {\n      // original B is in col-major: b[n * ld + k], here ld = kK\n      int8_t val = b[n * kK + k];\n      // target row-major: row index = k, col index = n\n      b_rowmajor[k * kN + n] = val;\n    }\n  }\n#ifdef CHECK\n  // CHECK: printf inputs\n  std::puts(\"\\nMatrix A:\\n\");\n  for (int m = 0; m < kM; ++m) {\n    for (int k = 0; k < kK; ++k) {\n      std::printf(\"%4d \", a[m * kK + k]);\n    }\n    std::puts(\"\");\n  }\n  std::puts(\"\\nMatrix B:\\n\");\n  for (int k = 0; k < kK; ++k) {\n    for (int n = 0; n < kN; ++n) {\n      std::printf(\"%4d \", b[n * kK + k]);\n    }\n    std::puts(\"\");\n  }\n#endif\n  std::memset(c, 0, kM * kN * sizeof(int32_t));\n  std::memset(c_unp, 0, kM * kN * sizeof(int32_t));\n  std::memset(ref, 0, kM * kN * sizeof(int32_t));\n  compute_reference(a, b_rowmajor, ref);\n#ifdef CHECK\n  // CHECK: printf reference\n  std::puts(\"\\nReference result:\\n\");\n  for (int m = 0; m < kM; ++m) {\n    for (int n = 0; n < kN; ++n) {\n      std::printf(\"%6d \", ref[m * kN + n]);\n    }\n    std::puts(\"\");\n  }\n#endif\n  const dim_t reorder_size = aocl_get_reorder_buf_size_s8s8s32os32('c', 'n', 'B', kK, kN);\n  b_reordered = static_cast<int8_t*>(bli_malloc_user(reorder_size, &err));\n  if (!b_reordered) {\n    std::fprintf(stderr, \"Reorder buffer allocation failed\\n\");\n    bli_free_user(a);\n    bli_free_user(b);\n    bli_free_user(c);\n    bli_free_user(ref);\n    return EXIT_FAILURE;\n  }\n  aocl_reorder_s8s8s32os32('c', 'n', 'B', b, b_reordered, kK, kN, kK);\n#ifdef CHECK\n  // CHECK: printf reordered B\n  std::puts(\"\\nReordered Matrix B:\\n\");\n  for (int k = 0; k < kK; ++k) {\n    for (int n = 0; n < kN; ++n) {\n      std::printf(\"%4d \", b_reordered[k * kN + n]);\n    }\n    std::puts(\"\");\n  }\n  std::printf(\"\\nReorder buffer size: %zu bytes\\n\", reorder_size);\n#endif\n\n  const int32_t alpha = 1;\n  const int32_t beta = 0;\n  aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, alpha, a, kK, 'n', b_reordered, kK, 'r', beta, c, kN, nullptr);\n  aocl_gemm_s8s8s32os32('r', 'n', 't', kM, kN, kK, alpha, a, kK, 'n', b, kK, 'n', beta, c_unp, kN, nullptr);\n#ifdef CHECK\n  // CHECK: printf AOCL result\n  std::puts(\"\\nAOCL GEMM result (with reordered B):\\n\");\n  for (int m = 0; m < kM; ++m) {\n    for (int n = 0; n < kN; ++n) {\n      std::printf(\"%6d \", c[m * kN + n]);\n    }\n    std::puts(\"\");\n  }\n  std::puts(\"\\nAOCL GEMM result (without reordered B):\\n\");\n  for (int m = 0; m < kM; ++m) {\n    for (int n = 0; n < kN; ++n) {\n      std::printf(\"%6d \", c_unp[m * kN + n]);\n    }\n    std::puts(\"\");\n  }\n#endif\n\n  if (check_result(c, ref)) {\n    std::puts(\"AOCL GEMM output matches reference.\");\n  } else {\n    std::puts(\"AOCL GEMM output mismatch detected.\");\n  }\n\n  if (check_result(c_unp, ref)) {\n    std::puts(\"unpack AOCL GEMM output matches reference.\");\n  } else {\n    std::puts(\"unpack AOCL GEMM output mismatch detected.\");\n  }\n\n  bli_free_user(a);\n  bli_free_user(b);\n  bli_free_user(b_rowmajor);\n  bli_free_user(b_reordered);\n  bli_free_user(c);\n  bli_free_user(c_unp);\n  bli_free_user(ref);\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/demo/tflops.py",
    "content": "import pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n# 读取数据\nfile_path = 'data.txt'  # 替换为你的文件路径\ndf = pd.read_csv(file_path, sep=r'\\s+', names=['m', 'n', 'tflops'])\n\n# 创建数据透视表，行为 m，列为 n，值为 tflops\npivot_table = df.pivot_table(index='m', columns='n', values='tflops')\n\n# 画热力图\nplt.figure(figsize=(10, 8))\nsns.heatmap(pivot_table, annot=True, fmt=\".2f\", cmap='viridis')\nplt.title('TFLOPS Heatmap')\nplt.xlabel('n')\nplt.ylabel('m')\nplt.tight_layout()\nplt.show()\n"
  },
  {
    "path": "kt-kernel/examples/.gitignore",
    "content": "debug"
  },
  {
    "path": "kt-kernel/examples/bench_moe_amx_int8.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nAMX INT8 MoE Benchmark Script\n\nBenchmarks performance of AMX-accelerated INT8 MOE operations with configurable parameters.\nSupports uniform workload distribution across experts and optional CUDA stream mode.\n\nUsage:\n    python bench_moe_amx_int8.py [options]\n\nExamples:\n    # Default parameters\n    python bench_moe_amx_int8.py\n\n    # Custom parameters\n    python bench_moe_amx_int8.py --layer_num 4 --expert_num 256 --workload 8 --use_cuda_stream\n\n    # Full configuration\n    python bench_moe_amx_int8.py --layer_num 2 --expert_num 128 --num_experts_per_tok 8 \\\n        --workload 4 --hidden_size 7168 --intermediate_size 2048 \\\n        --warmup_iter 100 --test_iter 1000 --use_cuda_stream\n\"\"\"\n\nimport os\nimport sys\nimport time\nimport argparse\n\n# Add build path for development\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\n\nimport torch\n\ntry:\n    from kt_kernel import kt_kernel_ext\n\n    HAS_KT_KERNEL = True\nexcept ImportError as e:\n    HAS_KT_KERNEL = False\n    import_error = str(e)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"AMX INT8 MoE Benchmark\", formatter_class=argparse.ArgumentDefaultsHelpFormatter\n    )\n\n    # Model parameters\n    parser.add_argument(\"--layer_num\", type=int, default=2, help=\"Number of MoE layers\")\n    parser.add_argument(\"--expert_num\", type=int, default=256, help=\"Number of experts per layer\")\n    parser.add_argument(\n        \"--num_experts_per_tok\", type=int, default=8, help=\"Number of experts selected per token (top-k)\"\n    )\n    parser.add_argument(\"--hidden_size\", type=int, default=7168, help=\"Hidden dimension size\")\n    parser.add_argument(\"--intermediate_size\", type=int, default=2048, help=\"Intermediate dimension size\")\n\n    # Workload parameters\n    parser.add_argument(\"--workload\", type=int, default=1, help=\"Workload (qlen, number of tokens)\")\n    parser.add_argument(\"--max_len\", type=int, default=25600, help=\"Maximum sequence length for buffer allocation\")\n\n    # Benchmark parameters\n    parser.add_argument(\"--warmup_iter\", type=int, default=100, help=\"Number of warmup iterations\")\n    parser.add_argument(\"--test_iter\", type=int, default=1000, help=\"Number of test iterations\")\n\n    # Execution mode\n    parser.add_argument(\"--use_cuda_stream\", action=\"store_true\", help=\"Use CUDA stream mode (submit_with_cuda_stream)\")\n    parser.add_argument(\"--profile\", action=\"store_true\", help=\"Enable PyTorch profiler and export trace.json\")\n    parser.add_argument(\"--profile_path\", type=str, default=\"./trace.json\", help=\"Path to save profile trace\")\n\n    # Worker configuration\n    parser.add_argument(\"--cpuinfer_threads\", type=int, default=60, help=\"Total CPU inference threads\")\n    parser.add_argument(\"--numa_count\", type=int, default=2, help=\"Number of NUMA nodes\")\n    parser.add_argument(\n        \"--num_gpu_experts\", type=int, default=0, help=\"Number of experts to place on GPU (first N experts)\"\n    )\n\n    return parser.parse_args()\n\n\ndef generate_uniform_workload(expert_num, num_experts_per_tok, workload):\n    \"\"\"\n    Generate expert_ids and weights with uniform workload distribution.\n\n    workload = qlen (number of tokens)\n    Each token selects num_experts_per_tok experts.\n    Total expert calls = workload * num_experts_per_tok\n    \"\"\"\n    qlen = workload\n\n    # Randomly select num_experts_per_tok experts (uniform, no duplicates)\n    # All tokens will use the same expert combination\n    selected_experts = torch.randperm(expert_num)[:num_experts_per_tok].tolist()\n\n    # Create expert_ids: all tokens use the same expert combination\n    expert_ids = [selected_experts for _ in range(qlen)]\n\n    # Create on GPU then copy to CPU (faster)\n    expert_ids = torch.tensor(expert_ids, dtype=torch.long, device=\"cuda\").to(\"cpu\").contiguous()\n    print(f\"Selected experts (all tokens use same): {selected_experts}\")\n    print(f\"Expert IDs shape: {expert_ids.shape}\")\n\n    # Uniform weights (normalized) - create on GPU then copy\n    weights = torch.ones((qlen, num_experts_per_tok), dtype=torch.float32, device=\"cuda\") / num_experts_per_tok\n    weights = weights.to(\"cpu\").contiguous()\n\n    return expert_ids, weights, qlen\n\n\ndef run_benchmark(args):\n    \"\"\"Run the AMX INT8 MoE benchmark.\"\"\"\n\n    print(\"=\" * 60)\n    print(\"AMX INT8 MoE Benchmark\")\n    print(\"=\" * 60)\n    print(f\"\\nConfiguration:\")\n    print(f\"  Layers:              {args.layer_num}\")\n    print(f\"  Experts per layer:   {args.expert_num}\")\n    print(f\"  Experts per token:   {args.num_experts_per_tok}\")\n    print(f\"  Hidden size:         {args.hidden_size}\")\n    print(f\"  Intermediate size:   {args.intermediate_size}\")\n    print(f\"  Workload (qlen):     {args.workload}\")\n    print(f\"  Use CUDA stream:     {args.use_cuda_stream}\")\n    print(f\"  Warmup iterations:   {args.warmup_iter}\")\n    print(f\"  Test iterations:     {args.test_iter}\")\n    print(f\"  CPU threads:         {args.cpuinfer_threads}\")\n    print(f\"  NUMA nodes:          {args.numa_count}\")\n\n    # Generate uniform workload\n    expert_ids, weights, qlen = generate_uniform_workload(args.expert_num, args.num_experts_per_tok, args.workload)\n    print(f\"\\nActual qlen:           {qlen}\")\n    print(f\"Total expert calls:    {qlen * args.num_experts_per_tok}\")\n\n    with torch.inference_mode():\n        # Initialize CPUInfer\n        if args.numa_count > 1:\n            worker_config = kt_kernel_ext.WorkerPoolConfig()\n            worker_config.subpool_count = args.numa_count\n            worker_config.subpool_numa_map = list(range(args.numa_count))\n            threads_per_numa = args.cpuinfer_threads // args.numa_count\n            worker_config.subpool_thread_count = [threads_per_numa] * args.numa_count\n            cpu_infer = kt_kernel_ext.CPUInfer(worker_config)\n        else:\n            cpu_infer = kt_kernel_ext.CPUInfer(args.cpuinfer_threads)\n\n        # Physical to logical mapping (identity)\n        physical_to_logical_map = torch.arange(args.expert_num, dtype=torch.int64, device=\"cpu\").contiguous()\n\n        # GPU experts mask - set first num_gpu_experts to True if specified\n        gpu_experts_mask = torch.zeros(args.expert_num, dtype=torch.bool, device=\"cpu\")\n        if args.num_gpu_experts > 0:\n            num_gpu = min(args.num_gpu_experts, args.expert_num)\n            gpu_experts_mask[:num_gpu] = True\n            print(f\"  GPU experts: {num_gpu} (experts 0-{num_gpu-1})\")\n\n        # Initialize MoE layers\n        print(\"\\nInitializing MoE layers...\")\n        moes = []\n        for layer_idx in range(args.layer_num):\n            # Create random weights on GPU then copy to CPU (faster)\n            gate_proj = (\n                torch.randn(\n                    (args.expert_num, args.intermediate_size, args.hidden_size), dtype=torch.bfloat16, device=\"cuda\"\n                )\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn(\n                    (args.expert_num, args.intermediate_size, args.hidden_size), dtype=torch.bfloat16, device=\"cuda\"\n                )\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn(\n                    (args.expert_num, args.hidden_size, args.intermediate_size), dtype=torch.bfloat16, device=\"cuda\"\n                )\n                .to(\"cpu\")\n                .contiguous()\n            )\n\n            # Configure MoE\n            config = kt_kernel_ext.moe.MOEConfig(\n                args.expert_num,\n                args.num_experts_per_tok,\n                args.hidden_size,\n                args.intermediate_size,\n                gpu_experts_mask.data_ptr(),\n            )\n            config.max_len = args.max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = cpu_infer.backend_\n\n            moe = kt_kernel_ext.moe.AMXInt8_MOE(config)\n            cpu_infer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            cpu_infer.sync()\n\n            moes.append(moe)\n            print(f\"  Layer {layer_idx} initialized\")\n\n        # Prepare input/output tensors\n        input_tensor = torch.randn((qlen, args.hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        output_tensor = torch.zeros((qlen, args.hidden_size), dtype=torch.bfloat16, device=\"cpu\").contiguous()\n        bsz_tensor = torch.tensor([qlen], dtype=torch.int32, device=\"cpu\")\n\n        # CUDA stream setup (if enabled)\n        cuda_stream = None\n        if args.use_cuda_stream:\n            if not torch.cuda.is_available():\n                print(\"\\nWarning: CUDA not available, falling back to non-stream mode\")\n                args.use_cuda_stream = False\n            else:\n                cuda_stream = torch.cuda.current_stream().cuda_stream\n                print(f\"\\nUsing CUDA stream: {cuda_stream}\")\n\n        # Warmup\n        print(f\"\\nWarmup ({args.warmup_iter} iterations)...\")\n        for i in range(args.warmup_iter):\n            moe = moes[i % args.layer_num]\n            task = moe.forward_task(\n                bsz_tensor.data_ptr(),\n                args.num_experts_per_tok,\n                expert_ids.data_ptr(),\n                weights.data_ptr(),\n                input_tensor.data_ptr(),\n                output_tensor.data_ptr(),\n                False,  # incremental\n            )\n\n            if args.use_cuda_stream:\n                cpu_infer.submit_with_cuda_stream(cuda_stream, task)\n                cpu_infer.sync_with_cuda_stream(cuda_stream)\n            else:\n                cpu_infer.submit(task)\n                cpu_infer.sync()\n\n        # Benchmark\n        print(f\"Benchmarking ({args.test_iter} iterations)...\")\n\n        if args.use_cuda_stream:\n            torch.cuda.synchronize()\n\n        # Setup profiler if enabled\n        profiler = None\n        if args.profile:\n            profiler = torch.profiler.profile(\n                activities=[\n                    torch.profiler.ProfilerActivity.CPU,\n                    torch.profiler.ProfilerActivity.CUDA,\n                ],\n                record_shapes=False,\n                with_stack=False,\n            )\n            profiler.__enter__()\n\n        start_time = time.perf_counter()\n\n        for i in range(args.test_iter):\n            moe = moes[i % args.layer_num]\n\n            if args.profile:\n                torch.cuda.nvtx.range_push(f\"iter_{i}\")\n\n            task = moe.forward_task(\n                bsz_tensor.data_ptr(),\n                args.num_experts_per_tok,\n                expert_ids.data_ptr(),\n                weights.data_ptr(),\n                input_tensor.data_ptr(),\n                output_tensor.data_ptr(),\n                False,\n            )\n\n            if args.use_cuda_stream:\n                if args.profile:\n                    torch.cuda.nvtx.range_push(\"submit\")\n                cpu_infer.submit_with_cuda_stream(cuda_stream, task)\n                if args.profile:\n                    torch.cuda.nvtx.range_pop()\n                    torch.cuda.nvtx.range_push(\"sync\")\n                cpu_infer.sync_with_cuda_stream(cuda_stream)\n                if args.profile:\n                    torch.cuda.nvtx.range_pop()\n            else:\n                cpu_infer.submit(task)\n                cpu_infer.sync()\n\n            if args.profile:\n                torch.cuda.nvtx.range_pop()\n\n        if args.use_cuda_stream:\n            torch.cuda.synchronize()\n\n        end_time = time.perf_counter()\n        total_time = end_time - start_time\n\n        # Export profiler trace\n        if profiler:\n            profiler.__exit__(None, None, None)\n            profiler.export_chrome_trace(args.profile_path)\n            print(f\"\\nProfile trace saved to: {args.profile_path}\")\n\n        # Calculate metrics\n        # Note: each iteration processes ONE layer (round-robin: moe = moes[i % layer_num])\n        time_per_iter_us = total_time / args.test_iter * 1e6\n\n        # Bandwidth calculation\n        # Weight size per expert: 3 * hidden_size * intermediate_size * bytes_per_elem\n        bytes_per_elem = 1.0  # INT8\n        weight_bytes_per_expert = 3 * args.hidden_size * args.intermediate_size * bytes_per_elem\n\n        # Total weight bytes accessed per iteration (one layer per iteration)\n        # Each token activates num_experts_per_tok experts\n        total_experts_activated = qlen * args.num_experts_per_tok\n        weight_bytes_per_iter = total_experts_activated * weight_bytes_per_expert\n\n        bandwidth_gbs = weight_bytes_per_iter * args.test_iter / total_time / 1e9\n\n        # FLOPS calculation\n        # Per expert: 3 * hidden * intermediate * 2 (multiply-add)\n        flops_per_expert = 3 * args.hidden_size * args.intermediate_size * 2\n        total_flops = total_experts_activated * flops_per_expert * args.test_iter\n        tflops = total_flops / total_time / 1e12\n\n        # Results\n        print(\"\\n\" + \"=\" * 60)\n        print(\"Results\")\n        print(\"=\" * 60)\n        print(f\"  Total time:           {total_time:.3f} s\")\n        print(f\"  Time per iteration:   {time_per_iter_us:.2f} us  (= time per layer)\")\n        print(f\"  Memory bandwidth:     {bandwidth_gbs:.2f} GB/s\")\n        print(f\"  Compute throughput:   {tflops:.3f} TFLOPS\")\n        print(\"=\" * 60)\n\n        return {\n            \"total_time_s\": total_time,\n            \"time_per_iter_us\": time_per_iter_us,\n            \"bandwidth_gbs\": bandwidth_gbs,\n            \"tflops\": tflops,\n        }\n\n\ndef main():\n    args = parse_args()\n\n    if not HAS_KT_KERNEL:\n        print(f\"Error: kt_kernel not available: {import_error}\")\n        sys.exit(1)\n\n    run_benchmark(args)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/examples/configuration_deepseek_v3.py",
    "content": "from transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nDEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}\nclass DeepseekV3Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the DeepSeek-V3.\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 129280):\n            Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`DeepseekV3Model`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        moe_intermediate_size (`int`, *optional*, defaults to 1407):\n            Dimension of the MoE representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer decoder.\n        num_nextn_predict_layers (`int`, *optional*, defaults to 1):\n            Number of nextn predict layers in the DeepSeekV3 Model.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        n_shared_experts (`int`, *optional*, defaults to None):\n            Number of shared experts, None means dense model.\n        n_routed_experts (`int`, *optional*, defaults to None):\n            Number of routed experts, None means dense model.\n        routed_scaling_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor or routed experts.\n        topk_method (`str`, *optional*, defaults to `gready`):\n            Topk method used in routed gate.\n        n_group (`int`, *optional*, defaults to None):\n            Number of groups for routed experts.\n        topk_group (`int`, *optional*, defaults to None):\n            Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).\n        num_experts_per_tok (`int`, *optional*, defaults to None):\n            Number of selected experts, None means dense model.\n        moe_layer_freq (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.\n        first_k_dense_replace (`int`, *optional*, defaults to 0):\n            Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).\n                                                            \\--k dense layers--/\n        norm_topk_prob (`bool`, *optional*, defaults to False):\n            Whether to normalize the weights of the routed experts.\n        scoring_func (`str`, *optional*, defaults to 'softmax'):\n            Method of computing expert weights.\n        aux_loss_alpha (`float`, *optional*, defaults to 0.001):\n            Auxiliary loss weight coefficient.\n        seq_aux = (`bool`, *optional*, defaults to True):\n            Whether to compute the auxiliary loss for each individual sample.\n        num_key_value_heads (`int`, *optional*):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 2048):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*):\n            Padding token id.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            Beginning of stream token id.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            End of stream token id.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling\n            strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is\n            `{\"type\": strategy name, \"factor\": scaling factor}`. When using this flag, don't update\n            `max_position_embeddings` to the expected new maximum.\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n    ```python\n    >>> from transformers import DeepseekV3Model, DeepseekV3Config\n    >>> # Initializing a Deepseek-V3 style configuration\n    >>> configuration = DeepseekV3Config()\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"deepseek_v3\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=129280,\n        hidden_size=7168,\n        intermediate_size=18432,\n        moe_intermediate_size = 2048,\n        num_hidden_layers=61,\n        num_nextn_predict_layers=1,\n        num_attention_heads=128,\n        num_key_value_heads=128,\n        n_shared_experts = 1,\n        n_routed_experts = 256,\n        ep_size = 1,\n        routed_scaling_factor = 2.5,\n        kv_lora_rank = 512,\n        q_lora_rank = 1536,\n        qk_rope_head_dim = 64,\n        v_head_dim = 128,\n        qk_nope_head_dim = 128,\n        topk_method = 'noaux_tc',\n        n_group = 8,\n        topk_group = 4,\n        num_experts_per_tok = 8,\n        moe_layer_freq = 1,\n        first_k_dense_replace = 3,\n        norm_topk_prob = True,\n        scoring_func = 'sigmoid',\n        hidden_act=\"silu\",\n        max_position_embeddings=4096,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=0,\n        eos_token_id=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_nextn_predict_layers = num_nextn_predict_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )"
  },
  {
    "path": "kt-kernel/examples/modeling_deepseek_v3.py",
    "content": "# coding=utf-8\n# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch DeepSeek model.\"\"\"\nimport math\nimport warnings\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n    _prepare_4d_attention_mask,\n    _prepare_4d_causal_attention_mask,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    SequenceClassifierOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import (\n    ALL_LAYERNORM_LAYERS,\n    is_torch_greater_or_equal_than_1_13,\n)\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.utils.import_utils import is_torch_fx_available\nfrom configuration_deepseek_v3 import DeepseekV3Config\nimport torch.distributed as dist\nimport numpy as np\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n\n# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.\n# It means that the function will not be traced through and simply appear as a node in the graph.\nif is_torch_fx_available():\n    if not is_torch_greater_or_equal_than_1_13:\n        import torch.fx\n\n    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"DeepseekV3Config\"\n\n\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(\n        torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)\n    )\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\nclass DeepseekV3RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        DeepseekV3RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n        self.hidden_size = hidden_size\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)\n\n\nclass DeepseekV3RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (\n            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)\n        )\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings,\n            device=self.inv_freq.device,\n            dtype=torch.get_default_dtype(),\n        )\n        self.max_seq_len_cached = None\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n\n        freqs = torch.outer(t, self.inv_freq.to(t.device))\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3\nclass DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):\n    \"\"\"DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n    ):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n        t = t / self.scaling_factor\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3\nclass DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):\n    \"\"\"DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n    ):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * (\n                (self.scaling_factor * seq_len / self.max_position_embeddings)\n                - (self.scaling_factor - 1)\n            ) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (\n                base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)\n            )\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Inverse dim formula to find dim based on number of rotations\ndef yarn_find_correction_dim(\n    num_rotations, dim, base=10000, max_position_embeddings=2048\n):\n    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (\n        2 * math.log(base)\n    )\n\n\n# Find dim range bounds based on rotations\ndef yarn_find_correction_range(\n    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048\n):\n    low = math.floor(\n        yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)\n    )\n    high = math.ceil(\n        yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)\n    )\n    return max(low, 0), min(high, dim - 1)  # Clamp values just in case\n\n\ndef yarn_get_mscale(scale=1, mscale=1):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\ndef yarn_linear_ramp_mask(min, max, dim):\n    if min == max:\n        max += 0.001  # Prevent singularity\n\n    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n    ramp_func = torch.clamp(linear_func, 0, 1)\n    return ramp_func\n\n\nclass DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n        original_max_position_embeddings=4096,\n        beta_fast=32,\n        beta_slow=1,\n        mscale=1,\n        mscale_all_dim=0,\n    ):\n        self.scaling_factor = scaling_factor\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self.beta_fast = beta_fast\n        self.beta_slow = beta_slow\n        self.mscale = mscale\n        self.mscale_all_dim = mscale_all_dim\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        dim = self.dim\n\n        freq_extra = 1.0 / (\n            self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n        freq_inter = 1.0 / (\n            self.scaling_factor\n            * self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n\n        low, high = yarn_find_correction_range(\n            self.beta_fast,\n            self.beta_slow,\n            dim,\n            self.base,\n            self.original_max_position_embeddings,\n        )\n        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(\n            device=device, dtype=torch.float32\n        )\n        inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(seq_len, device=device, dtype=torch.float32)\n\n        freqs = torch.outer(t, inv_freq)\n\n        _mscale = float(\n            yarn_get_mscale(self.scaling_factor, self.mscale)\n            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)\n        )\n\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\n            \"cos_cached\", (emb.cos() * _mscale).to(dtype), persistent=False\n        )\n        self.register_buffer(\n            \"sin_cached\", (emb.sin() * _mscale).to(dtype), persistent=False\n        )\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n            used to pass offsetted position ids when working with a KV-cache.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n    sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n\n    b, h, s, d = q.shape\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    b, h, s, d = k.shape\n    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass DeepseekV3MLP(nn.Module):\n    def __init__(self, config, hidden_size=None, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size\n        self.intermediate_size = (\n            config.intermediate_size if intermediate_size is None else intermediate_size\n        )\n\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass MoEGate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.top_k = config.num_experts_per_tok\n        self.n_routed_experts = config.n_routed_experts\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.scoring_func = config.scoring_func\n        self.topk_method = config.topk_method\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n\n        # topk selection algorithm\n        self.norm_topk_prob = config.norm_topk_prob\n        self.gating_dim = config.hidden_size\n        self.weight = nn.Parameter(\n            torch.empty((self.n_routed_experts, self.gating_dim))\n        )\n        if self.topk_method == \"noaux_tc\":\n            self.e_score_correction_bias = nn.Parameter(\n                torch.empty((self.n_routed_experts))\n            )\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        import torch.nn.init as init\n\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n\n    def forward(self, hidden_states):\n        bsz, seq_len, h = hidden_states.shape\n        ### compute gating score\n        hidden_states = hidden_states.view(-1, h)\n        logits = F.linear(\n            hidden_states.type(torch.float32), self.weight.type(torch.float32), None\n        )\n        if self.scoring_func == \"sigmoid\":\n            scores = logits.sigmoid()\n        else:\n            raise NotImplementedError(\n                f\"insupportable scoring function for MoE gating: {self.scoring_func}\"\n            )\n\n        ### select top-k experts\n        if self.topk_method == \"noaux_tc\":\n            # assert not self.training\n            scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)\n            group_scores = (\n                scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)\n            )  # [n, n_group]\n            group_idx = torch.topk(\n                group_scores, k=self.topk_group, dim=-1, sorted=False\n            )[\n                1\n            ]  # [n, top_k_group]\n            group_mask = torch.zeros_like(group_scores)  # [n, n_group]\n            group_mask.scatter_(1, group_idx, 1)  # [n, n_group]\n            score_mask = (\n                group_mask.unsqueeze(-1)\n                .expand(\n                    bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group\n                )\n                .reshape(bsz * seq_len, -1)\n            )  # [n, e]\n            tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float(\"-inf\"))  # [n, e]\n            _, topk_idx = torch.topk(\n                tmp_scores, k=self.top_k, dim=-1, sorted=False\n            )\n            topk_weight = scores.gather(1, topk_idx)\n        else:\n            raise NotImplementedError(\n                f\"insupportable TopK function for MoE gating: {self.topk_method}\"\n            )\n\n        ### norm gate to sum 1\n        if self.top_k > 1 and self.norm_topk_prob:\n            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weight = topk_weight / denominator\n        topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor\n\n        return topk_idx, topk_weight\n\nclass DeepseekV3MoE(nn.Module):\n    \"\"\"\n    A mixed expert module containing shared experts.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.num_experts_per_tok = config.num_experts_per_tok\n\n        if hasattr(config, \"ep_size\") and config.ep_size > 1:\n            assert config.ep_size == dist.get_world_size()\n            self.ep_size = config.ep_size\n            self.experts_per_rank = config.n_routed_experts // config.ep_size\n            self.ep_rank = dist.get_rank()\n            self.experts = nn.ModuleList(\n                [\n                    (\n                        DeepseekV3MLP(\n                            config, intermediate_size=config.moe_intermediate_size\n                        )\n                        if i >= self.ep_rank * self.experts_per_rank\n                        and i < (self.ep_rank + 1) * self.experts_per_rank\n                        else None\n                    )\n                    for i in range(config.n_routed_experts)\n                ]\n            )\n        else:\n            self.ep_size = 1\n            self.experts_per_rank = config.n_routed_experts\n            self.ep_rank = 0\n            self.experts = nn.ModuleList(\n                [\n                    DeepseekV3MLP(\n                        config, intermediate_size=config.moe_intermediate_size\n                    )\n                    for i in range(config.n_routed_experts)\n                ]\n            )\n        self.gate = MoEGate(config)\n        if config.n_shared_experts is not None:\n            intermediate_size = config.moe_intermediate_size * config.n_shared_experts\n            self.shared_experts = DeepseekV3MLP(\n                config=config, intermediate_size=intermediate_size\n            )\n\n    def forward(self, hidden_states):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        topk_idx, topk_weight = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        flat_topk_idx = topk_idx.view(-1)\n        if not self.training:\n            y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)\n        if self.config.n_shared_experts is not None:\n            y = y + self.shared_experts(identity)\n        return y\n\n    @torch.no_grad()\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        sorted_tokens_shape = sorted_tokens.shape\n        if self.ep_size > 1:\n            tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)\n            tokens_per_expert_group = tokens_per_expert.new_empty(\n                tokens_per_expert.shape[0]\n            )\n            dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)\n            output_splits = (\n                tokens_per_expert_group.view(self.ep_size, -1)\n                .sum(1)\n                .cpu()\n                .numpy()\n                .tolist()\n            )\n            gathered_tokens = sorted_tokens.new_empty(\n                tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]\n            )\n            input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()\n            dist.all_to_all(\n                list(gathered_tokens.split(output_splits)),\n                list(sorted_tokens.split(input_split_sizes)),\n            )\n            tokens_per_expert_post_gather = tokens_per_expert_group.view(\n                self.ep_size, self.experts_per_rank\n            ).sum(dim=0)\n            gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)\n            s = 0\n            for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):\n                gatherd_idxs[s : s + k] = i % self.experts_per_rank\n                s += k\n            gatherd_idxs = gatherd_idxs.argsort()\n            sorted_tokens = gathered_tokens[gatherd_idxs]\n            tokens_per_expert = tokens_per_expert_post_gather\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n        if self.ep_size > 1:\n            new_x = torch.empty_like(outs)\n            new_x[gatherd_idxs] = outs\n            gathered_tokens = new_x.new_empty(*sorted_tokens_shape)\n            dist.all_to_all(\n                list(gathered_tokens.split(input_split_sizes)),\n                list(new_x.split(output_splits)),\n            )\n            outs = gathered_tokens\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3\nclass DeepseekV3Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will \"\n                \"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.attention_dropout = config.attention_dropout\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.q_lora_rank = config.q_lora_rank\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.kv_lora_rank = config.kv_lora_rank\n        self.v_head_dim = config.v_head_dim\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n\n        self.is_causal = True\n\n        if self.q_lora_rank is None:\n            self.q_proj = nn.Linear(\n                self.hidden_size, self.num_heads * self.q_head_dim, bias=False\n            )\n        else:\n            self.q_a_proj = nn.Linear(\n                self.hidden_size, config.q_lora_rank, bias=config.attention_bias\n            )\n            self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)\n            self.q_b_proj = nn.Linear(\n                config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False\n            )\n\n        self.kv_a_proj_with_mqa = nn.Linear(\n            self.hidden_size,\n            config.kv_lora_rank + config.qk_rope_head_dim,\n            bias=config.attention_bias,\n        )\n        self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)\n        self.kv_b_proj = nn.Linear(\n            config.kv_lora_rank,\n            self.num_heads\n            * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),\n            bias=False,\n        )\n\n        self.o_proj = nn.Linear(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=config.attention_bias,\n        )\n        self._init_rope()\n\n        self.softmax_scale = self.q_head_dim ** (-0.5)\n        if self.config.rope_scaling is not None:\n            mscale_all_dim = self.config.rope_scaling.get(\"mscale_all_dim\", 0)\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if mscale_all_dim:\n                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n                self.softmax_scale = self.softmax_scale * mscale * mscale\n\n    def _init_rope(self):\n        if self.config.rope_scaling is None:\n            self.rotary_emb = DeepseekV3RotaryEmbedding(\n                self.qk_rope_head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                base=self.rope_theta,\n            )\n        else:\n            scaling_type = self.config.rope_scaling[\"type\"]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if scaling_type == \"linear\":\n                self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"dynamic\":\n                self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"yarn\":\n                kwargs = {\n                    key: self.config.rope_scaling[key]\n                    for key in [\n                        \"original_max_position_embeddings\",\n                        \"beta_fast\",\n                        \"beta_slow\",\n                        \"mscale\",\n                        \"mscale_all_dim\",\n                    ]\n                    if key in self.config.rope_scaling\n                }\n                self.rotary_emb = DeepseekV3YarnRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                    **kwargs,\n                )\n            else:\n                raise ValueError(f\"Unknown RoPE scaling type {scaling_type}\")\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)\n\n        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        attn_weights = (\n            torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale\n        )\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n        assert attention_mask is not None\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.attention_dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3\nclass DeepseekV3FlashAttention2(DeepseekV3Attention):\n    \"\"\"\n    DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays\n    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of\n    flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.\n        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.\n        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).\n        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        # DeepseekV3FlashAttention2 attention does not support output_attentions\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n\n            # overwrite attention_mask with padding_mask\n            attention_mask = kwargs.pop(\"padding_mask\")\n\n        output_attentions = False\n\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dim x hidden_dim\n        # therefore we just need to keep the original shape\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)\n\n        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n\n        if self.q_head_dim != self.v_head_dim:\n            value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n        # to be able to avoid many of these transpose/reshape/view.\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        dropout_rate = self.attention_dropout if self.training else 0.0\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in the correct dtype just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (DeepseekV3RMSNorm handles it correctly)\n\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            # Handle the case where the model is quantized\n            if hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            elif torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            else:\n                target_dtype = (\n                    self.q_proj.weight.dtype\n                    if self.q_lora_rank is None\n                    else self.q_a_proj.weight.dtype\n                )\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        attn_output = self._flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            dropout=dropout_rate,\n            softmax_scale=self.softmax_scale,\n        )\n        if self.q_head_dim != self.v_head_dim:\n            attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n        attn_output = attn_output.reshape(\n            bsz, q_len, self.num_heads * self.v_head_dim\n        ).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n    def _flash_attention_forward(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        query_length,\n        dropout=0.0,\n        softmax_scale=None,\n    ):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n\n        Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`int`, *optional*):\n                Attention dropout\n            softmax_scale (`float`, *optional*):\n                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)\n        \"\"\"\n        if not self._flash_attn_uses_top_left_mask:\n            causal = self.is_causal\n        else:\n            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.\n            causal = self.is_causal and query_length != 1\n\n        # Contains at least one padding token in the sequence\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            (\n                query_states,\n                key_states,\n                value_states,\n                indices_q,\n                cu_seq_lens,\n                max_seq_lens,\n            ) = self._upad_input(\n                query_states, key_states, value_states, attention_mask, query_length\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            attn_output_unpad = flash_attn_varlen_func(\n                query_states,\n                key_states,\n                value_states,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k=cu_seqlens_k,\n                max_seqlen_q=max_seqlen_in_batch_q,\n                max_seqlen_k=max_seqlen_in_batch_k,\n                dropout_p=dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n            attn_output = pad_input(\n                attn_output_unpad, indices_q, batch_size, query_length\n            )\n        else:\n            attn_output = flash_attn_func(\n                query_states,\n                key_states,\n                value_states,\n                dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n        return attn_output\n\n    def _upad_input(\n        self, query_layer, key_layer, value_layer, attention_mask, query_length\n    ):\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape\n\n        key_layer = index_first_axis(\n            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        value_layer = index_first_axis(\n            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(\n                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),\n                indices_k,\n            )\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(\n                batch_size + 1, dtype=torch.int32, device=query_layer.device\n            )  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(\n                query_layer, attention_mask\n            )\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q,\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\nATTENTION_CLASSES = {\n    \"eager\": DeepseekV3Attention,\n    \"flash_attention_2\": DeepseekV3FlashAttention2,\n}\n\n\nclass DeepseekV3DecoderLayer(nn.Module):\n    def __init__(self, config: DeepseekV3Config, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = ATTENTION_CLASSES[config._attn_implementation](\n            config=config, layer_idx=layer_idx\n        )\n\n        self.mlp = (\n            DeepseekV3MoE(config)\n            if (\n                config.n_routed_experts is not None\n                and layer_idx >= config.first_k_dense_replace\n                and layer_idx % config.moe_layer_freq == 0\n            )\n            else DeepseekV3MLP(config)\n        )\n        self.input_layernorm = DeepseekV3RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = DeepseekV3RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*):\n                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,\n                query_sequence_length, key_sequence_length)` if default attention is used.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nDeepseekV3_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`DeepseekV3Config`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.\",\n    DeepseekV3_START_DOCSTRING,\n)\nclass DeepseekV3PreTrainedModel(PreTrainedModel):\n    config_class = DeepseekV3Config\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"DeepseekV3DecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_cache_class = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nDeepseekV3_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.\",\n    DeepseekV3_START_DOCSTRING,\n)\nclass DeepseekV3Model(DeepseekV3PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]\n\n    Args:\n        config: DeepseekV3Config\n    \"\"\"\n\n    def __init__(self, config: DeepseekV3Config):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n        self.layers = nn.ModuleList(\n            [\n                DeepseekV3DecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self._use_flash_attention_2 = config._attn_implementation == \"flash_attention_2\"\n        self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape[:2]\n        elif inputs_embeds is not None:\n            batch_size, seq_length = inputs_embeds.shape[:2]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        past_key_values_length = 0\n        if use_cache:\n            use_legacy_cache = not isinstance(past_key_values, Cache)\n            if use_legacy_cache:\n                past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            past_key_values_length = past_key_values.get_usable_length(seq_length)\n\n        if position_ids is None:\n            device = input_ids.device if input_ids is not None else inputs_embeds.device\n            position_ids = torch.arange(\n                past_key_values_length,\n                seq_length + past_key_values_length,\n                dtype=torch.long,\n                device=device,\n            )\n            position_ids = position_ids.unsqueeze(0)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if self._use_flash_attention_2:\n            # 2d mask is passed through the layers\n            attention_mask = (\n                attention_mask\n                if (attention_mask is not None and 0 in attention_mask)\n                else None\n            )\n        else:\n            # 4d mask is passed through the layers\n            attention_mask = _prepare_4d_causal_attention_mask(\n                attention_mask,\n                (batch_size, seq_length),\n                inputs_embeds,\n                past_key_values_length,\n            )\n\n        # embed positions\n        hidden_states = inputs_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_values,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                cache_position=cache_position,\n            )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = (\n                next_decoder_cache.to_legacy_cache()\n                if use_legacy_cache\n                else next_decoder_cache\n            )\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\"Custom 4D attention mask should be passed in inverted form with max==0`\")\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\nclass DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = DeepseekV3Model(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)\n    @replace_return_docstrings(\n        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM\n\n        >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states[:,-1:,:])\n        logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        **kwargs,\n    ):\n        if past_key_values is not None:\n            if isinstance(past_key_values, Cache):\n                cache_length = past_key_values.get_seq_length()\n                past_length = past_key_values.seen_tokens\n                max_cache_length = past_key_values.get_max_length()\n            else:\n                cache_length = past_length = past_key_values[0][0].shape[2]\n                max_cache_length = None\n\n            # Keep only the unprocessed tokens:\n            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where\n            # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as\n            # input)\n            if (\n                attention_mask is not None\n                and attention_mask.shape[1] > input_ids.shape[1]\n            ):\n                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]\n            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard\n            # input_ids based on the past_length.\n            elif past_length < input_ids.shape[1]:\n                input_ids = input_ids[:, past_length:]\n            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.\n\n            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.\n            if (\n                max_cache_length is not None\n                and attention_mask is not None\n                and cache_length + input_ids.shape[1] > max_cache_length\n            ):\n                attention_mask = attention_mask[:, -max_cache_length:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx.to(past_state.device))\n                    for past_state in layer_past\n                ),\n            )\n        return reordered_past\n\n\n@add_start_docstrings(\n    \"\"\"\n    The DeepseekV3 Model transformer with a sequence classification head on top (linear layer).\n\n    [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    DeepseekV3_START_DOCSTRING,\n)\nclass DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = DeepseekV3Model(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\n                \"Cannot handle batch sizes > 1 if no padding token is defined.\"\n            )\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                sequence_lengths = (\n                    torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                ).to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[\n            torch.arange(batch_size, device=logits.device), sequence_lengths\n        ]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (\n                    labels.dtype == torch.long or labels.dtype == torch.int\n                ):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(\n                    pooled_logits.view(-1, self.num_labels), labels.view(-1)\n                )\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n"
  },
  {
    "path": "kt-kernel/examples/repro_llamafile_re.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nMinimal LLAMAFILE repro harness to catch intermittent RuntimeError/RE.\n\nRequirements:\n- kt_kernel_ext built with LLAMAFILE (and CUDA stream integration)\n- Valid GGUF weights directory (WEIGHT_PATH)\n\nUsage:\n  WEIGHT_PATH=/path/to/gguf python examples/repro_llamafile_re.py\n\nOptional env:\n  DEVICE=cuda|cpu           # default: auto (cuda if available)\n  N_ITERS=1000              # iterations\n  BATCH=4                   # batch size\n  H=2048                    # hidden size\n  EXPERTS=128               # total experts\n  TOPK=8                    # experts per token\n  INTER=768                 # intermediate size (must be divisible by 256)\n  GPU_EXPERTS=100           # num experts on GPU side\n  TP=2                      # threadpool_count\n  CPU_THREADS=32            # cpuinfer_threads\n  MAX_DEFER=2               # max_deferred_experts_per_token\n  MODE=split|forward        # split=submit+sync, forward=wrapper.forward\n  SEED=1                    # random seed\n\nDebug tips:\n  - Set CUDA_LAUNCH_BLOCKING=1 to catch async errors deterministically.\n  - Try varying N_ITERS, BATCH, TOPK, MAX_DEFER.\n  - Capture stdout/stderr for failure iteration index.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nimport sys\nimport faulthandler\nimport torch\n\nfrom kt_kernel import KTMoEWrapper\n\n\ndef getenv_int(name: str, default: int) -> int:\n    try:\n        return int(os.environ.get(name, default))\n    except Exception:\n        return default\n\n\ndef get_stream_for(device: torch.device | str):\n    device = torch.device(device)\n    if device.type == \"cuda\" and torch.cuda.is_available():\n        return torch.cuda.current_stream(device).cuda_stream\n    return 0\n\n\ndef main() -> int:\n    faulthandler.enable()\n\n    weight_path = (os.environ.get(\"WEIGHT_PATH\") or \"\").strip()\n    if not weight_path:\n        print(\"ERROR: WEIGHT_PATH env is required.\")\n        return 2\n    if not os.path.exists(weight_path):\n        print(f\"ERROR: WEIGHT_PATH does not exist: {weight_path}\")\n        return 2\n\n    device_str = os.environ.get(\"DEVICE\") or (\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    device = torch.device(device_str)\n\n    n_iters = getenv_int(\"N_ITERS\", 1000)\n    batch = getenv_int(\"BATCH\", 4)\n    hidden = getenv_int(\"H\", 2048)\n    experts = getenv_int(\"EXPERTS\", 128)\n    topk = getenv_int(\"TOPK\", 8)\n    inter = getenv_int(\"INTER\", 768)\n    gpu_experts = getenv_int(\"GPU_EXPERTS\", 100)\n    tp = getenv_int(\"TP\", 2)\n    cpu_threads = getenv_int(\"CPU_THREADS\", 32)\n    max_defer = getenv_int(\"MAX_DEFER\", 2)\n    seed = getenv_int(\"SEED\", 1)\n    mode = (os.environ.get(\"MODE\") or \"split\").lower()\n\n    if inter % 256 != 0:\n        print(f\"ERROR: INTER must be divisible by 256 for LLAMAFILE (got {inter}).\")\n        return 2\n\n    print(\n        f\"LLAMAFILE Repro: device={device}, iters={n_iters}, batch={batch}, H={hidden}, topk={topk}, E={experts}, inter={inter}, TP={tp}, CPU_THREADS={cpu_threads}, mode={mode}\"\n    )\n    print(f\"Weights: {weight_path}\")\n\n    torch.manual_seed(seed)\n\n    # Create wrapper and load weights once\n    wrapper = KTMoEWrapper(\n        layer_idx=0,\n        num_experts=experts,\n        num_experts_per_tok=topk,\n        hidden_size=hidden,\n        moe_intermediate_size=inter,\n        num_gpu_experts=gpu_experts,\n        cpuinfer_threads=cpu_threads,\n        threadpool_count=tp,\n        weight_path=weight_path,\n        chunked_prefill_size=512,\n        method=\"LLAMAFILE\",\n        max_deferred_experts_per_token=max_defer,\n    )\n    wrapper.load_weights()\n\n    # Optional capture of small batch sizes\n    KTMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])\n\n    stream = get_stream_for(device)\n\n    # Allocate once and reuse to reduce allocator noise\n    hidden_states = torch.empty(batch, hidden, dtype=torch.bfloat16, device=device)\n    topk_ids = torch.empty(batch, topk, dtype=torch.long, device=device)\n    topk_weights = torch.empty(batch, topk, dtype=torch.float32, device=device)\n\n    def fill_random():\n        hidden_states.normal_(mean=0.0, std=1.0)\n        topk_ids.random_(0, experts)\n        topk_weights.uniform_()\n        topk_weights.div_(topk_weights.sum(dim=-1, keepdim=True) + 1e-6)\n\n    # Warmup\n    fill_random()\n    _ = wrapper.forward(hidden_states, topk_ids, topk_weights, stream)\n    if device.type == \"cuda\":\n        torch.cuda.synchronize(device)\n\n    # Main loop\n    for i in range(n_iters):\n        try:\n            fill_random()\n            if mode == \"forward\":\n                _ = wrapper.forward(hidden_states, topk_ids, topk_weights, stream)\n            else:\n                wrapper.submit_forward(hidden_states, topk_ids, topk_weights, stream)\n                # Optional small GPU op to put work on the same stream\n                if device.type == \"cuda\":\n                    hidden_states.add_(0)  # no-op but enqueued on current stream\n                _ = wrapper.sync_forward(hidden_states, stream)\n\n            if (i + 1) % 50 == 0:\n                print(f\"ok: iter {i + 1}/{n_iters}\")\n                if device.type == \"cuda\":\n                    torch.cuda.synchronize(device)\n\n        except Exception as e:\n            print(f\"FAIL at iter {i}: {repr(e)}\")\n            # Flush GPU work for better diagnostics\n            if device.type == \"cuda\":\n                try:\n                    torch.cuda.synchronize(device)\n                except Exception as _:\n                    pass\n            return 1\n\n    print(\"All iterations completed without error.\")\n    return 0\n\n\nif __name__ == \"__main__\":\n    sys.exit(main())\n\n"
  },
  {
    "path": "kt-kernel/examples/test-debug.py",
    "content": "import os\nimport sys\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nimport torch\nimport ctypes\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext.moe import MOEConfig, MOE, AMXBF16_MOE, AMXInt8_MOE, AMXInt4_MOE, AMXInt4_1_MOE\n\nintermediate_size_full = 2048\nmoe_intermediate_size = 3072\nhidden_size = 7168\nexperts_num = 256\nnum_experts_per_tok = 8\ncpu_infer = kt_kernel_ext.CPUInfer(97)\n\nup = torch.empty(experts_num, intermediate_size_full, hidden_size, dtype=torch.bfloat16, device=\"cpu\")\n\ngate = torch.empty(experts_num, intermediate_size_full, hidden_size, dtype=torch.bfloat16, device=\"cpu\")\n\ndown = torch.empty(experts_num, hidden_size, intermediate_size_full, dtype=torch.bfloat16, device=\"cpu\")\n\ngate_ptr = ctypes.addressof(ctypes.cast(gate.data_ptr(), ctypes.POINTER(ctypes.c_uint64)).contents)\nup_ptr = ctypes.addressof(ctypes.cast(up.data_ptr(), ctypes.POINTER(ctypes.c_uint64)).contents)\ndown_ptr = ctypes.addressof(ctypes.cast(down.data_ptr(), ctypes.POINTER(ctypes.c_uint64)).contents)\nmoe_config = MOEConfig(\n    experts_num,\n    num_experts_per_tok,\n    hidden_size,\n    moe_intermediate_size,\n)\nmoe_config.layer_idx = 45\nmoe_config.pool = cpu_infer.backend_\nmoe_config.max_len = 1024  # TODO(zbx): multi cuda graph\nmoe_config.gate_proj = gate_ptr\nmoe_config.up_proj = up_ptr\nmoe_config.down_proj = down_ptr\nmoe_config.path = \"\"\nmoe = AMXInt4_MOE(moe_config)\n"
  },
  {
    "path": "kt-kernel/examples/test_apply_rope.py",
    "content": "import torch\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\ndef apply_rotary_pos_emb(q, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    b, h, s, d = q.shape\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    return q_embed\n\ndef my_apply(q,cos,sin):\n    \n    qa = q[:,:,range(0,64,2)]\n    qb = q[:,:,range(1,65,2)]\n    q1 = (qa * cos - qb * sin)\n    q2 = (qb*cos + qa*sin)\n    return torch.cat((q1,q2),-1)\n\n\nnum_heads = 128\nseq_len = 1024\nrope_size = 64\n\n# theta = torch.randn(, dtype=torch.float32)\n\n\n"
  },
  {
    "path": "kt-kernel/examples/test_attention.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : Jianwei Dong\nDate         : 2024-08-28 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-28 10:32:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nfrom flash_attn import flash_attn_with_kvcache\nimport torch\n\nlayer_num = 10\nkv_head_num = 8\nq_head_num = 32\nhead_dim = 128\nblock_len = 128\nanchor_num = 1\ncache_seqlen = 8192\ncache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device=\"cpu\")\nseqlens_zero = torch.zeros((1,), dtype=torch.int32, device=\"cpu\")\nanchor_type = kt_kernel_ext.kvcache.AnchorType.DYNAMIC\nkv_type = kt_kernel_ext.kvcache.ggml_type.FP16\nretrieval_type = kt_kernel_ext.kvcache.RetrievalType.LAYER\nlayer_step: int = 1\ntoken_step: int = 1\nlayer_offset: int = 0\nmax_thread_num: int = 2\nmax_batch_size: int = 1\nmax_block_num: int = 512\nCPUInfer = kt_kernel_ext.CPUInfer(max_thread_num)\nvalidation_iter = 100\n\nwith torch.inference_mode(mode=True):\n    config = kt_kernel_ext.kvcache.KVCacheConfig(\n        layer_num,\n        kv_head_num,\n        q_head_num,\n        head_dim,\n        block_len,\n        anchor_num,\n        anchor_type,\n        kv_type,\n        retrieval_type,\n        layer_step,\n        token_step,\n        layer_offset,\n        max_block_num,\n        max_batch_size,\n        max_thread_num,\n    )\n    local_kvcache = kt_kernel_ext.kvcache.KVCache(config)\n\n    kvcaches = []\n    block_table = torch.arange(max_block_num, dtype=torch.int32, device=\"cpu\").contiguous().view(1, -1)\n\n    for layer_idx in range(layer_num):\n        k_cache = torch.randn((1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device=\"cpu\").contiguous()\n        v_cache = torch.randn((1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device=\"cpu\").contiguous()\n\n        CPUInfer.submit(\n            local_kvcache.update_kvcache_fp16(\n                k_cache.data_ptr(),\n                v_cache.data_ptr(),\n                layer_idx,\n                block_table.data_ptr(),\n                1,\n                max_block_num,\n                seqlens_zero.data_ptr(),\n                cache_seqlen,\n            )\n        )\n        CPUInfer.sync()\n\n        kvcaches.append((k_cache.to(\"cuda\"), v_cache.to(\"cuda\")))\n\n    # validation\n    for i in range(validation_iter):\n\n        k_cache = kvcaches[i % layer_num][0]\n        v_cache = kvcaches[i % layer_num][1]\n        input = torch.randn((1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\").contiguous()\n        output = torch.empty((1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\").contiguous()\n\n        # attn_lse: (bsz, q_len, q_head_num)\n        attn_lse = torch.empty((1, 1, q_head_num), dtype=torch.float32, device=\"cpu\").contiguous()\n        input = input / 100\n\n        CPUInfer.submit(\n            local_kvcache.attn(\n                input.data_ptr(),\n                output.data_ptr(),\n                attn_lse.data_ptr(),\n                i % layer_num,\n                0,\n                1,\n                1,\n                max_block_num,\n                block_table.data_ptr(),\n                cache_seqlens.data_ptr(),\n                -1,\n                -1,\n                -1,\n            )\n        )\n        CPUInfer.sync()\n        # print(\"cpuinfer output\", output)\n\n        t_output = flash_attn_with_kvcache(\n            q=input.to(\"cuda\"),\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_seqlens=cache_seqlens.to(\"cuda\"),\n        )\n        # print(\"torch output\", t_output)\n\n        diff = torch.mean(torch.abs(output.to(\"cuda\") - t_output)) / torch.mean(torch.abs(t_output))\n        print(\"diff = \", diff)\n        assert diff < 0.001\n"
  },
  {
    "path": "kt-kernel/examples/test_awq_moe_amx.py",
    "content": "import os, sys\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\n\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\n# Set fixed seed for reproducible results\ntorch.manual_seed(42)\n\n# Constants for 4-bit packing\nQ_BITS = 4\nSTORAGE_BITS = 32\nPACK_NUM = STORAGE_BITS // Q_BITS  # 8\n\n\ndef pack(imatrix: torch.Tensor, direction: str = \"row\"):\n    \"\"\"\n    Packs a 4-bit integer matrix into a packed 32-bit integer matrix.\n    Packing order: 7 6 5 4 3 2 1 0 (MSB to LSB, original order)\n    Args:\n        imatrix (torch.Tensor): matrix of integers\n\n    Returns:\n        qmatrix (torch.Tensor): packed matrix of integers\n    \"\"\"\n    shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=imatrix.device)\n\n    imatrix = imatrix.to(torch.int8)\n    imatrix = torch.bitwise_and(imatrix, 0x0F)  # eventually correct overflow\n\n    if direction == \"column\":\n        imatrix = imatrix.view(-1, imatrix.shape[1] // PACK_NUM, PACK_NUM)\n        qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)\n\n    elif direction == \"row\":\n        imatrix = imatrix.view(imatrix.shape[0] // PACK_NUM, PACK_NUM, -1)\n        qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)\n\n    qmatrix = qmatrix.to(torch.int32)\n\n    return qmatrix\n\n\nexpert_num = 16\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nqlen = 1\nlayer_num = 1\nCPUInfer = kt_kernel_ext.CPUInfer(40)\nvalidation_iter = 10\nk_group_size = 64\ndebug_print_count = 16\n\nphysical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\n\ndef generate_original_weights():\n    \"\"\"Generate original FP16/BF16 weights for online quantization testing\"\"\"\n    # Set seed to ensure consistency between online and offline quantization\n    torch.manual_seed(42)\n\n    # Generate weights in the same format as test_moe_amx.py (bfloat16)\n    gate_proj_bf16 = (\n        torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device=\"cuda\")\n        .to(\"cpu\")\n        .contiguous()\n    )\n    up_proj_bf16 = (\n        torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device=\"cuda\")\n        .to(\"cpu\")\n        .contiguous()\n    )\n    down_proj_bf16 = (\n        torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.bfloat16, device=\"cuda\")\n        .to(\"cpu\")\n        .contiguous()\n    )\n\n    # Print first row of gate_proj for expert 0 (first debug_print_count elements)\n    print(\n        f\"[DEBUG] Online quantization gate_proj expert 0, row 0, first {debug_print_count} elements: {gate_proj_bf16[0, 0, :debug_print_count]}\"\n    )\n\n    return gate_proj_bf16, up_proj_bf16, down_proj_bf16\n\n\ndef generate_awq_quantized_weights():\n    \"\"\"Generate AWQ quantized weights (qweight, scales, qzeros) for testing\"\"\"\n    # Reset seed to ensure same weights as online quantization\n    torch.manual_seed(42)\n\n    # Generate original FP16 weights (convert from same random values as online version)\n    gate_proj_fp16 = (\n        torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device=\"cuda\")\n        .to(\"cpu\")\n        .to(torch.float16)\n        .contiguous()\n    )\n    up_proj_fp16 = (\n        torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device=\"cuda\")\n        .to(\"cpu\")\n        .to(torch.float16)\n        .contiguous()\n    )\n    down_proj_fp16 = (\n        torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.bfloat16, device=\"cuda\")\n        .to(\"cpu\")\n        .to(torch.float16)\n        .contiguous()\n    )\n\n    # Print first row of gate_proj for expert 0 (first debug_print_count elements)\n    print(\n        f\"[DEBUG] Offline AWQ gate_proj expert 0, row 0, first {debug_print_count} elements: {gate_proj_fp16[0, 0, :debug_print_count]}\"\n    )\n\n    # Calculate quantization parameters per group\n    def quantize_tensor_awq(weight, group_size=128):\n        \"\"\"Simple AWQ-style quantization simulation with interleaving\"\"\"\n\n        w_orig_shape = weight.shape\n        expert_num, col, row = weight.shape\n        group_num = (row + group_size - 1) // group_size\n\n        # 1. reshape into groups along row dimension\n        weight_grouped = weight.view(expert_num, col, group_num, group_size)  # [E, G, group_size, C]\n\n        # 2. calculate scales per group (max abs value / 7.0 for 4-bit signed)\n        max_val = torch.max(weight_grouped, dim=3).values\n        min_val = torch.min(weight_grouped, dim=3).values\n        scales = (max_val - min_val).clamp(min=1e-5) / 15.0  # [E, G, C]\n        zeros = (-torch.round(min_val / scales)).clamp_(0, 15).to(torch.int8)\n\n        # 5. quantize weights\n        qweight_int = torch.clamp(\n            torch.round((weight_grouped - min_val.unsqueeze(-1)) / scales.unsqueeze(-1)), 0, 15\n        ).to(torch.int8)\n\n        qweight_int = qweight_int.view(w_orig_shape)\n\n        # 6. pack qweight along row (group_size) using helper\n        qweight_packed_list = []\n        for e in range(expert_num):\n            packed = pack(qweight_int[e], direction=\"column\")  # [1, ? , col] or similar\n            qweight_packed_list.append(packed)\n        qweight_packed = torch.stack(qweight_packed_list, dim=0)  # [E, row, col / 8]\n\n        # 7. pack zeros along group dimension (row) using helper\n        zeros_packed_list = []\n        for e in range(expert_num):\n            zeros_packed_list.append(pack(zeros[e].transpose(0, 1), direction=\"column\"))  # [blocks, col]\n        qzeros_packed = torch.stack(zeros_packed_list, dim=0)\n\n        scales = scales.transpose(1, 2).to(torch.float16)\n        print(scales.shape)\n        scales = scales.flatten().contiguous()\n\n        min_val = min_val.transpose(1, 2).to(torch.float16).flatten().contiguous()\n\n        zeros = zeros.transpose(1, 2).flatten().contiguous()\n\n        qzeros_packed = qzeros_packed.flatten().contiguous()\n\n        qweight_packed = qweight_packed.flatten().contiguous()\n\n        return {\n            \"qweight\": qweight_packed,  # Same for both torch and AWQ-MoE\n            \"scales\": scales,  # Same for both torch and AWQ-MoE\n            \"qzeros\": qzeros_packed,  # Same for both torch and AWQ-MoE\n            \"mins\": min_val,  # scales * zeros for comparison\n        }\n\n    # Quantize each projection\n    gate_data = quantize_tensor_awq(gate_proj_fp16, k_group_size)\n    up_data = quantize_tensor_awq(up_proj_fp16, k_group_size)\n    down_data = quantize_tensor_awq(down_proj_fp16, k_group_size)\n\n    return {\n        # Data for both torch and AWQ-MoE (no interleaving)\n        \"gate_qweight\": gate_data[\"qweight\"],\n        \"gate_scales\": gate_data[\"scales\"],\n        \"gate_qzeros\": gate_data[\"qzeros\"],\n        \"gate_mins\": gate_data[\"mins\"],\n        \"up_qweight\": up_data[\"qweight\"],\n        \"up_scales\": up_data[\"scales\"],\n        \"up_qzeros\": up_data[\"qzeros\"],\n        \"up_mins\": up_data[\"mins\"],\n        \"down_qweight\": down_data[\"qweight\"],\n        \"down_scales\": down_data[\"scales\"],\n        \"down_qzeros\": down_data[\"qzeros\"],\n        \"down_mins\": down_data[\"mins\"],\n        \"original_fp16\": {\"gate_proj\": gate_proj_fp16, \"up_proj\": up_proj_fp16, \"down_proj\": down_proj_fp16},\n    }\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj, debug_expert_id=None, debug_print=False):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n\n    if debug_print and debug_expert_id is not None:\n        print(f\"[TORCH FP16 DEBUG] Expert {debug_expert_id}:\")\n        print(f\"  gate_buf[:{debug_print_count}] = {gate_buf.flatten()[:debug_print_count]}\")\n        print(f\"  up_buf[:{debug_print_count}] = {up_buf.flatten()[:debug_print_count]}\")\n\n    intermediate = act_fn(gate_buf) * up_buf\n\n    if debug_print and debug_expert_id is not None:\n        print(f\"  intermediate[:{debug_print_count}] = {intermediate.flatten()[:debug_print_count]}\")\n\n    ret = torch.mm(intermediate, down_proj.t())\n\n    if debug_print and debug_expert_id is not None:\n        print(f\"  down_output[:{debug_print_count}] = {ret.flatten()[:debug_print_count]}\")\n\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj, debug_print=False):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    # Get the first expert from expert_ids array to match AWQ-MoE behavior\n    target_debug_expert = expert_ids[0, 0].item()  # First expert in expert_ids array\n\n    outputs = []\n    start_idx = 0\n    activated_experts = []\n\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        activated_experts.append(i)\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        # Only debug the target expert that matches AWQ-MoE's first expert\n        should_debug = debug_print and i == target_debug_expert\n        if gate_proj[i].dtype == torch.float16:\n            expert_out = mlp_torch(\n                tokens_for_this_expert.to(torch.float16),\n                gate_proj[i],\n                up_proj[i],\n                down_proj[i],\n                debug_expert_id=i,\n                debug_print=should_debug,\n            )\n        else:\n            expert_out = mlp_torch(\n                tokens_for_this_expert,\n                gate_proj[i],\n                up_proj[i],\n                down_proj[i],\n                debug_expert_id=i,\n                debug_print=should_debug,\n            )\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    if debug_print:\n        print(f\"[TORCH DEBUG] Processing activated experts: {activated_experts}\")\n        print(f\"[TORCH DEBUG] Target debug expert (matches AWQ): {target_debug_expert}\")\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n\n    if debug_print:\n        print(f\"[TORCH DEBUG] Final MoE output[:{debug_print_count}] = {t_output.flatten()[:debug_print_count]}\")\n\n    return t_output\n\n\ndef test_online_int4_kgroup_moe():\n    \"\"\"Test online Int4LowKGroup quantization (reference implementation)\"\"\"\n    print(\"Testing Online Int4LowKGroup quantization (reference)...\")\n\n    # Generate original weights for online quantization\n    gate_proj, up_proj, down_proj = generate_original_weights()\n\n    with torch.inference_mode(mode=True):\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n\n        for _ in range(layer_num):\n            # Create Int4LowKGroup configuration (online quantization)\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.gate_scale = 0\n            config.pool = CPUInfer.backend_\n\n            # Set quantization config for Int4LowKGroup (matches test_moe_amx.py)\n            config.quant_config.bits = 4\n            config.quant_config.group_size = k_group_size\n            config.quant_config.zero_point = True\n\n            # Enable weight dumping for comparison\n            config.save = True\n            config.path = \"./awq_dump_online\"\n\n            # Create Int4LowKGroup MoE (online quantization during load_weights)\n            moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)\n\n            # Load weights (performs online quantization)\n            print(f\"Physical Map: {physical_to_logical_map.data_ptr()}\")\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n\n            # Warm up\n            CPUInfer.submit(moe.warm_up_task())\n            CPUInfer.sync()\n\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n\n        print(\"Online Int4LowKGroup MoE created and loaded successfully!\")\n\n        # Run validation tests\n        results_online = []\n        for i in range(validation_iter):\n            # Reset seed for reproducible expert_ids and weights\n            torch.manual_seed(100 + i)  # Different seed to avoid same random values\n\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            # input = torch.tensor(\n            #     data=torch.cat([torch.ones(qlen, 1), torch.zeros(qlen, hidden_size - 1)], dim=1),\n            #     dtype=torch.bfloat16\n            # )\n            input = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100\n\n            moe = moes[i % layer_num]\n\n            # Enable debug for first few iterations\n            enable_debug = i < 2\n            if enable_debug:\n                print(f\"\\n=== Online Int4LowKGroup Test Iteration {i} ===\")\n                print(f\"input[:{debug_print_count}] = {input.flatten()[:debug_print_count]}\")\n                print(f\"expert_ids = {expert_ids}\")\n                print(f\"weights = {weights}\")\n\n            # Run online quantized MoE forward\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            if enable_debug:\n                print(f\"[ONLINE DEBUG] AMX output[:{debug_print_count}] = {output.flatten()[:debug_print_count]}\")\n\n            # Compare with FP16 reference\n            gate_proj_ref = gate_projs[i % layer_num]\n            up_proj_ref = up_projs[i % layer_num]\n            down_proj_ref = down_projs[i % layer_num]\n\n            t_output_online = moe_torch(\n                input, expert_ids, weights, gate_proj_ref, up_proj_ref, down_proj_ref, debug_print=enable_debug\n            )\n\n            # Calculate differences\n            diff_online = torch.mean(torch.abs(output - t_output_online)) / torch.mean(torch.abs(t_output_online))\n            results_online.append(output.clone())\n\n            print(f\"Online Iteration {i}: Int4LowKGroup vs FP16 = {diff_online:.6f}\")\n\n            if enable_debug:\n                abs_diff_online = torch.abs(output - t_output_online)\n                print(f\"[COMPARE] Online Int4LowKGroup vs FP16:\")\n                print(f\"  Max abs diff = {torch.max(abs_diff_online):.6f}\")\n                print(f\"  Mean abs diff = {torch.mean(abs_diff_online):.6f}\")\n                print(f\"  Relative diff = {diff_online:.6f}\")\n                print(\"=\" * 70)\n\n        print(\"\\n✅ Online Int4LowKGroup tests passed!\")\n        return results_online\n\n\ndef test_awq_moe():\n    print(\"Testing AWQ MoE with Int4_1LowKGroup quantization...\")\n\n    # Generate AWQ quantized weights\n    awq_data = generate_awq_quantized_weights()\n\n    with torch.inference_mode(mode=True):\n        moes = []\n\n        for _ in range(layer_num):\n            # Create AWQ MoE configuration\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n\n            # Set quantization config for Int4_1LowKGroup\n            config.quant_config.bits = 4\n            config.quant_config.group_size = k_group_size\n            config.quant_config.zero_point = True\n\n            # Enable weight dumping for comparison\n            config.save = True\n            config.path = \"./awq_dump_offline\"\n\n            # Set pointers to AWQ quantized data (no interleaving)\n            config.gate_proj = awq_data[\"gate_qweight\"].data_ptr()\n            config.up_proj = awq_data[\"up_qweight\"].data_ptr()\n            config.down_proj = awq_data[\"down_qweight\"].data_ptr()\n\n            config.gate_scale = awq_data[\"gate_scales\"].data_ptr()\n            config.up_scale = awq_data[\"up_scales\"].data_ptr()\n            config.down_scale = awq_data[\"down_scales\"].data_ptr()\n\n            config.gate_zeros = awq_data[\"gate_qzeros\"].data_ptr()\n            config.up_zeros = awq_data[\"up_qzeros\"].data_ptr()\n            config.down_zeros = awq_data[\"down_qzeros\"].data_ptr()\n\n            config.pool = CPUInfer.backend_\n\n            # Create Int4_1LowKGroup MoE\n            moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)\n\n            # Load weights\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n\n            # Warm up\n            CPUInfer.submit(moe.warm_up_task())\n            CPUInfer.sync()\n\n            moes.append(moe)\n\n        print(\"AWQ MoE Int4_1LowKGroup created and loaded successfully!\")\n\n        # Run validation tests\n        results_awq = []\n        for i in range(validation_iter):\n            # Reset seed for reproducible expert_ids and weights (same as online test)\n            torch.manual_seed(100 + i)\n\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            # input = torch.tensor(\n            #     data=torch.cat([torch.ones(qlen, 1), torch.zeros(qlen, hidden_size - 1)], dim=1),\n            #     dtype=torch.bfloat16\n            # )\n            input = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100\n\n            moe = moes[i % layer_num]\n\n            # Enable debug for first few iterations\n            enable_debug = i < 2\n            if enable_debug:\n                print(f\"\\n=== AWQ MoE Int4_1LowKGroup Test Iteration {i} ===\")\n                print(f\"input[:{debug_print_count}] = {input.flatten()[:debug_print_count]}\")\n                print(f\"expert_ids = {expert_ids}\")\n                print(f\"weights = {weights}\")\n\n                # Print which experts will be activated\n                activated_experts = []\n                for token in range(expert_ids.shape[0]):\n                    for expert_idx in range(expert_ids.shape[1]):\n                        expert_id = expert_ids[token][expert_idx].item()\n                        if expert_id not in activated_experts:\n                            activated_experts.append(expert_id)\n                print(f\"[TORCH DEBUG] Activated experts: {sorted(activated_experts)}\")\n                print(f\"[TORCH DEBUG] First expert from expert_ids array: {expert_ids[0, 0].item()}\")\n\n            # Run AWQ MoE forward\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            if enable_debug:\n                print(f\"[AWQ-MoE DEBUG] AMX output[:{debug_print_count}] = {output.flatten()[:debug_print_count]}\")\n\n            # Compare with FP16 reference\n            original_weights = awq_data[\"original_fp16\"]\n            gate_proj = original_weights[\"gate_proj\"].to(torch.float16)\n            up_proj = original_weights[\"up_proj\"].to(torch.float16)\n            down_proj = original_weights[\"down_proj\"].to(torch.float16)\n\n            t_output_fp16 = moe_torch(\n                input, expert_ids, weights, gate_proj, up_proj, down_proj, debug_print=enable_debug\n            )\n\n            # Calculate differences\n            diff_fp16 = torch.mean(torch.abs(output - t_output_fp16)) / torch.mean(torch.abs(t_output_fp16))\n            results_awq.append(output.clone())\n\n            print(f\"AWQ Iteration {i}: AWQ-MoE vs FP16 = {diff_fp16:.6f}\")\n\n            if enable_debug:\n                abs_diff_fp16 = torch.abs(output - t_output_fp16)\n                print(f\"[COMPARE] AWQ-MoE vs FP16:\")\n                print(f\"  Max abs diff = {torch.max(abs_diff_fp16):.6f}\")\n                print(f\"  Mean abs diff = {torch.mean(abs_diff_fp16):.6f}\")\n                print(f\"  Relative diff = {diff_fp16:.6f}\")\n                print(\"=\" * 70)\n\n            # AWQ quantization typically has higher error tolerance due to 4-bit quantization vs FP16\n            # assert(diff_fp16 < 0.5), f\"AWQ-MoE vs FP16 error too large: {diff_fp16:.6f}\"\n\n        print(\"\\n✅ All AWQ MoE tests passed!\")\n        return results_awq\n\n\ndef compare_quantization_methods():\n    \"\"\"Compare online and offline quantization methods\"\"\"\n    print(\"=\" * 70)\n    print(\"Comparing Online vs Offline Quantization Methods\")\n    print(\"=\" * 70)\n\n    # Run online quantization test (reference)\n    print(\"\\n\" + \"=\" * 70)\n    print(\"PHASE 1: Online Int4LowKGroup Quantization (Reference)\")\n    print(\"=\" * 70)\n    results_online = test_online_int4_kgroup_moe()\n\n    # Run offline AWQ quantization test\n    print(\"\\n\" + \"=\" * 70)\n    print(\"PHASE 2: Offline AWQ Int4_1LowKGroup Quantization\")\n    print(\"=\" * 70)\n    results_awq = test_awq_moe()\n\n    # Compare the results\n    print(\"\\n\" + \"=\" * 70)\n    print(\"PHASE 3: Comparison Results\")\n    print(\"=\" * 70)\n\n    if len(results_online) != len(results_awq):\n        print(f\"❌ Different number of results: Online={len(results_online)}, AWQ={len(results_awq)}\")\n        return\n\n    print(\"Comparing Online Int4LowKGroup vs Offline AWQ results:\")\n    total_diff = 0.0\n    max_diff = 0.0\n\n    for i in range(len(results_online)):\n        diff = torch.mean(torch.abs(results_online[i] - results_awq[i]))\n        rel_diff = diff / torch.mean(torch.abs(results_online[i]))\n        total_diff += rel_diff\n        max_diff = max(max_diff, diff.item())\n\n        if i < 3:  # Show detailed comparison for first 3 iterations\n            print(f\"  Iteration {i}:\")\n            print(f\"    Absolute diff: {diff:.6f}\")\n            print(f\"    Relative diff: {rel_diff:.6f}\")\n            print(f\"    Online output[:{debug_print_count//2}]:  {results_online[i].flatten()[:debug_print_count//2]}\")\n            print(f\"    AWQ output[:{debug_print_count//2}]:     {results_awq[i].flatten()[:debug_print_count//2]}\")\n        else:\n            print(f\"  Iteration {i}: Relative diff = {rel_diff:.6f}\")\n\n    avg_diff = total_diff / len(results_online)\n    print(f\"\\nOverall comparison:\")\n    print(f\"  Average relative difference: {avg_diff:.6f}\")\n    print(f\"  Maximum absolute difference: {max_diff:.6f}\")\n\n    # Determine if results match within acceptable tolerance\n    tolerance = 0.01  # 1% tolerance\n    if avg_diff < tolerance:\n        print(f\"✅ Results match within {tolerance:.1%} tolerance!\")\n        print(\"   Your offline AWQ quantization implementation appears to be correct.\")\n    else:\n        print(f\"❌ Results differ by more than {tolerance:.1%} tolerance.\")\n        print(\"   There may be differences between online and offline quantization.\")\n\n\nif __name__ == \"__main__\":\n    print(\"=\" * 70)\n    print(\"AWQ MoE AMX Test - Online vs Offline Quantization Comparison\")\n    print(\"=\" * 70)\n\n    compare_quantization_methods()\n\n    print(\"\\n\" + \"=\" * 70)\n    print(\"Test completed successfully!\")\n    print(\"=\" * 70)\n"
  },
  {
    "path": "kt-kernel/examples/test_bf16_moe.py",
    "content": "\"\"\"\nTest script for AMX_BF16_MOE_TP (native BF16 MoE) kernel validation.\n\nThis script:\n1. Generates random BF16 weights\n2. Runs the BF16 MoE kernel\n3. Compares results with PyTorch reference\n\nBF16 format notes:\n- Weight: BF16 stored as ggml_bf16_t, shape [expert_num, n, k]\n- No scales needed (native BF16 precision)\n\"\"\"\n\nimport os\nimport sys\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\n\nimport torch\nfrom kt_kernel import kt_kernel_ext\n\ntorch.manual_seed(42)\n\n# Model config\nhidden_size = 2048\nintermediate_size = 768\nmax_len = 25600\n\nexpert_num = 128\nnum_experts_per_tok = 8\n\nqlen = 1\nlayer_num = 5\nCPUInfer = kt_kernel_ext.CPUInfer(3)\nvalidation_iter = 5\ndebug_print_count = 16\n\nphysical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n\ndef act_fn(x):\n    \"\"\"SiLU activation function\"\"\"\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    \"\"\"Reference MLP computation in PyTorch\"\"\"\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    \"\"\"Reference MoE computation in PyTorch\"\"\"\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\n\ndef build_bf16_weights():\n    \"\"\"\n    Generate random BF16 weights.\n\n    Returns:\n        dict with BF16 weights for gate, up, down projections\n    \"\"\"\n    torch.manual_seed(42)\n\n    # Generate random BF16 weights with small values\n    gate_proj = (\n        (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0)\n        .to(torch.bfloat16)\n        .contiguous()\n    )\n    up_proj = (\n        (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0)\n        .to(torch.bfloat16)\n        .contiguous()\n    )\n    down_proj = (\n        (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0)\n        .to(torch.bfloat16)\n        .contiguous()\n    )\n\n    print(f\"BF16 weights shape: gate={gate_proj.shape}, up={up_proj.shape}, down={down_proj.shape}\")\n\n    # Debug: Print BF16 weight info for expert 0\n    print(\"\\n=== DEBUG: BF16 Weight Info (Expert 0) ===\")\n    print(f\"gate_proj[0] first 8 values: {gate_proj[0, 0, :8]}\")\n    print(f\"gate_proj[0] stats: min={gate_proj[0].min()}, max={gate_proj[0].max()}\")\n    print(f\"up_proj[0] first 8 values: {up_proj[0, 0, :8]}\")\n    print(f\"down_proj[0] first 8 values: {down_proj[0, 0, :8]}\")\n\n    return {\n        \"gate_proj\": gate_proj,\n        \"up_proj\": up_proj,\n        \"down_proj\": down_proj,\n    }\n\n\ndef build_moes_from_bf16_data(bf16_data: dict):\n    \"\"\"\n    Build BF16 MoE modules from BF16 weight data.\n    \"\"\"\n    moes = []\n    with torch.inference_mode(mode=True):\n        for _ in range(layer_num):\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n\n            # Set BF16 weight pointers (no scales needed)\n            config.gate_proj = bf16_data[\"gate_proj\"].data_ptr()\n            config.up_proj = bf16_data[\"up_proj\"].data_ptr()\n            config.down_proj = bf16_data[\"down_proj\"].data_ptr()\n\n            # No scales for BF16\n            config.gate_scale = 0\n            config.up_scale = 0\n            config.down_scale = 0\n            config.pool = CPUInfer.backend_\n\n            moe = kt_kernel_ext.moe.AMXBF16_MOE(config)\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n            moes.append(moe)\n    return moes\n\n\ndef run_bf16_moe_test():\n    \"\"\"\n    Run BF16 MoE validation test.\n    \"\"\"\n    print(\"\\n\" + \"=\" * 70)\n    print(\"BF16 MoE Kernel Validation Test\")\n    print(\"=\" * 70)\n\n    # Build BF16 weights\n    print(\"\\nGenerating BF16 weights...\")\n    bf16_data = build_bf16_weights()\n\n    # Build MoE modules\n    print(\"\\nBuilding BF16 MoE modules...\")\n    moes = build_moes_from_bf16_data(bf16_data)\n\n    # Get weights for reference computation\n    gate_proj = bf16_data[\"gate_proj\"]\n    up_proj = bf16_data[\"up_proj\"]\n    down_proj = bf16_data[\"down_proj\"]\n\n    diffs = []\n    with torch.inference_mode(mode=True):\n        for i in range(validation_iter):\n            torch.manual_seed(114514 + i)\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 10\n            input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 3\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n\n            moe = moes[i % layer_num]\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input_tensor.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            assert not torch.isnan(output).any(), \"NaN values detected in CPU expert output.\"\n            assert not torch.isinf(output).any(), \"Inf values detected in CPU expert output.\"\n\n            # Reference computation using BF16 weights\n            t_output = moe_torch(input_tensor, expert_ids, weights, gate_proj, up_proj, down_proj)\n\n            t_output_flat = t_output.flatten()\n            output_flat = output.flatten()\n\n            diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12)\n            diffs.append(diff.item())\n            print(f\"Iteration {i}: relative L1 diff = {diff:.6f}\")\n\n            if i < 3:  # Print detailed output for first few iterations\n                print(f\"  kernel output: {output_flat[:debug_print_count]}\")\n                print(f\"  torch output:  {t_output_flat[:debug_print_count]}\")\n\n    mean_diff = float(sum(diffs) / len(diffs))\n    max_diff = float(max(diffs))\n    min_diff = float(min(diffs))\n\n    print(\"\\n\" + \"=\" * 70)\n    print(\"BF16 MoE Test Results\")\n    print(\"=\" * 70)\n    print(f\"Mean relative L1 diff: {mean_diff*100:.4f}%\")\n    print(f\"Max relative L1 diff:  {max_diff*100:.4f}%\")\n    print(f\"Min relative L1 diff:  {min_diff*100:.4f}%\")\n\n    # Pass/Fail criteria (BF16 should be very accurate, <5% error)\n    threshold = 5.0\n    if mean_diff * 100 < threshold:\n        print(f\"\\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold\")\n    else:\n        print(f\"\\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold\")\n\n    return {\"mean\": mean_diff, \"max\": max_diff, \"min\": min_diff}\n\n\nif __name__ == \"__main__\":\n    run_bf16_moe_test()\n"
  },
  {
    "path": "kt-kernel/examples/test_deepseekv3.py",
    "content": "import os, sys\nimport time\n\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext.kvcache import ggml_type\nimport torch\nimport logging\nimport sys\nimport json\nfrom pathlib import Path\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    TextStreamer,\n)\n\nlogger = logging.getLogger(\"reader\")\n\nfrom gguf.gguf_reader import GGUFReader\n\n# load_layers = 6\nload_layers = None\nCPUInfer = kt_kernel_ext.CPUInfer(304)\nmax_qlen = 4096\nmax_kvlen = 4096\npage_size = 256\npages_count = 200\n\n\ndef read_gguf_file(gguf_file_path):\n    \"\"\"\n    Reads and prints key-value pairs and tensor information from a GGUF file in an improved format.\n\n    Parameters:\n    - gguf_file_path: Path to the GGUF file.\n    \"\"\"\n\n    reader = GGUFReader(gguf_file_path)\n\n    # List all key-value pairs in a columnized format\n    # print(\"Key-Value Pairs:\") # noqa: NP100\n    # max_key_length = max(len(key) for key in reader.fields.keys())\n    for key, field in reader.fields.items():\n        value = field.parts[field.data[0]]\n        # print(f\"{key:{max_key_length}} : {value}\") # noqa: NP100\n    # print(\"----\") # noqa: NP100\n\n    # List all tensors\n    # print(\"Tensors:\") # noqa: NP100\n    # tensor_info_format = \"{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}\"\n    # print(tensor_info_format.format(\"Tensor Name\", \"Shape\", \"Size\", \"Quantization\")) # noqa: NP100\n    # print(\"-\" * 80) # noqa: NP100\n    re = []\n    for tensor in reader.tensors:\n        shape_str = \"x\".join(map(str, tensor.shape))\n        size_str = str(tensor.n_elements)\n        quantization_str = tensor.tensor_type.name\n        # print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str)) # noqa: NP100\n        re.append(tensor)\n    return re\n\n\ndef read_gguf_directory(directory):\n    \"\"\"\n    Reads all GGUF files in a directory and prints their contents.\n\n    Parameters:\n    - directory: Path to the directory containing GGUF files.\n    \"\"\"\n    if not os.path.isdir(directory):\n        logger.error(f\"Directory {directory} does not exist.\")\n        return\n\n    # List all GGUF files in the directory\n    files = [f for f in os.listdir(directory) if f.endswith(\".gguf\")]\n    if not files:\n        logger.info(f\"No GGUF files found in {directory}.\")\n        return\n\n    re = []\n    for file in files:\n        file_path = os.path.join(directory, file)\n        # print(f\"Reading {file_path}:\") # noqa: NP100\n        # print(\"\\n\") # noqa: NP100\n        re.extend(read_gguf_file(file_path))\n    re = {r.name: r for r in re}\n    return re\n\n\ndef find_weights(name, weights):\n    \"\"\"\n    Finds and returns the weights for a given name from the list of weights.\n\n    Parameters:\n    - name: The name of the weights to find.\n    - weights: List of weight tensors.\n\n    Returns:\n    - The weight tensor if found, otherwise None.\n    \"\"\"\n    for weight in weights:\n        if weight.name == name:\n            return weight\n    raise ValueError(f\"Weight with name {name} not found in the provided weights list.\")\n\n\ndef get_torch_tensor_from_gguf(gguf_weights, name):\n    return torch.from_numpy(gguf_weights[name].data).contiguous()\n\n\ndef get_torch_tensor_and_type_from_gguf(gguf_weights, name):\n    return torch.from_numpy(gguf_weights[name].data).contiguous(), gguf_weights[name].tensor_type.name\n\n\ndef type_to_ggml_type(type):\n    if type == \"F32\":\n        return ggml_type.FP32\n    elif type == \"F16\":\n        return ggml_type.FP16\n    elif type == \"BF16\":\n        return ggml_type.BF16\n    else:\n        raise ValueError(f\"Unsupported data type: {type}\")\n\n\ndef build_mla(layer_idx, json_config, gguf_weights):\n    hidden_size = json_config[\"hidden_size\"]\n    num_heads = json_config[\"num_attention_heads\"]\n    q_lora_rank = json_config[\"q_lora_rank\"]\n    kv_lora_rank = json_config[\"kv_lora_rank\"]\n    nope_size = json_config[\"qk_nope_head_dim\"]\n    rope_size = json_config[\"qk_rope_head_dim\"]\n    max_position_embeddings = json_config[\"max_position_embeddings\"]\n    rope_theta = json_config[\"rope_theta\"]\n    rope_scaling = json_config[\"rope_scaling\"]\n\n    config = kt_kernel_ext.mla.MLAConfig(\n        hidden_size,\n        q_lora_rank,\n        kv_lora_rank,\n        num_heads,\n        nope_size,\n        rope_size,\n    )\n    config.max_qlen = max_qlen\n    config.max_kvlen = max_kvlen\n    config.max_position_embeddings = max_position_embeddings\n    config.rope_scaling_factor = rope_scaling[\"factor\"]\n    config.rope_theta = rope_theta\n    config.rope_scaling_beta_fast = rope_scaling[\"beta_fast\"]\n    config.rope_scaling_beta_slow = rope_scaling[\"beta_slow\"]\n    config.rope_scaling_mscale = rope_scaling[\"mscale\"]\n    config.rope_scaling_mscale_all_dim = rope_scaling[\"mscale_all_dim\"]\n    config.rope_scaling_original_max_position_embeddings = rope_scaling[\"original_max_position_embeddings\"]\n\n    q_a_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_a.weight\")\n    config.q_a_proj = q_a_proj_weight.data_ptr()\n    config.q_a_proj_type = type_to_ggml_type(type)\n    q_a_type = type\n\n    q_a_norm_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_a_norm.weight\")\n    config.q_a_norm = q_a_norm_weight.data_ptr()\n    config.q_a_norm_type = type_to_ggml_type(type)\n\n    q_b_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_b.weight\")\n    config.q_b_proj = q_b_proj_weight.data_ptr()\n    config.q_b_proj_type = type_to_ggml_type(type)\n\n    kv_a_proj_with_mqa_weight, type = get_torch_tensor_and_type_from_gguf(\n        gguf_weights, f\"blk.{layer_idx}.attn_kv_a_mqa.weight\"\n    )\n    config.kv_a_proj_with_mqa = kv_a_proj_with_mqa_weight.data_ptr()\n    config.kv_a_proj_with_mqa_type = type_to_ggml_type(type)\n\n    kv_a_norm_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_kv_a_norm.weight\")\n    config.kv_a_norm = kv_a_norm_weight.data_ptr()\n    config.kv_a_norm_type = type_to_ggml_type(type)\n\n    kv_b_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_kv_b.weight\")\n    config.kv_b_proj = kv_b_proj_weight.data_ptr()\n    config.kv_b_proj_type = type_to_ggml_type(type)\n\n    o_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_output.weight\")\n    config.o_proj = o_proj_weight.data_ptr()\n    config.w_o_type = type_to_ggml_type(type)\n\n    config.layer_idx = layer_idx\n    config.pool = CPUInfer.backend_\n    config.page_count = pages_count\n\n    if q_a_type == \"F32\":\n        mla = kt_kernel_ext.mla.MLA_F32(config)\n    elif q_a_type == \"F16\":\n        mla = kt_kernel_ext.mla.MLA_F16(config)\n    elif q_a_type == \"BF16\":\n        # mla = kt_kernel_ext.mla.MLA_F32(config)\n        mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)\n    else:\n        raise ValueError(f\"Unsupported data type: {q_a_type}\")\n\n    mla.load_weights()\n    mla.set_local_pages(pages_count)\n    return mla\n\n\ndef build_ffn(layer_idx, json_config, gguf_weights):\n    if f\"blk.{layer_idx}.ffn_gate.weight\" in gguf_weights:  # dense\n        config = kt_kernel_ext.moe.MOEConfig(\n            json_config[\"num_experts_per_tok\"] + json_config[\"n_shared_experts\"],\n            json_config[\"num_experts_per_tok\"] + json_config[\"n_shared_experts\"],\n            json_config[\"hidden_size\"],\n            json_config[\"moe_intermediate_size\"],\n        )\n        config.layer_idx = layer_idx\n        config.max_len = max_qlen\n        config.pool = CPUInfer.backend_\n        gate, gate_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_gate.weight\")\n        up, up_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_up.weight\")\n        down, down_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_down.weight\")\n\n        config.gate_proj = gate.data_ptr()\n        config.gate_type = type_to_ggml_type(gate_type)\n        config.up_proj = up.data_ptr()\n        config.up_type = type_to_ggml_type(up_type)\n        config.down_proj = down.data_ptr()\n        config.down_type = type_to_ggml_type(down_type)\n\n        moe = kt_kernel_ext.moe.KMLInt8_MOE(config)\n        moe.load_weights()\n        return moe\n\n    elif f\"blk.{layer_idx}.ffn_gate_exps.weight\" in gguf_weights:\n        config = kt_kernel_ext.moe.MOEConfig(\n            json_config[\"n_routed_experts\"] + json_config[\"n_shared_experts\"],\n            json_config[\"num_experts_per_tok\"] + json_config[\"n_shared_experts\"],\n            json_config[\"hidden_size\"],\n            json_config[\"moe_intermediate_size\"],\n        )\n        config.layer_idx = layer_idx\n        config.max_len = max_qlen\n        config.pool = CPUInfer.backend_\n        gate, gate_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_gate_exps.weight\")\n        up, up_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_up_exps.weight\")\n        down, down_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_down_exps.weight\")\n\n        gate_sh, gate_sh_type = get_torch_tensor_and_type_from_gguf(\n            gguf_weights, f\"blk.{layer_idx}.ffn_gate_shexp.weight\"\n        )\n        up_sh, up_sh_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_up_shexp.weight\")\n        down_sh, down_sh_type = get_torch_tensor_and_type_from_gguf(\n            gguf_weights, f\"blk.{layer_idx}.ffn_down_shexp.weight\"\n        )\n\n        gate_sh_expanded = gate_sh.unsqueeze(0)\n        gate = torch.cat([gate, gate_sh_expanded], dim=0).contiguous()\n        up_sh_expanded = up_sh.unsqueeze(0)\n        up = torch.cat([up, up_sh_expanded], dim=0).contiguous()\n        down_sh_expanded = down_sh.unsqueeze(0)\n        down = torch.cat([down, down_sh_expanded], dim=0).contiguous()\n\n        config.gate_proj = gate.data_ptr()\n        config.gate_type = type_to_ggml_type(gate_type)\n        config.up_proj = up.data_ptr()\n        config.up_type = type_to_ggml_type(up_type)\n        config.down_proj = down.data_ptr()\n        config.down_type = type_to_ggml_type(down_type)\n\n        moe = kt_kernel_ext.moe.KMLInt8_MOE(config)\n        moe.load_weights()\n        return moe\n\n    else:\n        raise ValueError(f\"Unsupported FFN type for layer {layer_idx}\")\n\n\ndef build_moegate(layer_idx, json_config, gguf_weights):\n    config = kt_kernel_ext.gate.GateConfig(\n        json_config[\"hidden_size\"],\n        json_config[\"num_experts_per_tok\"],\n        json_config[\"n_routed_experts\"],\n        json_config[\"n_group\"],\n        json_config[\"topk_group\"],\n    )\n\n    config.routed_scaling_factor = json_config[\"routed_scaling_factor\"]\n\n    config.pool = CPUInfer.backend_\n\n    weight, weight_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_gate_inp.weight\")\n    config.weight = weight.data_ptr()\n    config.weight_type = type_to_ggml_type(weight_type)\n\n    bias, bias_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.exp_probs_b.bias\")\n    config.e_score_correction_bias = bias.data_ptr()\n    config.e_score_correction_bias_type = type_to_ggml_type(bias_type)\n\n    gate = kt_kernel_ext.gate.MoEGate(config)\n\n    return gate\n\n\ndef build_llm(json_config, gguf_weights):\n\n    general_config = kt_kernel_ext.GeneralConfig()\n    general_config.vocab_size = json_config[\"vocab_size\"]\n    general_config.hidden_size = json_config[\"hidden_size\"]\n    general_config.num_experts_per_tok = json_config[\"num_experts_per_tok\"]\n    general_config.n_routed_experts = json_config[\"n_routed_experts\"]\n    general_config.n_shared_experts = json_config[\"n_shared_experts\"]\n    general_config.max_qlen = max_qlen\n\n    lm_heads, lm_heads_type = get_torch_tensor_and_type_from_gguf(gguf_weights, \"output.weight\")\n    general_config.lm_heads_ptr = lm_heads.data_ptr()\n    general_config.lm_heads_type = type_to_ggml_type(lm_heads_type)\n\n    output_norm, output_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, \"output_norm.weight\")\n    general_config.norm_weights_ptr = output_norm.data_ptr()\n    general_config.norm_weights_type = type_to_ggml_type(output_norm_type)\n\n    token_embd, token_embd_type = get_torch_tensor_and_type_from_gguf(weights, \"token_embd.weight\")\n    general_config.token_embd_ptr = token_embd.data_ptr()\n    general_config.token_embd_type = type_to_ggml_type(token_embd_type)\n\n    general_config.pool = CPUInfer.backend_\n\n    llm = kt_kernel_ext.DeepseekV3ForCausalLM(general_config)\n    model = kt_kernel_ext.DeepseekV3Model(general_config)\n    llm.model = model\n\n    decoder_layers = []\n    real_load_layers = json_config[\"num_hidden_layers\"] if load_layers is None else load_layers\n\n    for i in range(real_load_layers):\n        layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config, i)\n        attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{i}.attn_norm.weight\")\n        ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{i}.ffn_norm.weight\")\n\n        layer.load_norm(\n            attn_norm.data_ptr(),\n            type_to_ggml_type(attn_norm_type),\n            ffn_norm.data_ptr(),\n            type_to_ggml_type(ffn_norm_type),\n        )\n        layer.self_attn = build_mla(i, json_config, gguf_weights)\n        if f\"blk.{i}.ffn_gate_inp.weight\" in gguf_weights:\n            layer.gate = build_moegate(i, json_config, gguf_weights)\n        layer.ffn = build_ffn(i, json_config, gguf_weights)\n        decoder_layers.append(layer)\n\n    model.layers = decoder_layers\n    return llm\n\n\nsafetensor_path = \"/home/bd/models/DeepSeek-R1\"\njson_path = os.path.join(safetensor_path, \"config.json\")\njson_config = json.load(open(json_path, \"r\"))\nprint(json_config)\n\ngguf_path = \"/home/bd/models/DeepSeek-R1-BF16\"\nweights = read_gguf_directory(gguf_path)\nweights = dict(sorted(weights.items()))\n\n\nfor name, t in weights.items():\n    # if not name.startswith(\"blk\"):\n    # if name.startswith(\"blk.10.\"):\n    # if \"ffn_gate.\" in name:\n    # print(f\"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}\")\n    print(f\"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}\")\n\nprint(\"Building LLM ...\")\nload_start_time = time.perf_counter()\nllm = build_llm(json_config, weights)\nload_end_time = time.perf_counter()\nprint(f\"Load time: {load_end_time - load_start_time:.4f} seconds\")\n\nprint(\"Release Weight Tensors ...\")\nweights = None\nprint(\"Loading Configs ...\")\n\n\ntokenizer = AutoTokenizer.from_pretrained(safetensor_path, trust_remote_code=True)\nconfig = AutoConfig.from_pretrained(safetensor_path, trust_remote_code=True)\n\nforce_think = False\n\n\noutput_logits = torch.zeros((max_qlen, json_config[\"vocab_size\"]), dtype=torch.float32)\n\n\ndef start_chat(content=None):\n    if content is None:\n        content = input(\"Chat: \")\n\n    messages = [{\"role\": \"user\", \"content\": content}]\n    input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\")\n    if force_think:\n        token_thinks = torch.tensor(\n            [tokenizer.encode(\"<think>\\\\n\", add_special_tokens=False)], device=input_tensor.device\n        )\n        input_tensor = torch.cat([input_tensor, token_thinks], dim=1)\n    input_tensor = input_tensor.squeeze(0)  # Add batch dimension\n\n    print(f\"Input tensor: {input_tensor}, type {input_tensor.dtype}, shape {input_tensor.shape}\")\n    kvlen = 0\n    step = 2\n    while True or step > 0:\n        step -= 1\n        stream = TextStreamer(tokenizer)\n\n        qlen = input_tensor.shape[0]\n        qlens = [qlen - kvlen]\n        kvlens = [kvlen]\n        page_tables = [list(range(pages_count))]\n        start_time = time.perf_counter()\n        llm.forward(qlens, page_tables, kvlens, input_tensor[kvlen:].data_ptr(), output_logits.data_ptr())\n        end_time = time.perf_counter()\n        print(\n            f\"Forward time: {end_time - start_time:.4f} seconds, tps: {qlens[0] / (end_time - start_time)} tokens/sec\"\n        )\n\n        logits = output_logits[0]\n        # print(logits)\n        # sample\n        next_token = torch.argmax(logits).item()\n        # print(f\"Next token: {next_token}, {tokenizer.decode(next_token)}\")\n        kvlen = input_tensor.shape[0]\n        input_tensor = torch.cat((input_tensor, torch.tensor([next_token])), dim=-1)\n\n        if next_token == tokenizer.eos_token_id or tokenizer.decode(next_token) == \"<|im_end|>\":\n            stream.end()\n            break\n        else:\n            stream.put(torch.tensor([next_token]))\n\n\njob_id = 0\nwhile True:\n    try:\n        # ---------- 让用户决定是否继续 ----------\n        choice = input(\"\\n【回车】开始对话 | 输入 1 读取文件 | 输入 q/quit/exit 退出程序： \").strip().lower()\n        if choice in {\"q\", \"quit\", \"exit\"}:\n            print(\"收到退出指令，程序结束。\")\n            break\n        elif choice == \"1\":\n            file_path = input(\"请输入要读取的文件路径：\").strip()\n            if not Path(file_path).is_file():\n                print(f\"文件 {file_path} 不存在，请检查路径。\")\n                continue\n            with open(file_path, \"r\", encoding=\"utf-8\") as file:\n                content = file.read()\n            print(f\"读取到内容：\\n{content}\\n\")\n            start_chat(content)\n        else:\n            start_chat()\n\n    except KeyboardInterrupt:\n        # 随时 Ctrl-C：放弃当前任务并重启\n        print(f\"\\n检测到 Ctrl-C，已终止对话 #{job_id}，马上重启…\")\n    except Exception as e:\n        # 其他异常：打印错误信息并重启\n        print(f\"\\n发生错误：{e}\\n已终止对话 #{job_id}，马上重启…\")\n        logger.error(f\"Error in job {job_id}: {e}\", exc_info=True)\n    finally:\n        job_id += 1  # 不管中断与否，都给下一任务换编号\n"
  },
  {
    "path": "kt-kernel/examples/test_deepseekv3_prefill.py",
    "content": "import os, sys\nimport time\n\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext.kvcache import ggml_type\nimport torch\nimport logging\nimport sys\nimport json\nfrom pathlib import Path\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    TextStreamer,\n)\n\nlogger = logging.getLogger(\"reader\")\n\nfrom gguf.gguf_reader import GGUFReader\n\nCPUInfer = kt_kernel_ext.CPUInfer(304)\nmax_qlen = 4096\nmax_kvlen = 4096\npage_size = 256\npages_count = 200\n\n\ndef read_gguf_file(gguf_file_path):\n    \"\"\"\n    Reads and prints key-value pairs and tensor information from a GGUF file in an improved format.\n\n    Parameters:\n    - gguf_file_path: Path to the GGUF file.\n    \"\"\"\n\n    reader = GGUFReader(gguf_file_path)\n\n    # List all key-value pairs in a columnized format\n    # print(\"Key-Value Pairs:\") # noqa: NP100\n    # max_key_length = max(len(key) for key in reader.fields.keys())\n    for key, field in reader.fields.items():\n        value = field.parts[field.data[0]]\n        # print(f\"{key:{max_key_length}} : {value}\") # noqa: NP100\n    # print(\"----\") # noqa: NP100\n\n    # List all tensors\n    # print(\"Tensors:\") # noqa: NP100\n    # tensor_info_format = \"{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}\"\n    # print(tensor_info_format.format(\"Tensor Name\", \"Shape\", \"Size\", \"Quantization\")) # noqa: NP100\n    # print(\"-\" * 80) # noqa: NP100\n    re = []\n    for tensor in reader.tensors:\n        shape_str = \"x\".join(map(str, tensor.shape))\n        size_str = str(tensor.n_elements)\n        quantization_str = tensor.tensor_type.name\n        # print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str)) # noqa: NP100\n        re.append(tensor)\n    return re\n\n\ndef read_gguf_directory(directory):\n    \"\"\"\n    Reads all GGUF files in a directory and prints their contents.\n\n    Parameters:\n    - directory: Path to the directory containing GGUF files.\n    \"\"\"\n    if not os.path.isdir(directory):\n        logger.error(f\"Directory {directory} does not exist.\")\n        return\n\n    # List all GGUF files in the directory\n    files = [f for f in os.listdir(directory) if f.endswith(\".gguf\")]\n    if not files:\n        logger.info(f\"No GGUF files found in {directory}.\")\n        return\n\n    re = []\n    for file in files:\n        file_path = os.path.join(directory, file)\n        # print(f\"Reading {file_path}:\") # noqa: NP100\n        # print(\"\\n\") # noqa: NP100\n        re.extend(read_gguf_file(file_path))\n    re = {r.name: r for r in re}\n    return re\n\n\ndef find_weights(name, weights):\n    \"\"\"\n    Finds and returns the weights for a given name from the list of weights.\n\n    Parameters:\n    - name: The name of the weights to find.\n    - weights: List of weight tensors.\n\n    Returns:\n    - The weight tensor if found, otherwise None.\n    \"\"\"\n    for weight in weights:\n        if weight.name == name:\n            return weight\n    raise ValueError(f\"Weight with name {name} not found in the provided weights list.\")\n\n\ndef get_torch_tensor_from_gguf(gguf_weights, name):\n    return torch.from_numpy(gguf_weights[name].data).contiguous()\n\n\ndef get_torch_tensor_and_type_from_gguf(gguf_weights, name):\n    return torch.from_numpy(gguf_weights[name].data).contiguous(), gguf_weights[name].tensor_type.name\n\n\ndef type_to_ggml_type(type):\n    if type == \"F32\":\n        return ggml_type.FP32\n    elif type == \"F16\":\n        return ggml_type.FP16\n    elif type == \"BF16\":\n        return ggml_type.BF16\n    else:\n        raise ValueError(f\"Unsupported data type: {type}\")\n\n\ndef build_mla(layer_idx, json_config, gguf_weights):\n    hidden_size = json_config[\"hidden_size\"]\n    num_heads = json_config[\"num_attention_heads\"]\n    q_lora_rank = json_config[\"q_lora_rank\"]\n    kv_lora_rank = json_config[\"kv_lora_rank\"]\n    nope_size = json_config[\"qk_nope_head_dim\"]\n    rope_size = json_config[\"qk_rope_head_dim\"]\n    max_position_embeddings = json_config[\"max_position_embeddings\"]\n    rope_theta = json_config[\"rope_theta\"]\n    rope_scaling = json_config[\"rope_scaling\"]\n\n    config = kt_kernel_ext.mla.MLAConfig(\n        hidden_size,\n        q_lora_rank,\n        kv_lora_rank,\n        num_heads,\n        nope_size,\n        rope_size,\n    )\n    config.max_qlen = max_qlen\n    config.max_kvlen = max_kvlen\n    config.max_position_embeddings = max_position_embeddings\n    config.rope_scaling_factor = rope_scaling[\"factor\"]\n    config.rope_theta = rope_theta\n    config.rope_scaling_beta_fast = rope_scaling[\"beta_fast\"]\n    config.rope_scaling_beta_slow = rope_scaling[\"beta_slow\"]\n    config.rope_scaling_mscale = rope_scaling[\"mscale\"]\n    config.rope_scaling_mscale_all_dim = rope_scaling[\"mscale_all_dim\"]\n    config.rope_scaling_original_max_position_embeddings = rope_scaling[\"original_max_position_embeddings\"]\n\n    q_a_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_a.weight\")\n    config.q_a_proj = q_a_proj_weight.data_ptr()\n    config.q_a_proj_type = type_to_ggml_type(type)\n    q_a_type = type\n\n    q_a_norm_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_a_norm.weight\")\n    config.q_a_norm = q_a_norm_weight.data_ptr()\n    config.q_a_norm_type = type_to_ggml_type(type)\n\n    q_b_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_b.weight\")\n    config.q_b_proj = q_b_proj_weight.data_ptr()\n    config.q_b_proj_type = type_to_ggml_type(type)\n\n    kv_a_proj_with_mqa_weight, type = get_torch_tensor_and_type_from_gguf(\n        gguf_weights, f\"blk.{layer_idx}.attn_kv_a_mqa.weight\"\n    )\n    config.kv_a_proj_with_mqa = kv_a_proj_with_mqa_weight.data_ptr()\n    config.kv_a_proj_with_mqa_type = type_to_ggml_type(type)\n\n    kv_a_norm_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_kv_a_norm.weight\")\n    config.kv_a_norm = kv_a_norm_weight.data_ptr()\n    config.kv_a_norm_type = type_to_ggml_type(type)\n\n    kv_b_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_kv_b.weight\")\n    config.kv_b_proj = kv_b_proj_weight.data_ptr()\n    config.kv_b_proj_type = type_to_ggml_type(type)\n\n    o_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_output.weight\")\n    config.o_proj = o_proj_weight.data_ptr()\n    config.w_o_type = type_to_ggml_type(type)\n\n    config.layer_idx = layer_idx\n    config.pool = CPUInfer.backend_\n    config.page_count = pages_count\n\n    if q_a_type == \"F32\":\n        mla = kt_kernel_ext.mla.MLA_F32(config)\n    elif q_a_type == \"F16\":\n        mla = kt_kernel_ext.mla.MLA_F16(config)\n    elif q_a_type == \"BF16\":\n        mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)\n        # mla = kt_kernel_ext.mla.MLA_F32(config)\n    else:\n        raise ValueError(f\"Unsupported data type: {q_a_type}\")\n\n    mla.load_weights()\n    mla.set_local_pages(pages_count)\n    return mla\n\n\ndef build_ffn(layer_idx, json_config, gguf_weights):\n    if f\"blk.{layer_idx}.ffn_gate.weight\" in gguf_weights:  # dense\n        config = kt_kernel_ext.moe.MOEConfig(\n            json_config[\"num_experts_per_tok\"] + json_config[\"n_shared_experts\"],\n            json_config[\"num_experts_per_tok\"] + json_config[\"n_shared_experts\"],\n            json_config[\"hidden_size\"],\n            json_config[\"moe_intermediate_size\"],\n        )\n        config.layer_idx = layer_idx\n        config.max_len = max_qlen\n        config.pool = CPUInfer.backend_\n        gate, gate_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_gate.weight\")\n        up, up_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_up.weight\")\n        down, down_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_down.weight\")\n\n        config.gate_proj = gate.data_ptr()\n        config.gate_type = type_to_ggml_type(gate_type)\n        config.up_proj = up.data_ptr()\n        config.up_type = type_to_ggml_type(up_type)\n        config.down_proj = down.data_ptr()\n        config.down_type = type_to_ggml_type(down_type)\n\n        moe = kt_kernel_ext.moe.KMLInt8_MOE(config)\n        moe.load_weights()\n        return moe\n\n    elif f\"blk.{layer_idx}.ffn_gate_exps.weight\" in gguf_weights:\n        config = kt_kernel_ext.moe.MOEConfig(\n            json_config[\"n_routed_experts\"] + json_config[\"n_shared_experts\"],\n            json_config[\"num_experts_per_tok\"] + json_config[\"n_shared_experts\"],\n            json_config[\"hidden_size\"],\n            json_config[\"moe_intermediate_size\"],\n        )\n        config.layer_idx = layer_idx\n        config.max_len = max_qlen\n        config.pool = CPUInfer.backend_\n        gate, gate_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_gate_exps.weight\")\n        up, up_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_up_exps.weight\")\n        down, down_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_down_exps.weight\")\n\n        gate_sh, gate_sh_type = get_torch_tensor_and_type_from_gguf(\n            gguf_weights, f\"blk.{layer_idx}.ffn_gate_shexp.weight\"\n        )\n        up_sh, up_sh_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_up_shexp.weight\")\n        down_sh, down_sh_type = get_torch_tensor_and_type_from_gguf(\n            gguf_weights, f\"blk.{layer_idx}.ffn_down_shexp.weight\"\n        )\n\n        gate_sh_expanded = gate_sh.unsqueeze(0)\n        gate = torch.cat([gate, gate_sh_expanded], dim=0).contiguous()\n        up_sh_expanded = up_sh.unsqueeze(0)\n        up = torch.cat([up, up_sh_expanded], dim=0).contiguous()\n        down_sh_expanded = down_sh.unsqueeze(0)\n        down = torch.cat([down, down_sh_expanded], dim=0).contiguous()\n\n        config.gate_proj = gate.data_ptr()\n        config.gate_type = type_to_ggml_type(gate_type)\n        config.up_proj = up.data_ptr()\n        config.up_type = type_to_ggml_type(up_type)\n        config.down_proj = down.data_ptr()\n        config.down_type = type_to_ggml_type(down_type)\n\n        moe = kt_kernel_ext.moe.KMLInt8_MOE(config)\n        moe.load_weights()\n        return moe\n\n    else:\n        raise ValueError(f\"Unsupported FFN type for layer {layer_idx}\")\n\n\ndef build_moegate(layer_idx, json_config, gguf_weights):\n    config = kt_kernel_ext.gate.GateConfig(\n        json_config[\"hidden_size\"],\n        json_config[\"num_experts_per_tok\"],\n        json_config[\"n_routed_experts\"],\n        json_config[\"n_group\"],\n        json_config[\"topk_group\"],\n    )\n\n    config.routed_scaling_factor = json_config[\"routed_scaling_factor\"]\n\n    config.pool = CPUInfer.backend_\n\n    weight, weight_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_gate_inp.weight\")\n    config.weight = weight.data_ptr()\n    config.weight_type = type_to_ggml_type(weight_type)\n\n    bias, bias_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.exp_probs_b.bias\")\n    config.e_score_correction_bias = bias.data_ptr()\n    config.e_score_correction_bias_type = type_to_ggml_type(bias_type)\n\n    gate = kt_kernel_ext.gate.MoEGate(config)\n\n    return gate\n\n\ndef build_llm(json_config, gguf_weights):\n\n    general_config = kt_kernel_ext.GeneralConfig()\n    general_config.vocab_size = json_config[\"vocab_size\"]\n    general_config.hidden_size = json_config[\"hidden_size\"]\n    general_config.num_experts_per_tok = json_config[\"num_experts_per_tok\"]\n    general_config.n_routed_experts = json_config[\"n_routed_experts\"]\n    general_config.n_shared_experts = json_config[\"n_shared_experts\"]\n    general_config.max_qlen = max_qlen\n\n    lm_heads, lm_heads_type = get_torch_tensor_and_type_from_gguf(gguf_weights, \"output.weight\")\n    general_config.lm_heads_ptr = lm_heads.data_ptr()\n    general_config.lm_heads_type = type_to_ggml_type(lm_heads_type)\n\n    output_norm, output_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, \"output_norm.weight\")\n    general_config.norm_weights_ptr = output_norm.data_ptr()\n    general_config.norm_weights_type = type_to_ggml_type(output_norm_type)\n\n    token_embd, token_embd_type = get_torch_tensor_and_type_from_gguf(weights, \"token_embd.weight\")\n    general_config.token_embd_ptr = token_embd.data_ptr()\n    general_config.token_embd_type = type_to_ggml_type(token_embd_type)\n\n    general_config.pool = CPUInfer.backend_\n\n    llm = kt_kernel_ext.DeepseekV3ForCausalLM(general_config)\n    model = kt_kernel_ext.DeepseekV3Model(general_config)\n    llm.model = model\n\n    decoder_layers = []\n    for i in range(json_config[\"num_hidden_layers\"]):\n        # for i in range(6):\n        # for i in [0,1,2,3,4,5,6,7,8,9,10]:\n        layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config, i)\n        attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{i}.attn_norm.weight\")\n        ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{i}.ffn_norm.weight\")\n\n        layer.load_norm(\n            attn_norm.data_ptr(),\n            type_to_ggml_type(attn_norm_type),\n            ffn_norm.data_ptr(),\n            type_to_ggml_type(ffn_norm_type),\n        )\n        layer.self_attn = build_mla(i, json_config, gguf_weights)\n        if f\"blk.{i}.ffn_gate_inp.weight\" in gguf_weights:\n            layer.gate = build_moegate(i, json_config, gguf_weights)\n        layer.ffn = build_ffn(i, json_config, gguf_weights)\n        decoder_layers.append(layer)\n\n    model.layers = decoder_layers\n    return llm\n\n\nsafetensor_path = \"/home/bd/models/DeepSeek-R1\"\njson_path = os.path.join(safetensor_path, \"config.json\")\njson_config = json.load(open(json_path, \"r\"))\nprint(json_config)\n\ngguf_path = \"/home/bd/models/DeepSeek-R1-BF16\"\nweights = read_gguf_directory(gguf_path)\nweights = dict(sorted(weights.items()))\n\n\nfor name, t in weights.items():\n    # if not name.startswith(\"blk\"):\n    # if name.startswith(\"blk.10.\"):\n    # if \"ffn_gate.\" in name:\n    # print(f\"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}\")\n    print(f\"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}\")\nprint(\"Building LLM ...\")\nllm = build_llm(json_config, weights)\nprint(\"Release Weight Tensors ...\")\nweights = None\nprint(\"Loading Configs ...\")\n\n\ntokenizer = AutoTokenizer.from_pretrained(safetensor_path, trust_remote_code=True)\nconfig = AutoConfig.from_pretrained(safetensor_path, trust_remote_code=True)\nprompt_file = None\nforce_think = False\n\n\noutput_logits = torch.zeros((max_qlen, json_config[\"vocab_size\"]), dtype=torch.float32)\n\n\ndef start_chat():\n    while True:\n        content = input(\"Chat: \")\n        if content.startswith('\"\"\"'):  # prefix \"\"\"\n            # multi lines input\n            content = content[3:] + \"\\n\"\n            while True:\n                line = input(\"\")\n                if line.endswith('\"\"\"'):\n                    # end multi lines input\n                    line = line[:-3]  # suffix \"\"\"\n                    if line:\n                        content += line + \"\\n\"\n                    break\n                else:\n                    content += line + \"\\n\"\n\n        if content == \"\":\n            if prompt_file != None:\n                content = open(prompt_file, \"r\").read()\n            else:\n                content = \"Please write a piece of quicksort code in C++.\"\n        elif os.path.isfile(content):\n            content = open(content, \"r\").read()\n\n        messages = [{\"role\": \"user\", \"content\": content}]\n        input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\")\n        if force_think:\n            token_thinks = torch.tensor(\n                [tokenizer.encode(\"<think>\\\\n\", add_special_tokens=False)], device=input_tensor.device\n            )\n            input_tensor = torch.cat([input_tensor, token_thinks], dim=1)\n        input_tensor = input_tensor.squeeze(0)  # Add batch dimension\n\n        print(f\"Input tensor: {input_tensor}, type {input_tensor.dtype}, shape {input_tensor.shape}\")\n        while True:\n            stream = TextStreamer(tokenizer)\n\n            qlen = input_tensor.shape[0]\n            qlens = [qlen]\n            kvlens = [0]\n            page_tables = [list(range(pages_count))]\n            llm.forward(qlens, page_tables, kvlens, input_tensor.data_ptr(), output_logits.data_ptr())\n\n            logits = output_logits[0]\n            # print(logits)\n            # sample\n            next_token = torch.argmax(logits).item()\n            # print(f\"Next token: {next_token}, {tokenizer.decode(next_token)}\")\n            input_tensor = torch.cat((input_tensor, torch.tensor([next_token])), dim=-1)\n\n            if next_token == tokenizer.eos_token_id or tokenizer.decode(next_token) == \"<|im_end|>\":\n                print(stream.end(), end=\"\", flush=True)\n                break\n            else:\n                print(stream.put(torch.tensor([next_token])), end=\"\", flush=True)\n\n\njob_id = 0\nwhile True:\n    try:\n        # ---------- 让用户决定是否继续 ----------\n        choice = input(\"\\n【回车】开始对话 | 输入 q/quit/exit 退出程序： \").strip().lower()\n        if choice in {\"q\", \"quit\", \"exit\"}:\n            print(\"收到退出指令，程序结束。\")\n            break\n\n        # ----------------------------------------\n\n        start_chat()  # 启动聊天会话\n    except KeyboardInterrupt:\n        # 随时 Ctrl-C：放弃当前任务并重启\n        print(f\"\\n检测到 Ctrl-C，已终止对话 #{job_id}，马上重启…\")\n    finally:\n        job_id += 1  # 不管中断与否，都给下一任务换编号\n"
  },
  {
    "path": "kt-kernel/examples/test_deepseekv3_prefill_speed.py",
    "content": "import os, sys\nimport time\n\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext.kvcache import ggml_type\nimport torch\nimport logging\nimport sys\nimport json\nfrom pathlib import Path\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    TextStreamer,\n)\n\nlogger = logging.getLogger(\"reader\")\n\nfrom gguf.gguf_reader import GGUFReader\n\n# load_layers = 3\nload_layers = None\nworker_config = kt_kernel_ext.WorkerPoolConfig()\nworker_config.subpool_count = 2\nworker_config.subpool_numa_map = [0, 1]\nworker_config.subpool_thread_count = [72, 72]\nCPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n\nmax_qlen = 4096\nmax_kvlen = 4096\npage_size = 256\npages_count = 200\n\n\ndef read_gguf_file(gguf_file_path):\n    \"\"\"\n    Reads and prints key-value pairs and tensor information from a GGUF file in an improved format.\n\n    Parameters:\n    - gguf_file_path: Path to the GGUF file.\n    \"\"\"\n\n    reader = GGUFReader(gguf_file_path)\n\n    # List all key-value pairs in a columnized format\n    # print(\"Key-Value Pairs:\") # noqa: NP100\n    # max_key_length = max(len(key) for key in reader.fields.keys())\n    for key, field in reader.fields.items():\n        value = field.parts[field.data[0]]\n        # print(f\"{key:{max_key_length}} : {value}\") # noqa: NP100\n    # print(\"----\") # noqa: NP100\n\n    # List all tensors\n    # print(\"Tensors:\") # noqa: NP100\n    # tensor_info_format = \"{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}\"\n    # print(tensor_info_format.format(\"Tensor Name\", \"Shape\", \"Size\", \"Quantization\")) # noqa: NP100\n    # print(\"-\" * 80) # noqa: NP100\n    re = []\n    for tensor in reader.tensors:\n        shape_str = \"x\".join(map(str, tensor.shape))\n        size_str = str(tensor.n_elements)\n        quantization_str = tensor.tensor_type.name\n        # print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str)) # noqa: NP100\n        re.append(tensor)\n    return re\n\n\ndef read_gguf_directory(directory):\n    \"\"\"\n    Reads all GGUF files in a directory and prints their contents.\n\n    Parameters:\n    - directory: Path to the directory containing GGUF files.\n    \"\"\"\n    if not os.path.isdir(directory):\n        logger.error(f\"Directory {directory} does not exist.\")\n        return\n\n    # List all GGUF files in the directory\n    files = [f for f in os.listdir(directory) if f.endswith(\".gguf\")]\n    if not files:\n        logger.info(f\"No GGUF files found in {directory}.\")\n        return\n\n    re = []\n    for file in files:\n        file_path = os.path.join(directory, file)\n        # print(f\"Reading {file_path}:\") # noqa: NP100\n        # print(\"\\n\") # noqa: NP100\n        re.extend(read_gguf_file(file_path))\n    re = {r.name: r for r in re}\n    return re\n\n\ndef find_weights(name, weights):\n    \"\"\"\n    Finds and returns the weights for a given name from the list of weights.\n\n    Parameters:\n    - name: The name of the weights to find.\n    - weights: List of weight tensors.\n\n    Returns:\n    - The weight tensor if found, otherwise None.\n    \"\"\"\n    for weight in weights:\n        if weight.name == name:\n            return weight\n    raise ValueError(f\"Weight with name {name} not found in the provided weights list.\")\n\n\ndef get_torch_tensor_from_gguf(gguf_weights, name):\n    return torch.from_numpy(gguf_weights[name].data).contiguous()\n\n\ndef get_torch_tensor_and_type_from_gguf(gguf_weights, name):\n    return torch.from_numpy(gguf_weights[name].data).contiguous(), gguf_weights[name].tensor_type.name\n\n\ndef type_to_ggml_type(type):\n    if type == \"F32\":\n        return ggml_type.FP32\n    elif type == \"F16\":\n        return ggml_type.FP16\n    elif type == \"BF16\":\n        return ggml_type.BF16\n    else:\n        raise ValueError(f\"Unsupported data type: {type}\")\n\n\ndef build_mla(layer_idx, json_config, gguf_weights):\n    hidden_size = json_config[\"hidden_size\"]\n    num_heads = json_config[\"num_attention_heads\"]\n    q_lora_rank = json_config[\"q_lora_rank\"]\n    kv_lora_rank = json_config[\"kv_lora_rank\"]\n    nope_size = json_config[\"qk_nope_head_dim\"]\n    rope_size = json_config[\"qk_rope_head_dim\"]\n    max_position_embeddings = json_config[\"max_position_embeddings\"]\n    rope_theta = json_config[\"rope_theta\"]\n    rope_scaling = json_config[\"rope_scaling\"]\n\n    config = kt_kernel_ext.mla.MLAConfig(\n        hidden_size,\n        q_lora_rank,\n        kv_lora_rank,\n        num_heads,\n        nope_size,\n        rope_size,\n    )\n    config.max_qlen = max_qlen\n    config.max_kvlen = max_kvlen\n    config.max_position_embeddings = max_position_embeddings\n    config.rope_scaling_factor = rope_scaling[\"factor\"]\n    config.rope_theta = rope_theta\n    config.rope_scaling_beta_fast = rope_scaling[\"beta_fast\"]\n    config.rope_scaling_beta_slow = rope_scaling[\"beta_slow\"]\n    config.rope_scaling_mscale = rope_scaling[\"mscale\"]\n    config.rope_scaling_mscale_all_dim = rope_scaling[\"mscale_all_dim\"]\n    config.rope_scaling_original_max_position_embeddings = rope_scaling[\"original_max_position_embeddings\"]\n\n    q_a_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_a.weight\")\n    config.q_a_proj = q_a_proj_weight.data_ptr()\n    config.q_a_proj_type = type_to_ggml_type(type)\n    q_a_type = type\n\n    q_a_norm_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_a_norm.weight\")\n    config.q_a_norm = q_a_norm_weight.data_ptr()\n    config.q_a_norm_type = type_to_ggml_type(type)\n\n    q_b_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_b.weight\")\n    config.q_b_proj = q_b_proj_weight.data_ptr()\n    config.q_b_proj_type = type_to_ggml_type(type)\n\n    kv_a_proj_with_mqa_weight, type = get_torch_tensor_and_type_from_gguf(\n        gguf_weights, f\"blk.{layer_idx}.attn_kv_a_mqa.weight\"\n    )\n    config.kv_a_proj_with_mqa = kv_a_proj_with_mqa_weight.data_ptr()\n    config.kv_a_proj_with_mqa_type = type_to_ggml_type(type)\n\n    kv_a_norm_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_kv_a_norm.weight\")\n    config.kv_a_norm = kv_a_norm_weight.data_ptr()\n    config.kv_a_norm_type = type_to_ggml_type(type)\n\n    kv_b_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_kv_b.weight\")\n    config.kv_b_proj = kv_b_proj_weight.data_ptr()\n    config.kv_b_proj_type = type_to_ggml_type(type)\n\n    o_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_output.weight\")\n    config.o_proj = o_proj_weight.data_ptr()\n    config.w_o_type = type_to_ggml_type(type)\n\n    config.layer_idx = layer_idx\n    config.pool = CPUInfer.backend_\n    config.page_count = pages_count\n\n    if q_a_type == \"F32\":\n        mla = kt_kernel_ext.mla.MLA_F32(config)\n    elif q_a_type == \"F16\":\n        mla = kt_kernel_ext.mla.MLA_F16(config)\n    elif q_a_type == \"BF16\":\n        # mla = kt_kernel_ext.mla.MLA_F32(config)\n        mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)\n    else:\n        raise ValueError(f\"Unsupported data type: {q_a_type}\")\n\n    mla.load_weights()\n    mla.set_local_pages(pages_count)\n    return mla\n\n\ndef build_ffn(layer_idx, json_config, gguf_weights):\n    if f\"blk.{layer_idx}.ffn_gate.weight\" in gguf_weights:  # dense\n        config = kt_kernel_ext.moe.MOEConfig(\n            json_config[\"num_experts_per_tok\"] + json_config[\"n_shared_experts\"],\n            json_config[\"num_experts_per_tok\"] + json_config[\"n_shared_experts\"],\n            json_config[\"hidden_size\"],\n            json_config[\"moe_intermediate_size\"],\n        )\n        config.layer_idx = layer_idx\n        config.max_len = max_qlen\n        config.pool = CPUInfer.backend_\n        gate, gate_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_gate.weight\")\n        up, up_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_up.weight\")\n        down, down_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_down.weight\")\n\n        config.gate_proj = gate.data_ptr()\n        config.gate_type = type_to_ggml_type(gate_type)\n        config.up_proj = up.data_ptr()\n        config.up_type = type_to_ggml_type(up_type)\n        config.down_proj = down.data_ptr()\n        config.down_type = type_to_ggml_type(down_type)\n\n        moe = kt_kernel_ext.moe.KMLInt8_MOE(config)\n        moe.load_weights()\n        return moe\n\n    elif f\"blk.{layer_idx}.ffn_gate_exps.weight\" in gguf_weights:\n        config = kt_kernel_ext.moe.MOEConfig(\n            json_config[\"n_routed_experts\"] + json_config[\"n_shared_experts\"],\n            json_config[\"num_experts_per_tok\"] + json_config[\"n_shared_experts\"],\n            json_config[\"hidden_size\"],\n            json_config[\"moe_intermediate_size\"],\n        )\n        config.layer_idx = layer_idx\n        config.max_len = max_qlen\n        config.pool = CPUInfer.backend_\n        gate, gate_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_gate_exps.weight\")\n        up, up_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_up_exps.weight\")\n        down, down_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_down_exps.weight\")\n\n        gate_sh, gate_sh_type = get_torch_tensor_and_type_from_gguf(\n            gguf_weights, f\"blk.{layer_idx}.ffn_gate_shexp.weight\"\n        )\n        up_sh, up_sh_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_up_shexp.weight\")\n        down_sh, down_sh_type = get_torch_tensor_and_type_from_gguf(\n            gguf_weights, f\"blk.{layer_idx}.ffn_down_shexp.weight\"\n        )\n\n        gate_sh_expanded = gate_sh.unsqueeze(0)\n        gate = torch.cat([gate, gate_sh_expanded], dim=0).contiguous()\n        up_sh_expanded = up_sh.unsqueeze(0)\n        up = torch.cat([up, up_sh_expanded], dim=0).contiguous()\n        down_sh_expanded = down_sh.unsqueeze(0)\n        down = torch.cat([down, down_sh_expanded], dim=0).contiguous()\n\n        config.gate_proj = gate.data_ptr()\n        config.gate_type = type_to_ggml_type(gate_type)\n        config.up_proj = up.data_ptr()\n        config.up_type = type_to_ggml_type(up_type)\n        config.down_proj = down.data_ptr()\n        config.down_type = type_to_ggml_type(down_type)\n\n        moe = kt_kernel_ext.moe.KMLInt8_MOE(config)\n        moe.load_weights()\n        return moe\n\n    else:\n        raise ValueError(f\"Unsupported FFN type for layer {layer_idx}\")\n\n\ndef build_moegate(layer_idx, json_config, gguf_weights):\n    config = kt_kernel_ext.gate.GateConfig(\n        json_config[\"hidden_size\"],\n        json_config[\"num_experts_per_tok\"],\n        json_config[\"n_routed_experts\"],\n        json_config[\"n_group\"],\n        json_config[\"topk_group\"],\n    )\n\n    config.routed_scaling_factor = json_config[\"routed_scaling_factor\"]\n\n    config.pool = CPUInfer.backend_\n\n    weight, weight_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.ffn_gate_inp.weight\")\n    config.weight = weight.data_ptr()\n    config.weight_type = type_to_ggml_type(weight_type)\n\n    bias, bias_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.exp_probs_b.bias\")\n    config.e_score_correction_bias = bias.data_ptr()\n    config.e_score_correction_bias_type = type_to_ggml_type(bias_type)\n\n    gate = kt_kernel_ext.gate.MoEGate(config)\n\n    return gate\n\n\ndef build_llm(json_config, gguf_weights):\n\n    general_config = kt_kernel_ext.GeneralConfig()\n    general_config.vocab_size = json_config[\"vocab_size\"]\n    general_config.hidden_size = json_config[\"hidden_size\"]\n    general_config.num_experts_per_tok = json_config[\"num_experts_per_tok\"]\n    general_config.n_routed_experts = json_config[\"n_routed_experts\"]\n    general_config.n_shared_experts = json_config[\"n_shared_experts\"]\n    general_config.max_qlen = max_qlen\n\n    lm_heads, lm_heads_type = get_torch_tensor_and_type_from_gguf(gguf_weights, \"output.weight\")\n    general_config.lm_heads_ptr = lm_heads.data_ptr()\n    general_config.lm_heads_type = type_to_ggml_type(lm_heads_type)\n\n    output_norm, output_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, \"output_norm.weight\")\n    general_config.norm_weights_ptr = output_norm.data_ptr()\n    general_config.norm_weights_type = type_to_ggml_type(output_norm_type)\n\n    token_embd, token_embd_type = get_torch_tensor_and_type_from_gguf(weights, \"token_embd.weight\")\n    general_config.token_embd_ptr = token_embd.data_ptr()\n    general_config.token_embd_type = type_to_ggml_type(token_embd_type)\n\n    general_config.pool = CPUInfer.backend_\n\n    llm = kt_kernel_ext.DeepseekV3ForCausalLM(general_config)\n    model = kt_kernel_ext.DeepseekV3Model(general_config)\n    llm.model = model\n\n    decoder_layers = []\n    real_load_layers = json_config[\"num_hidden_layers\"] if load_layers is None else load_layers\n\n    for i in range(real_load_layers):\n        # for i in [2,3]:\n        layer = kt_kernel_ext.DeepseekV3DecoderLayer(general_config, i)\n        attn_norm, attn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{i}.attn_norm.weight\")\n        ffn_norm, ffn_norm_type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{i}.ffn_norm.weight\")\n\n        layer.load_norm(\n            attn_norm.data_ptr(),\n            type_to_ggml_type(attn_norm_type),\n            ffn_norm.data_ptr(),\n            type_to_ggml_type(ffn_norm_type),\n        )\n        layer.self_attn = build_mla(i, json_config, gguf_weights)\n        if f\"blk.{i}.ffn_gate_inp.weight\" in gguf_weights:\n            layer.gate = build_moegate(i, json_config, gguf_weights)\n        layer.ffn = build_ffn(i, json_config, gguf_weights)\n        decoder_layers.append(layer)\n\n    model.layers = decoder_layers\n    return llm\n\n\nsafetensor_path = \"/home/bd/models/DeepSeek-R1\"\njson_path = os.path.join(safetensor_path, \"config.json\")\njson_config = json.load(open(json_path, \"r\"))\nprint(json_config)\n\ngguf_path = \"/home/bd/models/DeepSeek-R1-BF16\"\nweights = read_gguf_directory(gguf_path)\nweights = dict(sorted(weights.items()))\n\n\n# for name, t in weights.items():\n# if not name.startswith(\"blk\"):\n# if name.startswith(\"blk.10.\"):\n# if \"ffn_gate.\" in name:\n# print(f\"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}\")\n# print(f\"Found weight: {t.name}, Shape: {t.shape}, Type: {t.tensor_type.name}, Size: {t.n_elements}\")\n\nprint(\"Building LLM ...\")\nload_start_time = time.perf_counter()\nllm = build_llm(json_config, weights)\nload_end_time = time.perf_counter()\nprint(f\"Load time: {load_end_time - load_start_time:.4f} seconds\")\n\nprint(\"Release Weight Tensors ...\")\nweights = None\nprint(\"Loading Configs ...\")\n\n\ntokenizer = AutoTokenizer.from_pretrained(safetensor_path, trust_remote_code=True)\nconfig = AutoConfig.from_pretrained(safetensor_path, trust_remote_code=True)\n\nforce_think = False\n\n\noutput_logits = torch.zeros((max_qlen, json_config[\"vocab_size\"]), dtype=torch.float32)\n\n\ndef start_chat(content=None):\n    if content is None:\n        content = input(\"Chat: \")\n\n    messages = [{\"role\": \"user\", \"content\": content}]\n    input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\")\n    if force_think:\n        token_thinks = torch.tensor(\n            [tokenizer.encode(\"<think>\\\\n\", add_special_tokens=False)], device=input_tensor.device\n        )\n        input_tensor = torch.cat([input_tensor, token_thinks], dim=1)\n    input_tensor = input_tensor.squeeze(0)  # Add batch dimension\n\n    print(f\"Input tensor: {input_tensor}, type {input_tensor.dtype}, shape {input_tensor.shape}\")\n    kvlen = 0\n    step = 2\n    while True or step > 0:\n        step -= 1\n        stream = TextStreamer(tokenizer)\n\n        qlen = input_tensor.shape[0]\n        qlens = [qlen]\n        kvlens = [0]\n        page_tables = [list(range(pages_count))]\n        start_time = time.perf_counter()\n        llm.forward(qlens, page_tables, kvlens, input_tensor.data_ptr(), output_logits.data_ptr())\n        end_time = time.perf_counter()\n        print(\n            f\"Forward time: {end_time - start_time:.4f} seconds, tps: {qlens[0] / (end_time - start_time)} tokens/sec\"\n        )\n\n        logits = output_logits[0]\n        # print(logits)\n        # sample\n        next_token = torch.argmax(logits).item()\n        # print(f\"Next token: {next_token}, {tokenizer.decode(next_token)}\")\n        # kvlen = input_tensor.shape[0]\n        input_tensor = torch.cat((input_tensor, torch.tensor([next_token])), dim=-1)\n\n        if next_token == tokenizer.eos_token_id or tokenizer.decode(next_token) == \"<|im_end|>\":\n            stream.end()\n            break\n        else:\n            stream.put(torch.tensor([next_token]))\n\n\njob_id = 0\nwhile True:\n    try:\n        # ---------- 让用户决定是否继续 ----------\n        choice = input(\"\\n【回车】开始对话 | 输入 1 读取文件 | 输入 q/quit/exit 退出程序： \").strip().lower()\n        if choice in {\"q\", \"quit\", \"exit\"}:\n            print(\"收到退出指令，程序结束。\")\n            break\n        elif choice == \"1\":\n            file_path = input(\"请输入要读取的文件路径：\").strip()\n            if not Path(file_path).is_file():\n                print(f\"文件 {file_path} 不存在，请检查路径。\")\n                continue\n            with open(file_path, \"r\", encoding=\"utf-8\") as file:\n                content = file.read()\n            print(f\"读取到内容：\\n{content}\\n\")\n            start_chat(content)\n        else:\n            start_chat()\n\n    except KeyboardInterrupt:\n        # 随时 Ctrl-C：放弃当前任务并重启\n        print(f\"\\n检测到 Ctrl-C，已终止对话 #{job_id}，马上重启…\")\n    except Exception as e:\n        # 其他异常：打印错误信息并重启\n        print(f\"\\n发生错误：{e}\\n已终止对话 #{job_id}，马上重启…\")\n        logger.error(f\"Error in job {job_id}: {e}\", exc_info=True)\n    finally:\n        job_id += 1  # 不管中断与否，都给下一任务换编号\n"
  },
  {
    "path": "kt-kernel/examples/test_fp8_moe.py",
    "content": "\"\"\"\nTest script for GemmKernel224FP8 (FP8 MoE) kernel validation.\n\nThis script:\n1. Generates random BF16 weights\n2. Quantizes them to FP8 format with 128x128 block-wise scales\n3. Runs the FP8 MoE kernel\n4. Compares results with PyTorch reference using dequantized BF16 weights\n\nFP8 format notes:\n- Weight: FP8 (E4M3) stored as uint8, shape [expert_num, n, k]\n- Scale: FP32, shape [expert_num, n // group_size, k // group_size], group_size=128\n\"\"\"\n\nimport os\nimport sys\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\n\nimport torch\nimport kt_kernel\nfrom kt_kernel import kt_kernel_ext\n\ntorch.manual_seed(42)\n\n# Model config\nhidden_size = 3072\nintermediate_size = 1536\nmax_len = 25600\n\nexpert_num = 16\nnum_experts_per_tok = 8\n\nqlen = 100\nlayer_num = 1\nCPUInfer = kt_kernel_ext.CPUInfer(40)\nvalidation_iter = 1\nfp8_group_size = 128  # FP8 uses 128x128 block quantization\ndebug_print_count = 16\n\nphysical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n\ndef act_fn(x):\n    \"\"\"SiLU activation function\"\"\"\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    \"\"\"Reference MLP computation in PyTorch\"\"\"\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    \"\"\"Reference MoE computation in PyTorch\"\"\"\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\n\n# FP8 E4M3 constants\nFP8_E4M3_MAX = 448.0  # Maximum representable value in FP8 E4M3\n\n\ndef fp8_e4m3_to_float(fp8_val: int) -> float:\n    \"\"\"\n    Convert FP8 E4M3 value to float.\n    FP8 E4M3 format: 1 sign bit, 4 exponent bits, 3 mantissa bits\n    \"\"\"\n    sign = (fp8_val >> 7) & 1\n    exp = (fp8_val >> 3) & 0xF\n    mant = fp8_val & 0x7\n\n    if exp == 0:\n        # Subnormal or zero\n        if mant == 0:\n            return -0.0 if sign else 0.0\n        # Subnormal: value = (-1)^sign * 2^(-6) * (0.mant)\n        return ((-1) ** sign) * (2**-6) * (mant / 8.0)\n    elif exp == 15:\n        # NaN (FP8 E4M3 doesn't have Inf, all exp=15 are NaN)\n        return float(\"nan\")\n    else:\n        # Normal: value = (-1)^sign * 2^(exp-7) * (1.mant)\n        return ((-1) ** sign) * (2 ** (exp - 7)) * (1.0 + mant / 8.0)\n\n\ndef float_to_fp8_e4m3(val: float) -> int:\n    \"\"\"\n    Convert float to FP8 E4M3 value.\n    \"\"\"\n    if val != val:  # NaN\n        return 0x7F  # NaN representation\n\n    sign = 1 if val < 0 else 0\n    val = abs(val)\n\n    if val == 0:\n        return sign << 7\n\n    # Clamp to max representable value\n    val = min(val, FP8_E4M3_MAX)\n\n    # Find exponent\n    import math\n\n    if val < 2**-9:  # Subnormal threshold\n        # Subnormal\n        mant = int(round(val / (2**-9)))\n        mant = min(mant, 7)\n        return (sign << 7) | mant\n\n    exp = int(math.floor(math.log2(val))) + 7\n    exp = max(1, min(exp, 14))  # Clamp exponent to valid range\n\n    # Calculate mantissa\n    mant = int(round((val / (2 ** (exp - 7)) - 1.0) * 8))\n    mant = max(0, min(mant, 7))\n\n    # Handle overflow to next exponent\n    if mant > 7:\n        mant = 0\n        exp += 1\n        if exp > 14:\n            exp = 14\n            mant = 7\n\n    return (sign << 7) | (exp << 3) | mant\n\n\ndef quantize_to_fp8_blockwise(weights: torch.Tensor, group_size: int = 128):\n    \"\"\"\n    Quantize BF16/FP32 weights to FP8 with block-wise scaling.\n\n    Args:\n        weights: [expert_num, n, k] tensor in BF16/FP32\n        group_size: Block size for quantization (default 128 for DeepSeek)\n\n    Returns:\n        fp8_weights: [expert_num, n, k] uint8 tensor\n        scales: [expert_num, n // group_size, k // group_size] BF16 tensor (scale_inv)\n    \"\"\"\n    weights_f32 = weights.to(torch.float32)\n    e, n, k = weights_f32.shape\n\n    assert n % group_size == 0, f\"n ({n}) must be divisible by group_size ({group_size})\"\n    assert k % group_size == 0, f\"k ({k}) must be divisible by group_size ({group_size})\"\n\n    n_blocks = n // group_size\n    k_blocks = k // group_size\n\n    # Reshape to [e, n_blocks, group_size, k_blocks, group_size]\n    reshaped = weights_f32.view(e, n_blocks, group_size, k_blocks, group_size)\n    # Move to [e, n_blocks, k_blocks, group_size, group_size] for block processing\n    reshaped = reshaped.permute(0, 1, 3, 2, 4)\n\n    # Calculate max abs per block\n    max_abs = reshaped.abs().amax(dim=(-2, -1), keepdim=True)\n    max_abs = torch.clamp(max_abs, min=1e-12)\n\n    # Scale to FP8 range: scale = max_abs / FP8_MAX\n    # We store scale_inv = scale (for dequantization: fp8 * scale)\n    scales = (max_abs / FP8_E4M3_MAX).squeeze(-1).squeeze(-1)  # [e, n_blocks, k_blocks]\n\n    # Quantize: q = round(val / scale)\n    scaled = reshaped / (scales.unsqueeze(-1).unsqueeze(-1) + 1e-12)\n\n    # Convert to FP8 E4M3 using vectorized approach\n    # Clamp to FP8 representable range\n    scaled = scaled.clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX)\n\n    # Simple quantization: round to nearest representable FP8 value\n    # For simplicity, we use a lookup table approach\n    fp8_q = torch.zeros_like(scaled, dtype=torch.uint8)\n\n    # Vectorized FP8 quantization\n    sign_mask = (scaled < 0).to(torch.uint8) << 7\n    abs_scaled = scaled.abs()\n\n    # Handle different ranges\n    # Subnormal: 0 < |x| < 2^-6\n    subnormal_mask = (abs_scaled > 0) & (abs_scaled < 2**-6)\n    subnormal_mant = (abs_scaled / (2**-9)).round().clamp(0, 7).to(torch.uint8)\n\n    # Normal values\n    normal_mask = abs_scaled >= 2**-6\n    log2_val = torch.log2(abs_scaled.clamp(min=2**-9))\n    exp = (log2_val.floor() + 7).clamp(1, 14).to(torch.int32)\n    mant = ((abs_scaled / (2.0 ** (exp.float() - 7)) - 1.0) * 8).round().clamp(0, 7).to(torch.uint8)\n\n    # Combine\n    fp8_q = torch.where(subnormal_mask, sign_mask | subnormal_mant, fp8_q)\n    fp8_q = torch.where(normal_mask, sign_mask | (exp.to(torch.uint8) << 3) | mant, fp8_q)\n\n    # Reshape back to [e, n, k]\n    fp8_q = fp8_q.permute(0, 1, 3, 2, 4).reshape(e, n, k)\n\n    # Scales shape: [e, n_blocks, k_blocks] -> store as [e, n_blocks, k_blocks]\n    scales_fp32 = scales.to(torch.float32).contiguous()\n\n    return fp8_q.contiguous(), scales_fp32\n\n\ndef dequantize_fp8_blockwise(fp8_weights: torch.Tensor, scales: torch.Tensor, group_size: int = 128):\n    \"\"\"\n    Dequantize FP8 weights back to BF16 for reference computation.\n\n    Args:\n        fp8_weights: [expert_num, n, k] uint8 tensor\n        scales: [expert_num, n // group_size, k // group_size] BF16 tensor\n        group_size: Block size\n\n    Returns:\n        dequantized: [expert_num, n, k] BF16 tensor\n    \"\"\"\n    e, n, k = fp8_weights.shape\n    n_blocks = n // group_size\n    k_blocks = k // group_size\n\n    # Convert FP8 to float\n    # Build lookup table for FP8 E4M3 -> float\n    fp8_lut = torch.tensor([fp8_e4m3_to_float(i) for i in range(256)], dtype=torch.float32)\n\n    # Use lookup table\n    fp8_float = fp8_lut[fp8_weights.to(torch.int64)]\n\n    # Reshape for block-wise scaling\n    fp8_reshaped = fp8_float.view(e, n_blocks, group_size, k_blocks, group_size)\n    fp8_reshaped = fp8_reshaped.permute(0, 1, 3, 2, 4)  # [e, n_blocks, k_blocks, group_size, group_size]\n\n    # Apply scales\n    scales_f32 = scales.to(torch.float32).unsqueeze(-1).unsqueeze(-1)  # [e, n_blocks, k_blocks, 1, 1]\n    dequantized = fp8_reshaped * scales_f32\n\n    # Reshape back\n    dequantized = dequantized.permute(0, 1, 3, 2, 4).reshape(e, n, k)\n\n    return dequantized.to(torch.bfloat16).contiguous()\n\n\ndef build_random_fp8_weights():\n    \"\"\"\n    Generate random BF16 weights and quantize to FP8.\n\n    Returns:\n        dict with fp8 weights, scales, and original bf16 for reference\n    \"\"\"\n    torch.manual_seed(42)\n\n    # Generate random BF16 weights with small values\n    gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(\n        torch.bfloat16\n    )\n    up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(\n        torch.bfloat16\n    )\n    down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0).to(\n        torch.bfloat16\n    )\n\n    # Quantize to FP8\n    gate_fp8, gate_scales = quantize_to_fp8_blockwise(gate_proj, fp8_group_size)\n    up_fp8, up_scales = quantize_to_fp8_blockwise(up_proj, fp8_group_size)\n    down_fp8, down_scales = quantize_to_fp8_blockwise(down_proj, fp8_group_size)\n\n    # Dequantize for reference computation\n    gate_deq = dequantize_fp8_blockwise(gate_fp8, gate_scales, fp8_group_size)\n    up_deq = dequantize_fp8_blockwise(up_fp8, up_scales, fp8_group_size)\n    down_deq = dequantize_fp8_blockwise(down_fp8, down_scales, fp8_group_size)\n\n    print(f\"FP8 weights shape: gate={gate_fp8.shape}, up={up_fp8.shape}, down={down_fp8.shape}\")\n    print(f\"Scales shape: gate={gate_scales.shape}, up={up_scales.shape}, down={down_scales.shape}\")\n\n    # Debug: Print FP8 weight and scale info for expert 0\n    print(\"\\n=== DEBUG: FP8 Weight and Scale Info (Expert 0) ===\")\n    print(f\"gate_fp8[0] first 8x8 block:\")\n    for i in range(8):\n        print(f\"  row {i}: {gate_fp8[0, i, :8].numpy().tobytes().hex(' ')}\")\n    print(f\"gate_fp8[0] stats: min={gate_fp8[0].min()}, max={gate_fp8[0].max()}\")\n    print(f\"gate_scales[0] first 4x4 block:\\n{gate_scales[0, :4, :4]}\")\n    print(f\"gate_scales[0] stats: min={gate_scales[0].min()}, max={gate_scales[0].max()}\")\n\n    print(f\"\\nup_fp8[0] first 8x8 block:\")\n    for i in range(8):\n        print(f\"  row {i}: {up_fp8[0, i, :8].numpy().tobytes().hex(' ')}\")\n    print(f\"up_fp8[0] stats: min={up_fp8[0].min()}, max={up_fp8[0].max()}\")\n    print(f\"up_scales[0] first 4x4 block:\\n{up_scales[0, :4, :4]}\")\n    print(f\"up_scales[0] stats: min={up_scales[0].min()}, max={up_scales[0].max()}\")\n\n    print(f\"\\ndown_fp8[0] first 8x8 block:\")\n    for i in range(8):\n        print(f\"  row {i}: {down_fp8[0, i, :8].numpy().tobytes().hex(' ')}\")\n    print(f\"down_fp8[0] stats: min={down_fp8[0].min()}, max={down_fp8[0].max()}\")\n    print(f\"down_scales[0] first 4x4 block:\\n{down_scales[0, :4, :4]}\")\n    print(f\"down_scales[0] stats: min={down_scales[0].min()}, max={down_scales[0].max()}\")\n\n    return {\n        \"gate_fp8\": gate_fp8.contiguous(),\n        \"up_fp8\": up_fp8.contiguous(),\n        \"down_fp8\": down_fp8.contiguous(),\n        \"gate_scales\": gate_scales.contiguous(),\n        \"up_scales\": up_scales.contiguous(),\n        \"down_scales\": down_scales.contiguous(),\n        \"gate_deq\": gate_deq.contiguous(),\n        \"up_deq\": up_deq.contiguous(),\n        \"down_deq\": down_deq.contiguous(),\n    }\n\n\ndef build_moes_from_fp8_data(fp8_data: dict):\n    \"\"\"\n    Build FP8 MoE modules from quantized data.\n    \"\"\"\n    moes = []\n    with torch.inference_mode(mode=True):\n        for _ in range(layer_num):\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.quant_config.bits = 8\n            config.quant_config.group_size = fp8_group_size\n            config.quant_config.zero_point = False\n\n            # Set FP8 weight pointers\n            config.gate_proj = fp8_data[\"gate_fp8\"].data_ptr()\n            config.up_proj = fp8_data[\"up_fp8\"].data_ptr()\n            config.down_proj = fp8_data[\"down_fp8\"].data_ptr()\n\n            # Set scale pointers\n            config.gate_scale = fp8_data[\"gate_scales\"].data_ptr()\n            config.up_scale = fp8_data[\"up_scales\"].data_ptr()\n            config.down_scale = fp8_data[\"down_scales\"].data_ptr()\n            config.pool = CPUInfer.backend_\n\n            moe = kt_kernel_ext.moe.AMXFP8_MOE(config)\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n            moes.append(moe)\n    return moes\n\n\ndef run_fp8_moe_test():\n    \"\"\"\n    Run FP8 MoE validation test.\n    \"\"\"\n    print(\"\\n\" + \"=\" * 70)\n    print(\"FP8 MoE Kernel Validation Test\")\n    print(\"=\" * 70)\n\n    # Build FP8 weights\n    print(\"\\nGenerating and quantizing weights...\")\n    fp8_data = build_random_fp8_weights()\n\n    # Build MoE modules\n    print(\"\\nBuilding FP8 MoE modules...\")\n    moes = build_moes_from_fp8_data(fp8_data)\n\n    # Get dequantized weights for reference\n    gate_deq = fp8_data[\"gate_deq\"]\n    up_deq = fp8_data[\"up_deq\"]\n    down_deq = fp8_data[\"down_deq\"]\n\n    diffs = []\n    with torch.inference_mode(mode=True):\n        for i in range(validation_iter):\n            torch.manual_seed(100 + i)\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100\n            input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n\n            moe = moes[i % layer_num]\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input_tensor.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            assert not torch.isnan(output).any(), \"NaN values detected in CPU expert output.\"\n            assert not torch.isinf(output).any(), \"Inf values detected in CPU expert output.\"\n\n            # Reference computation using dequantized weights\n            t_output = moe_torch(input_tensor, expert_ids, weights, gate_deq, up_deq, down_deq)\n\n            t_output_flat = t_output.flatten()\n            output_flat = output.flatten()\n\n            diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12)\n            diffs.append(diff.item())\n            print(f\"Iteration {i}: relative L1 diff = {diff:.6f}\")\n\n            if i < 3:  # Print detailed output for first few iterations\n                print(f\"  kernel output: {output_flat[:debug_print_count]}\")\n                print(f\"  torch output:  {t_output_flat[:debug_print_count]}\")\n\n    mean_diff = float(sum(diffs) / len(diffs))\n    max_diff = float(max(diffs))\n    min_diff = float(min(diffs))\n\n    print(\"\\n\" + \"=\" * 70)\n    print(\"FP8 MoE Test Results\")\n    print(\"=\" * 70)\n    print(f\"Mean relative L1 diff: {mean_diff*100:.4f}%\")\n    print(f\"Max relative L1 diff:  {max_diff*100:.4f}%\")\n    print(f\"Min relative L1 diff:  {min_diff*100:.4f}%\")\n\n    # Pass/Fail criteria\n    threshold = 15.0  # 15% relative error threshold for FP8\n    if mean_diff * 100 < threshold:\n        print(f\"\\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold\")\n    else:\n        print(f\"\\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold\")\n\n    return {\"mean\": mean_diff, \"max\": max_diff, \"min\": min_diff}\n\n\nif __name__ == \"__main__\":\n    run_fp8_moe_test()\n"
  },
  {
    "path": "kt-kernel/examples/test_fp8_perchannel_moe.py",
    "content": "\"\"\"\nTest script for FP8 Per-Channel MoE kernel validation (GLM-4.7-FP8 style).\n\nThis script:\n1. Generates random BF16 weights\n2. Quantizes them to FP8 format with per-channel scales (one scale per output channel)\n3. Runs the FP8 Per-Channel MoE kernel\n4. Compares results with PyTorch reference using dequantized BF16 weights\n\nFP8 Per-Channel format notes:\n- Weight: FP8 (E4M3) stored as uint8, shape [expert_num, n, k]\n- Scale: FP32, shape [expert_num, n] (one scale per output row)\n\"\"\"\n\nimport os\nimport sys\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\n\nimport torch\nfrom kt_kernel import kt_kernel_ext\n\ntorch.manual_seed(42)\n\n# Model config\nhidden_size = 3072\nintermediate_size = 1536\nmax_len = 25600\n\nexpert_num = 16\nnum_experts_per_tok = 8\n\nqlen = 100\nlayer_num = 1\nCPUInfer = kt_kernel_ext.CPUInfer(40)\nvalidation_iter = 1\ndebug_print_count = 16\n\nphysical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n\ndef act_fn(x):\n    \"\"\"SiLU activation function\"\"\"\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    \"\"\"Reference MLP computation in PyTorch\"\"\"\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    \"\"\"Reference MoE computation in PyTorch\"\"\"\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\n\n# FP8 E4M3 constants\nFP8_E4M3_MAX = 448.0  # Maximum representable value in FP8 E4M3\n\n\ndef fp8_e4m3_to_float(fp8_val: int) -> float:\n    \"\"\"\n    Convert FP8 E4M3 value to float.\n    FP8 E4M3 format: 1 sign bit, 4 exponent bits, 3 mantissa bits\n    \"\"\"\n    sign = (fp8_val >> 7) & 1\n    exp = (fp8_val >> 3) & 0xF\n    mant = fp8_val & 0x7\n\n    if exp == 0:\n        # Subnormal or zero\n        if mant == 0:\n            return -0.0 if sign else 0.0\n        # Subnormal: value = (-1)^sign * 2^(-6) * (0.mant)\n        return ((-1) ** sign) * (2**-6) * (mant / 8.0)\n    elif exp == 15:\n        # NaN (FP8 E4M3 doesn't have Inf, all exp=15 are NaN)\n        return float(\"nan\")\n    else:\n        # Normal: value = (-1)^sign * 2^(exp-7) * (1.mant)\n        return ((-1) ** sign) * (2 ** (exp - 7)) * (1.0 + mant / 8.0)\n\n\ndef float_to_fp8_e4m3(val: float) -> int:\n    \"\"\"\n    Convert float to FP8 E4M3 value.\n    \"\"\"\n    if val != val:  # NaN\n        return 0x7F  # NaN representation\n\n    sign = 1 if val < 0 else 0\n    val = abs(val)\n\n    if val == 0:\n        return sign << 7\n\n    # Clamp to max representable value\n    val = min(val, FP8_E4M3_MAX)\n\n    # Find exponent\n    import math\n\n    if val < 2**-9:  # Subnormal threshold\n        # Subnormal\n        mant = int(round(val / (2**-9)))\n        mant = min(mant, 7)\n        return (sign << 7) | mant\n\n    exp = int(math.floor(math.log2(val))) + 7\n    exp = max(1, min(exp, 14))  # Clamp exponent to valid range\n\n    # Calculate mantissa\n    mant = int(round((val / (2 ** (exp - 7)) - 1.0) * 8))\n    mant = max(0, min(mant, 7))\n\n    # Handle overflow to next exponent\n    if mant > 7:\n        mant = 0\n        exp += 1\n        if exp > 14:\n            exp = 14\n            mant = 7\n\n    return (sign << 7) | (exp << 3) | mant\n\n\ndef quantize_to_fp8_perchannel(weights: torch.Tensor):\n    \"\"\"\n    Quantize BF16/FP32 weights to FP8 with per-channel scaling.\n\n    Args:\n        weights: [expert_num, n, k] tensor in BF16/FP32\n\n    Returns:\n        fp8_weights: [expert_num, n, k] uint8 tensor\n        scales: [expert_num, n] FP32 tensor (one scale per output row)\n    \"\"\"\n    weights_f32 = weights.to(torch.float32)\n    e, n, k = weights_f32.shape\n\n    # Calculate max abs per row (per output channel)\n    max_abs = weights_f32.abs().amax(dim=-1, keepdim=True)  # [e, n, 1]\n    max_abs = torch.clamp(max_abs, min=1e-12)\n\n    # Scale to FP8 range: scale = max_abs / FP8_MAX\n    scales = (max_abs / FP8_E4M3_MAX).squeeze(-1)  # [e, n]\n\n    # Quantize: q = round(val / scale)\n    scaled = weights_f32 / (scales.unsqueeze(-1) + 1e-12)\n\n    # Clamp to FP8 representable range\n    scaled = scaled.clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX)\n\n    # Vectorized FP8 quantization\n    fp8_q = torch.zeros_like(scaled, dtype=torch.uint8)\n\n    sign_mask = (scaled < 0).to(torch.uint8) << 7\n    abs_scaled = scaled.abs()\n\n    # Handle different ranges\n    # Subnormal: 0 < |x| < 2^-6\n    subnormal_mask = (abs_scaled > 0) & (abs_scaled < 2**-6)\n    subnormal_mant = (abs_scaled / (2**-9)).round().clamp(0, 7).to(torch.uint8)\n\n    # Normal values\n    normal_mask = abs_scaled >= 2**-6\n    log2_val = torch.log2(abs_scaled.clamp(min=2**-9))\n    exp = (log2_val.floor() + 7).clamp(1, 14).to(torch.int32)\n    mant = ((abs_scaled / (2.0 ** (exp.float() - 7)) - 1.0) * 8).round().clamp(0, 7).to(torch.uint8)\n\n    # Combine\n    fp8_q = torch.where(subnormal_mask, sign_mask | subnormal_mant, fp8_q)\n    fp8_q = torch.where(normal_mask, sign_mask | (exp.to(torch.uint8) << 3) | mant, fp8_q)\n\n    return fp8_q.contiguous(), scales.to(torch.float32).contiguous()\n\n\ndef dequantize_fp8_perchannel(fp8_weights: torch.Tensor, scales: torch.Tensor):\n    \"\"\"\n    Dequantize FP8 weights back to BF16 for reference computation.\n\n    Args:\n        fp8_weights: [expert_num, n, k] uint8 tensor\n        scales: [expert_num, n] FP32 tensor\n\n    Returns:\n        dequantized: [expert_num, n, k] BF16 tensor\n    \"\"\"\n    # Build lookup table for FP8 E4M3 -> float\n    fp8_lut = torch.tensor([fp8_e4m3_to_float(i) for i in range(256)], dtype=torch.float32)\n\n    # Use lookup table\n    fp8_float = fp8_lut[fp8_weights.to(torch.int64)]\n\n    # Apply per-channel scales\n    scales_expanded = scales.unsqueeze(-1)  # [e, n, 1]\n    dequantized = fp8_float * scales_expanded\n\n    return dequantized.to(torch.bfloat16).contiguous()\n\n\ndef build_random_fp8_perchannel_weights():\n    \"\"\"\n    Generate random BF16 weights and quantize to FP8 with per-channel scales.\n\n    Returns:\n        dict with fp8 weights, scales, and original bf16 for reference\n    \"\"\"\n    torch.manual_seed(42)\n\n    # Generate random BF16 weights with small values\n    gate_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(\n        torch.bfloat16\n    )\n    up_proj = (torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32) / 100.0).to(\n        torch.bfloat16\n    )\n    down_proj = (torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32) / 100.0).to(\n        torch.bfloat16\n    )\n\n    # Quantize to FP8 with per-channel scales\n    gate_fp8, gate_scales = quantize_to_fp8_perchannel(gate_proj)\n    up_fp8, up_scales = quantize_to_fp8_perchannel(up_proj)\n    down_fp8, down_scales = quantize_to_fp8_perchannel(down_proj)\n\n    # Dequantize for reference computation\n    gate_deq = dequantize_fp8_perchannel(gate_fp8, gate_scales)\n    up_deq = dequantize_fp8_perchannel(up_fp8, up_scales)\n    down_deq = dequantize_fp8_perchannel(down_fp8, down_scales)\n\n    print(f\"FP8 Per-Channel weights shape: gate={gate_fp8.shape}, up={up_fp8.shape}, down={down_fp8.shape}\")\n    print(f\"Per-Channel scales shape: gate={gate_scales.shape}, up={up_scales.shape}, down={down_scales.shape}\")\n\n    # Debug: Print FP8 weight and scale info for expert 0\n    print(\"\\n=== DEBUG: FP8 Per-Channel Weight and Scale Info (Expert 0) ===\")\n    print(f\"gate_fp8[0] first 8x8 block:\")\n    for i in range(8):\n        print(f\"  row {i}: {gate_fp8[0, i, :8].numpy().tobytes().hex(' ')}\")\n    print(f\"gate_fp8[0] stats: min={gate_fp8[0].min()}, max={gate_fp8[0].max()}\")\n    print(f\"gate_scales[0] first 8 channels: {gate_scales[0, :8]}\")\n    print(f\"gate_scales[0] stats: min={gate_scales[0].min():.6f}, max={gate_scales[0].max():.6f}\")\n\n    return {\n        \"gate_fp8\": gate_fp8.contiguous(),\n        \"up_fp8\": up_fp8.contiguous(),\n        \"down_fp8\": down_fp8.contiguous(),\n        \"gate_scales\": gate_scales.contiguous(),\n        \"up_scales\": up_scales.contiguous(),\n        \"down_scales\": down_scales.contiguous(),\n        \"gate_deq\": gate_deq.contiguous(),\n        \"up_deq\": up_deq.contiguous(),\n        \"down_deq\": down_deq.contiguous(),\n    }\n\n\ndef build_moes_from_fp8_perchannel_data(fp8_data: dict):\n    \"\"\"\n    Build FP8 Per-Channel MoE modules from quantized data.\n    \"\"\"\n    moes = []\n    with torch.inference_mode(mode=True):\n        for _ in range(layer_num):\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.quant_config.bits = 8\n            config.quant_config.group_size = 0  # Not used for per-channel\n            config.quant_config.zero_point = False\n            config.quant_config.per_channel = True  # Enable per-channel mode\n\n            # Set FP8 weight pointers\n            config.gate_proj = fp8_data[\"gate_fp8\"].data_ptr()\n            config.up_proj = fp8_data[\"up_fp8\"].data_ptr()\n            config.down_proj = fp8_data[\"down_fp8\"].data_ptr()\n\n            # Set per-channel scale pointers\n            config.gate_scale = fp8_data[\"gate_scales\"].data_ptr()\n            config.up_scale = fp8_data[\"up_scales\"].data_ptr()\n            config.down_scale = fp8_data[\"down_scales\"].data_ptr()\n            config.pool = CPUInfer.backend_\n\n            moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(config)\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n            moes.append(moe)\n    return moes\n\n\ndef run_fp8_perchannel_moe_test():\n    \"\"\"\n    Run FP8 Per-Channel MoE validation test.\n    \"\"\"\n    print(\"\\n\" + \"=\" * 70)\n    print(\"FP8 Per-Channel MoE Kernel Validation Test\")\n    print(\"=\" * 70)\n\n    # Build FP8 per-channel weights\n    print(\"\\nGenerating and quantizing weights with per-channel scales...\")\n    fp8_data = build_random_fp8_perchannel_weights()\n\n    # Build MoE modules\n    print(\"\\nBuilding FP8 Per-Channel MoE modules...\")\n    moes = build_moes_from_fp8_perchannel_data(fp8_data)\n\n    # Get dequantized weights for reference\n    gate_deq = fp8_data[\"gate_deq\"]\n    up_deq = fp8_data[\"up_deq\"]\n    down_deq = fp8_data[\"down_deq\"]\n\n    diffs = []\n    with torch.inference_mode(mode=True):\n        for i in range(validation_iter):\n            torch.manual_seed(100 + i)\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous() / 100\n            input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() * 1.5\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n\n            moe = moes[i % layer_num]\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input_tensor.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            assert not torch.isnan(output).any(), \"NaN values detected in CPU expert output.\"\n            assert not torch.isinf(output).any(), \"Inf values detected in CPU expert output.\"\n\n            # Reference computation using dequantized weights\n            t_output = moe_torch(input_tensor, expert_ids, weights, gate_deq, up_deq, down_deq)\n\n            t_output_flat = t_output.flatten()\n            output_flat = output.flatten()\n\n            diff = torch.mean(torch.abs(output_flat - t_output_flat)) / (torch.mean(torch.abs(t_output_flat)) + 1e-12)\n            diffs.append(diff.item())\n            print(f\"Iteration {i}: relative L1 diff = {diff:.6f}\")\n\n            if i < 3:  # Print detailed output for first few iterations\n                print(f\"  kernel output: {output_flat[:debug_print_count]}\")\n                print(f\"  torch output:  {t_output_flat[:debug_print_count]}\")\n\n    mean_diff = float(sum(diffs) / len(diffs))\n    max_diff = float(max(diffs))\n    min_diff = float(min(diffs))\n\n    print(\"\\n\" + \"=\" * 70)\n    print(\"FP8 Per-Channel MoE Test Results\")\n    print(\"=\" * 70)\n    print(f\"Mean relative L1 diff: {mean_diff*100:.4f}%\")\n    print(f\"Max relative L1 diff:  {max_diff*100:.4f}%\")\n    print(f\"Min relative L1 diff:  {min_diff*100:.4f}%\")\n\n    # Pass/Fail criteria\n    threshold = 15.0  # 15% relative error threshold for FP8\n    if mean_diff * 100 < threshold:\n        print(f\"\\nPASS: Mean error {mean_diff*100:.4f}% < {threshold}% threshold\")\n    else:\n        print(f\"\\nFAIL: Mean error {mean_diff*100:.4f}% >= {threshold}% threshold\")\n\n    return {\"mean\": mean_diff, \"max\": max_diff, \"min\": min_diff}\n\n\nif __name__ == \"__main__\":\n    run_fp8_perchannel_moe_test()\n"
  },
  {
    "path": "kt-kernel/examples/test_gate.py",
    "content": "import math\nimport os, sys\nimport time\nfrom typing import Optional\n\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext.kvcache import ggml_type\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\n# from modeling_deepseek_v3 import MoEGate\nfrom configuration_deepseek_v3 import DeepseekV3Config\n\nseed = 42  # 你可以选择任何整数作为种子\ntorch.manual_seed(seed)\ntorch.cuda.manual_seed_all(seed)\n\nseqlen = 64\n\nconfig = DeepseekV3Config()\n\nhidden_size = config.hidden_size\nnum_experts_per_token = config.num_experts_per_tok\nn_routed_experts = config.n_routed_experts\nn_group = config.n_group\ntopk_group = config.topk_group\nrouted_scaling_factor = config.routed_scaling_factor\n\nweights = torch.randn((n_routed_experts, hidden_size), dtype=torch.float32).to(\"cpu\").contiguous()\nbias = torch.randn((n_routed_experts,), dtype=torch.float32).to(\"cpu\").contiguous()\n\n\n# weights = torch.randn((n_routed_experts, hidden_size), dtype=torch.float16).to('cpu').contiguous  ()\ndef load_fp32_tensor(file_path, shape):\n    return torch.zeros(shape, dtype=torch.float32).to(\"cpu\").contiguous()\n    with open(file_path, \"rb\") as f:\n        raw_data = f.read()\n    tensor = torch.frombuffer(raw_data, dtype=torch.float32)\n    tensor = tensor.view(shape)  # 根据你的 shape reshape\n    return tensor\n\n\nclass MoEGate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.top_k = config.num_experts_per_tok\n        self.n_routed_experts = config.n_routed_experts\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.scoring_func = config.scoring_func\n        self.topk_method = config.topk_method\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n\n        # topk selection algorithm\n        self.norm_topk_prob = config.norm_topk_prob\n        self.gating_dim = config.hidden_size\n        self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))\n        if self.topk_method == \"noaux_tc\":\n            self.e_score_correction_bias = nn.Parameter(torch.empty((self.n_routed_experts)))\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        import torch.nn.init as init\n\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n\n    def forward(self, hidden_states):\n        bsz, seq_len, h = hidden_states.shape\n        ### compute gating score\n        hidden_states = hidden_states.view(-1, h)\n\n        h_to_check = load_fp32_tensor(\n            \"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/gate_input\", (seq_len, h)\n        )\n        diff = (h_to_check - hidden_states).abs().max()\n        # print(\"hidden_states diff:\", diff)\n        # assert diff<0.02\n\n        bias_to_check = load_fp32_tensor(\n            \"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/bias\", (n_routed_experts)\n        )\n        diff = (bias - bias_to_check).abs().max()\n        # print('bias diff:',diff)\n        # assert diff < 0.02\n\n        logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None)\n\n        logits_to_check = load_fp32_tensor(\n            \"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/gate_logits\",\n            (seq_len, n_routed_experts),\n        )\n        diff = (logits_to_check - logits).abs().max()\n        # print(\"logits diff:\", diff)\n        # assert diff < 0.02\n\n        if self.scoring_func == \"sigmoid\":\n            scores = logits.sigmoid()\n        else:\n            raise NotImplementedError(f\"insupportable scoring function for MoE gating: {self.scoring_func}\")\n\n        ### select top-k experts\n        if self.topk_method == \"noaux_tc\":\n            # assert not self.training\n            scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)\n\n            scores_to_check = load_fp32_tensor(\n                \"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/scores_to_choice\",\n                (seq_len, n_routed_experts),\n            )\n            diff = (scores_for_choice - scores_to_check).abs().max()\n            print(f\"score for choice diff = {diff}\")\n\n            group_scores = (\n                scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)\n            )  # [n, n_group]\n\n            group_scores_to_check = load_fp32_tensor(\n                \"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/group_scores\",\n                (seq_len, n_group),\n            )\n            diff = (group_scores - group_scores_to_check).abs().max()\n            print(f\"group scores diff = {diff}\")\n\n            group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]  # [n, top_k_group]\n            group_mask = torch.zeros_like(group_scores)  # [n, n_group]\n            group_mask.scatter_(1, group_idx, 1)  # [n, n_group]\n            score_mask = (\n                group_mask.unsqueeze(-1)\n                .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)\n                .reshape(bsz * seq_len, -1)\n            )  # [n, e]\n            tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float(\"-inf\"))  # [n, e]\n            tmp_scores_to_check = load_fp32_tensor(\n                \"/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug/gate_logits_toped\",\n                (seq_len, n_routed_experts),\n            )\n            is_close = torch.isclose(tmp_scores, tmp_scores_to_check, rtol=1e-2, atol=1e-2, equal_nan=True)\n            print(f\"tmp_score ok {is_close.all()}\")\n\n            _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)\n            topk_weight = scores.gather(1, topk_idx)\n        else:\n            raise NotImplementedError(f\"insupportable TopK function for MoE gating: {self.topk_method}\")\n\n        ### norm gate to sum 1\n        if self.top_k > 1 and self.norm_topk_prob:\n            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weight = topk_weight / denominator\n        topk_weight = topk_weight * self.routed_scaling_factor  # must multiply the scaling factor\n\n        return topk_idx, topk_weight\n\n\ndef torch_gate(hidden_states):\n    hidden_states.unsqueeze_(0)\n    gate = MoEGate(config)\n    gate.weight.data = weights\n    gate.e_score_correction_bias.data = bias\n    y = gate(hidden_states)\n    # print(y)\n    return y\n\n\ndef cpuinfer_gate(hidden_states):\n    config = kt_kernel_ext.gate.GateConfig(\n        hidden_size,\n        num_experts_per_token,\n        n_routed_experts,\n        n_group,\n        topk_group,\n    )\n\n    CPUInfer = kt_kernel_ext.CPUInfer(64)\n    config.routed_scaling_factor = routed_scaling_factor\n\n    config.pool = CPUInfer.backend_\n    config.weight = weights.data_ptr()\n    config.weight_type = ggml_type.FP32\n    config.e_score_correction_bias = bias.data_ptr()\n    config.e_score_correction_bias_type = ggml_type.FP32\n\n    gate = kt_kernel_ext.gate.MoEGate(config)\n\n    expert_ids = torch.zeros((seqlen, num_experts_per_token), dtype=torch.int64).to(\"cpu\").contiguous()\n    expert_weights = torch.zeros((seqlen, num_experts_per_token), dtype=torch.float32).to(\"cpu\").contiguous()\n\n    gate.forward(seqlen, hidden_states.data_ptr(), expert_ids.data_ptr(), expert_weights.data_ptr())\n\n    # print(expert_ids,expert_weights)\n    return expert_ids, expert_weights\n\n\ninput = torch.randn(seqlen, hidden_size, dtype=torch.float32).to(\"cpu\").contiguous()\n# print(input)\nids, we = cpuinfer_gate(input)\nidx = torch.argsort(ids, dim=-1, descending=True)\nids = torch.gather(ids, dim=-1, index=idx)\nwe = torch.gather(we, dim=-1, index=idx)\n\n\nstd_ids, std_we = torch_gate(input)\nidx = torch.argsort(std_ids, dim=-1, descending=True)\nstd_we = torch.gather(std_we, dim=-1, index=idx)\nstd_ids = torch.gather(std_ids, dim=-1, index=idx)\n\n\n# print(\"ids diff:\", torch.abs(std_ids - ids).max())\n# print(\"weights diff:\", torch.abs(std_we - we).max())\nassert torch.abs(std_ids - ids).max() == 0, \"Expert IDs do not match!\"\nassert torch.abs(std_we - we).max() < 1e-2, \"Expert Weights do not match!\"\nprint(\"Expert IDs and Weights match successfully!\")\n"
  },
  {
    "path": "kt-kernel/examples/test_k2_moe_amx.py",
    "content": "import math\nimport os\nimport sys\nfrom typing import Dict, Literal\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\n\nimport torch\nfrom kt_kernel import kt_kernel_ext\n\ntorch.manual_seed(42)\n\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\n\nexpert_num = 16\nnum_experts_per_tok = 8\n\nqlen = 1\nlayer_num = 1\nCPUInfer = kt_kernel_ext.CPUInfer(40)\nvalidation_iter = 10\nk_group_size = 32\ndebug_print_count = 16\n\nphysical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n\ndef _pattern_uniform(groups: int) -> torch.Tensor:\n    return torch.full((groups,), 0.02, dtype=torch.float32)\n\n\ndef _pattern_alternating(groups: int) -> torch.Tensor:\n    vals = torch.full((groups,), 0.015, dtype=torch.float32)\n    vals[1::2] = 0.03\n    return vals\n\n\ndef _pattern_ramp(groups: int) -> torch.Tensor:\n    return torch.linspace(0.005, 0.04, steps=groups, dtype=torch.float32)\n\n\nWEIGHT_PATTERNS = {\n    \"uniform_scale\": (\"All k-groups share the same abs max / scale\", _pattern_uniform),\n    \"alternating_scale\": (\"Alternate small / large abs max per k-group\", _pattern_alternating),\n    \"ramp_scale\": (\"Linearly increasing abs max per k-group\", _pattern_ramp),\n    \"random\": (\"Random bf16 weights (baseline)\", None),\n}\n\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    print(f\"gate_buf: {gate_buf}\")\n    print(f\"up_buf: {up_buf}\")\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    print(f\"intermediate: {intermediate}\")\n    print(f\"mlp output: {ret}\")\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\n\ndef pack_to_int32(value: torch.Tensor, num_bits: int, packed_dim: Literal[0, 1] = 1) -> torch.Tensor:\n    if value.dtype is not torch.int8:\n        raise ValueError(\"Tensor must be torch.int8 before packing\")\n    if not (1 <= num_bits <= 8):\n        raise ValueError(f\"num_bits must be in [1, 8], got {num_bits}\")\n\n    offset = 1 << (num_bits - 1)\n    value = (value + offset).to(torch.uint8)\n    device = value.device\n\n    pack_factor = 32 // num_bits\n\n    if packed_dim == 0:\n        value = value.transpose(0, 1)\n\n    rows, cols = value.shape\n    padded_cols = math.ceil(cols / pack_factor) * pack_factor\n    pad_len = padded_cols - cols\n\n    if pad_len > 0:\n        value = torch.nn.functional.pad(value, (0, pad_len))\n\n    num_groups = padded_cols // pack_factor\n\n    # Use int32 here\n    reshaped = value.view(rows, num_groups, pack_factor).to(torch.int32)\n    bit_shifts = torch.arange(pack_factor, device=device, dtype=torch.int32) * num_bits\n    packed = (reshaped << bit_shifts).sum(dim=2, dtype=torch.int32)\n\n    if packed_dim == 0:\n        packed = packed.transpose(0, 1)\n\n    return packed\n\n\ndef pack_tensor_per_row(q: torch.Tensor, num_bits: int) -> torch.Tensor:\n    e, rows, cols = q.shape\n    flat = q.view(e * rows, cols)\n    packed = pack_to_int32(flat, num_bits)\n    return packed.view(e, rows, -1).contiguous()\n\n\ndef quantize_k2_tensor(weights: torch.Tensor, group_size: int):\n    \"\"\"\n    Symmetric max-abs/7 quantization per k-group following compressed_tensors packing.\n    Args:\n        weights: [expert_num, rows (N), cols (K)]\n    Returns:\n        packed_q: int32 tensor storing 8 int4s per element with shape [expert_num, rows * (cols // 8)]\n        scales: bfloat16 tensor with shape [expert_num, rows * (cols // group_size)]\n    \"\"\"\n    weights_f32 = weights.to(torch.float32)\n    e, rows, cols = weights_f32.shape\n    if cols % group_size != 0 or cols % 2 != 0:\n        raise ValueError(f\"cols ({cols}) must be divisible by group_size ({group_size}) and 2\")\n\n    reshaped = weights_f32.view(e, rows, cols // group_size, group_size)\n    max_abs = reshaped.abs().amax(dim=-1, keepdim=True)\n    max_abs = torch.clamp(max_abs, min=1e-8)\n    scales = (max_abs / 7.0).squeeze(-1)\n    q = torch.round(reshaped / scales.unsqueeze(-1)).clamp(-8, 7).to(torch.int8)\n    q = q.view(e, rows, cols)\n    packed = pack_tensor_per_row(q, num_bits=4).view(e, rows, cols // 8).contiguous()\n    scales = scales.to(torch.bfloat16).contiguous().view(e, rows, cols // group_size).contiguous()\n\n    print(f\"Quantized weights: {packed.shape}, scales: {scales.shape}\")\n    print(f\"Quantized tensors: \\n{packed},\\n {scales}\")\n    return packed, scales\n\n\ndef build_structured_tensor(shape: torch.Size, pattern: str) -> torch.Tensor:\n    if pattern == \"random\":\n        torch.manual_seed(42)\n        return (torch.randn(shape, dtype=torch.bfloat16, device=\"cpu\") / 100.0).contiguous()\n\n    e, rows, cols = shape\n    groups = cols // k_group_size\n    group_builder = WEIGHT_PATTERNS[pattern][1]\n    group_vals = group_builder(groups).to(torch.float32)\n    block = group_vals.view(1, 1, groups, 1).expand(e, rows, groups, k_group_size).clone()\n    row_signs = torch.where(\n        (torch.arange(rows) % 2 == 0),\n        torch.ones(rows, dtype=torch.float32),\n        -torch.ones(rows, dtype=torch.float32),\n    ).view(1, rows, 1, 1)\n    col_offsets = torch.linspace(-0.0005, 0.0005, steps=k_group_size, dtype=torch.float32).view(1, 1, 1, k_group_size)\n    block = block * row_signs + col_offsets\n    return block.reshape(shape).to(torch.bfloat16).contiguous()\n\n\ndef prepare_k2_quantized_weights(pattern: str) -> Dict[str, torch.Tensor]:\n    if pattern not in WEIGHT_PATTERNS:\n        raise ValueError(f\"Unknown weight pattern: {pattern}\")\n\n    gate_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern)\n    up_proj = build_structured_tensor((expert_num, intermediate_size, hidden_size), pattern)\n    down_proj = build_structured_tensor((expert_num, hidden_size, intermediate_size), pattern)\n\n    gate_q, gate_scales = quantize_k2_tensor(gate_proj, k_group_size)\n    up_q, up_scales = quantize_k2_tensor(up_proj, k_group_size)\n    down_q, down_scales = quantize_k2_tensor(down_proj, k_group_size)\n\n    return {\n        \"gate_qweight\": gate_q.contiguous(),\n        \"up_qweight\": up_q.contiguous(),\n        \"down_qweight\": down_q.contiguous(),\n        \"gate_scales\": gate_scales.contiguous(),\n        \"up_scales\": up_scales.contiguous(),\n        \"down_scales\": down_scales.contiguous(),\n        \"original_fp16\": {\n            \"gate_proj\": gate_proj.to(torch.float16).contiguous(),\n            \"up_proj\": up_proj.to(torch.float16).contiguous(),\n            \"down_proj\": down_proj.to(torch.float16).contiguous(),\n        },\n    }\n\n\ndef build_moes_from_quantized_data(quant_data: Dict[str, torch.Tensor]):\n    moes = []\n    with torch.inference_mode(mode=True):\n        for _ in range(layer_num):\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.quant_config.bits = 4\n            config.quant_config.group_size = k_group_size\n            config.quant_config.zero_point = False\n\n            config.gate_proj = quant_data[\"gate_qweight\"].data_ptr()\n            config.up_proj = quant_data[\"up_qweight\"].data_ptr()\n            config.down_proj = quant_data[\"down_qweight\"].data_ptr()\n\n            config.gate_scale = quant_data[\"gate_scales\"].data_ptr()\n            config.up_scale = quant_data[\"up_scales\"].data_ptr()\n            config.down_scale = quant_data[\"down_scales\"].data_ptr()\n            config.pool = CPUInfer.backend_\n\n            moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(config)\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n            # CPUInfer.submit(moe.warm_up_task())\n            # CPUInfer.sync()\n            moes.append(moe)\n    return moes\n\n\ndef run_case(pattern: str) -> Dict[str, float]:\n    print(\"\\n\" + \"=\" * 70)\n    desc = WEIGHT_PATTERNS[pattern][0]\n    print(f\"Running case: {pattern} -> {desc}\")\n    print(\"=\" * 70)\n\n    quant_data = prepare_k2_quantized_weights(pattern)\n    moes = build_moes_from_quantized_data(quant_data)\n\n    original_weights = quant_data[\"original_fp16\"]\n    gate_fp16 = original_weights[\"gate_proj\"]\n    up_fp16 = original_weights[\"up_proj\"]\n    down_fp16 = original_weights[\"down_proj\"]\n\n    diffs = []\n    with torch.inference_mode(mode=True):\n        for i in range(validation_iter):\n            torch.manual_seed(100 + i)\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.randn((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n            input_tensor = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous() / 100\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n\n            moe = moes[i % layer_num]\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input_tensor.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            input_tensor_fp16 = input_tensor.to(torch.float16)\n            t_output = moe_torch(input_tensor_fp16, expert_ids, weights, gate_fp16, up_fp16, down_fp16).to(\n                torch.bfloat16\n            )\n\n            t_output = t_output.flatten()\n            output = output.flatten()\n\n            diff = torch.mean(torch.abs(output - t_output)) / (torch.mean(torch.abs(t_output)) + 1e-12)\n            diffs.append(diff.item())\n            print(f\"[{pattern}] Iteration {i}: relative L1 diff = {diff:.4f}\")\n            print(f\"           output   {output}\")\n            print(f\"           t_output {t_output}\")\n\n    mean_diff = float(sum(diffs) / len(diffs))\n    max_diff = float(max(diffs))\n    min_diff = float(min(diffs))\n    return {\"case\": pattern, \"description\": desc, \"mean\": mean_diff, \"max\": max_diff, \"min\": min_diff}\n\n\ndef run_k2_moe_test():\n    summary_rows = []\n    for case_name in WEIGHT_PATTERNS.keys():\n        results = run_case(case_name)\n        summary_rows.append(results)\n        # break\n\n    print(\"\\n=== Case vs. Relative Error Summary ===\")\n    print(f\"{'Case':<20} {'Mean':>10} {'Max':>10} {'Min':>10}\")\n    for row in summary_rows:\n        print(f\"{row['case']:<20} {row['mean']*100:9.2f}% {row['max']*100:9.2f}% {row['min']*100:9.2f}%\")\n\n\nif __name__ == \"__main__\":\n    run_k2_moe_test()\n"
  },
  {
    "path": "kt-kernel/examples/test_k2_write_buffer.py",
    "content": "import os\nimport sys\nimport time\n\nimport torch\nimport numpy as np\n\n\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext import CPUInfer\n\n\ndef make_cpu_infer(thread_num=80):\n    return CPUInfer(thread_num)\n\n\ndef build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):\n    cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n    cfg.max_len = 1\n    cfg.quant_config.bits = 4\n    cfg.quant_config.group_size = group_size\n    cfg.quant_config.zero_point = False\n    cfg.pool = cpuinfer.backend_\n    return cfg\n\n\ndef allocate_weights(expert_num, hidden_size, intermediate_size, group_size):\n    # packed int4 weights: 2 values per byte\n    per_mat_weight_bytes = (hidden_size * intermediate_size) // 2\n    per_mat_scale_elems = (hidden_size * intermediate_size) // group_size\n\n    gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n    up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n    down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n\n    gate_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)\n    up_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)\n    down_scale = torch.randn(expert_num * per_mat_scale_elems, dtype=torch.bfloat16)\n\n    return (\n        gate_q,\n        up_q,\n        down_q,\n        gate_scale,\n        up_scale,\n        down_scale,\n        per_mat_weight_bytes,\n        per_mat_scale_elems,\n    )\n\n\ndef test_with_tp(gpu_tp_count):\n    \"\"\"Test write_weight_scale_to_buffer with a specific gpu_tp_count\"\"\"\n    torch.manual_seed(123)\n\n    expert_num = 8  # Reduced for faster testing\n    gpu_experts = expert_num  # Number of experts on GPU\n\n    num_experts_per_tok = 8\n    hidden_size = 7168\n    intermediate_size = 2048\n    group_size = 32\n\n    cpuinfer = make_cpu_infer()\n    cfg = build_config(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)\n\n    (\n        gate_q,\n        up_q,\n        down_q,\n        gate_scale,\n        up_scale,\n        down_scale,\n        per_mat_weight_bytes,\n        per_mat_scale_elems,\n    ) = allocate_weights(expert_num, hidden_size, intermediate_size, group_size)\n\n    cfg.gate_proj = gate_q.data_ptr()\n    cfg.up_proj = up_q.data_ptr()\n    cfg.down_proj = down_q.data_ptr()\n    cfg.gate_scale = gate_scale.data_ptr()\n    cfg.up_scale = up_scale.data_ptr()\n    cfg.down_scale = down_scale.data_ptr()\n\n    moe = kt_kernel_ext.moe.AMXInt4_KGroup_MOE(cfg)\n\n    physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device=\"cpu\").contiguous()\n    cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n    cpuinfer.sync()\n\n    # TP configuration\n    # Calculate sizes per TP part (per expert)\n    weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count\n    scale_elems_per_expert_per_tp = per_mat_scale_elems // gpu_tp_count\n\n    # Total sizes for all gpu_experts\n    total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp\n    total_scale_elems_per_tp = gpu_experts * scale_elems_per_expert_per_tp\n\n    # Create buffer lists for w13 (gate+up) and w2 (down)\n    # These hold all experts' data for each GPU TP\n    w13_weight_bufs = []\n    w13_scale_bufs = []\n    w2_weight_bufs = []\n    w2_scale_bufs = []\n\n    for tp_idx in range(gpu_tp_count):\n        # w13 combines gate and up, so needs 2x the size per expert\n        w13_weight_bufs.append(torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8))\n        w13_scale_bufs.append(torch.empty(2 * total_scale_elems_per_tp, dtype=torch.bfloat16))\n        w2_weight_bufs.append(torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8))\n        w2_scale_bufs.append(torch.empty(total_scale_elems_per_tp, dtype=torch.bfloat16))\n\n    print(f\"Total experts: {expert_num}, GPU experts: {gpu_experts}\")\n    print(f\"GPU TP count: {gpu_tp_count}\")\n    print(f\"Original per matrix weight bytes: {per_mat_weight_bytes}\")\n    print(f\"Original per matrix scale elements: {per_mat_scale_elems}\")\n    print(f\"Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}\")\n    print(f\"Scale elements per expert per TP: {scale_elems_per_expert_per_tp}\")\n    print(f\"Total weight bytes per TP (w13): {2 * total_weight_bytes_per_tp}\")\n    print(f\"Total weight bytes per TP (w2): {total_weight_bytes_per_tp}\")\n\n    # Helper function to get pointers with expert offset\n    # K2 write_weights_to_buffer writes one expert at a time, so we need to pass\n    # pointers that already point to the correct location for each expert\n    def get_expert_ptrs(expert_id):\n        w13_weight_ptrs = []\n        w13_scale_ptrs = []\n        w2_weight_ptrs = []\n        w2_scale_ptrs = []\n\n        for tp_idx in range(gpu_tp_count):\n            # Calculate byte offsets for this expert\n            # w13: gate_weight + up_weight interleaved by expert\n            # Layout: [expert0_gate, expert0_up, expert1_gate, expert1_up, ...]\n            w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp\n            w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp\n            w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp\n            w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp\n\n            w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)\n            w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 2)  # bf16 = 2 bytes\n            w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)\n            w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 2)  # bf16 = 2 bytes\n\n        return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs\n\n    # Warm up\n    for i in range(2):\n        for expert_id in range(gpu_experts):\n            w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)\n            cpuinfer.submit(\n                moe.write_weight_scale_to_buffer_task(\n                    gpu_tp_count=gpu_tp_count,\n                    expert_id=expert_id,\n                    w13_weight_ptrs=w13_weight_ptrs,\n                    w13_scale_ptrs=w13_scale_ptrs,\n                    w2_weight_ptrs=w2_weight_ptrs,\n                    w2_scale_ptrs=w2_scale_ptrs,\n                )\n            )\n            cpuinfer.sync()\n\n    # Timing\n    begin_time = time.perf_counter_ns()\n    for expert_id in range(gpu_experts):\n        w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)\n        cpuinfer.submit(\n            moe.write_weight_scale_to_buffer_task(\n                gpu_tp_count=gpu_tp_count,\n                expert_id=expert_id,\n                w13_weight_ptrs=w13_weight_ptrs,\n                w13_scale_ptrs=w13_scale_ptrs,\n                w2_weight_ptrs=w2_weight_ptrs,\n                w2_scale_ptrs=w2_scale_ptrs,\n            )\n        )\n        cpuinfer.sync()\n    end_time = time.perf_counter_ns()\n    elapsed_ms = (end_time - begin_time) / 1000000\n    total_weights = hidden_size * intermediate_size * gpu_experts * 3\n    total_bytes = total_weights // group_size * 2 + total_weights // 2  # scale (bf16) + weight (int4)\n    print(f\"write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms\")\n    print(f\"Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s\")\n\n    def split_expert_tensor(tensor, chunk):\n        \"\"\"Split tensor by experts\"\"\"\n        return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]\n\n    # Split by experts first\n    gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)\n    up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)\n    down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)\n\n    gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems)\n    up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems)\n    down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems)\n\n    # Verify buffers for each TP part\n    for tp_idx in range(gpu_tp_count):\n        expected_w13_weights = []\n        expected_w13_scales = []\n        expected_w2_weights = []\n        expected_w2_scales = []\n\n        weight13_per_tp = per_mat_weight_bytes // gpu_tp_count\n        scale13_per_tp = per_mat_scale_elems // gpu_tp_count\n\n        # Process each GPU expert\n        for expert_id in range(gpu_experts):\n            # For w13 (gate and up), the slicing is straightforward\n            start_weight = tp_idx * weight13_per_tp\n            end_weight = (tp_idx + 1) * weight13_per_tp\n            start_scale = tp_idx * scale13_per_tp\n            end_scale = (tp_idx + 1) * scale13_per_tp\n\n            # Gate\n            gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]\n            gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]\n\n            # Up\n            up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]\n            up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]\n\n            # Down matrix needs special handling because it's sliced column-wise\n            # We need to reconstruct it from column slices\n            down_weight_tp_parts = []\n            down_scale_tp_parts = []\n\n            # Iterate through each column to extract the corresponding parts\n            for col_idx in range(hidden_size):\n                col_weight_start = col_idx * (intermediate_size // 2)\n                col_scale_start = col_idx * (intermediate_size // group_size)\n\n                # Direct mapping: each CPU TP corresponds to a GPU TP\n                tp_slice_weight_size = (intermediate_size // gpu_tp_count) // 2\n                tp_slice_scale_size = (intermediate_size // gpu_tp_count) // group_size\n\n                tp_weight_offset = col_weight_start + tp_idx * tp_slice_weight_size\n                tp_scale_offset = col_scale_start + tp_idx * tp_slice_scale_size\n\n                down_weight_tp_parts.append(\n                    down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]\n                )\n                down_scale_tp_parts.append(\n                    down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + tp_slice_scale_size]\n                )\n\n            # Concatenate all column slices for this TP\n            down_weight_tp = torch.cat(down_weight_tp_parts)\n            down_scale_tp = torch.cat(down_scale_tp_parts)\n\n            # Append to expected lists - interleaved by expert: [gate0, up0, gate1, up1, ...]\n            expected_w13_weights.append(gate_weight_tp)\n            expected_w13_weights.append(up_weight_tp)\n            expected_w13_scales.append(gate_scale_tp)\n            expected_w13_scales.append(up_scale_tp)\n            expected_w2_weights.append(down_weight_tp)\n            expected_w2_scales.append(down_scale_tp)\n\n        # Concatenate all experts for this TP part\n        expected_w13_weight = torch.cat(expected_w13_weights)\n        expected_w13_scale = torch.cat(expected_w13_scales)\n        expected_w2_weight = torch.cat(expected_w2_weights)\n        expected_w2_scale = torch.cat(expected_w2_scales)\n\n        print(f\"=== Checking TP part {tp_idx} ===\")\n        print(f\"  w13 weight shape: actual={w13_weight_bufs[tp_idx].shape}, expected={expected_w13_weight.shape}\")\n        print(f\"  w13 scale shape: actual={w13_scale_bufs[tp_idx].shape}, expected={expected_w13_scale.shape}\")\n        print(f\"  w2 weight shape: actual={w2_weight_bufs[tp_idx].shape}, expected={expected_w2_weight.shape}\")\n        print(f\"  w2 scale shape: actual={w2_scale_bufs[tp_idx].shape}, expected={expected_w2_scale.shape}\")\n\n        # Assert all checks pass\n        if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):\n            diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight\n            first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1\n            print(f\"  w13 weight mismatch at index {first_diff_idx}\")\n            print(f\"    actual: {w13_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}\")\n            print(f\"    expected: {expected_w13_weight[first_diff_idx:first_diff_idx+10]}\")\n            raise AssertionError(f\"w13 weight bytes mismatch for TP {tp_idx}\")\n\n        if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):\n            diff = torch.abs(w13_scale_bufs[tp_idx].float() - expected_w13_scale.float())\n            max_diff_idx = diff.argmax().item()\n            print(f\"  w13 scale mismatch, max diff at index {max_diff_idx}\")\n            print(f\"    actual: {w13_scale_bufs[tp_idx][max_diff_idx]}\")\n            print(f\"    expected: {expected_w13_scale[max_diff_idx]}\")\n            raise AssertionError(f\"w13 scale values mismatch for TP {tp_idx}\")\n\n        if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):\n            diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight\n            first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1\n            print(f\"  w2 weight mismatch at index {first_diff_idx}\")\n            print(f\"    actual: {w2_weight_bufs[tp_idx][first_diff_idx:first_diff_idx+10]}\")\n            print(f\"    expected: {expected_w2_weight[first_diff_idx:first_diff_idx+10]}\")\n            raise AssertionError(f\"w2 weight bytes mismatch for TP {tp_idx}\")\n\n        if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):\n            diff = torch.abs(w2_scale_bufs[tp_idx].float() - expected_w2_scale.float())\n            max_diff_idx = diff.argmax().item()\n            print(f\"  w2 scale mismatch, max diff at index {max_diff_idx}\")\n            print(f\"    actual: {w2_scale_bufs[tp_idx][max_diff_idx]}\")\n            print(f\"    expected: {expected_w2_scale[max_diff_idx]}\")\n            raise AssertionError(f\"w2 scale values mismatch for TP {tp_idx}\")\n\n    print(\n        f\"\\n✓ write_weight_scale_to_buffer passed: extracted {gpu_experts} GPU experts across {gpu_tp_count} TP parts\"\n    )\n    return True\n\n\ndef main():\n    \"\"\"Run tests for all gpu_tp_count values: 1, 2, 4, 8\"\"\"\n    tp_values = [1, 2, 4, 8]\n    all_passed = True\n    results = {}\n\n    print(\"=\" * 60)\n    print(\"Testing K2 write_weight_scale_to_buffer for TP = 1, 2, 4, 8\")\n    print(\"=\" * 60)\n\n    for tp in tp_values:\n        print(f\"\\n{'='*60}\")\n        print(f\"Testing with gpu_tp_count = {tp}\")\n        print(f\"{'='*60}\")\n        try:\n            test_with_tp(tp)\n            results[tp] = \"PASSED\"\n            print(f\"✓ TP={tp} PASSED\")\n        except Exception as e:\n            results[tp] = f\"FAILED: {e}\"\n            all_passed = False\n            print(f\"✗ TP={tp} FAILED: {e}\")\n\n    print(\"\\n\" + \"=\" * 60)\n    print(\"SUMMARY\")\n    print(\"=\" * 60)\n    for tp, result in results.items():\n        status = \"✓\" if \"PASSED\" in result else \"✗\"\n        print(f\"  {status} TP={tp}: {result}\")\n\n    if all_passed:\n        print(\"\\n✓ ALL TESTS PASSED\")\n    else:\n        print(\"\\n✗ SOME TESTS FAILED\")\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/examples/test_linear.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-06 10:36:59\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\ninput_size = 16384\noutput_size = 5120\nstride = 32\ngroup_max_len = 1024\nproj_type = 1  # ggml_type::GGML_TYPE_F16\nhidden_type = 1  # ggml_type::GGML_TYPE_F16\nqlen = 30\nlayer_num = 10\nCPUInfer = kt_kernel_ext.CPUInfer(48)\nvalidation_iter = 100\n\nwith torch.inference_mode(mode=True):\n    linears = []\n    projs = []\n    for _ in range(layer_num):\n        proj = torch.randn((output_size, input_size), dtype=torch.float16, device=\"cuda\").to(\"cpu\").contiguous()\n        config = kt_kernel_ext.linear.LinearConfig(\n            input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type\n        )\n        linear = kt_kernel_ext.linear.Linear(config)\n        projs.append(proj)\n        linears.append(linear)\n\n    # validation\n    for i in range(validation_iter):\n        linear = linears[i % layer_num]\n        input = torch.randn((qlen, input_size), dtype=torch.float16).contiguous()\n        output = torch.empty((qlen, output_size), dtype=torch.float16).contiguous()\n        input = input / 100\n\n        CPUInfer.submit(linear.forward(qlen, input.data_ptr(), output.data_ptr()))\n        CPUInfer.sync()\n        # print('cpuinfer output', output)\n\n        proj = projs[i % layer_num]\n        t_output = torch.mm(input, proj.t())\n        # print('torch output', t_output)\n\n        diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n        print(\"diff = \", diff)\n        assert diff < 0.001\n"
  },
  {
    "path": "kt-kernel/examples/test_mla.py",
    "content": "import logging\nimport os, sys\nimport time\nfrom typing import Optional\n\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext.kvcache import ggml_type\nimport torch\nfrom torch import inf, nn\nfrom torch.nn import init\nfrom torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding\n\nlogger = logging.getLogger(\"reader\")\n\nfrom gguf.gguf_reader import GGUFReader\n\n\ndef read_gguf_file(gguf_file_path):\n    \"\"\"\n    Reads and prints key-value pairs and tensor information from a GGUF file in an improved format.\n\n    Parameters:\n    - gguf_file_path: Path to the GGUF file.\n    \"\"\"\n\n    reader = GGUFReader(gguf_file_path)\n\n    # List all key-value pairs in a columnized format\n    # print(\"Key-Value Pairs:\") # noqa: NP100\n    # max_key_length = max(len(key) for key in reader.fields.keys())\n    for key, field in reader.fields.items():\n        value = field.parts[field.data[0]]\n        # print(f\"{key:{max_key_length}} : {value}\") # noqa: NP100\n    # print(\"----\") # noqa: NP100\n\n    # List all tensors\n    # print(\"Tensors:\") # noqa: NP100\n    # tensor_info_format = \"{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}\"\n    # print(tensor_info_format.format(\"Tensor Name\", \"Shape\", \"Size\", \"Quantization\")) # noqa: NP100\n    # print(\"-\" * 80) # noqa: NP100\n    re = []\n    for tensor in reader.tensors:\n        shape_str = \"x\".join(map(str, tensor.shape))\n        size_str = str(tensor.n_elements)\n        quantization_str = tensor.tensor_type.name\n        # print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str)) # noqa: NP100\n        re.append(tensor)\n    return re\n\n\ndef get_torch_tensor_from_gguf(gguf_weights, name):\n    return torch.from_numpy(gguf_weights[name].data).contiguous()\n\n\ndef get_torch_tensor_and_type_from_gguf(gguf_weights, name):\n    return torch.from_numpy(gguf_weights[name].data).contiguous(), gguf_weights[name].tensor_type.name\n\n\ndef type_to_ggml_type(type):\n    if type == \"F32\":\n        return ggml_type.FP32\n    elif type == \"F16\":\n        return ggml_type.FP16\n    elif type == \"BF16\":\n        return ggml_type.BF16\n    else:\n        raise ValueError(f\"Unsupported data type: {type}\")\n\n\nuse_real_weights = True\ngguf_path = \"/home/bd/models/DeepSeek-R1-BF16\"\n\nseed = 42  # 你可以选择任何整数作为种子\ntorch.manual_seed(seed)\ntorch.cuda.manual_seed_all(seed)\n\nqlen = 3212\nkvlen = 0\n\n\npage_table = range(20)\nbsz_tensors = torch.tensor([1])\n\n\npage_size = 256\npages_count = 200\ntp_count = 4\n\n\nhidden_size = 7168\nq_lora_rank = 1536\nkv_lora_rank = 512\nnum_heads = 128\nnope_size = 128\nrope_size = 64\n\nrope_theta = 10000\nmax_qlen = 4096\nmax_kvlen = 4096\n\nmax_position_embeddings = 163840\n\n\nrope_scaling = {\n    \"beta_fast\": 32,\n    \"beta_slow\": 1,\n    \"factor\": 40,\n    \"mscale\": 1.0,\n    \"mscale_all_dim\": 1.0,\n    \"original_max_position_embeddings\": 4096,\n    \"type\": \"yarn\",\n}\n\n\nCPUInfer = kt_kernel_ext.CPUInfer(30)\nvalidation_iter = 100\n\n\n# data_type = torch.float32\nweight_type = torch.bfloat16\n# weight_type = torch.float16\n\n\ninput_type = {\n    torch.float32: torch.float32,\n    torch.float16: torch.float16,\n    torch.bfloat16: torch.float32,\n}[weight_type]\n\nq_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=weight_type)\nq_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size + rope_size), bias=False, dtype=weight_type)\nkv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=weight_type)\nkv_b_proj = nn.Linear(num_heads * (nope_size + nope_size), kv_lora_rank, bias=False, dtype=weight_type)\no_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=weight_type)\nq_a_norm = torch.ones(hidden_size, dtype=torch.float32)\nkv_a_norm = torch.ones(hidden_size, dtype=torch.float32)\n\n\ndef read_gguf_directory(directory):\n    \"\"\"\n    Reads all GGUF files in a directory and prints their contents.\n\n    Parameters:\n    - directory: Path to the directory containing GGUF files.\n    \"\"\"\n    if not os.path.isdir(directory):\n        logger.error(f\"Directory {directory} does not exist.\")\n        return\n\n    # List all GGUF files in the directory\n    files = [f for f in os.listdir(directory) if f.endswith(\".gguf\")]\n    if not files:\n        logger.info(f\"No GGUF files found in {directory}.\")\n        return\n\n    re = []\n    for file in files:\n        file_path = os.path.join(directory, file)\n        # print(f\"Reading {file_path}:\") # noqa: NP100\n        # print(\"\\n\") # noqa: NP100\n        re.extend(read_gguf_file(file_path))\n    re = {r.name: r for r in re}\n    return re\n\n\nif use_real_weights := True:\n    gguf_weights = read_gguf_directory(gguf_path)\n    layer_idx = 0\n    q_a_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_a.weight\")\n    q_a_proj.weight = nn.Parameter(q_a_proj_weight.view(torch.bfloat16), requires_grad=False)\n    q_a_type = type\n\n    q_a_norm_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_a_norm.weight\")\n    q_a_norm = q_a_norm_weight.view(torch.float32)\n    # config.q_a_norm = q_a_norm_weight.data_ptr()\n    # config.q_a_norm_type = type_to_ggml_type(type)\n\n    q_b_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_b.weight\")\n    q_b_proj.weight = nn.Parameter(q_b_proj_weight.view(torch.bfloat16), requires_grad=False)\n\n    kv_a_proj_with_mqa_weight, type = get_torch_tensor_and_type_from_gguf(\n        gguf_weights, f\"blk.{layer_idx}.attn_kv_a_mqa.weight\"\n    )\n    kv_a_proj_with_mqa.weight = nn.Parameter(kv_a_proj_with_mqa_weight.view(torch.bfloat16), requires_grad=False)\n\n    kv_a_norm_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_kv_a_norm.weight\")\n    kv_a_norm = kv_a_norm_weight.view(torch.float32)\n    # config.kv_a_norm = kv_a_norm_weight.data_ptr()\n    # config.kv_a_norm_type = type_to_ggml_type(type)\n\n    kv_b_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_kv_b.weight\")\n    kv_b_proj.weight = nn.Parameter(kv_b_proj_weight.view(torch.bfloat16), requires_grad=False)\n\n    o_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_output.weight\")\n    o_proj.weight = nn.Parameter(o_proj_weight.view(torch.bfloat16), requires_grad=False)\n\nelse:\n    init.normal_(q_a_proj.weight, mean=0.0, std=0.02)\n    init.normal_(q_b_proj.weight, mean=0.0, std=0.02)\n    init.normal_(kv_a_proj_with_mqa.weight, mean=0.0, std=0.02)\n    init.normal_(kv_b_proj.weight, mean=0.0, std=0.02)\n    init.normal_(o_proj.weight, mean=0.0, std=0.02)\n\nx_reshaped = kv_b_proj.weight.view(num_heads, 2, nope_size, kv_lora_rank)\nq_absorb = x_reshaped[:, 0]\nout_absorb = x_reshaped[:, 1]\n\n\nhidden_states = torch.randn((qlen, hidden_size), dtype=input_type).to(\"cpu\").contiguous()\n\n\ndef test_cpu_mla():\n    os.environ[\"BLAS_NUM_THREADS\"] = \"1\"\n    q_a_proj_weight = q_a_proj.weight.to(weight_type).to(\"cpu\").contiguous()\n    q_b_proj_weight = q_b_proj.weight.to(weight_type).to(\"cpu\").contiguous()\n    kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to(\"cpu\").to(weight_type).contiguous()\n    kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to(\"cpu\").contiguous()\n    o_proj_weight = o_proj.weight.to(weight_type).to(\"cpu\").contiguous()\n\n    config = kt_kernel_ext.mla.MLAConfig(\n        hidden_size,\n        q_lora_rank,\n        kv_lora_rank,\n        num_heads,\n        nope_size,\n        rope_size,\n    )\n    config.max_qlen = max_qlen\n    config.max_kvlen = max_kvlen\n    config.max_position_embeddings = max_position_embeddings\n    config.rope_scaling_factor = rope_scaling[\"factor\"]\n    config.rope_theta = rope_theta\n    config.rope_scaling_beta_fast = rope_scaling[\"beta_fast\"]\n    config.rope_scaling_beta_slow = rope_scaling[\"beta_slow\"]\n    config.rope_scaling_mscale = rope_scaling[\"mscale\"]\n    config.rope_scaling_mscale_all_dim = rope_scaling[\"mscale_all_dim\"]\n    config.rope_scaling_original_max_position_embeddings = rope_scaling[\"original_max_position_embeddings\"]\n\n    config.q_a_proj = q_a_proj_weight.data_ptr()\n    config.q_b_proj = q_b_proj_weight.data_ptr()\n    config.kv_a_proj_with_mqa = kv_a_proj_with_mqa_weight.data_ptr()\n    config.kv_b_proj = kv_b_proj_weight.data_ptr()\n    config.o_proj = o_proj_weight.data_ptr()\n\n    config.q_a_norm = q_a_norm.data_ptr()\n    config.q_a_norm_type = ggml_type.FP32\n    config.kv_a_norm = kv_a_norm.data_ptr()\n    config.kv_a_norm_type = ggml_type.FP32\n    config.page_count = pages_count\n\n    if weight_type == torch.float32:\n        config.q_a_proj_type = ggml_type.FP32\n        config.q_b_proj_type = ggml_type.FP32\n        config.kv_a_proj_with_mqa_type = ggml_type.FP32\n        config.kv_b_proj_type = ggml_type.FP32\n        config.w_o_type = ggml_type.FP32\n    elif weight_type == torch.float16:\n        config.q_a_proj_type = ggml_type.FP16\n        config.q_b_proj_type = ggml_type.FP16\n        config.kv_a_proj_with_mqa_type = ggml_type.FP16\n        config.kv_b_proj_type = ggml_type.FP16\n        config.w_o_type = ggml_type.FP16\n    elif weight_type == torch.bfloat16:\n        config.q_a_proj_type = ggml_type.BF16\n        config.q_b_proj_type = ggml_type.BF16\n        config.kv_a_proj_with_mqa_type = ggml_type.BF16\n        config.kv_b_proj_type = ggml_type.BF16\n        config.w_o_type = ggml_type.BF16\n    else:\n        raise ValueError(f\"Unsupported data type: {weight_type}\")\n\n    config.pool = CPUInfer.backend_\n\n    if weight_type == torch.float32:\n        mla = kt_kernel_ext.mla.MLA_F32(config)\n    elif weight_type == torch.float16:\n        mla = kt_kernel_ext.mla.MLA_F16(config)\n    elif weight_type == torch.bfloat16:\n        # mla = kt_kernel_ext.mla.MLA_F32(config)\n        mla = kt_kernel_ext.mla.MLA_QUAN_F32(config)\n    else:\n        raise ValueError(f\"Unsupported data type: {weight_type}\")\n\n    mla.load_weights()\n    mla.set_local_pages(pages_count)\n\n    output = torch.zeros((qlen, hidden_size), dtype=input_type).to(\"cpu\").contiguous()\n    mla.forward([qlen], [page_table], [kvlen], hidden_states.data_ptr(), output.data_ptr())\n    print(\"CPU MLA Output: \", output)\n    return output\n\n\ndef load_fp16_tensor(file_path, shape):\n    # return load_fp32_tensor(file_path, shape)\n    return torch.zeros(shape)\n    with open(file_path, \"rb\") as f:\n        raw_data = f.read()\n    tensor = torch.frombuffer(raw_data, dtype=weight_type)\n    tensor = tensor.view(shape)  # 根据你的 shape reshape\n    return tensor\n\n\ndef load_fp32_tensor(file_path, shape):\n    return torch.zeros(shape)\n    with open(file_path, \"rb\") as f:\n        raw_data = f.read()\n    tensor = torch.frombuffer(raw_data, dtype=torch.float32)\n    tensor = tensor.view(shape)  # 根据你的 shape reshape\n    return tensor\n\n\ndef test_torch():\n    torch.set_grad_enabled(False)\n\n    softmax_scale = (nope_size + rope_size) ** -0.5\n    # 1代表的是压缩的kv的头数\n    k_caches = torch.randn(1, pages_count, page_size, 1, kv_lora_rank + rope_size).to(weight_type)\n    kv_cache = KDeepSeekV3Cache(page_size=page_size, kv_lora_rank=kv_lora_rank, k_caches=k_caches)\n\n    q_a_layernorm = DeepseekV2RMSNorm(q_lora_rank)\n    q_a_layernorm.weight = nn.Parameter(q_a_norm, requires_grad=False)\n\n    x = torch.randn(q_lora_rank, dtype=weight_type) * 100\n    print(x)\n    print(q_a_layernorm(x))\n\n    kv_a_layernorm = DeepseekV2RMSNorm(kv_lora_rank)\n    kv_a_layernorm.weight = nn.Parameter(kv_a_norm, requires_grad=False)\n\n    # 第三步：拆分成两个 tensor\n    # q_absorb, out_absorb = x_permuted[:, 0], x_permuted[:, 1]  # 都是 (num_heads, nope_size, kv_lora_rank\n    # q_absorb = kv_b_proj[:, ] # torch.randn(num_heads, nope_size, kv_lora_rank, dtype=data_type)\n    # out_absorb = kv_b_proj # torch.randn(num_heads, nope_size, kv_lora_rank, dtype=data_type)\n\n    rotary_emb = DeepseekV3YarnRotaryEmbedding(\n        rope_size,\n        max_position_embeddings=max_position_embeddings,\n        scaling_factor=rope_scaling[\"factor\"],\n        base=rope_theta,\n        beta_fast=rope_scaling[\"beta_fast\"],\n        beta_slow=rope_scaling[\"beta_slow\"],\n        mscale=rope_scaling[\"mscale\"],\n        mscale_all_dim=rope_scaling[\"mscale_all_dim\"],\n        original_max_position_embeddings=rope_scaling[\"original_max_position_embeddings\"],\n    )\n    # 构造一个qlen 长度的输入 hidden_states, 对应的历史 kv_indptr 是[0:bsz]\n    # kv_indices 是[0:bsz]，page_idx=[0:bsz], page_offset=[kvlen:qlen+kvlen]\n    # last_page_len = [qlen+kvlen,...] layer_idx = 1\n    # position_ids = [kvlen:qlen+kvlen]\n    q_indptr = torch.tensor([0, qlen]).to(torch.int32)\n\n    kv_indptr = torch.tensor([0, (qlen + kvlen + page_size - 1) // page_size]).to(torch.int32)\n    kv_indices = torch.tensor(range(pages_count)).to(torch.int32)\n\n    page_idx = torch.tensor([i // page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)\n    page_offset = torch.tensor([i % page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)\n\n    last_page_len = torch.tensor([256], device=hidden_states.device)\n    position_ids = torch.tensor(range(kvlen, kvlen + qlen)).to(torch.int32)\n\n    # 按照行创建 mask [qlen,kvlen+qlen]\n    attention_masks = torch.zeros((max_qlen, max_kvlen), dtype=weight_type)\n    for i in range(max_qlen):\n        attention_masks[i, i + kvlen + 1 :] = -inf\n\n    def torch_attn(\n        hidden_states_i: torch.Tensor,\n        kv_cache: KDeepSeekV3Cache,\n        position_ids: torch.Tensor,\n        page_idx: torch.Tensor,\n        page_offset: torch.Tensor,\n        attention_masks: Optional[list[torch.Tensor]] = None,\n        q_indptr: Optional[torch.Tensor] = None,\n        kv_indices: Optional[torch.Tensor] = None,\n        kv_indptr: Optional[torch.Tensor] = None,\n        bsz_tensors: Optional[torch.Tensor] = None,\n        last_page_len: Optional[torch.Tensor] = None,\n        layer_idx: Optional[int] = None,\n    ):\n        global out_absorb\n        global q_absorb\n        hidden_states = hidden_states_i.to(weight_type)\n        # range bsz_tensors\n        final_attention_output = torch.tensor([], device=hidden_states.device)\n        for i in range(bsz_tensors[0]):\n            batch_num_tokens_tensors = q_indptr[i + 1] - q_indptr[i]\n            batch_last_page_len = last_page_len[i]\n            # kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe\n            batch_page_idx = page_idx[q_indptr[i] : q_indptr[i + 1]]\n            batch_page_offset = page_offset[q_indptr[i] : q_indptr[i + 1]]\n            # kv_page_nums is the number of pages for the current batch\n            kv_page_nums = kv_indptr[i + 1] - kv_indptr[i]\n            # kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)\n            kv_total_len = kv_page_nums * page_size\n            if batch_last_page_len is not None:\n                kv_total_len = kv_total_len - (page_size - batch_last_page_len)\n            # print(f\"kv_total_len's shape {kv_total_len.shape}\")\n            # kv_index is the index of the kv cache pages for the current batch\n            kv_index = kv_indices[kv_indptr[i] : kv_indptr[i + 1]]\n            # we can index [kv_index, page_offset_indices] to get the kv cache for the current batch\n            # from q_indptr[i] to q_indptr[i+1] is the range of the current batch\n            batch_hidden_states = hidden_states[q_indptr[i] : q_indptr[i + 1]]\n            batch_position_ids = position_ids[q_indptr[i] : q_indptr[i + 1]]\n            qlen, _ = batch_hidden_states.size()\n            # print(\"qlen -> \", qlen)\n\n            hidden_states_to_check = load_fp16_tensor(\"./debug/query_0_tp_0_input.bin\", batch_hidden_states.shape)\n            diff = torch.abs(batch_hidden_states - hidden_states_to_check).max()\n            print(\"hidden_states diff -> \", diff)\n\n            q_lora = q_a_proj(batch_hidden_states)\n            # q_lora_to_check = load_fp16_tensor('./debug/query_0_tp_0_qlora.bin', q_lora.shape)\n            # q_lora_to_check_test = load_fp16_tensor('./debug/query_0_tp_0_qlora_test.bin', q_lora.shape)\n            # diff = torch.abs(q_lora - q_lora_to_check).max()\n            # diff_test = torch.abs(q_lora - q_lora_to_check_test).max()\n            # print(\"q_lora max diff -> \", diff)\n            # print(\"q_lora max diff test -> \", diff_test)\n            # mae =  torch.mean(torch.abs(q_lora - q_lora_to_check))\n            # mae_test =  torch.mean(torch.abs(q_lora - q_lora_to_check_test))\n            # print(\"q_lora mae -> \", mae)\n            # print(\"q_lora mae test -> \", mae_test)\n\n            q_lora_norm = q_a_layernorm(q_lora)\n            # q_lora_norm_to_check = load_fp16_tensor('./debug/query_0_tp_0_qlora_norm.bin', q_lora_norm.shape)\n            # q_lora_norm_to_check_test = load_fp16_tensor('./debug/query_0_tp_0_qlora_norm_test.bin', q_lora_norm.shape)\n            # diff = torch.abs(q_lora_norm - q_lora_norm_to_check).max()\n            # mae =  torch.mean(torch.abs(q_lora_norm - q_lora_norm_to_check))\n            # diff_test = torch.abs(q_lora_norm - q_lora_norm_to_check_test).max()\n            # mae_test =  torch.mean(torch.abs(q_lora_norm - q_lora_norm_to_check_test))\n            # print(\"q_lora_norm diff -> \", diff)\n            # print(\"q_lora_norm mae -> \", mae)\n            # print(\"q_lora_norm diff test -> \", diff_test)\n            # print(\"q_lora_norm mae test -> \", mae_test)\n\n            q = q_b_proj(q_lora_norm)\n            # for v3, bsz, qlen, num_heads(128), qk_head_dim(192=128(nope)+64(rope))\n            q = q.view(qlen, num_heads, nope_size + rope_size)\n            # q_nope is [qlen, num_heads(128), qk_nope_head_dim(128)]\n            # q_pe is [qlen, num_heads(128), qk_rope_head_dim(64)]\n            q_nope, q_pe = torch.split(q, [nope_size, rope_size], dim=-1)\n\n            # compressed_kv is [qlen, kv_lora_rank(512) + rope(64)]\n            compressed_kv = kv_a_proj_with_mqa(batch_hidden_states)\n            # compressed_kv is [qlen, kv_lora_rank(512)], k_pe is [qlen, rope(64)]\n            compressed_kv, k_pe = torch.split(compressed_kv, [kv_lora_rank, rope_size], dim=-1)\n            compressed_kv = compressed_kv.contiguous()\n\n            # compressed_kv_page_0 = compressed_kv[0:page_size, :]\n            # compressed_kv_to_check = load_fp16_tensor('./debug/query_0_tp_0_page_0_kv_lora_rank',\n            #                                           compressed_kv_page_0.shape)\n            # diff = torch.abs(compressed_kv_page_0 - compressed_kv_to_check).max()\n            # mae =  torch.mean(torch.abs(compressed_kv_page_0 - compressed_kv_to_check))\n            # print(\"compressed_kv diff -> \", diff)\n            # print(\"compressed_kv mae -> \", mae)\n\n            compressed_kv = kv_a_layernorm(compressed_kv)\n            # k_pe is [qlen, 1, qk_rope_head_dim(64)]\n\n            # compressed_kv_page_0 = compressed_kv[0:page_size, :]\n            # compressed_kv_to_check = load_fp16_tensor('./debug/query_0_tp_0_page_0_kv_lora_rank_norm',\n            #                                           compressed_kv_page_0.shape)\n            # diff = torch.abs(compressed_kv_page_0 - compressed_kv_to_check).max()\n            # mae =  torch.mean(torch.abs(compressed_kv_page_0 - compressed_kv_to_check))\n            # print(\"compressed_kv diff norm -> \", diff)\n            # print(\"compressed_kv mae norm -> \", mae)\n\n            k_pe = k_pe.view(qlen, 1, rope_size)\n            # compressed_kv is [qlen, 1, kv_lora_rank(512)]\n            compressed_kv = compressed_kv.view(qlen, 1, kv_lora_rank)\n\n            cos, sin = rotary_emb(q_pe, batch_position_ids)\n\n            # q_nope_check = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below\n\n            # q_nope_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_nope', q_nope_check[0].shape)\n            # q_nope_0_to_check_test = load_fp16_tensor('./debug/query_0_tp_0_q_nope_test', q_nope_check[0].shape)\n            # diff = torch.abs(q_nope_check[0] - q_nope_0_to_check).max()\n            # mae =  torch.mean(torch.abs(q_nope_check[0] - q_nope_0_to_check))\n            # diff_test = torch.abs(q_nope_check[0] - q_nope_0_to_check_test).max()\n            # mae_test =  torch.mean(torch.abs(q_nope_check[0] - q_nope_0_to_check_test))\n            # print(\"q_nope[0] diff -> \", diff)\n            # print(\"q_nope[0] mae -> \", mae)\n            # print(\"q_nope[0] diff test -> \", diff_test)\n            # print(\"q_nope[0] mae test -> \", mae_test)\n\n            q_pe_nope = q_pe.transpose(0, 1)\n            # q_pe_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_rope', q_pe_nope[0].shape)\n            # q_pe_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_rope_no_rope', q_pe_nope[0].shape)\n            # q_pe_0_to_check_test = load_fp16_tensor('./debug/query_0_tp_0_q_rope_no_rope_test', q_pe_nope[0].shape)\n            # diff = torch.abs(q_pe_nope[0] - q_pe_0_to_check).max()\n            # mae =  torch.mean(torch.abs(q_pe_nope[0] - q_pe_0_to_check))\n            # diff_test = torch.abs(q_pe_nope[0] - q_pe_0_to_check_test).max()\n            # mae_test =  torch.mean(torch.abs(q_pe_nope[0] - q_pe_0_to_check_test))\n            # print(\"q_pe nope[0] diff -> \", diff)\n            # print(\"q_pe nope[0] mae -> \", mae)\n            # print(\"q_pe nope[0] diff test -> \", diff_test)\n            # print(\"q_pe nope[0] mae test -> \", mae_test)\n\n            # cos_to_check = load_fp32_tensor('./debug/query_0_tp_0_rope_cos', (qlen,32))\n            # diff = torch.abs(cos[:,:32]-cos_to_check).max()\n            # mae =  torch.mean(torch.abs(cos[:,:32]-cos_to_check))\n            # print(\"cos diff -> \", diff)\n            # print(\"cos mae -> \", mae)\n            # sin_to_check = load_fp32_tensor('./debug/query_0_tp_0_rope_sin', (qlen,32))\n            # diff = torch.abs(sin[:,:32]-sin_to_check).max()\n            # mae =  torch.mean(torch.abs(sin[:,:32]-sin_to_check))\n            # print(\"sin diff -> \", diff)\n            # print(\"sin mae -> \", mae)\n\n            # new_q_pe = q_pe.transpose(0, 1)\n            # qa = new_q_pe[:,:,range(0,64,2)]\n            # qb = new_q_pe[:,:,range(1,65,2)]\n            # # q1 = (qa * cos[:,:32] - qb * sin[:,:32])\n            # # q2 = (qb*cos[:,:32] + qa*sin[:,:32])\n            # q1 = (qa * cos_to_check - qb * sin_to_check)\n            # q2 = (qb*cos_to_check + qa*sin_to_check)\n            # q_new = torch.cat((q1,q2), dim=-1)\n            # print(f\"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}\")\n            # new_q_pe = torch.zeros_like(q_pe)\n            # new_q_pe[:,:,range(0,64,2)] = 1\n            # new_q_pe[:,:,range(1,65,2)] = 10\n            q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=1)\n            q_pe = q_pe.squeeze(0)\n            # q_pe is [num_heads(128), qlen, qk_rope_head_dim(64)]\n            q_pe.transpose_(0, 1)\n\n            # diff = torch.abs(q_pe - q_new).max()\n            # print(\"q_pe diff -> \", diff)\n\n            # q_pe_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_rope', q_pe[0].shape)\n            # diff = torch.abs(q_pe[0] - q_pe_0_to_check).max()\n            # mae =  torch.mean(torch.abs(q_pe[0] - q_pe_0_to_check))\n            # print(\"q_pe[0] diff -> \", diff)\n            # print(\"q_pe[0] mae -> \", mae)\n\n            # diff = torch.abs(q_pe_0_to_check - q_new[0]).max()\n            # mae =  torch.mean(torch.abs(q_pe_0_to_check - q_new[0]))\n            # print(\"q_pe[0] 2  diff -> \", diff)\n            # print(\"q_pe[0] 2 mae -> \", mae)\n\n            if kv_cache is not None:\n                cache_kwargs = {\n                    \"sin\": sin,\n                    \"cos\": cos,\n                    \"page_idx\": batch_page_idx,\n                    \"page_offset\": batch_page_offset,\n                }  # Specific to RoPE models\n                compressed_kv_with_k_pe = kv_cache.update(\n                    compressed_kv.unsqueeze(0), k_pe, layer_idx, batch_page_idx, batch_page_offset, cache_kwargs\n                )\n                compressed_kv = compressed_kv_with_k_pe[:, :, :, :kv_lora_rank].view(-1, page_size, kv_lora_rank)\n                k_pe = compressed_kv_with_k_pe[:, :, :, kv_lora_rank:].view(-1, page_size, rope_size)\n            # q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]\n            # out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim\n            # q_absorb, out_absorb = get_absorbed()\n            # q_nope is [num_heads(128), qlen, qk_nope_head_dim(128)]\n            q_nope = q_nope.transpose(0, 1)  # qlen is 1, no GPU overhead, same below\n\n            # q_nope_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_nope', q_nope[0].shape)\n            # diff = torch.abs(q_nope[0] - q_nope_0_to_check).max()\n            # mae =  torch.mean(torch.abs(q_nope[0] - q_nope_0_to_check))\n            # print(\"q_nope[0] diff -> \", diff)\n\n            # q_nope is [num_heads(128), qlen, kv_lora_rank(512)]\n            q_nope = torch.matmul(q_nope, q_absorb)  # batched MM\n\n            # k_b_proj_check = load_fp16_tensor('./debug/query_0_tp_0_k_b_lora', (nope_size,kv_lora_rank))\n            # diff = torch.abs(q_absorb[0] - k_b_proj_check).max()\n            # print(\"kv b lora weight[0] diff -> \", diff)\n\n            # q_absorb_check = load_fp16_tensor('./debug/query_0_tp_0_q_absorb', (kv_lora_rank,1024))\n            # q_absorb_check = q_absorb_check[:,0:qlen].transpose(0,1)\n            # diff = torch.abs(q_nope[0] - q_absorb_check).max()\n            # mae =  torch.mean(torch.abs(q_nope[0] - q_absorb_check))\n            # print(\"q_nope absorb diff -> \", diff)\n            # print(\"q_nope absorb mae -> \", mae)\n\n            # # q_nope is [qlen, num_heads(128), kv_lora_rank(512)]\n            # q_nope = q_nope.transpose(0, 1)\n\n            # we need to index out the compressed_kv and k_pe for the current batch\n            batch_compressed_kv = None\n            batch_k_pe = None\n            for page_index in kv_index:\n                if kv_total_len > page_size:\n                    tmp_compressed_kv = compressed_kv[page_index, 0:page_size, :]\n                    tmp_k_pe = k_pe[page_index, 0:page_size, :]\n                    if batch_compressed_kv is None or batch_k_pe is None:\n                        batch_compressed_kv = tmp_compressed_kv\n                        batch_k_pe = tmp_k_pe\n                    else:\n                        batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n                        batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n                    kv_total_len -= page_size\n                else:\n                    tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :]\n                    tmp_k_pe = k_pe[page_index, 0:kv_total_len, :]\n                    if batch_compressed_kv is None or batch_k_pe is None:\n                        batch_compressed_kv = tmp_compressed_kv\n                        batch_k_pe = tmp_k_pe\n                    else:\n                        batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n                        batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n                    break\n            # batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]\n            # batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]\n\n            # k_pe_to_check = load_fp16_tensor('./debug/query_0_tp_0_page_0_k_rope', (256,64))\n            # diff = torch.abs(batch_k_pe[:256] - k_pe_to_check).max()\n            # mae =  torch.mean(torch.abs(batch_k_pe[:256] - k_pe_to_check))\n            # print(\"k_pe diff -> \", diff)\n            # print(\"k_pe mae -> \", mae)\n\n            pe_weights = torch.matmul(q_pe, batch_k_pe.mT)\n            kv_total_len = kv_page_nums * page_size\n            # pe_weights_0 = load_fp16_tensor('./debug/query_0_tp_0_pe_attention_weights', (1024,4096))\n            # pe_weights_0 = pe_weights_0[0:qlen, 0:kv_total_len]\n            # diff = torch.abs(pe_weights[0] - pe_weights_0).max()\n            # print(\"pe_weights[0] diff -> \", diff)\n\n            attention_weights = pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT)\n\n            # raw_weights = load_fp16_tensor('./debug/query_0_tp_0_raw_attention_weights', (1024, 4096))\n            # raw_weights = raw_weights[0:qlen, 0:kv_total_len]\n            # diff = torch.abs(attention_weights[0] - raw_weights).max()\n            # print(\"raw attention_weights[0] diff -> \", diff)\n\n            attention_weights = attention_weights * softmax_scale\n            # attention_weights is [num_heads(128), qlen, k_len]\n\n            # attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(qlen,-1,-1).transpose(0,1)\n\n            # attention_masks[i] is [qlen, k_len]\n\n            print(attention_weights.shape)\n            print(attention_masks.shape)\n            attention_weights = (\n                attention_weights + attention_masks[: attention_weights.shape[1], : attention_weights.shape[2]]\n            )\n            # attention_weights shape is [num_heads(128), qlen, k_len]\n\n            attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=weight_type).to(q_pe.dtype)\n\n            # attention_weights_0 = load_fp16_tensor('./debug/query_0_tp_0_attention_weights', (1024, 4096))\n            # attention_weights_0 = attention_weights_0[0:qlen, 0:kv_total_len]\n            # diff = torch.abs(attention_weights[0] - attention_weights_0).max()\n            # print(\"attention_weights[0] diff -> \", diff)\n\n            attn_output = torch.matmul(attention_weights, batch_compressed_kv)  # [num_heads(128),qlen, lora_rank(512)]\n            # out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]\n\n            # o_absorb_check = load_fp16_tensor('./debug/query_0_tp_0_o_absorb', (qlen,kv_lora_rank))\n            # diff = torch.abs(attn_output[0] - o_absorb_check).max()\n            # print(\"o absorb[0] diff -> \", diff)\n\n            out_absorb = out_absorb.transpose(1, 2)  # [qlen, num_heads(128), v_head_dim(128)]\n            # q for qlen, n for num_heads, h for v_head_dim, v for kv_lora_rank\n            attn_output = torch.matmul(attn_output, out_absorb)  # [num_heads(128), qlen, v_head_dim(128)]\n\n            # attn_output_check_0 = load_fp16_tensor('./debug/query_0_tp_0_attention_output', (qlen, nope_size))\n            # diff = torch.abs(attn_output[0] - attn_output_check_0).max()\n            # print(\"attn_output[0] diff -> \", diff)\n\n            attn_output = attn_output.transpose(0, 1)  # [qlen, num_heads(128), v_head_dim(128)]\n            attn_output = attn_output.reshape(qlen, num_heads * nope_size)\n\n            w_o = o_proj.weight.view([hidden_size, num_heads * nope_size])\n            output = torch.matmul(attn_output, w_o.transpose(0, 1))\n            output = output.view(qlen, hidden_size)\n\n            # output_0_check = load_fp16_tensor('./debug/query_0_tp_0_qlen_output', (qlen, hidden_size))\n            # h1_o = w_o[:,:128]\n            # local_o_check = load_fp16_tensor('./debug/query_0_tp_0_local_w_o', (hidden_size, 128))\n            # diff = torch.abs(local_o_check - h1_o).max()\n            # print(\"local w_o diff -> \", diff)\n\n            # h1_output = torch.matmul(attn_output[:,:128],h1_o.transpose(0,1))\n            # diff = torch.abs(h1_output - output_0_check).max()\n            # print(\"h1_output diff -> \", diff)\n\n            # output_check = load_fp16_tensor('./debug/output.bin', output.shape)\n            # diff = torch.abs(output - output_check).max()\n            # mae =   torch.mean(torch.abs(output - output_check))\n            # print(\"output diff -> \", diff)\n\n            final_attention_output = torch.cat((final_attention_output, output), dim=0)\n        return final_attention_output\n\n    torch_output = torch_attn(\n        hidden_states,\n        kv_cache,\n        position_ids,\n        page_idx,\n        page_offset,\n        attention_masks=attention_masks,\n        q_indptr=q_indptr,\n        kv_indices=kv_indices,\n        kv_indptr=kv_indptr,\n        bsz_tensors=bsz_tensors,\n        last_page_len=last_page_len,\n        layer_idx=0,\n    )\n    print(\"Torch Output: \", torch_output)\n    return torch_output\n\n\ntorch.set_printoptions(sci_mode=False, precision=5)\noutput_cpu = test_cpu_mla()\noutput_torch = test_torch()\nprint(\"Output CPU: \", output_cpu)\nprint(\"Output Torch: \", output_torch)\ndiff = (output_cpu - output_torch).abs()\n# 计算相对误差\ndiff_relative = diff / (output_cpu.abs())\n# 把 diff_relative 中的 NaN 替换为 0\ndiff_relative = torch.where(torch.isnan(diff_relative), torch.zeros_like(diff_relative), diff_relative)\ndiff_relative_mean = torch.mean(torch.abs(output_cpu - output_torch)) / torch.mean(torch.abs(output_torch))\n\nprint(\n    f\"Diff: ave:{diff.mean()}, max:{diff.max()}, min:{diff.min()},  relative_mean:{diff_relative_mean}, relative_max:{diff_relative.max()}, relative_min:{diff_relative.min()}\"\n)\nassert diff_relative_mean < 2e-1, \"CPU and Torch outputs are not close enough!\"\n"
  },
  {
    "path": "kt-kernel/examples/test_mla_qlen.py",
    "content": "import logging\nimport os, sys\nimport time\nfrom typing import Optional\n\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext.kvcache import ggml_type\nimport torch\nfrom torch import inf, nn\nfrom torch.nn import init\nfrom torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding\n\nlogger = logging.getLogger(\"reader\")\n\nfrom gguf.gguf_reader import GGUFReader\n\n\ndef read_gguf_file(gguf_file_path):\n    \"\"\"\n    Reads and prints key-value pairs and tensor information from a GGUF file in an improved format.\n\n    Parameters:\n    - gguf_file_path: Path to the GGUF file.\n    \"\"\"\n\n    reader = GGUFReader(gguf_file_path)\n\n    # List all key-value pairs in a columnized format\n    # print(\"Key-Value Pairs:\") # noqa: NP100\n    # max_key_length = max(len(key) for key in reader.fields.keys())\n    for key, field in reader.fields.items():\n        value = field.parts[field.data[0]]\n        # print(f\"{key:{max_key_length}} : {value}\") # noqa: NP100\n    # print(\"----\") # noqa: NP100\n\n    # List all tensors\n    # print(\"Tensors:\") # noqa: NP100\n    # tensor_info_format = \"{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}\"\n    # print(tensor_info_format.format(\"Tensor Name\", \"Shape\", \"Size\", \"Quantization\")) # noqa: NP100\n    # print(\"-\" * 80) # noqa: NP100\n    re = []\n    for tensor in reader.tensors:\n        shape_str = \"x\".join(map(str, tensor.shape))\n        size_str = str(tensor.n_elements)\n        quantization_str = tensor.tensor_type.name\n        # print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str)) # noqa: NP100\n        re.append(tensor)\n    return re\n\n\ndef get_torch_tensor_from_gguf(gguf_weights, name):\n    return torch.from_numpy(gguf_weights[name].data).contiguous()\n\n\ndef get_torch_tensor_and_type_from_gguf(gguf_weights, name):\n    return torch.from_numpy(gguf_weights[name].data).contiguous(), gguf_weights[name].tensor_type.name\n\n\ndef type_to_ggml_type(type):\n    if type == \"F32\":\n        return ggml_type.FP32\n    elif type == \"F16\":\n        return ggml_type.FP16\n    elif type == \"BF16\":\n        return ggml_type.BF16\n    else:\n        raise ValueError(f\"Unsupported data type: {type}\")\n\n\nuse_real_weights = True\ngguf_path = \"/home/bd/models/DeepSeek-R1-BF16\"\n\nseed = 42  # 你可以选择任何整数作为种子\ntorch.manual_seed(seed)\ntorch.cuda.manual_seed_all(seed)\n\nqlen = 1024\nkvlen = 0\n\n\npage_table = range(20)\nbsz_tensors = torch.tensor([1])\n\n\npage_size = 256\npages_count = 200\ntp_count = 4\n\n\nhidden_size = 7168\nq_lora_rank = 1536\nkv_lora_rank = 512\nnum_heads = 128\nnope_size = 128\nrope_size = 64\n\nrope_theta = 10000\nmax_qlen = 1024\nmax_kvlen = 4096\n\nmax_position_embeddings = 163840\n\n\nrope_scaling = {\n    \"beta_fast\": 32,\n    \"beta_slow\": 1,\n    \"factor\": 40,\n    \"mscale\": 1.0,\n    \"mscale_all_dim\": 1.0,\n    \"original_max_position_embeddings\": 4096,\n    \"type\": \"yarn\",\n}\n\n\nCPUInfer = kt_kernel_ext.CPUInfer(64)\nvalidation_iter = 100\n\n\n# data_type = torch.float32\nweight_type = torch.bfloat16\n# weight_type = torch.float16\n\n\ninput_type = {\n    torch.float32: torch.float32,\n    torch.float16: torch.float16,\n    torch.bfloat16: torch.float32,\n}[weight_type]\n\nq_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=weight_type)\nq_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size + rope_size), bias=False, dtype=weight_type)\nkv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=weight_type)\nkv_b_proj = nn.Linear(num_heads * (nope_size + nope_size), kv_lora_rank, bias=False, dtype=weight_type)\no_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=weight_type)\nq_a_norm = torch.ones(hidden_size, dtype=torch.float32)\nkv_a_norm = torch.ones(hidden_size, dtype=torch.float32)\n\n\ndef read_gguf_directory(directory):\n    \"\"\"\n    Reads all GGUF files in a directory and prints their contents.\n\n    Parameters:\n    - directory: Path to the directory containing GGUF files.\n    \"\"\"\n    if not os.path.isdir(directory):\n        logger.error(f\"Directory {directory} does not exist.\")\n        return\n\n    # List all GGUF files in the directory\n    files = [f for f in os.listdir(directory) if f.endswith(\".gguf\")]\n    if not files:\n        logger.info(f\"No GGUF files found in {directory}.\")\n        return\n\n    re = []\n    for file in files:\n        file_path = os.path.join(directory, file)\n        # print(f\"Reading {file_path}:\") # noqa: NP100\n        # print(\"\\n\") # noqa: NP100\n        re.extend(read_gguf_file(file_path))\n    re = {r.name: r for r in re}\n    return re\n\n\nif use_real_weights := True:\n    gguf_weights = read_gguf_directory(gguf_path)\n    layer_idx = 0\n    q_a_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_a.weight\")\n    q_a_proj.weight = nn.Parameter(q_a_proj_weight.view(torch.bfloat16), requires_grad=False)\n    q_a_type = type\n\n    q_a_norm_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_a_norm.weight\")\n    q_a_norm = q_a_norm_weight.view(torch.float32)\n    # config.q_a_norm = q_a_norm_weight.data_ptr()\n    # config.q_a_norm_type = type_to_ggml_type(type)\n\n    q_b_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_q_b.weight\")\n    q_b_proj.weight = nn.Parameter(q_b_proj_weight.view(torch.bfloat16), requires_grad=False)\n\n    kv_a_proj_with_mqa_weight, type = get_torch_tensor_and_type_from_gguf(\n        gguf_weights, f\"blk.{layer_idx}.attn_kv_a_mqa.weight\"\n    )\n    kv_a_proj_with_mqa.weight = nn.Parameter(kv_a_proj_with_mqa_weight.view(torch.bfloat16), requires_grad=False)\n\n    kv_a_norm_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_kv_a_norm.weight\")\n    kv_a_norm = kv_a_norm_weight.view(torch.float32)\n    # config.kv_a_norm = kv_a_norm_weight.data_ptr()\n    # config.kv_a_norm_type = type_to_ggml_type(type)\n\n    kv_b_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_kv_b.weight\")\n    kv_b_proj.weight = nn.Parameter(kv_b_proj_weight.view(torch.bfloat16), requires_grad=False)\n\n    o_proj_weight, type = get_torch_tensor_and_type_from_gguf(gguf_weights, f\"blk.{layer_idx}.attn_output.weight\")\n    o_proj.weight = nn.Parameter(o_proj_weight.view(torch.bfloat16), requires_grad=False)\n\nelse:\n    init.normal_(q_a_proj.weight, mean=0.0, std=0.02)\n    init.normal_(q_b_proj.weight, mean=0.0, std=0.02)\n    init.normal_(kv_a_proj_with_mqa.weight, mean=0.0, std=0.02)\n    init.normal_(kv_b_proj.weight, mean=0.0, std=0.02)\n    init.normal_(o_proj.weight, mean=0.0, std=0.02)\n\nx_reshaped = kv_b_proj.weight.view(num_heads, 2, nope_size, kv_lora_rank)\nq_absorb = x_reshaped[:, 0]\nout_absorb = x_reshaped[:, 1]\n\n\nhidden_states = torch.randn((qlen, hidden_size), dtype=input_type).to(\"cpu\").contiguous()\n\n\ndef build_mla():\n    os.environ[\"BLAS_NUM_THREADS\"] = \"1\"\n    q_a_proj_weight = q_a_proj.weight.to(weight_type).to(\"cpu\").contiguous()\n    q_b_proj_weight = q_b_proj.weight.to(weight_type).to(\"cpu\").contiguous()\n    kv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to(\"cpu\").to(weight_type).contiguous()\n    kv_b_proj_weight = kv_b_proj.weight.to(weight_type).to(\"cpu\").contiguous()\n    o_proj_weight = o_proj.weight.to(weight_type).to(\"cpu\").contiguous()\n\n    config = kt_kernel_ext.mla.MLAConfig(\n        hidden_size,\n        q_lora_rank,\n        kv_lora_rank,\n        num_heads,\n        nope_size,\n        rope_size,\n    )\n    config.max_qlen = max_qlen\n    config.max_kvlen = max_kvlen\n    config.max_position_embeddings = max_position_embeddings\n    config.rope_scaling_factor = rope_scaling[\"factor\"]\n    config.rope_theta = rope_theta\n    config.rope_scaling_beta_fast = rope_scaling[\"beta_fast\"]\n    config.rope_scaling_beta_slow = rope_scaling[\"beta_slow\"]\n    config.rope_scaling_mscale = rope_scaling[\"mscale\"]\n    config.rope_scaling_mscale_all_dim = rope_scaling[\"mscale_all_dim\"]\n    config.rope_scaling_original_max_position_embeddings = rope_scaling[\"original_max_position_embeddings\"]\n\n    config.q_a_proj = q_a_proj_weight.data_ptr()\n    config.q_b_proj = q_b_proj_weight.data_ptr()\n    config.kv_a_proj_with_mqa = kv_a_proj_with_mqa_weight.data_ptr()\n    config.kv_b_proj = kv_b_proj_weight.data_ptr()\n    config.o_proj = o_proj_weight.data_ptr()\n\n    config.q_a_norm = q_a_norm.data_ptr()\n    config.q_a_norm_type = ggml_type.FP32\n    config.kv_a_norm = kv_a_norm.data_ptr()\n    config.kv_a_norm_type = ggml_type.FP32\n\n    if weight_type == torch.float32:\n        config.q_a_proj_type = ggml_type.FP32\n        config.q_b_proj_type = ggml_type.FP32\n        config.kv_a_proj_with_mqa_type = ggml_type.FP32\n        config.kv_b_proj_type = ggml_type.FP32\n        config.w_o_type = ggml_type.FP32\n    elif weight_type == torch.float16:\n        config.q_a_proj_type = ggml_type.FP16\n        config.q_b_proj_type = ggml_type.FP16\n        config.kv_a_proj_with_mqa_type = ggml_type.FP16\n        config.kv_b_proj_type = ggml_type.FP16\n        config.w_o_type = ggml_type.FP16\n    elif weight_type == torch.bfloat16:\n        config.q_a_proj_type = ggml_type.BF16\n        config.q_b_proj_type = ggml_type.BF16\n        config.kv_a_proj_with_mqa_type = ggml_type.BF16\n        config.kv_b_proj_type = ggml_type.BF16\n        config.w_o_type = ggml_type.BF16\n    else:\n        raise ValueError(f\"Unsupported data type: {weight_type}\")\n\n    config.pool = CPUInfer.backend_\n\n    if weight_type == torch.float32:\n        mla = kt_kernel_ext.mla.MLA_F32(config)\n    elif weight_type == torch.float16:\n        mla = kt_kernel_ext.mla.MLA_F16(config)\n    elif weight_type == torch.bfloat16:\n        mla = kt_kernel_ext.mla.MLA_F32(config)\n    else:\n        raise ValueError(f\"Unsupported data type: {weight_type}\")\n\n    mla.load_weights()\n    mla.set_local_pages(pages_count)\n    return mla\n\n\ndef load_fp32_tensor(file_path, shape):\n    with open(file_path, \"rb\") as f:\n        raw_data = f.read()\n    tensor = torch.frombuffer(raw_data, dtype=torch.float32)\n    tensor = tensor.view(shape)  # 根据你的 shape reshape\n    return tensor\n\n\n# page3 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug1/query_0_tp_0_page_3_kv_lora_rank_norm.f32',(page_size,kv_lora_rank))\n# page3_2 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug2/query_0_tp_0_page_3_kv_lora_rank_norm.f32',(page_size,kv_lora_rank))\n\n# diff = torch.abs(page3 - page3_2)\n# print(f'Diff: ave:{diff.mean()}, max:{diff.max()}')\n\n# q_pe_1 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug1/query_0_tp_0_q_rope.f32',(1, rope_size))\n# q_pe_2 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug2/query_0_tp_0_q_rope.f32',(qlen, rope_size))\n# diff = torch.abs(q_pe_1 - q_pe_2[-1])\n# print(f'Q PE Diff: ave:{diff.mean()}, max:{diff.max()}')\n\n# q_nope_1 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug1/query_0_tp_0_q_nope.f32',(1, nope_size))\n# q_nope_2 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug2/query_0_tp_0_q_nope.f32',(qlen, nope_size))\n# diff = torch.abs(q_nope_1 - q_nope_2[-1])\n# print(f'Q Nope Diff: ave:{diff.mean()}, max:{diff.max()}')\n\n\n# pe_attn_w_1 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug1/query_0_tp_0_pe_attention_weights.f32',(1,max_kvlen))\n# pe_attn_w_2 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug2/query_0_tp_0_pe_attention_weights.f32',(qlen,max_kvlen))\n# diff = torch.abs(pe_attn_w_1 - pe_attn_w_2[-1])\n# print(f'PE Attention Weights Diff: ave:{diff.mean()}, max:{diff.max()}')\n\n\n# raw_attn_w_1 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug1/query_0_tp_0_raw_attention_weights.f32',(1,max_kvlen))\n# raw_attn_w_2 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug2/query_0_tp_0_raw_attention_weights.f32',(qlen,max_kvlen))\n# diff = torch.abs(raw_attn_w_1 - raw_attn_w_2[-1])\n# print(f'Raw Attention Weights Diff: ave:{diff.mean()}, max:{diff.max()}')\n\n\n# output_1 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug1/output.bin.f32',shape=(1, hidden_size))\n# output_2 = load_fp32_tensor('/home/yzw/xwy/Projects/ktransformers-dev/csrc/ktransformers_ext/examples/debug2/output.bin.f32',shape=(qlen, hidden_size))\n\n# diff = torch.abs(output_1 - output_2[-1])\n# print(f'Output Diff: ave:{diff.mean()}, max:{diff.max()}')\n\n\nmla = build_mla()\noutput = torch.zeros((qlen, hidden_size), dtype=input_type).to(\"cpu\").contiguous()\nmla.forward([qlen], [page_table], [kvlen], hidden_states.data_ptr(), output.data_ptr())\nprint(\"CPU MLA Output: \", output[-1])\n\n\noutput_2 = torch.zeros((1, hidden_size), dtype=input_type).to(\"cpu\").contiguous()\nmla.forward([1], [page_table], [qlen - 1], hidden_states[-1].data_ptr(), output_2.data_ptr())\nprint(\"CPU MLA Output 2: \", output_2[-1])\n\ndiff = torch.abs(output[-1] - output_2[-1])\nprint(f\"Diff: ave:{diff.mean()}, max:{diff.max()}\")\nassert diff.max() < 1e-1, \"CPU and Torch outputs are not close enough!\"\n"
  },
  {
    "path": "kt-kernel/examples/test_mla_quant.py",
    "content": "import logging\nimport os, sys\nimport time\nfrom typing import Optional\n\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext.kvcache import ggml_type\nimport torch\nfrom torch import inf, nn\nfrom torch.nn import init\nfrom torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding\n\nlogger = logging.getLogger(\"reader\")\n\nfrom gguf.gguf_reader import GGUFReader\n\n\ndef load_fp32_tensor_raw(file_path):\n    # return torch.zeros(shape)\n    with open(file_path, \"rb\") as f:\n        raw_data = f.read()\n    tensor = torch.frombuffer(raw_data, dtype=torch.float32)\n    return tensor\n\n\ndef load_fp16_tensor(file_path, shape=None):\n    # return load_fp32_tensor(file_path, shape)\n    return load_fp32_tensor_raw(file_path)\n    # return torch.zeros(shape)\n    with open(file_path, \"rb\") as f:\n        raw_data = f.read()\n    tensor = torch.frombuffer(raw_data, dtype=weight_type)\n    tensor = tensor.view(shape)  # 根据你的 shape reshape\n    return tensor\n\n\ndef load_fp32_tensor(file_path, shape):\n    # return torch.zeros(shape)\n    with open(file_path, \"rb\") as f:\n        raw_data = f.read()\n    tensor = torch.frombuffer(raw_data, dtype=torch.float32)\n    tensor = tensor.view(shape)  # 根据你的 shape reshape\n    return tensor\n\n\ndef test_torch():\n    torch.set_grad_enabled(False)\n\n    hidden_states_to_check_decode = load_fp16_tensor(\"./debug_decode/query_0_tp_0_input.bin\")\n    hidden_states_to_check_prefill = load_fp16_tensor(\"./debug_prefill/query_0_tp_0_input.bin\")\n    # diff = torch.abs(hidden_states_to_check_prefill - hidden_states_to_check_decode).max()\n    # print(\"hidden_states diff -> \", diff)\n\n    q_lora_to_check_decode = load_fp16_tensor(\"./debug_decode/query_0_tp_0_qlora.bin\")\n    q_lora_to_check_test_decode = load_fp16_tensor(\"./debug_decode/query_0_tp_0_qlora_test.bin\")\n    q_lora_to_check_prefill = load_fp16_tensor(\"./debug_prefill/query_0_tp_0_qlora.bin\")\n    q_lora_to_check_test_prefill = load_fp16_tensor(\"./debug_prefill/query_0_tp_0_qlora_test.bin\")\n    # diff = torch.abs(q_lora_to_check_prefill - q_lora_to_check_decode).max()\n    # diff_test = torch.abs(q_lora_to_check_prefill - q_lora_to_check_decode).max()\n    # print(\"q_lora max diff -> \", diff)\n    # print(\"q_lora max diff test -> \", diff_test)\n    # mae =  torch.mean(torch.abs(q_lora_to_check_prefill - q_lora_to_check_decode))\n    # mae_test =  torch.mean(torch.abs(q_lora_to_check_prefill - q_lora_to_check_decode))\n    # print(\"q_lora mae -> \", mae)\n    # print(\"q_lora mae test -> \", mae_test)\n\n    # q_lora_norm = q_a_layernorm(q_lora)\n    # q_lora_norm_to_check = load_fp16_tensor('./debug/query_0_tp_0_qlora_norm.bin', q_lora_norm.shape)\n    # q_lora_norm_to_check_test = load_fp16_tensor('./debug/query_0_tp_0_qlora_norm_test.bin', q_lora_norm.shape)\n    # diff = torch.abs(q_lora_norm - q_lora_norm_to_check).max()\n    # mae =  torch.mean(torch.abs(q_lora_norm - q_lora_norm_to_check))\n    # diff_test = torch.abs(q_lora_norm - q_lora_norm_to_check_test).max()\n    # mae_test =  torch.mean(torch.abs(q_lora_norm - q_lora_norm_to_check_test))\n    # print(\"q_lora_norm diff -> \", diff)\n    # print(\"q_lora_norm mae -> \", mae)\n    # print(\"q_lora_norm diff test -> \", diff_test)\n    # print(\"q_lora_norm mae test -> \", mae_test)\n\n    # q = q_b_proj(q_lora_norm)\n    # for v3, bsz, qlen, num_heads(128), qk_head_dim(192=128(nope)+64(rope))\n    # q = q.view(qlen, num_heads, nope_size+rope_size)\n    # q_nope is [qlen, num_heads(128), qk_nope_head_dim(128)]\n    # q_pe is [qlen, num_heads(128), qk_rope_head_dim(64)]\n    # q_nope, q_pe = torch.split(\n    #     q, [nope_size, rope_size], dim=-1\n    # )\n\n    # compressed_kv is [qlen, kv_lora_rank(512) + rope(64)]\n    # compressed_kv = kv_a_proj_with_mqa(batch_hidden_states)\n    # compressed_kv is [qlen, kv_lora_rank(512)], k_pe is [qlen, rope(64)]\n    # compressed_kv, k_pe = torch.split(\n    #     compressed_kv, [kv_lora_rank, rope_size], dim=-1\n    # )\n    # compressed_kv = compressed_kv.contiguous()\n\n    # compressed_kv_page_0 = compressed_kv[0:page_size, :]\n    compressed_kv_to_check_decode = load_fp16_tensor(\"./debug_decode/query_0_tp_0_page_0_kv_lora_rank\")\n    compressed_kv_to_check_prefill = load_fp16_tensor(\"./debug_prefill/query_0_tp_0_page_0_kv_lora_rank\")\n    # diff = torch.abs(compressed_kv_to_check_prefill - compressed_kv_to_check_decode).max()\n    # mae =  torch.mean(torch.abs(compressed_kv_to_check_prefill - compressed_kv_to_check_decode))\n    # print(\"compressed_kv diff -> \", diff)\n    # print(\"compressed_kv mae -> \", mae)\n\n    # compressed_kv = kv_a_layernorm(compressed_kv)\n    # k_pe is [qlen, 1, qk_rope_head_dim(64)]\n\n    # compressed_kv_page_0 = compressed_kv[0:page_size, :]\n    compressed_kv_to_check_decode = load_fp16_tensor(\"./debug_decode/query_0_tp_0_page_0_kv_lora_rank_norm\")\n    compressed_kv_to_check_prefill = load_fp16_tensor(\"./debug_prefill/query_0_tp_0_page_0_kv_lora_rank_norm\")\n    # diff = torch.abs(compressed_kv_page_0 - compressed_kv_to_check).max()\n    # mae =  torch.mean(torch.abs(compressed_kv_page_0 - compressed_kv_to_check))\n    # print(\"compressed_kv diff norm -> \", diff)\n    # print(\"compressed_kv mae norm -> \", mae)\n\n    # k_pe = k_pe.view(qlen, 1, rope_size)\n    # compressed_kv is [qlen, 1, kv_lora_rank(512)]\n    # compressed_kv = compressed_kv.view(qlen, 1, kv_lora_rank)\n\n    # cos, sin = rotary_emb(q_pe, batch_position_ids)\n\n    # q_nope_check = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below\n\n    # q_nope_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_nope', q_nope_check[0].shape)\n    # q_nope_0_to_check_test = load_fp16_tensor('./debug/query_0_tp_0_q_nope_test', q_nope_check[0].shape)\n    # diff = torch.abs(q_nope_check[0] - q_nope_0_to_check).max()\n    # mae =  torch.mean(torch.abs(q_nope_check[0] - q_nope_0_to_check))\n    # diff_test = torch.abs(q_nope_check[0] - q_nope_0_to_check_test).max()\n    # mae_test =  torch.mean(torch.abs(q_nope_check[0] - q_nope_0_to_check_test))\n    # print(\"q_nope[0] diff -> \", diff)\n    # print(\"q_nope[0] mae -> \", mae)\n    # print(\"q_nope[0] diff test -> \", diff_test)\n    # print(\"q_nope[0] mae test -> \", mae_test)\n\n    # q_pe_nope = q_pe.transpose(0,1)\n    q_pe_0_to_check_decode = load_fp16_tensor(\"./debug_decode/query_0_tp_0_q_rope\")\n    q_pe_0_to_check_prefill = load_fp16_tensor(\"./debug_prefill/query_0_tp_0_q_rope\")\n\n    # q_pe_0_to_check_decode_test = load_fp16_tensor('./debug_decode/query_0_tp_0_q_rope_test')\n    # q_pe_0_to_check_prefill_test = load_fp16_tensor('./debug_prefill/query_0_tp_0_q_rope_test')\n\n    # q_pe_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_rope_no_rope', q_pe_nope[0].shape)\n    # q_pe_0_to_check_test = load_fp16_tensor('./debug/query_0_tp_0_q_rope_no_rope_test', q_pe_nope[0].shape)\n    # diff = torch.abs(q_pe_nope[0] - q_pe_0_to_check).max()\n    # mae =  torch.mean(torch.abs(q_pe_nope[0] - q_pe_0_to_check))\n    # diff_test = torch.abs(q_pe_nope[0] - q_pe_0_to_check_test).max()\n    # mae_test =  torch.mean(torch.abs(q_pe_nope[0] - q_pe_0_to_check_test))\n    # print(\"q_pe nope[0] diff -> \", diff)\n    # print(\"q_pe nope[0] mae -> \", mae)\n    # print(\"q_pe nope[0] diff test -> \", diff_test)\n    # print(\"q_pe nope[0] mae test -> \", mae_test)\n\n    # cos_to_check = load_fp32_tensor('./debug/query_0_tp_0_rope_cos', (qlen,32))\n    # diff = torch.abs(cos[:,:32]-cos_to_check).max()\n    # mae =  torch.mean(torch.abs(cos[:,:32]-cos_to_check))\n    # print(\"cos diff -> \", diff)\n    # print(\"cos mae -> \", mae)\n    # sin_to_check = load_fp32_tensor('./debug/query_0_tp_0_rope_sin', (qlen,32))\n    # diff = torch.abs(sin[:,:32]-sin_to_check).max()\n    # mae =  torch.mean(torch.abs(sin[:,:32]-sin_to_check))\n    # print(\"sin diff -> \", diff)\n    # print(\"sin mae -> \", mae)\n\n    # new_q_pe = q_pe.transpose(0, 1)\n    # qa = new_q_pe[:,:,range(0,64,2)]\n    # qb = new_q_pe[:,:,range(1,65,2)]\n    # q1 = (qa * cos[:,:32] - qb * sin[:,:32])\n    # q2 = (qb*cos[:,:32] + qa*sin[:,:32])\n    # q1 = (qa * cos_to_check - qb * sin_to_check)\n    # q2 = (qb*cos_to_check + qa*sin_to_check)\n    # q_new = torch.cat((q1,q2), dim=-1)\n    # print(f\"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}\")\n    # new_q_pe = torch.zeros_like(q_pe)\n    # new_q_pe[:,:,range(0,64,2)] = 1\n    # new_q_pe[:,:,range(1,65,2)] = 10\n    # q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=1)\n    # q_pe = q_pe.squeeze(0)\n    # q_pe is [num_heads(128), qlen, qk_rope_head_dim(64)]\n    # q_pe.transpose_(0, 1)\n\n    # diff = torch.abs(q_pe - q_new).max()\n    # print(\"q_pe diff -> \", diff)\n\n    # q_pe_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_rope', q_pe[0].shape)\n    # diff = torch.abs(q_pe[0] - q_pe_0_to_check).max()\n    # mae =  torch.mean(torch.abs(q_pe[0] - q_pe_0_to_check))\n    # print(\"q_pe[0] diff -> \", diff)\n    # print(\"q_pe[0] mae -> \", mae)\n\n    # diff = torch.abs(q_pe_0_to_check - q_new[0]).max()\n    # mae =  torch.mean(torch.abs(q_pe_0_to_check - q_new[0]))\n    # print(\"q_pe[0] 2  diff -> \", diff)\n    # print(\"q_pe[0] 2 mae -> \", mae)\n\n    # if kv_cache is not None:\n    #     cache_kwargs = {\"sin\": sin, \"cos\": cos, \"page_idx\": batch_page_idx, \"page_offset\": batch_page_offset}  # Specific to RoPE models\n    #     compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, layer_idx, batch_page_idx, batch_page_offset, cache_kwargs)\n    #     compressed_kv = compressed_kv_with_k_pe [:, :, :, :kv_lora_rank].view(-1, page_size, kv_lora_rank)\n    #     k_pe = compressed_kv_with_k_pe [:, :, :, kv_lora_rank:].view(-1, page_size, rope_size)\n    # # q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]\n    # # out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim\n    # # q_absorb, out_absorb = get_absorbed()\n    # # q_nope is [num_heads(128), qlen, qk_nope_head_dim(128)]\n    # q_nope = q_nope.transpose(0, 1) # qlen is 1, no GPU overhead, same below\n\n    # q_nope_0_to_check = load_fp16_tensor('./debug/query_0_tp_0_q_nope', q_nope[0].shape)\n    # diff = torch.abs(q_nope[0] - q_nope_0_to_check).max()\n    # mae =  torch.mean(torch.abs(q_nope[0] - q_nope_0_to_check))\n    # print(\"q_nope[0] diff -> \", diff)\n\n    # # q_nope is [num_heads(128), qlen, kv_lora_rank(512)]\n    # q_nope = torch.matmul(q_nope, q_absorb) # batched MM\n\n    # k_b_proj_check = load_fp16_tensor('./debug/query_0_tp_0_k_b_lora', (nope_size,kv_lora_rank))\n    # diff = torch.abs(q_absorb[0] - k_b_proj_check).max()\n    # print(\"kv b lora weight[0] diff -> \", diff)\n\n    # q_absorb_check = load_fp16_tensor('./debug/query_0_tp_0_q_absorb', (kv_lora_rank,1024))\n    # q_absorb_check = q_absorb_check[:,0:qlen].transpose(0,1)\n    # diff = torch.abs(q_nope[0] - q_absorb_check).max()\n    # mae =  torch.mean(torch.abs(q_nope[0] - q_absorb_check))\n    # print(\"q_nope absorb diff -> \", diff)\n    # print(\"q_nope absorb mae -> \", mae)\n\n    # # q_nope is [qlen, num_heads(128), kv_lora_rank(512)]\n    # q_nope = q_nope.transpose(0, 1)\n\n    # we need to index out the compressed_kv and k_pe for the current batch\n    # batch_compressed_kv = None\n    # batch_k_pe = None\n    # for page_index in kv_index:\n    #     if kv_total_len > page_size:\n    #         tmp_compressed_kv = compressed_kv[page_index, 0:page_size, :]\n    #         tmp_k_pe = k_pe[page_index, 0:page_size, :]\n    #         if batch_compressed_kv is None or batch_k_pe is None:\n    #             batch_compressed_kv = tmp_compressed_kv\n    #             batch_k_pe = tmp_k_pe\n    #         else:\n    #             batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n    #             batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n    #         kv_total_len -= page_size\n    #     else:\n    #         tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :]\n    #         tmp_k_pe = k_pe[page_index, 0:kv_total_len, :]\n    #         if batch_compressed_kv is None or batch_k_pe is None:\n    #             batch_compressed_kv = tmp_compressed_kv\n    #             batch_k_pe = tmp_k_pe\n    #         else:\n    #             batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n    #             batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n    #         break\n    # batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]\n    # batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]\n\n    k_pe_to_check_decode = load_fp16_tensor(\"./debug_decode/query_0_tp_0_page_0_k_rope\", (256, 64))\n    k_pe_to_check_prefill = load_fp16_tensor(\"./debug_prefill/query_0_tp_0_page_0_k_rope\", (256, 64))\n    # diff = torch.abs(k_pe_to_check_prefill - k_pe_to_check_decode).max()\n    # mae =  torch.mean(k_pe_to_check_prefill - k_pe_to_check_decode)\n    # print(\"k_pe diff -> \", diff)\n    # print(\"k_pe mae -> \", mae)\n\n    # pe_weights = torch.matmul(q_pe,batch_k_pe.mT)\n    # kv_total_len = kv_page_nums * page_size\n    pe_weights_0_decode = load_fp16_tensor(\"./debug_decode/query_0_tp_0_pe_attention_weights\", (1024, 4096))\n    pe_weights_0_prefill = load_fp16_tensor(\"./debug_prefill/query_0_tp_0_pe_attention_weights\", (1024, 4096))\n\n    # diff = torch.abs(pe_weights[0] - pe_weights_0).max()\n    # print(\"pe_weights[0] diff -> \", diff)\n\n    # attention_weights = (pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT))\n\n    # raw_weights = load_fp16_tensor('./debug/query_0_tp_0_raw_attention_weights', (1024, 4096))\n    # raw_weights = raw_weights[0:qlen, 0:kv_total_len]\n    # diff = torch.abs(attention_weights[0] - raw_weights).max()\n    # print(\"raw attention_weigh/ts[0] diff -> \", diff)\n\n    # attention_weights = attention_weights * softmax_scale\n    # attention_weights is [num_heads(128), qlen, k_len]\n\n    # attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(qlen,-1,-1).transpose(0,1)\n\n    # attention_masks[i] is [qlen, k_len]\n\n    # attention_weights = (attention_weights + attention_masks)\n    # attention_weights shape is [num_heads(128), qlen, k_len]\n\n    # attention_weights = nn.functional.softmax(attention_weights,dim=-1,dtype=weight_type).to(q_pe.dtype)\n\n    attention_weights_0_decode = load_fp16_tensor(\"./debug_decode/query_0_tp_0_attention_weights\", (1024, 4096))\n    attention_weights_0_prefill = load_fp16_tensor(\"./debug_prefill/query_0_tp_0_attention_weights\", (1024, 4096))\n\n    # attention_weights_0 = attention_weights_0[0:qlen, 0:kv_total_len]\n    # diff = torch.abs(attention_weights[0] - attention_weights_0).max()\n    # print(\"attention_weights[0] diff -> \", diff)\n\n    # attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),qlen, lora_rank(512)]\n    # out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]\n\n    # o_absorb_check = load_fp16_tensor('./debug/query_0_tp_0_o_absorb', (qlen,kv_lora_rank))\n    # diff = torch.abs(attn_output[0] - o_absorb_check).max()\n    # print(\"o absorb[0] diff -> \", diff)\n\n    # out_absorb = out_absorb.transpose(1, 2) # [qlen, num_heads(128), v_head_dim(128)]\n    # # q for qlen, n for num_heads, h for v_head_dim, v for kv_lora_rank\n    # attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), qlen, v_head_dim(128)]\n\n    # attn_output_check_0 = load_fp16_tensor('./debug/query_0_tp_0_attention_output', (qlen, nope_size))\n    # diff = torch.abs(attn_output[0] - attn_output_check_0).max()\n    # print(\"attn_output[0] diff -> \", diff)\n\n    # attn_output = attn_output.transpose(0, 1) # [qlen, num_heads(128), v_head_dim(128)]\n    # attn_output = attn_output.reshape(qlen, num_heads * nope_size)\n\n    # w_o = o_proj.weight.view([hidden_size,num_heads * nope_size])\n    # output = torch.matmul(attn_output,w_o.transpose(0,1))\n    # output = output.view(qlen, hidden_size)\n\n    # output_0_check = load_fp16_tensor('./debug/query_0_tp_0_qlen_output', (qlen, hidden_size))\n    # h1_o = w_o[:,:128]\n    # local_o_check = load_fp16_tensor('./debug/query_0_tp_0_local_w_o', (hidden_size, 128))\n    # diff = torch.abs(local_o_check - h1_o).max()\n    # print(\"local w_o diff -> \", diff)\n\n    # h1_output = torch.matmul(attn_output[:,:128],h1_o.transpose(0,1))\n    # diff = torch.abs(h1_output - output_0_check).max()\n    # print(\"h1_output diff -> \", diff)\n\n    output_check_decode = load_fp16_tensor(\"./debug_decode/output.bin\")\n    output_check_prefill = load_fp16_tensor(\"./debug_prefill/output.bin\")\n    # diff = torch.abs(output - output_check).max()\n    # mae =   torch.mean(torch.abs(output - output_check))\n    # print(\"output diff -> \", diff)\n\n    return None\n\n\ntorch.set_printoptions(sci_mode=False, precision=5)\n# output_cpu = test_cpu_mla()\n# output_cpu_quant = test_cpu_mla_quant()\noutput_torch = test_torch()\n# print(\"Output CPU: \", output_cpu)\n# print(\"Output CPU: \", output_cpu_quant)\n# print(\"Output Torch: \", output_torch)\n# diff = (output_cpu - output_torch).abs()\n# # 计算相对误差\n# diff_relative = diff / (output_cpu.abs())\n# # 把 diff_relative 中的 NaN 替换为 0\n# diff_relative = torch.where(torch.isnan(diff_relative), torch.zeros_like(diff_relative), diff_relative)\n# diff_relative_mean = torch.mean(torch.abs(output_cpu-output_torch)) / torch.mean(torch.abs(output_torch))\n\n# print(f'Diff: ave:{diff.mean()}, max:{diff.max()}, min:{diff.min()},  relative_mean:{diff_relative_mean}, relative_max:{diff_relative.max()}, relative_min:{diff_relative.min()}')\n# assert diff_relative_mean < 2e-1, \"CPU and Torch outputs are not close enough!\"\n"
  },
  {
    "path": "kt-kernel/examples/test_mla_simple.py",
    "content": "import math\nimport random\nimport os, sys\nimport time\nimport subprocess\nimport platform\nimport json\nfrom typing import Any, Dict, Optional, Tuple\nimport numpy as np\nimport torch.nn.init as init\nfrom torch_attention import apply_rotary_pos_emb,DeepseekV2RMSNorm,KDeepSeekV3Cache,DeepseekV3YarnRotaryEmbedding\n\nimport torch\nfrom tqdm import tqdm\nfrom torch import nn\n\"\"\"\n\"rope_scaling\": {\n    \"beta_fast\": 32,\n    \"beta_slow\": 1,\n    \"factor\": 40,\n    \"mscale\": 1.0,\n    \"mscale_all_dim\": 1.0,\n    \"original_max_position_embeddings\": 4096,\n    \"type\": \"yarn\"\n  },\n\"\"\" \n\nrope_scaling = {\n    \"beta_fast\": 32,\n    \"beta_slow\": 1,\n    \"factor\": 40,\n    \"mscale\": 1.0,\n    \"mscale_all_dim\": 1.0,\n    \"original_max_position_embeddings\": 4096,\n    \"type\": \"yarn\"\n}\nseed = 42  # 你可以选择任何整数作为种子\ntorch.manual_seed(seed)\ntorch.cuda.manual_seed_all(seed)\nnp.random.seed(seed)\nrandom.seed(seed)\n\n# \"rope_theta\": 10000\nrope_theta = 10000\n\n\nhidden_size = 7168\nnum_heads = 128\nkv_lora_rank = 512\nq_lora_rank = 512\nnope_size = 128\nrope_size = 64\n\n# page 的个数\npage_nums = 10\npage_size = 512\nlayer_num = 10\nmax_position_embeddings =  163840\n\n\nwarm_up_iter = 1000\ntest_iter = 1000\n\nq_len = 200\nhis_kv_len = 128\n\nbsz_tensors=torch.tensor([1])\n\nsoftmax_scale = (nope_size + rope_size) ** -0.5\n# 1代表的是压缩的kv的头数\nk_caches = torch.randn(layer_num,page_nums, page_size,1, kv_lora_rank + rope_size).to(torch.float16)\nkv_cache = KDeepSeekV3Cache(page_size=page_size, kv_lora_rank=kv_lora_rank, k_caches=k_caches)\n\nq_a_layernorm = DeepseekV2RMSNorm(q_lora_rank)\n\nx = torch.randn(q_lora_rank, dtype=torch.float16)*100\nprint(x)\nprint(q_a_layernorm(x))\n\nkv_a_layernorm = DeepseekV2RMSNorm(kv_lora_rank)\n\nq_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=torch.float16)\nq_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size+rope_size) , bias=False, dtype=torch.float16)\nkv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=torch.float16)\nkv_b_proj = nn.Linear(kv_lora_rank, num_heads * (nope_size + nope_size), bias=False, dtype=torch.float16)\no_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=torch.float16)\n\ninit.normal_(q_a_proj.weight, mean=0.0, std=0.02)\ninit.normal_(q_b_proj.weight, mean=0.0, std=0.02)\ninit.normal_(kv_a_proj_with_mqa.weight, mean=0.0, std=0.02)\ninit.normal_(kv_b_proj.weight, mean=0.0, std=0.02)\ninit.normal_(o_proj.weight, mean=0.0, std=0.02)\n# # 这里的权重初始化是为了测试\n# # 将权重设置为全 1\n# with torch.no_grad():\n#     q_a_proj.weight.fill_(1.0)\n#     q_b_proj.weight.fill_(1.0)\n#     kv_a_proj_with_mqa.weight.fill_(1.0)\n#     kv_b_proj.weight.fill_(1.0)\n#     o_proj.weight.fill_(1.0)\n\nq_absorb = torch.randn(num_heads, nope_size, kv_lora_rank, dtype=torch.float16)\nout_absorb = torch.randn(num_heads, nope_size, kv_lora_rank, dtype=torch.float16)\n\nrotary_emb = DeepseekV3YarnRotaryEmbedding(\n    rope_size,\n    max_position_embeddings=max_position_embeddings,\n    scaling_factor=rope_scaling[\"factor\"],\n    base=rope_theta,\n    beta_fast=rope_scaling[\"beta_fast\"],\n    beta_slow=rope_scaling[\"beta_slow\"],\n    mscale=rope_scaling[\"mscale\"],\n    mscale_all_dim=rope_scaling[\"mscale_all_dim\"],\n    original_max_position_embeddings=rope_scaling[\"original_max_position_embeddings\"],\n)\n# 构造一个q_len 长度的输入 hidden_states, 对应的历史 kv_indptr 是[0:bsz]\n# kv_indices 是[0:bsz]，page_idx=[0:bsz], page_offset=[his_kv_len:q_len+his_kv_len]\n# last_page_len = [q_len+his_kv_len,...] layer_idx = 1\n# position_ids = [his_kv_len:q_len+his_kv_len]\nhidden_states = torch.randn(q_len, hidden_size, dtype=torch.float16)\nq_indptr = torch.tensor([0,q_len]).to(torch.int32)\nkv_indptr = torch.tensor(range(0, bsz_tensors[0] + 1)).to(torch.int32)\nkv_indices = torch.tensor(range(0, bsz_tensors[0])).to(torch.int32)\npage_idx = torch.tensor(range(0, bsz_tensors[0])).to(torch.int32)\npage_offset = torch.tensor(range(his_kv_len, his_kv_len + q_len)).to(torch.int32)\nlast_page_len = torch.tensor([q_len+his_kv_len]*bsz_tensors[0], device=hidden_states.device)\nposition_ids = torch.tensor(range(his_kv_len, his_kv_len + q_len)).to(torch.int32)\n\n\n# 按照行创建 mask [q_len,his_kv_len+q_len]\nattention_masks = torch.zeros((q_len, his_kv_len + q_len), dtype=torch.float16)\nfor i in range(q_len):\n    attention_masks[i, i + his_kv_len + 1: i + his_kv_len + q_len] = -65504.0\n\n\ndef torch_attn(hidden_states: torch.Tensor,\n                kv_cache: KDeepSeekV3Cache,\n                position_ids: torch.Tensor,\n                page_idx: torch.Tensor,\n                page_offset: torch.Tensor,\n                attention_masks: Optional[list[torch.Tensor]] = None,\n                q_indptr: Optional[torch.Tensor] = None,\n                kv_indices: Optional[torch.Tensor] = None,\n                kv_indptr: Optional[torch.Tensor] = None,\n                bsz_tensors: Optional[torch.Tensor] = None,\n                last_page_len: Optional[torch.Tensor] = None,\n                layer_idx: Optional[int] = None,\n                ):\n    global out_absorb\n    global q_absorb\n    # range bsz_tensors\n    final_attention_output = torch.tensor([], device=hidden_states.device)\n    for i in range(bsz_tensors[0]):\n        print(\"page_idx\", page_idx)\n        print(\"page_offset\", page_offset)\n        print(\"q_indptr\", q_indptr)\n        print(\"kv_indices\", kv_indices)\n        print(\"kv_indptr\", kv_indptr)\n\n        batch_num_tokens_tensors = q_indptr[i+1] - q_indptr[i]\n        batch_last_page_len = last_page_len[i]\n        # kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe\n        batch_page_idx = page_idx[q_indptr[i]:q_indptr[i+1]]\n        print('batch_page_idx',batch_page_idx)\n        batch_page_offset = page_offset[q_indptr[i]:q_indptr[i+1]]\n        # kv_page_nums is the number of pages for the current batch\n        kv_page_nums = kv_indptr[i+1] - kv_indptr[i]\n        # kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)\n        kv_total_len = kv_page_nums * page_size\n        if batch_last_page_len is not None:\n            kv_total_len = kv_total_len - (page_size - batch_last_page_len)\n        # print(f\"kv_total_len's shape {kv_total_len.shape}\")\n        # kv_index is the index of the kv cache pages for the current batch\n        kv_index = kv_indices[kv_indptr[i]:kv_indptr[i+1]]\n        # we can index [kv_index, page_offset_indices] to get the kv cache for the current batch\n        # from q_indptr[i] to q_indptr[i+1] is the range of the current batch\n        batch_hidden_states = hidden_states[q_indptr[i]:q_indptr[i+1]]\n        batch_position_ids = position_ids[q_indptr[i]:q_indptr[i+1]]\n        q_len, _ = batch_hidden_states.size()\n        # print(\"q_len -> \", q_len)\n        q_lora = q_a_proj(batch_hidden_states)\n        print('q_a_proj',q_a_proj.weight)\n        print('q_lora',q_lora)\n        \n        q = q_b_proj(q_a_layernorm(q_lora))\n        print('q_b_proj',q_b_proj.weight)\n        # for v3, bsz, q_len, num_heads(128), qk_head_dim(192=128(nope)+64(rope))\n        q = q.view(q_len, num_heads, nope_size+rope_size)\n        # q_nope is [q_len, num_heads(128), qk_nope_head_dim(128)]\n        # q_pe is [q_len, num_heads(128), qk_rope_head_dim(64)]\n        q_nope, q_pe = torch.split(\n            q, [nope_size, rope_size], dim=-1\n        )\n        print('q_nope',q_nope)\n        print('q_pe',q_pe)\n        # compressed_kv is [q_len, kv_lora_rank(512) + rope(64)]\n        compressed_kv = kv_a_proj_with_mqa(batch_hidden_states)\n        # compressed_kv is [q_len, kv_lora_rank(512)], k_pe is [q_len, rope(64)]\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [kv_lora_rank, rope_size], dim=-1\n        )\n        compressed_kv = compressed_kv.contiguous()\n        compressed_kv = kv_a_layernorm(compressed_kv)\n        # k_pe is [q_len, 1, qk_rope_head_dim(64)]\n        print('compressed_kv ',compressed_kv)\n        print('k_pe ',k_pe)\n        k_pe = k_pe.view(q_len, 1, rope_size)\n        # compressed_kv is [q_len, 1, kv_lora_rank(512)]\n        compressed_kv = compressed_kv.view(q_len, 1, kv_lora_rank)\n        \n        cos, sin = rotary_emb(q_pe, batch_position_ids)\n        # print(f\"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}\")\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=1)\n        q_pe = q_pe.squeeze(0)\n        # q_pe is [num_heads(128), q_len, qk_rope_head_dim(64)]\n        q_pe.transpose_(0, 1)            \n        if kv_cache is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"page_idx\": batch_page_idx, \"page_offset\": batch_page_offset}  # Specific to RoPE models\n            compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, layer_idx, batch_page_idx, batch_page_offset, cache_kwargs)\n            compressed_kv = compressed_kv_with_k_pe [:, :, :, :kv_lora_rank].view(-1, page_size, kv_lora_rank)\n            k_pe = compressed_kv_with_k_pe [:, :, :, kv_lora_rank:].view(-1, page_size, rope_size)\n        # q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]\n        # out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim\n        # q_absorb, out_absorb = get_absorbed()\n        # q_nope is [num_heads(128), q_len, qk_nope_head_dim(128)]\n        q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below\n        # q_nope is [num_heads(128), q_len, kv_lora_rank(512)]\n        q_nope = torch.matmul(q_nope, q_absorb) # batched MM\n\n        # # q_nope is [q_len, num_heads(128), kv_lora_rank(512)]\n        # q_nope = q_nope.transpose(0, 1)\n\n        # we need to index out the compressed_kv and k_pe for the current batch\n        batch_compressed_kv = None\n        batch_k_pe = None\n        for page_index in kv_index:\n            if kv_total_len > page_size:\n                tmp_compressed_kv = compressed_kv[page_index, 0:page_size, :]\n                tmp_k_pe = k_pe[page_index, 0:page_size, :]\n                if batch_compressed_kv is None or batch_k_pe is None:\n                    batch_compressed_kv = tmp_compressed_kv\n                    batch_k_pe = tmp_k_pe\n                else: \n                    batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n                    batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n                kv_total_len -= page_size\n            else:\n                tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :]\n                tmp_k_pe = k_pe[page_index, 0:kv_total_len, :]\n                if batch_compressed_kv is None or batch_k_pe is None:\n                    batch_compressed_kv = tmp_compressed_kv\n                    batch_k_pe = tmp_k_pe\n                else: \n                    batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n                    batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n                break\n        # batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]\n        # batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]\n        pe_weights = torch.matmul(q_pe,batch_k_pe.mT)\n        print('pe_weights',pe_weights)\n        attention_weights = (pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT)) * softmax_scale\n        # attention_weights is [num_heads(128), q_len, k_len]\n        \n        # attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(q_len,-1,-1).transpose(0,1)\n        \n        # attention_masks[i] is [q_len, k_len]\n        \n        attention_weights = (attention_weights + attention_masks[i])\n        # attention_weights shape is [num_heads(128), q_len, k_len]\n        attention_weights = nn.functional.softmax(attention_weights,dim=-1,dtype=torch.float16).to(q_pe.dtype)\n        attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),q_len, lora_rank(512)]\n        # out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]\n        out_absorb = out_absorb.transpose(1,2)\n        # q for q_len, n for num_heads, h for v_head_dim, v for kv_lora_rank\n        attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), q_len, v_head_dim(128)]\n        attn_output = attn_output.transpose(0, 1) # [q_len, num_heads(128), v_head_dim(128)]\n        attn_output = attn_output.reshape(q_len, num_heads * nope_size)\n        attn_output = o_proj(attn_output)\n        final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)\n    return final_attention_output\n\n\n\ndef torch_attn_for_test(hidden_states,kv_cache,):\n    pass\n\ndef test_mla_simple():\n    result = torch_attn(\n        hidden_states,\n        kv_cache,\n        position_ids,\n        page_idx,\n        page_offset,\n        attention_masks=attention_masks,\n        q_indptr=q_indptr,\n        kv_indices=kv_indices,\n        kv_indptr=kv_indptr,\n        bsz_tensors=bsz_tensors,\n        last_page_len=last_page_len,\n        layer_idx=1\n    )\n    print(result.shape)\n    print(result)\n    \ntest_mla_simple()"
  },
  {
    "path": "kt-kernel/examples/test_mla_torch.py",
    "content": "import os, sys\nimport time\nfrom typing import Optional\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext.kvcache import ggml_type\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch_attention import apply_rotary_pos_emb, DeepseekV2RMSNorm, KDeepSeekV3Cache, DeepseekV3YarnRotaryEmbedding\n\n\nseed = 42  # 你可以选择任何整数作为种子\ntorch.manual_seed(seed)\ntorch.cuda.manual_seed_all(seed)\n\nqlen = 1024\nkvlen = 0\n\n\npage_table = range(20)\nbsz_tensors = torch.tensor([1])\n\n\npage_size = 256\npages_count = 200\ntp_count = 4\n\n\nhidden_size = 7168\nq_lora_rank = 1536\nkv_lora_rank = 512\nnum_heads = 128\nnope_size = 128\nrope_size = 64\n\nrope_theta = 10000\nmax_qlen = 1024\nmax_kvlen = 4096\n\nmax_position_embeddings = 163840\n\n\nrope_scaling = {\n    \"beta_fast\": 32,\n    \"beta_slow\": 1,\n    \"factor\": 40,\n    \"mscale\": 1.0,\n    \"mscale_all_dim\": 1.0,\n    \"original_max_position_embeddings\": 4096,\n    \"type\": \"yarn\",\n}\n\n\nCPUInfer = kt_kernel_ext.CPUInfer(64)\nvalidation_iter = 100\n\n\nq_a_proj = nn.Linear(hidden_size, q_lora_rank, bias=False, dtype=torch.float16)\nq_b_proj = nn.Linear(q_lora_rank, num_heads * (nope_size + rope_size), bias=False, dtype=torch.float16)\nkv_a_proj_with_mqa = nn.Linear(hidden_size, kv_lora_rank + rope_size, bias=False, dtype=torch.float16)\nkv_b_proj = nn.Linear(kv_lora_rank, num_heads * (nope_size + nope_size), bias=False, dtype=torch.float16)\no_proj = nn.Linear(num_heads * nope_size, hidden_size, bias=False, dtype=torch.float16)\n\ninit.normal_(q_a_proj.weight, mean=0.0, std=0.02)\ninit.normal_(q_b_proj.weight, mean=0.0, std=0.02)\ninit.normal_(kv_a_proj_with_mqa.weight, mean=0.0, std=0.02)\ninit.normal_(kv_b_proj.weight, mean=0.0, std=0.02)\ninit.normal_(o_proj.weight, mean=0.0, std=0.02)\n\nq_a_proj_weight = q_a_proj.weight.to(torch.float16).to(\"cpu\").contiguous()\nq_b_proj_weight = q_b_proj.weight.to(torch.float16).to(\"cpu\").contiguous()\nkv_a_proj_with_mqa_weight = kv_a_proj_with_mqa.weight.to(\"cpu\").to(torch.float16).contiguous()\nkv_b_proj_weight = kv_b_proj.weight.to(torch.float16).to(\"cpu\").contiguous()\no_proj_weight = o_proj.weight.to(torch.float16).to(\"cpu\").contiguous()\n\n\nconfig = kt_kernel_ext.mla.MLAConfig(\n    hidden_size,\n    q_lora_rank,\n    kv_lora_rank,\n    num_heads,\n    nope_size,\n    rope_size,\n)\nconfig.max_qlen = max_qlen\nconfig.max_kvlen = max_kvlen\nconfig.max_position_embeddings = max_position_embeddings\nconfig.rope_scaling_factor = rope_scaling[\"factor\"]\nconfig.rope_theta = rope_theta\nconfig.rope_scaling_beta_fast = rope_scaling[\"beta_fast\"]\nconfig.rope_scaling_beta_slow = rope_scaling[\"beta_slow\"]\nconfig.rope_scaling_mscale = rope_scaling[\"mscale\"]\nconfig.rope_scaling_mscale_all_dim = rope_scaling[\"mscale_all_dim\"]\nconfig.rope_scaling_original_max_position_embeddings = rope_scaling[\"original_max_position_embeddings\"]\n\nconfig.q_a_proj = q_a_proj_weight.data_ptr()\nconfig.q_b_proj = q_b_proj_weight.data_ptr()\nconfig.kv_a_proj_with_mqa = kv_a_proj_with_mqa_weight.data_ptr()\nconfig.kv_b_proj = kv_b_proj_weight.data_ptr()\nconfig.o_proj = o_proj_weight.data_ptr()\n\nconfig.q_a_proj_type = ggml_type.FP16\nconfig.q_b_proj_type = ggml_type.FP16\nconfig.kv_a_proj_with_mqa_type = ggml_type.FP16\nconfig.kv_b_proj_type = ggml_type.FP16\nconfig.w_o_type = ggml_type.FP16\n\n\nconfig.pool = CPUInfer.backend_\n\n\nmla = kt_kernel_ext.mla.MLA(config)\nmla.load_weights()\nmla.set_local_pages(pages_count)\n\n\ninput = torch.randn((qlen, hidden_size), dtype=torch.float16).to(\"cpu\").contiguous()\n\n\noutput = torch.zeros((qlen, hidden_size), dtype=torch.float16).to(\"cpu\").contiguous()\nmla.forward([qlen], [page_table], [kvlen], input.data_ptr(), output.data_ptr())\nprint(\"CPU MLA Output: \", output)\n\n\nsoftmax_scale = (nope_size + rope_size) ** -0.5\n# 1代表的是压缩的kv的头数\nk_caches = torch.randn(1, pages_count, page_size, 1, kv_lora_rank + rope_size).to(torch.float16)\nkv_cache = KDeepSeekV3Cache(page_size=page_size, kv_lora_rank=kv_lora_rank, k_caches=k_caches)\n\nq_a_layernorm = DeepseekV2RMSNorm(q_lora_rank)\n\nx = torch.randn(q_lora_rank, dtype=torch.float16) * 100\nprint(x)\nprint(q_a_layernorm(x))\n\nkv_a_layernorm = DeepseekV2RMSNorm(kv_lora_rank)\n\n\nq_absorb = torch.randn(num_heads, nope_size, kv_lora_rank, dtype=torch.float16)\nout_absorb = torch.randn(num_heads, nope_size, kv_lora_rank, dtype=torch.float16)\n\nrotary_emb = DeepseekV3YarnRotaryEmbedding(\n    rope_size,\n    max_position_embeddings=max_position_embeddings,\n    scaling_factor=rope_scaling[\"factor\"],\n    base=rope_theta,\n    beta_fast=rope_scaling[\"beta_fast\"],\n    beta_slow=rope_scaling[\"beta_slow\"],\n    mscale=rope_scaling[\"mscale\"],\n    mscale_all_dim=rope_scaling[\"mscale_all_dim\"],\n    original_max_position_embeddings=rope_scaling[\"original_max_position_embeddings\"],\n)\n# 构造一个qlen 长度的输入 hidden_states, 对应的历史 kv_indptr 是[0:bsz]\n# kv_indices 是[0:bsz]，page_idx=[0:bsz], page_offset=[kvlen:qlen+kvlen]\n# last_page_len = [qlen+kvlen,...] layer_idx = 1\n# position_ids = [kvlen:qlen+kvlen]\nhidden_states = torch.randn(qlen, hidden_size, dtype=torch.float16)\nq_indptr = torch.tensor([0, qlen]).to(torch.int32)\n\nkv_indptr = torch.tensor([0, (qlen + kvlen + page_size - 1) // page_size]).to(torch.int32)\nkv_indices = torch.tensor(range(pages_count)).to(torch.int32)\n\npage_idx = torch.tensor([i // page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)\npage_offset = torch.tensor([i % page_size for i in range(kvlen, kvlen + qlen)]).to(torch.int32)\n\nlast_page_len = torch.tensor([(qlen + kvlen) % page_size], device=hidden_states.device)\nposition_ids = torch.tensor(range(kvlen, kvlen + qlen)).to(torch.int32)\n\n\n# 按照行创建 mask [qlen,kvlen+qlen]\nattention_masks = torch.zeros((qlen, kvlen + qlen), dtype=torch.float16)\nfor i in range(qlen):\n    attention_masks[i, i + kvlen + 1 : i + kvlen + qlen] = -65504.0\n\n\ndef torch_attn(\n    hidden_states: torch.Tensor,\n    kv_cache: KDeepSeekV3Cache,\n    position_ids: torch.Tensor,\n    page_idx: torch.Tensor,\n    page_offset: torch.Tensor,\n    attention_masks: Optional[list[torch.Tensor]] = None,\n    q_indptr: Optional[torch.Tensor] = None,\n    kv_indices: Optional[torch.Tensor] = None,\n    kv_indptr: Optional[torch.Tensor] = None,\n    bsz_tensors: Optional[torch.Tensor] = None,\n    last_page_len: Optional[torch.Tensor] = None,\n    layer_idx: Optional[int] = None,\n):\n    global out_absorb\n    global q_absorb\n    # range bsz_tensors\n    final_attention_output = torch.tensor([], device=hidden_states.device)\n    for i in range(bsz_tensors[0]):\n        batch_num_tokens_tensors = q_indptr[i + 1] - q_indptr[i]\n        batch_last_page_len = last_page_len[i]\n        # kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe\n        batch_page_idx = page_idx[q_indptr[i] : q_indptr[i + 1]]\n        batch_page_offset = page_offset[q_indptr[i] : q_indptr[i + 1]]\n        # kv_page_nums is the number of pages for the current batch\n        kv_page_nums = kv_indptr[i + 1] - kv_indptr[i]\n        # kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)\n        kv_total_len = kv_page_nums * page_size\n        if batch_last_page_len is not None:\n            kv_total_len = kv_total_len - (page_size - batch_last_page_len)\n        # print(f\"kv_total_len's shape {kv_total_len.shape}\")\n        # kv_index is the index of the kv cache pages for the current batch\n        kv_index = kv_indices[kv_indptr[i] : kv_indptr[i + 1]]\n        # we can index [kv_index, page_offset_indices] to get the kv cache for the current batch\n        # from q_indptr[i] to q_indptr[i+1] is the range of the current batch\n        batch_hidden_states = hidden_states[q_indptr[i] : q_indptr[i + 1]]\n        batch_position_ids = position_ids[q_indptr[i] : q_indptr[i + 1]]\n        qlen, _ = batch_hidden_states.size()\n        # print(\"qlen -> \", qlen)\n        q_lora = q_a_proj(batch_hidden_states)\n        print(\"q_a_proj\", q_a_proj.weight)\n        print(\"q_lora\", q_lora)\n\n        q = q_b_proj(q_a_layernorm(q_lora))\n        print(\"q_b_proj\", q_b_proj.weight)\n        # for v3, bsz, qlen, num_heads(128), qk_head_dim(192=128(nope)+64(rope))\n        q = q.view(qlen, num_heads, nope_size + rope_size)\n        # q_nope is [qlen, num_heads(128), qk_nope_head_dim(128)]\n        # q_pe is [qlen, num_heads(128), qk_rope_head_dim(64)]\n        q_nope, q_pe = torch.split(q, [nope_size, rope_size], dim=-1)\n        print(\"q_nope\", q_nope)\n        print(\"q_pe\", q_pe)\n        # compressed_kv is [qlen, kv_lora_rank(512) + rope(64)]\n        compressed_kv = kv_a_proj_with_mqa(batch_hidden_states)\n        # compressed_kv is [qlen, kv_lora_rank(512)], k_pe is [qlen, rope(64)]\n        compressed_kv, k_pe = torch.split(compressed_kv, [kv_lora_rank, rope_size], dim=-1)\n        compressed_kv = compressed_kv.contiguous()\n        compressed_kv = kv_a_layernorm(compressed_kv)\n        # k_pe is [qlen, 1, qk_rope_head_dim(64)]\n        print(\"compressed_kv \", compressed_kv)\n        print(\"k_pe \", k_pe)\n        k_pe = k_pe.view(qlen, 1, rope_size)\n        # compressed_kv is [qlen, 1, kv_lora_rank(512)]\n        compressed_kv = compressed_kv.view(qlen, 1, kv_lora_rank)\n\n        cos, sin = rotary_emb(q_pe, batch_position_ids)\n        # print(f\"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}\")\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=1)\n        q_pe = q_pe.squeeze(0)\n        # q_pe is [num_heads(128), qlen, qk_rope_head_dim(64)]\n        q_pe.transpose_(0, 1)\n        if kv_cache is not None:\n            cache_kwargs = {\n                \"sin\": sin,\n                \"cos\": cos,\n                \"page_idx\": batch_page_idx,\n                \"page_offset\": batch_page_offset,\n            }  # Specific to RoPE models\n            compressed_kv_with_k_pe = kv_cache.update(\n                compressed_kv.unsqueeze(0), k_pe, layer_idx, batch_page_idx, batch_page_offset, cache_kwargs\n            )\n            compressed_kv = compressed_kv_with_k_pe[:, :, :, :kv_lora_rank].view(-1, page_size, kv_lora_rank)\n            k_pe = compressed_kv_with_k_pe[:, :, :, kv_lora_rank:].view(-1, page_size, rope_size)\n        # q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]\n        # out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim\n        # q_absorb, out_absorb = get_absorbed()\n        # q_nope is [num_heads(128), qlen, qk_nope_head_dim(128)]\n        q_nope = q_nope.transpose(0, 1)  # qlen is 1, no GPU overhead, same below\n        # q_nope is [num_heads(128), qlen, kv_lora_rank(512)]\n        q_nope = torch.matmul(q_nope, q_absorb)  # batched MM\n\n        # # q_nope is [qlen, num_heads(128), kv_lora_rank(512)]\n        # q_nope = q_nope.transpose(0, 1)\n\n        # we need to index out the compressed_kv and k_pe for the current batch\n        batch_compressed_kv = None\n        batch_k_pe = None\n        for page_index in kv_index:\n            if kv_total_len > page_size:\n                tmp_compressed_kv = compressed_kv[page_index, 0:page_size, :]\n                tmp_k_pe = k_pe[page_index, 0:page_size, :]\n                if batch_compressed_kv is None or batch_k_pe is None:\n                    batch_compressed_kv = tmp_compressed_kv\n                    batch_k_pe = tmp_k_pe\n                else:\n                    batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n                    batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n                kv_total_len -= page_size\n            else:\n                tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :]\n                tmp_k_pe = k_pe[page_index, 0:kv_total_len, :]\n                if batch_compressed_kv is None or batch_k_pe is None:\n                    batch_compressed_kv = tmp_compressed_kv\n                    batch_k_pe = tmp_k_pe\n                else:\n                    batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n                    batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n                break\n        # batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]\n        # batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]\n        pe_weights = torch.matmul(q_pe, batch_k_pe.mT)\n        print(\"pe_weights\", pe_weights)\n        attention_weights = (pe_weights + torch.matmul(q_nope, batch_compressed_kv.mT)) * softmax_scale\n        # attention_weights is [num_heads(128), qlen, k_len]\n\n        # attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(qlen,-1,-1).transpose(0,1)\n\n        # attention_masks[i] is [qlen, k_len]\n\n        attention_weights = attention_weights + attention_masks[i]\n        # attention_weights shape is [num_heads(128), qlen, k_len]\n        attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float16).to(q_pe.dtype)\n        attn_output = torch.matmul(attention_weights, batch_compressed_kv)  # [num_heads(128),qlen, lora_rank(512)]\n        # out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]\n        out_absorb = out_absorb.transpose(1, 2)\n        # q for qlen, n for num_heads, h for v_head_dim, v for kv_lora_rank\n        attn_output = torch.matmul(attn_output, out_absorb)  # [num_heads(128), qlen, v_head_dim(128)]\n        attn_output = attn_output.transpose(0, 1)  # [qlen, num_heads(128), v_head_dim(128)]\n        attn_output = attn_output.reshape(qlen, num_heads * nope_size)\n        attn_output = o_proj(attn_output)\n        final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)\n    return final_attention_output\n\n\ntorch_output = torch_attn(\n    input,\n    kv_cache,\n    position_ids,\n    page_idx,\n    page_offset,\n    attention_masks=attention_masks,\n    q_indptr=q_indptr,\n    kv_indices=kv_indices,\n    kv_indptr=kv_indptr,\n    bsz_tensors=bsz_tensors,\n    last_page_len=last_page_len,\n    layer_idx=0,\n)\nprint(\"Torch Output: \", torch_output)\n"
  },
  {
    "path": "kt-kernel/examples/test_mlp.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-06 10:37:28\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\nhidden_size = 5120\nintermediate_size = 3072\nstride = 32\ngroup_max_len = 1024\ngate_type = 1  # ggml_type::GGML_TYPE_F16\nup_type = 1  # ggml_type::GGML_TYPE_F16\ndown_type = 1  # ggml_type::GGML_TYPE_F16\nhidden_type = 1  # ggml_type::GGML_TYPE_F16\nqlen = 30\nlayer_num = 10\nCPUInfer = kt_kernel_ext.CPUInfer(48)\nvalidation_iter = 100\n\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\nwith torch.inference_mode(mode=True):\n    mlps = []\n    gate_projs = []\n    up_projs = []\n    down_projs = []\n    for _ in range(layer_num):\n        gate_proj = (\n            torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        up_proj = (\n            torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        down_proj = (\n            torch.randn((hidden_size, intermediate_size), dtype=torch.float16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        config = kt_kernel_ext.mlp.MLPConfig(\n            hidden_size,\n            intermediate_size,\n            stride,\n            group_max_len,\n            gate_proj.data_ptr(),\n            up_proj.data_ptr(),\n            down_proj.data_ptr(),\n            gate_type,\n            up_type,\n            down_type,\n            hidden_type,\n        )\n        mlp = kt_kernel_ext.mlp.MLP(config)\n        gate_projs.append(gate_proj)\n        up_projs.append(up_proj)\n        down_projs.append(down_proj)\n        mlps.append(mlp)\n\n    # validation\n    for i in range(validation_iter):\n        mlp = mlps[i % layer_num]\n        input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()\n        output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()\n        input = input / 100\n\n        CPUInfer.submit(mlp.forward(qlen, input.data_ptr(), output.data_ptr()))\n        CPUInfer.sync()\n        # print('cpuinfer output', output)\n\n        gate_proj = gate_projs[i % layer_num]\n        up_proj = up_projs[i % layer_num]\n        down_proj = down_projs[i % layer_num]\n        t_output = mlp_torch(input, gate_proj, up_proj, down_proj)\n        # print('torch output', t_output)\n\n        diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n        print(\"diff = \", diff)\n        assert diff < 0.001\n"
  },
  {
    "path": "kt-kernel/examples/test_moe.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : SkqLiao\nLastEditTime : 2025-03-13 11:38:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys\nimport time\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nfrom kt_kernel import kt_kernel_ext\nimport torch\nfrom tqdm import tqdm\nfrom kt_kernel_ext.kvcache import ggml_type\n\ntorch.manual_seed(0)\n\nexpert_num = 8\nhidden_size = 2048  # 7168\nintermediate_size = 2048\nstride = 32\ngroup_min_len = 10\ngroup_max_len = 2560\nnum_experts_per_tok = 8\nlayer_num = 1\n# expert_num = 8\n# hidden_size = 7168\n# intermediate_size = 2048\n# stride = 32\n# group_min_len = 10\n# group_max_len = 10240\n# num_experts_per_tok = 8\n# qlen = 1024\n# layer_num = 1\nCPUInfer = kt_kernel_ext.CPUInfer(64)\nvalidation_iter = 10\n\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\n\ndef to_cpuinfer_tensor(tensor, type):\n    size = torch.prod(torch.tensor(tensor.shape, dtype=torch.int32)).item()\n    return kt_kernel_ext.utils.from_float(tensor.data_ptr(), size, type)\n\n\ndef from_cpuinfer_tensor(tensor, size, type):\n    return kt_kernel_ext.utils.to_float(tensor.data_ptr(), size, type)\n\n\nqlens = [1, 64]  # [64, 512, 2048, 8192, 16384]\n# gate_types = [ggml_type.FP32, ggml_type.FP16, ggml_type.Q8_0, ggml_type.Q6_K, ggml_type.Q5_K, ggml_type.Q4_K, ggml_type.Q3_K]\n# up_types = [ggml_type.FP32, ggml_type.FP16, ggml_type.Q8_0, ggml_type.Q6_K, ggml_type.Q5_K, ggml_type.Q4_K, ggml_type.Q3_K]\n# down_types = [ggml_type.FP32, ggml_type.FP16, ggml_type.Q8_0, ggml_type.Q6_K, ggml_type.Q6_K, ggml_type.Q6_K, ggml_type.Q5_K]\ngate_types = [ggml_type.Q4_K]\nup_types = [ggml_type.Q4_K]\ndown_types = [ggml_type.Q6_K]\nhidden_type = ggml_type.BF16\nprint(f\"Parameters: expert_num: {expert_num} hidden_size: {hidden_size} intermediate_size: {intermediate_size}\")\nprint(f\"group_max_len: \", group_max_len)\n\nfor qlen in qlens:\n    for gate_type, up_type, down_type in zip(gate_types, up_types, down_types):\n        with torch.inference_mode(mode=True):\n            moes = []\n            gate_projs = []\n            up_projs = []\n            down_projs = []\n            print(\"Preparing data...\")\n            converted_tensors = []\n            for _ in range(layer_num):\n                size = expert_num * intermediate_size * hidden_size\n                gate_proj = (\n                    torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                    .to(\"cpu\")\n                    .contiguous()\n                )\n                up_proj = (\n                    torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                    .to(\"cpu\")\n                    .contiguous()\n                )\n                down_proj = (\n                    torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device=\"cuda\")\n                    .to(\"cpu\")\n                    .contiguous()\n                )\n\n                gate_tensor = to_cpuinfer_tensor(gate_proj, gate_type)\n                up_tensor = to_cpuinfer_tensor(up_proj, up_type)\n                down_tensor = to_cpuinfer_tensor(down_proj, down_type)\n\n                config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n                config.pool = CPUInfer.backend_\n                config.stride = stride\n                config.group_min_len = group_min_len\n                config.group_max_len = group_max_len\n                config.gate_proj = gate_tensor.data_ptr()\n                config.up_proj = up_tensor.data_ptr()\n                config.down_proj = down_tensor.data_ptr()\n                config.gate_type = gate_type\n                config.up_type = up_type\n                config.down_type = down_type\n                config.hidden_type = hidden_type\n\n                moe = kt_kernel_ext.moe.MOE(config)\n                gate_projs.append(gate_proj)\n                up_projs.append(up_proj)\n                down_projs.append(down_proj)\n                CPUInfer.submit(moe.load_weights_task())\n                CPUInfer.sync()\n                moes.append(moe)\n                converted_tensors.append((gate_tensor, up_tensor, down_tensor))\n            print(\"Finished initialization!\")\n\n            CPUInfer.submit(moes[0].warm_up_task())\n            CPUInfer.sync()\n            print(\"Warm up finished!\")\n\n            # validation\n            progress_bar = tqdm(range(validation_iter), desc=\"Starting\")\n            total_diff = 0\n\n            for i in tqdm(progress_bar):\n                progress_bar.set_description(\"Round: {}/{}\".format(i + 1, validation_iter))\n                expert_ids = torch.stack(\n                    [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n                ).contiguous()\n                weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n                input_proj = torch.randn((qlen, hidden_size), dtype=torch.float32).contiguous() / 100\n                output_proj = torch.empty((qlen, hidden_size), dtype=torch.float32).contiguous()\n\n                input_tensor = to_cpuinfer_tensor(input_proj, hidden_type)\n                output_tensor = to_cpuinfer_tensor(output_proj, hidden_type)\n\n                qlen_tensor = torch.tensor([qlen], dtype=torch.int32)\n                moe = moes[i % layer_num]\n                CPUInfer.submit(\n                    moe.forward_task(\n                        qlen_tensor.data_ptr(),\n                        num_experts_per_tok,\n                        expert_ids.data_ptr(),\n                        weights.data_ptr(),\n                        input_tensor.data_ptr(),\n                        output_tensor.data_ptr(),\n                    )\n                )\n                CPUInfer.sync()\n                cpu_output = from_cpuinfer_tensor(output_tensor, qlen * hidden_size, hidden_type)\n\n                gate_proj = gate_projs[i % layer_num]\n                up_proj = up_projs[i % layer_num]\n                down_proj = down_projs[i % layer_num]\n                t_output = moe_torch(input_proj, expert_ids, weights, gate_proj, up_proj, down_proj)\n                print(\"cpuinfer output\", cpu_output)\n                print(\"torch output\", t_output)\n                diff = torch.mean(torch.abs(cpu_output.flatten() - t_output.flatten())) / torch.mean(\n                    torch.abs(t_output.flatten())\n                )\n                assert diff < 0.5\n                total_diff += diff\n\n            print(f\"gate_type: {gate_type}, up_type: {up_type}, down_type: {down_type}\")\n            print(f\"Average diff: {total_diff / validation_iter:.4f}\")\n"
  },
  {
    "path": "kt-kernel/examples/test_moe_amx.py",
    "content": "import os, sys\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nprint(\"sys.path:\", sys.path)\n\nimport torch\nfrom kt_kernel import kt_kernel_ext\n\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nqlen = 1\n# qlen = 640\nlayer_num = 1\nCPUInfer = kt_kernel_ext.CPUInfer(90)\n# validation_iter = 10000\nvalidation_iter = 2\nk_group_size = 64\ndebug_print_count = 16  # Number of values to print in debug output\nphysical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj, debug_expert_id=None, debug_print=False):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n\n    if debug_print and debug_expert_id is not None:\n        print(f\"[TORCH DEBUG] Expert {debug_expert_id}:\")\n        print(f\"  gate_buf[:{debug_print_count}] = {gate_buf.flatten()[:debug_print_count]}\")\n        print(f\"  up_buf[:{debug_print_count}] = {up_buf.flatten()[:debug_print_count]}\")\n\n    intermediate = act_fn(gate_buf) * up_buf\n\n    if debug_print and debug_expert_id is not None:\n        print(f\"  intermediate[:{debug_print_count}] = {intermediate.flatten()[:debug_print_count]}\")\n\n    ret = torch.mm(intermediate, down_proj.t())\n\n    if debug_print and debug_expert_id is not None:\n        print(f\"  down_output[:{debug_print_count}] = {ret.flatten()[:debug_print_count]}\")\n\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj, debug_print=False):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    # Get the first expert from expert_ids array to match AWQ-MoE behavior\n    target_debug_expert = expert_ids[0, 0].item()  # First expert in expert_ids array\n\n    outputs = []\n    start_idx = 0\n    activated_experts = []\n\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        activated_experts.append(i)\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        # Only debug the target expert that matches AWQ-MoE's first expert\n        should_debug = debug_print and i == target_debug_expert\n        expert_out = mlp_torch(\n            tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i], debug_expert_id=i, debug_print=should_debug\n        )\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    if debug_print:\n        print(f\"[TORCH DEBUG] Processing activated experts: {activated_experts}\")\n        print(f\"[TORCH DEBUG] Target debug expert (matches AWQ): {target_debug_expert}\")\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n\n    if debug_print:\n        print(f\"[TORCH DEBUG] Final MoE output[:{debug_print_count}] = {t_output.flatten()[:debug_print_count]}\")\n\n    return t_output\n\n\ndef test_moe(quant_mode: str):\n    assert (\n        quant_mode == \"bf16\"\n        or quant_mode == \"int8\"\n        or quant_mode == \"int4\"\n        or quant_mode == \"int4_1\"\n        or quant_mode == \"int4_1k\"\n    )\n    with torch.inference_mode(mode=True):\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.bfloat16, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.gate_scale = 0\n            config.pool = CPUInfer.backend_\n            if quant_mode == \"bf16\":\n                moe = kt_kernel_ext.moe.AMXBF16_MOE(config)\n                CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n                CPUInfer.sync()\n                CPUInfer.submit(moe.warm_up_task())\n                CPUInfer.sync()\n            elif quant_mode == \"int8\":\n                moe = kt_kernel_ext.moe.AMXInt8_MOE(config)\n                CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n                CPUInfer.sync()\n                # CPUInfer.submit(moe.warm_up_task())\n                # CPUInfer.sync()\n            elif quant_mode == \"int4\":\n                moe = kt_kernel_ext.moe.AMXInt4_MOE(config)\n                CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n                CPUInfer.sync()\n                CPUInfer.submit(moe.warm_up_task())\n                CPUInfer.sync()\n            elif quant_mode == \"int4_1\":\n                moe = kt_kernel_ext.moe.AMXInt4_1_MOE(config)\n                CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n                CPUInfer.sync()\n                CPUInfer.submit(moe.warm_up_task())\n                CPUInfer.sync()\n            elif quant_mode == \"int4_1k\":\n                config.quant_config.bits = 4\n                config.quant_config.group_size = k_group_size\n                config.quant_config.zero_point = True\n                moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)\n                # import debugpy\n                # debugpy.listen((\"127.0.0.1\", 5678))\n                # debugpy.wait_for_client()\n                # debugpy.breakpoint()\n                print(f\"the physical_logical map:{physical_to_logical_map.data_ptr()}\")\n                CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n                CPUInfer.sync()\n                # CPUInfer.submit(moe.warm_up_task())\n                # CPUInfer.sync()\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n\n        # validation\n        for i in range(validation_iter):\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n            input = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            input = input / 100\n            moe = moes[i % layer_num]\n\n            # Enable debug for first few iterations\n            enable_debug = i < 2\n            enable_debug = False\n            if enable_debug:\n                print(f\"\\n=== Iteration {i} Debug Info ===\")\n                print(f\"input[:{debug_print_count}] = {input.flatten()[:debug_print_count]}\")\n                print(f\"expert_ids = {expert_ids}\")\n                print(f\"weights = {weights}\")\n                # Print which experts will be activated for comparison\n                activated_experts = []\n                for token in range(expert_ids.shape[0]):\n                    for expert_idx in range(expert_ids.shape[1]):\n                        expert_id = expert_ids[token][expert_idx].item()\n                        if expert_id not in activated_experts:\n                            activated_experts.append(expert_id)\n                print(f\"[TORCH DEBUG] Activated experts: {sorted(activated_experts)}\")\n                print(f\"[TORCH DEBUG] First expert from expert_ids array: {expert_ids[0, 0].item()}\")\n            print(f\"expert_ids = {expert_ids}\")\n            # print('expert ids:',expert_ids)\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            if enable_debug:\n                print(f\"[AWQ-MOE DEBUG] AMX output[:{debug_print_count}] = {output.flatten()[:debug_print_count]}\")\n\n            gate_proj = gate_projs[i % layer_num]\n            up_proj = up_projs[i % layer_num]\n            down_proj = down_projs[i % layer_num]\n            t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj, debug_print=enable_debug)\n            print(\"torch output\", t_output)\n            print(\"amx output\", output)\n\n            # print(output - t_output)\n            # print(torch.abs(output - t_output))\n            diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n            # print(f'output_shape:{output.shape}, t_output_shape:{t_output.shape}\\n')\n            print(f\"Iteration {i}, diff = {diff:.6f}\")\n\n            if enable_debug:\n                abs_diff = torch.abs(output - t_output)\n                print(f\"[COMPARE] Max abs diff = {torch.max(abs_diff):.6f}\")\n                print(f\"[COMPARE] Mean abs diff = {torch.mean(abs_diff):.6f}\")\n                print(f\"[COMPARE] Relative diff = {diff:.6f}\")\n                print(\"=\" * 50)\n\n            if quant_mode == \"int4\" or quant_mode == \"int4_1\" or quant_mode == \"int4_1k\":\n                assert diff < 0.35\n            else:\n                assert diff < 0.05\n\n\n# only turn on 1 at a time\n\n# Debug mode is enabled for the first 2 iterations to compare intermediate results\n# between torch implementation and AWQ-MoE implementation.\n# The debug output shows:\n# 1. Input values and expert assignments\n# 2. Gate and up projection results\n# 3. Intermediate values after activation function\n# 4. Down projection results\n# 5. Final output comparison\n\n# test_moe(\"bf16\")\ntest_moe(\"int8\")\ntest_moe(\"int4\")\ntest_moe(\"int4_1\")\ntest_moe(\"int4_1k\")\n"
  },
  {
    "path": "kt-kernel/examples/test_moe_kernel.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-06 10:38:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys\nimport time\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nimport torch\nfrom kt_kernel import kt_kernel_ext\n\n\nexpert_num = 16\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 4096\nnum_experts_per_tok = 8\nm_block = 320\nn_block_up_gate = 32\nn_block_down = 64\nn_block_up_gate_prefi = 32\nn_block_down_prefi = 64\n# qlen = 1\nqlen = 1024\nlayer_num = 1\nCPUInfer = kt_kernel_ext.CPUInfer(160)\n# validation_iter = 10000\nvalidation_iter = 1\n\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\n\ndef test_moe(quant_mode: str):\n    assert quant_mode == \"int8\" or quant_mode == \"int4\" or quant_mode == \"int4_1\"\n    with torch.inference_mode(mode=True):\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device=\"cpu\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device=\"cpu\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.bfloat16, device=\"cpu\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = CPUInfer.backend_\n            if quant_mode == \"int8\":\n                d = kt_kernel_ext.moe.tiling.get_int8()\n                nbug_prefi = n_block_up_gate_prefi\n                nbd_prefi = n_block_down_prefi\n                kb = d[\"k_block\"]\n                nb = d[\"n_block\"]\n                mb = m_block\n                nbug = n_block_up_gate\n                nbd = n_block_down\n                print(\n                    f\"Int8 Tiling: nbug {nbug}, nbd {nbd}, nb {nb}, mb {mb}, kb {kb}, nbug_prefi {nbug_prefi}, nbd_prefi {nbd_prefi}\"\n                )\n                kt_kernel_ext.moe.tiling.set_int8(nbug, nbd, nb, mb, kb, nbug_prefi, nbd_prefi)\n                moe = kt_kernel_ext.moe.Int8_KERNEL_MOE(config)\n                CPUInfer.submit(moe.load_weights_task())\n                CPUInfer.sync()\n                # CPUInfer.submit(moe.warm_up_task())\n                # CPUInfer.sync()\n            elif quant_mode == \"int4\":\n                moe = kt_kernel_ext.moe.Int4_KERNEL_MOE(config)\n                CPUInfer.submit(moe.load_weights_task())\n                CPUInfer.sync()\n                CPUInfer.submit(moe.warm_up_task())\n                CPUInfer.sync()\n            else:\n                raise ValueError(f\"Unsupported quantization mode: {quant_mode}\")\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n\n        # validation\n        for i in range(validation_iter):\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n            input = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            input = input / 100\n            # 打印 input 的内容\n            print(\"input:\", input)\n            moe = moes[i % layer_num]\n            # print('expert ids:',expert_ids)\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n            print(\"cpuinfer output\", output)\n\n            gate_proj = gate_projs[i % layer_num]\n            up_proj = up_projs[i % layer_num]\n            down_proj = down_projs[i % layer_num]\n            t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj)\n            print(\"torch output\", t_output)\n\n            # print(output - t_output)\n            diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n            print(\"diff = \", diff)\n            if quant_mode == \"int4\":\n                assert diff < 0.35\n            else:\n                assert diff < 0.05\n\n\ntest_moe(\"int8\")\n# test_moe(\"int4\")\n"
  },
  {
    "path": "kt-kernel/examples/test_moe_kml.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022\nLastEditTime : 2024-08-06 10:38:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport os, sys\nimport time\n\nsys.path.insert(0, os.path.dirname(__file__) + \"/../build\")\nos.environ[\"BLAS_NUM_THREADS\"] = \"1\"\nfrom kt_kernel import kt_kernel_ext\nimport torch\n\nexpert_num = 16\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 4096\nnum_experts_per_tok = 8\nqlen = 512\n# qlen = 640\nlayer_num = 1\nCPUInfer = kt_kernel_ext.CPUInfer(112)\n# validation_iter = 10000\nvalidation_iter = 1\n\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\n\ndef test_moe(quant_mode: str):\n    assert quant_mode == \"bf16\" or quant_mode == \"int8\" or quant_mode == \"int4\" or quant_mode == \"int4_1\"\n    with torch.inference_mode(mode=True):\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device=\"cpu\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.bfloat16, device=\"cpu\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.bfloat16, device=\"cpu\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = CPUInfer.backend_\n            if quant_mode == \"bf16\":\n                moe = kt_kernel_ext.moe.AMXBF16_MOE(config)\n                CPUInfer.submit(moe.load_weights_task())\n                CPUInfer.sync()\n                CPUInfer.submit(moe.warm_up_task())\n                CPUInfer.sync()\n            elif quant_mode == \"int8\":\n                moe = kt_kernel_ext.moe.KMLInt8_MOE(config)\n                CPUInfer.submit(moe.load_weights_task())\n                CPUInfer.sync()\n                # CPUInfer.submit(moe.warm_up_task())\n                # CPUInfer.sync()\n            elif quant_mode == \"int4\":\n                moe = kt_kernel_ext.moe.KMLInt4_MOE(config)\n                CPUInfer.submit(moe.load_weights_task())\n                CPUInfer.sync()\n                CPUInfer.submit(moe.warm_up_task())\n                CPUInfer.sync()\n            else:\n                raise ValueError(f\"Unsupported quantization mode: {quant_mode}\")\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n\n        # validation\n        for i in range(validation_iter):\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n            input = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            input = input / 100\n            # 打印 input 的内容\n            print(\"input:\", input)\n            moe = moes[i % layer_num]\n            # print('expert ids:',expert_ids)\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n            print(\"cpuinfer output\", output)\n\n            gate_proj = gate_projs[i % layer_num]\n            up_proj = up_projs[i % layer_num]\n            down_proj = down_projs[i % layer_num]\n            t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj)\n            print(\"torch output\", t_output)\n\n            # print(output - t_output)\n            diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n            print(\"diff = \", diff)\n            if quant_mode == \"int4\":\n                assert diff < 0.35\n            else:\n                assert diff < 0.05\n\n\n# test_moe(\"bf16\")\n# test_moe(\"int8\")\ntest_moe(\"int4\")\n"
  },
  {
    "path": "kt-kernel/examples/test_rope.cpp",
    "content": "#include <cassert>\n#include <iostream>\n#include <random>\n#include <vector>\n\n#include \"../operators/rope.hpp\"\n\nstd::vector<float> create_random_vector(size_t total_size, std::vector<size_t> shape, unsigned int seed = 0) {\n  std::vector<float> vec(total_size);\n  std::mt19937 gen(seed == 0 ? std::random_device{}() : seed);\n  std::uniform_real_distribution<float> dist(-1.0f, 1.0f);\n  // for (size_t i = 0; i < total_size; ++i) {\n  //   vec[i] = 1; // dist(gen);\n  // }\n  for (size_t i = 0; i < shape[0]; ++i) {\n    size_t offset_i = i * shape[1] * shape[2] * shape[3];\n    for (size_t j = 0; j < shape[1]; ++j) {\n      size_t offset_j = j * shape[2] * shape[3];\n      for (size_t k = 0; k < shape[2]; ++k) {\n        size_t offset_k = k * shape[3];\n        for (size_t a = 0; a < shape[3]; ++a) {\n          vec[offset_i + offset_j + offset_k + a] = a;\n        }\n      }\n    }\n  }\n  return vec;\n}\n\nvoid print_vector_to_file(const std::vector<float>& vec, const char* filename) {\n  FILE* fp = fopen(filename, \"w\");\n  for (auto x : vec) {\n    fprintf(fp, \"%.2f \", x);\n  }\n  fclose(fp);\n}\n\nstd::pair<std::vector<float>, std::vector<float>> cpp_torch_rope_with_apply_single(\n    const std::vector<float>& q_in_const, const std::vector<float>& k_in_const,\n    DeepseekV3YarnRotaryEmbedding<float>& rotary_emb, size_t B, size_t H, size_t S, size_t D_rope) {\n  rotary_emb.init(S);\n\n  const float* full_cos_cache_ptr = rotary_emb.cos();\n  const float* full_sin_cache_ptr = rotary_emb.sin();\n\n  std::vector<float> q_out = q_in_const;\n  std::vector<float> k_out = k_in_const;\n\n  size_t stride_head = S * D_rope;\n  size_t stride_batch = H * stride_head;\n\n  for (size_t b = 0; b < B; ++b) {\n    for (size_t h = 0; h < H; ++h) {\n      float* current_k_head_ptr = k_out.data() + b * stride_batch + h * stride_head;\n      Rope<DeepseekV3YarnRotaryEmbedding<float>, float>::apply_multiple(rotary_emb, current_k_head_ptr,\n                                                                        static_cast<int>(D_rope), 0, S);\n      for (size_t s = 0; s < S; ++s) {\n        float* current_q_head_ptr = q_out.data() + b * stride_batch + h * stride_head + s * D_rope;\n\n        Rope<DeepseekV3YarnRotaryEmbedding<float>, float>::apply_single(rotary_emb, current_q_head_ptr,\n                                                                        static_cast<int>(D_rope), s);\n      }\n    }\n  }\n\n  return {q_out, k_out};\n}\n\nint main() {\n  size_t batch_size = 2;\n  size_t num_heads = 16;\n  size_t seq_len = 32;\n  size_t rope_size = 16;\n  float theta = 10000.0f;\n\n  float beta_fast_cfg = 32.0f;\n  float beta_slow_cfg = 1.0f;\n  float factor_cfg = 40.0f;\n  float mscale_cfg = 1.0f;\n  float mscale_all_dim_cfg = 1.0f;\n  size_t original_max_pos_embeddings_cfg = 4096;\n\n  std::cout << \"--- Test Parameters ---\" << std::endl;\n  std::cout << \"Batch Size: \" << batch_size << std::endl;\n  std::cout << \"Num Heads: \" << num_heads << std::endl;\n  std::cout << \"Seq Len: \" << seq_len << std::endl;\n  std::cout << \"Rope Size (dim): \" << rope_size << std::endl;\n  std::cout << \"Theta (base): \" << theta << std::endl;\n  std::cout << \"Scaling Factor: \" << factor_cfg << std::endl;\n  std::cout << \"Original Max Pos Embeddings: \" << original_max_pos_embeddings_cfg << std::endl;\n  std::cout << \"-----------------------\" << std::endl << std::endl;\n\n  DeepseekV3YarnRotaryEmbedding<float> rotary_emb(rope_size, original_max_pos_embeddings_cfg, theta, factor_cfg,\n                                                  original_max_pos_embeddings_cfg, beta_fast_cfg, beta_slow_cfg,\n                                                  mscale_cfg, mscale_all_dim_cfg);\n  std::cout << \"DeepseekV3YarnRotaryEmbedding instantiated.\" << std::endl;\n\n  size_t total_elements_per_tensor = batch_size * num_heads * seq_len * rope_size;\n\n  unsigned int q_seed = 123;\n  unsigned int k_seed = 456;\n  std::vector<float> q_pe_vec =\n      create_random_vector(total_elements_per_tensor, {batch_size, num_heads, seq_len, rope_size}, q_seed);\n  std::vector<float> k_pe_vec =\n      create_random_vector(total_elements_per_tensor, {batch_size, num_heads, seq_len, rope_size}, k_seed);\n\n  std::cout << \"Input Q_PE and K_PE vectors created. Total elements per tensor: \" << total_elements_per_tensor\n            << std::endl;\n\n  std::cout << std::endl;\n\n  std::cout << \"Applying RoPE using cpp_torch_rope_with_apply_single...\" << std::endl;\n  auto [q2_vec, k2_vec] =\n      cpp_torch_rope_with_apply_single(q_pe_vec, k_pe_vec, rotary_emb, batch_size, num_heads, seq_len, rope_size);\n  std::cout << \"RoPE application finished.\" << std::endl << std::endl;\n\n  std::cout << std::endl << \"test_rope.cpp finished successfully.\" << std::endl;\n\n  print_vector_to_file(q2_vec, \"q_cpp.out\");\n  print_vector_to_file(k2_vec, \"k_cpp.out\");\n\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/examples/test_rope.py",
    "content": "import torch\nfrom torch_attention import apply_rotary_pos_emb, DeepseekV3YarnRotaryEmbedding, DeepseekV3RotaryEmbedding\n\nbatch_size  = 1\nnum_heads   = 1\nseq_len     = 1024\nrope_size   = 64\ntheta       = 10000\n\nmax_position_embeddings =  163840\n\nscaling_cfg = {\n    \"beta_fast\": 32,\n    \"beta_slow\": 1,\n    \"factor\": 40,\n    \"mscale\": 1.0,\n    \"mscale_all_dim\": 1.0,\n    \"original_max_position_embeddings\": 4096,\n    \"type\": \"yarn\"\n}\n\nrotary_emb = DeepseekV3YarnRotaryEmbedding(\n    rope_size,\n    max_position_embeddings=max_position_embeddings,\n    scaling_factor=scaling_cfg[\"factor\"],\n    base=theta,\n    beta_fast=scaling_cfg[\"beta_fast\"],\n    beta_slow=scaling_cfg[\"beta_slow\"],\n    mscale=scaling_cfg[\"mscale\"],\n    mscale_all_dim=scaling_cfg[\"mscale_all_dim\"],\n    original_max_position_embeddings=scaling_cfg[\"original_max_position_embeddings\"],\n)\n\n\ndef load_fp16_tensor(file_path, shape):\n    with open(file_path, 'rb') as f:\n        raw_data = f.read()\n    tensor = torch.frombuffer(raw_data, dtype=torch.float16)\n    tensor = tensor.view(shape)  # 根据你的 shape reshape\n    return tensor\n\ndef load_fp32_tensor(file_path, shape):\n    with open(file_path, 'rb') as f:\n        raw_data = f.read()\n    tensor = torch.frombuffer(raw_data, dtype=torch.float32)\n    tensor = tensor.view(shape)  # 根据你的 shape reshape\n    return tensor\n\n#q_pe = torch.randn(batch_size, num_heads, seq_len, rope_size, dtype=torch.float32)\n#k_pe = torch.randn_like(q_pe)\n\nq_pe = load_fp16_tensor(\"csrc/ktransformers_ext/build/before_rope\",(batch_size, num_heads, seq_len, rope_size)) \n# k_pe = torch.ones_like(q_pe) \nk_pe = load_fp16_tensor(\"csrc/ktransformers_ext/build/before_rope\",(batch_size, num_heads, seq_len, rope_size)) \nprint(q_pe)\n\ncheck = load_fp16_tensor(\"csrc/ktransformers_ext/build/after_rope\",(batch_size, num_heads, seq_len, rope_size))\n\n\n\n\ndef torch_rope(q, k):\n    cos, sin = rotary_emb(q, seq_len=seq_len)\n\n    cos_to_check = load_fp32_tensor(\"csrc/ktransformers_ext/build/cos\",(seq_len, rope_size//2))\n    sin_to_check = load_fp32_tensor(\"csrc/ktransformers_ext/build/sin\",(seq_len, rope_size//2))\n\n\n    \n\n    sin = sin.unsqueeze(0)\n    cos = cos.unsqueeze(0)\n    q2, k2 = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1)\n    return q2, k2\n\nq2, k2 = torch_rope(q_pe, k_pe)\nprint(q2,k2)\nprint(check)\n\ndiff = torch.abs(q2 - check).max()\n\n\nprint(diff)\n\n# print(q2,k2)\n\n# print_tensor(q2, 'q_py.out')\n# print_tensor(k2, 'k_py.out')\n\n"
  },
  {
    "path": "kt-kernel/examples/test_softmax.py",
    "content": "\nimport torch\nfrom torch import nn\n\n\ndef load_fp16_tensor(file_path, shape):\n    with open(file_path, 'rb') as f:\n        raw_data = f.read()\n    tensor = torch.frombuffer(raw_data, dtype=torch.float16)\n    tensor = tensor.view(shape)  # 根据你的 shape reshape\n    return tensor\n\na = load_fp16_tensor(\"csrc/ktransformers_ext/build/before_softmax\", (64,1024))\ncheck = load_fp16_tensor(\"csrc/ktransformers_ext/build/after_softmax\", (64,1024))\n\n\na = nn.functional.softmax(a, dim=-1, dtype=torch.float16)\ndiff = torch.abs(a - check).max()\n\nprint(a)\nprint(check)\nprint(diff)\n\n\n"
  },
  {
    "path": "kt-kernel/examples/test_write_buffer.py",
    "content": "\"\"\"\nTest write_weight_scale_to_buffer for AMX MOE operators.\n\nSupports:\n- FP8: FP8 weights (1 byte) + float32 scales (block-wise)\n- FP8_PERCHANNEL: FP8 weights (1 byte) + float32 per-channel scales\n- BF16: Native BF16 weights (2 bytes), no scales\n\nUsage:\n    python test_write_buffer.py          # Run all modes\n    python test_write_buffer.py fp8      # Run FP8 only\n    python test_write_buffer.py fp8_perchannel  # Run FP8 per-channel only\n    python test_write_buffer.py bf16     # Run BF16 only\n\"\"\"\n\nimport os\nimport sys\nimport time\n\nimport torch\n\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"build\"))\n\nfrom kt_kernel import kt_kernel_ext\nfrom kt_kernel_ext import CPUInfer\n\n\ndef make_cpu_infer(thread_num=80):\n    return CPUInfer(thread_num)\n\n\ndef div_up(a, b):\n    return (a + b - 1) // b\n\n\ndef build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size):\n    cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n    cfg.max_len = 1\n    cfg.quant_config.bits = 8  # FP8\n    cfg.quant_config.group_size = group_size\n    cfg.quant_config.zero_point = False\n    cfg.pool = cpuinfer.backend_\n    return cfg\n\n\ndef build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):\n    cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n    cfg.max_len = 1\n    cfg.quant_config.bits = 8  # FP8\n    cfg.quant_config.group_size = 0  # Not used for per-channel\n    cfg.quant_config.zero_point = False\n    cfg.quant_config.per_channel = True\n    cfg.pool = cpuinfer.backend_\n    return cfg\n\n\ndef build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size):\n    cfg = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n    cfg.max_len = 1\n    cfg.pool = cpuinfer.backend_\n    return cfg\n\n\ndef allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size):\n    \"\"\"Allocate FP8 weights and scales for testing\"\"\"\n    # FP8 weights: 1 byte per element\n    per_mat_weight_bytes = hidden_size * intermediate_size\n    # FP8 scales: block-wise (group_size x group_size blocks), stored as float32\n    n_blocks_n_gate_up = div_up(intermediate_size, group_size)\n    n_blocks_k = div_up(hidden_size, group_size)\n    per_mat_scale_elems_gate_up = n_blocks_n_gate_up * n_blocks_k\n\n    # For down: n=hidden_size, k=intermediate_size\n    n_blocks_n_down = n_blocks_k\n    n_blocks_k_down = n_blocks_n_gate_up\n    per_mat_scale_elems_down = n_blocks_n_down * n_blocks_k_down\n\n    gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n    up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n    down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n\n    gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)\n    up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)\n    down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)\n\n    return {\n        \"gate_q\": gate_q,\n        \"up_q\": up_q,\n        \"down_q\": down_q,\n        \"gate_scale\": gate_scale,\n        \"up_scale\": up_scale,\n        \"down_scale\": down_scale,\n        \"per_mat_weight_bytes\": per_mat_weight_bytes,\n        \"per_mat_scale_elems_gate_up\": per_mat_scale_elems_gate_up,\n        \"per_mat_scale_elems_down\": per_mat_scale_elems_down,\n    }\n\n\ndef allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size):\n    \"\"\"Allocate FP8 per-channel weights and scales for testing\"\"\"\n    per_mat_weight_bytes = hidden_size * intermediate_size\n    per_mat_scale_elems_gate_up = intermediate_size  # one scale per output channel\n    per_mat_scale_elems_down = hidden_size\n\n    gate_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n    up_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n    down_q = torch.randint(0, 256, (expert_num * per_mat_weight_bytes,), dtype=torch.uint8)\n\n    gate_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)\n    up_scale = torch.randn(expert_num * per_mat_scale_elems_gate_up, dtype=torch.float32)\n    down_scale = torch.randn(expert_num * per_mat_scale_elems_down, dtype=torch.float32)\n\n    return {\n        \"gate_q\": gate_q,\n        \"up_q\": up_q,\n        \"down_q\": down_q,\n        \"gate_scale\": gate_scale,\n        \"up_scale\": up_scale,\n        \"down_scale\": down_scale,\n        \"per_mat_weight_bytes\": per_mat_weight_bytes,\n        \"per_mat_scale_elems_gate_up\": per_mat_scale_elems_gate_up,\n        \"per_mat_scale_elems_down\": per_mat_scale_elems_down,\n    }\n\n\ndef allocate_weights_bf16(expert_num, hidden_size, intermediate_size):\n    \"\"\"Allocate BF16 weights for testing (no scales)\"\"\"\n    # BF16 weights: 2 bytes per element\n    per_mat_weight_elems = hidden_size * intermediate_size\n    per_mat_weight_bytes = per_mat_weight_elems * 2  # BF16 = 2 bytes\n\n    gate_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)\n    up_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)\n    down_proj = torch.randn(expert_num * per_mat_weight_elems, dtype=torch.bfloat16)\n\n    return {\n        \"gate_proj\": gate_proj,\n        \"up_proj\": up_proj,\n        \"down_proj\": down_proj,\n        \"per_mat_weight_bytes\": per_mat_weight_bytes,\n        \"per_mat_weight_elems\": per_mat_weight_elems,\n    }\n\n\ndef test_fp8_write_buffer(gpu_tp_count):\n    \"\"\"Test write_weight_scale_to_buffer with FP8 weights\"\"\"\n    torch.manual_seed(123)\n\n    expert_num = 256\n    gpu_experts = expert_num\n    num_experts_per_tok = 8\n    hidden_size = 3072\n    intermediate_size = 1536\n    group_size = 128\n\n    cpuinfer = make_cpu_infer()\n    cfg = build_config_fp8(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size, group_size)\n    weights = allocate_weights_fp8(expert_num, hidden_size, intermediate_size, group_size)\n\n    cfg.gate_proj = weights[\"gate_q\"].data_ptr()\n    cfg.up_proj = weights[\"up_q\"].data_ptr()\n    cfg.down_proj = weights[\"down_q\"].data_ptr()\n    cfg.gate_scale = weights[\"gate_scale\"].data_ptr()\n    cfg.up_scale = weights[\"up_scale\"].data_ptr()\n    cfg.down_scale = weights[\"down_scale\"].data_ptr()\n\n    moe = kt_kernel_ext.moe.AMXFP8_MOE(cfg)\n\n    physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device=\"cpu\").contiguous()\n    cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n    cpuinfer.sync()\n\n    per_mat_weight_bytes = weights[\"per_mat_weight_bytes\"]\n    per_mat_scale_elems_gate_up = weights[\"per_mat_scale_elems_gate_up\"]\n    per_mat_scale_elems_down = weights[\"per_mat_scale_elems_down\"]\n\n    # Calculate sizes per TP part\n    weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count\n    gpu_n_w13 = intermediate_size // gpu_tp_count\n    gpu_k_w13 = hidden_size\n    scale_elems_per_expert_per_tp_gate_up = div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size)\n    gpu_n_w2 = hidden_size\n    gpu_k_w2 = intermediate_size // gpu_tp_count\n    scale_elems_per_expert_per_tp_down = div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size)\n\n    total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp\n    total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up\n    total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down\n\n    # Create buffer lists\n    w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w13_scale_bufs = [\n        torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)\n    ]\n    w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]\n\n    print(f\"[FP8] GPU TP count: {gpu_tp_count}, Experts: {expert_num}\")\n    print(f\"[FP8] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}\")\n    print(f\"[FP8] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}\")\n\n    def get_expert_ptrs(expert_id):\n        w13_weight_ptrs = []\n        w13_scale_ptrs = []\n        w2_weight_ptrs = []\n        w2_scale_ptrs = []\n        for tp_idx in range(gpu_tp_count):\n            w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp\n            w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up\n            w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp\n            w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down\n\n            w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)\n            w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4)\n            w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)\n            w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4)\n        return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs\n\n    # Warm up\n    for _ in range(2):\n        for expert_id in range(gpu_experts):\n            w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)\n            cpuinfer.submit(\n                moe.write_weight_scale_to_buffer_task(\n                    gpu_tp_count=gpu_tp_count,\n                    expert_id=expert_id,\n                    w13_weight_ptrs=w13_weight_ptrs,\n                    w13_scale_ptrs=w13_scale_ptrs,\n                    w2_weight_ptrs=w2_weight_ptrs,\n                    w2_scale_ptrs=w2_scale_ptrs,\n                )\n            )\n            cpuinfer.sync()\n\n    # Timing\n    begin_time = time.perf_counter_ns()\n    for expert_id in range(gpu_experts):\n        w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)\n        cpuinfer.submit(\n            moe.write_weight_scale_to_buffer_task(\n                gpu_tp_count=gpu_tp_count,\n                expert_id=expert_id,\n                w13_weight_ptrs=w13_weight_ptrs,\n                w13_scale_ptrs=w13_scale_ptrs,\n                w2_weight_ptrs=w2_weight_ptrs,\n                w2_scale_ptrs=w2_scale_ptrs,\n            )\n        )\n        cpuinfer.sync()\n    end_time = time.perf_counter_ns()\n    elapsed_ms = (end_time - begin_time) / 1e6\n\n    total_bytes = (\n        hidden_size * intermediate_size * gpu_experts * 3\n        + (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4\n    )\n    print(f\"[FP8] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms\")\n    print(f\"[FP8] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s\")\n\n    # Verify correctness\n    def split_expert_tensor(tensor, chunk):\n        return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]\n\n    gate_q = weights[\"gate_q\"]\n    up_q = weights[\"up_q\"]\n    down_q = weights[\"down_q\"]\n    gate_scale = weights[\"gate_scale\"]\n    up_scale = weights[\"up_scale\"]\n    down_scale = weights[\"down_scale\"]\n\n    gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)\n    up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)\n    down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)\n    gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)\n    up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)\n    down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)\n\n    n_blocks_n = div_up(hidden_size, group_size)\n    n_blocks_k = div_up(intermediate_size, group_size)\n    n_blocks_k_per_tp = n_blocks_k // gpu_tp_count\n\n    for tp_idx in range(gpu_tp_count):\n        expected_w13_weights = []\n        expected_w13_scales = []\n        expected_w2_weights = []\n        expected_w2_scales = []\n\n        weight13_per_tp = per_mat_weight_bytes // gpu_tp_count\n        scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count\n\n        for expert_id in range(gpu_experts):\n            start_weight = tp_idx * weight13_per_tp\n            end_weight = (tp_idx + 1) * weight13_per_tp\n            start_scale = tp_idx * scale13_per_tp\n            end_scale = (tp_idx + 1) * scale13_per_tp\n\n            gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]\n            gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]\n            up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]\n            up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]\n\n            down_weight_tp_parts = []\n            down_scale_tp_parts = []\n            tp_slice_weight_size = intermediate_size // gpu_tp_count\n\n            for row_idx in range(hidden_size):\n                row_weight_start = row_idx * intermediate_size\n                tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size\n                down_weight_tp_parts.append(\n                    down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]\n                )\n\n            for bn in range(n_blocks_n):\n                row_scale_start = bn * n_blocks_k\n                tp_scale_offset = row_scale_start + tp_idx * n_blocks_k_per_tp\n                down_scale_tp_parts.append(\n                    down_scale_experts[expert_id][tp_scale_offset : tp_scale_offset + n_blocks_k_per_tp]\n                )\n\n            down_weight_tp = torch.cat(down_weight_tp_parts)\n            down_scale_tp = torch.cat(down_scale_tp_parts)\n\n            expected_w13_weights.append(gate_weight_tp)\n            expected_w13_weights.append(up_weight_tp)\n            expected_w13_scales.append(gate_scale_tp)\n            expected_w13_scales.append(up_scale_tp)\n            expected_w2_weights.append(down_weight_tp)\n            expected_w2_scales.append(down_scale_tp)\n\n        expected_w13_weight = torch.cat(expected_w13_weights)\n        expected_w13_scale = torch.cat(expected_w13_scales)\n        expected_w2_weight = torch.cat(expected_w2_weights)\n        expected_w2_scale = torch.cat(expected_w2_scales)\n\n        if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):\n            diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight\n            first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1\n            raise AssertionError(f\"[FP8] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}\")\n\n        if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):\n            raise AssertionError(f\"[FP8] w13 scale mismatch for TP {tp_idx}\")\n\n        if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):\n            diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight\n            first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1\n            raise AssertionError(f\"[FP8] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}\")\n\n        if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):\n            raise AssertionError(f\"[FP8] w2 scale mismatch for TP {tp_idx}\")\n\n    print(f\"[FP8] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)\")\n    return True\n\n\ndef test_fp8_perchannel_write_buffer(gpu_tp_count):\n    \"\"\"Test write_weight_scale_to_buffer with FP8 per-channel weights\"\"\"\n    torch.manual_seed(123)\n\n    expert_num = 256\n    gpu_experts = expert_num\n    num_experts_per_tok = 8\n    hidden_size = 3072\n    intermediate_size = 1536\n\n    cpuinfer = make_cpu_infer()\n    cfg = build_config_fp8_perchannel(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n    weights = allocate_weights_fp8_perchannel(expert_num, hidden_size, intermediate_size)\n\n    cfg.gate_proj = weights[\"gate_q\"].data_ptr()\n    cfg.up_proj = weights[\"up_q\"].data_ptr()\n    cfg.down_proj = weights[\"down_q\"].data_ptr()\n    cfg.gate_scale = weights[\"gate_scale\"].data_ptr()\n    cfg.up_scale = weights[\"up_scale\"].data_ptr()\n    cfg.down_scale = weights[\"down_scale\"].data_ptr()\n\n    moe = kt_kernel_ext.moe.AMXFP8PerChannel_MOE(cfg)\n\n    physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device=\"cpu\").contiguous()\n    cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n    cpuinfer.sync()\n\n    per_mat_weight_bytes = weights[\"per_mat_weight_bytes\"]\n    per_mat_scale_elems_gate_up = weights[\"per_mat_scale_elems_gate_up\"]\n    per_mat_scale_elems_down = weights[\"per_mat_scale_elems_down\"]\n\n    weight_bytes_per_expert_per_tp = per_mat_weight_bytes // gpu_tp_count\n    gpu_n_w13 = intermediate_size // gpu_tp_count\n    scale_elems_per_expert_per_tp_gate_up = gpu_n_w13\n    scale_elems_per_expert_per_tp_down = per_mat_scale_elems_down\n\n    total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp\n    total_scale_elems_per_tp_gate_up = gpu_experts * scale_elems_per_expert_per_tp_gate_up\n    total_scale_elems_per_tp_down = gpu_experts * scale_elems_per_expert_per_tp_down\n\n    w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w13_scale_bufs = [\n        torch.empty(2 * total_scale_elems_per_tp_gate_up, dtype=torch.float32) for _ in range(gpu_tp_count)\n    ]\n    w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w2_scale_bufs = [torch.empty(total_scale_elems_per_tp_down, dtype=torch.float32) for _ in range(gpu_tp_count)]\n\n    print(f\"[FP8_PERCHANNEL] GPU TP count: {gpu_tp_count}, Experts: {expert_num}\")\n    print(f\"[FP8_PERCHANNEL] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}\")\n    print(f\"[FP8_PERCHANNEL] Scale elements per expert per TP (gate/up): {scale_elems_per_expert_per_tp_gate_up}\")\n\n    def get_expert_ptrs(expert_id):\n        w13_weight_ptrs = []\n        w13_scale_ptrs = []\n        w2_weight_ptrs = []\n        w2_scale_ptrs = []\n        for tp_idx in range(gpu_tp_count):\n            w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp\n            w13_scale_expert_offset = expert_id * 2 * scale_elems_per_expert_per_tp_gate_up\n            w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp\n            w2_scale_expert_offset = expert_id * scale_elems_per_expert_per_tp_down\n\n            w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)\n            w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr() + w13_scale_expert_offset * 4)\n            w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)\n            w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr() + w2_scale_expert_offset * 4)\n        return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs\n\n    for _ in range(2):\n        for expert_id in range(gpu_experts):\n            w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)\n            cpuinfer.submit(\n                moe.write_weight_scale_to_buffer_task(\n                    gpu_tp_count=gpu_tp_count,\n                    expert_id=expert_id,\n                    w13_weight_ptrs=w13_weight_ptrs,\n                    w13_scale_ptrs=w13_scale_ptrs,\n                    w2_weight_ptrs=w2_weight_ptrs,\n                    w2_scale_ptrs=w2_scale_ptrs,\n                )\n            )\n            cpuinfer.sync()\n\n    begin_time = time.perf_counter_ns()\n    for expert_id in range(gpu_experts):\n        w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)\n        cpuinfer.submit(\n            moe.write_weight_scale_to_buffer_task(\n                gpu_tp_count=gpu_tp_count,\n                expert_id=expert_id,\n                w13_weight_ptrs=w13_weight_ptrs,\n                w13_scale_ptrs=w13_scale_ptrs,\n                w2_weight_ptrs=w2_weight_ptrs,\n                w2_scale_ptrs=w2_scale_ptrs,\n            )\n        )\n        cpuinfer.sync()\n    end_time = time.perf_counter_ns()\n    elapsed_ms = (end_time - begin_time) / 1e6\n\n    total_bytes = (\n        hidden_size * intermediate_size * gpu_experts * 3\n        + (per_mat_scale_elems_gate_up * 2 + per_mat_scale_elems_down) * gpu_experts * 4\n    )\n    print(f\"[FP8_PERCHANNEL] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms\")\n    print(f\"[FP8_PERCHANNEL] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s\")\n\n    def split_expert_tensor(tensor, chunk):\n        return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]\n\n    gate_q = weights[\"gate_q\"]\n    up_q = weights[\"up_q\"]\n    down_q = weights[\"down_q\"]\n    gate_scale = weights[\"gate_scale\"]\n    up_scale = weights[\"up_scale\"]\n    down_scale = weights[\"down_scale\"]\n\n    gate_q_experts = split_expert_tensor(gate_q, per_mat_weight_bytes)\n    up_q_experts = split_expert_tensor(up_q, per_mat_weight_bytes)\n    down_q_experts = split_expert_tensor(down_q, per_mat_weight_bytes)\n    gate_scale_experts = split_expert_tensor(gate_scale, per_mat_scale_elems_gate_up)\n    up_scale_experts = split_expert_tensor(up_scale, per_mat_scale_elems_gate_up)\n    down_scale_experts = split_expert_tensor(down_scale, per_mat_scale_elems_down)\n\n    for tp_idx in range(gpu_tp_count):\n        expected_w13_weights = []\n        expected_w13_scales = []\n        expected_w2_weights = []\n        expected_w2_scales = []\n\n        weight13_per_tp = per_mat_weight_bytes // gpu_tp_count\n        scale13_per_tp = per_mat_scale_elems_gate_up // gpu_tp_count\n\n        for expert_id in range(gpu_experts):\n            start_weight = tp_idx * weight13_per_tp\n            end_weight = (tp_idx + 1) * weight13_per_tp\n            start_scale = tp_idx * scale13_per_tp\n            end_scale = (tp_idx + 1) * scale13_per_tp\n\n            gate_weight_tp = gate_q_experts[expert_id][start_weight:end_weight]\n            gate_scale_tp = gate_scale_experts[expert_id][start_scale:end_scale]\n            up_weight_tp = up_q_experts[expert_id][start_weight:end_weight]\n            up_scale_tp = up_scale_experts[expert_id][start_scale:end_scale]\n\n            down_weight_tp_parts = []\n            tp_slice_weight_size = intermediate_size // gpu_tp_count\n\n            for row_idx in range(hidden_size):\n                row_weight_start = row_idx * intermediate_size\n                tp_weight_offset = row_weight_start + tp_idx * tp_slice_weight_size\n                down_weight_tp_parts.append(\n                    down_q_experts[expert_id][tp_weight_offset : tp_weight_offset + tp_slice_weight_size]\n                )\n\n            down_weight_tp = torch.cat(down_weight_tp_parts)\n            down_scale_tp = down_scale_experts[expert_id]\n\n            expected_w13_weights.append(gate_weight_tp)\n            expected_w13_weights.append(up_weight_tp)\n            expected_w13_scales.append(gate_scale_tp)\n            expected_w13_scales.append(up_scale_tp)\n            expected_w2_weights.append(down_weight_tp)\n            expected_w2_scales.append(down_scale_tp)\n\n        expected_w13_weight = torch.cat(expected_w13_weights)\n        expected_w13_scale = torch.cat(expected_w13_scales)\n        expected_w2_weight = torch.cat(expected_w2_weights)\n        expected_w2_scale = torch.cat(expected_w2_scales)\n\n        if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):\n            diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight\n            first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1\n            raise AssertionError(f\"[FP8_PERCHANNEL] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}\")\n\n        if not torch.allclose(w13_scale_bufs[tp_idx], expected_w13_scale):\n            raise AssertionError(f\"[FP8_PERCHANNEL] w13 scale mismatch for TP {tp_idx}\")\n\n        if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):\n            diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight\n            first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1\n            raise AssertionError(f\"[FP8_PERCHANNEL] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}\")\n\n        if not torch.allclose(w2_scale_bufs[tp_idx], expected_w2_scale):\n            raise AssertionError(f\"[FP8_PERCHANNEL] w2 scale mismatch for TP {tp_idx}\")\n\n    print(f\"[FP8_PERCHANNEL] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)\")\n    return True\n\n\ndef test_bf16_write_buffer(gpu_tp_count):\n    \"\"\"Test write_weight_scale_to_buffer with BF16 weights (no scales)\"\"\"\n    torch.manual_seed(123)\n\n    expert_num = 16\n    gpu_experts = expert_num\n    num_experts_per_tok = 8\n    hidden_size = 3072\n    intermediate_size = 1536\n\n    cpuinfer = make_cpu_infer()\n    cfg = build_config_bf16(cpuinfer, expert_num, num_experts_per_tok, hidden_size, intermediate_size)\n    weights = allocate_weights_bf16(expert_num, hidden_size, intermediate_size)\n\n    cfg.gate_proj = weights[\"gate_proj\"].data_ptr()\n    cfg.up_proj = weights[\"up_proj\"].data_ptr()\n    cfg.down_proj = weights[\"down_proj\"].data_ptr()\n    cfg.gate_scale = 0\n    cfg.up_scale = 0\n    cfg.down_scale = 0\n\n    moe = kt_kernel_ext.moe.AMXBF16_MOE(cfg)\n\n    physical_to_logical_map = torch.arange(expert_num, dtype=torch.int64, device=\"cpu\").contiguous()\n    cpuinfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n    cpuinfer.sync()\n\n    per_mat_weight_elems = weights[\"per_mat_weight_elems\"]\n\n    # Calculate sizes per TP part (BF16 = 2 bytes per element)\n    weight_elems_per_expert_per_tp = per_mat_weight_elems // gpu_tp_count\n    weight_bytes_per_expert_per_tp = weight_elems_per_expert_per_tp * 2\n\n    total_weight_bytes_per_tp = gpu_experts * weight_bytes_per_expert_per_tp\n\n    # Create buffer lists (BF16: weights only, no scales)\n    w13_weight_bufs = [torch.empty(2 * total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    w2_weight_bufs = [torch.empty(total_weight_bytes_per_tp, dtype=torch.uint8) for _ in range(gpu_tp_count)]\n    # Empty scale buffers (not used for BF16 but needed for interface)\n    w13_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]\n    w2_scale_bufs = [torch.empty(1, dtype=torch.float32) for _ in range(gpu_tp_count)]\n\n    print(f\"[BF16] GPU TP count: {gpu_tp_count}, Experts: {expert_num}\")\n    print(f\"[BF16] Weight bytes per expert per TP: {weight_bytes_per_expert_per_tp}\")\n\n    def get_expert_ptrs(expert_id):\n        w13_weight_ptrs = []\n        w13_scale_ptrs = []\n        w2_weight_ptrs = []\n        w2_scale_ptrs = []\n        for tp_idx in range(gpu_tp_count):\n            w13_weight_expert_offset = expert_id * 2 * weight_bytes_per_expert_per_tp\n            w2_weight_expert_offset = expert_id * weight_bytes_per_expert_per_tp\n\n            w13_weight_ptrs.append(w13_weight_bufs[tp_idx].data_ptr() + w13_weight_expert_offset)\n            w13_scale_ptrs.append(w13_scale_bufs[tp_idx].data_ptr())  # Not used\n            w2_weight_ptrs.append(w2_weight_bufs[tp_idx].data_ptr() + w2_weight_expert_offset)\n            w2_scale_ptrs.append(w2_scale_bufs[tp_idx].data_ptr())  # Not used\n        return w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs\n\n    # Warm up\n    for _ in range(2):\n        for expert_id in range(gpu_experts):\n            w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)\n            cpuinfer.submit(\n                moe.write_weight_scale_to_buffer_task(\n                    gpu_tp_count=gpu_tp_count,\n                    expert_id=expert_id,\n                    w13_weight_ptrs=w13_weight_ptrs,\n                    w13_scale_ptrs=w13_scale_ptrs,\n                    w2_weight_ptrs=w2_weight_ptrs,\n                    w2_scale_ptrs=w2_scale_ptrs,\n                )\n            )\n            cpuinfer.sync()\n\n    # Timing\n    begin_time = time.perf_counter_ns()\n    for expert_id in range(gpu_experts):\n        w13_weight_ptrs, w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs = get_expert_ptrs(expert_id)\n        cpuinfer.submit(\n            moe.write_weight_scale_to_buffer_task(\n                gpu_tp_count=gpu_tp_count,\n                expert_id=expert_id,\n                w13_weight_ptrs=w13_weight_ptrs,\n                w13_scale_ptrs=w13_scale_ptrs,\n                w2_weight_ptrs=w2_weight_ptrs,\n                w2_scale_ptrs=w2_scale_ptrs,\n            )\n        )\n        cpuinfer.sync()\n    end_time = time.perf_counter_ns()\n    elapsed_ms = (end_time - begin_time) / 1e6\n\n    total_bytes = hidden_size * intermediate_size * gpu_experts * 3 * 2  # BF16 = 2 bytes\n    print(f\"[BF16] write_weight_scale_to_buffer time: {elapsed_ms:.2f} ms\")\n    print(f\"[BF16] Throughput: {total_bytes / (elapsed_ms * 1e6):.2f} GB/s\")\n\n    # Verify correctness (BF16: weights only, no scales)\n    def split_expert_tensor(tensor, chunk):\n        return [tensor[i * chunk : (i + 1) * chunk] for i in range(expert_num)]\n\n    gate_proj = weights[\"gate_proj\"]\n    up_proj = weights[\"up_proj\"]\n    down_proj = weights[\"down_proj\"]\n\n    # View BF16 as uint8 for byte-level comparison\n    gate_bytes = gate_proj.view(torch.uint8)\n    up_bytes = up_proj.view(torch.uint8)\n    down_bytes = down_proj.view(torch.uint8)\n\n    per_mat_bytes = per_mat_weight_elems * 2  # BF16 = 2 bytes\n    gate_experts = split_expert_tensor(gate_bytes, per_mat_bytes)\n    up_experts = split_expert_tensor(up_bytes, per_mat_bytes)\n    down_experts = split_expert_tensor(down_bytes, per_mat_bytes)\n\n    for tp_idx in range(gpu_tp_count):\n        expected_w13_weights = []\n        expected_w2_weights = []\n\n        weight_bytes_per_tp = per_mat_bytes // gpu_tp_count\n\n        for expert_id in range(gpu_experts):\n            start_weight = tp_idx * weight_bytes_per_tp\n            end_weight = (tp_idx + 1) * weight_bytes_per_tp\n\n            gate_weight_tp = gate_experts[expert_id][start_weight:end_weight]\n            up_weight_tp = up_experts[expert_id][start_weight:end_weight]\n\n            # Down matrix: sliced column-wise (BF16 = 2 bytes per element)\n            down_weight_tp_parts = []\n            tp_slice_elems = intermediate_size // gpu_tp_count\n            tp_slice_bytes = tp_slice_elems * 2\n\n            for row_idx in range(hidden_size):\n                row_byte_start = row_idx * intermediate_size * 2\n                tp_byte_offset = row_byte_start + tp_idx * tp_slice_bytes\n                down_weight_tp_parts.append(down_experts[expert_id][tp_byte_offset : tp_byte_offset + tp_slice_bytes])\n\n            down_weight_tp = torch.cat(down_weight_tp_parts)\n\n            expected_w13_weights.append(gate_weight_tp)\n            expected_w13_weights.append(up_weight_tp)\n            expected_w2_weights.append(down_weight_tp)\n\n        expected_w13_weight = torch.cat(expected_w13_weights)\n        expected_w2_weight = torch.cat(expected_w2_weights)\n\n        if not torch.equal(w13_weight_bufs[tp_idx], expected_w13_weight):\n            diff_mask = w13_weight_bufs[tp_idx] != expected_w13_weight\n            first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1\n            raise AssertionError(f\"[BF16] w13 weight mismatch for TP {tp_idx} at index {first_diff_idx}\")\n\n        if not torch.equal(w2_weight_bufs[tp_idx], expected_w2_weight):\n            diff_mask = w2_weight_bufs[tp_idx] != expected_w2_weight\n            first_diff_idx = diff_mask.nonzero()[0].item() if diff_mask.any() else -1\n            raise AssertionError(f\"[BF16] w2 weight mismatch for TP {tp_idx} at index {first_diff_idx}\")\n\n    print(f\"[BF16] TP={gpu_tp_count} PASSED (verified {gpu_experts} experts across {gpu_tp_count} TP parts)\")\n    return True\n\n\ndef test_with_tp(quant_mode: str, gpu_tp_count: int):\n    \"\"\"Test write_weight_scale_to_buffer with specified mode and TP count\"\"\"\n    if quant_mode == \"fp8\":\n        return test_fp8_write_buffer(gpu_tp_count)\n    elif quant_mode == \"fp8_perchannel\":\n        return test_fp8_perchannel_write_buffer(gpu_tp_count)\n    elif quant_mode == \"bf16\":\n        return test_bf16_write_buffer(gpu_tp_count)\n    else:\n        raise ValueError(f\"Unsupported quant_mode: {quant_mode}\")\n\n\ndef main(quant_modes=None):\n    \"\"\"Run tests for specified quant modes\"\"\"\n    if quant_modes is None:\n        quant_modes = [\"fp8\", \"fp8_perchannel\", \"bf16\"]\n\n    tp_values = [1, 2, 4]\n    all_passed = True\n    results = {}\n\n    for quant_mode in quant_modes:\n        print(\"\\n\" + \"=\" * 60)\n        print(f\"Testing {quant_mode.upper()} write_weight_scale_to_buffer\")\n        print(\"=\" * 60)\n\n        for tp in tp_values:\n            print(f\"\\n--- Testing {quant_mode.upper()} with gpu_tp_count = {tp} ---\")\n            try:\n                test_with_tp(quant_mode, tp)\n                results[(quant_mode, tp)] = \"PASSED\"\n            except Exception as e:\n                results[(quant_mode, tp)] = f\"FAILED: {e}\"\n                all_passed = False\n                print(f\"[{quant_mode.upper()}] TP={tp} FAILED: {e}\")\n\n    print(\"\\n\" + \"=\" * 60)\n    print(\"SUMMARY\")\n    print(\"=\" * 60)\n    for (mode, tp), result in results.items():\n        status = \"PASS\" if \"PASSED\" in result else \"FAIL\"\n        print(f\"  [{status}] {mode.upper()} TP={tp}: {result}\")\n\n    if all_passed:\n        print(\"\\nALL TESTS PASSED\")\n    else:\n        print(\"\\nSOME TESTS FAILED\")\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    if len(sys.argv) > 1:\n        mode = sys.argv[1].lower()\n        if mode in [\"fp8\", \"fp8_perchannel\", \"bf16\"]:\n            main([mode])\n        else:\n            print(f\"Unknown mode: {mode}. Use 'fp8', 'fp8_perchannel' or 'bf16'\")\n            sys.exit(1)\n    else:\n        main()\n"
  },
  {
    "path": "kt-kernel/examples/torch_attention.py",
    "content": "\nimport math\nimport os, sys\nimport time\nimport subprocess\nimport platform\nimport json\nfrom typing import Any, Dict, Optional, Tuple\nimport torch\nimport torch.nn.init as init\nfrom torch import nn\n\nclass KDeepSeekV3Cache(nn.Module):\n    def __init__(\n        self,\n        # config: PretrainedConfig,\n        page_size: int = 256,\n        kv_lora_rank: int = 128,\n        k_caches: Optional[torch.Tensor] = None,\n        dtype=torch.bfloat16,\n        device=torch.device(\"cuda:0\"),\n        \n    ):\n        super().__init__()\n        # self.config = config\n        self.dtype = dtype\n        self.device = device\n        self.kv_lora_rank = kv_lora_rank\n        self.page_size = page_size\n        self.v_caches = []\n        self.k_caches = k_caches\n\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n\n        page_idx: torch.Tensor,\n        page_offset: torch.Tensor,\n\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.\n        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.\n\n        Parameters:\n            key_states (`torch.Tensor`):\n                The new key states to cache.\n            value_states (`torch.Tensor`):\n                The new value states to cache.\n            layer_idx (`int`):\n                The index of the layer to cache the states for.\n            cache_kwargs (`Dict[str, Any]`, `optional`):\n                Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input\n                to know how where to write in the cache.\n\n        Return:\n            A tuple containing the updated key and value states.\n        \"\"\"\n        k_out = self.k_caches[layer_idx]\n\n        k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:])\n        k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:])\n        return k_out\n\n        \n    def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):\n        page_offset = cache_position % self.page_size  \n        page_idx_local = cache_position // self.page_size  \n        query_ids = torch.zeros_like(cache_position)\n        for i in range(len(q_indptr) - 1):\n            start_idx = q_indptr[i]\n            end_idx = q_indptr[i + 1]\n            query_ids[start_idx:end_idx] = i\n        page_idx = torch.zeros_like(page_idx_local)\n        for i in range(bsz_tensors[0]):\n            query_id = query_ids[i]\n            local_block = page_idx_local[i]\n            start_block = kv_indptr[query_id]\n            if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:\n                page_idx[i] = kv_indices[start_block + local_block]\n        \n        return page_idx, page_offset\n\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    b, h, s, d = q.shape\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n    b, h, s, d = k.shape\n    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass DeepseekV2RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        DeepseekV2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return (self.weight * hidden_states).to(input_dtype)\n\n\nclass DeepseekV2RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        super().__init__()\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\n    @torch.no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)   \n\nclass DeepseekV3RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (\n            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)\n        )\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings,\n            device=self.inv_freq.device,\n            dtype=torch.get_default_dtype(),\n        )\n        # self.max_seq_len_cached = None\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n\n        freqs = torch.outer(t, self.inv_freq.to(t.device))\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        print(\"emb\", emb.shape)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if self.max_seq_len_cached is None: # or seq_len[-1] > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[seq_len].to(dtype=x.dtype),\n            self.sin_cached[seq_len].to(dtype=x.dtype),\n        )\n\n# Inverse dim formula to find dim based on number of rotations\ndef yarn_find_correction_dim(\n    num_rotations, dim, base=10000, max_position_embeddings=2048\n):\n    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (\n        2 * math.log(base)\n    )\n\n\n# Find dim range bounds based on rotations\ndef yarn_find_correction_range(\n    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048\n):\n    low = math.floor(\n        yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)\n    )\n    high = math.ceil(\n        yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)\n    )\n    return max(low, 0), min(high, dim - 1)  # Clamp values just in case\n\ndef yarn_linear_ramp_mask(min, max, dim):\n    if min == max:\n        max += 0.001  # Prevent singularity\n\n    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n    ramp_func = torch.clamp(linear_func, 0, 1)\n    return ramp_func\n\ndef yarn_get_mscale(scale=1, mscale=1):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\nclass DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n        original_max_position_embeddings=4096,\n        beta_fast=32,\n        beta_slow=1,\n        mscale=1,\n        mscale_all_dim=0,\n    ):\n        self.scaling_factor = scaling_factor\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self.beta_fast = beta_fast\n        self.beta_slow = beta_slow\n        self.mscale = mscale\n        self.mscale_all_dim = mscale_all_dim\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        dim = self.dim\n\n        freq_extra = 1.0 / (\n            self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n        freq_inter = 1.0 / (\n            self.scaling_factor\n            * self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n\n        low, high = yarn_find_correction_range(\n            self.beta_fast,\n            self.beta_slow,\n            dim,\n            self.base,\n            self.original_max_position_embeddings,\n        )\n        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(\n            device=device, dtype=torch.float32\n        )\n        inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        # 判断 seq_len是否是 tensor\n        if isinstance(seq_len,torch.Tensor):\n            t = seq_len\n        else:\n            t = torch.arange(seq_len, device=device, dtype=torch.float32)\n\n        freqs = torch.outer(t, inv_freq)\n\n        _mscale = float(\n            yarn_get_mscale(self.scaling_factor, self.mscale)\n            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)\n        )\n\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\n            \"cos_cached\", (emb.cos() * _mscale).to(dtype), persistent=False\n        )\n        self.register_buffer(\n            \"sin_cached\", (emb.sin() * _mscale).to(dtype), persistent=False\n        )\n\n"
  },
  {
    "path": "kt-kernel/ext_bindings.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022, Jianwei Dong\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n// Python bindings\n#include <sys/types.h>\n\n#include <cstddef>\n\n#include \"cpu_backend/cpuinfer.h\"\n#include \"cpu_backend/worker_pool.h\"\n#include \"operators/common.hpp\"\n\n#if defined(USE_MOE_KERNEL)\n#include \"operators/moe_kernel/la/kernel.hpp\"\n#include \"operators/moe_kernel/moe.hpp\"\n#endif\n\n#if defined(__aarch64__) && defined(CPU_USE_KML)\n#if defined(KTRANSFORMERS_CPU_MLA)\n#include \"operators/kml/deepseekv3.hpp\"\n#include \"operators/kml/gate.hpp\"\n#include \"operators/kml/mla.hpp\"\n#include \"operators/kml/mla_int8.hpp\"\n#endif\n#include \"operators/kml/moe.hpp\"\nstatic const bool _is_plain_ = true;\n#else\nstatic const bool _is_plain_ = false;\n#endif\n\n#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)\n#include \"operators/amx/awq-moe.hpp\"\n#if defined(__AVX512BF16__)\n#include \"operators/amx/bf16-moe.hpp\"            // Native BF16 MoE using CRTP pattern\n#include \"operators/amx/fp8-moe.hpp\"             // FP8 MoE requires AVX512 BF16 support\n#include \"operators/amx/fp8-perchannel-moe.hpp\"  // FP8 Per-Channel MoE for GLM-4.7-FP8\n#endif\n#include \"operators/amx/k2-moe.hpp\"\n#include \"operators/amx/la/amx_kernels.hpp\"\n#include \"operators/amx/moe.hpp\"\n#endif\n#include <pybind11/stl.h>  // std::vector/std::pair/std::string conversions\n\n#include <cstdint>\n#include <memory>\n#include <type_traits>\n\n#include \"operators/kvcache/kvcache.h\"\n#include \"operators/llamafile/linear.h\"\n#include \"operators/llamafile/mla.hpp\"\n#include \"operators/llamafile/mlp.h\"\n#include \"operators/llamafile/moe.hpp\"\n#include \"pybind11/pybind11.h\"\n\nnamespace py = pybind11;\nusing namespace pybind11::literals;\n\npy::object to_float_ptr(uintptr_t input_ptr, int size, ggml_type type) {\n  if (type < 0 || type >= GGML_TYPE_COUNT) {\n    PyErr_SetString(PyExc_ValueError, \"Invalid ggml_type\");\n    throw py::error_already_set();\n  }\n\n  py::module torch = py::module::import(\"torch\");\n  py::dict kwargs;\n  kwargs[\"dtype\"] = torch.attr(\"float32\");\n  py::object tensor = torch.attr(\"empty\")(size, **kwargs);\n\n  uintptr_t output_ptr = tensor.attr(\"data_ptr\")().cast<uintptr_t>();\n  float* output_float_ptr = reinterpret_cast<float*>(output_ptr);\n\n  try {\n    to_float(reinterpret_cast<void*>(input_ptr), output_float_ptr, size, type);\n  } catch (const std::exception& e) {\n    PyErr_SetString(PyExc_RuntimeError, e.what());\n    throw py::error_already_set();\n  }\n\n  return tensor;\n}\n\npy::object from_float_ptr(uintptr_t input_ptr, int size, ggml_type type) {\n  if (type < 0 || type >= GGML_TYPE_COUNT) {\n    PyErr_SetString(PyExc_ValueError, \"Invalid ggml_type\");\n    throw py::error_already_set();\n  }\n\n  py::module torch = py::module::import(\"torch\");\n\n  size_t output_elem_bytes = ggml_type_size(type);\n  size_t output_elem_count = (size + ggml_blck_size(type) - 1) / ggml_blck_size(type);\n  size_t total_bytes = output_elem_count * output_elem_bytes;\n\n  py::dict kwargs;\n  kwargs[\"dtype\"] = torch.attr(\"uint8\");\n  py::object tensor = torch.attr(\"empty\")(total_bytes, **kwargs);\n\n  uintptr_t output_ptr = tensor.attr(\"data_ptr\")().cast<uintptr_t>();\n  void* output_void_ptr = reinterpret_cast<void*>(output_ptr);\n\n  try {\n    from_float(reinterpret_cast<float*>(input_ptr), output_void_ptr, size, type);\n  } catch (const std::exception& e) {\n    PyErr_SetString(PyExc_RuntimeError, e.what());\n    throw py::error_already_set();\n  }\n\n  return tensor;\n}\n\ntemplate <typename T>\nstd::vector<std::vector<uintptr_t>> void_ptr_nested_to_uint(const std::vector<std::vector<T*>>& input) {\n  std::vector<std::vector<uintptr_t>> result;\n  for (const auto& row : input) {\n    std::vector<uintptr_t> new_row;\n    for (auto ptr : row) {\n      new_row.push_back(reinterpret_cast<uintptr_t>(ptr));\n    }\n    result.push_back(std::move(new_row));\n  }\n  return result;\n}\n\ntemplate <typename T>\nstd::vector<std::vector<T*>> uint_to_void_ptr_nested(const std::vector<std::vector<uintptr_t>>& input) {\n  std::vector<std::vector<T*>> result;\n  for (const auto& row : input) {\n    std::vector<T*> new_row;\n    for (auto val : row) {\n      new_row.push_back(reinterpret_cast<T*>(val));\n    }\n    result.push_back(std::move(new_row));\n  }\n  return result;\n}\n\n#define DEF_PTR_PROPERTY(cls, name)                                                  \\\n  def_property(                                                                      \\\n      #name, [](const cls& self) { return reinterpret_cast<uintptr_t>(self.name); }, \\\n      [](cls& self, uintptr_t val) { self.name = reinterpret_cast<void*>(val); })\n\n#define DEF_PTR_2D_PROPERTY(cls, name)                                                 \\\n  def_property(                                                                        \\\n      #name, [](const cls& self) { return void_ptr_nested_to_uint<void>(self.name); }, \\\n      [](cls& self, const std::vector<std::vector<uintptr_t>>& val) {                  \\\n        self.name = uint_to_void_ptr_nested<void>(val);                                \\\n      })\n\ntemplate <class T>\nclass MOEBindings {\n public:\n  class WarmUpBindings {\n   public:\n    struct Args {\n      CPUInfer* cpuinfer;\n      TP_MOE<T>* moe;\n    };\n    static void inner(void* args) {\n      Args* args_ = (Args*)args;\n      args_->cpuinfer->enqueue(&TP_MOE<T>::warm_up, args_->moe);\n    }\n    static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe) {\n      Args* args = new Args{nullptr, moe.get()};\n      return std::make_pair((intptr_t)&inner, (intptr_t)args);\n    }\n  };\n  class LoadWeightsBindings {\n   public:\n    struct Args {\n      CPUInfer* cpuinfer;\n      TP_MOE<T>* moe;\n    };\n    static void inner(void* args) {\n      Args* args_ = (Args*)args;\n      args_->cpuinfer->enqueue(&TP_MOE<T>::load_weights, args_->moe);\n    }\n    static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe,\n                                                            const uintptr_t physical_to_logical_map = 0) {\n      Args* args = new Args{nullptr, moe.get()};\n      if (physical_to_logical_map) {\n        // printf(\"debug physical_to_logical_map in arg:%lu\\n\", physical_to_logical_map);\n        moe->config.physical_to_logical_map = reinterpret_cast<void*>(physical_to_logical_map);\n        // printf(\"moe ptr:%p,confirm: moe->config.physical_to_logical_map:%lu\\n\", reinterpret_cast<void*>(moe.get()),\n        //  reinterpret_cast<uintptr_t>(moe->config.physical_to_logical_map));\n      }\n      return std::make_pair((intptr_t)&inner, (intptr_t)args);\n    }\n    static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe) {\n      return cpuinfer_interface(moe, 0);\n    }\n  };\n  class ForwardBindings {\n   public:\n    struct Args {\n      CPUInfer* cpuinfer;\n      TP_MOE<T>* moe;\n      intptr_t qlen;\n      int k;\n      intptr_t expert_ids;\n      intptr_t weights;\n      intptr_t input;\n      intptr_t output;\n      bool incremental;\n    };\n    static void inner(void* args) {\n      Args* args_ = (Args*)args;\n      args_->cpuinfer->enqueue(&TP_MOE<T>::forward_binding, args_->moe, args_->qlen, args_->k, args_->expert_ids,\n                               args_->weights, args_->input, args_->output, args_->incremental);\n    }\n    static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe, intptr_t qlen, int k,\n                                                            intptr_t expert_ids, intptr_t weights, intptr_t input,\n                                                            intptr_t output, bool incremental = false) {\n      Args* args = new Args{nullptr, moe.get(), qlen, k, expert_ids, weights, input, output, incremental};\n      return std::make_pair((intptr_t)&inner, (intptr_t)args);\n    }\n    static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<TP_MOE<T>> moe, intptr_t qlen, int k,\n                                                            intptr_t expert_ids, intptr_t weights, intptr_t input,\n                                                            intptr_t output) {\n      return cpuinfer_interface(moe, qlen, k, expert_ids, weights, input, output, false);\n    }\n  };\n};\n\ntemplate <typename MoeTP>\nvoid bind_moe_module(py::module_& moe_module, const char* name) {\n  using MoeClass = TP_MOE<MoeTP>;\n  using MoeBindings = MOEBindings<MoeTP>;\n\n  auto moe_cls = py::class_<MoeClass, MoE_Interface, std::shared_ptr<MoeClass>>(moe_module, name);\n\n  moe_cls.def(py::init<GeneralMOEConfig>())\n      .def(\"warm_up_task\", &MoeBindings::WarmUpBindings::cpuinfer_interface)\n      .def(\"load_weights_task\",\n           py::overload_cast<std::shared_ptr<MoeClass>>(&MoeBindings::LoadWeightsBindings::cpuinfer_interface))\n      .def(\"load_weights_task\",\n           py::overload_cast<std::shared_ptr<MoeClass>, const uintptr_t>(\n               &MoeBindings::LoadWeightsBindings::cpuinfer_interface),\n           py::arg(\"physical_to_logical_map\"))\n      // .def(\"forward_task\", &MoeBindings::ForwardBindings::cpuinfer_interface)\n      .def(\"forward_task\",\n           py::overload_cast<std::shared_ptr<MoeClass>, intptr_t, int, intptr_t, intptr_t, intptr_t, intptr_t>(\n               &MoeBindings::ForwardBindings::cpuinfer_interface))\n      .def(\"forward_task\",\n           py::overload_cast<std::shared_ptr<MoeClass>, intptr_t, int, intptr_t, intptr_t, intptr_t, intptr_t, bool>(\n               &MoeBindings::ForwardBindings::cpuinfer_interface))\n      .def(\"warm_up\", &MoeClass::warm_up)\n      .def(\"load_weights\", &MoeClass::load_weights)\n      .def(\"forward\", &MoeClass::forward_binding);\n\n  // Bind write_weight_scale_to_buffer_task for MoE types that support it\n  // Uses SFINAE to detect if MoeClass has write_weight_scale_to_buffer method\n  if constexpr (requires { &MoeClass::write_weight_scale_to_buffer; }) {\n    struct WriteWeightScaleToBufferBindings {\n      struct Args {\n        CPUInfer* cpuinfer;\n        MoeClass* moe;\n        int gpu_tp_count;\n        int expert_id;\n        std::vector<uintptr_t> w13_weight_ptrs;\n        std::vector<uintptr_t> w13_scale_ptrs;\n        std::vector<uintptr_t> w2_weight_ptrs;\n        std::vector<uintptr_t> w2_scale_ptrs;\n      };\n\n      static void inner(void* args) {\n        Args* args_ = (Args*)args;\n        args_->cpuinfer->enqueue(&MoeClass::write_weight_scale_to_buffer, args_->moe, args_->gpu_tp_count,\n                                 args_->expert_id, args_->w13_weight_ptrs, args_->w13_scale_ptrs, args_->w2_weight_ptrs,\n                                 args_->w2_scale_ptrs);\n      }\n\n      static std::pair<intptr_t, intptr_t> cpuinfer_interface(std::shared_ptr<MoeClass> moe, int gpu_tp_count,\n                                                              int expert_id, py::list w13_weight_ptrs,\n                                                              py::list w13_scale_ptrs, py::list w2_weight_ptrs,\n                                                              py::list w2_scale_ptrs) {\n        // Convert Python lists to std::vector<uintptr_t>\n        std::vector<uintptr_t> w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec;\n\n        for (auto item : w13_weight_ptrs) w13_weight_vec.push_back(py::cast<uintptr_t>(item));\n        for (auto item : w13_scale_ptrs) w13_scale_vec.push_back(py::cast<uintptr_t>(item));\n        for (auto item : w2_weight_ptrs) w2_weight_vec.push_back(py::cast<uintptr_t>(item));\n        for (auto item : w2_scale_ptrs) w2_scale_vec.push_back(py::cast<uintptr_t>(item));\n\n        Args* args = new Args{nullptr,        moe.get(),     gpu_tp_count,  expert_id,\n                              w13_weight_vec, w13_scale_vec, w2_weight_vec, w2_scale_vec};\n        return std::make_pair((intptr_t)&inner, (intptr_t)args);\n      }\n    };\n\n    moe_cls.def(\"write_weight_scale_to_buffer_task\", &WriteWeightScaleToBufferBindings::cpuinfer_interface,\n                py::arg(\"gpu_tp_count\"), py::arg(\"expert_id\"), py::arg(\"w13_weight_ptrs\"), py::arg(\"w13_scale_ptrs\"),\n                py::arg(\"w2_weight_ptrs\"), py::arg(\"w2_scale_ptrs\"));\n  }\n}\n\nPYBIND11_MODULE(kt_kernel_ext, m) {\n  py::class_<WorkerPool>(m, \"WorkerPool\").def(py::init<int>());\n  py::class_<WorkerPoolConfig>(m, \"WorkerPoolConfig\")\n      .def(py::init<>())\n      .def_readwrite(\"subpool_count\", &WorkerPoolConfig::subpool_count)\n      .def_readwrite(\"subpool_numa_map\", &WorkerPoolConfig::subpool_numa_map)\n      .def_readwrite(\"subpool_thread_count\", &WorkerPoolConfig::subpool_thread_count);\n\n  py::class_<CPUInfer>(m, \"CPUInfer\")\n      .def(py::init<int>())\n      .def(py::init<WorkerPoolConfig>())\n      .def(\"submit\", &CPUInfer::submit)\n      .def(\"sync\", &CPUInfer::sync, py::arg(\"allow_n_pending\") = 0)\n      .def_readwrite(\"backend_\", &CPUInfer::backend_)\n#ifndef KTRANSFORMERS_CPU_ONLY\n      .def(\"sync_with_cuda_stream\", &CPUInfer::sync_with_cuda_stream, py::arg(\"user_cuda_stream\"),\n           py::arg(\"allow_n_pending\") = 0)\n      .def(\"submit_with_cuda_stream\", &CPUInfer::submit_with_cuda_stream)\n#endif\n      ;\n\n  auto linear_module = m.def_submodule(\"linear\");\n  py::class_<LinearConfig>(linear_module, \"LinearConfig\")\n      .def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t proj,\n                       int proj_type, int hidden_type) {\n        return LinearConfig(hidden_size, intermediate_size, stride, group_max_len, (void*)proj, (ggml_type)proj_type,\n                            (ggml_type)hidden_type);\n      }));\n  // py::class_<Linear>(linear_module, \"Linear\")\n  //     .def(py::init<LinearConfig>())\n  //     .def(\"warm_up\", &LinearBindings::WarmUpBindings::cpuinfer_interface)\n  //     .def(\"forward\", &LinearBindings::ForwardBindings::cpuinfer_interface);\n\n  auto mlp_module = m.def_submodule(\"mlp\");\n  py::class_<MLPConfig>(mlp_module, \"MLPConfig\")\n      .def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t gate_proj,\n                       intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type,\n                       int hidden_type) {\n        return MLPConfig(hidden_size, intermediate_size, stride, group_max_len, (void*)gate_proj, (void*)up_proj,\n                         (void*)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type,\n                         (ggml_type)hidden_type);\n      }));\n  // py::class_<MLP>(mlp_module, \"MLP\")\n  //     .def(py::init<MLPConfig>())\n  //     .def(\"warm_up\", &MLPBindings::WarmUpBindings::cpuinfer_interface)\n  //     .def(\"forward\", &MLPBindings::ForwardBindings::cpuinfer_interface);\n\n  py::class_<GeneralConfig>(m, \"GeneralConfig\")\n      .def(py::init<>())\n      .def_readwrite(\"vocab_size\", &GeneralConfig::vocab_size)\n      .def_readwrite(\"hidden_size\", &GeneralConfig::hidden_size)\n      .def_readwrite(\"num_experts_per_tok\", &GeneralConfig::num_experts_per_tok)\n      .def_readwrite(\"n_routed_experts\", &GeneralConfig::n_routed_experts)\n      .def_readwrite(\"n_shared_experts\", &GeneralConfig::n_shared_experts)\n      .def_readwrite(\"max_qlen\", &GeneralConfig::max_qlen)\n      .DEF_PTR_PROPERTY(GeneralConfig, lm_heads_ptr)\n      .def_readwrite(\"lm_heads_type\", &GeneralConfig::lm_heads_type)\n      .DEF_PTR_PROPERTY(GeneralConfig, norm_weights_ptr)\n      .def_readwrite(\"norm_weights_type\", &GeneralConfig::norm_weights_type)\n      .DEF_PTR_PROPERTY(GeneralConfig, token_embd_ptr)\n      .def_readwrite(\"token_embd_type\", &GeneralConfig::token_embd_type)\n      .def_readwrite(\"pool\", &GeneralConfig::pool);\n#if defined(__aarch64__) && defined(CPU_USE_KML) && defined(KTRANSFORMERS_CPU_MLA)\n  py::class_<DeepseekV3ForCausalLM, std::shared_ptr<DeepseekV3ForCausalLM>>(m, \"DeepseekV3ForCausalLM\")\n      .def(py::init([](GeneralConfig config) { return std::make_shared<DeepseekV3ForCausalLM>(config); }))\n      .def_readwrite(\"model\", &DeepseekV3ForCausalLM::model)\n      .def(\"forward\", &DeepseekV3ForCausalLM::forward_binding);\n\n  py::class_<DeepseekV3Model, std::shared_ptr<DeepseekV3Model>>(m, \"DeepseekV3Model\")\n      .def(py::init([](GeneralConfig config) { return std::make_shared<DeepseekV3Model>(config); }))\n      .def_readwrite(\"layers\", &DeepseekV3Model::layers);\n\n  py::class_<DeepseekV3DecoderLayer, std::shared_ptr<DeepseekV3DecoderLayer>>(m, \"DeepseekV3DecoderLayer\")\n      .def(py::init([](GeneralConfig config, size_t layer_idx) {\n        return std::make_shared<DeepseekV3DecoderLayer>(config, layer_idx);\n      }))\n      .def(\"load_norm\", &DeepseekV3DecoderLayer::load_norm_binding)\n      .def_readwrite(\"self_attn\", &DeepseekV3DecoderLayer::self_attn)\n      .def_readwrite(\"gate\", &DeepseekV3DecoderLayer::gate)\n      .def_readwrite(\"ffn\", &DeepseekV3DecoderLayer::ffn);\n#endif\n  auto mla_module = m.def_submodule(\"mla\");\n  py::class_<GeneralMLAConfig>(mla_module, \"MLAConfig\")\n      .def(py::init([](size_t hidden_size, size_t q_lora_rank, size_t num_heads, size_t nope_size, size_t rope_size,\n                       size_t kv_lora_rank) {\n        return GeneralMLAConfig(hidden_size, q_lora_rank, num_heads, nope_size, rope_size, kv_lora_rank);\n      }))\n      .def_readwrite(\"layer_idx\", &GeneralMLAConfig::layer_idx)\n      .def_readwrite(\"pool\", &GeneralMLAConfig::pool)\n      .def_readwrite(\"token_count_in_page\", &GeneralMLAConfig::token_count_in_page)\n      .def_readwrite(\"max_qlen\", &GeneralMLAConfig::max_qlen)\n      .def_readwrite(\"max_kvlen\", &GeneralMLAConfig::max_kvlen)\n\n      .def_readwrite(\"max_position_embeddings\", &GeneralMLAConfig::max_position_embeddings)\n      .def_readwrite(\"rope_scaling_factor\", &GeneralMLAConfig::rope_scaling_factor)\n      .def_readwrite(\"rope_theta\", &GeneralMLAConfig::rope_theta)\n      .def_readwrite(\"rope_scaling_beta_fast\", &GeneralMLAConfig::rope_scaling_beta_fast)\n      .def_readwrite(\"rope_scaling_beta_slow\", &GeneralMLAConfig::rope_scaling_beta_slow)\n      .def_readwrite(\"rope_scaling_mscale\", &GeneralMLAConfig::rope_scaling_mscale)\n      .def_readwrite(\"rope_scaling_mscale_all_dim\", &GeneralMLAConfig::rope_scaling_mscale_all_dim)\n      .def_readwrite(\"rope_scaling_original_max_position_embeddings\",\n                     &GeneralMLAConfig::rope_scaling_original_max_position_embeddings)\n\n      .DEF_PTR_PROPERTY(GeneralMLAConfig, q_a_proj)\n      .DEF_PTR_PROPERTY(GeneralMLAConfig, q_a_norm)\n      .DEF_PTR_PROPERTY(GeneralMLAConfig, q_b_proj)\n      .DEF_PTR_PROPERTY(GeneralMLAConfig, kv_a_proj_with_mqa)\n      .DEF_PTR_PROPERTY(GeneralMLAConfig, kv_a_norm)\n      .DEF_PTR_PROPERTY(GeneralMLAConfig, kv_b_proj)\n      .DEF_PTR_PROPERTY(GeneralMLAConfig, o_proj)\n\n      .def_readwrite(\"q_a_proj_type\", &GeneralMLAConfig::q_a_proj_type)\n      .def_readwrite(\"q_a_norm_type\", &GeneralMLAConfig::q_a_norm_type)\n      .def_readwrite(\"q_b_proj_type\", &GeneralMLAConfig::q_b_proj_type)\n      .def_readwrite(\"kv_a_proj_with_mqa_type\", &GeneralMLAConfig::kv_a_proj_with_mqa_type)\n      .def_readwrite(\"kv_a_norm_type\", &GeneralMLAConfig::kv_a_norm_type)\n      .def_readwrite(\"kv_b_proj_type\", &GeneralMLAConfig::kv_b_proj_type)\n      .def_readwrite(\"w_o_type\", &GeneralMLAConfig::w_o_type)\n      .def_readwrite(\"page_count\", &GeneralMLAConfig::page_count)\n\n      ;\n  py::class_<MLA_Interface, std::shared_ptr<MLA_Interface>>(mla_module, \"MLA_Interface\");\n#if defined(__aarch64__) && defined(CPU_USE_KML) && defined(KTRANSFORMERS_CPU_MLA)\n  py::class_<TP_MLA<KML_MLA_TP<float16_t>>, MLA_Interface, std::shared_ptr<TP_MLA<KML_MLA_TP<float16_t>>>>(mla_module,\n                                                                                                           \"MLA_F16\")\n      .def(py::init<GeneralMLAConfig>())\n      .def(\"load_weights\", &TP_MLA<KML_MLA_TP<float16_t>>::load_weights)\n      .def(\"forward\",\n           [](TP_MLA<KML_MLA_TP<float16_t>>& op, std::vector<int> qlens, std::vector<std::vector<int>> page_tables,\n              std::vector<int> kvlens, intptr_t input,\n              intptr_t output) { op.forward(qlens, page_tables, kvlens, (const void*)input, (void*)output); })\n      .def(\"set_local_pages\", &TP_MLA<KML_MLA_TP<float16_t>>::set_local_pages)\n      .def(\"set_pages\", [](TP_MLA<KML_MLA_TP<float16_t>>& op, std::vector<std::vector<intptr_t>> nope_pages,\n                           std::vector<std::vector<intptr_t>> rope_pages) {\n        std::vector<std::vector<void*>> nope_pages_ptr;\n        std::vector<std::vector<void*>> rope_pages_ptr;\n        op.set_pages(nope_pages_ptr, rope_pages_ptr);\n      });\n\n  py::class_<TP_MLA<KML_MLA_TP<float>>, MLA_Interface, std::shared_ptr<TP_MLA<KML_MLA_TP<float>>>>(mla_module,\n                                                                                                   \"MLA_F32\")\n      .def(py::init<GeneralMLAConfig>())\n      .def(\"load_weights\", &TP_MLA<KML_MLA_TP<float>>::load_weights)\n      .def(\"forward\",\n           [](TP_MLA<KML_MLA_TP<float>>& op, std::vector<int> qlens, std::vector<std::vector<int>> page_tables,\n              std::vector<int> kvlens, intptr_t input,\n              intptr_t output) { op.forward(qlens, page_tables, kvlens, (const void*)input, (void*)output); })\n      .def(\"set_local_pages\", &TP_MLA<KML_MLA_TP<float>>::set_local_pages)\n      .def(\"set_pages\", [](TP_MLA<KML_MLA_TP<float>>& op, std::vector<std::vector<intptr_t>> nope_pages,\n                           std::vector<std::vector<intptr_t>> rope_pages) {\n        std::vector<std::vector<void*>> nope_pages_ptr;\n        std::vector<std::vector<void*>> rope_pages_ptr;\n        op.set_pages(nope_pages_ptr, rope_pages_ptr);\n      });\n  py::class_<TP_MLA<KML_MLA_TP_QUAN<float>>, MLA_Interface, std::shared_ptr<TP_MLA<KML_MLA_TP_QUAN<float>>>>(\n      mla_module, \"MLA_QUAN_F32\")\n      .def(py::init<GeneralMLAConfig>())\n      .def(\"load_weights\", &TP_MLA<KML_MLA_TP_QUAN<float>>::load_weights)\n      .def(\"forward\",\n           [](TP_MLA<KML_MLA_TP_QUAN<float>>& op, std::vector<int> qlens, std::vector<std::vector<int>> page_tables,\n              std::vector<int> kvlens, intptr_t input,\n              intptr_t output) { op.forward(qlens, page_tables, kvlens, (const void*)input, (void*)output); })\n      .def(\"set_local_pages\", &TP_MLA<KML_MLA_TP_QUAN<float>>::set_local_pages)\n      .def(\"set_pages\", [](TP_MLA<KML_MLA_TP_QUAN<float>>& op, std::vector<std::vector<intptr_t>> nope_pages,\n                           std::vector<std::vector<intptr_t>> rope_pages) {\n        std::vector<std::vector<void*>> nope_pages_ptr;\n        std::vector<std::vector<void*>> rope_pages_ptr;\n        op.set_pages(nope_pages_ptr, rope_pages_ptr);\n      });\n\n  auto gate_module = m.def_submodule(\"gate\");\n  py::class_<GeneralGateConfig>(gate_module, \"GateConfig\")\n      .def(py::init([](int hidden_size, int num_experts_per_tok, int n_routed_experts, int n_group, int topk_group) {\n        return GeneralGateConfig(hidden_size, num_experts_per_tok, n_routed_experts, n_group, topk_group);\n      }))\n      .def_readwrite(\"routed_scaling_factor\", &GeneralGateConfig::routed_scaling_factor)\n\n      .def_readwrite(\"layer_idx\", &GeneralGateConfig::layer_idx)\n      .def_readwrite(\"pool\", &GeneralGateConfig::pool)\n      .DEF_PTR_PROPERTY(GeneralGateConfig, weight)\n      .def_readwrite(\"weight_type\", &GeneralGateConfig::weight_type)\n      .DEF_PTR_PROPERTY(GeneralGateConfig, e_score_correction_bias)\n      .def_readwrite(\"e_score_correction_bias_type\", &GeneralGateConfig::e_score_correction_bias_type)\n\n      ;\n  py::class_<MoEGate, std::shared_ptr<MoEGate>>(gate_module, \"MoEGate\")\n      .def(py::init<GeneralGateConfig>())\n      .def(\"forward\", &MoEGate::forward_binding);\n#endif\n\n  py::class_<QuantConfig>(m, \"QuantConfig\")\n      .def(py::init<>())\n      .def_readwrite(\"quant_method\", &QuantConfig::quant_method)\n      .def_readwrite(\"bits\", &QuantConfig::bits)\n      .def_readwrite(\"group_size\", &QuantConfig::group_size)\n      .def_readwrite(\"zero_point\", &QuantConfig::zero_point)\n      .def_readwrite(\"per_channel\", &QuantConfig::per_channel);\n\n  auto moe_module = m.def_submodule(\"moe\");\n\n  py::class_<GeneralMOEConfig>(moe_module, \"MOEConfig\")\n      .def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size) {\n        return GeneralMOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size);\n      }))\n      .def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size,\n                       uintptr_t gpu_experts_mask_ptr) {\n        GeneralMOEConfig cfg(expert_num, routed_expert_num, hidden_size, intermediate_size);\n        cfg.gpu_experts_mask = reinterpret_cast<uint8_t*>(gpu_experts_mask_ptr);\n        cfg.compute_num_gpu_experts();\n        return cfg;\n      }))\n      .def_readwrite(\"layer_idx\", &GeneralMOEConfig::layer_idx)\n      .def_readwrite(\"pool\", &GeneralMOEConfig::pool)\n\n      .def_readonly(\"num_gpu_experts\", &GeneralMOEConfig::num_gpu_experts)\n      .def_property(\n          \"gpu_experts_mask\",\n          [](const GeneralMOEConfig& self) { return reinterpret_cast<uintptr_t>(self.gpu_experts_mask); },\n          [](GeneralMOEConfig& self, uintptr_t val) { self.gpu_experts_mask = reinterpret_cast<uint8_t*>(val); })\n      .DEF_PTR_PROPERTY(GeneralMOEConfig, physical_to_logical_map)\n\n      .DEF_PTR_PROPERTY(GeneralMOEConfig, gate_proj)\n      .DEF_PTR_PROPERTY(GeneralMOEConfig, up_proj)\n      .DEF_PTR_PROPERTY(GeneralMOEConfig, down_proj)\n\n      .DEF_PTR_PROPERTY(GeneralMOEConfig, gate_scale)\n      .DEF_PTR_PROPERTY(GeneralMOEConfig, up_scale)\n      .DEF_PTR_PROPERTY(GeneralMOEConfig, down_scale)\n\n      .DEF_PTR_PROPERTY(GeneralMOEConfig, gate_zero)\n      .DEF_PTR_PROPERTY(GeneralMOEConfig, up_zero)\n      .DEF_PTR_PROPERTY(GeneralMOEConfig, down_zero)\n\n      .def_readwrite(\"quant_config\", &GeneralMOEConfig::quant_config)\n\n      .def_readwrite(\"max_len\", &GeneralMOEConfig::max_len)\n\n      .DEF_PTR_2D_PROPERTY(GeneralMOEConfig, gate_projs)\n      .DEF_PTR_2D_PROPERTY(GeneralMOEConfig, up_projs)\n      .DEF_PTR_2D_PROPERTY(GeneralMOEConfig, down_projs)\n\n      .DEF_PTR_2D_PROPERTY(GeneralMOEConfig, gate_scales)\n      .DEF_PTR_2D_PROPERTY(GeneralMOEConfig, up_scales)\n      .DEF_PTR_2D_PROPERTY(GeneralMOEConfig, down_scales)\n\n      .DEF_PTR_2D_PROPERTY(GeneralMOEConfig, gate_zeros)\n      .DEF_PTR_2D_PROPERTY(GeneralMOEConfig, up_zeros)\n      .DEF_PTR_2D_PROPERTY(GeneralMOEConfig, down_zeros)\n\n      .def_readwrite(\"path\", &GeneralMOEConfig::path)\n      .def_readwrite(\"save\", &GeneralMOEConfig::save)\n      .def_readwrite(\"load\", &GeneralMOEConfig::load)\n      .def_readwrite(\"m_block\", &GeneralMOEConfig::m_block)\n      .def_readwrite(\"group_min_len\", &GeneralMOEConfig::group_min_len)\n      .def_readwrite(\"group_max_len\", &GeneralMOEConfig::group_max_len)\n\n      .def_readwrite(\"gate_type\", &GeneralMOEConfig::gate_type)\n      .def_readwrite(\"up_type\", &GeneralMOEConfig::up_type)\n      .def_readwrite(\"down_type\", &GeneralMOEConfig::down_type)\n      .def_readwrite(\"hidden_type\", &GeneralMOEConfig::hidden_type)\n\n      ;\n\n  py::class_<MoE_Interface, std::shared_ptr<MoE_Interface>>(moe_module, \"MoE_Interface\");\n\n  bind_moe_module<LLAMA_MOE_TP>(moe_module, \"MOE\");\n\n#if defined(__x86_64__) && defined(USE_AMX_AVX_KERNEL)\n  bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int8>>(moe_module, \"AMXInt8_MOE\");\n  bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4>>(moe_module, \"AMXInt4_MOE\");\n  bind_moe_module<AMX_MOE_TP<amx::GemmKernel224Int4_1>>(moe_module, \"AMXInt4_1_MOE\");\n  bind_moe_module<AMX_AWQ_MOE_TP<amx::GemmKernel224Int4_1_LowKGroup>>(moe_module, \"AMXInt4_1KGroup_MOE\");\n  bind_moe_module<AMX_K2_MOE_TP<amx::GemmKernel224Int4SmallKGroup>>(moe_module, \"AMXInt4_KGroup_MOE\");\n#if defined(__AVX512BF16__)\n  bind_moe_module<AMX_BF16_MOE_TP<amx::GemmKernel224BF16>>(moe_module, \"AMXBF16_MOE\");\n  bind_moe_module<AMX_FP8_MOE_TP<amx::GemmKernel224FP8>>(moe_module, \"AMXFP8_MOE\");\n  bind_moe_module<AMX_FP8_PERCHANNEL_MOE_TP<amx::GemmKernel224FP8PerChannel>>(moe_module, \"AMXFP8PerChannel_MOE\");\n#endif\n#endif\n#if defined(USE_MOE_KERNEL)\n  bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt8, _is_plain_>>(moe_module, \"Int8_KERNEL_MOE\");\n#if defined(__aarch64__) && defined(CPU_USE_KML)\n  // amd have not implemented int4 kernel yet\n  bind_moe_module<MOE_KERNEL_TP<moe_kernel::GemmKernelInt4, _is_plain_>>(moe_module, \"Int4_KERNEL_MOE\");\n#endif\n#endif\n\n  // Expose kernel tiling/runtime parameters so Python can modify them at runtime\n  {\n    auto tiling_module = moe_module.def_submodule(\"tiling\");\n#if defined(USE_MOE_KERNEL)\n    tiling_module.def(\n        \"get_int8\",\n        []() {\n          auto t = moe_kernel::GemmKernelInt8::get_tiling();\n          py::dict d;\n          d[\"n_block_up_gate\"] = std::get<0>(t);\n          d[\"n_block_down\"] = std::get<1>(t);\n          d[\"n_block\"] = std::get<2>(t);\n          d[\"m_block\"] = std::get<3>(t);\n          d[\"k_block\"] = std::get<4>(t);\n          d[\"n_block_up_gate_prefi\"] = std::get<5>(t);\n          d[\"n_block_down_prefi\"] = std::get<6>(t);\n          return d;\n        },\n        \"Get current tiling parameters for INT8 kernel\");\n    tiling_module.def(\n        \"set_int8\",\n        [](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,\n           int n_block_down_prefi) {\n          moe_kernel::GemmKernelInt8::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,\n                                                 n_block_up_gate_prefi, n_block_down_prefi);\n        },\n        py::arg(\"n_block_up_gate\"), py::arg(\"n_block_down\"), py::arg(\"n_block\"), py::arg(\"m_block\"), py::arg(\"k_block\"),\n        py::arg(\"n_block_up_gate_prefi\"), py::arg(\"n_block_down_prefi\"), \"Set tiling parameters for INT8 kernel\");\n\n    tiling_module.def(\n        \"get_int4\",\n        []() {\n          auto t = moe_kernel::GemmKernelInt4::get_tiling();\n          py::dict d;\n          d[\"n_block_up_gate\"] = std::get<0>(t);\n          d[\"n_block_down\"] = std::get<1>(t);\n          d[\"n_block\"] = std::get<2>(t);\n          d[\"m_block\"] = std::get<3>(t);\n          d[\"k_block\"] = std::get<4>(t);\n          d[\"n_block_up_gate_prefi\"] = std::get<5>(t);\n          d[\"n_block_down_prefi\"] = std::get<6>(t);\n          return d;\n        },\n        \"Get current tiling parameters for INT4 kernel\");\n    tiling_module.def(\n        \"set_int4\",\n        [](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,\n           int n_block_down_prefi) {\n          moe_kernel::GemmKernelInt4::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,\n                                                 n_block_up_gate_prefi, n_block_down_prefi);\n        },\n        py::arg(\"n_block_up_gate\"), py::arg(\"n_block_down\"), py::arg(\"n_block\"), py::arg(\"m_block\"), py::arg(\"k_block\"),\n        py::arg(\"n_block_up_gate_prefi\"), py::arg(\"n_block_down_prefi\"), \"Set tiling parameters for INT4 kernel\");\n\n    // Convenience: set both\n    tiling_module.def(\n        \"set_all\",\n        [](int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block, int n_block_up_gate_prefi,\n           int n_block_down_prefi) {\n          moe_kernel::GemmKernelInt8::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,\n                                                 n_block_up_gate_prefi, n_block_down_prefi);\n          moe_kernel::GemmKernelInt4::set_tiling(n_block_up_gate, n_block_down, n_block, m_block, k_block,\n                                                 n_block_up_gate_prefi, n_block_down_prefi);\n        },\n        py::arg(\"n_block_up_gate\"), py::arg(\"n_block_down\"), py::arg(\"n_block\"), py::arg(\"m_block\"), py::arg(\"k_block\"),\n        py::arg(\"n_block_up_gate_prefi\"), py::arg(\"n_block_down_prefi\"),\n        \"Set tiling parameters for both INT8 and INT4 kernels\");\n#endif\n  }\n\n  auto kvcache_module = m.def_submodule(\"kvcache\");\n\n  py::enum_<AnchorType>(kvcache_module, \"AnchorType\")\n      .value(\"FIXED\", AnchorType::FIXED_ANCHOR)\n      .value(\"DYNAMIC\", AnchorType::DYNAMIC)\n      .value(\"QUEST\", AnchorType::QUEST)\n      .value(\"BLOCK_MAX\", AnchorType::BLOCK_MAX)\n      .value(\"BLOCK_MEAN\", AnchorType::BLOCK_MEAN);\n  py::enum_<ggml_type>(kvcache_module, \"ggml_type\")\n      // .value(\"FP16\", ggml_type::GGML_TYPE_F16)\n      // .value(\"FP32\", ggml_type::GGML_TYPE_F32)\n      // .value(\"Q4_0\", ggml_type::GGML_TYPE_Q4_0)\n      // .value(\"Q8_0\", ggml_type::GGML_TYPE_Q8_0)\n      .value(\"FP32\", GGML_TYPE_F32)\n      .value(\"FP16\", GGML_TYPE_F16)\n      .value(\"Q4_0\", GGML_TYPE_Q4_0)\n      .value(\"Q4_1\", GGML_TYPE_Q4_1)\n      .value(\"Q5_0\", GGML_TYPE_Q5_0)\n      .value(\"Q5_1\", GGML_TYPE_Q5_1)\n      .value(\"Q8_0\", GGML_TYPE_Q8_0)\n      .value(\"Q8_1\", GGML_TYPE_Q8_1)\n      .value(\"Q2_K\", GGML_TYPE_Q2_K)\n      .value(\"Q3_K\", GGML_TYPE_Q3_K)\n      .value(\"Q4_K\", GGML_TYPE_Q4_K)\n      .value(\"Q5_K\", GGML_TYPE_Q5_K)\n      .value(\"Q6_K\", GGML_TYPE_Q6_K)\n      .value(\"Q8_K\", GGML_TYPE_Q8_K)\n      .value(\"IQ2_XXS\", GGML_TYPE_IQ2_XXS)\n      .value(\"IQ2_XS\", GGML_TYPE_IQ2_XS)\n      .value(\"IQ3_XXS\", GGML_TYPE_IQ3_XXS)\n      .value(\"IQ1_S\", GGML_TYPE_IQ1_S)\n      .value(\"IQ4_NL\", GGML_TYPE_IQ4_NL)\n      .value(\"IQ3_S\", GGML_TYPE_IQ3_S)\n      .value(\"IQ2_S\", GGML_TYPE_IQ2_S)\n      .value(\"IQ4_XS\", GGML_TYPE_IQ4_XS)\n      .value(\"I8\", GGML_TYPE_I8)\n      .value(\"I16\", GGML_TYPE_I16)\n      .value(\"I32\", GGML_TYPE_I32)\n      .value(\"I64\", GGML_TYPE_I64)\n      .value(\"F64\", GGML_TYPE_F64)\n      .value(\"IQ1_M\", GGML_TYPE_IQ1_M)\n      .value(\"BF16\", GGML_TYPE_BF16)\n      .export_values();\n\n  py::enum_<RetrievalType>(kvcache_module, \"RetrievalType\")\n      .value(\"LAYER\", RetrievalType::LAYER)\n      .value(\"KVHEAD\", RetrievalType::KVHEAD)\n      .value(\"QHEAD\", RetrievalType::QHEAD);\n\n  py::class_<KVCacheConfig>(kvcache_module, \"KVCacheConfig\")\n      .def(py::init<int, int, int, int, int, int, AnchorType, ggml_type, RetrievalType, int, int, int, int, int, int>())\n      .def_readwrite(\"layer_num\", &KVCacheConfig::layer_num)\n      .def_readwrite(\"kv_head_num\", &KVCacheConfig::kv_head_num)\n      .def_readwrite(\"q_head_num\", &KVCacheConfig::q_head_num)\n      .def_readwrite(\"head_dim\", &KVCacheConfig::head_dim)\n      .def_readwrite(\"block_len\", &KVCacheConfig::block_len)\n      .def_readwrite(\"anchor_num\", &KVCacheConfig::anchor_num)\n      .def_readwrite(\"anchor_type\", &KVCacheConfig::anchor_type)\n      .def_readwrite(\"kv_type\", &KVCacheConfig::kv_type)\n      .def_readwrite(\"retrieval_type\", &KVCacheConfig::retrieval_type)\n      .def_readwrite(\"layer_step\", &KVCacheConfig::layer_step)\n      .def_readwrite(\"token_step\", &KVCacheConfig::token_step)\n      .def_readwrite(\"layer_offset\", &KVCacheConfig::layer_offset)\n      .def_readwrite(\"max_block_num\", &KVCacheConfig::max_block_num)\n      .def_readwrite(\"max_batch_size\", &KVCacheConfig::max_batch_size)\n      .def_readwrite(\"max_thread_num\", &KVCacheConfig::max_thread_num);\n  py::class_<KVCache>(kvcache_module, \"KVCache\")\n      .def(py::init<KVCacheConfig>())\n      .def(\"get_cache_total_len\", &KVCache::get_cache_total_len)\n      .def(\"update_cache_total_len\",\n           [](KVCache& kvcache, int cache_total_len) { kvcache.update_cache_total_len(cache_total_len); });\n\n  auto utils = m.def_submodule(\"utils\");\n\n  // 注册转换函数\n  utils.def(\"to_float\", &to_float_ptr, \"Convert tensor from any GGML type to float32\", py::arg(\"input\"),\n            py::arg(\"size\"), py::arg(\"type\"));\n\n  utils.def(\"from_float\", &from_float_ptr, \"Convert tensor from float32 to any GGML type\", py::arg(\"input\"),\n            py::arg(\"size\"), py::arg(\"type\"));\n}\n"
  },
  {
    "path": "kt-kernel/install.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\n\nusage() {\n  cat <<EOF\nUsage: $0 [SUBCOMMAND] [BUILD_OPTIONS]\n\nTwo-step installation in one file. Choose a subcommand:\n\nSUBCOMMANDS:\n  deps            Install system prerequisites only\n  build           Build and install kt-kernel (no dependency install)\n  all             Run deps then build (default when no subcommand)\n  -h, --help      Show this help message\n\nBUILD_OPTIONS (for \"build\" or \"all\"):\n  (none)          Auto-detect CPU and configure automatically (recommended)\n  --manual        Skip auto-detection, use manual configuration (see below)\n  --no-clean      Do not delete local build/ before building (default cleans)\n\nAUTO-DETECTION (Default):\n  The script will automatically detect your CPU and use ALL available features:\n  - CPUINFER_CPU_INSTRUCT = NATIVE (uses -march=native)\n  - CPUINFER_ENABLE_AMX   = ON/OFF (based on detection)\n  - CPUINFER_ENABLE_AVX512_VNNI = ON/OFF (with fallback if OFF)\n  - CPUINFER_ENABLE_AVX512_BF16 = ON/OFF (with fallback if OFF)\n  - CPUINFER_ENABLE_AVX512_VBMI = ON/OFF (required for FP8 MoE)\n\n  ✓ Best performance on YOUR machine\n  ✗ Binary may NOT work on different/older CPUs\n\n  Use this when: Installing for local use only\n\nMANUAL CONFIGURATION:\n  Use --manual flag when building for DISTRIBUTION or different machines.\n  Set these environment variables before running:\n\n  CPUINFER_CPU_INSTRUCT   - Target CPU instruction set\n                            Options: AVX512, AVX2, FANCY, NATIVE\n  CPUINFER_ENABLE_AMX     - Enable Intel AMX support\n                            Options: ON, OFF\n\nDistribution examples (portable binaries):\n\n┌──────────────────────────────────────────────────────────────────────────┐\n│ Configuration          │ Target CPUs              │ Use Case             │\n├────────────────────────┼──────────────────────────┼──────────────────────┤\n│ AVX512 + AMX=OFF       │ Skylake-X, Ice Lake,     │ General distribution │\n│                        │ Cascade Lake, Zen 4      │ (recommended)        │\n├────────────────────────┼──────────────────────────┼──────────────────────┤\n│ AVX2 + AMX=OFF         │ Haswell (2013) and newer │ Maximum compatibility│\n├────────────────────────┼──────────────────────────┼──────────────────────┤\n│ FANCY + AMX=OFF        │ Ice Lake+, Zen 4+        │ Modern CPUs only     │\n│                        │ (with full AVX512 ext)   │                      │\n└────────────────────────┴──────────────────────────┴──────────────────────┘\n\n  Use this when: Building Docker images, PyPI packages, or deploying to clusters\n\n  Example: Build for general distribution\n    export CPUINFER_CPU_INSTRUCT=AVX512\n    export CPUINFER_ENABLE_AMX=OFF\n    $0 build --manual\n    # Result: Works on any CPU with AVX512 (2017+)\n\n  Example: Build for maximum compatibility\n    export CPUINFER_CPU_INSTRUCT=AVX2\n    export CPUINFER_ENABLE_AMX=OFF\n    $0 build --manual\n    # Result: Works on any CPU with AVX2 (2013+)\n\nOptional variables (with defaults):\n  CPUINFER_BUILD_TYPE=Release           Build type (Debug/RelWithDebInfo/Release)\n  CPUINFER_PARALLEL=8                   Number of parallel build jobs\n  CPUINFER_VERBOSE=1                    Verbose build output (0/1)\n  CPUINFER_ENABLE_AVX512_VNNI=ON/OFF    Override VNNI detection (auto if unset)\n  CPUINFER_ENABLE_AVX512_BF16=ON/OFF    Override BF16 detection (auto if unset)\n  CPUINFER_ENABLE_AVX512_VBMI=ON/OFF    Override VBMI detection (auto if unset)\n\nSoftware Fallback Support:\n  ✓ If VNNI not available: Uses AVX512BW fallback (2-3x slower but works)\n  ✓ If BF16 not available: Uses AVX512F fallback (5-10x slower but works)\n  → Old CPUs with only AVX512F+BW can run all code (slower but functional)\n\nEOF\n  exit 1\n}\n\ninstall_dependencies() {\n  echo \"Checking and installing system dependencies...\"\n\n  # Determine if we need to use sudo\n  SUDO=\"\"\n  if [ \"${EUID:-0}\" -ne 0 ]; then\n    if command -v sudo &> /dev/null; then\n      SUDO=\"sudo\"\n    else\n      echo \"Warning: Not running as root and sudo not found. Package installation may fail.\"\n      echo \"Please run as root or install sudo.\"\n    fi\n  fi\n\n  if command -v conda &> /dev/null; then\n    echo \"Installing cmake via conda...\"\n    conda install -y cmake\n  else\n    echo \"Warning: conda not found. Skipping cmake installation via conda.\"\n    echo \"Please install conda or manually install cmake.\"\n  fi\n\n  # Detect OS type\n  if [ -f /etc/os-release ]; then\n    . /etc/os-release\n    OS=$ID\n  elif [ -f /etc/debian_version ]; then\n    OS=\"debian\"\n  elif [ -f /etc/redhat-release ]; then\n    OS=\"rhel\"\n  else\n    echo \"Warning: Unable to detect OS type. Skipping dependency installation.\"\n    return 0\n  fi\n\n  # Install dependencies based on OS\n  case \"$OS\" in\n    debian|ubuntu|linuxmint|pop)\n      echo \"Detected Debian-based system. Installing libhwloc-dev and pkg-config...\"\n      $SUDO apt update\n      $SUDO apt install -y libhwloc-dev pkg-config\n      ;;\n    fedora|rhel|centos|rocky|almalinux)\n      echo \"Detected Red Hat-based system. Installing hwloc-devel and pkgconfig...\"\n      $SUDO dnf install -y hwloc-devel pkgconfig || $SUDO yum install -y hwloc-devel pkgconfig\n      ;;\n    arch|manjaro)\n      echo \"Detected Arch-based system. Installing hwloc and pkgconf...\"\n      $SUDO pacman -S --noconfirm hwloc pkgconf\n      ;;\n    opensuse*|sles)\n      echo \"Detected openSUSE-based system. Installing hwloc-devel and pkg-config...\"\n      $SUDO zypper install -y hwloc-devel pkg-config\n      ;;\n    *)\n      echo \"Warning: Unsupported OS '$OS'. Please manually install libhwloc-dev and pkg-config.\"\n      ;;\n  esac\n}\n\n# Function to detect CPU features\n# Returns: \"has_amx has_avx512f has_avx512_vnni has_avx512_bf16 has_avx512_vbmi\" (space-separated 0/1 values)\ndetect_cpu_features() {\n  local has_amx=0\n  local has_avx512f=0\n  local has_avx512_vnni=0\n  local has_avx512_bf16=0\n  local has_avx512_vbmi=0\n\n  if [ -f /proc/cpuinfo ]; then\n    local cpu_flags\n    cpu_flags=$(grep -m1 \"^flags\" /proc/cpuinfo | tr ' ' '\\n')\n\n    # Check for AMX support on Linux\n    if echo \"$cpu_flags\" | grep -qE \"amx_tile|amx_int8|amx_bf16\"; then\n      has_amx=1\n    fi\n\n    # Check for AVX512F (foundation)\n    if echo \"$cpu_flags\" | grep -qE \"avx512f\"; then\n      has_avx512f=1\n    fi\n\n    # Check for AVX512_VNNI support\n    if echo \"$cpu_flags\" | grep -qE \"avx512_vnni|avx512vnni\"; then\n      has_avx512_vnni=1\n    fi\n\n    # Check for AVX512_BF16 support\n    if echo \"$cpu_flags\" | grep -qE \"avx512_bf16|avx512bf16\"; then\n      has_avx512_bf16=1\n    fi\n\n    # Check for AVX512_VBMI support\n    if echo \"$cpu_flags\" | grep -qE \"avx512_vbmi|avx512vbmi\"; then\n      has_avx512_vbmi=1\n    fi\n  elif [ \"$(uname)\" = \"Darwin\" ]; then\n    # macOS doesn't have AMX (ARM or Intel without AMX)\n    has_amx=0\n    has_avx512f=0\n    has_avx512_vnni=0\n    has_avx512_bf16=0\n    has_avx512_vbmi=0\n  fi\n\n  echo \"$has_amx $has_avx512f $has_avx512_vnni $has_avx512_bf16 $has_avx512_vbmi\"\n}\n\nbuild_step() {\n  # Parse build-only flags from arguments to this function\n  local MANUAL_MODE=0\n  local CLEAN_BUILD=1\n  while [[ $# -gt 0 ]]; do\n    case \"$1\" in\n      --manual) MANUAL_MODE=1; shift ;;\n      --no-clean) CLEAN_BUILD=0; shift ;;\n      -h|--help) usage ;;\n      *) break ;;\n    esac\n  done\n\n  # Clean local build directory to ensure a fresh CMake/configure\n  local REPO_ROOT\n  REPO_ROOT=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n  if [[ \"$CLEAN_BUILD\" -eq 1 ]]; then\n    if [[ -d \"$REPO_ROOT/build\" ]]; then\n      echo \"Cleaning previous build directory: $REPO_ROOT/build\"\n      rm -rf \"$REPO_ROOT/build\"\n    fi\n  else\n    echo \"Skipping clean of $REPO_ROOT/build (requested by --no-clean)\"\n  fi\n\n  if [ \"$MANUAL_MODE\" = \"0\" ]; then\n  # Auto-detection mode\n  echo \"==========================================\"\n  echo \"Auto-detecting CPU capabilities...\"\n  echo \"==========================================\"\n  echo \"\"\n\n  # detect_cpu_features returns \"has_amx has_avx512f has_avx512_vnni has_avx512_bf16 has_avx512_vbmi\"\n  CPU_FEATURES=$(detect_cpu_features)\n  HAS_AMX=$(echo \"$CPU_FEATURES\" | cut -d' ' -f1)\n  HAS_AVX512F=$(echo \"$CPU_FEATURES\" | cut -d' ' -f2)\n  HAS_AVX512_VNNI=$(echo \"$CPU_FEATURES\" | cut -d' ' -f3)\n  HAS_AVX512_BF16=$(echo \"$CPU_FEATURES\" | cut -d' ' -f4)\n  HAS_AVX512_VBMI=$(echo \"$CPU_FEATURES\" | cut -d' ' -f5)\n\n  export CPUINFER_CPU_INSTRUCT=NATIVE\n\n  if [ \"$HAS_AMX\" = \"1\" ]; then\n    echo \"✓ AMX instructions detected\"\n    export CPUINFER_ENABLE_AMX=ON\n    echo \"\"\n    echo \"Configuration: NATIVE + AMX=ON\"\n    echo \"  ✓ Best performance on this machine\"\n    echo \"  ✗ Binary requires Sapphire Rapids or newer CPU\"\n  else\n    echo \"ℹ AMX instructions not detected\"\n    export CPUINFER_ENABLE_AMX=OFF\n    echo \"\"\n    echo \"Configuration: NATIVE + AMX=OFF\"\n    echo \"  ✓ Using AVX512/AVX2 instructions\"\n  fi\n\n  echo \"\"\n  echo \"  ⚠️  IMPORTANT: This binary is optimized for THIS CPU only\"\n  echo \"     To build portable binaries for distribution, use:\"\n  echo \"       export CPUINFER_CPU_INSTRUCT=AVX512  # or AVX2\"\n  echo \"       export CPUINFER_ENABLE_AMX=OFF\"\n  echo \"       ./install.sh build --manual\"\n\n  # Fine-grained AVX512 subset detection (with fallback support)\n  echo \"\"\n  echo \"AVX512 Feature Detection:\"\n\n  # AVX512F: Foundation (required for all AVX512 variants)\n  if [ \"$HAS_AVX512F\" = \"1\" ]; then\n    echo \"  AVX512F: ✓ Detected (foundation)\"\n  else\n    echo \"  AVX512F: ✗ Not detected (AVX512 not available)\"\n  fi\n\n  # VNNI: Check if user manually set it, otherwise auto-detect\n  if [ -n \"${CPUINFER_ENABLE_AVX512_VNNI:-}\" ]; then\n    echo \"  VNNI: User override = $CPUINFER_ENABLE_AVX512_VNNI\"\n  else\n    if [ \"$HAS_AVX512_VNNI\" = \"1\" ]; then\n      echo \"  VNNI: ✓ Detected (hardware acceleration enabled)\"\n      export CPUINFER_ENABLE_AVX512_VNNI=ON\n    else\n      echo \"  VNNI: ✗ Not detected (will use software fallback, 2-3x slower)\"\n      export CPUINFER_ENABLE_AVX512_VNNI=OFF\n    fi\n  fi\n\n  # BF16: Check if user manually set it, otherwise auto-detect\n  if [ -n \"${CPUINFER_ENABLE_AVX512_BF16:-}\" ]; then\n    echo \"  BF16: User override = $CPUINFER_ENABLE_AVX512_BF16\"\n  else\n    if [ \"$HAS_AVX512_BF16\" = \"1\" ]; then\n      echo \"  BF16: ✓ Detected (hardware acceleration enabled)\"\n      export CPUINFER_ENABLE_AVX512_BF16=ON\n    else\n      echo \"  BF16: ✗ Not detected (will use software fallback, 5-10x slower)\"\n      export CPUINFER_ENABLE_AVX512_BF16=OFF\n    fi\n  fi\n\n  # VBMI: Check if user manually set it, otherwise auto-detect\n  if [ -n \"${CPUINFER_ENABLE_AVX512_VBMI:-}\" ]; then\n    echo \"  VBMI: User override = $CPUINFER_ENABLE_AVX512_VBMI\"\n  else\n    if [ \"$HAS_AVX512_VBMI\" = \"1\" ]; then\n      echo \"  VBMI: ✓ Detected (byte permutation enabled)\"\n      export CPUINFER_ENABLE_AVX512_VBMI=ON\n    else\n      echo \"  VBMI: ✗ Not detected (FP8 MoE may not work)\"\n      export CPUINFER_ENABLE_AVX512_VBMI=OFF\n    fi\n  fi\n\n  echo \"\"\n  echo \"  Note: Software fallbacks ensure all code works on older CPUs\"\n  echo \"  Note: FP8 MoE requires AVX512F + BF16 + VNNI + VBMI\"\n  echo \"  Tip: Override with CPUINFER_ENABLE_AVX512_[VNNI|BF16|VBMI]=ON/OFF\"\n\n  echo \"\"\n  echo \"To use manual configuration instead, run: $0 build --manual\"\n  echo \"\"\n  else\n  # Manual mode - validate user configuration (no exports)\n  if [ -z \"$CPUINFER_CPU_INSTRUCT\" ] || [ -z \"$CPUINFER_ENABLE_AMX\" ]; then\n    echo \"Error: Manual mode requires CPUINFER_CPU_INSTRUCT and CPUINFER_ENABLE_AMX to be set.\"\n    echo \"\"\n    usage\n  fi\n\n  # Validate CPUINFER_CPU_INSTRUCT\n  case \"$CPUINFER_CPU_INSTRUCT\" in\n    NATIVE|FANCY|AVX512|AVX2)\n      ;;\n    *)\n      echo \"Error: Invalid CPUINFER_CPU_INSTRUCT='$CPUINFER_CPU_INSTRUCT'\"\n      echo \"Must be one of: NATIVE, FANCY, AVX512, AVX2\"\n      exit 1\n      ;;\n  esac\n\n  # Validate CPUINFER_ENABLE_AMX\n  case \"$CPUINFER_ENABLE_AMX\" in\n    ON|OFF)\n      ;;\n    *)\n      echo \"Error: Invalid CPUINFER_ENABLE_AMX='$CPUINFER_ENABLE_AMX'\"\n      echo \"Must be either: ON or OFF\"\n      exit 1\n      ;;\n  esac\n\n  # Warn about problematic configuration\n  if [ \"$CPUINFER_CPU_INSTRUCT\" = \"NATIVE\" ] && [ \"$CPUINFER_ENABLE_AMX\" = \"OFF\" ]; then\n    CPU_FEATURES=$(detect_cpu_features)\n    HAS_AMX=$(echo \"$CPU_FEATURES\" | cut -d' ' -f1)\n    if [ \"$HAS_AMX\" = \"1\" ]; then\n      echo \"==========================================\"\n      echo \"⚠️  WARNING: Risky Configuration\"\n      echo \"==========================================\"\n      echo \"\"\n      echo \"Your configuration:\"\n      echo \"  CPUINFER_CPU_INSTRUCT = NATIVE\"\n      echo \"  CPUINFER_ENABLE_AMX   = OFF\"\n      echo \"\"\n      echo \"Your CPU HAS AMX support!\"\n      echo \"\"\n      echo \"Problem:\"\n      echo \"  • NATIVE uses -march=native which auto-enables ALL CPU features\"\n      echo \"  • This may IGNORE your AMX=OFF setting\"\n      echo \"  • The binary may still contain AMX instructions\"\n      echo \"\"\n      echo \"Recommended fixes:\"\n      echo \"  1) For portable build (recommended for distribution):\"\n      echo \"       export CPUINFER_CPU_INSTRUCT=AVX512\"\n      echo \"\"\n      echo \"  2) If you want best performance on this CPU:\"\n      echo \"       export CPUINFER_ENABLE_AMX=ON\"\n      echo \"\"\n      read -p \"Continue with risky configuration? (y/N) \" -n 1 -r\n      echo\n      if [[ ! $REPLY =~ ^[Yy]$ ]]; then\n        exit 1\n      fi\n    fi\n  fi\n\n# Close MANUAL_MODE conditional\n  fi\n\necho \"==========================================\"\necho \"Building kt-kernel with configuration:\"\necho \"==========================================\"\necho \"  CPUINFER_CPU_INSTRUCT        = $CPUINFER_CPU_INSTRUCT\"\necho \"  CPUINFER_ENABLE_AMX          = $CPUINFER_ENABLE_AMX\"\necho \"  CPUINFER_ENABLE_AVX512_VNNI  = ${CPUINFER_ENABLE_AVX512_VNNI:-AUTO}\"\necho \"  CPUINFER_ENABLE_AVX512_BF16  = ${CPUINFER_ENABLE_AVX512_BF16:-AUTO}\"\necho \"  CPUINFER_ENABLE_AVX512_VBMI  = ${CPUINFER_ENABLE_AVX512_VBMI:-AUTO}\"\necho \"  CPUINFER_BUILD_TYPE          = ${CPUINFER_BUILD_TYPE:-Release}\"\necho \"  CPUINFER_PARALLEL            = ${CPUINFER_PARALLEL:-AUTO}\"\necho \"  CPUINFER_VERBOSE             = ${CPUINFER_VERBOSE:-1}\"\necho \"\"\n\nif [ ${CPUINFER_VERBOSE:-1} = \"0\" ]; then\n  python3 -m pip install .\nelse\n  python3 -m pip install . -v\nfi\n}\n\n# Subcommand dispatcher: default to \"all\"\nSUBCMD=\"all\"\nif [[ $# -gt 0 ]]; then\n  case \"$1\" in\n    deps|build|all) SUBCMD=\"$1\"; shift ;;\n    -h|--help) usage ;;\n    *) SUBCMD=\"build\" ;; # backward compatibility: flags-only => build\n  esac\nfi\n\ncase \"$SUBCMD\" in\n  deps)\n    install_dependencies\n    ;;\n  build)\n    build_step \"$@\"\n    ;;\n  all)\n    install_dependencies\n    build_step \"$@\"\n    ;;\nesac\n"
  },
  {
    "path": "kt-kernel/operators/amx/awq-moe.hpp",
    "content": "/**\n * @Description  : AWQ Int4 AMX MoE operator with KGroup quantization and zero-point support\n * @Author       : chenht2022, oql\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 2.0.0\n * @LastEditors  : oql\n * @LastEditTime : 2025-12-10\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n *\n * This file implements AWQ Int4 MoE using CRTP pattern, inheriting from moe_base.hpp.\n * AWQ weights are stored with group-wise scales and zero-points (KGroup Int4 with zeros).\n **/\n#ifndef CPUINFER_OPERATOR_AMX_AWQ_MOE_H\n#define CPUINFER_OPERATOR_AMX_AWQ_MOE_H\n\n// #define CHECK\n\n#include \"moe_base.hpp\"\n\n/**\n * @brief AWQ Int4 MoE operator using CRTP pattern\n * @tparam T Kernel type for AWQ quantization\n *\n * This class provides AWQ-specific implementations:\n * - do_gate_up_gemm: Int4 weight with KGroup scale + zeros + AMX GEMM\n * - do_down_gemm: Same Int4 KGroup GEMM\n * - load_weights: Load Int4 weights with group-wise scales and zero-points\n */\ntemplate <class T>\nclass AMX_AWQ_MOE_TP : public AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>> {\n private:\n  using Base = AMX_MOE_BASE<T, AMX_AWQ_MOE_TP<T>>;\n  using Base::config_;\n  using Base::down_ba_;\n  using Base::down_bb_;\n  using Base::down_bc_;\n  using Base::gate_bb_;\n  using Base::gate_bc_;\n  using Base::gate_up_ba_;\n  using Base::m_local_num_;\n  using Base::tp_part_idx;\n  using Base::up_bb_;\n  using Base::up_bc_;\n\n#ifdef CHECK\n  char verify_bb[100000000];\n  char check_bb[100000000];\n  uint8_t compare_expers = 3;\n#endif\n\n  inline void write_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size,\n                            size_t scale_size) {\n    std::ofstream of(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                               std::to_string(size - scale_size) + \"Byte\" + \"_quant_\" + \".kt\"));\n    if (of.is_open() == false) {\n      printf(\"Failed to open weights file for writing\\n\");\n      return;\n    }\n    of.write((char*)bb, size - scale_size);\n    of.close();\n\n    of.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" + std::to_string(scale_size) + \"Byte\" +\n                      \"_scale_\" + \".kt\"));\n    if (of.is_open() == false) {\n      printf(\"Failed to open scales file for writing\\n\");\n      return;\n    }\n    of.write(((char*)bb) + size - scale_size, scale_size);\n    of.close();\n  }\n\n  // Enhanced version that writes all data including mins for complete comparison\n  inline void write_weights(std::filesystem::path prefix, std::string mat_class, typename T::BufferB* buffer,\n                            int expert_idx, const std::string& quantization_type = \"\") {\n    auto& quant_config = config_.quant_config;\n    int& group_size = quant_config.group_size;\n\n    // Calculate dimensions based on matrix type\n    int rows, cols, num_groups;\n    size_t scale_elem_count;\n    std::string matrix_type = mat_class.substr(1, mat_class.length() - 2);  // Remove leading/trailing underscore\n    if (matrix_type == \"gate\" || matrix_type == \"up\") {\n      rows = config_.intermediate_size;\n      cols = config_.hidden_size;\n      num_groups = cols / group_size;\n      scale_elem_count = num_groups * rows;\n    } else {  // down\n      rows = config_.hidden_size;\n      cols = config_.intermediate_size;\n      num_groups = cols / group_size;\n      scale_elem_count = num_groups * rows;\n    }\n\n    size_t weight_size = (rows * cols) / 2;  // INT4 packed\n    size_t scale_size = scale_elem_count * sizeof(float);\n\n    // Create filename prefix\n    std::string filename_base = T::name() + mat_class + std::to_string(expert_idx);\n    if (!quantization_type.empty()) {\n      filename_base += \"_\" + quantization_type;\n    }\n\n    // Write quantized weights\n    std::ofstream of(prefix / (filename_base + \"_\" + std::to_string(weight_size) + \"Byte_quant.kt\"));\n    if (of.is_open()) {\n      of.write((char*)buffer->b, weight_size);\n      of.close();\n    }\n\n    // Write scales\n    of.open(prefix / (filename_base + \"_\" + std::to_string(scale_size) + \"Byte_scale.kt\"));\n    if (of.is_open()) {\n      of.write((char*)buffer->d, scale_size);\n      of.close();\n    }\n\n    // Write mins if available\n    if (quant_config.zero_point && buffer->mins) {\n      of.open(prefix / (filename_base + \"_\" + std::to_string(scale_size) + \"Byte_mins.kt\"));\n      if (of.is_open()) {\n        of.write((char*)buffer->mins, scale_size);\n        of.close();\n      }\n    }\n  }\n\n  inline void read_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size,\n                           size_t scale_size, uint8_t mat_split, uint8_t mat_split_idex) {\n    std::ifstream f(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                              std::to_string(size - scale_size) + \"Byte\" + \"_quant_\" + \".kt\"));\n    if (f.is_open() == false) {\n      printf(\"Failed to open quantized weights file for reading\\n\");\n      return;\n    }\n    f.seekg(mat_split_idex * (size - scale_size) / mat_split);\n    f.read(((char*)bb) + mat_split_idex * (size - scale_size) / mat_split, (size - scale_size) / mat_split);\n    f.close();\n\n    f.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" + std::to_string(scale_size) + \"Byte\" +\n                     \"_scale_\" + \".kt\"));\n    if (f.is_open() == false) {\n      printf(\"Failed to open scales file for reading\\n\");\n      return;\n    }\n    f.seekg(mat_split_idex * scale_size / mat_split);\n    f.read((((char*)bb) + size - scale_size) + mat_split_idex * scale_size / mat_split, scale_size / mat_split);\n    f.close();\n  }\n\n  // Enhanced version that reads all data including mins\n  inline bool read_weights(std::filesystem::path prefix, std::string mat_class, typename T::BufferB* buffer,\n                           int expert_idx, const std::string& quantization_type = \"\") {\n    auto& quant_config = config_.quant_config;\n    int& group_size = quant_config.group_size;\n\n    // Calculate dimensions based on matrix type\n    int rows, cols, num_groups;\n    size_t scale_elem_count;\n    std::string matrix_type = mat_class.substr(1, mat_class.length() - 2);  // Remove leading/trailing underscore\n    if (matrix_type == \"gate\" || matrix_type == \"up\") {\n      rows = config_.intermediate_size;\n      cols = config_.hidden_size;\n      num_groups = cols / group_size;\n      scale_elem_count = num_groups * rows;\n    } else {  // down\n      rows = config_.hidden_size;\n      cols = config_.intermediate_size;\n      num_groups = cols / group_size;\n      scale_elem_count = num_groups * rows;\n    }\n\n    size_t weight_size = (rows * cols) / 2;  // INT4 packed\n    size_t scale_size = scale_elem_count * sizeof(float);\n\n    // Create filename prefix\n    std::string filename_base = T::name() + mat_class + std::to_string(expert_idx);\n    if (!quantization_type.empty()) {\n      filename_base += \"_\" + quantization_type;\n    }\n\n    // Read quantized weights\n    std::ifstream f(prefix / (filename_base + \"_\" + std::to_string(weight_size) + \"Byte_quant.kt\"));\n    if (!f.is_open()) {\n      return false;\n    }\n    f.read((char*)buffer->b, weight_size);\n    f.close();\n\n    // Read scales\n    f.open(prefix / (filename_base + \"_\" + std::to_string(scale_size) + \"Byte_scale.kt\"));\n    if (!f.is_open()) {\n      return false;\n    }\n    f.read((char*)buffer->d, scale_size);\n    f.close();\n\n    // Read mins if available and buffer supports it\n    if (quant_config.zero_point && buffer->mins) {\n      f.open(prefix / (filename_base + \"_\" + std::to_string(scale_size) + \"Byte_mins.kt\"));\n      if (f.is_open()) {\n        f.read((char*)buffer->mins, scale_size);\n        f.close();\n      }\n    }\n\n    return true;\n  }\n\n  // AWQ-specific function to read quantized weights, scales and zeros from files\n  inline void read_awq_weights(std::filesystem::path prefix, std::string proj_name, int expert_idx, char* weights_buf,\n                               float* scales_buf, uint8_t* zeros_buf, size_t weights_size, size_t scales_size,\n                               size_t zeros_size, uint8_t mat_split, uint8_t mat_split_idx) {\n    // Read qweights (quantized weights)\n    std::string weights_filename = proj_name + \".qweight.\" + std::to_string(expert_idx) + \".bin\";\n    std::ifstream weights_file(prefix / weights_filename, std::ios::binary);\n    if (!weights_file.is_open()) {\n      printf(\"Failed to open weights file: %s\\n\", (prefix / weights_filename).c_str());\n      throw std::runtime_error(\"Failed to open weights file: \" + weights_filename);\n    }\n\n    weights_file.seekg(mat_split_idx * weights_size / mat_split);\n    weights_file.read(weights_buf + mat_split_idx * weights_size / mat_split, weights_size / mat_split);\n    weights_file.close();\n\n    // Read scales\n    std::string scales_filename = proj_name + \".scales.\" + std::to_string(expert_idx) + \".bin\";\n    std::ifstream scales_file(prefix / scales_filename, std::ios::binary);\n    if (!scales_file.is_open()) {\n      printf(\"Failed to open scales file: %s\\n\", (prefix / scales_filename).c_str());\n      throw std::runtime_error(\"Failed to open scales file: \" + scales_filename);\n    }\n\n    scales_file.seekg(mat_split_idx * scales_size / mat_split);\n    scales_file.read(reinterpret_cast<char*>(scales_buf) + mat_split_idx * scales_size / mat_split,\n                     scales_size / mat_split);\n    scales_file.close();\n\n    // Read qzeros (quantized zeros)\n    std::string zeros_filename = proj_name + \".qzeros.\" + std::to_string(expert_idx) + \".bin\";\n    std::ifstream zeros_file(prefix / zeros_filename, std::ios::binary);\n    if (!zeros_file.is_open()) {\n      printf(\"Failed to open zeros file: %s\\n\", (prefix / zeros_filename).c_str());\n      throw std::runtime_error(\"Failed to open zeros file: \" + zeros_filename);\n    }\n\n    zeros_file.seekg(mat_split_idx * zeros_size / mat_split);\n    zeros_file.read(reinterpret_cast<char*>(zeros_buf) + mat_split_idx * zeros_size / mat_split,\n                    zeros_size / mat_split);\n    zeros_file.close();\n  }\n\n#ifdef CHECK\n  inline void load_check() {\n    memcpy(check_bb, (char*)down_bb_[compare_expers]->b,\n           T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size));\n  }\n\n  void verify_load_right() {\n    memcpy(verify_bb, (char*)down_bb_[compare_expers]->b,\n           T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size));\n    if (memcmp(verify_bb, check_bb,\n               T::BufferB::required_size(config_.hidden_size, config_.intermediate_size,\n                                         config_.quant_config.group_size)) != 0) {\n      printf(\"verify error\\n\");\n      for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size,\n                                                       config_.quant_config.group_size);\n           ++i) {\n        if (verify_bb[i] != check_bb[i]) {\n          printf(\"Difference at byte %zu: verify_bb_%d[%zu] = %02x, check_bb[%zu] = %02x\\n\", i, compare_expers, i,\n                 (unsigned char)verify_bb[i], i, (unsigned char)check_bb[i]);\n          break;\n        }\n      }\n      assert(0);\n    } else {\n      printf(\"pass verify\\n\");\n      printf(\"numa %d, verify_bb_%d:\\n\", tp_part_idx, compare_expers);\n      size_t size =\n          T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, config_.quant_config.group_size);\n      size_t scale_size = config_.hidden_size * sizeof(float);\n      for (size_t i = size - scale_size; i < size - scale_size + 50; ++i) {\n        printf(\"%02x \", (unsigned char)verify_bb[i]);\n      }\n      printf(\"\\n\");\n    }\n  }\n#endif\n\n  // Function to dump Buffer B data for debugging quantization results\n  inline void dump_buffer_b(const std::string& quantization_type, int expert_idx, const std::string& matrix_type,\n                            typename T::BufferB* buffer) {\n    auto& quant_config = config_.quant_config;\n    int& group_size = quant_config.group_size;\n\n    printf(\"[DUMP_BUFFER_B] TP%d %s Expert%d %s:\\n\", tp_part_idx, quantization_type.c_str(), expert_idx,\n           matrix_type.c_str());\n\n    // Calculate dimensions based on matrix type\n    int rows, cols, num_groups;\n    size_t scale_elem_count;\n    if (matrix_type == \"gate\" || matrix_type == \"up\") {\n      rows = config_.intermediate_size;\n      cols = config_.hidden_size;\n      num_groups = cols / group_size;\n      scale_elem_count = num_groups * rows;\n    } else {  // down\n      rows = config_.hidden_size;\n      cols = config_.intermediate_size;\n      num_groups = cols / group_size;\n      scale_elem_count = num_groups * rows;\n    }\n\n    // Dump scales (as float)\n    printf(\"  Scales[first 16]: \");\n    for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) {\n      printf(\"%.6f \", buffer->d[i]);\n    }\n    printf(\"\\n\");\n\n    if (scale_elem_count > 16) {\n      printf(\"  Scales[last 16]: \");\n      int start_idx = std::max(0, (int)scale_elem_count - 16);\n      for (int i = start_idx; i < (int)scale_elem_count; i++) {\n        printf(\"%.6f \", buffer->d[i]);\n      }\n      printf(\"\\n\");\n    }\n\n    // Dump mins (as float) if available\n    if (quant_config.zero_point && buffer->mins) {\n      printf(\"  Mins[first 16]: \");\n      for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) {\n        printf(\"%.6f \", buffer->mins[i]);\n      }\n      printf(\"\\n\");\n\n      if (scale_elem_count > 16) {\n        printf(\"  Mins[last 16]: \");\n        int start_idx = std::max(0, (int)scale_elem_count - 16);\n        for (int i = start_idx; i < (int)scale_elem_count; i++) {\n          printf(\"%.6f \", buffer->mins[i]);\n        }\n        printf(\"\\n\");\n      }\n    }\n\n    // Dump quantized weights (as hex uint8)\n    size_t weight_size = (rows * cols) / 2;  // INT4 packed\n    uint8_t* weight_ptr = (uint8_t*)buffer->b;\n\n    printf(\"  Weights[first 32 bytes]: \");\n    for (int i = 0; i < std::min(32, (int)weight_size); i++) {\n      printf(\"%02x \", weight_ptr[i]);\n    }\n    printf(\"\\n\");\n\n    if (weight_size > 32) {\n      printf(\"  Weights[last 32 bytes]: \");\n      int start_idx = std::max(32, (int)weight_size - 32);\n      for (int i = start_idx; i < (int)weight_size; i++) {\n        printf(\"%02x \", weight_ptr[i]);\n      }\n      printf(\"\\n\");\n    }\n\n    printf(\"  Matrix dimensions: %dx%d, Groups: %d, Group size: %d, Scale elements: %zu\\n\", rows, cols, num_groups,\n           group_size, scale_elem_count);\n    printf(\"\\n\");\n  }\n\n  // AVX-optimized function to convert INT4 zeros to float mins\n  // mins = -(zeros * scales) (element-wise), where scales is float format\n  inline void convert_zeros_to_mins_avx(const uint32_t* zeros_int4_packed, const float* scales, float* mins,\n                                        size_t num_elements) {\n    constexpr size_t simd_width = 8;  // 每次解 8 个 int4\n\n    for (size_t i = 0; i < num_elements; i += simd_width) {\n      uint32_t packed_vals = zeros_int4_packed[i / 8];\n\n      for (int j = 0; j < 8; j++) {\n        int v = packed_vals & 0xF;  // 取出4bit\n        mins[i + j] = -(scales[i + j] * v);\n        packed_vals = packed_vals >> 4;\n      }\n    }\n  }\n\n public:\n  using typename Base::input_t;\n  using typename Base::output_t;\n\n  AMX_AWQ_MOE_TP() = default;\n\n  AMX_AWQ_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}\n\n  void derived_init() {\n    auto& quant_config = config_.quant_config;\n    if (quant_config.group_size == 0 || !quant_config.zero_point) {\n      throw std::runtime_error(\"AWQ-Quantization AMX MoE only support KGroup Int4_1\");\n    }\n\n    printf(\"Creating AMX_AWQ_MOE_TP %d at numa %d\\n\", tp_part_idx, numa_node_of_cpu(sched_getcpu()));\n\n    auto& load = config_.load;\n    auto& save = config_.save;\n\n    std::filesystem::path prefix = config_.path;\n    prefix = prefix / (\"_layer_\" + std::to_string(config_.layer_idx)) / (\"_numa_\" + std::to_string(tp_part_idx));\n    if (save) {\n      std::cout << \"Creating \" << prefix << std::endl;\n      std::filesystem::create_directories(prefix);\n    }\n    if (load) {\n      if (std::filesystem::exists(prefix)) {\n        std::cout << \"Loading from \" << prefix << std::endl;\n      } else {\n        throw std::runtime_error(\"Path not found: \" + prefix.string());\n      }\n    }\n  }\n\n  ~AMX_AWQ_MOE_TP() = default;\n\n  // ============================================================================\n  // CRTP buffer creation - with group_size (AWQ uses zero-point)\n  // ============================================================================\n\n  size_t buffer_a_required_size_impl(size_t m, size_t k) const {\n    return T::BufferA::required_size(m, k, config_.quant_config.group_size);\n  }\n  size_t buffer_b_required_size_impl(size_t n, size_t k) const {\n    return T::BufferB::required_size(n, k, config_.quant_config.group_size);\n  }\n  size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }\n\n  std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferA>(m, k, config_.quant_config.group_size, data);\n  }\n  std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);\n  }\n  std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {\n    return std::make_shared<typename T::BufferC>(m, n, data);\n  }\n\n  // ============================================================================\n  // CRTP virtual points - GEMM dispatch (uses kgroup with zeros)\n  // ============================================================================\n\n  void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {\n    auto& group_size = config_.quant_config.group_size;\n    int m = m_local_num_[expert_idx];\n    auto& ba = gate_up_ba_[expert_idx];\n    auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];\n    auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];\n\n    // Dispatch based on qlen threshold\n    if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {\n      amx::mat_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);\n    } else {\n      amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);\n    }\n  }\n\n  void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {\n    auto& group_size = config_.quant_config.group_size;\n    int m = m_local_num_[expert_idx];\n\n    if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {\n      amx::mat_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],\n                          down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);\n    } else {\n      amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],\n                          down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);\n    }\n  }\n\n  /**\n   * @brief Load Int4 weights with scales and zero-points\n   *\n   * AWQ weights include:\n   * - Quantized INT4 weights\n   * - FP16 scales (converted to FP32)\n   * - INT4 zeros (converted to FP32 mins = -scale * zero)\n   */\n  void load_weights() {\n    auto& quant_config = config_.quant_config;\n    int& group_size = quant_config.group_size;\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;\n    if (quant_config.group_size == 0 || !quant_config.zero_point) {\n      throw std::runtime_error(\"AWQ-Quantization AMX MoE only support KGroup Int4_1\");\n    }\n\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n    if (config_.gate_projs.size()) {\n      throw std::runtime_error(\"AMX load weights from gate_projs is not supported\");\n    } else {\n      int nth = T::recommended_nth(config_.intermediate_size);\n      std::filesystem::path prefix = config_.path;\n      prefix = prefix / (\"_layer_\" + std::to_string(config_.layer_idx)) / (\"_numa_\" + std::to_string(tp_part_idx));\n\n      if (config_.load) {\n        throw std::runtime_error(\"AMX load weights from file is not supported\");\n      }\n#ifdef CHECK\n      load_check();\n#endif\n#ifndef CHECK\n      else if (config_.gate_scale != nullptr)\n#endif\n      {\n        // Loading quantized weights with scales and zeros\n        pool->do_work_stealing_job(\n            nth * config_.expert_num, nullptr,\n            [this, nth, physical_to_logical_map](int task_id) {\n              uint64_t expert_idx = task_id / nth;\n              uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n              int ith = task_id % nth;\n              // gate part\n              gate_bb_[expert_idx]->from_raw_mat(\n                  (uint8_t*)config_.gate_proj +\n                      ((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),\n                  ith, nth);\n              // up part\n              up_bb_[expert_idx]->from_raw_mat(\n                  (uint8_t*)config_.up_proj +\n                      ((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),\n                  ith, nth);\n            },\n            nullptr);\n\n        nth = T::recommended_nth(config_.hidden_size);\n        pool->do_work_stealing_job(\n            nth * config_.expert_num, nullptr,\n            [this, nth, physical_to_logical_map](int task_id) {\n              uint64_t expert_idx = task_id / nth;\n              uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n              int ith = task_id % nth;\n              // down part\n              down_bb_[expert_idx]->from_raw_mat(\n                  (uint8_t*)config_.down_proj +\n                      ((logical_expert_id * config_.hidden_size * config_.intermediate_size) >> 1),\n                  ith, nth);\n            },\n            nullptr);\n\n        pool->do_work_stealing_job(\n            config_.expert_num, nullptr,\n            [this, physical_to_logical_map](int task_id) {\n              uint64_t expert_idx = task_id;\n              uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n              size_t scale_elem_count =\n                  (config_.hidden_size * config_.intermediate_size) / config_.quant_config.group_size;\n\n              // convert scales from FP16 to FP32\n              convert_or_copy(gate_bb_[expert_idx]->d,\n                              (ggml_fp16_t*)config_.gate_scale + (logical_expert_id * scale_elem_count),\n                              scale_elem_count);\n              convert_or_copy(up_bb_[expert_idx]->d,\n                              (ggml_fp16_t*)config_.up_scale + (logical_expert_id * scale_elem_count),\n                              scale_elem_count);\n              convert_or_copy(down_bb_[expert_idx]->d,\n                              (ggml_fp16_t*)config_.down_scale + (logical_expert_id * scale_elem_count),\n                              scale_elem_count);\n\n              // Convert INT4 zeros to FP32 mins: mins = -(scale * zero)\n              convert_zeros_to_mins_avx(\n                  (const uint32_t*)((uint8_t*)config_.gate_zero + ((logical_expert_id * scale_elem_count) >> 1)),\n                  gate_bb_[expert_idx]->d, gate_bb_[expert_idx]->mins, scale_elem_count);\n              convert_zeros_to_mins_avx(\n                  (const uint32_t*)((uint8_t*)config_.up_zero + ((logical_expert_id * scale_elem_count) >> 1)),\n                  up_bb_[expert_idx]->d, up_bb_[expert_idx]->mins, scale_elem_count);\n              convert_zeros_to_mins_avx(\n                  (const uint32_t*)((uint8_t*)config_.down_zero + ((logical_expert_id * scale_elem_count) >> 1)),\n                  down_bb_[expert_idx]->d, down_bb_[expert_idx]->mins, scale_elem_count);\n            },\n            nullptr);\n\n        // Save offline quantization data if requested\n        if (config_.save) {\n          for (int expert_idx = 0; expert_idx < config_.expert_num; expert_idx++) {\n            write_weights(prefix, \"_gate_\", gate_bb_[expert_idx].get(), expert_idx, \"OFFLINE\");\n            write_weights(prefix, \"_up_\", up_bb_[expert_idx].get(), expert_idx, \"OFFLINE\");\n            write_weights(prefix, \"_down_\", down_bb_[expert_idx].get(), expert_idx, \"OFFLINE\");\n          }\n        }\n      }\n      else {\n        // Online Quantization from BF16\n        assert(config_.gate_proj != nullptr);\n\n        pool->do_work_stealing_job(\n            nth * config_.expert_num, nullptr,\n            [this, nth, physical_to_logical_map](int task_id) {\n              int64_t expert_idx = task_id / nth;\n              uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n              int ith = task_id % nth;\n              // gate part\n              gate_bb_[logical_expert_id]->from_mat(\n                  (ggml_bf16_t*)config_.gate_proj +\n                      (logical_expert_id * config_.intermediate_size * config_.hidden_size),\n                  ith, nth);\n              // up part\n              up_bb_[logical_expert_id]->from_mat(\n                  (ggml_bf16_t*)config_.up_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),\n                  ith, nth);\n            },\n            nullptr);\n\n        nth = T::recommended_nth(config_.hidden_size);\n        pool->do_work_stealing_job(\n            nth * config_.expert_num, nullptr,\n            [this, nth, physical_to_logical_map](int task_id) {\n              int64_t expert_idx = task_id / nth;\n              uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n              int ith = task_id % nth;\n              // down part\n              down_bb_[logical_expert_id]->from_mat(\n                  (ggml_bf16_t*)config_.down_proj +\n                      (logical_expert_id * config_.hidden_size * config_.intermediate_size),\n                  ith, nth);\n            },\n            nullptr);\n\n        // Save online quantization data if requested\n        if (config_.save) {\n          for (int expert_idx = 0; expert_idx < config_.expert_num; expert_idx++) {\n            write_weights(prefix, \"_gate_\", gate_bb_[expert_idx].get(), expert_idx, \"ONLINE\");\n            write_weights(prefix, \"_up_\", up_bb_[expert_idx].get(), expert_idx, \"ONLINE\");\n            write_weights(prefix, \"_down_\", down_bb_[expert_idx].get(), expert_idx, \"ONLINE\");\n          }\n        }\n      }\n#ifdef CHECK\n      verify_load_right();\n#endif\n    }\n  }\n\n  // forward, forward_prefill, forward_decode, warm_up are inherited from Base\n};\n\n// ============================================================================\n// TP_MOE specialization for AMX_AWQ_MOE_TP\n// Inherits from TP_MOE<AMX_MOE_BASE<...>> to reuse merge_results implementation\n// ============================================================================\n\ntemplate <typename K>\nclass TP_MOE<AMX_AWQ_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_AWQ_MOE_TP<K>>> {\n public:\n  using Base = TP_MOE<AMX_MOE_BASE<K, AMX_AWQ_MOE_TP<K>>>;\n  using Base::Base;\n\n  void load_weights() override {\n    auto& config = this->config;\n    auto& tps = this->tps;\n    auto& tp_count = this->tp_count;\n    auto pool = config.pool;\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;\n    if (config.gate_projs.empty() == false) {\n      printf(\"TP Load from loader\\n\");\n      DO_TPS_LOAD_WEIGHTS(pool);\n      this->weights_loaded = true;\n    } else if (config.gate_scale != nullptr) {\n      printf(\"From Packed Int4 with KGroup Scale and Zeros\\n\");\n      int& group_size = config.quant_config.group_size;\n      for (auto i = 0; i < tp_count; i++) {\n        auto& tpc = tps[i]->config_;\n        size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;\n        tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];\n        tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];\n        tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];\n\n        size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;\n\n        tpc.gate_scale = new ggml_fp16_t[(tpc.expert_num * scales_elem_count)];\n        tpc.up_scale = new ggml_fp16_t[(tpc.expert_num * scales_elem_count)];\n        tpc.down_scale = new ggml_fp16_t[(tpc.expert_num * scales_elem_count)];\n\n        tpc.gate_zero = new uint8_t[(tpc.expert_num * scales_elem_count) / 2];\n        tpc.up_zero = new uint8_t[(tpc.expert_num * scales_elem_count) / 2];\n        tpc.down_zero = new uint8_t[(tpc.expert_num * scales_elem_count) / 2];\n        if (tps[i]->config_.load == false) {\n          pool->get_subpool(i)->do_work_stealing_job(\n              tpc.expert_num, nullptr,\n              [&](int expert_id_) {\n                size_t expert_id = expert_map(physical_to_logical_map, expert_id_);\n\n                // weight TP-slicing\n                memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),\n                       (uint8_t*)config.gate_proj +\n                           ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),\n                       ((sizeof(uint8_t) * weight_elem_count) >> 1));\n\n                memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),\n                       (uint8_t*)config.up_proj +\n                           ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),\n                       ((sizeof(uint8_t) * weight_elem_count) >> 1));\n\n                // down scales and zeros TP-slicing\n                memcpy((ggml_fp16_t*)tpc.down_scale + (expert_id * scales_elem_count),\n                       (ggml_fp16_t*)config.down_scale +\n                           (expert_id * (config.intermediate_size / group_size) * config.hidden_size +\n                            i * scales_elem_count),\n                       sizeof(ggml_fp16_t) * scales_elem_count);\n\n                memcpy((uint8_t*)tpc.down_zero + ((expert_id * scales_elem_count) >> 1),\n                       (uint8_t*)config.down_zero +\n                           ((expert_id * (config.intermediate_size / group_size) * config.hidden_size +\n                             i * scales_elem_count) >>\n                            1),\n                       (sizeof(uint8_t) * scales_elem_count) >> 1);\n\n                for (size_t kg = 0; kg < config.hidden_size / group_size; kg++) {\n                  // copy gate/up scales\n                  memcpy((ggml_fp16_t*)tpc.gate_scale + (expert_id * scales_elem_count) + kg * tpc.intermediate_size,\n                         (ggml_fp16_t*)config.gate_scale +\n                             (expert_id * ((config.hidden_size / group_size) * config.intermediate_size) +\n                              kg * config.intermediate_size + i * tpc.intermediate_size),\n                         (sizeof(ggml_fp16_t) * tpc.intermediate_size));\n\n                  memcpy((ggml_fp16_t*)tpc.up_scale + (expert_id * scales_elem_count) + kg * tpc.intermediate_size,\n                         (ggml_fp16_t*)config.up_scale +\n                             (expert_id * ((config.hidden_size / group_size) * config.intermediate_size) +\n                              kg * config.intermediate_size + i * tpc.intermediate_size),\n                         (sizeof(ggml_fp16_t) * tpc.intermediate_size));\n\n                  // copy gate/up zeros TP-slicing\n                  memcpy(\n                      (uint8_t*)tpc.gate_zero + (((expert_id * scales_elem_count) + kg * tpc.intermediate_size) >> 1),\n                      (uint8_t*)config.gate_zero +\n                          ((expert_id * ((config.hidden_size / group_size) * config.intermediate_size) +\n                            kg * config.intermediate_size + i * tpc.intermediate_size) >>\n                           1),\n                      ((sizeof(uint8_t) * tpc.intermediate_size) >> 1));\n\n                  memcpy((uint8_t*)tpc.up_zero + (((expert_id * scales_elem_count) + kg * tpc.intermediate_size) >> 1),\n                         (uint8_t*)config.up_zero +\n                             ((expert_id * ((config.hidden_size / group_size) * config.intermediate_size) +\n                               kg * config.intermediate_size + i * tpc.intermediate_size) >>\n                              1),\n                         ((sizeof(uint8_t) * tpc.intermediate_size) >> 1));\n                }\n\n                // down weights TP-slicing (column-wise)\n                for (size_t col = 0; col < config.hidden_size; col++) {\n                  memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),\n                         (uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +\n                                                        col * config.intermediate_size + i * tpc.intermediate_size) >>\n                                                       1),\n                         (sizeof(uint8_t) * tpc.intermediate_size) >> 1);\n                }\n              },\n              nullptr);\n        }\n      }\n\n      DO_TPS_LOAD_WEIGHTS(pool);\n\n      for (auto i = 0; i < tp_count; i++) {\n        auto& tpc = tps[i]->config_;\n        delete[] (uint8_t*)(tpc.gate_proj);\n        delete[] (uint8_t*)(tpc.up_proj);\n        delete[] (uint8_t*)(tpc.down_proj);\n\n        delete[] (ggml_fp16_t*)(tpc.gate_scale);\n        delete[] (ggml_fp16_t*)(tpc.up_scale);\n        delete[] (ggml_fp16_t*)(tpc.down_scale);\n\n        delete[] (uint8_t*)(tpc.gate_zero);\n        delete[] (uint8_t*)(tpc.up_zero);\n        delete[] (uint8_t*)(tpc.down_zero);\n      }\n\n      this->weights_loaded = true;\n    } else if (config.gate_proj != nullptr) {\n      printf(\"From BF16 Online Quantization.\\n\");\n      fflush(stdout);\n      for (auto i = 0; i < tp_count; i++) {\n        auto& tpc = tps[i]->config_;\n        size_t gate_up_elcount = tpc.intermediate_size * tpc.hidden_size;\n        tpc.gate_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];\n        tpc.up_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];\n        tpc.down_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];\n        if (tps[i]->config_.load == false) {\n          pool->get_subpool(i)->do_work_stealing_job(\n              tpc.expert_num, nullptr,\n              [&](int expert_id_) {\n                size_t expert_id = expert_map(physical_to_logical_map, expert_id_);\n                memcpy((ggml_bf16_t*)tpc.gate_proj + expert_id * gate_up_elcount,\n                       (ggml_bf16_t*)config.gate_proj + expert_id * config.intermediate_size * config.hidden_size +\n                           i * gate_up_elcount,\n                       sizeof(ggml_bf16_t) * gate_up_elcount);\n                memcpy((ggml_bf16_t*)tpc.up_proj + expert_id * gate_up_elcount,\n                       (ggml_bf16_t*)config.up_proj + expert_id * config.intermediate_size * config.hidden_size +\n                           i * gate_up_elcount,\n                       sizeof(ggml_bf16_t) * gate_up_elcount);\n                for (size_t col = 0; col < config.hidden_size; col++) {\n                  memcpy((ggml_bf16_t*)tpc.down_proj + expert_id * tpc.hidden_size * tpc.intermediate_size +\n                             col * tpc.intermediate_size,\n                         (ggml_bf16_t*)config.down_proj + expert_id * config.intermediate_size * config.hidden_size +\n                             col * config.intermediate_size + i * tpc.intermediate_size,\n                         sizeof(ggml_bf16_t) * tpc.intermediate_size);\n                }\n              },\n              nullptr);\n        }\n      }\n\n      DO_TPS_LOAD_WEIGHTS(pool);\n\n      for (auto i = 0; i < tp_count; i++) {\n        auto& tpc = tps[i]->config_;\n        delete[] (ggml_bf16_t*)(tpc.gate_proj);\n        delete[] (ggml_bf16_t*)(tpc.up_proj);\n        delete[] (ggml_bf16_t*)(tpc.down_proj);\n      }\n\n      this->weights_loaded = true;\n    } else if (config.path != \"\") {\n      printf(\"TP Load from file\\n\");\n      DO_TPS_LOAD_WEIGHTS(pool);\n      this->weights_loaded = true;\n    } else {\n      throw std::runtime_error(\"no weight source\");\n    }\n  }\n\n  // merge_results is inherited from TP_MOE<AMX_MOE_BASE<K, AMX_AWQ_MOE_TP<K>>>\n};\n\n#endif\n"
  },
  {
    "path": "kt-kernel/operators/amx/bf16-moe.hpp",
    "content": "/**\n * @Description  : BF16 AMX MoE operator for native BF16 inference\n * @Author       : oql, Codex and Claude\n * @Date         : 2026-01-06\n * @Version      : 1.0.0\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n *\n * This file implements BF16 MoE using CRTP pattern, inheriting from moe_base.hpp.\n * BF16 weights are stored without quantization (no scales).\n **/\n#ifndef CPUINFER_OPERATOR_AMX_BF16_MOE_H\n#define CPUINFER_OPERATOR_AMX_BF16_MOE_H\n\n// #define DEBUG_BF16_MOE\n\n#include \"la/amx_kernels.hpp\"  // For vec_mul/mat_mul\n#include \"la/amx_raw_buffers.hpp\"\n#include \"la/amx_raw_kernels.hpp\"\n#include \"la/amx_utils.hpp\"  // For transpose_16x16_32bit\n#include \"moe_base.hpp\"\n\n/**\n * @brief BF16 MoE operator using CRTP pattern\n * @tparam T Kernel type, defaults to GemmKernel224BF16\n *\n * This class provides BF16-specific implementations:\n * - do_gate_up_gemm, do_down_gemm: BF16 weight mat mul (no quantization)\n * - load_weights: Load native BF16 weights (no scales)\n */\ntemplate <class T = amx::GemmKernel224BF16>\nclass AMX_BF16_MOE_TP : public AMX_MOE_BASE<T, AMX_BF16_MOE_TP<T>> {\n  using Base = AMX_MOE_BASE<T, AMX_BF16_MOE_TP<T>>;\n  using Base::config_;\n  using Base::down_ba_;\n  using Base::down_bb_;\n  using Base::down_bc_;\n  using Base::gate_bb_;\n  using Base::gate_bc_;\n  using Base::gate_up_ba_;\n  using Base::m_local_num_;\n  using Base::tp_part_idx;\n  using Base::up_bb_;\n  using Base::up_bc_;\n\n public:\n  using typename Base::input_t;\n  using typename Base::output_t;\n\n  AMX_BF16_MOE_TP() = default;\n\n  AMX_BF16_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {\n    // Initialization now happens in derived_init() which is called by base constructor\n  }\n\n  void derived_init() {\n    // BF16 has no quantization, no need to check quant_config\n    printf(\"Created AMX_BF16_MOE_TP %d at numa %d\\n\", tp_part_idx, numa_node_of_cpu(sched_getcpu()));\n  }\n\n  ~AMX_BF16_MOE_TP() = default;\n\n  // ============================================================================\n  // CRTP buffer creation - without group_size\n  // ============================================================================\n\n  size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }\n\n  size_t buffer_b_required_size_impl(size_t n, size_t k) const {\n    return T::BufferB::required_size(n, k);  // 2 parameters - no group_size\n  }\n\n  size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }\n\n  std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferA>(m, k, data);\n  }\n\n  std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferB>(n, k, data);  // 2 parameters - no group_size\n  }\n\n  std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {\n    return std::make_shared<typename T::BufferC>(m, n, data);\n  }\n\n  // ============================================================================\n  // CRTP virtual points - GEMM dispatch\n  // ============================================================================\n\n  void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {\n    int m = m_local_num_[expert_idx];\n    auto& ba = gate_up_ba_[expert_idx];\n    auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];\n    auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];\n\n    // Use vec_mul/mat_mul (no group_size)\n    if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {\n      amx::mat_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);\n    } else {\n      amx::vec_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);\n    }\n  }\n\n  void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {\n    int m = m_local_num_[expert_idx];\n\n    if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {\n      amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_[expert_idx],\n                   down_bc_[expert_idx], ith, nth);\n    } else {\n      amx::vec_mul(m, config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx], down_bb_[expert_idx],\n                   down_bc_[expert_idx], ith, nth);\n    }\n  }\n\n#ifdef DEBUG_BF16_MOE\n  // Function to dump Buffer B data for debugging\n  inline void dump_buffer_b(int expert_idx, const std::string& matrix_type, typename T::BufferB* buffer) {\n    printf(\"[DUMP_BUFFER_B] TP%d BF16 Expert%d %s:\\n\", tp_part_idx, expert_idx, matrix_type.c_str());\n\n    // Calculate dimensions based on matrix type\n    int rows, cols;\n    if (matrix_type == \"gate\" || matrix_type == \"up\") {\n      rows = config_.intermediate_size;\n      cols = config_.hidden_size;\n    } else {  // down\n      rows = config_.hidden_size;\n      cols = config_.intermediate_size;\n    }\n\n    // Dump BF16 weights\n    size_t weight_size = (size_t)rows * cols;\n    ggml_bf16_t* weight_ptr = buffer->b;\n\n    printf(\"  BF16 Weights[first 16]: \");\n    for (int i = 0; i < std::min(16, (int)weight_size); i++) {\n      printf(\"%.6f \", ggml_bf16_to_fp32(weight_ptr[i]));\n    }\n    printf(\"\\n\");\n\n    if (weight_size > 16) {\n      printf(\"  BF16 Weights[last 16]: \");\n      int start_idx = std::max(0, (int)weight_size - 16);\n      for (int i = start_idx; i < (int)weight_size; i++) {\n        printf(\"%.6f \", ggml_bf16_to_fp32(weight_ptr[i]));\n      }\n      printf(\"\\n\");\n    }\n\n    printf(\"  Matrix dimensions: %dx%d (n x k)\\n\", rows, cols);\n  }\n#endif\n\n  /**\n   * @brief Load BF16 weights from contiguous memory layout\n   *\n   * Loads weights from config_.gate_proj, up_proj, down_proj (no scales).\n   */\n  void load_weights() {\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n\n    if (config_.gate_proj == nullptr) {\n      throw std::runtime_error(\"BF16 MOE requires native BF16 weight.\");\n    }\n\n    // Load gate + up weights\n    int nth = T::recommended_nth(config_.intermediate_size);\n    pool->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [this, nth, physical_to_logical_map](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          int ith = task_id % nth;\n\n          // Gate: from BF16 data (no scale)\n          gate_bb_[expert_idx]->from_mat(\n              (ggml_bf16_t*)config_.gate_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),\n              ith, nth);  // 3 parameters: (bf16*, ith, nth)\n\n          // Up: same\n          up_bb_[expert_idx]->from_mat(\n              (ggml_bf16_t*)config_.up_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),\n              ith, nth);\n        },\n        nullptr);\n\n    // Load down weights\n    nth = T::recommended_nth(config_.hidden_size);\n    pool->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [this, nth, physical_to_logical_map](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          int ith = task_id % nth;\n\n          // Down\n          down_bb_[expert_idx]->from_mat(\n              (ggml_bf16_t*)config_.down_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),\n              ith, nth);\n        },\n        nullptr);\n\n#ifdef DEBUG_BF16_MOE\n    dump_buffer_b(0, \"gate\", gate_bb_[0].get());\n    dump_buffer_b(0, \"down\", down_bb_[0].get());\n#endif\n  }\n\n  // Fast 64-byte (512-bit) memcpy using AVX512\n  static inline void fast_memcpy_64(void* __restrict dst, const void* __restrict src) {\n    __m512i data = _mm512_loadu_si512(src);\n    _mm512_storeu_si512(dst, data);\n  }\n\n  // Fast 64-byte non-temporal store (bypass cache for write-only patterns)\n  static inline void fast_stream_64(void* __restrict dst, const void* __restrict src) {\n    __m512i data = _mm512_loadu_si512(src);\n    _mm512_stream_si512((__m512i*)dst, data);\n  }\n\n  // Fast memcpy for arbitrary sizes using AVX512\n  static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) {\n    uint8_t* d = (uint8_t*)dst;\n    const uint8_t* s = (const uint8_t*)src;\n    size_t chunks = bytes / 64;\n    for (size_t i = 0; i < chunks; i++) {\n      fast_memcpy_64(d, s);\n      d += 64;\n      s += 64;\n    }\n    bytes -= chunks * 64;\n    if (bytes > 0) {\n      std::memcpy(d, s, bytes);\n    }\n  }\n\n  /**\n   * @brief Unpack a single N_STEP x K_STEP block from packed BufferB format to n-major format (BF16 version)\n   *\n   * This is the inverse of the packing done in BufferBBF16Impl::from_mat.\n   * BF16 elements are 2 bytes, and the packed format includes 16x16 32-bit transpose.\n   *\n   * @param src Pointer to packed data (N_STEP * K_STEP * 2 bytes in packed layout)\n   * @param dst Pointer to destination in n-major layout\n   * @param dst_row_stride Row stride in destination buffer (number of BF16 elements per row)\n   */\n  static inline void unpack_nk_block_bf16(const ggml_bf16_t* src, ggml_bf16_t* dst, size_t dst_row_stride) {\n    constexpr int N_STEP = T::N_STEP;  // 32\n    constexpr int K_STEP = T::K_STEP;  // 32\n    constexpr int TILE_N = T::TILE_N;  // 16\n\n    // The packed format has two 16x16 blocks (32-bit view) that were transposed\n    // We need to reverse the transpose first, then copy to n-major layout\n\n    // Create aligned temporary buffers for transpose\n    alignas(64) __m512i temp_block1[TILE_N];\n    alignas(64) __m512i temp_block2[TILE_N];\n\n    // Copy source data to temporary buffers\n    const __m512i* src_vec = reinterpret_cast<const __m512i*>(src);\n    for (int i = 0; i < TILE_N; i++) {\n      temp_block1[i] = src_vec[i];\n      temp_block2[i] = src_vec[TILE_N + i];\n    }\n\n    // Reverse transpose (transpose is self-inverse)\n    amx::transpose_16x16_32bit(temp_block1);\n    amx::transpose_16x16_32bit(temp_block2);\n\n    // Copy transposed data to destination in n-major layout using non-temporal stores\n    // First 16 rows (block 1)\n    for (int i = 0; i < TILE_N; i++) {\n      fast_stream_64(dst + i * dst_row_stride, &temp_block1[i]);\n    }\n\n    // Next 16 rows (block 2)\n    for (int i = 0; i < TILE_N; i++) {\n      fast_stream_64(dst + (TILE_N + i) * dst_row_stride, &temp_block2[i]);\n    }\n\n    // Ensure all stores complete before returning\n    _mm_sfence();\n  }\n\n  /**\n   * @brief Reconstruct weights for a single expert to the output buffers\n   *\n   * Directly unpacks from packed BufferB format to n-major GPU buffers without intermediate storage.\n   * BF16 version - no scales needed.\n   *\n   * @param gpu_tp_count Number of GPU TP parts (1, 2, 4, or 8)\n   * @param cpu_tp_count Number of CPU TP parts\n   * @param expert_id Expert index to process\n   * @param full_config Full configuration (before CPU TP split)\n   * @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP)\n   * @param w13_scale_ptrs Pointers to gate+up scale buffers (unused for BF16, kept for interface compatibility)\n   * @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP)\n   * @param w2_scale_ptrs Pointers to down scale buffers (unused for BF16, kept for interface compatibility)\n   */\n  void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,\n                               const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,\n                               [[maybe_unused]] const std::vector<uintptr_t>& w13_scale_ptrs,\n                               const std::vector<uintptr_t>& w2_weight_ptrs,\n                               [[maybe_unused]] const std::vector<uintptr_t>& w2_scale_ptrs) const {\n    auto& config = config_;\n    auto pool = config.pool->get_subpool(tp_part_idx);\n\n    constexpr int N_STEP = T::N_STEP;\n    constexpr int K_STEP = T::K_STEP;\n    constexpr int N_BLOCK = T::N_BLOCK;\n    constexpr int K_BLOCK = T::K_BLOCK;\n\n    // ========= W13 (gate+up): Shape [intermediate, hidden], split by N only =========\n    const int cpu_n_w13 = config.intermediate_size;\n    const int cpu_k_w13 = config.hidden_size;\n    const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;\n    const int gpu_k_w13 = full_config.hidden_size;\n    const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;\n\n    const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;\n\n    // ========= W2 (down): Shape [hidden, intermediate], split by K =========\n    const int cpu_n_w2 = config.hidden_size;\n    const int cpu_k_w2 = config.intermediate_size;\n    const int gpu_n_w2 = full_config.hidden_size;\n    const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;\n    const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;\n\n    // ========= Optimized job layout =========\n    constexpr int NUM_W13_TASKS = 32;  // Per matrix (gate or up), total 64 for w13\n    constexpr int NUM_W2_TASKS = 32;   // For down matrix\n\n    const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS;\n\n    // Calculate N_STEP blocks per task\n    const int w13_n_steps = div_up(cpu_n_w13, N_STEP);\n    const int w13_steps_per_task = div_up(w13_n_steps, NUM_W13_TASKS);\n    const int w2_n_steps = div_up(cpu_n_w2, N_STEP);\n    const int w2_steps_per_task = div_up(w2_n_steps, NUM_W2_TASKS);\n\n    pool->do_work_stealing_job(\n        total_tasks, nullptr,\n        [=, &w13_weight_ptrs, &w2_weight_ptrs, this](int task_id) {\n          if (task_id < NUM_W13_TASKS * 2) {\n            // ========= W13 weight task: process chunk of rows x full K =========\n            const bool is_up = task_id >= NUM_W13_TASKS;\n            const int chunk_idx = task_id % NUM_W13_TASKS;\n            const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];\n\n            const int step_start = chunk_idx * w13_steps_per_task;\n            const int step_end = std::min(step_start + w13_steps_per_task, w13_n_steps);\n            if (step_start >= w13_n_steps) return;\n            const int chunk_n_start = step_start * N_STEP;\n            const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w13);\n\n            for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {\n              const int global_n = global_n_offset_w13 + local_n_start;\n              const int target_gpu = global_n / gpu_n_w13;\n              const int n_in_gpu = global_n % gpu_n_w13;\n\n              ggml_bf16_t* weight_base = (ggml_bf16_t*)w13_weight_ptrs[target_gpu];\n              const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0;\n\n              const int n_block_idx = local_n_start / N_BLOCK;\n              const int n_block_begin = n_block_idx * N_BLOCK;\n              const int n_block_size = std::min(N_BLOCK, cpu_n_w13 - n_block_begin);\n              const int n_in_block = local_n_start - n_block_begin;\n\n              for (int k_block_begin = 0; k_block_begin < cpu_k_w13; k_block_begin += K_BLOCK) {\n                const int k_block_size = std::min(K_BLOCK, cpu_k_w13 - k_block_begin);\n\n                for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n                  const ggml_bf16_t* src = bb->b + (size_t)n_block_begin * cpu_k_w13 +\n                                           (size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +\n                                           (size_t)k_begin * N_STEP;\n                  ggml_bf16_t* dst =\n                      weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;\n                  unpack_nk_block_bf16(src, dst, gpu_k_w13);\n                }\n              }\n            }\n\n          } else {\n            // ========= W2 weight task: process chunk of rows x all K slices =========\n            const int chunk_idx = task_id - NUM_W13_TASKS * 2;\n            const auto& bb = down_bb_[expert_id];\n\n            const int step_start = chunk_idx * w2_steps_per_task;\n            const int step_end = std::min(step_start + w2_steps_per_task, w2_n_steps);\n            if (step_start >= w2_n_steps) return;\n            const int chunk_n_start = step_start * N_STEP;\n            const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w2);\n\n            for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {\n              const int n_block_idx = local_n_start / N_BLOCK;\n              const int n_block_begin = n_block_idx * N_BLOCK;\n              const int n_block_size = std::min(N_BLOCK, cpu_n_w2 - n_block_begin);\n              const int n_in_block = local_n_start - n_block_begin;\n\n              for (int k_slice_start = 0; k_slice_start < cpu_k_w2; k_slice_start += gpu_k_w2) {\n                const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2);\n\n                const int global_k_start = global_k_offset_w2 + k_slice_start;\n                const int target_gpu = global_k_start / gpu_k_w2;\n                const int k_in_gpu_base = global_k_start % gpu_k_w2;\n\n                ggml_bf16_t* weight_base = (ggml_bf16_t*)w2_weight_ptrs[target_gpu];\n\n                for (int k_abs = k_slice_start; k_abs < k_slice_end; k_abs += K_STEP) {\n                  const int k_block_idx = k_abs / K_BLOCK;\n                  const int k_block_begin = k_block_idx * K_BLOCK;\n                  const int k_block_size = std::min(K_BLOCK, cpu_k_w2 - k_block_begin);\n                  const int k_in_block = k_abs - k_block_begin;\n                  const int k_in_gpu = k_in_gpu_base + (k_abs - k_slice_start);\n\n                  const ggml_bf16_t* src = bb->b + (size_t)n_block_begin * cpu_k_w2 +\n                                           (size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +\n                                           (size_t)k_in_block * N_STEP;\n                  ggml_bf16_t* dst = weight_base + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;\n                  unpack_nk_block_bf16(src, dst, gpu_k_w2);\n                }\n              }\n            }\n          }\n        },\n        nullptr);\n  }\n};\n\ntemplate <typename K>\nclass TP_MOE<AMX_BF16_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_BF16_MOE_TP<K>>> {\n public:\n  using Base = TP_MOE<AMX_MOE_BASE<K, AMX_BF16_MOE_TP<K>>>;\n  using Base::Base;\n\n  void load_weights() override {\n    auto& config = this->config;\n    auto& tps = this->tps;\n    auto& tp_count = this->tp_count;\n    auto pool = config.pool;\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;\n\n    // BF16 has no quantization check needed\n    if (config.gate_projs.empty() && config.gate_proj == nullptr) {\n      throw std::runtime_error(\"no weight source\");\n    }\n\n    const bool use_per_expert_ptrs = !config.gate_projs.empty();\n    const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;\n\n    pool->dispense_backend()->do_numa_job([&, this](int i) {\n      auto& tpc = tps[i]->config_;\n      const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;\n\n      // Allocate BF16 weights (2 bytes/element)\n      tpc.gate_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];\n      tpc.up_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];\n      tpc.down_proj = new ggml_bf16_t[tpc.expert_num * tp_weight_elems];\n\n      const size_t tp_idx = (size_t)i;\n      const size_t gate_up_weight_src_offset = i * tp_weight_elems;\n      const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;\n\n      pool->get_subpool(i)->do_work_stealing_job(\n          tpc.expert_num, nullptr,\n          [&, &tpc](int expert_id_) {\n            const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);\n\n            ggml_bf16_t* gate_dst = (ggml_bf16_t*)tpc.gate_proj + expert_id * tp_weight_elems;\n            ggml_bf16_t* up_dst = (ggml_bf16_t*)tpc.up_proj + expert_id * tp_weight_elems;\n            ggml_bf16_t* down_dst = (ggml_bf16_t*)tpc.down_proj + expert_id * tp_weight_elems;\n\n            const ggml_bf16_t* gate_src;\n            const ggml_bf16_t* up_src;\n            const ggml_bf16_t* down_src;\n\n            if (use_per_expert_ptrs) {\n              gate_src = (const ggml_bf16_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;\n              up_src = (const ggml_bf16_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;\n              down_src = (const ggml_bf16_t*)config.down_projs[0][expert_id];\n            } else {\n              gate_src =\n                  (const ggml_bf16_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;\n              up_src = (const ggml_bf16_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;\n              down_src = (const ggml_bf16_t*)config.down_proj + expert_id * full_weight_elems;\n            }\n\n            // Copy gate and up weights\n            std::memcpy(gate_dst, gate_src, tp_weight_elems * sizeof(ggml_bf16_t));\n            std::memcpy(up_dst, up_src, tp_weight_elems * sizeof(ggml_bf16_t));\n\n            // Copy down weights (row-wise split)\n            for (int row = 0; row < config.hidden_size; row++) {\n              const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;\n              const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;\n              std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset,\n                          (size_t)tpc.intermediate_size * sizeof(ggml_bf16_t));\n            }\n          },\n          nullptr);\n    });\n\n    DO_TPS_LOAD_WEIGHTS(pool);\n\n    pool->dispense_backend()->do_numa_job([&, this](int i) {\n      auto& tpc = tps[i]->config_;\n      delete[] (ggml_bf16_t*)tpc.gate_proj;\n      delete[] (ggml_bf16_t*)tpc.up_proj;\n      delete[] (ggml_bf16_t*)tpc.down_proj;\n    });\n\n    this->weights_loaded = true;\n  }\n\n  /**\n   * @brief Write weights to GPU buffer for all TP parts\n   *\n   * BF16 version - no scales needed, scale_ptrs parameters are kept for interface compatibility.\n   */\n  void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,\n                                    const std::vector<uintptr_t>& w13_scale_ptrs,\n                                    const std::vector<uintptr_t>& w2_weight_ptrs,\n                                    const std::vector<uintptr_t>& w2_scale_ptrs) {\n    if (this->weights_loaded == false) {\n      throw std::runtime_error(\"Not Loaded\");\n    }\n    if (this->tps.empty()) {\n      throw std::runtime_error(\"No TP parts initialized\");\n    }\n    if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w2_weight_ptrs.size() != gpu_tp_count) {\n      throw std::runtime_error(\"Weight pointer arrays size must match gpu_tp_count\");\n    }\n\n    this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {\n      this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,\n                                            w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);\n    });\n  }\n};\n\n#endif  // CPUINFER_OPERATOR_AMX_BF16_MOE_H\n"
  },
  {
    "path": "kt-kernel/operators/amx/fp8-moe.hpp",
    "content": "/**\n * @Description  : FP8 AMX MoE operator for DeepSeek V3.2 native inference\n * @Author       : oql, Codex and Claude\n * @Date         : 2025-12-09\n * @Version      : 1.0.0\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n *\n * This file implements FP8 MoE using CRTP pattern, inheriting from moe_base.hpp.\n * FP8 weights are stored with 128x128 block-wise scales.\n **/\n#ifndef CPUINFER_OPERATOR_AMX_FP8_MOE_H\n#define CPUINFER_OPERATOR_AMX_FP8_MOE_H\n\n// #define DEBUG_FP8_MOE\n\n#include \"la/amx_raw_buffers.hpp\"\n#include \"la/amx_raw_kernels.hpp\"\n#include \"moe_base.hpp\"\n\n/**\n * @brief FP8 MoE operator using CRTP pattern\n * @tparam T Kernel type, defaults to GemmKernel224FP8\n *\n * This class provides FP8-specific implementations:\n * - do_gate_up_gemm, do_down_gemm : FP8 weight -> BF16 conversion mat mul\n * - load_weights: Load FP8 weights with 128x128 block scales\n */\ntemplate <class T = amx::GemmKernel224FP8>\nclass AMX_FP8_MOE_TP : public AMX_MOE_BASE<T, AMX_FP8_MOE_TP<T>> {\n  using Base = AMX_MOE_BASE<T, AMX_FP8_MOE_TP<T>>;\n  using Base::config_;\n  using Base::down_ba_;\n  using Base::down_bb_;\n  using Base::down_bc_;\n  using Base::gate_bb_;\n  using Base::gate_bc_;\n  using Base::gate_up_ba_;\n  using Base::m_local_num_;\n  using Base::tp_part_idx;\n  using Base::up_bb_;\n  using Base::up_bc_;\n\n public:\n  using typename Base::input_t;\n  using typename Base::output_t;\n\n  AMX_FP8_MOE_TP() = default;\n\n  AMX_FP8_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {\n    // Initialization now happens in derived_init() which is called by base constructor\n  }\n\n  void derived_init() {\n    auto& quant_config = config_.quant_config;\n    if (quant_config.group_size == 0 || quant_config.zero_point) {\n      throw std::runtime_error(\"KT-Kernel fp8 MoE only support block-wise FP8\");\n    }\n    printf(\"Created AMX_FP8_MOE_TP %d at numa %d\\n\", tp_part_idx, numa_node_of_cpu(sched_getcpu()));\n  }\n\n  ~AMX_FP8_MOE_TP() = default;\n  // ============================================================================\n  // CRTP buffer creation - with group_size\n  // ============================================================================\n\n  size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }\n  size_t buffer_b_required_size_impl(size_t n, size_t k) const {\n    return T::BufferB::required_size(n, k, config_.quant_config.group_size);\n  }\n  size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }\n\n  std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferA>(m, k, data);\n  }\n  std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);\n  }\n  std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {\n    return std::make_shared<typename T::BufferC>(m, n, data);\n  }\n\n  // ============================================================================\n  // CRTP virtual points - GEMM dispatch\n  // ============================================================================\n\n  void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {\n    auto& group_size = config_.quant_config.group_size;\n    int m = m_local_num_[expert_idx];\n    auto& ba = gate_up_ba_[expert_idx];\n    auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];\n    auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];\n\n    amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);\n  }\n  void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {\n    auto& group_size = config_.quant_config.group_size;\n    int m = m_local_num_[expert_idx];\n\n    amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],\n                        down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);\n  }\n\n#ifdef DEBUG_FP8_MOE\n  // Function to dump Buffer B data for debugging FP8 quantization results\n  inline void dump_buffer_b(const std::string& quantization_type, int expert_idx, const std::string& matrix_type,\n                            typename T::BufferB* buffer) {\n    auto& quant_config = config_.quant_config;\n    int& group_size = quant_config.group_size;\n\n    printf(\"[DUMP_BUFFER_B] TP%d %s Expert%d %s:\\n\", tp_part_idx, quantization_type.c_str(), expert_idx,\n           matrix_type.c_str());\n\n    // Calculate dimensions based on matrix type\n    int rows, cols;\n    size_t scale_elem_count;\n    if (matrix_type == \"gate\" || matrix_type == \"up\") {\n      rows = config_.intermediate_size;\n      cols = config_.hidden_size;\n    } else {  // down\n      rows = config_.hidden_size;\n      cols = config_.intermediate_size;\n    }\n    int n_blocks_n = (rows + group_size - 1) / group_size;\n    int n_blocks_k = (cols + group_size - 1) / group_size;\n    scale_elem_count = n_blocks_n * n_blocks_k;\n\n    // Dump scales (as BF16 converted to float)\n    printf(\"  Scales[first 16]: \");\n    for (int i = 0; i < std::min(16, (int)scale_elem_count); i++) {\n      printf(\"%.6f \", buffer->d[i]);\n    }\n    printf(\"\\n\");\n\n    if (scale_elem_count > 16) {\n      printf(\"  Scales[last 16]: \");\n      int start_idx = std::max(0, (int)scale_elem_count - 16);\n      for (int i = start_idx; i < (int)scale_elem_count; i++) {\n        printf(\"%.6f \", buffer->d[i]);\n      }\n      printf(\"\\n\");\n    }\n\n    // Dump FP8 weights (as hex uint8)\n    size_t weight_size = (size_t)rows * cols;  // FP8 is 1 byte per element\n    uint8_t* weight_ptr = (uint8_t*)buffer->b;\n\n    printf(\"  FP8 Weights[first 32 bytes]: \");\n    for (int i = 0; i < std::min(32, (int)weight_size); i++) {\n      printf(\"%02x \", weight_ptr[i]);\n    }\n    printf(\"\\n\");\n\n    if (weight_size > 32) {\n      printf(\"  FP8 Weights[last 32 bytes]: \");\n      int start_idx = std::max(32, (int)weight_size - 32);\n      for (int i = start_idx; i < (int)weight_size; i++) {\n        printf(\"%02x \", weight_ptr[i]);\n      }\n      printf(\"\\n\");\n    }\n\n    printf(\"  Matrix dimensions: %dx%d (n x k), Scale blocks: %dx%d, Group size: %d, Scale elements: %zu\\n\", rows, cols,\n           n_blocks_n, n_blocks_k, group_size, scale_elem_count);\n  }\n#endif\n\n  /**\n   * @brief Load FP8 weights from contiguous memory layout\n   *\n   * Loads weights from config_.gate_proj, up_proj, down_proj with scales\n   * from config_.gate_scale, up_scale, down_scale.\n   */\n  void load_weights() {\n    auto& quant_config = config_.quant_config;\n    int& group_size = quant_config.group_size;\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n\n    if (config_.gate_scale == nullptr) {\n      throw std::runtime_error(\"FP8 AVX MOE only support native weight.\");\n    }\n\n    // load weight\n    int nth = T::recommended_nth(config_.intermediate_size);\n    pool->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [this, nth, physical_to_logical_map, group_size](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          int ith = task_id % nth;\n          // gate part\n          gate_bb_[expert_idx]->from_mat(\n              (uint8_t*)config_.gate_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),\n              (float*)config_.gate_scale +\n                  (logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)),\n              ith, nth);\n          // up part\n          up_bb_[expert_idx]->from_mat(\n              (uint8_t*)config_.up_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),\n              (float*)config_.up_scale +\n                  (logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)),\n              ith, nth);\n        },\n        nullptr);\n\n    nth = T::recommended_nth(config_.hidden_size);\n    pool->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [this, nth, physical_to_logical_map, group_size](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          int ith = task_id % nth;\n          // down part\n          down_bb_[expert_idx]->from_mat(\n              (uint8_t*)config_.down_proj + (logical_expert_id * config_.intermediate_size * config_.hidden_size),\n              (float*)config_.down_scale +\n                  (logical_expert_id * (config_.hidden_size / group_size) * (config_.intermediate_size / group_size)),\n              ith, nth);\n        },\n        nullptr);\n#ifdef DEBUG_FP8_MOE\n    dump_buffer_b(\"Native FP8\", 0, \"gate\", gate_bb_[0].get());\n    dump_buffer_b(\"Native FP8\", 0, \"down\", down_bb_[0].get());\n#endif\n  }\n\n  // Fast 64-byte (512-bit) memcpy using AVX512\n  static inline void fast_memcpy_64(void* __restrict dst, const void* __restrict src) {\n    __m512i data = _mm512_loadu_si512(src);\n    _mm512_storeu_si512(dst, data);\n  }\n\n  // Fast memcpy for arbitrary sizes using AVX512\n  static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) {\n    uint8_t* d = (uint8_t*)dst;\n    const uint8_t* s = (const uint8_t*)src;\n    size_t chunks = bytes / 64;\n    for (size_t i = 0; i < chunks; i++) {\n      fast_memcpy_64(d, s);\n      d += 64;\n      s += 64;\n    }\n    bytes -= chunks * 64;\n    if (bytes > 0) {\n      std::memcpy(d, s, bytes);\n    }\n  }\n\n  /**\n   * @brief Unpack a single N_STEP x K_STEP block from packed BufferB format to n-major format\n   *\n   * This is the inverse of the packing done in BufferBFP8Impl::from_mat.\n   * Optimized with AVX512 gather for efficient non-contiguous reads.\n   *\n   * @param src Pointer to packed data (N_STEP * K_STEP bytes in packed layout)\n   * @param dst Pointer to destination in n-major layout\n   * @param dst_row_stride Row stride in destination buffer (number of columns in full matrix)\n   */\n  static inline void unpack_nk_block(const uint8_t* src, uint8_t* dst, size_t dst_row_stride) {\n    // row_map[packed_i] gives the base row for packed index packed_i\n    static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};\n    const uint64_t* src64 = reinterpret_cast<const uint64_t*>(src);\n\n    // Gather indices: src64[8*j + packed_i] for j = 0..7\n    // Offsets in uint64 units: 0, 8, 16, 24, 32, 40, 48, 56 (+ packed_i for each group)\n    const __m512i gather_offsets = _mm512_set_epi64(56, 48, 40, 32, 24, 16, 8, 0);\n\n    // Process each packed group (8 groups of 4 rows each = 32 rows total)\n    for (int packed_i = 0; packed_i < 8; packed_i++) {\n      const int base_row = row_map[packed_i];\n      const uint64_t* base_src = src64 + packed_i;\n\n      // Gather 8 values for j=0..7 and j=8..15\n      __m512i vals_0_7 = _mm512_i64gather_epi64(gather_offsets, base_src, 8);\n      __m512i vals_8_15 = _mm512_i64gather_epi64(gather_offsets, base_src + 64, 8);\n\n      // Extract 4 rows from each set of 8 values\n      // Row 0: bits 0-15\n      __m128i row0_lo = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_0_7, _mm512_set1_epi64(0xFFFF)));\n      __m128i row0_hi = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_8_15, _mm512_set1_epi64(0xFFFF)));\n      // Row 1: bits 16-31\n      __m128i row1_lo =\n          _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 16), _mm512_set1_epi64(0xFFFF)));\n      __m128i row1_hi =\n          _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 16), _mm512_set1_epi64(0xFFFF)));\n      // Row 2: bits 32-47\n      __m128i row2_lo =\n          _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 32), _mm512_set1_epi64(0xFFFF)));\n      __m128i row2_hi =\n          _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 32), _mm512_set1_epi64(0xFFFF)));\n      // Row 3: bits 48-63\n      __m128i row3_lo = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_0_7, 48));\n      __m128i row3_hi = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_8_15, 48));\n\n      // Store 32 bytes (16 x uint16) to each row\n      // Combine two 128-bit values into 256-bit for more efficient stores\n      uint8_t* row0_dst = dst + (size_t)base_row * dst_row_stride;\n      uint8_t* row1_dst = dst + (size_t)(base_row + 1) * dst_row_stride;\n      uint8_t* row2_dst = dst + (size_t)(base_row + 2) * dst_row_stride;\n      uint8_t* row3_dst = dst + (size_t)(base_row + 3) * dst_row_stride;\n\n      // Combine lo and hi into 256-bit and store\n      __m256i row0_256 = _mm256_set_m128i(row0_hi, row0_lo);\n      __m256i row1_256 = _mm256_set_m128i(row1_hi, row1_lo);\n      __m256i row2_256 = _mm256_set_m128i(row2_hi, row2_lo);\n      __m256i row3_256 = _mm256_set_m128i(row3_hi, row3_lo);\n\n      _mm256_storeu_si256((__m256i*)row0_dst, row0_256);\n      _mm256_storeu_si256((__m256i*)row1_dst, row1_256);\n      _mm256_storeu_si256((__m256i*)row2_dst, row2_256);\n      _mm256_storeu_si256((__m256i*)row3_dst, row3_256);\n    }\n  }\n\n  /**\n   * @brief Unpack 4 consecutive N_STEP x K_STEP blocks to maximize cache line utilization\n   *\n   * Processing 4 blocks together means each row write is 128 bytes = 2 cache lines,\n   * which greatly improves write efficiency compared to 32 bytes per row.\n   *\n   * @param src Array of 4 source pointers (each pointing to a 32x32 packed block)\n   * @param dst Destination pointer in n-major layout\n   * @param dst_row_stride Row stride in destination buffer\n   */\n  static inline void unpack_4nk_blocks(const uint8_t* src[4], uint8_t* dst, size_t dst_row_stride) {\n    static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};\n    constexpr int K_STEP = T::K_STEP;  // 32\n\n    // Reinterpret as uint64 arrays for efficient access\n    const uint64_t* src0 = reinterpret_cast<const uint64_t*>(src[0]);\n    const uint64_t* src1 = reinterpret_cast<const uint64_t*>(src[1]);\n    const uint64_t* src2 = reinterpret_cast<const uint64_t*>(src[2]);\n    const uint64_t* src3 = reinterpret_cast<const uint64_t*>(src[3]);\n\n    // Process all 32 rows, writing 128 bytes (4 x 32) per row\n    for (int packed_i = 0; packed_i < 8; packed_i++) {\n      const int base_row = row_map[packed_i];\n\n      // Process 4 rows at a time\n      for (int r = 0; r < 4; r++) {\n        uint16_t* row_dst = reinterpret_cast<uint16_t*>(dst + (size_t)(base_row + r) * dst_row_stride);\n        const int shift = r * 16;\n\n        // Unroll: process all 4 blocks x 16 columns = 64 uint16 values\n        // Block 0: columns 0-15\n        for (int j = 0; j < 16; j++) {\n          row_dst[j] = static_cast<uint16_t>(src0[8 * j + packed_i] >> shift);\n        }\n        // Block 1: columns 16-31\n        for (int j = 0; j < 16; j++) {\n          row_dst[16 + j] = static_cast<uint16_t>(src1[8 * j + packed_i] >> shift);\n        }\n        // Block 2: columns 32-47\n        for (int j = 0; j < 16; j++) {\n          row_dst[32 + j] = static_cast<uint16_t>(src2[8 * j + packed_i] >> shift);\n        }\n        // Block 3: columns 48-63\n        for (int j = 0; j < 16; j++) {\n          row_dst[48 + j] = static_cast<uint16_t>(src3[8 * j + packed_i] >> shift);\n        }\n      }\n    }\n  }\n\n  /**\n   * @brief Reconstruct weights for a single expert to the output buffers (no temp buffer version)\n   *\n   * Directly unpacks from packed BufferB format to n-major GPU buffers without intermediate storage.\n   * Optimized version with coarse-grained task splitting for better cache utilization.\n   *\n   * Key optimizations:\n   * - Reduced task count (~40 vs ~350) to minimize scheduling overhead\n   * - Larger chunks per task for better cache line utilization\n   * - Process multiple N_STEPs per task for better write locality\n   *\n   * @param gpu_tp_count Number of GPU TP parts (1, 2, 4, or 8)\n   * @param cpu_tp_count Number of CPU TP parts\n   * @param expert_id Expert index to process\n   * @param full_config Full configuration (before CPU TP split)\n   * @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP)\n   * @param w13_scale_ptrs Pointers to gate+up scale buffers (one per GPU TP)\n   * @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP)\n   * @param w2_scale_ptrs Pointers to down scale buffers (one per GPU TP)\n   */\n  void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,\n                               const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,\n                               const std::vector<uintptr_t>& w13_scale_ptrs,\n                               const std::vector<uintptr_t>& w2_weight_ptrs,\n                               const std::vector<uintptr_t>& w2_scale_ptrs) const {\n    auto& config = config_;\n    const int group_size = config.quant_config.group_size;\n    auto pool = config.pool->get_subpool(tp_part_idx);\n\n    constexpr int N_STEP = T::N_STEP;\n    constexpr int K_STEP = T::K_STEP;\n    constexpr int N_BLOCK = T::N_BLOCK;\n    constexpr int K_BLOCK = T::K_BLOCK;\n\n    // ========= W13 (gate+up): Shape [intermediate, hidden], split by N only =========\n    const int cpu_n_w13 = config.intermediate_size;\n    const int cpu_k_w13 = config.hidden_size;\n    const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;\n    const int gpu_k_w13 = full_config.hidden_size;\n    const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;\n\n    const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;\n    const size_t gpu_w13_scale_per_mat = (size_t)div_up(gpu_n_w13, group_size) * div_up(gpu_k_w13, group_size);\n    const int cpu_scale_k_blocks_w13 = div_up(cpu_k_w13, group_size);\n    const int gpu_scale_k_blocks_w13 = div_up(gpu_k_w13, group_size);\n\n    // ========= W2 (down): Shape [hidden, intermediate], split by K =========\n    const int cpu_n_w2 = config.hidden_size;\n    const int cpu_k_w2 = config.intermediate_size;\n    const int gpu_n_w2 = full_config.hidden_size;\n    const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;\n    const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;\n\n    const size_t gpu_w2_weight_per_mat = (size_t)gpu_n_w2 * gpu_k_w2;\n    const size_t gpu_w2_scale_per_mat = (size_t)div_up(gpu_n_w2, group_size) * div_up(gpu_k_w2, group_size);\n    const int cpu_scale_k_blocks_w2 = div_up(cpu_k_w2, group_size);\n    const int gpu_scale_k_blocks_w2 = div_up(gpu_k_w2, group_size);\n\n    // ========= Scale dimensions =========\n    const int cpu_scale_n_blocks_w13 = div_up(cpu_n_w13, group_size);\n    const int gpu_scale_n_blocks_w13 = div_up(gpu_n_w13, group_size);\n    const int cpu_scale_n_blocks_w2 = div_up(cpu_n_w2, group_size);\n\n    // ========= Optimized job layout =========\n    // Use task count slightly above CPU core count for good work stealing\n    // For 80-core system, ~100 tasks provides good balance\n    constexpr int NUM_W13_TASKS = 32;  // Per matrix (gate or up), total 64 for w13\n    constexpr int NUM_W2_TASKS = 32;   // For down matrix\n    constexpr int SCALE_TASKS = 3;     // gate_scale, up_scale, down_scale\n\n    const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS + SCALE_TASKS;\n\n    // Calculate N_STEP blocks per task (must be N_STEP aligned for correct BufferB addressing)\n    const int w13_n_steps = div_up(cpu_n_w13, N_STEP);\n    const int w13_steps_per_task = div_up(w13_n_steps, NUM_W13_TASKS);\n    const int w2_n_steps = div_up(cpu_n_w2, N_STEP);\n    const int w2_steps_per_task = div_up(w2_n_steps, NUM_W2_TASKS);\n\n    pool->do_work_stealing_job(\n        total_tasks, nullptr,\n        [=, &w13_weight_ptrs, &w13_scale_ptrs, &w2_weight_ptrs, &w2_scale_ptrs, this](int task_id) {\n          if (task_id < NUM_W13_TASKS * 2) {\n            // ========= W13 weight task: process chunk of rows x full K =========\n            const bool is_up = task_id >= NUM_W13_TASKS;\n            const int chunk_idx = task_id % NUM_W13_TASKS;\n            const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];\n\n            // Calculate row range for this task (N_STEP aligned)\n            const int step_start = chunk_idx * w13_steps_per_task;\n            const int step_end = std::min(step_start + w13_steps_per_task, w13_n_steps);\n            if (step_start >= w13_n_steps) return;\n            const int chunk_n_start = step_start * N_STEP;\n            const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w13);\n\n            // Process each N_STEP within this chunk\n            for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {\n              // Calculate GPU target and offset for each N_STEP (may cross GPU TP boundaries)\n              const int global_n = global_n_offset_w13 + local_n_start;\n              const int target_gpu = global_n / gpu_n_w13;\n              const int n_in_gpu = global_n % gpu_n_w13;\n\n              uint8_t* weight_base = (uint8_t*)w13_weight_ptrs[target_gpu];\n              // Pointer already points to current expert's location, only add offset for up matrix\n              const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0;\n\n              // Calculate N_BLOCK info for source addressing\n              const int n_block_idx = local_n_start / N_BLOCK;\n              const int n_block_begin = n_block_idx * N_BLOCK;\n              const int n_block_size = std::min(N_BLOCK, cpu_n_w13 - n_block_begin);\n              const int n_in_block = local_n_start - n_block_begin;\n\n              // Process all K in groups of 4 K_STEPs when possible for cache efficiency\n              for (int k_block_begin = 0; k_block_begin < cpu_k_w13; k_block_begin += K_BLOCK) {\n                const int k_block_size = std::min(K_BLOCK, cpu_k_w13 - k_block_begin);\n\n                // Try to process 4 K_STEPs at once (128 columns = 2 cache lines per row)\n                int k_begin = 0;\n                for (; k_begin + 4 * K_STEP <= k_block_size; k_begin += 4 * K_STEP) {\n                  const uint8_t* src_ptrs[4];\n                  for (int i = 0; i < 4; i++) {\n                    src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w13 + (size_t)k_block_begin * n_block_size +\n                                  (size_t)n_in_block * k_block_size + (size_t)(k_begin + i * K_STEP) * N_STEP;\n                  }\n                  uint8_t* dst =\n                      weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;\n                  unpack_4nk_blocks(src_ptrs, dst, gpu_k_w13);\n                }\n\n                // Handle remaining K_STEPs one by one\n                for (; k_begin < k_block_size; k_begin += K_STEP) {\n                  const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w13 +\n                                       (size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +\n                                       (size_t)k_begin * N_STEP;\n                  uint8_t* dst =\n                      weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;\n                  unpack_nk_block(src, dst, gpu_k_w13);\n                }\n              }\n            }\n\n          } else if (task_id < NUM_W13_TASKS * 2 + NUM_W2_TASKS) {\n            // ========= W2 weight task: process chunk of rows x all K slices =========\n            const int chunk_idx = task_id - NUM_W13_TASKS * 2;\n            const auto& bb = down_bb_[expert_id];\n\n            // Calculate row range for this task (N_STEP aligned)\n            const int step_start = chunk_idx * w2_steps_per_task;\n            const int step_end = std::min(step_start + w2_steps_per_task, w2_n_steps);\n            if (step_start >= w2_n_steps) return;\n            const int chunk_n_start = step_start * N_STEP;\n            const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w2);\n\n            // Process each N_STEP within this chunk\n            for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {\n              // Calculate N_BLOCK info for source addressing\n              const int n_block_idx = local_n_start / N_BLOCK;\n              const int n_block_begin = n_block_idx * N_BLOCK;\n              const int n_block_size = std::min(N_BLOCK, cpu_n_w2 - n_block_begin);\n              const int n_in_block = local_n_start - n_block_begin;\n\n              // Process all K slices (each slice goes to a different GPU TP)\n              for (int k_slice_start = 0; k_slice_start < cpu_k_w2; k_slice_start += gpu_k_w2) {\n                const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2);\n\n                const int global_k_start = global_k_offset_w2 + k_slice_start;\n                const int target_gpu = global_k_start / gpu_k_w2;\n                const int k_in_gpu_base = global_k_start % gpu_k_w2;\n\n                uint8_t* weight_base = (uint8_t*)w2_weight_ptrs[target_gpu];\n                // Pointer already points to current expert's location\n                const size_t expert_weight_off = 0;\n\n                // Process K within this slice, trying 4 K_STEPs at once when aligned\n                for (int k_abs = k_slice_start; k_abs < k_slice_end;) {\n                  const int k_block_idx = k_abs / K_BLOCK;\n                  const int k_block_begin = k_block_idx * K_BLOCK;\n                  const int k_block_size = std::min(K_BLOCK, cpu_k_w2 - k_block_begin);\n                  const int k_in_block = k_abs - k_block_begin;\n                  const int k_in_gpu = k_in_gpu_base + (k_abs - k_slice_start);\n\n                  // Check if we can process 4 K_STEPs at once\n                  const int remaining_in_block = k_block_size - k_in_block;\n                  const int remaining_in_slice = k_slice_end - k_abs;\n\n                  if (remaining_in_block >= 4 * K_STEP && remaining_in_slice >= 4 * K_STEP) {\n                    const uint8_t* src_ptrs[4];\n                    for (int i = 0; i < 4; i++) {\n                      src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w2 + (size_t)k_block_begin * n_block_size +\n                                    (size_t)n_in_block * k_block_size + (size_t)(k_in_block + i * K_STEP) * N_STEP;\n                    }\n                    uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;\n                    unpack_4nk_blocks(src_ptrs, dst, gpu_k_w2);\n                    k_abs += 4 * K_STEP;\n                  } else {\n                    const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w2 +\n                                         (size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +\n                                         (size_t)k_in_block * N_STEP;\n                    uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;\n                    unpack_nk_block(src, dst, gpu_k_w2);\n                    k_abs += K_STEP;\n                  }\n                }\n              }\n            }\n\n          } else {\n            // ========= Scale copy task: simple linear copy with fast_memcpy =========\n            const int scale_task_id = task_id - NUM_W13_TASKS * 2 - NUM_W2_TASKS;\n\n            if (scale_task_id < 2) {\n              // Gate (0) or Up (1) scale copy\n              const bool is_up = scale_task_id == 1;\n              const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];\n\n              // W13 scales: copy N blocks corresponding to this CPU TP\n              // Note: when gpu_tp > cpu_tp, scale blocks may span multiple GPU TPs\n              const int bn_start_global = global_n_offset_w13 / group_size;\n\n              for (int bn = 0; bn < cpu_scale_n_blocks_w13; bn++) {\n                const int global_bn = bn_start_global + bn;\n                const int target_gpu = global_bn / gpu_scale_n_blocks_w13;\n                const int gpu_bn = global_bn % gpu_scale_n_blocks_w13;\n\n                float* scale_dst = (float*)w13_scale_ptrs[target_gpu];\n                // Pointer already points to current expert's location, only add offset for up matrix\n                const size_t expert_scale_off = is_up ? gpu_w13_scale_per_mat : 0;\n\n                fast_memcpy(scale_dst + expert_scale_off + (size_t)gpu_bn * gpu_scale_k_blocks_w13,\n                            bb->d + (size_t)bn * cpu_scale_k_blocks_w13, cpu_scale_k_blocks_w13 * sizeof(float));\n              }\n            } else {\n              // Down scale copy (scale_task_id == 2)\n              const auto& bb = down_bb_[expert_id];\n\n              // W2 scales: K dimension is split, copy to each GPU TP\n              for (int k_slice_idx = 0; k_slice_idx < div_up(cpu_k_w2, gpu_k_w2); k_slice_idx++) {\n                const int k_slice_start = k_slice_idx * gpu_k_w2;\n                const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2);\n\n                const int global_k_start = global_k_offset_w2 + k_slice_start;\n                const int target_gpu = global_k_start / gpu_k_w2;\n                const int bk_gpu_base = (global_k_start % gpu_k_w2) / group_size;\n\n                float* scale_dst = (float*)w2_scale_ptrs[target_gpu];\n                // Pointer already points to current expert's location\n                const size_t expert_scale_off = 0;\n\n                const int bk_start = k_slice_start / group_size;\n                const int bk_end = div_up(k_slice_end, group_size);\n                const int bk_count = bk_end - bk_start;\n\n                for (int bn = 0; bn < cpu_scale_n_blocks_w2; bn++) {\n                  fast_memcpy(scale_dst + expert_scale_off + (size_t)bn * gpu_scale_k_blocks_w2 + bk_gpu_base,\n                              bb->d + (size_t)bn * cpu_scale_k_blocks_w2 + bk_start, bk_count * sizeof(float));\n                }\n              }\n            }\n          }\n        },\n        nullptr);\n  }\n};\n\ntemplate <typename K>\nclass TP_MOE<AMX_FP8_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_FP8_MOE_TP<K>>> {\n public:\n  using Base = TP_MOE<AMX_MOE_BASE<K, AMX_FP8_MOE_TP<K>>>;\n  using Base::Base;\n\n  void load_weights() override {\n    auto& config = this->config;\n    auto& tps = this->tps;\n    auto& tp_count = this->tp_count;\n    auto pool = config.pool;\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;\n\n    const int group_size = config.quant_config.group_size;\n    if (group_size == 0 || config.quant_config.zero_point) {\n      throw std::runtime_error(\"FP8 MoE only supports have group_size, zero_point=false\");\n    }\n\n    if (config.gate_projs.empty() && config.gate_proj == nullptr) {\n      throw std::runtime_error(\"no weight source\");\n    }\n    const bool use_per_expert_ptrs = !config.gate_projs.empty();\n\n    const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;\n    const size_t full_scale_elems =\n        (size_t)div_up(config.hidden_size, group_size) * div_up(config.intermediate_size, group_size);\n\n    pool->dispense_backend()->do_numa_job([&, this](int i) {\n      auto& tpc = tps[i]->config_;\n      const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;\n      const size_t tp_scale_elems =\n          (size_t)div_up(tpc.intermediate_size, group_size) * div_up(tpc.hidden_size, group_size);\n\n      tpc.gate_proj = new uint8_t[tpc.expert_num * tp_weight_elems];\n      tpc.up_proj = new uint8_t[tpc.expert_num * tp_weight_elems];\n      tpc.down_proj = new uint8_t[tpc.expert_num * tp_weight_elems];\n\n      tpc.gate_scale = new float[tpc.expert_num * tp_scale_elems];\n      tpc.up_scale = new float[tpc.expert_num * tp_scale_elems];\n      tpc.down_scale = new float[tpc.expert_num * tp_scale_elems];\n\n      const size_t tp_idx = (size_t)i;\n      const size_t gate_up_weight_src_offset = i * tp_weight_elems;\n      const size_t gate_up_scale_src_offset = i * tp_scale_elems;\n\n      const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;\n      const size_t down_scale_src_block_k_offset = down_weight_src_col_offset / (size_t)group_size;\n\n      pool->get_subpool(i)->do_work_stealing_job(\n          tpc.expert_num, nullptr,\n          [&, &tpc](int expert_id_) {\n            const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);\n\n            uint8_t* gate_dst = (uint8_t*)tpc.gate_proj + expert_id * tp_weight_elems;\n            uint8_t* up_dst = (uint8_t*)tpc.up_proj + expert_id * tp_weight_elems;\n            uint8_t* down_dst = (uint8_t*)tpc.down_proj + expert_id * tp_weight_elems;\n\n            float* gate_scale_dst = (float*)tpc.gate_scale + expert_id * tp_scale_elems;\n            float* up_scale_dst = (float*)tpc.up_scale + expert_id * tp_scale_elems;\n            float* down_scale_dst = (float*)tpc.down_scale + expert_id * tp_scale_elems;\n\n            const uint8_t* gate_src;\n            const uint8_t* up_src;\n            const uint8_t* down_src;\n            const float* gate_scale_src;\n            const float* up_scale_src;\n            const float* down_scale_src;\n\n            if (use_per_expert_ptrs) {\n              gate_src = (const uint8_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;\n              up_src = (const uint8_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;\n              down_src = (const uint8_t*)config.down_projs[0][expert_id];\n\n              gate_scale_src = (const float*)config.gate_scales[0][expert_id] + gate_up_scale_src_offset;\n              up_scale_src = (const float*)config.up_scales[0][expert_id] + gate_up_scale_src_offset;\n              down_scale_src = (const float*)config.down_scales[0][expert_id];\n            } else {\n              gate_src = (const uint8_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;\n              up_src = (const uint8_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;\n              down_src = (const uint8_t*)config.down_proj + expert_id * full_weight_elems;\n\n              gate_scale_src =\n                  (const float*)config.gate_scale + expert_id * full_scale_elems + gate_up_scale_src_offset;\n              up_scale_src = (const float*)config.up_scale + expert_id * full_scale_elems + gate_up_scale_src_offset;\n              down_scale_src = (const float*)config.down_scale + expert_id * full_scale_elems;\n            }\n\n            std::memcpy(gate_dst, gate_src, tp_weight_elems);\n            std::memcpy(up_dst, up_src, tp_weight_elems);\n            std::memcpy(gate_scale_dst, gate_scale_src, sizeof(float) * tp_scale_elems);\n            std::memcpy(up_scale_dst, up_scale_src, sizeof(float) * tp_scale_elems);\n\n            for (int row = 0; row < config.hidden_size; row++) {\n              const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;\n              const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;\n              std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset, (size_t)tpc.intermediate_size);\n            }\n\n            const int n_blocks_n = div_up(config.hidden_size, group_size);\n            const int full_n_blocks_k = div_up(config.intermediate_size, group_size);\n            const int tp_n_blocks_k = div_up(tpc.intermediate_size, group_size);\n            for (int bn = 0; bn < n_blocks_n; bn++) {\n              const float* src = down_scale_src + (size_t)bn * (size_t)full_n_blocks_k + down_scale_src_block_k_offset;\n              float* dst = down_scale_dst + (size_t)bn * (size_t)tp_n_blocks_k;\n              std::memcpy(dst, src, sizeof(float) * (size_t)tp_n_blocks_k);\n            }\n          },\n          nullptr);\n    });\n\n    DO_TPS_LOAD_WEIGHTS(pool);\n\n    pool->dispense_backend()->do_numa_job([&, this](int i) {\n      auto& tpc = tps[i]->config_;\n      delete[] (uint8_t*)tpc.gate_proj;\n      delete[] (uint8_t*)tpc.up_proj;\n      delete[] (uint8_t*)tpc.down_proj;\n      delete[] (float*)tpc.gate_scale;\n      delete[] (float*)tpc.up_scale;\n      delete[] (float*)tpc.down_scale;\n    });\n\n    this->weights_loaded = true;\n  }\n\n  void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,\n                                    const std::vector<uintptr_t>& w13_scale_ptrs,\n                                    const std::vector<uintptr_t>& w2_weight_ptrs,\n                                    const std::vector<uintptr_t>& w2_scale_ptrs) {\n    if (this->weights_loaded == false) {\n      throw std::runtime_error(\"Not Loaded\");\n    }\n    if (this->tps.empty()) {\n      throw std::runtime_error(\"No TP parts initialized\");\n    }\n    if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w13_scale_ptrs.size() != gpu_tp_count ||\n        (int)w2_weight_ptrs.size() != gpu_tp_count || (int)w2_scale_ptrs.size() != gpu_tp_count) {\n      throw std::runtime_error(\"Pointer arrays size must match gpu_tp_count\");\n    }\n\n    this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {\n      this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,\n                                            w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);\n    });\n  }\n};\n\n#endif  // CPUINFER_OPERATOR_AMX_FP8_MOE_H\n"
  },
  {
    "path": "kt-kernel/operators/amx/fp8-perchannel-moe.hpp",
    "content": "/**\n * @Description  : FP8 Per-Channel AMX MoE operator for GLM-4.7-FP8 native inference\n * @Author       : Claude\n * @Date         : 2025-01-12\n * @Version      : 1.0.0\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n *\n * This file implements FP8 MoE with per-channel quantization using CRTP pattern.\n * Per-channel quantization: each output channel (row) has one scale factor.\n * This is different from block-wise quantization where each 128x128 block has one scale.\n **/\n#ifndef CPUINFER_OPERATOR_AMX_FP8_PERCHANNEL_MOE_H\n#define CPUINFER_OPERATOR_AMX_FP8_PERCHANNEL_MOE_H\n\n#include \"la/amx_raw_buffers.hpp\"\n#include \"la/amx_raw_kernels.hpp\"\n#include \"moe_base.hpp\"\n\n/**\n * @brief FP8 Per-Channel MoE operator using CRTP pattern\n * @tparam T Kernel type, defaults to GemmKernel224FP8PerChannel\n *\n * This class provides FP8 per-channel specific implementations:\n * - do_gate_up_gemm, do_down_gemm : FP8 weight -> BF16 conversion mat mul with per-channel scale\n * - load_weights: Load FP8 weights with per-channel scales (shape: [n])\n */\ntemplate <class T = amx::GemmKernel224FP8PerChannel>\nclass AMX_FP8_PERCHANNEL_MOE_TP : public AMX_MOE_BASE<T, AMX_FP8_PERCHANNEL_MOE_TP<T>> {\n  using Base = AMX_MOE_BASE<T, AMX_FP8_PERCHANNEL_MOE_TP<T>>;\n  using Base::config_;\n  using Base::down_ba_;\n  using Base::down_bb_;\n  using Base::down_bc_;\n  using Base::gate_bb_;\n  using Base::gate_bc_;\n  using Base::gate_up_ba_;\n  using Base::m_local_num_;\n  using Base::tp_part_idx;\n  using Base::up_bb_;\n  using Base::up_bc_;\n\n public:\n  using typename Base::input_t;\n  using typename Base::output_t;\n\n  AMX_FP8_PERCHANNEL_MOE_TP() = default;\n\n  AMX_FP8_PERCHANNEL_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {\n    // Initialization now happens in derived_init() which is called by base constructor\n  }\n\n  void derived_init() {\n    auto& quant_config = config_.quant_config;\n    if (!quant_config.per_channel) {\n      throw std::runtime_error(\"KT-Kernel FP8 Per-Channel MoE requires per_channel=true\");\n    }\n    printf(\"Created AMX_FP8_PERCHANNEL_MOE_TP %d at numa %d\\n\", tp_part_idx, numa_node_of_cpu(sched_getcpu()));\n  }\n\n  ~AMX_FP8_PERCHANNEL_MOE_TP() = default;\n\n  // ============================================================================\n  // CRTP buffer creation - per-channel (no group_size needed)\n  // ============================================================================\n\n  size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }\n  size_t buffer_b_required_size_impl(size_t n, size_t k) const {\n    // Per-channel: weight size + n scales (no group_size)\n    return T::BufferB::required_size(n, k);\n  }\n  size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }\n\n  std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferA>(m, k, data);\n  }\n  std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {\n    // Per-channel BufferB doesn't need group_size\n    return std::make_shared<typename T::BufferB>(n, k, data);\n  }\n  std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {\n    return std::make_shared<typename T::BufferC>(m, n, data);\n  }\n\n  // ============================================================================\n  // CRTP virtual points - GEMM dispatch (per-channel)\n  // ============================================================================\n\n  void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {\n    int m = m_local_num_[expert_idx];\n    auto& ba = gate_up_ba_[expert_idx];\n    auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];\n    auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];\n\n    // Per-channel: use vec_mul_perchannel instead of vec_mul_kgroup\n    amx::float_mat_vec_perchannel<T>(m, config_.intermediate_size, config_.hidden_size, ba.get(), bb.get(), bc.get(),\n                                     ith, nth);\n  }\n\n  void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {\n    int m = m_local_num_[expert_idx];\n\n    amx::float_mat_vec_perchannel<T>(m, config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx].get(),\n                                     down_bb_[expert_idx].get(), down_bc_[expert_idx].get(), ith, nth);\n  }\n\n  // Fast 64-byte (512-bit) memcpy using AVX512\n  static inline void fast_memcpy_64(void* __restrict dst, const void* __restrict src) {\n    __m512i data = _mm512_loadu_si512(src);\n    _mm512_storeu_si512(dst, data);\n  }\n\n  // Fast memcpy for arbitrary sizes using AVX512\n  static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) {\n    uint8_t* d = (uint8_t*)dst;\n    const uint8_t* s = (const uint8_t*)src;\n    size_t chunks = bytes / 64;\n    for (size_t i = 0; i < chunks; i++) {\n      fast_memcpy_64(d, s);\n      d += 64;\n      s += 64;\n    }\n    bytes -= chunks * 64;\n    if (bytes > 0) {\n      std::memcpy(d, s, bytes);\n    }\n  }\n\n  /**\n   * @brief Unpack a single N_STEP x K_STEP block from packed BufferB format to n-major format\n   *\n   * This is the inverse of the packing done in BufferBFP8PerChannelImpl::from_mat.\n   * Optimized with AVX512 gather for efficient non-contiguous reads.\n   *\n   * @param src Pointer to packed data (N_STEP * K_STEP bytes in packed layout)\n   * @param dst Pointer to destination in n-major layout\n   * @param dst_row_stride Row stride in destination buffer (number of columns in full matrix)\n   */\n  static inline void unpack_nk_block(const uint8_t* src, uint8_t* dst, size_t dst_row_stride) {\n    // row_map[packed_i] gives the base row for packed index packed_i\n    static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};\n    const uint64_t* src64 = reinterpret_cast<const uint64_t*>(src);\n\n    // Gather indices: src64[8*j + packed_i] for j = 0..7\n    // Offsets in uint64 units: 0, 8, 16, 24, 32, 40, 48, 56 (+ packed_i for each group)\n    const __m512i gather_offsets = _mm512_set_epi64(56, 48, 40, 32, 24, 16, 8, 0);\n\n    // Process each packed group (8 groups of 4 rows each = 32 rows total)\n    for (int packed_i = 0; packed_i < 8; packed_i++) {\n      const int base_row = row_map[packed_i];\n      const uint64_t* base_src = src64 + packed_i;\n\n      // Gather 8 values for j=0..7 and j=8..15\n      __m512i vals_0_7 = _mm512_i64gather_epi64(gather_offsets, base_src, 8);\n      __m512i vals_8_15 = _mm512_i64gather_epi64(gather_offsets, base_src + 64, 8);\n\n      // Extract 4 rows from each set of 8 values\n      // Row 0: bits 0-15\n      __m128i row0_lo = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_0_7, _mm512_set1_epi64(0xFFFF)));\n      __m128i row0_hi = _mm512_cvtepi64_epi16(_mm512_and_si512(vals_8_15, _mm512_set1_epi64(0xFFFF)));\n      // Row 1: bits 16-31\n      __m128i row1_lo =\n          _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 16), _mm512_set1_epi64(0xFFFF)));\n      __m128i row1_hi =\n          _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 16), _mm512_set1_epi64(0xFFFF)));\n      // Row 2: bits 32-47\n      __m128i row2_lo =\n          _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_0_7, 32), _mm512_set1_epi64(0xFFFF)));\n      __m128i row2_hi =\n          _mm512_cvtepi64_epi16(_mm512_and_si512(_mm512_srli_epi64(vals_8_15, 32), _mm512_set1_epi64(0xFFFF)));\n      // Row 3: bits 48-63\n      __m128i row3_lo = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_0_7, 48));\n      __m128i row3_hi = _mm512_cvtepi64_epi16(_mm512_srli_epi64(vals_8_15, 48));\n\n      // Store 32 bytes (16 x uint16) to each row\n      // Combine two 128-bit values into 256-bit for more efficient stores\n      uint8_t* row0_dst = dst + (size_t)base_row * dst_row_stride;\n      uint8_t* row1_dst = dst + (size_t)(base_row + 1) * dst_row_stride;\n      uint8_t* row2_dst = dst + (size_t)(base_row + 2) * dst_row_stride;\n      uint8_t* row3_dst = dst + (size_t)(base_row + 3) * dst_row_stride;\n\n      // Combine lo and hi into 256-bit and store\n      __m256i row0_256 = _mm256_set_m128i(row0_hi, row0_lo);\n      __m256i row1_256 = _mm256_set_m128i(row1_hi, row1_lo);\n      __m256i row2_256 = _mm256_set_m128i(row2_hi, row2_lo);\n      __m256i row3_256 = _mm256_set_m128i(row3_hi, row3_lo);\n\n      _mm256_storeu_si256((__m256i*)row0_dst, row0_256);\n      _mm256_storeu_si256((__m256i*)row1_dst, row1_256);\n      _mm256_storeu_si256((__m256i*)row2_dst, row2_256);\n      _mm256_storeu_si256((__m256i*)row3_dst, row3_256);\n    }\n  }\n\n  /**\n   * @brief Unpack 4 consecutive N_STEP x K_STEP blocks to maximize cache line utilization\n   *\n   * Processing 4 blocks together means each row write is 128 bytes = 2 cache lines,\n   * which greatly improves write efficiency compared to 32 bytes per row.\n   *\n   * @param src Array of 4 source pointers (each pointing to a 32x32 packed block)\n   * @param dst Destination pointer in n-major layout\n   * @param dst_row_stride Row stride in destination buffer\n   */\n  static inline void unpack_4nk_blocks(const uint8_t* src[4], uint8_t* dst, size_t dst_row_stride) {\n    static constexpr int row_map[8] = {0, 16, 4, 20, 8, 24, 12, 28};\n    constexpr int K_STEP = T::K_STEP;  // 32\n\n    // Reinterpret as uint64 arrays for efficient access\n    const uint64_t* src0 = reinterpret_cast<const uint64_t*>(src[0]);\n    const uint64_t* src1 = reinterpret_cast<const uint64_t*>(src[1]);\n    const uint64_t* src2 = reinterpret_cast<const uint64_t*>(src[2]);\n    const uint64_t* src3 = reinterpret_cast<const uint64_t*>(src[3]);\n\n    // Process all 32 rows, writing 128 bytes (4 x 32) per row\n    for (int packed_i = 0; packed_i < 8; packed_i++) {\n      const int base_row = row_map[packed_i];\n\n      // Process 4 rows at a time\n      for (int r = 0; r < 4; r++) {\n        uint16_t* row_dst = reinterpret_cast<uint16_t*>(dst + (size_t)(base_row + r) * dst_row_stride);\n        const int shift = r * 16;\n\n        // Unroll: process all 4 blocks x 16 columns = 64 uint16 values\n        // Block 0: columns 0-15\n        for (int j = 0; j < 16; j++) {\n          row_dst[j] = static_cast<uint16_t>(src0[8 * j + packed_i] >> shift);\n        }\n        // Block 1: columns 16-31\n        for (int j = 0; j < 16; j++) {\n          row_dst[16 + j] = static_cast<uint16_t>(src1[8 * j + packed_i] >> shift);\n        }\n        // Block 2: columns 32-47\n        for (int j = 0; j < 16; j++) {\n          row_dst[32 + j] = static_cast<uint16_t>(src2[8 * j + packed_i] >> shift);\n        }\n        // Block 3: columns 48-63\n        for (int j = 0; j < 16; j++) {\n          row_dst[48 + j] = static_cast<uint16_t>(src3[8 * j + packed_i] >> shift);\n        }\n      }\n    }\n  }\n\n  /**\n   * @brief Reconstruct weights for a single expert to the output buffers (per-channel version)\n   *\n   * Directly unpacks from packed BufferB format to n-major GPU buffers without intermediate storage.\n   * Scale handling is simplified for per-channel quantization (linear copy instead of block-wise).\n   *\n   * @param gpu_tp_count Number of GPU TP parts (1, 2, 4, or 8)\n   * @param cpu_tp_count Number of CPU TP parts\n   * @param expert_id Expert index to process\n   * @param full_config Full configuration (before CPU TP split)\n   * @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP)\n   * @param w13_scale_ptrs Pointers to gate+up scale buffers (one per GPU TP)\n   * @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP)\n   * @param w2_scale_ptrs Pointers to down scale buffers (one per GPU TP)\n   */\n  void write_weights_to_buffer(int gpu_tp_count, [[maybe_unused]] int cpu_tp_count, int expert_id,\n                               const GeneralMOEConfig& full_config, const std::vector<uintptr_t>& w13_weight_ptrs,\n                               const std::vector<uintptr_t>& w13_scale_ptrs,\n                               const std::vector<uintptr_t>& w2_weight_ptrs,\n                               const std::vector<uintptr_t>& w2_scale_ptrs) const {\n    auto& config = config_;\n    auto pool = config.pool->get_subpool(tp_part_idx);\n\n    constexpr int N_STEP = T::N_STEP;\n    constexpr int K_STEP = T::K_STEP;\n    constexpr int N_BLOCK = T::N_BLOCK;\n    constexpr int K_BLOCK = T::K_BLOCK;\n\n    // ========= W13 (gate+up): Shape [intermediate, hidden], split by N only =========\n    const int cpu_n_w13 = config.intermediate_size;\n    const int cpu_k_w13 = config.hidden_size;\n    const int gpu_n_w13 = full_config.intermediate_size / gpu_tp_count;\n    const int gpu_k_w13 = full_config.hidden_size;\n    const int global_n_offset_w13 = tp_part_idx * cpu_n_w13;\n\n    const size_t gpu_w13_weight_per_mat = (size_t)gpu_n_w13 * gpu_k_w13;\n    // Per-channel scale: shape [n] for each matrix\n    const size_t gpu_w13_scale_per_mat = (size_t)gpu_n_w13;\n\n    // ========= W2 (down): Shape [hidden, intermediate], split by K =========\n    const int cpu_n_w2 = config.hidden_size;\n    const int cpu_k_w2 = config.intermediate_size;\n    const int gpu_n_w2 = full_config.hidden_size;\n    const int gpu_k_w2 = full_config.intermediate_size / gpu_tp_count;\n    const int global_k_offset_w2 = tp_part_idx * cpu_k_w2;\n\n    const size_t gpu_w2_weight_per_mat = (size_t)gpu_n_w2 * gpu_k_w2;\n    // Per-channel scale for down: shape [hidden_size] - not split by K\n    const size_t gpu_w2_scale_per_mat = (size_t)gpu_n_w2;\n\n    // ========= Optimized job layout =========\n    constexpr int NUM_W13_TASKS = 32;  // Per matrix (gate or up), total 64 for w13\n    constexpr int NUM_W2_TASKS = 32;   // For down matrix\n    constexpr int SCALE_TASKS = 3;     // gate_scale, up_scale, down_scale\n\n    const int total_tasks = NUM_W13_TASKS * 2 + NUM_W2_TASKS + SCALE_TASKS;\n\n    // Calculate N_STEP blocks per task (must be N_STEP aligned for correct BufferB addressing)\n    const int w13_n_steps = div_up(cpu_n_w13, N_STEP);\n    const int w13_steps_per_task = div_up(w13_n_steps, NUM_W13_TASKS);\n    const int w2_n_steps = div_up(cpu_n_w2, N_STEP);\n    const int w2_steps_per_task = div_up(w2_n_steps, NUM_W2_TASKS);\n\n    pool->do_work_stealing_job(\n        total_tasks, nullptr,\n        [=, &w13_weight_ptrs, &w13_scale_ptrs, &w2_weight_ptrs, &w2_scale_ptrs, this](int task_id) {\n          if (task_id < NUM_W13_TASKS * 2) {\n            // ========= W13 weight task: process chunk of rows x full K =========\n            const bool is_up = task_id >= NUM_W13_TASKS;\n            const int chunk_idx = task_id % NUM_W13_TASKS;\n            const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];\n\n            // Calculate row range for this task (N_STEP aligned)\n            const int step_start = chunk_idx * w13_steps_per_task;\n            const int step_end = std::min(step_start + w13_steps_per_task, w13_n_steps);\n            if (step_start >= w13_n_steps) return;\n            const int chunk_n_start = step_start * N_STEP;\n            const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w13);\n\n            // Process each N_STEP within this chunk\n            for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {\n              // Calculate GPU target and offset for each N_STEP (may cross GPU TP boundaries)\n              const int global_n = global_n_offset_w13 + local_n_start;\n              const int target_gpu = global_n / gpu_n_w13;\n              const int n_in_gpu = global_n % gpu_n_w13;\n\n              uint8_t* weight_base = (uint8_t*)w13_weight_ptrs[target_gpu];\n              // Pointer already points to current expert's location, only add offset for up matrix\n              const size_t expert_weight_off = is_up ? gpu_w13_weight_per_mat : 0;\n\n              // Calculate N_BLOCK info for source addressing\n              const int n_block_idx = local_n_start / N_BLOCK;\n              const int n_block_begin = n_block_idx * N_BLOCK;\n              const int n_block_size = std::min(N_BLOCK, cpu_n_w13 - n_block_begin);\n              const int n_in_block = local_n_start - n_block_begin;\n\n              // Process all K in groups of 4 K_STEPs when possible for cache efficiency\n              for (int k_block_begin = 0; k_block_begin < cpu_k_w13; k_block_begin += K_BLOCK) {\n                const int k_block_size = std::min(K_BLOCK, cpu_k_w13 - k_block_begin);\n\n                // Try to process 4 K_STEPs at once (128 columns = 2 cache lines per row)\n                int k_begin = 0;\n                for (; k_begin + 4 * K_STEP <= k_block_size; k_begin += 4 * K_STEP) {\n                  const uint8_t* src_ptrs[4];\n                  for (int i = 0; i < 4; i++) {\n                    src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w13 + (size_t)k_block_begin * n_block_size +\n                                  (size_t)n_in_block * k_block_size + (size_t)(k_begin + i * K_STEP) * N_STEP;\n                  }\n                  uint8_t* dst =\n                      weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;\n                  unpack_4nk_blocks(src_ptrs, dst, gpu_k_w13);\n                }\n\n                // Handle remaining K_STEPs one by one\n                for (; k_begin < k_block_size; k_begin += K_STEP) {\n                  const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w13 +\n                                       (size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +\n                                       (size_t)k_begin * N_STEP;\n                  uint8_t* dst =\n                      weight_base + expert_weight_off + (size_t)n_in_gpu * gpu_k_w13 + k_block_begin + k_begin;\n                  unpack_nk_block(src, dst, gpu_k_w13);\n                }\n              }\n            }\n\n          } else if (task_id < NUM_W13_TASKS * 2 + NUM_W2_TASKS) {\n            // ========= W2 weight task: process chunk of rows x all K slices =========\n            const int chunk_idx = task_id - NUM_W13_TASKS * 2;\n            const auto& bb = down_bb_[expert_id];\n\n            // Calculate row range for this task (N_STEP aligned)\n            const int step_start = chunk_idx * w2_steps_per_task;\n            const int step_end = std::min(step_start + w2_steps_per_task, w2_n_steps);\n            if (step_start >= w2_n_steps) return;\n            const int chunk_n_start = step_start * N_STEP;\n            const int chunk_n_end = std::min(step_end * N_STEP, cpu_n_w2);\n\n            // Process each N_STEP within this chunk\n            for (int local_n_start = chunk_n_start; local_n_start < chunk_n_end; local_n_start += N_STEP) {\n              // Calculate N_BLOCK info for source addressing\n              const int n_block_idx = local_n_start / N_BLOCK;\n              const int n_block_begin = n_block_idx * N_BLOCK;\n              const int n_block_size = std::min(N_BLOCK, cpu_n_w2 - n_block_begin);\n              const int n_in_block = local_n_start - n_block_begin;\n\n              // Process all K slices (each slice goes to a different GPU TP)\n              for (int k_slice_start = 0; k_slice_start < cpu_k_w2; k_slice_start += gpu_k_w2) {\n                const int k_slice_end = std::min(k_slice_start + gpu_k_w2, cpu_k_w2);\n\n                const int global_k_start = global_k_offset_w2 + k_slice_start;\n                const int target_gpu = global_k_start / gpu_k_w2;\n                const int k_in_gpu_base = global_k_start % gpu_k_w2;\n\n                uint8_t* weight_base = (uint8_t*)w2_weight_ptrs[target_gpu];\n                // Pointer already points to current expert's location\n                const size_t expert_weight_off = 0;\n\n                // Process K within this slice, trying 4 K_STEPs at once when aligned\n                for (int k_abs = k_slice_start; k_abs < k_slice_end;) {\n                  const int k_block_idx = k_abs / K_BLOCK;\n                  const int k_block_begin = k_block_idx * K_BLOCK;\n                  const int k_block_size = std::min(K_BLOCK, cpu_k_w2 - k_block_begin);\n                  const int k_in_block = k_abs - k_block_begin;\n                  const int k_in_gpu = k_in_gpu_base + (k_abs - k_slice_start);\n\n                  // Check if we can process 4 K_STEPs at once\n                  const int remaining_in_block = k_block_size - k_in_block;\n                  const int remaining_in_slice = k_slice_end - k_abs;\n\n                  if (remaining_in_block >= 4 * K_STEP && remaining_in_slice >= 4 * K_STEP) {\n                    const uint8_t* src_ptrs[4];\n                    for (int i = 0; i < 4; i++) {\n                      src_ptrs[i] = bb->b + (size_t)n_block_begin * cpu_k_w2 + (size_t)k_block_begin * n_block_size +\n                                    (size_t)n_in_block * k_block_size + (size_t)(k_in_block + i * K_STEP) * N_STEP;\n                    }\n                    uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;\n                    unpack_4nk_blocks(src_ptrs, dst, gpu_k_w2);\n                    k_abs += 4 * K_STEP;\n                  } else {\n                    const uint8_t* src = bb->b + (size_t)n_block_begin * cpu_k_w2 +\n                                         (size_t)k_block_begin * n_block_size + (size_t)n_in_block * k_block_size +\n                                         (size_t)k_in_block * N_STEP;\n                    uint8_t* dst = weight_base + expert_weight_off + (size_t)local_n_start * gpu_k_w2 + k_in_gpu;\n                    unpack_nk_block(src, dst, gpu_k_w2);\n                    k_abs += K_STEP;\n                  }\n                }\n              }\n            }\n\n          } else {\n            // ========= Scale copy task: per-channel (simple linear copy) =========\n            const int scale_task_id = task_id - NUM_W13_TASKS * 2 - NUM_W2_TASKS;\n\n            if (scale_task_id < 2) {\n              // Gate (0) or Up (1) scale copy - per-channel: [intermediate_size]\n              const bool is_up = scale_task_id == 1;\n              const auto& bb = is_up ? up_bb_[expert_id] : gate_bb_[expert_id];\n\n              // W13 per-channel scales: copy N range corresponding to this CPU TP\n              // Each GPU TP gets [gpu_n_w13] scales\n              const int n_start_global = global_n_offset_w13;\n\n              for (int local_n = 0; local_n < cpu_n_w13;) {\n                const int global_n = n_start_global + local_n;\n                const int target_gpu = global_n / gpu_n_w13;\n                const int n_in_gpu = global_n % gpu_n_w13;\n\n                // Calculate how many scales to copy to this GPU TP\n                const int remaining_in_gpu = gpu_n_w13 - n_in_gpu;\n                const int remaining_local = cpu_n_w13 - local_n;\n                const int copy_count = std::min(remaining_in_gpu, remaining_local);\n\n                float* scale_dst = (float*)w13_scale_ptrs[target_gpu];\n                // Pointer already points to current expert's location, only add offset for up matrix\n                const size_t expert_scale_off = is_up ? gpu_w13_scale_per_mat : 0;\n\n                fast_memcpy(scale_dst + expert_scale_off + n_in_gpu, bb->d + local_n, copy_count * sizeof(float));\n\n                local_n += copy_count;\n              }\n            } else {\n              // Down scale copy (scale_task_id == 2) - per-channel: [hidden_size]\n              const auto& bb = down_bb_[expert_id];\n\n              // W2 per-channel scales: shape [hidden_size], not split by K\n              // All GPU TPs get the same scales (full hidden_size)\n              // However, since K is split, we need to write to each GPU TP\n              for (int gpu_idx = 0; gpu_idx < gpu_tp_count; gpu_idx++) {\n                // Check if this CPU TP contributes to this GPU TP's K range\n                const int gpu_k_start = gpu_idx * gpu_k_w2;\n                const int gpu_k_end = gpu_k_start + gpu_k_w2;\n                const int cpu_k_start = global_k_offset_w2;\n                const int cpu_k_end = cpu_k_start + cpu_k_w2;\n\n                // Check for overlap\n                if (cpu_k_start < gpu_k_end && cpu_k_end > gpu_k_start) {\n                  // This CPU TP contributes to this GPU TP\n                  // Only the first CPU TP for this GPU should write scales\n                  if (cpu_k_start == gpu_k_start || cpu_k_start % gpu_k_w2 == 0) {\n                    float* scale_dst = (float*)w2_scale_ptrs[gpu_idx];\n                    // Pointer already points to current expert's location\n                    fast_memcpy(scale_dst, bb->d, cpu_n_w2 * sizeof(float));\n                  }\n                }\n              }\n            }\n          }\n        },\n        nullptr);\n  }\n\n  /**\n   * @brief Load FP8 weights from contiguous memory layout with per-channel scales\n   *\n   * Loads weights from config_.gate_proj, up_proj, down_proj with scales\n   * from config_.gate_scale, up_scale, down_scale.\n   *\n   * Per-channel scale shape: [n] (one scale per output channel)\n   */\n  void load_weights() {\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n\n    if (config_.gate_scale == nullptr) {\n      throw std::runtime_error(\"FP8 Per-Channel MoE requires scale pointers.\");\n    }\n\n    // load gate and up weights\n    int nth = T::recommended_nth(config_.intermediate_size);\n    pool->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [this, nth, physical_to_logical_map](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          int ith = task_id % nth;\n\n          // Per-channel scale: shape [intermediate_size] for gate/up\n          const size_t weight_offset = logical_expert_id * config_.intermediate_size * config_.hidden_size;\n          const size_t scale_offset = logical_expert_id * config_.intermediate_size;\n\n          // gate part\n          gate_bb_[expert_idx]->from_mat((uint8_t*)config_.gate_proj + weight_offset,\n                                         (float*)config_.gate_scale + scale_offset, ith, nth);\n          // up part\n          up_bb_[expert_idx]->from_mat((uint8_t*)config_.up_proj + weight_offset,\n                                       (float*)config_.up_scale + scale_offset, ith, nth);\n        },\n        nullptr);\n\n    // load down weights\n    nth = T::recommended_nth(config_.hidden_size);\n    pool->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [this, nth, physical_to_logical_map](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          int ith = task_id % nth;\n\n          // Per-channel scale: shape [hidden_size] for down\n          const size_t weight_offset = logical_expert_id * config_.intermediate_size * config_.hidden_size;\n          const size_t scale_offset = logical_expert_id * config_.hidden_size;\n\n          // down part\n          down_bb_[expert_idx]->from_mat((uint8_t*)config_.down_proj + weight_offset,\n                                         (float*)config_.down_scale + scale_offset, ith, nth);\n        },\n        nullptr);\n  }\n};\n\n/**\n * @brief TP_MOE specialization for FP8 Per-Channel MoE\n */\ntemplate <typename K>\nclass TP_MOE<AMX_FP8_PERCHANNEL_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_FP8_PERCHANNEL_MOE_TP<K>>> {\n public:\n  using Base = TP_MOE<AMX_MOE_BASE<K, AMX_FP8_PERCHANNEL_MOE_TP<K>>>;\n  using Base::Base;\n\n  /**\n   * @brief Write weights and scales to GPU buffer for a single expert\n   *\n   * This method coordinates all CPU TP parts to write their portions\n   * of weights and scales to the GPU buffers.\n   *\n   * @param gpu_tp_count Number of GPU TP parts\n   * @param expert_id Expert index to write\n   * @param w13_weight_ptrs Pointers to gate+up weight buffers (one per GPU TP)\n   * @param w13_scale_ptrs Pointers to gate+up scale buffers (one per GPU TP)\n   * @param w2_weight_ptrs Pointers to down weight buffers (one per GPU TP)\n   * @param w2_scale_ptrs Pointers to down scale buffers (one per GPU TP)\n   */\n  void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,\n                                    const std::vector<uintptr_t>& w13_scale_ptrs,\n                                    const std::vector<uintptr_t>& w2_weight_ptrs,\n                                    const std::vector<uintptr_t>& w2_scale_ptrs) {\n    if (this->weights_loaded == false) {\n      throw std::runtime_error(\"Not Loaded\");\n    }\n    if (this->tps.empty()) {\n      throw std::runtime_error(\"No TP parts initialized\");\n    }\n    if ((int)w13_weight_ptrs.size() != gpu_tp_count || (int)w13_scale_ptrs.size() != gpu_tp_count ||\n        (int)w2_weight_ptrs.size() != gpu_tp_count || (int)w2_scale_ptrs.size() != gpu_tp_count) {\n      throw std::runtime_error(\"Pointer arrays size must match gpu_tp_count\");\n    }\n\n    this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {\n      this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,\n                                            w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);\n    });\n  }\n\n  void load_weights() override {\n    auto& config = this->config;\n    auto& tps = this->tps;\n    auto& tp_count = this->tp_count;\n    auto pool = config.pool;\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;\n\n    if (!config.quant_config.per_channel) {\n      throw std::runtime_error(\"FP8 Per-Channel MoE requires per_channel=true\");\n    }\n\n    if (config.gate_projs.empty() && config.gate_proj == nullptr) {\n      throw std::runtime_error(\"no weight source\");\n    }\n    const bool use_per_expert_ptrs = !config.gate_projs.empty();\n\n    const size_t full_weight_elems = (size_t)config.intermediate_size * config.hidden_size;\n    // Per-channel: scale count = output dimension\n    const size_t gate_up_scale_elems = (size_t)config.intermediate_size;\n    const size_t down_scale_elems = (size_t)config.hidden_size;\n\n    pool->dispense_backend()->do_numa_job([&, this](int i) {\n      auto& tpc = tps[i]->config_;\n      const size_t tp_weight_elems = (size_t)tpc.intermediate_size * tpc.hidden_size;\n      // Per-channel scales for TP part\n      const size_t tp_gate_up_scale_elems = (size_t)tpc.intermediate_size;\n      const size_t tp_down_scale_elems = (size_t)tpc.hidden_size;\n\n      tpc.gate_proj = new uint8_t[tpc.expert_num * tp_weight_elems];\n      tpc.up_proj = new uint8_t[tpc.expert_num * tp_weight_elems];\n      tpc.down_proj = new uint8_t[tpc.expert_num * tp_weight_elems];\n\n      tpc.gate_scale = new float[tpc.expert_num * tp_gate_up_scale_elems];\n      tpc.up_scale = new float[tpc.expert_num * tp_gate_up_scale_elems];\n      tpc.down_scale = new float[tpc.expert_num * tp_down_scale_elems];\n\n      const size_t tp_idx = (size_t)i;\n      // gate/up: split by N (intermediate_size)\n      const size_t gate_up_weight_src_offset = i * tp_weight_elems;\n      const size_t gate_up_scale_src_offset = i * tp_gate_up_scale_elems;\n\n      // down: split by K (intermediate_size)\n      const size_t down_weight_src_col_offset = i * (size_t)tpc.intermediate_size;\n\n      pool->get_subpool(i)->do_work_stealing_job(\n          tpc.expert_num, nullptr,\n          [&, &tpc](int expert_id_) {\n            const size_t expert_id = expert_map(physical_to_logical_map, expert_id_);\n\n            uint8_t* gate_dst = (uint8_t*)tpc.gate_proj + expert_id * tp_weight_elems;\n            uint8_t* up_dst = (uint8_t*)tpc.up_proj + expert_id * tp_weight_elems;\n            uint8_t* down_dst = (uint8_t*)tpc.down_proj + expert_id * tp_weight_elems;\n\n            float* gate_scale_dst = (float*)tpc.gate_scale + expert_id * tp_gate_up_scale_elems;\n            float* up_scale_dst = (float*)tpc.up_scale + expert_id * tp_gate_up_scale_elems;\n            float* down_scale_dst = (float*)tpc.down_scale + expert_id * tp_down_scale_elems;\n\n            const uint8_t* gate_src;\n            const uint8_t* up_src;\n            const uint8_t* down_src;\n            const float* gate_scale_src;\n            const float* up_scale_src;\n            const float* down_scale_src;\n\n            if (use_per_expert_ptrs) {\n              gate_src = (const uint8_t*)config.gate_projs[0][expert_id] + gate_up_weight_src_offset;\n              up_src = (const uint8_t*)config.up_projs[0][expert_id] + gate_up_weight_src_offset;\n              down_src = (const uint8_t*)config.down_projs[0][expert_id];\n\n              gate_scale_src = (const float*)config.gate_scales[0][expert_id] + gate_up_scale_src_offset;\n              up_scale_src = (const float*)config.up_scales[0][expert_id] + gate_up_scale_src_offset;\n              down_scale_src = (const float*)config.down_scales[0][expert_id];\n            } else {\n              gate_src = (const uint8_t*)config.gate_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;\n              up_src = (const uint8_t*)config.up_proj + expert_id * full_weight_elems + gate_up_weight_src_offset;\n              down_src = (const uint8_t*)config.down_proj + expert_id * full_weight_elems;\n\n              gate_scale_src =\n                  (const float*)config.gate_scale + expert_id * gate_up_scale_elems + gate_up_scale_src_offset;\n              up_scale_src = (const float*)config.up_scale + expert_id * gate_up_scale_elems + gate_up_scale_src_offset;\n              down_scale_src = (const float*)config.down_scale + expert_id * down_scale_elems;\n            }\n\n            // Copy gate/up weights and scales (N dimension split)\n            std::memcpy(gate_dst, gate_src, tp_weight_elems);\n            std::memcpy(up_dst, up_src, tp_weight_elems);\n            std::memcpy(gate_scale_dst, gate_scale_src, sizeof(float) * tp_gate_up_scale_elems);\n            std::memcpy(up_scale_dst, up_scale_src, sizeof(float) * tp_gate_up_scale_elems);\n\n            // Copy down weights (K dimension split) - row by row\n            for (int row = 0; row < config.hidden_size; row++) {\n              const size_t src_row_offset = (size_t)row * (size_t)config.intermediate_size + down_weight_src_col_offset;\n              const size_t dst_row_offset = (size_t)row * (size_t)tpc.intermediate_size;\n              std::memcpy(down_dst + dst_row_offset, down_src + src_row_offset, (size_t)tpc.intermediate_size);\n            }\n\n            // Copy down scales (N dimension = hidden_size, full copy for each TP)\n            std::memcpy(down_scale_dst, down_scale_src, sizeof(float) * tp_down_scale_elems);\n          },\n          nullptr);\n    });\n\n    DO_TPS_LOAD_WEIGHTS(pool);\n\n    pool->dispense_backend()->do_numa_job([&, this](int i) {\n      auto& tpc = tps[i]->config_;\n      delete[] (uint8_t*)tpc.gate_proj;\n      delete[] (uint8_t*)tpc.up_proj;\n      delete[] (uint8_t*)tpc.down_proj;\n      delete[] (float*)tpc.gate_scale;\n      delete[] (float*)tpc.up_scale;\n      delete[] (float*)tpc.down_scale;\n    });\n\n    this->weights_loaded = true;\n  }\n};\n\n#endif  // CPUINFER_OPERATOR_AMX_FP8_PERCHANNEL_MOE_H\n"
  },
  {
    "path": "kt-kernel/operators/amx/k2-moe.hpp",
    "content": "/**\n * @Description  : K2 AMX MoE operator for Kimi-K2 native inference\n * @Author       : oql, Codex and Claude\n * @Date         : 2025-12-09\n * @Version      : 1.0.0\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n *\n * This file implements K2 Int4 MoE using CRTP pattern, inheriting from moe_base.hpp.\n * K2 weights are stored with group-wise scales (KGroup Int4).\n **/\n#ifndef CPUINFER_OPERATOR_AMX_K2_MOE_H\n#define CPUINFER_OPERATOR_AMX_K2_MOE_H\n\n// #define LOAD_TIME_PROFILE\n\n#include \"moe_base.hpp\"\n\n/**\n * @brief K2 Int4 MoE operator using CRTP pattern\n * @tparam T Kernel type, defaults to amx::GemmKernel224Int4SmallKGroup\n *\n * This class provides K2-specific GEMM implementations:\n * - do_gate_up_gemm: Int4 weight with KGroup scale + AMX GEMM\n * - do_down_gemm: Same Int4 KGroup GEMM\n * - load_weights: Load Int4 weights with group-wise scales\n */\ntemplate <class T = amx::GemmKernel224Int4SmallKGroup>\nclass AMX_K2_MOE_TP : public AMX_MOE_BASE<T, AMX_K2_MOE_TP<T>> {\n  using Base = AMX_MOE_BASE<T, AMX_K2_MOE_TP<T>>;\n  using Base::config_;\n  using Base::down_ba_;\n  using Base::down_bb_;\n  using Base::down_bc_;\n  using Base::gate_bb_;\n  using Base::gate_bc_;\n  using Base::gate_up_ba_;\n  using Base::m_local_num_;\n  using Base::tp_part_idx;\n  using Base::up_bb_;\n  using Base::up_bc_;\n\n public:\n  using typename Base::input_t;\n  using typename Base::output_t;\n\n  AMX_K2_MOE_TP() = default;\n\n  AMX_K2_MOE_TP(GeneralMOEConfig config, int tp_part_idx_ = 0) : Base(config, tp_part_idx_) {}\n\n  void derived_init() {\n    auto& quant_config = config_.quant_config;\n    if (quant_config.group_size == 0 || quant_config.zero_point) {\n      throw std::runtime_error(\"Kimi-K2 MoE only support KGroup Int4\");\n    }\n    printf(\"Creating AMX_K2_MOE_TP %d at numa %d\\n\", tp_part_idx, numa_node_of_cpu(sched_getcpu()));\n  }\n\n  ~AMX_K2_MOE_TP() = default;\n\n  // ============================================================================\n  // CRTP buffer creation - with group_size\n  // ============================================================================\n\n  size_t buffer_a_required_size_impl(size_t m, size_t k) const {\n    return T::BufferA::required_size(m, k, config_.quant_config.group_size);\n  }\n  size_t buffer_b_required_size_impl(size_t n, size_t k) const {\n    return T::BufferB::required_size(n, k, config_.quant_config.group_size);\n  }\n  size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }\n\n  std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferA>(m, k, config_.quant_config.group_size, data);\n  }\n  std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferB>(n, k, config_.quant_config.group_size, data);\n  }\n  std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {\n    return std::make_shared<typename T::BufferC>(m, n, data);\n  }\n\n  // ============================================================================\n  // CRTP virtual points - GEMM dispatch\n  // ============================================================================\n\n  void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {\n    auto& group_size = config_.quant_config.group_size;\n    int m = m_local_num_[expert_idx];\n    auto& ba = gate_up_ba_[expert_idx];\n    auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];\n    auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];\n\n    // Dispatch based on qlen threshold\n    if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {\n      amx::mat_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);\n    } else {\n      amx::vec_mul_kgroup(m, config_.intermediate_size, config_.hidden_size, group_size, ba, bb, bc, ith, nth);\n    }\n  }\n\n  void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {\n    auto& group_size = config_.quant_config.group_size;\n    int m = m_local_num_[expert_idx];\n\n    if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {\n      amx::mat_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],\n                          down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);\n    } else {\n      amx::vec_mul_kgroup(m, config_.hidden_size, config_.intermediate_size, group_size, down_ba_[expert_idx],\n                          down_bb_[expert_idx], down_bc_[expert_idx], ith, nth);\n    }\n  }\n\n  /**\n   * @brief Load Int4 weights from contiguous memory layout\n   *\n   * Loads weights from config_.gate_proj, up_proj, down_proj with scales\n   * from config_.gate_scale, up_scale, down_scale.\n   */\n  void load_weights() {\n    auto& quant_config = config_.quant_config;\n    int& group_size = quant_config.group_size;\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n\n    if (quant_config.group_size == 0 || quant_config.zero_point) {\n      throw std::runtime_error(\"Kimi AVX MOE only support KGroup Int4.\");\n    }\n    if (config_.gate_scale == nullptr) {\n      throw std::runtime_error(\"Kimi AVX MOE only support load native weight.\");\n    }\n\n    // load weight\n    int nth = T::recommended_nth(config_.intermediate_size);\n    pool->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [this, nth, physical_to_logical_map](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          int ith = task_id % nth;\n          // gate part\n          gate_bb_[expert_idx]->from_raw_mat(\n              (uint8_t*)config_.gate_proj +\n                  ((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),\n              ith, nth);\n          // up part\n          up_bb_[expert_idx]->from_raw_mat(\n              (uint8_t*)config_.up_proj + ((logical_expert_id * config_.intermediate_size * config_.hidden_size) >> 1),\n              ith, nth);\n        },\n        nullptr);\n\n    nth = T::recommended_nth(config_.hidden_size);\n    pool->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [this, nth, physical_to_logical_map](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          int ith = task_id % nth;\n          // down part\n          down_bb_[expert_idx]->from_raw_mat(\n              (uint8_t*)config_.down_proj +\n                  ((logical_expert_id * config_.hidden_size * config_.intermediate_size) >> 1),\n              ith, nth);\n        },\n        nullptr);\n\n    pool->do_work_stealing_job(\n        config_.expert_num, nullptr,\n        [this, physical_to_logical_map](int task_id) {\n          uint64_t expert_idx = task_id;\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          size_t scale_elem_count = (config_.hidden_size * config_.intermediate_size) / config_.quant_config.group_size;\n\n          // convert scales from BF16 to FP32\n          convert_or_copy(gate_bb_[expert_idx]->d,\n                          (ggml_bf16_t*)config_.gate_scale + (logical_expert_id * scale_elem_count), scale_elem_count);\n          convert_or_copy(up_bb_[expert_idx]->d,\n                          (ggml_bf16_t*)config_.up_scale + (logical_expert_id * scale_elem_count), scale_elem_count);\n          convert_or_copy(down_bb_[expert_idx]->d,\n                          (ggml_bf16_t*)config_.down_scale + (logical_expert_id * scale_elem_count), scale_elem_count);\n        },\n        nullptr);\n#ifdef DEBUG_K2_MOE\n    dump_buffer_b(\"native\", 0, \"down\", down_bb_[0].get());\n#endif\n  }\n\n  static inline void fast_memcpy(void* __restrict dst, const void* __restrict src, size_t bytes) {\n    uint8_t* d = (uint8_t*)dst;\n    const uint8_t* s = (const uint8_t*)src;\n\n    // Main loop: 512-bit (64-byte) SIMD copies\n    size_t chunks = bytes / 64;\n    for (size_t i = 0; i < chunks; i++) {\n      __m512i data = _mm512_loadu_si512((__m512i*)s);\n      _mm512_storeu_si512((__m512i*)d, data);\n      d += 64;\n      s += 64;\n    }\n    bytes -= chunks * 64;\n\n    // Handle remaining bytes\n    if (bytes > 0) {\n      std::memcpy(d, s, bytes);\n    }\n  }\n\n  // Optimized SIMD float32 to bf16 conversion\n  static inline void fast_fp32_to_bf16(ggml_bf16_t* __restrict dst, const float* __restrict src, size_t count) {\n    size_t i = 0;\n\n    // Process 32 elements at a time (2x __m512, output 1x __m512i = 32 bf16)\n    for (; i + 32 <= count; i += 32) {\n      __m512 v0 = _mm512_loadu_ps(src + i);\n      __m512 v1 = _mm512_loadu_ps(src + i + 16);\n\n      // Convert to bf16 using truncation (shift right 16 bits)\n      __m512i i0 = _mm512_srli_epi32(_mm512_castps_si512(v0), 16);\n      __m512i i1 = _mm512_srli_epi32(_mm512_castps_si512(v1), 16);\n\n      // Pack 32-bit values to 16-bit\n      __m512i packed = _mm512_packus_epi32(i0, i1);\n\n      // Reorder due to packus lane behavior:\n      // packus outputs interleaved: [i0[0-3], i1[0-3], i0[4-7], i1[4-7], i0[8-11], i1[8-11], i0[12-15], i1[12-15]]\n      // We need sequential: [i0[0-15], i1[0-15]] = [i0[0-3], i0[4-7], i0[8-11], i0[12-15], i1[0-3], i1[4-7], i1[8-11],\n      // i1[12-15]] Permutation: [0, 2, 4, 6, 1, 3, 5, 7] (qword indices)\n      __m512i permuted = _mm512_permutexvar_epi64(_mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0), packed);\n\n      _mm512_storeu_si512((__m512i*)(dst + i), permuted);\n    }\n\n    // Handle remaining elements with scalar conversion\n    for (; i < count; i++) {\n      dst[i] = ggml_fp32_to_bf16(src[i]);\n    }\n  }\n\n  // Write a single expert's weights to the output buffers\n  // The caller provides pointers that already point to the target expert's location (no offset needed)\n  // expert_id: the index of the expert to write\n  // Optimized for maximum memory bandwidth using streaming stores\n  void write_weights_to_buffer(int gpu_tp_count, int cpu_tp_count, int expert_id, const GeneralMOEConfig& full_config,\n                               const std::vector<uintptr_t>& w13_weight_ptrs,\n                               const std::vector<uintptr_t>& w13_scale_ptrs,\n                               const std::vector<uintptr_t>& w2_weight_ptrs,\n                               const std::vector<uintptr_t>& w2_scale_ptrs) const {\n    const int group_size = config_.quant_config.group_size;\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n\n    // Calculate sizes for CPU TP part (this instance)\n    size_t cpu_tp_weight_elem_count = (size_t)config_.intermediate_size * config_.hidden_size;\n    size_t cpu_tp_weight_bytes = cpu_tp_weight_elem_count / 2;  // int4 packing\n    size_t cpu_tp_scale_elem_count = cpu_tp_weight_elem_count / group_size;\n\n    // Calculate sizes for GPU TP part\n    size_t gpu_tp_weight_elem_count = (size_t)full_config.intermediate_size * full_config.hidden_size / gpu_tp_count;\n    size_t gpu_tp_weight_bytes = gpu_tp_weight_elem_count / 2;  // int4 packing\n    size_t gpu_tp_scale_elem_count = gpu_tp_weight_elem_count / group_size;\n\n    // Determine mapping: which GPU TP parts should this CPU TP part write to?\n    // Since weights are col-major and we slice directly by memory order:\n    // - If cpu_tp_count >= gpu_tp_count: multiple(or one) CPU TPs write to one GPU TP\n    // - If cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs\n    if (cpu_tp_count >= gpu_tp_count) {\n      // Multiple CPU TPs map to one GPU TP\n      int target_gpu_tp = tp_part_idx / (cpu_tp_count / gpu_tp_count);\n      int local_idx = tp_part_idx % (cpu_tp_count / gpu_tp_count);\n\n      // Get pointers for this GPU TP part (already pointing to target expert's location)\n      uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[target_gpu_tp];\n      ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[target_gpu_tp];\n      uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[target_gpu_tp];\n      ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[target_gpu_tp];\n\n      // Calculate offset within the GPU TP buffer (for CPU TP slice within GPU TP)\n      size_t offset_in_gpu_weight = local_idx * cpu_tp_weight_bytes;\n      size_t offset_in_gpu_scale = local_idx * cpu_tp_scale_elem_count;\n\n      // Optimized task layout for maximum bandwidth:\n      // - Larger chunks to reduce task overhead\n      // - Separate large contiguous copies (gate_w, up_w) from strided copies (down)\n      // - Scale conversions are relatively small, merge with weight tasks\n\n      // Use fewer, larger tasks for better efficiency\n      constexpr int NUM_WEIGHT_TASKS = 8;  // Fewer tasks, larger chunks\n      constexpr int MIN_COLS_PER_TASK = 128;\n      int num_down_tasks = std::max(1, (int)config_.hidden_size / MIN_COLS_PER_TASK);\n      num_down_tasks = std::min(num_down_tasks, 32);\n\n      // Total tasks: gate_weight + up_weight + down_weight_scale + gate_scale + up_scale\n      int total_tasks = NUM_WEIGHT_TASKS * 2 + num_down_tasks + 2;\n\n      size_t weight_chunk_size = (cpu_tp_weight_bytes + NUM_WEIGHT_TASKS - 1) / NUM_WEIGHT_TASKS;\n      // Align chunk size to 64 bytes for optimal streaming stores\n      weight_chunk_size = (weight_chunk_size + 63) & ~63ULL;\n\n      pool->do_work_stealing_job(\n          total_tasks, nullptr,\n          [&, this, num_down_tasks, expert_id, weight_chunk_size](int task_id) {\n            if (task_id < NUM_WEIGHT_TASKS) {\n              // Gate weight copy - chunked\n              int chunk_idx = task_id;\n              size_t start = chunk_idx * weight_chunk_size;\n              size_t end = std::min(start + weight_chunk_size, cpu_tp_weight_bytes);\n              if (start < end) {\n                uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b;\n                fast_memcpy(w13_weight_dst + offset_in_gpu_weight + start, gate_weight_src + start, end - start);\n              }\n            } else if (task_id < NUM_WEIGHT_TASKS * 2) {\n              // Up weight copy - chunked\n              int chunk_idx = task_id - NUM_WEIGHT_TASKS;\n              size_t start = chunk_idx * weight_chunk_size;\n              size_t end = std::min(start + weight_chunk_size, cpu_tp_weight_bytes);\n              if (start < end) {\n                uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b;\n                fast_memcpy(w13_weight_dst + offset_in_gpu_weight + gpu_tp_weight_bytes + start, up_weight_src + start,\n                            end - start);\n              }\n            } else if (task_id < NUM_WEIGHT_TASKS * 2 + num_down_tasks) {\n              // Down columns - split by column chunks\n              // Each task handles multiple consecutive columns for better cache locality\n              int chunk_idx = task_id - NUM_WEIGHT_TASKS * 2;\n              size_t cols_per_chunk = (config_.hidden_size + num_down_tasks - 1) / num_down_tasks;\n              size_t col_start = chunk_idx * cols_per_chunk;\n              size_t col_end = std::min(col_start + cols_per_chunk, (size_t)config_.hidden_size);\n\n              size_t weight_per_col = config_.intermediate_size >> 1;\n              size_t scale_per_col = config_.intermediate_size / group_size;\n              size_t gpu_weight_stride = (full_config.intermediate_size / gpu_tp_count) >> 1;\n              size_t gpu_scale_stride = (full_config.intermediate_size / gpu_tp_count) / group_size;\n              size_t gpu_weight_slice_offset = local_idx * weight_per_col;\n              size_t gpu_scale_slice_offset = local_idx * scale_per_col;\n\n              for (size_t col = col_start; col < col_end; col++) {\n                fast_memcpy(w2_weight_dst + col * gpu_weight_stride + gpu_weight_slice_offset,\n                            (uint8_t*)down_bb_[expert_id]->b + col * weight_per_col, weight_per_col);\n\n                fast_fp32_to_bf16(w2_scale_dst + col * gpu_scale_stride + gpu_scale_slice_offset,\n                                  down_bb_[expert_id]->d + col * scale_per_col, scale_per_col);\n              }\n            } else if (task_id == NUM_WEIGHT_TASKS * 2 + num_down_tasks) {\n              // Gate scale convert\n              float* gate_scale_src = gate_bb_[expert_id]->d;\n              fast_fp32_to_bf16(w13_scale_dst + offset_in_gpu_scale, gate_scale_src, cpu_tp_scale_elem_count);\n            } else {\n              // Up scale convert\n              float* up_scale_src = up_bb_[expert_id]->d;\n              fast_fp32_to_bf16(w13_scale_dst + offset_in_gpu_scale + gpu_tp_scale_elem_count, up_scale_src,\n                                cpu_tp_scale_elem_count);\n            }\n          },\n          nullptr);\n    } else {\n      // cpu_tp_count < gpu_tp_count: one CPU TP writes to multiple GPU TPs\n      int gpu_tps_per_cpu_tp = gpu_tp_count / cpu_tp_count;\n      int start_gpu_tp = tp_part_idx * gpu_tps_per_cpu_tp;\n\n      // Size of data per GPU TP within this CPU TP\n      size_t data_per_gpu_tp_weight = cpu_tp_weight_bytes / gpu_tps_per_cpu_tp;\n      size_t data_per_gpu_tp_scale = cpu_tp_scale_elem_count / gpu_tps_per_cpu_tp;\n\n      // Optimized task layout\n      constexpr int NUM_WEIGHT_TASKS = 8;\n      constexpr int MIN_COLS_PER_TASK = 128;\n      int num_down_tasks = std::max(1, (int)config_.hidden_size / MIN_COLS_PER_TASK);\n      num_down_tasks = std::min(num_down_tasks, 32);\n\n      int tasks_per_gpu_tp = NUM_WEIGHT_TASKS * 2 + num_down_tasks + 2;\n      int total_tasks = tasks_per_gpu_tp * gpu_tps_per_cpu_tp;\n\n      size_t weight_chunk_size = (data_per_gpu_tp_weight + NUM_WEIGHT_TASKS - 1) / NUM_WEIGHT_TASKS;\n      weight_chunk_size = (weight_chunk_size + 63) & ~63ULL;\n\n      pool->do_work_stealing_job(\n          total_tasks, nullptr,\n          [&, this, gpu_tps_per_cpu_tp, start_gpu_tp, data_per_gpu_tp_weight, data_per_gpu_tp_scale, num_down_tasks,\n           tasks_per_gpu_tp, expert_id, weight_chunk_size](int task_id) {\n            int local_gpu_idx = task_id / tasks_per_gpu_tp;\n            int task_type = task_id % tasks_per_gpu_tp;\n            int gpu_tp_idx = start_gpu_tp + local_gpu_idx;\n\n            // Get pointers for this GPU TP part\n            uint8_t* w13_weight_dst = (uint8_t*)w13_weight_ptrs[gpu_tp_idx];\n            ggml_bf16_t* w13_scale_dst = (ggml_bf16_t*)w13_scale_ptrs[gpu_tp_idx];\n            uint8_t* w2_weight_dst = (uint8_t*)w2_weight_ptrs[gpu_tp_idx];\n            ggml_bf16_t* w2_scale_dst = (ggml_bf16_t*)w2_scale_ptrs[gpu_tp_idx];\n\n            // Calculate offsets within CPU TP buffers\n            size_t cpu_offset_weight = local_gpu_idx * data_per_gpu_tp_weight;\n            size_t cpu_offset_scale = local_gpu_idx * data_per_gpu_tp_scale;\n\n            if (task_type < NUM_WEIGHT_TASKS) {\n              // Gate weight copy - chunked\n              int chunk_idx = task_type;\n              size_t start = chunk_idx * weight_chunk_size;\n              size_t end = std::min(start + weight_chunk_size, data_per_gpu_tp_weight);\n              if (start < end) {\n                uint8_t* gate_weight_src = (uint8_t*)gate_bb_[expert_id]->b + cpu_offset_weight;\n                fast_memcpy(w13_weight_dst + start, gate_weight_src + start, end - start);\n              }\n            } else if (task_type < NUM_WEIGHT_TASKS * 2) {\n              // Up weight copy - chunked\n              int chunk_idx = task_type - NUM_WEIGHT_TASKS;\n              size_t start = chunk_idx * weight_chunk_size;\n              size_t end = std::min(start + weight_chunk_size, data_per_gpu_tp_weight);\n              if (start < end) {\n                uint8_t* up_weight_src = (uint8_t*)up_bb_[expert_id]->b + cpu_offset_weight;\n                fast_memcpy(w13_weight_dst + gpu_tp_weight_bytes + start, up_weight_src + start, end - start);\n              }\n            } else if (task_type < NUM_WEIGHT_TASKS * 2 + num_down_tasks) {\n              // Down columns - split by column chunks\n              int chunk_idx = task_type - NUM_WEIGHT_TASKS * 2;\n              size_t cols_per_chunk = (config_.hidden_size + num_down_tasks - 1) / num_down_tasks;\n              size_t col_start = chunk_idx * cols_per_chunk;\n              size_t col_end = std::min(col_start + cols_per_chunk, (size_t)config_.hidden_size);\n\n              size_t weight_per_gpu_col = (config_.intermediate_size / gpu_tps_per_cpu_tp) >> 1;\n              size_t scale_per_gpu_col = (config_.intermediate_size / gpu_tps_per_cpu_tp) / group_size;\n\n              for (size_t col = col_start; col < col_end; col++) {\n                size_t col_offset_weight = (col * config_.intermediate_size / 2) +\n                                           (local_gpu_idx * data_per_gpu_tp_weight / config_.hidden_size);\n                size_t col_offset_scale = (col * (config_.intermediate_size / group_size)) +\n                                          (local_gpu_idx * data_per_gpu_tp_scale / config_.hidden_size);\n\n                fast_memcpy(w2_weight_dst + col * weight_per_gpu_col,\n                            (uint8_t*)down_bb_[expert_id]->b + col_offset_weight, weight_per_gpu_col);\n\n                fast_fp32_to_bf16(w2_scale_dst + col * scale_per_gpu_col, down_bb_[expert_id]->d + col_offset_scale,\n                                  scale_per_gpu_col);\n              }\n            } else if (task_type == NUM_WEIGHT_TASKS * 2 + num_down_tasks) {\n              // Gate scale convert\n              float* gate_scale_src = gate_bb_[expert_id]->d + cpu_offset_scale;\n              fast_fp32_to_bf16(w13_scale_dst, gate_scale_src, data_per_gpu_tp_scale);\n            } else {\n              // Up scale convert\n              float* up_scale_src = up_bb_[expert_id]->d + cpu_offset_scale;\n              fast_fp32_to_bf16(w13_scale_dst + gpu_tp_scale_elem_count, up_scale_src, data_per_gpu_tp_scale);\n            }\n          },\n          nullptr);\n    }\n  }\n};\n\n// ============================================================================\n// TP_MOE specialization for AMX_K2_MOE_TP\n// Inherits from TP_MOE<AMX_MOE_BASE<...>> to reuse merge_results implementation\n// ============================================================================\n\ntemplate <typename K>\nclass TP_MOE<AMX_K2_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>> {\n public:\n  using Base = TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>>;\n  using Base::Base;\n\n  void load_weights() override {\n    auto& config = this->config;\n    auto& tps = this->tps;\n    auto& tp_count = this->tp_count;\n    auto pool = config.pool;\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;\n\n#ifdef LOAD_TIME_PROFILE\n    auto load_start_time = std::chrono::high_resolution_clock::now();\n    auto load_last = load_start_time;\n    long alloc_and_tp_slice_time = 0, tps_load_time = 0, cleanup_time = 0;\n#endif\n\n    bool use_per_expert_ptrs = !config.gate_projs.empty();\n\n    if (config.gate_projs.empty() && config.gate_scale == nullptr) {\n      throw std::runtime_error(\"K2 MoE only supports Packed Int4 with KGroup Scale\");\n    }\n\n    if (use_per_expert_ptrs) {\n      printf(\"From per-expert pointers (gate_projs)\\n\");\n    } else {\n      printf(\"From Packed Int4 with KGroup Scale\\n\");\n    }\n\n    int& group_size = config.quant_config.group_size;\n\n    pool->dispense_backend()->do_numa_job([&, this](int i) {\n      auto& tpc = tps[i]->config_;\n      size_t weight_elem_count = tpc.intermediate_size * tpc.hidden_size;\n      size_t scales_elem_count = (tpc.hidden_size / group_size) * tpc.intermediate_size;\n\n      tpc.gate_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];\n      tpc.up_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];\n      tpc.down_proj = new uint8_t[(tpc.expert_num * weight_elem_count) / 2];\n      tpc.gate_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];\n      tpc.up_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];\n      tpc.down_scale = new ggml_bf16_t[(tpc.expert_num * scales_elem_count)];\n\n      if (use_per_expert_ptrs) {\n        pool->get_subpool(i)->do_work_stealing_job(\n            tpc.expert_num, nullptr,\n            [&, i](int expert_id_) {\n              size_t expert_id = expert_map(physical_to_logical_map, expert_id_);\n\n              uint8_t* src_gate = (uint8_t*)config.gate_projs[0][expert_id];\n              uint8_t* src_up = (uint8_t*)config.up_projs[0][expert_id];\n              uint8_t* src_down = (uint8_t*)config.down_projs[0][expert_id];\n              ggml_bf16_t* src_gate_scale = (ggml_bf16_t*)config.gate_scales[0][expert_id];\n              ggml_bf16_t* src_up_scale = (ggml_bf16_t*)config.up_scales[0][expert_id];\n              ggml_bf16_t* src_down_scale = (ggml_bf16_t*)config.down_scales[0][expert_id];\n\n              memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),\n                     src_gate + ((i * weight_elem_count) >> 1), (weight_elem_count >> 1));\n\n              memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),\n                     src_up + ((i * weight_elem_count) >> 1), (weight_elem_count >> 1));\n\n              memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),\n                     src_gate_scale + (i * scales_elem_count), sizeof(ggml_bf16_t) * scales_elem_count);\n\n              memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),\n                     src_up_scale + (i * scales_elem_count), sizeof(ggml_bf16_t) * scales_elem_count);\n\n              for (size_t col = 0; col < config.hidden_size; col++) {\n                memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),\n                       src_down + ((col * config.intermediate_size + i * tpc.intermediate_size) >> 1),\n                       (tpc.intermediate_size >> 1));\n                memcpy((ggml_bf16_t*)tpc.down_scale +\n                           (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),\n                       src_down_scale +\n                           (col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),\n                       sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));\n              }\n            },\n            nullptr);\n      } else {\n        if (tpc.load == false) {\n          pool->get_subpool(i)->do_work_stealing_job(\n              tpc.expert_num, nullptr,\n              [&, i](int expert_id_) {\n                size_t expert_id = expert_map(physical_to_logical_map, expert_id_);\n\n                memcpy((uint8_t*)tpc.gate_proj + ((expert_id * weight_elem_count) >> 1),\n                       (uint8_t*)config.gate_proj +\n                           ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),\n                       ((sizeof(uint8_t) * weight_elem_count) >> 1));\n\n                memcpy((uint8_t*)tpc.up_proj + ((expert_id * weight_elem_count) >> 1),\n                       (uint8_t*)config.up_proj +\n                           ((expert_id * config.intermediate_size * config.hidden_size + i * weight_elem_count) >> 1),\n                       ((sizeof(uint8_t) * weight_elem_count) >> 1));\n\n                memcpy((ggml_bf16_t*)tpc.gate_scale + (expert_id * scales_elem_count),\n                       (ggml_bf16_t*)config.gate_scale +\n                           (expert_id * (config.hidden_size / group_size) * config.intermediate_size +\n                            i * scales_elem_count),\n                       sizeof(ggml_bf16_t) * scales_elem_count);\n\n                memcpy((ggml_bf16_t*)tpc.up_scale + (expert_id * scales_elem_count),\n                       (ggml_bf16_t*)config.up_scale +\n                           (expert_id * (config.hidden_size / group_size) * config.intermediate_size +\n                            i * scales_elem_count),\n                       sizeof(ggml_bf16_t) * scales_elem_count);\n\n                for (size_t col = 0; col < config.hidden_size; col++) {\n                  memcpy((uint8_t*)tpc.down_proj + ((expert_id * weight_elem_count + col * tpc.intermediate_size) >> 1),\n                         (uint8_t*)config.down_proj + ((expert_id * config.intermediate_size * config.hidden_size +\n                                                        col * config.intermediate_size + i * tpc.intermediate_size) >>\n                                                       1),\n                         (sizeof(uint8_t) * tpc.intermediate_size) >> 1);\n                  memcpy((ggml_bf16_t*)tpc.down_scale +\n                             (expert_id * scales_elem_count + col * (tpc.intermediate_size / group_size)),\n                         (ggml_bf16_t*)config.down_scale +\n                             ((expert_id * (config.intermediate_size / group_size) * config.hidden_size) +\n                              col * (config.intermediate_size / group_size) + i * (tpc.intermediate_size / group_size)),\n                         sizeof(ggml_bf16_t) * (tpc.intermediate_size / group_size));\n                }\n              },\n              nullptr);\n        }\n      }\n      printf(\"TP %d load weight done.\\n\", i);\n    });\n\n#ifdef LOAD_TIME_PROFILE\n    {\n      auto load_now_time = std::chrono::high_resolution_clock::now();\n      alloc_and_tp_slice_time =\n          std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();\n      load_last = load_now_time;\n    }\n#endif\n\n    DO_TPS_LOAD_WEIGHTS(pool);\n\n#ifdef LOAD_TIME_PROFILE\n    {\n      auto load_now_time = std::chrono::high_resolution_clock::now();\n      tps_load_time = std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();\n      load_last = load_now_time;\n    }\n#endif\n\n    pool->dispense_backend()->do_numa_job([&, this](int i) {\n      auto& tpc = tps[i]->config_;\n      delete[] (uint8_t*)(tpc.gate_proj);\n      delete[] (uint8_t*)(tpc.up_proj);\n      delete[] (uint8_t*)(tpc.down_proj);\n\n      delete[] (ggml_bf16_t*)(tpc.gate_scale);\n      delete[] (ggml_bf16_t*)(tpc.up_scale);\n      delete[] (ggml_bf16_t*)(tpc.down_scale);\n    });\n\n#ifdef LOAD_TIME_PROFILE\n    {\n      auto load_now_time = std::chrono::high_resolution_clock::now();\n      cleanup_time = std::chrono::duration_cast<std::chrono::microseconds>(load_now_time - load_last).count();\n    }\n    auto load_end_time = std::chrono::high_resolution_clock::now();\n    auto load_total_time =\n        std::chrono::duration_cast<std::chrono::microseconds>(load_end_time - load_start_time).count();\n    printf(\n        \"[K2 MoE Load Weights] tp_count: %d, alloc_and_tp_slice: %ld us, tps_load_weights: %ld us, cleanup: %ld us, \"\n        \"total: %ld us\\n\",\n        tp_count, alloc_and_tp_slice_time, tps_load_time, cleanup_time, load_total_time);\n#endif\n\n    this->weights_loaded = true;\n  }\n\n  void write_weight_scale_to_buffer(int gpu_tp_count, int expert_id, const std::vector<uintptr_t>& w13_weight_ptrs,\n                                    const std::vector<uintptr_t>& w13_scale_ptrs,\n                                    const std::vector<uintptr_t>& w2_weight_ptrs,\n                                    const std::vector<uintptr_t>& w2_scale_ptrs) {\n    if (this->weights_loaded == false) {\n      throw std::runtime_error(\"Not Loaded\");\n    }\n    if (this->tps.empty()) {\n      throw std::runtime_error(\"No TP parts initialized\");\n    }\n\n    if (w13_weight_ptrs.size() != gpu_tp_count || w13_scale_ptrs.size() != gpu_tp_count ||\n        w2_weight_ptrs.size() != gpu_tp_count || w2_scale_ptrs.size() != gpu_tp_count) {\n      throw std::runtime_error(\"Pointer arrays size must match gpu_tp_count\");\n    }\n\n    this->config.pool->dispense_backend()->do_numa_job([&, this](int i) {\n      this->tps[i]->write_weights_to_buffer(gpu_tp_count, this->tp_count, expert_id, this->config, w13_weight_ptrs,\n                                            w13_scale_ptrs, w2_weight_ptrs, w2_scale_ptrs);\n    });\n  }\n\n  // merge_results is inherited from TP_MOE<AMX_MOE_BASE<K, AMX_K2_MOE_TP<K>>>\n};\n\n#endif  // CPUINFER_OPERATOR_AMX_K2_MOE_H\n"
  },
  {
    "path": "kt-kernel/operators/amx/la/amx-example.cpp",
    "content": "#include <random>\n#include <stdexcept>\n\n#include \"amx.hpp\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n\nint main() {\n  // init GGML\n  struct ggml_init_params params = {\n      0,\n      NULL,\n      true,\n  };\n\n  auto ctx_eval = ggml_init(params);\n\n  if (!ctx_eval) {\n    throw std::runtime_error(\"Failed to create ggml context\");\n  }\n\n  // Allocate Memory\n  int m = 1000, n = 8, k = 512;\n  float* a = new float[m * k];  // m x k, Row Major\n  float* b = new float[k * n];  // k x n, Column Major\n  size_t c_row_size = n * sizeof(float);\n  c_row_size = (c_row_size + 63) / 64 * 64;  // pad C row\n  float* c = new (std::align_val_t(64)) float[m * c_row_size];\n  memset(c, 0, m * c_row_size * sizeof(float));\n  size_t ldc = c_row_size * sizeof(float);\n\n  std::mt19937 gen(123);\n  std::uniform_real_distribution<float> dis(0, 16);\n  for (int i = 0; i < m * k; i++) {\n    a[i] = dis(gen);\n  }\n  for (int i = 0; i < k * n; i++) {\n    b[i] = dis(gen);\n  }\n\n  // Convert to BF16\n  // QA and QB must be aligned to 64 for BF16\n  // k is a multiple of 32, so no need for padding\n  ggml_bf16_t* qa = new (std::align_val_t(64)) ggml_bf16_t[m * k];\n  size_t lda = k * sizeof(ggml_bf16_t);\n  ggml_bf16_t* qb = new (std::align_val_t(64)) ggml_bf16_t[k * n];\n  size_t ldb = k * sizeof(ggml_bf16_t);\n  ggml_fp32_to_bf16_row(a, qa, m * k);\n  ggml_fp32_to_bf16_row(b, qb, k * n);\n\n  // AMX Computation\n  amx::init_tile(GGML_TYPE_BF16, GGML_TYPE_BF16, GGML_TYPE_F32);\n  int nth = amx::recommended_nth(m, n, k, GGML_TYPE_BF16, GGML_TYPE_BF16, GGML_TYPE_F32);\n\n#pragma omp parallel for\n  for (int ith = 0; ith < nth; ith++) {\n    amx::gemm(m, n, k, qa, lda, GGML_TYPE_BF16, qb, ldb, GGML_TYPE_BF16, c, ldc, GGML_TYPE_F32, ith, nth);\n  }\n\n  // Check\n  float* d = new float[m * n];\n  memset(d, 0, m * n * sizeof(float));\n  for (int i = 0; i < m; i++) {\n    for (int j = 0; j < n; j++) {\n      for (int kk = 0; kk < k; kk++) {\n        d[i * n + j] += a[i * k + kk] * b[j * k + kk];\n      }\n    }\n  }\n\n  float max_error = 0;\n  for (int i = 0; i < m; i++) {\n    for (int j = 0; j < n; j++) {\n      max_error = std::max(max_error, std::abs(d[i * n + j] - c[i * c_row_size + j]) / std::abs(d[i * n + j]));\n      // printf(\"%.2f \",c[i*c_row_size+j]);\n    }\n    // printf(\"\\n\");\n  }\n  printf(\"Max Error %f%%\\n\", max_error * 100);\n\n  return 0;\n}\n"
  },
  {
    "path": "kt-kernel/operators/amx/la/amx.hpp",
    "content": "#ifndef AMX_HPP\n#define AMX_HPP\n#include <emmintrin.h>\n#include <immintrin.h>\n#include <stdlib.h>\n#include <sys/syscall.h>\n#include <tmmintrin.h>\n#include <unistd.h>\n\n#include <cassert>\n#include <cstdio>\n#include <stdexcept>\n\n#include \"llama.cpp/ggml-quants.h\"\n\n// Include the split AMX headers\n#include \"amx_config.hpp\"\n#include \"amx_kernels.hpp\"\n\nnamespace amx {\n\nstatic inline __m512 exp_avx512(__m512 x) {\n  const __m512 log2e = _mm512_set1_ps(1.44269504089f);\n  const __m512 c1 = _mm512_set1_ps(0.69314718056f);\n\n  __m512 y = _mm512_mul_ps(x, log2e);\n  __m512i int_part = _mm512_cvtps_epi32(y);\n  __m512 frac_part = _mm512_sub_ps(y, _mm512_cvtepi32_ps(int_part));\n\n  const __m512 poly_1 = _mm512_set1_ps(0.9999999995f);\n  const __m512 poly_2 = _mm512_set1_ps(0.6931471805f);\n  const __m512 poly_3 = _mm512_set1_ps(0.2402265069f);\n  const __m512 poly_4 = _mm512_set1_ps(0.0555041087f);\n  const __m512 poly_5 = _mm512_set1_ps(0.0096181291f);\n  const __m512 poly_6 = _mm512_set1_ps(0.0013333558f);\n\n  __m512 frac_exp = _mm512_fmadd_ps(\n      _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(poly_6, frac_part, poly_5), frac_part, poly_4),\n                                      frac_part, poly_3),\n                      frac_part, poly_2),\n      frac_part, poly_1);\n\n  __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part));\n  return _mm512_mul_ps(two_pow_i, frac_exp);\n}\n\nstatic inline __m512 act_fn(__m512 gate_val, __m512 up_val) {\n  __m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val);\n  // Clamp neg_gate_val to avoid exp overflow (exp(88) overflows for float32)\n  const __m512 max_exp_input = _mm512_set1_ps(88.0f);\n  neg_gate_val = _mm512_min_ps(neg_gate_val, max_exp_input);\n  __m512 exp_neg_gate = exp_avx512(neg_gate_val);\n  __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate);\n  __m512 act_val = _mm512_div_ps(gate_val, denom);\n\n  return _mm512_mul_ps(act_val, up_val);\n}\n\n#define AMX_DISPATCH_QTYPES(QA, QB, ...)                                 \\\n  [&] {                                                                  \\\n    switch (QB) {                                                        \\\n      case GGML_TYPE_Q8_0: {                                             \\\n        using qb = block_q8_0;                                           \\\n        switch (QA) {                                                    \\\n          case GGML_TYPE_Q4_0: {                                         \\\n            using qa = block_q4_0;                                       \\\n            return __VA_ARGS__();                                        \\\n          }                                                              \\\n          case GGML_TYPE_Q8_0: {                                         \\\n            using qa = block_q8_0;                                       \\\n            return __VA_ARGS__();                                        \\\n          }                                                              \\\n          default:                                                       \\\n            throw std::runtime_error(\"Unsupported quantized data type\"); \\\n        }                                                                \\\n      }                                                                  \\\n      case GGML_TYPE_Q8_K: {                                             \\\n        using qb = block_q8_K;                                           \\\n        switch (QA) {                                                    \\\n          case GGML_TYPE_Q4_K: {                                         \\\n            using qa = block_q4_K;                                       \\\n            return __VA_ARGS__();                                        \\\n          }                                                              \\\n          default:                                                       \\\n            throw std::runtime_error(\"Unsupported quantized data type\"); \\\n        }                                                                \\\n      }                                                                  \\\n      case GGML_TYPE_BF16: {                                             \\\n        using qb = ggml_bf16_t;                                          \\\n        switch (QA) {                                                    \\\n          case GGML_TYPE_BF16: {                                         \\\n            using qa = ggml_bf16_t;                                      \\\n            return __VA_ARGS__();                                        \\\n          }                                                              \\\n          default:                                                       \\\n            throw std::runtime_error(\"Unsupported quantized data type\"); \\\n        }                                                                \\\n      }                                                                  \\\n      default:                                                           \\\n        throw std::runtime_error(\"Unsupported quantized data type\");     \\\n    }                                                                    \\\n  }()\n\ninline void gemm(int m, int n, int k, const void* a, size_t lda, int type_a, const void* b, size_t ldb, int type_b,\n                 void* c, size_t ldc, int type_c, int ith, int nth) {\n  assert(reinterpret_cast<intptr_t>(c) % 64 == 0);\n  assert(ldc % 64 == 0);\n  assert(type_c == GGML_TYPE_F32);\n  float* cs = (float*)c;\n  AMX_DISPATCH_QTYPES(type_a, type_b, [&]() { mat_mul(m, n, k, (qa*)a, lda, (qb*)b, ldb, cs, ldc, ith, nth); });\n}\n\ninline void init_tile(int type_a, int type_b, int type_c) {\n#ifdef HAVE_AMX\n  enable_amx();\n  assert(type_c == GGML_TYPE_F32);\n  AMX_DISPATCH_QTYPES(type_a, type_b, []() { return GemmKernel<qa, qb, float>::type::config(); });\n#endif\n}\n\ninline int recommended_nth(int m, int n, int k, int type_a, int type_b, int type_c) {\n  assert(type_c == GGML_TYPE_F32);\n  return AMX_DISPATCH_QTYPES(type_a, type_b, [&]() { return GemmKernel<qa, qb, float>::type::recommended_nth(m); });\n}\n\n}  // namespace amx\n\n#endif  // AMX_HPP"
  },
  {
    "path": "kt-kernel/operators/amx/la/amx_buffers.hpp",
    "content": "#ifndef AMX_BUFFERS_HPP\n#define AMX_BUFFERS_HPP\n#include <algorithm>\n#include <cassert>\n#include <cstdint>\n#include <cstdio>\n#include <cstring>\n#include <limits>\n#include <vector>\n\n#include \"amx_config.hpp\"\n#include \"amx_utils.hpp\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"pack.hpp\"\n#include \"utils.hpp\"\n\nnamespace amx {\n\ntemplate <typename K>\nstruct BufferAImpl {\n  int8_t* a;\n  float* d;\n  int max_m, k;\n\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n\n  static size_t required_size(int max_m, int k) { return sizeof(int8_t) * max_m * k + sizeof(float) * max_m; }\n\n  BufferAImpl(int max_m, int k, void* ptr) : max_m(max_m), k(k) {\n    assert(max_m % M_STEP == 0);\n    assert(k % K_STEP == 0);\n    if (max_m % M_STEP || k % K_STEP) {\n      printf(\"max_m = %d, k = %d, M_STEP = %d, K_STEP = %d\\n\", max_m, k, M_STEP, K_STEP);\n      throw std::runtime_error(\"BufferAImpl: max_m and k must be multiple of M_STEP and K_STEP\");\n    }\n    set_data(ptr);\n  }\n\n  void set_data(void* ptr) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    a = reinterpret_cast<int8_t*>(ptr);\n    d = reinterpret_cast<float*>(a + max_m * k);\n  }\n\n  void from_mat(int m, ggml_bf16_t* src, int ith, int nth) {\n    assert(m <= max_m);\n    assert(ith == 0 && nth == 1);\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n        __m512 amax_v0 = _mm512_setzero_ps();\n        __m512 amax_v1 = _mm512_setzero_ps();\n        __m512 amax_v2 = _mm512_setzero_ps();\n        __m512 amax_v3 = _mm512_setzero_ps();\n        __m512 amax_v4 = _mm512_setzero_ps();\n        __m512 amax_v5 = _mm512_setzero_ps();\n        __m512 amax_v6 = _mm512_setzero_ps();\n        __m512 amax_v7 = _mm512_setzero_ps();\n        for (int j = 0; j < k; j += 128) {\n          __m512 f0, f1, f2, f3, f4, f5, f6, f7;\n          avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 0), &f0, &f1);\n          avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 32), &f2, &f3);\n          avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 64), &f4, &f5);\n          avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j + 96), &f6, &f7);\n          amax_v0 = vector_abs_max(amax_v0, f0);\n          amax_v1 = vector_abs_max(amax_v1, f1);\n          amax_v2 = vector_abs_max(amax_v2, f2);\n          amax_v3 = vector_abs_max(amax_v3, f3);\n          amax_v4 = vector_abs_max(amax_v4, f4);\n          amax_v5 = vector_abs_max(amax_v5, f5);\n          amax_v6 = vector_abs_max(amax_v6, f6);\n          amax_v7 = vector_abs_max(amax_v7, f7);\n        }\n        amax_v0 = vector_abs_max(amax_v0, amax_v1);\n        amax_v2 = vector_abs_max(amax_v2, amax_v3);\n        amax_v4 = vector_abs_max(amax_v4, amax_v5);\n        amax_v6 = vector_abs_max(amax_v6, amax_v7);\n        amax_v0 = vector_abs_max(amax_v0, amax_v2);\n        amax_v4 = vector_abs_max(amax_v4, amax_v6);\n        amax_v0 = vector_abs_max(amax_v0, amax_v4);\n        float amax = _mm512_reduce_max_ps(amax_v0);\n        d[m_begin + i] = amax / ((1 << 7) - 1);\n      }\n    }\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            __m512 id = _mm512_set1_ps(d[m_begin + i] ? 1.0f / d[m_begin + i] : 0.0f);\n            int8_t* dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP;\n            __m512 f0, f1, f2, f3;\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1);\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n            __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n            __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n            __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n            __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n            __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n            __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n            __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n            __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n            _mm_store_si128((__m128i*)dst, s0);\n            _mm_store_si128((__m128i*)(dst + 16), s1);\n            _mm_store_si128((__m128i*)(dst + 32), s2);\n            _mm_store_si128((__m128i*)(dst + 48), s3);\n          }\n        }\n      }\n    }\n  }\n\n  int8_t* get_submat(int m, int k, int m_begin, int k_begin) {\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;\n  }\n\n  float* get_scale(int m, int m_begin) { return d + m_begin; }\n};\n\ntemplate <typename K>\nstruct BufferAWithSumImpl {\n  int8_t* a;\n  float* d;\n  float* sum;\n  int max_m, k;\n\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n\n  static size_t required_size(int max_m, int k) { return sizeof(int8_t) * max_m * k + sizeof(float) * max_m * 2; }\n\n  BufferAWithSumImpl(int max_m, int k, void* ptr) : max_m(max_m), k(k) {\n    assert(max_m % M_STEP == 0);\n    assert(k % K_STEP == 0);\n    if (max_m % M_STEP || k % K_STEP) {\n      printf(\"max_m = %d, k = %d, M_STEP = %d, K_STEP = %d\\n\", max_m, k, M_STEP, K_STEP);\n      throw std::runtime_error(\"BufferAWithSumImpl: max_m and k must be multiple of M_STEP and K_STEP\");\n    }\n    set_data(ptr);\n  }\n\n  void set_data(void* ptr) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    a = reinterpret_cast<int8_t*>(ptr);\n    d = reinterpret_cast<float*>(a + max_m * k);\n    sum = d + max_m;\n  }\n\n  void from_mat(int m, ggml_bf16_t* src, int ith, int nth) {\n    assert(m <= max_m);\n    assert(ith == 0 && nth == 1);\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n        float amax = 0.0f;\n        float row_sum = 0.0f;\n        for (int j = 0; j < k; j += 32) {\n          __m512 f0, f1;\n          avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + j), &f0, &f1);\n          amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n          amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n          row_sum += _mm512_reduce_add_ps(f0);\n          row_sum += _mm512_reduce_add_ps(f1);\n        }\n        d[m_begin + i] = amax / ((1 << 7) - 1);\n        sum[m_begin + i] = row_sum;\n      }\n    }\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            __m512 id = _mm512_set1_ps(d[m_begin + i] ? 1.0f / d[m_begin + i] : 0.0f);\n            int8_t* dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP;\n            __m512 f0, f1, f2, f3;\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1);\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n            __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n            __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n            __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n            __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n            __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n            __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n            __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n            __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n            _mm_store_si128((__m128i*)dst, s0);\n            _mm_store_si128((__m128i*)(dst + 16), s1);\n            _mm_store_si128((__m128i*)(dst + 32), s2);\n            _mm_store_si128((__m128i*)(dst + 48), s3);\n          }\n        }\n      }\n    }\n  }\n\n  int8_t* get_submat(int m, int k, int m_begin, int k_begin) {\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;\n  }\n\n  float* get_scale(int m, int m_begin) { return d + m_begin; }\n  float* get_sum(int m, int m_begin) { return sum + m_begin; }\n};\n\ntemplate <typename K>\nstruct BufferAWithSumKGroupImpl {\n  int8_t* a;\n  float* d;\n  float* sum;\n  int max_m, k, k_group_size, k_group_count;\n\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n\n  static size_t required_size(int max_m, int k, int k_group_size) {\n    return sizeof(int8_t) * max_m * k + sizeof(float) * max_m * (k / k_group_size) * 2;\n  }\n\n  BufferAWithSumKGroupImpl(int max_m, int k, int k_group_size, void* ptr)\n      : max_m(max_m), k(k), k_group_size(k_group_size) {\n    if (max_m % M_STEP || k % K_STEP || k % k_group_size) {\n      printf(\"max_m = %d, k = %d, M_STEP = %d, K_STEP = %d, k_group_size = %d\\n\", max_m, k, M_STEP, K_STEP,\n             k_group_size);\n      throw std::runtime_error(\"BufferAWithSumImpl: max_m and k must be multiple of M_STEP and K_STEP\");\n    }\n    k_group_count = k / k_group_size;\n    set_data(ptr);\n  }\n\n  void set_data(void* ptr) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    a = reinterpret_cast<int8_t*>(ptr);\n    d = reinterpret_cast<float*>(a + max_m * k);\n    sum = d + max_m * k_group_count;\n  }\n\n  void from_mat(int m, ggml_bf16_t* src, int ith, int nth) {\n    assert(m <= max_m);\n    assert(ith == 0 && nth == 1);\n    // for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n    //   for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n    //     for(int kg = 0; kg < k_group_count; kg++){\n    //       float amax = 0.0f;\n    //       float row_sum = 0.0f;\n    //       for (int j = 0; j < k; j += 32) {\n    //         __m512 f0, f1;\n    //         avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + j), &f0, &f1);\n    //         amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n    //         amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n    //         row_sum += _mm512_reduce_add_ps(f0);\n    //         row_sum += _mm512_reduce_add_ps(f1);\n    //       }\n    //       d[(m_begin + i) * k_group_count + kg] = amax / ((1 << 7) - 1);\n    //       sum[(m_begin + i) * k_group_count + kg] = row_sum;\n    //     }\n    //   }\n    // }\n    for (int m_idx = 0; m_idx < m; m_idx++) {\n      for (int kg = 0; kg < k_group_count; kg++) {\n        float amax = 0.0f;\n        float row_sum = 0.0f;\n        int k_start = kg * k_group_size;\n        int k_end = k_start + k_group_size;\n        for (int j = k_start; j < k_end; j += 32) {\n          __m512 f0, f1;\n          avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_idx)*k + j), &f0, &f1);\n          amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n          amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n          row_sum += _mm512_reduce_add_ps(f0);\n          row_sum += _mm512_reduce_add_ps(f1);\n        }\n        d[(m_idx)*k_group_count + kg] = amax / ((1 << 7) - 1);\n        sum[(m_idx)*k_group_count + kg] = row_sum;\n      }\n    }\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            int k_group_idx = (k_block_begin + k_begin) / k_group_size;\n            float scale = d[(m_begin + i) * k_group_count + k_group_idx];\n            __m512 id = _mm512_set1_ps(scale ? 1.0f / scale : 0.0f);\n            // __m512 id = _mm512_set1_ps(d[m_begin + i] ? 1.0f / d[m_begin + i] : 0.0f);\n            int8_t* dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP;\n            __m512 f0, f1, f2, f3;\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1);\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n            __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n            __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n            __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n            __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n            __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n            __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n            __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n            __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n            _mm_store_si128((__m128i*)dst, s0);\n            _mm_store_si128((__m128i*)(dst + 16), s1);\n            _mm_store_si128((__m128i*)(dst + 32), s2);\n            _mm_store_si128((__m128i*)(dst + 48), s3);\n          }\n        }\n      }\n    }\n  }\n\n  int8_t* get_submat(int m, int k, int m_begin, int k_begin) {\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;\n  }\n\n  float* get_scale(int m, int m_begin, int k, int k_begin) {\n    int k_group_idx = k_begin / k_group_size;\n    return d + m_begin * k_group_count + k_group_idx;\n  }\n  float* get_sum(int m, int m_begin, int k, int k_begin) {\n    int k_group_idx = k_begin / k_group_size;\n    return sum + m_begin * k_group_count + k_group_idx;\n  }\n};\n\ntemplate <typename K>\nstruct BufferAKGroupImpl {\n  int8_t* a;\n  float* d;\n  int max_m, k, k_group_size, k_group_count;\n\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n\n  static size_t required_size(int max_m, int k, int k_group_size) {\n    ASSERT_RELEASE(k % k_group_size == 0, \"k must be multiple of k_group_size\");\n    return sizeof(int8_t) * max_m * k + sizeof(float) * max_m * (k / k_group_size);\n  }\n\n  BufferAKGroupImpl(int max_m, int k, int k_group_size, void* ptr) : max_m(max_m), k(k), k_group_size(k_group_size) {\n    ASSERT_RELEASE(k % k_group_size == 0, \"k must be multiple of k_group_size\");\n    ASSERT_RELEASE(max_m % M_STEP == 0, \"max_m must be multiple of M_STEP\");\n    ASSERT_RELEASE(k % K_STEP == 0, \"k must be multiple of K_STEP\");\n    ASSERT_RELEASE(K_BLOCK % k_group_size == 0, \"K_BLOCK must be multiple of k_group_size\");\n    // ASSERT_RELEASE(k % K_BLOCK == 0, \"k must be multiple of K_BLOCK\");\n    k_group_count = k / k_group_size;\n\n    set_data(ptr);\n  }\n\n  void set_data(void* ptr) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    a = reinterpret_cast<int8_t*>(ptr);\n    d = reinterpret_cast<float*>(a + max_m * k);\n  }\n\n  int8_t* get_submat(int m, int k, int m_begin, int k_begin) {\n    // Follow BufferAImpl pattern\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;\n  }\n\n  void from_mat(int m, ggml_bf16_t* src, int ith, int nth) {\n    assert(m <= max_m);\n    assert(ith == 0 && nth == 1);\n\n    // 计算每个 k_group 的 scale\n    for (int m_idx = 0; m_idx < m; m_idx++) {\n      for (int kg = 0; kg < k_group_count; kg++) {\n        float amax = 0.0f;\n        int k_start = kg * k_group_size;\n        int k_end = k_start + k_group_size;\n        // 32 -> M_STEP\n        for (int j = k_start; j < k_end; j += 32) {\n          __m512 f0, f1;\n          avx512_32xbf16_to_32xfp32((__m512i*)(src + m_idx * k + j), &f0, &f1);\n          amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n          amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n        }\n        d[m_idx * k_group_count + kg] = amax / ((1 << 7) - 1);\n      }\n    }\n\n    // Simplified quantization following BufferAImpl pattern but with k-group support\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            // Get the scale for this k_group\n            int k_group_idx = (k_block_begin + k_begin) / k_group_size;\n            float scale = d[(m_begin + i) * k_group_count + k_group_idx];\n            __m512 id = _mm512_set1_ps(scale ? 1.0f / scale : 0.0f);\n\n            // Calculate destination similar to BufferAImpl but accounting for k-groups\n            int8_t* dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP;\n\n            __m512 f0, f1, f2, f3;\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1);\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n            __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n            __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n            __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n            __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n            __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n            __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n            __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n            __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n            _mm_store_si128((__m128i*)dst, s0);\n            _mm_store_si128((__m128i*)(dst + 16), s1);\n            _mm_store_si128((__m128i*)(dst + 32), s2);\n            _mm_store_si128((__m128i*)(dst + 48), s3);\n          }\n        }\n      }\n    }\n  }\n\n  float* get_scale(int m, int m_begin, int k, int k_begin) {\n    int k_group_idx = k_begin / k_group_size;\n    return d + m_begin * k_group_count + k_group_idx;\n  }\n};\n\n// BufferASmallKGroupImpl: For kernels with K_STEP=32 (e.g., GemmKernel224Int4SmallKGroup)\n// This fixes the buffer overflow issue where the base class writes 64 bytes per K_STEP iteration\n// but the buffer is only sized for 32-byte steps.\ntemplate <typename K>\nstruct BufferASmallKGroupImpl : public BufferAKGroupImpl<K> {\n  using Base = BufferAKGroupImpl<K>;\n  using Base::a;\n  using Base::d;\n  using Base::k;\n  using Base::k_group_count;\n  using Base::k_group_size;\n  using Base::max_m;\n\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n\n  BufferASmallKGroupImpl(int max_m, int k, int k_group_size, void* ptr) : Base(max_m, k, k_group_size, ptr) {}\n\n  // Override from_mat to write only 32 bytes per K_STEP iteration\n  void from_mat(int m, ggml_bf16_t* src, int ith, int nth) {\n    assert(m <= max_m);\n    assert(ith == 0 && nth == 1);\n\n    // Calculate scale for each k_group (same as base class)\n    for (int m_idx = 0; m_idx < m; m_idx++) {\n      for (int kg = 0; kg < k_group_count; kg++) {\n        float amax = 0.0f;\n        int k_start = kg * k_group_size;\n        int k_end = k_start + k_group_size;\n        for (int j = k_start; j < k_end; j += 32) {\n          __m512 f0, f1;\n          avx512_32xbf16_to_32xfp32((__m512i*)(src + m_idx * k + j), &f0, &f1);\n          amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n          amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n        }\n        d[m_idx * k_group_count + kg] = amax / ((1 << 7) - 1);\n      }\n    }\n\n    // Quantization with 32-byte writes per K_STEP iteration\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            // Get the scale for this k_group\n            int k_group_idx = (k_block_begin + k_begin) / k_group_size;\n            float scale = d[(m_begin + i) * k_group_count + k_group_idx];\n            __m512 id = _mm512_set1_ps(scale ? 1.0f / scale : 0.0f);\n\n            // Calculate destination - writes K_STEP (32) bytes\n            int8_t* dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP;\n\n            // Only process 32 bytes (2 x __m512 -> 2 x __m128i) per iteration\n            __m512 f0, f1;\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1);\n            __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n            __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n            __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n            __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n            _mm_store_si128((__m128i*)dst, s0);\n            _mm_store_si128((__m128i*)(dst + 16), s1);\n          }\n        }\n      }\n    }\n  }\n};\n\ntemplate <typename K>\nstruct BufferBInt4Impl {\n  using dt = typename K::dt;\n  dt* b;\n  float* d;\n  int n, k;\n\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n  static constexpr int TILE_N = K::TILE_N;\n\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n  static const int B_K_STEP = 2 * K_STEP;\n  static constexpr bool SCALE = true;\n\n  static size_t required_size(int n, int k) { return sizeof(int8_t) * n * k / 2 + sizeof(float) * n; }\n\n  BufferBInt4Impl(int n, int k, void* ptr) : n(n), k(k) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    assert(n % N_STEP == 0);\n    assert(k % B_K_STEP == 0);\n    if (n % N_STEP || k % B_K_STEP) {\n      printf(\"n: %d, k: %d, N_STEP: %d, B_K_STEP: %d\\n\", n, k, N_STEP, B_K_STEP);\n      throw std::runtime_error(\"n or k is not aligned to N_STEP or B_K_STEP\");\n    }\n    b = reinterpret_cast<dt*>(ptr);\n    d = reinterpret_cast<float*>(offset_pointer(b, n * k / 2));\n  }\n\n  static __m128i round_4bit_s8(__m128i x) {\n    __m128i s = _mm_and_si128(x, _mm_set1_epi8(0x80));\n    s = _mm_or_si128(s, _mm_srai_epi16(s, 1));\n    s = _mm_or_si128(s, _mm_srai_epi16(s, 2));\n    s = _mm_or_si128(s, _mm_srai_epi16(s, 4));\n\n    x = _mm_abs_epi8(x);\n    x = _mm_add_epi8(x, _mm_set1_epi8(0x08));\n    x = _mm_and_si128(x, _mm_set1_epi8(0xF0));\n    x = _mm_xor_si128(x, s);\n    x = _mm_sub_epi8(x, s);\n    return x;\n  }\n\n  void from_mat(ggml_bf16_t* src, int ith, int nth) {\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int i = 0; i < N_STEP; i++) {\n        float amax = 0.0f;\n        for (int j = 0; j < k; j += 32) {\n          __m512 f0, f1;\n          avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1);\n          amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n          amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n        }\n        d[n_block_begin + n_begin + i] = amax / 112.0;  // 7*16\n      }\n    }\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += B_K_STEP) {\n          for (int i = 0; i < N_STEP; i++) {\n            __m512 id = _mm512_set1_ps(d[n_block_begin + n_begin + i] ? 1.0f / d[n_block_begin + n_begin + i] : 0.0f);\n            dt* dst = offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                         k_begin * N_STEP + i * B_K_STEP) /\n                                            2);\n            {\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin),\n                                        &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n              __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n              __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n              __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n              __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n              s0 = _mm_srli_epi16(round_4bit_s8(s0), 4);\n              s1 = _mm_srli_epi16(round_4bit_s8(s1), 4);\n              s2 = _mm_srli_epi16(round_4bit_s8(s2), 4);\n              s3 = _mm_srli_epi16(round_4bit_s8(s3), 4);\n              // s0 = _mm_or_si128(round_up4(s0), _mm_srli_epi16(round_up4(s1), 4));\n              // s2 = _mm_or_si128(round_up4(s2), _mm_srli_epi16(round_up4(s3), 4));\n              _mm_store_si128((__m128i*)dst, s0);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 16)), s1);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 32)), s2);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 48)), s3);\n            }\n\n            {\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 2, &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 3, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n              __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n              __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n              __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n              __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n              s0 = round_4bit_s8(s0);\n              s1 = round_4bit_s8(s1);\n              s2 = round_4bit_s8(s2);\n              s3 = round_4bit_s8(s3);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 0)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 0))), s0));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 16)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 16))), s1));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 32)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 32))), s2));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 48)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 48))), s3));\n            }\n          }\n          transpose_16x16_32bit((__m512i*)(offset_pointer(\n              b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2)));\n          transpose_16x16_32bit(\n              (__m512i*)(offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                            k_begin * N_STEP + TILE_N * B_K_STEP) /\n                                               2)));\n        }\n      }\n    }\n  }\n\n  dt* get_submat(int n, int k, int n_begin, int k_begin) {\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    n_begin -= n_block_begin;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return offset_pointer(\n        b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2);\n  }\n\n  float* get_scale(int n, int n_begin) { return d + n_begin; }\n};\n\ntemplate <typename K>\nstruct BufferBKGroupImpl {\n  using dt = typename K::dt;\n  dt* b;\n  float* d;\n  int n, k, k_group_size, k_group_count;\n\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n  static constexpr int TILE_N = K::TILE_N;\n\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n  static const int B_K_STEP = 2 * K_STEP;\n  static constexpr bool SCALE = true;\n\n  static size_t required_size(int n, int k, int k_group_size) {\n    ASSERT_RELEASE(k % k_group_size == 0, \"k must be multiple of k_group_size\");\n    return sizeof(int8_t) * n * k / 2 + sizeof(float) * n * (k / k_group_size);\n  }\n\n  BufferBKGroupImpl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    assert(n % N_STEP == 0);\n    assert(k % B_K_STEP == 0);\n    ASSERT_RELEASE(k % k_group_size == 0, \"k must be multiple of k_group_size\");\n    ASSERT_RELEASE(K_BLOCK % k_group_size == 0, \"K_BLOCK must be multiple of k_group_size\");\n    if (n % N_STEP || k % B_K_STEP) {\n      printf(\"n: %d, k: %d, N_STEP: %d, B_K_STEP: %d\\n\", n, k, N_STEP, B_K_STEP);\n      throw std::runtime_error(\"n or k is not aligned to N_STEP or B_K_STEP\");\n    }\n    k_group_count = k / k_group_size;\n    b = reinterpret_cast<dt*>(ptr);\n    d = reinterpret_cast<float*>(offset_pointer(b, n * k / 2));\n  }\n\n  static __m128i round_4bit_s8(__m128i x) {\n    __m128i s = _mm_and_si128(x, _mm_set1_epi8(0x80));\n    s = _mm_or_si128(s, _mm_srai_epi16(s, 1));\n    s = _mm_or_si128(s, _mm_srai_epi16(s, 2));\n    s = _mm_or_si128(s, _mm_srai_epi16(s, 4));\n\n    x = _mm_abs_epi8(x);\n    x = _mm_add_epi8(x, _mm_set1_epi8(0x08));\n    x = _mm_and_si128(x, _mm_set1_epi8(0xF0));\n    x = _mm_xor_si128(x, s);\n    x = _mm_sub_epi8(x, s);\n    return x;\n  }\n\n  void from_mat(ggml_bf16_t* src, int ith, int nth) {\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n\n    // Compute scales per k-group for each n\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int i = 0; i < N_STEP; i++) {\n        for (int kg = 0; kg < k_group_count; kg++) {\n          float amax = 0.0f;\n          int k_start = kg * k_group_size;\n          int k_end = k_start + k_group_size;\n\n          for (int j = k_start; j < k_end; j += 32) {\n            __m512 f0, f1;\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1);\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n          }\n          d[kg * n + (n_block_begin + n_begin + i)] = amax / 112.0;  // 7*16\n          // d[(n_block_begin + n_begin + i) * k_group_count + kg] = amax / 112.0; // 7*16\n        }\n      }\n    }\n\n    // Quantize with per k-group scaling\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += B_K_STEP) {\n          for (int i = 0; i < N_STEP; i++) {\n            // Get the scale for this k_group\n            int k_group_idx0 = (k_block_begin + k_begin) / k_group_size;\n            int k_group_idx1 = (k_block_begin + k_begin + K_STEP) / k_group_size;\n            float scale0 = d[k_group_idx0 * n + (n_block_begin + n_begin + i)];\n            float scale1 = d[k_group_idx1 * n + (n_block_begin + n_begin + i)];\n            // float scale = d[(n_block_begin + n_begin + i) * k_group_count + k_group_idx];\n            __m512 id0 = _mm512_set1_ps(scale0 ? 1.0f / scale0 : 0.0f);\n            __m512 id1 = _mm512_set1_ps(scale1 ? 1.0f / scale1 : 0.0f);\n\n            dt* dst = offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                         k_begin * N_STEP + i * B_K_STEP) /\n                                            2);\n            {\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin),\n                                        &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id0));\n              __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id0));\n              __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id0));\n              __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id0));\n              __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n              s0 = _mm_srli_epi16(round_4bit_s8(s0), 4);\n              s1 = _mm_srli_epi16(round_4bit_s8(s1), 4);\n              s2 = _mm_srli_epi16(round_4bit_s8(s2), 4);\n              s3 = _mm_srli_epi16(round_4bit_s8(s3), 4);\n              _mm_store_si128((__m128i*)dst, s0);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 16)), s1);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 32)), s2);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 48)), s3);\n            }\n\n            {\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 2, &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 3, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id1));\n              __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id1));\n              __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id1));\n              __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id1));\n              __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n              s0 = round_4bit_s8(s0);\n              s1 = round_4bit_s8(s1);\n              s2 = round_4bit_s8(s2);\n              s3 = round_4bit_s8(s3);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 0)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 0))), s0));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 16)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 16))), s1));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 32)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 32))), s2));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 48)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 48))), s3));\n            }\n          }\n          transpose_16x16_32bit((__m512i*)(offset_pointer(\n              b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2)));\n          transpose_16x16_32bit(\n              (__m512i*)(offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                            k_begin * N_STEP + TILE_N * B_K_STEP) /\n                                               2)));\n        }\n      }\n    }\n  }\n\n  dt* get_submat(int n, int k, int n_begin, int k_begin) {\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    n_begin -= n_block_begin;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return offset_pointer(\n        b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2);\n  }\n\n  float* get_scale(int n, int n_begin, int k, int k_begin) {\n    int k_group_idx = k_begin / k_group_size;\n    return d + k_group_idx * n + n_begin;\n    // return d + n_begin * k_group_count + k_group_idx;\n  }\n};\n\ntemplate <typename K>\nstruct BufferBInt4WithZeroImpl {\n  using dt = typename K::dt;\n  dt* b;\n  float *d, *mins;  // scale, mins\n  int n, k;\n\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n  static constexpr int TILE_N = K::TILE_N;\n\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n  static const int B_K_STEP = 2 * K_STEP;\n  static constexpr bool SCALE = true;\n\n  static size_t required_size(int n, int k) { return sizeof(int8_t) * n * k / 2 + sizeof(float) * n * 2; }\n\n  BufferBInt4WithZeroImpl(int n, int k, void* ptr) : n(n), k(k) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    assert(n % N_STEP == 0);\n    assert(k % B_K_STEP == 0);\n    if (n % N_STEP || k % B_K_STEP) {\n      printf(\"n: %d, k: %d, N_STEP: %d, B_K_STEP: %d\\n\", n, k, N_STEP, B_K_STEP);\n      throw std::runtime_error(\"n or k is not aligned to N_STEP or B_K_STEP\");\n    }\n    b = reinterpret_cast<dt*>(ptr);\n    d = reinterpret_cast<float*>(offset_pointer(b, n * k / 2));\n    mins = d + n;\n  }\n\n  // 对 uint8_t 批量四舍五入到最接近的 16 倍数\n  static __m128i round_4bit_u8(__m128i x) {\n    // 加 8 做四舍五入，使用 Saturate\n    x = _mm_adds_epi8(x, _mm_set1_epi8(0x08));\n    // 清除低 4 位（即对 16 对齐）\n    x = _mm_and_si128(x, _mm_set1_epi8(0xF0));\n    return x;\n  }\n\n  void from_mat(ggml_bf16_t* src, int ith, int nth) {\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int i = 0; i < N_STEP; i++) {\n        float amax = std::numeric_limits<float>::lowest();\n        float amin = std::numeric_limits<float>::max();\n        for (int j = 0; j < k; j += 32) {\n          __m512 f0, f1;\n          avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1);\n          amax = MAX(amax, _mm512_reduce_max_ps(f0));\n          amax = MAX(amax, _mm512_reduce_max_ps(f1));\n          amin = MIN(amin, _mm512_reduce_min_ps(f0));\n          amin = MIN(amin, _mm512_reduce_min_ps(f1));\n        }\n        d[n_block_begin + n_begin + i] = (amax - amin) / 240.0;  // 15*16\n        mins[n_block_begin + n_begin + i] = amin;\n      }\n    }\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += B_K_STEP) {\n          for (int i = 0; i < N_STEP; i++) {\n            __m512 id = _mm512_set1_ps(d[n_block_begin + n_begin + i] ? 1.0f / d[n_block_begin + n_begin + i] : 0.0f);\n            __m512 zps = _mm512_set1_ps(-mins[n_block_begin + n_begin + i]);\n            dt* dst = offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                         k_begin * N_STEP + i * B_K_STEP) /\n                                            2);\n            {\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin),\n                                        &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f0, zps), id));\n              __m512i i1 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f1, zps), id));\n              __m512i i2 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f2, zps), id));\n              __m512i i3 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f3, zps), id));\n              __m128i s0 = _mm512_cvtusepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtusepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtusepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtusepi32_epi8(i3);\n              s0 = _mm_srli_epi16(round_4bit_u8(s0), 4);\n              s1 = _mm_srli_epi16(round_4bit_u8(s1), 4);\n              s2 = _mm_srli_epi16(round_4bit_u8(s2), 4);\n              s3 = _mm_srli_epi16(round_4bit_u8(s3), 4);\n              // s0 = _mm_or_si128(round_up4(s0), _mm_srli_epi16(round_up4(s1), 4));\n              // s2 = _mm_or_si128(round_up4(s2), _mm_srli_epi16(round_up4(s3), 4));\n              _mm_store_si128((__m128i*)dst, s0);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 16)), s1);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 32)), s2);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 48)), s3);\n            }\n\n            {\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 2, &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 3, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f0, zps), id));\n              __m512i i1 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f1, zps), id));\n              __m512i i2 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f2, zps), id));\n              __m512i i3 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f3, zps), id));\n              __m128i s0 = _mm512_cvtusepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtusepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtusepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtusepi32_epi8(i3);\n              s0 = round_4bit_u8(s0);\n              s1 = round_4bit_u8(s1);\n              s2 = round_4bit_u8(s2);\n              s3 = round_4bit_u8(s3);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 0)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 0))), s0));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 16)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 16))), s1));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 32)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 32))), s2));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 48)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 48))), s3));\n            }\n          }\n          transpose_16x16_32bit((__m512i*)(offset_pointer(\n              b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2)));\n          transpose_16x16_32bit(\n              (__m512i*)(offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                            k_begin * N_STEP + TILE_N * B_K_STEP) /\n                                               2)));\n        }\n      }\n    }\n  }\n\n  dt* get_submat(int n, int k, int n_begin, int k_begin) {\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    n_begin -= n_block_begin;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return offset_pointer(\n        b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2);\n  }\n\n  float* get_scale(int n, int n_begin) { return d + n_begin; }\n  float* get_min(int n, int n_begin) { return mins + n_begin; }\n};\n\n// BufferB for Signed Int4 with KGroup Scale (no zero point)\n// Used for K2 MoE - signed int4 range: [-8, 7]\ntemplate <typename K>\nstruct BufferBInt4KGroupImpl {\n  using dt = typename K::dt;\n  dt* b;     // packed signed int4 weights, col majored\n  float* d;  // scales only (no mins/zero-points), row majored\n  int n, k, k_group_size, k_group_count;\n\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr bool SCALE = true;\n\n  // Size calculation: packed int4 weights + scales (NO mins)\n  static size_t required_size(int n, int k, int k_group_size) {\n    return sizeof(int8_t) * n * k / 2 + sizeof(float) * n * (k / k_group_size);\n  }\n\n  BufferBInt4KGroupImpl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    assert(n % N_STEP == 0);\n    assert(k % K_STEP == 0);\n    if (n % N_STEP || k % K_STEP || k % k_group_size) {\n      printf(\"BufferBInt4KGroupImpl: n: %d, k: %d, N_STEP: %d, K_STEP: %d, k_group_size: %d\\n\", n, k, N_STEP, K_STEP,\n             k_group_size);\n      throw std::runtime_error(\"n or k is not aligned to N_STEP or K_STEP\");\n    }\n    k_group_count = k / k_group_size;\n    b = reinterpret_cast<dt*>(ptr);\n    d = reinterpret_cast<float*>(offset_pointer(b, n * k / 2));\n  }\n\n  // Load from packed signed int4 format\n  // Input: proj is packed int4 weights (2 int4 values per byte)\n  // Each int4 value is in range [-8, 7] (signed)\n  void from_raw_mat(uint8_t* proj, int ith, int nth) {\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    if (n_start >= n_end) {\n      return;\n    }\n    const size_t row_bytes = static_cast<size_t>(k) / 2;\n    const size_t rows = static_cast<size_t>(n_end - n_start);\n    uint8_t* dst_weights = reinterpret_cast<uint8_t*>(b) + n_start * row_bytes;\n    const uint8_t* src_weights = proj + n_start * row_bytes;\n    std::memcpy(dst_weights, src_weights, rows * row_bytes);\n  }\n\n  // Get pointer to submatrix for computation\n  dt* get_submat(int n, int k, int n_begin, int k_begin) {\n    const size_t row_bytes = static_cast<size_t>(k) / 2;\n    const size_t row_offset = static_cast<size_t>(n_begin) * row_bytes;\n    const size_t col_offset = static_cast<size_t>(k_begin) / 2;\n    return reinterpret_cast<dt*>(reinterpret_cast<uint8_t*>(b) + row_offset + col_offset);\n  }\n\n  // Get scale pointer for a specific row and k_group\n  float* get_scale(int n, int n_begin, int k, int k_begin) {\n    int k_group_idx = k_begin / k_group_size;\n    return d + n_begin * (k / k_group_size) + k_group_idx;\n  }\n\n  // Split range for parallel processing\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_per_thread = (n + nth - 1) / nth;\n    n_per_thread = (n_per_thread + N_STEP - 1) / N_STEP * N_STEP;\n    int n_start = std::min(ith * n_per_thread, n);\n    int n_end = std::min(n_start + n_per_thread, n);\n    return {n_start, n_end};\n  }\n};\n\ntemplate <typename K>\nstruct BufferBInt4WithZeroKGroupImpl {\n  using dt = typename K::dt;\n  dt* b;\n  float *d, *mins;  // scale, mins\n  int n, k, k_group_size, k_group_count;\n\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n  static constexpr int TILE_N = K::TILE_N;\n\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n  static const int B_K_STEP = 2 * K_STEP;\n  static constexpr bool SCALE = true;\n\n  static size_t required_size(int n, int k, int k_group_size) {\n    return sizeof(int8_t) * n * k / 2 + sizeof(float) * n * (k / k_group_size) * 2;\n  }\n\n  BufferBInt4WithZeroKGroupImpl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    assert(n % N_STEP == 0);\n    assert(k % B_K_STEP == 0);\n    if (n % N_STEP || k % B_K_STEP || k % k_group_size) {\n      printf(\"n: %d, k: %d, N_STEP: %d, B_K_STEP: %d, k_group_size: %d\\n\", n, k, N_STEP, B_K_STEP, k_group_size);\n      throw std::runtime_error(\"n or k is not aligned to N_STEP or B_K_STEP\");\n    }\n    k_group_count = k / k_group_size;\n    b = reinterpret_cast<dt*>(ptr);\n    d = reinterpret_cast<float*>(offset_pointer(b, n * k / 2));\n    mins = d + n * k_group_count;\n  }\n\n  // 对 uint8_t 批量四舍五入到最接近的 16 倍数\n  static __m128i round_4bit_u8(__m128i x) {\n    // 加 8 做四舍五入，使用 Saturate\n    x = _mm_adds_epi8(x, _mm_set1_epi8(0x08));\n    // 清除低 4 位（即对 16 对齐）\n    x = _mm_and_si128(x, _mm_set1_epi8(0xF0));\n    return x;\n  }\n\n  void from_raw_mat(uint8_t* proj, int ith, int nth) {\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += B_K_STEP) {\n          for (int i = 0; i < N_STEP; i++) {\n            uint8_t* dst = (uint8_t*)offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size +\n                                                        n_begin * k_block_size + k_begin * N_STEP + i * B_K_STEP) >>\n                                                           1);\n            uint32_t* src =\n                (uint32_t*)offset_pointer(proj, ((n_block_begin + n_begin + i) * k + k_block_begin + k_begin) >> 1);\n            for (int a0 = 0; a0 < 8; a0++) {\n              uint32_t src0 = src[a0], src1 = src[a0 + 8];\n              for (int a1 = 0; a1 < 8; a1++) {\n                uint8_t cur_src0 = src0 & 0x0F, cur_src1 = src1 & 0x0F;\n                dst[(a0 * 8) + a1] = (cur_src0 | (cur_src1 << 4));\n                src0 = src0 >> 4;\n                src1 = src1 >> 4;\n              }\n            }\n          }\n          transpose_16x16_32bit((__m512i*)(offset_pointer(\n              b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2)));\n          transpose_16x16_32bit(\n              (__m512i*)(offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                            k_begin * N_STEP + TILE_N * B_K_STEP) /\n                                               2)));\n        }\n      }\n    }\n  }\n\n  void from_mat(ggml_bf16_t* src, int ith, int nth) {\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int i = 0; i < N_STEP; i++) {\n        for (int kg = 0; kg < k_group_count; kg++) {\n          int k_start = kg * k_group_size;\n          int k_end = k_start + k_group_size;\n\n          float amax = std::numeric_limits<float>::lowest();\n          float amin = std::numeric_limits<float>::max();\n          for (int j = k_start; j < k_end; j += 32) {\n            __m512 f0, f1;\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1);\n            amax = MAX(amax, _mm512_reduce_max_ps(f0));\n            amax = MAX(amax, _mm512_reduce_max_ps(f1));\n            amin = MIN(amin, _mm512_reduce_min_ps(f0));\n            amin = MIN(amin, _mm512_reduce_min_ps(f1));\n          }\n          d[kg * n + n_block_begin + n_begin + i] = (amax - amin) / 240.0;  // 15*16\n          // d[n_block_begin + n_begin + i] = (amax - amin) / 240.0; // 15*16\n          mins[kg * n + n_block_begin + n_begin + i] = amin;\n        }\n      }\n    }\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += B_K_STEP) {\n          for (int i = 0; i < N_STEP; i++) {\n            int k_group_idx0 = (k_block_begin + k_begin) / k_group_size;\n            int k_group_idx1 = (k_block_begin + k_begin + K_STEP) / k_group_size;\n            float scale0 = d[k_group_idx0 * n + n_block_begin + n_begin + i];\n            float scale1 = d[k_group_idx1 * n + n_block_begin + n_begin + i];\n            __m512 id0 = _mm512_set1_ps(scale0 ? 1.0f / scale0 : 0.0f);\n            __m512 id1 = _mm512_set1_ps(scale1 ? 1.0f / scale1 : 0.0f);\n            __m512 zps0 = _mm512_set1_ps(-mins[k_group_idx0 * n + n_block_begin + n_begin + i]);\n            __m512 zps1 = _mm512_set1_ps(-mins[k_group_idx1 * n + n_block_begin + n_begin + i]);\n            // __m512 zps = _mm512_set1_ps(-mins[n_block_begin + n_begin + i]);\n            dt* dst = offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                         k_begin * N_STEP + i * B_K_STEP) /\n                                            2);\n            {\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin),\n                                        &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f0, zps0), id0));\n              __m512i i1 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f1, zps0), id0));\n              __m512i i2 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f2, zps0), id0));\n              __m512i i3 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f3, zps0), id0));\n              __m128i s0 = _mm512_cvtusepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtusepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtusepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtusepi32_epi8(i3);\n              s0 = _mm_srli_epi16(round_4bit_u8(s0), 4);\n              s1 = _mm_srli_epi16(round_4bit_u8(s1), 4);\n              s2 = _mm_srli_epi16(round_4bit_u8(s2), 4);\n              s3 = _mm_srli_epi16(round_4bit_u8(s3), 4);\n              // s0 = _mm_or_si128(round_up4(s0), _mm_srli_epi16(round_up4(s1), 4));\n              // s2 = _mm_or_si128(round_up4(s2), _mm_srli_epi16(round_up4(s3), 4));\n              _mm_store_si128((__m128i*)dst, s0);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 16)), s1);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 32)), s2);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 48)), s3);\n            }\n\n            {\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 2, &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 3, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f0, zps1), id1));\n              __m512i i1 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f1, zps1), id1));\n              __m512i i2 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f2, zps1), id1));\n              __m512i i3 = _mm512_cvtps_epu32(_mm512_mul_ps(_mm512_add_ps(f3, zps1), id1));\n              __m128i s0 = _mm512_cvtusepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtusepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtusepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtusepi32_epi8(i3);\n              s0 = round_4bit_u8(s0);\n              s1 = round_4bit_u8(s1);\n              s2 = round_4bit_u8(s2);\n              s3 = round_4bit_u8(s3);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 0)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 0))), s0));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 16)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 16))), s1));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 32)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 32))), s2));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 48)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 48))), s3));\n            }\n          }\n          transpose_16x16_32bit((__m512i*)(offset_pointer(\n              b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2)));\n          transpose_16x16_32bit(\n              (__m512i*)(offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                            k_begin * N_STEP + TILE_N * B_K_STEP) /\n                                               2)));\n        }\n      }\n    }\n  }\n\n  dt* get_submat(int n, int k, int n_begin, int k_begin) {\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    n_begin -= n_block_begin;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return offset_pointer(\n        b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2);\n  }\n\n  float* get_scale(int n, int n_begin, int k, int k_begin) {\n    int k_group_idx = k_begin / k_group_size;\n    return d + k_group_idx * n + n_begin;\n  }\n  float* get_min(int n, int n_begin, int k, int k_begin) {\n    int k_group_idx = k_begin / k_group_size;\n    return mins + k_group_idx * n + n_begin;\n  }\n};\n\ntemplate <typename K>\nstruct BufferBInt4WithZeroLowKGroupImpl {\n  using dt = typename K::dt;\n  dt* b;\n  float *d, *mins;  // scale, mins\n  int n, k, k_group_size, k_group_count;\n\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n  static constexpr int TILE_N = K::TILE_N;\n\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n  static const int B_K_STEP = 2 * K_STEP;\n  static constexpr bool SCALE = true;\n\n  static size_t required_size(int n, int k, int k_group_size) {\n    return sizeof(int8_t) * n * k / 2 + sizeof(float) * n * (k / k_group_size) * 2;\n  }\n\n  BufferBInt4WithZeroLowKGroupImpl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    assert(n % N_STEP == 0);\n    assert(k % B_K_STEP == 0);\n    if (n % N_STEP || k % B_K_STEP || k % k_group_size) {\n      printf(\"n: %d, k: %d, N_STEP: %d, B_K_STEP: %d, k_group_size: %d\\n\", n, k, N_STEP, B_K_STEP, k_group_size);\n      throw std::runtime_error(\"n or k is not aligned to N_STEP or B_K_STEP\");\n    }\n    k_group_count = k / k_group_size;\n    b = reinterpret_cast<dt*>(ptr);\n    d = reinterpret_cast<float*>(offset_pointer(b, n * k / 2));\n    mins = d + n * k_group_count;\n  }\n\n  // 对 uint8_t 批量四舍五入到最接近的 16 倍数\n  static __m128i round_4bit_u8(__m128i x) {\n    // 加 8 做四舍五入，使用 Saturate\n    x = _mm_adds_epi8(x, _mm_set1_epi8(0x08));\n    // 清除低 4 位（即对 16 对齐）\n    x = _mm_and_si128(x, _mm_set1_epi8(0xF0));\n    return x;\n  }\n\n  void from_raw_mat(uint8_t* proj, int ith, int nth) {\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += B_K_STEP) {\n          for (int i = 0; i < N_STEP; i++) {\n            uint8_t* dst = (uint8_t*)offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size +\n                                                        n_begin * k_block_size + k_begin * N_STEP + i * B_K_STEP) >>\n                                                           1);\n            uint32_t* src =\n                (uint32_t*)offset_pointer(proj, ((n_block_begin + n_begin + i) * k + k_block_begin + k_begin) >> 1);\n            for (int a0 = 0; a0 < 8; a0++) {\n              uint32_t src0 = src[a0], src1 = src[a0 + 8];\n              for (int a1 = 0; a1 < 8; a1++) {\n                uint8_t cur_src0 = src0 & 0x0F, cur_src1 = src1 & 0x0F;\n                dst[(a0 * 8) + a1] = (cur_src0 | (cur_src1 << 4));\n                src0 = src0 >> 4;\n                src1 = src1 >> 4;\n              }\n            }\n          }\n          transpose_16x16_32bit((__m512i*)(offset_pointer(\n              b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2)));\n          transpose_16x16_32bit(\n              (__m512i*)(offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                            k_begin * N_STEP + TILE_N * B_K_STEP) /\n                                               2)));\n        }\n      }\n    }\n  }\n\n  void from_mat(ggml_bf16_t* src, int ith, int nth) {\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int i = 0; i < N_STEP; i++) {\n        for (int kg = 0; kg < k_group_count; kg++) {\n          int k_start = kg * k_group_size;\n          int k_end = k_start + k_group_size;\n\n          float amax = std::numeric_limits<float>::lowest();\n          float amin = std::numeric_limits<float>::max();\n          for (int j = k_start; j < k_end; j += 32) {\n            __m512 f0, f1;\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1);\n            amax = MAX(amax, _mm512_reduce_max_ps(f0));\n            amax = MAX(amax, _mm512_reduce_max_ps(f1));\n            amin = MIN(amin, _mm512_reduce_min_ps(f0));\n            amin = MIN(amin, _mm512_reduce_min_ps(f1));\n          }\n          d[kg * n + n_block_begin + n_begin + i] = (amax - amin) / 15.0;  // 15*16\n          // d[n_block_begin + n_begin + i] = (amax - amin) / 240.0; // 15*16\n          mins[kg * n + n_block_begin + n_begin + i] = amin;\n        }\n      }\n    }\n\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += B_K_STEP) {\n          for (int i = 0; i < N_STEP; i++) {\n            int k_group_idx0 = (k_block_begin + k_begin) / k_group_size;\n            int k_group_idx1 = (k_block_begin + k_begin + K_STEP) / k_group_size;\n            float scale0 = d[k_group_idx0 * n + n_block_begin + n_begin + i];\n            float scale1 = d[k_group_idx1 * n + n_block_begin + n_begin + i];\n            __m512 id0 = _mm512_set1_ps(scale0 ? 1.0f / scale0 : 0.0f);\n            __m512 id1 = _mm512_set1_ps(scale1 ? 1.0f / scale1 : 0.0f);\n            __m512 zps0 = _mm512_set1_ps(-mins[k_group_idx0 * n + n_block_begin + n_begin + i]);\n            __m512 zps1 = _mm512_set1_ps(-mins[k_group_idx1 * n + n_block_begin + n_begin + i]);\n            // __m512 zps = _mm512_set1_ps(-mins[n_block_begin + n_begin + i]);\n            dt* dst = offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                         k_begin * N_STEP + i * B_K_STEP) /\n                                            2);\n            {\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin),\n                                        &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n              __m512i i0 = _mm512_cvt_roundps_epu32(_mm512_mul_ps(_mm512_add_ps(f0, zps0), id0),\n                                                    _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);\n              __m512i i1 = _mm512_cvt_roundps_epu32(_mm512_mul_ps(_mm512_add_ps(f1, zps0), id0),\n                                                    _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);\n              __m512i i2 = _mm512_cvt_roundps_epu32(_mm512_mul_ps(_mm512_add_ps(f2, zps0), id0),\n                                                    _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);\n              __m512i i3 = _mm512_cvt_roundps_epu32(_mm512_mul_ps(_mm512_add_ps(f3, zps0), id0),\n                                                    _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);\n              __m128i s0 = _mm512_cvtusepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtusepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtusepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtusepi32_epi8(i3);\n              // s0 = _mm_srli_epi16(s0, 4);\n              // s1 = _mm_srli_epi16(s1, 4);\n              // s2 = _mm_srli_epi16(s2, 4);\n              // s3 = _mm_srli_epi16(s3, 4);\n              // s0 = _mm_or_si128(round_up4(s0), _mm_srli_epi16(round_up4(s1), 4));\n              // s2 = _mm_or_si128(round_up4(s2), _mm_srli_epi16(round_up4(s3), 4));\n              _mm_store_si128((__m128i*)dst, s0);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 16)), s1);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 32)), s2);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 48)), s3);\n            }\n\n            {\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 2, &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 3, &f2, &f3);\n              __m512i i0 = _mm512_cvt_roundps_epu32(_mm512_mul_ps(_mm512_add_ps(f0, zps1), id1),\n                                                    _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);\n              __m512i i1 = _mm512_cvt_roundps_epu32(_mm512_mul_ps(_mm512_add_ps(f1, zps1), id1),\n                                                    _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);\n              __m512i i2 = _mm512_cvt_roundps_epu32(_mm512_mul_ps(_mm512_add_ps(f2, zps1), id1),\n                                                    _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);\n              __m512i i3 = _mm512_cvt_roundps_epu32(_mm512_mul_ps(_mm512_add_ps(f3, zps1), id1),\n                                                    _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);\n              __m128i s0 = _mm512_cvtusepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtusepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtusepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtusepi32_epi8(i3);\n              s0 = _mm_slli_epi16(s0, 4);\n              s1 = _mm_slli_epi16(s1, 4);\n              s2 = _mm_slli_epi16(s2, 4);\n              s3 = _mm_slli_epi16(s3, 4);\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 0)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 0))), s0));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 16)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 16))), s1));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 32)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 32))), s2));\n              _mm_store_si128((__m128i*)(offset_pointer(dst, 48)),\n                              _mm_or_si128(_mm_loadu_si128((__m128i*)(offset_pointer(dst, 48))), s3));\n            }\n          }\n          transpose_16x16_32bit((__m512i*)(offset_pointer(\n              b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2)));\n          transpose_16x16_32bit(\n              (__m512i*)(offset_pointer(b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                            k_begin * N_STEP + TILE_N * B_K_STEP) /\n                                               2)));\n        }\n      }\n    }\n  }\n\n  dt* get_submat(int n, int k, int n_begin, int k_begin) {\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    n_begin -= n_block_begin;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return offset_pointer(\n        b, (n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP) / 2);\n  }\n\n  float* get_scale(int n, int n_begin, int k, int k_begin) {\n    int k_group_idx = k_begin / k_group_size;\n    return d + k_group_idx * n + n_begin;\n  }\n  float* get_min(int n, int n_begin, int k, int k_begin) {\n    int k_group_idx = k_begin / k_group_size;\n    return mins + k_group_idx * n + n_begin;\n  }\n};\n\ntemplate <typename K>\nstruct BufferCImpl {\n  float* c;\n  int max_m, n;\n\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n\n  static size_t required_size(int max_m, int n) { return sizeof(float) * max_m * n; }\n\n  BufferCImpl(int max_m, int n, void* ptr) : max_m(max_m), n(n) {\n    assert(max_m % M_STEP == 0);\n    assert(n % N_STEP == 0);\n    if (max_m % M_STEP || n % N_STEP) {\n      printf(\"max_m = %d, n = %d, M_STEP = %d, N_STEP = %d\\n\", max_m, n, M_STEP, N_STEP);\n      throw std::runtime_error(\"BufferCImpl: max_m and n must be multiple of M_STEP and N_STEP\");\n    }\n    set_data(ptr);\n  }\n\n  void set_data(void* ptr) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    c = reinterpret_cast<float*>(ptr);\n  }\n\n  void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {\n    assert(m <= max_m);\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n          __m512* x0 =\n              (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);\n          __m512* x1 =\n              (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16);\n          avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin));\n        }\n      }\n    }\n  }\n\n  float* get_submat(int m, int n, int m_begin, int n_begin) {\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    n_begin -= n_block_begin;\n    return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;\n  }\n};\n\ntemplate <typename K>\nstruct BufferCReduceImpl {\n  float* c;\n  int32_t* int_c;  // Additional int32_t buffer, same size as c\n  int max_m, n;\n\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n\n  static size_t required_size(int max_m, int n) {\n    // Need space for both float* c and int32_t* int_c\n    return sizeof(float) * max_m * n + sizeof(int32_t) * max_m * n;\n  }\n\n  BufferCReduceImpl(int max_m, int n, void* ptr) : max_m(max_m), n(n) {\n    assert(max_m % M_STEP == 0);\n    assert(n % N_STEP == 0);\n    if (max_m % M_STEP || n % N_STEP) {\n      printf(\"max_m = %d, n = %d, M_STEP = %d, N_STEP = %d\\n\", max_m, n, M_STEP, N_STEP);\n      throw std::runtime_error(\"BufferCReduceImpl: max_m and n must be multiple of M_STEP and N_STEP\");\n    }\n    set_data(ptr);\n  }\n\n  void set_data(void* ptr) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    c = reinterpret_cast<float*>(ptr);\n    // int_c starts after the float buffer\n    int_c = reinterpret_cast<int32_t*>(c + max_m * n);\n  }\n\n  void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {\n    assert(m <= max_m);\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n          __m512* x0 =\n              (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);\n          __m512* x1 =\n              (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16);\n          avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin));\n        }\n      }\n    }\n  }\n\n  float* get_submat(int m, int n, int m_begin, int n_begin) {\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    n_begin -= n_block_begin;\n    return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;\n  }\n\n  int32_t* get_int_submat(int m, int n, int m_begin, int n_begin) {\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    n_begin -= n_block_begin;\n    return int_c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;\n  }\n\n  // Clear the int_c buffer\n  void clear_int_buffer() { std::memset(int_c, 0, sizeof(int32_t) * max_m * n); }\n\n  // Convert int32_t results to float\n  void convert_int_to_float(int m) {\n    assert(m <= max_m);\n    for (int i = 0; i < m * n; i++) {\n      c[i] = static_cast<float>(int_c[i]);\n    }\n  }\n};\n\n}  // namespace amx\n\n#endif  // AMX_BUFFERS_HPP"
  },
  {
    "path": "kt-kernel/operators/amx/la/amx_config.hpp",
    "content": "#ifndef AMX_CONFIG_HPP\n#define AMX_CONFIG_HPP\n#if (defined(_WIN32) || defined(_WIN64))\n#define RESTRICT __restrict\n#else\n#define RESTRICT __restrict__\n#endif\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define ALWAYS_INLINE __forceinline\n#elif __has_attribute(always_inline) || defined(__GNUC__)\n#define ALWAYS_INLINE __attribute__((__always_inline__)) inline\n#else\n#define ALWAYS_INLINE inline\n#endif\n#include <immintrin.h>\n#if defined(__AMX__) || defined(__AMXINT8__) || defined(__AMXBF16__) || defined(__AMX_TILE__) || defined(HAVE_AMX)\n#ifndef HAVE_AMX\n#define HAVE_AMX\n#endif\n#include <emmintrin.h>\n#include <stdlib.h>\n#include <sys/syscall.h>\n#include <tmmintrin.h>\n#include <unistd.h>\n\n#include <array>\n#include <cassert>\n#include <cstdint>\n#include <cstdio>\n#include <stdexcept>\n\nnamespace amx {\n\n#define ARCH_GET_XCOMP_SUPP 0x1021\n#define ARCH_GET_XCOMP_PERM 0x1022\n#define ARCH_REQ_XCOMP_PERM 0x1023\n#define XFEATURE_XTILECFG 17\n#define XFEATURE_XTILEDATA 18\n#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG)\n#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA)\n#define XFEATURE_MASK_XTILE ((1 << XFEATURE_XTILECFG) | (1 << XFEATURE_XTILEDATA))\n\nconst int TMMCount = 8;\nconst int MaxTileHeight = 16;\nconst int MaxTileWidth = 64;\n\nconst int AMX_BLK_SIZE = 32;\n\n#define TMM0 0\n#define TMM1 1\n#define TMM2 2\n#define TMM3 3\n#define TMM4 4\n#define TMM5 5\n#define TMM6 6\n#define TMM7 7\n\ninline bool enable_amx() {\n  // CHECK: whether this can be removed?\n  // static thread_local bool initialized = false;\n  // if (initialized) {\n  //   return true;\n  // }\n  // initialized = true;\n\n  // if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {\n  //   printf(\"\\n Fail to do XFEATURE_XTILEDATA \\n\\n\");\n  //   return false;\n  // } else {\n  //   // printf(\"\\n TILE DATA USE SET - OK \\n\\n\");\n  //   return true;\n  // }\n  // return true;\n  unsigned long features;\n  long rc;\n  rc = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_SUPP, &features);\n\n  if (!rc && (features & XFEATURE_MASK_XTILE) == XFEATURE_MASK_XTILE) {\n    unsigned long bitmask = 0;\n    long status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask);\n    if (0 != status) return false;\n    if (bitmask & XFEATURE_MASK_XTILEDATA) return true;\n\n    status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);\n    if (0 != status) return false;  // XFEATURE_XTILEDATA setup is failed, TMUL usage is not allowed\n    status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask);\n\n    // XFEATURE_XTILEDATA setup is failed, can't use TMUL\n    if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) return false;\n\n    // XFEATURE_XTILEDATA set successfully, TMUL usage is allowed\n    // printf(\"\\n TILE DATA USE SET - OK \\n\\n\");\n    return true;\n  }\n  return false;\n}\n\nstruct alignas(64) TileConfig {\n  uint8_t palette;\n  uint8_t start_row;\n  std::array<uint8_t, 14> __0 = {};\n  std::array<uint16_t, 8> colsb;\n  std::array<uint8_t, 16> __1 = {};\n  std::array<uint8_t, 8> rows;\n  std::array<uint8_t, 8> __2 = {};\n\n  TileConfig() {\n    palette = 1;\n    start_row = 0;\n    for (int i = 0; i < 8; i++) {\n      set_row_col(i, 0, 0);\n    }\n  }\n\n  void set_row_col(int i, uint8_t row, uint16_t col) {\n    colsb[i] = col;\n    rows[i] = row;\n  }\n\n  void set_config() { _tile_loadconfig(this); }\n\n  static void load_data(int to, void* from, size_t stride) {\n    switch (to) {\n      case 0:\n        _tile_loadd(0, from, stride);\n        break;\n      case 1:\n        _tile_loadd(1, from, stride);\n        break;\n      case 2:\n        _tile_loadd(2, from, stride);\n        break;\n      case 3:\n        _tile_loadd(3, from, stride);\n        break;\n      case 4:\n        _tile_loadd(4, from, stride);\n        break;\n      case 5:\n        _tile_loadd(5, from, stride);\n        break;\n      case 6:\n        _tile_loadd(6, from, stride);\n        break;\n      case 7:\n        _tile_loadd(7, from, stride);\n        break;\n      default:\n        throw std::runtime_error(\"no such tile\");\n    }\n  }\n\n  static void store_data(int from, void* to, size_t stride) {\n    switch (from) {\n      case 0:\n        _tile_stored(0, to, stride);\n        break;\n      case 1:\n        _tile_stored(1, to, stride);\n        break;\n      case 2:\n        _tile_stored(2, to, stride);\n        break;\n      case 3:\n        _tile_stored(3, to, stride);\n        break;\n      case 4:\n        _tile_stored(4, to, stride);\n        break;\n      case 5:\n        _tile_stored(5, to, stride);\n        break;\n      case 6:\n        _tile_stored(6, to, stride);\n        break;\n      case 7:\n        _tile_stored(7, to, stride);\n        break;\n      default:\n        throw std::runtime_error(\"no such tile\");\n    }\n  }\n};\n\nstatic_assert(sizeof(TileConfig) == 64);\n\n}  // namespace amx\n#endif  // defined(__AMX__)\n#endif  // AMX_CONFIG_HPP"
  },
  {
    "path": "kt-kernel/operators/amx/la/amx_kernels.hpp",
    "content": "#ifndef AMX_KERNELS_HPP\n#define AMX_KERNELS_HPP\n#include <algorithm>\n#include <chrono>\n#include <cstdint>\n#include <cstdio>\n#include <memory>\n\n#include \"amx_buffers.hpp\"\n#include \"amx_config.hpp\"\n#include \"amx_quantization.hpp\"\n#include \"amx_utils.hpp\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llamafile/sgemm.h\"\n#include \"utils.hpp\"\n\nnamespace amx {\n\n// Compile-time detection: true when AMX intrinsics are available\n#if defined(__AMX__) || defined(__AMXINT8__) || defined(__AMXBF16__) || defined(__AMX_TILE__) || defined(HAVE_AMX)\ninline constexpr bool AMX_AVAILABLE = true;\n#ifndef HAVE_AMX\n#define HAVE_AMX\n#endif\n#else\ninline constexpr bool AMX_AVAILABLE = false;\n#endif\n\n/*\nWe use 1-3-3\n C = A x B\n\n\nA is a row major matrix of size M x K, usually an Linear Layer weight matrix\nB is a col major vector of size K x N, usually an input vector, N is usually\nquite small\n\n   B\n A C\n A C\n A C\n\n  TMM 0-2: A\n  TMM 3: B\n  TMM 4-6: C\n\n   3\n 0 4\n 1 5\n 2 6\n*/\n\ntemplate <class, class>\nstruct dpb133 {\n  static void run();\n};\n\ntemplate <>\ninline void dpb133<int8_t, int8_t>::run() {\n  _tile_dpbssd(4, 0, 3);\n  _tile_dpbssd(5, 1, 3);\n  _tile_dpbssd(6, 2, 3);\n}\n\ntemplate <>\ninline void dpb133<int8_t, uint8_t>::run() {\n  _tile_dpbsud(4, 0, 3);\n  _tile_dpbsud(5, 1, 3);\n  _tile_dpbsud(6, 2, 3);\n}\n\ntemplate <>\ninline void dpb133<uint8_t, int8_t>::run() {\n  _tile_dpbusd(4, 0, 3);\n  _tile_dpbusd(5, 1, 3);\n  _tile_dpbusd(6, 2, 3);\n}\n\ntemplate <>\ninline void dpb133<uint8_t, uint8_t>::run() {\n  _tile_dpbuud(4, 0, 3);\n  _tile_dpbuud(5, 1, 3);\n  _tile_dpbuud(6, 2, 3);\n}\n\ntemplate <int TILE_K = 32>\nstruct GemmKernel133 {\n  static const int TILE_M = 16;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n  static const int OUTPUT_T_SIZE = 4;\n\n  static const int M_STEP = TILE_M * 3;\n  static const int N_STEP = TILE_N;\n  static const int K_STEP = TILE_K;\n\n  static int recommended_nth(int m) { return (m + M_STEP - 1) / M_STEP; }\n\n  static void config() {\n#ifdef HAVE_AMX\n    TileConfig tile_config;\n\n    for (int i = 0; i < 3; i++) tile_config.set_row_col(i, TILE_M, TILE_K);\n\n    tile_config.set_row_col(3, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK);\n\n    for (int i = 4; i < 7; i++) tile_config.set_row_col(i, TILE_M, TILE_N * OUTPUT_T_SIZE);\n\n    tile_config.set_config();\n#endif\n  }\n\n  template <typename TA, typename TB, typename TC>\n  static void run_full_tile(const TA* a, size_t lda, const TB* b, size_t ldb, TC* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n    _tile_loadd(2, offset_pointer(a, lda * TILE_M * 2), lda);\n\n    _tile_loadd(3, b, ldb);\n\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, ldc * TILE_N), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_N * 2), ldc);\n\n    dpb133<TA, TB>::run();\n\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, ldc * TILE_N), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_N * 2), ldc);\n#endif\n  }\n\n  template <typename TA, typename TB, typename TC>\n  static void run_full_tile_zero(const TA* a, size_t lda, const TB* b, size_t ldb, TC* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n    _tile_loadd(2, offset_pointer(a, lda * TILE_M * 2), lda);\n\n    _tile_loadd(3, b, ldb);\n\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n\n    dpb133<TA, TB>::run();\n\n    // debug_tiles(7);\n\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, ldc * TILE_N), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_N * 2), ldc);\n#endif\n  }\n\n  static void convert_full_tile_b_to_vnni_inplace(void* b) { transpose_16x8_32bit((__m256i*)b); }\n\n  template <typename TA>\n  struct ATile {\n    TA v[3 * TILE_M * TILE_K];\n    void partial_load(TA* a, int m, int k, size_t lda) {\n      // memset(v, 0, sizeof(TA) * 3 * TILE_M * TILE_K);\n      for (int i = 0; i < m; i++) {\n        for (int j = 0; j < k; j++) {\n          v[i * TILE_K + j] = a[i * lda + j];\n        }\n      }\n    }\n\n    void partial_load_quant(block_q4_0* a, int m, int k, size_t lda) {\n      assert(k == 32);\n      // memset(v, 0, sizeof(TA) * 3 * TILE_M * TILE_K);\n      __m256i* vv = (__m256i*)v;\n      for (int i = 0; i < m; i++) {\n        vv[i] = dequant4x32(offset_pointer(a, lda * i)->qs);\n        vv[i] = _mm256_sub_epi8(vv[i], _mm256_set1_epi8(8));\n      }\n    }\n\n    void partial_load_quant(block_q8_0* a, int m, int k, size_t lda) {\n      assert(k == 32);\n      // memset(v, 0, sizeof(TA) * 3 * TILE_M * TILE_K);\n      __m256i* vv = (__m256i*)v;\n      for (int i = 0; i < m; i++) {\n        vv[i] = unaligned_copy8x32(offset_pointer(a, lda * i)->qs);\n      }\n    }\n\n    template <typename QA>\n    void partial_load_quant(TA* a, int m, size_t lda) {\n      // memset(v, 0, sizeof(TA) * 3 * TILE_M * TILE_K);\n      if constexpr (std::is_same_v<QA, blocks_aligned_q8_0_ref>) {\n        __m512i* vv = (__m512i*)v;\n        for (int i = 0; i < m; i++) {\n          vv[i] = copy8x64(offset_pointer(a, lda * i));\n        }\n      } else if constexpr (std::is_same_v<QA, blocks_aligned_q4_0_ref>) {\n        assert(0);\n      } else {\n        assert(0);\n      }\n    }\n\n    void partial_load_quant(block_q4_K* a, int m, int inner_block_idx, size_t lda) {\n      // memset(v, 0, sizeof(TA) * 3 * TILE_M * TILE_K);\n      __m256i* vv = (__m256i*)v;\n\n      size_t qs_offset = inner_block_idx / 2 * 32;\n      for (int i = 0; i < m; i++) {\n        block_q4_K* spa = offset_pointer_row_major(a, i, 0, lda);\n        if (inner_block_idx % 2 == 0) {\n          vv[i] = lo4bit(spa->qs + qs_offset);\n        } else {\n          vv[i] = hi4bit(spa->qs + qs_offset);\n        }\n      }\n    }\n\n    void partial_load_quant(blocks_aligned_q8_0_ref a, int m, int k, int blck_stride) {\n      // memset(v, 0, sizeof(TA) * 3 * TILE_M * TILE_K);\n      __m512i* vv = (__m512i*)v;\n      for (int i = 0; i < m; i++) {\n        vv[i] = copy8x64(a.offset(blck_stride * i).qs);\n      }\n    }\n  };\n\n  template <typename TB>\n  struct alignas(64) BTile {\n    TB v[TILE_N * TILE_K];\n    __m512 scale = {};\n\n    void partial_load(TB* b, int n, int k, size_t ldb) {\n      for (int i = 0; i < n; i++) {\n        for (int j = 0; j < k; j++) {\n          v[i * TILE_K + j] = b[i * ldb + j];\n        }\n      }\n      transpose_16x8_32bit((__m256i*)v);\n    }\n\n    void partial_load_quant(block_q8_0* b, int n, int k, size_t ldb) {\n      assert(k == 32);\n      memset(v, 0, sizeof(TB) * TILE_K * TILE_N);\n      __m256i* vv = (__m256i*)v;\n      float* bss = reinterpret_cast<float*>(&scale);\n      for (int i = 0; i < n; i++) {\n        vv[i] = unaligned_copy8x32(offset_pointer(b, ldb * i)->qs);\n        float sb = GGML_FP16_TO_FP32(offset_pointer_col_major(b, 0, i, ldb)->d);\n        bss[i] = sb;\n      }\n\n      transpose_16x8_32bit(vv);\n    }\n\n    void partial_load_quant(blocks_aligned_q8_0_ref b, int n, int k, int blck_stride) {\n      assert(k == 64);\n      memset(v, 0, sizeof(TB) * TILE_K * TILE_N);\n      __m512i* vv = (__m512i*)v;\n      float* vs = reinterpret_cast<float*>(&scale);\n      for (int i = 0; i < n; i++) {\n        auto ref = b.offset(blck_stride * i);\n        vv[i] = copy8x64(ref.qs);\n        float sb = GGML_FP16_TO_FP32(*ref.d);\n        vs[i] = sb;\n      }\n      transpose_16x16_32bit(vv);\n    }\n\n    void load_from(TB* b, size_t ldb) {\n      __m256i* vb = (__m256i*)b;\n      __m256i* vo = (__m256i*)v;\n      for (int i = 0; i < 16; i++) {\n        vo[i] = *offset_pointer(vb, ldb * i);\n      }\n      transpose_16x8_32bit(vo);\n    }\n\n    template <typename TA, typename TC>\n    void run_full_ac(TA* a, size_t lda, TC* c, size_t ldc) {\n      run_full_tile(a, lda, v, TILE_N * VNNI_BLK, c, ldc);\n    }\n  };\n\n  template <typename TB>\n  struct alignas(64) BTileSum {\n    TB v[TILE_N * TILE_K];\n    __m512 scale = {};\n    __m512 sum = {};\n    void partial_load_quant(block_q8_K* b, int n, int inner_block_idx, size_t ldb) {\n      memset(v, 0, TILE_K * TILE_N);\n      __m256i* vv = (__m256i*)v;\n      float* scale_s = reinterpret_cast<float*>(&scale);\n      float* sum_s = reinterpret_cast<float*>(&sum);\n      for (int i = 0; i < n; i++) {\n        block_q8_K* spb = offset_pointer_col_major(b, 0, i, ldb);\n        vv[i] = unaligned_copy8x32(spb->qs + inner_block_idx * 32);\n        scale_s[i] = spb->d;\n        sum_s[i] =\n            spb->bsums[inner_block_idx * 2] + spb->bsums[inner_block_idx * 2 + 1];  // TODO: may this will be slow\n        // printf(\"scale[%d] = %f, sum_s[%d] = %f\\n\", i, scale_s[i], i,\n        // sum_s[i]);\n      }\n      transpose_16x8_32bit(vv);\n    }\n  };\n  template <typename TC>\n  struct alignas(64) CTile {\n    static_assert(sizeof(TC) == 4);\n    TC v[3 * TILE_M * TILE_N] = {};\n\n    void partial_load(TC* c, int m, int n, size_t ldc) {\n      for (int i = 0; i < m; i++) {\n        for (int j = 0; j < n; j++) {\n          v[i * TILE_N + j] = offset_pointer(c, ldc * i)[j];\n        }\n      }\n    }\n\n    void partial_store(TC* c, int m, int n, size_t ldc) {\n      for (int i = 0; i < m; i++) {\n        for (int j = 0; j < n; j++) {\n          offset_pointer(c, ldc * i)[j] = v[i * TILE_N + j];\n        }\n      }\n    }\n\n    void to_fp32() {\n      __m512i* vv = (__m512i*)v;\n      __m512* vf = (__m512*)v;\n      for (int i = 0; i < 3 * TILE_M; i++) {\n        vf[i] = _mm512_cvtepi32_ps(vv[i]);\n      }\n    }\n  };\n\n  template <typename TA, typename TB, typename TC>\n  struct PartialTiles {\n    ATile<TA> ta;\n    BTile<TB> tb;\n    CTile<TC> tc;\n    void partial_run(int m, int n, int k, TA* a, size_t lda, TB* b, size_t ldb, TC* c, size_t ldc) {\n      ta.partial_load(a, m, k, lda);\n      tb.partial_load(b, n, k, ldb);\n      tc.partial_load(c, m, n, ldc);\n      run_full_tile(ta.v, TILE_K, tb.v, TILE_N * VNNI_BLK, tc.v, TILE_N * OUTPUT_T_SIZE);\n      tc.partial_store(c, m, n, ldc);\n    }\n\n    template <typename QA>\n    void partial_run_quant(int m, int n, int k, QA* a, size_t lda, block_q8_0* b, size_t ldb, float* c, size_t ldc) {\n      assert(QK4_0 == 32);\n      assert(QK8_0 == 32);\n\n      ta.partial_load_quant(a, m, k, lda);\n      tb.partial_load_quant(b, n, k, ldb);\n\n      run_full_tile_zero(ta.v, TILE_K, tb.v, TILE_N * VNNI_BLK, tc.v, TILE_N * OUTPUT_T_SIZE);\n\n      __m512i* cs = (__m512i*)tc.v;\n      for (int i = 0; i < m; i++) {\n        __m512 as = _mm512_set1_ps(GGML_FP16_TO_FP32(offset_pointer_row_major(a, i, 0, lda)->d));\n        __m512* now = reinterpret_cast<__m512*>(offset_pointer_row_major(c, i, 0, ldc));\n        *now = _mm512_fmadd_ps(_mm512_mul_ps(as, tb.scale), _mm512_cvtepi32_ps(cs[i]), *now);\n      }\n    }\n\n    template <typename QA>\n    void partial_run_quant_ac(int m, int n, int k, QA* a, size_t lda, float* c, size_t ldc) {\n      assert(QK4_0 == 32);\n      assert(QK8_0 == 32);\n\n      ta.partial_load_quant(a, m, k, lda);\n\n      run_full_tile_zero(ta.v, TILE_K, tb.v, TILE_N * VNNI_BLK, tc.v, TILE_N * OUTPUT_T_SIZE);\n\n      __m512i* cs = (__m512i*)tc.v;\n      for (int i = 0; i < m; i++) {\n        __m512 as = _mm512_set1_ps(GGML_FP16_TO_FP32(offset_pointer_row_major(a, i, 0, lda)->d));\n        __m512* now = reinterpret_cast<__m512*>(offset_pointer_row_major(c, i, 0, ldc));\n        *now = _mm512_fmadd_ps(_mm512_mul_ps(as, tb.scale), _mm512_cvtepi32_ps(cs[i]), *now);\n      }\n    }\n\n    template <typename AQA>\n    void partial_run_quant_ac(int m, int n, int k, AQA a, int a_blck_stride, float* c, size_t ldc) {\n      assert(AQA::block_size == 64);\n\n      ta.partial_load_quant(a, m, k, a_blck_stride);\n\n      run_full_tile_zero(ta.v, TILE_K, tb.v, TILE_N * VNNI_BLK, tc.v, TILE_N * OUTPUT_T_SIZE);\n\n      __m512i* cs = (__m512i*)tc.v;\n      for (int i = 0; i < m; i++) {\n        __m512 as = _mm512_set1_ps(GGML_FP16_TO_FP32(*a.offset(i * a_blck_stride).d));\n        // printf(\"%f\\n\", GGML_FP16_TO_FP32(*a.offset(i * a_blck_stride).d));\n        __m512* now = reinterpret_cast<__m512*>(offset_pointer_row_major(c, i, 0, ldc));\n        *now = _mm512_fmadd_ps(_mm512_mul_ps(as, tb.scale), _mm512_cvtepi32_ps(cs[i]), *now);\n      }\n    }\n  };\n\n  template <typename TA, typename TB, typename TC>\n  struct PartialTilesSum {\n    ATile<TA> ta;\n    BTileSum<TB> tb;\n    CTile<TC> tc;\n\n    void partial_run_quant_ac(int m, int n, int inner_block_idx, block_q4_K* a, size_t lda, float* c, size_t ldc,\n                              float a_scale, float a_min) {\n      ta.partial_load_quant(a, m, inner_block_idx, lda);\n\n      run_full_tile_zero(ta.v, TILE_K, tb.v, TILE_N * VNNI_BLK, tc.v, TILE_N * OUTPUT_T_SIZE);\n\n      __m512i* cs = (__m512i*)tc.v;\n      for (int i = 0; i < m; i++) {\n        __m512* now = reinterpret_cast<__m512*>(offset_pointer_row_major(c, i, 0, ldc));\n        *now = _mm512_fmadd_ps(_mm512_sub_ps(_mm512_mul_ps(_mm512_cvtepi32_ps(cs[i]), _mm512_set1_ps(a_scale)),\n                                             _mm512_mul_ps(tb.sum, _mm512_set1_ps(a_min))),\n                               tb.scale, *now);\n        // C += Bscale * (Ascale * dp - Amin * Bsum)\n      }\n    }\n  };\n};\n\nstruct GemmKernel133BF {\n  using dt = ggml_bf16_t;\n  using output_t = float;\n  static const int TILE_M = 16;\n  static const int TILE_K = 32;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 2;\n\n  static const int M_STEP = TILE_M * 3;\n  static const int N_STEP = TILE_N;\n  static const int K_STEP = TILE_K;\n\n  static int recommended_nth(int m) { return (m + M_STEP - 1) / M_STEP; }\n  static void config() {\n#ifdef HAVE_AMX\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 32\n    for (int i = 0; i < 3; i++) tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));\n\n    // size is 8 x 64\n    tile_config.set_row_col(3, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));\n\n    // size is 16 x 64\n    for (int i = 4; i < 7; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n#endif\n  }\n\n  static void run_full_tile(const dt* a, size_t lda, const dt* b, size_t ldb, output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n    _tile_loadd(2, offset_pointer(a, lda * TILE_M * 2), lda);\n\n    _tile_loadd(3, b, ldb);\n\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, ldc * TILE_N), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_N * 2), ldc);\n\n    _tile_dpbf16ps(4, 0, 3);\n    _tile_dpbf16ps(5, 1, 3);\n    _tile_dpbf16ps(6, 2, 3);\n\n    // debug_tiles(7);\n\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, ldc * TILE_N), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_N * 2), ldc);\n#endif\n  }\n\n  struct ATile {\n    dt v[3 * TILE_M * TILE_K];\n\n    void partial_load(dt* a, int m, int k, size_t lda) {\n      assert(k == TILE_K);\n      __m512* vv = (__m512*)v;\n      __m512* va = (__m512*)a;\n      for (int i = 0; i < m; i++) {\n        vv[i] = *offset_pointer_row_major(va, i, 0, lda);\n      }\n    }\n  };\n\n  struct alignas(64) BTile {\n    dt v[TILE_N * TILE_K];\n\n    void full_load(dt* b, size_t ldb) { partial_load(b, TILE_N, TILE_K, ldb); }\n\n    void partial_load(dt* b, int n, int k, size_t ldb) {\n      __m512* vv = (__m512*)v;\n      __m512* vb = (__m512*)b;\n      for (int i = 0; i < n; i++) {\n        vv[i] = *offset_pointer_col_major(vb, 0, i, ldb);\n      }\n      transpose_16x16_32bit((__m512i*)v);\n    }\n\n    template <typename TA, typename TC>\n    void run_full_ac(TA* a, size_t lda, TC* c, size_t ldc) {\n      run_full_tile(a, lda, v, TILE_N * VNNI_BLK * sizeof(dt), c, ldc);\n    }\n  };\n\n  struct alignas(64) CTile {\n    output_t v[3 * TILE_M * TILE_N];\n    // c must be 64 aligned, ldc must be 64 aligned\n    void partial_load(float* c, int m, int n, size_t ldc) {\n      assert(n <= TILE_N);\n      __m512* vv = (__m512*)v;\n      __m512* vc = (__m512*)c;\n      for (int i = 0; i < m; i++) {\n        vv[i] = *offset_pointer_row_major(vc, i, 0, ldc);\n      }\n    }\n\n    void partial_store(float* c, int m, int n, size_t ldc) {\n      assert(n <= TILE_N);\n      __m512* vv = (__m512*)v;\n      __m512* vc = (__m512*)c;\n      for (int i = 0; i < m; i++) {\n        *offset_pointer_row_major(vc, i, 0, ldc) = vv[i];\n      }\n    }\n  };\n\n  struct PartialTiles {\n    ATile ta;\n    BTile tb;\n    CTile tc;\n    void partial_run(int m, int n, int k, dt* a, size_t lda, dt* b, size_t ldb, output_t* c, size_t ldc) {\n      ta.partial_load(a, m, k, lda);\n      tb.partial_load(b, n, k, ldb);\n      tc.partial_load(c, m, n, ldc);\n      run_full_tile(ta.v, TILE_K * sizeof(dt), tb.v, TILE_N * VNNI_BLK * sizeof(dt), tc.v, TILE_N * sizeof(output_t));\n      tc.partial_store(c, m, n, ldc);\n    }\n  };\n};\n\ntemplate <typename T1, typename T2>\nconstexpr T2 convert_to(const T1& value) {\n  if constexpr (std::is_same<T1, T2>::value) {\n    return value;\n  } else if constexpr (std::is_same<T1, ggml_bf16_t>::value && std::is_same<T2, float>::value) {\n    return GGML_BF16_TO_FP32(value);\n  } else if constexpr (std::is_same<T1, float>::value && std::is_same<T2, ggml_bf16_t>::value) {\n    return GGML_FP32_TO_BF16(value);\n  }\n}\n\nstruct GemmKernel224BF {\n  using dt = ggml_bf16_t;\n  using output_t = float;\n  static constexpr double ELEMENT_SIZE = 2;\n  static const int TILE_M = 16;\n  static const int TILE_K = 32;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 2;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  static inline const int N_BLOCK = 256;\n  static inline const int K_BLOCK = 1792;\n  static std::string name() { return \"BF16\"; }\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n#ifdef HAVE_AMX\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 32\n    for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));\n\n    // size is 16 x 32\n    for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n#endif\n  }\n\n  static void load_a(dt* a, size_t lda) {\n#ifdef HAVE_AMX\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n#else\n    (void)a;\n    (void)lda;\n#endif\n  }\n\n  static void load_b(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    _tile_loadd(2, b, ldb);\n    _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void clean_c() {\n#ifdef HAVE_AMX\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n#endif\n  }\n\n  static void load_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void store_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void run_tile() {\n#ifdef HAVE_AMX\n    _tile_dpbf16ps(4, 0, 2);\n    _tile_dpbf16ps(5, 0, 3);\n    _tile_dpbf16ps(6, 1, 2);\n    _tile_dpbf16ps(7, 1, 3);\n#endif\n  }\n\n  struct BufferA {\n    ggml_bf16_t* a;\n    int max_m, k;\n\n    static size_t required_size(int max_m, int k) { return sizeof(ggml_bf16_t) * max_m * k; }\n\n    BufferA(int max_m, int k, void* ptr) : max_m(max_m), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(max_m % M_STEP == 0);\n      assert(k % K_STEP == 0);\n      a = reinterpret_cast<ggml_bf16_t*>(ptr);\n    }\n\n    void set_data(void* new_ptr) { a = reinterpret_cast<ggml_bf16_t*>(new_ptr); }\n\n    void from_mat(int m, ggml_bf16_t* src, int ith, int nth) {\n      assert(m <= max_m);\n      assert(ith == 0 && nth == 1);\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n              __m512i* s = (__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin);\n              __m512i* d =\n                  (__m512i*)(a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP);\n              avx512_copy_32xbf16(s, d);\n            }\n          }\n        }\n      }\n    }\n\n    ggml_bf16_t* get_submat(int m, int k, int m_begin, int k_begin) {\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;\n    }\n  };\n\n  struct BufferB {\n    ggml_bf16_t* b;\n    int n, k;\n    static constexpr bool SCALE = false;\n\n    static size_t required_size(int n, int k) { return sizeof(ggml_bf16_t) * n * k; }\n\n    BufferB(int n, int k, void* ptr) : n(n), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(n % N_STEP == 0);\n      assert(k % K_STEP == 0);\n      b = reinterpret_cast<ggml_bf16_t*>(ptr);\n    }\n\n    void set_data(void* new_ptr) { b = reinterpret_cast<ggml_bf16_t*>(new_ptr); }\n\n    void from_mat(ggml_bf16_t* src, int ith, int nth) {\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < N_STEP; i++) {\n              __m512i* s = (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin);\n              __m512i* d = (__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                      k_begin * N_STEP + i * K_STEP);\n              avx512_copy_32xbf16(s, d);\n            }\n            transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                             n_begin * k_block_size + k_begin * N_STEP));\n            transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                             n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));\n          }\n        }\n      }\n    }\n\n    ggml_bf16_t* get_submat(int n, int k, int n_begin, int k_begin) {\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      n_begin -= n_block_begin;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;\n    }\n  };\n\n  struct BufferC {\n    float* c;\n    int max_m, n;\n    // 物理布局(按 float 元素数)：\n    // 逻辑矩阵 C 为 (max_m, n) 行主序，max_m 为 M_STEP 的倍数，\n    // n 按 N_BLOCK 分块。\n    // 存储顺序：\n    //   n_block(N_BLOCK 列) → m_block(M_STEP 行) → n_step(N_STEP 列) → (M_STEP×N_STEP) 行主序 tile。\n    // 因此可视为 5D：\n    //   c[n_blocks][m_blocks][n_steps][M_STEP][N_STEP]，\n    //   n_blocks = ceil(n / N_BLOCK)，m_blocks = max_m / M_STEP，\n    //   n_steps = N_BLOCK / N_STEP（尾块可能更小）。\n    // get_submat(m_begin, n_begin) 返回连续的 (M_STEP×N_STEP) tile 起始地址。\n\n    static size_t required_size(int max_m, int n) { return sizeof(float) * max_m * n; }\n\n    BufferC(int max_m, int n, void* ptr) : max_m(max_m), n(n) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(max_m % M_STEP == 0);\n      assert(n % N_STEP == 0);\n      c = reinterpret_cast<float*>(ptr);\n    }\n\n    void set_data(void* new_ptr) { c = reinterpret_cast<float*>(new_ptr); }\n\n    void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {\n      assert(m <= max_m);\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            __m512* x0 =\n                (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);\n            __m512* x1 = (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP +\n                                   i * N_STEP + 16);\n            avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin));\n          }\n        }\n      }\n    }\n\n    float* get_submat(int m, int n, int m_begin, int n_begin) {\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      n_begin -= n_block_begin;\n      return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;\n    }\n  };\n};\n\nstruct GemmKernel224Int8 {\n  using dt = int8_t;\n  using output_t = int32_t;\n  static constexpr double ELEMENT_SIZE = 1;\n  static const int TILE_M = 16;\n  static const int TILE_K = 64;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  // static inline const int N_BLOCK = 256;\n  static inline const int N_BLOCK = 64;\n  // static inline const int N_BLOCK = 32;\n  static inline const int K_BLOCK = 3584;\n  static std::string name() { return \"INT8\"; }\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n#ifdef HAVE_AMX\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 64\n    for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));\n\n    // size is 16 x 64\n    for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n#endif\n  }\n\n  static void load_a(dt* a, size_t lda) {\n#ifdef HAVE_AMX\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n#else\n    (void)a;\n    (void)lda;\n#endif\n  }\n\n  static void load_b(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    _tile_loadd(2, b, ldb);\n    _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void clean_c() {\n#ifdef HAVE_AMX\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n#endif\n  }\n\n  static void load_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void store_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void run_tile() {\n#ifdef HAVE_AMX\n    _tile_dpbssd(4, 0, 2);\n    _tile_dpbssd(5, 0, 3);\n    _tile_dpbssd(6, 1, 2);\n    _tile_dpbssd(7, 1, 3);\n#endif\n  }\n\n  using BufferA = BufferAImpl<GemmKernel224Int8>;\n  using BufferC = BufferCImpl<GemmKernel224Int8>;\n\n  struct BufferB {\n    int8_t* b;\n    float* d;\n    int n, k;\n    static constexpr bool SCALE = true;\n\n    static size_t required_size(int n, int k) { return sizeof(int8_t) * n * k + sizeof(float) * n; }\n\n    BufferB(int n, int k, void* ptr) : n(n), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(n % N_STEP == 0);\n      assert(k % K_STEP == 0);\n      if (n % N_STEP || k % K_STEP) {\n        printf(\"n: %d, k: %d, N_STEP: %d, K_STEP: %d\\n\", n, k, N_STEP, K_STEP);\n        throw std::runtime_error(\"BufferB: n and k must be multiples of N_STEP and K_STEP\");\n      }\n      b = reinterpret_cast<int8_t*>(ptr);\n      d = reinterpret_cast<float*>(b + n * k);\n    }\n\n    void from_mat(ggml_bf16_t* src, int ith, int nth) {  // CHECK: nth has no usage\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < N_STEP; i++) {\n          float amax = 0.0f;\n          for (int j = 0; j < k; j += 32) {\n            __m512 f0, f1;\n            avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1);\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n          }\n          d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1);\n        }\n      }\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < N_STEP; i++) {\n              __m512 id = _mm512_set1_ps(d[n_block_begin + n_begin + i] ? 1.0f / d[n_block_begin + n_begin + i] : 0.0f);\n              int8_t* dst = b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                            k_begin * N_STEP + i * K_STEP;\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32((__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin),\n                                        &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n              __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n              __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n              __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n              __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n              _mm_store_si128((__m128i*)dst, s0);\n              _mm_store_si128((__m128i*)(dst + 16), s1);\n              _mm_store_si128((__m128i*)(dst + 32), s2);\n              _mm_store_si128((__m128i*)(dst + 48), s3);\n            }\n            transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                             n_begin * k_block_size + k_begin * N_STEP));\n            transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                             n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));\n          }\n        }\n      }\n    }\n\n    int8_t* get_submat(int n, int k, int n_begin, int k_begin) {\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      n_begin -= n_block_begin;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;\n    }\n\n    float* get_scale(int n, int n_begin) { return d + n_begin; }\n  };\n\n  static void amx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,\n                         BufferB* bb) {\n    using K = GemmKernel224Int8;\n    if (k_block_begin == 0) {\n      K::clean_c();\n    } else {\n      K::load_c((int32_t*)c, K::N_STEP * sizeof(int32_t));\n    }\n    for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t));\n      K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t));\n      K::run_tile();\n    }\n    K::store_c((int32_t*)c, K::N_STEP * sizeof(int32_t));\n  }\n  static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,\n                         BufferB* bb) {\n    __m512i* c512 = (__m512i*)c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n    if (k_block_begin == 0) {\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        c512[m_i * 2] = _mm512_setzero_si512();\n        c512[m_i * 2 + 1] = _mm512_setzero_si512();\n      }\n    }\n\n    for (int k_begin = 0; k_begin < K_BLOCK && k_block_begin + k_begin < k; k_begin += K_STEP) {\n      static_assert(K_STEP * sizeof(int8_t) == sizeof(__m512i));\n      static_assert(N_STEP / TILE_N == 2, \"Must be lke this\");\n\n      int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n      __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        for (int k_i = 0; k_i < 16; k_i++) {\n          __m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);\n          for (int n_i = 0; n_i < 2; n_i++) {\n            c512[m_i * 2 + n_i] = _mm512_dpbssd_epi32(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]);\n          }\n        }\n      }\n    }\n  }\n\n  static void apply_scale(int m, int n, int m_begin, int n_begin, float* c, BufferA* ba, BufferB* bb) {\n    using K = GemmKernel224Int8;\n    int to = m - m_begin;\n    if (m - m_begin > K::M_STEP) {\n      to = K::M_STEP;\n    }\n    for (int i = 0; i < to; i++) {\n      __m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i));\n      __m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin));\n      __m512i now = _mm512_load_si512((__m512i*)(c + i * K::N_STEP));\n      __m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP), result);\n      bs = _mm512_load_ps(bb->get_scale(n, n_begin) + K::TILE_N);\n      now = _mm512_load_si512((__m512i*)(c + i * K::N_STEP + K::TILE_N));\n      result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP + K::TILE_N), result);\n    }\n  }\n};\n\nstruct GemmKernel224Int4 {\n  using dt = void;\n  using output_t = int32_t;\n  static constexpr double ELEMENT_SIZE = 0.5;\n  static const int TILE_M = 16;\n  static const int TILE_K = 64;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  // static inline const int N_BLOCK = 256;\n  static inline const int N_BLOCK = 128;\n  // static inline const int N_BLOCK = 64;\n  // static inline const int K_BLOCK = 7168;\n  static inline const int K_BLOCK = 3584;\n  // static inline const int K_BLOCK = 2560;\n\n  static std::string name() { return \"INT4\"; }\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n#ifdef HAVE_AMX\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 64\n    for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K);\n\n    // size is 16 x 64\n    for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK);\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n#endif\n  }\n\n  alignas(64) static constexpr uint8_t hi_mask_arr[64] = {\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0};\n\n  alignas(64) static constexpr uint8_t lo_mask_arr[64] = {\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F};\n\n  alignas(64) static constexpr uint8_t sign_mask_arr[64] = {\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n  };\n\n  static __m512i hi_mask() { return *((__m512i*)(&hi_mask_arr[0])); }\n  static __m128i hi_mask_128() { return *((__m128i*)(&hi_mask_arr[0])); }\n  static __m512i lo_mask() { return *((__m512i*)(&lo_mask_arr[0])); }\n  static __m128i lo_mask_128() { return *((__m128i*)(&lo_mask_arr[0])); }\n  static __m128i si_mask_128() { return *((__m128i*)(&sign_mask_arr[0])); }\n\n  static void load_b_hi(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    // 在函数内部分配一个局部(栈上)对齐缓冲区\n    alignas(64) int8_t local_buffer[TILE_N * TILE_K];\n    __m512i* db = reinterpret_cast<__m512i*>(local_buffer);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_and_si512(hi_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * i)));\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(2, db, TILE_K);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_and_si512(hi_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * (i + TILE_N))));\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(3, db, TILE_K);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void load_b_lo(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    // 在函数内部分配一个局部(栈上)对齐缓冲区\n    alignas(64) int8_t local_buffer[TILE_N * TILE_K];\n    __m512i* db = reinterpret_cast<__m512i*>(local_buffer);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_slli_epi32(_mm512_and_si512(lo_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * i))), 4);\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(2, db, TILE_K);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_slli_epi32(\n          _mm512_and_si512(lo_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * (i + TILE_N)))), 4);\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(3, db, TILE_K);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void load_a(dt* a, size_t lda) {\n#ifdef HAVE_AMX\n    _tile_stream_loadd(0, a, lda);\n    _tile_stream_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n#else\n    (void)a;\n    (void)lda;\n#endif\n  }\n\n  static void clean_c() {\n#ifdef HAVE_AMX\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n#endif\n  }\n\n  static void load_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void store_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void run_tile() {\n#ifdef HAVE_AMX\n    _tile_dpbssd(4, 0, 2);\n    _tile_dpbssd(5, 0, 3);\n    _tile_dpbssd(6, 1, 2);\n    _tile_dpbssd(7, 1, 3);\n#endif\n  }\n\n  using BufferA = BufferAImpl<GemmKernel224Int4>;\n  using BufferB = BufferBInt4Impl<GemmKernel224Int4>;\n  using BufferC = BufferCImpl<GemmKernel224Int4>;\n\n  static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,\n                         BufferB* bb) {\n    using K = GemmKernel224Int4;\n    __m512i* c512 = (__m512i*)c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n    if (k_block_begin == 0) {\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        c512[m_i * 2] = _mm512_setzero_si512();\n        c512[m_i * 2 + 1] = _mm512_setzero_si512();\n      }\n    }\n\n    for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::BufferB::B_K_STEP) {\n      int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n      int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);\n      __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        for (int k_i = 0; k_i < 16; k_i++) {\n          __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);\n          __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);\n          for (int n_i = 0; n_i < 2; n_i++) {\n            __m512i b512_lo = _mm512_slli_epi32(_mm512_and_si512(K::lo_mask(), b512[n_i * 16 + k_i]), 4);\n            c512[m_i * 2 + n_i] = _mm512_dpbssd_epi32(c512[m_i * 2 + n_i], ma_lo, b512_lo);\n            __m512i b512_hi = _mm512_and_si512(K::hi_mask(), b512[n_i * 16 + k_i]);\n            c512[m_i * 2 + n_i] = _mm512_dpbssd_epi32(c512[m_i * 2 + n_i], ma_hi, b512_hi);\n          }\n        }\n      }\n    }\n  }\n  static void amx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,\n                         BufferB* bb) {\n    using K = GemmKernel224Int4;\n    if (k_block_begin == 0) {\n      K::clean_c();\n    } else {\n      // printf(\"load from c int4\\n\");\n      K::load_c((int32_t*)c, K::N_STEP * sizeof(int32_t));\n    }\n    for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::BufferB::B_K_STEP) {\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t));\n      K::load_b_lo(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::BufferB::B_K_STEP / 2);\n      K::run_tile();\n      // DEBUG\n      // if(m_begin == 0 && n_begin == 0 && k_begin==0){\n      //   int8_t *ba_ptr = ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n      //   int8_t *bb_ptr = (int8_t *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n      //   printf(\"k_begin:%d,k_block_begin:%d\\n\",k_begin,k_block_begin);\n      //   for(int j=0;j<4096;j++){\n      //     printf(\"a[%d]: %d \", j, ba_ptr[j]);\n      //   }\n      //   printf(\"\\n\");\n      //   for(int j=0;j<4096;j++){\n      //     printf(\"b[%d]: %d \", j, bb_ptr[j]);\n      //   }\n      //   printf(\"\\n\");\n      // }\n\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP), K::K_STEP * sizeof(int8_t));\n      K::load_b_hi(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::BufferB::B_K_STEP / 2);\n      K::run_tile();\n    }\n\n    // debug_tiles_224();\n    K::store_c((int32_t*)c, K::N_STEP * sizeof(int32_t));\n    // DEBUG c 的值,第一行的前 30 列\n    // printf(\"\\nint4, m_begin:%d,n_begin:%d,k_block_begin:%d\\n\",m_begin,n_begin,k_block_begin);\n    // for(int j=0;j<30;j++){\n    //   printf(\"c[%d]: %d \", j, ((int32_t *)c)[j]);\n    // }\n    // printf(\"\\n\");\n  }\n\n  static void apply_scale(int m, int n, int m_begin, int n_begin, float* c, BufferA* ba, BufferB* bb) {\n    using K = GemmKernel224Int4;\n    int to = m - m_begin;\n    if (m - m_begin > K::M_STEP) {\n      to = K::M_STEP;\n    }\n    for (int i = 0; i < to; i++) {\n      __m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i));\n      __m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin));\n      __m512i now = _mm512_load_epi32((__m512i*)(c + i * K::N_STEP));\n      __m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      // if(i==0){\n      //   printf(\"\\nnormal\\n\");\n      //   printf(\"m_begin:%d,n_begin:%d\\n\", m_begin, n_begin);\n      //   // 打印 result 结果，16 个 float 数值\n      //   for(int j = 0; j < 16; j++) {\n      //     float val = *((float *) &result + j);\n      //     int32_t now_val = *((int32_t *) &now + j);\n      //     printf(\"result[%d]: %f,now:%d \", j, val, now_val);\n      //   }\n      //   printf(\"\\n\");\n      // }\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP), result);\n      bs = _mm512_load_ps(bb->get_scale(n, n_begin) + K::TILE_N);\n      now = _mm512_load_si512((__m512i*)(c + i * K::N_STEP + K::TILE_N));\n      result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      // if(i==0){\n      //   printf(\"\\nnormal\\n\");\n      //   printf(\"m_begin:%d,n_begin:%d\\n\", m_begin, n_begin);\n      //   // 打印 result 结果，16 个 float 数值\n      //   for(int j = 0; j < 16; j++) {\n      //     float val = *((float *) &result + j);\n      //     int32_t now_val = *((int32_t *) &now + j);\n      //     printf(\"result[%d]: %f,now:%d \", j+16, val, now_val);\n      //   }\n      //   printf(\"\\n\");\n      // }\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP + K::TILE_N), result);\n    }\n  }\n};\n\nstruct GemmKernel224Int4_1 {\n  using dt = void;\n  using output_t = int32_t;\n  static constexpr double ELEMENT_SIZE = 0.5;\n  static const int TILE_M = 16;\n  static const int TILE_K = 64;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  static inline const int N_BLOCK = 256;\n  // static inline const int K_BLOCK = 7168;\n  static inline const int K_BLOCK = 3584;\n  // static inline const int K_BLOCK = 2560;\n  static std::string name() { return \"INT4_1\"; }\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n#ifdef HAVE_AMX\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 64\n    for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K);\n\n    // size is 16 x 64\n    for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK);\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n#endif\n  }\n\n  alignas(64) static constexpr uint8_t hi_mask_arr[64] = {\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0};\n\n  alignas(64) static constexpr uint8_t lo_mask_arr[64] = {\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F};\n\n  alignas(64) static constexpr uint8_t sign_mask_arr[64] = {\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n  };\n\n  static __m512i hi_mask() { return *((__m512i*)(&hi_mask_arr[0])); }\n  static __m128i hi_mask_128() { return *((__m128i*)(&hi_mask_arr[0])); }\n  static __m512i lo_mask() { return *((__m512i*)(&lo_mask_arr[0])); }\n  static __m128i lo_mask_128() { return *((__m128i*)(&lo_mask_arr[0])); }\n  static __m128i si_mask_128() { return *((__m128i*)(&sign_mask_arr[0])); }\n\n  static void load_b_hi(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    // 在函数内部分配一个局部(栈上)对齐缓冲区\n    alignas(64) int8_t local_buffer[TILE_N * TILE_K];\n    __m512i* db = reinterpret_cast<__m512i*>(local_buffer);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_and_si512(hi_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * i)));\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(2, db, TILE_K);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_and_si512(hi_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * (i + TILE_N))));\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(3, db, TILE_K);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void load_b_lo(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    // 在函数内部分配一个局部(栈上)对齐缓冲区\n    alignas(64) int8_t local_buffer[TILE_N * TILE_K];\n    __m512i* db = reinterpret_cast<__m512i*>(local_buffer);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_slli_epi32(_mm512_and_si512(lo_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * i))), 4);\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(2, db, TILE_K);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_slli_epi32(\n          _mm512_and_si512(lo_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * (i + TILE_N)))), 4);\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(3, db, TILE_K);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void load_a(dt* a, size_t lda) {\n#ifdef HAVE_AMX\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n#else\n    (void)a;\n    (void)lda;\n#endif\n  }\n\n  // static void load_b(dt* b, size_t ldb) {\n  //   _tile_loadd(2, b, ldb);\n  //   _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);\n  // }\n\n  static void clean_c() {\n#ifdef HAVE_AMX\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n#endif\n  }\n\n  static void load_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void store_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void run_tile() {\n#ifdef HAVE_AMX\n    _tile_dpbsud(4, 0, 2);\n    _tile_dpbsud(5, 0, 3);\n    _tile_dpbsud(6, 1, 2);\n    _tile_dpbsud(7, 1, 3);\n#endif\n  }\n\n  using BufferA = BufferAWithSumImpl<GemmKernel224Int4_1>;\n\n  using BufferB = BufferBInt4WithZeroImpl<GemmKernel224Int4_1>;\n\n  using BufferC = BufferCImpl<GemmKernel224Int4_1>;\n\n  static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,\n                         BufferB* bb) {\n    using K = GemmKernel224Int4_1;\n    __m512i* c512 = (__m512i*)c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n    if (k_block_begin == 0) {\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        c512[m_i * 2] = _mm512_setzero_si512();\n        c512[m_i * 2 + 1] = _mm512_setzero_si512();\n      }\n    }\n    for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::BufferB::B_K_STEP) {\n      int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n      int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP);\n      __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        for (int k_i = 0; k_i < 16; k_i++) {\n          __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);\n          __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);\n          for (int n_i = 0; n_i < 2; n_i++) {\n            __m512i b512_lo = _mm512_slli_epi32(_mm512_and_si512(K::lo_mask(), b512[n_i * 16 + k_i]), 4);\n            c512[m_i * 2 + n_i] = _mm512_dpbusd_epi32_compat(c512[m_i * 2 + n_i], b512_lo, ma_lo);\n            __m512i b512_hi = _mm512_and_si512(K::hi_mask(), b512[n_i * 16 + k_i]);\n            c512[m_i * 2 + n_i] = _mm512_dpbusd_epi32_compat(c512[m_i * 2 + n_i], b512_hi, ma_hi);\n          }\n        }\n      }\n    }\n  }\n  static void amx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,\n                         BufferB* bb) {\n    using K = GemmKernel224Int4_1;\n    if (k_block_begin == 0) {\n      K::clean_c();\n    } else {\n      K::load_c((int32_t*)c, K::N_STEP * sizeof(int32_t));\n    }\n    for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::BufferB::B_K_STEP) {\n      // printf(\"offset a %ld\\n\", pointer_offset(ba->get_submat(m, k, m_begin, k_block_begin + k_begin),\n      // ba->a)); printf(\"offset b %ld\\n\", pointer_offset(bb->get_submat(n, k, n_begin, k_block_begin +\n      // k_begin), bb->b));\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t));\n      K::load_b_lo(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::BufferB::B_K_STEP / 2);\n      K::run_tile();\n      // DEBUG\n      // if(m_begin == 0 && n_begin == 0 && k_begin==0){\n      //   int8_t *ba_ptr = ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n      //   int8_t *bb_ptr = (int8_t *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n      //   printf(\"k_begin:%d,k_block_begin:%d\\n\",k_begin,k_block_begin);\n      //   for(int j=0;j<2048;j++){\n      //     printf(\"a[%d]: %d \", j, ba_ptr[j]);\n      //   }\n      //   printf(\"\\n\");\n      //   for(int j=0;j<2048;j++){\n      //     printf(\"b[%d]: %d \", j, bb_ptr[j]);\n      //   }\n      //   printf(\"\\n\");\n      // }\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin + K::K_STEP), K::K_STEP * sizeof(int8_t));\n      K::load_b_hi(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::BufferB::B_K_STEP / 2);\n      K::run_tile();\n    }\n\n    // debug_tiles_224();\n    K::store_c((int32_t*)c, K::N_STEP * sizeof(int32_t));\n    // DEBUG c 的值,第一行的前 30 列\n    // printf(\"\\nint4_1, m_begin:%d,n_begin:%d,k_block_begin:%d\\n\",m_begin,n_begin,k_block_begin);\n    // for(int j=0;j<30;j++){\n    //   printf(\"c[%d]: %d \", j, ((int32_t *)c)[j]);\n    // }\n    // printf(\"\\n\");\n  }\n\n  static void apply_scale(int m, int n, int m_begin, int n_begin, float* c, BufferA* ba, BufferB* bb) {\n    using K = GemmKernel224Int4_1;\n    int to = m - m_begin;\n    if (m - m_begin > K::M_STEP) {\n      to = K::M_STEP;\n    }\n    for (int i = 0; i < to; i++) {\n      __m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i));\n      __m512 asum = _mm512_set1_ps(*ba->get_sum(m, m_begin + i));\n\n      __m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin));\n      __m512 b_mins = _mm512_load_ps(bb->get_min(n, n_begin));\n      __m512i now = _mm512_load_epi32((__m512i*)(c + i * K::N_STEP));\n      __m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      result = _mm512_add_ps(result, _mm512_mul_ps(asum, b_mins));\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP), result);\n\n      bs = _mm512_load_ps(bb->get_scale(n, n_begin) + K::TILE_N);\n      b_mins = _mm512_load_ps(bb->get_min(n, n_begin) + K::TILE_N);\n      now = _mm512_load_si512((__m512i*)(c + i * K::N_STEP + K::TILE_N));\n      result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      result = _mm512_add_ps(result, _mm512_mul_ps(asum, b_mins));\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP + K::TILE_N), result);\n    }\n  }\n};\n\ntemplate <typename TA, typename TB, typename TC>\nvoid mat_mul_single(int m, int n, int k, TA* a, size_t lda, TB* b, size_t ldb, TC* c, size_t ldc);\ntemplate <>\ninline void mat_mul_single(int m, int n, int k, int8_t* a, size_t lda, int8_t* b, size_t ldb, int32_t* c, size_t ldc) {\n  using Kernel = GemmKernel133<32>;\n  for (int m_begin = 0; m_begin < m; m_begin += GemmKernel133<32>::M_STEP) {\n    int m_end = std::min(m_begin + GemmKernel133<32>::M_STEP, m);\n    for (int n_begin = 0; n_begin < n; n_begin += GemmKernel133<32>::N_STEP) {\n      int n_end = std::min(n_begin + GemmKernel133<32>::N_STEP, n);\n      for (int k_begin = 0; k_begin < k; k_begin += GemmKernel133<32>::K_STEP) {\n        int k_end = std::min(k_begin + GemmKernel133<32>::K_STEP, k);\n        int8_t* as = offset_pointer_row_major(a, m_begin, k_begin, lda);\n        int8_t* bs = offset_pointer_col_major(b, k_begin, n_begin, ldb);\n        int32_t* cs = offset_pointer_row_major(c, m_begin, n_begin, ldc);\n        GemmKernel133<32>::BTile<int8_t> tb;\n        if (n_end - n_begin == GemmKernel133<32>::N_STEP && k_end - k_begin == GemmKernel133<32>::K_STEP) {\n          tb.load_from(bs, ldb);\n        } else {\n          tb.partial_load(bs, n_end - n_begin, k_end - k_begin, ldb);\n        }\n        if (m_end - m_begin == GemmKernel133<32>::M_STEP && k_end - k_begin == GemmKernel133<32>::K_STEP) {\n          // printf(\"sub mat mul, full tile: (%d,%d),(%d,%d),(%d,%d)\\n\",\n          // m_begin, m_end, n_begin, n_end, k_begin, k_end);\n          tb.run_full_ac(as, lda, cs, ldc);\n        } else {\n          // printf(\"sub mat mul, partial tile: (%d,%d),(%d,%d),(%d,%d)\\n\",\n          // m_begin, m_end, n_begin, n_end, k_begin, k_end);\n          GemmKernel133<32>::PartialTiles<int8_t, int8_t, int32_t> p;\n          p.partial_run(m_end - m_begin, n_end - n_begin, k_end - k_begin, as, lda, bs, ldb, cs, ldc);\n        }\n      }\n    }\n  }\n}\n\ntemplate <>\ninline void mat_mul_single(int m, int n, int k, ggml_bf16_t* a, size_t lda, ggml_bf16_t* b, size_t ldb, float* c,\n                           size_t ldc) {\n  // // GemmKernel133BF::config();\n\n  // for (int m_begin = 0; m_begin < m; m_begin += GemmKernel133BF::M_STEP) {\n  //   int m_end = std::min(m_begin + GemmKernel133BF::M_STEP, m);\n  //   for (int n_begin = 0; n_begin < n; n_begin += GemmKernel133BF::N_STEP) {\n  //     int n_end = std::min(n_begin + GemmKernel133BF::N_STEP, n);\n\n  //     for (int k_begin = 0; k_begin < k; k_begin += GemmKernel133BF::K_STEP)\n  //     {\n  //       int k_end = std::min(k_begin + GemmKernel133BF::K_STEP, k);\n\n  //       ggml_bf16_t* as = offset_pointer_row_major(a, m_begin, k_begin, lda);\n  //       ggml_bf16_t* bs = offset_pointer_col_major(b, k_begin, n_begin, ldb);\n  //       GemmKernel133BF::BTile tb;\n  //       if (n_end - n_begin == GemmKernel133BF::N_STEP && k_end - k_begin ==\n  //       GemmKernel133BF::K_STEP) {\n  //         tb.full_load(bs, ldb);\n  //       } else {\n  //         tb.partial_load(bs, n_end - n_begin, k_end - k_begin, ldb);\n  //       }\n  //       float* cs = offset_pointer_row_major(c, m_begin, n_begin, ldc);\n\n  //       if (m_end - m_begin == GemmKernel133<32>::M_STEP && k_end - k_begin\n  //       == GemmKernel133<32>::K_STEP) {\n  //         // printf(\"sub mat mul, full tile: (%d,%d),(%d,%d),(%d,%d)\\n\",\n  //         m_begin, m_end, n_begin, n_end, k_begin,\n  //         // k_end);\n  //         tb.run_full_ac(as, lda, cs, ldc);\n  //       } else {\n  //         // printf(\"sub mat mul, partial tile: (%d,%d),(%d,%d),(%d,%d)\\n\",\n  //         m_begin, m_end, n_begin, n_end, k_begin,\n  //         // k_end);\n  //         GemmKernel133BF::PartialTiles p;\n  //         p.partial_run(m_end - m_begin, n_end - n_begin, k_end - k_begin,\n  //         as, lda, bs, ldb, cs, ldc);\n  //       }\n  //     }\n  //   }\n  // }\n}\n\ntemplate <typename QA>\nvoid mat_mul_single(int m, int n, int k, QA* a, size_t lda, block_q8_0* b, size_t ldb, float* c, size_t ldc) {\n  // amx::init();\n  assert(QK8_0 == 32);\n  assert(QK4_0 == 32);\n  assert(GemmKernel133<32>::K_STEP == 32);\n  // assert(reinterpret_cast<intptr_t>(c) % 64 == 0);\n  assert(ldc % 64 == 0);\n\n  // GemmKernal133::config();\n  for (int n_begin = 0; n_begin < n; n_begin += GemmKernel133<32>::N_STEP) {\n    int n_end = std::min(n_begin + GemmKernel133<32>::N_STEP, n);\n\n    for (int k_begin = 0; k_begin < k; k_begin += GemmKernel133<32>::K_STEP) {\n      int k_end = std::min(k_begin + GemmKernel133<32>::K_STEP, k);\n      int kb = k_begin / GemmKernel133<32>::K_STEP;\n      block_q8_0* bs = offset_pointer_col_major(b, kb, n_begin, ldb);\n      GemmKernel133<32>::PartialTiles<int8_t, int8_t, int32_t> p;\n      p.tb.partial_load_quant(bs, n_end - n_begin, k_end - k_begin, ldb);\n      for (int m_begin = 0; m_begin < m; m_begin += GemmKernel133<32>::M_STEP) {\n        int m_end = std::min(m_begin + GemmKernel133<32>::M_STEP, m);\n        QA* as = offset_pointer_row_major(a, m_begin, kb, lda);\n\n        float* cs = offset_pointer_row_major(c, m_begin, n_begin, ldc);\n        // printf(\"sub mat mul: (%d,%d),(%d,%d),(%d,%d) %ld %ld\\n\", m_begin,\n        // m_end, n_begin, n_end, k_begin, k_end,as-a,bs-b);\n\n        // p.partial_run_quant(m_end - m_begin, n_end - n_begin, k_end -\n        // k_begin, as, lda, bs, ldb, cs, ldc);\n        p.partial_run_quant_ac(m_end - m_begin, n_end - n_begin, k_end - k_begin, as, lda, cs, ldc);\n      }\n    }\n  }\n}\n\ninline void mat_mul_single(int m, int n, int k, block_q4_K* a, size_t lda, block_q8_K* b, size_t ldb, float* c,\n                           size_t ldc) {\n  assert(QK_K == 256);\n  assert(k % QK_K == 0);\n  assert(QK_K % GemmKernel133<32>::K_STEP == 0);\n  assert(GemmKernel133<32>::K_STEP == 32);\n  assert(ldc % 64 == 0);\n\n  for (int m_begin = 0; m_begin < m; m_begin += GemmKernel133<32>::M_STEP) {\n    int m_end = std::min(m_begin + GemmKernel133<32>::M_STEP, m);\n    for (int n_begin = 0; n_begin < n; n_begin += GemmKernel133<32>::N_STEP) {\n      int n_end = std::min(n_begin + GemmKernel133<32>::N_STEP, n);\n      float* cs = offset_pointer_row_major(c, m_begin, n_begin, ldc);\n      for (int k_bigstart = 0; k_bigstart < k; k_bigstart += QK_K) {\n        int k_bigend = k_bigstart + QK_K;\n        int super_block_index = k_bigstart / QK_K;\n\n        block_q8_K* super_bs = offset_pointer_col_major(b, super_block_index, n_begin, ldb);\n\n        block_q4_K* super_as = offset_pointer_row_major(a, m_begin, super_block_index, lda);\n        float super_scale = GGML_FP16_TO_FP32(super_as->d);\n        float super_min = GGML_FP16_TO_FP32(super_as->dmin);\n        __m512 a_sm = _mm512_mul_ps(\n            _mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(make_q4K_scale_and_min(super_as->scales))),\n            _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set1_ps(super_scale)), _mm256_set1_ps(super_min), 1));\n        float* a_scale = reinterpret_cast<float*>(&a_sm);\n        float* a_min = a_scale + 8;\n\n        for (int inner_idx = 0; inner_idx < 256 / 32; inner_idx++) {\n          amx::GemmKernel133<32>::PartialTilesSum<uint8_t, int8_t, float> t;\n          // printf(\"sub mat mul: (%d,%d),(%d,%d),(%d,%d) %d\\n\", m_begin, m_end,\n          // n_begin, n_end, k_bigstart,\n          //        k_bigend,inner_idx);\n          t.tb.partial_load_quant(super_bs, n_end - n_begin, inner_idx, ldb);\n          t.partial_run_quant_ac(m_end - m_begin, n_end - n_begin, inner_idx, super_as, lda, cs, ldc,\n                                 a_scale[inner_idx], a_min[inner_idx]);\n        }\n      }\n    }\n  }\n}\n\ninline void mat_mul_single(int m, int n, int k, blocks_aligned_q8_0_ref a, int a_blck_stride, blocks_aligned_q8_0_ref b,\n                           int b_blck_stride, float* c, size_t ldc) {\n  using Kernel = GemmKernel133<64>;\n  using TA = uint8_t;\n  using TB = int8_t;\n\n  for (int m_begin = 0; m_begin < m; m_begin += Kernel::M_STEP) {\n    int m_end = std::min(m_begin + Kernel::M_STEP, m);\n    for (int n_begin = 0; n_begin < n; n_begin += Kernel::N_STEP) {\n      int n_end = std::min(n_begin + Kernel::N_STEP, n);\n      for (int k_begin = 0; k_begin < k; k_begin += Kernel::K_STEP) {\n        int k_end = std::min(k_begin + Kernel::K_STEP, k);\n\n        int k_block = k_begin / Kernel::K_STEP;\n\n        auto as = a.offset(m_begin * a_blck_stride + k_block);\n        auto bs = b.offset(n_begin * b_blck_stride + k_block);\n        auto cs = offset_pointer_row_major(c, m_begin, n_begin, ldc);\n\n        // printf(\"sub mat mul: (%d,%d),(%d,%d),(%d,%d) %ld %ld\\n\", m_begin,\n        // m_end, n_begin, n_end, k_begin, k_end,as.d-a.d,bs.d-b.d);\n\n        Kernel::PartialTiles<TA, TB, int32_t> t;\n        t.tb.partial_load_quant(bs, n_end - n_begin, k_end - k_begin, b_blck_stride);\n        t.partial_run_quant_ac(m_end - m_begin, n_end - n_begin, k_end - k_begin, as, a_blck_stride, cs, ldc);\n      }\n    }\n  }\n}\n\ninline void merge_mat(int d0, int d1, float* a, float* b, size_t ld) {\n  __m512* va = (__m512*)a;\n  __m512* vb = (__m512*)b;\n\n  size_t d1v = (d1 + 15) / 16;\n\n  for (int i = 0; i < d0; i++) {\n    auto ta = offset_pointer_row_major(va, i, 0, ld);\n    auto tb = offset_pointer_row_major(vb, i, 0, ld);\n    for (int j = 0; j < d1v; j++) {\n      ta[j] = _mm512_add_ps(ta[j], tb[j]);\n    }\n  }\n}\n\ninline void merge_mats(int d0, int d1, int cnt, float** data, size_t ld) {\n  for (int i = 0; i < cnt; i++) {\n    assert((intptr_t)data[i] % 64 == 0);\n    assert(ld % 64 == 0);\n  }\n\n  while (cnt > 1) {\n    int new_cnt = (cnt + 1) / 2;\n    for (int i = 0; i < new_cnt; i++) {\n      int j = new_cnt + i;\n      if (j < cnt) {\n        // printf(\"merge %d %d\\n\", i, j);\n        merge_mat(d0, d1, data[i], data[j], ld);\n      }\n    }\n    cnt = new_cnt;\n  }\n}\n\ntemplate <typename TA, typename TB, typename TC>\nstruct GemmKernel {\n  static_assert(sizeof(TA) == -1, \"No associated type defined for this type.\");\n  using type = GemmKernel224BF;\n};\n\ntemplate <typename TB>\nstruct GemmKernel<uint8_t, TB, float> {\n  using type = GemmKernel133<32>;\n};\n\ntemplate <typename TB>\nstruct GemmKernel<int8_t, TB, float> {\n  using type = GemmKernel133<32>;\n};\n\ntemplate <>\nstruct GemmKernel<block_q4_0, block_q8_0, float> {\n  using type = GemmKernel133<32>;\n};\n\ntemplate <>\nstruct GemmKernel<block_q8_0, block_q8_0, float> {\n  using type = GemmKernel133<32>;\n};\n\ntemplate <>\nstruct GemmKernel<block_q4_K, block_q8_K, float> {\n  using type = GemmKernel133<32>;\n};\n\ntemplate <>\nstruct GemmKernel<ggml_bf16_t, ggml_bf16_t, float> {\n  // using type = GemmKernel133BF;\n  using type = GemmKernel224BF;\n};\n\n// template <typename TA, typename TB, typename TC>\n// void mat_mul(int m, int n, int k, TA* a, size_t lda, TB* b, size_t ldb, TC*\n// c, size_t ldc, int ith, int nth) {\n//   using K = typename GemmKernel<TA, TB, TC>::type;\n\n//   int m_partition_count = (m + K::M_STEP - 1) / K::M_STEP;\n//   int partition_count_per_thread = (m_partition_count + nth - 1) / nth;\n//   int partition_start = ith * partition_count_per_thread;\n//   int partition_end = std::min(partition_start + partition_count_per_thread,\n//   m_partition_count); int m_start = partition_start * K::M_STEP; int m_end =\n//   std::min(m, partition_end * K::M_STEP);\n\n//   mat_mul_single(m_end - m_start, n, k, offset_pointer(a, m_start * lda),\n//   lda, b, ldb, offset_pointer(c, m_start * ldc),\n//                  ldc);\n// }\n\ntemplate <typename TA, typename TB, typename TC>\nvoid mat_mul(int m, int n, int k, TA* a, size_t lda, TB* b, size_t ldb, TC* c, size_t ldc, int ith, int nth) {\n  using K = typename GemmKernel<TA, TB, TC>::type;\n\n  int n_partition_count = (n + K::N_STEP - 1) / K::N_STEP;\n  int partition_count_per_thread = (n_partition_count + nth - 1) / nth;\n  int partition_start = ith * partition_count_per_thread;\n  int partition_end = std::min(partition_start + partition_count_per_thread, n_partition_count);\n  int n_start = partition_start * K::N_STEP;\n  int n_end = std::min(n, partition_end * K::N_STEP);\n\n  mat_mul_single(m, n_end - n_start, k, a, lda, offset_pointer_col_major(b, 0, n_start, ldb), ldb,\n                 offset_pointer_row_major(c, 0, n_start, ldc), ldc);\n}\n\ninline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224BF::BufferA> ba,\n                    std::shared_ptr<GemmKernel224BF::BufferB> bb, std::shared_ptr<GemmKernel224BF::BufferC> bc, int ith,\n                    int nth) {\n  using K = GemmKernel224BF;\n  assert(n % K::N_STEP == 0);\n  assert(k % K::K_STEP == 0);\n\n  auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n\n  // printf(\"n_start %d n_end %d\\n\", n_start, n_end);\n  for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {\n    for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {\n        float* c = bc->get_submat(m, n, m_begin, n_begin);\n        // if (m - m_begin == 1) {\n        if (false) {\n          // if(k_block_begin==0&&m_begin==0&&n_begin==n_start)\n          // printf(\"AVX\");\n          __m512* c512 = (__m512*)c;\n          if (k_block_begin == 0) {\n            for (int m_i = 0; m_i < m; m_i++) {\n              c512[m_i * 2] = _mm512_setzero_ps();\n              c512[m_i * 2 + 1] = _mm512_setzero_ps();\n            }\n          }\n\n          for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n            int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n            __m512bh* b512 = (__m512bh*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n            for (int m_i = 0; m_i < m; m_i++) {\n              for (int k_i = 0; k_i < 16; k_i++) {\n                __m512bh ma = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);\n                for (int n_i = 0; n_i < 2; n_i++) {\n                  c512[m_i * 2 + n_i] = _mm512_dpbf16_ps(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]);\n                }\n              }\n            }\n          }\n\n        } else {\n          if (k_block_begin == 0) {\n            K::clean_c();\n          } else {\n            K::load_c(c, K::N_STEP * sizeof(float));\n          }\n          for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n            K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t));\n            K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t));\n            K::run_tile();\n          }\n          K::store_c(c, K::N_STEP * sizeof(float));\n        }\n      }\n    }\n  }\n}\n\ninline void vec_mul(int m, int n, int k, std::shared_ptr<GemmKernel224BF::BufferA> ba,\n                    std::shared_ptr<GemmKernel224BF::BufferB> bb, std::shared_ptr<GemmKernel224BF::BufferC> bc, int ith,\n                    int nth) {\n  mat_mul(m, n, k, ba, bb, bc, ith, nth);\n}\n\ntemplate <typename K, bool amx_or_avx = true>\nvoid integer_mat_mul(int m, int n, int k, typename K::BufferA* ba, typename K::BufferB* bb, typename K::BufferC* bc,\n                     int ith, int nth) {\n  assert(n % K::N_STEP == 0);\n  assert(k % K::K_STEP == 0);\n\n  auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n\n  for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {\n    for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {\n        float* c = bc->get_submat(m, n, m_begin, n_begin);\n        if constexpr (amx_or_avx && AMX_AVAILABLE) {\n          K::amx_kernel(m, n, k, m_begin, n_begin, k_block_begin, c, ba, bb);\n        } else {\n          K::avx_kernel(m, n, k, m_begin, n_begin, k_block_begin, c, ba, bb);\n        }\n\n        if (k_block_begin + K::K_BLOCK >= k) {\n          K::apply_scale(m, n, m_begin, n_begin, c, ba, bb);\n        }\n      }\n    }\n  }\n}\n\ninline void vec_mul(int m, int n, int k, std::shared_ptr<GemmKernel224Int8::BufferA> ba,\n                    std::shared_ptr<GemmKernel224Int8::BufferB> bb, std::shared_ptr<GemmKernel224Int8::BufferC> bc,\n                    int ith, int nth) {\n  integer_mat_mul<GemmKernel224Int8, false>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224Int8::BufferA> ba,\n                    std::shared_ptr<GemmKernel224Int8::BufferB> bb, std::shared_ptr<GemmKernel224Int8::BufferC> bc,\n                    int ith, int nth) {\n  integer_mat_mul<GemmKernel224Int8, true>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void vec_mul(int m, int n, int k, std::shared_ptr<GemmKernel224Int4::BufferA> ba,\n                    std::shared_ptr<GemmKernel224Int4::BufferB> bb, std::shared_ptr<GemmKernel224Int4::BufferC> bc,\n                    int ith, int nth) {\n  integer_mat_mul<GemmKernel224Int4, false>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224Int4::BufferA> ba,\n                    std::shared_ptr<GemmKernel224Int4::BufferB> bb, std::shared_ptr<GemmKernel224Int4::BufferC> bc,\n                    int ith, int nth) {\n  integer_mat_mul<GemmKernel224Int4, true>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void vec_mul(int m, int n, int k, std::shared_ptr<GemmKernel224Int4_1::BufferA> ba,\n                    std::shared_ptr<GemmKernel224Int4_1::BufferB> bb, std::shared_ptr<GemmKernel224Int4_1::BufferC> bc,\n                    int ith, int nth) {\n  integer_mat_mul<GemmKernel224Int4_1, false>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224Int4_1::BufferA> ba,\n                    std::shared_ptr<GemmKernel224Int4_1::BufferB> bb, std::shared_ptr<GemmKernel224Int4_1::BufferC> bc,\n                    int ith, int nth) {\n  integer_mat_mul<GemmKernel224Int4_1, true>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void mat_mul(int m, int n, int k, blocks_aligned_q8_0_ref aref, int a_blck_stride, blocks_aligned_q8_0_ref bref,\n                    int b_blck_stride, float* c, size_t ldc, int ith, int nth) {\n  using K = GemmKernel133<64>;\n\n  int m_partition_count = (m + K::M_STEP - 1) / K::M_STEP;\n  int partition_count_per_thread = (m_partition_count + nth - 1) / nth;\n  int partition_start = ith * partition_count_per_thread;\n  int partition_end = std::min(partition_start + partition_count_per_thread, m_partition_count);\n  int m_start = partition_start * K::M_STEP;\n  int m_end = std::min(m, partition_end * K::M_STEP);\n\n  mat_mul_single(m_end - m_start, n, k, aref.offset(m_start * a_blck_stride), a_blck_stride, bref, b_blck_stride,\n                 offset_pointer(c, m_start * ldc), ldc);\n}\n\n// K-group quantization kernel with intermediate int32 accumulation\nstruct GemmKernel224Int4KGroup {\n  using dt = void;\n  using output_t = int32_t;\n  static constexpr double ELEMENT_SIZE = 0.5;\n  static const int TILE_M = 16;\n  static const int TILE_K = 64;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n  static inline const int N_BLOCK = 256;\n  // K_BLOCK should match k_group_size for proper scaling\n  static inline const int K_BLOCK = 7168;  // Will be overridden by k_group_size\n\n  static std::string name() { return \"INT4_KGROUP\"; }\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n#ifdef HAVE_AMX\n    enable_amx();\n    TileConfig tile_config;\n    // size is 16 x 64\n    for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K);\n    // size is 16 x 64\n    for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK);\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n    tile_config.set_config();\n#endif\n  }\n\n  alignas(64) static constexpr uint8_t hi_mask_arr[64] = {\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0};\n\n  alignas(64) static constexpr uint8_t lo_mask_arr[64] = {\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F};\n\n  static __m512i hi_mask() { return *((__m512i*)(&hi_mask_arr[0])); }\n  static __m512i lo_mask() { return *((__m512i*)(&lo_mask_arr[0])); }\n\n  static void clean_c() {\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n  }\n\n  static void load_c(output_t* c, size_t ldc) {\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n  }\n\n  static void store_c(output_t* c, size_t ldc) {\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n  }\n\n  static void load_a(dt* a, size_t lda) {\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n  }\n\n  static void load_b_lo(dt* b, size_t ldb) {\n    alignas(64) int8_t local_buffer[TILE_N * TILE_K];\n    __m512i* db = reinterpret_cast<__m512i*>(local_buffer);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      // __m512i temp = _mm512_and_si512(lo_mask(), *static_cast<__m512i *>(offset_pointer(b, ldb * i)));\n      // db[i] = _mm512_slli_epi32(temp, 4);\n      db[i] = _mm512_slli_epi32(_mm512_and_si512(lo_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * i))), 4);\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(2, db, TILE_K);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      // __m512i temp = _mm512_and_si512(lo_mask(), *static_cast<__m512i *>(offset_pointer(b, ldb * (i + TILE_N))));\n      // db[i] = _mm512_slli_epi32(temp, 4);\n      db[i] = _mm512_slli_epi32(\n          _mm512_and_si512(lo_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * (i + TILE_N)))), 4);\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(3, db, TILE_K);\n  }\n\n  static void load_b_hi(dt* b, size_t ldb) {\n    alignas(64) int8_t local_buffer[TILE_N * TILE_K];\n    __m512i* db = reinterpret_cast<__m512i*>(local_buffer);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_and_si512(hi_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * i)));\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(2, db, TILE_K);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_and_si512(hi_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * (i + TILE_N))));\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(3, db, TILE_K);\n  }\n\n  static void run_tile() {\n#ifdef HAVE_AMX\n    _tile_dpbssd(4, 0, 2);\n    _tile_dpbssd(5, 0, 3);\n    _tile_dpbssd(6, 1, 2);\n    _tile_dpbssd(7, 1, 3);\n#endif\n  }\n\n  using BufferA = BufferAKGroupImpl<GemmKernel224Int4KGroup>;\n  using BufferB = BufferBKGroupImpl<GemmKernel224Int4KGroup>;\n  using BufferC = BufferCReduceImpl<GemmKernel224Int4KGroup>;\n\n  // K-group aware AVX kernel - processes a single B_K_STEP chunk\n  static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, int32_t* int_c, BufferA* ba,\n                         BufferB* bb, int k_group_size) {\n    using K = GemmKernel224Int4KGroup;\n    __m512i* c512 = (__m512i*)int_c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n\n    // Initialize int_c to zero at the start of k_group\n    if (k_block_begin % k_group_size == 0) {\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        c512[m_i * 2] = _mm512_setzero_si512();\n        c512[m_i * 2 + 1] = _mm512_setzero_si512();\n      }\n    }\n    int k_offset = k_block_begin % K::BufferB::B_K_STEP;\n    if (k_offset == 0) {\n      int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);\n      __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        for (int k_i = 0; k_i < 16; k_i++) {\n          __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);\n          for (int n_i = 0; n_i < 2; n_i++) {\n            __m512i b512_lo = _mm512_slli_epi32(_mm512_and_si512(K::lo_mask(), b512[n_i * 16 + k_i]), 4);\n            c512[m_i * 2 + n_i] = _mm512_dpbssd_epi32(c512[m_i * 2 + n_i], ma_lo, b512_lo);\n          }\n        }\n      }\n    } else {\n      int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);\n      __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        for (int k_i = 0; k_i < 16; k_i++) {\n          __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);\n          for (int n_i = 0; n_i < 2; n_i++) {\n            __m512i b512_hi = _mm512_and_si512(K::hi_mask(), b512[n_i * 16 + k_i]);\n            c512[m_i * 2 + n_i] = _mm512_dpbssd_epi32(c512[m_i * 2 + n_i], ma_hi, b512_hi);\n          }\n        }\n      }\n    }\n  }\n\n  // K-group aware AMX kernel - processes a single K_STEP chunk (lo or hi nibble)\n  static void amx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, int32_t* int_c, BufferA* ba,\n                         BufferB* bb, int k_group_size) {\n    using K = GemmKernel224Int4KGroup;\n    // Initialize or load int_c at start of k_group\n    if (k_block_begin % k_group_size == 0) {\n      K::clean_c();\n    } else {\n      K::load_c(int_c, K::N_STEP * sizeof(int32_t));\n    }\n\n    // Determine if we're processing lo or hi nibble based on position within B_K_STEP\n    int k_offset = k_block_begin % K::BufferB::B_K_STEP;\n    if (k_offset == 0) {\n      // Process lo nibble\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin), K::K_STEP * sizeof(int8_t));\n      K::load_b_lo(bb->get_submat(n, k, n_begin, k_block_begin), K::BufferB::B_K_STEP / 2);\n      K::run_tile();\n    } else {\n      // Process hi nibble (k_offset == K_STEP)\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin), K::K_STEP * sizeof(int8_t));\n      K::load_b_hi(bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP), K::BufferB::B_K_STEP / 2);\n      K::run_tile();\n    }\n\n    K::store_c(int_c, K::N_STEP * sizeof(int32_t));\n  }\n\n  // K-group aware scale application\n  static void apply_scale_kgroup(int m, int n, int m_begin, int n_begin, int k_begin, float* c, int32_t* int_c,\n                                 BufferA* ba, BufferB* bb, int k, int k_group_size) {\n    using K = GemmKernel224Int4KGroup;\n    int to = m - m_begin;\n    if (m - m_begin > K::M_STEP) {\n      to = K::M_STEP;\n    }\n\n    for (int i = 0; i < to; i++) {\n      // Get scale for this k_group\n      __m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i, k, k_begin));\n      __m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin, k, k_begin));\n      __m512i now = _mm512_load_epi32((__m512i*)(int_c + i * K::N_STEP));\n      __m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      // Load existing float value from c and add\n      __m512 existing = _mm512_load_ps((__m512*)(c + i * K::N_STEP));\n      result = _mm512_add_ps(existing, result);\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP), result);\n\n      // Second half\n      bs = _mm512_load_ps(bb->get_scale(n, n_begin, k, k_begin) + K::TILE_N);\n      now = _mm512_load_si512((__m512i*)(int_c + i * K::N_STEP + K::TILE_N));\n      result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      existing = _mm512_load_ps((__m512*)(c + i * K::N_STEP + K::TILE_N));\n      result = _mm512_add_ps(existing, result);\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP + K::TILE_N), result);\n    }\n  }\n};\nstruct GemmKernel224Int4_1KGroup {\n  using dt = void;\n  using output_t = int32_t;\n  static constexpr double ELEMENT_SIZE = 0.5;\n  static const int TILE_M = 16;\n  static const int TILE_K = 64;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  static inline const int N_BLOCK = 256;\n  // static inline const int K_BLOCK = 7168;\n  static inline const int K_BLOCK = 3584;\n  // static inline const int K_BLOCK = 2560;\n  static std::string name() { return \"INT4_1K\"; }\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n#ifdef HAVE_AMX\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 64\n    for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K);\n\n    // size is 16 x 64\n    for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK);\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n#endif\n  }\n\n  alignas(64) static constexpr uint8_t hi_mask_arr[64] = {\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0};\n\n  alignas(64) static constexpr uint8_t lo_mask_arr[64] = {\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F};\n\n  alignas(64) static constexpr uint8_t sign_mask_arr[64] = {\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n  };\n\n  static __m512i hi_mask() { return *((__m512i*)(&hi_mask_arr[0])); }\n  static __m128i hi_mask_128() { return *((__m128i*)(&hi_mask_arr[0])); }\n  static __m512i lo_mask() { return *((__m512i*)(&lo_mask_arr[0])); }\n  static __m128i lo_mask_128() { return *((__m128i*)(&lo_mask_arr[0])); }\n  static __m128i si_mask_128() { return *((__m128i*)(&sign_mask_arr[0])); }\n\n  static void load_b_hi(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    // 在函数内部分配一个局部(栈上)对齐缓冲区\n    alignas(64) int8_t local_buffer[TILE_N * TILE_K];\n    __m512i* db = reinterpret_cast<__m512i*>(local_buffer);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_and_si512(hi_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * i)));\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(2, db, TILE_K);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_and_si512(hi_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * (i + TILE_N))));\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(3, db, TILE_K);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void load_b_lo(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    // 在函数内部分配一个局部(栈上)对齐缓冲区\n    alignas(64) int8_t local_buffer[TILE_N * TILE_K];\n    __m512i* db = reinterpret_cast<__m512i*>(local_buffer);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_slli_epi32(_mm512_and_si512(lo_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * i))), 4);\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(2, db, TILE_K);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_slli_epi32(\n          _mm512_and_si512(lo_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * (i + TILE_N)))), 4);\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(3, db, TILE_K);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void load_a(dt* a, size_t lda) {\n#ifdef HAVE_AMX\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n#else\n    (void)a;\n    (void)lda;\n#endif\n  }\n\n  // static void load_b(dt* b, size_t ldb) {\n  //   _tile_loadd(2, b, ldb);\n  //   _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);\n  // }\n\n  static void clean_c() {\n#ifdef HAVE_AMX\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n#endif\n  }\n\n  static void load_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void store_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void run_tile() {\n#ifdef HAVE_AMX\n    _tile_dpbsud(4, 0, 2);\n    _tile_dpbsud(5, 0, 3);\n    _tile_dpbsud(6, 1, 2);\n    _tile_dpbsud(7, 1, 3);\n#endif\n  }\n\n  using BufferA = BufferAWithSumKGroupImpl<GemmKernel224Int4_1KGroup>;\n\n  using BufferB = BufferBInt4WithZeroKGroupImpl<GemmKernel224Int4_1KGroup>;\n\n  using BufferC = BufferCReduceImpl<GemmKernel224Int4_1KGroup>;\n\n  static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, int32_t* int_c, BufferA* ba,\n                         BufferB* bb, int k_group_size) {\n    using K = GemmKernel224Int4_1KGroup;\n    __m512i* c512 = (__m512i*)int_c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n    if (k_block_begin % k_group_size == 0) {\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        c512[m_i * 2] = _mm512_setzero_si512();\n        c512[m_i * 2 + 1] = _mm512_setzero_si512();\n      }\n    }\n    int k_offset = k_block_begin % K::BufferB::B_K_STEP;\n    if (k_offset == 0) {\n      int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);\n      __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        for (int k_i = 0; k_i < 16; k_i++) {\n          __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);\n          for (int n_i = 0; n_i < 2; n_i++) {\n            __m512i b512_lo = _mm512_slli_epi32(_mm512_and_si512(K::lo_mask(), b512[n_i * 16 + k_i]), 4);\n            c512[m_i * 2 + n_i] = _mm512_dpbusd_epi32_compat(c512[m_i * 2 + n_i], b512_lo, ma_lo);\n          }\n        }\n      }\n    } else {\n      int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);\n      __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        for (int k_i = 0; k_i < 16; k_i++) {\n          __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);\n          for (int n_i = 0; n_i < 2; n_i++) {\n            __m512i b512_hi = _mm512_and_si512(K::hi_mask(), b512[n_i * 16 + k_i]);\n            c512[m_i * 2 + n_i] = _mm512_dpbusd_epi32_compat(c512[m_i * 2 + n_i], b512_hi, ma_hi);\n          }\n        }\n      }\n    }\n  }\n  static void amx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, int32_t* int_c, BufferA* ba,\n                         BufferB* bb, int k_group_size) {\n    using K = GemmKernel224Int4_1KGroup;\n    if (k_block_begin % k_group_size == 0) {\n      K::clean_c();\n    } else {\n      K::load_c(int_c, K::N_STEP * sizeof(int32_t));\n    }\n\n    // Determine if we're processing lo or hi nibble based on position within B_K_STEP\n    int k_offset = k_block_begin % K::BufferB::B_K_STEP;\n    if (k_offset == 0) {\n      // Process lo nibble\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin), K::K_STEP * sizeof(int8_t));\n      K::load_b_lo(bb->get_submat(n, k, n_begin, k_block_begin), K::BufferB::B_K_STEP / 2);\n      K::run_tile();\n    } else {\n      // Process hi nibble (k_offset == K_STEP)\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin), K::K_STEP * sizeof(int8_t));\n      K::load_b_hi(bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP), K::BufferB::B_K_STEP / 2);\n      K::run_tile();\n    }\n\n    K::store_c(int_c, K::N_STEP * sizeof(int32_t));\n  }\n\n  static void apply_scale_kgroup(int m, int n, int m_begin, int n_begin, int k_begin, float* c, int32_t* int_c,\n                                 BufferA* ba, BufferB* bb, int k, int k_group_size) {\n    using K = GemmKernel224Int4_1KGroup;\n    int to = m - m_begin;\n    if (m - m_begin > K::M_STEP) {\n      to = K::M_STEP;\n    }\n    for (int i = 0; i < to; i++) {\n      __m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i, k, k_begin));\n      __m512 asum = _mm512_set1_ps(*ba->get_sum(m, m_begin + i, k, k_begin));\n\n      __m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin, k, k_begin));\n      __m512 b_mins = _mm512_load_ps(bb->get_min(n, n_begin, k, k_begin));\n      __m512i now = _mm512_load_epi32((__m512i*)(int_c + i * K::N_STEP));\n      __m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      result = _mm512_add_ps(result, _mm512_mul_ps(asum, b_mins));\n      __m512 existing = _mm512_load_ps((__m512*)(c + i * K::N_STEP));\n      result = _mm512_add_ps(result, existing);\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP), result);\n\n      bs = _mm512_load_ps(bb->get_scale(n, n_begin, k, k_begin) + K::TILE_N);\n      b_mins = _mm512_load_ps(bb->get_min(n, n_begin, k, k_begin) + K::TILE_N);\n      now = _mm512_load_si512((__m512i*)(int_c + i * K::N_STEP + K::TILE_N));\n      result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      result = _mm512_add_ps(result, _mm512_mul_ps(asum, b_mins));\n      existing = _mm512_load_ps((__m512*)(c + i * K::N_STEP + K::TILE_N));\n      result = _mm512_add_ps(result, existing);\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP + K::TILE_N), result);\n    }\n  }\n};\n\nstruct GemmKernel224Int4_1_LowKGroup {\n  using dt = void;\n  using output_t = int32_t;\n  static constexpr double ELEMENT_SIZE = 0.5;\n  static const int TILE_M = 16;\n  static const int TILE_K = 64;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  static inline const int N_BLOCK = 256;\n  // static inline const int K_BLOCK = 7168;\n  static inline const int K_BLOCK = 3584;\n  // static inline const int K_BLOCK = 2560;\n  static std::string name() { return \"INT4_1K\"; }\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n#ifdef HAVE_AMX\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 64\n    for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K);\n\n    // size is 16 x 64\n    for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK);\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n#endif\n  }\n\n  alignas(64) static constexpr uint8_t hi_mask_arr[64] = {\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0};\n\n  alignas(64) static constexpr uint8_t lo_mask_arr[64] = {\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F};\n\n  alignas(64) static constexpr uint8_t sign_mask_arr[64] = {\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n      0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80,\n  };\n\n  static __m512i hi_mask() { return *((__m512i*)(&hi_mask_arr[0])); }\n  static __m128i hi_mask_128() { return *((__m128i*)(&hi_mask_arr[0])); }\n  static __m512i lo_mask() { return *((__m512i*)(&lo_mask_arr[0])); }\n  static __m128i lo_mask_128() { return *((__m128i*)(&lo_mask_arr[0])); }\n  static __m128i si_mask_128() { return *((__m128i*)(&sign_mask_arr[0])); }\n\n  static void load_b_hi(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    // 在函数内部分配一个局部(栈上)对齐缓冲区\n    alignas(64) int8_t local_buffer[TILE_N * TILE_K];\n    __m512i* db = reinterpret_cast<__m512i*>(local_buffer);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_srli_epi32(_mm512_and_si512(hi_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * i))), 4);\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(2, db, TILE_K);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_srli_epi32(\n          _mm512_and_si512(hi_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * (i + TILE_N)))), 4);\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(3, db, TILE_K);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void load_b_lo(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    // 在函数内部分配一个局部(栈上)对齐缓冲区\n    alignas(64) int8_t local_buffer[TILE_N * TILE_K];\n    __m512i* db = reinterpret_cast<__m512i*>(local_buffer);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_and_si512(lo_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * i)));\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(2, db, TILE_K);\n\n    for (size_t i = 0; i < TILE_N; i++) {\n      db[i] = _mm512_and_si512(lo_mask(), *static_cast<__m512i*>(offset_pointer(b, ldb * (i + TILE_N))));\n    }\n    asm volatile(\"\" ::: \"memory\");\n    _tile_loadd(3, db, TILE_K);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void load_a(dt* a, size_t lda) {\n#ifdef HAVE_AMX\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n#else\n    (void)a;\n    (void)lda;\n#endif\n  }\n\n  // static void load_b(dt* b, size_t ldb) {\n  //   _tile_loadd(2, b, ldb);\n  //   _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);\n  // }\n\n  static void clean_c() {\n#ifdef HAVE_AMX\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n#endif\n  }\n\n  static void load_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void store_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void run_tile() {\n#ifdef HAVE_AMX\n    _tile_dpbsud(4, 0, 2);\n    _tile_dpbsud(5, 0, 3);\n    _tile_dpbsud(6, 1, 2);\n    _tile_dpbsud(7, 1, 3);\n#endif\n  }\n\n  using BufferA = BufferAWithSumKGroupImpl<GemmKernel224Int4_1_LowKGroup>;\n\n  using BufferB = BufferBInt4WithZeroLowKGroupImpl<GemmKernel224Int4_1_LowKGroup>;\n\n  using BufferC = BufferCReduceImpl<GemmKernel224Int4_1_LowKGroup>;\n\n  static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, int32_t* int_c, BufferA* ba,\n                         BufferB* bb, int k_group_size) {\n    using K = GemmKernel224Int4_1_LowKGroup;\n    __m512i* c512 = (__m512i*)int_c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n    if (k_block_begin % k_group_size == 0) {\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        c512[m_i * 2] = _mm512_setzero_si512();\n        c512[m_i * 2 + 1] = _mm512_setzero_si512();\n      }\n    }\n    int k_offset = k_block_begin % K::BufferB::B_K_STEP;\n    if (k_offset == 0) {\n      int32_t* a32_lo = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);\n      __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin);\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        for (int k_i = 0; k_i < 16; k_i++) {\n          __m512i ma_lo = _mm512_set1_epi32(a32_lo[m_i * 16 + k_i]);\n          for (int n_i = 0; n_i < 2; n_i++) {\n            __m512i b512_lo = _mm512_and_si512(K::lo_mask(), b512[n_i * 16 + k_i]);\n            c512[m_i * 2 + n_i] = _mm512_dpbusd_epi32_compat(c512[m_i * 2 + n_i], b512_lo, ma_lo);\n          }\n        }\n      }\n    } else {\n      int32_t* a32_hi = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin);\n      __m512i* b512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP);\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        for (int k_i = 0; k_i < 16; k_i++) {\n          __m512i ma_hi = _mm512_set1_epi32(a32_hi[m_i * 16 + k_i]);\n          for (int n_i = 0; n_i < 2; n_i++) {\n            __m512i b512_hi = _mm512_srli_epi32(_mm512_and_si512(K::hi_mask(), b512[n_i * 16 + k_i]), 4);\n            c512[m_i * 2 + n_i] = _mm512_dpbusd_epi32_compat(c512[m_i * 2 + n_i], b512_hi, ma_hi);\n          }\n        }\n      }\n    }\n  }\n  static void amx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, int32_t* int_c, BufferA* ba,\n                         BufferB* bb, int k_group_size) {\n    using K = GemmKernel224Int4_1_LowKGroup;\n    if (k_block_begin % k_group_size == 0) {\n      K::clean_c();\n    } else {\n      K::load_c(int_c, K::N_STEP * sizeof(int32_t));\n    }\n\n    // Determine if we're processing lo or hi nibble based on position within B_K_STEP\n    int k_offset = k_block_begin % K::BufferB::B_K_STEP;\n    if (k_offset == 0) {\n      // Process lo nibble\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin), K::K_STEP * sizeof(int8_t));\n      K::load_b_lo(bb->get_submat(n, k, n_begin, k_block_begin), K::BufferB::B_K_STEP / 2);\n      K::run_tile();\n    } else {\n      // Process hi nibble (k_offset == K_STEP)\n      K::load_a(ba->get_submat(m, k, m_begin, k_block_begin), K::K_STEP * sizeof(int8_t));\n      K::load_b_hi(bb->get_submat(n, k, n_begin, k_block_begin - K::K_STEP), K::BufferB::B_K_STEP / 2);\n      K::run_tile();\n    }\n\n    K::store_c(int_c, K::N_STEP * sizeof(int32_t));\n  }\n\n  static void apply_scale_kgroup(int m, int n, int m_begin, int n_begin, int k_begin, float* c, int32_t* int_c,\n                                 BufferA* ba, BufferB* bb, int k, int k_group_size) {\n    using K = GemmKernel224Int4_1_LowKGroup;\n    int to = m - m_begin;\n    if (m - m_begin > K::M_STEP) {\n      to = K::M_STEP;\n    }\n    for (int i = 0; i < to; i++) {\n      __m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i, k, k_begin));\n      __m512 asum = _mm512_set1_ps(*ba->get_sum(m, m_begin + i, k, k_begin));\n\n      __m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin, k, k_begin));\n      __m512 b_mins = _mm512_load_ps(bb->get_min(n, n_begin, k, k_begin));\n      __m512i now = _mm512_load_epi32((__m512i*)(int_c + i * K::N_STEP));\n      __m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      result = _mm512_add_ps(result, _mm512_mul_ps(asum, b_mins));\n      __m512 existing = _mm512_load_ps((__m512*)(c + i * K::N_STEP));\n      result = _mm512_add_ps(result, existing);\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP), result);\n\n      bs = _mm512_load_ps(bb->get_scale(n, n_begin, k, k_begin) + K::TILE_N);\n      b_mins = _mm512_load_ps(bb->get_min(n, n_begin, k, k_begin) + K::TILE_N);\n      now = _mm512_load_si512((__m512i*)(int_c + i * K::N_STEP + K::TILE_N));\n      result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n      result = _mm512_add_ps(result, _mm512_mul_ps(asum, b_mins));\n      existing = _mm512_load_ps((__m512*)(c + i * K::N_STEP + K::TILE_N));\n      result = _mm512_add_ps(result, existing);\n      _mm512_store_ps((__m512*)(c + i * K::N_STEP + K::TILE_N), result);\n    }\n  }\n};\n\n// K2 Signed Int4 K-group quantization kernel (AVX only, no AMX)\n// For K2 MoE - signed int4 range: [-8, 7]\nstruct GemmKernel224Int4SmallKGroup {\n  using dt = uint8_t;  // packed int4 type\n  using output_t = int32_t;\n  static constexpr double ELEMENT_SIZE = 0.5;\n  static const int VNNI_BLK = 4;\n\n  static const int M_STEP = 1;\n  static const int N_STEP = 32;\n  static const int K_STEP = 32;\n\n  static inline const int N_BLOCK = 256;\n  // K_BLOCK should match k_group_size for proper scaling\n  static inline const int K_BLOCK = 7168;  // Will be overridden by k_group_size\n\n  static std::string name() { return \"K2_INT4_KGROUP\"; }\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n  static void config() {}\n\n  alignas(64) static constexpr uint8_t hi_mask_arr[32] = {\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0,\n      0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0, 0xF0};\n\n  alignas(64) static constexpr uint8_t lo_mask_arr[32] = {\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F,\n      0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F, 0x0F};\n\n  alignas(64) static constexpr uint8_t sign_xor_arr[32] = {\n      0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88,\n      0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88, 0x88};\n  static __m256i hi_mask() { return *((__m256i*)(&hi_mask_arr[0])); }\n  static __m256i lo_mask() { return *((__m256i*)(&lo_mask_arr[0])); }\n  static __m256i sign_xor_mask() { return *((__m256i*)(&sign_xor_arr[0])); }\n\n  using BufferA = BufferASmallKGroupImpl<GemmKernel224Int4SmallKGroup>;\n  using BufferB = BufferBInt4KGroupImpl<GemmKernel224Int4SmallKGroup>;  // Use new signed int4 buffer\n  using BufferC = BufferCReduceImpl<GemmKernel224Int4SmallKGroup>;\n\n  // K-group aware AVX kernel for signed int4\n  static inline __m512i compressed_int4_to_int8_avx512(__m256i b256) {\n    b256 = _mm256_xor_si256(b256, sign_xor_mask());\n    __m256i b_hi = _mm256_and_si256(b256, hi_mask());\n    __m256i b_lo = _mm256_slli_epi16(_mm256_andnot_si256(hi_mask(), b256), 4);\n\n    __m256i unpack_lo = _mm256_unpacklo_epi8(b_lo, b_hi);\n    __m256i unpack_hi = _mm256_unpackhi_epi8(b_lo, b_hi);\n    __m512i result = _mm512_inserti64x4(_mm512_castsi256_si512(unpack_lo), unpack_hi, 1);\n    const __m512i lane_shuffle = _mm512_set_epi64(7, 6, 3, 2, 5, 4, 1, 0);\n    return _mm512_permutexvar_epi64(lane_shuffle, result);\n  }\n  static inline void integer_mat_vec_kgroup(int m, int n, int k, int k_group_size, BufferA* ba, BufferB* bb,\n                                            BufferC* bc, int ith, int nth) {\n    auto [n_start, n_end] = split_range_n(n, ith, nth);\n    for (int m_begin = 0; m_begin < m; m_begin++) {\n      float* c = bc->get_submat(m, n, m_begin, n_start);\n      __m512i* a512 = (__m512i*)ba->get_submat(m, k, m_begin, 0);\n\n      for (int n_block_begin = n_start; n_block_begin < n_end; n_block_begin++) {\n        __m256i* b256 = (__m256i*)bb->get_submat(n, k, n_block_begin, 0);\n        float* as = (float*)ba->get_scale(m, m_begin, k, 0);\n        float* bs = (float*)bb->get_scale(n, n_block_begin, k, 0);\n\n        __m512 sum = _mm512_setzero_ps();\n#define WORK_K_BLOCK(k_block)                                                                     \\\n  {                                                                                               \\\n    __m256 abscale0 = _mm256_set1_ps(as[(k_block) * 2] * bs[(k_block) * 2]);                      \\\n    __m256 abscale1 = _mm256_set1_ps(as[(k_block) * 2 + 1] * bs[(k_block) * 2 + 1]);              \\\n    __m512 abscale = _mm512_insertf32x8(_mm512_castps256_ps512(abscale0), abscale1, 1);           \\\n    __m512i mul = _mm512_setzero_si512();                                                         \\\n    mul = _mm512_dpbssd_epi32(mul, a512[k_block], compressed_int4_to_int8_avx512(b256[k_block])); \\\n    sum = _mm512_add_ps(sum, _mm512_mul_ps(abscale, _mm512_cvtepi32_ps(mul)));                    \\\n  }\n\n        for (int k_block = 0; k_block < k / 64; k_block += 2) {\n          WORK_K_BLOCK(k_block);\n          WORK_K_BLOCK(k_block + 1);\n        }\n\n        c[n_block_begin - n_start] = _mm512_reduce_add_ps(sum) / 16;\n      }\n    }\n  }\n};\n\ninline void vec_mul_kgroup(int m, int n, int k, int k_group_size,\n                           std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,\n                           std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,\n                           std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {\n  GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void mat_mul_kgroup(int m, int n, int k, int k_group_size,\n                           std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferA> ba,\n                           std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferB> bb,\n                           std::shared_ptr<GemmKernel224Int4SmallKGroup::BufferC> bc, int ith, int nth) {\n  GemmKernel224Int4SmallKGroup::integer_mat_vec_kgroup(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\n// New k-group aware matrix multiplication function\ntemplate <typename K, bool amx_or_avx = true>\nvoid integer_mat_mul_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb,\n                            typename K::BufferC* bc, int ith, int nth) {\n  assert(n % K::N_STEP == 0);\n  assert(k % K::K_STEP == 0);\n  assert(k % k_group_size == 0);\n\n  auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n  // Process by k_groups\n  for (int k_group_begin = 0; k_group_begin < k; k_group_begin += k_group_size) {\n    for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {\n        float* c = bc->get_submat(m, n, m_begin, n_begin);\n        int32_t* int_c = bc->get_int_submat(m, n, m_begin, n_begin);\n\n        // Initialize float c to zero at the very beginning\n        if (k_group_begin == 0) {\n          for (int i = 0; i < K::M_STEP && m_begin + i < m; i++) {\n            for (int j = 0; j < K::N_STEP; j++) {\n              c[i * K::N_STEP + j] = 0.0f;\n            }\n          }\n        }\n        for (int k_begin = k_group_begin; k_begin < std::min(k, k_group_begin + k_group_size); k_begin += K::K_STEP) {\n          if constexpr (amx_or_avx && AMX_AVAILABLE) {\n            K::amx_kernel(m, n, k, m_begin, n_begin, k_begin, int_c, ba, bb, k_group_size);\n          } else {\n            K::avx_kernel(m, n, k, m_begin, n_begin, k_begin, int_c, ba, bb, k_group_size);\n          }\n        }\n        // }\n\n        // Apply scale and accumulate to float buffer at end of k_group\n        K::apply_scale_kgroup(m, n, m_begin, n_begin, k_group_begin, c, int_c, ba, bb, k, k_group_size);\n      }\n    }\n  }\n}\n\n// Convenience functions for k-group kernels\ninline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4KGroup::BufferA> ba,\n                           std::shared_ptr<GemmKernel224Int4KGroup::BufferB> bb,\n                           std::shared_ptr<GemmKernel224Int4KGroup::BufferC> bc, int ith, int nth) {\n  integer_mat_mul_kgroup<GemmKernel224Int4KGroup, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224Int4KGroup::BufferA> ba,\n                           std::shared_ptr<GemmKernel224Int4KGroup::BufferB> bb,\n                           std::shared_ptr<GemmKernel224Int4KGroup::BufferC> bc, int ith, int nth) {\n  integer_mat_mul_kgroup<GemmKernel224Int4KGroup, true>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\n// Convenience functions for k-group kernels\ninline void vec_mul_kgroup(int m, int n, int k, int k_group_size,\n                           std::shared_ptr<GemmKernel224Int4_1KGroup::BufferA> ba,\n                           std::shared_ptr<GemmKernel224Int4_1KGroup::BufferB> bb,\n                           std::shared_ptr<GemmKernel224Int4_1KGroup::BufferC> bc, int ith, int nth) {\n  integer_mat_mul_kgroup<GemmKernel224Int4_1KGroup, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith,\n                                                           nth);\n}\n\ninline void mat_mul_kgroup(int m, int n, int k, int k_group_size,\n                           std::shared_ptr<GemmKernel224Int4_1KGroup::BufferA> ba,\n                           std::shared_ptr<GemmKernel224Int4_1KGroup::BufferB> bb,\n                           std::shared_ptr<GemmKernel224Int4_1KGroup::BufferC> bc, int ith, int nth) {\n  integer_mat_mul_kgroup<GemmKernel224Int4_1KGroup, true>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith,\n                                                          nth);\n}\n\n// Convenience functions for k-group kernels\ninline void vec_mul_kgroup(int m, int n, int k, int k_group_size,\n                           std::shared_ptr<GemmKernel224Int4_1_LowKGroup::BufferA> ba,\n                           std::shared_ptr<GemmKernel224Int4_1_LowKGroup::BufferB> bb,\n                           std::shared_ptr<GemmKernel224Int4_1_LowKGroup::BufferC> bc, int ith, int nth) {\n  integer_mat_mul_kgroup<GemmKernel224Int4_1_LowKGroup, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith,\n                                                               nth);\n}\n\ninline void mat_mul_kgroup(int m, int n, int k, int k_group_size,\n                           std::shared_ptr<GemmKernel224Int4_1_LowKGroup::BufferA> ba,\n                           std::shared_ptr<GemmKernel224Int4_1_LowKGroup::BufferB> bb,\n                           std::shared_ptr<GemmKernel224Int4_1_LowKGroup::BufferC> bc, int ith, int nth) {\n  integer_mat_mul_kgroup<GemmKernel224Int4_1_LowKGroup, true>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith,\n                                                              nth);\n}\n\n}  // namespace amx\n\n#endif  // AMX_KERNELS_HPP\n"
  },
  {
    "path": "kt-kernel/operators/amx/la/amx_quantization.hpp",
    "content": "#ifndef AMX_QUANTIZATION_HPP\n#define AMX_QUANTIZATION_HPP\n#include <algorithm>\n#include <cmath>\n\n#include \"amx_config.hpp\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"utils.hpp\"\n\nnamespace amx {\n\nstruct blocks_aligned_q4_0_ref {\n  static constexpr int block_size = 64;\n  static constexpr double bytes_per_element = double(sizeof(ggml_half) + double(block_size) / 2) / block_size;\n\n  ggml_half* d;\n  uint8_t* qs;\n\n  blocks_aligned_q4_0_ref offset(size_t blck_cnt) const {\n    blocks_aligned_q4_0_ref re;\n    re.d = &d[blck_cnt];\n    re.qs = &qs[blck_cnt * block_size / 2];\n    return re;\n  }\n\n  static size_t expected_data_size(int64_t k) {\n    assert(k % block_size == 0);\n    return (sizeof(ggml_half) + block_size / 2) * (k / block_size);\n  }\n\n  uint8_t* get_qs(int block_idx) { return offset_pointer(qs, block_idx * (block_size / 2)); }\n\n  static blocks_aligned_q4_0_ref quantize(const float* RESTRICT x, void* RESTRICT data, int64_t k) {\n    assert(reinterpret_cast<intptr_t>(data) % 64 == 0);\n\n    blocks_aligned_q4_0_ref re;\n    re.qs = reinterpret_cast<uint8_t*>(data);\n    re.d = reinterpret_cast<ggml_half*>(offset_pointer(re.qs, k / 2));\n\n    static const int qk = block_size;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n      float amax = 0.0f;  // absolute max\n      float max = 0.0f;\n\n      for (int j = 0; j < qk; j++) {\n        const float v = x[i * qk + j];\n        if (amax < fabsf(v)) {\n          amax = fabsf(v);\n          max = v;\n        }\n      }\n\n      const float d = max / -8;\n      const float id = d ? 1.0f / d : 0.0f;\n\n      re.d[i] = GGML_FP32_TO_FP16(d);\n\n      for (int j = 0; j < qk / 2; ++j) {\n        const float x0 = x[i * qk + 0 + j] * id;\n        const float x1 = x[i * qk + qk / 2 + j] * id;\n\n        const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));\n        const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));\n\n        re.get_qs(i)[j] = xi0;\n        re.get_qs(i)[j] |= xi1 << 4;\n      }\n    }\n    return re;\n  }\n\n  void dequantize(float* y, int64_t k) {\n    static const int qk = block_size;\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n      const float d = GGML_FP16_TO_FP32(this->d[i]);\n\n      for (int j = 0; j < qk / 2; ++j) {\n        const int x0 = (get_qs(i)[j] & 0x0F) - 8;\n        const int x1 = (get_qs(i)[j] >> 4) - 8;\n\n        y[i * qk + j + 0] = x0 * d;\n        y[i * qk + j + qk / 2] = x1 * d;\n      }\n    }\n  }\n};\n\nstruct blocks_aligned_q8_0_ref {\n  static constexpr int block_size = 64;\n  static constexpr double bytes_per_element = double(sizeof(ggml_half) + block_size) / block_size;\n\n  ggml_half* d;\n  int8_t* qs;\n\n  blocks_aligned_q8_0_ref offset(size_t blck_cnt) const {\n    blocks_aligned_q8_0_ref re;\n    re.d = &d[blck_cnt];\n    re.qs = &qs[blck_cnt * block_size];\n    return re;\n  }\n\n  static size_t expected_data_size(int64_t k) {\n    assert(k % block_size == 0);\n    return (sizeof(ggml_half) + block_size) * (k / block_size);\n  }\n  int8_t* get_qs(int block_idx) { return offset_pointer(qs, block_idx * block_size); }\n\n  static blocks_aligned_q8_0_ref quantize(const float* RESTRICT x, void* RESTRICT data, int64_t k) {\n    assert(k % block_size == 0);\n    assert(reinterpret_cast<intptr_t>(data) % 64 == 0);\n\n    blocks_aligned_q8_0_ref re;\n    re.qs = reinterpret_cast<int8_t*>(data);\n    re.d = reinterpret_cast<ggml_half*>(offset_pointer(re.qs, k));\n    const int nb = k / block_size;\n\n    for (int i = 0; i < nb; i++) {\n      float amax = 0.0f;  // absolute max\n\n      for (int j = 0; j < block_size; j++) {\n        const float v = x[i * block_size + j];\n        amax = MAX(amax, fabsf(v));\n      }\n\n      const float d = amax / ((1 << 7) - 1);\n      const float id = d ? 1.0f / d : 0.0f;\n\n      re.d[i] = GGML_FP32_TO_FP16(d);\n\n      for (int j = 0; j < block_size; ++j) {\n        const float x0 = x[i * block_size + j] * id;\n        re.get_qs(i)[j] = roundf(x0);\n      }\n    }\n    return re;\n  }\n\n  void dequantize(float* y, int64_t k) {\n    static const int qk = block_size;\n\n    assert(k % qk == 0);\n\n    const int nb = k / qk;\n\n    for (int i = 0; i < nb; i++) {\n      const float d = GGML_FP16_TO_FP32(this->d[i]);\n\n      for (int j = 0; j < qk; ++j) {\n        y[i * qk + j] = get_qs(i)[j] * d;\n      }\n    }\n  }\n};\n\n#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)\n\ntemplate <typename Block>\nstruct Dequantizer {};\n\nconst __m256i MASK256_LO = _mm256_set1_epi8(0x0f);\nconst __m256i MASK256_4HI = _mm256_set1_epi8(0xf0);\nconst __m256i MASK256_8 = _mm256_set1_epi8(8);\n\nconst __m512i MASK512_LO = _mm512_set1_epi8(0x0f);\nconst __m512i MASK512_4HI = _mm512_set1_epi8(0xf0);\nconst __m512i MASK512_8 = _mm512_set1_epi8(8);\n\ninline __m256i dequant4x32(const uint8_t* qs) {\n  const __m128i aux128 = _mm_loadu_si128((const __m128i*)qs);\n  return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), MASK256_LO);\n}\n\ninline __m256i unaligned_copy8x32(const int8_t* qs) { return _mm256_loadu_si256((const __m256i*)qs); }\n\ninline __m512i copy8x64(const int8_t* qs) { return _mm512_load_si512((const __m512i*)qs); }\n\ninline __m256i lo4bit(const uint8_t* qs) {\n  return _mm256_and_si256(_mm256_loadu_si256((const __m256i*)qs), MASK256_LO);\n}\ninline __m256i hi4bit(const uint8_t* qs) {\n  return _mm256_srli_epi16(_mm256_and_si256(_mm256_loadu_si256((const __m256i*)qs), MASK256_4HI), 4);\n}\n\ninline __m128i make_q4K_scale_and_min(const uint8_t* scales8) {\n  __m128i re;\n  uint32_t* aux32 = (uint32_t*)&re;\n  const uint16_t* scales = (const uint16_t*)scales8;\n  const uint32_t a0 = scales[0] | (scales[1] << 16);\n  const uint32_t a1 = scales[2] | (scales[3] << 16);\n  const uint32_t a2 = scales[4] | (scales[5] << 16);\n  aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);\n  aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);\n  aux32[2] = a1 & 0x3f3f3f3f;\n  aux32[0] = a0 & 0x3f3f3f3f;\n  // aux32[1:0] is scale\n  // aux32[3:2] is min\n  return re;\n}\n\ninline __m256i merge_q8K_bsum(block_q8_K* b) {\n  return _mm256_madd_epi16(_mm256_loadu_si256((__m256i*)b->bsums), _mm256_set1_epi16(1));\n}\n\ninline __m512i _mm512_dpbusd_epi32_compat(__m512i src, __m512i a, __m512i b) {\n#if defined(__AVX512VNNI__)\n  return _mm512_dpbusd_epi32(src, a, b);\n#else\n  const __m512i mask_lo = _mm512_set1_epi16(0x00FF);\n  const __m512i ones16 = _mm512_set1_epi16(1);\n\n  __m512i a_even = _mm512_and_si512(a, mask_lo);\n  __m512i b_even = _mm512_srai_epi16(_mm512_slli_epi16(b, 8), 8);\n\n  __m512i a_odd = _mm512_srli_epi16(a, 8);\n  __m512i b_odd = _mm512_srai_epi16(b, 8);\n\n  __m512i prod_even = _mm512_mullo_epi16(a_even, b_even);\n  __m512i prod_odd = _mm512_mullo_epi16(a_odd, b_odd);\n\n  __m512i sum_even = _mm512_madd_epi16(prod_even, ones16);\n  __m512i sum_odd = _mm512_madd_epi16(prod_odd, ones16);\n\n  return _mm512_add_epi32(src, _mm512_add_epi32(sum_even, sum_odd));\n#endif\n}\n\ninline __m512i _mm512_dpbssd_epi32(__m512i src, __m512i a, __m512i b) {\n  __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);\n  __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);\n  __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);\n  __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);\n\n  b_lo = _mm256_sign_epi8(b_lo, a_lo);\n  b_hi = _mm256_sign_epi8(b_hi, a_hi);\n\n  b = _mm512_inserti64x4(b, b_lo, 0);\n  b = _mm512_inserti64x4(b, b_hi, 1);\n\n  a = _mm512_abs_epi8(a);\n\n  return _mm512_dpbusd_epi32_compat(src, a, b);\n}\n\n}  // namespace amx\n\n#endif  // AMX_QUANTIZATION_HPP"
  },
  {
    "path": "kt-kernel/operators/amx/la/amx_raw_buffers.hpp",
    "content": "#ifndef AMX_RAW_BUFFERS_HPP\n#define AMX_RAW_BUFFERS_HPP\n\n/**\n * @file amx_raw_buffers.hpp\n * @brief Raw data format buffer management (FP8, BF16, etc.)\n *\n * 本文件实现原精度格式的缓冲区管理，用于 DeepSeek V3.2 等原精度推理。\n *\n * 缓冲区类型：\n * - BufferAFP8Impl: 输入激活缓冲区，支持动态 FP8 量化\n * - BufferBFP8Impl: 权重缓冲区，FP8 格式 + 128x128 块缩放\n * - BufferBFP8BlockImpl: 优化的块量化权重缓冲区\n *\n * 内存布局：\n * - FP8 数据：1 字节/元素\n * - Scale：4 字节/块（BufferB 每 128x128 块一个，BufferA 每 128 行一个）\n */\n\n#include <algorithm>\n#include <cassert>\n#include <cstdint>\n#include <cstdio>\n#include <cstring>\n#include <limits>\n#include <vector>\n\n#include \"amx_config.hpp\"\n#include \"amx_utils.hpp\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"pack.hpp\"\n#include \"utils.hpp\"\n\nnamespace amx {\n\n// ============================================================================\n// BufferAFP8Impl: FP8 激活缓冲区（支持动态量化）\n// ============================================================================\n/* 物理布局(按 bf16 元素数)：\n * 逻辑矩阵 A 为 (m, k) 行主序，m pad 到 max_m(=m_block_size，M_STEP 的倍数)。\n * 存储顺序：\n *   k_block(K_BLOCK 列) → m_block(M_STEP 行) → k_step(K_STEP 列) → (M_STEP×K_STEP) 行主序 tile。\n * 因此可视为 5D：\n *   a[k_blocks][m_blocks][k_steps][M_STEP][K_STEP]，\n *   k_blocks = ceil(k / K_BLOCK)，m_blocks = max_m / M_STEP，\n *   k_steps = K_BLOCK / K_STEP（最后一个 k_block 可能更小）。\n * get_submat(m_begin, k_begin) 返回连续的 (M_STEP×K_STEP) tile。\n */\ntemplate <typename K>\nstruct BufferABF16Impl {\n  ggml_bf16_t* a;\n  int max_m, k;\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n\n  static size_t required_size(int max_m, int k) { return sizeof(ggml_bf16_t) * max_m * k; }\n\n  BufferABF16Impl(int max_m, int k, void* ptr) : max_m(max_m), k(k) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    assert(max_m % M_STEP == 0);\n    assert(k % K_STEP == 0);\n    a = reinterpret_cast<ggml_bf16_t*>(ptr);\n  }\n\n  void set_data(void* new_ptr) { a = reinterpret_cast<ggml_bf16_t*>(new_ptr); }\n\n  void from_mat(int m, ggml_bf16_t* src, int ith, int nth) {\n    assert(m <= max_m);\n    assert(ith == 0 && nth == 1);\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            __m512i* s = (__m512i*)(src + (m_begin + i) * k + k_block_begin + k_begin);\n            __m512i* d =\n                (__m512i*)(a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP);\n            avx512_copy_32xbf16(s, d);\n          }\n        }\n      }\n    }\n  }\n\n  ggml_bf16_t* get_submat(int m, int k, int m_begin, int k_begin) {\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;\n  }\n};\n\n// ============================================================================\n// BufferB\n// ============================================================================\n\n/**\n * @brief BF16 BufferB\n * 物理布局(按 bf16 元素数)：\n * 逻辑矩阵 B 为 (n, k) 行主序（用于 NT GEMM），n 按 N_BLOCK 分块。\n * 存储顺序：\n *   n_block(N_BLOCK 行) → k_block(K_BLOCK 列) → n_step(N_STEP 行) → k_step(K_STEP 列)\n *   → (N_STEP×K_STEP) tile；每个 tile 内部再对两个 16×16 子块做 transpose，\n *   以匹配 AMX BTile 的 VNNI 布局（TILE_K/VNNI_BLK × TILE_N*VNNI_BLK）。\n * 因此可视为 6D：\n *   b[n_blocks][k_blocks][n_steps][k_steps][N_STEP][K_STEP]，\n *   n_blocks = ceil(n / N_BLOCK)，k_blocks = ceil(k / K_BLOCK)，\n *   n_steps = N_BLOCK / N_STEP，k_steps = K_BLOCK / K_STEP（尾块可能更小）。\n * get_submat(n_begin, k_begin) 返回连续的 (N_STEP×K_STEP) tile 起始地址。\n * @tparam K Kernel 类型\n */\n\ntemplate <typename K>\nstruct BufferBBF16Impl {\n  ggml_bf16_t* b;\n  int n, k;\n  static constexpr bool SCALE = false;\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n  static constexpr int TILE_N = K::TILE_N;\n  static size_t required_size(int n, int k) { return sizeof(ggml_bf16_t) * n * k; }\n\n  BufferBBF16Impl(int n, int k, void* ptr) : n(n), k(k) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    assert(n % N_STEP == 0);\n    assert(k % K_STEP == 0);\n    b = reinterpret_cast<ggml_bf16_t*>(ptr);\n  }\n  void set_data(void* new_ptr) { b = reinterpret_cast<ggml_bf16_t*>(new_ptr); }\n\n  void from_mat(ggml_bf16_t* src, int ith, int nth) {\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n          for (int i = 0; i < N_STEP; i++) {\n            __m512i* s = (__m512i*)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin);\n            __m512i* d = (__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                    k_begin * N_STEP + i * K_STEP);\n            avx512_copy_32xbf16(s, d);\n          }\n          transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                           n_begin * k_block_size + k_begin * N_STEP));\n          transpose_16x16_32bit((__m512i*)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                           n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));\n        }\n      }\n    }\n  }\n  ggml_bf16_t* get_submat(int n, int k, int n_begin, int k_begin) {\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    n_begin -= n_block_begin;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;\n  }\n};\n\n/**\n * @brief FP8 权重缓冲区\n *\n * 存储 FP8 格式的权重矩阵，每个 128x128 块有一个缩放因子。\n * 这与 DeepSeek V3.2 的原精度格式匹配。\n *\n * @tparam K Kernel 类型\n */\ntemplate <typename K>\nstruct BufferBFP8Impl {\n  uint8_t* b;              // FP8 weight\n  float* d;                // scale_inv [n / k_group_size, k / k_group_size]\n  int n, k, k_group_size;  // k_group_size = 128 in DeepSeek\n\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n  static constexpr bool SCALE = true;\n\n  /**\n   * @brief 计算所需内存大小\n   */\n  static size_t required_size(int n, int k, int k_group_size) {\n    int n_blocks_n = (n + k_group_size - 1) / k_group_size;\n    int n_blocks_k = (k + k_group_size - 1) / k_group_size;\n    return sizeof(uint8_t) * n * k + sizeof(float) * n_blocks_n * n_blocks_k;\n  }\n\n  /**\n   * @brief 构造函数\n   */\n  BufferBFP8Impl(int n, int k, int k_group_size, void* ptr) : n(n), k(k), k_group_size(k_group_size) { set_data(ptr); }\n\n  void set_data(void* ptr) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    b = reinterpret_cast<uint8_t*>(ptr);\n    d = reinterpret_cast<float*>(b + (size_t)n * k);\n  }\n\n  static constexpr int mat_offset[8] = {0, 2, 4, 6, 1, 3, 5, 7};  // fp8 matrix offset for reordering\n  /**\n   * @brief 从原始 FP8 权重加载（已经是量化格式）\n   *\n   * @param b_src FP8 权重源数据 (n-major, n×k)\n   * @param d_src FP32 scale_inv 源数据 (n-major, ceil(n/128)×ceil(k/128))\n   */\n  void from_mat(const uint8_t* b_src, const float* d_src, int ith, int nth) {\n    assert(b != nullptr && d != nullptr);\n    assert(N_STEP == 32 && K_STEP == 32);  // from mat block copy assumes this\n\n    // Copy scales (per 128x128 block). Each thread copies its own n-block range.\n    const int n_blocks_k = (k + k_group_size - 1) / k_group_size;\n    if (d_src != nullptr) {\n      auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n      int bn_start = n_start / k_group_size;\n      int bn_end = (n_end + k_group_size - 1) / k_group_size;\n      memcpy(d + bn_start * n_blocks_k, d_src + bn_start * n_blocks_k,\n             sizeof(float) * (bn_end - bn_start) * n_blocks_k);\n    }\n\n    // Reorder FP8 weights into KT block-major layout (same panel->tile order as BF16 BufferB).\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      int n_step_size = std::min(N_STEP, n_block_size - n_begin);\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n          int k_step_size = std::min(K_STEP, k_block_size - k_begin);\n          // [k_step_size, n_step_size] block copy\n          const uint8_t* block_b_src = b_src + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin;\n          uint64_t* block_b_dst =\n              reinterpret_cast<uint64_t*>(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size +\n                                          (size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP);\n          for (int i = 0; i < 8; i++) {\n            const uint16_t* s = reinterpret_cast<const uint16_t*>(block_b_src + (size_t)i * k * 4);\n            for (int j = 0; j < 16; j++) {\n              uint64_t val = (((uint64_t)s[j])) | (((uint64_t)s[j + (k / 2) * 1]) << 16) |\n                             (((uint64_t)s[j + (k / 2) * 2]) << 32) | (((uint64_t)s[j + (k / 2) * 3]) << 48);\n              block_b_dst[8 * j + mat_offset[i]] = val;\n            }\n          }\n        }\n      }\n    }\n  }\n\n  /**\n   * @brief get scale_inv\n   */\n  float* get_scale(int n, int n_begin, int k, int k_begin) {\n    int n_blocks_k = (k + k_group_size - 1) / k_group_size;\n    int bn = n_begin / k_group_size;\n    int bk = k_begin / k_group_size;\n    return d + bn * n_blocks_k + bk;\n  }\n\n  /**\n   * @brief 获取子矩阵指针\n   */\n  uint8_t* get_submat(int n, int k, int n_begin, int k_begin) {\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    n_begin -= n_block_begin;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size + (size_t)n_begin * k_block_size +\n           (size_t)k_begin * N_STEP;\n  }\n\n  /**\n   * @brief Inverse mapping for mat_offset used in to_mat\n   * mat_offset = {0, 2, 4, 6, 1, 3, 5, 7}\n   * inv_mat_offset[mat_offset[i]] = i\n   */\n  static constexpr int inv_mat_offset[8] = {0, 4, 1, 5, 2, 6, 3, 7};\n\n  /**\n   * @brief Unpack FP8 weights from KT block-major layout back to n-major layout\n   *\n   * This is the inverse operation of from_mat.\n   *\n   * @param b_dst FP8 输出缓冲区 (n-major, n×k)\n   * @param d_dst FP32 scale_inv 输出缓冲区 (n-major, ceil(n/128)×ceil(k/128))\n   * @param ith Thread index\n   * @param nth Total number of threads\n   */\n  void to_mat(uint8_t* b_dst, float* d_dst, int ith, int nth) const {\n    assert(b != nullptr && d != nullptr);\n    assert(N_STEP == 32 && K_STEP == 32);\n\n    // Calculate N_BLOCK range for this thread\n    // Unlike split_range_n which gives one N_BLOCK per thread, we need to handle\n    // the case where nth < n/N_BLOCK (fewer threads than blocks)\n    int total_n_blocks = (n + N_BLOCK - 1) / N_BLOCK;\n    int blocks_per_thread = (total_n_blocks + nth - 1) / nth;\n    int start_n_block_idx = ith * blocks_per_thread;\n    int end_n_block_idx = std::min((ith + 1) * blocks_per_thread, total_n_blocks);\n\n    // Copy scales (per 128x128 block). Each thread copies its own n-block range.\n    const int n_blocks_k = (k + k_group_size - 1) / k_group_size;\n    if (d_dst != nullptr) {\n      int bn_start = start_n_block_idx;\n      int bn_end = end_n_block_idx;\n      memcpy(d_dst + bn_start * n_blocks_k, d + bn_start * n_blocks_k,\n             sizeof(float) * (bn_end - bn_start) * n_blocks_k);\n    }\n\n    // Reorder FP8 weights back to n-major layout (inverse of from_mat)\n    // Process each N_BLOCK assigned to this thread\n    for (int n_block_idx = start_n_block_idx; n_block_idx < end_n_block_idx; n_block_idx++) {\n      int n_block_begin = n_block_idx * N_BLOCK;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            // Source: packed layout (KT block-major)\n            const uint64_t* block_b_src =\n                reinterpret_cast<const uint64_t*>(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size +\n                                                  (size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP);\n\n            // Destination: n-major layout\n            uint8_t* block_b_dst = b_dst + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin;\n\n            // Inverse of from_mat transformation\n            for (int packed_i = 0; packed_i < 8; packed_i++) {\n              int i = inv_mat_offset[packed_i];\n              uint16_t* d_row = reinterpret_cast<uint16_t*>(block_b_dst + (size_t)i * k * 4);\n              for (int j = 0; j < 16; j++) {\n                uint64_t val = block_b_src[8 * j + packed_i];\n                d_row[j] = (uint16_t)(val & 0xFFFF);\n                d_row[j + (k / 2) * 1] = (uint16_t)((val >> 16) & 0xFFFF);\n                d_row[j + (k / 2) * 2] = (uint16_t)((val >> 32) & 0xFFFF);\n                d_row[j + (k / 2) * 3] = (uint16_t)((val >> 48) & 0xFFFF);\n              }\n            }\n          }\n        }\n      }\n    }\n  }\n};\n\n// ============================================================================\n// BufferCFP8Impl: FP32 输出缓冲区\n// ============================================================================\n\n/**\n * @brief FP32 输出缓冲区\n *\n * 存储 FP32 格式的累加器，支持转换为 BF16 输出\n *\n * @tparam K Kernel 类型\n */\ntemplate <typename K>\nstruct BufferCFP32Impl {\n  float* c;\n  int max_m, n;\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n  // 物理布局(按 float 元素数)：\n  // 逻辑矩阵 C 为 (max_m, n) 行主序，max_m 为 M_STEP 的倍数，\n  // n 按 N_BLOCK 分块。\n  // 存储顺序：\n  //   n_block(N_BLOCK 列) → m_block(M_STEP 行) → n_step(N_STEP 列) → (M_STEP×N_STEP) 行主序 tile。\n  // 因此可视为 5D：\n  //   c[n_blocks][m_blocks][n_steps][M_STEP][N_STEP]，\n  //   n_blocks = ceil(n / N_BLOCK)，m_blocks = max_m / M_STEP，\n  //   n_steps = N_BLOCK / N_STEP（尾块可能更小）。\n  // get_submat(m_begin, n_begin) 返回连续的 (M_STEP×N_STEP) tile 起始地址。\n\n  static size_t required_size(int max_m, int n) { return sizeof(float) * max_m * n; }\n\n  BufferCFP32Impl(int max_m, int n, void* ptr) : max_m(max_m), n(n) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    assert(max_m % M_STEP == 0);\n    assert(n % N_STEP == 0);\n    c = reinterpret_cast<float*>(ptr);\n  }\n\n  void set_data(void* new_ptr) { c = reinterpret_cast<float*>(new_ptr); }\n\n  void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {\n    assert(m <= max_m);\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n          __m512* x0 =\n              (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);\n          __m512* x1 =\n              (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16);\n          avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin));\n        }\n      }\n    }\n  }\n\n  float* get_submat(int m, int n, int m_begin, int n_begin) {\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    n_begin -= n_block_begin;\n    return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;\n  }\n};\n\ntemplate <typename K>\nstruct BufferCFP32ReduceImpl {\n  float* c;\n  float* reduce_buf;\n  int max_m, n;\n\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n\n  static size_t required_size(int max_m, int n) { return sizeof(float) * (size_t)max_m * n * 2; }\n\n  BufferCFP32ReduceImpl(int max_m, int n, void* ptr) : max_m(max_m), n(n) {\n    assert(max_m % M_STEP == 0);\n    assert(n % N_STEP == 0);\n    set_data(ptr);\n  }\n\n  void set_data(void* ptr) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    c = reinterpret_cast<float*>(ptr);\n    reduce_buf = c + (size_t)max_m * n;\n  }\n\n  void to_mat(int m, ggml_bf16_t* dst, int ith, int nth) {\n    assert(m <= max_m);\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n          __m512* x0 =\n              (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);\n          __m512* x1 =\n              (__m512*)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP + 16);\n          avx512_32xfp32_to_32xbf16(x0, x1, (__m512i*)(dst + (m_begin + i) * n + n_block_begin + n_begin));\n        }\n      }\n    }\n  }\n\n  float* get_submat(int m, int n, int m_begin, int n_begin) {\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    n_begin -= n_block_begin;\n    return c + (size_t)m_block_size * n_block_begin + (size_t)m_begin * n_block_size + (size_t)n_begin * M_STEP;\n  }\n\n  float* get_reduce_submat(int m, int n, int m_begin, int n_begin) {\n    int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    n_begin -= n_block_begin;\n    return reduce_buf + (size_t)m_block_size * n_block_begin + (size_t)m_begin * n_block_size +\n           (size_t)n_begin * M_STEP;\n  }\n};\n\n// ============================================================================\n// BufferBFP8PerChannelImpl: FP8 权重缓冲区（Per Channel 量化）\n// ============================================================================\n\n/**\n * @brief FP8 Per-Channel 权重缓冲区\n *\n * 存储 FP8 格式的权重矩阵，每个输出通道（行）有一个缩放因子。\n * 这与 GLM-4.7-FP8 的 per-channel 量化格式匹配。\n *\n * 与 BufferBFP8Impl (block-wise) 的区别：\n * - Block-wise: scale shape = [n/128, k/128], 每 128x128 块一个 scale\n * - Per-channel: scale shape = [n], 每行一个 scale\n *\n * @tparam K Kernel 类型\n */\ntemplate <typename K>\nstruct BufferBFP8PerChannelImpl {\n  uint8_t* b;  // FP8 weight [n, k]\n  float* d;    // per-channel scale [n]\n  int n, k;\n\n  static constexpr int N_STEP = K::N_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  static constexpr int N_BLOCK = K::N_BLOCK;\n  static constexpr int K_BLOCK = K::K_BLOCK;\n  static constexpr bool SCALE = true;\n  static constexpr bool PER_CHANNEL = true;\n\n  /**\n   * @brief 计算所需内存大小\n   * weight: n * k bytes (FP8)\n   * scale: n * sizeof(float) bytes\n   */\n  static size_t required_size(int n, int k) { return sizeof(uint8_t) * n * k + sizeof(float) * n; }\n\n  /**\n   * @brief 构造函数\n   */\n  BufferBFP8PerChannelImpl(int n, int k, void* ptr) : n(n), k(k) { set_data(ptr); }\n\n  void set_data(void* ptr) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    b = reinterpret_cast<uint8_t*>(ptr);\n    d = reinterpret_cast<float*>(b + (size_t)n * k);\n  }\n\n  static constexpr int mat_offset[8] = {0, 2, 4, 6, 1, 3, 5, 7};  // fp8 matrix offset for reordering\n\n  /**\n   * @brief 从原始 FP8 权重加载（per-channel 量化格式）\n   *\n   * @param b_src FP8 权重源数据 (n-major, n×k)\n   * @param d_src FP32 per-channel scale 源数据 (shape: [n] or [n, 1])\n   */\n  void from_mat(const uint8_t* b_src, const float* d_src, int ith, int nth) {\n    assert(b != nullptr && d != nullptr);\n    assert(N_STEP == 32 && K_STEP == 32);\n\n    // Copy per-channel scales. Each thread copies its own n-block range.\n    if (d_src != nullptr) {\n      auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n      memcpy(d + n_start, d_src + n_start, sizeof(float) * (n_end - n_start));\n    }\n\n    // Reorder FP8 weights into KT block-major layout (same as BufferBFP8Impl)\n    auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n    int n_block_begin = n_start;\n    int n_block_size = n_end - n_block_begin;\n    for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n      int n_step_size = std::min(N_STEP, n_block_size - n_begin);\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n        int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n        for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n          int k_step_size = std::min(K_STEP, k_block_size - k_begin);\n          // [k_step_size, n_step_size] block copy\n          const uint8_t* block_b_src = b_src + (size_t)(n_block_begin + n_begin) * k + k_block_begin + k_begin;\n          uint64_t* block_b_dst =\n              reinterpret_cast<uint64_t*>(b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size +\n                                          (size_t)n_begin * k_block_size + (size_t)k_begin * N_STEP);\n          for (int i = 0; i < 8; i++) {\n            const uint16_t* s = reinterpret_cast<const uint16_t*>(block_b_src + (size_t)i * k * 4);\n            for (int j = 0; j < 16; j++) {\n              uint64_t val = (((uint64_t)s[j])) | (((uint64_t)s[j + (k / 2) * 1]) << 16) |\n                             (((uint64_t)s[j + (k / 2) * 2]) << 32) | (((uint64_t)s[j + (k / 2) * 3]) << 48);\n              block_b_dst[8 * j + mat_offset[i]] = val;\n            }\n          }\n        }\n      }\n    }\n  }\n\n  /**\n   * @brief 获取行 n_begin 开始的 per-channel scale 指针\n   */\n  float* get_scale(int n_begin) { return d + n_begin; }\n\n  /**\n   * @brief 获取子矩阵指针\n   */\n  uint8_t* get_submat(int n, int k, int n_begin, int k_begin) {\n    int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n    n_begin -= n_block_begin;\n    int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n    int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n    k_begin -= k_block_begin;\n    int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n    return b + (size_t)n_block_begin * k + (size_t)k_block_begin * n_block_size + (size_t)n_begin * k_block_size +\n           (size_t)k_begin * N_STEP;\n  }\n};\n\n}  // namespace amx\n\n#endif  // AMX_RAW_BUFFERS_HPP\n"
  },
  {
    "path": "kt-kernel/operators/amx/la/amx_raw_kernels.hpp",
    "content": "#ifndef AMX_RAW_KERNELS_HPP\n#define AMX_RAW_KERNELS_HPP\n\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n#include <cstdint>\n#include <string>\n\n#include \"amx_config.hpp\"\n#include \"amx_raw_buffers.hpp\"\n#include \"amx_utils.hpp\"\n#include \"llama.cpp/ggml-impl.h\"\n\nnamespace amx {\n\nstruct GemmKernel224BF16 {\n  using dt = ggml_bf16_t;\n  using output_t = float;\n  static constexpr double ELEMENT_SIZE = 2;\n  static const int TILE_M = 16;\n  static const int TILE_K = 32;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 2;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  static inline const int N_BLOCK = 256;\n  static inline const int K_BLOCK = 1792;\n  static std::string name() { return \"BF16\"; }\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n#ifdef HAVE_AMX\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 32\n    for (int i = 0; i < 2; i++) tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));\n\n    // size is 16 x 32\n    for (int i = 2; i < 4; i++) tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++) tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n#endif\n  }\n\n  static void load_a(dt* a, size_t lda) {\n#ifdef HAVE_AMX\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n#else\n    (void)a;\n    (void)lda;\n#endif\n  }\n\n  static void load_b(dt* b, size_t ldb) {\n#ifdef HAVE_AMX\n    _tile_loadd(2, b, ldb);\n    _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);\n#else\n    (void)b;\n    (void)ldb;\n#endif\n  }\n\n  static void clean_c() {\n#ifdef HAVE_AMX\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n#endif\n  }\n\n  static void load_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void store_c(output_t* c, size_t ldc) {\n#ifdef HAVE_AMX\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n#else\n    (void)c;\n    (void)ldc;\n#endif\n  }\n\n  static void run_tile() {\n#ifdef HAVE_AMX\n    _tile_dpbf16ps(4, 0, 2);\n    _tile_dpbf16ps(5, 0, 3);\n    _tile_dpbf16ps(6, 1, 2);\n    _tile_dpbf16ps(7, 1, 3);\n#endif\n  }\n  using BufferA = BufferABF16Impl<GemmKernel224BF16>;\n  using BufferB = BufferBBF16Impl<GemmKernel224BF16>;\n  using BufferC = BufferCFP32Impl<GemmKernel224BF16>;\n\n  // Basic AVX kernel for BF16: process entire K_BLOCK\n  static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,\n                         BufferB* bb) {\n    __m512* c512 = (__m512*)c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n\n    // Zero out accumulator at the start of k_block\n    if (k_block_begin == 0) {\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        c512[m_i * 2] = _mm512_setzero_ps();\n        c512[m_i * 2 + 1] = _mm512_setzero_ps();\n      }\n    }\n\n    // Process entire K_BLOCK\n    for (int k_begin = 0; k_begin < K_BLOCK && k_block_begin + k_begin < k; k_begin += K_STEP) {\n      int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n      __m512bh* b512 = (__m512bh*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        for (int k_i = 0; k_i < 16; k_i++) {\n          __m512bh ma = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma, b512[k_i]);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma, b512[16 + k_i]);\n        }\n      }\n    }\n  }\n\n  // Optimized AVX kernel: process 4 k_i at once, unroll m rows by 2\n  static void avx_kernel_4(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,\n                           BufferB* bb) {\n    __m512* c512 = (__m512*)c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n\n    // Zero out accumulator at the start of k_block\n    if (k_block_begin == 0) {\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        c512[m_i * 2] = _mm512_setzero_ps();\n        c512[m_i * 2 + 1] = _mm512_setzero_ps();\n      }\n    }\n\n    // Process entire K_BLOCK\n    for (int k_begin = 0; k_begin < K_BLOCK && k_block_begin + k_begin < k; k_begin += K_STEP) {\n      int32_t* a32 = (int32_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n      __m512bh* b512 = (__m512bh*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n\n      // Process 4 k_i at once - load B vectors and reuse across all m rows\n      for (int k_i = 0; k_i < 16; k_i += 4) {\n        // Load 4 B vector pairs (lo and hi for each k_i)\n        __m512bh b0_lo = b512[k_i];\n        __m512bh b0_hi = b512[16 + k_i];\n        __m512bh b1_lo = b512[k_i + 1];\n        __m512bh b1_hi = b512[16 + k_i + 1];\n        __m512bh b2_lo = b512[k_i + 2];\n        __m512bh b2_hi = b512[16 + k_i + 2];\n        __m512bh b3_lo = b512[k_i + 3];\n        __m512bh b3_hi = b512[16 + k_i + 3];\n\n        // Process m rows - unroll by 2 for better ILP\n        int m_i = 0;\n        for (; m_i + 1 < m_block_end; m_i += 2) {\n          // Load A values for 2 rows, 4 k_i each\n          __m512bh ma0_0 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);\n          __m512bh ma1_0 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 1]);\n          __m512bh ma2_0 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 2]);\n          __m512bh ma3_0 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 3]);\n          __m512bh ma0_1 = (__m512bh)_mm512_set1_epi32(a32[(m_i + 1) * 16 + k_i]);\n          __m512bh ma1_1 = (__m512bh)_mm512_set1_epi32(a32[(m_i + 1) * 16 + k_i + 1]);\n          __m512bh ma2_1 = (__m512bh)_mm512_set1_epi32(a32[(m_i + 1) * 16 + k_i + 2]);\n          __m512bh ma3_1 = (__m512bh)_mm512_set1_epi32(a32[(m_i + 1) * 16 + k_i + 3]);\n\n          // Process row 0\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0_0, b0_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0_0, b0_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1_0, b1_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1_0, b1_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2_0, b2_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2_0, b2_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3_0, b3_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3_0, b3_hi);\n\n          // Process row 1\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma0_1, b0_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma0_1, b0_hi);\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma1_1, b1_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma1_1, b1_hi);\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma2_1, b2_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma2_1, b2_hi);\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma3_1, b3_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma3_1, b3_hi);\n        }\n        // Handle remaining row\n        for (; m_i < m_block_end; m_i++) {\n          __m512bh ma0 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);\n          __m512bh ma1 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 1]);\n          __m512bh ma2 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 2]);\n          __m512bh ma3 = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i + 3]);\n\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, b0_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, b0_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, b1_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, b1_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2, b2_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2, b2_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3, b3_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3, b3_hi);\n        }\n      }\n    }\n  }\n\n  // AMX kernel for BF16: process entire K_BLOCK using AMX tiles\n  static void amx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,\n                         BufferB* bb) {\n    if (k_block_begin == 0) {\n      clean_c();\n    } else {\n      load_c(c, N_STEP * sizeof(float));\n    }\n\n    for (int k_begin = 0; k_begin < K_BLOCK && k_block_begin + k_begin < k; k_begin += K_STEP) {\n      load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K_STEP * sizeof(ggml_bf16_t));\n      load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K_STEP * sizeof(ggml_bf16_t));\n      run_tile();\n    }\n\n    store_c(c, N_STEP * sizeof(float));\n  }\n};\n\n// FP8 (e4m3) AMX kernel that mirrors the GemmKernel224BF16 interface.\nstruct GemmKernel224FP8 {\n  using fp8_t = uint8_t;\n  using output_t = float;\n\n  static constexpr double ELEMENT_SIZE = 1.0;\n  static const int TILE_M = 16;\n  static const int TILE_K = 32;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 2;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  static inline const int BLOCK_SIZE = 128;  // 128 x 128 block quantization\n  static inline const int N_BLOCK = 128;\n  static inline const int K_BLOCK = 7168;\n\n  static std::string name() { return \"FP8\"; }\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {}\n\n  // FP8->BF16 conversion lookup tables (public for reuse by GemmKernel224FP8PerChannel)\n  alignas(64) static constexpr uint8_t bf16_hi_0_val[64] = {\n      0x00, 0x3b, 0x3b, 0x3b, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c, 0x3c,\n      0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d, 0x3d,\n      0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e, 0x3e,\n      0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f,\n  };\n  alignas(64) static constexpr uint8_t bf16_hi_1_val[64] = {\n      0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40, 0x40,\n      0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,\n      0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42,\n      0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43, 0x43,\n  };\n  alignas(64) static constexpr uint8_t bf16_lo_0_val[64] = {\n      0x00, 0x00, 0x80, 0xc0, 0x00, 0x20, 0x40, 0x60, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,\n      0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,\n      0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,\n      0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,\n  };\n  alignas(64) static constexpr uint8_t bf16_lo_1_val[64] = {\n      0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,\n      0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,\n      0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,\n      0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xa0, 0xb0, 0xc0, 0xd0, 0xe0, 0xf0,\n  };\n  // _mm512_set1_epi8 is not constexpr; keep it as a static cached value\n  alignas(64) static const __m512i sign_mask_val;\n  static inline __m512i bf16_hi_0_mask() { return _mm512_load_si512((__m512i const*)bf16_hi_0_val); }\n  static inline __m512i bf16_hi_1_mask() { return _mm512_load_si512((__m512i const*)bf16_hi_1_val); }\n  static inline __m512i bf16_lo_0_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_0_val); }\n  static inline __m512i bf16_lo_1_mask() { return _mm512_load_si512((__m512i const*)bf16_lo_1_val); }\n  static inline __m512i sign_mask() { return _mm512_set1_epi8(0x80); }\n  using BufferA = BufferABF16Impl<GemmKernel224FP8>;\n  using BufferB = BufferBFP8Impl<GemmKernel224FP8>;\n  using BufferC = BufferCFP32ReduceImpl<GemmKernel224FP8>;\n\n  static inline std::pair<__m512i, __m512i> fp8x64_to_bf16x64(__m512i bfp8_512) {\n    // fp8->bf16\n    __m512i b_hi = _mm512_permutex2var_epi8(bf16_hi_0_mask(), bfp8_512, bf16_hi_1_mask());\n    __m512i b_lo = _mm512_permutex2var_epi8(bf16_lo_0_mask(), bfp8_512, bf16_lo_1_mask());\n    b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask(), bfp8_512), b_hi);\n    __m512i bbf16_0 = _mm512_unpacklo_epi8(b_lo, b_hi);\n    __m512i bbf16_1 = _mm512_unpackhi_epi8(b_lo, b_hi);\n    return {bbf16_0, bbf16_1};\n  }\n  // Optimized AVX kernel: process entire k_group_size\n  // Load all data first, then convert all, then compute all\n  // This gives compiler more freedom to schedule instructions\n  static void avx_kernel(int m, int n, int k, int m_begin, int n_begin, int k_group_begin, float* c, BufferA* ba,\n                         BufferB* bb, int k_group_size) {\n    const __m512i bf16_hi_0_val = bf16_hi_0_mask();\n    const __m512i bf16_hi_1_val = bf16_hi_1_mask();\n    const __m512i bf16_lo_0_val = bf16_lo_0_mask();\n    const __m512i bf16_lo_1_val = bf16_lo_1_mask();\n    const __m512i sign_mask_val = sign_mask();\n\n    __m512* c512 = (__m512*)c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n\n    // Zero out accumulator at the start\n    for (int m_i = 0; m_i < m_block_end; m_i++) {\n      c512[m_i * 2] = _mm512_setzero_ps();\n      c512[m_i * 2 + 1] = _mm512_setzero_ps();\n    }\n\n    // Process entire k_group_size\n    for (int k_begin = 0; k_begin < k_group_size && k_group_begin + k_begin < k; k_begin += K_STEP) {\n      ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_group_begin + k_begin);\n      __m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_group_begin + k_begin);\n\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        // Process 2 k_i per iteration\n        for (int k_i = 0; k_i < 16; k_i += 2) {\n          // Load A vectors\n          __m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);\n          __m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);\n\n          // Load B matrices\n          __m512i bfp8_0 = bfp8_512[k_i];\n          __m512i bfp8_1 = bfp8_512[k_i + 1];\n\n          // Convert FP8 -> BF16 for all\n          __m512i b_hi_0 = _mm512_permutex2var_epi8(bf16_hi_0_val, bfp8_0, bf16_hi_1_val);\n          __m512i b_lo_0 = _mm512_permutex2var_epi8(bf16_lo_0_val, bfp8_0, bf16_lo_1_val);\n          b_hi_0 = _mm512_or_si512(_mm512_and_si512(sign_mask_val, bfp8_0), b_hi_0);\n\n          __m512i b_hi_1 = _mm512_permutex2var_epi8(bf16_hi_0_val, bfp8_1, bf16_hi_1_val);\n          __m512i b_lo_1 = _mm512_permutex2var_epi8(bf16_lo_0_val, bfp8_1, bf16_lo_1_val);\n          b_hi_1 = _mm512_or_si512(_mm512_and_si512(sign_mask_val, bfp8_1), b_hi_1);\n\n          // Compute dpbf16 for all\n          __m512bh bbf16_0_0 = (__m512bh)_mm512_unpacklo_epi8(b_lo_0, b_hi_0);\n          __m512bh bbf16_1_0 = (__m512bh)_mm512_unpackhi_epi8(b_lo_0, b_hi_0);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_0);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_1_0);\n\n          __m512bh bbf16_0_1 = (__m512bh)_mm512_unpacklo_epi8(b_lo_1, b_hi_1);\n          __m512bh bbf16_1_1 = (__m512bh)_mm512_unpackhi_epi8(b_lo_1, b_hi_1);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_0_1);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_1);\n        }\n      }\n    }\n  }\n\n  // Optimized AVX kernel: process 4 k_i at once, convert B once and reuse for all m rows\n  // This version achieved ~493 GB/s - restoring as baseline for further optimization\n  static void avx_kernel_4(int m, int n, int k, int m_begin, int n_begin, int k_group_begin, float* c, BufferA* ba,\n                           BufferB* bb, int k_group_size) {\n    const __m512i bf16_hi_0 = bf16_hi_0_mask();\n    const __m512i bf16_hi_1 = bf16_hi_1_mask();\n    const __m512i bf16_lo_0 = bf16_lo_0_mask();\n    const __m512i bf16_lo_1 = bf16_lo_1_mask();\n    const __m512i sign_mask_v = sign_mask();\n\n    __m512* c512 = (__m512*)c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n\n    // Zero out accumulator\n    for (int m_i = 0; m_i < m_block_end; m_i++) {\n      c512[m_i * 2] = _mm512_setzero_ps();\n      c512[m_i * 2 + 1] = _mm512_setzero_ps();\n    }\n\n    // Process entire k_group_size\n    for (int k_begin = 0; k_begin < k_group_size && k_group_begin + k_begin < k; k_begin += K_STEP) {\n      ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_group_begin + k_begin);\n      __m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_group_begin + k_begin);\n\n      // Process 4 k_i at once - convert B and reuse across all m rows\n      for (int k_i = 0; k_i < 16; k_i += 4) {\n        // Load 4 B vectors\n        __m512i bfp8_0 = bfp8_512[k_i];\n        __m512i bfp8_1 = bfp8_512[k_i + 1];\n        __m512i bfp8_2 = bfp8_512[k_i + 2];\n        __m512i bfp8_3 = bfp8_512[k_i + 3];\n\n        // Convert all 4 FP8 -> BF16\n        __m512i b_hi, b_lo;\n\n        b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_0),\n                               _mm512_permutex2var_epi8(bf16_hi_0, bfp8_0, bf16_hi_1));\n        b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_0, bf16_lo_1);\n        __m512bh bbf16_0_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);\n        __m512bh bbf16_0_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);\n\n        b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_1),\n                               _mm512_permutex2var_epi8(bf16_hi_0, bfp8_1, bf16_hi_1));\n        b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_1, bf16_lo_1);\n        __m512bh bbf16_1_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);\n        __m512bh bbf16_1_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);\n\n        b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_2),\n                               _mm512_permutex2var_epi8(bf16_hi_0, bfp8_2, bf16_hi_1));\n        b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_2, bf16_lo_1);\n        __m512bh bbf16_2_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);\n        __m512bh bbf16_2_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);\n\n        b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_3),\n                               _mm512_permutex2var_epi8(bf16_hi_0, bfp8_3, bf16_hi_1));\n        b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_3, bf16_lo_1);\n        __m512bh bbf16_3_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);\n        __m512bh bbf16_3_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);\n\n        // Process m rows - unroll by 2 for better ILP\n        int m_i = 0;\n        for (; m_i + 1 < m_block_end; m_i += 2) {\n          // Load A values for 2 rows\n          __m512bh ma0_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);\n          __m512bh ma1_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);\n          __m512bh ma2_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]);\n          __m512bh ma3_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]);\n          __m512bh ma0_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + k_i * 2]);\n          __m512bh ma1_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 1) * 2]);\n          __m512bh ma2_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 2) * 2]);\n          __m512bh ma3_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 3) * 2]);\n\n          // Process row 0, then row 1 - sequential to avoid dependencies\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0_0, bbf16_0_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0_0, bbf16_0_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1_0, bbf16_1_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1_0, bbf16_1_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2_0, bbf16_2_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2_0, bbf16_2_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3_0, bbf16_3_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3_0, bbf16_3_hi);\n\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma0_1, bbf16_0_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma0_1, bbf16_0_hi);\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma1_1, bbf16_1_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma1_1, bbf16_1_hi);\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma2_1, bbf16_2_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma2_1, bbf16_2_hi);\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma3_1, bbf16_3_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma3_1, bbf16_3_hi);\n        }\n        // Handle remaining row\n        for (; m_i < m_block_end; m_i++) {\n          __m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);\n          __m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);\n          __m512bh ma2 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]);\n          __m512bh ma3 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]);\n\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_0_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_1_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2, bbf16_2_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2, bbf16_2_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3, bbf16_3_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3, bbf16_3_hi);\n        }\n      }\n    }\n  }\n\n  static void apply_scale_kgroup(int m, int n, int m_begin, int n_begin, int k_block_begin, float* c, float* reduce_c,\n                                 BufferA* ba, BufferB* bb, int k, int k_group_size) {\n    using K = GemmKernel224FP8;\n    int to = std::min(m - m_begin, K::M_STEP);\n\n    for (int i = 0; i < to; i++) {\n      // Get scale for this k_group\n      __m512 bs = _mm512_set1_ps(*bb->get_scale(n, n_begin, k, k_block_begin));\n      __m512 now = _mm512_load_ps(reduce_c + i * K::N_STEP);\n      __m512 result = _mm512_mul_ps(now, bs);\n      __m512 existing = _mm512_load_ps(c + i * K::N_STEP);\n      result = _mm512_add_ps(result, existing);\n      _mm512_store_ps(c + i * K::N_STEP, result);\n\n      now = _mm512_load_ps(reduce_c + i * K::N_STEP + K::TILE_N);\n      result = _mm512_mul_ps(now, bs);\n      existing = _mm512_load_ps(c + i * K::N_STEP + K::TILE_N);\n      result = _mm512_add_ps(result, existing);\n      _mm512_store_ps(c + i * K::N_STEP + K::TILE_N, result);\n    }\n  }\n};\n\n// all step = 32\ntemplate <typename K, bool amx_or_avx = false>\nvoid float_mat_vec_kgroup(int m, int n, int k, int k_group_size, typename K::BufferA* ba, typename K::BufferB* bb,\n                          typename K::BufferC* bc, int ith, int nth) {\n  assert(n % K::N_STEP == 0);\n  assert(k % k_group_size == 0);\n  assert(k_group_size % K::K_STEP == 0);\n\n  auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n\n  // Process by k_groups\n  for (int k_group_begin = 0; k_group_begin < k; k_group_begin += k_group_size) {\n    for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {\n        float* c = bc->get_submat(m, n, m_begin, n_begin);\n        float* reduce_c = bc->get_reduce_submat(m, n, m_begin, n_begin);\n\n        if (k_group_begin == 0) {\n          for (int i = 0; i < K::M_STEP && m_begin + i < m; i++) {\n            for (int j = 0; j < K::N_STEP; j++) {\n              c[i * K::N_STEP + j] = 0.0f;\n            }\n          }\n        }\n\n        // avx_kernel_4 now processes entire k_group_size internally (like INT8's avx_kernel)\n        if constexpr (amx_or_avx && AMX_AVAILABLE) {\n          for (int k_begin = k_group_begin; k_begin < std::min(k, k_group_begin + k_group_size); k_begin += K::K_STEP) {\n            K::amx_kernel(m, n, k, m_begin, n_begin, k_begin, reduce_c, ba, bb, k_group_size);\n          }\n        } else {\n          // Single call processes entire k_group\n          K::avx_kernel(m, n, k, m_begin, n_begin, k_group_begin, reduce_c, ba, bb, k_group_size);\n        }\n        K::apply_scale_kgroup(m, n, m_begin, n_begin, k_group_begin, c, reduce_c, ba, bb, k, k_group_size);\n      }\n    }\n  }\n}\n\n// ============================================================================\n// GemmKernel224BF16 vec_mul/mat_mul\n// ============================================================================\n\n// Template function for BF16 mat_mul/vec_mul with AMX or AVX backend\ntemplate <typename K, bool amx_or_avx = true>\nvoid float_mat_vec(int m, int n, int k, typename K::BufferA* ba, typename K::BufferB* bb, typename K::BufferC* bc,\n                   int ith, int nth) {\n  assert(n % K::N_STEP == 0);\n  assert(k % K::K_STEP == 0);\n\n  auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n\n  for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {\n    for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {\n        float* c = bc->get_submat(m, n, m_begin, n_begin);\n\n        if constexpr (amx_or_avx && AMX_AVAILABLE) {\n          K::amx_kernel(m, n, k, m_begin, n_begin, k_block_begin, c, ba, bb);\n        } else {\n          K::avx_kernel_4(m, n, k, m_begin, n_begin, k_block_begin, c, ba, bb);\n        }\n      }\n    }\n  }\n}\n\ninline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224BF16::BufferA> ba,\n                    std::shared_ptr<GemmKernel224BF16::BufferB> bb, std::shared_ptr<GemmKernel224BF16::BufferC> bc,\n                    int ith, int nth) {\n  float_mat_vec<GemmKernel224BF16, true>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void vec_mul(int m, int n, int k, std::shared_ptr<GemmKernel224BF16::BufferA> ba,\n                    std::shared_ptr<GemmKernel224BF16::BufferB> bb, std::shared_ptr<GemmKernel224BF16::BufferC> bc,\n                    int ith, int nth) {\n  float_mat_vec<GemmKernel224BF16, false>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void vec_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224FP8::BufferA> ba,\n                           std::shared_ptr<GemmKernel224FP8::BufferB> bb, std::shared_ptr<GemmKernel224FP8::BufferC> bc,\n                           int ith, int nth) {\n  float_mat_vec_kgroup<GemmKernel224FP8, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\ninline void mat_mul_kgroup(int m, int n, int k, int k_group_size, std::shared_ptr<GemmKernel224FP8::BufferA> ba,\n                           std::shared_ptr<GemmKernel224FP8::BufferB> bb, std::shared_ptr<GemmKernel224FP8::BufferC> bc,\n                           int ith, int nth) {\n  float_mat_vec_kgroup<GemmKernel224FP8, false>(m, n, k, k_group_size, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\n// ============================================================================\n// Per-Channel FP8 GEMM (for GLM-4.7-FP8 style quantization)\n// ============================================================================\n\n/**\n * @brief FP8 Per-Channel Kernel\n *\n * Similar to GemmKernel224FP8 but with per-channel scaling instead of block-wise scaling.\n * - Block-wise: scale shape = [n/128, k/128], one scale per 128x128 block\n * - Per-channel: scale shape = [n], one scale per output row\n */\nstruct GemmKernel224FP8PerChannel {\n  using fp8_t = uint8_t;\n  using output_t = float;\n\n  static constexpr double ELEMENT_SIZE = 1.0;\n  static const int TILE_M = 16;\n  static const int TILE_K = 32;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 2;\n\n  static const int M_STEP = TILE_M * 2;\n  static const int N_STEP = TILE_N * 2;\n  static const int K_STEP = TILE_K;\n\n  // Use smaller N_BLOCK for per-channel to allow efficient scale application\n  static inline const int N_BLOCK = 128;\n  static inline const int K_BLOCK = 7168;\n\n  static std::string name() { return \"FP8PerChannel\"; }\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {}\n\n  using BufferA = BufferABF16Impl<GemmKernel224FP8PerChannel>;\n  using BufferB = BufferBFP8PerChannelImpl<GemmKernel224FP8PerChannel>;\n  using BufferC = BufferCFP32Impl<GemmKernel224FP8PerChannel>;\n\n  // Reuse FP8->BF16 conversion from GemmKernel224FP8\n  static inline std::pair<__m512i, __m512i> fp8x64_to_bf16x64(__m512i bfp8_512) {\n    return GemmKernel224FP8::fp8x64_to_bf16x64(bfp8_512);\n  }\n\n  /**\n   * @brief Apply per-channel scale to result\n   *\n   * Unlike block-wise scaling, per-channel scaling applies a different scale to each column\n   * of the result (each output channel).\n   *\n   * @param m Total rows\n   * @param n Total columns\n   * @param m_begin Starting row\n   * @param n_begin Starting column\n   * @param c Output buffer (M_STEP x N_STEP)\n   * @param bb BufferB containing per-channel scales\n   */\n  static void apply_scale_perchannel(int m, [[maybe_unused]] int n, int m_begin, int n_begin, float* c, BufferB* bb) {\n    int to = std::min(m - m_begin, M_STEP);\n\n    // Load N_STEP per-channel scales (32 floats)\n    __m512 bs_lo = _mm512_loadu_ps(bb->get_scale(n_begin));           // scale[n_begin..n_begin+15]\n    __m512 bs_hi = _mm512_loadu_ps(bb->get_scale(n_begin + TILE_N));  // scale[n_begin+16..n_begin+31]\n\n    for (int i = 0; i < to; i++) {\n      // Each row gets multiplied by the same set of per-channel scales\n      __m512 c_lo = _mm512_load_ps(c + i * N_STEP);\n      __m512 c_hi = _mm512_load_ps(c + i * N_STEP + TILE_N);\n      _mm512_store_ps(c + i * N_STEP, _mm512_mul_ps(c_lo, bs_lo));\n      _mm512_store_ps(c + i * N_STEP + TILE_N, _mm512_mul_ps(c_hi, bs_hi));\n    }\n  }\n\n  // AVX kernel for per-channel FP8 GEMM - processes entire K dimension\n  static void avx_kernel_4(int m, int n, int k, int m_begin, int n_begin, int k_block_begin, float* c, BufferA* ba,\n                           BufferB* bb) {\n    const __m512i bf16_hi_0 = GemmKernel224FP8::bf16_hi_0_mask();\n    const __m512i bf16_hi_1 = GemmKernel224FP8::bf16_hi_1_mask();\n    const __m512i bf16_lo_0 = GemmKernel224FP8::bf16_lo_0_mask();\n    const __m512i bf16_lo_1 = GemmKernel224FP8::bf16_lo_1_mask();\n    const __m512i sign_mask_v = GemmKernel224FP8::sign_mask();\n\n    __m512* c512 = (__m512*)c;\n    int m_block_end = std::min(m - m_begin, M_STEP);\n\n    // Zero out accumulator at start of K_BLOCK\n    if (k_block_begin == 0) {\n      for (int m_i = 0; m_i < m_block_end; m_i++) {\n        c512[m_i * 2] = _mm512_setzero_ps();\n        c512[m_i * 2 + 1] = _mm512_setzero_ps();\n      }\n    }\n\n    // Process K_BLOCK\n    for (int k_begin = 0; k_begin < K_BLOCK && k_block_begin + k_begin < k; k_begin += K_STEP) {\n      ggml_bf16_t* abf16 = (ggml_bf16_t*)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n      __m512i* bfp8_512 = (__m512i*)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n\n      // Process 4 k_i at once\n      for (int k_i = 0; k_i < 16; k_i += 4) {\n        // Load 4 B vectors\n        __m512i bfp8_0 = bfp8_512[k_i];\n        __m512i bfp8_1 = bfp8_512[k_i + 1];\n        __m512i bfp8_2 = bfp8_512[k_i + 2];\n        __m512i bfp8_3 = bfp8_512[k_i + 3];\n\n        // Convert all 4 FP8 -> BF16\n        __m512i b_hi, b_lo;\n\n        b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_0),\n                               _mm512_permutex2var_epi8(bf16_hi_0, bfp8_0, bf16_hi_1));\n        b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_0, bf16_lo_1);\n        __m512bh bbf16_0_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);\n        __m512bh bbf16_0_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);\n\n        b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_1),\n                               _mm512_permutex2var_epi8(bf16_hi_0, bfp8_1, bf16_hi_1));\n        b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_1, bf16_lo_1);\n        __m512bh bbf16_1_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);\n        __m512bh bbf16_1_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);\n\n        b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_2),\n                               _mm512_permutex2var_epi8(bf16_hi_0, bfp8_2, bf16_hi_1));\n        b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_2, bf16_lo_1);\n        __m512bh bbf16_2_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);\n        __m512bh bbf16_2_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);\n\n        b_hi = _mm512_or_si512(_mm512_and_si512(sign_mask_v, bfp8_3),\n                               _mm512_permutex2var_epi8(bf16_hi_0, bfp8_3, bf16_hi_1));\n        b_lo = _mm512_permutex2var_epi8(bf16_lo_0, bfp8_3, bf16_lo_1);\n        __m512bh bbf16_3_lo = (__m512bh)_mm512_unpacklo_epi8(b_lo, b_hi);\n        __m512bh bbf16_3_hi = (__m512bh)_mm512_unpackhi_epi8(b_lo, b_hi);\n\n        // Process m rows\n        int m_i = 0;\n        for (; m_i + 1 < m_block_end; m_i += 2) {\n          __m512bh ma0_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);\n          __m512bh ma1_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);\n          __m512bh ma2_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]);\n          __m512bh ma3_0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]);\n          __m512bh ma0_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + k_i * 2]);\n          __m512bh ma1_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 1) * 2]);\n          __m512bh ma2_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 2) * 2]);\n          __m512bh ma3_1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[(m_i + 1) * K_STEP + (k_i + 3) * 2]);\n\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0_0, bbf16_0_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0_0, bbf16_0_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1_0, bbf16_1_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1_0, bbf16_1_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2_0, bbf16_2_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2_0, bbf16_2_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3_0, bbf16_3_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3_0, bbf16_3_hi);\n\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma0_1, bbf16_0_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma0_1, bbf16_0_hi);\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma1_1, bbf16_1_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma1_1, bbf16_1_hi);\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma2_1, bbf16_2_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma2_1, bbf16_2_hi);\n          c512[(m_i + 1) * 2] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2], ma3_1, bbf16_3_lo);\n          c512[(m_i + 1) * 2 + 1] = _mm512_dpbf16_ps(c512[(m_i + 1) * 2 + 1], ma3_1, bbf16_3_hi);\n        }\n        // Handle remaining row\n        for (; m_i < m_block_end; m_i++) {\n          __m512bh ma0 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + k_i * 2]);\n          __m512bh ma1 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 1) * 2]);\n          __m512bh ma2 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 2) * 2]);\n          __m512bh ma3 = (__m512bh)_mm512_set1_epi32(*(int32_t*)&abf16[m_i * K_STEP + (k_i + 3) * 2]);\n\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma0, bbf16_0_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma0, bbf16_0_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma1, bbf16_1_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma1, bbf16_1_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma2, bbf16_2_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma2, bbf16_2_hi);\n          c512[m_i * 2] = _mm512_dpbf16_ps(c512[m_i * 2], ma3, bbf16_3_lo);\n          c512[m_i * 2 + 1] = _mm512_dpbf16_ps(c512[m_i * 2 + 1], ma3, bbf16_3_hi);\n        }\n      }\n    }\n  }\n};\n\n/**\n * @brief Per-channel FP8 GEMM function\n *\n * Unlike block-wise FP8 which applies scale per 128x128 block during computation,\n * per-channel FP8 processes entire K dimension first, then applies per-channel scale at the end.\n */\ntemplate <typename K>\nvoid float_mat_vec_perchannel(int m, int n, int k, typename K::BufferA* ba, typename K::BufferB* bb,\n                              typename K::BufferC* bc, int ith, int nth) {\n  assert(n % K::N_STEP == 0);\n  assert(k % K::K_STEP == 0);\n\n  auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n\n  for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {\n    for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {\n      float* c = bc->get_submat(m, n, m_begin, n_begin);\n\n      // Process entire K dimension with K_BLOCKs\n      for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {\n        K::avx_kernel_4(m, n, k, m_begin, n_begin, k_block_begin, c, ba, bb);\n      }\n\n      // Apply per-channel scale once after all K is processed\n      K::apply_scale_perchannel(m, n, m_begin, n_begin, c, bb);\n    }\n  }\n}\n\ninline void vec_mul_perchannel(int m, int n, int k, std::shared_ptr<GemmKernel224FP8PerChannel::BufferA> ba,\n                               std::shared_ptr<GemmKernel224FP8PerChannel::BufferB> bb,\n                               std::shared_ptr<GemmKernel224FP8PerChannel::BufferC> bc, int ith, int nth) {\n  float_mat_vec_perchannel<GemmKernel224FP8PerChannel>(m, n, k, ba.get(), bb.get(), bc.get(), ith, nth);\n}\n\n}  // namespace amx\n\n#endif  // AMX_RAW_KERNELS_HPP\n"
  },
  {
    "path": "kt-kernel/operators/amx/la/amx_utils.hpp",
    "content": "#ifndef AMX_UTILS_HPP\n#define AMX_UTILS_HPP\n\n#include <cstdio>\n#include <iostream>\n\n#include \"../../common.hpp\"\n#include \"amx_config.hpp\"\n\nnamespace amx {\n#if defined(HAVE_AMX)\n// Debug functions\ninline void debug_tile(int t) {\n  printf(\"Tile %d\\n\", t);\n  int8_t data[16][64] = {};\n  TileConfig::store_data(t, data, 64);\n  for (int i = 0; i < 16; i++) {\n    for (int j = 0; j < 64; j++) {\n      printf(\"%4d \", data[i][j]);\n    }\n    printf(\"\\n\");\n  }\n  printf(\"\\n\");\n}\n\ninline void debug_tile_int32(int t) {\n  printf(\"Tile %d\\n\", t);\n  int32_t data[16][16] = {};\n  TileConfig::store_data(t, data, 64);\n  for (int i = 0; i < 16; i++) {\n    for (int j = 0; j < 16; j++) {\n      printf(\"%10d \", data[i][j]);\n    }\n    printf(\"\\n\");\n  }\n  printf(\"\\n\");\n}\n\ninline void debug_tiles(int to = 8) {\n  for (int i = 0; i < to; i++) {\n    debug_tile(i);\n  }\n}\n\ninline void debug_tiles_int32(int to = 8) {\n  for (int i = 0; i < to; i++) {\n    debug_tile_int32(i);\n  }\n}\n\ninline void debug_tiles_224() {\n  for (int i = 0; i < 4; i++) {\n    debug_tile(i);\n  }\n  for (int i = 4; i < 8; i++) {\n    debug_tile_int32(i);\n  }\n}\n\ninline void debug_m512(__m512 x) {\n  float data[16];\n  _mm512_storeu_ps(data, x);\n  for (int i = 0; i < 16; i++) {\n    printf(\"%f \", data[i]);\n  }\n  printf(\"\\n\");\n}\n\ninline void debug_m512i(__m512i x) {\n  int32_t data[16];\n  _mm512_storeu_epi32(data, x);\n  for (int i = 0; i < 16; i++) {\n    printf(\"0x%08x \", data[i]);\n  }\n  printf(\"\\n\");\n}\n\ninline void debug_m128i(__m128i x) {\n  int32_t data[16];\n  _mm_storeu_epi32(data, x);\n  for (int i = 0; i < 4; i++) {\n    printf(\"0x%08x \", data[i]);\n  }\n  printf(\"\\n\");\n}\n#endif\n// transpose utils\n#define SHUFFLE_EPI32(a, b, mask) \\\n  _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))\n\ninline void transpose_8x8_32bit(__m256i* v, __m256i* v1) {\n  // unpacking and 32-bit elements\n  v1[0] = _mm256_unpacklo_epi32(v[0], v[1]);\n  v1[1] = _mm256_unpackhi_epi32(v[0], v[1]);\n  v1[2] = _mm256_unpacklo_epi32(v[2], v[3]);\n  v1[3] = _mm256_unpackhi_epi32(v[2], v[3]);\n  v1[4] = _mm256_unpacklo_epi32(v[4], v[5]);\n  v1[5] = _mm256_unpackhi_epi32(v[4], v[5]);\n  v1[6] = _mm256_unpacklo_epi32(v[6], v[7]);\n  v1[7] = _mm256_unpackhi_epi32(v[6], v[7]);\n\n  // shuffling the 32-bit elements\n  v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44);\n  v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee);\n  v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44);\n  v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee);\n  v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44);\n  v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee);\n  v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44);\n  v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee);\n\n  // shuffling 128-bit elements\n  v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02);\n  v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02);\n  v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02);\n  v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02);\n  v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13);\n  v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13);\n  v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13);\n  v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13);\n}\n\ninline void transpose_8x8_32bit(__m256i* v) {\n  __m256i v1[8];\n  transpose_8x8_32bit(v, v1);\n\n  v[0] = v1[0];\n  v[1] = v1[1];\n  v[2] = v1[2];\n  v[3] = v1[3];\n  v[4] = v1[4];\n  v[5] = v1[5];\n  v[6] = v1[6];\n  v[7] = v1[7];\n}\n\ninline void transpose_16x4_32bit(__m512i* r, __m512i* d) {\n  static const __m512i index1 =\n      _mm512_set_epi32(0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02, 0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00);\n\n  d[0] = _mm512_permutexvar_epi32(index1, r[0]);\n  d[1] = _mm512_permutexvar_epi32(index1, r[1]);\n  d[2] = _mm512_permutexvar_epi32(index1, r[2]);\n  d[3] = _mm512_permutexvar_epi32(index1, r[3]);\n\n  r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);\n  r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);\n  r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);\n  r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);\n\n  d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);\n  d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);\n  d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);\n  d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);\n}\n\ninline void transpose_16x16_32bit(__m512i* v) {\n  __m512i v1[16];\n  v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);\n  v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);\n  v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);\n  v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);\n  v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);\n  v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);\n  v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);\n  v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);\n  v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);\n  v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);\n  v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);\n  v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);\n  v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);\n  v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);\n  v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);\n  v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);\n\n  v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);\n  v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);\n  v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);\n  v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);\n  v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);\n  v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);\n  v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);\n  v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);\n  v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);\n  v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);\n  v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);\n  v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);\n  v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);\n  v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);\n  v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);\n  v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);\n\n  v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);\n  v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);\n  v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);\n  v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);\n  v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);\n  v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);\n  v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);\n  v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);\n  v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);\n  v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);\n  v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);\n  v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);\n  v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);\n  v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);\n  v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);\n  v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);\n\n  v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);\n  v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);\n  v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);\n  v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);\n  v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);\n  v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);\n  v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);\n  v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);\n  v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);\n  v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);\n  v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);\n  v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);\n  v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);\n  v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);\n  v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);\n  v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);\n}\n\ninline void transpose_16x8_32bit(__m256i* v) {\n  transpose_8x8_32bit(v);\n  transpose_8x8_32bit(v + 8);\n  __m256i v1[16];\n  for (int i = 0; i < 16; i++) v1[i] = v[i];\n\n  for (int i = 0; i < 8; i++) {\n    v[i * 2] = v1[i];\n    v[i * 2 + 1] = v1[8 + i];\n  }\n}\n\n/*\n  Transpose 16x16 32-bit elements\n  Note that v must be 64 byte aligned\n*/\ninline void transpose_16x16_32bit(__m512i* v, size_t stride) {\n  assert(reinterpret_cast<intptr_t>(v) % 64 == 0 && \"v must be 64 aligned\");\n\n  auto stride_v = [=](int i) { return offset_pointer(v, i * stride); };\n  __m512i v1[16];\n\n  v1[0] = _mm512_unpacklo_epi32(*stride_v(0), *stride_v(1));\n  v1[1] = _mm512_unpackhi_epi32(*stride_v(0), *stride_v(1));\n  v1[2] = _mm512_unpacklo_epi32(*stride_v(2), *stride_v(3));\n  v1[3] = _mm512_unpackhi_epi32(*stride_v(2), *stride_v(3));\n  v1[4] = _mm512_unpacklo_epi32(*stride_v(4), *stride_v(5));\n  v1[5] = _mm512_unpackhi_epi32(*stride_v(4), *stride_v(5));\n  v1[6] = _mm512_unpacklo_epi32(*stride_v(6), *stride_v(7));\n  v1[7] = _mm512_unpackhi_epi32(*stride_v(6), *stride_v(7));\n  v1[8] = _mm512_unpacklo_epi32(*stride_v(8), *stride_v(9));\n  v1[9] = _mm512_unpackhi_epi32(*stride_v(8), *stride_v(9));\n  v1[10] = _mm512_unpacklo_epi32(*stride_v(10), *stride_v(11));\n  v1[11] = _mm512_unpackhi_epi32(*stride_v(10), *stride_v(11));\n  v1[12] = _mm512_unpacklo_epi32(*stride_v(12), *stride_v(13));\n  v1[13] = _mm512_unpackhi_epi32(*stride_v(12), *stride_v(13));\n  v1[14] = _mm512_unpacklo_epi32(*stride_v(14), *stride_v(15));\n  v1[15] = _mm512_unpackhi_epi32(*stride_v(14), *stride_v(15));\n\n  *stride_v(0) = _mm512_unpacklo_epi64(v1[0], v1[2]);\n  *stride_v(1) = _mm512_unpackhi_epi64(v1[0], v1[2]);\n  *stride_v(2) = _mm512_unpacklo_epi64(v1[1], v1[3]);\n  *stride_v(3) = _mm512_unpackhi_epi64(v1[1], v1[3]);\n  *stride_v(4) = _mm512_unpacklo_epi64(v1[4], v1[6]);\n  *stride_v(5) = _mm512_unpackhi_epi64(v1[4], v1[6]);\n  *stride_v(6) = _mm512_unpacklo_epi64(v1[5], v1[7]);\n  *stride_v(7) = _mm512_unpackhi_epi64(v1[5], v1[7]);\n  *stride_v(8) = _mm512_unpacklo_epi64(v1[8], v1[10]);\n  *stride_v(9) = _mm512_unpackhi_epi64(v1[8], v1[10]);\n  *stride_v(10) = _mm512_unpacklo_epi64(v1[9], v1[11]);\n  *stride_v(11) = _mm512_unpackhi_epi64(v1[9], v1[11]);\n  *stride_v(12) = _mm512_unpacklo_epi64(v1[12], v1[14]);\n  *stride_v(13) = _mm512_unpackhi_epi64(v1[12], v1[14]);\n  *stride_v(14) = _mm512_unpacklo_epi64(v1[13], v1[15]);\n  *stride_v(15) = _mm512_unpackhi_epi64(v1[13], v1[15]);\n\n  v1[0] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0x88);\n  v1[1] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0x88);\n  v1[2] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0x88);\n  v1[3] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0x88);\n  v1[4] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0xdd);\n  v1[5] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0xdd);\n  v1[6] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0xdd);\n  v1[7] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0xdd);\n  v1[8] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0x88);\n  v1[9] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0x88);\n  v1[10] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0x88);\n  v1[11] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0x88);\n  v1[12] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0xdd);\n  v1[13] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0xdd);\n  v1[14] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0xdd);\n  v1[15] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0xdd);\n\n  *stride_v(0) = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);\n  *stride_v(1) = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);\n  *stride_v(2) = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);\n  *stride_v(3) = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);\n  *stride_v(4) = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);\n  *stride_v(5) = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);\n  *stride_v(6) = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);\n  *stride_v(7) = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);\n  *stride_v(8) = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);\n  *stride_v(9) = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);\n  *stride_v(10) = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);\n  *stride_v(11) = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);\n  *stride_v(12) = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);\n  *stride_v(13) = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);\n  *stride_v(14) = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);\n  *stride_v(15) = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);\n}\n\n}  // namespace amx\n\n#endif  // AMX_UTILS_HPP"
  },
  {
    "path": "kt-kernel/operators/amx/la/pack.hpp",
    "content": "#ifndef PACK_HPP\n#define PACK_HPP\n\n#pragma once\n#include <cassert>\n#include <cstddef>\n#include <iostream>\n#include <numeric>\n#include <stdexcept>\n#include <string>\n#include <utility>\n#include <vector>\n\nclass Packed2DLayout {\n public:\n  using index_t = std::size_t;\n\n  struct Dim {\n    index_t size;  // > 0\n    char dir;      // 'r' or 'c'\n  };\n\n  // 构造：dims 必须按从低维到高维给出\n  explicit Packed2DLayout(std::vector<Dim> dims) : dims_(std::move(dims)) {\n    if (dims_.empty()) throw std::invalid_argument(\"dims must not be empty\");\n    rows_ = 1;\n    cols_ = 1;\n\n    // 预计算行/列 stride（混合进位权重）\n    r_stride_for_dim_.assign(dims_.size(), 0);\n    c_stride_for_dim_.assign(dims_.size(), 0);\n\n    index_t r_stride = 1, c_stride = 1;\n    for (index_t i = 0; i < dims_.size(); ++i) {\n      const auto& d = dims_[i];\n      if (d.size == 0) throw std::invalid_argument(\"dim size must be > 0\");\n      if (d.dir == 'r') {\n        r_stride_for_dim_[i] = r_stride;\n        r_stride *= d.size;\n        rows_ *= d.size;\n      } else if (d.dir == 'c') {\n        c_stride_for_dim_[i] = c_stride;\n        c_stride *= d.size;\n        cols_ *= d.size;\n      } else {\n        throw std::invalid_argument(\"dim dir must be 'r' or 'c'\");\n      }\n    }\n    numel_ = rows_ * cols_;\n  }\n\n  // 基本信息\n  index_t dims() const { return static_cast<index_t>(dims_.size()); }\n  index_t rows() const { return rows_; }\n  index_t cols() const { return cols_; }\n  index_t numel() const { return numel_; }\n  const std::vector<Dim>& spec() const { return dims_; }\n  const std::vector<index_t>& r_strides() const { return r_stride_for_dim_; }\n  const std::vector<index_t>& c_strides() const { return c_stride_for_dim_; }\n\n  // ---------- 高维坐标 <-> 2D ----------\n  std::pair<index_t, index_t> hd_to_rc(const std::vector<index_t>& hd_idx) const {\n    check_hd_index(hd_idx);\n    index_t row = 0, col = 0;\n    for (index_t i = 0; i < dims(); ++i) {\n      const auto& d = dims_[i];\n      auto v = hd_idx[i];\n      if (v >= d.size) throw std::out_of_range(err_dim(i, v, d.size));\n      if (d.dir == 'r')\n        row += v * r_stride_for_dim_[i];\n      else\n        col += v * c_stride_for_dim_[i];\n    }\n    return {row, col};\n  }\n\n  std::vector<index_t> rc_to_hd(index_t row, index_t col) const {\n    if (row >= rows_ || col >= cols_)\n      throw std::out_of_range(\"rc out of range: (\" + std::to_string(row) + \",\" + std::to_string(col) +\n                              \"), expect rows<\" + std::to_string(rows_) + \", cols<\" + std::to_string(cols_) + \")\");\n    std::vector<index_t> hd_idx(dims(), 0);\n    for (index_t i = 0; i < dims(); ++i) {\n      const auto& d = dims_[i];\n      if (d.dir == 'r') {\n        auto stride = r_stride_for_dim_[i];\n        hd_idx[i] = (row / stride) % d.size;\n      } else {\n        auto stride = c_stride_for_dim_[i];\n        hd_idx[i] = (col / stride) % d.size;\n      }\n    }\n    return hd_idx;\n  }\n\n  // ---------- 2D <-> offset（行主序），支持自定义 ld ----------\n  index_t rc_to_offset(index_t row, index_t col, index_t ld = 0) const {\n    if (ld == 0) ld = cols_;\n    if (row >= rows_ || col >= cols_) throw std::out_of_range(\"rc out of range for rc_to_offset\");\n    return row * ld + col;\n  }\n\n  std::pair<index_t, index_t> offset_to_rc(index_t offset, index_t ld = 0) const {\n    if (ld == 0) ld = cols_;\n    index_t row = offset / ld;\n    index_t col = offset % ld;\n    if (row >= rows_ || col >= cols_) throw std::out_of_range(\"offset out of range for given ld\");\n    return {row, col};\n  }\n\n  // ---------- 高维坐标 <-> offset（组合/分解） ----------\n  index_t hd_to_offset(const std::vector<index_t>& hd_idx, index_t ld = 0) const {\n    auto [r, c] = hd_to_rc(hd_idx);\n    return rc_to_offset(r, c, ld);\n  }\n\n  std::vector<index_t> offset_to_hd(index_t offset, index_t ld = 0) const {\n    auto [r, c] = offset_to_rc(offset, ld);\n    return rc_to_hd(r, c);\n  }\n\n  // ---------- 工具：把某一组 r/c 维做“混合进位”分解/合成 ----------\n  // 给定行坐标 row，分解到所有 'r' 维的 digits（低维在前）\n  std::vector<index_t> decompose_row(index_t row) const {\n    if (row >= rows_) throw std::out_of_range(\"row out of range in decompose_row\");\n    std::vector<index_t> res(dims(), 0);\n    for (index_t i = 0; i < dims(); ++i) {\n      if (dims_[i].dir == 'r') {\n        auto stride = r_stride_for_dim_[i];\n        res[i] = (row / stride) % dims_[i].size;\n      }\n    }\n    return res;  // 只有 'r' 维位置含有有效 digit\n  }\n  // 给定列坐标 col，分解到所有 'c' 维的 digits（低维在前）\n  std::vector<index_t> decompose_col(index_t col) const {\n    if (col >= cols_) throw std::out_of_range(\"col out of range in decompose_col\");\n    std::vector<index_t> res(dims(), 0);\n    for (index_t i = 0; i < dims(); ++i) {\n      if (dims_[i].dir == 'c') {\n        auto stride = c_stride_for_dim_[i];\n        res[i] = (col / stride) % dims_[i].size;\n      }\n    }\n    return res;  // 只有 'c' 维位置含有有效 digit\n  }\n  // 合成行坐标（仅读取 'r' 维的位置）\n  index_t compose_row(const std::vector<index_t>& digits) const {\n    if (digits.size() != dims()) throw std::invalid_argument(\"digits dim mismatch\");\n    index_t row = 0;\n    for (index_t i = 0; i < dims(); ++i)\n      if (dims_[i].dir == 'r') {\n        if (digits[i] >= dims_[i].size) throw std::out_of_range(err_dim(i, digits[i], dims_[i].size));\n        row += digits[i] * r_stride_for_dim_[i];\n      }\n    return row;\n  }\n  // 合成列坐标（仅读取 'c' 维的位置）\n  index_t compose_col(const std::vector<index_t>& digits) const {\n    if (digits.size() != dims()) throw std::invalid_argument(\"digits dim mismatch\");\n    index_t col = 0;\n    for (index_t i = 0; i < dims(); ++i)\n      if (dims_[i].dir == 'c') {\n        if (digits[i] >= dims_[i].size) throw std::out_of_range(err_dim(i, digits[i], dims_[i].size));\n        col += digits[i] * c_stride_for_dim_[i];\n      }\n    return col;\n  }\n\n private:\n  void check_hd_index(const std::vector<index_t>& hd_idx) const {\n    if (hd_idx.size() != dims())\n      throw std::invalid_argument(\"hd index dim mismatch: got \" + std::to_string(hd_idx.size()) + \", expect \" +\n                                  std::to_string(dims()));\n  }\n  static std::string err_dim(index_t i, index_t v, index_t sz) {\n    return \"hd index out of range at dim \" + std::to_string(i) + \": got \" + std::to_string(v) + \", expect < \" +\n           std::to_string(sz);\n  }\n\n  std::vector<Dim> dims_;\n  std::vector<index_t> r_stride_for_dim_;\n  std::vector<index_t> c_stride_for_dim_;\n  index_t rows_{1}, cols_{1}, numel_{0};\n};\n\n// ===== 示例与自测（可选） =====\n// g++ -O2 test.cpp -DPACKED2D_DEMO && ./a.out\n#ifdef PACKED2D_DEMO\nint main() {\n  // 任意数量与顺序的 r/c 维；低 -> 高\n  Packed2DLayout p({\n      {4, 'r'}, {8, 'c'}, {2, 'r'}, {3, 'c'}  // rows=4*2=8, cols=8*3=24, numel=192\n  });\n\n  std::cout << \"rows=\" << p.rows() << \" cols=\" << p.cols() << \" numel=\" << p.numel() << \"\\n\";\n\n  // 高维 -> rc -> offset\n  std::vector<std::size_t> hd = {3, 5, 1, 2};\n  auto [r, c] = p.hd_to_rc(hd);\n  auto off = p.hd_to_offset(hd);\n  std::cout << \"hd -> rc=(\" << r << \",\" << c << \"), off=\" << off << \"\\n\";\n\n  // 反向\n  auto hd2 = p.offset_to_hd(off);\n  std::cout << \"offset->hd: \";\n  for (auto v : hd2) std::cout << v << \" \";\n  std::cout << \"\\n\";\n\n  // 只分解/合成行、列\n  auto rdigits = p.decompose_row(r);\n  auto cdigits = p.decompose_col(c);\n  auto r2 = p.compose_row(rdigits);\n  auto c2 = p.compose_col(cdigits);\n  std::cout << \"compose row=\" << r2 << \" col=\" << c2 << \"\\n\";\n  return 0;\n}\n#endif\n\n#endif"
  },
  {
    "path": "kt-kernel/operators/amx/la/utils.hpp",
    "content": "#ifndef UTILS_HPP\n#define UTILS_HPP\n#include <immintrin.h>\n\n#include <cstddef>\n#include <cstdint>\n\nstatic inline void avx512_copy_32xbf16(__m512i* src, __m512i* dst) {\n  _mm512_storeu_si512(dst, _mm512_loadu_si512(src));\n}\n\n// FP32 to BF16 conversion (32 floats -> 32 bf16)\n// This requires AVX512BF16 for the fast path, with a fallback for CPUs without it\nstatic inline void avx512_32xfp32_to_32xbf16(__m512* src0, __m512* src1, __m512i* dst) {\n#if defined(HAVE_AVX512BF16) || defined(__AVX512BF16__)\n  // Fast path: use native AVX512BF16 instruction\n  _mm512_storeu_si512(dst, __m512i(_mm512_cvtne2ps_pbh(*src1, *src0)));\n#else\n  // Fallback: manual BF16 conversion using bit manipulation\n  // BF16 is the upper 16 bits of FP32 (with rounding)\n  __m512i i0 = _mm512_castps_si512(*src0);\n  __m512i i1 = _mm512_castps_si512(*src1);\n\n  // Round to nearest even: add 0x7FFF + ((val >> 16) & 1)\n  __m512i round0 =\n      _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), _mm512_and_epi32(_mm512_srli_epi32(i0, 16), _mm512_set1_epi32(1)));\n  __m512i round1 =\n      _mm512_add_epi32(_mm512_set1_epi32(0x7FFF), _mm512_and_epi32(_mm512_srli_epi32(i1, 16), _mm512_set1_epi32(1)));\n\n  i0 = _mm512_add_epi32(i0, round0);\n  i1 = _mm512_add_epi32(i1, round1);\n\n  // Extract upper 16 bits (BF16)\n  i0 = _mm512_srli_epi32(i0, 16);\n  i1 = _mm512_srli_epi32(i1, 16);\n\n  // Pack 32-bit values to 16-bit\n  __m512i result = _mm512_packus_epi32(i0, i1);\n  // Fix the interleaving from packus\n  result = _mm512_permutexvar_epi64(_mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7), result);\n\n  _mm512_storeu_si512(dst, result);\n#endif\n}\n\n// BF16 to FP32 conversion (32 bf16 -> 32 floats)\n// This does NOT require AVX512BF16 - uses basic AVX512 bit manipulation\nstatic inline void avx512_32xbf16_to_32xfp32(__m512i* src, __m512* dst0, __m512* dst1) {\n  _mm512_storeu_ps(dst0, _mm512_castsi512_ps(\n                             _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)(src))), 16)));\n  _mm512_storeu_ps(dst1, _mm512_castsi512_ps(_mm512_slli_epi32(\n                             _mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)(src) + 1)), 16)));\n}\n\nstatic inline __m512 vector_abs_max(__m512 a, __m512 b) {\n  __m512 a_abs = _mm512_abs_ps(a);\n  __m512 b_abs = _mm512_abs_ps(b);\n\n  __mmask16 mask = _mm512_cmp_ps_mask(a_abs, b_abs, _CMP_GT_OS);\n\n  return _mm512_mask_blend_ps(mask, b_abs, a_abs);\n}\n\n#endif  // UTILS_HPP"
  },
  {
    "path": "kt-kernel/operators/amx/moe.hpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:35:10\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_AMX_MOE_H\n#define CPUINFER_OPERATOR_AMX_MOE_H\n\n// #define CHECK\n// #define FORWARD_TIME_PROFILE\n// #define FORWARD_TIME_REPORT\n\n#include \"moe_base.hpp\"\n\ntemplate <class T>\nclass AMX_MOE_TP : public AMX_MOE_BASE<T, AMX_MOE_TP<T>> {\n private:\n  using Base = AMX_MOE_BASE<T, AMX_MOE_TP<T>>;\n  using Base::config_;\n  using Base::down_ba_;\n  using Base::down_bb_;\n  using Base::down_bc_;\n  using Base::gate_bb_;\n  using Base::gate_bc_;\n  using Base::gate_up_ba_;\n  using Base::m_local_num_;\n  using Base::tp_part_idx;\n  using Base::up_bb_;\n  using Base::up_bc_;\n\n#ifdef CHECK\n  char verify_bb[100000000];\n  char check_bb[100000000];\n  uint8_t compare_expers = 3;\n#endif\n\n  inline void write_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size,\n                            size_t scale_size) {\n    // printf(\"expert %d, size %ld, scale size %ld\\n\", expert_idx, size, scale_size);\n    // std::ofstream of(prefix / (T::name() + mat_class + std::to_string(expert_idx)  + \"_quant_\" + \".kt\"));\n    std::ofstream of(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                               std::to_string(size - scale_size) + \"Byte\" + \"_quant_\" + \".kt\"));\n    if (of.is_open() == false) {\n      printf(\"no such file: %s\", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                                            std::to_string(size - scale_size) + \"Byte\" + \"_quant_\" + \".kt\"))\n                                     .c_str());\n      // throw std::runtime_error(\"No such file\");\n    }\n    of.write((char*)bb, size - scale_size);\n    of.close();\n    // of.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_scale_\" + \".kt\"));\n    of.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" + std::to_string(scale_size) + \"Byte\" +\n                      \"_scale_\" + \".kt\"));\n    if (of.is_open() == false) {\n      printf(\"no such file\\n\");\n      // throw std::runtime_error(\"No such file\");\n    }\n    of.write(((char*)bb) + size - scale_size, scale_size);\n  }\n\n  inline void read_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size,\n                           size_t scale_size, uint8_t mat_split, uint8_t mat_split_idex) {\n    // std::ifstream f(prefix / (T::name() + mat_class + std::to_string(expert_idx)  + \"_quant_\" + \".kt\"));\n    std::ifstream f(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                              std::to_string(size - scale_size) + \"Byte\" + \"_quant_\" + \".kt\"));\n    if (f.is_open() == false) {\n      printf(\"no such file: %s\\n\", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                                              std::to_string(size - scale_size) + \"Byte\" + \"_quant_\" + \".kt\"))\n                                       .c_str());\n      // throw std::runtime_error(\"No such file\");\n    }\n    f.seekg(mat_split_idex * (size - scale_size) / mat_split);\n    f.read(((char*)bb) + mat_split_idex * (size - scale_size) / mat_split, (size - scale_size) / mat_split);\n    f.close();\n    // f.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_scale_\" + \".kt\"));\n    f.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" + std::to_string(scale_size) + \"Byte\" +\n                     \"_scale_\" + \".kt\"));\n    if (f.is_open() == false) {\n      printf(\"no such file: %s\\n\", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                                              std::to_string(scale_size) + \"Byte\" + \"_scale_\" + \".kt\"))\n                                       .c_str());\n      // throw std::runtime_error(\"No such file\");\n    }\n    f.seekg(mat_split_idex * scale_size / mat_split);\n    f.read((((char*)bb) + size - scale_size) + mat_split_idex * scale_size / mat_split, scale_size / mat_split);\n  }\n#ifdef CHECK\n  inline void load_check() {\n    memcpy(check_bb, (char*)down_bb_[compare_expers]->b,\n           T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));\n  }\n\n  void verify_load_right() {\n    // printf(\"varify down bb_0 %d\\n\", tp_part_idx);\n    memcpy(verify_bb, (char*)down_bb_[compare_expers]->b,\n           T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));\n    // check if verify_bb_0 equal to check_bb_0\n    if (memcmp(verify_bb, check_bb, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)) != 0) {\n      printf(\"verify error\\n\");\n      for (size_t i = 0; i < T::BufferB::required_size(config_.hidden_size, config_.intermediate_size); ++i) {\n        if (verify_bb[i] != check_bb[i]) {\n          printf(\"Difference at byte %zu: verify_bb_%d[%zu] = %02x, check_bb[%zu] = %02x\\n\", i, compare_expers, i,\n                 (unsigned char)verify_bb[i], i, (unsigned char)check_bb[i]);\n          break;  // find the first difference and exit\n        }\n      }\n      assert(0);\n    } else {\n      printf(\"pass verify\\n\");\n      // pick out the 100th~150th byte of scale to see\n      printf(\"numa %d, verify_bb_%d:\\n\", tp_part_idx, compare_expers);\n      size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size);\n      size_t scale_size = config_.hidden_size * sizeof(float);\n      for (size_t i = size - scale_size; i < size - scale_size + 50; ++i) {\n        printf(\"%02x \", (unsigned char)verify_bb[i]);\n      }\n      printf(\"\\n\");\n    }\n  }\n#endif\n\n#ifdef FORWARD_TIME_REPORT\n  std::chrono::time_point<std::chrono::high_resolution_clock> last_now;\n#endif\n\n public:\n  AMX_MOE_TP() = default;\n\n  AMX_MOE_TP(GeneralMOEConfig config, int tp_part_idx = 0) : Base(config, tp_part_idx) {\n    // Initialization now happens in derived_init() which is called by base constructor\n  }\n\n  void derived_init() {\n    printf(\"Creating AMX_MOE_TP %d at numa %d\\n\", tp_part_idx, numa_node_of_cpu(sched_getcpu()));\n    auto& load = config_.load;\n    auto& save = config_.save;\n\n    std::filesystem::path prefix = config_.path;\n    prefix = prefix / (\"_layer_\" + std::to_string(config_.layer_idx)) / (\"_numa_\" + std::to_string(tp_part_idx));\n    if (save) {\n      std::cout << \"Creating \" << prefix << std::endl;\n      std::filesystem::create_directories(prefix);\n    }\n    if (load) {\n      if (std::filesystem::exists(prefix)) {\n        std::cout << \"Loading from \" << prefix << std::endl;\n      } else {\n        throw std::runtime_error(\"Path not found: \" + prefix.string());\n      }\n    }\n  }\n\n  ~AMX_MOE_TP() = default;\n\n  // ============================================================================\n  // CRTP buffer creation - no group_size\n  // ============================================================================\n\n  size_t buffer_a_required_size_impl(size_t m, size_t k) const { return T::BufferA::required_size(m, k); }\n  size_t buffer_b_required_size_impl(size_t n, size_t k) const { return T::BufferB::required_size(n, k); }\n  size_t buffer_c_required_size_impl(size_t m, size_t n) const { return T::BufferC::required_size(m, n); }\n\n  std::shared_ptr<typename T::BufferA> make_buffer_a_impl(size_t m, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferA>(m, k, data);\n  }\n  std::shared_ptr<typename T::BufferB> make_buffer_b_impl(size_t n, size_t k, void* data) const {\n    return std::make_shared<typename T::BufferB>(n, k, data);\n  }\n  std::shared_ptr<typename T::BufferC> make_buffer_c_impl(size_t m, size_t n, void* data) const {\n    return std::make_shared<typename T::BufferC>(m, n, data);\n  }\n\n  // ============================================================================\n  // CRTP virtual points - GEMM dispatch\n  // ============================================================================\n\n  void do_gate_up_gemm(bool do_up, int expert_idx, int ith, int nth, int qlen) {\n    int m = m_local_num_[expert_idx];\n    auto& ba = gate_up_ba_[expert_idx];\n    auto& bb = do_up ? up_bb_[expert_idx] : gate_bb_[expert_idx];\n    auto& bc = do_up ? up_bc_[expert_idx] : gate_bc_[expert_idx];\n\n    if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {\n      amx::mat_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);\n    } else {\n      amx::vec_mul(m, config_.intermediate_size, config_.hidden_size, ba, bb, bc, ith, nth);\n    }\n  }\n\n  void do_down_gemm(int expert_idx, int ith, int nth, int qlen) {\n    int m = m_local_num_[expert_idx];\n    auto& ba = down_ba_[expert_idx];\n    auto& bb = down_bb_[expert_idx];\n    auto& bc = down_bc_[expert_idx];\n\n    if (qlen > 4 * config_.expert_num / config_.num_experts_per_tok) {\n      amx::mat_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth);\n    } else {\n      amx::vec_mul(m, config_.hidden_size, config_.intermediate_size, ba, bb, bc, ith, nth);\n    }\n  }\n  void load_weights() {\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;\n    if (config_.gate_projs.size()) {\n      pool->do_work_stealing_job(\n          config_.expert_num, nullptr,\n          [this, physical_to_logical_map](int expert_id) {\n            // printf(\"Load layer %d [%d/%d]\\n\", config_.layer_idx, expert_id, config_.expert_num);\n            uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_id);\n            {\n              size_t scale_size = config_.intermediate_size * sizeof(float);\n              size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) - scale_size;\n\n              memcpy(gate_bb_[expert_id]->b, config_.gate_projs[tp_part_idx][logical_expert_id], size);\n\n              if constexpr (T::BufferB::SCALE) {\n                memcpy(gate_bb_[expert_id]->d, config_.gate_scales[tp_part_idx][logical_expert_id], scale_size);\n              }\n\n              memcpy(up_bb_[expert_id]->b, config_.up_projs[tp_part_idx][logical_expert_id], size);\n\n              if constexpr (T::BufferB::SCALE) {\n                memcpy(up_bb_[expert_id]->d, config_.up_scales[tp_part_idx][logical_expert_id], scale_size);\n              }\n            }\n\n            {\n              size_t scale_size = config_.hidden_size * sizeof(float);\n              size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size) - scale_size;\n\n              memcpy(down_bb_[expert_id]->b, config_.down_projs[tp_part_idx][logical_expert_id], size);\n\n              if constexpr (T::BufferB::SCALE) {\n                memcpy(down_bb_[expert_id]->d, config_.down_scales[tp_part_idx][logical_expert_id], scale_size);\n              }\n            }\n          },\n          nullptr);\n\n    } else {\n      int nth = T::recommended_nth(config_.intermediate_size);\n      static uint8_t mat_type_all = 3, mat_split = 1;\n      std::filesystem::path prefix = config_.path;\n      prefix = prefix / (\"_layer_\" + std::to_string(config_.layer_idx)) / (\"_numa_\" + std::to_string(tp_part_idx));\n\n      if (config_.load) {\n        std::cout << \"Loading from \" << prefix << std::endl;\n        for (int task_id = 0; task_id < config_.expert_num * mat_type_all * mat_split; task_id++) {\n          int64_t expert_idx = task_id / (mat_type_all * mat_split);\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          uint8_t mat_class = (task_id % (mat_type_all * mat_split)) / mat_split;\n          uint8_t mat_split_idex = task_id % mat_split;\n          if (mat_class == 0) {  // the up matrix\n            size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);\n            size_t scale_size = config_.intermediate_size * sizeof(float);\n            read_weights(prefix, \"_up_\", (char*)up_bb_[expert_idx]->b, logical_expert_id, size, scale_size, mat_split,\n                         mat_split_idex);\n          } else if (mat_class == 1) {\n            size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);\n            size_t scale_size = config_.intermediate_size * sizeof(float);\n            read_weights(prefix, \"_gate_\", (char*)gate_bb_[expert_idx]->b, logical_expert_id, size, scale_size,\n                         mat_split, mat_split_idex);\n          } else {\n            size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size);\n            size_t scale_size = config_.hidden_size * sizeof(float);\n            read_weights(prefix, \"_down_\", (char*)down_bb_[expert_idx]->b, logical_expert_id, size, scale_size,\n                         mat_split, mat_split_idex);\n          }\n        }\n      }\n// check process, store down matrix to check\n#ifdef CHECK\n      load_check();\n#endif\n#ifndef CHECK\n      else\n#endif\n      {\n        if (tp_part_idx == 0) {\n          std::cout << \"  online quant from bf16\" << std::endl;\n        }\n        pool->do_work_stealing_job(\n            nth * config_.expert_num, nullptr,\n            [this, nth, physical_to_logical_map](int task_id) {\n              int64_t expert_idx = task_id / nth;\n              uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n              int ith = task_id % nth;\n              // gate part\n              gate_bb_[logical_expert_id]->from_mat(\n                  (ggml_bf16_t*)config_.gate_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,\n                  ith, nth);\n              // up part\n              up_bb_[logical_expert_id]->from_mat(\n                  (ggml_bf16_t*)config_.up_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,\n                  ith, nth);\n            },\n            nullptr);\n\n        nth = T::recommended_nth(config_.hidden_size);\n        pool->do_work_stealing_job(\n            nth * config_.expert_num, nullptr,\n            [this, nth, physical_to_logical_map](int task_id) {\n              int64_t expert_idx = task_id / nth;\n              uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n              int ith = task_id % nth;\n              // down part\n              down_bb_[logical_expert_id]->from_mat(\n                  (ggml_bf16_t*)config_.down_proj + logical_expert_id * config_.hidden_size * config_.intermediate_size,\n                  ith, nth);\n              // printf(\"load idown, expert %ld, ith %d, total nth %d\\n\", expert_idx, ith, nth);\n            },\n            nullptr);\n      }\n#ifdef CHECK\n      verify_load_right();\n#endif\n      // save process\n      if (config_.save) {\n        pool->do_work_stealing_job(\n            config_.expert_num * mat_type_all, nullptr,\n            [this, physical_to_logical_map, prefix](int task_id) {\n              int64_t expert_idx = task_id / mat_type_all;\n              expert_idx = expert_map(physical_to_logical_map, expert_idx);\n              uint8_t mat_class = task_id % mat_type_all;\n              if (mat_class == 0) {  // the up matrix\n                size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);\n                size_t scale_size = config_.intermediate_size * sizeof(float);\n                write_weights(prefix, \"_up_\", (char*)up_bb_[expert_idx]->b, expert_idx, size, scale_size);\n              } else if (mat_class == 1) {\n                size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size);\n                size_t scale_size = config_.intermediate_size * sizeof(float);\n                write_weights(prefix, \"_gate_\", (char*)gate_bb_[expert_idx]->b, expert_idx, size, scale_size);\n              } else if (mat_class == 2) {\n                size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size);\n                size_t scale_size = config_.hidden_size * sizeof(float);\n                write_weights(prefix, \"_down_\", (char*)down_bb_[expert_idx]->b, expert_idx, size, scale_size);\n              }\n            },\n            nullptr);\n      }\n    }\n  }\n\n  // forward, forward_prefill, forward_decode, warm_up are inherited from Base\n};\n\n// ============================================================================\n// TP_MOE specialization for AMX_MOE_TP\n// Inherits from TP_MOE<AMX_MOE_BASE<...>> to reuse merge_results implementation\n// ============================================================================\n\ntemplate <typename K>\nclass TP_MOE<AMX_MOE_TP<K>> : public TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>> {\n public:\n  using Base = TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>>;\n  using Base::Base;\n\n  void load_weights() override {\n    auto& config = this->config;\n    auto& tps = this->tps;\n    auto& tp_count = this->tp_count;\n    auto pool = config.pool;\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;\n    if (config.gate_projs.empty() == false) {\n      printf(\"TP Load from loader\\n\");\n      DO_TPS_LOAD_WEIGHTS(pool);\n      this->weights_loaded = true;\n    } else if (config.gate_proj != nullptr) {\n      printf(\"From BF16\\n\");\n      for (auto i = 0; i < tp_count; i++) {\n        auto& tpc = tps[i]->config_;\n        size_t gate_up_elcount = tpc.intermediate_size * tpc.hidden_size;\n        tpc.gate_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];\n        tpc.up_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];\n        tpc.down_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];\n        if (tps[i]->config_.load == false) {\n          pool->get_subpool(i)->do_work_stealing_job(\n              tpc.expert_num, nullptr,\n              [&](int expert_id_) {\n                size_t expert_id = expert_map(physical_to_logical_map, expert_id_);\n                memcpy((ggml_bf16_t*)tpc.gate_proj + expert_id * gate_up_elcount,\n                       (ggml_bf16_t*)config.gate_proj + expert_id * config.intermediate_size * config.hidden_size +\n                           i * gate_up_elcount,\n                       sizeof(ggml_bf16_t) * gate_up_elcount);\n                memcpy((ggml_bf16_t*)tpc.up_proj + expert_id * gate_up_elcount,\n                       (ggml_bf16_t*)config.up_proj + expert_id * config.intermediate_size * config.hidden_size +\n                           i * gate_up_elcount,\n                       sizeof(ggml_bf16_t) * gate_up_elcount);\n                for (size_t col = 0; col < config.hidden_size; col++) {\n                  memcpy((ggml_bf16_t*)tpc.down_proj + expert_id * tpc.hidden_size * tpc.intermediate_size +\n                             col * tpc.intermediate_size,\n                         (ggml_bf16_t*)config.down_proj + expert_id * config.intermediate_size * config.hidden_size +\n                             col * config.intermediate_size + i * tpc.intermediate_size,\n                         sizeof(ggml_bf16_t) * tpc.intermediate_size);\n                }\n              },\n              nullptr);\n        }\n      }\n\n      DO_TPS_LOAD_WEIGHTS(pool);\n\n      for (auto i = 0; i < tp_count; i++) {\n        auto& tpc = tps[i]->config_;\n        delete[] (ggml_bf16_t*)(tpc.gate_proj);\n        delete[] (ggml_bf16_t*)(tpc.up_proj);\n        delete[] (ggml_bf16_t*)(tpc.down_proj);\n      }\n\n      this->weights_loaded = true;\n    } else if (config.path != \"\") {\n      printf(\"TP Load from file %s\\n\", config.path.c_str());\n      DO_TPS_LOAD_WEIGHTS(pool);\n      this->weights_loaded = true;\n    } else {\n      throw std::runtime_error(\"no weight source\");\n    }\n  }\n\n  // merge_results is inherited from TP_MOE<AMX_MOE_BASE<K, AMX_MOE_TP<K>>>\n};\n\n#endif\n"
  },
  {
    "path": "kt-kernel/operators/amx/moe_base.hpp",
    "content": "/**\n * @Description  : Common AMX MoE base class extracted from K2 implementation.\n * @Author       : oql, Codex and Claude\n * @Date         : 2025-12-09\n * @Version      : 0.1.0\n * @LastEditors  : oql, Codex and Claude\n * @LastEditTime : 2025-12-09\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_AMX_MOE_BASE_H\n#define CPUINFER_OPERATOR_AMX_MOE_BASE_H\n\n// #define FORWARD_TIME_PROFILE\n\n#include <immintrin.h>\n\n#include <algorithm>\n#include <chrono>\n#include <cmath>\n#include <cstddef>\n#include <cstdint>\n#include <cstdio>\n#include <cstring>\n#include <filesystem>\n#include <fstream>\n#include <memory>\n#include <string>\n#include <utility>\n#include <vector>\n\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"../../cpu_backend/worker_pool.h\"\n#include \"../common.hpp\"\n#include \"../moe-tp.hpp\"\n#include \"la/amx.hpp\"\n#include \"llama.cpp/ggml.h\"\n\ntemplate <class T, class Derived>\nclass AMX_MOE_BASE {\n public:\n  int tp_part_idx = 0;\n\n  ggml_bf16_t* m_local_input_ = nullptr;\n  ggml_bf16_t* m_local_gate_output_ = nullptr;\n  ggml_bf16_t* m_local_up_output_ = nullptr;\n  ggml_bf16_t* m_local_down_output_ = nullptr;\n\n  std::vector<std::vector<int>> m_local_pos_;\n  std::vector<int> m_local_num_;\n  std::vector<int> m_expert_id_map_;\n  std::vector<ggml_bf16_t*> m_local_input_ptr_;\n  std::vector<ggml_bf16_t*> m_local_gate_output_ptr_;\n  std::vector<ggml_bf16_t*> m_local_up_output_ptr_;\n  std::vector<ggml_bf16_t*> m_local_down_output_ptr_;\n\n  std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;\n  std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;\n  std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;\n  std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;\n  std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;\n  std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;\n  std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;\n  std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;\n\n  size_t pool_count_ = 0;\n  size_t gate_up_ba_pool_bytes_ = 0;\n  size_t gate_bc_pool_bytes_ = 0;\n  size_t up_bc_pool_bytes_ = 0;\n  size_t down_ba_pool_bytes_ = 0;\n  size_t down_bc_pool_bytes_ = 0;\n  void* gate_up_ba_pool_ = nullptr;\n  void* gate_bc_pool_ = nullptr;\n  void* up_bc_pool_ = nullptr;\n  void* down_ba_pool_ = nullptr;\n  void* down_bc_pool_ = nullptr;\n\n  GeneralMOEConfig config_;\n  using input_t = ggml_bf16_t;\n  using output_t = float;\n  static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;\n\n  AMX_MOE_BASE(GeneralMOEConfig config, int tp_part_idx_) : tp_part_idx(tp_part_idx_), config_(config) {\n    init();\n    derived()->derived_init();\n  }\n\n  void init() {\n    if (config_.load && config_.path == \"\") {\n      config_.load = false;\n    }\n\n    MemoryRequest mem_requests;\n    mem_requests.append_pointer(\n        &m_local_input_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);\n    mem_requests.append_pointer(&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *\n                                                           config_.max_len * config_.intermediate_size);\n    mem_requests.append_pointer(&m_local_up_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *\n                                                         config_.max_len * config_.intermediate_size);\n    mem_requests.append_pointer(&m_local_down_output_, sizeof(ggml_bf16_t) * config_.num_experts_per_tok *\n                                                           config_.max_len * config_.hidden_size);\n\n    m_local_pos_.resize(config_.max_len);\n    for (int i = 0; i < config_.max_len; i++) {\n      m_local_pos_[i].resize(config_.num_experts_per_tok);\n    }\n    m_expert_id_map_.resize(config_.expert_num);\n    m_local_num_.resize(config_.expert_num);\n    m_local_input_ptr_.resize(config_.expert_num);\n    m_local_gate_output_ptr_.resize(config_.expert_num);\n    m_local_up_output_ptr_.resize(config_.expert_num);\n    m_local_down_output_ptr_.resize(config_.expert_num);\n\n    for (size_t i = 0; i < config_.expert_num; i++) {\n      gate_up_ba_.push_back(make_buffer_a(config_.max_len, config_.hidden_size, nullptr));\n      gate_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));\n      up_bc_.push_back(make_buffer_c(config_.max_len, config_.intermediate_size, nullptr));\n      down_ba_.push_back(make_buffer_a(config_.max_len, config_.intermediate_size, nullptr));\n      down_bc_.push_back(make_buffer_c(config_.max_len, config_.hidden_size, nullptr));\n\n      void* gate_bb_ptr =\n          std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));\n      gate_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));\n\n      void* up_bb_ptr = std::aligned_alloc(64, buffer_b_required_size(config_.intermediate_size, config_.hidden_size));\n      up_bb_.push_back(make_buffer_b(config_.intermediate_size, config_.hidden_size, up_bb_ptr));\n\n      void* down_bb_ptr =\n          std::aligned_alloc(64, buffer_b_required_size(config_.hidden_size, config_.intermediate_size));\n      down_bb_.push_back(make_buffer_b(config_.hidden_size, config_.intermediate_size, down_bb_ptr));\n    }\n    // TODO: need update to all *.hpp\n    // (config_.expert_num * T::M_STEP) in pool_count_ is to ensure padding for each experts.\n    pool_count_ = config_.max_len * config_.num_experts_per_tok + config_.expert_num * T::M_STEP;\n\n    gate_up_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;\n    gate_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;\n    up_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;\n    down_ba_pool_bytes_ = buffer_a_required_size(pool_count_, config_.intermediate_size) + pool_count_ * 64;\n    down_bc_pool_bytes_ = buffer_c_required_size(pool_count_, config_.hidden_size) + pool_count_ * 64;\n\n    mem_requests.append_pointer(&gate_up_ba_pool_, gate_up_ba_pool_bytes_);\n    mem_requests.append_pointer(&gate_bc_pool_, gate_bc_pool_bytes_);\n    mem_requests.append_pointer(&up_bc_pool_, up_bc_pool_bytes_);\n    mem_requests.append_pointer(&down_ba_pool_, down_ba_pool_bytes_);\n    mem_requests.append_pointer(&down_bc_pool_, down_bc_pool_bytes_);\n\n    shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);\n  }\n\n  ~AMX_MOE_BASE() = default;\n\n  void warm_up() {\n    int qlen = config_.max_len;\n    std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);\n    std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config_.hidden_size);\n    std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);\n    std::vector<float> weights(qlen * config_.num_experts_per_tok);\n    for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {\n      expert_ids[i] = i % config_.expert_num;\n      weights[i] = 0.01;\n    }\n    forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());\n  }\n\n  void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {\n    if (qlen > 1) {\n      forward_prefill(qlen, k, expert_ids, weights, input, output);\n    } else {\n      forward_decode(k, expert_ids, weights, input, output);\n    }\n  }\n\n  template <typename... Args>\n  void load_weights(Args&&... args) {\n    derived()->load_weights(std::forward<Args>(args)...);\n  }\n\n  template <typename... Args>\n  void write_weights_to_buffer(Args&&... args) const {\n    derived_const()->write_weights_to_buffer(std::forward<Args>(args)...);\n  }\n\n  void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,\n                       void* output) {\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n#ifdef FORWARD_TIME_PROFILE\n    auto start_time = std::chrono::high_resolution_clock::now();\n    auto last = start_time;\n    long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;\n    long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;\n    int max_local_num = 0;\n#endif\n\n    int activated_expert = 0;\n    std::fill(m_local_num_.begin(), m_local_num_.end(), 0);\n    for (int i = 0; i < qlen; i++) {\n      for (int j = 0; j < k; j++) {\n        if (config_.should_skip_expert(expert_ids[i * k + j])) {\n          continue;\n        }\n        m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n      }\n    }\n\n    for (int i = 0; i < config_.expert_num; i++) {\n      if (m_local_num_[i] > 0) {\n#ifdef FORWARD_TIME_PROFILE\n        max_local_num = std::max(max_local_num, m_local_num_[i]);\n#endif\n        m_expert_id_map_[activated_expert] = i;\n        activated_expert++;\n      }\n    }\n\n    size_t offset = 0;\n    void* gate_up_ba_pool_ptr = gate_up_ba_pool_;\n    void* gate_bc_pool_ptr = gate_bc_pool_;\n    void* up_bc_pool_ptr = up_bc_pool_;\n    void* down_ba_pool_ptr = down_ba_pool_;\n    void* down_bc_pool_ptr = down_bc_pool_;\n    constexpr size_t M_STEP = T::M_STEP;\n    auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };\n    size_t used_pool_m = 0;\n    size_t used_pool_bytes_a = 0, used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0,\n           used_pool_bytes_bc_down = 0;\n\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;\n      m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n      m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n      m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;\n      offset += m_local_num_[i];\n\n      if (m_local_num_[i] == 0) {\n        continue;\n      }\n\n      size_t max_m = (m_local_num_[i] + M_STEP - 1) / M_STEP * M_STEP;\n      gate_up_ba_[i]->max_m = max_m;\n      gate_up_ba_[i]->set_data(gate_up_ba_pool_ptr);\n      size_t ba_size = align64(buffer_a_required_size(max_m, config_.hidden_size));\n      gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);\n\n      gate_bc_[i]->max_m = max_m;\n      gate_bc_[i]->set_data(gate_bc_pool_ptr);\n      size_t bc_gate_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));\n      gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size);\n\n      up_bc_[i]->max_m = max_m;\n      up_bc_[i]->set_data(up_bc_pool_ptr);\n      size_t bc_up_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));\n      up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size);\n\n      down_ba_[i]->max_m = max_m;\n      down_ba_[i]->set_data(down_ba_pool_ptr);\n      size_t ba_down_size = align64(buffer_a_required_size(max_m, config_.intermediate_size));\n      down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size);\n\n      down_bc_[i]->max_m = max_m;\n      down_bc_[i]->set_data(down_bc_pool_ptr);\n      size_t bc_down_size = align64(buffer_c_required_size(max_m, config_.hidden_size));\n      down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);\n\n      used_pool_m += max_m;\n      used_pool_bytes_a += ba_size;\n      used_pool_bytes_bc_gate += bc_gate_size;\n      used_pool_bytes_bc_up += bc_up_size;\n      used_pool_bytes_ba_down += ba_down_size;\n      used_pool_bytes_bc_down += bc_down_size;\n    }\n\n    assert(used_pool_m <= pool_count_);\n    assert(used_pool_bytes_a <= gate_up_ba_pool_bytes_);\n    assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);\n    assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);\n    assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);\n    assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    auto direct_or_pool = [&](int count, auto&& fn) {\n      if (qlen < 10) {\n        for (int i = 0; i < count; i++) {\n          fn(i);\n        }\n      } else {\n        pool->do_work_stealing_job(count, nullptr, fn, nullptr);\n      }\n    };\n\n    direct_or_pool(qlen, [&](int i) {\n      for (int j = 0; j < k; j++) {\n        if (config_.should_skip_expert(expert_ids[i * k + j])) {\n          continue;\n        }\n        memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,\n               (ggml_bf16_t*)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);\n      }\n    });\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    direct_or_pool(activated_expert, [this](int task_id) {\n      int expert_idx = m_expert_id_map_[task_id];\n      gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);\n    });\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    int nth = T::recommended_nth(config_.intermediate_size);\n    pool->do_work_stealing_job(\n        nth * activated_expert * 2, [](int _) { T::config(); },\n        [this, nth, qlen](int task_id2) {\n          int task_id = task_id2 / 2;\n          bool do_up = task_id2 % 2;\n          int expert_idx = m_expert_id_map_[task_id / nth];\n\n          int ith = task_id % nth;\n          derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);\n          if (do_up) {\n            up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);\n          } else {\n            gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);\n          }\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    apply_activation(activated_expert, nth, qlen);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    pool->do_work_stealing_job(\n        activated_expert, nullptr,\n        [this](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    nth = T::recommended_nth(config_.hidden_size);\n    pool->do_work_stealing_job(\n        nth * activated_expert, [](int _) { T::config(); },\n        [this, nth, qlen](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n          derived()->do_down_gemm(expert_idx, ith, nth, qlen);\n          down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    pool->do_work_stealing_job(\n        qlen, nullptr,\n        [this, output, k, expert_ids, weights](int i) {\n          for (int e = 0; e < config_.hidden_size; e += 32) {\n            __m512 x0 = _mm512_setzero_ps();\n            __m512 x1 = _mm512_setzero_ps();\n            for (int j = 0; j < k; j++) {\n              if (config_.should_skip_expert(expert_ids[i * k + j])) {\n                continue;\n              }\n              __m512 weight = _mm512_set1_ps(weights[i * k + j]);\n              __m512 down_output0, down_output1;\n              avx512_32xbf16_to_32xfp32((__m512i*)(m_local_down_output_ptr_[expert_ids[i * k + j]] +\n                                                   m_local_pos_[i][j] * config_.hidden_size + e),\n                                        &down_output0, &down_output1);\n              x0 = _mm512_fmadd_ps(down_output0, weight, x0);\n              x1 = _mm512_fmadd_ps(down_output1, weight, x1);\n            }\n            auto f32out = (__m512*)((float*)output + i * config_.hidden_size + e);\n            f32out[0] = x0;\n            f32out[1] = x1;\n          }\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n    auto end_time = std::chrono::high_resolution_clock::now();\n    auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();\n    printf(\n        \"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, \"\n        \"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: \"\n        \"%d, qlen: %d\\n\",\n        tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,\n        down_time, weight_time, forward_total_time, max_local_num, qlen);\n#endif\n  }\n\n  void forward_decode(int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {\n    int qlen = 1;\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n#ifdef FORWARD_TIME_PROFILE\n    auto start_time = std::chrono::high_resolution_clock::now();\n    auto last = start_time;\n    long q_input_time = 0, up_gate_time = 0, act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;\n#endif\n\n    int activated_expert = 0;\n    std::fill(m_local_num_.begin(), m_local_num_.end(), 0);\n    for (int i = 0; i < k; i++) {\n      if (config_.should_skip_expert(expert_ids[i])) {\n        continue;\n      }\n      m_expert_id_map_[activated_expert] = expert_ids[i];\n      m_local_pos_[0][i] = 0;\n      m_local_num_[expert_ids[i]] = qlen;\n      activated_expert++;\n    }\n\n    size_t offset = 0;\n    for (int i = 0; i < activated_expert; i++) {\n      auto expert_idx = m_expert_id_map_[i];\n      m_local_gate_output_ptr_[expert_idx] = m_local_gate_output_ + offset * config_.intermediate_size;\n      m_local_up_output_ptr_[expert_idx] = m_local_up_output_ + offset * config_.intermediate_size;\n      m_local_down_output_ptr_[expert_idx] = m_local_down_output_ + offset * config_.hidden_size;\n      offset += qlen;\n    }\n\n    void* gate_bc_pool_ptr = gate_bc_pool_;\n    void* up_bc_pool_ptr = up_bc_pool_;\n    void* down_ba_pool_ptr = down_ba_pool_;\n    void* down_bc_pool_ptr = down_bc_pool_;\n    constexpr size_t M_STEP = T::M_STEP;\n    auto align64 = [](size_t v) { return (v + 63) & (~(size_t)63); };\n    size_t used_pool_m = 0;\n    size_t used_pool_bytes_bc_gate = 0, used_pool_bytes_bc_up = 0, used_pool_bytes_ba_down = 0,\n           used_pool_bytes_bc_down = 0;\n    for (int i = 0; i < activated_expert; i++) {\n      auto expert_idx = m_expert_id_map_[i];\n      size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;\n\n      gate_bc_[expert_idx]->max_m = max_m;\n      gate_bc_[expert_idx]->set_data(gate_bc_pool_ptr);\n      size_t bc_gate_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));\n      gate_bc_pool_ptr = (void*)((uintptr_t)gate_bc_pool_ptr + bc_gate_size);\n\n      up_bc_[expert_idx]->max_m = max_m;\n      up_bc_[expert_idx]->set_data(up_bc_pool_ptr);\n      size_t bc_up_size = align64(buffer_c_required_size(max_m, config_.intermediate_size));\n      up_bc_pool_ptr = (void*)((uintptr_t)up_bc_pool_ptr + bc_up_size);\n\n      down_ba_[expert_idx]->max_m = max_m;\n      down_ba_[expert_idx]->set_data(down_ba_pool_ptr);\n      size_t ba_down_size = align64(buffer_a_required_size(max_m, config_.intermediate_size));\n      down_ba_pool_ptr = (void*)((uintptr_t)down_ba_pool_ptr + ba_down_size);\n\n      down_bc_[expert_idx]->max_m = max_m;\n      down_bc_[expert_idx]->set_data(down_bc_pool_ptr);\n      size_t bc_down_size = align64(buffer_c_required_size(max_m, config_.hidden_size));\n      down_bc_pool_ptr = (void*)((uintptr_t)down_bc_pool_ptr + bc_down_size);\n\n      used_pool_m += max_m;\n      used_pool_bytes_bc_gate += bc_gate_size;\n      used_pool_bytes_bc_up += bc_up_size;\n      used_pool_bytes_ba_down += ba_down_size;\n      used_pool_bytes_bc_down += bc_down_size;\n    }\n    assert(used_pool_m <= pool_count_);\n    assert(used_pool_bytes_bc_gate <= gate_bc_pool_bytes_);\n    assert(used_pool_bytes_bc_up <= up_bc_pool_bytes_);\n    assert(used_pool_bytes_ba_down <= down_ba_pool_bytes_);\n    assert(used_pool_bytes_bc_down <= down_bc_pool_bytes_);\n\n    void* gate_up_ba_pool_ptr = gate_up_ba_pool_;\n    for (int i = 0; i < activated_expert; i++) {\n      auto expert_idx = m_expert_id_map_[i];\n      size_t max_m = (qlen + M_STEP - 1) / M_STEP * M_STEP;\n      gate_up_ba_[expert_idx]->max_m = max_m;\n      gate_up_ba_[expert_idx]->set_data(gate_up_ba_pool_ptr);\n      size_t ba_size = align64(buffer_a_required_size(max_m, config_.hidden_size));\n      gate_up_ba_pool_ptr = (void*)((uintptr_t)gate_up_ba_pool_ptr + ba_size);\n      gate_up_ba_[expert_idx]->from_mat(qlen, (ggml_bf16_t*)input, 0, 1);\n    }\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      q_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    int nth = T::recommended_nth(config_.intermediate_size);\n    pool->do_work_stealing_job(\n        nth * activated_expert * 2, [](int _) { T::config(); },\n        [this, nth, qlen](int task_id2) {\n          int task_id = task_id2 / 2;\n          bool do_up = task_id2 % 2;\n          int expert_idx = m_expert_id_map_[task_id / nth];\n\n          int ith = task_id % nth;\n          derived()->do_gate_up_gemm(do_up, expert_idx, ith, nth, qlen);\n          if (do_up) {\n            up_bc_[expert_idx]->to_mat(qlen, m_local_up_output_ptr_[expert_idx], ith, nth);\n          } else {\n            gate_bc_[expert_idx]->to_mat(qlen, m_local_gate_output_ptr_[expert_idx], ith, nth);\n          }\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    apply_activation(activated_expert, nth, qlen);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      act_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    pool->do_work_stealing_job(\n        activated_expert, nullptr,\n        [this, qlen](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          down_ba_[expert_idx]->from_mat(qlen, m_local_gate_output_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      q_down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    nth = T::recommended_nth(config_.hidden_size);\n    pool->do_work_stealing_job(\n        nth * activated_expert, [](int _) { T::config(); },\n        [this, nth, qlen](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n          derived()->do_down_gemm(expert_idx, ith, nth, qlen);\n          down_bc_[expert_idx]->to_mat(qlen, m_local_down_output_ptr_[expert_idx], ith, nth);\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    for (int e = 0; e < config_.hidden_size; e += 32) {\n      __m512 x0 = _mm512_setzero_ps();\n      __m512 x1 = _mm512_setzero_ps();\n      for (int j = 0; j < k; j++) {\n        if (config_.should_skip_expert(expert_ids[j])) {\n          continue;\n        }\n        __m512 weight = _mm512_set1_ps(weights[j]);\n        __m512 down_output0, down_output1;\n        avx512_32xbf16_to_32xfp32(\n            (__m512i*)(m_local_down_output_ptr_[expert_ids[j]] + m_local_pos_[0][j] * config_.hidden_size + e),\n            &down_output0, &down_output1);\n        x0 = _mm512_fmadd_ps(down_output0, weight, x0);\n        x1 = _mm512_fmadd_ps(down_output1, weight, x1);\n      }\n      auto f32out = (__m512*)((float*)output + e);\n      f32out[0] = x0;\n      f32out[1] = x1;\n    }\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n    auto end_time = std::chrono::high_resolution_clock::now();\n    auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();\n    printf(\n        \"Profiling Results (numa[%d]): activated_expert: %d, q_input: %ld us, \"\n        \"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us\\n\",\n        tp_part_idx, activated_expert, q_input_time, up_gate_time, act_time, q_down_time, down_time, weight_time,\n        forward_total_time);\n#endif\n  }\n\n protected:\n  Derived* derived() { return static_cast<Derived*>(this); }\n  const Derived* derived_const() const { return static_cast<const Derived*>(this); }\n\n  // ============================================================================\n  // Derived class initialization hook\n  // Called after base class init() completes, allows derived classes to perform\n  // their own initialization that depends on base class being fully initialized\n  // ============================================================================\n  void derived_init() {\n    // Default implementation does nothing - derived classes can override\n  }\n\n  // ============================================================================\n  // Virtual points for buffer creation and size calculation\n  // Default implementations use group_size (for KGroup quantization like K2)\n  // Derived classes (like moe.hpp) can override to not use group_size\n  // ============================================================================\n\n  size_t buffer_a_required_size(size_t m, size_t k) const { return derived_const()->buffer_a_required_size_impl(m, k); }\n  size_t buffer_b_required_size(size_t n, size_t k) const { return derived_const()->buffer_b_required_size_impl(n, k); }\n  size_t buffer_c_required_size(size_t m, size_t n) const { return derived_const()->buffer_c_required_size_impl(m, n); }\n\n  std::shared_ptr<typename T::BufferA> make_buffer_a(size_t m, size_t k, void* data) const {\n    return derived_const()->make_buffer_a_impl(m, k, data);\n  }\n  std::shared_ptr<typename T::BufferB> make_buffer_b(size_t n, size_t k, void* data) const {\n    return derived_const()->make_buffer_b_impl(n, k, data);\n  }\n  std::shared_ptr<typename T::BufferC> make_buffer_c(size_t m, size_t n, void* data) const {\n    return derived_const()->make_buffer_c_impl(m, n, data);\n  }\n\n  void apply_activation(int activated_expert, int nth, int qlen) {\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n    auto fn = [this, nth](int task_id) {\n      int expert_idx = m_expert_id_map_[task_id / nth];\n      int ith = task_id % nth;\n      auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);\n      for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n        ggml_bf16_t* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];\n        ggml_bf16_t* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];\n        for (int j = n_start; j < n_end; j += 32) {\n          __m512 gate_val0, gate_val1, up_val0, up_val1;\n          avx512_32xbf16_to_32xfp32((__m512i*)(gate_output_ptr + j), &gate_val0, &gate_val1);\n          avx512_32xbf16_to_32xfp32((__m512i*)(up_output_ptr + j), &up_val0, &up_val1);\n          __m512 result0 = amx::act_fn(gate_val0, up_val0);\n          __m512 result1 = amx::act_fn(gate_val1, up_val1);\n          avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i*)(gate_output_ptr + j));\n        }\n      }\n    };\n\n    if (activated_expert == 0) {\n      return;\n    }\n\n    if (qlen < 10) {\n      for (int task_id = 0; task_id < nth * activated_expert; task_id++) {\n        fn(task_id);\n      }\n    } else {\n      pool->do_work_stealing_job(nth * activated_expert, nullptr, fn, nullptr);\n    }\n  }\n};\n\n// ============================================================================\n// TP_MOE specialization for AMX_MOE_BASE derived classes\n// ============================================================================\n\ntemplate <class T, class Derived>\nclass TP_MOE<AMX_MOE_BASE<T, Derived>> : public TP_MOE_Common<AMX_MOE_BASE<T, Derived>> {\n public:\n  using TP_MOE_Common<AMX_MOE_BASE<T, Derived>>::TP_MOE_Common;\n\n  // Default load_weights implementation - can be overridden by derived TP_MOE classes\n  void load_weights() override { throw std::runtime_error(\"Not Implemented\"); }\n\n  void write_weight_scale_to_buffer(int gpu_tp_count, int gpu_experts_num,\n                                    const std::vector<uintptr_t>& w13_weight_ptrs,\n                                    const std::vector<uintptr_t>& w13_scale_ptrs,\n                                    const std::vector<uintptr_t>& w2_weight_ptrs,\n                                    const std::vector<uintptr_t>& w2_scale_ptrs) {\n    throw std::runtime_error(\"Not Implemented\");\n  }\n\n  void merge_results(int qlen, void* output, bool incremental) override {\n    auto& config = this->config;\n    auto& tp_count = this->tp_count;\n    auto& local_output_numa = this->local_output_numa;\n    auto& tp_configs = this->tp_configs;\n\n    auto merge_fn = [this, output, incremental, &config, &tp_count, &local_output_numa, &tp_configs](int token_nth) {\n      float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;\n      if (incremental) {\n        for (int e = 0; e < config.hidden_size; e += 32) {\n          __m512 x0, x1;\n          avx512_32xbf16_to_32xfp32((__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e), &x0, &x1);\n          *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), x0);\n          *((__m512*)(merge_to + e + 16)) = _mm512_add_ps(*((__m512*)(merge_to + e + 16)), x1);\n        }\n      }\n      for (int i = 1; i < tp_count; i++) {\n        float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;\n        for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {\n          *((__m512*)(merge_to + e)) = _mm512_add_ps(*((__m512*)(merge_to + e)), *((__m512*)(merge_from + e)));\n        }\n      }\n      for (int e = 0; e < config.hidden_size; e += 32) {\n        __m512 x0 = *(__m512*)(merge_to + e);\n        __m512 x1 = *(__m512*)(merge_to + e + 16);\n        avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i*)((ggml_bf16_t*)output + token_nth * config.hidden_size + e));\n      }\n    };\n\n    auto pool = config.pool;\n\n    auto direct_or_pool = [&](int count, auto&& fn) {\n      if (qlen < 10) {\n        for (int i = 0; i < count; i++) {\n          fn(i);\n        }\n      } else {\n        pool->do_work_stealing_job(count, nullptr, fn, nullptr);\n      }\n    };\n\n    direct_or_pool(qlen, merge_fn);\n  }\n\n  void merge_results(int qlen, void* output) override { merge_results(qlen, output, false); }\n};\n\n#endif  // CPUINFER_OPERATOR_AMX_MOE_BASE_H\n"
  },
  {
    "path": "kt-kernel/operators/amx/test/amx-bkgroup-test.cpp",
    "content": "#include <omp.h>\n\n#include \"../la/amx.hpp\"\n#define FMT_HEADER_ONLY\n#include <fmt/core.h>\n\n#include <cmath>\n#include <iostream>\n#include <memory>\n\n// Test kernel configuration for k-group testing\nstruct TestKernelKGroupB {\n  static constexpr int M_STEP = 32;\n  static constexpr int K_STEP = 64;\n  static constexpr int K_BLOCK = 512;\n  static constexpr int N_STEP = 32;\n  static constexpr int N_BLOCK = 512;\n  static constexpr int TILE_N = 16;\n  using dt = int8_t;\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_per_thread = (n + nth - 1) / nth;\n    int n_start = ith * n_per_thread;\n    int n_end = std::min(n_start + n_per_thread, n);\n    return {n_start, n_end};\n  }\n};\n\nvoid test_buffer_bkgroup_basic() {\n  std::cout << \"=== Testing BufferBKGroupImpl Basic Functionality ===\" << std::endl;\n\n  // Test parameters\n  const int k = 2048;            // Must be multiple of K_STEP and K_BLOCK\n  const int n = 1024;            // Must be multiple of TILE_N\n  const int k_group_size = 128;  // Must divide K_BLOCK evenly\n\n  std::cout << fmt::format(\"Parameters: k={}, n={}, k_group_size={}\\n\", k, n, k_group_size);\n\n  // Calculate and allocate buffer\n  size_t buffer_size = amx::BufferBKGroupImpl<TestKernelKGroupB>::required_size(k, n, k_group_size);\n  void* buffer = std::aligned_alloc(64, buffer_size);\n  std::memset(buffer, 0, buffer_size);\n\n  std::cout << fmt::format(\"Buffer size: {} bytes\\n\", buffer_size);\n\n  // Create BufferBKGroupImpl instance\n  auto buf = std::make_unique<amx::BufferBKGroupImpl<TestKernelKGroupB>>(k, n, k_group_size, buffer);\n\n  // Create test input data (bf16)\n  std::vector<ggml_bf16_t> input(k * n);\n  std::mt19937 gen(42);\n  std::uniform_real_distribution<float> dist(-1.0f, 1.0f);\n\n  for (int i = 0; i < k * n; i++) {\n    float val = dist(gen);\n    input[i] = ggml_compute_fp32_to_bf16(val);\n  }\n\n  // Test from_mat\n  std::cout << \"Testing from_mat...\" << std::endl;\n  buf->from_mat(input.data(), 0, 1);\n  std::cout << \"✓ from_mat completed successfully\" << std::endl;\n\n  // Test get_submat\n  std::cout << \"Testing get_submat...\" << std::endl;\n  for (int k_begin = 0; k_begin < k; k_begin += TestKernelKGroupB::K_STEP) {\n    for (int n_begin = 0; n_begin < n; n_begin += TestKernelKGroupB::TILE_N) {\n      int8_t* submat = buf->get_submat(k, n, k_begin, n_begin);\n      if (submat == nullptr) {\n        std::cerr << fmt::format(\"ERROR: get_submat returned null for k_begin={}, n_begin={}\\n\", k_begin, n_begin);\n        free(buffer);\n        return;\n      }\n    }\n  }\n  std::cout << \"✓ get_submat tested for all valid positions\" << std::endl;\n\n  // Test get_scale\n  std::cout << \"Testing get_scale...\" << std::endl;\n  int k_group_count = k / k_group_size;\n  for (int n_idx = 0; n_idx < n; n_idx++) {\n    for (int kg_idx = 0; kg_idx < k_group_count; kg_idx++) {\n      float* scale = buf->get_scale(n, n_idx, k, kg_idx * k_group_size);\n      if (scale == nullptr) {\n        std::cerr << fmt::format(\"ERROR: get_scale returned null for n_idx={}, k_group={}\\n\", n_idx, kg_idx);\n        free(buffer);\n        return;\n      }\n      // Verify scale is non-zero (should be set by from_mat)\n      if (*scale == 0.0f) {\n        std::cerr << fmt::format(\"WARNING: scale is zero for n_idx={}, k_group={}\\n\", n_idx, kg_idx);\n      }\n    }\n  }\n  std::cout << \"✓ get_scale tested for all k-groups\" << std::endl;\n\n  // Print some scale values for verification\n  std::cout << \"\\nSample scale values:\" << std::endl;\n  for (int kg = 0; kg < std::min(4, k_group_count); kg++) {\n    float* scale = buf->get_scale(n, 0, k, kg * k_group_size);\n    std::cout << fmt::format(\"  k_group[{}] (k={}): scale = {:.6f}\\n\", kg, kg * k_group_size, *scale);\n  }\n\n  // Clean up\n  free(buffer);\n  std::cout << \"\\n✓ All basic tests passed!\" << std::endl;\n}\n\nvoid test_buffer_bkgroup_correctness() {\n  std::cout << \"\\n=== Testing BufferBKGroupImpl Quantization Correctness ===\" << std::endl;\n\n  const int k = 512;\n  const int n = 256;\n  const int k_group_size = 128;\n\n  size_t buffer_size = amx::BufferBKGroupImpl<TestKernelKGroupB>::required_size(k, n, k_group_size);\n  void* buffer = std::aligned_alloc(64, buffer_size);\n\n  auto buf = std::make_unique<amx::BufferBKGroupImpl<TestKernelKGroupB>>(k, n, k_group_size, buffer);\n\n  // Create test input matrix with known patterns\n  std::vector<float> original(k * n);\n  std::vector<ggml_bf16_t> input(k * n);\n\n  // Fill with different patterns for each k-group to test group-wise quantization\n  for (int k_idx = 0; k_idx < k; k_idx++) {\n    for (int n_idx = 0; n_idx < n; n_idx++) {\n      int kg = k_idx / k_group_size;\n      // Different magnitude for each k-group\n      float base_val = (kg + 1) * 0.1f;\n      float val = base_val * std::sin(k_idx * 0.01f + n_idx * 0.1f);\n      original[k_idx * n + n_idx] = val;\n      input[k_idx * n + n_idx] = ggml_compute_fp32_to_bf16(val);\n    }\n  }\n\n  // Quantize\n  buf->from_mat(input.data(), 0, 1);\n\n  // Calculate quantization error statistics\n  float max_error = 0.0f;\n  float total_error = 0.0f;\n  float avg_magnitude = 0.0f;\n\n  for (int i = 0; i < k * n; i++) {\n    avg_magnitude += std::abs(original[i]);\n  }\n  avg_magnitude /= (k * n);\n\n  // Since we're using 4-bit quantization, expect higher error than int8\n  // Just verify that scales are being computed correctly\n  std::cout << fmt::format(\"Quantization Analysis:\\n\");\n  std::cout << fmt::format(\"  Average magnitude: {:.6f}\\n\", avg_magnitude);\n  std::cout << fmt::format(\"  Using 4-bit quantization (INT4)\\n\");\n\n  // Test that different k-groups have different scales\n  std::cout << \"\\nVerifying k-group scales are computed independently:\" << std::endl;\n  bool scales_differ = false;\n  for (int n_idx = 0; n_idx < std::min(4, n); n_idx++) {\n    float* scale0 = buf->get_scale(n, n_idx, k, 0);\n    for (int kg = 1; kg < k / k_group_size; kg++) {\n      float* scale_kg = buf->get_scale(n, n_idx, k, kg * k_group_size);\n      if (std::abs(*scale0 - *scale_kg) > 1e-6f) {\n        scales_differ = true;\n        break;\n      }\n    }\n    if (scales_differ) break;\n  }\n\n  if (scales_differ) {\n    std::cout << \"✓ Different k-groups have independent scales\" << std::endl;\n  } else {\n    std::cout << \"✗ Warning: All k-groups have the same scale (might be correct for uniform data)\" << std::endl;\n  }\n\n  free(buffer);\n}\n\nvoid test_buffer_bkgroup_comparison() {\n  std::cout << \"\\n=== Comparing BufferBInt4Impl vs BufferBKGroupImpl ===\" << std::endl;\n\n  const int k = 2048;\n  const int n = 512;\n  const int k_group_size = 256;\n\n  // Create test data\n  std::vector<ggml_bf16_t> input(k * n);\n  std::mt19937 gen(456);\n  std::uniform_real_distribution<float> dist(-1.0f, 1.0f);\n  for (int i = 0; i < k * n; i++) {\n    input[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n\n  // Test original BufferBInt4Impl\n  {\n    size_t buffer_size = amx::BufferBInt4Impl<TestKernelKGroupB>::required_size(k, n);\n    void* buffer = std::aligned_alloc(64, buffer_size);\n    auto buf_b = std::make_unique<amx::BufferBInt4Impl<TestKernelKGroupB>>(k, n, buffer);\n\n    buf_b->from_mat(input.data(), 0, 1);\n\n    // Print some scales\n    std::cout << \"BufferBInt4Impl scales (per-column):\" << std::endl;\n    for (int n_idx = 0; n_idx < std::min(4, n); n_idx++) {\n      float* scale = buf_b->get_scale(n, n_idx);\n      std::cout << fmt::format(\"  col[{}]: scale = {:.6f}\\n\", n_idx, *scale);\n    }\n\n    free(buffer);\n  }\n\n  // Test BufferBKGroupImpl\n  {\n    size_t buffer_size = amx::BufferBKGroupImpl<TestKernelKGroupB>::required_size(k, n, k_group_size);\n    void* buffer = std::aligned_alloc(64, buffer_size);\n    auto buf_kg = std::make_unique<amx::BufferBKGroupImpl<TestKernelKGroupB>>(k, n, k_group_size, buffer);\n\n    buf_kg->from_mat(input.data(), 0, 1);\n\n    // Print some scales\n    std::cout << \"\\nBufferBKGroupImpl scales (per k-group):\" << std::endl;\n    for (int n_idx = 0; n_idx < std::min(2, n); n_idx++) {\n      std::cout << fmt::format(\"  col[{}]:\\n\", n_idx);\n      for (int kg = 0; kg < std::min(4, k / k_group_size); kg++) {\n        float* scale = buf_kg->get_scale(n, n_idx, k, kg * k_group_size);\n        std::cout << fmt::format(\"    k_group[{}]: scale = {:.6f}\\n\", kg, *scale);\n      }\n    }\n\n    free(buffer);\n  }\n\n  std::cout << \"\\n✓ Comparison test completed\" << std::endl;\n}\n\nint main(int argc, char** argv) {\n  std::cout << \"Starting BufferBKGroupImpl Tests\\n\" << std::endl;\n\n  try {\n    // Run basic functionality tests\n    test_buffer_bkgroup_basic();\n\n    // Run correctness tests\n    test_buffer_bkgroup_correctness();\n\n    // Run comparison tests\n    test_buffer_bkgroup_comparison();\n\n    std::cout << \"\\n=== All tests completed successfully! ===\" << std::endl;\n  } catch (const std::exception& e) {\n    std::cerr << \"Test failed with exception: \" << e.what() << std::endl;\n    return 1;\n  }\n\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/amx-c-reduce-test.cpp",
    "content": "#include <omp.h>\n\n#include \"../la/amx.hpp\"\n#define FMT_HEADER_ONLY\n#include <fmt/core.h>\n\n#include <cmath>\n#include <iostream>\n#include <memory>\n#include <random>\n\n// Test kernel configuration\nstruct TestKernelC {\n  static constexpr int M_STEP = 32;\n  static constexpr int K_STEP = 64;\n  static constexpr int K_BLOCK = 512;\n  static constexpr int N_STEP = 32;\n  static constexpr int N_BLOCK = 512;\n  static constexpr int TILE_N = 16;\n  using dt = int8_t;\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_per_thread = (n + nth - 1) / nth;\n    int n_start = ith * n_per_thread;\n    int n_end = std::min(n_start + n_per_thread, n);\n    return {n_start, n_end};\n  }\n};\n\nvoid test_buffer_c_reduce_basic() {\n  std::cout << \"=== Testing BufferCReduceImpl Basic Functionality ===\" << std::endl;\n\n  // Test parameters\n  const int max_m = 64;  // Must be multiple of M_STEP\n  const int n = 512;     // Must be multiple of N_STEP\n\n  std::cout << fmt::format(\"Parameters: max_m={}, n={}\\n\", max_m, n);\n\n  // Calculate and allocate buffer for BufferCReduceImpl\n  size_t buffer_size = amx::BufferCReduceImpl<TestKernelC>::required_size(max_m, n);\n  void* buffer = std::aligned_alloc(64, buffer_size);\n  std::memset(buffer, 0, buffer_size);\n\n  std::cout << fmt::format(\"Buffer size: {} bytes\\n\", buffer_size);\n  std::cout << fmt::format(\"  Float buffer: {} bytes\\n\", sizeof(float) * max_m * n);\n  std::cout << fmt::format(\"  Int32 buffer: {} bytes\\n\", sizeof(int32_t) * max_m * n);\n\n  // Create BufferCReduceImpl instance\n  auto buf = std::make_unique<amx::BufferCReduceImpl<TestKernelC>>(max_m, n, buffer);\n\n  // Test 1: Verify buffer pointers are set correctly\n  std::cout << \"\\nTest 1: Buffer pointer verification\" << std::endl;\n  if (buf->c == nullptr) {\n    std::cerr << \"ERROR: Float buffer pointer is null\" << std::endl;\n    free(buffer);\n    return;\n  }\n  if (buf->int_c == nullptr) {\n    std::cerr << \"ERROR: Int32 buffer pointer is null\" << std::endl;\n    free(buffer);\n    return;\n  }\n\n  // Verify int_c starts after c\n  size_t expected_offset = max_m * n;\n  size_t actual_offset = buf->int_c - reinterpret_cast<int32_t*>(buf->c);\n  if (actual_offset != expected_offset) {\n    std::cerr << fmt::format(\"ERROR: int_c offset incorrect. Expected: {}, Got: {}\\n\", expected_offset, actual_offset)\n              << std::endl;\n    free(buffer);\n    return;\n  }\n  std::cout << \"✓ Buffer pointers are correctly set\" << std::endl;\n\n  // Test 2: Write to float buffer and verify\n  std::cout << \"\\nTest 2: Float buffer write/read\" << std::endl;\n  std::mt19937 gen(42);\n  std::uniform_real_distribution<float> dist(-1.0f, 1.0f);\n\n  // Fill float buffer with test data\n  for (int i = 0; i < max_m * n; i++) {\n    buf->c[i] = dist(gen);\n  }\n\n  // Verify get_submat works\n  for (int m_begin = 0; m_begin < max_m; m_begin += TestKernelC::M_STEP) {\n    for (int n_begin = 0; n_begin < n; n_begin += TestKernelC::N_STEP) {\n      float* submat = buf->get_submat(max_m, n, m_begin, n_begin);\n      if (submat == nullptr) {\n        std::cerr << fmt::format(\"ERROR: get_submat returned null for m_begin={}, n_begin={}\\n\", m_begin, n_begin)\n                  << std::endl;\n        free(buffer);\n        return;\n      }\n    }\n  }\n  std::cout << \"✓ Float buffer read/write works correctly\" << std::endl;\n\n  // Test 3: Write to int32 buffer and verify\n  std::cout << \"\\nTest 3: Int32 buffer write/read\" << std::endl;\n  std::uniform_int_distribution<int32_t> int_dist(-1000, 1000);\n\n  // Fill int32 buffer with test data\n  for (int i = 0; i < max_m * n; i++) {\n    buf->int_c[i] = int_dist(gen);\n  }\n\n  // Verify get_int_submat works\n  for (int m_begin = 0; m_begin < max_m; m_begin += TestKernelC::M_STEP) {\n    for (int n_begin = 0; n_begin < n; n_begin += TestKernelC::N_STEP) {\n      int32_t* submat = buf->get_int_submat(max_m, n, m_begin, n_begin);\n      if (submat == nullptr) {\n        std::cerr << fmt::format(\"ERROR: get_int_submat returned null for m_begin={}, n_begin={}\\n\", m_begin, n_begin)\n                  << std::endl;\n        free(buffer);\n        return;\n      }\n    }\n  }\n  std::cout << \"✓ Int32 buffer read/write works correctly\" << std::endl;\n\n  // Test 4: Clear int buffer\n  std::cout << \"\\nTest 4: Clear int buffer\" << std::endl;\n  buf->clear_int_buffer();\n  bool all_zero = true;\n  for (int i = 0; i < max_m * n; i++) {\n    if (buf->int_c[i] != 0) {\n      all_zero = false;\n      break;\n    }\n  }\n  if (!all_zero) {\n    std::cerr << \"ERROR: clear_int_buffer failed to zero the buffer\" << std::endl;\n    free(buffer);\n    return;\n  }\n  std::cout << \"✓ clear_int_buffer works correctly\" << std::endl;\n\n  // Test 5: Convert int to float\n  std::cout << \"\\nTest 5: Convert int32 to float\" << std::endl;\n  // Set some test values in int buffer\n  for (int i = 0; i < max_m * n; i++) {\n    buf->int_c[i] = i % 100 - 50;  // Values from -50 to 49\n  }\n\n  // Convert\n  buf->convert_int_to_float(max_m);\n\n  // Verify conversion\n  bool conversion_correct = true;\n  for (int i = 0; i < max_m * n; i++) {\n    float expected = static_cast<float>(i % 100 - 50);\n    if (std::abs(buf->c[i] - expected) > 1e-6) {\n      std::cerr << fmt::format(\"ERROR: Conversion mismatch at index {}. Expected: {}, Got: {}\\n\", i, expected,\n                               buf->c[i])\n                << std::endl;\n      conversion_correct = false;\n      break;\n    }\n  }\n  if (!conversion_correct) {\n    free(buffer);\n    return;\n  }\n  std::cout << \"✓ convert_int_to_float works correctly\" << std::endl;\n\n  // Test 6: to_mat functionality\n  std::cout << \"\\nTest 6: to_mat conversion\" << std::endl;\n  // Fill buffer using proper blocked layout via get_submat\n  for (int m_idx = 0; m_idx < max_m; m_idx += TestKernelC::M_STEP) {\n    for (int n_idx = 0; n_idx < n; n_idx += TestKernelC::N_STEP) {\n      float* submat = buf->get_submat(max_m, n, m_idx, n_idx);\n      // Fill this submat block\n      for (int i = 0; i < TestKernelC::M_STEP && m_idx + i < max_m; i++) {\n        for (int j = 0; j < TestKernelC::N_STEP && n_idx + j < n; j++) {\n          submat[i * TestKernelC::N_STEP + j] = (m_idx + i) * 0.1f + (n_idx + j) * 0.01f;\n        }\n      }\n    }\n  }\n\n  // Convert to bf16\n  std::vector<ggml_bf16_t> output(max_m * n);\n  buf->to_mat(max_m, output.data(), 0, 1);\n\n  // Verify some values\n  bool to_mat_correct = true;\n  for (int i = 0; i < std::min(10, max_m); i++) {\n    for (int j = 0; j < std::min(10, n); j++) {\n      float original = i * 0.1f + j * 0.01f;\n      float converted = ggml_compute_bf16_to_fp32(output[i * n + j]);\n      // BF16 has limited precision, allow for some error\n      if (std::abs(original - converted) > 0.02f) {  // Increased tolerance for BF16\n        std::cerr << fmt::format(\"ERROR: to_mat mismatch at ({},{}). Original: {}, Converted: {}\\n\", i, j, original,\n                                 converted)\n                  << std::endl;\n        to_mat_correct = false;\n        break;\n      }\n    }\n    if (!to_mat_correct) break;\n  }\n\n  if (!to_mat_correct) {\n    free(buffer);\n    return;\n  }\n  std::cout << \"✓ to_mat works correctly\" << std::endl;\n\n  // Clean up\n  free(buffer);\n  std::cout << \"\\n✓ All basic tests passed!\" << std::endl;\n}\n\nvoid test_buffer_c_reduce_comparison() {\n  std::cout << \"\\n=== Comparing BufferCImpl vs BufferCReduceImpl ===\" << std::endl;\n\n  const int max_m = 128;\n  const int n = 1024;\n\n  // Test original BufferCImpl\n  {\n    size_t buffer_size = amx::BufferCImpl<TestKernelC>::required_size(max_m, n);\n    void* buffer = std::aligned_alloc(64, buffer_size);\n    auto buf_c = std::make_unique<amx::BufferCImpl<TestKernelC>>(max_m, n, buffer);\n\n    std::cout << fmt::format(\"BufferCImpl size: {} bytes\\n\", buffer_size);\n\n    // Fill with test data\n    for (int i = 0; i < max_m * n; i++) {\n      buf_c->c[i] = static_cast<float>(i % 1000) / 100.0f;\n    }\n\n    // Test to_mat\n    std::vector<ggml_bf16_t> output(max_m * n);\n    buf_c->to_mat(max_m, output.data(), 0, 1);\n\n    std::cout << \"  Sample values from BufferCImpl:\" << std::endl;\n    for (int i = 0; i < 3; i++) {\n      std::cout << fmt::format(\"    c[{}] = {:.4f}\\n\", i, buf_c->c[i]);\n    }\n\n    free(buffer);\n  }\n\n  // Test BufferCReduceImpl\n  {\n    size_t buffer_size = amx::BufferCReduceImpl<TestKernelC>::required_size(max_m, n);\n    void* buffer = std::aligned_alloc(64, buffer_size);\n    auto buf_cr = std::make_unique<amx::BufferCReduceImpl<TestKernelC>>(max_m, n, buffer);\n\n    std::cout << fmt::format(\"\\nBufferCReduceImpl size: {} bytes ({}x larger)\\n\", buffer_size,\n                             buffer_size / (sizeof(float) * max_m * n));\n\n    // Fill float buffer\n    for (int i = 0; i < max_m * n; i++) {\n      buf_cr->c[i] = static_cast<float>(i % 1000) / 100.0f;\n    }\n\n    // Fill int buffer\n    for (int i = 0; i < max_m * n; i++) {\n      buf_cr->int_c[i] = i % 1000;\n    }\n\n    // Test to_mat\n    std::vector<ggml_bf16_t> output(max_m * n);\n    buf_cr->to_mat(max_m, output.data(), 0, 1);\n\n    std::cout << \"  Sample values from BufferCReduceImpl:\" << std::endl;\n    for (int i = 0; i < 3; i++) {\n      std::cout << fmt::format(\"    c[{}] = {:.4f}, int_c[{}] = {}\\n\", i, buf_cr->c[i], i, buf_cr->int_c[i]);\n    }\n\n    free(buffer);\n  }\n\n  std::cout << \"\\n✓ Comparison test completed\" << std::endl;\n}\n\nvoid test_buffer_c_reduce_performance() {\n  std::cout << \"\\n=== Testing BufferCReduceImpl Performance Characteristics ===\" << std::endl;\n\n  const int max_m = 256;\n  const int n = 2048;\n  const int iterations = 1000;\n\n  size_t buffer_size = amx::BufferCReduceImpl<TestKernelC>::required_size(max_m, n);\n  void* buffer = std::aligned_alloc(64, buffer_size);\n  auto buf = std::make_unique<amx::BufferCReduceImpl<TestKernelC>>(max_m, n, buffer);\n\n  std::cout << fmt::format(\"Testing with max_m={}, n={}\\n\", max_m, n);\n  std::cout << fmt::format(\"Total elements: {}\\n\", max_m * n);\n  std::cout << fmt::format(\"Buffer size: {:.2f} MB\\n\", buffer_size / (1024.0 * 1024.0));\n\n  // Test clear_int_buffer performance\n  std::cout << \"\\nTesting clear_int_buffer...\" << std::endl;\n  auto start = std::chrono::high_resolution_clock::now();\n  for (int i = 0; i < iterations; i++) {\n    buf->clear_int_buffer();\n  }\n  auto end = std::chrono::high_resolution_clock::now();\n  auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();\n  std::cout << fmt::format(\"  Average time: {:.3f} us\\n\", duration / (double)iterations);\n\n  // Test convert_int_to_float performance\n  std::cout << \"\\nTesting convert_int_to_float...\" << std::endl;\n  // Fill int buffer with test data\n  for (int i = 0; i < max_m * n; i++) {\n    buf->int_c[i] = i;\n  }\n\n  start = std::chrono::high_resolution_clock::now();\n  for (int i = 0; i < iterations; i++) {\n    buf->convert_int_to_float(max_m);\n  }\n  end = std::chrono::high_resolution_clock::now();\n  duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();\n  std::cout << fmt::format(\"  Average time: {:.3f} us\\n\", duration / (double)iterations);\n\n  free(buffer);\n  std::cout << \"\\n✓ Performance tests completed\" << std::endl;\n}\n\nint main(int argc, char** argv) {\n  std::cout << \"Starting BufferCReduceImpl Tests\\n\" << std::endl;\n\n  try {\n    // Run basic functionality tests\n    test_buffer_c_reduce_basic();\n\n    // Run comparison tests\n    test_buffer_c_reduce_comparison();\n\n    // Run performance tests\n    test_buffer_c_reduce_performance();\n\n    std::cout << \"\\n=== All tests completed successfully! ===\" << std::endl;\n  } catch (const std::exception& e) {\n    std::cerr << \"Test failed with exception: \" << e.what() << std::endl;\n    return 1;\n  }\n\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/amx-kgroup-test.cpp",
    "content": "#include <omp.h>\n\n#include \"../la/amx.hpp\"\n#define FMT_HEADER_ONLY\n#include <fmt/core.h>\n\n#include <cmath>\n#include <iostream>\n#include <memory>\n\n// Test kernel configuration for k-group testing\nstruct TestKernelKGroup {\n  static constexpr int M_STEP = 32;\n  static constexpr int K_STEP = 64;\n  static constexpr int K_BLOCK = 512;\n  static constexpr int N_STEP = 32;\n  static constexpr int N_BLOCK = 512;\n  static constexpr int TILE_N = 16;\n  using dt = int8_t;\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_per_thread = (n + nth - 1) / nth;\n    int n_start = ith * n_per_thread;\n    int n_end = std::min(n_start + n_per_thread, n);\n    return {n_start, n_end};\n  }\n};\n\nvoid test_buffer_kgroup_basic() {\n  std::cout << \"=== Testing BufferAKGroupImpl Basic Functionality ===\" << std::endl;\n\n  // Test parameters\n  const int max_m = 64;          // Must be multiple of M_STEP\n  const int k = 2048;            // Must be multiple of K_STEP and K_BLOCK\n  const int k_group_size = 128;  // Must divide K_BLOCK evenly\n\n  std::cout << fmt::format(\"Parameters: max_m={}, k={}, k_group_size={}\\n\", max_m, k, k_group_size);\n\n  // Calculate and allocate buffer\n  size_t buffer_size = amx::BufferAKGroupImpl<TestKernelKGroup>::required_size(max_m, k, k_group_size);\n  void* buffer = std::aligned_alloc(64, buffer_size);\n  std::memset(buffer, 0, buffer_size);\n\n  std::cout << fmt::format(\"Buffer size: {} bytes\\n\", buffer_size);\n\n  // Create BufferAKGroupImpl instance\n  auto buf = std::make_unique<amx::BufferAKGroupImpl<TestKernelKGroup>>(max_m, k, k_group_size, buffer);\n\n  // Create test input data (bf16)\n  std::vector<ggml_bf16_t> input(max_m * k);\n  std::mt19937 gen(42);\n  std::uniform_real_distribution<float> dist(-1.0f, 1.0f);\n\n  for (int i = 0; i < max_m * k; i++) {\n    float val = dist(gen);\n    input[i] = ggml_compute_fp32_to_bf16(val);\n  }\n\n  // Test from_mat\n  std::cout << \"Testing from_mat...\" << std::endl;\n  buf->from_mat(max_m, input.data(), 0, 1);\n  std::cout << \"✓ from_mat completed successfully\" << std::endl;\n\n  // Test get_submat\n  std::cout << \"Testing get_submat...\" << std::endl;\n  for (int m_begin = 0; m_begin < max_m; m_begin += TestKernelKGroup::M_STEP) {\n    for (int k_begin = 0; k_begin < k; k_begin += TestKernelKGroup::K_STEP) {\n      int8_t* submat = buf->get_submat(max_m, k, m_begin, k_begin);\n      if (submat == nullptr) {\n        std::cerr << fmt::format(\"ERROR: get_submat returned null for m_begin={}, k_begin={}\\n\", m_begin, k_begin);\n        free(buffer);\n        return;\n      }\n    }\n  }\n  std::cout << \"✓ get_submat tested for all valid positions\" << std::endl;\n\n  // Test get_scale\n  std::cout << \"Testing get_scale...\" << std::endl;\n  int k_group_count = k / k_group_size;\n  for (int m_idx = 0; m_idx < max_m; m_idx++) {\n    for (int kg_idx = 0; kg_idx < k_group_count; kg_idx++) {\n      float* scale = buf->get_scale(max_m, m_idx, k, kg_idx * k_group_size);\n      if (scale == nullptr) {\n        std::cerr << fmt::format(\"ERROR: get_scale returned null for m_idx={}, k_group={}\\n\", m_idx, kg_idx);\n        free(buffer);\n        return;\n      }\n      // Verify scale is non-zero (should be set by from_mat)\n      if (*scale == 0.0f) {\n        std::cerr << fmt::format(\"WARNING: scale is zero for m_idx={}, k_group={}\\n\", m_idx, kg_idx);\n      }\n    }\n  }\n  std::cout << \"✓ get_scale tested for all k-groups\" << std::endl;\n\n  // Print some scale values for verification\n  std::cout << \"\\nSample scale values:\" << std::endl;\n  for (int kg = 0; kg < std::min(4, k_group_count); kg++) {\n    float* scale = buf->get_scale(max_m, 0, k, kg * k_group_size);\n    std::cout << fmt::format(\"  k_group[{}] (k={}): scale = {:.6f}\\n\", kg, kg * k_group_size, *scale);\n  }\n\n  // Clean up\n  free(buffer);\n  std::cout << \"\\n✓ All basic tests passed!\" << std::endl;\n}\n\nvoid test_buffer_kgroup_correctness() {\n  std::cout << \"\\n=== Testing BufferAKGroupImpl Quantization Correctness ===\" << std::endl;\n\n  const int max_m = 32;\n  const int k = 512;\n  const int k_group_size = 128;\n\n  size_t buffer_size = amx::BufferAKGroupImpl<TestKernelKGroup>::required_size(max_m, k, k_group_size);\n  void* buffer = std::aligned_alloc(64, buffer_size);\n\n  auto buf = std::make_unique<amx::BufferAKGroupImpl<TestKernelKGroup>>(max_m, k, k_group_size, buffer);\n\n  // Create test input matrix with known patterns\n  std::vector<float> original(max_m * k);\n  std::vector<ggml_bf16_t> input(max_m * k);\n\n  // Fill with different patterns for each k-group to test group-wise quantization\n  for (int m = 0; m < max_m; m++) {\n    for (int k_idx = 0; k_idx < k; k_idx++) {\n      int kg = k_idx / k_group_size;\n      // Different magnitude for each k-group\n      float base_val = (kg + 1) * 0.1f;\n      float val = base_val * std::sin(m * 0.1f + k_idx * 0.01f);\n      original[m * k + k_idx] = val;\n      input[m * k + k_idx] = ggml_compute_fp32_to_bf16(val);\n    }\n  }\n\n  // Quantize\n  buf->from_mat(max_m, input.data(), 0, 1);\n\n  // Dequantize and check error\n  std::vector<float> dequantized(max_m * k);\n  float max_error = 0.0f;\n  float total_error = 0.0f;\n  int num_elements = 0;\n\n  for (int m = 0; m < max_m; m++) {\n    for (int k_idx = 0; k_idx < k; k_idx++) {\n      int kg = k_idx / k_group_size;\n\n      // Get the scale for this k-group\n      float* scale_ptr = buf->get_scale(max_m, m, k, kg * k_group_size);\n      float scale = *scale_ptr;\n\n      // Get quantized value (simplified access for testing)\n      // In real use, this would go through get_submat\n      int m_block_size = (max_m + TestKernelKGroup::M_STEP - 1) / TestKernelKGroup::M_STEP * TestKernelKGroup::M_STEP;\n      int k_block_begin = (k_idx / TestKernelKGroup::K_BLOCK) * TestKernelKGroup::K_BLOCK;\n      int k_in_block = k_idx - k_block_begin;\n      int k_block_size = std::min(TestKernelKGroup::K_BLOCK, k - k_block_begin);\n\n      // Locate the quantized data\n      int m_step_idx = m / TestKernelKGroup::M_STEP;\n      int m_in_step = m % TestKernelKGroup::M_STEP;\n      int k_step_idx = k_in_block / TestKernelKGroup::K_STEP;\n      int k_in_step = k_in_block % TestKernelKGroup::K_STEP;\n\n      int8_t* base = buf->a + k_block_begin * m_block_size + m_step_idx * TestKernelKGroup::M_STEP * k_block_size +\n                     k_step_idx * TestKernelKGroup::K_STEP * TestKernelKGroup::M_STEP +\n                     m_in_step * TestKernelKGroup::K_STEP + k_in_step;\n\n      int8_t quantized_val = *base;\n\n      // Dequantize\n      float deq = quantized_val * scale;\n      dequantized[m * k + k_idx] = deq;\n\n      // Calculate error\n      float error = std::abs(original[m * k + k_idx] - deq);\n      max_error = std::max(max_error, error);\n      total_error += error;\n      num_elements++;\n    }\n  }\n\n  float avg_error = total_error / num_elements;\n  float avg_magnitude = 0.0f;\n  for (int i = 0; i < max_m * k; i++) {\n    avg_magnitude += std::abs(original[i]);\n  }\n  avg_magnitude /= (max_m * k);\n\n  float relative_error = avg_error / (avg_magnitude + 1e-8f);\n\n  std::cout << fmt::format(\"Quantization Error Analysis:\\n\");\n  std::cout << fmt::format(\"  Max absolute error: {:.6f}\\n\", max_error);\n  std::cout << fmt::format(\"  Average absolute error: {:.6f}\\n\", avg_error);\n  std::cout << fmt::format(\"  Average magnitude: {:.6f}\\n\", avg_magnitude);\n  std::cout << fmt::format(\"  Relative error: {:.2f}%\\n\", relative_error * 100);\n\n  // Check that relative error is reasonable (typically < 5% for int8 quantization)\n  if (relative_error < 0.05f) {\n    std::cout << \"✓ Quantization error is within acceptable range\" << std::endl;\n  } else {\n    std::cerr << \"WARNING: Quantization error is higher than expected!\" << std::endl;\n  }\n\n  // Test that different k-groups have different scales\n  std::cout << \"\\nVerifying k-group scales are computed independently:\" << std::endl;\n  bool scales_differ = false;\n  for (int m = 0; m < std::min(4, max_m); m++) {\n    float* scale0 = buf->get_scale(max_m, m, k, 0);\n    for (int kg = 1; kg < k / k_group_size; kg++) {\n      float* scale_kg = buf->get_scale(max_m, m, k, kg * k_group_size);\n      if (std::abs(*scale0 - *scale_kg) > 1e-6f) {\n        scales_differ = true;\n        break;\n      }\n    }\n    if (scales_differ) break;\n  }\n\n  if (scales_differ) {\n    std::cout << \"✓ Different k-groups have independent scales\" << std::endl;\n  } else {\n    std::cout << \"✗ Warning: All k-groups have the same scale (might be correct for uniform data)\" << std::endl;\n  }\n\n  free(buffer);\n}\n\nvoid test_buffer_kgroup_comparison() {\n  std::cout << \"\\n=== Comparing BufferAImpl vs BufferAKGroupImpl ===\" << std::endl;\n\n  const int max_m = 128;\n  const int k = 2048;\n  const int k_group_size = 256;\n\n  // Create test data\n  std::vector<ggml_bf16_t> input(max_m * k);\n  std::mt19937 gen(456);\n  std::uniform_real_distribution<float> dist(-1.0f, 1.0f);\n  for (int i = 0; i < max_m * k; i++) {\n    input[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n\n  // Test original BufferAImpl\n  {\n    size_t buffer_size = amx::BufferAImpl<TestKernelKGroup>::required_size(max_m, k);\n    void* buffer = std::aligned_alloc(64, buffer_size);\n    auto buf_a = std::make_unique<amx::BufferAImpl<TestKernelKGroup>>(max_m, k, buffer);\n\n    buf_a->from_mat(max_m, input.data(), 0, 1);\n\n    // Print some scales\n    std::cout << \"BufferAImpl scales (per-row):\" << std::endl;\n    for (int m = 0; m < std::min(4, max_m); m++) {\n      float* scale = buf_a->get_scale(max_m, m);\n      std::cout << fmt::format(\"  row[{}]: scale = {:.6f}\\n\", m, *scale);\n    }\n\n    free(buffer);\n  }\n\n  // Test BufferAKGroupImpl\n  {\n    size_t buffer_size = amx::BufferAKGroupImpl<TestKernelKGroup>::required_size(max_m, k, k_group_size);\n    void* buffer = std::aligned_alloc(64, buffer_size);\n    auto buf_kg = std::make_unique<amx::BufferAKGroupImpl<TestKernelKGroup>>(max_m, k, k_group_size, buffer);\n\n    buf_kg->from_mat(max_m, input.data(), 0, 1);\n\n    // Print some scales\n    std::cout << \"\\nBufferAKGroupImpl scales (per k-group):\" << std::endl;\n    for (int m = 0; m < std::min(2, max_m); m++) {\n      std::cout << fmt::format(\"  row[{}]:\\n\", m);\n      for (int kg = 0; kg < std::min(4, k / k_group_size); kg++) {\n        float* scale = buf_kg->get_scale(max_m, m, k, kg * k_group_size);\n        std::cout << fmt::format(\"    k_group[{}]: scale = {:.6f}\\n\", kg, *scale);\n      }\n    }\n\n    free(buffer);\n  }\n\n  std::cout << \"\\n✓ Comparison test completed\" << std::endl;\n}\n\nint main(int argc, char** argv) {\n  std::cout << \"Starting BufferAKGroupImpl Tests\\n\" << std::endl;\n\n  try {\n    // Run basic functionality tests\n    test_buffer_kgroup_basic();\n\n    // Run correctness tests\n    test_buffer_kgroup_correctness();\n\n    // Run comparison tests\n    test_buffer_kgroup_comparison();\n\n    std::cout << \"\\n=== All tests completed successfully! ===\" << std::endl;\n  } catch (const std::exception& e) {\n    std::cerr << \"Test failed with exception: \" << e.what() << std::endl;\n    return 1;\n  }\n\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/amx-test.cpp",
    "content": "#include \"../la/amx.hpp\"\n\n#include <omp.h>\n\n#include \"mat-test.hpp\"\n#define FMT_HEADER_ONLY\n#include <fmt/core.h>\n\nconst int test_iter = 100;\nconst bool mt = true;\nconst bool cache_hit = false;\n\nvoid q_latency_test_bf16(int m, int n, int k, ggml_bf16_t* qa, ggml_bf16_t* qb) {\n  int nth = amx::GemmKernel224BF::recommended_nth(n);\n  int m_ = (m + 31) / 32 * 32;\n  Mat<float> d(m_, n, Layout::RowMajor);\n  {\n    int repeat = 100;\n    std::vector<ggml_bf16_t*> vec_a;\n    std::vector<ggml_bf16_t*> vec_b;\n    std::vector<float*> vec_c;\n    std::vector<std::shared_ptr<amx::GemmKernel224BF::BufferA>> vec_ba;\n    std::vector<std::shared_ptr<amx::GemmKernel224BF::BufferB>> vec_bb;\n    std::vector<std::shared_ptr<amx::GemmKernel224BF::BufferC>> vec_bc;\n    for (int i = 0; i < repeat * 2; i++) {\n      ggml_bf16_t* a = (ggml_bf16_t*)std::aligned_alloc(64, amx::GemmKernel224BF::BufferA::required_size(m_, k));\n      std::shared_ptr<amx::GemmKernel224BF::BufferA> ba = std::make_shared<amx::GemmKernel224BF::BufferA>(m_, k, a);\n      ggml_bf16_t* b = (ggml_bf16_t*)std::aligned_alloc(64, amx::GemmKernel224BF::BufferB::required_size(n, k));\n      std::shared_ptr<amx::GemmKernel224BF::BufferB> bb = std::make_shared<amx::GemmKernel224BF::BufferB>(n, k, b);\n      float* c = (float*)std::aligned_alloc(64, amx::GemmKernel224BF::BufferC::required_size(m_, n));\n      std::shared_ptr<amx::GemmKernel224BF::BufferC> bc = std::make_shared<amx::GemmKernel224BF::BufferC>(m_, n, c);\n      ba->from_mat(m, qa, 0, 1);\n      int nth = amx::GemmKernel224BF::recommended_nth(n);\n      for (int i = 0; i < nth; i++) {\n        bb->from_mat(qb, i, nth);\n      }\n      vec_a.push_back(a);\n      vec_b.push_back(b);\n      vec_c.push_back(c);\n      vec_ba.push_back(ba);\n      vec_bb.push_back(bb);\n      vec_bc.push_back(bc);\n    }\n    Timer t(fmt::format(\"m:{} n:{} k:{} t:{} repeat:{}, latency\", m, n, k, test_iter, repeat));\n    for (int t = 0; t < test_iter; t++) {\n#pragma omp parallel for schedule(dynamic, 1)\n      for (int ti = 0; ti < nth * repeat; ti++) {\n        int mat_id = ti / nth + repeat * (t % 2);\n        int ith = ti % nth;\n        if (cache_hit) {\n          mat_id = 0;\n        }\n        amx::mat_mul(m, n, k, vec_ba[mat_id], vec_bb[mat_id], vec_bc[mat_id], ith, nth);\n      }\n    }\n    for (int i = 0; i < repeat * 2; i++) {\n      free(vec_a[i]);\n      free(vec_b[i]);\n      free(vec_c[i]);\n    }\n  }\n  d.dealloc();\n}\n\nvoid group_q_latency_test_bf16(int n_max, int k_max) {\n  amx::GemmKernel224BF::config();\n\n  int m_max = 1024;\n  int m_start = 32;\n  int m_step = 32;\n\n  Mat<float> a(m_max, k_max, Layout::RowMajor), b(k_max, n_max, Layout::ColumnMajor);\n  std::mt19937 gen(123);\n  a.random(gen);\n  b.random(gen);\n  a.quant(GGML_TYPE_BF16);\n  b.quant(GGML_TYPE_BF16);\n\n  std::string method_name = \"BF16\";\n  if (mt) {\n    method_name += fmt::format(\"_mt{}\", omp_get_max_threads());\n  }\n  if (cache_hit) {\n    method_name += \"-cache-hit\";\n  }\n\n  auto output = fmt::format(\"{}-m:{}:{}:{}-n:{}-k:{}-x{}x{}.txt\", method_name, m_start, m_max, m_step, n_max, k_max,\n                            amx::GemmKernel224BF::N_BLOCK, amx::GemmKernel224BF::K_BLOCK);\n  // std::cout << \"Output to: \" << output << std::endl;\n  auto x = freopen(output.c_str(), \"w\", stdout);\n  assert(x);\n\n  for (int m = m_start; m <= m_max; m *= 2) {\n    q_latency_test_bf16(m, n_max, k_max, a.quant_data<ggml_bf16_t>(), b.quant_data<ggml_bf16_t>());\n  }\n}\n\nvoid q_latency_test_int8(int m, int n, int k, ggml_bf16_t* qa, ggml_bf16_t* qb) {\n  int nth = amx::GemmKernel224Int8::recommended_nth(n);\n  int m_ = (m + 31) / 32 * 32;\n  Mat<float> d(m_, n, Layout::RowMajor);\n  {\n    int repeat = 100;\n    std::vector<int8_t*> vec_a;\n    std::vector<int8_t*> vec_b;\n    std::vector<float*> vec_c;\n    std::vector<std::shared_ptr<amx::GemmKernel224Int8::BufferA>> vec_ba;\n    std::vector<std::shared_ptr<amx::GemmKernel224Int8::BufferB>> vec_bb;\n    std::vector<std::shared_ptr<amx::GemmKernel224Int8::BufferC>> vec_bc;\n    for (int i = 0; i < repeat * 2; i++) {\n      int8_t* a = (int8_t*)std::aligned_alloc(64, amx::GemmKernel224Int8::BufferA::required_size(m_, k));\n      std::shared_ptr<amx::GemmKernel224Int8::BufferA> ba = std::make_shared<amx::GemmKernel224Int8::BufferA>(m_, k, a);\n      int8_t* b = (int8_t*)std::aligned_alloc(64, amx::GemmKernel224Int8::BufferB::required_size(n, k));\n      std::shared_ptr<amx::GemmKernel224Int8::BufferB> bb = std::make_shared<amx::GemmKernel224Int8::BufferB>(n, k, b);\n      float* c = (float*)std::aligned_alloc(64, amx::GemmKernel224Int8::BufferC::required_size(m_, n));\n      std::shared_ptr<amx::GemmKernel224Int8::BufferC> bc = std::make_shared<amx::GemmKernel224Int8::BufferC>(m_, n, c);\n      ba->from_mat(m, qa, 0, 1);\n      int nth = amx::GemmKernel224Int8::recommended_nth(n);\n      for (int i = 0; i < nth; i++) {\n        bb->from_mat(qb, i, nth);\n      }\n      vec_a.push_back(a);\n      vec_b.push_back(b);\n      vec_c.push_back(c);\n      vec_ba.push_back(ba);\n      vec_bb.push_back(bb);\n      vec_bc.push_back(bc);\n    }\n    Timer t(fmt::format(\"m:{} n:{} k:{} t:{} repeat:{}, latency\", m, n, k, test_iter, repeat));\n    for (int t = 0; t < test_iter; t++) {\n#pragma omp parallel for schedule(dynamic, 1)\n      for (int ti = 0; ti < nth * repeat; ti++) {\n        int mat_id = ti / nth + repeat * (t % 2);\n        int ith = ti % nth;\n        if (cache_hit) {\n          mat_id = 0;\n        }\n        amx::mat_mul(m, n, k, vec_ba[mat_id], vec_bb[mat_id], vec_bc[mat_id], ith, nth);\n      }\n    }\n    for (int i = 0; i < repeat * 2; i++) {\n      free(vec_a[i]);\n      free(vec_b[i]);\n      free(vec_c[i]);\n    }\n  }\n  d.dealloc();\n}\n\nvoid group_q_latency_test_int8(int n_max, int k_max) {\n  amx::GemmKernel224Int8::config();\n\n  int m_max = 1024;\n  int m_start = 32;\n  int m_step = 32;\n\n  Mat<float> a(m_max, k_max, Layout::RowMajor), b(k_max, n_max, Layout::ColumnMajor);\n  std::mt19937 gen(123);\n  a.random(gen);\n  b.random(gen);\n  a.quant(GGML_TYPE_BF16);\n  b.quant(GGML_TYPE_BF16);\n\n  std::string method_name = \"INT8\";\n  if (mt) {\n    method_name += fmt::format(\"_mt{}\", omp_get_max_threads());\n  }\n  if (cache_hit) {\n    method_name += \"-cache-hit\";\n  }\n\n  auto output = fmt::format(\"{}-m:{}:{}:{}-n:{}-k:{}-x{}x{}.txt\", method_name, m_start, m_max, m_step, n_max, k_max,\n                            amx::GemmKernel224Int8::N_BLOCK, amx::GemmKernel224Int8::K_BLOCK);\n  // std::cout << \"Output to: \" << output << std::endl;\n  auto x = freopen(output.c_str(), \"w\", stdout);\n  assert(x);\n  for (int m = m_start; m <= m_max; m *= 2) {\n    q_latency_test_int8(m, n_max, k_max, a.quant_data<ggml_bf16_t>(), b.quant_data<ggml_bf16_t>());\n  }\n}\n\nvoid correction_test_int4(int m, int n, int k) {\n  amx::GemmKernel224Int4::config();\n\n  int m_max = 1024;\n  int m_start = 32;\n  int m_step = 32;\n\n  Mat<float> ma(m, k, Layout::RowMajor), mb(k, n, Layout::ColumnMajor);\n  // std::mt19937 gen(123);\n\n  // for(size_t i=0;i<m;i++){\n  //   for(size_t j=0;j<k;j++){\n  //     // ma.at(i,j) = std::max(int(-i+j),0);\n  //     ma.at(i,j) = (i+j)%25/25.0;\n  //   }\n  // }\n  // for (size_t i = 0; i < k; i++) {\n  //   for (size_t j = 0; j < n; j++) {\n  //     // mb.at(i,j) = std::max(int(-i+j),0);\n  //     mb.at(i,j) = (i+j)%25/25.0;\n  //   }\n  // }\n  std::mt19937 gena(123);\n  std::mt19937 genb(312);\n  ma.random(gena);\n  mb.random(genb);\n  // ma.random(gen);\n  // mb.random(gen);\n\n  auto mc = ma.mul_check(mb);\n  // ma.print();\n  // mb.print();\n\n  ma.quant(GGML_TYPE_BF16);\n  mb.quant(GGML_TYPE_BF16);\n\n  using K = amx::GemmKernel224Int4;\n  int8_t* a = (int8_t*)std::aligned_alloc(64, K::BufferA::required_size(m, k));\n  std::shared_ptr<K::BufferA> ba = std::make_shared<K::BufferA>(m, k, a);\n  int8_t* b = (int8_t*)std::aligned_alloc(64, K::BufferB::required_size(n, k));\n  std::shared_ptr<K::BufferB> bb = std::make_shared<K::BufferB>(n, k, b);\n  float* c = (float*)std::aligned_alloc(64, K::BufferC::required_size(m, n));\n  std::shared_ptr<K::BufferC> bc = std::make_shared<K::BufferC>(m, n, c);\n\n  ba->from_mat(m, ma.quant_data<ggml_bf16_t>(), 0, 1);\n  // printf(\"%d\\n\",amx::GemmKernel224Int4::BufferA::required_size(m, k));\n  // for(size_t i=0;i<amx::GemmKernel224Int4::BufferA::required_size(m, k);i++){\n  //   if((i*2)%k==0)\n  //     printf(\"\\n\");\n\n  //   printf(\"%02x \", (unsigned char)(a[i]));\n  // }\n  // printf(\"\\n\");\n\n  // int nth = amx::GemmKernel224Int4::recommended_nth(n);\n  bb->from_mat(mb.quant_data<ggml_bf16_t>(), 0, 1);\n\n  // for(size_t i=0;i<amx::GemmKernel224Int4::BufferB::required_size(n, k);i++){\n  //  if((i*2)%k==0)\n  //     printf(\"\\n\");\n\n  //  printf(\"%02x \", (unsigned char)(b[i]));\n  // }\n  // printf(\"\\n\");\n\n  amx::mat_mul(m, n, k, ba, bb, bc, 0, 1);\n\n  // for(size_t i=0;i<m;i++){\n  //   for(size_t j=0;j<n;j++){\n  //     printf(\"%.2f \",c[i*n+j]);\n  //   }\n  //   printf(\"\\n\");\n  // }\n\n  // printf(\"\\n\");\n  Mat<float> tc(m, n, Layout::RowMajor);\n  tc.data = c;\n  // std::cout<<\"AMX OUTPUT:\"<<std::endl;\n  // tc.print_all();\n  // std::cout<<\"STD OUTPUT:\"<<std::endl;\n  // mc.print_all();\n\n  mc.cmp(tc);\n\n  // for(size_t i=0;i<m/32;i++){\n  //   for(size_t j=0;j<n/32;j++){\n  //     Mat<float> stdre(32,32,Layout::RowMajor);\n  //     Mat<float> amxre(32,32,Layout::RowMajor);\n  //     for(size_t ii=i*32;ii<i*32+32;ii++){\n  //       for(size_t jj=j*32;jj<j*32+32;jj++){\n  //         stdre.at(ii-i*32,jj-j*32) = mc.at(ii,jj);\n  //         amxre.at(ii-i*32,jj-j*32) = tc.at(ii,jj);\n  //       }\n  //     }\n  //     printf(\"%d %d \",i,j);\n  //     stdre.cmp(amxre);\n  //     // if(i==0&&j==0){\n  //       std::cout<<\"STD\"<<std::endl;\n  //       stdre.print_all();\n  //       std::cout<<\"AMX\"<<std::endl;\n  //       amxre.print_all();\n  //     // }\n  //   }\n  // }\n}\n\nvoid correction_test_int4_1(int m, int n, int k) {\n  using K = amx::GemmKernel224Int4_1;\n  K::config();\n\n  int m_max = 1024;\n  int m_start = 32;\n  int m_step = 32;\n\n  Mat<float> ma(m, k, Layout::RowMajor), mb(k, n, Layout::ColumnMajor);\n  // std::mt19937 gen(123);\n\n  // for(size_t i=0;i<m;i++){\n  //   for(size_t j=0;j<k;j++){\n  //     // ma.at(i,j) = std::max(int(-i+j),0);\n  //     ma.at(i,j) = (i+j)%25/25.0;\n  //   }\n  // }\n  // for (size_t i = 0; i < k; i++) {\n  //   for (size_t j = 0; j < n; j++) {\n  //     // mb.at(i,j) = std::max(int(-i+j),0);\n  //     mb.at(i,j) = (i+j)%25/25.0;\n  //   }\n  // }\n  std::mt19937 gena(123);\n  std::mt19937 genb(312);\n  ma.random(gena);\n  mb.random(genb);\n  // ma.random(gen);\n  // mb.random(gen);\n\n  auto mc = ma.mul_check(mb);\n  // ma.print();\n  // mb.print();\n\n  ma.quant(GGML_TYPE_BF16);\n  mb.quant(GGML_TYPE_BF16);\n\n  int8_t* a = (int8_t*)std::aligned_alloc(64, K::BufferA::required_size(m, k));\n  std::shared_ptr<K::BufferA> ba = std::make_shared<K::BufferA>(m, k, a);\n  int8_t* b = (int8_t*)std::aligned_alloc(64, K::BufferB::required_size(n, k));\n  std::shared_ptr<K::BufferB> bb = std::make_shared<K::BufferB>(n, k, b);\n  float* c = (float*)std::aligned_alloc(64, K::BufferC::required_size(m, n));\n  std::shared_ptr<K::BufferC> bc = std::make_shared<K::BufferC>(m, n, c);\n\n  ba->from_mat(m, ma.quant_data<ggml_bf16_t>(), 0, 1);\n  // printf(\"%d\\n\",amx::GemmKernel224Int4::BufferA::required_size(m, k));\n  // for(size_t i=0;i<amx::GemmKernel224Int4::BufferA::required_size(m, k);i++){\n  //   if((i*2)%k==0)\n  //     printf(\"\\n\");\n\n  //   printf(\"%02x \", (unsigned char)(a[i]));\n  // }\n  // printf(\"\\n\");\n\n  // int nth = amx::GemmKernel224Int4::recommended_nth(n);\n  bb->from_mat(mb.quant_data<ggml_bf16_t>(), 0, 1);\n\n  // for(size_t i=0;i<amx::GemmKernel224Int4::BufferB::required_size(n, k);i++){\n  //  if((i*2)%k==0)\n  //     printf(\"\\n\");\n\n  //  printf(\"%02x \", (unsigned char)(b[i]));\n  // }\n  // printf(\"\\n\");\n\n  amx::mat_mul(m, n, k, ba, bb, bc, 0, 1);\n\n  // for(size_t i=0;i<m;i++){\n  //   for(size_t j=0;j<n;j++){\n  //     printf(\"%.2f \",c[i*n+j]);\n  //   }\n  //   printf(\"\\n\");\n  // }\n\n  // printf(\"\\n\");\n  Mat<float> tc(m, n, Layout::RowMajor);\n  tc.data = c;\n  std::cout << \"AMX OUTPUT:\" << std::endl;\n  tc.print_all();\n  std::cout << \"STD OUTPUT:\" << std::endl;\n  mc.print_all();\n\n  mc.cmp(tc);\n\n  // for(size_t i=0;i<m/32;i++){\n  //   for(size_t j=0;j<n/32;j++){\n  //     Mat<float> stdre(32,32,Layout::RowMajor);\n  //     Mat<float> amxre(32,32,Layout::RowMajor);\n  //     for(size_t ii=i*32;ii<i*32+32;ii++){\n  //       for(size_t jj=j*32;jj<j*32+32;jj++){\n  //         stdre.at(ii-i*32,jj-j*32) = mc.at(ii,jj);\n  //         amxre.at(ii-i*32,jj-j*32) = tc.at(ii,jj);\n  //       }\n  //     }\n  //     printf(\"%d %d \",i,j);\n  //     stdre.cmp(amxre);\n  //     // if(i==0&&j==0){\n  //       std::cout<<\"STD\"<<std::endl;\n  //       stdre.print_all();\n  //       std::cout<<\"AMX\"<<std::endl;\n  //       amxre.print_all();\n  //     // }\n  //   }\n  // }\n}\n\nvoid q_latency_test_int4(int m, int n, int k, ggml_bf16_t* qa, ggml_bf16_t* qb) {\n  int nth = amx::GemmKernel224Int4::recommended_nth(n);\n  int m_ = (m + 31) / 32 * 32;\n  Mat<float> d(m_, n, Layout::RowMajor);\n  {\n    int repeat = 100;\n    std::vector<int8_t*> vec_a;\n    std::vector<int8_t*> vec_b;\n    std::vector<float*> vec_c;\n    std::vector<std::shared_ptr<amx::GemmKernel224Int4::BufferA>> vec_ba;\n    std::vector<std::shared_ptr<amx::GemmKernel224Int4::BufferB>> vec_bb;\n    std::vector<std::shared_ptr<amx::GemmKernel224Int4::BufferC>> vec_bc;\n    for (int i = 0; i < repeat * 2; i++) {\n      int8_t* a = (int8_t*)std::aligned_alloc(64, amx::GemmKernel224Int4::BufferA::required_size(m_, k));\n      std::shared_ptr<amx::GemmKernel224Int4::BufferA> ba = std::make_shared<amx::GemmKernel224Int4::BufferA>(m_, k, a);\n      int8_t* b = (int8_t*)std::aligned_alloc(64, amx::GemmKernel224Int4::BufferB::required_size(n, k));\n      std::shared_ptr<amx::GemmKernel224Int4::BufferB> bb = std::make_shared<amx::GemmKernel224Int4::BufferB>(n, k, b);\n      float* c = (float*)std::aligned_alloc(64, amx::GemmKernel224Int4::BufferC::required_size(m_, n));\n      std::shared_ptr<amx::GemmKernel224Int4::BufferC> bc = std::make_shared<amx::GemmKernel224Int4::BufferC>(m_, n, c);\n      ba->from_mat(m, qa, 0, 1);\n      int nth = amx::GemmKernel224Int4::recommended_nth(n);\n      for (int i = 0; i < nth; i++) {\n        bb->from_mat(qb, i, nth);\n      }\n      vec_a.push_back(a);\n      vec_b.push_back(b);\n      vec_c.push_back(c);\n      vec_ba.push_back(ba);\n      vec_bb.push_back(bb);\n      vec_bc.push_back(bc);\n    }\n    Timer t(fmt::format(\"m:{} n:{} k:{} t:{} repeat:{}, latency\", m, n, k, test_iter, repeat));\n    for (int t = 0; t < test_iter; t++) {\n#pragma omp parallel for schedule(dynamic, 1)\n      for (int ti = 0; ti < nth * repeat; ti++) {\n        int mat_id = ti / nth + repeat * (t % 2);\n        int ith = ti % nth;\n        if (cache_hit) {\n          mat_id = 0;\n        }\n        amx::mat_mul(m, n, k, vec_ba[mat_id], vec_bb[mat_id], vec_bc[mat_id], ith, nth);\n      }\n    }\n    for (int i = 0; i < repeat * 2; i++) {\n      free(vec_a[i]);\n      free(vec_b[i]);\n      free(vec_c[i]);\n    }\n  }\n  d.dealloc();\n}\n\nvoid group_q_latency_test_int4(int n_max, int k_max) {\n  amx::GemmKernel224Int4::config();\n\n  int m_max = 1024;\n  int m_start = 32;\n  int m_step = 32;\n\n  Mat<float> a(m_max, k_max, Layout::RowMajor), b(k_max, n_max, Layout::ColumnMajor);\n  std::mt19937 gen(123);\n  a.random(gen);\n  b.random(gen);\n  a.quant(GGML_TYPE_BF16);\n  b.quant(GGML_TYPE_BF16);\n\n  std::string method_name = \"INT4\";\n  if (mt) {\n    method_name += fmt::format(\"_mt{}\", omp_get_max_threads());\n  }\n  if (cache_hit) {\n    method_name += \"-cache-hit\";\n  }\n\n  auto output = fmt::format(\"{}-m:{}:{}:{}-n:{}-k:{}-x{}x{}.txt\", method_name, m_start, m_max, m_step, n_max, k_max,\n                            amx::GemmKernel224Int4::N_BLOCK, amx::GemmKernel224Int4::K_BLOCK);\n  // std::cout << \"Output to: \" << output << std::endl;\n  auto x = freopen(output.c_str(), \"w\", stdout);\n  assert(x);\n\n  for (int m = m_start; m <= m_max; m *= 2) {\n    q_latency_test_int4(m, n_max, k_max, a.quant_data<ggml_bf16_t>(), b.quant_data<ggml_bf16_t>());\n  }\n}\n\nvoid q_latency_test_int4_1(int m, int n, int k, ggml_bf16_t* qa, ggml_bf16_t* qb) {\n  int nth = amx::GemmKernel224Int4_1::recommended_nth(n);\n  int m_ = (m + 31) / 32 * 32;\n  Mat<float> d(m_, n, Layout::RowMajor);\n  {\n    int repeat = 100;\n    std::vector<int8_t*> vec_a;\n    std::vector<int8_t*> vec_b;\n    std::vector<float*> vec_c;\n    std::vector<std::shared_ptr<amx::GemmKernel224Int4_1::BufferA>> vec_ba;\n    std::vector<std::shared_ptr<amx::GemmKernel224Int4_1::BufferB>> vec_bb;\n    std::vector<std::shared_ptr<amx::GemmKernel224Int4_1::BufferC>> vec_bc;\n    for (int i = 0; i < repeat * 2; i++) {\n      int8_t* a = (int8_t*)std::aligned_alloc(64, amx::GemmKernel224Int4_1::BufferA::required_size(m_, k));\n      std::shared_ptr<amx::GemmKernel224Int4_1::BufferA> ba =\n          std::make_shared<amx::GemmKernel224Int4_1::BufferA>(m_, k, a);\n      int8_t* b = (int8_t*)std::aligned_alloc(64, amx::GemmKernel224Int4_1::BufferB::required_size(n, k));\n      std::shared_ptr<amx::GemmKernel224Int4_1::BufferB> bb =\n          std::make_shared<amx::GemmKernel224Int4_1::BufferB>(n, k, b);\n      float* c = (float*)std::aligned_alloc(64, amx::GemmKernel224Int4_1::BufferC::required_size(m_, n));\n      std::shared_ptr<amx::GemmKernel224Int4_1::BufferC> bc =\n          std::make_shared<amx::GemmKernel224Int4_1::BufferC>(m_, n, c);\n      ba->from_mat(m, qa, 0, 1);\n      int nth = amx::GemmKernel224Int4_1::recommended_nth(n);\n      for (int i = 0; i < nth; i++) {\n        bb->from_mat(qb, i, nth);\n      }\n      vec_a.push_back(a);\n      vec_b.push_back(b);\n      vec_c.push_back(c);\n      vec_ba.push_back(ba);\n      vec_bb.push_back(bb);\n      vec_bc.push_back(bc);\n    }\n    Timer t(fmt::format(\"m:{} n:{} k:{} t:{} repeat:{}, latency\", m, n, k, test_iter, repeat));\n    for (int t = 0; t < test_iter; t++) {\n#pragma omp parallel for schedule(dynamic, 1)\n      for (int ti = 0; ti < nth * repeat; ti++) {\n        int mat_id = ti / nth + repeat * (t % 2);\n        int ith = ti % nth;\n        if (cache_hit) {\n          mat_id = 0;\n        }\n        amx::mat_mul(m, n, k, vec_ba[mat_id], vec_bb[mat_id], vec_bc[mat_id], ith, nth);\n      }\n    }\n    for (int i = 0; i < repeat * 2; i++) {\n      free(vec_a[i]);\n      free(vec_b[i]);\n      free(vec_c[i]);\n    }\n  }\n  d.dealloc();\n}\n\nvoid group_q_latency_test_int4_1(int n_max, int k_max) {\n  amx::GemmKernel224Int4_1::config();\n\n  int m_max = 1024;\n  int m_start = 32;\n  int m_step = 32;\n\n  Mat<float> a(m_max, k_max, Layout::RowMajor), b(k_max, n_max, Layout::ColumnMajor);\n  std::mt19937 gen(123);\n  a.random(gen);\n  b.random(gen);\n  a.quant(GGML_TYPE_BF16);\n  b.quant(GGML_TYPE_BF16);\n\n  std::string method_name = \"INT4_1\";\n  if (mt) {\n    method_name += fmt::format(\"_mt{}\", omp_get_max_threads());\n  }\n  if (cache_hit) {\n    method_name += \"-cache-hit\";\n  }\n\n  auto output = fmt::format(\"{}-m:{}:{}:{}-n:{}-k:{}-x{}x{}.txt\", method_name, m_start, m_max, m_step, n_max, k_max,\n                            amx::GemmKernel224Int4_1::N_BLOCK, amx::GemmKernel224Int4_1::K_BLOCK);\n  // std::cout << \"Output to: \" << output << std::endl;\n  auto x = freopen(output.c_str(), \"w\", stdout);\n  assert(x);\n\n  for (int m = m_start; m <= m_max; m *= 2) {\n    q_latency_test_int4_1(m, n_max, k_max, a.quant_data<ggml_bf16_t>(), b.quant_data<ggml_bf16_t>());\n  }\n}\n\nint main() {\n  amx::enable_amx();\n  init();\n\n  // group_q_latency_test_bf16(5120, 1536);\n  // group_q_latency_test_bf16(3584, 2560);\n  // group_q_latency_test_bf16(2560, 3584);\n  // group_q_latency_test_bf16(1536, 5120);\n  // group_q_latency_test_bf16(7168, 2048);\n  // group_q_latency_test_bf16(2048, 7168);\n\n  // group_q_latency_test_int8(5120, 1536);\n  // group_q_latency_test_int8(3584, 2560);\n  // group_q_latency_test_int8(2560, 3584);\n  // group_q_latency_test_int8(1536, 5120);\n  // group_q_latency_test_int8(7168, 2048);\n  // group_q_latency_test_int8(2048, 7168);\n\n  group_q_latency_test_int4(5120, 1536);\n  group_q_latency_test_int4(3584, 2560);\n  group_q_latency_test_int4(2560, 3584);\n  group_q_latency_test_int4(1536, 5120);\n  group_q_latency_test_int4(7168, 2048);\n  group_q_latency_test_int4(2048, 7168);\n\n  // group_q_latency_test_int4_1(5120, 1536);\n  // group_q_latency_test_int4_1(3584, 2560);\n  // group_q_latency_test_int4_1(2560, 3584);\n  // group_q_latency_test_int4_1(1536, 5120);\n  // group_q_latency_test_int4_1(7168, 2048);\n  // group_q_latency_test_int4_1(2048, 7168);\n\n  // int k = 2048;\n  // correction_test_int4_1(32, 32, k);\n  // correction_test_int4(256, 256, 2048);\n  // correction_test_int4(32, 32, 4096);\n  // correction_test_int4(256, 256, 4096);\n  // correction_test_int4(32, 32, k);\n  // correction_test_int4(256, 32, 128);\n  // correction_test_int4(32, 64, 128);\n  // correction_test_int4(64, 32, 128);\n  // correction_test_int4(256, 256, 128);\n}\n"
  },
  {
    "path": "kt-kernel/operators/amx/test/analyze-error.cpp",
    "content": "#include <cmath>\n#include <iostream>\n#include <memory>\n#include <random>\n#include <vector>\n\n#include \"../la/amx.hpp\"\n\nvoid analyze_error_patterns() {\n  std::cout << \"=== Analyzing Error Patterns in K-Group Quantization ===\" << std::endl;\n\n  const int m = 32;\n  const int n = 32;\n  const int k = 512;\n  const int k_group_size = 128;\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n\n  void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n  void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));\n  void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  Kernel::config();\n\n  std::cout << \"\\n1. Testing with very small values (prone to quantization loss):\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    // Very small values - will mostly quantize to 0\n    for (int i = 0; i < m * k; i++) {\n      input_a[i] = ggml_compute_fp32_to_bf16(0.0001f * (i % 10));\n    }\n    for (int i = 0; i < k * n; i++) {\n      input_b[i] = ggml_compute_fp32_to_bf16(0.0001f * (i % 10));\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n\n    // Check scales\n    float a_scale = *ba->get_scale(m, 0, k, 0);\n    float b_scale = *bb->get_scale(n, 0, k, 0);\n    std::cout << \"  A scale: \" << a_scale << \", B scale: \" << b_scale << std::endl;\n\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    float first_val = ggml_compute_bf16_to_fp32(output[0]);\n    std::cout << \"  Result[0,0]: \" << first_val << std::endl;\n  }\n\n  std::cout << \"\\n2. Testing with values near quantization boundaries:\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    // Values at quantization boundaries (multiples of 1/127 for int8)\n    for (int i = 0; i < m * k; i++) {\n      float val = (i % 16) / 127.0f;  // INT4 has 16 levels\n      input_a[i] = ggml_compute_fp32_to_bf16(val);\n    }\n    for (int i = 0; i < k * n; i++) {\n      float val = (i % 16) / 127.0f;\n      input_b[i] = ggml_compute_fp32_to_bf16(val);\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    std::cout << \"  First row results: \";\n    for (int j = 0; j < 5; j++) {\n      float val = ggml_compute_bf16_to_fp32(output[j]);\n      std::cout << val << \" \";\n    }\n    std::cout << std::endl;\n  }\n\n  std::cout << \"\\n3. Testing with different scale ranges per k-group:\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    // Different magnitude for each k-group\n    for (int i = 0; i < m; i++) {\n      for (int j = 0; j < k; j++) {\n        int kg = j / k_group_size;\n        float scale = std::pow(10.0f, -kg);  // 1.0, 0.1, 0.01, 0.001\n        input_a[i * k + j] = ggml_compute_fp32_to_bf16(scale * 0.5f);\n      }\n    }\n\n    for (int i = 0; i < k; i++) {\n      for (int j = 0; j < n; j++) {\n        int kg = i / k_group_size;\n        float scale = std::pow(10.0f, -kg);\n        input_b[i * n + j] = ggml_compute_fp32_to_bf16(scale * 0.5f);\n      }\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n\n    // Print scales for each k-group\n    std::cout << \"  A scales per k-group: \";\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *ba->get_scale(m, 0, k, kg * k_group_size);\n      std::cout << scale << \" \";\n    }\n    std::cout << std::endl;\n\n    std::cout << \"  B scales per k-group: \";\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *bb->get_scale(n, 0, k, kg * k_group_size);\n      std::cout << scale << \" \";\n    }\n    std::cout << std::endl;\n\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    // Compute reference\n    float ref = 0.0f;\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = std::pow(10.0f, -kg);\n      ref += k_group_size * scale * scale * 0.25f;  // 0.5 * 0.5\n    }\n\n    float actual = ggml_compute_bf16_to_fp32(output[0]);\n    std::cout << \"  Expected: \" << ref << \", Actual: \" << actual << std::endl;\n    std::cout << \"  Error: \" << std::abs(ref - actual) / ref * 100 << \"%\" << std::endl;\n  }\n\n  std::cout << \"\\n4. Testing with sparse patterns (many zeros):\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    // Sparse pattern - 90% zeros\n    std::mt19937 gen(42);\n    std::uniform_real_distribution<float> dist(0.0f, 1.0f);\n\n    for (int i = 0; i < m * k; i++) {\n      float val = (dist(gen) < 0.1f) ? 0.5f : 0.0f;\n      input_a[i] = ggml_compute_fp32_to_bf16(val);\n    }\n    for (int i = 0; i < k * n; i++) {\n      float val = (dist(gen) < 0.1f) ? 0.5f : 0.0f;\n      input_b[i] = ggml_compute_fp32_to_bf16(val);\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    // Compute statistics\n    float max_val = 0.0f;\n    float avg_val = 0.0f;\n    int non_zero = 0;\n\n    for (int i = 0; i < m * n; i++) {\n      float val = std::abs(ggml_compute_bf16_to_fp32(output[i]));\n      max_val = std::max(max_val, val);\n      avg_val += val;\n      if (val > 1e-6) non_zero++;\n    }\n    avg_val /= (m * n);\n\n    std::cout << \"  Max value: \" << max_val << std::endl;\n    std::cout << \"  Avg value: \" << avg_val << std::endl;\n    std::cout << \"  Non-zero outputs: \" << non_zero << \"/\" << m * n << std::endl;\n  }\n\n  std::cout << \"\\n5. Testing with gradual value changes (worst case for k-group):\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    // Gradual increase across k dimension - worst case for k-group quantization\n    for (int i = 0; i < m; i++) {\n      for (int j = 0; j < k; j++) {\n        float val = j * 0.001f;  // Gradual increase\n        input_a[i * k + j] = ggml_compute_fp32_to_bf16(val);\n      }\n    }\n\n    for (int i = 0; i < k; i++) {\n      for (int j = 0; j < n; j++) {\n        float val = 0.1f;  // Constant\n        input_b[i * n + j] = ggml_compute_fp32_to_bf16(val);\n      }\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n\n    // Check how scales vary\n    std::cout << \"  A scales (should increase): \";\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *ba->get_scale(m, 0, k, kg * k_group_size);\n      std::cout << scale << \" \";\n    }\n    std::cout << std::endl;\n\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    // Reference calculation\n    float ref = 0.0f;\n    for (int j = 0; j < k; j++) {\n      ref += j * 0.001f * 0.1f;\n    }\n\n    float actual = ggml_compute_bf16_to_fp32(output[0]);\n    std::cout << \"  Expected: \" << ref << \", Actual: \" << actual << std::endl;\n    std::cout << \"  Error: \" << std::abs(ref - actual) / ref * 100 << \"%\" << std::endl;\n  }\n\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n}\n\nint main() {\n  analyze_error_patterns();\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/avx-test.cpp",
    "content": "\n#include <immintrin.h>\n#include <omp.h>\n\n#include <chrono>\n#include <cstdlib>\n#include <iostream>\n#include <random>\n\nconstexpr size_t DATA_SIZE = 100ULL * 1024 * 1024 * 1024;  // 100 GB\nconstexpr size_t ALIGNMENT = 64;                           // alignment for AVX-512\nconstexpr int TEST_ITERATIONS = 100;\nconstexpr int INNER_TEST_ITERATIONS = 100;\n\nvoid generate_data(uint8_t* data, size_t size) {\n  size_t size_int64 = size / sizeof(int64_t);\n\n#pragma omp parallel\n  {\n    std::mt19937_64 engine(omp_get_thread_num());\n    std::uniform_int_distribution<int64_t> dist;\n\n    int64_t* data64 = reinterpret_cast<int64_t*>(data);\n\n#pragma omp for\n    for (size_t i = 0; i < size_int64; ++i) {\n      data64[i] = dist(engine);\n    }\n  }\n}\n\nvoid dpbusd_test(const uint8_t* data_a, const uint8_t* data_b, int32_t* result, size_t size) {\n  constexpr size_t simd_width = 64;  // 512 bits = 64 bytes\n  size_t vec_count = size / simd_width;\n\n#pragma omp parallel for\n  for (size_t x = 0; x < vec_count * INNER_TEST_ITERATIONS; ++x) {\n    auto i = x % vec_count;\n    __m512i va = _mm512_load_si512(reinterpret_cast<const __m512i*>(data_a + i * simd_width));\n    __m512i vb = _mm512_load_si512(reinterpret_cast<const __m512i*>(data_b + i * simd_width));\n    __m512i vc = _mm512_setzero_si512();\n\n    vc = _mm512_dpbusd_epi32(vc, va, vb);\n\n    _mm512_store_si512(reinterpret_cast<__m512i*>(result + i * (simd_width / 4)), vc);\n  }\n}\n\nint main() {\n  std::cout << \"Allocating aligned memory...\\n\";\n  uint8_t* data_a = reinterpret_cast<uint8_t*>(aligned_alloc(ALIGNMENT, DATA_SIZE));\n  uint8_t* data_b = reinterpret_cast<uint8_t*>(aligned_alloc(ALIGNMENT, DATA_SIZE));\n  int32_t* result = reinterpret_cast<int32_t*>(aligned_alloc(ALIGNMENT, DATA_SIZE));\n\n  std::cout << \"Generating random data...\\n\";\n  generate_data(data_a, DATA_SIZE);\n  generate_data(data_b, DATA_SIZE);\n\n  for (int iter = 0; iter < TEST_ITERATIONS; ++iter) {\n    std::cout << \"Starting computation iteration \" << iter + 1 << \"...\\n\";\n    auto start = std::chrono::high_resolution_clock::now();\n\n    dpbusd_test(data_a, data_b, result, DATA_SIZE);\n\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n\n    double bandwidth = (3 * DATA_SIZE * INNER_TEST_ITERATIONS) / (1e9) / diff.count();  // GB/s\n\n    std::cout << \"Iteration \" << iter + 1 << \" execution time: \" << diff.count() << \" s\\n\";\n    std::cout << \"Iteration \" << iter + 1 << \" estimated memory bandwidth: \" << bandwidth << \" GB/s\\n\";\n  }\n\n  free(data_a);\n  free(data_b);\n  free(result);\n\n  return 0;\n}\n"
  },
  {
    "path": "kt-kernel/operators/amx/test/debug-kgroup-details.cpp",
    "content": "#include <cmath>\n#include <iostream>\n#include <memory>\n#include <vector>\n\n#include \"../la/amx.hpp\"\n\nvoid debug_kgroup_details() {\n  std::cout << \"=== Debugging K-Group Details ===\\n\" << std::endl;\n\n  const int m = 32;  // Minimum size for AMX\n  const int n = 32;\n  const int k = 512;  // 4 k-groups, must be >= K_BLOCK\n  const int k_group_size = 128;\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n\n  void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n  void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));\n  void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  // Test with specific values to debug quantization\n  std::cout << \"Test: Specific values with normal distribution\\n\" << std::endl;\n\n  std::vector<ggml_bf16_t> input_a(m * k);\n  std::vector<ggml_bf16_t> input_b(k * n);\n\n  std::mt19937 gen(42);\n  std::normal_distribution<float> dist(0.0f, 0.1f);\n\n  // Fill with random normal values and print some\n  std::cout << \"Sample A values (first 8):\" << std::endl;\n  for (int i = 0; i < 8; i++) {\n    float val = dist(gen);\n    input_a[i] = ggml_compute_fp32_to_bf16(val);\n    std::cout << \"  A[\" << i << \"] = \" << val << std::endl;\n  }\n\n  // Fill rest of A\n  for (int i = 8; i < m * k; i++) {\n    input_a[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n\n  std::cout << \"\\nSample B values (first 8):\" << std::endl;\n  for (int i = 0; i < 8; i++) {\n    float val = dist(gen);\n    input_b[i] = ggml_compute_fp32_to_bf16(val);\n    std::cout << \"  B[\" << i << \"] = \" << val << std::endl;\n  }\n\n  // Fill rest of B\n  for (int i = 8; i < k * n; i++) {\n    input_b[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n\n  // Quantize\n  ba->from_mat(m, input_a.data(), 0, 1);\n  bb->from_mat(input_b.data(), 0, 1);\n\n  // Print scales for debugging\n  std::cout << \"\\nA scales (per k-group):\" << std::endl;\n  for (int row = 0; row < m; row++) {\n    std::cout << \"  Row \" << row << \": \";\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *ba->get_scale(m, row, k, kg * k_group_size);\n      std::cout << \"kg\" << kg << \"=\" << scale << \" \";\n    }\n    std::cout << std::endl;\n  }\n\n  std::cout << \"\\nB scales (per k-group):\" << std::endl;\n  for (int col = 0; col < n; col++) {\n    std::cout << \"  Col \" << col << \": \";\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *bb->get_scale(n, col, k, kg * k_group_size);\n      std::cout << \"kg\" << kg << \"=\" << scale << \" \";\n    }\n    std::cout << std::endl;\n  }\n\n  // Test dequantization to check if quantization is working\n  std::cout << \"\\nDequantization test (first row of A):\" << std::endl;\n  // We need to manually dequantize to check\n  // Get quantized values and scale\n  int8_t* a_data = (int8_t*)ba->get_submat(m, k, 0, 0);\n  float scale0 = *ba->get_scale(m, 0, k, 0);\n\n  std::cout << \"  First 8 quantized values: \";\n  for (int i = 0; i < 8; i++) {\n    std::cout << (int)a_data[i] << \" \";\n  }\n  std::cout << std::endl;\n\n  std::cout << \"  Dequantized (q * scale): \";\n  for (int i = 0; i < 8; i++) {\n    float dequant = a_data[i] * scale0;\n    float original = ggml_compute_bf16_to_fp32(input_a[i]);\n    std::cout << dequant << \" (orig=\" << original << \") \";\n  }\n  std::cout << std::endl;\n\n  // Compute reference\n  std::cout << \"\\nComputing reference result...\" << std::endl;\n  std::vector<float> ref_result(m * n, 0.0f);\n  for (int i = 0; i < m; i++) {\n    for (int j = 0; j < n; j++) {\n      float sum = 0.0f;\n      for (int l = 0; l < k; l++) {\n        float a_val = ggml_compute_bf16_to_fp32(input_a[i * k + l]);\n        float b_val = ggml_compute_bf16_to_fp32(input_b[l * n + j]);\n        sum += a_val * b_val;\n      }\n      ref_result[i * n + j] = sum;\n    }\n  }\n\n  // Run k-group multiplication\n  Kernel::config();\n  amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n  std::vector<ggml_bf16_t> output(m * n);\n  bc->to_mat(m, output.data(), 0, 1);\n\n  // Compare results\n  std::cout << \"\\nResults comparison:\" << std::endl;\n  for (int i = 0; i < m; i++) {\n    for (int j = 0; j < n; j++) {\n      int idx = i * n + j;\n      float actual = ggml_compute_bf16_to_fp32(output[idx]);\n      float ref = ref_result[idx];\n      float error = std::abs(actual - ref) / (std::abs(ref) + 1e-8) * 100;\n      std::cout << \"  [\" << i << \",\" << j << \"]: actual=\" << actual << \", ref=\" << ref << \", error=\" << error << \"%\"\n                << std::endl;\n    }\n  }\n\n  // Test a simple case to verify the mechanism\n  std::cout << \"\\n--- Simple test with k_group boundaries ---\" << std::endl;\n\n  // Clear buffers\n  for (int i = 0; i < m * k; i++) {\n    input_a[i] = ggml_compute_fp32_to_bf16(0.0f);\n  }\n  for (int i = 0; i < k * n; i++) {\n    input_b[i] = ggml_compute_fp32_to_bf16(0.0f);\n  }\n\n  // Set specific values for each k-group\n  for (int i = 0; i < m; i++) {\n    // First k-group (0-127): value = 0.5\n    for (int j = 0; j < 128; j++) {\n      input_a[i * k + j] = ggml_compute_fp32_to_bf16(0.5f);\n    }\n    // Second k-group (128-255): value = 0.25\n    for (int j = 128; j < 256; j++) {\n      input_a[i * k + j] = ggml_compute_fp32_to_bf16(0.25f);\n    }\n    // Remaining k-groups: value = 0.1\n    for (int j = 256; j < k; j++) {\n      input_a[i * k + j] = ggml_compute_fp32_to_bf16(0.1f);\n    }\n  }\n\n  // B matrix: all 0.4\n  for (int i = 0; i < k * n; i++) {\n    input_b[i] = ggml_compute_fp32_to_bf16(0.4f);\n  }\n\n  ba->from_mat(m, input_a.data(), 0, 1);\n  bb->from_mat(input_b.data(), 0, 1);\n\n  // Expected: 0.5 * 0.4 * 128 + 0.25 * 0.4 * 128 + 0.1 * 0.4 * 256 = 25.6 + 12.8 + 10.24 = 48.64\n  float expected = 0.5f * 0.4f * 128 + 0.25f * 0.4f * 128 + 0.1f * 0.4f * 256;\n  std::cout << \"Expected value: \" << expected << std::endl;\n\n  amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n  bc->to_mat(m, output.data(), 0, 1);\n\n  float actual = ggml_compute_bf16_to_fp32(output[0]);\n  std::cout << \"Actual value: \" << actual << std::endl;\n  std::cout << \"Error: \" << std::abs(actual - expected) / expected * 100 << \"%\" << std::endl;\n\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n}\n\nint main() {\n  debug_kgroup_details();\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/debug-kgroup.cpp",
    "content": "#include <omp.h>\n\n#include \"../la/amx.hpp\"\n#define FMT_HEADER_ONLY\n#include <fmt/core.h>\n\n#include <cmath>\n#include <iostream>\n#include <memory>\n#include <random>\n\nvoid debug_simple_multiplication() {\n  std::cout << \"=== Debug Simple K-Group Multiplication ===\" << std::endl;\n\n  // Very small test case for debugging\n  const int m = 32;   // 1 M_STEP\n  const int n = 32;   // 1 N_STEP\n  const int k = 512;  // Must be at least K_BLOCK (512)\n  const int k_group_size = 128;\n\n  std::cout << fmt::format(\"Parameters: m={}, n={}, k={}, k_group_size={}\\n\", m, n, k, k_group_size);\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n\n  // Allocate buffers\n  void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n  void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));\n  void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  // Create identity-like matrices for easy verification\n  std::vector<ggml_bf16_t> input_a(m * k);\n  std::vector<ggml_bf16_t> input_b(k * n);\n\n  // Initialize A as mostly zeros with a few ones\n  for (int i = 0; i < m * k; i++) {\n    input_a[i] = ggml_compute_fp32_to_bf16(0.0f);\n  }\n  // Set A[0,0] = 1\n  input_a[0] = ggml_compute_fp32_to_bf16(1.0f);\n\n  // Initialize B as mostly zeros with a few ones\n  for (int i = 0; i < k * n; i++) {\n    input_b[i] = ggml_compute_fp32_to_bf16(0.0f);\n  }\n  // Set B[0,0] = 1\n  input_b[0] = ggml_compute_fp32_to_bf16(1.0f);\n\n  // Expected result: C[0,0] = 1*1 = 1, rest = 0\n  std::cout << \"\\nExpected result: C[0,0] = 1.0, rest = 0.0\\n\" << std::endl;\n\n  // Quantize inputs\n  ba->from_mat(m, input_a.data(), 0, 1);\n  bb->from_mat(input_b.data(), 0, 1);\n\n  // Print scales for debugging\n  std::cout << \"BufferA scales for row 0:\" << std::endl;\n  for (int kg = 0; kg < k / k_group_size; kg++) {\n    float scale = *ba->get_scale(m, 0, k, kg * k_group_size);\n    std::cout << fmt::format(\"  k_group[{}]: scale = {:.6f}\\n\", kg, scale);\n  }\n\n  std::cout << \"\\nBufferB scales for col 0:\" << std::endl;\n  for (int kg = 0; kg < k / k_group_size; kg++) {\n    float scale = *bb->get_scale(n, 0, k, kg * k_group_size);\n    std::cout << fmt::format(\"  k_group[{}]: scale = {:.6f}\\n\", kg, scale);\n  }\n\n  // Configure AMX\n  Kernel::config();\n\n  // Run matrix multiplication\n  amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n  // Get output\n  std::vector<ggml_bf16_t> output(m * n);\n  bc->to_mat(m, output.data(), 0, 1);\n\n  // Print results\n  std::cout << \"\\nActual result (first 5x5):\" << std::endl;\n  for (int i = 0; i < std::min(5, m); i++) {\n    for (int j = 0; j < std::min(5, n); j++) {\n      float val = ggml_compute_bf16_to_fp32(output[i * n + j]);\n      std::cout << fmt::format(\"{:8.4f} \", val);\n    }\n    std::cout << std::endl;\n  }\n\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n}\n\nvoid debug_pattern_multiplication() {\n  std::cout << \"\\n=== Debug Pattern Multiplication ===\" << std::endl;\n\n  const int m = 32;\n  const int n = 32;\n  const int k = 512;  // Must be at least K_BLOCK (512)\n  const int k_group_size = 128;\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n\n  void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n  void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));\n  void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  // Create constant matrices\n  std::vector<ggml_bf16_t> input_a(m * k);\n  std::vector<ggml_bf16_t> input_b(k * n);\n\n  // Fill A with 0.1 and B with 0.1\n  for (int i = 0; i < m * k; i++) {\n    input_a[i] = ggml_compute_fp32_to_bf16(0.1f);\n  }\n  for (int i = 0; i < k * n; i++) {\n    input_b[i] = ggml_compute_fp32_to_bf16(0.1f);\n  }\n\n  // Expected: Each element should be 0.1 * 0.1 * k = 0.01 * 512 = 5.12\n  float expected = 0.1f * 0.1f * k;\n  std::cout << fmt::format(\"\\nExpected result: all elements = {:.4f}\\n\", expected);\n\n  // Quantize\n  ba->from_mat(m, input_a.data(), 0, 1);\n  bb->from_mat(input_b.data(), 0, 1);\n\n  // Run\n  Kernel::config();\n  amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n  // Get output\n  std::vector<ggml_bf16_t> output(m * n);\n  bc->to_mat(m, output.data(), 0, 1);\n\n  // Check results\n  float max_error = 0.0f;\n  float avg_error = 0.0f;\n  for (int i = 0; i < m * n; i++) {\n    float actual = ggml_compute_bf16_to_fp32(output[i]);\n    float error = std::abs(actual - expected);\n    max_error = std::max(max_error, error);\n    avg_error += error;\n  }\n  avg_error /= (m * n);\n\n  std::cout << fmt::format(\"Max error: {:.6f}\\n\", max_error);\n  std::cout << fmt::format(\"Avg error: {:.6f}\\n\", avg_error);\n  std::cout << fmt::format(\"Relative error: {:.2f}%\\n\", (max_error / expected) * 100);\n\n  // Print sample values\n  std::cout << \"\\nSample values (first 5x5):\" << std::endl;\n  for (int i = 0; i < std::min(5, m); i++) {\n    for (int j = 0; j < std::min(5, n); j++) {\n      float val = ggml_compute_bf16_to_fp32(output[i * n + j]);\n      std::cout << fmt::format(\"{:8.4f} \", val);\n    }\n    std::cout << std::endl;\n  }\n\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n}\n\nvoid compare_with_regular_int4() {\n  std::cout << \"\\n=== Compare K-Group vs Regular INT4 ===\" << std::endl;\n\n  const int m = 32;\n  const int n = 32;\n  const int k = 512;\n  const int k_group_size = 128;\n\n  // Create test data\n  std::vector<ggml_bf16_t> input_a(m * k);\n  std::vector<ggml_bf16_t> input_b(k * n);\n\n  std::mt19937 gen(42);\n  std::uniform_real_distribution<float> dist(-0.1f, 0.1f);\n\n  for (int i = 0; i < m * k; i++) {\n    input_a[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n  for (int i = 0; i < k * n; i++) {\n    input_b[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n\n  // Test with regular INT4\n  {\n    using Kernel = amx::GemmKernel224Int4;\n    using BufferA = Kernel::BufferA;\n    using BufferB = Kernel::BufferB;\n    using BufferC = Kernel::BufferC;\n\n    void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k));\n    void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k));  // Fixed: n, k not k, n\n    void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n    auto ba = std::make_shared<BufferA>(m, k, buffer_a);\n    auto bb = std::make_shared<BufferB>(n, k, buffer_b);  // Fixed: n, k not k, n\n    auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n\n    Kernel::config();\n    amx::mat_mul(m, n, k, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output_regular(m * n);\n    bc->to_mat(m, output_regular.data(), 0, 1);\n\n    std::cout << \"Regular INT4 results (first 3x3):\" << std::endl;\n    for (int i = 0; i < 3; i++) {\n      for (int j = 0; j < 3; j++) {\n        float val = ggml_compute_bf16_to_fp32(output_regular[i * n + j]);\n        std::cout << fmt::format(\"{:8.4f} \", val);\n      }\n      std::cout << std::endl;\n    }\n\n    free(buffer_a);\n    free(buffer_b);\n    free(buffer_c);\n  }\n\n  // Test with K-Group INT4\n  {\n    using Kernel = amx::GemmKernel224Int4KGroup;\n    using BufferA = Kernel::BufferA;\n    using BufferB = Kernel::BufferB;\n    using BufferC = Kernel::BufferC;\n\n    void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n    void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));\n    void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n    auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n    auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);\n    auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n\n    Kernel::config();\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output_kgroup(m * n);\n    bc->to_mat(m, output_kgroup.data(), 0, 1);\n\n    std::cout << \"\\nK-Group INT4 results (first 3x3):\" << std::endl;\n    for (int i = 0; i < 3; i++) {\n      for (int j = 0; j < 3; j++) {\n        float val = ggml_compute_bf16_to_fp32(output_kgroup[i * n + j]);\n        std::cout << fmt::format(\"{:8.4f} \", val);\n      }\n      std::cout << std::endl;\n    }\n\n    free(buffer_a);\n    free(buffer_b);\n    free(buffer_c);\n  }\n}\n\nint main() {\n  std::cout << \"Starting K-Group Debugging\\n\" << std::endl;\n\n  debug_simple_multiplication();\n  debug_pattern_multiplication();\n  compare_with_regular_int4();\n\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/debug-specific-dims.cpp",
    "content": "#include <cmath>\n#include <iostream>\n#include <memory>\n#include <vector>\n\n#include \"../la/amx.hpp\"\n\nvoid debug_specific_dimensions() {\n  std::cout << \"=== Debugging Specific Dimensions Issue ===\\n\" << std::endl;\n\n  const int m_original = 200;\n  const int n = 2048;\n  const int k = 7168;\n  const int k_group_size = 128;\n\n  const int M_STEP = 32;\n  const int m = ((m_original + M_STEP - 1) / M_STEP) * M_STEP;  // Round up to 224\n\n  std::cout << \"Original dimensions: \" << m_original << \" x \" << n << \" x \" << k << std::endl;\n  std::cout << \"Padded dimensions: \" << m << \" x \" << n << \" x \" << k << std::endl;\n  std::cout << \"K-group size: \" << k_group_size << std::endl;\n  std::cout << \"Number of k-groups: \" << k / k_group_size << std::endl;\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n\n  void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n  void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));\n  void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  // Test 1: Simple pattern - all ones\n  std::cout << \"\\n--- Test 1: All ones (should give k = 7168) ---\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    for (int i = 0; i < m * k; i++) {\n      input_a[i] = ggml_compute_fp32_to_bf16(1.0f);\n    }\n    for (int i = 0; i < k * n; i++) {\n      input_b[i] = ggml_compute_fp32_to_bf16(1.0f);\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n\n    // Check some scales\n    std::cout << \"A scales (first 3 k-groups): \";\n    for (int kg = 0; kg < 3; kg++) {\n      float scale = *ba->get_scale(m, 0, k, kg * k_group_size);\n      std::cout << scale << \" \";\n    }\n    std::cout << std::endl;\n\n    std::cout << \"B scales (first 3 k-groups): \";\n    for (int kg = 0; kg < 3; kg++) {\n      float scale = *bb->get_scale(n, 0, k, kg * k_group_size);\n      std::cout << scale << \" \";\n    }\n    std::cout << std::endl;\n\n    Kernel::config();\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    float expected = 7168.0f;\n    float actual = ggml_compute_bf16_to_fp32(output[0]);\n    std::cout << \"Expected: \" << expected << \", Actual: \" << actual << std::endl;\n    std::cout << \"Error: \" << std::abs(actual - expected) / expected * 100 << \"%\" << std::endl;\n  }\n\n  // Test 2: Small values\n  std::cout << \"\\n--- Test 2: Small values (0.01) ---\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    for (int i = 0; i < m * k; i++) {\n      input_a[i] = ggml_compute_fp32_to_bf16(0.01f);\n    }\n    for (int i = 0; i < k * n; i++) {\n      input_b[i] = ggml_compute_fp32_to_bf16(0.01f);\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n\n    Kernel::config();\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    float expected = 0.01f * 0.01f * 7168.0f;  // 0.7168\n    float actual = ggml_compute_bf16_to_fp32(output[0]);\n    std::cout << \"Expected: \" << expected << \", Actual: \" << actual << std::endl;\n    std::cout << \"Error: \" << std::abs(actual - expected) / expected * 100 << \"%\" << std::endl;\n  }\n\n  // Test 3: Identity-like pattern\n  std::cout << \"\\n--- Test 3: Identity pattern ---\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    // Initialize to zeros\n    for (int i = 0; i < m * k; i++) {\n      input_a[i] = ggml_compute_fp32_to_bf16(0.0f);\n    }\n    for (int i = 0; i < k * n; i++) {\n      input_b[i] = ggml_compute_fp32_to_bf16(0.0f);\n    }\n\n    // Set diagonal to 1\n    int min_dim = std::min(std::min(m, n), k);\n    for (int i = 0; i < min_dim; i++) {\n      input_a[i * k + i] = ggml_compute_fp32_to_bf16(1.0f);\n      input_b[i * n + i] = ggml_compute_fp32_to_bf16(1.0f);\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n\n    Kernel::config();\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    // Check diagonal elements\n    std::cout << \"Diagonal elements (should be 1): \";\n    for (int i = 0; i < std::min(5, min_dim); i++) {\n      float val = ggml_compute_bf16_to_fp32(output[i * n + i]);\n      std::cout << val << \" \";\n    }\n    std::cout << std::endl;\n  }\n\n  // Test 4: Pattern with different values per k-group\n  std::cout << \"\\n--- Test 4: Different values per k-group ---\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    // Each k-group has different value\n    for (int i = 0; i < m; i++) {\n      for (int j = 0; j < k; j++) {\n        int kg = j / k_group_size;\n        float val = (kg + 1) * 0.1f;  // 0.1, 0.2, 0.3, ...\n        input_a[i * k + j] = ggml_compute_fp32_to_bf16(val);\n      }\n    }\n\n    for (int i = 0; i < k; i++) {\n      for (int j = 0; j < n; j++) {\n        input_b[i * n + j] = ggml_compute_fp32_to_bf16(0.1f);\n      }\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n\n    // Check scales for different k-groups\n    std::cout << \"A scales (first 5 k-groups): \";\n    for (int kg = 0; kg < std::min(5, k / k_group_size); kg++) {\n      float scale = *ba->get_scale(m, 0, k, kg * k_group_size);\n      std::cout << scale << \" \";\n    }\n    std::cout << std::endl;\n\n    Kernel::config();\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    // Expected: sum of (kg+1)*0.1 * 0.1 * k_group_size for all k-groups\n    float expected = 0.0f;\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      expected += (kg + 1) * 0.1f * 0.1f * k_group_size;\n    }\n\n    float actual = ggml_compute_bf16_to_fp32(output[0]);\n    std::cout << \"Expected: \" << expected << \", Actual: \" << actual << std::endl;\n    std::cout << \"Error: \" << std::abs(actual - expected) / expected * 100 << \"%\" << std::endl;\n  }\n\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n}\n\nint main() {\n  debug_specific_dimensions();\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/mat-test.hpp",
    "content": "#ifndef AMX_MAT_TEST_HPP\n#define AMX_MAT_TEST_HPP\n\n#include <cassert>\n#include <iostream>\n#include <limits>\n#include <random>\n\n#include \"../../common.hpp\"\n#include \"../la/utils.hpp\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"timer.hh\"\n\ntemplate <typename T>\nstruct DotProductImpl {\n  static_assert(sizeof(T) == -1, \"No associated type defined for this type.\");\n  using type = void;\n};\n\ntemplate <typename T>\nusing DotProductType = typename DotProductImpl<T>::type;\n\ntemplate <>\nstruct DotProductImpl<uint8_t> {\n  using type = uint32_t;\n};\ntemplate <>\nstruct DotProductImpl<int8_t> {\n  using type = int32_t;\n};\ntemplate <>\nstruct DotProductImpl<uint32_t> {\n  using type = uint32_t;\n};\ntemplate <>\nstruct DotProductImpl<int32_t> {\n  using type = int32_t;\n};\n\ntemplate <>\nstruct DotProductImpl<float> {\n  using type = float;\n};\n\nenum class Layout {\n  RowMajor,\n  ColumnMajor,\n  VNNIColumnMajor,\n};\n\ntemplate <typename T>\nstruct Mat {\n  int rows, cols;\n  size_t size() { return rows * cols; }\n  T* data;\n  size_t stride_in_bytes;\n\n  void* qdata = nullptr;\n  ggml_type q_type;\n  size_t q_stride;\n\n  Layout layout = Layout::RowMajor;\n\n  Mat() {};\n\n  Mat(int rows, int cols, Layout layout) : rows(rows), cols(cols), layout(layout) {\n    size_t total_size;\n    if (layout == Layout::RowMajor) {\n      stride_in_bytes = cols * sizeof(T);\n      stride_in_bytes = (stride_in_bytes + 63) / 64 * 64;\n      total_size = stride_in_bytes * rows;\n    } else if (layout == Layout::ColumnMajor) {\n      stride_in_bytes = rows * sizeof(T);\n      stride_in_bytes = (stride_in_bytes + 63) / 64 * 64;\n      total_size = stride_in_bytes * cols;\n    } else {\n      assert(0);\n    }\n\n    // data = new(std::align_val_t(64)) T[rows * cols];\n    data = reinterpret_cast<T*>(aligned_alloc(64, total_size));\n    memset(data, 0, total_size);\n  }\n\n  Mat<T> sub_mat(int r, int c) {\n    Mat<T> re;\n    re.rows = r;\n    re.cols = c;\n    re.data = data;\n    re.layout = layout;\n    re.stride_in_bytes = stride_in_bytes;\n    re.qdata = qdata;\n    re.q_stride = q_stride;\n    re.q_type = q_type;\n  }\n\n  void dealloc() {\n    delete[] data;\n    if (qdata) {\n      delete[] reinterpret_cast<char*>(qdata);\n    }\n  }\n\n  void row_major_increase() {\n    int x = 0;\n    for (int i = 0; i < rows; i++) {\n      for (int j = 0; j < cols; j++) {\n        at(i, j) = x++;\n      }\n    }\n  }\n\n  void dis_to_00() {\n    for (int i = 0; i < rows; i++) {\n      for (int j = 0; j < cols; j++) {\n        at(i, j) = i + j;\n      }\n    }\n  }\n\n  void random(std::mt19937& gen) {\n    if constexpr (std::is_integral_v<T>) {\n      std::uniform_int_distribution<T> dist(0, 100);\n      for (int i = 0; i < rows; i++) {\n        for (int j = 0; j < cols; j++) {\n          at(i, j) = dist(gen);\n        }\n      }\n    } else if constexpr (std::is_floating_point_v<T>) {\n      std::uniform_real_distribution<T> dist(-1.0, 1.0);\n      for (int i = 0; i < rows; i++) {\n        std::mt19937 gen_row(gen());\n        for (int j = 0; j < cols; j++) {\n          at(i, j) = dist(gen_row);\n        }\n      }\n    } else {\n      throw std::runtime_error(\"Unsupported type\");\n    }\n  }\n\n  size_t stride() { return stride_in_bytes; }\n\n  int line_element_count() {\n    if (layout == Layout::RowMajor) {\n      return cols;\n    } else if (layout == Layout::ColumnMajor) {\n      return rows;\n    } else {\n      assert(0);\n    }\n    assert(0);\n    return 0;\n  }\n\n  T& at(int r, int c) {\n    switch (layout) {\n      case Layout::RowMajor:\n        return *offset_pointer_row_major(data, r, c, stride());\n      case Layout::ColumnMajor:\n        return *offset_pointer_col_major(data, r, c, stride());\n      // case Layout::VNNIColumnMajor:\n      // return data[c*rows+r];\n      default: {\n        assert(0);\n      }\n    }\n    throw std::runtime_error(\"Unsupported layout\");\n    // assert(0);\n  }\n\n  void print() {\n    int limit = 10;      // 设置阈值\n    int print_rows = 3;  // 开头和结尾打印的行数和列数\n\n    for (int i = 0; i < rows; i++) {\n      // 当行数过多时，跳过中间的行\n      if (rows > limit && (i >= print_rows && i < rows - print_rows)) {\n        if (i == print_rows) {\n          std::cout << \"...\\n...\\n\";\n        }\n        continue;\n      }\n\n      for (int j = 0; j < cols; j++) {\n        // 当列数过多时，跳过中间的列\n        if (cols > limit && (j >= print_rows && j < cols - print_rows)) {\n          if (j == print_rows) {\n            std::cout << \"... \";\n          }\n          continue;\n        }\n\n        if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>) {\n          std::cout << (int)at(i, j) << \" \";\n        } else {\n          std::cout << at(i, j) << \" \";\n        }\n      }\n      std::cout << std::endl;\n    }\n    std::cout << std::endl;\n  }\n\n  void print_all() {\n    for (int i = 0; i < rows; i++) {\n      for (int j = 0; j < cols; j++) {\n        if constexpr (std::is_same_v<T, uint8_t> || std::is_same_v<T, int8_t>) {\n          std::cout << (int)at(i, j) << \" \";\n        } else if constexpr (std::is_floating_point_v<T>) {\n          // std::cout << std::setw(6) << std::scientific << std::setprecision(2) << at(i, j) << \"  \";\n          printf(\"%6.2f \", at(i, j));\n        } else {\n          std::cout << at(i, j) << \" \";\n        }\n      }\n      std::cout << std::endl;\n    }\n    std::cout << std::endl;\n  }\n\n  Mat<DotProductType<T>> mul_check(Mat<T>& b) {\n    assert(cols == b.rows);\n    Mat<DotProductType<T>> c(rows, b.cols, Layout::RowMajor);\n    for (int i = 0; i < rows; i++) {\n      for (int j = 0; j < b.cols; j++) {\n        c.at(i, j) = 0;\n        for (int k = 0; k < cols; k++) {\n          c.at(i, j) += static_cast<DotProductType<T>>(at(i, k)) * static_cast<DotProductType<T>>(b.at(k, j));\n        }\n      }\n    }\n    return c;\n  }\n\n  bool cmp(Mat<T>& b) {\n    if constexpr (std::is_integral_v<T>) {\n      assert(rows == b.rows && cols == b.cols);\n      for (int i = 0; i < rows; i++) {\n        for (int j = 0; j < cols; j++) {\n          if (at(i, j) != b.at(i, j)) {\n            std::cout << \"Error at \" << i << \" \" << j << \" \" << at(i, j) << \", \" << b.at(i, j) << std::endl;\n            // std::cout << \"Error at \" << i << \" \" << j << std::endl;\n            // std::cout << \"Other: \" << b.at(i, j) << std::endl;\n            // std::cout << \"Me: \" << at(i, j) << std::endl;\n            // assert(0);\n            // break;\n            // return false;\n          }\n        }\n      }\n      std::cout << \"Check passed\" << std::endl;\n      return true;\n    }\n\n    if constexpr (std::is_floating_point_v<T>) {\n      T rel_error_sum = 0;\n      T error_sum = 0;\n      T max_error = 0;\n      T max_rel_error = 0;\n      int max_i = 0, max_j = 0;\n      assert(rows == b.rows && cols == b.cols);\n      for (int i = 0; i < rows; i++) {\n        for (int j = 0; j < cols; j++) {\n          T error = std::abs(at(i, j) - b.at(i, j));\n          error_sum += error;\n          rel_error_sum += error / std::abs(at(i, j));\n          if (error / std::abs(at(i, j)) > max_rel_error) {\n            max_rel_error = error / std::abs(at(i, j));\n          }\n          if (error > max_error) {\n            max_i = i;\n            max_j = j;\n            max_error = error;\n          }\n        }\n      }\n      if (rel_error_sum / size() > 1e-2 || max_error / at(max_i, max_j) > 1e-2) {\n        std::cout << \"Max Error: \" << std::fixed << max_error << \"(\" << max_error / at(max_i, max_j) << \")\"\n                  << \" at \" << max_i << \" \" << max_j << \", Max Rel Error \" << max_rel_error\n                  << \", Average Relative: \" << rel_error_sum / size() << \", Average Error: \" << error_sum / size()\n                  << std::endl;\n      } else {\n        std::cout << \"Error Less Than 1%\" << std::endl;\n      }\n\n      return true;\n    }\n  }\n\n  void quant(ggml_type to) {\n    if constexpr (std::is_same<T, float>::value == false) {\n      throw std::runtime_error(\"Quantization only supported for f32 matrices\");\n    }\n    // Timer t(std::string(\"to \") + ggml_type_name(to));\n    assert(line_element_count() * sizeof(T) == stride());\n    assert(line_element_count() % ggml_blck_size(to) == 0);\n    int blck_cnt_per_row = line_element_count() / ggml_blck_size(to);\n    q_stride = blck_cnt_per_row * ggml_type_size(to);\n\n    size_t qdata_size = size() * ggml_type_size(to) / ggml_blck_size(to);\n    qdata_size += 512 - q_stride % 512;\n\n    qdata = new (std::align_val_t(512)) char[qdata_size];\n    q_type = to;\n\n    switch (to) {\n      case GGML_TYPE_F32: {\n        return;\n      }\n      case GGML_TYPE_F16: {\n        ggml_fp32_to_fp16_row(data, reinterpret_cast<ggml_fp16_t*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_BF16: {\n        ggml_fp32_to_bf16_row(data, reinterpret_cast<ggml_bf16_t*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q4_0: {\n        quantize_row_q4_0(data, reinterpret_cast<block_q4_0*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q4_1: {\n        quantize_row_q4_1(data, reinterpret_cast<block_q4_1*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q5_0: {\n        quantize_row_q5_0(data, reinterpret_cast<block_q5_0*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q5_1: {\n        quantize_row_q5_1(data, reinterpret_cast<block_q5_1*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q8_0: {\n        quantize_row_q8_0(data, reinterpret_cast<block_q8_0*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q8_1: {\n        quantize_row_q8_1(data, reinterpret_cast<block_q8_1*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q2_K: {\n        quantize_row_q2_K(data, reinterpret_cast<block_q2_K*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q3_K: {\n        quantize_row_q3_K(data, reinterpret_cast<block_q3_K*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q4_K: {\n        quantize_row_q4_K(data, reinterpret_cast<block_q4_K*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q5_K: {\n        quantize_row_q5_K(data, reinterpret_cast<block_q5_K*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q6_K: {\n        quantize_row_q6_K(data, reinterpret_cast<block_q6_K*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_Q8_K: {\n        quantize_row_q8_K(data, reinterpret_cast<block_q8_K*>(qdata), size());\n        return;\n      }\n      case GGML_TYPE_IQ2_XXS:\n      case GGML_TYPE_IQ2_XS:\n      case GGML_TYPE_IQ3_XXS:\n      case GGML_TYPE_IQ1_S:\n      case GGML_TYPE_IQ4_NL:\n      case GGML_TYPE_IQ3_S:\n      case GGML_TYPE_IQ2_S:\n      case GGML_TYPE_IQ4_XS:\n      case GGML_TYPE_I8:\n      case GGML_TYPE_I16:\n      case GGML_TYPE_I32:\n      case GGML_TYPE_I64:\n      case GGML_TYPE_F64:\n      case GGML_TYPE_IQ1_M:\n      case GGML_TYPE_COUNT:\n      default:\n        throw std::runtime_error(\"Unsupported quantization type\");\n    }\n    throw std::runtime_error(\"Unsupported quantization type\");\n  }\n\n  template <typename Block>\n  Block* quant_data() {\n    return reinterpret_cast<Block*>(qdata);\n  }\n\n  void dequant() {\n    auto x = q_type;\n    switch (x) {\n      case GGML_TYPE_F32: {\n        return;\n      }\n      case GGML_TYPE_F16: {\n        ggml_fp16_to_fp32_row(reinterpret_cast<ggml_fp16_t*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_Q4_0: {\n        dequantize_row_q4_0(reinterpret_cast<block_q4_0*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_Q4_1: {\n        dequantize_row_q4_1(reinterpret_cast<block_q4_1*>(qdata), data, size());\n\n        return;\n      }\n      case GGML_TYPE_Q5_0: {\n        dequantize_row_q5_0(reinterpret_cast<block_q5_0*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_Q5_1: {\n        dequantize_row_q5_1(reinterpret_cast<block_q5_1*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_Q8_0: {\n        dequantize_row_q8_0(reinterpret_cast<block_q8_0*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_Q8_1: {\n        throw std::runtime_error(\"not supported\");\n      }\n      case GGML_TYPE_Q2_K: {\n        dequantize_row_q2_K(reinterpret_cast<block_q2_K*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_Q3_K: {\n        dequantize_row_q3_K(reinterpret_cast<block_q3_K*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_Q4_K: {\n        dequantize_row_q4_K(reinterpret_cast<block_q4_K*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_Q5_K: {\n        dequantize_row_q5_K(reinterpret_cast<block_q5_K*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_Q6_K: {\n        dequantize_row_q6_K(reinterpret_cast<block_q6_K*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_Q8_K: {\n        dequantize_row_q8_K(reinterpret_cast<block_q8_K*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_IQ2_XXS:\n      case GGML_TYPE_IQ2_XS:\n      case GGML_TYPE_IQ3_XXS:\n      case GGML_TYPE_IQ1_S:\n      case GGML_TYPE_IQ4_NL:\n      case GGML_TYPE_IQ3_S:\n      case GGML_TYPE_IQ2_S:\n      case GGML_TYPE_IQ4_XS:\n      case GGML_TYPE_I8:\n      case GGML_TYPE_I16:\n      case GGML_TYPE_I32:\n      case GGML_TYPE_I64:\n      case GGML_TYPE_F64:\n      case GGML_TYPE_IQ1_M:\n      case GGML_TYPE_BF16: {\n        ggml_bf16_to_fp32_row(reinterpret_cast<ggml_bf16_t*>(qdata), data, size());\n        return;\n      }\n      case GGML_TYPE_COUNT:\n      default:\n        throw std::runtime_error(\"Unsupported quantization type\");\n    }\n    throw std::runtime_error(\"Unsupported quantization type\");\n  }\n};\n\ninline void init() {\n  struct ggml_init_params params = {\n      0,\n      NULL,\n      true,\n  };\n\n  auto ctx_eval = ggml_init(params);\n\n  if (!ctx_eval) {\n    throw std::runtime_error(\"Failed to create ggml context\");\n  }\n}\n#endif"
  },
  {
    "path": "kt-kernel/operators/amx/test/mmq-test.cpp",
    "content": "\n#if defined(__GNUC__)\n#pragma GCC diagnostic ignored \"-Wpedantic\"\n#pragma GCC diagnostic ignored \"-Wunused-local-typedefs\"\n#endif\n\n#include \"mmq.h\"\n\n#include <algorithm>\n#include <type_traits>\n\n#include \"ggml-impl.h\"\n#include \"ggml-quants.h\"\n#include \"mat-test.hpp\"\n\n#if defined(__gnu_linux__)\n#include <sys/syscall.h>\n#include <unistd.h>\n#endif\n\n#if defined(_OPENMP)\n#include <omp.h>\n#endif\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define RESTRICT __restrict\n#else\n#define RESTRICT __restrict__\n#endif\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define ALWAYS_INLINE __forceinline\n#elif __has_attribute(always_inline) || defined(__GNUC__)\n#define ALWAYS_INLINE __attribute__((__always_inline__)) inline\n#else\n#define ALWAYS_INLINE inline\n#endif\n\n#if defined(__AMX_INT8__)\n\nnamespace {\n\n#define TILE_M 16\n#define TILE_N 16\n#define TILE_K 32\n#define VNNI_BLK 4\n\n#define AMX_BLK_SIZE 32\n\n#define TMM0 0\n#define TMM1 1\n#define TMM2 2\n#define TMM3 3\n#define TMM4 4\n#define TMM5 5\n#define TMM6 6\n#define TMM7 7\n\n// parallel routines\n// template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0> inline T div_up(T x, T y) {\n//   return (x + y - 1) / y;\n// }\n\ntemplate <typename T>\nvoid balance211(T n, T nth, T ith, T& n_start, T& n_end) {\n#if 0\n  // onednn partition pattern\n  T& n_my = n_end;\n  if (nth <= 1 || n == 0) {\n    n_start = 0;\n    n_my = n;\n  } else {\n    T n1 = div_up(n, nth);\n    T n2 = n1 - 1;\n    T T1 = n - n2 * nth;\n    n_my = ith < T1 ? n1 : n2;\n    n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;\n  }\n  n_end += n_start;\n#else\n  // pytorch aten partition pattern\n  T n_my = div_up(n, nth);\n  n_start = ith * n_my;\n  n_end = std::min(n_start + n_my, n);\n#endif\n}\n\ntemplate <typename func_t>\ninline void parallel_for(int nth, int ith, int n, const func_t& f) {\n  // int nth = omp_get_num_threads();\n  // int ith = omp_get_thread_num();\n  int tbegin, tend;\n  balance211(n, nth, ith, tbegin, tend);\n  f(tbegin, tend);\n}\n\n// Forced unrolling\ntemplate <int n>\nstruct Unroll {\n  template <typename Func, typename... Args>\n  ALWAYS_INLINE void operator()(const Func& f, Args... args) const {\n    Unroll<n - 1>{}(f, args...);\n    f(std::integral_constant<int, n - 1>{}, args...);\n  }\n};\n\ntemplate <>\nstruct Unroll<1> {\n  template <typename Func, typename... Args>\n  ALWAYS_INLINE void operator()(const Func& f, Args... args) const {\n    f(std::integral_constant<int, 0>{}, args...);\n  }\n};\n\n// type traits\ntemplate <typename T>\nstruct PackedTypes {};\ntemplate <>\nstruct PackedTypes<block_q4_0> {\n  using type = int8_t;\n};\ntemplate <>\nstruct PackedTypes<block_q4_1> {\n  using type = uint8_t;\n};\ntemplate <>\nstruct PackedTypes<block_q8_0> {\n  using type = int8_t;\n};\ntemplate <typename T>\nusing packed_B_type = typename PackedTypes<T>::type;\n\ntemplate <typename T>\nstruct do_compensate : std::integral_constant<bool, std::is_same<T, block_q8_0>::value> {};\n\ntemplate <typename T>\nstruct do_unpack\n    : std::integral_constant<bool, std::is_same<T, block_q4_0>::value || std::is_same<T, block_q4_1>::value> {};\n\ntemplate <typename T>\nstruct is_type_qkk\n    : std::integral_constant<bool, std::is_same<T, block_q4_K>::value || std::is_same<T, block_q5_K>::value ||\n                                       std::is_same<T, block_q6_K>::value || std::is_same<T, block_iq4_xs>::value> {};\n\n#define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...)              \\\n  [&] {                                                      \\\n    switch (TYPE) {                                          \\\n      case GGML_TYPE_F16: {                                  \\\n        using type = ggml_fp16_t;                            \\\n        constexpr int blck_size = 16;                        \\\n        return __VA_ARGS__();                                \\\n      }                                                      \\\n      case GGML_TYPE_BF16: {                                 \\\n        using type = ggml_bf16_t;                            \\\n        constexpr int blck_size = 32;                        \\\n        return __VA_ARGS__();                                \\\n      }                                                      \\\n      default:                                               \\\n        fprintf(stderr, \"Unsupported floating data type\\n\"); \\\n    }                                                        \\\n  }()\n\n#define GGML_DISPATCH_QTYPES(QT, ...)                         \\\n  [&] {                                                       \\\n    switch (QT) {                                             \\\n      case GGML_TYPE_Q4_0: {                                  \\\n        using type = block_q4_0;                              \\\n        using vec_dot_type = block_q8_0;                      \\\n        constexpr int blck_size = QK4_0;                      \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_Q4_1: {                                  \\\n        using type = block_q4_1;                              \\\n        using vec_dot_type = block_q8_1;                      \\\n        constexpr int blck_size = QK4_1;                      \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_Q8_0: {                                  \\\n        using type = block_q8_0;                              \\\n        using vec_dot_type = block_q8_0;                      \\\n        constexpr int blck_size = QK8_0;                      \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_Q4_K: {                                  \\\n        using type = block_q4_K;                              \\\n        using vec_dot_type = block_q8_K;                      \\\n        constexpr int blck_size = QK_K;                       \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_Q5_K: {                                  \\\n        using type = block_q5_K;                              \\\n        using vec_dot_type = block_q8_K;                      \\\n        constexpr int blck_size = QK_K;                       \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_Q6_K: {                                  \\\n        using type = block_q6_K;                              \\\n        using vec_dot_type = block_q8_K;                      \\\n        constexpr int blck_size = QK_K;                       \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_IQ4_XS: {                                \\\n        using type = block_iq4_xs;                            \\\n        using vec_dot_type = block_q8_K;                      \\\n        constexpr int blck_size = QK_K;                       \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      default:                                                \\\n        fprintf(stderr, \"Unsupported quantized data type\\n\"); \\\n    }                                                         \\\n  }()\n\n#define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \\\n  [&] {                                            \\\n    if (BOOL_V) {                                  \\\n      constexpr bool BOOL_NAME = true;             \\\n      return __VA_ARGS__();                        \\\n    } else {                                       \\\n      constexpr bool BOOL_NAME = false;            \\\n      return __VA_ARGS__();                        \\\n    }                                              \\\n  }()\n\n// define amx tile config data structure\nstruct tile_config_t {\n  uint8_t palette_id = 0;\n  uint8_t start_row = 0;\n  uint8_t reserved_0[14] = {0};\n  uint16_t colsb[16] = {0};\n  uint8_t rows[16] = {0};\n};\n\n// Notes: amx tile config\n//\n// Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values,\n// and accumulate the result to a 16 x 16 matrix C containing INT32 values,\n//\n// As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used\n// instead of the normally used 16-16-64 config.\n//\n//   Block A: {16, 32}, dtype = int8_t\n//   Block B: {16, 32}, dtype = uint8_t/int8_t\n//   Block C: {16, 16}, dtype = int32_t\n//\n// Block B needs to be prepacked to vnni format before feeding into  TMUL:\n//   packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64}\n//\n// Therefore, we get tileconfig:\n//             A    B    C\n//    rows    16    8   16\n//    colsb   32   64   16\n//\n// For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1,\n// C used TMM4-TMM7:\n//            B TMM0  B TMM1\n//    A TMM2  C TMM4  C TMM6\n//    A TMM3  C TMM5  C TMM7\n//\n// Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A\n// will be needed.\n//\n// Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16;\n// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`.\n//\n// ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/\n//   advanced-matrix-extensions-intrinsics-functions.html\n//\n\n#define TC_CONFIG_TILE(i, r, cb) \\\n  tc.rows[i] = r;                \\\n  tc.colsb[i] = cb\nvoid ggml_tile_config_init(void) {\n  static thread_local tile_config_t tc;\n  tile_config_t current_tc;\n  _tile_storeconfig(&current_tc);\n\n  // load only when config changes\n  if (tc.palette_id == 0 || (memcmp(&current_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&\n                             memcmp(&current_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {\n    tc.palette_id = 1;\n    tc.start_row = 0;\n    TC_CONFIG_TILE(TMM0, 8, 64);\n    TC_CONFIG_TILE(TMM1, 8, 64);\n    TC_CONFIG_TILE(TMM2, 16, 32);\n    TC_CONFIG_TILE(TMM3, 16, 32);\n    TC_CONFIG_TILE(TMM4, 16, 64);\n    TC_CONFIG_TILE(TMM5, 16, 64);\n    TC_CONFIG_TILE(TMM6, 16, 64);\n    TC_CONFIG_TILE(TMM7, 16, 64);\n    _tile_loadconfig(&tc);\n  }\n}\n\n// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.\n// See the notes `s8s8 igemm compensation in avx512-vnni` for detail.\ntemplate <typename TB>\nint get_tile_size() {\n  int tile_size = TILE_N * sizeof(TB);\n  if (do_compensate<TB>::value) {\n    tile_size += TILE_N * sizeof(int32_t);\n  }\n  if (std::is_same<TB, block_q4_K>::value || std::is_same<TB, block_q5_K>::value) {\n    tile_size += TILE_N * 4;\n  }\n  if (std::is_same<TB, block_iq4_xs>::value) {\n    tile_size += TILE_N * 2;\n  }\n  return tile_size;\n}\n\ntemplate <typename TB, int BLOCK_K>\nint get_row_size(int K) {\n  int KB = K / BLOCK_K;\n  int row_size = KB * sizeof(TB);\n  if (do_compensate<TB>::value) {\n    row_size += KB * sizeof(int32_t);\n  }\n  if (std::is_same<TB, block_q4_K>::value || std::is_same<TB, block_q5_K>::value) {\n    row_size += KB * 4;\n  }\n  if (std::is_same<TB, block_iq4_xs>::value) {\n    row_size += KB * 2;\n  }\n  return row_size;\n}\n\n// vectorized dtype conversion\ninline float FP16_TO_FP32(ggml_half val) {\n  __m256i v = _mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);\n  __m512 o = _mm512_cvtph_ps(v);\n  return _mm512_cvtss_f32(o);\n}\n\ninline __m512 FP16_TO_FP32_VEC(ggml_half val) {\n  __m256i v = _mm256_set1_epi16(val);\n  return _mm512_cvtph_ps(v);\n}\n\n// horizontal reduce\ninline float _mm512_reduce_max_ps(const __m512 x) {\n  __m512 v = x;\n  __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);\n  v = _mm512_max_ps(v, v1);\n  v1 = _mm512_shuffle_f32x4(v, v, 0xB1);\n  v = _mm512_max_ps(v, v1);\n  v1 = _mm512_shuffle_ps(v, v, 0x4E);\n  v = _mm512_max_ps(v, v1);\n  v1 = _mm512_shuffle_ps(v, v, 0xB1);\n  v = _mm512_max_ps(v, v1);\n  return _mm512_cvtss_f32(v);\n}\n\n// transpose utils\n#define SHUFFLE_EPI32(a, b, mask) \\\n  _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))\n\n// transpose 8x8 32-bit element from v to v1\ninline void transpose_8x8_32bit(__m256i* v, __m256i* v1) {\n  // unpacking and 32-bit elements\n  v1[0] = _mm256_unpacklo_epi32(v[0], v[1]);\n  v1[1] = _mm256_unpackhi_epi32(v[0], v[1]);\n  v1[2] = _mm256_unpacklo_epi32(v[2], v[3]);\n  v1[3] = _mm256_unpackhi_epi32(v[2], v[3]);\n  v1[4] = _mm256_unpacklo_epi32(v[4], v[5]);\n  v1[5] = _mm256_unpackhi_epi32(v[4], v[5]);\n  v1[6] = _mm256_unpacklo_epi32(v[6], v[7]);\n  v1[7] = _mm256_unpackhi_epi32(v[6], v[7]);\n\n  // shuffling the 32-bit elements\n  v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44);\n  v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee);\n  v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44);\n  v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee);\n  v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44);\n  v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee);\n  v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44);\n  v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee);\n\n  // shuffling 128-bit elements\n  v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02);\n  v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02);\n  v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02);\n  v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02);\n  v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13);\n  v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13);\n  v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13);\n  v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13);\n}\n\n// transpose 16x4 32-bit element to 4x16 from r to d\ninline void transpose_16x4_32bit(__m512i* r, __m512i* d) {\n  static const __m512i index1 =\n      _mm512_set_epi32(0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02, 0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00);\n\n  d[0] = _mm512_permutexvar_epi32(index1, r[0]);\n  d[1] = _mm512_permutexvar_epi32(index1, r[1]);\n  d[2] = _mm512_permutexvar_epi32(index1, r[2]);\n  d[3] = _mm512_permutexvar_epi32(index1, r[3]);\n\n  r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);\n  r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);\n  r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);\n  r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);\n\n  d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);\n  d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);\n  d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);\n  d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);\n}\n\n// transpose 16x16 32-bit element in place\ninline void transpose_16x16_32bit(__m512i* v) {\n  __m512i v1[16];\n  v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);\n  v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);\n  v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);\n  v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);\n  v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);\n  v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);\n  v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);\n  v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);\n  v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);\n  v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);\n  v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);\n  v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);\n  v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);\n  v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);\n  v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);\n  v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);\n\n  v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);\n  v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);\n  v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);\n  v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);\n  v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);\n  v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);\n  v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);\n  v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);\n  v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);\n  v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);\n  v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);\n  v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);\n  v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);\n  v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);\n  v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);\n  v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);\n\n  v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);\n  v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);\n  v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);\n  v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);\n  v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);\n  v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);\n  v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);\n  v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);\n  v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);\n  v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);\n  v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);\n  v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);\n  v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);\n  v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);\n  v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);\n  v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);\n\n  v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);\n  v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);\n  v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);\n  v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);\n  v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);\n  v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);\n  v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);\n  v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);\n  v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);\n  v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);\n  v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);\n  v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);\n  v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);\n  v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);\n  v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);\n  v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);\n}\n\nvoid quantize_row_q8_K_vnni(const float* RESTRICT x, void* RESTRICT vy, int64_t k) {\n  assert(k % QK_K == 0);\n  const int KB = k / QK_K;\n  constexpr int kVecs = QK_K / 16;\n\n  block_q8_K* y = reinterpret_cast<block_q8_K*>(vy);\n\n  // hold 16 float vecs from x\n  __m512 v[kVecs];\n\n  // hold the quants vecs\n  __m512i vq[kVecs / 4];\n\n  // hold the packed quants vecs\n  __m512i vq_packed[kVecs / 4];\n\n  const __m512 signBit = _mm512_set1_ps(-0.f);\n\n  for (int i = 0; i < KB; ++i) {\n    // Compute max(abs(e)) for the block\n    __m512 vamax = _mm512_set1_ps(0.f);\n    for (int j = 0; j < kVecs; ++j) {\n      v[j] = _mm512_loadu_ps(x);\n      x += 16;\n      vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j]));\n    }\n    const float amax = _mm512_reduce_max_ps(vamax);\n\n    // Quantize these floats\n    const float iscale = 127.f / amax;\n    y[i].d = GGML_FP32_TO_FP16(1 / iscale);\n    const float id = (amax != 0.0f) ? iscale : 0.f;\n    const __m512 vscale = _mm512_set1_ps(id);\n\n    // Apply multiplier and round to nearest integer\n    for (int j = 0; j < kVecs; ++j) {\n      v[j] = _mm512_mul_ps(v[j], vscale);\n      v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n    }\n\n    // Pack to epi8 vecs\n    for (int j = 0; j < kVecs / 4; ++j) {\n      __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0]));\n      __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1]));\n      __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2]));\n      __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3]));\n\n      __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1);\n      __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1);\n\n      vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1);\n      _mm512_storeu_si512((__m512i*)(y[i].qs + j * 64), vq[j]);\n    }\n\n    // Compute the bsums with vnni\n    transpose_16x4_32bit(vq, vq_packed);\n\n    const __m512i one = _mm512_set1_epi8(1);\n    __m512i sum = _mm512_setzero_si512();\n    for (int k = 0; k < 4; ++k) {\n      sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]);\n    }\n    _mm256_storeu_si256((__m256i*)(y[i].bsums), _mm512_cvtepi32_epi16(sum));\n  }\n}\n\n// quantize A from float to `vec_dot_type`\ntemplate <typename T>\ninline void from_float(const float* x, char* vy, int64_t k);\n\ntemplate <>\ninline void from_float<block_q8_0>(const float* x, char* vy, int64_t k) {\n  quantize_row_q8_0(x, vy, k);\n}\n\ntemplate <>\ninline void from_float<block_q8_1>(const float* x, char* vy, int64_t k) {\n  quantize_row_q8_1(x, vy, k);\n}\n\ntemplate <>\ninline void from_float<block_q8_K>(const float* x, char* vy, int64_t k) {\n#if 1\n  // TODO: this is reference impl!\n  quantize_row_q8_K(x, vy, k);\n#else\n  quantize_row_q8_K_vnni(x, vy, k);\n#endif\n}\n\n// load A from memory to array when nrows can not fill in whole tile\nvoid unpack_A(int8_t* RESTRICT tile, const block_q8_0* RESTRICT A, int lda, int nr) {\n  assert(nr != TILE_M);\n  for (int m = 0; m < nr; ++m) {\n    const __m256i v = _mm256_loadu_si256((const __m256i*)(A[m * lda].qs));\n    _mm256_storeu_si256((__m256i*)(tile + m * TILE_K), v);\n  }\n}\n\nvoid unpack_A(int8_t* RESTRICT tile, const block_q8_1* RESTRICT A, int lda, int nr) {\n  assert(nr != TILE_M);\n  for (int m = 0; m < nr; ++m) {\n    const __m256i v = _mm256_loadu_si256((const __m256i*)(A[m * lda].qs));\n    _mm256_storeu_si256((__m256i*)(tile + m * TILE_K), v);\n  }\n}\n\ntemplate <typename TB>\nvoid unpack_A(int8_t* RESTRICT tile, const block_q8_K* RESTRICT A, int lda, int k, int nr) {\n  assert(nr <= TILE_M);\n  for (int m = 0; m < nr; ++m) {\n    const __m256i v = _mm256_loadu_si256((const __m256i*)(A[m * lda].qs + k * 32));\n    _mm256_storeu_si256((__m256i*)(tile + m * TILE_K), v);\n  }\n}\n\ntemplate <>\nvoid unpack_A<block_q6_K>(int8_t* RESTRICT tile, const block_q8_K* RESTRICT A, int lda, int k, int nr) {\n  assert(nr <= TILE_M);\n  // zero padding k from 16 to 32, so that we don't have to re-config amx\n  const __m128i zero = _mm_setzero_si128();\n  for (int m = 0; m < nr; ++m) {\n    const __m128i v = _mm_loadu_si128((const __m128i*)(A[m * lda].qs + k * 16));\n    const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1);\n    _mm256_storeu_si256((__m256i*)(tile + m * TILE_K), r);\n  }\n}\n\n#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)\ninline __m256i bytes_from_nibbles_32(const uint8_t* rsi) {\n  const __m128i tmp = _mm_loadu_si128((const __m128i*)rsi);\n  const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);\n  const __m256i lowMask = _mm256_set1_epi8(0xF);\n  return _mm256_and_si256(lowMask, bytes);\n}\n\n// used for block_q4_K\ninline __m512i bytes_from_nibbles_64(const uint8_t* rsi) {\n  const __m256i tmp = _mm256_loadu_si256((const __m256i*)rsi);\n  const __m256i lowMask = _mm256_set1_epi8(0xF);\n  const __m256i q4l = _mm256_and_si256(tmp, lowMask);\n  const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask);\n  return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1);\n}\n\n// used for block_q5_K\ninline __m512i bytes_from_nibbles_64(const uint8_t* qs, const uint8_t* qh, int k) {\n  const __m256i lowMask = _mm256_set1_epi8(0xF);\n  __m256i hmask = _mm256_set1_epi8(1);\n  hmask = _mm256_slli_epi16(hmask, k);\n\n  const __m256i q5bits = _mm256_loadu_si256((const __m256i*)qs);\n  const __m256i hbits = _mm256_loadu_si256((const __m256i*)qh);\n\n  const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask);\n  const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4);\n  const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);\n  hmask = _mm256_slli_epi16(hmask, 1);\n\n  const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask);\n  const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4);\n  const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);\n\n  return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1);\n}\n\n// used for block_q6_K\ninline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t* qs, const uint8_t* qh) {\n  const __m256i m4 = _mm256_set1_epi8(0xF);\n  const __m256i m2 = _mm256_set1_epi8(0x3);\n\n  const __m256i q6bits1 = _mm256_loadu_si256((const __m256i*)qs);\n  const __m256i q6bits2 = _mm256_loadu_si256((const __m256i*)(qs + 32));\n  const __m256i q6bitsH = _mm256_loadu_si256((const __m256i*)qh);\n\n  const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256(q6bitsH, m2), 4);\n  const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4);\n  const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4);\n  const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4);\n\n  const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0);\n  const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1);\n  const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2);\n  const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3);\n\n  r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1);\n  r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1);\n}\n\ninline __m512i packNibbles(__m512i r0, __m512i r1) { return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4)); }\n\ntemplate <typename TB>\ninline void pack_qs(void* RESTRICT packed_B, const TB* RESTRICT B, int KB) {\n  int8_t tmp[8 * 64];\n  __m256i v[8], v2[8];\n  for (int n = 0; n < 8; ++n) {\n    v[n] = bytes_from_nibbles_32(B[n * KB].qs);\n  }\n  transpose_8x8_32bit(v, v2);\n  for (int n = 0; n < 8; ++n) {\n    _mm256_storeu_si256((__m256i*)(tmp + n * 64), v2[n]);\n  }\n  for (int n = 0; n < 8; ++n) {\n    v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs);\n  }\n  transpose_8x8_32bit(v, v2);\n  for (int n = 0; n < 8; ++n) {\n    _mm256_storeu_si256((__m256i*)(tmp + n * 64 + 32), v2[n]);\n  }\n\n  // pack again with 128 to fully utilize vector length\n  for (int n = 0; n < 8; n += 2) {\n    __m512i r0 = _mm512_loadu_si512((const __m512i*)(tmp + n * 64));\n    __m512i r1 = _mm512_loadu_si512((const __m512i*)(tmp + n * 64 + 64));\n    __m512i r1r0 = packNibbles(r0, r1);\n    _mm512_storeu_si512((__m512i*)((char*)packed_B + n * 32), r1r0);\n  }\n}\n\ntemplate <>\ninline void pack_qs<block_q8_0>(void* RESTRICT packed_B, const block_q8_0* RESTRICT B, int KB) {\n  __m256i v[8], v2[8];\n  for (int n = 0; n < 8; ++n) {\n    v[n] = _mm256_loadu_si256((const __m256i*)(B[n * KB].qs));\n  }\n  transpose_8x8_32bit(v, v2);\n  for (int n = 0; n < 8; ++n) {\n    _mm256_storeu_si256((__m256i*)((char*)packed_B + n * 64), v2[n]);\n  }\n  for (int n = 0; n < 8; ++n) {\n    v[n] = _mm256_loadu_si256((const __m256i*)(B[(n + 8) * KB].qs));\n  }\n  transpose_8x8_32bit(v, v2);\n  for (int n = 0; n < 8; ++n) {\n    _mm256_storeu_si256((__m256i*)((char*)packed_B + n * 64 + 32), v2[n]);\n  }\n}\n\ntemplate <>\ninline void pack_qs<block_q4_K>(void* RESTRICT packed_B, const block_q4_K* RESTRICT B, int KB) {\n  __m512i v[16];\n  // QK_K 256 with 8 groups, handle 2 groups at a time\n  char* pb = (char*)packed_B;\n  for (int k = 0; k < QK_K / 64; ++k) {\n    // pack 2 groups { n, g,  k} to {g, k/4, 4n}\n    //          e.g. {16, 2, 32} to {2,   8, 64}\n    for (int n = 0; n < TILE_N; ++n) {\n      v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32);\n    }\n\n    transpose_16x16_32bit(v);\n\n    // pack again with 128 to fully utilize vector length\n    for (int n = 0; n < TILE_N; n += 2) {\n      _mm512_storeu_si512((__m512i*)pb, packNibbles(v[n], v[n + 1]));\n      pb += 64;\n    }\n  }\n}\n\ntemplate <>\ninline void pack_qs<block_q5_K>(void* RESTRICT packed_B, const block_q5_K* RESTRICT B, int KB) {\n  __m512i v[16];\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  // QK_K 256 with 8 groups, handle 2 groups at a time\n  char* pb = (char*)packed_B;\n  char* ph = (char*)packed_B + (QK_K / 2) * TILE_N;\n  for (int k = 0; k < QK_K / 64; ++k) {\n    // pack 2 groups { n, g,  k} to {g, k/4, 4n}\n    //          e.g. {16, 2, 32} to {2,   8, 64}\n    for (int n = 0; n < TILE_N; ++n) {\n      v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */ 2 * k);\n    }\n\n    transpose_16x16_32bit(v);\n\n    // 1. pack lower 4bits with 2 groups\n    for (int n = 0; n < TILE_N; n += 2) {\n      // get lower 4 bits\n      const __m512i r0 = _mm512_and_si512(v[n], lowMask);\n      const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);\n      _mm512_storeu_si512((__m512i*)pb, packNibbles(r0, r1));\n      pb += 64;\n    }\n\n    // 2. pack higher 1bit with 2 groups\n    const __m512i hmask = _mm512_set1_epi8(0x10);\n    for (int g = 0; g < 2; ++g) {\n      __m512i hbits = _mm512_setzero_si512();\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4));\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3));\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2));\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1));\n      hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask));\n      hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1));\n      hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2));\n      hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3));\n      _mm512_storeu_si512((__m512i*)ph, hbits);\n      ph += 64;\n    }\n  }\n}\n\ntemplate <>\ninline void pack_qs<block_q6_K>(void* RESTRICT packed_B, const block_q6_K* RESTRICT B, int KB) {\n  __m512i v[32];\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  // QK_K 256 with 8 groups, handle 4 groups at a time\n  char* pb = (char*)packed_B;\n  char* ph = (char*)packed_B + (QK_K / 2) * TILE_N;\n  for (int k = 0; k < QK_K / 128; ++k) {\n    for (int n = 0; n < TILE_N; ++n) {\n      bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32);\n    }\n\n    // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7\n    transpose_16x16_32bit(v);\n    transpose_16x16_32bit(v + 16);\n\n    // 1. pack lower 4bits with 4 groups\n    for (int n = 0; n < 32; n += 2) {\n      const __m512i r0 = _mm512_and_si512(v[n], lowMask);\n      const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);\n      _mm512_storeu_si512((__m512i*)pb, packNibbles(r0, r1));\n      pb += 64;\n    }\n\n    // 2. pack higher 2bit with 4 groups\n    const __m512i hmask = _mm512_set1_epi8(0x30);\n    for (int g = 0; g < 8; ++g) {\n      __m512i hbits = _mm512_setzero_si512();\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4));\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2));\n      hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask));\n      hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2));\n      _mm512_storeu_si512((__m512i*)ph, hbits);\n      ph += 64;\n    }\n  }\n}\n\ntemplate <>\ninline void pack_qs<block_iq4_xs>(void* RESTRICT packed_B, const block_iq4_xs* RESTRICT B, int KB) {\n  __m512i v[16];\n  char* pb = (char*)packed_B;\n  for (int k = 0; k < QK_K / 64; ++k) {\n    for (int n = 0; n < TILE_N; ++n) {\n      __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0);\n      __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16);\n      v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1);\n    }\n\n    transpose_16x16_32bit(v);\n\n    // pack again with 128 to fully utilize vector length\n    for (int n = 0; n < TILE_N; n += 2) {\n      _mm512_storeu_si512((__m512i*)pb, packNibbles(v[n], v[n + 1]));\n      pb += 64;\n    }\n  }\n}\n\n// pack B to vnni formats in 4bits or 8 bits\nvoid pack_B(void* RESTRICT packed_B, const block_q4_0* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n  ggml_half* d0 = reinterpret_cast<ggml_half*>((char*)packed_B + TILE_N * TILE_K / 2);\n  for (int n = 0; n < TILE_N; ++n) {\n    d0[n] = B[n * KB].d;\n  }\n}\n\nvoid pack_B(void* RESTRICT packed_B, const block_q4_1* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n  ggml_half* d0 = reinterpret_cast<ggml_half*>((char*)packed_B + TILE_N * TILE_K / 2);\n  ggml_half* m0 = d0 + TILE_N;\n  for (int n = 0; n < TILE_N; ++n) {\n    d0[n] = B[n * KB].d;\n    m0[n] = B[n * KB].m;\n  }\n}\n\ninline void s8s8_compensation(void* RESTRICT packed_B) {\n  // packed_B layout:\n  //   quants {TILE_N, TILEK}  int8_t\n  //   d0     {TILE_N}      ggml_half\n  //   comp   {TILE_N}        int32_t\n  const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);\n  __m512i vcomp = _mm512_setzero_si512();\n  const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));\n  for (int k = 0; k < 8; ++k) {\n    __m512i vb = _mm512_loadu_si512((const __m512i*)((const char*)packed_B + k * 64));\n    vcomp = _mm512_dpbusd_epi32(vcomp, off, vb);\n  }\n  _mm512_storeu_si512((__m512i*)((char*)(packed_B) + offset), vcomp);\n}\n\nvoid pack_B(void* RESTRICT packed_B, const block_q8_0* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n  ggml_half* d0 = reinterpret_cast<ggml_half*>((char*)packed_B + TILE_N * TILE_K);\n  for (int n = 0; n < TILE_N; ++n) {\n    d0[n] = B[n * KB].d;\n  }\n  s8s8_compensation(packed_B);\n}\n\n// convert 8 * {min, scale} from int6 to int8\ninline void unpack_mins_and_scales(const uint8_t* scales, uint32_t* utmp) {\n  const uint32_t kmask1 = 0x3f3f3f3f;\n  const uint32_t kmask2 = 0x0f0f0f0f;\n  const uint32_t kmask3 = 0x03030303;\n\n  memcpy(utmp, scales, 12);\n  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n  const uint32_t uaux = utmp[1] & kmask1;\n  utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n  utmp[2] = uaux;\n  utmp[0] &= kmask1;\n}\n\n// packed_B layout:\n//   quants {8, TILE_N, 16}  uint8\n//   scales {8, TILE_N}      uint8\n//   mins   {8, TILE_N}      uint8\n//   d      {TILE_N}     ggml_half\n//   dmin   {TILE_N}     ggml_half\nvoid pack_B(void* RESTRICT packed_B, const block_q4_K* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n\n  uint8_t* scales = reinterpret_cast<uint8_t*>((char*)packed_B + (QK_K / 2) * TILE_N);\n  uint8_t* mins = scales + 8 * TILE_N;\n  ggml_half* d = reinterpret_cast<ggml_half*>(mins + 8 * TILE_N);\n  ggml_half* dmin = d + TILE_N;\n\n  union {\n    uint32_t u32[4];\n    uint8_t u8[16];\n  } s;\n\n  for (int n = 0; n < TILE_N; ++n) {\n    unpack_mins_and_scales(B[n * KB].scales, s.u32);\n    for (int k = 0; k < 8; ++k) {\n      scales[k * TILE_N + n] = s.u8[k];\n      mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];\n    }\n    d[n] = B[n * KB].d;\n    dmin[n] = B[n * KB].dmin;\n  }\n}\n\n// packed_B layout:\n//   quants {8, TILE_N, 16}  uint8\n//   qh     {8, TILE_N,  4}  uint8\n//   scales {8, TILE_N}      uint8\n//   mins   {8, TILE_N}      uint8\n//   d      {TILE_N}     ggml_half\n//   dmin   {TILE_N}     ggml_half\nvoid pack_B(void* RESTRICT packed_B, const block_q5_K* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n\n  uint8_t* scales = reinterpret_cast<uint8_t*>((char*)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);\n  uint8_t* mins = scales + 8 * TILE_N;\n  ggml_half* d = reinterpret_cast<ggml_half*>(mins + 8 * TILE_N);\n  ggml_half* dmin = d + TILE_N;\n\n  union {\n    uint32_t u32[4];\n    uint8_t u8[16];\n  } s;\n\n  for (int n = 0; n < TILE_N; ++n) {\n    unpack_mins_and_scales(B[n * KB].scales, s.u32);\n    for (int k = 0; k < 8; ++k) {\n      scales[k * TILE_N + n] = s.u8[k];\n      mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];\n    }\n    d[n] = B[n * KB].d;\n    dmin[n] = B[n * KB].dmin;\n  }\n}\n\n// packed_B layout:\n//   quants {16, TILE_N, 8}  uint8\n//   qh     {16, TILE_N, 4}  uint8\n//   scales {16, TILE_N}      uint8\n//   d      {TILE_N}     ggml_half\nvoid pack_B(void* RESTRICT packed_B, const block_q6_K* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n\n  uint8_t* scales = reinterpret_cast<uint8_t*>((char*)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);\n  ggml_half* d = reinterpret_cast<ggml_half*>(scales + 16 * TILE_N);\n  for (int n = 0; n < TILE_N; ++n) {\n    const int8_t* ps = B[n * KB].scales;\n    for (int k = 0; k < 16; ++k) {\n      scales[k * TILE_N + n] = ps[k];\n    }\n    d[n] = B[n * KB].d;\n  }\n}\n\n// packed_B layout:\n//   quants {8, TILE_N, 16}  uint8\n//   scales {8, TILE_N}       int8\n//   d      {TILE_N}     ggml_half\nvoid pack_B(void* RESTRICT packed_B, const block_iq4_xs* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n\n  int8_t* scales = reinterpret_cast<int8_t*>((char*)packed_B + (QK_K / 2) * TILE_N);\n  ggml_half* d = reinterpret_cast<ggml_half*>(scales + 8 * TILE_N);\n\n  // pack the scales\n  for (int n = 0; n < TILE_N; ++n) {\n    uint16_t sh = B[n * KB].scales_h;\n    for (int k = 0; k < 8; k += 2) {\n      const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32;\n      const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32;\n      scales[(k + 0) * TILE_N + n] = ls1;\n      scales[(k + 1) * TILE_N + n] = ls2;\n      sh >>= 4;\n    }\n    d[n] = B[n * KB].d;\n  }\n}\n\ntemplate <typename TB, typename packed_B_t = packed_B_type<TB>>\nvoid unpack_B(packed_B_t* RESTRICT tile, const void* RESTRICT packed_B) {\n  GGML_UNUSED(tile);\n  GGML_UNUSED(packed_B);\n};\n\ntemplate <>\nvoid unpack_B<block_q4_0>(int8_t* RESTRICT tile, const void* RESTRICT packed_B) {\n  const __m512i off = _mm512_set1_epi8(8);\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512((const __m512i*)((const char*)packed_B + n * 32));\n    const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off);\n    const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 0), r0);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 64), r1);\n  }\n}\n\ntemplate <>\nvoid unpack_B<block_q4_1>(uint8_t* RESTRICT tile, const void* RESTRICT packed_B) {\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512((const __m512i*)((const char*)packed_B + n * 32));\n    const __m512i r0 = _mm512_and_si512(bytes, lowMask);\n    const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 0), r0);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 64), r1);\n  }\n}\n\n// packed_B_t for QKK is int8_t\ntemplate <typename TB>\nvoid unpack_B(int8_t* RESTRICT tile, const void* RESTRICT packed_B, int k) {\n  const int packed_B_group_size = QK_K / 2 * TILE_N / 8;\n  const char* packed_B_group = (const char*)packed_B + k * packed_B_group_size;\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32);\n    const __m512i r0 = _mm512_and_si512(bytes, lowMask);\n    const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 0), r0);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 64), r1);\n  }\n}\n\ntemplate <>\nvoid unpack_B<block_q5_K>(int8_t* RESTRICT tile, const void* RESTRICT packed_B, int k) {\n  // lower 4bits, stride 256 bytes\n  const int packed_l4_group_size = QK_K / 2 * TILE_N / 8;\n  const char* pb = (const char*)packed_B + k * packed_l4_group_size;\n\n  // higher 1bit, stride 64 bytes\n  const int packed_h1_group_size = QK_K / 8 * TILE_N / 8;\n  const char* ph = (const char*)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size;\n  const __m512i hbits = _mm512_loadu_si512(ph);\n\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  __m512i hmask0 = _mm512_set1_epi8(0x1);\n  __m512i hmask1 = _mm512_set1_epi8(0x2);\n\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512(pb + n * 32);\n    __m512i r0 = _mm512_and_si512(bytes, lowMask);\n    __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n    __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4);\n    __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4);\n\n    hmask0 = _mm512_slli_epi16(hmask0, 2);\n    hmask1 = _mm512_slli_epi16(hmask1, 2);\n    r0 = _mm512_add_epi8(r0, h0);\n    r1 = _mm512_add_epi8(r1, h1);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 0), r0);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 64), r1);\n  }\n}\n\ntemplate <>\nvoid unpack_B<block_q6_K>(int8_t* RESTRICT tile, const void* RESTRICT packed_B, int k) {\n  // lower 4bits, stride 128 bytes\n  const int packed_l4_group_size = QK_K / 2 * TILE_N / 16;\n  const char* pb = (const char*)packed_B + k * packed_l4_group_size;\n\n  // higher 2bits, stride 64 bytes\n  const int packed_h2_group_size = QK_K / 4 * TILE_N / 16;\n  const char* ph = (const char*)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size;\n  const __m512i hbits = _mm512_loadu_si512(ph);\n\n  const __m512i off = _mm512_set1_epi8(32);\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  __m512i hmask0 = _mm512_set1_epi8(0x3);  // 0011\n  __m512i hmask1 = _mm512_set1_epi8(0xC);  // 1100\n\n  // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A`\n  __m512i bytes = _mm512_loadu_si512(pb);\n  __m512i r0 = _mm512_and_si512(bytes, lowMask);\n  __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n  __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4);\n  __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2);\n  _mm512_storeu_si512((__m512i*)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));\n  _mm512_storeu_si512((__m512i*)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));\n\n  hmask0 = _mm512_slli_epi16(hmask0, 4);\n  hmask1 = _mm512_slli_epi16(hmask1, 4);\n\n  bytes = _mm512_loadu_si512(pb + 64);\n  r0 = _mm512_and_si512(bytes, lowMask);\n  r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n  h0 = _mm512_and_si512(hbits, hmask0);\n  h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2);\n  _mm512_storeu_si512((__m512i*)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));\n  _mm512_storeu_si512((__m512i*)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));\n}\n\ntemplate <>\nvoid unpack_B<block_iq4_xs>(int8_t* RESTRICT tile, const void* RESTRICT packed_B, int k) {\n  static const __m512i values128 = _mm512_set_epi8(\n      113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, 113, 89, 69, 53, 38, 25, 13, 1, -10,\n      -22, -35, -49, -65, -83, -104, -127, 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n      113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127);\n\n  const int packed_B_group_size = QK_K / 2 * TILE_N / 8;\n  const char* pb = (const char*)packed_B + k * packed_B_group_size;\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512(pb + n * 32);\n    const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask));\n    const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 0), r0);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 64), r1);\n  }\n}\n\ntemplate <typename TA, typename TB, bool is_acc>\nstruct acc_C {};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_0, block_q4_0, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_0* A, int lda,\n                    const void* packed_B, int nr) {\n    const int offset = TILE_N * TILE_K / 2;\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)((const char*)packed_B + offset)));\n\n    for (int m = 0; m < nr; ++m) {\n      const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n      vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_1, block_q4_1, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_1* A, int lda,\n                    const void* packed_B, int nr) {\n    const int offset = TILE_N * TILE_K / 2;\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)((const char*)packed_B + offset)));\n    const __m512 vm0 = _mm512_cvtph_ps(\n        _mm256_loadu_si256((const __m256i*)((const char*)packed_B + offset + TILE_N * sizeof(ggml_half))));\n\n    for (int m = 0; m < nr; ++m) {\n      const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));\n      const __m512 vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].s));\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n      vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);\n      vsum = _mm512_fmadd_ps(vm0, vs1, vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_0, block_q8_0, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_0* A, int lda,\n                    const void* packed_B, int nr) {\n    const int offset = TILE_N * TILE_K;\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)((const char*)packed_B + offset)));\n\n    for (int m = 0; m < nr; ++m) {\n      const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n      vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_q4_K, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_K* A, int lda,\n                    const void* packed_B, int nr) {\n    const uint8_t* scales = reinterpret_cast<const uint8_t*>((const char*)packed_B + (QK_K / 2) * TILE_N);\n    const uint8_t* mins = scales + 8 * TILE_N;\n    const ggml_half* d0 = reinterpret_cast<const ggml_half*>(mins + 8 * TILE_N);\n    const ggml_half* dmin = d0 + TILE_N;\n\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)d0));\n    const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)dmin));\n\n    for (int m = 0; m < nr; ++m) {\n      const float d1 = A[m * lda].d;\n      const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n      const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n\n      const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[m * lda].bsums);\n      const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n\n      __m512i acc_m = _mm512_setzero_si512();\n      for (int k = 0; k < 4; ++k) {\n        __m512i vmask = _mm512_set1_epi32(k);\n        __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));\n        __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(mins + k * 32)));\n        acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n      }\n\n      vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n      vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_q5_K, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_K* A, int lda,\n                    const void* packed_B, int nr) {\n    const uint8_t* scales =\n        reinterpret_cast<const uint8_t*>((const char*)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);\n    const uint8_t* mins = scales + 8 * TILE_N;\n    const ggml_half* d0 = reinterpret_cast<const ggml_half*>(mins + 8 * TILE_N);\n    const ggml_half* dmin = d0 + TILE_N;\n\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)d0));\n    const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)dmin));\n\n    for (int m = 0; m < nr; ++m) {\n      const float d1 = A[m * lda].d;\n      const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n      const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n\n      const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[m * lda].bsums);\n      const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n\n      __m512i acc_m = _mm512_setzero_si512();\n      for (int k = 0; k < 4; ++k) {\n        __m512i vmask = _mm512_set1_epi32(k);\n        __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));\n        __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(mins + k * 32)));\n        acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n      }\n\n      vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n      vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_q6_K, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_K* A, int lda,\n                    const void* packed_B, int nr) {\n    const uint8_t* scales =\n        reinterpret_cast<const uint8_t*>((const char*)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);\n    const ggml_half* d0 = reinterpret_cast<const ggml_half*>(scales + 16 * TILE_N);\n\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)d0));\n\n    for (int m = 0; m < nr; ++m) {\n      const float d1 = A[m * lda].d;\n      const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n\n      vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_iq4_xs, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_K* A, int lda,\n                    const void* packed_B, int nr) {\n    const int8_t* scales = reinterpret_cast<const int8_t*>((const char*)packed_B + (QK_K / 2) * TILE_N);\n    const ggml_half* d0 = reinterpret_cast<const ggml_half*>(scales + 8 * TILE_N);\n\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)d0));\n\n    for (int m = 0; m < nr; ++m) {\n      const float d1 = A[m * lda].d;\n      const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n\n      vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <typename TB>\nconstexpr int get_quants_size();\ntemplate <>\nconstexpr int get_quants_size<block_q4_K>() {\n  return (QK_K / 2) * TILE_N;\n}\ntemplate <>\nconstexpr int get_quants_size<block_q5_K>() {\n  return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N;\n}\ntemplate <>\nconstexpr int get_quants_size<block_q6_K>() {\n  return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N;\n}\ntemplate <>\nconstexpr int get_quants_size<block_iq4_xs>() {\n  return (QK_K / 2) * TILE_N;\n}\n\n// used for QKK format\ntemplate <typename TB, bool is_acc, typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>\ninline void scale_C(const int32_t* RESTRICT tile, int32_t* RESTRICT sumi, const void* packed_B, int k, int nr) {\n  const uint8_t* scales = reinterpret_cast<const uint8_t*>((const char*)packed_B + get_quants_size<TB>());\n  const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)(scales + k * TILE_N)));\n\n  for (int m = 0; m < nr; ++m) {\n    __m512i vsumi;\n    if (is_acc) {\n      vsumi = _mm512_loadu_si512(sumi + m * TILE_N);\n    } else {\n      vsumi = _mm512_setzero_si512();\n    }\n    __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N);\n    vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale));\n    _mm512_storeu_si512((__m512i*)(sumi + m * TILE_N), vsumi);\n  }\n}\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_avx {\n  static void apply(int K, const TA* RESTRICT A, const TB* RESTRICT B, TC* RESTRICT C, int ldc) {\n    GGML_UNUSED(K);\n    GGML_UNUSED(A);\n    GGML_UNUSED(B);\n    GGML_UNUSED(C);\n    GGML_UNUSED(ldc);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int K, const float* RESTRICT A, const ggml_fp16_t* RESTRICT B, float* RESTRICT C, int ldc) {\n    constexpr int ROWS = BLOCK_M;\n    constexpr int COLS = BLOCK_N;\n    assert(BLOCK_K == 16);\n\n    __m512 va;\n    __m512 vb[COLS];\n    __m512 vc[ROWS * COLS];\n\n    auto loadc = [&](int idx) { vc[idx] = _mm512_setzero_ps(); };\n    Unroll<ROWS * COLS>{}(loadc);\n\n    auto compute = [&](int idx, int k) {\n      // TODO: use `constexpr` here to get rid of interger div\n      // when upgraded to C++17\n      const int row = idx / COLS;\n      const int col = idx % COLS;\n\n      if (col == 0) {\n        va = _mm512_loadu_ps(A + row * K + k);\n      }\n      if (row == 0) {\n        vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(B + col * K + k)));\n      }\n      vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);\n    };\n\n    for (int k = 0; k < K; k += 16) {\n      Unroll<ROWS * COLS>{}(compute, k);\n    }\n\n    auto storec = [&](int idx) {\n      const int row = idx / COLS;\n      const int col = idx % COLS;\n      C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]);\n    };\n    Unroll<ROWS * COLS>{}(storec);\n  }\n};\n\n#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE)                                      \\\n  tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply(            \\\n      K, (const float*)src1->data + mb_start * K, (const type*)src0->data + nb_start * K, \\\n      (float*)dst->data + mb_start * ldc + nb_start, ldc);\n\n// re-organize in the format {NB, KB, TILE_SIZE}:\n#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size\n\ntemplate <typename TB, int BLOCK_K>\nvoid convert_B_packed_format(void* RESTRICT packed_B, const TB* RESTRICT B, int N, int K) {\n  const int NB = N / TILE_N;\n  const int KB = K / BLOCK_K;\n  const int TILE_SIZE = get_tile_size<TB>();\n\n  // parallel on NB should be enough\n  parallel_for(1, 0, NB, [&](int begin, int end) {\n    for (int n = begin; n < end; ++n) {\n      for (int k = 0; k < KB; ++k) {\n        int n0 = n * TILE_N;\n        pack_B((char*)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB);\n      }\n    }\n  });\n}\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni {};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q4_0);\n\n    const block_q8_0* RESTRICT A = static_cast<const block_q8_0*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    __m512i va[8];\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // sum of offsets, shared across COLS\n    //\n    // avx512-vnni does not have `_mm512_dpbssd_epi32`,\n    // need to transfrom ss to us:\n    //   a * (b - 8) is equavilent to b * a - 8 * a\n    //   s    u   u                   u   s   u   s\n    //\n    __m512i vcomp;\n\n    const __m512i off = _mm512_set1_epi8(8);\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    auto compute = [&](int col, int i) {\n      // load a and compute compensation\n      if (col == 0) {\n        const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A[0 * KB + i].qs);\n        vcomp = _mm512_setzero_si512();\n        for (int k = 0; k < 8; ++k) {\n          va[k] = _mm512_set1_epi32(a_ptr[k]);\n          vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]);\n        }\n        vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));\n      }\n\n      // load b\n      __m512i vsum = _mm512_setzero_si512();\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      for (int k = 0; k < 8; k += 2) {\n        __m512i bytes = _mm512_loadu_si512((const __m512i*)(b_ptr + k * 32));\n        __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n        vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]);\n        __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n        vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]);\n      }\n      const int offset = TILE_N * TILE_K / 2;\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset)));\n      vsum = _mm512_sub_epi32(vsum, vcomp);\n\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q4_1);\n\n    const block_q8_1* RESTRICT A = static_cast<const block_q8_1*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    __m512i va[8];\n    __m512i vb[8];\n    __m512 vc[COLS];\n    __m512 vd1, vs1;\n\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    auto compute = [&](int col, int i) {\n      // load a\n      if (col == 0) {\n        const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A[0 * KB + i].qs);\n        for (int k = 0; k < 8; ++k) {\n          va[k] = _mm512_set1_epi32(a_ptr[k]);\n        }\n        vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));\n        vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].s));\n      }\n\n      // load b\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      for (int k = 0; k < 8; k += 2) {\n        __m512i bytes = _mm512_loadu_si512((const __m512i*)(b_ptr + k * 32));\n        vb[k + 0] = _mm512_and_si512(bytes, lowMask);\n        vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n      }\n      const int offset = TILE_N * TILE_K / 2;\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset)));\n      const __m512 vm0 =\n          _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset + TILE_N * sizeof(ggml_half))));\n\n      __m512i vsum = _mm512_setzero_si512();\n      for (int k = 0; k < 8; ++k) {\n        vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]);\n      }\n\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);\n      vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t);\n\n    const block_q8_0* RESTRICT A = static_cast<const block_q8_0*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    __m512i va[8];\n    __m512i vb[8];\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // Notes: s8s8 igemm compensation in avx512-vnni\n    // change s8s8 to u8s8 with compensate\n    //   a * b = (a + 128) * b - 128 * b\n    //   s   s       u       s    u    s\n    //\n    // (128 * b is pre-computed when packing B to vnni formats)\n    //\n    const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    auto compute = [&](int col, int i) {\n      // load a and add offset 128\n      if (col == 0) {\n        const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A[0 * KB + i].qs);\n        for (int k = 0; k < 8; ++k) {\n          va[k] = _mm512_set1_epi32(a_ptr[k]);\n          va[k] = _mm512_add_epi8(va[k], off);\n        }\n        vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));\n      }\n\n      // load b\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      for (int k = 0; k < 8; ++k) {\n        vb[k] = _mm512_loadu_si512((const __m512i*)(b_ptr + k * 64));\n      }\n      const int offset = TILE_N * TILE_K;\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset)));\n      const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);\n      const __m512i vcomp = _mm512_loadu_si512((const __m512i*)(b_ptr + offset2));\n\n      __m512i vsum = _mm512_setzero_si512();\n      for (int k = 0; k < 8; ++k) {\n        vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]);\n      }\n      vsum = _mm512_sub_epi32(vsum, vcomp);\n\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4;\n\n    const block_q8_K* RESTRICT A = static_cast<const block_q8_K*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    // a.qs:   8 groups, 32 bytes each group (m256i)\n    __m512i va[8];\n    // a.bsum: 8 groups,  2 bytes each group (m128i)\n    __m512i va_bsum;\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // packed_B:\n    const int offset_scales = (QK_K / 2) * TILE_N;\n    const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N;\n    const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N;\n    const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);\n\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    // Notes: vnni formats in QK_K\n    //   a) quants vnni format\n    //     int8  {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32\n    //     from {16, 32} to {8, 64}\n    //\n    //   b) min vnni format\n    //     int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8\n    //     from {16,  8} to {4, 32}\n    //\n    auto compute = [&](int col, int i) {\n      // load a\n      if (col == 0) {\n        for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n          va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(A[0 * KB + i].qs + k_group * 32)));\n        }\n        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[0 * KB + i].bsums);\n        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n        va_bsum = _mm512_castsi128_si512(q8s);\n        vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n      }\n\n      // step 1: accumultate the quants\n      __m512i acc = _mm512_setzero_si512();\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      const char* b_qs = b_ptr;\n      for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n        __m512i vsum = _mm512_setzero_si512();\n        for (int k = 0; k < 8; k += 2) {\n          __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);\n          __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);\n\n          __m512i bytes = _mm512_loadu_si512((const __m512i*)b_qs);\n          __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n          vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n          __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n          vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n\n          b_qs += 64;\n        }\n        // vacc += scale * (q8 @ q4)\n        const __m512i vscale =\n            _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)(b_ptr + offset_scales + k_group * TILE_N)));\n        acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n      }\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_d0)));\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n\n      // step 2: accumulate the mins\n      __m512i acc_m = _mm512_setzero_si512();\n      for (int k = 0; k < 4; ++k) {\n        __m512i vmask = _mm512_set1_epi32(k);\n        __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);\n        __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_mins + k * 32)));\n        acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n      }\n      const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_dmin)));\n      vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4;\n\n    const block_q8_K* RESTRICT A = static_cast<const block_q8_K*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    // a.qs:   8 groups, 32 bytes each group (m256i)\n    __m512i va[8];\n    // a.bsum: 8 groups,  2 bytes each group (m128i)\n    __m512i va_bsum;\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // packed_B:\n    const int offset_qh = (QK_K / 2) * TILE_N;\n    const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N;\n    const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N;\n    const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N;\n    const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);\n\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    // Q5_K and Q4_K shares the same vnni formats, refer to notes above.\n    auto compute = [&](int col, int i) {\n      // load a\n      if (col == 0) {\n        for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n          va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(A[0 * KB + i].qs + k_group * 32)));\n        }\n        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[0 * KB + i].bsums);\n        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n        va_bsum = _mm512_castsi128_si512(q8s);\n        vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n      }\n\n      // step 1: accumultate the quants\n      __m512i acc = _mm512_setzero_si512();\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      const char* b_qs = b_ptr;\n      const char* b_qh = b_ptr + offset_qh;\n      for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n        __m512i vsum = _mm512_setzero_si512();\n        __m512i hmask0 = _mm512_set1_epi8(0x1);\n        __m512i hmask1 = _mm512_set1_epi8(0x2);\n        __m512i hbits = _mm512_loadu_si512((const __m512i*)(b_qh + k_group * 64));\n        for (int k = 0; k < 8; k += 2) {\n          __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);\n          __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);\n\n          __m512i bytes = _mm512_loadu_si512((const __m512i*)b_qs);\n          __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n          __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n\n          __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4);\n          __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4);\n\n          hmask0 = _mm512_slli_epi16(hmask0, 2);\n          hmask1 = _mm512_slli_epi16(hmask1, 2);\n          vb0 = _mm512_add_epi8(vb0, vh0);\n          vb1 = _mm512_add_epi8(vb1, vh1);\n\n          vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n          vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n\n          b_qs += 64;\n        }\n        // vacc += scale * (q8 @ q5)\n        const __m512i vscale =\n            _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)(b_ptr + offset_scales + k_group * TILE_N)));\n        acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n      }\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_d0)));\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n\n      // step 2: accumulate the mins\n      __m512i acc_m = _mm512_setzero_si512();\n      for (int k = 0; k < 4; ++k) {\n        __m512i vmask = _mm512_set1_epi32(k);\n        __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);\n        __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_mins + k * 32)));\n        acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n      }\n      const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_dmin)));\n      vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q6_K);\n\n    const block_q8_K* RESTRICT A = static_cast<const block_q8_K*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    // load the 256 bytes from A to 4 avx512 vectors\n    __m512i va[4];\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // packed_B:\n    const int offset_qh = (QK_K / 2) * TILE_N;\n    const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N;\n    const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N;\n\n    // compensation\n    __m512i vcomp;\n\n    const __m512i m32s = _mm512_set1_epi32(32);\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    auto compute = [&](int col, int i) {\n      if (col == 0) {\n        // load a\n        va[0] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 0));\n        va[1] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 64));\n        va[2] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 128));\n        va[3] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 192));\n\n        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[0 * KB + i].bsums);\n        vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s);\n        vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n      }\n\n      // accmulate the quants\n      __m512i acc = _mm512_setzero_si512();\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      const char* b_qs = b_ptr;\n      const char* b_qh = b_ptr + offset_qh;\n      int mask = 0;\n      for (int k_group = 0; k_group < QK_K / 16; ++k_group) {\n        int r = k_group >> 2;\n        __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n        __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n\n        __m512i vsum = _mm512_setzero_si512();\n        __m512i hmask = _mm512_set1_epi8(0x3);\n\n        __m512i bytes = _mm512_loadu_si512(b_qs);\n        __m512i hbits = _mm512_loadu_si512(b_qh);\n        __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n        __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n        __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4);\n        __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2);\n\n        vb0 = _mm512_add_epi8(vb0, vh0);\n        vb1 = _mm512_add_epi8(vb1, vh1);\n        vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n        vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n        b_qs += 64;\n\n        va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n        va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n\n        bytes = _mm512_loadu_si512(b_qs);\n        vb0 = _mm512_and_si512(bytes, lowMask);\n        vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n        vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4));\n        vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2);\n        vb0 = _mm512_add_epi8(vb0, vh0);\n        vb1 = _mm512_add_epi8(vb1, vh1);\n        vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n        vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n        b_qs += 64;\n        b_qh += 64;\n\n        // B * A - 32 * A\n        __m512i vmask = _mm512_set1_epi32(k_group);\n        vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));\n\n        // vacc += scale * (q8 @ q6)\n        const __m512i vscale =\n            _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)(b_ptr + offset_scales + k_group * TILE_N)));\n        acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n      }\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_d0)));\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2;\n\n    const block_q8_K* RESTRICT A = static_cast<const block_q8_K*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    // load the 256 bytes from A to 4 avx512 vectors\n    __m512i va[4];\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // packed_B:\n    const int offset_scales = (QK_K / 2) * TILE_N;\n    const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N;\n\n    // compensation\n    __m512i vcomp;\n\n    const __m256i m128s = _mm256_set1_epi16(128);\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    const __m512i values128 = _mm512_set_epi8(113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n                                              113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n                                              113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n                                              113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127);\n    const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));\n    const __m512i values256 = _mm512_add_epi8(values128, off);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    auto compute = [&](int col, int i) {\n      if (col == 0) {\n        // load a\n        va[0] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 0));\n        va[1] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 64));\n        va[2] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 128));\n        va[3] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 192));\n\n        // compensation: 128 * A\n        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[0 * KB + i].bsums);\n        vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s));\n        vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n      }\n\n      // accmulate the quants\n      __m512i acc = _mm512_setzero_si512();\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      const char* b_qs = b_ptr;\n      int mask = 0;\n      for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n        int r = k_group >> 1;\n        __m512i vmask = _mm512_set1_epi32(k_group);\n        __m512i vsum = _mm512_setzero_si512();\n        for (int k = 0; k < 8; k += 2) {\n          __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n          __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n\n          __m512i bytes = _mm512_loadu_si512(b_qs);\n          __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask));\n          __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));\n\n          vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n          vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n          b_qs += 64;\n        }\n        // (B + 128) * A - 128 * A\n        vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));\n\n        // vacc += scale * (q8 @ q4)\n        const __m512i vscale =\n            _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)(b_ptr + offset_scales + k_group * TILE_N)));\n        acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n      }\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_d0)));\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\n#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE)                                                                         \\\n  tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply(                                     \\\n      KB, (const char*)wdata + 0 * row_size_A,                                                                       \\\n      (const char*)src0->extra + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), (float*)dst->data + 0 * N + nb_start, \\\n      ldc)\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_K,\n          typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>\nvoid tinygemm_kernel_amx(int M, int N, int KB, const void* RESTRICT _A, const void* RESTRICT _B, TC* RESTRICT C,\n                         int ldc) {\n  using packed_B_t = packed_B_type<TB>;\n  const int TILE_SIZE = get_tile_size<TB>();\n  const bool need_unpack = do_unpack<TB>::value;\n\n  GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);\n  const TA* RESTRICT A = static_cast<const TA*>(_A);\n  const char* RESTRICT B = static_cast<const char*>(_B);\n\n  const int m0 = std::min(M, TILE_M);\n  const int m1 = std::max(M - TILE_M, 0);\n  const int lda = KB * sizeof(TA);\n  // const int ldb = KB * sizeof(TB);\n\n  static thread_local packed_B_t Tile0[TILE_N * TILE_K];\n  static thread_local packed_B_t Tile1[TILE_N * TILE_K];\n  static thread_local int8_t Tile23[TILE_M * TILE_K];\n\n  static thread_local int32_t TileC0[TILE_M * TILE_N * 4];\n  static thread_local int32_t TileC1[TILE_M * TILE_N * 4];\n\n  // double buffering C to interleave avx512 and amx\n  int32_t* C_cur = TileC0;\n  int32_t* C_pre = TileC1;\n\n  auto Tile4 = [&](int32_t* base) { return base; };\n  auto Tile5 = [&](int32_t* base) { return base + TILE_M * TILE_N; };\n  auto Tile6 = [&](int32_t* base) { return base + 2 * TILE_M * TILE_N; };\n  auto Tile7 = [&](int32_t* base) { return base + 3 * TILE_M * TILE_N; };\n\n  if (M == 2 * TILE_M) {\n    // i = 0\n    const char* B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE);\n    const char* B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE);\n    if (need_unpack) {\n      unpack_B<TB>(Tile0, B_blk0);\n      _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n    } else {\n      _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);\n    }\n\n    _tile_zero(TMM4);\n    _tile_loadd(TMM2, A[0].qs, lda);\n    _tile_dpbssd(TMM4, TMM2, TMM0);\n    _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t));\n\n    _tile_zero(TMM5);\n    _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda);\n    _tile_dpbssd(TMM5, TMM3, TMM0);\n    _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));\n\n    if (need_unpack) {\n      unpack_B<TB>(Tile1, B_blk0);\n      _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n    } else {\n      _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);\n    }\n\n    _tile_zero(TMM6);\n    _tile_dpbssd(TMM6, TMM2, TMM1);\n    _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t));\n\n    _tile_zero(TMM7);\n    _tile_dpbssd(TMM7, TMM3, TMM1);\n    _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t));\n\n    for (int i = 1; i < KB; ++i) {\n      // index of previous iter\n      const int ii = i - 1;\n      const char* B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);\n      const char* B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);\n      GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] {\n        if (need_unpack) {\n          unpack_B<TB>(Tile0, B_blk0);\n          _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n        } else {\n          _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);\n        }\n        _tile_zero(TMM4);\n        _tile_loadd(TMM2, A[i].qs, lda);\n        acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n\n        _tile_dpbssd(TMM4, TMM2, TMM0);\n        _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));\n\n        _tile_zero(TMM5);\n        _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda);\n        acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB,\n                                     B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n\n        _tile_dpbssd(TMM5, TMM3, TMM0);\n        _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));\n\n        if (need_unpack) {\n          unpack_B<TB>(Tile1, B_blk1);\n          _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n        } else {\n          _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);\n        }\n        _tile_zero(TMM6);\n        acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE),\n                                     TILE_M);\n\n        _tile_dpbssd(TMM6, TMM2, TMM1);\n        _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));\n\n        _tile_zero(TMM7);\n        acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB,\n                                     B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);\n\n        _tile_dpbssd(TMM7, TMM3, TMM1);\n        _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));\n\n        std::swap(C_cur, C_pre);\n      });\n    }\n    // final accumulation\n    {\n      int ii = KB - 1;\n      acc_C<TA, TB, true>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n      acc_C<TA, TB, true>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB,\n                                 B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n      acc_C<TA, TB, true>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE),\n                                 TILE_M);\n      acc_C<TA, TB, true>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB,\n                                 B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);\n    }\n  } else {\n    for (int i = 0; i < KB; ++i) {\n      _tile_zero(TMM4);\n      _tile_zero(TMM6);\n      if (m1 != 0) {\n        _tile_zero(TMM5);\n        _tile_zero(TMM7);\n      }\n\n      const char* B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);\n      const char* B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);\n      if (need_unpack) {\n        unpack_B<TB>(Tile0, B_blk0);\n        _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n      } else {\n        _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);\n      }\n\n      if (need_unpack) {\n        unpack_B<TB>(Tile1, B_blk1);\n        _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n      } else {\n        _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);\n      }\n\n      if (m0 == TILE_M) {\n        _tile_loadd(TMM2, A[i].qs, lda);\n      } else {\n        unpack_A(Tile23, &A[i], KB, m0);\n        _tile_loadd(TMM2, Tile23, TILE_K);\n      }\n\n      _tile_dpbssd(TMM4, TMM2, TMM0);\n      _tile_dpbssd(TMM6, TMM2, TMM1);\n\n      _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));\n      _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));\n\n      GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {\n        acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);\n        acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE),\n                                     m0);\n      });\n      if (m1 != 0) {\n        unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1);\n        _tile_loadd(TMM3, Tile23, TILE_K);\n\n        _tile_dpbssd(TMM5, TMM3, TMM0);\n        _tile_dpbssd(TMM7, TMM3, TMM1);\n        _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));\n        _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));\n        GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {\n          acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB,\n                                       B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);\n          acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB,\n                                       B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);\n        });\n      }\n    }\n  }\n  return;\n}\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_K,\n          typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>\nvoid tinygemm_kernel_amx(int M, int N, int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C,\n                         int ldc) {\n  static_assert(std::is_same<TA, block_q8_K>::value);\n  const int TILE_SIZE = get_tile_size<TB>();\n\n  GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);\n  const TA* RESTRICT A = static_cast<const TA*>(_A);\n  const char* RESTRICT B = static_cast<const char*>(_B);\n\n  const int m0 = std::min(M, TILE_M);\n  const int m1 = std::max(M - TILE_M, 0);\n  // const int lda = KB * sizeof(TA);\n\n  static thread_local int8_t Tile0[TILE_N * TILE_K];\n  static thread_local int8_t Tile1[TILE_N * TILE_K];\n  static thread_local int8_t Tile23[TILE_M * TILE_K];\n\n  // mat mul result for each group\n  static thread_local int32_t Tile4[TILE_M * TILE_N];\n  static thread_local int32_t Tile5[TILE_M * TILE_N];\n  static thread_local int32_t Tile6[TILE_M * TILE_N];\n  static thread_local int32_t Tile7[TILE_M * TILE_N];\n\n  // sum of each QK_K block, contains 8 groups, int32\n  static thread_local int32_t Sumi4[TILE_M * TILE_N];\n  static thread_local int32_t Sumi5[TILE_M * TILE_N];\n  static thread_local int32_t Sumi6[TILE_M * TILE_N];\n  static thread_local int32_t Sumi7[TILE_M * TILE_N];\n\n  const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32;\n  for (int i = 0; i < KB; ++i) {\n    // step 1: accumulate the quants across 8 groups, each group with 32\n    for (int k = 0; k < QK_K / k_group_size; ++k) {\n      GGML_DISPATCH_BOOL(k > 0, is_acc, [&] {\n        _tile_zero(TMM4);\n        _tile_zero(TMM6);\n\n        unpack_B<TB>(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k);\n        _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n\n        unpack_B<TB>(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k);\n        _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n\n        unpack_A<TB>(Tile23, &A[i], KB, k, m0);\n        _tile_loadd(TMM2, Tile23, TILE_K);\n\n        _tile_dpbssd(TMM4, TMM2, TMM0);\n        _tile_dpbssd(TMM6, TMM2, TMM1);\n\n        _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));\n        _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));\n\n        scale_C<TB, is_acc>(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0);\n        scale_C<TB, is_acc>(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0);\n\n        if (m1 != 0) {\n          _tile_zero(TMM5);\n          _tile_zero(TMM7);\n\n          unpack_A<TB>(Tile23, &A[TILE_M * KB + i], KB, k, m1);\n          _tile_loadd(TMM3, Tile23, TILE_K);\n\n          _tile_dpbssd(TMM5, TMM3, TMM0);\n          _tile_dpbssd(TMM7, TMM3, TMM1);\n\n          _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));\n          _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));\n\n          scale_C<TB, is_acc>(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1);\n          scale_C<TB, is_acc>(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1);\n        }\n      });\n    }\n\n    // step 2: accmulate the mins\n    GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {\n      acc_C<TA, TB, is_acc>::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);\n      acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);\n      if (m1 != 0) {\n        acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB,\n                                     B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);\n        acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB,\n                                     B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);\n      }\n    });\n  }\n  return;\n}\n\n}  // anonymous namespace\n\n#define ARCH_GET_XCOMP_PERM 0x1022\n#define ARCH_REQ_XCOMP_PERM 0x1023\n#define XFEATURE_XTILECFG 17\n#define XFEATURE_XTILEDATA 18\n\nbool ggml_amx_init() {\n#if defined(__gnu_linux__)\n  if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {\n    fprintf(stderr, \"AMX is not ready to be used!\\n\");\n    return false;\n  }\n  return true;\n#elif defined(_WIN32)\n  return true;\n#endif\n}\n\nbool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor* dst) {\n  static thread_local bool is_first_time = true;\n  if (is_first_time) {\n#pragma omp single\n    { ggml_amx_init(); }\n\n    // load tile config\n    ggml_tile_config_init();\n  }\n  is_first_time = false;\n\n  const struct ggml_tensor* src0 = dst->src[0];\n  const struct ggml_tensor* src1 = dst->src[1];\n\n  const enum ggml_type type = src0->type;\n  const int64_t ne0 = dst->ne[0];\n\n  bool is_training = src0->grad || src1->grad;\n\n  // amx kernels enables for Q4_0, Q4_1, Q8_0, F16\n  // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256\n  bool has_amx_kernels = (type == GGML_TYPE_Q4_0) || (type == GGML_TYPE_Q4_1) || (type == GGML_TYPE_Q8_0) ||\n#ifndef GGML_QKK_64\n                         // only enabled for QK_K == 256\n                         (type == GGML_TYPE_Q4_K) || (type == GGML_TYPE_Q5_K) || (type == GGML_TYPE_Q6_K) ||\n                         (type == GGML_TYPE_IQ4_XS) ||\n#endif\n                         (type == GGML_TYPE_F16);\n\n  // handle only 2d gemm for now\n  auto is_contiguous_2d = [](const struct ggml_tensor* t) {\n    return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;\n  };\n\n  return dst->op != GGML_OP_MUL_MAT_ID && is_contiguous_2d(src0) && is_contiguous_2d(src1) && !is_training &&\n         src1->type == GGML_TYPE_F32 && has_amx_kernels &&\n         // out features is 32x\n         ne0 % (TILE_N * 2) == 0;\n}\n\n// NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX)\n//\n// src0: weight in shape of {N, K}, quantized\n// src1: input  in shape of {M, K}, float32\n// dst:  output in shape of {M, N}, float32\n//\n// the function performs: dst = src1 @ src0.T\n//\nvoid ggml_mul_mat_amx(struct ggml_tensor* dst, int nth, int ith, void* wdata, int wsize) {\n  struct ggml_tensor* src0 = dst->src[0];\n  struct ggml_tensor* src1 = dst->src[1];\n\n  const enum ggml_type TYPE = src0->type;\n\n  // f16 only has avx512 kernels for now,\n  // amx kernels will be added once 6th gen xeon is released.\n  const bool is_floating_type = TYPE == GGML_TYPE_F16;\n\n  const int M = dst->ne[1];\n  const int N = dst->ne[0];\n  const int K = src0->ne[0];\n  const int ldc = dst->nb[1] / dst->nb[0];\n\n  if (is_floating_type) {\n    constexpr int BLOCK_M = 4;\n    constexpr int BLOCK_N = 6;\n    const int MB = div_up(M, BLOCK_M);\n    const int NB = div_up(N, BLOCK_N);\n\n    parallel_for(nth, ith, MB * NB, [&](int begin, int end) {\n      GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {\n        for (int i = begin; i < end; ++i) {\n          int mb = i / NB;\n          int nb = i % NB;\n\n          int mb_start = mb * BLOCK_M;\n          int mb_size = std::min(BLOCK_M, M - mb_start);\n          int nb_start = nb * BLOCK_N;\n          int nb_size = std::min(BLOCK_N, N - nb_start);\n\n          switch (mb_size << 4 | nb_size) {\n            case 0x12:\n              LAUNCH_TINYGEMM_KERNEL_AVX(1, 2);\n              break;\n            case 0x14:\n              LAUNCH_TINYGEMM_KERNEL_AVX(1, 4);\n              break;\n            case 0x16:\n              LAUNCH_TINYGEMM_KERNEL_AVX(1, 6);\n              break;\n            case 0x22:\n              LAUNCH_TINYGEMM_KERNEL_AVX(2, 2);\n              break;\n            case 0x24:\n              LAUNCH_TINYGEMM_KERNEL_AVX(2, 4);\n              break;\n            case 0x26:\n              LAUNCH_TINYGEMM_KERNEL_AVX(2, 6);\n              break;\n            case 0x32:\n              LAUNCH_TINYGEMM_KERNEL_AVX(3, 2);\n              break;\n            case 0x34:\n              LAUNCH_TINYGEMM_KERNEL_AVX(3, 4);\n              break;\n            case 0x36:\n              LAUNCH_TINYGEMM_KERNEL_AVX(3, 6);\n              break;\n            case 0x42:\n              LAUNCH_TINYGEMM_KERNEL_AVX(4, 2);\n              break;\n            case 0x44:\n              LAUNCH_TINYGEMM_KERNEL_AVX(4, 4);\n              break;\n            case 0x46:\n              LAUNCH_TINYGEMM_KERNEL_AVX(4, 6);\n              break;\n            default:\n              fprintf(stderr, \"Unexpected block size!\\n\");\n          }\n        }\n      });\n    });\n    return;\n  }\n\n#pragma omp single\n  {\n    GGML_DISPATCH_QTYPES(TYPE, [&] {\n      const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);\n      GGML_ASSERT(wsize >= int(M * row_size_A));\n\n      // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size\n      // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size\n      GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);\n      // pack mat B to vnni format\n      if (src0->extra == nullptr) {\n        const size_t row_size_B = get_row_size<type, blck_size>(K);\n        src0->extra = aligned_alloc(64, N * row_size_B);\n        convert_B_packed_format<type, blck_size>((void*)src0->extra, (const type*)src0->data, N, K);\n      }\n\n      const float* A_data = static_cast<const float*>(src1->data);\n      for (int m = 0; m < M; ++m) {\n        from_float<vec_dot_type>(A_data + m * K, (char*)wdata + m * row_size_A, K);\n      }\n    });\n  }\n\n  GGML_ASSERT(src0->extra != nullptr);\n  if (M == 1) {\n    // MB = 1 and handle 8 tiles in each block\n    constexpr int kTilesN = 4;\n    constexpr int BLOCK_N = TILE_N * kTilesN;\n    const int NB = div_up(N, BLOCK_N);\n\n    parallel_for(nth, ith, NB, [&](int begin, int end) {\n      GGML_DISPATCH_QTYPES(TYPE, [&] {\n        const int KB = K / blck_size;\n        const int TILE_SIZE = get_tile_size<type>();\n        const int row_size_A = KB * sizeof(vec_dot_type);\n        for (int i = begin; i < end; ++i) {\n          int nb = i;\n          int nb_start = nb * BLOCK_N;\n          int nb_size = std::min(BLOCK_N, N - nb_start);  // 32, 64, 96\n\n          switch (nb_size) {\n            // case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break;\n            case 128:\n              LAUNCH_TINYGEMM_KERNEL_VNNI(128);\n              break;\n            case 96:\n              LAUNCH_TINYGEMM_KERNEL_VNNI(96);\n              break;\n            case 64:\n              LAUNCH_TINYGEMM_KERNEL_VNNI(64);\n              break;\n            case 32:\n              LAUNCH_TINYGEMM_KERNEL_VNNI(32);\n              break;\n            default:\n              fprintf(stderr, \"Unexpected n block size!\\n\");\n          }\n        }\n      });\n    });\n    return;\n  }\n\n  // handle 4 tiles at a tile\n  constexpr int BLOCK_M = TILE_M * 2;\n  constexpr int BLOCK_N = TILE_N * 2;\n  const int MB = div_up(M, BLOCK_M);\n  const int NB = div_up(N, BLOCK_N);\n\n  parallel_for(nth, ith, MB * NB, [&](int begin, int end) {\n    GGML_DISPATCH_QTYPES(TYPE, [&] {\n      const int KB = K / blck_size;\n      const int TILE_SIZE = get_tile_size<type>();\n      const int row_size_A = KB * sizeof(vec_dot_type);\n\n      for (int i = begin; i < end; ++i) {\n        int mb = i / NB;\n        int nb = i % NB;\n\n        int mb_start = mb * BLOCK_M;\n        int mb_size = std::min(BLOCK_M, M - mb_start);\n        int nb_start = nb * BLOCK_N;\n        int nb_size = BLOCK_N;\n\n        tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(\n            mb_size, nb_size, KB, (const char*)wdata + mb_start * row_size_A,\n            (const char*)src0->extra + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),\n            (float*)dst->data + mb_start * N + nb_start, ldc);\n      }\n    });\n  });\n}\n\n#else  // if defined(__AMX_INT8__)\n\nbool ggml_amx_init() {\n  fprintf(stderr, \"GGML is not compiled with AMX support!\\n\");\n  return false;\n}\n\nbool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor* dst) {\n  GGML_UNUSED(dst);\n  fprintf(stderr, \"GGML is not compiled with AMX support!\\n\");\n  return false;\n}\n\nvoid ggml_mul_mat_amx(struct ggml_tensor* dst, int nth, int ith, void* wdata, int wsize) {\n  GGML_UNUSED(dst);\n  GGML_UNUSED(nth);\n  GGML_UNUSED(ith);\n  GGML_UNUSED(wdata);\n  GGML_UNUSED(wsize);\n  fprintf(stderr, \"GGML is not compiled with AMX support!\\n\");\n}\n\n#endif  // if defined(__AMX_INT8__)\n\nvoid test_gemm() {\n  std::mt19937 gen(123);\n  // const int m=10,n=10,k=10;\n  const int m = 100, n = 100, k = 1024;\n  Mat<float> a(m, k, Layout::RowMajor), b(k, n, Layout::ColumnMajor);\n  a.random(gen);\n  b.random(gen);\n\n  a.print();\n  b.print();\n\n  ggml_type a_type = GGML_TYPE_Q4_K;\n  a.quant(a_type);\n  b.quant(ggml_internal_get_type_traits(a_type).vec_dot_type);\n\n  auto c = a.mul_check(b);\n\n  // quantize_row_q4_K_reference(a.data, block_q4_K *restrict y, int64_t k)\n\n  c.print();\n}\n\nint main() {\n  // int32_t x[1000]={};\n  // int32_t y[1000]={};\n  // for(int i=0;i<1000;i++){\n  //   x[i] = i;\n  // }\n  // // transpose_16x16_32bit(reinterpret_cast<__m512i*>(x));\n  // // transpose_16x4_32bit(reinterpret_cast<__m512i*>(x),(__m512i*)y);\n  // transpose_8x8_32bit((__m256i*)x,  (__m256i*)y);\n  // for(int i=0;i<300;i++){\n  //   if(i%8==0) printf(\"\\n\");\n  //   printf(\"%d \",x[i]);\n  // }\n  // for (int i = 0; i < 300; i++) {\n  //   if (i % 8 == 0)\n  //     printf(\"\\n\");\n  //   printf(\"%d \", y[i]);\n  // }\n\n  // block_q8_0 test[20] = {};\n  // for(int i=0;i<20;i++){\n  //   for(int j=0;j<32;j++){\n  //     test[i].qs[j] = i*32+j;\n  //   }\n  //   test[i].d = 0xffff;\n  // }\n  // uint8_t test_out[1000];\n\n  // for (int i = 0; i < 512; i++) {\n  //   if (i % 32 == 0)\n  //     printf(\"\\n\");\n  //   printf(\"%d \", test[i/32].qs[i%32]);\n\n  // }\n\n  // pack_B(test_out, test, 1);\n\n  // for(int i=0;i<512;i++){\n  //   if(i%32==0) printf(\"\\n\");\n  //   printf(\"%d \",test_out[i]);\n  // }\n\n  test_gemm();\n\n  return 0;\n}\n"
  },
  {
    "path": "kt-kernel/operators/amx/test/mmq.cpp",
    "content": "\n#if defined(__GNUC__)\n#pragma GCC diagnostic ignored \"-Wpedantic\"\n#pragma GCC diagnostic ignored \"-Wunused-local-typedefs\"\n#endif\n\n#include \"mmq.h\"\n\n#include <algorithm>\n#include <type_traits>\n\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n\n#if defined(__gnu_linux__)\n#include <sys/syscall.h>\n#include <unistd.h>\n#endif\n\n#if defined(_OPENMP)\n#include <omp.h>\n#endif\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define RESTRICT __restrict\n#else\n#define RESTRICT __restrict__\n#endif\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define ALWAYS_INLINE __forceinline\n#elif __has_attribute(always_inline) || defined(__GNUC__)\n#define ALWAYS_INLINE __attribute__((__always_inline__)) inline\n#else\n#define ALWAYS_INLINE inline\n#endif\n\n#if defined(__AMX_INT8__)\n\nnamespace {\n\n#define TILE_M 16\n#define TILE_N 16\n#define TILE_K 32\n#define VNNI_BLK 4\n\n#define AMX_BLK_SIZE 32\n\n#define TMM0 0\n#define TMM1 1\n#define TMM2 2\n#define TMM3 3\n#define TMM4 4\n#define TMM5 5\n#define TMM6 6\n#define TMM7 7\n\n// parallel routines\ntemplate <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>\ninline T div_up(T x, T y) {\n  return (x + y - 1) / y;\n}\n\ntemplate <typename T>\nvoid balance211(T n, T nth, T ith, T& n_start, T& n_end) {\n#if 0\n  // onednn partition pattern\n  T& n_my = n_end;\n  if (nth <= 1 || n == 0) {\n    n_start = 0;\n    n_my = n;\n  } else {\n    T n1 = div_up(n, nth);\n    T n2 = n1 - 1;\n    T T1 = n - n2 * nth;\n    n_my = ith < T1 ? n1 : n2;\n    n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;\n  }\n  n_end += n_start;\n#else\n  // pytorch aten partition pattern\n  T n_my = div_up(n, nth);\n  n_start = ith * n_my;\n  n_end = std::min(n_start + n_my, n);\n#endif\n}\n\ntemplate <typename func_t>\ninline void parallel_for(int nth, int ith, int n, const func_t& f) {\n  // int nth = omp_get_num_threads();\n  // int ith = omp_get_thread_num();\n  int tbegin, tend;\n  balance211(n, nth, ith, tbegin, tend);\n  f(tbegin, tend);\n}\n\n// Forced unrolling\ntemplate <int n>\nstruct Unroll {\n  template <typename Func, typename... Args>\n  ALWAYS_INLINE void operator()(const Func& f, Args... args) const {\n    Unroll<n - 1>{}(f, args...);\n    f(std::integral_constant<int, n - 1>{}, args...);\n  }\n};\n\ntemplate <>\nstruct Unroll<1> {\n  template <typename Func, typename... Args>\n  ALWAYS_INLINE void operator()(const Func& f, Args... args) const {\n    f(std::integral_constant<int, 0>{}, args...);\n  }\n};\n\n// type traits\ntemplate <typename T>\nstruct PackedTypes {};\ntemplate <>\nstruct PackedTypes<block_q4_0> {\n  using type = int8_t;\n};\ntemplate <>\nstruct PackedTypes<block_q4_1> {\n  using type = uint8_t;\n};\ntemplate <>\nstruct PackedTypes<block_q8_0> {\n  using type = int8_t;\n};\ntemplate <typename T>\nusing packed_B_type = typename PackedTypes<T>::type;\n\ntemplate <typename T>\nstruct do_compensate : std::integral_constant<bool, std::is_same<T, block_q8_0>::value> {};\n\ntemplate <typename T>\nstruct do_unpack\n    : std::integral_constant<bool, std::is_same<T, block_q4_0>::value || std::is_same<T, block_q4_1>::value> {};\n\ntemplate <typename T>\nstruct is_type_qkk\n    : std::integral_constant<bool, std::is_same<T, block_q4_K>::value || std::is_same<T, block_q5_K>::value ||\n                                       std::is_same<T, block_q6_K>::value || std::is_same<T, block_iq4_xs>::value> {};\n\n#define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...)              \\\n  [&] {                                                      \\\n    switch (TYPE) {                                          \\\n      case GGML_TYPE_F16: {                                  \\\n        using type = ggml_fp16_t;                            \\\n        constexpr int blck_size = 16;                        \\\n        return __VA_ARGS__();                                \\\n      }                                                      \\\n      case GGML_TYPE_BF16: {                                 \\\n        using type = ggml_bf16_t;                            \\\n        constexpr int blck_size = 32;                        \\\n        return __VA_ARGS__();                                \\\n      }                                                      \\\n      default:                                               \\\n        fprintf(stderr, \"Unsupported floating data type\\n\"); \\\n    }                                                        \\\n  }()\n\n#define GGML_DISPATCH_QTYPES(QT, ...)                         \\\n  [&] {                                                       \\\n    switch (QT) {                                             \\\n      case GGML_TYPE_Q4_0: {                                  \\\n        using type = block_q4_0;                              \\\n        using vec_dot_type = block_q8_0;                      \\\n        constexpr int blck_size = QK4_0;                      \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_Q4_1: {                                  \\\n        using type = block_q4_1;                              \\\n        using vec_dot_type = block_q8_1;                      \\\n        constexpr int blck_size = QK4_1;                      \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_Q8_0: {                                  \\\n        using type = block_q8_0;                              \\\n        using vec_dot_type = block_q8_0;                      \\\n        constexpr int blck_size = QK8_0;                      \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_Q4_K: {                                  \\\n        using type = block_q4_K;                              \\\n        using vec_dot_type = block_q8_K;                      \\\n        constexpr int blck_size = QK_K;                       \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_Q5_K: {                                  \\\n        using type = block_q5_K;                              \\\n        using vec_dot_type = block_q8_K;                      \\\n        constexpr int blck_size = QK_K;                       \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_Q6_K: {                                  \\\n        using type = block_q6_K;                              \\\n        using vec_dot_type = block_q8_K;                      \\\n        constexpr int blck_size = QK_K;                       \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      case GGML_TYPE_IQ4_XS: {                                \\\n        using type = block_iq4_xs;                            \\\n        using vec_dot_type = block_q8_K;                      \\\n        constexpr int blck_size = QK_K;                       \\\n        return __VA_ARGS__();                                 \\\n      }                                                       \\\n      default:                                                \\\n        fprintf(stderr, \"Unsupported quantized data type\\n\"); \\\n    }                                                         \\\n  }()\n\n#define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \\\n  [&] {                                            \\\n    if (BOOL_V) {                                  \\\n      constexpr bool BOOL_NAME = true;             \\\n      return __VA_ARGS__();                        \\\n    } else {                                       \\\n      constexpr bool BOOL_NAME = false;            \\\n      return __VA_ARGS__();                        \\\n    }                                              \\\n  }()\n\n// define amx tile config data structure\nstruct tile_config_t {\n  uint8_t palette_id = 0;\n  uint8_t start_row = 0;\n  uint8_t reserved_0[14] = {0};\n  uint16_t colsb[16] = {0};\n  uint8_t rows[16] = {0};\n};\n\n// Notes: amx tile config\n//\n// Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values,\n// and accumulate the result to a 16 x 16 matrix C containing INT32 values,\n//\n// As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used\n// instead of the normally used 16-16-64 config.\n//\n//   Block A: {16, 32}, dtype = int8_t\n//   Block B: {16, 32}, dtype = uint8_t/int8_t\n//   Block C: {16, 16}, dtype = int32_t\n//\n// Block B needs to be prepacked to vnni format before feeding into  TMUL:\n//   packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64}\n//\n// Therefore, we get tileconfig:\n//             A    B    C\n//    rows    16    8   16\n//    colsb   32   64   16\n//\n// For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1,\n// C used TMM4-TMM7:\n//            B TMM0  B TMM1\n//    A TMM2  C TMM4  C TMM6\n//    A TMM3  C TMM5  C TMM7\n//\n// Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A\n// will be needed.\n//\n// Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16;\n// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`.\n//\n// ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/\n//   advanced-matrix-extensions-intrinsics-functions.html\n//\n\n#define TC_CONFIG_TILE(i, r, cb) \\\n  tc.rows[i] = r;                \\\n  tc.colsb[i] = cb\nvoid ggml_tile_config_init(void) {\n  static thread_local tile_config_t tc;\n  tile_config_t current_tc;\n  _tile_storeconfig(&current_tc);\n\n  // load only when config changes\n  if (tc.palette_id == 0 || (memcmp(&current_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&\n                             memcmp(&current_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {\n    tc.palette_id = 1;\n    tc.start_row = 0;\n    TC_CONFIG_TILE(TMM0, 8, 64);\n    TC_CONFIG_TILE(TMM1, 8, 64);\n    TC_CONFIG_TILE(TMM2, 16, 32);\n    TC_CONFIG_TILE(TMM3, 16, 32);\n    TC_CONFIG_TILE(TMM4, 16, 64);\n    TC_CONFIG_TILE(TMM5, 16, 64);\n    TC_CONFIG_TILE(TMM6, 16, 64);\n    TC_CONFIG_TILE(TMM7, 16, 64);\n    _tile_loadconfig(&tc);\n  }\n}\n\n// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.\n// See the notes `s8s8 igemm compensation in avx512-vnni` for detail.\ntemplate <typename TB>\nint get_tile_size() {\n  int tile_size = TILE_N * sizeof(TB);\n  if (do_compensate<TB>::value) {\n    tile_size += TILE_N * sizeof(int32_t);\n  }\n  if (std::is_same<TB, block_q4_K>::value || std::is_same<TB, block_q5_K>::value) {\n    tile_size += TILE_N * 4;\n  }\n  if (std::is_same<TB, block_iq4_xs>::value) {\n    tile_size += TILE_N * 2;\n  }\n  return tile_size;\n}\n\ntemplate <typename TB, int BLOCK_K>\nint get_row_size(int K) {\n  int KB = K / BLOCK_K;\n  int row_size = KB * sizeof(TB);\n  if (do_compensate<TB>::value) {\n    row_size += KB * sizeof(int32_t);\n  }\n  if (std::is_same<TB, block_q4_K>::value || std::is_same<TB, block_q5_K>::value) {\n    row_size += KB * 4;\n  }\n  if (std::is_same<TB, block_iq4_xs>::value) {\n    row_size += KB * 2;\n  }\n  return row_size;\n}\n\n// vectorized dtype conversion\ninline float FP16_TO_FP32(ggml_half val) {\n  __m256i v = _mm256_setr_epi16(val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);\n  __m512 o = _mm512_cvtph_ps(v);\n  return _mm512_cvtss_f32(o);\n}\n\ninline __m512 FP16_TO_FP32_VEC(ggml_half val) {\n  __m256i v = _mm256_set1_epi16(val);\n  return _mm512_cvtph_ps(v);\n}\n\n// horizontal reduce\ninline float _mm512_reduce_max_ps(const __m512 x) {\n  __m512 v = x;\n  __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);\n  v = _mm512_max_ps(v, v1);\n  v1 = _mm512_shuffle_f32x4(v, v, 0xB1);\n  v = _mm512_max_ps(v, v1);\n  v1 = _mm512_shuffle_ps(v, v, 0x4E);\n  v = _mm512_max_ps(v, v1);\n  v1 = _mm512_shuffle_ps(v, v, 0xB1);\n  v = _mm512_max_ps(v, v1);\n  return _mm512_cvtss_f32(v);\n}\n\n// transpose utils\n#define SHUFFLE_EPI32(a, b, mask) \\\n  _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))\ninline void transpose_8x8_32bit(__m256i* v, __m256i* v1) {\n  // unpacking and 32-bit elements\n  v1[0] = _mm256_unpacklo_epi32(v[0], v[1]);\n  v1[1] = _mm256_unpackhi_epi32(v[0], v[1]);\n  v1[2] = _mm256_unpacklo_epi32(v[2], v[3]);\n  v1[3] = _mm256_unpackhi_epi32(v[2], v[3]);\n  v1[4] = _mm256_unpacklo_epi32(v[4], v[5]);\n  v1[5] = _mm256_unpackhi_epi32(v[4], v[5]);\n  v1[6] = _mm256_unpacklo_epi32(v[6], v[7]);\n  v1[7] = _mm256_unpackhi_epi32(v[6], v[7]);\n\n  // shuffling the 32-bit elements\n  v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44);\n  v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee);\n  v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44);\n  v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee);\n  v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44);\n  v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee);\n  v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44);\n  v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee);\n\n  // shuffling 128-bit elements\n  v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02);\n  v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02);\n  v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02);\n  v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02);\n  v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13);\n  v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13);\n  v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13);\n  v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13);\n}\n\ninline void transpose_16x4_32bit(__m512i* r, __m512i* d) {\n  static const __m512i index1 =\n      _mm512_set_epi32(0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02, 0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00);\n\n  d[0] = _mm512_permutexvar_epi32(index1, r[0]);\n  d[1] = _mm512_permutexvar_epi32(index1, r[1]);\n  d[2] = _mm512_permutexvar_epi32(index1, r[2]);\n  d[3] = _mm512_permutexvar_epi32(index1, r[3]);\n\n  r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);\n  r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);\n  r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);\n  r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);\n\n  d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);\n  d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);\n  d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);\n  d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);\n}\n\ninline void transpose_16x16_32bit(__m512i* v) {\n  __m512i v1[16];\n  v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);\n  v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);\n  v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);\n  v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);\n  v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);\n  v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);\n  v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);\n  v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);\n  v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);\n  v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);\n  v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);\n  v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);\n  v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);\n  v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);\n  v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);\n  v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);\n\n  v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);\n  v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);\n  v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);\n  v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);\n  v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);\n  v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);\n  v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);\n  v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);\n  v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);\n  v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);\n  v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);\n  v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);\n  v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);\n  v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);\n  v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);\n  v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);\n\n  v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);\n  v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);\n  v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);\n  v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);\n  v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);\n  v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);\n  v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);\n  v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);\n  v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);\n  v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);\n  v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);\n  v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);\n  v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);\n  v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);\n  v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);\n  v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);\n\n  v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);\n  v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);\n  v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);\n  v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);\n  v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);\n  v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);\n  v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);\n  v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);\n  v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);\n  v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);\n  v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);\n  v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);\n  v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);\n  v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);\n  v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);\n  v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);\n}\n\nvoid quantize_row_q8_K_vnni(const float* RESTRICT x, void* RESTRICT vy, int64_t k) {\n  assert(k % QK_K == 0);\n  const int KB = k / QK_K;\n  constexpr int kVecs = QK_K / 16;\n\n  block_q8_K* y = reinterpret_cast<block_q8_K*>(vy);\n\n  // hold 16 float vecs from x\n  __m512 v[kVecs];\n\n  // hold the quants vecs\n  __m512i vq[kVecs / 4];\n\n  // hold the packed quants vecs\n  __m512i vq_packed[kVecs / 4];\n\n  const __m512 signBit = _mm512_set1_ps(-0.f);\n\n  for (int i = 0; i < KB; ++i) {\n    // Compute max(abs(e)) for the block\n    __m512 vamax = _mm512_set1_ps(0.f);\n    for (int j = 0; j < kVecs; ++j) {\n      v[j] = _mm512_loadu_ps(x);\n      x += 16;\n      vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j]));\n    }\n    const float amax = _mm512_reduce_max_ps(vamax);\n\n    // Quantize these floats\n    const float iscale = 127.f / amax;\n    y[i].d = GGML_FP32_TO_FP16(1 / iscale);\n    const float id = (amax != 0.0f) ? iscale : 0.f;\n    const __m512 vscale = _mm512_set1_ps(id);\n\n    // Apply multiplier and round to nearest integer\n    for (int j = 0; j < kVecs; ++j) {\n      v[j] = _mm512_mul_ps(v[j], vscale);\n      v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\n    }\n\n    // Pack to epi8 vecs\n    for (int j = 0; j < kVecs / 4; ++j) {\n      __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0]));\n      __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1]));\n      __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2]));\n      __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3]));\n\n      __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1);\n      __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1);\n\n      vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1);\n      _mm512_storeu_si512((__m512i*)(y[i].qs + j * 64), vq[j]);\n    }\n\n    // Compute the bsums with vnni\n    transpose_16x4_32bit(vq, vq_packed);\n\n    const __m512i one = _mm512_set1_epi8(1);\n    __m512i sum = _mm512_setzero_si512();\n    for (int k = 0; k < 4; ++k) {\n      sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]);\n    }\n    _mm256_storeu_si256((__m256i*)(y[i].bsums), _mm512_cvtepi32_epi16(sum));\n  }\n}\n\n// quantize A from float to `vec_dot_type`\ntemplate <typename T>\ninline void from_float(const float* x, char* vy, int64_t k);\n\ntemplate <>\ninline void from_float<block_q8_0>(const float* x, char* vy, int64_t k) {\n  quantize_row_q8_0(x, vy, k);\n}\n\ntemplate <>\ninline void from_float<block_q8_1>(const float* x, char* vy, int64_t k) {\n  quantize_row_q8_1(x, vy, k);\n}\n\ntemplate <>\ninline void from_float<block_q8_K>(const float* x, char* vy, int64_t k) {\n#if 1\n  // TODO: this is reference impl!\n  quantize_row_q8_K(x, vy, k);\n#else\n  quantize_row_q8_K_vnni(x, vy, k);\n#endif\n}\n\n// load A from memory to array when nrows can not fill in whole tile\nvoid unpack_A(int8_t* RESTRICT tile, const block_q8_0* RESTRICT A, int lda, int nr) {\n  assert(nr != TILE_M);\n  for (int m = 0; m < nr; ++m) {\n    const __m256i v = _mm256_loadu_si256((const __m256i*)(A[m * lda].qs));\n    _mm256_storeu_si256((__m256i*)(tile + m * TILE_K), v);\n  }\n}\n\nvoid unpack_A(int8_t* RESTRICT tile, const block_q8_1* RESTRICT A, int lda, int nr) {\n  assert(nr != TILE_M);\n  for (int m = 0; m < nr; ++m) {\n    const __m256i v = _mm256_loadu_si256((const __m256i*)(A[m * lda].qs));\n    _mm256_storeu_si256((__m256i*)(tile + m * TILE_K), v);\n  }\n}\n\ntemplate <typename TB>\nvoid unpack_A(int8_t* RESTRICT tile, const block_q8_K* RESTRICT A, int lda, int k, int nr) {\n  assert(nr <= TILE_M);\n  for (int m = 0; m < nr; ++m) {\n    const __m256i v = _mm256_loadu_si256((const __m256i*)(A[m * lda].qs + k * 32));\n    _mm256_storeu_si256((__m256i*)(tile + m * TILE_K), v);\n  }\n}\n\ntemplate <>\nvoid unpack_A<block_q6_K>(int8_t* RESTRICT tile, const block_q8_K* RESTRICT A, int lda, int k, int nr) {\n  assert(nr <= TILE_M);\n  // zero padding k from 16 to 32, so that we don't have to re-config amx\n  const __m128i zero = _mm_setzero_si128();\n  for (int m = 0; m < nr; ++m) {\n    const __m128i v = _mm_loadu_si128((const __m128i*)(A[m * lda].qs + k * 16));\n    const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1);\n    _mm256_storeu_si256((__m256i*)(tile + m * TILE_K), r);\n  }\n}\n\n#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)\ninline __m256i bytes_from_nibbles_32(const uint8_t* rsi) {\n  const __m128i tmp = _mm_loadu_si128((const __m128i*)rsi);\n  const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);\n  const __m256i lowMask = _mm256_set1_epi8(0xF);\n  return _mm256_and_si256(lowMask, bytes);\n}\n\n// used for block_q4_K\ninline __m512i bytes_from_nibbles_64(const uint8_t* rsi) {\n  const __m256i tmp = _mm256_loadu_si256((const __m256i*)rsi);\n  const __m256i lowMask = _mm256_set1_epi8(0xF);\n  const __m256i q4l = _mm256_and_si256(tmp, lowMask);\n  const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask);\n  return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1);\n}\n\n// used for block_q5_K\ninline __m512i bytes_from_nibbles_64(const uint8_t* qs, const uint8_t* qh, int k) {\n  const __m256i lowMask = _mm256_set1_epi8(0xF);\n  __m256i hmask = _mm256_set1_epi8(1);\n  hmask = _mm256_slli_epi16(hmask, k);\n\n  const __m256i q5bits = _mm256_loadu_si256((const __m256i*)qs);\n  const __m256i hbits = _mm256_loadu_si256((const __m256i*)qh);\n\n  const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask);\n  const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4);\n  const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);\n  hmask = _mm256_slli_epi16(hmask, 1);\n\n  const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask);\n  const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4);\n  const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);\n\n  return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1);\n}\n\n// used for block_q6_K\ninline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t* qs, const uint8_t* qh) {\n  const __m256i m4 = _mm256_set1_epi8(0xF);\n  const __m256i m2 = _mm256_set1_epi8(0x3);\n\n  const __m256i q6bits1 = _mm256_loadu_si256((const __m256i*)qs);\n  const __m256i q6bits2 = _mm256_loadu_si256((const __m256i*)(qs + 32));\n  const __m256i q6bitsH = _mm256_loadu_si256((const __m256i*)qh);\n\n  const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256(q6bitsH, m2), 4);\n  const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4);\n  const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4);\n  const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4);\n\n  const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0);\n  const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1);\n  const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2);\n  const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3);\n\n  r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1);\n  r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1);\n}\n\ninline __m512i packNibbles(__m512i r0, __m512i r1) { return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4)); }\n\ntemplate <typename TB>\ninline void pack_qs(void* RESTRICT packed_B, const TB* RESTRICT B, int KB) {\n  int8_t tmp[8 * 64];\n  __m256i v[8], v2[8];\n  for (int n = 0; n < 8; ++n) {\n    v[n] = bytes_from_nibbles_32(B[n * KB].qs);\n  }\n  transpose_8x8_32bit(v, v2);\n  for (int n = 0; n < 8; ++n) {\n    _mm256_storeu_si256((__m256i*)(tmp + n * 64), v2[n]);\n  }\n  for (int n = 0; n < 8; ++n) {\n    v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs);\n  }\n  transpose_8x8_32bit(v, v2);\n  for (int n = 0; n < 8; ++n) {\n    _mm256_storeu_si256((__m256i*)(tmp + n * 64 + 32), v2[n]);\n  }\n\n  // pack again with 128 to fully utilize vector length\n  for (int n = 0; n < 8; n += 2) {\n    __m512i r0 = _mm512_loadu_si512((const __m512i*)(tmp + n * 64));\n    __m512i r1 = _mm512_loadu_si512((const __m512i*)(tmp + n * 64 + 64));\n    __m512i r1r0 = packNibbles(r0, r1);\n    _mm512_storeu_si512((__m512i*)((char*)packed_B + n * 32), r1r0);\n  }\n}\n\ntemplate <>\ninline void pack_qs<block_q8_0>(void* RESTRICT packed_B, const block_q8_0* RESTRICT B, int KB) {\n  __m256i v[8], v2[8];\n  for (int n = 0; n < 8; ++n) {\n    v[n] = _mm256_loadu_si256((const __m256i*)(B[n * KB].qs));\n  }\n  transpose_8x8_32bit(v, v2);\n  for (int n = 0; n < 8; ++n) {\n    _mm256_storeu_si256((__m256i*)((char*)packed_B + n * 64), v2[n]);\n  }\n  for (int n = 0; n < 8; ++n) {\n    v[n] = _mm256_loadu_si256((const __m256i*)(B[(n + 8) * KB].qs));\n  }\n  transpose_8x8_32bit(v, v2);\n  for (int n = 0; n < 8; ++n) {\n    _mm256_storeu_si256((__m256i*)((char*)packed_B + n * 64 + 32), v2[n]);\n  }\n}\n\ntemplate <>\ninline void pack_qs<block_q4_K>(void* RESTRICT packed_B, const block_q4_K* RESTRICT B, int KB) {\n  __m512i v[16];\n  // QK_K 256 with 8 groups, handle 2 groups at a time\n  char* pb = (char*)packed_B;\n  for (int k = 0; k < QK_K / 64; ++k) {\n    // pack 2 groups { n, g,  k} to {g, k/4, 4n}\n    //          e.g. {16, 2, 32} to {2,   8, 64}\n    for (int n = 0; n < TILE_N; ++n) {\n      v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32);\n    }\n\n    transpose_16x16_32bit(v);\n\n    // pack again with 128 to fully utilize vector length\n    for (int n = 0; n < TILE_N; n += 2) {\n      _mm512_storeu_si512((__m512i*)pb, packNibbles(v[n], v[n + 1]));\n      pb += 64;\n    }\n  }\n}\n\ntemplate <>\ninline void pack_qs<block_q5_K>(void* RESTRICT packed_B, const block_q5_K* RESTRICT B, int KB) {\n  __m512i v[16];\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  // QK_K 256 with 8 groups, handle 2 groups at a time\n  char* pb = (char*)packed_B;\n  char* ph = (char*)packed_B + (QK_K / 2) * TILE_N;\n  for (int k = 0; k < QK_K / 64; ++k) {\n    // pack 2 groups { n, g,  k} to {g, k/4, 4n}\n    //          e.g. {16, 2, 32} to {2,   8, 64}\n    for (int n = 0; n < TILE_N; ++n) {\n      v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */ 2 * k);\n    }\n\n    transpose_16x16_32bit(v);\n\n    // 1. pack lower 4bits with 2 groups\n    for (int n = 0; n < TILE_N; n += 2) {\n      // get lower 4 bits\n      const __m512i r0 = _mm512_and_si512(v[n], lowMask);\n      const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);\n      _mm512_storeu_si512((__m512i*)pb, packNibbles(r0, r1));\n      pb += 64;\n    }\n\n    // 2. pack higher 1bit with 2 groups\n    const __m512i hmask = _mm512_set1_epi8(0x10);\n    for (int g = 0; g < 2; ++g) {\n      __m512i hbits = _mm512_setzero_si512();\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4));\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3));\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2));\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1));\n      hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask));\n      hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1));\n      hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2));\n      hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3));\n      _mm512_storeu_si512((__m512i*)ph, hbits);\n      ph += 64;\n    }\n  }\n}\n\ntemplate <>\ninline void pack_qs<block_q6_K>(void* RESTRICT packed_B, const block_q6_K* RESTRICT B, int KB) {\n  __m512i v[32];\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  // QK_K 256 with 8 groups, handle 4 groups at a time\n  char* pb = (char*)packed_B;\n  char* ph = (char*)packed_B + (QK_K / 2) * TILE_N;\n  for (int k = 0; k < QK_K / 128; ++k) {\n    for (int n = 0; n < TILE_N; ++n) {\n      bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32);\n    }\n\n    // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7\n    transpose_16x16_32bit(v);\n    transpose_16x16_32bit(v + 16);\n\n    // 1. pack lower 4bits with 4 groups\n    for (int n = 0; n < 32; n += 2) {\n      const __m512i r0 = _mm512_and_si512(v[n], lowMask);\n      const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);\n      _mm512_storeu_si512((__m512i*)pb, packNibbles(r0, r1));\n      pb += 64;\n    }\n\n    // 2. pack higher 2bit with 4 groups\n    const __m512i hmask = _mm512_set1_epi8(0x30);\n    for (int g = 0; g < 8; ++g) {\n      __m512i hbits = _mm512_setzero_si512();\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4));\n      hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2));\n      hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask));\n      hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2));\n      _mm512_storeu_si512((__m512i*)ph, hbits);\n      ph += 64;\n    }\n  }\n}\n\ntemplate <>\ninline void pack_qs<block_iq4_xs>(void* RESTRICT packed_B, const block_iq4_xs* RESTRICT B, int KB) {\n  __m512i v[16];\n  char* pb = (char*)packed_B;\n  for (int k = 0; k < QK_K / 64; ++k) {\n    for (int n = 0; n < TILE_N; ++n) {\n      __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0);\n      __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16);\n      v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1);\n    }\n\n    transpose_16x16_32bit(v);\n\n    // pack again with 128 to fully utilize vector length\n    for (int n = 0; n < TILE_N; n += 2) {\n      _mm512_storeu_si512((__m512i*)pb, packNibbles(v[n], v[n + 1]));\n      pb += 64;\n    }\n  }\n}\n\n// pack B to vnni formats in 4bits or 8 bits\nvoid pack_B(void* RESTRICT packed_B, const block_q4_0* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n  ggml_half* d0 = reinterpret_cast<ggml_half*>((char*)packed_B + TILE_N * TILE_K / 2);\n  for (int n = 0; n < TILE_N; ++n) {\n    d0[n] = B[n * KB].d;\n  }\n}\n\nvoid pack_B(void* RESTRICT packed_B, const block_q4_1* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n  ggml_half* d0 = reinterpret_cast<ggml_half*>((char*)packed_B + TILE_N * TILE_K / 2);\n  ggml_half* m0 = d0 + TILE_N;\n  for (int n = 0; n < TILE_N; ++n) {\n    d0[n] = B[n * KB].d;\n    m0[n] = B[n * KB].m;\n  }\n}\n\ninline void s8s8_compensation(void* RESTRICT packed_B) {\n  // packed_B layout:\n  //   quants {TILE_N, TILEK}  int8_t\n  //   d0     {TILE_N}      ggml_half\n  //   comp   {TILE_N}        int32_t\n  const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);\n  __m512i vcomp = _mm512_setzero_si512();\n  const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));\n  for (int k = 0; k < 8; ++k) {\n    __m512i vb = _mm512_loadu_si512((const __m512i*)((const char*)packed_B + k * 64));\n    vcomp = _mm512_dpbusd_epi32(vcomp, off, vb);\n  }\n  _mm512_storeu_si512((__m512i*)((char*)(packed_B) + offset), vcomp);\n}\n\nvoid pack_B(void* RESTRICT packed_B, const block_q8_0* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n  ggml_half* d0 = reinterpret_cast<ggml_half*>((char*)packed_B + TILE_N * TILE_K);\n  for (int n = 0; n < TILE_N; ++n) {\n    d0[n] = B[n * KB].d;\n  }\n  s8s8_compensation(packed_B);\n}\n\n// convert 8 * {min, scale} from int6 to int8\ninline void unpack_mins_and_scales(const uint8_t* scales, uint32_t* utmp) {\n  const uint32_t kmask1 = 0x3f3f3f3f;\n  const uint32_t kmask2 = 0x0f0f0f0f;\n  const uint32_t kmask3 = 0x03030303;\n\n  memcpy(utmp, scales, 12);\n  utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);\n  const uint32_t uaux = utmp[1] & kmask1;\n  utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);\n  utmp[2] = uaux;\n  utmp[0] &= kmask1;\n}\n\n// packed_B layout:\n//   quants {8, TILE_N, 16}  uint8\n//   scales {8, TILE_N}      uint8\n//   mins   {8, TILE_N}      uint8\n//   d      {TILE_N}     ggml_half\n//   dmin   {TILE_N}     ggml_half\nvoid pack_B(void* RESTRICT packed_B, const block_q4_K* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n\n  uint8_t* scales = reinterpret_cast<uint8_t*>((char*)packed_B + (QK_K / 2) * TILE_N);\n  uint8_t* mins = scales + 8 * TILE_N;\n  ggml_half* d = reinterpret_cast<ggml_half*>(mins + 8 * TILE_N);\n  ggml_half* dmin = d + TILE_N;\n\n  union {\n    uint32_t u32[4];\n    uint8_t u8[16];\n  } s;\n\n  for (int n = 0; n < TILE_N; ++n) {\n    unpack_mins_and_scales(B[n * KB].scales, s.u32);\n    for (int k = 0; k < 8; ++k) {\n      scales[k * TILE_N + n] = s.u8[k];\n      mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];\n    }\n    d[n] = B[n * KB].d;\n    dmin[n] = B[n * KB].dmin;\n  }\n}\n\n// packed_B layout:\n//   quants {8, TILE_N, 16}  uint8\n//   qh     {8, TILE_N,  4}  uint8\n//   scales {8, TILE_N}      uint8\n//   mins   {8, TILE_N}      uint8\n//   d      {TILE_N}     ggml_half\n//   dmin   {TILE_N}     ggml_half\nvoid pack_B(void* RESTRICT packed_B, const block_q5_K* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n\n  uint8_t* scales = reinterpret_cast<uint8_t*>((char*)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);\n  uint8_t* mins = scales + 8 * TILE_N;\n  ggml_half* d = reinterpret_cast<ggml_half*>(mins + 8 * TILE_N);\n  ggml_half* dmin = d + TILE_N;\n\n  union {\n    uint32_t u32[4];\n    uint8_t u8[16];\n  } s;\n\n  for (int n = 0; n < TILE_N; ++n) {\n    unpack_mins_and_scales(B[n * KB].scales, s.u32);\n    for (int k = 0; k < 8; ++k) {\n      scales[k * TILE_N + n] = s.u8[k];\n      mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];\n    }\n    d[n] = B[n * KB].d;\n    dmin[n] = B[n * KB].dmin;\n  }\n}\n\n// packed_B layout:\n//   quants {16, TILE_N, 8}  uint8\n//   qh     {16, TILE_N, 4}  uint8\n//   scales {16, TILE_N}      uint8\n//   d      {TILE_N}     ggml_half\nvoid pack_B(void* RESTRICT packed_B, const block_q6_K* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n\n  uint8_t* scales = reinterpret_cast<uint8_t*>((char*)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);\n  ggml_half* d = reinterpret_cast<ggml_half*>(scales + 16 * TILE_N);\n  for (int n = 0; n < TILE_N; ++n) {\n    const int8_t* ps = B[n * KB].scales;\n    for (int k = 0; k < 16; ++k) {\n      scales[k * TILE_N + n] = ps[k];\n    }\n    d[n] = B[n * KB].d;\n  }\n}\n\n// packed_B layout:\n//   quants {8, TILE_N, 16}  uint8\n//   scales {8, TILE_N}       int8\n//   d      {TILE_N}     ggml_half\nvoid pack_B(void* RESTRICT packed_B, const block_iq4_xs* RESTRICT B, int KB) {\n  pack_qs(packed_B, B, KB);\n\n  int8_t* scales = reinterpret_cast<int8_t*>((char*)packed_B + (QK_K / 2) * TILE_N);\n  ggml_half* d = reinterpret_cast<ggml_half*>(scales + 8 * TILE_N);\n\n  // pack the scales\n  for (int n = 0; n < TILE_N; ++n) {\n    uint16_t sh = B[n * KB].scales_h;\n    for (int k = 0; k < 8; k += 2) {\n      const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32;\n      const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32;\n      scales[(k + 0) * TILE_N + n] = ls1;\n      scales[(k + 1) * TILE_N + n] = ls2;\n      sh >>= 4;\n    }\n    d[n] = B[n * KB].d;\n  }\n}\n\ntemplate <typename TB, typename packed_B_t = packed_B_type<TB>>\nvoid unpack_B(packed_B_t* RESTRICT tile, const void* RESTRICT packed_B) {\n  GGML_UNUSED(tile);\n  GGML_UNUSED(packed_B);\n};\n\ntemplate <>\nvoid unpack_B<block_q4_0>(int8_t* RESTRICT tile, const void* RESTRICT packed_B) {\n  const __m512i off = _mm512_set1_epi8(8);\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512((const __m512i*)((const char*)packed_B + n * 32));\n    const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off);\n    const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 0), r0);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 64), r1);\n  }\n}\n\ntemplate <>\nvoid unpack_B<block_q4_1>(uint8_t* RESTRICT tile, const void* RESTRICT packed_B) {\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512((const __m512i*)((const char*)packed_B + n * 32));\n    const __m512i r0 = _mm512_and_si512(bytes, lowMask);\n    const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 0), r0);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 64), r1);\n  }\n}\n\n// packed_B_t for QKK is int8_t\ntemplate <typename TB>\nvoid unpack_B(int8_t* RESTRICT tile, const void* RESTRICT packed_B, int k) {\n  const int packed_B_group_size = QK_K / 2 * TILE_N / 8;\n  const char* packed_B_group = (const char*)packed_B + k * packed_B_group_size;\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32);\n    const __m512i r0 = _mm512_and_si512(bytes, lowMask);\n    const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 0), r0);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 64), r1);\n  }\n}\n\ntemplate <>\nvoid unpack_B<block_q5_K>(int8_t* RESTRICT tile, const void* RESTRICT packed_B, int k) {\n  // lower 4bits, stride 256 bytes\n  const int packed_l4_group_size = QK_K / 2 * TILE_N / 8;\n  const char* pb = (const char*)packed_B + k * packed_l4_group_size;\n\n  // higher 1bit, stride 64 bytes\n  const int packed_h1_group_size = QK_K / 8 * TILE_N / 8;\n  const char* ph = (const char*)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size;\n  const __m512i hbits = _mm512_loadu_si512(ph);\n\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  __m512i hmask0 = _mm512_set1_epi8(0x1);\n  __m512i hmask1 = _mm512_set1_epi8(0x2);\n\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512(pb + n * 32);\n    __m512i r0 = _mm512_and_si512(bytes, lowMask);\n    __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n    __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4);\n    __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4);\n\n    hmask0 = _mm512_slli_epi16(hmask0, 2);\n    hmask1 = _mm512_slli_epi16(hmask1, 2);\n    r0 = _mm512_add_epi8(r0, h0);\n    r1 = _mm512_add_epi8(r1, h1);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 0), r0);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 64), r1);\n  }\n}\n\ntemplate <>\nvoid unpack_B<block_q6_K>(int8_t* RESTRICT tile, const void* RESTRICT packed_B, int k) {\n  // lower 4bits, stride 128 bytes\n  const int packed_l4_group_size = QK_K / 2 * TILE_N / 16;\n  const char* pb = (const char*)packed_B + k * packed_l4_group_size;\n\n  // higher 2bits, stride 64 bytes\n  const int packed_h2_group_size = QK_K / 4 * TILE_N / 16;\n  const char* ph = (const char*)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size;\n  const __m512i hbits = _mm512_loadu_si512(ph);\n\n  const __m512i off = _mm512_set1_epi8(32);\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n  __m512i hmask0 = _mm512_set1_epi8(0x3);  // 0011\n  __m512i hmask1 = _mm512_set1_epi8(0xC);  // 1100\n\n  // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A`\n  __m512i bytes = _mm512_loadu_si512(pb);\n  __m512i r0 = _mm512_and_si512(bytes, lowMask);\n  __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n  __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4);\n  __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2);\n  _mm512_storeu_si512((__m512i*)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));\n  _mm512_storeu_si512((__m512i*)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));\n\n  hmask0 = _mm512_slli_epi16(hmask0, 4);\n  hmask1 = _mm512_slli_epi16(hmask1, 4);\n\n  bytes = _mm512_loadu_si512(pb + 64);\n  r0 = _mm512_and_si512(bytes, lowMask);\n  r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n  h0 = _mm512_and_si512(hbits, hmask0);\n  h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2);\n  _mm512_storeu_si512((__m512i*)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));\n  _mm512_storeu_si512((__m512i*)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));\n}\n\ntemplate <>\nvoid unpack_B<block_iq4_xs>(int8_t* RESTRICT tile, const void* RESTRICT packed_B, int k) {\n  static const __m512i values128 = _mm512_set_epi8(\n      113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, 113, 89, 69, 53, 38, 25, 13, 1, -10,\n      -22, -35, -49, -65, -83, -104, -127, 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n      113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127);\n\n  const int packed_B_group_size = QK_K / 2 * TILE_N / 8;\n  const char* pb = (const char*)packed_B + k * packed_B_group_size;\n  const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n  for (int n = 0; n < 8; n += 2) {\n    __m512i bytes = _mm512_loadu_si512(pb + n * 32);\n    const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask));\n    const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 0), r0);\n    _mm512_storeu_si512((__m512i*)(tile + n * 64 + 64), r1);\n  }\n}\n\ntemplate <typename TA, typename TB, bool is_acc>\nstruct acc_C {};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_0, block_q4_0, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_0* A, int lda,\n                    const void* packed_B, int nr) {\n    const int offset = TILE_N * TILE_K / 2;\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)((const char*)packed_B + offset)));\n\n    for (int m = 0; m < nr; ++m) {\n      const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n      vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_1, block_q4_1, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_1* A, int lda,\n                    const void* packed_B, int nr) {\n    const int offset = TILE_N * TILE_K / 2;\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)((const char*)packed_B + offset)));\n    const __m512 vm0 = _mm512_cvtph_ps(\n        _mm256_loadu_si256((const __m256i*)((const char*)packed_B + offset + TILE_N * sizeof(ggml_half))));\n\n    for (int m = 0; m < nr; ++m) {\n      const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));\n      const __m512 vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].s));\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n      vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);\n      vsum = _mm512_fmadd_ps(vm0, vs1, vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_0, block_q8_0, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_0* A, int lda,\n                    const void* packed_B, int nr) {\n    const int offset = TILE_N * TILE_K;\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)((const char*)packed_B + offset)));\n\n    for (int m = 0; m < nr; ++m) {\n      const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n      vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_q4_K, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_K* A, int lda,\n                    const void* packed_B, int nr) {\n    const uint8_t* scales = reinterpret_cast<const uint8_t*>((const char*)packed_B + (QK_K / 2) * TILE_N);\n    const uint8_t* mins = scales + 8 * TILE_N;\n    const ggml_half* d0 = reinterpret_cast<const ggml_half*>(mins + 8 * TILE_N);\n    const ggml_half* dmin = d0 + TILE_N;\n\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)d0));\n    const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)dmin));\n\n    for (int m = 0; m < nr; ++m) {\n      const float d1 = A[m * lda].d;\n      const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n      const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n\n      const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[m * lda].bsums);\n      const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n\n      __m512i acc_m = _mm512_setzero_si512();\n      for (int k = 0; k < 4; ++k) {\n        __m512i vmask = _mm512_set1_epi32(k);\n        __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));\n        __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(mins + k * 32)));\n        acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n      }\n\n      vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n      vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_q5_K, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_K* A, int lda,\n                    const void* packed_B, int nr) {\n    const uint8_t* scales =\n        reinterpret_cast<const uint8_t*>((const char*)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);\n    const uint8_t* mins = scales + 8 * TILE_N;\n    const ggml_half* d0 = reinterpret_cast<const ggml_half*>(mins + 8 * TILE_N);\n    const ggml_half* dmin = d0 + TILE_N;\n\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)d0));\n    const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)dmin));\n\n    for (int m = 0; m < nr; ++m) {\n      const float d1 = A[m * lda].d;\n      const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n      const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n\n      const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[m * lda].bsums);\n      const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n\n      __m512i acc_m = _mm512_setzero_si512();\n      for (int k = 0; k < 4; ++k) {\n        __m512i vmask = _mm512_set1_epi32(k);\n        __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));\n        __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(mins + k * 32)));\n        acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n      }\n\n      vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n      vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_q6_K, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_K* A, int lda,\n                    const void* packed_B, int nr) {\n    const uint8_t* scales =\n        reinterpret_cast<const uint8_t*>((const char*)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);\n    const ggml_half* d0 = reinterpret_cast<const ggml_half*>(scales + 16 * TILE_N);\n\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)d0));\n\n    for (int m = 0; m < nr; ++m) {\n      const float d1 = A[m * lda].d;\n      const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n\n      vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <bool is_acc>\nstruct acc_C<block_q8_K, block_iq4_xs, is_acc> {\n  static void apply(float* RESTRICT C, int ldc, const int32_t* RESTRICT tile, const block_q8_K* A, int lda,\n                    const void* packed_B, int nr) {\n    const int8_t* scales = reinterpret_cast<const int8_t*>((const char*)packed_B + (QK_K / 2) * TILE_N);\n    const ggml_half* d0 = reinterpret_cast<const ggml_half*>(scales + 8 * TILE_N);\n\n    const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)d0));\n\n    for (int m = 0; m < nr; ++m) {\n      const float d1 = A[m * lda].d;\n      const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);\n      const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));\n\n      __m512 vsum;\n      if (is_acc) {\n        vsum = _mm512_loadu_ps(C + m * ldc);\n      } else {\n        vsum = _mm512_set1_ps(0.f);\n      }\n\n      vsum = _mm512_fmadd_ps(vtile, vd, vsum);\n      _mm512_storeu_ps(C + m * ldc, vsum);\n    }\n  }\n};\n\ntemplate <typename TB>\nconstexpr int get_quants_size();\ntemplate <>\nconstexpr int get_quants_size<block_q4_K>() {\n  return (QK_K / 2) * TILE_N;\n}\ntemplate <>\nconstexpr int get_quants_size<block_q5_K>() {\n  return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N;\n}\ntemplate <>\nconstexpr int get_quants_size<block_q6_K>() {\n  return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N;\n}\ntemplate <>\nconstexpr int get_quants_size<block_iq4_xs>() {\n  return (QK_K / 2) * TILE_N;\n}\n\n// used for QKK format\ntemplate <typename TB, bool is_acc, typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>\ninline void scale_C(const int32_t* RESTRICT tile, int32_t* RESTRICT sumi, const void* packed_B, int k, int nr) {\n  const uint8_t* scales = reinterpret_cast<const uint8_t*>((const char*)packed_B + get_quants_size<TB>());\n  const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)(scales + k * TILE_N)));\n\n  for (int m = 0; m < nr; ++m) {\n    __m512i vsumi;\n    if (is_acc) {\n      vsumi = _mm512_loadu_si512(sumi + m * TILE_N);\n    } else {\n      vsumi = _mm512_setzero_si512();\n    }\n    __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N);\n    vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale));\n    _mm512_storeu_si512((__m512i*)(sumi + m * TILE_N), vsumi);\n  }\n}\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_avx {\n  static void apply(int K, const TA* RESTRICT A, const TB* RESTRICT B, TC* RESTRICT C, int ldc) {\n    GGML_UNUSED(K);\n    GGML_UNUSED(A);\n    GGML_UNUSED(B);\n    GGML_UNUSED(C);\n    GGML_UNUSED(ldc);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int K, const float* RESTRICT A, const ggml_fp16_t* RESTRICT B, float* RESTRICT C, int ldc) {\n    constexpr int ROWS = BLOCK_M;\n    constexpr int COLS = BLOCK_N;\n    assert(BLOCK_K == 16);\n\n    __m512 va;\n    __m512 vb[COLS];\n    __m512 vc[ROWS * COLS];\n\n    auto loadc = [&](int idx) { vc[idx] = _mm512_setzero_ps(); };\n    Unroll<ROWS * COLS>{}(loadc);\n\n    auto compute = [&](int idx, int k) {\n      // TODO: use `constexpr` here to get rid of interger div\n      // when upgraded to C++17\n      const int row = idx / COLS;\n      const int col = idx % COLS;\n\n      if (col == 0) {\n        va = _mm512_loadu_ps(A + row * K + k);\n      }\n      if (row == 0) {\n        vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(B + col * K + k)));\n      }\n      vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);\n    };\n\n    for (int k = 0; k < K; k += 16) {\n      Unroll<ROWS * COLS>{}(compute, k);\n    }\n\n    auto storec = [&](int idx) {\n      const int row = idx / COLS;\n      const int col = idx % COLS;\n      C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]);\n    };\n    Unroll<ROWS * COLS>{}(storec);\n  }\n};\n\n#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE)                                      \\\n  tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply(            \\\n      K, (const float*)src1->data + mb_start * K, (const type*)src0->data + nb_start * K, \\\n      (float*)dst->data + mb_start * ldc + nb_start, ldc);\n\n// re-organize in the format {NB, KB, TILE_SIZE}:\n#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size\n\ntemplate <typename TB, int BLOCK_K>\nvoid convert_B_packed_format(void* RESTRICT packed_B, const TB* RESTRICT B, int N, int K) {\n  const int NB = N / TILE_N;\n  const int KB = K / BLOCK_K;\n  const int TILE_SIZE = get_tile_size<TB>();\n\n  // parallel on NB should be enough\n  parallel_for(1, 0, NB, [&](int begin, int end) {\n    for (int n = begin; n < end; ++n) {\n      for (int k = 0; k < KB; ++k) {\n        int n0 = n * TILE_N;\n        pack_B((char*)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB);\n      }\n    }\n  });\n}\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni {};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q4_0);\n\n    const block_q8_0* RESTRICT A = static_cast<const block_q8_0*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    __m512i va[8];\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // sum of offsets, shared across COLS\n    //\n    // avx512-vnni does not have `_mm512_dpbssd_epi32`,\n    // need to transfrom ss to us:\n    //   a * (b - 8) is equavilent to b * a - 8 * a\n    //   s    u   u                   u   s   u   s\n    //\n    __m512i vcomp;\n\n    const __m512i off = _mm512_set1_epi8(8);\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    auto compute = [&](int col, int i) {\n      // load a and compute compensation\n      if (col == 0) {\n        const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A[0 * KB + i].qs);\n        vcomp = _mm512_setzero_si512();\n        for (int k = 0; k < 8; ++k) {\n          va[k] = _mm512_set1_epi32(a_ptr[k]);\n          vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]);\n        }\n        vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));\n      }\n\n      // load b\n      __m512i vsum = _mm512_setzero_si512();\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      for (int k = 0; k < 8; k += 2) {\n        __m512i bytes = _mm512_loadu_si512((const __m512i*)(b_ptr + k * 32));\n        __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n        vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]);\n        __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n        vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]);\n      }\n      const int offset = TILE_N * TILE_K / 2;\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset)));\n      vsum = _mm512_sub_epi32(vsum, vcomp);\n\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_1, block_q4_1, float, 1, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q4_1);\n\n    const block_q8_1* RESTRICT A = static_cast<const block_q8_1*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    __m512i va[8];\n    __m512i vb[8];\n    __m512 vc[COLS];\n    __m512 vd1, vs1;\n\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    auto compute = [&](int col, int i) {\n      // load a\n      if (col == 0) {\n        const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A[0 * KB + i].qs);\n        for (int k = 0; k < 8; ++k) {\n          va[k] = _mm512_set1_epi32(a_ptr[k]);\n        }\n        vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));\n        vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].s));\n      }\n\n      // load b\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      for (int k = 0; k < 8; k += 2) {\n        __m512i bytes = _mm512_loadu_si512((const __m512i*)(b_ptr + k * 32));\n        vb[k + 0] = _mm512_and_si512(bytes, lowMask);\n        vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n      }\n      const int offset = TILE_N * TILE_K / 2;\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset)));\n      const __m512 vm0 =\n          _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset + TILE_N * sizeof(ggml_half))));\n\n      __m512i vsum = _mm512_setzero_si512();\n      for (int k = 0; k < 8; ++k) {\n        vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]);\n      }\n\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);\n      vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_0, block_q8_0, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t);\n\n    const block_q8_0* RESTRICT A = static_cast<const block_q8_0*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    __m512i va[8];\n    __m512i vb[8];\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // Notes: s8s8 igemm compensation in avx512-vnni\n    // change s8s8 to u8s8 with compensate\n    //   a * b = (a + 128) * b - 128 * b\n    //   s   s       u       s    u    s\n    //\n    // (128 * b is pre-computed when packing B to vnni formats)\n    //\n    const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    auto compute = [&](int col, int i) {\n      // load a and add offset 128\n      if (col == 0) {\n        const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A[0 * KB + i].qs);\n        for (int k = 0; k < 8; ++k) {\n          va[k] = _mm512_set1_epi32(a_ptr[k]);\n          va[k] = _mm512_add_epi8(va[k], off);\n        }\n        vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));\n      }\n\n      // load b\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      for (int k = 0; k < 8; ++k) {\n        vb[k] = _mm512_loadu_si512((const __m512i*)(b_ptr + k * 64));\n      }\n      const int offset = TILE_N * TILE_K;\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset)));\n      const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);\n      const __m512i vcomp = _mm512_loadu_si512((const __m512i*)(b_ptr + offset2));\n\n      __m512i vsum = _mm512_setzero_si512();\n      for (int k = 0; k < 8; ++k) {\n        vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]);\n      }\n      vsum = _mm512_sub_epi32(vsum, vcomp);\n\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_q4_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4;\n\n    const block_q8_K* RESTRICT A = static_cast<const block_q8_K*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    // a.qs:   8 groups, 32 bytes each group (m256i)\n    __m512i va[8];\n    // a.bsum: 8 groups,  2 bytes each group (m128i)\n    __m512i va_bsum;\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // packed_B:\n    const int offset_scales = (QK_K / 2) * TILE_N;\n    const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N;\n    const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N;\n    const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);\n\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    // Notes: vnni formats in QK_K\n    //   a) quants vnni format\n    //     int8  {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32\n    //     from {16, 32} to {8, 64}\n    //\n    //   b) min vnni format\n    //     int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8\n    //     from {16,  8} to {4, 32}\n    //\n    auto compute = [&](int col, int i) {\n      // load a\n      if (col == 0) {\n        for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n          va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(A[0 * KB + i].qs + k_group * 32)));\n        }\n        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[0 * KB + i].bsums);\n        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n        va_bsum = _mm512_castsi128_si512(q8s);\n        vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n      }\n\n      // step 1: accumultate the quants\n      __m512i acc = _mm512_setzero_si512();\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      const char* b_qs = b_ptr;\n      for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n        __m512i vsum = _mm512_setzero_si512();\n        for (int k = 0; k < 8; k += 2) {\n          __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);\n          __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);\n\n          __m512i bytes = _mm512_loadu_si512((const __m512i*)b_qs);\n          __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n          vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n          __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n          vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n\n          b_qs += 64;\n        }\n        // vacc += scale * (q8 @ q4)\n        const __m512i vscale =\n            _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)(b_ptr + offset_scales + k_group * TILE_N)));\n        acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n      }\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_d0)));\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n\n      // step 2: accumulate the mins\n      __m512i acc_m = _mm512_setzero_si512();\n      for (int k = 0; k < 4; ++k) {\n        __m512i vmask = _mm512_set1_epi32(k);\n        __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);\n        __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_mins + k * 32)));\n        acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n      }\n      const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_dmin)));\n      vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_q5_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4;\n\n    const block_q8_K* RESTRICT A = static_cast<const block_q8_K*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    // a.qs:   8 groups, 32 bytes each group (m256i)\n    __m512i va[8];\n    // a.bsum: 8 groups,  2 bytes each group (m128i)\n    __m512i va_bsum;\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // packed_B:\n    const int offset_qh = (QK_K / 2) * TILE_N;\n    const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N;\n    const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N;\n    const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N;\n    const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);\n\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    // Q5_K and Q4_K shares the same vnni formats, refer to notes above.\n    auto compute = [&](int col, int i) {\n      // load a\n      if (col == 0) {\n        for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n          va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i*)(A[0 * KB + i].qs + k_group * 32)));\n        }\n        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[0 * KB + i].bsums);\n        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));\n        va_bsum = _mm512_castsi128_si512(q8s);\n        vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n      }\n\n      // step 1: accumultate the quants\n      __m512i acc = _mm512_setzero_si512();\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      const char* b_qs = b_ptr;\n      const char* b_qh = b_ptr + offset_qh;\n      for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n        __m512i vsum = _mm512_setzero_si512();\n        __m512i hmask0 = _mm512_set1_epi8(0x1);\n        __m512i hmask1 = _mm512_set1_epi8(0x2);\n        __m512i hbits = _mm512_loadu_si512((const __m512i*)(b_qh + k_group * 64));\n        for (int k = 0; k < 8; k += 2) {\n          __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);\n          __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);\n\n          __m512i bytes = _mm512_loadu_si512((const __m512i*)b_qs);\n          __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n          __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n\n          __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4);\n          __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4);\n\n          hmask0 = _mm512_slli_epi16(hmask0, 2);\n          hmask1 = _mm512_slli_epi16(hmask1, 2);\n          vb0 = _mm512_add_epi8(vb0, vh0);\n          vb1 = _mm512_add_epi8(vb1, vh1);\n\n          vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n          vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n\n          b_qs += 64;\n        }\n        // vacc += scale * (q8 @ q5)\n        const __m512i vscale =\n            _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)(b_ptr + offset_scales + k_group * TILE_N)));\n        acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n      }\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_d0)));\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n\n      // step 2: accumulate the mins\n      __m512i acc_m = _mm512_setzero_si512();\n      for (int k = 0; k < 4; ++k) {\n        __m512i vmask = _mm512_set1_epi32(k);\n        __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);\n        __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_mins + k * 32)));\n        acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);\n      }\n      const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_dmin)));\n      vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_q6_K, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_q6_K);\n\n    const block_q8_K* RESTRICT A = static_cast<const block_q8_K*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    // load the 256 bytes from A to 4 avx512 vectors\n    __m512i va[4];\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // packed_B:\n    const int offset_qh = (QK_K / 2) * TILE_N;\n    const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N;\n    const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N;\n\n    // compensation\n    __m512i vcomp;\n\n    const __m512i m32s = _mm512_set1_epi32(32);\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    auto compute = [&](int col, int i) {\n      if (col == 0) {\n        // load a\n        va[0] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 0));\n        va[1] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 64));\n        va[2] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 128));\n        va[3] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 192));\n\n        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[0 * KB + i].bsums);\n        vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s);\n        vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n      }\n\n      // accmulate the quants\n      __m512i acc = _mm512_setzero_si512();\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      const char* b_qs = b_ptr;\n      const char* b_qh = b_ptr + offset_qh;\n      int mask = 0;\n      for (int k_group = 0; k_group < QK_K / 16; ++k_group) {\n        int r = k_group >> 2;\n        __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n        __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n\n        __m512i vsum = _mm512_setzero_si512();\n        __m512i hmask = _mm512_set1_epi8(0x3);\n\n        __m512i bytes = _mm512_loadu_si512(b_qs);\n        __m512i hbits = _mm512_loadu_si512(b_qh);\n        __m512i vb0 = _mm512_and_si512(bytes, lowMask);\n        __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n        __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4);\n        __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2);\n\n        vb0 = _mm512_add_epi8(vb0, vh0);\n        vb1 = _mm512_add_epi8(vb1, vh1);\n        vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n        vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n        b_qs += 64;\n\n        va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n        va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n\n        bytes = _mm512_loadu_si512(b_qs);\n        vb0 = _mm512_and_si512(bytes, lowMask);\n        vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);\n        vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4));\n        vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2);\n        vb0 = _mm512_add_epi8(vb0, vh0);\n        vb1 = _mm512_add_epi8(vb1, vh1);\n        vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n        vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n        b_qs += 64;\n        b_qh += 64;\n\n        // B * A - 32 * A\n        __m512i vmask = _mm512_set1_epi32(k_group);\n        vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));\n\n        // vacc += scale * (q8 @ q6)\n        const __m512i vscale =\n            _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)(b_ptr + offset_scales + k_group * TILE_N)));\n        acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n      }\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_d0)));\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\ntemplate <int BLOCK_M, int BLOCK_N, int BLOCK_K>\nstruct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, BLOCK_K> {\n  static void apply(int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C, int ldc) {\n    constexpr int COLS = BLOCK_N / 16;\n    const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2;\n\n    const block_q8_K* RESTRICT A = static_cast<const block_q8_K*>(_A);\n    const char* RESTRICT B = static_cast<const char*>(_B);\n\n    // load the 256 bytes from A to 4 avx512 vectors\n    __m512i va[4];\n    __m512 vc[COLS];\n    __m512 vd1;\n\n    // packed_B:\n    const int offset_scales = (QK_K / 2) * TILE_N;\n    const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N;\n\n    // compensation\n    __m512i vcomp;\n\n    const __m256i m128s = _mm256_set1_epi16(128);\n    const __m512i lowMask = _mm512_set1_epi8(0xF);\n\n    const __m512i values128 = _mm512_set_epi8(113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n                                              113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n                                              113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,\n                                              113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127);\n    const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));\n    const __m512i values256 = _mm512_add_epi8(values128, off);\n\n    auto loadc = [&](int col) { vc[col] = _mm512_setzero_ps(); };\n    Unroll<COLS>{}(loadc);\n\n    auto compute = [&](int col, int i) {\n      if (col == 0) {\n        // load a\n        va[0] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 0));\n        va[1] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 64));\n        va[2] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 128));\n        va[3] = _mm512_loadu_si512((const __m512i*)(A[0 * KB + i].qs + 192));\n\n        // compensation: 128 * A\n        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)A[0 * KB + i].bsums);\n        vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s));\n        vd1 = _mm512_set1_ps(A[0 * KB + i].d);\n      }\n\n      // accmulate the quants\n      __m512i acc = _mm512_setzero_si512();\n      const char* b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);\n      const char* b_qs = b_ptr;\n      int mask = 0;\n      for (int k_group = 0; k_group < QK_K / 32; ++k_group) {\n        int r = k_group >> 1;\n        __m512i vmask = _mm512_set1_epi32(k_group);\n        __m512i vsum = _mm512_setzero_si512();\n        for (int k = 0; k < 8; k += 2) {\n          __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n          __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);\n\n          __m512i bytes = _mm512_loadu_si512(b_qs);\n          __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask));\n          __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));\n\n          vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);\n          vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);\n          b_qs += 64;\n        }\n        // (B + 128) * A - 128 * A\n        vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));\n\n        // vacc += scale * (q8 @ q4)\n        const __m512i vscale =\n            _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i*)(b_ptr + offset_scales + k_group * TILE_N)));\n        acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));\n      }\n      const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)(b_ptr + offset_d0)));\n      vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);\n    };\n\n    for (int i = 0; i < KB; ++i) {\n      Unroll<COLS>{}(compute, i);\n    }\n\n    // store to C\n    auto storec = [&](int col) { _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); };\n    Unroll<COLS>{}(storec);\n  }\n};\n\n#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE)                                                                         \\\n  tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply(                                     \\\n      KB, (const char*)wdata + 0 * row_size_A,                                                                       \\\n      (const char*)src0->extra + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), (float*)dst->data + 0 * N + nb_start, \\\n      ldc)\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_K,\n          typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0>\nvoid tinygemm_kernel_amx(int M, int N, int KB, const void* RESTRICT _A, const void* RESTRICT _B, TC* RESTRICT C,\n                         int ldc) {\n  using packed_B_t = packed_B_type<TB>;\n  const int TILE_SIZE = get_tile_size<TB>();\n  const bool need_unpack = do_unpack<TB>::value;\n\n  GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);\n  const TA* RESTRICT A = static_cast<const TA*>(_A);\n  const char* RESTRICT B = static_cast<const char*>(_B);\n\n  const int m0 = std::min(M, TILE_M);\n  const int m1 = std::max(M - TILE_M, 0);\n  const int lda = KB * sizeof(TA);\n  // const int ldb = KB * sizeof(TB);\n\n  static thread_local packed_B_t Tile0[TILE_N * TILE_K];\n  static thread_local packed_B_t Tile1[TILE_N * TILE_K];\n  static thread_local int8_t Tile23[TILE_M * TILE_K];\n\n  static thread_local int32_t TileC0[TILE_M * TILE_N * 4];\n  static thread_local int32_t TileC1[TILE_M * TILE_N * 4];\n\n  // double buffering C to interleave avx512 and amx\n  int32_t* C_cur = TileC0;\n  int32_t* C_pre = TileC1;\n\n  auto Tile4 = [&](int32_t* base) { return base; };\n  auto Tile5 = [&](int32_t* base) { return base + TILE_M * TILE_N; };\n  auto Tile6 = [&](int32_t* base) { return base + 2 * TILE_M * TILE_N; };\n  auto Tile7 = [&](int32_t* base) { return base + 3 * TILE_M * TILE_N; };\n\n  if (M == 2 * TILE_M) {\n    // i = 0\n    const char* B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE);\n    const char* B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE);\n    if (need_unpack) {\n      unpack_B<TB>(Tile0, B_blk0);\n      _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n    } else {\n      _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);\n    }\n\n    _tile_zero(TMM4);\n    _tile_loadd(TMM2, A[0].qs, lda);\n    _tile_dpbssd(TMM4, TMM2, TMM0);\n    _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t));\n\n    _tile_zero(TMM5);\n    _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda);\n    _tile_dpbssd(TMM5, TMM3, TMM0);\n    _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));\n\n    if (need_unpack) {\n      unpack_B<TB>(Tile1, B_blk0);\n      _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n    } else {\n      _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);\n    }\n\n    _tile_zero(TMM6);\n    _tile_dpbssd(TMM6, TMM2, TMM1);\n    _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t));\n\n    _tile_zero(TMM7);\n    _tile_dpbssd(TMM7, TMM3, TMM1);\n    _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t));\n\n    for (int i = 1; i < KB; ++i) {\n      // index of previous iter\n      const int ii = i - 1;\n      const char* B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);\n      const char* B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);\n      GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] {\n        if (need_unpack) {\n          unpack_B<TB>(Tile0, B_blk0);\n          _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n        } else {\n          _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);\n        }\n        _tile_zero(TMM4);\n        _tile_loadd(TMM2, A[i].qs, lda);\n        acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n\n        _tile_dpbssd(TMM4, TMM2, TMM0);\n        _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));\n\n        _tile_zero(TMM5);\n        _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda);\n        acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB,\n                                     B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n\n        _tile_dpbssd(TMM5, TMM3, TMM0);\n        _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));\n\n        if (need_unpack) {\n          unpack_B<TB>(Tile1, B_blk1);\n          _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n        } else {\n          _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);\n        }\n        _tile_zero(TMM6);\n        acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE),\n                                     TILE_M);\n\n        _tile_dpbssd(TMM6, TMM2, TMM1);\n        _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));\n\n        _tile_zero(TMM7);\n        acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB,\n                                     B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);\n\n        _tile_dpbssd(TMM7, TMM3, TMM1);\n        _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));\n\n        std::swap(C_cur, C_pre);\n      });\n    }\n    // final accumulation\n    {\n      int ii = KB - 1;\n      acc_C<TA, TB, true>::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n      acc_C<TA, TB, true>::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB,\n                                 B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);\n      acc_C<TA, TB, true>::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE),\n                                 TILE_M);\n      acc_C<TA, TB, true>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB,\n                                 B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);\n    }\n  } else {\n    for (int i = 0; i < KB; ++i) {\n      _tile_zero(TMM4);\n      _tile_zero(TMM6);\n      if (m1 != 0) {\n        _tile_zero(TMM5);\n        _tile_zero(TMM7);\n      }\n\n      const char* B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);\n      const char* B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);\n      if (need_unpack) {\n        unpack_B<TB>(Tile0, B_blk0);\n        _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n      } else {\n        _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);\n      }\n\n      if (need_unpack) {\n        unpack_B<TB>(Tile1, B_blk1);\n        _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n      } else {\n        _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);\n      }\n\n      if (m0 == TILE_M) {\n        _tile_loadd(TMM2, A[i].qs, lda);\n      } else {\n        unpack_A(Tile23, &A[i], KB, m0);\n        _tile_loadd(TMM2, Tile23, TILE_K);\n      }\n\n      _tile_dpbssd(TMM4, TMM2, TMM0);\n      _tile_dpbssd(TMM6, TMM2, TMM1);\n\n      _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));\n      _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));\n\n      GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {\n        acc_C<TA, TB, is_acc>::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);\n        acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE),\n                                     m0);\n      });\n      if (m1 != 0) {\n        unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1);\n        _tile_loadd(TMM3, Tile23, TILE_K);\n\n        _tile_dpbssd(TMM5, TMM3, TMM0);\n        _tile_dpbssd(TMM7, TMM3, TMM1);\n        _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));\n        _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));\n        GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {\n          acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB,\n                                       B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);\n          acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB,\n                                       B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);\n        });\n      }\n    }\n  }\n  return;\n}\n\ntemplate <typename TA, typename TB, typename TC, int BLOCK_K,\n          typename std::enable_if<is_type_qkk<TB>::value, int>::type = 0>\nvoid tinygemm_kernel_amx(int M, int N, int KB, const void* RESTRICT _A, const void* RESTRICT _B, float* RESTRICT C,\n                         int ldc) {\n  static_assert(std::is_same<TA, block_q8_K>::value);\n  const int TILE_SIZE = get_tile_size<TB>();\n\n  GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);\n  const TA* RESTRICT A = static_cast<const TA*>(_A);\n  const char* RESTRICT B = static_cast<const char*>(_B);\n\n  const int m0 = std::min(M, TILE_M);\n  const int m1 = std::max(M - TILE_M, 0);\n  // const int lda = KB * sizeof(TA);\n\n  static thread_local int8_t Tile0[TILE_N * TILE_K];\n  static thread_local int8_t Tile1[TILE_N * TILE_K];\n  static thread_local int8_t Tile23[TILE_M * TILE_K];\n\n  // mat mul result for each group\n  static thread_local int32_t Tile4[TILE_M * TILE_N];\n  static thread_local int32_t Tile5[TILE_M * TILE_N];\n  static thread_local int32_t Tile6[TILE_M * TILE_N];\n  static thread_local int32_t Tile7[TILE_M * TILE_N];\n\n  // sum of each QK_K block, contains 8 groups, int32\n  static thread_local int32_t Sumi4[TILE_M * TILE_N];\n  static thread_local int32_t Sumi5[TILE_M * TILE_N];\n  static thread_local int32_t Sumi6[TILE_M * TILE_N];\n  static thread_local int32_t Sumi7[TILE_M * TILE_N];\n\n  const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32;\n  for (int i = 0; i < KB; ++i) {\n    // step 1: accumulate the quants across 8 groups, each group with 32\n    for (int k = 0; k < QK_K / k_group_size; ++k) {\n      GGML_DISPATCH_BOOL(k > 0, is_acc, [&] {\n        _tile_zero(TMM4);\n        _tile_zero(TMM6);\n\n        unpack_B<TB>(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k);\n        _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);\n\n        unpack_B<TB>(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k);\n        _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);\n\n        unpack_A<TB>(Tile23, &A[i], KB, k, m0);\n        _tile_loadd(TMM2, Tile23, TILE_K);\n\n        _tile_dpbssd(TMM4, TMM2, TMM0);\n        _tile_dpbssd(TMM6, TMM2, TMM1);\n\n        _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));\n        _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));\n\n        scale_C<TB, is_acc>(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0);\n        scale_C<TB, is_acc>(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0);\n\n        if (m1 != 0) {\n          _tile_zero(TMM5);\n          _tile_zero(TMM7);\n\n          unpack_A<TB>(Tile23, &A[TILE_M * KB + i], KB, k, m1);\n          _tile_loadd(TMM3, Tile23, TILE_K);\n\n          _tile_dpbssd(TMM5, TMM3, TMM0);\n          _tile_dpbssd(TMM7, TMM3, TMM1);\n\n          _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));\n          _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));\n\n          scale_C<TB, is_acc>(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1);\n          scale_C<TB, is_acc>(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1);\n        }\n      });\n    }\n\n    // step 2: accmulate the mins\n    GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {\n      acc_C<TA, TB, is_acc>::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);\n      acc_C<TA, TB, is_acc>::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);\n      if (m1 != 0) {\n        acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB,\n                                     B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);\n        acc_C<TA, TB, is_acc>::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB,\n                                     B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);\n      }\n    });\n  }\n  return;\n}\n\n}  // anonymous namespace\n\n#define ARCH_GET_XCOMP_PERM 0x1022\n#define ARCH_REQ_XCOMP_PERM 0x1023\n#define XFEATURE_XTILECFG 17\n#define XFEATURE_XTILEDATA 18\n\nbool ggml_amx_init() {\n#if defined(__gnu_linux__)\n  if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {\n    fprintf(stderr, \"AMX is not ready to be used!\\n\");\n    return false;\n  }\n  return true;\n#elif defined(_WIN32)\n  return true;\n#endif\n}\n\nbool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor* dst) {\n  static thread_local bool is_first_time = true;\n  if (is_first_time) {\n#pragma omp single\n    { ggml_amx_init(); }\n\n    // load tile config\n    ggml_tile_config_init();\n  }\n  is_first_time = false;\n\n  const struct ggml_tensor* src0 = dst->src[0];\n  const struct ggml_tensor* src1 = dst->src[1];\n\n  const enum ggml_type type = src0->type;\n  const int64_t ne0 = dst->ne[0];\n\n  bool is_training = src0->grad || src1->grad;\n\n  // amx kernels enables for Q4_0, Q4_1, Q8_0, F16\n  // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256\n  bool has_amx_kernels = (type == GGML_TYPE_Q4_0) || (type == GGML_TYPE_Q4_1) || (type == GGML_TYPE_Q8_0) ||\n#ifndef GGML_QKK_64\n                         // only enabled for QK_K == 256\n                         (type == GGML_TYPE_Q4_K) || (type == GGML_TYPE_Q5_K) || (type == GGML_TYPE_Q6_K) ||\n                         (type == GGML_TYPE_IQ4_XS) ||\n#endif\n                         (type == GGML_TYPE_F16);\n\n  // handle only 2d gemm for now\n  auto is_contiguous_2d = [](const struct ggml_tensor* t) {\n    return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;\n  };\n\n  return dst->op != GGML_OP_MUL_MAT_ID && is_contiguous_2d(src0) && is_contiguous_2d(src1) && !is_training &&\n         src1->type == GGML_TYPE_F32 && has_amx_kernels &&\n         // out features is 32x\n         ne0 % (TILE_N * 2) == 0;\n}\n\n// NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX)\n//\n// src0: weight in shape of {N, K}, quantized\n// src1: input  in shape of {M, K}, float32\n// dst:  output in shape of {M, N}, float32\n//\n// the function performs: dst = src1 @ src0.T\n//\nvoid ggml_mul_mat_amx(struct ggml_tensor* dst, int nth, int ith, void* wdata, int wsize) {\n  struct ggml_tensor* src0 = dst->src[0];\n  struct ggml_tensor* src1 = dst->src[1];\n\n  const enum ggml_type TYPE = src0->type;\n\n  // f16 only has avx512 kernels for now,\n  // amx kernels will be added once 6th gen xeon is released.\n  const bool is_floating_type = TYPE == GGML_TYPE_F16;\n\n  const int M = dst->ne[1];\n  const int N = dst->ne[0];\n  const int K = src0->ne[0];\n  const int ldc = dst->nb[1] / dst->nb[0];\n\n  if (is_floating_type) {\n    constexpr int BLOCK_M = 4;\n    constexpr int BLOCK_N = 6;\n    const int MB = div_up(M, BLOCK_M);\n    const int NB = div_up(N, BLOCK_N);\n\n    parallel_for(nth, ith, MB * NB, [&](int begin, int end) {\n      GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {\n        for (int i = begin; i < end; ++i) {\n          int mb = i / NB;\n          int nb = i % NB;\n\n          int mb_start = mb * BLOCK_M;\n          int mb_size = std::min(BLOCK_M, M - mb_start);\n          int nb_start = nb * BLOCK_N;\n          int nb_size = std::min(BLOCK_N, N - nb_start);\n\n          switch (mb_size << 4 | nb_size) {\n            case 0x12:\n              LAUNCH_TINYGEMM_KERNEL_AVX(1, 2);\n              break;\n            case 0x14:\n              LAUNCH_TINYGEMM_KERNEL_AVX(1, 4);\n              break;\n            case 0x16:\n              LAUNCH_TINYGEMM_KERNEL_AVX(1, 6);\n              break;\n            case 0x22:\n              LAUNCH_TINYGEMM_KERNEL_AVX(2, 2);\n              break;\n            case 0x24:\n              LAUNCH_TINYGEMM_KERNEL_AVX(2, 4);\n              break;\n            case 0x26:\n              LAUNCH_TINYGEMM_KERNEL_AVX(2, 6);\n              break;\n            case 0x32:\n              LAUNCH_TINYGEMM_KERNEL_AVX(3, 2);\n              break;\n            case 0x34:\n              LAUNCH_TINYGEMM_KERNEL_AVX(3, 4);\n              break;\n            case 0x36:\n              LAUNCH_TINYGEMM_KERNEL_AVX(3, 6);\n              break;\n            case 0x42:\n              LAUNCH_TINYGEMM_KERNEL_AVX(4, 2);\n              break;\n            case 0x44:\n              LAUNCH_TINYGEMM_KERNEL_AVX(4, 4);\n              break;\n            case 0x46:\n              LAUNCH_TINYGEMM_KERNEL_AVX(4, 6);\n              break;\n            default:\n              fprintf(stderr, \"Unexpected block size!\\n\");\n          }\n        }\n      });\n    });\n    return;\n  }\n\n#pragma omp single\n  {\n    GGML_DISPATCH_QTYPES(TYPE, [&] {\n      const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);\n      GGML_ASSERT(wsize >= int(M * row_size_A));\n\n      // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size\n      // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size\n      GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);\n      // pack mat B to vnni format\n      if (src0->extra == nullptr) {\n        const size_t row_size_B = get_row_size<type, blck_size>(K);\n        src0->extra = aligned_alloc(64, N * row_size_B);\n        convert_B_packed_format<type, blck_size>((void*)src0->extra, (const type*)src0->data, N, K);\n      }\n\n      const float* A_data = static_cast<const float*>(src1->data);\n      for (int m = 0; m < M; ++m) {\n        from_float<vec_dot_type>(A_data + m * K, (char*)wdata + m * row_size_A, K);\n      }\n    });\n  }\n\n  GGML_ASSERT(src0->extra != nullptr);\n  if (M == 1) {\n    // MB = 1 and handle 8 tiles in each block\n    constexpr int kTilesN = 4;\n    constexpr int BLOCK_N = TILE_N * kTilesN;\n    const int NB = div_up(N, BLOCK_N);\n\n    parallel_for(nth, ith, NB, [&](int begin, int end) {\n      GGML_DISPATCH_QTYPES(TYPE, [&] {\n        const int KB = K / blck_size;\n        const int TILE_SIZE = get_tile_size<type>();\n        const int row_size_A = KB * sizeof(vec_dot_type);\n        for (int i = begin; i < end; ++i) {\n          int nb = i;\n          int nb_start = nb * BLOCK_N;\n          int nb_size = std::min(BLOCK_N, N - nb_start);  // 32, 64, 96\n\n          switch (nb_size) {\n            // case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break;\n            case 128:\n              LAUNCH_TINYGEMM_KERNEL_VNNI(128);\n              break;\n            case 96:\n              LAUNCH_TINYGEMM_KERNEL_VNNI(96);\n              break;\n            case 64:\n              LAUNCH_TINYGEMM_KERNEL_VNNI(64);\n              break;\n            case 32:\n              LAUNCH_TINYGEMM_KERNEL_VNNI(32);\n              break;\n            default:\n              fprintf(stderr, \"Unexpected n block size!\\n\");\n          }\n        }\n      });\n    });\n    return;\n  }\n\n  // handle 4 tiles at a tile\n  constexpr int BLOCK_M = TILE_M * 2;\n  constexpr int BLOCK_N = TILE_N * 2;\n  const int MB = div_up(M, BLOCK_M);\n  const int NB = div_up(N, BLOCK_N);\n\n  parallel_for(nth, ith, MB * NB, [&](int begin, int end) {\n    GGML_DISPATCH_QTYPES(TYPE, [&] {\n      const int KB = K / blck_size;\n      const int TILE_SIZE = get_tile_size<type>();\n      const int row_size_A = KB * sizeof(vec_dot_type);\n\n      for (int i = begin; i < end; ++i) {\n        int mb = i / NB;\n        int nb = i % NB;\n\n        int mb_start = mb * BLOCK_M;\n        int mb_size = std::min(BLOCK_M, M - mb_start);\n        int nb_start = nb * BLOCK_N;\n        int nb_size = BLOCK_N;\n\n        tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>(\n            mb_size, nb_size, KB, (const char*)wdata + mb_start * row_size_A,\n            (const char*)src0->extra + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),\n            (float*)dst->data + mb_start * N + nb_start, ldc);\n      }\n    });\n  });\n}\n\n#else  // if defined(__AMX_INT8__)\n\nbool ggml_amx_init() {\n  fprintf(stderr, \"GGML is not compiled with AMX support!\\n\");\n  return false;\n}\n\nbool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor* dst) {\n  GGML_UNUSED(dst);\n  fprintf(stderr, \"GGML is not compiled with AMX support!\\n\");\n  return false;\n}\n\nvoid ggml_mul_mat_amx(struct ggml_tensor* dst, int nth, int ith, void* wdata, int wsize) {\n  GGML_UNUSED(dst);\n  GGML_UNUSED(nth);\n  GGML_UNUSED(ith);\n  GGML_UNUSED(wdata);\n  GGML_UNUSED(wsize);\n  fprintf(stderr, \"GGML is not compiled with AMX support!\\n\");\n}\n\n#endif  // if defined(__AMX_INT8__)\nint main() {\n  // to be written\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/mmq.h",
    "content": "#ifndef MMQ_H\n#define MMQ_H\n#include <stdint.h>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nbool ggml_amx_init(void);\n\nbool ggml_compute_forward_mul_mat_use_amx(struct ggml_tensor* dst);\n\nvoid ggml_mul_mat_amx(struct ggml_tensor* dst, int nth, int ith, void* wdata, int wsize);\n\n/**\n * @param m\n * @param n\n * @param k\n * @param a\n * @param a_type\n * @param b\n * @param b_type\n * @param c\n * @param c_type\n * @param ldc c stride in elements\n * @param ith\n * @param nth\n * @param wdata auxillary data area\n * @param wsize size of auxillary data size\n */\n\nvoid mat_mul_amx(int m, int n, int k, const void* a, int a_type, const void* b, int b_type, void* c, int c_type,\n                 int ldc, int ith, int nth, void* wdata, int wsize);\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif  // MMQ_H\n"
  },
  {
    "path": "kt-kernel/operators/amx/test/test-kgroup-128.cpp",
    "content": "#include <cmath>\n#include <iostream>\n#include <memory>\n#include <vector>\n\n#include \"../la/amx.hpp\"\n\nvoid test_kgroup_128() {\n  std::cout << \"=== Testing K-Group with k_group_size = 128 ===\\n\" << std::endl;\n\n  const int m = 32;  // Simple case\n  const int n = 32;\n  const int k = 512;  // Multiple of 128\n  const int k_group_size = 128;\n\n  std::cout << \"Matrix dimensions: \" << m << \" x \" << n << \" x \" << k << std::endl;\n  std::cout << \"K-group size: \" << k_group_size << std::endl;\n  std::cout << \"Number of k-groups: \" << k / k_group_size << std::endl;\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n\n  void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n  void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));\n  void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  Kernel::config();\n\n  // Test 1: All ones\n  std::cout << \"\\n--- Test 1: All ones (expected = \" << k << \") ---\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    for (int i = 0; i < m * k; i++) {\n      input_a[i] = ggml_compute_fp32_to_bf16(1.0f);\n    }\n    for (int i = 0; i < k * n; i++) {\n      input_b[i] = ggml_compute_fp32_to_bf16(1.0f);\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    float actual = ggml_compute_bf16_to_fp32(output[0]);\n    float error = std::abs(actual - k) / k * 100;\n    std::cout << \"Result[0,0]: \" << actual << \" (error: \" << error << \"%)\" << std::endl;\n  }\n\n  // Test 2: Values in quantization sweet spot (0.5)\n  std::cout << \"\\n--- Test 2: All 0.5 (expected = \" << 0.5f * 0.5f * k << \") ---\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    for (int i = 0; i < m * k; i++) {\n      input_a[i] = ggml_compute_fp32_to_bf16(0.5f);\n    }\n    for (int i = 0; i < k * n; i++) {\n      input_b[i] = ggml_compute_fp32_to_bf16(0.5f);\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    float expected = 0.5f * 0.5f * k;\n    float actual = ggml_compute_bf16_to_fp32(output[0]);\n    float error = std::abs(actual - expected) / expected * 100;\n    std::cout << \"Result[0,0]: \" << actual << \" (expected: \" << expected << \", error: \" << error << \"%)\" << std::endl;\n  }\n\n  // Test 3: Different values per k-group\n  std::cout << \"\\n--- Test 3: Different values per k-group ---\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    // Each k-group has different value\n    for (int i = 0; i < m; i++) {\n      for (int j = 0; j < k; j++) {\n        int kg = j / k_group_size;\n        float val = (kg + 1) * 0.25f;  // 0.25, 0.5, 0.75, 1.0\n        input_a[i * k + j] = ggml_compute_fp32_to_bf16(val);\n      }\n    }\n\n    for (int i = 0; i < k * n; i++) {\n      input_b[i] = ggml_compute_fp32_to_bf16(0.5f);\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    // Expected: sum of (kg+1)*0.25 * 0.5 * k_group_size for all k-groups\n    float expected = 0.0f;\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      expected += (kg + 1) * 0.25f * 0.5f * k_group_size;\n    }\n\n    float actual = ggml_compute_bf16_to_fp32(output[0]);\n    float error = std::abs(actual - expected) / expected * 100;\n    std::cout << \"Expected: \" << expected << \", Actual: \" << actual << std::endl;\n    std::cout << \"Error: \" << error << \"%\" << std::endl;\n  }\n\n  // Test 4: Pattern test\n  std::cout << \"\\n--- Test 4: Pattern with alternating values ---\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    // Alternating pattern in A\n    for (int i = 0; i < m * k; i++) {\n      float val = (i % 2 == 0) ? 0.25f : 0.75f;\n      input_a[i] = ggml_compute_fp32_to_bf16(val);\n    }\n\n    // Constant in B\n    for (int i = 0; i < k * n; i++) {\n      input_b[i] = ggml_compute_fp32_to_bf16(0.4f);\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    // Expected: average of 0.25 and 0.75 is 0.5, so 0.5 * 0.4 * k\n    float expected = 0.5f * 0.4f * k;\n    float actual = ggml_compute_bf16_to_fp32(output[0]);\n    float error = std::abs(actual - expected) / expected * 100;\n    std::cout << \"Expected: \" << expected << \", Actual: \" << actual << std::endl;\n    std::cout << \"Error: \" << error << \"%\" << std::endl;\n  }\n\n  // Test 5: Check all output elements\n  std::cout << \"\\n--- Test 5: Verify all output elements (0.1 × 0.1) ---\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    for (int i = 0; i < m * k; i++) {\n      input_a[i] = ggml_compute_fp32_to_bf16(0.1f);\n    }\n    for (int i = 0; i < k * n; i++) {\n      input_b[i] = ggml_compute_fp32_to_bf16(0.1f);\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    float expected = 0.1f * 0.1f * k;\n    float max_error = 0.0f;\n    float avg_error = 0.0f;\n    int error_count = 0;\n\n    for (int i = 0; i < m * n; i++) {\n      float actual = ggml_compute_bf16_to_fp32(output[i]);\n      float error = std::abs(actual - expected) / expected * 100;\n      max_error = std::max(max_error, error);\n      avg_error += error;\n      if (error > 5.0f) error_count++;\n    }\n    avg_error /= (m * n);\n\n    std::cout << \"Expected value: \" << expected << std::endl;\n    std::cout << \"Max error: \" << max_error << \"%\" << std::endl;\n    std::cout << \"Average error: \" << avg_error << \"%\" << std::endl;\n    std::cout << \"Elements with >5% error: \" << error_count << \"/\" << m * n << std::endl;\n  }\n\n  // Test 6: Random normal distribution (like real model weights)\n  std::cout << \"\\n--- Test 6: Random normal distribution ---\" << std::endl;\n  {\n    std::vector<ggml_bf16_t> input_a(m * k);\n    std::vector<ggml_bf16_t> input_b(k * n);\n\n    std::mt19937 gen(42);\n    std::normal_distribution<float> dist(0.0f, 0.1f);\n\n    for (int i = 0; i < m * k; i++) {\n      input_a[i] = ggml_compute_fp32_to_bf16(dist(gen));\n    }\n    for (int i = 0; i < k * n; i++) {\n      input_b[i] = ggml_compute_fp32_to_bf16(dist(gen));\n    }\n\n    // Compute reference with float32\n    std::vector<float> ref_result(m * n, 0.0f);\n    for (int i = 0; i < m; i++) {\n      for (int j = 0; j < n; j++) {\n        float sum = 0.0f;\n        for (int l = 0; l < k; l++) {\n          float a_val = ggml_compute_bf16_to_fp32(input_a[i * k + l]);\n          float b_val = ggml_compute_bf16_to_fp32(input_b[l * n + j]);\n          sum += a_val * b_val;\n        }\n        ref_result[i * n + j] = sum;\n      }\n    }\n\n    ba->from_mat(m, input_a.data(), 0, 1);\n    bb->from_mat(input_b.data(), 0, 1);\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n    std::vector<ggml_bf16_t> output(m * n);\n    bc->to_mat(m, output.data(), 0, 1);\n\n    // Compute errors\n    float max_abs_error = 0.0f;\n    float max_rel_error = 0.0f;\n    float avg_rel_error = 0.0f;\n    int large_error_count = 0;\n\n    for (int i = 0; i < m * n; i++) {\n      float actual = ggml_compute_bf16_to_fp32(output[i]);\n      float ref = ref_result[i];\n      float abs_error = std::abs(actual - ref);\n      float rel_error = std::abs(ref) > 1e-6 ? abs_error / std::abs(ref) : 0.0f;\n\n      max_abs_error = std::max(max_abs_error, abs_error);\n      max_rel_error = std::max(max_rel_error, rel_error);\n      avg_rel_error += rel_error;\n\n      if (rel_error > 0.2f) {  // 20% error\n        large_error_count++;\n        if (large_error_count <= 5) {\n          std::cout << \"  [\" << i / n << \",\" << i % n << \"]: actual=\" << actual << \", ref=\" << ref\n                    << \", rel_error=\" << (rel_error * 100) << \"%\" << std::endl;\n        }\n      }\n    }\n    avg_rel_error /= (m * n);\n\n    std::cout << \"Max absolute error: \" << max_abs_error << std::endl;\n    std::cout << \"Max relative error: \" << (max_rel_error * 100) << \"%\" << std::endl;\n    std::cout << \"Average relative error: \" << (avg_rel_error * 100) << \"%\" << std::endl;\n    std::cout << \"Elements with >20% error: \" << large_error_count << \"/\" << m * n << std::endl;\n  }\n\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n}\n\nint main() {\n  test_kgroup_128();\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/test-kgroup-kernel.cpp",
    "content": "#include <omp.h>\n\n#include \"../la/amx.hpp\"\n#define FMT_HEADER_ONLY\n#include <fmt/core.h>\n\n#include <chrono>\n#include <cmath>\n#include <iostream>\n#include <memory>\n#include <random>\n\nvoid test_kgroup_kernel_basic() {\n  std::cout << \"=== Testing GemmKernel224Int4KGroup Basic Functionality ===\" << std::endl;\n\n  // Test parameters - must match kernel requirements\n  const int m = 64;              // Must be multiple of M_STEP (32)\n  const int n = 64;              // Must be multiple of N_STEP (32)\n  const int k = 1024;            // Must be multiple of K_STEP (64)\n  const int k_group_size = 256;  // Must divide k evenly\n\n  std::cout << fmt::format(\"Parameters: m={}, n={}, k={}, k_group_size={}\\n\", m, n, k, k_group_size);\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n\n  // Allocate buffers\n  size_t size_a = BufferA::required_size(m, k, k_group_size);\n  size_t size_b = BufferB::required_size(n, k, k_group_size);  // Fixed: n, k not k, n\n  size_t size_c = BufferC::required_size(m, n);\n\n  void* buffer_a = std::aligned_alloc(64, size_a);\n  void* buffer_b = std::aligned_alloc(64, size_b);\n  void* buffer_c = std::aligned_alloc(64, size_c);\n\n  std::cout << fmt::format(\"Buffer sizes: A={} KB, B={} KB, C={} KB\\n\", size_a / 1024, size_b / 1024, size_c / 1024);\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);  // Fixed: n, k not k, n\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  // Create test input data\n  std::vector<ggml_bf16_t> input_a(m * k);\n  std::vector<ggml_bf16_t> input_b(k * n);\n\n  std::mt19937 gen(42);\n  std::uniform_real_distribution<float> dist(-0.5f, 0.5f);\n\n  // Fill with small values to avoid overflow\n  for (int i = 0; i < m * k; i++) {\n    input_a[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n  for (int i = 0; i < k * n; i++) {\n    input_b[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n\n  // Quantize inputs\n  std::cout << \"Quantizing inputs...\" << std::endl;\n  ba->from_mat(m, input_a.data(), 0, 1);\n  bb->from_mat(input_b.data(), 0, 1);\n\n  // Configure AMX\n  Kernel::config();\n\n  // Run matrix multiplication with k-group quantization\n  std::cout << \"Running k-group matrix multiplication...\" << std::endl;\n  auto start = std::chrono::high_resolution_clock::now();\n\n  amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n  auto end = std::chrono::high_resolution_clock::now();\n  auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();\n  std::cout << fmt::format(\"Time: {} ms\\n\", duration / 1000.0);\n\n  // Convert output to bf16\n  std::vector<ggml_bf16_t> output(m * n);\n  bc->to_mat(m, output.data(), 0, 1);\n\n  // Print sample output values\n  std::cout << \"\\nSample output values:\" << std::endl;\n  for (int i = 0; i < std::min(5, m); i++) {\n    for (int j = 0; j < std::min(5, n); j++) {\n      float val = ggml_compute_bf16_to_fp32(output[i * n + j]);\n      std::cout << fmt::format(\"{:8.4f} \", val);\n    }\n    std::cout << std::endl;\n  }\n\n  // Clean up\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n\n  std::cout << \"\\n✓ Basic test completed!\" << std::endl;\n}\n\nvoid test_kgroup_kernel_correctness() {\n  std::cout << \"\\n=== Testing GemmKernel224Int4KGroup Correctness ===\" << std::endl;\n\n  const int m = 32;\n  const int n = 32;\n  const int k = 512;\n  const int k_group_size = 128;\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n\n  // Allocate buffers\n  void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n  void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));  // Fixed: n, k not k, n\n  void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);  // Fixed: n, k not k, n\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  // Create simple test pattern\n  std::vector<ggml_bf16_t> input_a(m * k);\n  std::vector<ggml_bf16_t> input_b(k * n);\n  std::vector<float> expected(m * n, 0.0f);\n\n  // Fill A with row indices and B with column indices (scaled down)\n  for (int i = 0; i < m; i++) {\n    for (int j = 0; j < k; j++) {\n      input_a[i * k + j] = ggml_compute_fp32_to_bf16((i + 1) * 0.001f);\n    }\n  }\n\n  for (int i = 0; i < k; i++) {\n    for (int j = 0; j < n; j++) {\n      input_b[i * n + j] = ggml_compute_fp32_to_bf16((j + 1) * 0.001f);\n    }\n  }\n\n  // Compute expected result (naive)\n  for (int i = 0; i < m; i++) {\n    for (int j = 0; j < n; j++) {\n      float sum = 0.0f;\n      for (int l = 0; l < k; l++) {\n        float a_val = ggml_compute_bf16_to_fp32(input_a[i * k + l]);\n        float b_val = ggml_compute_bf16_to_fp32(input_b[l * n + j]);\n        sum += a_val * b_val;\n      }\n      expected[i * n + j] = sum;\n    }\n  }\n\n  // Quantize and run\n  ba->from_mat(m, input_a.data(), 0, 1);\n  bb->from_mat(input_b.data(), 0, 1);\n\n  Kernel::config();\n  amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n  // Get output\n  std::vector<ggml_bf16_t> output(m * n);\n  bc->to_mat(m, output.data(), 0, 1);\n\n  // Compare results\n  float max_error = 0.0f;\n  float total_error = 0.0f;\n  int count = 0;\n\n  for (int i = 0; i < m; i++) {\n    for (int j = 0; j < n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output[i * n + j]);\n      float exp = expected[i * n + j];\n      float error = std::abs(actual - exp);\n      max_error = std::max(max_error, error);\n      total_error += error;\n      count++;\n    }\n  }\n\n  float avg_error = total_error / count;\n  float relative_error = max_error / (*std::max_element(expected.begin(), expected.end()) + 1e-8f);\n\n  std::cout << fmt::format(\"Error Analysis:\\n\");\n  std::cout << fmt::format(\"  Max absolute error: {:.6f}\\n\", max_error);\n  std::cout << fmt::format(\"  Average absolute error: {:.6f}\\n\", avg_error);\n  std::cout << fmt::format(\"  Relative error: {:.2f}%\\n\", relative_error * 100);\n\n  // Check acceptability (INT4 quantization + k-group should have reasonable error)\n  if (relative_error < 0.10f) {  // 10% relative error threshold for INT4\n    std::cout << \"✓ Error is within acceptable range for INT4 quantization\" << std::endl;\n  } else {\n    std::cout << \"✗ Error is higher than expected!\" << std::endl;\n  }\n\n  // Print first few values for comparison\n  std::cout << \"\\nFirst 5x5 values comparison:\" << std::endl;\n  std::cout << \"Expected vs Actual:\" << std::endl;\n  for (int i = 0; i < std::min(5, m); i++) {\n    for (int j = 0; j < std::min(5, n); j++) {\n      float actual = ggml_compute_bf16_to_fp32(output[i * n + j]);\n      float exp = expected[i * n + j];\n      std::cout << fmt::format(\"({:.4f},{:.4f}) \", exp, actual);\n    }\n    std::cout << std::endl;\n  }\n\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n\n  std::cout << \"\\n✓ Correctness test completed!\" << std::endl;\n}\n\nvoid test_kgroup_kernel_performance() {\n  std::cout << \"\\n=== Testing GemmKernel224Int4KGroup Performance ===\" << std::endl;\n\n  const int m = 256;\n  const int n = 256;\n  const int k = 2048;\n  const int k_group_size = 512;\n  const int iterations = 100;\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n\n  // Allocate buffers\n  void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n  void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));  // Fixed: n, k not k, n\n  void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);  // Fixed: n, k not k, n\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  // Create random input\n  std::vector<ggml_bf16_t> input_a(m * k);\n  std::vector<ggml_bf16_t> input_b(k * n);\n\n  std::mt19937 gen(42);\n  std::uniform_real_distribution<float> dist(-0.1f, 0.1f);\n\n  for (int i = 0; i < m * k; i++) {\n    input_a[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n  for (int i = 0; i < k * n; i++) {\n    input_b[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n\n  // Quantize\n  ba->from_mat(m, input_a.data(), 0, 1);\n  bb->from_mat(input_b.data(), 0, 1);\n\n  Kernel::config();\n\n  // Warm up\n  for (int i = 0; i < 10; i++) {\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n  }\n\n  // Benchmark\n  auto start = std::chrono::high_resolution_clock::now();\n\n  for (int i = 0; i < iterations; i++) {\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n  }\n\n  auto end = std::chrono::high_resolution_clock::now();\n  auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();\n\n  double avg_time_ms = duration / (1000.0 * iterations);\n  double ops = 2.0 * m * n * k;\n  double gflops = (ops * iterations) / (duration * 1000.0);\n\n  std::cout << fmt::format(\"Matrix size: {}x{}x{}\\n\", m, n, k);\n  std::cout << fmt::format(\"K-group size: {}\\n\", k_group_size);\n  std::cout << fmt::format(\"Average time per multiplication: {:.3f} ms\\n\", avg_time_ms);\n  std::cout << fmt::format(\"Performance: {:.2f} GFLOPS\\n\", gflops);\n\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n\n  std::cout << \"\\n✓ Performance test completed!\" << std::endl;\n}\n\nint main(int argc, char** argv) {\n  std::cout << \"Starting GemmKernel224Int4KGroup Tests\\n\" << std::endl;\n\n  try {\n    test_kgroup_kernel_basic();\n    test_kgroup_kernel_correctness();\n    test_kgroup_kernel_performance();\n\n    std::cout << \"\\n=== All tests completed successfully! ===\" << std::endl;\n  } catch (const std::exception& e) {\n    std::cerr << \"Test failed with exception: \" << e.what() << std::endl;\n    return 1;\n  }\n\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/test-specific-dims.cpp",
    "content": "#include <chrono>\n#include <cmath>\n#include <iostream>\n#include <memory>\n#include <random>\n#include <vector>\n\n#include \"../la/amx.hpp\"\n#include \"../la/amx_buffers.hpp\"\n#include \"../la/amx_kernels.hpp\"\n\nvoid test_specific_dimensions() {\n  std::cout << \"=== Testing Specific Dimensions ===\\n\" << std::endl;\n\n  const int m_original = 200;\n  const int n = 512;\n  const int k = 7168;\n  const int k_group_size = 64;\n\n  // Pad m to nearest multiple of 32 (M_STEP)\n  const int M_STEP = 32;\n  const int m = ((m_original + M_STEP - 1) / M_STEP) * M_STEP;  // Round up to 224\n\n  std::cout << \"Original dimensions: \" << m_original << \" x \" << n << \" x \" << k << std::endl;\n  std::cout << \"Padded dimensions: \" << m << \" x \" << n << \" x \" << k << std::endl;\n  std::cout << \"K-group size is: \" << k_group_size << std::endl;\n  std::cout << \"Number of k-groups: \" << k / k_group_size << std::endl;\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using Kernel_int4_1 = amx::GemmKernel224Int4_1;\n  using Kernel_int4 = amx::GemmKernel224Int4;\n  using Kernel_k_int4_1 = amx::GemmKernel224Int4_1KGroup;\n  using Kernel_k_int4_1_low = amx::GemmKernel224Int4_1_LowKGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n  using BufferA_int4_1 = Kernel_int4_1::BufferA;\n  using BufferB_int4_1 = Kernel_int4_1::BufferB;\n  using BufferC_int4_1 = Kernel_int4_1::BufferC;\n  using BufferA_int4 = Kernel_int4::BufferA;\n  using BufferB_int4 = Kernel_int4::BufferB;\n  using BufferC_int4 = Kernel_int4::BufferC;\n  using BufferA_k_int4_1 = Kernel_k_int4_1::BufferA;\n  using BufferB_k_int4_1 = Kernel_k_int4_1::BufferB;\n  using BufferC_k_int4_1 = Kernel_k_int4_1::BufferC;\n  using BufferA_k_int4_1_low = Kernel_k_int4_1_low::BufferA;\n  using BufferB_k_int4_1_low = Kernel_k_int4_1_low::BufferB;\n  using BufferC_k_int4_1_low = Kernel_k_int4_1_low::BufferC;\n\n  void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n  void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));\n  void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n  void* buffer_a_int4_1 = std::aligned_alloc(64, BufferA_int4_1::required_size(m, k));\n  void* buffer_b_int4_1 = std::aligned_alloc(64, BufferB_int4_1::required_size(n, k));\n  void* buffer_c_int4_1 = std::aligned_alloc(64, BufferC_int4_1::required_size(m, n));\n\n  void* buffer_a_int4 = std::aligned_alloc(64, BufferA_int4::required_size(m, k));\n  void* buffer_b_int4 = std::aligned_alloc(64, BufferB_int4::required_size(n, k));\n  void* buffer_c_int4 = std::aligned_alloc(64, BufferC_int4::required_size(m, n));\n\n  void* buffer_a_k_int4_1 = std::aligned_alloc(64, BufferA_k_int4_1::required_size(m, k, k_group_size));\n  void* buffer_b_k_int4_1 = std::aligned_alloc(64, BufferB_k_int4_1::required_size(n, k, k_group_size));\n  void* buffer_c_k_int4_1 = std::aligned_alloc(64, BufferC_k_int4_1::required_size(m, n));\n\n  void* buffer_a_k_int4_1_low = std::aligned_alloc(64, BufferA_k_int4_1_low::required_size(m, k, k_group_size));\n  void* buffer_b_k_int4_1_low = std::aligned_alloc(64, BufferB_k_int4_1_low::required_size(n, k, k_group_size));\n  void* buffer_c_k_int4_1_low = std::aligned_alloc(64, BufferC_k_int4_1_low::required_size(m, n));\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  printf(\"buffer_b ptr:%p\\n\", buffer_b);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  auto ba_int4_1 = std::make_shared<BufferA_int4_1>(m, k, buffer_a_int4_1);\n  auto bb_int4_1 = std::make_shared<BufferB_int4_1>(n, k, buffer_b_int4_1);\n  auto bc_int4_1 = std::make_shared<BufferC_int4_1>(m, n, buffer_c_int4_1);\n\n  auto ba_int4 = std::make_shared<BufferA_int4>(m, k, buffer_a_int4);\n  auto bb_int4 = std::make_shared<BufferB_int4>(n, k, buffer_b_int4);\n  auto bc_int4 = std::make_shared<BufferC_int4>(m, n, buffer_c_int4);\n\n  auto ba_k_int4_1 = std::make_shared<BufferA_k_int4_1>(m, k, k_group_size, buffer_a_k_int4_1);\n  auto bb_k_int4_1 = std::make_shared<BufferB_k_int4_1>(n, k, k_group_size, buffer_b_k_int4_1);\n  auto bc_k_int4_1 = std::make_shared<BufferC_k_int4_1>(m, n, buffer_c_k_int4_1);\n\n  auto ba_k_int4_1_low = std::make_shared<BufferA_k_int4_1_low>(m, k, k_group_size, buffer_a_k_int4_1_low);\n  auto bb_k_int4_1_low = std::make_shared<BufferB_k_int4_1_low>(n, k, k_group_size, buffer_b_k_int4_1_low);\n  auto bc_k_int4_1_low = std::make_shared<BufferC_k_int4_1_low>(m, n, buffer_c_k_int4_1_low);\n\n  // Create input matrices with realistic values\n  std::vector<ggml_bf16_t> input_a(m * k);\n  std::vector<ggml_bf16_t> input_b(k * n);\n\n  std::mt19937 gen(42);\n  std::normal_distribution<float> dist(0.0f, 0.1f);  // Normal distribution, mean=0, std=0.1\n\n  std::cout << \"\\nGenerating input matrices...\" << std::endl;\n  // print input mat(first 10)\n  // for (int i = 0; i < std::min(10, m * k); i++) {\n  //   std::cout << \"input_a[\" << i << \"] = \" << ggml_compute_bf16_to_fp32(input_a[i]) << std::endl;\n  // }\n  // for (int i = 0; i < std::min(10, k * n); i++) {\n  //   std::cout << \"input_b[\" << i << \"] = \" << ggml_compute_bf16_to_fp32(input_b[i]) << std::endl;\n  // }\n  for (int i = 0; i < m * k; i++) {\n    input_a[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n  for (int i = 0; i < k * n; i++) {\n    input_b[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n\n  // Compute reference result with float32 (sampling for speed, only use original m rows)\n  std::cout << \"Computing reference (sampling)...\" << std::endl;\n  const int sample_m = std::min(50, m_original);  // Use original m for reference\n  const int sample_n = std::min(50, n);\n  std::vector<float> ref_result(sample_m * sample_n, 0.0f);\n\n  for (int i = 0; i < sample_m; i++) {\n    for (int j = 0; j < sample_n; j++) {\n      float sum = 0.0f;\n      for (int l = 0; l < k; l++) {\n        float a_val = ggml_compute_bf16_to_fp32(input_a[i * k + l]);\n        float b_val = ggml_compute_bf16_to_fp32(input_b[j * k + l]);\n        sum += a_val * b_val;\n      }\n      ref_result[i * sample_n + j] = sum;\n    }\n  }\n\n  // Quantize and compute with k-group\n  std::cout << \"Quantizing matrices...\" << std::endl;\n  ba->from_mat(m, input_a.data(), 0, 1);\n  int nth = Kernel::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    bb->from_mat(input_b.data(), i, nth);\n  }\n\n  ba_int4_1->from_mat(m, input_a.data(), 0, 1);\n  nth = Kernel_int4_1::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    bb_int4_1->from_mat(input_b.data(), i, nth);\n  }\n\n  ba_int4->from_mat(m, input_a.data(), 0, 1);\n  nth = Kernel_int4::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    bb_int4->from_mat(input_b.data(), i, nth);\n  }\n\n  ba_k_int4_1->from_mat(m, input_a.data(), 0, 1);\n  nth = Kernel_k_int4_1::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    bb_k_int4_1->from_mat(input_b.data(), i, nth);\n  }\n\n  ba_k_int4_1_low->from_mat(m, input_a.data(), 0, 1);\n  nth = Kernel_k_int4_1_low::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    bb_k_int4_1_low->from_mat(input_b.data(), i, nth);\n  }\n\n  // Print some scale statistics\n  std::cout << \"\\nScale statistics:\" << std::endl;\n  float min_a_scale = 1e10f, max_a_scale = 0.0f;\n  float min_b_scale = 1e10f, max_b_scale = 0.0f;\n  float min_a_scale_int4_1 = 1e10f, max_a_scale_int4_1 = 0.0f;\n  float min_b_scale_int4_1 = 1e10f, max_b_scale_int4_1 = 0.0f;\n  float min_b_min_int4_1 = 1e10f, max_b_min_int4_1 = -1e10f;\n  float min_a_scale_int4 = 1e10f, max_a_scale_int4 = 0.0f;\n  float min_b_scale_int4 = 1e10f, max_b_scale_int4 = 0.0f;\n  float min_a_scale_k_int4_1 = 1e10f, max_a_scale_k_int4_1 = 0.0f;\n  float min_b_scale_k_int4_1 = 1e10f, max_b_scale_k_int4_1 = 0.0f;\n  float min_b_min_k_int4_1 = 1e10f, max_b_min_k_int4_1 = -1e10f;\n  float min_a_scale_k_int4_1_low = 1e10f, max_a_scale_k_int4_1_low = 0.0f;\n  float min_b_scale_k_int4_1_low = 1e10f, max_b_scale_k_int4_1_low = 0.0f;\n  float min_b_min_k_int4_1_low = 1e10f, max_b_min_k_int4_1_low = -1e10f;\n\n  for (int i = 0; i < std::min(10, m); i++) {\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *ba->get_scale(m, i, k, kg * k_group_size);\n      min_a_scale = std::min(min_a_scale, scale);\n      max_a_scale = std::max(max_a_scale, scale);\n    }\n  }\n\n  for (int j = 0; j < std::min(10, n); j++) {\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *bb->get_scale(n, j, k, kg * k_group_size);\n      min_b_scale = std::min(min_b_scale, scale);\n      max_b_scale = std::max(max_b_scale, scale);\n    }\n  }\n  for (int i = 0; i < std::min(10, m); i++) {\n    float scale = *ba_int4_1->get_scale(m, i);\n    min_a_scale_int4_1 = std::min(min_a_scale_int4_1, scale);\n    max_a_scale_int4_1 = std::max(max_a_scale_int4_1, scale);\n  }\n  for (int j = 0; j < std::min(10, n); j++) {\n    float scale = *bb_int4_1->get_scale(n, j);\n    min_b_scale_int4_1 = std::min(min_b_scale_int4_1, scale);\n    max_b_scale_int4_1 = std::max(max_b_scale_int4_1, scale);\n    float b_min = *bb_int4_1->get_min(n, j);\n    min_b_min_int4_1 = std::min(min_b_min_int4_1, b_min);\n    max_b_min_int4_1 = std::max(max_b_min_int4_1, b_min);\n  }\n\n  for (int i = 0; i < std::min(10, m); i++) {\n    float scale = *ba_int4->get_scale(m, i);\n    min_a_scale_int4 = std::min(min_a_scale_int4, scale);\n    max_a_scale_int4 = std::max(max_a_scale_int4, scale);\n  }\n\n  for (int j = 0; j < std::min(10, n); j++) {\n    float scale = *bb_int4->get_scale(n, j);\n    min_b_scale_int4 = std::min(min_b_scale_int4, scale);\n    max_b_scale_int4 = std::max(max_b_scale_int4, scale);\n  }\n\n  for (int i = 0; i < std::min(10, m); i++) {\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *ba_k_int4_1->get_scale(m, i, k, kg * k_group_size);\n      min_a_scale_k_int4_1 = std::min(min_a_scale_k_int4_1, scale);\n      max_a_scale_k_int4_1 = std::max(max_a_scale_k_int4_1, scale);\n    }\n  }\n\n  for (int j = 0; j < std::min(10, n); j++) {\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *bb_k_int4_1->get_scale(n, j, k, kg * k_group_size);\n      min_b_scale_k_int4_1 = std::min(min_b_scale_k_int4_1, scale);\n      max_b_scale_k_int4_1 = std::max(max_b_scale_k_int4_1, scale);\n      float b_min = *bb_k_int4_1->get_min(n, j, k, kg * k_group_size);\n      min_b_min_k_int4_1 = std::min(min_b_min_k_int4_1, b_min);\n      max_b_min_k_int4_1 = std::max(max_b_min_k_int4_1, b_min);\n    }\n  }\n\n  for (int i = 0; i < std::min(10, m); i++) {\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *ba_k_int4_1_low->get_scale(m, i, k, kg * k_group_size);\n      min_a_scale_k_int4_1_low = std::min(min_a_scale_k_int4_1_low, scale);\n      max_a_scale_k_int4_1_low = std::max(max_a_scale_k_int4_1_low, scale);\n    }\n  }\n\n  for (int j = 0; j < std::min(10, n); j++) {\n    for (int kg = 0; kg < k / k_group_size; kg++) {\n      float scale = *bb_k_int4_1_low->get_scale(n, j, k, kg * k_group_size);\n      min_b_scale_k_int4_1_low = std::min(min_b_scale_k_int4_1_low, scale);\n      max_b_scale_k_int4_1_low = std::max(max_b_scale_k_int4_1_low, scale);\n      float b_min = *bb_k_int4_1_low->get_min(n, j, k, kg * k_group_size);\n      min_b_min_k_int4_1_low = std::min(min_b_min_k_int4_1_low, b_min);\n      max_b_min_k_int4_1_low = std::max(max_b_min_k_int4_1_low, b_min);\n    }\n  }\n  std::cout << \"  B_int4_1 scales: min=\" << min_b_scale_int4_1 << \", max=\" << max_b_scale_int4_1 << std::endl;\n  std::cout << \"  B_int4_1 min: min=\" << min_b_min_int4_1 << \", max=\" << max_b_min_int4_1 << std::endl;\n\n  std::cout << \"  A_int4 scales: min=\" << min_a_scale_int4 << \", max=\" << max_a_scale_int4 << std::endl;\n  std::cout << \"  B_int4 scales: min=\" << min_b_scale_int4 << \", max=\" << max_b_scale_int4 << std::endl;\n\n  std::cout << \"  A_k_int4_1 scales: min=\" << min_a_scale_k_int4_1 << \", max=\" << max_a_scale_k_int4_1 << std::endl;\n  std::cout << \"  B_k_int4_1 scales: min=\" << min_b_scale_k_int4_1 << \", max=\" << max_b_scale_k_int4_1 << std::endl;\n  std::cout << \"  B_k_int4_1 min: min=\" << min_b_min_k_int4_1 << \", max=\" << max_b_min_k_int4_1 << std::endl;\n\n  std::cout << \"  A_k_int4_1_low scales: min=\" << min_a_scale_k_int4_1_low << \", max=\" << max_a_scale_k_int4_1_low\n            << std::endl;\n  std::cout << \"  B_k_int4_1_low scales: min=\" << min_b_scale_k_int4_1_low << \", max=\" << max_b_scale_k_int4_1_low\n            << std::endl;\n  std::cout << \"  B_k_int4_1_low min: min=\" << min_b_min_k_int4_1_low << \", max=\" << max_b_min_k_int4_1_low\n            << std::endl;\n\n  Kernel::config();\n\n  std::cout << \"\\nRunning k-group matrix multiplication...\" << std::endl;\n  auto start = std::chrono::high_resolution_clock::now();\n\n  nth = Kernel::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, i, nth);\n  }\n\n  nth = Kernel_int4_1::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    amx::mat_mul(m, n, k, ba_int4_1, bb_int4_1, bc_int4_1, i, nth);\n  }\n\n  nth = Kernel_int4::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    amx::mat_mul(m, n, k, ba_int4, bb_int4, bc_int4, i, nth);\n  }\n\n  nth = Kernel_k_int4_1::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    amx::vec_mul_kgroup(m, n, k, k_group_size, ba_k_int4_1, bb_k_int4_1, bc_k_int4_1, i, nth);\n  }\n\n  nth = Kernel_k_int4_1_low::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    amx::vec_mul_kgroup(m, n, k, k_group_size, ba_k_int4_1_low, bb_k_int4_1_low, bc_k_int4_1_low, i, nth);\n  }\n  auto end = std::chrono::high_resolution_clock::now();\n\n  auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);\n  std::cout << \"Computation time: \" << duration.count() / 1000.0 << \" ms\" << std::endl;\n\n  // Calculate GFLOPS\n  double ops = 2.0 * m * n * k;\n  double gflops = ops / (duration.count() * 1000.0);\n  std::cout << \"Performance: \" << gflops << \" GFLOPS\" << std::endl;\n\n  std::vector<ggml_bf16_t> output(m * n);\n  std::vector<ggml_bf16_t> output_int4_1(m * n);\n  std::vector<ggml_bf16_t> output_int4(m * n);\n  std::vector<ggml_bf16_t> output_k_int4_1(m * n);\n  std::vector<ggml_bf16_t> output_k_int4_1_low(m * n);\n  nth = Kernel::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    bc->to_mat(m, output.data(), i, nth);\n  }\n  nth = Kernel_int4_1::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    bc_int4_1->to_mat(m, output_int4_1.data(), i, nth);\n  }\n  nth = Kernel_int4::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    bc_int4->to_mat(m, output_int4.data(), i, nth);\n  }\n  nth = Kernel_k_int4_1::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    bc_k_int4_1->to_mat(m, output_k_int4_1.data(), i, nth);\n  }\n  nth = Kernel_k_int4_1_low::recommended_nth(n);\n  for (int i = 0; i <= nth; i++) {\n    bc_k_int4_1_low->to_mat(m, output_k_int4_1_low.data(), i, nth);\n  }\n  float thresh_hold = 2.0f;\n  // Compute errors for sampled elements\n  std::cout << \"\\nError analysis (sampled):\" << std::endl;\n  float max_abs_error = 0.0f;\n  float total_abs_error = 0.0f;\n  float max_rel_error = 0.0f;\n  float total_rel_error = 0.0f;\n  int count = 0;\n\n  for (int i = 0; i < sample_m; i++) {\n    for (int j = 0; j < sample_n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float abs_error = std::abs(actual - ref);\n      float rel_error = std::abs(ref) > 1e-6 ? abs_error / std::abs(ref) : 0.0f;\n      if (rel_error >= thresh_hold) {\n        rel_error = thresh_hold;\n      }\n      max_abs_error = std::max(max_abs_error, abs_error);\n      total_abs_error += abs_error;\n      max_rel_error = std::max(max_rel_error, rel_error);\n      total_rel_error += rel_error;\n      count++;\n    }\n  }\n\n  float avg_abs_error = total_abs_error / count;\n  float avg_rel_error = total_rel_error / count;\n\n  std::cout << \"  Max absolute error: \" << max_abs_error << std::endl;\n  std::cout << \"  Average absolute error: \" << avg_abs_error << std::endl;\n  std::cout << \"  Max relative error: \" << (max_rel_error * 100) << \"%\" << std::endl;\n  std::cout << \"  Average relative error: \" << (avg_rel_error * 100) << \"%\" << std::endl;\n\n  float max_abs_error_int4_1 = 0.0f;\n  float total_abs_error_int4_1 = 0.0f;\n  float max_rel_error_int4_1 = 0.0f;\n  float total_rel_error_int4_1 = 0.0f;\n  int count_int4_1 = 0;\n\n  for (int i = 0; i < sample_m; i++) {\n    for (int j = 0; j < sample_n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_int4_1[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float abs_error = std::abs(actual - ref);\n      float rel_error = std::abs(ref) > 1e-6 ? abs_error / std::abs(ref) : 0.0f;\n      if (rel_error >= thresh_hold) {\n        rel_error = thresh_hold;\n      }\n\n      max_abs_error_int4_1 = std::max(max_abs_error_int4_1, abs_error);\n      total_abs_error_int4_1 += abs_error;\n      max_rel_error_int4_1 = std::max(max_rel_error_int4_1, rel_error);\n      total_rel_error_int4_1 += rel_error;\n      count_int4_1++;\n    }\n  }\n\n  float avg_abs_error_int4_1 = total_abs_error_int4_1 / count_int4_1;\n  float avg_rel_error_int4_1 = total_rel_error_int4_1 / count_int4_1;\n  std::cout << \"\\nINT4_1 Error analysis (sampled):\" << std::endl;\n  std::cout << \"  Max absolute error: \" << max_abs_error_int4_1 << std::endl;\n  std::cout << \"  Average absolute error: \" << avg_abs_error_int4_1 << std::endl;\n  std::cout << \"  Max relative error: \" << (max_rel_error_int4_1 * 100) << \"%\" << std::endl;\n  std::cout << \"  Average relative error: \" << (avg_rel_error_int4_1 * 100) << \"%\" << std::endl;\n\n  float max_abs_error_int4 = 0.0f;\n  float total_abs_error_int4 = 0.0f;\n  float max_rel_error_int4 = 0.0f;\n  float total_rel_error_int4 = 0.0f;\n  int count_int4 = 0;\n\n  for (int i = 0; i < sample_m; i++) {\n    for (int j = 0; j < sample_n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_int4[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float abs_error = std::abs(actual - ref);\n      float rel_error = std::abs(ref) > 1e-6 ? abs_error / std::abs(ref) : 0.0f;\n      if (rel_error >= thresh_hold) {\n        rel_error = thresh_hold;\n      }\n\n      max_abs_error_int4 = std::max(max_abs_error_int4, abs_error);\n      total_abs_error_int4 += abs_error;\n      max_rel_error_int4 = std::max(max_rel_error_int4, rel_error);\n      total_rel_error_int4 += rel_error;\n      count_int4++;\n    }\n  }\n\n  float avg_abs_error_int4 = total_abs_error_int4 / count_int4;\n  float avg_rel_error_int4 = total_rel_error_int4 / count_int4;\n  std::cout << \"\\nINT4 Error analysis (sampled):\" << std::endl;\n  std::cout << \"  Max absolute error: \" << max_abs_error_int4 << std::endl;\n  std::cout << \"  Average absolute error: \" << avg_abs_error_int4 << std::endl;\n  std::cout << \"  Max relative error: \" << (max_rel_error_int4 * 100) << \"%\" << std::endl;\n  std::cout << \"  Average relative error: \" << (avg_rel_error_int4 * 100) << \"%\" << std::endl;\n\n  float max_abs_error_k_int4_1 = 0.0f;\n  float total_abs_error_k_int4_1 = 0.0f;\n  float max_rel_error_k_int4_1 = 0.0f;\n  float total_rel_error_k_int4_1 = 0.0f;\n  int count_k_int4_1 = 0;\n\n  for (int i = 0; i < sample_m; i++) {\n    for (int j = 0; j < sample_n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_k_int4_1[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float abs_error = std::abs(actual - ref);\n      float rel_error = std::abs(ref) > 1e-6 ? abs_error / std::abs(ref) : 0.0f;\n      if (rel_error >= thresh_hold) {\n        rel_error = thresh_hold;\n      }\n\n      max_abs_error_k_int4_1 = std::max(max_abs_error_k_int4_1, abs_error);\n      total_abs_error_k_int4_1 += abs_error;\n      max_rel_error_k_int4_1 = std::max(max_rel_error_k_int4_1, rel_error);\n      total_rel_error_k_int4_1 += rel_error;\n      count_k_int4_1++;\n    }\n  }\n  float avg_abs_error_k_int4_1 = total_abs_error_k_int4_1 / count_k_int4_1;\n  float avg_rel_error_k_int4_1 = total_rel_error_k_int4_1 / count_k_int4_1;\n  std::cout << \"\\nINT4_1_k Error analysis (sampled):\" << std::endl;\n  std::cout << \"  Max absolute error: \" << max_abs_error_k_int4_1 << std::endl;\n  std::cout << \"  Average absolute error: \" << avg_abs_error_k_int4_1 << std::endl;\n  std::cout << \"  Max relative error: \" << (max_rel_error_k_int4_1 * 100) << \"%\" << std::endl;\n  std::cout << \"  Average relative error: \" << (avg_rel_error_k_int4_1 * 100) << \"%\" << std::endl;\n\n  float max_abs_error_k_int4_1_low = 0.0f;\n  float total_abs_error_k_int4_1_low = 0.0f;\n  float max_rel_error_k_int4_1_low = 0.0f;\n  float total_rel_error_k_int4_1_low = 0.0f;\n  int count_k_int4_1_low = 0;\n\n  for (int i = 0; i < sample_m; i++) {\n    for (int j = 0; j < sample_n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_k_int4_1_low[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float abs_error = std::abs(actual - ref);\n      float rel_error = std::abs(ref) > 1e-6 ? abs_error / std::abs(ref) : 0.0f;\n      if (rel_error >= thresh_hold) {\n        rel_error = thresh_hold;\n      }\n\n      max_abs_error_k_int4_1_low = std::max(max_abs_error_k_int4_1_low, abs_error);\n      total_abs_error_k_int4_1_low += abs_error;\n      max_rel_error_k_int4_1_low = std::max(max_rel_error_k_int4_1_low, rel_error);\n      total_rel_error_k_int4_1_low += rel_error;\n      count_k_int4_1_low++;\n    }\n  }\n\n  float avg_abs_error_k_int4_1_low = total_abs_error_k_int4_1_low / count_k_int4_1_low;\n  float avg_rel_error_k_int4_1_low = total_rel_error_k_int4_1_low / count_k_int4_1_low;\n  std::cout << \"\\nINT4_1_k_low Error analysis (sampled):\" << std::endl;\n  std::cout << \"  Max absolute error: \" << max_abs_error_k_int4_1_low << std::endl;\n  std::cout << \"  Average absolute error: \" << avg_abs_error_k_int4_1_low << std::endl;\n  std::cout << \"  Max relative error: \" << (max_rel_error_k_int4_1_low * 100) << \"%\" << std::endl;\n  std::cout << \"  Average relative error: \" << (avg_rel_error_k_int4_1_low * 100) << \"%\" << std::endl;\n\n  // Print sample comparison\n  std::cout << \"\\nSample comparison (first 10x10):\" << std::endl;\n  std::cout << \"Format: actual (reference) [error%]\" << std::endl;\n  for (int i = 10; i < std::min(20, sample_m); i++) {\n    for (int j = 10; j < std::min(20, sample_n); j++) {\n      float actual = ggml_compute_bf16_to_fp32(output[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float error_pct = std::abs(ref) > 1e-6 ? (actual - ref) / ref * 100 : 0.0f;\n      printf(\"%7.4f (%7.4f) [%+6.1f%%]  \", actual, ref, error_pct);\n    }\n    std::cout << std::endl;\n  }\n  std::cout << \"\\nint4_1 Sample comparison (first 10x10):\" << std::endl;\n  std::cout << \"Format: actual (reference) [error%]\" << std::endl;\n  for (int i = 10; i < std::min(20, sample_m); i++) {\n    for (int j = 10; j < std::min(20, sample_n); j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_int4_1[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float error_pct = std::abs(ref) > 1e-6 ? (actual - ref) / ref * 100 : 0.0f;\n      printf(\"%7.4f (%7.4f) [%+6.1f%%]  \", actual, ref, error_pct);\n    }\n    std::cout << std::endl;\n  }\n  std::cout << \"\\nint4 Sample comparison (first 10x10):\" << std::endl;\n  std::cout << \"Format: actual (reference) [error%]\" << std::endl;\n  for (int i = 10; i < std::min(20, sample_m); i++) {\n    for (int j = 10; j < std::min(20, sample_n); j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_int4[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float error_pct = std::abs(ref) > 1e-6 ? (actual - ref) / ref * 100 : 0.0f;\n      printf(\"%7.4f (%7.4f) [%+6.1f%%]  \", actual, ref, error_pct);\n    }\n    std::cout << std::endl;\n  }\n\n  std::cout << \"\\nint4_1_k Sample comparison (first 10x10):\" << std::endl;\n  std::cout << \"Format: actual (reference) [error%]\" << std::endl;\n  for (int i = 10; i < std::min(20, sample_m); i++) {\n    for (int j = 10; j < std::min(20, sample_n); j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_k_int4_1[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float error_pct = std::abs(ref) > 1e-6 ? (actual - ref) / ref * 100 : 0.0f;\n      printf(\"%7.4f (%7.4f) [%+6.1f%%]  \", actual, ref, error_pct);\n    }\n    std::cout << std::endl;\n  }\n\n  std::cout << \"\\nint4_1_k_low Sample comparison (first 10x10):\" << std::endl;\n  std::cout << \"Format: actual (reference) [error%]\" << std::endl;\n  for (int i = 10; i < std::min(20, sample_m); i++) {\n    for (int j = 10; j < std::min(20, sample_n); j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_k_int4_1_low[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float error_pct = std::abs(ref) > 1e-6 ? (actual - ref) / ref * 100 : 0.0f;\n      printf(\"%7.4f (%7.4f) [%+6.1f%%]  \", actual, ref, error_pct);\n    }\n    std::cout << std::endl;\n  }\n\n  std::cout << \"\\nint4 Sample comparison:\" << std::endl;\n  std::cout << \"Format: actual (reference) [error%]\" << std::endl;\n  for (int i = 0; i < 1; i++) {\n    for (int j = 0; j < n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_int4[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float error_pct = std::abs(ref) > 1e-6 ? (actual - ref) / ref * 100 : 0.0f;\n      printf(\"j:%d, %7.4f (%7.4f) [%+6.1f%%]  \", j, actual, ref, error_pct);\n    }\n    std::cout << std::endl;\n  }\n\n  std::cout << \"\\nSample comparison:\" << std::endl;\n  std::cout << \"Format: actual (reference) [error%]\" << std::endl;\n  for (int i = 0; i < 1; i++) {\n    for (int j = 0; j < n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float error_pct = std::abs(ref) > 1e-6 ? (actual - ref) / ref * 100 : 0.0f;\n      printf(\"j:%d, %7.4f (%7.4f) [%+6.1f%%]  \", j, actual, ref, error_pct);\n    }\n    std::cout << std::endl;\n  }\n\n  std::cout << \"\\nint4_1_k Sample comparison:\" << std::endl;\n  std::cout << \"Format: actual (reference) [error%]\" << std::endl;\n  for (int i = 0; i < 1; i++) {\n    for (int j = 0; j < n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_k_int4_1[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float error_pct = std::abs(ref) > 1e-6 ? (actual - ref) / ref * 100 : 0.0f;\n      printf(\"j:%d, %7.4f (%7.4f) [%+6.1f%%]  \", j, actual, ref, error_pct);\n    }\n    std::cout << std::endl;\n  }\n\n  std::cout << \"\\nint4_1 Sample comparison:\" << std::endl;\n  std::cout << \"Format: actual (reference) [error%]\" << std::endl;\n  for (int i = 0; i < 1; i++) {\n    for (int j = 0; j < n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_int4_1[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float error_pct = std::abs(ref) > 1e-6 ? (actual - ref) / ref * 100 : 0.0f;\n      printf(\"j:%d, %7.4f (%7.4f) [%+6.1f%%]  \", j, actual, ref, error_pct);\n    }\n    std::cout << std::endl;\n  }\n\n  std::cout << \"\\nint4_1_k_low Sample comparison:\" << std::endl;\n  std::cout << \"Format: actual (reference) [error%]\" << std::endl;\n  for (int i = 0; i < 1; i++) {\n    for (int j = 0; j < n; j++) {\n      float actual = ggml_compute_bf16_to_fp32(output_k_int4_1_low[i * n + j]);\n      float ref = ref_result[i * sample_n + j];\n      float error_pct = std::abs(ref) > 1e-6 ? (actual - ref) / ref * 100 : 0.0f;\n      printf(\"j:%d, %7.4f (%7.4f) [%+6.1f%%]  \", j, actual, ref, error_pct);\n    }\n    std::cout << std::endl;\n  }\n\n  // Check if accuracy is acceptable for INT4\n  if (avg_rel_error < 0.2f) {\n    std::cout << \"\\n✓ Excellent accuracy (<20% average error)\" << std::endl;\n  } else if (avg_rel_error < 0.3f) {\n    std::cout << \"\\n✓ Acceptable accuracy (20-30% average error)\" << std::endl;\n  } else if (avg_rel_error < 0.4f) {\n    std::cout << \"\\n⚠ Marginal accuracy (30-40% average error)\" << std::endl;\n  } else {\n    std::cout << \"\\n✗ Poor accuracy (>40% average error)\" << std::endl;\n  }\n\n  if (avg_rel_error_int4_1 < 0.2f) {\n    std::cout << \"\\n✓ Excellent accuracy for INT4 quantization (<20% average error)\" << std::endl;\n  } else if (avg_rel_error_int4_1 < 0.3f) {\n    std::cout << \"\\n✓ Acceptable accuracy for INT4 quantization (20-30% average error)\" << std::endl;\n  } else if (avg_rel_error_int4_1 < 0.4f) {\n    std::cout << \"\\n⚠ Marginal accuracy for INT4 quantization (30-40% average error)\" << std::endl;\n  } else {\n    std::cout << \"\\n✗ Poor accuracy for INT4 quantization (>40% average error)\" << std::endl;\n  }\n\n  if (avg_rel_error_int4 < 0.2f) {\n    std::cout << \"\\n✓ Excellent accuracy for INT4 quantization (<20% average error)\" << std::endl;\n  } else if (avg_rel_error_int4 < 0.3f) {\n    std::cout << \"\\n✓ Acceptable accuracy for INT4 quantization (20-30% average error)\" << std::endl;\n  } else if (avg_rel_error_int4 < 0.4f) {\n    std::cout << \"\\n⚠ Marginal accuracy for INT4 quantization (30-40% average error)\" << std::endl;\n  } else {\n    std::cout << \"\\n✗ Poor accuracy for INT4 quantization (>40% average error)\" << std::endl;\n  }\n\n  if (avg_rel_error_k_int4_1 < 0.2f) {\n    std::cout << \"\\n✓ Excellent accuracy for INT4 k-group quantization (<20% average error)\" << std::endl;\n  } else if (avg_rel_error_k_int4_1 < 0.3f) {\n    std::cout << \"\\n✓ Acceptable accuracy for INT4 k-group quantization (20-30% average error)\" << std::endl;\n  } else if (avg_rel_error_k_int4_1 < 0.4f) {\n    std::cout << \"\\n⚠ Marginal accuracy for INT4 k-group quantization (30-40% average error)\" << std::endl;\n  } else {\n    std::cout << \"\\n✗ Poor accuracy for INT4 k-group quantization (>40% average error)\" << std::endl;\n  }\n\n  if (avg_rel_error_k_int4_1_low < 0.2f) {\n    std::cout << \"\\n✓ Excellent accuracy for INT4 k-group low quantization (<20% average error)\" << std::endl;\n  } else if (avg_rel_error_k_int4_1_low < 0.3f) {\n    std::cout << \"\\n✓ Acceptable accuracy for INT4 k-group low quantization (20-30% average error)\" << std::endl;\n  } else if (avg_rel_error_k_int4_1_low < 0.4f) {\n    std::cout << \"\\n⚠ Marginal accuracy for INT4 k-group low quantization (30-40% average error)\" << std::endl;\n  } else {\n    std::cout << \"\\n✗ Poor accuracy for INT4 k-group low quantization (>40% average error)\" << std::endl;\n  }\n\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n}\n\nint main() {\n  test_specific_dimensions();\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/amx/test/thread_test.sh",
    "content": "#!/bin/bash\n\n# 进入脚本所在的目录\ncd \"$(dirname \"$0\")\" || { echo \"Failed to enter the script's directory\"; exit 1; }\n\n# 进入 ../build 目录\ncd ../build || { echo \"Failed to enter ../build directory\"; exit 1; }\n\n# 设置线程数列表\nthreads=(1 2 4 8 16 24 36 48 72)\n\n# 遍历每个线程数并运行命令\nfor t in \"${threads[@]}\"; do\n    echo \"Running with OMP_NUM_THREADS=$t\"\n    OMP_NUM_THREADS=$t numactl -N 0 ./la/amx-test\n    sleep 1s\ndone\n"
  },
  {
    "path": "kt-kernel/operators/amx/test/timer.hh",
    "content": "#ifndef TIMER_HH\n#define TIMER_HH\n\n#include <cassert>\n#include <chrono>\n#include <iomanip>\n#include <iostream>\n#include <map>\n#include <sstream>\n#include <string>\n#include <array>\n\ninline std::string doubleToStringR2(double value) {\n  std::stringstream stream;\n  stream << std::fixed << std::setprecision(2) << value;\n  return stream.str();\n}\n\ninline std::array<std::string, 7> units = {\"\", \"K\", \"M\", \"G\", \"T\", \"P\", \"E\"};\n\ninline std::string readable_number(size_t size) {\n  size_t unit_index = 0;\n  double readable_size = size;\n  while (readable_size >= 1000 && unit_index < units.size() - 1) {\n    readable_size /= 1000;\n    unit_index++;\n  }\n  std::ostringstream ss;\n  ss << std::fixed << std::setprecision(2) << readable_size;\n  std::string str = ss.str();\n  return str + \"\" + units[unit_index];\n}\n\nclass Timer {\npublic:\n  std::string name;\n  bool tmp_timer = false;\n\n  Timer() {}\n  Timer(std::string name) : name(name), tmp_timer(true) { start(); }\n  ~Timer() {\n    if (tmp_timer) {\n      std::cout << name << \" \" << elapsedTime() << std::endl;\n    }\n  }\n\n  void start() {\n    m_startTime = std::chrono::high_resolution_clock::now();\n    assert(m_isRunning == false);\n    m_isRunning = true;\n  }\n\n  void stop() {\n    m_endTime = std::chrono::high_resolution_clock::now();\n    assert(m_isRunning == true);\n    m_isRunning = false;\n    m_runningNs += elapsedNs();\n  }\n\n  double elapsedNs() {\n    std::chrono::time_point<std::chrono::high_resolution_clock> endTime;\n\n    if (m_isRunning) {\n      endTime = std::chrono::high_resolution_clock::now();\n    } else {\n      endTime = m_endTime;\n    }\n\n    return std::chrono::duration_cast<std::chrono::nanoseconds>(endTime -\n                                                                m_startTime)\n        .count();\n  }\n\n  void printElapsedMilliseconds() {\n    std::cout << elapsedNs() / 1e6 << \" ms\" << std::endl;\n  }\n\n  static std::string ns_to_string(double duration) {\n    auto nano_sec = duration;\n    if (nano_sec >= 1000) {\n      auto mirco_sec = nano_sec / 1000.0;\n      if (mirco_sec >= 1000) {\n        auto milli_sec = mirco_sec / 1000.0;\n        if (milli_sec >= 1000) {\n          auto seconds = milli_sec / 1000.0;\n\n          if (seconds >= 60.0) {\n            auto minutes = seconds / 60.0;\n\n            if (minutes >= 60.0) {\n              auto hours = minutes / 60.0;\n              return doubleToStringR2(hours) + \" h\";\n            } else {\n              return doubleToStringR2(minutes) + \" min\";\n            }\n          } else {\n            return doubleToStringR2(seconds) + \" sec\";\n          }\n        } else {\n          return doubleToStringR2(milli_sec) + \" ms\";\n        }\n      } else {\n        return doubleToStringR2(mirco_sec) + \" us\";\n      }\n    } else {\n      return doubleToStringR2(nano_sec) + \" ns\";\n    }\n  }\n\n  double runningTimeNs() { return m_runningNs; }\n\n  std::string runningTime() {\n    auto duration = m_runningNs;\n    return ns_to_string(duration);\n  }\n\n  std::string elapsedTime() { return ns_to_string(elapsedNs()); }\n  double elapsedMs() { return elapsedNs() / 1e6; }\n  std::string report_throughput(size_t op_cnt) {\n    double ops = op_cnt / elapsedMs() * 1000;\n    return readable_number(ops) + \"op/s\";\n  }\n\n  void merge(Timer &other) {\n    assert(m_isRunning == false);\n    assert(other.m_isRunning == false);\n    m_runningNs += other.runningTimeNs();\n  }\n\nprivate:\n  std::chrono::time_point<std::chrono::high_resolution_clock> m_startTime;\n  std::chrono::time_point<std::chrono::high_resolution_clock> m_endTime;\n  bool m_isRunning = false;\n  double m_runningNs = 0.0;\n};\n\nclass Counter {\npublic:\n  Counter() {}\n\n  std::map<std::string, size_t> counters;\n\n  void inc(const char *name, size_t num) { counters[name] += num; };\n  void print() {\n    for (auto &p : counters) {\n      std::cout << p.first << \" : \" << p.second << std::endl;\n    }\n  };\n};\n\n#endif // TIMER_HH\n"
  },
  {
    "path": "kt-kernel/operators/amx/test/verify-kgroup.cpp",
    "content": "#include <cmath>\n#include <iostream>\n#include <memory>\n#include <random>\n\n#include \"../la/amx.hpp\"\n\nvoid verify_kgroup_accuracy() {\n  std::cout << \"=== Verifying K-Group Accuracy ===\" << std::endl;\n\n  const int m = 32;\n  const int n = 32;\n  const int k = 1024;\n  const int k_group_size = 256;\n\n  using Kernel = amx::GemmKernel224Int4KGroup;\n  using BufferA = Kernel::BufferA;\n  using BufferB = Kernel::BufferB;\n  using BufferC = Kernel::BufferC;\n\n  void* buffer_a = std::aligned_alloc(64, BufferA::required_size(m, k, k_group_size));\n  void* buffer_b = std::aligned_alloc(64, BufferB::required_size(n, k, k_group_size));\n  void* buffer_c = std::aligned_alloc(64, BufferC::required_size(m, n));\n\n  auto ba = std::make_shared<BufferA>(m, k, k_group_size, buffer_a);\n  auto bb = std::make_shared<BufferB>(n, k, k_group_size, buffer_b);\n  auto bc = std::make_shared<BufferC>(m, n, buffer_c);\n\n  // Create input matrices with values in the quantization sweet spot\n  std::vector<ggml_bf16_t> input_a(m * k);\n  std::vector<ggml_bf16_t> input_b(k * n);\n\n  std::mt19937 gen(12345);\n  std::uniform_real_distribution<float> dist(-0.5f, 0.5f);\n\n  for (int i = 0; i < m * k; i++) {\n    input_a[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n  for (int i = 0; i < k * n; i++) {\n    input_b[i] = ggml_compute_fp32_to_bf16(dist(gen));\n  }\n\n  // Compute reference result with float32\n  std::vector<float> ref_result(m * n, 0.0f);\n  for (int i = 0; i < m; i++) {\n    for (int j = 0; j < n; j++) {\n      float sum = 0.0f;\n      for (int l = 0; l < k; l++) {\n        float a_val = ggml_compute_bf16_to_fp32(input_a[i * k + l]);\n        float b_val = ggml_compute_bf16_to_fp32(input_b[l * n + j]);\n        sum += a_val * b_val;\n      }\n      ref_result[i * n + j] = sum;\n    }\n  }\n\n  // Quantize and compute with k-group\n  ba->from_mat(m, input_a.data(), 0, 1);\n  bb->from_mat(input_b.data(), 0, 1);\n\n  Kernel::config();\n  amx::mat_mul_kgroup(m, n, k, k_group_size, ba, bb, bc, 0, 1);\n\n  std::vector<ggml_bf16_t> output(m * n);\n  bc->to_mat(m, output.data(), 0, 1);\n\n  // Compute errors\n  float max_abs_error = 0.0f;\n  float total_abs_error = 0.0f;\n  float max_ref_value = 0.0f;\n\n  for (int i = 0; i < m * n; i++) {\n    float actual = ggml_compute_bf16_to_fp32(output[i]);\n    float ref = ref_result[i];\n    float error = std::abs(actual - ref);\n\n    max_abs_error = std::max(max_abs_error, error);\n    total_abs_error += error;\n    max_ref_value = std::max(max_ref_value, std::abs(ref));\n  }\n\n  float avg_abs_error = total_abs_error / (m * n);\n  float relative_error = max_abs_error / (max_ref_value + 1e-8f);\n\n  std::cout << \"Matrix dimensions: \" << m << \"x\" << n << \"x\" << k << std::endl;\n  std::cout << \"K-group size: \" << k_group_size << std::endl;\n  std::cout << \"Max absolute error: \" << max_abs_error << std::endl;\n  std::cout << \"Average absolute error: \" << avg_abs_error << std::endl;\n  std::cout << \"Max reference value: \" << max_ref_value << std::endl;\n  std::cout << \"Relative error: \" << (relative_error * 100) << \"%\" << std::endl;\n\n  // Check if accuracy is acceptable for INT4\n  // INT4 quantization typically has 5-10% error\n  if (relative_error < 0.15f) {\n    std::cout << \"✓ Accuracy is acceptable for INT4 quantization\" << std::endl;\n  } else {\n    std::cout << \"✗ Accuracy needs improvement\" << std::endl;\n  }\n\n  free(buffer_a);\n  free(buffer_b);\n  free(buffer_c);\n}\n\nint main() {\n  verify_kgroup_accuracy();\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/common.hpp",
    "content": "#ifndef CPUINFER_OPERATOR_COMMON_HPP\n#define CPUINFER_OPERATOR_COMMON_HPP\n\n#include <map>\n\n#include \"../cpu_backend/worker_pool.h\"\n#include \"ggml.h\"\n\n#if defined(__aarch64__) && defined(CPU_USE_KML)\n#include <arm_sve.h>\n#endif\n\n#include <chrono>\n#include <cmath>\n#include <cstdio>\n#include <cstring>\n#include <stdexcept>\n#include <type_traits>\n\n// #define FORWARD_TIME_PROFILE\n// #define FORWARD_TIME_REPORT\n\n#define ASSERT_RELEASE(x, text)                                                            \\\n  do {                                                                                     \\\n    if (!(x)) {                                                                            \\\n      fprintf(stderr, \"Assertion failed: %s, file %s, line %d\\n\", #x, __FILE__, __LINE__); \\\n      fprintf(stderr, \"Error message: %s\\n\", (text));                                      \\\n      throw std::runtime_error((text));                                                    \\\n    }                                                                                      \\\n  } while (0)\n\n#define PUSH_MEM_REQ(ptr, size) mem_requests.append_pointer(&(ptr), (size))\n\n#define PROFILE_RECORD_TIME_STAMP(name)                                                             \\\n  do {                                                                                              \\\n    auto end_time = std::chrono::high_resolution_clock::now();                                      \\\n    auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - last).count(); \\\n    time_map[(name)] = duration;                                                                    \\\n    last = end_time;                                                                                \\\n  } while (0)\n\n#define DO_TPS_LOAD_WEIGHTS(pool)                                                         \\\n  (pool)->dispense_backend()->do_numa_job([this, pool, config](int numa_id) {             \\\n    this->tps[numa_id]->config_.physical_to_logical_map = config.physical_to_logical_map; \\\n    this->tps[numa_id]->load_weights();                                                   \\\n  })\n\n#define expert_map(m, x) (m != nullptr ? m[(x)] : (x))\n\ntemplate <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>\ninline T div_up(T x, T y) {\n  return (x + y - 1) / y;\n}\n\ntemplate <typename T>\nT* offset_pointer(T* ptr, size_t byte_offset) {\n  return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr) + byte_offset);\n}\n\ntemplate <typename T>\nsize_t pointer_offset(T* ptr, T* b) {\n  return reinterpret_cast<size_t>(b) - reinterpret_cast<size_t>(ptr);\n}\n\ntemplate <typename T>\nconst T* offset_pointer(const T* ptr, size_t byte_offset) {\n  return reinterpret_cast<const T*>(reinterpret_cast<const char*>(ptr) + byte_offset);\n}\n\ntemplate <typename T>\nT* offset_pointer_row_major(T* t, int row, int col, size_t ld) {\n  return offset_pointer(t, row * ld) + col;\n}\n\ntemplate <typename T>\nT* offset_pointer_col_major(T* t, int row, int col, size_t ld) {\n  return offset_pointer(t, col * ld) + row;\n}\n\nclass TimePerf {\n protected:\n  std::string time_perf_name;\n  std::map<std::string, long> time_map;\n  std::chrono::time_point<std::chrono::high_resolution_clock> last;\n  std::chrono::time_point<std::chrono::high_resolution_clock> start_time;\n\n  void forward_perf_start() {\n    start_time = std::chrono::high_resolution_clock::now();\n    last = start_time;\n  }\n\n  void perf_report() {\n    auto end_time = std::chrono::high_resolution_clock::now();\n    auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);\n    std::string output = time_perf_name + \", forward time: \" + std::to_string(duration.count()) + \" us\";\n    // for (auto [name, t] : time_map) {\n    //   double p = 100.0 * t / duration.count();\n    //   // if (p < 1.0) {\n    //   //   continue; // Skip if the percentage is less than 1%\n    //   // }\n    //   output += \", \" + name + \": \" + std::to_string(t) + \" us(\" + std::to_string(size_t(round(p))) + \"%)\";\n    // }\n    // 反向遍历\n    for (auto it = time_map.rbegin(); it != time_map.rend(); ++it) {\n      const std::string& name = it->first;\n      long t = it->second;\n      double p = 100.0 * t / duration.count();\n      // if (p < 1.0) {\n      //   continue; // Skip if the percentage is less than 1%\n      // }\n      output += \", \" + name + \": \" + std::to_string(t) + \" us(\" + std::to_string(size_t(round(p))) + \"%)\";\n    }\n    printf(\"%s\\n\", output.c_str());\n  }\n};\n\nstruct TaskCounter {\n  std::vector<size_t> fold = {}, card = {};\n\n  TaskCounter(std::initializer_list<size_t> i) {\n    card.push_back(1);\n    for (auto j : i) {\n      push_back(j);\n    }\n  }\n\n  void push_back(size_t i) {\n    fold.push_back(i);\n    for (auto& c : card) {\n      c *= i;\n    }\n    card.push_back(1);\n  }\n  void push_back(std::vector<size_t> i) {\n    for (auto j : i) {\n      push_back(j);\n    }\n  }\n  size_t count() { return card[0]; }\n  size_t at(size_t id, size_t which) { return id % card.at(which) / card.at(which + 1); }\n};\n\nstruct GeneralConfig {\n  size_t vocab_size;\n  size_t hidden_size;\n\n  size_t num_experts_per_tok;\n  size_t n_routed_experts;\n  size_t n_shared_experts;\n  size_t max_qlen = 4096;\n\n  void* lm_heads_ptr;\n  ggml_type lm_heads_type;\n  void* norm_weights_ptr;\n  ggml_type norm_weights_type;\n  void* token_embd_ptr;\n  ggml_type token_embd_type;\n  WorkerPool* pool = nullptr;\n  GeneralConfig() {}\n};\n\nstruct GeneralMLAConfig {\n  size_t hidden_size;\n  size_t q_lora_rank;\n  size_t num_heads;\n  size_t nope_size;\n  size_t rope_size;\n  size_t kv_lora_rank;\n\n  int layer_idx = 0;\n  WorkerPool* pool = nullptr;\n  size_t token_count_in_page = 256;  // token count in a page\n  size_t max_qlen = 1024;\n  size_t max_kvlen = 4096;\n\n  // rope\n  size_t max_position_embeddings;\n  double rope_scaling_factor = 1.0;\n  double rope_theta = 10000.0;\n  double rope_scaling_beta_fast;\n  double rope_scaling_beta_slow;\n  double rope_scaling_mscale;\n  double rope_scaling_mscale_all_dim;\n  double rope_scaling_original_max_position_embeddings;\n\n  void* q_a_proj;\n  void* q_a_norm = nullptr;\n  void* q_b_proj;\n  void* kv_a_proj_with_mqa;\n  void* kv_a_norm = nullptr;\n  void* kv_b_proj;\n  void* o_proj;\n\n  // for llamafile\n  ggml_type q_a_proj_type;\n  ggml_type q_a_norm_type;\n  ggml_type q_b_proj_type;\n  ggml_type kv_a_proj_with_mqa_type;\n  ggml_type kv_a_norm_type;\n  ggml_type kv_b_proj_type;\n  ggml_type w_o_type;\n\n  ggml_type input_type = GGML_TYPE_F32;\n  ggml_type output_type = GGML_TYPE_F32;\n\n  size_t m_block = 4;\n  size_t n_block = 4;\n  // for kvcache\n  size_t page_count = 200;  // page count for kv cache\n\n  GeneralMLAConfig() {}\n  GeneralMLAConfig(size_t hidden_size, size_t q_lora_rank, size_t kv_lora_rank, size_t num_heads, size_t nope_size,\n                   size_t rope_size)\n      : hidden_size(hidden_size),\n        q_lora_rank(q_lora_rank),\n        kv_lora_rank(kv_lora_rank),\n        num_heads(num_heads),\n        nope_size(nope_size),\n        rope_size(rope_size) {}\n};\n\nstruct QuantConfig {\n  std::string quant_method = \"\";\n  int bits = 0;\n  int group_size = 0;\n  bool zero_point = false;\n  bool per_channel = false;  // Per-channel quantization (GLM-4.7-FP8 style)\n};\n\nstruct GeneralMOEConfig {\n  // Basic Config\n  int expert_num;\n  int num_experts_per_tok;\n  int hidden_size;\n  int intermediate_size;\n\n  int layer_idx = 0;\n  WorkerPool* pool = nullptr;\n\n  // SGLang offload\n  int num_gpu_experts = 0;              // Computed from gpu_experts_mask\n  uint8_t* gpu_experts_mask = nullptr;  // Bool mask: true = expert on GPU\n  void* physical_to_logical_map = nullptr;\n\n  // Compute num_gpu_experts from gpu_experts_mask\n  void compute_num_gpu_experts() {\n    num_gpu_experts = 0;\n    if (gpu_experts_mask) {\n      for (int i = 0; i < expert_num; i++) {\n        if (gpu_experts_mask[i]) num_gpu_experts++;\n      }\n    }\n  }\n\n  // Check if expert should be skipped (invalid, out of range, or on GPU)\n  inline bool should_skip_expert(int64_t expert_id) const {\n    return expert_id < 0 || expert_id >= expert_num || (gpu_experts_mask && gpu_experts_mask[expert_id]);\n  }\n\n  void* gate_proj;\n  void* up_proj;\n  void* down_proj;\n\n  void* gate_scale;\n  void* up_scale;\n  void* down_scale;\n\n  void* gate_zero;\n  void* up_zero;\n  void* down_zero;\n\n  QuantConfig quant_config;\n\n  // for amx\n  int max_len = 0;\n  std::vector<std::vector<void*>> gate_projs;\n  std::vector<std::vector<void*>> up_projs;\n  std::vector<std::vector<void*>> down_projs;\n  std::vector<std::vector<void*>> gate_scales;\n  std::vector<std::vector<void*>> up_scales;\n  std::vector<std::vector<void*>> down_scales;\n  std::vector<std::vector<void*>> gate_zeros;\n  std::vector<std::vector<void*>> up_zeros;\n  std::vector<std::vector<void*>> down_zeros;\n\n  std::string path;\n  bool save = false;\n  bool load = false;\n\n  // for llamafile\n  int m_block = 4;\n  int group_min_len = 0;\n  int group_max_len = 0;\n  int gate_type;\n  int up_type;\n  int down_type;\n  int hidden_type;\n\n  GeneralMOEConfig() {}\n\n  GeneralMOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size)\n      : expert_num(expert_num),\n        num_experts_per_tok(routed_expert_num),\n        hidden_size(hidden_size),\n        intermediate_size(intermediate_size) {}\n\n  int max_possible_qlen() { return std::max(max_len, group_max_len); }\n};\n\nstruct GeneralGateConfig {\n  size_t hidden_size;\n  size_t num_experts_per_tok;\n  size_t n_routed_experts;\n  size_t n_group;\n  size_t topk_group;\n\n  bool norm_topk_prob = true;\n  float routed_scaling_factor = 2.5f;\n\n  std::string scoring_func = \"sigmoid\";\n  std::string topk_method = \"noaux_tc\";\n\n  int layer_idx = 0;\n  WorkerPool* pool = nullptr;\n\n  void* weight = nullptr;\n  ggml_type weight_type;\n  void* e_score_correction_bias = nullptr;\n  ggml_type e_score_correction_bias_type;\n\n  size_t max_seqlen = 25600;\n\n  GeneralGateConfig() = default;\n\n  GeneralGateConfig(int hidden_size, int num_experts_per_tok, int n_routed_experts, int n_group, int topk_group)\n      : hidden_size(hidden_size),\n        num_experts_per_tok(num_experts_per_tok),\n        n_routed_experts(n_routed_experts),\n        n_group(n_group),\n        topk_group(topk_group) {}\n};\n\nclass MLA_Interface {\n public:\n  virtual void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,\n                       const void* input, void* output) = 0;\n};\n\nclass MoE_Interface {\n public:\n  virtual void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,\n                       void* output, bool incremental = false) = 0;\n};\ninline void init_ggml() {\n  static bool inited = false;\n  if (inited) {\n    return;\n  }\n  struct ggml_init_params params = {\n      0,\n      NULL,\n      true,\n  };\n\n  auto ctx_eval = ggml_init(params);\n\n  if (!ctx_eval) {\n    throw std::runtime_error(\"Failed to create ggml context\");\n  }\n  inited = true;\n}\n\ntemplate <typename A, typename B>\nvoid convert_or_copy(A* dst, const B* src, size_t count) {\n  if constexpr (std::is_same_v<A, B>) {\n    // printf(\"Direct copy\\n\");\n    memcpy(dst, src, sizeof(A) * count);\n  } else {\n    if constexpr (std::is_same_v<A, float>) {\n      if constexpr (std::is_same_v<B, ggml_bf16_t>) {\n        // printf(\"Converting ggml_bf16_t to float\\n\");\n        ggml_bf16_to_fp32_row(src, dst, count);\n      } else if constexpr (std::is_same_v<B, ggml_fp16_t>) {\n        ggml_fp16_to_fp32_row(src, dst, count);\n      } else {\n        throw std::runtime_error(\"Unsupported conversion\");\n      }\n    } else if constexpr (std::is_same_v<A, ggml_bf16_t>) {\n      if constexpr (std::is_same_v<B, float>) {\n        // printf(\"Converting float to ggml_bf16_t\\n\");\n        ggml_fp32_to_bf16_row(src, dst, count);\n      } else {\n        throw std::runtime_error(\"Unsupported conversion\");\n      }\n    }\n\n    else {\n      throw std::runtime_error(\"Unsupported conversion\");\n    }\n  }\n}\n\ntemplate <typename A>\nvoid convert_or_copy(A* dst, void* src, ggml_type type, size_t count) {\n  switch (type) {\n    case GGML_TYPE_BF16: {\n      auto src_bf16 = (ggml_bf16_t*)src;\n      convert_or_copy(dst, src_bf16, count);\n      break;\n    }\n    case GGML_TYPE_F16: {\n#if defined(__aarch64__) && defined(CPU_USE_KML)\n      auto src_fp16 = (float16_t*)src;\n      convert_or_copy(dst, src_fp16, count);\n#else\n      throw std::runtime_error(\"GGML_TYPE_F16 is not supported on this platform\");\n#endif\n      break;\n    }\n    case GGML_TYPE_F32: {\n      auto src_f32 = (float*)src;\n      convert_or_copy(dst, src_f32, count);\n      break;\n    }\n    default:\n      throw std::runtime_error(\"Unsupported type for conversion\");\n  }\n}\n\ntemplate <typename A>\nvoid check_numerics(A* data, size_t count) {\n  for (size_t i = 0; i < count; i++) {\n    if (std::isnan(data[i]) || std::isinf(data[i])) {\n      printf(\"Numerics check failed at index %zu: value = %f\\n\", i, data[i]);\n      throw std::runtime_error(\"Numerics check failed\");\n    }\n  }\n  printf(\"Numerics check passed for %zu elements.\\n\", count);\n}\n\ninline void debug_bf16(ggml_bf16_t* x) {\n  for (int i = 0; i < 10; i++) {\n    printf(\"%f \", ggml_bf16_to_fp32(x[i]));\n  }\n  printf(\"\\n\");\n}\ninline void debug_f32(float* x) {\n  for (int i = 0; i < 10; i++) {\n    printf(\"%f \", x[i]);\n  }\n  printf(\"\\n\");\n}\n\ninline void debug_f32(float* x, size_t count) {\n  if (count < 10) {\n    for (size_t i = 0; i < count; i++) {\n      printf(\"%f \", x[i]);\n    }\n  } else {\n    for (size_t i = 0; i < 3; i++) {\n      printf(\"%f \", x[i]);\n    }\n    printf(\"...\");\n    for (size_t i = count - 3; i < count; i++) {\n      printf(\"%f \", x[i]);\n    }\n    printf(\"\\n\");\n  }\n}\n\n#endif\n"
  },
  {
    "path": "kt-kernel/operators/kvcache/kvcache.h",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#ifndef CPUINFER_OPERATOR_KVCACHE_H\n#define CPUINFER_OPERATOR_KVCACHE_H\n\n#include <cassert>\n#include <cstdint>\n#include <cstdio>\n#include <cstring>\n#include <functional>\n#include <memory>\n#include <mutex>\n#include <queue>\n#include <vector>\n\n#include \"../../cpu_backend/worker_pool.h\"\n#include \"llama.cpp/ggml-common.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n\n#define CHUNK_SIZE 32\n\n/**\n * @brief Converts a ggml_type enum value to its corresponding string\n * representation.\n *\n * This function provides a human-readable string representation for a given\n * ggml_type enum value. The string can be used for logging, debugging, or\n * displaying information in a user interface.\n *\n * @param type The ggml_type enum value to convert.\n * @return A string representation of the enum value.\n */\nstd::string ggml_type_to_string(ggml_type type);\n\n/**\n * @enum AnchorType\n * @brief Defines the types of anchors used in attention mechanisms.\n *\n * This enum specifies different types of anchors that can be used in attention\n * mechanisms, such as fixed anchors, dynamic anchors, or special anchors like\n * QUEST, BLOCK_MEAN, or BLOCK_MAX.\n */\nenum AnchorType {\n  FIXED_ANCHOR, /**< A fixed anchor that does not change. */\n  DYNAMIC,      /**< A dynamic anchor that can change over time. */\n  QUEST,        /**< A special anchor type used for QUEST (Query and Embedding Space\n                   Transformation). */\n  BLOCK_MEAN,   /**< An anchor based on the mean of a block of data. */\n  BLOCK_MAX     /**< An anchor based on the maximum value within a block of data.\n                 */\n};\n\n/**\n * @brief Converts an AnchorType enum value to its corresponding string\n * representation.\n *\n * This function provides a human-readable string representation for a given\n * AnchorType enum value. The string can be used for logging, debugging, or\n * displaying information in a user interface.\n *\n * @param anchor_type The AnchorType enum value to convert.\n * @return A string representation of the enum value.\n */\nstd::string AnchorTypeToString(AnchorType anchor_type);\n\n/**\n * @enum RetrievalType\n * @brief Defines the types of retrieval strategies in attention mechanisms.\n *\n * This enum specifies different retrieval strategies that can be used in\n * attention mechanisms, such as layer-level retrieval, key-value head-level\n * retrieval, or query head-level retrieval.\n */\nenum RetrievalType {\n  LAYER,  /**< Retrieval at the layer level. */\n  KVHEAD, /**< Retrieval at the key-value head level. */\n  QHEAD   /**< Retrieval at the query head level. */\n};\n\n/**\n * @brief Converts a RetrievalType enum value to its corresponding string\n * representation.\n *\n * This function provides a human-readable string representation for a given\n * RetrievalType enum value. The string can be used for logging, debugging, or\n * displaying information in a user interface.\n *\n * @param retrieval_type The RetrievalType enum value to convert.\n * @return A string representation of the enum value.\n */\nstd::string RetrievalTypeToString(RetrievalType retrieval_type);\n\n/**\n * @struct KVCacheConfig\n * @brief Configuration structure for Key-Value (KV) Cache.\n *\n * This structure holds configuration parameters for setting up and managing\n * a Key-Value (KV) Cache used in various attention mechanisms. It includes\n * parameters such as the number of layers, the number of heads, the dimension\n * of each head, block length, anchor information, and memory-related settings.\n */\nstruct KVCacheConfig {\n  int layer_num;   /**< Number of layers in the model. */\n  int kv_head_num; /**< Number of heads in the KV Cache. */\n  int q_head_num;  /**< Number of heads in the query. */\n  int head_dim;    /**< Dimension of each head. */\n  int block_len;   /**< Length of each block in the cache. */\n  int anchor_num;  /**< Number of anchors used in attention. */\n\n  ggml_type kv_type; /**< Data type of the KV Cache (e.g., fp16, q8_0). */\n\n  // Controls the pre-allocated memory size\n  int max_block_num;  /**< Maximum number of blocks that can be allocated. */\n  int max_batch_size; /**< Maximum batch size that can be processed. */\n  int max_thread_num; /**< Maximum number of threads that can be used. */\n\n  AnchorType anchor_type;       /**< Type of anchors used in the attention mechanism. */\n  RetrievalType retrieval_type; /**< Type of retrieval strategy used in the cache. */\n\n  int layer_step;   /**< Step size between layers. */\n  int token_step;   /**< Step size between tokens. */\n  int layer_offset; /**< Offset value for layers. */\n\n  /**\n   * @brief Default constructor for KVCacheConfig.\n   *\n   * Initializes the configuration with default values. This constructor\n   * does not initialize any member variables explicitly.\n   */\n  KVCacheConfig() = default;\n\n  /**\n   * @brief Parameterized constructor for KVCacheConfig.\n   *\n   * This constructor initializes the configuration with specific values\n   * for all member variables.\n   *\n   * @param layer_num The number of layers in the model.\n   * @param kv_head_num The number of heads in the KV Cache.\n   * @param q_head_num The number of heads in the query.\n   * @param head_dim The dimension of each head.\n   * @param block_len The length of each block in the cache.\n   * @param anchor_num The number of anchors used in attention.\n   * @param anchor_type The type of anchors used in the attention mechanism.\n   * @param kv_type The data type of the KV Cache (e.g., fp16, q8_0).\n   * @param retrieval_type The type of retrieval strategy used in the cache.\n   * @param layer_step The step size between layers.\n   * @param token_step The step size between tokens.\n   * @param layer_offset The offset value for layers.\n   * @param max_block_num The maximum number of blocks that can be allocated.\n   * @param max_batch_size The maximum batch size that can be processed.\n   * @param max_thread_num The maximum number of threads that can be used.\n   */\n  KVCacheConfig(int layer_num, int kv_head_num, int q_head_num, int head_dim, int block_len, int anchor_num,\n                AnchorType anchor_type, ggml_type kv_type, RetrievalType retrieval_type, int layer_step, int token_step,\n                int layer_offset, int max_block_num, int max_batch_size, int max_thread_num);\n};\n\n/**\n * @class KVCache\n * @brief Manages the Key-Value (KV) Cache used in attention mechanisms.\n *\n * The KVCache class provides functionality for managing the Key-Value Cache,\n * including resizing the cache, retrieving configuration parameters, and\n * updating internal states. This class is typically used in transformer models\n * to store and manage past key and value states for efficient attention\n * computations.\n */\nclass KVCache {\n public:\n  /**\n   * @brief Constructs a KVCache object with the given configuration.\n   *\n   * Initializes the KVCache with the specified configuration parameters,\n   * such as the number of layers, heads, head dimensions, and other\n   * relevant settings.\n   *\n   * @param config The configuration object containing initialization\n   * parameters.\n   */\n  KVCache(KVCacheConfig config);\n\n  /**\n   * @brief Resizes the number of threads used by the cache.\n   *\n   * This function adjusts the number of threads that the cache can utilize.\n   * It allows dynamic reconfiguration of the parallel processing capabilities\n   * based on the current workload or system resources.\n   *\n   * @param thread_num The new number of threads to use.\n   */\n  void ThreadResize(int thread_num);\n\n  /**\n   * @brief Resizes the batch size managed by the cache.\n   *\n   * This function adjusts the batch size that the cache can handle. It\n   * is useful when the input batch size changes dynamically, allowing\n   * the cache to be reconfigured accordingly.\n   *\n   * @param batch_size The new batch size.\n   */\n  void BatchResize(int batch_size);\n\n  /**\n   * @brief Resizes the number of blocks managed by the cache.\n   *\n   * This function adjusts the number of blocks that the cache can manage.\n   * It allows dynamic reconfiguration of the block structure based on the\n   * current sequence length or other factors.\n   *\n   * @param block_num The new number of blocks.\n   */\n  void BlockResize(int block_num);\n\n  /**\n   * @brief Gets the number of layers in the cache.\n   *\n   * @return The number of layers configured in the cache.\n   */\n  int get_layer_num() { return config_.layer_num; }\n\n  /**\n   * @brief Gets the number of KV heads in the cache.\n   *\n   * @return The number of KV heads configured in the cache.\n   */\n  int get_kv_head_num() { return config_.kv_head_num; }\n\n  /**\n   * @brief Gets the number of query heads in the cache.\n   *\n   * @return The number of query heads configured in the cache.\n   */\n  int get_q_head_num() { return config_.q_head_num; }\n\n  /**\n   * @brief Gets the dimension of each head in the cache.\n   *\n   * @return The dimension of each head.\n   */\n  int get_head_dim() { return config_.head_dim; }\n\n  /**\n   * @brief Gets the length of each block in the cache.\n   *\n   * @return The length of each block.\n   */\n  int get_block_len() { return config_.block_len; }\n\n  /**\n   * @brief Gets the number of blocks for a specific layer.\n   *\n   * @param layer_id The ID of the layer for which to retrieve the block\n   * number.\n   * @return The number of blocks in the specified layer.\n   */\n  int get_block_num(int layer_id) { return past_block_num_[layer_id]; }\n\n  /**\n   * @brief Gets the number of anchors in the cache.\n   *\n   * @return The number of anchors configured in the cache.\n   */\n  int get_anchor_num() { return config_.anchor_num; }\n\n  /**\n   * @brief Gets the total length of the cache.\n   *\n   * @return The total length of the cache.\n   */\n  int get_cache_total_len() { return cache_total_len_; }\n\n  /**\n   * @brief Gets the total number of blocks in the cache.\n   *\n   * This function computes and returns the total number of blocks in the\n   * cache based on the total cache length and the block length configuration.\n   *\n   * @return The total number of blocks in the cache.\n   */\n  int get_cache_total_block_num() { return (cache_total_len_ + config_.block_len - 1) / config_.block_len; }\n\n  /**\n   * @brief Updates the total length of the cache.\n   *\n   * This function sets a new total length for the cache, allowing dynamic\n   * adjustment of the cache size during runtime.\n   *\n   * @param cache_total_len The new total length of the cache.\n   */\n  void update_cache_total_len(int cache_total_len) { cache_total_len_ = cache_total_len; }\n  void attn(const ggml_fp16_t* q_in, ggml_fp16_t* output, float* attn_lse, int layer_idx, int generate_token_idx,\n            int q_len, int batch_size, int max_block_num, int* block_table, int* cache_seqlens, int pick_block_num,\n            int init_block_num, int local_block_num, WorkerPool* backend);\n\n  void update_kvcache_one_block_fp16(const ggml_fp16_t* k_in, const ggml_fp16_t* v_in, int layer_id, int block_idx,\n                                     WorkerPool* backend);\n\n  void get_kvcache_one_block_fp16(ggml_fp16_t* k_in, ggml_fp16_t* v_in, int layer_id, int block_idx,\n                                  WorkerPool* backend);\n\n  void update_importance_one_block(const ggml_fp16_t* importance, int layer_id, int block_idx, WorkerPool* backend);\n  void get_importance_one_block(ggml_fp16_t* importance, int layer_id, int block_idx, WorkerPool* backend);\n\n  void get_anchor_one_block(ggml_fp16_t* anchor, int layer_id, int block_idx, WorkerPool* backend);\n\n  void update_anchor_one_block(const ggml_fp16_t* anchor, int layer_id, int block_idx, WorkerPool* backend);\n\n  void calc_anchor_all_layers(int* block_table, int* cache_seqlens, int batch_size, int max_block_num,\n                              WorkerPool* backend);\n\n  void load_kvcache(std::string tensor_file_path, WorkerPool* backend);\n  void dump_kvcache(int* block_table, int cache_total_len, std::string tensor_file_path, WorkerPool* backend);\n\n  void get_and_update_kvcache_fp16(ggml_fp16_t* k_in, ggml_fp16_t* v_in, int layer_id, int* block_table, int batch_size,\n                                   int max_block_num, int* cache_seqlens, int q_len, WorkerPool* backend);\n\n  void get_kvcache_fp16(ggml_fp16_t* k_in, ggml_fp16_t* v_in, int layer_id, int* block_table, int batch_size,\n                        int max_block_num, int* cache_seqlens, WorkerPool* backend);\n\n  void update_kvcache_fp16(const ggml_fp16_t* k_in, const ggml_fp16_t* v_in, int layer_id, int* block_table,\n                           int batch_size, int max_block_num, int* cache_seqlens, int q_len, WorkerPool* backend);\n\n  void update_importance(const ggml_fp16_t* importance, int layer_id, int* block_table, int batch_size,\n                         int max_block_num, int* offset, int width, WorkerPool* backend);\n\n  void attn_with_kvcache(const ggml_fp16_t* q_in, const ggml_fp16_t* k_in, const ggml_fp16_t* v_in, ggml_fp16_t* output,\n                         float* attn_lse, int layer_idx, int generate_token_idx, int q_len, int batch_size,\n                         int max_block_num, int* block_table, int* cache_seqlens, int topk, int local,\n                         WorkerPool* backend);\n\n  void clear_importance_all_layers(int* block_table, int* cache_seqlens, int batch_size, int max_block_num,\n                                   WorkerPool* backend);\n\n  void clear_kvcache_all_layers(int* block_table, int* cache_seqlens, int batch_size, int max_block_num,\n                                WorkerPool* backend);\n\n  void get_sincos(ggml_fp16_t* sin, ggml_fp16_t* cos, int seqlen);\n\n  void get_attn_sparsity(const ggml_fp16_t* q_in, float* attn_sparsity, int layer_idx, int generate_token_idx,\n                         int q_len, int batch_size, int max_block_num, int* block_table, int* cache_seqlens,\n                         int* block_table_origin, int* cache_seqlens_origin, int max_block_num_origin, int topk,\n                         int local, WorkerPool* backend);\n\n  void get_all_kvcache_one_layer(int layer_id, ggml_fp16_t* k_in, ggml_fp16_t* v_in, WorkerPool* backend);\n\n private:\n  // Persistent data\n  KVCacheConfig config_;\n  int n_gqa_;                             // q_head_num / kv_head_num\n  int cache_total_len_;                   // Number of tokens in cache\n  std::vector<uint64_t> past_block_num_;  // [layer_num]\n  std::vector<std::vector<std::vector<std::vector<block_q4_0>>>>\n      k_cache_q4;  // [layer_num, kv_head_num, past_block_num,\n                   // block_len * (head_dim / QK_4)]\n  std::vector<std::vector<std::vector<std::vector<block_q4_0>>>>\n      v_cache_q4;  // [layer_num, kv_head_num, past_block_num,\n                   // head_dim * (block_len / QK_4)]\n  std::vector<std::vector<std::vector<std::vector<block_q8_0>>>>\n      k_cache_q8;  // [layer_num, kv_head_num, past_block_num,\n                   // block_len * (head_dim / QK_8)]\n  std::vector<std::vector<std::vector<std::vector<block_q8_0>>>>\n      v_cache_q8;  // [layer_num, kv_head_num, past_block_num,\n                   // head_dim * (block_len / QK_8)]\n\n  std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>\n      k_cache_fp16_;  // [layer_num, kv_head_num, past_block_num, block_len *\n                      // head_dim]\n  std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>\n      v_cache_fp16_;  // [layer_num, kv_head_num, past_block_num, head_dim *\n                      // block_len]\n\n  std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>> importance_;  // [layer_num, past_block_num,\n                                                                                // block_len, attention_head_num]\n\n  std::vector<ggml_fp16_t> anchor_;  // [layer_num * past_block_num * anchor_num *\n                                     // attention_head_num * head_dim]\n\n  // Runtime data\n  int64_t layer_id_;\n  int64_t block_idx_;\n  int* block_table_;\n  uint64_t block_num_;\n  int max_block_num_after_retrieval_;\n\n  // Rotary positional embeddings\n  std::vector<std::vector<ggml_fp16_t>> sin_;  // [seq_len, head_dim]\n  std::vector<std::vector<ggml_fp16_t>> cos_;  // [seq_len, head_dim]\n\n  // update/get\n  int seq_len_;\n  uint16_t* k_scales_;         // q4_0\n  uint8_t* k_in_;              // q4_0\n  uint16_t* v_scales_;         // q4_0\n  uint8_t* v_in_;              // q4_0\n  uint16_t* k_data_;           // fp16\n  uint16_t* v_data_;           // fp16\n  uint16_t* importance_data_;  // fp16\n  uint16_t* anchor_data_;      // fp16\n\n  // sparsity = (sigma(block lse / lse))\n  std::vector<std::vector<std::vector<float>>> block_lse_;  // [batch_size, max_block_num, q_head_num]\n  std::vector<std::vector<float>> attn_sparsity_;           // [batch_size, q_head_num]\n\n  // attn\n  std::vector<std::vector<float>> avg_q;  // [batch_size, q_head_num * head_dim]\n\n  std::vector<std::vector<ggml_fp16_t>> avg_q_fp16;  // [batch_size, q_head_num * head_dim]\n  std::vector<std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>, std::greater<>>>\n      top_similar_block_;\n\n  std::vector<std::vector<float>> block_similar_;\n  std::vector<std::vector<std::vector<float>>> block_similar_kv_head_;\n  std::vector<std::vector<std::vector<float>>> block_similar_q_head_;\n\n  std::vector<int> cache_seqlens_;                // [batch_size]\n  std::vector<int> selected_blocks_num_history_;  // [layer_num // layer_step]\n\n  std::vector<std::vector<std::vector<int>>> selected_blocks_history_;\n  // [layer_num // layer_step, batch_size, max_block_num]\n\n  std::vector<std::vector<std::vector<std::vector<int>>>>\n      selected_blocks_history_kvhead_;  // [layer_num // layer_step,\n                                        // batch_size, max_block_num,\n                                        // kv_head_num]\n\n  std::vector<std::vector<int>> block_table_before_retrieval_;  // [batch_size, max_block_num]\n  std::vector<std::vector<int>> block_table_after_retrieval_;   // [batch_size, pick_block_num]\n\n  std::vector<std::vector<std::vector<int>>> block_table_before_retrieval_qhead_;  // [batch_size, max_block_num,\n                                                                                   // q_head_num]\n  std::vector<std::vector<std::vector<int>>> block_table_after_retrieval_qhead_;   // [batch_size, pick_block_num,\n                                                                                   // q_head_num]\n\n  std::vector<std::vector<std::vector<int>>> block_table_before_retrieval_kvhead_;  // [batch_size, max_block_num,\n                                                                                    // kv_head_num]\n  std::vector<std::vector<std::vector<int>>> block_table_after_retrieval_kvhead_;   // [batch_size, pick_block_num,\n                                                                                    // kv_head_num]\n\n  std::vector<std::vector<std::unique_ptr<std::mutex>>> mutex_;  // [batch_size, kv_head_num]\n  std::vector<std::vector<std::vector<block_q8_0>>> q_q8_0_;     // [batch_size, kv_head_num, n_gqa * head_dim / QK8_0]\n  std::vector<std::vector<std::vector<float>>> q_fp32_;          // [batch_size, kv_head_num, n_gqa * head_dim]\n\n  std::vector<std::vector<std::vector<float>>> output_fp32_;  // [batch_size, kv_head_num, n_gqa * head_dim]\n  std::vector<std::vector<std::vector<float>>> attn_lse_;     // [batch_size, kv_head_num, n_gqa]\n\n  std::vector<std::pair<int, int>> thread_cur_head_idx_;  // [thread_num]\n\n  std::vector<std::vector<block_q8_0>> thread_local_output_q8_0_;  // [thread_num, n_gqa * head_dim / QK8_0]\n  std::vector<std::vector<float>> thread_local_attn_score_;        // [thread_num, n_gqa * block_len]\n  std::vector<std::vector<float>> thread_local_output_fp32_;       // [thread_num, n_gqa * head_dim]\n  std::vector<std::vector<float>> thread_local_attn_lse_;          // [thread_num, n_gqa]\n  std::vector<std::vector<float>> thread_local_cur_output_fp32_;   // [thread_num, n_gqa * head_dim]\n  std::vector<std::vector<float>> thread_local_cur_attn_lse_;      // [thread_num, n_gqa]\n  std::vector<std::vector<uint8_t>> thread_local_attn_mask_;       // [thread_num, block_len // 8]\n  std::vector<std::vector<char>> thread_local_draft_;              // [thread_num, 2 * n_gqa * block_len + 6 * n_gqa *\n                                                                   // head_dim + 2 * block_len * head_dim]\n\n  // tmp space\n  std::vector<float> q_fp32;  // [n_gqa * head_dim]\n\n  void quantize_q_(const uint16_t* q_in_data, int batch_size);\n  void attn_initialize_layer_(int batch_size, int layer_idx, int* block_table, int& max_block_num, int* cache_seqlens);\n  void attn_initialize_kvhead_(int batch_size, int layer_idx, int* block_table, int& max_block_num, int* cache_seqlens);\n  void retrieval_kvcache_layer_(const uint16_t* q_in_data, int init_block_num, int local_block_num, int pick_block_num,\n                                int q_len, int generate_token_idx, int batch_size, int layer_idx, int* cache_seqlens,\n                                int& max_block_num, WorkerPool* backend);\n  void retrieval_kvcache_kvhead_(const uint16_t* q_in_data, int init_block_num, int local_block_num, int pick_block_num,\n                                 int q_len, int generate_token_idx, int batch_size, int layer_idx, int* cache_seqlens,\n                                 int& max_block_num, WorkerPool* backend);\n\n  void calculate_block_similarity_layer_(const uint16_t* q_in_data, int batch_size, int layer_idx, int q_len,\n                                         int max_block_num, int* cache_seqlens, int init_block_num, int local_block_num,\n                                         int pick_block_num, WorkerPool* backend);\n  void calculate_block_similarity_kvhead_(const uint16_t* q_in_data, int batch_size, int layer_idx, int q_len,\n                                          int max_block_num, int* cache_seqlens, int init_block_num,\n                                          int local_block_num, int pick_block_num, WorkerPool* backend);\n\n  void select_block_layer_(int batch_size, int layer_idx, int max_block_num, int init_block_num, int local_block_num,\n                           int pick_block_num);\n  void select_block_kvhead_(int batch_size, int layer_idx, int max_block_num, int init_block_num, int local_block_num,\n                            int pick_block_num);\n\n  void calculate_sparsity_layer_(const uint16_t* q_in_data, float* attn_sparsity, int batch_size, int max_block_num,\n                                 int* block_table, int* cache_seqlens, WorkerPool* backend);\n  void calculate_sparsity_kvhead_(const uint16_t* q_in_data, float* attn_sparsity, int batch_size, int max_block_num,\n                                  int* block_table, int* cache_seqlens, WorkerPool* backend);\n\n  void attention_kvhead_(const uint16_t* q_in_data, ggml_fp16_t* output, float* attn_lse, int batch_size,\n                         WorkerPool* backend);\n  void attention_layer_(const uint16_t* q_in_data, ggml_fp16_t* output, float* attn_lse, int batch_size,\n                        WorkerPool* backend);\n\n  /**\n   * @brief Computes attention with KV cache for one block.\n   *\n   * This function performs attention computation for one block using KV\n   * cache. The function supports different data types for Q, K, and V caches,\n   * and provides options for quantization. The function does not perform any\n   * dynamic memory allocation internally, so all necessary buffers must be\n   * pre-allocated externally.\n   *\n   * @param head_dim The dimension of the head.\n   * @param bsz The batch size.\n   * @param q_type The data type of Q (GGML data type). Only supports fp16 and\n   * q8_0.\n   * @param q Pointer to the Q tensor [bsz, head_dim]. The quantization is\n   *          always applied along the head_dim dimension. The size must be\n   *          bsz * head_dim/32 * qtype_size. If head_dim % 32 != 0, an error\n   *          will be raised.\n   * @param past_kv_len The length of the past KV cache.\n   * @param past_kv_offset The offset in the past KV cache.\n   * @param is_full_attn Boolean flag indicating whether to use full attention\n   *                     (true for full 1 mask).\n   * @param attn_mask Pointer to the attention mask [bsz, past_kv_len]. If\n   *                  is_full_attn = false, a bit matrix is passed to\n   * represent the mask.\n   * @param k_type The data type of K cache (GGML data type). Only supports\n   *               fp16, q4_0, and q8_0.\n   * @param k_quant_type Quantization type for K cache. 0 for per_token, 1 for\n   *                     per_channel. Other values will raise an error.\n   * @param k_cache Pointer to the K cache tensor [seq_len, head_dim]. If\n   *                quant_type == 0, head_dim % 32 must be 0. If quant_type ==\n   * 1, seq_len % 32 must be 0.\n   * @param num_k_anchor The number of K anchors. If num_k_anchor == 0, it\n   * means no anchor is present.\n   * @param k_cache_anchors Pointer to the K cache anchors [num_k_anchor,\n   * head_dim]. The k_anchor_type must be fp16.\n   * @param k_cache_anchor_pos Pointer to the K cache anchor positions. Each\n   * token is associated with the nearest previous anchor position.\n   * @param v_type The data type of V cache (GGML data type).\n   * @param v_quant_type Quantization type for V cache.\n   * @param v_cache Pointer to the V cache tensor [head_dim, seq_len].\n   * @param num_v_anchor The number of V anchors.\n   * @param v_cache_anchors Pointer to the V cache anchors.\n   * @param v_cache_anchor_pos Pointer to the V cache anchor positions.\n   * @param attn_score Pre-allocated buffer for attention scores [bsz,\n   * past_kv_len].\n   * @param output Output tensor [bsz, head_dim] with the same type as q_type.\n   * @param lse Pre-allocated buffer [bsz] for the log-sum-exp of the\n   * attention scores.\n   * @param draft Pre-allocated temporary buffer. The buffer size should be\n   * enough to hold (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 *\n   *              past_kv_len * head_dim + past_kv_len * head_dim / 32) bytes.\n   * @param rotary_angle Pointer to the rotary angle tensor.\n   * @param rotary_cos Pointer to the cosine values for rotary embedding.\n   * @param rotary_sin Pointer to the sine values for rotary embedding.\n   */\n  void attn_with_kvcache_one_block_(int head_dim, int bsz,\n                                    ggml_type q_type,  // GGML data type of `Q`, only supports fp16 and q8_0\n                                    // [bsz, head_dim]\n                                    // Quantization is always on the head_dim dimension (per_token). If\n                                    // head_dim % 32 != 0, an error will be raised. The size must be bsz *\n                                    // head_dim/32 * qtype_size.\n                                    const void* q,\n\n                                    int past_kv_len, int past_kv_offset,\n                                    bool is_full_attn,  // true indicates a full 1 mask\n                                    // If is_full_attn = false, a bit matrix representing the mask is\n                                    // passed. [bsz, past_kv_len]\n                                    const uint8_t* attn_mask,\n\n                                    ggml_type k_type,  // GGML data type of `K Cache`, only supports fp16,\n                                                       // q4_0, q8_0\n                                    int k_quant_type,  // 0 for per_token, 1 for per_channel, others raise an\n                                                       // error\n                                    // [seq_len, head_dim]\n                                    // If quant_type == 0, head_dim % 32 must be 0.\n                                    // If quant_type == 1, seq_len % 32 must be 0.\n                                    const void* k_cache,\n\n                                    // k_anchor_type must be fp16\n                                    int num_k_anchor,  // num_k_anchor == 0 indicates no anchor\n                                    // [num_k_anchor, head_dim]\n                                    const void* k_cache_anchors,\n                                    // Each token is associated with the nearest previous position's anchor,\n                                    // with the same distance.\n                                    const int* k_cache_anchor_pos,\n\n                                    // v_cache similar to k_cache\n                                    ggml_type v_type, int v_quant_type,\n                                    // [head_dim, seq_len]\n                                    const void* v_cache, int num_v_anchor, const void* v_cache_anchors,\n                                    const int* v_cache_anchor_pos,\n\n                                    // Pre-allocated buffer for intermediate calculations [bsz,\n                                    // past_kv_len]. No malloc is performed inside this function.\n                                    float* attn_score,\n\n                                    // Output: [bsz, head_dim], with the same type as q_type\n                                    void* output,\n                                    // [bsz]\n                                    float* lse,\n\n                                    // Pre-allocated temporary buffer with sufficient size:\n                                    // (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *\n                                    // head_dim + past_kv_len * head_dim / 32) bytes.\n                                    void* draft,\n\n                                    // Apply rotary embedding online\n                                    const int* rotary_angle, const void* rotary_cos, const void* rotary_sin\n                                    // rotary_cos=None,\n                                    // rotary_sin=None,\n                                    // cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,\n                                    // cache_batch_idx: Optional[torch.Tensor] = None,\n                                    // rotary_interleaved=True,\n\n                                    // // Not supported for now\n                                    // window_size=(-1, -1),  # -1 means infinite context window\n                                    // alibi_slopes=None,\n  );\n};\n\n/**\n * @brief Scales a float32 vector by a given scalar value.\n *\n * This function multiplies each element of the input vector `y` by a scalar\n * `v`. It uses platform-specific optimizations if available, such as Apple's\n * Accelerate framework or SIMD instructions. If no specific optimization is\n * available, the function falls back to a simple scalar multiplication loop.\n *\n * @param n The number of elements in the vector `y`.\n * @param y The input vector to be scaled. The result will be stored in the same\n * vector.\n * @param v The scalar value by which to scale the vector.\n */\nvoid ggml_vec_scale_f32(const int n, float* y, const float v);\n#endif"
  },
  {
    "path": "kt-kernel/operators/kvcache/kvcache_attn.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include <chrono>\n#include <cmath>\n\n#include \"ggml-impl.h\"\n#include \"kvcache.h\"\n#include \"llamafile/sgemm.h\"\n\nvoid KVCache::attention_kvhead_(const uint16_t* q_in_data, ggml_fp16_t* output, float* attn_lse, int batch_size,\n                                WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  seq_len_ = config_.block_len;\n\n  backend->do_work_stealing_job(\n      batch_size * config_.kv_head_num * max_block_num_after_retrieval_,\n      [&](int thread_id) {\n        thread_cur_head_idx_[thread_id].first = -1;\n        thread_cur_head_idx_[thread_id].second = -1;\n      },\n      [&](int task_id) {\n        int batch_id = task_id / (config_.kv_head_num * max_block_num_after_retrieval_);\n        int head_id =\n            (task_id % (config_.kv_head_num * max_block_num_after_retrieval_)) / max_block_num_after_retrieval_;\n        int block_id = task_id % max_block_num_after_retrieval_;\n        int thread_id = WorkerPool::thread_local_id;\n\n        // If the block is out of the sequence length, skip it.\n        if (cache_seqlens_[batch_id] / config_.block_len < block_id) {\n          return;\n        }\n        int block_idx = block_table_after_retrieval_kvhead_[batch_id][block_id][head_id];\n        if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n          int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n          if (seq_len == 0) return;\n\n          // Prepare the attention mask for the last block.\n          int full_blocks = seq_len / 8;\n          int remaining_bits = seq_len % 8;\n          // Fill full blocks with 1s\n          for (int i = 0; i < full_blocks; ++i) {\n            thread_local_attn_mask_[thread_id][i] = 0xFF;\n          }\n          // Fill the remaining bits in the next block\n          if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n            thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;\n          } else {\n            thread_local_attn_mask_[thread_id][full_blocks] = 0;\n          }\n\n          for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n            thread_local_attn_mask_[thread_id][i] = 0;\n          }\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                                         (void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +\n                                                           head_id * n_gqa_ * config_.head_dim],\n                                         seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,\n                                         0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_fp32_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),\n                GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),\n                GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          }\n        } else {\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                (void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +\n                                  head_id * n_gqa_ * config_.head_dim],\n                seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,\n                nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                                         q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,\n                                         0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_q8_0_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                                         q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,\n                                         0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_q8_0_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          }\n        }\n        int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n        int cur_head_id = thread_cur_head_idx_[thread_id].second;\n        if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n          for (int i = 0; i < n_gqa_; i++) {\n            float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +\n                                 std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -\n                                                         thread_local_cur_attn_lse_[thread_id][i]));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));\n            for (int j = 0; j < config_.head_dim; j++) {\n              thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=\n                  thread_local_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n            thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n          }\n        } else {\n          if (cur_batch_idx != -1) {\n            mutex_[cur_batch_idx][cur_head_id]->lock();\n            for (int i = 0; i < n_gqa_; i++) {\n              if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {\n                attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];\n                for (int j = 0; j < config_.head_dim; j++) {\n                  output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =\n                      thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n                }\n                continue;\n              }\n              float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +\n                                   std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                                           attn_lse_[cur_batch_idx][cur_head_id][i]));\n              ggml_vec_scale_f32(config_.head_dim,\n                                 output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,\n                                 std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));\n              ggml_vec_scale_f32(config_.head_dim,\n                                 thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                                 std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n              for (int j = 0; j < config_.head_dim; j++) {\n                output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=\n                    thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n              }\n              attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n            }\n            mutex_[cur_batch_idx][cur_head_id]->unlock();\n          }\n          thread_cur_head_idx_[thread_id].first = batch_id;\n          thread_cur_head_idx_[thread_id].second = head_id;\n          for (int i = 0; i < n_gqa_; i++) {\n            thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];\n            for (int j = 0; j < config_.head_dim; j++) {\n              thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =\n                  thread_local_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n          }\n        }\n      },\n      // Merge the results of the remaining blocks.\n      [&](int thread_id) {\n        int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n        int cur_head_id = thread_cur_head_idx_[thread_id].second;\n        if (cur_head_id != -1) {\n          mutex_[cur_batch_idx][cur_head_id]->lock();\n          for (int i = 0; i < n_gqa_; i++) {\n            float new_attn_lse;\n            if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {\n              attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];\n              for (int j = 0; j < config_.head_dim; j++) {\n                output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =\n                    thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n              }\n              continue;\n            }\n            new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +\n                           std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                                   attn_lse_[cur_batch_idx][cur_head_id][i]));\n            ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,\n                               std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n            for (int j = 0; j < config_.head_dim; j++) {\n              output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=\n                  thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n            attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n          }\n          mutex_[cur_batch_idx][cur_head_id]->unlock();\n        }\n      });\n  // move the results to output and attn_lse\n  uint16_t* output_data = reinterpret_cast<uint16_t*>(output);\n  float* attn_lse_data = attn_lse;\n  for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n    for (int i = 0; i < config_.kv_head_num; i++) {\n      for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n        output_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim + i * n_gqa_ * config_.head_dim + j] =\n            GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);\n      }\n      for (int j = 0; j < n_gqa_; j++) {\n        attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ + i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];\n      }\n    }\n  }\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  // printf(\"layer %d time of computing attention: %f s\\n\", layer_idx,\n  //        diff.count());\n}\n\nvoid KVCache::attention_layer_(const uint16_t* q_in_data, ggml_fp16_t* output, float* attn_lse, int batch_size,\n                               WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  seq_len_ = config_.block_len;\n  backend->do_work_stealing_job(\n      batch_size * config_.kv_head_num * max_block_num_after_retrieval_,\n      [&](int thread_id) {\n        thread_cur_head_idx_[thread_id].first = -1;\n        thread_cur_head_idx_[thread_id].second = -1;\n      },\n      [&](int task_id) {\n        int batch_id = task_id / (config_.kv_head_num * max_block_num_after_retrieval_);\n        int head_id =\n            (task_id % (config_.kv_head_num * max_block_num_after_retrieval_)) / max_block_num_after_retrieval_;\n        int block_id = task_id % max_block_num_after_retrieval_;\n        int thread_id = WorkerPool::thread_local_id;\n        // If the block is out of the sequence length, skip it.\n        if (cache_seqlens_[batch_id] / config_.block_len < block_id) {\n          return;\n        }\n        int block_idx = block_table_after_retrieval_[batch_id][block_id];\n        if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n          int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n          if (seq_len == 0) return;\n\n          // Prepare the attention mask for the last block.\n          int full_blocks = seq_len / 8;\n          int remaining_bits = seq_len % 8;\n\n          // Fill full blocks with 1s\n          for (int i = 0; i < full_blocks; ++i) {\n            thread_local_attn_mask_[thread_id][i] = 0xFF;\n          }\n          // Fill the remaining bits in the next block\n          if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n            thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;\n          } else {\n            thread_local_attn_mask_[thread_id][full_blocks] = 0;\n          }\n\n          for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n            thread_local_attn_mask_[thread_id][i] = 0;\n          }\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                                         (void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +\n                                                           head_id * n_gqa_ * config_.head_dim],\n                                         seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,\n                                         0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_fp32_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),\n                GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),\n                GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          }\n        } else {\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                (void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +\n                                  head_id * n_gqa_ * config_.head_dim],\n                seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,\n                nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                                         q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,\n                                         0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_q8_0_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                                         q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,\n                                         0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_q8_0_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          }\n        }\n        int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n        int cur_head_id = thread_cur_head_idx_[thread_id].second;\n        if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n          for (int i = 0; i < n_gqa_; i++) {\n            float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +\n                                 std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -\n                                                         thread_local_cur_attn_lse_[thread_id][i]));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));\n            for (int j = 0; j < config_.head_dim; j++) {\n              thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=\n                  thread_local_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n            thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n          }\n        } else {\n          if (cur_batch_idx != -1) {\n            mutex_[cur_batch_idx][cur_head_id]->lock();\n            for (int i = 0; i < n_gqa_; i++) {\n              if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {\n                attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];\n                for (int j = 0; j < config_.head_dim; j++) {\n                  output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =\n                      thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n                }\n                continue;\n              }\n              float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +\n                                   std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                                           attn_lse_[cur_batch_idx][cur_head_id][i]));\n              ggml_vec_scale_f32(config_.head_dim,\n                                 output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,\n                                 std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));\n              ggml_vec_scale_f32(config_.head_dim,\n                                 thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                                 std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n              for (int j = 0; j < config_.head_dim; j++) {\n                output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=\n                    thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n              }\n              attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n            }\n            mutex_[cur_batch_idx][cur_head_id]->unlock();\n          }\n          thread_cur_head_idx_[thread_id].first = batch_id;\n          thread_cur_head_idx_[thread_id].second = head_id;\n          for (int i = 0; i < n_gqa_; i++) {\n            thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];\n            for (int j = 0; j < config_.head_dim; j++) {\n              thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =\n                  thread_local_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n          }\n        }\n      },\n      // Merge the results of the remaining blocks.\n      [&](int thread_id) {\n        int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n        int cur_head_id = thread_cur_head_idx_[thread_id].second;\n        if (cur_head_id != -1) {\n          mutex_[cur_batch_idx][cur_head_id]->lock();\n          for (int i = 0; i < n_gqa_; i++) {\n            float new_attn_lse;\n            if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {\n              attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];\n              for (int j = 0; j < config_.head_dim; j++) {\n                output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =\n                    thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n              }\n              continue;\n            }\n            new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +\n                           std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                                   attn_lse_[cur_batch_idx][cur_head_id][i]));\n            ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,\n                               std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n            for (int j = 0; j < config_.head_dim; j++) {\n              output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=\n                  thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n            attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n          }\n          mutex_[cur_batch_idx][cur_head_id]->unlock();\n        }\n      });\n\n  // move the results to output and attn_lse\n  uint16_t* output_data = reinterpret_cast<uint16_t*>(output);\n  float* attn_lse_data = attn_lse;\n  for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n    for (int i = 0; i < config_.kv_head_num; i++) {\n      for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n        output_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim + i * n_gqa_ * config_.head_dim + j] =\n            GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);\n      }\n      for (int j = 0; j < n_gqa_; j++) {\n        attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ + i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];\n      }\n    }\n  }\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  //     printf(\"layer %d time of computing attention: %f s\\n\", layer_id_,\n  //     diff.count());\n}\n\nvoid KVCache::attn(const ggml_fp16_t* q_in, ggml_fp16_t* output, float* attn_lse, int layer_idx, int generate_token_idx,\n                   int q_len, int batch_size, int max_block_num, int* block_table, int* cache_seqlens,\n                   int pick_block_num, int init_block_num, int local_block_num, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  layer_id_ = layer_idx;\n  batch_size = batch_size * q_len;\n\n  const uint16_t* q_in_data = const_cast<const uint16_t*>(q_in);\n\n  quantize_q_(q_in_data, batch_size);\n  if (config_.retrieval_type == RetrievalType::LAYER) {\n    attn_initialize_layer_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);\n    retrieval_kvcache_layer_(q_in_data, init_block_num, local_block_num, pick_block_num, q_len, generate_token_idx,\n                             batch_size, layer_idx, cache_seqlens, max_block_num, backend);\n    attention_layer_(q_in_data, output, attn_lse, batch_size, backend);\n  } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n    attn_initialize_kvhead_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);\n    retrieval_kvcache_kvhead_(q_in_data, init_block_num, local_block_num, pick_block_num, q_len, generate_token_idx,\n                              batch_size, layer_idx, cache_seqlens, max_block_num, backend);\n    attention_kvhead_(q_in_data, output, attn_lse, batch_size, backend);\n  }\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  // printf(\"layer %d time of computing attention: %f s\\n\", layer_idx,\n  //        diff.count());\n}\n\nvoid KVCache::attn_with_kvcache(const ggml_fp16_t* q_in, const ggml_fp16_t* k_in, const ggml_fp16_t* v_in,\n                                ggml_fp16_t* output, float* attn_lse, int layer_idx, int generate_token_idx, int q_len,\n                                int batch_size, int max_block_num, int* block_table, int* cache_seqlens, int topk,\n                                int local, WorkerPool* backend) {\n  //    printf(\"attn_with_kvcache start\\n\");\n  assert(q_len == 1);\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_idx;\n\n  update_kvcache_fp16(k_in, v_in, layer_idx, block_table, batch_size, max_block_num, cache_seqlens, q_len, backend);\n  //    printf(\"update finished.\\n\");\n\n  // cache_seqlens memory is modified.\n  for (int i = 0; i < batch_size; i++) {\n    cache_seqlens[i] += q_len;\n  }\n  int init_block_num = 1;\n  if (config_.block_len <= 32) {\n    init_block_num = 64 / config_.block_len;\n  }\n\n  attn(q_in, output, attn_lse, layer_idx, generate_token_idx, q_len, batch_size, max_block_num, block_table,\n       cache_seqlens, topk, init_block_num, local, backend);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  //     printf(\"layer %d time of computing attention with kvcache: %f s\\n\",\n  //     layer_idx, diff.count());\n}\n\nvoid KVCache::quantize_q_(const uint16_t* q_in_data, int batch_size) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n    if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n      // quantize q\n      for (int i = 0; i < config_.kv_head_num; i++) {\n        for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n          q_fp32_[batch_idx][i][j] =\n              GGML_FP16_TO_FP32(q_in_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim +\n                                          i * n_gqa_ * config_.head_dim + j]);\n        }\n      }\n    } else {\n      // quantize q\n      for (int i = 0; i < config_.kv_head_num; i++) {\n        for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n          q_fp32[j] = GGML_FP16_TO_FP32(q_in_data[batch_idx * config_.kv_head_num * n_gqa_ * config_.head_dim +\n                                                  i * n_gqa_ * config_.head_dim + j]);\n        }\n        quantize_row_q8_0(q_fp32.data(), q_q8_0_[batch_idx][i].data(), n_gqa_ * config_.head_dim);\n      }\n    }\n  }\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  // printf(\"time of quantizing q: %f s\\n\",\n  //        std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::attn_initialize_layer_(int batch_size, int layer_idx, int* block_table, int& max_block_num,\n                                     int* cache_seqlens) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n    // initialize output_fp32_ and attn_lse_\n    for (int i = 0; i < config_.kv_head_num; i++) {\n      for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n        output_fp32_[batch_idx][i][j] = 0;\n      }\n      for (int j = 0; j < n_gqa_; j++) {\n        attn_lse_[batch_idx][i][j] = 0;\n      }\n    }\n    // clear top_similar_block_\n\n    while (!top_similar_block_[batch_idx].empty()) top_similar_block_[batch_idx].pop();\n  }\n\n  // get block_table_before_retrieval_ and cache_seqlens_\n  if (block_table == nullptr) {\n    max_block_num = past_block_num_[layer_idx];\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n      if (cache_total_len_ != 0)\n        cache_seqlens_[batch_idx] = cache_total_len_;\n      else\n        cache_seqlens_[batch_idx] = max_block_num * config_.block_len;\n      for (int i = 0; i < max_block_num; i++) {\n        block_table_before_retrieval_[batch_idx][i] = i;\n        block_similar_[batch_idx][i] = 0;\n      }\n    }\n  } else {\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n      cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];\n      for (int i = 0; i < max_block_num; i++) {\n        block_table_before_retrieval_[batch_idx][i] = block_table[batch_idx * max_block_num + i];\n        block_similar_[batch_idx][i] = 0;\n      }\n    }\n  }\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  // printf(\"layer %d time of initializing attention: %f s\\n\", layer_idx,\n  //        std::chrono::duration<double>(end - start).count());\n}\n\nvoid KVCache::calculate_block_similarity_layer_(const uint16_t* q_in_data, int batch_size, int layer_idx, int q_len,\n                                                int max_block_num, int* cache_seqlens, int init_block_num,\n                                                int local_block_num, int pick_block_num, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  if (batch_size == 1 && config_.anchor_num == 1) {  // TODO: improve batch_size > 1\n    for (int batch_id = 0; batch_id < batch_size; batch_id++) {\n      if (q_len == 1) {\n        for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) {\n          avg_q[batch_id][j] =\n              GGML_FP16_TO_FP32(q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + j]);\n          avg_q_fp16[batch_id][j] = q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim + j];\n        }\n      } else {\n        for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) {\n          avg_q[batch_id][j] = 0;\n        }\n        for (int i = 0; i < q_len; i++) {\n          for (int j = 0; j < config_.head_dim; j++) {\n            avg_q[batch_id][j] += GGML_FP16_TO_FP32(q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim +\n                                                              i * config_.q_head_num * config_.head_dim + j]);\n          }\n        }\n        for (int j = 0; j < config_.head_dim * config_.q_head_num; j++) {\n          avg_q[batch_id][j] /= q_len;\n          avg_q_fp16[batch_id][j] = GGML_FP32_TO_FP16(avg_q[batch_id][j]);\n        }\n      }\n      int seq_len = cache_seqlens_[batch_id];\n      int block_num = (seq_len / config_.block_len) - local_block_num - init_block_num;\n      if (block_num <= 0) {\n        continue;\n      }\n      bool is_seq = true;\n      for (int i = init_block_num + 1; i < (seq_len / config_.block_len) - local_block_num; i++) {\n        if (block_table_before_retrieval_[batch_id][i] != block_table_before_retrieval_[batch_id][i - 1] + 1) {\n          is_seq = false;\n          break;\n        }\n      }\n      if (is_seq) {\n        int nth = backend->get_thread_num();\n        backend->do_work_stealing_job(\n            nth, nullptr,\n            [&](int task_id) {\n              int ith = task_id;\n              bool ok = llamafile_sgemm(\n                  block_num, 1, config_.q_head_num * config_.head_dim,\n                  anchor_.data() +\n                      (layer_idx * config_.max_block_num + block_table_before_retrieval_[batch_id][init_block_num]) *\n                          config_.anchor_num * config_.q_head_num * config_.head_dim,\n                  config_.q_head_num * config_.head_dim, avg_q_fp16[batch_id].data(),\n                  config_.q_head_num * config_.head_dim, block_similar_[batch_id].data() + init_block_num, block_num,\n                  ith, nth, GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n              if (!ok) {\n                printf(\"llamafile_sgemm failed\\n\");\n              }\n            },\n            nullptr);\n      } else {\n        backend->do_work_stealing_job(\n            block_num, nullptr,\n            [&](int task_id) {\n              int block_id = task_id + init_block_num;\n              int block_idx = block_table_before_retrieval_[batch_id][block_id];\n              bool ok = llamafile_sgemm(\n                  1, 1, config_.q_head_num * config_.head_dim,\n                  anchor_.data() +\n                      (layer_idx * config_.max_block_num + block_table_before_retrieval_[batch_id][block_idx]) *\n                          config_.anchor_num * config_.q_head_num * config_.head_dim,\n                  config_.q_head_num * config_.head_dim, avg_q_fp16[batch_id].data(),\n                  config_.q_head_num * config_.head_dim, block_similar_[batch_id].data() + block_id, 1, 0, 1,\n                  GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n              if (!ok) {\n                printf(\"llamafile_sgemm failed\\n\");\n              }\n            },\n            nullptr);\n      }\n    }\n  } else {\n    backend->do_work_stealing_job(\n        batch_size * max_block_num, nullptr,\n        [&](int task_id) {\n          int batch_id = task_id / max_block_num;\n          int block_id = task_id % max_block_num;\n          int seq_len = cache_seqlens_[batch_id];\n\n          if (block_id < init_block_num || block_id >= (seq_len / config_.block_len) - local_block_num) {\n            return;\n          }\n\n          int block_idx = block_table_before_retrieval_[batch_id][block_id];\n          float sim = 0;\n\n          for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n            for (int i = 0; i < config_.head_dim; i++) {\n              float q_i = 0, qa_i = std::numeric_limits<float>::lowest();\n              for (int q_id = 0; q_id < q_len; q_id++) {\n                q_i += GGML_FP16_TO_FP32(\n                    q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim +\n                              q_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]);\n              }\n              q_i /= q_len;\n              for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) {\n                qa_i = std::max(\n                    qa_i,\n                    GGML_FP16_TO_FP32(\n                        anchor_[(long long)layer_idx * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim +\n                                block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]) *\n                        q_i);\n              }\n              sim += qa_i;\n            }\n          }\n          block_similar_[batch_id][block_id] = sim;\n        },\n        nullptr);\n  }\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  // printf(\"layer %d time of calculating similarity: %f s\\n\", layer_idx,\n  //        diff.count());\n}\n\nvoid KVCache::select_block_layer_(int batch_size, int layer_idx, int max_block_num, int init_block_num,\n                                  int local_block_num, int pick_block_num) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n    if (cache_seqlens_[batch_idx] / config_.block_len <= init_block_num + pick_block_num + local_block_num) {\n      block_table_after_retrieval_[batch_idx].swap(block_table_before_retrieval_[batch_idx]);\n      selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;\n      continue;\n    }\n\n    for (int block_id = init_block_num; block_id < (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num;\n         block_id++) {\n      top_similar_block_[batch_idx].push(\n          std::make_pair(block_similar_[batch_idx][block_id], block_table_before_retrieval_[batch_idx][block_id]));\n      if (top_similar_block_[batch_idx].size() > pick_block_num) {\n        top_similar_block_[batch_idx].pop();\n      }\n    }\n\n    int i = 0;\n    for (; i < init_block_num; i++) {\n      block_table_after_retrieval_[batch_idx][i] = block_table_before_retrieval_[batch_idx][i];\n    }\n    while (!top_similar_block_[batch_idx].empty()) {\n      block_table_after_retrieval_[batch_idx][i] = top_similar_block_[batch_idx].top().second;\n      top_similar_block_[batch_idx].pop();\n      i++;\n    }\n    for (; i < init_block_num + pick_block_num + local_block_num; i++) {\n      block_table_after_retrieval_[batch_idx][i] =\n          block_table_before_retrieval_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len) - local_block_num +\n                                                   i - init_block_num - pick_block_num];\n    }\n    if (cache_seqlens_[batch_idx] % config_.block_len != 0) {\n      block_table_after_retrieval_[batch_idx][i] =\n          block_table_before_retrieval_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len)];\n      cache_seqlens_[batch_idx] = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;\n      i++;\n    } else {\n      cache_seqlens_[batch_idx] = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;\n    }\n    for (int j = 0; j < i; j++) {\n      selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][j] =\n          block_table_after_retrieval_[batch_idx][j];\n    }\n    selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = i;\n  }\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  // printf(\"layer %d time of selecting blocks: %f s\\n\", layer_idx,\n  //        diff.count());\n}\n\n// retrieval kvcache, get the init_block_num block at beginning, top\n// pick_block_num similar and last local_block_num blocks. Each task\n// calculates the simlarity of a certain block with the query, then push\n// the block into the priority queue. Finally, the required blocks are\n// pushed into the block_table_after_retrieval_.\nvoid KVCache::retrieval_kvcache_layer_(const uint16_t* q_in_data, int init_block_num, int local_block_num,\n                                       int pick_block_num, int q_len, int generate_token_idx, int batch_size,\n                                       int layer_idx, int* cache_seqlens, int& max_block_num, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  max_block_num_after_retrieval_ = 0;\n  if (pick_block_num != -1 &&\n      (generate_token_idx % config_.token_step != 0 || (layer_idx % config_.layer_step != config_.layer_offset))) {\n    if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] == 0) {\n      max_block_num_after_retrieval_ = max_block_num;\n      block_table_after_retrieval_.swap(block_table_before_retrieval_);\n    } else {\n      max_block_num_after_retrieval_ =\n          selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];\n      for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        for (int i = 0; i < max_block_num_after_retrieval_; i++) {\n          block_table_after_retrieval_[batch_idx][i] =\n              selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][i];\n        }\n\n        if (cache_seqlens[batch_idx] % config_.block_len == 1) {\n          selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] += 1;\n          int x = selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];\n          int last_block_idx = block_table_before_retrieval_[batch_idx][cache_seqlens[batch_idx] / config_.block_len];\n          selected_blocks_history_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][x - 1] =\n              last_block_idx;\n          block_table_after_retrieval_[batch_idx][x - 1] = last_block_idx;\n        }\n        cache_seqlens_[batch_idx] =\n            (cache_seqlens_[batch_idx] % config_.block_len) +\n            selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] * config_.block_len -\n            config_.block_len;\n      }\n    }\n  } else if (pick_block_num != -1) {\n    max_block_num_after_retrieval_ = std::min(max_block_num, init_block_num + pick_block_num + local_block_num + 1);\n    calculate_block_similarity_layer_(q_in_data, batch_size, layer_idx, q_len, max_block_num, cache_seqlens,\n                                      init_block_num, local_block_num, pick_block_num, backend);\n    select_block_layer_(batch_size, layer_idx, max_block_num, init_block_num, local_block_num, pick_block_num);\n  } else {\n    selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;\n    max_block_num_after_retrieval_ = max_block_num;\n    block_table_after_retrieval_.swap(block_table_before_retrieval_);\n  }\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  //     printf(\"layer %d time of retrieval kvcache: %f s\\n\", layer_idx,\n  //     std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::calculate_sparsity_layer_(const uint16_t* q_in_data, float* attn_sparsity, int batch_size,\n                                        int max_block_num, int* block_table, int* cache_seqlens, WorkerPool* backend\n\n) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  seq_len_ = config_.block_len;\n  backend->do_work_stealing_job(\n      batch_size * config_.kv_head_num * max_block_num,\n      [&](int thread_id) {\n        thread_cur_head_idx_[thread_id].first = -1;\n        thread_cur_head_idx_[thread_id].second = -1;\n      },\n      [&](int task_id) {\n        int batch_id = task_id / (config_.kv_head_num * max_block_num);\n        int head_id = (task_id % (config_.kv_head_num * max_block_num)) / max_block_num;\n        int block_id = task_id % max_block_num;\n        int thread_id = WorkerPool::thread_local_id;\n        // If the block is out of the sequence length, skip it.\n        if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n          return;\n        }\n        int block_idx = block_table[batch_id * max_block_num + block_id];\n        if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n          int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n          if (seq_len == 0) return;\n\n          // Prepare the attention mask for the last block.\n          int full_blocks = seq_len / 8;\n          int remaining_bits = seq_len % 8;\n          // Fill full blocks with 1s\n          for (int i = 0; i < full_blocks; ++i) {\n            thread_local_attn_mask_[thread_id][i] = 0xFF;\n          }\n          // Fill the remaining bits in the next block\n          if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n            thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;\n          } else {\n            thread_local_attn_mask_[thread_id][full_blocks] = 0;\n          }\n\n          for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n            thread_local_attn_mask_[thread_id][i] = 0;\n          }\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                                         (void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +\n                                                           head_id * n_gqa_ * config_.head_dim],\n                                         seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,\n                                         0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_fp32_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),\n                GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),\n                GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          }\n        } else {\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                (void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +\n                                  head_id * n_gqa_ * config_.head_dim],\n                seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,\n                nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                                         q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,\n                                         0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_q8_0_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                                         q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,\n                                         0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_q8_0_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          }\n        }\n        for (int i = 0; i < n_gqa_; i++) {\n          block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] = thread_local_attn_lse_[thread_id][i];\n        }\n        int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n        int cur_head_id = thread_cur_head_idx_[thread_id].second;\n        if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n          for (int i = 0; i < n_gqa_; i++) {\n            float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +\n                                 std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -\n                                                         thread_local_cur_attn_lse_[thread_id][i]));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));\n            for (int j = 0; j < config_.head_dim; j++) {\n              thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=\n                  thread_local_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n            thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n          }\n        } else {\n          if (cur_batch_idx != -1) {\n            mutex_[cur_batch_idx][cur_head_id]->lock();\n            for (int i = 0; i < n_gqa_; i++) {\n              if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {\n                attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];\n                for (int j = 0; j < config_.head_dim; j++) {\n                  output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =\n                      thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n                }\n                continue;\n              }\n              float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +\n                                   std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                                           attn_lse_[cur_batch_idx][cur_head_id][i]));\n              ggml_vec_scale_f32(config_.head_dim,\n                                 output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,\n                                 std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));\n              ggml_vec_scale_f32(config_.head_dim,\n                                 thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                                 std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n              for (int j = 0; j < config_.head_dim; j++) {\n                output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=\n                    thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n              }\n              attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n            }\n            mutex_[cur_batch_idx][cur_head_id]->unlock();\n          }\n          thread_cur_head_idx_[thread_id].first = batch_id;\n          thread_cur_head_idx_[thread_id].second = head_id;\n          for (int i = 0; i < n_gqa_; i++) {\n            thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];\n            for (int j = 0; j < config_.head_dim; j++) {\n              thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =\n                  thread_local_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n          }\n        }\n      },\n      // Merge the results of the remaining blocks.\n      [&](int thread_id) {\n        int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n        int cur_head_id = thread_cur_head_idx_[thread_id].second;\n        if (cur_head_id != -1) {\n          mutex_[cur_batch_idx][cur_head_id]->lock();\n          for (int i = 0; i < n_gqa_; i++) {\n            float new_attn_lse;\n            if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {\n              attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];\n              for (int j = 0; j < config_.head_dim; j++) {\n                output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =\n                    thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n              }\n              continue;\n            }\n            new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +\n                           std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                                   attn_lse_[cur_batch_idx][cur_head_id][i]));\n            ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,\n                               std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n            for (int j = 0; j < config_.head_dim; j++) {\n              output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=\n                  thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n            attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n          }\n          mutex_[cur_batch_idx][cur_head_id]->unlock();\n        }\n      });\n\n  for (int i = 0; i < batch_size; i++) {\n    for (int j = 0; j < max_block_num_after_retrieval_; j++) {\n      int block_idx = block_table_after_retrieval_[i][j];\n      for (int k = 0; k < config_.q_head_num; k++) {\n        attn_sparsity[i * config_.q_head_num + k] +=\n            std::exp(block_lse_[i][block_idx][k] - attn_lse_[i][k / n_gqa_][k % n_gqa_]);\n      }\n    }\n  }\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  // printf(\"layer %d time of calculating sparsity: %f s\\n\", layer_id_,\n  //        diff.count());\n}\n\nvoid KVCache::attn_initialize_kvhead_(int batch_size, int layer_idx, int* block_table, int& max_block_num,\n                                      int* cache_seqlens) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n    // initialize output_fp32_ and attn_lse_\n    for (int i = 0; i < config_.kv_head_num; i++) {\n      for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n        output_fp32_[batch_idx][i][j] = 0;\n      }\n      for (int j = 0; j < n_gqa_; j++) {\n        attn_lse_[batch_idx][i][j] = 0;\n      }\n    }\n\n    // clear top_similar_block_\n    while (!top_similar_block_[batch_idx].empty()) top_similar_block_[batch_idx].pop();\n  }\n\n  for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n    cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];\n    for (int i = 0; i < max_block_num; i++) {\n      for (int j = 0; j < config_.kv_head_num; j++) {\n        block_table_before_retrieval_kvhead_[batch_idx][i][j] = block_table[batch_idx * max_block_num + i];\n        block_similar_kv_head_[batch_idx][i][j] = 0;\n      }\n    }\n  }\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  // printf(\"layer %d time of initializing attn: %f s\\n\", layer_idx,\n  //        std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::retrieval_kvcache_kvhead_(const uint16_t* q_in_data, int init_block_num, int local_block_num,\n                                        int pick_block_num, int q_len, int generate_token_idx, int batch_size,\n                                        int layer_idx, int* cache_seqlens, int& max_block_num, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  max_block_num_after_retrieval_ = 0;\n  if (pick_block_num != -1 &&\n      (generate_token_idx % config_.token_step != 0 || (layer_idx % config_.layer_step != config_.layer_offset))) {\n    if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] == 0) {\n      max_block_num_after_retrieval_ = max_block_num;\n      for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        for (int i = 0; i < max_block_num; i++) {\n          for (int j = 0; j < config_.kv_head_num; j++) {\n            block_table_after_retrieval_kvhead_[batch_idx][i][j] =\n                block_table_before_retrieval_kvhead_[batch_idx][i][j];\n          }\n        }\n      }\n    } else {\n      max_block_num_after_retrieval_ =\n          selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];\n\n      for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        for (int i = 0; i < max_block_num_after_retrieval_; i++) {\n          for (int j = 0; j < config_.kv_head_num; j++) {\n            block_table_after_retrieval_kvhead_[batch_idx][i][j] =\n                selected_blocks_history_kvhead_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][i]\n                                               [j];\n          }\n        }\n\n        if (cache_seqlens[batch_idx] % config_.block_len == 1) {\n          selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] += 1;\n          int x = selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step];\n          for (int i = 0; i < config_.kv_head_num; i++) {\n            int last_block_idx =\n                block_table_before_retrieval_kvhead_[batch_idx][cache_seqlens[batch_idx] / config_.block_len][i];\n            selected_blocks_history_kvhead_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][x - 1]\n                                           [i] = last_block_idx;\n            block_table_after_retrieval_kvhead_[batch_idx][x - 1][i] = last_block_idx;\n          }\n        }\n        cache_seqlens_[batch_idx] = std::min(\n            cache_seqlens_[batch_idx], (cache_seqlens_[batch_idx] % config_.block_len) +\n                                           (init_block_num + pick_block_num + local_block_num) * config_.block_len);\n      }\n    }\n  } else if (pick_block_num != -1) {\n    max_block_num_after_retrieval_ = std::min(max_block_num, init_block_num + pick_block_num + local_block_num + 1);\n    calculate_block_similarity_kvhead_(q_in_data, batch_size, layer_idx, q_len, max_block_num, cache_seqlens,\n                                       init_block_num, local_block_num, pick_block_num, backend);\n    select_block_kvhead_(batch_size, layer_idx, max_block_num, init_block_num, local_block_num, pick_block_num);\n  } else {\n    selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;\n    max_block_num_after_retrieval_ = max_block_num;\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n      for (int i = 0; i < max_block_num; i++) {\n        for (int j = 0; j < config_.kv_head_num; j++) {\n          block_table_after_retrieval_kvhead_[batch_idx][i][j] = block_table_before_retrieval_kvhead_[batch_idx][i][j];\n        }\n      }\n    }\n  }\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  // printf(\"layer %d time of retrieval kvcache: %f s\\n\", layer_idx,\n  //        std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::calculate_sparsity_kvhead_(const uint16_t* q_in_data, float* attn_sparsity, int batch_size,\n                                         int max_block_num, int* block_table, int* cache_seqlens, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  seq_len_ = config_.block_len;\n  backend->do_work_stealing_job(\n      batch_size * config_.kv_head_num * max_block_num,\n      [&](int thread_id) {\n        thread_cur_head_idx_[thread_id].first = -1;\n        thread_cur_head_idx_[thread_id].second = -1;\n      },\n      [&](int task_id) {\n        int batch_id = task_id / (config_.kv_head_num * max_block_num);\n        int head_id = (task_id % (config_.kv_head_num * max_block_num)) / max_block_num;\n        int block_id = task_id % max_block_num;\n        int thread_id = WorkerPool::thread_local_id;\n        // If the block is out of the sequence length, skip it.\n        if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n          return;\n        }\n        int block_idx = block_table[batch_id * max_block_num + block_id];\n        if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n          int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n          if (seq_len == 0) return;\n\n          // Prepare the attention mask for the last block.\n          int full_blocks = seq_len / 8;\n          int remaining_bits = seq_len % 8;\n\n          // Fill full blocks with 1s\n          for (int i = 0; i < full_blocks; ++i) {\n            thread_local_attn_mask_[thread_id][i] = 0xFF;\n          }\n          // Fill the remaining bits in the next block\n          if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n            thread_local_attn_mask_[thread_id][full_blocks] = (1 << remaining_bits) - 1;\n          } else {\n            thread_local_attn_mask_[thread_id][full_blocks] = 0;\n          }\n\n          for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n            thread_local_attn_mask_[thread_id][i] = 0;\n          }\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                                         (void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +\n                                                           head_id * n_gqa_ * config_.head_dim],\n                                         seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(), GGML_TYPE_F16,\n                                         0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_fp32_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),\n                GGML_TYPE_Q4_0, 0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                q_q8_0_[batch_id][head_id].data(), seq_len_, 0, false, thread_local_attn_mask_[thread_id].data(),\n                GGML_TYPE_Q8_0, 0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                thread_local_attn_score_[thread_id].data(), thread_local_output_q8_0_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          }\n        } else {\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            attn_with_kvcache_one_block_(\n                config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                (void*)&q_in_data[batch_id * config_.kv_head_num * n_gqa_ * config_.head_dim +\n                                  head_id * n_gqa_ * config_.head_dim],\n                seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0, k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                nullptr, nullptr, GGML_TYPE_F16, 1, v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0, nullptr,\n                nullptr, thread_local_attn_score_[thread_id].data(), thread_local_output_fp32_[thread_id].data(),\n                thread_local_attn_lse_[thread_id].data(), thread_local_draft_[thread_id].data(), nullptr, cos_.data(),\n                sin_.data());\n\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                                         q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0,\n                                         0, k_cache_q4[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_Q4_0, 1, v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_q8_0_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            attn_with_kvcache_one_block_(config_.head_dim, config_.q_head_num / config_.kv_head_num, GGML_TYPE_Q8_0,\n                                         q_q8_0_[batch_id][head_id].data(), seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0,\n                                         0, k_cache_q8[layer_id_][head_id][block_idx].data(), 0, nullptr, nullptr,\n                                         GGML_TYPE_Q8_0, 1, v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                                         nullptr, nullptr, thread_local_attn_score_[thread_id].data(),\n                                         thread_local_output_q8_0_[thread_id].data(),\n                                         thread_local_attn_lse_[thread_id].data(),\n                                         thread_local_draft_[thread_id].data(), nullptr, cos_.data(), sin_.data());\n            dequantize_row_q8_0(thread_local_output_q8_0_[thread_id].data(),\n                                thread_local_output_fp32_[thread_id].data(), n_gqa_ * config_.head_dim);\n          }\n        }\n        for (int i = 0; i < n_gqa_; i++) {\n          block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] = thread_local_attn_lse_[thread_id][i];\n        }\n        int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n        int cur_head_id = thread_cur_head_idx_[thread_id].second;\n        if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n          for (int i = 0; i < n_gqa_; i++) {\n            float new_attn_lse = thread_local_cur_attn_lse_[thread_id][i] +\n                                 std::log(1.0 + std::exp(thread_local_attn_lse_[thread_id][i] -\n                                                         thread_local_cur_attn_lse_[thread_id][i]));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_attn_lse_[thread_id][i] - new_attn_lse));\n            for (int j = 0; j < config_.head_dim; j++) {\n              thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] +=\n                  thread_local_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n            thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n          }\n        } else {\n          if (cur_batch_idx != -1) {\n            mutex_[cur_batch_idx][cur_head_id]->lock();\n            for (int i = 0; i < n_gqa_; i++) {\n              if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {\n                attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];\n                for (int j = 0; j < config_.head_dim; j++) {\n                  output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =\n                      thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n                }\n                continue;\n              }\n              float new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +\n                                   std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                                           attn_lse_[cur_batch_idx][cur_head_id][i]));\n              ggml_vec_scale_f32(config_.head_dim,\n                                 output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,\n                                 std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));\n              ggml_vec_scale_f32(config_.head_dim,\n                                 thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                                 std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n              for (int j = 0; j < config_.head_dim; j++) {\n                output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=\n                    thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n              }\n              attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n            }\n            mutex_[cur_batch_idx][cur_head_id]->unlock();\n          }\n          thread_cur_head_idx_[thread_id].first = batch_id;\n          thread_cur_head_idx_[thread_id].second = head_id;\n          for (int i = 0; i < n_gqa_; i++) {\n            thread_local_cur_attn_lse_[thread_id][i] = thread_local_attn_lse_[thread_id][i];\n            for (int j = 0; j < config_.head_dim; j++) {\n              thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j] =\n                  thread_local_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n          }\n        }\n      },\n      // Merge the results of the remaining blocks.\n      [&](int thread_id) {\n        int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n        int cur_head_id = thread_cur_head_idx_[thread_id].second;\n        if (cur_head_id != -1) {\n          mutex_[cur_batch_idx][cur_head_id]->lock();\n          for (int i = 0; i < n_gqa_; i++) {\n            float new_attn_lse;\n            if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) < 1e-6) {\n              attn_lse_[cur_batch_idx][cur_head_id][i] = thread_local_cur_attn_lse_[thread_id][i];\n              for (int j = 0; j < config_.head_dim; j++) {\n                output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] =\n                    thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n              }\n              continue;\n            }\n            new_attn_lse = attn_lse_[cur_batch_idx][cur_head_id][i] +\n                           std::log(1.0 + std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                                   attn_lse_[cur_batch_idx][cur_head_id][i]));\n            ggml_vec_scale_f32(config_.head_dim, output_fp32_[cur_batch_idx][cur_head_id].data() + i * config_.head_dim,\n                               std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] - new_attn_lse));\n            ggml_vec_scale_f32(config_.head_dim, thread_local_cur_output_fp32_[thread_id].data() + i * config_.head_dim,\n                               std::exp(thread_local_cur_attn_lse_[thread_id][i] - new_attn_lse));\n            for (int j = 0; j < config_.head_dim; j++) {\n              output_fp32_[cur_batch_idx][cur_head_id][i * config_.head_dim + j] +=\n                  thread_local_cur_output_fp32_[thread_id][i * config_.head_dim + j];\n            }\n            attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n          }\n          mutex_[cur_batch_idx][cur_head_id]->unlock();\n        }\n      });\n\n  for (int i = 0; i < batch_size; i++) {\n    for (int j = 0; j < max_block_num_after_retrieval_; j++) {\n      for (int k = 0; k < config_.q_head_num; k++) {\n        int block_idx = block_table_after_retrieval_kvhead_[i][j][k / n_gqa_];\n        attn_sparsity[i * config_.q_head_num + k] +=\n            std::exp(block_lse_[i][block_idx][k] - attn_lse_[i][k / n_gqa_][k % n_gqa_]);\n      }\n    }\n  }\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  // printf(\"layer %d time of calculating sparsity: %f s\\n\", layer_id_,\n  //        diff.count());\n}\nvoid KVCache::calculate_block_similarity_kvhead_(const uint16_t* q_in_data, int batch_size, int layer_idx, int q_len,\n                                                 int max_block_num, int* cache_seqlens, int init_block_num,\n                                                 int local_block_num, int pick_block_num, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  backend->do_work_stealing_job(\n      batch_size * max_block_num, nullptr,\n      [&](int task_id) {\n        int batch_id = task_id / max_block_num;\n        int block_id = task_id % max_block_num;\n        int seq_len = cache_seqlens_[batch_id];\n\n        if (block_id < init_block_num || block_id >= (seq_len / config_.block_len) - local_block_num) {\n          return;\n        }\n        int block_idx = block_table_before_retrieval_kvhead_[batch_id][block_id][0];\n\n        for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n          for (int i = 0; i < config_.head_dim; i++) {\n            float q_i = 0, qa_i = std::numeric_limits<float>::lowest();\n            for (int q_id = 0; q_id < q_len; q_id++) {\n              q_i += GGML_FP16_TO_FP32(\n                  q_in_data[batch_id * q_len * config_.q_head_num * config_.head_dim +\n                            q_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]);\n            }\n            q_i /= q_len;\n            for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) {\n              qa_i = std::max(\n                  qa_i,\n                  GGML_FP16_TO_FP32(\n                      anchor_[layer_idx * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                  config_.head_dim +\n                              block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                              anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + i]) *\n                      q_i);\n            }\n            block_similar_kv_head_[batch_id][block_id][head_id / n_gqa_] += qa_i;\n          }\n        }\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  // printf(\"layer %d time of calculating similarity: %f s\\n\", layer_idx,\n  //        diff.count());\n}\nvoid KVCache::select_block_kvhead_(int batch_size, int layer_idx, int max_block_num, int init_block_num,\n                                   int local_block_num, int pick_block_num) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n    int cache_len_after_retrieval = 0;\n    if (cache_seqlens_[batch_idx] / config_.block_len <= init_block_num + pick_block_num + local_block_num) {\n      selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] = 0;\n      for (int i = 0; i < max_block_num; i++) {\n        for (int j = 0; j < config_.kv_head_num; j++) {\n          block_table_after_retrieval_kvhead_[batch_idx][i][j] = block_table_before_retrieval_kvhead_[batch_idx][i][j];\n        }\n      }\n      continue;\n    }\n    for (int head_id = 0; head_id < config_.kv_head_num; head_id++) {\n      for (int block_id = init_block_num; block_id < (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num;\n           block_id++) {\n        top_similar_block_[batch_idx].push(\n            std::make_pair(block_similar_kv_head_[batch_idx][block_id][head_id],\n                           block_table_before_retrieval_kvhead_[batch_idx][block_id][head_id]));\n        if (top_similar_block_[batch_idx].size() > pick_block_num) {\n          top_similar_block_[batch_idx].pop();\n        }\n      }\n\n      int i = 0;\n      for (; i < init_block_num; i++) {\n        block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n            block_table_before_retrieval_kvhead_[batch_idx][i][head_id];\n      }\n      while (!top_similar_block_[batch_idx].empty()) {\n        block_table_after_retrieval_kvhead_[batch_idx][i][head_id] = top_similar_block_[batch_idx].top().second;\n        top_similar_block_[batch_idx].pop();\n        i++;\n      }\n      for (; i < init_block_num + pick_block_num + local_block_num; i++) {\n        block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n            block_table_before_retrieval_kvhead_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len) -\n                                                            local_block_num + i - init_block_num - pick_block_num]\n                                                [head_id];\n      }\n      if (cache_seqlens_[batch_idx] % config_.block_len != 0) {\n        block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n            block_table_before_retrieval_kvhead_[batch_idx][(cache_seqlens_[batch_idx] / config_.block_len)][head_id];\n        cache_len_after_retrieval = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;\n        i++;\n      } else {\n        cache_len_after_retrieval = (cache_seqlens_[batch_idx] % config_.block_len) + i * config_.block_len;\n      }\n      for (int j = 0; j < i; j++) {\n        selected_blocks_history_kvhead_[(layer_idx - config_.layer_offset) / config_.layer_step][batch_idx][j]\n                                       [head_id] = block_table_after_retrieval_kvhead_[batch_idx][j][head_id];\n      }\n    }\n    cache_seqlens_[batch_idx] = cache_len_after_retrieval;\n    selected_blocks_num_history_[(layer_idx - config_.layer_offset) / config_.layer_step] =\n        (cache_len_after_retrieval + config_.block_len - 1) / config_.block_len;\n  }\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  // printf(\"layer %d time of selecting block: %f s\\n\", layer_idx,\n  //        diff.count())\n}\n\nvoid KVCache::get_attn_sparsity(const ggml_fp16_t* q_in, float* attn_sparsity, int layer_idx, int generate_token_idx,\n                                int q_len, int batch_size, int max_block_num, int* block_table, int* cache_seqlens,\n                                int* block_table_origin, int* cache_seqlens_origin, int max_block_num_origin, int topk,\n                                int local, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  layer_id_ = layer_idx;\n  int thread_num = backend->get_thread_num();\n  batch_size = 1;\n\n  const uint16_t* q_in_data = const_cast<const uint16_t*>(q_in);\n\n  quantize_q_(q_in_data, batch_size);\n  if (config_.retrieval_type == RetrievalType::LAYER) {\n    attn_initialize_layer_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);\n    retrieval_kvcache_layer_(q_in_data, 1, local, topk, q_len, generate_token_idx, batch_size, layer_idx, cache_seqlens,\n                             max_block_num, backend);\n    calculate_sparsity_layer_(q_in_data, attn_sparsity, batch_size, max_block_num_origin, block_table_origin,\n                              cache_seqlens_origin, backend);\n  } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n    attn_initialize_kvhead_(batch_size, layer_idx, block_table, max_block_num, cache_seqlens);\n    retrieval_kvcache_kvhead_(q_in_data, 1, local, topk, q_len, generate_token_idx, batch_size, layer_idx,\n                              cache_seqlens, max_block_num, backend);\n    calculate_sparsity_kvhead_(q_in_data, attn_sparsity, batch_size, max_block_num_origin, block_table_origin,\n                               cache_seqlens_origin, backend);\n  }\n}\n\nvoid KVCache::attn_with_kvcache_one_block_(int head_dim, int bsz,\n                                           ggml_type q_type,  // GGML data type of `Q`, only supports fp16 and q8_0\n                                           // [bsz, head_dim]\n                                           // Quantization is always on the head_dim dimension (per_token). If\n                                           // head_dim % 32 != 0, an error will be raised. The size must be bsz *\n                                           // head_dim/32 * qtype_size.\n                                           const void* q,\n\n                                           int past_kv_len, int past_kv_offset,\n                                           bool is_full_attn,  // true indicates a full 1 mask\n                                           // If is_full_attn = false, a bit matrix representing the mask is\n                                           // passed. [bsz, past_kv_len]\n                                           const uint8_t* attn_mask,\n\n                                           ggml_type k_type,  // GGML data type of `K Cache`, only supports fp16,\n                                                              // q4_0, q8_0\n                                           int k_quant_type,  // 0 for per_token, 1 for per_channel, others raise an\n                                                              // error\n                                           // [seq_len, head_dim]\n                                           // If quant_type == 0, head_dim % 32 must be 0.\n                                           // If quant_type == 1, seq_len % 32 must be 0.\n                                           const void* k_cache,\n\n                                           // k_anchor_type must be fp16\n                                           int num_k_anchor,  // num_k_anchor == 0 indicates no anchor\n                                           // [num_k_anchor, head_dim]\n                                           const void* k_cache_anchors,\n                                           // Each token is associated with the nearest previous position's anchor,\n                                           // with the same distance.\n                                           const int* k_cache_anchor_pos,\n\n                                           // v_cache similar to k_cache\n                                           ggml_type v_type, int v_quant_type,\n                                           // [head_dim, seq_len]\n                                           const void* v_cache, int num_v_anchor, const void* v_cache_anchors,\n                                           const int* v_cache_anchor_pos,\n\n                                           // Pre-allocated buffer for intermediate calculations [bsz,\n                                           // past_kv_len]. No malloc is performed inside this function.\n                                           float* attn_score,\n\n                                           // Output: [bsz, head_dim], with the same type as q_type\n                                           void* output,\n                                           // [bsz]\n                                           float* lse,\n\n                                           // Pre-allocated temporary buffer with sufficient size:\n                                           // (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *\n                                           // head_dim + past_kv_len * head_dim / 32) bytes.\n                                           void* draft,\n\n                                           // Apply rotary embedding online\n                                           const int* rotary_angle, const void* rotary_cos, const void* rotary_sin\n                                           // rotary_cos=None,\n                                           // rotary_sin=None,\n                                           // cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,\n                                           // cache_batch_idx: Optional[torch.Tensor] = None,\n                                           // rotary_interleaved=True,\n\n                                           // // Not supported for now\n                                           // window_size=(-1, -1),  # -1 means infinite context window\n                                           // alibi_slopes=None,\n) {\n  assert(head_dim % 32 == 0);\n  assert(k_quant_type == 0);\n  assert(v_quant_type == 1);\n  assert(q_type == GGML_TYPE_F16 || q_type == GGML_TYPE_Q8_0);\n  if (q_type == GGML_TYPE_F16) {\n    assert(k_type == GGML_TYPE_F16);\n    assert(v_type == GGML_TYPE_F16);\n\n    // attn = q * k + q * k_anchor\n    // TODO: anchor\n    assert(num_k_anchor == 0);\n\n    if (rotary_angle != nullptr) {\n      ggml_fp16_t* k_cache_with_rope_fp16 =\n          (reinterpret_cast<ggml_fp16_t*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n           sizeof(float) * bsz * head_dim);\n      // dequant k_cache and apply rope\n      // k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)\n      // k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)\n\n      // k(i)cos(i) -> k_rope(i)\n      // k(i)sin(i+l) -> k_rope(i+l)\n\n      // k(i)cos(i) -> k_rope(i)\n      // -k(i)sin(i-l) -> k_rope(i-l)\n\n      std::vector<float> block_fp32(32);\n      for (int k = 0; k < past_kv_len; k++) {\n        int angle = rotary_angle[k];\n        for (int l = 0; l < head_dim / 32; l++) {\n          for (int m = 0; m < 32; m++) {\n            float x = GGML_FP16_TO_FP32(((ggml_fp16_t*)k_cache)[k * head_dim + l * 32 + m]);\n            float sin_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_sin)[angle * head_dim + l * 32 + m]);\n            float cos_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_cos)[angle * head_dim + l * 32 + m]);\n\n            if (l * 32 + m < head_dim / 2) {\n              k_cache_with_rope_fp16[k * head_dim + l * 32 + m] = GGML_FP32_TO_FP16(x * cos_val);\n              k_cache_with_rope_fp16[k * head_dim + l * 32 + m + head_dim / 2] = GGML_FP32_TO_FP16(-x * sin_val);\n            } else {\n              k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =\n                  GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(k_cache_with_rope_fp16[k * head_dim + l * 32 + m]) + x * sin_val);\n              k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2] = GGML_FP32_TO_FP16(\n                  GGML_FP16_TO_FP32(k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2]) - x * cos_val);\n            }\n          }\n        }\n      }\n\n      llamafile_sgemm(past_kv_len, bsz, head_dim, (ggml_fp16_t*)k_cache_with_rope_fp16, head_dim, (ggml_fp16_t*)q,\n                      head_dim, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16,\n                      GGML_TYPE_F32, GGML_PREC_DEFAULT);\n    } else {\n      bool ok = llamafile_sgemm(past_kv_len, bsz, head_dim, (ggml_fp16_t*)k_cache, head_dim, (ggml_fp16_t*)q, head_dim,\n                                attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16,\n                                GGML_TYPE_F32, GGML_PREC_DEFAULT);\n\n      if (!ok) {\n        printf(\"llamafile_sgemm failed\\n\");\n      }\n    }\n    // attn = attn * scale\n    float scale_factor = 1.0 / std::sqrt(float(head_dim));\n    ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);\n\n    // attn = attn & mask\n    if (!is_full_attn) {\n      for (int i = 0; i < bsz; i++) {\n        for (int j = 0; j < past_kv_len; j++) {\n          int index = i * past_kv_len + j;\n          if (!(attn_mask[j / 8] & (1 << (j % 8)))) {\n            attn_score[index] = std::numeric_limits<float>::lowest();\n          }\n        }\n      }\n    }\n\n    // attn = softmax(attn)\n    for (int i = 0; i < bsz; i++) {\n      float sum_exp = 0;\n      for (int j = 0; j < past_kv_len; j++) {\n        attn_score[i * past_kv_len + j] = std::exp(attn_score[i * past_kv_len + j]);\n        sum_exp += attn_score[i * past_kv_len + j];\n      }\n      for (int j = 0; j < past_kv_len; j++) {\n        attn_score[i * past_kv_len + j] /= sum_exp;\n      }\n      if (lse != nullptr) {\n        lse[i] = std::log(sum_exp);\n      }\n    }\n\n    // output = attn * v + attn * v_anchor\n    // std::vector<float> sum(bsz * head_dim);\n    float* sum =\n        reinterpret_cast<float*>(reinterpret_cast<char*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0);\n\n    // float* attn_score_fp16(bsz, past_kv_len)\n    ggml_fp16_t* attn_score_fp16 = (reinterpret_cast<ggml_fp16_t*>(reinterpret_cast<char*>(draft) +\n                                                                   sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n                                                                   sizeof(float) * bsz * head_dim));\n\n    for (int i = 0; i < bsz * past_kv_len; i++) {\n      attn_score_fp16[i] = GGML_FP32_TO_FP16(attn_score[i]);\n    }\n\n    // TODO: anchor\n    assert(num_v_anchor == 0);\n    bool ok = llamafile_sgemm(head_dim, bsz, past_kv_len, (ggml_fp16_t*)v_cache, past_kv_len,\n                              (ggml_fp16_t*)attn_score_fp16, past_kv_len, sum, head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE,\n                              v_type, GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n    if (!ok) {\n      printf(\"llamafile_sgemm failed\\n\");\n    }\n\n    // copy to output\n    for (int i = 0; i < bsz; i++) {\n      for (int j = 0; j < head_dim; j++) {\n        ((float*)output)[i * head_dim + j] = sum[i * head_dim + j];\n      }\n    }\n  } else {\n    assert(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q8_0);\n    assert(v_type == GGML_TYPE_Q4_0 || v_type == GGML_TYPE_Q8_0);\n\n    // attn = q * k + q * k_anchor\n    // TODO: anchor\n    assert(num_k_anchor == 0);\n\n    if (rotary_angle != nullptr) {\n      ggml_fp16_t* k_cache_with_rope_fp16 =\n          (reinterpret_cast<ggml_fp16_t*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n           sizeof(float) * bsz * head_dim);\n      block_q4_0* k_cache_with_rope_q4 =\n          (reinterpret_cast<block_q4_0*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n           sizeof(float) * bsz * head_dim) +\n          sizeof(ggml_fp16_t) * bsz * head_dim;\n      // dequant k_cache and apply rope\n      // k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)\n      // k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)\n\n      // k(i)cos(i) -> k_rope(i)\n      // k(i)sin(i+l) -> k_rope(i+l)\n\n      // k(i)cos(i) -> k_rope(i)\n      // -k(i)sin(i-l) -> k_rope(i-l)\n\n      std::vector<float> block_fp32(32);\n      for (int k = 0; k < past_kv_len; k++) {\n        int angle = rotary_angle[k];\n        for (int l = 0; l < head_dim / 32; l++) {\n          block_q4_0 block = ((block_q4_0*)k_cache)[k * head_dim / 32 + l];\n          dequantize_row_q4_0(&block, block_fp32.data(), 32);\n          for (int m = 0; m < 32; m++) {\n            float sin_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_sin)[angle * head_dim + l * 32 + m]);\n            float cos_val = GGML_FP16_TO_FP32(((ggml_fp16_t*)rotary_cos)[angle * head_dim + l * 32 + m]);\n\n            if (l * 32 + m < head_dim / 2) {\n              k_cache_with_rope_fp16[k * head_dim + l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m] * cos_val);\n              k_cache_with_rope_fp16[k * head_dim + l * 32 + m + head_dim / 2] =\n                  GGML_FP32_TO_FP16(-block_fp32[m] * sin_val);\n            } else {\n              k_cache_with_rope_fp16[k * head_dim + l * 32 + m] += GGML_FP32_TO_FP16(block_fp32[m] * sin_val);\n              k_cache_with_rope_fp16[k * head_dim + l * 32 + m - head_dim / 2] -=\n                  GGML_FP32_TO_FP16(block_fp32[m] * cos_val);\n            }\n          }\n        }\n      }\n      // quantize k_cache_with_rope_fp16\n      for (int k = 0; k < past_kv_len; k++) {\n        for (int l = 0; l < head_dim / 32; l++) {\n          for (int m = 0; m < 32; m++) {\n            block_fp32[m] = GGML_FP16_TO_FP32(k_cache_with_rope_fp16[k * head_dim + l * 32 + m]);\n          }\n          quantize_row_q4_0(block_fp32.data(), &k_cache_with_rope_q4[k * head_dim / 32 + l], 32);\n        }\n      }\n\n      llamafile_sgemm(past_kv_len, bsz, head_dim / 32, (block_q4_0*)k_cache_with_rope_q4, head_dim / 32, (block_q8_0*)q,\n                      head_dim / 32, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_Q8_0,\n                      GGML_TYPE_F32, GGML_PREC_DEFAULT);\n    } else {\n      llamafile_sgemm(past_kv_len, bsz, head_dim / 32, (block_q4_0*)k_cache, head_dim / 32, (block_q8_0*)q,\n                      head_dim / 32, attn_score, past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_Q8_0,\n                      GGML_TYPE_F32, GGML_PREC_DEFAULT);\n    }\n\n    // attn = attn * scale\n    float scale_factor = 1.0 / std::sqrt(float(head_dim));\n    ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);\n\n    // attn = attn & mask\n    if (!is_full_attn) {\n      for (int i = 0; i < bsz; i++) {\n        for (int j = 0; j < past_kv_len; j++) {\n          int index = i * past_kv_len + j;\n          if (!(attn_mask[j / 8] & (1 << (j % 8)))) {\n            attn_score[index] = std::numeric_limits<float>::lowest();\n          }\n        }\n      }\n    }\n\n    // attn = softmax(attn)\n    for (int i = 0; i < bsz; i++) {\n      float sum_exp = 0;\n      for (int j = 0; j < past_kv_len; j++) {\n        attn_score[i * past_kv_len + j] = std::exp(attn_score[i * past_kv_len + j]);\n        sum_exp += attn_score[i * past_kv_len + j];\n      }\n      for (int j = 0; j < past_kv_len; j++) {\n        attn_score[i * past_kv_len + j] /= sum_exp;\n      }\n      if (lse != nullptr) {\n        lse[i] = std::log(sum_exp);\n      }\n    }\n\n    // output = attn * v + attn * v_anchor\n    // std::vector<block_q8_0> attn_q8_0(bsz * past_kv_len / QK8_0);\n    block_q8_0* attn_q8_0 = reinterpret_cast<block_q8_0*>(draft);\n    quantize_row_q8_0(attn_score, attn_q8_0, bsz * past_kv_len);\n    // std::vector<float> sum(bsz * head_dim);\n    float* sum =\n        reinterpret_cast<float*>(reinterpret_cast<char*>(draft) + sizeof(block_q8_0) * bsz * past_kv_len / QK8_0);\n    // TODO: anchor\n    assert(num_v_anchor == 0);\n    llamafile_sgemm(head_dim, bsz, past_kv_len / 32, (block_q4_0*)v_cache, past_kv_len / 32, attn_q8_0,\n                    past_kv_len / 32, sum, head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE, v_type, GGML_TYPE_Q8_0,\n                    GGML_TYPE_F32, GGML_PREC_DEFAULT);\n\n    quantize_row_q8_0(sum, (block_q8_0*)output, bsz * head_dim);\n  }\n}\n"
  },
  {
    "path": "kt-kernel/operators/kvcache/kvcache_load_dump.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include <chrono>\n#include <fstream>\n#include <iostream>\n\n#include \"kvcache.h\"\n\nvoid KVCache::load_kvcache(std::string tensor_file_path, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  std::ifstream ifs_tensor(tensor_file_path, std::ios::binary);\n  if (!ifs_tensor) {\n    throw std::runtime_error(\"Failed to open tensor file\");\n  }\n  ifs_tensor.read(reinterpret_cast<char*>(&cache_total_len_), sizeof(cache_total_len_));\n  int past_block_num = (cache_total_len_ + config_.block_len - 1) / config_.block_len;\n  printf(\"cache_total_len: %d, past_block_num: %d\\n\", cache_total_len_, past_block_num);\n  for (int i = 0; i < config_.layer_num; ++i) {\n    past_block_num_[i] = past_block_num;\n  }\n  ifs_tensor.read(reinterpret_cast<char*>(anchor_.data()), anchor_.size() * sizeof(ggml_fp16_t));\n  for (int i = 0; i < config_.layer_num; ++i) {\n    for (int j = 0; j < config_.kv_head_num; ++j) {\n      for (int k = 0; k < past_block_num_[i]; ++k) {\n        if (config_.kv_type == GGML_TYPE_F16) {\n          ifs_tensor.read(reinterpret_cast<char*>(k_cache_fp16_[i][j][k].data()),\n                          k_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));\n          ifs_tensor.read(reinterpret_cast<char*>(v_cache_fp16_[i][j][k].data()),\n                          v_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));\n        } else if (config_.kv_type == GGML_TYPE_Q4_0) {\n          ifs_tensor.read(reinterpret_cast<char*>(k_cache_q4[i][j][k].data()),\n                          k_cache_q4[i][j][k].size() * sizeof(block_q4_0));\n          ifs_tensor.read(reinterpret_cast<char*>(v_cache_q4[i][j][k].data()),\n                          v_cache_q4[i][j][k].size() * sizeof(block_q4_0));\n        }\n      }\n    }\n    for (int k = 0; k < past_block_num_[i]; ++k) {\n      for (int l = 0; l < config_.block_len; l++) {\n        ifs_tensor.read(reinterpret_cast<char*>(importance_[i][k][l].data()),\n                        importance_[i][k][l].size() * sizeof(ggml_fp16_t));\n      }\n    }\n  }\n  ifs_tensor.close();\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  printf(\"time of load: %f s\\n\", diff.count());\n}\nvoid KVCache::dump_kvcache(int* block_table, int cache_total_len, std::string tensor_file_path, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n  std::ofstream ofs(tensor_file_path, std::ios::binary);\n  printf(\"dump_kvcache: %s\\n\", tensor_file_path.c_str());\n  if (!ofs.is_open()) {\n    std::cerr << \"Cannot open file \" << tensor_file_path << std::endl;\n    return;\n  }\n  ofs.write(reinterpret_cast<const char*>(&cache_total_len), sizeof(cache_total_len));\n  int past_block_num = (cache_total_len + config_.block_len - 1) / config_.block_len;\n  printf(\"cache_total_len: %d, past_block_num: %d\\n\", cache_total_len, past_block_num);\n  ofs.write(reinterpret_cast<const char*>(anchor_.data()), anchor_.size() * sizeof(ggml_fp16_t));\n  for (int i = 0; i < config_.layer_num; ++i) {\n    for (int j = 0; j < config_.kv_head_num; ++j) {\n      for (int k = 0; k < past_block_num; ++k) {\n        int block_idx = block_table[k];\n        if (config_.kv_type == GGML_TYPE_F16) {\n          ofs.write(reinterpret_cast<const char*>(k_cache_fp16_[i][j][block_idx].data()),\n                    k_cache_fp16_[i][j][block_idx].size() * sizeof(ggml_fp16_t));\n          ofs.write(reinterpret_cast<const char*>(v_cache_fp16_[i][j][block_idx].data()),\n                    v_cache_fp16_[i][j][block_idx].size() * sizeof(ggml_fp16_t));\n\n        } else if (config_.kv_type == GGML_TYPE_Q4_0) {\n          ofs.write(reinterpret_cast<const char*>(k_cache_q4[i][j][block_idx].data()),\n                    k_cache_q4[i][j][block_idx].size() * sizeof(block_q4_0));\n          ofs.write(reinterpret_cast<const char*>(v_cache_q4[i][j][block_idx].data()),\n                    v_cache_q4[i][j][block_idx].size() * sizeof(block_q4_0));\n        }\n      }\n    }\n    for (int k = 0; k < past_block_num; ++k) {\n      int block_idx = block_table[k];\n      for (int l = 0; l < config_.block_len; l++) {\n        ofs.write(reinterpret_cast<const char*>(importance_[i][block_idx][l].data()),\n                  importance_[i][block_idx][l].size() * sizeof(ggml_fp16_t));\n      }\n    }\n  }\n  ofs.close();\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> diff = end - start;\n  printf(\"time of dump: %f s\\n\", diff.count());\n}"
  },
  {
    "path": "kt-kernel/operators/kvcache/kvcache_read_write.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include <chrono>\n\n#include \"ggml-impl.h\"\n#include \"kvcache.h\"\n\nvoid KVCache::get_anchor_one_block(ggml_fp16_t* anchor, int layer_id, int block_idx, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  block_idx = block_idx;\n  seq_len_ = config_.block_len;\n  anchor_data_ = const_cast<uint16_t*>(anchor);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  printf(\"layer %d block %d time of reading anchor: %f s\\n\", layer_id, block_idx, duration.count());\n}\n\nvoid KVCache::update_anchor_one_block(const ggml_fp16_t* anchor, int layer_id, int block_idx, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  block_idx = block_idx;\n  seq_len_ = config_.block_len;\n  anchor_data_ = const_cast<uint16_t*>(anchor);\n\n  // Each task updates the anchor of a certain position\n  // backend->do_work_stealing_job(config_.anchor_num, [&](int task_id) {\n  //     int k = task_id % config_.anchor_num;\n  //     int head_id = task_id / config_.anchor_num;\n  //     memcpy(anchor_[layer_id_][head_id][block_idx].data() +\n  //                k * config_.head_dim,\n  //            anchor_data_ + k * config_.head_dim,\n  //            sizeof(uint16_t) * config_.head_dim);\n  // });\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  printf(\"layer %d block %d time of writting anchor: %f s\\n\", layer_id, block_idx, duration.count());\n}\n\nvoid KVCache::update_importance_one_block(const ggml_fp16_t* importance, int layer_id, int block_idx,\n                                          WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  block_idx = block_idx;\n  seq_len_ = config_.block_len;\n  importance_data_ = const_cast<uint16_t*>(importance);\n\n  // Each task updates the importance of a certain position\n  backend->do_work_stealing_job(\n      config_.block_len, nullptr,\n      [&](int task_id) {\n        int k = task_id;\n        memcpy(importance_[layer_id_][block_idx].data() + k, importance_data_ + k, sizeof(uint16_t));\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  printf(\"layer %d block %d time of writting importance: %f s\\n\", layer_id, block_idx, duration.count());\n}\n\nvoid KVCache::get_importance_one_block(ggml_fp16_t* importance, int layer_id, int block_idx, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  block_idx = block_idx;\n  seq_len_ = config_.block_len;\n  importance_data_ = const_cast<uint16_t*>(importance);\n\n  // Each task updates the importance of a certain position\n  backend->do_work_stealing_job(\n      config_.block_len, nullptr,\n      [&](int task_id) {\n        int k = task_id;\n        memcpy(importance_data_ + k, importance_[layer_id_][block_idx].data() + k, sizeof(uint16_t));\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  printf(\"layer %d block %d time of reading importance: %f s\\n\", layer_id, block_idx, duration.count());\n}\n\nvoid KVCache::update_kvcache_one_block_fp16(const ggml_fp16_t* k_in, const ggml_fp16_t* v_in, int layer_id,\n                                            int block_idx, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  block_idx = block_idx;\n  seq_len_ = config_.block_len;\n  k_data_ = const_cast<uint16_t*>(k_in);\n  v_data_ = const_cast<uint16_t*>(v_in);\n\n  int new_block_num = std::max((int)past_block_num_[layer_id], block_idx + 1);\n\n  importance_[layer_id_].resize(new_block_num);\n\n  for (int i = 0; i < config_.kv_head_num; i++) {\n    k_cache_q4[layer_id][i].resize(new_block_num);\n    v_cache_q4[layer_id][i].resize(new_block_num);\n    // anchor_[layer_id][i].resize(new_block_num);\n  }\n\n  for (int i = 0; i < new_block_num; i++) {\n    importance_[layer_id][i].resize(config_.block_len);\n  }\n\n  // Each task updates the k cache or v cache of a certain header\n  backend->do_work_stealing_job(\n      config_.kv_head_num * 2, nullptr,\n      [&](int task_id) {\n        std::vector<float> block_fp32(32);\n        int head_id = task_id / 2;\n        if (task_id & 1) {\n          // fill k_cache_\n          k_cache_q4[layer_id_][head_id][block_idx].resize(config_.block_len * config_.head_dim / 32);\n          for (int k = 0; k < config_.block_len; k++) {\n            for (int l = 0; l < config_.head_dim / 32; l++) {\n              block_q4_0 block;\n              for (int m = 0; m < 32; m++) {\n                block_fp32[m] = GGML_FP16_TO_FP32(\n                    k_data_[((0 * config_.kv_head_num + head_id) * seq_len_ + 0 * config_.block_len + k) *\n                                config_.head_dim +\n                            l * 32 + m]);\n              }\n              quantize_row_q4_0(block_fp32.data(), &block, 32);\n              k_cache_q4[layer_id_][head_id][block_idx][k * config_.head_dim / 32 + l] = block;\n            }\n          }\n        } else {\n          // fill v_cache_\n          v_cache_q4[layer_id_][head_id][block_idx].resize(config_.head_dim * config_.block_len / 32);\n          for (int k = 0; k < config_.block_len / 32; k++) {\n            for (int l = 0; l < config_.head_dim; l++) {\n              block_q4_0 block;\n              for (int m = 0; m < 32; m++) {\n                block_fp32[m] = GGML_FP16_TO_FP32(\n                    v_data_[((0 * config_.kv_head_num + head_id) * seq_len_ + 0 * config_.block_len + k * 32 + m) *\n                                config_.head_dim +\n                            l]);\n              }\n              quantize_row_q4_0(block_fp32.data(), &block, 32);\n              v_cache_q4[layer_id_][head_id][block_idx][l * config_.block_len / 32 + k] = block;\n            }\n          }\n        }\n      },\n      nullptr);\n  past_block_num_[layer_id] = new_block_num;\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  printf(\"layer %d block %d time of writting KV Cache: %f s\\n\", layer_id, block_idx, duration.count());\n  // printf(\"get_one_block_fp16 duration: %ld\\n\", duration);\n}\n\nvoid KVCache::get_kvcache_one_block_fp16(ggml_fp16_t* k_in, ggml_fp16_t* v_in, int layer_id, int block_idx,\n                                         WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  seq_len_ = config_.block_len;\n  k_data_ = reinterpret_cast<uint16_t*>(k_in);\n  v_data_ = reinterpret_cast<uint16_t*>(v_in);\n\n  // printf(\"layer_id: %d, block_idx: %d\\n\", layer_id, block_idx);\n  // Each task gets the k cache or v cache of a certain header\n  backend->do_work_stealing_job(\n      config_.kv_head_num * 2, nullptr,\n      [&](int task_id) {\n        std::vector<float> block_fp32(32);\n        int head_id = task_id / 2;\n        if (task_id & 1) {\n          // get k_cache_\n          for (int k = 0; k < config_.block_len; k++) {\n            for (int l = 0; l < config_.head_dim / 32; l++) {\n              block_q4_0 block = k_cache_q4[layer_id_][head_id][block_idx][k * config_.head_dim / 32 + l];\n              dequantize_row_q4_0(&block, block_fp32.data(), 32);\n              for (int m = 0; m < 32; m++) {\n                k_data_[((0 * config_.kv_head_num + head_id) * seq_len_ + 0 * config_.block_len + k) *\n                            config_.head_dim +\n                        l * 32 + m] = GGML_FP32_TO_FP16(block_fp32[m]);\n              }\n            }\n          }\n        } else {\n          // get v_cache_\n          for (int k = 0; k < config_.block_len / 32; k++) {\n            for (int l = 0; l < config_.head_dim; l++) {\n              block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx][l * config_.block_len / 32 + k];\n              dequantize_row_q4_0(&block, block_fp32.data(), 32);\n              for (int m = 0; m < 32; m++) {\n                v_data_[((0 * config_.kv_head_num + head_id) * seq_len_ + 0 * config_.block_len + k * 32 + m) *\n                            config_.head_dim +\n                        l] = GGML_FP32_TO_FP16(block_fp32[m]);\n              }\n            }\n          }\n        }\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  printf(\"layer %d block %d time of reading KV Cache: %f s\\n\", layer_id, block_idx, duration.count());\n  // printf(\"get_one_block_fp16 duration: %ld\\n\", duration);\n}\n\n// k_in: (batch_size, seq_len, head_num, head_dim)\n// v_in: (batch_size, seq_len, head_num, head_dim)\nvoid KVCache::get_and_update_kvcache_fp16(ggml_fp16_t* k_in, ggml_fp16_t* v_in, int layer_id, int* block_table,\n                                          int batch_size, int max_block_num, int* cache_seqlens, int q_len,\n                                          WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  k_data_ = const_cast<uint16_t*>(k_in);\n  v_data_ = const_cast<uint16_t*>(v_in);\n\n  // Each task updates the k cache and v cache of a certain header\n  backend->do_work_stealing_job(\n      config_.kv_head_num * max_block_num * batch_size, nullptr,\n      [&](int task_id) {\n        // printf(\"block_idx: %d, task_id: %d\\n\", block_idx, task_id);\n        std::vector<float> block_fp32(32);\n        int batch_id = task_id / (config_.kv_head_num * max_block_num);\n        int block_id = (task_id / config_.kv_head_num) % max_block_num;\n        int head_id = task_id % config_.kv_head_num;\n        int block_idx = block_table[batch_id * max_block_num + block_id];\n        int seq_len = cache_seqlens[batch_id];\n        int block_l = block_id * config_.block_len;\n        int block_r = block_id * config_.block_len + config_.block_len;\n\n        if (block_l < seq_len) {\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            for (int k = 0; k < config_.block_len; k++) {\n              if (block_id * config_.block_len + k >= seq_len) break;\n              for (int l = 0; l < config_.head_dim; l++) {\n                k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                        block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                        k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l] =\n                    k_cache_fp16_[layer_id_][head_id][block_idx][k * config_.head_dim + l];\n                v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                        block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                        k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l] =\n                    v_cache_fp16_[layer_id_][head_id][block_idx][l * config_.block_len + k];\n              }\n            }\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            // get k_cache_\n            for (int k = 0; k < config_.block_len; k++) {\n              if (block_id * config_.block_len + k >= seq_len) break;\n              for (int l = 0; l < config_.head_dim / 32; l++) {\n                block_q4_0 block = k_cache_q4[layer_id_][head_id][block_idx][k * config_.head_dim / 32 + l];\n                dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                for (int m = 0; m < 32; m++) {\n                  k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m] =\n                      GGML_FP32_TO_FP16(block_fp32[m]);\n                }\n              }\n            }\n            // get v_cache_\n            for (int k = 0; k < config_.block_len / 32; k++) {\n              for (int l = 0; l < config_.head_dim; l++) {\n                block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx][l * config_.block_len / 32 + k];\n                dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                for (int m = 0; m < 32; m++) {\n                  if (block_id * config_.block_len + k * 32 + m >= seq_len) break;\n                  v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                      GGML_FP32_TO_FP16(block_fp32[m]);\n                }\n              }\n            }\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            // get k_cache_\n            for (int k = 0; k < config_.block_len; k++) {\n              if (block_id * config_.block_len + k >= seq_len) break;\n              for (int l = 0; l < config_.head_dim / 32; l++) {\n                block_q8_0 block = k_cache_q8[layer_id_][head_id][block_idx][k * config_.head_dim / 32 + l];\n                dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                for (int m = 0; m < 32; m++) {\n                  k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m] =\n                      GGML_FP32_TO_FP16(block_fp32[m]);\n                }\n              }\n            }\n            // get v_cache_\n            for (int k = 0; k < config_.block_len / 32; k++) {\n              for (int l = 0; l < config_.head_dim; l++) {\n                block_q8_0 block = v_cache_q8[layer_id_][head_id][block_idx][l * config_.block_len / 32 + k];\n                dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                for (int m = 0; m < 32; m++) {\n                  if (block_id * config_.block_len + k * 32 + m >= seq_len) break;\n                  v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                      GGML_FP32_TO_FP16(block_fp32[m]);\n                }\n              }\n            }\n          }\n        }\n        if (block_r > seq_len && block_l < seq_len + q_len) {\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            for (int k = 0; k < config_.block_len; k++) {\n              if (block_id * config_.block_len + k >= seq_len + q_len || block_id * config_.block_len + k < seq_len)\n                continue;\n              for (int l = 0; l < config_.head_dim; l++) {\n                k_cache_fp16_[layer_id_][head_id][block_idx][k * config_.head_dim + l] =\n                    k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                            block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                            k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l];\n                v_cache_fp16_[layer_id_][head_id][block_idx][l * config_.block_len + k] =\n                    v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                            block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                            k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l];\n              }\n            }\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            // fill k_cache_\n            for (int k = 0; k < config_.block_len; k++) {\n              if (block_id * config_.block_len + k >= seq_len + q_len || block_id * config_.block_len + k < seq_len)\n                continue;\n              for (int l = 0; l < config_.head_dim / 32; l++) {\n                block_q4_0 block;\n                for (int m = 0; m < 32; m++) {\n                  block_fp32[m] = GGML_FP16_TO_FP32(\n                      k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                              block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                              k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m]);\n                }\n                quantize_row_q4_0(block_fp32.data(), &block, 32);\n                k_cache_q4[layer_id_][head_id][block_idx][k * config_.head_dim / 32 + l] = block;\n              }\n            }\n\n            // fill v_cache_\n            for (int k = 0; k < config_.block_len / 32; k++) {\n              for (int l = 0; l < config_.head_dim; l++) {\n                block_q4_0 block;\n                for (int m = 0; m < 32; m++) {\n                  if (block_id * config_.block_len + k * 32 + m >= seq_len + q_len) {\n                    block_fp32[m] = 0;\n                    continue;\n                  }\n                  block_fp32[m] = GGML_FP16_TO_FP32(\n                      v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                              block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                              (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l]);\n                }\n                quantize_row_q4_0(block_fp32.data(), &block, 32);\n                v_cache_q4[layer_id_][head_id][block_idx][l * config_.block_len / 32 + k] = block;\n              }\n            }\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            // fill k_cache_\n            for (int k = 0; k < config_.block_len; k++) {\n              if (block_id * config_.block_len + k >= seq_len + q_len || block_id * config_.block_len + k < seq_len)\n                continue;\n              for (int l = 0; l < config_.head_dim / 32; l++) {\n                block_q8_0 block;\n                for (int m = 0; m < 32; m++) {\n                  block_fp32[m] = GGML_FP16_TO_FP32(\n                      k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                              block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                              k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m]);\n                }\n                quantize_row_q8_0(block_fp32.data(), &block, 32);\n                k_cache_q8[layer_id_][head_id][block_idx][k * config_.head_dim / 32 + l] = block;\n              }\n            }\n\n            // fill v_cache_\n            for (int k = 0; k < config_.block_len / 32; k++) {\n              for (int l = 0; l < config_.head_dim; l++) {\n                block_q8_0 block;\n                for (int m = 0; m < 32; m++) {\n                  if (block_id * config_.block_len + k * 32 + m >= seq_len + q_len) {\n                    block_fp32[m] = 0;\n                    continue;\n                  }\n                  block_fp32[m] = GGML_FP16_TO_FP32(\n                      v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                              block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                              (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l]);\n                }\n                quantize_row_q8_0(block_fp32.data(), &block, 32);\n                v_cache_q8[layer_id_][head_id][block_idx][l * config_.block_len / 32 + k] = block;\n              }\n            }\n          }\n        }\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n\n  // printf(\"layer %d time of reading and updating KV Cache: %f s\\n\",\n  // layer_id,\n  //        duration.count());\n}\n\nvoid KVCache::update_importance(const ggml_fp16_t* importance, int layer_id, int* block_table, int batch_size,\n                                int max_block_num, int* offset, int width, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  importance_data_ = const_cast<uint16_t*>(importance);\n\n  // Each task updates the importance of a certain position\n  backend->do_work_stealing_job(\n      max_block_num * batch_size, nullptr,\n      [&](int task_id) {\n        int block_id = task_id % max_block_num;\n        int batch_id = task_id / max_block_num;\n        int block_idx = block_table[batch_id * max_block_num + block_id];\n        if (block_id > (offset[batch_id] + width) / config_.block_len) {\n          return;\n        }\n        for (int k = 0; k < config_.block_len; k++) {\n          for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n            importance_[layer_id_][block_idx][k][head_id] = GGML_FP32_TO_FP16(\n                GGML_FP16_TO_FP32(importance_data_[batch_id * max_block_num * config_.block_len * config_.q_head_num +\n                                                   (block_id * config_.block_len + k) * config_.q_head_num + head_id]) +\n                GGML_FP16_TO_FP32(importance_[layer_id_][block_idx][k][head_id]));\n          }\n        }\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n\n  // printf(\"layer %d time of updating importance: %f s\\n\", layer_id,\n  //        duration.count());\n}\n\nvoid KVCache::get_kvcache_fp16(ggml_fp16_t* k_in, ggml_fp16_t* v_in, int layer_id, int* block_table, int batch_size,\n                               int max_block_num, int* cache_seqlens, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  k_data_ = const_cast<uint16_t*>(k_in);\n  v_data_ = const_cast<uint16_t*>(v_in);\n\n  // Each task updates the k cache and v cache of a certain header\n  backend->do_work_stealing_job(\n      config_.kv_head_num * max_block_num * batch_size, nullptr,\n      [&](int task_id) {\n        // printf(\"block_idx: %d, task_id: %d\\n\", block_idx, task_id);\n        std::vector<float> block_fp32(32);\n        int batch_id = task_id / (config_.kv_head_num * max_block_num);\n        int block_id = (task_id / config_.kv_head_num) % max_block_num;\n        int head_id = task_id % config_.kv_head_num;\n        int block_idx = block_table[batch_id * max_block_num + block_id];\n        int seq_len = cache_seqlens[batch_id];\n        int block_l = block_id * config_.block_len;\n        int block_r = block_id * config_.block_len + config_.block_len;\n\n        if (block_l < seq_len) {\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            for (int k = 0; k < config_.block_len; k++) {\n              if (block_id * config_.block_len + k >= seq_len) break;\n              for (int l = 0; l < config_.head_dim; l++) {\n                k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                        block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                        k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l] =\n                    k_cache_fp16_[layer_id_][head_id][block_idx][k * config_.head_dim + l];\n                v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                        block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                        k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l] =\n                    v_cache_fp16_[layer_id_][head_id][block_idx][l * config_.block_len + k];\n              }\n            }\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            // get k_cache_\n            for (int k = 0; k < config_.block_len; k++) {\n              if (block_id * config_.block_len + k >= seq_len) break;\n              for (int l = 0; l < config_.head_dim / 32; l++) {\n                block_q4_0 block = k_cache_q4[layer_id_][head_id][block_idx][k * config_.head_dim / 32 + l];\n                dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                for (int m = 0; m < 32; m++) {\n                  k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m] =\n                      GGML_FP32_TO_FP16(block_fp32[m]);\n                }\n              }\n            }\n            // get v_cache_\n            for (int k = 0; k < config_.block_len / 32; k++) {\n              for (int l = 0; l < config_.head_dim; l++) {\n                block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx][l * config_.block_len / 32 + k];\n                dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                for (int m = 0; m < 32; m++) {\n                  if (block_id * config_.block_len + k * 32 + m >= seq_len) break;\n                  v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                      GGML_FP32_TO_FP16(block_fp32[m]);\n                }\n              }\n            }\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            // get k_cache_\n            for (int k = 0; k < config_.block_len; k++) {\n              if (block_id * config_.block_len + k >= seq_len) break;\n              for (int l = 0; l < config_.head_dim / 32; l++) {\n                block_q8_0 block = k_cache_q8[layer_id_][head_id][block_idx][k * config_.head_dim / 32 + l];\n                dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                for (int m = 0; m < 32; m++) {\n                  k_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          k * (config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l * 32 + m] =\n                      GGML_FP32_TO_FP16(block_fp32[m]);\n                }\n              }\n            }\n            // get v_cache_\n            for (int k = 0; k < config_.block_len / 32; k++) {\n              for (int l = 0; l < config_.head_dim; l++) {\n                block_q8_0 block = v_cache_q8[layer_id_][head_id][block_idx][l * config_.block_len / 32 + k];\n                dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                for (int m = 0; m < 32; m++) {\n                  if (block_id * config_.block_len + k * 32 + m >= seq_len) break;\n                  v_data_[batch_id * (max_block_num * config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          block_id * (config_.block_len * config_.kv_head_num * config_.head_dim) +\n                          (k * 32 + m) * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                      GGML_FP32_TO_FP16(block_fp32[m]);\n                }\n              }\n            }\n          }\n        }\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n}\n\nvoid KVCache::update_kvcache_fp16(const ggml_fp16_t* k_in, const ggml_fp16_t* v_in, int layer_id, int* block_table,\n                                  int batch_size, int max_block_num, int* cache_seqlens, int q_len,\n                                  WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  k_data_ = const_cast<uint16_t*>(k_in);\n  v_data_ = const_cast<uint16_t*>(v_in);\n  // Each task updates the k cache and v cache of a certain header\n  backend->do_work_stealing_job(\n      batch_size * config_.kv_head_num * q_len, nullptr,\n      [&](int task_id) {\n        int batch_id = task_id / (config_.kv_head_num * q_len);\n        int head_id = task_id / q_len % config_.kv_head_num;\n        int seq_len = cache_seqlens[batch_id] + task_id % q_len;\n        int q_offset = task_id % q_len;\n\n        int block_id = seq_len / config_.block_len;\n        int block_idx = block_table[batch_id * max_block_num + block_id];\n        int pos_in_block = seq_len % config_.block_len;\n\n        if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n          for (int l = 0; l < config_.head_dim; l++) {\n            k_cache_fp16_[layer_id_][head_id][block_idx][pos_in_block * config_.head_dim + l] =\n                k_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) +\n                        q_offset * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l];\n            v_cache_fp16_[layer_id_][head_id][block_idx][l * config_.block_len + pos_in_block] =\n                v_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) +\n                        q_offset * config_.kv_head_num * config_.head_dim + head_id * config_.head_dim + l];\n          }\n        } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n          std::vector<float> block_fp32(32);\n          // fill k_cache_\n          for (int l = 0; l < config_.head_dim / 32; l++) {\n            block_q4_0 block;\n            for (int m = 0; m < 32; m++) {\n              block_fp32[m] = GGML_FP16_TO_FP32(k_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) +\n                                                        head_id * config_.head_dim + l * 32 + m]);\n            }\n            quantize_row_q4_0(block_fp32.data(), &block, 32);\n\n            k_cache_q4[layer_id_][head_id][block_idx][pos_in_block * config_.head_dim / 32 + l] = block;\n          }\n\n          // fill v_cache_\n          for (int l = 0; l < config_.head_dim; l++) {\n            block_q4_0 block =\n                v_cache_q4[layer_id_][head_id][block_idx][l * config_.block_len / 32 + pos_in_block / 32];\n            dequantize_row_q4_0(&block, block_fp32.data(), 32);\n            block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32(\n                v_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l]);\n            quantize_row_q4_0(block_fp32.data(), &block, 32);\n            v_cache_q4[layer_id_][head_id][block_idx][l * config_.block_len / 32 + pos_in_block / 32] = block;\n          }\n        } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n          std::vector<float> block_fp32(32);\n          // fill k_cache_\n          for (int l = 0; l < config_.head_dim / 32; l++) {\n            block_q8_0 block;\n            for (int m = 0; m < 32; m++) {\n              block_fp32[m] = GGML_FP16_TO_FP32(k_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) +\n                                                        head_id * config_.head_dim + l * 32 + m]);\n            }\n            quantize_row_q8_0(block_fp32.data(), &block, 32);\n\n            k_cache_q8[layer_id_][head_id][block_idx][pos_in_block * config_.head_dim / 32 + l] = block;\n          }\n\n          // fill v_cache_\n          for (int l = 0; l < config_.head_dim; l++) {\n            block_q8_0 block =\n                v_cache_q8[layer_id_][head_id][block_idx][l * config_.block_len / 32 + pos_in_block / 32];\n            dequantize_row_q8_0(&block, block_fp32.data(), 32);\n            block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32(\n                v_data_[batch_id * (q_len * config_.kv_head_num * config_.head_dim) + head_id * config_.head_dim + l]);\n            quantize_row_q8_0(block_fp32.data(), &block, 32);\n            v_cache_q8[layer_id_][head_id][block_idx][l * config_.block_len / 32 + pos_in_block / 32] = block;\n          }\n        }\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  // printf(\"layer %d time of reading KV Cache: %f s\\n\", layer_id,\n  //        duration.count());\n}\n\nvoid KVCache::get_all_kvcache_one_layer(int layer_id, ggml_fp16_t* k_in, ggml_fp16_t* v_in, WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  layer_id_ = layer_id;\n  seq_len_ = config_.block_len;\n  block_num_ = get_cache_total_block_num();\n  k_data_ = reinterpret_cast<uint16_t*>(k_in);\n  v_data_ = reinterpret_cast<uint16_t*>(v_in);\n\n  // Each task gets the k cache or v cache of a certain header\n  backend->do_work_stealing_job(\n      config_.kv_head_num * past_block_num_[layer_id] * 2, nullptr,\n      [&](int task_id) {\n        std::vector<float> block_fp32(32);\n        int head_id = task_id / 2 / past_block_num_[layer_id];\n        int block_idx = task_id / 2 % past_block_num_[layer_id];\n        if (block_idx >= block_num_) return;\n\n        int max_offset = 0;\n        if (task_id & 1) {\n          // get k_cache_\n          for (int k = 0; k < config_.block_len; k++) {\n            if (block_idx * seq_len_ + k >= cache_total_len_) break;\n            for (int l = 0; l < config_.head_dim / 32; l++) {\n              block_q4_0 block = k_cache_q4[layer_id_][head_id][block_idx][k * config_.head_dim / 32 + l];\n              dequantize_row_q4_0(&block, block_fp32.data(), 32);\n              for (int m = 0; m < 32; m++) {\n                k_data_[(head_id * cache_total_len_ + block_idx * config_.block_len + k) * config_.head_dim + l * 32 +\n                        m] = GGML_FP32_TO_FP16(block_fp32[m]);\n                max_offset =\n                    std::max(max_offset,\n                             (int)(head_id * cache_total_len_ + block_idx * config_.block_len + k) * config_.head_dim +\n                                 l * 32 + m);\n              }\n            }\n          }\n        } else {\n          // get v_cache_\n          for (int k = 0; k < config_.block_len / 32; k++) {\n            for (int l = 0; l < config_.head_dim; l++) {\n              block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx][l * config_.block_len / 32 + k];\n              dequantize_row_q4_0(&block, block_fp32.data(), 32);\n              for (int m = 0; m < 32; m++) {\n                if (block_idx * seq_len_ + k * 32 + m >= cache_total_len_) break;\n                v_data_[(head_id * cache_total_len_ + block_idx * config_.block_len + k * 32 + m) * config_.head_dim +\n                        l] = GGML_FP32_TO_FP16(block_fp32[m]);\n                max_offset = std::max(\n                    max_offset,\n                    (int)((head_id * cache_total_len_ + block_idx * config_.block_len + k * 32 + m) * config_.head_dim +\n                          l));\n              }\n            }\n          }\n        }\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  // printf(\"layer %d block num %d time of reading all KV Cache: %f s\\n\",\n  //        layer_id, block_num_, duration.count());\n}\n"
  },
  {
    "path": "kt-kernel/operators/kvcache/kvcache_utils.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include <chrono>\n\n#include \"ggml-impl.h\"\n#include \"kvcache.h\"\n\nstd::string ggml_type_to_string(ggml_type type) {\n  switch (type) {\n    case GGML_TYPE_F32:\n      return \"GGML_TYPE_F32\";\n    case GGML_TYPE_F16:\n      return \"GGML_TYPE_F16\";\n    case GGML_TYPE_Q4_0:\n      return \"GGML_TYPE_Q4_0\";\n    case GGML_TYPE_Q8_0:\n      return \"GGML_TYPE_Q8_0\";\n  }\n  return \"UNDIFINED\";\n}\nstd::string AnchorTypeToString(AnchorType type) {\n  switch (type) {\n    case AnchorType::DYNAMIC:\n      return \"DYNAMIC\";\n    case AnchorType::BLOCK_MEAN:\n      return \"BLOCK_MEAN\";\n    case AnchorType::BLOCK_MAX:\n      return \"BLOCK_MAX\";\n    case AnchorType::FIXED_ANCHOR:\n      return \"FIXED_ANCHOR\";\n    case AnchorType::QUEST:\n      return \"QUEST\";\n  }\n  return \"UNDIFINED\";\n}\nstd::string RetrievalTypeToString(RetrievalType type) {\n  switch (type) {\n    case RetrievalType::LAYER:\n      return \"SHARED\";\n    case RetrievalType::KVHEAD:\n      return \"SEPARATE\";\n    case RetrievalType::QHEAD:\n      return \"INDIVIDUAL\";\n  }\n  return \"UNDIFINED\";\n}\nKVCacheConfig::KVCacheConfig(int layer_num, int kv_head_num, int q_head_num, int head_dim, int block_len,\n                             int anchor_num, AnchorType anchor_type, ggml_type kv_type, RetrievalType retrieval_type,\n                             int layer_step, int token_step, int layer_offset, int max_block_num, int max_batch_size,\n                             int max_thread_num)\n    : layer_num(layer_num),\n      kv_head_num(kv_head_num),\n      q_head_num(q_head_num),\n      head_dim(head_dim),\n      block_len(block_len),\n      anchor_num(anchor_num),\n      anchor_type(anchor_type),\n      kv_type(kv_type),\n      retrieval_type(retrieval_type),\n      layer_step(layer_step),\n      token_step(token_step),\n      layer_offset(layer_offset),\n      max_block_num(max_block_num),\n      max_batch_size(max_batch_size),\n      max_thread_num(max_thread_num) {\n  printf(\n      \"layer_num: %d, kv_head_num: %d, q_head_num: %d, head_dim: %d, \"\n      \"block_len: %d, anchor_num: %d, anchor_type: %s, kv_type: %s, \"\n      \"retrieval_type: %s, layer_step: %d, token_step: %d, layer_offset: %d,\"\n      \"max_block_num: %d, max_batch_size: %d, max_thread_num: %d\\n\",\n      layer_num, kv_head_num, q_head_num, head_dim, block_len, anchor_num, AnchorTypeToString(anchor_type).c_str(),\n      ggml_type_to_string(kv_type).c_str(), RetrievalTypeToString(retrieval_type).c_str(), layer_step, token_step,\n      layer_offset, max_block_num, max_batch_size, max_thread_num);\n  assert(q_head_num % kv_head_num == 0);\n}\nKVCache::KVCache(KVCacheConfig config) {\n  this->config_ = config;\n\n  n_gqa_ = config_.q_head_num / config_.kv_head_num;\n  if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n    // TODO: Elegant implement\n    k_cache_fp16_.resize(config_.layer_num);\n    v_cache_fp16_.resize(config_.layer_num);\n    selected_blocks_num_history_.resize(config_.layer_num / config_.layer_step);\n    if (config_.retrieval_type == RetrievalType::LAYER) {\n      selected_blocks_history_.resize(config_.layer_num / config_.layer_step);\n    } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n      selected_blocks_history_kvhead_.resize(config_.layer_num / config_.layer_step);\n    } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n    }\n  } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n    k_cache_q4.resize(config.layer_num);\n    v_cache_q4.resize(config.layer_num);\n  } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n    k_cache_q8.resize(config.layer_num);\n    v_cache_q8.resize(config.layer_num);\n  } else {\n    assert(false);\n  }\n  anchor_.resize(config.layer_num * config.max_block_num * config.anchor_num * config.q_head_num * config.head_dim);\n  importance_.resize(config.layer_num);\n  past_block_num_.resize(config.layer_num);\n  for (int i = 0; i < config.layer_num; i++) {\n    past_block_num_[i] = 0;\n  }\n\n  ThreadResize(config.max_thread_num);\n  BatchResize(config.max_batch_size);\n  BlockResize(config.max_block_num);\n  q_fp32.resize(n_gqa_ * config.head_dim);\n}\n\nvoid KVCache::ThreadResize(int thread_num) {\n  thread_local_output_q8_0_.resize(thread_num);\n  thread_local_attn_score_.resize(thread_num);\n  thread_local_output_fp32_.resize(thread_num);\n  thread_local_attn_lse_.resize(thread_num);\n  thread_local_cur_output_fp32_.resize(thread_num);\n  thread_local_cur_attn_lse_.resize(thread_num);\n  thread_local_draft_.resize(thread_num);\n  thread_cur_head_idx_.resize(thread_num);\n  thread_local_attn_mask_.resize(thread_num);\n  for (int i = 0; i < thread_num; i++) {\n    thread_local_output_q8_0_[i].resize(n_gqa_ * config_.head_dim / QK8_0);\n    thread_local_attn_score_[i].resize(n_gqa_ * config_.block_len);\n    thread_local_output_fp32_[i].resize(n_gqa_ * config_.head_dim);\n    thread_local_attn_lse_[i].resize(n_gqa_);\n    thread_local_cur_output_fp32_[i].resize(n_gqa_ * config_.head_dim);\n    thread_local_cur_attn_lse_[i].resize(n_gqa_);\n    thread_local_draft_[i].resize(2 * n_gqa_ * config_.block_len + 6 * n_gqa_ * config_.head_dim +\n                                  2 * config_.block_len * config_.head_dim +\n                                  config_.block_len * config_.head_dim / QK4_0);\n    thread_local_attn_mask_[i].resize(config_.block_len / 8);\n  }\n}\nvoid KVCache::BatchResize(int batch_size) {\n  mutex_.resize(batch_size);\n  q_q8_0_.resize(batch_size);\n  q_fp32_.resize(batch_size);\n  output_fp32_.resize(batch_size);\n  attn_lse_.resize(batch_size);\n  block_lse_.resize(batch_size);\n  attn_sparsity_.resize(batch_size);\n\n  if (config_.retrieval_type == RetrievalType::LAYER) {\n    block_table_before_retrieval_.resize(batch_size);\n    block_table_after_retrieval_.resize(batch_size);\n\n    for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {\n      selected_blocks_history_[i].resize(batch_size);\n    }\n\n  } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n    block_table_before_retrieval_kvhead_.resize(batch_size);\n    block_table_after_retrieval_kvhead_.resize(batch_size);\n    for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {\n      selected_blocks_history_kvhead_[i].resize(batch_size);\n    }\n  } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n    block_table_before_retrieval_qhead_.resize(batch_size);\n    block_table_after_retrieval_qhead_.resize(batch_size);\n  }\n  cache_seqlens_.resize(batch_size);\n  if (config_.retrieval_type == RetrievalType::LAYER) {\n    block_similar_.resize(batch_size);\n  } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n    block_similar_kv_head_.resize(batch_size);\n  } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n    block_similar_q_head_.resize(batch_size);\n  }\n  for (int i = 0; i < batch_size; i++) {\n    top_similar_block_.resize(batch_size);\n\n    mutex_[i].resize(config_.kv_head_num);\n    q_q8_0_[i].resize(config_.kv_head_num);\n    q_fp32_[i].resize(config_.kv_head_num);\n    output_fp32_[i].resize(config_.kv_head_num);\n    attn_lse_[i].resize(config_.kv_head_num);\n\n    for (int j = 0; j < config_.kv_head_num; j++) {\n      if (!mutex_[i][j]) {\n        mutex_[i][j] = std::make_unique<std::mutex>();\n      }\n      q_q8_0_[i][j].resize(n_gqa_ * config_.head_dim / QK8_0);\n      q_fp32_[i][j].resize(n_gqa_ * config_.head_dim);\n      output_fp32_[i][j].resize(n_gqa_ * config_.head_dim);\n      attn_lse_[i][j].resize(n_gqa_);\n    }\n  }\n  avg_q.resize(batch_size);\n  avg_q_fp16.resize(batch_size);\n  for (int i = 0; i < batch_size; i++) {\n    attn_sparsity_[i].resize(config_.q_head_num);\n    avg_q[i].resize(config_.q_head_num * config_.head_dim);\n    avg_q_fp16[i].resize(config_.q_head_num * config_.head_dim);\n  }\n}\n\nvoid KVCache::BlockResize(int max_block_num) {\n  sin_.resize(max_block_num * config_.block_len);\n  cos_.resize(max_block_num * config_.block_len);\n  for (int i = 0; i < max_block_num * config_.block_len; i++) {\n    sin_[i].resize(config_.head_dim);\n    cos_[i].resize(config_.head_dim);\n  }\n\n  for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {\n    for (int j = 0; j < config_.max_batch_size; j++) {\n      if (config_.retrieval_type == RetrievalType::LAYER) {\n        selected_blocks_history_[i][j].resize(max_block_num);\n      } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n        selected_blocks_history_kvhead_[i][j].resize(max_block_num);\n        for (int k = 0; k < config_.max_block_num; k++) {\n          selected_blocks_history_kvhead_[i][j][k].resize(config_.kv_head_num);\n        }\n      } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n      }\n    }\n  }\n\n  for (int layer_id = 0; layer_id < config_.layer_num; layer_id++) {\n    importance_[layer_id].resize(max_block_num);\n\n    if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n      // TODO: Elegant implement\n      k_cache_fp16_[layer_id].resize(config_.kv_head_num);\n      v_cache_fp16_[layer_id].resize(config_.kv_head_num);\n\n      for (int i = 0; i < config_.kv_head_num; i++) {\n        k_cache_fp16_[layer_id][i].resize(max_block_num);\n        v_cache_fp16_[layer_id][i].resize(max_block_num);\n\n        for (int j = 0; j < max_block_num; j++) {\n          k_cache_fp16_[layer_id][i][j].resize(config_.block_len * config_.head_dim);\n          v_cache_fp16_[layer_id][i][j].resize(config_.block_len * config_.head_dim);\n        }\n      }\n\n    } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n      k_cache_q4[layer_id].resize(config_.kv_head_num);\n      v_cache_q4[layer_id].resize(config_.kv_head_num);\n      for (int i = 0; i < config_.kv_head_num; i++) {\n        k_cache_q4[layer_id][i].resize(max_block_num);\n        v_cache_q4[layer_id][i].resize(max_block_num);\n\n        for (int j = 0; j < max_block_num; j++) {\n          k_cache_q4[layer_id][i][j].resize(config_.block_len * config_.head_dim / 32);\n          v_cache_q4[layer_id][i][j].resize(config_.block_len * config_.head_dim / 32);\n        }\n      }\n    } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n      k_cache_q8[layer_id].resize(config_.kv_head_num);\n      v_cache_q8[layer_id].resize(config_.kv_head_num);\n      for (int i = 0; i < config_.kv_head_num; i++) {\n        k_cache_q8[layer_id][i].resize(max_block_num);\n        v_cache_q8[layer_id][i].resize(max_block_num);\n\n        for (int j = 0; j < max_block_num; j++) {\n          k_cache_q8[layer_id][i][j].resize(config_.block_len * config_.head_dim / 32);\n          v_cache_q8[layer_id][i][j].resize(config_.block_len * config_.head_dim / 32);\n        }\n      }\n    } else {\n      assert(false);\n    }\n    for (int i = 0; i < config_.max_batch_size; i++) {\n      if (config_.retrieval_type == RetrievalType::LAYER) {\n        block_similar_[i].resize(max_block_num);\n        block_table_before_retrieval_[i].resize(max_block_num);\n        block_table_after_retrieval_[i].resize(max_block_num);\n      } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n        block_similar_kv_head_[i].resize(max_block_num);\n        block_table_before_retrieval_kvhead_[i].resize(max_block_num);\n        block_table_after_retrieval_kvhead_[i].resize(max_block_num);\n        for (int j = 0; j < max_block_num; j++) {\n          block_similar_kv_head_[i][j].resize(config_.kv_head_num);\n          block_table_before_retrieval_kvhead_[i][j].resize(config_.kv_head_num);\n          block_table_after_retrieval_kvhead_[i][j].resize(config_.kv_head_num);\n        }\n      } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n        block_similar_q_head_[i].resize(max_block_num);\n        block_table_before_retrieval_qhead_[i].resize(max_block_num);\n        block_table_after_retrieval_qhead_[i].resize(max_block_num);\n        for (int j = 0; j < max_block_num; j++) {\n          block_similar_q_head_[i][j].resize(config_.q_head_num);\n          block_table_before_retrieval_qhead_[i][j].resize(config_.q_head_num);\n          block_table_after_retrieval_qhead_[i][j].resize(config_.q_head_num);\n        }\n      }\n      block_lse_[i].resize(max_block_num);\n      for (int j = 0; j < max_block_num; j++) {\n        block_lse_[i][j].resize(config_.q_head_num);\n      }\n    }\n\n    for (int i = 0; i < max_block_num; i++) {\n      importance_[layer_id][i].resize(config_.block_len);\n      for (int j = 0; j < config_.block_len; j++) {\n        importance_[layer_id][i][j].resize(config_.q_head_num);\n      }\n    }\n  }\n}\n\nvoid KVCache::calc_anchor_all_layers(int* block_table, int* cache_seqlens, int batch_size, int max_block_num,\n                                     WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  // Each task updates the importance of a certain block\n  seq_len_ = config_.block_len;\n  backend->do_work_stealing_job(\n      config_.layer_num * batch_size * max_block_num, nullptr,\n      [&](int task_id) {\n        int layer_id = task_id / (batch_size * max_block_num);\n        int batch_id = (task_id / max_block_num) % batch_size;\n        int block_id = task_id % max_block_num;\n        // If the block is out of the sequence length, skip it. In\n        // particular, the last block of the sequence that is shorter than\n        // the block length should be skipped.\n\n        if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n          return;\n        }\n        int block_idx = block_table[batch_id * max_block_num + block_id];\n\n        std::vector<float> block_fp32(32);\n        if (config_.anchor_type == AnchorType::DYNAMIC) {\n          // clear anchor_\n          for (int anchor_id = 0; anchor_id < 1; anchor_id++) {\n            for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n              for (int l = 0; l < config_.head_dim; l++) {\n                anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                        block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                        anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = 0;\n              }\n            }\n          }\n\n          // find top anchor_num importances and their corresponding\n          // positions in the importance_ tensor\n          // TODO: Move top_importances to the class member to avoid\n          // repeated memory allocation\n          std::priority_queue<std::pair<float, std::pair<int, int>>, std::vector<std::pair<float, std::pair<int, int>>>,\n                              std::greater<>>\n              top_importances;\n          for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n            for (int k = 0; k < seq_len_; k++) {\n              top_importances.push(std::make_pair(GGML_FP16_TO_FP32(importance_[layer_id][block_idx][k][head_id]),\n                                                  std::make_pair(block_idx, k)));\n              // TODO: change to config_ item\n              if (top_importances.size() > config_.anchor_num) {\n                top_importances.pop();\n              }\n            }\n\n            // fill anchor_\n\n            for (int l = 0; l < config_.head_dim; l++) {\n              anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                      block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                      0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = 0;\n            }\n            for (int k = 0; k < config_.anchor_num; k++) {\n              int top_indice = top_importances.top().second.second;\n              int top_block_idx = top_importances.top().second.first;\n\n              if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                for (int l = 0; l < config_.head_dim; l++) {\n                  anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                              config_.head_dim +\n                          top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                          0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                      GGML_FP32_TO_FP16(\n                          GGML_FP16_TO_FP32(\n                              anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                          config_.head_dim +\n                                      top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                      0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l]) +\n                          GGML_FP16_TO_FP32(k_cache_fp16_[layer_id][head_id / n_gqa_][top_block_idx]\n                                                         [top_indice * config_.head_dim + l]));\n                }\n\n              } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                for (int l = 0; l < config_.head_dim / 32; l++) {\n                  block_q4_0 block =\n                      k_cache_q4[layer_id][head_id / n_gqa_][top_block_idx][top_indice * config_.head_dim / 32 + l];\n                  dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                  for (int m = 0; m < 32; m++) {\n                    anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                config_.head_dim +\n                            top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                            0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] =\n                        GGML_FP32_TO_FP16(\n                            block_fp32[m] / 4 +\n                            GGML_FP16_TO_FP32(\n                                anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                            config_.head_dim +\n                                        top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                        0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim +\n                                        l * 32 + m]));\n                  }\n                }\n              } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                for (int l = 0; l < config_.head_dim / 32; l++) {\n                  block_q8_0 block =\n                      k_cache_q8[layer_id][head_id / n_gqa_][top_block_idx][top_indice * config_.head_dim / 32 + l];\n                  dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                  for (int m = 0; m < 32; m++) {\n                    anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                config_.head_dim +\n                            top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                            0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] =\n                        GGML_FP32_TO_FP16(\n                            block_fp32[m] / 4 +\n                            GGML_FP16_TO_FP32(\n                                anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                            config_.head_dim +\n                                        top_block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                        0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim +\n                                        l * 32 + m]));\n                  }\n                }\n              }\n              top_importances.pop();\n            }\n          }\n        } else if (config_.anchor_type == AnchorType::BLOCK_MEAN) {\n          // clear anchor_\n          for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) {\n            for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n              for (int l = 0; l < config_.head_dim; l++) {\n                anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                        block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                        anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = 0;\n              }\n            }\n          }\n\n          // fill anchor_\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n              for (int k = 0; k < config_.block_len; k++) {\n                for (int l = 0; l < config_.head_dim; l++) {\n                  anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                              config_.head_dim +\n                          block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                          0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                      GGML_FP32_TO_FP16(\n                          GGML_FP16_TO_FP32(\n                              anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                          config_.head_dim +\n                                      block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                      0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l]) +\n                          GGML_FP16_TO_FP32(\n                              k_cache_fp16_[layer_id][head_id / n_gqa_][block_idx][k * config_.head_dim + l]) /\n                              config_.block_len);\n                }\n              }\n            }\n          }\n        } else if (config_.anchor_type == AnchorType::BLOCK_MAX) {\n          // clear anchor_\n          for (int anchor_id = 0; anchor_id < config_.anchor_num; anchor_id++) {\n            for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n              for (int l = 0; l < config_.head_dim; l++) {\n                anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                        block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                        anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = 0;\n              }\n            }\n          }\n\n          // fill anchor_\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n              for (int k = 0; k < config_.block_len; k++) {\n                for (int l = 0; l < config_.head_dim; l++) {\n                  anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                              config_.head_dim +\n                          block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                          0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                      GGML_FP32_TO_FP16(std::max(\n                          GGML_FP16_TO_FP32(\n                              anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                          config_.head_dim +\n                                      block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                      0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l]),\n                          GGML_FP16_TO_FP32(\n                              k_cache_fp16_[layer_id][head_id / n_gqa_][block_idx][k * config_.head_dim + l])));\n                }\n              }\n            }\n          }\n        } else if (config_.anchor_type == AnchorType::FIXED_ANCHOR) {\n          // clear anchor_\n          for (int anchor_id = 0; anchor_id < 1; anchor_id++) {\n            for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n              for (int l = 0; l < config_.head_dim; l++) {\n                anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                        block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                        anchor_id * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] = 0;\n              }\n            }\n          }\n\n          // fill anchor_\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            int stride = config_.block_len / config_.anchor_num;\n            for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n              for (int k = 0, tot = 0; k < config_.block_len, tot < config_.anchor_num; k += stride, tot++) {\n                for (int l = 0; l < config_.head_dim; l++) {\n                  anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                              config_.head_dim +\n                          block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                          0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                      GGML_FP32_TO_FP16(\n                          GGML_FP16_TO_FP32(\n                              anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                          config_.head_dim +\n                                      block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                      0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l]) +\n                          GGML_FP16_TO_FP32(\n                              k_cache_fp16_[layer_id][head_id / n_gqa_][block_idx][k * config_.head_dim + l]) /\n                              config_.anchor_num);\n                }\n              }\n            }\n          }\n\n        } else if (config_.anchor_type == AnchorType::QUEST) {\n          // clear anchor_\n          for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n            for (int l = 0; l < config_.head_dim; l++) {\n              anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                      block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                      1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                  GGML_FP32_TO_FP16(std::numeric_limits<float>::max());\n\n              anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                      block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                      0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                  GGML_FP32_TO_FP16(std::numeric_limits<float>::min());\n            }\n          }\n\n          // fill anchor_\n\n          if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            for (int indice = 0; indice < seq_len_; indice++) {\n              for (int head_id = 0; head_id < config_.kv_head_num; head_id++) {\n                for (int l = 0; l < config_.head_dim; l++) {\n                  anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                              config_.head_dim +\n                          block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                          0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                      GGML_FP32_TO_FP16(std::max(\n                          GGML_FP16_TO_FP32(k_cache_fp16_[layer_id][head_id][block_idx][indice * config_.head_dim + l]),\n                          GGML_FP16_TO_FP32(\n                              anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                          config_.head_dim +\n                                      block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                      0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l])));\n\n                  anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                              config_.head_dim +\n                          block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                          1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l] =\n                      GGML_FP32_TO_FP16(std::min(\n                          GGML_FP16_TO_FP32(k_cache_fp16_[layer_id][head_id][block_idx][indice * config_.head_dim + l]),\n                          GGML_FP16_TO_FP32(\n                              anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                          config_.head_dim +\n                                      block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                      1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l])));\n                }\n              }\n            }\n\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            for (int indice = 0; indice < seq_len_; indice++) {\n              for (int head_id = 0; head_id < config_.kv_head_num; head_id++) {\n                for (int l = 0; l < config_.head_dim / 32; l++) {\n                  block_q4_0 block = k_cache_q4[layer_id][head_id][block_idx][indice * config_.head_dim / 32 + l];\n                  dequantize_row_q4_0(&block, block_fp32.data(), 32);\n\n                  for (int m = 0; m < 32; m++) {\n                    for (int gqa_idx = 0; gqa_idx < n_gqa_; gqa_idx++) {\n                      anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                  config_.head_dim +\n                              block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                              0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] =\n                          GGML_FP32_TO_FP16(std::max(\n                              block_fp32[m],\n                              GGML_FP16_TO_FP32(\n                                  anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                              config_.head_dim +\n                                          block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                          0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim +\n                                          l * 32 + m])));\n\n                      anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                  config_.head_dim +\n                              block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                              1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] =\n                          GGML_FP32_TO_FP16(std::min(\n                              block_fp32[m],\n                              GGML_FP16_TO_FP32(\n                                  anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                              config_.head_dim +\n                                          block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                          1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim +\n                                          l * 32 + m])));\n                    }\n                  }\n                }\n              }\n            }\n          } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            for (int indice = 0; indice < seq_len_; indice++) {\n              for (int head_id = 0; head_id < config_.kv_head_num; head_id++) {\n                for (int l = 0; l < config_.head_dim / 32; l++) {\n                  block_q8_0 block = k_cache_q8[layer_id][head_id][block_idx][indice * config_.head_dim / 32 + l];\n                  dequantize_row_q8_0(&block, block_fp32.data(), 32);\n\n                  for (int m = 0; m < 32; m++) {\n                    for (int gqa_idx = 0; gqa_idx < n_gqa_; gqa_idx++) {\n                      anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                  config_.head_dim +\n                              block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                              0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] =\n                          GGML_FP32_TO_FP16(std::max(\n                              block_fp32[m],\n                              GGML_FP16_TO_FP32(\n                                  anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                              config_.head_dim +\n                                          block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                          0 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim +\n                                          l * 32 + m])));\n\n                      anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                  config_.head_dim +\n                              block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                              1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim + l * 32 + m] =\n                          GGML_FP32_TO_FP16(std::min(\n                              block_fp32[m],\n                              GGML_FP16_TO_FP32(\n                                  anchor_[layer_id * config_.max_block_num * config_.anchor_num * config_.q_head_num *\n                                              config_.head_dim +\n                                          block_idx * config_.anchor_num * config_.q_head_num * config_.head_dim +\n                                          1 * config_.q_head_num * config_.head_dim + head_id * config_.head_dim +\n                                          l * 32 + m])));\n                    }\n                  }\n                }\n              }\n            }\n          }\n        } else {\n          assert(false);\n        }\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  //    printf(\"time of calc_anchor_all_layers: %f s\\n\", duration.count());\n}\n\nvoid KVCache::clear_importance_all_layers(int* block_table, int* cache_seqlens, int batch_size, int max_block_num,\n                                          WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  // Each task updates the importance of a certain block\n  seq_len_ = config_.block_len;\n  backend->do_work_stealing_job(\n      config_.layer_num * batch_size * max_block_num, nullptr,\n      [&](int task_id) {\n        int layer_id = task_id / (batch_size * max_block_num);\n        int batch_id = (task_id / max_block_num) % batch_size;\n        int block_id = task_id % max_block_num;\n        // If the block is out of the sequence length, skip it. In\n        // particular, the last block of the sequence that is shorter than\n        // the block length should be skipped.\n\n        if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n          return;\n        }\n        int block_idx = block_table[batch_id * max_block_num + block_id];\n\n        if (config_.anchor_type == AnchorType::DYNAMIC) {\n          // clear anchor_\n          for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n            for (int l = 0; l < config_.block_len; l++) {\n              importance_[layer_id][block_idx][l][head_id] = 0;\n            }\n          }\n        }\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  //    printf(\"time of clear_importance_all_layerssssss: %f s\\n\",\n  //    duration.count());\n}\n\nvoid KVCache::clear_kvcache_all_layers(int* block_table, int* cache_seqlens, int batch_size, int max_block_num,\n                                       WorkerPool* backend) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  // Each task updates the importance of a certain block\n  seq_len_ = config_.block_len;\n  backend->do_work_stealing_job(\n      config_.layer_num * batch_size * max_block_num * config_.kv_head_num, nullptr,\n      [&](int task_id) {\n        int layer_id = task_id / (batch_size * max_block_num * config_.kv_head_num);\n        int batch_id = (task_id / (max_block_num * config_.kv_head_num)) % batch_size;\n        int block_id = task_id / config_.kv_head_num % max_block_num;\n        int head_id = task_id % config_.kv_head_num;\n        // If the block is out of the sequence length, skip it. In\n        // particular, the last block of the sequence that is shorter than\n        // the block length should be skipped.\n        if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n          return;\n        }\n        int block_idx = block_table[batch_id * max_block_num + block_id];\n\n        if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n          for (int l = 0; l < config_.block_len * config_.head_dim; l++) {\n            k_cache_fp16_[layer_id][head_id][block_idx][l] = 0;\n            v_cache_fp16_[layer_id][head_id][block_idx][l] = 0;\n          }\n        } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n          for (int l = 0; l < config_.block_len * config_.head_dim / 32; l++) {\n            k_cache_q4[layer_id][head_id][block_idx][l].d = 0;\n            v_cache_q4[layer_id][head_id][block_idx][l].d = 0;\n          }\n        } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n          for (int l = 0; l < config_.block_len * config_.head_dim / 32; l++) {\n            k_cache_q8[layer_id][head_id][block_idx][l].d = 0;\n            v_cache_q8[layer_id][head_id][block_idx][l].d = 0;\n          }\n        }\n      },\n      nullptr);\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  //    printf(\"time of clear_kvcache_all_layers: %f s\\n\", duration.count());\n}\n\nvoid KVCache::get_sincos(ggml_fp16_t* sin, ggml_fp16_t* cos, int seqlen) {\n  // Timer start\n  auto start = std::chrono::high_resolution_clock::now();\n\n  const uint16_t* sin_data = const_cast<const uint16_t*>(sin);\n  const uint16_t* cos_data = const_cast<const uint16_t*>(cos);\n\n  for (int i = 0; i < seqlen; i++) {\n    for (int j = 0; j < config_.head_dim; j++) {\n      sin_[i][j] = sin_data[i * config_.head_dim + j];\n      cos_[i][j] = cos_data[i * config_.head_dim + j];\n    }\n  }\n\n  // Timer end\n  auto end = std::chrono::high_resolution_clock::now();\n  std::chrono::duration<double> duration = end - start;\n  printf(\"time of get_sincos: %f s\\n\", duration.count());\n}\n\nvoid ggml_vec_scale_f32(const int n, float* y, const float v) {\n#if defined(GGML_USE_ACCELERATE)\n  vDSP_vsmul(y, 1, &v, y, 1, n);\n#elif defined(GGML_SIMD)\n  const int np = (n & ~(GGML_F32_STEP - 1));\n\n  GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);\n\n  GGML_F32_VEC ay[GGML_F32_ARR];\n\n  for (int i = 0; i < np; i += GGML_F32_STEP) {\n    for (int j = 0; j < GGML_F32_ARR; j++) {\n      ay[j] = GGML_F32_VEC_LOAD(y + i + j * GGML_F32_EPR);\n      ay[j] = GGML_F32_VEC_MUL(ay[j], vx);\n\n      GGML_F32_VEC_STORE(y + i + j * GGML_F32_EPR, ay[j]);\n    }\n  }\n\n  // leftovers\n  for (int i = np; i < n; ++i) {\n    y[i] *= v;\n  }\n#else\n  // scalar\n  for (int i = 0; i < n; ++i) {\n    y[i] *= v;\n  }\n#endif\n}"
  },
  {
    "path": "kt-kernel/operators/llamafile/conversion.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:34:55\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_CONVERSION_H\n#define CPUINFER_CONVERSION_H\n\n#include <memory.h>\n\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n\ninline void to_float(const void* input, float* output, int size, ggml_type type) {\n  if (type == ggml_type::GGML_TYPE_F32) {\n    memcpy(output, input, size * sizeof(float));\n  } else {\n    if (type == ggml_type::GGML_TYPE_Q8_K) {\n      dequantize_row_q8_K((block_q8_K*)input, output, size);\n    } else {\n      ggml_internal_get_type_traits(type).to_float(input, output, size);\n    }\n  }\n}\n\ninline void from_float(const float* input, void* output, int size, ggml_type type) {\n  if (type == ggml_type::GGML_TYPE_F32) {\n    memcpy(output, input, size * sizeof(float));\n  } else {\n    ggml_internal_get_type_traits(type).from_float(input, output, size);\n  }\n}\n\n#endif"
  },
  {
    "path": "kt-kernel/operators/llamafile/linear.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-15 07:45:18\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"linear.h\"\n\nLinear::Linear(LinearConfig config) {\n  config_ = config;\n  proj_ = config_.proj;\n\n  MemoryRequest mem_requests;\n  mem_requests.append_pointer(&input_fp32_, sizeof(float) * config_.group_max_len * config_.input_size);\n  mem_requests.append_pointer(&proj_input_,\n                              config_.group_max_len * config_.input_size *\n                                  ggml_type_size(ggml_internal_get_type_traits(config_.proj_type).vec_dot_type) /\n                                  ggml_blck_size(ggml_internal_get_type_traits(config_.proj_type).vec_dot_type));\n  mem_requests.append_pointer(&proj_output_, sizeof(float) * config_.group_max_len * config_.output_size);\n  shared_mem_buffer.alloc(this, mem_requests);\n}\n\nLinear::~Linear() {}\n\nvoid Linear::warm_up(WorkerPool* backend) {\n  std::vector<float> input_fp32(config_.input_size);\n  std::vector<uint8_t> input(config_.input_size * ggml_type_size(config_.hidden_type) /\n                             ggml_blck_size(config_.hidden_type));\n  std::vector<uint8_t> output(config_.output_size * ggml_type_size(config_.hidden_type) /\n                              ggml_blck_size(config_.hidden_type));\n  for (int i = 0; i < config_.input_size; i++) {\n    input_fp32[i] = 0;\n  }\n  from_float(input_fp32.data(), input.data(), config_.input_size, config_.hidden_type);\n  forward_many(1, input.data(), output.data(), backend);\n}\n\nvoid Linear::forward_many(int qlen, const void* input, void* output, WorkerPool* backend) {\n  const void* proj_input_ptr;\n  if (config_.hidden_type == ggml_internal_get_type_traits(config_.proj_type).vec_dot_type) {\n    proj_input_ptr = input;\n  } else {\n    to_float(input, input_fp32_, qlen * config_.input_size, config_.hidden_type);\n    from_float(input_fp32_, proj_input_, qlen * config_.input_size,\n               ggml_internal_get_type_traits(config_.proj_type).vec_dot_type);\n    proj_input_ptr = proj_input_;\n  }\n  int nth = config_.output_size / config_.stride;\n  backend->do_work_stealing_job(\n      nth, nullptr,\n      [&](int task_id) {\n        int ith = task_id;\n        void* proj_ptr = (uint8_t*)proj_ + ith * config_.stride * config_.input_size *\n                                               ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type);\n        float* proj_output_ptr = proj_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.input_size / ggml_blck_size(config_.proj_type), proj_ptr,\n                        config_.input_size / ggml_blck_size(config_.proj_type), proj_input_ptr,\n                        config_.input_size / ggml_blck_size(config_.proj_type), proj_output_ptr, config_.output_size, 0,\n                        1, GGML_TASK_TYPE_COMPUTE, config_.proj_type,\n                        ggml_internal_get_type_traits(config_.proj_type).vec_dot_type, GGML_TYPE_F32,\n                        GGML_PREC_DEFAULT);\n        if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {\n          for (int i = 0; i < qlen; i++) {\n            float* output_fp32_ptr = proj_output_ + i * config_.output_size + ith * config_.stride;\n            void* output_ptr =\n                (uint8_t*)output +\n                i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) +\n                ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n            from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);\n          }\n        }\n      },\n      nullptr);\n  if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {\n    from_float(proj_output_, output, qlen * config_.output_size, config_.hidden_type);\n  }\n}\n\nvoid Linear::forward(int qlen, const void* input, void* output, WorkerPool* backend) {\n  if (qlen <= 0) {\n    return;\n  }\n  int forward_len = std::min(qlen, config_.group_max_len);\n  forward_many(forward_len, input, output, backend);\n  forward(qlen - forward_len,\n          (uint8_t*)input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) /\n                                ggml_blck_size(config_.hidden_type),\n          (uint8_t*)output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) /\n                                 ggml_blck_size(config_.hidden_type),\n          backend);\n}"
  },
  {
    "path": "kt-kernel/operators/llamafile/linear.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:35:00\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_LINEAR_H\n#define CPUINFER_OPERATOR_LINEAR_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"../../cpu_backend/worker_pool.h\"\n#include \"conversion.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\nstruct LinearConfig {\n  int input_size;\n  int output_size;\n  int stride;\n  int group_max_len;\n  void* proj;\n  ggml_type proj_type;\n  ggml_type hidden_type;\n\n  LinearConfig() {}\n\n  LinearConfig(int input_size, int output_size, int stride, int group_max_len, void* proj, ggml_type proj_type,\n               ggml_type hidden_type)\n      : input_size(input_size),\n        output_size(output_size),\n        stride(stride),\n        group_max_len(group_max_len),\n        proj(proj),\n        proj_type(proj_type),\n        hidden_type(hidden_type) {}\n};\n\nclass Linear {\n public:\n  Linear(LinearConfig);\n  ~Linear();\n  void warm_up(WorkerPool* backend);\n  void forward_many(int qlen, const void* input, void* output, WorkerPool* backend);\n  void forward(int qlen, const void* input, void* output, WorkerPool* backend);\n\n private:\n  LinearConfig config_;\n  void* proj_;  // [output_size * input_size ( /32 if quantized)]\n\n  float* input_fp32_;    // [group_max_len * input_size]\n  uint8_t* proj_input_;  // [group_max_len * input_size *\n                         // ggml_type_size(ggml_internal_get_type_traits(proj_type).vec_dot_type) /\n                         // ggml_blck_size(ggml_internal_get_type_traits(proj_type).vec_dot_type)]\n  float* proj_output_;   // [group_max_len * output_size]\n};\n\n#endif"
  },
  {
    "path": "kt-kernel/operators/llamafile/mla.hpp",
    "content": "// #ifndef LLAMAFILE_MLA_HPP\n// #define LLAMAFILE_MLA_HPP\n\n// #include \"../common.hpp\"\n// #include \"../mla-tp.hpp\"\n// #include \"../rms-norm.hpp\"\n// #include \"../rope.hpp\"\n// #include \"ggml-quants.h\"\n// #include \"ggml.h\"\n// #include \"llamafile/sgemm.h\"\n\n// #include <algorithm>\n// #include <cstddef>\n// #include <utility>\n// #include <vector>\n\n// #define DIRECT_OR_POOL_BY(what, threshold, var, fn) \\\n//   do { \\\n//     if ((what) < (threshold)) { \\\n//       for (int i = 0; i < (var); i++) { \\\n//         (fn)(i); \\\n//       } \\\n//     } else { \\\n//       pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \\\n//     } \\\n//   } while (0)\n\n// #define VEC_DOT_TYPE(type) (ggml_internal_get_type_traits((ggml_type)(type)).vec_dot_type)\n// #define QUANT_BLCK_COUNT(size, type) (((size_t)(size)) / (ggml_blck_size((ggml_type)(type))))\n// #define QUANT_BLCK_SIZE(size, type) (QUANT_BLCK_COUNT(size, type) * (ggml_type_size((ggml_type)(type))))\n// #define QUANT_OFFSET(ptr, type, n, n_elements) \\\n//   (offset_pointer((ptr), (size_t)(n) * QUANT_BLCK_SIZE((n_elements), (type))))\n\n// #define LLAMAFILE_SGEMM_QUANT_FULL_MATMUL(m, n, k, a, a_type, b, b_col, c, c_col) \\\n//   do { \\\n//     llamafile_sgemm((m), (n), QUANT_BLCK_COUNT((k), (a_type)), (a), QUANT_BLCK_COUNT((k), (a_type)), \\\n//                     QUANT_OFFSET((b), VEC_DOT_TYPE((a_type)), (b_col), (k)), \\\n//                     QUANT_BLCK_COUNT((k), VEC_DOT_TYPE((a_type))), offset_pointer((c), (c_col) * (m) *\n//                     sizeof(float)), \\\n//                     (k), 0, 1, GGML_TASK_TYPE_COMPUTE, (a_type), VEC_DOT_TYPE((a_type)), GGML_TYPE_F32, \\\n//                     GGML_PREC_DEFAULT); \\\n//   } while (0)\n\n// #define LLAMAFILE_SGEMM_MATMUL_F32(m, n, k, a, lda, b, ldb, c, ldc) \\\n//   do { \\\n//     llamafile_sgemm((m), (n), (k), (a), (lda), (b), (ldb), (c), (ldc), 0, 1, GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F32, \\\n//                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_PREC_DEFAULT); \\\n//   } while (0)\n\n// // bool decide_absorb(size_t a,int a_type,size_t b,int b_type,size_t c,int c_type,size_t d,int d_type){\n// //   size_t flops1 = ;\n\n// // }\n\n// inline void transpose(void *start, size_t dim0, size_t stride, size_t dim1) {\n//   // static_assert(false, \"TODO\");\n// }\n\n// template <RMS_NORM T_RMSNorm = RMSNorm, ROPE_APPLIER T_RopeApplier = Rope, ROPE_ANGLE T_RopeAngle = Yarn>\n// class LLAMA_MLA_TP {\n// private:\n//   GeneralMLAConfig config;\n//   int tp_part_idx;\n//   std::vector<void *> nope_pages;     // [page_count * page_token_count * nope]\n//   std::vector<void *> rope_pages;     // [page_count * page_token_count * nope]\n\n//   // weights\n//   void *local_q_a_proj;               // [hidden_size * q_lora_rank]\n//   void *local_q_a_norm;               // [q_lora_rank]\n//   std::vector<void *> local_q_b_proj; // [num_heads * (nope_size + rope_size))]\n//   void *local_kv_a_proj_with_mqa;     // [hidden_size * (kv_lora_rank + rope)]\n//   void *local_kv_a_norm_with_mqa;\n//   void *local_kv_b_proj;                   // [(num_heads * (nope_size + nope_size) * kv_lora_rank)],\n//                                            // q_absorb:   [num_heads * nope_size * kv_lora_rank]\n//                                            // out_absorb: [num_heads * nope_size * kv_lora_rank]\n//   std::vector<void *> local_k_b_proj_nope; // [(num_heads * kv_lora_rank * nope)],\n//   void *local_w_o; // [(num_heads * nope_size) * hidden_size]\n//   T_RopeAngle rope_angle;\n\n//   // intermediate\n\n//   void *quant_input;           // [qlen, hidden size(Q)]\n//   void *q_a_proj_output;       // [qlen, q_lora_rank]\n//   void *quant_q_a_proj_output; // [qlen, q_lora_rank(Q)]\n\n//   // for each query\n//   std::vector<void *> q_pe;              // [num_heads * max_qlen * rope_size]\n//   std::vector<void *> k_pe;              // [num_threads * rope_size]\n//   std::vector<void *> q_nope;            // [num_heads * max_qlen * nope_size]\n//   std::vector<void *> attention_weights; // [num_heads * max_qlen * max_klen];\n//   std::vector<void *> q_absorb;          // [num_heads, max_qlen, kv_lora_rank],  or [num_heads, kv_lora_rank,\n//   max_qlen] std::vector<void *> o_absorb;          // [num_heads, max_qlen, kv_lora_rank],  or [num_heads,\n//   kv_lora_rank, max_qlen] std::vector<void *> compressed_kv_tmp; // [num_threads * token_count_in_page *\n//   kv_lora_rank] std::vector<void *> quant_o_absorb;    // [num_heads, max_qlen, kv_lora_rank],  or [num_heads,\n//   kv_lora_rank, max_qlen] std::vector<void *> attention_output;  // [num_threads * max_qlen * nope] std::vector<void\n//   *> quant_attention_output; // [num_threads * max_qlen * nope]\n\n// public:\n//   using output_t = float;\n\n//   LLAMA_MLA_TP(GeneralMLAConfig config, int tp_part_idx) : config(config), tp_part_idx(tp_part_idx) {\n//     std::vector<std::pair<void **, uint64_t>> s_mem_requests;\n//   }\n\n//   void set_pages(std::vector<void *> cache_pages) { this->nope_pages = cache_pages; }\n//   void set_pages(std::vector<void *> cache_pages, std::vector<void *> pe_pages) {\n//     this->nope_pages = cache_pages;\n//     this->rope_pages = pe_pages;\n//   }\n\n//   void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,\n//                const void *input, void *output) {}\n\n//   void forward_prefill(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kvlens,\n//                        const void *input_raw, void *output) {\n//     auto pool = config.pool->get_subpool(tp_part_idx);\n\n//     float *input = (float *)input_raw;\n//     std::vector<int> qlen_split, total_len_split;\n//     qlen_split.reserve(qlens.size() + 1);\n//     qlen_split.push_back(0);\n//     total_len_split.reserve(qlens.size() + 1);\n//     int qlen_sum = 0;\n//     int total_len_sum = 0;\n//     for (size_t i = 0; i < qlens.size(); i++) {\n//       qlen_sum += qlens[i];\n//       qlen_split.push_back(qlen_sum);\n\n//       total_len_sum += qlens[i] + kvlens[i];\n//       total_len_split.push_back(total_len_sum);\n//     }\n\n//     auto which_query_by_qlen_sum = [&](int token_nth) -> std::pair<size_t, size_t> {\n//       auto query_idx = std::upper_bound(qlen_split.begin(), qlen_split.end(), token_nth) - qlen_split.begin() - 1;\n//       auto token_nth_from_start = token_nth - qlen_split.at(query_idx) + kvlens.at(query_idx);\n//       return {query_idx, token_nth_from_start};\n//     };\n//     auto which_query_by_total_sum = [&](int token_nth) -> std::pair<size_t, size_t> {\n//       auto query_idx =\n//           std::upper_bound(total_len_split.begin(), total_len_split.end(), token_nth) - total_len_split.begin() - 1;\n//       auto token_nth_from_start = token_nth - total_len_split.at(query_idx);\n//       return {query_idx, token_nth_from_start};\n//     };\n\n//     auto which_page = [&](int query, int token_nth_from_start) -> std::pair<size_t, size_t> {\n//       size_t page_idx = page_tables.at(query).at(div_up((size_t)token_nth_from_start, config.token_count_in_page));\n\n//       size_t token_at_in_page = token_nth_from_start % config.token_count_in_page;\n//       return {page_idx, token_at_in_page};\n//     };\n\n//     ggml_type vec_dot_type = ggml_internal_get_type_traits((ggml_type)config.q_a_proj_type).vec_dot_type;\n//     size_t hidden_size_float_bytes = config.hidden_size * sizeof(float);\n//     size_t hidden_size_quant_blck_count = config.hidden_size / ggml_blck_size(vec_dot_type);\n//     size_t hidden_size_quant_bytes = hidden_size_quant_blck_count * ggml_type_size(vec_dot_type);\n//     // quant to q8 0\n\n//     DIRECT_OR_POOL_BY(qlen_sum, 10, qlen_sum, [&](int token_at_i) {\n//       size_t token_at = token_at_i;\n//       quantize_q8_0(offset_pointer(input, token_at * config.hidden_size * sizeof(float)),\n//                     offset_pointer(quant_input,\n//                                    token_at * QUANT_BLCK_SIZE(config.hidden_size,\n//                                    VEC_DOT_TYPE(config.q_a_proj_type))),\n//                     1, config.hidden_size, nullptr);\n//     });\n\n//     {\n//       // q lora rank\n//       // maybe this should be up to non tp\n//       auto proj_lora_a = [&](int task_id) {\n//         size_t token_at = task_id % qlen_sum;\n//         bool do_q_or_kv = (task_id / qlen_sum) == 0;\n//         if (do_q_or_kv) {\n//           auto this_q_a_proj_output =\n//               (float *)offset_pointer(q_a_proj_output, token_at * config.hidden_size * sizeof(float));\n\n//           LLAMAFILE_SGEMM_QUANT_FULL_MATMUL(config.q_lora_rank, 1, config.hidden_size, local_q_a_proj,\n//                                             config.q_a_proj_type, quant_input, token_at, this_q_a_proj_output, 0);\n\n//           T_RMSNorm::rms_norm_single(config.q_lora_rank, (float *)local_q_a_norm, this_q_a_proj_output);\n\n//           quantize_q8_0(\n//               this_q_a_proj_output,\n//               offset_pointer(quant_q_a_proj_output,\n//                              token_at * QUANT_BLCK_SIZE(config.q_lora_rank, VEC_DOT_TYPE(config.q_b_proj_type))),\n//               1, config.q_lora_rank, nullptr);\n\n//         } else {\n//           auto [query, token_from_start] = which_query_by_qlen_sum(token_at);\n//           auto [page_idx, token_at_in_page] = which_page(query, token_from_start);\n//           LLAMAFILE_SGEMM_QUANT_FULL_MATMUL(config.kv_lora_rank, 1, config.hidden_size, local_kv_a_proj_with_mqa,\n//                                             config.kv_a_proj_with_mqa_type, quant_input, token_at,\n//                                             rope_pages.at(page_idx), token_at_in_page);\n//           T_RMSNorm::rms_norm_single(\n//               config.kv_lora_rank, (float *)local_kv_a_norm_with_mqa,\n//               (float *)offset_pointer(rope_pages.at(page_idx), token_at_in_page * config.kv_lora_rank *\n//               sizeof(float)));\n//           LLAMAFILE_SGEMM_QUANT_FULL_MATMUL(config.rope_size, 1, config.hidden_size,\n//                                             QUANT_OFFSET(local_kv_a_proj_with_mqa, config.kv_a_proj_with_mqa_type,\n//                                                          config.kv_lora_rank, config.hidden_size),\n//                                             config.kv_a_proj_with_mqa_type, quant_input, token_at,\n//                                             nope_pages.at(page_idx), token_at_in_page);\n//         }\n//       };\n//       DIRECT_OR_POOL_BY(qlen_sum, 10, qlen_sum * 2, proj_lora_a);\n//     }\n\n//     {\n//       int task_count = config.num_heads * 2 * qlen_sum; // head, rope/nope, qlen\n//       auto q_proj_lora_b = [&](int task_id) {\n//         size_t head_idx = task_id / (2 * qlen_sum);\n//         task_id %= (2 * qlen_sum);\n//         bool nope_or_rope = (task_id / qlen_sum) == 0;\n//         task_id %= qlen_sum;\n//         size_t token_at = task_id;\n//         auto [query, token_from_start] = which_query_by_qlen_sum(token_at);\n//         auto [page_idx, token_at_in_page] = which_page(query, token_from_start);\n\n//         if (nope_or_rope) {\n//           LLAMAFILE_SGEMM_QUANT_FULL_MATMUL(\n//               config.nope_size, 1, config.q_lora_rank,\n//               QUANT_OFFSET(local_q_b_proj.at(head_idx), config.q_b_proj_type,\n//                            head_idx * (config.nope_size + config.rope_size), config.q_lora_rank),\n//               config.q_b_proj_type, quant_q_a_proj_output, token_at, q_nope.at(head_idx), token_at);\n//         } else {\n//           LLAMAFILE_SGEMM_QUANT_FULL_MATMUL(\n//               config.rope_size, 1, config.q_lora_rank,\n//               QUANT_OFFSET(local_q_b_proj.at(head_idx), config.q_b_proj_type,\n//                            head_idx * (config.nope_size + config.rope_size) + config.nope_size, config.q_lora_rank),\n//               config.q_b_proj_type, quant_q_a_proj_output, token_at, q_pe.at(head_idx), token_at);\n//           T_RopeApplier::apply_single(config.rope_size,\n//                                       offset_pointer(q_pe.at(head_idx), token_at * config.rope_size * sizeof(float)),\n//                                       rope_angle.cos(token_at), rope_angle.sin(token_at));\n//         }\n//       };\n//       pool->do_work_stealing_job(task_count, nullptr, q_proj_lora_b, nullptr);\n//     }\n\n//     for (int query = 0; query < qlens.size(); query++) {\n//       {\n//         // pe attention\n//         // apply k pe online\n//         int task_count = config.num_heads * (qlens[query] + kvlens[query]); // by kvlen\n//         auto pe_attn = [&](int task_id) {\n//           size_t head_idx = task_id / (qlens[query] + kvlens[query]);\n//           size_t token_from_start = task_id % (qlens[query] + kvlens[query]);\n\n//           // auto q_token_at = qlen_split[query] + qlens[query];\n\n//           auto [page_idx, token_at_in_page] = which_page(query, token_from_start);\n//           memcpy(k_pe[WorkerPool::thread_local_id],\n//                  offset_pointer(rope_pages.at(page_idx), token_at_in_page * config.rope_size * sizeof(float)),\n//                  sizeof(float) * config.rope_size);\n//           T_RopeApplier::apply_single(config.rope_size, k_pe[WorkerPool::thread_local_id],\n//                                       rope_angle.cos(token_from_start), rope_angle.sin(token_from_start));\n\n//           LLAMAFILE_SGEMM_MATMUL_F32(1, qlens[query], config.rope_size, k_pe[WorkerPool::thread_local_id],\n//                                      config.rope_size, q_pe.at(head_idx), config.rope_size,\n//                                      attention_weights[head_idx], config.max_kvlen);\n//         };\n//         pool->do_work_stealing_job(task_count, pe_attn);\n//       }\n//       {\n//         // clear q absorb\n//         pool->do_work_stealing_job(config.num_heads, [&](int task_id) {\n//           memset(q_absorb[task_id], 0, config.kv_lora_rank * config.max_qlen * sizeof(float));\n//         });\n\n//         // aborb W_uk\n//         int task_count = config.num_heads * qlens[query];\n//         auto task = [&](int task_id) {\n//           size_t head_idx = task_id / qlens[query];\n//           size_t token_at = task_id % qlens[query];\n\n//           // q_absorb now [kvrank, max_qlen]\n//           LLAMAFILE_SGEMM_MATMUL_F32(qlens[query], config.kv_lora_rank, config.nope_size, q_nope[head_idx],\n//                                      config.nope_size, local_k_b_proj_nope[head_idx], config.nope_size,\n//                                      q_absorb[head_idx], config.max_qlen);\n//           transpose(q_nope[head_idx], config.kv_lora_rank, config.max_qlen, qlens[query]);\n//         };\n//         pool->do_work_stealing_job(task_count, task);\n//       }\n\n//       {\n//         // nope attention weights\n//         size_t page_count = div_up((size_t)kvlens[query], config.token_count_in_page);\n//         int task_count = config.num_heads * page_count;\n//         auto task = [&](int task_id) {\n//           size_t head_idx = task_id / page_count;\n//           size_t page_idx = task_id % page_count;\n//           void *page_ptr = nope_pages[page_tables[query][page_idx]]; // mla no head\n\n//           size_t kvlen =\n//               page_idx == (page_count - 1) ? (kvlens[query] % config.token_count_in_page) :\n//               config.token_count_in_page;\n\n//           LLAMAFILE_SGEMM_MATMUL_F32(\n//               kvlen, qlens[query], config.kv_lora_rank, page_ptr, config.kv_lora_rank, q_absorb[head_idx],\n//               config.max_qlen,\n//               offset_pointer(attention_weights[head_idx], page_idx * config.token_count_in_page * sizeof(float)),\n//               config.max_kvlen);\n//           // static_assert(false, \"soft max todo\");\n//         };\n//         pool->do_work_stealing_job(task_count, task);\n//       }\n\n//       {\n//         // clear o absorb\n//         pool->do_work_stealing_job(config.num_heads, [&](int task_id) {\n//           memset(o_absorb[task_id], 0, config.kv_lora_rank * config.max_qlen * sizeof(float));\n//         });\n\n//         // o absorb\n//         size_t page_count = div_up((size_t)kvlens[query], config.token_count_in_page);\n//         int task_count = config.num_heads * page_count;\n//         auto task = [&](int task_id) {\n//           size_t head_idx = task_id / page_count;\n//           size_t page_idx = task_id % page_count;\n//           void *page_ptr = nope_pages[page_tables[query][page_idx]]; // mla no head\n//           size_t kvlen =\n//               page_idx == (page_count - 1) ? (kvlens[query] % config.token_count_in_page) :\n//               config.token_count_in_page;\n\n//           memcpy(compressed_kv_tmp[WorkerPool::thread_local_id], page_ptr,\n//                  config.token_count_in_page * config.kv_lora_rank * sizeof(float));\n//           transpose(compressed_kv_tmp[WorkerPool::thread_local_id], config.token_count_in_page, config.kv_lora_rank,\n//                     kvlen);\n\n//           LLAMAFILE_SGEMM_MATMUL_F32(\n//               config.kv_lora_rank, qlens[query], kvlen, compressed_kv_tmp[WorkerPool::thread_local_id],\n//               config.token_count_in_page,\n//               offset_pointer(attention_weights[head_idx], page_idx * config.token_count_in_page * sizeof(float)),\n//               config.max_kvlen, o_absorb[head_idx], config.kv_lora_rank);\n//         };\n//         pool->do_work_stealing_job(task_count, task);\n//       }\n\n//       {\n\n//         // clear\n//         pool->do_work_stealing_job(config.num_heads, [&](int task_id) {\n//           memset(attention_output[task_id], 0, config.nope_size * config.max_qlen * sizeof(float));\n//         });\n\n//         // attention output\n//         int task_count = config.num_heads * qlens[query];\n//         auto task = [&](int task_id) {\n//           size_t head_idx = task_id / qlens[query];\n//           size_t token_at = task_id % qlens[query];\n\n//           quantize_q8_0((float *)offset_pointer(o_absorb[head_idx], config.kv_lora_rank * token_at * sizeof(float)),\n//                         offset_pointer(quant_o_absorb[head_idx],\n//                                        QUANT_BLCK_SIZE(config.kv_lora_rank, VEC_DOT_TYPE(config.kv_b_proj_type))),\n//                         1, config.kv_lora_rank, nullptr);\n\n//           auto kv_b_proj_ptr =\n//               offset_pointer(local_kv_b_proj, ((head_idx * 2 + 1) * config.nope_size) *\n//                                                   QUANT_BLCK_SIZE(config.kv_lora_rank, config.kv_b_proj_type));\n\n//           LLAMAFILE_SGEMM_QUANT_FULL_MATMUL(config.nope_size, 1, config.kv_lora_rank, kv_b_proj_ptr,\n//                                             config.kv_b_proj_type, quant_o_absorb[head_idx], token_at,\n//                                             attention_output[head_idx], token_at);\n//         };\n//         pool->do_work_stealing_job(task_count, task);\n//       }\n\n//       {\n//         // quant attention output\n//         // static_assert(false,\"TODO\" );\n//       }\n\n//       {\n//         // get final output\n//         // static_assert(false,\"TODO\" );\n//       }\n//     }\n//   }\n\n//   void load_weights(int complete_num_heads, int offset) {}\n// };\n// template <typename Norm, typename Rope, typename RopeAngle>\n// class TP_MLA<LLAMA_MLA_TP<Norm, Rope, RopeAngle>> : public TP_MLA_Common<LLAMA_MLA_TP<Norm, Rope, RopeAngle>> {\n// public:\n//   using TP_MLA_Common<LLAMA_MLA_TP<Norm, Rope, RopeAngle>>::TP_MLA_Common;\n\n//   void load_weights() {\n//     auto pool = this->config.pool;\n//     auto tp_num_heads = this->config.num_heads / this->tp_count;\n//     pool->dispense_backend()->do_numa_job([this, pool, tp_num_heads](int tp_id) {\n//       this->tps[tp_id]->load_weights(this->config.num_heads, tp_id * tp_num_heads);\n//     });\n//     this->weights_loaded = true;\n//   }\n\n//   void merge_results(int qlen, void *output) {}\n// };\n\n// #endif\n"
  },
  {
    "path": "kt-kernel/operators/llamafile/mlp.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-16 10:43:18\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-15 07:44:38\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"mlp.h\"\n\nMLP::MLP(MLPConfig config) {\n  config_ = config;\n  gate_proj_ = config_.gate_proj;\n  up_proj_ = config_.up_proj;\n  down_proj_ = config_.down_proj;\n\n  MemoryRequest mem_requests;\n  mem_requests.append_pointer(&input_fp32_, sizeof(float) * config_.group_max_len * config_.hidden_size);\n  mem_requests.append_pointer(&gate_input_,\n                              config_.group_max_len * config_.hidden_size *\n                                  ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) /\n                                  ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type));\n  mem_requests.append_pointer(&up_input_,\n                              config_.group_max_len * config_.hidden_size *\n                                  ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) /\n                                  ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type));\n  mem_requests.append_pointer(&gate_output_, sizeof(float) * config_.group_max_len * config_.intermediate_size);\n  mem_requests.append_pointer(&up_output_, sizeof(float) * config_.group_max_len * config_.intermediate_size);\n  mem_requests.append_pointer(&intermediate_fp32_, sizeof(float) * config_.group_max_len * config_.intermediate_size);\n  mem_requests.append_pointer(&down_input_,\n                              config_.group_max_len * config_.intermediate_size *\n                                  ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) /\n                                  ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type));\n  mem_requests.append_pointer(&down_output_, sizeof(float) * config_.group_max_len * config_.hidden_size);\n  shared_mem_buffer.alloc(this, mem_requests);\n}\n\nMLP::~MLP() {}\n\nvoid MLP::warm_up(WorkerPool* backend) {\n  std::vector<float> input_fp32(config_.hidden_size);\n  std::vector<uint8_t> input(config_.hidden_size * ggml_type_size(config_.hidden_type) /\n                             ggml_blck_size(config_.hidden_type));\n  std::vector<uint8_t> output(config_.hidden_size * ggml_type_size(config_.hidden_type) /\n                              ggml_blck_size(config_.hidden_type));\n  for (int i = 0; i < config_.hidden_size; i++) {\n    input_fp32[i] = 0;\n  }\n  from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type);\n  forward_many(1, input.data(), output.data(), backend);\n}\n\nstatic float act_fn(float x) { return x / (1.0f + expf(-x)); }\n\nvoid MLP::forward_many(int qlen, const void* input, void* output, WorkerPool* backend) {\n  const void* gate_input_ptr;\n  const void* up_input_ptr;\n  if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type &&\n      config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n    gate_input_ptr = up_input_ptr = input;\n  } else {\n    to_float(input, input_fp32_, qlen * config_.hidden_size, config_.hidden_type);\n    if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type ==\n        ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n      from_float(input_fp32_, gate_input_, qlen * config_.hidden_size,\n                 ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n      gate_input_ptr = up_input_ptr = gate_input_;\n    } else {\n      if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {\n        from_float(input_fp32_, gate_input_, qlen * config_.hidden_size,\n                   ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n        gate_input_ptr = gate_input_;\n      } else {\n        gate_input_ptr = input;\n      }\n      if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n        from_float(input_fp32_, up_input_, qlen * config_.hidden_size,\n                   ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n        up_input_ptr = up_input_;\n      } else {\n        up_input_ptr = input;\n      }\n    }\n  }\n  int nth = config_.intermediate_size / config_.stride;\n  backend->do_work_stealing_job(\n      nth, nullptr,\n      [&](int task_id) {\n        int ith = task_id;\n        void* gate_proj_ptr = (uint8_t*)gate_proj_ + ith * config_.stride * config_.hidden_size *\n                                                         ggml_type_size(config_.gate_type) /\n                                                         ggml_blck_size(config_.gate_type);\n        float* gate_output_ptr = gate_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr,\n                        config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr,\n                        config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr,\n                        config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type,\n                        ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32,\n                        GGML_PREC_DEFAULT);\n        void* up_proj_ptr = (uint8_t*)up_proj_ + ith * config_.stride * config_.hidden_size *\n                                                     ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        float* up_output_ptr = up_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr,\n                        config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr,\n                        config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size,\n                        0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type,\n                        ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        for (int i = 0; i < qlen; i++) {\n          for (int j = ith * config_.stride; j < (ith + 1) * config_.stride; j++) {\n            intermediate_fp32_[i * config_.intermediate_size + j] =\n                act_fn(gate_output_[i * config_.intermediate_size + j]) * up_output_[i * config_.intermediate_size + j];\n          }\n          if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {\n            float* intermediate_fp32_ptr = intermediate_fp32_ + i * config_.intermediate_size + ith * config_.stride;\n            void* down_input_ptr = (uint8_t*)down_input_ +\n                                   i * config_.intermediate_size *\n                                       ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) /\n                                       ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) +\n                                   ith * config_.stride *\n                                       ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) /\n                                       ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n            from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride,\n                       ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n          }\n        }\n      },\n      nullptr);\n  if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {\n    from_float(intermediate_fp32_, down_input_, qlen * config_.intermediate_size,\n               ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n  }\n  nth = config_.hidden_size / config_.stride;\n  backend->do_work_stealing_job(\n      nth, nullptr,\n      [&](int task_id) {\n        int ith = task_id;\n        void* down_proj_ptr = (uint8_t*)down_proj_ + ith * config_.stride * config_.intermediate_size *\n                                                         ggml_type_size(config_.down_type) /\n                                                         ggml_blck_size(config_.down_type);\n        float* down_output_ptr = down_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.intermediate_size / ggml_blck_size(config_.down_type),\n                        down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_,\n                        config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr,\n                        config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type,\n                        ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32,\n                        GGML_PREC_DEFAULT);\n        if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {\n          for (int i = 0; i < qlen; i++) {\n            float* output_fp32_ptr = down_output_ + i * config_.hidden_size + ith * config_.stride;\n            void* output_ptr =\n                (uint8_t*)output +\n                i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) +\n                ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n            from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);\n          }\n        }\n      },\n      nullptr);\n  if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {\n    from_float(down_output_, output, qlen * config_.hidden_size, config_.hidden_type);\n  }\n}\n\nvoid MLP::forward(int qlen, const void* input, void* output, WorkerPool* backend) {\n  if (qlen <= 0) {\n    return;\n  }\n  int forward_len = std::min(qlen, config_.group_max_len);\n  forward_many(forward_len, input, output, backend);\n  forward(qlen - forward_len,\n          (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) /\n                                ggml_blck_size(config_.hidden_type),\n          (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) /\n                                 ggml_blck_size(config_.hidden_type),\n          backend);\n}"
  },
  {
    "path": "kt-kernel/operators/llamafile/mlp.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:35:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_MLP_H\n#define CPUINFER_OPERATOR_MLP_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"../../cpu_backend/worker_pool.h\"\n#include \"conversion.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\nstruct MLPConfig {\n  int hidden_size;\n  int intermediate_size;\n  int stride;\n  int group_max_len;\n  void* gate_proj;\n  void* up_proj;\n  void* down_proj;\n  ggml_type gate_type;\n  ggml_type up_type;\n  ggml_type down_type;\n  ggml_type hidden_type;\n\n  MLPConfig() {}\n\n  MLPConfig(int hidden_size, int intermediate_size, int stride, int group_max_len, void* gate_proj, void* up_proj,\n            void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)\n      : hidden_size(hidden_size),\n        intermediate_size(intermediate_size),\n        stride(stride),\n        group_max_len(group_max_len),\n        gate_proj(gate_proj),\n        up_proj(up_proj),\n        down_proj(down_proj),\n        gate_type(gate_type),\n        up_type(up_type),\n        down_type(down_type),\n        hidden_type(hidden_type) {}\n};\n\nclass MLP {\n public:\n  MLP(MLPConfig);\n  ~MLP();\n  void warm_up(WorkerPool* backend);\n  void forward_many(int qlen, const void* input, void* output, WorkerPool* backend);\n  void forward(int qlen, const void* input, void* output, WorkerPool* backend);\n\n private:\n  MLPConfig config_;\n  void* gate_proj_;  // [intermediate_size * hidden_size ( /32 if quantized)]\n  void* up_proj_;    // [intermediate_size * hidden_size ( /32 if quantized)]\n  void* down_proj_;  // [hidden_size * intermediate_size ( /32 if quantized)]\n\n  float* input_fp32_;    // [group_max_len * hidden_size]\n  uint8_t* gate_input_;  // [group_max_len * hidden_size *\n                         // ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) /\n                         // ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n  uint8_t*\n      up_input_;  // [group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type)\n                  // / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n  float* gate_output_;        // [group_max_len * intermediate_size]\n  float* up_output_;          // [group_max_len * intermediate_size]\n  float* intermediate_fp32_;  // [group_max_len * intermediate_size]\n  uint8_t* down_input_;       // [group_max_len * intermediate_size *\n                              // ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) /\n                              // ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n  float* down_output_;        // [group_max_len * hidden_size]\n};\n\n#endif"
  },
  {
    "path": "kt-kernel/operators/llamafile/moe.hpp",
    "content": "#ifndef LLAMAFILE_MOE_HPP\n#define LLAMAFILE_MOE_HPP\n#ifdef FORWARD_TIME_PROFILE\n#include <fmt/format.h>\n#endif\n#include <numa.h>\n#include <numaif.h>\n\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n#include <cstdint>\n#include <cstdio>\n#include <functional>\n#include <vector>\n\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"../../cpu_backend/worker_pool.h\"\n#include \"../moe-tp.hpp\"\n#include \"conversion.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\ninline void debug_quant(void* input, ggml_type type) {\n  std::vector<float> output(ggml_blck_size(type));\n  to_float(input, output.data(), ggml_blck_size(type), type);\n  for (size_t i = 0; i < 10; i++) {\n    printf(\"%f \", output[i]);\n  }\n  printf(\"\\n\");\n}\n\nclass LLAMA_MOE_TP {\n private:\n  GeneralMOEConfig config_;\n  int tp_part_idx;\n\n  uint8_t* m_local_gate_proj_;  // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n  uint8_t* m_local_up_proj_;    // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n  uint8_t* m_local_down_proj_;  // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n\n  float* s_input_fp32_;    // [hidden_size]\n  uint8_t* s_gate_input_;  // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) /\n                           // ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n  uint8_t* s_up_input_;    // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) /\n                           // ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n  std::vector<float*> s_gate_output_;        // [routed_expert_num, intermediate_size]\n  std::vector<float*> s_up_output_;          // [routed_expert_num, intermediate_size]\n  std::vector<float*> s_intermediate_fp32_;  // [routed_expert_num, intermediate_size]\n  std::vector<uint8_t*> s_down_input_;       // [routed_expert_num, intermediate_size *\n                                             // ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) /\n                                             // ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n  std::vector<float*> s_down_output_;        // [routed_expert_num, hidden_size]\n  float* s_output_fp32_;                     // [hidden_size]\n\n  std::vector<float*> m_input_fp32_;    // [group_max_len, hidden_size]\n  std::vector<uint8_t*> m_gate_input_;  // [group_max_len, hidden_size *\n                                        // ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) /\n                                        // ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n  std::vector<uint8_t*>\n      m_up_input_;  // [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type)\n                    // / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n  uint8_t* m_local_gate_input_;        // [routed_expert_num * group_max_len * hidden_size *\n                                       // ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) /\n                                       // ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n  uint8_t* m_local_up_input_;          // [routed_expert_num * group_max_len * hidden_size *\n                                       // ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) /\n                                       // ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n  float* m_local_gate_output_;         // [routed_expert_num * group_max_len * intermediate_size]\n  float* m_local_up_output_;           // [routed_expert_num * group_max_len * intermediate_size]\n  float* m_local_intermediate_fp32_;   // [routed_expert_num * group_max_len * intermediate_size]\n  uint8_t* m_local_down_input_;        // [routed_expert_num * group_max_len * intermediate_size *\n                                       // ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) /\n                                       // ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n  float* m_local_down_output_;         // [routed_expert_num * group_max_len * hidden_size]\n  std::vector<float*> m_output_fp32_;  // [group_max_len, hidden_size]\n\n  std::vector<std::vector<int>> m_local_pos_;          // [group_max_len, routed_expert_num]\n  std::vector<int> m_local_num_;                       // [expert_num]\n  std::vector<int> m_expert_id_map_;                   // [expert_num]\n  std::vector<uint8_t*> m_local_gate_input_ptr_;       // [expert_num]\n  std::vector<uint8_t*> m_local_up_input_ptr_;         // [expert_num]\n  std::vector<float*> m_local_gate_output_ptr_;        // [expert_num]\n  std::vector<float*> m_local_up_output_ptr_;          // [expert_num]\n  std::vector<float*> m_local_intermediate_fp32_ptr_;  // [expert_num]\n  std::vector<uint8_t*> m_local_down_input_ptr_;       // [expert_num]\n  std::vector<float*> m_local_down_output_ptr_;        // [expert_num]\n public:\n  using input_t = ggml_bf16_t;\n  using output_t = float;\n\n  LLAMA_MOE_TP(GeneralMOEConfig config, int tp_part_idx) : config_(config), tp_part_idx(tp_part_idx) {\n    MemoryRequest mem_requests;\n    mem_requests.append_pointer(&s_input_fp32_, sizeof(float) * config_.hidden_size);\n    mem_requests.append_pointer(\n        &s_gate_input_, config_.hidden_size *\n                            ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type) /\n                            ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type));\n    mem_requests.append_pointer(\n        &s_up_input_, config_.hidden_size *\n                          ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) /\n                          ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type));\n    s_gate_output_.resize(config_.num_experts_per_tok);\n    s_up_output_.resize(config_.num_experts_per_tok);\n    s_intermediate_fp32_.resize(config_.num_experts_per_tok);\n    s_down_input_.resize(config_.num_experts_per_tok);\n    s_down_output_.resize(config_.num_experts_per_tok);\n    for (int i = 0; i < config_.num_experts_per_tok; i++) {\n      mem_requests.append_pointer(&s_gate_output_[i], sizeof(float) * config_.intermediate_size);\n      mem_requests.append_pointer(&s_up_output_[i], sizeof(float) * config_.intermediate_size);\n      mem_requests.append_pointer(&s_intermediate_fp32_[i], sizeof(float) * config_.intermediate_size);\n      mem_requests.append_pointer(\n          &s_down_input_[i],\n          config_.intermediate_size *\n              ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) /\n              ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type));\n      mem_requests.append_pointer(&s_down_output_[i], sizeof(float) * config_.hidden_size);\n    }\n    mem_requests.append_pointer(&s_output_fp32_, sizeof(float) * config_.hidden_size);\n    shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);\n    // shared_mem_buffer.alloc(this, mem_requests);\n\n    m_input_fp32_.resize(config_.group_max_len);\n    m_gate_input_.resize(config_.group_max_len);\n    m_up_input_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n      mem_requests.append_pointer(&m_input_fp32_[i], sizeof(float) * config_.hidden_size);\n      mem_requests.append_pointer(\n          &m_gate_input_[i],\n          config_.hidden_size *\n              ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type) /\n              ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type));\n      mem_requests.append_pointer(\n          &m_up_input_[i], config_.hidden_size *\n                               ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) /\n                               ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type));\n    }\n    mem_requests.append_pointer(\n        &m_local_gate_input_,\n        config_.num_experts_per_tok * config_.group_max_len * config_.hidden_size *\n            ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type) /\n            ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type));\n    mem_requests.append_pointer(\n        &m_local_up_input_, config_.num_experts_per_tok * config_.group_max_len * config_.hidden_size *\n                                ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) /\n                                ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type));\n    mem_requests.append_pointer(&m_local_gate_output_, sizeof(float) * config_.num_experts_per_tok *\n                                                           config_.group_max_len * config_.intermediate_size);\n    mem_requests.append_pointer(&m_local_up_output_, sizeof(float) * config_.num_experts_per_tok *\n                                                         config_.group_max_len * config_.intermediate_size);\n    mem_requests.append_pointer(&m_local_intermediate_fp32_, sizeof(float) * config_.num_experts_per_tok *\n                                                                 config_.group_max_len * config_.intermediate_size);\n    mem_requests.append_pointer(\n        &m_local_down_input_,\n        config_.num_experts_per_tok * config_.group_max_len * config_.intermediate_size *\n            ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) /\n            ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type));\n    mem_requests.append_pointer(&m_local_down_output_, sizeof(float) * config_.num_experts_per_tok *\n                                                           config_.group_max_len * config_.hidden_size);\n    m_output_fp32_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n      mem_requests.append_pointer(&m_output_fp32_[i], sizeof(float) * config_.hidden_size);\n    }\n    shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);\n    // shared_mem_buffer.alloc(this, m_mem_requests);\n\n    m_local_pos_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n      m_local_pos_[i].resize(config_.num_experts_per_tok);\n    }\n    m_expert_id_map_.resize(config_.expert_num);\n    m_local_num_.resize(config_.expert_num);\n    m_local_gate_input_ptr_.resize(config_.expert_num);\n    m_local_up_input_ptr_.resize(config_.expert_num);\n    m_local_gate_output_ptr_.resize(config_.expert_num);\n    m_local_up_output_ptr_.resize(config_.expert_num);\n    m_local_intermediate_fp32_ptr_.resize(config_.expert_num);\n    m_local_down_input_ptr_.resize(config_.expert_num);\n    m_local_down_output_ptr_.resize(config_.expert_num);\n\n    auto size = 1ll * config.expert_num * config.intermediate_size * config.hidden_size;\n    m_local_up_proj_ =\n        new uint8_t[size * ggml_type_size((ggml_type)config.up_type) / ggml_blck_size((ggml_type)config.up_type)];\n\n    m_local_gate_proj_ =\n        new uint8_t[size * ggml_type_size((ggml_type)config.gate_type) / ggml_blck_size((ggml_type)config.gate_type)];\n    m_local_down_proj_ =\n        new uint8_t[size * ggml_type_size((ggml_type)config.down_type) / ggml_blck_size((ggml_type)config.down_type)];\n  }\n\n  void load_weights(int complete_intermediate_size, int offset) {\n    auto local_gate_proj = m_local_gate_proj_;\n    auto local_up_proj = m_local_up_proj_;\n    auto local_down_proj = m_local_down_proj_;\n    auto& config = config_;\n    // printf(\"gate load weights:\");\n    // debug_quant(config.gate_proj, (ggml_type)config.gate_type);\n    // we need to make sure the blck size is correct for size.\n    if (config.intermediate_size % ggml_blck_size((ggml_type)config.down_type) != 0) {\n      printf(\"intermediate_size: %d, down_type blck size: %d\\n\", config.intermediate_size,\n             ggml_blck_size((ggml_type)config.down_type));\n      throw std::runtime_error(\"intermediate_size must be a multiple of gate_type blck size\");\n    }\n    if (config.intermediate_size * config.hidden_size % ggml_blck_size((ggml_type)config.up_type) != 0) {\n      printf(\"intermediate_size: %d, up_type blck size: %d\\n\", config.intermediate_size,\n             ggml_blck_size((ggml_type)config.up_type));\n      throw std::runtime_error(\"intermediate_size * hidden_size must be a multiple of up_type blck size\");\n    }\n    if (config.intermediate_size * config.hidden_size % ggml_blck_size((ggml_type)config.gate_type) != 0) {\n      printf(\"intermediate_size: %d, gate_type blck size: %d\\n\", config.intermediate_size,\n             ggml_blck_size((ggml_type)config.gate_type));\n      throw std::runtime_error(\"intermediate_size * hidden_size must be a multiple of gate_type blck size\");\n    }\n    uint8_t* gate_proj = (uint8_t*)config.gate_proj + offset * config.hidden_size *\n                                                          ggml_type_size((ggml_type)config.gate_type) /\n                                                          ggml_blck_size((ggml_type)config.gate_type);\n    uint8_t* up_proj = (uint8_t*)config.up_proj + offset * config.hidden_size *\n                                                      ggml_type_size((ggml_type)config.up_type) /\n                                                      ggml_blck_size((ggml_type)config.up_type);\n    uint8_t* down_proj = (uint8_t*)config.down_proj + offset * ggml_type_size((ggml_type)config.down_type) /\n                                                          ggml_blck_size((ggml_type)config.down_type);\n\n    for (int i = 0; i < config.expert_num; ++i) {\n      memcpy(local_gate_proj, gate_proj,\n             config.intermediate_size * config.hidden_size * ggml_type_size((ggml_type)config.gate_type) /\n                 ggml_blck_size((ggml_type)config.gate_type));\n      memcpy(local_up_proj, up_proj,\n             config.intermediate_size * config.hidden_size * ggml_type_size((ggml_type)config.up_type) /\n                 ggml_blck_size((ggml_type)config.up_type));\n      for (int j = 0; j < config.hidden_size; ++j) {\n        memcpy(local_down_proj, down_proj,\n               config.intermediate_size * ggml_type_size((ggml_type)config.down_type) /\n                   ggml_blck_size((ggml_type)config.down_type));\n        local_down_proj += config.intermediate_size * ggml_type_size((ggml_type)config.down_type) /\n                           ggml_blck_size((ggml_type)config.down_type);\n        down_proj += complete_intermediate_size * ggml_type_size((ggml_type)config.down_type) /\n                     ggml_blck_size((ggml_type)config.down_type);\n      }\n      local_gate_proj += config.intermediate_size * config.hidden_size * ggml_type_size((ggml_type)config.gate_type) /\n                         ggml_blck_size((ggml_type)config.gate_type);\n      local_up_proj += config.intermediate_size * config.hidden_size * ggml_type_size((ggml_type)config.up_type) /\n                       ggml_blck_size((ggml_type)config.up_type);\n      gate_proj += complete_intermediate_size * config.hidden_size * ggml_type_size((ggml_type)config.gate_type) /\n                   ggml_blck_size((ggml_type)config.gate_type);\n      up_proj += complete_intermediate_size * config.hidden_size * ggml_type_size((ggml_type)config.up_type) /\n                 ggml_blck_size((ggml_type)config.up_type);\n    }\n  }\n\n  void warm_up() {\n    std::vector<float> input_fp32(config_.hidden_size);\n    std::vector<uint8_t> input(config_.hidden_size * ggml_type_size((ggml_type)config_.hidden_type) /\n                               ggml_blck_size((ggml_type)config_.hidden_type));\n    std::vector<float> output(config_.hidden_size);\n    for (int i = 0; i < config_.hidden_size; i++) {\n      input_fp32[i] = 0;\n    }\n    from_float(input_fp32.data(), input.data(), config_.hidden_size, (ggml_type)config_.hidden_type);\n    for (int i = 0; i < config_.expert_num; i++) {\n      int64_t expert_ids = i;\n      float weights = 0;\n      forward_one(1, &expert_ids, &weights, input.data(), output.data());\n    }\n  }\n\n  static float act_fn(float x) { return x / (1.0f + expf(-x)); }\n\n  void forward_one(int k, const int64_t* expert_ids, const float* weights, const void* input, float* output) {\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n#ifdef FORWARD_TIME_PROFILE\n    auto t0 = std::chrono::high_resolution_clock::now();\n#endif\n    const void* gate_input_ptr;\n    const void* up_input_ptr;\n    if ((ggml_type)config_.hidden_type == ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type &&\n        (ggml_type)config_.hidden_type == ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) {\n      gate_input_ptr = up_input_ptr = input;\n    } else {\n      to_float(input, s_input_fp32_, config_.hidden_size, (ggml_type)config_.hidden_type);\n      if (ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type ==\n          ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) {\n        from_float(s_input_fp32_, s_gate_input_, config_.hidden_size,\n                   ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type);\n        gate_input_ptr = up_input_ptr = s_gate_input_;\n      } else {\n        if ((ggml_type)config_.hidden_type !=\n            ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type) {\n          from_float(s_input_fp32_, s_gate_input_, config_.hidden_size,\n                     ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type);\n          gate_input_ptr = s_gate_input_;\n        } else {\n          gate_input_ptr = input;\n        }\n        if ((ggml_type)config_.hidden_type != ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) {\n          from_float(s_input_fp32_, s_up_input_, config_.hidden_size,\n                     ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type);\n          up_input_ptr = s_up_input_;\n        } else {\n          up_input_ptr = input;\n        }\n      }\n    }\n\n#ifdef FORWARD_TIME_PROFILE\n    // printf(\"gate_input: \");\n    // debug_quant(const_cast<void *>(gate_input_ptr),\n    // ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type);\n    // printf(\"up_input: \");\n    // debug_quant(const_cast<void *>(up_input_ptr),\n    // ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type);\n    auto t1 = std::chrono::high_resolution_clock::now();\n    fmt::print(\"numa_node: {}, convert time: {}\\n\", tp_part_idx,\n               std::chrono::duration_cast<std::chrono::nanoseconds>(t1 - t0).count());\n\n#endif\n\n    int activated_expert = 0;\n    for (int i = 0; i < k; i++) {\n      if (config_.should_skip_expert(expert_ids[i])) {\n        continue;\n      }\n      m_expert_id_map_[activated_expert] = expert_ids[i];\n      activated_expert++;\n    }\n\n    int nth = config_.intermediate_size / config_.m_block;\n\n    // Only process activated (CPU) experts; skip GPU experts entirely to keep buffers aligned.\n    if (activated_expert > 0) {\n      pool->do_work_stealing_job(\n          nth * activated_expert, nullptr,\n          [&](int task_id) {\n            int act_idx = task_id / nth;\n            int64_t expert_id = m_expert_id_map_[act_idx];\n            if (expert_id == -1) {\n              return;\n            }\n            int ith = task_id % nth;\n\n            void* gate_proj_ptr =\n                (uint8_t*)m_local_gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.m_block) *\n                                                   config_.hidden_size * ggml_type_size((ggml_type)config_.gate_type) /\n                                                   ggml_blck_size((ggml_type)config_.gate_type);\n\n            float* gate_output_ptr = s_gate_output_[act_idx] + ith * config_.m_block;\n            auto ok = llamafile_sgemm(\n                config_.m_block, 1, config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_proj_ptr,\n                config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_input_ptr,\n                config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_output_ptr, config_.m_block, 0,\n                1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.gate_type,\n                ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type, GGML_TYPE_F32,\n                GGML_PREC_DEFAULT);\n            if (ok == false) [[unlikely]] {\n              throw std::runtime_error(\"llamafile not supported\");\n            }\n\n            void* up_proj_ptr =\n                (uint8_t*)m_local_up_proj_ + (expert_id * config_.intermediate_size + ith * config_.m_block) *\n                                                 config_.hidden_size * ggml_type_size((ggml_type)config_.up_type) /\n                                                 ggml_blck_size((ggml_type)config_.up_type);\n\n            float* up_output_ptr = s_up_output_[act_idx] + ith * config_.m_block;\n            llamafile_sgemm(config_.m_block, 1, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type),\n                            up_proj_ptr, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_input_ptr,\n                            config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_output_ptr,\n                            config_.m_block, 0, 1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.up_type,\n                            ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type, GGML_TYPE_F32,\n                            GGML_PREC_DEFAULT);\n\n            for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {\n              s_intermediate_fp32_[act_idx][i] = act_fn(s_gate_output_[act_idx][i]) * s_up_output_[act_idx][i];\n            }\n            if (config_.m_block %\n                    ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) ==\n                0) {\n              float* intermediate_fp32_ptr = s_intermediate_fp32_[act_idx] + ith * config_.m_block;\n              void* down_input_ptr =\n                  s_down_input_[act_idx] +\n                  ith * config_.m_block *\n                      ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) /\n                      ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);\n              from_float(intermediate_fp32_ptr, down_input_ptr, config_.m_block,\n                         ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);\n            }\n          },\n          nullptr);\n    }\n\n    if (config_.m_block % ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) !=\n        0) {\n      for (int i = 0; i < activated_expert; i++) {\n        from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size,\n                   ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);\n      }\n    }\n\n#ifdef FORWARD_TIME_PROFILE\n    // printf(\"sinter:\");\n    // debug_f32(s_intermediate_fp32_[expert_ids[0]]);\n    auto t2 = std::chrono::high_resolution_clock::now();\n    fmt::print(\"numa_node: {}, gate/up time: {}\\n\", tp_part_idx,\n               std::chrono::duration_cast<std::chrono::nanoseconds>(t2 - t1).count());\n#endif\n\n    nth = config_.hidden_size / config_.m_block;\n    pool->do_work_stealing_job(\n        nth, nullptr,\n        [&](int task_id) {\n          int ith = task_id;\n          for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {\n            output[i] = 0;\n          }\n          for (int expert_idx = 0; expert_idx < activated_expert; expert_idx++) {\n            int64_t expert_id = m_expert_id_map_[expert_idx];\n            if (expert_id == -1) {\n              continue;\n            }\n\n            auto expert_offset = expert_id * config_.hidden_size * config_.intermediate_size;\n            auto m_block_offset = ith * config_.m_block * config_.intermediate_size;\n            void* down_proj_ptr = (uint8_t*)m_local_down_proj_ + (expert_offset + m_block_offset) *\n                                                                     ggml_type_size((ggml_type)config_.down_type) /\n                                                                     ggml_blck_size((ggml_type)config_.down_type);\n\n            float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.m_block;\n            llamafile_sgemm(\n                config_.m_block, 1, config_.intermediate_size / ggml_blck_size((ggml_type)config_.down_type),\n                down_proj_ptr, config_.intermediate_size / ggml_blck_size((ggml_type)config_.down_type),\n                s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size((ggml_type)config_.down_type),\n                down_output_ptr, config_.m_block, 0, 1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.down_type,\n                ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type, GGML_TYPE_F32,\n                GGML_PREC_DEFAULT);\n\n            float expert_weight = 0.0f;\n            for (int j = 0; j < k; j++) {\n              if (expert_ids[j] == expert_id) {\n                expert_weight = weights[j];\n                break;\n              }\n            }\n\n            for (int i = ith * config_.m_block; i < (ith + 1) * config_.m_block; i++) {\n              output[i] += s_down_output_[expert_idx][i] * expert_weight;\n            }\n          }\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    auto t3 = std::chrono::high_resolution_clock::now();\n    fmt::print(\"numa_node: {}, down time: {}\\n\", tp_part_idx,\n               std::chrono::duration_cast<std::chrono::nanoseconds>(t3 - t2).count());\n    fmt::print(\"numa_node: {}, total time: {}\\n\", tp_part_idx,\n               std::chrono::duration_cast<std::chrono::nanoseconds>(t3 - t0).count());\n#endif\n  }\n\n  void forward_many(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,\n                    float* output) {\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n#ifdef FORWARD_TIME_PROFILE\n    auto start_time = std::chrono::high_resolution_clock::now();\n    auto last = start_time;\n    // 用于保存各阶段耗时（单位：微秒）\n    long prepare_time = 0, cpy_input_time = 0, q_input_time = 0, up_gate_time = 0;\n    long act_time = 0, q_down_time = 0, down_time = 0, weight_time = 0;\n    int max_local_num = 0;  // 记录最大的 local num\n#endif\n\n    int activated_expert = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n      for (int j = 0; j < k; j++) {\n        if (config_.should_skip_expert(expert_ids[i * k + j])) {\n          continue;\n        }\n        m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n      }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_gate_input_ptr_[i] =\n          m_local_gate_input_ +\n          offset * config_.hidden_size *\n              ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type) /\n              ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type);\n      m_local_up_input_ptr_[i] =\n          m_local_up_input_ +\n          offset * config_.hidden_size *\n              ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) /\n              ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type);\n      m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n      m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n      m_local_intermediate_fp32_ptr_[i] = m_local_intermediate_fp32_ + offset * config_.intermediate_size;\n      m_local_down_input_ptr_[i] =\n          m_local_down_input_ +\n          offset * config_.intermediate_size *\n              ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) /\n              ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);\n      m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;\n      offset += m_local_num_[i];\n      if (m_local_num_[i] > 0) {\n#ifdef FORWARD_TIME_PROFILE\n        max_local_num = std::max(max_local_num, m_local_num_[i]);\n#endif\n        m_expert_id_map_[activated_expert] = i;\n        activated_expert++;\n      }\n    }\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      prepare_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    pool->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          const void* gate_input_ptr;\n          const void* up_input_ptr;\n          if ((ggml_type)config_.hidden_type ==\n                  ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type &&\n              (ggml_type)config_.hidden_type ==\n                  ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) {\n            gate_input_ptr = up_input_ptr = (uint8_t*)input + i * config_.hidden_size *\n                                                                  ggml_type_size((ggml_type)config_.hidden_type) /\n                                                                  ggml_blck_size((ggml_type)config_.hidden_type);\n          } else {\n            to_float((uint8_t*)input + i * config_.hidden_size * ggml_type_size((ggml_type)config_.hidden_type) /\n                                           ggml_blck_size((ggml_type)config_.hidden_type),\n                     m_input_fp32_[i], config_.hidden_size, (ggml_type)config_.hidden_type);\n            if (ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type ==\n                ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) {\n              from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size,\n                         ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type);\n              gate_input_ptr = up_input_ptr = m_gate_input_[i];\n            } else {\n              if ((ggml_type)config_.hidden_type !=\n                  ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type) {\n                from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size,\n                           ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type);\n                gate_input_ptr = m_gate_input_[i];\n              } else {\n                gate_input_ptr = (uint8_t*)input + i * config_.hidden_size *\n                                                       ggml_type_size((ggml_type)config_.hidden_type) /\n                                                       ggml_blck_size((ggml_type)config_.hidden_type);\n              }\n              if ((ggml_type)config_.hidden_type !=\n                  ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) {\n                from_float(m_input_fp32_[i], m_up_input_[i], config_.hidden_size,\n                           ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type);\n                up_input_ptr = m_up_input_[i];\n              } else {\n                up_input_ptr = (uint8_t*)input + i * config_.hidden_size *\n                                                     ggml_type_size((ggml_type)config_.hidden_type) /\n                                                     ggml_blck_size((ggml_type)config_.hidden_type);\n              }\n            }\n          }\n          for (int j = 0; j < k; j++) {\n            if (config_.should_skip_expert(expert_ids[i * k + j])) {\n              continue;\n            }\n            memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] +\n                       m_local_pos_[i][j] * config_.hidden_size *\n                           ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type) /\n                           ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type),\n                   gate_input_ptr,\n                   config_.hidden_size *\n                       ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type) /\n                       ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type));\n            memcpy(m_local_up_input_ptr_[expert_ids[i * k + j]] +\n                       m_local_pos_[i][j] * config_.hidden_size *\n                           ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) /\n                           ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type),\n                   up_input_ptr,\n                   config_.hidden_size *\n                       ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type) /\n                       ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type));\n          }\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      cpy_input_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    int m_block = QK_K;\n    int nth = config_.intermediate_size / m_block;\n    // printf(\"nth: %d, m_block: %d, activated_expert: %d\\n\", nth, m_block, activated_expert);\n    // printf(\"config_.hidden_size: %d, config_.intermediate_size: %d\\n\", config_.hidden_size,\n    // config_.intermediate_size);\n    pool->do_work_stealing_job(\n        nth * activated_expert, nullptr,\n        [&](int task_id) {\n          int64_t expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n          void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];\n\n          void* gate_proj_ptr =\n              (uint8_t*)m_local_gate_proj_ + (expert_idx * config_.intermediate_size + ith * m_block) *\n                                                 config_.hidden_size * ggml_type_size((ggml_type)config_.gate_type) /\n                                                 ggml_blck_size((ggml_type)config_.gate_type);\n\n          float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * m_block;\n\n          // if (ith == 0) {\n          //   printf(\"matrix size: m:%d, n:%d, k:%d\\n\", m_block, m_local_num_[expert_idx],\n          //          config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type));\n          // }\n          llamafile_sgemm(m_block, m_local_num_[expert_idx],\n                          config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_proj_ptr,\n                          config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_input_ptr,\n                          config_.hidden_size / ggml_blck_size((ggml_type)config_.gate_type), gate_output_ptr,\n                          config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.gate_type,\n                          ggml_internal_get_type_traits((ggml_type)config_.gate_type).vec_dot_type, GGML_TYPE_F32,\n                          GGML_PREC_DEFAULT);\n          void* up_input_ptr = m_local_up_input_ptr_[expert_idx];\n\n          void* up_proj_ptr = (uint8_t*)m_local_up_proj_ + (expert_idx * config_.intermediate_size + ith * m_block) *\n                                                               config_.hidden_size *\n                                                               ggml_type_size((ggml_type)config_.up_type) /\n                                                               ggml_blck_size((ggml_type)config_.up_type);\n\n          float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * m_block;\n          llamafile_sgemm(\n              m_block, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type),\n              up_proj_ptr, config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_input_ptr,\n              config_.hidden_size / ggml_blck_size((ggml_type)config_.up_type), up_output_ptr,\n              config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.up_type,\n              ggml_internal_get_type_traits((ggml_type)config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n          for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n            for (int j = ith * m_block; j < (ith + 1) * m_block; j++) {\n              m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] =\n                  act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) *\n                  m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];\n            }\n            float* intermediate_fp32_ptr =\n                m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * m_block;\n            void* down_input_ptr =\n                m_local_down_input_ptr_[expert_idx] +\n                i * config_.intermediate_size *\n                    ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) /\n                    ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) +\n                ith * m_block *\n                    ggml_type_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type) /\n                    ggml_blck_size(ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);\n            from_float(intermediate_fp32_ptr, down_input_ptr, m_block,\n                       ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type);\n          }\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      up_gate_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    m_block = QK_K;\n    nth = config_.hidden_size / m_block;\n    pool->do_work_stealing_job(\n        nth * activated_expert, nullptr,\n        [&](int task_id) {\n          int64_t expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n          void* down_input_ptr = m_local_down_input_ptr_[expert_idx];\n\n          auto expert_offset = expert_idx * config_.hidden_size * config_.intermediate_size;\n          auto m_block_offset = ith * m_block * config_.intermediate_size;\n\n          void* down_proj_ptr = (uint8_t*)m_local_down_proj_ + (expert_offset + m_block_offset) *\n                                                                   ggml_type_size((ggml_type)config_.down_type) /\n                                                                   ggml_blck_size((ggml_type)config_.down_type);\n\n          float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * m_block;\n          llamafile_sgemm(m_block, m_local_num_[expert_idx],\n                          config_.intermediate_size / ggml_blck_size((ggml_type)config_.down_type), down_proj_ptr,\n                          config_.intermediate_size / ggml_blck_size((ggml_type)config_.down_type), down_input_ptr,\n                          config_.intermediate_size / ggml_blck_size((ggml_type)config_.down_type), down_output_ptr,\n                          config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, (ggml_type)config_.down_type,\n                          ggml_internal_get_type_traits((ggml_type)config_.down_type).vec_dot_type, GGML_TYPE_F32,\n                          GGML_PREC_DEFAULT);\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      down_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n#endif\n\n    pool->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int e = 0; e < config_.hidden_size; e++) {\n            m_output_fp32_[i][e] = 0;\n          }\n          for (int j = 0; j < k; j++) {\n            if (config_.should_skip_expert(expert_ids[i * k + j])) {\n              continue;\n            }\n            for (int e = 0; e < config_.hidden_size; e++) {\n              m_output_fp32_[i][e] +=\n                  m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] *\n                  weights[i * k + j];\n            }\n          }\n          for (int e = 0; e < config_.hidden_size; e++) {\n            output[i * config_.hidden_size + e] = m_output_fp32_[i][e];\n          }\n        },\n        nullptr);\n#ifdef FORWARD_TIME_PROFILE\n    {\n      auto now_time = std::chrono::high_resolution_clock::now();\n      weight_time = std::chrono::duration_cast<std::chrono::microseconds>(now_time - last).count();\n      last = now_time;\n    }\n    auto end_time = std::chrono::high_resolution_clock::now();\n    auto forward_total_time = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count();\n    // 在函数末尾一次性打印所有阶段的耗时，并附带 max_local_num 和 qlen\n    printf(\n        \"Profiling Results (numa[%d]): activated_expert: %d, prepare: %ld us, cpy_input: %ld us, q_input: %ld us, \"\n        \"up_gate: %ld us, act: %ld us, q_down: %ld us, down: %ld us, weight: %ld us, total: %ld us, max_local_num: \"\n        \"%d, qlen: %d\\n\",\n        tp_part_idx, activated_expert, prepare_time, cpy_input_time, q_input_time, up_gate_time, act_time, q_down_time,\n        down_time, weight_time, forward_total_time, max_local_num, qlen);\n#endif\n  }\n\n  void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output_in) {\n    auto output = (float*)output_in;\n    if (qlen < config_.group_min_len) {\n      for (int i = 0; i < qlen; i++) {\n        forward_one(k, expert_ids + i * k, weights + i * k,\n                    (uint8_t*)input + i * config_.hidden_size * ggml_type_size((ggml_type)config_.hidden_type) /\n                                          ggml_blck_size((ggml_type)config_.hidden_type),\n                    output + i * config_.hidden_size);\n      }\n      return;\n    }\n    int forward_len = std::min(config_.group_max_len, qlen);\n    forward_many(forward_len, k, expert_ids, weights, input, output);\n    forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k,\n            (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size((ggml_type)config_.hidden_type) /\n                                  ggml_blck_size((ggml_type)config_.hidden_type),\n            output + forward_len * config_.hidden_size);\n  }\n};\n\ntemplate <>\nclass TP_MOE<LLAMA_MOE_TP> : public TP_MOE_Common<LLAMA_MOE_TP> {\n public:\n  using TP_MOE_Common<LLAMA_MOE_TP>::TP_MOE_Common;\n\n  void load_weights() {\n    auto pool = this->config.pool;\n\n    std::vector<int> tp_offsets(this->tp_count);\n    int accumulated_offset = 0;\n    for (int i = 0; i < this->tp_count; i++) {\n      tp_offsets[i] = accumulated_offset;\n      accumulated_offset += this->tp_configs[i].intermediate_size;\n    }\n\n    pool->dispense_backend()->do_numa_job([this, pool, tp_offsets](int tp_id) {\n      this->tps[tp_id]->load_weights(this->config.intermediate_size, tp_offsets[tp_id]);\n    });\n    this->weights_loaded = true;\n  }\n\n  void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }\n\n  void merge_results(int qlen, void* output, bool incremental) {\n    auto pool = this->config.pool;\n    pool->do_work_stealing_job(\n        qlen, nullptr,\n        [this, output, incremental](int token_nth) {\n          if (incremental) {\n            to_float((uint8_t*)output + token_nth * config.hidden_size * ggml_type_size((ggml_type)config.hidden_type) /\n                                            ggml_blck_size((ggml_type)config.hidden_type),\n                     local_output + token_nth * config.hidden_size, config.hidden_size, (ggml_type)config.hidden_type);\n            for (int e = 0; e < config.hidden_size; e++) {\n              local_output_numa[0][token_nth * config.hidden_size + e] +=\n                  local_output[token_nth * config.hidden_size + e];\n            }\n          }\n          auto& tp_count = this->tp_count;\n          for (int i = 1; i < tp_count; i++) {\n            for (int e = 0; e < config.hidden_size; e++) {\n              local_output_numa[0][token_nth * config.hidden_size + e] +=\n                  local_output_numa[i][token_nth * config.hidden_size + e];\n            }\n          }\n          from_float(local_output_numa[0] + token_nth * config.hidden_size,\n                     (uint8_t*)output + token_nth * config.hidden_size * ggml_type_size((ggml_type)config.hidden_type) /\n                                            ggml_blck_size((ggml_type)config.hidden_type),\n                     config.hidden_size, (ggml_type)config.hidden_type);\n        },\n        nullptr);\n  }\n};\n#endif\n"
  },
  {
    "path": "kt-kernel/operators/mla-tp.hpp",
    "content": "#ifndef CPUINFER_OPERATOR_MLA_HPP\n#define CPUINFER_OPERATOR_MLA_HPP\n\n#include \"common.hpp\"\n\ntemplate <typename T>\n// qlens: token count for each query\n// cache_pages: kv_cache for all queries in the current layer\n// page_tables: kv_cache page table for each query ([query_idx][page_idx])\n// kv_lens: kv_cache length for each query\n// input: input tensor, shape [qlen, hidden_size]\n// output: output tensor, shape [qlen, hidden_size]\n// config: GeneralMLAConfig\n// tp_idx: thread pool index\n// T must have the following methods:\nconcept MLA_TP_PART =\n    requires(T t, std::vector<int> qlens, std::vector<void*> kv_lora_pages, std::vector<void*> pe_pages,\n             std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens, const void* input, void* output,\n             GeneralMLAConfig config, int tp_idx, int page_count, std::vector<void*> attention_masks) {\n      typename T::output_t;\n      { new T(config, tp_idx) } -> std::same_as<T*>;\n      { t.set_pages(kv_lora_pages, pe_pages) } -> std::same_as<void>;\n      { t.set_local_pages(page_count) } -> std::same_as<void>;\n      { t.forward(qlens, page_tables, kv_lens, input, output) } -> std::same_as<void>;\n      { t.forward(qlens, page_tables, kv_lens, attention_masks, input, output) } -> std::same_as<void>;\n    };\n\ntemplate <MLA_TP_PART T>\nclass TP_MLA_Common : public MLA_Interface {\n protected:\n  GeneralMLAConfig config;\n  std::vector<GeneralMLAConfig> tp_configs;\n  int tp_count;\n  int me_numa_id;\n  std::vector<std::unique_ptr<T>> tps;\n\n  std::vector<typename T::output_t*> local_output_numa;\n\n  bool weights_loaded = false;\n#ifdef FORWARD_TIME_REPORT\n  size_t forward_time_sum_ns = 0;\n  size_t forward_count = 0;\n#endif\n\n public:\n  TP_MLA_Common(GeneralMLAConfig config) : config(config) {\n    printf(\"TP MLA layer %d, pool: 0x%lx\\n\", config.layer_idx, (intptr_t)config.pool);\n    if (config.pool == nullptr) {\n      printf(\"TP MLA layer %d, no worker pool\\n\", config.layer_idx);\n      throw std::runtime_error(\"no worker pool\");\n    }\n\n    this->config = config;\n    tp_count = config.pool->config.subpool_count;\n    if (config.hidden_size % tp_count != 0) {\n      printf(\"hidden_size %d, tp count %d\\n\", config.hidden_size, tp_count);\n      throw std::runtime_error(\n          \"For TP, hidden_size must be a \"\n          \"multiple of NUMA node count\");\n    }\n\n    for (auto i = 0; i < tp_count; i++) {\n      tps.push_back(nullptr);\n    }\n\n    tp_configs.resize(tp_count);\n    config.pool->dispense_backend()->do_numa_job([this, config](int i) {\n      tp_configs[i] = config;\n      tp_configs[i].num_heads /= tp_count;\n      tps[i] = std::move(std::unique_ptr<T>(new T(tp_configs[i], i)));\n    });\n\n    local_output_numa.resize(tp_count, nullptr);\n    MemoryRequest mem_requests;\n    for (auto i = 0; i < tp_count; i++) {\n      mem_requests.append_pointer(&local_output_numa[i],\n                                  sizeof(typename T::output_t) * tp_configs[i].max_qlen * tp_configs[i].hidden_size);\n    }\n    shared_mem_buffer.alloc(this, mem_requests);\n  }\n\n  void forward(std::vector<int> qlens, std::vector<std::vector<int>> page_tables, std::vector<int> kv_lens,\n               const void* input, void* output) override {\n    if (weights_loaded == false) [[unlikely]] {\n      throw std::runtime_error(\"Not Loaded\");\n    }\n#ifdef FORWARD_TIME_REPORT\n    auto start = std::chrono::high_resolution_clock::now();\n#endif\n\n    auto pool = config.pool;\n    pool->dispense_backend()->do_numa_job([this, pool, qlens, page_tables, kv_lens, input](int numa_id) {\n      tps[numa_id]->forward(qlens, page_tables, kv_lens, input, this->local_output_numa[numa_id]);\n    });\n    int qlen_sum = 0;\n    for (auto i = 0; i < qlens.size(); i++) {\n      qlen_sum += qlens[i];\n    }\n\n    merge_results(qlen_sum, output);\n\n#ifdef FORWARD_TIME_REPORT\n    auto end = std::chrono::high_resolution_clock::now();\n    auto forward_time = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();\n    auto band_width = (1.0 * config.routed_expert_num * config.hidden_size * config.intermediate_size * 3 / 1e9) /\n                      (1.0 * forward_time / 1e6);\n    auto GFLOPS =\n        (1.0 * config.hidden_size * config.intermediate_size * qlen * 3 * config.routed_expert_num * 2 / 1e9) /\n        (1.0 * forward_time / 1e6);\n    if (qlen <= 10) {\n      forward_time_sum_ns += forward_time;\n      forward_count++;\n    }\n    auto average_bandwidth =\n        (1.0 * forward_count * config.routed_expert_num * config.hidden_size * config.intermediate_size * 3 / 1e9) /\n        (1.0 * forward_time_sum_ns / 1e6);\n    printf(\n        \"forward time %ld, time stamp:%ld, band width %f GElement/s, ave bandwidth %f GElement/s (only \"\n        \"decode), %f GFLOPS, me numa: %d\\n\",\n        forward_time, end.time_since_epoch().count() / 1000 % 100000000, band_width, average_bandwidth, GFLOPS,\n        numa_node_of_cpu(sched_getcpu()));\n#endif\n  }\n\n  void set_pages(std::vector<std::vector<void*>> kv_lora_pages, std::vector<std::vector<void*>> pe_pages) {\n    for (auto i = 0; i < tp_count; i++) {\n      tps[i]->set_pages(kv_lora_pages[i], pe_pages[i]);\n    }\n  }\n\n  void set_local_pages(int page_count) {\n    config.pool->dispense_backend()->do_numa_job(\n        [this, page_count](int tp_idx) { tps[tp_idx]->set_local_pages(page_count); });\n  }\n\n  virtual void load_weights() = 0;\n  virtual void merge_results(int qlen, void* output) = 0;\n};\n\ntemplate <MLA_TP_PART T>\nclass TP_MLA : public TP_MLA_Common<T> {\n public:\n  using TP_MLA_Common<T>::TP_MLA_Common;\n  void load_weights() { throw std::runtime_error(\"Not Implemented\"); }\n  void merge_results(int qlen, void* output) { throw std::runtime_error(\"Not Implemented\"); }\n};\n\n#endif"
  },
  {
    "path": "kt-kernel/operators/moe-tp.hpp",
    "content": "#ifndef CPUINFER_OPERATOR_MOE_HPP\n#define CPUINFER_OPERATOR_MOE_HPP\n\n// #define CHECK\n\n#include <cstdint>\n#include <cstdio>\n#include <type_traits>\n\n#include \"../cpu_backend/shared_mem_buffer.h\"\n#include \"common.hpp\"\n\n// Forward declaration for Llamafile backend type checking\nclass LLAMA_MOE_TP;\n\ntemplate <typename T>\nconcept MOE_TP_PART = requires(T t, int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,\n                               void* output, GeneralMOEConfig config, int tp_idx) {\n  typename T::output_t;\n  { new T(config, tp_idx) } -> std::same_as<T*>;\n  { t.forward(qlen, k, expert_ids, weights, input, output) } -> std::same_as<void>;\n  // { t.load_weights() } -> std::same_as<void>;\n};\n\ntemplate <MOE_TP_PART T>\nclass TP_MOE_Common : public MoE_Interface {\n protected:\n  std::vector<GeneralMOEConfig> tp_configs;\n  int tp_count;\n  int me_numa_id;\n  std::vector<std::unique_ptr<T>> tps;\n\n  std::vector<typename T::output_t*> local_output_numa;\n  T::output_t* local_output = nullptr;\n\n  bool weights_loaded = false;\n\n#ifdef FORWARD_TIME_REPORT\n  size_t forward_time_sum_ns = 0;\n  size_t forward_count = 0;\n#endif\n public:\n  GeneralMOEConfig config;\n  using input_t = typename T::input_t;\n  TP_MOE_Common(GeneralMOEConfig config) : config(config) {\n    printf(\"TP MOE layer %d, pool: 0x%lx, expert num: %d, num_experts_per_tok: %d\\n\", config.layer_idx,\n           (intptr_t)config.pool, config.expert_num, config.num_experts_per_tok);\n    if (config.pool == nullptr) {\n      printf(\"TP MOE layer %d, no worker pool\\n\", config.layer_idx);\n      throw std::runtime_error(\"no worker pool\");\n    }\n\n    this->config = config;\n    tp_count = config.pool->config.subpool_count;\n    if (config.intermediate_size % tp_count != 0) {\n      printf(\"intermediate_size %d, tp count %d\\n\", config.intermediate_size, tp_count);\n      throw std::runtime_error(\n          \"For TP, intermediate_size must be a \"\n          \"multiple of NUMA node count\");\n    }\n\n    // Check if this is Llamafile backend using compile-time type checking\n    constexpr bool is_llamafile = std::is_same<T, LLAMA_MOE_TP>::value;\n#ifndef QK_K\n#define QK_K 256\n#endif\n\n    if (is_llamafile) {\n      // For Llamafile backend: use QK_K-aligned TP splitting\n      if (config.intermediate_size % QK_K != 0) {\n        printf(\"intermediate_size %d must be divisible by QK_K %d for Llamafile backend\\n\", config.intermediate_size,\n               QK_K);\n        throw std::runtime_error(\"intermediate_size must be divisible by QK_K (256) for Llamafile backend\");\n      }\n\n      int num_blocks = config.intermediate_size / QK_K;\n      int base_blocks = num_blocks / tp_count;\n      int extra_blocks = num_blocks % tp_count;\n\n      if (base_blocks == 0) {\n        printf(\"intermediate_size %d is too small for tp_count %d (num_blocks=%d)\\n\", config.intermediate_size,\n               tp_count, num_blocks);\n        throw std::runtime_error(\"intermediate_size too small: cannot distribute blocks to all TP instances\");\n      }\n\n      printf(\"Llamafile TP splitting: intermediate_size=%d, tp_count=%d, QK_K=%d\\n\", config.intermediate_size, tp_count,\n             QK_K);\n      printf(\"  num_blocks=%d, base_blocks=%d, extra_blocks=%d\\n\", num_blocks, base_blocks, extra_blocks);\n\n      int current_offset = 0;\n      for (auto i = 0; i < tp_count; i++) {\n        tps.push_back(nullptr);\n        GeneralMOEConfig tp_config = config;\n\n        // First extra_blocks TPs get one more block\n        int num_blocks_for_this_tp = base_blocks + (i < extra_blocks ? 1 : 0);\n        tp_config.intermediate_size = num_blocks_for_this_tp * QK_K;\n\n        printf(\"  TP %d: intermediate_size=%d, offset=%d, blocks=%d\\n\", i, tp_config.intermediate_size, current_offset,\n               num_blocks_for_this_tp);\n\n        tp_configs.push_back(tp_config);\n        current_offset += tp_config.intermediate_size;\n      }\n    } else {\n      // For non-Llamafile backends: use simple equal division\n      if (config.intermediate_size % tp_count != 0) {\n        printf(\"intermediate_size %d, tp count %d\\n\", config.intermediate_size, tp_count);\n        throw std::runtime_error(\n            \"For TP, intermediate_size must be a \"\n            \"multiple of NUMA node count\");\n      }\n\n      for (auto i = 0; i < tp_count; i++) {\n        tps.push_back(nullptr);\n        GeneralMOEConfig tp_config = config;\n        tp_config.intermediate_size /= tp_count;\n        tp_configs.push_back(tp_config);\n      }\n    }\n\n    config.pool->dispense_backend()->do_numa_job(\n        [this, config](int i) { tps[i] = std::move(std::unique_ptr<T>(new T(tp_configs[i], i))); });\n\n    local_output_numa.resize(tp_count, nullptr);\n    MemoryRequest mem_requests;\n    for (auto i = 0; i < tp_count; i++) {\n      mem_requests.append_pointer(\n          &local_output_numa[i],\n          (size_t)sizeof(typename T::output_t) * tp_configs[i].max_possible_qlen() * tp_configs[i].hidden_size);\n    }\n    mem_requests.append_pointer(\n        (void**)&local_output,\n        sizeof(typename T::output_t) * tp_configs[0].max_possible_qlen() * tp_configs[0].hidden_size);\n    // printf(\"local output tp, %d,\\n\", tp_configs[0].max_possible_qlen());\n    shared_mem_buffer.alloc(this, mem_requests);\n  }\n\n  void warm_up() {\n    int qlen = config.max_possible_qlen();\n    std::vector<uint8_t> input(sizeof(ggml_bf16_t) * qlen * config.hidden_size);\n    std::vector<uint8_t> output(sizeof(ggml_bf16_t) * qlen * config.hidden_size);\n    std::vector<int64_t> expert_ids(qlen * config.num_experts_per_tok);\n    std::vector<float> weights(qlen * config.num_experts_per_tok);\n    for (int i = 0; i < qlen * config.num_experts_per_tok; i++) {\n      expert_ids[i] = i % config.expert_num;\n      weights[i] = 0.01;\n    }\n    forward(&qlen, config.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data(), false);\n  }\n\n  void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output,\n               bool incremental = false) {\n    int qlen_local = qlen;\n    forward(&qlen_local, k, expert_ids, weights, input, output, incremental);\n  }\n\n  void forward(int* qlen_ptr, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {\n    forward(qlen_ptr, k, expert_ids, weights, input, output, false);\n  }\n\n  void forward_binding(intptr_t qlen_ptr, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output,\n                       bool incremental) {\n    forward((int*)qlen_ptr, k, (const int64_t*)expert_ids, (const float*)weights, (const void*)input, (void*)output,\n            incremental);\n  }\n\n  void forward(int* qlen_ptr, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output,\n               bool incremental) {\n    if (weights_loaded == false) [[unlikely]] {\n      throw std::runtime_error(\"Not Loaded\");\n    }\n#ifdef FORWARD_TIME_REPORT\n    auto start = std::chrono::high_resolution_clock::now();\n#endif\n    int qlen = *qlen_ptr;\n\n    auto pool = config.pool;\n    pool->dispense_backend()->do_numa_job([this, pool, qlen, k, expert_ids, input, weights](int numa_id) {\n      tps[numa_id]->forward(qlen, k, expert_ids, weights, input, this->local_output_numa[numa_id]);\n    });\n\n    merge_results(qlen, output, incremental);\n#ifdef FORWARD_TIME_REPORT\n    auto end = std::chrono::high_resolution_clock::now();\n    auto forward_time = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();\n    int unique_experts = 0;\n    {\n      std::unordered_set<int64_t> expert_set;\n      for (int i = 0; i < qlen * config.num_experts_per_tok; i++) {\n        expert_set.insert(expert_ids[i]);\n      }\n      unique_experts = expert_set.size();\n    }\n    auto band_width =\n        (1.0 * unique_experts * config.hidden_size * config.intermediate_size * 3 / 1e9) / (1.0 * forward_time / 1e6);\n    auto GFLOPS =\n        (1.0 * config.hidden_size * config.intermediate_size * qlen * 3 * config.num_experts_per_tok * 2 / 1e9) /\n        (1.0 * forward_time / 1e6);\n    if (qlen <= 10) {\n      forward_time_sum_ns += forward_time;\n      forward_count++;\n    }\n    auto average_bandwidth =\n        (1.0 * forward_count * unique_experts * config.hidden_size * config.intermediate_size * 3 / 1e9) /\n        (1.0 * forward_time_sum_ns / 1e6);\n    printf(\n        \"forward time %ld, time stamp:%ld, band width %f GElement/s, ave bandwidth %f GElement/s (only \"\n        \"decode), %f GFLOPS, me numa: %d\\n\",\n        forward_time, end.time_since_epoch().count() / 1000 % 100000000, band_width, average_bandwidth, GFLOPS,\n        numa_node_of_cpu(sched_getcpu()));\n#endif\n  }\n\n  virtual void load_weights() = 0;\n\n  virtual void merge_results(int qlen, void* output) = 0;\n\n  virtual void merge_results(int qlen, void* output, bool incremental) {\n    if (incremental == false) {\n      merge_results(qlen, output);\n    } else {\n      throw std::runtime_error(\"Not Implemented\");\n    }\n  };\n};\n\ntemplate <MOE_TP_PART T>\nclass TP_MOE : public TP_MOE_Common<T> {\n public:\n  using TP_MOE_Common<T>::TP_MOE_Common;\n  void load_weights(const uint64_t* physical_to_logical_map) { throw std::runtime_error(\"Not Implemented\"); }\n  // void merge_results(int qlen, void *output, bool incremental) { throw std::runtime_error(\"Not Implemented\"); }\n};\n\n#endif\n"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/api/common.h",
    "content": "// BOOST_STRONG_TYPEDEF(int8_t, int4_2_t);\n#pragma once\n#include <cstdint>\n\n#include \"llama.cpp/ggml.h\"\n#if !defined(CPUINFER_HAS_FLOAT16_T)\nusing float16_t = ggml_fp16_t;\n#define CPUINFER_HAS_FLOAT16_T 1\n#endif\n\n#if !defined(CPUINFER_HAS_BFLOAT16_T)\nusing bfloat16_t = ggml_bf16_t;\n#define CPUINFER_HAS_BFLOAT16_T 1\n#endif  // CPUINFER_HAS_BFLOAT16_T\nconst bool PACKED = true;\n#if defined(__aarch64__) || defined(__arm__) || defined(CPU_USE_KML)\n#ifndef CPU_USE_KML\n#define CPU_USE_KML\n#endif\n#endif  // USE_MOE_KERNEL_AMD or CPU_USE_KML\n\n#define STRONG_TYPEDEF(T, D)                                   \\\n  struct D {                                                   \\\n    T t;                                                       \\\n    explicit D(const T &v) : t(v) {}                           \\\n    D() = default;                                             \\\n    D(const D &) = default;                                    \\\n    D &operator=(const D &) = default;                         \\\n    D &operator=(const T &rhs) {                               \\\n      t = rhs;                                                 \\\n      return *this;                                            \\\n    }                                                          \\\n    operator const T &() const { return t; }                   \\\n    operator T &() { return t; }                               \\\n    bool operator==(const D &rhs) const { return t == rhs.t; } \\\n    bool operator!=(const D &rhs) const { return t != rhs.t; } \\\n    bool operator<(const D &rhs) const { return t < rhs.t; }   \\\n  };\nSTRONG_TYPEDEF(int8_t, int4_2_t)\ntypedef int8_t BLASINT8;\n\n/* matrix transpose or conjugate transpose */\ntypedef enum KERNEL_CBLAS_TRANSPOSE {\n  KernelCblasNoTrans = 111,\n  KernelCblasTrans = 112,\n  KernelCblasConjTrans = 113,\n  KernelCblasConjNoTrans = 114\n} KERNEL_CBLAS_TRANSPOSE;\n/* matrix stored in rows or cols */\ntypedef enum KERNEL_CBLAS_ORDER { KernelCblasRowMajor = 101, KernelCblasColMajor = 102 } KERNEL_CBLAS_ORDER;\n/* matrix position is left or right */\ntypedef enum KERNEL_CBLAS_SIDE { KernelCblasLeft = 141, KernelCblasRight = 142 } KERNEL_CBLAS_SIDE;\ntypedef KERNEL_CBLAS_ORDER KERNEL_CBLAS_LAYOUT;\ntypedef enum KERNEL_CBLAS_OFFSET {\n  KernelCblasRowOffset = 171,\n  KernelCblasColOffset = 172,\n  KernelCblasFixOffset = 173\n} KERNEL_CBLAS_OFFSET;\n\nenum class MatKernelVariant {\n  Decode,\n  Prefill,\n};"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/api/mat_kernel.h",
    "content": "#pragma once\n\n#include <cstddef>\n#include <cstdint>\n#include <type_traits>\n\n#include \"common.h\"\n\nusing GemmFn = void (*)(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                        const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,\n                        const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,\n                        const int8_t oa, const void* b, const size_t ldb, const int8_t ob, const float beta, int32_t* c,\n                        const size_t ldc, const int32_t* oc);\n\nstruct MatKernelSelection {\n  GemmFn fn;\n  int divide_elements_size;\n};\n\nMatKernelSelection select_kernel_for_int4(MatKernelVariant variant);\nMatKernelSelection select_kernel_for_int8(MatKernelVariant variant);\n\ntemplate <typename T>\nMatKernelSelection select_mat_kernel(MatKernelVariant variant) {\n  if constexpr (std::is_same_v<typename T::dt, int4_2_t>) {\n    return select_kernel_for_int4(variant);\n  } else {\n    return select_kernel_for_int8(variant);\n  }\n}\n"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/la/kernel.hpp",
    "content": "#ifndef CPUINFER_OPERATOR_KERNEL_LA_HPP\n#define CPUINFER_OPERATOR_KERNEL_LA_HPP\n\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n#include <cstddef>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <stdexcept>\n#include <string>\n#include <vector>\n\n#include \"../api/common.h\"\n#include \"../mat_kernel/batch_gemm_api.hpp\"\n#include \"llama.cpp/ggml.h\"\nstatic const size_t MAX_Nth_B = 1024, MAX_N_B = 1024, MAX_K_B = 10240;\nnamespace moe_kernel {\ntemplate <typename T>\nT *offset_pointer(T *ptr, size_t byte_offset) {\n  return reinterpret_cast<T *>(reinterpret_cast<char *>(ptr) + byte_offset);\n}\n\ninline float bf16_to_fp32(ggml_bf16_t src) {\n  // 将 bfloat16 的 16 位移到 float32 的高 16 位，低 16 位填充 0\n  uint16_t *src_16 = reinterpret_cast<uint16_t *>(&src);\n  uint32_t packed = (uint32_t)*src_16 << 16;\n\n  // 使用 union 将 uint32 解释为 float\n  union {\n    uint32_t u;\n    float f;\n  } converter;\n\n  converter.u = packed;\n  return converter.f;\n}\n\ninline float fp16_to_fp32(ggml_fp16_t src) { return ggml_fp16_to_fp32(src); }\n\ntemplate <typename K>\nstruct BufferAImpl {\n  int8_t *a;\n  float *d;\n  int max_m, k;\n  bool if_pack = false;\n\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int K_STEP = K::K_STEP;\n  // K_BLOCK is runtime-configurable via kernel tiling; expose as function to avoid constexpr requirements\n  static inline int K_BLOCK() { return K::K_BLOCK; }\n  static constexpr int PACK_SIZE_M = K::PACK_SIZE_M;\n  static constexpr int PACK_SIZE_K = K::PACK_SIZE_K;\n\n  static size_t required_size(int max_m, int k) { return sizeof(int8_t) * max_m * k + sizeof(float) * max_m; }\n\n  BufferAImpl(int max_m, int k, void *ptr, bool if_pack = false) : max_m(max_m), k(k), if_pack(if_pack) {\n    set_data(ptr);\n  }\n\n  BufferAImpl(int max_m, int k, bool if_pack = false) : max_m(max_m), k(k), if_pack(if_pack) {\n    if (max_m % M_STEP != 0 || k % K_STEP != 0) {\n      throw std::runtime_error(\"max_m and k must be multiples of M_STEP and K_STEP respectively\");\n    }\n  }\n\n  void set_data(void *ptr) {\n    a = reinterpret_cast<int8_t *>(ptr);\n    d = reinterpret_cast<float *>(a + max_m * k);\n  }\n\n  size_t required_size() const { return sizeof(int8_t) * max_m * k + sizeof(float) * max_m; }\n\n  BufferAImpl<K> offset_row(size_t row_begin, size_t row_block) {\n    auto buffera = BufferAImpl<K>(row_block, k, a + row_begin * k, if_pack);\n    buffera.d = d + row_begin;\n    return buffera;\n  }\n\n  // 将输入的 A 矩阵转换成 int8_t 的形式，\n  // 这里的 A 矩阵是 m * k 的矩阵，存储在 a 中, 是行主序的 (row major)\n  void from_mat(int m, ggml_bf16_t *src, int ith, int mth) {\n    // printf(\"in A from_mat, m = %d, ith = %d, nth = %d\\n\", m, ith, nth);\n    auto [m_start, m_end] = K::split_range_m(m, ith, mth);\n    int m_block_begin = m_start;\n    int m_block_size = m_end - m_block_begin;\n    if (m_block_size < 0) {\n      throw std::runtime_error(\"m_block_size is negative, this should not happen\");\n    }\n    for (int m_begin = 0; m_begin < m_block_size; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m_block_size; i++) {\n        float amax = 0;\n        // TODO: 后续用 SVE 来加速\n        for (int j = 0; j < k; j++) {\n          // 先把 src 转换成 float\n          float f = bf16_to_fp32(src[(m_block_begin + m_begin + i) * k + j]);\n          f = f < 0 ? -f : f;\n          if (f > amax) {\n            amax = f;\n          }\n        }\n        d[m_block_begin + m_begin + i] = amax / ((1 << 7) - 1);\n        // TODO: 后续用 SVE 来加速\n        // 通过这个 amax 来量化这一行\n        for (int j = 0; j < k; j++) {\n          // 先把 src 转换成 float\n          float f = bf16_to_fp32(src[(m_block_begin + m_begin + i) * k + j]);\n          if (if_pack) {\n            throw std::runtime_error(\"Packing is deprecated in this function\");\n            size_t split_m = (m_begin + i) / PACK_SIZE_M;\n            size_t m_idx = (m_begin + i) % PACK_SIZE_M;\n            size_t split_k = j / PACK_SIZE_K;\n            size_t k_idx = j % PACK_SIZE_K;\n            size_t buff_idx = m_block_begin * k + split_m * PACK_SIZE_M * k + split_k * PACK_SIZE_K * PACK_SIZE_M +\n                              m_idx * PACK_SIZE_K + k_idx;\n            a[buff_idx] = static_cast<int8_t>(std::round(f / d[m_block_begin + m_begin + i]));\n          } else {\n            // 这里的 amax 是当前行的最大值\n            a[(m_block_begin + m_begin + i) * k + j] =\n                static_cast<int8_t>(std::round(f / d[m_block_begin + m_begin + i]));\n          }\n        }\n      }\n    }\n  }\n\n  void from_mat(int m, ggml_fp16_t *src, int ith, int mth) {\n    // printf(\"in A from_mat, m = %d, ith = %d, nth = %d\\n\", m, ith, nth);\n    auto [m_start, m_end] = K::split_range_m(m, ith, mth);\n    int m_block_begin = m_start;\n    int m_block_size = m_end - m_block_begin;\n    if (m_block_size < 0) {\n      throw std::runtime_error(\"m_block_size is negative, this should not happen\");\n    }\n    for (int m_begin = 0; m_begin < m_block_size; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m_block_size; i++) {\n        float amax = 0;\n        // TODO: 后续用 SVE 来加速\n        for (int j = 0; j < k; j++) {\n          // 先把 src 转换成 float\n          float f = fp16_to_fp32(src[(m_block_begin + m_begin + i) * k + j]);\n          f = f < 0 ? -f : f;\n          if (f > amax) {\n            amax = f;\n          }\n        }\n        d[m_block_begin + m_begin + i] = amax / ((1 << 7) - 1);\n        // TODO: 后续用 SVE 来加速\n        // 通过这个 amax 来量化这一行\n        for (int j = 0; j < k; j++) {\n          // 先把 src 转换成 float\n          float f = fp16_to_fp32(src[(m_block_begin + m_begin + i) * k + j]);\n          if (if_pack) {\n            throw std::runtime_error(\"Packing is deprecated in this function\");\n            size_t split_m = (m_begin + i) / PACK_SIZE_M;\n            size_t m_idx = (m_begin + i) % PACK_SIZE_M;\n            size_t split_k = j / PACK_SIZE_K;\n            size_t k_idx = j % PACK_SIZE_K;\n            size_t buff_idx = m_block_begin * k + split_m * PACK_SIZE_M * k + split_k * PACK_SIZE_K * PACK_SIZE_M +\n                              m_idx * PACK_SIZE_K + k_idx;\n            a[buff_idx] = static_cast<int8_t>(std::round(f / d[m_block_begin + m_begin + i]));\n          } else {\n            // 这里的 amax 是当前行的最大值\n            a[(m_block_begin + m_begin + i) * k + j] =\n                static_cast<int8_t>(std::round(f / d[m_block_begin + m_begin + i]));\n          }\n        }\n      }\n    }\n  }\n\n  // 这里是针对 gate_output 作为 fp32 的形式，量化到 int8_t 的形式\n  // 这里的 A 矩阵是 m * n (intermediate_size) 的矩阵，存储在 a 中, 是行主序的 (row major)\n  void from_mat(int m, float *src, int ith, int mth) {\n    assert(m <= max_m);\n    // assert(ith == 0 && nth == 1);\n    auto [m_start, m_end] = K::split_range_m(m, ith, mth);\n    int m_block_begin = m_start;\n    int m_block_size = m_end - m_block_begin;\n    for (int m_begin = 0; m_begin < m_block_size; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m_block_size; i++) {\n        float amax = 0;\n        // TODO: 后续用 SVE 来加速\n        for (int j = 0; j < k; j++) {\n          // 先把 src 转换成 float\n          float f = src[(m_block_begin + m_begin + i) * k + j];\n          f = f < 0 ? -f : f;\n          if (f > amax) {\n            amax = f;\n          }\n        }\n        d[m_block_begin + m_begin + i] = amax / ((1 << 7) - 1);\n        // TODO: 后续用 SVE 来加速\n        // 通过这个 amax 来量化这一行\n        for (int j = 0; j < k; j++) {\n          // 先把 src 转换成 float\n          float f = src[(m_block_begin + m_begin + i) * k + j];\n          if (if_pack) {\n            throw std::runtime_error(\"Packing is deprecated in this function\");\n            size_t split_m = (m_begin + i) / PACK_SIZE_M;\n            size_t m_idx = (m_begin + i) % PACK_SIZE_M;\n            size_t split_k = j / PACK_SIZE_K;\n            size_t k_idx = j % PACK_SIZE_K;\n            size_t buff_idx = m_block_begin * k + split_m * PACK_SIZE_M * k + split_k * PACK_SIZE_K * PACK_SIZE_M +\n                              m_idx * PACK_SIZE_K + k_idx;\n            a[buff_idx] = static_cast<int8_t>(std::round(f / d[m_block_begin + m_begin + i]));\n          } else {\n            // 这里的 amax 是当前行的最大值\n            a[(m_block_begin + m_begin + i) * k + j] =\n                static_cast<int8_t>(std::round(f / d[m_block_begin + m_begin + i]));\n          }\n        }\n      }\n    }\n  }\n\n  void from_mat(int m, float *src) {\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n        float amax = 0;\n        // TODO: 后续用 SVE 来加速\n        for (int j = 0; j < k; j++) {\n          // 先把 src 转换成 float\n          float f = src[(m_begin + i) * k + j];\n          f = f < 0 ? -f : f;\n          if (f > amax) {\n            amax = f;\n          }\n        }\n        d[m_begin + i] = amax / ((1 << 7) - 1);\n        // TODO: 后续用 SVE 来加速\n        // 通过这个 amax 来量化这一行\n        for (int j = 0; j < k; j++) {\n          // 先把 src 转换成 float\n          float f = src[(m_begin + i) * k + j];\n          // 这里的 amax 是当前行的最大值\n          a[(m_begin + i) * k + j] = static_cast<int8_t>(std::round(f / d[m_begin + i]));\n        }\n      }\n    }\n  }\n\n  // 反量化\n  void to_mat(int m, float *dst, int ith, int mth) {\n    auto [m_start, m_end] = K::split_range_m(m, ith, mth);\n    int m_block_begin = m_start;\n    int m_block_size = m_end - m_block_begin;\n    for (int m_begin = 0; m_begin < m_block_size; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m_block_size; i++) {\n        for (int j = 0; j < k; j++) {\n          float f = static_cast<float>(a[(m_block_begin + m_begin + i) * k + j]);\n          f *= d[m_block_begin + m_begin + i];\n          dst[(m_block_begin + m_begin + i) * k + j] = f;\n        }\n      }\n    }\n  }\n\n  float *get_scale(int m, int m_begin) { return d + m_begin; }\n};\n\ntemplate <typename K>\nstruct BufferCImpl {\n  int32_t *c;\n  int max_m, n;\n  bool if_row_major;\n\n  static constexpr int M_STEP = K::M_STEP;\n  static constexpr int N_STEP = K::N_STEP;\n  // N_BLOCK is runtime-configurable via kernel tiling; expose as function to avoid constexpr requirements\n  static inline int N_BLOCK() { return K::N_BLOCK; }\n\n  static size_t required_size(int max_m, int n) { return sizeof(int32_t) * max_m * n; }\n\n  BufferCImpl(int max_m, int n, void *ptr, bool if_row_major = false) : max_m(max_m), n(n), if_row_major(if_row_major) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    assert(max_m % M_STEP == 0);\n    assert(n % N_STEP == 0);\n    c = reinterpret_cast<int *>(ptr);\n  }\n\n  BufferCImpl(int max_m, int n, bool if_row_major = false) : max_m(max_m), n(n), if_row_major(if_row_major) {}\n\n  void set_data(void *ptr) {\n    assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n    c = reinterpret_cast<int32_t *>(ptr);\n  }\n  size_t required_size() const { return sizeof(int32_t) * max_m * n; }\n\n  // void to_mat(int m, float **dst, int ith, int nth) {\n  //   *dst = c + ith * N_BLOCK;\n  // }\n};\n\nstruct GemmKernelInt8 {\n  using dt = int8_t;\n  using output_t = int32_t;\n  static constexpr double ELEMENT_SIZE = 1;\n  static const int TILE_M = 16;\n  static const int TILE_K = 64;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n\n  // static const int M_STEP = TILE_M * 2;\n  // static const int N_STEP = TILE_N * 2;\n  // static const int K_STEP = TILE_K;\n  static const int M_STEP = 1;\n  static const int N_STEP = 1;\n  static const int K_STEP = 1;\n\n  // static inline const int N_BLOCK = 1024;\n  // Make tiling params runtime-configurable (modifiable via Python bindings)\n  static inline int N_BLOCK_UP_GATE = 32;\n  static inline int N_BLOCK_DOWN = 64;\n  static inline int N_BLOCK_UP_GATE_PREFI = 32;\n  static inline int N_BLOCK_DOWN_PREFI = 64;\n  static inline int N_BLOCK = 64;\n  static inline int M_BLOCK = 320;\n  // static inline const int N_BLOCK = 32;\n  static inline int K_BLOCK = 7168;\n\n  // Setter/getter for runtime tiling configuration\n  static void set_tiling(int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block,\n                         int n_block_up_gate_prefi, int n_block_down_prefi) {\n    N_BLOCK_UP_GATE = n_block_up_gate;\n    N_BLOCK_DOWN = n_block_down;\n    N_BLOCK = n_block;\n    M_BLOCK = m_block;\n    K_BLOCK = k_block;\n    N_BLOCK_UP_GATE_PREFI = n_block_up_gate_prefi;\n    N_BLOCK_DOWN_PREFI = n_block_down_prefi;\n  }\n  static std::tuple<int, int, int, int, int, int, int> get_tiling() {\n    return std::make_tuple(N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK, N_BLOCK_UP_GATE_PREFI,\n                           N_BLOCK_DOWN_PREFI);\n  }\n\n  static inline const int PACK_SIZE_N = 8;\n  static inline const int PACK_SIZE_M = 8;\n  static inline const int PACK_SIZE_K = 32;\n\n  static std::string name() { return \"MOE_INT8\"; }\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n  // type_: d for decode, p for prefill\n  static int recommended_nth_down(int n, char type_ = 'd') {\n    if (type_ == 'p') {\n      if (n % N_BLOCK_DOWN_PREFI != 0) {\n        throw std::invalid_argument(\"n must be multiple of N_BLOCK_DOWN_PREFI in prefill\");\n      }\n      return n / N_BLOCK_DOWN_PREFI;\n    } else {\n      if (n % N_BLOCK_DOWN != 0) {\n        throw std::invalid_argument(\"n must be multiple of N_BLOCK_DOWN in decode\");\n      }\n      return n / N_BLOCK_DOWN;\n    }\n  }\n\n  static int recommended_nth_up_gate(int n, char type_ = 'd') {\n    if (type_ == 'p') {\n      if (n % N_BLOCK_UP_GATE_PREFI != 0) {\n        throw std::invalid_argument(\"n must be multiple of N_BLOCK_UP_GATE_PREFI in prefill\");\n      }\n      return n / N_BLOCK_UP_GATE_PREFI;\n    } else {\n      if (n % N_BLOCK_UP_GATE != 0) {\n        throw std::invalid_argument(\"n must be multiple of N_BLOCK_UP_GATE in decode\");\n      }\n      return n / N_BLOCK_UP_GATE;\n    }\n  }\n\n  static int recommended_mth(int m) { return (m + M_BLOCK - 1) / M_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth, int block_size = N_BLOCK) {\n    int n_start = block_size * ith;\n    int n_end = std::min(n, block_size * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static std::pair<int, int> split_range_m(int m, int ith, int mth = 0) {\n    int m_start = M_BLOCK * ith;\n    int m_end = std::min(m, M_BLOCK * (ith + 1));\n    return {m_start, m_end};\n  }\n\n  static std::pair<int, int> split_range_n_block(int n, int ith, int nth, int block) {\n    int n_start = block * ith;\n    int n_end = std::min(n, block * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  using BufferA = BufferAImpl<GemmKernelInt8>;\n  using BufferC = BufferCImpl<GemmKernelInt8>;\n\n  struct BufferB {\n    int8_t *b;\n    std::vector<int8_t *> b_pack;  // b_pack[i] -> the ith block (the ith packed matrix of B)\n    size_t reorder_B_size;\n    size_t nth_B;       // number of blocks of B\n    size_t block_size;  // size of each block of B\n    float *d;\n    int n, k;\n    static constexpr bool SCALE = true;\n    bool if_pack = false;\n    // n for normal, u for up_gate, d for down\n    static size_t required_size(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) {\n      int nth, n_block;\n      if (if_pack && !plain) {\n        switch (mat_type) {\n          case 'n':\n            nth = recommended_nth(n);\n            n_block = N_BLOCK;\n            break;\n          case 'u':\n            nth = recommended_nth_up_gate(n);\n            n_block = N_BLOCK_UP_GATE;\n            break;\n          case 'd':\n            nth = recommended_nth_down(n);\n            n_block = N_BLOCK_DOWN;\n            break;\n          default:\n            throw std::invalid_argument(\"Invalid mat_type\");\n        }\n        size_t reorder_B_size = get_reorder_B_size(KernelCblasRowMajor, KernelCblasNoTrans, k, n_block);\n        return sizeof(int8_t) * nth * reorder_B_size + sizeof(float) * n;\n      } else {\n        return sizeof(int8_t) * n * k + sizeof(float) * n;\n      }\n    }\n    BufferB(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) : n(n), k(k), if_pack(if_pack) {\n      int nth, n_block;\n      if (if_pack && !plain) {\n        switch (mat_type) {\n          case 'n':\n            nth = recommended_nth(n);\n            n_block = N_BLOCK;\n            break;\n          case 'u':\n            nth = recommended_nth_up_gate(n);\n            n_block = N_BLOCK_UP_GATE;\n            break;\n          case 'd':\n            nth = recommended_nth_down(n);\n            n_block = N_BLOCK_DOWN;\n            break;\n          default:\n            throw std::invalid_argument(\"Invalid mat_type\");\n        }\n        reorder_B_size = get_reorder_B_size(KernelCblasRowMajor, KernelCblasNoTrans, k, n_block);\n        nth_B = nth;\n        block_size = n_block;\n        b_pack.resize(nth);\n      }\n      if (n % N_STEP != 0 || k % K_STEP != 0) {\n        throw std::runtime_error(\"n and k must be multiples of N_STEP and K_STEP respectively\");\n      }\n    }\n    BufferB(int n, int k, void *ptr, bool if_pack = false, char mat_type = 'n', bool plain = true)\n        : BufferB(n, k, if_pack, mat_type, plain) {\n      set_data(ptr, plain);\n      // printf(\"mat_type:%c,nth_B:%zu,b_pack_ptr[0]:%p,d_ptr:%p,ptr:%p\\n\", mat_type, nth_B, b_pack[0], d, ptr);\n    }\n    void set_data(void *ptr, bool plain = true) {\n      if (if_pack && !plain) {\n        for (size_t i = 0; i < nth_B; i++) {\n          b_pack[i] = reinterpret_cast<int8_t *>(ptr) + i * reorder_B_size;\n        }\n        d = reinterpret_cast<float *>((int8_t *)ptr + nth_B * reorder_B_size);\n      } else {\n        b = reinterpret_cast<int8_t *>(ptr);\n        d = reinterpret_cast<float *>(b + n * k);\n      }\n    }\n    size_t required_size() const { return sizeof(int8_t) * n * k + sizeof(float) * n; }\n    BufferB offset_col(size_t col_begin, size_t col_block) {\n      auto bufferb = BufferB(col_block, k, b + col_begin * k, if_pack);\n      bufferb.d = d + col_begin;\n      return bufferb;\n    }\n    // B 矩阵是 K * N 的矩阵，存储在 b 中, 是列主序的 (column major)\n    void from_mat(ggml_bf16_t *src, int ith, int nth, int n_new = -1, bool if_pack = false,\n                  bool plain = true) {  // CHECK: nth has no usage\n      if (n_new > 0) {\n        n = n_new;  // 如果 n_new 大于 0，则使用 n_new\n      }\n      // 这里将 src 转换成 int8_t 的形式，按照k 维度量化  (也就是按列量化)\n      int8_t *b_t = nullptr;\n      if ((if_pack || this->if_pack) && !plain) {\n        b_t = (int8_t *)malloc(sizeof(int8_t) * n * k);\n      }\n      auto [n_start, n_end] = split_range_n(n, ith, nth, block_size);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < N_STEP && n_begin + i < n_block_size; i++) {\n          float amax = 0;\n          // TODO: 后续用 SVE 来加速\n          for (int j = 0; j < k; j++) {\n            // 先把 src 转换成 float\n            float f = bf16_to_fp32(src[(n_block_begin + n_begin + i) * k + j]);\n            f = f < 0 ? -f : f;\n            if (f > amax) {\n              amax = f;\n            }\n          }\n          d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1);\n          // TODO: 后续用 SVE 来加速\n          // 通过这个 amax 来量化这一列\n          for (int j = 0; j < k; j++) {\n            // 先把 src 转换成 float\n            float f = bf16_to_fp32(src[(n_block_begin + n_begin + i) * k + j]);\n            if ((if_pack || this->if_pack) && plain) {\n              size_t split_n = (n_begin + i) / PACK_SIZE_N;\n              size_t n_idx = (n_begin + i) % PACK_SIZE_N;\n              size_t split_k = j / PACK_SIZE_K;\n              size_t k_idx = j % PACK_SIZE_K;\n\n              size_t buff_idx = n_block_begin * k + split_n * PACK_SIZE_N * k + split_k * PACK_SIZE_N * PACK_SIZE_K +\n                                n_idx * PACK_SIZE_K + k_idx;\n              b[buff_idx] = static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));\n            } else if ((if_pack || this->if_pack) && !plain) {\n              // 这里的 amax 是当前列的最大值\n              b_t[(n_begin + i) * k + j] = static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));\n            } else {\n              b[(n_block_begin + n_begin + i) * k + j] =\n                  static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));\n            }\n          }\n        }\n      }\n      if ((if_pack || this->if_pack) && !plain) {\n        // 在这里调用 AMD 的reorder函数\n        reorder_B_gemm(KernelCblasColMajor, KernelCblasNoTrans, k, n_block_size, k, b_t, b_pack[ith]);\n        free(b_t);\n      }\n    }\n\n    void from_mat(float *src, int ith, int nth, int n_new = -1, bool if_pack = false) {  // CHECK: nth has no usage\n      if (n_new > 0) {\n        n = n_new;  // 如果 n_new 大于 0，则使用 n_new\n      }\n      // 这里将 src 转换成 int8_t 的形式，按照k 维度量化  (也就是按列量化)\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      // printf(\"n_start = %d, n_end = %d, n = %d\\n\", n_start, n_end, n);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      float average = 0;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < N_STEP && n_begin + i < n_block_size; i++) {\n          float amax = 0;\n          // TODO: 后续用 SVE 来加速\n          for (int j = 0; j < k; j++) {\n            // 先把 src 转换成 float\n            float f = src[(n_block_begin + n_begin + i) * k + j];\n            f = f < 0 ? -f : f;\n            average += f;\n            if (f > amax) {\n              amax = f;\n            }\n          }\n          average /= k;\n          d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1);\n          // printf(\"amax: %f,average: %f\\n\", amax, average);\n          // TODO: 后续用 SVE 来加速\n          // 通过这个 amax 来量化这一列\n          for (int j = 0; j < k; j++) {\n            // 先把 src 转换成 float\n            float f = src[(n_block_begin + n_begin + i) * k + j];\n            // 这里的 amax 是当前列的最大值\n            if (if_pack || this->if_pack) {\n              size_t split_n = (n_begin + i) / PACK_SIZE_N;\n              size_t n_idx = (n_begin + i) % PACK_SIZE_N;\n              size_t split_k = j / PACK_SIZE_K;\n              size_t k_idx = j % PACK_SIZE_K;\n\n              size_t buff_idx = n_block_begin * k + split_n * PACK_SIZE_N * k + split_k * PACK_SIZE_N * PACK_SIZE_K +\n                                n_idx * PACK_SIZE_K + k_idx;\n              b[buff_idx] = static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));\n            } else {\n              b[(n_block_begin + n_begin + i) * k + j] =\n                  static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));\n            }\n          }\n        }\n      }\n    }\n\n    void from_mat_row_major(float *src, int ld, int ith, int nth, int n_new = -1) {  // CHECK: nth has no usage\n      if (n_new > 0) {\n        n = n_new;  // 如果 n_new 大于 0，则使用 n_new\n      }\n      // 这里将 src 转换成 int8_t 的形式，按照k 维度量化  (也就是按列量化),但是 src 是行主序的\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < N_STEP && n_begin + i < n_block_size; i++) {\n          float amax = 0;\n          for (int j = 0; j < k; j++) {\n            float f = src[j * ld + (n_block_begin + n_begin + i)];\n            f = f < 0 ? -f : f;\n            if (f > amax) {\n              amax = f;\n            }\n          }\n          d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1);\n          for (int j = 0; j < k; j++) {\n            float f = src[j * ld + (n_block_begin + n_begin + i)];\n            // 这里的 amax 是当前列的最大值\n            b[(n_block_begin + n_begin + i) * k + j] =\n                static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));\n          }\n        }\n      }\n    }\n\n    // 将内容解量化为 float\n    void to_mat(float *dst, int ith, int nth, int n_new = -1) {\n      if (n_new > 0) {\n        n = n_new;  // 如果 n_new 大于 0，则使用 n_new\n      }\n      // 这里将 b 转换成 float 的形式，按照k 维度解量化\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < N_STEP && n_begin + i < n_block_size; i++) {\n          // 通过这个 amax 来解量化这一列\n          for (int j = 0; j < k; j++) {\n            // 先把 b 转换成 float\n            int8_t b_val = b[(n_block_begin + n_begin + i) * k + j];\n            float d_val = d[n_block_begin + n_begin + i];\n            dst[(n_block_begin + n_begin + i) * k + j] = b_val * d_val;\n          }\n        }\n      }\n    }\n\n    float *get_scale(int n, int n_begin) { return d + n_begin; }\n  };\n  /* 将 buffer A 转为 buffer B, [m,k](row major) -> [k,n](column major) (n = m)\n    而量化部分没变化，直接 buffer A 的 d = buffer B 的 d，校验 m 和 n 以及 k是否相等，才能转换\n  */\n  static void convert_buffer_a_to_buffer_b(BufferA *ba, BufferB *bb) {\n    if (bb->n != ba->max_m || bb->k != ba->k || bb->if_pack != ba->if_pack) {\n      throw std::runtime_error(\n          \"BufferA and BufferB dimensions do not match for conversion, or they are not the same pack.\");\n    }\n    bb->b = ba->a;\n    bb->d = ba->d;\n  }\n\n  static void convert_buffer_b_to_buffer_a(BufferB *bb, BufferA *ba) {\n    if (ba->max_m != bb->n || ba->k != bb->k || ba->if_pack != bb->if_pack) {\n      throw std::runtime_error(\n          \"BufferB and BufferA dimensions do not match for conversion, or they are not the same pack.\");\n    }\n    ba->a = bb->b;\n    ba->d = bb->d;\n  }\n  // 改变当前 C 的 view\n  static void change_view(BufferC *c_src, BufferC *c_dst) {\n    if (c_src->max_m != c_dst->n || c_src->n != c_dst->max_m || c_src->if_row_major == c_dst->if_row_major) {\n      throw std::runtime_error(\"C buffer size mismatch or they are the same major\");\n    }\n    c_dst->c = c_src->c;\n  }\n  // 此函数作用是，对 int32结果的 c 矩阵应用 A和 B 矩阵的scale（反量化）\n  // 这里的 c 矩阵是 m * n 的矩阵，存储在 c 中, 是行主序的 (row major)\n  // A 矩阵是 m * k 的矩阵，按照行量化，其 scale 是 d 是 m 维度的，对应每一行的量化系数\n  // B 矩阵是 k * n 的矩阵，按照列量化，其 scale 是 d 是 n 维度的，对应每一列的量化系数\n  // C 的第 i 行第 j 列的缩放值就是 A 的第 i 行的缩放值 * B 的第 j 列的缩放值\n  static void apply_scale(int m, int n, float *c, BufferA *ba, BufferB *bb, BufferC *bc) {\n    // TODO: 后续用 SVE 来加速\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n        float *scale_a = ba->get_scale(m, m_begin + i);\n        for (int n_begin = 0; n_begin < n; n_begin += N_STEP) {\n          for (int j = 0; j < N_STEP && n_begin + j < n; j++) {\n            float *scale_b = bb->get_scale(n, n_begin + j);\n            c[(m_begin + i) * n + (n_begin + j)] = (*scale_a) * (*scale_b) * bc->c[(m_begin + i) * n + (n_begin + j)];\n          }\n        }\n      }\n    }\n  }\n\n  // 对第二个维度分块的 apply scale\n  static void apply_scale(int m, int n, float *c, BufferA *ba, BufferB *bb, BufferC *bc, int ith, int nth, int block,\n                          int jth = -1) {\n    // printf(\"use split apply scale\\n\");\n    auto [n_start, n_end] = split_range_n_block(n, ith, nth, block);\n    int m_start = 0, m_end = m;\n    if (jth != -1) {\n      auto tmp = split_range_m(m, jth);\n      m_start = tmp.first;\n      m_end = tmp.second;\n    }\n    // TODO: 后续用 SVE 来加速\n    for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {\n        float *scale_a = ba->get_scale(m, m_begin + i);\n        for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {\n          for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {\n            float *scale_b = bb->get_scale(n, n_begin + j);\n            c[(m_begin + i) * n + (n_begin + j)] = (*scale_a) * (*scale_b) * bc->c[(m_begin + i) * n + (n_begin + j)];\n          }\n        }\n      }\n    }\n  }\n\n  // 两个维度均有分块的 apply scale\n  // C 矩阵区分是 row major 还是 column major\n  static void apply_scale(float *c, int ldc, BufferA *ba, BufferB *bb, BufferC *bc, int m_start, int m_end, int n_start,\n                          int n_end, bool if_row_major = true, long long c_row_idx_offset = 0,\n                          long long c_col_idx_offset = 0) {\n    if (if_row_major) {\n      for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {\n        for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {\n          float *scale_a = ba->get_scale(m_end, m_begin + i);\n          for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {\n            for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {\n              float *scale_b = bb->get_scale(n_end, n_begin + j);\n              c[(m_begin + i + c_row_idx_offset) * ldc + (n_begin + j + c_col_idx_offset)] =\n                  (*scale_a) * (*scale_b) *\n                  bc->c[(m_begin + i + c_row_idx_offset) * ldc + (n_begin + j + c_col_idx_offset)];\n            }\n          }\n        }\n      }\n    } else {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {\n        for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {\n          float *scale_b = bb->get_scale(n_end, n_begin + j);\n          for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {\n            for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {\n              float *scale_a = ba->get_scale(m_end, m_begin + i);\n              c[(n_begin + j + c_col_idx_offset) * ldc + (m_begin + i + c_row_idx_offset)] =\n                  (*scale_a) * (*scale_b) *\n                  bc->c[(n_begin + j + c_col_idx_offset) * ldc + (m_begin + i + c_row_idx_offset)];\n            }\n          }\n        }\n      }\n    }\n  }\n\n  // 两个维度均有分块的 apply scale\n  // C 矩阵区分是 row major 还是 column major\n  static void apply_scale(float *c, int ldc, BufferA *ba, BufferB *bb, int32_t *bc, int m_start, int m_end, int n_start,\n                          int n_end, bool if_row_major = true, long long c_row_idx_offset = 0,\n                          long long c_col_idx_offset = 0) {\n    if (if_row_major) {\n      for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {\n        for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {\n          float *scale_a = ba->get_scale(m_end, m_begin + i);\n          for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {\n            for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {\n              float *scale_b = bb->get_scale(n_end, n_begin + j);\n              c[(m_begin + i + c_row_idx_offset) * ldc + (n_begin + j + c_col_idx_offset)] =\n                  (*scale_a) * (*scale_b) *\n                  bc[(m_begin + i + c_row_idx_offset) * ldc + (n_begin + j + c_col_idx_offset)];\n            }\n          }\n        }\n      }\n    } else {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {\n        for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {\n          float *scale_b = bb->get_scale(n_end, n_begin + j);\n          for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {\n            for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {\n              float *scale_a = ba->get_scale(m_end, m_begin + i);\n              c[(n_begin + j + c_col_idx_offset) * ldc + (m_begin + i + c_row_idx_offset)] =\n                  (*scale_a) * (*scale_b) *\n                  bc[(n_begin + j + c_col_idx_offset) * ldc + (m_begin + i + c_row_idx_offset)];\n            }\n          }\n        }\n      }\n    }\n  }\n};\n\nstruct GemmKernelInt4 {\n  using dt = int4_2_t;\n  using output_t = int32_t;\n  static constexpr double ELEMENT_SIZE = 0.5;\n  static const int TILE_M = 16;\n  static const int TILE_K = 64;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n\n  // static const int M_STEP = TILE_M * 2;\n  // static const int N_STEP = TILE_N * 2;\n  // static const int K_STEP = TILE_K;\n  static const int M_STEP = 1;\n  static const int N_STEP = 1;\n  static const int K_STEP = 1;\n\n  // static inline const int N_BLOCK = 1024;\n  // Make tiling params runtime-configurable (modifiable via Python bindings)\n  static inline int N_BLOCK_UP_GATE = 256;\n  static inline int N_BLOCK_DOWN = 1024;\n  static inline int N_BLOCK_UP_GATE_PREFI = 256;\n  static inline int N_BLOCK_DOWN_PREFI = 1024;\n  static inline int N_BLOCK = 64;\n  static inline int M_BLOCK = 320;\n  // static inline const int N_BLOCK = 32;\n  static inline int K_BLOCK = 7168;\n\n  // Setter/getter for runtime tiling configuration\n  static void set_tiling(int n_block_up_gate, int n_block_down, int n_block, int m_block, int k_block,\n                         int n_block_up_gate_prefi, int n_block_down_prefi) {\n    N_BLOCK_UP_GATE = n_block_up_gate;\n    N_BLOCK_DOWN = n_block_down;\n    N_BLOCK = n_block;\n    M_BLOCK = m_block;\n    K_BLOCK = k_block;\n    N_BLOCK_UP_GATE_PREFI = n_block_up_gate_prefi;\n    N_BLOCK_DOWN_PREFI = n_block_down_prefi;\n  }\n  static std::tuple<int, int, int, int, int, int, int> get_tiling() {\n    return std::make_tuple(N_BLOCK_UP_GATE, N_BLOCK_DOWN, N_BLOCK, M_BLOCK, K_BLOCK, N_BLOCK_UP_GATE_PREFI,\n                           N_BLOCK_DOWN_PREFI);\n  }\n\n  static inline const int PACK_SIZE_N = 8;\n  static inline const int PACK_SIZE_K = 32;\n  static inline const int PACK_SIZE_M = 8;\n\n  static std::string name() { return \"MOE_INT4\"; }\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static int recommended_nth_down(int n, char type_ = 'd') {\n    if (type_ == 'p') {\n      if (n % N_BLOCK_DOWN_PREFI != 0) {\n        throw std::invalid_argument(\"n must be multiple of N_BLOCK_DOWN_PREFI in prefill\");\n      }\n      return n / N_BLOCK_DOWN_PREFI;\n    } else {\n      if (n % N_BLOCK_DOWN != 0) {\n        throw std::invalid_argument(\"n must be multiple of N_BLOCK_DOWN in decode\");\n      }\n      return n / N_BLOCK_DOWN;\n    }\n  }\n  static int recommended_mth(int m) { return (m + M_BLOCK - 1) / M_BLOCK; }\n\n  static int recommended_nth_up_gate(int n, char type_ = 'd') {\n    if (type_ == 'p') {\n      if (n % N_BLOCK_UP_GATE_PREFI != 0) {\n        throw std::invalid_argument(\"n must be multiple of N_BLOCK_UP_GATE_PREFI in prefill\");\n      }\n      return n / N_BLOCK_UP_GATE_PREFI;\n    } else {\n      if (n % N_BLOCK_UP_GATE != 0) {\n        throw std::invalid_argument(\"n must be multiple of N_BLOCK_UP_GATE in decode\");\n      }\n      return n / N_BLOCK_UP_GATE;\n    }\n  }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n  static std::pair<int, int> split_range_m(int m, int ith, int mth) {\n    int n_start = M_BLOCK * ith;\n    int n_end = std::min(m, M_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static std::pair<int, int> split_range_n_block(int n, int ith, int nth, int block) {\n    int n_start = block * ith;\n    int n_end = std::min(n, block * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  using BufferA = BufferAImpl<GemmKernelInt4>;\n  using BufferC = BufferCImpl<GemmKernelInt4>;\n\n  struct BufferB {\n    dt *b;\n    float *d;\n    int n, k;\n    std::vector<int8_t *> b_pack;  // b_pack[i] -> the ith block (the ith packed matrix of B)\n    static constexpr bool SCALE = true;\n    bool if_pack = false;\n\n    // static size_t required_size(int n, int k) { return sizeof(int8_t) * n * k / 2 + sizeof(float) * n; }\n    static size_t required_size(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) {\n      int nth, n_block;\n      if (if_pack && !plain) {\n        switch (mat_type) {\n          case 'n':\n            nth = recommended_nth(n);\n            n_block = N_BLOCK;\n            break;\n          case 'u':\n            nth = recommended_nth_up_gate(n);\n            n_block = N_BLOCK_UP_GATE;\n            break;\n          case 'd':\n            nth = recommended_nth_down(n);\n            n_block = N_BLOCK_DOWN;\n            break;\n          default:\n            throw std::invalid_argument(\"Invalid mat_type\");\n        }\n        size_t reorder_B_size = get_reorder_B_size(KernelCblasRowMajor, KernelCblasNoTrans, k, n_block);\n        return sizeof(int8_t) * nth * reorder_B_size + sizeof(float) * n;\n      } else {\n        return sizeof(int8_t) * n * k / 2 + sizeof(float) * n;\n      }\n    }\n\n    // BufferB(int n, int k, void *ptr, bool if_pack = false) : n(n), k(k), if_pack(if_pack) {\n    //   b = reinterpret_cast<dt *>(ptr);\n    //   d = reinterpret_cast<float *>(moe_kernel::offset_pointer(b, n * k / 2));\n    // }\n    BufferB(int n, int k, bool if_pack = false, char mat_type = 'n', bool plain = true) : n(n), k(k), if_pack(if_pack) {\n      if (n % N_STEP != 0 || k % K_STEP != 0) {\n        throw std::runtime_error(\"n and k must be multiples of N_STEP and K_STEP respectively\");\n      }\n    }\n    BufferB(int n, int k, void *ptr, bool if_pack = false, char mat_type = 'n', bool plain = true)\n        : BufferB(n, k, if_pack, mat_type, plain) {\n      set_data(ptr, plain);\n    }\n    void set_data(void *ptr, bool plain = true) {\n      b = reinterpret_cast<dt *>(ptr);\n      d = reinterpret_cast<float *>(moe_kernel::offset_pointer(b, n * k / 2));\n    }\n    size_t required_size() const { return sizeof(int8_t) * n * k / 2 + sizeof(float) * n; }\n    BufferB offset_col(size_t col_begin, size_t col_block) {\n      auto bufferb = BufferB(col_block, k, moe_kernel::offset_pointer(b, (col_begin * k) / 2), if_pack);\n      bufferb.d = d + col_begin;\n      return bufferb;\n    }\n    // B 矩阵是 K * N 的矩阵，存储在 b 中, 是列主序的 (column major)\n    void from_mat(ggml_bf16_t *src, int ith, int nth, int n_new = -1, bool if_pack = false,\n                  bool plain = true) {  // CHECK: nth has no usage\n      if (!if_pack && !this->if_pack) throw std::runtime_error(\"from mat for buffer should be packed\");\n      if (n_new > 0) {\n        n = n_new;  // 如果 n_new 大于 0，则使用 n_new\n      }\n      // 这里将 src 转换成 int8_t 的形式，按照k 维度量化  (也就是按列量化)\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < N_STEP && n_begin + i < n_block_size; i++) {\n          float amax = 0;\n          // TODO: 后续用 SVE 来加速\n          for (int j = 0; j < k; j++) {\n            // 先把 src 转换成 float\n            float f = bf16_to_fp32(src[(n_block_begin + n_begin + i) * k + j]);\n            f = f < 0 ? -f : f;\n            if (f > amax) {\n              amax = f;\n            }\n          }\n          d[n_block_begin + n_begin + i] = amax / 112.0;\n          // TODO: 后续用 SVE 来加速\n          for (int k_start = 0; k_start < k; k_start += (PACK_SIZE_K * 2)) {\n            for (int j = 0; j < PACK_SIZE_K; j++) {\n              size_t split_n = (n_begin + i) / PACK_SIZE_N;\n              size_t n_idx = (n_begin + i) % PACK_SIZE_N;\n              size_t split_k = k_start / (PACK_SIZE_K * 2);\n              size_t k_idx = j;\n\n              size_t buff_idx = n_block_begin * k / 2 + split_n * PACK_SIZE_N * k / 2 +\n                                split_k * PACK_SIZE_N * PACK_SIZE_K + n_idx * PACK_SIZE_K + k_idx;\n\n              float f0 = bf16_to_fp32(src[(n_block_begin + n_begin + i) * k + k_start + j]);\n              float f1 = bf16_to_fp32(src[(n_block_begin + n_begin + i) * k + k_start + j + PACK_SIZE_K]);\n              // static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));\n              int8_t b0 = static_cast<int8_t>(std::round((f0 / (d[n_block_begin + n_begin + i] * 16.0))) * 16);\n              int8_t b1 = static_cast<int8_t>(std::round((f1 / (d[n_block_begin + n_begin + i] * 16.0))) * 16);\n              int8_t b01 = (b0 & 0xF0) | ((b1 >> 4) & 0x0F);\n              // int8_t b01 = ((b0 << 4) & 0xF0) | ((b1)&0x0F);\n\n              b[buff_idx] = b01;\n            }\n          }\n        }\n      }\n    }\n\n    void from_mat(float *src, int ith, int nth, int n_new = -1, bool if_pack = false) {  // CHECK: nth has no usage\n      if (!if_pack && !this->if_pack) throw std::runtime_error(\"from mat for buffer should be packed\");\n      if (n_new > 0) {\n        n = n_new;  // 如果 n_new 大于 0，则使用 n_new\n      }\n      // 这里将 src 转换成 int8_t 的形式，按照k 维度量化  (也就是按列量化)\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      // DEBUG: 查看 average 值\n      float average = 0;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < N_STEP && n_begin + i < n_block_size; i++) {\n          float amax = 0;\n          // TODO: 后续用 SVE 来加速\n          for (int j = 0; j < k; j++) {\n            // 先把 src 转换成 float\n            float f = src[(n_block_begin + n_begin + i) * k + j];\n            f = f < 0 ? -f : f;\n            average += f;\n            if (f > amax) {\n              amax = f;\n            }\n          }\n          average /= k;\n          d[n_block_begin + n_begin + i] = amax / 112.0;\n          // printf(\"amax: %f,average: %f\\n\", amax, average);\n          // TODO: 后续用 SVE 来加速\n          // 通过这个 amax 来量化这一列\n          for (int k_start = 0; k_start < k; k_start += (PACK_SIZE_K * 2)) {\n            for (int j = 0; j < PACK_SIZE_K; j++) {\n              size_t split_n = (n_begin + i) / PACK_SIZE_N;\n              size_t n_idx = (n_begin + i) % PACK_SIZE_N;\n              size_t split_k = k_start / (PACK_SIZE_K * 2);\n              size_t k_idx = j;\n\n              size_t buff_idx = n_block_begin * k / 2 + split_n * PACK_SIZE_N * k / 2 +\n                                split_k * PACK_SIZE_N * PACK_SIZE_K + n_idx * PACK_SIZE_K + k_idx;\n\n              float f0 = (src[(n_block_begin + n_begin + i) * k + k_start + j]);\n              float f1 = (src[(n_block_begin + n_begin + i) * k + k_start + j + PACK_SIZE_K]);\n              // static_cast<int8_t>(std::round(f / d[n_block_begin + n_begin + i]));\n              int8_t b0 = static_cast<int8_t>(std::round((f0 / (d[n_block_begin + n_begin + i] * 16.0))) * 16);\n              int8_t b1 = static_cast<int8_t>(std::round((f1 / (d[n_block_begin + n_begin + i] * 16.0))) * 16);\n              int8_t b01 = (b0 & 0xF0) | ((b1 >> 4) & 0x0F);\n              // int8_t b01 = ((b0 << 4) & 0xF0) | ((b1)&0x0F);\n              // if (n_begin == 0 && i == 0 && k_start == 0 && j <= 10) {\n              //   printf(\"b0: %d, b1: %d, b01: %d,f0: %f, f1: %f, scale: %f\\n\", b0, b1, b01, f0, f1,\n              //          d[n_block_begin + n_begin + i]);\n              // }\n\n              b[buff_idx] = b01;\n            }\n          }\n        }\n      }\n      // printf(\"from_mat done, n: %d, k: %d, if_pack: %d\\n\", n, k, if_pack);\n    }\n\n    float *get_scale(int n, int n_begin) { return d + n_begin; }\n  };\n  /* 将 buffer A 转为 buffer B, [m,k](row major) -> [k,n](column major) (n = m)\n    而量化部分没变化，直接 buffer A 的 d = buffer B 的 d，校验 m 和 n 以及 k是否相等，才能转换\n  */\n  static void convert_buffer_a_to_buffer_b(BufferA *ba, BufferB *bb) {\n    if (bb->n != ba->max_m || bb->k != ba->k || bb->if_pack != ba->if_pack) {\n      throw std::runtime_error(\n          \"BufferA and BufferB dimensions do not match for conversion, or they are not the same pack.\");\n    }\n    throw std::runtime_error(\"int4 not support convert\");\n    // bb->b = ba->a;\n    // bb->d = ba->d;\n  }\n\n  static void convert_buffer_b_to_buffer_a(BufferB *bb, BufferA *ba) {\n    if (ba->max_m != bb->n || ba->k != bb->k || ba->if_pack != bb->if_pack) {\n      throw std::runtime_error(\n          \"BufferB and BufferA dimensions do not match for conversion, or they are not the same pack.\");\n    }\n    throw std::runtime_error(\"int4 not support convert\");\n\n    // ba->a = bb->b;\n    // ba->d = bb->d;\n  }\n  // 改变当前 C 的 view\n  static void change_view(BufferC *c_src, BufferC *c_dst) {\n    if (c_src->max_m != c_dst->n || c_src->n != c_dst->max_m || c_src->if_row_major == c_dst->if_row_major) {\n      throw std::runtime_error(\"C buffer size mismatch or they are the same major\");\n    }\n    throw std::runtime_error(\"int4 not support convert\");\n\n    // c_dst->c = c_src->c;\n  }\n  // 此函数作用是，对 int32结果的 c 矩阵应用 A和 B 矩阵的scale（反量化）\n  // 这里的 c 矩阵是 m * n 的矩阵，存储在 c 中, 是行主序的 (row major)\n  // A 矩阵是 m * k 的矩阵，按照行量化，其 scale 是 d 是 m 维度的，对应每一行的量化系数\n  // B 矩阵是 k * n 的矩阵，按照列量化，其 scale 是 d 是 n 维度的，对应每一列的量化系数\n  // C 的第 i 行第 j 列的缩放值就是 A 的第 i 行的缩放值 * B 的第 j 列的缩放值\n  static void apply_scale(int m, int n, float *c, BufferA *ba, BufferB *bb, BufferC *bc) {\n    // TODO: 后续用 SVE 来加速\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n        float *scale_a = ba->get_scale(m, m_begin + i);\n        for (int n_begin = 0; n_begin < n; n_begin += N_STEP) {\n          for (int j = 0; j < N_STEP && n_begin + j < n; j++) {\n            float *scale_b = bb->get_scale(n, n_begin + j);\n            c[(m_begin + i) * n + (n_begin + j)] = (*scale_a) * (*scale_b) * bc->c[(m_begin + i) * n + (n_begin + j)];\n          }\n        }\n      }\n    }\n  }\n  // 对第二个维度分块的 apply scale\n  static void apply_scale(int m, int n, float *c, BufferA *ba, BufferB *bb, BufferC *bc, int ith, int nth, int block) {\n    // printf(\"use split apply scale\\n\");\n    auto [n_start, n_end] = split_range_n_block(n, ith, nth, block);\n    // TODO: 后续用 SVE 来加速\n    for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n      for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n        float *scale_a = ba->get_scale(m, m_begin + i);\n        for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {\n          for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {\n            float *scale_b = bb->get_scale(n, n_begin + j);\n            c[(m_begin + i) * n + (n_begin + j)] = (*scale_a) * (*scale_b) * bc->c[(m_begin + i) * n + (n_begin + j)];\n          }\n        }\n      }\n    }\n  }\n  // 两个维度均有分块的 apply scale\n  // C 矩阵区分是 row major 还是 column major\n  static void apply_scale(float *c, int ldc, BufferA *ba, BufferB *bb, BufferC *bc, int m_start, int m_end, int n_start,\n                          int n_end, bool if_row_major = true, long long c_row_idx_offset = 0,\n                          long long c_col_idx_offset = 0) {\n    if (if_row_major) {\n      for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {\n        for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {\n          float *scale_a = ba->get_scale(m_end, m_begin + i);\n          for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {\n            for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {\n              float *scale_b = bb->get_scale(n_end, n_begin + j);\n              c[(m_begin + i + c_row_idx_offset) * ldc + (n_begin + j + c_col_idx_offset)] =\n                  (*scale_a) * (*scale_b) *\n                  bc->c[(m_begin + i + c_row_idx_offset) * ldc + (n_begin + j + c_col_idx_offset)];\n            }\n          }\n        }\n      }\n    } else {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {\n        for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {\n          float *scale_b = bb->get_scale(n_end, n_begin + j);\n          for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {\n            for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {\n              float *scale_a = ba->get_scale(m_end, m_begin + i);\n              c[(n_begin + j + c_col_idx_offset) * ldc + (m_begin + i + c_row_idx_offset)] =\n                  (*scale_a) * (*scale_b) *\n                  bc->c[(n_begin + j + c_col_idx_offset) * ldc + (m_begin + i + c_row_idx_offset)];\n            }\n          }\n        }\n      }\n    }\n  }\n\n  // 两个维度均有分块的 apply scale\n  // C 矩阵区分是 row major 还是 column major\n  static void apply_scale(float *c, int ldc, BufferA *ba, BufferB *bb, int32_t *bc, int m_start, int m_end, int n_start,\n                          int n_end, bool if_row_major = true, long long c_row_idx_offset = 0,\n                          long long c_col_idx_offset = 0) {\n    if (if_row_major) {\n      for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {\n        for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {\n          float *scale_a = ba->get_scale(m_end, m_begin + i);\n          for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {\n            for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {\n              float *scale_b = bb->get_scale(n_end, n_begin + j);\n              c[(m_begin + i + c_row_idx_offset) * ldc + (n_begin + j + c_col_idx_offset)] =\n                  (*scale_a) * (*scale_b) *\n                  bc[(m_begin + i + c_row_idx_offset) * ldc + (n_begin + j + c_col_idx_offset)];\n            }\n          }\n        }\n      }\n    } else {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += N_STEP) {\n        for (int j = 0; j < N_STEP && n_begin + j < n_end; j++) {\n          float *scale_b = bb->get_scale(n_end, n_begin + j);\n          for (int m_begin = m_start; m_begin < m_end; m_begin += M_STEP) {\n            for (int i = 0; i < M_STEP && m_begin + i < m_end; i++) {\n              float *scale_a = ba->get_scale(m_end, m_begin + i);\n              c[(n_begin + j + c_col_idx_offset) * ldc + (m_begin + i + c_row_idx_offset)] =\n                  (*scale_a) * (*scale_b) *\n                  bc[(n_begin + j + c_col_idx_offset) * ldc + (m_begin + i + c_row_idx_offset)];\n            }\n          }\n        }\n      }\n    }\n  }\n};\n\n}  // namespace moe_kernel\n\n#endif"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/la/mat_kernel.cpp",
    "content": "#include \"../api/mat_kernel.h\"\n\n#include <cassert>\n\nnamespace {\nconstexpr int kInt4ElementDivisor = 2;\nconstexpr int kInt8ElementDivisor = 1;\n}  // namespace\nextern \"C\" {\nvoid decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                               const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,\n                               const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,\n                               const int8_t oa, const void* b, const size_t ldb, const int8_t ob, const float beta,\n                               int32_t* c, const size_t ldc, const int32_t* oc);\n\nvoid prefill_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                                const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,\n                                const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,\n                                const int8_t oa, const void* b, const size_t ldb, const int8_t ob, const float beta,\n                                int32_t* c, const size_t ldc, const int32_t* oc);\n\nvoid decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                                    const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,\n                                    const size_t m, const size_t n, const size_t k, const float alpha, const void* a,\n                                    const size_t lda, const int8_t oa, const void* b, const size_t ldb, const int8_t ob,\n                                    const float beta, int32_t* c, const size_t ldc, const int32_t* oc);\n\nvoid prefill_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                                     const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,\n                                     const size_t m, const size_t n, const size_t k, const float alpha, const void* a,\n                                     const size_t lda, const int8_t oa, const void* b, const size_t ldb,\n                                     const int8_t ob, const float beta, int32_t* c, const size_t ldc,\n                                     const int32_t* oc);\n}\n\nMatKernelSelection select_kernel_for_int4(MatKernelVariant variant) {\n  switch (variant) {\n    case MatKernelVariant::Decode:\n      return {decode_int4_cblas_gemm_s8s8s32, kInt4ElementDivisor};\n    case MatKernelVariant::Prefill:\n      return {prefill_int4_cblas_gemm_s8s8s32, kInt4ElementDivisor};\n  }\n  return {nullptr, 0};\n}\n\nMatKernelSelection select_kernel_for_int8(MatKernelVariant variant) {\n  switch (variant) {\n    case MatKernelVariant::Decode:\n      return {decode_cblas_gemm_s8s8s32, kInt8ElementDivisor};\n    case MatKernelVariant::Prefill:\n      return {prefill_cblas_gemm_s8s8s32, kInt8ElementDivisor};\n  }\n  return {nullptr, 0};\n}"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/la/utils.hpp",
    "content": "#pragma once\n// #include <arm_sve.h>\n#include <cstdint>\n#include <cstring>\n\n// 简单截断模式：直接丢弃低 16 位\nstatic inline uint16_t float_to_bf16_trunc(float f) {\n  uint32_t u;\n  // 按位拷贝，避免 strict‑aliasing UB\n  memcpy(&u, &f, sizeof(u));   // :contentReference[oaicite:3]{index=3}\n  return (uint16_t)(u >> 16);  // 截断得到高 16 位 :contentReference[oaicite:4]{index=4}\n}\n\nstatic inline void convert_32fp32_to_32bf16_pure_c(const float* src, uint16_t* dst) {\n  // src 已偏移至 token_nth * hidden_size\n  for (int e = 0; e < 32; e++) {  // 共 32 个元素\n    // 选择截断或四舍五入\n    dst[e] = float_to_bf16_trunc(src[e]);\n  }\n}\n\n// 把 32 个 bf16 元素转换成 32 个 fp32 元素\n\nstatic inline void convert_32bf16_to_32fp32_pure_c(const uint16_t* src, float* dst) {\n  for (int e = 0; e < 32; e++) {\n    uint32_t temp = ((uint32_t)src[e]) << 16;  // 将 BF16 左移 16 位\n    memcpy(&dst[e], &temp, sizeof(float));     // 将结果复制到 FP32 变量中\n  }\n}"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/mat_kernel/aocl_kernel/kernel.cpp",
    "content": "#include <stdexcept>\n\n#include \"../batch_gemm_api.hpp\"\n#include \"blis.h\"\n\nnamespace {\n\nchar ToAoclOrder(KERNEL_CBLAS_LAYOUT layout) {\n  switch (layout) {\n    case KernelCblasRowMajor:\n      return 'r';\n    case KernelCblasColMajor:\n      return 'c';\n  }\n  throw std::invalid_argument(\"Unsupported KERNEL_CBLAS_LAYOUT value\");\n}\n\nchar ToAoclTranspose(KERNEL_CBLAS_TRANSPOSE transpose) {\n  switch (transpose) {\n    case KernelCblasNoTrans:\n      return 'n';\n    case KernelCblasTrans:\n      return 't';\n    case KernelCblasConjTrans:\n    case KernelCblasConjNoTrans:\n      break;\n  }\n  throw std::invalid_argument(\"Unsupported KERNEL_CBLAS_TRANSPOSE value\");\n}\n\n}  // namespace\n\n// 映射表，layout 从KERNEL_CBLAS_ORDER 映射到'r'或者'c',以及将KERNEL_CBLAS_TRANSPOSE映射到'n'或者't'\n#ifdef __cplusplus\nextern \"C\" {\n#endif\nvoid decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                               const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,\n                               const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,\n                               const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,\n                               int32_t* c, const size_t ldc, const int32_t* oc) {\n  const char order = ToAoclOrder(layout);\n  const char op_a = ToAoclTranspose(transa);\n  const char op_b = ToAoclTranspose(transb);\n  (void)offsetc;\n  aocl_gemm_s8s8s32os32(order, op_a, op_b, static_cast<dim_t>(m), static_cast<dim_t>(n), static_cast<dim_t>(k),\n                        static_cast<int32_t>(alpha), static_cast<const int8_t*>(a), static_cast<dim_t>(lda), 'n',\n                        static_cast<const int8_t*>(b), static_cast<dim_t>(ldb), 'r', static_cast<int32_t>(beta), c,\n                        static_cast<dim_t>(ldc), nullptr);\n}\n\nvoid prefill_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                                const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,\n                                const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,\n                                const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,\n                                int32_t* c, const size_t ldc, const int32_t* oc) {\n  const char order = ToAoclOrder(layout);\n  const char op_a = ToAoclTranspose(transa);\n  const char op_b = ToAoclTranspose(transb);\n  (void)offsetc;\n  aocl_gemm_s8s8s32os32(order, op_a, op_b, static_cast<dim_t>(m), static_cast<dim_t>(n), static_cast<dim_t>(k),\n                        static_cast<int32_t>(alpha), static_cast<const int8_t*>(a), static_cast<dim_t>(lda), 'n',\n                        static_cast<const int8_t*>(b), static_cast<dim_t>(ldb), 'r', static_cast<int32_t>(beta), c,\n                        static_cast<dim_t>(ldc), nullptr);\n}\n\nvoid prefill_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                                     const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,\n                                     const size_t m, const size_t n, const size_t k, const float alpha, const void* a,\n                                     const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,\n                                     const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,\n                                     const int32_t* oc) {\n  throw std::runtime_error(\"int4 not support prefill\");\n}\n\nvoid decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                                    const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,\n                                    const size_t m, const size_t n, const size_t k, const float alpha, const void* a,\n                                    const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,\n                                    const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,\n                                    const int32_t* oc) {\n  throw std::runtime_error(\"int4 not support decode\");\n}\n\nvoid reorder_B_gemm(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,\n                    const size_t n, const size_t ldb, const void* b, void* b_reordered) {\n  const char order = ToAoclOrder(layout);\n  const char op_b = ToAoclTranspose(transb);\n  aocl_reorder_s8s8s32os32(order, op_b, 'B', static_cast<const int8_t*>(b), static_cast<int8_t*>(b_reordered), k, n,\n                           ldb);\n}\n\nsize_t get_reorder_B_size(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,\n                          const size_t n) {\n  return aocl_get_reorder_buf_size_s8s8s32os32(ToAoclOrder(layout), ToAoclTranspose(transb), 'B', k, n);\n}\n\n#ifdef __cplusplus\n}\n#endif"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/mat_kernel/batch_gemm_api.hpp",
    "content": "#pragma once\n#include <cstddef>\n#ifndef _BATCH_GEMM_KERNEL_API_\n#define _BATCH_GEMM_KERNEL_API_\n#include \"../api/common.h\"\n#ifdef __cplusplus\nextern \"C\" {\n#endif\nvoid decode_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                               const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,\n                               const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,\n                               const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,\n                               int32_t* c, const size_t ldc, const int32_t* oc);\n\nvoid prefill_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                                const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc, const size_t m,\n                                const size_t n, const size_t k, const float alpha, const void* a, const size_t lda,\n                                const BLASINT8 oa, const void* b, const size_t ldb, const BLASINT8 ob, const float beta,\n                                int32_t* c, const size_t ldc, const int32_t* oc);\n\nvoid decode_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                                    const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,\n                                    const size_t m, const size_t n, const size_t k, const float alpha, const void* a,\n                                    const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,\n                                    const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,\n                                    const int32_t* oc);\n\nvoid prefill_int4_cblas_gemm_s8s8s32(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transa,\n                                     const KERNEL_CBLAS_TRANSPOSE transb, const KERNEL_CBLAS_OFFSET offsetc,\n                                     const size_t m, const size_t n, const size_t k, const float alpha, const void* a,\n                                     const size_t lda, const BLASINT8 oa, const void* b, const size_t ldb,\n                                     const BLASINT8 ob, const float beta, int32_t* c, const size_t ldc,\n                                     const int32_t* oc);\nvoid reorder_B_gemm(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,\n                    const size_t n, const size_t ldb, const void* b, void* b_reordered);\nsize_t get_reorder_B_size(const KERNEL_CBLAS_LAYOUT layout, const KERNEL_CBLAS_TRANSPOSE transb, const size_t k,\n                          const size_t n);\n\n#ifdef __cplusplus\n}\n#endif\n#endif /*** _BATCH_GEMM_KERNEL_API_ ***/"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/moe.hpp",
    "content": "#ifndef MOE_KERNEL_HPP\n#define MOE_KERNEL_HPP\n\n#include <algorithm>\n#include <cmath>\n#include <cstddef>\n#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n#include <filesystem>\n#include <fstream>\n#include <iostream>\n#include <vector>\n\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"../common.hpp\"\n#include \"../moe-tp.hpp\"\n#include \"api/common.h\"\n#include \"api/mat_kernel.h\"\n#include \"llama.cpp/ggml.h\"\ntemplate <class T, bool PLAIN = true>\nclass MOE_KERNEL_TP\n#ifdef FORWARD_TIME_PROFILE\n    : protected TimePerf\n#endif\n{\n private:\n  int tp_part_idx;\n  std::filesystem::path prefix;\n\n  void* gate_proj_;  // [expert_num * intermediate_size * hidden_size ( /32 if\n                     // quantized)]\n  void* up_proj_;    // [expert_num * intermediate_size * hidden_size ( /32 if\n                     // quantized)]\n  void* down_proj_;  // [expert_num * hidden_size * intermediate_size ( /32 if\n                     // quantized)]\n\n  ggml_bf16_t* m_local_input_;  // [routed_expert_num * max_len * hidden_size]\n  float* m_local_gate_output_;  // [routed_expert_num * max_len * intermediate_size]\n  float* m_local_up_output_;    // [routed_expert_num * max_len * intermediate_size]\n  float* m_local_down_output_;  // [routed_expert_num * max_len * hidden_size]\n\n  std::vector<std::vector<int>> m_local_pos_;    // [max_len, routed_expert_num]\n  std::vector<int> m_local_num_;                 // [expert_num]\n  std::vector<int> m_expert_id_map_;             // [expert_num]\n  std::vector<ggml_bf16_t*> m_local_input_ptr_;  // [expert_num]\n  std::vector<float*> m_local_gate_output_ptr_;  // [expert_num]\n  std::vector<float*> m_local_up_output_ptr_;    // [expert_num]\n  std::vector<float*> m_local_down_output_ptr_;  // [expert_num]\n\n  std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;\n  std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;\n  std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;\n  std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;\n  std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;\n  std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;\n  std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;\n  std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;\n\n  std::vector<void*> gate_up_owner_ptr_;\n  std::vector<void*> down_owner_ptr_;\n\n  inline void write_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size,\n                            size_t scale_size) {\n    // printf(\"expert %d, size %ld, scale size %ld\\n\", expert_idx, size, scale_size);\n    // std::ofstream of(prefix / (T::name() + mat_class + std::to_string(expert_idx)  + \"_quant_\" + \".kt\"));\n    std::ofstream of(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                               std::to_string(size - scale_size) + \"Byte\" + \"_quant_\" + \".kt\"));\n    if (of.is_open() == false) {\n      printf(\"no such file: %s\", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                                            std::to_string(size - scale_size) + \"Byte\" + \"_quant_\" + \".kt\"))\n                                     .c_str());\n      // throw std::runtime_error(\"No such file\");\n    }\n    of.write((char*)bb, size - scale_size);\n    of.close();\n    // of.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_scale_\" + \".kt\"));\n    of.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" + std::to_string(scale_size) + \"Byte\" +\n                      \"_scale_\" + \".kt\"));\n    if (of.is_open() == false) {\n      printf(\"no such file\\n\");\n      // throw std::runtime_error(\"No such file\");\n    }\n    of.write(((char*)bb) + size - scale_size, scale_size);\n  }\n\n  inline void read_weights(std::filesystem::path prefix, std::string mat_class, char* bb, int expert_idx, size_t size,\n                           size_t scale_size, uint8_t mat_split, uint8_t mat_split_idex) {\n    // std::ifstream f(prefix / (T::name() + mat_class + std::to_string(expert_idx)  + \"_quant_\" + \".kt\"));\n    std::ifstream f(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                              std::to_string(size - scale_size) + \"Byte\" + \"_quant_\" + \".kt\"));\n    if (f.is_open() == false) {\n      printf(\"no such file: %s\\n\", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                                              std::to_string(size - scale_size) + \"Byte\" + \"_quant_\" + \".kt\"))\n                                       .c_str());\n      // throw std::runtime_error(\"No such file\");\n    }\n    f.seekg(mat_split_idex * (size - scale_size) / mat_split);\n    f.read(((char*)bb) + mat_split_idex * (size - scale_size) / mat_split, (size - scale_size) / mat_split);\n    f.close();\n    // f.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_scale_\" + \".kt\"));\n    f.open(prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" + std::to_string(scale_size) + \"Byte\" +\n                     \"_scale_\" + \".kt\"));\n    if (f.is_open() == false) {\n      printf(\"no such file: %s\\n\", (prefix / (T::name() + mat_class + std::to_string(expert_idx) + \"_\" +\n                                              std::to_string(scale_size) + \"Byte\" + \"_scale_\" + \".kt\"))\n                                       .c_str());\n      // throw std::runtime_error(\"No such file\");\n    }\n    f.seekg(mat_split_idex * scale_size / mat_split);\n    f.read((((char*)bb) + size - scale_size) + mat_split_idex * scale_size / mat_split, scale_size / mat_split);\n  }\n\n public:\n  using input_t = ggml_bf16_t;\n  using output_t = float;\n\n  GeneralMOEConfig config_;\n  static constexpr double ELEMENT_SIZE = T::ELEMENT_SIZE;\n\n  MOE_KERNEL_TP(GeneralMOEConfig config, int tp_part_idx) {\n    printf(\"  Creating AMD_MOE_TP %d at numa %d\\n\", tp_part_idx, numa_node_of_cpu(sched_getcpu()));\n    auto& load = config.load;\n    auto& save = config.save;\n    if (load && config.path == \"\") {\n      load = false;\n    }\n\n    prefix = config.path;\n    prefix = prefix / (\"_layer_\" + std::to_string(config.layer_idx)) / (\"_numa_\" + std::to_string(tp_part_idx));\n    if (save) {\n      std::cout << \"Creating \" << prefix << std::endl;\n      std::filesystem::create_directories(prefix);\n    }\n    if (load) {\n      if (std::filesystem::exists(prefix)) {\n        std::cout << \"Loading from \" << prefix << std::endl;\n      } else {\n        throw std::runtime_error(\"Path not found: \" + prefix.string());\n      }\n    }\n\n    this->tp_part_idx = tp_part_idx;\n    config_ = config;\n    gate_proj_ = config_.gate_proj;\n    up_proj_ = config_.up_proj;\n    down_proj_ = config_.down_proj;\n\n    MemoryRequest mem_requests;\n    mem_requests.append_pointer(&m_local_input_,\n                                sizeof(input_t) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);\n\n    mem_requests.append_pointer(&m_local_gate_output_, sizeof(float) * config_.num_experts_per_tok * config_.max_len *\n                                                           config_.intermediate_size);\n    mem_requests.append_pointer(\n        &m_local_up_output_, sizeof(float) * config_.num_experts_per_tok * config_.max_len * config_.intermediate_size);\n    mem_requests.append_pointer(&m_local_down_output_,\n                                sizeof(float) * config_.num_experts_per_tok * config_.max_len * config_.hidden_size);\n\n    m_local_pos_.resize(config_.max_len);\n    for (int i = 0; i < config_.max_len; i++) {\n      m_local_pos_[i].resize(config_.num_experts_per_tok);\n    }\n    m_expert_id_map_.resize(config_.expert_num);\n    m_local_num_.resize(config_.expert_num);\n    m_local_input_ptr_.resize(config_.expert_num);\n    m_local_gate_output_ptr_.resize(config_.expert_num);\n    m_local_up_output_ptr_.resize(config_.expert_num);\n    m_local_down_output_ptr_.resize(config_.expert_num);\n\n    // printf(\"tp part %d alloc layer %d, %f GB, on numa %d\\n\", tp_part_idx, config_.layer_idx,\n    //        1e-9 * config_.expert_num *\n    //            (T::BufferB::required_size(config_.intermediate_size, config_.hidden_size) * 2 +\n    //             T::BufferB::required_size(config_.hidden_size, config_.intermediate_size)),\n    //        numa_node_of_cpu(sched_getcpu()));\n    // 统一分配一块巨大的内存用于权重：\n    size_t gate_up_exp_size =\n        T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN) +\n        T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);\n\n    for (uint64_t i = 0; i < config_.expert_num; i++) {\n      gate_up_ba_.push_back(std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, nullptr));\n      gate_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));\n      up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, nullptr));\n      down_ba_.push_back(std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, nullptr));\n      down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, nullptr));\n      void* gate_up_down_per_exp_ptr = std::aligned_alloc(64, gate_up_exp_size);\n      gate_up_owner_ptr_.push_back(gate_up_down_per_exp_ptr);\n\n      gate_bb_.push_back(std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size,\n                                                               gate_up_down_per_exp_ptr, PACKED, 'u', PLAIN));\n      up_bb_.push_back(std::make_shared<typename T::BufferB>(\n          config_.intermediate_size, config_.hidden_size,\n          offset_pointer(gate_up_down_per_exp_ptr,\n                         T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN)),\n          PACKED, 'u', PLAIN));\n\n      void* down_bb_ptr = std::aligned_alloc(\n          64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN));\n      down_owner_ptr_.push_back(down_bb_ptr);\n      down_bb_.push_back(std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size,\n                                                               down_bb_ptr, PACKED, 'd', PLAIN));\n    }\n\n    for (int i = 0; i < config_.expert_num; i++) {\n      mem_requests.append_function([this, i](void* new_ptr) { gate_up_ba_[i]->set_data(new_ptr); },\n                                   T::BufferA::required_size(config_.max_len, config_.hidden_size));\n      mem_requests.append_function([this, i](void* new_ptr) { gate_bc_[i]->set_data(new_ptr); },\n                                   T::BufferC::required_size(config_.max_len, config_.intermediate_size));\n      mem_requests.append_function([this, i](void* new_ptr) { up_bc_[i]->set_data(new_ptr); },\n                                   T::BufferC::required_size(config_.max_len, config_.intermediate_size));\n      mem_requests.append_function([this, i](void* new_ptr) { down_ba_[i]->set_data(new_ptr); },\n                                   T::BufferA::required_size(config_.max_len, config_.intermediate_size));\n      mem_requests.append_function([this, i](void* new_ptr) { down_bc_[i]->set_data(new_ptr); },\n                                   T::BufferC::required_size(config_.max_len, config_.hidden_size));\n    }\n\n    shared_mem_buffer_numa.alloc(tp_part_idx, this, mem_requests);\n  }\n\n  MOE_KERNEL_TP(const MOE_KERNEL_TP&) = delete;\n  MOE_KERNEL_TP& operator=(const MOE_KERNEL_TP&) = delete;\n  MOE_KERNEL_TP(MOE_KERNEL_TP&&) = delete;\n  MOE_KERNEL_TP& operator=(MOE_KERNEL_TP&&) = delete;\n\n  ~MOE_KERNEL_TP() {\n    // printf(\"  Destroying KML_MOE_TP %lx\\n\", (intptr_t)(this));\n    for (void* ptr : gate_up_owner_ptr_) {\n      std::free(ptr);\n    }\n    for (void* ptr : down_owner_ptr_) {\n      std::free(ptr);\n    }\n  }\n\n  void load_weights() {\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config_.physical_to_logical_map;\n    if (config_.gate_projs.size()) {\n      printf(\"load from safetensor\");\n      pool->do_work_stealing_job(\n          config_.expert_num, nullptr,\n          [this, physical_to_logical_map](int expert_id) {\n            uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_id);\n            {\n              size_t scale_size = config_.intermediate_size * sizeof(float);\n              size_t whole_size_ =\n                  T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);\n              size_t size = whole_size_ - scale_size;\n              void* dst_ = PLAIN ? gate_bb_[expert_id]->b : gate_bb_[expert_id]->b_pack[0];\n\n              memcpy(dst_, config_.gate_projs[tp_part_idx][logical_expert_id], size);\n\n              if constexpr (T::BufferB::SCALE) {\n                memcpy(gate_bb_[expert_id]->d, config_.gate_scales[tp_part_idx][logical_expert_id], scale_size);\n              }\n\n              whole_size_ =\n                  T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);\n              size = whole_size_ - scale_size;\n              dst_ = PLAIN ? up_bb_[expert_id]->b : up_bb_[expert_id]->b_pack[0];\n              memcpy(dst_, config_.up_projs[tp_part_idx][logical_expert_id], size);\n\n              if constexpr (T::BufferB::SCALE) {\n                memcpy(up_bb_[expert_id]->d, config_.up_scales[tp_part_idx][logical_expert_id], scale_size);\n              }\n            }\n\n            {\n              size_t scale_size = config_.hidden_size * sizeof(float);\n              size_t whole_size_ =\n                  T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN);\n              size_t size = whole_size_ - scale_size;\n              void* dst_ = PLAIN ? down_bb_[expert_id]->b : down_bb_[expert_id]->b_pack[0];\n              memcpy(dst_, config_.down_projs[tp_part_idx][logical_expert_id], size);\n\n              if constexpr (T::BufferB::SCALE) {\n                memcpy(down_bb_[expert_id]->d, config_.down_scales[tp_part_idx][logical_expert_id], scale_size);\n              }\n            }\n          },\n          nullptr);\n\n    } else {\n      static uint8_t mat_type_all = 3, mat_split = 1;\n      if (config_.load) {\n        std::cout << \"Loading from \" << prefix << std::endl;\n        for (int task_id = 0; task_id < config_.expert_num * mat_type_all * mat_split; task_id++) {\n          int64_t expert_idx = task_id / (mat_type_all * mat_split);\n          uint8_t mat_class = (task_id % (mat_type_all * mat_split)) / mat_split;\n          uint8_t mat_split_idex = task_id % mat_split;\n          uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n          void* src_;\n          if (mat_class == 0) {  // the up matrix\n            src_ = PLAIN ? up_bb_[expert_idx]->b : up_bb_[expert_idx]->b_pack[0];\n            size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);\n            size_t scale_size = config_.intermediate_size * sizeof(float);\n            read_weights(prefix, \"_up_\", (char*)src_, logical_expert_id, size, scale_size, mat_split, mat_split_idex);\n          } else if (mat_class == 1) {\n            void* src_ = PLAIN ? gate_bb_[expert_idx]->b : gate_bb_[expert_idx]->b_pack[0];\n            size_t size = T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);\n            size_t scale_size = config_.intermediate_size * sizeof(float);\n            read_weights(prefix, \"_gate_\", (char*)src_, logical_expert_id, size, scale_size, mat_split, mat_split_idex);\n          } else {\n            void* src_ = PLAIN ? down_bb_[expert_idx]->b : down_bb_[expert_idx]->b_pack[0];\n            size_t size = T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN);\n            size_t scale_size = config_.hidden_size * sizeof(float);\n            read_weights(prefix, \"_down_\", (char*)src_, logical_expert_id, size, scale_size, mat_split, mat_split_idex);\n          }\n        }\n      }\n// check process, store down matrix to check\n#ifdef CHECK\n      load_check();\n#endif\n#ifndef CHECK\n      else\n#endif\n      {\n        if (tp_part_idx == 0) {\n          std::cout << \"  online quant from bf16\" << std::endl;\n        }\n        int nth = T::recommended_nth_up_gate(config_.intermediate_size);\n        pool->do_work_stealing_job(\n            nth * config_.expert_num, nullptr,\n            [this, nth, physical_to_logical_map](int task_id) {\n              int64_t expert_idx = task_id / nth;\n              uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n              int ith = task_id % nth;\n              // gate part\n              gate_bb_[logical_expert_id]->from_mat(\n                  (ggml_bf16_t*)config_.gate_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,\n                  ith, nth, -1, PACKED, PLAIN);\n              // up part\n              up_bb_[logical_expert_id]->from_mat(\n                  (ggml_bf16_t*)config_.up_proj + logical_expert_id * config_.intermediate_size * config_.hidden_size,\n                  ith, nth, -1, PACKED, PLAIN);\n            },\n            nullptr);\n\n        nth = T::recommended_nth_down(config_.hidden_size);\n        pool->do_work_stealing_job(\n            nth * config_.expert_num, nullptr,\n            [this, nth, physical_to_logical_map](int task_id) {\n              int64_t expert_idx = task_id / nth;\n              int ith = task_id % nth;\n              uint64_t logical_expert_id = expert_map(physical_to_logical_map, expert_idx);\n              // down part\n              down_bb_[logical_expert_id]->from_mat(\n                  (ggml_bf16_t*)config_.down_proj + logical_expert_id * config_.hidden_size * config_.intermediate_size,\n                  ith, nth, -1, PACKED, PLAIN);\n            },\n            nullptr);\n      }\n#ifdef CHECK\n      verify_load_right();\n#endif\n      // save process\n      if (config_.save) {\n        pool->do_work_stealing_job(\n            config_.expert_num * mat_type_all, nullptr,\n            [this, physical_to_logical_map](int task_id) {\n              int64_t expert_idx = task_id / mat_type_all;\n              expert_idx = expert_map(physical_to_logical_map, expert_idx);\n              uint8_t mat_class = task_id % mat_type_all;\n              if (mat_class == 0) {  // the up matrix\n                size_t size =\n                    T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);\n                size_t scale_size = config_.intermediate_size * sizeof(float);\n                write_weights(prefix, \"_up_\", (char*)up_bb_[expert_idx]->b_pack[0], expert_idx, size, scale_size);\n              } else if (mat_class == 1) {\n                size_t size =\n                    T::BufferB::required_size(config_.intermediate_size, config_.hidden_size, PACKED, 'u', PLAIN);\n                size_t scale_size = config_.intermediate_size * sizeof(float);\n                write_weights(prefix, \"_gate_\", (char*)gate_bb_[expert_idx]->b_pack[0], expert_idx, size, scale_size);\n              } else if (mat_class == 2) {\n                size_t size =\n                    T::BufferB::required_size(config_.hidden_size, config_.intermediate_size, PACKED, 'd', PLAIN);\n                size_t scale_size = config_.hidden_size * sizeof(float);\n                write_weights(prefix, \"_down_\", (char*)down_bb_[expert_idx]->b_pack[0], expert_idx, size, scale_size);\n              }\n            },\n            nullptr);\n      }\n    }\n  }\n\n  void warm_up() {\n    int qlen = config_.max_len;\n    std::vector<uint8_t> input(sizeof(input_t) * qlen * config_.hidden_size);\n    std::vector<uint8_t> output(sizeof(output_t) * qlen * config_.hidden_size);\n    std::vector<int64_t> expert_ids(qlen * config_.num_experts_per_tok);\n    std::vector<float> weights(qlen * config_.num_experts_per_tok);\n    for (int i = 0; i < qlen * config_.num_experts_per_tok; i++) {\n      expert_ids[i] = i % config_.expert_num;\n      weights[i] = 0.01;\n    }\n    forward(qlen, config_.num_experts_per_tok, expert_ids.data(), weights.data(), input.data(), output.data());\n  }\n\n#define MOE_DIRECT_OR_POOL_BY_VAR(var, fn)                       \\\n  do {                                                           \\\n    if (var < 5) {                                               \\\n      for (int i = 0; i < (var); i++) {                          \\\n        (fn)(i);                                                 \\\n      }                                                          \\\n    } else {                                                     \\\n      pool->do_work_stealing_job((var), nullptr, (fn), nullptr); \\\n    }                                                            \\\n  } while (0)\n  static float act_fn(float x) { return x / (1.0f + expf(-x)); }\n\n  void forward(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input, void* output) {\n    // Unified forward path: 'd' for decode (qlen<=1), 'p' for prefill (qlen>1)\n    char mode = (qlen <= 1) ? 'd' : 'p';\n    forward_unified(mode, qlen, k, expert_ids, weights, input, output);\n  }\n\n  // Helper to select B pointer for up or gate mat based on packing\n  inline int8_t* select_up_or_gate_B_ptr_(bool do_up, int expert_idx, int ith, int devide_elements_size) {\n    if constexpr (PLAIN) {\n      int8_t* base = do_up ? (int8_t*)up_bb_[expert_idx]->b : (int8_t*)gate_bb_[expert_idx]->b;\n      return base + ith * config_.hidden_size * T::N_BLOCK_UP_GATE / devide_elements_size;\n    } else {\n      return do_up ? (int8_t*)up_bb_[expert_idx]->b_pack[ith] : (int8_t*)gate_bb_[expert_idx]->b_pack[ith];\n    }\n  }\n\n  // Helper to select B pointer for down mat based on packing\n  inline int8_t* select_down_B_ptr_(int expert_idx, int ith, int devide_elements_size) {\n    if constexpr (PLAIN) {\n      return ((int8_t*)down_bb_[expert_idx]->b) +\n             ith * config_.intermediate_size * T::N_BLOCK_DOWN / devide_elements_size;\n    } else {\n      return (int8_t*)down_bb_[expert_idx]->b_pack[ith];\n    }\n  }\n\n  // Unified implementation for decode/prefill using mode 'd' or 'p'\n  void forward_unified(char mode, int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,\n                       void* output) {\n    MatKernelVariant var = (mode == 'p') ? MatKernelVariant::Prefill : MatKernelVariant::Decode;\n    MatKernelSelection kernel = select_mat_kernel<T>(var);\n    GemmFn cblas_gemm_s8s8s32 = kernel.fn;\n    int devide_elements_size = kernel.divide_elements_size;\n\n#ifdef FORWARD_TIME_PROFILE\n    forward_perf_start();\n#endif\n    int max_local_num = 0;\n\n    auto pool = config_.pool->get_subpool(tp_part_idx);\n\n    int activated_expert = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n      for (int j = 0; j < k; j++) {\n        if (config_.should_skip_expert(expert_ids[i * k + j])) {\n          continue;\n        }\n        m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n      }\n    }\n\n    for (int i = 0; i < config_.expert_num; i++) {\n      if (m_local_num_[i] > 0) {\n        max_local_num = std::max(max_local_num, m_local_num_[i]);\n        m_expert_id_map_[activated_expert] = i;\n        activated_expert++;\n      }\n    }\n\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;\n      m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n      m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n      m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;\n      offset += m_local_num_[i];\n    }\n\n#ifdef FORWARD_TIME_PROFILE\n    PROFILE_RECORD_TIME_STAMP(\"prepare\");\n#endif\n\n    // Copy inputs into expert-local buffers\n    MOE_DIRECT_OR_POOL_BY_VAR(qlen, [&](int i) {\n      for (int j = 0; j < k; j++) {\n        if (config_.should_skip_expert(expert_ids[i * k + j])) {\n          continue;\n        }\n        memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,\n               (input_t*)input + i * config_.hidden_size, sizeof(input_t) * config_.hidden_size);\n      }\n    });\n\n#ifdef FORWARD_TIME_PROFILE\n    PROFILE_RECORD_TIME_STAMP(\"copy_input\");\n#endif\n\n    // Quantize expert inputs (row-wise)\n    {\n      size_t mth = T::recommended_mth(max_local_num);\n      MOE_DIRECT_OR_POOL_BY_VAR(activated_expert * mth, [&](int task_id) {\n        int task_id_expert = task_id / mth;\n        int ith = task_id % mth;\n        int expert_idx = m_expert_id_map_[task_id_expert];\n        if (ith * T::M_BLOCK >= m_local_num_[expert_idx]) return;\n        gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], ith, mth);\n      });\n    }\n\n#ifdef FORWARD_TIME_PROFILE\n    PROFILE_RECORD_TIME_STAMP(\"quant_input\");\n#endif\n\n    int nth_up = T::recommended_nth_up_gate(config_.intermediate_size, mode);\n    int mth = T::recommended_mth(max_local_num);\n    int32_t oc = 0;\n\n    // Up and Gate GEMMs + dequant scale\n    pool->do_work_stealing_job(\n        mth * nth_up * activated_expert * 2, nullptr,\n        [this, qlen, nth_up, oc, &cblas_gemm_s8s8s32, devide_elements_size, mth](int task_id2) {\n          int task_id = task_id2 / 2;\n          bool do_up = task_id2 % 2;\n          int expert_idx = m_expert_id_map_[task_id / (nth_up * mth)];\n          task_id = task_id % (nth_up * mth);\n          int ith = task_id % nth_up;\n          int jth = task_id / nth_up;\n          if (jth * T::M_BLOCK >= m_local_num_[expert_idx]) return;\n          int m_block = T::M_BLOCK;\n          if ((jth + 1) * T::M_BLOCK > m_local_num_[expert_idx]) {\n            m_block = m_local_num_[expert_idx] - jth * T::M_BLOCK;\n          }\n          int8_t* a_ptr = (int8_t*)gate_up_ba_[expert_idx]->a + jth * T::M_BLOCK * config_.hidden_size;\n          int8_t* b_ptr = select_up_or_gate_B_ptr_(do_up, expert_idx, ith, devide_elements_size);\n          int32_t* c_ptr = (do_up ? (int32_t*)up_bc_[expert_idx]->c : (int32_t*)gate_bc_[expert_idx]->c) +\n                           ith * T::N_BLOCK_UP_GATE + jth * T::M_BLOCK * config_.intermediate_size;\n\n          cblas_gemm_s8s8s32(KernelCblasRowMajor, KernelCblasNoTrans, KernelCblasTrans, KernelCblasFixOffset, m_block,\n                             T::N_BLOCK_UP_GATE, config_.hidden_size, 1.0, a_ptr, config_.hidden_size, 0, b_ptr,\n                             config_.hidden_size, 0, 0.0, c_ptr, config_.intermediate_size, &oc);\n\n          if (do_up) {\n            T::apply_scale(m_local_num_[expert_idx], config_.intermediate_size, m_local_up_output_ptr_[expert_idx],\n                           gate_up_ba_[expert_idx].get(), up_bb_[expert_idx].get(), up_bc_[expert_idx].get(), ith,\n                           nth_up, T::N_BLOCK_UP_GATE, jth);\n          } else {\n            T::apply_scale(m_local_num_[expert_idx], config_.intermediate_size, m_local_gate_output_ptr_[expert_idx],\n                           gate_up_ba_[expert_idx].get(), gate_bb_[expert_idx].get(), gate_bc_[expert_idx].get(), ith,\n                           nth_up, T::N_BLOCK_UP_GATE, jth);\n          }\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    PROFILE_RECORD_TIME_STAMP(\"up_gate\");\n#endif\n\n    // Activate gate and multiply by up\n    {\n      int nth = T::recommended_nth(config_.intermediate_size);\n      auto up_gate_fn = [this, nth](int task_id) {\n        int expert_idx = m_expert_id_map_[task_id / nth];\n        int ith = task_id % nth;\n        auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);\n        for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n          float* gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];\n          float* up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];\n          for (int j = n_start; j < n_end; j++) {\n            gate_output_ptr[j] = act_fn(gate_output_ptr[j]) * up_output_ptr[j];\n          }\n        }\n      };\n      MOE_DIRECT_OR_POOL_BY_VAR(nth * activated_expert, up_gate_fn);\n    }\n\n#ifdef FORWARD_TIME_PROFILE\n    PROFILE_RECORD_TIME_STAMP(\"act\");\n#endif\n\n    pool->do_work_stealing_job(\n        activated_expert, nullptr,\n        [this](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx]);\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    PROFILE_RECORD_TIME_STAMP(\"quant_down_input\");\n#endif\n\n    int nth_down = T::recommended_nth_down(config_.hidden_size, mode);\n    pool->do_work_stealing_job(\n        mth * nth_down * activated_expert, nullptr,\n        [this, qlen, nth_down, oc, &cblas_gemm_s8s8s32, devide_elements_size, mth](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / (nth_down * mth)];\n          task_id = task_id % (nth_down * mth);\n          int ith = task_id % nth_down;\n          int jth = task_id / nth_down;\n          if (jth * T::M_BLOCK >= m_local_num_[expert_idx]) return;\n          int m_block = T::M_BLOCK;\n          if ((jth + 1) * T::M_BLOCK > m_local_num_[expert_idx]) {\n            m_block = m_local_num_[expert_idx] - jth * T::M_BLOCK;\n          }\n          int8_t* a_ptr = ((int8_t*)down_ba_[expert_idx]->a) + jth * T::M_BLOCK * config_.intermediate_size;\n          int8_t* b_ptr = select_down_B_ptr_(expert_idx, ith, devide_elements_size);\n          int32_t* c_ptr =\n              ((int32_t*)down_bc_[expert_idx]->c) + ith * T::N_BLOCK_DOWN + jth * T::M_BLOCK * config_.hidden_size;\n          cblas_gemm_s8s8s32(KernelCblasRowMajor, KernelCblasNoTrans, KernelCblasTrans, KernelCblasFixOffset, m_block,\n                             T::N_BLOCK_DOWN, config_.intermediate_size, 1.0, a_ptr, config_.intermediate_size, 0,\n                             b_ptr, config_.intermediate_size, 0, 0.0, c_ptr, config_.hidden_size, &oc);\n\n          T::apply_scale(m_local_num_[expert_idx], config_.hidden_size, m_local_down_output_ptr_[expert_idx],\n                         down_ba_[expert_idx].get(), down_bb_[expert_idx].get(), down_bc_[expert_idx].get(), ith,\n                         nth_down, T::N_BLOCK_DOWN, jth);\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    PROFILE_RECORD_TIME_STAMP(\"down\");\n#endif\n\n    // Merge k experts per token with weights\n    size_t block_dim = 512;\n    size_t block_num = (config_.hidden_size + block_dim - 1) / block_dim;\n    pool->do_work_stealing_job(\n        qlen * block_num, nullptr,\n        [this, k, expert_ids, weights, output, block_dim, block_num](int i) {\n          int q_idx = i / block_num;\n          int block_idx = i % block_num;\n          int e_start = block_idx * block_dim;\n          int e_end =\n              ((block_idx + 1) * block_dim) < config_.hidden_size ? ((block_idx + 1) * block_dim) : config_.hidden_size;\n          for (int e = e_start; e < e_end; e++) {\n            float sum = 0;\n            for (int j = 0; j < k; j++) {\n              if (config_.should_skip_expert(expert_ids[q_idx * k + j])) {\n                continue;\n              }\n              sum += weights[q_idx * k + j] * ((float*)m_local_down_output_ptr_[expert_ids[q_idx * k + j]])\n                                                  [m_local_pos_[q_idx][j] * config_.hidden_size + e];\n            }\n            ((float*)output)[q_idx * config_.hidden_size + e] = sum;\n          }\n        },\n        nullptr);\n\n#ifdef FORWARD_TIME_PROFILE\n    time_perf_name = std::string(\"[moe] \") + ((mode == 'p') ? \"layer prefill\" : \"decode layer \") +\n                     std::to_string(config_.layer_idx) + \" tp_part_idx: \" + std::to_string(tp_part_idx);\n    perf_report();\n#endif\n  }\n\n  /* merged into forward_unified */\n  void forward_decode(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,\n                      void* output) {\n    forward_unified('d', qlen, k, expert_ids, weights, input, output);\n  }\n\n  void forward_prefill(int qlen, int k, const int64_t* expert_ids, const float* weights, const void* input,\n                       void* output) {\n    forward_unified('p', qlen, k, expert_ids, weights, input, output);\n  }\n};\n\ntemplate <typename K, bool T>\nclass TP_MOE<MOE_KERNEL_TP<K, T>> : public TP_MOE_Common<MOE_KERNEL_TP<K, T>> {\n public:\n  using TP_MOE_Common<MOE_KERNEL_TP<K, T>>::TP_MOE_Common;\n\n  void load_weights() {\n    auto& config = this->config;\n    auto& tps = this->tps;\n    auto& tp_count = this->tp_count;\n    auto pool = config.pool;\n    const uint64_t* physical_to_logical_map = (const uint64_t*)config.physical_to_logical_map;\n    if (config.gate_projs.empty() == false) {\n      printf(\"TP Load from loader\\n\");\n      pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });\n      this->weights_loaded = true;\n    } else if (config.gate_proj != nullptr) {\n      printf(\"From BF16\\n\");\n      for (auto i = 0; i < tp_count; i++) {\n        auto& tpc = tps[i]->config_;\n        size_t gate_up_elcount = tpc.intermediate_size * tpc.hidden_size;\n        tpc.gate_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];\n        tpc.up_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];\n        tpc.down_proj = new ggml_bf16_t[tpc.expert_num * gate_up_elcount];\n        if (tps[i]->config_.load == false) {\n          pool->get_subpool(i)->do_work_stealing_job(\n              tpc.expert_num, nullptr,\n              [&](int expert_id_) {\n                size_t expert_id = expert_map(physical_to_logical_map, expert_id_);\n                memcpy((ggml_bf16_t*)tpc.gate_proj + expert_id * gate_up_elcount,\n                       (ggml_bf16_t*)config.gate_proj + expert_id * config.intermediate_size * config.hidden_size +\n                           i * gate_up_elcount,\n                       sizeof(ggml_bf16_t) * gate_up_elcount);\n                memcpy((ggml_bf16_t*)tpc.up_proj + expert_id * gate_up_elcount,\n                       (ggml_bf16_t*)config.up_proj + expert_id * config.intermediate_size * config.hidden_size +\n                           i * gate_up_elcount,\n                       sizeof(ggml_bf16_t) * gate_up_elcount);\n                for (size_t col = 0; col < config.hidden_size; col++) {\n                  memcpy((ggml_bf16_t*)tpc.down_proj + expert_id * tpc.hidden_size * tpc.intermediate_size +\n                             col * tpc.intermediate_size,\n                         (ggml_bf16_t*)config.down_proj + expert_id * config.intermediate_size * config.hidden_size +\n                             col * config.intermediate_size + i * tpc.intermediate_size,\n                         sizeof(ggml_bf16_t) * tpc.intermediate_size);\n                }\n              },\n              nullptr);\n        }\n      }\n\n      pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });\n\n      for (auto i = 0; i < tp_count; i++) {\n        auto& tpc = tps[i]->config_;\n        delete[] (ggml_bf16_t*)(tpc.gate_proj);\n        delete[] (ggml_bf16_t*)(tpc.up_proj);\n        delete[] (ggml_bf16_t*)(tpc.down_proj);\n      }\n      if (config.save) {\n        // free the bf16 weights after saving\n        tps.clear();\n      }\n\n      this->weights_loaded = true;\n    } else if (config.path != \"\") {\n      printf(\"TP Load from file\\n\");\n      pool->dispense_backend()->do_numa_job([this, pool](int numa_id) { this->tps[numa_id]->load_weights(); });\n      this->weights_loaded = true;\n    } else {\n      throw std::runtime_error(\"no weight source\");\n    }\n  }\n\n  void merge_results(int qlen, void* output, bool incremental) {\n    // #ifdef FORWARD_TIME_PROFILE\n    //     forward_perf_start();\n    // #endif\n    auto pool = this->config.pool;\n    auto merge_fn = [this, output, incremental](int token_nth) {\n      auto& local_output_numa = this->local_output_numa;\n      auto& tp_configs = this->tp_configs;\n      auto& tp_count = this->tp_count;\n      auto& config = this->config;\n      float* merge_to = local_output_numa[0] + token_nth * tp_configs[0].hidden_size;\n      if (incremental) {\n        for (int e = 0; e < config.hidden_size; e++) {\n          merge_to[e] += ggml_bf16_to_fp32(((ggml_bf16_t*)output + token_nth * config.hidden_size)[e]);\n        }\n      }\n\n      for (int i = 1; i < tp_count; i++) {\n        float* merge_from = local_output_numa[i] + token_nth * tp_configs[i].hidden_size;\n        // TODO: 后续用 SVE 来加速\n        // for (int e = 0; e < tp_configs[i].hidden_size; e += 16) {\n        //   *((__m512 *)(merge_to + e)) = _mm512_add_ps(*((__m512 *)(merge_to + e)), *((__m512 *)(merge_from + e)));\n        // }\n        // CHECK: 目前用普通的纯 C++ 来实现\n        for (int e = 0; e < tp_configs[i].hidden_size; e++) {\n          merge_to[e] += merge_from[e];\n        }\n      }\n\n      convert_or_copy((ggml_bf16_t*)output + token_nth * config.hidden_size, merge_to, config.hidden_size);\n\n      // for (int e = 0; e < config.hidden_size; e += 32) {\n      // TODO: 这里需要用 SVE 来加速，实现 fp32 到 bf16 的转换\n      // __m512 x0 = *(__m512 *)(merge_to + e);\n      // __m512 x1 = *(__m512 *)(merge_to + e + 16);\n      // avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + token_nth * config.hidden_size + e));\n\n      // CHECK: 目前用普通的纯 C++ 来实现 fp32 到 bf16 的转换\n\n      // convert_32fp32_to_32bf16_pure_c(merge_to + e,\n      // (uint16_t *)((ggml_bf16_t *)output + token_nth * config.hidden_size + e));\n\n      // }\n    };\n    MOE_DIRECT_OR_POOL_BY_VAR(qlen, merge_fn);\n    // #ifdef FORWARD_TIME_PROFILE\n    //     PROFILE_RECORD_TIME_STAMP(\"moe merge done\");\n    // #endif\n    // #ifdef FORWARD_TIME_PROFILE\n    //     time_perf_name = \"[moe merge] decode layer \" + std::to_string(this->config.layer_idx);\n    //     perf_report();\n    // #endif\n  }\n\n  void merge_results(int qlen, void* output) { merge_results(qlen, output, false); }\n};\n\n#endif"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/test/convert-test.cpp",
    "content": "#include <arm_sve.h>\n\n#include <cmath>\n#include <cstdlib>\n#include <fstream>\n#include <stdexcept>\n\n#include \"../../reduce.hpp\"\n#include \"../../rms-norm.hpp\"\n#include \"../../rope.hpp\"\n#include \"../../softmax.hpp\"\n#include \"../la/arm_kml.hpp\"\n#include \"llama.cpp/ggml-common.h\"\n#include \"llama.cpp/ggml.h\"\n\nvoid bf16_to_fp16(const ggml_bf16_t* src, ggml_fp16_t* dst, size_t n) {\n  for (size_t i = 0; i < n; ++i) {\n    float x = ggml_bf16_to_fp32(src[i]);\n    dst[i] = ggml_fp32_to_fp16(x);\n  }\n}\n\nvoid debug_rope() {\n  float16_t* fp16 = new float16_t[1024 * 64];\n\n  for (size_t i = 0; i < 1024 * 64; i++) {\n    fp16[i] = static_cast<double>(std::rand()) / RAND_MAX;\n  }\n  std::ofstream(\"before_rope\", std::ios::binary).write((char*)fp16, 1024 * 64 * sizeof(float16_t));\n\n  DeepseekV3YarnRotaryEmbedding rope(64, 163840, 10000, 40, 4096, 32, 1, 1, 1);\n\n  rope.init(1024);\n\n  Rope<DeepseekV3YarnRotaryEmbedding, float16_t> rope_applier;\n  rope_applier.apply_multiple(rope, fp16, 64, 64, 0, 1024);\n\n  std::ofstream(\"cos\", std::ios::binary).write((char*)rope.cos(0), 1024 * 32 * sizeof(float));\n  std::ofstream(\"sin\", std::ios::binary).write((char*)rope.sin(0), 1024 * 32 * sizeof(float));\n\n  std::ofstream(\"after_rope\", std::ios::binary).write((char*)fp16, 1024 * 64 * sizeof(float16_t));\n}\n\nvoid debug_softmax() {\n  float16_t* fp16 = new float16_t[64 * 1024];\n\n  for (size_t i = 0; i < 1024 * 64; i++) {\n    fp16[i] = static_cast<double>(std::rand()) / RAND_MAX * 10;\n    if (i % 12 == 0) {\n      fp16[i] -= std::numeric_limits<float16_t>::infinity();\n    }\n  }\n  std::ofstream(\"before_softmax\", std::ios::binary).write((char*)fp16, 1024 * 64 * sizeof(float16_t));\n\n  Softmax<float16_t>::apply_multiple(64, fp16, 1024, 1024);\n  std::ofstream(\"after_softmax\", std::ios::binary).write((char*)fp16, 1024 * 64 * sizeof(float16_t));\n}\n\nvoid debug_inf() {\n  float16_t x, y;\n  // x = std::numeric_limits<float16_t>::infinity(); // 0.00\n  // y = -std::numeric_limits<float16_t>::infinity(); // -0.00\n  // x = 1e10;\n  x = std::numeric_limits<float>::infinity();   // inf\n  y = -std::numeric_limits<float>::infinity();  // -inf\n  printf(\"x = %f, y = %f\\n\", x, y);\n}\n\nvoid debug_reduce() {\n  std::vector<float16_t*> fp16s(128);\n  for (size_t i = 0; i < 128; i++) {\n    fp16s[i] = new float16_t[1024];\n    for (size_t j = 0; j < 1024; j++) {\n      fp16s[i][j] = i;\n    }\n  }\n\n  reduce_sum(fp16s.data(), 128, 0, 10);\n  for (int i = 0; i < 10; i++) {\n    printf(\"%f \", fp16s[0][i]);\n  }\n}\n\nint main() {\n  debug_reduce();\n\n  return 0;\n}\n"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/test/debug.hpp",
    "content": "#ifndef KML_DEBUG_HPP\n#define KML_DEBUG_HPP\n\n#include <arm_sve.h>\n\n#include <cstdint>\n#include <cstdlib>\n#include <fstream>\n#include <string>\n\ninline std::string get_env_or_default(const char* var_name, const std::string& default_value) {\n  const char* value = std::getenv(var_name);\n  return (value != nullptr) ? std::string(value) : default_value;\n}\n\ninline void dump_bin(std::string file_name, float16_t* data, size_t count) {\n  file_name = get_env_or_default(\"KML_DEBUG_PATH\", \"debug\") + \"/\" + file_name + \".f16\";\n  std::ofstream f(file_name, std::ios::binary);\n  f.write(reinterpret_cast<const char*>(data), count * sizeof(*data));\n  f.close();\n}\ninline void dump_bin(std::string file_name, float* data, size_t count) {\n  file_name = get_env_or_default(\"KML_DEBUG_PATH\", \"debug\") + \"/\" + file_name + \".f32\";\n  std::ofstream f(file_name, std::ios::binary);\n  f.write(reinterpret_cast<const char*>(data), count * sizeof(*data));\n  f.close();\n}\ninline void dump_bin(std::string file_name, int64_t* data, size_t count) {\n  file_name = get_env_or_default(\"KML_DEBUG_PATH\", \"debug\") + \"/\" + file_name + \".int64\";\n  std::ofstream f(file_name, std::ios::binary);\n  f.write(reinterpret_cast<const char*>(data), count * sizeof(*data));\n  f.close();\n}\n\ninline void dump_bin(std::string file_name, int8_t* data, size_t count) {\n  file_name = get_env_or_default(\"KML_DEBUG_PATH\", \"debug\") + \"/\" + file_name + \".int8\";\n  std::ofstream f(file_name, std::ios::binary);\n  f.write(reinterpret_cast<const char*>(data), count * sizeof(*data));\n  f.close();\n}\n\ninline void dump_bin(std::string file_name, int32_t* data, size_t count) {\n  file_name = get_env_or_default(\"KML_DEBUG_PATH\", \"debug\") + \"/\" + file_name + \".int32\";\n  std::ofstream f(file_name, std::ios::binary);\n  f.write(reinterpret_cast<const char*>(data), count * sizeof(*data));\n  f.close();\n}\n\ninline void load_bin(std::string file_name, float* data, size_t count) {\n  file_name = get_env_or_default(\"KML_DEBUG_PATH\", \"debug\") + \"/\" + file_name + \".f32\";\n  std::ifstream f(file_name, std::ios::binary);\n  if (!f.is_open()) {\n    throw std::runtime_error(\"Failed to open file: \" + file_name);\n  }\n  f.read(reinterpret_cast<char*>(data), count * sizeof(*data));\n  f.close();\n}\n\n#endif\n"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/test/int4_mul-test.cpp",
    "content": "#include <cstdint>\n#include <cstdio>\n#include <cstdlib>\n\n#include \"../la/arm_kml.hpp\"\n#include \"debug.hpp\"\n#include \"kblas.h\"\nconst int M = 1, K = 7168, N = 8;\n\nint main() {\n  // 随机生成a, b, c矩阵\n  arm_kml::GemmKernelInt4::BufferA buffer_a(M, K);\n  arm_kml::GemmKernelInt4::BufferB buffer_b(N, K, true);\n  arm_kml::GemmKernelInt4::BufferC buffer_c(M, N);\n\n  arm_kml::GemmKernelInt8::BufferA buffer_a_check(M, K);\n  arm_kml::GemmKernelInt8::BufferB buffer_b_check(N, K, true);\n  arm_kml::GemmKernelInt8::BufferC buffer_c_check(M, N);\n\n  float* a = (float*)aligned_alloc(64, sizeof(float) * M * K);\n  float* b = (float*)aligned_alloc(64, sizeof(float) * K * N);\n  float* c = (float*)aligned_alloc(64, sizeof(float) * M * N);\n  float* c_check = (float*)aligned_alloc(64, sizeof(float) * M * N);\n  int8_t* buffer_a_data = (int8_t*)aligned_alloc(64, buffer_a.required_size());\n  int4_2_t* buffer_b_data = (int4_2_t*)aligned_alloc(64, buffer_b.required_size());\n  int32_t* c_data = (int32_t*)aligned_alloc(64, buffer_c.required_size());\n  int8_t* buffer_a_data_check = (int8_t*)aligned_alloc(64, buffer_a_check.required_size());\n  int8_t* buffer_b_data_check = (int8_t*)aligned_alloc(64, buffer_b_check.required_size());\n  int32_t* c_data_check = (int32_t*)aligned_alloc(64, buffer_c_check.required_size());\n  // 初始化元素内容\n  load_bin(\"input.bin\", a, M * K);\n  load_bin(\"local_q_a_proj_quant.bin\", b, N * K);\n\n  // for (int i = 0; i < M * K; i++) {\n  //   // 随机浮点数\n  //   // a[i] = (static_cast<float>(rand()) / (float)RAND_MAX) / 25 - 0.02;\n  //   a[i] = -(static_cast<float>(rand()) / (float)RAND_MAX) / 25;\n  //   // a[i] = i % 10;\n  //   // a[i] = 1;\n  // }\n  // for (int i = 0; i < K * N; i++) {\n  //   // 随机浮点数\n  //   // b[i] = (static_cast<float>(rand()) / (float)RAND_MAX) / 25 - 0.02;\n  //   b[i] = -(static_cast<float>(rand()) / (float)RAND_MAX) / 25;\n  //   // b[i] = i % 10;\n  //   // b[i] = 1;\n  // }\n  // // // // 设置离群值\n  // for (int i = 0; i < N; i++) {\n  //   b[i * K] = 0.06f; // 设置第一列为离群值\n  // }\n  // // 打印一下输入矩阵和权重矩阵\n  // printf(\"Input matrix a:\\n\");\n  // for (int i = 0; i < M; i++) {\n  //   for (int j = 0; j < K; j++) {\n  //     printf(\"%f \", a[i * K + j]);\n  //   }\n  //   printf(\"\\n\");\n  // }\n  // printf(\"Weight matrix b:\\n\");\n  // for (int i = 0; i < N; i++) {\n  //   for (int j = 0; j < K; j++) {\n  //     printf(\"%f \", b[i * K + j]);\n  //   }\n  //   printf(\"\\n\");\n  // }\n  buffer_a.set_data(buffer_a_data);\n  buffer_b.set_data(buffer_b_data);\n  buffer_c.set_data(c_data);\n  buffer_a_check.set_data(buffer_a_data_check);\n  buffer_b_check.set_data(buffer_b_data_check);\n  buffer_c_check.set_data(c_data_check);\n  //   调用 from mat 进行量化\n  buffer_a.from_mat(M, a, 0, M);\n  for (int i = 0; i <= arm_kml::GemmKernelInt4::recommended_nth(N); i++) {\n    buffer_b.from_mat(b, i, arm_kml::GemmKernelInt4::recommended_nth(N));\n  }\n  buffer_a_check.from_mat(M, a, 0, M);\n  for (int i = 0; i <= arm_kml::GemmKernelInt8::recommended_nth(N); i++) {\n    buffer_b_check.from_mat(b, i, arm_kml::GemmKernelInt8::recommended_nth(N));\n  }\n  // 进行乘法\n  arm_kml::MatRef<int8_t> a_ref(buffer_a.a, M, K, K, CblasRowMajor);\n  arm_kml::MatRef<int4_2_t> b_ref(buffer_b.b, K, N, K, CblasColMajor, CblasNoTrans, buffer_b.if_pack);\n  arm_kml::MatRef<int32_t> c_ref(buffer_c.c, M, N, N, CblasRowMajor);\n  b_ref = b_ref.offset_col(0, N);\n\n  arm_kml::MatRef<int8_t> a_ref_check(buffer_a_check.a, M, K, K, CblasRowMajor);\n  arm_kml::MatRef<int8_t> b_ref_check(buffer_b_check.b, K, N, K, CblasColMajor, CblasNoTrans, buffer_b_check.if_pack);\n  arm_kml::MatRef<int32_t> c_ref_check(buffer_c_check.c, M, N, N, CblasRowMajor);\n\n  arm_kml::decode_mul_mat_clearc(a_ref, b_ref, c_ref);\n  arm_kml::decode_mul_mat_clearc(a_ref_check, b_ref_check, c_ref_check);\n  //   反量化，apply scale\n  arm_kml::GemmKernelInt4::apply_scale(c, N, &buffer_a, &buffer_b, &buffer_c, 0, M, 0, N, true);\n  arm_kml::GemmKernelInt8::apply_scale(c_check, N, &buffer_a_check, &buffer_b_check, &buffer_c_check, 0, M, 0, N, true);\n  // 打印结果,比较 c 和 c_check\n  const float threashold = 0.05;\n  for (int i = 0; i < M * N; i++) {\n    float diff_relative = (c[i] - c_check[i]) / (c_check[i] + 1e-6);\n\n    if (diff_relative > threashold || diff_relative < -threashold) {\n      printf(\"diff_relative: %f\\n\", diff_relative);\n      printf(\"Mismatch at index %d: c = %f, c_check = %f\\n\", i, c[i], c_check[i]);\n    } else {\n      printf(\"Match at index %d: c = %f, c_check = %f\\n\", i, c[i], c_check[i]);\n    }\n  }\n  return 0;\n}"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/test/mat_test.cpp",
    "content": "#include \"arm_kml.hpp\"\n\nint main() {\n  const size_t M = 128, N = 64;\n  float16_t* a = new float16_t[M * N];\n  float16_t* b = new float16_t[M * N];\n  float16_t* c = new float16_t[M * M];\n  float16_t* c_check = new float16_t[M * M];\n  for (size_t i = 0; i < M * N; i++) {\n    a[i] = static_cast<double>(std::rand()) / RAND_MAX / 10.0;\n    b[i] = static_cast<double>(std::rand()) / RAND_MAX / 10.0;\n  }\n\n  arm_kml::MatRef<float16_t> aref(a, M, N, M, CblasColMajor);\n  arm_kml::MatRef<float16_t> bref(b, N, M, M, CblasColMajor);\n  arm_kml::MatRef<float16_t> cref(c, M, M, M, CblasColMajor);\n  {\n    memset(c, 0, M * M * sizeof(float16_t));\n    memset(c_check, 0, M * M * sizeof(float16_t));\n    arm_kml::mul_mat(aref, bref, cref);\n  }\n}\n"
  },
  {
    "path": "kt-kernel/operators/moe_kernel/test/utils_test.cpp",
    "content": "// #pragma once\n#ifdef TEST_UTIL\n#include <arm_neon.h>\n#include <arm_sve.h>\n#include <stdio.h>\n\nstatic inline void sve_32xbf16_to_32xfp32(const bfloat16_t* src, float* dst0, float* dst1) {\n#ifdef __ARM_FEATURE_SVE\n  // 全真谓词，对应每个 16‑bit 元素\n#else\n// fallback: scalar or NEON\n#endif\n}\n\nstatic inline void neon_32xbf16_to_32xfp32(const uint16_t* src, float* dst0, float* dst1) {\n  // src 指向 32 个连续的 BF16（uint16_t）\n  // dst0、dst1 各指向 16 个 float 的缓冲\n\n  for (int block = 0; block < 4; ++block) {\n    // 每次处理 8 个 BF16 → 8 个 FP32（拆为两次 4→4 存储）\n    uint16x8_t v_bf16 = vld1q_u16(src + block * 8);  // load 8×BF16 :contentReference[oaicite:6]{index=6}\n\n    // 拆低半、高半各 4 个到 u32\n    uint32x4_t lo_u32 = vmovl_u16(vget_low_u16(v_bf16));   // lower 4 → u32 :contentReference[oaicite:7]{index=7}\n    uint32x4_t hi_u32 = vmovl_u16(vget_high_u16(v_bf16));  // upper 4 → u32 :contentReference[oaicite:8]{index=8}\n\n    // 左移 16 位，相当于将 BF16 的 16 位 mantissa+exp 放到 FP32 高位\n    lo_u32 = vshlq_n_u32(lo_u32, 16);  // shift left 16 :contentReference[oaicite:9]{index=9}\n    hi_u32 = vshlq_n_u32(hi_u32, 16);  // shift left 16 :contentReference[oaicite:10]{index=10}\n\n    // 重新解释为 float32x4_t\n    float32x4_t lo_f32 = vreinterpretq_f32_u32(lo_u32);  // bits → FP32 :contentReference[oaicite:11]{index=11}\n    float32x4_t hi_f32 = vreinterpretq_f32_u32(hi_u32);  // bits → FP32 :contentReference[oaicite:12]{index=12}\n\n    // 存储到 dst0 或 dst1，每次存 8 个\n    if (block < 2) {\n      vst1q_f32(dst0 + block * 4, lo_f32);      // store 4 floats :contentReference[oaicite:13]{index=13}\n      vst1q_f32(dst0 + block * 4 + 4, hi_f32);  // store next 4 floats :contentReference[oaicite:14]{index=14}\n    } else {\n      int b = block - 2;\n      vst1q_f32(dst1 + b * 4, lo_f32);      // store 4 floats :contentReference[oaicite:15]{index=15}\n      vst1q_f32(dst1 + b * 4 + 4, hi_f32);  // store next 4 floats :contentReference[oaicite:16]{index=16}\n    }\n  }\n}\n\nint main() {\n  // 测试代码\n  uint16_t bf16_data[32] = {0};  // 假设这里填充了一些 BF16 数据\n  float f32_data0[16] = {0};\n  float f32_data1[16] = {0};\n\n  neon_32xbf16_to_32xfp32(bf16_data, f32_data0, f32_data1);\n\n  // 打印结果\n  for (int i = 0; i < 16; ++i) {\n    printf(\"f32_data0[%d]: %f\\n\", i, f32_data0[i]);\n    printf(\"f32_data1[%d]: %f\\n\", i, f32_data1[i]);\n  }\n\n  return 0;\n}\n#endif\n"
  },
  {
    "path": "kt-kernel/operators/reduce.hpp",
    "content": "#ifndef CPUINFER_REDUCE_HPP\n#define CPUINFER_REDUCE_HPP\n\n#include <cmath>\n\ntemplate <typename T>\nvoid reduce_sum(T** data, size_t data_groups_count, size_t begin, size_t end) {\n  if (data_groups_count <= 1) {\n  } else if (data_groups_count == 2) {\n    for (size_t i = begin; i < end; i++) {\n      data[0][i] += data[1][i];\n    }\n  } else {\n    int part1 = data_groups_count / 2;\n    reduce_sum(data, part1, begin, end);\n    reduce_sum(data + part1, data_groups_count - part1, begin, end);\n    for (size_t i = begin; i < end; i++) {\n      data[0][i] += data[part1][i];\n    }\n  }\n}\n\n#endif"
  },
  {
    "path": "kt-kernel/operators/rms-norm.hpp",
    "content": "#ifndef CPUINFER_RMS_NORM_HPP\n#define CPUINFER_RMS_NORM_HPP\n\n#include <cmath>\n\ntemplate <typename T, typename A>\nconcept RMS_NORM = requires(T t, int size, int hidden_size, int qlen, A* weights, A* input) {\n  { T::rms_norm(hidden_size, qlen, input) } -> std::same_as<void>;\n  { T::rms_norm_single(size, input) } -> std::same_as<void>;\n  { T::rms_norm_with_weights(hidden_size, qlen, weights, input) } -> std::same_as<void>;\n  { T::rms_norm_single_with_weights(size, weights, input) } -> std::same_as<void>;\n};\n\ntemplate <typename A>\nstruct RMSNorm {\n  static void rms_norm_single(int size, A* input) {\n    const float epsilon = 1e-6;\n    float sum = 0;\n    for (int i = 0; i < size; i++) {\n      sum += (float)input[i] * (float)input[i];\n    }\n    sum = sqrt(sum / size + epsilon);\n    for (int i = 0; i < size; i++) {\n      input[i] = (float)input[i] / sum;\n    }\n  }\n\n  static void rms_norm(int hidden_size, int qlen, A* input) {\n    const A epsilon = 1e-6;\n    for (int t = 0; t < qlen; t++) {\n      rms_norm_single(hidden_size, input + t * hidden_size);\n    }\n  }\n\n  static void rms_norm_with_weights(int hidden_size, int qlen, A* weights, A* input) {\n    const A epsilon = 1e-6;\n    for (int t = 0; t < qlen; t++) {\n      rms_norm_single_with_weights(hidden_size, input + t * hidden_size);\n    }\n  }\n  static void rms_norm_single_with_weights(int size, A* weights, A* input) {\n    const float epsilon = 1e-6;\n    float sum = 0;\n    for (int i = 0; i < size; i++) {\n      sum += (float)input[i] * (float)input[i];\n    }\n    sum = sqrt(sum / size + epsilon);\n    for (int i = 0; i < size; i++) {\n      input[i] = (float)weights[i] * (float)input[i] / sum;\n    }\n  }\n};\n\n#endif"
  },
  {
    "path": "kt-kernel/operators/rope.hpp",
    "content": "#ifndef CPUINFER_ROPE_HPP\n#define CPUINFER_ROPE_HPP\n\n#include <algorithm>\n#include <cassert>\n#include <cmath>\n#include <cstring>\n#include <stdexcept>\n#include <vector>\n\ntemplate <typename T, typename E, typename A>\nconcept ROPE_APPLIER = requires(T t, E* emb, int size, int pos_start, int pos_len, A* v) {\n  // must be thread safe and efficient\n\n  // apply embeddings with pos_start to v, v is vector of size\n  { T::apply_single(emb, v, size, pos_start) } -> std::same_as<void>;\n\n  // for every v i, apply embeddings with pos_start + i to v[i], v is vector of size\n  { T::apply_multiple(emb, v, size, pos_start, pos_len) } -> std::same_as<void>;\n};\n\ntemplate <typename T, typename A>\nconcept ROPE_ANGLE = requires(T t, size_t at) {\n  { t.cos(at) } -> std::same_as<float*>;\n  { t.sin(at) } -> std::same_as<float*>;\n  { t.init(at) } -> std::same_as<void>;\n};\n\ntemplate <typename E, typename A>\n  requires ROPE_ANGLE<E, A>\nstruct Rope {\n public:\n  static void apply_single(E& emb, A* v, int size, int pos_start) {\n    if (size == 0) {\n      return;\n    }\n    if (size % 2 != 0) {\n      throw std::invalid_argument(\"Rope::apply_single: 'size' (head_dim) must be even for LLaMA-style RoPE.\");\n    }\n\n    const float* cos = emb.cos(pos_start);\n    const float* sin = emb.sin(pos_start);\n\n    thread_local static std::vector<float> v2;\n    if (v2.size() < size) {\n      v2.resize(size);\n    }\n\n    for (int i = 0; i < size / 2; i++) {\n      float a = v[2 * i], b = v[2 * i + 1];\n      v2[i] = cos[i] * a - sin[i] * b;\n      v2[i + size / 2] = sin[i] * a + cos[i] * b;\n    }\n\n    for (int i = 0; i < size; i++) {\n      v[i] = v2[i];\n    }\n  }\n\n  static void apply_multiple(E& emb, A* v_block_start, int size_per_vector, int pos_start, int pos_len) {\n    if (size_per_vector == 0 || pos_len == 0) {\n      return;\n    }\n    if (size_per_vector % 2 != 0) {\n      throw std::invalid_argument(\"Rope::apply_multiple: 'size_per_vector' (head_dim) must be even.\");\n    }\n\n    for (int i = 0; i < pos_len; ++i) {\n      apply_single(emb, v_block_start + size_per_vector * i, size_per_vector, pos_start + i);\n    }\n  }\n};\n\nclass RotaryEmbeddingBase {\n public:\n  virtual ~RotaryEmbeddingBase() = default;\n  virtual void init(size_t seq_len) {\n    calculate_inv_freq();\n    set_cos_sin_cache(seq_len);\n    this->max_seq_len_cached_ = seq_len;\n  }\n\n protected:\n  RotaryEmbeddingBase(size_t dim, size_t max_pos_embeddings, double base_val)\n      : dim_(dim), max_position_embeddings_(max_pos_embeddings), base_(base_val), max_seq_len_cached_(0) {}\n\n  virtual void calculate_inv_freq() = 0;\n  virtual void set_cos_sin_cache(size_t seq_len) = 0;\n\n  size_t dim_;\n  size_t max_position_embeddings_;\n  double base_;\n  std::vector<double> inv_freq_;\n  size_t max_seq_len_cached_;\n};\n\nclass DeepseekV3RotaryEmbedding : public RotaryEmbeddingBase {\n public:\n  DeepseekV3RotaryEmbedding(size_t dim, size_t max_position_embeddings = 2048, double base = 10000.0f)\n      : RotaryEmbeddingBase(dim, max_position_embeddings, base) {\n    if (this->dim_ % 2 != 0 || this->dim_ < 0) {\n      throw std::invalid_argument(\"Dimension must be even for RotaryEmbedding and >= 0.\");\n    }\n\n    if (this->max_position_embeddings_ < 0) {\n      throw std::invalid_argument(\"DeepseekV3RotaryEmbedding max_position_embeddings_ must be >= 0.\");\n    }\n\n    calculate_inv_freq();\n    set_cos_sin_cache(this->max_position_embeddings_);\n  }\n\n  float* sin(size_t at) { return sin_cached_.data() + at * this->dim_ / 2; }\n  float* cos(size_t at) { return cos_cached_.data() + at * this->dim_ / 2; }\n\n protected:\n  void calculate_inv_freq() override {\n    this->inv_freq_.resize(this->dim_ / 2);\n    for (size_t i = 0; i < this->dim_ / 2; ++i) {\n      this->inv_freq_[i] = 1.0 / std::pow(this->base_, 2.0 * i / this->dim_);\n    }\n  }\n\n  void set_cos_sin_cache(size_t seq_len) override {\n    if (this->inv_freq_.empty()) {\n      calculate_inv_freq();\n    }\n\n    cos_cached_.resize(seq_len * this->dim_ / 2);\n    sin_cached_.resize(seq_len * this->dim_ / 2);\n\n    for (size_t i = 0; i < seq_len; ++i) {\n      for (size_t j = 0; j < this->inv_freq_.size(); ++j) {\n        double freq = static_cast<double>(i) * this->inv_freq_[j];\n        double cos_val = std::cos(freq);\n        double sin_val = std::sin(freq);\n        size_t idx1 = i * this->dim_ / 2 + j;\n\n        cos_cached_.at(idx1) = cos_val;\n        sin_cached_.at(idx1) = sin_val;\n      }\n    }\n    this->max_seq_len_cached_ = seq_len;\n  }\n\n  std::vector<float> cos_cached_;\n  std::vector<float> sin_cached_;\n};\n\ninline double yarn_find_correction_dim(double num_rotations, double dim, double base, double max_position_embeddings) {\n  return (dim * std::log(max_position_embeddings / (num_rotations * static_cast<double>(2.0f) * M_PI))) /\n         (static_cast<double>(2.0f) * std::log(base));\n}\n\ninline std::pair<size_t, size_t> yarn_find_correction_range(double low_rot, double high_rot, size_t dim,\n                                                            double base = 10000,\n                                                            double max_position_embeddings = 2048) {\n  double low_f = std::floor(yarn_find_correction_dim(low_rot, static_cast<double>(dim), base, max_position_embeddings));\n  double high_f =\n      std::ceil(yarn_find_correction_dim(high_rot, static_cast<double>(dim), base, max_position_embeddings));\n\n  size_t low = static_cast<size_t>(std::max(0.0, low_f));\n  size_t high = static_cast<size_t>(std::min(static_cast<double>(dim - 1), high_f));\n  return std::pair{low, high};\n}\n\ninline std::vector<double> yarn_linear_ramp_mask(double min_val, double max_val, size_t dim) {\n  if (std::abs(min_val - max_val) < 1e-6f) {\n    max_val += 0.001;\n  }\n  std::vector<double> ramp_func(dim);\n  for (size_t i = 0; i < dim; ++i) {\n    double linear_func = (static_cast<double>(i) - min_val) / (max_val - min_val);\n    ramp_func[i] = std::clamp(linear_func, 0.0, 1.0);\n  }\n  return ramp_func;\n}\n\ninline double yarn_get_mscale(double scale = 1.0, double mscale = 1.0) {\n  if (scale <= 1.0) {\n    return 1.0;\n  }\n  return 0.1 * mscale * std::log(scale) + 1.0;\n}\n\nclass DeepseekV3YarnRotaryEmbedding : public DeepseekV3RotaryEmbedding {\n public:\n  DeepseekV3YarnRotaryEmbedding(size_t dim, size_t max_position_embeddings = 2048, double base = 10000.0f,\n                                double scaling_factor = 1.0, size_t original_max_position_embeddings = 4096,\n                                double beta_fast = 32.0, double beta_slow = 1.0, double mscale_val = 1.0,\n                                double mscale_all_dim_val = 0.0)\n      : DeepseekV3RotaryEmbedding(dim, 0, base),\n        scaling_factor_(scaling_factor),\n        original_max_position_embeddings_(original_max_position_embeddings),\n        beta_fast_(beta_fast),\n        beta_slow_(beta_slow),\n        mscale_(mscale_val),\n        mscale_all_dim_(mscale_all_dim_val) {\n    if (this->dim_ % 2 != 0 || this->dim_ < 0) {\n      throw std::invalid_argument(\"Dimension must be even for RotaryEmbedding and >= 0.\");\n    }\n\n    if (this->max_position_embeddings_ < 0) {\n      throw std::invalid_argument(\"DeepseekV3YarnRotaryEmbedding: max_position_embeddings_ must be >= 0.\");\n    }\n    calculate_inv_freq();\n    set_cos_sin_cache(max_position_embeddings);\n  }\n\n protected:\n  void calculate_inv_freq() override {\n    if (this->dim_ == 0) {\n      this->inv_freq_.clear();\n      return;\n    }\n    size_t dim_half = this->dim_ / 2;\n    this->inv_freq_.resize(dim_half);\n\n    std::vector<double> freq_extra(dim_half);\n    std::vector<double> freq_inter(dim_half);\n    for (size_t i = 0; i < dim_half; ++i) {\n      double freq_index = 2.0 * i / this->dim_;\n      freq_extra[i] = 1.0 / std::pow(this->base_, freq_index);\n      freq_inter[i] = 1.0f / (scaling_factor_ * std::pow(this->base_, freq_index));\n    }\n\n    auto [low_idx_f, high_idx_f] =\n        yarn_find_correction_range(beta_fast_, beta_slow_, this->dim_, this->base_, original_max_position_embeddings_);\n\n    size_t low_idx = static_cast<size_t>(low_idx_f);\n    size_t high_idx = static_cast<size_t>(high_idx_f);\n\n    std::vector<double> inv_freq_mask_ramp;\n    inv_freq_mask_ramp = yarn_linear_ramp_mask(low_idx, high_idx, dim_half);\n\n    for (size_t i = 0; i < dim_half; ++i) {\n      double mask_val = 1.0 - inv_freq_mask_ramp[i];\n      this->inv_freq_[i] = freq_inter[i] * (1.0 - mask_val) + freq_extra[i] * mask_val;\n    }\n  }\n\n  void set_cos_sin_cache(size_t seq_len) override {\n    if (this->inv_freq_.empty() || this->inv_freq_.size() != this->dim_ / 2) {\n      calculate_inv_freq();\n    }\n\n    this->cos_cached_.resize(seq_len * this->dim_ / 2);\n    this->sin_cached_.resize(seq_len * this->dim_ / 2);\n\n    // printf(\"scaling_factor %f, mscale %f, mscale all dim %f\\n\", scaling_factor_, mscale_, mscale_all_dim_);\n    double scale_factor_val = yarn_get_mscale(scaling_factor_, mscale_);\n    double scale_all_dim_factor_val = yarn_get_mscale(scaling_factor_, mscale_all_dim_);\n    double actual_mscale = 1.0;\n    if (std::abs(scale_all_dim_factor_val) > 1e-6f) {\n      actual_mscale = scale_factor_val / scale_all_dim_factor_val;\n    }\n    // printf(\"actual_mscale: %f, %f, %f\\n\", actual_mscale, scale_factor_val, scale_all_dim_factor_val);\n\n    for (size_t i = 0; i < seq_len; ++i) {\n      for (size_t j = 0; j < this->inv_freq_.size(); ++j) {\n        double freq = static_cast<double>(i) * this->inv_freq_[j];\n        double cos_val = std::cos(freq) * actual_mscale;\n        double sin_val = std::sin(freq) * actual_mscale;\n        size_t idx1 = i * this->dim_ / 2 + j;\n\n        this->cos_cached_.at(idx1) = cos_val;\n        this->sin_cached_.at(idx1) = sin_val;\n      }\n    }\n    this->max_seq_len_cached_ = seq_len;\n  }\n\n private:\n  double scaling_factor_;\n  size_t original_max_position_embeddings_;\n  double beta_fast_;\n  double beta_slow_;\n  double mscale_;\n  double mscale_all_dim_;\n};\n\n#endif"
  },
  {
    "path": "kt-kernel/operators/softmax.hpp",
    "content": "#ifndef CPUINFER_OPERATOR_SOFTMAX_HPP\n#define CPUINFER_OPERATOR_SOFTMAX_HPP\n\n#include <algorithm>  // max_element\n#include <cmath>      // exp\n#include <cstddef>\n#ifdef __aarch64__\n#include <arm_sve.h>\n#endif\n\n#include <type_traits>\n\ntemplate <typename T, typename A>\nconcept SOFTMAX_APPLIER = requires(T t, A* v, size_t size, size_t count, size_t ld) {\n  { T::apply_single(v, size) } -> std::same_as<void>;\n  { T::apply_multiple(count, v, size, ld) } -> std::same_as<void>;\n};\n\ntemplate <typename A>\nclass Softmax {\n public:\n  /// 对单个向量做 softmax，就地写回\n  static void apply_single(A* v, size_t size) {\n    static thread_local std::vector<float> v2(100000);\n    if (size == 0 || v == nullptr) return;\n    if (size > v2.size()) {\n      v2.resize(size);\n    }\n\n    for (int i = 0; i < size; i++) {\n      v2[i] = v[i];\n    }\n\n    const float max_val = *std::max_element(v2.begin(), v2.begin() + size);\n\n    float sum = 0;\n    for (size_t i = 0; i < size; ++i) {\n      v2[i] = std::exp(v2[i] - max_val);\n      sum += v2[i];\n    }\n    if (sum == 0) return;  // 理论上不会发生，但防御一下\n    const float inv_sum = 1.0 / sum;\n    for (size_t i = 0; i < size; ++i) {\n      v[i] = v2[i] * inv_sum;\n    }\n  }\n\n  static void apply_multiple(size_t count, A* v, size_t size, size_t ld) {\n    for (size_t i = 0; i < count; ++i) {\n      apply_single(v + i * ld, size);\n    }\n  }\n};\n\n#endif  // CPUINFER_OPERATOR_SOFTMAX_HPP\n"
  },
  {
    "path": "kt-kernel/operators/tp.hpp",
    "content": ""
  },
  {
    "path": "kt-kernel/pyproject.toml",
    "content": "[build-system]\n# Minimum versions: setuptools for setup.py declarative usage, wheel for bdist_wheel\nrequires = [\"setuptools>=61\", \"wheel\", \"cmake>=3.16\", \"pybind11\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"kt-kernel\"\n# Version is dynamically read from ../version.py via setup.py\ndynamic = [\"version\"]\ndescription = \"KT-Kernel: High-performance kernel operations for KTransformers (AMX/AVX/KML optimizations)\"\nreadme = \"README.md\"\nauthors = [{ name = \"kvcache-ai\" }]\n# Use SPDX string form (table form deprecated in newer setuptools)\nlicense = \"Apache-2.0\"\nclassifiers = [\n  \"Programming Language :: Python :: 3\",\n  \"Programming Language :: C++\",\n  \"Operating System :: POSIX :: Linux\",\n  \"Operating System :: MacOS\",\n]\nrequires-python = \">=3.8\"\ndependencies = [\n  # Core dependencies\n  \"torch>=2.0.0\",\n  \"safetensors>=0.4.0\",\n  \"compressed-tensors>=0.7.0\",\n  \"numpy>=1.24.0\",\n  \"triton>=2.0.0\",\n  \"gguf>=0.17.0\",\n  # CLI dependencies\n  \"typer[all]>=0.9.0\",\n  \"rich>=13.0.0\",\n  \"pyyaml>=6.0\",\n  \"httpx>=0.25.0\",\n  \"packaging>=23.0\",\n  # SGLang (kvcache-ai fork)\n  \"sglang-kt\",\n  # Development dependencies\n  \"black>=25.9.0\",\n]\n\n[project.optional-dependencies]\ntest = [\n  \"pytest>=7.0.0\",\n  \"psutil>=5.9.0\",\n]\n\n[project.scripts]\nkt = \"kt_kernel.cli.main:main\"\n\n[project.urls]\nHomepage = \"https://github.com/kvcache-ai\"\n\n[tool.setuptools]\npackages = [\n  \"kt_kernel\",\n  \"kt_kernel.utils\",\n  \"kt_kernel.cli\",\n  \"kt_kernel.cli.commands\",\n  \"kt_kernel.cli.config\",\n  \"kt_kernel.cli.utils\",\n  \"kt_kernel.cli.completions\",\n]\ninclude-package-data = true\n\n[tool.setuptools.package-dir]\nkt_kernel = \"python\"\n\"kt_kernel.utils\" = \"python/utils\"\n\"kt_kernel.cli\" = \"python/cli\"\n\"kt_kernel.cli.commands\" = \"python/cli/commands\"\n\"kt_kernel.cli.config\" = \"python/cli/config\"\n\"kt_kernel.cli.utils\" = \"python/cli/utils\"\n\"kt_kernel.cli.completions\" = \"python/cli/completions\"\n\n[tool.setuptools.package-data]\n\"kt_kernel.cli.completions\" = [\"*.bash\", \"*.fish\", \"_kt\"]\n\n[tool.setuptools.exclude-package-data]\n# (empty)\n\n[tool.cpuinfer]\n# Custom section (example). You can place build options documentation here.\n# CPUINFER_CPU_INSTRUCT: NATIVE|FANCY|AVX512|AVX2\n# CPUINFER_ENABLE_AMX: ON/OFF\n# CPUINFER_VERBOSE: 1/0\n\n[tool.black]\n# Code style for Black formatter\nline-length = 120\ntarget-version = [\"py311\"]\nexclude = '''\n(\n  /(\\.\n    | build\n    | dist\n    | temp\n    | __pycache__\n    | kt_kernel\\.egg-info\n    | third_party\n  )/\n)\n'''\n"
  },
  {
    "path": "kt-kernel/pytest.ini",
    "content": "[pytest]\n# Test paths\ntestpaths = test/per_commit\n\n# File and function naming conventions\npython_files = test_*.py\npython_classes = Test*\npython_functions = test_*\n\n# Markers for hardware backends\nmarkers =\n    cpu: CPU backend tests (Intel AMX/AVX512/AVX2)\n    cuda: CUDA backend tests (NVIDIA GPUs)\n    amd: AMD backend tests (ROCm)\n    slow: Slow-running tests (>60 seconds)\n    requires_model: Tests requiring model files\n\n# Output options\naddopts =\n    -v\n    --tb=short\n    --strict-markers\n\n# Filter warnings\nfilterwarnings =\n    ignore::DeprecationWarning\n    ignore::PendingDeprecationWarning\n"
  },
  {
    "path": "kt-kernel/python/__init__.py",
    "content": "# KT-Kernel: High-performance kernel operations for KTransformers\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nKT-Kernel provides high-performance kernel operations for KTransformers,\nincluding CPU-optimized MoE inference with AMX, AVX, and KML support.\n\nThe package automatically detects your CPU capabilities and loads the optimal\nkernel variant (AMX, AVX512, or AVX2) at runtime.\n\nExample usage:\n    >>> from kt_kernel import KTMoEWrapper\n    >>> wrapper = KTMoEWrapper(\n    ...     layer_idx=0,\n    ...     num_experts=8,\n    ...     num_experts_per_tok=2,\n    ...     hidden_size=4096,\n    ...     moe_intermediate_size=14336,\n    ...     num_gpu_experts=2,\n    ...     cpuinfer_threads=32,\n    ...     threadpool_count=2,\n    ...     weight_path=\"/path/to/weights\",\n    ...     chunked_prefill_size=512,\n    ...     method=\"AMXINT4\"\n    ... )\n\n    Check which CPU variant is loaded:\n    >>> import kt_kernel\n    >>> print(kt_kernel.__cpu_variant__)  # 'amx', 'avx512', or 'avx2'\n\nEnvironment Variables:\n    KT_KERNEL_CPU_VARIANT: Override automatic detection ('amx', 'avx512', 'avx2')\n    KT_KERNEL_DEBUG: Enable debug output ('1' to enable)\n\"\"\"\n\nfrom __future__ import annotations\n\n# Detect CPU and load optimal extension variant\nfrom ._cpu_detect import initialize as _initialize_cpu\n\n_kt_kernel_ext, __cpu_variant__ = _initialize_cpu()\n\n# Make the extension module available to other modules in this package\nimport sys\n\nsys.modules[\"kt_kernel_ext\"] = _kt_kernel_ext\n\n# Also expose kt_kernel_ext as an attribute for backward compatibility\nkt_kernel_ext = _kt_kernel_ext\n\n# Import main API\nfrom .experts import KTMoEWrapper\nfrom .experts_base import generate_gpu_experts_masks\n\n# Read version from package metadata (preferred) or fallback to project root\ntry:\n    # Try to get version from installed package metadata (works in installed environment)\n    from importlib.metadata import version, PackageNotFoundError\n\n    try:\n        __version__ = version(\"kt-kernel\")\n    except PackageNotFoundError:\n        # Package not installed, try to read from source tree version.py\n        import os\n\n        _root_version_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), \"version.py\")\n        if os.path.exists(_root_version_file):\n            _version_ns = {}\n            with open(_root_version_file, \"r\", encoding=\"utf-8\") as f:\n                exec(f.read(), _version_ns)\n            __version__ = _version_ns.get(\"__version__\", \"0.4.3\")\n        else:\n            __version__ = \"0.4.3\"\nexcept ImportError:\n    # Python < 3.8, fallback to pkg_resources or hardcoded version\n    try:\n        from pkg_resources import get_distribution, DistributionNotFound\n\n        try:\n            __version__ = get_distribution(\"kt-kernel\").version\n        except DistributionNotFound:\n            __version__ = \"0.4.3\"\n    except ImportError:\n        __version__ = \"0.4.3\"\n\n__all__ = [\"KTMoEWrapper\", \"generate_gpu_experts_masks\", \"kt_kernel_ext\", \"__cpu_variant__\", \"__version__\"]\n"
  },
  {
    "path": "kt-kernel/python/_cpu_detect.py",
    "content": "\"\"\"\nCPU feature detection and optimal kernel loader for kt-kernel.\n\nThis module automatically detects CPU capabilities and loads the best available\nkernel variant (AMX, AVX512, or AVX2) at runtime.\n\nEnvironment Variables:\n    KT_KERNEL_CPU_VARIANT: Override automatic detection ('amx', 'avx512', 'avx2')\n    KT_KERNEL_DEBUG: Enable debug output ('1' to enable)\n\nExample:\n    >>> import kt_kernel\n    >>> print(kt_kernel.__cpu_variant__)  # Shows detected variant\n\n    # Override detection\n    >>> import os\n    >>> os.environ['KT_KERNEL_CPU_VARIANT'] = 'avx2'\n    >>> import kt_kernel  # Will use AVX2 variant\n\"\"\"\n\nimport os\nimport sys\nfrom pathlib import Path\n\n\ndef detect_cpu_features():\n    \"\"\"\n    Detect CPU features and determine the best kernel variant using progressive matching.\n\n    Progressive variant hierarchy (from most to least advanced):\n        1. AMX: amx_tile, amx_int8, amx_bf16 + full AVX512\n        2. AVX512_BF16: avx512f, avx512bw, avx512_vnni, avx512_vbmi, avx512_bf16\n        3. AVX512_VBMI: avx512f, avx512bw, avx512_vnni, avx512_vbmi\n        4. AVX512_VNNI: avx512f, avx512bw, avx512_vnni\n        5. AVX512_BASE: avx512f, avx512bw\n        6. AVX2: avx2 (fallback)\n\n    Returns:\n        str: Variant name - one of: 'amx', 'avx512_bf16', 'avx512_vbmi',\n             'avx512_vnni', 'avx512_base', 'avx2'\n    \"\"\"\n    # Check environment override\n    variant = os.environ.get(\"KT_KERNEL_CPU_VARIANT\", \"\").lower()\n    valid_variants = [\"amx\", \"avx512_bf16\", \"avx512_vbmi\", \"avx512_vnni\", \"avx512_base\", \"avx2\"]\n    if variant in valid_variants:\n        if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n            print(f\"[kt-kernel] Using environment override: {variant}\")\n        return variant\n\n    # Try to read /proc/cpuinfo on Linux\n    try:\n        with open(\"/proc/cpuinfo\", \"r\") as f:\n            cpuinfo = f.read().lower()\n\n        # Extract CPU flags into a set for fast lookup\n        cpu_flags = set()\n        for line in cpuinfo.split(\"\\n\"):\n            if line.startswith(\"flags\"):\n                flags_str = line.split(\":\", 1)[1]\n                cpu_flags = set(flags_str.split())\n                break\n\n        # Define variant requirements in priority order (best to worst)\n        variant_requirements = [\n            (\n                \"amx\",\n                [\n                    \"amx_tile\",\n                    \"amx_int8\",\n                    \"amx_bf16\",\n                    \"avx512f\",\n                    \"avx512bw\",\n                    \"avx512_vnni\",\n                    \"avx512_vbmi\",\n                    \"avx512_bf16\",\n                ],\n            ),\n            (\"avx512_bf16\", [\"avx512f\", \"avx512bw\", \"avx512_vnni\", \"avx512_vbmi\", \"avx512_bf16\"]),\n            (\"avx512_vbmi\", [\"avx512f\", \"avx512bw\", \"avx512_vnni\", \"avx512_vbmi\"]),\n            (\"avx512_vnni\", [\"avx512f\", \"avx512bw\", \"avx512_vnni\"]),\n            (\"avx512_base\", [\"avx512f\", \"avx512bw\"]),\n            (\"avx2\", [\"avx2\"]),\n        ]\n\n        # Find the best matching variant\n        for variant_name, required_flags in variant_requirements:\n            # Check if all required flags are present\n            # Handle flag name variations (e.g., avx512_bf16 vs avx512bf16)\n            has_all_flags = True\n            for flag in required_flags:\n                # Try exact match first, then without underscore\n                flag_alt = flag.replace(\"_\", \"\")\n                if flag not in cpu_flags and flag_alt not in cpu_flags:\n                    has_all_flags = False\n                    break\n\n            if has_all_flags:\n                if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n                    print(f\"[kt-kernel] Detected {variant_name} support via /proc/cpuinfo\")\n                    print(f\"[kt-kernel] Matched flags: {', '.join(required_flags)}\")\n                return variant_name\n\n        # Fallback to AVX2 (should be rare on modern CPUs)\n        if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n            print(\"[kt-kernel] No supported features detected, using AVX2 fallback\")\n        return \"avx2\"\n\n    except FileNotFoundError:\n        # /proc/cpuinfo doesn't exist (not Linux or in container)\n        # Try cpufeature package as fallback\n        if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n            print(\"[kt-kernel] /proc/cpuinfo not found, trying cpufeature package\")\n\n        try:\n            import cpufeature\n\n            # Define variant requirements in priority order (using cpufeature naming)\n            cpufeature_requirements = [\n                (\n                    \"amx\",\n                    [\n                        \"AMX_TILE\",\n                        \"AMX_INT8\",\n                        \"AMX_BF16\",\n                        \"AVX512F\",\n                        \"AVX512BW\",\n                        \"AVX512_VNNI\",\n                        \"AVX512_VBMI\",\n                        \"AVX512_BF16\",\n                    ],\n                ),\n                (\"avx512_bf16\", [\"AVX512F\", \"AVX512BW\", \"AVX512_VNNI\", \"AVX512_VBMI\", \"AVX512_BF16\"]),\n                (\"avx512_vbmi\", [\"AVX512F\", \"AVX512BW\", \"AVX512_VNNI\", \"AVX512_VBMI\"]),\n                (\"avx512_vnni\", [\"AVX512F\", \"AVX512BW\", \"AVX512_VNNI\"]),\n                (\"avx512_base\", [\"AVX512F\", \"AVX512BW\"]),\n                (\"avx2\", [\"AVX2\"]),\n            ]\n\n            # Find the best matching variant\n            for variant_name, required_features in cpufeature_requirements:\n                has_all_features = all(cpufeature.CPUFeature.get(feat, False) for feat in required_features)\n                if has_all_features:\n                    if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n                        print(f\"[kt-kernel] Detected {variant_name} support via cpufeature\")\n                    return variant_name\n\n            # Fallback to AVX2\n            if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n                print(\"[kt-kernel] Using AVX2 fallback via cpufeature\")\n            return \"avx2\"\n\n        except ImportError:\n            # cpufeature not available - ultimate fallback\n            if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n                print(\"[kt-kernel] cpufeature not available, using AVX2 fallback\")\n            return \"avx2\"\n\n    except Exception as e:\n        # Any other error - safe fallback\n        if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n            print(f\"[kt-kernel] Error during CPU detection: {e}, using AVX2 fallback\")\n        return \"avx2\"\n\n\ndef load_extension(variant):\n    \"\"\"\n    Load the appropriate kt_kernel_ext variant.\n\n    Tries to import the specified variant, with automatic fallback to\n    lower-performance variants if the requested one is not available.\n\n    Supports both multi-variant builds (_kt_kernel_ext_amx.*.so) and\n    single-variant builds (kt_kernel_ext.*.so).\n\n    Fallback chain (each variant falls back to the next in line):\n        amx -> avx512_bf16 -> avx512_vbmi -> avx512_vnni -> avx512_base -> avx2 -> single-variant\n\n    Args:\n        variant (str): One of 'amx', 'avx512_bf16', 'avx512_vbmi', 'avx512_vnni', 'avx512_base', 'avx2'\n\n    Returns:\n        module: The loaded extension module\n\n    Raises:\n        ImportError: If all variants fail to load\n    \"\"\"\n    import importlib.util\n    import glob\n\n    # The .so files can be named in two ways:\n    # Multi-variant: _kt_kernel_ext_amx.cpython-311-x86_64-linux-gnu.so\n    # Single-variant: kt_kernel_ext.cpython-311-x86_64-linux-gnu.so\n    # Both export PyInit_kt_kernel_ext (the original module name)\n\n    try:\n        # Find the kt_kernel package directory\n        # We can't import kt_kernel here (circular import), so use __file__\n        kt_kernel_dir = os.path.dirname(os.path.abspath(__file__))\n\n        # Try multi-variant naming first\n        pattern = os.path.join(kt_kernel_dir, f\"_kt_kernel_ext_{variant}.*.so\")\n        so_files = glob.glob(pattern)\n\n        if not so_files:\n            # Try single-variant naming (fallback for builds without CPUINFER_BUILD_ALL_VARIANTS)\n            pattern = os.path.join(kt_kernel_dir, \"kt_kernel_ext.*.so\")\n            so_files = glob.glob(pattern)\n\n            if so_files:\n                if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n                    print(f\"[kt-kernel] Multi-variant {variant} not found, using single-variant build\")\n            else:\n                raise ImportError(\n                    f\"No .so file found for variant {variant} (tried patterns: {kt_kernel_dir}/_kt_kernel_ext_{variant}.*.so and {kt_kernel_dir}/kt_kernel_ext.*.so)\"\n                )\n\n        so_file = so_files[0]\n\n        if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n            print(f\"[kt-kernel] Loading {variant} from: {so_file}\")\n\n        # Load the module manually\n        # The module exports PyInit_kt_kernel_ext, so we use that as the module name\n        spec = importlib.util.spec_from_file_location(\"kt_kernel_ext\", so_file)\n        if spec is None or spec.loader is None:\n            raise ImportError(f\"Failed to create spec for {so_file}\")\n\n        ext = importlib.util.module_from_spec(spec)\n        spec.loader.exec_module(ext)\n\n        if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n            print(f\"[kt-kernel] Successfully loaded {variant.upper()} variant\")\n        return ext\n\n    except (ImportError, ModuleNotFoundError, FileNotFoundError) as e:\n        if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n            print(f\"[kt-kernel] Failed to load {variant} variant: {e}\")\n\n        # Define fallback chain: each variant falls back to the next lower one\n        fallback_chain = {\n            \"amx\": \"avx512_bf16\",\n            \"avx512_bf16\": \"avx512_vbmi\",\n            \"avx512_vbmi\": \"avx512_vnni\",\n            \"avx512_vnni\": \"avx512_base\",\n            \"avx512_base\": \"avx2\",\n            \"avx2\": None,  # No fallback - terminal variant\n        }\n\n        # Get next fallback variant\n        next_variant = fallback_chain.get(variant)\n\n        if next_variant:\n            # Try next variant in the chain\n            if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n                print(f\"[kt-kernel] Falling back from {variant} to {next_variant}\")\n            return load_extension(next_variant)\n        else:\n            # AVX2 is the last fallback - if this fails, we can't continue\n            raise ImportError(\n                f\"Failed to load kt_kernel extension (variant: {variant}). \"\n                f\"Original error: {e}\\n\"\n                f\"This usually means the kt_kernel package is not properly installed.\"\n            )\n\n\ndef initialize():\n    \"\"\"\n    Detect CPU capabilities and load the optimal extension variant.\n\n    This is the main entry point called by kt_kernel.__init__.py.\n\n    Returns:\n        tuple: (extension_module, variant_name)\n    - extension_module: The loaded C++ extension module\n            - variant_name: String indicating which variant was loaded ('amx', 'avx512', 'avx2')\n\n    Example:\n        >>> ext, variant = initialize()\n        >>> print(f\"Loaded {variant} variant\")\n        >>> wrapper = ext.AMXMoEWrapper(...)\n    \"\"\"\n    # Detect CPU features\n    variant = detect_cpu_features()\n\n    if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n        print(f\"[kt-kernel] Selected CPU variant: {variant}\")\n\n    # Load the appropriate extension\n    ext = load_extension(variant)\n\n    if os.environ.get(\"KT_KERNEL_DEBUG\") == \"1\":\n        print(f\"[kt-kernel] Extension module loaded: {ext.__name__}\")\n\n    return ext, variant\n"
  },
  {
    "path": "kt-kernel/python/cli/__init__.py",
    "content": "\"\"\"\nKTransformers CLI - A unified command-line interface for KTransformers.\n\nThis CLI provides a user-friendly interface to all KTransformers functionality,\nincluding model inference, fine-tuning, benchmarking, and more.\n\"\"\"\n\n__version__ = \"0.1.0\"\n"
  },
  {
    "path": "kt-kernel/python/cli/commands/__init__.py",
    "content": "\"\"\"\nCommand modules for kt-cli.\n\"\"\"\n"
  },
  {
    "path": "kt-kernel/python/cli/commands/bench.py",
    "content": "\"\"\"\nBench commands for kt-cli.\n\nRuns benchmarks for performance testing.\n\"\"\"\n\nimport subprocess\nimport sys\nfrom enum import Enum\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\n\nfrom kt_kernel.cli.i18n import t\nfrom kt_kernel.cli.utils.console import (\n    console,\n    print_error,\n    print_info,\n    print_step,\n    print_success,\n)\n\n\nclass BenchType(str, Enum):\n    \"\"\"Benchmark type.\"\"\"\n\n    INFERENCE = \"inference\"\n    MLA = \"mla\"\n    MOE = \"moe\"\n    LINEAR = \"linear\"\n    ATTENTION = \"attention\"\n    ALL = \"all\"\n\n\ndef bench(\n    type: BenchType = typer.Option(\n        BenchType.ALL,\n        \"--type\",\n        \"-t\",\n        help=\"Benchmark type\",\n    ),\n    model: Optional[str] = typer.Option(\n        None,\n        \"--model\",\n        \"-m\",\n        help=\"Model to benchmark\",\n    ),\n    output: Optional[Path] = typer.Option(\n        None,\n        \"--output\",\n        \"-o\",\n        help=\"Output file for results (JSON)\",\n    ),\n    iterations: int = typer.Option(\n        10,\n        \"--iterations\",\n        \"-n\",\n        help=\"Number of iterations\",\n    ),\n) -> None:\n    \"\"\"Run full benchmark suite.\"\"\"\n    console.print()\n    print_step(t(\"bench_starting\"))\n    print_info(t(\"bench_type\", type=type.value))\n    console.print()\n\n    if type == BenchType.ALL:\n        _run_all_benchmarks(model, output, iterations)\n    elif type == BenchType.INFERENCE:\n        _run_inference_benchmark(model, output, iterations)\n    elif type == BenchType.MLA:\n        _run_component_benchmark(\"mla\", output, iterations)\n    elif type == BenchType.MOE:\n        _run_component_benchmark(\"moe\", output, iterations)\n    elif type == BenchType.LINEAR:\n        _run_component_benchmark(\"linear\", output, iterations)\n    elif type == BenchType.ATTENTION:\n        _run_component_benchmark(\"attention\", output, iterations)\n\n    console.print()\n    print_success(t(\"bench_complete\"))\n    if output:\n        console.print(f\"  Results saved to: {output}\")\n    console.print()\n\n\ndef microbench(\n    component: str = typer.Argument(\n        \"moe\",\n        help=\"Component to benchmark (moe, mla, linear, attention)\",\n    ),\n    batch_size: int = typer.Option(\n        1,\n        \"--batch-size\",\n        \"-b\",\n        help=\"Batch size\",\n    ),\n    seq_len: int = typer.Option(\n        1,\n        \"--seq-len\",\n        \"-s\",\n        help=\"Sequence length\",\n    ),\n    iterations: int = typer.Option(\n        100,\n        \"--iterations\",\n        \"-n\",\n        help=\"Number of iterations\",\n    ),\n    warmup: int = typer.Option(\n        10,\n        \"--warmup\",\n        \"-w\",\n        help=\"Warmup iterations\",\n    ),\n    output: Optional[Path] = typer.Option(\n        None,\n        \"--output\",\n        \"-o\",\n        help=\"Output file for results (JSON)\",\n    ),\n) -> None:\n    \"\"\"Run micro-benchmark for specific components.\"\"\"\n    console.print()\n    console.print(f\"[yellow]{t('feature_coming_soon')}[/yellow]\")\n    console.print()\n    raise typer.Exit(0)\n\n    # Try to find the benchmark script\n    kt_kernel_path = _find_kt_kernel_path()\n\n    if kt_kernel_path is None:\n        print_error(\"kt-kernel not found. Install with: kt install inference\")\n        raise typer.Exit(1)\n\n    bench_dir = kt_kernel_path / \"bench\"\n\n    # Map component to script\n    component_scripts = {\n        \"moe\": \"bench_moe.py\",\n        \"mla\": \"bench_mla.py\",\n        \"linear\": \"bench_linear.py\",\n        \"attention\": \"bench_attention.py\",\n        \"mlp\": \"bench_mlp.py\",\n    }\n\n    script_name = component_scripts.get(component.lower())\n    if script_name is None:\n        print_error(f\"Unknown component: {component}\")\n        console.print(f\"Available: {', '.join(component_scripts.keys())}\")\n        raise typer.Exit(1)\n\n    script_path = bench_dir / script_name\n    if not script_path.exists():\n        print_error(f\"Benchmark script not found: {script_path}\")\n        raise typer.Exit(1)\n\n    # Run benchmark\n    cmd = [\n        sys.executable,\n        str(script_path),\n        \"--batch-size\",\n        str(batch_size),\n        \"--seq-len\",\n        str(seq_len),\n        \"--iterations\",\n        str(iterations),\n        \"--warmup\",\n        str(warmup),\n    ]\n\n    if output:\n        cmd.extend([\"--output\", str(output)])\n\n    console.print(f\"[dim]$ {' '.join(cmd)}[/dim]\")\n    console.print()\n\n    try:\n        process = subprocess.run(cmd)\n\n        if process.returncode == 0:\n            console.print()\n            print_success(t(\"bench_complete\"))\n            if output:\n                console.print(f\"  Results saved to: {output}\")\n        else:\n            print_error(f\"Benchmark failed with exit code {process.returncode}\")\n            raise typer.Exit(process.returncode)\n\n    except FileNotFoundError as e:\n        print_error(f\"Failed to run benchmark: {e}\")\n        raise typer.Exit(1)\n\n\ndef _find_kt_kernel_path() -> Optional[Path]:\n    \"\"\"Find the kt-kernel installation path.\"\"\"\n    try:\n        import kt_kernel\n\n        return Path(kt_kernel.__file__).parent.parent\n    except ImportError:\n        pass\n\n    # Check common locations\n    possible_paths = [\n        Path.home() / \"Projects\" / \"ktransformers\" / \"kt-kernel\",\n        Path(\"/opt/ktransformers/kt-kernel\"),\n        Path.cwd() / \"kt-kernel\",\n    ]\n\n    for path in possible_paths:\n        if path.exists() and (path / \"bench\").exists():\n            return path\n\n    return None\n\n\ndef _run_all_benchmarks(model: Optional[str], output: Optional[Path], iterations: int) -> None:\n    \"\"\"Run all benchmarks.\"\"\"\n    components = [\"moe\", \"mla\", \"linear\", \"attention\"]\n\n    for component in components:\n        console.print(f\"\\n[bold]Running {component} benchmark...[/bold]\")\n        _run_component_benchmark(component, None, iterations)\n\n\ndef _run_inference_benchmark(model: Optional[str], output: Optional[Path], iterations: int) -> None:\n    \"\"\"Run inference benchmark.\"\"\"\n    if model is None:\n        print_error(\"Model required for inference benchmark. Use --model flag.\")\n        raise typer.Exit(1)\n\n    print_info(f\"Running inference benchmark on {model}...\")\n    console.print()\n    console.print(\"[dim]This will start the server and run test requests.[/dim]\")\n    console.print()\n\n    # TODO: Implement actual inference benchmarking\n    print_error(\"Inference benchmarking not yet implemented.\")\n\n\ndef _run_component_benchmark(component: str, output: Optional[Path], iterations: int) -> None:\n    \"\"\"Run a component benchmark.\"\"\"\n    kt_kernel_path = _find_kt_kernel_path()\n\n    if kt_kernel_path is None:\n        print_error(\"kt-kernel not found.\")\n        return\n\n    bench_dir = kt_kernel_path / \"bench\"\n    script_map = {\n        \"moe\": \"bench_moe.py\",\n        \"mla\": \"bench_mla.py\",\n        \"linear\": \"bench_linear.py\",\n        \"attention\": \"bench_attention.py\",\n    }\n\n    script_name = script_map.get(component)\n    if script_name is None:\n        print_error(f\"Unknown component: {component}\")\n        return\n\n    script_path = bench_dir / script_name\n    if not script_path.exists():\n        print_error(f\"Script not found: {script_path}\")\n        return\n\n    cmd = [sys.executable, str(script_path), \"--iterations\", str(iterations)]\n\n    try:\n        subprocess.run(cmd)\n    except Exception as e:\n        print_error(f\"Benchmark failed: {e}\")\n"
  },
  {
    "path": "kt-kernel/python/cli/commands/chat.py",
    "content": "\"\"\"\nChat command for kt-cli.\n\nProvides interactive chat interface with running model server.\n\"\"\"\n\nimport json\nimport os\nimport sys\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\nfrom rich.console import Console\nfrom rich.markdown import Markdown\nfrom rich.panel import Panel\nfrom rich.prompt import Prompt, Confirm\n\nfrom kt_kernel.cli.config.settings import get_settings\nfrom kt_kernel.cli.i18n import t\nfrom kt_kernel.cli.utils.console import (\n    console,\n    print_error,\n    print_info,\n    print_success,\n    print_warning,\n)\n\n# Try to import OpenAI SDK\ntry:\n    from openai import OpenAI\n\n    HAS_OPENAI = True\nexcept ImportError:\n    HAS_OPENAI = False\n\n\ndef chat(\n    host: Optional[str] = typer.Option(\n        None,\n        \"--host\",\n        \"-H\",\n        help=\"Server host address\",\n    ),\n    port: Optional[int] = typer.Option(\n        None,\n        \"--port\",\n        \"-p\",\n        help=\"Server port\",\n    ),\n    model: Optional[str] = typer.Option(\n        None,\n        \"--model\",\n        \"-m\",\n        help=\"Model name (if server hosts multiple models)\",\n    ),\n    temperature: float = typer.Option(\n        0.7,\n        \"--temperature\",\n        \"-t\",\n        help=\"Sampling temperature (0.0 to 2.0)\",\n    ),\n    max_tokens: int = typer.Option(\n        2048,\n        \"--max-tokens\",\n        help=\"Maximum tokens to generate\",\n    ),\n    system_prompt: Optional[str] = typer.Option(\n        None,\n        \"--system\",\n        \"-s\",\n        help=\"System prompt\",\n    ),\n    save_history: bool = typer.Option(\n        True,\n        \"--save-history/--no-save-history\",\n        help=\"Save conversation history\",\n    ),\n    history_file: Optional[Path] = typer.Option(\n        None,\n        \"--history-file\",\n        help=\"Path to save conversation history\",\n    ),\n    stream: bool = typer.Option(\n        True,\n        \"--stream/--no-stream\",\n        help=\"Enable streaming output\",\n    ),\n) -> None:\n    \"\"\"Start interactive chat with a running model server.\n\n    Examples:\n        kt chat                          # Connect to default server\n        kt chat --host 127.0.0.1 -p 8080 # Connect to specific server\n        kt chat -t 0.9 --max-tokens 4096 # Adjust generation parameters\n    \"\"\"\n    if not HAS_OPENAI:\n        print_error(t(\"chat_openai_required\"))\n        console.print()\n        console.print(t(\"chat_install_hint\"))\n        console.print(\"  pip install openai\")\n        raise typer.Exit(1)\n\n    settings = get_settings()\n\n    # Resolve server connection\n    final_host = host or settings.get(\"server.host\", \"127.0.0.1\")\n    final_port = port or settings.get(\"server.port\", 30000)\n\n    # Construct base URL for OpenAI-compatible API\n    base_url = f\"http://{final_host}:{final_port}/v1\"\n\n    console.print()\n    console.print(\n        Panel.fit(\n            f\"[bold cyan]{t('chat_title')}[/bold cyan]\\n\\n\"\n            f\"{t('chat_server')}: [yellow]{final_host}:{final_port}[/yellow]\\n\"\n            f\"{t('chat_temperature')}: [cyan]{temperature}[/cyan] | {t('chat_max_tokens')}: [cyan]{max_tokens}[/cyan]\\n\\n\"\n            f\"[dim]{t('chat_help_hint')}[/dim]\",\n            border_style=\"cyan\",\n        )\n    )\n    console.print()\n\n    # Check for proxy environment variables\n    proxy_vars = [\"HTTP_PROXY\", \"HTTPS_PROXY\", \"http_proxy\", \"https_proxy\", \"ALL_PROXY\", \"all_proxy\"]\n    detected_proxies = {var: os.environ.get(var) for var in proxy_vars if os.environ.get(var)}\n\n    if detected_proxies:\n        proxy_info = \", \".join(f\"{k}={v}\" for k, v in detected_proxies.items())\n        console.print()\n        print_warning(t(\"chat_proxy_detected\"))\n        console.print(f\"  [dim]{proxy_info}[/dim]\")\n        console.print()\n\n        use_proxy = Confirm.ask(t(\"chat_proxy_confirm\"), default=False)\n\n        if not use_proxy:\n            # Temporarily disable proxy for this connection\n            for var in proxy_vars:\n                if var in os.environ:\n                    del os.environ[var]\n            print_info(t(\"chat_proxy_disabled\"))\n            console.print()\n\n    # Initialize OpenAI client\n    try:\n        client = OpenAI(\n            base_url=base_url,\n            api_key=\"EMPTY\",  # SGLang doesn't require API key\n        )\n\n        # Test connection\n        print_info(t(\"chat_connecting\"))\n        models = client.models.list()\n        available_models = [m.id for m in models.data]\n\n        if not available_models:\n            print_error(t(\"chat_no_models\"))\n            raise typer.Exit(1)\n\n        # Select model\n        if model:\n            if model not in available_models:\n                print_warning(t(\"chat_model_not_found\", model=model, available=\", \".join(available_models)))\n                selected_model = available_models[0]\n            else:\n                selected_model = model\n        else:\n            selected_model = available_models[0]\n\n        print_success(t(\"chat_connected\", model=selected_model))\n        console.print()\n\n        # Load tokenizer for accurate token counting\n        tokenizer = None\n        try:\n            from transformers import AutoTokenizer\n\n            # selected_model is the model path\n            tokenizer = AutoTokenizer.from_pretrained(selected_model, trust_remote_code=True)\n            console.print(f\"[dim]Loaded tokenizer from {selected_model}[/dim]\")\n            console.print()\n        except Exception as e:\n            console.print(f\"[dim yellow]Warning: Could not load tokenizer, token counts will be estimated[/dim]\")\n            console.print()\n\n    except Exception as e:\n        print_error(t(\"chat_connect_failed\", error=str(e)))\n        console.print()\n        console.print(t(\"chat_server_not_running\"))\n        console.print(\"  kt run <model>\")\n        raise typer.Exit(1)\n\n    # Initialize conversation history\n    messages = []\n\n    # Add system prompt if provided\n    if system_prompt:\n        messages.append({\"role\": \"system\", \"content\": system_prompt})\n\n    # Setup history file\n    if save_history:\n        if history_file is None:\n            history_dir = settings.config_dir / \"chat_history\"\n            history_dir.mkdir(parents=True, exist_ok=True)\n            timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n            history_file = history_dir / f\"chat_{timestamp}.json\"\n        else:\n            history_file = Path(history_file)\n            history_file.parent.mkdir(parents=True, exist_ok=True)\n\n    # Main chat loop\n    try:\n        while True:\n            # Get user input - use console.input() for better CJK character support\n            try:\n                user_input = console.input(f\"[bold green]{t('chat_user_prompt')}[/bold green]: \")\n            except (EOFError, KeyboardInterrupt):\n                console.print()\n                print_info(t(\"chat_goodbye\"))\n                break\n\n            if not user_input.strip():\n                continue\n\n            # Handle special commands\n            if user_input.startswith(\"/\"):\n                if _handle_command(user_input, messages, temperature, max_tokens):\n                    continue\n                else:\n                    break  # Exit command\n\n            # Add user message to history\n            messages.append({\"role\": \"user\", \"content\": user_input})\n\n            # Generate response\n            console.print()\n            console.print(f\"[bold cyan]{t('chat_assistant_prompt')}[/bold cyan]\")\n\n            try:\n                if stream:\n                    # Streaming response\n                    response_content = _stream_response(\n                        client, selected_model, messages, temperature, max_tokens, tokenizer\n                    )\n                else:\n                    # Non-streaming response\n                    response_content = _generate_response(\n                        client, selected_model, messages, temperature, max_tokens, tokenizer\n                    )\n\n                # Add assistant response to history\n                messages.append({\"role\": \"assistant\", \"content\": response_content})\n\n                console.print()\n\n            except Exception as e:\n                print_error(t(\"chat_generation_error\", error=str(e)))\n                # Remove the user message that caused the error\n                messages.pop()\n                continue\n\n            # Save history if enabled\n            if save_history:\n                _save_history(history_file, messages, selected_model)\n\n    except KeyboardInterrupt:\n        console.print()\n        console.print()\n        print_info(t(\"chat_interrupted\"))\n\n    # Final history save\n    if save_history and messages:\n        _save_history(history_file, messages, selected_model)\n        console.print(f\"[dim]{t('chat_history_saved', path=str(history_file))}[/dim]\")\n        console.print()\n\n\ndef _stream_response(\n    client: \"OpenAI\",\n    model: str,\n    messages: list,\n    temperature: float,\n    max_tokens: int,\n    tokenizer=None,\n) -> str:\n    \"\"\"Generate streaming response and display in real-time.\"\"\"\n    import time\n\n    response_content = \"\"\n    reasoning_content = \"\"\n\n    # Performance tracking\n    first_token_time = None\n    chunk_count = 0\n\n    try:\n        # Start timing before sending request\n        start_time = time.time()\n\n        stream = client.chat.completions.create(\n            model=model,\n            messages=messages,\n            temperature=temperature,\n            max_tokens=max_tokens,\n            stream=True,\n        )\n\n        for chunk in stream:\n            delta = chunk.choices[0].delta if chunk.choices else None\n            if delta:\n                reasoning_delta = getattr(delta, \"reasoning_content\", None)\n                if reasoning_delta:\n                    if first_token_time is None:\n                        first_token_time = time.time()\n                    reasoning_content += reasoning_delta\n                    console.print(reasoning_delta, end=\"\", style=\"dim\")\n                    chunk_count += 1\n\n                if delta.content:\n                    if first_token_time is None:\n                        first_token_time = time.time()\n                    content = delta.content\n                    response_content += content\n                    console.print(content, end=\"\")\n                    chunk_count += 1\n\n        console.print()  # Newline after streaming\n\n        # Display performance metrics\n        end_time = time.time()\n        if first_token_time and chunk_count > 0:\n            ttft = first_token_time - start_time\n            total_time = end_time - start_time\n\n            # Calculate TPOT based on chunks\n            if chunk_count > 1:\n                generation_time = total_time - ttft\n                tpot = generation_time / (chunk_count - 1)\n            else:\n                tpot = 0\n\n            # Calculate accurate token counts using tokenizer\n            if tokenizer:\n                input_tokens = _count_tokens_with_tokenizer(messages, tokenizer)\n                output_tokens = _count_tokens_with_tokenizer(\n                    [{\"role\": \"assistant\", \"content\": response_content}], tokenizer\n                )\n                token_prefix = \"\"\n            else:\n                # Fallback to estimation\n                input_tokens = _estimate_tokens(messages)\n                output_tokens = _estimate_tokens([{\"role\": \"assistant\", \"content\": response_content}])\n                token_prefix = \"~\"\n\n            # Build metrics display\n            metrics = f\"[dim]Total: {total_time*1000:.0f}ms | TTFT: {ttft*1000:.0f}ms\"\n            if tpot > 0:\n                metrics += f\" | TPOT: {tpot*1000:.1f}ms\"\n            metrics += f\" | In: {token_prefix}{input_tokens} | Out: {token_prefix}{output_tokens}\"\n            metrics += \"[/dim]\"\n\n            console.print(metrics)\n\n    except Exception as e:\n        raise Exception(f\"Streaming error: {e}\")\n\n    return response_content\n\n\ndef _count_tokens_with_tokenizer(messages: list, tokenizer) -> int:\n    \"\"\"Count tokens accurately using the model's tokenizer.\"\"\"\n    try:\n        # Concatenate all message content\n        text = \"\"\n        for msg in messages:\n            role = msg.get(\"role\", \"\")\n            content = msg.get(\"content\", \"\")\n            # Simple format: role + content\n            text += f\"{role}: {content}\\n\"\n\n        # Encode and count tokens - suppress any debug output from custom tokenizers\n        import os\n        import sys\n        from contextlib import redirect_stdout, redirect_stderr\n\n        with open(os.devnull, \"w\") as devnull:\n            with redirect_stdout(devnull), redirect_stderr(devnull):\n                tokens = tokenizer.encode(text, add_special_tokens=True)\n        return len(tokens)\n    except Exception:\n        # Fallback to estimation if tokenizer fails\n        return _estimate_tokens(messages)\n\n\ndef _estimate_tokens(messages: list) -> int:\n    \"\"\"Estimate token count for messages (rough approximation).\"\"\"\n    total_chars = 0\n    for msg in messages:\n        content = msg.get(\"content\", \"\")\n        total_chars += len(content)\n\n    # Rough estimation:\n    # - English: ~4 chars per token\n    # - Chinese: ~1.5 chars per token\n    # Use 2.5 as average\n    return max(1, int(total_chars / 2.5))\n\n\ndef _generate_response(\n    client: \"OpenAI\",\n    model: str,\n    messages: list,\n    temperature: float,\n    max_tokens: int,\n    tokenizer=None,\n) -> str:\n    \"\"\"Generate non-streaming response.\"\"\"\n    import time\n\n    try:\n        start_time = time.time()\n\n        response = client.chat.completions.create(\n            model=model,\n            messages=messages,\n            temperature=temperature,\n            max_tokens=max_tokens,\n            stream=False,\n        )\n\n        end_time = time.time()\n        total_time = end_time - start_time\n\n        content = response.choices[0].message.content\n\n        # Display as markdown\n        md = Markdown(content)\n        console.print(md)\n\n        # Calculate accurate token counts using tokenizer\n        if tokenizer:\n            input_tokens = _count_tokens_with_tokenizer(messages, tokenizer)\n            output_tokens = _count_tokens_with_tokenizer([{\"role\": \"assistant\", \"content\": content}], tokenizer)\n            token_prefix = \"\"\n        else:\n            # Fallback to API usage or estimation\n            input_tokens = response.usage.prompt_tokens if response.usage else _estimate_tokens(messages)\n            output_tokens = (\n                response.usage.completion_tokens\n                if response.usage\n                else _estimate_tokens([{\"role\": \"assistant\", \"content\": content}])\n            )\n            token_prefix = \"\" if response.usage else \"~\"\n\n        # Display performance metrics\n        console.print(\n            f\"[dim]Time: {total_time*1000:.0f}ms | \"\n            f\"In: {token_prefix}{input_tokens} | Out: {token_prefix}{output_tokens}[/dim]\"\n        )\n\n        return content\n\n    except Exception as e:\n        raise Exception(f\"Generation error: {e}\")\n\n\ndef _handle_command(command: str, messages: list, temperature: float, max_tokens: int) -> bool:\n    \"\"\"Handle special commands. Returns True to continue chat, False to exit.\"\"\"\n    cmd = command.lower().strip()\n\n    if cmd in [\"/quit\", \"/exit\", \"/q\"]:\n        console.print()\n        print_info(t(\"chat_goodbye\"))\n        return False\n\n    elif cmd in [\"/help\", \"/h\"]:\n        console.print()\n        console.print(\n            Panel(\n                f\"[bold]{t('chat_help_title')}[/bold]\\n\\n{t('chat_help_content')}\",\n                title=\"Help\",\n                border_style=\"cyan\",\n            )\n        )\n        console.print()\n        return True\n\n    elif cmd in [\"/clear\", \"/c\"]:\n        messages.clear()\n        console.print()\n        print_success(t(\"chat_history_cleared\"))\n        console.print()\n        return True\n\n    elif cmd in [\"/history\", \"/hist\"]:\n        console.print()\n        if not messages:\n            print_info(t(\"chat_no_history\"))\n        else:\n            console.print(\n                Panel(\n                    _format_history(messages),\n                    title=t(\"chat_history_title\", count=len(messages)),\n                    border_style=\"cyan\",\n                )\n            )\n        console.print()\n        return True\n\n    elif cmd in [\"/info\", \"/i\"]:\n        console.print()\n        console.print(\n            Panel(\n                f\"[bold]{t('chat_info_title')}[/bold]\\n\\n{t('chat_info_content', temperature=temperature, max_tokens=max_tokens, messages=len(messages))}\",\n                title=\"Info\",\n                border_style=\"cyan\",\n            )\n        )\n        console.print()\n        return True\n\n    elif cmd in [\"/retry\", \"/r\"]:\n        if len(messages) >= 2 and messages[-1][\"role\"] == \"assistant\":\n            # Remove last assistant response\n            messages.pop()\n            print_info(t(\"chat_retrying\"))\n            console.print()\n        else:\n            print_warning(t(\"chat_no_retry\"))\n            console.print()\n        return True\n\n    else:\n        print_warning(t(\"chat_unknown_command\", command=command))\n        console.print(f\"[dim]{t('chat_unknown_hint')}[/dim]\")\n        console.print()\n        return True\n\n\ndef _format_history(messages: list) -> str:\n    \"\"\"Format conversation history for display.\"\"\"\n    lines = []\n    for i, msg in enumerate(messages, 1):\n        role = msg[\"role\"].capitalize()\n        content = msg[\"content\"]\n\n        # Truncate long messages\n        if len(content) > 200:\n            content = content[:200] + \"...\"\n\n        lines.append(f\"[bold]{i}. {role}:[/bold] {content}\")\n\n    return \"\\n\\n\".join(lines)\n\n\ndef _save_history(file_path: Path, messages: list, model: str) -> None:\n    \"\"\"Save conversation history to file.\"\"\"\n    try:\n        history_data = {\n            \"model\": model,\n            \"timestamp\": datetime.now().isoformat(),\n            \"messages\": messages,\n        }\n\n        with open(file_path, \"w\", encoding=\"utf-8\") as f:\n            json.dump(history_data, f, indent=2, ensure_ascii=False)\n\n    except Exception as e:\n        print_warning(f\"Failed to save history: {e}\")\n"
  },
  {
    "path": "kt-kernel/python/cli/commands/config.py",
    "content": "\"\"\"\nConfig command for kt-cli.\n\nManages kt-cli configuration.\n\"\"\"\n\nfrom typing import Optional\n\nimport typer\nimport yaml\nfrom rich.syntax import Syntax\n\nfrom kt_kernel.cli.config.settings import get_settings\nfrom kt_kernel.cli.i18n import t\nfrom kt_kernel.cli.utils.console import confirm, console, print_error, print_success\n\napp = typer.Typer(help=\"Manage kt-cli configuration\")\n\n\n@app.command(name=\"init\")\ndef init() -> None:\n    \"\"\"Initialize or re-run the first-time setup wizard.\"\"\"\n    from kt_kernel.cli.main import _show_first_run_setup\n    from kt_kernel.cli.config.settings import get_settings\n\n    settings = get_settings()\n    _show_first_run_setup(settings)\n\n\n@app.command(name=\"show\")\ndef show(\n    key: Optional[str] = typer.Argument(None, help=\"Configuration key to show (e.g., server.port)\"),\n) -> None:\n    \"\"\"Show current configuration.\"\"\"\n    settings = get_settings()\n\n    if key:\n        value = settings.get(key)\n        if value is not None:\n            if isinstance(value, (dict, list)):\n                console.print(yaml.dump({key: value}, default_flow_style=False, allow_unicode=True))\n            else:\n                console.print(t(\"config_get_value\", key=key, value=value))\n        else:\n            print_error(t(\"config_get_not_found\", key=key))\n            raise typer.Exit(1)\n    else:\n        console.print(f\"\\n[bold]{t('config_show_title')}[/bold]\\n\")\n        console.print(f\"[dim]{t('config_file_location', path=str(settings.config_path))}[/dim]\\n\")\n\n        config_yaml = yaml.dump(settings.get_all(), default_flow_style=False, allow_unicode=True)\n        syntax = Syntax(config_yaml, \"yaml\", theme=\"monokai\", line_numbers=False)\n        console.print(syntax)\n\n\n@app.command(name=\"set\")\ndef set_config(\n    key: str = typer.Argument(..., help=\"Configuration key (e.g., server.port)\"),\n    value: str = typer.Argument(..., help=\"Value to set\"),\n) -> None:\n    \"\"\"Set a configuration value.\"\"\"\n    settings = get_settings()\n\n    # Try to parse value as JSON/YAML for complex types\n    parsed_value = _parse_value(value)\n\n    settings.set(key, parsed_value)\n    print_success(t(\"config_set_success\", key=key, value=parsed_value))\n\n\n@app.command(name=\"get\")\ndef get_config(\n    key: str = typer.Argument(..., help=\"Configuration key (e.g., server.port)\"),\n) -> None:\n    \"\"\"Get a configuration value.\"\"\"\n    settings = get_settings()\n    value = settings.get(key)\n\n    if value is not None:\n        if isinstance(value, (dict, list)):\n            console.print(yaml.dump(value, default_flow_style=False, allow_unicode=True))\n        else:\n            console.print(str(value))\n    else:\n        print_error(t(\"config_get_not_found\", key=key))\n        raise typer.Exit(1)\n\n\n@app.command(name=\"reset\")\ndef reset(\n    yes: bool = typer.Option(False, \"--yes\", \"-y\", help=\"Skip confirmation\"),\n) -> None:\n    \"\"\"Reset configuration to defaults.\"\"\"\n    if not yes:\n        if not confirm(t(\"config_reset_confirm\"), default=False):\n            raise typer.Abort()\n\n    settings = get_settings()\n    settings.reset()\n    print_success(t(\"config_reset_success\"))\n\n\n@app.command(name=\"path\")\ndef path() -> None:\n    \"\"\"Show configuration file path.\"\"\"\n    settings = get_settings()\n    console.print(str(settings.config_path))\n\n\n@app.command(name=\"model-path-list\", deprecated=True, hidden=True)\ndef model_path_list() -> None:\n    \"\"\"[Deprecated] Use 'kt model path-list' instead.\"\"\"\n    console.print(\"[yellow]⚠ This command is deprecated. Use 'kt model path-list' instead.[/yellow]\\n\")\n    import subprocess\n    subprocess.run([\"kt\", \"model\", \"path-list\"])\n\n\n@app.command(name=\"model-path-add\", deprecated=True, hidden=True)\ndef model_path_add(\n    path: str = typer.Argument(..., help=\"Path to add\"),\n) -> None:\n    \"\"\"[Deprecated] Use 'kt model path-add' instead.\"\"\"\n    console.print(\"[yellow]⚠ This command is deprecated. Use 'kt model path-add' instead.[/yellow]\\n\")\n    import subprocess\n    subprocess.run([\"kt\", \"model\", \"path-add\", path])\n\n\n@app.command(name=\"model-path-remove\", deprecated=True, hidden=True)\ndef model_path_remove(\n    path: str = typer.Argument(..., help=\"Path to remove\"),\n) -> None:\n    \"\"\"[Deprecated] Use 'kt model path-remove' instead.\"\"\"\n    console.print(\"[yellow]⚠ This command is deprecated. Use 'kt model path-remove' instead.[/yellow]\\n\")\n    import subprocess\n    subprocess.run([\"kt\", \"model\", \"path-remove\", path])\n\n\ndef _parse_value(value: str):\n    \"\"\"Parse a string value into appropriate Python type.\"\"\"\n    # Try boolean\n    if value.lower() in (\"true\", \"yes\", \"on\", \"1\"):\n        return True\n    if value.lower() in (\"false\", \"no\", \"off\", \"0\"):\n        return False\n\n    # Try integer\n    try:\n        return int(value)\n    except ValueError:\n        pass\n\n    # Try float\n    try:\n        return float(value)\n    except ValueError:\n        pass\n\n    # Try YAML/JSON parsing for lists/dicts\n    try:\n        parsed = yaml.safe_load(value)\n        if isinstance(parsed, (dict, list)):\n            return parsed\n    except yaml.YAMLError:\n        pass\n\n    # Return as string\n    return value\n"
  },
  {
    "path": "kt-kernel/python/cli/commands/doctor.py",
    "content": "\"\"\"\nDoctor command for kt-cli.\n\nDiagnoses environment issues and provides recommendations.\n\"\"\"\n\nimport glob\nimport os\nimport platform\nimport shutil\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\nfrom rich.table import Table\n\nfrom kt_kernel.cli.config.settings import get_settings\nfrom kt_kernel.cli.i18n import t\nfrom kt_kernel.cli.utils.console import console, print_error, print_info, print_success, print_warning\nfrom kt_kernel.cli.utils.environment import (\n    check_docker,\n    detect_available_ram_gb,\n    detect_cpu_info,\n    detect_cuda_version,\n    detect_disk_space_gb,\n    detect_env_managers,\n    detect_gpus,\n    detect_memory_info,\n    detect_ram_gb,\n    get_installed_package_version,\n)\n\n\ndef _get_kt_kernel_info() -> dict:\n    \"\"\"Get kt-kernel installation information.\"\"\"\n    info = {\n        \"installed\": False,\n        \"version\": None,\n        \"cpu_variant\": None,\n        \"install_path\": None,\n        \"available_variants\": [],\n        \"extension_file\": None,\n    }\n\n    try:\n        import kt_kernel\n\n        info[\"installed\"] = True\n        info[\"version\"] = getattr(kt_kernel, \"__version__\", \"unknown\")\n        info[\"cpu_variant\"] = getattr(kt_kernel, \"__cpu_variant__\", \"unknown\")\n\n        # Get installation path\n        info[\"install_path\"] = os.path.dirname(kt_kernel.__file__)\n\n        # Find available .so files\n        kt_kernel_dir = info[\"install_path\"]\n        so_files = glob.glob(os.path.join(kt_kernel_dir, \"_kt_kernel_ext_*.so\"))\n        so_files.extend(glob.glob(os.path.join(kt_kernel_dir, \"kt_kernel_ext*.so\")))\n\n        # Parse variant names from filenames\n        variants = set()\n        for so_file in so_files:\n            basename = os.path.basename(so_file)\n            if \"_kt_kernel_ext_\" in basename:\n                # Extract variant from _kt_kernel_ext_amx.cpython-311-x86_64-linux-gnu.so\n                parts = basename.split(\"_\")\n                if len(parts) >= 4:\n                    variant = parts[3]  # \"amx\" from \"_kt_kernel_ext_amx...\"\n                    if variant.startswith(\"avx\"):\n                        # Normalize avx variants\n                        if variant in [\"avx512\", \"avx512_bf16\", \"avx512_vbmi\", \"avx512_vnni\", \"avx512_base\"]:\n                            variants.add(\"avx512\")\n                        else:\n                            variants.add(variant)\n                    else:\n                        variants.add(variant)\n            elif \"kt_kernel_ext\" in basename:\n                variants.add(\"default\")\n\n        info[\"available_variants\"] = sorted(list(variants))\n\n        # Get current extension file\n        if hasattr(kt_kernel, \"kt_kernel_ext\"):\n            ext_module = kt_kernel.kt_kernel_ext\n            info[\"extension_file\"] = getattr(ext_module, \"__file__\", None)\n\n    except ImportError:\n        info[\"installed\"] = False\n    except Exception as e:\n        info[\"error\"] = str(e)\n\n    return info\n\n\ndef doctor(\n    verbose: bool = typer.Option(False, \"--verbose\", \"-v\", help=\"Show detailed diagnostics\"),\n) -> None:\n    \"\"\"Diagnose environment issues.\"\"\"\n    console.print(f\"\\n[bold]{t('doctor_title')}[/bold]\\n\")\n\n    issues_found = False\n    checks = []\n\n    # 1. Python version\n    python_version = platform.python_version()\n    python_ok = _check_python_version(python_version)\n    checks.append(\n        {\n            \"name\": t(\"doctor_check_python\"),\n            \"status\": \"ok\" if python_ok else \"error\",\n            \"value\": python_version,\n            \"hint\": \"Python 3.10+ required\" if not python_ok else None,\n        }\n    )\n    if not python_ok:\n        issues_found = True\n\n    # 2. CUDA availability\n    cuda_version = detect_cuda_version()\n    checks.append(\n        {\n            \"name\": t(\"doctor_check_cuda\"),\n            \"status\": \"ok\" if cuda_version else \"warning\",\n            \"value\": cuda_version or t(\"version_cuda_not_found\"),\n            \"hint\": \"CUDA is optional but recommended for GPU acceleration\" if not cuda_version else None,\n        }\n    )\n\n    # 3. GPU detection\n    gpus = detect_gpus()\n    if gpus:\n        gpu_names = \", \".join(g.name for g in gpus)\n        total_vram = sum(g.vram_gb for g in gpus)\n        checks.append(\n            {\n                \"name\": t(\"doctor_check_gpu\"),\n                \"status\": \"ok\",\n                \"value\": t(\"doctor_gpu_found\", count=len(gpus), names=gpu_names),\n                \"hint\": f\"Total VRAM: {total_vram}GB\",\n            }\n        )\n    else:\n        checks.append(\n            {\n                \"name\": t(\"doctor_check_gpu\"),\n                \"status\": \"warning\",\n                \"value\": t(\"doctor_gpu_not_found\"),\n                \"hint\": \"GPU recommended for best performance\",\n            }\n        )\n\n    # 4. CPU information\n    cpu_info = detect_cpu_info()\n    checks.append(\n        {\n            \"name\": t(\"doctor_check_cpu\"),\n            \"status\": \"ok\",\n            \"value\": t(\"doctor_cpu_info\", name=cpu_info.name, cores=cpu_info.cores, threads=cpu_info.threads),\n            \"hint\": None,\n        }\n    )\n\n    # 5. CPU instruction sets (critical for kt-kernel)\n    isa_list = cpu_info.instruction_sets\n    # Check for recommended instruction sets\n    recommended_isa = {\"AVX2\", \"AVX512F\", \"AMX-INT8\"}\n    has_recommended = bool(set(isa_list) & recommended_isa)\n    has_avx2 = \"AVX2\" in isa_list\n    has_avx512 = any(isa.startswith(\"AVX512\") for isa in isa_list)\n    has_amx = any(isa.startswith(\"AMX\") for isa in isa_list)\n\n    # Determine status and build display string\n    if has_amx:\n        isa_status = \"ok\"\n        isa_hint = \"AMX available - best performance for INT4/INT8\"\n    elif has_avx512:\n        isa_status = \"ok\"\n        isa_hint = \"AVX512 available - good performance\"\n    elif has_avx2:\n        isa_status = \"warning\"\n        isa_hint = \"AVX2 only - consider upgrading CPU for better performance\"\n    else:\n        isa_status = \"error\"\n        isa_hint = \"AVX2 required for kt-kernel\"\n\n    # Show top instruction sets (prioritize important ones)\n    display_isa = isa_list[:8] if len(isa_list) > 8 else isa_list\n    isa_display = \", \".join(display_isa)\n    if len(isa_list) > 8:\n        isa_display += f\" (+{len(isa_list) - 8} more)\"\n\n    checks.append(\n        {\n            \"name\": t(\"doctor_check_cpu_isa\"),\n            \"status\": isa_status,\n            \"value\": isa_display if isa_display else \"None detected\",\n            \"hint\": isa_hint,\n        }\n    )\n\n    # 6. NUMA topology\n    numa_detail = []\n    for node, cpus in sorted(cpu_info.numa_info.items()):\n        if len(cpus) > 6:\n            cpu_str = f\"{cpus[0]}-{cpus[-1]}\"\n        else:\n            cpu_str = \",\".join(str(c) for c in cpus)\n        numa_detail.append(f\"{node}: {cpu_str}\")\n\n    numa_value = t(\"doctor_numa_info\", nodes=cpu_info.numa_nodes)\n    if verbose and numa_detail:\n        numa_value += \" (\" + \"; \".join(numa_detail) + \")\"\n\n    checks.append(\n        {\n            \"name\": t(\"doctor_check_numa\"),\n            \"status\": \"ok\",\n            \"value\": numa_value,\n            \"hint\": f\"{cpu_info.threads // cpu_info.numa_nodes} threads per node\" if cpu_info.numa_nodes > 1 else None,\n        }\n    )\n\n    # 6b. kt-kernel installation check\n    kt_info = _get_kt_kernel_info()\n\n    if kt_info[\"installed\"]:\n        # Build display string for kt-kernel\n        variant = kt_info[\"cpu_variant\"]\n        version = kt_info[\"version\"]\n        available_variants = kt_info[\"available_variants\"]\n\n        # Determine status based on CPU variant\n        if variant == \"amx\":\n            kt_status = \"ok\"\n            kt_hint = \"AMX variant loaded - optimal performance\"\n        elif variant.startswith(\"avx512\"):\n            kt_status = \"ok\"\n            kt_hint = \"AVX512 variant loaded - good performance\"\n        elif variant == \"avx2\":\n            kt_status = \"warning\"\n            kt_hint = \"AVX2 variant - consider upgrading CPU for AMX/AVX512\"\n        else:\n            kt_status = \"warning\"\n            kt_hint = f\"Unknown variant: {variant}\"\n\n        kt_value = f\"v{version} ({variant.upper()})\"\n        if verbose and available_variants:\n            kt_value += f\" [dim] - available: {', '.join(available_variants)}[/dim]\"\n\n        checks.append(\n            {\n                \"name\": \"kt-kernel\",\n                \"status\": kt_status,\n                \"value\": kt_value,\n                \"hint\": kt_hint,\n            }\n        )\n\n        # Show extension file path in verbose mode\n        if verbose and kt_info.get(\"extension_file\"):\n            ext_file = os.path.basename(kt_info[\"extension_file\"])\n            checks.append(\n                {\n                    \"name\": \"  └─ Extension\",\n                    \"status\": \"ok\",\n                    \"value\": ext_file,\n                    \"hint\": None,\n                }\n            )\n\n        # Show installation path in verbose mode\n        if verbose and kt_info.get(\"install_path\"):\n            checks.append(\n                {\n                    \"name\": \"  └─ Path\",\n                    \"status\": \"ok\",\n                    \"value\": kt_info[\"install_path\"],\n                    \"hint\": None,\n                }\n            )\n    else:\n        error_msg = kt_info.get(\"error\", \"Not installed\")\n        checks.append(\n            {\n                \"name\": \"kt-kernel\",\n                \"status\": \"error\",\n                \"value\": error_msg,\n                \"hint\": \"kt-kernel is required - run: pip install kt-kernel\",\n            }\n        )\n        issues_found = True\n\n    # 7. System memory (with frequency if available)\n    mem_info = detect_memory_info()\n    if mem_info.frequency_mhz and mem_info.type:\n        mem_value = t(\n            \"doctor_memory_freq\",\n            available=f\"{mem_info.available_gb}GB\",\n            total=f\"{mem_info.total_gb}GB\",\n            freq=mem_info.frequency_mhz,\n            type=mem_info.type,\n        )\n    else:\n        mem_value = t(\"doctor_memory_info\", available=f\"{mem_info.available_gb}GB\", total=f\"{mem_info.total_gb}GB\")\n\n    ram_ok = mem_info.total_gb >= 32\n    checks.append(\n        {\n            \"name\": t(\"doctor_check_memory\"),\n            \"status\": \"ok\" if ram_ok else \"warning\",\n            \"value\": mem_value,\n            \"hint\": \"32GB+ RAM recommended for large models\" if not ram_ok else None,\n        }\n    )\n\n    # 8. Disk space - check all model paths\n    settings = get_settings()\n    model_paths = settings.get_model_paths()\n\n    # Check all configured model paths\n    for i, disk_path in enumerate(model_paths):\n        available_disk, total_disk = detect_disk_space_gb(str(disk_path))\n        disk_ok = available_disk >= 100\n\n        # For multiple paths, add index to name\n        path_label = f\"Model Path {i+1}\" if len(model_paths) > 1 else t(\"doctor_check_disk\")\n\n        checks.append(\n            {\n                \"name\": path_label,\n                \"status\": \"ok\" if disk_ok else \"warning\",\n                \"value\": t(\"doctor_disk_info\", available=f\"{available_disk}GB\", path=str(disk_path)),\n                \"hint\": \"100GB+ free space recommended for model storage\" if not disk_ok else None,\n            }\n        )\n\n    # 6. Required packages\n    packages = [\n        (\"kt-kernel\", \">=0.4.0\", False),  # name, version_req, required\n        (\"sglang\", \">=0.4.0\", False),\n        (\"torch\", \">=2.4.0\", True),\n        (\"transformers\", \">=4.45.0\", True),\n    ]\n\n    package_issues = []\n    for pkg_name, version_req, required in packages:\n        version = get_installed_package_version(pkg_name)\n        if version:\n            package_issues.append((pkg_name, version, \"ok\"))\n        elif required:\n            package_issues.append((pkg_name, t(\"version_not_installed\"), \"error\"))\n            issues_found = True\n        else:\n            package_issues.append((pkg_name, t(\"version_not_installed\"), \"warning\"))\n\n    if verbose:\n        checks.append(\n            {\n                \"name\": t(\"doctor_check_packages\"),\n                \"status\": \"ok\" if not any(p[2] == \"error\" for p in package_issues) else \"error\",\n                \"value\": f\"{sum(1 for p in package_issues if p[2] == 'ok')}/{len(package_issues)} installed\",\n                \"packages\": package_issues,\n            }\n        )\n\n    # 7. SGLang installation source check\n    from kt_kernel.cli.utils.sglang_checker import check_sglang_installation, check_sglang_kt_kernel_support\n\n    sglang_info = check_sglang_installation()\n\n    if sglang_info[\"installed\"]:\n        if sglang_info.get(\"is_kvcache_fork\"):\n            # Package name is sglang-kt — this is definitively the kvcache-ai fork\n            if sglang_info[\"from_source\"] and sglang_info[\"git_info\"]:\n                git_remote = sglang_info[\"git_info\"].get(\"remote\", \"unknown\")\n                git_branch = sglang_info[\"git_info\"].get(\"branch\", \"unknown\")\n                sglang_source_value = f\"sglang-kt (Source: {git_remote}, branch: {git_branch})\"\n            elif sglang_info[\"editable\"]:\n                sglang_source_value = \"sglang-kt (editable)\"\n            else:\n                sglang_source_value = \"sglang-kt\"\n            sglang_source_status = \"ok\"\n            sglang_source_hint = None\n        elif sglang_info[\"from_source\"]:\n            if sglang_info[\"git_info\"]:\n                git_remote = sglang_info[\"git_info\"].get(\"remote\", \"unknown\")\n                git_branch = sglang_info[\"git_info\"].get(\"branch\", \"unknown\")\n                sglang_source_value = f\"Source (GitHub: {git_remote}, branch: {git_branch})\"\n                sglang_source_status = \"ok\"\n                sglang_source_hint = None\n            else:\n                sglang_source_value = \"Source (editable)\"\n                sglang_source_status = \"ok\"\n                sglang_source_hint = None\n        else:\n            sglang_source_value = \"PyPI sglang (not kvcache-ai fork)\"\n            sglang_source_status = \"warning\"\n            sglang_source_hint = t(\"sglang_pypi_hint\")\n    else:\n        sglang_source_value = \"Not installed\"\n        sglang_source_status = \"warning\"\n        sglang_source_hint = t(\"sglang_install_hint\")\n\n    checks.append(\n        {\n            \"name\": \"SGLang Source\",\n            \"status\": sglang_source_status,\n            \"value\": sglang_source_value,\n            \"hint\": sglang_source_hint,\n        }\n    )\n\n    # 7b. SGLang kt-kernel support check (only if SGLang is installed)\n    kt_kernel_support = {\"supported\": True}  # Default to True if not checked\n    if sglang_info[\"installed\"]:\n        # Use cache=False to force re-check in doctor, but silent=True since we show in table\n        kt_kernel_support = check_sglang_kt_kernel_support(use_cache=False, silent=True)\n\n        if kt_kernel_support[\"supported\"]:\n            kt_kernel_value = t(\"sglang_kt_kernel_supported\")\n            kt_kernel_status = \"ok\"\n            kt_kernel_hint = None\n        else:\n            kt_kernel_value = t(\"sglang_kt_kernel_not_supported\")\n            kt_kernel_status = \"error\"\n            kt_kernel_hint = \"Reinstall SGLang: pip uninstall sglang -y && pip install sglang-kt (or run ./install.sh from ktransformers root)\"\n            issues_found = True\n\n        checks.append(\n            {\n                \"name\": \"SGLang kt-kernel\",\n                \"status\": kt_kernel_status,\n                \"value\": kt_kernel_value,\n                \"hint\": kt_kernel_hint,\n            }\n        )\n\n    # 8. Environment managers\n    env_managers = detect_env_managers()\n    docker = check_docker()\n    env_list = [f\"{m.name} {m.version}\" for m in env_managers]\n    if docker:\n        env_list.append(f\"docker {docker.version}\")\n\n    checks.append(\n        {\n            \"name\": \"Environment Managers\",\n            \"status\": \"ok\" if env_list else \"warning\",\n            \"value\": \", \".join(env_list) if env_list else \"None found\",\n            \"hint\": \"conda or docker recommended for installation\" if not env_list else None,\n        }\n    )\n\n    # Display results\n    _display_results(checks, verbose)\n\n    # Show SGLang installation instructions if not installed\n    if not sglang_info[\"installed\"]:\n        from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions\n\n        console.print()\n        print_sglang_install_instructions()\n    # Show kt-kernel installation instructions if SGLang is installed but doesn't support kt-kernel\n    elif sglang_info[\"installed\"] and not kt_kernel_support.get(\"supported\", True):\n        from kt_kernel.cli.utils.sglang_checker import print_sglang_kt_kernel_instructions\n\n        console.print()\n        print_sglang_kt_kernel_instructions()\n\n    # Summary\n    console.print()\n    if issues_found:\n        print_warning(t(\"doctor_has_issues\"))\n    else:\n        print_success(t(\"doctor_all_ok\"))\n    console.print()\n\n\ndef _check_python_version(version: str) -> bool:\n    \"\"\"Check if Python version meets requirements.\"\"\"\n    parts = version.split(\".\")\n    try:\n        major, minor = int(parts[0]), int(parts[1])\n        return major >= 3 and minor >= 10\n    except (IndexError, ValueError):\n        return False\n\n\ndef _display_results(checks: list[dict], verbose: bool) -> None:\n    \"\"\"Display diagnostic results.\"\"\"\n    table = Table(show_header=True, header_style=\"bold\")\n    table.add_column(\"Check\", style=\"bold\")\n    table.add_column(\"Status\", width=8)\n    table.add_column(\"Value\")\n    if verbose:\n        table.add_column(\"Notes\", style=\"dim\")\n\n    for check in checks:\n        status = check[\"status\"]\n        if status == \"ok\":\n            status_str = f\"[green]{t('doctor_status_ok')}[/green]\"\n        elif status == \"warning\":\n            status_str = f\"[yellow]{t('doctor_status_warning')}[/yellow]\"\n        else:\n            status_str = f\"[red]{t('doctor_status_error')}[/red]\"\n\n        if verbose:\n            table.add_row(\n                check[\"name\"],\n                status_str,\n                check[\"value\"],\n                check.get(\"hint\", \"\"),\n            )\n        else:\n            table.add_row(\n                check[\"name\"],\n                status_str,\n                check[\"value\"],\n            )\n\n        # Show package details if verbose\n        if verbose and \"packages\" in check:\n            for pkg_name, pkg_version, pkg_status in check[\"packages\"]:\n                if pkg_status == \"ok\":\n                    pkg_status_str = \"[green]✓[/green]\"\n                elif pkg_status == \"warning\":\n                    pkg_status_str = \"[yellow]○[/yellow]\"\n                else:\n                    pkg_status_str = \"[red]✗[/red]\"\n\n                table.add_row(\n                    f\"  └─ {pkg_name}\",\n                    pkg_status_str,\n                    pkg_version,\n                    \"\",\n                )\n\n    console.print(table)\n"
  },
  {
    "path": "kt-kernel/python/cli/commands/model.py",
    "content": "\"\"\"\nModel command for kt-cli.\n\nManages models: download, list, and storage paths.\n\"\"\"\n\nimport os\nfrom pathlib import Path\nfrom typing import Optional, List\n\nimport typer\n\nfrom kt_kernel.cli.config.settings import get_settings\nfrom kt_kernel.cli.i18n import t, get_lang\nfrom kt_kernel.cli.utils.console import (\n    confirm,\n    console,\n    print_error,\n    print_info,\n    print_step,\n    print_success,\n    print_warning,\n    prompt_choice,\n)\n\n\n# Common SHA256 status display mapping used across multiple commands\nSHA256_STATUS_MAP = {\n    \"not_checked\": \"[dim]Not Checked[/dim]\",\n    \"checking\": \"[yellow]Checking...[/yellow]\",\n    \"passed\": \"[green]✓ Passed[/green]\",\n    \"failed\": \"[red]✗ Failed[/red]\",\n    \"no_repo\": \"[dim]-[/dim]\",\n}\n\n# Plain text version for panels and verbose output\nSHA256_STATUS_MAP_PLAIN = {\n    \"not_checked\": \"Not Checked\",\n    \"checking\": \"Checking...\",\n    \"passed\": \"✓ Passed\",\n    \"failed\": \"✗ Failed\",\n    \"no_repo\": \"-\",\n}\n\n\ndef is_amx_weights(model_path) -> tuple[bool, int]:\n    \"\"\"\n    Determine if a model uses AMX weights and count NUMA nodes.\n\n    Returns:\n        (is_amx, numa_count): Tuple where is_amx indicates AMX weights,\n        and numa_count is the number of NUMA nodes (0 if not AMX).\n    \"\"\"\n    import re\n    from pathlib import Path\n    from safetensors import safe_open\n\n    model_path = Path(model_path)\n    safetensors_files = sorted(model_path.glob(\"*.safetensors\"))\n\n    if not safetensors_files:\n        return False, 0\n\n    numa_indices = set()\n    numa_pattern = re.compile(r\"\\.numa\\.(\\d+)\\.\")\n\n    # Check first 3 files for NUMA keys\n    for file_path in safetensors_files[:3]:\n        try:\n            with safe_open(file_path, framework=\"pt\", device=\"cpu\") as f:\n                for key in f.keys():\n                    if \".numa.\" in key:\n                        match = numa_pattern.search(key)\n                        if match:\n                            numa_indices.add(int(match.group(1)))\n        except Exception:\n            continue\n\n    if not numa_indices:\n        return False, 0\n\n    return True, len(numa_indices)\n\n\napp = typer.Typer(\n    help=\"Manage models and storage paths\",\n    invoke_without_command=True,\n    no_args_is_help=False,\n)\n\n\n@app.callback()\ndef callback(ctx: typer.Context) -> None:\n    \"\"\"\n    Model management commands.\n\n    Run without arguments to see available models.\n    \"\"\"\n    # If no subcommand is provided, show the full model list\n    if ctx.invoked_subcommand is None:\n        list_models(verbose=False, all_models=False, show_moe=True, no_cache=False)\n\n\n@app.command(name=\"download\")\ndef download(\n    repo: Optional[str] = typer.Argument(None, help=\"Repository ID (optional, interactive mode if not provided)\"),\n    local_dir: Optional[str] = typer.Option(\n        None,\n        \"--local-dir\",\n        \"-d\",\n        help=\"Local directory to download to (default: auto-detect from config)\",\n    ),\n    repo_type: Optional[str] = typer.Option(\n        None,\n        \"--repo-type\",\n        \"-t\",\n        help=\"Repository type: huggingface or modelscope\",\n    ),\n    resume: bool = typer.Option(\n        True,\n        \"--resume/--no-resume\",\n        help=\"Resume incomplete downloads\",\n    ),\n    yes: bool = typer.Option(\n        False,\n        \"--yes\",\n        \"-y\",\n        help=\"Skip all prompts and use defaults\",\n    ),\n) -> None:\n    \"\"\"Download model from HuggingFace or ModelScope (interactive mode).\"\"\"\n    import subprocess\n    import os\n    from pathlib import Path\n    from rich.prompt import Prompt, Confirm\n    from rich.table import Table\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry, UserModel\n    from kt_kernel.cli.utils.model_scanner import scan_single_path, format_size\n    from kt_kernel.cli.utils.model_verifier import check_huggingface_connectivity\n    from kt_kernel.cli.utils.download_helper import (\n        list_remote_files_hf,\n        list_remote_files_ms,\n        filter_files_by_pattern,\n        calculate_total_size,\n        format_file_list_table,\n        verify_repo_exists,\n    )\n\n    settings = get_settings()\n    user_registry = UserModelRegistry()\n\n    console.print()\n\n    # ========== Step 1: Select repository type ==========\n    if not repo_type and not yes:\n        console.print(\"[bold cyan]Step 1: Select Repository Source[/bold cyan]\\n\")\n        console.print(\"  1. HuggingFace\")\n        console.print(\"  2. ModelScope\")\n        console.print()\n\n        choice = Prompt.ask(\"Select source\", choices=[\"1\", \"2\"], default=\"1\")\n        repo_type = \"huggingface\" if choice == \"1\" else \"modelscope\"\n        console.print()\n    elif not repo_type:\n        repo_type = \"huggingface\"  # Default for --yes mode\n\n    # Validate repo_type\n    if repo_type not in [\"huggingface\", \"modelscope\"]:\n        print_error(f\"Invalid repo type: {repo_type}. Must be 'huggingface' or 'modelscope'\")\n        raise typer.Exit(1)\n\n    # Check HuggingFace connectivity and auto-switch to mirror if needed\n    use_mirror = False\n    if repo_type == \"huggingface\":\n        with console.status(\"[dim]Checking HuggingFace connectivity...[/dim]\"):\n            is_accessible, message = check_huggingface_connectivity(timeout=5)\n\n        if not is_accessible:\n            print_warning(\"HuggingFace Connection Failed\")\n            console.print()\n            console.print(f\"  {message}\")\n            console.print()\n            console.print(\"  [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]\")\n            console.print()\n            use_mirror = True\n\n    # ========== Step 2: Input repository ID ==========\n    while True:\n        if not repo and not yes:\n            console.print(\"[bold cyan]Step 2: Enter Repository ID[/bold cyan]\\n\")\n            console.print(\"  Examples:\")\n            console.print(\"    • HuggingFace: deepseek-ai/DeepSeek-V3\")\n            console.print(\"    • ModelScope: Qwen/Qwen3-Coder-480B-A35B-Instruct\")\n            console.print()\n\n            repo = Prompt.ask(\"Repository ID\")\n            console.print()\n        elif not repo:\n            print_error(\"Repository ID is required\")\n            raise typer.Exit(1)\n\n        # Verify repository exists\n        with console.status(f\"[dim]Verifying repository: {repo}...[/dim]\"):\n            exists, msg = verify_repo_exists(repo, repo_type, use_mirror)\n\n        if exists:\n            print_success(f\"✓ Repository found: {repo}\")\n            console.print()\n            break\n        else:\n            print_error(msg)\n            console.print()\n            if yes:\n                raise typer.Exit(1)\n            repo = None  # Reset to ask again\n\n    # ========== Step 3: Input file pattern and preview files ==========\n    files_to_download = []\n    file_pattern = \"*\"\n\n    while True:\n        if not yes:\n            console.print(\"[bold cyan]Step 3: Select Files to Download[/bold cyan]\\n\")\n            console.print(\"  File pattern (glob syntax):\")\n            console.print(\"    • *                  - All files (default)\")\n            console.print(\"    • *.safetensors      - Only safetensors files\")\n            console.print(\"    • *.gguf             - Only GGUF files\")\n            console.print(\"    • *Q4_K_M.gguf       - Specific GGUF quant\")\n            console.print()\n\n            pattern_input = Prompt.ask(\"File pattern\", default=\"*\")\n            file_pattern = pattern_input\n            console.print()\n\n        # Fetch remote file list\n        with console.status(f\"[dim]Fetching file list from {repo_type}...[/dim]\"):\n            try:\n                if repo_type == \"huggingface\":\n                    all_files = list_remote_files_hf(repo, use_mirror)\n                else:\n                    all_files = list_remote_files_ms(repo)\n\n                files_to_download = filter_files_by_pattern(all_files, file_pattern)\n            except Exception as e:\n                print_error(f\"Failed to fetch file list: {e}\")\n                raise typer.Exit(1)\n\n        if not files_to_download:\n            print_warning(f\"No files match pattern: {file_pattern}\")\n            console.print()\n            if yes:\n                raise typer.Exit(1)\n            continue  # Ask for pattern again\n\n        # Display matched files\n        total_size = calculate_total_size(files_to_download)\n        print_success(f\"Found {len(files_to_download)} files (Total: {format_size(total_size)})\")\n        console.print()\n\n        file_table = format_file_list_table(files_to_download, max_display=10)\n        console.print(file_table)\n        console.print()\n\n        # Confirm or retry\n        if yes:\n            break\n\n        action = Prompt.ask(\"Action\", choices=[\"continue\", \"retry\", \"cancel\"], default=\"continue\")\n\n        if action == \"continue\":\n            console.print()\n            break\n        elif action == \"cancel\":\n            console.print()\n            print_info(\"Download cancelled\")\n            console.print()\n            return\n        # else retry - loop continues\n\n    # ========== Step 4: Select download path ==========\n    download_path = None\n\n    if local_dir:\n        download_path = Path(os.path.expanduser(local_dir)).resolve()\n    elif not yes:\n        console.print(\"[bold cyan]Step 4: Select Download Location[/bold cyan]\\n\")\n\n        # Get configured model paths\n        model_paths = settings.get_model_paths()\n        if not model_paths:\n            print_error(\"No model storage paths configured.\")\n            console.print()\n            console.print(f\"  Add a path with: [cyan]kt model path-add <path>[/cyan]\")\n            console.print()\n            raise typer.Exit(1)\n\n        # Display configured paths\n        console.print(\"  Configured storage paths:\")\n        for i, path in enumerate(model_paths, 1):\n            console.print(f\"    {i}. {path}\")\n        console.print(f\"    {len(model_paths) + 1}. Custom path (manual input)\")\n        console.print()\n\n        path_choice = Prompt.ask(\"Select path\", choices=[str(i) for i in range(1, len(model_paths) + 2)], default=\"1\")\n\n        if int(path_choice) <= len(model_paths):\n            base_path = model_paths[int(path_choice) - 1]\n        else:\n            custom = Prompt.ask(\"Enter custom path\")\n            base_path = Path(os.path.expanduser(custom)).resolve()\n\n        console.print()\n\n        # Ask for folder name\n        default_folder = repo.split(\"/\")[-1]\n        folder_name = Prompt.ask(\"Folder name\", default=default_folder)\n\n        download_path = base_path / folder_name\n        console.print()\n    else:\n        # --yes mode: use default\n        model_paths = settings.get_model_paths()\n        if not model_paths:\n            print_error(\"No model storage paths configured.\")\n            raise typer.Exit(1)\n\n        default_folder = repo.split(\"/\")[-1]\n        download_path = model_paths[0] / default_folder\n\n    # ========== Step 5: Confirm and download ==========\n    print_info(f\"Download destination: {download_path}\")\n    console.print()\n\n    # Check if path exists\n    if download_path.exists():\n        existing = user_registry.find_by_path(str(download_path))\n        if existing:\n            print_warning(f\"Model already registered as: {existing.name}\")\n            console.print()\n            if not yes and not Confirm.ask(\"Re-download anyway?\", default=False):\n                return\n        else:\n            print_warning(f\"Directory already exists: {download_path}\")\n            if not yes and not Confirm.ask(\"Overwrite?\", default=False):\n                return\n        console.print()\n\n    # Final confirmation\n    if not yes:\n        console.print(\"[bold]Download Summary:[/bold]\")\n        console.print(f\"  Source:      {repo_type}:{repo}\")\n        console.print(\n            f\"  Files:       {len(files_to_download)} files ({format_size(calculate_total_size(files_to_download))})\"\n        )\n        console.print(f\"  Pattern:     {file_pattern}\")\n        console.print(f\"  Destination: {download_path}\")\n        console.print()\n\n        if not Confirm.ask(\"Start download?\", default=True):\n            console.print()\n            print_info(\"Download cancelled\")\n            console.print()\n            return\n\n    # Download\n    console.print()\n    print_step(\"Downloading model files...\")\n    console.print()\n\n    # Set mirror for HuggingFace if needed\n    original_hf_endpoint = os.environ.get(\"HF_ENDPOINT\")\n    if use_mirror and repo_type == \"huggingface\" and not original_hf_endpoint:\n        os.environ[\"HF_ENDPOINT\"] = \"https://hf-mirror.com\"\n\n    try:\n        if repo_type == \"huggingface\":\n            from huggingface_hub import snapshot_download\n\n            snapshot_download(\n                repo_id=repo,\n                local_dir=str(download_path),\n                allow_patterns=file_pattern if file_pattern != \"*\" else None,\n                local_dir_use_symlinks=False,\n                resume_download=resume,\n            )\n\n        else:  # modelscope\n            from modelscope.hub.snapshot_download import snapshot_download\n\n            snapshot_download(\n                model_id=repo,\n                local_dir=str(download_path),\n                allow_file_pattern=file_pattern if file_pattern != \"*\" else None,\n            )\n\n    except ImportError as e:\n        pkg = \"huggingface_hub\" if repo_type == \"huggingface\" else \"modelscope\"\n        print_error(f\"{pkg} not installed. Install: pip install {pkg}\")\n        raise typer.Exit(1)\n    except Exception as e:\n        print_error(f\"Download failed: {e}\")\n        raise typer.Exit(1)\n    finally:\n        # Restore HF_ENDPOINT\n        if use_mirror and repo_type == \"huggingface\" and not original_hf_endpoint:\n            os.environ.pop(\"HF_ENDPOINT\", None)\n        elif original_hf_endpoint:\n            os.environ[\"HF_ENDPOINT\"] = original_hf_endpoint\n\n    # ========== Step 6: Scan and register ==========\n    console.print()\n    print_success(\"Download complete!\")\n\n    console.print()\n    print_step(\"Scanning downloaded model...\")\n\n    try:\n        scanned = scan_single_path(download_path)\n    except Exception as e:\n        print_error(f\"Failed to scan model: {e}\")\n        console.print()\n        console.print(f\"  You can manually add it: [cyan]kt model add {download_path}[/cyan]\")\n        console.print()\n        raise typer.Exit(1)\n\n    if not scanned:\n        print_warning(\"No model files found in downloaded directory.\")\n        console.print()\n        console.print(\"  Supported formats: .safetensors, .gguf\")\n        console.print()\n        return\n\n    # Auto-generate model name\n    model_name = download_path.name\n    if user_registry.check_name_conflict(model_name):\n        model_name = user_registry.suggest_name(model_name)\n\n    # Create and register model\n    user_model = UserModel(\n        name=model_name,\n        path=str(download_path),\n        format=scanned.format,\n        repo_type=repo_type,\n        repo_id=repo,\n        sha256_status=\"not_checked\",\n    )\n\n    try:\n        user_registry.add_model(user_model)\n        console.print()\n        print_success(f\"Model registered as: {model_name}\")\n        console.print()\n        console.print(f\"  View details:     [cyan]kt model info {model_name}[/cyan]\")\n        console.print(f\"  Run model:        [cyan]kt run {model_name}[/cyan]\")\n        console.print(f\"  Verify integrity: [cyan]kt model verify {model_name}[/cyan]\")\n        console.print()\n    except Exception as e:\n        print_error(f\"Failed to register model: {e}\")\n        console.print()\n        console.print(f\"  You can manually add it: [cyan]kt model add {download_path}[/cyan]\")\n        console.print()\n        raise typer.Exit(1)\n\n\n@app.command(name=\"list\")\ndef list_models(\n    verbose: bool = typer.Option(False, \"--verbose\", \"-v\", help=\"Show detailed info including paths\"),\n    all_models: bool = typer.Option(False, \"--all\", help=\"Show all models (reserved for future use)\"),\n    show_moe: bool = typer.Option(True, \"--moe/--no-moe\", help=\"Show MoE model information (default: enabled)\"),\n    no_cache: bool = typer.Option(False, \"--no-cache\", help=\"Force re-analyze all models, ignore cache\"),\n) -> None:\n    \"\"\"List user-registered models.\"\"\"\n    from rich.table import Table\n    from rich.panel import Panel\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n    from kt_kernel.cli.utils.model_scanner import format_size\n    import sys\n    from pathlib import Path as PathLib\n\n    # Try to import analyze_moe_model from multiple locations\n    analyze_moe_model = None\n    try:\n        # Try 1: From kt_kernel.cli.utils\n        from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model\n    except ImportError:\n        try:\n            # Try 2: From parent directories\n            analyze_moe_path = PathLib(__file__).parent.parent.parent.parent.parent.parent / \"analyze_moe_model.py\"\n            if analyze_moe_path.exists():\n                sys.path.insert(0, str(analyze_moe_path.parent))\n                from analyze_moe_model import analyze_moe_model\n        except (ImportError, Exception):\n            try:\n                # Try 3: Absolute path\n                sys.path.insert(0, \"/mnt/data2/ljq/ktransformers\")\n                from analyze_moe_model import analyze_moe_model\n            except (ImportError, Exception):\n                analyze_moe_model = None\n\n    registry = UserModelRegistry()\n    models = registry.list_models()\n\n    console.print()\n\n    if not models:\n        print_warning(t(\"model_no_registered_models\"))\n        console.print()\n        console.print(f\"  {t('model_scan_hint')} [cyan]kt model scan[/cyan]\")\n        console.print(f\"  {t('model_add_hint')} [cyan]kt model add <path>[/cyan]\")\n        console.print()\n        return\n\n    # Check for models with non-existent paths and remove them automatically\n    models_to_remove = []\n    for model in models:\n        if not model.path_exists():\n            models_to_remove.append(model)\n\n    if models_to_remove:\n        console.print(f\"[yellow]Found {len(models_to_remove)} model(s) with non-existent paths:[/yellow]\")\n        for model in models_to_remove:\n            console.print(f\"  [dim]✗ {model.name}: {model.path}[/dim]\")\n            registry.remove_model(model.name)\n        console.print(f\"[green]✓ Automatically removed {len(models_to_remove)} model(s) with missing paths[/green]\")\n        console.print()\n\n        # Refresh the models list\n        models = registry.list_models()\n\n        if not models:\n            console.print(f\"[dim]No models remaining after cleanup.[/dim]\")\n            console.print()\n            console.print(f\"  {t('model_scan_hint')} [cyan]kt model scan[/cyan]\")\n            console.print(f\"  {t('model_add_hint')} [cyan]kt model add <path>[/cyan]\")\n            console.print()\n            return\n\n    if verbose:\n        # Verbose mode: detailed cards\n        console.print(f\"[bold cyan]{t('model_registered_models_title')}[/bold cyan]\\n\")\n\n        for i, model in enumerate(models, 1):\n            # Check if path exists\n            path_status = \"[green]✓ Exists[/green]\" if model.path_exists() else \"[red]✗ Missing[/red]\"\n\n            # Format repo info\n            if model.repo_id:\n                repo_abbr = \"hf\" if model.repo_type == \"huggingface\" else \"ms\"\n                repo_info = f\"{repo_abbr}:{model.repo_id}\"\n            else:\n                repo_info = \"-\"\n\n            # Format SHA256 status\n            sha256_display = SHA256_STATUS_MAP_PLAIN.get(model.sha256_status, model.sha256_status)\n\n            # Calculate folder size if exists\n            if model.path_exists():\n                from pathlib import Path\n\n                path_obj = Path(model.path)\n                try:\n                    if model.format == \"safetensors\":\n                        files = list(path_obj.glob(\"*.safetensors\"))\n                    else:\n                        files = list(path_obj.glob(\"*.gguf\"))\n\n                    total_size = sum(f.stat().st_size for f in files if f.exists())\n                    size_str = format_size(total_size)\n                    file_count = len(files)\n                    size_info = f\"{size_str} ({file_count} files)\"\n                except:\n                    size_info = \"Unknown\"\n            else:\n                size_info = \"-\"\n\n            # Create panel content\n            content = f\"\"\"[bold]Path:[/bold]   {model.path}\n[bold]Format:[/bold] {model.format}\n[bold]Repo:[/bold]   {repo_info}\n[bold]SHA256:[/bold] {sha256_display}\n[bold]Size:[/bold]   {size_info}\n[bold]Status:[/bold] {path_status}\"\"\"\n\n            panel = Panel(content, title=f\"[cyan]{model.name}[/cyan]\", border_style=\"cyan\", padding=(0, 1))\n            console.print(panel)\n\n        console.print()\n        console.print(f\"[dim]Total: {len(models)} model(s)[/dim]\\n\")\n    else:\n        # Compact mode: separate tables by model type\n        from rich.align import Align\n        from pathlib import Path\n\n        # Categorize models\n        gguf_models = []\n        amx_models = []\n        gpu_models = []\n\n        for model in models:\n            if model.format == \"gguf\":\n                gguf_models.append(model)\n            elif model.format == \"safetensors\" and model.path_exists():\n                is_amx, numa_count = is_amx_weights(model.path)\n                if is_amx:\n                    amx_models.append((model, numa_count))\n                else:\n                    gpu_models.append(model)\n            else:\n                gpu_models.append(model)\n\n        # Pre-analyze GPU MoE models concurrently if enabled\n        moe_results = {}\n        moe_failed_models = []  # Track models that failed MoE analysis\n        if show_moe and analyze_moe_model and gpu_models:\n            from concurrent.futures import ThreadPoolExecutor, as_completed\n            import threading\n\n            # Collect GPU models that need MoE analysis\n            # Priority: use cached MoE info from UserModel, only analyze if is_moe is None\n            models_to_analyze = []\n            models_need_update = []  # Track models that need registry update\n\n            for model in gpu_models:\n                # Check if MoE info is already cached in UserModel (and not using --no-cache)\n                if not no_cache and model.is_moe is not None:\n                    # Use cached info from UserModel\n                    if model.is_moe:\n                        moe_results[model.name] = {\n                            \"is_moe\": True,\n                            \"num_experts\": model.moe_num_experts,\n                            \"num_experts_per_tok\": model.moe_num_experts_per_tok,\n                            \"cached\": True,\n                        }\n                    # If is_moe is False, don't add to moe_results\n                else:\n                    # Need to analyze (is_moe is None or --no-cache)\n                    path_obj = Path(model.path)\n                    models_to_analyze.append((model.name, str(path_obj)))\n                    models_need_update.append(model)\n\n            if models_to_analyze:\n                # Use lock for thread-safe console output\n                print_lock = threading.Lock()\n                completed_count = [0]  # Use list to allow modification in nested function\n\n                def analyze_with_progress(model_info):\n                    model_name, model_path = model_info\n                    try:\n                        with print_lock:\n                            console.print(f\"[dim]Analyzing MoE: {model_name}...[/dim]\")\n                        result = analyze_moe_model(model_path, use_cache=not no_cache)\n\n                        # Check if analysis returned valid results\n                        if result is None or result.get(\"num_experts\", 0) == 0:\n                            with print_lock:\n                                completed_count[0] += 1\n                                console.print(\n                                    f\"[dim]✗ [{completed_count[0]}/{len(models_to_analyze)}] {model_name} - Not a MoE model or analysis failed[/dim]\"\n                                )\n                            return (model_name, None, \"Not a MoE model or analysis failed\")\n\n                        with print_lock:\n                            completed_count[0] += 1\n                            cached_tag = \"[green](cached)[/green]\" if result and result.get(\"cached\") else \"\"\n                            console.print(\n                                f\"[dim]✓ [{completed_count[0]}/{len(models_to_analyze)}] {model_name} {cached_tag}[/dim]\"\n                            )\n                        return (model_name, result, None)\n                    except Exception as e:\n                        with print_lock:\n                            completed_count[0] += 1\n                            error_msg = str(e)[:80]\n                            console.print(\n                                f\"[dim]✗ [{completed_count[0]}/{len(models_to_analyze)}] {model_name} - Error: {error_msg}[/dim]\"\n                            )\n                        return (model_name, None, error_msg)\n\n                if no_cache:\n                    console.print(f\"\\n[yellow]Force re-analyzing (--no-cache): ignoring cached results[/yellow]\")\n                console.print(\n                    f\"\\n[cyan]Analyzing {len(models_to_analyze)} MoE model(s) with {min(16, len(models_to_analyze))} threads...[/cyan]\\n\"\n                )\n\n                # Analyze concurrently with up to 16 workers\n                with ThreadPoolExecutor(max_workers=16) as executor:\n                    futures = {\n                        executor.submit(analyze_with_progress, model_info): model_info\n                        for model_info in models_to_analyze\n                    }\n\n                    for future in as_completed(futures):\n                        model_name, result, error = future.result()\n                        if error:\n                            # Find the model object\n                            failed_model = next((m for m in gpu_models if m.name == model_name), None)\n                            if failed_model:\n                                moe_failed_models.append((failed_model, error))\n                                # Update model registry: mark as non-MoE\n                                registry.update_model(model_name, {\"is_moe\": False})\n                        else:\n                            moe_results[model_name] = result\n                            # Update model registry with MoE info\n                            if result and result.get(\"is_moe\"):\n                                registry.update_model(\n                                    model_name,\n                                    {\n                                        \"is_moe\": True,\n                                        \"moe_num_experts\": result.get(\"num_experts\"),\n                                        \"moe_num_experts_per_tok\": result.get(\"num_experts_per_tok\"),\n                                    },\n                                )\n                            else:\n                                registry.update_model(model_name, {\"is_moe\": False})\n\n                console.print(f\"\\n[green]✓ MoE analysis complete[/green]\\n\")\n\n                # Remove failed models from gpu_models list\n                if moe_failed_models:\n                    failed_names = {m.name for m, _ in moe_failed_models}\n                    gpu_models = [m for m in gpu_models if m.name not in failed_names]\n\n        # Separate MoE and non-MoE GPU models\n        moe_gpu_models = []\n        non_moe_gpu_models = []\n        for model in gpu_models:\n            if model.name in moe_results:\n                moe_gpu_models.append(model)\n            else:\n                non_moe_gpu_models.append(model)\n\n        # Count failed MoE models (these are also non-MoE)\n        total_non_moe_count = len(non_moe_gpu_models) + len(moe_failed_models)\n\n        # Filter display based on --all flag\n        if not all_models:\n            # Default: only show MoE models\n            gpu_models_to_display = moe_gpu_models\n            show_failed_table = False\n        else:\n            # --all: show all GPU models including non-MoE and failed\n            gpu_models_to_display = gpu_models\n            show_failed_table = True\n            total_non_moe_count = 0  # Don't show hint when displaying all\n\n        # Helper function to create table rows\n        def format_model_row(model, moe_info=None, numa_count=None):\n            from kt_kernel.cli.utils.model_scanner import format_size\n\n            # Calculate size\n            if model.path_exists():\n                path_obj = Path(model.path)\n                try:\n                    if model.format == \"safetensors\":\n                        files = list(path_obj.glob(\"*.safetensors\"))\n                    else:\n                        files = list(path_obj.glob(\"*.gguf\"))\n\n                    total_size = sum(f.stat().st_size for f in files if f.exists())\n                    size_display = format_size(total_size)\n                except:\n                    size_display = \"[dim]-[/dim]\"\n            else:\n                size_display = \"[dim]-[/dim]\"\n\n            # Format repo info\n            if model.repo_id:\n                repo_abbr = \"hf\" if model.repo_type == \"huggingface\" else \"ms\"\n                repo_display = f\"{repo_abbr}:{model.repo_id}\"\n            else:\n                repo_display = \"[dim]-[/dim]\"\n\n            # Format SHA256 status\n            sha256_display = SHA256_STATUS_MAP.get(model.sha256_status, model.sha256_status)\n\n            row = [model.name, model.path, size_display]\n\n            # Add type-specific columns\n            if numa_count is not None:\n                # AMX model\n                row.append(f\"[yellow]{numa_count} NUMA[/yellow]\")\n            elif moe_info:\n                # GPU MoE model\n                experts_display = f\"[yellow]{moe_info['num_experts']}[/yellow]\"\n                activated_display = f\"[green]{moe_info['num_experts_per_tok']}[/green]\"\n                moe_total_display = f\"[cyan]{size_display}[/cyan]\"\n                row.extend([experts_display, activated_display, moe_total_display])\n            elif show_moe and analyze_moe_model and model.format == \"safetensors\":\n                # GPU non-MoE model\n                row.extend([\"[dim]-[/dim]\", \"[dim]-[/dim]\", \"[dim]-[/dim]\"])\n\n            row.extend([repo_display, sha256_display])\n            return row\n\n        # Display tables\n        title = Align.center(f\"[bold cyan]{t('model_registered_models_title')}[/bold cyan]\")\n        console.print(title)\n        console.print()\n\n        # Table 1: GGUF Models (Llamafile)\n        if gguf_models:\n            console.print(\"[bold yellow]GGUF Models (Llamafile)[/bold yellow]\")\n            table = Table(show_header=True, header_style=\"bold\")\n            table.add_column(\"#\", justify=\"right\", style=\"cyan\", no_wrap=True)\n            table.add_column(t(\"model_column_name\"), style=\"cyan\", no_wrap=True)\n            table.add_column(\"Path\", style=\"dim\", overflow=\"fold\")\n            table.add_column(\"Total\", justify=\"right\")\n            table.add_column(t(\"model_column_repo\"), style=\"dim\", overflow=\"fold\")\n            table.add_column(t(\"model_column_sha256\"), justify=\"center\")\n\n            for i, model in enumerate(gguf_models, 1):\n                row = [str(i)] + format_model_row(model)\n                table.add_row(*row)\n\n            console.print(table)\n            console.print()\n\n        # Table 2: AMX Models\n        if amx_models:\n            from kt_kernel.cli.utils.model_scanner import format_size\n            import json\n\n            console.print(\"[bold magenta]AMX Models (CPU)[/bold magenta]\")\n            table = Table(show_header=True, header_style=\"bold\", show_lines=False)\n            table.add_column(\"#\", justify=\"right\", style=\"cyan\", no_wrap=True)\n            table.add_column(t(\"model_column_name\"), style=\"cyan\", no_wrap=True)\n            table.add_column(\"Path\", style=\"dim\", overflow=\"fold\")\n            table.add_column(\"Total\", justify=\"right\")\n            table.add_column(\"Method\", justify=\"center\", style=\"yellow\")\n            table.add_column(\"NUMA\", justify=\"center\", style=\"green\")\n            table.add_column(\"Source\", style=\"dim\", overflow=\"fold\")\n\n            # Build reverse map: AMX model ID -> GPU models using it\n            amx_used_by_gpu = {}  # {amx_model_id: [gpu_model_names]}\n            for model, _ in amx_models:\n                if model.gpu_model_ids:\n                    # This AMX is linked to these GPU models\n                    gpu_names = []\n                    for gpu_id in model.gpu_model_ids:\n                        # Find GPU model by ID\n                        for gpu_model in gpu_models:\n                            if gpu_model.id == gpu_id:\n                                gpu_names.append(gpu_model.name)\n                                break\n                    if gpu_names:\n                        amx_used_by_gpu[model.id] = gpu_names\n\n            for i, (model, numa_count) in enumerate(amx_models, 1):\n                # Calculate size\n                if model.path_exists():\n                    path_obj = Path(model.path)\n                    try:\n                        files = list(path_obj.glob(\"*.safetensors\"))\n                        total_size = sum(f.stat().st_size for f in files if f.exists())\n                        size_display = format_size(total_size)\n                    except:\n                        size_display = \"[dim]-[/dim]\"\n                else:\n                    size_display = \"[dim]-[/dim]\"\n\n                # Read AMX metadata from config.json (fallback if not in UserModel)\n                method_from_config = None\n                numa_from_config = None\n                if model.path_exists():\n                    config_path = Path(model.path) / \"config.json\"\n                    if config_path.exists():\n                        try:\n                            with open(config_path, \"r\", encoding=\"utf-8\") as f:\n                                config = json.load(f)\n                                amx_quant = config.get(\"amx_quantization\", {})\n                                if amx_quant.get(\"converted\"):\n                                    method_from_config = amx_quant.get(\"method\")\n                                    numa_from_config = amx_quant.get(\"numa_count\")\n                        except:\n                            pass\n\n                # AMX-specific metadata (priority: UserModel > config.json > detected numa_count)\n                method_display = (\n                    model.amx_quant_method.upper()\n                    if model.amx_quant_method\n                    else method_from_config.upper() if method_from_config else \"[dim]?[/dim]\"\n                )\n                numa_display = (\n                    str(model.amx_numa_nodes)\n                    if model.amx_numa_nodes\n                    else (\n                        str(numa_from_config) if numa_from_config else str(numa_count) if numa_count else \"[dim]?[/dim]\"\n                    )\n                )\n                source_display = model.amx_source_model if model.amx_source_model else \"[dim]-[/dim]\"\n\n                table.add_row(\n                    str(i), model.name, model.path, size_display, method_display, numa_display, source_display\n                )\n\n                # Add linked GPU models info below this AMX model\n                if model.id in amx_used_by_gpu:\n                    gpu_list = amx_used_by_gpu[model.id]\n                    gpu_names_str = \", \".join([f\"[dim]{name}[/dim]\" for name in gpu_list])\n                    # Create a sub-row with empty cells except for the first column (7 columns total with #)\n                    sub_row = [\"\", f\"  [dim]↳ GPU: {gpu_names_str}[/dim]\", \"\", \"\", \"\", \"\", \"\"]\n                    table.add_row(*sub_row, style=\"dim\")\n\n            console.print(table)\n            console.print()\n\n        # Table 3: GPU Models (Safetensors)\n        if gpu_models_to_display:\n            console.print(\"[bold green]GPU Models (Safetensors)[/bold green]\")\n            table = Table(show_header=True, header_style=\"bold\", show_lines=False)\n            table.add_column(\"#\", justify=\"right\", style=\"cyan\", no_wrap=True)\n            table.add_column(t(\"model_column_name\"), style=\"cyan\", no_wrap=True)\n            table.add_column(\"Path\", style=\"dim\", overflow=\"fold\")\n            table.add_column(\"Total\", justify=\"right\")\n\n            if show_moe and analyze_moe_model:\n                table.add_column(\"Exps\", justify=\"center\", style=\"yellow\")\n                table.add_column(\"Act\", justify=\"center\", style=\"green\")\n                table.add_column(\"MoE Size\", justify=\"right\", style=\"cyan\")\n\n            table.add_column(t(\"model_column_repo\"), style=\"dim\", overflow=\"fold\")\n            table.add_column(t(\"model_column_sha256\"), justify=\"center\")\n\n            # Build a map of GPU model UUID -> attached CPU models\n            attached_cpu_models = {}  # {gpu_model_id: [(cpu_model, type)]}\n            for model in gguf_models:\n                if model.gpu_model_ids:\n                    for gpu_id in model.gpu_model_ids:\n                        if gpu_id not in attached_cpu_models:\n                            attached_cpu_models[gpu_id] = []\n                        attached_cpu_models[gpu_id].append((model, \"GGUF\"))\n\n            for model, numa_count in amx_models:\n                if model.gpu_model_ids:\n                    for gpu_id in model.gpu_model_ids:\n                        if gpu_id not in attached_cpu_models:\n                            attached_cpu_models[gpu_id] = []\n                        attached_cpu_models[gpu_id].append((model, \"AMX\"))\n\n            for i, model in enumerate(gpu_models_to_display, 1):\n                moe_info = moe_results.get(model.name) if show_moe and analyze_moe_model else None\n                row = [str(i)] + format_model_row(model, moe_info=moe_info)\n                table.add_row(*row)\n\n                # Add attached CPU models info below this GPU model (using UUID matching)\n                if model.id in attached_cpu_models:\n                    cpu_list = attached_cpu_models[model.id]\n                    cpu_names = \", \".join([f\"[dim]{m.name} ({t})[/dim]\" for m, t in cpu_list])\n                    # Create a sub-row with empty cells except for the first column\n                    num_cols = len(row)\n                    sub_row = [\"\", f\"  [dim]↳ CPU: {cpu_names}[/dim]\"] + [\"\"] * (num_cols - 2)\n                    table.add_row(*sub_row, style=\"dim\")\n\n            console.print(table)\n            console.print()\n\n        # Table 4: Failed MoE Analysis (only show with --all)\n        if show_failed_table and moe_failed_models:\n            console.print(\"[bold red]Failed MoE Analysis[/bold red]\")\n            console.print(\"[yellow]These models may not be MoE models or have analysis errors:[/yellow]\\n\")\n            table = Table(show_header=True, header_style=\"bold\")\n            table.add_column(\"#\", justify=\"right\", style=\"cyan\", no_wrap=True)\n            table.add_column(t(\"model_column_name\"), style=\"red\", no_wrap=True)\n            table.add_column(\"Path\", style=\"dim\", overflow=\"fold\")\n            table.add_column(\"Total\", justify=\"right\")\n            table.add_column(\"Error\", style=\"yellow\", overflow=\"fold\")\n\n            for i, (model, error) in enumerate(moe_failed_models, 1):\n                from kt_kernel.cli.utils.model_scanner import format_size\n\n                if model.path_exists():\n                    path_obj = Path(model.path)\n                    try:\n                        files = list(path_obj.glob(\"*.safetensors\"))\n                        total_size = sum(f.stat().st_size for f in files if f.exists())\n                        size_display = format_size(total_size)\n                    except:\n                        size_display = \"[dim]-[/dim]\"\n                else:\n                    size_display = \"[dim]-[/dim]\"\n\n                table.add_row(str(i), model.name, model.path, size_display, error)\n\n            console.print(table)\n            console.print()\n\n        # Show hint if non-MoE models are hidden (display before summary)\n        if total_non_moe_count > 0:\n            hint_text = t(\"model_non_moe_hidden_hint\", count=total_non_moe_count)\n            console.print(f\"[dim]{hint_text}[/dim]\")\n            console.print()\n\n        # Summary\n        total_count = len(gguf_models) + len(amx_models) + len(gpu_models)\n        failed_count = len(moe_failed_models)\n        if failed_count > 0:\n            console.print(\n                f\"[dim]Total: {total_count} model(s) | GGUF: {len(gguf_models)} | AMX: {len(amx_models)} | GPU: {len(gpu_models)} | [red]Failed: {failed_count}[/red][/dim]\\n\"\n            )\n        else:\n            console.print(\n                f\"[dim]Total: {total_count} model(s) | GGUF: {len(gguf_models)} | AMX: {len(amx_models)} | GPU: {len(gpu_models)}[/dim]\\n\"\n            )\n\n        # Show usage hints (only in non-verbose mode)\n        if not verbose and models:\n            console.print(f\"[bold cyan]{t('model_usage_title')}[/bold cyan]\")\n            console.print(f\"  {t('model_usage_info'):<17} [cyan]kt model info <name>[/cyan]\")\n            console.print(f\"  {t('model_usage_edit'):<17} [cyan]kt model edit <name>[/cyan]\")\n            console.print(f\"  {t('model_usage_verify'):<17} [cyan]kt model verify <name>[/cyan]\")\n            console.print(f\"  {t('model_usage_quant'):<17} [cyan]kt quant <name>[/cyan]\")\n            console.print(f\"  {t('model_usage_run'):<17} [cyan]kt run <name>[/cyan]\")\n            console.print()\n            console.print(f\"  {t('model_usage_scan'):<17} [cyan]kt model scan[/cyan]\")\n            console.print(f\"  {t('model_usage_add'):<17} [cyan]kt model add <path>[/cyan]\")\n            console.print()\n\n\n@app.command(name=\"clear-cache\")\ndef clear_cache() -> None:\n    \"\"\"Clear MoE analysis cache.\"\"\"\n    from pathlib import Path\n    import json\n\n    cache_file = Path.home() / \".ktransformers\" / \"cache\" / \"moe_analysis.json\"\n\n    if not cache_file.exists():\n        console.print()\n        console.print(\"[dim]No MoE cache found.[/dim]\")\n        console.print()\n        return\n\n    # Read cache to count entries\n    try:\n        with open(cache_file, \"r\") as f:\n            cache_data = json.load(f)\n        cache_count = len(cache_data)\n    except Exception:\n        cache_count = 0\n\n    if cache_count == 0:\n        console.print()\n        console.print(\"[dim]MoE cache is empty.[/dim]\")\n        console.print()\n        return\n\n    console.print()\n    console.print(f\"[yellow]Found {cache_count} cached model(s) in:[/yellow]\")\n    console.print(f\"  {cache_file}\")\n    console.print()\n\n    if confirm(\"Clear all MoE analysis cache?\", default=False):\n        cache_file.unlink()\n        console.print(f\"[green]✓ Cleared cache for {cache_count} model(s)[/green]\")\n    else:\n        console.print(\"[dim]Cache clear cancelled.[/dim]\")\n\n    console.print()\n\n\n@app.command(name=\"path-list\")\ndef path_list() -> None:\n    \"\"\"List all configured model storage paths.\"\"\"\n    settings = get_settings()\n    model_paths = settings.get_model_paths()\n\n    console.print()\n    console.print(f\"[bold]{t('model_storage_paths_title')}:[/bold]\\n\")\n\n    for i, path in enumerate(model_paths, 1):\n        marker = \"[green]✓[/green]\" if path.exists() else \"[red]✗[/red]\"\n        console.print(f\"  {marker} [{i}] {path}\")\n\n    console.print()\n\n\n@app.command(name=\"link-cpu\")\ndef link_cpu(\n    cpu_model: str = typer.Argument(..., help=\"Name of the CPU model (GGUF/AMX)\"),\n    gpu_models: List[str] = typer.Argument(..., help=\"Name(s) of GPU model(s) to link with\"),\n) -> None:\n    \"\"\"Link a CPU model (GGUF/AMX) with one or more GPU models for joint startup.\"\"\"\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n\n    registry = UserModelRegistry()\n\n    # Check if CPU model exists\n    cpu_model_obj = registry.get_model(cpu_model)\n    if not cpu_model_obj:\n        print_error(f\"CPU model '{cpu_model}' not found in registry.\")\n        console.print()\n        console.print(f\"  Use [cyan]kt model list[/cyan] to see registered models\")\n        console.print()\n        raise typer.Exit(1)\n\n    # Check if it's actually a CPU model (GGUF or AMX)\n    if cpu_model_obj.format == \"safetensors\":\n        # Check if it's AMX by looking for .numa. pattern\n        is_amx, _ = is_amx_weights(cpu_model_obj.path)\n        if not is_amx:\n            print_error(f\"Model '{cpu_model}' is a GPU model (safetensors), not a CPU model.\")\n            console.print()\n            console.print(f\"  Only GGUF and AMX models can be linked to GPU models\")\n            console.print()\n            raise typer.Exit(1)\n\n    # Verify all GPU models exist and collect their UUIDs\n    gpu_model_uuids = []\n    missing_models = []\n    for gpu_name in gpu_models:\n        gpu_model_obj = registry.get_model(gpu_name)\n        if not gpu_model_obj:\n            missing_models.append(gpu_name)\n        else:\n            gpu_model_uuids.append(gpu_model_obj.id)\n\n    if missing_models:\n        print_error(f\"GPU model(s) not found: {', '.join(missing_models)}\")\n        console.print()\n        console.print(f\"  Use [cyan]kt model list[/cyan] to see registered models\")\n        console.print()\n        raise typer.Exit(1)\n\n    # Update the CPU model with GPU links (using UUIDs for stability)\n    registry.update_model(cpu_model, {\"gpu_model_ids\": gpu_model_uuids})\n\n    console.print()\n    print_success(f\"Linked CPU model '{cpu_model}' with GPU model(s):\")\n    for gpu_name in gpu_models:\n        console.print(f\"  [green]✓[/green] {gpu_name}\")\n    console.print()\n    console.print(f\"  View the relationship with [cyan]kt model list[/cyan]\")\n    console.print()\n\n\n@app.command(name=\"unlink-cpu\")\ndef unlink_cpu(\n    cpu_model: str = typer.Argument(..., help=\"Name of the CPU model to unlink\"),\n) -> None:\n    \"\"\"Remove GPU model links from a CPU model.\"\"\"\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n\n    registry = UserModelRegistry()\n\n    # Check if model exists\n    model = registry.get_model(cpu_model)\n    if not model:\n        print_error(f\"Model '{cpu_model}' not found in registry.\")\n        console.print()\n        raise typer.Exit(1)\n\n    if not model.gpu_model_ids:\n        console.print()\n        console.print(f\"[yellow]Model '{cpu_model}' has no GPU links.[/yellow]\")\n        console.print()\n        return\n\n    # Remove links\n    registry.update_model(cpu_model, {\"gpu_model_ids\": None})\n\n    console.print()\n    print_success(f\"Removed all GPU links from '{cpu_model}'\")\n    console.print()\n\n\n@app.command(name=\"path-add\")\ndef path_add(\n    path: str = typer.Argument(..., help=\"Path to add\"),\n) -> None:\n    \"\"\"Add a new model storage path.\"\"\"\n    # Expand user home directory\n    path = os.path.expanduser(path)\n\n    # Check if path exists or can be created\n    path_obj = Path(path)\n    if not path_obj.exists():\n        console.print(f\"[yellow]{t('model_path_not_exist', path=path)}[/yellow]\")\n        if confirm(t(\"model_create_directory\", path=path), default=True):\n            try:\n                path_obj.mkdir(parents=True, exist_ok=True)\n                console.print(f\"[green]✓[/green] {t('model_created_directory', path=path)}\")\n            except (OSError, PermissionError) as e:\n                print_error(t(\"model_create_dir_failed\", error=str(e)))\n                raise typer.Exit(1)\n        else:\n            raise typer.Abort()\n\n    # Add to configuration\n    settings = get_settings()\n    settings.add_model_path(path)\n    print_success(t(\"model_path_added\", path=path))\n\n\n@app.command(name=\"path-remove\")\ndef path_remove(\n    path: str = typer.Argument(..., help=\"Path to remove\"),\n) -> None:\n    \"\"\"Remove a model storage path from configuration.\"\"\"\n    # Expand user home directory\n    path = os.path.expanduser(path)\n\n    settings = get_settings()\n    if settings.remove_model_path(path):\n        print_success(t(\"model_path_removed\", path=path))\n    else:\n        print_error(t(\"model_path_not_found\", path=path))\n        raise typer.Exit(1)\n\n\n@app.command(name=\"scan\")\ndef scan(\n    min_size: float = typer.Option(2.0, \"--min-size\", help=\"Minimum model file size in GB (default: 2.0)\"),\n    max_depth: int = typer.Option(6, \"--max-depth\", help=\"Maximum search depth (default: 6)\"),\n) -> None:\n    \"\"\"Perform global scan for models and add new ones to registry.\"\"\"\n    from kt_kernel.cli.utils.model_discovery import discover_and_register_global, format_discovery_summary\n    from kt_kernel.cli.config.settings import get_settings\n\n    settings = get_settings()\n    lang = settings.get(\"general.language\", \"en\")\n\n    console.print()\n    if lang == \"zh\":\n        print_info(\"全局扫描模型权重\")\n        console.print()\n    else:\n        print_info(\"Global Model Scan\")\n        console.print()\n\n    try:\n        total_found, new_found, registered = discover_and_register_global(\n            min_size_gb=min_size, max_depth=max_depth, show_progress=True, lang=lang\n        )\n\n        format_discovery_summary(\n            total_found=total_found,\n            new_found=new_found,\n            registered=registered,\n            lang=lang,\n            show_models=True,\n            max_show=20,\n        )\n\n        if new_found > 0:\n            console.print()\n            if lang == \"zh\":\n                console.print(\"[dim]下一步:[/dim]\")\n                console.print(f\"  • 查看模型列表: [cyan]kt model list[/cyan]\")\n                console.print(f\"  • 编辑模型信息: [cyan]kt model edit <name>[/cyan]\")\n                console.print(f\"  • 验证模型: [cyan]kt model verify <name>[/cyan]\")\n            else:\n                console.print(\"[dim]Next steps:[/dim]\")\n                console.print(f\"  • View model list: [cyan]kt model list[/cyan]\")\n                console.print(f\"  • Edit model info: [cyan]kt model edit <name>[/cyan]\")\n                console.print(f\"  • Verify models: [cyan]kt model verify <name>[/cyan]\")\n            console.print()\n\n    except Exception as e:\n        print_error(f\"Scan failed: {e}\")\n        raise typer.Exit(1)\n\n\n@app.command(name=\"add\")\ndef add_model(\n    path: str = typer.Argument(..., help=\"Path to scan for models\"),\n) -> None:\n    \"\"\"Scan a directory and add all found models to the registry.\"\"\"\n    from pathlib import Path\n    from kt_kernel.cli.utils.model_discovery import discover_and_register_path\n    from kt_kernel.cli.config.settings import get_settings\n\n    settings = get_settings()\n    lang = settings.get(\"general.language\", \"en\")\n\n    # Expand and validate path\n    path_obj = Path(os.path.expanduser(path)).resolve()\n\n    if not path_obj.exists():\n        print_error(f\"Path does not exist: {path_obj}\")\n        raise typer.Exit(1)\n\n    if not path_obj.is_dir():\n        print_error(f\"Not a directory: {path_obj}\")\n        raise typer.Exit(1)\n\n    # Scan and register models\n    console.print()\n    try:\n        total_found, new_found, registered = discover_and_register_path(\n            path=str(path_obj), min_size_gb=2.0, existing_paths=None, show_progress=True, lang=lang\n        )\n\n        console.print()\n        if new_found == 0:\n            if total_found > 0:\n                if lang == \"zh\":\n                    console.print(f\"[yellow]在此路径找到 {total_found} 个模型，但所有模型均已在列表中[/yellow]\")\n                else:\n                    console.print(\n                        f\"[yellow]Found {total_found} models in this path, but all already in the list[/yellow]\"\n                    )\n            else:\n                if lang == \"zh\":\n                    console.print(\"[yellow]未找到模型[/yellow]\")\n                    console.print()\n                    console.print(\"  支持的格式: *.gguf, *.safetensors (需要 config.json)\")\n                else:\n                    console.print(\"[yellow]No models found[/yellow]\")\n                    console.print()\n                    console.print(\"  Supported formats: *.gguf, *.safetensors (with config.json)\")\n        else:\n            if lang == \"zh\":\n                console.print(\n                    f\"[green]✓[/green] 在此路径找到 {total_found} 个模型，成功添加 {len(registered)} 个新模型\"\n                )\n            else:\n                console.print(\n                    f\"[green]✓[/green] Found {total_found} models in this path, added {len(registered)} new models\"\n                )\n\n            if registered:\n                console.print()\n                if lang == \"zh\":\n                    console.print(\"[dim]新添加的模型:[/dim]\")\n                else:\n                    console.print(\"[dim]Newly added models:[/dim]\")\n\n                for model in registered:\n                    console.print(f\"  • {model.name} ({model.format})\")\n                    console.print(f\"    [dim]{model.path}[/dim]\")\n\n        console.print()\n\n    except Exception as e:\n        print_error(f\"Failed to scan path: {e}\")\n        raise typer.Exit(1)\n\n\n@app.command(name=\"edit\")\ndef edit_model(\n    name: Optional[str] = typer.Argument(\n        None, help=\"Name of model to edit (optional - will show selection if not provided)\"\n    ),\n) -> None:\n    \"\"\"Edit model information interactively.\"\"\"\n    from rich.prompt import Prompt, Confirm\n    from rich.panel import Panel\n    from rich.table import Table\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n\n    registry = UserModelRegistry()\n\n    # If no name provided, show interactive selection\n    if name is None:\n        all_models = registry.list_models()\n\n        # Filter to only show MoE GPU models (safetensors that are not AMX)\n        moe_models = []\n        for m in all_models:\n            if m.format == \"safetensors\":\n                is_amx_model, _ = is_amx_weights(m.path)\n                if not is_amx_model:\n                    moe_models.append(m)\n\n        if not moe_models:\n            print_error(t(\"model_edit_no_models\"))\n            console.print()\n            console.print(f\"  {t('model_edit_add_hint_scan')} [cyan]kt model scan[/cyan]\")\n            console.print(f\"  {t('model_edit_add_hint_add')} [cyan]kt model add <path>[/cyan]\")\n            console.print()\n            raise typer.Exit(1)\n\n        # Display models table with # column\n        console.print()\n        console.print(f\"[bold cyan]{t('model_edit_select_title')}[/bold cyan]\")\n        console.print()\n\n        table = Table(show_header=True, header_style=\"bold\", show_lines=False)\n        table.add_column(\"#\", justify=\"right\", style=\"cyan\", no_wrap=True)\n        table.add_column(\"Name\", style=\"cyan\", no_wrap=True)\n        table.add_column(\"Format\", style=\"dim\")\n        table.add_column(\"Path\", style=\"dim\", overflow=\"fold\")\n\n        for i, model_item in enumerate(moe_models, 1):\n            table.add_row(str(i), model_item.name, model_item.format, model_item.path)\n\n        console.print(table)\n        console.print()\n\n        from rich.prompt import IntPrompt\n\n        choice = IntPrompt.ask(t(\"model_edit_select_model\"), default=1, show_choices=False)\n\n        if choice < 1 or choice > len(moe_models):\n            print_error(t(\"model_edit_invalid_choice\"))\n            raise typer.Exit(1)\n\n        model = moe_models[choice - 1]\n    else:\n        # Load model by name\n        model = registry.get_model(name)\n        if not model:\n            print_error(t(\"model_edit_not_found\", name=name))\n            console.print()\n            console.print(f\"  {t('model_edit_list_hint')} [cyan]kt model list[/cyan]\")\n            console.print()\n            raise typer.Exit(1)\n\n    # Keep track of original values to detect changes\n    original_name = model.name\n    original_repo_type = model.repo_type\n    original_repo_id = model.repo_id\n    original_gpu_model_ids = model.gpu_model_ids.copy() if model.gpu_model_ids else None\n\n    # Working copy for edits (not saved until user confirms)\n    edited_name = model.name\n    edited_repo_type = model.repo_type\n    edited_repo_id = model.repo_id\n    edited_gpu_model_ids = model.gpu_model_ids.copy() if model.gpu_model_ids else None\n\n    has_changes = False\n\n    while True:\n        # Display current configuration (show edited values)\n        console.print()\n        console.print(f\"[bold cyan]{t('model_edit_current_config')}[/bold cyan]\\n\")\n\n        # Format SHA256 status (from original model)\n        sha256_display = SHA256_STATUS_MAP_PLAIN.get(model.sha256_status, model.sha256_status)\n\n        # Check if this is a CPU model (GGUF or AMX)\n        is_cpu_model = model.format == \"gguf\"\n        if not is_cpu_model and model.format == \"safetensors\":\n            is_amx, _ = is_amx_weights(model.path)\n            is_cpu_model = is_amx\n\n        # Format GPU links info (for CPU models)\n        gpu_links_info = \"\"\n        if is_cpu_model and edited_gpu_model_ids:\n            gpu_names = []\n            for gpu_id in edited_gpu_model_ids:\n                gpu_obj = registry.get_model_by_id(gpu_id)\n                if gpu_obj:\n                    gpu_names.append(gpu_obj.name)\n                else:\n                    gpu_names.append(f\"[dim red]{gpu_id[:8]}... (deleted)[/dim red]\")\n            gpu_links_info = f\"\\n[bold]{t('model_edit_gpu_links')}[/bold]  {', '.join(gpu_names)}\"\n\n        content = f\"\"\"[bold]Name:[/bold]       {edited_name}\n[bold]Path:[/bold]       {model.path}\n[bold]Format:[/bold]     {model.format}\n[bold]Repo Type:[/bold]  {edited_repo_type or '-'}\n[bold]Repo ID:[/bold]    {edited_repo_id or '-'}\n[bold]SHA256:[/bold]     {sha256_display}{gpu_links_info}\"\"\"\n\n        panel = Panel(content, border_style=\"cyan\", padding=(0, 1))\n        console.print(panel)\n        console.print()\n\n        # Check if there are any changes\n        has_changes = (\n            edited_name != original_name\n            or edited_repo_type != original_repo_type\n            or edited_repo_id != original_repo_id\n            or edited_gpu_model_ids != original_gpu_model_ids\n        )\n\n        # Show menu\n        console.print(f\"[bold]{t('model_edit_what_to_edit')}[/bold]\")\n        console.print(\"  [1] \" + t(\"model_edit_option_name\"))\n        console.print(\"  [2] \" + t(\"model_edit_option_repo\"))\n        console.print(\"  [3] \" + t(\"model_edit_option_delete\"))\n        if is_cpu_model:\n            console.print(\"  [4] \" + t(\"model_edit_manage_gpu_links\"))\n            save_option = \"5\"\n            cancel_option = \"6\"\n            console.print(\n                f\"  [{save_option}] {t('model_edit_save_changes')}\"\n                + (\n                    f\" [cyan]{t('model_edit_has_changes')}[/cyan]\"\n                    if has_changes\n                    else f\" [dim]{t('model_edit_no_changes')}[/dim]\"\n                )\n            )\n            console.print(f\"  [{cancel_option}] \" + t(\"model_edit_option_cancel\"))\n            console.print()\n            choice = Prompt.ask(t(\"model_edit_choice_prompt\"), choices=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\"], default=\"6\")\n        else:\n            save_option = \"4\"\n            cancel_option = \"5\"\n            console.print(\n                f\"  [{save_option}] {t('model_edit_save_changes')}\"\n                + (\n                    f\" [cyan]{t('model_edit_has_changes')}[/cyan]\"\n                    if has_changes\n                    else f\" [dim]{t('model_edit_no_changes')}[/dim]\"\n                )\n            )\n            console.print(f\"  [{cancel_option}] \" + t(\"model_edit_option_cancel\"))\n            console.print()\n            choice = Prompt.ask(t(\"model_edit_choice_prompt\"), choices=[\"1\", \"2\", \"3\", \"4\", \"5\"], default=\"5\")\n\n        if choice == \"1\":\n            # Edit name (update working copy only)\n            console.print()\n            new_name = Prompt.ask(t(\"model_edit_new_name\"), default=edited_name)\n\n            if new_name != edited_name:\n                # Check for conflict (excluding both original and edited names)\n                if new_name != original_name and registry.check_name_conflict(new_name, exclude_name=original_name):\n                    print_error(t(\"model_edit_name_conflict\", name=new_name))\n                    continue\n\n                edited_name = new_name\n                console.print()\n                print_info(f\"[dim]{t('model_edit_name_pending')}[/dim]\")\n\n        elif choice == \"2\":\n            # Edit repo configuration (update working copy only)\n            console.print()\n            console.print(t(\"model_edit_repo_type_prompt\"))\n            console.print(\"  [1] HuggingFace\")\n            console.print(\"  [2] ModelScope\")\n            console.print(\"  [3] \" + t(\"model_edit_repo_remove\"))\n            console.print()\n\n            repo_choice = Prompt.ask(t(\"model_edit_choice_prompt\"), choices=[\"1\", \"2\", \"3\"], default=\"3\")\n\n            if repo_choice == \"3\":\n                # Remove repo\n                edited_repo_type = None\n                edited_repo_id = None\n                console.print()\n                print_info(f\"[dim]{t('model_edit_repo_remove_pending')}[/dim]\")\n            else:\n                # Set repo\n                repo_type = \"huggingface\" if repo_choice == \"1\" else \"modelscope\"\n                example = \"deepseek-ai/DeepSeek-V3\" if repo_choice == \"1\" else \"deepseek/DeepSeek-V3\"\n\n                current_default = edited_repo_id if edited_repo_id and edited_repo_type == repo_type else \"\"\n                repo_id = Prompt.ask(\n                    t(\"model_edit_repo_id_prompt\", example=example),\n                    default=current_default if current_default else None,\n                )\n\n                edited_repo_type = repo_type\n                edited_repo_id = repo_id\n                console.print()\n                print_info(f\"[dim]{t('model_edit_repo_update_pending')}[/dim]\")\n\n        elif choice == \"3\":\n            # Delete model\n            console.print()\n            console.print(f\"[bold yellow]{t('model_edit_delete_warning')}[/bold yellow]\")\n            console.print(f\"  {t('model_edit_delete_note')}\")\n            console.print()\n\n            if Confirm.ask(t(\"model_edit_delete_confirm\", name=model.name), default=False):\n                registry.remove_model(model.name)\n                console.print()\n                print_success(t(\"model_edit_deleted\", name=model.name))\n                console.print()\n                return\n            else:\n                console.print()\n                print_info(t(\"model_edit_delete_cancelled\"))\n\n        elif choice == \"4\" and is_cpu_model:\n            # Manage GPU Links (only for CPU models) - update working copy\n            console.print()\n            console.print(f\"[bold cyan]{t('model_edit_gpu_links_title', name=edited_name)}[/bold cyan]\")\n            console.print()\n\n            # Show current links (from edited values)\n            if edited_gpu_model_ids:\n                console.print(f\"[bold]{t('model_edit_current_gpu_links')}[/bold]\")\n                for i, gpu_id in enumerate(edited_gpu_model_ids, 1):\n                    gpu_obj = registry.get_model_by_id(gpu_id)\n                    if gpu_obj:\n                        console.print(f\"  [{i}] {gpu_obj.name}\")\n                    else:\n                        console.print(f\"  [{i}] [red]{gpu_id[:8]}... (deleted)[/red]\")\n                console.print()\n            else:\n                console.print(f\"[dim]{t('model_edit_no_gpu_links')}[/dim]\")\n                console.print()\n\n            console.print(f\"{t('model_edit_gpu_options')}\")\n            console.print(f\"  [1] {t('model_edit_gpu_add')}\")\n            console.print(f\"  [2] {t('model_edit_gpu_remove')}\")\n            console.print(f\"  [3] {t('model_edit_gpu_clear')}\")\n            console.print(f\"  [4] {t('model_edit_gpu_back')}\")\n            console.print()\n\n            link_choice = Prompt.ask(t(\"model_edit_gpu_choose_option\"), choices=[\"1\", \"2\", \"3\", \"4\"], default=\"4\")\n\n            if link_choice == \"1\":\n                # Add GPU link\n                # Get all GPU models (safetensors that are not AMX)\n                all_models = registry.list_models()\n                available_gpu_models = []\n                for m in all_models:\n                    if m.format == \"safetensors\":\n                        is_amx_model, _ = is_amx_weights(m.path)\n                        if not is_amx_model:\n                            available_gpu_models.append(m)\n\n                if not available_gpu_models:\n                    console.print()\n                    console.print(f\"[yellow]{t('model_edit_gpu_none_available')}[/yellow]\")\n                    console.print()\n                else:\n                    console.print()\n                    console.print(f\"{t('model_edit_gpu_available_models')}\")\n                    for i, gpu_m in enumerate(available_gpu_models, 1):\n                        already_linked = edited_gpu_model_ids and gpu_m.id in edited_gpu_model_ids\n                        status = f\" [dim]{t('model_edit_gpu_already_linked')}[/dim]\" if already_linked else \"\"\n                        console.print(f\"  [{i}] {gpu_m.name}{status}\")\n                    console.print()\n\n                    gpu_choice = Prompt.ask(t(\"model_edit_gpu_enter_number\"), default=\"0\")\n                    try:\n                        gpu_idx = int(gpu_choice) - 1\n                        if 0 <= gpu_idx < len(available_gpu_models):\n                            selected_gpu = available_gpu_models[gpu_idx]\n\n                            # Add to edited_gpu_model_ids\n                            current_ids = list(edited_gpu_model_ids) if edited_gpu_model_ids else []\n                            if selected_gpu.id not in current_ids:\n                                current_ids.append(selected_gpu.id)\n                                edited_gpu_model_ids = current_ids\n                                console.print()\n                                print_info(f\"[dim]{t('model_edit_gpu_link_pending', name=selected_gpu.name)}[/dim]\")\n                            else:\n                                console.print()\n                                console.print(f\"[yellow]{t('model_edit_gpu_already_exists')}[/yellow]\")\n                        else:\n                            console.print()\n                            console.print(f\"[red]{t('model_edit_gpu_invalid_choice')}[/red]\")\n                    except ValueError:\n                        console.print()\n                        console.print(f\"[red]{t('model_edit_gpu_invalid_input')}[/red]\")\n\n            elif link_choice == \"2\":\n                # Remove GPU link\n                if not edited_gpu_model_ids:\n                    console.print()\n                    console.print(f\"[yellow]{t('model_edit_gpu_none_to_remove')}[/yellow]\")\n                    console.print()\n                else:\n                    console.print()\n                    console.print(f\"{t('model_edit_gpu_choose_to_remove')}\")\n                    gpu_list = []\n                    for i, gpu_id in enumerate(edited_gpu_model_ids, 1):\n                        gpu_obj = registry.get_model_by_id(gpu_id)\n                        gpu_name = gpu_obj.name if gpu_obj else f\"{gpu_id[:8]}... (deleted)\"\n                        gpu_list.append((gpu_id, gpu_name))\n                        console.print(f\"  [{i}] {gpu_name}\")\n                    console.print()\n\n                    remove_choice = Prompt.ask(t(\"model_edit_gpu_enter_to_remove\"), default=\"0\")\n                    try:\n                        remove_idx = int(remove_choice) - 1\n                        if 0 <= remove_idx < len(gpu_list):\n                            removed_id, removed_name = gpu_list[remove_idx]\n                            new_ids = [gid for gid in edited_gpu_model_ids if gid != removed_id]\n                            edited_gpu_model_ids = new_ids if new_ids else None\n                            console.print()\n                            print_info(f\"[dim]{t('model_edit_gpu_remove_pending', name=removed_name)}[/dim]\")\n                        else:\n                            console.print()\n                            console.print(f\"[red]{t('model_edit_gpu_invalid_choice')}[/red]\")\n                    except ValueError:\n                        console.print()\n                        console.print(f\"[red]{t('model_edit_gpu_invalid_input')}[/red]\")\n\n            elif link_choice == \"3\":\n                # Clear all GPU links\n                if not edited_gpu_model_ids:\n                    console.print()\n                    console.print(f\"[yellow]{t('model_edit_gpu_none_to_clear')}[/yellow]\")\n                    console.print()\n                else:\n                    if Confirm.ask(t(\"model_edit_gpu_clear_confirm\"), default=False):\n                        edited_gpu_model_ids = None\n                        console.print()\n                        print_info(f\"[dim]{t('model_edit_gpu_clear_pending')}[/dim]\")\n                    else:\n                        console.print()\n                        print_info(t(\"model_edit_cancelled_short\"))\n\n        elif choice == save_option:\n            # Save changes\n            if not has_changes:\n                console.print()\n                print_info(f\"[dim]{t('model_edit_no_changes_to_save')}[/dim]\")\n                continue\n\n            console.print()\n            console.print(f\"[bold cyan]{t('model_edit_saving')}[/bold cyan]\")\n            console.print()\n\n            # Determine if repo info changed (for verification prompt)\n            repo_changed = (original_repo_id is None and edited_repo_id is not None) or (\n                original_repo_id != edited_repo_id\n            )\n\n            # Build updates dict\n            updates = {}\n            if edited_name != original_name:\n                updates[\"name\"] = edited_name\n            if edited_repo_type != original_repo_type:\n                updates[\"repo_type\"] = edited_repo_type\n            if edited_repo_id != original_repo_id:\n                updates[\"repo_id\"] = edited_repo_id\n                # Update SHA256 status when repo changes\n                if edited_repo_id is None:\n                    updates[\"sha256_status\"] = \"no_repo\"\n                else:\n                    updates[\"sha256_status\"] = \"not_checked\"\n            if edited_gpu_model_ids != original_gpu_model_ids:\n                updates[\"gpu_model_ids\"] = edited_gpu_model_ids\n\n            # Save to registry\n            registry.update_model(original_name, updates)\n            print_success(t(\"model_edit_saved\"))\n\n            # Update local model object\n            if \"name\" in updates:\n                model.name = edited_name\n            if \"repo_type\" in updates:\n                model.repo_type = edited_repo_type\n            if \"repo_id\" in updates:\n                model.repo_id = edited_repo_id\n            if \"sha256_status\" in updates:\n                model.sha256_status = updates[\"sha256_status\"]\n            if \"gpu_model_ids\" in updates:\n                model.gpu_model_ids = edited_gpu_model_ids\n\n            # Update original values for next iteration\n            original_name = edited_name\n            original_repo_type = edited_repo_type\n            original_repo_id = edited_repo_id\n            original_gpu_model_ids = edited_gpu_model_ids.copy() if edited_gpu_model_ids else None\n\n            # Display updated configuration\n            console.print()\n            console.print(f\"[bold cyan]{t('model_edit_updated_config')}[/bold cyan]\\n\")\n\n            sha256_display = SHA256_STATUS_MAP_PLAIN.get(model.sha256_status, model.sha256_status)\n            gpu_links_info = \"\"\n            if is_cpu_model and model.gpu_model_ids:\n                gpu_names = []\n                for gpu_id in model.gpu_model_ids:\n                    gpu_obj = registry.get_model_by_id(gpu_id)\n                    if gpu_obj:\n                        gpu_names.append(gpu_obj.name)\n                    else:\n                        gpu_names.append(f\"[dim red]{gpu_id[:8]}... (deleted)[/dim red]\")\n                gpu_links_info = f\"\\n[bold]{t('model_edit_gpu_links')}[/bold]  {', '.join(gpu_names)}\"\n\n            content = f\"\"\"[bold]Name:[/bold]       {model.name}\n[bold]Path:[/bold]       {model.path}\n[bold]Format:[/bold]     {model.format}\n[bold]Repo Type:[/bold]  {model.repo_type or '-'}\n[bold]Repo ID:[/bold]    {model.repo_id or '-'}\n[bold]SHA256:[/bold]     {sha256_display}{gpu_links_info}\"\"\"\n\n            panel = Panel(content, border_style=\"green\", padding=(0, 1))\n            console.print(panel)\n            console.print()\n\n            # If repo changed, suggest verification\n            if repo_changed and model.repo_id:\n                console.print()\n                console.print(f\"[bold yellow]{t('model_edit_repo_changed_warning')}[/bold yellow]\")\n                console.print()\n                console.print(f\"  {t('model_edit_verify_hint')}\")\n                console.print()\n\n            return\n\n        elif choice == cancel_option:\n            # Cancel\n            console.print()\n            if has_changes:\n                if Confirm.ask(f\"[yellow]{t('model_edit_discard_changes')}[/yellow]\", default=False):\n                    print_info(t(\"model_edit_cancelled\"))\n                    console.print()\n                    return\n                else:\n                    # Go back to menu\n                    continue\n            else:\n                print_info(t(\"model_edit_cancelled\"))\n                console.print()\n                return\n\n\n@app.command(name=\"info\")\ndef info_model(\n    name: str = typer.Argument(..., help=\"Name of model to display\"),\n) -> None:\n    \"\"\"Display detailed information about a model.\"\"\"\n    from rich.panel import Panel\n    from pathlib import Path\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n    from kt_kernel.cli.utils.model_scanner import format_size\n\n    registry = UserModelRegistry()\n\n    # Load model\n    model = registry.get_model(name)\n    if not model:\n        print_error(t(\"model_info_not_found\", name=name))\n        console.print()\n        console.print(f\"  {t('model_info_list_hint')} [cyan]kt model list[/cyan]\")\n        console.print()\n        raise typer.Exit(1)\n\n    console.print()\n\n    # Check if path exists\n    path_status = \"[green]✓ Exists[/green]\" if model.path_exists() else \"[red]✗ Missing[/red]\"\n\n    # Format repo info\n    if model.repo_id:\n        repo_abbr = \"hf\" if model.repo_type == \"huggingface\" else \"ms\"\n        repo_info = f\"{repo_abbr}:{model.repo_id}\"\n    else:\n        repo_info = \"-\"\n\n    # Format SHA256 status\n    sha256_display = SHA256_STATUS_MAP_PLAIN.get(model.sha256_status, model.sha256_status)\n\n    # Calculate folder size and list files if exists\n    moe_info = \"\"\n    amx_info = \"\"\n\n    if model.path_exists():\n        path_obj = Path(model.path)\n        try:\n            if model.format == \"safetensors\":\n                files = list(path_obj.glob(\"*.safetensors\"))\n\n                # Check for AMX weights\n                is_amx, numa_count = is_amx_weights(str(path_obj))\n                if is_amx:\n                    amx_info = f\"\\n[bold]AMX Format:[/bold]   Yes (NUMA: {numa_count})\"\n                else:\n                    # Check for MOE model\n                    try:\n                        from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model\n\n                        moe_result = analyze_moe_model(str(path_obj))\n                        if moe_result and moe_result.get(\"num_experts\", 0) > 0:\n                            moe_info = f\"\"\"\n[bold]MoE Info:[/bold]\n  • Total Experts:     {moe_result['num_experts']}\n  • Activated Experts: {moe_result['num_experts_per_tok']} experts/token\n  • Hidden Layers:     {moe_result['num_hidden_layers']}\n  • Total Model Size:  {moe_result['total_size_gb']:.2f} GB\"\"\"\n                    except Exception:\n                        pass  # Not a MoE model or analysis failed\n            else:\n                files = list(path_obj.glob(\"*.gguf\"))\n\n            total_size = sum(f.stat().st_size for f in files if f.exists())\n            size_str = format_size(total_size)\n            file_count = len(files)\n            size_info = f\"{size_str} ({file_count} files)\"\n\n            # List first few files\n            file_list = \"\\n\".join([f\"  • {f.name}\" for f in sorted(files)[:5]])\n            if len(files) > 5:\n                file_list += f\"\\n  ... and {len(files) - 5} more files\"\n        except Exception as e:\n            size_info = f\"Error calculating size: {e}\"\n            file_list = \"-\"\n    else:\n        size_info = \"-\"\n        file_list = \"[red]Path does not exist[/red]\"\n\n    # Format created/verified dates\n    from datetime import datetime\n\n    try:\n        created_date = datetime.fromisoformat(model.created_at).strftime(\"%Y-%m-%d %H:%M:%S\")\n    except:\n        created_date = model.created_at\n\n    if model.last_verified:\n        try:\n            verified_date = datetime.fromisoformat(model.last_verified).strftime(\"%Y-%m-%d %H:%M:%S\")\n        except:\n            verified_date = model.last_verified\n    else:\n        verified_date = \"-\"\n\n    # Create detailed panel\n    content = f\"\"\"[bold]Name:[/bold]         {model.name}\n[bold]Path:[/bold]         {model.path}\n[bold]Path Status:[/bold]  {path_status}\n[bold]Format:[/bold]       {model.format}\n[bold]Size:[/bold]         {size_info}{amx_info}{moe_info}\n[bold]Repo Type:[/bold]    {model.repo_type or '-'}\n[bold]Repo ID:[/bold]      {model.repo_id or '-'}\n[bold]SHA256:[/bold]       {sha256_display}\n[bold]Created:[/bold]      {created_date}\n[bold]Last Verified:[/bold] {verified_date}\n\n[bold]Files:[/bold]\n{file_list}\"\"\"\n\n    panel = Panel(content, title=f\"[cyan]Model Information: {model.name}[/cyan]\", border_style=\"cyan\", padding=(1, 2))\n    console.print(panel)\n    console.print()\n\n\n@app.command(name=\"remove\")\ndef remove_model(\n    name: str = typer.Argument(..., help=\"Name of model to remove\"),\n    yes: bool = typer.Option(False, \"--yes\", \"-y\", help=\"Skip confirmation\"),\n) -> None:\n    \"\"\"Remove a model from the registry (does not delete files).\"\"\"\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n\n    registry = UserModelRegistry()\n\n    # Check if model exists\n    model = registry.get_model(name)\n    if not model:\n        print_error(t(\"model_remove_not_found\", name=name))\n        console.print()\n        console.print(f\"  {t('model_remove_list_hint')} [cyan]kt model list[/cyan]\")\n        console.print()\n        raise typer.Exit(1)\n\n    console.print()\n    console.print(f\"[bold yellow]{t('model_remove_warning')}[/bold yellow]\")\n    console.print(f\"  {t('model_remove_note')}\")\n    console.print(f\"  [dim]Path: {model.path}[/dim]\")\n    console.print()\n\n    # Check if this GPU model is linked by any CPU models\n    model_uuid = model.id\n    affected_cpu_models = []\n\n    # Only check for GPU models (safetensors that are not AMX)\n    if model.format == \"safetensors\":\n        is_amx, _ = is_amx_weights(model.path)\n        if not is_amx:\n            # This is a GPU model, check for CPU models that link to it\n            for m in registry.list_models():\n                if m.gpu_model_ids and model_uuid in m.gpu_model_ids:\n                    affected_cpu_models.append(m)\n\n    # If there are affected CPU models, inform the user\n    if affected_cpu_models:\n        console.print(f\"[yellow]This GPU model is linked by {len(affected_cpu_models)} CPU model(s):[/yellow]\")\n        for cpu_model in affected_cpu_models:\n            console.print(f\"  • {cpu_model.name}\")\n        console.print()\n        console.print(f\"[dim]These links will be automatically removed.[/dim]\")\n        console.print()\n\n    # Confirm deletion\n    if not yes:\n        if not confirm(t(\"model_remove_confirm\", name=name), default=False):\n            print_info(t(\"model_remove_cancelled\"))\n            console.print()\n            return\n\n    # Clean up references in CPU models before removing\n    if affected_cpu_models:\n        for cpu_model in affected_cpu_models:\n            # Remove this GPU model's UUID from the cpu_model's gpu_model_ids list\n            new_gpu_ids = [gid for gid in cpu_model.gpu_model_ids if gid != model_uuid]\n            registry.update_model(cpu_model.name, {\"gpu_model_ids\": new_gpu_ids if new_gpu_ids else None})\n\n    # Remove from registry\n    if registry.remove_model(name):\n        console.print()\n        print_success(t(\"model_removed\", name=name))\n        console.print()\n    else:\n        print_error(t(\"model_remove_failed\", name=name))\n        raise typer.Exit(1)\n\n\n@app.command(name=\"refresh\")\ndef refresh_models() -> None:\n    \"\"\"Check all registered models and identify missing ones.\"\"\"\n    from rich.table import Table\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n\n    registry = UserModelRegistry()\n    models = registry.list_models()\n\n    if not models:\n        print_warning(t(\"model_no_registered_models\"))\n        console.print()\n        return\n\n    console.print()\n    print_info(t(\"model_refresh_checking\"))\n\n    # Refresh status\n    status = registry.refresh_status()\n\n    # Check relationship integrity\n    broken_relationships = []  # [(cpu_model, gpu_uuid, gpu_name_or_none)]\n    for model in models:\n        if model.gpu_model_ids:\n            for gpu_uuid in model.gpu_model_ids:\n                gpu_obj = registry.get_model_by_id(gpu_uuid)\n                if not gpu_obj:\n                    broken_relationships.append((model.name, gpu_uuid, None))\n                elif not gpu_obj.path_exists():\n                    broken_relationships.append((model.name, gpu_uuid, gpu_obj.name))\n\n    console.print()\n\n    # Show results\n    has_issues = status[\"missing\"] or broken_relationships\n\n    if not has_issues:\n        print_success(t(\"model_refresh_all_valid\", count=len(models)))\n        console.print(f\"  {t('model_refresh_total', total=len(models))}\")\n        console.print()\n        return\n\n    # Show broken relationships\n    if broken_relationships:\n        print_warning(f\"Found {len(broken_relationships)} broken GPU link(s)\")\n        console.print()\n\n        from rich.table import Table\n\n        rel_table = Table(show_header=True, header_style=\"bold yellow\")\n        rel_table.add_column(\"CPU Model\", style=\"cyan\")\n        rel_table.add_column(\"GPU Model\", style=\"dim\")\n        rel_table.add_column(\"Issue\", style=\"red\")\n\n        for cpu_name, gpu_uuid, gpu_name in broken_relationships:\n            if gpu_name is None:\n                gpu_display = f\"{gpu_uuid[:8]}...\"\n                issue = \"Deleted\"\n            else:\n                gpu_display = gpu_name\n                issue = \"Path Missing\"\n            rel_table.add_row(cpu_name, gpu_display, issue)\n\n        console.print(rel_table)\n        console.print()\n        console.print(f\"[dim]Use [cyan]kt model edit <cpu-model>[/cyan] to fix GPU links[/dim]\")\n        console.print()\n\n    if not status[\"missing\"]:\n        # Only broken relationships, no missing models\n        return\n\n    # Show missing models\n    print_warning(t(\"model_refresh_missing_found\", count=len(status[\"missing\"])))\n    console.print()\n\n    table = Table(show_header=True, header_style=\"bold\")\n    table.add_column(t(\"model_column_name\"), style=\"cyan\")\n    table.add_column(t(\"model_column_path\"), style=\"dim\")\n    table.add_column(t(\"model_column_status\"), justify=\"center\")\n\n    for model in models:\n        if model.name in status[\"missing\"]:\n            status_text = \"[red]✗ Missing[/red]\"\n        else:\n            status_text = \"[green]✓ Valid[/green]\"\n\n        table.add_row(model.name, model.path, status_text)\n\n    console.print(table)\n    console.print()\n\n    # Suggest actions\n    console.print(f\"[bold]{t('model_refresh_suggestions')}:[/bold]\")\n    console.print(f\"  • {t('model_refresh_remove_hint')} [cyan]kt model remove <name>[/cyan]\")\n    console.print(f\"  • {t('model_refresh_rescan_hint')} [cyan]kt model scan[/cyan]\")\n    console.print()\n\n\n@app.command(name=\"verify\")\ndef verify_model(\n    name: str = typer.Argument(None, help=\"Name of model to verify (interactive if not provided)\"),\n    verbose: bool = typer.Option(False, \"--verbose\", \"-v\", help=\"Show detailed SHA256 comparison for each file\"),\n) -> None:\n    \"\"\"Verify model integrity using SHA256 checksums with interactive repair.\"\"\"\n    from pathlib import Path\n    from rich.prompt import Prompt, Confirm\n    from rich.table import Table\n    from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn, MofNCompleteColumn\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n    from kt_kernel.cli.utils.model_verifier import verify_model_integrity_with_progress, check_huggingface_connectivity\n\n    registry = UserModelRegistry()\n\n    # Helper function to display model selection table\n    def show_model_table():\n        from kt_kernel.cli.utils.model_scanner import format_size\n        from pathlib import Path\n\n        # Import MoE analyzer\n        analyze_moe_model = None\n        try:\n            from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model\n        except ImportError:\n            pass\n\n        all_models = registry.list_models()\n\n        # Filter: only safetensors models with repo_id\n        verifiable_models = [m for m in all_models if m.repo_id and m.format == \"safetensors\"]\n\n        if not verifiable_models:\n            print_warning(t(\"model_verify_all_no_repos\"))\n            console.print()\n            console.print(f\"  {t('model_verify_all_config_hint')}\")\n            console.print()\n            return None\n\n        # Analyze MoE models\n        moe_results = {}\n        if analyze_moe_model:\n            for model in verifiable_models:\n                try:\n                    result = analyze_moe_model(model.path, use_cache=True)\n                    if result and result.get(\"num_experts\", 0) > 0:\n                        moe_results[model.name] = result\n                except Exception:\n                    pass\n\n        # Filter to only show MoE models\n        moe_verifiable_models = [m for m in verifiable_models if m.name in moe_results]\n\n        if not moe_verifiable_models:\n            console.print()\n            console.print(\"[yellow]No MoE models with repo_id found for verification.[/yellow]\")\n            console.print()\n            console.print(\n                f\"[dim]Only MoE models can be verified. Use [cyan]kt model list[/cyan] to see all models.[/dim]\"\n            )\n            console.print()\n            return None\n\n        console.print()\n        console.print(\"[bold]Select a MoE model to verify:[/bold]\\n\")\n\n        table = Table(show_header=True, header_style=\"bold\", show_lines=False)\n        table.add_column(\"#\", justify=\"right\", style=\"dim\", width=4)\n        table.add_column(t(\"model_column_name\"), style=\"cyan\", no_wrap=True)\n        table.add_column(\"Path\", style=\"dim\", overflow=\"fold\")\n        table.add_column(\"Total\", justify=\"right\")\n        table.add_column(\"Exps\", justify=\"center\", style=\"yellow\")\n        table.add_column(\"Act\", justify=\"center\", style=\"green\")\n        table.add_column(t(\"model_column_repo\"), style=\"dim\", overflow=\"fold\")\n        table.add_column(t(\"model_column_sha256\"), justify=\"center\")\n\n        for i, model in enumerate(moe_verifiable_models, 1):\n            # Calculate size\n            if model.path_exists():\n                path_obj = Path(model.path)\n                try:\n                    files = list(path_obj.glob(\"*.safetensors\"))\n                    total_size = sum(f.stat().st_size for f in files if f.exists())\n                    size_display = format_size(total_size)\n                except:\n                    size_display = \"[dim]-[/dim]\"\n            else:\n                size_display = \"[dim]-[/dim]\"\n\n            # Get MoE info\n            moe_info = moe_results.get(model.name)\n            experts_display = f\"[yellow]{moe_info['num_experts']}[/yellow]\" if moe_info else \"[dim]-[/dim]\"\n            activated_display = f\"[green]{moe_info['num_experts_per_tok']}[/green]\" if moe_info else \"[dim]-[/dim]\"\n\n            # Repo info\n            repo_abbr = \"hf\" if model.repo_type == \"huggingface\" else \"ms\"\n            repo_display = f\"{repo_abbr}:{model.repo_id}\"\n\n            # SHA256 status\n            status_icon = {\n                \"not_checked\": \"[dim]○[/dim]\",\n                \"checking\": \"[yellow]◐[/yellow]\",\n                \"passed\": \"[green]✓[/green]\",\n                \"failed\": \"[red]✗[/red]\",\n                \"no_repo\": \"[dim]-[/dim]\",\n            }.get(model.sha256_status, \"[dim]?[/dim]\")\n\n            table.add_row(\n                str(i),\n                model.name,\n                model.path,\n                size_display,\n                experts_display,\n                activated_display,\n                repo_display,\n                status_icon,\n            )\n\n        console.print(table)\n        console.print()\n        console.print(\"[dim]SHA256 Status: ○ Not checked | ✓ Passed | ✗ Failed[/dim]\")\n        console.print()\n\n        return moe_verifiable_models\n\n    # Main verification loop\n    # Track files to verify (None = all files, list = specific files for re-verification)\n    files_to_verify = None\n\n    while True:\n        selected_model = None\n\n        # If name provided directly, use it once then switch to interactive\n        if name:\n            selected_model = registry.get_model(name)\n            if not selected_model:\n                print_error(t(\"model_verify_not_found\", name=name))\n                console.print()\n                console.print(f\"  {t('model_verify_list_hint')} [cyan]kt model list[/cyan]\")\n                console.print()\n                raise typer.Exit(1)\n            name = None  # Clear so next loop is interactive\n        else:\n            # Show interactive selection\n            verifiable_models = show_model_table()\n            if not verifiable_models:\n                return\n\n            choice = Prompt.ask(\"Enter model number to verify (or 'q' to quit)\", default=\"1\")\n\n            if choice.lower() == \"q\":\n                return\n\n            try:\n                idx = int(choice) - 1\n                if 0 <= idx < len(verifiable_models):\n                    selected_model = verifiable_models[idx]\n                    # Reset files_to_verify when selecting a new model\n                    files_to_verify = None\n                else:\n                    print_error(f\"Invalid selection: {choice}\")\n                    console.print()\n                    continue\n            except ValueError:\n                print_error(f\"Invalid input: {choice}\")\n                console.print()\n                continue\n\n        # Check model prerequisites\n        console.print()\n\n        if not selected_model.repo_id:\n            print_warning(t(\"model_verify_no_repo\", name=selected_model.name))\n            console.print()\n            console.print(f\"  {t('model_verify_config_hint', name=selected_model.name)}\")\n            console.print()\n            continue\n\n        if not selected_model.path_exists():\n            print_error(t(\"model_verify_path_missing\", path=selected_model.path))\n            console.print()\n            continue\n\n        # Check HuggingFace connectivity and decide whether to use mirror\n        use_mirror = False\n        if selected_model.repo_type == \"huggingface\":\n            with console.status(\"[dim]Checking HuggingFace connectivity...[/dim]\"):\n                is_accessible, message = check_huggingface_connectivity(timeout=5)\n\n            if not is_accessible:\n                print_warning(\"HuggingFace Connection Failed\")\n                console.print()\n                console.print(f\"  {message}\")\n                console.print()\n                console.print(\"  [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]\")\n                console.print()\n                use_mirror = True\n\n        # Perform verification with progress bar\n        if files_to_verify:\n            print_info(f\"Re-verifying {len(files_to_verify)} repaired files: {selected_model.name}\")\n        else:\n            print_info(f\"Verifying: {selected_model.name}\")\n        console.print(f\"  Repository: [yellow]{selected_model.repo_type}[/yellow]:{selected_model.repo_id}\")\n        console.print(f\"  Local path: {selected_model.path}\")\n        console.print()\n\n        # Helper function to fetch remote hashes with timeout (using console.status like connectivity check)\n        def fetch_remote_hashes_with_timeout(repo_type, repo_id, use_mirror, timeout_seconds):\n            \"\"\"Fetch remote hashes with timeout, returns (hashes_dict, timed_out).\"\"\"\n            from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError\n            from kt_kernel.cli.utils.model_verifier import fetch_model_sha256\n\n            def fetch_hashes():\n                platform = \"hf\" if repo_type == \"huggingface\" else \"ms\"\n                return fetch_model_sha256(repo_id, platform, use_mirror=use_mirror, timeout=timeout_seconds)\n\n            executor = ThreadPoolExecutor(max_workers=1)\n            try:\n                future = executor.submit(fetch_hashes)\n                hashes = future.result(timeout=timeout_seconds)\n                executor.shutdown(wait=False)\n                return (hashes, False)\n            except (FutureTimeoutError, Exception):\n                executor.shutdown(wait=False)\n                return (None, True)\n\n        # Step 1: Fetch remote hashes with timeout and fallback\n        official_hashes = None\n\n        if selected_model.repo_type == \"huggingface\":\n            # HF fallback chain: HF → HF-mirror → MS\n\n            # Try 1: HuggingFace (or HF-mirror if already set)\n            status = console.status(\n                \"[dim]Fetching remote hashes from HuggingFace{}...[/dim]\".format(\" mirror\" if use_mirror else \"\")\n            )\n            status.start()\n            official_hashes, timed_out = fetch_remote_hashes_with_timeout(\n                repo_type=\"huggingface\", repo_id=selected_model.repo_id, use_mirror=use_mirror, timeout_seconds=10\n            )\n            status.stop()\n\n            # Try 2: If timed out and not already using mirror, try HF-mirror\n            if timed_out and not use_mirror:\n                print_warning(\"HuggingFace Fetch Timeout (10s)\")\n                console.print()\n                console.print(\"  [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]\")\n                console.print()\n\n                status = console.status(\"[dim]Fetching remote hashes from HuggingFace mirror...[/dim]\")\n                status.start()\n                official_hashes, timed_out = fetch_remote_hashes_with_timeout(\n                    repo_type=\"huggingface\",\n                    repo_id=selected_model.repo_id,\n                    use_mirror=True,  # Use mirror\n                    timeout_seconds=10,\n                )\n                status.stop()\n\n            # Try 3: If still timed out, try ModelScope with same repo_id\n            if timed_out:\n                print_warning(\"HuggingFace Mirror Timeout (10s)\")\n                console.print()\n                console.print(\"  [yellow]Fallback to ModelScope mirror with same repo_id...[/yellow]\")\n                console.print()\n\n                status = console.status(\"[dim]Fetching remote hashes from ModelScope...[/dim]\")\n                status.start()\n                official_hashes, timed_out = fetch_remote_hashes_with_timeout(\n                    repo_type=\"modelscope\",\n                    repo_id=selected_model.repo_id,  # Use same repo_id\n                    use_mirror=False,\n                    timeout_seconds=10,\n                )\n                status.stop()\n\n                if official_hashes:\n                    # Success with ModelScope\n                    console.print(\"  [green]✓ Successfully fetched from ModelScope[/green]\")\n                    console.print()\n                elif timed_out:\n                    # All failed\n                    print_error(\"All sources timed out (HuggingFace and ModelScope)\")\n                    console.print()\n                    console.print(\"  Please check your network connection or try again later\")\n                    console.print()\n                    continue\n\n        elif selected_model.repo_type == \"modelscope\":\n            # ModelScope: no fallback, just timeout\n            status = console.status(\"[dim]Fetching remote hashes from ModelScope...[/dim]\")\n            status.start()\n            official_hashes, timed_out = fetch_remote_hashes_with_timeout(\n                repo_type=\"modelscope\", repo_id=selected_model.repo_id, use_mirror=False, timeout_seconds=10\n            )\n            status.stop()\n\n            if timed_out:\n                print_error(\"ModelScope Fetch Timeout (10s)\")\n                console.print()\n                console.print(\"  Please check your network connection or try again later\")\n                console.print()\n                continue\n\n        # Check if we successfully fetched remote hashes\n        if not official_hashes:\n            # Already printed error message above, skip to next model\n            continue\n\n        # Success - print confirmation\n        console.print(f\"  [green]✓ Fetched {len(official_hashes)} file hashes from remote[/green]\")\n        console.print()\n\n        # Step 2 & 3: Calculate local SHA256 and compare (with Progress bar)\n        from kt_kernel.cli.utils.model_verifier import calculate_local_sha256\n\n        with Progress(\n            SpinnerColumn(),\n            TextColumn(\"[progress.description]{task.description}\"),\n            BarColumn(),\n            MofNCompleteColumn(),\n            TimeElapsedColumn(),\n            console=console,\n        ) as progress:\n            # Step 2: Calculate local SHA256 hashes (no timeout)\n            local_dir_path = Path(selected_model.path)\n\n            # Determine which files to hash\n            if files_to_verify:\n                # Only hash files that need re-verification\n                clean_filenames = {\n                    Path(f.replace(\" (missing)\", \"\").replace(\" (hash mismatch)\", \"\").strip()).name\n                    for f in files_to_verify\n                }\n                # Collect files matching *.safetensors, *.json, *.py\n                files_to_hash = []\n                for pattern in [\"*.safetensors\", \"*.json\", \"*.py\"]:\n                    files_to_hash.extend(\n                        [f for f in local_dir_path.glob(pattern) if f.is_file() and f.name in clean_filenames]\n                    )\n            else:\n                # Collect all important files: *.safetensors, *.json, *.py\n                files_to_hash = []\n                for pattern in [\"*.safetensors\", \"*.json\", \"*.py\"]:\n                    files_to_hash.extend([f for f in local_dir_path.glob(pattern) if f.is_file()])\n\n            total_files = len(files_to_hash)\n\n            # Create progress task for local hashing\n            hash_task_id = progress.add_task(\"[yellow]Calculating local SHA256...\", total=total_files)\n            completed_count = [0]\n\n            def local_hash_callback(msg: str):\n                if \"Using\" in msg and \"workers\" in msg:\n                    # Show parallel worker info\n                    console.print(f\"  [dim]{msg}[/dim]\")\n                elif \"[\" in msg and \"/\" in msg and \"]\" in msg:\n                    # Progress update\n                    completed_count[0] += 1\n                    if \"✓\" in msg:\n                        filename = msg.split(\"✓\")[1].strip().split(\"(\")[0].strip()\n                        progress.update(hash_task_id, advance=1, description=f\"[yellow]Hashing: {filename[:40]}...\")\n\n            local_hashes = calculate_local_sha256(\n                local_dir_path,\n                \"*.safetensors\",\n                progress_callback=local_hash_callback,\n                files_list=files_to_hash if files_to_verify else None,\n            )\n\n            progress.remove_task(hash_task_id)\n            console.print(f\"  [green]✓ Calculated {len(local_hashes)} local file hashes[/green]\")\n\n            # Step 3: Compare hashes\n            # If re-verifying specific files, only compare those files\n            if files_to_verify:\n                # Build set of clean filenames to verify\n                clean_verify_filenames = {\n                    Path(f.replace(\" (missing)\", \"\").replace(\" (hash mismatch)\", \"\").strip()).name\n                    for f in files_to_verify\n                }\n                # Filter official_hashes to only include files we're re-verifying\n                hashes_to_compare = {\n                    filename: hash_value\n                    for filename, hash_value in official_hashes.items()\n                    if Path(filename).name in clean_verify_filenames\n                }\n            else:\n                # First-time verification: compare all files\n                hashes_to_compare = official_hashes\n\n            compare_task_id = progress.add_task(\"[blue]Comparing hashes...\", total=len(hashes_to_compare))\n\n            files_failed = []\n            files_missing = []\n            files_passed = 0\n\n            for filename, official_hash in hashes_to_compare.items():\n                file_basename = Path(filename).name\n\n                # Find matching local file\n                local_hash = None\n                for local_file, local_hash_value in local_hashes.items():\n                    if Path(local_file).name == file_basename:\n                        local_hash = local_hash_value\n                        break\n\n                if local_hash is None:\n                    files_missing.append(filename)\n                    if verbose:\n                        console.print(f\"  [red]✗ {file_basename} (missing)[/red]\")\n                elif local_hash.lower() != official_hash.lower():\n                    files_failed.append(f\"{filename} (hash mismatch)\")\n                    if verbose:\n                        console.print(f\"  [red]✗ {file_basename} (hash mismatch)[/red]\")\n                else:\n                    files_passed += 1\n                    if verbose:\n                        console.print(f\"  [green]✓ {file_basename}[/green]\")\n\n                progress.update(compare_task_id, advance=1)\n\n            progress.remove_task(compare_task_id)\n\n            # Build result\n            total_checked = len(hashes_to_compare)  # Use actual compared count\n            if files_failed or files_missing:\n                all_failed = files_failed + [f\"{f} (missing)\" for f in files_missing]\n                result = {\n                    \"status\": \"failed\",\n                    \"files_checked\": total_checked,\n                    \"files_passed\": files_passed,\n                    \"files_failed\": all_failed,\n                }\n            else:\n                result = {\n                    \"status\": \"passed\",\n                    \"files_checked\": total_checked,\n                    \"files_passed\": files_passed,\n                    \"files_failed\": [],\n                }\n\n        # Update registry status and display results\n        if result[\"status\"] == \"passed\":\n            registry.update_model(selected_model.name, {\"sha256_status\": \"passed\"})\n            console.print()\n            print_success(t(\"model_verify_passed\"))\n            console.print()\n            console.print(f\"  ✓ Files checked: [bold green]{result['files_checked']}[/bold green]\")\n            console.print(f\"  ✓ All files passed SHA256 verification\")\n            console.print()\n        elif result[\"status\"] == \"failed\":\n            registry.update_model(selected_model.name, {\"sha256_status\": \"failed\"})\n            console.print()\n            print_error(f\"Verification failed! {len(result['files_failed'])} file(s) have issues\")\n            console.print()\n            console.print(f\"  Total files: {result['files_checked']}\")\n            console.print(f\"  ✓ Passed: [green]{result['files_passed']}[/green]\")\n            console.print(f\"  ✗ Failed: [red]{len(result['files_failed'])}[/red]\")\n            console.print()\n\n            # Show failed files (only if not already shown in verbose mode)\n            if not verbose:\n                console.print(\"  [bold red]Failed files:[/bold red]\")\n                for failed_file in result[\"files_failed\"]:\n                    console.print(f\"    ✗ {failed_file}\")\n                console.print()\n\n            # Ask if user wants to repair\n            if Confirm.ask(\"[yellow]Do you want to repair (re-download) the failed files?[/yellow]\", default=True):\n                console.print()\n                print_info(\"Repairing failed files...\")\n\n                # Extract clean filenames by removing status suffixes\n                files_to_download = [\n                    f.replace(\" (missing)\", \"\").replace(\" (hash mismatch)\", \"\").strip() for f in result[\"files_failed\"]\n                ]\n\n                # Download each failed file\n                success_count = 0\n\n                # Set mirror for downloads if needed\n                import os\n\n                original_hf_endpoint = os.environ.get(\"HF_ENDPOINT\")\n                if use_mirror and selected_model.repo_type == \"huggingface\" and not original_hf_endpoint:\n                    os.environ[\"HF_ENDPOINT\"] = \"https://hf-mirror.com\"\n                    console.print(f\"  [dim]Using HuggingFace mirror for downloads[/dim]\")\n\n                try:\n                    for file_to_repair in files_to_download:\n                        console.print(f\"  Repairing: [cyan]{file_to_repair}[/cyan]\")\n\n                        # Step 1: Delete the corrupted/missing file if it exists\n                        local_file_path = Path(selected_model.path) / file_to_repair\n                        if local_file_path.exists():\n                            try:\n                                local_file_path.unlink()\n                                console.print(f\"    [dim]✓ Deleted corrupted file[/dim]\")\n                            except Exception as e:\n                                console.print(f\"    [yellow]⚠ Could not delete file: {e}[/yellow]\")\n\n                        # Step 2: Download the fresh file\n                        if selected_model.repo_type == \"huggingface\":\n                            # Use hf_hub_download for HuggingFace (inherits HF_ENDPOINT env var)\n                            try:\n                                from huggingface_hub import hf_hub_download\n\n                                hf_hub_download(\n                                    repo_id=selected_model.repo_id,\n                                    filename=file_to_repair,\n                                    local_dir=selected_model.path,\n                                    local_dir_use_symlinks=False,\n                                )\n                                console.print(f\"    [green]✓ Downloaded successfully[/green]\")\n                                success_count += 1\n                            except ImportError:\n                                print_error(\"huggingface_hub not installed. Install: pip install huggingface_hub\")\n                                break\n                            except Exception as e:\n                                console.print(f\"    [red]✗ Download failed: {e}[/red]\")\n                        else:\n                            # Use modelscope download for ModelScope\n                            try:\n                                from modelscope.hub.snapshot_download import snapshot_download\n\n                                # Download directly to local_dir\n                                snapshot_download(\n                                    model_id=selected_model.repo_id,\n                                    local_dir=selected_model.path,\n                                    allow_file_pattern=file_to_repair,\n                                )\n                                console.print(f\"    [green]✓ Downloaded successfully[/green]\")\n                                success_count += 1\n                            except ImportError:\n                                print_error(\"modelscope not installed. Install: pip install modelscope\")\n                                break\n                            except Exception as e:\n                                console.print(f\"    [red]✗ Download failed: {e}[/red]\")\n                finally:\n                    # Restore original HF_ENDPOINT\n                    if use_mirror and selected_model.repo_type == \"huggingface\" and not original_hf_endpoint:\n                        os.environ.pop(\"HF_ENDPOINT\", None)\n                    elif original_hf_endpoint:\n                        os.environ[\"HF_ENDPOINT\"] = original_hf_endpoint\n\n                console.print()\n                if success_count > 0:\n                    print_success(f\"Repaired {success_count}/{len(files_to_download)} files\")\n                    console.print()\n\n                    # Ask if user wants to re-verify\n                    if Confirm.ask(\"Re-verify the model now?\", default=True):\n                        # Re-verify by continuing the loop with the same model\n                        # Only verify the files that were repaired\n                        name = selected_model.name\n                        files_to_verify = files_to_download\n                        continue\n\n\n@app.command(name=\"verify-all\")\ndef verify_all_models() -> None:\n    \"\"\"Verify all models with repo configuration (not yet implemented).\"\"\"\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n\n    registry = UserModelRegistry()\n    models = registry.list_models()\n\n    # Filter models with repo configuration\n    models_with_repo = [m for m in models if m.repo_id]\n\n    if not models_with_repo:\n        print_warning(t(\"model_verify_all_no_repos\"))\n        console.print()\n        console.print(f\"  {t('model_verify_all_config_hint')} [cyan]kt model edit <name>[/cyan]\")\n        console.print()\n        return\n\n    console.print()\n    print_warning(t(\"model_verify_not_implemented\"))\n    console.print()\n    console.print(f\"  {t('model_verify_all_found', count=len(models_with_repo))}\")\n    console.print()\n\n    for model in models_with_repo:\n        console.print(f\"  • {model.name} ({model.repo_type}:{model.repo_id})\")\n\n    console.print()\n    console.print(f\"  [dim]{t('model_verify_future_note')}[/dim]\")\n    console.print()\n    console.print(f\"  {t('model_verify_all_manual_hint')} [cyan]kt model verify <name>[/cyan]\")\n    console.print()\n\n\n@app.command(name=\"auto-repo\")\ndef auto_detect_repo(\n    apply: bool = typer.Option(\n        False, \"--apply\", \"-a\", help=\"Automatically apply detected repo information without confirmation\"\n    ),\n    dry_run: bool = typer.Option(\n        False, \"--dry-run\", \"-d\", help=\"Show what would be detected without making any changes\"\n    ),\n) -> None:\n    \"\"\"\n    Auto-detect repository information from model README.md files.\n\n    Scans all models without repo_id (safetensors/gguf only) and attempts to\n    extract repository information from README.md metadata (license_link field).\n\n    Examples:\n        kt model auto-repo              # Scan and ask for confirmation\n        kt model auto-repo --apply      # Scan and apply automatically\n        kt model auto-repo --dry-run    # Scan only, no changes\n    \"\"\"\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n    from kt_kernel.cli.utils.repo_detector import scan_models_for_repo, format_detection_report, apply_detection_results\n    from rich.table import Table\n\n    console.print()\n    print_info(\"Scanning models for repository information...\")\n    console.print()\n\n    # Get all models\n    registry = UserModelRegistry()\n    models = registry.list_models()\n\n    if not models:\n        print_warning(\"No models found in registry\")\n        console.print()\n        return\n\n    # Scan for repo information\n    print_step(\"Analyzing README.md files...\")\n    results = scan_models_for_repo(models)\n\n    # Show results\n    console.print()\n\n    if not results[\"detected\"] and not results[\"not_detected\"]:\n        print_info(\"All models already have repository information configured\")\n        console.print()\n        return\n\n    # Create results table\n    if results[\"detected\"]:\n        console.print(\"[bold green]✓ Detected Repository Information[/bold green]\")\n        console.print()\n\n        table = Table(show_header=True, header_style=\"bold cyan\")\n        table.add_column(\"Model Name\", style=\"yellow\")\n        table.add_column(\"Repository\", style=\"cyan\")\n        table.add_column(\"Type\", style=\"magenta\")\n\n        for model, repo_id, repo_type in results[\"detected\"]:\n            table.add_row(model.name, repo_id, repo_type)\n\n        console.print(table)\n        console.print()\n\n    if results[\"not_detected\"]:\n        console.print(\n            f\"[bold yellow]✗ No Repository Information Found ({len(results['not_detected'])} models)[/bold yellow]\"\n        )\n        console.print()\n\n        for model in results[\"not_detected\"][:5]:  # Show first 5\n            console.print(f\"  • {model.name}\")\n\n        if len(results[\"not_detected\"]) > 5:\n            console.print(f\"  ... and {len(results['not_detected']) - 5} more\")\n\n        console.print()\n\n    if results[\"skipped\"]:\n        console.print(\n            f\"[dim]⊘ Skipped {len(results['skipped'])} models (already configured or not safetensors/gguf)[/dim]\"\n        )\n        console.print()\n\n    # Summary\n    console.print(\"[bold]Summary:[/bold]\")\n    console.print(f\"  • [green]{len(results['detected'])}[/green] detected\")\n    console.print(f\"  • [yellow]{len(results['not_detected'])}[/yellow] not detected\")\n    console.print(f\"  • [dim]{len(results['skipped'])}[/dim] skipped\")\n    console.print()\n\n    # Exit if dry run or no detections\n    if dry_run:\n        print_info(\"Dry run mode - no changes made\")\n        console.print()\n        return\n\n    if not results[\"detected\"]:\n        console.print()\n        return\n\n    # Ask for confirmation (unless --apply flag)\n    if not apply:\n        console.print()\n        if not confirm(f\"Apply repository information to {len(results['detected'])} model(s)?\", default=False):\n            print_warning(\"Cancelled - no changes made\")\n            console.print()\n            return\n\n    # Apply changes\n    console.print()\n    print_step(\"Applying changes...\")\n\n    updated_count = apply_detection_results(results, registry)\n\n    console.print()\n    if updated_count > 0:\n        print_success(f\"✓ Updated {updated_count} model(s) with repository information\")\n        console.print()\n        console.print(\"  You can now:\")\n        console.print(\"  • Run [cyan]kt model verify <name>[/cyan] to verify model integrity\")\n        console.print(\"  • Check status with [cyan]kt model list[/cyan]\")\n        console.print()\n    else:\n        print_error(\"Failed to update models\")\n        console.print()\n"
  },
  {
    "path": "kt-kernel/python/cli/commands/quant.py",
    "content": "\"\"\"\nQuant command for kt-cli.\n\nQuantizes model weights for CPU inference.\n\"\"\"\n\nimport subprocess\nimport sys\nfrom enum import Enum\nfrom pathlib import Path\nfrom typing import Optional\n\nimport typer\n\nfrom kt_kernel.cli.config.settings import get_settings\nfrom kt_kernel.cli.i18n import t\nfrom kt_kernel.cli.utils.console import (\n    confirm,\n    console,\n    create_progress,\n    print_error,\n    print_info,\n    print_step,\n    print_success,\n    print_warning,\n)\nfrom kt_kernel.cli.utils.environment import detect_cpu_info\n\n\nclass QuantMethod(str, Enum):\n    \"\"\"Quantization method.\"\"\"\n\n    INT4 = \"int4\"\n    INT8 = \"int8\"\n\n\ndef quant(\n    model: Optional[str] = typer.Argument(\n        None,\n        help=\"Model name or path to quantize\",\n    ),\n    method: Optional[QuantMethod] = typer.Option(\n        None,\n        \"--method\",\n        \"-m\",\n        help=\"Quantization method\",\n    ),\n    output: Optional[Path] = typer.Option(\n        None,\n        \"--output\",\n        \"-o\",\n        help=\"Output path for quantized weights\",\n    ),\n    input_type: Optional[str] = typer.Option(\n        None,\n        \"--input-type\",\n        \"-i\",\n        help=\"Input weight type (fp8, fp16, bf16)\",\n    ),\n    cpu_threads: Optional[int] = typer.Option(\n        None,\n        \"--cpu-threads\",\n        help=\"Number of CPU threads for quantization\",\n    ),\n    numa_nodes: Optional[int] = typer.Option(\n        None,\n        \"--numa-nodes\",\n        help=\"Number of NUMA nodes\",\n    ),\n    no_merge: bool = typer.Option(\n        False,\n        \"--no-merge\",\n        help=\"Don't merge safetensor files\",\n    ),\n    gpu: bool = typer.Option(\n        False,\n        \"--gpu\",\n        help=\"Use GPU for conversion (faster)\",\n    ),\n    yes: bool = typer.Option(\n        False,\n        \"--yes\",\n        \"-y\",\n        help=\"Skip confirmation prompts\",\n    ),\n) -> None:\n    \"\"\"Quantize model weights for CPU inference.\n\n    If no model is specified, interactive mode will be activated.\n    \"\"\"\n    settings = get_settings()\n\n    # Check if we should use interactive mode\n    # Interactive mode triggers when: no model, or missing critical parameters\n    needs_interactive = model is None or method is None or cpu_threads is None or numa_nodes is None\n    is_interactive = False\n\n    if needs_interactive and sys.stdin.isatty():\n        # Use interactive configuration (includes verification in Step 1.5)\n        from kt_kernel.cli.utils.quant_interactive import interactive_quant_config\n\n        console.print()\n        console.print(f\"[bold cyan]═══ {t('quant_interactive_title')} ═══[/bold cyan]\")\n        console.print()\n        console.print(f\"[yellow]{t('quant_new_model_notice')}[/yellow]\")\n        console.print()\n\n        config = interactive_quant_config()\n        if config is None:\n            # User cancelled\n            raise typer.Exit(0)\n\n        # Extract configuration\n        model_obj = config[\"model\"]\n        model = model_obj.id\n        input_path = Path(model_obj.path)\n        method = QuantMethod(config[\"method\"])\n        input_type = config[\"input_type\"]\n        cpu_threads = config[\"cpu_threads\"]\n        numa_nodes = config[\"numa_nodes\"]\n        output = config[\"output_path\"]\n        gpu = config[\"use_gpu\"]\n        is_interactive = True\n\n        console.print()\n        print_success(t(\"quant_config_complete\"))\n        console.print()\n    else:\n        # Non-interactive mode - require model parameter\n        if model is None:\n            print_error(\"Model argument is required in non-interactive mode\")\n            console.print()\n            console.print(\"Usage: kt quant <model>\")\n            console.print(\"   Or: kt quant  (for interactive mode)\")\n            raise typer.Exit(1)\n\n        # Set defaults for optional parameters\n        method = method or QuantMethod.INT4\n        input_type = input_type or \"fp8\"\n\n        console.print()\n\n        # Resolve input path\n        input_path = _resolve_input_path(model, settings)\n        if input_path is None:\n            print_error(t(\"quant_input_not_found\", path=model))\n            raise typer.Exit(1)\n\n        # Pre-quantization verification (only in non-interactive mode)\n        # Interactive mode already did verification in interactive_quant_config()\n        from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n        from kt_kernel.cli.utils.model_verifier import pre_operation_verification\n\n        user_registry = UserModelRegistry()\n        user_model_obj = user_registry.find_by_path(str(input_path))\n\n        if user_model_obj and user_model_obj.format == \"safetensors\":\n            pre_operation_verification(user_model_obj, user_registry, operation_name=\"quantizing\")\n\n    # Get user model info for both modes (needed later for registering quantized model)\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n\n    user_registry = UserModelRegistry()\n    user_model_obj = user_registry.find_by_path(str(input_path))\n\n    # Validate that it's a MoE model (not AMX or GGUF)\n    from kt_kernel.cli.commands.model import is_amx_weights\n\n    # Check if it's AMX (already quantized)\n    is_amx, _ = is_amx_weights(str(input_path))\n    if is_amx:\n        print_error(\"Cannot quantize AMX models (already quantized)\")\n        console.print()\n        console.print(f\"  The model at {input_path} is already in AMX format.\")\n        raise typer.Exit(1)\n\n    # Check if it's a MoE model\n    from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model\n\n    moe_result = None  # Store for later use when registering quantized model\n    try:\n        moe_result = analyze_moe_model(str(input_path), use_cache=True)\n        if not moe_result or not moe_result.get(\"is_moe\"):\n            print_error(\"Only MoE models can be quantized to AMX format\")\n            console.print()\n            console.print(f\"  The model at {input_path} is not a MoE model.\")\n            console.print(\"  AMX quantization is designed for MoE models (e.g., DeepSeek-V3).\")\n            raise typer.Exit(1)\n    except Exception as e:\n        print_warning(f\"Could not detect MoE information: {e}\")\n        console.print()\n        if not yes:\n            if not confirm(\"Continue quantization anyway?\", default=False):\n                raise typer.Exit(1)\n\n    # Detect CPU configuration and resolve output path (only needed in non-interactive mode)\n    if not is_interactive:\n        print_info(t(\"quant_input_path\", path=str(input_path)))\n\n        # Detect CPU configuration (needed for output path)\n        cpu = detect_cpu_info()\n        final_cpu_threads = cpu_threads or cpu.cores\n        final_numa_nodes = numa_nodes or cpu.numa_nodes\n\n        # Resolve output path\n        if output is None:\n            # Priority: paths.weights > paths.models[0] > model's parent directory\n            weights_dir = settings.weights_dir\n\n            if weights_dir and weights_dir.exists():\n                # Use configured weights directory (highest priority)\n                output = weights_dir / f\"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}\"\n            else:\n                # Use first model storage path\n                model_paths = settings.get_model_paths()\n                if model_paths and model_paths[0].exists():\n                    output = model_paths[0] / f\"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}\"\n                else:\n                    # Fallback to model's parent directory\n                    output = input_path.parent / f\"{input_path.name}-AMX{method.value.upper()}-NUMA{final_numa_nodes}\"\n\n        print_info(t(\"quant_output_path\", path=str(output)))\n        print_info(t(\"quant_method\", method=method.value.upper()))\n        print_info(t(\"quant_cpu_threads\", threads=final_cpu_threads))\n        print_info(t(\"quant_numa_nodes\", nodes=final_numa_nodes))\n\n        # Calculate space requirements\n        console.print()\n        console.print(f\"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]\")\n        console.print()\n\n        # Calculate source model size\n        try:\n            total_bytes = sum(f.stat().st_size for f in input_path.glob(\"*.safetensors\") if f.is_file())\n            source_size_gb = total_bytes / (1024**3)\n        except Exception:\n            source_size_gb = 0.0\n\n        # Estimate quantized size\n        input_bits = {\"fp8\": 8, \"fp16\": 16, \"bf16\": 16}\n        quant_bits = {\"int4\": 4, \"int8\": 8}\n        input_bit = input_bits.get(input_type, 16)\n        quant_bit = quant_bits.get(method.value, 4)\n        ratio = quant_bit / input_bit\n        estimated_size_gb = source_size_gb * ratio\n\n        # Check available space\n        import shutil\n\n        try:\n            check_path = output.parent if not output.exists() else output\n            while not check_path.exists() and check_path != check_path.parent:\n                check_path = check_path.parent\n            stat = shutil.disk_usage(check_path)\n            available_gb = stat.free / (1024**3)\n        except Exception:\n            available_gb = 0.0\n\n        is_sufficient = available_gb >= (estimated_size_gb * 1.2)\n\n        console.print(f\"  {t('quant_source_size'):<26} {source_size_gb:.2f} GB\")\n        console.print(f\"  {t('quant_estimated_size'):<26} {estimated_size_gb:.2f} GB\")\n        console.print(f\"  {t('quant_available_space'):<26} {available_gb:.2f} GB\")\n        console.print()\n\n        if not is_sufficient:\n            required_with_buffer = estimated_size_gb * 1.2\n            print_warning(t(\"quant_insufficient_space\"))\n            console.print()\n            console.print(f\"  {t('quant_required_space'):<26} {required_with_buffer:.2f} GB\")\n            console.print(f\"  {t('quant_available_space'):<26} {available_gb:.2f} GB\")\n            console.print(f\"  {t('quant_shortage'):<26} {required_with_buffer - available_gb:.2f} GB\")\n            console.print()\n            console.print(f\"  {t('quant_may_fail')}\")\n            console.print()\n\n            if not yes:\n                if not confirm(t(\"quant_continue_anyway\"), default=False):\n                    raise typer.Abort()\n            console.print()\n\n        # Check if output exists and generate unique name\n        if output.exists():\n            print_warning(t(\"quant_output_exists\", path=str(output)))\n            console.print()\n\n            # Generate unique name by adding suffix\n            original_name = output.name\n            parent_dir = output.parent\n            counter = 2\n\n            while output.exists():\n                new_name = f\"{original_name}-{counter}\"\n                output = parent_dir / new_name\n                counter += 1\n\n            print_success(t(\"quant_using_unique\", path=str(output)))\n            console.print()\n\n        # Confirm (only show if not using --yes flag)\n        if not yes:\n            console.print()\n            print_warning(t(\"quant_time_warning\"))\n            console.print()\n\n            if not confirm(t(\"prompt_continue\")):\n                raise typer.Abort()\n    else:\n        # Interactive mode: cpu_threads and numa_nodes already set\n        final_cpu_threads = cpu_threads\n        final_numa_nodes = numa_nodes\n\n    # Find conversion script\n    kt_kernel_path = _find_kt_kernel_path()\n    if kt_kernel_path is None:\n        print_error(\"kt-kernel not found. Install with: kt install inference\")\n        raise typer.Exit(1)\n\n    script_path = kt_kernel_path / \"scripts\" / \"convert_cpu_weights.py\"\n    if not script_path.exists():\n        print_error(f\"Conversion script not found: {script_path}\")\n        raise typer.Exit(1)\n\n    # Build command\n    cmd = [\n        sys.executable,\n        str(script_path),\n        \"--input-path\",\n        str(input_path),\n        \"--input-type\",\n        input_type,\n        \"--output\",\n        str(output),\n        \"--quant-method\",\n        method.value,\n        \"--cpuinfer-threads\",\n        str(final_cpu_threads),\n        \"--threadpool-count\",\n        str(final_numa_nodes),\n    ]\n\n    if no_merge:\n        cmd.append(\"--no-merge-safetensor\")\n\n    if gpu:\n        cmd.append(\"--gpu\")\n\n    # Run quantization\n    console.print()\n    print_step(t(\"quant_starting\"))\n    console.print()\n    console.print(f\"[dim]$ {' '.join(cmd)}[/dim]\")\n    console.print()\n    console.print(\"[dim]\" + \"=\" * 80 + \"[/dim]\")\n    console.print()\n\n    try:\n        # Run with real-time stdout/stderr output\n        import os\n        import time\n\n        env = os.environ.copy()\n        env[\"PYTHONUNBUFFERED\"] = \"1\"  # Disable Python output buffering\n\n        # Record start time\n        start_time = time.time()\n\n        process = subprocess.run(\n            cmd,\n            stdout=None,  # Inherit parent's stdout (real-time output)\n            stderr=None,  # Inherit parent's stderr (real-time output)\n            env=env,\n        )\n\n        # Calculate elapsed time\n        elapsed_time = time.time() - start_time\n        hours = int(elapsed_time // 3600)\n        minutes = int((elapsed_time % 3600) // 60)\n        seconds = int(elapsed_time % 60)\n\n        console.print()\n        console.print(\"[dim]\" + \"=\" * 80 + \"[/dim]\")\n        console.print()\n\n        if process.returncode == 0:\n            print_success(t(\"quant_complete\"))\n            console.print()\n\n            # Display elapsed time\n            if hours > 0:\n                time_str = f\"{hours}h {minutes}m {seconds}s\"\n            elif minutes > 0:\n                time_str = f\"{minutes}m {seconds}s\"\n            else:\n                time_str = f\"{seconds}s\"\n            console.print(f\"  [cyan]{t('quant_time_elapsed')} {time_str}[/cyan]\")\n            console.print()\n            console.print(f\"  Quantized weights saved to: {output}\")\n            console.print()\n\n            # Auto-register the quantized model\n            try:\n                from kt_kernel.cli.utils.user_model_registry import UserModel\n\n                # Generate model name from output path\n                base_name = output.name\n                suggested_name = user_registry.suggest_name(base_name)\n\n                # Determine MoE information and source model name\n                if user_model_obj:\n                    is_moe_val = user_model_obj.is_moe\n                    num_experts = user_model_obj.moe_num_experts\n                    num_active = user_model_obj.moe_num_experts_per_tok\n                    repo_type_val = user_model_obj.repo_type\n                    repo_id_val = user_model_obj.repo_id\n                    source_model_name = user_model_obj.name  # Store source model name\n                elif moe_result:\n                    is_moe_val = moe_result.get(\"is_moe\", True)\n                    num_experts = moe_result.get(\"num_experts\")\n                    num_active = moe_result.get(\"num_experts_per_tok\")\n                    repo_type_val = None\n                    repo_id_val = None\n                    source_model_name = input_path.name  # Use folder name as fallback\n                else:\n                    is_moe_val = None\n                    num_experts = None\n                    num_active = None\n                    repo_type_val = None\n                    repo_id_val = None\n                    source_model_name = input_path.name  # Use folder name as fallback\n\n                # Create new model entry (AMX format uses \"safetensors\" format, detected by is_amx_weights())\n                new_model = UserModel(\n                    name=suggested_name,\n                    path=str(output),\n                    format=\"safetensors\",  # AMX files are safetensors format\n                    repo_type=repo_type_val,\n                    repo_id=repo_id_val,\n                    sha256_status=\"not_checked\",  # AMX weights don't need verification\n                    # Inherit MoE information from source model\n                    is_moe=is_moe_val,\n                    moe_num_experts=num_experts,\n                    moe_num_experts_per_tok=num_active,\n                    # AMX quantization metadata\n                    amx_source_model=source_model_name,\n                    amx_quant_method=method.value,  # \"int4\" or \"int8\"\n                    amx_numa_nodes=final_numa_nodes,\n                )\n\n                user_registry.add_model(new_model)\n                console.print()\n                print_success(t(\"quant_registered\", name=suggested_name))\n                console.print()\n                console.print(f\"  {t('quant_view_with')} [cyan]kt model list[/cyan]\")\n                console.print(f\"  {t('quant_use_with')}  [cyan]kt run {suggested_name}[/cyan]\")\n                console.print()\n            except Exception as e:\n                # Non-fatal error - quantization succeeded but registration failed\n                console.print()\n                print_warning(t(\"quant_register_failed\", error=str(e)))\n                console.print()\n                console.print(f\"  {t('quant_use_with')}\")\n                console.print(f\"    kt run {model} --weights-path {output}\")\n                console.print()\n        else:\n            print_error(f\"Quantization failed with exit code {process.returncode}\")\n            raise typer.Exit(process.returncode)\n\n    except FileNotFoundError as e:\n        print_error(f\"Failed to run quantization: {e}\")\n        raise typer.Exit(1)\n    except KeyboardInterrupt:\n        console.print()\n        print_warning(\"Quantization interrupted.\")\n        raise typer.Exit(130)\n\n\ndef _resolve_input_path(model: str, settings) -> Optional[Path]:\n    \"\"\"Resolve the input model path.\"\"\"\n    # Check if it's already a path\n    path = Path(model)\n    if path.exists() and (path / \"config.json\").exists():\n        return path\n\n    # Search in models directory\n    from kt_kernel.cli.utils.model_registry import get_registry\n\n    registry = get_registry()\n    matches = registry.search(model)\n\n    if matches:\n        model_info = matches[0]\n        # Try to find in all configured model directories\n        model_paths = settings.get_model_paths()\n\n        for models_dir in model_paths:\n            possible_paths = [\n                models_dir / model_info.name,\n                models_dir / model_info.name.lower(),\n                models_dir / model_info.hf_repo.split(\"/\")[-1],\n            ]\n\n            for p in possible_paths:\n                if p.exists() and (p / \"config.json\").exists():\n                    return p\n\n    return None\n\n\ndef _find_kt_kernel_path() -> Optional[Path]:\n    \"\"\"Find the kt-kernel installation path.\"\"\"\n    try:\n        import kt_kernel\n\n        return Path(kt_kernel.__file__).parent.parent\n    except ImportError:\n        pass\n\n    # Check common locations\n    possible_paths = [\n        Path.home() / \"Projects\" / \"ktransformers\" / \"kt-kernel\",\n        Path.cwd().parent / \"kt-kernel\",\n        Path.cwd() / \"kt-kernel\",\n    ]\n\n    for path in possible_paths:\n        if path.exists() and (path / \"scripts\").exists():\n            return path\n\n    return None\n"
  },
  {
    "path": "kt-kernel/python/cli/commands/run.py",
    "content": "\"\"\"\nRun command for kt-cli.\n\nStarts the model inference server using SGLang + kt-kernel.\n\"\"\"\n\nimport os\nimport subprocess\nimport sys\nfrom pathlib import Path\nfrom typing import Optional\n\nimport click\nimport typer\n\nfrom kt_kernel.cli.config.settings import get_settings\nfrom kt_kernel.cli.i18n import t\nfrom kt_kernel.cli.utils.console import (\n    confirm,\n    console,\n    print_api_info,\n    print_error,\n    print_info,\n    print_server_info,\n    print_step,\n    print_success,\n    print_warning,\n    prompt_choice,\n)\nfrom kt_kernel.cli.utils.environment import detect_cpu_info, detect_gpus, detect_ram_gb\nfrom kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n\n\n@click.command(\n    context_settings={\"ignore_unknown_options\": True, \"allow_extra_args\": True},\n    add_help_option=False,  # We'll handle help manually to avoid conflicts\n)\n@click.argument(\"model\", required=False, default=None)\n@click.option(\"--host\", \"-H\", default=None, help=\"Server host address\")\n@click.option(\"--port\", \"-p\", type=int, default=None, help=\"Server port\")\n@click.option(\"--gpu-experts\", type=int, default=None, help=\"Number of GPU experts per layer\")\n@click.option(\"--cpu-threads\", type=int, default=None, help=\"Number of CPU inference threads\")\n@click.option(\"--numa-nodes\", type=int, default=None, help=\"Number of NUMA nodes\")\n@click.option(\n    \"--tensor-parallel-size\", \"--tp\", \"tensor_parallel_size\", type=int, default=None, help=\"Tensor parallel size\"\n)\n@click.option(\"--model-path\", type=click.Path(), default=None, help=\"Custom model path\")\n@click.option(\"--weights-path\", type=click.Path(), default=None, help=\"Custom quantized weights path\")\n@click.option(\"--kt-method\", default=None, help=\"KT quantization method\")\n@click.option(\n    \"--kt-gpu-prefill-threshold\", \"kt_gpu_prefill_threshold\", type=int, default=None, help=\"GPU prefill token threshold\"\n)\n@click.option(\"--attention-backend\", default=None, help=\"Attention backend\")\n@click.option(\"--max-total-tokens\", \"max_total_tokens\", type=int, default=None, help=\"Maximum total tokens\")\n@click.option(\"--max-running-requests\", \"max_running_requests\", type=int, default=None, help=\"Maximum running requests\")\n@click.option(\"--chunked-prefill-size\", \"chunked_prefill_size\", type=int, default=None, help=\"Chunked prefill size\")\n@click.option(\"--mem-fraction-static\", \"mem_fraction_static\", type=float, default=None, help=\"Memory fraction static\")\n@click.option(\"--watchdog-timeout\", \"watchdog_timeout\", type=int, default=None, help=\"Watchdog timeout\")\n@click.option(\"--served-model-name\", \"served_model_name\", default=None, help=\"Served model name\")\n@click.option(\n    \"--disable-shared-experts-fusion\",\n    \"disable_shared_experts_fusion\",\n    is_flag=True,\n    default=None,\n    help=\"Disable shared experts fusion\",\n)\n@click.option(\n    \"--enable-shared-experts-fusion\",\n    \"enable_shared_experts_fusion\",\n    is_flag=True,\n    default=False,\n    help=\"Enable shared experts fusion\",\n)\n@click.option(\"--quantize\", \"-q\", is_flag=True, default=False, help=\"Quantize model\")\n@click.option(\"--advanced\", is_flag=True, default=False, help=\"Show advanced options\")\n@click.option(\"--dry-run\", \"dry_run\", is_flag=True, default=False, help=\"Show command without executing\")\n@click.pass_context\ndef run(\n    ctx: click.Context,\n    model: Optional[str],\n    host: Optional[str],\n    port: Optional[int],\n    gpu_experts: Optional[int],\n    cpu_threads: Optional[int],\n    numa_nodes: Optional[int],\n    tensor_parallel_size: Optional[int],\n    model_path: Optional[str],\n    weights_path: Optional[str],\n    kt_method: Optional[str],\n    kt_gpu_prefill_threshold: Optional[int],\n    attention_backend: Optional[str],\n    max_total_tokens: Optional[int],\n    max_running_requests: Optional[int],\n    chunked_prefill_size: Optional[int],\n    mem_fraction_static: Optional[float],\n    watchdog_timeout: Optional[int],\n    served_model_name: Optional[str],\n    disable_shared_experts_fusion: Optional[bool],\n    enable_shared_experts_fusion: bool,\n    quantize: bool,\n    advanced: bool,\n    dry_run: bool,\n) -> None:\n    \"\"\"Start model inference server.\n\n    \\b\n    Examples: kt run deepseek-v3 | kt run m2 --tensor-parallel-size 2 | kt run /path/to/model --gpu-experts 4\n\n    \\b\n    Custom Options: Pass any SGLang server option directly (e.g., kt run m2 --fp8-gemm-backend triton).\n    Common: --fp8-gemm-backend, --tool-call-parser, --reasoning-parser, --dp-size, --enable-ma\n    For full list: python -m sglang.launch_server --help\n    \"\"\"\n    # Handle --help manually since we disabled it\n    # Check sys.argv for --help or -h since ctx.args may not be set yet\n    if \"--help\" in sys.argv or \"-h\" in sys.argv:\n        click.echo(ctx.get_help())\n        return\n\n    # Handle disable/enable shared experts fusion flags\n    if enable_shared_experts_fusion:\n        disable_shared_experts_fusion = False\n\n    # Convert Path objects from click\n    model_path_obj = Path(model_path) if model_path else None\n    weights_path_obj = Path(weights_path) if weights_path else None\n\n    # Get extra args that weren't parsed (unknown options)\n    # click stores these in ctx.args when ignore_unknown_options=True\n    extra_cli_args = list(ctx.args) if ctx.args else []\n\n    # Remove --help from extra args if present (already handled)\n    extra_cli_args = [arg for arg in extra_cli_args if arg not in [\"--help\", \"-h\"]]\n\n    # Call the actual run function implementation\n    _run_impl(\n        model=model,\n        host=host,\n        port=port,\n        gpu_experts=gpu_experts,\n        cpu_threads=cpu_threads,\n        numa_nodes=numa_nodes,\n        tensor_parallel_size=tensor_parallel_size,\n        model_path=model_path_obj,\n        weights_path=weights_path_obj,\n        kt_method=kt_method,\n        kt_gpu_prefill_threshold=kt_gpu_prefill_threshold,\n        attention_backend=attention_backend,\n        max_total_tokens=max_total_tokens,\n        max_running_requests=max_running_requests,\n        chunked_prefill_size=chunked_prefill_size,\n        mem_fraction_static=mem_fraction_static,\n        watchdog_timeout=watchdog_timeout,\n        served_model_name=served_model_name,\n        disable_shared_experts_fusion=disable_shared_experts_fusion,\n        quantize=quantize,\n        advanced=advanced,\n        dry_run=dry_run,\n        extra_cli_args=extra_cli_args,\n    )\n\n\ndef _run_impl(\n    model: Optional[str],\n    host: Optional[str],\n    port: Optional[int],\n    gpu_experts: Optional[int],\n    cpu_threads: Optional[int],\n    numa_nodes: Optional[int],\n    tensor_parallel_size: Optional[int],\n    model_path: Optional[Path],\n    weights_path: Optional[Path],\n    kt_method: Optional[str],\n    kt_gpu_prefill_threshold: Optional[int],\n    attention_backend: Optional[str],\n    max_total_tokens: Optional[int],\n    max_running_requests: Optional[int],\n    chunked_prefill_size: Optional[int],\n    mem_fraction_static: Optional[float],\n    watchdog_timeout: Optional[int],\n    served_model_name: Optional[str],\n    disable_shared_experts_fusion: Optional[bool],\n    quantize: bool,\n    advanced: bool,\n    dry_run: bool,\n    extra_cli_args: list[str],\n) -> None:\n    \"\"\"Actual implementation of run command.\"\"\"\n    # Check if SGLang is installed before proceeding\n    from kt_kernel.cli.utils.sglang_checker import (\n        check_sglang_installation,\n        check_sglang_kt_kernel_support,\n        print_sglang_install_instructions,\n        print_sglang_kt_kernel_instructions,\n    )\n\n    sglang_info = check_sglang_installation()\n    if not sglang_info[\"installed\"]:\n        console.print()\n        print_error(t(\"sglang_not_found\"))\n        console.print()\n        print_sglang_install_instructions()\n        raise typer.Exit(1)\n\n    # Check if SGLang supports kt-kernel (has --kt-gpu-prefill-token-threshold parameter)\n    kt_kernel_support = check_sglang_kt_kernel_support()\n    if not kt_kernel_support[\"supported\"]:\n        console.print()\n        print_error(t(\"sglang_kt_kernel_not_supported\"))\n        console.print()\n        print_sglang_kt_kernel_instructions()\n        raise typer.Exit(1)\n\n    settings = get_settings()\n    user_registry = UserModelRegistry()\n\n    # Check if we should use interactive mode\n    # Interactive mode triggers when:\n    # 1. No model specified, OR\n    # 2. Model specified but missing critical parameters (gpu_experts, tensor_parallel_size, etc.)\n    use_interactive = False\n\n    if model is None:\n        use_interactive = True\n    elif (\n        gpu_experts is None\n        or tensor_parallel_size is None\n        or cpu_threads is None\n        or numa_nodes is None\n        or max_total_tokens is None\n    ):\n        # Model specified but some parameters missing - use interactive\n        use_interactive = True\n\n    if use_interactive and sys.stdin.isatty():\n        # Use new interactive configuration flow\n        from kt_kernel.cli.utils.run_interactive import interactive_run_config\n\n        console.print()\n        console.print(\"[bold cyan]═══ Interactive Run Configuration ═══[/bold cyan]\")\n        console.print()\n\n        config = interactive_run_config()\n        if config is None:\n            # User cancelled\n            raise typer.Exit(0)\n\n        # Extract configuration from new format\n        user_model_obj = config[\"model\"]\n        model = user_model_obj.id\n        resolved_model_path = Path(config[\"model_path\"])\n        resolved_weights_path = Path(config[\"weights_path\"])\n\n        # Extract parameters\n        gpu_experts = config[\"gpu_experts\"]\n        cpu_threads = config[\"cpu_threads\"]\n        numa_nodes = config[\"numa_nodes\"]\n        tensor_parallel_size = config[\"tp_size\"]\n\n        # Get kt-method and other method-specific settings\n        kt_method = config[\"kt_method\"]\n\n        # KV cache settings (may be None for non-raw methods)\n        max_total_tokens = config.get(\"kv_cache\", 32768)\n        chunked_prefill_size = config.get(\"chunk_prefill\", 32768)\n        kt_gpu_prefill_threshold = config.get(\"gpu_prefill_threshold\", 500)\n\n        # Memory settings\n        mem_fraction_static = config[\"mem_fraction_static\"]\n\n        # Parser settings (optional)\n        tool_call_parser = config.get(\"tool_call_parser\")\n        reasoning_parser = config.get(\"reasoning_parser\")\n\n        # Server settings\n        host = config.get(\"host\", \"0.0.0.0\")\n        port = config.get(\"port\", 30000)\n\n        # Set CUDA_VISIBLE_DEVICES for selected GPUs\n        selected_gpus = config[\"selected_gpus\"]\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join(str(gpu_id) for gpu_id in selected_gpus)\n\n        # Detect hardware for parameter resolution (needed for resolve() function later)\n        gpus = detect_gpus()\n        cpu = detect_cpu_info()\n\n        console.print()\n        print_info(f\"[green]✓[/green] Configuration complete\")\n        console.print()\n    else:\n        # Non-interactive mode - use traditional flow\n        console.print()\n\n        # Initialize variables that may have been set by interactive mode\n        # These will be None in non-interactive mode and will use defaults via resolve()\n\n        # If no model specified, show old interactive selection\n        if model is None:\n            model = _interactive_model_selection(user_registry, settings)\n            if model is None:\n                raise typer.Exit(0)\n\n        # Detect hardware (needed for defaults)\n        gpus = detect_gpus()\n        cpu = detect_cpu_info()\n        ram = detect_ram_gb()\n\n        if gpus:\n            gpu_info = f\"{gpus[0].name} ({gpus[0].vram_gb}GB VRAM)\"\n            if len(gpus) > 1:\n                gpu_info += f\" + {len(gpus) - 1} more\"\n            print_info(t(\"run_gpu_info\", name=gpus[0].name, vram=gpus[0].vram_gb))\n        else:\n            print_warning(t(\"doctor_gpu_not_found\"))\n            gpu_info = \"None\"\n\n        print_info(t(\"run_cpu_info\", name=cpu.name, cores=cpu.cores, numa=cpu.numa_nodes))\n        print_info(t(\"run_ram_info\", total=int(ram)))\n\n        # Step 2: Resolve model\n        console.print()\n        print_step(t(\"run_checking_model\"))\n\n        user_model_obj = None\n        resolved_model_path = model_path\n\n        # Check if model is a path\n        if Path(model).exists():\n            resolved_model_path = Path(model)\n            print_info(t(\"run_model_path\", path=str(resolved_model_path)))\n\n            # Try to find in user registry by path\n            user_model_obj = user_registry.find_by_path(str(resolved_model_path))\n            if user_model_obj:\n                print_info(f\"Using registered model: {user_model_obj.name}\")\n            else:\n                print_warning(\"Using unregistered model path. Consider adding it with 'kt model add'\")\n        else:\n            # Search in user registry by name\n            user_model_obj = user_registry.get_model(model)\n\n            if not user_model_obj:\n                print_error(t(\"run_model_not_found\", name=model))\n                console.print()\n\n                # Show available models\n                all_models = user_registry.list_models()\n                if all_models:\n                    console.print(\"Available registered models:\")\n                    for m in all_models[:5]:\n                        console.print(f\"  - {m.name}\")\n                    if len(all_models) > 5:\n                        console.print(f\"  ... and {len(all_models) - 5} more\")\n                else:\n                    console.print(\"No models registered yet.\")\n\n                console.print()\n                console.print(f\"Add your model with: [cyan]kt model add /path/to/model[/cyan]\")\n                console.print(f\"Or scan for models: [cyan]kt model scan[/cyan]\")\n                raise typer.Exit(1)\n\n            # Use model path from registry\n            resolved_model_path = Path(user_model_obj.path)\n\n            # Verify path exists\n            if not resolved_model_path.exists():\n                print_error(f\"Model path does not exist: {resolved_model_path}\")\n                console.print()\n                console.print(f\"Run 'kt model refresh' to check all models\")\n                raise typer.Exit(1)\n\n            print_info(t(\"run_model_path\", path=str(resolved_model_path)))\n\n        # Step 2.5: Pre-run verification (optional integrity check)\n        if user_model_obj and user_model_obj.format == \"safetensors\":\n            from kt_kernel.cli.utils.model_verifier import pre_operation_verification\n\n            pre_operation_verification(user_model_obj, user_registry, operation_name=\"running\")\n\n        # Step 3: Check quantized weights (only if explicitly requested)\n        resolved_weights_path = None\n\n        # Only use quantized weights if explicitly specified by user\n        if weights_path is not None:\n            # User explicitly specified weights path\n            resolved_weights_path = weights_path\n            if not resolved_weights_path.exists():\n                print_error(t(\"run_weights_not_found\"))\n                console.print(f\"  Path: {resolved_weights_path}\")\n                raise typer.Exit(1)\n            print_info(f\"Using quantized weights: {resolved_weights_path}\")\n        elif quantize:\n            # User requested quantization\n            console.print()\n            print_step(t(\"run_quantizing\"))\n            # TODO: Implement quantization\n            print_warning(\"Quantization not yet implemented. Please run 'kt quant' manually.\")\n            raise typer.Exit(1)\n        else:\n            # Default: use original precision model without quantization\n            console.print()\n            print_info(\"Using original precision model (no quantization)\")\n\n    # Step 4: Build command\n    # Helper to resolve parameter with fallback chain: CLI > config > default\n    def resolve(cli_val, config_key, default):\n        if cli_val is not None:\n            return cli_val\n        config_val = settings.get(config_key)\n        return config_val if config_val is not None else default\n\n    # Server configuration\n    final_host = resolve(host, \"server.host\", \"0.0.0.0\")\n    final_port = resolve(port, \"server.port\", 30000)\n\n    # Tensor parallel size: CLI > config > auto-detect from GPUs\n    final_tensor_parallel_size = resolve(\n        tensor_parallel_size, \"inference.tensor_parallel_size\", len(gpus) if gpus else 1\n    )\n\n    # CPU/GPU configuration with smart defaults\n    total_threads = cpu.threads  # Use logical threads instead of physical cores\n    final_cpu_threads = resolve(cpu_threads, \"inference.cpu_threads\", int(total_threads * 0.8))\n    final_numa_nodes = resolve(numa_nodes, \"inference.numa_nodes\", cpu.numa_nodes)\n    final_gpu_experts = resolve(gpu_experts, \"inference.gpu_experts\", 1)\n\n    # KT-kernel options\n    final_kt_method = resolve(kt_method, \"inference.kt_method\", \"AMXINT4\")\n    final_kt_gpu_prefill_threshold = resolve(kt_gpu_prefill_threshold, \"inference.kt_gpu_prefill_token_threshold\", 4096)\n\n    # SGLang options\n    final_attention_backend = resolve(attention_backend, \"inference.attention_backend\", \"flashinfer\")\n    final_max_total_tokens = resolve(max_total_tokens, \"inference.max_total_tokens\", 40000)\n    final_max_running_requests = resolve(max_running_requests, \"inference.max_running_requests\", 32)\n    final_chunked_prefill_size = resolve(chunked_prefill_size, \"inference.chunked_prefill_size\", 4096)\n    final_mem_fraction_static = resolve(mem_fraction_static, \"inference.mem_fraction_static\", 0.98)\n    final_watchdog_timeout = resolve(watchdog_timeout, \"inference.watchdog_timeout\", 3000)\n    final_served_model_name = resolve(served_model_name, \"inference.served_model_name\", \"\")\n\n    # Performance flags\n    final_disable_shared_experts_fusion = resolve(\n        disable_shared_experts_fusion, \"inference.disable_shared_experts_fusion\", True\n    )\n\n    # Pass extra CLI parameters\n    extra_params = {}\n\n    # Parser parameters (from interactive mode or None in non-interactive mode)\n    final_tool_call_parser = None\n    final_reasoning_parser = None\n    if \"tool_call_parser\" in locals() and tool_call_parser:\n        final_tool_call_parser = tool_call_parser\n    if \"reasoning_parser\" in locals() and reasoning_parser:\n        final_reasoning_parser = reasoning_parser\n\n    cmd = _build_sglang_command(\n        model_path=resolved_model_path,\n        weights_path=resolved_weights_path,\n        host=final_host,\n        port=final_port,\n        gpu_experts=final_gpu_experts,\n        cpu_threads=final_cpu_threads,\n        numa_nodes=final_numa_nodes,\n        tensor_parallel_size=final_tensor_parallel_size,\n        kt_method=final_kt_method,\n        kt_gpu_prefill_threshold=final_kt_gpu_prefill_threshold,\n        attention_backend=final_attention_backend,\n        max_total_tokens=final_max_total_tokens,\n        max_running_requests=final_max_running_requests,\n        chunked_prefill_size=final_chunked_prefill_size,\n        mem_fraction_static=final_mem_fraction_static,\n        watchdog_timeout=final_watchdog_timeout,\n        served_model_name=final_served_model_name,\n        disable_shared_experts_fusion=final_disable_shared_experts_fusion,\n        tool_call_parser=final_tool_call_parser,\n        reasoning_parser=final_reasoning_parser,\n        settings=settings,\n        extra_model_params=extra_params,\n        extra_cli_args=extra_cli_args,\n    )\n\n    # Prepare environment variables\n    env = os.environ.copy()\n    # Add environment variables from advanced.env\n    env.update(settings.get_env_vars())\n    # Add environment variables from inference.env\n    inference_env = settings.get(\"inference.env\", {})\n    if isinstance(inference_env, dict):\n        env.update({k: str(v) for k, v in inference_env.items()})\n\n    # Step 5: Show configuration summary\n    console.print()\n    print_step(\"Configuration\")\n\n    # Display model name\n    model_display_name = user_model_obj.name if user_model_obj else resolved_model_path.name\n    console.print(f\"  Model: [bold]{model_display_name}[/bold]\")\n\n    console.print(f\"  Path: [dim]{resolved_model_path}[/dim]\")\n\n    # Key parameters\n    console.print()\n    console.print(f\"  GPU Experts: [cyan]{final_gpu_experts}[/cyan] per layer\")\n    console.print(f\"  CPU Threads (kt-cpuinfer): [cyan]{final_cpu_threads}[/cyan]\")\n    console.print(f\"  NUMA Nodes (kt-threadpool-count): [cyan]{final_numa_nodes}[/cyan]\")\n    console.print(f\"  Tensor Parallel: [cyan]{final_tensor_parallel_size}[/cyan]\")\n    console.print(f\"  Method: [cyan]{final_kt_method}[/cyan]\")\n    console.print(f\"  Attention: [cyan]{final_attention_backend}[/cyan]\")\n\n    # Weights info\n    if resolved_weights_path:\n        console.print()\n        console.print(f\"  Quantized weights: [yellow]{resolved_weights_path}[/yellow]\")\n\n    console.print()\n    console.print(f\"  Server: [green]http://{final_host}:{final_port}[/green]\")\n    console.print()\n\n    # Step 6: Show or execute\n    if dry_run:\n        console.print()\n        console.print(\"[bold]Command:[/bold]\")\n        console.print()\n        console.print(f\"  [dim]{' '.join(cmd)}[/dim]\")\n        console.print()\n        return\n\n    # Execute with prepared environment variables\n    # Don't print \"Server started\" or API info here - let sglang's logs speak for themselves\n    # The actual startup takes time and these messages are misleading\n\n    # Print the command being executed\n    console.print()\n    console.print(\"[bold]Launching server with command:[/bold]\")\n    console.print()\n    console.print(f\"  [dim]{' '.join(cmd)}[/dim]\")\n    console.print()\n\n    try:\n        # Execute directly without intercepting output or signals\n        # This allows direct output to terminal and Ctrl+C to work naturally\n        process = subprocess.run(cmd, env=env)\n        sys.exit(process.returncode)\n\n    except FileNotFoundError:\n        from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions\n\n        print_error(t(\"sglang_not_found\"))\n        console.print()\n        print_sglang_install_instructions()\n        raise typer.Exit(1)\n    except Exception as e:\n        print_error(f\"Failed to start server: {e}\")\n        raise typer.Exit(1)\n\n\n# Dead code removed: _find_model_path() and _find_weights_path()\n# These functions were part of the old builtin model system\n\n\ndef _build_sglang_command(\n    model_path: Path,\n    weights_path: Optional[Path],\n    host: str,\n    port: int,\n    gpu_experts: int,\n    cpu_threads: int,\n    numa_nodes: int,\n    tensor_parallel_size: int,\n    kt_method: str,\n    kt_gpu_prefill_threshold: int,\n    attention_backend: str,\n    max_total_tokens: int,\n    max_running_requests: int,\n    chunked_prefill_size: int,\n    mem_fraction_static: float,\n    watchdog_timeout: int,\n    served_model_name: str,\n    disable_shared_experts_fusion: bool,\n    tool_call_parser: Optional[str],\n    reasoning_parser: Optional[str],\n    settings,\n    extra_model_params: Optional[dict] = None,  # New parameter for additional params\n    extra_cli_args: Optional[list[str]] = None,  # Extra args from CLI to pass to sglang\n) -> list[str]:\n    \"\"\"Build the SGLang launch command.\"\"\"\n    cmd = [\n        sys.executable,\n        \"-m\",\n        \"sglang.launch_server\",\n        \"--host\",\n        host,\n        \"--port\",\n        str(port),\n        \"--model\",\n        str(model_path),\n    ]\n\n    # Add kt-kernel options\n    # kt-kernel is needed for:\n    # 1. Quantized models (when weights_path is provided)\n    # 2. MoE models with CPU offloading (when kt-cpuinfer > 0 or kt-num-gpu-experts is configured)\n    use_kt_kernel = False\n\n    # Check if we should use kt-kernel\n    if weights_path:\n        # Quantized model - always use kt-kernel\n        use_kt_kernel = True\n    elif cpu_threads > 0 or gpu_experts > 1:\n        # CPU offloading configured - use kt-kernel\n        use_kt_kernel = True\n\n    if use_kt_kernel:\n        # Add kt-weight-path: use quantized weights if available, otherwise use model path\n        weight_path_to_use = weights_path if weights_path else model_path\n\n        # Add kt-kernel configuration\n        cmd.extend(\n            [\n                \"--kt-weight-path\",\n                str(weight_path_to_use),\n                \"--kt-cpuinfer\",\n                str(cpu_threads),\n                \"--kt-threadpool-count\",\n                str(numa_nodes),\n                \"--kt-num-gpu-experts\",\n                str(gpu_experts),\n                \"--kt-method\",\n                kt_method,\n                \"--kt-gpu-prefill-token-threshold\",\n                str(kt_gpu_prefill_threshold),\n                \"--kt-enable-dynamic-expert-update\",  # Enable dynamic expert updates\n            ]\n        )\n\n    # Add SGLang options\n    cmd.extend(\n        [\n            \"--attention-backend\",\n            attention_backend,\n            \"--trust-remote-code\",\n            \"--mem-fraction-static\",\n            str(mem_fraction_static),\n            \"--chunked-prefill-size\",\n            str(chunked_prefill_size),\n            \"--max-running-requests\",\n            str(max_running_requests),\n            \"--max-total-tokens\",\n            str(max_total_tokens),\n            \"--watchdog-timeout\",\n            str(watchdog_timeout),\n            \"--enable-mixed-chunk\",\n            \"--tensor-parallel-size\",\n            str(tensor_parallel_size),\n            \"--enable-p2p-check\",\n        ]\n    )\n\n    # Add served model name if specified\n    if served_model_name:\n        cmd.extend([\"--served-model-name\", served_model_name])\n\n    # Add performance flags\n    if disable_shared_experts_fusion:\n        cmd.append(\"--disable-shared-experts-fusion\")\n\n    # Add FP8 backend if using FP8 method\n    if \"FP8\" in kt_method.upper():\n        cmd.extend([\"--fp8-gemm-backend\", \"triton\"])\n\n    # Add parsers if specified\n    if tool_call_parser:\n        cmd.extend([\"--tool-call-parser\", tool_call_parser])\n    if reasoning_parser:\n        cmd.extend([\"--reasoning-parser\", reasoning_parser])\n\n    # Add any extra parameters from model defaults that weren't explicitly handled\n    if extra_model_params:\n        # List of parameters already handled above\n        handled_params = {\n            \"kt-num-gpu-experts\",\n            \"kt-cpuinfer\",\n            \"kt-threadpool-count\",\n            \"kt-method\",\n            \"kt-gpu-prefill-token-threshold\",\n            \"attention-backend\",\n            \"tensor-parallel-size\",\n            \"max-total-tokens\",\n            \"max-running-requests\",\n            \"chunked-prefill-size\",\n            \"mem-fraction-static\",\n            \"watchdog-timeout\",\n            \"served-model-name\",\n            \"disable-shared-experts-fusion\",\n        }\n\n        for key, value in extra_model_params.items():\n            if key not in handled_params:\n                # Add unhandled parameters dynamically\n                cmd.append(f\"--{key}\")\n                if isinstance(value, bool):\n                    # Boolean flags don't need a value\n                    if not value:\n                        # For False boolean, skip the flag entirely\n                        cmd.pop()  # Remove the flag we just added\n                else:\n                    cmd.append(str(value))\n\n    # Add extra args from settings\n    extra_args = settings.get(\"advanced.sglang_args\", [])\n    if extra_args:\n        cmd.extend(extra_args)\n\n    # Add extra CLI args (user-provided options not defined in kt CLI)\n    if extra_cli_args:\n        cmd.extend(extra_cli_args)\n\n    return cmd\n\n\ndef _interactive_model_selection(user_registry, settings) -> Optional[str]:\n    \"\"\"Show interactive model selection interface.\n\n    Returns:\n        Selected model name or None if cancelled.\n    \"\"\"\n    from rich.panel import Panel\n    from rich.prompt import Prompt\n\n    # Get all user models\n    all_models = user_registry.list_models()\n\n    if not all_models:\n        console.print()\n        print_warning(\"No models registered.\")\n        console.print()\n        console.print(f\"  Add models with: [cyan]kt model scan[/cyan]\")\n        console.print(f\"  Or manually: [cyan]kt model add /path/to/model[/cyan]\")\n        console.print()\n        return None\n\n    console.print()\n    console.print(\n        Panel.fit(\n            \"Select a model to run\",\n            border_style=\"cyan\",\n        )\n    )\n    console.print()\n\n    # Build choices list\n    choices = []\n    choice_map = {}  # index -> model name\n\n    # Show all user models\n    console.print(f\"[bold green]Available Models:[/bold green]\")\n    console.print()\n\n    for i, model in enumerate(all_models, 1):\n        # Check if path exists\n        path_status = \"✓\" if model.path_exists() else \"✗ Missing\"\n        console.print(f\"  [cyan][{i}][/cyan] [bold]{model.name}[/bold] [{path_status}]\")\n        console.print(f\"      [dim]{model.format} - {model.path}[/dim]\")\n        choices.append(str(i))\n        choice_map[str(i)] = model.name\n\n    console.print()\n\n    # Add cancel option\n    cancel_idx = str(len(choices) + 1)\n    console.print(f\"  [cyan][{cancel_idx}][/cyan] [dim]Cancel[/dim]\")\n    choices.append(cancel_idx)\n    console.print()\n\n    # Prompt for selection\n    try:\n        selection = Prompt.ask(\n            \"Select model\",\n            choices=choices,\n            default=\"1\" if choices else cancel_idx,\n        )\n    except KeyboardInterrupt:\n        console.print()\n        return None\n\n    if selection == cancel_idx:\n        return None\n\n    return choice_map.get(selection)\n"
  },
  {
    "path": "kt-kernel/python/cli/commands/sft.py",
    "content": "\"\"\"\nSFT command for kt-cli.\n\nFine-tuning with LlamaFactory integration.\n\"\"\"\n\nimport typer\n\nfrom kt_kernel.cli.i18n import t\nfrom kt_kernel.cli.utils.console import console\n\napp = typer.Typer(help=\"Fine-tuning with LlamaFactory (coming soon)\")\n\n\n@app.callback(invoke_without_command=True)\ndef callback(ctx: typer.Context) -> None:\n    \"\"\"Fine-tuning commands (coming soon).\"\"\"\n    if ctx.invoked_subcommand is None:\n        console.print()\n        console.print(f\"[yellow]{t('feature_coming_soon')}[/yellow]\")\n        console.print()\n        console.print(\"[dim]kt sft train   - Train a model[/dim]\")\n        console.print(\"[dim]kt sft chat    - Chat with a trained model[/dim]\")\n        console.print(\"[dim]kt sft export  - Export a trained model[/dim]\")\n        console.print()\n\n\n@app.command(name=\"train\")\ndef train() -> None:\n    \"\"\"Train a model using LlamaFactory (coming soon).\"\"\"\n    console.print()\n    console.print(f\"[yellow]{t('feature_coming_soon')}[/yellow]\")\n    console.print()\n    raise typer.Exit(0)\n\n\n@app.command(name=\"chat\")\ndef chat() -> None:\n    \"\"\"Chat with a trained model using LlamaFactory (coming soon).\"\"\"\n    console.print()\n    console.print(f\"[yellow]{t('feature_coming_soon')}[/yellow]\")\n    console.print()\n    raise typer.Exit(0)\n\n\n@app.command(name=\"export\")\ndef export() -> None:\n    \"\"\"Export a trained model using LlamaFactory (coming soon).\"\"\"\n    console.print()\n    console.print(f\"[yellow]{t('feature_coming_soon')}[/yellow]\")\n    console.print()\n    raise typer.Exit(0)\n"
  },
  {
    "path": "kt-kernel/python/cli/commands/version.py",
    "content": "\"\"\"\nVersion command for kt-cli.\n\nDisplays version information for kt-cli and related packages.\n\"\"\"\n\nimport platform\nfrom typing import Optional\n\nimport typer\n\nfrom kt_kernel.cli import __version__\nfrom kt_kernel.cli.i18n import t\nfrom kt_kernel.cli.utils.console import console, print_version_table\nfrom kt_kernel.cli.utils.environment import detect_cuda_version, get_installed_package_version\n\n\ndef _get_sglang_info() -> str:\n    \"\"\"Get sglang-kt version and installation source information.\"\"\"\n    from kt_kernel.cli.utils.sglang_checker import check_sglang_installation\n\n    info = check_sglang_installation()\n\n    if not info[\"installed\"]:\n        return t(\"version_not_installed\")\n\n    # Get version from package metadata (prefer sglang-kt)\n    version = get_installed_package_version(\"sglang-kt\")\n    if not version:\n        version = get_installed_package_version(\"sglang\")\n    if not version:\n        version = info.get(\"version\") or \"unknown\"\n\n    # Determine source label\n    if info.get(\"is_kvcache_fork\"):\n        if info[\"from_source\"] and info.get(\"git_info\"):\n            git_remote = info[\"git_info\"].get(\"remote\", \"\")\n            return f\"{version} [dim](Source: {git_remote})[/dim]\"\n        elif info[\"editable\"]:\n            return f\"{version} [dim](editable)[/dim]\"\n        else:\n            return f\"{version} [dim](sglang-kt)[/dim]\"\n    elif info[\"from_source\"]:\n        if info.get(\"git_info\"):\n            git_remote = info[\"git_info\"].get(\"remote\", \"\")\n            return f\"{version} [dim](Source: {git_remote})[/dim]\"\n        return f\"{version} [dim](source)[/dim]\"\n    else:\n        return f\"{version} [dim](PyPI)[/dim]\"\n\n\ndef version(\n    verbose: bool = typer.Option(False, \"--verbose\", \"-v\", help=\"Show detailed version info\"),\n) -> None:\n    \"\"\"Show version information.\"\"\"\n    console.print(f\"\\n[bold]{t('version_info')}[/bold] v{__version__}\\n\")\n\n    # Basic info\n    versions = {\n        t(\"version_python\"): platform.python_version(),\n        t(\"version_platform\"): f\"{platform.system()} {platform.release()}\",\n    }\n\n    # CUDA version\n    cuda_version = detect_cuda_version()\n    versions[t(\"version_cuda\")] = cuda_version or t(\"version_cuda_not_found\")\n\n    print_version_table(versions)\n\n    # Always show key packages with installation source\n    console.print(\"\\n[bold]Packages:[/bold]\\n\")\n\n    sglang_info = _get_sglang_info()\n    key_packages = {\n        t(\"version_kt_kernel\"): get_installed_package_version(\"kt-kernel\") or t(\"version_not_installed\"),\n        t(\"version_sglang\"): sglang_info,\n    }\n\n    print_version_table(key_packages)\n\n    # Show SGLang installation hint if not installed\n    if sglang_info == t(\"version_not_installed\"):\n        from kt_kernel.cli.utils.sglang_checker import print_sglang_install_instructions\n\n        console.print()\n        print_sglang_install_instructions()\n\n    if verbose:\n        console.print(\"\\n[bold]Additional Packages:[/bold]\\n\")\n\n        package_versions = {\n            t(\"version_ktransformers\"): get_installed_package_version(\"ktransformers\") or t(\"version_not_installed\"),\n            t(\"version_llamafactory\"): get_installed_package_version(\"llamafactory\") or t(\"version_not_installed\"),\n            \"typer\": get_installed_package_version(\"typer\") or t(\"version_not_installed\"),\n            \"rich\": get_installed_package_version(\"rich\") or t(\"version_not_installed\"),\n            \"torch\": get_installed_package_version(\"torch\") or t(\"version_not_installed\"),\n            \"transformers\": get_installed_package_version(\"transformers\") or t(\"version_not_installed\"),\n        }\n\n        print_version_table(package_versions)\n\n    console.print()\n"
  },
  {
    "path": "kt-kernel/python/cli/completions/__init__.py",
    "content": "\"\"\"Shell completion scripts for kt-cli.\"\"\"\n"
  },
  {
    "path": "kt-kernel/python/cli/completions/_kt",
    "content": "#compdef kt\n# Zsh completion for kt command\n# This is a static completion script that doesn't require Python startup\n\n_kt() {\n    local -a commands\n    commands=(\n        'version:Show version information'\n        'run:Start model inference server'\n        'chat:Interactive chat with running model'\n        'quant:Quantize model weights'\n        'bench:Run full benchmark'\n        'microbench:Run micro-benchmark'\n        'doctor:Diagnose environment issues'\n        'model:Manage models and storage paths'\n        'config:Manage configuration'\n        'sft:Fine-tuning with LlamaFactory'\n    )\n\n    local -a run_opts\n    run_opts=(\n        '--host[Server host]:host:'\n        '--port[Server port]:port:'\n        '--gpu-experts[Number of GPU experts]:count:'\n        '--cpu-threads[Number of CPU threads]:count:'\n        '--tensor-parallel-size[Tensor parallel size]:size:'\n        '--kt-method[KT method]:method:(AMXINT4 FP8 RAWINT4)'\n        '--attention-backend[Attention backend]:backend:(triton flashinfer)'\n        '--max-total-tokens[Maximum total tokens]:tokens:'\n        '--dry-run[Show command without executing]'\n        '--help[Show help message]'\n    )\n\n    local -a chat_opts\n    chat_opts=(\n        '--host[Server host]:host:'\n        '--port[Server port]:port:'\n        '--model[Model name]:model:'\n        '--temperature[Sampling temperature]:temp:'\n        '--max-tokens[Maximum tokens]:tokens:'\n        '--system[System prompt]:prompt:'\n        '--save-history[Save conversation history]'\n        '--no-save-history[Do not save history]'\n        '--history-file[History file path]:path:_files'\n        '--stream[Enable streaming output]'\n        '--no-stream[Disable streaming output]'\n        '--help[Show help message]'\n    )\n\n    local -a model_cmds\n    model_cmds=(\n        'download:Download a model from HuggingFace'\n        'list:List available models'\n        'path-list:List all model storage paths'\n        'path-add:Add a new model storage path'\n        'path-remove:Remove a model storage path'\n        'search:Search for models in the registry'\n    )\n\n    local -a config_cmds\n    config_cmds=(\n        'show:Show all configuration'\n        'get:Get configuration value'\n        'set:Set configuration value'\n        'reset:Reset to defaults'\n        'path:Show configuration file path'\n        'init:Re-run first-time setup wizard'\n    )\n\n    local -a sft_cmds\n    sft_cmds=(\n        'train:Train model'\n        'chat:Chat with model'\n        'export:Export model'\n    )\n\n    _arguments -C \\\n        '1: :->command' \\\n        '*::arg:->args'\n\n    case $state in\n        command)\n            _describe 'kt commands' commands\n            _arguments \\\n                '--help[Show help message]' \\\n                '--version[Show version]'\n            ;;\n        args)\n            case $words[1] in\n                run)\n                    _arguments $run_opts \\\n                        '1:model:'\n                    ;;\n                chat)\n                    _arguments $chat_opts\n                    ;;\n                quant)\n                    _arguments \\\n                        '--method[Quantization method]:method:' \\\n                        '--output[Output directory]:path:_files -/' \\\n                        '--help[Show help message]' \\\n                        '1:model:_files -/'\n                    ;;\n                bench|microbench)\n                    _arguments \\\n                        '--model[Model name or path]:model:' \\\n                        '--config[Config file path]:path:_files' \\\n                        '--help[Show help message]'\n                    ;;\n                doctor)\n                    _arguments \\\n                        '--verbose[Verbose output]' \\\n                        '--help[Show help message]'\n                    ;;\n                model)\n                    _arguments \\\n                        '1: :->model_cmd' \\\n                        '*::arg:->model_args'\n\n                    case $state in\n                        model_cmd)\n                            _describe 'model commands' model_cmds\n                            ;;\n                    esac\n                    ;;\n                config)\n                    _arguments \\\n                        '1: :->config_cmd' \\\n                        '*::arg:->config_args'\n\n                    case $state in\n                        config_cmd)\n                            _describe 'config commands' config_cmds\n                            ;;\n                    esac\n                    ;;\n                sft)\n                    _arguments \\\n                        '1: :->sft_cmd' \\\n                        '*::arg:->sft_args'\n\n                    case $state in\n                        sft_cmd)\n                            _describe 'sft commands' sft_cmds\n                            ;;\n                    esac\n                    ;;\n            esac\n            ;;\n    esac\n}\n\n_kt \"$@\"\n"
  },
  {
    "path": "kt-kernel/python/cli/completions/kt-completion.bash",
    "content": "#!/bin/bash\n# Bash completion for kt command\n# This is a static completion script that doesn't require Python startup\n\n_kt_completion() {\n    local cur prev opts\n    COMPREPLY=()\n    cur=\"${COMP_WORDS[COMP_CWORD]}\"\n    prev=\"${COMP_WORDS[COMP_CWORD-1]}\"\n\n    # Main commands\n    local commands=\"version run chat quant edit bench microbench doctor model config sft\"\n\n    # Global options\n    local global_opts=\"--help --version\"\n\n    # Handle subcommands\n    case \"${COMP_CWORD}\" in\n        1)\n            # First argument: suggest commands and global options\n            COMPREPLY=( $(compgen -W \"${commands} ${global_opts}\" -- ${cur}) )\n            return 0\n            ;;\n        *)\n            # Handle specific command options\n            case \"${COMP_WORDS[1]}\" in\n                run)\n                    local run_opts=\"--host --port --gpu-experts --cpu-threads --tensor-parallel-size --kt-method --attention-backend --max-total-tokens --dry-run --help\"\n                    COMPREPLY=( $(compgen -W \"${run_opts}\" -- ${cur}) )\n                    ;;\n                chat)\n                    local chat_opts=\"--host --port --model --temperature --max-tokens --system --save-history --no-save-history --history-file --stream --no-stream --help\"\n                    COMPREPLY=( $(compgen -W \"${chat_opts}\" -- ${cur}) )\n                    ;;\n                quant)\n                    local quant_opts=\"--method --output --help\"\n                    COMPREPLY=( $(compgen -W \"${quant_opts}\" -- ${cur}) )\n                    ;;\n                edit)\n                    local edit_opts=\"--help\"\n                    COMPREPLY=( $(compgen -W \"${edit_opts}\" -- ${cur}) )\n                    ;;\n                bench|microbench)\n                    local bench_opts=\"--model --config --help\"\n                    COMPREPLY=( $(compgen -W \"${bench_opts}\" -- ${cur}) )\n                    ;;\n                doctor)\n                    local doctor_opts=\"--verbose --help\"\n                    COMPREPLY=( $(compgen -W \"${doctor_opts}\" -- ${cur}) )\n                    ;;\n                model)\n                    local model_cmds=\"download list path-list path-add path-remove search\"\n                    local model_opts=\"--help\"\n                    COMPREPLY=( $(compgen -W \"${model_cmds} ${model_opts}\" -- ${cur}) )\n                    ;;\n                config)\n                    local config_cmds=\"show get set reset path init model-path-list model-path-add model-path-remove\"\n                    local config_opts=\"--help\"\n                    COMPREPLY=( $(compgen -W \"${config_cmds} ${config_opts}\" -- ${cur}) )\n                    ;;\n                sft)\n                    local sft_cmds=\"train chat export\"\n                    local sft_opts=\"--help\"\n                    COMPREPLY=( $(compgen -W \"${sft_cmds} ${sft_opts}\" -- ${cur}) )\n                    ;;\n                version)\n                    COMPREPLY=( $(compgen -W \"--help\" -- ${cur}) )\n                    ;;\n                *)\n                    COMPREPLY=()\n                    ;;\n            esac\n            ;;\n    esac\n}\n\ncomplete -F _kt_completion kt\n"
  },
  {
    "path": "kt-kernel/python/cli/completions/kt.fish",
    "content": "# Fish completion for kt command\n# This is a static completion script that doesn't require Python startup\n\n# Main commands\ncomplete -c kt -f -n \"__fish_use_subcommand\" -a \"version\" -d \"Show version information\"\ncomplete -c kt -f -n \"__fish_use_subcommand\" -a \"run\" -d \"Start model inference server\"\ncomplete -c kt -f -n \"__fish_use_subcommand\" -a \"chat\" -d \"Interactive chat with running model\"\ncomplete -c kt -f -n \"__fish_use_subcommand\" -a \"quant\" -d \"Quantize model weights\"\ncomplete -c kt -f -n \"__fish_use_subcommand\" -a \"bench\" -d \"Run full benchmark\"\ncomplete -c kt -f -n \"__fish_use_subcommand\" -a \"microbench\" -d \"Run micro-benchmark\"\ncomplete -c kt -f -n \"__fish_use_subcommand\" -a \"doctor\" -d \"Diagnose environment issues\"\ncomplete -c kt -f -n \"__fish_use_subcommand\" -a \"model\" -d \"Manage models and storage paths\"\ncomplete -c kt -f -n \"__fish_use_subcommand\" -a \"config\" -d \"Manage configuration\"\ncomplete -c kt -f -n \"__fish_use_subcommand\" -a \"sft\" -d \"Fine-tuning with LlamaFactory\"\n\n# Global options\ncomplete -c kt -l help -d \"Show help message\"\ncomplete -c kt -l version -d \"Show version\"\n\n# Run command options\ncomplete -c kt -f -n \"__fish_seen_subcommand_from run\" -l host -d \"Server host\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from run\" -l port -d \"Server port\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from run\" -l gpu-experts -d \"Number of GPU experts\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from run\" -l cpu-threads -d \"Number of CPU threads\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from run\" -l tensor-parallel-size -d \"Tensor parallel size\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from run\" -l kt-method -d \"KT method\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from run\" -l attention-backend -d \"Attention backend\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from run\" -l max-total-tokens -d \"Maximum total tokens\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from run\" -l dry-run -d \"Show command without executing\"\n\n# Chat command options\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l host -d \"Server host\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l port -d \"Server port\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l model -d \"Model name\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l temperature -d \"Sampling temperature\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l max-tokens -d \"Maximum tokens\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l system -d \"System prompt\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l save-history -d \"Save conversation history\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l no-save-history -d \"Do not save history\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l history-file -d \"History file path\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l stream -d \"Enable streaming output\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from chat\" -l no-stream -d \"Disable streaming output\"\n\n# Quant command options\ncomplete -c kt -f -n \"__fish_seen_subcommand_from quant\" -l method -d \"Quantization method\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from quant\" -l output -d \"Output directory\"\n\n# Bench command options\ncomplete -c kt -f -n \"__fish_seen_subcommand_from bench microbench\" -l model -d \"Model name or path\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from bench microbench\" -l config -d \"Config file path\"\n\n# Doctor command options\ncomplete -c kt -f -n \"__fish_seen_subcommand_from doctor\" -l verbose -d \"Verbose output\"\n\n# Model subcommands\ncomplete -c kt -f -n \"__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search\" -a \"download\" -d \"Download a model from HuggingFace\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search\" -a \"list\" -d \"List available models\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search\" -a \"path-list\" -d \"List all model storage paths\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search\" -a \"path-add\" -d \"Add a new model storage path\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search\" -a \"path-remove\" -d \"Remove a model storage path\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from model; and not __fish_seen_subcommand_from download list path-list path-add path-remove search\" -a \"search\" -d \"Search for models in the registry\"\n\n# Config subcommands\ncomplete -c kt -f -n \"__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init\" -a \"show\" -d \"Show all configuration\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init\" -a \"get\" -d \"Get configuration value\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init\" -a \"set\" -d \"Set configuration value\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init\" -a \"reset\" -d \"Reset to defaults\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init\" -a \"path\" -d \"Show configuration file path\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from config; and not __fish_seen_subcommand_from show get set reset path init\" -a \"init\" -d \"Re-run first-time setup wizard\"\n\n# SFT subcommands\ncomplete -c kt -f -n \"__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export\" -a \"train\" -d \"Train model\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export\" -a \"chat\" -d \"Chat with model\"\ncomplete -c kt -f -n \"__fish_seen_subcommand_from sft; and not __fish_seen_subcommand_from train chat export\" -a \"export\" -d \"Export model\"\n"
  },
  {
    "path": "kt-kernel/python/cli/config/__init__.py",
    "content": "\"\"\"\nConfiguration management for kt-cli.\n\"\"\"\n\nfrom kt_kernel.cli.config.settings import Settings, get_settings\n\n__all__ = [\"Settings\", \"get_settings\"]\n"
  },
  {
    "path": "kt-kernel/python/cli/config/settings.py",
    "content": "\"\"\"\nConfiguration management for kt-cli.\n\nHandles reading and writing configuration from ~/.ktransformers/config.yaml\n\"\"\"\n\nimport os\nfrom pathlib import Path\nfrom typing import Any, Optional\n\nimport yaml\n\n# Default configuration directory\nDEFAULT_CONFIG_DIR = Path.home() / \".ktransformers\"\nDEFAULT_CONFIG_FILE = DEFAULT_CONFIG_DIR / \"config.yaml\"\nDEFAULT_MODELS_DIR = DEFAULT_CONFIG_DIR / \"models\"\nDEFAULT_CACHE_DIR = DEFAULT_CONFIG_DIR / \"cache\"\n\n# Default configuration values\nDEFAULT_CONFIG = {\n    \"general\": {\n        \"language\": \"auto\",  # auto, en, zh\n        \"color\": True,\n        \"verbose\": False,\n    },\n    \"paths\": {\n        \"models\": str(DEFAULT_MODELS_DIR),\n        \"cache\": str(DEFAULT_CACHE_DIR),\n        \"weights\": \"\",  # Custom quantized weights path\n    },\n    \"server\": {\n        \"host\": \"0.0.0.0\",\n        \"port\": 30000,\n    },\n    \"inference\": {\n        # Inference parameters are model-specific and should not have defaults\n        # They will be auto-detected or use model-specific optimizations\n        # Environment variables (general optimizations)\n        \"env\": {\n            \"PYTORCH_ALLOC_CONF\": \"expandable_segments:True\",\n            \"SGLANG_ENABLE_JIT_DEEPGEMM\": \"0\",\n        },\n    },\n    \"download\": {\n        \"mirror\": \"\",  # HuggingFace mirror URL\n        \"resume\": True,\n        \"verify\": True,\n    },\n    \"advanced\": {\n        # Environment variables to set when running\n        \"env\": {},\n        # Extra arguments to pass to sglang\n        \"sglang_args\": [],\n        # Extra arguments to pass to llamafactory\n        \"llamafactory_args\": [],\n    },\n    \"dependencies\": {\n        # SGLang installation source configuration\n        \"sglang\": {\n            \"source\": \"github\",  # \"pypi\" or \"github\"\n            \"repo\": \"https://github.com/kvcache-ai/sglang\",\n            \"branch\": \"main\",\n        },\n    },\n}\n\n\nclass Settings:\n    \"\"\"Configuration manager for kt-cli.\"\"\"\n\n    def __init__(self, config_path: Optional[Path] = None):\n        \"\"\"Initialize settings manager.\n\n        Args:\n            config_path: Path to config file. Defaults to ~/.ktransformers/config.yaml\n        \"\"\"\n        self.config_path = config_path or DEFAULT_CONFIG_FILE\n        self.config_dir = self.config_path.parent\n        self._config: dict[str, Any] = {}\n        self._load()\n\n    def _ensure_dirs(self) -> None:\n        \"\"\"Ensure configuration directories exist.\"\"\"\n        self.config_dir.mkdir(parents=True, exist_ok=True)\n\n        # Ensure all model paths exist\n        model_paths = self.get_model_paths()\n        for path in model_paths:\n            path.mkdir(parents=True, exist_ok=True)\n\n        Path(self.get(\"paths.cache\", DEFAULT_CACHE_DIR)).mkdir(parents=True, exist_ok=True)\n\n    def _load(self) -> None:\n        \"\"\"Load configuration from file.\"\"\"\n        self._config = self._deep_copy(DEFAULT_CONFIG)\n\n        if self.config_path.exists():\n            try:\n                with open(self.config_path, \"r\", encoding=\"utf-8\") as f:\n                    user_config = yaml.safe_load(f) or {}\n                self._deep_merge(self._config, user_config)\n            except (yaml.YAMLError, OSError) as e:\n                # Log warning but continue with defaults\n                print(f\"Warning: Failed to load config: {e}\")\n\n        self._ensure_dirs()\n\n    def _save(self) -> None:\n        \"\"\"Save configuration to file.\"\"\"\n        self._ensure_dirs()\n        try:\n            with open(self.config_path, \"w\", encoding=\"utf-8\") as f:\n                yaml.dump(self._config, f, default_flow_style=False, allow_unicode=True)\n        except OSError as e:\n            raise RuntimeError(f\"Failed to save config: {e}\")\n\n    def _deep_copy(self, obj: Any) -> Any:\n        \"\"\"Create a deep copy of a nested dict.\"\"\"\n        if isinstance(obj, dict):\n            return {k: self._deep_copy(v) for k, v in obj.items()}\n        if isinstance(obj, list):\n            return [self._deep_copy(item) for item in obj]\n        return obj\n\n    def _deep_merge(self, base: dict, override: dict) -> None:\n        \"\"\"Deep merge override into base.\"\"\"\n        for key, value in override.items():\n            if key in base and isinstance(base[key], dict) and isinstance(value, dict):\n                self._deep_merge(base[key], value)\n            else:\n                base[key] = value\n\n    def get(self, key: str, default: Any = None) -> Any:\n        \"\"\"Get a configuration value by dot-separated key.\n\n        Args:\n            key: Dot-separated key path (e.g., \"server.port\")\n            default: Default value if key not found\n\n        Returns:\n            Configuration value or default\n        \"\"\"\n        parts = key.split(\".\")\n        value = self._config\n\n        for part in parts:\n            if isinstance(value, dict) and part in value:\n                value = value[part]\n            else:\n                return default\n\n        return value\n\n    def set(self, key: str, value: Any) -> None:\n        \"\"\"Set a configuration value by dot-separated key.\n\n        Args:\n            key: Dot-separated key path (e.g., \"server.port\")\n            value: Value to set\n        \"\"\"\n        parts = key.split(\".\")\n        config = self._config\n\n        # Navigate to parent\n        for part in parts[:-1]:\n            if part not in config:\n                config[part] = {}\n            config = config[part]\n\n        # Set value\n        config[parts[-1]] = value\n        self._save()\n\n    def delete(self, key: str) -> bool:\n        \"\"\"Delete a configuration value.\n\n        Args:\n            key: Dot-separated key path\n\n        Returns:\n            True if key was deleted, False if not found\n        \"\"\"\n        parts = key.split(\".\")\n        config = self._config\n\n        # Navigate to parent\n        for part in parts[:-1]:\n            if part not in config:\n                return False\n            config = config[part]\n\n        # Delete key\n        if parts[-1] in config:\n            del config[parts[-1]]\n            self._save()\n            return True\n        return False\n\n    def reset(self) -> None:\n        \"\"\"Reset configuration to defaults.\"\"\"\n        self._config = self._deep_copy(DEFAULT_CONFIG)\n        self._save()\n\n    def get_all(self) -> dict[str, Any]:\n        \"\"\"Get all configuration values.\"\"\"\n        return self._deep_copy(self._config)\n\n    def get_env_vars(self) -> dict[str, str]:\n        \"\"\"Get environment variables to set.\"\"\"\n        env_vars = {}\n\n        # Get from advanced.env\n        advanced_env = self.get(\"advanced.env\", {})\n        if isinstance(advanced_env, dict):\n            env_vars.update({k: str(v) for k, v in advanced_env.items()})\n\n        return env_vars\n\n    @property\n    def models_dir(self) -> Path:\n        \"\"\"Get the primary models directory path (for backward compatibility).\"\"\"\n        paths = self.get_model_paths()\n        return paths[0] if paths else Path(DEFAULT_MODELS_DIR)\n\n    def get_model_paths(self) -> list[Path]:\n        \"\"\"Get all model directory paths.\n\n        Returns a list of Path objects. Supports both:\n        - Single path: paths.models = \"/path/to/models\"\n        - Multiple paths: paths.models = [\"/path/1\", \"/path/2\"]\n        \"\"\"\n        models_config = self.get(\"paths.models\", DEFAULT_MODELS_DIR)\n\n        # Handle both string and list\n        if isinstance(models_config, str):\n            return [Path(models_config)]\n        elif isinstance(models_config, list):\n            return [Path(p) for p in models_config]\n        else:\n            return [Path(DEFAULT_MODELS_DIR)]\n\n    def add_model_path(self, path: str) -> None:\n        \"\"\"Add a new model path to the configuration.\"\"\"\n        models_config = self.get(\"paths.models\", DEFAULT_MODELS_DIR)\n\n        # Convert to list if it's a string\n        if isinstance(models_config, str):\n            paths = [models_config]\n        elif isinstance(models_config, list):\n            paths = list(models_config)\n        else:\n            paths = []\n\n        # Add new path if not already present\n        if path not in paths:\n            paths.append(path)\n            self.set(\"paths.models\", paths)\n\n    def remove_model_path(self, path: str) -> bool:\n        \"\"\"Remove a model path from the configuration.\n\n        Returns True if path was removed, False if not found.\n        \"\"\"\n        models_config = self.get(\"paths.models\", DEFAULT_MODELS_DIR)\n\n        if isinstance(models_config, str):\n            # Can't remove if it's a single string\n            if models_config == path:\n                # Don't remove the last path\n                return False\n            return False\n        elif isinstance(models_config, list):\n            if path in models_config:\n                paths = list(models_config)\n                paths.remove(path)\n                # Don't allow removing all paths\n                if not paths:\n                    return False\n                self.set(\"paths.models\", paths if len(paths) > 1 else paths[0])\n                return True\n\n        return False\n\n    @property\n    def cache_dir(self) -> Path:\n        \"\"\"Get the cache directory path.\"\"\"\n        return Path(self.get(\"paths.cache\", DEFAULT_CACHE_DIR))\n\n    @property\n    def weights_dir(self) -> Optional[Path]:\n        \"\"\"Get the custom weights directory path.\"\"\"\n        weights = self.get(\"paths.weights\", \"\")\n        return Path(weights) if weights else None\n\n\n# Global settings instance\n_settings: Optional[Settings] = None\n\n\ndef get_settings() -> Settings:\n    \"\"\"Get the global settings instance.\"\"\"\n    global _settings\n    if _settings is None:\n        _settings = Settings()\n    return _settings\n\n\ndef reset_settings() -> None:\n    \"\"\"Reset the global settings instance.\"\"\"\n    global _settings\n    _settings = None\n"
  },
  {
    "path": "kt-kernel/python/cli/i18n.py",
    "content": "\"\"\"\nInternationalization (i18n) module for kt-cli.\n\nSupports English and Chinese languages, with automatic detection based on\nsystem locale or KT_LANG environment variable.\n\"\"\"\n\nimport os\nfrom typing import Any\n\n# Message definitions for all supported languages\nMESSAGES: dict[str, dict[str, str]] = {\n    \"en\": {\n        # General\n        \"welcome\": \"Welcome to KTransformers!\",\n        \"goodbye\": \"Goodbye!\",\n        \"error\": \"Error\",\n        \"warning\": \"Warning\",\n        \"success\": \"Success\",\n        \"info\": \"Info\",\n        \"yes\": \"Yes\",\n        \"no\": \"No\",\n        \"cancel\": \"Cancel\",\n        \"confirm\": \"Confirm\",\n        \"done\": \"Done\",\n        \"failed\": \"Failed\",\n        \"skip\": \"Skip\",\n        \"back\": \"Back\",\n        \"next\": \"Next\",\n        \"retry\": \"Retry\",\n        \"abort\": \"Abort\",\n        # Version command\n        \"version_info\": \"KTransformers CLI\",\n        \"version_python\": \"Python\",\n        \"version_platform\": \"Platform\",\n        \"version_cuda\": \"CUDA\",\n        \"version_cuda_not_found\": \"Not found\",\n        \"version_kt_kernel\": \"kt-kernel\",\n        \"version_ktransformers\": \"ktransformers\",\n        \"version_sglang\": \"sglang-kt\",\n        \"version_llamafactory\": \"llamafactory\",\n        \"version_not_installed\": \"Not installed\",\n        # Install command\n        \"install_detecting_env\": \"Detecting environment managers...\",\n        \"install_found\": \"Found {name} (version {version})\",\n        \"install_not_found\": \"Not found: {name}\",\n        \"install_checking_env\": \"Checking existing environments...\",\n        \"install_env_exists\": \"Found existing 'kt' environment\",\n        \"install_env_not_exists\": \"No 'kt' environment found\",\n        \"install_no_env_manager\": \"No virtual environment manager detected\",\n        \"install_select_method\": \"Please select installation method:\",\n        \"install_method_conda\": \"Create new conda environment 'kt' (Recommended)\",\n        \"install_method_venv\": \"Create new venv environment\",\n        \"install_method_uv\": \"Create new uv environment (Fast)\",\n        \"install_method_docker\": \"Use Docker container\",\n        \"install_method_system\": \"Install to system Python (Not recommended)\",\n        \"install_select_mode\": \"Please select installation mode:\",\n        \"install_mode_inference\": \"Inference - Install kt-kernel + SGLang\",\n        \"install_mode_sft\": \"Training - Install kt-sft + LlamaFactory\",\n        \"install_mode_full\": \"Full - Install all components\",\n        \"install_creating_env\": \"Creating {type} environment '{name}'...\",\n        \"install_env_created\": \"Environment created successfully\",\n        \"install_installing_deps\": \"Installing dependencies...\",\n        \"install_checking_deps\": \"Checking dependency versions...\",\n        \"install_dep_ok\": \"OK\",\n        \"install_dep_outdated\": \"Needs update\",\n        \"install_dep_missing\": \"Missing\",\n        \"install_installing_pytorch\": \"Installing PyTorch...\",\n        \"install_installing_from_requirements\": \"Installing from requirements file...\",\n        \"install_deps_outdated\": \"Found {count} package(s) that need updating. Continue?\",\n        \"install_updating\": \"Updating packages...\",\n        \"install_complete\": \"Installation complete!\",\n        \"install_activate_hint\": \"Activate environment: {command}\",\n        \"install_start_hint\": \"Get started: kt run --help\",\n        \"install_docker_pulling\": \"Pulling Docker image...\",\n        \"install_docker_complete\": \"Docker image ready!\",\n        \"install_docker_run_hint\": \"Run with: docker run --gpus all -p 30000:30000 {image} kt run {model}\",\n        \"install_in_venv\": \"Running in virtual environment: {name}\",\n        \"install_continue_without_venv\": \"Continue installing to system Python?\",\n        \"install_already_installed\": \"All dependencies are already installed!\",\n        \"install_confirm\": \"Install {count} package(s)?\",\n        # Install - System dependencies\n        \"install_checking_system_deps\": \"Checking system dependencies...\",\n        \"install_dep_name\": \"Dependency\",\n        \"install_dep_status\": \"Status\",\n        \"install_deps_all_installed\": \"All system dependencies are installed\",\n        \"install_deps_install_prompt\": \"Install missing dependencies?\",\n        \"install_installing_system_deps\": \"Installing system dependencies...\",\n        \"install_installing_dep\": \"Installing {name}\",\n        \"install_dep_no_install_cmd\": \"No install command available for {name} on {os}\",\n        \"install_dep_install_failed\": \"Failed to install {name}\",\n        \"install_deps_skipped\": \"Skipping dependency installation\",\n        \"install_deps_failed\": \"Failed to install system dependencies\",\n        # Install - CPU detection\n        \"install_auto_detect_cpu\": \"Auto-detecting CPU capabilities...\",\n        \"install_cpu_features\": \"Detected CPU features: {features}\",\n        \"install_cpu_no_features\": \"No advanced CPU features detected\",\n        # Install - Build configuration\n        \"install_build_config\": \"Build Configuration:\",\n        \"install_native_warning\": \"Note: Binary optimized for THIS CPU only (not portable)\",\n        \"install_building_from_source\": \"Building kt-kernel from source...\",\n        \"install_build_failed\": \"Build failed\",\n        \"install_build_success\": \"Build completed successfully\",\n        # Install - Verification\n        \"install_verifying\": \"Verifying installation...\",\n        \"install_verify_success\": \"kt-kernel {version} ({variant} variant) installed successfully\",\n        \"install_verify_failed\": \"Verification failed: {error}\",\n        # Install - Docker\n        \"install_docker_guide_title\": \"Docker Installation\",\n        \"install_docker_guide_desc\": \"For Docker installation, please refer to the official guide:\",\n        # Config command\n        \"config_show_title\": \"Current Configuration\",\n        \"config_set_success\": \"Configuration updated: {key} = {value}\",\n        \"config_get_value\": \"{key} = {value}\",\n        \"config_get_not_found\": \"Configuration key '{key}' not found\",\n        \"config_reset_confirm\": \"This will reset all configurations to default. Continue?\",\n        \"config_reset_success\": \"Configuration reset to default\",\n        \"config_file_location\": \"Configuration file: {path}\",\n        # Doctor command\n        \"doctor_title\": \"KTransformers Environment Diagnostics\",\n        \"doctor_checking\": \"Running diagnostics...\",\n        \"doctor_check_python\": \"Python version\",\n        \"doctor_check_cuda\": \"CUDA availability\",\n        \"doctor_check_gpu\": \"GPU detection\",\n        \"doctor_check_cpu\": \"CPU\",\n        \"doctor_check_cpu_isa\": \"CPU Instructions\",\n        \"doctor_check_numa\": \"NUMA Topology\",\n        \"doctor_check_memory\": \"System memory\",\n        \"doctor_check_disk\": \"Disk space\",\n        \"doctor_check_packages\": \"Required packages\",\n        \"doctor_check_env\": \"Environment variables\",\n        \"doctor_status_ok\": \"OK\",\n        \"doctor_status_warning\": \"Warning\",\n        \"doctor_status_error\": \"Error\",\n        \"doctor_gpu_found\": \"Found {count} GPU(s): {names}\",\n        \"doctor_gpu_not_found\": \"No GPU detected\",\n        \"doctor_cpu_info\": \"{name} ({cores} cores / {threads} threads)\",\n        \"doctor_cpu_isa_info\": \"{isa_list}\",\n        \"doctor_cpu_isa_missing\": \"Missing recommended: {missing}\",\n        \"doctor_numa_info\": \"{nodes} node(s)\",\n        \"doctor_numa_detail\": \"{node}: CPUs {cpus}\",\n        \"doctor_memory_info\": \"{available} available / {total} total\",\n        \"doctor_memory_freq\": \"{available} available / {total} total ({freq}MHz {type})\",\n        \"doctor_disk_info\": \"{available} available at {path}\",\n        \"doctor_all_ok\": \"All checks passed! Your environment is ready.\",\n        \"doctor_has_issues\": \"Some issues were found. Please review the warnings/errors above.\",\n        # Run command\n        \"run_detecting_hardware\": \"Detecting hardware configuration...\",\n        \"run_gpu_info\": \"GPU: {name} ({vram}GB VRAM)\",\n        \"run_cpu_info\": \"CPU: {name} ({cores} cores, {numa} NUMA nodes)\",\n        \"run_ram_info\": \"RAM: {total}GB\",\n        \"run_checking_model\": \"Checking model status...\",\n        \"run_model_path\": \"Model path: {path}\",\n        \"run_weights_not_found\": \"Quantized weights not found\",\n        \"run_quant_prompt\": \"Quantize model now? (This may take a while)\",\n        \"run_quantizing\": \"Quantizing model...\",\n        \"run_starting_server\": \"Starting server...\",\n        \"run_server_mode\": \"Mode: SGLang + kt-kernel\",\n        \"run_server_port\": \"Port: {port}\",\n        \"run_gpu_experts\": \"GPU experts: {count}/layer\",\n        \"run_cpu_threads\": \"CPU threads: {count}\",\n        \"run_server_started\": \"Server started!\",\n        \"run_api_url\": \"API URL: http://{host}:{port}\",\n        \"run_docs_url\": \"Docs URL: http://{host}:{port}/docs\",\n        \"run_stop_hint\": \"Press Ctrl+C to stop the server\",\n        \"run_model_not_found\": \"Model '{name}' not found. Run 'kt download' first.\",\n        \"run_multiple_matches\": \"Multiple models found. Please select:\",\n        \"run_select_model\": \"Select model\",\n        \"run_select_model_title\": \"Select a model to run\",\n        \"run_select_model_prompt\": \"Enter number\",\n        \"run_local_models\": \"Local Models (Downloaded)\",\n        \"run_registered_models\": \"Registered Models\",\n        # Download command\n        \"download_list_title\": \"Available Models\",\n        \"download_searching\": \"Searching for model '{name}'...\",\n        \"download_found\": \"Found: {name}\",\n        \"download_multiple_found\": \"Multiple matches found:\",\n        \"download_select\": \"Select model to download:\",\n        \"download_destination\": \"Destination: {path}\",\n        \"download_starting\": \"Starting download...\",\n        \"download_progress\": \"Downloading {name}...\",\n        \"download_complete\": \"Download complete!\",\n        \"download_already_exists\": \"Model already exists at {path}\",\n        \"download_overwrite_prompt\": \"Overwrite existing files?\",\n        # Quant command\n        \"quant_input_path\": \"Input path: {path}\",\n        \"quant_output_path\": \"Output path: {path}\",\n        \"quant_method\": \"Quantization method: {method}\",\n        \"quant_starting\": \"Starting quantization...\",\n        \"quant_progress\": \"Quantizing...\",\n        \"quant_complete\": \"Quantization complete!\",\n        \"quant_input_not_found\": \"Input model not found at {path}\",\n        \"quant_cpu_threads\": \"CPU threads: {threads}\",\n        \"quant_numa_nodes\": \"NUMA nodes: {nodes}\",\n        \"quant_time_warning\": \"Quantization may take 30-60 minutes depending on model size.\",\n        \"quant_disk_analysis\": \"Disk Space Analysis:\",\n        \"quant_source_size\": \"Source model size:\",\n        \"quant_estimated_size\": \"Estimated output size:\",\n        \"quant_available_space\": \"Available space:\",\n        \"quant_insufficient_space\": \"WARNING: Insufficient disk space!\",\n        \"quant_required_space\": \"Required space (with 20% buffer):\",\n        \"quant_shortage\": \"Shortage:\",\n        \"quant_may_fail\": \"Quantization may fail or produce incomplete files.\",\n        \"quant_continue_anyway\": \"Continue anyway?\",\n        \"quant_settings\": \"Quantization Settings:\",\n        \"quant_registered\": \"Quantized model registered: {name}\",\n        \"quant_view_with\": \"View with:\",\n        \"quant_use_with\": \"Use with:\",\n        \"quant_register_failed\": \"Failed to auto-register model: {error}\",\n        \"quant_output_exists\": \"Output path already exists: {path}\",\n        \"quant_using_unique\": \"Using unique name: {path}\",\n        # Interactive quant\n        \"quant_interactive_title\": \"Interactive Quantization Configuration\",\n        \"quant_new_model_notice\": \"⚠ Note: Some newer models cannot be quantized yet (conversion script not adapted). Recommended to use the original precision for inference (no weight conversion needed).\",\n        \"quant_no_moe_models\": \"No MoE models found for quantization.\",\n        \"quant_only_moe\": \"Only MoE models (e.g., DeepSeek-V3) can be quantized to AMX format.\",\n        \"quant_add_models\": \"Add models with: {command}\",\n        \"quant_moe_available\": \"MoE Models Available for Quantization:\",\n        \"quant_select_model\": \"Select model to quantize\",\n        \"quant_invalid_choice\": \"Invalid choice\",\n        \"quant_step2_method\": \"Step 2: Quantization Method\",\n        \"quant_method_label\": \"Quantization Method:\",\n        \"quant_int4_desc\": \"INT4\",\n        \"quant_int8_desc\": \"INT8\",\n        \"quant_select_method\": \"Select quantization method\",\n        \"quant_input_type_label\": \"Input Weight Type:\",\n        \"quant_fp8_desc\": \"FP8 (for 8-bit float weights)\",\n        \"quant_fp16_desc\": \"FP16 (for 16-bit float weights)\",\n        \"quant_bf16_desc\": \"BF16 (for Brain Float 16 weights)\",\n        \"quant_select_input_type\": \"Select input type\",\n        \"quant_step3_cpu\": \"Step 3: CPU Configuration\",\n        \"quant_cpu_threads_prompt\": \"CPU Threads (1 to {max})\",\n        \"quant_numa_nodes_prompt\": \"NUMA Nodes (1 to {max})\",\n        \"quant_use_gpu_label\": \"Use GPU for conversion?\",\n        \"quant_gpu_speedup\": \"GPU can significantly speed up the quantization process\",\n        \"quant_enable_gpu\": \"Enable GPU acceleration?\",\n        \"quant_step4_output\": \"Step 4: Output Path\",\n        \"quant_default_path\": \"Default:\",\n        \"quant_use_default\": \"Use default output path?\",\n        \"quant_custom_path\": \"Enter custom output path\",\n        \"quant_output_exists_warn\": \"⚠ Output path already exists: {path}\",\n        \"quant_using_unique_name\": \"→ Using unique name: {path}\",\n        \"quant_config_summary\": \"Configuration Summary\",\n        \"quant_summary_model\": \"Model:\",\n        \"quant_summary_method\": \"Method:\",\n        \"quant_summary_input_type\": \"Input Type:\",\n        \"quant_summary_cpu_threads\": \"CPU Threads:\",\n        \"quant_summary_numa\": \"NUMA Nodes:\",\n        \"quant_summary_gpu\": \"Use GPU:\",\n        \"quant_summary_output\": \"Output Path:\",\n        \"quant_start_question\": \"Start quantization?\",\n        \"quant_cancelled\": \"Cancelled\",\n        \"quant_config_complete\": \"Configuration complete\",\n        \"quant_time_elapsed\": \"Time elapsed:\",\n        \"yes\": \"Yes\",\n        \"no\": \"No\",\n        # SFT command\n        \"sft_mode_train\": \"Training mode\",\n        \"sft_mode_chat\": \"Chat mode\",\n        \"sft_mode_export\": \"Export mode\",\n        \"sft_config_path\": \"Config file: {path}\",\n        \"sft_starting\": \"Starting {mode}...\",\n        \"sft_complete\": \"{mode} complete!\",\n        \"sft_config_not_found\": \"Config file not found: {path}\",\n        # Bench command\n        \"bench_starting\": \"Starting benchmark...\",\n        \"bench_type\": \"Benchmark type: {type}\",\n        \"bench_complete\": \"Benchmark complete!\",\n        \"bench_results_title\": \"Benchmark Results\",\n        # Common prompts\n        \"prompt_continue\": \"Continue?\",\n        \"prompt_select\": \"Please select:\",\n        \"prompt_enter_value\": \"Enter value:\",\n        \"prompt_confirm_action\": \"Confirm this action?\",\n        # First-run setup - Model path selection\n        \"setup_model_path_title\": \"Model Storage Location\",\n        \"setup_model_path_desc\": \"LLM models are large (50-200GB+). Please select a storage location with sufficient space:\",\n        \"setup_scanning_disks\": \"Scanning available storage locations...\",\n        \"setup_disk_option\": \"{path} ({available} available / {total} total)\",\n        \"setup_disk_option_recommended\": \"{path} ({available} available / {total} total) [Recommended]\",\n        \"setup_custom_path\": \"Enter custom path\",\n        \"setup_enter_custom_path\": \"Enter the path for model storage\",\n        \"setup_path_not_exist\": \"Path does not exist. Create it?\",\n        \"setup_path_no_write\": \"No write permission for this path. Please choose another.\",\n        \"setup_path_low_space\": \"Warning: Less than 100GB available. Large models may not fit.\",\n        \"setup_model_path_set\": \"Model storage path set to: {path}\",\n        \"setup_no_large_disk\": \"No large storage locations found. Using default path.\",\n        \"setup_scanning_models\": \"Scanning for existing models...\",\n        \"setup_found_models\": \"Found {count} model(s):\",\n        \"setup_model_info\": \"{name} ({size}, {type})\",\n        \"setup_no_models_found\": \"No existing models found in this location.\",\n        \"setup_location_has_models\": \"{count} model(s) found\",\n        \"setup_installing_completion\": \"Installing shell completion for {shell}...\",\n        \"setup_completion_installed\": \"Shell completion installed! Restart terminal to enable.\",\n        \"setup_completion_failed\": \"Failed to install shell completion. Run 'kt --install-completion' manually.\",\n        # Auto completion\n        \"completion_installed_title\": \"Tab Completion\",\n        \"completion_installed_for\": \"Shell completion installed for {shell}\",\n        \"completion_activate_now\": \"To enable completion in this terminal session, run:\",\n        \"completion_next_session\": \"Completion will be automatically enabled in new terminal sessions.\",\n        # SGLang\n        \"sglang_not_found\": \"SGLang not found\",\n        \"sglang_pypi_warning\": \"SGLang from PyPI may not be compatible with kt-kernel. Use sglang-kt instead: pip install sglang-kt\",\n        \"sglang_pypi_hint\": \"SGLang from PyPI may not be compatible. Install the kvcache-ai fork: pip install sglang-kt (or run ./install.sh from ktransformers root)\",\n        \"sglang_install_hint\": \"Install SGLang: pip install sglang-kt (or run ./install.sh from ktransformers root)\",\n        \"sglang_recommend_source\": \"Recommend reinstalling with the kvcache-ai fork: pip uninstall sglang -y && pip install sglang-kt\",\n        \"sglang_kt_kernel_not_supported\": \"SGLang does not support kt-kernel (missing --kt-gpu-prefill-token-threshold parameter)\",\n        \"sglang_checking_kt_kernel_support\": \"Checking SGLang kt-kernel support...\",\n        \"sglang_kt_kernel_supported\": \"SGLang kt-kernel support verified\",\n        # Chat\n        \"chat_proxy_detected\": \"Proxy detected in environment\",\n        \"chat_proxy_confirm\": \"Use proxy for connection?\",\n        \"chat_proxy_disabled\": \"Proxy disabled for this session\",\n        \"chat_openai_required\": \"OpenAI Python SDK is required for chat functionality.\",\n        \"chat_install_hint\": \"Install it with:\",\n        \"chat_title\": \"KTransformers Chat\",\n        \"chat_server\": \"Server\",\n        \"chat_temperature\": \"Temperature\",\n        \"chat_max_tokens\": \"Max tokens\",\n        \"chat_help_hint\": \"Type '/help' for commands, '/quit' to exit\",\n        \"chat_connecting\": \"Connecting to server...\",\n        \"chat_no_models\": \"No models available on server\",\n        \"chat_model_not_found\": \"Model '{model}' not found. Available models: {available}\",\n        \"chat_connected\": \"Connected to model: {model}\",\n        \"chat_connect_failed\": \"Failed to connect to server: {error}\",\n        \"chat_server_not_running\": \"Make sure the model server is running:\",\n        \"chat_user_prompt\": \"You\",\n        \"chat_assistant_prompt\": \"Assistant\",\n        \"chat_generation_error\": \"Error generating response: {error}\",\n        \"chat_interrupted\": \"Chat interrupted. Goodbye!\",\n        \"chat_history_saved\": \"History saved to: {path}\",\n        \"chat_goodbye\": \"Goodbye!\",\n        \"chat_help_title\": \"Available Commands:\",\n        \"chat_help_content\": \"/help, /h         - Show this help message\\n/quit, /exit, /q  - Exit chat\\n/clear, /c        - Clear conversation history\\n/history, /hist   - Show conversation history\\n/info, /i         - Show current settings\\n/retry, /r        - Regenerate last response\",\n        \"chat_history_cleared\": \"Conversation history cleared\",\n        \"chat_no_history\": \"No conversation history\",\n        \"chat_history_title\": \"History ({count} messages)\",\n        \"chat_info_title\": \"Current Settings:\",\n        \"chat_info_content\": \"Temperature: {temperature}\\nMax tokens: {max_tokens}\\nMessages: {messages}\",\n        \"chat_retrying\": \"Retrying last response...\",\n        \"chat_no_retry\": \"No previous response to retry\",\n        \"chat_unknown_command\": \"Unknown command: {command}\",\n        \"chat_unknown_hint\": \"Type /help for available commands\",\n        # Run Interactive\n        \"run_int_no_moe_models\": \"No MoE GPU models found.\",\n        \"run_int_add_models\": \"Add models with: kt model scan\",\n        \"run_int_list_all\": \"List all models: kt model list --all\",\n        \"run_int_step1_title\": \"Step 1: Select Model (GPU MoE Models)\",\n        \"run_int_select_model\": \"Select model\",\n        \"run_int_step2_title\": \"Step 2: Select Inference Method\",\n        \"run_int_method_raw\": \"RAW Precision (FP8/FP8_PERCHANNEL/BF16/RAWINT4)\",\n        \"run_int_method_amx\": \"AMX Quantization (INT4/INT8)\",\n        \"run_int_method_gguf\": \"GGUF (Llamafile)\",\n        \"run_int_method_saved\": \"Use Saved Configuration\",\n        \"run_int_select_method\": \"Select inference method\",\n        \"run_int_raw_precision\": \"RAW Precision:\",\n        \"run_int_select_precision\": \"Select precision\",\n        \"run_int_amx_method\": \"AMX Method:\",\n        \"run_int_select_amx\": \"Select AMX method\",\n        \"run_int_step3_title\": \"Step 3: NUMA and CPU Configuration\",\n        \"run_int_numa_nodes\": \"NUMA Nodes (1-{max})\",\n        \"run_int_cpu_threads\": \"CPU Threads per NUMA (1-{max})\",\n        \"run_int_amx_warning\": \"⚠ Warning: AMX INT4/INT8 requires compatible CPU. Check with: kt doctor\",\n        \"run_int_step4_title\": \"Step 4: GPU Experts Configuration\",\n        \"run_int_gpu_experts\": \"GPU Experts per Layer (0-{max})\",\n        \"run_int_gpu_experts_info\": \"Total experts: {total}, Activated per token: {active}\",\n        \"run_int_step5_title\": \"Step 5: KV Cache Configuration\",\n        \"run_int_kv_cache_size\": \"KV Cache Size (tokens)\",\n        \"run_int_chunk_prefill\": \"Enable Chunk Prefill?\",\n        \"run_int_chunk_size\": \"Chunk Prefill Size (tokens)\",\n        \"run_int_gpu_prefill_threshold\": \"GPU Prefill Threshold (tokens)\",\n        \"run_int_step6_title\": \"Step 6: GPU Selection and Tensor Parallelism\",\n        \"run_int_available_gpus\": \"Available GPUs:\",\n        \"run_int_gpu_id\": \"GPU {id}\",\n        \"run_int_vram_info\": \"{name} ({total:.1f}GB total, {free:.1f}GB free)\",\n        \"run_int_select_gpus\": \"Select GPU IDs (comma-separated)\",\n        \"run_int_invalid_gpu_range\": \"All GPU IDs must be between 0 and {max}\",\n        \"run_int_tp_size\": \"TP Size (must be power of 2: 1,2,4,8...)\",\n        \"run_int_tp_mismatch\": \"TP size must match number of selected GPUs ({count})\",\n        \"run_int_tp_not_power_of_2\": \"TP size must be a power of 2\",\n        \"run_int_mem_fraction\": \"Static Memory Fraction (0.0-1.0)\",\n        \"run_int_using_saved_mem\": \"Using saved memory fraction: {fraction}\",\n        \"run_int_step7_title\": \"Step 7: Parser Configuration (Optional)\",\n        \"run_int_tool_call_parser\": \"Tool Call Parser (press Enter to skip)\",\n        \"run_int_reasoning_parser\": \"Reasoning Parser (press Enter to skip)\",\n        \"run_int_step8_title\": \"Step 8: Host and Port Configuration\",\n        \"run_int_host\": \"Host\",\n        \"run_int_port\": \"Port\",\n        \"run_int_port_occupied\": \"⚠ Port {port} is already in use\",\n        \"run_int_port_suggestion\": \"Suggested available port: {port}\",\n        \"run_int_use_suggested\": \"Use suggested port?\",\n        \"run_int_saved_configs\": \"Saved Configurations:\",\n        \"run_int_config_name\": \"Configuration {num}\",\n        \"run_int_kt_method\": \"KT Method:\",\n        \"run_int_numa_nodes_label\": \"NUMA Nodes:\",\n        \"run_int_cpu_threads_label\": \"CPU Threads:\",\n        \"run_int_gpu_experts_label\": \"GPU Experts:\",\n        \"run_int_tp_size_label\": \"TP Size:\",\n        \"run_int_mem_fraction_label\": \"Memory Fraction:\",\n        \"run_int_server_label\": \"Server:\",\n        \"run_int_kv_cache_label\": \"KV Cache:\",\n        \"run_int_chunk_prefill_label\": \"Chunk Prefill:\",\n        \"run_int_gpu_prefill_label\": \"GPU Prefill Thr:\",\n        \"run_int_tool_parser_label\": \"Tool Call Parser:\",\n        \"run_int_reasoning_parser_label\": \"Reasoning Parser:\",\n        \"run_int_command_label\": \"Command:\",\n        \"run_int_select_config\": \"Select configuration\",\n        \"run_int_gpu_select_required\": \"Please select {tp} GPUs (TP size from saved config)\",\n        \"run_int_port_check_title\": \"Port Configuration\",\n        \"run_int_port_checking\": \"Checking port {port} availability...\",\n        \"run_int_port_available\": \"Port {port} is available\",\n        \"run_int_saved_config_title\": \"Saved Configuration\",\n        \"run_int_save_config_title\": \"Save Configuration\",\n        \"run_int_save_config_prompt\": \"Save this configuration for future use?\",\n        \"run_int_config_name_prompt\": \"Configuration name\",\n        \"run_int_config_name_default\": \"Config {timestamp}\",\n        \"run_int_config_saved\": \"Configuration saved: {name}\",\n        \"run_int_config_summary\": \"Configuration Complete\",\n        \"run_int_model_label\": \"Model:\",\n        \"run_int_selected_gpus_label\": \"Selected GPUs:\",\n        # Model command\n        \"model_supported_title\": \"KTransformers Supported Models\",\n        \"model_column_model\": \"Model\",\n        \"model_column_status\": \"Status\",\n        \"model_column_local_path\": \"Local Path\",\n        \"model_status_local\": \"Local\",\n        \"model_status_not_downloaded\": \"Not downloaded\",\n        \"model_usage_title\": \"Usage\",\n        \"model_usage_download\": \"Download a model:\",\n        \"model_usage_list_local\": \"List local models:\",\n        \"model_usage_search\": \"Search models:\",\n        \"model_storage_paths_title\": \"Model Storage Paths\",\n        \"model_local_models_title\": \"Locally Downloaded Models\",\n        \"model_available_models_title\": \"Available Models\",\n        \"model_no_local_models\": \"No locally downloaded models found\",\n        \"model_download_hint\": \"Download a model with:\",\n        \"model_download_usage_hint\": \"Usage: kt model download <model-name>\",\n        \"model_download_list_hint\": \"Use 'kt model download --list' to see available models.\",\n        \"model_download_hf_hint\": \"Or specify a HuggingFace repo directly: kt model download org/model-name\",\n        \"model_saved_to\": \"Model saved to: {path}\",\n        \"model_start_with\": \"Start with: kt run {name}\",\n        \"model_download_failed\": \"Download failed: {error}\",\n        \"model_hf_cli_not_found\": \"huggingface-cli not found. Install with: pip install huggingface-hub\",\n        \"model_path_not_exist\": \"Path does not exist: {path}\",\n        \"model_create_directory\": \"Create directory {path}?\",\n        \"model_created_directory\": \"Created directory: {path}\",\n        \"model_create_dir_failed\": \"Failed to create directory: {error}\",\n        \"model_path_added\": \"Added model path: {path}\",\n        \"model_path_removed\": \"Removed model path: {path}\",\n        \"model_path_not_found\": \"Path not found in configuration or cannot remove last path: {path}\",\n        \"model_search_no_results\": \"No models found matching '{query}'\",\n        \"model_search_results_title\": \"Search Results for '{query}'\",\n        \"model_column_name\": \"Name\",\n        \"model_column_hf_repo\": \"HuggingFace Repo\",\n        \"model_column_aliases\": \"Aliases\",\n        # Model management - new user registry system\n        \"model_no_registered_models\": \"No models registered yet.\",\n        \"model_scan_hint\": \"Scan for models: kt model scan\",\n        \"model_add_hint\": \"Add a model: kt model add /path/to/model\",\n        \"model_registered_models_title\": \"Registered Models\",\n        \"model_column_format\": \"Format\",\n        \"model_column_repo\": \"Repository\",\n        \"model_column_sha256\": \"SHA256\",\n        \"model_non_moe_hidden_hint\": \"Detected {count} non-MoE models, use kt model list --all to show all\",\n        \"model_usage_title\": \"Common Operations:\",\n        \"model_usage_info\": \"View details:\",\n        \"model_usage_edit\": \"Edit model:\",\n        \"model_usage_verify\": \"Verify integrity:\",\n        \"model_usage_quant\": \"Quantize model:\",\n        \"model_usage_run\": \"Run model:\",\n        \"model_usage_scan\": \"Scan for models:\",\n        \"model_usage_add\": \"Add model:\",\n        \"model_usage_verbose\": \"View with file details:\",\n        \"model_no_storage_paths\": \"No storage paths configured.\",\n        \"model_add_path_hint\": \"Add a storage path with: kt config set model.storage_paths /path/to/models\",\n        \"model_scanning_paths\": \"Scanning configured storage paths...\",\n        \"model_scanning_progress\": \"Scanning: {path}\",\n        \"model_scan_warnings_title\": \"Warnings\",\n        \"model_scan_no_models_found\": \"No models found in configured paths.\",\n        \"model_scan_check_paths_hint\": \"Check your storage paths: kt config get model.storage_paths\",\n        \"model_scan_min_size_hint\": \"Folders must be ≥{size}GB to be detected as models.\",\n        \"model_scan_found_title\": \"Found {count} new model(s)\",\n        \"model_column_path\": \"Path\",\n        \"model_column_size\": \"Size\",\n        \"model_scan_auto_adding\": \"Auto-adding models...\",\n        \"model_added\": \"Added: {name}\",\n        \"model_add_failed\": \"Failed to add {name}: {error}\",\n        \"model_scan_complete\": \"Scan complete! Added {count} model(s).\",\n        \"model_scan_interactive_prompt\": \"Commands: edit <id> | del <id> | done\",\n        \"model_scan_cmd_edit\": \"Set custom name for model\",\n        \"model_scan_cmd_delete\": \"Skip this model\",\n        \"model_scan_cmd_done\": \"Finish and add models\",\n        \"model_scan_marked_skip\": \"Skipped model #{id}\",\n        \"model_scan_invalid_id\": \"Invalid model ID: {id}\",\n        \"model_scan_invalid_command\": \"Invalid command. Use: edit <id> | del <id> | done\",\n        \"model_scan_edit_model\": \"Edit model {id}\",\n        \"model_scan_edit_note\": \"You can change the model name before adding it to registry\",\n        \"model_scan_adding_models\": \"Adding {count} model(s)...\",\n        \"model_scan_next_steps\": \"Next Steps\",\n        \"model_scan_view_hint\": \"View registered models: kt model list\",\n        \"model_scan_edit_hint\": \"Edit model details: kt model edit <name>\",\n        \"model_scan_no_models_added\": \"No models were added.\",\n        \"model_add_path_not_exist\": \"Error: Path does not exist: {path}\",\n        \"model_add_not_directory\": \"Error: Path is not a directory: {path}\",\n        \"model_add_already_registered\": \"This path is already registered as: {name}\",\n        \"model_add_view_hint\": \"View with: kt model info {name}\",\n        \"model_add_scanning\": \"Scanning model files...\",\n        \"model_add_scan_failed\": \"Failed to scan model: {error}\",\n        \"model_add_no_model_files\": \"No model files found in {path}\",\n        \"model_add_supported_formats\": \"Supported: *.safetensors, *.gguf (folder ≥10GB)\",\n        \"model_add_detected\": \"Detected: {format} format, {size}, {count} file(s)\",\n        \"model_add_name_conflict\": \"Name '{name}' already exists.\",\n        \"model_add_prompt_name\": \"Enter a name for this model\",\n        \"model_add_name_exists\": \"Name already exists. Please choose another name:\",\n        \"model_add_configure_repo\": \"Configure repository information for SHA256 verification?\",\n        \"model_add_repo_type_prompt\": \"Select repository type:\",\n        \"model_add_choice\": \"Choice\",\n        \"model_add_repo_id_prompt\": \"Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)\",\n        \"model_add_success\": \"Successfully added model: {name}\",\n        \"model_add_verify_hint\": \"Verify integrity: kt model verify {name}\",\n        \"model_add_edit_later_hint\": \"Edit details later: kt model edit {name}\",\n        \"model_add_failed_generic\": \"Failed to add model: {error}\",\n        \"model_edit_not_found\": \"Model '{name}' not found.\",\n        \"model_edit_list_hint\": \"List models: kt model list\",\n        \"model_edit_current_config\": \"Current Configuration\",\n        \"model_edit_what_to_edit\": \"What would you like to edit?\",\n        \"model_edit_option_name\": \"Edit name\",\n        \"model_edit_option_repo\": \"Configure repository info\",\n        \"model_edit_option_delete\": \"Delete this model\",\n        \"model_edit_option_cancel\": \"Cancel / Exit\",\n        \"model_edit_choice_prompt\": \"Select option\",\n        \"model_edit_new_name\": \"Enter new name\",\n        \"model_edit_name_conflict\": \"Name '{name}' already exists. Please choose another:\",\n        \"model_edit_name_updated\": \"Name updated: {old} → {new}\",\n        \"model_edit_repo_type_prompt\": \"Repository type (or enter to remove repo info):\",\n        \"model_edit_repo_remove\": \"Remove repository info\",\n        \"model_edit_repo_id_prompt\": \"Enter repository ID\",\n        \"model_edit_repo_removed\": \"Repository info removed\",\n        \"model_edit_repo_updated\": \"Repository configured: {repo_type} → {repo_id}\",\n        \"model_edit_delete_warning\": \"Delete model '{name}' from registry?\",\n        \"model_edit_delete_note\": \"Note: This only removes the registry entry. Model files in {path} will NOT be deleted.\",\n        \"model_edit_delete_confirm\": \"Confirm deletion?\",\n        \"model_edit_deleted\": \"Model '{name}' deleted from registry\",\n        \"model_edit_delete_cancelled\": \"Deletion cancelled\",\n        \"model_edit_cancelled\": \"Edit cancelled\",\n        # Model edit - Interactive selection\n        \"model_edit_select_title\": \"Select Model to Edit\",\n        \"model_edit_select_model\": \"Select model\",\n        \"model_edit_invalid_choice\": \"Invalid choice\",\n        \"model_edit_no_models\": \"No models found in registry.\",\n        \"model_edit_add_hint_scan\": \"Add models with:\",\n        \"model_edit_add_hint_add\": \"Or:\",\n        # Model edit - Display\n        \"model_edit_gpu_links\": \"GPU Links:\",\n        # Model edit - Menu options\n        \"model_edit_manage_gpu_links\": \"Manage GPU Links\",\n        \"model_edit_save_changes\": \"Save changes\",\n        \"model_edit_has_changes\": \"(has changes)\",\n        \"model_edit_no_changes\": \"(no changes)\",\n        # Model edit - Pending changes messages\n        \"model_edit_name_pending\": \"Name will be updated when you save changes.\",\n        \"model_edit_repo_remove_pending\": \"Repository info will be removed when you save changes.\",\n        \"model_edit_repo_update_pending\": \"Repository info will be updated when you save changes.\",\n        # Model edit - GPU link management\n        \"model_edit_gpu_links_title\": \"Manage GPU Links for {name}\",\n        \"model_edit_current_gpu_links\": \"Current GPU links:\",\n        \"model_edit_no_gpu_links\": \"No GPU links configured.\",\n        \"model_edit_gpu_options\": \"Options:\",\n        \"model_edit_gpu_add\": \"Add GPU link\",\n        \"model_edit_gpu_remove\": \"Remove GPU link\",\n        \"model_edit_gpu_clear\": \"Clear all GPU links\",\n        \"model_edit_gpu_back\": \"Back to main menu\",\n        \"model_edit_gpu_choose_option\": \"Choose option\",\n        \"model_edit_gpu_none_available\": \"No GPU models available to link.\",\n        \"model_edit_gpu_available_models\": \"Available GPU models:\",\n        \"model_edit_gpu_already_linked\": \"(already linked)\",\n        \"model_edit_gpu_enter_number\": \"Enter GPU model number to add\",\n        \"model_edit_gpu_link_pending\": \"GPU link will be added when you save changes: {name}\",\n        \"model_edit_gpu_already_exists\": \"This GPU model is already linked.\",\n        \"model_edit_gpu_invalid_choice\": \"Invalid choice.\",\n        \"model_edit_gpu_invalid_input\": \"Invalid input.\",\n        \"model_edit_gpu_none_to_remove\": \"No GPU links to remove.\",\n        \"model_edit_gpu_choose_to_remove\": \"Choose GPU link to remove:\",\n        \"model_edit_gpu_enter_to_remove\": \"Enter number to remove\",\n        \"model_edit_gpu_remove_pending\": \"GPU link will be removed when you save changes: {name}\",\n        \"model_edit_gpu_none_to_clear\": \"No GPU links to clear.\",\n        \"model_edit_gpu_clear_confirm\": \"Remove all GPU links?\",\n        \"model_edit_gpu_clear_pending\": \"All GPU links will be removed when you save changes.\",\n        \"model_edit_cancelled_short\": \"Cancelled.\",\n        # Model edit - Save operation\n        \"model_edit_no_changes_to_save\": \"No changes to save.\",\n        \"model_edit_saving\": \"Saving changes...\",\n        \"model_edit_saved\": \"Changes saved successfully!\",\n        \"model_edit_updated_config\": \"Updated Configuration:\",\n        \"model_edit_repo_changed_warning\": \"⚠ Repository information has changed.\",\n        \"model_edit_verify_hint\": \"Run [cyan]kt model verify[/cyan] to verify model integrity with SHA256 checksums.\",\n        \"model_edit_discard_changes\": \"Discard unsaved changes?\",\n        \"model_info_not_found\": \"Model '{name}' not found.\",\n        \"model_info_list_hint\": \"List all models: kt model list\",\n        \"model_remove_not_found\": \"Model '{name}' not found.\",\n        \"model_remove_list_hint\": \"List models: kt model list\",\n        \"model_remove_warning\": \"Remove model '{name}' from registry?\",\n        \"model_remove_note\": \"Note: This only removes the registry entry. Model files will NOT be deleted from {path}.\",\n        \"model_remove_confirm\": \"Confirm removal?\",\n        \"model_remove_cancelled\": \"Removal cancelled\",\n        \"model_removed\": \"Model '{name}' removed from registry\",\n        \"model_remove_failed\": \"Failed to remove model: {error}\",\n        \"model_refresh_checking\": \"Checking model paths...\",\n        \"model_refresh_all_valid\": \"All models are valid! ({count} model(s) checked)\",\n        \"model_refresh_total\": \"Total models: {total}\",\n        \"model_refresh_missing_found\": \"Found {count} missing model(s)\",\n        \"model_refresh_suggestions\": \"Suggested Actions\",\n        \"model_refresh_remove_hint\": \"Remove from registry: kt model remove <name>\",\n        \"model_refresh_rescan_hint\": \"Re-scan for models: kt model scan\",\n        \"model_verify_not_found\": \"Model '{name}' not found.\",\n        \"model_verify_list_hint\": \"List models: kt model list\",\n        \"model_verify_no_repo\": \"Model '{name}' has no repository information configured.\",\n        \"model_verify_config_hint\": \"Configure repository: kt model edit {name}\",\n        \"model_verify_path_missing\": \"Model path does not exist: {path}\",\n        \"model_verify_starting\": \"Verifying model integrity...\",\n        \"model_verify_progress\": \"Repository: {repo_type} → {repo_id}\",\n        \"model_verify_not_implemented\": \"SHA256 verification not implemented yet\",\n        \"model_verify_future_note\": \"This feature will fetch official SHA256 hashes from {repo_type} and compare with local files.\",\n        \"model_verify_passed\": \"Verification passed! All files match official hashes.\",\n        \"model_verify_failed\": \"Verification failed! {count} file(s) have hash mismatches.\",\n        \"model_verify_all_no_repos\": \"No models have repository information configured.\",\n        \"model_verify_all_config_hint\": \"Configure repos using: kt model edit <name>\",\n        \"model_verify_all_found\": \"Found {count} model(s) with repository info\",\n        \"model_verify_all_manual_hint\": \"Verify specific model: kt model verify <name>\",\n        # Coming soon\n        \"feature_coming_soon\": \"This feature is coming soon...\",\n    },\n    \"zh\": {\n        # General\n        \"welcome\": \"欢迎使用 KTransformers！\",\n        \"goodbye\": \"再见！\",\n        \"error\": \"错误\",\n        \"warning\": \"警告\",\n        \"success\": \"成功\",\n        \"info\": \"信息\",\n        \"yes\": \"是\",\n        \"no\": \"否\",\n        \"cancel\": \"取消\",\n        \"confirm\": \"确认\",\n        \"done\": \"完成\",\n        \"failed\": \"失败\",\n        \"skip\": \"跳过\",\n        \"back\": \"返回\",\n        \"next\": \"下一步\",\n        \"retry\": \"重试\",\n        \"abort\": \"中止\",\n        # Version command\n        \"version_info\": \"KTransformers CLI\",\n        \"version_python\": \"Python\",\n        \"version_platform\": \"平台\",\n        \"version_cuda\": \"CUDA\",\n        \"version_cuda_not_found\": \"未找到\",\n        \"version_kt_kernel\": \"kt-kernel\",\n        \"version_ktransformers\": \"ktransformers\",\n        \"version_sglang\": \"sglang-kt\",\n        \"version_llamafactory\": \"llamafactory\",\n        \"version_not_installed\": \"未安装\",\n        # Install command\n        \"install_detecting_env\": \"检测环境管理工具...\",\n        \"install_found\": \"发现 {name} (版本 {version})\",\n        \"install_not_found\": \"未找到: {name}\",\n        \"install_checking_env\": \"检查现有环境...\",\n        \"install_env_exists\": \"发现现有 'kt' 环境\",\n        \"install_env_not_exists\": \"未发现 'kt' 环境\",\n        \"install_no_env_manager\": \"未检测到虚拟环境管理工具\",\n        \"install_select_method\": \"请选择安装方式:\",\n        \"install_method_conda\": \"创建新的 conda 环境 'kt' (推荐)\",\n        \"install_method_venv\": \"创建新的 venv 环境\",\n        \"install_method_uv\": \"创建新的 uv 环境 (快速)\",\n        \"install_method_docker\": \"使用 Docker 容器\",\n        \"install_method_system\": \"安装到系统 Python (不推荐)\",\n        \"install_select_mode\": \"请选择安装模式:\",\n        \"install_mode_inference\": \"推理模式 - 安装 kt-kernel + SGLang\",\n        \"install_mode_sft\": \"训练模式 - 安装 kt-sft + LlamaFactory\",\n        \"install_mode_full\": \"完整安装 - 安装所有组件\",\n        \"install_creating_env\": \"正在创建 {type} 环境 '{name}'...\",\n        \"install_env_created\": \"环境创建成功\",\n        \"install_installing_deps\": \"正在安装依赖...\",\n        \"install_checking_deps\": \"检查依赖版本...\",\n        \"install_dep_ok\": \"正常\",\n        \"install_dep_outdated\": \"需更新\",\n        \"install_dep_missing\": \"缺失\",\n        \"install_installing_pytorch\": \"正在安装 PyTorch...\",\n        \"install_installing_from_requirements\": \"从依赖文件安装...\",\n        \"install_deps_outdated\": \"发现 {count} 个包需要更新，是否继续？\",\n        \"install_updating\": \"正在更新包...\",\n        \"install_complete\": \"安装完成！\",\n        \"install_activate_hint\": \"激活环境: {command}\",\n        \"install_start_hint\": \"开始使用: kt run --help\",\n        \"install_docker_pulling\": \"正在拉取 Docker 镜像...\",\n        \"install_docker_complete\": \"Docker 镜像已就绪！\",\n        \"install_docker_run_hint\": \"运行: docker run --gpus all -p 30000:30000 {image} kt run {model}\",\n        \"install_in_venv\": \"当前在虚拟环境中: {name}\",\n        \"install_continue_without_venv\": \"继续安装到系统 Python？\",\n        \"install_already_installed\": \"所有依赖已安装！\",\n        \"install_confirm\": \"安装 {count} 个包？\",\n        # Install - System dependencies\n        \"install_checking_system_deps\": \"检查系统依赖...\",\n        \"install_dep_name\": \"依赖项\",\n        \"install_dep_status\": \"状态\",\n        \"install_deps_all_installed\": \"所有系统依赖已安装\",\n        \"install_deps_install_prompt\": \"是否安装缺失的依赖？\",\n        \"install_installing_system_deps\": \"正在安装系统依赖...\",\n        \"install_installing_dep\": \"正在安装 {name}\",\n        \"install_dep_no_install_cmd\": \"{os} 系统上没有 {name} 的安装命令\",\n        \"install_dep_install_failed\": \"安装 {name} 失败\",\n        \"install_deps_skipped\": \"跳过依赖安装\",\n        \"install_deps_failed\": \"系统依赖安装失败\",\n        # Install - CPU detection\n        \"install_auto_detect_cpu\": \"正在自动检测 CPU 能力...\",\n        \"install_cpu_features\": \"检测到的 CPU 特性: {features}\",\n        \"install_cpu_no_features\": \"未检测到高级 CPU 特性\",\n        # Install - Build configuration\n        \"install_build_config\": \"构建配置:\",\n        \"install_native_warning\": \"注意: 二进制文件仅针对当前 CPU 优化（不可移植）\",\n        \"install_building_from_source\": \"正在从源码构建 kt-kernel...\",\n        \"install_build_failed\": \"构建失败\",\n        \"install_build_success\": \"构建成功\",\n        # Install - Verification\n        \"install_verifying\": \"正在验证安装...\",\n        \"install_verify_success\": \"kt-kernel {version} ({variant} 变体) 安装成功\",\n        \"install_verify_failed\": \"验证失败: {error}\",\n        # Install - Docker\n        \"install_docker_guide_title\": \"Docker 安装\",\n        \"install_docker_guide_desc\": \"有关 Docker 安装，请参阅官方指南:\",\n        # Config command\n        \"config_show_title\": \"当前配置\",\n        \"config_set_success\": \"配置已更新: {key} = {value}\",\n        \"config_get_value\": \"{key} = {value}\",\n        \"config_get_not_found\": \"未找到配置项 '{key}'\",\n        \"config_reset_confirm\": \"这将重置所有配置为默认值。是否继续？\",\n        \"config_reset_success\": \"配置已重置为默认值\",\n        \"config_file_location\": \"配置文件: {path}\",\n        # Doctor command\n        \"doctor_title\": \"KTransformers 环境诊断\",\n        \"doctor_checking\": \"正在运行诊断...\",\n        \"doctor_check_python\": \"Python 版本\",\n        \"doctor_check_cuda\": \"CUDA 可用性\",\n        \"doctor_check_gpu\": \"GPU 检测\",\n        \"doctor_check_cpu\": \"CPU\",\n        \"doctor_check_cpu_isa\": \"CPU 指令集\",\n        \"doctor_check_numa\": \"NUMA 拓扑\",\n        \"doctor_check_memory\": \"系统内存\",\n        \"doctor_check_disk\": \"磁盘空间\",\n        \"doctor_check_packages\": \"必需的包\",\n        \"doctor_check_env\": \"环境变量\",\n        \"doctor_status_ok\": \"正常\",\n        \"doctor_status_warning\": \"警告\",\n        \"doctor_status_error\": \"错误\",\n        \"doctor_gpu_found\": \"发现 {count} 个 GPU: {names}\",\n        \"doctor_gpu_not_found\": \"未检测到 GPU\",\n        \"doctor_cpu_info\": \"{name} ({cores} 核心 / {threads} 线程)\",\n        \"doctor_cpu_isa_info\": \"{isa_list}\",\n        \"doctor_cpu_isa_missing\": \"缺少推荐指令集: {missing}\",\n        \"doctor_numa_info\": \"{nodes} 个节点\",\n        \"doctor_numa_detail\": \"{node}: CPU {cpus}\",\n        \"doctor_memory_info\": \"{available} 可用 / {total} 总计\",\n        \"doctor_memory_freq\": \"{available} 可用 / {total} 总计 ({freq}MHz {type})\",\n        \"doctor_disk_info\": \"{path} 有 {available} 可用空间\",\n        \"doctor_all_ok\": \"所有检查通过！您的环境已就绪。\",\n        \"doctor_has_issues\": \"发现一些问题，请查看上方的警告/错误信息。\",\n        # Run command\n        \"run_detecting_hardware\": \"检测硬件配置...\",\n        \"run_gpu_info\": \"GPU: {name} ({vram}GB 显存)\",\n        \"run_cpu_info\": \"CPU: {name} ({cores} 核心, {numa} NUMA 节点)\",\n        \"run_ram_info\": \"内存: {total}GB\",\n        \"run_checking_model\": \"检查模型状态...\",\n        \"run_model_path\": \"模型路径: {path}\",\n        \"run_weights_not_found\": \"未找到量化权重\",\n        \"run_quant_prompt\": \"是否现在量化模型？(这可能需要一些时间)\",\n        \"run_quantizing\": \"正在量化模型...\",\n        \"run_starting_server\": \"正在启动服务器...\",\n        \"run_server_mode\": \"模式: SGLang + kt-kernel\",\n        \"run_server_port\": \"端口: {port}\",\n        \"run_gpu_experts\": \"GPU 专家: {count}/层\",\n        \"run_cpu_threads\": \"CPU 线程: {count}\",\n        \"run_server_started\": \"服务器已启动！\",\n        \"run_api_url\": \"API 地址: http://{host}:{port}\",\n        \"run_docs_url\": \"文档地址: http://{host}:{port}/docs\",\n        \"run_stop_hint\": \"按 Ctrl+C 停止服务器\",\n        \"run_model_not_found\": \"未找到模型 '{name}'。请先运行 'kt download'。\",\n        \"run_multiple_matches\": \"找到多个匹配的模型，请选择:\",\n        \"run_select_model\": \"选择模型\",\n        \"run_select_model_title\": \"选择要运行的模型\",\n        \"run_select_model_prompt\": \"输入编号\",\n        \"run_local_models\": \"本地模型 (已下载)\",\n        \"run_registered_models\": \"注册模型\",\n        # Download command\n        \"download_list_title\": \"可用模型\",\n        \"download_searching\": \"正在搜索模型 '{name}'...\",\n        \"download_found\": \"找到: {name}\",\n        \"download_multiple_found\": \"找到多个匹配:\",\n        \"download_select\": \"选择要下载的模型:\",\n        \"download_destination\": \"目标路径: {path}\",\n        \"download_starting\": \"开始下载...\",\n        \"download_progress\": \"正在下载 {name}...\",\n        \"download_complete\": \"下载完成！\",\n        \"download_already_exists\": \"模型已存在于 {path}\",\n        \"download_overwrite_prompt\": \"是否覆盖现有文件？\",\n        # Quant command\n        \"quant_input_path\": \"输入路径: {path}\",\n        \"quant_output_path\": \"输出路径: {path}\",\n        \"quant_method\": \"量化方法: {method}\",\n        \"quant_starting\": \"开始量化...\",\n        \"quant_progress\": \"正在量化...\",\n        \"quant_complete\": \"量化完成！\",\n        \"quant_input_not_found\": \"未找到输入模型: {path}\",\n        \"quant_cpu_threads\": \"CPU 线程数: {threads}\",\n        \"quant_numa_nodes\": \"NUMA 节点数: {nodes}\",\n        \"quant_time_warning\": \"量化可能需要 30-60 分钟，具体取决于模型大小。\",\n        \"quant_disk_analysis\": \"磁盘空间分析：\",\n        \"quant_source_size\": \"源模型大小：\",\n        \"quant_estimated_size\": \"预估输出大小：\",\n        \"quant_available_space\": \"可用空间：\",\n        \"quant_insufficient_space\": \"警告：磁盘空间不足！\",\n        \"quant_required_space\": \"所需空间（含20%缓冲）：\",\n        \"quant_shortage\": \"不足：\",\n        \"quant_may_fail\": \"量化可能失败或生成不完整的文件。\",\n        \"quant_continue_anyway\": \"仍然继续？\",\n        \"quant_settings\": \"量化设置：\",\n        \"quant_registered\": \"量化模型已注册：{name}\",\n        \"quant_view_with\": \"查看：\",\n        \"quant_use_with\": \"使用：\",\n        \"quant_register_failed\": \"自动注册模型失败：{error}\",\n        \"quant_output_exists\": \"输出路径已存在：{path}\",\n        \"quant_using_unique\": \"使用唯一名称：{path}\",\n        # Interactive quant\n        \"quant_interactive_title\": \"交互式量化配置\",\n        \"quant_new_model_notice\": \"⚠ 注意：部分新模型暂时无法量化（转换脚本未适配），推荐使用原精度进行推理（无需转换权重）。\",\n        \"quant_no_moe_models\": \"未找到可量化的 MoE 模型。\",\n        \"quant_only_moe\": \"只有 MoE 模型（如 DeepSeek-V3）可以被量化为 AMX 格式。\",\n        \"quant_add_models\": \"添加模型：{command}\",\n        \"quant_moe_available\": \"可量化的 MoE 模型：\",\n        \"quant_select_model\": \"选择要量化的模型\",\n        \"quant_invalid_choice\": \"无效选择\",\n        \"quant_step2_method\": \"第 2 步：量化方法\",\n        \"quant_method_label\": \"量化方法：\",\n        \"quant_int4_desc\": \"INT4\",\n        \"quant_int8_desc\": \"INT8\",\n        \"quant_select_method\": \"选择量化方法\",\n        \"quant_input_type_label\": \"输入权重类型：\",\n        \"quant_fp8_desc\": \"FP8（适用于 8 位浮点权重）\",\n        \"quant_fp16_desc\": \"FP16（适用于 16 位浮点权重）\",\n        \"quant_bf16_desc\": \"BF16（适用于 Brain Float 16 权重）\",\n        \"quant_select_input_type\": \"选择输入类型\",\n        \"quant_step3_cpu\": \"第 3 步：CPU 配置\",\n        \"quant_cpu_threads_prompt\": \"CPU 线程数（1 到 {max}）\",\n        \"quant_numa_nodes_prompt\": \"NUMA 节点数（1 到 {max}）\",\n        \"quant_use_gpu_label\": \"是否使用 GPU 进行转换？\",\n        \"quant_gpu_speedup\": \"GPU 可以显著加快量化速度\",\n        \"quant_enable_gpu\": \"启用 GPU 加速？\",\n        \"quant_step4_output\": \"第 4 步：输出路径\",\n        \"quant_default_path\": \"默认：\",\n        \"quant_use_default\": \"使用默认输出路径？\",\n        \"quant_custom_path\": \"输入自定义输出路径\",\n        \"quant_output_exists_warn\": \"⚠ 输出路径已存在：{path}\",\n        \"quant_using_unique_name\": \"→ 使用唯一名称：{path}\",\n        \"quant_config_summary\": \"配置摘要\",\n        \"quant_summary_model\": \"模型：\",\n        \"quant_summary_method\": \"方法：\",\n        \"quant_summary_input_type\": \"输入类型：\",\n        \"quant_summary_cpu_threads\": \"CPU 线程数：\",\n        \"quant_summary_numa\": \"NUMA 节点数：\",\n        \"quant_summary_gpu\": \"使用 GPU：\",\n        \"quant_summary_output\": \"输出路径：\",\n        \"quant_start_question\": \"开始量化？\",\n        \"quant_cancelled\": \"已取消\",\n        \"quant_config_complete\": \"配置完成\",\n        \"quant_time_elapsed\": \"耗时：\",\n        \"yes\": \"是\",\n        \"no\": \"否\",\n        # SFT command\n        \"sft_mode_train\": \"训练模式\",\n        \"sft_mode_chat\": \"聊天模式\",\n        \"sft_mode_export\": \"导出模式\",\n        \"sft_config_path\": \"配置文件: {path}\",\n        \"sft_starting\": \"正在启动 {mode}...\",\n        \"sft_complete\": \"{mode} 完成！\",\n        \"sft_config_not_found\": \"未找到配置文件: {path}\",\n        # Bench command\n        \"bench_starting\": \"开始基准测试...\",\n        \"bench_type\": \"测试类型: {type}\",\n        \"bench_complete\": \"基准测试完成！\",\n        \"bench_results_title\": \"基准测试结果\",\n        # Common prompts\n        \"prompt_continue\": \"是否继续？\",\n        \"prompt_select\": \"请选择:\",\n        \"prompt_enter_value\": \"请输入:\",\n        \"prompt_confirm_action\": \"确认此操作？\",\n        # First-run setup - Model path selection\n        \"setup_model_path_title\": \"模型存储位置\",\n        \"setup_model_path_desc\": \"大语言模型体积较大（50-200GB+）。请选择一个有足够空间的存储位置：\",\n        \"setup_scanning_disks\": \"正在扫描可用存储位置...\",\n        \"setup_disk_option\": \"{path} (可用 {available} / 总共 {total})\",\n        \"setup_disk_option_recommended\": \"{path} (可用 {available} / 总共 {total}) [推荐]\",\n        \"setup_custom_path\": \"输入自定义路径\",\n        \"setup_enter_custom_path\": \"请输入模型存储路径\",\n        \"setup_path_not_exist\": \"路径不存在，是否创建？\",\n        \"setup_path_no_write\": \"没有该路径的写入权限，请选择其他路径。\",\n        \"setup_path_low_space\": \"警告：可用空间不足 100GB，可能无法存储大型模型。\",\n        \"setup_model_path_set\": \"模型存储路径已设置为: {path}\",\n        \"setup_no_large_disk\": \"未发现大容量存储位置，使用默认路径。\",\n        \"setup_scanning_models\": \"正在扫描已有模型...\",\n        \"setup_found_models\": \"发现 {count} 个模型:\",\n        \"setup_model_info\": \"{name} ({size}, {type})\",\n        \"setup_no_models_found\": \"该位置未发现已有模型。\",\n        \"setup_location_has_models\": \"发现 {count} 个模型\",\n        \"setup_installing_completion\": \"正在为 {shell} 安装命令补全...\",\n        \"setup_completion_installed\": \"命令补全已安装！重启终端后生效。\",\n        \"setup_completion_failed\": \"命令补全安装失败。请手动运行 'kt --install-completion'。\",\n        # Auto completion\n        \"completion_installed_title\": \"命令补全\",\n        \"completion_installed_for\": \"已为 {shell} 安装命令补全\",\n        \"completion_activate_now\": \"在当前终端会话中启用补全，请运行：\",\n        \"completion_next_session\": \"新的终端会话将自动启用补全。\",\n        # SGLang\n        \"sglang_not_found\": \"未找到 SGLang\",\n        \"sglang_pypi_warning\": \"PyPI 版本的 SGLang 可能与 kt-kernel 不兼容。请使用 sglang-kt: pip install sglang-kt\",\n        \"sglang_pypi_hint\": \"PyPI 版本可能不兼容。安装 kvcache-ai 分支: pip install sglang-kt (或在 ktransformers 根目录运行 ./install.sh)\",\n        \"sglang_install_hint\": \"安装 SGLang: pip install sglang-kt (或在 ktransformers 根目录运行 ./install.sh)\",\n        \"sglang_recommend_source\": \"建议重新安装 kvcache-ai 分支: pip uninstall sglang -y && pip install sglang-kt\",\n        \"sglang_kt_kernel_not_supported\": \"SGLang 不支持 kt-kernel (缺少 --kt-gpu-prefill-token-threshold 参数)\",\n        \"sglang_checking_kt_kernel_support\": \"正在检查 SGLang kt-kernel 支持...\",\n        \"sglang_kt_kernel_supported\": \"SGLang kt-kernel 支持已验证\",\n        # Chat\n        \"chat_proxy_detected\": \"检测到环境中存在代理设置\",\n        \"chat_proxy_confirm\": \"是否使用代理连接？\",\n        \"chat_proxy_disabled\": \"已在本次会话中禁用代理\",\n        \"chat_openai_required\": \"聊天功能需要 OpenAI Python SDK。\",\n        \"chat_install_hint\": \"安装命令：\",\n        \"chat_title\": \"KTransformers 对话\",\n        \"chat_server\": \"服务器\",\n        \"chat_temperature\": \"温度\",\n        \"chat_max_tokens\": \"最大 tokens\",\n        \"chat_help_hint\": \"输入 '/help' 查看命令，'/quit' 退出\",\n        \"chat_connecting\": \"正在连接服务器...\",\n        \"chat_no_models\": \"服务器上没有可用模型\",\n        \"chat_model_not_found\": \"未找到模型 '{model}'。可用模型：{available}\",\n        \"chat_connected\": \"已连接到模型：{model}\",\n        \"chat_connect_failed\": \"连接服务器失败：{error}\",\n        \"chat_server_not_running\": \"请确保模型服务器正在运行：\",\n        \"chat_user_prompt\": \"用户\",\n        \"chat_assistant_prompt\": \"助手\",\n        \"chat_generation_error\": \"生成回复时出错：{error}\",\n        \"chat_interrupted\": \"对话已中断。再见！\",\n        \"chat_history_saved\": \"历史记录已保存到：{path}\",\n        \"chat_goodbye\": \"再见！\",\n        \"chat_help_title\": \"可用命令：\",\n        \"chat_help_content\": \"/help, /h         - 显示此帮助信息\\n/quit, /exit, /q  - 退出聊天\\n/clear, /c        - 清除对话历史\\n/history, /hist   - 显示对话历史\\n/info, /i         - 显示当前设置\\n/retry, /r        - 重新生成上一个回复\",\n        \"chat_history_cleared\": \"对话历史已清除\",\n        \"chat_no_history\": \"暂无对话历史\",\n        \"chat_history_title\": \"历史记录（{count} 条消息）\",\n        \"chat_info_title\": \"当前设置：\",\n        \"chat_info_content\": \"温度：{temperature}\\n最大 tokens：{max_tokens}\\n消息数：{messages}\",\n        \"chat_retrying\": \"正在重试上一个回复...\",\n        \"chat_no_retry\": \"没有可重试的回复\",\n        \"chat_unknown_command\": \"未知命令：{command}\",\n        \"chat_unknown_hint\": \"输入 /help 查看可用命令\",\n        # Run Interactive\n        \"run_int_no_moe_models\": \"未找到 MoE GPU 模型。\",\n        \"run_int_add_models\": \"添加模型：kt model scan\",\n        \"run_int_list_all\": \"列出所有模型：kt model list --all\",\n        \"run_int_step1_title\": \"第 1 步：选择模型（GPU MoE 模型）\",\n        \"run_int_select_model\": \"选择模型\",\n        \"run_int_step2_title\": \"第 2 步：选择推理方法\",\n        \"run_int_method_raw\": \"RAW 精度（FP8/FP8_PERCHANNEL/BF16/RAWINT4）\",\n        \"run_int_method_amx\": \"AMX 量化（INT4/INT8）\",\n        \"run_int_method_gguf\": \"GGUF（Llamafile）\",\n        \"run_int_method_saved\": \"使用已保存的配置\",\n        \"run_int_select_method\": \"选择推理方法\",\n        \"run_int_raw_precision\": \"RAW 精度：\",\n        \"run_int_select_precision\": \"选择精度\",\n        \"run_int_amx_method\": \"AMX 方法：\",\n        \"run_int_select_amx\": \"选择 AMX 方法\",\n        \"run_int_step3_title\": \"第 3 步：NUMA 和 CPU 配置\",\n        \"run_int_numa_nodes\": \"NUMA 节点数（1-{max}）\",\n        \"run_int_cpu_threads\": \"每个 NUMA 的 CPU 线程数（1-{max}）\",\n        \"run_int_amx_warning\": \"⚠ 警告：AMX INT4/INT8 需要兼容的 CPU。检查命令：kt doctor\",\n        \"run_int_step4_title\": \"第 4 步：GPU 专家配置\",\n        \"run_int_gpu_experts\": \"每层 GPU 专家数（0-{max}）\",\n        \"run_int_gpu_experts_info\": \"总专家数：{total}，每 token 激活：{active}\",\n        \"run_int_step5_title\": \"第 5 步：KV Cache 配置\",\n        \"run_int_kv_cache_size\": \"KV Cache 大小（tokens）\",\n        \"run_int_chunk_prefill\": \"启用分块预填充？\",\n        \"run_int_chunk_size\": \"分块预填充大小（tokens）\",\n        \"run_int_gpu_prefill_threshold\": \"GPU 预填充阈值（tokens）\",\n        \"run_int_step6_title\": \"第 6 步：GPU 选择和张量并行\",\n        \"run_int_available_gpus\": \"可用 GPU：\",\n        \"run_int_gpu_id\": \"GPU {id}\",\n        \"run_int_vram_info\": \"{name}（总计 {total:.1f}GB，空闲 {free:.1f}GB）\",\n        \"run_int_select_gpus\": \"选择 GPU ID（逗号分隔）\",\n        \"run_int_invalid_gpu_range\": \"所有 GPU ID 必须在 0 到 {max} 之间\",\n        \"run_int_tp_size\": \"TP 大小（必须是 2 的幂：1,2,4,8...）\",\n        \"run_int_tp_mismatch\": \"TP 大小必须与选择的 GPU 数量匹配（{count}）\",\n        \"run_int_tp_not_power_of_2\": \"TP 大小必须是 2 的幂\",\n        \"run_int_mem_fraction\": \"静态内存占用比例（0.0-1.0）\",\n        \"run_int_using_saved_mem\": \"使用已保存的内存占用比例：{fraction}\",\n        \"run_int_step7_title\": \"第 7 步：解析器配置（可选）\",\n        \"run_int_tool_call_parser\": \"工具调用解析器（按回车跳过）\",\n        \"run_int_reasoning_parser\": \"推理解析器（按回车跳过）\",\n        \"run_int_step8_title\": \"第 8 步：主机和端口配置\",\n        \"run_int_host\": \"主机\",\n        \"run_int_port\": \"端口\",\n        \"run_int_port_occupied\": \"⚠ 端口 {port} 已被占用\",\n        \"run_int_port_suggestion\": \"建议使用可用端口：{port}\",\n        \"run_int_use_suggested\": \"使用建议的端口？\",\n        \"run_int_saved_configs\": \"已保存的配置：\",\n        \"run_int_config_name\": \"配置 {num}\",\n        \"run_int_kt_method\": \"KT 方法：\",\n        \"run_int_numa_nodes_label\": \"NUMA 节点：\",\n        \"run_int_cpu_threads_label\": \"CPU 线程：\",\n        \"run_int_gpu_experts_label\": \"GPU 专家：\",\n        \"run_int_tp_size_label\": \"TP 大小：\",\n        \"run_int_mem_fraction_label\": \"内存占用比例：\",\n        \"run_int_server_label\": \"服务器：\",\n        \"run_int_kv_cache_label\": \"KV Cache：\",\n        \"run_int_chunk_prefill_label\": \"分块预填充：\",\n        \"run_int_gpu_prefill_label\": \"GPU 预填充阈值：\",\n        \"run_int_tool_parser_label\": \"工具调用解析器：\",\n        \"run_int_reasoning_parser_label\": \"推理解析器：\",\n        \"run_int_command_label\": \"命令：\",\n        \"run_int_select_config\": \"选择配置\",\n        \"run_int_gpu_select_required\": \"请选择 {tp} 个 GPU（来自已保存配置的 TP 大小）\",\n        \"run_int_port_check_title\": \"端口配置\",\n        \"run_int_port_checking\": \"正在检查端口 {port} 可用性...\",\n        \"run_int_port_available\": \"端口 {port} 可用\",\n        \"run_int_saved_config_title\": \"已保存的配置\",\n        \"run_int_save_config_title\": \"保存配置\",\n        \"run_int_save_config_prompt\": \"保存此配置以供将来使用？\",\n        \"run_int_config_name_prompt\": \"配置名称\",\n        \"run_int_config_name_default\": \"配置 {timestamp}\",\n        \"run_int_config_saved\": \"配置已保存：{name}\",\n        \"run_int_config_summary\": \"配置完成\",\n        \"run_int_model_label\": \"模型：\",\n        \"run_int_selected_gpus_label\": \"已选择的 GPU：\",\n        # Model command\n        \"model_supported_title\": \"KTransformers 支持的模型\",\n        \"model_column_model\": \"模型\",\n        \"model_column_status\": \"状态\",\n        \"model_column_local_path\": \"本地路径\",\n        \"model_status_local\": \"本地\",\n        \"model_status_not_downloaded\": \"未下载\",\n        \"model_usage_title\": \"使用方法\",\n        \"model_usage_download\": \"下载模型:\",\n        \"model_usage_list_local\": \"列出本地模型:\",\n        \"model_usage_search\": \"搜索模型:\",\n        \"model_storage_paths_title\": \"模型存储路径\",\n        \"model_local_models_title\": \"本地已下载的模型\",\n        \"model_available_models_title\": \"可用模型\",\n        \"model_no_local_models\": \"未找到本地已下载的模型\",\n        \"model_download_hint\": \"下载模型:\",\n        \"model_download_usage_hint\": \"用法: kt model download <模型名称>\",\n        \"model_download_list_hint\": \"使用 'kt model download --list' 查看可用模型。\",\n        \"model_download_hf_hint\": \"或直接指定 HuggingFace 仓库: kt model download org/model-name\",\n        \"model_saved_to\": \"模型已保存到: {path}\",\n        \"model_start_with\": \"启动命令: kt run {name}\",\n        \"model_download_failed\": \"下载失败: {error}\",\n        \"model_hf_cli_not_found\": \"未找到 huggingface-cli。请安装: pip install huggingface-hub\",\n        \"model_path_not_exist\": \"路径不存在: {path}\",\n        \"model_create_directory\": \"创建目录 {path}？\",\n        \"model_created_directory\": \"已创建目录: {path}\",\n        \"model_create_dir_failed\": \"创建目录失败: {error}\",\n        \"model_path_added\": \"已添加模型路径: {path}\",\n        \"model_path_removed\": \"已移除模型路径: {path}\",\n        \"model_path_not_found\": \"路径未找到或无法移除最后一个路径: {path}\",\n        \"model_search_no_results\": \"未找到匹配 '{query}' 的模型\",\n        \"model_search_results_title\": \"'{query}' 的搜索结果\",\n        \"model_column_name\": \"名称\",\n        \"model_column_hf_repo\": \"HuggingFace 仓库\",\n        \"model_column_aliases\": \"别名\",\n        # Model management - new user registry system\n        \"model_no_registered_models\": \"尚未注册任何模型。\",\n        \"model_scan_hint\": \"扫描模型: kt model scan\",\n        \"model_add_hint\": \"添加模型: kt model add /path/to/model\",\n        \"model_registered_models_title\": \"已注册的模型\",\n        \"model_column_format\": \"格式\",\n        \"model_column_repo\": \"仓库\",\n        \"model_column_sha256\": \"SHA256\",\n        \"model_non_moe_hidden_hint\": \"检测到 {count} 个非MoE模型，使用 kt model list --all 展示全部\",\n        \"model_usage_title\": \"常用操作:\",\n        \"model_usage_info\": \"查看详情:\",\n        \"model_usage_edit\": \"编辑模型:\",\n        \"model_usage_verify\": \"校验权重:\",\n        \"model_usage_quant\": \"量化模型:\",\n        \"model_usage_run\": \"运行模型:\",\n        \"model_usage_scan\": \"扫描模型:\",\n        \"model_usage_add\": \"添加模型:\",\n        \"model_usage_verbose\": \"查看包含文件详情:\",\n        \"model_no_storage_paths\": \"未配置存储路径。\",\n        \"model_add_path_hint\": \"添加存储路径: kt config set model.storage_paths /path/to/models\",\n        \"model_scanning_paths\": \"正在扫描配置的存储路径...\",\n        \"model_scanning_progress\": \"扫描中: {path}\",\n        \"model_scan_warnings_title\": \"警告\",\n        \"model_scan_no_models_found\": \"在配置的路径中未找到模型。\",\n        \"model_scan_check_paths_hint\": \"检查存储路径: kt config get model.storage_paths\",\n        \"model_scan_min_size_hint\": \"文件夹必须 ≥{size}GB 才能被识别为模型。\",\n        \"model_scan_found_title\": \"发现 {count} 个新模型\",\n        \"model_column_path\": \"路径\",\n        \"model_column_size\": \"大小\",\n        \"model_scan_auto_adding\": \"正在自动添加模型...\",\n        \"model_added\": \"已添加: {name}\",\n        \"model_add_failed\": \"添加 {name} 失败: {error}\",\n        \"model_scan_complete\": \"扫描完成！已添加 {count} 个模型。\",\n        \"model_scan_interactive_prompt\": \"命令: edit <id> | del <id> | done\",\n        \"model_scan_cmd_edit\": \"设置模型自定义名称和仓库\",\n        \"model_scan_cmd_delete\": \"跳过此模型\",\n        \"model_scan_cmd_done\": \"完成并添加模型\",\n        \"model_scan_marked_skip\": \"已跳过模型 #{id}\",\n        \"model_scan_invalid_id\": \"无效的模型 ID: {id}\",\n        \"model_scan_invalid_command\": \"无效命令。使用: edit <id> | del <id> | done\",\n        \"model_scan_edit_model\": \"编辑模型 {id}\",\n        \"model_scan_edit_note\": \"您可以在添加到注册表前更改模型名称和配置仓库信息\",\n        \"model_scan_adding_models\": \"正在添加 {count} 个模型...\",\n        \"model_scan_next_steps\": \"后续步骤\",\n        \"model_scan_view_hint\": \"查看已注册模型: kt model list\",\n        \"model_scan_edit_hint\": \"编辑模型详情: kt model edit <name>\",\n        \"model_scan_no_models_added\": \"未添加任何模型。\",\n        \"model_add_path_not_exist\": \"错误: 路径不存在: {path}\",\n        \"model_add_not_directory\": \"错误: 路径不是目录: {path}\",\n        \"model_add_already_registered\": \"此路径已注册为: {name}\",\n        \"model_add_view_hint\": \"查看: kt model info {name}\",\n        \"model_add_scanning\": \"正在扫描模型文件...\",\n        \"model_add_scan_failed\": \"扫描模型失败: {error}\",\n        \"model_add_no_model_files\": \"在 {path} 中未找到模型文件\",\n        \"model_add_supported_formats\": \"支持: *.safetensors, *.gguf (文件夹 ≥10GB)\",\n        \"model_add_detected\": \"检测到: {format} 格式, {size}, {count} 个文件\",\n        \"model_add_name_conflict\": \"名称 '{name}' 已存在。\",\n        \"model_add_prompt_name\": \"为此模型输入名称\",\n        \"model_add_name_exists\": \"名称已存在。请选择其他名称:\",\n        \"model_add_configure_repo\": \"配置仓库信息以进行 SHA256 验证?\",\n        \"model_add_repo_type_prompt\": \"选择仓库类型:\",\n        \"model_add_choice\": \"选择\",\n        \"model_add_repo_id_prompt\": \"输入仓库 ID (例如: deepseek-ai/DeepSeek-V3)\",\n        \"model_add_success\": \"成功添加模型: {name}\",\n        \"model_add_verify_hint\": \"验证完整性: kt model verify {name}\",\n        \"model_add_edit_later_hint\": \"稍后编辑详情: kt model edit {name}\",\n        \"model_add_failed_generic\": \"添加模型失败: {error}\",\n        \"model_edit_not_found\": \"未找到模型 '{name}'。\",\n        \"model_edit_list_hint\": \"列出模型: kt model list\",\n        \"model_edit_current_config\": \"当前配置\",\n        \"model_edit_what_to_edit\": \"您想编辑什么?\",\n        \"model_edit_option_name\": \"编辑名称\",\n        \"model_edit_option_repo\": \"配置仓库信息\",\n        \"model_edit_option_delete\": \"删除此模型\",\n        \"model_edit_option_cancel\": \"取消 / 退出\",\n        \"model_edit_choice_prompt\": \"选择选项\",\n        \"model_edit_new_name\": \"输入新名称\",\n        \"model_edit_name_conflict\": \"名称 '{name}' 已存在。请选择其他名称:\",\n        \"model_edit_name_updated\": \"名称已更新: {old} → {new}\",\n        \"model_edit_repo_type_prompt\": \"仓库类型 (或按回车删除仓库信息):\",\n        \"model_edit_repo_remove\": \"删除仓库信息\",\n        \"model_edit_repo_id_prompt\": \"输入仓库 ID\",\n        \"model_edit_repo_removed\": \"仓库信息已删除\",\n        \"model_edit_repo_updated\": \"仓库已配置: {repo_type} → {repo_id}\",\n        \"model_edit_delete_warning\": \"从注册表中删除模型 '{name}'?\",\n        \"model_edit_delete_note\": \"注意: 这只会删除注册表条目。{path} 中的模型文件不会被删除。\",\n        \"model_edit_delete_confirm\": \"确认删除?\",\n        \"model_edit_deleted\": \"模型 '{name}' 已从注册表中删除\",\n        \"model_edit_delete_cancelled\": \"删除已取消\",\n        \"model_edit_cancelled\": \"编辑已取消\",\n        # Model edit - Interactive selection\n        \"model_edit_select_title\": \"选择要编辑的模型\",\n        \"model_edit_select_model\": \"选择模型\",\n        \"model_edit_invalid_choice\": \"无效选择\",\n        \"model_edit_no_models\": \"注册表中未找到模型。\",\n        \"model_edit_add_hint_scan\": \"添加模型:\",\n        \"model_edit_add_hint_add\": \"或:\",\n        # Model edit - Display\n        \"model_edit_gpu_links\": \"GPU 链接:\",\n        # Model edit - Menu options\n        \"model_edit_manage_gpu_links\": \"管理 GPU 链接\",\n        \"model_edit_save_changes\": \"保存更改\",\n        \"model_edit_has_changes\": \"(有更改)\",\n        \"model_edit_no_changes\": \"(无更改)\",\n        # Model edit - Pending changes messages\n        \"model_edit_name_pending\": \"名称将在保存更改时更新。\",\n        \"model_edit_repo_remove_pending\": \"仓库信息将在保存更改时删除。\",\n        \"model_edit_repo_update_pending\": \"仓库信息将在保存更改时更新。\",\n        # Model edit - GPU link management\n        \"model_edit_gpu_links_title\": \"管理 {name} 的 GPU 链接\",\n        \"model_edit_current_gpu_links\": \"当前 GPU 链接:\",\n        \"model_edit_no_gpu_links\": \"未配置 GPU 链接。\",\n        \"model_edit_gpu_options\": \"选项:\",\n        \"model_edit_gpu_add\": \"添加 GPU 链接\",\n        \"model_edit_gpu_remove\": \"删除 GPU 链接\",\n        \"model_edit_gpu_clear\": \"清除所有 GPU 链接\",\n        \"model_edit_gpu_back\": \"返回主菜单\",\n        \"model_edit_gpu_choose_option\": \"选择选项\",\n        \"model_edit_gpu_none_available\": \"没有可链接的 GPU 模型。\",\n        \"model_edit_gpu_available_models\": \"可用的 GPU 模型:\",\n        \"model_edit_gpu_already_linked\": \"(已链接)\",\n        \"model_edit_gpu_enter_number\": \"输入要添加的 GPU 模型编号\",\n        \"model_edit_gpu_link_pending\": \"GPU 链接将在保存更改时添加: {name}\",\n        \"model_edit_gpu_already_exists\": \"此 GPU 模型已链接。\",\n        \"model_edit_gpu_invalid_choice\": \"无效选择。\",\n        \"model_edit_gpu_invalid_input\": \"无效输入。\",\n        \"model_edit_gpu_none_to_remove\": \"没有可删除的 GPU 链接。\",\n        \"model_edit_gpu_choose_to_remove\": \"选择要删除的 GPU 链接:\",\n        \"model_edit_gpu_enter_to_remove\": \"输入要删除的编号\",\n        \"model_edit_gpu_remove_pending\": \"GPU 链接将在保存更改时删除: {name}\",\n        \"model_edit_gpu_none_to_clear\": \"没有可清除的 GPU 链接。\",\n        \"model_edit_gpu_clear_confirm\": \"删除所有 GPU 链接?\",\n        \"model_edit_gpu_clear_pending\": \"所有 GPU 链接将在保存更改时删除。\",\n        \"model_edit_cancelled_short\": \"已取消。\",\n        # Model edit - Save operation\n        \"model_edit_no_changes_to_save\": \"没有更改可保存。\",\n        \"model_edit_saving\": \"正在保存更改...\",\n        \"model_edit_saved\": \"更改保存成功!\",\n        \"model_edit_updated_config\": \"更新后的配置:\",\n        \"model_edit_repo_changed_warning\": \"⚠ 仓库信息已更改。\",\n        \"model_edit_verify_hint\": \"运行 [cyan]kt model verify[/cyan] 以使用 SHA256 校验和验证模型完整性。\",\n        \"model_edit_discard_changes\": \"放弃未保存的更改?\",\n        \"model_info_not_found\": \"未找到模型 '{name}'。\",\n        \"model_info_list_hint\": \"列出所有模型: kt model list\",\n        \"model_remove_not_found\": \"未找到模型 '{name}'。\",\n        \"model_remove_list_hint\": \"列出模型: kt model list\",\n        \"model_remove_warning\": \"从注册表中删除模型 '{name}'?\",\n        \"model_remove_note\": \"注意: 这只会删除注册表条目。模型文件不会从 {path} 中删除。\",\n        \"model_remove_confirm\": \"确认删除?\",\n        \"model_remove_cancelled\": \"删除已取消\",\n        \"model_removed\": \"模型 '{name}' 已从注册表中删除\",\n        \"model_remove_failed\": \"删除模型失败: {error}\",\n        \"model_refresh_checking\": \"正在检查模型路径...\",\n        \"model_refresh_all_valid\": \"所有模型都有效! (已检查 {count} 个模型)\",\n        \"model_refresh_total\": \"总模型数: {total}\",\n        \"model_refresh_missing_found\": \"发现 {count} 个缺失的模型\",\n        \"model_refresh_suggestions\": \"建议操作\",\n        \"model_refresh_remove_hint\": \"从注册表中删除: kt model remove <name>\",\n        \"model_refresh_rescan_hint\": \"重新扫描模型: kt model scan\",\n        \"model_verify_not_found\": \"未找到模型 '{name}'。\",\n        \"model_verify_list_hint\": \"列出模型: kt model list\",\n        \"model_verify_no_repo\": \"模型 '{name}' 未配置仓库信息。\",\n        \"model_verify_config_hint\": \"配置仓库: kt model edit {name}\",\n        \"model_verify_path_missing\": \"模型路径不存在: {path}\",\n        \"model_verify_starting\": \"正在验证模型完整性...\",\n        \"model_verify_progress\": \"仓库: {repo_type} → {repo_id}\",\n        \"model_verify_not_implemented\": \"SHA256 验证尚未实现\",\n        \"model_verify_future_note\": \"此功能将从 {repo_type} 获取官方 SHA256 哈希值并与本地文件进行比较。\",\n        \"model_verify_passed\": \"验证通过！所有文件都与官方哈希匹配。\",\n        \"model_verify_failed\": \"验证失败！{count} 个文件的哈希不匹配。\",\n        \"model_verify_all_no_repos\": \"没有模型配置了仓库信息。\",\n        \"model_verify_all_config_hint\": \"配置仓库使用: kt model edit <name>\",\n        \"model_verify_all_found\": \"发现 {count} 个配置了仓库信息的模型\",\n        \"model_verify_all_manual_hint\": \"验证特定模型: kt model verify <name>\",\n        # Coming soon\n        \"feature_coming_soon\": \"此功能即将推出...\",\n    },\n}\n\n\n# Cache for language detection to avoid repeated I/O\n_lang_cache: str | None = None\n\n\ndef get_lang() -> str:\n    \"\"\"\n    Detect the current language setting.\n\n    Priority:\n    1. KT_LANG environment variable\n    2. Config file general.language setting\n    3. LANG environment variable (if config is \"auto\")\n    4. Default to English\n\n    Returns:\n        Language code: \"zh\" for Chinese, \"en\" for English\n    \"\"\"\n    global _lang_cache\n\n    # 1. Check KT_LANG environment variable (highest priority)\n    kt_lang = os.environ.get(\"KT_LANG\", \"\").lower()\n    if kt_lang:\n        return \"zh\" if kt_lang.startswith(\"zh\") else \"en\"\n\n    # 2. Return cached value if available (avoids I/O on every call)\n    if _lang_cache is not None:\n        return _lang_cache\n\n    # 3. Check config file setting (with caching)\n    # Import here to avoid circular imports\n    from kt_kernel.cli.config.settings import get_settings\n\n    try:\n        settings = get_settings()\n        config_lang = settings.get(\"general.language\", \"auto\")\n        if config_lang and config_lang != \"auto\":\n            lang = \"zh\" if config_lang.lower().startswith(\"zh\") else \"en\"\n            _lang_cache = lang\n            return lang\n    except Exception:\n        # If settings fail to load, continue with system detection\n        pass\n\n    # 4. Check system LANG environment variable\n    system_lang = os.environ.get(\"LANG\", \"\").lower()\n    lang = \"zh\" if system_lang.startswith(\"zh\") else \"en\"\n    _lang_cache = lang\n    return lang\n\n\ndef t(msg_key: str, **kwargs: Any) -> str:\n    \"\"\"\n    Translate a message key to the current language.\n\n    Args:\n        msg_key: Message key to translate\n        **kwargs: Format arguments for the message\n\n    Returns:\n        Translated and formatted message string\n\n    Example:\n        >>> t(\"welcome\")\n        \"Welcome to KTransformers!\"  # or \"欢迎使用 KTransformers！\" in Chinese\n\n        >>> t(\"install_found\", name=\"conda\", version=\"24.1.0\")\n        \"Found conda (version 24.1.0)\"\n    \"\"\"\n    lang = get_lang()\n    messages = MESSAGES.get(lang, MESSAGES[\"en\"])\n    message = messages.get(msg_key, MESSAGES[\"en\"].get(msg_key, msg_key))\n\n    if kwargs:\n        try:\n            return message.format(**kwargs)\n        except KeyError:\n            return message\n    return message\n\n\ndef set_lang(lang: str) -> None:\n    \"\"\"\n    Set the language for the current session.\n\n    Args:\n        lang: Language code (\"en\" or \"zh\")\n    \"\"\"\n    global _lang_cache\n    os.environ[\"KT_LANG\"] = lang\n    _lang_cache = lang  # Update cache when language is explicitly set\n"
  },
  {
    "path": "kt-kernel/python/cli/main.py",
    "content": "\"\"\"\nMain entry point for kt-cli.\n\nKTransformers CLI - A unified command-line interface for KTransformers.\n\"\"\"\n\nimport sys\nimport warnings\n\n# Suppress numpy subnormal warnings\nwarnings.filterwarnings(\"ignore\", message=\"The value of the smallest subnormal\")\n\nimport typer\n\nfrom kt_kernel.cli import __version__\nfrom kt_kernel.cli.commands import bench, chat, config, doctor, model, quant, run, sft, version\nfrom kt_kernel.cli.i18n import t, set_lang, get_lang\n\n\ndef _get_app_help() -> str:\n    \"\"\"Get app help text based on current language.\"\"\"\n    lang = get_lang()\n    if lang == \"zh\":\n        return \"KTransformers CLI - KTransformers 统一命令行界面\"\n    return \"KTransformers CLI - A unified command-line interface for KTransformers.\"\n\n\ndef _get_help(key: str) -> str:\n    \"\"\"Get help text based on current language.\"\"\"\n    help_texts = {\n        \"version\": {\"en\": \"Show version information\", \"zh\": \"显示版本信息\"},\n        \"run\": {\"en\": \"Start model inference server\", \"zh\": \"启动模型推理服务器\"},\n        \"chat\": {\"en\": \"Interactive chat with running model\", \"zh\": \"与运行中的模型进行交互式聊天\"},\n        \"quant\": {\"en\": \"Quantize model weights\", \"zh\": \"量化模型权重\"},\n        \"edit\": {\"en\": \"Edit model information\", \"zh\": \"编辑模型信息\"},\n        \"bench\": {\"en\": \"Run full benchmark\", \"zh\": \"运行完整基准测试\"},\n        \"microbench\": {\"en\": \"Run micro-benchmark\", \"zh\": \"运行微基准测试\"},\n        \"doctor\": {\"en\": \"Diagnose environment issues\", \"zh\": \"诊断环境问题\"},\n        \"model\": {\"en\": \"Manage models and storage paths\", \"zh\": \"管理模型和存储路径\"},\n        \"config\": {\"en\": \"Manage configuration\", \"zh\": \"管理配置\"},\n        \"sft\": {\"en\": \"Fine-tuning with LlamaFactory\", \"zh\": \"使用 LlamaFactory 进行微调\"},\n    }\n    lang = get_lang()\n    return help_texts.get(key, {}).get(lang, help_texts.get(key, {}).get(\"en\", key))\n\n\n# Create main app with dynamic help\napp = typer.Typer(\n    name=\"kt\",\n    help=\"KTransformers CLI - A unified command-line interface for KTransformers.\",\n    no_args_is_help=False,  # Handle no-args case manually to support first-run setup\n    add_completion=False,  # Use static completion scripts instead of dynamic completion\n    rich_markup_mode=\"rich\",\n)\n\n\ndef _update_help_texts() -> None:\n    \"\"\"Update all help texts based on current language setting.\"\"\"\n    # Update main app help\n    app.info.help = _get_app_help()\n\n    # Update command help texts\n    for cmd_info in app.registered_commands:\n        # cmd_info is a CommandInfo object\n        if hasattr(cmd_info, \"name\") and cmd_info.name:\n            cmd_info.help = _get_help(cmd_info.name)\n\n    # Update sub-app help texts\n    for group_info in app.registered_groups:\n        if hasattr(group_info, \"name\") and group_info.name:\n            group_info.help = _get_help(group_info.name)\n\n\n# Commands are registered later after tui_command is defined\n\n\ndef check_first_run() -> None:\n    \"\"\"Check if this is the first run and prompt for language setup.\"\"\"\n    import os\n\n    # Skip if not running in interactive terminal\n    if not sys.stdin.isatty():\n        return\n\n    from kt_kernel.cli.config.settings import DEFAULT_CONFIG_FILE\n\n    # Only check if config file exists - don't create it yet\n    if not DEFAULT_CONFIG_FILE.exists():\n        # First run - show welcome and language selection\n        from kt_kernel.cli.config.settings import get_settings\n\n        settings = get_settings()\n        _show_first_run_setup(settings)\n    else:\n        # Config exists - check if initialized\n        from kt_kernel.cli.config.settings import get_settings\n\n        settings = get_settings()\n        if not settings.get(\"general._initialized\"):\n            _show_first_run_setup(settings)\n\n\ndef _show_first_run_setup(settings) -> None:\n    \"\"\"Show first-run setup wizard.\"\"\"\n    from rich.console import Console\n    from rich.panel import Panel\n    from rich.prompt import Prompt, Confirm\n    from rich.spinner import Spinner\n    from rich.live import Live\n\n    from kt_kernel.cli.utils.environment import scan_storage_locations, format_size_gb\n\n    console = Console()\n\n    # Welcome message\n    console.print()\n    console.print(\n        Panel.fit(\n            \"[bold cyan]Welcome to KTransformers CLI! / 欢迎使用 KTransformers CLI![/bold cyan]\\n\\n\"\n            \"Let's set up your preferences.\\n\"\n            \"让我们设置您的偏好。\",\n            title=\"kt-cli\",\n            border_style=\"cyan\",\n        )\n    )\n    console.print()\n\n    # Language selection\n    console.print(\"[bold]Select your preferred language / 选择您的首选语言:[/bold]\")\n    console.print()\n    console.print(\"  [cyan][1][/cyan] English\")\n    console.print(\"  [cyan][2][/cyan] 中文 (Chinese)\")\n    console.print()\n\n    choice = Prompt.ask(\"Enter choice / 输入选择\", choices=[\"1\", \"2\"], default=\"1\")\n    lang = \"en\" if choice == \"1\" else \"zh\"\n\n    # Save language setting\n    settings.set(\"general.language\", lang)\n    set_lang(lang)\n\n    # Confirmation message\n    console.print()\n    if lang == \"zh\":\n        console.print(\"[green]✓[/green] 语言已设置为中文\")\n    else:\n        console.print(\"[green]✓[/green] Language set to English\")\n\n    # Model discovery section\n    console.print()\n    if lang == \"zh\":\n        console.print(\"[bold]发现模型权重[/bold]\")\n        console.print()\n        console.print(\"[dim]扫描系统中已有的模型权重文件，以便快速添加到模型列表。[/dim]\")\n        console.print()\n        console.print(\"  [cyan][1][/cyan] 全局扫描 (自动扫描所有非系统路径)\")\n        console.print(\"  [cyan][2][/cyan] 手动指定路径 (可添加多个)\")\n        console.print(\"  [cyan][3][/cyan] 跳过 (稍后手动添加)\")\n        console.print()\n        scan_choice = Prompt.ask(\"选择扫描方式\", choices=[\"1\", \"2\", \"3\"], default=\"1\")\n    else:\n        console.print(\"[bold]Discover Model Weights[/bold]\")\n        console.print()\n        console.print(\"[dim]Scan existing model weights on your system to quickly add them to the model list.[/dim]\")\n        console.print()\n        console.print(\"  [cyan][1][/cyan] Global scan (auto-scan all non-system paths)\")\n        console.print(\"  [cyan][2][/cyan] Manual paths (add multiple paths)\")\n        console.print(\"  [cyan][3][/cyan] Skip (add manually later)\")\n        console.print()\n        scan_choice = Prompt.ask(\"Select scan method\", choices=[\"1\", \"2\", \"3\"], default=\"1\")\n\n    if scan_choice == \"1\":\n        # Global scan\n        from kt_kernel.cli.utils.model_discovery import discover_and_register_global, format_discovery_summary\n\n        console.print()\n        try:\n            total_found, new_found, registered = discover_and_register_global(\n                min_size_gb=2.0, max_depth=6, show_progress=True, lang=lang\n            )\n\n            format_discovery_summary(\n                total_found=total_found,\n                new_found=new_found,\n                registered=registered,\n                lang=lang,\n                show_models=True,\n                max_show=10,\n            )\n\n        except Exception as e:\n            console.print(f\"[yellow]Warning: Scan failed - {e}[/yellow]\")\n\n    elif scan_choice == \"2\":\n        # Manual path specification\n        from kt_kernel.cli.utils.model_discovery import discover_and_register_path\n        import os\n\n        discovered_paths = set()  # Track paths discovered in this session\n        total_registered = []\n\n        while True:\n            console.print()\n            if lang == \"zh\":\n                path = Prompt.ask(\"输入要扫描的路径 (例如: /mnt/data/models)\")\n            else:\n                path = Prompt.ask(\"Enter path to scan (e.g., /mnt/data/models)\")\n\n            # Expand and validate path\n            path = os.path.expanduser(path)\n\n            if not os.path.exists(path):\n                if lang == \"zh\":\n                    console.print(f\"[yellow]警告: 路径不存在: {path}[/yellow]\")\n                else:\n                    console.print(f\"[yellow]Warning: Path does not exist: {path}[/yellow]\")\n                continue\n\n            if not os.path.isdir(path):\n                if lang == \"zh\":\n                    console.print(f\"[yellow]警告: 不是一个目录: {path}[/yellow]\")\n                else:\n                    console.print(f\"[yellow]Warning: Not a directory: {path}[/yellow]\")\n                continue\n\n            # Scan this path\n            console.print()\n            try:\n                total_found, new_found, registered = discover_and_register_path(\n                    path=path, min_size_gb=2.0, existing_paths=discovered_paths, show_progress=True, lang=lang\n                )\n\n                # Update discovered paths\n                for model in registered:\n                    discovered_paths.add(model.path)\n                total_registered.extend(registered)\n\n                console.print()\n                if lang == \"zh\":\n                    console.print(f\"[green]✓[/green] 在此路径找到 {total_found} 个模型，其中 {new_found} 个为新模型\")\n                else:\n                    console.print(f\"[green]✓[/green] Found {total_found} models in this path, {new_found} are new\")\n\n                if new_found > 0:\n                    for model in registered[:5]:\n                        console.print(f\"  • {model.name} ({model.format})\")\n\n                    if len(registered) > 5:\n                        if lang == \"zh\":\n                            console.print(f\"  [dim]... 还有 {len(registered) - 5} 个新模型[/dim]\")\n                        else:\n                            console.print(f\"  [dim]... and {len(registered) - 5} more new models[/dim]\")\n\n            except Exception as e:\n                console.print(f\"[red]Error scanning path: {e}[/red]\")\n\n            # Ask if continue\n            console.print()\n            if lang == \"zh\":\n                continue_scan = Confirm.ask(\"是否继续添加其他路径?\", default=False)\n            else:\n                continue_scan = Confirm.ask(\"Continue adding more paths?\", default=False)\n\n            if not continue_scan:\n                break\n\n        if total_registered:\n            console.print()\n            if lang == \"zh\":\n                console.print(f\"[green]✓[/green] 总共发现 {len(total_registered)} 个新模型\")\n            else:\n                console.print(f\"[green]✓[/green] Total {len(total_registered)} new models discovered\")\n\n    # Model storage path selection\n    console.print()\n    console.print(f\"[bold]{t('setup_model_path_title')}[/bold]\")\n    console.print()\n    console.print(f\"[dim]{t('setup_model_path_desc')}[/dim]\")\n    console.print()\n\n    # Scan for storage locations\n    console.print(f\"[dim]{t('setup_scanning_disks')}[/dim]\")\n    locations = scan_storage_locations(min_size_gb=50.0)\n    console.print()\n\n    if locations:\n        # Show storage location options\n        for i, loc in enumerate(locations[:5], 1):  # Show top 5 options\n            available = format_size_gb(loc.available_gb)\n            total = format_size_gb(loc.total_gb)\n\n            # Build the option string\n            if i == 1:\n                option_str = t(\"setup_disk_option_recommended\", path=loc.path, available=available, total=total)\n            else:\n                option_str = t(\"setup_disk_option\", path=loc.path, available=available, total=total)\n\n            console.print(f\"  [cyan][{i}][/cyan] {option_str}\")\n\n        # Custom path option\n        custom_idx = min(len(locations), 5) + 1\n        console.print(f\"  [cyan][{custom_idx}][/cyan] {t('setup_custom_path')}\")\n        console.print()\n\n        valid_choices = [str(i) for i in range(1, custom_idx + 1)]\n        path_choice = Prompt.ask(t(\"prompt_select\"), choices=valid_choices, default=\"1\")\n\n        if path_choice == str(custom_idx):\n            # Custom path\n            selected_path = _prompt_custom_path(console, settings)\n        else:\n            selected_path = locations[int(path_choice) - 1].path\n    else:\n        # No large storage found, ask for custom path\n        console.print(f\"[yellow]{t('setup_no_large_disk')}[/yellow]\")\n        console.print()\n        selected_path = _prompt_custom_path(console, settings)\n\n    # Ensure the path exists\n    import os\n    from pathlib import Path\n\n    if not os.path.exists(selected_path):\n        if Confirm.ask(t(\"setup_path_not_exist\"), default=True):\n            try:\n                Path(selected_path).mkdir(parents=True, exist_ok=True)\n            except (OSError, PermissionError) as e:\n                console.print(f\"[red]{t('error')}: {e}[/red]\")\n                # Fall back to default\n                selected_path = str(Path.home() / \".ktransformers\" / \"models\")\n                Path(selected_path).mkdir(parents=True, exist_ok=True)\n\n    # Check available space and warn if low\n    from kt_kernel.cli.utils.environment import detect_disk_space_gb\n\n    available_gb, _ = detect_disk_space_gb(\n        selected_path if os.path.exists(selected_path) else str(Path(selected_path).parent)\n    )\n    if available_gb < 100:\n        console.print(f\"[yellow]{t('setup_path_low_space')}[/yellow]\")\n\n    # Save the path\n    settings.set(\"paths.models\", selected_path)\n    settings.set(\"general._initialized\", True)\n\n    console.print()\n    console.print(f\"[green]✓[/green] {t('setup_model_path_set', path=selected_path)}\")\n    console.print()\n\n    # Tips\n    if lang == \"zh\":\n        console.print(\"[dim]提示: 运行 'kt config show' 查看所有配置[/dim]\")\n    else:\n        console.print(\"[dim]Tip: Run 'kt config show' to view all settings[/dim]\")\n\n    console.print()\n\n\ndef _prompt_custom_path(console, settings) -> str:\n    \"\"\"Prompt user to enter a custom path.\"\"\"\n    from rich.prompt import Prompt\n    from pathlib import Path\n    import os\n\n    default_path = str(Path.home() / \".ktransformers\" / \"models\")\n\n    while True:\n        custom_path = Prompt.ask(t(\"setup_enter_custom_path\"), default=default_path)\n\n        # Expand user home\n        custom_path = os.path.expanduser(custom_path)\n\n        # Check if path exists or parent is writable\n        if os.path.exists(custom_path):\n            if os.access(custom_path, os.W_OK):\n                return custom_path\n            else:\n                console.print(f\"[red]{t('setup_path_no_write')}[/red]\")\n        else:\n            # Check if we can create it (parent writable)\n            parent = str(Path(custom_path).parent)\n            while not os.path.exists(parent) and parent != \"/\":\n                parent = str(Path(parent).parent)\n\n            if os.access(parent, os.W_OK):\n                return custom_path\n            else:\n                console.print(f\"[red]{t('setup_path_no_write')}[/red]\")\n\n\ndef _install_shell_completion() -> None:\n    \"\"\"Install shell completion scripts to user directories.\n\n    Uses standard locations that are auto-loaded by shell completion systems:\n    - Bash: ~/.local/share/bash-completion/completions/kt (auto-loaded by bash-completion 2.0+)\n    - Zsh: ~/.zfunc/_kt (requires fpath setup, but commonly used)\n    - Fish: ~/.config/fish/completions/kt.fish (auto-loaded)\n    \"\"\"\n    import os\n    import shutil\n    from pathlib import Path\n\n    from kt_kernel.cli.config.settings import get_settings\n\n    settings = get_settings()\n\n    # Check if already installed\n    if settings.get(\"general._completion_installed\", False):\n        return\n\n    # Detect current shell\n    shell = os.environ.get(\"SHELL\", \"\")\n    shell_name = \"zsh\" if \"zsh\" in shell else \"fish\" if \"fish\" in shell else \"bash\"\n\n    try:\n        cli_dir = Path(__file__).parent\n        completions_dir = cli_dir / \"completions\"\n        home = Path.home()\n\n        def install_completion(src_name: str, dest_dir: Path, dest_name: str) -> None:\n            \"\"\"Install completion file from source to destination.\"\"\"\n            src_file = completions_dir / src_name\n            if src_file.exists():\n                dest_dir.mkdir(parents=True, exist_ok=True)\n                shutil.copy2(src_file, dest_dir / dest_name)\n\n        if shell_name == \"bash\":\n            install_completion(\n                \"kt-completion.bash\", home / \".local\" / \"share\" / \"bash-completion\" / \"completions\", \"kt\"\n            )\n        elif shell_name == \"zsh\":\n            install_completion(\"_kt\", home / \".zfunc\", \"_kt\")\n        elif shell_name == \"fish\":\n            install_completion(\"kt.fish\", home / \".config\" / \"fish\" / \"completions\", \"kt.fish\")\n\n        # Mark as installed\n        settings.set(\"general._completion_installed\", True)\n\n        # For bash/zsh, completion will work in new terminals automatically\n        # (bash-completion 2.0+ auto-loads from ~/.local/share/bash-completion/completions/)\n\n    except (OSError, IOError):\n        # Silently ignore errors - completion is not critical\n        pass\n\n\ndef _apply_saved_language() -> None:\n    \"\"\"Apply the saved language setting.\n\n    Priority:\n    1. KT_LANG environment variable (if already set, don't override)\n    2. Config file setting\n    3. System locale (auto)\n    \"\"\"\n    import os\n\n    # Don't override if KT_LANG is already set by user\n    if os.environ.get(\"KT_LANG\"):\n        return\n\n    from kt_kernel.cli.config.settings import get_settings\n\n    settings = get_settings()\n    lang = settings.get(\"general.language\", \"auto\")\n\n    if lang != \"auto\":\n        set_lang(lang)\n\n\napp.command(name=\"version\", help=\"Show version information\")(version.version)\napp.command(name=\"chat\", help=\"Interactive chat with running model\")(chat.chat)\napp.command(name=\"quant\", help=\"Quantize model weights\")(quant.quant)\napp.command(name=\"edit\", help=\"Edit model information\")(model.edit_model)\napp.command(name=\"bench\", help=\"Run full benchmark\")(bench.bench)\napp.command(name=\"microbench\", help=\"Run micro-benchmark\")(bench.microbench)\napp.command(name=\"doctor\", help=\"Diagnose environment issues\")(doctor.doctor)\n\n# Register sub-apps\napp.add_typer(model.app, name=\"model\", help=\"Manage models and storage paths\")\napp.add_typer(config.app, name=\"config\", help=\"Manage configuration\")\napp.add_typer(sft.app, name=\"sft\", help=\"Fine-tuning with LlamaFactory\")\n\n\ndef main():\n    \"\"\"Main entry point.\"\"\"\n    # Apply saved language setting first (before anything else for correct help display)\n    _apply_saved_language()\n\n    # Update help texts based on language\n    _update_help_texts()\n\n    # Check for first run (but not for certain commands)\n    # Skip first-run check for: --help, config commands, version\n    args = sys.argv[1:] if len(sys.argv) > 1 else []\n    skip_commands = [\"--help\", \"-h\", \"config\", \"version\", \"--version\", \"--no-tui\"]\n\n    should_check_first_run = True\n    for arg in args:\n        if arg in skip_commands:\n            should_check_first_run = False\n            break\n\n    # Handle no arguments case\n    if not args:\n        # Check if this is first run\n        from kt_kernel.cli.config.settings import DEFAULT_CONFIG_FILE, get_settings\n\n        is_first_run = False\n        if not DEFAULT_CONFIG_FILE.exists():\n            is_first_run = True\n        else:\n            settings = get_settings()\n            if not settings.get(\"general._initialized\"):\n                is_first_run = True\n\n        if is_first_run:\n            # First run - start initialization\n            _install_shell_completion()\n            check_first_run()\n            return\n        else:\n            # Not first run - show help\n            app([\"--help\"])\n            return\n\n    # Auto-install shell completion on first run\n    if should_check_first_run:\n        _install_shell_completion()\n\n    # Check first run before running commands\n    if should_check_first_run:\n        check_first_run()\n\n    # Handle \"run\" command specially to pass through unknown options\n    if args and args[0] == \"run\":\n        # Get args after \"run\"\n        run_args = args[1:]\n        # Use click command directly with ignore_unknown_options\n        from kt_kernel.cli.commands import run as run_module\n\n        sys.exit(run_module.run.main(args=run_args, standalone_mode=False))\n\n    app()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/python/cli/requirements/inference.txt",
    "content": "# Inference dependencies for KTransformers\n# NOTE: sglang is installed separately from source (see install.py)\n\ntransformers>=4.45.0\nsafetensors>=0.4.0\nhuggingface-hub>=0.20.0\n"
  },
  {
    "path": "kt-kernel/python/cli/requirements/sft.txt",
    "content": "# SFT (Supervised Fine-Tuning) dependencies for KTransformers\n\nllamafactory>=0.9.0\npeft>=0.12.0\ntransformers>=4.45.0\ndatasets>=2.14.0\naccelerate>=0.30.0\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/__init__.py",
    "content": "\"\"\"\nUtility modules for kt-cli.\n\"\"\"\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/analyze_moe_model.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\n快速分析 MoE 模型 - 基于 config.json\n(复用 sglang 的模型注册表和判断逻辑)\n\"\"\"\nimport json\nimport hashlib\nfrom pathlib import Path\nfrom typing import Optional, Dict, Any\n\n\ndef _get_sglang_moe_architectures():\n    \"\"\"\n    从 sglang 的模型注册表获取所有 MoE 架构\n\n    复用 sglang 的代码，这样 sglang 更新后自动支持新模型\n    \"\"\"\n    try:\n        import sys\n\n        # 添加 sglang 路径到 sys.path\n        sglang_path = Path(\"/mnt/data2/ljq/sglang/python\")\n        if sglang_path.exists() and str(sglang_path) not in sys.path:\n            sys.path.insert(0, str(sglang_path))\n\n        # 直接导入 sglang 的 ModelRegistry\n        # 注意：这需要 sglang 及其依赖正确安装\n        from sglang.srt.models.registry import ModelRegistry\n\n        # 获取所有支持的架构\n        supported_archs = ModelRegistry.get_supported_archs()\n\n        # 过滤出 MoE 模型（名称包含 Moe）\n        moe_archs = {arch for arch in supported_archs if \"Moe\" in arch or \"moe\" in arch.lower()}\n\n        # 手动添加一些不带 \"Moe\" 字样但是 MoE 模型的架构\n        # DeepSeek V2/V3 系列\n        deepseek_moe = {arch for arch in supported_archs if arch.startswith(\"Deepseek\") or arch.startswith(\"deepseek\")}\n        moe_archs.update(deepseek_moe)\n\n        # DBRX 也是 MoE 模型\n        dbrx_moe = {arch for arch in supported_archs if \"DBRX\" in arch or \"dbrx\" in arch.lower()}\n        moe_archs.update(dbrx_moe)\n\n        # Grok 也是 MoE 模型\n        grok_moe = {arch for arch in supported_archs if \"Grok\" in arch or \"grok\" in arch.lower()}\n        moe_archs.update(grok_moe)\n\n        return moe_archs\n    except Exception as e:\n        # 如果 sglang 不可用，返回空集合\n        # 这种情况下，后续会使用配置文件中的其他判断方法\n        import warnings\n\n        warnings.warn(f\"Failed to load MoE architectures from sglang: {e}. Using fallback detection methods.\")\n        return set()\n\n\n# 获取 MoE 架构列表（优先从 sglang 获取）\nMOE_ARCHITECTURES = _get_sglang_moe_architectures()\n\n\ndef _get_cache_file():\n    \"\"\"获取集中式缓存文件路径\"\"\"\n    cache_dir = Path.home() / \".ktransformers\" / \"cache\"\n    cache_dir.mkdir(parents=True, exist_ok=True)\n    return cache_dir / \"moe_analysis_v2.json\"\n\n\ndef _load_all_cache():\n    \"\"\"加载所有缓存数据\"\"\"\n    cache_file = _get_cache_file()\n    if not cache_file.exists():\n        return {}\n\n    try:\n        with open(cache_file, \"r\") as f:\n            return json.load(f)\n    except Exception:\n        return {}\n\n\ndef _save_all_cache(cache_data):\n    \"\"\"保存所有缓存数据\"\"\"\n    cache_file = _get_cache_file()\n    try:\n        with open(cache_file, \"w\") as f:\n            json.dump(cache_data, f, indent=2)\n    except Exception as e:\n        import warnings\n\n        warnings.warn(f\"Failed to save MoE cache: {e}\")\n\n\ndef _compute_config_fingerprint(config_path: Path) -> Optional[str]:\n    \"\"\"计算 config.json 指纹\"\"\"\n    if not config_path.exists():\n        return None\n\n    try:\n        stat = config_path.stat()\n        # 使用文件大小和修改时间作为指纹\n        fingerprint_str = f\"{config_path.name}:{stat.st_size}:{int(stat.st_mtime)}\"\n        return hashlib.md5(fingerprint_str.encode()).hexdigest()\n    except Exception:\n        return None\n\n\ndef _load_cache(model_path: Path) -> Optional[Dict[str, Any]]:\n    \"\"\"加载指定模型的缓存\"\"\"\n    model_path_str = str(model_path.resolve())\n    all_cache = _load_all_cache()\n\n    if model_path_str not in all_cache:\n        return None\n\n    try:\n        cache_entry = all_cache[model_path_str]\n\n        # 验证缓存版本\n        cache_version = cache_entry.get(\"cache_version\", 0)\n        if cache_version != 2:\n            return None\n\n        # 验证 config.json 指纹\n        config_path = model_path / \"config.json\"\n        current_fingerprint = _compute_config_fingerprint(config_path)\n        if cache_entry.get(\"fingerprint\") != current_fingerprint:\n            return None\n\n        return cache_entry.get(\"result\")\n    except Exception:\n        return None\n\n\ndef _save_cache(model_path: Path, result: Dict[str, Any]):\n    \"\"\"保存指定模型的缓存\"\"\"\n    model_path_str = str(model_path.resolve())\n\n    try:\n        config_path = model_path / \"config.json\"\n        fingerprint = _compute_config_fingerprint(config_path)\n\n        all_cache = _load_all_cache()\n\n        all_cache[model_path_str] = {\n            \"fingerprint\": fingerprint,\n            \"result\": result,\n            \"cache_version\": 2,\n            \"last_updated\": __import__(\"datetime\").datetime.now().isoformat(),\n        }\n\n        _save_all_cache(all_cache)\n    except Exception as e:\n        import warnings\n\n        warnings.warn(f\"Failed to save MoE cache for {model_path}: {e}\")\n\n\ndef _load_config_json(model_path: Path) -> Optional[Dict[str, Any]]:\n    \"\"\"读取 config.json 文件\n\n    参考 sglang 的 get_config() 实现\n    \"\"\"\n    config_path = model_path / \"config.json\"\n\n    if not config_path.exists():\n        return None\n\n    try:\n        with open(config_path, \"r\", encoding=\"utf-8\") as f:\n            config = json.load(f)\n        return config\n    except Exception:\n        return None\n\n\ndef _is_moe_model(config: Dict[str, Any]) -> bool:\n    \"\"\"判断是否是 MoE 模型\n\n    参考 sglang 的模型注册表和架构识别方式\n    \"\"\"\n    # 方法1: 检查架构名称\n    architectures = config.get(\"architectures\", [])\n    if any(arch in MOE_ARCHITECTURES for arch in architectures):\n        return True\n\n    # 方法2: 检查是否有 MoE 相关字段（Mistral 格式）\n    if config.get(\"moe\"):\n        return True\n\n    # 方法3: 检查是否有 num_experts 或其变体字段\n    # 需要检查 text_config（对于某些多模态模型）\n    text_config = config.get(\"text_config\", config)\n\n    # 检查各种专家数量字段\n    if (\n        text_config.get(\"num_experts\") or text_config.get(\"num_local_experts\") or text_config.get(\"n_routed_experts\")\n    ):  # Kimi-K2 使用这个字段\n        return True\n\n    return False\n\n\ndef _extract_moe_params(config: Dict[str, Any]) -> Dict[str, Any]:\n    \"\"\"从 config 中提取 MoE 参数\n\n    参考 sglang 的各种 MoE 模型实现\n    \"\"\"\n    # 处理嵌套的 text_config\n    text_config = config.get(\"text_config\", config)\n\n    # 提取基本参数\n    result = {\n        \"architectures\": config.get(\"architectures\", []),\n        \"model_type\": config.get(\"model_type\", \"unknown\"),\n    }\n\n    # 专家数量（不同模型字段名不同）\n    num_experts = (\n        text_config.get(\"num_experts\")  # Qwen2/3 MoE, DeepSeek V2\n        or text_config.get(\"num_local_experts\")  # Mixtral\n        or text_config.get(\"n_routed_experts\")  # Kimi-K2, DeepSeek V3\n        or config.get(\"moe\", {}).get(\"num_experts\")  # Mistral 格式\n    )\n\n    # 每个 token 激活的专家数\n    num_experts_per_tok = (\n        text_config.get(\"num_experts_per_tok\")\n        or text_config.get(\"num_experts_per_token\")\n        or config.get(\"moe\", {}).get(\"num_experts_per_tok\")\n        or 2  # 默认值\n    )\n\n    # 层数\n    num_hidden_layers = text_config.get(\"num_hidden_layers\") or text_config.get(\"n_layer\") or 0\n\n    # 隐藏层维度\n    hidden_size = text_config.get(\"hidden_size\") or text_config.get(\"d_model\") or 0\n\n    # MoE 专家中间层大小\n    moe_intermediate_size = (\n        text_config.get(\"moe_intermediate_size\")\n        or text_config.get(\"intermediate_size\")  # 如果没有特殊的 moe_intermediate_size\n        or 0\n    )\n\n    # 共享专家中间层大小（Qwen2/3 MoE）\n    shared_expert_intermediate_size = text_config.get(\"shared_expert_intermediate_size\", 0)\n\n    result.update(\n        {\n            \"num_experts\": num_experts or 0,\n            \"num_experts_per_tok\": num_experts_per_tok,\n            \"num_hidden_layers\": num_hidden_layers,\n            \"hidden_size\": hidden_size,\n            \"moe_intermediate_size\": moe_intermediate_size,\n            \"shared_expert_intermediate_size\": shared_expert_intermediate_size,\n        }\n    )\n\n    # 提取其他有用的参数\n    result[\"num_attention_heads\"] = text_config.get(\"num_attention_heads\", 0)\n    result[\"num_key_value_heads\"] = text_config.get(\"num_key_value_heads\", 0)\n    result[\"vocab_size\"] = text_config.get(\"vocab_size\", 0)\n    result[\"max_position_embeddings\"] = text_config.get(\"max_position_embeddings\", 0)\n\n    return result\n\n\ndef _estimate_model_size(model_path: Path) -> float:\n    \"\"\"估算模型总大小（GB）\n\n    快速统计 safetensors 文件总大小\n    \"\"\"\n    try:\n        total_size = 0\n        for file_path in model_path.glob(\"*.safetensors\"):\n            total_size += file_path.stat().st_size\n        return total_size / (1024**3)\n    except Exception:\n        return 0.0\n\n\ndef analyze_moe_model(model_path, use_cache=True):\n    \"\"\"\n    快速分析 MoE 模型 - 只读取 config.json\n\n    参数:\n        model_path: 模型路径（字符串或Path对象）\n        use_cache: 是否使用缓存（默认True）\n\n    返回:\n        dict: {\n            'is_moe': 是否是 MoE 模型,\n            'num_experts': 专家总数,\n            'num_experts_per_tok': 每个 token 激活的专家数,\n            'num_hidden_layers': 层数,\n            'hidden_size': 隐藏层维度,\n            'moe_intermediate_size': MoE 专家中间层大小,\n            'shared_expert_intermediate_size': 共享专家中间层大小,\n            'architectures': 模型架构列表,\n            'model_type': 模型类型,\n            'total_size_gb': 模型总大小（估算，GB）,\n            'cached': 是否从缓存读取\n        }\n        如果不是 MoE 模型或失败，返回 None\n    \"\"\"\n    model_path = Path(model_path)\n\n    if not model_path.exists():\n        return None\n\n    # 尝试加载缓存\n    if use_cache:\n        cached_result = _load_cache(model_path)\n        if cached_result:\n            cached_result[\"cached\"] = True\n            return cached_result\n\n    # 读取 config.json\n    config = _load_config_json(model_path)\n    if not config:\n        return None\n\n    # 判断是否是 MoE 模型\n    if not _is_moe_model(config):\n        return None\n\n    # 提取 MoE 参数\n    params = _extract_moe_params(config)\n\n    # 验证必要参数\n    if params[\"num_experts\"] == 0:\n        return None\n\n    # 估算模型大小\n    total_size_gb = _estimate_model_size(model_path)\n\n    # 组装结果\n    result = {\n        \"is_moe\": True,\n        \"num_experts\": params[\"num_experts\"],\n        \"num_experts_per_tok\": params[\"num_experts_per_tok\"],\n        \"num_hidden_layers\": params[\"num_hidden_layers\"],\n        \"hidden_size\": params[\"hidden_size\"],\n        \"moe_intermediate_size\": params[\"moe_intermediate_size\"],\n        \"shared_expert_intermediate_size\": params[\"shared_expert_intermediate_size\"],\n        \"architectures\": params[\"architectures\"],\n        \"model_type\": params[\"model_type\"],\n        \"total_size_gb\": total_size_gb,\n        \"cached\": False,\n        # 额外参数\n        \"num_attention_heads\": params.get(\"num_attention_heads\", 0),\n        \"num_key_value_heads\": params.get(\"num_key_value_heads\", 0),\n        \"vocab_size\": params.get(\"vocab_size\", 0),\n    }\n\n    # 保存缓存\n    if use_cache:\n        _save_cache(model_path, result)\n\n    return result\n\n\ndef print_analysis(model_path):\n    \"\"\"打印模型分析结果\"\"\"\n    print(f\"分析模型: {model_path}\\n\")\n\n    result = analyze_moe_model(model_path)\n\n    if result is None:\n        print(\"不是 MoE 模型或分析失败\")\n        return\n\n    print(\"=\" * 70)\n    print(\"MoE 模型分析结果\")\n    if result.get(\"cached\"):\n        print(\"[使用缓存]\")\n    print(\"=\" * 70)\n    print(f\"模型架构:\")\n    print(f\"  - 架构: {', '.join(result['architectures'])}\")\n    print(f\"  - 类型: {result['model_type']}\")\n    print()\n    print(f\"MoE 结构:\")\n    print(f\"  - 专家总数: {result['num_experts']}\")\n    print(f\"  - 激活专家数: {result['num_experts_per_tok']} experts/token\")\n    print(f\"  - 层数: {result['num_hidden_layers']}\")\n    print(f\"  - 隐藏维度: {result['hidden_size']}\")\n    print(f\"  - MoE 中间层: {result['moe_intermediate_size']}\")\n    if result[\"shared_expert_intermediate_size\"] > 0:\n        print(f\"  - 共享专家中间层: {result['shared_expert_intermediate_size']}\")\n    print()\n    print(f\"大小统计:\")\n    print(f\"  - 模型总大小: {result['total_size_gb']:.2f} GB\")\n    print(\"=\" * 70)\n    print()\n\n\ndef main():\n    import sys\n\n    models = [\"/mnt/data2/models/Qwen3-30B-A3B\", \"/mnt/data2/models/Qwen3-235B-A22B-Instruct-2507\"]\n\n    if len(sys.argv) > 1:\n        models = [sys.argv[1]]\n\n    for model_path in models:\n        print_analysis(model_path)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/console.py",
    "content": "\"\"\"\nConsole utilities for kt-cli.\n\nProvides Rich-based console output helpers for consistent formatting.\n\"\"\"\n\nfrom typing import Optional\n\nfrom rich.console import Console\nfrom rich.panel import Panel\nfrom rich.progress import (\n    BarColumn,\n    DownloadColumn,\n    Progress,\n    SpinnerColumn,\n    TaskProgressColumn,\n    TextColumn,\n    TimeElapsedColumn,\n    TimeRemainingColumn,\n    TransferSpeedColumn,\n)\nfrom rich.prompt import Confirm, Prompt\nfrom rich.table import Table\nfrom rich.theme import Theme\n\nfrom kt_kernel.cli.i18n import t\n\n# Custom theme for kt-cli\nKT_THEME = Theme(\n    {\n        \"info\": \"cyan\",\n        \"warning\": \"yellow\",\n        \"error\": \"bold red\",\n        \"success\": \"bold green\",\n        \"highlight\": \"bold magenta\",\n        \"muted\": \"dim\",\n    }\n)\n\n# Global console instance\nconsole = Console(theme=KT_THEME)\n\n\ndef print_info(message: str, **kwargs) -> None:\n    \"\"\"Print an info message.\"\"\"\n    console.print(f\"[info]ℹ[/info] {message}\", **kwargs)\n\n\ndef print_success(message: str, **kwargs) -> None:\n    \"\"\"Print a success message.\"\"\"\n    console.print(f\"[success]✓[/success] {message}\", **kwargs)\n\n\ndef print_warning(message: str, **kwargs) -> None:\n    \"\"\"Print a warning message.\"\"\"\n    console.print(f\"[warning]⚠[/warning] {message}\", **kwargs)\n\n\ndef print_error(message: str, **kwargs) -> None:\n    \"\"\"Print an error message.\"\"\"\n    console.print(f\"[error]✗[/error] {message}\", **kwargs)\n\n\ndef print_step(message: str, **kwargs) -> None:\n    \"\"\"Print a step indicator.\"\"\"\n    console.print(f\"[highlight]→[/highlight] {message}\", **kwargs)\n\n\ndef print_header(title: str, subtitle: Optional[str] = None) -> None:\n    \"\"\"Print a header panel.\"\"\"\n    content = f\"[bold]{title}[/bold]\"\n    if subtitle:\n        content += f\"\\n[muted]{subtitle}[/muted]\"\n    console.print(Panel(content, expand=False))\n\n\ndef print_version_table(versions: dict[str, Optional[str]]) -> None:\n    \"\"\"Print a version information table.\"\"\"\n    table = Table(show_header=False, box=None, padding=(0, 2))\n    table.add_column(\"Component\", style=\"bold\")\n    table.add_column(\"Version\")\n\n    for name, version in versions.items():\n        if version:\n            table.add_row(name, f\"[success]{version}[/success]\")\n        else:\n            table.add_row(name, f\"[muted]{t('version_not_installed')}[/muted]\")\n\n    console.print(table)\n\n\ndef print_dependency_table(deps: list[dict]) -> None:\n    \"\"\"Print a dependency status table.\"\"\"\n    table = Table(title=t(\"install_checking_deps\"))\n    table.add_column(t(\"version_info\"), style=\"bold\")\n    table.add_column(\"Current\")\n    table.add_column(\"Required\")\n    table.add_column(\"Status\")\n\n    for dep in deps:\n        status = dep.get(\"status\", \"ok\")\n        if status == \"ok\":\n            status_str = f\"[success]{t('install_dep_ok')}[/success]\"\n        elif status == \"outdated\":\n            status_str = f\"[warning]{t('install_dep_outdated')}[/warning]\"\n        else:\n            status_str = f\"[error]{t('install_dep_missing')}[/error]\"\n\n        table.add_row(\n            dep[\"name\"],\n            dep.get(\"installed\", \"-\"),\n            dep.get(\"required\", \"-\"),\n            status_str,\n        )\n\n    console.print(table)\n\n\ndef confirm(message: str, default: bool = True) -> bool:\n    \"\"\"Ask for confirmation.\"\"\"\n    return Confirm.ask(message, default=default, console=console)\n\n\ndef prompt_choice(message: str, choices: list[str], default: Optional[str] = None) -> str:\n    \"\"\"Prompt for a choice from a list.\"\"\"\n    # Display numbered choices\n    console.print(f\"\\n[bold]{message}[/bold]\")\n    for i, choice in enumerate(choices, 1):\n        console.print(f\"  [highlight][{i}][/highlight] {choice}\")\n\n    while True:\n        response = Prompt.ask(\n            \"\\n\" + t(\"prompt_select\"),\n            console=console,\n            default=str(choices.index(default) + 1) if default else None,\n        )\n        try:\n            idx = int(response) - 1\n            if 0 <= idx < len(choices):\n                return choices[idx]\n        except ValueError:\n            # Check if response matches a choice directly\n            if response in choices:\n                return response\n\n        print_error(f\"Please enter a number between 1 and {len(choices)}\")\n\n\ndef prompt_text(message: str, default: Optional[str] = None) -> str:\n    \"\"\"Prompt for text input.\"\"\"\n    return Prompt.ask(message, console=console, default=default)\n\n\ndef create_progress() -> Progress:\n    \"\"\"Create a progress bar for general tasks.\"\"\"\n    return Progress(\n        SpinnerColumn(),\n        TextColumn(\"[progress.description]{task.description}\"),\n        BarColumn(),\n        TaskProgressColumn(),\n        TimeElapsedColumn(),\n        console=console,\n    )\n\n\ndef create_download_progress() -> Progress:\n    \"\"\"Create a progress bar for downloads.\"\"\"\n    return Progress(\n        SpinnerColumn(),\n        TextColumn(\"[progress.description]{task.description}\"),\n        BarColumn(),\n        DownloadColumn(),\n        TransferSpeedColumn(),\n        TimeRemainingColumn(),\n        console=console,\n    )\n\n\ndef print_model_table(models: list[dict]) -> None:\n    \"\"\"Print a table of models.\"\"\"\n    table = Table(title=t(\"download_list_title\"))\n    table.add_column(\"Name\", style=\"bold\")\n    table.add_column(\"Repository\")\n    table.add_column(\"Type\")\n    table.add_column(\"Requirements\")\n\n    for model in models:\n        reqs = []\n        if model.get(\"gpu_vram_gb\"):\n            reqs.append(f\"GPU: {model['gpu_vram_gb']}GB\")\n        if model.get(\"cpu_ram_gb\"):\n            reqs.append(f\"RAM: {model['cpu_ram_gb']}GB\")\n\n        table.add_row(\n            model.get(\"name\", \"\"),\n            model.get(\"hf_repo\", \"\"),\n            model.get(\"type\", \"\"),\n            \", \".join(reqs) if reqs else \"-\",\n        )\n\n    console.print(table)\n\n\ndef print_hardware_info(gpu_info: str, cpu_info: str, ram_info: str) -> None:\n    \"\"\"Print hardware information.\"\"\"\n    table = Table(show_header=False, box=None)\n    table.add_column(\"Icon\", width=3)\n    table.add_column(\"Info\")\n\n    table.add_row(\"🖥️\", gpu_info)\n    table.add_row(\"💻\", cpu_info)\n    table.add_row(\"🧠\", ram_info)\n\n    console.print(Panel(table, title=\"Hardware\", expand=False))\n\n\ndef print_server_info(\n    mode: str, host: str, port: int, gpu_experts: int, cpu_threads: int\n) -> None:\n    \"\"\"Print server startup information.\"\"\"\n    table = Table(show_header=False, box=None)\n    table.add_column(\"Key\", style=\"bold\")\n    table.add_column(\"Value\")\n\n    table.add_row(t(\"run_server_mode\").split(\":\")[0], mode)\n    table.add_row(\"Host\", host)\n    table.add_row(\"Port\", str(port))\n    table.add_row(t(\"run_gpu_experts\").split(\":\")[0], f\"{gpu_experts}/layer\")\n    table.add_row(t(\"run_cpu_threads\").split(\":\")[0], str(cpu_threads))\n\n    console.print(Panel(table, title=t(\"run_server_started\"), expand=False, border_style=\"green\"))\n\n\ndef print_api_info(host: str, port: int) -> None:\n    \"\"\"Print API endpoint information.\"\"\"\n    api_url = f\"http://{host}:{port}\"\n    docs_url = f\"http://{host}:{port}/docs\"\n\n    console.print()\n    console.print(f\"  {t('run_api_url', host=host, port=port)}\")\n    console.print(f\"  {t('run_docs_url', host=host, port=port)}\")\n    console.print()\n    console.print(f\"  [muted]Test command:[/muted]\")\n    console.print(\n        f\"  [dim]curl {api_url}/v1/chat/completions -H 'Content-Type: application/json' \"\n        f\"-d '{{\\\"model\\\": \\\"default\\\", \\\"messages\\\": [{{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"Hello\\\"}}]}}'[/dim]\"\n    )\n    console.print()\n    console.print(f\"  [muted]{t('run_stop_hint')}[/muted]\")\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/debug_configs.py",
    "content": "\"\"\"\nDebug utility to inspect saved run configurations.\n\nUsage: python -m kt_kernel.cli.utils.debug_configs\n\"\"\"\n\nfrom pathlib import Path\nimport yaml\nfrom rich.console import Console\nfrom rich.table import Table\nfrom rich import box\n\nconsole = Console()\n\n\ndef main():\n    \"\"\"Show all saved configurations.\"\"\"\n    config_file = Path.home() / \".ktransformers\" / \"run_configs.yaml\"\n\n    console.print()\n    console.print(f\"[bold]Configuration file:[/bold] {config_file}\")\n    console.print()\n\n    if not config_file.exists():\n        console.print(\"[red]✗ Configuration file does not exist![/red]\")\n        console.print()\n        console.print(\"No configurations have been saved yet.\")\n        return\n\n    try:\n        with open(config_file, \"r\", encoding=\"utf-8\") as f:\n            data = yaml.safe_load(f) or {}\n    except Exception as e:\n        console.print(f\"[red]✗ Failed to load configuration file: {e}[/red]\")\n        return\n\n    console.print(f\"[green]✓[/green] Configuration file loaded\")\n    console.print()\n\n    configs = data.get(\"configs\", {})\n\n    if not configs:\n        console.print(\"[yellow]No saved configurations found.[/yellow]\")\n        return\n\n    console.print(f\"[bold]Found configurations for {len(configs)} model(s):[/bold]\")\n    console.print()\n\n    for model_id, model_configs in configs.items():\n        console.print(f\"[cyan]Model ID:[/cyan] {model_id}\")\n        console.print(f\"[dim]  {len(model_configs)} configuration(s)[/dim]\")\n        console.print()\n\n        if not model_configs:\n            continue\n\n        # Display configs in a table\n        table = Table(box=box.ROUNDED, show_header=True, header_style=\"bold cyan\")\n        table.add_column(\"#\", justify=\"right\", style=\"cyan\")\n        table.add_column(\"Name\", style=\"white\")\n        table.add_column(\"Method\", style=\"yellow\")\n        table.add_column(\"TP\", justify=\"right\", style=\"green\")\n        table.add_column(\"GPU Experts\", justify=\"right\", style=\"magenta\")\n        table.add_column(\"Created\", style=\"dim\")\n\n        for i, cfg in enumerate(model_configs, 1):\n            method = cfg.get(\"inference_method\", \"?\")\n            kt_method = cfg.get(\"kt_method\", \"?\")\n            method_display = f\"{method.upper()}\"\n            if method == \"raw\":\n                method_display += f\" ({cfg.get('raw_method', '?')})\"\n            elif method == \"amx\":\n                method_display += f\" ({kt_method})\"\n\n            table.add_row(\n                str(i),\n                cfg.get(\"config_name\", f\"Config {i}\"),\n                method_display,\n                str(cfg.get(\"tp_size\", \"?\")),\n                str(cfg.get(\"gpu_experts\", \"?\")),\n                cfg.get(\"created_at\", \"Unknown\")[:19] if cfg.get(\"created_at\") else \"Unknown\",\n            )\n\n        console.print(table)\n        console.print()\n\n    # Also check user_models.yaml to show model names\n    console.print(\"[bold]Checking model registry...[/bold]\")\n    console.print()\n\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n\n    try:\n        registry = UserModelRegistry()\n        all_models = registry.list_models()\n\n        console.print(f\"[green]✓[/green] Found {len(all_models)} registered model(s)\")\n        console.print()\n\n        # Map model IDs to names\n        id_to_name = {m.id: m.name for m in all_models}\n\n        console.print(\"[bold]Model ID → Name mapping:[/bold]\")\n        console.print()\n\n        for model_id in configs.keys():\n            model_name = id_to_name.get(model_id, \"[red]Unknown (model not found in registry)[/red]\")\n            console.print(f\"  {model_id[:8]}... → {model_name}\")\n\n        console.print()\n\n    except Exception as e:\n        console.print(f\"[yellow]⚠ Could not load model registry: {e}[/yellow]\")\n        console.print()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/download_helper.py",
    "content": "\"\"\"Helper functions for interactive model download.\"\"\"\n\nfrom pathlib import Path\nfrom typing import Dict, List, Tuple\nimport fnmatch\n\n\ndef list_remote_files_hf(repo_id: str, use_mirror: bool = False) -> List[Dict[str, any]]:\n    \"\"\"\n    List files in a HuggingFace repository.\n\n    Returns:\n        List of dicts with keys: 'path', 'size' (in bytes)\n    \"\"\"\n    from huggingface_hub import HfApi\n    import os\n\n    # Set mirror if needed\n    original_endpoint = os.environ.get(\"HF_ENDPOINT\")\n    if use_mirror and not original_endpoint:\n        os.environ[\"HF_ENDPOINT\"] = \"https://hf-mirror.com\"\n\n    try:\n        api = HfApi()\n        files_info = api.list_repo_tree(repo_id=repo_id, recursive=True)\n\n        result = []\n        for item in files_info:\n            # Skip directories\n            if hasattr(item, \"type\") and item.type == \"directory\":\n                continue\n\n            # Get file info\n            file_path = item.path if hasattr(item, \"path\") else str(item)\n            file_size = item.size if hasattr(item, \"size\") else 0\n\n            result.append({\"path\": file_path, \"size\": file_size})\n\n        return result\n    finally:\n        # Restore original endpoint\n        if use_mirror and not original_endpoint:\n            os.environ.pop(\"HF_ENDPOINT\", None)\n        elif original_endpoint:\n            os.environ[\"HF_ENDPOINT\"] = original_endpoint\n\n\ndef list_remote_files_ms(repo_id: str) -> List[Dict[str, any]]:\n    \"\"\"\n    List files in a ModelScope repository.\n\n    Returns:\n        List of dicts with keys: 'path', 'size' (in bytes)\n    \"\"\"\n    from modelscope.hub.api import HubApi\n\n    api = HubApi()\n    files_info = api.get_model_files(model_id=repo_id, recursive=True)\n\n    result = []\n    for file_info in files_info:\n        file_path = file_info.get(\"Name\", file_info.get(\"Path\", \"\"))\n        file_size = file_info.get(\"Size\", 0)\n\n        result.append({\"path\": file_path, \"size\": file_size})\n\n    return result\n\n\ndef filter_files_by_pattern(files: List[Dict[str, any]], pattern: str) -> List[Dict[str, any]]:\n    \"\"\"Filter files by glob pattern.\"\"\"\n    if pattern == \"*\":\n        return files\n\n    filtered = []\n    for file in files:\n        # Check if filename matches pattern\n        filename = Path(file[\"path\"]).name\n        full_path = file[\"path\"]\n\n        if fnmatch.fnmatch(filename, pattern) or fnmatch.fnmatch(full_path, pattern):\n            filtered.append(file)\n\n    return filtered\n\n\ndef calculate_total_size(files: List[Dict[str, any]]) -> int:\n    \"\"\"Calculate total size of files in bytes.\"\"\"\n    return sum(f[\"size\"] for f in files)\n\n\ndef format_file_list_table(files: List[Dict[str, any]], max_display: int = 10):\n    \"\"\"Format file list as a table for display.\"\"\"\n    from rich.table import Table\n    from kt_kernel.cli.utils.model_scanner import format_size\n\n    table = Table(show_header=True, header_style=\"bold\")\n    table.add_column(\"File\", style=\"cyan\", overflow=\"fold\")\n    table.add_column(\"Size\", justify=\"right\")\n\n    # Show first max_display files\n    for file in files[:max_display]:\n        table.add_row(file[\"path\"], format_size(file[\"size\"]))\n\n    if len(files) > max_display:\n        table.add_row(f\"... and {len(files) - max_display} more files\", \"[dim]...[/dim]\")\n\n    return table\n\n\ndef verify_repo_exists(repo_id: str, repo_type: str, use_mirror: bool = False) -> Tuple[bool, str]:\n    \"\"\"\n    Verify if a repository exists.\n\n    Returns:\n        (exists: bool, message: str)\n    \"\"\"\n    try:\n        if repo_type == \"huggingface\":\n            import os\n\n            original_endpoint = os.environ.get(\"HF_ENDPOINT\")\n            if use_mirror and not original_endpoint:\n                os.environ[\"HF_ENDPOINT\"] = \"https://hf-mirror.com\"\n\n            from huggingface_hub import HfApi\n\n            try:\n                api = HfApi()\n                api.repo_info(repo_id=repo_id, repo_type=\"model\")\n                return True, \"Repository found\"\n            finally:\n                if use_mirror and not original_endpoint:\n                    os.environ.pop(\"HF_ENDPOINT\", None)\n                elif original_endpoint:\n                    os.environ[\"HF_ENDPOINT\"] = original_endpoint\n\n        else:  # modelscope\n            from modelscope.hub.api import HubApi\n\n            api = HubApi()\n            api.get_model(model_id=repo_id)\n            return True, \"Repository found\"\n\n    except Exception as e:\n        return False, f\"Repository not found: {str(e)}\"\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/environment.py",
    "content": "\"\"\"\nEnvironment detection utilities for kt-cli.\n\nProvides functions to detect:\n- Virtual environment managers (conda, venv, uv, mamba)\n- Python version and packages\n- CUDA and GPU information\n- System resources (CPU, RAM, disk)\n\"\"\"\n\nimport os\nimport platform\nimport shutil\nimport subprocess\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Optional\n\n\n@dataclass\nclass EnvManager:\n    \"\"\"Information about an environment manager.\"\"\"\n\n    name: str\n    version: str\n    path: str\n\n\n@dataclass\nclass GPUInfo:\n    \"\"\"Information about a GPU.\"\"\"\n\n    index: int\n    name: str\n    vram_gb: float\n    cuda_capability: Optional[str] = None\n\n\n@dataclass\nclass CPUInfo:\n    \"\"\"Information about the CPU.\"\"\"\n\n    name: str\n    cores: int\n    threads: int\n    numa_nodes: int\n    instruction_sets: list[str] = field(default_factory=list)  # AVX, AVX2, AVX512, AMX, etc.\n    numa_info: dict = field(default_factory=dict)  # node -> cpus mapping\n\n\n@dataclass\nclass MemoryInfo:\n    \"\"\"Information about system memory.\"\"\"\n\n    total_gb: float\n    available_gb: float\n    frequency_mhz: Optional[int] = None\n    channels: Optional[int] = None\n    type: Optional[str] = None  # DDR4, DDR5, etc.\n\n\n@dataclass\nclass SystemInfo:\n    \"\"\"Complete system information.\"\"\"\n\n    python_version: str\n    platform: str\n    cuda_version: Optional[str]\n    gpus: list[GPUInfo]\n    cpu: CPUInfo\n    ram_gb: float\n    env_managers: list[EnvManager]\n\n\ndef run_command(cmd: list[str], timeout: int = 10) -> Optional[str]:\n    \"\"\"Run a command and return its output, or None if it fails.\"\"\"\n    try:\n        result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, check=False)\n        if result.returncode == 0:\n            return result.stdout.strip()\n        return None\n    except (subprocess.TimeoutExpired, FileNotFoundError, OSError):\n        return None\n\n\ndef detect_env_managers() -> list[EnvManager]:\n    \"\"\"Detect available virtual environment managers.\"\"\"\n    managers = []\n\n    # Check conda\n    conda_path = shutil.which(\"conda\")\n    if conda_path:\n        version = run_command([\"conda\", \"--version\"])\n        if version:\n            # \"conda 24.1.0\" -> \"24.1.0\"\n            version = version.split()[-1] if version else \"unknown\"\n            managers.append(EnvManager(name=\"conda\", version=version, path=conda_path))\n\n    # Check mamba\n    mamba_path = shutil.which(\"mamba\")\n    if mamba_path:\n        version = run_command([\"mamba\", \"--version\"])\n        if version:\n            # First line: \"mamba 1.5.0\"\n            version = version.split(\"\\n\")[0].split()[-1] if version else \"unknown\"\n            managers.append(EnvManager(name=\"mamba\", version=version, path=mamba_path))\n\n    # Check uv\n    uv_path = shutil.which(\"uv\")\n    if uv_path:\n        version = run_command([\"uv\", \"--version\"])\n        if version:\n            # \"uv 0.5.0\" -> \"0.5.0\"\n            version = version.split()[-1] if version else \"unknown\"\n            managers.append(EnvManager(name=\"uv\", version=version, path=uv_path))\n\n    # Check if venv is available (built into Python)\n    try:\n        import venv  # noqa: F401\n\n        managers.append(EnvManager(name=\"venv\", version=\"builtin\", path=\"python -m venv\"))\n    except ImportError:\n        pass\n\n    return managers\n\n\ndef check_docker() -> Optional[EnvManager]:\n    \"\"\"Check if Docker is available.\"\"\"\n    docker_path = shutil.which(\"docker\")\n    if docker_path:\n        version = run_command([\"docker\", \"--version\"])\n        if version:\n            # \"Docker version 24.0.7, build afdd53b\"\n            parts = version.split()\n            version = parts[2].rstrip(\",\") if len(parts) > 2 else \"unknown\"\n            return EnvManager(name=\"docker\", version=version, path=docker_path)\n    return None\n\n\ndef check_kt_env_exists(manager: str, env_name: str = \"kt\") -> bool:\n    \"\"\"Check if a kt environment exists for the given manager.\"\"\"\n    if manager == \"conda\" or manager == \"mamba\":\n        result = run_command([manager, \"env\", \"list\"])\n        if result:\n            # Check if env_name appears as a separate word in the output\n            for line in result.split(\"\\n\"):\n                parts = line.split()\n                if parts and parts[0] == env_name:\n                    return True\n    elif manager == \"uv\":\n        # uv uses .venv in the project directory or ~/.local/share/uv/envs/\n        venv_path = Path.home() / \".local\" / \"share\" / \"uv\" / \"envs\" / env_name\n        if venv_path.exists():\n            return True\n        # Also check current directory\n        if Path(env_name).exists() and (Path(env_name) / \"bin\" / \"python\").exists():\n            return True\n    elif manager == \"venv\":\n        # Check common locations\n        venv_path = Path.home() / \".virtualenvs\" / env_name\n        if venv_path.exists():\n            return True\n        if Path(env_name).exists() and (Path(env_name) / \"bin\" / \"python\").exists():\n            return True\n\n    return False\n\n\ndef get_kt_env_path(manager: str, env_name: str = \"kt\") -> Optional[Path]:\n    \"\"\"Get the path to the kt environment.\"\"\"\n    if manager == \"conda\" or manager == \"mamba\":\n        result = run_command([manager, \"env\", \"list\"])\n        if result:\n            for line in result.split(\"\\n\"):\n                parts = line.split()\n                if parts and parts[0] == env_name:\n                    # The path is the last part\n                    return Path(parts[-1])\n    elif manager == \"uv\":\n        venv_path = Path.home() / \".local\" / \"share\" / \"uv\" / \"envs\" / env_name\n        if venv_path.exists():\n            return venv_path\n    elif manager == \"venv\":\n        venv_path = Path.home() / \".virtualenvs\" / env_name\n        if venv_path.exists():\n            return venv_path\n\n    return None\n\n\ndef detect_cuda_version() -> Optional[str]:\n    \"\"\"Detect CUDA version from nvidia-smi or nvcc.\"\"\"\n    # Try nvidia-smi first\n    nvidia_smi = run_command([\"nvidia-smi\", \"--query-gpu=driver_version\", \"--format=csv,noheader\"])\n    if nvidia_smi:\n        # Get CUDA version from nvidia-smi\n        full_output = run_command([\"nvidia-smi\"])\n        if full_output:\n            for line in full_output.split(\"\\n\"):\n                if \"CUDA Version:\" in line:\n                    # \"| CUDA Version: 12.1     |\"\n                    parts = line.split(\"CUDA Version:\")\n                    if len(parts) > 1:\n                        version = parts[1].strip().split()[0]\n                        return version\n\n    # Try nvcc\n    nvcc_output = run_command([\"nvcc\", \"--version\"])\n    if nvcc_output:\n        for line in nvcc_output.split(\"\\n\"):\n            if \"release\" in line.lower():\n                # \"Cuda compilation tools, release 12.1, V12.1.105\"\n                parts = line.split(\"release\")\n                if len(parts) > 1:\n                    version = parts[1].strip().split(\",\")[0].strip()\n                    return version\n\n    return None\n\n\ndef detect_gpus() -> list[GPUInfo]:\n    \"\"\"Detect available NVIDIA GPUs, respecting CUDA_VISIBLE_DEVICES.\"\"\"\n    gpus = []\n\n    nvidia_smi = run_command([\"nvidia-smi\", \"--query-gpu=index,name,memory.total\", \"--format=csv,noheader,nounits\"])\n\n    if nvidia_smi:\n        for line in nvidia_smi.strip().split(\"\\n\"):\n            parts = [p.strip() for p in line.split(\",\")]\n            if len(parts) >= 3:\n                try:\n                    index = int(parts[0])\n                    name = parts[1]\n                    vram_mb = float(parts[2])\n                    vram_gb = round(vram_mb / 1024, 1)\n                    gpus.append(GPUInfo(index=index, name=name, vram_gb=vram_gb))\n                except (ValueError, IndexError):\n                    continue\n\n    # Filter by CUDA_VISIBLE_DEVICES if set\n    cuda_visible = os.environ.get(\"CUDA_VISIBLE_DEVICES\")\n    if cuda_visible is not None:\n        if cuda_visible == \"\":\n            # Empty string means no GPUs visible\n            return []\n\n        try:\n            # Parse CUDA_VISIBLE_DEVICES (can be \"0,1,2\" or \"0-3\" etc.)\n            visible_indices = _parse_cuda_visible_devices(cuda_visible)\n            # Filter GPUs to only those in CUDA_VISIBLE_DEVICES\n            filtered_gpus = [gpu for gpu in gpus if gpu.index in visible_indices]\n            # Re-index GPUs to match CUDA's logical indexing (0, 1, 2, ...)\n            for i, gpu in enumerate(filtered_gpus):\n                # Keep original index in a comment, but CUDA sees them as 0,1,2...\n                gpu.index = i\n            return filtered_gpus\n        except ValueError:\n            # If parsing fails, return all GPUs as fallback\n            pass\n\n    return gpus\n\n\ndef _parse_cuda_visible_devices(cuda_visible: str) -> list[int]:\n    \"\"\"Parse CUDA_VISIBLE_DEVICES string into list of GPU indices.\n\n    Supports formats like:\n    - \"0,1,2,3\" -> [0, 1, 2, 3]\n    - \"0-3\" -> [0, 1, 2, 3]\n    - \"0,2-4,7\" -> [0, 2, 3, 4, 7]\n    \"\"\"\n    indices = []\n    parts = cuda_visible.split(\",\")\n\n    for part in parts:\n        part = part.strip()\n        if \"-\" in part:\n            # Range like \"0-3\"\n            start, end = part.split(\"-\")\n            indices.extend(range(int(start), int(end) + 1))\n        else:\n            # Single index\n            indices.append(int(part))\n\n    return sorted(set(indices))  # Remove duplicates and sort\n\n\ndef detect_cpu_info() -> CPUInfo:\n    \"\"\"Detect CPU information including instruction sets and NUMA topology.\"\"\"\n    name = \"Unknown\"\n    cores = os.cpu_count() or 1\n    threads = cores\n    numa_nodes = 1\n    instruction_sets: list[str] = []\n    numa_info: dict[str, list[int]] = {}\n\n    if platform.system() == \"Linux\":\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                content = f.read()\n\n            # Get CPU name\n            for line in content.split(\"\\n\"):\n                if line.startswith(\"model name\"):\n                    name = line.split(\":\")[1].strip()\n                    break\n\n            # Get physical cores vs threads\n            cpu_cores = content.count(\"processor\\t:\")\n            if cpu_cores > 0:\n                threads = cpu_cores\n\n            siblings = None\n            cores_per = None\n            for line in content.split(\"\\n\"):\n                if \"siblings\" in line:\n                    siblings = int(line.split(\":\")[1].strip())\n                if \"cpu cores\" in line:\n                    cores_per = int(line.split(\":\")[1].strip())\n            if siblings and cores_per:\n                cores = threads // (siblings // cores_per) if siblings > cores_per else threads\n\n            # Get instruction sets from flags\n            for line in content.split(\"\\n\"):\n                if line.startswith(\"flags\"):\n                    flags = line.split(\":\")[1].strip().split()\n                    instruction_sets = _parse_cpu_flags(flags)\n                    break\n\n        except (OSError, IOError, ValueError):\n            pass\n\n        # Get NUMA topology\n        numa_path = Path(\"/sys/devices/system/node\")\n        if numa_path.exists():\n            numa_dirs = [d for d in numa_path.iterdir() if d.name.startswith(\"node\")]\n            numa_nodes = len(numa_dirs)\n\n            for node_dir in numa_dirs:\n                node_name = node_dir.name  # e.g., \"node0\"\n                cpulist_path = node_dir / \"cpulist\"\n                if cpulist_path.exists():\n                    try:\n                        cpulist = cpulist_path.read_text().strip()\n                        numa_info[node_name] = _parse_cpu_list(cpulist)\n                    except (OSError, IOError):\n                        pass\n\n    elif platform.system() == \"Darwin\":\n        # macOS\n        name_output = run_command([\"sysctl\", \"-n\", \"machdep.cpu.brand_string\"])\n        if name_output:\n            name = name_output.strip()\n        cores_output = run_command([\"sysctl\", \"-n\", \"hw.physicalcpu\"])\n        if cores_output:\n            cores = int(cores_output.strip())\n        threads_output = run_command([\"sysctl\", \"-n\", \"hw.logicalcpu\"])\n        if threads_output:\n            threads = int(threads_output.strip())\n\n        # Get instruction sets on macOS\n        features_output = run_command([\"sysctl\", \"-n\", \"machdep.cpu.features\"])\n        if features_output:\n            flags = features_output.lower().split()\n            instruction_sets = _parse_cpu_flags(flags)\n\n    return CPUInfo(\n        name=name,\n        cores=cores,\n        threads=threads,\n        numa_nodes=numa_nodes,\n        instruction_sets=instruction_sets,\n        numa_info=numa_info,\n    )\n\n\ndef _parse_cpu_flags(flags: list[str]) -> list[str]:\n    \"\"\"Parse CPU flags to extract relevant instruction sets for KTransformers.\"\"\"\n    # Instruction sets important for KTransformers/kt-kernel\n    relevant_instructions = {\n        # Basic SIMD\n        \"sse\": \"SSE\",\n        \"sse2\": \"SSE2\",\n        \"sse3\": \"SSE3\",\n        \"ssse3\": \"SSSE3\",\n        \"sse4_1\": \"SSE4.1\",\n        \"sse4_2\": \"SSE4.2\",\n        # AVX family\n        \"avx\": \"AVX\",\n        \"avx2\": \"AVX2\",\n        \"avx512f\": \"AVX512F\",\n        \"avx512bw\": \"AVX512BW\",\n        \"avx512vl\": \"AVX512VL\",\n        \"avx512dq\": \"AVX512DQ\",\n        \"avx512cd\": \"AVX512CD\",\n        \"avx512vnni\": \"AVX512VNNI\",\n        \"avx512_bf16\": \"AVX512BF16\",\n        \"avx512_fp16\": \"AVX512FP16\",\n        \"avx_vnni\": \"AVX-VNNI\",\n        # AMX (Advanced Matrix Extensions) - Intel\n        \"amx_tile\": \"AMX-TILE\",\n        \"amx_bf16\": \"AMX-BF16\",\n        \"amx_int8\": \"AMX-INT8\",\n        \"amx_fp16\": \"AMX-FP16\",\n        # Other relevant\n        \"fma\": \"FMA\",\n        \"f16c\": \"F16C\",\n        \"bmi1\": \"BMI1\",\n        \"bmi2\": \"BMI2\",\n    }\n\n    found = []\n    flags_lower = {f.lower() for f in flags}\n\n    for flag, display_name in relevant_instructions.items():\n        if flag in flags_lower:\n            found.append(display_name)\n\n    # Sort by importance for display\n    priority = [\n        \"AMX-INT8\",\n        \"AMX-BF16\",\n        \"AMX-FP16\",\n        \"AMX-TILE\",\n        \"AVX512BF16\",\n        \"AVX512VNNI\",\n        \"AVX512F\",\n        \"AVX512BW\",\n        \"AVX512VL\",\n        \"AVX2\",\n        \"AVX\",\n        \"FMA\",\n        \"SSE4.2\",\n    ]\n    result = []\n    for p in priority:\n        if p in found:\n            result.append(p)\n            found.remove(p)\n    result.extend(sorted(found))  # Add remaining\n\n    return result\n\n\ndef _parse_cpu_list(cpulist: str) -> list[int]:\n    \"\"\"Parse CPU list string like '0-3,8-11' to list of CPU IDs.\"\"\"\n    cpus = []\n    for part in cpulist.split(\",\"):\n        if \"-\" in part:\n            start, end = part.split(\"-\")\n            cpus.extend(range(int(start), int(end) + 1))\n        else:\n            cpus.append(int(part))\n    return cpus\n\n\ndef detect_memory_info() -> MemoryInfo:\n    \"\"\"Detect detailed memory information including frequency and type.\"\"\"\n    total_gb = detect_ram_gb()\n    available_gb = detect_available_ram_gb()\n    frequency_mhz: Optional[int] = None\n    channels: Optional[int] = None\n    mem_type: Optional[str] = None\n\n    if platform.system() == \"Linux\":\n        # Try dmidecode without sudo first (may work if user has permissions)\n        dmidecode_output = run_command([\"dmidecode\", \"-t\", \"memory\"])\n        if dmidecode_output:\n            frequency_mhz, mem_type, channels = _parse_dmidecode_memory(dmidecode_output)\n\n        # Fallback: try to read from /sys or /proc\n        if frequency_mhz is None:\n            frequency_mhz = _detect_memory_frequency_sysfs()\n\n    elif platform.system() == \"Darwin\":\n        # macOS - use system_profiler\n        mem_output = run_command([\"system_profiler\", \"SPMemoryDataType\"])\n        if mem_output:\n            frequency_mhz, mem_type = _parse_macos_memory(mem_output)\n\n    return MemoryInfo(\n        total_gb=total_gb,\n        available_gb=available_gb,\n        frequency_mhz=frequency_mhz,\n        channels=channels,\n        type=mem_type,\n    )\n\n\ndef _parse_dmidecode_memory(output: str) -> tuple[Optional[int], Optional[str], Optional[int]]:\n    \"\"\"Parse dmidecode memory output.\"\"\"\n    frequency_mhz: Optional[int] = None\n    mem_type: Optional[str] = None\n    dimm_count = 0\n\n    for line in output.split(\"\\n\"):\n        line = line.strip()\n        if line.startswith(\"Speed:\") and \"MHz\" in line:\n            try:\n                # \"Speed: 4800 MHz\" or \"Speed: 4800 MT/s\"\n                parts = line.split(\":\")[1].strip().split()\n                freq = int(parts[0])\n                if freq > 0 and (frequency_mhz is None or freq > frequency_mhz):\n                    frequency_mhz = freq\n            except (ValueError, IndexError):\n                pass\n        elif line.startswith(\"Type:\") and mem_type is None:\n            type_val = line.split(\":\")[1].strip()\n            if type_val and type_val != \"Unknown\":\n                mem_type = type_val\n        elif line.startswith(\"Size:\") and \"MB\" in line or \"GB\" in line:\n            dimm_count += 1\n\n    return frequency_mhz, mem_type, dimm_count if dimm_count > 0 else None\n\n\ndef _detect_memory_frequency_sysfs() -> Optional[int]:\n    \"\"\"Try to detect memory frequency from sysfs.\"\"\"\n    # This is a fallback and may not work on all systems\n    try:\n        # Try reading from edac\n        edac_path = Path(\"/sys/devices/system/edac/mc\")\n        if edac_path.exists():\n            for mc_dir in edac_path.iterdir():\n                freq_file = mc_dir / \"mc_config\"\n                if freq_file.exists():\n                    content = freq_file.read_text()\n                    # Parse for frequency information\n                    # Format varies by system\n                    pass\n    except (OSError, IOError):\n        pass\n\n    return None\n\n\ndef _parse_macos_memory(output: str) -> tuple[Optional[int], Optional[str]]:\n    \"\"\"Parse macOS system_profiler memory output.\"\"\"\n    frequency_mhz: Optional[int] = None\n    mem_type: Optional[str] = None\n\n    for line in output.split(\"\\n\"):\n        line = line.strip()\n        if \"Speed:\" in line:\n            try:\n                parts = line.split(\":\")[1].strip().split()\n                frequency_mhz = int(parts[0])\n            except (ValueError, IndexError):\n                pass\n        elif \"Type:\" in line:\n            mem_type = line.split(\":\")[1].strip()\n\n    return frequency_mhz, mem_type\n\n\ndef detect_ram_gb() -> float:\n    \"\"\"Detect total system RAM in GB.\"\"\"\n    if platform.system() == \"Linux\":\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if line.startswith(\"MemTotal:\"):\n                        # \"MemTotal:       32780516 kB\"\n                        kb = int(line.split()[1])\n                        return round(kb / 1024 / 1024, 1)\n        except (OSError, IOError, ValueError):\n            pass\n    elif platform.system() == \"Darwin\":\n        mem_output = run_command([\"sysctl\", \"-n\", \"hw.memsize\"])\n        if mem_output:\n            return round(int(mem_output) / 1024 / 1024 / 1024, 1)\n\n    # Fallback\n    try:\n        import psutil\n\n        return round(psutil.virtual_memory().total / 1024 / 1024 / 1024, 1)\n    except ImportError:\n        return 0.0\n\n\ndef detect_available_ram_gb() -> float:\n    \"\"\"Detect available system RAM in GB.\"\"\"\n    if platform.system() == \"Linux\":\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if line.startswith(\"MemAvailable:\"):\n                        kb = int(line.split()[1])\n                        return round(kb / 1024 / 1024, 1)\n        except (OSError, IOError, ValueError):\n            pass\n\n    # Fallback\n    try:\n        import psutil\n\n        return round(psutil.virtual_memory().available / 1024 / 1024 / 1024, 1)\n    except ImportError:\n        return 0.0\n\n\ndef detect_disk_space_gb(path: str = \"/\") -> tuple[float, float]:\n    \"\"\"Detect disk space (available, total) in GB for the given path.\"\"\"\n    try:\n        import shutil\n\n        total, used, free = shutil.disk_usage(path)\n        return round(free / 1024 / 1024 / 1024, 1), round(total / 1024 / 1024 / 1024, 1)\n    except (OSError, IOError):\n        return 0.0, 0.0\n\n\ndef get_installed_package_version(package_name: str) -> Optional[str]:\n    \"\"\"Get the version of an installed Python package.\"\"\"\n    try:\n        from importlib.metadata import version\n\n        return version(package_name)\n    except Exception:\n        return None\n\n\ndef get_system_info() -> SystemInfo:\n    \"\"\"Gather complete system information.\"\"\"\n    return SystemInfo(\n        python_version=platform.python_version(),\n        platform=f\"{platform.system()} {platform.release()}\",\n        cuda_version=detect_cuda_version(),\n        gpus=detect_gpus(),\n        cpu=detect_cpu_info(),\n        ram_gb=detect_ram_gb(),\n        env_managers=detect_env_managers(),\n    )\n\n\ndef is_in_virtual_env() -> bool:\n    \"\"\"Check if currently running inside a virtual environment.\"\"\"\n    return (\n        hasattr(sys, \"real_prefix\")\n        or (hasattr(sys, \"base_prefix\") and sys.base_prefix != sys.prefix)\n        or os.environ.get(\"VIRTUAL_ENV\") is not None\n        or os.environ.get(\"CONDA_PREFIX\") is not None\n    )\n\n\ndef get_current_env_name() -> Optional[str]:\n    \"\"\"Get the name of the current virtual environment.\"\"\"\n    if os.environ.get(\"CONDA_DEFAULT_ENV\"):\n        return os.environ[\"CONDA_DEFAULT_ENV\"]\n    if os.environ.get(\"VIRTUAL_ENV\"):\n        return Path(os.environ[\"VIRTUAL_ENV\"]).name\n    return None\n\n\n# Import sys for is_in_virtual_env\nimport sys  # noqa: E402\n\n\n@dataclass\nclass StorageLocation:\n    \"\"\"Information about a storage location.\"\"\"\n\n    path: str\n    available_gb: float\n    total_gb: float\n    is_writable: bool\n    mount_point: str\n\n\ndef scan_storage_locations(min_size_gb: float = 50.0) -> list[StorageLocation]:\n    \"\"\"\n    Scan system for potential model storage locations.\n\n    Looks for:\n    - Large mounted filesystems (> min_size_gb)\n    - Common model storage paths\n    - User home directory\n\n    Args:\n        min_size_gb: Minimum available space in GB to consider\n\n    Returns:\n        List of StorageLocation sorted by available space (descending)\n    \"\"\"\n    locations: dict[str, StorageLocation] = {}  # Use dict to deduplicate by path\n\n    # Get all mount points from /proc/mounts (Linux)\n    mount_points = _get_mount_points()\n\n    for mount_point in mount_points:\n        try:\n            available_gb, total_gb = detect_disk_space_gb(mount_point)\n\n            # Skip small or pseudo filesystems\n            if total_gb < 10:\n                continue\n\n            # Check if writable\n            is_writable = os.access(mount_point, os.W_OK)\n\n            # Create potential model paths under this mount\n            potential_paths = _get_potential_model_paths(mount_point)\n\n            for path in potential_paths:\n                if path in locations:\n                    continue\n\n                # Get actual available space for this path\n                path_available, path_total = detect_disk_space_gb(path)\n\n                if path_available >= min_size_gb:\n                    path_writable = os.access(path, os.W_OK) if os.path.exists(path) else is_writable\n                    locations[path] = StorageLocation(\n                        path=path,\n                        available_gb=path_available,\n                        total_gb=path_total,\n                        is_writable=path_writable,\n                        mount_point=mount_point,\n                    )\n        except (OSError, IOError):\n            continue\n\n    # Also check common model storage locations\n    common_paths = [\n        str(Path.home() / \".ktransformers\" / \"models\"),\n        str(Path.home() / \"models\"),\n        str(Path.home() / \".cache\" / \"huggingface\"),\n        \"/data/models\",\n        \"/models\",\n        \"/opt/models\",\n    ]\n\n    for path in common_paths:\n        if path in locations:\n            continue\n        try:\n            # Check if parent exists for paths that don't exist yet\n            check_path = path\n            while not os.path.exists(check_path) and check_path != \"/\":\n                check_path = str(Path(check_path).parent)\n\n            if os.path.exists(check_path):\n                available_gb, total_gb = detect_disk_space_gb(check_path)\n                if available_gb >= min_size_gb:\n                    is_writable = os.access(check_path, os.W_OK)\n                    locations[path] = StorageLocation(\n                        path=path,\n                        available_gb=available_gb,\n                        total_gb=total_gb,\n                        is_writable=is_writable,\n                        mount_point=check_path,\n                    )\n        except (OSError, IOError):\n            continue\n\n    # Sort by available space descending, then by path\n    sorted_locations = sorted(locations.values(), key=lambda x: (-x.available_gb, x.path))\n\n    # Filter to only writable locations\n    return [loc for loc in sorted_locations if loc.is_writable]\n\n\ndef _get_mount_points() -> list[str]:\n    \"\"\"Get all mount points on the system.\"\"\"\n    mount_points = []\n\n    if platform.system() == \"Linux\":\n        try:\n            with open(\"/proc/mounts\", \"r\") as f:\n                for line in f:\n                    parts = line.split()\n                    if len(parts) >= 2:\n                        mount_point = parts[1]\n                        fs_type = parts[2] if len(parts) > 2 else \"\"\n\n                        # Skip pseudo filesystems\n                        skip_fs = {\n                            \"proc\",\n                            \"sysfs\",\n                            \"devpts\",\n                            \"tmpfs\",\n                            \"cgroup\",\n                            \"cgroup2\",\n                            \"pstore\",\n                            \"securityfs\",\n                            \"debugfs\",\n                            \"hugetlbfs\",\n                            \"mqueue\",\n                            \"fusectl\",\n                            \"configfs\",\n                            \"devtmpfs\",\n                            \"efivarfs\",\n                            \"autofs\",\n                            \"binfmt_misc\",\n                            \"overlay\",\n                            \"nsfs\",\n                            \"tracefs\",\n                        }\n                        if fs_type in skip_fs:\n                            continue\n\n                        # Skip paths that are clearly system paths\n                        skip_prefixes = (\"/sys\", \"/proc\", \"/dev\", \"/run/user\")\n                        if any(mount_point.startswith(p) for p in skip_prefixes):\n                            continue\n\n                        mount_points.append(mount_point)\n        except (OSError, IOError):\n            pass\n\n    # Always include home and root\n    mount_points.extend([str(Path.home()), \"/\"])\n\n    # Deduplicate while preserving order\n    seen = set()\n    unique_mounts = []\n    for mp in mount_points:\n        if mp not in seen:\n            seen.add(mp)\n            unique_mounts.append(mp)\n\n    return unique_mounts\n\n\ndef _get_potential_model_paths(mount_point: str) -> list[str]:\n    \"\"\"Get potential model storage paths under a mount point.\"\"\"\n    paths = []\n\n    # The mount point itself (for dedicated data drives)\n    if mount_point not in (\"/\", \"/home\"):\n        paths.append(mount_point)\n        paths.append(os.path.join(mount_point, \"models\"))\n\n    # If it's under home, suggest standard locations\n    home = str(Path.home())\n    if mount_point == home or mount_point == \"/home\":\n        paths.append(os.path.join(home, \".ktransformers\", \"models\"))\n        paths.append(os.path.join(home, \"models\"))\n\n    # For root mount, suggest /data or /opt\n    if mount_point == \"/\":\n        paths.extend([\"/data/models\", \"/opt/models\"])\n\n    # Check for common data directories on this mount\n    for subdir in [\"data\", \"models\", \"ai\", \"llm\", \"huggingface\"]:\n        potential = os.path.join(mount_point, subdir)\n        if os.path.exists(potential) and os.path.isdir(potential):\n            paths.append(potential)\n\n    return paths\n\n\ndef format_size_gb(size_gb: float) -> str:\n    \"\"\"Format size in GB to human readable string.\"\"\"\n    if size_gb >= 1000:\n        return f\"{size_gb / 1000:.1f}TB\"\n    return f\"{size_gb:.1f}GB\"\n\n\n@dataclass\nclass LocalModel:\n    \"\"\"Information about a locally detected model.\"\"\"\n\n    name: str\n    path: str\n    size_gb: float\n    model_type: str  # \"huggingface\", \"gguf\", \"safetensors\"\n    has_config: bool\n    file_count: int\n\n\ndef scan_local_models(search_paths: list[str], max_depth: int = 3) -> list[LocalModel]:\n    \"\"\"\n    Scan directories for locally downloaded models.\n\n    Looks for:\n    - Directories with config.json (HuggingFace format)\n    - Directories with .safetensors files\n    - Directories with .gguf files\n\n    Args:\n        search_paths: List of paths to search\n        max_depth: Maximum directory depth to search\n\n    Returns:\n        List of LocalModel sorted by size (descending)\n    \"\"\"\n    models: dict[str, LocalModel] = {}  # Use path as key to deduplicate\n\n    for search_path in search_paths:\n        if not os.path.exists(search_path):\n            continue\n\n        _scan_directory_for_models(search_path, models, current_depth=0, max_depth=max_depth)\n\n    # Sort by size descending\n    return sorted(models.values(), key=lambda x: -x.size_gb)\n\n\ndef _scan_directory_for_models(\n    directory: str, models: dict[str, LocalModel], current_depth: int, max_depth: int\n) -> None:\n    \"\"\"Recursively scan a directory for models.\"\"\"\n    if current_depth > max_depth:\n        return\n\n    try:\n        entries = list(os.scandir(directory))\n    except (PermissionError, OSError):\n        return\n\n    # Check if this directory is a model\n    model = _detect_model_in_directory(directory, entries)\n    if model:\n        models[model.path] = model\n        return  # Don't scan subdirectories of a model\n\n    # Scan subdirectories\n    for entry in entries:\n        if entry.is_dir() and not entry.name.startswith(\".\"):\n            _scan_directory_for_models(entry.path, models, current_depth + 1, max_depth)\n\n\ndef _detect_model_in_directory(directory: str, entries: list) -> Optional[LocalModel]:\n    \"\"\"Detect if a directory contains a model.\"\"\"\n    entry_names = {e.name for e in entries}\n\n    has_config = \"config.json\" in entry_names\n    safetensor_files = [e for e in entries if e.name.endswith(\".safetensors\") and e.is_file()]\n    gguf_files = [e for e in entries if e.name.endswith(\".gguf\") and e.is_file()]\n\n    # Determine model type\n    model_type = None\n    if has_config and safetensor_files:\n        model_type = \"huggingface\"\n    elif gguf_files:\n        model_type = \"gguf\"\n    elif safetensor_files:\n        model_type = \"safetensors\"\n    elif has_config:\n        # Config but no weights - might be incomplete\n        # Check for other model-related files\n        model_files = {\n            \"model.safetensors.index.json\",\n            \"pytorch_model.bin.index.json\",\n            \"model.safetensors\",\n            \"pytorch_model.bin\",\n        }\n        if entry_names & model_files:\n            model_type = \"huggingface\"\n\n    if not model_type:\n        return None\n\n    # Calculate directory size\n    size_bytes = _get_directory_size(directory)\n    size_gb = size_bytes / (1024**3)\n\n    # Skip very small directories (likely incomplete or config-only)\n    if size_gb < 0.1:\n        return None\n\n    # Get model name from directory name\n    name = os.path.basename(directory)\n\n    # Count model files\n    file_count = len(safetensor_files) + len(gguf_files)\n    if not file_count:\n        # Count .bin files as fallback\n        file_count = len([e for e in entries if e.name.endswith(\".bin\") and e.is_file()])\n\n    return LocalModel(\n        name=name,\n        path=directory,\n        size_gb=round(size_gb, 1),\n        model_type=model_type,\n        has_config=has_config,\n        file_count=file_count,\n    )\n\n\ndef _get_directory_size(directory: str) -> int:\n    \"\"\"Get total size of a directory in bytes.\"\"\"\n    total_size = 0\n    try:\n        for entry in os.scandir(directory):\n            try:\n                if entry.is_file(follow_symlinks=False):\n                    total_size += entry.stat().st_size\n                elif entry.is_dir(follow_symlinks=False):\n                    total_size += _get_directory_size(entry.path)\n            except (PermissionError, OSError):\n                continue\n    except (PermissionError, OSError):\n        pass\n    return total_size\n\n\ndef scan_models_in_location(location: StorageLocation, max_depth: int = 2) -> list[LocalModel]:\n    \"\"\"Scan a storage location for models.\"\"\"\n    search_paths = [location.path]\n\n    # Also check common subdirectories\n    for subdir in [\"models\", \"huggingface\", \"hub\", \".cache/huggingface/hub\"]:\n        subpath = os.path.join(location.path, subdir)\n        if os.path.exists(subpath):\n            search_paths.append(subpath)\n\n    return scan_local_models(search_paths, max_depth=max_depth)\n\n\n@dataclass\nclass CPUBuildFeatures:\n    \"\"\"CPU features for build configuration.\"\"\"\n\n    has_amx: bool\n    has_avx512: bool\n    has_avx512_vnni: bool\n    has_avx512_bf16: bool\n    has_avx2: bool\n    recommended_instruct: str  # NATIVE, AVX512, AVX2\n    recommended_amx: bool\n\n\ndef detect_cpu_build_features() -> CPUBuildFeatures:\n    \"\"\"\n    Detect CPU features for build configuration.\n\n    This is used to auto-configure kt-kernel source builds.\n    Reads /proc/cpuinfo on Linux to detect instruction set support.\n\n    Returns:\n        CPUBuildFeatures with detection results\n    \"\"\"\n    has_amx = False\n    has_avx512 = False\n    has_avx512_vnni = False\n    has_avx512_bf16 = False\n    has_avx2 = False\n\n    if platform.system() == \"Linux\":\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                content = f.read()\n\n            # Get flags from first processor\n            for line in content.split(\"\\n\"):\n                if line.startswith(\"flags\"):\n                    flags = line.split(\":\")[1].strip().split()\n                    flags_lower = {f.lower() for f in flags}\n\n                    # Check for AMX support (requires all three)\n                    if {\"amx_tile\", \"amx_int8\", \"amx_bf16\"} <= flags_lower:\n                        has_amx = True\n\n                    # Check for AVX512 support\n                    if \"avx512f\" in flags_lower:\n                        has_avx512 = True\n\n                    # Check for AVX512 VNNI\n                    if \"avx512_vnni\" in flags_lower or \"avx512vnni\" in flags_lower:\n                        has_avx512_vnni = True\n\n                    # Check for AVX512 BF16\n                    if \"avx512_bf16\" in flags_lower or \"avx512bf16\" in flags_lower:\n                        has_avx512_bf16 = True\n\n                    # Check for AVX2\n                    if \"avx2\" in flags_lower:\n                        has_avx2 = True\n\n                    break\n        except (OSError, IOError):\n            pass\n\n    elif platform.system() == \"Darwin\":\n        # macOS - use sysctl\n        features_output = run_command([\"sysctl\", \"-n\", \"machdep.cpu.features\"])\n        if features_output:\n            flags_lower = {f.lower() for f in features_output.split()}\n            has_avx2 = \"avx2\" in flags_lower\n            # macOS doesn't have AMX or AVX512 typically\n\n    # Determine recommended configuration\n    if has_amx:\n        recommended_instruct = \"NATIVE\"\n        recommended_amx = True\n    elif has_avx512:\n        recommended_instruct = \"NATIVE\"\n        recommended_amx = False\n    elif has_avx2:\n        recommended_instruct = \"NATIVE\"\n        recommended_amx = False\n    else:\n        recommended_instruct = \"AVX2\"\n        recommended_amx = False\n\n    return CPUBuildFeatures(\n        has_amx=has_amx,\n        has_avx512=has_avx512,\n        has_avx512_vnni=has_avx512_vnni,\n        has_avx512_bf16=has_avx512_bf16,\n        has_avx2=has_avx2,\n        recommended_instruct=recommended_instruct,\n        recommended_amx=recommended_amx,\n    )\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/input_validators.py",
    "content": "\"\"\"\nInput validation utilities with retry mechanism.\n\nProvides robust input validation with automatic retry on failure.\n\"\"\"\n\nfrom typing import Optional, List, Callable, Any\nfrom rich.console import Console\nfrom rich.prompt import Prompt\n\nconsole = Console()\n\n\ndef prompt_int_with_retry(\n    message: str,\n    default: Optional[int] = None,\n    min_val: Optional[int] = None,\n    max_val: Optional[int] = None,\n    validator: Optional[Callable[[int], bool]] = None,\n    validator_error_msg: Optional[str] = None,\n) -> int:\n    \"\"\"Prompt for integer input with validation and retry.\n\n    Args:\n        message: Prompt message\n        default: Default value (optional)\n        min_val: Minimum allowed value (optional)\n        max_val: Maximum allowed value (optional)\n        validator: Custom validation function (optional)\n        validator_error_msg: Error message for custom validator (optional)\n\n    Returns:\n        Validated integer value\n    \"\"\"\n    while True:\n        # Build prompt with default\n        if default is not None:\n            prompt_text = f\"{message} [{default}]\"\n        else:\n            prompt_text = message\n\n        # Get input\n        user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None)\n\n        # Try to parse as integer\n        try:\n            value = int(user_input)\n        except ValueError:\n            console.print(f\"[red]✗ Invalid input. Please enter a valid integer.[/red]\")\n            console.print()\n            continue\n\n        # Validate range\n        if min_val is not None and value < min_val:\n            console.print(f\"[red]✗ Value must be at least {min_val}[/red]\")\n            console.print()\n            continue\n\n        if max_val is not None and value > max_val:\n            console.print(f\"[red]✗ Value must be at most {max_val}[/red]\")\n            console.print()\n            continue\n\n        # Custom validation\n        if validator is not None:\n            if not validator(value):\n                error_msg = validator_error_msg or \"Invalid value\"\n                console.print(f\"[red]✗ {error_msg}[/red]\")\n                console.print()\n                continue\n\n        # All validations passed\n        return value\n\n\ndef prompt_float_with_retry(\n    message: str,\n    default: Optional[float] = None,\n    min_val: Optional[float] = None,\n    max_val: Optional[float] = None,\n) -> float:\n    \"\"\"Prompt for float input with validation and retry.\n\n    Args:\n        message: Prompt message\n        default: Default value (optional)\n        min_val: Minimum allowed value (optional)\n        max_val: Maximum allowed value (optional)\n\n    Returns:\n        Validated float value\n    \"\"\"\n    while True:\n        # Build prompt with default\n        if default is not None:\n            prompt_text = f\"{message} [{default}]\"\n        else:\n            prompt_text = message\n\n        # Get input\n        user_input = Prompt.ask(prompt_text, default=str(default) if default is not None else None)\n\n        # Try to parse as float\n        try:\n            value = float(user_input)\n        except ValueError:\n            console.print(f\"[red]✗ Invalid input. Please enter a valid number.[/red]\")\n            console.print()\n            continue\n\n        # Validate range\n        if min_val is not None and value < min_val:\n            console.print(f\"[red]✗ Value must be at least {min_val}[/red]\")\n            console.print()\n            continue\n\n        if max_val is not None and value > max_val:\n            console.print(f\"[red]✗ Value must be at most {max_val}[/red]\")\n            console.print()\n            continue\n\n        # All validations passed\n        return value\n\n\ndef prompt_choice_with_retry(\n    message: str,\n    choices: List[str],\n    default: Optional[str] = None,\n) -> str:\n    \"\"\"Prompt for choice input with validation and retry.\n\n    Args:\n        message: Prompt message\n        choices: List of valid choices\n        default: Default choice (optional)\n\n    Returns:\n        Selected choice\n    \"\"\"\n    while True:\n        # Get input\n        user_input = Prompt.ask(message, default=default)\n\n        # Validate choice\n        if user_input not in choices:\n            console.print(f\"[red]✗ Invalid choice. Please select from: {', '.join(choices)}[/red]\")\n            console.print()\n            continue\n\n        return user_input\n\n\ndef prompt_int_list_with_retry(\n    message: str,\n    default: Optional[str] = None,\n    min_val: Optional[int] = None,\n    max_val: Optional[int] = None,\n    validator: Optional[Callable[[List[int]], tuple[bool, Optional[str]]]] = None,\n) -> List[int]:\n    \"\"\"Prompt for comma-separated integer list with validation and retry.\n\n    Args:\n        message: Prompt message\n        default: Default value as string (e.g., \"0,1,2,3\")\n        min_val: Minimum allowed value for each integer (optional)\n        max_val: Maximum allowed value for each integer (optional)\n        validator: Custom validation function that returns (is_valid, error_message) (optional)\n\n    Returns:\n        List of validated integers\n    \"\"\"\n    while True:\n        # Get input\n        user_input = Prompt.ask(message, default=default)\n\n        # Clean input: support Chinese comma and spaces\n        user_input_cleaned = user_input.replace(\"，\", \",\").replace(\" \", \"\")\n\n        # Try to parse as integers\n        try:\n            values = [int(x.strip()) for x in user_input_cleaned.split(\",\") if x.strip()]\n        except ValueError:\n            console.print(f\"[red]✗ Invalid format. Please enter numbers separated by commas.[/red]\")\n            console.print()\n            continue\n\n        # Validate each value's range\n        invalid_values = []\n        for value in values:\n            if min_val is not None and value < min_val:\n                invalid_values.append(value)\n            elif max_val is not None and value > max_val:\n                invalid_values.append(value)\n\n        if invalid_values:\n            if min_val is not None and max_val is not None:\n                console.print(f\"[red]✗ Invalid value(s): {invalid_values}[/red]\")\n                console.print(f\"[yellow]Valid range: {min_val}-{max_val}[/yellow]\")\n            elif min_val is not None:\n                console.print(f\"[red]✗ Value(s) must be at least {min_val}: {invalid_values}[/red]\")\n            elif max_val is not None:\n                console.print(f\"[red]✗ Value(s) must be at most {max_val}: {invalid_values}[/red]\")\n            console.print()\n            continue\n\n        # Custom validation\n        if validator is not None:\n            is_valid, error_msg = validator(values)\n            if not is_valid:\n                console.print(f\"[red]✗ {error_msg}[/red]\")\n                console.print()\n                continue\n\n        # All validations passed\n        return values\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/kv_cache_calculator.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nKV Cache Size Calculator for SGLang\n\nThis script calculates the KV cache size in GB for a given model and number of tokens.\nIt follows the same logic as in sglang/srt/model_executor/model_runner.py\n\"\"\"\n\nimport os\nimport sys\nimport torch\nfrom transformers import AutoConfig\n\n# Add sglang to path\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"python\"))\n\nfrom sglang.srt.configs.model_config import ModelConfig, is_deepseek_nsa, get_nsa_index_head_dim\nfrom sglang.srt.mem_cache.memory_pool import NSATokenToKVPool\n\n\ndef get_dtype_bytes(dtype_str: str) -> int:\n    \"\"\"Get the number of bytes for a given dtype string.\"\"\"\n    dtype_map = {\n        \"float32\": 4,\n        \"float16\": 2,\n        \"bfloat16\": 2,\n        \"float8_e4m3fn\": 1,\n        \"float8_e5m2\": 1,\n        \"auto\": 2,  # Usually defaults to bfloat16\n    }\n    return dtype_map.get(dtype_str, 2)\n\n\ndef get_kv_size_gb(\n    model_path: str,\n    max_total_tokens: int,\n    tp: int = 1,\n    dtype: str = \"auto\",\n    verbose: bool = True,\n) -> dict:\n    \"\"\"\n    Calculate the KV cache size in GB for a given model and number of tokens.\n\n    Args:\n        model_path: Path to the model\n        max_total_tokens: Maximum number of tokens to cache\n        tp: Tensor parallelism size\n        dtype: Data type for KV cache (auto, float16, bfloat16, float8_e4m3fn, etc.)\n        verbose: Whether to print detailed information\n\n    Returns:\n        dict: Dictionary containing calculation details\n    \"\"\"\n    # Load model config\n    model_config = ModelConfig(model_path, dtype=dtype)\n    hf_config = model_config.hf_config\n\n    # Determine dtype bytes\n    dtype_bytes = get_dtype_bytes(dtype)\n    if dtype == \"auto\":\n        # Auto dtype usually becomes bfloat16\n        dtype_bytes = 2\n\n    # Number of layers\n    num_layers = model_config.num_attention_layers\n\n    # Check if it's MLA (Multi-head Latent Attention) model\n    is_mla = hasattr(model_config, \"attention_arch\") and model_config.attention_arch.name == \"MLA\"\n\n    result = {\n        \"model_path\": model_path,\n        \"max_total_tokens\": max_total_tokens,\n        \"tp\": tp,\n        \"dtype\": dtype,\n        \"dtype_bytes\": dtype_bytes,\n        \"num_layers\": num_layers,\n        \"is_mla\": is_mla,\n    }\n\n    if is_mla:\n        # MLA models (DeepSeek-V2/V3, MiniCPM3, etc.)\n        kv_lora_rank = model_config.kv_lora_rank\n        qk_rope_head_dim = model_config.qk_rope_head_dim\n\n        # Calculate cell size (per token)\n        cell_size = (kv_lora_rank + qk_rope_head_dim) * num_layers * dtype_bytes\n\n        result.update(\n            {\n                \"kv_lora_rank\": kv_lora_rank,\n                \"qk_rope_head_dim\": qk_rope_head_dim,\n                \"cell_size_bytes\": cell_size,\n            }\n        )\n\n        # Check if it's NSA (Native Sparse Attention) model\n        if is_deepseek_nsa(hf_config):\n            index_head_dim = get_nsa_index_head_dim(hf_config)\n            indexer_size_per_token = index_head_dim + index_head_dim // NSATokenToKVPool.quant_block_size * 4\n            indexer_dtype_bytes = torch._utils._element_size(NSATokenToKVPool.index_k_with_scale_buffer_dtype)\n            indexer_cell_size = indexer_size_per_token * num_layers * indexer_dtype_bytes\n            cell_size += indexer_cell_size\n\n            result.update(\n                {\n                    \"is_nsa\": True,\n                    \"index_head_dim\": index_head_dim,\n                    \"indexer_cell_size_bytes\": indexer_cell_size,\n                    \"total_cell_size_bytes\": cell_size,\n                }\n            )\n        else:\n            result[\"is_nsa\"] = False\n    else:\n        # Standard MHA models\n        num_kv_heads = model_config.get_num_kv_heads(tp)\n        head_dim = model_config.head_dim\n        v_head_dim = model_config.v_head_dim\n\n        # Calculate cell size (per token)\n        cell_size = num_kv_heads * (head_dim + v_head_dim) * num_layers * dtype_bytes\n\n        result.update(\n            {\n                \"num_kv_heads\": num_kv_heads,\n                \"head_dim\": head_dim,\n                \"v_head_dim\": v_head_dim,\n                \"cell_size_bytes\": cell_size,\n            }\n        )\n\n    # Calculate total KV cache size\n    total_size_bytes = max_total_tokens * cell_size\n    total_size_gb = total_size_bytes / (1024**3)\n\n    # For MHA models with separate K and V buffers\n    if not is_mla:\n        k_size_bytes = max_total_tokens * num_kv_heads * head_dim * num_layers * dtype_bytes\n        v_size_bytes = max_total_tokens * num_kv_heads * v_head_dim * num_layers * dtype_bytes\n        k_size_gb = k_size_bytes / (1024**3)\n        v_size_gb = v_size_bytes / (1024**3)\n\n        result.update(\n            {\n                \"k_size_gb\": k_size_gb,\n                \"v_size_gb\": v_size_gb,\n            }\n        )\n\n    result.update(\n        {\n            \"total_size_bytes\": total_size_bytes,\n            \"total_size_gb\": total_size_gb,\n        }\n    )\n\n    if verbose:\n        print(f\"Model: {model_path}\")\n        print(f\"Tokens: {max_total_tokens}, TP: {tp}, Dtype: {dtype}\")\n        print(f\"Architecture: {'MLA' if is_mla else 'MHA'}\")\n        print(f\"Layers: {num_layers}\")\n\n        if is_mla:\n            print(f\"KV LoRA Rank: {kv_lora_rank}, QK RoPE Head Dim: {qk_rope_head_dim}\")\n            if result.get(\"is_nsa\"):\n                print(f\"NSA Index Head Dim: {index_head_dim}\")\n                print(\n                    f\"Cell size: {cell_size} bytes (Main: {result['cell_size_bytes']}, Indexer: {result['indexer_cell_size_bytes']})\"\n                )\n            else:\n                print(f\"Cell size: {cell_size} bytes\")\n        else:\n            print(f\"KV Heads: {num_kv_heads}, Head Dim: {head_dim}, V Head Dim: {v_head_dim}\")\n            print(f\"Cell size: {cell_size} bytes\")\n            print(f\"K size: {k_size_gb:.2f} GB, V size: {v_size_gb:.2f} GB\")\n\n        print(f\"Total KV Cache Size: {total_size_gb:.2f} GB\")\n\n    return result\n\n\ndef main():\n    import argparse\n\n    parser = argparse.ArgumentParser(description=\"Calculate KV cache size for a model\")\n    parser.add_argument(\"model_path\", help=\"Path to the model\")\n    parser.add_argument(\"max_total_tokens\", type=int, help=\"Maximum number of tokens\")\n    parser.add_argument(\"--tp\", type=int, default=1, help=\"Tensor parallelism size\")\n    parser.add_argument(\"--dtype\", type=str, default=\"auto\", help=\"Data type (auto, float16, bfloat16, etc.)\")\n    parser.add_argument(\"--quiet\", action=\"store_true\", help=\"Suppress verbose output\")\n\n    args = parser.parse_args()\n\n    result = get_kv_size_gb(\n        args.model_path,\n        args.max_total_tokens,\n        tp=args.tp,\n        dtype=args.dtype,\n        verbose=not args.quiet,\n    )\n\n    if args.quiet:\n        print(f\"{result['total_size_gb']:.2f}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/model_discovery.py",
    "content": "\"\"\"\nModel Discovery Utilities\n\nShared functions for discovering and registering new models across different commands.\n\"\"\"\n\nfrom typing import List, Optional, Tuple\nfrom pathlib import Path\nfrom rich.console import Console\n\nfrom kt_kernel.cli.utils.model_scanner import (\n    discover_models,\n    scan_directory_for_models,\n    ScannedModel,\n)\nfrom kt_kernel.cli.utils.user_model_registry import UserModelRegistry, UserModel\n\n\nconsole = Console()\n\n\ndef discover_and_register_global(\n    min_size_gb: float = 2.0, max_depth: int = 6, show_progress: bool = True, lang: str = \"en\"\n) -> Tuple[int, int, List[UserModel]]:\n    \"\"\"\n    Perform global model discovery and register new models.\n\n    Args:\n        min_size_gb: Minimum model size in GB\n        max_depth: Maximum search depth\n        show_progress: Whether to show progress messages\n        lang: Language for messages (\"en\" or \"zh\")\n\n    Returns:\n        Tuple of (total_found, new_found, registered_models)\n    \"\"\"\n    registry = UserModelRegistry()\n\n    if show_progress:\n        if lang == \"zh\":\n            console.print(\"[dim]正在扫描系统中的模型权重，这可能需要30-60秒...[/dim]\")\n        else:\n            console.print(\"[dim]Scanning system for model weights, this may take 30-60 seconds...[/dim]\")\n\n    # Global scan\n    all_models = discover_models(mount_points=None, min_size_gb=min_size_gb, max_depth=max_depth)\n\n    # Filter out existing models\n    new_models = []\n    for model in all_models:\n        if not registry.find_by_path(model.path):\n            new_models.append(model)\n\n    # Register new models\n    registered = []\n    for model in new_models:\n        user_model = _create_and_register_model(registry, model)\n        if user_model:\n            registered.append(user_model)\n\n    return len(all_models), len(new_models), registered\n\n\ndef discover_and_register_path(\n    path: str,\n    min_size_gb: float = 2.0,\n    existing_paths: Optional[set] = None,\n    show_progress: bool = True,\n    lang: str = \"en\",\n) -> Tuple[int, int, List[UserModel]]:\n    \"\"\"\n    Discover models in a specific path and register new ones.\n\n    Args:\n        path: Directory path to scan\n        min_size_gb: Minimum model file size in GB\n        existing_paths: Set of already discovered paths in this session (optional)\n        show_progress: Whether to show progress messages\n        lang: Language for messages (\"en\" or \"zh\")\n\n    Returns:\n        Tuple of (total_found, new_found, registered_models)\n    \"\"\"\n    registry = UserModelRegistry()\n\n    if show_progress:\n        if lang == \"zh\":\n            console.print(f\"[dim]正在扫描 {path}...[/dim]\")\n        else:\n            console.print(f\"[dim]Scanning {path}...[/dim]\")\n\n    # Scan directory\n    model_info = scan_directory_for_models(path, min_file_size_gb=min_size_gb)\n\n    if not model_info:\n        return 0, 0, []\n\n    # Convert to ScannedModel and filter\n    new_models = []\n    for dir_path, (format_type, size_bytes, file_count, files) in model_info.items():\n        # Check if already in registry\n        if registry.find_by_path(dir_path):\n            continue\n\n        # Check if already discovered in this session\n        if existing_paths and dir_path in existing_paths:\n            continue\n\n        model = ScannedModel(\n            path=dir_path, format=format_type, size_bytes=size_bytes, file_count=file_count, files=files\n        )\n        new_models.append(model)\n\n    # Register new models\n    registered = []\n    for model in new_models:\n        user_model = _create_and_register_model(registry, model)\n        if user_model:\n            registered.append(user_model)\n\n    return len(model_info), len(new_models), registered\n\n\ndef _create_and_register_model(registry: UserModelRegistry, scanned_model: ScannedModel) -> Optional[UserModel]:\n    \"\"\"\n    Create a UserModel from ScannedModel and register it.\n\n    Handles name conflicts by suggesting a unique name (e.g., model-2, model-3).\n    Automatically detects repo_id from README.md YAML frontmatter.\n    Automatically detects and caches MoE information for safetensors models.\n\n    Args:\n        registry: UserModelRegistry instance\n        scanned_model: ScannedModel to register\n\n    Returns:\n        Registered UserModel or None if failed\n    \"\"\"\n    # Use suggest_name to get a unique name (adds -2, -3, etc. if needed)\n    unique_name = registry.suggest_name(scanned_model.folder_name)\n\n    user_model = UserModel(name=unique_name, path=scanned_model.path, format=scanned_model.format)\n\n    # Auto-detect repo_id from README.md (only YAML frontmatter)\n    try:\n        from kt_kernel.cli.utils.repo_detector import detect_repo_for_model\n\n        repo_info = detect_repo_for_model(scanned_model.path)\n        if repo_info:\n            repo_id, repo_type = repo_info\n            user_model.repo_id = repo_id\n            user_model.repo_type = repo_type\n    except Exception:\n        # Silently continue if detection fails\n        pass\n\n    # Auto-detect MoE information for safetensors models\n    if scanned_model.format == \"safetensors\":\n        try:\n            from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model\n\n            moe_result = analyze_moe_model(scanned_model.path, use_cache=True)\n            if moe_result and moe_result.get(\"is_moe\"):\n                user_model.is_moe = True\n                user_model.moe_num_experts = moe_result.get(\"num_experts\")\n                user_model.moe_num_experts_per_tok = moe_result.get(\"num_experts_per_tok\")\n            else:\n                user_model.is_moe = False\n        except Exception:\n            # Silently continue if MoE detection fails\n            # is_moe will remain None\n            pass\n\n    try:\n        registry.add_model(user_model)\n        return user_model\n    except Exception:\n        # Should not happen since we used suggest_name, but handle gracefully\n        return None\n\n\ndef format_discovery_summary(\n    total_found: int,\n    new_found: int,\n    registered: List[UserModel],\n    lang: str = \"en\",\n    show_models: bool = True,\n    max_show: int = 10,\n) -> None:\n    \"\"\"\n    Print formatted discovery summary.\n\n    Args:\n        total_found: Total models found\n        new_found: New models found\n        registered: List of registered UserModel objects\n        lang: Language (\"en\" or \"zh\")\n        show_models: Whether to show model list\n        max_show: Maximum models to show\n    \"\"\"\n    console.print()\n\n    if new_found == 0:\n        if total_found > 0:\n            if lang == \"zh\":\n                console.print(f\"[green]✓[/green] 扫描完成：找到 {total_found} 个模型，所有模型均已在列表中\")\n            else:\n                console.print(f\"[green]✓[/green] Scan complete: found {total_found} models, all already in the list\")\n        else:\n            if lang == \"zh\":\n                console.print(\"[yellow]未找到模型[/yellow]\")\n            else:\n                console.print(\"[yellow]No models found[/yellow]\")\n        return\n\n    # Show summary\n    if lang == \"zh\":\n        console.print(f\"[green]✓[/green] 扫描完成：找到 {total_found} 个模型，其中 {new_found} 个为新模型\")\n    else:\n        console.print(f\"[green]✓[/green] Scan complete: found {total_found} models, {new_found} are new\")\n\n    # Show registered count\n    if len(registered) > 0:\n        if lang == \"zh\":\n            console.print(f\"[green]✓[/green] 成功添加 {len(registered)} 个新模型到列表\")\n        else:\n            console.print(f\"[green]✓[/green] Successfully added {len(registered)} new models to list\")\n\n    # Show model list\n    if show_models and registered:\n        console.print()\n        if lang == \"zh\":\n            console.print(f\"[dim]新发现的模型（前{max_show}个）:[/dim]\")\n        else:\n            console.print(f\"[dim]Newly discovered models (first {max_show}):[/dim]\")\n\n        for i, model in enumerate(registered[:max_show], 1):\n            # Get size from registry or estimate\n            size_str = \"?.? GB\"\n            # Try to find the ScannedModel to get size\n            # For now just show name and path\n            console.print(f\"  {i}. {model.name} ({model.format})\")\n            console.print(f\"     [dim]{model.path}[/dim]\")\n\n        if len(registered) > max_show:\n            remaining = len(registered) - max_show\n            if lang == \"zh\":\n                console.print(f\"  [dim]... 还有 {remaining} 个新模型[/dim]\")\n            else:\n                console.print(f\"  [dim]... and {remaining} more new models[/dim]\")\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/model_registry.py",
    "content": "\"\"\"\nModel registry for kt-cli.\n\nProvides a registry of supported models with fuzzy matching capabilities.\n\"\"\"\n\nimport re\nfrom dataclasses import dataclass, field\nfrom pathlib import Path\nfrom typing import Callable, Optional\n\nimport yaml\n\nfrom kt_kernel.cli.config.settings import get_settings\n\n\n@dataclass\nclass ModelInfo:\n    \"\"\"Information about a supported model.\"\"\"\n\n    name: str\n    hf_repo: str\n    aliases: list[str] = field(default_factory=list)\n    type: str = \"moe\"  # moe, dense\n    gpu_vram_gb: float = 0\n    cpu_ram_gb: float = 0\n    default_params: dict = field(default_factory=dict)\n    description: str = \"\"\n    description_zh: str = \"\"\n    max_tensor_parallel_size: Optional[int] = None  # Maximum tensor parallel size for this model\n\n\n# Built-in model registry\nBUILTIN_MODELS: list[ModelInfo] = [\n    ModelInfo(\n        name=\"DeepSeek-V3-0324\",\n        hf_repo=\"deepseek-ai/DeepSeek-V3-0324\",\n        aliases=[\"deepseek-v3-0324\", \"deepseek-v3\", \"dsv3\", \"deepseek3\", \"v3-0324\"],\n        type=\"moe\",\n        default_params={\n            \"kt-num-gpu-experts\": 1,\n            \"attention-backend\": \"triton\",\n            \"disable-shared-experts-fusion\": True,\n            \"kt-method\": \"AMXINT4\",\n        },\n        description=\"DeepSeek V3-0324 685B MoE model (March 2025, improved benchmarks)\",\n        description_zh=\"DeepSeek V3-0324 685B MoE 模型（2025年3月，改进的基准测试）\",\n    ),\n    ModelInfo(\n        name=\"DeepSeek-V3.2\",\n        hf_repo=\"deepseek-ai/DeepSeek-V3.2\",\n        aliases=[\"deepseek-v3.2\", \"dsv3.2\", \"deepseek3.2\", \"v3.2\"],\n        type=\"moe\",\n        default_params={\n            \"kt-method\": \"FP8\",\n            \"kt-gpu-prefill-token-threshold\": 4096,\n            \"attention-backend\": \"flashinfer\",\n            \"fp8-gemm-backend\": \"triton\",\n            \"max-total-tokens\": 100000,\n            \"max-running-requests\": 16,\n            \"chunked-prefill-size\": 32768,\n            \"mem-fraction-static\": 0.80,\n            \"watchdog-timeout\": 3000,\n            \"served-model-name\": \"DeepSeek-V3.2\",\n            \"disable-shared-experts-fusion\": True,\n        },\n        description=\"DeepSeek V3.2 671B MoE model (latest)\",\n        description_zh=\"DeepSeek V3.2 671B MoE 模型（最新）\",\n    ),\n    ModelInfo(\n        name=\"DeepSeek-R1-0528\",\n        hf_repo=\"deepseek-ai/DeepSeek-R1-0528\",\n        aliases=[\"deepseek-r1-0528\", \"deepseek-r1\", \"dsr1\", \"r1\", \"r1-0528\"],\n        type=\"moe\",\n        default_params={\n            \"kt-num-gpu-experts\": 1,\n            \"attention-backend\": \"triton\",\n            \"disable-shared-experts-fusion\": True,\n            \"kt-method\": \"AMXINT4\",\n        },\n        description=\"DeepSeek R1-0528 reasoning model (May 2025, improved reasoning depth)\",\n        description_zh=\"DeepSeek R1-0528 推理模型（2025年5月，改进的推理深度）\",\n    ),\n    ModelInfo(\n        name=\"Kimi-K2-Thinking\",\n        hf_repo=\"moonshotai/Kimi-K2-Thinking\",\n        aliases=[\"kimi-k2-thinking\", \"kimi-thinking\", \"k2-thinking\", \"kimi\", \"k2\"],\n        type=\"moe\",\n        default_params={\n            \"kt-method\": \"RAWINT4\",\n            \"kt-gpu-prefill-token-threshold\": 400,\n            \"attention-backend\": \"flashinfer\",\n            \"max-total-tokens\": 100000,\n            \"max-running-requests\": 16,\n            \"chunked-prefill-size\": 32768,\n            \"mem-fraction-static\": 0.80,\n            \"watchdog-timeout\": 3000,\n            \"served-model-name\": \"Kimi-K2-Thinking\",\n            \"disable-shared-experts-fusion\": True,\n        },\n        description=\"Moonshot Kimi K2 Thinking MoE model\",\n        description_zh=\"月之暗面 Kimi K2 Thinking MoE 模型\",\n    ),\n    ModelInfo(\n        name=\"MiniMax-M2\",\n        hf_repo=\"MiniMaxAI/MiniMax-M2\",\n        aliases=[\"minimax-m2\", \"m2\"],\n        type=\"moe\",\n        default_params={\n            \"kt-method\": \"FP8\",\n            \"kt-gpu-prefill-token-threshold\": 4096,\n            \"attention-backend\": \"flashinfer\",\n            \"fp8-gemm-backend\": \"triton\",\n            \"max-total-tokens\": 100000,\n            \"max-running-requests\": 16,\n            \"chunked-prefill-size\": 32768,\n            \"mem-fraction-static\": 0.80,\n            \"watchdog-timeout\": 3000,\n            \"served-model-name\": \"MiniMax-M2\",\n            \"disable-shared-experts-fusion\": True,\n            \"tool-call-parser\": \"minimax-m2\",\n            \"reasoning-parser\": \"minimax-append-think\",\n        },\n        description=\"MiniMax M2 MoE model\",\n        description_zh=\"MiniMax M2 MoE 模型\",\n        max_tensor_parallel_size=4,  # M2 only supports up to 4-way tensor parallelism\n    ),\n    ModelInfo(\n        name=\"MiniMax-M2.1\",\n        hf_repo=\"MiniMaxAI/MiniMax-M2.1\",\n        aliases=[\"minimax-m2.1\", \"m2.1\"],\n        type=\"moe\",\n        default_params={\n            \"kt-method\": \"FP8\",\n            \"kt-gpu-prefill-token-threshold\": 4096,\n            \"attention-backend\": \"flashinfer\",\n            \"fp8-gemm-backend\": \"triton\",\n            \"max-total-tokens\": 100000,\n            \"max-running-requests\": 16,\n            \"chunked-prefill-size\": 32768,\n            \"mem-fraction-static\": 0.80,\n            \"watchdog-timeout\": 3000,\n            \"served-model-name\": \"MiniMax-M2.1\",\n            \"disable-shared-experts-fusion\": True,\n            \"tool-call-parser\": \"minimax-m2\",\n            \"reasoning-parser\": \"minimax-append-think\",\n        },\n        description=\"MiniMax M2.1 MoE model (enhanced multi-language programming)\",\n        description_zh=\"MiniMax M2.1 MoE 模型（增强多语言编程能力）\",\n        max_tensor_parallel_size=4,  # M2.1 only supports up to 4-way tensor parallelism\n    ),\n]\n\n\nclass ModelRegistry:\n    \"\"\"Registry of supported models with fuzzy matching.\"\"\"\n\n    def __init__(self):\n        \"\"\"Initialize the model registry.\"\"\"\n        self._models: dict[str, ModelInfo] = {}\n        self._aliases: dict[str, str] = {}\n        self._load_builtin_models()\n        self._load_user_models()\n\n    def _load_builtin_models(self) -> None:\n        \"\"\"Load built-in models.\"\"\"\n        for model in BUILTIN_MODELS:\n            self._register(model)\n\n    def _load_user_models(self) -> None:\n        \"\"\"Load user-defined models from config.\"\"\"\n        settings = get_settings()\n        registry_file = settings.config_dir / \"registry.yaml\"\n\n        if registry_file.exists():\n            try:\n                with open(registry_file, \"r\", encoding=\"utf-8\") as f:\n                    data = yaml.safe_load(f) or {}\n\n                for name, info in data.get(\"models\", {}).items():\n                    model = ModelInfo(\n                        name=name,\n                        hf_repo=info.get(\"hf_repo\", \"\"),\n                        aliases=info.get(\"aliases\", []),\n                        type=info.get(\"type\", \"moe\"),\n                        gpu_vram_gb=info.get(\"gpu_vram_gb\", 0),\n                        cpu_ram_gb=info.get(\"cpu_ram_gb\", 0),\n                        default_params=info.get(\"default_params\", {}),\n                        description=info.get(\"description\", \"\"),\n                        description_zh=info.get(\"description_zh\", \"\"),\n                        max_tensor_parallel_size=info.get(\"max_tensor_parallel_size\"),\n                    )\n                    self._register(model)\n            except (yaml.YAMLError, OSError):\n                pass\n\n    def _register(self, model: ModelInfo) -> None:\n        \"\"\"Register a model.\"\"\"\n        self._models[model.name.lower()] = model\n\n        # Register aliases\n        for alias in model.aliases:\n            self._aliases[alias.lower()] = model.name.lower()\n\n    def get(self, name: str) -> Optional[ModelInfo]:\n        \"\"\"Get a model by exact name or alias.\"\"\"\n        name_lower = name.lower()\n\n        # Check direct match\n        if name_lower in self._models:\n            return self._models[name_lower]\n\n        # Check aliases\n        if name_lower in self._aliases:\n            return self._models[self._aliases[name_lower]]\n\n        return None\n\n    def search(self, query: str, limit: int = 10) -> list[ModelInfo]:\n        \"\"\"Search for models using fuzzy matching.\n\n        Args:\n            query: Search query\n            limit: Maximum number of results\n\n        Returns:\n            List of matching models, sorted by relevance\n        \"\"\"\n        query_lower = query.lower()\n        results: list[tuple[float, ModelInfo]] = []\n\n        for model in self._models.values():\n            score = self._match_score(query_lower, model)\n            if score > 0:\n                results.append((score, model))\n\n        # Sort by score descending\n        results.sort(key=lambda x: x[0], reverse=True)\n\n        return [model for _, model in results[:limit]]\n\n    def _match_score(self, query: str, model: ModelInfo) -> float:\n        \"\"\"Calculate match score for a model.\n\n        Returns a score between 0 and 1, where 1 is an exact match.\n        \"\"\"\n        # Check exact match\n        if query == model.name.lower():\n            return 1.0\n\n        # Check alias exact match\n        for alias in model.aliases:\n            if query == alias.lower():\n                return 0.95\n\n        # Check if query is contained in name\n        if query in model.name.lower():\n            return 0.8\n\n        # Check if query is contained in aliases\n        for alias in model.aliases:\n            if query in alias.lower():\n                return 0.7\n\n        # Check if query is contained in hf_repo\n        if query in model.hf_repo.lower():\n            return 0.6\n\n        # Fuzzy matching - check if all query parts are present\n        query_parts = re.split(r\"[-_.\\s]\", query)\n        name_lower = model.name.lower()\n\n        matches = sum(1 for part in query_parts if part and part in name_lower)\n        if matches > 0:\n            return 0.5 * (matches / len(query_parts))\n\n        return 0.0\n\n    def list_all(self) -> list[ModelInfo]:\n        \"\"\"List all registered models.\"\"\"\n        return list(self._models.values())\n\n    def find_local_models(self, max_depth: int = 3) -> list[tuple[ModelInfo, Path]]:\n        \"\"\"Find models that are downloaded locally in any configured model path.\n\n        Args:\n            max_depth: Maximum depth to search within each model path (default: 3)\n\n        Returns:\n            List of (ModelInfo, path) tuples for local models\n        \"\"\"\n        settings = get_settings()\n        model_paths = settings.get_model_paths()\n        results = []\n\n        for model in self._models.values():\n            found = False\n            # Search in all configured model directories\n            for models_dir in model_paths:\n                if not models_dir.exists():\n                    continue\n\n                # Generate possible names to search for\n                possible_names = [\n                    model.name,\n                    model.name.lower(),\n                    model.hf_repo.split(\"/\")[-1],\n                    model.hf_repo.replace(\"/\", \"--\"),\n                ]\n\n                # Search recursively up to max_depth\n                for depth in range(max_depth):\n                    # Build glob pattern for current depth\n                    # depth=0: direct children, depth=1: grandchildren, etc.\n                    glob_pattern = \"*\" if depth > 0 else \"\"\n                    for _ in range(depth):\n                        glob_pattern = \"*/\" + glob_pattern if glob_pattern else \"*\"\n\n                    for name in possible_names:\n                        if depth == 0:\n                            # Direct children: models_dir / name\n                            search_paths = [models_dir / name]\n                        else:\n                            # Nested: use rglob to find directories matching the name\n                            search_paths = list(models_dir.rglob(name))\n\n                        for path in search_paths:\n                            if path.exists() and (path / \"config.json\").exists():\n                                results.append((model, path))\n                                found = True\n                                break\n\n                        if found:\n                            break\n\n                    if found:\n                        break\n\n                if found:\n                    break\n\n        return results\n\n\n# Global registry instance\n_registry: Optional[ModelRegistry] = None\n\n\ndef get_registry() -> ModelRegistry:\n    \"\"\"Get the global model registry instance.\"\"\"\n    global _registry\n    if _registry is None:\n        _registry = ModelRegistry()\n    return _registry\n\n\n# ============================================================================\n# Model-specific parameter computation functions\n# ============================================================================\n\n\ndef compute_deepseek_v3_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:\n    per_gpu_gb = 16\n    if vram_per_gpu_gb < per_gpu_gb:\n        return int(0)\n    total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))\n\n    return total_vram // 3\n\n\ndef compute_kimi_k2_thinking_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:\n    \"\"\"Compute kt-num-gpu-experts for Kimi K2 Thinking.\"\"\"\n    per_gpu_gb = 16\n    if vram_per_gpu_gb < per_gpu_gb:\n        return int(0)\n    total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))\n\n    return total_vram * 2 // 3\n\n\ndef compute_minimax_m2_gpu_experts(tensor_parallel_size: int, vram_per_gpu_gb: float) -> int:\n    \"\"\"Compute kt-num-gpu-experts for MiniMax M2/M2.1.\"\"\"\n    per_gpu_gb = 16\n    if vram_per_gpu_gb < per_gpu_gb:\n        return int(0)\n    total_vram = int(tensor_parallel_size * (vram_per_gpu_gb - per_gpu_gb))\n\n    return total_vram // 1\n\n\n# Model name to computation function mapping\nMODEL_COMPUTE_FUNCTIONS: dict[str, Callable[[int, float], int]] = {\n    \"DeepSeek-V3-0324\": compute_deepseek_v3_gpu_experts,\n    \"DeepSeek-V3.2\": compute_deepseek_v3_gpu_experts,  # Same as V3-0324\n    \"DeepSeek-R1-0528\": compute_deepseek_v3_gpu_experts,  # Same as V3-0324\n    \"Kimi-K2-Thinking\": compute_kimi_k2_thinking_gpu_experts,\n    \"MiniMax-M2\": compute_minimax_m2_gpu_experts,\n    \"MiniMax-M2.1\": compute_minimax_m2_gpu_experts,  # Same as M2\n}\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/model_scanner.py",
    "content": "\"\"\"\nModel Scanner\n\nScans directories for model files (safetensors, gguf) and identifies models\n\"\"\"\n\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import List, Optional, Set, Tuple, Dict\nfrom collections import defaultdict\nimport os\nimport subprocess\nimport json\n\n\n@dataclass\nclass ScannedModel:\n    \"\"\"Temporary structure for scanned model information\"\"\"\n\n    path: str  # Absolute path to model directory\n    format: str  # \"safetensors\" | \"gguf\" | \"mixed\"\n    size_bytes: int  # Total size in bytes\n    file_count: int  # Number of model files\n    files: List[str]  # List of model file names\n\n    @property\n    def size_gb(self) -> float:\n        \"\"\"Get size in GB\"\"\"\n        return self.size_bytes / (1024**3)\n\n    @property\n    def folder_name(self) -> str:\n        \"\"\"Get the folder name (default model name)\"\"\"\n        return Path(self.path).name\n\n\nclass ModelScanner:\n    \"\"\"Scanner for discovering models in directory trees\"\"\"\n\n    def __init__(self, min_size_gb: float = 10.0):\n        \"\"\"\n        Initialize scanner\n\n        Args:\n            min_size_gb: Minimum folder size in GB to be considered a model\n        \"\"\"\n        self.min_size_bytes = int(min_size_gb * 1024**3)\n\n    def scan_directory(\n        self, base_path: Path, exclude_paths: Optional[Set[str]] = None\n    ) -> Tuple[List[ScannedModel], List[str]]:\n        \"\"\"\n        Scan directory tree for models\n\n        Args:\n            base_path: Root directory to scan\n            exclude_paths: Set of absolute paths to exclude from results\n\n        Returns:\n            Tuple of (valid_models, warnings)\n            - valid_models: List of ScannedModel instances\n            - warnings: List of warning messages\n        \"\"\"\n        if not base_path.exists():\n            raise ValueError(f\"Path does not exist: {base_path}\")\n\n        if not base_path.is_dir():\n            raise ValueError(f\"Path is not a directory: {base_path}\")\n\n        exclude_paths = exclude_paths or set()\n        results: List[ScannedModel] = []\n        warnings: List[str] = []\n\n        # Walk the directory tree\n        for root, dirs, files in os.walk(base_path):\n            root_path = Path(root).resolve()\n\n            # Skip if already registered\n            if str(root_path) in exclude_paths:\n                dirs[:] = []  # Don't descend into this directory\n                continue\n\n            # Check for model files\n            safetensors_files = [f for f in files if f.endswith(\".safetensors\")]\n            gguf_files = [f for f in files if f.endswith(\".gguf\")]\n\n            if not safetensors_files and not gguf_files:\n                continue  # No model files in this directory\n\n            # Calculate total size\n            model_files = safetensors_files + gguf_files\n            total_size = self._calculate_total_size(root_path, model_files)\n\n            # Check if size meets minimum threshold\n            if total_size < self.min_size_bytes:\n                continue  # Too small, but keep scanning subdirectories\n\n            # Detect format\n            if safetensors_files and gguf_files:\n                # Mixed format - issue warning\n                warnings.append(\n                    f\"Mixed format detected in {root_path}: \"\n                    f\"{len(safetensors_files)} safetensors + {len(gguf_files)} gguf files. \"\n                    \"Please separate into different folders and re-scan.\"\n                )\n                dirs[:] = []  # Don't descend into mixed format directories\n                continue\n\n            # Determine format\n            format_type = \"safetensors\" if safetensors_files else \"gguf\"\n\n            # Create scanned model\n            scanned = ScannedModel(\n                path=str(root_path),\n                format=format_type,\n                size_bytes=total_size,\n                file_count=len(model_files),\n                files=model_files,\n            )\n\n            results.append(scanned)\n\n            # Continue scanning subdirectories - they might also contain models\n            # Each subdirectory will be independently checked for size >= 10GB\n\n        return results, warnings\n\n    def scan_single_path(self, path: Path) -> Optional[ScannedModel]:\n        \"\"\"\n        Scan a single path for model files\n\n        Args:\n            path: Path to scan\n\n        Returns:\n            ScannedModel instance or None if not a valid model\n        \"\"\"\n        if not path.exists() or not path.is_dir():\n            return None\n\n        # Find model files\n        safetensors_files = list(path.glob(\"*.safetensors\"))\n        gguf_files = list(path.glob(\"*.gguf\"))\n\n        if not safetensors_files and not gguf_files:\n            return None\n\n        # Check for mixed format\n        if safetensors_files and gguf_files:\n            raise ValueError(\n                f\"Mixed format detected: {len(safetensors_files)} safetensors + \"\n                f\"{len(gguf_files)} gguf files. Please use a single format.\"\n            )\n\n        # Calculate size\n        model_files = [f.name for f in safetensors_files + gguf_files]\n        total_size = self._calculate_total_size(path, model_files)\n\n        # Determine format\n        format_type = \"safetensors\" if safetensors_files else \"gguf\"\n\n        return ScannedModel(\n            path=str(path.resolve()),\n            format=format_type,\n            size_bytes=total_size,\n            file_count=len(model_files),\n            files=model_files,\n        )\n\n    def _calculate_total_size(self, directory: Path, filenames: List[str]) -> int:\n        \"\"\"\n        Calculate total size of specified files in directory\n\n        Args:\n            directory: Directory containing the files\n            filenames: List of filenames to sum\n\n        Returns:\n            Total size in bytes\n        \"\"\"\n        total = 0\n        for filename in filenames:\n            file_path = directory / filename\n            if file_path.exists():\n                try:\n                    total += file_path.stat().st_size\n                except OSError:\n                    # File might be inaccessible, skip it\n                    pass\n        return total\n\n\n# Convenience functions\n\n\ndef scan_directory(\n    base_path: Path, min_size_gb: float = 10.0, exclude_paths: Optional[Set[str]] = None\n) -> Tuple[List[ScannedModel], List[str]]:\n    \"\"\"\n    Convenience function to scan a directory\n\n    Args:\n        base_path: Root directory to scan\n        min_size_gb: Minimum folder size in GB\n        exclude_paths: Set of paths to exclude\n\n    Returns:\n        Tuple of (models, warnings)\n    \"\"\"\n    scanner = ModelScanner(min_size_gb=min_size_gb)\n    return scanner.scan_directory(base_path, exclude_paths)\n\n\ndef scan_single_path(path: Path) -> Optional[ScannedModel]:\n    \"\"\"\n    Convenience function to scan a single path\n\n    Args:\n        path: Path to scan\n\n    Returns:\n        ScannedModel or None\n    \"\"\"\n    scanner = ModelScanner()\n    return scanner.scan_single_path(path)\n\n\ndef format_size(size_bytes: int) -> str:\n    \"\"\"\n    Format size in bytes to human-readable string\n\n    Args:\n        size_bytes: Size in bytes\n\n    Returns:\n        Formatted string (e.g., \"42.3 GB\")\n    \"\"\"\n    for unit in [\"B\", \"KB\", \"MB\", \"GB\", \"TB\"]:\n        if size_bytes < 1024.0:\n            return f\"{size_bytes:.1f} {unit}\"\n        size_bytes /= 1024.0\n    return f\"{size_bytes:.1f} PB\"\n\n\n# ===== Fast Scanning with Find Command and Tree-based Root Detection =====\n\n\ndef find_files_fast(mount_point: str, pattern: str, max_depth: int = 6, timeout: int = 30) -> List[str]:\n    \"\"\"\n    Use find command to quickly locate files\n\n    Args:\n        mount_point: Starting directory\n        pattern: File pattern (e.g., \"config.json\", \"*.gguf\")\n        max_depth: Maximum directory depth (default: 6)\n        timeout: Command timeout in seconds\n\n    Returns:\n        List of absolute file paths\n    \"\"\"\n    try:\n        # Use shell=False for better security and handling of special characters in paths\n        cmd = [\"find\", mount_point, \"-maxdepth\", str(max_depth), \"-name\", pattern, \"-type\", \"f\"]\n        result = subprocess.run(\n            cmd,\n            stdout=subprocess.PIPE,\n            stderr=subprocess.DEVNULL,\n            text=True,\n            timeout=timeout,\n        )\n\n        # Return results even if returncode is non-zero (due to permission errors)\n        # As long as we got some output\n        if result.stdout:\n            return [line.strip() for line in result.stdout.strip().split(\"\\n\") if line.strip()]\n        return []\n    except (subprocess.TimeoutExpired, FileNotFoundError):\n        return []\n\n\ndef is_valid_model_directory(directory: Path, min_size_gb: float = 10.0) -> Tuple[bool, Optional[str]]:\n    \"\"\"\n    Check if a directory is a valid model directory\n\n    Args:\n        directory: Path to check\n        min_size_gb: Minimum size in GB\n\n    Returns:\n        (is_valid, model_type) where model_type is \"safetensors\", \"gguf\", or None\n    \"\"\"\n    if not directory.exists() or not directory.is_dir():\n        return False, None\n\n    has_config = (directory / \"config.json\").exists()\n    safetensors_files = list(directory.glob(\"*.safetensors\"))\n    gguf_files = list(directory.glob(\"*.gguf\"))\n\n    # Determine model type\n    model_type = None\n    if (has_config and safetensors_files) or safetensors_files:\n        model_type = \"safetensors\"\n    elif gguf_files:\n        model_type = \"gguf\"\n    else:\n        return False, None\n\n    # Check size - only count model files (fast!)\n    total_size = 0\n    if model_type == \"safetensors\":\n        for f in safetensors_files:\n            try:\n                total_size += f.stat().st_size\n            except OSError:\n                pass\n    else:  # gguf\n        for f in gguf_files:\n            try:\n                total_size += f.stat().st_size\n            except OSError:\n                pass\n\n    size_gb = total_size / (1024**3)\n    if size_gb < min_size_gb:\n        return False, None\n\n    return True, model_type\n\n\ndef scan_all_models_fast(mount_points: List[str], min_size_gb: float = 10.0, max_depth: int = 6) -> List[str]:\n    \"\"\"\n    Fast scan for all model paths using find command\n\n    Args:\n        mount_points: List of mount points to scan\n        min_size_gb: Minimum model size in GB\n        max_depth: Maximum search depth (default: 6)\n\n    Returns:\n        List of valid model directory paths\n    \"\"\"\n    model_paths = set()\n\n    for mount in mount_points:\n        if not os.path.exists(mount):\n            continue\n\n        # Find all config.json files\n        config_files = find_files_fast(mount, \"config.json\", max_depth=max_depth)\n        for config_path in config_files:\n            model_dir = Path(config_path).parent\n            is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb)\n            if is_valid:\n                model_paths.add(str(model_dir.resolve()))\n\n        # Find all *.gguf files\n        gguf_files = find_files_fast(mount, \"*.gguf\", max_depth=max_depth)\n        for gguf_path in gguf_files:\n            model_dir = Path(gguf_path).parent\n            is_valid, model_type = is_valid_model_directory(model_dir, min_size_gb)\n            if is_valid:\n                model_paths.add(str(model_dir.resolve()))\n\n    return sorted(model_paths)\n\n\ndef get_root_subdirs() -> List[str]:\n    \"\"\"\n    Get subdirectories of / that are worth scanning\n\n    Filters out system paths only\n\n    Returns:\n        List of directories to scan\n    \"\"\"\n    # System paths to exclude\n    excluded = {\n        \"dev\",\n        \"proc\",\n        \"sys\",\n        \"run\",\n        \"boot\",\n        \"tmp\",\n        \"usr\",\n        \"lib\",\n        \"lib64\",\n        \"bin\",\n        \"sbin\",\n        \"etc\",\n        \"opt\",\n        \"var\",\n        \"snap\",\n    }\n\n    scan_dirs = []\n\n    try:\n        for entry in os.scandir(\"/\"):\n            if not entry.is_dir():\n                continue\n\n            # Skip excluded paths\n            if entry.name in excluded:\n                continue\n\n            scan_dirs.append(entry.path)\n\n    except PermissionError:\n        pass\n\n    return sorted(scan_dirs)\n\n\ndef scan_directory_for_models(directory: str, min_file_size_gb: float = 2.0) -> Dict[str, tuple]:\n    \"\"\"\n    Scan a directory for models using find command with size filter\n\n    Uses find -size +2G to only locate large model files (>=2GB)\n\n    Args:\n        directory: Directory to scan\n        min_file_size_gb: Minimum individual file size in GB (default: 2.0)\n\n    Returns:\n        Dict mapping model_path -> (model_type, size_bytes, file_count, files)\n    \"\"\"\n    model_info = {}\n\n    # Convert GB to find's format (e.g., 2GB = +2G)\n    if min_file_size_gb >= 1.0:\n        size_filter = f\"+{int(min_file_size_gb)}G\"\n    else:\n        size_mb = int(min_file_size_gb * 1024)\n        size_filter = f\"+{size_mb}M\"\n\n    # 1. Find *.gguf files >= 2GB\n    gguf_cmd = [\"find\", directory, \"-name\", \"*.gguf\", \"-type\", \"f\", \"-size\", size_filter, \"-printf\", \"%p\\t%s\\n\"]\n    result = subprocess.run(gguf_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, timeout=120)\n\n    # Group by directory\n    gguf_dirs = defaultdict(list)\n    for line in result.stdout.strip().split(\"\\n\"):\n        if not line:\n            continue\n        parts = line.split(\"\\t\")\n        if len(parts) != 2:\n            continue\n        file_path, size_str = parts\n        file_path_obj = Path(file_path)\n        dir_path = str(file_path_obj.parent)\n        gguf_dirs[dir_path].append((file_path_obj.name, int(size_str)))\n\n    # Add all gguf directories\n    for dir_path, files in gguf_dirs.items():\n        total_size = sum(size for _, size in files)\n        model_info[dir_path] = (\"gguf\", total_size, len(files), [name for name, _ in files])\n\n    # 2. Find *.safetensors files >= 2GB\n    safetensors_cmd = [\"find\", directory, \"-name\", \"*.safetensors\", \"-type\", \"f\", \"-size\", size_filter, \"-printf\", \"%p\\t%s\\n\"]\n    result = subprocess.run(safetensors_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, timeout=120)\n\n    # Group by directory\n    safetensors_dirs = defaultdict(list)\n    for line in result.stdout.strip().split(\"\\n\"):\n        if not line:\n            continue\n        parts = line.split(\"\\t\")\n        if len(parts) != 2:\n            continue\n        file_path, size_str = parts\n        file_path_obj = Path(file_path)\n        dir_path = str(file_path_obj.parent)\n        safetensors_dirs[dir_path].append((file_path_obj.name, int(size_str)))\n\n    # 3. Check each safetensors directory for config.json\n    for dir_path, files in safetensors_dirs.items():\n        if os.path.exists(os.path.join(dir_path, \"config.json\")):\n            total_size = sum(size for _, size in files)\n            model_info[dir_path] = (\"safetensors\", total_size, len(files), [name for name, _ in files])\n\n    return model_info\n\n\ndef scan_all_models_with_info(\n    mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6\n) -> Dict[str, tuple]:\n    \"\"\"\n    Fast scan with parallel directory scanning\n\n    Strategy:\n    1. Use provided directories or auto-detect root subdirectories\n    2. Scan each directory in parallel (one thread per directory)\n    3. Use find -size +2G to find large model files (>=2GB)\n\n    Args:\n        mount_points: Specific directories to scan, or None to auto-detect from / subdirs\n        min_size_gb: Not used anymore (kept for API compatibility)\n        max_depth: Not used anymore (kept for API compatibility)\n\n    Returns:\n        Dict mapping model_path -> (model_type, size_bytes, file_count, files)\n    \"\"\"\n    from concurrent.futures import ThreadPoolExecutor, as_completed\n\n    # Get directories to scan\n    if mount_points is None:\n        # Get root subdirectories (exclude system paths)\n        scan_dirs = get_root_subdirs()\n    else:\n        scan_dirs = mount_points\n\n    if not scan_dirs:\n        return {}\n\n    model_info = {}\n\n    # Scan each directory in parallel (max 8 concurrent)\n    # Use 2GB threshold to find model files\n    with ThreadPoolExecutor(max_workers=min(len(scan_dirs), 8)) as executor:\n        futures = {executor.submit(scan_directory_for_models, d, 2.0): d for d in scan_dirs}\n\n        for future in as_completed(futures):\n            try:\n                dir_results = future.result()\n                model_info.update(dir_results)\n            except Exception as e:\n                # Skip directories with errors\n                pass\n\n    return model_info\n\n\ndef find_model_roots_from_paths(model_paths: List[str]) -> Tuple[List[str], Dict[str, int]]:\n    \"\"\"\n    Find optimal root paths from model paths using tree-based algorithm\n\n    Algorithm:\n    1. Build path tree with all intermediate paths\n    2. DFS to calculate f(x) = subtree sum (number of models in subtree)\n    3. Find roots where f(parent) = f(x) > max(f(children))\n\n    Args:\n        model_paths: List of model directory paths\n\n    Returns:\n        (root_paths, subtree_sizes) where:\n        - root_paths: List of inferred root directories\n        - subtree_sizes: Dict mapping each root to number of models\n    \"\"\"\n    if not model_paths:\n        return [], {}\n\n    # 1. Build path set (including all intermediate paths)\n    all_paths = set()\n    model_set = set(model_paths)\n\n    for model_path in model_paths:\n        path = Path(model_path)\n        for i in range(1, len(path.parts) + 1):\n            all_paths.add(str(Path(*path.parts[:i])))\n\n    # 2. Build parent-child relationships\n    children_map = defaultdict(list)\n    for path in all_paths:\n        path_obj = Path(path)\n        if len(path_obj.parts) > 1:\n            parent = str(path_obj.parent)\n            if parent in all_paths:\n                children_map[parent].append(path)\n\n    # 3. DFS to calculate f(x) and max_child_f(x)\n    f = {}  # path -> subtree sum\n    max_child_f = {}  # path -> max(f(children))\n    visited = set()\n\n    def dfs(path: str) -> int:\n        if path in visited:\n            return f[path]\n        visited.add(path)\n\n        # Current node weight (1 if it's a model path, 0 otherwise)\n        weight = 1 if path in model_set else 0\n\n        # Recursively calculate children\n        children = children_map.get(path, [])\n        if not children:\n            # Leaf node\n            f[path] = weight\n            max_child_f[path] = 0\n            return weight\n\n        # Calculate f values for all children\n        children_f_values = [dfs(child) for child in children]\n\n        # Calculate f(x) and max_child_f(x)\n        f[path] = weight + sum(children_f_values)\n        max_child_f[path] = max(children_f_values) if children_f_values else 0\n\n        return f[path]\n\n    # Find top-level nodes (no parent in all_paths)\n    top_nodes = []\n    for path in all_paths:\n        parent = str(Path(path).parent)\n        if parent not in all_paths or parent == path:\n            top_nodes.append(path)\n\n    # Execute DFS from all top nodes\n    for top in top_nodes:\n        dfs(top)\n\n    # 4. Find root nodes: f(parent) = f(x) >= max(f(children))\n    # Note: Use >= instead of > to handle the case where a directory contains only one model\n    candidate_roots = []\n    for path in all_paths:\n        # Skip model paths themselves (leaf nodes in model tree)\n        if path in model_set:\n            continue\n\n        parent = str(Path(path).parent)\n\n        # Check condition: f(parent) = f(x) and f(x) >= max(f(children))\n        if parent in f and f.get(parent, 0) == f.get(path, 0):\n            if f.get(path, 0) >= max_child_f.get(path, 0) and f.get(path, 0) > 0:\n                candidate_roots.append(path)\n\n    # 5. Remove redundant roots (prefer deeper paths)\n    # If a root is an ancestor of another root with the same f value, remove it\n    roots = []\n    candidate_roots_sorted = sorted(candidate_roots, key=lambda p: -len(Path(p).parts))\n\n    for root in candidate_roots_sorted:\n        # Check if this root is a parent of any already selected root\n        is_redundant = False\n        for selected in roots:\n            if selected.startswith(root + \"/\"):\n                # selected is a child of root\n                # Only keep root if it has more models (shouldn't happen by algorithm)\n                if f.get(root, 0) == f.get(selected, 0):\n                    is_redundant = True\n                    break\n\n        if not is_redundant:\n            # Also filter out very shallow paths (< 3 levels)\n            if len(Path(root).parts) >= 3:\n                roots.append(root)\n\n    # Build subtree sizes for roots\n    subtree_sizes = {root: f.get(root, 0) for root in roots}\n\n    return sorted(roots), subtree_sizes\n\n\n@dataclass\nclass ModelRootInfo:\n    \"\"\"Information about a detected model root path\"\"\"\n\n    path: str\n    model_count: int\n    models: List[ScannedModel]\n\n\ndef discover_models(\n    mount_points: Optional[List[str]] = None, min_size_gb: float = 10.0, max_depth: int = 6\n) -> List[ScannedModel]:\n    \"\"\"\n    Discover all model directories on the system\n\n    Fast scan using find command to locate all models that meet the criteria\n\n    Args:\n        mount_points: List of mount points to scan (None = auto-detect)\n        min_size_gb: Minimum model size in GB (default: 10.0)\n        max_depth: Maximum search depth (default: 6)\n\n    Returns:\n        List of ScannedModel sorted by path\n    \"\"\"\n    # Auto-detect mount points if not provided\n    if mount_points is None:\n        mount_points = _get_mount_points()\n\n    # Fast scan with cached info (only scan once!)\n    model_info = scan_all_models_with_info(mount_points, min_size_gb, max_depth)\n\n    if not model_info:\n        return []\n\n    # Convert to ScannedModel objects\n    results = []\n    for model_path, (model_type, total_size, file_count, files) in model_info.items():\n        results.append(\n            ScannedModel(path=model_path, format=model_type, size_bytes=total_size, file_count=file_count, files=files)\n        )\n\n    # Sort by path\n    results.sort(key=lambda m: m.path)\n    return results\n\n\ndef _get_mount_points() -> List[str]:\n    \"\"\"\n    Get all valid mount points from /proc/mounts, filtering out system paths\n\n    Returns:\n        List of mount point paths suitable for model storage\n        (excludes root \"/\" to avoid scanning entire filesystem)\n    \"\"\"\n    mount_points = set()\n\n    # System paths to exclude (unlikely to contain model files)\n    excluded_paths = [\n        \"/snap/\",\n        \"/proc/\",\n        \"/sys/\",\n        \"/run/\",\n        \"/boot\",\n        \"/dev/\",\n        \"/usr\",\n        \"/lib\",\n        \"/lib64\",\n        \"/bin\",\n        \"/sbin\",\n        \"/etc\",\n        \"/opt\",\n        \"/var\",\n        \"/tmp\",\n    ]\n\n    try:\n        with open(\"/proc/mounts\", \"r\") as f:\n            for line in f:\n                parts = line.split()\n                if len(parts) < 3:\n                    continue\n\n                device, mount_point, fs_type = parts[0], parts[1], parts[2]\n\n                # Filter out pseudo filesystems\n                pseudo_fs = {\n                    \"proc\",\n                    \"sysfs\",\n                    \"devpts\",\n                    \"tmpfs\",\n                    \"devtmpfs\",\n                    \"cgroup\",\n                    \"cgroup2\",\n                    \"pstore\",\n                    \"bpf\",\n                    \"tracefs\",\n                    \"debugfs\",\n                    \"hugetlbfs\",\n                    \"mqueue\",\n                    \"configfs\",\n                    \"securityfs\",\n                    \"fuse.gvfsd-fuse\",\n                    \"fusectl\",\n                    \"squashfs\",\n                    \"overlay\",  # snap packages\n                }\n\n                if fs_type in pseudo_fs:\n                    continue\n\n                # Skip root directory (too large to scan)\n                if mount_point == \"/\":\n                    continue\n\n                # Filter out system paths\n                if any(mount_point.startswith(x) for x in excluded_paths):\n                    continue\n\n                # Only include if it exists and is readable\n                if os.path.exists(mount_point) and os.access(mount_point, os.R_OK):\n                    mount_points.add(mount_point)\n\n        # If no mount points found, add common data directories\n        if not mount_points:\n            # Add /home if it exists and is not already a separate mount point\n            common_paths = [\"/home\", \"/data\", \"/mnt\"]\n            for path in common_paths:\n                if os.path.exists(path) and os.access(path, os.R_OK):\n                    mount_points.add(path)\n\n    except (FileNotFoundError, PermissionError):\n        # Fallback to common paths\n        mount_points = {\"/home\", \"/mnt\", \"/data\"}\n\n    return sorted(mount_points)\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/model_table_builder.py",
    "content": "\"\"\"\nShared model table builders for consistent UI across commands.\n\nProvides reusable table construction functions for displaying models\nin kt model list, kt quant, kt run, etc.\n\"\"\"\n\nfrom typing import List, Optional, Tuple\nfrom pathlib import Path\nfrom rich.table import Table\nfrom rich.console import Console\nimport json\n\n\ndef format_model_size(model_path: Path, format_type: str) -> str:\n    \"\"\"Calculate and format model size.\"\"\"\n    from kt_kernel.cli.utils.model_scanner import format_size\n\n    try:\n        if format_type == \"safetensors\":\n            files = list(model_path.glob(\"*.safetensors\"))\n        elif format_type == \"gguf\":\n            files = list(model_path.glob(\"*.gguf\"))\n        else:\n            return \"[dim]-[/dim]\"\n\n        total_size = sum(f.stat().st_size for f in files if f.exists())\n        return format_size(total_size)\n    except Exception:\n        return \"[dim]-[/dim]\"\n\n\ndef format_repo_info(model) -> str:\n    \"\"\"Format repository information.\"\"\"\n    if model.repo_id:\n        repo_abbr = \"hf\" if model.repo_type == \"huggingface\" else \"ms\"\n        return f\"{repo_abbr}:{model.repo_id}\"\n    return \"[dim]-[/dim]\"\n\n\ndef format_sha256_status(model, status_map: dict) -> str:\n    \"\"\"Format SHA256 verification status.\"\"\"\n    return status_map.get(model.sha256_status or \"not_checked\", \"[dim]?[/dim]\")\n\n\ndef build_moe_gpu_table(\n    models: List, status_map: dict, show_index: bool = True, start_index: int = 1\n) -> Tuple[Table, List]:\n    \"\"\"\n    Build MoE GPU models table.\n\n    Args:\n        models: List of MoE GPU model objects\n        status_map: SHA256_STATUS_MAP for formatting status\n        show_index: Whether to show # column for selection (default: True)\n        start_index: Starting index number\n\n    Returns:\n        Tuple of (Table object, list of models in display order)\n    \"\"\"\n    table = Table(show_header=True, header_style=\"bold\", show_lines=False)\n\n    if show_index:\n        table.add_column(\"#\", justify=\"right\", style=\"cyan\", no_wrap=True)\n\n    table.add_column(\"Name\", style=\"cyan\", no_wrap=True)\n    table.add_column(\"Path\", style=\"dim\", overflow=\"fold\")\n    table.add_column(\"Total\", justify=\"right\")\n    table.add_column(\"Exps\", justify=\"center\", style=\"yellow\")\n    table.add_column(\"Act\", justify=\"center\", style=\"green\")\n    table.add_column(\"Repository\", style=\"dim\", overflow=\"fold\")\n    table.add_column(\"SHA256\", justify=\"center\")\n\n    displayed_models = []\n\n    for i, model in enumerate(models, start_index):\n        displayed_models.append(model)\n\n        # Calculate size\n        size_str = format_model_size(Path(model.path), \"safetensors\")\n\n        # MoE info\n        num_experts = str(model.moe_num_experts) if model.moe_num_experts else \"[dim]-[/dim]\"\n        num_active = str(model.moe_num_experts_per_tok) if model.moe_num_experts_per_tok else \"[dim]-[/dim]\"\n\n        # Repository and SHA256\n        repo_str = format_repo_info(model)\n        sha256_str = format_sha256_status(model, status_map)\n\n        row = []\n        if show_index:\n            row.append(str(i))\n\n        row.extend([model.name, model.path, size_str, num_experts, num_active, repo_str, sha256_str])\n\n        table.add_row(*row)\n\n    return table, displayed_models\n\n\ndef build_amx_table(\n    models: List,\n    status_map: dict = None,  # Kept for API compatibility but not used\n    show_index: bool = True,\n    start_index: int = 1,\n    show_linked_gpus: bool = False,\n    gpu_models: Optional[List] = None,\n) -> Tuple[Table, List]:\n    \"\"\"\n    Build AMX models table.\n\n    Note: AMX models are locally quantized, so no SHA256 verification column.\n\n    Args:\n        models: List of AMX model objects\n        status_map: (Unused - kept for API compatibility)\n        show_index: Whether to show # column for selection (default: True)\n        start_index: Starting index number\n        show_linked_gpus: Whether to show sub-rows for linked GPU models\n        gpu_models: List of GPU models (required if show_linked_gpus=True)\n\n    Returns:\n        Tuple of (Table object, list of models in display order)\n    \"\"\"\n    table = Table(show_header=True, header_style=\"bold\", show_lines=False)\n\n    if show_index:\n        table.add_column(\"#\", justify=\"right\", style=\"cyan\", no_wrap=True)\n\n    table.add_column(\"Name\", style=\"cyan\", no_wrap=True)\n    table.add_column(\"Path\", style=\"dim\", overflow=\"fold\")\n    table.add_column(\"Total\", justify=\"right\")\n    table.add_column(\"Method\", justify=\"center\", style=\"yellow\")\n    table.add_column(\"NUMA\", justify=\"center\", style=\"green\")\n    table.add_column(\"Source\", style=\"dim\", overflow=\"fold\")\n\n    # Build reverse map if needed\n    amx_used_by_gpu = {}\n    if show_linked_gpus and gpu_models:\n        for model in models:\n            if model.gpu_model_ids:\n                gpu_names = []\n                for gpu_id in model.gpu_model_ids:\n                    for gpu_model in gpu_models:\n                        if gpu_model.id == gpu_id:\n                            gpu_names.append(gpu_model.name)\n                            break\n                if gpu_names:\n                    amx_used_by_gpu[model.id] = gpu_names\n\n    displayed_models = []\n\n    for i, model in enumerate(models, start_index):\n        displayed_models.append(model)\n\n        # Calculate size\n        size_str = format_model_size(Path(model.path), \"safetensors\")\n\n        # Read metadata from config.json or UserModel fields\n        method_from_config = None\n        numa_from_config = None\n        try:\n            config_path = Path(model.path) / \"config.json\"\n            if config_path.exists():\n                with open(config_path, \"r\", encoding=\"utf-8\") as f:\n                    config = json.load(f)\n                    amx_quant = config.get(\"amx_quantization\", {})\n                    if amx_quant.get(\"converted\"):\n                        method_from_config = amx_quant.get(\"method\")\n                        numa_from_config = amx_quant.get(\"numa_count\")\n        except Exception:\n            pass\n\n        # Priority: UserModel fields > config.json > ?\n        method_display = (\n            model.amx_quant_method.upper()\n            if model.amx_quant_method\n            else method_from_config.upper() if method_from_config else \"[dim]?[/dim]\"\n        )\n        numa_display = (\n            str(model.amx_numa_nodes)\n            if model.amx_numa_nodes\n            else str(numa_from_config) if numa_from_config else \"[dim]?[/dim]\"\n        )\n        source_display = model.amx_source_model or \"[dim]-[/dim]\"\n\n        row = []\n        if show_index:\n            row.append(str(i))\n\n        row.extend([model.name, model.path, size_str, method_display, numa_display, source_display])\n\n        table.add_row(*row)\n\n        # Add sub-row showing linked GPUs\n        if show_linked_gpus and model.id in amx_used_by_gpu:\n            gpu_list = amx_used_by_gpu[model.id]\n            gpu_names_str = \", \".join([f\"[dim]{name}[/dim]\" for name in gpu_list])\n            sub_row = []\n            if show_index:\n                sub_row.append(\"\")\n            sub_row.extend([f\"  [dim]↳ GPU: {gpu_names_str}[/dim]\", \"\", \"\", \"\", \"\", \"\"])\n            table.add_row(*sub_row, style=\"dim\")\n\n    return table, displayed_models\n\n\ndef build_gguf_table(\n    models: List, status_map: dict, show_index: bool = True, start_index: int = 1\n) -> Tuple[Table, List]:\n    \"\"\"\n    Build GGUF models table.\n\n    Args:\n        models: List of GGUF model objects\n        status_map: SHA256_STATUS_MAP for formatting status\n        show_index: Whether to show # column for selection (default: True)\n        start_index: Starting index number\n\n    Returns:\n        Tuple of (Table object, list of models in display order)\n    \"\"\"\n    table = Table(show_header=True, header_style=\"bold\", show_lines=False)\n\n    if show_index:\n        table.add_column(\"#\", justify=\"right\", style=\"cyan\", no_wrap=True)\n\n    table.add_column(\"Name\", style=\"cyan\", no_wrap=True)\n    table.add_column(\"Path\", style=\"dim\", overflow=\"fold\")\n    table.add_column(\"Total\", justify=\"right\")\n    table.add_column(\"Repository\", style=\"dim\", overflow=\"fold\")\n    table.add_column(\"SHA256\", justify=\"center\")\n\n    displayed_models = []\n\n    for i, model in enumerate(models, start_index):\n        displayed_models.append(model)\n\n        # Calculate size\n        size_str = format_model_size(Path(model.path), \"gguf\")\n\n        # Repository and SHA256\n        repo_str = format_repo_info(model)\n        sha256_str = format_sha256_status(model, status_map)\n\n        row = []\n        if show_index:\n            row.append(str(i))\n\n        row.extend([model.name, model.path, size_str, repo_str, sha256_str])\n\n        table.add_row(*row)\n\n    return table, displayed_models\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/model_verifier.py",
    "content": "\"\"\"\nModel Verifier\n\nSHA256 verification for model integrity\n\"\"\"\n\nimport hashlib\nimport requests\nimport os\nfrom pathlib import Path\nfrom typing import Dict, Any, Literal, Tuple\nfrom concurrent.futures import ProcessPoolExecutor, as_completed\n\n\ndef _compute_file_sha256(file_path: Path) -> Tuple[str, str, float]:\n    \"\"\"\n    Compute SHA256 for a single file (worker function for multiprocessing).\n\n    Args:\n        file_path: Path to the file\n\n    Returns:\n        Tuple of (filename, sha256_hash, file_size_mb)\n    \"\"\"\n    sha256_hash = hashlib.sha256()\n    file_size_mb = file_path.stat().st_size / (1024 * 1024)\n\n    # Read file in chunks to handle large files\n    with open(file_path, \"rb\") as f:\n        for byte_block in iter(lambda: f.read(8192 * 1024), b\"\"):  # 8MB chunks\n            sha256_hash.update(byte_block)\n\n    return file_path.name, sha256_hash.hexdigest(), file_size_mb\n\n\ndef check_huggingface_connectivity(timeout: int = 5) -> Tuple[bool, str]:\n    \"\"\"\n    Check if HuggingFace is accessible.\n\n    Args:\n        timeout: Connection timeout in seconds\n\n    Returns:\n        Tuple of (is_accessible, message)\n    \"\"\"\n    test_url = \"https://huggingface.co\"\n\n    try:\n        response = requests.head(test_url, timeout=timeout, allow_redirects=True)\n        if response.status_code < 500:  # 2xx, 3xx, 4xx are all considered \"accessible\"\n            return True, \"HuggingFace is accessible\"\n    except requests.exceptions.Timeout:\n        return False, f\"Connection to {test_url} timed out\"\n    except requests.exceptions.ConnectionError:\n        return False, f\"Cannot connect to {test_url}\"\n    except requests.exceptions.RequestException as e:\n        return False, f\"Connection error: {str(e)}\"\n\n    return False, \"Unknown connection error\"\n\n\ndef verify_model_integrity(\n    repo_type: Literal[\"huggingface\", \"modelscope\"],\n    repo_id: str,\n    local_dir: Path,\n    progress_callback=None,\n) -> Dict[str, Any]:\n    \"\"\"\n    Verify local model integrity against remote repository SHA256 hashes.\n\n    Verifies all important files:\n    - *.safetensors (weights)\n    - *.json (config files)\n    - *.py (custom model code)\n\n    Args:\n        repo_type: Type of repository (\"huggingface\" or \"modelscope\")\n        repo_id: Repository ID (e.g., \"deepseek-ai/DeepSeek-V3\")\n        local_dir: Local directory containing model files\n        progress_callback: Optional callback function(message: str) for progress updates\n\n    Returns:\n        Dictionary with verification results:\n        {\n            \"status\": \"passed\" | \"failed\" | \"error\",\n            \"files_checked\": int,\n            \"files_passed\": int,\n            \"files_failed\": [list of filenames],\n            \"error_message\": str (optional)\n        }\n    \"\"\"\n\n    def report_progress(msg: str):\n        \"\"\"Helper to report progress\"\"\"\n        if progress_callback:\n            progress_callback(msg)\n\n    try:\n        # Convert repo_type to platform format\n        platform = \"hf\" if repo_type == \"huggingface\" else \"ms\"\n\n        # 1. Fetch official SHA256 hashes from remote\n        report_progress(\"Fetching official SHA256 hashes from remote repository...\")\n        official_hashes = fetch_model_sha256(repo_id, platform)\n        report_progress(f\"✓ Fetched {len(official_hashes)} file hashes from remote\")\n\n        if not official_hashes:\n            return {\n                \"status\": \"error\",\n                \"files_checked\": 0,\n                \"files_passed\": 0,\n                \"files_failed\": [],\n                \"error_message\": f\"No verifiable files found in remote repository: {repo_id}\",\n            }\n\n        # 2. Calculate local SHA256 hashes with progress\n        report_progress(f\"Calculating SHA256 for local files...\")\n\n        # Get all local files matching the patterns\n        local_files = []\n        for pattern in [\"*.safetensors\", \"*.json\", \"*.py\"]:\n            local_files.extend([f for f in local_dir.glob(pattern) if f.is_file()])\n\n        if not local_files:\n            return {\n                \"status\": \"error\",\n                \"files_checked\": 0,\n                \"files_passed\": 0,\n                \"files_failed\": [],\n                \"error_message\": f\"No verifiable files found in local directory: {local_dir}\",\n            }\n\n        # Calculate hashes for all files\n        local_hashes = calculate_local_sha256(\n            local_dir,\n            file_pattern=\"*.safetensors\",  # Unused when files_list is provided\n            progress_callback=report_progress,\n            files_list=local_files,\n        )\n        report_progress(f\"✓ Calculated {len(local_hashes)} local file hashes\")\n\n        # 3. Compare hashes with progress\n        report_progress(f\"Comparing {len(official_hashes)} files...\")\n        files_failed = []\n        files_missing = []\n        files_passed = 0\n\n        for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):\n            # Handle potential path separators in filename\n            file_basename = Path(filename).name\n\n            # Try to find the file in local hashes\n            local_hash = None\n            for local_file, local_hash_value in local_hashes.items():\n                if Path(local_file).name == file_basename:\n                    local_hash = local_hash_value\n                    break\n\n            if local_hash is None:\n                files_missing.append(filename)\n                report_progress(f\"  [{idx}/{len(official_hashes)}] ✗ {file_basename} - MISSING\")\n            elif local_hash.lower() != official_hash.lower():\n                files_failed.append(f\"{filename} (hash mismatch)\")\n                report_progress(f\"  [{idx}/{len(official_hashes)}] ✗ {file_basename} - HASH MISMATCH\")\n            else:\n                files_passed += 1\n                report_progress(f\"  [{idx}/{len(official_hashes)}] ✓ {file_basename}\")\n\n        # 4. Return results\n        total_checked = len(official_hashes)\n\n        if files_failed or files_missing:\n            all_failed = files_failed + [f\"{f} (missing)\" for f in files_missing]\n            return {\n                \"status\": \"failed\",\n                \"files_checked\": total_checked,\n                \"files_passed\": files_passed,\n                \"files_failed\": all_failed,\n                \"error_message\": f\"{len(all_failed)} file(s) failed verification\",\n            }\n        else:\n            return {\n                \"status\": \"passed\",\n                \"files_checked\": total_checked,\n                \"files_passed\": files_passed,\n                \"files_failed\": [],\n            }\n\n    except ImportError as e:\n        return {\n            \"status\": \"error\",\n            \"files_checked\": 0,\n            \"files_passed\": 0,\n            \"files_failed\": [],\n            \"error_message\": f\"Missing required package: {str(e)}. Install with: pip install huggingface-hub modelscope\",\n            \"is_network_error\": False,\n        }\n    except (\n        requests.exceptions.ConnectionError,\n        requests.exceptions.Timeout,\n        requests.exceptions.RequestException,\n    ) as e:\n        # Network-related errors - suggest mirror\n        error_msg = f\"Network error: {str(e)}\"\n        if repo_type == \"huggingface\":\n            error_msg += \"\\n\\nTry using HuggingFace mirror:\\n  export HF_ENDPOINT=https://hf-mirror.com\"\n        return {\n            \"status\": \"error\",\n            \"files_checked\": 0,\n            \"files_passed\": 0,\n            \"files_failed\": [],\n            \"error_message\": error_msg,\n            \"is_network_error\": True,\n        }\n    except Exception as e:\n        return {\n            \"status\": \"error\",\n            \"files_checked\": 0,\n            \"files_passed\": 0,\n            \"files_failed\": [],\n            \"error_message\": f\"Verification failed: {str(e)}\",\n            \"is_network_error\": False,\n        }\n\n\ndef calculate_local_sha256(\n    local_dir: Path, file_pattern: str = \"*.safetensors\", progress_callback=None, files_list: list[Path] = None\n) -> Dict[str, str]:\n    \"\"\"\n    Calculate SHA256 hashes for files in a directory using parallel processing.\n\n    Args:\n        local_dir: Directory to scan\n        file_pattern: Glob pattern for files to hash (ignored if files_list is provided)\n        progress_callback: Optional callback function(message: str) for progress updates\n        files_list: Optional pre-filtered list of files to hash (overrides file_pattern)\n\n    Returns:\n        Dictionary mapping filename to SHA256 hash\n    \"\"\"\n    result = {}\n\n    if not local_dir.exists():\n        return result\n\n    # Get all files first to report total\n    if files_list is not None:\n        files_to_hash = files_list\n    else:\n        files_to_hash = [f for f in local_dir.glob(file_pattern) if f.is_file()]\n    total_files = len(files_to_hash)\n\n    if total_files == 0:\n        return result\n\n    # Use min(16, total_files) workers to avoid over-spawning processes\n    max_workers = min(16, total_files)\n\n    if progress_callback:\n        progress_callback(f\"  Using {max_workers} parallel workers for SHA256 calculation\")\n\n    # Use ProcessPoolExecutor for CPU-intensive SHA256 computation\n    completed_count = 0\n    with ProcessPoolExecutor(max_workers=max_workers) as executor:\n        # Submit all tasks\n        future_to_file = {executor.submit(_compute_file_sha256, file_path): file_path for file_path in files_to_hash}\n\n        # Process results as they complete\n        for future in as_completed(future_to_file):\n            completed_count += 1\n            try:\n                filename, sha256_hash, file_size_mb = future.result()\n                result[filename] = sha256_hash\n\n                if progress_callback:\n                    progress_callback(f\"  [{completed_count}/{total_files}] ✓ {filename} ({file_size_mb:.1f} MB)\")\n\n            except Exception as e:\n                file_path = future_to_file[future]\n                if progress_callback:\n                    progress_callback(f\"  [{completed_count}/{total_files}] ✗ {file_path.name} - Error: {str(e)}\")\n\n    return result\n\n\ndef fetch_model_sha256(\n    repo_id: str,\n    platform: Literal[\"hf\", \"ms\"],\n    revision: str | None = None,\n    use_mirror: bool = False,\n    timeout: int | None = None,\n) -> dict[str, str]:\n    \"\"\"\n    获取模型仓库中所有重要文件的 sha256 哈希值。\n\n    包括：\n    - *.safetensors (权重文件)\n    - *.json (配置文件：config.json, tokenizer_config.json 等)\n    - *.py (自定义模型代码：modeling.py, configuration.py 等)\n\n    Args:\n        repo_id: 仓库 ID，例如 \"Qwen/Qwen3-30B-A3B\"\n        platform: 平台，\"hf\" (HuggingFace) 或 \"ms\" (ModelScope)\n        revision: 版本/分支，默认 HuggingFace 为 \"main\"，ModelScope 为 \"master\"\n        use_mirror: 是否使用镜像（仅对 HuggingFace 有效）\n        timeout: 网络请求超时时间（秒），None 表示不设置超时\n\n    Returns:\n        dict: 文件名到 sha256 的映射，例如 {\"model-00001-of-00016.safetensors\": \"abc123...\", \"config.json\": \"def456...\"}\n    \"\"\"\n    if platform == \"hf\":\n        # 先尝试直连，失败后自动使用镜像\n        try:\n            if use_mirror:\n                return _fetch_from_huggingface(repo_id, revision or \"main\", use_mirror=True, timeout=timeout)\n            else:\n                return _fetch_from_huggingface(repo_id, revision or \"main\", use_mirror=False, timeout=timeout)\n        except Exception as e:\n            # 如果不是镜像模式且失败了，自动重试使用镜像\n            if not use_mirror:\n                return _fetch_from_huggingface(repo_id, revision or \"main\", use_mirror=True, timeout=timeout)\n            else:\n                raise e\n    elif platform == \"ms\":\n        return _fetch_from_modelscope(repo_id, revision or \"master\", timeout=timeout)\n    else:\n        raise ValueError(f\"不支持的平台: {platform}，请使用 'hf' 或 'ms'\")\n\n\ndef _fetch_from_huggingface(\n    repo_id: str, revision: str, use_mirror: bool = False, timeout: int | None = None\n) -> dict[str, str]:\n    \"\"\"从 HuggingFace 获取所有重要文件的 sha256\n\n    Args:\n        repo_id: 仓库 ID\n        revision: 版本/分支\n        use_mirror: 是否使用镜像（hf-mirror.com）\n        timeout: 网络请求超时时间（秒），None 表示不设置超时\n    \"\"\"\n    import os\n    import socket\n\n    # 如果需要使用镜像，设置环境变量\n    original_endpoint = os.environ.get(\"HF_ENDPOINT\")\n    if use_mirror and not original_endpoint:\n        os.environ[\"HF_ENDPOINT\"] = \"https://hf-mirror.com\"\n\n    # Set socket timeout if specified\n    original_timeout = socket.getdefaulttimeout()\n    if timeout is not None:\n        socket.setdefaulttimeout(timeout)\n\n    from huggingface_hub import HfApi, list_repo_files\n\n    try:\n        api = HfApi()\n        all_files = list_repo_files(repo_id=repo_id, revision=revision)\n\n        # 筛选重要文件：*.safetensors, *.json, *.py\n        important_files = [f for f in all_files if f.endswith((\".safetensors\", \".json\", \".py\"))]\n\n        if not important_files:\n            return {}\n\n        paths_info = api.get_paths_info(\n            repo_id=repo_id,\n            paths=important_files,\n            revision=revision,\n        )\n\n        result = {}\n        for file_info in paths_info:\n            if hasattr(file_info, \"lfs\") and file_info.lfs is not None:\n                sha256 = file_info.lfs.sha256\n            else:\n                sha256 = getattr(file_info, \"blob_id\", None)\n            result[file_info.path] = sha256\n\n        return result\n    finally:\n        # 恢复原始 socket timeout\n        socket.setdefaulttimeout(original_timeout)\n\n        # 恢复原始环境变量\n        if use_mirror and not original_endpoint:\n            os.environ.pop(\"HF_ENDPOINT\", None)\n        elif original_endpoint:\n            os.environ[\"HF_ENDPOINT\"] = original_endpoint\n\n\ndef _fetch_from_modelscope(repo_id: str, revision: str, timeout: int | None = None) -> dict[str, str]:\n    \"\"\"从 ModelScope 获取所有重要文件的 sha256\n\n    Args:\n        repo_id: 仓库 ID\n        revision: 版本/分支\n        timeout: 网络请求超时时间（秒），None 表示不设置超时\n    \"\"\"\n    import socket\n    from modelscope.hub.api import HubApi\n\n    # Set socket timeout if specified\n    original_timeout = socket.getdefaulttimeout()\n    if timeout is not None:\n        socket.setdefaulttimeout(timeout)\n\n    try:\n        api = HubApi()\n        files_info = api.get_model_files(model_id=repo_id, revision=revision)\n\n        result = {}\n        for file_info in files_info:\n            filename = file_info.get(\"Name\", file_info.get(\"Path\", \"\"))\n            # 筛选重要文件：*.safetensors, *.json, *.py\n            if filename.endswith((\".safetensors\", \".json\", \".py\")):\n                sha256 = file_info.get(\"Sha256\", file_info.get(\"sha256\", None))\n                result[filename] = sha256\n\n        return result\n    finally:\n        # 恢复原始 socket timeout\n        socket.setdefaulttimeout(original_timeout)\n\n\ndef verify_model_integrity_with_progress(\n    repo_type: Literal[\"huggingface\", \"modelscope\"],\n    repo_id: str,\n    local_dir: Path,\n    progress_callback=None,\n    verbose: bool = False,\n    use_mirror: bool = False,\n    files_to_verify: list[str] | None = None,\n    timeout: int | None = None,\n) -> Dict[str, Any]:\n    \"\"\"\n    Verify model integrity with enhanced progress reporting for Rich Progress bars.\n\n    This is a wrapper around verify_model_integrity() that provides more detailed\n    progress information suitable for progress bar display.\n\n    The progress_callback receives:\n    - (message: str, total: int, current: int) for countable operations\n    - (message: str) for status updates\n\n    Args:\n        repo_type: Repository type (\"huggingface\" or \"modelscope\")\n        repo_id: Repository ID\n        local_dir: Local directory path\n        progress_callback: Optional callback for progress updates\n        verbose: If True, output detailed SHA256 comparison for each file\n        use_mirror: If True, use HuggingFace mirror (hf-mirror.com)\n        files_to_verify: Optional list of specific files to verify (for re-verification)\n        timeout: Network request timeout in seconds (None = no timeout)\n    \"\"\"\n\n    def report_progress(msg: str, total=None, current=None):\n        \"\"\"Enhanced progress reporter\"\"\"\n        if progress_callback:\n            progress_callback(msg, total, current)\n\n    try:\n        platform = \"hf\" if repo_type == \"huggingface\" else \"ms\"\n\n        # 1. Fetch official SHA256 hashes\n        if files_to_verify:\n            report_progress(f\"Fetching SHA256 hashes for {len(files_to_verify)} files...\")\n        elif use_mirror and platform == \"hf\":\n            report_progress(\"Fetching official SHA256 hashes from mirror (hf-mirror.com)...\")\n        else:\n            report_progress(\"Fetching official SHA256 hashes from remote repository...\")\n\n        official_hashes = fetch_model_sha256(repo_id, platform, use_mirror=use_mirror, timeout=timeout)\n\n        # Filter to only requested files if specified\n        if files_to_verify:\n            # Extract clean filenames from files_to_verify (remove markers like \"(missing)\")\n            clean_filenames = set()\n            for f in files_to_verify:\n                clean_f = f.replace(\" (missing)\", \"\").replace(\" (hash mismatch)\", \"\").strip()\n                # Ensure we only use the filename, not full path\n                clean_filenames.add(Path(clean_f).name)\n\n            # Filter official_hashes to only include requested files\n            # Compare using basename since official_hashes keys might have paths\n            official_hashes = {k: v for k, v in official_hashes.items() if Path(k).name in clean_filenames}\n\n        report_progress(f\"✓ Fetched {len(official_hashes)} file hashes from remote\")\n\n        if not official_hashes:\n            return {\n                \"status\": \"error\",\n                \"files_checked\": 0,\n                \"files_passed\": 0,\n                \"files_failed\": [],\n                \"error_message\": f\"No safetensors files found in remote repository: {repo_id}\",\n            }\n\n        # 2. Calculate local SHA256 hashes\n        local_dir_path = Path(local_dir)\n\n        # Only hash the files we need to verify\n        if files_to_verify:\n            # Extract clean filenames (without markers)\n            clean_filenames = set()\n            for f in files_to_verify:\n                clean_f = f.replace(\" (missing)\", \"\").replace(\" (hash mismatch)\", \"\").strip()\n                # Ensure we only use the filename, not full path\n                clean_filenames.add(Path(clean_f).name)\n\n            # Only hash files that match the clean filenames\n            files_to_hash = [\n                f for f in local_dir_path.glob(\"*.safetensors\") if f.is_file() and f.name in clean_filenames\n            ]\n        else:\n            files_to_hash = [f for f in local_dir_path.glob(\"*.safetensors\") if f.is_file()]\n\n        total_files = len(files_to_hash)\n\n        if files_to_verify:\n            report_progress(f\"Calculating SHA256 for {total_files} repaired files...\", total=total_files, current=0)\n        else:\n            report_progress(f\"Calculating SHA256 for local files...\", total=total_files, current=0)\n\n        # Progress wrapper for hashing\n        completed_count = [0]  # Use list for mutable closure\n\n        def hash_progress_callback(msg: str):\n            if \"Using\" in msg and \"workers\" in msg:\n                report_progress(msg)\n            elif \"[\" in msg and \"/\" in msg and \"]\" in msg:\n                # Progress update like: [1/10] ✓ filename (123.4 MB)\n                completed_count[0] += 1\n                report_progress(msg, total=total_files, current=completed_count[0])\n\n        # Pass the pre-filtered files_to_hash list\n        local_hashes = calculate_local_sha256(\n            local_dir_path,\n            \"*.safetensors\",\n            progress_callback=hash_progress_callback,\n            files_list=files_to_hash if files_to_verify else None,\n        )\n        report_progress(f\"✓ Calculated {len(local_hashes)} local file hashes\")\n\n        # 3. Compare hashes\n        report_progress(f\"Comparing {len(official_hashes)} files...\", total=len(official_hashes), current=0)\n\n        files_failed = []\n        files_missing = []\n        files_passed = 0\n\n        for idx, (filename, official_hash) in enumerate(official_hashes.items(), 1):\n            file_basename = Path(filename).name\n\n            # Find matching local file\n            local_hash = None\n            for local_file, local_hash_value in local_hashes.items():\n                if Path(local_file).name == file_basename:\n                    local_hash = local_hash_value\n                    break\n\n            if local_hash is None:\n                files_missing.append(filename)\n                if verbose:\n                    report_progress(\n                        f\"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)\\n  Remote: {official_hash}\\n  Local:  <missing>\",\n                        total=len(official_hashes),\n                        current=idx,\n                    )\n                else:\n                    report_progress(\n                        f\"[{idx}/{len(official_hashes)}] ✗ {file_basename} (missing)\",\n                        total=len(official_hashes),\n                        current=idx,\n                    )\n            elif local_hash.lower() != official_hash.lower():\n                files_failed.append(f\"{filename} (hash mismatch)\")\n                if verbose:\n                    report_progress(\n                        f\"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)\\n  Remote: {official_hash}\\n  Local:  {local_hash}\",\n                        total=len(official_hashes),\n                        current=idx,\n                    )\n                else:\n                    report_progress(\n                        f\"[{idx}/{len(official_hashes)}] ✗ {file_basename} (hash mismatch)\",\n                        total=len(official_hashes),\n                        current=idx,\n                    )\n            else:\n                files_passed += 1\n                if verbose:\n                    report_progress(\n                        f\"[{idx}/{len(official_hashes)}] ✓ {file_basename}\\n  Remote: {official_hash}\\n  Local:  {local_hash}\",\n                        total=len(official_hashes),\n                        current=idx,\n                    )\n                else:\n                    report_progress(\n                        f\"[{idx}/{len(official_hashes)}] ✓ {file_basename}\", total=len(official_hashes), current=idx\n                    )\n\n        # 4. Return results\n        total_checked = len(official_hashes)\n\n        if files_failed or files_missing:\n            all_failed = files_failed + [f\"{f} (missing)\" for f in files_missing]\n            return {\n                \"status\": \"failed\",\n                \"files_checked\": total_checked,\n                \"files_passed\": files_passed,\n                \"files_failed\": all_failed,\n                \"error_message\": f\"{len(all_failed)} file(s) failed verification\",\n            }\n        else:\n            return {\n                \"status\": \"passed\",\n                \"files_checked\": total_checked,\n                \"files_passed\": files_passed,\n                \"files_failed\": [],\n            }\n\n    except (\n        requests.exceptions.ConnectionError,\n        requests.exceptions.Timeout,\n        requests.exceptions.RequestException,\n        TimeoutError,  # Socket timeout from socket.setdefaulttimeout()\n        OSError,  # Network-related OS errors\n    ) as e:\n        error_msg = f\"Network error: {str(e)}\"\n        if repo_type == \"huggingface\":\n            error_msg += \"\\n\\nTry using HuggingFace mirror:\\n  export HF_ENDPOINT=https://hf-mirror.com\"\n        return {\n            \"status\": \"error\",\n            \"files_checked\": 0,\n            \"files_passed\": 0,\n            \"files_failed\": [],\n            \"error_message\": error_msg,\n            \"is_network_error\": True,\n        }\n    except Exception as e:\n        return {\n            \"status\": \"error\",\n            \"files_checked\": 0,\n            \"files_passed\": 0,\n            \"files_failed\": [],\n            \"error_message\": f\"Verification failed: {str(e)}\",\n            \"is_network_error\": False,\n        }\n\n\ndef pre_operation_verification(user_model, user_registry, operation_name: str = \"operation\") -> None:\n    \"\"\"Pre-operation verification of model integrity.\n\n    Can be used before running or quantizing models to ensure integrity.\n\n    Args:\n        user_model: UserModel object to verify\n        user_registry: UserModelRegistry instance\n        operation_name: Name of the operation (e.g., \"running\", \"quantizing\")\n    \"\"\"\n    from rich.prompt import Prompt, Confirm\n    from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, MofNCompleteColumn, TimeElapsedColumn\n    from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError\n    from kt_kernel.cli.i18n import get_lang\n    from kt_kernel.cli.utils.console import console, print_info, print_warning, print_error, print_success, print_step\n    import typer\n\n    lang = get_lang()\n\n    # Check if already verified\n    if user_model.sha256_status == \"passed\":\n        console.print()\n        print_info(\"Model integrity already verified ✓\")\n        console.print()\n        return\n\n    # Model not verified yet\n    console.print()\n    console.print(\"[bold yellow]═══ Model Integrity Check ═══[/bold yellow]\")\n    console.print()\n\n    # Check if repo_id exists\n    if not user_model.repo_id:\n        # No repo_id - ask user to provide one\n        console.print(\"[yellow]No repository ID configured for this model.[/yellow]\")\n        console.print()\n        console.print(\"To verify model integrity, we need the repository ID (e.g., 'deepseek-ai/DeepSeek-V3')\")\n        console.print()\n\n        if not Confirm.ask(\"Would you like to configure repository ID now?\", default=True):\n            console.print()\n            print_warning(f\"Skipping verification. Model will be used for {operation_name} without integrity check.\")\n            console.print()\n            return\n\n        # Ask for repo type\n        console.print()\n        console.print(\"Repository type:\")\n        console.print(\"  [cyan][1][/cyan] HuggingFace\")\n        console.print(\"  [cyan][2][/cyan] ModelScope\")\n        console.print()\n\n        repo_type_choice = Prompt.ask(\"Select repository type\", choices=[\"1\", \"2\"], default=\"1\")\n        repo_type = \"huggingface\" if repo_type_choice == \"1\" else \"modelscope\"\n\n        # Ask for repo_id\n        console.print()\n        repo_id = Prompt.ask(\"Enter repository ID (e.g., deepseek-ai/DeepSeek-V3)\")\n\n        # Update model\n        user_registry.update_model(user_model.name, {\"repo_type\": repo_type, \"repo_id\": repo_id})\n        user_model.repo_type = repo_type\n        user_model.repo_id = repo_id\n\n        console.print()\n        print_success(f\"Repository configured: {repo_type}:{repo_id}\")\n        console.print()\n\n    # Now ask if user wants to verify\n    console.print(\"[dim]Model integrity verification is a one-time check that ensures your[/dim]\")\n    console.print(\"[dim]model weights are not corrupted. This helps prevent runtime errors.[/dim]\")\n    console.print()\n\n    if not Confirm.ask(f\"Would you like to verify model integrity before {operation_name}?\", default=True):\n        console.print()\n        print_warning(f\"Skipping verification. Model will be used for {operation_name} without integrity check.\")\n        console.print()\n        return\n\n    # Perform verification\n    console.print()\n    print_step(\"Verifying model integrity...\")\n    console.print()\n\n    # Check connectivity first\n    use_mirror = False\n    if user_model.repo_type == \"huggingface\":\n        with console.status(\"[dim]Checking HuggingFace connectivity...[/dim]\"):\n            is_accessible, message = check_huggingface_connectivity(timeout=5)\n\n        if not is_accessible:\n            print_warning(\"HuggingFace Connection Failed\")\n            console.print()\n            console.print(f\"  {message}\")\n            console.print()\n            console.print(\"  [yellow]Auto-switching to HuggingFace mirror:[/yellow] [cyan]hf-mirror.com[/cyan]\")\n            console.print()\n            use_mirror = True\n\n    # Fetch remote hashes with timeout\n    def fetch_with_timeout(repo_type, repo_id, use_mirror, timeout):\n        \"\"\"Fetch hashes with timeout.\"\"\"\n        executor = ThreadPoolExecutor(max_workers=1)\n        try:\n            platform = \"hf\" if repo_type == \"huggingface\" else \"ms\"\n            future = executor.submit(fetch_model_sha256, repo_id, platform, use_mirror=use_mirror, timeout=timeout)\n            hashes = future.result(timeout=timeout)\n            executor.shutdown(wait=False)\n            return (hashes, False)\n        except (FutureTimeoutError, Exception):\n            executor.shutdown(wait=False)\n            return (None, True)\n\n    # Try fetching hashes\n    status = console.status(\"[dim]Fetching remote hashes...[/dim]\")\n    status.start()\n    official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, use_mirror, 10)\n    status.stop()\n\n    # Handle timeout with fallback\n    if timed_out and user_model.repo_type == \"huggingface\" and not use_mirror:\n        print_warning(\"HuggingFace Fetch Timeout (10s)\")\n        console.print()\n        console.print(\"  [yellow]Trying HuggingFace mirror...[/yellow]\")\n        console.print()\n\n        status = console.status(\"[dim]Fetching remote hashes from mirror...[/dim]\")\n        status.start()\n        official_hashes, timed_out = fetch_with_timeout(user_model.repo_type, user_model.repo_id, True, 10)\n        status.stop()\n\n    if timed_out and user_model.repo_type == \"huggingface\":\n        print_warning(\"HuggingFace Mirror Timeout (10s)\")\n        console.print()\n        console.print(\"  [yellow]Fallback to ModelScope...[/yellow]\")\n        console.print()\n\n        status = console.status(\"[dim]Fetching remote hashes from ModelScope...[/dim]\")\n        status.start()\n        official_hashes, timed_out = fetch_with_timeout(\"modelscope\", user_model.repo_id, False, 10)\n        status.stop()\n\n    if not official_hashes or timed_out:\n        print_error(\"Failed to fetch remote hashes (network timeout)\")\n        console.print()\n        console.print(\"  [yellow]Unable to verify model integrity due to network issues.[/yellow]\")\n        console.print()\n\n        if not Confirm.ask(f\"Continue {operation_name} without verification?\", default=False):\n            raise typer.Exit(0)\n\n        console.print()\n        return\n\n    console.print(f\"  [green]✓ Fetched {len(official_hashes)} file hashes[/green]\")\n    console.print()\n\n    # Calculate local hashes and compare\n    local_dir = Path(user_model.path)\n    files_to_hash = [f for f in local_dir.glob(\"*.safetensors\") if f.is_file()]\n\n    with Progress(\n        SpinnerColumn(),\n        TextColumn(\"[progress.description]{task.description}\"),\n        BarColumn(),\n        MofNCompleteColumn(),\n        TimeElapsedColumn(),\n        console=console,\n    ) as progress:\n        # Calculate local hashes\n        task = progress.add_task(\"[yellow]Calculating local SHA256...\", total=len(files_to_hash))\n\n        def hash_callback(msg):\n            if \"[\" in msg and \"/\" in msg and \"]\" in msg and \"✓\" in msg:\n                progress.advance(task)\n\n        local_hashes = calculate_local_sha256(local_dir, \"*.safetensors\", progress_callback=hash_callback)\n        progress.remove_task(task)\n\n        console.print(f\"  [green]✓ Calculated {len(local_hashes)} local hashes[/green]\")\n        console.print()\n\n        # Compare hashes\n        task = progress.add_task(\"[blue]Comparing hashes...\", total=len(official_hashes))\n\n        files_failed = []\n        files_missing = []\n        files_passed = 0\n\n        for filename, official_hash in official_hashes.items():\n            file_basename = Path(filename).name\n            local_hash = None\n\n            for local_file, local_hash_value in local_hashes.items():\n                if Path(local_file).name == file_basename:\n                    local_hash = local_hash_value\n                    break\n\n            if local_hash is None:\n                files_missing.append(filename)\n            elif local_hash.lower() != official_hash.lower():\n                files_failed.append(f\"{filename} (hash mismatch)\")\n            else:\n                files_passed += 1\n\n            progress.advance(task)\n\n        progress.remove_task(task)\n\n    console.print()\n\n    # Check results\n    if not files_failed and not files_missing:\n        # Verification passed\n        user_registry.update_model(user_model.name, {\"sha256_status\": \"passed\"})\n        print_success(\"Model integrity verification PASSED ✓\")\n        console.print()\n        console.print(f\"  All {files_passed} files verified successfully\")\n        console.print()\n    else:\n        # Verification failed\n        user_registry.update_model(user_model.name, {\"sha256_status\": \"failed\"})\n        print_error(f\"Model integrity verification FAILED\")\n        console.print()\n        console.print(f\"  ✓ Passed: [green]{files_passed}[/green]\")\n        console.print(f\"  ✗ Failed: [red]{len(files_failed) + len(files_missing)}[/red]\")\n        console.print()\n\n        if files_missing:\n            console.print(f\"  [red]Missing files ({len(files_missing)}):[/red]\")\n            for f in files_missing[:5]:\n                console.print(f\"    - {Path(f).name}\")\n            if len(files_missing) > 5:\n                console.print(f\"    ... and {len(files_missing) - 5} more\")\n            console.print()\n\n        if files_failed:\n            console.print(f\"  [red]Hash mismatch ({len(files_failed)}):[/red]\")\n            for f in files_failed[:5]:\n                console.print(f\"    - {f}\")\n            if len(files_failed) > 5:\n                console.print(f\"    ... and {len(files_failed) - 5} more\")\n            console.print()\n\n        console.print(\"[bold red]⚠ WARNING: Model weights may be corrupted![/bold red]\")\n        console.print()\n        console.print(\"This could cause runtime errors or incorrect inference results.\")\n        console.print()\n\n        # Ask if user wants to repair\n        if Confirm.ask(\"Would you like to repair (re-download) the corrupted files?\", default=True):\n            console.print()\n            print_info(\"Please run: [cyan]kt model verify \" + user_model.name + \"[/cyan]\")\n            console.print()\n            console.print(\"The verify command will guide you through the repair process.\")\n            raise typer.Exit(0)\n\n        # Ask if user wants to continue anyway\n        console.print()\n        if not Confirm.ask(\n            f\"[yellow]Continue {operation_name} with potentially corrupted weights?[/yellow]\", default=False\n        ):\n            raise typer.Exit(0)\n\n        console.print()\n        print_warning(f\"Proceeding with {operation_name} using unverified weights at your own risk...\")\n        console.print()\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/port_checker.py",
    "content": "\"\"\"\nPort availability checking utilities.\n\"\"\"\n\nimport socket\nfrom typing import Tuple\n\n\ndef is_port_available(host: str, port: int) -> bool:\n    \"\"\"Check if a port is available on the given host.\n\n    Args:\n        host: Host address (e.g., \"0.0.0.0\", \"127.0.0.1\")\n        port: Port number to check\n\n    Returns:\n        True if port is available, False if occupied\n    \"\"\"\n    try:\n        # Try to bind to the port\n        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n        sock.settimeout(1)\n\n        # Use SO_REUSEADDR to allow binding to recently closed ports\n        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n\n        # Try to bind\n        result = sock.connect_ex((host if host != \"0.0.0.0\" else \"127.0.0.1\", port))\n        sock.close()\n\n        # If connect_ex returns 0, port is occupied\n        # If it returns error (non-zero), port is available\n        return result != 0\n\n    except Exception:\n        # If any error occurs, assume port is not available\n        return False\n\n\ndef find_available_port(host: str, start_port: int, max_attempts: int = 100) -> Tuple[bool, int]:\n    \"\"\"Find an available port starting from start_port.\n\n    Args:\n        host: Host address\n        start_port: Starting port number to check\n        max_attempts: Maximum number of ports to try\n\n    Returns:\n        Tuple of (found, port_number)\n        - found: True if an available port was found\n        - port_number: The available port number (or start_port if not found)\n    \"\"\"\n    for port in range(start_port, start_port + max_attempts):\n        if is_port_available(host, port):\n            return True, port\n\n    return False, start_port\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/quant_interactive.py",
    "content": "\"\"\"\nInteractive configuration for kt quant command.\n\nProvides rich, multi-step interactive configuration for model quantization.\n\"\"\"\n\nfrom typing import Optional, Dict, Any\nfrom pathlib import Path\nfrom rich.console import Console\nfrom rich.table import Table\nfrom rich.panel import Panel\nfrom rich.prompt import Prompt, Confirm, IntPrompt\nfrom kt_kernel.cli.i18n import t\n\n\nconsole = Console()\n\n\ndef select_model_to_quantize() -> Optional[Any]:\n    \"\"\"Select model to quantize interactively.\"\"\"\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n    from kt_kernel.cli.commands.model import is_amx_weights, SHA256_STATUS_MAP\n    from kt_kernel.cli.utils.model_table_builder import build_moe_gpu_table\n\n    registry = UserModelRegistry()\n    all_models = registry.list_models()\n\n    # Filter MoE models only (safetensors, not AMX, is_moe=True)\n    quant_models = []\n    for model in all_models:\n        if model.format == \"safetensors\":\n            # Skip AMX models\n            is_amx, _ = is_amx_weights(model.path)\n            if is_amx:\n                continue\n\n            # Only include MoE models\n            if model.is_moe:\n                quant_models.append(model)\n\n    if not quant_models:\n        console.print(f\"[yellow]{t('quant_no_moe_models')}[/yellow]\")\n        console.print()\n        console.print(f\"  {t('quant_only_moe')}\")\n        console.print()\n        console.print(f\"  {t('quant_add_models', command='kt model scan')}\")\n        console.print(f\"  {t('quant_add_models', command='kt model add <path>')}\")\n        return None\n\n    # Display models\n    console.print()\n    console.print(f\"[bold green]{t('quant_moe_available')}[/bold green]\")\n    console.print()\n\n    # Use shared table builder\n    table, displayed_models = build_moe_gpu_table(\n        models=quant_models, status_map=SHA256_STATUS_MAP, show_index=True, start_index=1\n    )\n\n    console.print(table)\n    console.print()\n\n    choice = IntPrompt.ask(t(\"quant_select_model\"), default=1, show_choices=False)\n\n    if choice < 1 or choice > len(displayed_models):\n        console.print(f\"[red]{t('quant_invalid_choice')}[/red]\")\n        return None\n\n    return displayed_models[choice - 1]\n\n\ndef configure_quantization_method() -> Dict[str, str]:\n    \"\"\"Select quantization method and input type.\"\"\"\n    console.print()\n    console.print(Panel(f\"[bold cyan]{t('quant_step2_method')}[/bold cyan]\", expand=False))\n    console.print()\n\n    # Method selection\n    console.print(f\"[bold]{t('quant_method_label')}[/bold]\")\n    console.print(f\"  [cyan][1][/cyan] {t('quant_int4_desc')}\")\n    console.print(f\"  [cyan][2][/cyan] {t('quant_int8_desc')}\")\n    console.print()\n\n    method_choice = Prompt.ask(t(\"quant_select_method\"), choices=[\"1\", \"2\"], default=\"1\")\n    method = \"int4\" if method_choice == \"1\" else \"int8\"\n\n    console.print()\n    console.print(f\"[bold]{t('quant_input_type_label')}[/bold]\")\n    console.print(f\"  [cyan][1][/cyan] {t('quant_fp8_desc')}\")\n    console.print(f\"  [cyan][2][/cyan] {t('quant_fp16_desc')}\")\n    console.print(f\"  [cyan][3][/cyan] {t('quant_bf16_desc')}\")\n    console.print()\n\n    input_choice = Prompt.ask(t(\"quant_select_input_type\"), choices=[\"1\", \"2\", \"3\"], default=\"1\")\n    input_type_map = {\"1\": \"fp8\", \"2\": \"fp16\", \"3\": \"bf16\"}\n    input_type = input_type_map[input_choice]\n\n    return {\"method\": method, \"input_type\": input_type}\n\n\ndef configure_cpu_params(max_cores: int, max_numa: int) -> Dict[str, Any]:\n    \"\"\"Configure CPU parameters.\"\"\"\n    console.print()\n    console.print(Panel(f\"[bold cyan]{t('quant_step3_cpu')}[/bold cyan]\", expand=False))\n    console.print()\n\n    def clamp(value: int, min_val: int, max_val: int, default: int) -> int:\n        \"\"\"Clamp value to range or return default if out of bounds.\"\"\"\n        if min_val <= value <= max_val:\n            return max(min_val, min(value, max_val))\n        return default\n\n    default_threads = int(max_cores * 0.8)\n    cpu_threads = IntPrompt.ask(t(\"quant_cpu_threads_prompt\", max=max_cores), default=default_threads)\n    cpu_threads = clamp(cpu_threads, 1, max_cores, default_threads)\n\n    numa_nodes = IntPrompt.ask(t(\"quant_numa_nodes_prompt\", max=max_numa), default=max_numa)\n    numa_nodes = clamp(numa_nodes, 1, max_numa, max_numa)\n\n    # Ask about GPU usage\n    console.print()\n    console.print(f\"[bold]{t('quant_use_gpu_label')}[/bold]\")\n    console.print(f\"  [dim]{t('quant_gpu_speedup')}[/dim]\")\n    console.print()\n    use_gpu = Confirm.ask(t(\"quant_enable_gpu\"), default=True)\n\n    return {\"cpu_threads\": cpu_threads, \"numa_nodes\": numa_nodes, \"use_gpu\": use_gpu}\n\n\ndef configure_output_path(model: Any, method: str, numa_nodes: int) -> Path:\n    \"\"\"Configure output path for quantized weights.\"\"\"\n    from kt_kernel.cli.config.settings import get_settings\n\n    console.print()\n    console.print(Panel(f\"[bold cyan]{t('quant_step4_output')}[/bold cyan]\", expand=False))\n    console.print()\n\n    # Generate default output path\n    model_path = Path(model.path)\n    method_upper = method.upper()\n    settings = get_settings()\n\n    # Priority: paths.weights > paths.models[0] > model's parent directory\n    weights_dir = settings.weights_dir\n    if weights_dir and weights_dir.exists():\n        # Use configured weights directory (highest priority)\n        default_output = weights_dir / f\"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}\"\n    else:\n        # Use first model storage path\n        model_paths = settings.get_model_paths()\n        if model_paths and model_paths[0].exists():\n            default_output = model_paths[0] / f\"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}\"\n        else:\n            # Fallback to model's parent directory\n            default_output = model_path.parent / f\"{model_path.name}-AMX{method_upper}-NUMA{numa_nodes}\"\n\n    console.print(f\"[dim]{t('quant_default_path')}[/dim]\", default_output)\n    console.print()\n\n    use_default = Confirm.ask(t(\"quant_use_default\"), default=True)\n\n    if use_default:\n        return default_output\n\n    custom_path = Prompt.ask(t(\"quant_custom_path\"), default=str(default_output))\n\n    return Path(custom_path)\n\n\ndef calculate_quantized_size(source_path: Path, input_type: str, quant_method: str) -> tuple[float, float]:\n    \"\"\"\n    Calculate source model size and estimated quantized size.\n\n    Args:\n        source_path: Path to source model\n        input_type: Input type (fp8, fp16, bf16)\n        quant_method: Quantization method (int4, int8)\n\n    Returns:\n        Tuple of (source_size_gb, estimated_quant_size_gb)\n    \"\"\"\n    # Calculate source model size\n    try:\n        total_bytes = sum(f.stat().st_size for f in source_path.glob(\"*.safetensors\") if f.is_file())\n        source_size_gb = total_bytes / (1024**3)\n    except Exception:\n        return 0.0, 0.0\n\n    # Bits mapping\n    input_bits = {\"fp8\": 8, \"fp16\": 16, \"bf16\": 16}\n    quant_bits = {\"int4\": 4, \"int8\": 8}\n\n    input_bit = input_bits.get(input_type, 16)\n    quant_bit = quant_bits.get(quant_method, 4)\n\n    # Estimate: source_size * (quant_bits / input_bits)\n    ratio = quant_bit / input_bit\n    estimated_size_gb = source_size_gb * ratio\n\n    return source_size_gb, estimated_size_gb\n\n\ndef check_disk_space(output_path: Path, required_size_gb: float) -> tuple[float, bool]:\n    \"\"\"\n    Check available disk space at output path.\n\n    Args:\n        output_path: Target output path\n        required_size_gb: Required space in GB\n\n    Returns:\n        Tuple of (available_gb, is_sufficient)\n        is_sufficient is True if available >= required * 1.2\n    \"\"\"\n    import shutil\n\n    try:\n        # Get parent directory that exists\n        check_path = output_path.parent if not output_path.exists() else output_path\n        while not check_path.exists() and check_path != check_path.parent:\n            check_path = check_path.parent\n\n        stat = shutil.disk_usage(check_path)\n        available_gb = stat.free / (1024**3)\n\n        # Check if available space >= required * 1.2 (20% buffer)\n        is_sufficient = available_gb >= (required_size_gb * 1.2)\n\n        return available_gb, is_sufficient\n    except Exception:\n        return 0.0, False\n\n\ndef interactive_quant_config() -> Optional[Dict[str, Any]]:\n    \"\"\"\n    Interactive configuration for kt quant.\n\n    Returns configuration dict or None if cancelled.\n    \"\"\"\n    from kt_kernel.cli.utils.environment import detect_cpu_info\n\n    # Get CPU info\n    cpu_info = detect_cpu_info()\n\n    # Step 1: Select model\n    model = select_model_to_quantize()\n    if not model:\n        return None\n\n    # Step 1.5: Pre-quantization verification (optional)\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n    from kt_kernel.cli.utils.model_verifier import pre_operation_verification\n\n    user_registry = UserModelRegistry()\n    user_model_obj = user_registry.find_by_path(model.path)\n\n    if user_model_obj and user_model_obj.format == \"safetensors\":\n        pre_operation_verification(user_model_obj, user_registry, operation_name=\"quantizing\")\n\n    # Step 2: Configure quantization method\n    quant_config = configure_quantization_method()\n\n    # Step 3: Configure CPU parameters\n    cpu_config = configure_cpu_params(cpu_info.threads, cpu_info.numa_nodes)  # Use logical threads\n\n    # Step 4: Configure output path\n    output_path = configure_output_path(model, quant_config[\"method\"], cpu_config[\"numa_nodes\"])\n\n    # Step 4.5: Check if output path already exists and generate unique name\n    if output_path.exists():\n        console.print()\n        console.print(t(\"quant_output_exists_warn\", path=str(output_path)))\n        console.print()\n\n        # Generate unique name by adding suffix\n        original_name = output_path.name\n        parent_dir = output_path.parent\n        counter = 2\n\n        while output_path.exists():\n            new_name = f\"{original_name}-{counter}\"\n            output_path = parent_dir / new_name\n            counter += 1\n\n        console.print(t(\"quant_using_unique_name\", path=str(output_path)))\n        console.print()\n\n    # Step 5: Calculate space requirements and check availability\n    console.print()\n    console.print(Panel(f\"[bold cyan]{t('quant_disk_analysis')}[/bold cyan]\", expand=False))\n    console.print()\n\n    source_size_gb, estimated_size_gb = calculate_quantized_size(\n        Path(model.path), quant_config[\"input_type\"], quant_config[\"method\"]\n    )\n\n    available_gb, is_sufficient = check_disk_space(output_path, estimated_size_gb)\n\n    console.print(f\"  {t('quant_source_size'):<26} [cyan]{source_size_gb:.2f} GB[/cyan]\")\n    console.print(f\"  {t('quant_estimated_size'):<26} [yellow]{estimated_size_gb:.2f} GB[/yellow]\")\n    console.print(\n        f\"  {t('quant_available_space'):<26} [{'green' if is_sufficient else 'red'}]{available_gb:.2f} GB[/{'green' if is_sufficient else 'red'}]\"\n    )\n    console.print()\n\n    if not is_sufficient:\n        required_with_buffer = estimated_size_gb * 1.2\n        console.print(f\"[bold red]⚠ {t('quant_insufficient_space')}[/bold red]\")\n        console.print()\n        console.print(f\"  {t('quant_required_space'):<26} [yellow]{required_with_buffer:.2f} GB[/yellow]\")\n        console.print(f\"  {t('quant_available_space'):<26} [red]{available_gb:.2f} GB[/red]\")\n        console.print(f\"  {t('quant_shortage'):<26} [red]{required_with_buffer - available_gb:.2f} GB[/red]\")\n        console.print()\n        console.print(f\"  {t('quant_may_fail')}\")\n        console.print()\n\n        if not Confirm.ask(f\"[yellow]{t('quant_continue_anyway')}[/yellow]\", default=False):\n            console.print(f\"[yellow]{t('quant_cancelled')}[/yellow]\")\n            return None\n        console.print()\n\n    # Summary and confirmation\n    console.print()\n    console.print(Panel(f\"[bold cyan]{t('quant_config_summary')}[/bold cyan]\", expand=False))\n    console.print()\n    console.print(f\"  {t('quant_summary_model'):<15} {model.name}\")\n    console.print(f\"  {t('quant_summary_method'):<15} {quant_config['method'].upper()}\")\n    console.print(f\"  {t('quant_summary_input_type'):<15} {quant_config['input_type'].upper()}\")\n    console.print(f\"  {t('quant_summary_cpu_threads'):<15} {cpu_config['cpu_threads']}\")\n    console.print(f\"  {t('quant_summary_numa'):<15} {cpu_config['numa_nodes']}\")\n    console.print(f\"  {t('quant_summary_gpu'):<15} {t('yes') if cpu_config['use_gpu'] else t('no')}\")\n    console.print(f\"  {t('quant_summary_output'):<15} {output_path}\")\n    console.print()\n\n    if not Confirm.ask(f\"[bold green]{t('quant_start_question')}[/bold green]\", default=True):\n        console.print(f\"[yellow]{t('quant_cancelled')}[/yellow]\")\n        return None\n\n    return {\n        \"model\": model,\n        \"method\": quant_config[\"method\"],\n        \"input_type\": quant_config[\"input_type\"],\n        \"cpu_threads\": cpu_config[\"cpu_threads\"],\n        \"numa_nodes\": cpu_config[\"numa_nodes\"],\n        \"use_gpu\": cpu_config[\"use_gpu\"],\n        \"output_path\": output_path,\n    }\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/repo_detector.py",
    "content": "\"\"\"\nRepo Detector\n\nAutomatically detect repository information from model README.md files\n\"\"\"\n\nimport re\nfrom pathlib import Path\nfrom typing import Optional, Dict, Tuple\nimport yaml\n\n\ndef parse_readme_frontmatter(readme_path: Path) -> Optional[Dict]:\n    \"\"\"\n    Parse YAML frontmatter from README.md\n\n    Args:\n        readme_path: Path to README.md file\n\n    Returns:\n        Dictionary of frontmatter data, or None if not found\n    \"\"\"\n    if not readme_path.exists():\n        return None\n\n    try:\n        with open(readme_path, \"r\", encoding=\"utf-8\") as f:\n            content = f.read()\n\n        # Match YAML frontmatter between --- markers\n        match = re.match(r\"^---\\s*\\n(.*?)\\n---\\s*\\n\", content, re.DOTALL)\n        if not match:\n            return None\n\n        yaml_content = match.group(1)\n\n        # Parse YAML\n        try:\n            data = yaml.safe_load(yaml_content)\n            return data if isinstance(data, dict) else None\n        except yaml.YAMLError:\n            return None\n\n    except Exception as e:\n        return None\n\n\ndef extract_repo_from_frontmatter(frontmatter: Dict) -> Optional[Tuple[str, str]]:\n    \"\"\"\n    Extract repo_id and repo_type from frontmatter\n\n    Args:\n        frontmatter: Parsed YAML frontmatter dictionary\n\n    Returns:\n        Tuple of (repo_id, repo_type) or None\n        repo_type is either \"huggingface\" or \"modelscope\"\n    \"\"\"\n    if not frontmatter:\n        return None\n\n    # Priority 1: Extract from license_link (most reliable)\n    license_link = frontmatter.get(\"license_link\")\n    if license_link and isinstance(license_link, str):\n        result = _extract_repo_from_url(license_link)\n        if result:\n            return result\n\n    # Priority 2: Try to find repo_id from other fields\n    repo_id = None\n\n    # Check base_model field\n    base_model = frontmatter.get(\"base_model\")\n    if base_model:\n        if isinstance(base_model, list) and len(base_model) > 0:\n            # base_model is a list, take first item\n            repo_id = base_model[0]\n        elif isinstance(base_model, str):\n            repo_id = base_model\n\n    # Check model-index field\n    if not repo_id:\n        model_index = frontmatter.get(\"model-index\")\n        if isinstance(model_index, list) and len(model_index) > 0:\n            first_model = model_index[0]\n            if isinstance(first_model, dict):\n                repo_id = first_model.get(\"name\")\n\n    # Check model_name field\n    if not repo_id:\n        repo_id = frontmatter.get(\"model_name\")\n\n    if not repo_id or not isinstance(repo_id, str):\n        return None\n\n    # Validate format: should be \"namespace/model-name\"\n    if \"/\" not in repo_id:\n        return None\n\n    parts = repo_id.split(\"/\")\n    if len(parts) != 2:\n        return None\n\n    # Determine repo type\n    repo_type = \"huggingface\"  # Default\n\n    # Look for ModelScope indicators\n    if \"modelscope\" in repo_id.lower():\n        repo_type = \"modelscope\"\n\n    # Check tags\n    tags = frontmatter.get(\"tags\", [])\n    if isinstance(tags, list):\n        if \"modelscope\" in [str(t).lower() for t in tags]:\n            repo_type = \"modelscope\"\n\n    return (repo_id, repo_type)\n\n\ndef _extract_repo_from_url(url: str) -> Optional[Tuple[str, str]]:\n    \"\"\"\n    Extract repo_id and repo_type from a URL\n\n    Supports:\n    - https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/LICENSE\n    - https://modelscope.cn/models/Qwen/Qwen3-30B-A3B\n\n    Args:\n        url: URL string\n\n    Returns:\n        Tuple of (repo_id, repo_type) or None\n    \"\"\"\n    # HuggingFace pattern: https://huggingface.co/{namespace}/{model}/...\n    hf_match = re.match(r\"https?://huggingface\\.co/([^/]+)/([^/]+)\", url)\n    if hf_match:\n        namespace = hf_match.group(1)\n        model_name = hf_match.group(2)\n        repo_id = f\"{namespace}/{model_name}\"\n        return (repo_id, \"huggingface\")\n\n    # ModelScope pattern: https://modelscope.cn/models/{namespace}/{model}\n    ms_match = re.match(r\"https?://(?:www\\.)?modelscope\\.cn/models/([^/]+)/([^/]+)\", url)\n    if ms_match:\n        namespace = ms_match.group(1)\n        model_name = ms_match.group(2)\n        repo_id = f\"{namespace}/{model_name}\"\n        return (repo_id, \"modelscope\")\n\n    return None\n\n\ndef extract_repo_from_global_search(readme_path: Path) -> Optional[Tuple[str, str]]:\n    \"\"\"\n    Extract repo info by globally searching for URLs in README.md\n\n    Args:\n        readme_path: Path to README.md file\n\n    Returns:\n        Tuple of (repo_id, repo_type) or None if not found\n    \"\"\"\n    if not readme_path.exists():\n        return None\n\n    try:\n        with open(readme_path, \"r\", encoding=\"utf-8\") as f:\n            content = f.read()\n\n        # Find all HuggingFace URLs\n        hf_pattern = r\"https?://huggingface\\.co/([^/\\s]+)/([^/\\s\\)]+)\"\n        hf_matches = re.findall(hf_pattern, content)\n\n        # Find all ModelScope URLs\n        ms_pattern = r\"https?://(?:www\\.)?modelscope\\.cn/models/([^/\\s]+)/([^/\\s\\)]+)\"\n        ms_matches = re.findall(ms_pattern, content)\n\n        # Collect all found repos with their types\n        found_repos = []\n\n        for namespace, model_name in hf_matches:\n            # Skip common non-repo paths\n            if namespace.lower() in [\"docs\", \"blog\", \"spaces\", \"datasets\"]:\n                continue\n            if model_name.lower() in [\"tree\", \"blob\", \"raw\", \"resolve\", \"discussions\"]:\n                continue\n\n            repo_id = f\"{namespace}/{model_name}\"\n            found_repos.append((repo_id, \"huggingface\"))\n\n        for namespace, model_name in ms_matches:\n            repo_id = f\"{namespace}/{model_name}\"\n            found_repos.append((repo_id, \"modelscope\"))\n\n        if not found_repos:\n            return None\n\n        # If multiple different repos found, use the last one\n        # First, deduplicate\n        seen = {}\n        for repo_id, repo_type in found_repos:\n            seen[repo_id] = repo_type  # Will keep the last occurrence\n\n        # Get the last unique repo\n        if seen:\n            # Use the last item from found_repos that's unique\n            last_unique = None\n            for repo_id, repo_type in found_repos:\n                if repo_id in seen:\n                    last_unique = (repo_id, repo_type)\n\n            return last_unique\n\n        return None\n\n    except Exception as e:\n        return None\n\n\ndef detect_repo_for_model(model_path: str) -> Optional[Tuple[str, str]]:\n    \"\"\"\n    Detect repository information for a model\n\n    Strategy:\n    Only extract from YAML frontmatter metadata in README.md\n    (Removed global URL search to avoid false positives)\n\n    Args:\n        model_path: Path to model directory\n\n    Returns:\n        Tuple of (repo_id, repo_type) or None if not detected\n    \"\"\"\n    model_dir = Path(model_path)\n\n    if not model_dir.exists() or not model_dir.is_dir():\n        return None\n\n    # Look for README.md\n    readme_path = model_dir / \"README.md\"\n    if not readme_path.exists():\n        return None\n\n    # Only parse YAML frontmatter (no fallback to global search)\n    frontmatter = parse_readme_frontmatter(readme_path)\n    if frontmatter:\n        return extract_repo_from_frontmatter(frontmatter)\n\n    return None\n\n\ndef scan_models_for_repo(model_list) -> Dict:\n    \"\"\"\n    Scan a list of models and detect repo information\n\n    Args:\n        model_list: List of UserModel objects\n\n    Returns:\n        Dictionary with scan results:\n        {\n            'detected': [(model, repo_id, repo_type), ...],\n            'not_detected': [model, ...],\n            'skipped': [model, ...]  # Already has repo_id\n        }\n    \"\"\"\n    results = {\"detected\": [], \"not_detected\": [], \"skipped\": []}\n\n    for model in model_list:\n        # Skip if already has repo_id\n        if model.repo_id:\n            results[\"skipped\"].append(model)\n            continue\n\n        # Only process safetensors and gguf models\n        if model.format not in [\"safetensors\", \"gguf\"]:\n            results[\"skipped\"].append(model)\n            continue\n\n        # Try to detect repo\n        repo_info = detect_repo_for_model(model.path)\n\n        if repo_info:\n            repo_id, repo_type = repo_info\n            results[\"detected\"].append((model, repo_id, repo_type))\n        else:\n            results[\"not_detected\"].append(model)\n\n    return results\n\n\ndef format_detection_report(results: Dict) -> str:\n    \"\"\"\n    Format scan results into a readable report\n\n    Args:\n        results: Results from scan_models_for_repo()\n\n    Returns:\n        Formatted string report\n    \"\"\"\n    lines = []\n\n    lines.append(\"=\" * 80)\n    lines.append(\"Auto-Detection Report\")\n    lines.append(\"=\" * 80)\n    lines.append(\"\")\n\n    # Detected\n    if results[\"detected\"]:\n        lines.append(f\"✓ Detected repository information ({len(results['detected'])} models):\")\n        lines.append(\"\")\n        for model, repo_id, repo_type in results[\"detected\"]:\n            lines.append(f\"  • {model.name}\")\n            lines.append(f\"    Path: {model.path}\")\n            lines.append(f\"    Repo: {repo_id} ({repo_type})\")\n            lines.append(\"\")\n\n    # Not detected\n    if results[\"not_detected\"]:\n        lines.append(f\"✗ No repository information found ({len(results['not_detected'])} models):\")\n        lines.append(\"\")\n        for model in results[\"not_detected\"]:\n            lines.append(f\"  • {model.name}\")\n            lines.append(f\"    Path: {model.path}\")\n        lines.append(\"\")\n\n    # Skipped\n    if results[\"skipped\"]:\n        lines.append(f\"⊘ Skipped ({len(results['skipped'])} models):\")\n        lines.append(f\"  (Already have repo_id or not safetensors/gguf format)\")\n        lines.append(\"\")\n\n    lines.append(\"=\" * 80)\n    lines.append(\n        f\"Summary: {len(results['detected'])} detected, \"\n        f\"{len(results['not_detected'])} not detected, \"\n        f\"{len(results['skipped'])} skipped\"\n    )\n    lines.append(\"=\" * 80)\n\n    return \"\\n\".join(lines)\n\n\ndef apply_detection_results(results: Dict, registry) -> int:\n    \"\"\"\n    Apply detected repo information to models in registry\n\n    Args:\n        results: Results from scan_models_for_repo()\n        registry: UserModelRegistry instance\n\n    Returns:\n        Number of models updated\n    \"\"\"\n    updated_count = 0\n\n    for model, repo_id, repo_type in results[\"detected\"]:\n        success = registry.update_model(model.name, {\"repo_id\": repo_id, \"repo_type\": repo_type})\n\n        if success:\n            updated_count += 1\n\n    return updated_count\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/run_configs.py",
    "content": "\"\"\"\nConfiguration save/load for kt run command.\n\nManages saved run configurations bound to specific models.\n\"\"\"\n\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Any\nfrom datetime import datetime\nimport yaml\n\n\nCONFIG_FILE = Path.home() / \".ktransformers\" / \"run_configs.yaml\"\n\n\nclass RunConfigManager:\n    \"\"\"Manager for saved run configurations.\"\"\"\n\n    def __init__(self):\n        self.config_file = CONFIG_FILE\n        self._ensure_config_file()\n\n    def _ensure_config_file(self):\n        \"\"\"Ensure config file exists.\"\"\"\n        if not self.config_file.exists():\n            self.config_file.parent.mkdir(parents=True, exist_ok=True)\n            self._save_data({\"version\": \"1.0\", \"configs\": {}})\n\n    def _load_data(self) -> Dict:\n        \"\"\"Load raw config data.\"\"\"\n        try:\n            with open(self.config_file, \"r\", encoding=\"utf-8\") as f:\n                return yaml.safe_load(f) or {\"version\": \"1.0\", \"configs\": {}}\n        except Exception:\n            return {\"version\": \"1.0\", \"configs\": {}}\n\n    def _save_data(self, data: Dict):\n        \"\"\"Save raw config data.\"\"\"\n        with open(self.config_file, \"w\", encoding=\"utf-8\") as f:\n            yaml.dump(data, f, allow_unicode=True, default_flow_style=False)\n\n    def list_configs(self, model_id: str) -> List[Dict[str, Any]]:\n        \"\"\"List all saved configs for a model.\n\n        Returns:\n            List of config dicts with 'config_name' and other fields.\n        \"\"\"\n        data = self._load_data()\n        configs = data.get(\"configs\", {}).get(model_id, [])\n        return configs if isinstance(configs, list) else []\n\n    def save_config(self, model_id: str, config: Dict[str, Any]):\n        \"\"\"Save a configuration for a model.\n\n        Args:\n            model_id: Model ID to bind config to\n            config: Configuration dict with all run parameters\n        \"\"\"\n        data = self._load_data()\n\n        if \"configs\" not in data:\n            data[\"configs\"] = {}\n\n        if model_id not in data[\"configs\"]:\n            data[\"configs\"][model_id] = []\n\n        # Add timestamp\n        config[\"created_at\"] = datetime.now().isoformat()\n\n        # Append config\n        data[\"configs\"][model_id].append(config)\n\n        self._save_data(data)\n\n    def delete_config(self, model_id: str, config_index: int) -> bool:\n        \"\"\"Delete a saved configuration.\n\n        Args:\n            model_id: Model ID\n            config_index: Index of config to delete (0-based)\n\n        Returns:\n            True if deleted, False if not found\n        \"\"\"\n        data = self._load_data()\n\n        if model_id not in data.get(\"configs\", {}):\n            return False\n\n        configs = data[\"configs\"][model_id]\n        if config_index < 0 or config_index >= len(configs):\n            return False\n\n        configs.pop(config_index)\n        self._save_data(data)\n        return True\n\n    def get_config(self, model_id: str, config_index: int) -> Optional[Dict[str, Any]]:\n        \"\"\"Get a specific saved configuration.\n\n        Args:\n            model_id: Model ID\n            config_index: Index of config to get (0-based)\n\n        Returns:\n            Config dict or None if not found\n        \"\"\"\n        configs = self.list_configs(model_id)\n        if config_index < 0 or config_index >= len(configs):\n            return None\n        return configs[config_index]\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/run_interactive.py",
    "content": "\"\"\"\nInteractive configuration for kt run command - New Implementation.\n\nProvides step-by-step interactive configuration for running models.\n\"\"\"\n\nfrom typing import Optional, List, Dict, Any, Tuple\nfrom pathlib import Path\nfrom rich.console import Console\nfrom rich.table import Table\nfrom rich.panel import Panel\nfrom rich.prompt import Prompt, Confirm\nfrom rich import box\nimport torch\n\nfrom kt_kernel.cli.i18n import t\nfrom kt_kernel.cli.utils.input_validators import (\n    prompt_int_with_retry,\n    prompt_float_with_retry,\n    prompt_choice_with_retry,\n    prompt_int_list_with_retry,\n)\n\n\nconsole = Console()\n\n\ndef get_gpu_info() -> List[Dict[str, Any]]:\n    \"\"\"Get real-time GPU information with free VRAM.\"\"\"\n    from kt_kernel.cli.utils.environment import detect_gpus\n\n    gpus = detect_gpus()\n    gpu_info_list = []\n\n    for i, gpu in enumerate(gpus):\n        total_vram_gb = gpu.vram_gb\n        free_vram_gb = gpu.vram_gb  # Default fallback\n\n        # Try to get real-time free VRAM\n        if torch.cuda.is_available() and i < torch.cuda.device_count():\n            try:\n                free_vram_bytes, total_vram_bytes = torch.cuda.mem_get_info(i)\n                free_vram_gb = free_vram_bytes / (1024**3)\n                total_vram_gb = total_vram_bytes / (1024**3)\n            except Exception:\n                pass  # Use fallback values\n\n        gpu_info_list.append(\n            {\n                \"id\": i,\n                \"name\": gpu.name,\n                \"total_vram_gb\": total_vram_gb,\n                \"free_vram_gb\": free_vram_gb,\n            }\n        )\n\n    return gpu_info_list\n\n\ndef select_model() -> Optional[Any]:\n    \"\"\"Step 1: Select a safetensors MoE model.\n\n    Returns:\n        Selected UserModel object or None if cancelled.\n    \"\"\"\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n    from kt_kernel.cli.commands.model import is_amx_weights\n\n    registry = UserModelRegistry()\n    all_models = registry.list_models()\n\n    # Filter: safetensors models only (exclude AMX and GGUF)\n    # Then filter to only show MoE models (matching kt model list behavior)\n    moe_models = []\n    for model in all_models:\n        if model.format == \"safetensors\" and model.path_exists():\n            is_amx, _ = is_amx_weights(model.path)\n            if not is_amx:\n                # Only include MoE models (is_moe == True)\n                # Also include models not yet analyzed (is_moe == None) for backwards compatibility\n                if model.is_moe is True or model.is_moe is None:\n                    moe_models.append(model)\n\n    if not moe_models:\n        console.print(f\"[yellow]{t('run_int_no_moe_models')}[/yellow]\")\n        console.print(f\"  {t('run_int_add_models')}\")\n        console.print(f\"  {t('run_int_list_all')}\")\n        return None\n\n    console.print()\n    console.print(Panel(f\"[bold cyan]{t('run_int_step1_title')}[/bold cyan]\", expand=False))\n    console.print()\n\n    # Display models using same format as kt model list\n    from kt_kernel.cli.utils.model_scanner import format_size\n    from kt_kernel.cli.commands.model import SHA256_STATUS_MAP\n\n    table = Table(box=box.ROUNDED, show_header=True, header_style=\"bold cyan\")\n    table.add_column(\"#\", justify=\"right\", style=\"cyan\", no_wrap=True)\n    table.add_column(\"Name\", style=\"cyan\", no_wrap=True)\n    table.add_column(\"Path\", style=\"dim\", overflow=\"fold\")\n    table.add_column(\"Total\", justify=\"right\")\n    table.add_column(\"Exps\", justify=\"center\", style=\"yellow\")\n    table.add_column(\"Act\", justify=\"center\", style=\"green\")\n    table.add_column(\"MoE Size\", justify=\"right\", style=\"cyan\")\n    table.add_column(\"Repo\", style=\"dim\", overflow=\"fold\")\n    table.add_column(\"SHA256\", justify=\"center\")\n\n    for i, model in enumerate(moe_models, 1):\n        # Calculate size\n        if model.path_exists():\n            path_obj = Path(model.path)\n            try:\n                files = list(path_obj.glob(\"*.safetensors\"))\n                total_size = sum(f.stat().st_size for f in files if f.exists())\n                size_display = format_size(total_size)\n            except:\n                size_display = \"[dim]-[/dim]\"\n        else:\n            size_display = \"[dim]-[/dim]\"\n\n        # Format MoE info\n        experts = f\"[yellow]{model.moe_num_experts}[/yellow]\" if model.moe_num_experts else \"[dim]-[/dim]\"\n        active = f\"[green]{model.moe_num_experts_per_tok}[/green]\" if model.moe_num_experts_per_tok else \"[dim]-[/dim]\"\n        moe_size = f\"[cyan]{size_display}[/cyan]\" if model.moe_num_experts else \"[dim]-[/dim]\"\n\n        # Format repo info\n        if model.repo_id:\n            repo_abbr = \"hf\" if model.repo_type == \"huggingface\" else \"ms\"\n            repo_display = f\"{repo_abbr}:{model.repo_id}\"\n        else:\n            repo_display = \"[dim]-[/dim]\"\n\n        # Format SHA256 status\n        sha256_display = SHA256_STATUS_MAP.get(model.sha256_status, model.sha256_status)\n\n        table.add_row(\n            str(i),\n            model.name,\n            str(model.path),\n            size_display,\n            experts,\n            active,\n            moe_size,\n            repo_display,\n            sha256_display,\n        )\n\n    console.print(table)\n    console.print()\n\n    choice = prompt_int_with_retry(\n        t(\"run_int_select_model\"),\n        default=1,\n        min_val=1,\n        max_val=len(moe_models),\n    )\n\n    return moe_models[choice - 1]\n\n\ndef select_inference_method(model: Any) -> Optional[Dict[str, Any]]:\n    \"\"\"Step 2: Select inference method.\n\n    Args:\n        model: Selected UserModel\n\n    Returns:\n        Dict with 'method' (raw/amx/gguf/saved), and method-specific fields, or None if cancelled.\n    \"\"\"\n    from kt_kernel.cli.utils.run_configs import RunConfigManager\n\n    config_manager = RunConfigManager()\n    saved_configs = config_manager.list_configs(model.id)\n\n    # Debug output (can be removed later)\n    if False:  # Set to True for debugging\n        console.print()\n        console.print(f\"[dim]DEBUG: Model ID: {model.id}[/dim]\")\n        console.print(f\"[dim]DEBUG: Saved configs count: {len(saved_configs)}[/dim]\")\n        if saved_configs:\n            console.print(f\"[dim]DEBUG: Configs: {[c.get('config_name', '?') for c in saved_configs]}[/dim]\")\n        console.print()\n\n    console.print()\n    console.print(Panel(\"[bold cyan]Step 2: Select Inference Method[/bold cyan]\", expand=False))\n    console.print()\n\n    options = []\n    option_map = {}\n\n    # Option 1: Use saved configuration (if any)\n    if saved_configs:\n        option_idx = len(options) + 1\n        console.print(f\"  [cyan][{option_idx}][/cyan] [bold]Use Saved Configuration[/bold]\")\n        console.print(f\"      [dim]{len(saved_configs)} saved config(s) available[/dim]\")\n        options.append(str(option_idx))\n        option_map[str(option_idx)] = \"saved\"\n\n    # Option 2: Raw precision inference\n    option_idx = len(options) + 1\n    console.print(f\"  [cyan][{option_idx}][/cyan] [bold]Raw Precision Inference[/bold]\")\n    console.print(\"      [dim]FP8 / FP8_PERCHANNEL / BF16 / RAWINT4[/dim]\")\n    options.append(str(option_idx))\n    option_map[str(option_idx)] = \"raw\"\n\n    # Option 3: AMX quantized inference\n    option_idx = len(options) + 1\n    console.print(f\"  [cyan][{option_idx}][/cyan] [bold]AMX Quantized Inference[/bold]\")\n    console.print(\"      [dim]INT4 / INT8 (CPU optimized)[/dim]\")\n    options.append(str(option_idx))\n    option_map[str(option_idx)] = \"amx\"\n\n    # Option 4: GGUF inference\n    option_idx = len(options) + 1\n    console.print(f\"  [cyan][{option_idx}][/cyan] [bold]GGUF Inference[/bold]\")\n    console.print(\"      [dim]Llamafile format[/dim]\")\n    options.append(str(option_idx))\n    option_map[str(option_idx)] = \"gguf\"\n\n    console.print()\n\n    choice = prompt_choice_with_retry(\"Select method\", choices=options, default=\"1\")\n    method = option_map[choice]\n\n    if method == \"saved\":\n        return _select_saved_config(model, saved_configs)\n    elif method == \"raw\":\n        return _configure_raw_inference(model)\n    elif method == \"amx\":\n        return _configure_amx_inference(model)\n    elif method == \"gguf\":\n        return _configure_gguf_inference(model)\n\n    return None\n\n\ndef _select_saved_config(model: Any, saved_configs: List[Dict]) -> Optional[Dict[str, Any]]:\n    \"\"\"Select from saved configurations with detailed display.\"\"\"\n    console.print()\n    console.print(\"[bold]Saved Configurations:[/bold]\")\n    console.print()\n\n    for i, cfg in enumerate(saved_configs, 1):\n        # Build method display\n        method_display = cfg.get(\"inference_method\", \"unknown\").upper()\n        kt_method = cfg.get(\"kt_method\", \"unknown\")\n\n        if cfg.get(\"inference_method\") == \"raw\":\n            raw_method = cfg.get(\"raw_method\", \"unknown\")\n            method_display = f\"{raw_method}\"\n        elif cfg.get(\"inference_method\") == \"amx\":\n            method_display = kt_method\n        elif cfg.get(\"inference_method\") == \"gguf\":\n            method_display = \"LLAMAFILE\"\n        else:\n            method_display = kt_method\n\n        # Display config header\n        console.print(f\"  [cyan][{i}][/cyan] [bold]{cfg.get('config_name', f'Config {i}')}[/bold]\")\n        console.print()\n\n        # Display detailed parameters\n        console.print(f\"      [yellow]KT Method:[/yellow]       {method_display}\")\n        console.print(f\"      [yellow]NUMA Nodes:[/yellow]      {cfg.get('numa_nodes', '?')}\")\n        console.print(f\"      [yellow]CPU Threads:[/yellow]     {cfg.get('cpu_threads', '?')}\")\n        console.print(f\"      [yellow]GPU Experts:[/yellow]     {cfg.get('gpu_experts', '?')}\")\n        console.print(f\"      [yellow]TP Size:[/yellow]         {cfg.get('tp_size', '?')}\")\n        console.print(f\"      [yellow]Memory Fraction:[/yellow] {cfg.get('mem_fraction_static', '?')}\")\n        console.print(f\"      [yellow]Server:[/yellow]          {cfg.get('host', '0.0.0.0')}:{cfg.get('port', 30000)}\")\n\n        # Display KV cache info if present\n        if cfg.get(\"kv_cache\"):\n            console.print(f\"      [yellow]KV Cache:[/yellow]        {cfg.get('kv_cache', '?')}\")\n            console.print(f\"      [yellow]Chunk Prefill:[/yellow]   {cfg.get('chunk_prefill', '?')}\")\n            console.print(f\"      [yellow]GPU Prefill Thr:[/yellow] {cfg.get('gpu_prefill_threshold', '?')}\")\n\n        # Display parser info if present\n        if cfg.get(\"tool_call_parser\") or cfg.get(\"reasoning_parser\"):\n            if cfg.get(\"tool_call_parser\"):\n                console.print(f\"      [yellow]Tool Call Parser:[/yellow] {cfg.get('tool_call_parser')}\")\n            if cfg.get(\"reasoning_parser\"):\n                console.print(f\"      [yellow]Reasoning Parser:[/yellow] {cfg.get('reasoning_parser')}\")\n\n        console.print()\n\n        # Build and display command preview\n        cmd_preview = _build_command_preview(model, cfg)\n        console.print(\"      [dim]Command:[/dim]\")\n        console.print()\n        for line in cmd_preview:\n            console.print(f\"      {line}\")\n        console.print()\n\n    choice = prompt_int_with_retry(\n        \"Select configuration\",\n        default=1,\n        min_val=1,\n        max_val=len(saved_configs),\n    )\n\n    selected_config = saved_configs[choice - 1].copy()\n    selected_config[\"method\"] = \"saved\"\n    return selected_config\n\n\ndef _build_command_preview(model: Any, cfg: Dict[str, Any]) -> List[str]:\n    \"\"\"Build command preview for saved configuration.\n\n    Args:\n        model: UserModel object\n        cfg: Saved configuration dict\n\n    Returns:\n        List of command lines for display\n    \"\"\"\n    import sys\n\n    host = cfg.get(\"host\", \"0.0.0.0\")\n    port = cfg.get(\"port\", 30000)\n\n    lines = [\n        \"python -m sglang.launch_server \\\\\",\n        f\"    --host {host} \\\\\",\n        f\"    --port {port} \\\\\",\n        f\"    --model {cfg.get('model_path', '?')} \\\\\",\n        f\"    --kt-weight-path {cfg.get('weights_path', '?')} \\\\\",\n        f\"    --kt-cpuinfer {cfg.get('cpu_threads', '?')} \\\\\",\n        f\"    --kt-threadpool-count {cfg.get('numa_nodes', '?')} \\\\\",\n        f\"    --kt-num-gpu-experts {cfg.get('gpu_experts', '?')} \\\\\",\n        f\"    --kt-method {cfg.get('kt_method', '?')} \\\\\",\n    ]\n\n    # Add GPU prefill threshold (use saved value or default)\n    gpu_prefill = cfg.get(\"gpu_prefill_threshold\", 500)\n    lines.append(f\"    --kt-gpu-prefill-token-threshold {gpu_prefill} \\\\\")\n    lines.append(\"    --kt-enable-dynamic-expert-update \\\\\")\n\n    # Add attention backend\n    lines.append(\"    --attention-backend flashinfer \\\\\")\n    lines.append(\"    --trust-remote-code \\\\\")\n\n    # Add memory and performance settings\n    lines.append(f\"    --mem-fraction-static {cfg.get('mem_fraction_static', 0.9)} \\\\\")\n\n    # Add KV cache settings\n    chunk_prefill = cfg.get(\"chunk_prefill\", 32768)\n    max_tokens = cfg.get(\"kv_cache\", 32768)\n    lines.append(f\"    --chunked-prefill-size {chunk_prefill} \\\\\")\n    lines.append(f\"    --max-total-tokens {max_tokens} \\\\\")\n\n    lines.append(\"    --max-running-requests 4 \\\\\")\n    lines.append(\"    --watchdog-timeout 3000 \\\\\")\n    lines.append(\"    --enable-mixed-chunk \\\\\")\n\n    # Add TP size (will be updated with actual GPU selection)\n    lines.append(f\"    --tensor-parallel-size {cfg.get('tp_size', '?')} \\\\\")\n    lines.append(\"    --enable-p2p-check \\\\\")\n\n    # Add FP8 backend if using FP8\n    kt_method = cfg.get(\"kt_method\", \"\")\n    if \"FP8\" in kt_method.upper():\n        lines.append(\"    --fp8-gemm-backend triton \\\\\")\n\n    # Add parsers if configured\n    if cfg.get(\"tool_call_parser\"):\n        lines.append(f\"    --tool-call-parser {cfg['tool_call_parser']} \\\\\")\n    if cfg.get(\"reasoning_parser\"):\n        lines.append(f\"    --reasoning-parser {cfg['reasoning_parser']} \\\\\")\n\n    # Remove trailing backslash from last line\n    if lines:\n        lines[-1] = lines[-1].rstrip(\" \\\\\")\n\n    return lines\n\n\ndef _configure_raw_inference(model: Any) -> Dict[str, Any]:\n    \"\"\"Configure raw precision inference.\"\"\"\n    console.print()\n    console.print(\"[bold]Select Raw Precision Type:[/bold]\")\n    console.print()\n    console.print(\"  [cyan][1][/cyan] FP8\")\n    console.print(\"  [cyan][2][/cyan] FP8_PERCHANNEL\")\n    console.print(\"  [cyan][3][/cyan] BF16\")\n    console.print(\"  [cyan][4][/cyan] RAWINT4\")\n    console.print()\n\n    choice = prompt_choice_with_retry(\"Select precision\", choices=[\"1\", \"2\", \"3\", \"4\"], default=\"1\")\n\n    precision_map = {\n        \"1\": \"FP8\",\n        \"2\": \"FP8_PERCHANNEL\",\n        \"3\": \"BF16\",\n        \"4\": \"RAWINT4\",\n    }\n\n    raw_method = precision_map[choice]\n\n    return {\n        \"method\": \"raw\",\n        \"raw_method\": raw_method,\n        \"kt_method\": raw_method,\n        \"model_path\": model.path,\n        \"weights_path\": model.path,  # Same as model path for raw\n    }\n\n\ndef _configure_amx_inference(model: Any) -> Optional[Dict[str, Any]]:\n    \"\"\"Configure AMX quantized inference.\"\"\"\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n    from kt_kernel.cli.commands.model import is_amx_weights\n\n    registry = UserModelRegistry()\n    all_models = registry.list_models()\n\n    # Filter AMX models\n    amx_models = []\n    for m in all_models:\n        if m.format == \"safetensors\":\n            is_amx, numa = is_amx_weights(m.path)\n            if is_amx:\n                # Check if it's derived from the selected model\n                if m.amx_source_model == model.name:\n                    amx_models.insert(0, m)  # Prioritize matched models\n                else:\n                    amx_models.append(m)\n\n    if not amx_models:\n        console.print(\"[yellow]No AMX quantized models found.[/yellow]\")\n        console.print(\"  Quantize your model with: [cyan]kt quant[/cyan]\")\n        return None\n\n    console.print()\n    console.print(\"[bold]Select AMX Weights:[/bold]\")\n    console.print()\n\n    for i, m in enumerate(amx_models, 1):\n        is_amx, numa = is_amx_weights(m.path)\n        method_str = m.amx_quant_method.upper() if m.amx_quant_method else \"Unknown\"\n        match_indicator = \"[green]★[/green]\" if m.amx_source_model == model.name else \" \"\n        console.print(f\"  {match_indicator} [cyan][{i}][/cyan] {m.name}\")\n        console.print(\n            f\"      [dim]Method: AMX{method_str}, NUMA: {numa}, Source: {m.amx_source_model or 'Unknown'}[/dim]\"\n        )\n\n    console.print()\n    choice = prompt_int_with_retry(\n        \"Select AMX weights\",\n        default=1,\n        min_val=1,\n        max_val=len(amx_models),\n    )\n\n    selected_amx = amx_models[choice - 1]\n    is_amx, numa = is_amx_weights(selected_amx.path)\n    kt_method = f\"AMX{selected_amx.amx_quant_method.upper()}\" if selected_amx.amx_quant_method else \"AMXINT4\"\n\n    return {\n        \"method\": \"amx\",\n        \"kt_method\": kt_method,\n        \"model_path\": model.path,\n        \"weights_path\": selected_amx.path,\n        \"amx_numa_nodes\": numa,\n    }\n\n\ndef _configure_gguf_inference(model: Any) -> Optional[Dict[str, Any]]:\n    \"\"\"Configure GGUF inference.\"\"\"\n    from kt_kernel.cli.utils.user_model_registry import UserModelRegistry\n\n    registry = UserModelRegistry()\n    all_models = registry.list_models()\n\n    # Filter GGUF models\n    gguf_models = [m for m in all_models if m.format == \"gguf\"]\n\n    if not gguf_models:\n        console.print(\"[yellow]No GGUF models found.[/yellow]\")\n        console.print(\"  Add GGUF models with: [cyan]kt model add /path/to/model.gguf[/cyan]\")\n        return None\n\n    console.print()\n    console.print(\"[bold]Select GGUF Weights:[/bold]\")\n    console.print()\n\n    for i, m in enumerate(gguf_models, 1):\n        console.print(f\"  [cyan][{i}][/cyan] {m.name}\")\n        console.print(f\"      [dim]Path: {m.path}[/dim]\")\n\n    console.print()\n    choice = prompt_int_with_retry(\n        \"Select GGUF weights\",\n        default=1,\n        min_val=1,\n        max_val=len(gguf_models),\n    )\n\n    selected_gguf = gguf_models[choice - 1]\n\n    return {\n        \"method\": \"gguf\",\n        \"kt_method\": \"LLAMAFILE\",\n        \"model_path\": model.path,\n        \"weights_path\": selected_gguf.path,\n    }\n\n\ndef configure_numa_and_cpu(method_config: Dict[str, Any]) -> Dict[str, int]:\n    \"\"\"Step 3: Configure NUMA and CPU threads.\n\n    Args:\n        method_config: Config from step 2 (may contain amx_numa_nodes hint)\n\n    Returns:\n        Dict with 'numa_nodes' and 'cpu_threads'\n    \"\"\"\n    from kt_kernel.cli.utils.environment import detect_cpu_info\n\n    cpu_info = detect_cpu_info()\n    max_numa = cpu_info.numa_nodes\n    max_cores = cpu_info.threads  # Use logical threads instead of physical cores\n\n    console.print()\n    console.print(Panel(\"[bold cyan]Step 3: NUMA and CPU Configuration[/bold cyan]\", expand=False))\n    console.print()\n\n    # Show AMX hint if applicable\n    if method_config.get(\"method\") == \"amx\" and method_config.get(\"amx_numa_nodes\"):\n        amx_numa = method_config[\"amx_numa_nodes\"]\n        console.print(f\"[yellow]⚠ Note: This AMX model was quantized with NUMA={amx_numa}[/yellow]\")\n        console.print(f\"[yellow]  For optimal performance, use the same NUMA setting.[/yellow]\")\n        console.print()\n        default_numa = amx_numa\n    else:\n        default_numa = max_numa\n\n    numa_nodes = prompt_int_with_retry(\n        f\"NUMA Nodes (1 to {max_numa})\",\n        default=default_numa,\n        min_val=1,\n        max_val=max_numa,\n    )\n\n    default_threads = int(max_cores * 0.8)\n    cpu_threads = prompt_int_with_retry(\n        f\"CPU Threads (1 to {max_cores})\",\n        default=default_threads,\n        min_val=1,\n        max_val=max_cores,\n    )\n\n    return {\n        \"numa_nodes\": numa_nodes,\n        \"cpu_threads\": cpu_threads,\n    }\n\n\ndef configure_gpu_experts(model: Any) -> int:\n    \"\"\"Step 4: Configure GPU expert count.\n\n    Args:\n        model: Selected model\n\n    Returns:\n        Number of GPU experts\n    \"\"\"\n    from kt_kernel.cli.utils.analyze_moe_model import analyze_moe_model\n\n    console.print()\n    console.print(Panel(\"[bold cyan]Step 4: GPU Experts Configuration[/bold cyan]\", expand=False))\n    console.print()\n\n    # Try to get num_experts from model\n    try:\n        moe_result = analyze_moe_model(model.path)\n        num_experts = moe_result.get(\"num_experts\", 256)\n    except Exception:\n        num_experts = 256  # Default fallback\n\n    console.print(f\"[dim]Model has {num_experts} experts total[/dim]\")\n    console.print()\n    console.print(\"[yellow]⚠ Tip: More GPU experts = faster inference, but uses more VRAM[/yellow]\")\n    console.print()\n\n    default_experts = min(8, num_experts)\n    gpu_experts = prompt_int_with_retry(\n        f\"GPU Experts per layer (0 to {num_experts})\",\n        default=default_experts,\n        min_val=0,\n        max_val=num_experts,\n    )\n\n    return gpu_experts\n\n\ndef configure_kv_cache(is_raw_inference: bool) -> Optional[Dict[str, int]]:\n    \"\"\"Step 5: Configure KV Cache (only for raw inference).\n\n    Args:\n        is_raw_inference: True if using raw precision inference\n\n    Returns:\n        Dict with 'kv_cache', 'chunk_prefill', 'gpu_prefill_threshold' or None if not applicable\n    \"\"\"\n    if not is_raw_inference:\n        return None\n\n    console.print()\n    console.print(Panel(\"[bold cyan]Step 5: KV Cache and Prefill Configuration[/bold cyan]\", expand=False))\n    console.print()\n    console.print(\"[dim]These settings control memory allocation and prefill batch size[/dim]\")\n    console.print(\"[dim]gpu-prefill-token-threshold: maximum length for single layerwise prefill[/dim]\")\n    console.print()\n\n    kv_cache = prompt_int_with_retry(\"KV Cache Size (max_total_tokens)\", default=32768, min_val=1)\n    chunk_prefill = prompt_int_with_retry(\"Chunk Prefill Size\", default=32768, min_val=1)\n    gpu_prefill_threshold = prompt_int_with_retry(\"GPU Prefill Token Threshold\", default=500, min_val=1)\n\n    return {\n        \"kv_cache\": kv_cache,\n        \"chunk_prefill\": chunk_prefill,\n        \"gpu_prefill_threshold\": gpu_prefill_threshold,\n    }\n\n\ndef select_gpus_and_tp(\n    required_tp_size: Optional[int] = None, saved_mem_fraction: Optional[float] = None\n) -> Tuple[List[int], int, float]:\n    \"\"\"Step 6: Select GPUs, TP size, and memory fraction.\n\n    Args:\n        required_tp_size: If specified, user must select exactly this many GPUs.\n                         If None, TP size can be any power of 2.\n        saved_mem_fraction: If specified, use this memory fraction instead of prompting.\n                           Used when loading saved configurations.\n\n    Returns:\n        Tuple of (selected_gpu_ids, tp_size, mem_fraction_static)\n    \"\"\"\n    gpu_info_list = get_gpu_info()\n\n    if not gpu_info_list:\n        console.print(\"[red]No GPUs detected[/red]\")\n        return [], 0, 0.9\n\n    console.print()\n    if required_tp_size is not None:\n        console.print(Panel(f\"[bold cyan]Select {required_tp_size} GPUs (for saved config)[/bold cyan]\", expand=False))\n        console.print()\n        console.print(f\"[yellow]Required TP size: {required_tp_size}[/yellow]\")\n        console.print(f\"[yellow]You must select exactly {required_tp_size} GPU(s)[/yellow]\")\n    else:\n        console.print(Panel(\"[bold cyan]Step 6: GPU Selection and Memory[/bold cyan]\", expand=False))\n        console.print()\n        console.print(\"[dim]TP (Tensor Parallel) size must be a power of 2: 1, 2, 4, 8, ...[/dim]\")\n    console.print()\n\n    # Display GPUs\n    table = Table(box=box.ROUNDED, show_header=True, header_style=\"bold cyan\")\n    table.add_column(\"ID\", justify=\"right\", style=\"cyan\")\n    table.add_column(\"Name\", style=\"white\")\n    table.add_column(\"Free VRAM\", justify=\"right\", style=\"green\")\n    table.add_column(\"Total VRAM\", justify=\"right\", style=\"dim\")\n\n    for gpu in gpu_info_list:\n        table.add_row(str(gpu[\"id\"]), gpu[\"name\"], f\"{gpu['free_vram_gb']:.1f} GB\", f\"{gpu['total_vram_gb']:.1f} GB\")\n\n    console.print(table)\n    console.print()\n\n    # Validator function\n    def validate_tp_requirements(gpu_ids: List[int]) -> tuple[bool, Optional[str]]:\n        \"\"\"Validate TP requirements based on required_tp_size.\"\"\"\n        actual_count = len(gpu_ids)\n\n        if required_tp_size is not None:\n            # Exact count required\n            if actual_count != required_tp_size:\n                return False, f\"Must select exactly {required_tp_size} GPU(s), but you selected {actual_count}.\"\n        else:\n            # Must be power of 2\n            if actual_count & (actual_count - 1) != 0:\n                return (\n                    False,\n                    f\"TP size ({actual_count}) must be a power of 2. Valid sizes: 1, 2, 4, 8, 16, 32, ...\\nYou selected {actual_count} GPU(s). Please select a different number.\",\n                )\n\n        return True, None\n\n    # Generate default GPU selection\n    if required_tp_size is not None:\n        # For saved config: select first N GPUs\n        if required_tp_size <= len(gpu_info_list):\n            default_gpus = \",\".join(str(i) for i in range(required_tp_size))\n        else:\n            default_gpus = \",\".join(str(i) for i in range(len(gpu_info_list)))\n        prompt_text = f\"Enter {required_tp_size} GPU ID(s) separated by commas (e.g., 0,1,2,3)\"\n    else:\n        # For new config: select all GPUs\n        default_gpus = \",\".join(str(i) for i in range(len(gpu_info_list)))\n        prompt_text = \"Enter GPU IDs separated by commas (e.g., 0,1,2,3)\"\n        console.print(prompt_text)\n        console.print(f\"  Or press Enter to use all {len(gpu_info_list)} GPUs\")\n\n    console.print()\n\n    selected_gpu_ids = prompt_int_list_with_retry(\n        \"GPU IDs\",\n        default=default_gpus,\n        min_val=0,\n        max_val=len(gpu_info_list) - 1,\n        validator=validate_tp_requirements,\n    )\n\n    tp_size = len(selected_gpu_ids)\n\n    console.print()\n    console.print(f\"[green]✓[/green] Selected {tp_size} GPU(s): {selected_gpu_ids}\")\n    console.print()\n\n    # Memory fraction - use saved value if provided, otherwise prompt\n    if saved_mem_fraction is not None:\n        mem_fraction = saved_mem_fraction\n        console.print(f\"[dim]Using saved memory fraction: {mem_fraction}[/dim]\")\n    else:\n        mem_fraction = prompt_float_with_retry(\n            \"Static Memory Fraction (0.0-1.0)\",\n            default=0.9,\n            min_val=0.0,\n            max_val=1.0,\n        )\n\n    return selected_gpu_ids, tp_size, mem_fraction\n\n\ndef configure_parsers() -> Dict[str, Optional[str]]:\n    \"\"\"Step 7: Configure parsers (optional).\n\n    Returns:\n        Dict with 'tool_call_parser' and 'reasoning_parser' (can be None)\n    \"\"\"\n    console.print()\n    console.print(Panel(\"[bold cyan]Step 7: Parser Configuration (Optional)[/bold cyan]\", expand=False))\n    console.print()\n    console.print(\"[dim]Press Enter to skip (no parser will be added)[/dim]\")\n    console.print()\n\n    tool_call_parser = Prompt.ask(\"Tool Call Parser (e.g., glm47)\", default=\"\")\n    tool_call_parser = tool_call_parser.strip() if tool_call_parser else None\n\n    reasoning_parser = Prompt.ask(\"Reasoning Parser (e.g., glm45)\", default=\"\")\n    reasoning_parser = reasoning_parser.strip() if reasoning_parser else None\n\n    if tool_call_parser or reasoning_parser:\n        console.print()\n        if tool_call_parser:\n            console.print(f\"[green]✓[/green] Tool Call Parser: {tool_call_parser}\")\n        if reasoning_parser:\n            console.print(f\"[green]✓[/green] Reasoning Parser: {reasoning_parser}\")\n    else:\n        console.print()\n        console.print(\"[dim]No parsers configured[/dim]\")\n\n    return {\n        \"tool_call_parser\": tool_call_parser,\n        \"reasoning_parser\": reasoning_parser,\n    }\n\n\ndef configure_host_and_port() -> Dict[str, Any]:\n    \"\"\"Step 8: Configure host and port with availability check.\n\n    Returns:\n        Dict with 'host' and 'port'\n    \"\"\"\n    from kt_kernel.cli.utils.port_checker import is_port_available\n\n    console.print()\n    console.print(Panel(\"[bold cyan]Step 8: Server Configuration[/bold cyan]\", expand=False))\n    console.print()\n\n    # Get host\n    host = Prompt.ask(\"Server Host\", default=\"0.0.0.0\")\n\n    # Get port with availability check\n    while True:\n        port = prompt_int_with_retry(\n            \"Server Port\",\n            default=30000,\n            min_val=1024,\n            max_val=65535,\n        )\n\n        # Check if port is available\n        console.print()\n        console.print(f\"[dim]Checking port {port} availability...[/dim]\")\n\n        if is_port_available(host, port):\n            console.print(f\"[green]✓[/green] Port {port} is available\")\n            break\n        else:\n            console.print(f\"[red]✗[/red] Port {port} is already in use\")\n            console.print()\n\n            # Suggest next available port\n            from kt_kernel.cli.utils.port_checker import find_available_port\n\n            found, suggested_port = find_available_port(host, port + 1, max_attempts=100)\n            if found:\n                console.print(f\"[yellow]Suggestion:[/yellow] Port {suggested_port} is available\")\n            console.print()\n\n    console.print()\n    console.print(f\"[green]✓[/green] Server will listen on {host}:{port}\")\n\n    return {\n        \"host\": host,\n        \"port\": port,\n    }\n\n\ndef save_config_prompt(model: Any, full_config: Dict[str, Any]) -> bool:\n    \"\"\"Step 7: Prompt to save configuration.\n\n    Args:\n        model: Selected model\n        full_config: Complete configuration dict\n\n    Returns:\n        True if saved, False otherwise\n    \"\"\"\n    console.print()\n    console.print(Panel(\"[bold cyan]Step 7: Save Configuration[/bold cyan]\", expand=False))\n    console.print()\n\n    if not Confirm.ask(\"Save this configuration for future use?\", default=True):\n        return False\n\n    config_name = Prompt.ask(\"Configuration name\", default=f\"Config {full_config.get('inference_method', 'default')}\")\n\n    from kt_kernel.cli.utils.run_configs import RunConfigManager\n\n    config_manager = RunConfigManager()\n\n    # Prepare config to save (exclude runtime-only fields and non-serializable objects)\n    save_config = {\n        \"config_name\": config_name,\n        \"inference_method\": full_config[\"inference_method\"],\n        \"kt_method\": full_config[\"kt_method\"],\n        \"model_path\": str(full_config[\"model_path\"]),\n        \"weights_path\": str(full_config[\"weights_path\"]),\n        \"numa_nodes\": full_config[\"numa_nodes\"],\n        \"cpu_threads\": full_config[\"cpu_threads\"],\n        \"gpu_experts\": full_config[\"gpu_experts\"],\n        \"tp_size\": full_config[\"tp_size\"],\n        \"mem_fraction_static\": full_config[\"mem_fraction_static\"],\n        \"host\": full_config[\"host\"],\n        \"port\": full_config[\"port\"],\n        # Note: selected_gpus is NOT saved - user will select GPUs when loading config\n    }\n\n    # Add parser config if present\n    if full_config.get(\"tool_call_parser\"):\n        save_config[\"tool_call_parser\"] = full_config[\"tool_call_parser\"]\n    if full_config.get(\"reasoning_parser\"):\n        save_config[\"reasoning_parser\"] = full_config[\"reasoning_parser\"]\n\n    # Add raw-specific config if present\n    if full_config.get(\"raw_method\"):\n        save_config[\"raw_method\"] = full_config[\"raw_method\"]\n\n    if full_config.get(\"kv_cache\"):\n        save_config[\"kv_cache\"] = full_config[\"kv_cache\"]\n        save_config[\"chunk_prefill\"] = full_config[\"chunk_prefill\"]\n        save_config[\"gpu_prefill_threshold\"] = full_config[\"gpu_prefill_threshold\"]\n\n    config_manager.save_config(model.id, save_config)\n\n    console.print()\n    console.print(f\"[green]✓[/green] Configuration saved: {config_name}\")\n\n    return True\n\n\ndef interactive_run_config() -> Optional[Dict[str, Any]]:\n    \"\"\"\n    Main interactive configuration flow for kt run.\n\n    Returns:\n        Complete configuration dict or None if cancelled.\n    \"\"\"\n    # Step 1: Select model\n    model = select_model()\n    if not model:\n        return None\n\n    # Step 2: Select inference method\n    method_config = select_inference_method(model)\n    if not method_config:\n        return None\n\n    # If using saved config, add model object and return directly\n    if method_config.get(\"method\") == \"saved\":\n        console.print()\n        console.print(\"[green]✓[/green] Using saved configuration\")\n\n        # Let user select GPUs (must match saved TP size)\n        saved_tp_size = method_config.get(\"tp_size\", 1)\n\n        console.print()\n        console.print(f\"[yellow]This configuration requires TP={saved_tp_size}[/yellow]\")\n        console.print(f\"[yellow]Please select {saved_tp_size} GPU(s)[/yellow]\")\n\n        # Get saved memory fraction\n        saved_mem_fraction = method_config.get(\"mem_fraction_static\", 0.9)\n\n        selected_gpus, actual_tp_size, _ = select_gpus_and_tp(\n            required_tp_size=saved_tp_size, saved_mem_fraction=saved_mem_fraction\n        )\n        if not selected_gpus:\n            return None\n\n        # Update config with selected GPUs (keep saved mem_fraction_static)\n        method_config[\"selected_gpus\"] = selected_gpus\n        # tp_size is already in method_config from saved data\n\n        # Check port availability\n        from kt_kernel.cli.utils.port_checker import is_port_available, find_available_port\n\n        saved_host = method_config.get(\"host\", \"0.0.0.0\")\n        saved_port = method_config.get(\"port\", 30000)\n\n        console.print()\n        console.print(f\"[dim]Checking port {saved_port} availability...[/dim]\")\n\n        if is_port_available(saved_host, saved_port):\n            console.print(f\"[green]✓[/green] Port {saved_port} is available\")\n            method_config[\"port\"] = saved_port\n            method_config[\"host\"] = saved_host\n        else:\n            console.print(f\"[red]✗[/red] Port {saved_port} is already in use\")\n            console.print()\n\n            # Suggest next available port\n            found, suggested_port = find_available_port(saved_host, saved_port + 1, max_attempts=100)\n            if found:\n                console.print(f\"[yellow]Suggestion:[/yellow] Port {suggested_port} is available\")\n            console.print()\n\n            # Ask user for new port\n            while True:\n                new_port = prompt_int_with_retry(\n                    \"Enter new port\",\n                    default=suggested_port if found else saved_port + 1,\n                    min_val=1024,\n                    max_val=65535,\n                )\n\n                console.print()\n                console.print(f\"[dim]Checking port {new_port} availability...[/dim]\")\n\n                if is_port_available(saved_host, new_port):\n                    console.print(f\"[green]✓[/green] Port {new_port} is available\")\n                    method_config[\"port\"] = new_port\n                    method_config[\"host\"] = saved_host\n                    break\n                else:\n                    console.print(f\"[red]✗[/red] Port {new_port} is already in use\")\n                    console.print()\n\n        # Add model object for run.py compatibility\n        method_config[\"model\"] = model\n\n        # Ensure paths are Path objects\n        from pathlib import Path\n\n        if \"model_path\" in method_config:\n            method_config[\"model_path\"] = Path(method_config[\"model_path\"])\n        if \"weights_path\" in method_config:\n            method_config[\"weights_path\"] = Path(method_config[\"weights_path\"])\n\n        # Display configuration summary\n        console.print()\n        console.print(Panel(\"[bold cyan]Saved Configuration[/bold cyan]\", expand=False))\n        console.print()\n        _display_config_summary(method_config)\n        console.print()\n\n        # Start directly without confirmation when using saved config\n        return method_config\n\n    # Step 3: Configure NUMA and CPU\n    numa_cpu_config = configure_numa_and_cpu(method_config)\n\n    # Step 4: Configure GPU experts\n    gpu_experts = configure_gpu_experts(model)\n\n    # Step 5: Configure KV Cache (only for raw)\n    is_raw = method_config.get(\"method\") == \"raw\"\n    kv_config = configure_kv_cache(is_raw)\n\n    # Step 6: Select GPUs and TP\n    selected_gpus, tp_size, mem_fraction = select_gpus_and_tp()\n    if not selected_gpus:\n        return None\n\n    # Step 7: Configure parsers (optional)\n    parser_config = configure_parsers()\n\n    # Step 8: Configure host and port\n    server_config = configure_host_and_port()\n\n    # Build complete configuration\n    full_config = {\n        \"model\": model,\n        \"inference_method\": method_config[\"method\"],\n        \"kt_method\": method_config[\"kt_method\"],\n        \"model_path\": method_config[\"model_path\"],\n        \"weights_path\": method_config[\"weights_path\"],\n        **numa_cpu_config,\n        \"gpu_experts\": gpu_experts,\n        \"selected_gpus\": selected_gpus,\n        \"tp_size\": tp_size,\n        \"mem_fraction_static\": mem_fraction,\n        **parser_config,  # Add parser config\n        **server_config,  # Add server config (host, port)\n    }\n\n    # Add raw-specific config\n    if kv_config:\n        full_config[\"raw_method\"] = method_config.get(\"raw_method\")\n        full_config.update(kv_config)\n\n    # Step 9: Save configuration\n    save_config_prompt(model, full_config)\n\n    # Final confirmation\n    console.print()\n    console.print(Panel(\"[bold cyan]Configuration Complete[/bold cyan]\", expand=False))\n    console.print()\n    _display_config_summary(full_config)\n    console.print()\n\n    if not Confirm.ask(\"[bold green]Start model server with this configuration?[/bold green]\", default=True):\n        console.print(\"[yellow]Cancelled[/yellow]\")\n        return None\n\n    return full_config\n\n\ndef _display_config_summary(config: Dict[str, Any]):\n    \"\"\"Display configuration summary.\"\"\"\n    model = config[\"model\"]\n    console.print(f\"  Model:           {model.name}\")\n    console.print(f\"  KT Method:       {config['kt_method']}\")\n    console.print(f\"  NUMA Nodes:      {config['numa_nodes']}\")\n    console.print(f\"  CPU Threads:     {config['cpu_threads']}\")\n    console.print(f\"  GPU Experts:     {config['gpu_experts']}\")\n\n    # Handle both new config and saved config format\n    tp_size = config.get(\"tp_size\", len(config.get(\"selected_gpus\", [])))\n    selected_gpus = config.get(\"selected_gpus\", [])\n\n    console.print(f\"  GPUs:            {selected_gpus} (TP={tp_size})\")\n    console.print(f\"  Memory Fraction: {config['mem_fraction_static']}\")\n\n    # Server config\n    host = config.get(\"host\", \"0.0.0.0\")\n    port = config.get(\"port\", 30000)\n    console.print(f\"  Server:          {host}:{port}\")\n\n    if config.get(\"kv_cache\"):\n        console.print(f\"  KV Cache:        {config['kv_cache']}\")\n        console.print(f\"  Chunk Prefill:   {config['chunk_prefill']}\")\n        console.print(f\"  GPU Prefill Thr: {config['gpu_prefill_threshold']}\")\n\n    # Display parsers if configured\n    if config.get(\"tool_call_parser\") or config.get(\"reasoning_parser\"):\n        console.print()\n        if config.get(\"tool_call_parser\"):\n            console.print(f\"  Tool Call Parser: {config['tool_call_parser']}\")\n        if config.get(\"reasoning_parser\"):\n            console.print(f\"  Reasoning Parser: {config['reasoning_parser']}\")\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/sglang_checker.py",
    "content": "\"\"\"\nSGLang installation checker and installation instructions provider.\n\nThis module provides utilities to:\n- Check if SGLang is installed and get its metadata\n- Provide installation instructions when SGLang is not found\n\"\"\"\n\nimport subprocess\nimport sys\nfrom pathlib import Path\nfrom typing import Optional\n\nfrom kt_kernel.cli.i18n import t\nfrom kt_kernel.cli.utils.console import console\n\n\ndef check_sglang_installation() -> dict:\n    \"\"\"Check if SGLang is installed and get its metadata.\n\n    Returns:\n        dict with keys:\n        - installed: bool\n        - version: str or None\n        - location: str or None (installation path)\n        - editable: bool (whether installed in editable mode)\n        - git_info: dict or None (git remote and branch if available)\n        - from_source: bool (whether installed from source repository)\n    \"\"\"\n    try:\n        # Try to import sglang\n        import sglang\n\n        version = getattr(sglang, \"__version__\", None)\n\n        # Use pip show to get detailed package information\n        location = None\n        editable = False\n        git_info = None\n        from_source = False\n        is_kvcache_fork = False  # True if installed as sglang-kt package\n\n        try:\n            # Get pip show output (try sglang-kt first, then sglang)\n            result = subprocess.run(\n                [sys.executable, \"-m\", \"pip\", \"show\", \"sglang-kt\"],\n                capture_output=True,\n                text=True,\n                timeout=10,\n            )\n            if result.returncode == 0:\n                is_kvcache_fork = True  # sglang-kt package name proves it's the fork\n            else:\n                result = subprocess.run(\n                    [sys.executable, \"-m\", \"pip\", \"show\", \"sglang\"],\n                    capture_output=True,\n                    text=True,\n                    timeout=10,\n                )\n\n            if result.returncode == 0:\n                pip_info = {}\n                for line in result.stdout.split(\"\\n\"):\n                    if \":\" in line:\n                        key, value = line.split(\":\", 1)\n                        pip_info[key.strip()] = value.strip()\n\n                location = pip_info.get(\"Location\")\n                editable_location = pip_info.get(\"Editable project location\")\n\n                if editable_location:\n                    editable = True\n                    location = editable_location\n        except (subprocess.TimeoutExpired, FileNotFoundError, OSError):\n            # Fallback to module location\n            if hasattr(sglang, \"__file__\") and sglang.__file__:\n                location = str(Path(sglang.__file__).parent.parent)\n\n        # Check if it's installed from source (has .git directory)\n        if location:\n            git_root = None\n            check_path = Path(location)\n\n            # Check current directory and up to 2 parent directories\n            for _ in range(3):\n                git_dir = check_path / \".git\"\n                if git_dir.exists():\n                    git_root = check_path\n                    from_source = True\n                    break\n                if check_path.parent == check_path:  # Reached root\n                    break\n                check_path = check_path.parent\n\n            if from_source and git_root:\n                # Try to get git remote and branch info\n                try:\n                    # Get remote URL\n                    result = subprocess.run(\n                        [\"git\", \"remote\", \"get-url\", \"origin\"],\n                        cwd=git_root,\n                        capture_output=True,\n                        text=True,\n                        timeout=5,\n                    )\n                    remote_url = result.stdout.strip() if result.returncode == 0 else None\n\n                    # Extract org/repo from URL\n                    remote_short = None\n                    if remote_url:\n                        # Handle both https and git@ URLs\n                        if \"github.com\" in remote_url:\n                            parts = remote_url.rstrip(\"/\").replace(\".git\", \"\").split(\"github.com\")[-1]\n                            remote_short = parts.lstrip(\"/\").lstrip(\":\")\n\n                    # Get current branch\n                    result = subprocess.run(\n                        [\"git\", \"branch\", \"--show-current\"],\n                        cwd=git_root,\n                        capture_output=True,\n                        text=True,\n                        timeout=5,\n                    )\n                    branch = result.stdout.strip() if result.returncode == 0 else None\n\n                    if remote_url or branch:\n                        git_info = {\n                            \"remote\": remote_short or remote_url,\n                            \"branch\": branch,\n                        }\n                except (subprocess.TimeoutExpired, FileNotFoundError, OSError):\n                    pass\n\n        return {\n            \"installed\": True,\n            \"version\": version,\n            \"location\": location,\n            \"editable\": editable,\n            \"git_info\": git_info,\n            \"from_source\": from_source,\n            \"is_kvcache_fork\": is_kvcache_fork,\n        }\n    except ImportError:\n        return {\n            \"installed\": False,\n            \"version\": None,\n            \"location\": None,\n            \"editable\": False,\n            \"git_info\": None,\n            \"from_source\": False,\n            \"is_kvcache_fork\": False,\n        }\n\n\ndef get_sglang_install_instructions(lang: Optional[str] = None) -> str:\n    \"\"\"Get SGLang installation instructions.\n\n    Args:\n        lang: Language code ('en' or 'zh'). If None, uses current language setting.\n\n    Returns:\n        Formatted installation instructions string.\n    \"\"\"\n    from kt_kernel.cli.i18n import get_lang\n\n    if lang is None:\n        lang = get_lang()\n\n    if lang == \"zh\":\n        return \"\"\"\n[bold yellow]SGLang \\u672a\\u5b89\\u88c5[/bold yellow]\n\n\\u8bf7\\u9009\\u62e9\\u4ee5\\u4e0b\\u65b9\\u5f0f\\u4e4b\\u4e00\\u5b89\\u88c5 SGLang (kvcache-ai \\u5206\\u652f):\n\n[bold]\\u65b9\\u5f0f A - \\u4e00\\u952e\\u5b89\\u88c5 (\\u63a8\\u8350):[/bold]\n   \\u4ece ktransformers \\u6839\\u76ee\\u5f55\\u8fd0\\u884c:\n   [cyan]./install.sh[/cyan]\n\n[bold]\\u65b9\\u5f0f B - pip \\u5b89\\u88c5:[/bold]\n   [cyan]pip install sglang-kt[/cyan]\n\n[bold]\\u65b9\\u5f0f C - \\u4ece\\u6e90\\u7801\\u5b89\\u88c5:[/bold]\n   git clone --recursive https://github.com/kvcache-ai/ktransformers.git\n   cd ktransformers\n   pip install \"third_party/sglang/python[all]\"\n\n[dim]\\u6ce8\\u610f: \\u8bf7\\u786e\\u4fdd\\u5728\\u6b63\\u786e\\u7684 Python \\u73af\\u5883\\u4e2d\\u6267\\u884c\\u4ee5\\u4e0a\\u547d\\u4ee4[/dim]\n\"\"\"\n    else:\n        return \"\"\"\n[bold yellow]SGLang is not installed[/bold yellow]\n\nInstall SGLang (kvcache-ai fork) using one of these methods:\n\n[bold]Option A - One-click install (recommended):[/bold]\n   From the ktransformers root directory, run:\n   [cyan]./install.sh[/cyan]\n\n[bold]Option B - pip install:[/bold]\n   [cyan]pip install sglang-kt[/cyan]\n\n[bold]Option C - From source:[/bold]\n   git clone --recursive https://github.com/kvcache-ai/ktransformers.git\n   cd ktransformers\n   pip install \"third_party/sglang/python[all]\"\n\n[dim]Note: Make sure to run these commands in the correct Python environment[/dim]\n\"\"\"\n\n\ndef print_sglang_install_instructions() -> None:\n    \"\"\"Print SGLang installation instructions to console.\"\"\"\n    instructions = get_sglang_install_instructions()\n    console.print(instructions)\n\n\ndef check_sglang_and_warn() -> bool:\n    \"\"\"Check if SGLang is installed, print warning if not.\n\n    Returns:\n        True if SGLang is installed, False otherwise.\n    \"\"\"\n    info = check_sglang_installation()\n\n    if not info[\"installed\"]:\n        print_sglang_install_instructions()\n        return False\n\n    # Check if installed from PyPI (not recommended)\n    if info[\"installed\"] and not info[\"from_source\"]:\n        from kt_kernel.cli.utils.console import print_warning\n\n        print_warning(t(\"sglang_pypi_warning\"))\n        console.print()\n        console.print(\"[dim]\" + t(\"sglang_recommend_source\") + \"[/dim]\")\n        console.print()\n\n    return True\n\n\ndef _get_sglang_kt_kernel_cache_path() -> Path:\n    \"\"\"Get the path to the sglang kt-kernel support cache file.\"\"\"\n    cache_dir = Path.home() / \".ktransformers\" / \"cache\"\n    cache_dir.mkdir(parents=True, exist_ok=True)\n    return cache_dir / \"sglang_kt_kernel_supported\"\n\n\ndef _is_sglang_kt_kernel_cache_valid() -> bool:\n    \"\"\"Check if the sglang kt-kernel support cache is valid.\n\n    The cache is considered valid if:\n    1. The cache file exists\n    2. The cache file contains 'true' (indicating previous check passed)\n\n    Returns:\n        True if cache is valid and indicates support, False otherwise.\n    \"\"\"\n    cache_path = _get_sglang_kt_kernel_cache_path()\n    if cache_path.exists():\n        try:\n            content = cache_path.read_text().strip().lower()\n            return content == \"true\"\n        except (OSError, IOError):\n            pass\n    return False\n\n\ndef _save_sglang_kt_kernel_cache(supported: bool) -> None:\n    \"\"\"Save the sglang kt-kernel support check result to cache.\"\"\"\n    cache_path = _get_sglang_kt_kernel_cache_path()\n    try:\n        cache_path.write_text(\"true\" if supported else \"false\")\n    except (OSError, IOError):\n        pass  # Ignore cache write errors\n\n\ndef clear_sglang_kt_kernel_cache() -> None:\n    \"\"\"Clear the sglang kt-kernel support cache, forcing a re-check on next run.\"\"\"\n    cache_path = _get_sglang_kt_kernel_cache_path()\n    try:\n        if cache_path.exists():\n            cache_path.unlink()\n    except (OSError, IOError):\n        pass\n\n\ndef check_sglang_kt_kernel_support(use_cache: bool = True, silent: bool = False) -> dict:\n    \"\"\"Check if SGLang supports kt-kernel parameters (--kt-gpu-prefill-token-threshold).\n\n    This function runs `python -m sglang.launch_server --help` and checks if the\n    output contains the `--kt-gpu-prefill-token-threshold` parameter. This parameter\n    is only available in the kvcache-ai/sglang fork, not in the official sglang.\n\n    The result is cached after the first successful check to avoid repeated checks.\n\n    Args:\n        use_cache: If True, use cached result if available. Default is True.\n        silent: If True, don't print checking message. Default is False.\n\n    Returns:\n        dict with keys:\n        - supported: bool - True if kt-kernel parameters are supported\n        - help_output: str or None - The help output from sglang.launch_server\n        - error: str or None - Error message if check failed\n        - from_cache: bool - True if result was from cache\n    \"\"\"\n    from kt_kernel.cli.utils.console import print_step\n\n    # Check cache first\n    if use_cache and _is_sglang_kt_kernel_cache_valid():\n        return {\n            \"supported\": True,\n            \"help_output\": None,\n            \"error\": None,\n            \"from_cache\": True,\n        }\n\n    # Print checking message\n    if not silent:\n        print_step(t(\"sglang_checking_kt_kernel_support\"))\n\n    try:\n        result = subprocess.run(\n            [sys.executable, \"-m\", \"sglang.launch_server\", \"--help\"],\n            capture_output=True,\n            text=True,\n            timeout=90,  # Increased for slow CUDA init and module loading in some environments\n        )\n\n        help_output = result.stdout + result.stderr\n\n        # Check if --kt-gpu-prefill-token-threshold is in the help output\n        supported = \"--kt-gpu-prefill-token-threshold\" in help_output\n\n        # Save to cache if supported\n        if supported:\n            _save_sglang_kt_kernel_cache(True)\n\n        return {\n            \"supported\": supported,\n            \"help_output\": help_output,\n            \"error\": None,\n            \"from_cache\": False,\n        }\n\n    except subprocess.TimeoutExpired:\n        return {\n            \"supported\": False,\n            \"help_output\": None,\n            \"error\": \"Timeout while checking sglang.launch_server --help\",\n            \"from_cache\": False,\n        }\n    except FileNotFoundError:\n        return {\n            \"supported\": False,\n            \"help_output\": None,\n            \"error\": \"Python interpreter not found\",\n            \"from_cache\": False,\n        }\n    except Exception as e:\n        return {\n            \"supported\": False,\n            \"help_output\": None,\n            \"error\": str(e),\n            \"from_cache\": False,\n        }\n\n\ndef print_sglang_kt_kernel_instructions() -> None:\n    \"\"\"Print instructions for installing the kvcache-ai fork of SGLang with kt-kernel support.\"\"\"\n    from kt_kernel.cli.i18n import get_lang\n\n    lang = get_lang()\n\n    if lang == \"zh\":\n        instructions = \"\"\"\n[bold red]SGLang 不支持 kt-kernel[/bold red]\n\n您当前安装的 SGLang 不包含 kt-kernel 支持。\nkt-kernel 需要使用 kvcache-ai 维护的 SGLang 分支。\n\n[bold]请按以下步骤重新安装:[/bold]\n\n[cyan]1. 卸载当前的 SGLang:[/cyan]\n   pip uninstall sglang -y\n\n[cyan]2. 安装 kvcache-ai 版本 (选择一种方式):[/cyan]\n\n   [bold]方式 A - 一键安装 (推荐):[/bold]\n   从 ktransformers 根目录运行: ./install.sh\n\n   [bold]方式 B - pip 安装:[/bold]\n   pip install sglang-kt\n\n[dim]注意: 请确保在正确的 Python 环境中执行以上命令[/dim]\n\"\"\"\n    else:\n        instructions = \"\"\"\n[bold red]SGLang does not support kt-kernel[/bold red]\n\nYour current SGLang installation does not include kt-kernel support.\nkt-kernel requires the kvcache-ai maintained fork of SGLang.\n\n[bold]Please reinstall SGLang:[/bold]\n\n[cyan]1. Uninstall current SGLang:[/cyan]\n   pip uninstall sglang -y\n\n[cyan]2. Install the kvcache-ai fork (choose one):[/cyan]\n\n   [bold]Option A - One-click install (recommended):[/bold]\n   From the ktransformers root directory, run: ./install.sh\n\n   [bold]Option B - pip install:[/bold]\n   pip install sglang-kt\n\n[dim]Note: Make sure to run these commands in the correct Python environment[/dim]\n\"\"\"\n    console.print(instructions)\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/tuna_engine.py",
    "content": "\"\"\"\nTuna engine for auto-tuning GPU experts configuration.\n\nAutomatically finds the maximum viable num-gpu-experts through binary search\nby testing actual server launches with different configurations.\n\"\"\"\n\nimport json\nimport math\nimport random\nimport subprocess\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Optional\n\nfrom kt_kernel.cli.utils.console import console, print_error, print_info, print_warning\n\n\ndef get_num_experts(model_path: Path) -> int:\n    \"\"\"\n    Get the number of experts per layer from model config.\n\n    Args:\n        model_path: Path to the model directory\n\n    Returns:\n        Number of experts per layer\n\n    Raises:\n        ValueError: If config.json not found or num_experts field missing\n    \"\"\"\n    config_file = model_path / \"config.json\"\n\n    if not config_file.exists():\n        raise ValueError(f\"config.json not found in {model_path}\")\n\n    try:\n        config = json.loads(config_file.read_text())\n    except Exception as e:\n        raise ValueError(f\"Failed to parse config.json: {e}\")\n\n    # Different models may use different field names\n    possible_keys = [\n        \"num_experts_per_tok\",  # DeepSeek\n        \"num_local_experts\",  # Mixtral\n        \"n_routed_experts\",  # Qwen\n        \"num_experts\",  # Generic\n    ]\n\n    for key in possible_keys:\n        if key in config:\n            return config[key]\n\n    raise ValueError(f\"Cannot find num_experts field in {config_file}. \" f\"Tried: {', '.join(possible_keys)}\")\n\n\ndef detect_oom(log_line: Optional[str]) -> bool:\n    \"\"\"\n    Detect OOM (Out Of Memory) errors from log output.\n\n    Args:\n        log_line: A line from server output\n\n    Returns:\n        True if OOM detected, False otherwise\n    \"\"\"\n    if log_line is None:\n        return False\n\n    log_lower = log_line.lower()\n\n    oom_patterns = [\n        \"cuda out of memory\",\n        \"out of memory\",\n        \"outofmemoryerror\",\n        \"oom\",\n        \"failed to allocate\",\n        \"cumemalloc failed\",\n        \"cumemallocasync failed\",\n        \"allocation failed\",\n    ]\n\n    return any(pattern in log_lower for pattern in oom_patterns)\n\n\ndef test_config(\n    num_gpu_experts: int,\n    model_path: Path,\n    config: dict,\n    verbose: bool = False,\n) -> tuple[bool, float]:\n    \"\"\"\n    Test if a configuration with given num_gpu_experts works.\n\n    Args:\n        num_gpu_experts: Number of GPU experts to test\n        model_path: Path to the model\n        config: Configuration dict with all parameters\n        verbose: Whether to show detailed logs\n\n    Returns:\n        (success: bool, elapsed_time: float)\n        - success: True if server starts and inference works\n        - elapsed_time: Time taken for the test\n    \"\"\"\n    start_time = time.time()\n\n    # Use random port to avoid conflicts\n    test_port = random.randint(30000, 40000)\n\n    # Build command\n    cmd = [\n        sys.executable,\n        \"-m\",\n        \"sglang.launch_server\",\n        \"--model\",\n        str(model_path),\n        \"--port\",\n        str(test_port),\n        \"--host\",\n        \"127.0.0.1\",\n        \"--tensor-parallel-size\",\n        str(config[\"tensor_parallel_size\"]),\n        \"--kt-num-gpu-experts\",\n        str(num_gpu_experts),\n        \"--max-total-tokens\",\n        str(config[\"max_total_tokens\"]),\n    ]\n\n    # Add kt-kernel options\n    if config.get(\"weights_path\"):\n        cmd.extend([\"--kt-weight-path\", str(config[\"weights_path\"])])\n    else:\n        cmd.extend([\"--kt-weight-path\", str(model_path)])\n\n    cmd.extend(\n        [\n            \"--kt-cpuinfer\",\n            str(config.get(\"cpu_threads\", 64)),\n            \"--kt-threadpool-count\",\n            str(config.get(\"numa_nodes\", 2)),\n            \"--kt-method\",\n            config.get(\"kt_method\", \"AMXINT4\"),\n            \"--kt-gpu-prefill-token-threshold\",\n            str(config.get(\"kt_gpu_prefill_threshold\", 4096)),\n        ]\n    )\n\n    # Add other SGLang options\n    if config.get(\"attention_backend\"):\n        cmd.extend([\"--attention-backend\", config[\"attention_backend\"]])\n\n    cmd.extend(\n        [\n            \"--trust-remote-code\",\n            \"--mem-fraction-static\",\n            str(config.get(\"mem_fraction_static\", 0.98)),\n            \"--chunked-prefill-size\",\n            str(config.get(\"chunked_prefill_size\", 4096)),\n            \"--max-running-requests\",\n            str(config.get(\"max_running_requests\", 1)),  # Use 1 for faster testing\n            \"--watchdog-timeout\",\n            str(config.get(\"watchdog_timeout\", 3000)),\n            \"--enable-mixed-chunk\",\n            \"--enable-p2p-check\",\n        ]\n    )\n\n    # Add disable-shared-experts-fusion if specified\n    if config.get(\"disable_shared_experts_fusion\"):\n        cmd.append(\"--disable-shared-experts-fusion\")\n\n    # Add extra args\n    if config.get(\"extra_args\"):\n        cmd.extend(config[\"extra_args\"])\n\n    if verbose:\n        console.print(f\"[dim]Command: {' '.join(cmd)}[/dim]\")\n\n    # Start process\n    try:\n        process = subprocess.Popen(\n            cmd,\n            stdout=subprocess.PIPE,\n            stderr=subprocess.STDOUT,\n            text=True,\n            bufsize=1,\n            env=config.get(\"env\"),\n        )\n    except Exception as e:\n        if verbose:\n            print_error(f\"Failed to start process: {e}\")\n        return False, time.time() - start_time\n\n    # Monitor process output\n    timeout = 60  # Maximum 60 seconds to wait\n    server_ready = False\n\n    try:\n        while time.time() - start_time < timeout:\n            # Check if process has output\n            if process.poll() is not None:\n                # Process exited\n                if verbose:\n                    print_warning(\"Process exited early\")\n                return False, time.time() - start_time\n\n            # Read output line (non-blocking)\n            try:\n                line = process.stdout.readline()\n                if not line:\n                    time.sleep(0.1)\n                    continue\n\n                if verbose:\n                    console.print(f\"[dim]{line.rstrip()}[/dim]\")\n\n                # Fast OOM detection\n                if detect_oom(line):\n                    if verbose:\n                        print_warning(f\"OOM detected: {line.rstrip()}\")\n                    process.terminate()\n                    try:\n                        process.wait(timeout=2)\n                    except subprocess.TimeoutExpired:\n                        process.kill()\n                    return False, time.time() - start_time\n\n                # Check for startup success\n                if \"Uvicorn running\" in line or \"Application startup complete\" in line:\n                    server_ready = True\n                    break\n\n            except Exception as e:\n                if verbose:\n                    print_warning(f\"Error reading output: {e}\")\n                break\n\n        if not server_ready:\n            # Timeout or failed to start\n            process.terminate()\n            try:\n                process.wait(timeout=2)\n            except subprocess.TimeoutExpired:\n                process.kill()\n            return False, time.time() - start_time\n\n        # Server is ready, test inference\n        success = test_inference(test_port, verbose=verbose)\n\n        # Cleanup\n        process.terminate()\n        try:\n            process.wait(timeout=5)\n        except subprocess.TimeoutExpired:\n            process.kill()\n            process.wait(timeout=2)\n\n        return success, time.time() - start_time\n\n    except KeyboardInterrupt:\n        # User cancelled\n        process.terminate()\n        try:\n            process.wait(timeout=2)\n        except subprocess.TimeoutExpired:\n            process.kill()\n        raise\n    except Exception as e:\n        if verbose:\n            print_error(f\"Test failed with exception: {e}\")\n        try:\n            process.terminate()\n            process.wait(timeout=2)\n        except:\n            try:\n                process.kill()\n            except:\n                pass\n        return False, time.time() - start_time\n\n\ndef test_inference(port: int, verbose: bool = False) -> bool:\n    \"\"\"\n    Test if the server can handle a simple inference request.\n\n    Args:\n        port: Server port\n        verbose: Whether to show detailed logs\n\n    Returns:\n        True if inference succeeds, False otherwise\n    \"\"\"\n    try:\n        # Wait a bit for server to be fully ready\n        time.sleep(2)\n\n        # Try to import OpenAI client\n        try:\n            from openai import OpenAI\n        except ImportError:\n            if verbose:\n                print_warning(\"OpenAI package not available, skipping inference test\")\n            return True  # Assume success if we can't test\n\n        client = OpenAI(\n            base_url=f\"http://127.0.0.1:{port}/v1\",\n            api_key=\"test\",\n        )\n\n        # Send a simple test request\n        response = client.chat.completions.create(\n            model=\"test\",\n            messages=[{\"role\": \"user\", \"content\": \"Hi\"}],\n            max_tokens=1,\n            temperature=0,\n            timeout=10,\n        )\n\n        # Check if we got a valid response\n        success = response.choices and len(response.choices) > 0 and response.choices[0].message.content is not None\n\n        if verbose:\n            if success:\n                print_info(f\"Inference test passed: {response.choices[0].message.content}\")\n            else:\n                print_warning(\"Inference test failed: no valid response\")\n\n        return success\n\n    except Exception as e:\n        if verbose:\n            print_warning(f\"Inference test failed: {e}\")\n        return False\n\n\ndef find_max_gpu_experts(\n    model_path: Path,\n    config: dict,\n    verbose: bool = False,\n) -> int:\n    \"\"\"\n    Binary search to find the maximum viable num_gpu_experts.\n\n    Args:\n        model_path: Path to the model\n        config: Configuration dict\n        verbose: Whether to show detailed logs\n\n    Returns:\n        Maximum number of GPU experts that works\n    \"\"\"\n    # Get number of experts from model config\n    try:\n        num_experts = get_num_experts(model_path)\n    except ValueError as e:\n        print_error(str(e))\n        raise\n\n    console.print()\n    console.print(f\"Binary search range: [0, {num_experts}]\")\n    console.print()\n\n    left, right = 0, num_experts\n    result = 0\n    iteration = 0\n    total_iterations = math.ceil(math.log2(num_experts + 1))\n\n    while left <= right:\n        iteration += 1\n        mid = (left + right) // 2\n\n        console.print(f\"[{iteration}/{total_iterations}] Testing gpu-experts={mid}... \", end=\"\")\n\n        success, elapsed = test_config(mid, model_path, config, verbose=verbose)\n\n        if success:\n            console.print(f\"[green]✓ OK[/green] ({elapsed:.1f}s)\")\n            result = mid\n            left = mid + 1\n        else:\n            console.print(f\"[red]✗ FAILED[/red] ({elapsed:.1f}s)\")\n            right = mid - 1\n\n    return result\n\n\ndef run_tuna(\n    model_path: Path,\n    tensor_parallel_size: int,\n    max_total_tokens: int,\n    kt_method: str,\n    verbose: bool = False,\n    **kwargs,\n) -> int:\n    \"\"\"\n    Run tuna auto-tuning to find optimal num_gpu_experts.\n\n    Args:\n        model_path: Path to the model\n        tensor_parallel_size: Tensor parallel size\n        max_total_tokens: Maximum total tokens\n        kt_method: KT quantization method\n        verbose: Whether to show detailed logs\n        **kwargs: Additional configuration parameters\n\n    Returns:\n        Optimal num_gpu_experts value\n\n    Raises:\n        ValueError: If tuning fails completely\n    \"\"\"\n    # Prepare configuration\n    config = {\n        \"tensor_parallel_size\": tensor_parallel_size,\n        \"max_total_tokens\": max_total_tokens,\n        \"kt_method\": kt_method,\n        **kwargs,\n    }\n\n    # Run binary search\n    try:\n        result = find_max_gpu_experts(model_path, config, verbose=verbose)\n    except KeyboardInterrupt:\n        console.print()\n        print_warning(\"Tuning cancelled by user\")\n        raise\n\n    console.print()\n\n    # Check if even 0 doesn't work\n    if result == 0:\n        console.print(\"[yellow]Testing if gpu-experts=0 is viable...[/yellow]\")\n        success, _ = test_config(0, model_path, config, verbose=verbose)\n\n        if not success:\n            # Even 0 doesn't work\n            console.print()\n            print_error(\"Failed to start server even with all experts on CPU (gpu-experts=0)\")\n            console.print()\n            console.print(\"[bold]Possible reasons:[/bold]\")\n            console.print(\"  • Insufficient GPU memory for base model layers\")\n            console.print(\"  • max-total-tokens is too large for available VRAM\")\n            console.print(\"  • Tensor parallel configuration issue\")\n            console.print()\n            console.print(\"[bold]Suggestions:[/bold]\")\n            console.print(f\"  • Reduce --max-total-tokens (current: {max_total_tokens})\")\n            console.print(f\"  • Reduce --tensor-parallel-size (current: {tensor_parallel_size})\")\n            console.print(\"  • Use more GPUs or GPUs with more VRAM\")\n            console.print(\"  • Try a smaller model\")\n            console.print()\n            raise ValueError(\"Minimum GPU memory requirements not met\")\n        else:\n            # 0 works but nothing more\n            console.print()\n            print_warning(\"All experts will run on CPU (gpu-experts=0). \" \"Performance will be limited by CPU speed.\")\n\n    return result\n"
  },
  {
    "path": "kt-kernel/python/cli/utils/user_model_registry.py",
    "content": "\"\"\"\nUser Model Registry\n\nManages user-registered models in ~/.ktransformers/user_models.yaml\n\"\"\"\n\nfrom dataclasses import dataclass, asdict, field\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Optional, List, Dict, Any\nimport yaml\n\n\n# Constants\nUSER_MODELS_FILE = Path.home() / \".ktransformers\" / \"user_models.yaml\"\nREGISTRY_VERSION = \"1.0\"\n\n\n@dataclass\nclass UserModel:\n    \"\"\"Represents a user-registered model\"\"\"\n\n    name: str  # User-editable name (default: folder name)\n    path: str  # Absolute path to model directory\n    format: str  # \"safetensors\" | \"gguf\"\n    id: Optional[str] = None  # Unique UUID for this model (auto-generated if None)\n    repo_type: Optional[str] = None  # \"huggingface\" | \"modelscope\" | None\n    repo_id: Optional[str] = None  # e.g., \"deepseek-ai/DeepSeek-V3\"\n    sha256_status: str = \"not_checked\"  # \"not_checked\" | \"checking\" | \"passed\" | \"failed\" | \"no_repo\"\n    gpu_model_ids: Optional[List[str]] = None  # For llamafile/AMX: list of GPU model UUIDs to run with\n    created_at: str = field(default_factory=lambda: datetime.now().isoformat())\n    last_verified: Optional[str] = None  # ISO format datetime\n    # MoE information (cached from analyze_moe_model)\n    is_moe: Optional[bool] = None  # True if MoE model, False if non-MoE, None if not analyzed\n    moe_num_experts: Optional[int] = None  # Total number of experts (for MoE models)\n    moe_num_experts_per_tok: Optional[int] = None  # Number of active experts per token (for MoE models)\n    # AMX quantization metadata (for format == \"amx\")\n    amx_source_model: Optional[str] = None  # Name of the source MoE model that was quantized\n    amx_quant_method: Optional[str] = None  # \"int4\" | \"int8\"\n    amx_numa_nodes: Optional[int] = None  # Number of NUMA nodes used for quantization\n\n    def __post_init__(self):\n        \"\"\"Ensure ID is set after initialization\"\"\"\n        if self.id is None:\n            import uuid\n\n            self.id = str(uuid.uuid4())\n\n    def to_dict(self) -> Dict[str, Any]:\n        \"\"\"Convert to dictionary for YAML serialization\"\"\"\n        return asdict(self)\n\n    @classmethod\n    def from_dict(cls, data: Dict[str, Any]) -> \"UserModel\":\n        \"\"\"Create from dictionary loaded from YAML\"\"\"\n        return cls(**data)\n\n    def path_exists(self) -> bool:\n        \"\"\"Check if model path still exists\"\"\"\n        return Path(self.path).exists()\n\n\nclass UserModelRegistry:\n    \"\"\"Manages the user model registry\"\"\"\n\n    def __init__(self, registry_file: Optional[Path] = None):\n        \"\"\"\n        Initialize the registry\n\n        Args:\n            registry_file: Path to the registry YAML file (default: USER_MODELS_FILE)\n        \"\"\"\n        self.registry_file = registry_file or USER_MODELS_FILE\n        self.models: List[UserModel] = []\n        self.version = REGISTRY_VERSION\n\n        # Ensure directory exists\n        self.registry_file.parent.mkdir(parents=True, exist_ok=True)\n\n        # Load existing registry\n        self.load()\n\n    def load(self) -> None:\n        \"\"\"Load models from YAML file\"\"\"\n        if not self.registry_file.exists():\n            # Initialize empty registry\n            self.models = []\n            self.save()  # Create the file\n            return\n\n        try:\n            with open(self.registry_file, \"r\", encoding=\"utf-8\") as f:\n                data = yaml.safe_load(f)\n\n            if not data:\n                self.models = []\n                return\n\n            # Load version\n            self.version = data.get(\"version\", REGISTRY_VERSION)\n\n            # Load models\n            models_data = data.get(\"models\", [])\n            self.models = [UserModel.from_dict(m) for m in models_data]\n\n            # Migrate: ensure all models have UUIDs (for backward compatibility)\n            needs_save = False\n            for model in self.models:\n                if model.id is None:\n                    import uuid\n\n                    model.id = str(uuid.uuid4())\n                    needs_save = True\n\n            if needs_save:\n                self.save()\n\n        except Exception as e:\n            raise RuntimeError(f\"Failed to load user model registry: {e}\")\n\n    def save(self) -> None:\n        \"\"\"Save models to YAML file\"\"\"\n        data = {\"version\": self.version, \"models\": [m.to_dict() for m in self.models]}\n\n        try:\n            with open(self.registry_file, \"w\", encoding=\"utf-8\") as f:\n                yaml.safe_dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)\n        except Exception as e:\n            raise RuntimeError(f\"Failed to save user model registry: {e}\")\n\n    def add_model(self, model: UserModel) -> None:\n        \"\"\"\n        Add a model to the registry\n\n        Args:\n            model: UserModel instance to add\n\n        Raises:\n            ValueError: If a model with the same name already exists\n        \"\"\"\n        if self.check_name_conflict(model.name):\n            raise ValueError(f\"Model with name '{model.name}' already exists\")\n\n        self.models.append(model)\n        self.save()\n\n    def remove_model(self, name: str) -> bool:\n        \"\"\"\n        Remove a model from the registry\n\n        Args:\n            name: Name of the model to remove\n\n        Returns:\n            True if model was removed, False if not found\n        \"\"\"\n        original_count = len(self.models)\n        self.models = [m for m in self.models if m.name != name]\n\n        if len(self.models) < original_count:\n            self.save()\n            return True\n        return False\n\n    def update_model(self, name: str, updates: Dict[str, Any]) -> bool:\n        \"\"\"\n        Update a model's attributes\n\n        Args:\n            name: Name of the model to update\n            updates: Dictionary of attributes to update\n\n        Returns:\n            True if model was updated, False if not found\n        \"\"\"\n        model = self.get_model(name)\n        if not model:\n            return False\n\n        # Update attributes\n        for key, value in updates.items():\n            if hasattr(model, key):\n                setattr(model, key, value)\n\n        self.save()\n        return True\n\n    def get_model(self, name: str) -> Optional[UserModel]:\n        \"\"\"\n        Get a model by name\n\n        Args:\n            name: Name of the model\n\n        Returns:\n            UserModel instance or None if not found\n        \"\"\"\n        for model in self.models:\n            if model.name == name:\n                return model\n        return None\n\n    def get_model_by_id(self, model_id: str) -> Optional[UserModel]:\n        \"\"\"\n        Get a model by its unique ID\n\n        Args:\n            model_id: UUID of the model\n\n        Returns:\n            UserModel instance or None if not found\n        \"\"\"\n        for model in self.models:\n            if model.id == model_id:\n                return model\n        return None\n\n    def list_models(self) -> List[UserModel]:\n        \"\"\"\n        List all models\n\n        Returns:\n            List of all UserModel instances\n        \"\"\"\n        return self.models.copy()\n\n    def find_by_path(self, path: str) -> Optional[UserModel]:\n        \"\"\"\n        Find a model by its path\n\n        Args:\n            path: Model directory path\n\n        Returns:\n            UserModel instance or None if not found\n        \"\"\"\n        # Normalize paths for comparison\n        search_path = str(Path(path).resolve())\n\n        for model in self.models:\n            model_path = str(Path(model.path).resolve())\n            if model_path == search_path:\n                return model\n        return None\n\n    def check_name_conflict(self, name: str, exclude_name: Optional[str] = None) -> bool:\n        \"\"\"\n        Check if a name conflicts with existing models\n\n        Args:\n            name: Name to check\n            exclude_name: Optional name to exclude from check (for rename operations)\n\n        Returns:\n            True if conflict exists, False otherwise\n        \"\"\"\n        for model in self.models:\n            if model.name == name and model.name != exclude_name:\n                return True\n        return False\n\n    def refresh_status(self) -> Dict[str, List[str]]:\n        \"\"\"\n        Check all models and identify missing ones\n\n        Returns:\n            Dictionary with 'valid' and 'missing' lists of model names\n        \"\"\"\n        valid = []\n        missing = []\n\n        for model in self.models:\n            if model.path_exists():\n                valid.append(model.name)\n            else:\n                missing.append(model.name)\n\n        return {\"valid\": valid, \"missing\": missing}\n\n    def get_model_count(self) -> int:\n        \"\"\"Get total number of registered models\"\"\"\n        return len(self.models)\n\n    def suggest_name(self, base_name: str) -> str:\n        \"\"\"\n        Suggest a unique name based on base_name\n\n        Args:\n            base_name: Base name to derive from\n\n        Returns:\n            A unique name (may have suffix like -2, -3 etc.)\n        \"\"\"\n        if not self.check_name_conflict(base_name):\n            return base_name\n\n        counter = 2\n        while True:\n            candidate = f\"{base_name}-{counter}\"\n            if not self.check_name_conflict(candidate):\n                return candidate\n            counter += 1\n"
  },
  {
    "path": "kt-kernel/python/experts.py",
    "content": "# Wrapper for MoE CPU inference operations\n# This module encapsulates CPU inference engine, weight loading, and buffer management\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nExpert wrappers for CPU-based MoE inference.\n\nThis module provides the main factory interface (KTMoEWrapper) that automatically\nselects the appropriate backend implementation based on the method parameter.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport torch\nfrom typing import List, Optional\n\n# Import base infrastructure\nfrom .experts_base import BaseMoEWrapper, KExpertsCPUBuffer\n\n# Import backend implementations\nfrom .utils.amx import AMXMoEWrapper, NativeMoEWrapper\nfrom .utils.llamafile import LlamafileMoEWrapper\nfrom .utils.moe_kernel import GeneralMoEWrapper\n\n\nclass KTMoEWrapper:\n    \"\"\"\n    Factory interface for MoE CPU inference operations.\n\n    This class serves as the main entry point for external code. It automatically\n    selects the appropriate backend implementation based on the `method` parameter.\n\n    Usage:\n        # Create a mask where experts 0, 2, 5 are on GPU\n        gpu_mask = torch.zeros(8, dtype=torch.bool)\n        gpu_mask[[0, 2, 5]] = True\n\n        wrapper = KTMoEWrapper(\n            layer_idx=0,\n            num_experts=8,\n            num_experts_per_tok=2,\n            hidden_size=4096,\n            moe_intermediate_size=14336,\n            gpu_experts_mask=gpu_mask,  # or None for all experts on CPU\n            cpuinfer_threads=32,\n            threadpool_count=2,\n            weight_path=\"/path/to/weights\",\n            chunked_prefill_size=512,\n            method=\"AMXINT4\"  # or \"AMXINT8\", \"LLAMAFILE\"\n        )\n    \"\"\"\n\n    def __new__(\n        cls,\n        layer_idx: int,\n        num_experts: int,\n        num_experts_per_tok: int,\n        hidden_size: int,\n        moe_intermediate_size: int,\n        gpu_experts_mask: Optional[torch.Tensor],\n        cpuinfer_threads: int,\n        threadpool_count: int,\n        weight_path: str,\n        chunked_prefill_size: int,\n        cpu_save: bool = False,\n        max_deferred_experts_per_token: Optional[int] = None,\n        method: str = \"AMXINT4\",\n    ):\n        \"\"\"\n        Factory method to create the appropriate backend implementation.\n\n        Args:\n            layer_idx: Layer index\n            num_experts: Total number of experts\n            num_experts_per_tok: Number of experts per token (top-k)\n            hidden_size: Hidden dimension size\n            moe_intermediate_size: MoE intermediate size\n            gpu_experts_mask: Boolean mask indicating which experts are on GPU.\n                              Shape: [num_experts], dtype: torch.bool.\n                              mask[i] = True means expert i is on GPU.\n                              If None, all experts are on CPU.\n            cpuinfer_threads: Number of CPU inference threads\n            threadpool_count: Number of NUMA subpools\n            weight_path: Path to weights\n            chunked_prefill_size: Maximum prefill chunk size\n            cpu_save: Whether to save weights to CPU memory\n            max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.\n            method: Backend method (\"AMXINT4\", \"AMXINT8\", \"RAWINT4\", \"FP8\", \"BF16\", \"LLAMAFILE\", \"MOE_INT4\", \"MOE_INT8\")\n\n        Returns:\n            An instance of the appropriate backend implementation (e.g., AMXMoEWrapper)\n        \"\"\"\n        # Select backend based on method\n        if method in [\"AMXINT4\", \"AMXINT8\"]:\n            backend_cls = AMXMoEWrapper\n        elif method in [\"RAWINT4\", \"FP8\", \"BF16\", \"FP8_PERCHANNEL\"]:\n            backend_cls = NativeMoEWrapper\n        elif method == \"LLAMAFILE\":\n            backend_cls = LlamafileMoEWrapper\n        elif method in [\"MOE_INT4\", \"MOE_INT8\"]:\n            backend_cls = GeneralMoEWrapper\n        else:\n            raise NotImplementedError(f\"Unsupported method: {method}\")\n\n        # Create and return backend instance\n        return backend_cls(\n            layer_idx=layer_idx,\n            num_experts=num_experts,\n            num_experts_per_tok=num_experts_per_tok,\n            hidden_size=hidden_size,\n            moe_intermediate_size=moe_intermediate_size,\n            gpu_experts_mask=gpu_experts_mask,\n            cpuinfer_threads=cpuinfer_threads,\n            threadpool_count=threadpool_count,\n            weight_path=weight_path,\n            chunked_prefill_size=chunked_prefill_size,\n            cpu_save=cpu_save,\n            max_deferred_experts_per_token=max_deferred_experts_per_token,\n            method=method,\n        )\n\n    # Forward static methods to the base class\n    @staticmethod\n    def set_capture_batch_sizes(capture_bs: List[int]):\n        \"\"\"\n        Set batch sizes to capture and cache buffers for.\n\n        This allows pre-allocation of CPU buffers for specific batch sizes,\n        improving performance by avoiding buffer re-allocation during inference.\n\n        Args:\n            capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16])\n        \"\"\"\n        BaseMoEWrapper.set_capture_batch_sizes(capture_bs)\n\n    @staticmethod\n    def get_capture_batch_sizes() -> List[int]:\n        \"\"\"\n        Get currently configured capture batch sizes.\n\n        Returns:\n            List of batch sizes that are being captured\n        \"\"\"\n        return BaseMoEWrapper.get_capture_batch_sizes()\n\n    @staticmethod\n    def clear_buffer_cache():\n        \"\"\"\n        Clear all cached buffers.\n\n        This frees up memory by clearing the buffer cache. Useful when you want\n        to reset the buffer state or free memory.\n        \"\"\"\n        BaseMoEWrapper.clear_buffer_cache()\n"
  },
  {
    "path": "kt-kernel/python/experts_base.py",
    "content": "# Base classes for MoE CPU inference operations\n# SPDX-License-Identifier: Apache-2.0\n\n\"\"\"\nBase infrastructure for CPU-based MoE inference.\n\nThis module contains base classes and utilities shared across all backend implementations.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport torch\nfrom typing import Dict, List, Optional, Tuple\nfrom abc import ABC, abstractmethod\nimport os\nimport ctypes\n\nfrom kt_kernel import kt_kernel_ext\n\n\ndef generate_gpu_experts_masks(\n    activation_freq: torch.Tensor,\n    num_gpu_experts: int,\n) -> torch.Tensor:\n    \"\"\"\n    Generate GPU experts masks based on activation frequency.\n\n    Selects the top `num_gpu_experts` experts with highest activation frequency\n    across all layers to be placed on GPU.\n\n    Args:\n        activation_freq: Activation frequency table of shape (num_layers, num_experts).\n                         Higher values indicate more frequently activated experts.\n        num_gpu_experts: Total number of experts to place on GPU across all layers.\n\n    Returns:\n        gpu_experts_masks: Boolean mask of shape (num_layers, num_experts) on CPU.\n                           True means the expert should be on GPU.\n\n    Example:\n        >>> activation_freq = torch.tensor([\n        ...     [0.1, 0.5, 0.3, 0.8],  # layer 0\n        ...     [0.2, 0.4, 0.9, 0.1],  # layer 1\n        ... ])\n        >>> masks = generate_gpu_experts_masks(activation_freq, num_gpu_experts=3)\n        >>> # Top 3: layer0-expert3 (0.8), layer1-expert2 (0.9), layer0-expert1 (0.5)\n        >>> masks\n        tensor([[False,  True, False,  True],\n                [False, False,  True, False]])\n    \"\"\"\n    num_layers, num_experts_per_layer = activation_freq.shape\n    total_experts = num_layers * num_experts_per_layer\n\n    # Clamp num_gpu_experts to valid range\n    num_gpu_experts = min(num_gpu_experts, total_experts)\n    num_gpu_experts = max(num_gpu_experts, 0)\n\n    if num_gpu_experts == 0:\n        return torch.zeros(num_layers, num_experts_per_layer, dtype=torch.bool, device=\"cpu\")\n\n    # Flatten and find top-k indices\n    flat_freq = activation_freq.view(-1).to(device=\"cpu\")\n    _, top_indices = torch.topk(flat_freq, k=num_gpu_experts, largest=True, sorted=False)\n\n    # Create mask\n    gpu_experts_masks = torch.zeros(total_experts, dtype=torch.bool, device=\"cpu\")\n    gpu_experts_masks[top_indices] = True\n\n    # Reshape to (num_layers, num_experts)\n    gpu_experts_masks = gpu_experts_masks.view(num_layers, num_experts_per_layer)\n\n    return gpu_experts_masks\n\n\nclass KExpertsCPUBuffer:\n    \"\"\"\n    CPU buffer management for expert computation.\n\n    Manages pinned memory buffers for efficient GPU-CPU data transfer.\n    \"\"\"\n\n    capture_bs: List = list()\n    capture_buffers: Dict = dict()\n    temp_bs: int = 0\n    temp_buffer: tuple = tuple()\n    buffer_depth: int = 2\n\n    @classmethod\n    def get_buffer(cls, hidden_states: torch.Tensor, num_experts_per_tok):\n        hidden_size = hidden_states.shape[-1]\n        batch_size = hidden_states.shape[0]\n\n        if batch_size in cls.capture_buffers:\n            return cls.capture_buffers[batch_size]\n        if batch_size == cls.temp_bs:\n            return cls.temp_buffer\n\n        input_tensor_cpu = [\n            torch.zeros((batch_size, hidden_size), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16)\n            for _ in range(cls.buffer_depth)\n        ]\n        immediate_experts_ids_cpu = [\n            torch.zeros((batch_size, num_experts_per_tok), device=\"cpu\", dtype=torch.long, pin_memory=True)\n            for _ in range(cls.buffer_depth)\n        ]\n        deferred_experts_ids_cpu = [\n            torch.full((batch_size, num_experts_per_tok), -1, device=\"cpu\", dtype=torch.long, pin_memory=True)\n            for _ in range(cls.buffer_depth)\n        ]\n        weights_cpu = [\n            torch.zeros((batch_size, num_experts_per_tok), device=\"cpu\", dtype=torch.float32, pin_memory=True)\n            for _ in range(cls.buffer_depth)\n        ]\n        output_cpu = [\n            torch.zeros((batch_size, hidden_size), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16)\n            for _ in range(cls.buffer_depth)\n        ]\n        bsz_tensor_cpu = [\n            torch.full((1,), batch_size, device=\"cpu\", dtype=torch.int32, pin_memory=True)\n            for _ in range(cls.buffer_depth)\n        ]\n        output_gpu = [\n            torch.zeros((batch_size, hidden_size), device=hidden_states.device, dtype=hidden_states.dtype)\n            for _ in range(cls.buffer_depth)\n        ]\n\n        cur_buffer = (\n            input_tensor_cpu,\n            immediate_experts_ids_cpu,\n            deferred_experts_ids_cpu,\n            weights_cpu,\n            output_cpu,\n            bsz_tensor_cpu,\n            output_gpu,\n        )\n        if batch_size in cls.capture_bs:\n            cls.capture_buffers[batch_size] = cur_buffer\n        cls.temp_bs = batch_size\n        cls.temp_buffer = cur_buffer\n        return cur_buffer\n\n\nclass BaseMoEWrapper(ABC):\n    \"\"\"\n    Base class for MoE CPU inference operations.\n    Provides common functionality for all backend implementations.\n    \"\"\"\n\n    _cpu_infer_instance = None\n    _layer_has_pending_deferred: Dict[int, bool] = {}\n\n    def __init__(\n        self,\n        layer_idx: int,\n        num_experts: int,\n        num_experts_per_tok: int,\n        hidden_size: int,\n        moe_intermediate_size: int,\n        gpu_experts_mask: Optional[torch.Tensor],\n        cpuinfer_threads: int,\n        threadpool_count: int,\n        weight_path: str,\n        chunked_prefill_size: int,\n        cpu_save: bool = False,\n        max_deferred_experts_per_token: Optional[int] = None,\n        method: str = \"AMXINT4\",\n    ):\n        \"\"\"\n        Initialize base MoE Wrapper.\n\n        Args:\n            layer_idx: Layer index\n            num_experts: Total number of experts\n            num_experts_per_tok: Number of experts per token (top-k)\n            hidden_size: Hidden dimension size\n            moe_intermediate_size: MoE intermediate size\n            gpu_experts_mask: Boolean mask indicating which experts are on GPU.\n                              Shape: [num_experts], dtype: torch.bool.\n                              mask[i] = True means expert i is on GPU.\n                              If None, all experts are on CPU.\n            cpuinfer_threads: Number of CPU inference threads\n            threadpool_count: Number of NUMA subpools\n            weight_path: Path to weights\n            chunked_prefill_size: Maximum prefill chunk size\n            cpu_save: Whether to save weights to CPU memory\n            max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer).\n            method: Backend method string\n        \"\"\"\n        self.layer_idx = layer_idx\n        self.num_experts = num_experts\n        self.num_experts_per_tok = num_experts_per_tok\n        self.hidden_size = hidden_size\n        self.moe_intermediate_size = moe_intermediate_size\n\n        # Process gpu_experts_mask: convert to bool tensor on CPU, pinned memory for async copy\n        # This mask is shared between C and Python (C uses uint8_t*), both can read/write it\n        if gpu_experts_mask is None:\n            # No GPU experts - all experts on CPU\n            self.gpu_experts_mask = torch.zeros(num_experts, dtype=torch.bool, device=\"cpu\", pin_memory=True)\n        else:\n            # Create a new pinned tensor and copy data into it\n            self.gpu_experts_mask = torch.empty(num_experts, dtype=torch.bool, device=\"cpu\", pin_memory=True)\n            self.gpu_experts_mask.copy_(gpu_experts_mask)\n\n        self.num_gpu_experts = int(self.gpu_experts_mask.sum().item())\n\n        # GPU copy for mask operations in forward pass (e.g., mask_cpu_expert_ids)\n        # This will be lazily initialized when needed\n        self._gpu_experts_mask_gpu: Optional[torch.Tensor] = None\n        self.weight_path = weight_path\n        self.chunked_prefill_size = chunked_prefill_size\n        self.cpu_save = cpu_save\n        self.max_deferred_experts_per_token = (\n            int(max_deferred_experts_per_token) if max_deferred_experts_per_token is not None else 0\n        )\n\n        BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False\n        self.method = method\n\n        # Initialize CPU inference engine (singleton)\n        if BaseMoEWrapper._cpu_infer_instance is None:\n            worker_config = kt_kernel_ext.WorkerPoolConfig()\n\n            subpool_numa_map = list(range(threadpool_count))\n            subpool_thread_count = [\n                cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0)\n                for i in range(threadpool_count)\n            ]\n\n            worker_config.subpool_count = threadpool_count\n            worker_config.subpool_numa_map = subpool_numa_map\n            worker_config.subpool_thread_count = subpool_thread_count\n            BaseMoEWrapper._cpu_infer_instance = kt_kernel_ext.CPUInfer(worker_config)\n\n        self.cpu_infer = BaseMoEWrapper._cpu_infer_instance\n\n        # Backend-specific initialization happens in subclasses\n        self.moe = None\n\n    @abstractmethod\n    def load_weights_from_tensors(\n        self,\n        gate_proj: torch.Tensor,\n        up_proj: torch.Tensor,\n        down_proj: torch.Tensor,\n        physical_to_logical_map_cpu: torch.Tensor,\n    ):\n        \"\"\"\n        Load and quantize weights from BF16/FP16 tensors (online quantization).\n\n        Args:\n            gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]\n            up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]\n            down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]\n            physical_to_logical_map_cpu: Mapping from physical to logical expert IDs\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):\n        \"\"\"\n        Load weights for this layer and initialize the MoE module.\n\n        Args:\n            physical_to_logical_map_cpu: Mapping from physical to logical expert IDs\n        \"\"\"\n        pass\n\n    def select_deferred_experts(\n        self,\n        expert_ids: torch.Tensor,\n        expert_scores: torch.Tensor,\n        protected_k: int,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n        batch, topk = expert_ids.shape\n        device = expert_ids.device\n\n        protected_k = max(0, min(int(protected_k), topk))\n        if protected_k == 0:\n            deferred_ids = expert_ids.clone()\n            immediate_ids = torch.full_like(expert_ids, -1)\n            return immediate_ids, deferred_ids\n\n        topk_result = torch.topk(expert_scores, k=protected_k, dim=-1, largest=True, sorted=False)\n        protected_indices = topk_result.indices\n        protected_ids = torch.gather(expert_ids, -1, protected_indices)\n\n        protected_flag = torch.zeros((self.num_experts,), dtype=torch.int32, device=device)\n        protected_flag.scatter_(0, protected_ids.reshape(-1), 1)\n\n        protected_mask_flat = torch.gather(protected_flag, 0, expert_ids.reshape(-1)).ne(0)\n        protected_mask = protected_mask_flat.view(batch, topk)\n\n        immediate_ids = expert_ids.clone().masked_fill(~protected_mask, -1)\n        deferred_ids = expert_ids.clone().masked_fill(protected_mask, -1)\n\n        return immediate_ids, deferred_ids\n\n    def submit_forward(\n        self,\n        hidden_states: torch.Tensor,\n        topk_ids: torch.Tensor,\n        topk_weights: torch.Tensor,\n        cuda_stream,\n    ):\n        \"\"\"\n        Submit forward inference task to CPU (non-blocking).\n\n        Args:\n            hidden_states: Input hidden states [batch_size, hidden_size]\n            topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]\n            topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]\n            cuda_stream: CUDA stream for synchronization\n        \"\"\"\n        flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        batch_size = flat_hidden_states.shape[0]\n\n        (\n            input_tensor_cpu,\n            immediate_experts_ids_cpu,\n            deferred_experts_ids_cpu,\n            weights_cpu,\n            output_cpu,\n            bsz_tensor_cpu,\n            _output_gpu,\n        ) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)\n\n        current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth\n        next_slot = (current_slot + 1) % KExpertsCPUBuffer.buffer_depth\n\n        bsz_slot_tensor = bsz_tensor_cpu[current_slot]\n\n        topk_ids_long = topk_ids.to(torch.long)\n        immediate_ids: torch.Tensor\n        deferred_ids: Optional[torch.Tensor]\n        if self.max_deferred_experts_per_token > 0:\n            protected_k = self.num_experts_per_tok - self.max_deferred_experts_per_token\n\n            immediate_ids, deferred_ids = self.select_deferred_experts(topk_ids_long, topk_weights, protected_k)\n        else:\n            immediate_ids = topk_ids_long\n            deferred_ids = None\n\n        input_tensor_cpu[current_slot].copy_(flat_hidden_states, non_blocking=True)\n        weights_cpu[current_slot].copy_(topk_weights, non_blocking=True)\n        immediate_experts_ids_cpu[current_slot].copy_(immediate_ids, non_blocking=True)\n\n        incremental = BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx - 1, False)\n        self.cpu_infer.submit_with_cuda_stream(\n            cuda_stream,\n            self.moe.forward_task(\n                bsz_slot_tensor.data_ptr(),\n                immediate_experts_ids_cpu[current_slot].size(-1),\n                immediate_experts_ids_cpu[current_slot].data_ptr(),\n                weights_cpu[current_slot].data_ptr(),\n                input_tensor_cpu[current_slot].data_ptr(),\n                output_cpu[current_slot].data_ptr(),\n                incremental,\n            ),\n        )\n\n        BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = False\n        if deferred_ids is not None:\n            deferred_experts_ids_cpu[current_slot].copy_(deferred_ids, non_blocking=True)\n            self.cpu_infer.submit_with_cuda_stream(\n                cuda_stream,\n                self.moe.forward_task(\n                    bsz_slot_tensor.data_ptr(),\n                    deferred_experts_ids_cpu[current_slot].size(-1),\n                    deferred_experts_ids_cpu[current_slot].data_ptr(),\n                    weights_cpu[current_slot].data_ptr(),\n                    input_tensor_cpu[current_slot].data_ptr(),\n                    output_cpu[next_slot].data_ptr(),\n                    False,\n                ),\n            )\n            BaseMoEWrapper._layer_has_pending_deferred[self.layer_idx] = True\n\n    def sync_forward(self, hidden_states: torch.Tensor, cuda_stream) -> torch.Tensor:\n        \"\"\"\n        Synchronize and retrieve forward inference results.\n\n        Args:\n            hidden_states: Original input hidden states (for getting buffer)\n            cuda_stream: CUDA stream for synchronization\n\n        Returns:\n            output_gpu: Output tensor on GPU\n        \"\"\"\n        flat_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        (\n            _input_tensor_cpu,\n            _immediate_experts_ids_cpu,\n            _deferred_experts_ids_cpu,\n            _weights_cpu,\n            output_cpu,\n            _bsz_tensor_cpu,\n            output_gpu,\n        ) = KExpertsCPUBuffer.get_buffer(flat_hidden_states, self.num_experts_per_tok)\n\n        current_slot = self.layer_idx % KExpertsCPUBuffer.buffer_depth\n        allow_pending = 1 if BaseMoEWrapper._layer_has_pending_deferred.get(self.layer_idx, False) else 0\n        self.cpu_infer.sync_with_cuda_stream(cuda_stream, allow_pending)\n        output_gpu[current_slot].copy_(output_cpu[current_slot], non_blocking=True)\n        return output_gpu[current_slot]\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        topk_ids: torch.Tensor,\n        topk_weights: torch.Tensor,\n        cuda_stream,\n    ) -> torch.Tensor:\n        \"\"\"\n        Execute forward inference synchronously (submit + sync).\n\n        Args:\n            hidden_states: Input hidden states [batch_size, hidden_size]\n            topk_ids: Top-k expert IDs [batch_size, num_experts_per_tok]\n            topk_weights: Top-k expert weights [batch_size, num_experts_per_tok]\n            cuda_stream: CUDA stream for synchronization\n\n        Returns:\n            Output tensor on GPU\n        \"\"\"\n        self.submit_forward(hidden_states, topk_ids, topk_weights, cuda_stream)\n        return self.sync_forward(hidden_states, cuda_stream)\n\n    @staticmethod\n    def set_capture_batch_sizes(capture_bs: List[int]):\n        \"\"\"\n        Set batch sizes to capture and cache buffers for.\n\n        This allows pre-allocation of CPU buffers for specific batch sizes,\n        improving performance by avoiding buffer re-allocation during inference.\n\n        Args:\n            capture_bs: List of batch sizes to capture (e.g., [1, 2, 4, 8, 16])\n\n        Example:\n            >>> BaseMoEWrapper.set_capture_batch_sizes([1, 2, 4, 8, 16])\n        \"\"\"\n        KExpertsCPUBuffer.capture_bs = capture_bs\n\n    @staticmethod\n    def get_capture_batch_sizes() -> List[int]:\n        \"\"\"\n        Get currently configured capture batch sizes.\n\n        Returns:\n            List of batch sizes that are being captured\n        \"\"\"\n        return KExpertsCPUBuffer.capture_bs\n\n    @staticmethod\n    def clear_buffer_cache():\n        \"\"\"\n        Clear all cached buffers.\n\n        This frees up memory by clearing the buffer cache. Useful when you want\n        to reset the buffer state or free memory.\n        \"\"\"\n        KExpertsCPUBuffer.capture_buffers.clear()\n        KExpertsCPUBuffer.temp_bs = 0\n        KExpertsCPUBuffer.temp_buffer = tuple()\n"
  },
  {
    "path": "kt-kernel/python/utils/__init__.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\"\"\"\nUtilities for kt_kernel package.\n\"\"\"\n\nfrom .amx import AMXMoEWrapper, NativeMoEWrapper\nfrom .llamafile import LlamafileMoEWrapper\nfrom .loader import SafeTensorLoader, GGUFLoader, CompressedSafeTensorLoader\n\n__all__ = [\n    \"AMXMoEWrapper\",\n    \"NativeMoEWrapper\",\n    \"LlamafileMoEWrapper\",\n    \"SafeTensorLoader\",\n    \"CompressedSafeTensorLoader\",\n    \"GGUFLoader\",\n]\n"
  },
  {
    "path": "kt-kernel/python/utils/amx.py",
    "content": "import os\nimport torch\nimport ctypes\nfrom typing import Optional\n\n# Use relative imports for package structure\nfrom ..experts_base import BaseMoEWrapper\nfrom .loader import SafeTensorLoader, CompressedSafeTensorLoader, FP8SafeTensorLoader, BF16SafeTensorLoader\nfrom kt_kernel_ext.moe import MOEConfig\nimport kt_kernel_ext.moe as _moe_mod\n\nAMXInt4_MOE = getattr(_moe_mod, \"AMXInt4_MOE\", None)\nAMXInt8_MOE = getattr(_moe_mod, \"AMXInt8_MOE\", None)\nAMXInt4_KGroup_MOE = getattr(_moe_mod, \"AMXInt4_KGroup_MOE\", None)\nAMXFP8_MOE = getattr(_moe_mod, \"AMXFP8_MOE\", None)\nAMXBF16_MOE = getattr(_moe_mod, \"AMXBF16_MOE\", None)\nAMXFP8PerChannel_MOE = getattr(_moe_mod, \"AMXFP8PerChannel_MOE\", None)\n\n_HAS_AMXINT4_SUPPORT = AMXInt4_MOE is not None\n_HAS_AMXINT8_SUPPORT = AMXInt8_MOE is not None\n_HAS_RAWINT4_SUPPORT = AMXInt4_KGroup_MOE is not None\n_HAS_FP8_SUPPORT = AMXFP8_MOE is not None\n_HAS_BF16_SUPPORT = AMXBF16_MOE is not None\n_HAS_FP8_PERCHANNEL_SUPPORT = AMXFP8PerChannel_MOE is not None\n\n\nclass AMXMoEWrapper(BaseMoEWrapper):\n    \"\"\"\n    AMX-based MoE wrapper implementation.\n    Supports AMXINT4 and AMXINT8 quantization methods.\n    \"\"\"\n\n    _safetensor_loader_instance = None  # Singleton SafeTensorLoader\n\n    def __init__(\n        self,\n        layer_idx: int,\n        num_experts: int,\n        num_experts_per_tok: int,\n        hidden_size: int,\n        moe_intermediate_size: int,\n        gpu_experts_mask: Optional[torch.Tensor],\n        cpuinfer_threads: int,\n        threadpool_count: int,\n        weight_path: str,\n        chunked_prefill_size: int,\n        cpu_save: bool = False,\n        max_deferred_experts_per_token: Optional[int] = None,\n        method: str = \"AMXINT4\",\n    ):\n        \"\"\"\n        Initialize AMX MoE Wrapper.\n\n        Args:\n            layer_idx: Layer index\n            num_experts: Total number of experts\n            num_experts_per_tok: Number of experts per token (top-k)\n            hidden_size: Hidden dimension size\n            moe_intermediate_size: MoE intermediate size\n            gpu_experts_mask: Boolean mask indicating which experts are on GPU.\n                              Shape: [num_experts], dtype: torch.bool.\n                              mask[i] = True means expert i is on GPU.\n                              If None, all experts are on CPU.\n            cpuinfer_threads: Number of CPU inference threads\n            threadpool_count: Number of NUMA subpools\n            weight_path: Path to AMX weights (SafeTensor format)\n            chunked_prefill_size: Maximum prefill chunk size\n            cpu_save: Whether to save weights to CPU memory\n            max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.\n            method: AMX quantization method (\"AMXINT4\" or \"AMXINT8\")\n        \"\"\"\n        if method == \"AMXINT4\" and not _HAS_AMXINT4_SUPPORT:\n            raise RuntimeError(\n                \"AMXINT4 backend not available. Required ISA:\\n\"\n                \"  - AVX512F + AVX512BW (VNNI optional)\\n\"\n                \"Please recompile kt_kernel_ext with AVX512 enabled.\"\n            )\n        if method == \"AMXINT8\" and not _HAS_AMXINT8_SUPPORT:\n            raise RuntimeError(\n                \"AMXINT8 backend not available. Required ISA:\\n\"\n                \"  - AVX512F + AVX512BW (VNNI optional)\\n\"\n                \"Please recompile kt_kernel_ext with AVX512 enabled.\"\n            )\n\n        # Initialize base class\n        super().__init__(\n            layer_idx=layer_idx,\n            num_experts=num_experts,\n            num_experts_per_tok=num_experts_per_tok,\n            hidden_size=hidden_size,\n            moe_intermediate_size=moe_intermediate_size,\n            gpu_experts_mask=gpu_experts_mask,\n            cpuinfer_threads=cpuinfer_threads,\n            threadpool_count=threadpool_count,\n            weight_path=weight_path,\n            chunked_prefill_size=chunked_prefill_size,\n            cpu_save=cpu_save,\n            max_deferred_experts_per_token=max_deferred_experts_per_token,\n            method=method,\n        )\n\n        # AMX-specific: Check if we should load merged safetensor weights\n        self.load_merged_weight = False\n        import glob\n\n        if glob.glob(os.path.join(weight_path, \"*.safetensors\")):\n            self.load_merged_weight = True\n\n        # Initialize SafeTensor loader (singleton)\n        if self.load_merged_weight:\n            if AMXMoEWrapper._safetensor_loader_instance is None:\n                AMXMoEWrapper._safetensor_loader_instance = SafeTensorLoader(weight_path)\n            self.safetensor_loader = AMXMoEWrapper._safetensor_loader_instance\n\n        # AMX-specific weight storage\n        self.gate_weights = None\n        self.up_weights = None\n        self.down_weights = None\n        self.gate_scales = None\n        self.up_scales = None\n        self.down_scales = None\n\n    def load_weights_from_tensors(\n        self,\n        gate_proj: torch.Tensor,\n        up_proj: torch.Tensor,\n        down_proj: torch.Tensor,\n        physical_to_logical_map_cpu: torch.Tensor,\n    ):\n        \"\"\"\n        Load and quantize weights from BF16/FP16 tensors (online quantization).\n\n        Args:\n            gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]\n            up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]\n            down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]\n            physical_to_logical_map_cpu: Mapping from physical to logical expert IDs\n        \"\"\"\n        # Store tensors as instance variables to keep them alive\n        self.gate_proj = gate_proj.contiguous()\n        self.up_proj = up_proj.contiguous()\n        self.down_proj = down_proj.contiguous()\n\n        # Configure MoE with online quantization (cpu_save mode)\n        moe_config = MOEConfig(\n            self.num_experts,\n            self.num_experts_per_tok,\n            self.hidden_size,\n            self.moe_intermediate_size,\n            self.gpu_experts_mask.data_ptr(),\n        )\n        moe_config.layer_idx = self.layer_idx\n        moe_config.pool = self.cpu_infer.backend_\n        moe_config.max_len = self.chunked_prefill_size\n\n        # Enable save mode for online quantization\n        moe_config.save = True\n        moe_config.load = False\n\n        # Set weight pointers\n        moe_config.gate_proj = self.gate_proj.data_ptr()\n        moe_config.up_proj = self.up_proj.data_ptr()\n        moe_config.down_proj = self.down_proj.data_ptr()\n\n        # Set output path for quantized weights\n        moe_config.path = self.weight_path\n\n        # Create MoE module based on AMX method\n        if self.method == \"AMXINT4\":\n            self.moe = AMXInt4_MOE(moe_config)\n        elif self.method == \"AMXINT8\":\n            self.moe = AMXInt8_MOE(moe_config)\n        else:\n            raise NotImplementedError(f\"Unsupported AMX method: {self.method}\")\n\n        # Submit quantization and save task\n        self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))\n        self.cpu_infer.sync()\n\n    def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):\n        \"\"\"\n        Load weights for this layer and initialize the MoE module.\n\n        Args:\n            physical_to_logical_map_cpu: Mapping from physical to logical expert IDs\n        \"\"\"\n        gate_ptr = 0\n        up_ptr = 0\n        down_ptr = 0\n\n        gate_ptrs = []\n        up_ptrs = []\n        down_ptrs = []\n\n        gate_scale_ptrs = []\n        up_scale_ptrs = []\n        down_scale_ptrs = []\n\n        if self.load_merged_weight:\n            base_key = f\"blk.{self.layer_idx}\"\n            w = self.safetensor_loader.load_experts(base_key)\n\n            self.gate_weights = w[\"gate\"]\n            self.up_weights = w[\"up\"]\n            self.down_weights = w[\"down\"]\n            self.gate_scales = w[\"gate_scale\"]\n            self.up_scales = w[\"up_scale\"]\n            self.down_scales = w[\"down_scale\"]\n\n            # Get pointers to weight arrays\n            gate_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.gate_weights\n            ]\n\n            up_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.up_weights\n            ]\n\n            down_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.down_weights\n            ]\n\n            gate_scale_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.gate_scales\n            ]\n\n            up_scale_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.up_scales\n            ]\n\n            down_scale_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.down_scales\n            ]\n\n        # Configure MoE\n        moe_config = MOEConfig(\n            self.num_experts,\n            self.num_experts_per_tok,\n            self.hidden_size,\n            self.moe_intermediate_size,\n            self.gpu_experts_mask.data_ptr(),\n        )\n        moe_config.layer_idx = self.layer_idx\n        moe_config.pool = self.cpu_infer.backend_\n        moe_config.max_len = self.chunked_prefill_size\n\n        moe_config.gate_proj = gate_ptr\n        moe_config.up_proj = up_ptr\n        moe_config.down_proj = down_ptr\n        moe_config.gate_projs = gate_ptrs\n        moe_config.up_projs = up_ptrs\n        moe_config.down_projs = down_ptrs\n        moe_config.gate_scales = gate_scale_ptrs\n        moe_config.up_scales = up_scale_ptrs\n        moe_config.down_scales = down_scale_ptrs\n\n        if self.cpu_save:\n            moe_config.save = True\n            moe_config.load = False\n            base_key = f\"model.layers.{self.layer_idx}\"\n            w = self.safetensor_loader.load_experts(base_key)\n\n            self.gate_proj = torch.cat(w[\"gate_weight\"], dim=0).contiguous()\n            self.up_proj = torch.cat(w[\"up_weight\"], dim=0).contiguous()\n            self.down_proj = torch.cat(w[\"down_weight\"], dim=0).contiguous()\n\n            moe_config.gate_proj = self.gate_proj.data_ptr()\n            moe_config.up_proj = self.up_proj.data_ptr()\n            moe_config.down_proj = self.down_proj.data_ptr()\n        else:\n            moe_config.load = True\n\n        if not self.load_merged_weight:\n            moe_config.path = self.weight_path\n\n        # Create MoE module based on AMX method\n        if self.method == \"AMXINT4\":\n            self.moe = AMXInt4_MOE(moe_config)\n        elif self.method == \"AMXINT8\":\n            self.moe = AMXInt8_MOE(moe_config)\n        else:\n            raise NotImplementedError(f\"Unsupported AMX method: {self.method}\")\n\n        # Load weights\n        self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))\n        self.cpu_infer.sync()\n\n        # Clean up temporary weight storage if using merged weights\n        if self.load_merged_weight:\n            del self.gate_weights\n            del self.up_weights\n            del self.down_weights\n            del self.gate_scales\n            del self.up_scales\n            del self.down_scales\n\n\nclass NativeMoEWrapper(BaseMoEWrapper):\n    \"\"\"Wrapper for RAWINT4/FP8/FP8_PERCHANNEL/BF16 experts stored in compressed SafeTensor format.\"\"\"\n\n    _native_loader_instance = None\n\n    def __init__(\n        self,\n        layer_idx: int,\n        num_experts: int,\n        num_experts_per_tok: int,\n        hidden_size: int,\n        moe_intermediate_size: int,\n        gpu_experts_mask: Optional[torch.Tensor],\n        cpuinfer_threads: int,\n        threadpool_count: int,\n        weight_path: str,\n        chunked_prefill_size: int,\n        cpu_save: bool = False,\n        max_deferred_experts_per_token: Optional[int] = None,\n        method: str = \"RAWINT4\",\n    ):\n        if method == \"RAWINT4\" and not _HAS_RAWINT4_SUPPORT:\n            raise RuntimeError(\n                \"RAWINT4 backend not available. Required ISA:\\n\"\n                \"  - AVX512F + AVX512BW (VNNI optional)\\n\"\n                \"Please recompile kt_kernel_ext with AVX512 enabled.\"\n            )\n        if method == \"FP8\" and not _HAS_FP8_SUPPORT:\n            raise RuntimeError(\n                \"FP8 backend not available. Required ISA:\\n\"\n                \"  - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\\n\"\n                \"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled.\"\n            )\n        if method == \"FP8_PERCHANNEL\" and not _HAS_FP8_PERCHANNEL_SUPPORT:\n            raise RuntimeError(\n                \"FP8_PERCHANNEL backend not available. Required ISA:\\n\"\n                \"  - AVX512F + AVX512BW + AVX512_BF16 + AVX512_VBMI\\n\"\n                \"Please recompile kt_kernel_ext with AVX512 + BF16 + VBMI enabled.\"\n            )\n        if method == \"BF16\" and not _HAS_BF16_SUPPORT:\n            raise RuntimeError(\n                \"BF16 backend not available. Required ISA:\\n\"\n                \"  - AVX512F + AVX512BW + AVX512_BF16\\n\"\n                \"Please recompile kt_kernel_ext with AVX512 + BF16 enabled.\"\n            )\n\n        super().__init__(\n            layer_idx=layer_idx,\n            num_experts=num_experts,\n            num_experts_per_tok=num_experts_per_tok,\n            hidden_size=hidden_size,\n            moe_intermediate_size=moe_intermediate_size,\n            gpu_experts_mask=gpu_experts_mask,\n            cpuinfer_threads=cpuinfer_threads,\n            threadpool_count=threadpool_count,\n            weight_path=weight_path,\n            chunked_prefill_size=chunked_prefill_size,\n            cpu_save=cpu_save,\n            max_deferred_experts_per_token=max_deferred_experts_per_token,\n            method=method,\n        )\n\n        if NativeMoEWrapper._native_loader_instance is None:\n            if method == \"RAWINT4\":\n                NativeMoEWrapper._native_loader_instance = CompressedSafeTensorLoader(weight_path)\n            elif method == \"FP8\":\n                NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path)\n            elif method == \"FP8_PERCHANNEL\":\n                # Use FP8SafeTensorLoader with per-channel scale format\n                NativeMoEWrapper._native_loader_instance = FP8SafeTensorLoader(weight_path, scale_suffix=\"weight_scale\")\n            elif method == \"BF16\":\n                NativeMoEWrapper._native_loader_instance = BF16SafeTensorLoader(weight_path)\n            else:\n                raise NotImplementedError(f\"Unsupported method for NativeMoEWrapper: {method}\")\n        self.loader = NativeMoEWrapper._native_loader_instance\n\n        self.gate_weights = None\n        self.up_weights = None\n        self.down_weights = None\n        self.gate_scales = None\n        self.up_scales = None\n        self.down_scales = None\n\n    def load_weights_from_tensors(\n        self,\n        gate_proj: torch.Tensor,\n        up_proj: torch.Tensor,\n        down_proj: torch.Tensor,\n        physical_to_logical_map_cpu: torch.Tensor,\n    ):\n        raise NotImplementedError(\"RAWINT4 wrapper expects pre-quantized safetensor weights.\")\n\n    def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):\n        import time\n\n        t0 = time.time()\n        base_key = f\"model.layers.{self.layer_idx}\"\n        weights = self.loader.load_experts(base_key)\n        t1 = time.time()\n\n        # Keep individual tensors instead of stacking - avoid expensive memory copy\n        # weights[\"gate\"], weights[\"up\"], weights[\"down\"] are lists of tensors per expert\n        self.gate_weights = weights[\"gate\"]  # list of tensors\n        self.up_weights = weights[\"up\"]\n        self.down_weights = weights[\"down\"]\n\n        # BF16 has no scales, others have scales\n        if self.method == \"BF16\":\n            # BF16 doesn't have scales\n            self.gate_scales = None\n            self.up_scales = None\n            self.down_scales = None\n        else:\n            # Convert scales to bf16 individually\n            # self.gate_scales = [t.to(torch.bfloat16).contiguous() for t in weights[\"gate_scale\"]]\n            # self.up_scales = [t.to(torch.bfloat16).contiguous() for t in weights[\"up_scale\"]]\n            # self.down_scales = [t.to(torch.bfloat16).contiguous() for t in weights[\"down_scale\"]]\n            self.gate_scales = weights[\"gate_scale\"]\n            self.up_scales = weights[\"up_scale\"]\n            self.down_scales = weights[\"down_scale\"]\n            if self.method == \"RAWINT4\":\n                assert self.gate_scales[0].dtype == torch.bfloat16, \"Expected bf16 scales for RAWINT4\"\n            elif self.method == \"FP8\":\n                if self.gate_scales[0].dtype != torch.float32:\n                    self.gate_scales = [t.to(torch.float32).contiguous() for t in weights[\"gate_scale\"]]\n                    self.up_scales = [t.to(torch.float32).contiguous() for t in weights[\"up_scale\"]]\n                    self.down_scales = [t.to(torch.float32).contiguous() for t in weights[\"down_scale\"]]\n                assert self.gate_scales[0].dtype == torch.float32, \"Expected float32 scales for FP8\"\n            elif self.method == \"FP8_PERCHANNEL\":\n                if self.gate_scales[0].dtype != torch.float32:\n                    self.gate_scales = [t.to(torch.float32).contiguous() for t in weights[\"gate_scale\"]]\n                    self.up_scales = [t.to(torch.float32).contiguous() for t in weights[\"up_scale\"]]\n                    self.down_scales = [t.to(torch.float32).contiguous() for t in weights[\"down_scale\"]]\n                assert self.gate_scales[0].dtype == torch.float32, \"Expected float32 scales for FP8_PERCHANNEL\"\n\n        t2 = time.time()\n\n        # Build pointer lists: [numa_id][expert_id] -> pointer\n        # Since RAWINT4/FP8/BF16 has no numa sharding, numa dimension is 1\n        gate_ptrs = [[t.data_ptr() for t in self.gate_weights]]\n        up_ptrs = [[t.data_ptr() for t in self.up_weights]]\n        down_ptrs = [[t.data_ptr() for t in self.down_weights]]\n\n        # BF16 has no scales, pass empty lists (will use 0/nullptr for consistency)\n        if self.method == \"BF16\":\n            gate_scale_ptrs = [[0 for _ in self.gate_weights]]\n            up_scale_ptrs = [[0 for _ in self.up_weights]]\n            down_scale_ptrs = [[0 for _ in self.down_weights]]\n        else:\n            gate_scale_ptrs = [[t.data_ptr() for t in self.gate_scales]]\n            up_scale_ptrs = [[t.data_ptr() for t in self.up_scales]]\n            down_scale_ptrs = [[t.data_ptr() for t in self.down_scales]]\n        t3 = time.time()\n\n        moe_config = MOEConfig(\n            self.num_experts,\n            self.num_experts_per_tok,\n            self.hidden_size,\n            self.moe_intermediate_size,\n            self.gpu_experts_mask.data_ptr(),\n        )\n        moe_config.layer_idx = self.layer_idx\n        moe_config.pool = self.cpu_infer.backend_\n        moe_config.max_len = self.chunked_prefill_size\n\n        # Use gate_projs instead of gate_proj for per-expert pointers\n        moe_config.gate_projs = gate_ptrs\n        moe_config.up_projs = up_ptrs\n        moe_config.down_projs = down_ptrs\n        moe_config.gate_scales = gate_scale_ptrs\n        moe_config.up_scales = up_scale_ptrs\n        moe_config.down_scales = down_scale_ptrs\n\n        # Infer group_size from scale shape (column-major layout)\n        # For gate/up projection: in_features = hidden_size\n        # So: group_size = hidden_size / scale.shape[1]\n\n        if self.method == \"RAWINT4\":\n            group_size = self.hidden_size // self.gate_scales[0].shape[1]\n            moe_config.quant_config.bits = 4\n            moe_config.quant_config.group_size = group_size\n            moe_config.quant_config.zero_point = False\n            self.moe = AMXInt4_KGroup_MOE(moe_config)\n        elif self.method == \"FP8\":\n            moe_config.quant_config.bits = 8\n            moe_config.quant_config.group_size = 128\n            moe_config.quant_config.zero_point = False\n            self.moe = AMXFP8_MOE(moe_config)\n        elif self.method == \"FP8_PERCHANNEL\":\n            moe_config.quant_config.bits = 8\n            moe_config.quant_config.per_channel = True\n            moe_config.quant_config.zero_point = False\n            self.moe = AMXFP8PerChannel_MOE(moe_config)\n        elif self.method == \"BF16\":\n            # BF16 has no quantization config needed\n            self.moe = AMXBF16_MOE(moe_config)\n        t4 = time.time()\n\n        self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))\n        self.cpu_infer.sync()\n        t5 = time.time()\n\n        del self.gate_weights\n        del self.up_weights\n        del self.down_weights\n        if self.gate_scales is not None:\n            del self.gate_scales\n            del self.up_scales\n            del self.down_scales\n        t6 = time.time()\n\n        print(\n            f\"[NativeMoEWrapper Layer {self.layer_idx}] \"\n            f\"load_experts: {(t1-t0)*1000:.1f}ms, \"\n            f\"prepare_tensors: {(t2-t1)*1000:.1f}ms, \"\n            f\"build_ptrs: {(t3-t2)*1000:.1f}ms, \"\n            f\"create_moe: {(t4-t3)*1000:.1f}ms, \"\n            f\"cpp_load_weights: {(t5-t4)*1000:.1f}ms, \"\n            f\"cleanup: {(t6-t5)*1000:.1f}ms, \"\n            f\"total: {(t6-t0)*1000:.1f}ms\"\n        )\n\n    def submit_write_weight_scale_to_buffer(\n        self,\n        gpu_tp_count: int,\n        expert_id: int,\n        w13_weight_ptrs,\n        w13_scale_ptrs,\n        w2_weight_ptrs,\n        w2_scale_ptrs,\n    ):\n        \"\"\"\n        Submit the write_weight_scale_to_buffer task for RAWINT4 KGroup AMX implementation.\n\n        This method submits the C++-exposed task `write_weight_scale_to_buffer_task` to the\n        shared CPUInfer queue. The pointer lists should be plain integer lists (e.g. from\n        tensor.data_ptr()).\n        \"\"\"\n        if self.moe is None:\n            raise RuntimeError(\"MoE instance not initialized; cannot submit write_weight_scale_to_buffer task.\")\n\n        if not hasattr(self.moe, \"write_weight_scale_to_buffer_task\"):\n            raise NotImplementedError(\n                \"write_weight_scale_to_buffer_task is not available for this backend implementation.\"\n            )\n\n        self.cpu_infer.submit(\n            self.moe.write_weight_scale_to_buffer_task(\n                gpu_tp_count,\n                expert_id,\n                w13_weight_ptrs,\n                w13_scale_ptrs,\n                w2_weight_ptrs,\n                w2_scale_ptrs,\n            )\n        )\n\n    def sync_write_weight_scale_to_buffer(self):\n        \"\"\"\n        Block until previously submitted write_weight_scale_to_buffer tasks finish.\n        \"\"\"\n        # The CPUInfer.sync() call blocks until pending tasks complete.\n        self.cpu_infer.sync()\n"
  },
  {
    "path": "kt-kernel/python/utils/llamafile.py",
    "content": "import torch\nfrom typing import Optional\nimport os\n\n# Use relative imports for package structure\nfrom ..experts_base import BaseMoEWrapper\nfrom .loader import GGUFLoader\nfrom kt_kernel_ext.moe import MOEConfig\n\ntry:\n    from kt_kernel_ext.moe import MOE\n\n    _HAS_LLAMAFILE_SUPPORT = True\nexcept (ImportError, AttributeError):\n    _HAS_LLAMAFILE_SUPPORT = False\n    MOE = None\n\nfrom kt_kernel_ext.kvcache import ggml_type\n\n\nclass LlamafileMoEWrapper(BaseMoEWrapper):\n    \"\"\"\n    Llamafile-based MoE wrapper implementation.\n    Supports GGUF quantized weights with llamafile backend.\n    \"\"\"\n\n    _gguf_loader_instance = None  # Singleton GGUFLoader\n\n    def __init__(\n        self,\n        layer_idx: int,\n        num_experts: int,\n        num_experts_per_tok: int,\n        hidden_size: int,\n        moe_intermediate_size: int,\n        gpu_experts_mask: Optional[torch.Tensor],\n        cpuinfer_threads: int,\n        threadpool_count: int,\n        weight_path: str,\n        chunked_prefill_size: int,\n        cpu_save: bool = False,\n        max_deferred_experts_per_token: Optional[int] = None,\n        method: str = \"LLAMAFILE\",\n    ):\n        \"\"\"\n        Initialize Llamafile MoE Wrapper.\n\n        Args:\n            layer_idx: Layer index\n            num_experts: Total number of experts\n            num_experts_per_tok: Number of experts per token (top-k)\n            hidden_size: Hidden dimension size\n            moe_intermediate_size: MoE intermediate size\n            gpu_experts_mask: Boolean mask indicating which experts are on GPU.\n                              Shape: [num_experts], dtype: torch.bool.\n                              mask[i] = True means expert i is on GPU.\n                              If None, all experts are on CPU.\n            cpuinfer_threads: Number of CPU inference threads\n            threadpool_count: Number of NUMA subpools (TP count)\n            weight_path: Path to GGUF weights\n            chunked_prefill_size: Maximum prefill chunk size\n            cpu_save: Not supported for Llamafile backend\n            max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.\n            method: Should be \"LLAMAFILE\"\n        \"\"\"\n        if not _HAS_LLAMAFILE_SUPPORT:\n            raise RuntimeError(\n                \"Llamafile backend not available. kt_kernel_ext was not compiled with Llamafile support.\\n\"\n                \"Please recompile with Llamafile enabled.\"\n            )\n\n        if not os.path.exists(weight_path):\n            raise FileNotFoundError(f\"GGUF weight path not found: {weight_path}\")\n\n        # Initialize GGUF loader (singleton)\n        if LlamafileMoEWrapper._gguf_loader_instance is None:\n            LlamafileMoEWrapper._gguf_loader_instance = GGUFLoader(weight_path)\n        self.gguf_loader = LlamafileMoEWrapper._gguf_loader_instance\n\n        # Validate TP configuration with QK_K alignment\n        QK_K = 256\n\n        # Check if intermediate_size is divisible by QK_K\n        if moe_intermediate_size % QK_K != 0:\n            raise ValueError(\n                f\"intermediate_size ({moe_intermediate_size}) must be divisible by QK_K ({QK_K}) \"\n                f\"for Llamafile backend\"\n            )\n\n        # Calculate TP splits with QK_K alignment\n        num_blocks = moe_intermediate_size // QK_K\n        base_blocks = num_blocks // threadpool_count\n        extra_blocks = num_blocks % threadpool_count\n\n        # Validate that we have enough blocks\n        if base_blocks == 0:\n            valid_tp_counts = list(range(1, num_blocks + 1))\n            raise ValueError(\n                f\"intermediate_size ({moe_intermediate_size}) is too small for threadpool_count ({threadpool_count}).\\n\"\n                f\"Total blocks: {num_blocks} (intermediate_size / QK_K)\\n\"\n                f\"Cannot distribute to {threadpool_count} TPs (each TP needs at least 1 block).\\n\"\n                f\"Valid threadpool_count values: {valid_tp_counts}\"\n            )\n\n        # Log TP split information\n        print(f\"[LlamafileMoEWrapper] Layer {layer_idx} TP configuration:\")\n        print(f\"  intermediate_size: {moe_intermediate_size}\")\n        print(f\"  threadpool_count: {threadpool_count}\")\n        print(f\"  QK_K: {QK_K}\")\n        print(f\"  Total blocks: {num_blocks}\")\n        print(f\"  Base blocks per TP: {base_blocks}\")\n        print(f\"  Extra blocks (distributed to first TPs): {extra_blocks}\")\n\n        current_offset = 0\n        for tp_id in range(threadpool_count):\n            tp_blocks = base_blocks + (1 if tp_id < extra_blocks else 0)\n            tp_size = tp_blocks * QK_K\n            print(f\"  TP {tp_id}: size={tp_size}, offset={current_offset}, blocks={tp_blocks}\")\n            current_offset += tp_size\n\n        # Initialize base class\n        super().__init__(\n            layer_idx=layer_idx,\n            num_experts=num_experts,\n            num_experts_per_tok=num_experts_per_tok,\n            hidden_size=hidden_size,\n            moe_intermediate_size=moe_intermediate_size,\n            gpu_experts_mask=gpu_experts_mask,\n            cpuinfer_threads=cpuinfer_threads,\n            threadpool_count=threadpool_count,\n            weight_path=weight_path,\n            chunked_prefill_size=chunked_prefill_size,\n            cpu_save=cpu_save,\n            max_deferred_experts_per_token=max_deferred_experts_per_token,\n            method=method,\n        )\n\n        self.weights_to_keep = None\n\n    def load_weights_from_tensors(\n        self,\n        gate_proj: torch.Tensor,\n        up_proj: torch.Tensor,\n        down_proj: torch.Tensor,\n        physical_to_logical_map_cpu: torch.Tensor,\n    ):\n        \"\"\"\n        Online quantization is not supported for Llamafile backend.\n        Use pre-quantized GGUF weights instead.\n        \"\"\"\n        raise NotImplementedError(\n            \"Llamafile backend does not support online quantization (load_weights_from_tensors).\\n\"\n            \"Please use pre-quantized GGUF weights and call load_weights() instead.\"\n        )\n\n    def load_weights(self, physical_to_logical_map_cpu: Optional[torch.Tensor] = None):\n        \"\"\"\n        Load weights for this layer from GGUF files and initialize the MoE module.\n\n        Args:\n            physical_to_logical_map_cpu: Optional mapping from physical to logical expert IDs\n                                         Shape: [num_experts], dtype: int32\n                                         If None, uses identity mapping [0, 1, 2, ..., num_experts-1]\n        \"\"\"\n        if not _HAS_LLAMAFILE_SUPPORT:\n            raise RuntimeError(\n                \"Llamafile backend not available. kt_kernel_ext was not compiled with Llamafile support.\\n\"\n                \"Please recompile with Llamafile enabled.\"\n            )\n\n        if physical_to_logical_map_cpu is None:\n            physical_to_logical_map_cpu = torch.arange(self.num_experts, dtype=torch.int32, device=\"cpu\")\n            print(f\"  Using default identity mapping for {self.num_experts} experts\")\n\n        base_key = f\"blk.{self.layer_idx}\"\n\n        # Load quantized tensors from GGUF\n        gate_data, gate_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(f\"{base_key}.ffn_gate_exps.weight\")\n\n        up_data, up_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(f\"{base_key}.ffn_up_exps.weight\")\n\n        down_data, down_type = self.gguf_loader.get_undequanted_tensor_and_ggml_type(f\"{base_key}.ffn_down_exps.weight\")\n\n        # Keep tensors alive\n        self.weights_to_keep = (gate_data, up_data, down_data)\n\n        hidden_type = ggml_type.BF16\n\n        # Configure MoE\n        moe_config = MOEConfig(\n            self.num_experts,\n            self.num_experts_per_tok,\n            self.hidden_size,\n            self.moe_intermediate_size,\n            self.gpu_experts_mask.data_ptr(),\n        )\n        moe_config.layer_idx = self.layer_idx\n        moe_config.pool = self.cpu_infer.backend_\n\n        # Llamafile-specific configuration\n        moe_config.m_block = 32  # Parallel block size\n        moe_config.group_min_len = 10  # Use forward_one when qlen < 10\n        moe_config.max_len = self.chunked_prefill_size\n        moe_config.group_max_len = max(1, int(self.chunked_prefill_size))\n\n        # Set weight pointers\n        moe_config.gate_proj = gate_data.data_ptr()\n        moe_config.up_proj = up_data.data_ptr()\n        moe_config.down_proj = down_data.data_ptr()\n\n        # Set quantization types\n        moe_config.gate_type = gate_type\n        moe_config.up_type = up_type\n        moe_config.down_type = down_type\n        moe_config.hidden_type = hidden_type\n\n        # Create MoE module\n        self.moe = MOE(moe_config)\n\n        # Load weights\n        self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))\n        self.cpu_infer.sync()\n\n        # Drop original weights after loading\n        self.weights_to_keep = None\n"
  },
  {
    "path": "kt-kernel/python/utils/loader.py",
    "content": "\"\"\"\nWeight loaders for different formats.\n\nThis module provides loaders for:\n- SafeTensor format (for AMX quantized weights)\n- GGUF format (for Llamafile quantized weights)\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nimport numpy as np\nimport torch\nfrom enum import IntEnum\nfrom safetensors import safe_open\nfrom gguf.gguf_reader import GGUFReader\n\n\nclass GGMLQuantizationType(IntEnum):\n    \"\"\"GGML quantization type enumeration\"\"\"\n\n    F32 = 0\n    F16 = 1\n    Q4_0 = 2\n    Q4_1 = 3\n    Q5_0 = 6\n    Q5_1 = 7\n    Q8_0 = 8\n    Q8_1 = 9\n    Q2_K = 10\n    Q3_K = 11\n    Q4_K = 12\n    Q5_K = 13\n    Q6_K = 14\n    Q8_K = 15\n    IQ2_XXS = 16\n    IQ2_XS = 17\n    IQ3_XXS = 18\n    IQ1_S = 19\n    IQ4_NL = 20\n    IQ3_S = 21\n    IQ2_S = 22\n    IQ4_XS = 23\n    I8 = 24\n    I16 = 25\n    I32 = 26\n    I64 = 27\n    F64 = 28\n    IQ1_M = 29\n    BF16 = 30\n\n\ndef translate_name_to_gguf(name):\n    \"\"\"\n    Translate PyTorch tensor name to GGUF format\n    \"\"\"\n    name = name.replace(\"lm_head.\", \"output.\")\n    name = name.replace(\"model.embed_tokens.\", \"token_embd.\")\n    name = name.replace(\"model.norm.\", \"output_norm.\")\n    name = name.replace(\"model.layers.\", \"blk.\")\n    name = name.replace(\".input_layernorm\", \".attn_norm\")\n    name = name.replace(\".mlp.down_proj\", \".ffn_down\")\n    name = name.replace(\".mlp.gate_proj\", \".ffn_gate\")\n    name = name.replace(\".mlp.up_proj\", \".ffn_up\")\n    name = name.replace(\".post_attention_layernorm\", \".ffn_norm\")\n    name = name.replace(\".self_attn.q_proj\", \".attn_q\")\n    name = name.replace(\".self_attn.k_proj\", \".attn_k\")\n    name = name.replace(\".self_attn.v_proj\", \".attn_v\")\n    name = name.replace(\".self_attn.o_proj\", \".attn_output\")\n    name = name.replace(\".self_attn.qkv_proj\", \".attn_qkv\")\n    name = name.replace(\".self_attn.kv_a_proj_with_mqa\", \".attn_kv_a_mqa\")\n    name = name.replace(\".self_attn.kv_a_layernorm\", \".attn_kv_a_norm\")\n    name = name.replace(\".self_attn.kv_b_proj\", \".attn_kv_b\")\n    name = name.replace(\".self_attn.q_a_proj\", \".attn_q_a\")\n    name = name.replace(\".self_attn.q_a_layernorm\", \".attn_q_a_norm\")\n    name = name.replace(\".self_attn.q_b_proj\", \".attn_q_b\")\n    name = name.replace(\".self_attn.q_norm\", \".attn_q_norm\")\n    name = name.replace(\".self_attn.k_norm\", \".attn_k_norm\")\n    name = name.replace(\".shared_expert.\", \".shared_experts.\")\n    name = name.replace(\".shared_expert_\", \".shared_experts_\")\n    name = name.replace(\".gate_up_proj.\", \".up_proj\")\n    name = name.replace(\".mlp.shared_experts.down_proj\", \".ffn_down_shexp\")\n    name = name.replace(\".mlp.gate.e_score_correction_bias\", \".exp_probs_b.bias\")\n    name = name.replace(\".mlp.gate\", \".ffn_gate_inp\")\n    name = name.replace(\".mlp.shared_experts.gate_proj\", \".ffn_gate_shexp\")\n    name = name.replace(\".mlp.shared_experts.up_proj\", \".ffn_up_shexp\")\n    name = name.replace(\".mlp.shared_experts_gate\", \".ffn_gate_inp_shexp\")\n    name = name.replace(\".mlp.experts\", \"\")\n    name = name.replace(\".mlp.experts.ffn_down_exps\", \".ffn_down_exps\")\n    name = name.replace(\".mlp.experts.ffn_gate_exps\", \".ffn_gate_exps\")\n    name = name.replace(\".mlp.experts.ffn_up_exps\", \".ffn_up_exps\")\n    name = name.replace(\".block_sparse_moe.gate.\", \".ffn_gate_inp.\")\n    name = name.replace(\".block_sparse_moe.experts\", \"\")\n    name = name.replace(\".feed_forward.experts\", \"\")\n    name = name.replace(\".feed_forward.router\", \".ffn_gate_inp\")\n    name = name.replace(\".feed_forward.shared_experts.down_proj\", \".ffn_down_shexp\")\n    name = name.replace(\".feed_forward.shared_experts.gate_proj\", \".ffn_gate_shexp\")\n    name = name.replace(\".feed_forward.shared_experts.up_proj\", \".ffn_up_shexp\")\n    return name\n\n\nclass SafeTensorLoader:\n    \"\"\"\n    SafeTensor format loader for AMX quantized weights.\n\n    Supports loading tensors from .safetensors files with NUMA-sharded expert weights.\n    \"\"\"\n\n    tensor_file_map: dict\n    tensor_type_map: dict\n    file_handle_map: dict\n    tensor_device_map: dict\n\n    def __init__(self, file_path: str):\n        self.__load_tensor_file_map(file_path)\n\n    def __load_tensor_file_map(self, file_path: str):\n        if not os.path.exists(file_path):\n            raise FileNotFoundError(f\"Path not found: {file_path}\")\n        if os.path.isfile(file_path):\n            folder_path = os.path.dirname(file_path)\n        else:\n            folder_path = file_path\n        self.file_handle_map = {}\n        self.tensor_file_map = {}\n        self.tensor_type_map = {}\n        self.tensor_device_map = {}\n\n        found_safetensor = False\n        for root, _, files in os.walk(folder_path):\n            files = sorted(files)\n            for file in files:\n                if file.endswith(\".safetensors\"):\n                    found_safetensor = True\n                    file_path = os.path.join(root, file)\n                    if file not in self.file_handle_map:\n                        try:\n                            handle = safe_open(file_path, framework=\"pt\")\n                            self.file_handle_map[file] = handle\n                        except Exception as e:\n                            print(f\"Error opening Safetensor file {file_path}: {e}\")\n                            continue\n\n                    f = self.file_handle_map.get(file)\n                    if f is None:\n                        continue\n                    try:\n                        for key in f.keys():\n                            self.tensor_file_map[key] = file\n                    except Exception as e:\n                        print(f\"Error reading Safetensor file {file_path}: {e}\")\n\n        if not found_safetensor:\n            raise FileNotFoundError(f\"No Safetensor files found in {folder_path}\")\n\n    def load_tensor(self, key: str, device: str = \"cpu\"):\n        if key not in self.tensor_file_map:\n            raise KeyError(f\"Key {key} not found in Safetensor files\")\n        file = self.tensor_file_map[key]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(key)\n        return tensor.to(device)\n\n    def close_all_handles(self):\n        for handle in self.file_handle_map.values():\n            handle.close()\n        self.file_handle_map.clear()\n\n    def load_experts(self, base_key: str, device: str = \"cpu\"):\n        \"\"\"\n        Load expert weights from SafeTensor files.\n\n        Expected format:\n        - blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.weight\n        - blk.{layer_index}.ffn_[up, down, gate]_exps.{expert_id}.numa.{numa_id}.scale\n\n        Args:\n            base_key: Base key like \"blk.{layer_index}\"\n            device: Target device for tensors\n\n        Returns:\n            Dictionary with keys: up, gate, down, up_scale, gate_scale, down_scale\n            Each value is a list of lists: [numa_id][expert_id] -> numpy array\n        \"\"\"\n        up_base_key = f\"{base_key}.ffn_up_exps\"\n        gate_base_key = f\"{base_key}.ffn_gate_exps\"\n        down_base_key = f\"{base_key}.ffn_down_exps\"\n        max_numa_id = -1\n        max_experts_count = -1\n        while self.has_tensor(f\"{up_base_key}.{max_experts_count+1}.numa.{0}.weight\"):\n            max_experts_count += 1\n        if max_experts_count == 0:\n            raise ValueError(f\"No experts found for key {base_key}\")\n        while self.has_tensor(f\"{up_base_key}.{0}.numa.{max_numa_id+1}.weight\"):\n            max_numa_id += 1\n        # Initialize empty lists to store tensors for each projection type\n        up_weights = [[] for _ in range(max_numa_id + 1)]\n        gate_weights = [[] for _ in range(max_numa_id + 1)]\n        down_weights = [[] for _ in range(max_numa_id + 1)]\n        up_scales = [[] for _ in range(max_numa_id + 1)]\n        gate_scales = [[] for _ in range(max_numa_id + 1)]\n        down_scales = [[] for _ in range(max_numa_id + 1)]\n        for numa_id in range(max_numa_id + 1):\n            for expert_id in range(max_experts_count + 1):\n                up_key = f\"{up_base_key}.{expert_id}.numa.{numa_id}.weight\"\n                gate_key = f\"{gate_base_key}.{expert_id}.numa.{numa_id}.weight\"\n                down_key = f\"{down_base_key}.{expert_id}.numa.{numa_id}.weight\"\n                up_scale_key = f\"{up_base_key}.{expert_id}.numa.{numa_id}.scale\"\n                gate_scale_key = f\"{gate_base_key}.{expert_id}.numa.{numa_id}.scale\"\n                down_scale_key = f\"{down_base_key}.{expert_id}.numa.{numa_id}.scale\"\n                # make sure contiguous\n                up_tensor = self.load_tensor(up_key, device).numpy()\n                gate_tensor = self.load_tensor(gate_key, device).numpy()\n                down_tensor = self.load_tensor(down_key, device).numpy()\n                up_scale_tensor = self.load_tensor(up_scale_key, device).numpy()\n                gate_scale_tensor = self.load_tensor(gate_scale_key, device).numpy()\n                down_scale_tensor = self.load_tensor(down_scale_key, device).numpy()\n\n                up_weights[numa_id].append(up_tensor)\n                gate_weights[numa_id].append(gate_tensor)\n                down_weights[numa_id].append(down_tensor)\n                up_scales[numa_id].append(up_scale_tensor)\n                gate_scales[numa_id].append(gate_scale_tensor)\n                down_scales[numa_id].append(down_scale_tensor)\n        return {\n            \"up\": up_weights,\n            \"gate\": gate_weights,\n            \"down\": down_weights,\n            \"up_scale\": up_scales,\n            \"gate_scale\": gate_scales,\n            \"down_scale\": down_scales,\n        }\n\n    def has_tensor(self, name: str):\n        return name in self.tensor_file_map\n\n\nclass FP8SafeTensorLoader(SafeTensorLoader):\n    \"\"\"Loader for FP8 expert weights with auto-detection of naming formats.\n\n    Supported formats:\n    - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight\n    - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight\n    - Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight\n\n    Supported scale formats (auto-detected):\n    - Block-wise: weight_scale_inv (DeepSeek FP8)\n    - Per-channel: weight_scale (GLM-4.7-FP8)\n\n    The format is auto-detected during initialization.\n    \"\"\"\n\n    # Known MoE naming formats: (experts_path_template, gate_name, up_name, down_name)\n    MOE_FORMATS = {\n        \"deepseek\": (\"{base}.mlp.experts\", \"gate_proj\", \"up_proj\", \"down_proj\"),\n        \"mixtral\": (\"{base}.block_sparse_moe.experts\", \"w1\", \"w3\", \"w2\"),\n        \"mistral\": (\"{base}.experts\", \"w1\", \"w3\", \"w2\"),\n    }\n\n    def __init__(self, file_path: str, scale_suffix: str = None):\n        \"\"\"Initialize FP8 loader with optional scale suffix override.\n\n        Args:\n            file_path: Path to safetensor files\n            scale_suffix: Optional scale key suffix. If None, auto-detect between\n                         'weight_scale_inv' (block-wise) and 'weight_scale' (per-channel).\n        \"\"\"\n        super().__init__(file_path)\n        self._detected_format = None\n        self._scale_suffix = scale_suffix  # None means auto-detect\n        # Set per_channel based on explicit scale_suffix if provided\n        if scale_suffix == \"weight_scale\":\n            self._is_per_channel = True\n        elif scale_suffix == \"weight_scale_inv\":\n            self._is_per_channel = False\n        else:\n            self._is_per_channel = False  # Will be updated in _detect_format if auto-detect\n        self._is_vl_model = False\n        self._detect_format()\n\n    def _detect_format(self):\n        \"\"\"Auto-detect the MoE naming format and scale format by checking tensor keys.\"\"\"\n        # Sample some tensor names to detect format\n        sample_keys = list(self.tensor_file_map.keys())[:1000]\n\n        for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():\n            # Check if any key matches this format pattern\n            # Look for pattern like: model.layers.0.{experts_path}.0.{gate_name}.weight\n            for key in sample_keys:\n                if \".experts.\" in key and f\".{gate}.weight\" in key:\n                    # Verify the path template matches\n                    if \"block_sparse_moe.experts\" in key and fmt_name == \"mixtral\":\n                        self._detected_format = fmt_name\n                        print(f\"[FP8SafeTensorLoader] Detected format: {fmt_name}\")\n                        break\n                    elif \"mlp.experts\" in key and \"block_sparse_moe\" not in key and fmt_name == \"deepseek\":\n                        self._detected_format = fmt_name\n                        print(f\"[FP8SafeTensorLoader] Detected format: {fmt_name}\")\n                        break\n                    elif fmt_name == \"mistral\" and \".mlp.experts\" not in key and \".block_sparse_moe.experts\" not in key:\n                        self._detected_format = fmt_name\n                        print(f\"[FP8SafeTensorLoader] Detected format: {fmt_name}\")\n                        break\n            if self._detected_format:\n                break\n\n        # Default to deepseek if no format detected\n        if not self._detected_format:\n            self._detected_format = \"deepseek\"\n            print(\"[FP8SafeTensorLoader] No MoE format detected, defaulting to: deepseek\")\n\n        # Auto-detect scale suffix if not specified\n        if self._scale_suffix is None:\n            _, gate, _, _ = self.MOE_FORMATS[self._detected_format]\n            # Check for per-channel scale (weight_scale) vs block-wise (weight_scale_inv)\n            for key in sample_keys:\n                if f\".{gate}.weight_scale_inv\" in key:\n                    self._scale_suffix = \"weight_scale_inv\"\n                    self._is_per_channel = False\n                    print(\"[FP8SafeTensorLoader] Detected scale format: block-wise (weight_scale_inv)\")\n                    if key.startswith(\"model.language_model.\") and self._detected_format == \"deepseek\":\n                        # VL models(Qwen3.5): model.layers.{N} -> model.language_model.layers.{N}\n                        self._is_vl_model = True\n                        print(\"[FP8SafeTensorLoader] Detected VL model\")\n                    return\n                elif f\".{gate}.weight_scale\" in key and \"weight_scale_inv\" not in key:\n                    self._scale_suffix = \"weight_scale\"\n                    # Some models (e.g., Mistral) use block-wise FP8 scales but keep\n                    # the key suffix as `weight_scale` (without `_inv`). Infer format\n                    # from scale tensor shape instead of suffix alone:\n                    # - per-channel: [N] or [N, 1]\n                    # - block-wise: [N_block, K_block] (both dims > 1)\n                    scale_tensor = self.load_tensor(key, device=\"cpu\")\n                    if scale_tensor.dim() == 1:\n                        self._is_per_channel = True\n                    elif scale_tensor.dim() == 2 and scale_tensor.shape[1] == 1:\n                        self._is_per_channel = True\n                    else:\n                        self._is_per_channel = False\n\n                    scale_kind = \"per-channel\" if self._is_per_channel else \"block-wise\"\n                    print(f\"[FP8SafeTensorLoader] Detected scale format: {scale_kind} (weight_scale)\")\n                    return\n            # Default to weight_scale_inv\n            self._scale_suffix = \"weight_scale_inv\"\n            self._is_per_channel = False\n            print(\"[FP8SafeTensorLoader] No scale format detected, defaulting to: weight_scale_inv\")\n        else:\n            # Scale suffix was explicitly provided\n            scale_type = \"per-channel\" if self._is_per_channel else \"block-wise\"\n            print(f\"[FP8SafeTensorLoader] Using explicit scale format: {scale_type} ({self._scale_suffix})\")\n\n    def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:\n        \"\"\"Get candidate experts prefixes based on detected format and base key variants.\"\"\"\n        path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]\n        candidates = []\n        if self._is_vl_model:\n            base_key = base_key.replace(\"model.layers\", \"model.language_model.layers\")\n        candidates.append(path_tpl.format(base=base_key))\n\n        # Some model weights (e.g., Mistral native format) do not have \"model.\" prefix.\n        if base_key.startswith(\"model.\"):\n            candidates.append(path_tpl.format(base=base_key[len(\"model.\") :]))\n\n        # Deduplicate while preserving order.\n        return list(dict.fromkeys(candidates))\n\n    def _get_proj_names(self):\n        \"\"\"Get projection names (gate, up, down) based on detected format.\"\"\"\n        _, gate, up, down = self.MOE_FORMATS[self._detected_format]\n        return gate, up, down\n\n    def load_tensor(self, key: str, device: str = \"cpu\"):\n        if key not in self.tensor_file_map:\n            raise KeyError(f\"Key {key} not found in Safetensor files\")\n        file = self.tensor_file_map[key]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(key)\n        if device == \"cpu\":\n            return tensor\n        return tensor.to(device)\n\n    def load_experts(self, base_key: str, device: str = \"cpu\"):\n        \"\"\"Load FP8 expert weights and their scale tensors.\n\n        Supports both block-wise (weight_scale_inv) and per-channel (weight_scale) formats.\n        Per-channel scales are squeezed from [N, 1] to [N] if needed.\n        \"\"\"\n        experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)\n        gate_name, up_name, down_name = self._get_proj_names()\n\n        expert_count = 0\n        experts_prefix = None\n        for prefix in experts_prefix_candidates:\n            expert_count = 0\n            while self.has_tensor(f\"{prefix}.{expert_count}.{gate_name}.weight\"):\n                expert_count += 1\n            if expert_count > 0:\n                experts_prefix = prefix\n                break\n\n        if expert_count == 0 or experts_prefix is None:\n            raise ValueError(f\"No experts found for keys: {experts_prefix_candidates}\")\n\n        gate_weights = [None] * expert_count\n        up_weights = [None] * expert_count\n        down_weights = [None] * expert_count\n        gate_scales = [None] * expert_count\n        up_scales = [None] * expert_count\n        down_scales = [None] * expert_count\n\n        for exp_id in range(expert_count):\n            gate_w_key = f\"{experts_prefix}.{exp_id}.{gate_name}.weight\"\n            up_w_key = f\"{experts_prefix}.{exp_id}.{up_name}.weight\"\n            down_w_key = f\"{experts_prefix}.{exp_id}.{down_name}.weight\"\n            gate_s_key = f\"{experts_prefix}.{exp_id}.{gate_name}.{self._scale_suffix}\"\n            up_s_key = f\"{experts_prefix}.{exp_id}.{up_name}.{self._scale_suffix}\"\n            down_s_key = f\"{experts_prefix}.{exp_id}.{down_name}.{self._scale_suffix}\"\n\n            gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()\n            up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()\n            down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()\n\n            gate_scale = self.load_tensor(gate_s_key, device)\n            up_scale = self.load_tensor(up_s_key, device)\n            down_scale = self.load_tensor(down_s_key, device)\n\n            # For per-channel scales, squeeze [N, 1] -> [N] if needed\n            if self._is_per_channel:\n                if gate_scale.dim() == 2 and gate_scale.shape[1] == 1:\n                    gate_scale = gate_scale.squeeze(1)\n                if up_scale.dim() == 2 and up_scale.shape[1] == 1:\n                    up_scale = up_scale.squeeze(1)\n                if down_scale.dim() == 2 and down_scale.shape[1] == 1:\n                    down_scale = down_scale.squeeze(1)\n\n            gate_scales[exp_id] = gate_scale.contiguous()\n            up_scales[exp_id] = up_scale.contiguous()\n            down_scales[exp_id] = down_scale.contiguous()\n\n        return {\n            \"gate\": gate_weights,\n            \"up\": up_weights,\n            \"down\": down_weights,\n            \"gate_scale\": gate_scales,\n            \"up_scale\": up_scales,\n            \"down_scale\": down_scales,\n        }\n\n    def is_per_channel(self) -> bool:\n        \"\"\"Return True if using per-channel quantization, False for block-wise.\"\"\"\n        return self._is_per_channel\n\n\nclass BF16SafeTensorLoader(SafeTensorLoader):\n    \"\"\"Loader for native BF16 expert weights (no quantization, no scales).\n\n    Supported formats:\n    - DeepSeek style: {base}.mlp.experts.{id}.{gate,up,down}_proj.weight\n    - Mixtral/MiniMax style: {base}.block_sparse_moe.experts.{id}.{w1,w3,w2}.weight\n    - Mistral style: {base}.experts.{id}.{w1,w3,w2}.weight\n\n    The format is auto-detected during initialization.\n    \"\"\"\n\n    MOE_FORMATS = {\n        \"deepseek\": (\"{base}.mlp.experts\", \"gate_proj\", \"up_proj\", \"down_proj\"),\n        \"mixtral\": (\"{base}.block_sparse_moe.experts\", \"w1\", \"w3\", \"w2\"),\n        \"mistral\": (\"{base}.experts\", \"w1\", \"w3\", \"w2\"),\n    }\n\n    def __init__(self, file_path: str):\n        super().__init__(file_path)\n        self._detected_format = None\n        self._detect_format()\n\n    def _detect_format(self):\n        \"\"\"Auto-detect the MoE naming format by checking tensor keys.\"\"\"\n        sample_keys = list(self.tensor_file_map.keys())[:1000]\n\n        # Check for packed format first (Qwen3.5 MoE style: all experts in one 3D tensor)\n        for key in sample_keys:\n            if key.endswith(\".mlp.experts.gate_up_proj\"):\n                self._detected_format = \"packed\"\n                print(\"[BF16SafeTensorLoader] Detected format: packed (Qwen3.5 MoE style)\")\n                return\n\n        for fmt_name, (path_tpl, gate, up, down) in self.MOE_FORMATS.items():\n            for key in sample_keys:\n                if \".experts.\" in key and f\".{gate}.weight\" in key:\n                    if \"block_sparse_moe.experts\" in key and fmt_name == \"mixtral\":\n                        self._detected_format = fmt_name\n                        print(f\"[BF16SafeTensorLoader] Detected format: {fmt_name}\")\n                        return\n                    elif \"mlp.experts\" in key and \"block_sparse_moe\" not in key and fmt_name == \"deepseek\":\n                        self._detected_format = fmt_name\n                        print(f\"[BF16SafeTensorLoader] Detected format: {fmt_name}\")\n                        return\n                    elif fmt_name == \"mistral\" and \".mlp.experts\" not in key and \".block_sparse_moe.experts\" not in key:\n                        self._detected_format = fmt_name\n                        print(f\"[BF16SafeTensorLoader] Detected format: {fmt_name}\")\n                        return\n\n        self._detected_format = \"deepseek\"\n        print(\"[BF16SafeTensorLoader] No MoE format detected, defaulting to: deepseek\")\n\n    def _get_experts_prefix_candidates(self, base_key: str) -> list[str]:\n        \"\"\"Get candidate experts prefixes based on detected format and base key variants.\"\"\"\n        path_tpl, _, _, _ = self.MOE_FORMATS[self._detected_format]\n        candidates = [path_tpl.format(base=base_key)]\n\n        # Some model weights (e.g., Mistral native format) do not have \"model.\" prefix.\n        if base_key.startswith(\"model.\"):\n            candidates.append(path_tpl.format(base=base_key[len(\"model.\") :]))\n\n        return list(dict.fromkeys(candidates))\n\n    def _get_proj_names(self):\n        \"\"\"Get projection names (gate, up, down) based on detected format.\"\"\"\n        _, gate, up, down = self.MOE_FORMATS[self._detected_format]\n        return gate, up, down\n\n    def load_tensor(self, key: str, device: str = \"cpu\"):\n        if key not in self.tensor_file_map:\n            raise KeyError(f\"Key {key} not found in Safetensor files\")\n        file = self.tensor_file_map[key]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(key)\n        if device == \"cpu\":\n            return tensor\n        return tensor.to(device)\n\n    def load_experts(self, base_key: str, device: str = \"cpu\"):\n        \"\"\"Load BF16 expert weights (no scales needed).\"\"\"\n        if self._detected_format == \"packed\":\n            return self._load_experts_packed(base_key, device)\n\n        experts_prefix_candidates = self._get_experts_prefix_candidates(base_key)\n        gate_name, up_name, down_name = self._get_proj_names()\n\n        expert_count = 0\n        experts_prefix = None\n        for prefix in experts_prefix_candidates:\n            expert_count = 0\n            while self.has_tensor(f\"{prefix}.{expert_count}.{gate_name}.weight\"):\n                expert_count += 1\n            if expert_count > 0:\n                experts_prefix = prefix\n                break\n\n        if expert_count == 0 or experts_prefix is None:\n            raise ValueError(f\"No experts found for keys: {experts_prefix_candidates}\")\n\n        gate_weights = [None] * expert_count\n        up_weights = [None] * expert_count\n        down_weights = [None] * expert_count\n\n        for exp_id in range(expert_count):\n            gate_w_key = f\"{experts_prefix}.{exp_id}.{gate_name}.weight\"\n            up_w_key = f\"{experts_prefix}.{exp_id}.{up_name}.weight\"\n            down_w_key = f\"{experts_prefix}.{exp_id}.{down_name}.weight\"\n\n            gate_weights[exp_id] = self.load_tensor(gate_w_key, device).contiguous()\n            up_weights[exp_id] = self.load_tensor(up_w_key, device).contiguous()\n            down_weights[exp_id] = self.load_tensor(down_w_key, device).contiguous()\n\n        return {\n            \"gate\": gate_weights,\n            \"up\": up_weights,\n            \"down\": down_weights,\n        }\n\n    def _resolve_packed_experts_prefix(self, base_key: str) -> str:\n        \"\"\"Resolve the experts prefix for packed format, trying fallbacks.\"\"\"\n        # Direct: model.layers.{N}.mlp.experts\n        experts_prefix = f\"{base_key}.mlp.experts\"\n        if self.has_tensor(f\"{experts_prefix}.gate_up_proj\"):\n            return experts_prefix\n\n        # VL models: model.layers.{N} -> model.language_model.layers.{N}\n        parts = base_key.split(\".\", 1)\n        if len(parts) == 2:\n            alt_base = f\"{parts[0]}.language_model.{parts[1]}\"\n            experts_prefix = f\"{alt_base}.mlp.experts\"\n            if self.has_tensor(f\"{experts_prefix}.gate_up_proj\"):\n                return experts_prefix\n\n        raise ValueError(f\"No packed experts found for base_key '{base_key}'.\")\n\n    def _load_experts_packed(self, base_key: str, device: str = \"cpu\"):\n        \"\"\"Load packed expert weights (Qwen3.5 MoE style).\n\n        Packed format stores all experts in stacked 3D tensors:\n        - gate_up_proj: [num_experts, 2 * intermediate_size, hidden_size]\n        - down_proj:    [num_experts, hidden_size, intermediate_size]\n        \"\"\"\n        experts_prefix = self._resolve_packed_experts_prefix(base_key)\n\n        gate_up_key = f\"{experts_prefix}.gate_up_proj\"\n        down_key = f\"{experts_prefix}.down_proj\"\n\n        gate_up = self.load_tensor(gate_up_key, device)  # [E, 2*I, H]\n        down = self.load_tensor(down_key, device)  # [E, H, I]\n\n        mid = gate_up.shape[1] // 2\n        gate_list = [gate_up[i, :mid, :].contiguous() for i in range(gate_up.shape[0])]\n        up_list = [gate_up[i, mid:, :].contiguous() for i in range(gate_up.shape[0])]\n        down_list = [down[i].contiguous() for i in range(down.shape[0])]\n\n        return {\n            \"gate\": gate_list,\n            \"up\": up_list,\n            \"down\": down_list,\n        }\n\n\nclass CompressedSafeTensorLoader(SafeTensorLoader):\n    \"\"\"Loader for compressed SafeTensor layouts (RAWINT4 weights).\"\"\"\n\n    def load_experts(self, base_key: str, device: str = \"cpu\"):\n        \"\"\"Load raw expert weights stored in compressed safetensor format.\"\"\"\n\n        experts_prefix = f\"{base_key}.mlp.experts\"\n\n        expert_idx = 0\n        while self.has_tensor(f\"{experts_prefix}.{expert_idx}.up_proj.weight_packed\"):\n            expert_idx += 1\n\n        if expert_idx == 0:\n            experts_prefix = f\"language_model.{base_key}.mlp.experts\"\n            expert_idx = 0\n            while self.has_tensor(f\"{experts_prefix}.{expert_idx}.up_proj.weight_packed\"):\n                expert_idx += 1\n            if expert_idx == 0:\n                raise ValueError(f\"No experts found for key {experts_prefix}\")\n\n        def load_projection(proj_name: str):\n            weight_entries = []\n            scale_entries = []\n\n            for exp_id in range(expert_idx):\n                weight_key = f\"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_packed\"\n                scale_key = f\"{experts_prefix}.{exp_id}.{proj_name}_proj.weight_scale\"\n\n                if not self.has_tensor(weight_key):\n                    raise KeyError(f\"Missing tensor: {weight_key}\")\n                if not self.has_tensor(scale_key):\n                    raise KeyError(f\"Missing tensor: {scale_key}\")\n\n                weight_tensor = self.load_tensor(weight_key, device).contiguous()\n                scale_tensor = self.load_tensor(scale_key, device).contiguous()\n\n                weight_entries.append(weight_tensor)\n                scale_entries.append(scale_tensor)\n\n            return weight_entries, scale_entries\n\n        gate_weights, gate_scales = load_projection(\"gate\")\n        up_weights, up_scales = load_projection(\"up\")\n        down_weights, down_scales = load_projection(\"down\")\n\n        return {\n            \"gate\": gate_weights,\n            \"up\": up_weights,\n            \"down\": down_weights,\n            \"gate_scale\": gate_scales,\n            \"up_scale\": up_scales,\n            \"down_scale\": down_scales,\n        }\n\n\nclass GGUFLoader:\n    \"\"\"\n    GGUF format loader using the official gguf library (gguf.gguf_reader.GGUFReader)\n\n    This is a cleaner implementation compared to manual binary parsing.\n    \"\"\"\n\n    def __init__(self, gguf_path: str):\n        \"\"\"\n        Initialize GGUF loader from a file or directory\n\n        Args:\n            gguf_path: Path to a single GGUF file or a directory containing GGUF files\n        \"\"\"\n        if not os.path.exists(gguf_path):\n            raise FileNotFoundError(f\"GGUF path not found: {gguf_path}\")\n\n        self.tensor_info = {}\n        self.metadata = {}\n        self.tensor_file_map = {}\n        self.file_data_map = {}\n\n        if os.path.isfile(gguf_path) and gguf_path.endswith(\".gguf\"):\n            print(f\"\\n[GGUFLoader] Loading single GGUF file : {os.path.basename(gguf_path)}\")\n            self._load_single_file(gguf_path)\n        elif os.path.isdir(gguf_path):\n            print(f\"\\n[GGUFLoader] Loading GGUF files from directory: {gguf_path}\")\n            self._load_directory(gguf_path)\n        else:\n            raise ValueError(f\"Path must be a .gguf file or a directory: {gguf_path}\")\n\n        print(f\"[GGUFLoader] Summary:\")\n        print(f\"  Files loaded: {len(self.file_data_map)}\")\n        print(f\"  Total tensors: {len(self.tensor_info)}\")\n        print(f\"  Metadata keys: {len(self.metadata)}\")\n        tensors = [\"blk.0.ffn_up_exps.weight\", \"blk.0.ffn_gate_exps.weight\", \"blk.0.ffn_down_exps.weight\"]\n        for key in tensors:\n            if key in self.tensor_info:\n                info = self.tensor_info[key]\n                print(f\" {'.'.join(key.split('.')[2:-1])}, Dtype: {info['dtype'].name}\")\n\n    def _load_single_file(self, file_path: str):\n        \"\"\"Load a single GGUF file\"\"\"\n        reader = GGUFReader(file_path)\n\n        for key, field in reader.fields.items():\n            value = field.parts[field.data[0]]\n            if isinstance(value, bytes):\n                value = value.decode(\"utf-8\")\n            elif isinstance(value, np.ndarray) and value.dtype == np.uint8:\n                try:\n                    value = bytes(value).decode(\"utf-8\")\n                except:\n                    pass\n            self.metadata[key] = value\n\n        for tensor in reader.tensors:\n            self.tensor_info[tensor.name] = {\n                \"shape\": list(reversed(tensor.shape)),  # Reverse to match PyTorch order\n                \"dtype\": tensor.tensor_type,\n                \"offset\": tensor.data_offset,\n                \"n_elements\": tensor.n_elements,\n            }\n            self.tensor_file_map[tensor.name] = file_path\n\n        self.file_data_map[file_path] = np.memmap(file_path, mode=\"r\")\n\n    def _load_directory(self, dir_path: str):\n        \"\"\"Load all GGUF files from a directory (non-recursive)\"\"\"\n        found_gguf = False\n\n        for file in sorted(os.listdir(dir_path)):\n            if file.endswith(\".gguf\"):\n                found_gguf = True\n                file_path = os.path.join(dir_path, file)\n                print(f\"  Loading: {file}\")\n\n                reader = GGUFReader(file_path)\n\n                for key, field in reader.fields.items():\n                    value = field.parts[field.data[0]]\n                    if isinstance(value, bytes):\n                        value = value.decode(\"utf-8\")\n                    elif isinstance(value, np.ndarray) and value.dtype == np.uint8:\n                        try:\n                            value = bytes(value).decode(\"utf-8\")\n                        except:\n                            pass\n                    self.metadata[key] = value\n\n                for tensor in reader.tensors:\n                    self.tensor_info[tensor.name] = {\n                        \"shape\": list(reversed(tensor.shape)),\n                        \"dtype\": tensor.tensor_type,\n                        \"offset\": tensor.data_offset,\n                        \"n_elements\": tensor.n_elements,\n                    }\n                    self.tensor_file_map[tensor.name] = file_path\n\n                self.file_data_map[file_path] = np.memmap(file_path, mode=\"r\")\n\n        if not found_gguf:\n            raise FileNotFoundError(f\"No .gguf files found in directory: {dir_path}\")\n\n    def get_model_config(self, layer_idx: int = 0):\n        \"\"\"\n        Extract model configuration from GGUF metadata and tensor shapes.\n\n        Args:\n            layer_idx: Layer index to inspect (default: 0)\n\n        Returns:\n            dict with keys: num_experts, num_experts_per_tok, hidden_size, moe_intermediate_size\n        \"\"\"\n        config = {}\n\n        arch = self.metadata.get(\"general.architecture\", \"unknown\")\n\n        num_experts = None\n        for key_suffix in [\n            \"expert_count\",\n            \"expert.count\",\n            \"moe.expert_count\",\n            \"expert_feed_forward_length\",\n        ]:\n            key = f\"{arch}.{key_suffix}\"\n            if key in self.metadata:\n                val = self.metadata[key]\n                num_experts = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)\n                break\n\n        num_experts_per_tok = None\n        for key_suffix in [\n            \"expert_used_count\",\n            \"expert.used_count\",\n            \"moe.num_experts_per_tok\",\n        ]:\n            key = f\"{arch}.{key_suffix}\"\n            if key in self.metadata:\n                val = self.metadata[key]\n                num_experts_per_tok = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)\n                break\n\n        hidden_size = None\n        for key_suffix in [\n            \"embedding_length\",\n            \"embed_length\",\n            \"hidden_size\",\n        ]:\n            key = f\"{arch}.{key_suffix}\"\n            if key in self.metadata:\n                val = self.metadata[key]\n                hidden_size = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)\n                break\n\n        moe_intermediate_size = None\n        for key_suffix in [\n            \"expert_feed_forward_length\",\n            \"feed_forward_length\",\n            \"ffn_length\",\n            \"intermediate_size\",\n        ]:\n            key = f\"{arch}.{key_suffix}\"\n            if key in self.metadata:\n                val = self.metadata[key]\n                moe_intermediate_size = int(val[0]) if isinstance(val, (list, np.ndarray)) else int(val)\n                break\n\n        if any(v is None for v in [num_experts, hidden_size, moe_intermediate_size]):\n\n            base_key = f\"blk.{layer_idx}.ffn_gate_exps.weight\"\n            if base_key in self.tensor_info:\n                gate_shape = self.tensor_info[base_key][\"shape\"]\n                print(f\"  Found tensor '{base_key}' with shape: {gate_shape}\")\n\n                if len(gate_shape) >= 3:\n                    if num_experts is None:\n                        num_experts = int(gate_shape[0])\n                    if moe_intermediate_size is None:\n                        moe_intermediate_size = int(gate_shape[1])\n                    if hidden_size is None:\n                        hidden_size = int(gate_shape[2])\n\n        config = {\n            \"num_experts\": num_experts,\n            \"num_experts_per_tok\": num_experts_per_tok,\n            \"hidden_size\": hidden_size,\n            \"moe_intermediate_size\": moe_intermediate_size,\n        }\n\n        return config\n\n    def print_metadata(self, filter_keywords=None):\n        \"\"\"\n        Print GGUF file metadata for debugging.\n\n        Args:\n            filter_keywords: Optional list of keywords to filter metadata keys\n        \"\"\"\n        print(f\"\\n[GGUFLoader] GGUF Metadata:\")\n        print(f\"  Total metadata entries: {len(self.metadata)}\")\n\n        if filter_keywords:\n            filtered = {\n                k: v for k, v in self.metadata.items() if any(kw.lower() in k.lower() for kw in filter_keywords)\n            }\n            for k, v in sorted(filtered.items()):\n                print(f\"  {k}: {v}\")\n        else:\n            for k, v in sorted(self.metadata.items()):\n                print(f\"  {k}: {v}\")\n\n    def has_tensor(self, name: str):\n        \"\"\"Check if tensor exists\"\"\"\n        name = translate_name_to_gguf(name)\n        return name in self.tensor_info\n\n    def get_ggml_type(self, name: str):\n        \"\"\"Get GGML type of a tensor\"\"\"\n        name = translate_name_to_gguf(name)\n        if name not in self.tensor_info:\n            raise KeyError(f\"Tensor '{name}' not found in GGUF files\")\n        return self.tensor_info[name][\"dtype\"]\n\n    def get_undequanted_tensor_and_ggml_type(self, name: str):\n        \"\"\"\n        Get tensor data and its GGML type without dequantizing\n\n        Args:\n            name: Tensor name (in PyTorch format, will be translated to GGUF format)\n\n        Returns:\n            (data, ggml_type): Tuple of tensor data and GGML quantization type\n        \"\"\"\n        name = translate_name_to_gguf(name)\n\n        if name not in self.tensor_info:\n            raise KeyError(f\"Tensor '{name}' not found in GGUF files\")\n\n        info = self.tensor_info[name]\n        file_path = self.tensor_file_map[name]\n        mmap_data = self.file_data_map[file_path]\n\n        offset = info[\"offset\"]\n        n_elements = info[\"n_elements\"]\n        ggml_type = info[\"dtype\"]\n\n        GGML_QUANT_SIZES = {\n            GGMLQuantizationType.F32: (1, 4),\n            GGMLQuantizationType.F16: (1, 2),\n            GGMLQuantizationType.BF16: (1, 2),\n            GGMLQuantizationType.Q4_0: (32, 2 + 16),\n            GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),\n            GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),\n            GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),\n            GGMLQuantizationType.Q8_0: (32, 2 + 32),\n            GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),\n            GGMLQuantizationType.Q2_K: (256, 2 + 2 + 256 // 16 + 256 // 4),\n            GGMLQuantizationType.Q3_K: (256, 2 + 256 // 4 + 256 // 8 + 12),\n            GGMLQuantizationType.Q4_K: (256, 2 + 2 + 256 // 2 + 12),\n            GGMLQuantizationType.Q5_K: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12),\n            GGMLQuantizationType.Q6_K: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16),\n            GGMLQuantizationType.Q8_K: (256, 4 + 256 + 256 // 8),\n            GGMLQuantizationType.IQ2_XXS: (256, 2 + 256 // 4),\n            GGMLQuantizationType.IQ2_XS: (256, 2 + 256 // 4 + 256 // 32),\n            GGMLQuantizationType.IQ3_XXS: (256, 2 + 256 // 4 + 256 // 8),\n            GGMLQuantizationType.IQ1_S: (256, 2 + 256 // 8 + 256 // 16),\n            GGMLQuantizationType.IQ4_NL: (32, 2 + 16),\n            GGMLQuantizationType.IQ3_S: (256, 2 + 256 // 4 + 256 // 8 + 256 // 32 + 4),\n            GGMLQuantizationType.IQ2_S: (256, 2 + 256 // 4 + 256 // 16),\n            GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + 256 // 2 + 256 // 64),\n            GGMLQuantizationType.I8: (1, 1),\n            GGMLQuantizationType.I16: (1, 2),\n            GGMLQuantizationType.I32: (1, 4),\n            GGMLQuantizationType.I64: (1, 8),\n            GGMLQuantizationType.F64: (1, 8),\n            GGMLQuantizationType.IQ1_M: (256, 256 // 8 + 256 // 16 + 256 // 32),\n        }\n\n        block_size, type_size = GGML_QUANT_SIZES[ggml_type]\n        n_bytes = n_elements * type_size // block_size\n\n        data_bytes = mmap_data[offset : offset + n_bytes]\n        data = torch.from_numpy(np.frombuffer(data_bytes, dtype=np.uint8).copy())\n\n        return data, ggml_type\n"
  },
  {
    "path": "kt-kernel/python/utils/moe_kernel.py",
    "content": "import os\nimport torch\nimport ctypes\nfrom typing import Optional\n\n# Use relative imports for package structure\nfrom ..experts_base import BaseMoEWrapper\nfrom .loader import SafeTensorLoader\nfrom kt_kernel_ext.moe import MOEConfig\n\ntry:\n    from kt_kernel_ext.moe import Int8_KERNEL_MOE\n\n    _HAS_INT8_SUPPORT = True\nexcept (ImportError, AttributeError):\n    Int8_KERNEL_MOE = None\n    _HAS_INT8_SUPPORT = False\ntry:\n    from kt_kernel_ext.moe import Int4_KERNEL_MOE\n\n    _HAS_INT4_SUPPORT = True\nexcept (ImportError, AttributeError):\n    Int4_KERNEL_MOE = None\n    _HAS_INT4_SUPPORT = False\n\nfrom typing import Optional\n\n\nclass GeneralMoEWrapper(BaseMoEWrapper):\n    \"\"\"\n    moe-based MoE wrapper implementation.\n    Supports MOE_INT4 and MOE_INT8 quantization methods.\n    \"\"\"\n\n    _safetensor_loader_instance = None  # Singleton SafeTensorLoader\n\n    def __init__(\n        self,\n        layer_idx: int,\n        num_experts: int,\n        num_experts_per_tok: int,\n        hidden_size: int,\n        moe_intermediate_size: int,\n        gpu_experts_mask: Optional[torch.Tensor],\n        cpuinfer_threads: int,\n        threadpool_count: int,\n        weight_path: str,\n        chunked_prefill_size: int,\n        cpu_save: bool = False,\n        max_deferred_experts_per_token: Optional[int] = None,\n        method: str = \"MOE_INT8\",\n    ):\n        \"\"\"\n        Initialize general MoE Wrapper.\n\n        Args:\n            layer_idx: Layer index\n            num_experts: Total number of experts\n            num_experts_per_tok: Number of experts per token (top-k)\n            hidden_size: Hidden dimension size\n            moe_intermediate_size: MoE intermediate size\n            gpu_experts_mask: Boolean mask indicating which experts are on GPU.\n                              Shape: [num_experts], dtype: torch.bool.\n                              mask[i] = True means expert i is on GPU.\n                              If None, all experts are on CPU.\n            cpuinfer_threads: Number of CPU inference threads\n            threadpool_count: Number of NUMA subpools\n            weight_path: Path to weights (SafeTensor format)\n            chunked_prefill_size: Maximum prefill chunk size\n            cpu_save: Whether to save weights to CPU memory\n            max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0.\n            method: general quantization method (\"MOE_INT4\" or \"MOE_INT8\")\n        \"\"\"\n        if not _HAS_INT4_SUPPORT and method == \"MOE_INT4\":\n            raise RuntimeError(\n                \"MoE_INT4 backend not available. kt_kernel_ext was not compiled with int4 support.\\n\"\n                \"Please recompile with int4 enabled.\"\n            )\n        if not _HAS_INT8_SUPPORT and method == \"MOE_INT8\":\n            raise RuntimeError(\n                \"MoE_INT8 backend not available. kt_kernel_ext was not compiled with int8 support.\\n\"\n                \"Please recompile with int8 enabled.\"\n            )\n\n        # Initialize base class\n        super().__init__(\n            layer_idx=layer_idx,\n            num_experts=num_experts,\n            num_experts_per_tok=num_experts_per_tok,\n            hidden_size=hidden_size,\n            moe_intermediate_size=moe_intermediate_size,\n            gpu_experts_mask=gpu_experts_mask,\n            cpuinfer_threads=cpuinfer_threads,\n            threadpool_count=threadpool_count,\n            weight_path=weight_path,\n            chunked_prefill_size=chunked_prefill_size,\n            cpu_save=cpu_save,\n            max_deferred_experts_per_token=max_deferred_experts_per_token,\n            method=method,\n        )\n\n        # moe-specific: Check if we should load merged safetensor weights\n        self.load_merged_weight = False\n        import glob\n\n        if glob.glob(os.path.join(weight_path, \"*.safetensors\")):\n            self.load_merged_weight = True\n\n        # Initialize SafeTensor loader (singleton)\n        if self.load_merged_weight:\n            if GeneralMoEWrapper._safetensor_loader_instance is None:\n                GeneralMoEWrapper._safetensor_loader_instance = SafeTensorLoader(weight_path)\n            self.safetensor_loader = GeneralMoEWrapper._safetensor_loader_instance\n\n        # moe-specific weight storage\n        self.gate_weights = None\n        self.up_weights = None\n        self.down_weights = None\n        self.gate_scales = None\n        self.up_scales = None\n        self.down_scales = None\n\n    def load_weights_from_tensors(\n        self,\n        gate_proj: torch.Tensor,\n        up_proj: torch.Tensor,\n        down_proj: torch.Tensor,\n        physical_to_logical_map_cpu: torch.Tensor,\n    ):\n        \"\"\"\n        Load and quantize weights from BF16/FP16 tensors (online quantization).\n\n        Args:\n            gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]\n            up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]\n            down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]\n            physical_to_logical_map_cpu: Mapping from physical to logical expert IDs\n        \"\"\"\n        # Store tensors as instance variables to keep them alive\n        self.gate_proj = gate_proj.contiguous()\n        self.up_proj = up_proj.contiguous()\n        self.down_proj = down_proj.contiguous()\n\n        # Configure MoE with online quantization (cpu_save mode)\n        moe_config = MOEConfig(\n            self.num_experts,\n            self.num_experts_per_tok,\n            self.hidden_size,\n            self.moe_intermediate_size,\n            self.gpu_experts_mask.data_ptr(),\n        )\n        moe_config.layer_idx = self.layer_idx\n        moe_config.pool = self.cpu_infer.backend_\n        moe_config.max_len = self.chunked_prefill_size\n\n        # Enable save mode for online quantization\n        moe_config.save = True\n        moe_config.load = False\n\n        # Set weight pointers\n        moe_config.gate_proj = self.gate_proj.data_ptr()\n        moe_config.up_proj = self.up_proj.data_ptr()\n        moe_config.down_proj = self.down_proj.data_ptr()\n\n        # Set output path for quantized weights\n        moe_config.path = self.weight_path\n\n        # Create MoE module based on method\n        if self.method == \"MOE_INT4\":\n            self.moe = Int4_KERNEL_MOE(moe_config)\n        elif self.method == \"MOE_INT8\":\n            self.moe = Int8_KERNEL_MOE(moe_config)\n        else:\n            raise NotImplementedError(f\"Unsupported MoE method: {self.method}\")\n\n        # Submit quantization and save task\n        self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))\n        self.cpu_infer.sync()\n\n    def load_weights(self, physical_to_logical_map_cpu: torch.Tensor):\n        \"\"\"\n        Load weights for this layer and initialize the MoE module.\n\n        Args:\n            physical_to_logical_map_cpu: Mapping from physical to logical expert IDs\n        \"\"\"\n        gate_ptr = 0\n        up_ptr = 0\n        down_ptr = 0\n\n        gate_ptrs = []\n        up_ptrs = []\n        down_ptrs = []\n\n        gate_scale_ptrs = []\n        up_scale_ptrs = []\n        down_scale_ptrs = []\n\n        if self.load_merged_weight:\n            base_key = f\"blk.{self.layer_idx}\"\n            w = self.safetensor_loader.load_experts(base_key)\n\n            self.gate_weights = w[\"gate\"]\n            self.up_weights = w[\"up\"]\n            self.down_weights = w[\"down\"]\n            self.gate_scales = w[\"gate_scale\"]\n            self.up_scales = w[\"up_scale\"]\n            self.down_scales = w[\"down_scale\"]\n\n            # Get pointers to weight arrays\n            gate_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.gate_weights\n            ]\n\n            up_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.up_weights\n            ]\n\n            down_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.down_weights\n            ]\n\n            gate_scale_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.gate_scales\n            ]\n\n            up_scale_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.up_scales\n            ]\n\n            down_scale_ptrs = [\n                [\n                    ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)\n                    for et in numa_array\n                ]\n                for numa_array in self.down_scales\n            ]\n\n        # Configure MoE\n        moe_config = MOEConfig(\n            self.num_experts,\n            self.num_experts_per_tok,\n            self.hidden_size,\n            self.moe_intermediate_size,\n            self.gpu_experts_mask.data_ptr(),\n        )\n        moe_config.layer_idx = self.layer_idx\n        moe_config.pool = self.cpu_infer.backend_\n        moe_config.max_len = self.chunked_prefill_size\n\n        moe_config.gate_proj = gate_ptr\n        moe_config.up_proj = up_ptr\n        moe_config.down_proj = down_ptr\n        moe_config.gate_projs = gate_ptrs\n        moe_config.up_projs = up_ptrs\n        moe_config.down_projs = down_ptrs\n        moe_config.gate_scales = gate_scale_ptrs\n        moe_config.up_scales = up_scale_ptrs\n        moe_config.down_scales = down_scale_ptrs\n\n        if self.cpu_save:\n            moe_config.save = True\n            moe_config.load = False\n            base_key = f\"model.layers.{self.layer_idx}\"\n            w = self.safetensor_loader.load_experts(base_key)\n\n            self.gate_proj = torch.cat(w[\"gate_weight\"], dim=0).contiguous()\n            self.up_proj = torch.cat(w[\"up_weight\"], dim=0).contiguous()\n            self.down_proj = torch.cat(w[\"down_weight\"], dim=0).contiguous()\n\n            moe_config.gate_proj = self.gate_proj.data_ptr()\n            moe_config.up_proj = self.up_proj.data_ptr()\n            moe_config.down_proj = self.down_proj.data_ptr()\n        else:\n            moe_config.load = True\n\n        if not self.load_merged_weight:\n            moe_config.path = self.weight_path\n\n        # Create MoE module based on moe method\n        if self.method == \"MOE_INT4\":\n            self.moe = Int4_KERNEL_MOE(moe_config)\n        elif self.method == \"MOE_INT8\":\n            self.moe = Int8_KERNEL_MOE(moe_config)\n        else:\n            raise NotImplementedError(f\"Unsupported MoE method: {self.method}\")\n\n        # Load weights\n        self.cpu_infer.submit(self.moe.load_weights_task(physical_to_logical_map_cpu.data_ptr()))\n        self.cpu_infer.sync()\n\n        # Clean up temporary weight storage if using merged weights\n        if self.load_merged_weight:\n            del self.gate_weights\n            del self.up_weights\n            del self.down_weights\n            del self.gate_scales\n            del self.up_scales\n            del self.down_scales\n"
  },
  {
    "path": "kt-kernel/requirements.txt",
    "content": "# Optional: Install these if not already available in your environment\n# These dependencies will be automatically installed when running `pip install .`\n# You can skip this file if you already have these packages installed\n\n# Core dependencies (minimum versions)\ntorch>=2.0.0\nsafetensors>=0.4.0\ncompressed-tensors>=0.7.0\nnumpy>=1.24.0\ntriton>=2.0.0\ngguf>=0.17.0\n# Development dependencies\nblack>=25.9.0\n"
  },
  {
    "path": "kt-kernel/scripts/README.md",
    "content": "# Weight Quantization Tools\n\nKT-Kernel provides weight conversion tools for CPU-GPU hybrid inference (e.g., integrating KTransformers with SGLang). Both tools work together to enable heterogeneous expert placement:\n\n- **CPU Weights (`convert_cpu_weights.py`)**: Quantize weights to INT4/INT8 with AMX optimization for CPU-resident \"cold\" experts\n- **GPU Weights (`convert_gpu_weights.py`)**: Apply GPTQ/RTN quantization (W4A16/W8A16) for GPU-resident \"hot\" experts\n\n---\n\n## CPU Weight Quantization\n\nConvert weights to INT4/INT8 format optimized for AMX inference on CPU. These quantized weights are used for \"cold\" experts (less frequently accessed) that run on CPU in hybrid inference scenarios.\n\n### Quantization Methods\n\n- **INT4**: 4-bit quantization for maximum memory efficiency\n- **INT8**: 8-bit quantization for better accuracy\n\n### Supported Input Formats\n\n- **FP8**: 8-bit floating point with automatic dequantization\n- **FP16**: 16-bit floating point\n- **BF16**: BFloat16 format\n\n> **⚠️ Precision Warning:** Quantizing directly from FP8 to INT4/INT8 may cause significant accuracy degradation. For best results, use the original **BF16** model as the source for INT4/INT8 quantization.\n\n## Basic Usage\n\n### Quantize BF16 model to INT4\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /path/to/bf16/model \\\n  --input-type bf16 \\\n  --output /path/to/output \\\n  --quant-method int4\n```\n\n### Quantize FP16 model to INT8\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /path/to/fp16/model \\\n  --input-type fp16 \\\n  --output /path/to/output \\\n  --quant-method int8\n```\n\n### Quantize FP8 model to INT4\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /path/to/fp8/model \\\n  --input-type fp8 \\\n  --output /path/to/output \\\n  --quant-method int4\n```\n\n## Output Format\n\nBy default, the converted weights are saved in SafeTensors format with NUMA-aware layout:\n\n```\noutput_dir/\n├── model-00001-of-00050.safetensors\n├── model-00002-of-00050.safetensors\n├── ...\n├── config.json\n└── tokenizer files...\n```\n\nEach expert's weights are split across NUMA nodes for optimal memory access:\n- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.weight`: Quantized weights\n- `blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.scale`: Quantization scales\n\n## Advanced Options\n\n### Low Memory Mode\n\nFor systems with insufficient memory to complete full model quantization, use the `--no-merge-safetensor` flag to keep weights in layer folder structure without merging into safetensor files:\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /path/to/model \\\n  --input-type bf16 \\\n  --output /path/to/output \\\n  --quant-method int4 \\\n  --no-merge-safetensor\n```\n\nThis will save quantized weights in the following folder structure:\n\n```\noutput_dir/\n├── _layer_0/\n│   ├── _numa_0/\n│   │   ├── INT4_down_0_*.kt\n│   │   ├── INT4_gate_0_*.kt\n│   │   └── INT4_up_0_*.kt\n│   └── _numa_1/\n│       └── ...\n├── _layer_1/\n│   └── ...\n└── ...\n```\n\n**When to use `--no-merge-safetensor`:**\n- Machine runs out of memory during the merge step\n- Need to process very large models on memory-constrained systems\n- Want to preserve intermediate layer-wise quantized weights\n\n### Resume Layer\n\nFor memory-constrained systems that are unable to complete quantization despite enabling low memory mode with `--no-merge-safetensor`, restart the script with the `--resume-layer` arg to specify the layer from which to continue the conversion process. In the example below, we skip layers 0-11 and resume conversion starting with layer 12.\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /path/to/model \\\n  --input-type bf16 \\\n  --output /path/to/output \\\n  --quant-method int4 \\\n  --no-merge-safetensor\n  --resume-layer 12\n```\n\n## Examples\n\n### Example 1: Quantize DeepSeek-V3.1 (FP8 → INT4)\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /mnt/data/models/DeepSeek-V3.1 \\\n  --input-type fp8 \\\n  --output /mnt/data/models/DeepSeek-V3.1-INT4 \\\n  --quant-method int4 \\\n  --cpuinfer-threads 60 \\\n  --threadpool-count 2\n```\n\n### Example 2: Quantize Qwen3-Next-80B (BF16 → INT4, Low Memory)\n\n```bash\npython scripts/convert_cpu_weights.py \\\n  --input-path /mnt/data/models/Qwen3-Next-80B-A3B-Instruct \\\n  --input-type bf16 \\\n  --output /mnt/data/models/Qwen3-Next-80B-A3B-Instruct-INT4 \\\n  --quant-method int4 \\\n  --cpuinfer-threads 60 \\\n  --threadpool-count 2 \\\n  --no-merge-safetensor\n```\n\n---\n\n## GPU Weight Quantization\n\n### Prerequisites\n\nGPU weight quantization requires additional dependencies. Install them before proceeding:\n\n```bash\npip install accelerate transformers llmcompressor datasets\n```\n\n**Required packages:**\n- `accelerate`: For distributed model loading and device mapping\n- `transformers`: For model and tokenizer loading\n- `llmcompressor`: For quantization (supports GPTQ and RTN methods)\n- `datasets`: For calibration data loading (GPTQ only)\n\n**Documentation:** This tool is based on llmcompressor. For more details, see [llmcompressor quantization guide](https://docs.vllm.ai/projects/llm-compressor/en/latest/getting-started/compress/#select-a-quantization-method-and-scheme).\n\n### Overview\n\nApply weight quantization to model weights for GPU-resident \"hot\" experts (frequently accessed) in CPU-GPU hybrid inference. This tool works together with `convert_cpu_weights.py` to enable heterogeneous expert placement:\n\n- **GPU-resident experts** (\"hot\" experts) use GPTQ/RTN quantization (this tool) for efficient GPU memory usage\n- **CPU-resident experts** (\"cold\" experts) use AMX-optimized INT4/INT8 quantization (convert_cpu_weights.py)\n- **Attention layers, gates, and shared experts** remain in higher precision\n\nThis approach maximizes throughput and resource utilization by intelligently distributing experts across CPUs and GPUs.\n\n### Quantization Methods\n\n#### 1. GPTQ (Calibration-based, Default)\n**Pros:**\n- Higher accuracy through calibration-based quantization\n- Recommended for production deployments\n\n**Cons:**\n- Requires calibration dataset\n- Slower quantization process\n- Higher memory requirements (needs Hessian matrix)\n\n#### 2. RTN (Round-To-Nearest)\n**Pros:**\n- Fast quantization (no calibration needed)\n- Lower memory requirements\n- Good for quick testing and prototyping\n\n**Cons:**\n- Slightly lower accuracy compared to GPTQ\n- No calibration optimization\n\n### Quantization Types\n\n- **W4A16**: 4-bit weights, 16-bit activations (INT4)\n- **W8A16**: 8-bit weights, 16-bit activations (INT8)\n\n### Basic Usage\n\n#### GPTQ Quantization (Recommended for Production)\n```bash\npython scripts/convert_gpu_weights.py \\\n  --model_id /path/to/model \\\n  --output_dir /path/to/output \\\n  --quant_method GPTQ \\\n  --quant_type W4A16\n```\n\n#### RTN Quantization (Fast, for Testing)\n```bash\npython scripts/convert_gpu_weights.py \\\n  --model_id /path/to/model \\\n  --output_dir /path/to/output \\\n  --quant_method RTN \\\n  --quant_type W4A16\n```\n\n### Memory Requirements\n\nUnderstanding memory requirements is crucial for successful quantization. The requirements differ significantly between RTN and GPTQ methods.\n\n#### RTN Memory Requirements\n\nRTN only requires memory for quantization parameters (scales/zero-points):\n\n| Component | Requirement |\n|-----------|-------------|\n| **DRAM (CPU Memory)** | ≥ Total model parameters |\n| **VRAM (GPU Memory)** | ≥ Single layer parameters |\n\n**Example: DeepSeek-R1-0528-BF16 (684B parameters)**\n- DRAM: ~1368 GB (684B params × 2 bytes)\n- VRAM: ~22.4 GB (1 layer)\n\n#### GPTQ Memory Requirements\n\nGPTQ requires additional memory for Hessian matrices during calibration:\n\n| Component | Requirement |\n|-----------|-------------|\n| **DRAM (CPU Memory)** | ≥ Total model parameters |\n| **VRAM (GPU Memory)** | ≥ Single layer parameters × 2 |\n\nThe Hessian matrix is approximately the same size as the layer weights and is used to increase accuracy recovery.\n\n**Example: DeepSeek-R1-0528-BF16 (684B parameters)**\n- DRAM: ~1368 GB (684B params × 2 bytes)\n- VRAM: ~44.8 GB (1 layer × 2 for Hessian matrix)\n\n#### Method Comparison\n\n| Method | Speed | VRAM | Accuracy | Use Case |\n|--------|-------|------|----------|----------|\n| **RTN** | Fast | Low (~22GB) | Good | Testing, prototyping |\n| **GPTQ** | Slow | High (~45GB) | Better | Production deployment |\n\n### Advanced Options\n\n#### Calibration Configuration (GPTQ Only)\n\nFor GPTQ quantization, control the calibration process for better quantization quality:\n\n```bash\npython scripts/convert_gpu_weights.py \\\n  --model_id /path/to/model \\\n  --output_dir /path/to/output \\\n  --quant_method GPTQ \\\n  --quant_type W4A16 \\\n  --num_calibration_samples 512 \\\n  --max_sequence_length 2048 \\\n  --dataset HuggingFaceH4/ultrachat_200k \\\n  --dataset_split train_sft\n```\n\n**Options (GPTQ only):**\n- `--num_calibration_samples`: Number of samples for calibration (default: 512)\n- `--max_sequence_length`: Maximum sequence length (default: 2048)\n- `--dataset`: HuggingFace dataset for calibration\n- `--dataset_split`: Dataset split to use\n- `--dampening_frac`: Dampening fraction to reduce quantization noise (default: 0.1)\n\n#### Memory Management\n\nUse `--max_gpu_memory` to limit GPU memory usage and offload remaining layers to CPU:\n\n```bash\npython scripts/convert_gpu_weights.py \\\n  --model_id /path/to/model \\\n  --output_dir /path/to/output \\\n  --quant_method GPTQ \\\n  --quant_type W4A16 \\\n  --max_gpu_memory \"40GiB\"\n```\n\n**Recommended settings for GPTQ:**\n\n| GPU VRAM | Suggested `--max_gpu_memory` | Notes |\n|----------|------------------------------|-------|\n| 24 GiB   | 10-12 GiB | Reserve ~50% for Hessian |\n| 48 GiB   | 24-30 GiB | Reserve ~40% for Hessian |\n| 80 GiB   | 40-50 GiB | Reserve ~40% for Hessian |\n\n**Recommended settings for RTN:**\n\n| GPU VRAM | Suggested `--max_gpu_memory` | Notes |\n|----------|------------------------------|-------|\n| 24 GiB   | 18-20 GiB | No Hessian needed |\n| 48 GiB   | 40-45 GiB | No Hessian needed |\n| 80 GiB   | 70-75 GiB | No Hessian needed |\n\n**Options:**\n- `--max_gpu_memory`: Maximum GPU memory for model weights per device (e.g., '40GiB')\n- `--max_cpu_memory`: Maximum CPU memory (default: 1000GiB when `--max_gpu_memory` is set)\n\n**Important:** llmcompressor does not support disk offloading. Ensure your machine has enough GPU + CPU memory to load the entire model. If you still encounter OOM:\n1. Use RTN instead of GPTQ (requires less memory)\n2. Reduce `--num_calibration_samples` (GPTQ only, e.g., 256)\n3. Reduce `--max_sequence_length` (GPTQ only, e.g., 1024)\n4. Use `--force_cpu` to run entirely on CPU (slower but avoids GPU OOM)\n\n### Examples\n\n#### Example 1: GPTQ Quantization for Production (Qwen3-Next-80B, W4A16)\n\n```bash\npython scripts/convert_gpu_weights.py \\\n  --model_id /mnt/data/models/Qwen3-Next-80B-A3B-Instruct \\\n  --output_dir /mnt/data/models/Qwen3-Next-80B-A3B-Instruct-GPTQ-W4A16 \\\n  --quant_method GPTQ \\\n  --quant_type W4A16 \\\n  --num_calibration_samples 512 \\\n  --max_sequence_length 2048 \\\n  --max_gpu_memory \"40GiB\" \\\n  --trust_remote_code\n```\n\n#### Example 2: RTN Quantization for Fast Testing (DeepSeek-R1, W4A16)\n\n```bash\npython scripts/convert_gpu_weights.py \\\n  --model_id /mnt/data/models/DeepSeek-R1-0528-BF16 \\\n  --output_dir /mnt/data/models/DeepSeek-R1-0528-RTN-W4A16 \\\n  --quant_method RTN \\\n  --quant_type W4A16 \\\n  --max_gpu_memory \"70GiB\" \\\n  --trust_remote_code\n```\n\n#### Example 3: GPTQ with Custom Calibration Dataset (GLM-4.5-Air, W8A16)\n\n```bash\npython scripts/convert_gpu_weights.py \\\n  --model_id /mnt/data/models/GLM-4.5-Air \\\n  --output_dir /mnt/data/models/GLM-4.5-Air-GPTQ-W8A16 \\\n  --quant_method GPTQ \\\n  --quant_type W8A16 \\\n  --dataset \"tatsu-lab/alpaca\" \\\n  --dataset_split \"train\" \\\n  --num_calibration_samples 256 \\\n  --max_gpu_memory \"40GiB\" \\\n  --trust_remote_code\n```\n"
  },
  {
    "path": "kt-kernel/scripts/check.py",
    "content": "import os\n\n# insert the path of the project\nimport sys\n\n# sys.path.insert(0, \"/home/azure/ktransformers\")\nimport argparse\nimport torch\nfrom safetensors import safe_open\nfrom safetensors.torch import save_file\nimport re\nfrom collections import defaultdict\nimport itertools\nimport os\nimport torch\nimport numpy as np\n\ntensor_from_amx = [\".mlp.experts.\"]  # todo: add keys in gguf that should be used in the final tensor\n\n\ndef safe_open_binary_to_tensor(file_path):\n    if not os.path.exists(file_path):\n        raise FileNotFoundError(f\"文件不存在: {file_path}\")\n\n    if not os.access(file_path, os.R_OK):\n        raise PermissionError(f\"没有权限读取文件: {file_path}\")\n\n    try:\n        with open(file_path, \"rb\") as f:\n            binary_data = f.read()\n\n        np_array = np.frombuffer(binary_data, dtype=np.int8)\n\n        tensor = torch.from_numpy(np_array)\n\n        return tensor\n\n    except Exception as e:\n        raise IOError(f\"file process error: {str(e)}\")\n\n\ndef read_safetensor_keys_from_folder(folder_path) -> dict:\n    \"\"\"\n    :param folder_path: folder path\n    :return: key_to_file_map\n    \"\"\"\n    # check if the folder path is exist\n    if not os.path.exists(folder_path):\n        raise FileNotFoundError(f\"GGUF dir not found: {folder_path}\")\n    if os.path.isfile(folder_path):\n        folder_path = os.path.dirname(folder_path)\n\n    key_to_file_map = {}\n\n    found_safetensor = False\n    for root, dirs, files in os.walk(folder_path):\n        # sort files\n        files = sorted(files)\n        for file in files:\n            if file.endswith(\".safetensors\"):\n                found_safetensor = True\n                file_path = os.path.join(root, file)\n                try:\n                    with safe_open(file_path, framework=\"pt\") as f:\n                        for key in f.keys():\n                            if \"model.layers.61\" in key:\n                                # skip MTP layer\n                                continue\n                            # try:\n                            #     if int(key.split('.')[2]) > 4:\n                            #         continue\n                            # except:\n                            #     pass\n                            key_to_file_map[key] = file_path\n                except Exception as e:\n                    print(f\"Error reading Safetensor file {file_path}: {e}\")\n\n    if not found_safetensor:\n        raise FileNotFoundError(f\"No Safetensor files found in {folder_path}\")\n\n    return key_to_file_map\n\n\ndef read_amx_tensor_from_folder(folder_path, keys) -> dict:\n    layer_list = [f\"_layer_{i}\" for i in range(3, 61)]\n    numa_list = [\"_numa_0\", \"_numa_1\"]\n\n    down_list = [f\"INT4_down_{i}_quant_.kt\" for i in range(256)]\n    gate_list = [f\"INT4_gate_{i}_quant_.kt\" for i in range(256)]\n    up_list = [f\"INT4_up_{i}_quant_.kt\" for i in range(256)]\n    down_scale_list = [f\"INT4_down_{i}_scale_.kt\" for i in range(256)]\n    gate_scale_list = [f\"INT4_gate_{i}_scale_.kt\" for i in range(256)]\n    up_scale_list = [f\"INT4_up_{i}_scale_.kt\" for i in range(256)]\n    target = [\"ffn_up_exps\", \"ffn_down_exps\", \"ffn_gate_exps\"]\n    tensor_file_map = {}\n    for key in keys:\n        layer = int(key.split(\".\")[1])\n        if layer < 3:\n            continue\n        layer_path = f\"_layer_{layer}\"\n        # concatenate the path layer/numa/(down|gate|up)_(0-255)_3670016Byte_quant_.kt\n        # store the path in the tensor_file_map\n        # key = key+'.idx.weight'\n        # scale_key = key+'.idx.scale'\n        for numa_idx, numa in enumerate(numa_list):\n            # TODO: 256 should be a variable\n            for i in range(256):\n                prefix_key = \".\".join(key.split(\".\")[:-1])\n\n                experts_key = prefix_key + f\".{i}.numa.{numa_idx}.weight\"\n                scale_key = prefix_key + f\".{i}.numa.{numa_idx}.scale\"\n                if \"down\" in experts_key:\n                    tensor_file_map[experts_key] = os.path.join(folder_path, layer_path, numa, down_list[i])\n                    tensor_file_map[scale_key] = os.path.join(folder_path, layer_path, numa, down_scale_list[i])\n                elif \"gate\" in experts_key:\n                    tensor_file_map[experts_key] = os.path.join(folder_path, layer_path, numa, gate_list[i])\n                    tensor_file_map[scale_key] = os.path.join(folder_path, layer_path, numa, gate_scale_list[i])\n                elif \"up\" in experts_key:\n                    tensor_file_map[experts_key] = os.path.join(folder_path, layer_path, numa, up_list[i])\n                    tensor_file_map[scale_key] = os.path.join(folder_path, layer_path, numa, up_scale_list[i])\n    return tensor_file_map\n\n\n# def translate_name(name:str)->str:\n#     \"\"\"\n#     :param name: name of the tensor\n#     :return: translated name\n#     \"\"\"\n#     name = translate_name_to_gguf(name)\n#     name = name.replace(\".up_proj.\", \".ffn_up_exps.\")\n#     name = name.replace(\".down_proj.\", \".ffn_down_exps.\")\n#     name = name.replace(\".gate_proj.\", \".ffn_gate_exps.\")\n#     name = name.replace(\".ffn_gate_inp.e_score_correction_bias\", \".exp_probs_b.bias\")\n#     return name\n\n\ndef _clean_keys(keys):\n    keys = list(keys)\n    target = [\"ffn_up_exps\", \"ffn_down_exps\", \"ffn_gate_exps\"]\n    # only keep the keys that contain the target\n    keys = [key for key in keys if any(target_key in key for target_key in target) and \"ggml_type\" not in key]\n    return keys\n\n\ndef combine_tensor_sources(safetensor_path, amx_path):\n    safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)\n\n    keys = _clean_keys(safetensor_tensor_file_map.keys())\n\n    amx_tensor_file_map = read_amx_tensor_from_folder(amx_path, keys)\n    target_tensor_map = {}\n    for key in safetensor_tensor_file_map.keys():\n        if \"_exps.\" in key:\n            continue\n\n        target_tensor_map[key] = safetensor_tensor_file_map[key]\n\n    for key in amx_tensor_file_map.keys():\n        target_tensor_map[key] = amx_tensor_file_map[key]\n\n    return target_tensor_map\n\n\ndef write_combined_tensor(target_tensor_map: dict, output_path: str):\n    # Ensure output directory exists\n    os.makedirs(output_path, exist_ok=True)\n\n    # Cache for safetensor file handles and GGUF loaders\n    safetensors_cache = {}\n    amx_cache = {}\n\n    # Group tensors by layer\n    layer_groups = defaultdict(list)\n    non_layer_keys = []\n    layer_pattern = re.compile(r\"blk\\.(\\d+)\\.\")\n\n    for key in target_tensor_map:\n        match = layer_pattern.search(key)\n        if match:\n            layer_groups[int(match.group(1))].append(key)\n        else:\n            non_layer_keys.append(key)\n\n    # Calculate the number of shards\n    total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1\n\n    shard_idx = 0\n    # Save non-layer tensors to the first shard if they exist\n    if non_layer_keys:\n        tensors = {}\n        for key in non_layer_keys:\n            file_path = target_tensor_map[key]\n            tensor = None\n            ggml_type = None\n            if file_path.endswith(\".safetensors\"):\n                if file_path not in safetensors_cache:\n                    safetensors_cache[file_path] = safe_open(file_path, framework=\"pt\")\n                f = safetensors_cache[file_path]\n                tensor = f.get_tensor(key)\n            elif file_path.endswith(\".kt\"):\n                tensor = safe_open_binary_to_tensor(file_path)\n            else:\n                raise ValueError(f\"Unsupported file format: {file_path}\")\n            tensors[key] = tensor\n\n        output_file = os.path.join(output_path, f\"model-{shard_idx:05}-of-{total_shards:05}.safetensors\")\n        print(f\"Saving non-layer tensors to {output_file}\")\n        save_file(tensors, output_file)\n        shard_idx += 1\n\n    # Save each layer's tensors to subsequent shards\n    for layer_num in sorted(layer_groups.keys()):\n        layer_keys = layer_groups[layer_num]\n        tensors = {}\n        for key in layer_keys:\n            file_path = target_tensor_map[key]\n            tensor = None\n            ggml_type = None\n            if file_path.endswith(\".safetensors\"):\n                if file_path not in safetensors_cache:\n                    safetensors_cache[file_path] = safe_open(file_path, framework=\"pt\")\n                f = safetensors_cache[file_path]\n                tensor = f.get_tensor(key)\n                tensor_info = tensor.shape\n            elif file_path.endswith(\".kt\"):\n                tensor = safe_open_binary_to_tensor(file_path)\n            else:\n                raise ValueError(f\"Unsupported file format: {file_path}\")\n            tensors[key] = tensor\n\n        output_file = os.path.join(output_path, f\"model-{shard_idx:05}-of-{total_shards:05}.safetensors\")\n        print(f\"Saving layer {layer_num} to {output_file}\")\n        save_file(tensors, output_file)\n        shard_idx += 1\n    return\n\n\ndef main():\n    # 输入已经处理过的混合模型路径，提前处理好的amx路径，输出路径\n    parser = argparse.ArgumentParser(description=\"Read parameters from Safetensor and GGUF files\")\n    parser.add_argument(\n        \"--safetensor_path\",\n        type=str,\n        help=\"Path to the Safetensor file\",\n        default=\"/mnt/data/models/DeepSeek-R1-GGML-FP8-Hybrid/DeepSeek-R1-IQ1S-FP8\",\n    )\n    parser.add_argument(\n        \"--amx_path\", type=str, help=\"Path to the GGUF file\", default=\"/mnt/data/models/DeepSeek-R1-INT4\"\n    )\n    parser.add_argument(\n        \"--output_path\",\n        type=str,\n        help=\"Path to the output file\",\n        default=\"/mnt/data/models/DeepSeek-R1-GGML-FP8-Hybrid/DeepSeek-R1-AMXQ4-FP8\",\n    )\n\n    # print all the arguments\n    print(\"All the arguments:\")\n    print(parser.parse_args())\n\n    # 解析命令行参数\n    args = parser.parse_args()\n\n    safetensor_path = args.safetensor_path\n    amx_path = args.amx_path\n    output_path = args.output_path\n\n    target_tensor_map = combine_tensor_sources(safetensor_path, amx_path)\n    for key, value in target_tensor_map.items():\n        print(f\"{key}: {value}\")\n    write_combined_tensor(target_tensor_map, output_path)\n\n    return\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/scripts/check_cpu_features.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nCPU feature detection script for kt-kernel.\n\nThis script checks if your CPU supports the required instruction sets for FP8 MoE:\n- AVX512F (foundation)\n- AVX512_BF16 (BF16 dot product)\n- AVX512_VNNI (VNNI instructions)\n- AVX512_VBMI (byte permutation)\n\nUsage:\n    python3 scripts/check_cpu_features.py\n\"\"\"\n\nimport os\nimport sys\n\n\ndef check_cpuinfo():\n    \"\"\"Check CPU features via /proc/cpuinfo.\"\"\"\n    try:\n        with open(\"/proc/cpuinfo\", \"r\") as f:\n            cpuinfo = f.read().lower()\n        return cpuinfo\n    except FileNotFoundError:\n        return None\n\n\ndef main():\n    print(\"=\" * 70)\n    print(\"KT-Kernel CPU Feature Detection\")\n    print(\"=\" * 70)\n    print()\n\n    cpuinfo = check_cpuinfo()\n\n    if cpuinfo is None:\n        print(\"❌ /proc/cpuinfo not found (not on Linux?)\")\n        print(\"   Cannot detect CPU features automatically.\")\n        sys.exit(1)\n\n    # Extract CPU model\n    for line in cpuinfo.split(\"\\n\"):\n        if \"model name\" in line:\n            model = line.split(\":\")[1].strip()\n            print(f\"CPU Model: {model}\")\n            break\n    print()\n\n    # Check AMX support\n    print(\"AMX Support (Intel Sapphire Rapids+):\")\n    amx_flags = [\"amx_tile\", \"amx_int8\", \"amx_bf16\"]\n    amx_status = {}\n    for flag in amx_flags:\n        has_flag = flag in cpuinfo\n        amx_status[flag] = has_flag\n        status = \"✅\" if has_flag else \"❌\"\n        print(f\"  {status} {flag.upper()}\")\n\n    has_amx = all(amx_status.values())\n    print(f\"\\n  Overall AMX Support: {'✅ YES' if has_amx else '❌ NO'}\")\n    print()\n\n    # Check AVX512 support\n    print(\"AVX512 Support (required for FP8 MoE):\")\n    avx512_flags = [\"avx512f\", \"avx512_bf16\", \"avx512_vnni\", \"avx512_vbmi\"]\n    avx512_status = {}\n    for flag in avx512_flags:\n        has_flag = flag in cpuinfo\n        avx512_status[flag] = has_flag\n        status = \"✅\" if has_flag else \"❌\"\n        flag_desc = {\n            \"avx512f\": \"AVX512F (foundation)\",\n            \"avx512_bf16\": \"AVX512_BF16 (BF16 dot product)\",\n            \"avx512_vnni\": \"AVX512_VNNI (VNNI instructions)\",\n            \"avx512_vbmi\": \"AVX512_VBMI (byte permutation)\",\n        }\n        print(f\"  {status} {flag_desc.get(flag, flag.upper())}\")\n\n    has_avx512_full = all(avx512_status.values())\n    print(f\"\\n  Overall AVX512 Support: {'✅ YES' if has_avx512_full else '❌ NO'}\")\n\n    if not has_avx512_full and avx512_status[\"avx512f\"]:\n        missing = [f for f in avx512_flags if not avx512_status[f]]\n        print(f\"  ⚠️  Warning: AVX512F detected but missing: {', '.join(missing)}\")\n        print(f\"      kt-kernel will fall back to AVX2 mode\")\n    print()\n\n    # Check AVX2 support\n    print(\"AVX2 Support (fallback):\")\n    has_avx2 = \"avx2\" in cpuinfo\n    status = \"✅\" if has_avx2 else \"❌\"\n    print(f\"  {status} AVX2\")\n    print()\n\n    # Recommendation\n    print(\"=\" * 70)\n    print(\"Recommendation:\")\n    print(\"=\" * 70)\n    if has_amx:\n        print(\"✅ Your CPU supports AMX - you can use the highest performance mode!\")\n        print(\"   Build with: -DKTRANSFORMERS_CPU_USE_AMX_AVX512=ON -DKTRANSFORMERS_CPU_USE_AMX=ON\")\n    elif has_avx512_full:\n        print(\"✅ Your CPU supports full AVX512 (F/BF16/VNNI/VBMI) - FP8 MoE will work!\")\n        print(\"   Build with: -DKTRANSFORMERS_CPU_USE_AMX_AVX512=ON\")\n    elif avx512_status.get(\"avx512f\", False):\n        print(\"⚠️  Your CPU has AVX512F but missing required extensions.\")\n        print(\"   FP8 MoE will NOT work. kt-kernel will fall back to AVX2 mode.\")\n        print(\"   Missing extensions:\", \", \".join([f for f in avx512_flags if not avx512_status.get(f, False)]))\n    elif has_avx2:\n        print(\"ℹ️  Your CPU supports AVX2 only - basic compatibility mode.\")\n        print(\"   FP8 MoE will NOT be available, but other features will work.\")\n    else:\n        print(\"❌ Your CPU does not support the minimum required instruction set (AVX2).\")\n        print(\"   kt-kernel may not work on this system.\")\n    print()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/scripts/compare_weights.py",
    "content": "#!/usr/bin/env python3\n\"\"\"\nCompare two sets of quantized weights generated by convert_cpu_weights.py\n\nThis script supports comparing:\n- Two safetensor format weights (merged)\n- Two .kt format weights (layer folder structure)\n- One safetensor and one .kt format (cross-format comparison)\n\nUsage:\n    python compare_weights.py --path1 /path/to/weights1 --path2 /path/to/weights2\n    python compare_weights.py --path1 /path/to/weights1 --path2 /path/to/weights2 --tolerance 1e-5\n\"\"\"\n\nimport argparse\nimport os\nimport glob\nimport numpy as np\nimport torch\nfrom safetensors import safe_open\nfrom typing import Dict, Tuple\nfrom collections import defaultdict\n\n\ndef unpack_awq_int32_to_int8(packed: np.ndarray, bits: int = 4) -> np.ndarray:\n    \"\"\"Unpack AWQ int32 packed format to int8\n\n    AWQ uses INT4 quantization: 8 x 4-bit values packed into 1 x 32-bit integer\n\n    Args:\n        packed: Packed int32 array\n        bits: Number of bits per element (default: 4)\n\n    Returns:\n        Unpacked int8 array\n    \"\"\"\n    if packed.dtype != np.int32:\n        # Try to reinterpret as int32\n        packed = packed.view(np.int32)\n\n    pack_num = 32 // bits  # 8 for INT4\n    unpacked_size = packed.size * pack_num\n\n    unpacked = np.empty(unpacked_size, dtype=np.int8)\n\n    for i in range(pack_num):\n        shift = i * bits\n        mask = (1 << bits) - 1  # 0x0F for 4-bit\n        unpacked[i::pack_num] = ((packed >> shift) & mask).astype(np.int8)\n\n    return unpacked\n\n\ndef normalize_tensor_dtype(tensor: np.ndarray, tensor_name: str, is_awq: bool = False) -> np.ndarray:\n    \"\"\"Normalize tensor to consistent dtype based on tensor type\n\n    Args:\n        tensor: Input tensor\n        tensor_name: Name of the tensor (used to determine type)\n        is_awq: Whether this is AWQ format (requires unpacking)\n\n    Returns:\n        Normalized tensor with consistent dtype\n    \"\"\"\n    # Determine tensor type from name\n    is_scale = \"scale\" in tensor_name\n    is_weight = \"weight\" in tensor_name\n    is_qzeros = \"qzeros\" in tensor_name\n\n    if is_scale:\n        # Scale should be float32\n        if tensor.dtype != np.float32:\n            # Try to reinterpret bytes as float32\n            tensor = tensor.view(np.float32)\n        return tensor\n\n    elif is_weight or is_qzeros:\n        # Weight/qzeros should be int8\n        if is_awq and tensor.dtype == np.int32:\n            # AWQ format: unpack int32 to int8\n            tensor = unpack_awq_int32_to_int8(tensor)\n        elif tensor.dtype == np.float32:\n            # Two cases for float32:\n            # Case 1: Values look like int8 values (e.g., [37., 73., -70.])\n            #         -> use astype to convert values\n            # Case 2: Values are large scientific notation (e.g., [2.6e34, ...])\n            #         -> use view to reinterpret bytes\n\n            # Check if values are in int8 range (-128 to 127)\n            if len(tensor) > 0:\n                sample_size = min(100, len(tensor))\n                sample_values = tensor.flat[:sample_size]\n\n                # If most values are in int8 range and have no decimal parts\n                in_int8_range = np.all((sample_values >= -128) & (sample_values <= 127))\n                is_integer_valued = np.all(sample_values == np.round(sample_values))\n\n                if in_int8_range and is_integer_valued:\n                    # Case 1: Direct value conversion\n                    tensor = tensor.astype(np.int8)\n                else:\n                    # Case 2: Byte reinterpretation (4 bytes -> 4 int8s)\n                    tensor = tensor.view(np.int8)\n            else:\n                tensor = tensor.astype(np.int8)\n\n        elif tensor.dtype == np.int32:\n            # Reinterpret int32 as int8 (4x more elements)\n            tensor = tensor.view(np.int8)\n        elif tensor.dtype != np.int8:\n            # Other types: try to convert\n            tensor = tensor.astype(np.int8)\n\n        return tensor\n\n    else:\n        # Unknown type, return as-is\n        return tensor\n\n\ndef load_kt_binary(file_path: str) -> np.ndarray:\n    \"\"\"Load .kt format binary tensor file\n\n    Args:\n        file_path: Path to .kt binary file\n\n    Returns:\n        numpy array with the loaded tensor\n    \"\"\"\n    if not os.path.exists(file_path):\n        raise FileNotFoundError(f\"File not found: {file_path}\")\n\n    with open(file_path, \"rb\") as f:\n        binary_data = f.read()\n\n    # Determine dtype based on file name\n    if \"scale\" in file_path:\n        dtype = np.float32\n    else:\n        dtype = np.int8\n\n    return np.frombuffer(binary_data, dtype=dtype)\n\n\ndef detect_weight_format(path: str) -> str:\n    \"\"\"Detect if weights are in safetensor or .kt format\n\n    Args:\n        path: Path to weight directory\n\n    Returns:\n        'safetensor' or 'kt' or 'unknown'\n    \"\"\"\n    if not os.path.exists(path):\n        raise FileNotFoundError(f\"Path not found: {path}\")\n\n    # Check for safetensor files\n    safetensor_files = glob.glob(os.path.join(path, \"*.safetensors\"))\n    if safetensor_files:\n        return \"safetensor\"\n\n    # Check for layer folder structure\n    layer_folders = glob.glob(os.path.join(path, \"_layer_*\"))\n    if layer_folders:\n        return \"kt\"\n\n    return \"unknown\"\n\n\ndef detect_awq_format(weights_sample: Dict[str, np.ndarray]) -> bool:\n    \"\"\"Detect if weights are in AWQ format\n\n    AWQ format characteristics:\n    - Has 'qzeros' tensors\n    - Weight tensors are int32 dtype (packed)\n\n    Args:\n        weights_sample: Sample of loaded weights\n\n    Returns:\n        True if AWQ format detected\n    \"\"\"\n    has_qzeros = any(\"qzeros\" in key for key in weights_sample.keys())\n\n    if not has_qzeros:\n        return False\n\n    # Check if weight tensors are int32\n    for key, tensor in weights_sample.items():\n        if \"weight\" in key and tensor.dtype == np.int32:\n            return True\n\n    return False\n\n\ndef load_safetensor_weights(path: str) -> Dict[str, np.ndarray]:\n    \"\"\"Load all weights from safetensor format\n\n    Args:\n        path: Path to directory containing safetensor files\n\n    Returns:\n        Dictionary mapping tensor names to numpy arrays (dtype normalized)\n    \"\"\"\n    weights = {}\n\n    safetensor_files = sorted(glob.glob(os.path.join(path, \"*.safetensors\")))\n    if not safetensor_files:\n        raise FileNotFoundError(f\"No safetensor files found in {path}\")\n\n    print(f\"Loading safetensor files from {path}\")\n\n    # First pass: load all tensors\n    for file in safetensor_files:\n        with safe_open(file, framework=\"pt\") as f:\n            for key in f.keys():\n                # Only load MoE expert weights for comparison\n                if \".ffn_\" in key and \"_exps.\" in key:\n                    tensor = f.get_tensor(key)\n                    weights[key] = tensor.cpu().numpy()\n\n    # Detect AWQ format\n    is_awq = detect_awq_format(weights)\n    print(f\"  Format detected: {'AWQ' if is_awq else 'INT4/INT8'}\")\n\n    # Second pass: normalize dtypes\n    print(f\"  Normalizing dtypes...\")\n    for key in list(weights.keys()):\n        original_dtype = weights[key].dtype\n        original_shape = weights[key].shape\n        weights[key] = normalize_tensor_dtype(weights[key], key, is_awq=is_awq)\n\n        if weights[key].shape != original_shape or weights[key].dtype != original_dtype:\n            print(f\"    {key}: {original_dtype}{original_shape} -> {weights[key].dtype}{weights[key].shape}\")\n\n    print(f\"  Loaded {len(weights)} tensors from safetensor format\")\n    return weights\n\n\ndef load_kt_weights(path: str) -> Dict[str, np.ndarray]:\n    \"\"\"Load all weights from .kt format (layer folder structure)\n\n    Args:\n        path: Path to directory containing _layer_* folders\n\n    Returns:\n        Dictionary mapping tensor names to numpy arrays\n    \"\"\"\n    weights = {}\n\n    layer_folders = sorted(glob.glob(os.path.join(path, \"_layer_*\")))\n    if not layer_folders:\n        raise FileNotFoundError(f\"No _layer_* folders found in {path}\")\n\n    print(f\"Loading .kt files from {path}\")\n\n    for layer_folder in layer_folders:\n        # Extract layer index from folder name\n        layer_idx = int(os.path.basename(layer_folder).split(\"_\")[-1])\n\n        # Find all NUMA folders\n        numa_folders = sorted(glob.glob(os.path.join(layer_folder, \"_numa_*\")))\n\n        for numa_folder in numa_folders:\n            # Extract NUMA index\n            numa_idx = int(os.path.basename(numa_folder).split(\"_\")[-1])\n\n            # Find all .kt files\n            kt_files = glob.glob(os.path.join(numa_folder, \"*.kt\"))\n\n            for kt_file in kt_files:\n                filename = os.path.basename(kt_file)\n\n                # Parse filename to extract metadata\n                # Format: {METHOD}_{proj}_{expert}_{size}Byte_{type}_.kt\n                parts = filename.replace(\".kt\", \"\").split(\"_\")\n\n                if len(parts) >= 5:\n                    method = parts[0]  # INT4, INT8, etc.\n                    proj = parts[1]  # down, gate, up\n                    expert = parts[2]  # expert ID\n                    tensor_type = parts[4]  # quant or scale\n\n                    # Map proj names\n                    proj_map = {\"down\": \"ffn_down_exps\", \"gate\": \"ffn_gate_exps\", \"up\": \"ffn_up_exps\"}\n\n                    proj_key = proj_map.get(proj, proj)\n\n                    # Build key matching safetensor format\n                    if tensor_type == \"quant\":\n                        key = f\"blk.{layer_idx}.{proj_key}.{expert}.numa.{numa_idx}.weight\"\n                    else:  # scale\n                        key = f\"blk.{layer_idx}.{proj_key}.{expert}.numa.{numa_idx}.scale\"\n\n                    # Load tensor\n                    weights[key] = load_kt_binary(kt_file)\n\n    # Normalize dtypes (.kt format is never AWQ)\n    print(f\"  Normalizing dtypes...\")\n    for key in list(weights.keys()):\n        original_dtype = weights[key].dtype\n        original_shape = weights[key].shape\n        weights[key] = normalize_tensor_dtype(weights[key], key, is_awq=False)\n\n        if weights[key].shape != original_shape or weights[key].dtype != original_dtype:\n            print(f\"    {key}: {original_dtype}{original_shape} -> {weights[key].dtype}{weights[key].shape}\")\n\n    print(f\"  Loaded {len(weights)} tensors from .kt format\")\n    return weights\n\n\ndef normalize_key(key: str) -> Tuple[int, str, int, str]:\n    \"\"\"Normalize tensor key to extract layer, projection, expert, and type\n\n    Args:\n        key: Tensor key like \"blk.0.ffn_up_exps.5.weight\" or \"blk.0.ffn_up_exps.5.numa.0.weight\"\n\n    Returns:\n        Tuple of (layer_idx, proj_name, expert_idx, tensor_type)\n    \"\"\"\n    parts = key.split(\".\")\n\n    layer_idx = int(parts[1])\n    proj_name = parts[2]\n    expert_idx = int(parts[3])\n\n    # Handle both formats: with and without numa\n    if \"numa\" in key:\n        tensor_type = parts[6]  # weight or scale\n    else:\n        tensor_type = parts[4]  # weight, scale, or qzeros\n\n    return (layer_idx, proj_name, expert_idx, tensor_type)\n\n\ndef compare_weights(\n    weights1: Dict[str, np.ndarray], weights2: Dict[str, np.ndarray], tolerance: float = 1e-6\n) -> Tuple[bool, Dict[str, Dict]]:\n    \"\"\"Compare two sets of weights\n\n    Args:\n        weights1: First set of weights\n        weights2: Second set of weights\n        tolerance: Numerical tolerance for comparison\n\n    Returns:\n        Tuple of (all_match, differences_dict)\n    \"\"\"\n    print(\"\\n\" + \"=\" * 80)\n    print(\"WEIGHT COMPARISON\")\n    print(\"=\" * 80)\n\n    # Group keys by normalized form (ignoring numa index)\n    def group_by_base_key(weights):\n        groups = defaultdict(list)\n        for key in weights.keys():\n            try:\n                layer, proj, expert, ttype = normalize_key(key)\n                base_key = f\"blk.{layer}.{proj}.{expert}.{ttype}\"\n                groups[base_key].append(key)\n            except:\n                # Skip keys that don't match expected format\n                pass\n        return groups\n\n    groups1 = group_by_base_key(weights1)\n    groups2 = group_by_base_key(weights2)\n\n    all_base_keys = sorted(set(groups1.keys()) | set(groups2.keys()))\n\n    all_match = True\n    differences = {}\n\n    total_comparisons = 0\n    matching_comparisons = 0\n\n    for base_key in all_base_keys:\n        keys1 = groups1.get(base_key, [])\n        keys2 = groups2.get(base_key, [])\n\n        if not keys1:\n            print(f\"❌ Missing in weights1: {base_key}\")\n            differences[base_key] = {\"status\": \"missing_in_weights1\"}\n            all_match = False\n            continue\n\n        if not keys2:\n            print(f\"❌ Missing in weights2: {base_key}\")\n            differences[base_key] = {\"status\": \"missing_in_weights2\"}\n            all_match = False\n            continue\n\n        # For kt format, we may have multiple keys (one per NUMA node)\n        # We need to concatenate them for comparison\n        if len(keys1) > 1 or len(keys2) > 1:\n            # Concatenate tensors from all NUMA nodes\n            tensor1 = np.concatenate([weights1[k] for k in sorted(keys1)])\n            tensor2 = np.concatenate([weights2[k] for k in sorted(keys2)])\n        else:\n            tensor1 = weights1[keys1[0]]\n            tensor2 = weights2[keys2[0]]\n\n        total_comparisons += 1\n\n        # Debug: print dtype and shape info\n        if tensor1.dtype != tensor2.dtype:\n            print(f\"⚠️  Dtype mismatch for {base_key}: {tensor1.dtype} vs {tensor2.dtype}\")\n            print(f\"   This should have been normalized. Shape: {tensor1.shape} vs {tensor2.shape}\")\n\n        # Compare shapes\n        if tensor1.shape != tensor2.shape:\n            print(f\"❌ Shape mismatch for {base_key}:\")\n            print(f\"   Shape1: {tensor1.shape} (dtype: {tensor1.dtype})\")\n            print(f\"   Shape2: {tensor2.shape} (dtype: {tensor2.dtype})\")\n            differences[base_key] = {\n                \"status\": \"shape_mismatch\",\n                \"shape1\": tensor1.shape,\n                \"shape2\": tensor2.shape,\n                \"dtype1\": str(tensor1.dtype),\n                \"dtype2\": str(tensor2.dtype),\n            }\n            all_match = False\n            continue\n\n        # Compare dtypes (should be consistent after normalization)\n        if tensor1.dtype != tensor2.dtype:\n            print(f\"❌ Dtype mismatch for {base_key} after normalization:\")\n            print(f\"   Dtype1: {tensor1.dtype}\")\n            print(f\"   Dtype2: {tensor2.dtype}\")\n            differences[base_key] = {\n                \"status\": \"dtype_mismatch\",\n                \"dtype1\": str(tensor1.dtype),\n                \"dtype2\": str(tensor2.dtype),\n            }\n            all_match = False\n            continue\n\n        # Compare values\n        if np.allclose(tensor1, tensor2, atol=tolerance, rtol=tolerance):\n            matching_comparisons += 1\n        else:\n            max_diff = np.max(np.abs(tensor1 - tensor2))\n            mean_diff = np.mean(np.abs(tensor1 - tensor2))\n\n            print(f\"❌ Value mismatch for {base_key}:\")\n            print(f\"   Max difference: {max_diff:.2e}\")\n            print(f\"   Mean difference: {mean_diff:.2e}\")\n            print(f\"   Tolerance: {tolerance:.2e}\")\n\n            differences[base_key] = {\n                \"status\": \"value_mismatch\",\n                \"max_diff\": float(max_diff),\n                \"mean_diff\": float(mean_diff),\n                \"tolerance\": tolerance,\n            }\n            all_match = False\n\n    print(\"\\n\" + \"=\" * 80)\n    print(\"SUMMARY\")\n    print(\"=\" * 80)\n    print(f\"Total comparisons: {total_comparisons}\")\n    print(f\"Matching: {matching_comparisons}\")\n    print(f\"Mismatching: {total_comparisons - matching_comparisons}\")\n    print(f\"Missing tensors: {len(differences) - (total_comparisons - matching_comparisons)}\")\n\n    if all_match:\n        print(\"\\n✅ All weights match!\")\n    else:\n        print(f\"\\n❌ Found {len(differences)} differences\")\n\n    return all_match, differences\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Compare two sets of quantized weights\")\n    parser.add_argument(\"--path1\", type=str, required=True, help=\"Path to first weight directory\")\n    parser.add_argument(\"--path2\", type=str, required=True, help=\"Path to second weight directory\")\n    parser.add_argument(\n        \"--tolerance\", type=float, default=1e-6, help=\"Numerical tolerance for comparison (default: 1e-6)\"\n    )\n\n    args = parser.parse_args()\n\n    # Validate paths\n    if not os.path.exists(args.path1):\n        print(f\"Error: Path1 does not exist: {args.path1}\")\n        return 1\n\n    if not os.path.exists(args.path2):\n        print(f\"Error: Path2 does not exist: {args.path2}\")\n        return 1\n\n    # Detect formats\n    print(\"Detecting weight formats...\")\n    format1 = detect_weight_format(args.path1)\n    format2 = detect_weight_format(args.path2)\n\n    print(f\"Path1 format: {format1}\")\n    print(f\"Path2 format: {format2}\")\n\n    if format1 == \"unknown\":\n        print(f\"Error: Unable to detect weight format in {args.path1}\")\n        return 1\n\n    if format2 == \"unknown\":\n        print(f\"Error: Unable to detect weight format in {args.path2}\")\n        return 1\n\n    # Load weights based on format\n    print(\"\\nLoading weights...\")\n\n    if format1 == \"safetensor\":\n        weights1 = load_safetensor_weights(args.path1)\n    else:\n        weights1 = load_kt_weights(args.path1)\n\n    if format2 == \"safetensor\":\n        weights2 = load_safetensor_weights(args.path2)\n    else:\n        weights2 = load_kt_weights(args.path2)\n\n    # Compare weights\n    all_match, differences = compare_weights(weights1, weights2, args.tolerance)\n\n    return 0 if all_match else 1\n\n\nif __name__ == \"__main__\":\n    exit(main())\n"
  },
  {
    "path": "kt-kernel/scripts/convert_cpu_weights.py",
    "content": "#!/usr/bin/env python3\n\nimport argparse\nimport os\nfrom collections import defaultdict\nfrom typing import Dict, List\nimport torch\nfrom safetensors import safe_open\nfrom safetensors.torch import save_file\nimport gc\nimport time\nimport json\nimport sys\nimport glob\nimport numpy as np\n\n# Add parent directory to path to import kt_kernel\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom kt_kernel import KTMoEWrapper\n\nimport triton\nimport triton.language as tl\n\n\nQ_BITS = 4\nSTORAGE_BITS = 32\nPACK_NUM = STORAGE_BITS // Q_BITS\nNUMA_NUM = 2\n\nREVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\n@triton.jit\ndef weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):\n    pid_m = tl.program_id(axis=0)\n    pid_n = tl.program_id(axis=1)\n    n = tl.cdiv(N, BLOCK_SIZE)\n    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    offs = offs_m[:, None] * N + offs_n[None, :]\n    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)\n    x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)\n    s = tl.load(s_ptr + pid_m * n + pid_n)\n    y = x * s\n    tl.store(y_ptr + offs, y, mask=mask)\n\n\ndef weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:\n    assert x.is_contiguous() and s.is_contiguous()\n    assert x.dim() == 2 and s.dim() == 2\n    M, N = x.size()\n    y = torch.empty_like(x, dtype=torch.get_default_dtype())\n    grid = lambda meta: (triton.cdiv(M, meta[\"BLOCK_SIZE\"]), triton.cdiv(N, meta[\"BLOCK_SIZE\"]))\n    weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)\n    return y\n\n\ndef load_model_config(input_path: str, input_type: str = None) -> Dict:\n    \"\"\"Load model configuration from config.json\n\n    Args:\n        input_path: Path to directory containing config.json\n        input_type: Input weight type (fp8/fp16/bf16/awq), used to validate FP8 config\n\n    Returns:\n        Dictionary with model configuration\n    \"\"\"\n    config_path = os.path.join(input_path, \"config.json\")\n    if not os.path.exists(config_path):\n        raise FileNotFoundError(f\"config.json not found in {input_path}\")\n\n    with open(config_path, \"r\") as f:\n        config = json.load(f)\n\n    if \"text_config\" in config:\n        text_cfg = config[\"text_config\"]\n    else:\n        text_cfg = config\n\n    # Extract required fields with fallbacks\n    model_config = {\n        \"num_experts\": text_cfg.get(\"n_routed_experts\", text_cfg.get(\"num_experts\")),\n        \"num_experts_per_tok\": text_cfg.get(\"num_experts_per_tok\", 2),\n        \"hidden_size\": text_cfg.get(\"hidden_size\"),\n        \"moe_intermediate_size\": text_cfg.get(\"moe_intermediate_size\", text_cfg.get(\"intermediate_size\")),\n    }\n\n    # Validate required fields\n    missing_fields = [k for k, v in model_config.items() if v is None]\n    if missing_fields:\n        raise ValueError(f\"Missing required config fields: {missing_fields}\")\n\n    # For FP8 input, extract and validate quantization_config\n    if input_type == \"fp8\":\n        quant_config = config.get(\"quantization_config\") or text_cfg.get(\"quantization_config\")\n        if quant_config is None:\n            raise ValueError(\n                \"FP8 input type specified but 'quantization_config' not found in config.json. \"\n                \"Expected quantization_config with weight_block_size field.\"\n            )\n\n        weight_block_size = quant_config.get(\"weight_block_size\")\n        if weight_block_size is None:\n            raise ValueError(\n                \"FP8 quantization_config found but 'weight_block_size' field is missing. \"\n                \"Expected format: 'weight_block_size': [128, 128]\"\n            )\n\n        if not isinstance(weight_block_size, list) or len(weight_block_size) != 2:\n            raise ValueError(\n                f\"Invalid weight_block_size format: {weight_block_size}. \"\n                \"Expected a list of two integers, e.g., [128, 128]\"\n            )\n\n        model_config[\"fp8_weight_block_size\"] = weight_block_size\n        print(f\"FP8 quantization config detected:\")\n        print(f\"  format: {quant_config.get('fmt', 'unknown')}\")\n        print(f\"  weight_block_size: {weight_block_size}\")\n    return model_config\n\n\ndef pack(imatrix: torch.Tensor):\n    \"\"\"\n    Packs a 4-bit integer matrix into a packed 32-bit integer matrix.\n    Args:\n        imatrix (torch.Tensor): matrix of integers\n        direction (str): direction of packing, either \"column\" or \"row\"\n\n    Returns:\n        qmatrix (torch.Tensor): packed matrix of integers\n    \"\"\"\n    shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=imatrix.device)\n\n    imatrix = torch.bitwise_and(imatrix, 0x0F).to(torch.int32)  # eventually correct overflow\n\n    imatrix = imatrix.view(imatrix.shape[0], imatrix.shape[1], -1, PACK_NUM)\n    qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, None, :]).sum(dim=-1)\n\n    qmatrix = qmatrix.to(torch.int32)\n\n    return qmatrix\n\n\ndef unpack(qmatrix: torch.Tensor):\n    \"\"\"\n    Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.\n\n    Args:\n        qmatrix (torch.Tensor): matrix of packed integers\n        direction (str): direction of unpacking, either \"column\" or \"row\"\n\n    Returns:\n        imatrix (torch.Tensor): matrix of integers\n    \"\"\"\n    shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=qmatrix.device)\n\n    imatrix = torch.bitwise_right_shift(qmatrix[:, :, :, None], shifts[None, None, None, :]).view(\n        qmatrix.shape[0], qmatrix.shape[1], -1\n    )\n\n    imatrix = imatrix.to(torch.int8) & 0x0F  # eventually correct overflow\n\n    return imatrix\n\n\ndef reverse_awq_interleaving(imatrix: torch.Tensor):\n    \"\"\"Reverse AWQ interleaving to get original order\"\"\"\n    # Reshape to handle interleaving at pack level\n    original_shape = imatrix.shape\n    imatrix_reshaped = imatrix.view(original_shape[0], original_shape[1], -1, PACK_NUM)\n\n    # Apply reverse AWQ pack order\n    imatrix_reordered = imatrix_reshaped[:, :, :, REVERSE_AWQ_PACK_ORDER]\n\n    return imatrix_reordered.view(original_shape)\n\n\ndef unpack_reverse_awq_interleaving(qweight: torch.Tensor, qzeros: torch.Tensor = None):\n    \"\"\"\n    Row-major unpack AWQ I32 -> INT4 and reverse interleaving to get original order\n\n    Args:\n        qweight: Packed AWQ weights with interleaving (I32)\n        qzeros: Packed AWQ zeros with interleaving (I32, optional)\n\n    Returns:\n        Tuple of (unpacked_weights, unpacked_zeros) in row major order (original)\n    \"\"\"\n    # Step 1: Row-major unpack I32 to INT4\n    iweights = unpack(qweight)  # Use row direction for row-major\n\n    if qzeros is not None:\n        izeros = unpack(qzeros)  # Use row direction for row-major\n    else:\n        izeros = None\n\n    # Step 2: Reverse AWQ interleaving to get original row-major order\n    iweights_original = reverse_awq_interleaving(iweights)\n\n    if izeros is not None:\n        izeros_original = reverse_awq_interleaving(izeros)\n    else:\n        izeros_original = None\n\n    return iweights_original, izeros_original\n\n\ndef pack_column_major_1d(iweights: torch.Tensor, izeros: torch.Tensor = None):\n    \"\"\"\n    Pack INT4 -> I32 then flatten to 1D with different logic for weights vs zeros\n\n    Args:\n        iweights: Unpacked weights in row major order (INT4)\n        izeros: Unpacked zeros in row major order (INT4, optional)\n\n    Returns:\n        Tuple of (packed_weights, packed_zeros) as 1D tensors\n    \"\"\"\n    # qweight: transpose to column-major then pack\n    iweights_transposed = iweights.transpose(1, 2).contiguous()\n    qweight = pack(iweights_transposed)\n    # qweight = qweight_2d.flatten()  # Flatten to 1D\n\n    # qzeros: NO transpose, keep original shape, pack with original interleaving (01234567)\n    if izeros is not None:\n        qzeros = pack(izeros)  # Keep original shape, original interleaving\n        # qzeros = qzeros_2d.flatten()  # Flatten to 1D\n    else:\n        qzeros = None\n\n    return qweight, qzeros\n\n\nclass ConverterBase:\n    \"\"\"Base class for converting model weights.\n\n    Subclasses must implement `_convert_layer_experts` to handle the expert\n    tensor transformation for a given quantization method (e.g., awq, int4, int8).\n    \"\"\"\n\n    def __init__(\n        self,\n        input_path: str,\n        output_path: str,\n        model_config: Dict,\n        cpuinfer_threads: int = 60,\n        threadpool_count: int = 2,\n        input_type: str = None,\n        merge_to_safetensor: bool = True,\n    ):\n        self.input_path = input_path\n        self.output_path = output_path\n        self.model_config = model_config\n        self.cpuinfer_threads = cpuinfer_threads\n        self.threadpool_count = threadpool_count\n        self.input_type = input_type\n        self.merge_to_safetensor = merge_to_safetensor\n        self.tensor_file_map: Dict[str, str] = {}  # key -> filename\n        self.tensor_key_map: Dict[str, str] = {}  # old key -> new key\n        self.file_handle_map: Dict[str, any] = {}  # filename -> file\n\n        # Extract commonly used config values for convenience\n        self.num_experts = model_config[\"num_experts\"]\n        self.num_experts_per_tok = model_config[\"num_experts_per_tok\"]\n        self.hidden_size = model_config[\"hidden_size\"]\n        self.moe_intermediate_size = model_config[\"moe_intermediate_size\"]\n        self.layout = \"base\"\n\n        # Load input safetensors files\n        self._load_input_files()\n\n    def _load_input_files(self):\n        \"\"\"Load all safetensors files from input directory\"\"\"\n        print(f\"Loading safetensors files from {self.input_path}\")\n\n        found_safetensor = False\n        for root, _, files in os.walk(self.input_path):\n            files = sorted(files)\n            for file in files:\n                if file.endswith(\".safetensors\"):\n                    found_safetensor = True\n                    file_path = os.path.join(root, file)\n                    try:\n                        handle = safe_open(file_path, framework=\"pt\")\n                        self.file_handle_map[file] = handle\n                        renamed = False\n                        for key in handle.keys():\n                            if \"language_model\" in key:\n                                key_ = key.replace(\"language_model.\", \"\")\n                                # print(\"  Renaming key:\", key, \"->\", key_)\n                                renamed = True\n                            else:\n                                key_ = key\n                            self.tensor_key_map[key_] = key\n                            self.tensor_file_map[key_] = file\n                        print(\n                            f\"  Loaded: {file} ({len(list(handle.keys()))} tensors){' (renamed keys)' if renamed else ''}\"\n                        )\n                    except Exception as e:\n                        print(f\"  Error loading {file}: {e}\")\n\n        if not found_safetensor:\n            raise FileNotFoundError(f\"No safetensors files found in {self.input_path}\")\n\n        print(f\"Total tensors loaded: {len(self.tensor_file_map)}\")\n\n    def _load_tensor(self, key: str) -> torch.Tensor:\n        \"\"\"Load tensor by key\"\"\"\n        if key not in self.tensor_file_map:\n            raise KeyError(f\"Key {key} not found\")\n\n        file = self.tensor_file_map[key]\n        handle = self.file_handle_map[file]\n        return handle.get_tensor(self.tensor_key_map.get(key, key))\n\n    # layers_id -> list[experts_id]\n    def _find_expert_layers(self) -> Dict[int, List[int]]:\n        \"\"\"Find all layers and experts in the model\"\"\"\n        layers = defaultdict(set)\n\n        # detect layout\n        for key in self.tensor_file_map.keys():\n            if \"mlp.experts\" in key and \"gate_up\" in key:\n                self.layout = \"fused\"\n                break\n\n        if self.layout == \"fused\":  # Pattern: model.layers.{layer}.mlp.experts.{proj}\n            layers = set()\n            for key in self.tensor_file_map.keys():\n                if \"model.layers.\" in key and \".mlp.experts.\" in key:\n                    parts = key.split(\".\")\n                    if len(parts) >= 6:\n                        layer_idx = int(parts[2])\n                        layers.add(layer_idx)\n\n            result: Dict[int, List[int]] = {}\n            for layer_idx in sorted(layers):\n                result[layer_idx] = [-1]\n\n            print(f\"Found {len(result)} layers with fused MoE experts\")\n            return result\n\n        # Pattern: model.layers.{layer}.mlp.experts.{expert}.{proj}.{type}\n        for key in self.tensor_file_map.keys():\n            if \"model.layers.\" in key and \".mlp.experts.\" in key:\n                parts = key.split(\".\")\n                if len(parts) >= 6:\n                    layer_idx = int(parts[2])\n                    expert_idx = int(parts[5])\n                    layers[layer_idx].add(expert_idx)\n\n        # Convert to sorted lists\n        result: Dict[int, List[int]] = {}\n        for layer_idx, expert_set in layers.items():\n            result[layer_idx] = sorted(list(expert_set))\n\n        print(f\"Found {len(result)} layers with MoE experts:\")\n        for layer_idx, experts in sorted(result.items()):\n            print(f\"  Layer {layer_idx}: {len(experts)} experts (0-{max(experts)})\")\n\n        return result\n\n    def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]:\n        \"\"\"Subclasses must implement expert conversion for a given layer.\n\n        Expected to return a mapping from output tensor keys to tensors.\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement _convert_layer_experts\")\n\n    def convert(self, resume_layer: int = 0):\n        \"\"\"Convert all expert layers using subclass-specific logic.\n\n        Args:\n            resume_layer (int, optional): The layer index to resume conversion from.\n                Layers with an index lower than this will be skipped. Defaults to 0.\n        \"\"\"\n        print(\"Starting conversion...\")\n        print(f\"Input: {self.input_path}\")\n        print(f\"Output: {self.output_path}\")\n        if resume_layer > 0:\n            print(f\"Resuming from layer: {resume_layer}\")\n\n        # Create output directory\n        os.makedirs(self.output_path, exist_ok=True)\n\n        # Find all expert layers\n        expert_layers = self._find_expert_layers()\n\n        if not expert_layers:\n            print(\"No MoE expert layers found in input!\")\n            return\n\n        # Convert each layer with memory management\n        all_tensors: Dict[str, torch.Tensor] = {}\n\n        # Enable memory optimization\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n        # Process layers with memory cleanup\n        for i, (layer_idx, expert_ids) in enumerate(sorted(expert_layers.items())):\n            if layer_idx < resume_layer:\n                continue\n            print(f\"Processing layer {layer_idx} ({i+1}/{len(expert_layers)})...\")\n\n            layer_tensors = self._convert_layer_experts(layer_idx, expert_ids)\n            all_tensors.update(layer_tensors)\n\n            # Periodic garbage collection to free memory\n            if (i + 1) % 5 == 0:  # Every 5 layers\n                gc.collect()\n                if torch.cuda.is_available():\n                    torch.cuda.empty_cache()\n                print(f\"  Memory cleanup after layer {layer_idx}\")\n\n        if self.merge_to_safetensor:\n            # Copy non-expert tensors (embeddings, norms, etc.)\n            print(\"Copying non-expert tensors...\")\n            for key in self.tensor_file_map.keys():\n                if not (\".mlp.experts.\" in key):\n                    # Convert key format for consistency\n                    if key.startswith(\"model.\"):\n                        # Convert model.layers.X -> blk.X for non-expert layers\n                        new_key = key.replace(\"model.layers.\", \"blk.\").replace(\"model.\", \"\")\n                        all_tensors[new_key] = self._load_tensor(key)\n                    else:\n                        all_tensors[key] = self._load_tensor(key)\n\n            # Save all tensors\n            print(f\"Saving {len(all_tensors)} tensors...\")\n\n            # Split into multiple files if too large\n            max_tensors_per_file = 3000  # Adjust based on memory constraints\n            tensor_items = list(all_tensors.items())\n\n            if len(tensor_items) <= max_tensors_per_file:\n                # Single file\n                output_file = os.path.join(self.output_path, \"model.safetensors\")\n                save_file(dict(tensor_items), output_file)\n                print(f\"Saved to: {output_file}\")\n            else:\n                # Multiple files\n                for i in range(0, len(tensor_items), max_tensors_per_file):\n                    batch = dict(tensor_items[i : i + max_tensors_per_file])\n                    output_file = os.path.join(self.output_path, f\"model-{i//max_tensors_per_file + 1:05d}.safetensors\")\n                    save_file(batch, output_file)\n                    print(f\"Saved batch to: {output_file}\")\n\n            # Copy config files\n            self._copy_config_files()\n\n            print(\"Conversion completed successfully!\")\n        else:\n            print(\"Skipping safetensor merge, weights kept in layer folder structure\")\n            print(\"Conversion completed successfully!\")\n\n    def _copy_config_files(self):\n        \"\"\"Copy configuration files to output directory\"\"\"\n        config_files = [\"config.json\", \"tokenizer.json\", \"tokenizer_config.json\", \"special_tokens_map.json\"]\n\n        for config_file in config_files:\n            src_path = os.path.join(self.input_path, config_file)\n            if os.path.exists(src_path):\n                import shutil\n\n                dst_path = os.path.join(self.output_path, config_file)\n                shutil.copy2(src_path, dst_path)\n                print(f\"Copied: {config_file}\")\n\n    def close(self):\n        \"\"\"Close all file handles\"\"\"\n        self.file_handle_map.clear()\n\n\nclass AWQToColumnMajorConverter(ConverterBase):\n    \"\"\"Convert raw AWQ safetensors to NUMA-sliced column-major format.\"\"\"\n\n    # NOTE: Only this method differs across quantization methods.\n    def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]:\n        \"\"\"Convert all experts in a layer to column major format with optimized AWQ processing\"\"\"\n        output_tensors = {}\n\n        start_time = time.time()\n        print(f\"Converting layer {layer_idx} with {len(expert_ids)} experts...\")\n\n        # Pre-compute projection name mappings\n        proj_mappings = {\"up_proj\": \"ffn_up_exps\", \"gate_proj\": \"ffn_gate_exps\", \"down_proj\": \"ffn_down_exps\"}\n\n        # Batch process all experts to reduce nested loops\n        for proj_name, out_proj in proj_mappings.items():\n            # Load all expert tensors for this projection at once\n            expert_qweights = []\n            expert_qzeros = []\n            expert_scales = []\n            valid_experts = []\n\n            for expert_id in expert_ids:\n                qweight_key = f\"model.layers.{layer_idx}.mlp.experts.{expert_id}.{proj_name}.qweight\"\n                qzeros_key = f\"model.layers.{layer_idx}.mlp.experts.{expert_id}.{proj_name}.qzeros\"\n                scales_key = f\"model.layers.{layer_idx}.mlp.experts.{expert_id}.{proj_name}.scales\"\n\n                if qweight_key in self.tensor_file_map:\n                    qweight = self._load_tensor(qweight_key)\n                    qzeros = self._load_tensor(qzeros_key) if qzeros_key in self.tensor_file_map else None\n                    scales = self._load_tensor(scales_key) if scales_key in self.tensor_file_map else None\n\n                    expert_qweights.append(qweight)\n                    expert_qzeros.append(qzeros)\n                    expert_scales.append(scales)\n                    valid_experts.append(expert_id)\n\n            if not valid_experts:\n                continue\n\n            print(f\"  Processing {proj_name}: {len(valid_experts)} experts\")\n\n            qweights_stack = torch.stack([w for w in expert_qweights if w is not None], dim=0)\n            qzeros_stack = torch.stack([z for z in expert_qzeros if z is not None], dim=0)\n\n            batch_size = 128\n\n            for batch_start in range(0, len(valid_experts), batch_size):\n                batch_end = min(batch_start + batch_size, len(valid_experts))\n                qweights_batch = qweights_stack[batch_start:batch_end].to(\"cuda\")\n                qzeros_batch = qzeros_stack[batch_start:batch_end].to(\"cuda\")\n                iweights_batch, izeros_batch = unpack_reverse_awq_interleaving(qweights_batch, qzeros_batch)\n                qweights_1d_batch, qzeros_1d_batch = pack_column_major_1d(iweights_batch, izeros_batch)\n\n                for idx in range(batch_start, batch_end):\n                    expert_id = valid_experts[idx]\n                    batch_idx = idx - batch_start\n                    output_tensors[f\"blk.{layer_idx}.{out_proj}.{expert_id}.scale\"] = expert_scales[idx].flatten()\n                    output_tensors[f\"blk.{layer_idx}.{out_proj}.{expert_id}.weight\"] = qweights_1d_batch[\n                        batch_idx\n                    ].cpu()\n                    if qzeros_1d_batch is not None:\n                        output_tensors[f\"blk.{layer_idx}.{out_proj}.{expert_id}.qzeros\"] = qzeros_1d_batch[\n                            batch_idx\n                        ].cpu()\n\n            gc.collect()\n\n        elapsed = time.time() - start_time\n        print(f\"  Generated {len(output_tensors)} column-major 1D tensors in {elapsed:.2f}s\")\n        return output_tensors\n\n\nclass OnlineQuantConverter(ConverterBase):\n    \"\"\"Convert FP8/FP16/BF16 weights to quantized format using AMXMoEWrapper.\n\n    Performs online quantization (FP8/FP16/BF16 -> INT4/INT8) using AMXMoEWrapper\n    with NUMA-aware memory management and automatic weight saving.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_path: str,\n        output_path: str,\n        model_config: Dict,\n        cpuinfer_threads: int = 60,\n        threadpool_count: int = 2,\n        input_type: str = None,\n        quant_method: str = \"int4\",\n        merge_to_safetensor: bool = True,\n    ):\n        super().__init__(\n            input_path, output_path, model_config, cpuinfer_threads, threadpool_count, input_type, merge_to_safetensor\n        )\n        self.quant_method = quant_method\n\n        # For FP8, get block size from model_config\n        if input_type == \"fp8\":\n            self.fp8_block_size = model_config.get(\"fp8_weight_block_size\", [128, 128])\n        else:\n            self.fp8_block_size = None\n\n    def _dequantize_fp8_blockwise(self, fp8_weight: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:\n        \"\"\"Dequantize FP8 weight with block-wise scaling.\n\n        Args:\n            fp8_weight: FP8 weight tensor of shape [H, W]\n            scale_inv: Scale inverse tensor of shape [H//block_size, W//block_size]\n\n        Returns:\n            Dequantized BF16 weight tensor of shape [H, W]\n        \"\"\"\n        H, W = fp8_weight.shape\n        num_blocks_h, num_blocks_w = scale_inv.shape\n\n        # Infer block size from shapes\n        block_h = H // num_blocks_h\n        block_w = W // num_blocks_w\n\n        # Reshape fp8_weight to [num_blocks_h, block_h, num_blocks_w, block_w]\n        fp8_reshaped = fp8_weight.view(num_blocks_h, block_h, num_blocks_w, block_w)\n\n        # Reshape scale_inv to [num_blocks_h, 1, num_blocks_w, 1] for broadcasting\n        scale_inv_reshaped = scale_inv.view(num_blocks_h, 1, num_blocks_w, 1)\n\n        # Dequantize: convert to bf16 and multiply by scale_inv\n        dequantized = fp8_reshaped.to(torch.bfloat16) * scale_inv_reshaped\n\n        # Reshape back to [H, W]\n        dequantized = dequantized.view(H, W).contiguous()\n\n        return dequantized\n\n    def _load_binary_tensor(self, file_path: str) -> torch.Tensor:\n        \"\"\"Load .kt format binary tensor file\n\n        Args:\n            file_path: Path to .kt binary file\n\n        Returns:\n            torch.Tensor: Loaded tensor\n        \"\"\"\n        if not os.path.exists(file_path):\n            raise FileNotFoundError(f\"File not found: {file_path}\")\n\n        with open(file_path, \"rb\") as f:\n            binary_data = f.read()\n\n        # Determine dtype based on file name\n        if \"scale\" in file_path:\n            # Scale tensors are typically float32\n            np_array = np.frombuffer(binary_data, dtype=np.float32)\n        else:\n            # Quant tensors are typically int8\n            np_array = np.frombuffer(binary_data, dtype=np.int8)\n\n        tensor = torch.from_numpy(np_array.copy())\n        return tensor\n\n    def _load_layer_tensors_from_disk(self, layer_idx: int) -> Dict[str, torch.Tensor]:\n        \"\"\"Load all quantized tensors from _layer_{layer_idx} folder\n\n        Args:\n            layer_idx: Layer index\n\n        Returns:\n            Dict[str, torch.Tensor]: Dictionary with keys in format:\n                'blk.{layer}.ffn_{proj}_exps.{expert}.numa.{numa_idx}.{weight|scale}'\n        \"\"\"\n        layer_path = os.path.join(self.output_path, f\"_layer_{layer_idx}\")\n        if not os.path.exists(layer_path):\n            raise FileNotFoundError(f\"Layer folder not found: {layer_path}\")\n\n        tensors = {}\n\n        # Get AMX method from quant_method parameter (INT4/INT8)\n        # Map quant_method to AMX_METHOD format\n        quant_to_amx_map = {\n            \"int4\": \"INT4\",\n            \"int8\": \"INT8\",\n            \"moe_int4\": \"MOE_INT4\",\n            \"moe_int8\": \"MOE_INT8\",\n        }\n        amx_method = quant_to_amx_map.get(self.quant_method, \"INT4\")\n\n        # Iterate through all NUMA folders\n        for numa_idx in range(self.threadpool_count):\n            numa_folder = os.path.join(layer_path, f\"_numa_{numa_idx}\")\n            if not os.path.exists(numa_folder):\n                print(f\"  Warning: NUMA folder not found: {numa_folder}, skipping...\")\n                continue\n\n            # Iterate through all experts\n            for expert_id in range(self.num_experts):\n                # For each projection (down, gate, up)\n                proj_mappings = [(\"down\", \"ffn_down_exps\"), (\"gate\", \"ffn_gate_exps\"), (\"up\", \"ffn_up_exps\")]\n\n                for proj_name, proj_key in proj_mappings:\n                    # Build file patterns\n                    quant_pattern = os.path.join(numa_folder, f\"{amx_method}_{proj_name}_{expert_id}_*Byte_quant_.kt\")\n                    scale_pattern = os.path.join(numa_folder, f\"{amx_method}_{proj_name}_{expert_id}_*Byte_scale_.kt\")\n\n                    # Find files using glob\n                    quant_files = glob.glob(quant_pattern)\n                    scale_files = glob.glob(scale_pattern)\n\n                    # Build keys (following merge_small_tensor.py format)\n                    weight_key = f\"blk.{layer_idx}.{proj_key}.{expert_id}.numa.{numa_idx}.weight\"\n                    scale_key = f\"blk.{layer_idx}.{proj_key}.{expert_id}.numa.{numa_idx}.scale\"\n\n                    # Load quant tensor\n                    if quant_files:\n                        if len(quant_files) > 1:\n                            raise ValueError(f\"Multiple quant files found: {quant_files}\")\n                        tensors[weight_key] = self._load_binary_tensor(quant_files[0])\n\n                    # Load scale tensor\n                    if scale_files:\n                        if len(scale_files) > 1:\n                            raise ValueError(f\"Multiple scale files found: {scale_files}\")\n                        tensors[scale_key] = self._load_binary_tensor(scale_files[0])\n\n        return tensors\n\n    def _remove_layer_folder(self, layer_idx: int):\n        \"\"\"Remove _layer_{layer_idx} folder and all its contents\n\n        Args:\n            layer_idx: Layer index\n        \"\"\"\n        import shutil\n\n        layer_path = os.path.join(self.output_path, f\"_layer_{layer_idx}\")\n        if os.path.exists(layer_path):\n            shutil.rmtree(layer_path)\n            print(f\"  Removed temporary folder: {layer_path}\")\n\n    def _convert_layer_experts(self, layer_idx: int, expert_ids: List[int]) -> Dict[str, torch.Tensor]:\n        \"\"\"Convert all experts in a layer using online quantization via AMXMoEWrapper\"\"\"\n        start_time = time.time()\n        print(\n            f\"Converting layer {layer_idx} with {len(expert_ids) if self.layout == 'base' else 'fused'} experts via online quantization...\"\n        )\n        # Load all expert weights for this layer\n        if self.layout == \"fused\":\n            if self.input_type not in [\"bf16\", \"fp16\"]:\n                raise ValueError(f\"Fused path currently supports bf16/fp16 only, got input_type={self.input_type}\")\n\n            proj_set = set()\n            prefix = f\"model.layers.{layer_idx}.mlp.experts.\"\n            for key in self.tensor_file_map.keys():\n                if key.startswith(prefix):\n                    parts = key.split(\".\")\n                    if len(parts) >= 6:\n                        proj_set.add(parts[5])\n\n            if not proj_set:\n                raise ValueError(f\"[Fused] No fused MoE experts found for layer {layer_idx} under 'model.layers'\")\n\n            projs = sorted(proj_set)\n            print(f\"  [Fused] layer {layer_idx} fused proj keys: {projs}\")\n            if len(projs) < 2:\n                raise ValueError(\n                    f\"[Fused] Expect at least 2 fused tensors (down & gate_up) in layer {layer_idx}, got {len(projs)}\"\n                )\n\n            fused_tensors = []\n            for p in projs:\n                key = f\"model.layers.{layer_idx}.mlp.experts.{p}\"\n                if key not in self.tensor_file_map:\n                    raise KeyError(f\"[Fused] Missing fused tensor {key} for layer {layer_idx}\")\n                w = self._load_tensor(key)\n                if self.input_type == \"fp16\":\n                    w = w.to(torch.bfloat16)\n                print(f\"    [Fused] tensor {p} shape: {tuple(w.shape)}\")\n                fused_tensors.append(w)\n\n            #   fused_tensors[0] : down-like, [E, I, H]\n            #   fused_tensors[1] : gate_up-like, [E, H, 2I]\n            down_fused = fused_tensors[0]\n            gate_up_fused = fused_tensors[1]\n\n            #    gate_up_fused: [E, H, 2I] -> [E, 2I, H] -> gate / up\n            if gate_up_fused.dim() != 3:\n                raise ValueError(\n                    f\"[Fused] Expect gate_up fused tensor to be 3D, got shape {tuple(gate_up_fused.shape)}\"\n                )\n            E, H, twoI = gate_up_fused.shape\n            if twoI % 2 != 0:\n                raise ValueError(f\"[Fused] gate_up last dim (2I) not even: {twoI}\")\n            I = twoI // 2\n\n            gate_up_T = gate_up_fused.transpose(1, 2).contiguous()  # [E, 2I, H]\n            gate_proj = gate_up_T[:, :I, :]  # [E, I, H]\n            up_proj = gate_up_T[:, I:, :]  # [E, I, H]\n\n            if down_fused.dim() != 3:\n                raise ValueError(f\"[Fused] Expect down fused tensor to be 3D, got shape {tuple(down_fused.shape)}\")\n            if down_fused.shape[0] != E:\n                raise ValueError(f\"[Fused] down_fused expert dim mismatch: {down_fused.shape[0]} vs gate_up {E}\")\n            down_proj = down_fused.transpose(1, 2).contiguous()  # [E, H, I]\n            del fused_tensors\n            del gate_up_fused\n            del down_fused\n        else:\n            gate_weights = []\n            up_weights = []\n            down_weights = []\n\n            for expert_id in expert_ids:\n                gate_key = f\"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight\"\n                up_key = f\"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight\"\n                down_key = f\"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight\"\n\n                if gate_key not in self.tensor_file_map:\n                    raise KeyError(f\"Missing gate weight for layer {layer_idx}, expert {expert_id}\")\n                if up_key not in self.tensor_file_map:\n                    raise KeyError(f\"Missing up weight for layer {layer_idx}, expert {expert_id}\")\n                if down_key not in self.tensor_file_map:\n                    raise KeyError(f\"Missing down weight for layer {layer_idx}, expert {expert_id}\")\n\n                # Load weights based on input type\n                if self.input_type == \"fp8\":\n                    # Load FP8 weights and their scale_inv tensors\n                    gate_scale_key = f\"model.layers.{layer_idx}.mlp.experts.{expert_id}.gate_proj.weight_scale_inv\"\n                    up_scale_key = f\"model.layers.{layer_idx}.mlp.experts.{expert_id}.up_proj.weight_scale_inv\"\n                    down_scale_key = f\"model.layers.{layer_idx}.mlp.experts.{expert_id}.down_proj.weight_scale_inv\"\n\n                    if gate_scale_key not in self.tensor_file_map:\n                        raise KeyError(f\"Missing gate weight_scale_inv for layer {layer_idx}, expert {expert_id}\")\n                    if up_scale_key not in self.tensor_file_map:\n                        raise KeyError(f\"Missing up weight_scale_inv for layer {layer_idx}, expert {expert_id}\")\n                    if down_scale_key not in self.tensor_file_map:\n                        raise KeyError(f\"Missing down weight_scale_inv for layer {layer_idx}, expert {expert_id}\")\n\n                    # Load FP8 weights and scales\n                    gate_fp8 = self._load_tensor(gate_key).to(\"cuda\")\n                    up_fp8 = self._load_tensor(up_key).to(\"cuda\")\n                    down_fp8 = self._load_tensor(down_key).to(\"cuda\")\n\n                    gate_scale_inv = self._load_tensor(gate_scale_key).to(\"cuda\")\n                    up_scale_inv = self._load_tensor(up_scale_key).to(\"cuda\")\n                    down_scale_inv = self._load_tensor(down_scale_key).to(\"cuda\")\n\n                    # Dequantize FP8 to BF16 using block-wise scaling\n                    gate_weight = weight_dequant(gate_fp8, gate_scale_inv).to(\"cpu\").to(torch.bfloat16).contiguous()\n                    up_weight = weight_dequant(up_fp8, up_scale_inv).to(\"cpu\").to(torch.bfloat16).contiguous()\n                    down_weight = weight_dequant(down_fp8, down_scale_inv).to(\"cpu\").to(torch.bfloat16).contiguous()\n\n                elif self.input_type == \"fp16\":\n                    # Load FP16 and convert to BF16\n                    gate_weight = self._load_tensor(gate_key).to(torch.bfloat16)\n                    up_weight = self._load_tensor(up_key).to(torch.bfloat16)\n                    down_weight = self._load_tensor(down_key).to(torch.bfloat16)\n\n                elif self.input_type == \"bf16\":\n                    # Load BF16 directly\n                    gate_weight = self._load_tensor(gate_key)\n                    up_weight = self._load_tensor(up_key)\n                    down_weight = self._load_tensor(down_key)\n\n                else:\n                    raise ValueError(f\"Unsupported input_type for INT4 conversion: {self.input_type}\")\n\n                gate_weights.append(gate_weight)\n                up_weights.append(up_weight)\n                down_weights.append(down_weight)\n\n            # Stack weights into single tensors: [num_experts, ...]\n            gate_proj = torch.stack(gate_weights, dim=0).contiguous()\n            up_proj = torch.stack(up_weights, dim=0).contiguous()\n            down_proj = torch.stack(down_weights, dim=0).contiguous()\n            del gate_weights, up_weights, down_weights\n\n        print(f\"  Loaded weights shapes:\")\n        print(f\"    gate_proj: {gate_proj.shape}\")\n        print(f\"    up_proj: {up_proj.shape}\")\n        print(f\"    down_proj: {down_proj.shape}\")\n\n        # Create physical_to_logical_map: identity mapping where position i maps to expert i\n        physical_to_logical_map = torch.arange(self.num_experts, dtype=torch.int64)\n\n        # Map quant_method to AMX method format\n        quant_to_amx_map = {\n            \"int4\": \"AMXINT4\",\n            \"int8\": \"AMXINT8\",\n            \"moe_int4\": \"MOE_INT4\",\n            \"moe_int8\": \"MOE_INT8\",\n        }\n        amx_method = quant_to_amx_map.get(self.quant_method, \"AMXINT4\")\n\n        # Create KTMoEWrapper instance for this layer\n        # gpu_experts_mask: all False means all experts are on CPU for conversion\n        gpu_experts_mask = torch.zeros(self.num_experts, dtype=torch.bool)\n        wrapper = KTMoEWrapper(\n            layer_idx=layer_idx,\n            num_experts=self.num_experts,\n            num_experts_per_tok=self.num_experts_per_tok,\n            hidden_size=self.hidden_size,\n            moe_intermediate_size=self.moe_intermediate_size,\n            gpu_experts_mask=gpu_experts_mask,  # All experts on CPU for conversion\n            cpuinfer_threads=self.cpuinfer_threads,\n            threadpool_count=self.threadpool_count,\n            weight_path=self.output_path,  # Output path for quantized weights\n            chunked_prefill_size=512,  # Arbitrary value, not critical for conversion\n            cpu_save=True,  # Enable saving quantized weights to output\n            method=amx_method,  # Specify quantization method (AMXINT4 or AMXINT8)\n        )\n\n        # Load and quantize weights from tensors\n        # This triggers the quantization process and saves to disk\n        wrapper.load_weights_from_tensors(gate_proj, up_proj, down_proj, physical_to_logical_map)\n\n        # Clean up to free memory\n        del gate_proj, up_proj, down_proj\n        gc.collect()\n\n        elapsed = time.time() - start_time\n\n        if self.merge_to_safetensor:\n            # Load quantized tensors from disk\n            print(f\"  Loading quantized tensors from disk...\")\n            layer_tensors = self._load_layer_tensors_from_disk(layer_idx)\n            print(f\"  Loaded {len(layer_tensors)} tensors\")\n\n            # Remove temporary layer folder\n            self._remove_layer_folder(layer_idx)\n\n            print(f\"  Layer {layer_idx} quantized and saved in {elapsed:.2f}s\")\n\n            # Return loaded tensors\n            return layer_tensors\n        else:\n            # Keep layer folders, return empty dict\n            print(f\"  Layer {layer_idx} quantized and saved in {elapsed:.2f}s\")\n            print(f\"  Keeping layer folder structure at {self.output_path}/_layer_{layer_idx}\")\n            return {}\n\n\n\"\"\"\nExample usage(test passed):\npython convert_cpu_weights.py --input-path /mnt/data3/models/DeepSeek-R1-0528/ --input-type fp8 --output /mnt/data3/models/DeepSeek-R1-0528-INT4-test --quant-method int4 --cpuinfer-threads 60 --threadpool-count 2\npython convert_cpu_weights.py --input-path /mnt/data3/models/DeepSeek-R1-0528/ --input-type fp8 --output /mnt/data3/models/DeepSeek-R1-0528-INT8-test --quant-method int8 --cpuinfer-threads 60 --threadpool-count 2\npython convert_cpu_weights.py --input-path /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct --input-type bf16 --output /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct-INT4-test --quant-method int4 --cpuinfer-threads 60 --threadpool-count 2\n\"\"\"\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Convert SafeTensors to column major 1D format\")\n    parser.add_argument(\"--input-path\", \"-i\", required=True, help=\"Input directory with safetensors\")\n    parser.add_argument(\n        \"--input-type\",\n        choices=[\"awq\", \"fp8\", \"fp16\", \"bf16\"],\n        required=True,\n        help=\"Input weight type (awq/fp8/fp16/bf16)\",\n    )\n    parser.add_argument(\"--output\", \"-o\", required=True, help=\"Output directory for converted safetensors\")\n    parser.add_argument(\n        \"--quant-method\",\n        choices=[\"int4\", \"int8\", \"awq\", \"moe_int4\", \"moe_int8\"],\n        default=\"int4\",\n        help=\"Quantization method for output (default: int4)\",\n    )\n    parser.add_argument(\n        \"--cpuinfer-threads\",\n        type=int,\n        default=60,\n        help=\"Number of CPU inference threads (default: 60)\",\n    )\n    parser.add_argument(\n        \"--threadpool-count\",\n        type=int,\n        default=2,\n        help=\"Number of NUMA subpools for thread distribution (default: 2)\",\n    )\n    parser.add_argument(\"--gpu\", action=\"store_true\", help=\"Use GPU for conversion if available\")\n    parser.add_argument(\n        \"--no-merge-safetensor\",\n        action=\"store_true\",\n        default=False,\n        help=\"Keep layer folders without merging to safetensor files (default: False)\",\n    )\n    parser.add_argument(\n        \"--resume-layer\",\n        type=int,\n        default=0,\n        help=\"Resume conversion starting at this layer index (default: 0)\",\n    )\n\n    args = parser.parse_args()\n\n    # Validate inputs\n    if not os.path.exists(args.input_path):\n        print(f\"Error: Input path does not exist: {args.input_path}\")\n        return 1\n    try:\n        # Load model configuration from config.json\n        print(\"Loading model configuration...\")\n        model_config = load_model_config(args.input_path, args.input_type)\n        print(f\"Model config: {model_config}\")\n        print(f\"  num_experts: {model_config['num_experts']}\")\n        print(f\"  num_experts_per_tok: {model_config['num_experts_per_tok']}\")\n        print(f\"  hidden_size: {model_config['hidden_size']}\")\n        print(f\"  moe_intermediate_size: {model_config['moe_intermediate_size']}\")\n        print(f\"CPU inference config:\")\n        print(f\"  cpuinfer_threads: {args.cpuinfer_threads}\")\n        print(f\"  threadpool_count: {args.threadpool_count}\")\n        print()\n\n        # Create converter by quantization method\n        quant_method = args.quant_method.lower()\n        merge_to_safetensor = not args.no_merge_safetensor\n\n        if quant_method == \"awq\":\n            converter = AWQToColumnMajorConverter(\n                args.input_path,\n                args.output,\n                model_config,\n                args.cpuinfer_threads,\n                args.threadpool_count,\n                input_type=None,\n                merge_to_safetensor=merge_to_safetensor,\n            )\n        elif quant_method in [\"int4\", \"int8\", \"moe_int4\", \"moe_int8\"] and args.input_type in [\"fp8\", \"fp16\", \"bf16\"]:\n            # Use OnlineQuantConverter for both INT4 and INT8 quantization\n            converter = OnlineQuantConverter(\n                args.input_path,\n                args.output,\n                model_config,\n                args.cpuinfer_threads,\n                args.threadpool_count,\n                args.input_type,\n                quant_method,\n                merge_to_safetensor,\n            )\n        else:\n            raise ValueError(\n                f\"Unsupported quant_method: {args.quant_method} or incompatible input_type: {args.input_type}\"\n            )\n\n        # Run conversion\n        converter.convert(resume_layer=args.resume_layer)\n\n        # Cleanup\n        converter.close()\n        return 0\n\n    except Exception as e:\n        print(f\"Error during conversion: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        return 1\n\n\nif __name__ == \"__main__\":\n    exit(main())\n"
  },
  {
    "path": "kt-kernel/scripts/convert_gpu_weights.py",
    "content": "#!/usr/bin/env python\n\"\"\"\nGPU Weight Quantization Tool for KTransformers\n\nThis script quantizes model weights for CPU-GPU hybrid inference when integrating\nKTransformers with SGLang. It supports multiple quantization methods (GPTQ, RTN) and\napplies selective quantization to GPU-resident layers while preserving certain\ncomponents (e.g., attention, gates, shared experts) in higher precision.\n\nUsage:\n    python convert_gpu_weights.py --model_id /path/to/model --output_dir /path/to/output --quant_method GPTQ --quant_type W4A16\n\nExample (GPTQ with calibration for best accuracy):\n    python convert_gpu_weights.py \\\n        --model_id /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct \\\n        --output_dir /mnt/data2/models/Qwen3-Next-80B-A3B-Instruct-GPU-weight \\\n        --quant_method GPTQ \\\n        --quant_type W4A16\n\nExample (RTN for fast quantization without calibration):\n    python convert_gpu_weights.py \\\n        --model_id /mnt/data/models/GLM-4.5-Air \\\n        --output_dir /mnt/data/models/GLM-4.5-Air-GPU-weights-rtn \\\n        --quant_method RTN \\\n        --quant_type W4A16\n\"\"\"\n\nimport os\nimport sys\nimport warnings\nimport argparse\n\n# IMPORTANT: Parse force_cpu argument BEFORE importing torch\n# CUDA_VISIBLE_DEVICES must be set before torch initializes CUDA\nif __name__ == \"__main__\":\n    # Quick check for --force_cpu flag before full argument parsing\n    if \"--force_cpu\" in sys.argv:\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n        warnings.filterwarnings(\"ignore\", message=\"Can't initialize NVML\")\n        print(\"🔧 Forced CPU-only mode (CUDA_VISIBLE_DEVICES set before torch import)\")\n\n# Now it's safe to import torch and other GPU-dependent libraries\nimport torch\nfrom accelerate import init_empty_weights, infer_auto_device_map\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\nfrom llmcompressor import oneshot\nfrom llmcompressor.modifiers.quantization.gptq import GPTQModifier\nfrom llmcompressor.modifiers.quantization import QuantizationModifier\nfrom datasets import load_dataset\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Quantize MoE models with selective quantization\")\n\n    # Required arguments\n    parser.add_argument(\"--model_id\", type=str, required=True, help=\"Path to the input model directory\")\n    parser.add_argument(\"--output_dir\", type=str, required=True, help=\"Path to save the quantized model\")\n\n    # Optional arguments\n    parser.add_argument(\n        \"--quant_method\",\n        type=str,\n        choices=[\"GPTQ\", \"RTN\"],\n        default=\"GPTQ\",\n        help=\"Quantization method: GPTQ (calibration-based) or RTN (round-to-nearest, no calibration). Default: GPTQ\",\n    )\n    parser.add_argument(\n        \"--quant_type\",\n        type=str,\n        choices=[\"W4A16\", \"W8A16\"],\n        default=\"W8A16\",\n        help=\"Quantization type: W4A16 (INT4) or W8A16 (INT8). Default: W8A16\",\n    )\n    parser.add_argument(\n        \"--num_calibration_samples\",\n        type=int,\n        default=512,\n        help=\"Number of calibration samples (GPTQ only). Default: 512\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=2048,\n        help=\"Maximum sequence length for calibration (GPTQ only). Default: 2048\",\n    )\n    parser.add_argument(\n        \"--dampening_frac\",\n        type=float,\n        default=0.1,\n        help=\"Dampening fraction to mitigate quantization noise (GPTQ only). Default: 0.1\",\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"HuggingFaceH4/ultrachat_200k\",\n        help=\"Dataset for calibration (GPTQ only). Default: HuggingFaceH4/ultrachat_200k\",\n    )\n    parser.add_argument(\n        \"--dataset_split\", type=str, default=\"train_sft\", help=\"Dataset split to use (GPTQ only). Default: train_sft\"\n    )\n    parser.add_argument(\n        \"--force_cpu\", action=\"store_true\", help=\"Force all computations to CPU (sets CUDA_VISIBLE_DEVICES='')\"\n    )\n    parser.add_argument(\n        \"--ignore_patterns\",\n        type=str,\n        nargs=\"*\",\n        default=[\n            \"lm_head\",\n            r\"re:.*\\.mlp\\.gate$\",\n            r\"re:.*\\.self_attn\\..*$\",\n            r\"re:.*\\.shared_expert\\..*$\",\n            r\"re:.*\\.shared_experts\\..*$\",\n            r\"re:.*\\.mlp\\.shared_expert_gate$\",\n            r\"re:.*\\.linear_attn\\..*$\",\n        ],\n        help=\"Regex patterns for layers to ignore during quantization\",\n    )\n    parser.add_argument(\n        \"--torch_dtype\",\n        type=str,\n        choices=[\"bfloat16\", \"float16\", \"float32\"],\n        default=\"bfloat16\",\n        help=\"PyTorch dtype for model loading. Default: bfloat16\",\n    )\n    parser.add_argument(\n        \"--trust_remote_code\", action=\"store_true\", help=\"Allow loading of remote code (required for some models)\"\n    )\n    parser.add_argument(\"--random_seed\", type=int, default=42, help=\"Random seed for dataset shuffling. Default: 42\")\n    parser.add_argument(\n        \"--max_gpu_memory\",\n        type=str,\n        default=None,\n        help=\"Maximum GPU memory for model weights per device (e.g., '40GiB'). \"\n        \"GPTQ quantization requires additional GPU memory for Hessian matrix computation, \"\n        \"so reserve 40-50%% of total VRAM. For example, use '40GiB' on 80GB GPUs. \"\n        \"Remaining layers will be offloaded to CPU. Default: use all available\",\n    )\n    parser.add_argument(\n        \"--max_cpu_memory\",\n        type=str,\n        default=None,\n        help=\"Maximum CPU memory to use (e.g., '100GiB'). Default: use all available\",\n    )\n\n    return parser.parse_args()\n\n\ndef setup_environment(force_cpu=False):\n    \"\"\"\n    Verify environment setup (actual setup happens before torch import).\n\n    Args:\n        force_cpu: If True, was requested to force CPU-only mode\n\n    Note:\n        CUDA_VISIBLE_DEVICES must be set BEFORE importing torch.\n        The actual environment setup is done at module import time.\n    \"\"\"\n    if force_cpu:\n        # Verify the environment variable was set correctly\n        cuda_visible = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n        if cuda_visible != \"\":\n            print(\"⚠️  Warning: force_cpu was requested but CUDA_VISIBLE_DEVICES is not empty\")\n            print(f\"   Current value: '{cuda_visible}'\")\n            print(\"   This may happen if imported as a module. Recommend running as script.\")\n        else:\n            print(\"✅ CPU-only mode verified (CUDA_VISIBLE_DEVICES is empty)\")\n\n\ndef get_torch_dtype(dtype_str):\n    \"\"\"\n    Convert string to torch dtype.\n\n    Args:\n        dtype_str: String representation of dtype (\"bfloat16\", \"float16\", \"float32\")\n\n    Returns:\n        torch.dtype: Corresponding PyTorch dtype\n    \"\"\"\n    dtype_map = {\"bfloat16\": torch.bfloat16, \"float16\": torch.float16, \"float32\": torch.float32}\n    return dtype_map[dtype_str]\n\n\ndef check_dense_layers_and_update_ignore(model_id, ignore_patterns, trust_remote_code=False):\n    \"\"\"\n    Check if the model has dense layers (first_k_dense_replace parameter) and add them to ignore list.\n\n    Some MoE models have dense MLP layers in the first few layers instead of MoE layers.\n    These dense layers should not be quantized using the same scheme as expert layers.\n\n    Args:\n        model_id: Path to the model\n        ignore_patterns: List of existing ignore patterns\n        trust_remote_code: Whether to trust remote code\n\n    Returns:\n        Updated ignore_patterns list with dense layer patterns added\n    \"\"\"\n    print(\"🔍 Checking model configuration for dense layers...\")\n\n    try:\n        # Load model configuration\n        config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)\n\n        # Check if the model has first_k_dense_replace parameter\n        first_k_dense_replace = getattr(config, \"first_k_dense_replace\", None)\n\n        if first_k_dense_replace is not None and first_k_dense_replace > 0:\n            print(f\"✅ Found dense layers configuration: first_k_dense_replace = {first_k_dense_replace}\")\n            print(f\"   Adding first {first_k_dense_replace} layers to ignore list...\")\n\n            # Create regex pattern for dense layers (layers 0 to first_k_dense_replace-1)\n            if first_k_dense_replace == 1:\n                dense_pattern = r\"re:model\\.layers\\.0\\.mlp\\..*$\"\n            else:\n                # For multiple layers, use range pattern\n                layer_range = f\"[0-{first_k_dense_replace-1}]\"\n                dense_pattern = f\"re:model\\\\.layers\\\\.{layer_range}\\\\.mlp\\\\..*$\"\n\n            # Add the dense layer pattern to ignore list\n            updated_ignore_patterns = ignore_patterns + [dense_pattern]\n\n            print(f\"   Dense layer pattern added: {dense_pattern}\")\n            print(f\"   This will ignore MLP components in layers 0-{first_k_dense_replace-1}\")\n\n            return updated_ignore_patterns\n        else:\n            print(\"ℹ️  No dense layers detected (first_k_dense_replace not found or is 0)\")\n            return ignore_patterns\n\n    except Exception as e:\n        print(f\"⚠️  Warning: Could not check model config for dense layers: {e}\")\n        print(\"   Proceeding with original ignore patterns...\")\n        return ignore_patterns\n\n\ndef load_and_prepare_dataset(dataset_name, dataset_split, num_samples, max_length, tokenizer, seed=42):\n    \"\"\"\n    Load and prepare calibration dataset for GPTQ quantization.\n\n    GPTQ requires calibration data to compute optimal quantization parameters.\n    This function loads a conversation dataset, applies chat template, and tokenizes it.\n\n    Args:\n        dataset_name: HuggingFace dataset name\n        dataset_split: Dataset split to use (e.g., \"train_sft\")\n        num_samples: Number of samples to use for calibration\n        max_length: Maximum sequence length for tokenization\n        tokenizer: Model tokenizer\n        seed: Random seed for shuffling\n\n    Returns:\n        Dataset with tokenized calibration samples\n    \"\"\"\n    print(f\"📊 Loading dataset: {dataset_name}\")\n\n    # Load dataset\n    ds = load_dataset(dataset_name, split=f\"{dataset_split}[:{num_samples}]\")\n    ds = ds.shuffle(seed=seed)\n\n    # Preprocess the data into the format the model is trained with\n    def preprocess(example):\n        return {\"text\": tokenizer.apply_chat_template(example[\"messages\"], tokenize=False)}\n\n    ds = ds.map(preprocess)\n\n    # Tokenize the data\n    def tokenize(sample):\n        return tokenizer(\n            sample[\"text\"], padding=False, max_length=max_length, truncation=True, add_special_tokens=False\n        )\n\n    ds = ds.map(tokenize, remove_columns=ds.column_names)\n    print(f\"✅ Dataset prepared with {len(ds)} samples\")\n\n    return ds\n\n\ndef main():\n    \"\"\"\n    Main function for GPU weight quantization.\n\n    This performs weight quantization on model weights intended for GPU execution\n    in CPU-GPU hybrid inference scenarios. Supports two quantization methods:\n\n    1. GPTQ (default): Calibration-based quantization for better accuracy\n       - Requires calibration dataset\n       - Higher accuracy but slower\n       - Recommended for production use\n\n    2. RTN (Round-To-Nearest): Fast quantization without calibration\n       - No calibration dataset needed\n       - Faster but may have lower accuracy\n       - Good for quick testing or prototyping\n\n    The quantization is selective:\n    - Expert MLP weights are quantized to INT4/INT8\n    - Attention layers, gates, and shared experts remain in original precision\n    - Dense layers (if present) are excluded from quantization\n\n    The quantized model can be used with SGLang+KTransformers for heterogeneous\n    inference, where \"hot\" experts run on GPU and \"cold\" experts run on CPU.\n    \"\"\"\n    args = parse_args()\n\n    # Setup environment\n    setup_environment(args.force_cpu)\n\n    # Convert torch dtype\n    torch_dtype = get_torch_dtype(args.torch_dtype)\n\n    print(f\"🚀 Starting quantization process\")\n    print(f\"   Model: {args.model_id}\")\n    print(f\"   Output: {args.output_dir}\")\n    print(f\"   Quantization method: {args.quant_method}\")\n    print(f\"   Quantization type: {args.quant_type}\")\n    if args.quant_method == \"GPTQ\":\n        print(f\"   Calibration samples: {args.num_calibration_samples}\")\n        print(f\"   Max sequence length: {args.max_sequence_length}\")\n    else:\n        print(f\"   Calibration: Not required for {args.quant_method}\")\n\n    # --------------------------------------------------------------------\n    # 0) Check for dense layers and update ignore patterns\n    # Dense layers in the first few layers should not be quantized\n    updated_ignore_patterns = check_dense_layers_and_update_ignore(\n        args.model_id, args.ignore_patterns, args.trust_remote_code\n    )\n\n    # --------------------------------------------------------------------\n    # 1) Build a dummy model (no weights) to infer a device map\n    # This determines optimal device placement for each module\n    if args.force_cpu:\n        # In force_cpu mode, directly get module names without calling infer_auto_device_map\n        # to avoid GPU memory allocation\n        print(\"🔍 Building CPU-only device map...\")\n        with init_empty_weights():\n            dummy = AutoModelForCausalLM.from_pretrained(\n                args.model_id, torch_dtype=torch_dtype, trust_remote_code=args.trust_remote_code\n            )\n            device_map = {name: \"cpu\" for name, _ in dummy.named_modules() if name}\n            del dummy\n    else:\n        print(\"🔍 Inferring device map...\")\n        with init_empty_weights():\n            dummy = AutoModelForCausalLM.from_pretrained(\n                args.model_id, torch_dtype=torch_dtype, trust_remote_code=args.trust_remote_code\n            )\n            # Build max_memory dict if specified\n            max_memory = None\n            if args.max_gpu_memory or args.max_cpu_memory:\n                max_memory = {}\n                if args.max_gpu_memory:\n                    # Apply to all available GPUs\n                    num_gpus = torch.cuda.device_count()\n                    for i in range(num_gpus):\n                        max_memory[i] = args.max_gpu_memory\n                    print(f\"   GPU memory limit: {args.max_gpu_memory} per device ({num_gpus} GPUs)\")\n\n                # Always set CPU memory when max_memory is used\n                # Otherwise infer_auto_device_map may trigger disk offloading\n                if args.max_cpu_memory:\n                    max_memory[\"cpu\"] = args.max_cpu_memory\n                    print(f\"   CPU memory limit: {args.max_cpu_memory}\")\n                else:\n                    # Use a very large value to allow using all available CPU memory\n                    # This prevents disk offloading when user has enough RAM\n                    max_memory[\"cpu\"] = \"1000GiB\"\n                    print(f\"   CPU memory limit: 1000GiB (default, to prevent disk offloading)\")\n\n            device_map = infer_auto_device_map(\n                dummy, no_split_module_classes=dummy._no_split_modules, max_memory=max_memory\n            )\n\n            # Check if disk offloading was triggered (not supported by llmcompressor)\n            disk_modules = [k for k, v in device_map.items() if v == \"disk\"]\n            if disk_modules:\n                print(f\"❌ Error: {len(disk_modules)} modules would be offloaded to disk.\")\n                print(\"   llmcompressor does not support disk offloading.\")\n                print(\"   Solutions:\")\n                print(\"   1. Increase --max_gpu_memory to use more GPU memory\")\n                print(\"   2. Add --max_cpu_memory with higher value (e.g., '200GiB')\")\n                print(\"   3. Ensure your machine has enough GPU + CPU memory\")\n                raise RuntimeError(\n                    \"Disk offloading is not supported by llmcompressor. \"\n                    \"Please ensure you have enough GPU + CPU memory.\"\n                )\n\n            del dummy\n    # --------------------------------------------------------------------\n    # 2) Load the full model weights with device mapping\n    # Note: offload_folder=None disables disk offloading (not supported by llmcompressor)\n    print(\"📥 Loading model...\")\n    try:\n        model = AutoModelForCausalLM.from_pretrained(\n            args.model_id,\n            device_map=device_map,\n            torch_dtype=torch_dtype,\n            trust_remote_code=args.trust_remote_code,\n            offload_folder=None,  # Disable disk offloading (not supported by llmcompressor)\n        )\n    except Exception as e:\n        if \"disk\" in str(e).lower() or \"offload\" in str(e).lower():\n            print(f\"❌ Error: Not enough GPU + CPU memory to load the model.\")\n            print(\"   llmcompressor does not support disk offloading.\")\n            print(\"   Solutions:\")\n            print(\"   1. Increase --max_gpu_memory to use more GPU memory\")\n            print(\"   2. Ensure you have enough CPU RAM for remaining layers\")\n            print(\"   3. Use a machine with more memory\")\n            raise\n        raise\n\n    tokenizer = AutoTokenizer.from_pretrained(args.model_id)\n\n    # --------------------------------------------------------------------\n    # 3) Prepare calibration dataset\n    # GPTQ needs calibration data to compute optimal quantization parameters\n    if args.quant_method == \"GPTQ\":\n        ds = load_and_prepare_dataset(\n            args.dataset,\n            args.dataset_split,\n            args.num_calibration_samples,\n            args.max_sequence_length,\n            tokenizer,\n            args.random_seed,\n        )\n\n    # --------------------------------------------------------------------\n    # 4) Create quantization recipe with selective layer exclusion\n    print(f\"⚙️  Setting up {args.quant_method} {args.quant_type} quantization recipe...\")\n    if args.quant_method == \"GPTQ\":\n        # GPTQ: calibration-based quantization for better accuracy\n        recipe = GPTQModifier(\n            targets=\"Linear\",  # Target all Linear layers\n            scheme=args.quant_type,  # W4A16 or W8A16\n            ignore=updated_ignore_patterns,  # Exclude specific patterns\n            dampening_frac=args.dampening_frac,\n        )\n    elif args.quant_method == \"RTN\":\n        # RTN (Round-To-Nearest): fast quantization without calibration\n        recipe = QuantizationModifier(\n            targets=\"Linear\",  # Target all Linear layers\n            scheme=args.quant_type,  # W4A16 or W8A16\n            ignore=updated_ignore_patterns,  # Exclude specific patterns\n        )\n    else:\n        raise ValueError(f\"Unsupported quantization method: {args.quant_method}\")\n\n    print(\"🔧 Ignoring the following patterns from quantization:\")\n    for i, pattern in enumerate(updated_ignore_patterns):\n        marker = \"🆕\" if i >= len(args.ignore_patterns) else \"   \"\n        print(f\"   {marker} {pattern}\")\n\n    # --------------------------------------------------------------------\n    # 5) Perform one-shot quantization\n    # GPTQ: calibration-based quantization to minimize accuracy loss\n    # RTN: fast round-to-nearest quantization without calibration\n    print(\"🎯 Starting one-shot quantization...\")\n    if args.quant_method == \"GPTQ\":\n        # GPTQ requires calibration dataset\n        oneshot(\n            model=model,\n            dataset=ds,\n            recipe=recipe,\n            output_dir=args.output_dir,\n            max_seq_length=args.max_sequence_length,\n            num_calibration_samples=args.num_calibration_samples,\n        )\n    elif args.quant_method == \"RTN\":\n        # RTN does not require calibration dataset\n        oneshot(\n            model=model,\n            recipe=recipe,\n            output_dir=args.output_dir,\n        )\n    else:\n        raise ValueError(f\"Unsupported quantization method: {args.quant_method}\")\n\n    print(f\"\\n✅ Quantized model written to: {args.output_dir}\")\n    print(f\"   Quantization method: {args.quant_method}\")\n    print(f\"   Quantization type: {args.quant_type}\")\n    print(f\"   Ignored patterns remain in {args.torch_dtype}\")\n    print(\"🎉 Quantization completed successfully!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/scripts/convert_kimi_k2_fp8_to_bf16_cpu.py",
    "content": "import os\nimport json\nfrom argparse import ArgumentParser\nfrom glob import glob\nfrom tqdm import tqdm\n\nimport torch\nfrom safetensors.torch import load_file, save_file\n\nimport gc\n\n\ndef weight_dequant_cpu(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:\n    assert x.dim() == 2 and s.dim() == 2, \"Expect 2D tensors for x and s\"\n    M, N = x.shape\n    n_m = (M + block_size - 1) // block_size\n    n_n = (N + block_size - 1) // block_size\n\n    y = torch.empty((M, N), dtype=torch.bfloat16, device=\"cpu\")\n    for bm in range(n_m):\n        m0 = bm * block_size\n        m1 = min(m0 + block_size, M)\n        for bn in range(n_n):\n            n0 = bn * block_size\n            n1 = min(n0 + block_size, N)\n            scale = s[bm, bn].item()\n            sub = x[m0:m1, n0:n1].to(torch.float32) * scale\n            y[m0:m1, n0:n1] = sub.to(torch.bfloat16)\n    return y\n\n\ndef main(fp8_path, bf16_path):\n    torch.set_default_dtype(torch.bfloat16)\n    os.makedirs(bf16_path, exist_ok=True)\n    model_index_file = os.path.join(fp8_path, \"model.safetensors.index.json\")\n    with open(model_index_file, \"r\") as f:\n        model_index = json.load(f)\n    weight_map = model_index[\"weight_map\"]\n\n    loaded_files = {}\n    fp8_weight_names = []\n\n    def get_tensor(tensor_name):\n        file_name = weight_map[tensor_name]\n        if file_name not in loaded_files:\n            file_path = os.path.join(fp8_path, file_name)\n            loaded_files[file_name] = load_file(file_path, device=\"cpu\")\n        return loaded_files[file_name][tensor_name]\n\n    safetensor_files = list(glob(os.path.join(fp8_path, \"*.safetensors\")))\n    safetensor_files.sort()\n    for safetensor_file in tqdm(safetensor_files, desc=\"weight file convert\"):\n        file_name = os.path.basename(safetensor_file)\n        current_state_dict = load_file(safetensor_file, device=\"cpu\")\n        loaded_files[file_name] = current_state_dict\n\n        new_state_dict = {}\n        for weight_name, weight in current_state_dict.items():\n            if weight_name.endswith(\"_scale_inv\"):\n                continue\n            elif weight.element_size() == 1:\n                scale_inv_name = f\"{weight_name}_scale_inv\"\n                try:\n                    scale_inv = get_tensor(scale_inv_name)\n                    fp8_weight_names.append(weight_name)\n                    new_state_dict[weight_name] = weight_dequant_cpu(weight, scale_inv)\n                except KeyError:\n                    print(f\"Warning: {weight_name}loss scale factor\")\n                    new_state_dict[weight_name] = weight\n            else:\n                new_state_dict[weight_name] = weight\n\n        new_safetensor_file = os.path.join(bf16_path, file_name)\n        save_file(new_state_dict, new_safetensor_file)\n\n        if len(loaded_files) > 2:\n            oldest_file = next(iter(loaded_files))\n            del loaded_files[oldest_file]\n            gc.collect()\n            if torch.cuda.is_available():\n                torch.cuda.empty_cache()\n\n    new_model_index_file = os.path.join(bf16_path, \"model.safetensors.index.json\")\n    for weight_name in fp8_weight_names:\n        scale_inv_name = f\"{weight_name}_scale_inv\"\n        if scale_inv_name in weight_map:\n            weight_map.pop(scale_inv_name)\n    with open(new_model_index_file, \"w\") as f:\n        json.dump({\"metadata\": {}, \"weight_map\": weight_map}, f, indent=2)\n    print(f\"Finish, Result in: {bf16_path}\")\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--input-fp8-hf-path\", type=str, required=True, help=\"Kimi-K2 FP8 model\")\n    parser.add_argument(\"--output-bf16-hf-path\", type=str, required=True, help=\"BF16 model (After convert)\")\n    args = parser.parse_args()\n    main(args.input_fp8_hf_path, args.output_bf16_hf_path)\n"
  },
  {
    "path": "kt-kernel/scripts/convert_moe_to_bf16.py",
    "content": "import argparse\nimport json\nimport os\nfrom collections import defaultdict\nfrom typing import Dict, Iterable, List, Optional, Tuple\n\nimport torch\nfrom safetensors.torch import save_file, safe_open\n\nfrom compressed_tensors.compressors import unpack_from_int32\n\n\ndef _load_config(model_dir: str, config_path: Optional[str]) -> Tuple[int, int, int]:\n    cfg_path = config_path or os.path.join(model_dir, \"config.json\")\n    with open(cfg_path, \"r\") as f:\n        cfg = json.load(f)\n    hidden_size = int(cfg.get(\"hidden_size\"))\n    inter_size = int(cfg.get(\"moe_intermediate_size\"))\n    group_size = int(\n        cfg.get(\"quantization_config\", {})\n        .get(\"config_groups\", {})\n        .get(\"group_0\", {})\n        .get(\"weights\", {})\n        .get(\"group_size\", 32)\n    )\n    return hidden_size, inter_size, group_size\n\n\ndef _dequantize_tensor(\n    weight_packed: torch.Tensor,\n    weight_scale: torch.Tensor,\n    weight_shape: torch.Tensor,\n    group_size: int,\n) -> torch.Tensor:\n    if isinstance(weight_shape, torch.Tensor):\n        shape = tuple(int(v) for v in weight_shape.view(-1).tolist())\n    else:\n        shape = tuple(weight_shape)\n    weight = unpack_from_int32(weight_packed, 4, shape)\n    if group_size > 0:\n        scale = weight_scale.to(torch.float32)\n        if scale.dim() == 1:\n            scale = scale.unsqueeze(1)\n        scales = torch.repeat_interleave(scale, repeats=group_size, dim=1)\n    else:\n        scales = weight_scale.to(torch.float32)\n    if scales.shape != weight.shape:\n        if scales.numel() == weight.numel():\n            scales = scales.reshape_as(weight)\n        else:\n            raise ValueError(f\"Scale shape {scales.shape} incompatible with weight shape {weight.shape}\")\n    bf16 = (weight.to(torch.float32) * scales).to(torch.bfloat16)\n    return bf16.contiguous()\n\n\ndef _is_quantized_weight_key(key: str) -> bool:\n    if \".mlp.experts.\" not in key or \".shared_experts.\" in key:\n        return False\n    suffixes = (\"weight_packed\", \"weight_scale\", \"weight_shape\")\n    for proj in (\"gate_proj\", \"up_proj\", \"down_proj\"):\n        for suffix in suffixes:\n            if key.endswith(f\".{proj}.{suffix}\"):\n                return True\n    return False\n\n\ndef convert_file(\n    input_path: str,\n    output_path: str,\n    group_size: int,\n    skip_existing: bool = True,\n):\n    if skip_existing and os.path.exists(output_path):\n        print(f\"[skip] {output_path} already exists.\")\n        return\n\n    tensors: Dict[str, torch.Tensor] = {}\n    expert_buffers: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = defaultdict(lambda: defaultdict(dict))\n\n    with safe_open(input_path, framework=\"pt\") as reader:\n        keys = list(reader.keys())\n        for key in keys:\n            tensor = reader.get_tensor(key).detach().cpu()\n\n            if not _is_quantized_weight_key(key):\n                tensors[key] = tensor\n                continue\n\n            parts = key.split(\".\")\n            try:\n                expert_idx = parts.index(\"experts\")\n            except ValueError:\n                tensors[key] = tensor\n                continue\n\n            prefix = \".\".join(parts[: expert_idx + 2])\n            project = parts[-2]\n            suffix = parts[-1]\n            expert_buffers[prefix][project][suffix] = tensor\n\n    stats = {\n        \"converted\": 0,\n        \"skipped\": 0,\n    }\n\n    for prefix, components in expert_buffers.items():\n        for proj_name in [\"gate_proj\", \"up_proj\", \"down_proj\"]:\n            proj_data = components.get(proj_name, {})\n            required = {\"weight_packed\", \"weight_scale\", \"weight_shape\"}\n            if not required.issubset(proj_data.keys()):\n                print(f\"[warn] Missing components for {prefix}.{proj_name}, keeping quantized tensors.\")\n                for suffix, value in proj_data.items():\n                    tensors[f\"{prefix}.{proj_name}.{suffix}\"] = value\n                stats[\"skipped\"] += 1\n                continue\n\n            bf16_weight = _dequantize_tensor(\n                proj_data[\"weight_packed\"].to(torch.int32),\n                proj_data[\"weight_scale\"].to(torch.float32),\n                proj_data[\"weight_shape\"],\n                group_size,\n            )\n            tensors[f\"{prefix}.{proj_name}.weight\"] = bf16_weight.to(torch.bfloat16)\n            stats[\"converted\"] += 1\n            print(f\"    converted {prefix}.{proj_name}.weight -> bf16\")\n\n    os.makedirs(os.path.dirname(output_path), exist_ok=True)\n    save_file(tensors, output_path)\n    print(f\"[done] wrote {output_path} (converted={stats['converted']}, skipped={stats['skipped']})\")\n\n\ndef parse_args() -> argparse.Namespace:\n    parser = argparse.ArgumentParser(description=\"Convert MoE experts to BF16 weights.\")\n    parser.add_argument(\"--model-dir\", required=True, help=\"Directory containing safetensors checkpoints.\")\n    parser.add_argument(\n        \"--output-dir\",\n        default=None,\n        help=\"Destination directory for converted checkpoints (default: <model-dir>_bf16).\",\n    )\n    parser.add_argument(\n        \"--files\",\n        nargs=\"+\",\n        default=None,\n        help=\"Specific safetensor filenames to convert (relative to model-dir). Convert all if omitted.\",\n    )\n    parser.add_argument(\n        \"--config-path\",\n        default=None,\n        help=\"Path to config.json for extracting group_size (default: model-dir/config.json).\",\n    )\n    parser.add_argument(\n        \"--overwrite\",\n        action=\"store_true\",\n        help=\"Rewrite output files even if they already exist.\",\n    )\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n    model_dir = os.path.abspath(args.model_dir)\n    output_dir = os.path.abspath(args.output_dir or f\"{model_dir}_bf16\")\n\n    if not os.path.isdir(model_dir):\n        raise FileNotFoundError(f\"Model directory not found: {model_dir}\")\n\n    _, _, group_size = _load_config(model_dir, args.config_path)\n\n    if args.files:\n        targets = [os.path.join(model_dir, fname) for fname in args.files]\n    else:\n        targets = [\n            os.path.join(model_dir, name) for name in sorted(os.listdir(model_dir)) if name.endswith(\".safetensors\")\n        ]\n\n    if not targets:\n        print(\"No safetensors checkpoints found.\")\n        return\n\n    total = len(targets)\n\n    for idx, path in enumerate(targets, start=1):\n        if not os.path.isfile(path):\n            print(f\"[skip] {path} is not a file.\")\n            continue\n        rel = os.path.relpath(path, model_dir)\n        output_path = os.path.join(output_dir, rel)\n        print(f\"[{idx}/{total}] converting {rel}\")\n        convert_file(path, output_path, group_size, skip_existing=not args.overwrite)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/scripts/install-git-hooks.sh",
    "content": "#!/usr/bin/env sh\n# Install git hooks from kt-kernel/.githooks into the monorepo's .git/hooks by\n# creating symlinks (or copying if symlink fails).\n\nset -eu\n\n# This script lives in kt-kernel/scripts/, so REPO_ROOT = kt-kernel\nREPO_ROOT=\"$(cd \"$(dirname \"$0\")/..\" && pwd)\"\nHOOKS_SRC=\"$REPO_ROOT/.githooks\"\n\n# Detect the top-level Git worktree (the monorepo root: ktransformers)\nGIT_TOP=\"$(git rev-parse --show-toplevel 2>/dev/null || true)\"\nif [ -z \"$GIT_TOP\" ] || [ ! -d \"$GIT_TOP/.git\" ]; then\n  echo \"[install-git-hooks] Not inside a git worktree; skipping hooks installation.\" >&2\n  exit 0\nfi\n\nGIT_DIR=\"$GIT_TOP/.git\"\nHOOKS_DEST=\"$GIT_DIR/hooks\"\n\nif [ ! -d \"$HOOKS_SRC\" ]; then\n  echo \"[install-git-hooks] No .githooks directory found at $HOOKS_SRC\" >&2\n  exit 1\nfi\n\necho \"[install-git-hooks] Installing git hooks from $HOOKS_SRC to $HOOKS_DEST (repo: $GIT_TOP)\"\n\n# Ensure all source hook files are executable so that even if copied (not symlinked) they run.\nfor src_hook in \"$HOOKS_SRC\"/*; do\n  [ -f \"$src_hook\" ] || continue\n  if [ ! -x \"$src_hook\" ]; then\n    chmod +x \"$src_hook\" || true\n  fi\ndone\n\nfor hook in \"$HOOKS_SRC\"/*; do\n  [ -e \"$hook\" ] || continue\n  name=$(basename \"$hook\")\n  dest=\"$HOOKS_DEST/$name\"\n\n  # Remove existing hook if it's our symlink or a file\n  if [ -L \"$dest\" ] || [ -f \"$dest\" ]; then\n    rm -f \"$dest\"\n  fi\n\n  # Try symlink first\n  if ln -s \"$hook\" \"$dest\" 2>/dev/null; then\n    echo \"linked $name\"\n  else\n    # Fall back to copying and preserve executable bit\n    cp \"$hook\" \"$dest\"\n    chmod +x \"$dest\"\n    echo \"copied $name\"\n  fi\ndone\n\necho \"[install-git-hooks] Done. Hooks installed.\"\n"
  },
  {
    "path": "kt-kernel/setup.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\"\"\"\nLightweight packaging script for building and distributing kt-kernel,\na high-performance kernel operations library for KTransformers.\n\n    pip install kt-kernel\n    >>> from kt_kernel import AMXMoEWrapper\n\nThis script drives your existing CMake build (root `CMakeLists.txt`) and\nonly needs a working C++ toolchain, CMake (>=3.16), and pybind11 (vendored\nalready in the repo).\n\nEnvironment knobs (export before running pip install .):\n  CPUINFER_FORCE_REBUILD=1        Always rebuild (ignore any cached build)\n  CPUINFER_BUILD_TYPE=Release     Debug / RelWithDebInfo / Release\n  CPUINFER_PARALLEL=8             Parallel build jobs (auto = detected cores)\n  CPUINFER_CPU_INSTRUCT=FANCY     One of: NATIVE|FANCY|AVX512|AVX2 (maps to CMake flags)\n  CPUINFER_ENABLE_AMX=OFF         ON/OFF -> -DKTRANSFORMERS_CPU_USE_AMX\n  CPUINFER_ENABLE_MLA=OFF         ON/OFF -> -DKTRANSFORMERS_CPU_MLA\n  CPUINFER_ENABLE_BLIS=OFF         ON/OFF -> -DKTRANSFORMERS_CPU_MOE_AMD\n  CPUINFER_ENABLE_KML=OFF         ON/OFF -> -DKTRANSFORMERS_CPU_USE_KML\n  CPUINFER_ENABLE_AVX512=OFF      ON/OFF -> -DKTRANSFORMERS_CPU_USE_AMX_AVX512\n  CPUINFER_ENABLE_AVX512_VNNI=OFF ON/OFF -> -DLLAMA_AVX512_VNNI\n  CPUINFER_ENABLE_AVX512_BF16=OFF ON/OFF -> -DLLAMA_AVX512_BF16\n  CPUINFER_ENABLE_AVX512_VBMI=OFF ON/OFF -> -DLLAMA_AVX512_VBMI (required for FP8 MoE)\n  CPUINFER_BLIS_ROOT=/path/to/blis  Forward to -DBLIS_ROOT\n\n\n  CPUINFER_ENABLE_LTO=ON          ON/OFF -> -DCPUINFER_ENABLE_LTO (your added option)\n  CPUINFER_LTO_JOBS=8             Forward to -DCPUINFER_LTO_JOBS\n  CPUINFER_LTO_MODE=auto          Forward to -DCPUINFER_LTO_MODE\n  CPUINFER_NATIVE=ON               (override LLAMA_NATIVE)\n\n\nGPU backends (if ever added later, keep placeholders):\n  CPUINFER_USE_CUDA=0/1           -DKTRANSFORMERS_USE_CUDA\n  CPUINFER_USE_ROCM=0/1           -DKTRANSFORMERS_USE_ROCM\n  CPUINFER_USE_MUSA=0/1           -DKTRANSFORMERS_USE_MUSA\n\nUsage:\n  pip install .\nOr build wheel:\n  python -m build  (if you have build/installed)\n\nResulting wheel exposes a top-level package `kt_kernel` with AMXMoEWrapper and other kernel wrappers.\n\"\"\"\nfrom __future__ import annotations\nimport os\nimport re\nimport sys\nimport platform\nimport subprocess\nfrom pathlib import Path\nfrom setuptools import setup, Extension\nfrom setuptools.command.build_ext import build_ext\nimport shutil\n\n\n# -------------------------\n# Env parsing helpers\n# -------------------------\ndef _env_get_bool(name: str, default: bool | None = None) -> bool | None:\n    v = os.environ.get(name)\n    if v is None:\n        return default\n    val = v.strip().lower()\n    if val in (\"1\", \"on\", \"true\", \"yes\", \"y\", \"enable\", \"enabled\"):\n        return True\n    if val in (\"0\", \"off\", \"false\", \"no\", \"n\", \"disable\", \"disabled\"):\n        return False\n    return default\n\n\ndef _cmake_onoff(flag: bool) -> str:\n    return \"ON\" if flag else \"OFF\"\n\n\ndef _forward_bool_env(cmake_args: list[str], env_name: str, cmake_flag: str) -> bool:\n    \"\"\"If env exists, forward it to CMake as -D<flag>=ON/OFF and return True; else return False.\"\"\"\n    b = _env_get_bool(env_name, None)\n    if b is None:\n        return False\n    cmake_args.append(f\"-D{cmake_flag}={_cmake_onoff(b)}\")\n    print(f\"-- Forward {env_name} -> -D{cmake_flag}={_cmake_onoff(b)}\")\n    return True\n\n\ndef _forward_str_env(cmake_args: list[str], env_name: str, cmake_flag: str) -> bool:\n    v = os.environ.get(env_name)\n    if not v:\n        return False\n    cmake_args.append(f\"-D{cmake_flag}={v}\")\n    print(f\"-- Forward {env_name} -> -D{cmake_flag}={v}\")\n    return True\n\n\n################################################################################\n# Helpers\n################################################################################\n\nREPO_ROOT = Path(__file__).parent.resolve()\n\nCPU_FEATURE_MAP = {\n    \"FANCY\": \"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON\",\n    \"AVX512\": \"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON\",\n    \"AVX2\": \"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON\",\n    \"NATIVE\": \"-DLLAMA_NATIVE=ON\",\n}\n\n\ndef default_build_type() -> str:\n    return os.environ.get(\"CPUINFER_BUILD_TYPE\", \"Release\")\n\n\ndef detect_parallel_jobs() -> str:\n    if \"CPUINFER_PARALLEL\" in os.environ:\n        return os.environ[\"CPUINFER_PARALLEL\"]\n    try:\n        import multiprocessing\n\n        return str(multiprocessing.cpu_count())\n    except Exception:\n        return \"1\"\n\n\ndef cpu_feature_flags() -> list[str]:\n    mode = os.environ.get(\"CPUINFER_CPU_INSTRUCT\", \"NATIVE\").upper()\n    return [tok for tok in CPU_FEATURE_MAP.get(mode, CPU_FEATURE_MAP[\"NATIVE\"]).split() if tok]\n\n\n################################################################################\n# CMakeExtension + builder\n################################################################################\n\n\nclass CMakeExtension(Extension):\n    def __init__(self, name: str, sourcedir: str = \"\"):\n        super().__init__(name, sources=[])\n        self.sourcedir = str(Path(sourcedir).resolve())\n\n\nclass CMakeBuild(build_ext):\n    def run(self):\n        # Ensure CMake present\n        try:\n            subprocess.run([\"cmake\", \"--version\"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n        except Exception as e:  # pragma: no cover\n            raise RuntimeError(\"CMake is required to build this project\") from e\n        super().run()\n\n    def detect_cpu_info(self) -> dict:\n        \"\"\"Detect CPU vendor/arch and instruction set features.\n\n        Returns a dict like:\n            {\n                'vendor': 'intel'|'amd'|'arm'|'unknown',\n                'arch': platform.machine().lower(),\n                'features': set(['AVX2','AVX512','AMX']),\n                'raw': { 'flags': set([...]) }\n            }\n        \"\"\"\n        info = {\n            \"vendor\": \"unknown\",\n            \"arch\": platform.machine().lower(),\n            \"features\": set(),\n            \"raw\": {\"flags\": set()},\n        }\n        try:\n            sysname = platform.system()\n            if sysname == \"Linux\":\n                with open(\"/proc/cpuinfo\", \"r\", encoding=\"utf-8\", errors=\"ignore\") as f:\n                    cpuinfo = f.read()\n                low = cpuinfo.lower()\n\n                # vendor\n                if \"vendor_id\" in low:\n                    # Typical x86 linux\n                    m = re.search(r\"vendor_id\\s*:\\s*(\\S+)\", cpuinfo)\n                    if m:\n                        v = m.group(1).lower()\n                        if \"genuineintel\" in v:\n                            info[\"vendor\"] = \"intel\"\n                        elif \"authenticamd\" in v:\n                            info[\"vendor\"] = \"amd\"\n                # ARM sometimes has 'model name' or 'Hardware'\n                if info[\"vendor\"] == \"unknown\":\n                    if any(tok in low for tok in [\"aarch64\", \"armv8\", \"arm cortex\", \"kunpeng\", \"kirin\", \"huawei\"]):\n                        info[\"vendor\"] = \"arm\"\n\n                # flags collection (x86 uses 'flags', arm uses 'Features')\n                flags = set()\n                for key in (\"flags\", \"Features\", \"features\"):\n                    m = re.search(rf\"^{key}\\s*:\\s*(.+)$\", cpuinfo, re.IGNORECASE | re.MULTILINE)\n                    if m:\n                        flags.update(m.group(1).lower().split())\n                info[\"raw\"][\"flags\"] = flags\n\n                # feature summary\n                if any(f in flags or f in low for f in [\"avx512f\", \"avx512bw\", \"avx512dq\", \"avx512vl\"]):\n                    info[\"features\"].add(\"AVX512\")\n                if \"avx2\" in flags or \"avx2\" in low:\n                    info[\"features\"].add(\"AVX2\")\n                # AMX flags on Linux are with underscores; keep hyphen fallback just in case\n                if any(\n                    f in flags or f in low\n                    for f in [\"amx_bf16\", \"amx_int8\", \"amx_tile\", \"amx-bf16\", \"amx-int8\", \"amx-tile\"]\n                ):\n                    info[\"features\"].add(\"AMX\")\n\n                # Fine-grained AVX512 subset detection\n                if any(f in flags for f in [\"avx512_vnni\", \"avx512vnni\"]):\n                    info[\"features\"].add(\"AVX512_VNNI\")\n                if any(f in flags for f in [\"avx512_bf16\", \"avx512bf16\"]):\n                    info[\"features\"].add(\"AVX512_BF16\")\n                if any(f in flags for f in [\"avx512_vbmi\", \"avx512vbmi\"]):\n                    info[\"features\"].add(\"AVX512_VBMI\")\n                if any(f in flags for f in [\"avx512_vpopcntdq\", \"avx512vpopcntdq\"]):\n                    info[\"features\"].add(\"AVX512_VPOPCNTDQ\")\n\n            elif sysname == \"Darwin\":\n                # macOS: Apple Silicon (arm64) vs Intel\n                arch = platform.machine().lower()\n                info[\"arch\"] = arch\n                if arch in (\"arm64\", \"aarch64\"):\n                    info[\"vendor\"] = \"arm\"\n                else:\n                    info[\"vendor\"] = \"intel\"\n                # No AVX/AMX on Apple Silicon; assume none\n\n            elif sysname == \"Windows\":\n                # Minimal detection via arch; detailed CPUID omitted for brevity\n                arch = platform.machine().lower()\n                info[\"arch\"] = arch\n                if arch in (\"arm64\", \"aarch64\"):\n                    info[\"vendor\"] = \"arm\"\n                else:\n                    # Could be Intel or AMD; leave unknown\n                    info[\"vendor\"] = \"unknown\"\n        except Exception as e:\n            print(f\"Warning: CPU detection failed: {e}\")\n        return info\n\n    def build_extension(self, ext: CMakeExtension):\n        \"\"\"\n        Main entry point for building the extension.\n\n        Checks if multi-variant build is requested (CPUINFER_BUILD_ALL_VARIANTS=1)\n        and routes to the appropriate build method.\n        \"\"\"\n        if _env_get_bool(\"CPUINFER_BUILD_ALL_VARIANTS\", False):\n            # Build all 3 variants (AMX, AVX512, AVX2)\n            self.build_multi_variants(ext)\n        else:\n            # Build single variant (original behavior)\n            self._build_single_variant(ext)\n\n    def build_multi_variants(self, ext: CMakeExtension):\n        \"\"\"\n        Build all 6 CPU variants with progressive AVX512 capabilities.\n\n        This creates 6 separate .so files optimized for different CPU generations:\n        - _kt_kernel_ext_avx2.so         (Haswell+, 2013)\n        - _kt_kernel_ext_avx512_base.so  (Skylake-X+, 2017)\n        - _kt_kernel_ext_avx512_vnni.so  (Cascade Lake+, 2019)\n        - _kt_kernel_ext_avx512_vbmi.so  (Ice Lake client, 2019)\n        - _kt_kernel_ext_avx512_bf16.so  (Ice Lake server/Zen 4+, 2021)\n        - _kt_kernel_ext_amx.so          (Sapphire Rapids+, 2023)\n\n        Runtime CPU detection (in _cpu_detect.py) will automatically select the best match.\n        \"\"\"\n        print(\"=\" * 70)\n        print(\"Building kt-kernel with ALL 6 CPU variants\")\n        print(\"=\" * 70)\n        print()\n        print(\"This will build six progressive variants in a single wheel:\")\n        print(\"  1. AVX2          - Haswell+ (2013)\")\n        print(\"  2. AVX512 Base   - Skylake-X+ (2017)\")\n        print(\"  3. AVX512+VNNI   - Cascade Lake+ (2019)\")\n        print(\"  4. AVX512+VBMI   - Ice Lake client (2019)\")\n        print(\"  5. AVX512+BF16   - Ice Lake server, Zen 4+ (2021)\")\n        print(\"  6. AMX           - Sapphire Rapids+ (2023)\")\n        print()\n        print(\"Runtime CPU detection will automatically select the best variant.\")\n        print()\n\n        extdir = Path(self.get_ext_fullpath(ext.name)).parent.resolve()\n        cfg = default_build_type()\n\n        # Save original env vars to restore later\n        env_backup = {\n            \"CPUINFER_CPU_INSTRUCT\": os.environ.get(\"CPUINFER_CPU_INSTRUCT\"),\n            \"CPUINFER_ENABLE_AMX\": os.environ.get(\"CPUINFER_ENABLE_AMX\"),\n            \"CPUINFER_ENABLE_AVX512\": os.environ.get(\"CPUINFER_ENABLE_AVX512\"),\n            \"CPUINFER_ENABLE_AVX512_VNNI\": os.environ.get(\"CPUINFER_ENABLE_AVX512_VNNI\"),\n            \"CPUINFER_ENABLE_AVX512_BF16\": os.environ.get(\"CPUINFER_ENABLE_AVX512_BF16\"),\n            \"CPUINFER_ENABLE_AVX512_VBMI\": os.environ.get(\"CPUINFER_ENABLE_AVX512_VBMI\"),\n        }\n\n        # Variant configurations: (name, description, env_vars)\n        # Each variant specifies exactly which features to enable\n        variants = [\n            (\n                \"avx2\",\n                \"AVX2 baseline\",\n                {\n                    \"CPUINFER_CPU_INSTRUCT\": \"AVX2\",\n                    \"CPUINFER_ENABLE_AVX512\": \"OFF\",\n                    \"CPUINFER_ENABLE_AMX\": \"OFF\",\n                },\n            ),\n            (\n                \"avx512_base\",\n                \"AVX512F+BW\",\n                {\n                    \"CPUINFER_CPU_INSTRUCT\": \"AVX512\",\n                    \"CPUINFER_ENABLE_AVX512\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_VNNI\": \"OFF\",\n                    \"CPUINFER_ENABLE_AVX512_BF16\": \"OFF\",\n                    \"CPUINFER_ENABLE_AVX512_VBMI\": \"OFF\",\n                    \"CPUINFER_ENABLE_AMX\": \"OFF\",\n                },\n            ),\n            (\n                \"avx512_vnni\",\n                \"AVX512F+VNNI\",\n                {\n                    \"CPUINFER_CPU_INSTRUCT\": \"AVX512\",\n                    \"CPUINFER_ENABLE_AVX512\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_VNNI\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_BF16\": \"OFF\",\n                    \"CPUINFER_ENABLE_AVX512_VBMI\": \"OFF\",\n                    \"CPUINFER_ENABLE_AMX\": \"OFF\",\n                },\n            ),\n            (\n                \"avx512_vbmi\",\n                \"AVX512F+VNNI+VBMI\",\n                {\n                    \"CPUINFER_CPU_INSTRUCT\": \"AVX512\",\n                    \"CPUINFER_ENABLE_AVX512\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_VNNI\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_BF16\": \"OFF\",\n                    \"CPUINFER_ENABLE_AVX512_VBMI\": \"ON\",\n                    \"CPUINFER_ENABLE_AMX\": \"OFF\",\n                },\n            ),\n            (\n                \"avx512_bf16\",\n                \"AVX512 Full (F+VNNI+VBMI+BF16)\",\n                {\n                    \"CPUINFER_CPU_INSTRUCT\": \"AVX512\",\n                    \"CPUINFER_ENABLE_AVX512\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_VNNI\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_BF16\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_VBMI\": \"ON\",\n                    \"CPUINFER_ENABLE_AMX\": \"OFF\",\n                },\n            ),\n            (\n                \"amx\",\n                \"AMX + AVX512 Full\",\n                {\n                    \"CPUINFER_CPU_INSTRUCT\": \"AVX512\",\n                    \"CPUINFER_ENABLE_AVX512\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_VNNI\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_BF16\": \"ON\",\n                    \"CPUINFER_ENABLE_AVX512_VBMI\": \"ON\",\n                    \"CPUINFER_ENABLE_AMX\": \"ON\",\n                },\n            ),\n        ]\n\n        for variant_name, variant_desc, env_vars in variants:\n            print(\"=\" * 70)\n            print(f\"Building {variant_name.upper()} variant ({variant_desc})\")\n            print(\"=\" * 70)\n            print()\n\n            # Set environment variables for this variant\n            for key, value in env_vars.items():\n                os.environ[key] = value\n                print(f\"  {key} = {value}\")\n\n            # Use separate build directory for each variant\n            build_temp = Path(self.build_temp) / f\"{ext.name}_{cfg}_{variant_name}\"\n            build_temp.mkdir(parents=True, exist_ok=True)\n\n            # Build this variant\n            self._build_single_variant_impl(ext, extdir, build_temp, cfg)\n\n            # Rename the built .so file to include variant suffix\n            # Original name: kt_kernel_ext.cpython-311-x86_64-linux-gnu.so\n            # New name: _kt_kernel_ext_amx.cpython-311-x86_64-linux-gnu.so\n            built_so_files = list(extdir.glob(f\"{ext.name.split('.')[-1]}.*.so\"))\n            if built_so_files:\n                original_so = built_so_files[0]\n                # Extract the suffix after the module name\n                # e.g., \"kt_kernel_ext.cpython-311-x86_64-linux-gnu.so\" -> \".cpython-311-x86_64-linux-gnu.so\"\n                suffix = original_so.name.replace(ext.name.split(\".\")[-1], \"\")\n                new_name = f\"_kt_kernel_ext_{variant_name}{suffix}\"\n                new_path = extdir / new_name\n\n                # Remove existing file if present\n                if new_path.exists():\n                    new_path.unlink()\n\n                # Rename\n                original_so.rename(new_path)\n                print(f\"✓ Built and renamed to: {new_name}\")\n                print()\n            else:\n                print(f\"⚠ Warning: Could not find built .so file for {variant_name} variant\")\n                print()\n\n        # Restore original env vars\n        for key, value in env_backup.items():\n            if value is not None:\n                os.environ[key] = value\n            elif key in os.environ:\n                del os.environ[key]\n\n        print(\"=\" * 70)\n        print(\"✓ All 6 variants built successfully!\")\n        print(\"=\" * 70)\n        print()\n        print(\"The wheel now contains 6 CPU variants:\")\n        for so_file in sorted(extdir.glob(\"_kt_kernel_ext_*.so\")):\n            print(f\"  - {so_file.name}\")\n        print()\n\n    def _build_single_variant(self, ext: CMakeExtension):\n        \"\"\"Original single-variant build logic - wrapper for backward compatibility.\"\"\"\n        extdir = Path(self.get_ext_fullpath(ext.name)).parent.resolve()\n        cfg = default_build_type()\n        build_temp = Path(self.build_temp) / f\"{ext.name}_{cfg}\"\n        build_temp.mkdir(parents=True, exist_ok=True)\n\n        self._build_single_variant_impl(ext, extdir, build_temp, cfg)\n\n    def _build_single_variant_impl(self, ext: CMakeExtension, extdir: Path, build_temp: Path, cfg: str):\n        \"\"\"\n        Core build logic for a single variant.\n\n        This method contains the actual CMake configuration and build steps.\n        It's called by both _build_single_variant() and build_multi_variants().\n\n        Args:\n            ext: The CMakeExtension to build\n            extdir: Directory where the .so file should be placed\n            build_temp: Temporary build directory for CMake\n            cfg: Build type (Release/Debug/etc.)\n        \"\"\"\n\n        # Auto-detect CUDA toolkit if user did not explicitly set CPUINFER_USE_CUDA\n        def detect_cuda_toolkit() -> bool:\n            # Respect CUDA_HOME\n            cuda_home = os.environ.get(\"CUDA_HOME\")\n            if cuda_home:\n                nvcc_path = Path(cuda_home) / \"bin\" / \"nvcc\"\n                if nvcc_path.exists():\n                    return True\n            # PATH lookup\n            if shutil.which(\"nvcc\") is not None:\n                return True\n            # Common default install prefix\n            if Path(\"/usr/local/cuda/bin/nvcc\").exists():\n                return True\n            return False\n\n        # Locate nvcc executable (without forcing user to set -DCMAKE_CUDA_COMPILER)\n        def find_nvcc_path() -> str | None:\n            cuda_home = os.environ.get(\"CUDA_HOME\")\n            if cuda_home:\n                cand = Path(cuda_home) / \"bin\" / \"nvcc\"\n                if cand.exists():\n                    return str(cand)\n            which_nvcc = shutil.which(\"nvcc\")\n            if which_nvcc:\n                return which_nvcc\n            # Common fallbacks (ordered by preference)\n            for cand in [\n                \"/usr/local/cuda-12.6/bin/nvcc\",\n                \"/usr/local/cuda/bin/nvcc\",\n                \"/usr/bin/nvcc\",\n                \"/usr/lib/nvidia-cuda-toolkit/bin/nvcc\",\n            ]:\n                if Path(cand).exists():\n                    return cand\n            return None\n\n        # Note: We no longer set CMAKE_CUDA_ARCHITECTURES by default.\n        # If users want to specify CUDA archs, they can set env CPUINFER_CUDA_ARCHS\n        # (e.g. \"89\" or \"86;89\") or pass it via CMAKE_ARGS.\n        auto_moe_kernel_ = False\n        # Normalize CPUINFER_USE_CUDA: if unset, auto-detect; otherwise respect truthy/falsey values\n        cuda_env = _env_get_bool(\"CPUINFER_USE_CUDA\", None)\n        if cuda_env is None:\n            auto_cuda = detect_cuda_toolkit()\n            os.environ[\"CPUINFER_USE_CUDA\"] = \"1\" if auto_cuda else \"0\"\n            print(f\"-- CPUINFER_USE_CUDA not set; auto-detected CUDA toolkit: {'YES' if auto_cuda else 'NO'}\")\n\n        # Base CMake args\n        cmake_args = [\n            f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}/\",\n            f\"-DPYTHON_EXECUTABLE={sys.executable}\",\n            f\"-DCMAKE_BUILD_TYPE={cfg}\",\n        ]\n\n        # CPU feature flags mapping: if user specified CPUINFER_CPU_INSTRUCT, honor it;\n        # else auto-pick based on detection (x86 only)\n        cmake_args += cpu_feature_flags()\n        d = self.detect_cpu_info()\n        print(f\"Detected CPU info: {d}\")\n        cpu_mode = os.environ.get(\"CPUINFER_CPU_INSTRUCT\", \"NATIVE\").upper()\n\n        # Vendor / feature specific toggles\n        # AMD MoE: explicit env overrides; otherwise default ON on AMD CPU\n        _forward_bool_env(cmake_args, \"CPUINFER_ENABLE_BLIS\", \"KTRANSFORMERS_CPU_MOE_AMD\")\n        # if d.get(\"vendor\") == \"amd\":\n        #     auto_moe_kernel_ = True\n        #     cmake_args.append(\"-DKTRANSFORMERS_CPU_MOE_AMD=ON\")\n        #     print(\"-- Detected AMD CPU; enabling AMD MoE kernel (-DKTRANSFORMERS_CPU_MOE_AMD=ON)\")\n        #     _forward_str_env(cmake_args, \"CPUINFER_BLIS_ROOT\", \"BLIS_ROOT\")\n\n        # KML: explicit env overrides; otherwise default ON on ARM\n        _forward_bool_env(cmake_args, \"CPUINFER_ENABLE_KML\", \"KTRANSFORMERS_CPU_USE_KML\")\n        # if d.get(\"vendor\") == \"arm\":\n        #     auto_moe_kernel_ = True\n        #     cmake_args.append(\"-DKTRANSFORMERS_CPU_USE_KML=ON\")\n        #     print(\"-- Detected ARM CPU; enabling KML (-DKTRANSFORMERS_CPU_USE_KML=ON)\")\n\n        # AMX: explicit env overrides; else enable if detected\n        if not _forward_bool_env(cmake_args, \"CPUINFER_ENABLE_AMX\", \"KTRANSFORMERS_CPU_USE_AMX\"):\n            if \"AMX\" in d[\"features\"]:\n                cmake_args.append(\"-DKTRANSFORMERS_CPU_USE_AMX=ON\")\n                print(\"-- AMX support detected; enabling (-DKTRANSFORMERS_CPU_USE_AMX=ON)\")\n\n        # AVX512 umbrella (AMX/AVX512 kernels):\n        # - If user explicitly sets CPUINFER_ENABLE_AVX512 -> honor it\n        # - Otherwise, only auto-enable when CPU mode actually wants AVX512\n        #   (NATIVE/FANCY/AVX512). In AVX2 mode we do NOT enable this, so\n        #   RAWINT4 / K2 kernels are not compiled.\n        if not _forward_bool_env(cmake_args, \"CPUINFER_ENABLE_AVX512\", \"KTRANSFORMERS_CPU_USE_AMX_AVX512\"):\n            if cpu_mode in (\"NATIVE\", \"FANCY\", \"AVX512\") and (\"AMX\" in d[\"features\"] or \"AVX512\" in d[\"features\"]):\n                cmake_args.append(\"-DKTRANSFORMERS_CPU_USE_AMX_AVX512=ON\")\n                print(\"-- Enabling AMX/AVX512 umbrella (-DKTRANSFORMERS_CPU_USE_AMX_AVX512=ON)\")\n            else:\n                print(f\"-- CPUINFER_CPU_INSTRUCT={cpu_mode}; not auto-enabling AMX/AVX512 umbrella\")\n\n        # Fine-grained AVX512 subset flags: only enable if CPU actually supports them\n        # These are passed to CMake to conditionally add compiler flags\n        # Track if any AVX512 extension is enabled\n        avx512_extension_enabled = False\n        allow_avx512_ext_auto = cpu_mode in (\"NATIVE\", \"FANCY\", \"AVX512\")\n\n        if not _forward_bool_env(cmake_args, \"CPUINFER_ENABLE_AVX512_VNNI\", \"LLAMA_AVX512_VNNI\"):\n            if allow_avx512_ext_auto and \"AVX512_VNNI\" in d[\"features\"]:\n                cmake_args.append(\"-DLLAMA_AVX512_VNNI=ON\")\n                print(\"-- AVX512_VNNI detected; enabling (-DLLAMA_AVX512_VNNI=ON)\")\n                avx512_extension_enabled = True\n        else:\n            avx512_extension_enabled = True\n\n        if not _forward_bool_env(cmake_args, \"CPUINFER_ENABLE_AVX512_BF16\", \"LLAMA_AVX512_BF16\"):\n            if allow_avx512_ext_auto and \"AVX512_BF16\" in d[\"features\"]:\n                cmake_args.append(\"-DLLAMA_AVX512_BF16=ON\")\n                print(\"-- AVX512_BF16 detected; enabling (-DLLAMA_AVX512_BF16=ON)\")\n                avx512_extension_enabled = True\n        else:\n            avx512_extension_enabled = True\n\n        if not _forward_bool_env(cmake_args, \"CPUINFER_ENABLE_AVX512_VBMI\", \"LLAMA_AVX512_VBMI\"):\n            if allow_avx512_ext_auto and \"AVX512_VBMI\" in d[\"features\"]:\n                cmake_args.append(\"-DLLAMA_AVX512_VBMI=ON\")\n                print(\"-- AVX512_VBMI detected; enabling (-DLLAMA_AVX512_VBMI=ON)\")\n                avx512_extension_enabled = True\n        else:\n            avx512_extension_enabled = True\n\n        # If any AVX512 extension is enabled, ensure base AVX512 is also enabled\n        if avx512_extension_enabled and cpu_mode in (\"NATIVE\", \"FANCY\", \"AVX512\"):\n            if not any(\"LLAMA_AVX512=ON\" in a for a in cmake_args):\n                cmake_args.append(\"-DLLAMA_AVX512=ON\")\n                print(\"-- AVX512 extensions enabled; also enabling base AVX512F (-DLLAMA_AVX512=ON)\")\n\n        # Auto-enable MOE kernel only when env explicitly turns on AMD or KML backend\n        # (Do not enable purely on vendor auto-detection to avoid surprise behavior.)\n        amd_env = _env_get_bool(\"CPUINFER_ENABLE_BLIS\", None)\n        kml_env = _env_get_bool(\"CPUINFER_ENABLE_KML\", None)\n        if amd_env or kml_env:\n            auto_moe_kernel_ = True\n        already_set = any(\"KTRANSFORMERS_CPU_MOE_KERNEL\" in a for a in cmake_args)\n        if not already_set and auto_moe_kernel_:\n            cmake_args.append(\"-DKTRANSFORMERS_CPU_MOE_KERNEL=ON\")\n            print(\n                \"-- Auto-enabling MOE kernel (-DKTRANSFORMERS_CPU_MOE_KERNEL=ON) because CPUINFER_ENABLE_BLIS or CPUINFER_ENABLE_KML is ON\"\n            )\n\n        # Friendly summary\n        print(\n            f\"-- CPU detection: vendor={d.get('vendor')} arch={d.get('arch')} features={sorted(list(d.get('features', [])))}\"\n        )\n\n        # MLA toggle (string/boolean allowed)\n        if not _forward_bool_env(cmake_args, \"CPUINFER_ENABLE_MLA\", \"KTRANSFORMERS_CPU_MLA\"):\n            _forward_str_env(cmake_args, \"CPUINFER_ENABLE_MLA\", \"KTRANSFORMERS_CPU_MLA\")\n\n        # LTO toggles\n        _forward_bool_env(cmake_args, \"CPUINFER_ENABLE_LTO\", \"CPUINFER_ENABLE_LTO\")\n        _forward_str_env(cmake_args, \"CPUINFER_LTO_JOBS\", \"CPUINFER_LTO_JOBS\")\n        _forward_str_env(cmake_args, \"CPUINFER_LTO_MODE\", \"CPUINFER_LTO_MODE\")\n\n        # CUDA static runtime toggle\n        _forward_bool_env(cmake_args, \"CPUINFER_CUDA_STATIC_RUNTIME\", \"KTRANSFORMERS_CUDA_STATIC_RUNTIME\")\n\n        # GPU backends (mutually exclusive expected)\n        if _env_get_bool(\"CPUINFER_USE_CUDA\", False):\n            cmake_args.append(\"-DKTRANSFORMERS_USE_CUDA=ON\")\n            print(\"-- Enabling CUDA backend (-DKTRANSFORMERS_USE_CUDA=ON)\")\n            # Inject nvcc compiler path automatically unless user already specified one.\n            user_specified_compiler = any(\"CMAKE_CUDA_COMPILER\" in a for a in cmake_args)\n            if not user_specified_compiler:\n                extra_env = os.environ.get(\"CMAKE_ARGS\", \"\")\n                if \"CMAKE_CUDA_COMPILER\" in extra_env:\n                    user_specified_compiler = True\n            if not user_specified_compiler:\n                nvcc_path = find_nvcc_path()\n                if nvcc_path:\n                    cmake_args.append(f\"-DCMAKE_CUDA_COMPILER={nvcc_path}\")\n                    print(f\"-- Auto-detected nvcc: {nvcc_path} (adding -DCMAKE_CUDA_COMPILER)\")\n                else:\n                    print(\"-- Warning: nvcc not found via CUDA_HOME/PATH/common prefixes; CUDA configure may fail.\")\n            # Optional host compiler for nvcc if user set CUDAHOSTCXX\n            if os.environ.get(\"CUDAHOSTCXX\"):\n                hostcxx = os.environ[\"CUDAHOSTCXX\"]\n                cmake_args.append(f\"-DCMAKE_CUDA_HOST_COMPILER={hostcxx}\")\n                print(f\"-- Using CUDA host compiler from CUDAHOSTCXX: {hostcxx}\")\n            # Set CUDA architectures (default: Ampere/Ada/Hopper)\n            archs_env = os.environ.get(\"CPUINFER_CUDA_ARCHS\", \"80;86;89;90\").strip()\n            if archs_env and not any(\"CMAKE_CUDA_ARCHITECTURES\" in a for a in cmake_args):\n                cmake_args.append(f\"-DCMAKE_CUDA_ARCHITECTURES={archs_env}\")\n                print(f\"-- Set CUDA architectures: {archs_env}\")\n        if _env_get_bool(\"CPUINFER_USE_ROCM\", False):\n            cmake_args.append(\"-DKTRANSFORMERS_USE_ROCM=ON\")\n        if _env_get_bool(\"CPUINFER_USE_MUSA\", False):\n            cmake_args.append(\"-DKTRANSFORMERS_USE_MUSA=ON\")\n\n        # Respect user extra CMAKE_ARGS (space separated)\n        extra = os.environ.get(\"CMAKE_ARGS\")\n        if extra:\n            cmake_args += [a for a in extra.split() if a]\n\n        # Force rebuild? (delete cache)\n        if _env_get_bool(\"CPUINFER_FORCE_REBUILD\", True):\n            cache = build_temp / \"CMakeCache.txt\"\n            if cache.exists():\n                cache.unlink()\n\n        print(\"-- CMake configure args:\")\n        for a in cmake_args:\n            print(\"   \", a)\n\n        # Configure\n        subprocess.run([\"cmake\", ext.sourcedir, *cmake_args], cwd=build_temp, check=True)\n\n        # Build\n        build_args = [\"--build\", \".\", \"--config\", cfg]\n        jobs = detect_parallel_jobs()\n        if jobs:\n            build_args += [\"--parallel\", jobs]\n        print(\"-- CMake build args:\", \" \".join(build_args))\n        subprocess.run([\"cmake\", *build_args], cwd=build_temp, check=True)\n\n        # On some systems LTO + CMake + pybind may place the built .so inside build tree; move if needed\n        built_candidates = list(build_temp.rglob(f\"{ext.name}*.so\"))\n        for cand in built_candidates:\n            if cand.parent != extdir:\n                target = extdir / cand.name\n                target.parent.mkdir(parents=True, exist_ok=True)\n                # Overwrite stale\n                if not target.exists() or target.stat().st_mtime < cand.stat().st_mtime:\n                    print(f\"-- Copying {cand} -> {target}\")\n                    target.write_bytes(cand.read_bytes())\n\n\n################################################################################\n# Version (simple). If you later add a python package dir, you can read from it.\n################################################################################\n\n\n# Read base version from version.py\n_version_file = Path(__file__).resolve().parent.parent / \"version.py\"\nif _version_file.exists():\n    _version_ns = {}\n    with open(_version_file, \"r\", encoding=\"utf-8\") as f:\n        exec(f.read(), _version_ns)\n    _base_version = _version_ns.get(\"__version__\", \"0.5.0\")\nelse:\n    _base_version = \"0.5.0\"\n\n# Determine version\nif \"CPUINFER_VERSION\" in os.environ:\n    # User explicitly set version (e.g., for testing)\n    VERSION = os.environ[\"CPUINFER_VERSION\"]\n    print(f\"-- Explicit version: {VERSION}\")\nelse:\n    VERSION = _base_version\n    print(f\"-- Version: {VERSION}\")\n\n# Package name is always kt-kernel\n# The CUDA-enabled wheel includes both CPU multi-variant support and CUDA capabilities\nPACKAGE_NAME = \"kt-kernel\"\ncuda_enabled = _env_get_bool(\"CPUINFER_USE_CUDA\", False)\nif cuda_enabled:\n    print(f\"-- Building kt-kernel with CUDA support (+ CPU multi-variant)\")\nelse:\n    print(f\"-- Building kt-kernel (CPU-only multi-variant)\")\n\n################################################################################\n# Setup\n################################################################################\n\nsetup(\n    name=PACKAGE_NAME,\n    version=VERSION,\n    description=\"KT-Kernel: High-performance kernel operations for KTransformers (AMX/AVX/KML optimizations)\",\n    author=\"kvcache-ai\",\n    license=\"Apache-2.0\",\n    python_requires=\">=3.8\",\n    packages=[\n        \"kt_kernel\",\n        \"kt_kernel.utils\",\n        \"kt_kernel.cli\",\n        \"kt_kernel.cli.commands\",\n        \"kt_kernel.cli.config\",\n        \"kt_kernel.cli.utils\",\n    ],\n    package_dir={\n        \"kt_kernel\": \"python\",\n        \"kt_kernel.utils\": \"python/utils\",\n        \"kt_kernel.cli\": \"python/cli\",\n        \"kt_kernel.cli.commands\": \"python/cli/commands\",\n        \"kt_kernel.cli.config\": \"python/cli/config\",\n        \"kt_kernel.cli.utils\": \"python/cli/utils\",\n    },\n    entry_points={\n        \"console_scripts\": [\n            \"kt=kt_kernel.cli.main:main\",\n        ],\n    },\n    ext_modules=[CMakeExtension(\"kt_kernel.kt_kernel_ext\", str(REPO_ROOT))],\n    cmdclass={\"build_ext\": CMakeBuild},\n    zip_safe=False,\n    classifiers=[\n        \"Programming Language :: Python :: 3\",\n        \"Programming Language :: C++\",\n        \"Operating System :: POSIX :: Linux\",\n        \"Operating System :: MacOS\",\n    ],\n)\n"
  },
  {
    "path": "kt-kernel/test/__init__.py",
    "content": "\"\"\"KT-Kernel Test Suite\n\nThis test suite is adapted from SGLang's CI testing framework.\nIt provides hardware-aware test registration and execution with timeout control.\n\"\"\"\n"
  },
  {
    "path": "kt-kernel/test/ci/__init__.py",
    "content": "\"\"\"CI test registration and execution utilities.\"\"\"\n"
  },
  {
    "path": "kt-kernel/test/ci/ci_register.py",
    "content": "import ast\nimport warnings\nfrom dataclasses import dataclass\nfrom enum import Enum, auto\nfrom typing import List\n\n\nclass HWBackend(Enum):\n    CPU = auto()\n    CUDA = auto()\n    AMD = auto()\n\n\n@dataclass\nclass CIRegistry:\n    backend: HWBackend\n    filename: str\n    est_time: float\n    suite: str\n\n\ndef register_cpu_ci(est_time: float, suite: str):\n    pass\n\n\ndef register_cuda_ci(est_time: float, suite: str):\n    pass\n\n\ndef register_amd_ci(est_time: float, suite: str):\n    pass\n\n\nREGISTER_MAPPING = {\n    \"register_cpu_ci\": HWBackend.CPU,\n    \"register_cuda_ci\": HWBackend.CUDA,\n    \"register_amd_ci\": HWBackend.AMD,\n}\n\n\nclass RegistryVisitor(ast.NodeVisitor):\n    def __init__(self, filename: str):\n        self.filename = filename\n        self.registries: list[CIRegistry] = []\n\n    def _collect_ci_registry(self, func_call: ast.Call):\n        if not isinstance(func_call.func, ast.Name):\n            return None\n\n        if func_call.func.id not in REGISTER_MAPPING:\n            return None\n\n        hw = REGISTER_MAPPING[func_call.func.id]\n        est_time, suite = None, None\n        for kw in func_call.keywords:\n            if kw.arg == \"est_time\":\n                if isinstance(kw.value, ast.Constant):\n                    est_time = kw.value.value\n            elif kw.arg == \"suite\":\n                if isinstance(kw.value, ast.Constant):\n                    suite = kw.value.value\n\n        for i, arg in enumerate(func_call.args):\n            if isinstance(arg, ast.Constant):\n                if i == 0:\n                    est_time = arg.value\n                elif i == 1:\n                    suite = arg.value\n        assert (\n            est_time is not None\n        ), \"esimation_time is required and should be a constant\"\n        assert suite is not None, \"suite is required and should be a constant\"\n        return CIRegistry(\n            backend=hw, filename=self.filename, est_time=est_time, suite=suite\n        )\n\n    def visit_Module(self, node):\n        for stmt in node.body:\n            if not isinstance(stmt, ast.Expr) or not isinstance(stmt.value, ast.Call):\n                continue\n\n            cr = self._collect_ci_registry(stmt.value)\n            if cr is not None:\n                self.registries.append(cr)\n\n        self.generic_visit(node)\n\n\ndef ut_parse_one_file(filename: str) -> List[CIRegistry]:\n    with open(filename, \"r\") as f:\n        file_content = f.read()\n    tree = ast.parse(file_content, filename=filename)\n    visitor = RegistryVisitor(filename=filename)\n    visitor.visit(tree)\n    return visitor.registries\n\n\ndef collect_tests(files: list[str], sanity_check: bool = True) -> List[CIRegistry]:\n    ci_tests = []\n    for file in files:\n        registries = ut_parse_one_file(file)\n        if len(registries) == 0:\n            msg = f\"No CI registry found in {file}\"\n            if sanity_check:\n                raise ValueError(msg)\n            else:\n                warnings.warn(msg)\n                continue\n\n        ci_tests.extend(registries)\n\n    return ci_tests\n"
  },
  {
    "path": "kt-kernel/test/ci/ci_utils.py",
    "content": "import os\nimport subprocess\nimport threading\nimport time\nfrom dataclasses import dataclass\nfrom typing import Callable, List, Optional\n\nimport psutil, signal, sys\ndef kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):\n    \"\"\"Kill the process and all its child processes.\"\"\"\n    # Remove sigchld handler to avoid spammy logs.\n    if threading.current_thread() is threading.main_thread():\n        signal.signal(signal.SIGCHLD, signal.SIG_DFL)\n\n    if parent_pid is None:\n        parent_pid = os.getpid()\n        include_parent = False\n\n    try:\n        itself = psutil.Process(parent_pid)\n    except psutil.NoSuchProcess:\n        return\n\n    children = itself.children(recursive=True)\n    for child in children:\n        if child.pid == skip_pid:\n            continue\n        try:\n            child.kill()\n        except psutil.NoSuchProcess:\n            pass\n\n    if include_parent:\n        try:\n            if parent_pid == os.getpid():\n                itself.kill()\n                sys.exit(0)\n\n            itself.kill()\n\n            # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),\n            # so we send an additional signal to kill them.\n            itself.send_signal(signal.SIGQUIT)\n        except psutil.NoSuchProcess:\n            pass\n\n\n@dataclass\nclass TestFile:\n    name: str\n    estimated_time: float = 60\n\n\ndef run_with_timeout(\n    func: Callable,\n    args: tuple = (),\n    kwargs: Optional[dict] = None,\n    timeout: float = None,\n):\n    \"\"\"Run a function with timeout.\"\"\"\n    ret_value = []\n\n    def _target_func():\n        ret_value.append(func(*args, **(kwargs or {})))\n\n    t = threading.Thread(target=_target_func)\n    t.start()\n    t.join(timeout=timeout)\n    if t.is_alive():\n        raise TimeoutError()\n\n    if not ret_value:\n        raise RuntimeError()\n\n    return ret_value[0]\n\n\ndef run_unittest_files(\n    files: List[TestFile], timeout_per_file: float, continue_on_error: bool = False\n):\n    \"\"\"\n    Run a list of test files.\n\n    Args:\n        files: List of TestFile objects to run\n        timeout_per_file: Timeout in seconds for each test file\n        continue_on_error: If True, continue running remaining tests even if one fails.\n                          If False, stop at first failure (default behavior for PR tests).\n    \"\"\"\n    tic = time.perf_counter()\n    success = True\n    passed_tests = []\n    failed_tests = []\n\n    for i, file in enumerate(files):\n        filename, estimated_time = file.name, file.estimated_time\n        process = None\n\n        def run_one_file(filename):\n            nonlocal process\n\n            filename = os.path.join(os.getcwd(), filename)\n            print(\n                f\".\\n.\\nBegin ({i}/{len(files) - 1}):\\npython3 {filename}\\n.\\n.\\n\",\n                flush=True,\n            )\n            tic = time.perf_counter()\n\n            process = subprocess.Popen(\n                [\"python3\", filename], stdout=None, stderr=None, env=os.environ\n            )\n            process.wait()\n            elapsed = time.perf_counter() - tic\n\n            print(\n                f\".\\n.\\nEnd ({i}/{len(files) - 1}):\\n{filename=}, {elapsed=:.0f}, {estimated_time=}\\n.\\n.\\n\",\n                flush=True,\n            )\n            return process.returncode\n\n        try:\n            ret_code = run_with_timeout(\n                run_one_file, args=(filename,), timeout=timeout_per_file\n            )\n            if ret_code != 0:\n                print(\n                    f\"\\n✗ FAILED: {filename} returned exit code {ret_code}\\n\",\n                    flush=True,\n                )\n                success = False\n                failed_tests.append((filename, f\"exit code {ret_code}\"))\n                if not continue_on_error:\n                    # Stop at first failure for PR tests\n                    break\n                # Otherwise continue to next test for nightly tests\n            else:\n                passed_tests.append(filename)\n        except TimeoutError:\n            kill_process_tree(process.pid)\n            time.sleep(5)\n            print(\n                f\"\\n✗ TIMEOUT: {filename} after {timeout_per_file} seconds\\n\",\n                flush=True,\n            )\n            success = False\n            failed_tests.append((filename, f\"timeout after {timeout_per_file}s\"))\n            if not continue_on_error:\n                # Stop at first timeout for PR tests\n                break\n            # Otherwise continue to next test for nightly tests\n\n    if success:\n        print(f\"Success. Time elapsed: {time.perf_counter() - tic:.2f}s\", flush=True)\n    else:\n        print(f\"Fail. Time elapsed: {time.perf_counter() - tic:.2f}s\", flush=True)\n\n    # Print summary\n    print(f\"\\n{'='*60}\", flush=True)\n    print(f\"Test Summary: {len(passed_tests)}/{len(files)} passed\", flush=True)\n    print(f\"{'='*60}\", flush=True)\n    if passed_tests:\n        print(\"✓ PASSED:\", flush=True)\n        for test in passed_tests:\n            print(f\"  {test}\", flush=True)\n    if failed_tests:\n        print(\"\\n✗ FAILED:\", flush=True)\n        for test, reason in failed_tests:\n            print(f\"  {test} ({reason})\", flush=True)\n    print(f\"{'='*60}\\n\", flush=True)\n\n    return 0 if success else -1\n"
  },
  {
    "path": "kt-kernel/test/per_commit/__init__.py",
    "content": "\"\"\"Per-commit tests for KT-Kernel.\n\nTests in this directory are run on every commit in CI.\n\"\"\"\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_amd_placeholder.py",
    "content": "\"\"\"AMD/ROCm backend tests for KT-Kernel (Placeholder).\n\nThis file is a placeholder for future AMD/ROCm backend tests.\nCurrently, KT-Kernel focuses on CPU optimizations (Intel AMX/AVX512).\n\nTo implement AMD tests:\n1. Add actual test functions with @pytest.mark.amd\n2. Update the estimated time in register_amd_ci()\n3. Implement AMD/ROCm-specific initialization and validation tests\n\"\"\"\n\nimport os\nimport sys\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_amd_ci\n\n# Register this test for AMD CI (estimated time: 10 seconds, placeholder)\n# Update suite name when implementing: currently using \"stage-a-test-1\"\nregister_amd_ci(est_time=10, suite=\"stage-a-test-1\")\n\n\ndef test_amd_placeholder():\n    \"\"\"Placeholder test for AMD/ROCm backend.\n\n    TODO: Implement actual AMD/ROCm tests when AMD support is added to kt-kernel.\n    \"\"\"\n    # Currently a no-op placeholder\n    pass\n\n\nif __name__ == \"__main__\":\n    # Allow running standalone (required by test runner)\n    print(\"⚠ AMD/ROCm tests are not yet implemented (placeholder)\")\n    print(\"✓ Placeholder test passed\")\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_basic_cpu.py",
    "content": "\"\"\"Basic CPU backend tests for KT-Kernel.\n\nThese tests verify basic functionality without requiring model files.\n\"\"\"\n\nimport os\nimport sys\nimport pytest\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_cpu_ci\n\n# Register this test for CPU CI with estimated runtime of 30 seconds\nregister_cpu_ci(est_time=30, suite=\"default\")\n\n# Check if kt_kernel_ext is available\ntry:\n    import kt_kernel  # Import kt_kernel first to register kt_kernel_ext\n\n    kt_kernel_ext = kt_kernel.kt_kernel_ext  # Access the extension module\n    HAS_KT_KERNEL = True\nexcept ImportError:\n    HAS_KT_KERNEL = False\n    kt_kernel_ext = None\n\n\n@pytest.mark.cpu\ndef test_kt_kernel_import():\n    \"\"\"Test that kt_kernel_ext can be imported.\"\"\"\n    if not HAS_KT_KERNEL:\n        pytest.skip(\"kt_kernel_ext not built or available\")\n\n    assert kt_kernel_ext is not None, \"kt_kernel_ext module should be importable\"\n\n\n@pytest.mark.cpu\ndef test_cpu_infer_initialization():\n    \"\"\"Test that CPUInfer can be initialized.\"\"\"\n    if not HAS_KT_KERNEL:\n        pytest.skip(\"kt_kernel_ext not built or available\")\n\n    # Initialize CPUInfer with 4 threads\n    cpuinfer = kt_kernel_ext.CPUInfer(4)\n    assert cpuinfer is not None, \"CPUInfer should be initialized successfully\"\n\n\n@pytest.mark.cpu\ndef test_basic_module_attributes():\n    \"\"\"Test that kt_kernel_ext has expected attributes.\"\"\"\n    if not HAS_KT_KERNEL:\n        pytest.skip(\"kt_kernel_ext not built or available\")\n\n    # Check for key attributes/functions\n    assert hasattr(kt_kernel_ext, \"CPUInfer\"), \"kt_kernel_ext should have CPUInfer class\"\n\n\ndef run_all_tests():\n    \"\"\"Run all tests in this file (for standalone execution).\"\"\"\n    if not HAS_KT_KERNEL:\n        print(\"⚠ kt_kernel_ext not available, skipping tests\")\n        return\n\n    try:\n        test_kt_kernel_import()\n        print(\"✓ test_kt_kernel_import passed\")\n\n        test_cpu_infer_initialization()\n        print(\"✓ test_cpu_infer_initialization passed\")\n\n        test_basic_module_attributes()\n        print(\"✓ test_basic_module_attributes passed\")\n\n        print(\"\\n✓ All tests passed!\")\n    except Exception as e:\n        print(f\"\\n✗ Test failed: {e}\")\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    # Allow running standalone (required by test runner)\n    run_all_tests()\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_cuda_placeholder.py",
    "content": "\"\"\"CUDA backend tests for KT-Kernel (Placeholder).\n\nThis file is a placeholder for future CUDA backend tests.\nCurrently, KT-Kernel focuses on CPU optimizations (Intel AMX/AVX512).\n\nTo implement CUDA tests:\n1. Add actual test functions with @pytest.mark.cuda\n2. Update the estimated time in register_cuda_ci()\n3. Implement CUDA-specific initialization and validation tests\n\"\"\"\n\nimport os\nimport sys\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_cuda_ci\n\n# Register this test for CUDA CI (estimated time: 10 seconds, placeholder)\n# Update suite name when implementing: currently using \"stage-a-test-1\"\nregister_cuda_ci(est_time=10, suite=\"stage-a-test-1\")\n\n\ndef test_cuda_placeholder():\n    \"\"\"Placeholder test for CUDA backend.\n\n    TODO: Implement actual CUDA tests when CUDA support is added to kt-kernel.\n    \"\"\"\n    # Currently a no-op placeholder\n    pass\n\n\nif __name__ == \"__main__\":\n    # Allow running standalone (required by test runner)\n    print(\"⚠ CUDA tests are not yet implemented (placeholder)\")\n    print(\"✓ Placeholder test passed\")\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_moe_amx_accuracy_int4.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"AMX MOE INT4 accuracy tests for KT-Kernel.\n\nTests accuracy of AMX-accelerated INT4 MOE operations against torch reference.\n\"\"\"\n\nimport os\nimport sys\nimport pytest\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_cpu_ci\n\n# Register this test for CPU CI with estimated runtime of 120 seconds\nregister_cpu_ci(est_time=120, suite=\"default\")\n\n# Check if dependencies are available\ntry:\n    import torch\n    import kt_kernel  # Import kt_kernel first to register kt_kernel_ext\n\n    kt_kernel_ext = kt_kernel.kt_kernel_ext  # Access the extension module\n    HAS_DEPS = True\nexcept ImportError as e:\n    HAS_DEPS = False\n    import_error = str(e)\n\n# Test parameters (from original test_moe_amx.py)\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nqlen = 1\nlayer_num = 1\nvalidation_iter = 2\nphysical_to_logical_map = None\n\n\ndef act_fn(x):\n    \"\"\"Activation function for MoE.\"\"\"\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    \"\"\"PyTorch reference implementation of MLP.\"\"\"\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    \"\"\"PyTorch reference implementation of MoE.\"\"\"\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n\n    return t_output\n\n\n@pytest.mark.cpu\ndef test_moe_amx_int4_accuracy():\n    \"\"\"Test AMX INT4 MOE accuracy against PyTorch reference implementation.\"\"\"\n    if not HAS_DEPS:\n        pytest.skip(f\"Dependencies not available: {import_error}\")\n\n    global physical_to_logical_map\n    physical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n    CPUInfer = kt_kernel_ext.CPUInfer(60)\n\n    with torch.inference_mode(mode=True):\n        # Initialize MoE layers\n        gate_proj = (\n            torch.randn(\n                (expert_num, intermediate_size, hidden_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n        up_proj = (\n            torch.randn(\n                (expert_num, intermediate_size, hidden_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n        down_proj = (\n            torch.randn(\n                (expert_num, hidden_size, intermediate_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n\n        # Create MOE config\n        config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n        config.max_len = max_len\n        config.gate_proj = gate_proj.data_ptr()\n        config.up_proj = up_proj.data_ptr()\n        config.down_proj = down_proj.data_ptr()\n        config.gate_scale = 0\n        config.pool = CPUInfer.backend_\n\n        # Initialize INT4 MOE\n        moe = kt_kernel_ext.moe.AMXInt4_MOE(config)\n        CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n        CPUInfer.sync()\n        CPUInfer.submit(moe.warm_up_task())\n        CPUInfer.sync()\n\n        # Run validation iterations\n        for i in range(validation_iter):\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n            input_data = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            input_data = input_data / 100\n\n            # Run AMX MOE\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input_data.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            # Run torch reference\n            t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)\n\n            # Calculate relative difference\n            diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n            print(f\"Iteration {i}, diff = {diff:.6f}\")\n\n            # INT4 should have diff < 0.35\n            assert diff < 0.35, f\"INT4 accuracy test failed: diff={diff:.6f} >= 0.35\"\n\n\ndef run_all_tests():\n    \"\"\"Run all tests in this file (for standalone execution).\"\"\"\n    if not HAS_DEPS:\n        print(f\"⚠ Dependencies not available: {import_error}\")\n        print(\"Skipping AMX MOE INT4 accuracy tests\")\n        return\n\n    try:\n        print(\"Running AMX MOE INT4 accuracy test...\")\n        test_moe_amx_int4_accuracy()\n        print(\"✓ AMX MOE INT4 accuracy test passed\")\n        print(\"\\n✓ All tests passed!\")\n    except Exception as e:\n        print(f\"\\n✗ Test failed: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    run_all_tests()\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_moe_amx_accuracy_int4_1.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"AMX MOE INT4_1 accuracy tests for KT-Kernel.\n\nTests accuracy of AMX-accelerated INT4_1 MOE operations against torch reference.\n\"\"\"\n\nimport os\nimport sys\nimport pytest\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_cpu_ci\n\n# Register this test for CPU CI with estimated runtime of 120 seconds\nregister_cpu_ci(est_time=120, suite=\"default\")\n\n# Check if dependencies are available\ntry:\n    import torch\n    import kt_kernel  # Import kt_kernel first to register kt_kernel_ext\n\n    kt_kernel_ext = kt_kernel.kt_kernel_ext  # Access the extension module\n    HAS_DEPS = True\nexcept ImportError as e:\n    HAS_DEPS = False\n    import_error = str(e)\n\n# Test parameters (from original test_moe_amx.py)\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nqlen = 1\nlayer_num = 1\nvalidation_iter = 2\nphysical_to_logical_map = None\n\n\ndef act_fn(x):\n    \"\"\"Activation function for MoE.\"\"\"\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    \"\"\"PyTorch reference implementation of MLP.\"\"\"\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    \"\"\"PyTorch reference implementation of MoE.\"\"\"\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n\n    return t_output\n\n\n@pytest.mark.cpu\ndef test_moe_amx_int4_1_accuracy():\n    \"\"\"Test AMX INT4_1 MOE accuracy against PyTorch reference implementation.\"\"\"\n    if not HAS_DEPS:\n        pytest.skip(f\"Dependencies not available: {import_error}\")\n\n    global physical_to_logical_map\n    physical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n    CPUInfer = kt_kernel_ext.CPUInfer(60)\n\n    with torch.inference_mode(mode=True):\n        # Initialize MoE layers\n        gate_proj = (\n            torch.randn(\n                (expert_num, intermediate_size, hidden_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n        up_proj = (\n            torch.randn(\n                (expert_num, intermediate_size, hidden_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n        down_proj = (\n            torch.randn(\n                (expert_num, hidden_size, intermediate_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n\n        # Create MOE config\n        config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n        config.max_len = max_len\n        config.gate_proj = gate_proj.data_ptr()\n        config.up_proj = up_proj.data_ptr()\n        config.down_proj = down_proj.data_ptr()\n        config.gate_scale = 0\n        config.pool = CPUInfer.backend_\n\n        # Initialize INT4_1 MOE\n        moe = kt_kernel_ext.moe.AMXInt4_1_MOE(config)\n        CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n        CPUInfer.sync()\n        CPUInfer.submit(moe.warm_up_task())\n        CPUInfer.sync()\n\n        # Run validation iterations\n        for i in range(validation_iter):\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n            input_data = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            input_data = input_data / 100\n\n            # Run AMX MOE\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input_data.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            # Run torch reference\n            t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)\n\n            # Calculate relative difference\n            diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n            print(f\"Iteration {i}, diff = {diff:.6f}\")\n\n            # INT4_1 should have diff < 0.35\n            assert diff < 0.35, f\"INT4_1 accuracy test failed: diff={diff:.6f} >= 0.35\"\n\n\ndef run_all_tests():\n    \"\"\"Run all tests in this file (for standalone execution).\"\"\"\n    if not HAS_DEPS:\n        print(f\"⚠ Dependencies not available: {import_error}\")\n        print(\"Skipping AMX MOE INT4_1 accuracy tests\")\n        return\n\n    try:\n        print(\"Running AMX MOE INT4_1 accuracy test...\")\n        test_moe_amx_int4_1_accuracy()\n        print(\"✓ AMX MOE INT4_1 accuracy test passed\")\n        print(\"\\n✓ All tests passed!\")\n    except Exception as e:\n        print(f\"\\n✗ Test failed: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    run_all_tests()\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_moe_amx_accuracy_int4_1k.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"AMX MOE INT4_1K accuracy tests for KT-Kernel.\n\nTests accuracy of AMX-accelerated INT4_1K group quantization MOE operations against torch reference.\n\"\"\"\n\nimport os\nimport sys\nimport pytest\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_cpu_ci\n\n# Register this test for CPU CI with estimated runtime of 120 seconds\nregister_cpu_ci(est_time=120, suite=\"default\")\n\n# Check if dependencies are available\ntry:\n    import torch\n    import kt_kernel  # Import kt_kernel first to register kt_kernel_ext\n\n    kt_kernel_ext = kt_kernel.kt_kernel_ext  # Access the extension module\n    HAS_DEPS = True\nexcept ImportError as e:\n    HAS_DEPS = False\n    import_error = str(e)\n\n# Test parameters (from original test_moe_amx.py)\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nqlen = 1\nlayer_num = 1\nvalidation_iter = 2\nk_group_size = 64\nphysical_to_logical_map = None\n\n\ndef act_fn(x):\n    \"\"\"Activation function for MoE.\"\"\"\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    \"\"\"PyTorch reference implementation of MLP.\"\"\"\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    \"\"\"PyTorch reference implementation of MoE.\"\"\"\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n\n    return t_output\n\n\n@pytest.mark.cpu\ndef test_moe_amx_int4_1k_accuracy():\n    \"\"\"Test AMX INT4_1K MOE accuracy against PyTorch reference implementation.\"\"\"\n    if not HAS_DEPS:\n        pytest.skip(f\"Dependencies not available: {import_error}\")\n\n    global physical_to_logical_map\n    physical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n    CPUInfer = kt_kernel_ext.CPUInfer(60)\n\n    with torch.inference_mode(mode=True):\n        # Initialize MoE layers\n        gate_proj = (\n            torch.randn(\n                (expert_num, intermediate_size, hidden_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n        up_proj = (\n            torch.randn(\n                (expert_num, intermediate_size, hidden_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n        down_proj = (\n            torch.randn(\n                (expert_num, hidden_size, intermediate_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n\n        # Create MOE config\n        config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n        config.max_len = max_len\n        config.gate_proj = gate_proj.data_ptr()\n        config.up_proj = up_proj.data_ptr()\n        config.down_proj = down_proj.data_ptr()\n        config.gate_scale = 0\n        config.pool = CPUInfer.backend_\n\n        # Configure INT4_1K quantization settings\n        config.quant_config.bits = 4\n        config.quant_config.group_size = k_group_size\n        config.quant_config.zero_point = True\n\n        # Initialize INT4_1K MOE\n        moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)\n        CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n        CPUInfer.sync()\n\n        # Run validation iterations\n        for i in range(validation_iter):\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n            input_data = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            input_data = input_data / 100\n\n            # Run AMX MOE\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input_data.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            # Run torch reference\n            t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)\n\n            # Calculate relative difference\n            diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n            print(f\"Iteration {i}, diff = {diff:.6f}\")\n\n            # INT4_1K should have diff < 0.35\n            assert diff < 0.35, f\"INT4_1K accuracy test failed: diff={diff:.6f} >= 0.35\"\n\n\ndef run_all_tests():\n    \"\"\"Run all tests in this file (for standalone execution).\"\"\"\n    if not HAS_DEPS:\n        print(f\"⚠ Dependencies not available: {import_error}\")\n        print(\"Skipping AMX MOE INT4_1K accuracy tests\")\n        return\n\n    try:\n        print(\"Running AMX MOE INT4_1K accuracy test...\")\n        test_moe_amx_int4_1k_accuracy()\n        print(\"✓ AMX MOE INT4_1K accuracy test passed\")\n        print(\"\\n✓ All tests passed!\")\n    except Exception as e:\n        print(f\"\\n✗ Test failed: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    run_all_tests()\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_moe_amx_accuracy_int8.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"AMX MOE INT8 accuracy tests for KT-Kernel.\n\nTests accuracy of AMX-accelerated INT8 MOE operations against torch reference.\n\"\"\"\n\nimport os\nimport sys\nimport pytest\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_cpu_ci\n\n# Register this test for CPU CI with estimated runtime of 120 seconds\nregister_cpu_ci(est_time=120, suite=\"default\")\n\n# Check if dependencies are available\ntry:\n    import torch\n    import kt_kernel  # Import kt_kernel first to register kt_kernel_ext\n\n    kt_kernel_ext = kt_kernel.kt_kernel_ext  # Access the extension module\n    HAS_DEPS = True\nexcept ImportError as e:\n    HAS_DEPS = False\n    import_error = str(e)\n\n# Test parameters (from original test_moe_amx.py)\nexpert_num = 256\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nqlen = 1\nlayer_num = 1\nvalidation_iter = 2\nphysical_to_logical_map = None\n\n\ndef act_fn(x):\n    \"\"\"Activation function for MoE.\"\"\"\n    return x / (1.0 + torch.exp(-x))\n\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    \"\"\"PyTorch reference implementation of MLP.\"\"\"\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    \"\"\"PyTorch reference implementation of MoE.\"\"\"\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n\n    return t_output\n\n\n@pytest.mark.cpu\ndef test_moe_amx_int8_accuracy():\n    \"\"\"Test AMX INT8 MOE accuracy against PyTorch reference implementation.\"\"\"\n    if not HAS_DEPS:\n        pytest.skip(f\"Dependencies not available: {import_error}\")\n\n    global physical_to_logical_map\n    physical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n    CPUInfer = kt_kernel_ext.CPUInfer(60)\n\n    with torch.inference_mode(mode=True):\n        # Initialize MoE layers\n        gate_proj = (\n            torch.randn(\n                (expert_num, intermediate_size, hidden_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n        up_proj = (\n            torch.randn(\n                (expert_num, intermediate_size, hidden_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n        down_proj = (\n            torch.randn(\n                (expert_num, hidden_size, intermediate_size),\n                dtype=torch.bfloat16,\n                device=\"cuda\",\n            )\n            .to(\"cpu\")\n            .contiguous()\n        )\n\n        # Create MOE config\n        config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n        config.max_len = max_len\n        config.gate_proj = gate_proj.data_ptr()\n        config.up_proj = up_proj.data_ptr()\n        config.down_proj = down_proj.data_ptr()\n        config.gate_scale = 0\n        config.pool = CPUInfer.backend_\n\n        # Initialize INT8 MOE\n        moe = kt_kernel_ext.moe.AMXInt8_MOE(config)\n        CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n        CPUInfer.sync()\n\n        # Run validation iterations\n        for i in range(validation_iter):\n            bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n            expert_ids = torch.stack(\n                [torch.randperm(expert_num)[:num_experts_per_tok] for _ in range(qlen)]\n            ).contiguous()\n            weights = torch.rand((qlen, num_experts_per_tok), dtype=torch.float32).contiguous()\n            input_data = torch.randn((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            output = torch.empty((qlen, hidden_size), dtype=torch.bfloat16).contiguous()\n            input_data = input_data / 100\n\n            # Run AMX MOE\n            CPUInfer.submit(\n                moe.forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids.data_ptr(),\n                    weights.data_ptr(),\n                    input_data.data_ptr(),\n                    output.data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n            # Run torch reference\n            t_output = moe_torch(input_data, expert_ids, weights, gate_proj, up_proj, down_proj)\n\n            # Calculate relative difference\n            diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n            print(f\"Iteration {i}, diff = {diff:.6f}\")\n\n            # INT8 should have diff < 0.05\n            assert diff < 0.05, f\"INT8 accuracy test failed: diff={diff:.6f} >= 0.05\"\n\n\ndef run_all_tests():\n    \"\"\"Run all tests in this file (for standalone execution).\"\"\"\n    if not HAS_DEPS:\n        print(f\"⚠ Dependencies not available: {import_error}\")\n        print(\"Skipping AMX MOE INT8 accuracy tests\")\n        return\n\n    try:\n        print(\"Running AMX MOE INT8 accuracy test...\")\n        test_moe_amx_int8_accuracy()\n        print(\"✓ AMX MOE INT8 accuracy test passed\")\n        print(\"\\n✓ All tests passed!\")\n    except Exception as e:\n        print(f\"\\n✗ Test failed: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    run_all_tests()\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_moe_amx_bench_int4.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"AMX MOE INT4 benchmark tests for KT-Kernel.\n\nBenchmarks performance (bandwidth and FLOPS) of AMX-accelerated INT4 MOE operations.\n\"\"\"\n\nimport os\nimport sys\nimport time\nimport json\nimport subprocess\nimport platform\nimport pytest\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_cpu_ci\n\n# Register this test for CPU CI with estimated runtime of 300 seconds\nregister_cpu_ci(est_time=300, suite=\"default\")\n\n# Check if dependencies are available\ntry:\n    import torch\n    import kt_kernel  # Import kt_kernel first to register kt_kernel_ext\n\n    kt_kernel_ext = kt_kernel.kt_kernel_ext  # Access the extension module\n    from tqdm import tqdm\n\n    HAS_DEPS = True\nexcept ImportError as e:\n    HAS_DEPS = False\n    import_error = str(e)\n\n# Test parameters (from original bench_moe_amx.py)\nexpert_num = 16\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nlayer_num = 2\nqlen = 2048\nwarm_up_iter = 1000\ntest_iter = 2000\n\n# Worker configuration\nworker_config_dict = {\n    \"subpool_count\": 2,\n    \"subpool_numa_map\": [0, 1],\n    \"subpool_thread_count\": [30, 30],\n}\nCPUINFER_PARAM = 60\n\n\ndef get_git_commit():\n    \"\"\"Get current git commit information.\"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"Get system information including CPU model, memory, cores, and sockets.\"\"\"\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    # Get CPU model (Linux only)\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    # Get memory size in GB (Linux only)\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    # Get CPU core count\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    # Get socket count\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception as e:\n            sockets = set()\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\ndef record_results(result, filename):\n    \"\"\"Append results to JSONL file.\"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\n@pytest.mark.cpu\ndef test_moe_amx_int4_benchmark():\n    \"\"\"Benchmark AMX INT4 MOE performance.\"\"\"\n    if not HAS_DEPS:\n        pytest.skip(f\"Dependencies not available: {import_error}\")\n\n    quant_mode = \"int4\"\n    bytes_per_elem = 0.5\n\n    # Setup output file\n    script_dir = os.path.dirname(os.path.abspath(__file__))\n    json_path = os.path.join(script_dir, \"bench_moe_amx_int4.jsonl\")\n\n    with torch.inference_mode():\n        # Initialize CPUInfer with worker config\n        worker_config = kt_kernel_ext.WorkerPoolConfig()\n        worker_config.subpool_count = worker_config_dict[\"subpool_count\"]\n        worker_config.subpool_numa_map = worker_config_dict[\"subpool_numa_map\"]\n        worker_config.subpool_thread_count = worker_config_dict[\"subpool_thread_count\"]\n        CPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n\n        # Initialize MOE layers\n        moes = []\n        for layer_index in range(layer_num):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = CPUInfer.backend_\n\n            moe = kt_kernel_ext.moe.AMXInt4_MOE(config)\n            CPUInfer.submit(moe.load_weights_task())\n            CPUInfer.sync()\n            moes.append(moe)\n\n        # Generate test data\n        gen_iter = 3000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .to(\"cpu\")\n            .contiguous()\n        )\n        weights = (\n            torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").to(\"cpu\").contiguous()\n        )\n        input_tensor = (\n            torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        output_tensor = (\n            torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n\n        # Warm-up iterations\n        print(f\"Running warm-up for {warm_up_iter} iterations...\")\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n        # Test iterations\n        print(f\"Running test for {test_iter} iterations...\")\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n\n        # Calculate performance metrics\n        time_per_iter_us = total_time / test_iter * 1e6\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # GB/s\n        flops = (\n            hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n        )  # TFLOPS\n\n        print(\"Quant mode: \", quant_mode)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", time_per_iter_us)\n        print(\"Bandwidth: \", bandwidth, \"GB/s\")\n        print(\"Flops: \", flops, \"TFLOPS\")\n\n        # Record results\n        result = {\n            \"quant_mode\": quant_mode,\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"max_len\": max_len,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        result.update(get_git_commit())\n        result.update(get_system_info())\n        record_results(result, json_path)\n\n        print(f\"Results saved to {json_path}\")\n\n\ndef run_all_tests():\n    \"\"\"Run all tests in this file (for standalone execution).\"\"\"\n    if not HAS_DEPS:\n        print(f\"⚠ Dependencies not available: {import_error}\")\n        print(\"Skipping AMX MOE INT4 benchmark tests\")\n        return\n\n    try:\n        print(\"Running AMX MOE INT4 benchmark test...\")\n        test_moe_amx_int4_benchmark()\n        print(\"✓ AMX MOE INT4 benchmark test passed\")\n        print(\"\\n✓ All tests passed!\")\n    except Exception as e:\n        print(f\"\\n✗ Test failed: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    run_all_tests()\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_moe_amx_bench_int4_1.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"AMX MOE INT4 benchmark tests for KT-Kernel.\n\nBenchmarks performance (bandwidth and FLOPS) of AMX-accelerated INT4 MOE operations.\n\"\"\"\n\nimport os\nimport sys\nimport time\nimport json\nimport subprocess\nimport platform\nimport pytest\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_cpu_ci\n\n# Register this test for CPU CI with estimated runtime of 300 seconds\nregister_cpu_ci(est_time=300, suite=\"default\")\n\n# Check if dependencies are available\ntry:\n    import torch\n    import kt_kernel  # Import kt_kernel first to register kt_kernel_ext\n\n    kt_kernel_ext = kt_kernel.kt_kernel_ext  # Access the extension module\n    from tqdm import tqdm\n\n    HAS_DEPS = True\nexcept ImportError as e:\n    HAS_DEPS = False\n    import_error = str(e)\n\n# Test parameters (from original bench_moe_amx.py)\nexpert_num = 16\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nlayer_num = 2\nqlen = 1024\nwarm_up_iter = 1000\ntest_iter = 2000\n\n# Worker configuration\nworker_config_dict = {\n    \"subpool_count\": 2,\n    \"subpool_numa_map\": [0, 1],\n    \"subpool_thread_count\": [30, 30],\n}\nCPUINFER_PARAM = 60\n\n\ndef get_git_commit():\n    \"\"\"Get current git commit information.\"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"Get system information including CPU model, memory, cores, and sockets.\"\"\"\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    # Get CPU model (Linux only)\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    # Get memory size in GB (Linux only)\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    # Get CPU core count\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    # Get socket count\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception as e:\n            sockets = set()\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\ndef record_results(result, filename):\n    \"\"\"Append results to JSONL file.\"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\n@pytest.mark.cpu\ndef test_moe_amx_int4_1_benchmark():\n    \"\"\"Benchmark AMX INT4 MOE performance.\"\"\"\n    if not HAS_DEPS:\n        pytest.skip(f\"Dependencies not available: {import_error}\")\n\n    quant_mode = \"int4\"\n    bytes_per_elem = 0.5\n\n    # Setup output file\n    script_dir = os.path.dirname(os.path.abspath(__file__))\n    json_path = os.path.join(script_dir, \"bench_moe_amx_int4_1.jsonl\")\n\n    with torch.inference_mode():\n        # Initialize CPUInfer with worker config\n        worker_config = kt_kernel_ext.WorkerPoolConfig()\n        worker_config.subpool_count = worker_config_dict[\"subpool_count\"]\n        worker_config.subpool_numa_map = worker_config_dict[\"subpool_numa_map\"]\n        worker_config.subpool_thread_count = worker_config_dict[\"subpool_thread_count\"]\n        CPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n\n        # Initialize MOE layers\n        moes = []\n        for layer_index in range(layer_num):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = CPUInfer.backend_\n\n            moe = kt_kernel_ext.moe.AMXInt4_MOE(config)\n            CPUInfer.submit(moe.load_weights_task())\n            CPUInfer.sync()\n            moes.append(moe)\n\n        # Generate test data\n        gen_iter = 3000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .to(\"cpu\")\n            .contiguous()\n        )\n        weights = (\n            torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").to(\"cpu\").contiguous()\n        )\n        input_tensor = (\n            torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        output_tensor = (\n            torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n\n        # Warm-up iterations\n        print(f\"Running warm-up for {warm_up_iter} iterations...\")\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n        # Test iterations\n        print(f\"Running test for {test_iter} iterations...\")\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n\n        # Calculate performance metrics\n        time_per_iter_us = total_time / test_iter * 1e6\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # GB/s\n        flops = (\n            hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n        )  # TFLOPS\n\n        print(\"Quant mode: \", quant_mode)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", time_per_iter_us)\n        print(\"Bandwidth: \", bandwidth, \"GB/s\")\n        print(\"Flops: \", flops, \"TFLOPS\")\n\n        # Record results\n        result = {\n            \"quant_mode\": quant_mode,\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"max_len\": max_len,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        result.update(get_git_commit())\n        result.update(get_system_info())\n        record_results(result, json_path)\n\n        print(f\"Results saved to {json_path}\")\n\n\ndef run_all_tests():\n    \"\"\"Run all tests in this file (for standalone execution).\"\"\"\n    if not HAS_DEPS:\n        print(f\"Dependencies not available: {import_error}\")\n        print(\"Skipping AMX MOE INT4 benchmark tests\")\n        return\n\n    try:\n        print(\"Running AMX MOE INT4 benchmark test...\")\n        test_moe_amx_int4_1_benchmark()\n        print(\"AMX MOE INT4 benchmark test passed\")\n        print(\"\\nAll tests passed!\")\n    except Exception as e:\n        print(f\"\\nTest failed: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    run_all_tests()\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_moe_amx_bench_int4_1k.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"AMX MOE INT4 1K Group benchmark tests for KT-Kernel.\n\nBenchmarks performance (bandwidth and FLOPS) of AMX-accelerated INT4 MOE operations\nwith 1K group quantization (AMXInt4_1KGroup_MOE).\n\"\"\"\n\nimport os\nimport sys\nimport time\nimport json\nimport subprocess\nimport platform\nimport pytest\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_cpu_ci\n\n# Register this test for CPU CI with estimated runtime of 300 seconds\nregister_cpu_ci(est_time=300, suite=\"default\")\n\n# Check if dependencies are available\ntry:\n    import torch\n    import kt_kernel  # Import kt_kernel first to register kt_kernel_ext\n\n    kt_kernel_ext = kt_kernel.kt_kernel_ext  # Access the extension module\n    from tqdm import tqdm\n\n    HAS_DEPS = True\nexcept ImportError as e:\n    HAS_DEPS = False\n    import_error = str(e)\n\n# Test parameters (from bench_moe_amx_k.py)\nexpert_num = 16\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 8\nlayer_num = 2\nqlen = 1024\nwarm_up_iter = 1000\ntest_iter = 2000\nk_group_size = 128\n\n# Worker configuration\nworker_config_dict = {\n    \"subpool_count\": 2,\n    \"subpool_numa_map\": [0, 1],\n    \"subpool_thread_count\": [30, 30],\n}\nCPUINFER_PARAM = 60\n\n\ndef get_git_commit():\n    \"\"\"Get current git commit information.\"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"Get system information including CPU model, memory, cores, and sockets.\"\"\"\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    # Get CPU model (Linux only)\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    # Get memory size in GB (Linux only)\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    # Get CPU core count\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    # Get socket count\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception as e:\n            sockets = set()\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\ndef record_results(result, filename):\n    \"\"\"Append results to JSONL file.\"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\n@pytest.mark.cpu\ndef test_moe_amx_int4_1k_benchmark():\n    \"\"\"Benchmark AMX INT4 1K Group MOE performance.\"\"\"\n    if not HAS_DEPS:\n        pytest.skip(f\"Dependencies not available: {import_error}\")\n\n    quant_mode = \"int4_1k\"\n    bytes_per_elem = 0.5\n\n    # Setup output file\n    script_dir = os.path.dirname(os.path.abspath(__file__))\n    json_path = os.path.join(script_dir, \"bench_moe_amx_int4_1k.jsonl\")\n\n    with torch.inference_mode():\n        # Initialize CPUInfer with worker config\n        worker_config = kt_kernel_ext.WorkerPoolConfig()\n        worker_config.subpool_count = worker_config_dict[\"subpool_count\"]\n        worker_config.subpool_numa_map = worker_config_dict[\"subpool_numa_map\"]\n        worker_config.subpool_thread_count = worker_config_dict[\"subpool_thread_count\"]\n        CPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n\n        # Physical to logical map for weight loading\n        physical_to_logical_map = torch.tensor(data=range(expert_num), device=\"cpu\", dtype=torch.int64).contiguous()\n\n        # Initialize MOE layers\n        moes = []\n        for layer_index in range(layer_num):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = CPUInfer.backend_\n\n            # Configure quantization for INT4 1K Group\n            config.quant_config.bits = 4\n            config.quant_config.group_size = k_group_size\n            config.quant_config.zero_point = True\n            config.gate_scale = 0\n\n            moe = kt_kernel_ext.moe.AMXInt4_1KGroup_MOE(config)\n            CPUInfer.submit(moe.load_weights_task(physical_to_logical_map.data_ptr()))\n            CPUInfer.sync()\n            moes.append(moe)\n\n        # Generate test data\n        gen_iter = 3000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .to(\"cpu\")\n            .contiguous()\n        )\n        weights = (\n            torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").to(\"cpu\").contiguous()\n        )\n        input_tensor = (\n            torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        output_tensor = (\n            torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n\n        # Warm-up iterations\n        print(f\"Running warm-up for {warm_up_iter} iterations...\")\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n        # Test iterations\n        print(f\"Running test for {test_iter} iterations...\")\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n\n        # Calculate performance metrics\n        time_per_iter_us = total_time / test_iter * 1e6\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # GB/s\n        flops = (\n            hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n        )  # TFLOPS\n\n        print(\"Quant mode: \", quant_mode)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", time_per_iter_us)\n        print(\"Bandwidth: \", bandwidth, \"GB/s\")\n        print(\"Flops: \", flops, \"TFLOPS\")\n\n        # Record results\n        result = {\n            \"quant_mode\": quant_mode,\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"max_len\": max_len,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n                \"k_group_size\": k_group_size,\n            },\n        }\n        result.update(get_git_commit())\n        result.update(get_system_info())\n        record_results(result, json_path)\n\n        print(f\"Results saved to {json_path}\")\n\n\ndef run_all_tests():\n    \"\"\"Run all tests in this file (for standalone execution).\"\"\"\n    if not HAS_DEPS:\n        print(f\"Dependencies not available: {import_error}\")\n        print(\"Skipping AMX MOE INT4 1K Group benchmark tests\")\n        return\n\n    try:\n        print(\"Running AMX MOE INT4 1K Group benchmark test...\")\n        test_moe_amx_int4_1k_benchmark()\n        print(\"AMX MOE INT4 1K Group benchmark test passed\")\n        print(\"\\nAll tests passed!\")\n    except Exception as e:\n        print(f\"\\nTest failed: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    run_all_tests()\n"
  },
  {
    "path": "kt-kernel/test/per_commit/test_moe_amx_bench_int8.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"AMX MOE INT8 benchmark tests for KT-Kernel.\n\nBenchmarks performance (bandwidth and FLOPS) of AMX-accelerated INT8 MOE operations.\n\"\"\"\n\nimport os\nimport sys\nimport time\nimport json\nimport subprocess\nimport platform\nimport pytest\n\n# Add parent directory to path for CI registration\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\"))\nfrom ci.ci_register import register_cpu_ci\n\n# Register this test for CPU CI with estimated runtime of 300 seconds\nregister_cpu_ci(est_time=300, suite=\"default\")\n\n# Check if dependencies are available\ntry:\n    import torch\n    import kt_kernel  # Import kt_kernel first to register kt_kernel_ext\n\n    kt_kernel_ext = kt_kernel.kt_kernel_ext  # Access the extension module\n    from tqdm import tqdm\n\n    HAS_DEPS = True\nexcept ImportError as e:\n    HAS_DEPS = False\n    import_error = str(e)\n\n# Test parameters (from original bench_moe_amx.py)\nexpert_num = 128\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nnum_experts_per_tok = 0\nlayer_num = 2\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 2000\n\n# Worker configuration\nworker_config_dict = {\n    \"subpool_count\": 2,\n    \"subpool_numa_map\": [0, 1],\n    \"subpool_thread_count\": [30, 30],\n}\nCPUINFER_PARAM = 60\n\n\ndef get_git_commit():\n    \"\"\"Get current git commit information.\"\"\"\n    result = {}\n    try:\n        commit = subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"]).decode(\"utf-8\").strip()\n        commit_msg = subprocess.check_output([\"git\", \"log\", \"-1\", \"--pretty=%B\"]).decode(\"utf-8\").strip()\n        result[\"commit\"] = commit\n        result[\"commit_message\"] = commit_msg\n\n        dirty_output = subprocess.check_output([\"git\", \"status\", \"--porcelain\"]).decode(\"utf-8\").strip()\n        if dirty_output:\n            result[\"dirty\"] = True\n            result[\"dirty_files\"] = dirty_output.splitlines()\n        else:\n            result[\"dirty\"] = False\n    except Exception as e:\n        result[\"commit\"] = None\n        result[\"commit_message\"] = None\n        result[\"dirty\"] = None\n        result[\"error\"] = str(e)\n    return result\n\n\ndef get_system_info():\n    \"\"\"Get system information including CPU model, memory, cores, and sockets.\"\"\"\n    info = {}\n    uname = platform.uname()\n    info[\"system_name\"] = uname.system\n    info[\"node_name\"] = uname.node\n\n    # Get CPU model (Linux only)\n    cpu_model = None\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"model name\" in line:\n                        cpu_model = line.split(\":\", 1)[1].strip()\n                        break\n        except Exception as e:\n            cpu_model = f\"Error: {e}\"\n    info[\"cpu_model\"] = cpu_model\n\n    # Get memory size in GB (Linux only)\n    mem_total_gb = None\n    if os.path.exists(\"/proc/meminfo\"):\n        try:\n            with open(\"/proc/meminfo\", \"r\") as f:\n                for line in f:\n                    if \"MemTotal\" in line:\n                        mem_kb = float(line.split(\":\", 1)[1].split()[0])\n                        mem_total_gb = round(mem_kb / (1024 * 1024), 2)\n                        break\n        except Exception as e:\n            mem_total_gb = f\"Error: {e}\"\n    info[\"memory_size_GB\"] = mem_total_gb\n\n    # Get CPU core count\n    info[\"cpu_core_count\"] = os.cpu_count()\n\n    # Get socket count\n    sockets = set()\n    if os.path.exists(\"/proc/cpuinfo\"):\n        try:\n            with open(\"/proc/cpuinfo\", \"r\") as f:\n                for line in f:\n                    if \"physical id\" in line:\n                        sockets.add(line.split(\":\", 1)[1].strip())\n        except Exception as e:\n            sockets = set()\n    info[\"cpu_socket_count\"] = len(sockets) if len(sockets) > 0 else 1\n\n    return info\n\n\ndef record_results(result, filename):\n    \"\"\"Append results to JSONL file.\"\"\"\n    with open(filename, \"a\") as f:\n        f.write(json.dumps(result) + \"\\n\")\n\n\n@pytest.mark.cpu\ndef test_moe_amx_int8_benchmark():\n    \"\"\"Benchmark AMX INT8 MOE performance.\"\"\"\n    if not HAS_DEPS:\n        pytest.skip(f\"Dependencies not available: {import_error}\")\n\n    quant_mode = \"int8\"\n    bytes_per_elem = 1.0\n\n    # Setup output file\n    script_dir = os.path.dirname(os.path.abspath(__file__))\n    json_path = os.path.join(script_dir, \"bench_moe_amx_int8.jsonl\")\n\n    with torch.inference_mode():\n        # Initialize CPUInfer with worker config\n        worker_config = kt_kernel_ext.WorkerPoolConfig()\n        worker_config.subpool_count = worker_config_dict[\"subpool_count\"]\n        worker_config.subpool_numa_map = worker_config_dict[\"subpool_numa_map\"]\n        worker_config.subpool_thread_count = worker_config_dict[\"subpool_thread_count\"]\n        CPUInfer = kt_kernel_ext.CPUInfer(worker_config)\n\n        # Initialize MOE layers\n        moes = []\n        for layer_index in range(layer_num):\n            gate_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            up_proj = (\n                torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            down_proj = (\n                torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device=\"cuda\")\n                .to(\"cpu\")\n                .contiguous()\n            )\n            config = kt_kernel_ext.moe.MOEConfig(expert_num, num_experts_per_tok, hidden_size, intermediate_size, 0)\n            config.max_len = max_len\n            config.gate_proj = gate_proj.data_ptr()\n            config.up_proj = up_proj.data_ptr()\n            config.down_proj = down_proj.data_ptr()\n            config.pool = CPUInfer.backend_\n\n            moe = kt_kernel_ext.moe.AMXInt8_MOE(config)\n            CPUInfer.submit(moe.load_weights_task())\n            CPUInfer.sync()\n            moes.append(moe)\n\n        # Generate test data\n        gen_iter = 3000\n        expert_ids = (\n            torch.rand(gen_iter * qlen, expert_num, device=\"cpu\")\n            .argsort(dim=-1)[:, :num_experts_per_tok]\n            .reshape(gen_iter, qlen * num_experts_per_tok)\n            .to(\"cpu\")\n            .contiguous()\n        )\n        weights = (\n            torch.rand((gen_iter, qlen, num_experts_per_tok), dtype=torch.float32, device=\"cpu\").to(\"cpu\").contiguous()\n        )\n        input_tensor = (\n            torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        output_tensor = (\n            torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device=\"cuda\").to(\"cpu\").contiguous()\n        )\n        bsz_tensor = torch.tensor([qlen], device=\"cpu\")\n\n        # Warm-up iterations\n        print(f\"Running warm-up for {warm_up_iter} iterations...\")\n        for i in tqdm(range(warm_up_iter), desc=\"Warm-up\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n\n        # Test iterations\n        print(f\"Running test for {test_iter} iterations...\")\n        start = time.perf_counter()\n        for i in tqdm(range(test_iter), desc=\"Testing\"):\n            CPUInfer.submit(\n                moes[i % layer_num].forward_task(\n                    bsz_tensor.data_ptr(),\n                    num_experts_per_tok,\n                    expert_ids[i % gen_iter].data_ptr(),\n                    weights[i % gen_iter].data_ptr(),\n                    input_tensor[i % layer_num].data_ptr(),\n                    output_tensor[i % layer_num].data_ptr(),\n                    False,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n\n        # Calculate performance metrics\n        time_per_iter_us = total_time / test_iter * 1e6\n        bandwidth = (\n            hidden_size\n            * intermediate_size\n            * 3\n            * num_experts_per_tok\n            * (1 / 8 * 256 * (1 - (31 / 32) ** qlen))\n            * bytes_per_elem\n            * test_iter\n            / total_time\n            / 1e9\n        )  # GB/s\n        flops = (\n            hidden_size * intermediate_size * qlen * 3 * num_experts_per_tok * 2 * test_iter / total_time / 1e12\n        )  # TFLOPS\n\n        print(\"Quant mode: \", quant_mode)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", time_per_iter_us)\n        print(\"Bandwidth: \", bandwidth, \"GB/s\")\n        print(\"Flops: \", flops, \"TFLOPS\")\n\n        # Record results\n        result = {\n            \"quant_mode\": quant_mode,\n            \"total_time_seconds\": total_time,\n            \"iterations\": test_iter,\n            \"time_per_iteration_us\": time_per_iter_us,\n            \"bandwidth_GBs\": bandwidth,\n            \"flops_TFLOPS\": flops,\n            \"timestamp\": time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()),\n            \"test_parameters\": {\n                \"expert_num\": expert_num,\n                \"hidden_size\": hidden_size,\n                \"intermediate_size\": intermediate_size,\n                \"max_len\": max_len,\n                \"num_experts_per_tok\": num_experts_per_tok,\n                \"layer_num\": layer_num,\n                \"qlen\": qlen,\n                \"warm_up_iter\": warm_up_iter,\n                \"test_iter\": test_iter,\n                \"CPUInfer_parameter\": CPUINFER_PARAM,\n            },\n        }\n        result.update(get_git_commit())\n        result.update(get_system_info())\n        record_results(result, json_path)\n\n        print(f\"Results saved to {json_path}\")\n\n\ndef run_all_tests():\n    \"\"\"Run all tests in this file (for standalone execution).\"\"\"\n    if not HAS_DEPS:\n        print(f\"⚠ Dependencies not available: {import_error}\")\n        print(\"Skipping AMX MOE INT8 benchmark tests\")\n        return\n\n    try:\n        print(\"Running AMX MOE INT8 benchmark test...\")\n        test_moe_amx_int8_benchmark()\n        print(\"✓ AMX MOE INT8 benchmark test passed\")\n        print(\"\\n✓ All tests passed!\")\n    except Exception as e:\n        print(f\"\\n✗ Test failed: {e}\")\n        import traceback\n\n        traceback.print_exc()\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    run_all_tests()\n"
  },
  {
    "path": "kt-kernel/test/run_suite.py",
    "content": "import argparse\nimport glob\nimport sys\nfrom typing import List\n\nfrom ci.ci_register import HWBackend, CIRegistry, collect_tests\nfrom ci.ci_utils import TestFile, run_unittest_files\n\nHW_MAPPING = {\n    \"cpu\": HWBackend.CPU,\n    \"cuda\": HWBackend.CUDA,\n    \"amd\": HWBackend.AMD,\n}\n\nLABEL_MAPPING = {\n    HWBackend.CPU: [\"default\"],\n    HWBackend.AMD: [\"stage-a-test-1\"],\n    HWBackend.CUDA: [\"stage-a-test-1\"],\n}\n\n\ndef _filter_tests(\n    ci_tests: List[CIRegistry], hw: HWBackend, suite: str\n) -> List[CIRegistry]:\n    ci_tests = [t for t in ci_tests if t.backend == hw]\n    ret = []\n    for t in ci_tests:\n        assert t.suite in LABEL_MAPPING[hw], f\"Unknown stage {t.suite} for backend {hw}\"\n        if t.suite == suite:\n            ret.append(t)\n    return ret\n\n\ndef run_per_commit(hw: HWBackend, suite: str):\n    files = glob.glob(\"per_commit/**/*.py\", recursive=True)\n    # Exclude __init__.py files as they don't contain test registrations\n    files = [f for f in files if not f.endswith(\"__init__.py\")]\n    ci_tests = _filter_tests(collect_tests(files), hw, suite)\n    test_files = [TestFile(t.filename, t.est_time) for t in ci_tests]\n\n    return run_unittest_files(\n        test_files,\n        timeout_per_file=1200,\n        continue_on_error=False,\n    )\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--hw\",\n        type=str,\n        choices=[\"cpu\", \"cuda\", \"amd\"],\n        required=True,\n        help=\"Hardware backend to run tests on.\",\n    )\n    parser.add_argument(\n        \"--suite\",\n        type=str,\n        required=True,\n        help=\"Test suite to run.\",\n    )\n    args = parser.parse_args()\n    hw = HW_MAPPING[args.hw]\n    exit_code = run_per_commit(hw, args.suite)\n    # run_unittest_files returns 0 for success, -1 for failure\n    # Convert to standard exit codes: 0 for success, 1 for failure\n    sys.exit(0 if exit_code == 0 else 1)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-kernel/test/test_generate_gpu_experts_masks.py",
    "content": "\"\"\"Test for generate_gpu_experts_masks function.\"\"\"\n\nimport sys\nimport os\n\n# Add python directory to path\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"python\"))\n\nimport torch\nimport time\nfrom experts_base import generate_gpu_experts_masks\n\n\ndef test_basic():\n    \"\"\"Test basic functionality.\"\"\"\n    print(\"=\" * 60)\n    print(\"Test 1: Basic functionality\")\n    print(\"=\" * 60)\n\n    activation_freq = torch.tensor([\n        [0.1, 0.5, 0.3, 0.8],  # layer 0\n        [0.2, 0.4, 0.9, 0.1],  # layer 1\n    ])\n\n    print(f\"Input activation_freq:\\n{activation_freq}\")\n    print(f\"num_gpu_experts: 3\")\n\n    masks = generate_gpu_experts_masks(activation_freq, num_gpu_experts=3)\n\n    print(f\"Output masks:\\n{masks}\")\n    print(f\"Output dtype: {masks.dtype}, device: {masks.device}\")\n\n    # Verify: top 3 should be (1,2)=0.9, (0,3)=0.8, (0,1)=0.5\n    expected_gpu_count = masks.sum().item()\n    print(f\"Total GPU experts: {expected_gpu_count}\")\n\n    # Check the top 3 positions\n    assert masks[1, 2] == True, \"layer1-expert2 (0.9) should be on GPU\"\n    assert masks[0, 3] == True, \"layer0-expert3 (0.8) should be on GPU\"\n    assert masks[0, 1] == True, \"layer0-expert1 (0.5) should be on GPU\"\n    assert expected_gpu_count == 3, f\"Expected 3 GPU experts, got {expected_gpu_count}\"\n\n    print(\"PASSED\\n\")\n\n\ndef test_edge_cases():\n    \"\"\"Test edge cases.\"\"\"\n    print(\"=\" * 60)\n    print(\"Test 2: Edge cases\")\n    print(\"=\" * 60)\n\n    activation_freq = torch.tensor([\n        [0.1, 0.5, 0.3, 0.8],\n        [0.2, 0.4, 0.9, 0.1],\n    ])\n\n    # Test num_gpu_experts = 0\n    masks = generate_gpu_experts_masks(activation_freq, num_gpu_experts=0)\n    assert masks.sum().item() == 0, \"num_gpu_experts=0 should have no GPU experts\"\n    print(f\"num_gpu_experts=0: {masks.sum().item()} GPU experts - PASSED\")\n\n    # Test num_gpu_experts = total experts\n    masks = generate_gpu_experts_masks(activation_freq, num_gpu_experts=8)\n    assert masks.sum().item() == 8, \"num_gpu_experts=8 should have all experts on GPU\"\n    print(f\"num_gpu_experts=8 (all): {masks.sum().item()} GPU experts - PASSED\")\n\n    # Test num_gpu_experts > total experts (should clamp)\n    masks = generate_gpu_experts_masks(activation_freq, num_gpu_experts=100)\n    assert masks.sum().item() == 8, \"num_gpu_experts=100 should be clamped to 8\"\n    print(f\"num_gpu_experts=100 (clamped): {masks.sum().item()} GPU experts - PASSED\")\n\n    # Test negative num_gpu_experts (should clamp to 0)\n    masks = generate_gpu_experts_masks(activation_freq, num_gpu_experts=-5)\n    assert masks.sum().item() == 0, \"num_gpu_experts=-5 should be clamped to 0\"\n    print(f\"num_gpu_experts=-5 (clamped): {masks.sum().item()} GPU experts - PASSED\")\n\n    print(\"All edge cases PASSED\\n\")\n\n\ndef test_performance():\n    \"\"\"Test performance with realistic sizes.\"\"\"\n    print(\"=\" * 60)\n    print(\"Test 3: Performance\")\n    print(\"=\" * 60)\n\n    # DeepSeek-V3 like: 61 layers, 256 experts\n    num_layers = 61\n    num_experts = 256\n\n    # Generate random activation frequencies\n    activation_freq = torch.rand(num_layers, num_experts)\n\n    # Test with different num_gpu_experts\n    test_cases = [0, 100, 500, 1000, 2000, 5000, num_layers * num_experts]\n\n    print(f\"Shape: ({num_layers}, {num_experts}) = {num_layers * num_experts} total experts\\n\")\n\n    for num_gpu in test_cases:\n        # Warmup\n        _ = generate_gpu_experts_masks(activation_freq, num_gpu_experts=num_gpu)\n\n        # Measure time\n        num_runs = 100\n        start = time.perf_counter()\n        for _ in range(num_runs):\n            masks = generate_gpu_experts_masks(activation_freq, num_gpu_experts=num_gpu)\n        end = time.perf_counter()\n\n        avg_time_us = (end - start) / num_runs * 1e6\n        actual_gpu = masks.sum().item()\n\n        print(f\"num_gpu_experts={num_gpu:5d} -> actual={actual_gpu:5d}, time={avg_time_us:8.2f} us\")\n\n    print(\"\\nPerformance test PASSED\\n\")\n\n\ndef test_output_properties():\n    \"\"\"Test output tensor properties.\"\"\"\n    print(\"=\" * 60)\n    print(\"Test 4: Output properties\")\n    print(\"=\" * 60)\n\n    activation_freq = torch.rand(10, 64)\n    masks = generate_gpu_experts_masks(activation_freq, num_gpu_experts=50)\n\n    print(f\"Shape: {masks.shape}\")\n    print(f\"Dtype: {masks.dtype}\")\n    print(f\"Device: {masks.device}\")\n    print(f\"Is contiguous: {masks.is_contiguous()}\")\n\n    assert masks.shape == (10, 64), f\"Expected shape (10, 64), got {masks.shape}\"\n    assert masks.dtype == torch.bool, f\"Expected dtype bool, got {masks.dtype}\"\n    assert str(masks.device) == \"cpu\", f\"Expected device cpu, got {masks.device}\"\n\n    print(\"All properties PASSED\\n\")\n\n\ndef test_determinism():\n    \"\"\"Test that results are deterministic.\"\"\"\n    print(\"=\" * 60)\n    print(\"Test 5: Determinism\")\n    print(\"=\" * 60)\n\n    activation_freq = torch.rand(20, 128)\n\n    masks1 = generate_gpu_experts_masks(activation_freq, num_gpu_experts=100)\n    masks2 = generate_gpu_experts_masks(activation_freq, num_gpu_experts=100)\n\n    assert torch.equal(masks1, masks2), \"Results should be deterministic\"\n    print(\"Determinism PASSED\\n\")\n\n\nif __name__ == \"__main__\":\n    test_basic()\n    test_edge_cases()\n    test_output_properties()\n    test_determinism()\n    test_performance()\n\n    print(\"=\" * 60)\n    print(\"All tests PASSED!\")\n    print(\"=\" * 60)\n"
  },
  {
    "path": "kt-sft/.flake8",
    "content": "[flake8]\nmax-line-length = 120\nextend-select = B950\nextend-ignore = E203,E501,E701, B001,B006,B007,B008,B009,B010,B011,B016,B028,B031,B950,E265,E266,E401,E402,E711,E712,E713,E721,E722,E731,F401,F403,F405,F541,F811,F821,F841,W391"
  },
  {
    "path": "kt-sft/.gitignore",
    "content": "__pycache__\nbuild\n.vscode\n*.so\n*.cache\nserver.db\nlogs\nnode_modules\n*.nsys-rep\n.vs/\n*pycache*\n*build/\n*/third_party/*\n.DS_Store\ncompile_commands.json\n*.egg-info*\n*dist/\nktransformers/server/local_store/\nktransformers/server_test1.db\n*.patch\nimg/\ntmp*.txt\ntmp*.py\ntest.txt\nbook\nktransformers/tests/chat_txt.txt\nmmlu_result*\nktransformers/ktransformers_ext/cuda_musa/\ntest_prompt.txt\ncsrc/demo\n\n.vscode/\n\n*__pycache__*\n*.py[cod]\n*$py.class\n.pytest_cache/\n\nGGUF-DeepSeek-V2-Lite-Chat\nDeepSeek-V2-Lite-Chat\nktransformers/sft/adapter\ntmp\ngraphviz/\ncompute_graph*\ngraphviz*\nthird_party/\ntest_adapter/demo_*\n*.whl\n*.svg\n*_graph\ntmp_package.txt\nlogs/\n\n*.vscode/\n\n__pycache__/\n*.py[cod]\n*$py.class\n.pytest_cache/\n\n# MakeFiles for kt_ext\nbuild/\nktransformers/ktransformers_ext/bin\nktransformers/ktransformers_ext/CMakeFiles\nktransformers/ktransformers_ext/cmake_install.cmake\nktransformers/ktransformers_ext/CMakeCache.txt\nktransformers/ktransformers_ext/compile_commands.json\nktransformers/ktransformers_ext/Makefile\n*.egg-info*\n*.so\n\n*.txt\n*.pt\n\ndebug/*\n\ntest_adapter/ESC_inst_all.json\n\n!CMakeLists.txt\n!requirements-sft.txt\n\n*-test*.yaml\n\nduipai_pure_tf\ndata/dataset_info.json\n\ntest_adapter/*.json\n\n.venv*"
  },
  {
    "path": "kt-sft/.gitmodules",
    "content": "[submodule \"third_party/llama.cpp\"]\n\tpath = third_party/llama.cpp\n\turl = https://github.com/ggerganov/llama.cpp.git\n[submodule \"third_party/pybind11\"]\n\tpath = third_party/pybind11\n\turl = https://github.com/pybind/pybind11.git\n[submodule \"third_party/spdlog\"]\n\tpath = third_party/spdlog\n\turl = https://github.com/gabime/spdlog.git\n[submodule \"third_party/custom_flashinfer\"]\n\tpath = third_party/custom_flashinfer\n\turl = https://github.com/kvcache-ai/custom_flashinfer.git\n\tbranch = fix-precision-mla-merge-main\n[submodule \"third_party/xxHash\"]\n\tpath = third_party/xxHash\n\turl = https://github.com/Cyan4973/xxHash.git\n[submodule \"third_party/prometheus-cpp\"]\n\tpath = third_party/prometheus-cpp\n\turl = https://github.com/jupp0r/prometheus-cpp.git\n"
  },
  {
    "path": "kt-sft/.pylintrc",
    "content": "[MASTER]\nextension-pkg-whitelist=pydantic\nmax-line-length=120\n\n[MESSAGES CONTROL]\ndisable=missing-function-docstring"
  },
  {
    "path": "kt-sft/Dockerfile",
    "content": "FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-devel as compile_server\n\n\nARG CPU_INSTRUCT=NATIVE\n\n# 设置工作目录和 CUDA 路径\nWORKDIR /workspace\nENV CUDA_HOME=/usr/local/cuda\n\n\n\n# 安装依赖\nRUN apt update -y\nRUN apt install -y --no-install-recommends \\\n    libtbb-dev \\\n    libssl-dev \\\n    libcurl4-openssl-dev \\\n    libaio1 \\\n    libaio-dev \\\n    libfmt-dev \\\n    libgflags-dev \\\n    zlib1g-dev \\\n    patchelf \\\n    git \\\n    wget \\\n    vim \\\n    gcc \\\n    g++ \\\n    cmake\n# 拷贝代码\nRUN git clone https://github.com/kvcache-ai/ktransformers.git \n# 清理 apt 缓存\nRUN rm -rf /var/lib/apt/lists/*\n\n# 进入项目目录\nWORKDIR /workspace/ktransformers\n# 初始化子模块\nRUN git submodule update --init --recursive\n\n# 升级 pip\nRUN pip install --upgrade pip\n\n# 安装构建依赖\nRUN pip install ninja pyproject numpy cpufeature aiohttp zmq openai\n\n# 安装 flash-attn（提前装可以避免后续某些编译依赖出错）\nRUN pip install flash-attn\n\n# 安装 ktransformers 本体（含编译）\nRUN CPU_INSTRUCT=${CPU_INSTRUCT} \\\n    USE_BALANCE_SERVE=1 \\\n    KTRANSFORMERS_FORCE_BUILD=TRUE \\\n    TORCH_CUDA_ARCH_LIST=\"8.0;8.6;8.7;8.9;9.0+PTX\" \\\n    pip install . --no-build-isolation --verbose\n\nRUN pip install third_party/custom_flashinfer/\n# 清理 pip 缓存\nRUN pip cache purge\n\n# 拷贝 C++ 运行时库\nRUN cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/\n\n# 保持容器运行（调试用）\nENTRYPOINT [\"tail\", \"-f\", \"/dev/null\"]"
  },
  {
    "path": "kt-sft/Dockerfile.xpu",
    "content": "# Base image\nFROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04\n\nARG http_proxy\nARG https_proxy\n\nENV DEBIAN_FRONTEND=noninteractive\nENV CONDA_DIR=/opt/conda\n\n# Install dependencies\nRUN apt-get update && apt-get install -y \\\n    wget \\\n    curl \\\n    bash \\\n    git \\\n    vim \\\n    ca-certificates \\\n    binutils \\\n    cmake \\\n    g++ \\\n    && rm -rf /var/lib/apt/lists/*\n\n# Install Miniforge\nRUN wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O /tmp/miniforge.sh && \\\n    bash /tmp/miniforge.sh -b -p $CONDA_DIR && \\\n    rm /tmp/miniforge.sh && \\\n    $CONDA_DIR/bin/conda clean -afy\n\n# Add conda to PATH\nENV PATH=$CONDA_DIR/bin:$PATH\n\nRUN bash -c \"\\\n    source /opt/conda/etc/profile.d/conda.sh && \\\n    conda create --name ktransformers python=3.11 -y && \\\n    conda activate ktransformers && \\\n    conda env list && \\\n    conda install -c conda-forge libstdcxx-ng -y && \\\n    strings \\$(find /opt/conda/envs/ktransformers/lib -name 'libstdc++.so.6') | grep GLIBCXX | grep 3.4.32 \\\n\"\n\nRUN bash -c \"\\\n    source /opt/conda/etc/profile.d/conda.sh && \\\n    conda activate ktransformers && \\\n    pip install ipex-llm[xpu_2.6]==2.3.0b20250518 --extra-index-url https://download.pytorch.org/whl/xpu && \\\n    pip uninstall -y torch torchvision torchaudio && \\\n    pip install torch==2.7+xpu torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu && \\\n    pip uninstall -y intel-opencl-rt dpcpp-cpp-rt && \\\n    pip list \\\n\"\n\n# Clone and set up ktransformers repo\nRUN bash -c \"\\\n    source $CONDA_DIR/etc/profile.d/conda.sh && \\\n    conda activate ktransformers && \\\n    git clone https://github.com/kvcache-ai/ktransformers.git && \\\n    cd ktransformers && \\\n    git submodule update --init && \\\n    sed -i 's/torch\\.xpu\\.is_available()/True/g' setup.py && \\\n    bash install.sh --dev xpu \\\n\"\n\n# Init conda and prepare bashrc\nRUN conda init bash && \\\n    echo \"source $CONDA_DIR/etc/profile.d/conda.sh\" >> ~/.bashrc && \\\n    echo \"conda activate ktransformers\" >> ~/.bashrc\n\nWORKDIR /ktransformers/\nCMD [\"bash\"]\n"
  },
  {
    "path": "kt-sft/LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "kt-sft/MANIFEST.in",
    "content": "graft third_party\ngraft ktransformers\ngraft local_chat.py\ngraft csrc\ninclude LICENSE README.md\nprune ktransformers/website\nprune ktransformers/logs\nprune ktransformers.egg-info\nprune third_party/llama.cpp/models\ngraft ktransformers/website/dist\nglobal-exclude __pycache__\ninclude KTransformersOps.*.so\ninclude cpuinfer_ext.*.so\n"
  },
  {
    "path": "kt-sft/Makefile",
    "content": "flake_find:\n\tcd ktransformers && flake8 | grep -Eo '[A-Z][0-9]{3}' | sort | uniq| paste -sd ',' - \nformat:\n\t@cd ktransformers && black .\n\t@black setup.py\ndev_install:\n# clear build dirs\n\trm -rf build\n\trm -rf *.egg-info\n\trm -rf ktransformers/ktransformers_ext/build\n\trm -rf ktransformers/ktransformers_ext/cuda/build\n\trm -rf ktransformers/ktransformers_ext/cuda/dist\n\trm -rf ktransformers/ktransformers_ext/cuda/*.egg-info\n\n# install ktransformers\n\techo \"Installing python dependencies from requirements.txt\"\n\tpip install -r requirements-local_chat.txt\n\n\techo \"Installing ktransformers\"\n\tKTRANSFORMERS_FORCE_BUILD=TRUE pip install -e . -v --no-build-isolation\n\techo \"Installation completed successfully\"\nclean:\n\trm -rf build\n\trm -rf *.egg-info\n\trm -rf ktransformers/ktransformers_ext/build\n\trm -rf ktransformers/ktransformers_ext/cuda/build\n\trm -rf ktransformers/ktransformers_ext/cuda/dist\n\trm -rf ktransformers/ktransformers_ext/cuda/*.egg-info\t\ninstall_numa:\n\tUSE_NUMA=1 make dev_install\ninstall_no_numa:\n\tenv -u USE_NUMA make dev_install"
  },
  {
    "path": "kt-sft/README.md",
    "content": "- [KTransformers Fine-Tuning × LLaMA-Factory Integration – User Guide](#ktransformers-fine-tuning-x-llama-factory-integration-–-user-guide)\n- [Introduction](#introduction)\n\n- [Fine-Tuning Results (Examples)](#fine-tuning-results-examples)\n  - [Stylized Dialogue (CatGirl tone)](#stylized-dialogue-catgirl-tone)\n  - [Benchmarks](#benchmarks)\n    - [Translational-Style dataset](#translational-style-dataset)\n    - [AfriMed-QA (short answer)](#afrimed-qa-short-answer)\n    - [AfriMed-QA (multiple choice)](#afrimed-qa-multiple-choice)\n\n- [Quick to Start](#quick-to-start)\n  - [Environment Setup](#environment-setup)\n  - [Core Feature 1: Use KTransformers backend to fine-tune ultra-large MoE models](#core-feature-1-use-ktransformers-backend-to-fine-tune-ultra-large-moe-models)\n  - [Core Feature 2: Chat with the fine-tuned model (base + LoRA adapter)](#core-feature-2-chat-with-the-fine-tuned-model-base--lora-adapter)\n  - [Core Feature 3: Batch inference + metrics (base + LoRA adapter)](#core-feature-3-batch-inference--metrics-base--lora-adapter)\n\n- [KT Fine-Tuning Speed (User-Side View)](#kt-fine-tuning-speed-user-side-view)\n  - [End-to-End Performance](#end-to-end-performance)\n  - [GPU/CPU Memory Footprint](#gpucpu-memory-footprint)\n\n- [Conclusion](#conclusion)\n\n\n# KTransformers Fine-Tuning × LLaMA-Factory Integration – User Guide\n\n**MadSys Lab, KVCache-AI Team, Approaching AI, LLaMA-Factory Team**\n\n## Introduction\n\nFrom **DeepSeek-V3/R1** to **Qwen3-MoE** and **Kimi-K2**, each wave of open-sourced large models brings leaps in performance and scale. However, many researchers and developers are constrained by expensive GPUs and models with tens or even hundreds of billions of parameters, making it **hard to fine-tune very large models under limited resources**. To bridge this gap, we propose a practical approach: combining **KTransformers** with **LLaMA-Factory**. With just **2–4 RTX 4090s** and a high-memory CPU, you can fine-tune ultra-large MoE models like DeepSeek-671B.\n\nOur goal is to give resource-constrained researchers a **local path to explore fine-tuning ultra-large models**, and also a fast way to customize smaller models (e.g., 14B/30B) for specific scenarios. We validate the setup using **stylized dialogue**, **Westernized translation tone**, and **medical Q&A** as representative tasks, showing that **personalized adaptation can be achieved within hours**.\n\nAs shown below, LLaMA-Factory is the unified orchestration/configuration layer for the whole fine-tuning workflow—handling data, training scheduling, LoRA injection, and inference interfaces. **KTransformers** acts as a pluggable high-performance backend that takes over core operators like Attention/MoE under the same training configs, enabling efficient **GPU+CPU heterogeneous cooperation**.\n\n![image-20251011010558909](../doc/assets/image-20251011010558909.png)\n\nWithin LLaMA-Factory, we compared LoRA fine-tuning with **HuggingFace**, **Unsloth**, and **KTransformers** backends. KTransformers is the **only workable 4090-class solution** for ultra-large MoE models (e.g., 671B) and also delivers higher throughput and lower GPU memory on smaller MoE models (e.g., DeepSeek-14B).\n\n| Under LoRA (BF16) + [NekoQA-10K stylized dialogue](https://github.com/mindsRiverPonder/LLM-practice) | HuggingFace Backend                      | Unsloth Backend                      | KTransformers Backend |\n| ------------------------------------------------------------ | ---------------------------------------- | ------------------------------------ | --------------------- |\n| [14B-DeepSeekV2-Lite] LoRA fine-tuning throughput            | 303.58 token/s                           | 455.37 token/s                       | 530.38 token/s        |\n| [14B-DeepSeekV2-Lite] GPU memory                             | 32.12 GB                                 | 9.64 GB                              | 6.08 GB               |\n| [671B-DeepSeekV3] LoRA fine-tuning throughput                | <font color='red'>Too Huge to run</font> | <font color='red'>NOT SUPPORT</font> | 40.35 token/s         |\n| [671B-DeepSeekV3] GPU memory (sum across GPUs)               | theoretical 1400 GB †                    | <font color='red'>NOT SUPPORT</font> | 70 GB †               |\n\n† **1400 GB** is a **theoretical** FP16 full-parameter resident footprint (not runnable). **70 GB** is the **measured peak** with KT strategy (Attention on GPU + layered MoE offload).\n\n![按照模型划分的对比图_02](../doc/assets/image-compare_model.png)\n\n### Fine-Tuning Results (Examples)\n\n#### Stylized Dialogue (CatGirl tone)\n\nDataset: [NekoQA-10K](https://zhuanlan.zhihu.com/p/1934983798233231689). Goal: improve style consistency and recognizability.\n\nThe figure compares responses from the base vs. fine-tuned models. The fine-tuned model maintains the target tone and address terms more consistently (red boxes), validating the effectiveness of **style-transfer fine-tuning**.\n\n![image-20251016175046882](../doc/assets/image-20251016175046882.png)\n\n#### Benchmarks\n\nWe use:\n\n(1) [Translational-Style-ChatLLM](https://github.com/Benson114/Translational-Style-ChatLLM), which asks for an exaggerated, Westernized translation tone—clear, stylized customization.\n\n(2) [AfriMed-QA](https://aclanthology.org/2025.acl-long.96/) (ACL 2025), a medical dataset for African contexts with strong domain specificity, including multiple-choice and short-answer sub-tasks—well-suited for vertical fine-tuning evaluation.\n\nThe tables show metrics before vs. after LoRA fine-tuning. We observe **large improvements** across metrics, verifying fine-tuning effectiveness:\n\n| Translational-Style dataset    | BLEU-1    | BLEU-2    | BLEU-3    | BLEU-4    | ROUGE-1   | ROUGE-2   | ROUGE-L   |\n| ------------------------------ | --------- | --------- | --------- | --------- | --------- | --------- | --------- |\n| V2-Lite (no LoRA)              | 20.66     | 8.33      | 4.54      | 2.89      | 22.71     | 4.52      | 19.19     |\n| **KT-LoRA fine-tuned V2-Lite** | **35.41** | **22.44** | **15.42** | **11.18** | **42.03** | **18.38** | **33.10** |\n| V3 base (no LoRA)              | 8.49      | 3.34      | 1.62      | 0.96      | 15.91     | 2.55      | 10.07     |\n| **KT-LoRA fine-tuned V3**      | **37.02** | **23.70** | **16.21** | **11.49** | **43.43** | **18.96** | **34.54** |\n\n| AfriMed-QA (short answer)      | BLEU-1    | BLEU-2    | BLEU-3    | BLEU-4    | ROUGE-1   | ROUGE-2   | ROUGE-L   |\n| ------------------------------ | --------- | --------- | --------- | --------- | --------- | --------- | --------- |\n| V2-Lite (no LoRA)              | 13.58     | 11.12     | 9.10      | 7.23      | 22.48     | 7.81      | 11.73     |\n| **KT-LoRA fine-tuned V2-Lite** | **35.90** | **27.63** | **22.99** | **19.15** | **35.25** | **17.50** | **28.44** |\n| V3 base (no LoRA)              | 12.75     | 10.27     | 8.05      | 5.99      | 20.33     | 5.65      | 10.11     |\n| **KT-LoRA fine-tuned V3**      | **42.42** | **34.12** | **28.95** | **24.54** | **41.97** | **22.37** | **33.28** |\n\n| AfriMed-QA (multiple choice)   | Accuracy   |\n| ------------------------------ | ---------- |\n| V2-Lite (no LoRA)              | 0.0645     |\n| **KT-LoRA fine-tuned V2-Lite** | **0.4812** |\n| V3 base (no LoRA)              | 0.5833     |\n| **KT-LoRA fine-tuned V3**      | **0.7930** |\n\nEven for ultra-large MoE models, **KTransformers-backed fine-tuning** achieves strong task performance quickly.\n\n\n\n## Quick to Start\n\nThis section shows how to install and use **LLaMA-Factory + KTransformers** for fine-tuning and inference:\n\n- Environment setup\n- Fine-tune ultra-large MoE models with KTransformers backend\n- Load the fine-tuned model (base + LoRA adapter) for chat/inference\n- Batch inference and metric evaluation\n\n### Environment Setup\n\nAccording to the following example, install both the **KTransformers** and **LLaMA-Factory** environments simultaneously.\n This time, to simplify the installation process of KTransformers, we have specially packaged a wheel file to avoid local compilation.\n The detailed installation steps are as follows:\n (Note: Make sure your local **Python version**, **Torch version**, **CUDA version**, and the **KTransformers wheel filename** correspond correctly.)\n\n```shell\n# 1. Create a conda environment\nconda create -n Kllama python=3.12 # choose from : [3.10, 3.11, 3.12, 3.13]\nconda install -y -c conda-forge libstdcxx-ng gcc_impl_linux-64\n# ATTENTION: DO NOT skip this step, even if your cuda version is not 11.8! Otherwise, you will get this error: ImportError: libcudart.so.11.0: cannot open shared object file: No such file or directory.\nconda install -y -c nvidia/label/cuda-11.8.0 cuda-runtime\n\n# 2. Install the LLaMA-Factory environment\ngit clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git\ncd LLaMA-Factory\npip install -e \".[torch,metrics]\" --no-build-isolation\n\n# 3. Install the KTransformers wheel that matches your Torch and Python versions, from https://github.com/kvcache-ai/ktransformers/releases/tag/v0.4.1 (Note: The CUDA version can differ from that in the wheel filename.)\npip install ktransformers-0.4.1+cu128torch27fancy-cp312-cp312-linux_x86_64.whl\n\n# 4. Install flash-attention, download the corresponding file based on your Python and Torch versions from: https://github.com/Dao-AILab/flash-attention/releases\npip install flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp312-cp312-linux_x86_64.whl\n# abi=True/False can find from below\n# import torch\n# print(torch._C._GLIBCXX_USE_CXX11_ABI)\n\n# 5. (Optional) If you want to use flash_infer (otherwise it defaults to triton)\ngit clone https://github.com/kvcache-ai/custom_flashinfer.git\npip install custom_flashinfer/\n```\n\n**Usage tip:** In LLaMA-Factory YAML, set `use_kt: true` and pick a `kt_optimize_rule` file to have KTransformers handle the core compute. The features below show typical configs.\n\n### Core Feature 1: Use KTransformers backend to fine-tune ultra-large MoE models\n\nRun the command: `USE_KT=1 llamafactory-cli train examples/train_lora/deepseek3_lora_sft_kt.yaml`.\n\nNote: You **must** provide a **BF16** model. DeepSeek-V3-671B is released in FP8 by default; convert with [DeepSeek-V3/inference/fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py).\n\n```yaml\n### model\nmodel_name_or_path: opensourcerelease/DeepSeek-V3-bf16\ntrust_remote_code: true\n\n### method\nstage: sft\ndo_train: true\nfinetuning_type: lora\nlora_rank: 8\nlora_target: all\n\n### dataset\ndataset: identity\ntemplate: deepseek\ncutoff_len: 2048\nmax_samples: 100000\noverwrite_cache: true\npreprocessing_num_workers: 16\ndataloader_num_workers: 4\n\n### output\noutput_dir: saves/Kllama_deepseekV3\nlogging_steps: 10\nsave_steps: 500\nplot_loss: true\noverwrite_output_dir: true\nsave_only_model: false\nreport_to: none  # choices: [none, wandb, tensorboard, swanlab, mlflow]\n\n### train\nper_device_train_batch_size: 1\ngradient_accumulation_steps: 8\nlearning_rate: 1.0e-4\nnum_train_epochs: 3.0\nlr_scheduler_type: cosine\nwarmup_ratio: 0.1\nbf16: true\nddp_timeout: 180000000\nresume_from_checkpoint: null\n\n### ktransformers\nuse_kt: true # use KTransformers as LoRA sft backend\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml\ncpu_infer: 32\nchunk_size: 8192\n```\n\nWe also support RL DPO training using the KTransformers backend now. See [DPO Tutorial](../doc/en/SFT/DPO_tutorial.md) for details.  \n\n`kt_optimize_rule` controls **placement strategy**. See also [ktransformers/optimize_rules](https://github.com/kvcache-ai/ktransformers/tree/main/ktransformers/optimize/optimize_rules). Naming hints (`*` = wildcard):\n\n| Pattern                                      | Meaning                                               |\n| -------------------------------------------- | ----------------------------------------------------- |\n| DeepSeek-V2-Lite-Chat-* / DeepSeek-V3-Chat-* | Target model variants                                 |\n| *-sft-*                                      | Strategy for fine-tuning; others are for inference    |\n| *-amx-*                                      | Use AMX on CPU; otherwise use **llamafile**           |\n| *-multi-gpu-X*                               | Model parallel on X GPUs (X omitted → default 2 GPUs) |\n\nExample: `DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml` = V3-Chat fine-tuning with AMX and 2-GPU model parallel.\n\nWe recommend **AMX acceleration** where available (`lscpu | grep amx`). AMX supports BF16/INT8. Example:\n\n```yaml\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert parallelism\n    kwargs:\n      prefill_device: \"cpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n```\n\nOutputs go to `output_dir` in safetensors format plus adapter metadata for later loading.\n\n![image-20251016171537997](../doc/assets/image-20251016171537997.png)\n\n### Core Feature 2: Chat with the fine-tuned model (base + LoRA adapter)\n\nRun the command: `llamafactory-cli chat examples/inference/deepseek3_lora_sft_kt.yaml`.\n\nUse the safetensors adapter trained with KT for inference.\n\n```yaml\nmodel_name_or_path: opensourcerelease/DeepSeek-V3-bf16\nadapter_name_or_path: saves/Kllama_deepseekV3\ntemplate: deepseek\ninfer_backend: ktransformers  # choices: [huggingface, vllm, sglang, ktransformers]\ntrust_remote_code: true\n\nuse_kt: true # use KTransformers as LoRA sft backend to inference\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml\ncpu_infer: 32\nchunk_size: 8192\n```\n\nWe also support **GGUF** adapters: for safetensors, set the **directory**; for GGUF, set the **file path** in `adapter_name_or_path`.\n\nDuring loading, LLaMA-Factory maps layer names to KT’s naming. You’ll see logs like `Loaded adapter weight: XXX -> XXX`:\n\n![image-20251016171526210](../doc/assets/image-20251016171526210.png)\n\n### Core Feature 3: Batch inference + metrics (base + LoRA adapter)\n\nRun the command: `API_PORT=8000 llamafactory-cli api examples/inference/deepseek3_lora_sft_kt.yaml`.\n Invoke the KT fine-tuned adapter to provide the API; the usage logic of other APIs is consistent with the native LLaMA-Factory approach.\n\n```yaml\nmodel_name_or_path: opensourcerelease/DeepSeek-V3-bf16\nadapter_name_or_path: saves/Kllama_deepseekV3\ntemplate: deepseek\ninfer_backend: ktransformers  # choices: [huggingface, vllm, sglang, ktransformers]\ntrust_remote_code: true\n\nuse_kt: true # use KTransformers as LoRA sft backend to inference\nkt_optimize_rule: examples/kt_optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml\ncpu_infer: 32\nchunk_size: 8192\n```\n\n\n\n## KT Fine-Tuning Speed (User-Side View)\n\n### End-to-End Performance\n\n**Definitions**\n\n- `step_time`: wall-clock time for a full optimization step (tensor movement + Attention + MoE + other compute).\n- `tokens_per_step = GAS × qlen`; `token/s = tokens_per_step / step_time`.\n\n**Settings:** `GAS=16`, `qlen=512` (→ `tokens_per_step = 8192`); LoRA (`r=8, alpha=32, dropout=0.1`); **AMX** enabled; GPU: RTX 4090, CPU: Intel Xeon Platinum 8488C.\n\n**Measured**\n\n- **DeepSeek-V3-671B:** `step_time = 203 s` → `token/s ≈ 8192 / 203 ≈ 40.35`\n- **DeepSeek-V2-Lite-14B:** `step_time = 36 s` → `token/s ≈ 8192 / 36 ≈ 227.6`\n\n### GPU/CPU Memory Footprint\n\n- DeepSeek-V3 (671B; 61 layers with 58 MoE): ~**70 GB** total GPU VRAM (multi-GPU), ~**1.2–1.3 TB** RAM.\n- DeepSeek-V2-Lite (14B; 27 layers with 26 MoE): ~**5.5 GB** GPU VRAM, ~**30 GB** RAM.\n\n## Conclusion\n\nBy integrating **KTransformers LoRA fine-tuning** into **LLaMA-Factory**, we provide a practical guide for efficient training and deployment of MoE LLMs. KT brings cutting-edge optimizations (DeepSeek/Qwen/Kimi support with AMX-accelerated kernels), and LoRA enables customization under very low GPU memory. LLaMA-Factory offers a friendly, unified interface.\n\nThis integration (akin to Unsloth-style speedups) means even models with tens to hundreds of billions of parameters can be fine-tuned and deployed with low latency on commodity hardware. You get **memory savings, speed-ups, and usability** together. We encourage you to try LLaMA-Factory + KT for your next MoE project and follow this guide. Feedback is welcome!\n"
  },
  {
    "path": "kt-sft/SECURITY.md",
    "content": "# Security Policy\n\n## Supported Versions\n\nUse this section to tell people about which versions of your project are\ncurrently being supported with security updates.\n\n| Version | Supported          |\n| ------- | ------------------ |\n| 5.1.x   | :white_check_mark: |\n| 5.0.x   | :x:                |\n| 4.0.x   | :white_check_mark: |\n| < 4.0   | :x:                |\n\n## Reporting a Vulnerability\n\nUse this section to tell people how to report a vulnerability.\n\nTell them where to go, how often they can expect to get an update on a\nreported vulnerability, what to expect if the vulnerability is accepted or\ndeclined, etc.\n"
  },
  {
    "path": "kt-sft/autosetup.sh",
    "content": "#!/usr/bin/env bash\nset -euo pipefail\nshopt -s nullglob\n\n# 允许通过环境变量覆盖\nPY_LIST=${PY_LIST:-\"3.13\"}\nTORCH_LIST=${TORCH_LIST:-\"2.5.0 2.6.0 2.7.0 2.8.0 2.9.0\"}\nWHEELS_DIR=${WHEELS_DIR:-wheels}\nFORCE=${FORCE:-0}    # FORCE=1 时强制重建\nmkdir -p \"$WHEELS_DIR\"\n\n# 每个 Torch 版本选择一个存在的 CUDA 索引（可按需调整）\nindex_for_torch_version () {\n  case \"$1\" in\n    2.3.*) echo \"https://download.pytorch.org/whl/cu121\" ;;\n    2.4.*) echo \"https://download.pytorch.org/whl/cu121\" ;;\n    2.5.*) echo \"https://download.pytorch.org/whl/cu124\" ;;\n    2.6.*) echo \"https://download.pytorch.org/whl/cu126\" ;;\n    2.7.*) echo \"https://download.pytorch.org/whl/cu128\" ;;\n    2.8.*) echo \"https://download.pytorch.org/whl/cu128\" ;;  # 可换 cu129\n    2.9.*) echo \"https://download.pytorch.org/whl/cu128\" ;;  # 可换 cu129\n    *)     echo \"https://download.pytorch.org/whl/cu121\" ;;\n  esac\n}\n\n# 检查指定“当前已激活环境”的组合是否已有产物\n# 依据 wheel 命名规则中的后缀：+<backend>torch<MM> 以及 -<cp_tag>-<cp_tag>-linux_<arch>\nhave_wheel_for_current_env () {\n  python - <<'PY'\nimport sys, platform, torch\nfrom packaging.version import parse\npy = f\"cp{sys.version_info.major}{sys.version_info.minor}\"\narch = platform.uname().machine\ntver = parse(torch.__version__)\nmm = f\"{tver.major}{tver.minor}\"\nbackend = \"\"\nif torch.version.cuda:\n    backend = \"cu\" + torch.version.cuda.replace(\".\", \"\")\nelif getattr(torch.version, \"hip\", None):\n    backend = \"rocm\" + torch.version.hip.replace(\".\", \"\")\nelse:\n    backend = \"cpu\"  # 极少走到这里\nprint(py, arch, backend, mm)\nPY\n}\n\nfor py in $PY_LIST; do\n  PYBIN=\"$(command -v python${py} || true)\"\n  if [[ ! -x \"$PYBIN\" ]]; then\n    echo \">> Skip python ${py}: not found\"\n    continue\n  fi\n  for tv in $TORCH_LIST; do\n    echo \"======== Build: Python ${py} × Torch ${tv} ========\"\n\n    # 1) 新建并激活 venv\n    ENV_DIR=\".venv-py${py//./}-torch${tv%%.*}${tv#*.}\"\n    \"$PYBIN\" -m venv \"$ENV_DIR\"\n    source \"$ENV_DIR/bin/activate\"\n\n    # 2) 安装构建依赖 + 目标 torch（固定 CUDA 索引以避免装到 CPU 轮子）\n    python -m pip install -U pip\n    python -m pip install setuptools wheel build ninja cmake packaging cpufeature\n    IDX=\"$(index_for_torch_version \"$tv\")\"\n    python -m pip install --index-url \"$IDX\" \"torch==$tv\"\n\n    # 3) 读取当前环境的关键信息，拼出匹配的 wheel 通配符并检查是否已存在\n    read -r CP_TAG ARCH BACKEND MM <<<\"$(have_wheel_for_current_env)\"\n    plat=\"linux_${ARCH}\"\n    pattern=\"${WHEELS_DIR}/ktransformers-*+${BACKEND}torch${MM}*-${CP_TAG}-${CP_TAG}-${plat}.whl\"\n\n    if [[ \"$FORCE\" = \"0\" ]]; then\n      existing=( $pattern )\n      if (( ${#existing[@]} > 0 )); then\n        echo \">> Found existing wheel, skip: ${existing[0]}\"\n        deactivate\n        continue\n      fi\n    else\n      echo \">> FORCE=1, rebuild even if wheel exists\"\n    fi\n\n    # 打印对齐信息\n    python - <<'PY'\nimport torch, sys\nprint(\">>> torch:\", torch.__version__, \"cuda:\", torch.version.cuda,\n      \"cxx11abi:\", torch.compiled_with_cxx11_abi())\nprint(\">>> python:\", sys.version)\nPY\n\n    # ★ 清理所有构建产物（含内嵌 CMake build）\n    rm -rf build/ dist/ *.egg-info\n    find csrc -type d -name build -prune -exec rm -rf {} +\n\n    # 构建\n    KTRANSFORMERS_FORCE_BUILD=TRUE KTRANSFORMERS_DISABLE_PREBUILT=1 \\\n    python -m build --no-isolation --wheel\n\n    # ★ 验证 wheel 内包含 cpuinfer_ext\n    whl=\"$(ls dist/*.whl)\"\n    unzip -l \"$whl\" | grep -E 'cpuinfer_ext.*\\.so' >/dev/null || {\n      echo \"!! cpuinfer_ext missing in $whl\"; exit 2;\n    }\n\n    mv dist/*.whl wheels/ || true\n    deactivate\n  done\ndone\n\necho \"== Wheels saved in ./wheels ==\"\n"
  },
  {
    "path": "kt-sft/book.toml",
    "content": "[book]\nauthors = [\"kvcache-ai\"]\nlanguage = \"zh-CN\"\ntitle = \"Ktransformers\"\nsrc = \"doc\"\n\n[output.html]\ngit-repository-url = \"https://github.com/kvcache-ai/ktransformers\"\nedit-url-template = \"https://github.com/kvcache-ai/ktransformers/edit/main/{path}\"\n\n[output.html.playground]\neditable = true\ncopy-js = true\n# line-numbers = true\n\n[output.html.fold]\nenable = true\nlevel = 0"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/csrc/custom_marlin/binding.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Azure-Tang\n * @Date         : 2024-07-25 13:38:30\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-12 03:05:04\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"gptq_marlin/ops.h\"\n// Python bindings\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <torch/extension.h>\n#include <torch/library.h>\n#include <torch/torch.h>\n// namespace py = pybind11;\n\nPYBIND11_MODULE(vLLMMarlin, m) {\n\n    /*m.def(\"dequantize_q8_0\", &dequantize_q8_0, \"Function to dequantize q8_0\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_q6_k\", &dequantize_q6_k, \"Function to dequantize q6_k\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_q5_k\", &dequantize_q5_k, \"Function to dequantize q5_k\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_q4_k\",  &dequantize_q4_k, \"Function to dequantize q4_k\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_q3_k\",  &dequantize_q3_k, \"Function to dequantize q3_k\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_q2_k\",  &dequantize_q2_k, \"Function to dequantize q2_k\n    data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));\n    m.def(\"dequantize_iq4_xs\",  &dequantize_iq4_xs, \"Function to dequantize\n    iq4_xs data.\", py::arg(\"data\"), py::arg(\"blk_size\"), py::arg(\"device\"));*/\n    m.def(\"gptq_marlin_gemm\", &gptq_marlin_gemm,\n          \"Function to perform GEMM using Marlin quantization.\", py::arg(\"a\"),\n          py::arg(\"b_q_weight\"), py::arg(\"b_scales\"), py::arg(\"g_idx\"),\n          py::arg(\"perm\"), py::arg(\"workspace\"), py::arg(\"num_bits\"), py::arg(\"size_m_tensor\"),\n          py::arg(\"size_m\"), py::arg(\"size_n\"), py::arg(\"size_k\"),\n          py::arg(\"sms\"), py::arg(\"is_k_full\"));\n    m.def(\"gptq_marlin_repack\", &gptq_marlin_repack,\n            \"gptq_marlin repack from GPTQ\");\n}"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/gptq_marlin/gptq_marlin.cu",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n /*\n  * Adapted from https://github.com/IST-DASLab/marlin\n  */\n  /*\n   * Adapted from\n   * https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n   */\n#include \"gptq_marlin.cuh\"\n#include \"gptq_marlin_dtypes.cuh\"\n#include <c10/cuda/CUDAGuard.h>\n#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t)                              \\\n    static_assert(std::is_same<scalar_t, half>::value ||                       \\\n                      std::is_same<scalar_t, nv_bfloat16>::value,              \\\n                  \"only float16 and bfloat16 is supported\");\n\ntemplate <typename T> inline std::string str(T x) { return std::to_string(x); }\n\nnamespace gptq_marlin {\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\n    __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,\n        int const* __restrict__ perm_int_ptr,\n        int4* __restrict__ out_int4_ptr, int size_m,\n        int size_k, int block_rows) {}\n\n    template <typename scalar_t,         // compute dtype, half or nv_float16\n        const int num_bits,        // number of bits used for weights\n        const int threads,         // number of threads in a threadblock\n        const int thread_m_blocks, // number of 16x16 blocks in the m\n        // dimension (batchsize) of the\n        // threadblock\n        const int thread_n_blocks, // same for n dimension (output)\n        const int thread_k_blocks, // same for k dimension (reduction)\n        const int stages, // number of stages for the async global->shared\n        // fetch pipeline\n        const bool has_act_order,   // whether act_order is enabled\n        const int group_blocks = -1 // number of consecutive 16x16 blocks\n        // with a separate quantization scale\n    >\n    __global__ void\n        Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk\n            const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn\n            int4* __restrict__ C,       // fp16 output buffer of shape mxn\n            const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape\n            // (k/groupsize)xn\n            const int* __restrict__ g_idx, // int32 group indices of shape k\n            int num_groups, // number of scale groups per output channel\n            int prob_m,     // batch dimension m\n            int prob_n,     // output dimension n\n            int prob_k,     // reduction dimension k\n            int* locks      // extra global storage for barrier synchronization\n        ) {}\n\n} // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n    torch::Tensor& b_scales, torch::Tensor& g_idx,\n    torch::Tensor& perm, torch::Tensor& workspace,\n    int64_t num_bits, int64_t size_m, int64_t size_n,\n    int64_t size_k, bool is_k_full) {\n    TORCH_CHECK_NOT_IMPLEMENTED(false,\n        \"marlin_gemm(..) requires CUDA_ARCH >= 8.0\");\n    return torch::empty({ 1, 1 });\n}\n\n#else\n\n    // m16n8k16 tensor core mma instruction with fp16 inputs and fp32\n    // output/accumulation.\n    template <typename scalar_t>\n    __device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,\n        const typename ScalarType<scalar_t>::FragB& frag_b,\n        typename ScalarType<scalar_t>::FragC& frag_c) {\n        const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n        const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n        float* c = reinterpret_cast<float*>(&frag_c);\n        if constexpr (std::is_same<scalar_t, half>::value) {\n            asm volatile(\n                \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n                \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n                : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n                : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n                \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n        }\n        else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n            asm volatile(\n                \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n                \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n                : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n                : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n                \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n        }\n        else {\n            STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n        }\n    }\n\n    // Instruction for loading a full 16x16 matrix fragment of operand A from shared\n    // memory, directly in tensor core layout.\n    template <typename scalar_t>\n    __device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,\n        const void* smem_ptr) {\n        uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);\n        uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n        asm volatile(\n            \"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n            : \"=r\"(a[0]), \"=r\"(a[1]), \"=r\"(a[2]), \"=r\"(a[3])\n            : \"r\"(smem));\n    }\n\n    // Lookup-table based 3-input logical operation; explicitly used for\n    // dequantization as the compiler does not seem to automatically recognize it in\n    // all cases.\n    template <int lut> __device__ inline int lop3(int a, int b, int c) {\n        int res;\n        asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n            : \"=r\"(res)\n            : \"r\"(a), \"r\"(b), \"r\"(c), \"n\"(lut));\n        return res;\n    }\n\n    // Constructs destination register by taking bytes from 2 sources (based on\n    // mask)\n    template <int start_byte, int mask>\n    __device__ inline uint32_t prmt(uint32_t a) {\n        uint32_t res;\n        asm volatile(\"prmt.b32 %0, %1, %2, %3;\\n\"\n            : \"=r\"(res)\n            : \"r\"(a), \"n\"(start_byte), \"n\"(mask));\n        return res;\n    }\n\n    // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16\n    // values. We mostly follow the strategy in the link below, with some small\n    // changes:\n    // - FP16:\n    // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287\n    // - BF16:\n    // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385\n    template <typename scalar_t>\n    __device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {\n        STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n    }\n\n    template <>\n    __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {\n        const int LO = 0x000f000f;\n        const int HI = 0x00f000f0;\n        const int EX = 0x64006400;\n        // Guarantee that the `(a & b) | c` operations are LOP3s.\n        int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);\n        int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);\n        // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point\n        // directly into `SUB` and `ADD`.\n        const int SUB = 0x64086408;\n        const int MUL = 0x2c002c00;\n        const int ADD = 0xd480d480;\n        typename ScalarType<half>::FragB frag_b;\n        frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),\n            *reinterpret_cast<const half2*>(&SUB));\n        frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),\n            *reinterpret_cast<const half2*>(&MUL),\n            *reinterpret_cast<const half2*>(&ADD));\n        return frag_b;\n    }\n\n    template <>\n    __device__ inline typename ScalarType<nv_bfloat16>::FragB\n        dequant_4bit<nv_bfloat16>(int q) {\n        static constexpr uint32_t MASK = 0x000f000f;\n        static constexpr uint32_t EX = 0x43004300;\n\n        // Guarantee that the `(a & b) | c` operations are LOP3s.\n\n        int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n        q >>= 4;\n        int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n\n        typename ScalarType<nv_bfloat16>::FragB frag_b;\n        static constexpr uint32_t MUL = 0x3F803F80;\n        static constexpr uint32_t ADD = 0xC308C308;\n\n        frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),\n            *reinterpret_cast<const nv_bfloat162*>(&MUL),\n            *reinterpret_cast<const nv_bfloat162*>(&ADD));\n        frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),\n            *reinterpret_cast<const nv_bfloat162*>(&MUL),\n            *reinterpret_cast<const nv_bfloat162*>(&ADD));\n        return frag_b;\n    }\n\n    // Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or\n    // bf16 Reference:\n    // - FP16:\n    // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85\n    // - BF16:\n    // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175\n    template <typename scalar_t>\n    __device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {\n        STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n    }\n\n    template <>\n    __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {\n        static constexpr uint32_t mask_for_elt_01 = 0x5250;\n        static constexpr uint32_t mask_for_elt_23 = 0x5351;\n        static constexpr uint32_t start_byte_for_fp16 = 0x64646464;\n\n        uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);\n        uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);\n\n        static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;\n\n        typename ScalarType<half>::FragB frag_b;\n        frag_b[0] =\n            __hsub2(*reinterpret_cast<half2*>(&lo),\n                *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n        frag_b[1] =\n            __hsub2(*reinterpret_cast<half2*>(&hi),\n                *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n        return frag_b;\n    }\n\n    template <>\n    __device__ inline typename ScalarType<nv_bfloat16>::FragB\n        dequant_8bit<nv_bfloat16>(int q) {\n        typename ScalarType<nv_bfloat16>::FragB frag_b;\n\n        float fp32_intermediates[4];\n        uint32_t* fp32_intermediates_casted =\n            reinterpret_cast<uint32_t*>(fp32_intermediates);\n\n        static constexpr uint32_t fp32_base = 0x4B000000;\n        fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);\n        fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);\n        fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);\n        fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);\n\n        fp32_intermediates[0] -= 8388736.f;\n        fp32_intermediates[1] -= 8388736.f;\n        fp32_intermediates[2] -= 8388736.f;\n        fp32_intermediates[3] -= 8388736.f;\n\n        uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);\n        bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],\n            fp32_intermediates_casted[1], 0x7632);\n        bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],\n            fp32_intermediates_casted[3], 0x7632);\n\n        return frag_b;\n    }\n\n    // Multiply dequantized values by the corresponding quantization scale; used\n    // only for grouped quantization.\n    template <typename scalar_t>\n    __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,\n        typename ScalarType<scalar_t>::FragS& frag_s,\n        int i) {\n        using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n        scalar_t2 s = ScalarType<scalar_t>::num2num2(\n            reinterpret_cast<scalar_t*>(&frag_s)[i]);\n        frag_b[0] = __hmul2(frag_b[0], s);\n        frag_b[1] = __hmul2(frag_b[1], s);\n    }\n\n    // Same as above, but for act_order (each K is multiplied individually)\n    template <typename scalar_t>\n    __device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,\n        typename ScalarType<scalar_t>::FragS& frag_s_1,\n        typename ScalarType<scalar_t>::FragS& frag_s_2,\n        typename ScalarType<scalar_t>::FragS& frag_s_3,\n        typename ScalarType<scalar_t>::FragS& frag_s_4,\n        int i) {\n        using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n        scalar_t2 s_val_1_2;\n        s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];\n        s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];\n\n        scalar_t2 s_val_3_4;\n        s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];\n        s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];\n\n        frag_b[0] = __hmul2(frag_b[0], s_val_1_2);\n        frag_b[1] = __hmul2(frag_b[1], s_val_3_4);\n    }\n\n    // Given 2 floats multiply by 2 scales (halves)\n    template <typename scalar_t>\n    __device__ inline void scale_float(float* c,\n        typename ScalarType<scalar_t>::FragS& s) {\n        scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);\n        c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));\n        c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));\n    }\n\n    // Wait until barrier reaches `count`, then lock for current threadblock.\n    __device__ inline void barrier_acquire(int* lock, int count) {\n        if (threadIdx.x == 0) {\n            int state = -1;\n            do\n                // Guarantee that subsequent writes by this threadblock will be\n                // visible globally.\n                asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\"\n                    : \"=r\"(state)\n                    : \"l\"(lock));\n            while (state != count);\n        }\n        __syncthreads();\n    }\n\n    // Release barrier and increment visitation count.\n    __device__ inline void barrier_release(int* lock, bool reset = false) {\n        __syncthreads();\n        if (threadIdx.x == 0) {\n            if (reset) {\n                lock[0] = 0;\n                return;\n            }\n            int val = 1;\n            // Make sure that all writes since acquiring this barrier are visible\n            // globally, while releasing the barrier.\n            asm volatile(\"fence.acq_rel.gpu;\\n\");\n            asm volatile(\"red.relaxed.gpu.global.add.s32 [%0], %1;\\n\"\n                :\n            : \"l\"(lock), \"r\"(val));\n        }\n    }\n\n    // For a given \"a\" of size [M,K] performs a permutation of the K columns based\n    // on the given \"perm\" indices.\n    __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,\n        int const* __restrict__ perm_int_ptr,\n        int4* __restrict__ out_int4_ptr, int size_m,\n        int size_k, int block_rows) {\n        int start_row = block_rows * blockIdx.x;\n        int finish_row = start_row + block_rows;\n        if (finish_row > size_m) {\n            finish_row = size_m;\n        }\n        int cur_block_rows = finish_row - start_row;\n\n        int row_stride = size_k * sizeof(half) / 16;\n\n        auto permute_row = [&](int row) {\n            int iters = size_k / default_threads;\n            int rest = size_k % default_threads;\n\n            int offset = row * row_stride;\n\n            half const* a_row_half =\n                reinterpret_cast<half const*>(a_int4_ptr + offset);\n            half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);\n\n            int base_k = 0;\n\n            for (int i = 0; i < iters; i++) {\n                int cur_k = base_k + threadIdx.x;\n                int src_pos = perm_int_ptr[cur_k];\n\n                out_half[cur_k] = a_row_half[src_pos];\n\n                base_k += default_threads;\n            }\n\n            if (rest) {\n                if (threadIdx.x < rest) {\n                    int cur_k = base_k + threadIdx.x;\n                    int src_pos = perm_int_ptr[cur_k];\n\n                    out_half[cur_k] = a_row_half[src_pos];\n                }\n            }\n            };\n\n        for (int i = 0; i < cur_block_rows; i++) {\n            int cur_row = start_row + i;\n            if (cur_row < size_m) {\n                permute_row(cur_row);\n            }\n        }\n    }\n\n    template <typename scalar_t,         // compute dtype, half or nv_float16\n        const int num_bits,        // number of bits used for weights\n        const int threads,         // number of threads in a threadblock\n        const int thread_m_blocks, // number of 16x16 blocks in the m\n        // dimension (batchsize) of the\n        // threadblock\n        const int thread_n_blocks, // same for n dimension (output)\n        const int thread_k_blocks, // same for k dimension (reduction)\n        const int stages, // number of stages for the async global->shared\n        // fetch pipeline\n        const bool has_act_order,   // whether act_order is enabled\n        const int group_blocks = -1 // number of consecutive 16x16 blocks\n        // with a separate quantization scale\n    >\n    __device__ void\n        Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk\n            const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn\n            int4* __restrict__ C,       // fp16 output buffer of shape mxn\n            const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape\n            // (k/groupsize)xn\n            const int* __restrict__ g_idx, // int32 group indices of shape k\n            int num_groups, // number of scale groups per output channel\n            int prob_m,     // batch dimension m, should be divisible by (16 * thread_m_blocks) if bigger than that\n            int prob_n,     // output dimension n\n            int prob_k,     // reduction dimension k\n            int* locks      // extra global storage for barrier synchronization\n        ) {\n        // Each threadblock processes one \"stripe\" of the B matrix with (roughly) the\n        // same size, which might involve multiple column \"slices\" (of width 16 *\n        // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM\n        // example:\n        //   0 1 3\n        //   0 2 3\n        //   1 2 4\n        // While this kind of partitioning makes things somewhat more complicated, it\n        // ensures good utilization of all SMs for many kinds of shape and GPU\n        // configurations, while requiring as few slow global cross-threadblock\n        // reductions as possible.\n        using Dtype = ScalarType<scalar_t>;\n        using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n        using FragA = typename ScalarType<scalar_t>::FragA;\n        using FragB = typename ScalarType<scalar_t>::FragB;\n        using FragC = typename ScalarType<scalar_t>::FragC;\n        using FragS = typename ScalarType<scalar_t>::FragS;\n\n        constexpr int pack_factor = 32 / num_bits;\n\n        // int prob_m = *prob_m_ptr;\n        // const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);\n        // constexpr int thread_m_blocks = template_thread_m_blocks;\n\n        // For larger GEMMs we run multiple batchsize 64 versions in parallel for a\n        // better partitioning with less reductions\n        int parallel = 1;\n        if (prob_m > 16 * thread_m_blocks) {\n            parallel = prob_m / (16 * thread_m_blocks);\n            prob_m = 16 * thread_m_blocks;\n        }\n\n        int k_tiles = prob_k / 16 / thread_k_blocks;\n        int n_tiles = prob_n / 16 / thread_n_blocks;\n        int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);\n\n        if constexpr (!has_act_order && group_blocks != -1) {\n            if (group_blocks >= thread_k_blocks) {\n                // Ensure that the number of tiles in each stripe is a multiple of the\n                // groupsize; this avoids an annoying special case where a stripe starts\n                // in the middle of group.\n                iters = (group_blocks / thread_k_blocks) *\n                    div_ceil(iters, (group_blocks / thread_k_blocks));\n            }\n        }\n\n        int slice_row = (iters * blockIdx.x) % k_tiles;\n        int slice_col_par = (iters * blockIdx.x) / k_tiles;\n        int slice_col = slice_col_par;\n        int slice_iters;  // number of threadblock tiles in the current slice\n        int slice_count =\n            0;          // total number of active threadblocks in the current slice\n        int slice_idx;  // index of threadblock in current slice; numbered bottom to\n        // top\n\n    // We can easily implement parallel problem execution by just remapping\n    // indices and advancing global pointers\n        if (slice_col_par >= n_tiles) {\n            A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;\n            C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;\n            locks += (slice_col_par / n_tiles) * n_tiles;\n            slice_col = slice_col_par % n_tiles;\n        }\n\n        // Compute all information about the current slice which is required for\n        // synchronization.\n        auto init_slice = [&]() {\n            slice_iters =\n                iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);\n            if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;\n            if (slice_iters == 0) return;\n            if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;\n            slice_count = 1;\n            slice_idx = 0;\n            int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);\n            if (col_first <= k_tiles * (slice_col_par + 1)) {\n                int col_off = col_first - k_tiles * slice_col_par;\n                slice_count = div_ceil(k_tiles - col_off, iters);\n                if (col_off > 0) slice_count++;\n                int delta_first = iters * blockIdx.x - col_first;\n                if (delta_first < 0 || (col_off == 0 && delta_first == 0))\n                    slice_idx = slice_count - 1;\n                else {\n                    slice_idx = slice_count - 1 - delta_first / iters;\n                    if (col_off > 0) slice_idx--;\n                }\n            }\n            if (slice_col == n_tiles) {\n                A += 16 * thread_m_blocks * prob_k / 8;\n                C += 16 * thread_m_blocks * prob_n / 8;\n                locks += n_tiles;\n                slice_col = 0;\n            }\n            };\n        init_slice();\n\n        // A sizes/strides\n\n        // stride of the A matrix in global memory\n        int a_gl_stride = prob_k / 8;\n        // stride of an A matrix tile in shared memory\n        constexpr int a_sh_stride = 16 * thread_k_blocks / 8;\n        // delta between subsequent A tiles in global memory\n        constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;\n        // between subsequent accesses within a tile\n        int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);\n        // between shared memory writes\n        constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);\n        // between shared memory tile reads\n        constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));\n        // within a shared memory tile\n        constexpr int a_sh_rd_delta_i = a_sh_stride * 16;\n        // overall size of a tile\n        constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);\n        // number of shared write iterations for a tile\n        constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);\n\n        // B sizes/strides\n        int b_gl_stride = 16 * prob_n / (pack_factor * 4);\n        constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;\n        constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;\n        constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;\n\n        int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;\n        int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);\n        constexpr int b_sh_wr_delta = threads * b_thread_vecs;\n        constexpr int b_sh_rd_delta = threads * b_thread_vecs;\n        constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;\n        constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;\n\n        // Scale sizes/strides without act_order\n        int s_gl_stride = prob_n / 8;\n        constexpr int s_sh_stride = 16 * thread_n_blocks / 8;\n        constexpr int s_tb_groups =\n            !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks\n            ? thread_k_blocks / group_blocks\n            : 1;\n        constexpr int s_sh_stage = s_tb_groups * s_sh_stride;\n        int s_gl_rd_delta = s_gl_stride;\n\n        // Scale size/strides with act_order\n        constexpr int tb_k = 16 * thread_k_blocks;\n        constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;\n        // constexpr int act_s_row_stride      = 1;\n        // int           act_s_col_stride      = act_s_row_stride * num_groups;\n        int act_s_col_stride = 1;\n        int act_s_col_warp_stride = act_s_col_stride * 8;\n        int tb_n_warps = thread_n_blocks / 4;\n        int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;\n\n        // Global A read index of current thread.\n        int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n            (threadIdx.x % a_gl_rd_delta_o);\n        a_gl_rd += a_gl_rd_delta_o * slice_row;\n        // Shared write index of current thread.\n        int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +\n            (threadIdx.x % a_gl_rd_delta_o);\n        // Shared read index.\n        int a_sh_rd =\n            a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;\n        a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));\n\n        int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +\n            (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;\n        b_gl_rd += b_sh_stride * slice_col;\n        b_gl_rd += b_gl_rd_delta_o * slice_row;\n        int b_sh_wr = threadIdx.x * b_thread_vecs;\n        int b_sh_rd = threadIdx.x * b_thread_vecs;\n\n        // For act_order\n        constexpr int k_iter_size = tb_k / b_sh_wr_iters;\n        int slice_k_start = tb_k * slice_row;\n        int slice_k_finish = slice_k_start + tb_k * slice_iters;\n        int slice_k_start_shared_fetch = slice_k_start;\n        int slice_n_offset = act_s_col_tb_stride * slice_col;\n\n        // No act_order\n        int s_gl_rd;\n        if constexpr (!has_act_order) {\n            if constexpr (group_blocks == -1) {\n                s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n            }\n            else {\n                s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +\n                    s_sh_stride * slice_col + threadIdx.x;\n            }\n        }\n        int s_sh_wr = threadIdx.x;\n        bool s_sh_wr_pred = threadIdx.x < s_sh_stride;\n\n        // We use a different scale layout for grouped and column-wise quantization as\n        // we scale a `half2` tile in column-major layout in the former and in\n        // row-major in the latter case.\n        int s_sh_rd;\n        if constexpr (group_blocks != -1)\n            s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n            (threadIdx.x % 32) / 4;\n        else\n            s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n            (threadIdx.x % 32) % 4;\n\n        // Precompute which thread should not read memory in which iterations; this is\n        // needed if there are more threads than required for a certain tilesize or\n        // when the batchsize is not a multiple of 16.\n        bool a_sh_wr_pred[a_sh_wr_iters];\n#pragma unroll\n        for (int i = 0; i < a_sh_wr_iters; i++) {\n            a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;\n        }\n\n        // To ensure that writing and reading A tiles to/from shared memory, the\n        // latter in fragment format, is fully bank conflict free, we need to use a\n        // rather fancy XOR-based layout. The key here is that neither reads nor\n        // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the\n        // same shared memory banks. Further, it seems (based on NSight-Compute) that\n        // each warp must also write a consecutive memory segment?\n        auto transform_a = [&](int i) {\n            int row = i / a_gl_rd_delta_o;\n            return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;\n            };\n        // Since the computation of this remapping is non-trivial and, due to our main\n        // loop unrolls, all shared memory accesses are static, we simply precompute\n        // both transformed reads and writes.\n        int a_sh_wr_trans[a_sh_wr_iters];\n#pragma unroll\n        for (int i = 0; i < a_sh_wr_iters; i++) {\n            a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);\n        }\n        int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];\n#pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++) {\n#pragma unroll\n            for (int j = 0; j < thread_m_blocks; j++)\n            {\n                a_sh_rd_trans[i][j] =\n                    transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);\n            }\n        }\n\n        // Since B-accesses have non-constant stride they have to be computed at\n        // runtime; we break dependencies between subsequent accesses with a tile by\n        // maintining multiple pointers (we have enough registers), a tiny\n        // optimization.\n        const int4* B_ptr[b_sh_wr_iters];\n#pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++)\n            B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;\n\n        extern __shared__ int4 sh[];\n        // Shared memory storage for global fetch pipelines.\n        int4* sh_a = sh;\n        int4* sh_b = sh_a + (stages * a_sh_stage);\n        int4* sh_g_idx = sh_b + (stages * b_sh_stage);\n        int4* sh_s = sh_g_idx + (stages * g_idx_stage);\n\n        // Register storage for double buffer of shared memory reads.\n        FragA frag_a[2][thread_m_blocks];\n        I4 frag_b_quant[2][b_thread_vecs];\n        FragC frag_c[thread_m_blocks][4][2];\n        FragS frag_s[2][4];         // No act-order\n        FragS act_frag_s[2][4][4];  // For act-order\n\n        // Zero accumulators.\n        auto zero_accums = [&]() {\n#pragma unroll\n            for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)\n            {\n                reinterpret_cast<float*>(frag_c)[i] = 0;\n            }\n            };\n\n        int sh_first_group_id = -1;\n        int sh_num_groups = -1;\n        constexpr int sh_max_num_groups = 32;\n\n        auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,\n            int last_group_id) {\n                sh_first_group_id = first_group_id;\n                sh_num_groups = last_group_id - first_group_id + 1;\n\n                if (sh_num_groups < sh_max_num_groups) {\n                    sh_num_groups = sh_max_num_groups;\n                }\n\n                if (sh_first_group_id + sh_num_groups > num_groups) {\n                    sh_num_groups = num_groups - sh_first_group_id;\n                }\n\n                int row_offset = first_group_id * s_gl_stride;\n\n                if (is_async) {\n                    for (int i = 0; i < sh_num_groups; i++) {\n                        if (threadIdx.x < s_sh_stride) {\n                            cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],\n                                &scales_ptr[row_offset + (i * s_gl_stride) +\n                                slice_n_offset + threadIdx.x]);\n                        }\n                    }\n                }\n                else {\n                    for (int i = 0; i < sh_num_groups; i++) {\n                        if (threadIdx.x < s_sh_stride) {\n                            sh_s[(i * s_sh_stride) + threadIdx.x] =\n                                scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +\n                                threadIdx.x];\n                        }\n                    }\n                }\n            };\n        // Asynchronously fetch the next A, B and s tile from global to the next\n        // shared memory pipeline location.\n        auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {\n            if (pred) {\n                int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n#pragma unroll\n                for (int i = 0; i < a_sh_wr_iters; i++) {\n                    cp_async4_pred(\n                        &sh_a_stage[a_sh_wr_trans[i]],\n                        &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],\n                        a_sh_wr_pred[i]);\n                }\n                int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n#pragma unroll\n                for (int i = 0; i < b_sh_wr_iters; i++) {\n#pragma unroll\n                    for (int j = 0; j < b_thread_vecs; j++) {\n                        cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);\n                    }\n\n                    B_ptr[i] += b_gl_rd_delta_o;\n                }\n\n                if constexpr (has_act_order) {\n                    // Fetch g_idx thread-block portion\n                    int full_pipe = a_off;\n                    int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;\n                    if (cur_k < prob_k && cur_k < slice_k_finish) {\n                        int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n\n                        int4 const* cur_g_idx_stage_ptr =\n                            reinterpret_cast<int4 const*>(&g_idx[cur_k]);\n\n                        if (threadIdx.x < g_idx_stage) {\n                            cp_async4_pred(&sh_g_idx_stage[threadIdx.x],\n                                &cur_g_idx_stage_ptr[threadIdx.x]);\n                        }\n                    }\n                }\n                else {\n                    if constexpr (group_blocks != -1) {\n                        int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n                        if constexpr (group_blocks >= thread_k_blocks) {\n                            // Only fetch scales if this tile starts a new group\n                            if (pipe % (group_blocks / thread_k_blocks) == 0) {\n                                if (s_sh_wr_pred) {\n                                    cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);\n                                }\n                                s_gl_rd += s_gl_rd_delta;\n                            }\n                        }\n                        else {\n                            for (int i = 0; i < s_tb_groups; i++) {\n                                if (s_sh_wr_pred) {\n                                    cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],\n                                        &scales_ptr[s_gl_rd]);\n                                }\n                                s_gl_rd += s_gl_rd_delta;\n                            }\n                        }\n                    }\n                }\n            }\n            // Insert a fence even when we are winding down the pipeline to ensure that\n            // waiting is also correct at this point.\n            cp_async_fence();\n            };\n\n        // Wait until the next thread tile has been loaded to shared memory.\n        auto wait_for_stage = [&]() {\n            // We only have `stages - 2` active fetches since we are double buffering\n            // and can only issue the next fetch when it is guaranteed that the previous\n            // shared memory load is fully complete (as it may otherwise be\n            // overwritten).\n            cp_async_wait<stages - 2>();\n            __syncthreads();\n            };\n\n        // Load the next sub-tile from the current location in the shared memory pipe\n        // into the current register buffer.\n        auto fetch_to_registers = [&](int k, int pipe) {\n            int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n#pragma unroll\n            for (int i = 0; i < thread_m_blocks; i++)\n            {\n                ldsm4<scalar_t>(frag_a[k % 2][i],\n                    &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);\n            }\n\n            int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n\n#pragma unroll\n            for (int i = 0; i < b_thread_vecs; i++) {\n                frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(\n                    &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);\n            }\n            };\n\n        bool is_same_group[stages];\n        int same_group_id[stages];\n\n        auto init_same_group = [&](int pipe) {\n            if constexpr (!has_act_order) {\n                is_same_group[pipe] = false;\n                same_group_id[pipe] = 0;\n                return;\n            }\n\n            int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n            int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n            int group_id_1 = sh_g_idx_int_ptr[0];\n            int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];\n\n            is_same_group[pipe] = group_id_1 == group_id_2;\n            same_group_id[pipe] = group_id_1;\n            };\n\n        auto fetch_scales_to_registers = [&](int k, int full_pipe) {\n            int pipe = full_pipe % stages;\n\n            if constexpr (!has_act_order) {\n                // No act-order case\n                if constexpr (group_blocks != -1) {\n                    if constexpr (group_blocks >= thread_k_blocks) {\n                        int4* sh_s_stage =\n                            sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *\n                                (pipe / (group_blocks / thread_k_blocks)));\n                        reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];\n                    }\n                    else {\n                        int warp_id = threadIdx.x / 32;\n                        int n_warps = thread_n_blocks / 4;\n\n                        int warp_row = warp_id / n_warps;\n\n                        int cur_k = warp_row * 16;\n                        cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n                        int k_blocks = cur_k / 16;\n                        int cur_group_id = k_blocks / group_blocks;\n\n                        int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n                        reinterpret_cast<int4*>(&frag_s[k % 2])[0] =\n                            sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];\n                    }\n                }\n\n                return;\n            }\n\n            // Act-order case\n\n            // Determine K of the \"current\" thread-block\n            int cur_k = slice_k_start + tb_k * full_pipe;\n            if (cur_k >= prob_k || cur_k >= slice_k_finish) {\n                return;\n            }\n\n            // Reset (to current thread-block) since we read g_idx portion from the\n            // shared memory\n            cur_k = 0;\n\n            // Progress to current iteration\n            cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n            // Determine \"position\" inside the thread-block (based on warp and\n            // thread-id)\n            int warp_id = threadIdx.x / 32;\n            int n_warps =\n                thread_n_blocks / 4;  // Each warp processes 4 16-size tiles over N\n\n            int warp_row = warp_id / n_warps;\n            int warp_col = warp_id % n_warps;\n\n            cur_k += warp_row * 16;\n\n            int th_id = threadIdx.x % 32;\n            cur_k += (th_id % 4) * 2;  // Due to tensor-core layout for fp16 B matrix\n\n            int s_col_shift =\n                /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +\n                (th_id / 4) * act_s_col_stride;\n\n            if (is_same_group[pipe]) {\n                if (k % 2 == 0) {\n                    *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n                        sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +\n                        s_col_shift];\n                }\n                else {\n                    *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n                        *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));\n                }\n\n                for (int i = 1; i < 4; i++) {\n                    *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =\n                        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));\n                }\n                return;\n            }\n\n            int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n            int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n            constexpr int k_frag_offsets[4] = { 0, 1, 8,\n                                               9 };  // Tensor core offsets per thread\n\n#pragma unroll\n            for (int i = 0; i < 4; i++) {\n                int actual_k = cur_k + k_frag_offsets[i];\n\n                int group_id = sh_g_idx_int_ptr[actual_k];\n                int rel_group_id = group_id - sh_first_group_id;\n\n                *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =\n                    sh_s[rel_group_id * s_sh_stride + s_col_shift];\n            }\n            };\n\n        // Execute the actual tensor core matmul of a sub-tile.\n        auto matmul = [&](int k) {\n            // We have the m dimension as the inner loop in order to encourage overlapping\n            // dequantization and matmul operations.\n#pragma unroll\n            for (int j = 0; j < 4; j++) {\n                FragB frag_b0;\n                FragB frag_b1;\n                if constexpr (num_bits == 4) {\n                    int b_quant = frag_b_quant[k % 2][0][j];\n                    int b_quant_shift = b_quant >> 8;\n\n                    frag_b0 = dequant_4bit<scalar_t>(b_quant);\n                    frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);\n\n                }\n                else {\n                    int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);\n                    int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];\n                    int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];\n\n                    frag_b0 = dequant_8bit<scalar_t>(b_quant_0);\n                    frag_b1 = dequant_8bit<scalar_t>(b_quant_1);\n                }\n\n                // Apply scale to frag_b0\n                if constexpr (has_act_order) {\n                    scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],\n                        act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],\n                        act_frag_s[k % 2][3][j], 0);\n                }\n                else {\n                    if constexpr (group_blocks != -1) {\n                        scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);\n                    }\n                }\n\n                // Apply scale to frag_b1\n                if constexpr (has_act_order) {\n                    scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],\n                        act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],\n                        act_frag_s[k % 2][3][j], 1);\n\n                }\n                else {\n                    if constexpr (group_blocks != -1) {\n                        scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);\n                    }\n                }\n\n#pragma unroll\n                for (int i = 0; i < thread_m_blocks; i++) {\n                    mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);\n                    mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);\n                }\n            }\n            };\n\n        // Since we slice across the k dimension of a tile in order to increase the\n        // number of warps while keeping the n dimension of a tile reasonable, we have\n        // multiple warps that accumulate their partial sums of the same output\n        // location; which we have to reduce over in the end. We do in shared memory.\n        auto thread_block_reduce = [&]() {\n            constexpr int red_off = threads / b_sh_stride_threads / 2;\n            if (red_off >= 1) {\n                int red_idx = threadIdx.x / b_sh_stride_threads;\n                constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;\n                constexpr int red_sh_delta = b_sh_stride_threads;\n                int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +\n                    (threadIdx.x % b_sh_stride_threads);\n\n                // Parallel logarithmic shared memory reduction. We make sure to avoid any\n                // unnecessary read or write iterations, e.g., for two warps we write only\n                // once by warp 1 and read only once by warp 0.\n\n#pragma unroll\n                for (int m_block = 0; m_block < thread_m_blocks; m_block++) {\n#pragma unroll\n                    for (int i = red_off; i > 0; i /= 2) {\n                        if (i <= red_idx && red_idx < 2 * i) {\n#pragma unroll\n                            for (int j = 0; j < 4 * 2; j++) {\n                                int red_sh_wr =\n                                    red_sh_delta * j + (red_sh_rd - red_sh_stride * i);\n                                if (i < red_off) {\n                                    float* c_rd =\n                                        reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);\n                                    float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);\n#pragma unroll\n                                    for (int k = 0; k < 4; k++)\n                                        reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=\n                                        c_rd[k] + c_wr[k];\n                                }\n                                sh[red_sh_wr] =\n                                    reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];\n                            }\n                        }\n                        __syncthreads();\n                    }\n                    if (red_idx == 0) {\n#pragma unroll\n                        for (int i = 0; i < 4 * 2; i++) {\n                            float* c_rd =\n                                reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);\n#pragma unroll\n                            for (int j = 0; j < 4; j++)\n                                reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=\n                                c_rd[j];\n                        }\n                    }\n                    __syncthreads();\n                }\n            }\n            };\n\n        // Since multiple threadblocks may process parts of the same column slice, we\n        // finally have to globally reduce over the results. As the striped\n        // partitioning minimizes the number of such reductions and our outputs are\n        // usually rather small, we perform this reduction serially in L2 cache.\n        auto global_reduce = [&](bool first = false, bool last = false) {\n            // We are very careful here to reduce directly in the output buffer to\n            // maximize L2 cache utilization in this step. To do this, we write out\n            // results in FP16 (but still reduce with FP32 compute).\n            constexpr int active_threads = 32 * thread_n_blocks / 4;\n            if (threadIdx.x < active_threads) {\n                int c_gl_stride = prob_n / 8;\n                int c_gl_wr_delta_o = 8 * c_gl_stride;\n                int c_gl_wr_delta_i = 4 * (active_threads / 32);\n                int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +\n                    4 * (threadIdx.x / 32) + threadIdx.x % 4;\n                c_gl_wr += (2 * thread_n_blocks) * slice_col;\n                constexpr int c_sh_wr_delta = active_threads;\n                int c_sh_wr = threadIdx.x;\n\n                int row = (threadIdx.x % 32) / 4;\n\n                if (!first) {\n                    // Interestingly, doing direct global accesses here really seems to mess up\n                    // the compiler and lead to slowdowns, hence we also use async-copies even\n                    // though these fetches are not actually asynchronous.\n#pragma unroll\n                    for (int i = 0; i < thread_m_blocks * 4; i++) {\n                        cp_async4_pred(\n                            &sh[c_sh_wr + c_sh_wr_delta * i],\n                            &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +\n                            c_gl_wr_delta_i * (i % 2)],\n                            i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);\n                    }\n                    cp_async_fence();\n                    cp_async_wait<0>();\n                }\n\n#pragma unroll\n                for (int i = 0; i < thread_m_blocks * 4; i++) {\n                    if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {\n                        if (!first) {\n                            int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];\n#pragma unroll\n                            for (int j = 0; j < 2 * 4; j++) {\n                                reinterpret_cast<float*>(\n                                    &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=\n                                    Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);\n                            }\n                        }\n                        if (!last) {\n                            int4 c;\n#pragma unroll\n                            for (int j = 0; j < 2 * 4; j++) {\n                                reinterpret_cast<scalar_t*>(&c)[j] =\n                                    Dtype::float2num(reinterpret_cast<float*>(\n                                        &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);\n                            }\n                            C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =\n                                c;\n                        }\n                    }\n                }\n            }\n            };\n\n        // Write out the reduce final result in the correct layout. We only actually\n        // reshuffle matrix fragments in this step, the reduction above is performed\n        // in fragment layout.\n        auto write_result = [&]() {\n            int c_gl_stride = prob_n / 8;\n            constexpr int c_sh_stride = 2 * thread_n_blocks + 1;\n            int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));\n            constexpr int c_sh_rd_delta =\n                c_sh_stride * (threads / (2 * thread_n_blocks));\n\n            int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                (threadIdx.x % (2 * thread_n_blocks));\n            c_gl_wr += (2 * thread_n_blocks) * slice_col;\n            int c_sh_wr =\n                (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;\n            c_sh_wr += 32 * (threadIdx.x / 32);\n            int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                (threadIdx.x % (2 * thread_n_blocks));\n\n            int c_gl_wr_end = c_gl_stride * prob_m;\n\n            // We first reorder in shared memory to guarantee the most efficient final\n            // global write patterns\n            auto write = [&](int idx, float c0, float c1, FragS& s) {\n                scalar_t2 res =\n                    Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));\n\n                // For per-column quantization we finally apply the scale here (only for\n                // 4-bit)\n                if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {\n                    res = __hmul2(res, s[0]);\n                }\n\n                ((scalar_t2*)sh)[idx] = res;\n                };\n\n            if (threadIdx.x / 32 < thread_n_blocks / 4) {\n#pragma unroll\n                for (int i = 0; i < thread_m_blocks; i++) {\n#pragma unroll\n                    for (int j = 0; j < 4; j++) {\n                        int wr = c_sh_wr + 8 * j;\n                        write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],\n                            frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);\n                        write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],\n                            frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);\n                        write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],\n                            frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);\n                        write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],\n                            frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);\n                    }\n                    c_sh_wr += 16 * (4 * c_sh_stride);\n                }\n            }\n            __syncthreads();\n\n#pragma unroll\n            for (int i = 0;\n                i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));\n                i++) {\n                if (c_gl_wr < c_gl_wr_end) {\n                    C[c_gl_wr] = sh[c_sh_rd];\n                    c_gl_wr += c_gl_wr_delta;\n                    c_sh_rd += c_sh_rd_delta;\n                }\n            }\n            };\n\n        // Start global fetch and register load pipelines.\n        auto start_pipes = [&]() {\n\n#pragma unroll\n            for (int i = 0; i < stages - 1; i++) {\n                if (has_act_order && i == 0) {\n                    int last_g_idx = slice_k_start + stages * tb_k * 2;\n                    if (last_g_idx >= prob_k) {\n                        last_g_idx = prob_k - 1;\n                    }\n                    fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);\n                }\n                fetch_to_shared(i, i, i < slice_iters);\n            }\n\n            zero_accums();\n            wait_for_stage();\n            init_same_group(0);\n            fetch_to_registers(0, 0);\n            fetch_scales_to_registers(0, 0);\n            a_gl_rd += a_gl_rd_delta_o * (stages - 1);\n            slice_k_start_shared_fetch += tb_k * (stages - 1);\n            };\n        if (slice_iters) {\n            start_pipes();\n        }\n\n        // Main loop.\n        while (slice_iters) {\n            // We unroll over both the global fetch and the register load pipeline to\n            // ensure all shared memory accesses are static. Note that both pipelines\n            // have even length meaning that the next iteration will always start at\n            // index 0.\n\n#pragma unroll\n            for (int pipe = 0; pipe < stages;) {\n#pragma unroll\n                for (int k = 0; k < b_sh_wr_iters; k++) {\n                    fetch_to_registers(k + 1, pipe % stages);\n                    fetch_scales_to_registers(k + 1, pipe);\n                    if (k == b_sh_wr_iters - 2) {\n                        fetch_to_shared((pipe + stages - 1) % stages, pipe,\n                            slice_iters >= stages);\n                        pipe++;\n                        wait_for_stage();\n                        init_same_group(pipe % stages);\n                    }\n                    matmul(k);\n                }\n                slice_iters--;\n                if (slice_iters == 0) {\n                    break;\n                }\n            }\n\n            a_gl_rd += a_gl_rd_delta_o * stages;\n            slice_k_start += tb_k * stages;\n            slice_k_start_shared_fetch += tb_k * stages;\n\n            if constexpr (has_act_order) {\n                int first_group_id = g_idx[slice_k_start];\n                int last_g_idx = slice_k_start + stages * tb_k * 2;\n                if (last_g_idx >= prob_k) {\n                    last_g_idx = prob_k - 1;\n                }\n                int last_group_id = g_idx[last_g_idx];\n                if (last_group_id >= sh_first_group_id + sh_num_groups) {\n                    fetch_scales_to_shared(false, first_group_id, last_group_id);\n                    __syncthreads();\n                }\n            }\n\n            // Process results and, if necessary, proceed to the next column slice.\n            // While this pattern may not be the most readable, other ways of writing\n            // the loop seemed to noticeably worse performance after compilation.\n            if (slice_iters == 0) {\n                cp_async_wait<0>();\n                bool last = slice_idx == slice_count - 1;\n                // For per-column scales, we only fetch them here in the final step before\n                // write-out\n                if constexpr (!has_act_order && group_blocks == -1) {\n                    if constexpr (num_bits == 8) {\n                        if (s_sh_wr_pred) {\n                            cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n                        }\n                        cp_async_fence();\n                    }\n                    else {\n                        if (last) {\n                            if (s_sh_wr_pred) {\n                                cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n                            }\n                            cp_async_fence();\n                        }\n                    }\n                }\n\n                thread_block_reduce();\n                if constexpr (!has_act_order && group_blocks == -1) {\n                    if constexpr (num_bits == 8) {\n                        cp_async_wait<0>();\n                        __syncthreads();\n                        if (threadIdx.x / 32 < thread_n_blocks / 4) {\n                            reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n                            reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n                        }\n\n                    }\n                    else {\n                        if (last) {\n                            cp_async_wait<0>();\n                            __syncthreads();\n                            if (threadIdx.x / 32 < thread_n_blocks / 4) {\n                                reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n                                reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n                            }\n                        }\n                    }\n                }\n\n                // For 8-bit channelwise, we apply the scale before the global reduction\n                // that converts the fp32 results to fp16 (so that we avoid possible\n                // overflow in fp16)\n                if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {\n                    if (threadIdx.x / 32 < thread_n_blocks / 4) {\n#pragma unroll\n                        for (int i = 0; i < thread_m_blocks; i++) {\n#pragma unroll\n                            for (int j = 0; j < 4; j++) {\n                                scale_float<scalar_t>(\n                                    reinterpret_cast<float*>(&frag_c[i][j][0][0]),\n                                    frag_s[j / 2][2 * (j % 2) + 0]);\n                                scale_float<scalar_t>(\n                                    reinterpret_cast<float*>(&frag_c[i][j][0][2]),\n                                    frag_s[j / 2][2 * (j % 2) + 0]);\n\n                                scale_float<scalar_t>(\n                                    reinterpret_cast<float*>(&frag_c[i][j][1][0]),\n                                    frag_s[j / 2][2 * (j % 2) + 1]);\n                                scale_float<scalar_t>(\n                                    reinterpret_cast<float*>(&frag_c[i][j][1][2]),\n                                    frag_s[j / 2][2 * (j % 2) + 1]);\n                            }\n                        }\n                    }\n                }\n\n                if (slice_count > 1) {  // only globally reduce if there is more than one\n                    // block in a slice\n                    barrier_acquire(&locks[slice_col], slice_idx);\n                    global_reduce(slice_idx == 0, last);\n                    barrier_release(&locks[slice_col], last);\n                }\n                if (last)  // only the last block in a slice actually writes the result\n                    write_result();\n                slice_row = 0;\n                slice_col_par++;\n                slice_col++;\n                init_slice();\n                if (slice_iters) {\n                    a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                        (threadIdx.x % a_gl_rd_delta_o);\n#pragma unroll\n                    for (int i = 0; i < b_sh_wr_iters; i++)\n                        B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;\n                    if (slice_col == 0) {\n#pragma unroll\n                        for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;\n                    }\n\n                    // Update slice k/n for scales loading\n                    if constexpr (has_act_order) {\n                        slice_k_start = tb_k * slice_row;\n                        slice_k_finish = slice_k_start + tb_k * slice_iters;\n                        slice_k_start_shared_fetch = slice_k_start;\n                        slice_n_offset = act_s_col_tb_stride * slice_col;\n\n                    }\n                    else {\n                        s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n                    }\n\n                    start_pipes();\n                }\n            }\n        }\n    }\n\n    template <typename scalar_t,         // compute dtype, half or nv_float16\n        const int num_bits,        // number of bits used for weights\n        const int threads,         // number of threads in a threadblock\n        const int template_thread_m_blocks, // number of 16x16 blocks in the m\n        // dimension (batchsize) of the\n        // threadblock\n        const int thread_n_blocks, // same for n dimension (output)\n        const int thread_k_blocks, // same for k dimension (reduction)\n        const int stages, // number of stages for the async global->shared\n        // fetch pipeline\n        const bool has_act_order,   // whether act_order is enabled\n        const int group_blocks = -1 // number of consecutive 16x16 blocks\n        // with a separate quantization scale\n    >\n    __global__ void\n        Marlin_wrapper(const int4* __restrict__ A, // fp16 input matrix of shape mxk\n            const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn\n            int4* __restrict__ C,       // fp16 output buffer of shape mxn\n            const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape\n            // (k/groupsize)xn\n            const int* __restrict__ g_idx, // int32 group indices of shape k\n            int num_groups, // number of scale groups per output channel\n            const int* __restrict__ prob_m_ptr,     // batch dimension m\n            int prob_n,     // output dimension n\n            int prob_k,     // reduction dimension k\n            int* locks      // extra global storage for barrier synchronization\n        ) {\n        int prob_m = *prob_m_ptr;\n        prob_m = min(prob_m, 1024);\n        const int thread_m_blocks = min(div_ceil(prob_m, 16), template_thread_m_blocks);\n        if(prob_m > 16 * thread_m_blocks)\n            prob_m = (16 * thread_m_blocks) * div_ceil(prob_m, (16 * thread_m_blocks));\n        /*if (blockIdx.x == 0 && threadIdx.x == 0)\n            printf(\"marlin prob_m %d\\n\", prob_m);*/\n        if (thread_m_blocks == 1) {\n            Marlin<scalar_t, num_bits, threads, 1,\n                thread_n_blocks, thread_k_blocks, stages, has_act_order,\n                group_blocks>(\n                    A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,\n                    prob_k, locks);\n        }\n        else if (thread_m_blocks == 2) {\n            Marlin<scalar_t, num_bits, threads, 2,\n                thread_n_blocks, thread_k_blocks, stages, has_act_order,\n                group_blocks>(\n                    A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,\n                    prob_k, locks);\n        }\n        else if (thread_m_blocks == 3) {\n            Marlin<scalar_t, num_bits, threads, 3,\n                thread_n_blocks, thread_k_blocks, stages, has_act_order,\n                group_blocks>(\n                    A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,\n                    prob_k, locks);\n        }\n        else if (thread_m_blocks == 4) {\n            Marlin<scalar_t, num_bits, threads, 4,\n                thread_n_blocks, thread_k_blocks, stages, has_act_order,\n                group_blocks>(\n                    A, B, C, scales_ptr, g_idx, num_groups, prob_m, prob_n,\n                    prob_k, locks);\n        }\n    }\n\n#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \\\n                  HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS)                    \\\n    else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS &&     \\\n             thread_n_blocks == THREAD_N_BLOCKS &&                             \\\n             thread_k_blocks == THREAD_K_BLOCKS &&                             \\\n             has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \\\n             num_threads == NUM_THREADS) {                                     \\\n        cudaFuncSetAttribute(                                                  \\\n            Marlin_wrapper<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,           \\\n                   THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages,              \\\n                   HAS_ACT_ORDER, GROUP_BLOCKS>,                               \\\n            cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);      \\\n        Marlin_wrapper<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,               \\\n               THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER,   \\\n               GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \\\n            A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m_ptr, prob_n, \\\n            prob_k, locks);                                                    \\\n    }\n\n    typedef struct {\n        int thread_k;\n        int thread_n;\n        int num_threads;\n    } thread_config_t;\n\n    typedef struct {\n        int max_m_blocks;\n        thread_config_t tb_cfg;\n    } exec_config_t;\n\n    thread_config_t small_batch_thread_configs[] = {\n        // Ordered by priority\n\n        // thread_k, thread_n, num_threads\n        {128, 128, 256},\n        {64, 128, 128},\n        {128, 64, 128},\n    };\n\n    thread_config_t large_batch_thread_configs[] = {\n        // Ordered by priority\n\n        // thread_k, thread_n, num_threads\n        {64, 256, 256},\n        // {128, 128, 256},\n        {64, 128, 128},\n        {128, 64, 128},\n\n    };\n\n    int get_scales_cache_size(thread_config_t const& th_config, int prob_m,\n        int prob_n, int prob_k, int num_bits, int group_size,\n        bool has_act_order, bool is_k_full) {\n        bool cache_scales_chunk = has_act_order && !is_k_full;\n\n        int tb_n = th_config.thread_n;\n        int tb_k = th_config.thread_k;\n\n        // Get max scale groups per thread-block\n        int tb_groups;\n        if (group_size == -1) {\n            tb_groups = 1;\n        }\n        else if (group_size == 0) {\n            tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size\n        }\n        else {\n            tb_groups = div_ceil(tb_k, group_size);\n        }\n\n        if (cache_scales_chunk) {\n            int load_groups =\n                tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K\n            load_groups = max(load_groups, 32); // We load at least 32 scale groups\n            return load_groups * tb_n * 2;\n\n        }\n        else {\n            int tb_scales = tb_groups * tb_n * 2;\n\n            return tb_scales * pipe_stages;\n        }\n    }\n\n    bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,\n        int prob_m, int prob_n, int prob_k, int num_bits,\n        int scales_cache_size, int max_shared_mem) {\n        int pack_factor = 32 / num_bits;\n\n        // Get B size\n        int tb_k = th_config.thread_k;\n        int tb_n = th_config.thread_n;\n\n        int b_size = (tb_k * tb_n / pack_factor) * 4;\n\n        // Get A size\n        int m_blocks = div_ceil(prob_m, 16);\n        int tb_max_m = 16;\n\n        // zbx: too ugly\n        // origin\n        /*while (true) {\n          if (m_blocks >= max_m_blocks) {\n            tb_max_m *= max_m_blocks;\n            break;\n          }\n\n          max_m_blocks--;\n          if (max_m_blocks == 0) {\n            TORCH_CHECK(false, \"Unexpected m_blocks = \", m_blocks);\n          }\n        }*/\n        // refactor\n        tb_max_m *= std::min(m_blocks, max_m_blocks);\n\n        int a_size = (tb_max_m * tb_k) * 2;\n\n        float pipe_size = (a_size + b_size) * pipe_stages;\n\n        TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity\n        return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);\n    }\n\n    bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,\n        int prob_m, int prob_n, int prob_k, int num_bits,\n        int group_size, bool has_act_order, bool is_k_full,\n        int max_shared_mem) {\n        // Sanity\n        if (th_config.thread_k == -1 || th_config.thread_n == -1 ||\n            th_config.num_threads == -1) {\n            return false;\n        }\n\n        // Verify K/N are divisible by thread K/N\n        if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {\n            return false;\n        }\n\n        // Verify min for thread K/N\n        if (th_config.thread_n < min_thread_n ||\n            th_config.thread_k < min_thread_k) {\n            return false;\n        }\n\n        // num_threads must be at least 128 (= 4 warps)\n        if (th_config.num_threads < 128) {\n            return false;\n        }\n\n        //  Determine cache for scales\n        int scales_cache_size =\n            get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,\n                group_size, has_act_order, is_k_full);\n\n        // Check that pipeline fits into cache\n        if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n            num_bits, scales_cache_size, max_shared_mem)) {\n            return false;\n        }\n\n        return true;\n    }\n\n    exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,\n        int num_bits, int group_size,\n        bool has_act_order, bool is_k_full,\n        int max_shared_mem) {\n        int max_m_blocks = 4;\n        while (max_m_blocks > 0) {\n            if (prob_m <= 16) {\n                for (auto th_config : small_batch_thread_configs) {\n                    if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n,\n                        prob_k, num_bits, group_size, has_act_order,\n                        is_k_full, max_shared_mem)) {\n                        return exec_config_t{ max_m_blocks, th_config };\n                    }\n                }\n            }\n            else {\n                for (auto th_config : large_batch_thread_configs) {\n                    if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n,\n                        prob_k, num_bits, group_size, has_act_order,\n                        is_k_full, max_shared_mem)) {\n                        return exec_config_t{ max_m_blocks, th_config };\n                    }\n                }\n            }\n\n            max_m_blocks--; // Process less M blocks per invocation to reduce cache\n            // usage\n        }\n\n        return exec_config_t{ 0, {-1, -1, -1} };\n    }\n\n#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS)                     \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)          \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)        \n\n    template <typename scalar_t>\n    void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,\n        void* g_idx, void* perm, void* a_tmp, int* prob_m_ptr, int prob_m,\n        int prob_n, int prob_k, void* workspace, int num_bits,\n        bool has_act_order, bool is_k_full, int num_groups,\n        int group_size, int dev, cudaStream_t stream, int thread_k,\n        int thread_n, int sms, int max_par) {\n        TORCH_CHECK(num_bits == 4 || num_bits == 8,\n            \"num_bits must be 4 or 8. Got = \", num_bits);\n        TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, \"Invalid MNK = [\",\n            prob_m, \", \", prob_n, \", \", prob_k, \"]\");\n\n        int tot_m = prob_m;\n        int tot_m_blocks = div_ceil(tot_m, 16);\n        int pad = 16 * tot_m_blocks - tot_m;\n\n        if (sms == -1) {\n            cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);\n        }\n\n        int max_shared_mem = 0;\n        cudaDeviceGetAttribute(&max_shared_mem,\n            cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n        TORCH_CHECK(max_shared_mem > 0);\n\n        // Set thread config\n        exec_config_t exec_cfg;\n        if (thread_k != -1 && thread_n != -1) {\n            // User-defined config\n            exec_cfg = exec_config_t{\n                4, thread_config_t{thread_k, thread_n, default_threads} };\n        }\n        else {\n            // Auto config\n            exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits,\n                group_size, has_act_order, is_k_full,\n                max_shared_mem);\n        }\n\n        TORCH_CHECK(\n            exec_cfg.max_m_blocks > 0 &&\n            is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m,\n                prob_n, prob_k, num_bits, group_size, has_act_order,\n                is_k_full, max_shared_mem),\n            \"Invalid thread config: max_m_blocks = \", exec_cfg.max_m_blocks,\n            \", thread_k = \", exec_cfg.tb_cfg.thread_k,\n            \", thread_n = \", exec_cfg.tb_cfg.thread_n,\n            \", num_threads = \", exec_cfg.tb_cfg.num_threads, \" for MKN = [\", prob_m,\n            \", \", prob_k, \", \", prob_n, \"] and num_bits = \", num_bits,\n            \", group_size = \", group_size, \", has_act_order = \", has_act_order,\n            \", is_k_full = \", is_k_full, \", max_shared_mem = \", max_shared_mem);\n\n        int num_threads = exec_cfg.tb_cfg.num_threads;\n        thread_k = exec_cfg.tb_cfg.thread_k;\n        thread_n = exec_cfg.tb_cfg.thread_n;\n\n        int thread_k_blocks = thread_k / 16;\n        int thread_n_blocks = thread_n / 16;\n\n        int blocks = sms;\n\n        TORCH_CHECK(prob_n % thread_n == 0, \"prob_n = \", prob_n,\n            \" is not divisible by thread_n = \", thread_n);\n        TORCH_CHECK(prob_k % thread_k == 0, \"prob_k = \", prob_k,\n            \" is not divisible by thread_k = \", thread_k);\n\n        int group_blocks = 0;\n        if (has_act_order) {\n            if (is_k_full) {\n                TORCH_CHECK(group_size != -1);\n                group_blocks = group_size / 16;\n                TORCH_CHECK(prob_k % group_blocks == 0, \"prob_k = \", prob_k,\n                    \" is not divisible by group_blocks = \", group_blocks);\n            }\n            else {\n                TORCH_CHECK(group_size == 0);\n                group_blocks = 0;\n            }\n\n        }\n        else {\n            if (group_size == -1) {\n                group_blocks = -1;\n            }\n            else {\n                group_blocks = group_size / 16;\n                TORCH_CHECK(prob_k % group_blocks == 0, \"prob_k = \", prob_k,\n                    \" is not divisible by group_blocks = \", group_blocks);\n            }\n        }\n\n        const int4* A_ptr = (const int4*)A;\n        const int4* B_ptr = (const int4*)B;\n        int4* C_ptr = (int4*)C;\n        const int4* s_ptr = (const int4*)s;\n        const int* g_idx_ptr = (const int*)g_idx;\n        const int* perm_ptr = (const int*)perm;\n        int4* a_tmp_ptr = (int4*)a_tmp;\n\n        int* locks = (int*)workspace;\n\n        if (has_act_order) {\n            // Permute A columns\n            int block_rows = div_ceil(prob_m, blocks);\n            permute_cols_kernel << <blocks, default_threads, 0, stream >> > (\n                A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);\n            A_ptr = a_tmp_ptr;\n        }\n\n        // If we have a full K, then we can run the non-act-order version of Marlin\n        // (since the weight rows are reordered by increasing group ids, and by\n        // having a full K, we have full original groups)\n        if (is_k_full) {\n            has_act_order = false;\n        }\n\n        // Main loop\n        for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {\n            int thread_m_blocks = tot_m_blocks - i;\n            prob_m = tot_m - 16 * i;\n            int par = 1;\n            if (thread_m_blocks > exec_cfg.max_m_blocks) {\n                // Note that parallel > 1 currently only works for inputs without\n                // any padding\n                par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);\n                if (par > max_par)\n                    par = max_par;\n                prob_m = (16 * exec_cfg.max_m_blocks) * par;\n                i += exec_cfg.max_m_blocks * (par - 1);\n                thread_m_blocks = exec_cfg.max_m_blocks;\n            }\n\n            // Define kernel configurations\n#define undefined_error                                                        \\\n    TORCH_CHECK(false, \"Unsupported shapes: MNK = [\" + str(prob_m) + \", \" +    \\\n                           str(prob_n) + \", \" + str(prob_k) + \"]\" +            \\\n                           \", has_act_order = \" + str(has_act_order) +         \\\n                           \", num_groups = \" + str(num_groups) +               \\\n                           \", group_size = \" + str(group_size) +               \\\n                           \", thread_m_blocks = \" + str(thread_m_blocks) +     \\\n                           \", thread_n_blocks = \" + str(thread_n_blocks) +     \\\n                           \", thread_k_blocks = \" + str(thread_k_blocks));\n\n        /* std::cout << \"MNK = [\" + str(prob_m) + \", \" + \\\n             str(prob_n) + \", \" + str(prob_k) + \"]\" + \\\n             \", has_act_order = \" + str(has_act_order) + \\\n             \", num_groups = \" + str(num_groups) + \\\n             \", group_size = \" + str(group_size) + \\\n             \", thread_m_blocks = \" + str(thread_m_blocks) + \\\n             \", thread_n_blocks = \" + str(thread_n_blocks) + \\\n             \", thread_k_blocks = \" + str(thread_k_blocks) << std::endl;*/\n\n             /*if (false) {\n             }\n             // CALL_IF(4, 32, 2, 256)\n             // CALL_IF(4, 16, 4, 256)\n             __CALL_IF(4, 1, 16, 4, false, 4, 256)\n             __CALL_IF(4, 2, 16, 4, false, 4, 256)\n             // CALL_IF(4, 8, 8, 256)\n             __CALL_IF(4, 1, 8, 8, false, 4, 256)\n             __CALL_IF(4, 2, 8, 8, false, 4, 256)\n             // CALL_IF(4, 16, 4, 128)\n             __CALL_IF(4, 1, 16, 4, false, 4, 128)\n             __CALL_IF(4, 2, 16, 4, false, 4, 128)\n             // CALL_IF(4, 8, 8, 128)\n             __CALL_IF(4, 1, 8, 8, false, 4, 128)\n             __CALL_IF(4, 2, 8, 8, false, 4, 128)\n             else {undefined_error}*/\n\n            if (num_bits == 4 && num_threads == 256)\n            {\n                if (false) {\n                }\n                CALL_IF(4, 32, 2, 256)\n                    CALL_IF(4, 16, 4, 256)\n                    CALL_IF(4, 8, 8, 256)\n                else {\n                    undefined_error\n                }\n            }\n            else if (num_bits == 4 && num_threads == 128)\n            {\n                if (false) {\n                }\n                CALL_IF(4, 8, 4, 128)\n                    CALL_IF(4, 16, 4, 128)\n                    CALL_IF(4, 4, 8, 128)\n                else {\n                    undefined_error\n                }\n            }\n            // else if (num_bits == 8 && num_threads == 256)\n            // {\n            //     if (false) {\n            //     }\n            //     CALL_IF(8, 32, 2, 256)\n            //     CALL_IF(8, 16, 4, 256)\n            //     CALL_IF(8, 8, 8, 256)\n            //     else {\n            //         undefined_error\n            //     }\n            // }\n            // else if (num_bits == 8 && num_threads == 128)\n            // {\n            //     if (false) {\n            //     }\n            //     CALL_IF(8, 8, 4, 128)\n            //     CALL_IF(8, 16, 4, 128)\n            //     CALL_IF(8, 4, 8, 128)\n            //     else {\n            //         undefined_error\n            //     }\n            // }\n            else {\n                undefined_error\n            }\n\n            A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;\n            C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;\n        }\n    }\n\n} // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n    torch::Tensor& b_scales, torch::Tensor& g_idx,\n    torch::Tensor& perm, torch::Tensor& workspace,\n    int64_t num_bits, torch::Tensor size_m_tensor, int64_t size_m, int64_t size_n,\n    int64_t size_k, int sms, bool is_k_full) {\n    const at::cuda::OptionalCUDAGuard device_guard(device_of(a));\n    // Verify num_bits\n    TORCH_CHECK(num_bits == 4 || num_bits == 8,\n        \"num_bits must be 4 or 8. Got = \", num_bits);\n    int pack_factor = 32 / num_bits;\n\n    // Verify A\n    TORCH_CHECK(a.size(0) == size_m, \"Shape mismatch: a.size(0) = \", a.size(0),\n        \", size_m = \", size_m);\n    TORCH_CHECK(a.size(1) == size_k, \"Shape mismatch: a.size(1) = \", a.size(1),\n        \", size_k = \", size_k);\n\n    // Verify B\n    TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, \"size_k = \", size_k,\n        \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n    TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),\n        \"Shape mismatch: b_q_weight.size(0) = \", b_q_weight.size(0),\n        \", size_k = \", size_k,\n        \", tile_size = \", gptq_marlin::tile_size);\n    TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,\n        \"b_q_weight.size(1) = \", b_q_weight.size(1),\n        \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n    int actual_size_n =\n        (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;\n    TORCH_CHECK(size_n == actual_size_n, \"size_n = \", size_n,\n        \", actual_size_n = \", actual_size_n);\n\n    // Verify device and strides\n    TORCH_CHECK(a.device().is_cuda(), \"A is not on GPU\");\n    TORCH_CHECK(a.is_contiguous(), \"A is not contiguous\");\n\n    TORCH_CHECK(b_q_weight.device().is_cuda(), \"b_q_weight is not on GPU\");\n    TORCH_CHECK(b_q_weight.is_contiguous(), \"b_q_weight is not contiguous\");\n\n    TORCH_CHECK(b_scales.device().is_cuda(), \"b_scales is not on GPU\");\n    TORCH_CHECK(b_scales.is_contiguous(), \"b_scales is not contiguous\");\n\n    TORCH_CHECK(g_idx.device().is_cuda(), \"g_idx is not on GPU\");\n    TORCH_CHECK(g_idx.is_contiguous(), \"g_idx is not contiguous\");\n\n    TORCH_CHECK(perm.device().is_cuda(), \"perm is not on GPU\");\n    TORCH_CHECK(perm.is_contiguous(), \"perm is not contiguous\");\n\n    // Alloc buffers\n    auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());\n    torch::Tensor c = torch::empty({ size_m, size_n }, options);\n    torch::Tensor a_tmp = torch::empty({ size_m, size_k }, options);\n\n    // thread_k: `k` size of a thread_tile in `weights` (can usually be left as\n    // auto -1)\n    int thread_k = -1;\n    // thread_n: `n` size of a thread_tile in `weights` (can usually be left as\n    // auto -1)\n    int thread_n = -1;\n    // sms: number of SMs to use for the kernel (can usually be left as auto -1)\n    // int sms = -1; //zbx\n\n    // Verify g_idx and perm\n    TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||\n        (g_idx.size(0) == size_k && perm.size(0) == size_k),\n        \"Unexpected g_idx.size(0) = \", g_idx.size(0),\n        \" and perm.size(0) = \", perm.size(0),\n        \", where size_k = \", size_k);\n\n    // Detect groupsize and act_order\n    int num_groups = -1;\n    int group_size = -1;\n    bool has_act_order = g_idx.size(0) != 0;\n\n    int b_rank = b_scales.sizes().size();\n    TORCH_CHECK(b_rank == 2, \"b_scales rank = \", b_rank, \" is not 2\");\n    TORCH_CHECK(b_scales.size(1) == size_n,\n        \"b_scales dim 1 = \", b_scales.size(1),\n        \" is not size_n = \", size_n);\n    num_groups = b_scales.size(0);\n\n    if (has_act_order) {\n        if (is_k_full) {\n            TORCH_CHECK(num_groups > 1,\n                \"For act_order, num_groups must be > 1\");\n            TORCH_CHECK(size_k % num_groups == 0, \"size_k = \", size_k,\n                \", is not divisible by num_groups = \", num_groups);\n            group_size = size_k / num_groups;\n        }\n        else {\n            group_size = 0;\n        }\n\n    }\n    else {\n        if (num_groups > 1) {\n            TORCH_CHECK(\n                size_k % num_groups == 0, \"size_k = \", size_k,\n                \", is not divisible by b_scales.size(0) = \", b_scales.size(0));\n            group_size = size_k / num_groups;\n        }\n        else {\n            group_size = -1;\n        }\n    }\n\n    // Verify workspace size\n    TORCH_CHECK(\n        size_n % gptq_marlin::min_thread_n == 0, \"size_n = \", size_n,\n        \", is not divisible by min_thread_n = \", gptq_marlin::min_thread_n);\n    int min_workspace_size =\n        (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;\n    TORCH_CHECK(workspace.numel() >= min_workspace_size,\n        \"workspace.numel = \", workspace.numel(),\n        \" is below min_workspace_size = \", min_workspace_size);\n\n    int dev = a.get_device();\n    if (a.scalar_type() == at::ScalarType::Half) {\n        gptq_marlin::marlin_mm_f16i4<half>(\n            a.data_ptr<at::Half>(), b_q_weight.data_ptr(),\n            c.data_ptr<at::Half>(), b_scales.data_ptr<at::Half>(),\n            g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::Half>(),\n            size_m_tensor.data_ptr<int>(),\n            size_m, size_n, size_k, workspace.data_ptr(), num_bits,\n            has_act_order, is_k_full, num_groups, group_size, dev,\n            at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,\n            gptq_marlin::max_par);\n    }\n    else if (a.scalar_type() == at::ScalarType::BFloat16) {\n        gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(\n            a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),\n            c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),\n            g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),\n            size_m_tensor.data_ptr<int>(),\n            size_m, size_n, size_k, workspace.data_ptr(), num_bits,\n            has_act_order, is_k_full, num_groups, group_size, dev,\n            at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,\n            gptq_marlin::max_par);\n    }\n    else {\n        TORCH_CHECK(false,\n            \"gpt_marlin_gemm only supports bfloat16 and float16\");\n    }\n\n    return c;\n}\n\n#endif"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/gptq_marlin/gptq_marlin.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n// Copyrigth 2024 The vLLM team.\n// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#pragma once\n\n#include <torch/all.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <iostream>\n\nnamespace gptq_marlin {\n\n// 8 warps are a good choice since every SM has 4 schedulers and having more\n// than 1 warp per schedule allows some more latency hiding. At the same time,\n// we want relatively few warps to have many registers per warp and small tiles.\nstatic constexpr int default_threads = 256;\n\nstatic constexpr int pipe_stages =\n    4; // 4 pipeline stages fit into shared memory\n\nstatic constexpr int min_thread_n = 64;\nstatic constexpr int min_thread_k = 64;\n\nstatic constexpr int tile_size = 16;\nstatic constexpr int max_par = 16;\n\ntemplate <typename T, int n> struct Vec {\n    T elems[n];\n    __device__ T &operator[](int i) { return elems[i]; }\n};\n\nusing I4 = Vec<int, 4>;\n\nconstexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n// No support for async\n#else\n\n__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,\n                                      bool pred = true) {\n    const int BYTES = 16;\n    uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n    asm volatile(\"{\\n\"\n                 \"   .reg .pred p;\\n\"\n                 \"   setp.ne.b32 p, %0, 0;\\n\"\n                 \"   @p cp.async.cg.shared.global [%1], [%2], %3;\\n\"\n                 \"}\\n\" ::\"r\"((int)pred),\n                 \"r\"(smem), \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {\n    const int BYTES = 16;\n    uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n    asm volatile(\"{\\n\"\n                 \"   cp.async.cg.shared.global [%0], [%1], %2;\\n\"\n                 \"}\\n\" ::\"r\"(smem),\n                 \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async_fence() {\n    asm volatile(\"cp.async.commit_group;\\n\" ::);\n}\n\ntemplate <int n> __device__ inline void cp_async_wait() {\n    asm volatile(\"cp.async.wait_group %0;\\n\" ::\"n\"(n));\n}\n\n#endif\n\n} // namespace gptq_marlin"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/gptq_marlin/gptq_marlin_dtypes.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n// Copyrigth 2024 The vLLM team.\n// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#ifndef _data_types_cuh\n#define _data_types_cuh\n#include \"gptq_marlin.cuh\"\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\nnamespace gptq_marlin {\n\ntemplate <typename scalar_t> class ScalarType {};\n\ntemplate <> class ScalarType<half> {\n  public:\n    using scalar_t = half;\n    using scalar_t2 = half2;\n\n    // Matrix fragments for tensor core instructions; their precise layout is\n    // documented here:\n    // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type\n    using FragA = Vec<half2, 4>;\n    using FragB = Vec<half2, 2>;\n    using FragC = Vec<float, 4>;\n    using FragS = Vec<half2, 1>;\n\n    static __device__ float inline num2float(const half x) {\n        return __half2float(x);\n    }\n\n    static __device__ half2 inline num2num2(const half x) {\n        return __half2half2(x);\n    }\n\n    static __device__ half2 inline nums2num2(const half x1, const half x2) {\n        return __halves2half2(x1, x2);\n    }\n\n    static __host__ __device__ half inline float2num(const float x) {\n        return __float2half(x);\n    }\n};\n\ntemplate <> class ScalarType<nv_bfloat16> {\n  public:\n    using scalar_t = nv_bfloat16;\n    using scalar_t2 = nv_bfloat162;\n\n    using FragA = Vec<nv_bfloat162, 4>;\n    using FragB = Vec<nv_bfloat162, 2>;\n    using FragC = Vec<float, 4>;\n    using FragS = Vec<nv_bfloat162, 1>;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n    static __device__ float inline num2float(const nv_bfloat16 x) {\n        return __bfloat162float(x);\n    }\n\n    static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {\n        return __bfloat162bfloat162(x);\n    }\n\n    static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,\n                                                    const nv_bfloat16 x2) {\n        return __halves2bfloat162(x1, x2);\n    }\n\n    static __host__ __device__ nv_bfloat16 inline float2num(const float x) {\n        return __float2bfloat16(x);\n    }\n#endif\n};\n\n} // namespace gptq_marlin\n\n#endif"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu",
    "content": "#include \"gptq_marlin.cuh\"\n\nnamespace gptq_marlin {\n\nstatic constexpr int repack_stages = 8;\n\nstatic constexpr int repack_threads = 256;\n\nstatic constexpr int tile_k_size = tile_size;\nstatic constexpr int tile_n_size = tile_k_size * 4;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800\n\ntemplate <int const num_threads, int const num_bits, bool const has_perm>\n__global__ void marlin_repack_kernel(\n    uint32_t const* __restrict__ b_q_weight_ptr,\n    uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,\n    int size_k, int size_n) {}\n\n}  // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,\n                                 int64_t size_k, int64_t size_n,\n                                 int64_t num_bits) {\n  TORCH_CHECK_NOT_IMPLEMENTED(\n      false, \"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0\");\n  return torch::empty({1, 1});\n}\n\n#else\n\ntemplate <int const num_threads, int const num_bits, bool const has_perm>\n__global__ void marlin_repack_kernel(\n    uint32_t const* __restrict__ b_q_weight_ptr,\n    uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,\n    int size_k, int size_n) {\n  constexpr int pack_factor = 32 / num_bits;\n\n  int k_tiles = size_k / tile_k_size;\n  int n_tiles = size_n / tile_n_size;\n  int block_k_tiles = div_ceil(k_tiles, gridDim.x);\n\n  int start_k_tile = blockIdx.x * block_k_tiles;\n  if (start_k_tile >= k_tiles) {\n    return;\n  }\n\n  int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<repack_stages - 2>();\n    __syncthreads();\n  };\n\n  extern __shared__ int4 sh[];\n\n  constexpr int perm_size = tile_k_size / 4;\n\n  int4* sh_perm_ptr = sh;\n  int4* sh_pipe_ptr = sh_perm_ptr;\n  if constexpr (has_perm) {\n    sh_pipe_ptr += perm_size;\n  }\n\n  constexpr int tile_ints = tile_k_size / pack_factor;\n\n  constexpr int stage_n_threads = tile_n_size / 4;\n  constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;\n  constexpr int stage_size = stage_k_threads * stage_n_threads;\n\n  auto load_perm_to_shared = [&](int k_tile_id) {\n    int first_k_int4 = (k_tile_id * tile_k_size) / 4;\n\n    int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);\n\n    if (threadIdx.x < perm_size) {\n      sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];\n    }\n    __syncthreads();\n  };\n\n  auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {\n    if (n_tile_id >= n_tiles) {\n      cp_async_fence();\n      return;\n    }\n\n    int first_n = n_tile_id * tile_n_size;\n\n    int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;\n\n    if constexpr (has_perm) {\n      if (threadIdx.x < stage_size) {\n        int k_id = threadIdx.x / stage_n_threads;\n        int n_id = threadIdx.x % stage_n_threads;\n\n        uint32_t const* sh_perm_int_ptr =\n            reinterpret_cast<uint32_t const*>(sh_perm_ptr);\n\n        int src_k = sh_perm_int_ptr[k_id];\n        int src_k_packed = src_k / pack_factor;\n\n        cp_async4(\n            &sh_ptr[k_id * stage_n_threads + n_id],\n            reinterpret_cast<int4 const*>(&(\n                b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));\n      }\n\n    } else {\n      if (threadIdx.x < stage_size) {\n        int k_id = threadIdx.x / stage_n_threads;\n        int n_id = threadIdx.x % stage_n_threads;\n\n        int first_k = k_tile_id * tile_k_size;\n        int first_k_packed = first_k / pack_factor;\n\n        cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],\n                  reinterpret_cast<int4 const*>(\n                      &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +\n                                       first_n + (n_id * 4)])));\n      }\n    }\n\n    cp_async_fence();\n  };\n\n  auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {\n    if (n_tile_id >= n_tiles) {\n      return;\n    }\n\n    int warp_id = threadIdx.x / 32;\n    int th_id = threadIdx.x % 32;\n\n    if (warp_id >= 4) {\n      return;\n    }\n\n    int tc_col = th_id / 4;\n    int tc_row = (th_id % 4) * 2;\n\n    constexpr int tc_offsets[4] = {0, 1, 8, 9};\n\n    int cur_n = warp_id * 16 + tc_col;\n\n    constexpr int sh_stride = 64;\n    constexpr uint32_t mask = (1 << num_bits) - 1;\n\n    int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;\n    uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);\n\n    uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);\n\n    uint32_t vals[8];\n\n    if constexpr (has_perm) {\n      for (int i = 0; i < 4; i++) {\n        int k_idx = tc_row + tc_offsets[i];\n\n        uint32_t src_k = sh_perm_int_ptr[k_idx];\n        uint32_t src_k_pos = src_k % pack_factor;\n\n        uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];\n        uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;\n\n        uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];\n        uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;\n\n        vals[i] = b1_cur_val;\n        vals[4 + i] = b2_cur_val;\n      }\n\n    } else {\n      uint32_t b1_vals[tile_ints];\n      uint32_t b2_vals[tile_ints];\n\n  #pragma unroll\n      for (int i = 0; i < tile_ints; i++) {\n        b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];\n        b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];\n      }\n\n  #pragma unroll\n      for (int i = 0; i < 4; i++) {\n        int cur_elem = tc_row + tc_offsets[i];\n        int cur_int = cur_elem / pack_factor;\n        int cur_pos = cur_elem % pack_factor;\n\n        vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;\n        vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;\n      }\n    }\n\n    constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;\n    int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;\n\n    // Result of:\n    // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h\n    if constexpr (num_bits == 4) {\n      constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};\n\n      uint32_t res = 0;\n  #pragma unroll\n      for (int i = 0; i < 8; i++) {\n        res |= vals[pack_idx[i]] << (i * 4);\n      }\n\n      out_ptr[out_offset + th_id * 4 + warp_id] = res;\n\n    } else {\n      constexpr int pack_idx[4] = {0, 2, 1, 3};\n\n      uint32_t res1 = 0;\n      uint32_t res2 = 0;\n  #pragma unroll\n      for (int i = 0; i < 4; i++) {\n        res1 |= vals[pack_idx[i]] << (i * 8);\n        res2 |= vals[4 + pack_idx[i]] << (i * 8);\n      }\n\n      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;\n      out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;\n    }\n  };\n\n  auto start_pipes = [&](int k_tile_id, int n_tile_id) {\n  #pragma unroll\n    for (int pipe = 0; pipe < repack_stages - 1; pipe++) {\n      fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);\n    }\n\n    wait_for_stage();\n  };\n  #pragma unroll\n  for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {\n    int n_tile_id = 0;\n\n    if constexpr (has_perm) {\n      load_perm_to_shared(k_tile_id);\n    }\n\n    start_pipes(k_tile_id, n_tile_id);\n\n    while (n_tile_id < n_tiles) {\n  #pragma unroll\n      for (int pipe = 0; pipe < repack_stages; pipe++) {\n        fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,\n                        n_tile_id + pipe + repack_stages - 1);\n        repack_tile(pipe, k_tile_id, n_tile_id + pipe);\n        wait_for_stage();\n      }\n      n_tile_id += repack_stages;\n    }\n  }\n}\n\n}  // namespace gptq_marlin\n\n  #define CALL_IF(NUM_BITS, HAS_PERM)                                          \\\n    else if (num_bits == NUM_BITS && has_perm == HAS_PERM) {                   \\\n      cudaFuncSetAttribute(                                                    \\\n          gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads,       \\\n                                            NUM_BITS, HAS_PERM>,               \\\n          cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);        \\\n      gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \\\n                                        HAS_PERM>                              \\\n          <<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>(   \\\n              b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);              \\\n    }\n\ntorch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,\n                                 int64_t size_k, int64_t size_n,\n                                 int64_t num_bits) {\n  // Verify compatibility with marlin tile of 16x64\n  TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, \"size_k = \", size_k,\n              \" is not divisible by tile_k_size = \", gptq_marlin::tile_k_size);\n  TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, \"size_n = \", size_n,\n              \" is not divisible by tile_n_size = \", gptq_marlin::tile_n_size);\n\n  TORCH_CHECK(num_bits == 4 || num_bits == 8,\n              \"num_bits must be 4 or 8. Got = \", num_bits);\n  int const pack_factor = 32 / num_bits;\n\n  // Verify B\n  TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),\n              \"Shape mismatch: b_q_weight.size(0) = \", b_q_weight.size(0),\n              \", size_k = \", size_k, \", pack_factor = \", pack_factor);\n  TORCH_CHECK(b_q_weight.size(1) == size_n,\n              \"b_q_weight.size(1) = \", b_q_weight.size(1),\n              \" is not size_n = \", size_n);\n\n  // Verify device and strides\n  TORCH_CHECK(b_q_weight.device().is_cuda(), \"b_q_weight is not on GPU\");\n  TORCH_CHECK(b_q_weight.is_contiguous(), \"b_q_weight is not contiguous\");\n  TORCH_CHECK(b_q_weight.dtype() == at::kInt, \"b_q_weight type is not kInt\");\n\n  TORCH_CHECK(perm.device().is_cuda(), \"perm is not on GPU\");\n  TORCH_CHECK(perm.is_contiguous(), \"perm is not contiguous\");\n  TORCH_CHECK(perm.dtype() == at::kInt, \"perm type is not at::kInt\");\n\n  // Alloc buffers\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));\n  auto options = torch::TensorOptions()\n                     .dtype(b_q_weight.dtype())\n                     .device(b_q_weight.device());\n  torch::Tensor out =\n      torch::empty({size_k / gptq_marlin::tile_size,\n                    size_n * gptq_marlin::tile_size / pack_factor},\n                   options);\n\n  // Detect if there is act_order\n  bool has_perm = perm.size(0) != 0;\n\n  // Get ptrs\n  uint32_t const* b_q_weight_ptr =\n      reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());\n  uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());\n  uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());\n\n  // Get dev info\n  int dev = b_q_weight.get_device();\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);\n  int blocks;\n  cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);\n\n  int max_shared_mem = 0;\n  cudaDeviceGetAttribute(&max_shared_mem,\n                         cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n  TORCH_CHECK(max_shared_mem > 0);\n\n  if (false) {\n  }\n  CALL_IF(4, false)\n  CALL_IF(4, true)\n  CALL_IF(8, false)\n  CALL_IF(8, true)\n  else {\n    TORCH_CHECK(false, \"Unsupported repack config: num_bits = \", num_bits,\n                \", has_perm = \", has_perm);\n  }\n\n  return out;\n}\n\n#endif"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/gptq_marlin/ops.h",
    "content": "/**\n * @Description  :\n * @Author       : Azure\n * @Date         : 2024-07-22 09:27:55\n * @Version      : 1.0.0\n * @LastEditors  : Azure\n * @LastEditTime : 2024-07-26 08:35:00\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#pragma once\n\n#include <torch/extension.h>\n#include <torch/library.h>\n#include <torch/torch.h>\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,\n                               torch::Tensor &b_scales, torch::Tensor &g_idx,\n                               torch::Tensor &perm, torch::Tensor &workspace,\n                               int64_t num_bits, torch::Tensor size_m_tensor, int64_t size_m, int64_t size_n,\n                               int64_t size_k, int sms, bool is_k_full);\n\ntorch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor&perm,\n                                 int64_t size_k, int64_t size_n,\n                                 int64_t num_bits);"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/setup.py",
    "content": "from setuptools import setup, Extension\nfrom torch.utils import cpp_extension\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nsetup(\n    name='vLLMMarlin',\n    ext_modules=[\n        CUDAExtension(\n            'vLLMMarlin', [\n                #'custom_gguf/dequant.cu',\n                'binding.cpp',\n                'gptq_marlin/gptq_marlin.cu',\n                'gptq_marlin/gptq_marlin_repack.cu',\n            ],\n            extra_compile_args={\n                'cxx': ['-O3'],\n                'nvcc': [\n                    '-O3',\n                    '--use_fast_math',\n                    '-Xcompiler', '-fPIC',\n                ]\n            },\n        )\n    ],\n    cmdclass={'build_ext': BuildExtension}\n)"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/test_cuda_graph.py",
    "content": "import csv\r\nimport torch\r\nimport torch.nn as nn\r\nimport vLLMMarlin\r\ntorch.set_grad_enabled(False)\r\nfrom utils.marlin_utils import (\r\n\tMarlinWorkspace,\r\n\tmarlin_quantize,\r\n\tGPTQ_MARLIN_MIN_THREAD_N,\r\n\tGPTQ_MARLIN_MIN_THREAD_K,\r\n\tGPTQ_MARLIN_MAX_PARALLEL,\r\n)\r\n\r\ndef setup_seed(seed):\r\n\ttorch.manual_seed(seed)\r\n\ttorch.cuda.manual_seed_all(seed)\r\n\r\nsetup_seed(20241223)\r\n\r\ntorch.set_grad_enabled(False)\r\ntorch.set_default_dtype(torch.bfloat16)\r\nglobal_dtype=torch.bfloat16\r\nglobal_device=torch.device(\"cuda\",0)\r\nglobal_num_cases:int=int(50)\r\ntorch.cuda.set_device(0)\r\ntorch.backends.cudnn.enabled =True\r\ntorch.backends.cudnn.benchmark = True\r\n\r\nmax_batch_size = 512\r\nmax_tp = 8\r\nL2_size = 73728 * 1024\r\n\r\ndef get_usable_mem():\r\n\tproperties = torch.cuda.get_device_properties(global_device)\r\n\t#print(f\"Total memory: {properties.total_memory / (1024 ** 3):.2f} GB\")\r\n\tallocated_memory = torch.cuda.memory_allocated(global_device)\r\n\t#print(f\"Currently allocated memory: {allocated_memory / (1024 ** 2):.2f} MB\")\r\n\treserved_memory = torch.cuda.memory_reserved(global_device)\r\n\t#print(f\"Currently reserved memory: {reserved_memory / (1024 ** 2):.2f} MB\")\r\n\treturn properties.total_memory - 512 * 1024 ** 2 - allocated_memory# - reserved_memory\r\n\r\ndef exp_range(start, stop, step = 2):\r\n\tnow = start\r\n\twhile now <= stop:\r\n\t\tyield now\r\n\t\tnow *= step\r\n\r\ndef timing(func, iters, epochs=100):\r\n\t#warmup\r\n\tfor idx in range(iters):\r\n\t\tfunc(idx)\r\n\t\t\r\n\ttorch.cuda.synchronize()\r\n\tcuda_graph = torch.cuda.CUDAGraph()\r\n\twith torch.cuda.graph(cuda_graph):\r\n\t\tfor idx in range(iters):\r\n\t\t\tfunc(idx)\r\n\r\n\tfor _ in range(2000):\r\n\t\tcuda_graph.replay()\r\n\r\n\tstart_event = torch.cuda.Event(enable_timing=True)\r\n\tend_event = torch.cuda.Event(enable_timing=True)\r\n\tstream = torch.cuda.Stream()\r\n\ttorch.cuda.synchronize()\r\n\t#with torch.cuda.stream(stream):\r\n\tstart_event.record()\r\n\tfor _ in range(10):\r\n\t\tcuda_graph.replay()\r\n\tend_event.record()\r\n\ttorch.cuda.synchronize()\r\n\telapsed_time_ms0 = start_event.elapsed_time(end_event)\r\n\t\r\n\tstart_event = torch.cuda.Event(enable_timing=True)\r\n\tend_event = torch.cuda.Event(enable_timing=True)\r\n\ttorch.cuda.synchronize()\r\n\t#with torch.cuda.stream(stream):\r\n\tstart_event.record()\r\n\tfor _ in range(epochs+10):\r\n\t\tcuda_graph.replay()\r\n\tend_event.record()\r\n\ttorch.cuda.synchronize()\r\n\telapsed_time_ms = start_event.elapsed_time(end_event) - elapsed_time_ms0\r\n\t\r\n\t#print(elapsed_time_ms0, elapsed_time_ms)\r\n\treturn elapsed_time_ms/iters/epochs\r\n\r\nclass LinearMarlin(nn.Linear):\r\n\tmarlin_q_w: torch.Tensor\r\n\tmarlin_s: torch.Tensor\r\n\tg_idx: torch.Tensor\r\n\tsort_indices: torch.Tensor\r\n\thas_bias: bool\r\n\tdef __init__(\r\n\t\tself,\r\n\t\tin_features,\r\n\t\tout_features,\r\n\t\tbias = False,\r\n\t\tdevice: str = \"cuda\",\r\n\t\tnum_bits: int = 4,  # 4-bit/8-bit is supported\r\n\t\tgroup_size: int = 64,  # -1, 32, 64, 128\r\n\t\tact_order: bool = False,\r\n\t\tis_k_full=True,\r\n\t\tsms = -1, # sms in GPU\r\n\t\t**kwargs,\r\n\t):\r\n\t\tself.padding = False\r\n\t\tassert device.lower() != \"cpu\", \"Marlin quantized linear only supports GPU device\"\r\n\t\tif in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:\r\n\t\t\t#print(f\"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding\")\r\n\t\t\tself.padding = True\r\n\t\t\tself.orin_in_features = in_features\r\n\t\t\tself.orin_out_features = out_features\r\n\t\t\tin_features = (in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K\r\n\t\t\tout_features = (out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N\r\n\t\t\t#print(f\"After padding: in_features={in_features}, out_features={out_features}\")\r\n\t\t\t\r\n\r\n\t\tsuper().__init__(in_features, out_features, bias, device)\r\n\t\tself.has_bias = bias\r\n\t\tself.device = device\r\n\t\tself.num_bits = num_bits\r\n\t\tself.group_size = group_size\r\n\t\tself.act_order = act_order\r\n\t\t# TODO: optimize every shape GEMM\r\n\t\t\r\n\t\tblocks_k, blocks_n = in_features//128, out_features//128\r\n\r\n\t\tself.sms = sms\r\n\r\n\t\tself.is_k_full = is_k_full\r\n\t\t\r\n\t\tself.weight.requires_grad = False\r\n\t\tself.weight.t_()\r\n\t\t# Pack Marlin linear\r\n\t\t#w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(\r\n\t\t#    self.weight, self.num_bits, self.group_size, self.act_order\r\n\t\t#)\r\n\t\tmarlin_q_w = torch.randint(int(-1e9), int(1e9), (in_features//16, out_features*2), device=device, dtype=torch.int)\r\n\t\tmarlin_s = torch.randn((in_features//64, out_features), device=device)\r\n\t\tself.workspace = MarlinWorkspace(\r\n\t\t\tself.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL, self.device\r\n\t\t)\r\n\t\tself.marlin_q_w = marlin_q_w\r\n\t\tself.marlin_s = marlin_s\r\n\t\tself.g_idx = torch.empty((0), dtype=torch.int32, device=self.device)\r\n\t\tself.sort_indices = torch.empty((0), dtype=torch.int32, device=self.device)\r\n\t\tself.k = self.weight.shape[0]\r\n\t\tself.n = self.weight.shape[1]\r\n\t\tself.weight = None\r\n\t\t\"\"\"\r\n\t\tprint(in_features, out_features)\r\n\t\tprint(marlin_q_w.shape)\r\n\t\tprint(marlin_q_w.dtype)\r\n\t\tprint(marlin_s.shape)\r\n\t\tprint(marlin_s.dtype)\r\n\t\tprint(self.workspace.scratch.shape)\r\n\t\tprint(self.workspace.scratch.dtype)\r\n\t\tprint(self.g_idx.shape)\r\n\t\tprint(self.g_idx.dtype)\r\n\t\tprint(self.sort_indices.shape)\r\n\t\tprint(self.sort_indices.dtype)\r\n\t\t#print(w_ref.shape)\r\n\t\t#print(w_ref.dtype)\r\n\t\t\"\"\"\r\n\t\t#w_ref = None\r\n\r\n\tdef forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:\r\n\t\t# Only support input x as BF16 and FP16\r\n\t\tx = x.to(self.device)\r\n\t\torig_shape = list(x.shape)\r\n\t\torig_dtype = x.dtype\r\n\t\tx = x.reshape(-1, x.shape[-1])\r\n\t\tif self.padding:\r\n\t\t\tpadding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)\r\n\t\t\tpadding_input[:,:self.orin_in_features] = x\r\n\t\t\tx = padding_input\r\n\t\tmarlin_s = self.marlin_s.to(x.dtype)\r\n\t\t#print(self.sms * ((orig_shape[0]+63)//64))\r\n\t\t\r\n\t\tsms = self.sms\r\n\r\n\t\tx = vLLMMarlin.gptq_marlin_gemm(\r\n\t\t\tx,\r\n\t\t\tself.marlin_q_w,\r\n\t\t\tmarlin_s,\r\n\t\t\tself.g_idx,\r\n\t\t\tself.sort_indices,\r\n\t\t\tself.workspace.scratch,\r\n\t\t\tself.num_bits,\r\n\t\t\tbsz_tensor,\r\n\t\t\tx.shape[0],\r\n\t\t\tself.n,\r\n\t\t\tx.shape[-1],\r\n\t\t\tsms,\r\n\t\t\tself.is_k_full,\r\n\t\t)\r\n\t\t# TODO: don't padding bias\r\n\t\tif self.has_bias:\r\n\t\t\tx = x + self.bias\r\n\t\tif self.padding:\r\n\t\t\tx = x[:,:self.orin_out_features]\r\n\t\t\torig_shape[-1] = self.orin_out_features\r\n\t\telse:\r\n\t\t\torig_shape[-1] = self.out_features\r\n\t\treturn x.reshape(orig_shape).to(orig_dtype)\r\n\r\ndef benchLinearMarlin(input_dim, output_dim):#, out_file\r\n\tprint(\"benchmarking MLP Marlin\")\r\n\tprint(\"-----------------------------------------------------------\")\r\n\theaders = [\"batch_size\", \"tp\", \"used_time\", \"bandwidth GB/s\", \"TFLOPS\", \"cases\", \"padding\", \"sms\"]\r\n\tprint(\" | \".join(headers) + \"\\n\")\r\n\trows = []\r\n\tfor batch_size in exp_range(1, 64):\r\n\t\tfor tp in exp_range(1, max_tp):\r\n\t\t\ttorch.cuda.empty_cache()\r\n\t\t\tif output_dim % tp != 0:\r\n\t\t\t\tcontinue\r\n\t\t\tcur_output_dim = output_dim // tp\r\n\t\t\tmodules = []\r\n\t\t\tinputs = []\r\n\t\t\tdata_size = int(0.53125*input_dim*cur_output_dim)\r\n\t\t\tinput_size = int(2*batch_size*input_dim)\r\n\t\t\toutput_size = int(2*batch_size*cur_output_dim)\r\n\t\t\tusable_mem = get_usable_mem() - 2 * input_dim * cur_output_dim\r\n\t\t\tmin_cases = max(global_num_cases, (2*L2_size) // (data_size+input_size))\r\n\t\t\tcases = int(min(min_cases, (usable_mem * 0.8) // (data_size+input_size)))\r\n\t\t\t#print(usable_mem, data_size, input_size, cases)\r\n\t\t\t\t\r\n\t\t\tbsz_tensor = torch.tensor([batch_size], device=global_device, dtype=torch.int32)\r\n\r\n\t\t\tif cases == 0:\r\n\t\t\t\trow = [f\"{batch_size}\", \"OOM\", \"OOM\", \"OOM\", \"0\", \"False\"]\r\n\t\t\t\trows.append(row)\r\n\t\t\t\tbreak\r\n\t\t\tfor _ in range(cases):\r\n\t\t\t\tmodules.append(LinearMarlin(input_dim, cur_output_dim, sms=56, non_equal_division=False).to(device=global_device).eval())\r\n\t\t\t\tinputs.append(torch.randn(batch_size, 1, input_dim, device=global_device))\r\n\t\t\t\t\r\n\t\t\tdef forward(case_id):\r\n\t\t\t\tmodules[case_id](inputs[case_id], bsz_tensor)\r\n\t\t\t\t\r\n\t\t\tused_time = timing(forward, iters=cases)\r\n\t\t\tbandwidth = (data_size+input_size+output_size)/used_time/1e6\r\n\t\t\tflops = 2*batch_size*input_dim*cur_output_dim\r\n\t\t\ttflops = flops/used_time/1e9\r\n\t\t\tcur_sms = modules[0].sms\r\n\t\t\trow = [f\"{batch_size}\", f\"{tp}\", f\"{used_time}\", f\"{bandwidth}\", f\"{tflops}\", f\"{cases}\", modules[0].padding, cur_sms]\r\n\t\t\trows.append(row)\r\n\t\t\tprint(f\"{batch_size}\", f\"{tp}\", f\"{used_time}\", f\"{bandwidth}\", f\"{tflops}\", f\"{cases}\", modules[0].padding, cur_sms)\r\n\t\r\n\t\"\"\"\r\n\twith open(out_file, 'w', newline='') as csvfile:\r\n\t\tcsvwriter = csv.writer(csvfile)\r\n\t\tcsvwriter.writerow(headers)\r\n\t\tfor row in rows:\r\n\t\t\tcsvwriter.writerow(row)\r\n\t\"\"\"\r\n\t\r\n\t\"\"\"\r\n\tmarkdown_table = \" | \".join(headers) + \"\\n\"\r\n\tmarkdown_table += \" | \".join([\"---\"] * len(headers)) + \"\\n\"\r\n\tfor row in rows:\r\n\t\tmarkdown_table += \" | \".join(row) + \"\\n\"\r\n\r\n\tprint(markdown_table)\r\n\t\"\"\"\r\n\t#print(\"finish write file\", out_file)\r\n\t#print(\"-------------------------------------------------------------\")\r\n\r\nif __name__ == \"__main__\":\r\n\t\r\n\tbenchLinearMarlin(5120, 3584)\r\n\texit(0)\r\n\t\r\n\tmax_batch = 1\r\n\tcur_batch = 1\r\n\r\n\r\n\tmarlin_linear = LinearMarlin(5120, 3584)\r\n\r\n\tinput_tensor = torch.randn(max_batch, 1, 5120, device=\"cuda\", dtype=torch.bfloat16)\r\n\tbsz_tensor = torch.tensor([max_batch], device=\"cuda\", dtype=torch.int32)\r\n\r\n\tout_truth = marlin_linear(input_tensor, bsz_tensor)\r\n\r\n\tprint(out_truth)\r\n\r\n\tg = torch.cuda.CUDAGraph()\r\n\twith torch.cuda.graph(g):\r\n\t\tout_buf = marlin_linear(input_tensor, bsz_tensor)\r\n\t\r\n\tfor i in range(10000):\r\n\t\tg.replay()\r\n\t\r\n\t#torch.testing.assert_close(out_buf, out_truth, rtol=1e-3, atol=1e-3)\r\n\t\r\n\tmarlin_linear = LinearMarlin(5120, 3584)\r\n\tg = torch.cuda.CUDAGraph()\r\n\twith torch.cuda.graph(g):\r\n\t\tout_buf = marlin_linear(input_tensor, bsz_tensor)\r\n\t\r\n\tnew_input = torch.randn(cur_batch, 1, 5120, device=\"cuda\", dtype=torch.bfloat16)\r\n\tbsz_tensor.copy_(torch.tensor([cur_batch], device=\"cuda\", dtype=torch.int32))\r\n\t\r\n\tnew_out_truth = marlin_linear(new_input, bsz_tensor)\r\n\tinput_tensor[:cur_batch].copy_(new_input)\r\n\tinput_tensor[cur_batch:] = 0\r\n\t\r\n\tg.replay()\r\n\t\r\n\ttorch.cuda.synchronize()\r\n\r\n\tdef printMinMax(tensor):\r\n\t\tabs_tensor = torch.abs(tensor)\r\n\r\n\t\tmin_val = torch.min(abs_tensor)\r\n\t\tmax_val = torch.max(abs_tensor)\r\n\r\n\t\tmin_indices = (abs_tensor == min_val).nonzero(as_tuple=True)\r\n\t\tmax_indices = (abs_tensor == max_val).nonzero(as_tuple=True)\r\n\r\n\t\tprint(f\"min: {min_val.item()}\")\r\n\t\tprint(f\"min idx: {min_indices}\")\r\n\t\tprint(f\"max: {max_val.item()}\")\r\n\t\tprint(f\"max idx: {max_indices}\")\r\n\r\n\tprint(out_buf[:cur_batch].shape)\r\n\tprint(new_out_truth.shape)\r\n\r\n\r\n\tprintMinMax(out_buf[:cur_batch])\r\n\tprintMinMax(new_out_truth)\r\n\r\n\t#torch.testing.assert_close(out_buf[:cur_batch, 0, :], new_out_truth[:cur_batch, 0, :], rtol=1e-3, atol=1e-3)\r\n"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/utils/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/csrc/custom_marlin/utils/format24.py",
    "content": "#\n# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).\n#\n\nimport torch\n\n\n# This is PyTorch implementation of main part of reorder_meta()\n# function, from tools/util/include/cutlass/util/host_reorder.h file\n# of CUTLASS source tree.  Furthermore, CUTLASS template for sparse\n# GEMM decides upon layout of this matrix, and at the moment for the\n# sparse GEMM executed on tensor cores, this is layout described by\n# ColumnMajorInterleaved<2> data structure, in\n# include/cutlass/layout/matrix.h of CUTLASS source tree.  The\n# reordering of meta matrix into meta_reordered matrix calculated\n# according to these segments of CUTLASS code is re-implemented here.\n# Note that this calculation produces offsets for scattering metadata\n# matrix elements into reordered metadata matrix elements (or,\n# equivalently, for gathering reordered metadata matrix element back\n# into metadata matrix elements).\ndef _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype,\n                                               device):\n    dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)\n    dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)\n\n    # Reorder the rows, then swizzle the 2x2 blocks.\n    group_x = 64\n    group_y = 32 if meta_dtype.itemsize == 2 else 16\n\n    dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 +\n                (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 +\n                ((dst_rows % group_x) // 8) * 4)\n\n    topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)\n    bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)\n    dst_rows += topright - bottomleft\n    dst_cols -= topright - bottomleft\n\n    # Assumed that meta tensor is to be stored in CUTLASS\n    # InterleavedColumnMajor layout, and reverse engineered\n    # corresponding code to store values into this tensor.\n    interleave = 2\n    cols_maj = dst_cols // interleave\n    cols_min = dst_cols % interleave\n    return (cols_maj * m * interleave + dst_rows * interleave +\n            cols_min).view(-1)\n\n\n# This function converts dense matrix into sparse semi-structured\n# representation, producing \"compressed\" matrix, in the layout used by\n# CUTLASS backend, and corresponding metadata matrix.\ndef sparse_semi_structured_from_dense_cutlass(dense):\n    if dense.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor\"  # noqa: E501\n        )\n\n    m, k = dense.shape\n    device = dense.device\n\n    meta_dtype = torch.int8\n    if dense.dtype == torch.int8:\n        meta_dtype = torch.int32\n    elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:\n        meta_dtype = torch.int16\n    else:\n        raise RuntimeError(f\"Invalid datatype {dense.dtype} of dense matrix\")\n    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4\n    if quadbits_per_meta_elem not in (4, 8):\n        raise RuntimeError(\n            \"Invalid number of elements per meta element calculated\")\n\n    if meta_dtype == torch.int32:\n        if m % 16 != 0:\n            raise RuntimeError(\n                f\"Number of rows of dense matrix {m} must be divisible by 16\")\n    else:\n        if m % 32 != 0:\n            raise RuntimeError(\n                f\"Number of rows of dense matrix {m} must be divisible by 32\")\n    if k % (4 * quadbits_per_meta_elem) != 0:\n        raise RuntimeError(\n            f\"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}\"  # noqa: E501\n        )\n\n    if dense.dtype != torch.float:\n        ksparse = 4\n        dense_4 = dense.view(-1, k // ksparse, ksparse)\n        m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)\n    else:\n        ksparse = 2\n        dense_2 = dense.view(-1, k // ksparse, ksparse)\n        m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)\n    meta_ncols = k // (ksparse * quadbits_per_meta_elem)\n\n    # Encoding quadruples of True/False values as follows:\n    #     [True,  True,  False, False] -> 0b0100\n    #     [True,  False, True,  False] -> 0b1000\n    #     [False, True,  True,  False] -> 0b1001\n    #     [True,  False, False, True ] -> 0b1100\n    #     [False, True,  False, True ] -> 0b1101\n    #     [False, False, True,  True ] -> 0b1110\n    # Thus, lower two bits in the encoding are index of the True value\n    # at the lowest index in the quadruple, and the higher two bits in\n    # the encoding are index of the other True value in the quadruple.\n    # In case there are less than two True values, than False value or\n    # values at some index or indices are considered True for the\n    # encoding.  In case there are more than two True values, then the\n    # excess True value(s) at some indices are considered False for\n    # the encoding.  The exact encodings used for these cases are as\n    # follows:\n    #     [False, False, False, False] -> 0b1110\n    #     [False, False, False, True ] -> 0b1110\n    #     [False, False, True,  False] -> 0b1110\n    #     [False, True,  False, False] -> 0b1001\n    #     [False, True,  True,  True ] -> 0b1101\n    #     [True,  False, False, False] -> 0b1000\n    #     [True,  False, True,  True ] -> 0b1100\n    #     [True,  True,  False, True ] -> 0b0100\n    #     [True,  True,  True,  False] -> 0b0100\n    #     [True,  True,  True,  True ] -> 0b0100\n    # These particular encodings are chosen, with the help of Espresso\n    # logic minimizer software, for the purpose of minimization of\n    # corresponding Boolean functions, that translate non-zero flags\n    # into encoding bits.  Note also possible choices for the first\n    # and last of these encodings were limited only to (0b0100,\n    # 0b1110), in order to produce valid encodings for 1:2 sparsity\n    # case.\n\n    expr0 = m0 & m1\n    expr1 = ~m0 & m1\n    expr2 = ~m0 & ~m1\n    bit0 = expr1\n    bit1 = expr2\n    bit2 = expr0 | expr2 | m3\n    bit3 = expr1 | ~m1\n    idxs0 = bit0 | (bit1.to(torch.int64) << 1)\n    idxs1 = bit2 | (bit3.to(torch.int64) << 1)\n\n    if dense.dtype != torch.float:\n        sparse0 = dense_4.gather(\n            -1, idxs0.unsqueeze(-1))  # type: ignore[possibly-undefined]\n        sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))\n        sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)\n    else:\n        sparse = dense_2.gather(-1,\n                                idxs0.unsqueeze(-1) // 2).view(\n                                    m,\n                                    k // 2)  # type: ignore[possibly-undefined]\n\n    meta_4 = idxs0 | (idxs1 << 2)\n    meta_n = meta_4.view(\n        (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)\n\n    if quadbits_per_meta_elem == 4:\n        meta = (meta_n[:, :, 0]\n                | (meta_n[:, :, 1] << 4)\n                | (meta_n[:, :, 2] << 8)\n                | (meta_n[:, :, 3] << 12))\n    elif quadbits_per_meta_elem == 8:\n        meta = (meta_n[:, :, 0]\n                | (meta_n[:, :, 1] << 4)\n                | (meta_n[:, :, 2] << 8)\n                | (meta_n[:, :, 3] << 12)\n                | (meta_n[:, :, 4] << 16)\n                | (meta_n[:, :, 5] << 20)\n                | (meta_n[:, :, 6] << 24)\n                | (meta_n[:, :, 7] << 28))\n\n    # Reorder meta tensor elements.\n    meta_reordered = meta.new_empty(\n        (m * meta_ncols, ))  # type: ignore[possibly-undefined]\n    meta_offsets = _calculate_meta_reordering_scatter_offsets(\n        m, meta_ncols, meta_dtype, device)\n    meta_reordered.scatter_(0, meta_offsets, meta.view(-1))\n\n    return (sparse, meta_reordered.view(m, meta_ncols))\n\n\n# This function performs reverse of the function above - it\n# reconstructs dense matrix from a pair of \"compressed\" matrix, given\n# in the layout used by CUTLASS backend, and accompanying metadata\n# matrix.\ndef sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):\n    if sparse.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor\"  # noqa: E501\n        )\n\n    m, k = sparse.shape\n    device = sparse.device\n\n    if meta_reordered.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor\"  # noqa: E501\n        )\n    if meta_reordered.device != device:\n        raise RuntimeError(\n            f\"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device\"  # noqa: E501\n        )\n\n    meta_dtype = meta_reordered.dtype\n    if meta_dtype not in (torch.int16, torch.int32):\n        raise RuntimeError(f\"Invalid datatype {meta_dtype} of meta matrix\")\n    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4\n\n    ksparse = 4 if sparse.dtype != torch.float else 2\n\n    meta_nrows, meta_ncols = meta_reordered.shape\n    if meta_nrows != m:\n        raise RuntimeError(\n            f\"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}\"  # noqa: E501\n        )\n    if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:\n        raise RuntimeError(\n            f\"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, \"  # noqa: E501\n            \"expected according to the number of columns of meta matrix\")\n\n    # Undo meta tensor elements reordering.\n    meta_offsets = _calculate_meta_reordering_scatter_offsets(\n        m, meta_ncols, meta_dtype, device)\n    meta = torch.gather(meta_reordered.view(-1), 0,\n                        meta_offsets).view(m, meta_ncols)\n\n    # Unpack sparse tensor back to original dense tensor, using\n    # information provided by meta tensor.  Note that torch.float\n    # datatype is handled pretty much the same as\n    # torch.half/torch.bfloat16, as metadata for a pair of torch.float\n    # value is encoded as if underlying 8 bytes contain four\n    # torch.half/torch.bfloat16 values, where either first two or last\n    # two are zeros.\n    meta_2 = torch.empty(\n        (m, meta_ncols, 2 * quadbits_per_meta_elem),\n        dtype=meta_dtype,\n        device=device,\n    )\n    if quadbits_per_meta_elem == 4:\n        meta_2[:, :, 0] = meta & 0b11\n        meta_2[:, :, 1] = (meta >> 2) & 0b11\n        meta_2[:, :, 2] = (meta >> 4) & 0b11\n        meta_2[:, :, 3] = (meta >> 6) & 0b11\n        meta_2[:, :, 4] = (meta >> 8) & 0b11\n        meta_2[:, :, 5] = (meta >> 10) & 0b11\n        meta_2[:, :, 6] = (meta >> 12) & 0b11\n        meta_2[:, :, 7] = (meta >> 14) & 0b11\n    elif quadbits_per_meta_elem == 8:\n        meta_2[:, :, 0] = meta & 0b11\n        meta_2[:, :, 1] = (meta >> 2) & 0b11\n        meta_2[:, :, 2] = (meta >> 4) & 0b11\n        meta_2[:, :, 3] = (meta >> 6) & 0b11\n        meta_2[:, :, 4] = (meta >> 8) & 0b11\n        meta_2[:, :, 5] = (meta >> 10) & 0b11\n        meta_2[:, :, 6] = (meta >> 12) & 0b11\n        meta_2[:, :, 7] = (meta >> 14) & 0b11\n        meta_2[:, :, 8] = (meta >> 16) & 0b11\n        meta_2[:, :, 9] = (meta >> 18) & 0b11\n        meta_2[:, :, 10] = (meta >> 20) & 0b11\n        meta_2[:, :, 11] = (meta >> 22) & 0b11\n        meta_2[:, :, 12] = (meta >> 24) & 0b11\n        meta_2[:, :, 13] = (meta >> 26) & 0b11\n        meta_2[:, :, 14] = (meta >> 28) & 0b11\n        meta_2[:, :, 15] = (meta >> 30) & 0b11\n\n    dense_offsets = meta_2.view(-1) + (\n        torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view(\n            -1, 1).repeat(1, 2).view(-1)\n\n    dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device)\n    if sparse.dtype != torch.float:\n        # dense.scatter_(0, dense_offsets, sparse.view(-1))\n        dense.scatter_(0, dense_offsets, sparse.reshape(-1))\n    else:\n        dense.view(torch.half).scatter_(0, dense_offsets,\n                                        sparse.view(torch.half).view(-1))\n\n    return dense.view(m, 2 * k)\n\n\ndef mask_creator(tensor):\n    \"\"\"\n    Class for creating N:M sparsity masks.\n    Masks will be created using the N:M ratio, where for every block of \n    M weights, N will be pruned based on ranked weight value. Each mask \n    will correspond to the given tensor.\n\n    :param N: The number of weights in a group to keep\n    :param M: The size of a weight group\n    \"\"\"\n    N = 2\n    M = 4\n\n    mask = None\n    # for i, tensor in enumerate(tensors):\n    if tensor.numel() % M != 0:\n        raise ValueError(\n            f\"Tensor of size {tensor.shape} can't be evenly divided into \"\n            f\"{M} groups\")\n\n    num_groups = tensor.numel() // M\n\n    # N:M sparsity for linear layers\n    tensor_temp = tensor.detach().abs().reshape(num_groups, M)\n    index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)]\n\n    w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)\n    mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)\n\n    return mask"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/utils/marlin_24_perms.py",
    "content": "'''\nDate: 2024-11-08 02:46:07\nLastEditors: djw\nLastEditTime: 2024-11-08 02:46:41\n'''\n\"\"\"This file is used for /tests and /benchmarks\"\"\"\nfrom typing import Dict, List\n\nimport numpy\nimport torch\n\n\n# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501\n#\n# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501\n# with the tensor-core format that is described here:\n# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501\n#\n# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501\n# (without the need to use ldmatrix instructions) # noqa: E501\ndef get_perms_24(num_bits: int):\n    perm_list: List[int] = []\n    for i in range(32):\n        perm1: List[int] = []\n        col = i // 4\n        col_o = col // 2\n        for block in [0, 1]:\n            for row in [\n                    2 * (i % 4),\n                    2 * (i % 4) + 1,\n                    2 * (i % 4 + 4),\n                    2 * (i % 4 + 4) + 1,\n            ]:\n                perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +\n                             4 * block)\n        for j in range(4):\n            perm_list.extend([p + 1 * j for p in perm1])\n    perm = numpy.array(perm_list)\n\n    if num_bits == 4:\n        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = numpy.array([0, 2, 1, 3])\n    else:\n        raise ValueError(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()\n    perm = torch.from_numpy(perm)\n    scale_perm: List[int] = []\n    for i in range(8):\n        scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])\n    scale_perm_single: List[int] = []\n    for i in range(8):\n        scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])\n    return perm, scale_perm, scale_perm_single\n\n\nmarlin_24_perm: Dict[int, torch.Tensor] = {}\nmarlin_24_scale_perm: Dict[int, List[int]] = {}\nmarlin_24_scale_perm_single: Dict[int, List[int]] = {}\nfor num_bits in [4, 8]:\n    perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)\n    marlin_24_perm[num_bits] = perm_24\n    marlin_24_scale_perm[num_bits] = scale_perm_24\n    marlin_24_scale_perm_single[num_bits] = scale_perm_single_24"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/utils/marlin_perms.py",
    "content": "'''\nDate: 2024-11-08 02:46:47\nLastEditors: djw\nLastEditTime: 2024-11-08 02:46:55\n'''\n\"\"\"This file is used for /tests and /benchmarks\"\"\"\nfrom typing import Dict, List\n\nimport numpy\nimport torch\n\n\n# Precompute permutations for Marlin weight and scale shuffling # noqa: E501\n#\n# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501\n# with the tensor-core format that is described here:\n# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501\n#\n# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501\n# (without the need to use ldmatrix instructions) # noqa: E501\ndef get_perms(num_bits: int):\n    perm_list: List[int] = []\n    for i in range(32):\n        perm1: List[int] = []\n        col = i // 4\n        for block in [0, 1]:\n            for row in [\n                    2 * (i % 4),\n                    2 * (i % 4) + 1,\n                    2 * (i % 4 + 4),\n                    2 * (i % 4 + 4) + 1,\n            ]:\n                perm1.append(16 * row + col + 8 * block)\n        for j in range(4):\n            perm_list.extend([p + 256 * j for p in perm1])\n\n    perm = numpy.array(perm_list)\n\n    if num_bits == 4:\n        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = numpy.array([0, 2, 1, 3])\n    else:\n        raise Exception(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()\n    perm = torch.from_numpy(perm)\n    scale_perm: List[int] = []\n    for i in range(8):\n        scale_perm.extend([i + 8 * j for j in range(8)])\n    scale_perm_single: List[int] = []\n    for i in range(4):\n        scale_perm_single.extend(\n            [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])\n    return perm, scale_perm, scale_perm_single\n\n\nmarlin_perm: Dict[int, torch.Tensor] = {}\nmarlin_scale_perm: Dict[int, List[int]] = {}\nmarlin_scale_perm_single: Dict[int, List[int]] = {}\nfor num_bits in [4, 8]:\n    perm, scale_perm, scale_perm_single = get_perms(num_bits)\n    marlin_perm[num_bits] = perm\n    marlin_scale_perm[num_bits] = scale_perm\n    marlin_scale_perm_single[num_bits] = scale_perm_single"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/utils/marlin_utils.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nimport random\n\nimport numpy\nimport torch\n\nfrom .format24 import (\n    mask_creator, sparse_semi_structured_from_dense_cutlass)\nfrom .marlin_24_perms import (\n    marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single)\nfrom .marlin_perms import (\n    marlin_perm, marlin_scale_perm, marlin_scale_perm_single)\nfrom .quant_utils import (\n    get_pack_factor, quantize_weights, sort_weights, dequantize_weights)\n\n\n\n__cuda_arch = torch.cuda.get_device_capability()\n\nMARLIN_TILE = 16\n\nGPTQ_MARLIN_TILE = 16\nGPTQ_MARLIN_MIN_THREAD_N = 64\nGPTQ_MARLIN_MIN_THREAD_K = 128\nGPTQ_MARLIN_MAX_PARALLEL = 16\n\nGPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]\nGPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]\nGPTQ_MARLIN_SUPPORTED_SYM = [True]\n\ndef is_marlin_supported():\n    return __cuda_arch[0] >= 8\n\n\ndef marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):\n    assert q_w.shape == (size_k, size_n)\n    assert size_k % tile == 0, f\"size_k = {size_k}, tile = {tile}\"\n    assert size_n % tile == 0, f\"size_k = {size_n}, tile = {tile}\"\n\n    # Permute weights to 16x64 marlin tiles\n    q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))\n    q_w = q_w.permute((0, 2, 1, 3))\n    q_w = q_w.reshape((size_k // tile, size_n * tile))\n\n    q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)\n\n    return q_w\n\n\ndef marlin_weights(q_w, size_k, size_n, num_bits, perm):\n    # Permute\n    q_w = marlin_permute_weights(q_w, size_k, size_n, perm)\n\n    # Pack\n    pack_factor = get_pack_factor(num_bits)\n    orig_device = q_w.device\n\n    q_w = q_w.cpu().numpy().astype(numpy.uint32)\n\n    q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),\n                           dtype=numpy.uint32)\n    for i in range(pack_factor):\n        q_packed |= q_w[:, i::pack_factor] << num_bits * i\n\n    q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)\n\n    return q_packed\n\n\ndef marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,\n                          scale_perm_single):\n    if group_size < size_k and group_size != -1:\n        s = s.reshape((-1, len(scale_perm)))[:, scale_perm]\n    else:\n        s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]\n    s = s.reshape((-1, size_n)).contiguous()\n\n    return s\n\n\ndef marlin_quantize(\n    w: torch.Tensor,\n    num_bits: int,\n    group_size: int,\n    act_order: bool,\n):\n    size_k, size_n = w.shape\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    # Quantize (and apply act_order if provided)\n    w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,\n                                                       act_order)\n\n    # For act_order, sort the \"weights\" and \"g_idx\" so that group ids are\n    # increasing\n    sort_indices = torch.empty(0, dtype=torch.int, device=w.device)\n    if act_order:\n        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)\n\n    # Reformat to marlin\n    marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,\n                                marlin_perm[num_bits])\n    marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,\n                                     marlin_scale_perm[num_bits],\n                                     marlin_scale_perm_single[num_bits])\n\n    # Create result\n    res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]\n    for i in range(len(res_list)):\n        res_list[i] = res_list[i].to(w.device)\n\n    return res_list\n\n\ndef inject_24(w, size_k, size_n):\n    assert w.shape == (size_k, size_n)\n\n    mask = mask_creator(w.t()).t().cuda().bool()\n\n    return (mask * w).contiguous(), mask.contiguous()\n\n\ndef check_24(w, num_rows_to_sample=50, _verbose=False):\n    BLOCK_SIZE = 4\n    MAX_NON_ZEROS = 2\n\n    w = w.t().contiguous()\n\n    print(\"check_24: w.shape = {}\".format(w.shape))\n\n    num_rows, num_cols = w.shape\n    sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)\n    if _verbose:\n        print(f\"Sampled row idxs = {sampled_row_idxs}\")\n\n    total_segments = 0\n    non_24_segments = 0\n    for i in sampled_row_idxs:\n        for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):\n            total_segments += 1\n            block = w[i, j:j + BLOCK_SIZE]\n            num_nonzero = torch.count_nonzero(block)\n            if num_nonzero > MAX_NON_ZEROS:\n                print(\"i = {} j = {} block = {}\".format(i, j, block))\n                non_24_segments += 1\n\n    print(f\"{non_24_segments} / {total_segments} do not have 2:4 structure.\")\n\n\ndef compress_quantized_24_weight(q_24, size_k, size_n, num_bits):\n    assert q_24.shape == (size_k, size_n)\n\n    # Remove zp to normalize over 0\n    max_q_val = (1 << num_bits) - 1\n    zp = (max_q_val + 1) // 2\n    q_24_no_zp = q_24 - zp\n\n    # Compress\n    q_24_no_zp = q_24_no_zp.t().contiguous()\n    q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(\n        q_24_no_zp)\n    q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()\n\n    # Restore zp\n    q_24_comp = q_24_no_zp_comp + zp\n\n    # Resize meta to its actual shape (without moving any data)\n    meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)\n\n    return q_24_comp, meta\n\n\ndef marlin_24_quantize(\n    w: torch.Tensor,\n    num_bits: int,\n    group_size: int,\n):\n    size_k, size_n = w.shape\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    # Inject 2:4 sparsity\n    w_24, mask_24 = inject_24(w, size_k, size_n)\n\n    # Quantize\n    w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,\n                                                             num_bits,\n                                                             group_size,\n                                                             act_order=False)\n\n    # Compress quantized weight\n    q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,\n                                                     num_bits)\n    size_k_comp = size_k // 2\n\n    # Reformat to marlin\n    marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,\n                                        num_bits, marlin_24_perm[num_bits])\n    marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,\n                                        marlin_24_scale_perm[num_bits],\n                                        marlin_24_scale_perm_single[num_bits])\n\n    # Create result\n    res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]\n    for i in range(len(res_list)):\n        res_list[i] = res_list[i].to(w.device)\n\n    return res_list\n\n\ndef compute_max_diff(output, output_ref):\n    return torch.mean(torch.abs(output - output_ref)) / torch.mean(\n        torch.abs(output_ref))\n\n\nclass MarlinWorkspace:\n\n    def __init__(self, out_features, min_thread_n, max_parallel, device):\n        assert (out_features % min_thread_n == 0), (\n            \"out_features = {} is undivisible by min_thread_n = {}\".format(\n                out_features, min_thread_n))\n\n        max_workspace_size = ((out_features // min_thread_n) * max_parallel)\n\n        self.scratch = torch.zeros(max_workspace_size,\n                                   dtype=torch.int,\n                                   device=device)"
  },
  {
    "path": "kt-sft/csrc/custom_marlin/utils/quant_utils.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nimport numpy\nimport torch\n\nSUPPORTED_NUM_BITS = [4, 8]\nSUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]\n\n\ndef get_pack_factor(num_bits):\n    assert num_bits in SUPPORTED_NUM_BITS, f\"Unsupported num_bits = {num_bits}\"\n    return 32 // num_bits\n\n\ndef permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):\n    assert q_w.shape == w_ref.shape\n\n    orig_device = q_w.device\n    k_size, _ = q_w.shape\n\n    g_idx = torch.zeros((k_size, ), dtype=torch.int32)\n    for i in range(k_size):\n        g_idx[i] = i // group_size\n\n    # Simulate act_order by doing a random permutation on K\n    rand_perm = torch.randperm(k_size)\n\n    g_idx = g_idx[rand_perm].contiguous()\n    q_w = q_w[rand_perm, :].contiguous()\n    w_ref = w_ref[rand_perm, :].contiguous()\n\n    return (\n        w_ref.to(device=orig_device),\n        q_w.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        rand_perm.to(device=orig_device),\n    )\n\n\n# Function: Dequantize quantized weights\ndef dequantize_weights(qweight, qzeros, scales, g_idx, bits=4, group_size=128, device='cuda:0'):\n    # Create a tensor for bitwise right shift operation\n    wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=device).unsqueeze(0)\n\n    # Apply bitwise right shift and convert qzeros to the appropriate type\n    zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)\n    torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)\n\n    # Reshape the zeros tensor\n    zeros = zeros + 1\n    zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])\n\n    # Reshape the scales tensor\n    scales = scales.reshape(-1, 1, scales.shape[-1])\n\n    # Similar bitwise right shift operation for qweight and reshape\n    weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)\n    torch.bitwise_and(weight, (2 ** bits) - 1, out=weight)\n    weight = weight.reshape(-1, group_size, weight.shape[2])\n\n    # Apply dequantization formula and reshape the final weight\n    weight = (scales * (weight - zeros))\n    weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])\n\n    # Return the transposed weight\n    return weight.transpose(0, 1)\n\ndef quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,\n                     act_order: bool):\n    orig_device = w.device\n    size_k, size_n = w.shape\n\n    assert w.is_floating_point(), \"w must be float\"\n    assert num_bits in SUPPORTED_NUM_BITS, f\"Unsupported num_bits = {num_bits}\"\n    assert group_size in SUPPORTED_GROUP_SIZES + [\n        size_k\n    ], f\"Unsupported groupsize = {group_size}\"\n\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    max_q_val = 2**num_bits - 1\n    half_q_val = (max_q_val + 1) // 2\n\n    # Reshape to [groupsize, -1]\n    if group_size < size_k:\n        w = w.view((-1, group_size, size_n))\n        w = w.permute(1, 0, 2)\n        w = w.reshape((group_size, -1))\n\n    # Compute scale for each group\n    s = torch.max(torch.abs(w), 0, keepdim=True)[0]\n    s *= 2 / max_q_val  # 2 => symmetric\n\n    # Quantize\n    q_w = torch.round(w / s).int()\n    q_w += half_q_val\n    q_w = torch.clamp(q_w, 0, max_q_val)\n\n    # Compute ref (dequantized)\n    w_ref = (q_w - half_q_val).half() * s\n\n    # Restore original shapes\n    if group_size < size_k:\n\n        def reshape_w(w):\n            w = w.reshape((group_size, -1, size_n))\n            w = w.permute(1, 0, 2)\n            w = w.reshape((size_k, size_n)).contiguous()\n            return w\n\n        q_w = reshape_w(q_w)\n        w_ref = reshape_w(w_ref)\n\n    s = s.reshape((-1, size_n)).contiguous()\n\n    # Apply act_order\n    g_idx = torch.empty(0, dtype=torch.int, device=w.device)\n    rand_perm = torch.empty(0, dtype=torch.int, device=w.device)\n    if act_order:\n        assert (\n            group_size < size_k\n        ), \"For act_order, groupsize = {} must be less than size_k = {}\".format(\n            group_size, size_k)\n\n        w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)\n\n    return (\n        w_ref.to(device=orig_device),\n        q_w.to(device=orig_device),\n        s.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        rand_perm.to(device=orig_device),\n    )\n\n\ndef sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):\n    orig_device = q_w.device\n\n    sort_indices = torch.argsort(g_idx).to(\n        dtype=torch.int32)  # Sort based on g_idx\n\n    g_idx = g_idx[sort_indices].contiguous()\n    q_w = q_w[sort_indices, :].contiguous()\n\n    return (\n        q_w.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        sort_indices.to(device=orig_device),\n    )\n\n\ndef gptq_pack(\n    q_w: torch.Tensor,\n    num_bits: int,\n    size_k: int,\n    size_n: int,\n):\n    assert q_w.shape == (size_k, size_n)\n\n    pack_factor = get_pack_factor(num_bits)\n    assert size_k % pack_factor == 0\n\n    orig_device = q_w.device\n\n    q_w = q_w.cpu().numpy().astype(numpy.uint32)\n\n    q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)\n\n    for i in range(pack_factor):\n        q_res |= q_w[i::pack_factor, :] << num_bits * i\n\n    q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)\n    return q_res\n\ndef gptq_unpack(\n    q_res: torch.Tensor,\n    num_bits: int,\n    size_k: int,\n    size_n: int,\n):\n    pack_factor = 32 // num_bits\n    assert size_k % pack_factor == 0\n\n    orig_device = q_res.device\n\n    q_res = q_res.cpu().numpy()\n\n    q_w = numpy.zeros((size_k, size_n), dtype=numpy.uint32)\n\n    for i in range(pack_factor):\n        q_w[i::pack_factor, :] = (q_res >> (num_bits * i)) & ((1 << num_bits) - 1)\n\n    q_w = torch.from_numpy(q_w.astype(numpy.int32)).to(orig_device)\n    return q_w"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.16)\nproject(cpuinfer_ext VERSION 0.1.0)\n\n\nset(CMAKE_CXX_STANDARD 17)\n\n\nset(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -O3 -ffast-math -fopenmp\")\nadd_compile_definitions(_GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI})\nset(CMAKE_BUILD_TYPE \"Release\")\n# set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -ffast-math -fopenmp\")\n# set(CMAKE_BUILD_TYPE \"Debug\")\nset(CMAKE_EXPORT_COMPILE_COMMANDS ON)\n\n\ninclude(CheckCXXCompilerFlag)\nset(CMAKE_POSITION_INDEPENDENT_CODE ON)\n\n\noption(LLAMA_NATIVE                     \"llama: enable -march=native flag\"                      ON)\n\n# instruction set specific\nif (LLAMA_NATIVE)\n    set(INS_ENB OFF)\nelse()\n    set(INS_ENB ON)\nendif()\n\noption(LLAMA_AVX                             \"llama: enable AVX\"                                OFF)\noption(LLAMA_AVX2                            \"llama: enable AVX2\"                               OFF)\noption(LLAMA_AVX512                          \"llama: enable AVX512\"                             OFF)\noption(LLAMA_AVX512_VBMI                     \"llama: enable AVX512-VBMI\"                        OFF)\noption(LLAMA_AVX512_VNNI                     \"llama: enable AVX512-VNNI\"                        OFF)\noption(LLAMA_AVX512_BF16                     \"llama: enable AVX512-BF16\"                        OFF)\noption(LLAMA_FMA                             \"llama: enable FMA\"                                OFF)\n# in MSVC F16C is implied with AVX2/AVX512\nif (NOT MSVC)\n    option(LLAMA_F16C                        \"llama: enable F16C\"                               OFF)\nendif()\noption(LLAMA_AVX512_FANCY_SIMD               \"llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI\"                        OFF)\noption(KTRANSFORMERS_USE_CUDA                \"ktransformers: use CUDA\"                          ON)\noption(KTRANSFORMERS_USE_MUSA                \"ktransformers: use MUSA\"                          OFF)\noption(KTRANSFORMERS_USE_ROCM                \"ktransformers: use ROCM\"                          OFF)\noption(KTRANSFORMERS_USE_XPU                 \"ktransformers: use XPU\"                           OFF)\noption(KTRANSFORMERS_USE_NPU                 \"ktransformers: use NPU\"                           OFF)\n\nif(KTRANSFORMERS_USE_NPU)\n    add_definitions(-DKTRANSFORMERS_USE_NPU=1)\nendif()\n\n# Architecture specific\n# TODO: probably these flags need to be tweaked on some architectures\n#       feel free to update the Makefile for your architecture and send a pull request or issue\nmessage(STATUS \"CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}\")\nif (MSVC)\n    string(TOLOWER \"${CMAKE_GENERATOR_PLATFORM}\" CMAKE_GENERATOR_PLATFORM_LWR)\n    message(STATUS \"CMAKE_GENERATOR_PLATFORM: ${CMAKE_GENERATOR_PLATFORM}\")\nelse ()\n    set(CMAKE_GENERATOR_PLATFORM_LWR \"\")\nendif ()\n\nif (NOT MSVC)\n    if (LLAMA_STATIC)\n        add_link_options(-static)\n        if (MINGW)\n            add_link_options(-static-libgcc -static-libstdc++)\n        endif()\n    endif()\n    if (LLAMA_GPROF)\n        add_compile_options(-pg)\n    endif()\nendif()\n\nset(ARCH_FLAGS \"\")\n\nif (CMAKE_OSX_ARCHITECTURES STREQUAL \"arm64\" OR CMAKE_GENERATOR_PLATFORM_LWR STREQUAL \"arm64\" OR\n    (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND\n     CMAKE_SYSTEM_PROCESSOR MATCHES \"^(aarch64|arm.*|ARM64)$\"))\n    message(STATUS \"ARM detected\")\n    if (MSVC)\n        add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead\n        add_compile_definitions(__ARM_NEON)\n        add_compile_definitions(__ARM_FEATURE_FMA)\n\n        set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS})\n        string(JOIN \" \" CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} \"/arch:armv8.2\")\n        check_cxx_source_compiles(\"#include <arm_neon.h>\\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }\" GGML_COMPILER_SUPPORT_DOTPROD)\n        if (GGML_COMPILER_SUPPORT_DOTPROD)\n            add_compile_definitions(__ARM_FEATURE_DOTPROD)\n        endif ()\n        check_cxx_source_compiles(\"#include <arm_neon.h>\\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }\" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)\n        if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)\n            add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)\n        endif ()\n        set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})\n    else()\n        if(KTRANSFORMERS_USE_NPU)\n            list(APPEND ARCH_FLAGS -march=armv8.2-a+fp16+fp16fml+dotprod -lnuma)\n        endif()\n        check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)\n        if (NOT \"${COMPILER_SUPPORTS_FP16_FORMAT_I3E}\" STREQUAL \"\")\n            list(APPEND ARCH_FLAGS -mfp16-format=ieee)\n        endif()\n        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"armv6\")\n            # Raspberry Pi 1, Zero\n            list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)\n        endif()\n        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"armv7\")\n            if (\"${CMAKE_SYSTEM_NAME}\" STREQUAL \"Android\")\n                # Android armeabi-v7a\n                list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)\n            else()\n                # Raspberry Pi 2\n                list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)\n            endif()\n        endif()\n        if (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"armv8\")\n            # Android arm64-v8a\n            # Raspberry Pi 3, 4, Zero 2 (32-bit)\n            list(APPEND ARCH_FLAGS -mno-unaligned-access)\n        endif()\n    endif()\nelseif (CMAKE_OSX_ARCHITECTURES STREQUAL \"x86_64\" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES \"^(x86_64|i686|amd64|x64|win32)$\" OR\n        (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND\n         CMAKE_SYSTEM_PROCESSOR MATCHES \"^(x86_64|i686|AMD64)$\"))\n    message(STATUS \"x86 detected\")\n    if(NOT KTRANSFORMERS_USE_NPU)\n        set(HOST_IS_X86 TRUE)\n        set(HAS_AVX512 TRUE)\n        set(__HAS_AMX__ TRUE)\n        add_compile_definitions(__x86_64__)\n        # check AVX512\n        execute_process(\n            COMMAND lscpu\n            OUTPUT_VARIABLE LSCPU_OUTPUT\n            OUTPUT_STRIP_TRAILING_WHITESPACE\n        )\n        # message(STATUS \"LSCPU_OUTPUT: ${LSCPU_OUTPUT}\")\n    \n        string(FIND \"${LSCPU_OUTPUT}\" \"avx512\" COMPILER_SUPPORTS_AVX512F)\n        \n        if (COMPILER_SUPPORTS_AVX512F GREATER -1)\n            message(STATUS \"Compiler and CPU support AVX512F (tested by compiling a program)\")\n            add_compile_definitions(__HAS_AVX512F__)\n        else()\n            message(STATUS \"Compiler and/or CPU do NOT support AVX512F\")\n            set(HAS_AVX512 False)\n        endif()\n    \n        # check AMX\n        string(FIND \"${LSCPU_OUTPUT}\" \"amx\" COMPILER_SUPPORTS_AMX)\n        \n        if(COMPILER_SUPPORTS_AMX GREATER -1)\n            message(STATUS \"Compiler supports AMX\")\n            add_compile_definitions(__HAS_AMX__)\n        else()\n            message(STATUS \"Compiler does NOT support AMX\")\n        endif()\n    endif()\n    if (MSVC)\n        # instruction set detection for MSVC only\n        if (LLAMA_NATIVE)\n            include(cmake/FindSIMD.cmake)\n        endif ()\n        if (LLAMA_AVX512)\n            list(APPEND ARCH_FLAGS /arch:AVX512)\n            # MSVC has no compile-time flags enabling specific\n            # AVX512 extensions, neither it defines the\n            # macros corresponding to the extensions.\n            # Do it manually.\n            if (LLAMA_AVX512_VBMI)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)\n            endif()\n            if (LLAMA_AVX512_VNNI)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)\n            endif()\n            if (LLAMA_AVX512_FANCY_SIMD)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VL__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VL__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BW__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BW__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512DQ__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512DQ__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)\n            endif()\n            if (LLAMA_AVX512_BF16)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)\n                add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)\n            endif()\n        elseif (LLAMA_AVX2)\n            list(APPEND ARCH_FLAGS /arch:AVX2)\n        elseif (LLAMA_AVX)\n            list(APPEND ARCH_FLAGS /arch:AVX)\n        endif()\n    else()\n        if (LLAMA_NATIVE)\n            list(APPEND ARCH_FLAGS -mfma -mavx -mavx2)\n            list(APPEND ARCH_FLAGS -march=native)\n        endif()\n        if (LLAMA_F16C)\n            list(APPEND ARCH_FLAGS -mf16c)\n        endif()\n        if (LLAMA_FMA)\n            list(APPEND ARCH_FLAGS -mfma)\n        endif()\n        if (LLAMA_AVX)\n            list(APPEND ARCH_FLAGS -mavx)\n        endif()\n        if (LLAMA_AVX2)\n            list(APPEND ARCH_FLAGS -mavx2)\n        endif()\n        if (LLAMA_AVX512)\n            list(APPEND ARCH_FLAGS -mavx512f)\n            list(APPEND ARCH_FLAGS -mavx512bw)\n        endif()\n        if (LLAMA_AVX512_VBMI)\n            list(APPEND ARCH_FLAGS -mavx512vbmi)\n        endif()\n        if (LLAMA_AVX512_VNNI)\n            list(APPEND ARCH_FLAGS -mavx512vnni)\n        endif()\n        if (LLAMA_AVX512_FANCY_SIMD)\n            message(STATUS \"AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI enabled\")\n            list(APPEND ARCH_FLAGS -mavx512vl)\n            list(APPEND ARCH_FLAGS -mavx512bw)\n            list(APPEND ARCH_FLAGS -mavx512dq)\n            list(APPEND ARCH_FLAGS -mavx512vnni)\n            list(APPEND ARCH_FLAGS -mavx512vpopcntdq)\n        endif()\n        if (LLAMA_AVX512_BF16)\n            list(APPEND ARCH_FLAGS -mavx512bf16)\n        endif()\n    endif()\nelseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc64\")\n    message(STATUS \"PowerPC detected\")\n    if (${CMAKE_SYSTEM_PROCESSOR} MATCHES \"ppc64le\")\n        list(APPEND ARCH_FLAGS -mcpu=powerpc64le)\n    else()\n        list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)\n        #TODO: Add  targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)\n    endif()\nelse()\n    message(STATUS \"Unknown architecture\")\nendif()\n\n# message(STATUS \"CUDAToolkit_ROOT:${CUDAToolkit_ROOT}\")\n# find_package(FindCUDAToolkit REQUIRED)\n# if(CUDAToolkit_FOUND)\n#     message(STATUS \"Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}\")\n# else()\n#     message(STATUS \"Can't found CUDA lib\")\n# endif()\n\nif (NOT EXISTS $ENV{ROCM_PATH})\n    if (NOT EXISTS /opt/rocm)\n        set(ROCM_PATH /usr)\n    else()\n        set(ROCM_PATH /opt/rocm)\n    endif()\nelse()\n    set(ROCM_PATH $ENV{ROCM_PATH})\nendif()\n\nlist(APPEND CMAKE_PREFIX_PATH  ${ROCM_PATH})\nlist(APPEND CMAKE_PREFIX_PATH \"${ROCM_PATH}/lib64/cmake\")\n\nif (NOT EXISTS $ENV{MUSA_PATH})\n    if (NOT EXISTS /opt/musa)\n        set(MUSA_PATH /usr/local/musa)\n    else()\n        set(MUSA_PATH /opt/musa)\n    endif()\nelse()\n    set(MUSA_PATH $ENV{MUSA_PATH})\nendif()\n\nlist(APPEND CMAKE_MODULE_PATH \"${MUSA_PATH}/cmake\")\n\nadd_compile_options(\"$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>\")\nadd_compile_options(\"$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>\")\n\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/third_party/pybind11)\nadd_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)\n\ninclude_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party)\nif (WIN32)\n    include_directories(\"$ENV{CUDA_PATH}/include\")\n    add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)\nelseif (UNIX)\n    if (KTRANSFORMERS_USE_ROCM)\n        find_package(HIP REQUIRED)\n        if(HIP_FOUND)\n            include_directories(\"${HIP_INCLUDE_DIRS}\")\n            add_compile_definitions(KTRANSFORMERS_USE_ROCM=1)\n        endif()\n    elseif (KTRANSFORMERS_USE_MUSA)\n        if (NOT EXISTS $ENV{MUSA_PATH})\n            if (NOT EXISTS /opt/musa)\n                set(MUSA_PATH /usr/local/musa)\n            else()\n                set(MUSA_PATH /opt/musa)\n            endif()\n        else()\n            set(MUSA_PATH $ENV{MUSA_PATH})\n        endif()\n\n        list(APPEND CMAKE_MODULE_PATH \"${MUSA_PATH}/cmake\")\n\n        find_package(MUSAToolkit)\n        if (MUSAToolkit_FOUND)\n            message(STATUS \"MUSA Toolkit found\")\n            add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)\n        endif()\n    elseif (KTRANSFORMERS_USE_XPU)\n        add_compile_definitions(KTRANSFORMERS_USE_XPU=1)\n    elseif (KTRANSFORMERS_USE_CUDA)\n        find_package(CUDA REQUIRED)\n        include_directories(\"${CUDA_INCLUDE_DIRS}\")\n        include(CheckLanguage)\n        check_language(CUDA)\n        if(CMAKE_CUDA_COMPILER)\n            message(STATUS \"CUDA detected\")\n            find_package(CUDAToolkit REQUIRED)\n            include_directories(${CUDAToolkit_INCLUDE_DIRS})\n        endif()\n        message(STATUS \"enabling CUDA\")\n        enable_language(CUDA)\n        add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)\n    endif()\nendif()\n\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/llamafile SOURCE_DIR3)\n# aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/llamafile SOURCE_DIR4)\nfile(GLOB LLAMAFILE_SOURCES \"${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/llamafile/*.cpp\")\nlist(REMOVE_ITEM LLAMAFILE_SOURCES\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/llamafile/sgemm_arm.cpp\"\n    \"${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/llamafile/sgemm_x86.cpp\"\n)\nset(SOURCE_DIR4 ${LLAMAFILE_SOURCES})\naux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/kvcache SOURCE_DIR5)\n\nif (HOST_IS_X86 AND HAS_AVX512 AND __HAS_AMX__)\n    aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/operators/amx SOURCE_DIR6)\nendif()\n\n\nset(ALL_SOURCES ${SOURCE_DIR1} ${SOURCE_DIR2} ${SOURCE_DIR3} ${SOURCE_DIR4} ${SOURCE_DIR5} ${SOURCE_DIR6})\n\nfile(GLOB_RECURSE FMT_SOURCES \"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp\" \"${CMAKE_CURRENT_SOURCE_DIR}/*.hpp\" \"${CMAKE_CURRENT_SOURCE_DIR}/*.h\")\n\nadd_custom_target(\n    format\n    COMMAND clang-format\n    -i\n    -style=file\n    ${FMT_SOURCES}\n    COMMENT \"Running clang-format on all source files\"\n)\n\n\nadd_library(llamafile STATIC ${SOURCE_DIR4})\n\nmessage(STATUS \"CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}\")\nmessage(STATUS \"ARCH_FLAGS: ${ARCH_FLAGS}\")\npybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})\ntarget_link_libraries(${PROJECT_NAME} PRIVATE llama)\n\n\nif(WIN32)\n    target_link_libraries(${PROJECT_NAME} PRIVATE \"$ENV{CUDA_PATH}/lib/x64/cudart.lib\")#CUDA::cudart\nelseif(UNIX)\n    if (KTRANSFORMERS_USE_ROCM)\n        add_compile_definitions(USE_HIP=1)\n        target_link_libraries(${PROJECT_NAME} PRIVATE \"${ROCM_PATH}/lib/libamdhip64.so\")\n        message(STATUS \"Building for HIP\")\n    elseif(KTRANSFORMERS_USE_MUSA)\n        target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)\n    elseif(KTRANSFORMERS_USE_XPU)\n    elseif(KTRANSFORMERS_USE_CUDA AND NOT KTRANSFORMERS_USE_MUSA)\n        target_link_libraries(${PROJECT_NAME} PRIVATE \"${CUDAToolkit_LIBRARY_DIR}/libcudart.so\")\n    endif()\nendif()\n\n# Define the USE_NUMA option\noption(USE_NUMA \"Disable NUMA support\" OFF)\n\n# Check if the USE_NUMA environment variable is set\nif(DEFINED ENV{USE_NUMA})\n    set(USE_NUMA ON)\nendif()\n\nif(USE_NUMA)\n    message(STATUS \"NUMA support is enabled\")\nelse()\n    message(STATUS \"NUMA support is disabled\")\nendif()\n\nfind_library(NUMA_LIBRARY NAMES numa)\n\nif(NUMA_LIBRARY AND USE_NUMA)\n    message(STATUS \"NUMA library found: ${NUMA_LIBRARY} - enabling NUMA support\")\n    target_link_libraries(${PROJECT_NAME} PRIVATE ${NUMA_LIBRARY})\n    target_compile_definitions(${PROJECT_NAME} PRIVATE USE_NUMA)\nelse()\n    if(USE_NUMA)\n        message(FATAL_ERROR \"NUMA library not found - maybe sudo apt install libnuma-dev\")\n    else()\n        message(STATUS \"NUMA library not found or user not set USE_NUMA - disabling NUMA support\")\n    endif()\nendif()\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/bench/bench_attention.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :  \nAuthor       : Jianwei Dong\nDate         : 2024-08-28 10:32:05\nVersion      : 1.0.0\nLastEditors  : Jianwei Dong \nLastEditTime : 2024-08-28 10:32:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nimport cpuinfer_ext\nimport torch\n\nlayer_num = 10\nkv_head_num = 8\nq_head_num = 32\nhead_dim = 128\nblock_len = 128\nanchor_num = 1\n\nanchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC\nkv_type = cpuinfer_ext.kvcache.ggml_type.FP16\nretrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER\nlayer_step: int = 1\ntoken_step: int = 1\nlayer_offset: int = 0\nmax_thread_num: int = 64\nmax_batch_size: int = 1\nmax_block_num: int = 1024\nCPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)\n\nwarm_up_iter = 1000\ntest_iter = 10000\n\n\ndef bench_linear(cache_seqlen: int):\n    with torch.inference_mode(mode=True):\n        cache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device=\"cpu\")\n        seqlens_zero = torch.zeros((1,), dtype=torch.int32, device=\"cpu\")\n\n        config = cpuinfer_ext.kvcache.KVCacheConfig(\n            layer_num,\n            kv_head_num,\n            q_head_num,\n            head_dim,\n            block_len,\n            anchor_num,\n            anchor_type,\n            kv_type,\n            retrieval_type,\n            layer_step,\n            token_step,\n            layer_offset,\n            max_block_num,\n            max_batch_size,\n            max_thread_num,\n        )\n        local_kvcache = cpuinfer_ext.kvcache.KVCache(config)\n        block_table = (\n            torch.arange(max_block_num, dtype=torch.int32, device=\"cpu\")\n            .contiguous()\n            .view(1, -1)\n        )\n\n        for layer_idx in range(layer_num):\n            k_cache = torch.randn(\n                (1, cache_seqlen, kv_head_num, head_dim),\n                dtype=torch.float16,\n                device=\"cpu\",\n            ).contiguous()\n            v_cache = torch.randn(\n                (1, cache_seqlen, kv_head_num, head_dim),\n                dtype=torch.float16,\n                device=\"cpu\",\n            ).contiguous()\n\n            CPUInfer.submit(\n                local_kvcache.update_kvcache_fp16(\n                    k_cache.data_ptr(),\n                    v_cache.data_ptr(),\n                    layer_idx,\n                    block_table.data_ptr(),\n                    1,\n                    max_block_num,\n                    seqlens_zero.data_ptr(),\n                    cache_seqlen,\n                )\n            )\n            CPUInfer.sync()\n\n        input = torch.randn(\n            (1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n        output = torch.empty(\n            (1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n\n        # attn_lse: (bsz, q_len, q_head_num)\n        attn_lse = torch.empty(\n            (1, 1, q_head_num), dtype=torch.float32, device=\"cpu\"\n        ).contiguous()\n        input = input / 100\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                local_kvcache.attn(\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    attn_lse.data_ptr(),\n                    i % layer_num,\n                    0,\n                    1,\n                    1,\n                    max_block_num,\n                    block_table.data_ptr(),\n                    cache_seqlens.data_ptr(),\n                    -1,\n                    -1,\n                    -1,\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                local_kvcache.attn(\n                    input.data_ptr(),\n                    output.data_ptr(),\n                    attn_lse.data_ptr(),\n                    i % layer_num,\n                    0,\n                    1,\n                    1,\n                    max_block_num,\n                    block_table.data_ptr(),\n                    cache_seqlens.data_ptr(),\n                    -1,\n                    -1,\n                    -1,\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print(\"cache sequence length: \", cache_seqlen)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", total_time / test_iter * 1000000)\n        print(\n            \"Bandwidth: \",\n            cache_seqlen\n            * kv_head_num\n            * head_dim\n            * 2\n            * 2\n            * test_iter\n            / total_time\n            / 1000\n            / 1000\n            / 1000,\n            \"GB/s\",\n        )\n        print(\"\")\n\n\nbench_linear(1024)\nbench_linear(4096)\nbench_linear(16384)\nbench_linear(32768)\nbench_linear(65536)\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/bench/bench_attention_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :  \nAuthor       : Jianwei Dong\nDate         : 2024-08-28 10:32:05\nVersion      : 1.0.0\nLastEditors  : Jianwei Dong \nLastEditTime : 2024-08-28 10:32:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nimport cpuinfer_ext\nimport torch\n\nlayer_num = 10\nkv_head_num = 8\nq_head_num = 32\nhead_dim = 128\nblock_len = 128\nanchor_num = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\n\ndef bench_linear(cache_seqlen: int, device):\n    with torch.inference_mode(mode=True):\n\n        kvcaches = []\n\n        for layer_idx in range(layer_num):\n            k_cache = torch.randn(\n                (1, 32, cache_seqlen, head_dim),\n                dtype=torch.float16,\n                device=device,\n            ).contiguous()\n            v_cache = torch.randn(\n                (1, 32, cache_seqlen, head_dim),\n                dtype=torch.float16,\n                device=device,\n            ).contiguous()\n\n            kvcaches.append((k_cache, v_cache))\n\n        input = torch.randn(\n            (1, q_head_num, 1, head_dim), dtype=torch.float16, device=device\n        ).contiguous()\n        input = input / 100\n\n        # warm up\n        for i in range(warm_up_iter):\n            k_cache = kvcaches[i % layer_num][0]\n            v_cache = kvcaches[i % layer_num][1]\n            torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            k_cache = kvcaches[i % layer_num][0]\n            v_cache = kvcaches[i % layer_num][1]\n            torch.nn.functional.scaled_dot_product_attention(input, k_cache, v_cache)\n        end = time.perf_counter()\n        total_time = end - start\n        print(\"cache sequence length: \", cache_seqlen)\n        print(\"Time(s): \", total_time)\n        print(\"Iteration: \", test_iter)\n        print(\"Time(us) per iteration: \", total_time / test_iter * 1000000)\n        print(\n            \"Bandwidth: \",\n            cache_seqlen\n            * q_head_num\n            * head_dim\n            * 2\n            * 2\n            * test_iter\n            / total_time\n            / 1000\n            / 1000\n            / 1000,\n            \"GB/s\",\n        )\n        print(\"\")\n\n\nbench_linear(1024, \"cpu\")\nbench_linear(4096, \"cpu\")\nbench_linear(1024, \"cuda\")\nbench_linear(4096, \"cuda\")\nbench_linear(16384, \"cuda\")\nbench_linear(32768, \"cuda\")\nbench_linear(65536, \"cuda\")\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/bench/bench_linear.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:31:59\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:35:35\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\ninput_size = 16384\noutput_size = 5120\nstride = 16\ngroup_max_len = 1024\nlayer_num = 10\nqlen = 1\nCPUInfer = cpuinfer_ext.CPUInfer(64)\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef bench_linear(quant_mode: str):\n    with torch.inference_mode(mode=True):\n\n        hidden_type = 30 # ggml_type::GGML_TYPE_BF16\n        if quant_mode == \"fp32\":\n            proj_type = 0 # ggml_type::GGML_TYPE_F32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = 1 # ggml_type::GGML_TYPE_F16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = 30 # ggml_type::GGML_TYPE_BF16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"q8_0\":\n            proj_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            bytes_per_elem = 1.062500\n        elif quant_mode == \"q6_k\":\n            proj_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.820312\n        elif quant_mode == \"q5_k_m\":\n            proj_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            bytes_per_elem = 0.687500\n        elif quant_mode == \"q4_k_m\":\n            proj_type = 12 # ggml_type::GGML_TYPE_Q4_K\n            bytes_per_elem = 0.562500\n        elif quant_mode == \"q3_k_m\":\n            proj_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"q2_k\":\n            proj_type = 10 # ggml_type::GGML_TYPE_Q2_K\n            bytes_per_elem = 0.328125\n        elif quant_mode == \"iq3_xs\":\n            proj_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"iq2_xxs\":\n            proj_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            bytes_per_elem = 0.257812\n        else:\n            assert(False)\n\n        linears = []\n        projs = []\n        for _ in range(layer_num):\n            proj = torch.randn((output_size, input_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)\n            linear = cpuinfer_ext.linear.Linear(config)\n            projs.append(proj)\n            linears.append(linear)\n        input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n        output = torch.empty((layer_num, qlen, output_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                linears[i % layer_num].forward(\n                    qlen, \n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                linears[i % layer_num].forward(\n                    qlen, \n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', input_size * output_size * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_linear(\"fp32\")\nbench_linear(\"fp16\")\nbench_linear(\"bf16\")\nbench_linear(\"q8_0\")\nbench_linear(\"q6_k\")\nbench_linear(\"q5_k_m\")\nbench_linear(\"q4_k_m\")\nbench_linear(\"q3_k_m\")\nbench_linear(\"q2_k\")\n# Not supported on __x86_64__\n# bench_linear(\"iq3_xs\")\n# bench_linear(\"iq2_xxs\")\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/bench/bench_linear_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:31:59\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-07-25 10:32:48\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nimport torch\nimport torch.nn.quantized as nnq\n\nscale, zero_point = 0.1, 0  # Adjust scale and zero_point based on your dataset\n\ninput_size = 16384\noutput_size = 5120\nlayer_num = 10\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef bench_linear(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"fp32\":\n            proj_type = torch.float32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = torch.float16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = torch.bfloat16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"qint8\":\n            proj_type = torch.qint8\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n        projs = []\n        for _ in range(layer_num):\n            proj = torch.randn((output_size, input_size), dtype = torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            if quant_mode == \"qint8\":\n                proj_q = torch.quantize_per_tensor(proj, scale, zero_point, torch.qint8)\n                quantized_layer = nnq.Linear(input_size, output_size)\n                quantized_layer.set_weight_bias(proj_q, None)\n                projs.append(quantized_layer)\n            else:\n                projs.append(proj.to(proj_type))\n        input = torch.randn((layer_num, qlen, input_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            if isinstance(projs[i % layer_num], nnq.Linear):\n                input_q = torch.quantize_per_tensor(input[i % layer_num].to(torch.float32), scale, zero_point, torch.quint8)\n                t_output = projs[i % layer_num](input_q)\n            else:\n                t_output = torch.mm(input[i % layer_num].to(proj_type), projs[i % layer_num].t())\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            if isinstance(projs[i % layer_num], nnq.Linear):\n                input_q = torch.quantize_per_tensor(input[i % layer_num].to(torch.float32), scale, zero_point, torch.quint8)\n                t_output = projs[i % layer_num](input_q)\n            else:\n                t_output = torch.mm(input[i % layer_num].to(proj_type), projs[i % layer_num].t())\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', input_size * output_size * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_linear(\"fp32\")\nbench_linear(\"fp16\")\nbench_linear(\"bf16\")\nbench_linear(\"qint8\")\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/bench/bench_mlp.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-16 10:43:18\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:36:04\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nhidden_size = 5120\nintermediate_size = 3072\nstride = 16\ngroup_max_len = 1024\nlayer_num = 10\nqlen = 1\nCPUInfer = cpuinfer_ext.CPUInfer(64)\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef bench_mlp(quant_mode: str):\n    with torch.inference_mode(mode=True):\n\n        hidden_type = 30 # ggml_type::GGML_TYPE_BF16\n        if quant_mode == \"fp32\":\n            gate_type = 0 # ggml_type::GGML_TYPE_F32\n            up_type = 0 # ggml_type::GGML_TYPE_F32\n            down_type = 0 # ggml_type::GGML_TYPE_F32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            gate_type = 1 # ggml_type::GGML_TYPE_F16\n            up_type = 1 # ggml_type::GGML_TYPE_F16\n            down_type = 1 # ggml_type::GGML_TYPE_F16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            gate_type = 30 # ggml_type::GGML_TYPE_BF16\n            up_type = 30 # ggml_type::GGML_TYPE_BF16\n            down_type = 30 # ggml_type::GGML_TYPE_BF16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"q8_0\":\n            gate_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            up_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            down_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            bytes_per_elem = 1.062500\n        elif quant_mode == \"q6_k\":\n            gate_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            up_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.820312\n        elif quant_mode == \"q5_k_m\":\n            gate_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            up_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.731771\n        elif quant_mode == \"q4_k_m\":\n            gate_type = 12 # ggml_type::GGML_TYPE_Q4_K\n            up_type = 12 # ggml_type::GGML_TYPE_Q4_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.648437\n        elif quant_mode == \"q3_k_m\":\n            gate_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            up_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            down_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            bytes_per_elem = 0.515625\n        elif quant_mode == \"q2_k\":\n            gate_type = 10 # ggml_type::GGML_TYPE_Q2_K\n            up_type = 10 # ggml_type::GGML_TYPE_Q2_K\n            down_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            bytes_per_elem = 0.328125\n        elif quant_mode == \"iq3_xs\":\n            gate_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            up_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            down_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"iq2_xxs\":\n            gate_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            up_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            down_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            bytes_per_elem = 0.257812\n        else:\n            assert(False)\n\n\n        mlps = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)\n            mlp = cpuinfer_ext.mlp.MLP(config)\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            mlps.append(mlp)\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n        output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                mlps[i % layer_num].forward( \n                    qlen, \n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                mlps[i % layer_num].forward( \n                    qlen, \n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_mlp(\"fp32\")\nbench_mlp(\"fp16\")\nbench_mlp(\"bf16\")\nbench_mlp(\"q8_0\")\nbench_mlp(\"q6_k\")\nbench_mlp(\"q5_k_m\")\nbench_mlp(\"q4_k_m\")\nbench_mlp(\"q3_k_m\")\nbench_mlp(\"q2_k\")\n# Not supported on __x86_64__\n# bench_linear(\"iq3_xs\")\n# bench_linear(\"iq2_xxs\")\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/bench/bench_mlp_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-16 10:43:18\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-07-25 10:32:53\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nimport torch\nimport torch.nn.quantized as nnq\n\nscale, zero_point = 0.1, 0  # Adjust scale and zero_point based on your dataset\n\nhidden_size = 5120\nintermediate_size = 3072\nlayer_num = 10\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    if isinstance(gate_proj, nnq.Linear):\n        input_q = torch.quantize_per_tensor(input.to(torch.float32), scale, zero_point, torch.quint8)\n        gate_buf = gate_proj(input_q)\n        up_buf = up_proj(input_q)\n        gate_buf = gate_buf.dequantize()\n        up_buf = up_buf.dequantize()\n        intermediate = act_fn(gate_buf) * up_buf\n        intermediate_q = torch.quantize_per_tensor(intermediate, scale, zero_point, torch.quint8)\n        expert_output = down_proj(intermediate_q)\n        ret = expert_output.dequantize()\n    else:\n        gate_buf = torch.mm(input.to(gate_proj.dtype), gate_proj.t())\n        up_buf = torch.mm(input.to(up_proj.dtype), up_proj.t())\n        intermediate = act_fn(gate_buf) * up_buf\n        ret = torch.mm(intermediate.to(down_proj.dtype), down_proj.t())\n    return ret\n\ndef bench_mlp(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"fp32\":\n            proj_type = torch.float32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = torch.float16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = torch.bfloat16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"qint8\":\n            proj_type = torch.qint8\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            if quant_mode == \"qint8\":\n                gate_proj_q = torch.quantize_per_tensor(gate_proj, scale, zero_point, torch.qint8)\n                quantized_gate = nnq.Linear(hidden_size, intermediate_size)\n                quantized_gate.set_weight_bias(gate_proj_q, None)\n                up_proj_q = torch.quantize_per_tensor(up_proj, scale, zero_point, torch.qint8)\n                quantized_up = nnq.Linear(hidden_size, intermediate_size)\n                quantized_up.set_weight_bias(up_proj_q, None)\n                down_proj_q = torch.quantize_per_tensor(down_proj, scale, zero_point, torch.qint8)\n                quantized_down = nnq.Linear(intermediate_size, hidden_size)\n                quantized_down.set_weight_bias(down_proj_q, None)\n                gate_projs.append(quantized_gate)\n                up_projs.append(quantized_up)\n                down_projs.append(quantized_down)\n            else:\n                gate_projs.append(gate_proj.to(proj_type))\n                up_projs.append(up_proj.to(proj_type))\n                down_projs.append(down_proj.to(proj_type))\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            mlp_torch(input[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            mlp_torch(input[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_mlp(\"fp32\")\nbench_mlp(\"fp16\")\nbench_mlp(\"bf16\")\nbench_mlp(\"qint8\")\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/bench/bench_moe.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:41:28\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nexpert_num = 160\nhidden_size = 5120\nintermediate_size = 1536\nstride = 16\ngroup_min_len = 10\ngroup_max_len = 1024\nn_routed_experts = 6\nlayer_num = 10\nqlen = 1\nCPUInfer = cpuinfer_ext.CPUInfer(64)\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        hidden_type = 30 # ggml_type::GGML_TYPE_BF16\n        if quant_mode == \"fp32\":\n            gate_type = 0 # ggml_type::GGML_TYPE_F32\n            up_type = 0 # ggml_type::GGML_TYPE_F32\n            down_type = 0 # ggml_type::GGML_TYPE_F32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            gate_type = 1 # ggml_type::GGML_TYPE_F16\n            up_type = 1 # ggml_type::GGML_TYPE_F16\n            down_type = 1 # ggml_type::GGML_TYPE_F16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            gate_type = 30 # ggml_type::GGML_TYPE_BF16\n            up_type = 30 # ggml_type::GGML_TYPE_BF16\n            down_type = 30 # ggml_type::GGML_TYPE_BF16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"q8_0\":\n            gate_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            up_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            down_type = 8 # ggml_type::GGML_TYPE_Q8_0\n            bytes_per_elem = 1.062500\n        elif quant_mode == \"q6_k\":\n            gate_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            up_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.820312\n        elif quant_mode == \"q5_k_m\":\n            gate_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            up_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.731771\n        elif quant_mode == \"q4_k_m\":\n            gate_type = 12 # ggml_type::GGML_TYPE_Q4_K\n            up_type = 12 # ggml_type::GGML_TYPE_Q4_K\n            down_type = 14 # ggml_type::GGML_TYPE_Q6_K\n            bytes_per_elem = 0.648437\n        elif quant_mode == \"q3_k_m\":\n            gate_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            up_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            down_type = 13 # ggml_type::GGML_TYPE_Q5_K\n            bytes_per_elem = 0.515625\n        elif quant_mode == \"q2_k\":\n            gate_type = 10 # ggml_type::GGML_TYPE_Q2_K\n            up_type = 10 # ggml_type::GGML_TYPE_Q2_K\n            down_type = 11 # ggml_type::GGML_TYPE_Q3_K\n            bytes_per_elem = 0.328125\n        elif quant_mode == \"iq3_xs\":\n            gate_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            up_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            down_type = 21 # ggml_type::GGML_TYPE_IQ3_S\n            bytes_per_elem = 0.429688\n        elif quant_mode == \"iq2_xxs\":\n            gate_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            up_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            down_type = 16 # ggml_type::GGML_TYPE_IQ2_XXS\n            bytes_per_elem = 0.257812\n        else:\n            assert(False)\n\n\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            config = cpuinfer_ext.moe.MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, stride, group_min_len, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)\n            moe = cpuinfer_ext.moe.MOE(config)\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n        expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = \"cuda\")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to(\"cpu\").contiguous()\n        weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n        output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                moes[i % layer_num].forward( \n                    qlen, \n                    n_routed_experts, \n                    expert_ids[i % layer_num].data_ptr(), \n                    weights[i % layer_num].data_ptr(),\n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                moes[i % layer_num].forward( \n                    qlen, \n                    n_routed_experts, \n                    expert_ids[i % layer_num].data_ptr(), \n                    weights[i % layer_num].data_ptr(),\n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr()\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_moe(\"fp32\")\nbench_moe(\"fp16\")\nbench_moe(\"bf16\")\nbench_moe(\"q8_0\")\nbench_moe(\"q6_k\")\nbench_moe(\"q5_k_m\")\nbench_moe(\"q4_k_m\")\nbench_moe(\"q3_k_m\")\nbench_moe(\"q2_k\")\n# Not supported on __x86_64__\n# bench_linear(\"iq3_xs\")\n# bench_linear(\"iq2_xxs\")\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/bench/bench_moe_amx.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2025-04-25 18:28:12\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2025-04-25 18:28:12\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nexpert_num = 8\nhidden_size = 7168\nintermediate_size = 2048\nmax_len = 25600\nn_routed_experts = 8\nlayer_num = 10\nqlen = 1024\nCPUInfer = cpuinfer_ext.CPUInfer(65)\nwarm_up_iter = 100\ntest_iter = 100\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"bf16\":\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"int8\":\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            config = cpuinfer_ext.moe.AMX_MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr())\n            if quant_mode == \"bf16\":\n                moe = cpuinfer_ext.moe.AMXBF16_MOE(config)\n                CPUInfer.submit(moe.load_weights())\n                CPUInfer.sync()\n            elif quant_mode == \"int8\":\n                moe = cpuinfer_ext.moe.AMXInt8_MOE(config)\n                CPUInfer.submit(moe.load_weights())\n                CPUInfer.sync()\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n        expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = \"cuda\")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to(\"cpu\").contiguous()\n        weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n        output = torch.empty((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n        qlen_tensor = torch.tensor([qlen], dtype=torch.int32)\n\n        # warm up\n        for i in range(warm_up_iter):\n            CPUInfer.submit(\n                moes[i % layer_num].forward( \n                    qlen, \n                    n_routed_experts, \n                    expert_ids[i % layer_num].data_ptr(), \n                    weights[i % layer_num].data_ptr(),\n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr(),\n                    qlen_tensor.data_ptr()\n                )\n            )\n            CPUInfer.sync()\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            CPUInfer.submit(\n                moes[i % layer_num].forward( \n                    qlen, \n                    n_routed_experts, \n                    expert_ids[i % layer_num].data_ptr(), \n                    weights[i % layer_num].data_ptr(),\n                    input[i % layer_num].data_ptr(), \n                    output[i % layer_num].data_ptr(),\n                    qlen_tensor.data_ptr()\n                )\n            )\n            CPUInfer.sync()\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('Flops: ', hidden_size * intermediate_size * qlen * 3 * n_routed_experts * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GFLOPS')\n        print('')\n\nbench_moe(\"bf16\")\nbench_moe(\"int8\")\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/bench/bench_moe_torch.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-07-25 10:32:57\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nimport torch\nimport torch.nn.quantized as nnq\n\nscale, zero_point = 0.1, 0  # Adjust scale and zero_point based on your dataset\n\nexpert_num = 160\nhidden_size = 5120\nintermediate_size = 1536\nn_routed_experts = 6\nlayer_num = 10\nqlen = 1\nwarm_up_iter = 1000\ntest_iter = 10000\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    if isinstance(gate_proj, nnq.Linear):\n        input_q = torch.quantize_per_tensor(input.to(torch.float32), scale, zero_point, torch.quint8)\n        gate_buf = gate_proj(input_q)\n        up_buf = up_proj(input_q)\n        gate_buf = gate_buf.dequantize()\n        up_buf = up_buf.dequantize()\n        intermediate = act_fn(gate_buf) * up_buf\n        intermediate_q = torch.quantize_per_tensor(intermediate, scale, zero_point, torch.quint8)\n        expert_output = down_proj(intermediate_q)\n        ret = expert_output.dequantize()\n    else:\n        gate_buf = torch.mm(input.to(gate_proj.dtype), gate_proj.t())\n        up_buf = torch.mm(input.to(up_proj.dtype), up_proj.t())\n        intermediate = act_fn(gate_buf) * up_buf\n        ret = torch.mm(intermediate.to(down_proj.dtype), down_proj.t())\n    return ret\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\ndef bench_moe(quant_mode: str):\n    with torch.inference_mode(mode=True):\n        if quant_mode == \"fp32\":\n            proj_type = torch.float32\n            bytes_per_elem = 4.000000\n        elif quant_mode == \"fp16\":\n            proj_type = torch.float16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"bf16\":\n            proj_type = torch.bfloat16\n            bytes_per_elem = 2.000000\n        elif quant_mode == \"qint8\":\n            proj_type = torch.qint8\n            bytes_per_elem = 1.000000\n        else:\n            assert(False)\n\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n            if quant_mode == \"qint8\":\n                quantized_gate_proj = []\n                quantized_up_proj = []\n                quantized_down_proj = []\n                for i in range(expert_num):\n                    gate_proj_q = torch.quantize_per_tensor(gate_proj[i], scale, zero_point, torch.qint8)\n                    quantized_gate = nnq.Linear(hidden_size, intermediate_size)\n                    quantized_gate.set_weight_bias(gate_proj_q, None)\n                    quantized_gate_proj.append(quantized_gate)\n                    up_proj_q = torch.quantize_per_tensor(up_proj[i], scale, zero_point, torch.qint8)\n                    quantized_up = nnq.Linear(hidden_size, intermediate_size)\n                    quantized_up.set_weight_bias(up_proj_q, None)\n                    quantized_up_proj.append(quantized_up)\n                    down_proj_q = torch.quantize_per_tensor(down_proj[i], scale, zero_point, torch.qint8)\n                    quantized_down = nnq.Linear(intermediate_size, hidden_size)\n                    quantized_down.set_weight_bias(down_proj_q, None)\n                    quantized_down_proj.append(quantized_down)\n                gate_projs.append(quantized_gate_proj)\n                up_projs.append(quantized_up_proj)\n                down_projs.append(quantized_down_proj)\n            else:\n                gate_projs.append(gate_proj.to(proj_type))\n                up_projs.append(up_proj.to(proj_type))\n                down_projs.append(down_proj.to(proj_type))\n        expert_ids = torch.stack([torch.stack([torch.randperm(expert_num, dtype=torch.int64, device = \"cuda\")[:n_routed_experts] for _ in range(qlen)]) for _ in range(layer_num)]).to(\"cpu\").contiguous()\n        weights = torch.rand((layer_num, qlen, n_routed_experts), dtype=torch.float32, device = \"cuda\").to(\"cpu\").contiguous()\n        input = torch.randn((layer_num, qlen, hidden_size), dtype=torch.bfloat16, device = \"cuda\").to(\"cpu\").contiguous()\n\n        # warm up\n        for i in range(warm_up_iter):\n            moe_torch(input[i % layer_num], expert_ids[i % layer_num], weights[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n\n        # test\n        start = time.perf_counter()\n        for i in range(test_iter):\n            moe_torch(input[i % layer_num], expert_ids[i % layer_num], weights[i % layer_num], gate_projs[i % layer_num], up_projs[i % layer_num], down_projs[i % layer_num])\n        end = time.perf_counter()\n        total_time = end - start\n        print('Quant mode: ', quant_mode)\n        print('Time(s): ', total_time)\n        print('Iteration: ', test_iter) \n        print('Time(us) per iteration: ', total_time / test_iter * 1000000)\n        print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * bytes_per_elem * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')\n        print('')\n\nbench_moe(\"fp32\")\nbench_moe(\"fp16\")\nbench_moe(\"bf16\")\nbench_moe(\"qint8\")\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cmake/FindSIMD.cmake",
    "content": "include(CheckCSourceRuns)\n\nset(AVX_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256 a;\n        a = _mm256_set1_ps(0);\n        return 0;\n    }\n\")\n\nset(AVX512_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0,\n                                    0, 0, 0, 0, 0, 0, 0, 0);\n        __m512i b = a;\n        __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);\n        return 0;\n    }\n\")\n\nset(AVX2_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256i a = {0};\n        a = _mm256_abs_epi16(a);\n        __m256i x;\n        _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code\n        return 0;\n    }\n\")\n\nset(FMA_CODE \"\n    #include <immintrin.h>\n    int main()\n    {\n        __m256 acc = _mm256_setzero_ps();\n        const __m256 d = _mm256_setzero_ps();\n        const __m256 p = _mm256_setzero_ps();\n        acc = _mm256_fmadd_ps( d, p, acc );\n        return 0;\n    }\n\")\n\nmacro(check_sse type flags)\n    set(__FLAG_I 1)\n    set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})\n    foreach (__FLAG ${flags})\n        if (NOT ${type}_FOUND)\n            set(CMAKE_REQUIRED_FLAGS ${__FLAG})\n            check_c_source_runs(\"${${type}_CODE}\" HAS_${type}_${__FLAG_I})\n            if (HAS_${type}_${__FLAG_I})\n                set(${type}_FOUND TRUE CACHE BOOL \"${type} support\")\n                set(${type}_FLAGS \"${__FLAG}\" CACHE STRING \"${type} flags\")\n            endif()\n            math(EXPR __FLAG_I \"${__FLAG_I}+1\")\n        endif()\n    endforeach()\n    set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})\n\n    if (NOT ${type}_FOUND)\n        set(${type}_FOUND FALSE CACHE BOOL \"${type} support\")\n        set(${type}_FLAGS \"\" CACHE STRING \"${type} flags\")\n    endif()\n\n    mark_as_advanced(${type}_FOUND ${type}_FLAGS)\nendmacro()\n\n# flags are for MSVC only!\ncheck_sse(\"AVX\" \" ;/arch:AVX\")\nif (NOT ${AVX_FOUND})\n    set(LLAMA_AVX OFF)\nelse()\n    set(LLAMA_AVX ON)\nendif()\n\ncheck_sse(\"AVX2\" \" ;/arch:AVX2\")\ncheck_sse(\"FMA\" \" ;/arch:AVX2\")\nif ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))\n    set(LLAMA_AVX2 OFF)\nelse()\n    set(LLAMA_AVX2 ON)\nendif()\n\ncheck_sse(\"AVX512\" \" ;/arch:AVX512\")\nif (NOT ${AVX512_FOUND})\n    set(LLAMA_AVX512 OFF)\nelse()\n    set(LLAMA_AVX512 ON)\nendif()\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/backend.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:05\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:33:34\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"backend.h\"\n\n#ifdef USE_NUMA\n#include <numa.h>\n#include <numaif.h>\n\nthread_local int Backend::numa_node = -1;\n#endif\n\nthread_local int Backend::thread_local_id = -1;\n\nBackend::Backend(int max_thread_num) {\n    max_thread_num_ = max_thread_num;\n    thread_state_.resize(max_thread_num_);\n    for (int i = 0; i < max_thread_num_; i++) {\n        thread_state_[i].curr = std::make_unique<std::atomic<int>>();\n        thread_state_[i].status =\n            std::make_unique<std::atomic<ThreadStatus>>(ThreadStatus::WAITING);\n    }\n    workers_.resize(max_thread_num_);\n    for (int i = 1; i < max_thread_num_; i++) {\n        workers_[i] = std::thread(&Backend::worker_thread, this, i);\n    }\n}\n\nBackend::~Backend() {\n    for (int i = 0; i < max_thread_num_; i++) {\n        thread_state_[i].status->store(ThreadStatus::EXIT,\n                                       std::memory_order_release);\n    }\n    for (int i = 1; i < max_thread_num_; i++) {\n        if (workers_[i].joinable()) {\n            workers_[i].join();\n        }\n    }\n}\n\nint Backend::get_thread_num() { return max_thread_num_; }\n\nvoid Backend::do_work_stealing_job(int task_num,\n                                   std::function<void(int)> init_func,\n                                   std::function<void(int)> compute_func,\n                                   std::function<void(int)> finalize_func) {\n    init_func_ = init_func;\n    compute_func_ = compute_func;\n    finalize_func_ = finalize_func;\n#ifdef USE_NUMA\n    // numa node location will be calculated based on the number of threads\n    thread_num_ = max_thread_num_;\n#else\n    thread_num_ = std::min(max_thread_num_, task_num);\n#endif\n    int base = task_num / thread_num_;\n    int remain = task_num % thread_num_;\n    thread_state_[0].end = base + (0 < remain);\n\n    // 为主线程设置 thread_local_id\n    thread_local_id = 0;\n\n    for (int i = 1; i < thread_num_; i++) {\n        thread_state_[i].curr->store(thread_state_[i - 1].end,\n                                     std::memory_order_relaxed);\n        thread_state_[i].end = thread_state_[i - 1].end + base + (i < remain);\n        thread_state_[i].status->store(ThreadStatus::WORKING,\n                                       std::memory_order_release);\n    }\n    thread_state_[0].curr->store(0, std::memory_order_relaxed);\n    thread_state_[0].status->store(ThreadStatus::WORKING,\n                                   std::memory_order_release);\n    process_tasks(0);\n    for (int i = 1; i < thread_num_; i++) {\n        while (thread_state_[i].status->load(std::memory_order_acquire) ==\n               ThreadStatus::WORKING) {\n        }\n    }\n}\n\nvoid Backend::process_tasks(int thread_id) {\n    \n    #ifdef USE_NUMA\n    if(numa_node == -1){\n        numa_node = thread_id * numa_num_configured_nodes() / thread_num_;\n        struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes());\n        numa_bitmask_setbit(mask, numa_node);\n        numa_bind(mask);\n    }\n    #endif\n\n    if (init_func_ != nullptr) {\n        init_func_(thread_id);\n    }\n    while (true) {\n        int task_id = thread_state_[thread_id].curr->fetch_add(\n            1, std::memory_order_acq_rel);\n        if (task_id >= thread_state_[thread_id].end) {\n            break;\n        }\n        compute_func_(task_id);\n    }\n    for (int t_offset = 1; t_offset < thread_num_; t_offset++) {\n        int t_i = (thread_id + t_offset) % thread_num_;\n        if (thread_state_[t_i].status->load(std::memory_order_acquire) !=\n            ThreadStatus::WORKING) {\n            continue;\n        }\n        while (true) {\n            int task_id = thread_state_[t_i].curr->fetch_add(\n                1, std::memory_order_acq_rel);\n            if (task_id >= thread_state_[t_i].end) {\n                break;\n            }\n            compute_func_(task_id);\n        }\n    }\n    if (finalize_func_ != nullptr) {\n        finalize_func_(thread_id);\n    }\n    thread_state_[thread_id].status->store(ThreadStatus::WAITING,\n                                           std::memory_order_release);\n}\n\nvoid Backend::worker_thread(int thread_id) {\n    auto start = std::chrono::steady_clock::now();\n    thread_local_id = thread_id; // 设置线程本地变量\n    while (true) {\n        ThreadStatus status =\n            thread_state_[thread_id].status->load(std::memory_order_acquire);\n        if (status == ThreadStatus::WORKING) {\n            process_tasks(thread_id);\n            start = std::chrono::steady_clock::now();\n        } else if (status == ThreadStatus::WAITING) {\n            auto now = std::chrono::steady_clock::now();\n            auto duration =\n                std::chrono::duration_cast<std::chrono::milliseconds>(now -\n                                                                      start)\n                    .count();\n            if (duration > 50) {\n                std::this_thread::sleep_for(std::chrono::milliseconds(1));\n            }\n        } else if (status == ThreadStatus::EXIT) {\n            return;\n        }\n    }\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/backend.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:05\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:33:38\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_BACKEND_H\n#define CPUINFER_BACKEND_H\n\n#include <atomic>\n#include <condition_variable>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <thread>\n#include <vector>\n\nenum ThreadStatus {\n    WORKING,\n    WAITING,\n    EXIT,\n};\n\nstruct ThreadState {\n    std::unique_ptr<std::atomic<ThreadStatus>> status;\n    std::unique_ptr<std::atomic<int>> curr;\n    int end;\n};\n\nclass Backend {\n  public:\n    Backend(int);\n    ~Backend();\n    int get_thread_num();\n    void do_work_stealing_job(int, std::function<void(int)>,\n                              std::function<void(int)>,\n                              std::function<void(int)>);\n    #ifdef USE_NUMA\n    static thread_local int numa_node;\n    #endif\n    static thread_local int thread_local_id;\n\n  private:\n    int thread_num_;\n    int max_thread_num_;\n    std::vector<ThreadState> thread_state_; // [thread_num]\n    std::function<void(int)> init_func_;\n    std::function<void(int)> compute_func_;\n    std::function<void(int)> finalize_func_;\n    std::vector<std::thread> workers_;\n\n    void process_tasks(int);\n    void worker_thread(int);\n};\n#endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/cpuinfer.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-16 10:43:18\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-08-07 09:47:43\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n #ifndef CPUINFER_CPUINFER_H\n #define CPUINFER_CPUINFER_H\n \n #include <atomic>\n #include <condition_variable>\n #include <functional>\n #include <mutex>\n #include <queue>\n #include <thread>\n #include <vector>\n #include <stdexcept>\n #ifdef KTRANSFORMERS_USE_CUDA\n #include \"vendors/cuda.h\"\n #elif KTRANSFORMERS_USE_MUSA\n #include \"vendors/musa.h\"\n #elif KTRANSFORMERS_USE_ROCM\n #define __HIP_PLATFORM_AMD__\n #include \"vendors/hip.h\"\n #endif\n \n #include \"backend.h\"\n #include \"task_queue.h\"\n #include \"./vendors/vendor.h\"\n \n #include \"llama.cpp/ggml-impl.h\"\n \n class CPUInfer {\n    public:\n     CPUInfer(int thread_num) {\n         backend_ = new Backend(thread_num - 1);\n         task_queue_ = new TaskQueue();\n         for (int i = 0; i < (1 << 16); ++i) {\n             ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(i);\n         }\n     }\n \n     ~CPUInfer() {\n         delete backend_;\n         delete task_queue_;\n     }\n \n     template <typename Func, typename Obj, typename... Args>\n     void enqueue(Func f, Obj* obj, Args... args) {\n         task_queue_->enqueue([=]() {\n             std::invoke(f, *obj, args..., backend_);\n         });\n     }\n \n     void submit(std::pair<intptr_t, intptr_t> params) {\n         void (*func)(void*) = (void (*)(void*))params.first;\n         void* args = (void*)params.second;\n         *((CPUInfer**)args) = this;\n         func(args);\n     }\n \n     void sync() {\n         task_queue_->sync();\n     }\n \n     void submit_with_cuda_stream(intptr_t user_cuda_stream, std::pair<intptr_t, intptr_t> params) {\n        #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)\n         void (*func)(void*) = (void (*)(void*))params.first;\n         void* args = (void*)params.second;\n         *((CPUInfer**)args) = this;\n         cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)func, args);\n        #else\n         throw std::runtime_error(\"submit_with_cuda_stream is not supported on this platforma\");\n        #endif\n     }\n \n     static void sync_(void* cpu_infer_ptr) {\n         CPUInfer* cpuinfer = (CPUInfer*)cpu_infer_ptr;\n         cpuinfer->sync();\n     }\n \n     void sync_with_cuda_stream(intptr_t user_cuda_stream) {\n        #if defined(KTRANSFORMERS_USE_CUDA) || defined(KTRANSFORMERS_USE_MUSA) || defined(KTRANSFORMERS_USE_ROCM)\n         cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)&sync_, (void*)this);\n        #else\n         throw std::runtime_error(\"sync_with_cuda_stream is not supported on this platforma\");\n        #endif\n     }\n \n    public:\n     Backend* backend_;\n     TaskQueue* task_queue_;\n };\n \n #endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-08-05 04:49:08\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022 \n * @LastEditTime : 2024-08-05 09:21:29\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"shared_mem_buffer.h\"\n#include <cstdio>\n\nSharedMemBuffer::SharedMemBuffer() {\n    buffer_ = nullptr;\n    size_ = 0;\n}\n\nSharedMemBuffer::~SharedMemBuffer() {\n    if (buffer_) {\n        free(buffer_);\n    }\n}\n\nvoid SharedMemBuffer::alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests) {\n    uint64_t size = 0;\n    for (auto& request : requests) {\n        size += request.second;\n    }\n    if (size > size_) {\n        if (buffer_) {\n            free(buffer_);\n        }\n        buffer_ = std::aligned_alloc(64, size);\n\n        size_ = size;\n        for (auto& obj_requests : hist_requests_) {\n            for (auto& requests : obj_requests.second) {\n                arrange(requests);\n            }\n        }\n    }\n    arrange(requests);\n    hist_requests_[object].push_back(requests);\n}\n\nvoid SharedMemBuffer::dealloc(void* object) {\n    hist_requests_.erase(object);\n}\n\nvoid SharedMemBuffer::arrange(std::vector<std::pair<void**, uint64_t>> requests) {\n    uint64_t offset = 0;\n    for (auto& request : requests) {\n        *(request.first) = (uint8_t*)buffer_ + offset;\n        offset += request.second;\n    }\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/shared_mem_buffer.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-08-05 04:49:08\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022 \n * @LastEditTime : 2024-08-05 06:36:41\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n #ifndef CPUINFER_SHAREDMEMBUFFER_H\n #define CPUINFER_SHAREDMEMBUFFER_H\n \n #include <cstdint>\n #include <cstdlib>\n #include <map>\n #include <vector>\n \n class SharedMemBuffer {\n    public:\n     SharedMemBuffer();\n     ~SharedMemBuffer();\n \n     void alloc(void* object, std::vector<std::pair<void**, uint64_t>> requests);\n     void dealloc(void* object);\n \n    private:\n     void* buffer_;\n     uint64_t size_;\n     std::map<void*, std::vector<std::vector<std::pair<void**, uint64_t>>>> hist_requests_;\n \n     void arrange(std::vector<std::pair<void**, uint64_t>> requests);\n };\n \n static SharedMemBuffer shared_mem_buffer;\n \n #endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/task_queue.cpp",
    "content": "/**\n * @Description :\n * @Author    : chenht2022\n * @Date     : 2024-07-17 12:25:51\n * @Version   : 1.0.0\n * @LastEditors : chenht2022\n * @LastEditTime : 2024-10-09 11:08:10\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"task_queue.h\"\n\nTaskQueue::TaskQueue() {\n    worker = std::thread(&TaskQueue::processTasks, this);\n    sync_flag.store(true, std::memory_order_seq_cst);\n    exit_flag.store(false, std::memory_order_seq_cst);\n}\n\nTaskQueue::~TaskQueue() {\n    {\n        mutex.lock();\n        exit_flag.store(true, std::memory_order_seq_cst);\n        mutex.unlock();\n    }\n    cv.notify_all();\n    if (worker.joinable()) {\n        worker.join();\n    }\n}\n\nvoid TaskQueue::enqueue(std::function<void()> task) {\n    {\n        mutex.lock();\n        tasks.push(task);\n        sync_flag.store(false, std::memory_order_seq_cst);\n        mutex.unlock();\n    }\n    cv.notify_one();\n}\n\nvoid TaskQueue::sync() {\n    while (!sync_flag.load(std::memory_order_seq_cst))\n        ;\n}\n\nvoid TaskQueue::processTasks() {\n    while (true) {\n        std::function<void()> task;\n        {\n            mutex.lock();\n            cv.wait(mutex, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); });\n            if (exit_flag.load(std::memory_order_seq_cst) && tasks.empty()) {\n                return;\n            }\n            task = tasks.front();\n            tasks.pop();\n            mutex.unlock();\n        }\n        task();\n        {\n            mutex.lock();\n            if (tasks.empty()) {\n                sync_flag.store(true, std::memory_order_seq_cst);\n            }\n            mutex.unlock();\n        }\n    }\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/task_queue.h",
    "content": "/**\n * @Description :\n * @Author    : chenht2022\n * @Date     : 2024-07-16 10:43:18\n * @Version   : 1.0.0\n * @LastEditors : chenht\n * @LastEditTime : 2024-10-09 11:08:07\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_TASKQUEUE_H\n#define CPUINFER_TASKQUEUE_H\n\n#include <atomic>\n#include <condition_variable>\n#include <functional>\n#include <mutex>\n#include <queue>\n#include <thread>\n#include <vector>\n#ifdef _WIN32\n#include <windows.h>\n#endif\n\nclass custom_mutex {\n   private:\n#ifdef _WIN32\n    CRITICAL_SECTION cs;\n#else\n    std::mutex mtx;\n#endif\n\n   public:\n    custom_mutex() {\n#ifdef _WIN32\n        InitializeCriticalSection(&cs);\n#else\n        // No initialization required for std::mutex\n#endif\n    }\n\n    ~custom_mutex() {\n#ifdef _WIN32\n        DeleteCriticalSection(&cs);\n#endif\n    }\n\n    void lock() {\n#ifdef _WIN32\n        EnterCriticalSection(&cs);\n#else\n        mtx.lock();\n#endif\n    }\n\n    void unlock() {\n#ifdef _WIN32\n        LeaveCriticalSection(&cs);\n#else\n        mtx.unlock();\n#endif\n    }\n\n#ifdef _WIN32\n    CRITICAL_SECTION* get_handle() {\n        return &cs;\n    }\n#else\n    std::mutex* get_handle() {\n        return &mtx;\n    }\n#endif\n};\n\nclass custom_condition_variable {\n   private:\n#ifdef _WIN32\n    CONDITION_VARIABLE cond_var;\n#else\n    std::condition_variable cond_var;\n#endif\n\n   public:\n    custom_condition_variable() {\n#ifdef _WIN32\n        InitializeConditionVariable(&cond_var);\n#endif\n    }\n\n    template <typename Predicate>\n    void wait(custom_mutex& mutex, Predicate pred) {\n#ifdef _WIN32\n        while (!pred()) {\n            SleepConditionVariableCS(&cond_var, mutex.get_handle(), INFINITE);\n        }\n#else\n        std::unique_lock<std::mutex> lock(*mutex.get_handle(), std::adopt_lock);\n        cond_var.wait(lock, pred);\n        lock.release();\n#endif\n    }\n\n    void notify_one() {\n#ifdef _WIN32\n        WakeConditionVariable(&cond_var);\n#else\n        cond_var.notify_one();\n#endif\n    }\n\n    void notify_all() {\n#ifdef _WIN32\n        WakeAllConditionVariable(&cond_var);\n#else\n        cond_var.notify_all();\n#endif\n    }\n};\n\nclass TaskQueue {\n   public:\n    TaskQueue();\n    ~TaskQueue();\n\n    void enqueue(std::function<void()>);\n\n    void sync();\n\n   private:\n    void processTasks();\n\n    std::queue<std::function<void()>> tasks;\n    custom_mutex mutex;\n    custom_condition_variable cv;\n    std::thread worker;\n    std::atomic<bool> sync_flag;\n    std::atomic<bool> exit_flag;\n};\n#endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/vendors/README.md",
    "content": "## TODO\n\nThis directory can be removed after updating the version of `llama.cpp`."
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/vendors/cuda.h",
    "content": "#pragma once\n\n#include <cuda_runtime.h>\n#include <cuda.h>\n#include <cublas_v2.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#if CUDART_VERSION < 11020\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#endif // CUDART_VERSION < 11020\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/vendors/hip.h",
    "content": "#pragma once\n\n#define HIP_ENABLE_WARP_SYNC_BUILTINS 1\n#include <hip/hip_runtime.h>\n#include <hipblas/hipblas.h>\n#include <hip/hip_fp16.h>\n#include <hip/hip_bfloat16.h>\n#ifdef __HIP_PLATFORM_AMD__\n// for rocblas_initialize()\n#include \"rocblas/rocblas.h\"\n#endif // __HIP_PLATFORM_AMD__\n\n#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F\n#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F\n#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N HIPBLAS_OP_N\n#define CUBLAS_OP_T HIPBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH 0\n#define CUDA_R_16F  HIPBLAS_R_16F\n#define CUDA_R_32F  HIPBLAS_R_32F\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended\n#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned\n#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite\n#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT(\"HipVMM Failure: %s\\n\", hipGetErrorString(err)); }}\n#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)\n#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)\n#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6\n#define cublasCreate hipblasCreate\n#define cublasDestroy hipblasDestroy\n#define cublasGemmEx hipblasGemmEx\n#define cublasGemmBatchedEx hipblasGemmBatchedEx\n#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx\n#define cublasHandle_t hipblasHandle_t\n#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS\n#define cublasSetStream hipblasSetStream\n#define cublasSgemm hipblasSgemm\n#define cublasStatus_t hipblasStatus_t\n#define cublasOperation_t hipblasOperation_t\n#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6\n#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess\n#define cudaDeviceProp hipDeviceProp_t\n#define cudaDeviceSynchronize hipDeviceSynchronize\n#define cudaError_t hipError_t\n#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags hipEventCreateWithFlags\n#define cudaEventDisableTiming hipEventDisableTiming\n#define cudaEventRecord hipEventRecord\n#define cudaEventSynchronize hipEventSynchronize\n#define cudaEvent_t hipEvent_t\n#define cudaEventDestroy hipEventDestroy\n#define cudaFree hipFree\n#define cudaFreeHost hipHostFree\n#define cudaGetDevice hipGetDevice\n#define cudaGetDeviceCount hipGetDeviceCount\n#define cudaGetDeviceProperties hipGetDeviceProperties\n#define cudaGetErrorString hipGetErrorString\n#define cudaGetLastError hipGetLastError\n#define cudaHostRegister hipHostRegister\n#define cudaHostRegisterPortable hipHostRegisterPortable\n#define cudaHostRegisterReadOnly hipHostRegisterReadOnly\n#define cudaHostUnregister hipHostUnregister\n#define cudaLaunchHostFunc hipLaunchHostFunc\n#define cudaMalloc hipMalloc\n#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)\n#define cudaMemcpy hipMemcpy\n#define cudaMemcpyAsync hipMemcpyAsync\n#define cudaMemcpyPeerAsync hipMemcpyPeerAsync\n#define cudaMemcpy2DAsync hipMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice hipMemcpyHostToDevice\n#define cudaMemcpyKind hipMemcpyKind\n#define cudaMemset hipMemset\n#define cudaMemsetAsync hipMemsetAsync\n#define cudaMemGetInfo hipMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize\n#define cudaSetDevice hipSetDevice\n#define cuDeviceGet hipDeviceGet\n#define CUdevice hipDevice_t\n#define CUdeviceptr hipDeviceptr_t\n#define cuMemUnmap hipMemUnmap\n#define CUmemAccessDesc hipMemAccessDesc\n#define cuMemAddressFree hipMemAddressFree\n#define cuMemRelease hipMemRelease\n#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t\n#define cuMemCreate hipMemCreate\n#define cuMemAddressReserve hipMemAddressReserve\n#define cuMemMap hipMemMap\n#define cuMemSetAccess hipMemSetAccess\n#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity\n#define CUmemAllocationProp hipMemAllocationProp\n#define cuDeviceGetAttribute hipDeviceGetAttribute\n#define cudaStreamCreateWithFlags hipStreamCreateWithFlags\n#define cudaStreamDestroy hipStreamDestroy\n#define cudaStreamFireAndForget hipStreamFireAndForget\n#define cudaStreamNonBlocking hipStreamNonBlocking\n#define cudaStreamPerThread hipStreamPerThread\n#define cudaStreamSynchronize hipStreamSynchronize\n#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)\n#define cudaGraphExec_t hipGraphExec_t\n#define cudaGraphNode_t hipGraphNode_t\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaGraphExecDestroy hipGraphExecDestroy\n#define cudaGraphLaunch hipGraphLaunch\n#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure\n#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult\n#define cudaGraphNodeType hipGraphNodeType\n#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel\n#define cudaGraphInstantiate hipGraphInstantiate\n#define cudaStreamEndCapture hipStreamEndCapture\n#define cudaGraphDestroy hipGraphDestroy\n#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams\n#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction\n#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams\n#define cudaGraphNodeGetType hipGraphNodeGetType\n#define cudaGraphGetNodes hipGraphGetNodes\n#define cudaGraphExecUpdate hipGraphExecUpdate\n#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed\n#define cudaStreamBeginCapture hipStreamBeginCapture\n#define cudaGraph_t hipGraph_t\n#define cudaStream_t hipStream_t\n#define cudaSuccess hipSuccess\n#define cudaHostFn_t hipHostFn_t\n#define __trap() do { abort(); __builtin_unreachable(); } while(0)\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED\n#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED\n#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE\n#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH\n#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR\n#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED\n#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR\n#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED\n\n#define __CUDA_ARCH__ 1300\n\n#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)\n#define GCN\n#endif\n\n#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)\n#define CDNA\n#endif\n\n#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \\\n    defined(__gfx1150__) || defined(__gfx1151__)\n#define RDNA3\n#endif\n\n#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \\\n    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)\n#define RDNA2\n#endif\n\n#if defined(__gfx1010__) || defined(__gfx1012__)\n#define RDNA1\n#endif\n\n#ifndef __has_builtin\n    #define __has_builtin(x) 0\n#endif\n\ntypedef hip_bfloat16 nv_bfloat16;\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/vendors/musa.h",
    "content": "#pragma once\n\n#include <musa_runtime.h>\n#include <musa.h>\n#include <mublas.h>\n#include <musa_bf16.h>\n#include <musa_fp16.h>\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F\n#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N MUBLAS_OP_N\n#define CUBLAS_OP_T MUBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT\n#define CUDA_R_16F  MUSA_R_16F\n#define CUDA_R_32F  MUSA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#define cublasCreate mublasCreate\n#define cublasDestroy mublasDestroy\n#define cublasGemmEx mublasGemmEx\n#define cublasGemmBatchedEx mublasGemmBatchedEx\n#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx\n#define cublasHandle_t mublasHandle_t\n#define cublasSetMathMode mublasSetMathMode\n#define cublasSetStream mublasSetStream\n#define cublasSgemm mublasSgemm\n#define cublasStatus_t mublasStatus_t\n#define cublasOperation_t mublasOperation_t\n#define cublasGetStatusString mublasStatus_to_string\n#define cudaDataType_t musaDataType_t\n#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess\n#define cudaDeviceProp musaDeviceProp\n#define cudaDeviceSynchronize musaDeviceSynchronize\n#define cudaError_t musaError_t\n#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags musaEventCreateWithFlags\n#define cudaEventDisableTiming musaEventDisableTiming\n#define cudaEventRecord musaEventRecord\n#define cudaEventSynchronize musaEventSynchronize\n#define cudaEvent_t musaEvent_t\n#define cudaEventDestroy musaEventDestroy\n#define cudaFree musaFree\n#define cudaFreeHost musaFreeHost\n#define cudaGetDevice musaGetDevice\n#define cudaGetDeviceCount musaGetDeviceCount\n#define cudaGetDeviceProperties musaGetDeviceProperties\n#define cudaGetErrorString musaGetErrorString\n#define cudaGetLastError musaGetLastError\n#define cudaHostRegister musaHostRegister\n#define cudaHostRegisterPortable musaHostRegisterPortable\n#define cudaHostRegisterReadOnly musaHostRegisterReadOnly\n#define cudaHostUnregister musaHostUnregister\n#define cudaLaunchHostFunc musaLaunchHostFunc\n#define cudaMalloc musaMalloc\n#define cudaMallocHost musaMallocHost\n#define cudaMallocManaged musaMallocManaged\n#define cudaMemcpy musaMemcpy\n#define cudaMemcpyAsync musaMemcpyAsync\n#define cudaMemcpyPeerAsync musaMemcpyPeerAsync\n#define cudaMemcpy2DAsync musaMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice musaMemcpyHostToDevice\n#define cudaMemcpyKind musaMemcpyKind\n#define cudaMemset musaMemset\n#define cudaMemsetAsync musaMemsetAsync\n#define cudaMemGetInfo musaMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize\n#define cudaSetDevice musaSetDevice\n#define cudaStreamCreateWithFlags musaStreamCreateWithFlags\n#define cudaStreamDestroy musaStreamDestroy\n#define cudaStreamFireAndForget musaStreamFireAndForget\n#define cudaStreamNonBlocking musaStreamNonBlocking\n#define cudaStreamPerThread musaStreamPerThread\n#define cudaStreamSynchronize musaStreamSynchronize\n#define cudaStreamWaitEvent musaStreamWaitEvent\n#define cudaStream_t musaStream_t\n#define cudaSuccess musaSuccess\n\n// Additional mappings for MUSA virtual memory pool\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED\n#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED\n#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE\n#define CUdevice MUdevice\n#define CUdeviceptr MUdeviceptr\n#define CUmemAccessDesc MUmemAccessDesc\n#define CUmemAllocationProp MUmemAllocationProp\n#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle\n#define cuDeviceGet muDeviceGet\n#define cuDeviceGetAttribute muDeviceGetAttribute\n#define cuMemAddressFree muMemAddressFree\n#define cuMemAddressReserve muMemAddressReserve\n#define cuMemCreate muMemCreate\n#define cuMemGetAllocationGranularity muMemGetAllocationGranularity\n#define cuMemMap muMemMap\n#define cuMemRelease muMemRelease\n#define cuMemSetAccess muMemSetAccess\n#define cuMemUnmap muMemUnmap\n#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize\n#define cudaFuncSetAttribute musaFuncSetAttribute\n#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms\n#define make_cudaExtent make_musaExtent\n#define make_cudaPitchedPtr make_musaPitchedPtr\n\n// Additional mappings for MUSA graphs\n#define CUDA_SUCCESS MUSA_SUCCESS\n#define CUresult MUresult\n#define cuGetErrorString muGetErrorString\n#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure\n#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction\n#define cudaGraphDestroy musaGraphDestroy\n#define cudaGraphExecDestroy musaGraphExecDestroy\n#define cudaGraphExec_t musaGraphExec_t\n#define cudaGraphExecUpdate musaGraphExecUpdate\n#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult\n#define cudaGraphGetNodes musaGraphGetNodes\n#define cudaGraphInstantiate musaGraphInstantiate\n#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams\n#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams\n#define cudaGraphLaunch musaGraphLaunch\n#define cudaGraphNodeGetType musaGraphNodeGetType\n#define cudaGraphNode_t musaGraphNode_t\n#define cudaGraphNodeType musaGraphNodeType\n#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel\n#define cudaGraph_t musaGraph_t\n#define cudaKernelNodeParams musaKernelNodeParams\n#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed\n#define cudaStreamEndCapture musaStreamEndCapture\n\ntypedef mt_bfloat16 nv_bfloat16;\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cpu_backend/vendors/vendor.h",
    "content": "#ifndef CPUINFER_VENDOR_VENDOR_H\n#define CPUINFER_VENDOR_VENDOR_H\n\n#ifdef USE_CUDA\n#include \"cuda.h\"\n#elif USE_HIP\n#define __HIP_PLATFORM_AMD__\n#include \"hip.h\"\n#elif USE_MUSA\n#include \"musa.h\"\n#endif\n\n#endif  // CPUINFER_VENDOR_VENDOR_H"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cuda/binding.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Azure-Tang, Boxin Zhang\n * @Date         : 2024-07-25 13:38:30\n * @Version      : 0.2.2\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n**/\n\n#include \"custom_gguf/ops.h\"\n#ifdef KTRANSFORMERS_USE_CUDA\n#include \"gptq_marlin/ops.h\"\n#endif\n// Python bindings\n#include <pybind11/pybind11.h>\n#include <pybind11/stl.h>\n#include <torch/library.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n// namespace py = pybind11;\n\nPYBIND11_MODULE(KTransformersOps, m) {\n\n    m.def(\"dequantize_q8_0\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q8_0 data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_q6_k\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q6_k data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_q5_k\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q5_k data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_q4_k\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q4_k data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_q3_k\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q3_k data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_q2_k\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize q2_k data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n    m.def(\"dequantize_iq4_xs\", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) {\n        torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype);\n        return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype);\n        }, \"Function to dequantize iq4_xs data.\",\n        py::arg(\"data\"), py::arg(\"num_bytes\"), py::arg(\"blk_size\"), py::arg(\"ele_per_blk\"), py::arg(\"device\"), py::arg(\"target_dtype\"));\n\n#ifdef KTRANSFORMERS_USE_CUDA\n    m.def(\"gptq_marlin_gemm\", &gptq_marlin_gemm, \"Function to perform GEMM using Marlin quantization.\",\n        py::arg(\"a\"), py::arg(\"b_q_weight\"), py::arg(\"b_scales\"), py::arg(\"g_idx\"),\n        py::arg(\"perm\"), py::arg(\"workspace\"), py::arg(\"num_bits\"), py::arg(\"size_m\"),\n        py::arg(\"size_n\"), py::arg(\"size_k\"), py::arg(\"is_k_full\"));\n#endif\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu",
    "content": "/*\n * @Description  :  \n * @Author       : Azure-Tang, Boxin Zhang\n * @Date         : 2024-07-25 13:38:30\n * @Version      : 0.2.2\n * Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c\n * Copyright (c) 2023-2024 The ggml authors\n * Copyright (c) 2024 by KVCache.AI, All Rights Reserved. \n */\n#include <cuda_runtime.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <torch/library.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n#include <cstdint>\n#include <c10/cuda/CUDAGuard.h>\n\n#ifdef __HIP_PLATFORM_AMD__\ntypedef __hip_bfloat16 nv_bfloat16;\n#endif\n\n__global__ void dequantize_q8_0_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        const int8_t* cur_block = data + block_id * blk_size;\n        float scale = __half2float(*((half*)cur_block));\n        cur_block += 2;\n        for (int i = 0; i < ele_per_blk; i++){\n            output_blk[i] = scale * cur_block[i];\n        }\n    }\n}\n\n__global__ void dequantize_q8_0_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        const int8_t* cur_block = data + block_id * blk_size;\n        float scale = __half2float(*((half*)cur_block));\n        cur_block += 2;\n        for (int i = 0; i < ele_per_blk; i++) {\n            output_blk[i] = __float2half(scale * cur_block[i]);\n        }\n    }\n}\n\n__global__ void dequantize_q8_0_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x) {\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        const int8_t* cur_block = data + block_id * blk_size;\n        float scale = __half2float(*((half*)cur_block));\n        cur_block += 2;\n        for (int i = 0; i < ele_per_blk; i++) {\n            output_blk[i] = __float2bfloat16(scale * cur_block[i]);\n        }\n    }\n}\n\n// __device__ void get_scale_min_k4(int j, const uint8_t * __restrict__ q, uint8_t * __restrict__ d, uint8_t * __restrict__ m) {\n__device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict__ d, uint8_t * __restrict__ m) {\n    if (j < 4) {\n        *d = q[j] & 63; *m = q[j + 4] & 63;\n    } else {\n        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);\n        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);\n    }\n}\n\n__global__ void dequantize_q2_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));\n\n        const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);\n\n        int is = 0;\n        float dl, ml;\n\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n                uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                uint8_t sc = *scales;\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;\n\n                scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                sc = *scales;\n\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;\n\n                shift += 2;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q2_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));\n\n        const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);\n\n        int is = 0;\n        float dl, ml;\n\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n                uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                uint8_t sc = *scales;\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l] >> shift) & 3)) - ml);\n\n                scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                sc = *scales;\n\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml);\n\n                shift += 2;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q2_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 80)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 82)));\n\n        const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);\n\n        int is = 0;\n        float dl, ml;\n\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n                uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                uint8_t sc = *scales;\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l] >> shift) & 3)) - ml);\n\n                scales = (uint8_t*)(data + block_id * blk_size + (is++));\n                sc = *scales;\n\n                dl = d * (sc & 0xF); ml = min * (sc >> 4);\n                for (int l = 0; l < 16; ++l) *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml);\n\n                shift += 2;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q3_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    \n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;    \n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n\n        uint32_t aux[4];\n        const int8_t * scales = (const int8_t*)aux;\n        const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));\n\n        const uint8_t * __restrict__ q  = (uint8_t*)(data + block_id * blk_size + 32);\n        const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);\n        uint8_t m = 1;\n\n\n        uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);\n\n        for (int i = 0; i < 3; i++) {  \n            aux[i] = 0;  \n            for (int j = 0; j < 4; j++) {  \n                aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);\n            }\n        }\n\n        uint32_t tmp = aux[2];\n        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n        int is = 0;\n        float dl;\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));\n                }\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));\n                }\n\n                shift += 2;\n                m <<= 1;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q3_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    \n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;    \n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n\n        uint32_t aux[4];\n        const int8_t * scales = (const int8_t*)aux;\n        const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));\n\n        const uint8_t * __restrict__ q  = (uint8_t*)(data + block_id * blk_size + 32);\n        const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);\n        uint8_t m = 1;\n\n\n        uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);\n\n        for (int i = 0; i < 3; i++) {  \n            aux[i] = 0;  \n            for (int j = 0; j < 4; j++) {  \n                aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);\n            }\n        }\n\n        uint32_t tmp = aux[2];\n        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n        int is = 0;\n        float dl;\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2half(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)));\n                }\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2half(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)));\n                }\n\n                shift += 2;\n                m <<= 1;\n            }\n            q += 32;\n        }\n    }\n}\n\n__global__ void dequantize_q3_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    \n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;    \n    const uint32_t kmask1 = 0x03030303;\n    const uint32_t kmask2 = 0x0f0f0f0f;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n\n        uint32_t aux[4];\n        const int8_t * scales = (const int8_t*)aux;\n        const float d_all = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 108)));\n\n        const uint8_t * __restrict__ q  = (uint8_t*)(data + block_id * blk_size + 32);\n        const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);\n        uint8_t m = 1;\n\n\n        uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);\n\n        for (int i = 0; i < 3; i++) {  \n            aux[i] = 0;  \n            for (int j = 0; j < 4; j++) {  \n                aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);\n            }\n        }\n\n        uint32_t tmp = aux[2];\n        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);\n        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);\n        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);\n        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);\n\n        int is = 0;\n        float dl;\n        for (int n = 0; n < 256; n += 128) {\n            int shift = 0;\n            for (int j = 0; j < 4; ++j) {\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)));\n                }\n\n                dl = d_all * (scales[is++] - 32);\n                for (int l = 0; l < 16; ++l) {\n                    *output_blk++ = __float2bfloat16(dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)));\n                }\n\n                shift += 2;\n                m <<= 1;\n            }\n            q += 32;\n        }\n    }\n}\n\n\n__global__ void dequantize_q4_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        // const uint8_t * q = data[i].qs;\n        const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));\n        int is = 0;\n        uint8_t sc, m;\n        for (int j = 0; j < ele_per_blk; j += 64) {\n            uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d1 * (q[l] & 0xF) - m1;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d2 * (q[l]  >> 4) - m2;\n            q += 32; is += 2;\n        }\n    }\n}\n\n__global__ void dequantize_q4_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        // const uint8_t * q = data[i].qs;\n        const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));\n        int is = 0;\n        uint8_t sc, m;\n        for (int j = 0; j < ele_per_blk; j += 64) {\n            uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * (q[l] & 0xF) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * (q[l]  >> 4) - m2);\n            q += 32; is += 2;\n        }\n    }\n}\n\n__global__ void dequantize_q4_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        // const uint8_t * q = data[i].qs;\n        const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * 144 + 2)));\n        int is = 0;\n        uint8_t sc, m;\n        for (int j = 0; j < ele_per_blk; j += 64) {\n            uint8_t* scales = (uint8_t*)(data + block_id * 144 + 4);\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * (q[l] & 0xF) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * (q[l]  >> 4) - m2);\n            q += 32; is += 2;\n        }\n    }\n}\n\n__global__ void dequantize_q5_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));\n\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);\n\n        int is = 0;\n        uint8_t sc, m;\n        uint8_t u1 = 1, u2 = 2;\n        uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);\n\n        for (int j = 0; j < 256; j += 64) {\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;\n            for (int l = 0; l < 32; ++l) *output_blk++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;\n            ql += 32; is += 2;\n            u1 <<= 2; u2 <<= 2;\n        }\n    }\n}\n\n__global__ void dequantize_q5_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));\n\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);\n\n        int is = 0;\n        uint8_t sc, m;\n        uint8_t u1 = 1, u2 = 2;\n        uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);\n\n        for (int j = 0; j < 256; j += 64) {\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2half(d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2);\n            ql += 32; is += 2;\n            u1 <<= 2; u2 <<= 2;\n        }\n    }\n}\n\n__global__ void dequantize_q5_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n\n        const float d   = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 0)));\n        const float min = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 2)));\n\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);\n\n        int is = 0;\n        uint8_t sc, m;\n        uint8_t u1 = 1, u2 = 2;\n        uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);\n\n        for (int j = 0; j < 256; j += 64) {\n            get_scale_min_k4(is + 0, scales, &sc, &m);\n            const float d1 = d * sc; const float m1 = min * m;\n            get_scale_min_k4(is + 1, scales, &sc, &m);\n            const float d2 = d * sc; const float m2 = min * m;\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1);\n            for (int l = 0; l < 32; ++l) *output_blk++ = __float2bfloat16(d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2);\n            ql += 32; is += 2;\n            u1 <<= 2; u2 <<= 2;\n        }\n    }\n}\n\n__global__ void dequantize_q6_k_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long  block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));\n\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);\n        const int8_t  * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);\n\n\n        for (int n = 0; n < ele_per_blk; n += 128) {\n            for (int l = 0; l < 32; ++l) {\n                int is = l/16;\n                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n                output_blk[l +  0] = d * sc[is + 0] * q1;\n                output_blk[l + 32] = d * sc[is + 2] * q2;\n                output_blk[l + 64] = d * sc[is + 4] * q3;\n                output_blk[l + 96] = d * sc[is + 6] * q4;\n            }\n            output_blk += 128;\n            ql += 64;\n            qh += 32;\n            sc += 8;\n        }\n    }\n}\n\n__global__ void dequantize_q6_k_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long  block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));\n\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);\n        const int8_t  * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);\n\n\n        for (int n = 0; n < ele_per_blk; n += 128) {\n            for (int l = 0; l < 32; ++l) {\n                int is = l/16;\n                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n                output_blk[l +  0] = __float2half(d * sc[is + 0] * q1);\n                output_blk[l + 32] = __float2half(d * sc[is + 2] * q2);\n                output_blk[l + 64] = __float2half(d * sc[is + 4] * q3);\n                output_blk[l + 96] = __float2half(d * sc[is + 6] * q4);\n            }\n            output_blk += 128;\n            ql += 64;\n            qh += 32;\n            sc += 8;\n        }\n    }\n}\n\n__global__ void dequantize_q6_k_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long  block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size + 208)));\n\n        const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size);\n        const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 128);\n        const int8_t  * __restrict__ sc = (int8_t*)(data + block_id * blk_size + 192);\n\n\n        for (int n = 0; n < ele_per_blk; n += 128) {\n            for (int l = 0; l < 32; ++l) {\n                int is = l/16;\n                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;\n                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;\n                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;\n                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;\n                output_blk[l +  0] = __float2bfloat16(d * sc[is + 0] * q1);\n                output_blk[l + 32] = __float2bfloat16(d * sc[is + 2] * q2);\n                output_blk[l + 64] = __float2bfloat16(d * sc[is + 4] * q3);\n                output_blk[l + 96] = __float2bfloat16(d * sc[is + 6] * q4);\n            }\n            output_blk += 128;\n            ql += 64;\n            qh += 32;\n            sc += 8;\n        }\n    }\n}\n\nstatic constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};\n\n__global__ void dequantize_iq4_xs_fp32_kernel(const int8_t* data, float* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {\n        float* __restrict__ output_blk = (float*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));\n        const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));\n        const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);\n        const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);\n\n        for (int ib = 0; ib < 8; ++ib) {\n            const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);\n            const float dl = d * (ls - 32);\n            for (int j = 0; j < 16; ++j) {\n                output_blk[j + 0] = dl * kvalues_iq4nl[qs[j] & 0xf];\n                output_blk[j + 16] = dl * kvalues_iq4nl[qs[j] >> 4];\n            }\n            output_blk += 32;\n            qs += 16;\n        }\n    }\n}\n\n__global__ void dequantize_iq4_xs_fp16_kernel(const int8_t* data, __half* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {\n        __half* __restrict__ output_blk = (__half*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));\n        const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));\n        const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);\n        const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);\n\n        for (int ib = 0; ib < 8; ++ib) {\n            const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);\n            const float dl = d * (ls - 32);\n            for (int j = 0; j < 16; ++j) {\n                output_blk[j + 0] = __float2half(dl * kvalues_iq4nl[qs[j] & 0xf]);\n                output_blk[j + 16] = __float2half(dl * kvalues_iq4nl[qs[j] >> 4]);\n            }\n            output_blk += 32;\n            qs += 16;\n        }\n    }\n}\n\n__global__ void dequantize_iq4_xs_bf16_kernel(const int8_t* data, nv_bfloat16* output, const int blk_size, const int ele_per_blk, const int num_blocks) {\n    long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;\n    for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {\n        nv_bfloat16* __restrict__ output_blk = (nv_bfloat16*)(output + block_id * ele_per_blk);\n        const float d = __half2float(*(reinterpret_cast<const half*>(data + block_id * blk_size)));\n        const uint16_t scales_h = *(reinterpret_cast<const uint16_t*>(data + block_id * blk_size + 2));\n        const uint8_t* scales_l = (uint8_t*)(data + block_id * blk_size + 2 + 2);\n        const uint8_t* qs = (uint8_t*)(data + block_id * blk_size + 2 + 2 + 4);\n\n        for (int ib = 0; ib < 8; ++ib) {\n            const int ls = ((scales_l[ib / 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h >> 2 * ib) & 3) << 4);\n            const float dl = d * (ls - 32);\n            for (int j = 0; j < 16; ++j) {\n                output_blk[j + 0] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] & 0xf]);\n                output_blk[j + 16] = __float2bfloat16(dl * kvalues_iq4nl[qs[j] >> 4]);\n            }\n            output_blk += 32;\n            qs += 16;\n        }\n    }\n}\n\ntorch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({ num_bytes }, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({ num_blocks, ele_per_blk }, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q8_0_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q8_0_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q8_0_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n\n    cudaDeviceSynchronize();\n    return output;\n}\n\n\ntorch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    // data.numel%blk_size should be 0, else raise err\n    int num_blocks = num_bytes / blk_size;\n\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q6_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q6_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q6_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q5_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q5_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q5_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    // data.numel%blk_size should be 0, else raise err\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q4_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q4_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q4_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q3_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q3_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q3_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_q2_k_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_q2_k_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_q2_k_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n\ntorch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype) {\n    int num_blocks = num_bytes / blk_size;\n    const at::cuda::OptionalCUDAGuard device_guard(device);\n\n    auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);\n    auto data_gpu = torch::empty({num_bytes}, options);\n\n    cudaMemcpy(data_gpu.data_ptr<int8_t>(), data, num_bytes, cudaMemcpyHostToDevice);\n    //data_gpu.copy_(data, false);\n\n    // Create output tensor\n    auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));\n\n    switch (target_dtype) {\n        case torch::kFloat16:\n            dequantize_iq4_xs_fp16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (__half*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kBFloat16:\n            dequantize_iq4_xs_bf16_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), (nv_bfloat16*)output.data_ptr(), blk_size, ele_per_blk, num_blocks);\n            break;\n        case torch::kFloat32:\n            dequantize_iq4_xs_fp32_kernel<<<512, 256>>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, ele_per_blk, num_blocks);\n            break;\n        default:\n            printf(\"target type not support\\n\");\n            exit(0);\n    }\n    cudaDeviceSynchronize();\n    return output;\n}\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cuda/custom_gguf/ops.h",
    "content": "/**\n * @Description  :\n * @Author       : Azure-Tang\n * @Date         : 2024-07-22 09:27:55\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-12 03:48:46\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n**/\n#pragma once\n\n#include <torch/library.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\ntorch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);\ntorch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::Dtype target_dtype);"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu",
    "content": "/*\n * Modified by Neural Magic\n * Copyright (C) Marlin.2024 Elias Frantar\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *         http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\n/*\n * Adapted from https://github.com/IST-DASLab/marlin\n */\n/*\n * Adapted from  https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n */\n#include \"gptq_marlin.cuh\"\n#include \"gptq_marlin_dtypes.cuh\"\n#include <c10/cuda/CUDAGuard.h>\n#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t)               \\\n  static_assert(std::is_same<scalar_t, half>::value ||          \\\n                    std::is_same<scalar_t, nv_bfloat16>::value, \\\n                \"only float16 and bfloat16 is supported\");\n\ntemplate <typename T>\ninline std::string str(T x) {\n  return std::to_string(x);\n}\n\nnamespace gptq_marlin {\n\n#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined(__HIP_PLATFORM_AMD__)\n\n__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,\n                                    int const* __restrict__ perm_int_ptr,\n                                    int4* __restrict__ out_int4_ptr, int size_m,\n                                    int size_k, int block_rows) {}\n\ntemplate <typename scalar_t,          // compute dtype, half or nv_float16\n          const int num_bits,         // number of bits used for weights\n          const int threads,          // number of threads in a threadblock\n          const int thread_m_blocks,  // number of 16x16 blocks in the m\n                                      // dimension (batchsize) of the\n                                      // threadblock\n          const int thread_n_blocks,  // same for n dimension (output)\n          const int thread_k_blocks,  // same for k dimension (reduction)\n          const int stages,  // number of stages for the async global->shared\n                             // fetch pipeline\n          const bool has_act_order,    // whether act_order is enabled\n          const int group_blocks = -1  // number of consecutive 16x16 blocks\n                                       // with a separate quantization scale\n          >\n__global__ void Marlin(\n    const int4* __restrict__ A,  // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,        // fp16 output buffer of shape mxn\n    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape\n                                          // (k/groupsize)xn\n    const int* __restrict__ g_idx,        // int32 group indices of shape k\n    int num_groups,  // number of scale groups per output channel\n    int prob_m,      // batch dimension m\n    int prob_n,      // output dimension n\n    int prob_k,      // reduction dimension k\n    int* locks       // extra global storage for barrier synchronization\n) {}\n\n}  // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                               torch::Tensor& b_scales, torch::Tensor& g_idx,\n                               torch::Tensor& perm, torch::Tensor& workspace,\n                               int64_t num_bits, int64_t size_m, int64_t size_n,\n                               int64_t size_k, bool is_k_full) {\n  TORCH_CHECK_NOT_IMPLEMENTED(false,\n                              \"marlin_gemm(..) requires CUDA_ARCH >= 8.0\");\n  return torch::empty({1, 1});\n}\n\n#else\n\n// m16n8k16 tensor core mma instruction with fp16 inputs and fp32\n// output/accumulation.\ntemplate <typename scalar_t>\n__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,\n                           const typename ScalarType<scalar_t>::FragB& frag_b,\n                           typename ScalarType<scalar_t>::FragC& frag_c) {\n  const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);\n  const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);\n  float* c = reinterpret_cast<float*>(&frag_c);\n  if constexpr (std::is_same<scalar_t, half>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n          \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {\n    asm volatile(\n        \"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \"\n        \"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\\n\"\n        : \"=f\"(c[0]), \"=f\"(c[1]), \"=f\"(c[2]), \"=f\"(c[3])\n        : \"r\"(a[0]), \"r\"(a[1]), \"r\"(a[2]), \"r\"(a[3]), \"r\"(b[0]), \"r\"(b[1]),\n          \"f\"(c[0]), \"f\"(c[1]), \"f\"(c[2]), \"f\"(c[3]));\n  } else {\n    STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n  }\n}\n\n// Instruction for loading a full 16x16 matrix fragment of operand A from shared\n// memory, directly in tensor core layout.\ntemplate <typename scalar_t>\n__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,\n                             const void* smem_ptr) {\n  uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\\n\"\n               : \"=r\"(a[0]), \"=r\"(a[1]), \"=r\"(a[2]), \"=r\"(a[3])\n               : \"r\"(smem));\n}\n\n// Lookup-table based 3-input logical operation; explicitly used for\n// dequantization as the compiler does not seem to automatically recognize it in\n// all cases.\ntemplate <int lut>\n__device__ inline int lop3(int a, int b, int c) {\n  int res;\n  asm volatile(\"lop3.b32 %0, %1, %2, %3, %4;\\n\"\n               : \"=r\"(res)\n               : \"r\"(a), \"r\"(b), \"r\"(c), \"n\"(lut));\n  return res;\n}\n\n// Constructs destination register by taking bytes from 2 sources (based on\n// mask)\ntemplate <int start_byte, int mask>\n__device__ inline uint32_t prmt(uint32_t a) {\n  uint32_t res;\n  asm volatile(\"prmt.b32 %0, %1, %2, %3;\\n\"\n               : \"=r\"(res)\n               : \"r\"(a), \"n\"(start_byte), \"n\"(mask));\n  return res;\n}\n\n// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16\n// values. We mostly follow the strategy in the link below, with some small\n// changes:\n// - FP16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287\n// - BF16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385\ntemplate <typename scalar_t>\n__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {\n  STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n}\n\ntemplate <>\n__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {\n  const int LO = 0x000f000f;\n  const int HI = 0x00f000f0;\n  const int EX = 0x64006400;\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);\n  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point\n  // directly into `SUB` and `ADD`.\n  const int SUB = 0x64086408;\n  const int MUL = 0x2c002c00;\n  const int ADD = 0xd480d480;\n  typename ScalarType<half>::FragB frag_b;\n  frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),\n                      *reinterpret_cast<const half2*>(&SUB));\n  frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),\n                      *reinterpret_cast<const half2*>(&MUL),\n                      *reinterpret_cast<const half2*>(&ADD));\n  return frag_b;\n}\n\ntemplate <>\n__device__ inline typename ScalarType<nv_bfloat16>::FragB\ndequant_4bit<nv_bfloat16>(int q) {\n  static constexpr uint32_t MASK = 0x000f000f;\n  static constexpr uint32_t EX = 0x43004300;\n\n  // Guarantee that the `(a & b) | c` operations are LOP3s.\n\n  int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n  q >>= 4;\n  int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);\n\n  typename ScalarType<nv_bfloat16>::FragB frag_b;\n  static constexpr uint32_t MUL = 0x3F803F80;\n  static constexpr uint32_t ADD = 0xC308C308;\n\n  frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),\n                      *reinterpret_cast<const nv_bfloat162*>(&MUL),\n                      *reinterpret_cast<const nv_bfloat162*>(&ADD));\n  frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),\n                      *reinterpret_cast<const nv_bfloat162*>(&MUL),\n                      *reinterpret_cast<const nv_bfloat162*>(&ADD));\n  return frag_b;\n}\n\n// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or\n// bf16 Reference:\n// - FP16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85\n// - BF16:\n// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175\ntemplate <typename scalar_t>\n__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {\n  STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);\n}\n\ntemplate <>\n__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {\n  static constexpr uint32_t mask_for_elt_01 = 0x5250;\n  static constexpr uint32_t mask_for_elt_23 = 0x5351;\n  static constexpr uint32_t start_byte_for_fp16 = 0x64646464;\n\n  uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);\n  uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);\n\n  static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;\n\n  typename ScalarType<half>::FragB frag_b;\n  frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),\n                      *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n  frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),\n                      *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));\n  return frag_b;\n}\n\ntemplate <>\n__device__ inline typename ScalarType<nv_bfloat16>::FragB\ndequant_8bit<nv_bfloat16>(int q) {\n  typename ScalarType<nv_bfloat16>::FragB frag_b;\n\n  float fp32_intermediates[4];\n  uint32_t* fp32_intermediates_casted =\n      reinterpret_cast<uint32_t*>(fp32_intermediates);\n\n  static constexpr uint32_t fp32_base = 0x4B000000;\n  fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);\n  fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);\n  fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);\n  fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);\n\n  fp32_intermediates[0] -= 8388736.f;\n  fp32_intermediates[1] -= 8388736.f;\n  fp32_intermediates[2] -= 8388736.f;\n  fp32_intermediates[3] -= 8388736.f;\n\n  uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);\n  bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],\n                                   fp32_intermediates_casted[1], 0x7632);\n  bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],\n                                   fp32_intermediates_casted[3], 0x7632);\n\n  return frag_b;\n}\n\n// Multiply dequantized values by the corresponding quantization scale; used\n// only for grouped quantization.\ntemplate <typename scalar_t>\n__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,\n                             typename ScalarType<scalar_t>::FragS& frag_s,\n                             int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s =\n      ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);\n  frag_b[0] = __hmul2(frag_b[0], s);\n  frag_b[1] = __hmul2(frag_b[1], s);\n}\n\n// Same as above, but for act_order (each K is multiplied individually)\ntemplate <typename scalar_t>\n__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,\n                              typename ScalarType<scalar_t>::FragS& frag_s_1,\n                              typename ScalarType<scalar_t>::FragS& frag_s_2,\n                              typename ScalarType<scalar_t>::FragS& frag_s_3,\n                              typename ScalarType<scalar_t>::FragS& frag_s_4,\n                              int i) {\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  scalar_t2 s_val_1_2;\n  s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];\n  s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];\n\n  scalar_t2 s_val_3_4;\n  s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];\n  s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];\n\n  frag_b[0] = __hmul2(frag_b[0], s_val_1_2);\n  frag_b[1] = __hmul2(frag_b[1], s_val_3_4);\n}\n\n// Given 2 floats multiply by 2 scales (halves)\ntemplate <typename scalar_t>\n__device__ inline void scale_float(float* c,\n                                   typename ScalarType<scalar_t>::FragS& s) {\n  scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);\n  c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));\n  c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));\n}\n\n// Wait until barrier reaches `count`, then lock for current threadblock.\n__device__ inline void barrier_acquire(int* lock, int count) {\n  if (threadIdx.x == 0) {\n    int state = -1;\n    do\n      // Guarantee that subsequent writes by this threadblock will be visible\n      // globally.\n      asm volatile(\"ld.global.acquire.gpu.b32 %0, [%1];\\n\"\n                   : \"=r\"(state)\n                   : \"l\"(lock));\n    while (state != count);\n  }\n  __syncthreads();\n}\n\n// Release barrier and increment visitation count.\n__device__ inline void barrier_release(int* lock, bool reset = false) {\n  __syncthreads();\n  if (threadIdx.x == 0) {\n    if (reset) {\n      lock[0] = 0;\n      return;\n    }\n    int val = 1;\n    // Make sure that all writes since acquiring this barrier are visible\n    // globally, while releasing the barrier.\n    asm volatile(\"fence.acq_rel.gpu;\\n\");\n    asm volatile(\"red.relaxed.gpu.global.add.s32 [%0], %1;\\n\"\n                 :\n                 : \"l\"(lock), \"r\"(val));\n  }\n}\n\n// For a given \"a\" of size [M,K] performs a permutation of the K columns based\n// on the given \"perm\" indices.\n__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,\n                                    int const* __restrict__ perm_int_ptr,\n                                    int4* __restrict__ out_int4_ptr, int size_m,\n                                    int size_k, int block_rows) {\n  int start_row = block_rows * blockIdx.x;\n  int finish_row = start_row + block_rows;\n  if (finish_row > size_m) {\n    finish_row = size_m;\n  }\n  int cur_block_rows = finish_row - start_row;\n\n  int row_stride = size_k * sizeof(half) / 16;\n\n  auto permute_row = [&](int row) {\n    int iters = size_k / default_threads;\n    int rest = size_k % default_threads;\n\n    int offset = row * row_stride;\n\n    half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);\n    half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);\n\n    int base_k = 0;\n\n    for (int i = 0; i < iters; i++) {\n      int cur_k = base_k + threadIdx.x;\n      int src_pos = perm_int_ptr[cur_k];\n\n      out_half[cur_k] = a_row_half[src_pos];\n\n      base_k += default_threads;\n    }\n\n    if (rest) {\n      if (threadIdx.x < rest) {\n        int cur_k = base_k + threadIdx.x;\n        int src_pos = perm_int_ptr[cur_k];\n\n        out_half[cur_k] = a_row_half[src_pos];\n      }\n    }\n  };\n\n  for (int i = 0; i < cur_block_rows; i++) {\n    int cur_row = start_row + i;\n    if (cur_row < size_m) {\n      permute_row(cur_row);\n    }\n  }\n}\n\ntemplate <typename scalar_t,          // compute dtype, half or nv_float16\n          const int num_bits,         // number of bits used for weights\n          const int threads,          // number of threads in a threadblock\n          const int thread_m_blocks,  // number of 16x16 blocks in the m\n                                      // dimension (batchsize) of the\n                                      // threadblock\n          const int thread_n_blocks,  // same for n dimension (output)\n          const int thread_k_blocks,  // same for k dimension (reduction)\n          const int stages,  // number of stages for the async global->shared\n                             // fetch pipeline\n          const bool has_act_order,    // whether act_order is enabled\n          const int group_blocks = -1  // number of consecutive 16x16 blocks\n                                       // with a separate quantization scale\n          >\n__global__ void Marlin(\n    const int4* __restrict__ A,  // fp16 input matrix of shape mxk\n    const int4* __restrict__ B,  // 4bit quantized weight matrix of shape kxn\n    int4* __restrict__ C,        // fp16 output buffer of shape mxn\n    const int4* __restrict__ scales_ptr,  // fp16 quantization scales of shape\n                                          // (k/groupsize)xn\n    const int* __restrict__ g_idx,        // int32 group indices of shape k\n    int num_groups,  // number of scale groups per output channel\n    int prob_m,      // batch dimension m\n    int prob_n,      // output dimension n\n    int prob_k,      // reduction dimension k\n    int* locks       // extra global storage for barrier synchronization\n) {\n  // Each threadblock processes one \"stripe\" of the B matrix with (roughly) the\n  // same size, which might involve multiple column \"slices\" (of width 16 *\n  // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM\n  // example:\n  //   0 1 3\n  //   0 2 3\n  //   1 2 4\n  // While this kind of partitioning makes things somewhat more complicated, it\n  // ensures good utilization of all SMs for many kinds of shape and GPU\n  // configurations, while requiring as few slow global cross-threadblock\n  // reductions as possible.\n  using Dtype = ScalarType<scalar_t>;\n  using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;\n  using FragA = typename ScalarType<scalar_t>::FragA;\n  using FragB = typename ScalarType<scalar_t>::FragB;\n  using FragC = typename ScalarType<scalar_t>::FragC;\n  using FragS = typename ScalarType<scalar_t>::FragS;\n\n  constexpr int pack_factor = 32 / num_bits;\n\n  // For larger GEMMs we run multiple batchsize 64 versions in parallel for a\n  // better partitioning with less reductions\n  int parallel = 1;\n  if (prob_m > 16 * thread_m_blocks) {\n    parallel = prob_m / (16 * thread_m_blocks);\n    prob_m = 16 * thread_m_blocks;\n  }\n\n  int k_tiles = prob_k / 16 / thread_k_blocks;\n  int n_tiles = prob_n / 16 / thread_n_blocks;\n  int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);\n\n  if constexpr (!has_act_order && group_blocks != -1) {\n    if (group_blocks >= thread_k_blocks) {\n      // Ensure that the number of tiles in each stripe is a multiple of the\n      // groupsize; this avoids an annoying special case where a stripe starts\n      // in the middle of group.\n      iters = (group_blocks / thread_k_blocks) *\n              div_ceil(iters, (group_blocks / thread_k_blocks));\n    }\n  }\n\n  int slice_row = (iters * blockIdx.x) % k_tiles;\n  int slice_col_par = (iters * blockIdx.x) / k_tiles;\n  int slice_col = slice_col_par;\n  int slice_iters;  // number of threadblock tiles in the current slice\n  int slice_count =\n      0;          // total number of active threadblocks in the current slice\n  int slice_idx;  // index of threadblock in current slice; numbered bottom to\n                  // top\n\n  // We can easily implement parallel problem execution by just remapping\n  // indices and advancing global pointers\n  if (slice_col_par >= n_tiles) {\n    A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;\n    C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;\n    locks += (slice_col_par / n_tiles) * n_tiles;\n    slice_col = slice_col_par % n_tiles;\n  }\n\n  // Compute all information about the current slice which is required for\n  // synchronization.\n  auto init_slice = [&]() {\n    slice_iters =\n        iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);\n    if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;\n    if (slice_iters == 0) return;\n    if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;\n    slice_count = 1;\n    slice_idx = 0;\n    int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);\n    if (col_first <= k_tiles * (slice_col_par + 1)) {\n      int col_off = col_first - k_tiles * slice_col_par;\n      slice_count = div_ceil(k_tiles - col_off, iters);\n      if (col_off > 0) slice_count++;\n      int delta_first = iters * blockIdx.x - col_first;\n      if (delta_first < 0 || (col_off == 0 && delta_first == 0))\n        slice_idx = slice_count - 1;\n      else {\n        slice_idx = slice_count - 1 - delta_first / iters;\n        if (col_off > 0) slice_idx--;\n      }\n    }\n    if (slice_col == n_tiles) {\n      A += 16 * thread_m_blocks * prob_k / 8;\n      C += 16 * thread_m_blocks * prob_n / 8;\n      locks += n_tiles;\n      slice_col = 0;\n    }\n  };\n  init_slice();\n\n  // A sizes/strides\n\n  // stride of the A matrix in global memory\n  int a_gl_stride = prob_k / 8;\n  // stride of an A matrix tile in shared memory\n  constexpr int a_sh_stride = 16 * thread_k_blocks / 8;\n  // delta between subsequent A tiles in global memory\n  constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;\n  // between subsequent accesses within a tile\n  int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory writes\n  constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);\n  // between shared memory tile reads\n  constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));\n  // within a shared memory tile\n  constexpr int a_sh_rd_delta_i = a_sh_stride * 16;\n  // overall size of a tile\n  constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);\n  // number of shared write iterations for a tile\n  constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);\n\n  // B sizes/strides\n  int b_gl_stride = 16 * prob_n / (pack_factor * 4);\n  constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;\n  constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;\n  constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;\n\n  int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;\n  int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);\n  constexpr int b_sh_wr_delta = threads * b_thread_vecs;\n  constexpr int b_sh_rd_delta = threads * b_thread_vecs;\n  constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;\n  constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;\n\n  // Scale sizes/strides without act_order\n  int s_gl_stride = prob_n / 8;\n  constexpr int s_sh_stride = 16 * thread_n_blocks / 8;\n  constexpr int s_tb_groups =\n      !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks\n          ? thread_k_blocks / group_blocks\n          : 1;\n  constexpr int s_sh_stage = s_tb_groups * s_sh_stride;\n  int s_gl_rd_delta = s_gl_stride;\n\n  // Scale size/strides with act_order\n  constexpr int tb_k = 16 * thread_k_blocks;\n  constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;\n  // constexpr int act_s_row_stride      = 1;\n  // int           act_s_col_stride      = act_s_row_stride * num_groups;\n  int act_s_col_stride = 1;\n  int act_s_col_warp_stride = act_s_col_stride * 8;\n  int tb_n_warps = thread_n_blocks / 4;\n  int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;\n\n  // Global A read index of current thread.\n  int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                (threadIdx.x % a_gl_rd_delta_o);\n  a_gl_rd += a_gl_rd_delta_o * slice_row;\n  // Shared write index of current thread.\n  int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                (threadIdx.x % a_gl_rd_delta_o);\n  // Shared read index.\n  int a_sh_rd =\n      a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;\n  a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));\n\n  int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +\n                (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;\n  b_gl_rd += b_sh_stride * slice_col;\n  b_gl_rd += b_gl_rd_delta_o * slice_row;\n  int b_sh_wr = threadIdx.x * b_thread_vecs;\n  int b_sh_rd = threadIdx.x * b_thread_vecs;\n\n  // For act_order\n  constexpr int k_iter_size = tb_k / b_sh_wr_iters;\n  int slice_k_start = tb_k * slice_row;\n  int slice_k_finish = slice_k_start + tb_k * slice_iters;\n  int slice_k_start_shared_fetch = slice_k_start;\n  int slice_n_offset = act_s_col_tb_stride * slice_col;\n\n  // No act_order\n  int s_gl_rd;\n  if constexpr (!has_act_order) {\n    if constexpr (group_blocks == -1) {\n      s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n    } else {\n      s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +\n                s_sh_stride * slice_col + threadIdx.x;\n    }\n  }\n  int s_sh_wr = threadIdx.x;\n  bool s_sh_wr_pred = threadIdx.x < s_sh_stride;\n\n  // We use a different scale layout for grouped and column-wise quantization as\n  // we scale a `half2` tile in column-major layout in the former and in\n  // row-major in the latter case.\n  int s_sh_rd;\n  if constexpr (group_blocks != -1)\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n              (threadIdx.x % 32) / 4;\n  else\n    s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +\n              (threadIdx.x % 32) % 4;\n\n  // Precompute which thread should not read memory in which iterations; this is\n  // needed if there are more threads than required for a certain tilesize or\n  // when the batchsize is not a multiple of 16.\n  bool a_sh_wr_pred[a_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;\n\n  // To ensure that writing and reading A tiles to/from shared memory, the\n  // latter in fragment format, is fully bank conflict free, we need to use a\n  // rather fancy XOR-based layout. The key here is that neither reads nor\n  // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the\n  // same shared memory banks. Further, it seems (based on NSight-Compute) that\n  // each warp must also write a consecutive memory segment?\n  auto transform_a = [&](int i) {\n    int row = i / a_gl_rd_delta_o;\n    return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;\n  };\n  // Since the computation of this remapping is non-trivial and, due to our main\n  // loop unrolls, all shared memory accesses are static, we simply precompute\n  // both transformed reads and writes.\n  int a_sh_wr_trans[a_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < a_sh_wr_iters; i++)\n    a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);\n  int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];\n  #pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++) {\n  #pragma unroll\n    for (int j = 0; j < thread_m_blocks; j++)\n      a_sh_rd_trans[i][j] =\n          transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);\n  }\n\n  // Since B-accesses have non-constant stride they have to be computed at\n  // runtime; we break dependencies between subsequent accesses with a tile by\n  // maintining multiple pointers (we have enough registers), a tiny\n  // optimization.\n  const int4* B_ptr[b_sh_wr_iters];\n  #pragma unroll\n  for (int i = 0; i < b_sh_wr_iters; i++)\n    B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;\n\n  extern __shared__ int4 sh[];\n  // Shared memory storage for global fetch pipelines.\n  int4* sh_a = sh;\n  int4* sh_b = sh_a + (stages * a_sh_stage);\n  int4* sh_g_idx = sh_b + (stages * b_sh_stage);\n  int4* sh_s = sh_g_idx + (stages * g_idx_stage);\n\n  // Register storage for double buffer of shared memory reads.\n  FragA frag_a[2][thread_m_blocks];\n  I4 frag_b_quant[2][b_thread_vecs];\n  FragC frag_c[thread_m_blocks][4][2];\n  FragS frag_s[2][4];         // No act-order\n  FragS act_frag_s[2][4][4];  // For act-order\n\n  // Zero accumulators.\n  auto zero_accums = [&]() {\n  #pragma unroll\n    for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)\n      reinterpret_cast<float*>(frag_c)[i] = 0;\n  };\n\n  int sh_first_group_id = -1;\n  int sh_num_groups = -1;\n  constexpr int sh_max_num_groups = 32;\n\n  auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,\n                                    int last_group_id) {\n    sh_first_group_id = first_group_id;\n    sh_num_groups = last_group_id - first_group_id + 1;\n\n    if (sh_num_groups < sh_max_num_groups) {\n      sh_num_groups = sh_max_num_groups;\n    }\n\n    if (sh_first_group_id + sh_num_groups > num_groups) {\n      sh_num_groups = num_groups - sh_first_group_id;\n    }\n\n    int row_offset = first_group_id * s_gl_stride;\n\n    if (is_async) {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],\n                         &scales_ptr[row_offset + (i * s_gl_stride) +\n                                     slice_n_offset + threadIdx.x]);\n        }\n      }\n    } else {\n      for (int i = 0; i < sh_num_groups; i++) {\n        if (threadIdx.x < s_sh_stride) {\n          sh_s[(i * s_sh_stride) + threadIdx.x] =\n              scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +\n                         threadIdx.x];\n        }\n      }\n    }\n  };\n  // Asynchronously fetch the next A, B and s tile from global to the next\n  // shared memory pipeline location.\n  auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {\n    if (pred) {\n      int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n  #pragma unroll\n      for (int i = 0; i < a_sh_wr_iters; i++) {\n        cp_async4_pred(\n            &sh_a_stage[a_sh_wr_trans[i]],\n            &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],\n            a_sh_wr_pred[i]);\n      }\n      int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n  #pragma unroll\n      for (int i = 0; i < b_sh_wr_iters; i++) {\n  #pragma unroll\n        for (int j = 0; j < b_thread_vecs; j++) {\n          cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);\n        }\n\n        B_ptr[i] += b_gl_rd_delta_o;\n      }\n\n      if constexpr (has_act_order) {\n        // Fetch g_idx thread-block portion\n        int full_pipe = a_off;\n        int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;\n        if (cur_k < prob_k && cur_k < slice_k_finish) {\n          int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n\n          int4 const* cur_g_idx_stage_ptr =\n              reinterpret_cast<int4 const*>(&g_idx[cur_k]);\n\n          if (threadIdx.x < g_idx_stage) {\n            cp_async4_pred(&sh_g_idx_stage[threadIdx.x],\n                           &cur_g_idx_stage_ptr[threadIdx.x]);\n          }\n        }\n      } else {\n        if constexpr (group_blocks != -1) {\n          int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n          if constexpr (group_blocks >= thread_k_blocks) {\n            // Only fetch scales if this tile starts a new group\n            if (pipe % (group_blocks / thread_k_blocks) == 0) {\n              if (s_sh_wr_pred) {\n                cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);\n              }\n              s_gl_rd += s_gl_rd_delta;\n            }\n          } else {\n            for (int i = 0; i < s_tb_groups; i++) {\n              if (s_sh_wr_pred) {\n                cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],\n                          &scales_ptr[s_gl_rd]);\n              }\n              s_gl_rd += s_gl_rd_delta;\n            }\n          }\n        }\n      }\n    }\n    // Insert a fence even when we are winding down the pipeline to ensure that\n    // waiting is also correct at this point.\n    cp_async_fence();\n  };\n\n  // Wait until the next thread tile has been loaded to shared memory.\n  auto wait_for_stage = [&]() {\n    // We only have `stages - 2` active fetches since we are double buffering\n    // and can only issue the next fetch when it is guaranteed that the previous\n    // shared memory load is fully complete (as it may otherwise be\n    // overwritten).\n    cp_async_wait<stages - 2>();\n    __syncthreads();\n  };\n\n  // Load the next sub-tile from the current location in the shared memory pipe\n  // into the current register buffer.\n  auto fetch_to_registers = [&](int k, int pipe) {\n    int4* sh_a_stage = sh_a + a_sh_stage * pipe;\n  #pragma unroll\n    for (int i = 0; i < thread_m_blocks; i++)\n      ldsm4<scalar_t>(frag_a[k % 2][i],\n                      &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);\n    int4* sh_b_stage = sh_b + b_sh_stage * pipe;\n\n  #pragma unroll\n    for (int i = 0; i < b_thread_vecs; i++) {\n      frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(\n          &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);\n    }\n  };\n\n  bool is_same_group[stages];\n  int same_group_id[stages];\n\n  auto init_same_group = [&](int pipe) {\n    if constexpr (!has_act_order) {\n      is_same_group[pipe] = false;\n      same_group_id[pipe] = 0;\n      return;\n    }\n\n    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n    int group_id_1 = sh_g_idx_int_ptr[0];\n    int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];\n\n    is_same_group[pipe] = group_id_1 == group_id_2;\n    same_group_id[pipe] = group_id_1;\n  };\n\n  auto fetch_scales_to_registers = [&](int k, int full_pipe) {\n    int pipe = full_pipe % stages;\n\n    if constexpr (!has_act_order) {\n      // No act-order case\n      if constexpr (group_blocks != -1) {\n        if constexpr (group_blocks >= thread_k_blocks) {\n          int4* sh_s_stage =\n              sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *\n                                   (pipe / (group_blocks / thread_k_blocks)));\n          reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];\n        } else {\n          int warp_id = threadIdx.x / 32;\n          int n_warps = thread_n_blocks / 4;\n\n          int warp_row = warp_id / n_warps;\n\n          int cur_k = warp_row * 16;\n          cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n          int k_blocks = cur_k / 16;\n          int cur_group_id = k_blocks / group_blocks;\n\n          int4* sh_s_stage = sh_s + s_sh_stage * pipe;\n\n          reinterpret_cast<int4*>(&frag_s[k % 2])[0] =\n              sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];\n        }\n      }\n\n      return;\n    }\n\n    // Act-order case\n\n    // Determine K of the \"current\" thread-block\n    int cur_k = slice_k_start + tb_k * full_pipe;\n    if (cur_k >= prob_k || cur_k >= slice_k_finish) {\n      return;\n    }\n\n    // Reset (to current thread-block) since we read g_idx portion from the\n    // shared memory\n    cur_k = 0;\n\n    // Progress to current iteration\n    cur_k += k_iter_size * (k % b_sh_wr_iters);\n\n    // Determine \"position\" inside the thread-block (based on warp and\n    // thread-id)\n    int warp_id = threadIdx.x / 32;\n    int n_warps =\n        thread_n_blocks / 4;  // Each warp processes 4 16-size tiles over N\n\n    int warp_row = warp_id / n_warps;\n    int warp_col = warp_id % n_warps;\n\n    cur_k += warp_row * 16;\n\n    int th_id = threadIdx.x % 32;\n    cur_k += (th_id % 4) * 2;  // Due to tensor-core layout for fp16 B matrix\n\n    int s_col_shift =\n        /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) +\n        (th_id / 4) * act_s_col_stride;\n\n    if (is_same_group[pipe]) {\n      if (k % 2 == 0) {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n            sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride +\n                 s_col_shift];\n      } else {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =\n            *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));\n      }\n\n      for (int i = 1; i < 4; i++) {\n        *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =\n            *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));\n      }\n      return;\n    }\n\n    int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;\n    int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);\n\n    constexpr int k_frag_offsets[4] = {0, 1, 8,\n                                       9};  // Tensor core offsets per thread\n\n  #pragma unroll\n    for (int i = 0; i < 4; i++) {\n      int actual_k = cur_k + k_frag_offsets[i];\n\n      int group_id = sh_g_idx_int_ptr[actual_k];\n      int rel_group_id = group_id - sh_first_group_id;\n\n      *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) =\n          sh_s[rel_group_id * s_sh_stride + s_col_shift];\n    }\n  };\n\n  // Execute the actual tensor core matmul of a sub-tile.\n  auto matmul = [&](int k) {\n  // We have the m dimension as the inner loop in order to encourage overlapping\n  // dequantization and matmul operations.\n  #pragma unroll\n    for (int j = 0; j < 4; j++) {\n      FragB frag_b0;\n      FragB frag_b1;\n      if constexpr (num_bits == 4) {\n        int b_quant = frag_b_quant[k % 2][0][j];\n        int b_quant_shift = b_quant >> 8;\n\n        frag_b0 = dequant_4bit<scalar_t>(b_quant);\n        frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);\n\n      } else {\n        int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);\n        int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];\n        int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];\n\n        frag_b0 = dequant_8bit<scalar_t>(b_quant_0);\n        frag_b1 = dequant_8bit<scalar_t>(b_quant_1);\n      }\n\n      // Apply scale to frag_b0\n      if constexpr (has_act_order) {\n        scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j],\n                         act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],\n                         act_frag_s[k % 2][3][j], 0);\n      } else {\n        if constexpr (group_blocks != -1) {\n          scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0);\n        }\n      }\n\n      // Apply scale to frag_b1\n      if constexpr (has_act_order) {\n        scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],\n                         act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j],\n                         act_frag_s[k % 2][3][j], 1);\n\n      } else {\n        if constexpr (group_blocks != -1) {\n          scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1);\n        }\n      }\n\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n        mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);\n        mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);\n      }\n    }\n  };\n\n  // Since we slice across the k dimension of a tile in order to increase the\n  // number of warps while keeping the n dimension of a tile reasonable, we have\n  // multiple warps that accumulate their partial sums of the same output\n  // location; which we have to reduce over in the end. We do in shared memory.\n  auto thread_block_reduce = [&]() {\n    constexpr int red_off = threads / b_sh_stride_threads / 2;\n    if (red_off >= 1) {\n      int red_idx = threadIdx.x / b_sh_stride_threads;\n      constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;\n      constexpr int red_sh_delta = b_sh_stride_threads;\n      int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +\n                      (threadIdx.x % b_sh_stride_threads);\n\n      // Parallel logarithmic shared memory reduction. We make sure to avoid any\n      // unnecessary read or write iterations, e.g., for two warps we write only\n      // once by warp 1 and read only once by warp 0.\n\n  #pragma unroll\n      for (int m_block = 0; m_block < thread_m_blocks; m_block++) {\n  #pragma unroll\n        for (int i = red_off; i > 0; i /= 2) {\n          if (i <= red_idx && red_idx < 2 * i) {\n  #pragma unroll\n            for (int j = 0; j < 4 * 2; j++) {\n              int red_sh_wr =\n                  red_sh_delta * j + (red_sh_rd - red_sh_stride * i);\n              if (i < red_off) {\n                float* c_rd =\n                    reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);\n                float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);\n  #pragma unroll\n                for (int k = 0; k < 4; k++)\n                  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=\n                      c_rd[k] + c_wr[k];\n              }\n              sh[red_sh_wr] =\n                  reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];\n            }\n          }\n          __syncthreads();\n        }\n        if (red_idx == 0) {\n  #pragma unroll\n          for (int i = 0; i < 4 * 2; i++) {\n            float* c_rd =\n                reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);\n  #pragma unroll\n            for (int j = 0; j < 4; j++)\n              reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=\n                  c_rd[j];\n          }\n        }\n        __syncthreads();\n      }\n    }\n  };\n\n  // Since multiple threadblocks may process parts of the same column slice, we\n  // finally have to globally reduce over the results. As the striped\n  // partitioning minimizes the number of such reductions and our outputs are\n  // usually rather small, we perform this reduction serially in L2 cache.\n  auto global_reduce = [&](bool first = false, bool last = false) {\n    // We are very careful here to reduce directly in the output buffer to\n    // maximize L2 cache utilization in this step. To do this, we write out\n    // results in FP16 (but still reduce with FP32 compute).\n    constexpr int active_threads = 32 * thread_n_blocks / 4;\n    if (threadIdx.x < active_threads) {\n      int c_gl_stride = prob_n / 8;\n      int c_gl_wr_delta_o = 8 * c_gl_stride;\n      int c_gl_wr_delta_i = 4 * (active_threads / 32);\n      int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +\n                    4 * (threadIdx.x / 32) + threadIdx.x % 4;\n      c_gl_wr += (2 * thread_n_blocks) * slice_col;\n      constexpr int c_sh_wr_delta = active_threads;\n      int c_sh_wr = threadIdx.x;\n\n      int row = (threadIdx.x % 32) / 4;\n\n      if (!first) {\n  // Interestingly, doing direct global accesses here really seems to mess up\n  // the compiler and lead to slowdowns, hence we also use async-copies even\n  // though these fetches are not actually asynchronous.\n  #pragma unroll\n        for (int i = 0; i < thread_m_blocks * 4; i++) {\n          cp_async4_pred(\n              &sh[c_sh_wr + c_sh_wr_delta * i],\n              &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +\n                 c_gl_wr_delta_i * (i % 2)],\n              i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);\n        }\n        cp_async_fence();\n        cp_async_wait<0>();\n      }\n\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks * 4; i++) {\n        if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {\n          if (!first) {\n            int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];\n  #pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              reinterpret_cast<float*>(\n                  &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=\n                  Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);\n            }\n          }\n          if (!last) {\n            int4 c;\n  #pragma unroll\n            for (int j = 0; j < 2 * 4; j++) {\n              reinterpret_cast<scalar_t*>(&c)[j] =\n                  Dtype::float2num(reinterpret_cast<float*>(\n                      &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);\n            }\n            C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =\n                c;\n          }\n        }\n      }\n    }\n  };\n\n  // Write out the reduce final result in the correct layout. We only actually\n  // reshuffle matrix fragments in this step, the reduction above is performed\n  // in fragment layout.\n  auto write_result = [&]() {\n    int c_gl_stride = prob_n / 8;\n    constexpr int c_sh_stride = 2 * thread_n_blocks + 1;\n    int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));\n    constexpr int c_sh_rd_delta =\n        c_sh_stride * (threads / (2 * thread_n_blocks));\n\n    int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                  (threadIdx.x % (2 * thread_n_blocks));\n    c_gl_wr += (2 * thread_n_blocks) * slice_col;\n    int c_sh_wr =\n        (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;\n    c_sh_wr += 32 * (threadIdx.x / 32);\n    int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +\n                  (threadIdx.x % (2 * thread_n_blocks));\n\n    int c_gl_wr_end = c_gl_stride * prob_m;\n\n    // We first reorder in shared memory to guarantee the most efficient final\n    // global write patterns\n    auto write = [&](int idx, float c0, float c1, FragS& s) {\n      scalar_t2 res =\n          Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));\n\n      // For per-column quantization we finally apply the scale here (only for\n      // 4-bit)\n      if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {\n        res = __hmul2(res, s[0]);\n      }\n\n      ((scalar_t2*)sh)[idx] = res;\n    };\n\n    if (threadIdx.x / 32 < thread_n_blocks / 4) {\n  #pragma unroll\n      for (int i = 0; i < thread_m_blocks; i++) {\n  #pragma unroll\n        for (int j = 0; j < 4; j++) {\n          int wr = c_sh_wr + 8 * j;\n          write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],\n                frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);\n          write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],\n                frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);\n          write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],\n                frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);\n          write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],\n                frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);\n        }\n        c_sh_wr += 16 * (4 * c_sh_stride);\n      }\n    }\n    __syncthreads();\n\n  #pragma unroll\n    for (int i = 0;\n         i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));\n         i++) {\n      if (c_gl_wr < c_gl_wr_end) {\n        C[c_gl_wr] = sh[c_sh_rd];\n        c_gl_wr += c_gl_wr_delta;\n        c_sh_rd += c_sh_rd_delta;\n      }\n    }\n  };\n\n  // Start global fetch and register load pipelines.\n  auto start_pipes = [&]() {\n\n  #pragma unroll\n    for (int i = 0; i < stages - 1; i++) {\n      if (has_act_order && i == 0) {\n        int last_g_idx = slice_k_start + stages * tb_k * 2;\n        if (last_g_idx >= prob_k) {\n          last_g_idx = prob_k - 1;\n        }\n        fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);\n      }\n      fetch_to_shared(i, i, i < slice_iters);\n    }\n\n    zero_accums();\n    wait_for_stage();\n    init_same_group(0);\n    fetch_to_registers(0, 0);\n    fetch_scales_to_registers(0, 0);\n    a_gl_rd += a_gl_rd_delta_o * (stages - 1);\n    slice_k_start_shared_fetch += tb_k * (stages - 1);\n  };\n  if (slice_iters) {\n    start_pipes();\n  }\n\n  // Main loop.\n  while (slice_iters) {\n    // We unroll over both the global fetch and the register load pipeline to\n    // ensure all shared memory accesses are static. Note that both pipelines\n    // have even length meaning that the next iteration will always start at\n    // index 0.\n\n  #pragma unroll\n    for (int pipe = 0; pipe < stages;) {\n  #pragma unroll\n      for (int k = 0; k < b_sh_wr_iters; k++) {\n        fetch_to_registers(k + 1, pipe % stages);\n        fetch_scales_to_registers(k + 1, pipe);\n        if (k == b_sh_wr_iters - 2) {\n          fetch_to_shared((pipe + stages - 1) % stages, pipe,\n                          slice_iters >= stages);\n          pipe++;\n          wait_for_stage();\n          init_same_group(pipe % stages);\n        }\n        matmul(k);\n      }\n      slice_iters--;\n      if (slice_iters == 0) {\n        break;\n      }\n    }\n\n    a_gl_rd += a_gl_rd_delta_o * stages;\n    slice_k_start += tb_k * stages;\n    slice_k_start_shared_fetch += tb_k * stages;\n\n    if constexpr (has_act_order) {\n      int first_group_id = g_idx[slice_k_start];\n      int last_g_idx = slice_k_start + stages * tb_k * 2;\n      if (last_g_idx >= prob_k) {\n        last_g_idx = prob_k - 1;\n      }\n      int last_group_id = g_idx[last_g_idx];\n      if (last_group_id >= sh_first_group_id + sh_num_groups) {\n        fetch_scales_to_shared(false, first_group_id, last_group_id);\n        __syncthreads();\n      }\n    }\n\n    // Process results and, if necessary, proceed to the next column slice.\n    // While this pattern may not be the most readable, other ways of writing\n    // the loop seemed to noticeably worse performance after compilation.\n    if (slice_iters == 0) {\n      cp_async_wait<0>();\n      bool last = slice_idx == slice_count - 1;\n      // For per-column scales, we only fetch them here in the final step before\n      // write-out\n      if constexpr (!has_act_order && group_blocks == -1) {\n        if constexpr (num_bits == 8) {\n          if (s_sh_wr_pred) {\n            cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n          }\n          cp_async_fence();\n        } else {\n          if (last) {\n            if (s_sh_wr_pred) {\n              cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);\n            }\n            cp_async_fence();\n          }\n        }\n      }\n\n      thread_block_reduce();\n      if constexpr (!has_act_order && group_blocks == -1) {\n        if constexpr (num_bits == 8) {\n          cp_async_wait<0>();\n          __syncthreads();\n          if (threadIdx.x / 32 < thread_n_blocks / 4) {\n            reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n            reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n          }\n\n        } else {\n          if (last) {\n            cp_async_wait<0>();\n            __syncthreads();\n            if (threadIdx.x / 32 < thread_n_blocks / 4) {\n              reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];\n              reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];\n            }\n          }\n        }\n      }\n\n      // For 8-bit channelwise, we apply the scale before the global reduction\n      // that converts the fp32 results to fp16 (so that we avoid possible\n      // overflow in fp16)\n      if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {\n        if (threadIdx.x / 32 < thread_n_blocks / 4) {\n  #pragma unroll\n          for (int i = 0; i < thread_m_blocks; i++) {\n  #pragma unroll\n            for (int j = 0; j < 4; j++) {\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][0][0]),\n                  frag_s[j / 2][2 * (j % 2) + 0]);\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][0][2]),\n                  frag_s[j / 2][2 * (j % 2) + 0]);\n\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][1][0]),\n                  frag_s[j / 2][2 * (j % 2) + 1]);\n              scale_float<scalar_t>(\n                  reinterpret_cast<float*>(&frag_c[i][j][1][2]),\n                  frag_s[j / 2][2 * (j % 2) + 1]);\n            }\n          }\n        }\n      }\n\n      if (slice_count > 1) {  // only globally reduce if there is more than one\n                              // block in a slice\n        barrier_acquire(&locks[slice_col], slice_idx);\n        global_reduce(slice_idx == 0, last);\n        barrier_release(&locks[slice_col], last);\n      }\n      if (last)  // only the last block in a slice actually writes the result\n        write_result();\n      slice_row = 0;\n      slice_col_par++;\n      slice_col++;\n      init_slice();\n      if (slice_iters) {\n        a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +\n                  (threadIdx.x % a_gl_rd_delta_o);\n  #pragma unroll\n        for (int i = 0; i < b_sh_wr_iters; i++)\n          B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;\n        if (slice_col == 0) {\n  #pragma unroll\n          for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;\n        }\n\n        // Update slice k/n for scales loading\n        if constexpr (has_act_order) {\n          slice_k_start = tb_k * slice_row;\n          slice_k_finish = slice_k_start + tb_k * slice_iters;\n          slice_k_start_shared_fetch = slice_k_start;\n          slice_n_offset = act_s_col_tb_stride * slice_col;\n\n        } else {\n          s_gl_rd = s_sh_stride * slice_col + threadIdx.x;\n        }\n\n        start_pipes();\n      }\n    }\n  }\n}\n\n  #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS,                \\\n                    THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \\\n    else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS &&     \\\n             thread_n_blocks == THREAD_N_BLOCKS &&                             \\\n             thread_k_blocks == THREAD_K_BLOCKS &&                             \\\n             has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \\\n             num_threads == NUM_THREADS) {                                     \\\n      cudaFuncSetAttribute(                                                    \\\n          Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,             \\\n                 THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \\\n                 GROUP_BLOCKS>,                                                \\\n          cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);        \\\n      Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS,                 \\\n             THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER,     \\\n             GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>(   \\\n          A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n,   \\\n          prob_k, locks);                                                      \\\n    }\n\ntypedef struct {\n  int thread_k;\n  int thread_n;\n  int num_threads;\n} thread_config_t;\n\ntypedef struct {\n  int max_m_blocks;\n  thread_config_t tb_cfg;\n} exec_config_t;\n\nthread_config_t small_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {128, 128, 256},\n    {64, 128, 128},\n    {128, 64, 128},\n};\n\nthread_config_t large_batch_thread_configs[] = {\n    // Ordered by priority\n\n    // thread_k, thread_n, num_threads\n    {64, 256, 256},\n    {64, 128, 128},\n    {128, 64, 128},\n\n};\n\nint get_scales_cache_size(thread_config_t const& th_config, int prob_m,\n                          int prob_n, int prob_k, int num_bits, int group_size,\n                          bool has_act_order, bool is_k_full) {\n  bool cache_scales_chunk = has_act_order && !is_k_full;\n\n  int tb_n = th_config.thread_n;\n  int tb_k = th_config.thread_k;\n\n  // Get max scale groups per thread-block\n  int tb_groups;\n  if (group_size == -1) {\n    tb_groups = 1;\n  } else if (group_size == 0) {\n    tb_groups = div_ceil(tb_k, 32);  // Worst case is 32 group size\n  } else {\n    tb_groups = div_ceil(tb_k, group_size);\n  }\n\n  if (cache_scales_chunk) {\n    int load_groups =\n        tb_groups * pipe_stages * 2;     // Chunk size is 2x pipeline over dim K\n    load_groups = max(load_groups, 32);  // We load at least 32 scale groups\n    return load_groups * tb_n * 2;\n\n  } else {\n    int tb_scales = tb_groups * tb_n * 2;\n\n    return tb_scales * pipe_stages;\n  }\n}\n\nbool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,\n                         int prob_m, int prob_n, int prob_k, int num_bits,\n                         int scales_cache_size, int max_shared_mem) {\n  int pack_factor = 32 / num_bits;\n\n  // Get B size\n  int tb_k = th_config.thread_k;\n  int tb_n = th_config.thread_n;\n\n  int b_size = (tb_k * tb_n / pack_factor) * 4;\n\n  // Get A size\n  int m_blocks = div_ceil(prob_m, 16);\n  int tb_max_m = 16;\n\n  while (true) {\n    if (m_blocks >= max_m_blocks) {\n      tb_max_m *= max_m_blocks;\n      break;\n    }\n\n    max_m_blocks--;\n    if (max_m_blocks == 0) {\n      TORCH_CHECK(false, \"Unexpected m_blocks = \", m_blocks);\n    }\n  }\n\n  int a_size = (tb_max_m * tb_k) * 2;\n\n  float pipe_size = (a_size + b_size) * pipe_stages;\n\n  TORCH_CHECK(max_shared_mem / 2 > scales_cache_size);  // Sanity\n\n  return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);\n}\n\nbool is_valid_config(thread_config_t const& th_config, int max_m_blocks,\n                     int prob_m, int prob_n, int prob_k, int num_bits,\n                     int group_size, bool has_act_order, bool is_k_full,\n                     int max_shared_mem) {\n  // Sanity\n  if (th_config.thread_k == -1 || th_config.thread_n == -1 ||\n      th_config.num_threads == -1) {\n    return false;\n  }\n\n  // Verify K/N are divisible by thread K/N\n  if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {\n    return false;\n  }\n\n  // Verify min for thread K/N\n  if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {\n    return false;\n  }\n\n  // num_threads must be at least 128 (= 4 warps)\n  if (th_config.num_threads < 128) {\n    return false;\n  }\n\n  //  Determine cache for scales\n  int scales_cache_size =\n      get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,\n                            group_size, has_act_order, is_k_full);\n\n  // Check that pipeline fits into cache\n  if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                           num_bits, scales_cache_size, max_shared_mem)) {\n    return false;\n  }\n\n  return true;\n}\n\nexec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,\n                                      int num_bits, int group_size,\n                                      bool has_act_order, bool is_k_full,\n                                      int max_shared_mem) {\n  int max_m_blocks = 4;\n  while (max_m_blocks > 0) {\n    if (prob_m <= 16) {\n      for (auto th_config : small_batch_thread_configs) {\n        if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                            num_bits, group_size, has_act_order, is_k_full,\n                            max_shared_mem)) {\n          return exec_config_t{max_m_blocks, th_config};\n        }\n      }\n    } else {\n      for (auto th_config : large_batch_thread_configs) {\n        if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,\n                            num_bits, group_size, has_act_order, is_k_full,\n                            max_shared_mem)) {\n          return exec_config_t{max_m_blocks, th_config};\n        }\n      }\n    }\n\n    max_m_blocks--;  // Process less M blocks per invocation to reduce cache\n                     // usage\n  }\n\n  return exec_config_t{0, {-1, -1, -1}};\n}\n\n  #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS)           \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS)   \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)  \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)  \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)  \\\n                                                                       \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS)  \\\n    __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)\n\ntemplate <typename scalar_t>\nvoid marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,\n                     void* g_idx, void* perm, void* a_tmp, int prob_m,\n                     int prob_n, int prob_k, void* workspace, int num_bits,\n                     bool has_act_order, bool is_k_full, int num_groups,\n                     int group_size, int dev, cudaStream_t stream, int thread_k,\n                     int thread_n, int sms, int max_par) {\n  TORCH_CHECK(num_bits == 4 || num_bits == 8,\n              \"num_bits must be 4 or 8. Got = \", num_bits);\n  TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, \"Invalid MNK = [\", prob_m,\n              \", \", prob_n, \", \", prob_k, \"]\");\n\n  int tot_m = prob_m;\n  int tot_m_blocks = div_ceil(tot_m, 16);\n  int pad = 16 * tot_m_blocks - tot_m;\n\n  if (sms == -1) {\n    cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);\n  }\n\n  int max_shared_mem = 0;\n  cudaDeviceGetAttribute(&max_shared_mem,\n                         cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);\n  TORCH_CHECK(max_shared_mem > 0);\n\n  // Set thread config\n  exec_config_t exec_cfg;\n  if (thread_k != -1 && thread_n != -1) {\n    // User-defined config\n    exec_cfg =\n        exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};\n  } else {\n    // Auto config\n    exec_cfg =\n        determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,\n                                has_act_order, is_k_full, max_shared_mem);\n  }\n\n  TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&\n                  is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,\n                                  prob_m, prob_n, prob_k, num_bits, group_size,\n                                  has_act_order, is_k_full, max_shared_mem),\n              \"Invalid thread config: max_m_blocks = \", exec_cfg.max_m_blocks,\n              \", thread_k = \", exec_cfg.tb_cfg.thread_k,\n              \", thread_n = \", exec_cfg.tb_cfg.thread_n,\n              \", num_threads = \", exec_cfg.tb_cfg.num_threads, \" for MKN = [\",\n              prob_m, \", \", prob_k, \", \", prob_n, \"] and num_bits = \", num_bits,\n              \", group_size = \", group_size,\n              \", has_act_order = \", has_act_order, \", is_k_full = \", is_k_full,\n              \", max_shared_mem = \", max_shared_mem);\n\n  int num_threads = exec_cfg.tb_cfg.num_threads;\n  thread_k = exec_cfg.tb_cfg.thread_k;\n  thread_n = exec_cfg.tb_cfg.thread_n;\n\n  int thread_k_blocks = thread_k / 16;\n  int thread_n_blocks = thread_n / 16;\n\n  int blocks = sms;\n\n  TORCH_CHECK(prob_n % thread_n == 0, \"prob_n = \", prob_n,\n              \" is not divisible by thread_n = \", thread_n);\n  TORCH_CHECK(prob_k % thread_k == 0, \"prob_k = \", prob_k,\n              \" is not divisible by thread_k = \", thread_k);\n\n  int group_blocks = 0;\n  if (has_act_order) {\n    if (is_k_full) {\n      TORCH_CHECK(group_size != -1);\n      group_blocks = group_size / 16;\n      TORCH_CHECK(prob_k % group_blocks == 0, \"prob_k = \", prob_k,\n                  \" is not divisible by group_blocks = \", group_blocks);\n    } else {\n      TORCH_CHECK(group_size == 0);\n      group_blocks = 0;\n    }\n\n  } else {\n    if (group_size == -1) {\n      group_blocks = -1;\n    } else {\n      group_blocks = group_size / 16;\n      TORCH_CHECK(prob_k % group_blocks == 0, \"prob_k = \", prob_k,\n                  \" is not divisible by group_blocks = \", group_blocks);\n    }\n  }\n\n  const int4* A_ptr = (const int4*)A;\n  const int4* B_ptr = (const int4*)B;\n  int4* C_ptr = (int4*)C;\n  const int4* s_ptr = (const int4*)s;\n  const int* g_idx_ptr = (const int*)g_idx;\n  const int* perm_ptr = (const int*)perm;\n  int4* a_tmp_ptr = (int4*)a_tmp;\n\n  int* locks = (int*)workspace;\n\n  if (has_act_order) {\n    // Permute A columns\n    int block_rows = div_ceil(prob_m, blocks);\n    permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(\n        A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);\n    A_ptr = a_tmp_ptr;\n  }\n\n  // If we have a full K, then we can run the non-act-order version of Marlin\n  // (since the weight rows are reordered by increasing group ids, and by having\n  // a full K, we have full original groups)\n  if (is_k_full) {\n    has_act_order = false;\n  }\n\n  // Main loop\n  for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {\n    int thread_m_blocks = tot_m_blocks - i;\n    prob_m = tot_m - 16 * i;\n    int par = 1;\n    if (thread_m_blocks > exec_cfg.max_m_blocks) {\n      // Note that parallel > 1 currently only works for inputs without any\n      // padding\n      par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);\n      if (par > max_par) par = max_par;\n      prob_m = (16 * exec_cfg.max_m_blocks) * par;\n      i += exec_cfg.max_m_blocks * (par - 1);\n      thread_m_blocks = exec_cfg.max_m_blocks;\n    }\n\n\n\n    // Define kernel configurations\n#define undefined_error TORCH_CHECK(false, \"Unsupported shapes: MNK = [\" + str(prob_m) + \", \" + \\\n    str(prob_n) + \", \" + str(prob_k) + \"]\" + \\\n        \", has_act_order = \" + str(has_act_order) + \\\n        \", num_groups = \" + str(num_groups) + \\\n        \", group_size = \" + str(group_size) + \\\n        \", thread_m_blocks = \" + str(thread_m_blocks) + \\\n        \", thread_n_blocks = \" + str(thread_n_blocks) + \\\n        \", thread_k_blocks = \" + str(thread_k_blocks));\n\n\n    if (num_bits == 4 && num_threads == 256)\n    {\n        if (false) {\n        }\n        CALL_IF(4, 32, 2, 256)\n        CALL_IF(4, 16, 4, 256)\n        CALL_IF(4, 8, 8, 256)\n        else {\n            undefined_error\n        }\n    }\n    else if (num_bits == 4 && num_threads == 128)\n    {\n        if (false) {\n        }\n        CALL_IF(4, 8, 4, 128)\n        CALL_IF(4, 4, 8, 128)\n        else {\n            undefined_error\n        }\n    }\n    else if (num_bits == 8 && num_threads == 256)\n    {\n        if (false) {\n        }\n        CALL_IF(8, 32, 2, 256)\n        CALL_IF(8, 16, 4, 256)\n        CALL_IF(8, 8, 8, 256)\n        else {\n            undefined_error\n        }\n    }\n    else if (num_bits == 8 && num_threads == 128)\n    {\n        if (false) {\n        }\n        CALL_IF(8, 8, 4, 128)\n        CALL_IF(8, 4, 8, 128)\n        else {\n            undefined_error\n        }\n    }\n    else {\n        undefined_error\n    }\n\n    A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;\n    C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;\n  }\n}\n\n}  // namespace gptq_marlin\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                               torch::Tensor& b_scales, torch::Tensor& g_idx,\n                               torch::Tensor& perm, torch::Tensor& workspace,\n                               int64_t num_bits, int64_t size_m, int64_t size_n,\n                               int64_t size_k, bool is_k_full) {\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(a));\n  // Verify num_bits\n  TORCH_CHECK(num_bits == 4 || num_bits == 8,\n              \"num_bits must be 4 or 8. Got = \", num_bits);\n  int pack_factor = 32 / num_bits;\n\n  // Verify A\n  TORCH_CHECK(a.size(0) == size_m, \"Shape mismatch: a.size(0) = \", a.size(0),\n              \", size_m = \", size_m);\n  TORCH_CHECK(a.size(1) == size_k, \"Shape mismatch: a.size(1) = \", a.size(1),\n              \", size_k = \", size_k);\n\n  // Verify B\n  TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, \"size_k = \", size_k,\n              \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n  TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),\n              \"Shape mismatch: b_q_weight.size(0) = \", b_q_weight.size(0),\n              \", size_k = \", size_k, \", tile_size = \", gptq_marlin::tile_size);\n  TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,\n              \"b_q_weight.size(1) = \", b_q_weight.size(1),\n              \" is not divisible by tile_size = \", gptq_marlin::tile_size);\n  int actual_size_n =\n      (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;\n  TORCH_CHECK(size_n == actual_size_n, \"size_n = \", size_n,\n              \", actual_size_n = \", actual_size_n);\n\n  // Verify device and strides\n  TORCH_CHECK(a.device().is_cuda(), \"A is not on GPU\");\n  TORCH_CHECK(a.is_contiguous(), \"A is not contiguous\");\n\n  TORCH_CHECK(b_q_weight.device().is_cuda(), \"b_q_weight is not on GPU\");\n  TORCH_CHECK(b_q_weight.is_contiguous(), \"b_q_weight is not contiguous\");\n\n  TORCH_CHECK(b_scales.device().is_cuda(), \"b_scales is not on GPU\");\n  TORCH_CHECK(b_scales.is_contiguous(), \"b_scales is not contiguous\");\n\n  TORCH_CHECK(g_idx.device().is_cuda(), \"g_idx is not on GPU\");\n  TORCH_CHECK(g_idx.is_contiguous(), \"g_idx is not contiguous\");\n\n  TORCH_CHECK(perm.device().is_cuda(), \"perm is not on GPU\");\n  TORCH_CHECK(perm.is_contiguous(), \"perm is not contiguous\");\n\n  // Alloc buffers\n  auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());\n  torch::Tensor c = torch::empty({size_m, size_n}, options);\n  torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);\n\n  // thread_k: `k` size of a thread_tile in `weights` (can usually be left as\n  // auto -1)\n  int thread_k = -1;\n  // thread_n: `n` size of a thread_tile in `weights` (can usually be left as\n  // auto -1)\n  int thread_n = -1;\n  // sms: number of SMs to use for the kernel (can usually be left as auto -1)\n  int sms = -1;\n\n  // Verify g_idx and perm\n  TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||\n                  (g_idx.size(0) == size_k && perm.size(0) == size_k),\n              \"Unexpected g_idx.size(0) = \", g_idx.size(0),\n              \" and perm.size(0) = \", perm.size(0),\n              \", where size_k = \", size_k);\n\n  // Detect groupsize and act_order\n  int num_groups = -1;\n  int group_size = -1;\n  bool has_act_order = g_idx.size(0) != 0;\n\n  int b_rank = b_scales.sizes().size();\n  TORCH_CHECK(b_rank == 2, \"b_scales rank = \", b_rank, \" is not 2\");\n  TORCH_CHECK(b_scales.size(1) == size_n, \"b_scales dim 1 = \", b_scales.size(1),\n              \" is not size_n = \", size_n);\n  num_groups = b_scales.size(0);\n\n  if (has_act_order) {\n    if (is_k_full) {\n      TORCH_CHECK(num_groups > 1, \"For act_order, num_groups must be > 1\");\n      TORCH_CHECK(size_k % num_groups == 0, \"size_k = \", size_k,\n                  \", is not divisible by num_groups = \", num_groups);\n      group_size = size_k / num_groups;\n    } else {\n      group_size = 0;\n    }\n\n  } else {\n    if (num_groups > 1) {\n      TORCH_CHECK(\n          size_k % num_groups == 0, \"size_k = \", size_k,\n          \", is not divisible by b_scales.size(0) = \", b_scales.size(0));\n      group_size = size_k / num_groups;\n    } else {\n      group_size = -1;\n    }\n  }\n\n  // Verify workspace size\n  TORCH_CHECK(\n      size_n % gptq_marlin::min_thread_n == 0, \"size_n = \", size_n,\n      \", is not divisible by min_thread_n = \", gptq_marlin::min_thread_n);\n  int min_workspace_size =\n      (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;\n  TORCH_CHECK(workspace.numel() >= min_workspace_size,\n              \"workspace.numel = \", workspace.numel(),\n              \" is below min_workspace_size = \", min_workspace_size);\n\n  int dev = a.get_device();\n  if (a.scalar_type() == at::ScalarType::Half) {\n    gptq_marlin::marlin_mm_f16i4<half>(\n        a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),\n        b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),\n        a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,\n        workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,\n        group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,\n        thread_n, sms, gptq_marlin::max_par);\n  } else if (a.scalar_type() == at::ScalarType::BFloat16) {\n    gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(\n        a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),\n        c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),\n        g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),\n        size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,\n        is_k_full, num_groups, group_size, dev,\n        at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,\n        gptq_marlin::max_par);\n  } else {\n    TORCH_CHECK(false, \"gpt_marlin_gemm only supports bfloat16 and float16\");\n  }\n\n  return c;\n}\n\n#endif\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n// Copyrigth 2024 The vLLM team.\n// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#pragma once\n\n#include <torch/all.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <iostream>\n\nnamespace gptq_marlin {\n\n// 8 warps are a good choice since every SM has 4 schedulers and having more\n// than 1 warp per schedule allows some more latency hiding. At the same time,\n// we want relatively few warps to have many registers per warp and small tiles.\nstatic constexpr int default_threads = 256;\n\nstatic constexpr int pipe_stages =\n    4;  // 4 pipeline stages fit into shared memory\n\nstatic constexpr int min_thread_n = 64;\nstatic constexpr int min_thread_k = 64;\n\nstatic constexpr int tile_size = 16;\nstatic constexpr int max_par = 16;\n\ntemplate <typename T, int n>\nstruct Vec {\n  T elems[n];\n  __device__ T& operator[](int i) { return elems[i]; }\n};\n\nusing I4 = Vec<int, 4>;\n\nconstexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }\n\n#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800) || defined (__HIP_PLATFORM_AMD__)\n// No support for async\n#else\n\n__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,\n                                      bool pred = true) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n      \"{\\n\"\n      \"   .reg .pred p;\\n\"\n      \"   setp.ne.b32 p, %0, 0;\\n\"\n      \"   @p cp.async.cg.shared.global [%1], [%2], %3;\\n\"\n      \"}\\n\" ::\"r\"((int)pred),\n      \"r\"(smem), \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {\n  const int BYTES = 16;\n  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));\n  asm volatile(\n      \"{\\n\"\n      \"   cp.async.cg.shared.global [%0], [%1], %2;\\n\"\n      \"}\\n\" ::\"r\"(smem),\n      \"l\"(glob_ptr), \"n\"(BYTES));\n}\n\n__device__ inline void cp_async_fence() {\n  asm volatile(\"cp.async.commit_group;\\n\" ::);\n}\n\ntemplate <int n>\n__device__ inline void cp_async_wait() {\n  asm volatile(\"cp.async.wait_group %0;\\n\" ::\"n\"(n));\n}\n\n#endif\n\n}  // namespace gptq_marlin\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin_dtypes.cuh",
    "content": "// Adapted from\n// https://github.com/vllm-project/vllm/tree/main/csrc/quantization/gptq_marlin\n// Copyrigth 2024 The vLLM team.\n// Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#ifndef _data_types_cuh\n#define _data_types_cuh\n#include \"gptq_marlin.cuh\"\n#include <cuda_fp16.h>\n#include <cuda_bf16.h>\n\n#ifdef __HIP_PLATFORM_AMD__\ntypedef __hip_bfloat16 nv_bfloat16;\ntypedef __hip_bfloat162 nv_bfloat162;\n#endif\n\nnamespace gptq_marlin {\n\ntemplate <typename scalar_t>\nclass ScalarType {};\n\ntemplate <>\nclass ScalarType<half> {\n public:\n  using scalar_t = half;\n  using scalar_t2 = half2;\n\n  // Matrix fragments for tensor core instructions; their precise layout is\n  // documented here:\n  // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type\n  using FragA = Vec<half2, 4>;\n  using FragB = Vec<half2, 2>;\n  using FragC = Vec<float, 4>;\n  using FragS = Vec<half2, 1>;\n\n  static __device__ float inline num2float(const half x) {\n    return __half2float(x);\n  }\n\n  static __device__ half2 inline num2num2(const half x) {\n    return __half2half2(x);\n  }\n\n  static __device__ half2 inline nums2num2(const half x1, const half x2) {\n    return __halves2half2(x1, x2);\n  }\n\n  static __host__ __device__ half inline float2num(const float x) {\n    return __float2half(x);\n  }\n};\n\ntemplate <>\nclass ScalarType<nv_bfloat16> {\n public:\n  using scalar_t = nv_bfloat16;\n  using scalar_t2 = nv_bfloat162;\n\n  using FragA = Vec<nv_bfloat162, 4>;\n  using FragB = Vec<nv_bfloat162, 2>;\n  using FragC = Vec<float, 4>;\n  using FragS = Vec<nv_bfloat162, 1>;\n\n#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800\n  static __device__ float inline num2float(const nv_bfloat16 x) {\n    return __bfloat162float(x);\n  }\n\n  static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {\n    return __bfloat162bfloat162(x);\n  }\n\n  static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,\n                                                  const nv_bfloat16 x2) {\n    return __halves2bfloat162(x1, x2);\n  }\n\n  static __host__ __device__ nv_bfloat16 inline float2num(const float x) {\n    return __float2bfloat16(x);\n  }\n#endif\n};\n\n}  // namespace gptq_marlin\n\n#endif\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cuda/gptq_marlin/ops.h",
    "content": "/**\n * @Description  :  \n * @Author       : Azure\n * @Date         : 2024-07-22 09:27:55\n * @Version      : 1.0.0\n * @LastEditors  : Azure \n * @LastEditTime : 2024-07-26 08:35:00\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. \n**/\n#pragma once\n\n#include <torch/library.h>\n#include <torch/extension.h>\n#include <torch/torch.h>\n\ntorch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,\n                               torch::Tensor& b_scales, torch::Tensor& g_idx,\n                               torch::Tensor& perm, torch::Tensor& workspace,\n                               int64_t num_bits, int64_t size_m, int64_t size_n,\n                               int64_t size_k, bool is_k_full);\n\n// torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,\n//                                  int64_t size_k, int64_t size_n,\n//                                  int64_t num_bits);"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cuda/setup.py",
    "content": "\nfrom setuptools import setup, Extension\nfrom torch.utils import cpp_extension\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nsetup(\n    name='KTransformersOps',\n    ext_modules=[\n        CUDAExtension(\n            'KTransformersOps', [\n                'custom_gguf/dequant.cu',\n                'binding.cpp',\n                'gptq_marlin/gptq_marlin.cu',\n                # 'gptq_marlin_repack.cu',\n            ],\n            extra_compile_args={\n                'cxx': ['-O3'],\n                'nvcc': [\n                    '-O3',\n                    '--use_fast_math',\n                    '-Xcompiler', '-fPIC',\n                ]\n            },\n        )\n    ],\n    cmdclass={'build_ext': BuildExtension}\n)"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/cuda/test_dequant.py",
    "content": "import os\nimport sys\nsys.path.insert(0,\"/home/zbx/ktransformers\")\nfrom ktransformers.util.custom_loader import GGUFLoader\nimport torch\n\ngguf_loader_1 = GGUFLoader(\"/mnt/data/model/DeepseekV3-q4km-gguf\")\ngguf_loader_2 = GGUFLoader(\"/mnt/data/chenht/model/gguf_for_ktransformers/DeepSeek-V3-bf16/\")\n\ntorch.set_default_dtype(torch.bfloat16)\n\ntensor_1 = gguf_loader_1.load_gguf_tensor(\"blk.0.attn_kv_a_mqa.weight\", \"cuda\")\ntensor_2 = gguf_loader_2.load_gguf_tensor(\"blk.0.attn_kv_a_mqa.weight\", \"cuda\")\n\nprint(tensor_1[0, -64:])\nprint(tensor_2[0, -64:])"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/examples/test_attention.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :  \nAuthor       : Jianwei Dong\nDate         : 2024-08-28 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-28 10:32:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\nimport os, sys\nimport time\n\nsys.path.append(os.path.dirname(__file__) + \"/../build\")\nimport cpuinfer_ext\nfrom flash_attn import flash_attn_with_kvcache\nimport torch\n\nlayer_num = 10\nkv_head_num = 8\nq_head_num = 32\nhead_dim = 128\nblock_len = 128\nanchor_num = 1\ncache_seqlen = 8192\ncache_seqlens = torch.tensor([cache_seqlen], dtype=torch.int32, device=\"cpu\")\nseqlens_zero = torch.zeros((1,), dtype=torch.int32, device=\"cpu\")\nanchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC\nkv_type = cpuinfer_ext.kvcache.ggml_type.FP16\nretrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER\nlayer_step: int = 1\ntoken_step: int = 1\nlayer_offset: int = 0\nmax_thread_num: int = 2\nmax_batch_size: int = 1\nmax_block_num: int = 512\nCPUInfer = cpuinfer_ext.CPUInfer(max_thread_num)\nvalidation_iter = 100\n\nwith torch.inference_mode(mode=True):\n    config = cpuinfer_ext.kvcache.KVCacheConfig(\n        layer_num,\n        kv_head_num,\n        q_head_num,\n        head_dim,\n        block_len,\n        anchor_num,\n        anchor_type,\n        kv_type,\n        retrieval_type,\n        layer_step,\n        token_step,\n        layer_offset,\n        max_block_num,\n        max_batch_size,\n        max_thread_num,\n    )\n    local_kvcache = cpuinfer_ext.kvcache.KVCache(config)\n\n    kvcaches = []\n    block_table = (\n        torch.arange(max_block_num, dtype=torch.int32, device=\"cpu\")\n        .contiguous()\n        .view(1, -1)\n    )\n\n    for layer_idx in range(layer_num):\n        k_cache = torch.randn(\n            (1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n        v_cache = torch.randn(\n            (1, cache_seqlen, kv_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n\n        CPUInfer.submit(\n            local_kvcache.update_kvcache_fp16(\n                k_cache.data_ptr(),\n                v_cache.data_ptr(),\n                layer_idx,\n                block_table.data_ptr(),\n                1,\n                max_block_num,\n                seqlens_zero.data_ptr(),\n                cache_seqlen,\n            )\n        )\n        CPUInfer.sync()\n\n        kvcaches.append((k_cache.to(\"cuda\"), v_cache.to(\"cuda\")))\n\n    # validation\n    for i in range(validation_iter):\n\n        k_cache = kvcaches[i % layer_num][0]\n        v_cache = kvcaches[i % layer_num][1]\n        input = torch.randn(\n            (1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n        output = torch.empty(\n            (1, 1, q_head_num, head_dim), dtype=torch.float16, device=\"cpu\"\n        ).contiguous()\n\n        # attn_lse: (bsz, q_len, q_head_num)\n        attn_lse = torch.empty(\n            (1, 1, q_head_num), dtype=torch.float32, device=\"cpu\"\n        ).contiguous()\n        input = input / 100\n\n        CPUInfer.submit(\n            local_kvcache.attn(\n                input.data_ptr(),\n                output.data_ptr(),\n                attn_lse.data_ptr(),\n                i % layer_num,\n                0,\n                1,\n                1,\n                max_block_num,\n                block_table.data_ptr(),\n                cache_seqlens.data_ptr(),\n                -1,\n                -1,\n                -1,\n            )\n        )\n        CPUInfer.sync()\n        # print(\"cpuinfer output\", output)\n\n        t_output = flash_attn_with_kvcache(\n            q=input.to(\"cuda\"),\n            k_cache=k_cache,\n            v_cache=v_cache,\n            cache_seqlens=cache_seqlens.to(\"cuda\"),\n        )\n        # print(\"torch output\", t_output)\n\n        diff = torch.mean(torch.abs(output.to(\"cuda\") - t_output)) / torch.mean(\n            torch.abs(t_output)\n        )\n        print(\"diff = \", diff)\n        assert diff < 0.001\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/examples/test_linear.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:36:59\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\ninput_size = 16384\noutput_size = 5120\nstride = 32\ngroup_max_len = 1024\nproj_type = 1 # ggml_type::GGML_TYPE_F16\nhidden_type = 1 # ggml_type::GGML_TYPE_F16\nqlen = 30\nlayer_num = 10\nCPUInfer = cpuinfer_ext.CPUInfer(48)\nvalidation_iter = 100\n\nwith torch.inference_mode(mode=True):\n    linears = []\n    projs = []\n    for _ in range(layer_num):\n        proj = torch.randn((output_size, input_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)\n        linear = cpuinfer_ext.linear.Linear(config)\n        projs.append(proj)\n        linears.append(linear)\n\n    # validation\n    for i in range(validation_iter):\n        linear = linears[i % layer_num]\n        input = torch.randn((qlen, input_size), dtype=torch.float16).contiguous()\n        output = torch.empty((qlen, output_size), dtype=torch.float16).contiguous()\n        input = input / 100\n\n        CPUInfer.submit(\n            linear.forward(\n                qlen,\n                input.data_ptr(),\n                output.data_ptr()\n            )\n        )\n        CPUInfer.sync()\n        # print('cpuinfer output', output)\n\n        proj = projs[i%layer_num]\n        t_output = torch.mm(input, proj.t())\n        # print('torch output', t_output)\n\n        diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n        print('diff = ', diff)\n        assert(diff < 0.001)\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/examples/test_mlp.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:37:28\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nhidden_size = 5120\nintermediate_size = 3072\nstride = 32\ngroup_max_len = 1024\ngate_type = 1 # ggml_type::GGML_TYPE_F16\nup_type = 1 # ggml_type::GGML_TYPE_F16\ndown_type = 1 # ggml_type::GGML_TYPE_F16\nhidden_type = 1 # ggml_type::GGML_TYPE_F16\nqlen = 30\nlayer_num = 10\nCPUInfer = cpuinfer_ext.CPUInfer(48)\nvalidation_iter = 100\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\nwith torch.inference_mode(mode=True):\n    mlps = []\n    gate_projs = []\n    up_projs = []\n    down_projs = []\n    for _ in range(layer_num):\n        gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)\n        mlp = cpuinfer_ext.mlp.MLP(config)\n        gate_projs.append(gate_proj)\n        up_projs.append(up_proj)\n        down_projs.append(down_proj)\n        mlps.append(mlp)\n\n    # validation\n    for i in range(validation_iter):\n        mlp = mlps[i % layer_num]\n        input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()\n        output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()\n        input = input / 100\n\n        CPUInfer.submit(\n            mlp.forward(\n                qlen,\n                input.data_ptr(), \n                output.data_ptr()\n            )\n        )\n        CPUInfer.sync()\n        # print('cpuinfer output', output)\n\n        gate_proj = gate_projs[i%layer_num]\n        up_proj = up_projs[i%layer_num]\n        down_proj = down_projs[i%layer_num]\n        t_output = mlp_torch(input, gate_proj, up_proj, down_proj)\n        # print('torch output', t_output)\n\n        diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n        print('diff = ', diff)\n        assert(diff < 0.001)\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/examples/test_moe.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:38:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nexpert_num = 160\nhidden_size = 5120\nintermediate_size = 1536\nstride = 32\ngroup_min_len = 10\ngroup_max_len = 1024\ngate_type = 1 # ggml_type::GGML_TYPE_F16\nup_type = 1 # ggml_type::GGML_TYPE_F16\ndown_type = 1 # ggml_type::GGML_TYPE_F16\nhidden_type = 1 # ggml_type::GGML_TYPE_F16\nn_routed_experts = 6\nqlen = 30\nlayer_num = 10\nCPUInfer = cpuinfer_ext.CPUInfer(48)\nvalidation_iter = 100\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    intermediate = act_fn(gate_buf) * up_buf\n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\nwith torch.inference_mode(mode=True):\n    moes = []\n    gate_projs = []\n    up_projs = []\n    down_projs = []\n    for _ in range(layer_num):\n        gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=torch.float16, device = \"cuda\").to(\"cpu\").contiguous()\n        config = cpuinfer_ext.moe.MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, stride, group_min_len, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)\n        moe = cpuinfer_ext.moe.MOE(config)\n        gate_projs.append(gate_proj)\n        up_projs.append(up_proj)\n        down_projs.append(down_proj)\n        moes.append(moe)\n\n    # validation\n    for i in range(validation_iter):\n        expert_ids = torch.stack([torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]).contiguous()\n        weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()\n        input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()\n        output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()\n        input = input / 100\n        \n        moe = moes[i % layer_num]\n        CPUInfer.submit(\n            moe.forward( \n                qlen,\n                n_routed_experts, \n                expert_ids.data_ptr(), \n                weights.data_ptr(), \n                input.data_ptr(), \n                output.data_ptr()\n            )\n        )\n        CPUInfer.sync()\n        # print('cpuinfer output', output)\n\n        gate_proj = gate_projs[i%layer_num]\n        up_proj = up_projs[i%layer_num]\n        down_proj = down_projs[i%layer_num]\n        t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj)\n        # print('torch output', t_output)\n\n        diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n        print('diff = ', diff)\n        assert(diff < 0.001)\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/examples/test_sft_amx_moe.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:38:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\nfrom pathlib import Path\nimport numpy as np\n\nexpert_num = 10\nhidden_size = 5120\nintermediate_size = 1536\nmax_len = 1024\n\nn_routed_experts = 2\nqlen = 600\nlayer_num = 10\nnum_threads = 112\nvalidation_iter = 1\nLAYER_IDX  = 0\nDUMP_DIR   = Path(os.getenv(\"SFT_DEBUG_PATH\", \"debug\"))\n\ndtype = torch.bfloat16\ngradtype = torch.bfloat16\n# torch.backends.cuda.matmul.allow_tf32 = False\n\nimport shutil\nfolder_path = \"/home/lpl/kt-sft/debug\"\nif os.path.exists(folder_path):\n    shutil.rmtree(folder_path)\nos.makedirs(folder_path)\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\ndef silu_fwd(x: torch.Tensor) -> torch.Tensor:\n    return x / (1. + torch.exp(-x))\n\ndef silu_grad(x: torch.Tensor) -> torch.Tensor:\n    \"\"\"SiLU激活函数的梯度\"\"\"\n    sigmoid_x = torch.sigmoid(x)\n    return sigmoid_x * (1. + x * (1. - sigmoid_x))\n\nclass SiLU(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, inp):\n        ctx.save_for_backward(inp)\n        return silu_fwd(inp)\n\n    @staticmethod\n    def backward(ctx, grad_out):\n        (inp,) = ctx.saved_tensors\n        sig = torch.sigmoid(inp)\n        return grad_out * (sig + inp * sig * (1. - sig))\n\nsilu = SiLU.apply   # 可求导版本\n\n# -------------------- Torch MLP / MoE 参考实现 --------------------\ndef mlp_torch(x, gate, up, down, req_grad=False):\n    g = torch.mm(x, gate.t())\n    u = torch.mm(x, up.t())\n    if req_grad:\n        inter = silu_fwd(g) * u\n    else:\n        inter = silu_fwd(g) * u\n    return torch.mm(inter, down.t())\n\ndef moe_torch(x, eid, w, gate, up, down, req_grad=False):\n    \"\"\"eid: [T,k]  int64,  w: [T,k] float\"\"\"\n    T, k = eid.shape\n    tok_cnt = torch.zeros(expert_num, dtype=torch.int64)\n    for e in eid.view(-1):\n        tok_cnt[e] += 1\n    # 打包 token\n    order = eid.view(-1).argsort()\n    packed = x[order // k]\n\n    outputs, start = [], 0\n    for e in range(expert_num):\n        num = tok_cnt[e].item()\n        if not num:\n            continue\n        end = start + num\n        o = mlp_torch(packed[start:end], gate[e], up[e], down[e], req_grad)\n        outputs.append(o)\n        start = end\n    if outputs:\n        out_all = torch.cat(outputs, 0)\n    else:\n        out_all = packed.new_empty(0, hidden_size)\n\n    # 还原顺序并做加权\n    out_restore = torch.empty_like(out_all)\n    out_restore[order] = out_all\n    out_restore = out_restore.view(T, k, hidden_size)\n    out = (out_restore * w.unsqueeze(-1)).sum(1)\n    return out\n\ndef moe_backward_python(x, eid, w, gate, up, down, grad_output, gate_u_cache, up_v_cache):\n    \"\"\"\n    Python模拟C++的MoE backward计算 - 完全仿照sft_moe.hpp的实现\n    参数:\n        x: 输入 [T, hidden_size]\n        eid: expert_ids [T, k]\n        w: weights [T, k]\n        gate, up, down: 权重矩阵\n        grad_output: 输出梯度 [T, hidden_size]\n        gate_u_cache, up_v_cache: forward时缓存的中间结果\n    返回:\n        grad_input: 输入梯度 [T, hidden_size]\n    \"\"\"\n    T, k = eid.shape\n    expert_num = gate.shape[0]\n    hidden_size = gate.shape[2]\n    intermediate_size = gate.shape[1]\n    \n    print(\"\\n========== Python Backward详细对拍 ==========\")\n    print(f\"输入形状: T={T}, k={k}, hidden_size={hidden_size}, intermediate_size={intermediate_size}\")\n    print(f\"\\n--- Python Token 0 ---\")\n    print(f\"  Expert 0: weight={w[0, 0].item():.6f}\")\n    \n    # 初始化梯度\n    grad_input = torch.zeros_like(x, dtype=torch.float32)\n    \n    # print(f\"grad_output:{grad_output}\")\n    # print(f\"gate_u_cache:{gate_u_cache}\")\n    # print(f\"up_v_cache:{up_v_cache}\")\n    \n    # print(f\"grad_output:{grad_output}\")\n    # print(f\"gate_u_cache:{gate_u_cache}\")\n    # print(f\"up_v_cache:{up_v_cache}\")\n    \n    # 按C++的方式组织数据：按expert分组\n    # 1. 统计每个expert处理的token数量\n    expert_token_counts = torch.zeros(expert_num, dtype=torch.int64)\n    for i in range(T):\n        for j in range(k):\n            expert_token_counts[eid[i, j]] += 1\n    \n    # 2. 构建expert到token的映射\n    expert_token_indices = [[] for _ in range(expert_num)]\n    expert_token_positions = [[] for _ in range(expert_num)]\n    \n    for i in range(T):\n        for j in range(k):\n            expert_id = int(eid[i, j].item())\n            expert_token_indices[expert_id].append(i)\n            expert_token_positions[expert_id].append(j)\n    \n    # 3. 为每个expert分配本地存储空间\n    max_tokens_per_expert = int(expert_token_counts.max().item()) if expert_token_counts.max() > 0 else 0\n    \n    # 本地存储空间（模拟C++中的m_local_*_ptr_）\n    local_input = torch.zeros(expert_num, max_tokens_per_expert, hidden_size, dtype=torch.float32)\n    local_gate_output = torch.zeros(expert_num, max_tokens_per_expert, intermediate_size, dtype=torch.float32)\n    local_up_output = torch.zeros(expert_num, max_tokens_per_expert, intermediate_size, dtype=torch.float32)\n    local_down_output_grad = torch.zeros(expert_num, max_tokens_per_expert, hidden_size, dtype=torch.float32)\n    local_down_input_grad = torch.zeros(expert_num, max_tokens_per_expert, intermediate_size, dtype=torch.float32)\n    local_gate_output_grad = torch.zeros(expert_num, max_tokens_per_expert, intermediate_size, dtype=torch.float32)\n    local_up_output_grad = torch.zeros(expert_num, max_tokens_per_expert, intermediate_size, dtype=torch.float32)\n    local_gate_input_grad = torch.zeros(expert_num, max_tokens_per_expert, hidden_size, dtype=torch.float32)\n    local_up_input_grad = torch.zeros(expert_num, max_tokens_per_expert, hidden_size, dtype=torch.float32)\n    \n    # 4. 复制输入数据和梯度到本地存储\n    for expert_id in range(expert_num):\n        for local_idx, (token_idx, expert_pos) in enumerate(zip(expert_token_indices[expert_id], expert_token_positions[expert_id])):\n            local_input[expert_id, local_idx] = x[token_idx].to(torch.float32)\n            local_down_output_grad[expert_id, local_idx] = grad_output[token_idx].to(torch.float32)\n    \n    # 5. 重新计算forward的中间结果（模拟C++中的forward计算）\n    for expert_id in range(expert_num):\n        num_tokens = expert_token_counts[expert_id]\n        if num_tokens == 0:\n            continue\n            \n        # 计算gate和up的输出\n        local_input_expert = local_input[expert_id, :num_tokens]  # [num_tokens, hidden_size]\n        gate_output = torch.mm(local_input_expert, gate[expert_id].to(torch.float32).t())  # [num_tokens, intermediate_size]\n        up_output = torch.mm(local_input_expert, up[expert_id].to(torch.float32).t())      # [num_tokens, intermediate_size]\n        \n        # 应用激活函数\n        gate_output_activated = silu_fwd(gate_output) * up_output\n        \n        local_gate_output[expert_id, :num_tokens] = gate_output\n        local_up_output[expert_id, :num_tokens] = up_output\n        \n    for expert_id in range(expert_num):\n        num_tokens = expert_token_counts[expert_id]\n        if num_tokens == 0:\n            continue\n        # print(f\"local_down_output_grad_E_{expert_id}: {local_down_output_grad[expert_id, :num_tokens]}\")\n        # print(f\"shape:{local_down_output_grad[expert_id, :num_tokens].shape}\")\n        # torch.save(local_down_output_grad[expert_id, :num_tokens], f\"debug/py_layer0_E_End{expert_id}_down_output_grad_.pt\")\n        # torch.save(local_gate_output[expert_id, :num_tokens], f\"debug/py_layer0_E_End{expert_id}_gate_output_.pt\")\n        # torch.save(local_up_output[expert_id, :num_tokens], f\"debug/py_layer0_E_End{expert_id}_up_output_.pt\")\n    \n    # 6. 计算down_input_grad（模拟C++中的down_t_bc_计算）\n    for expert_id in range(expert_num):\n        num_tokens = expert_token_counts[expert_id]\n        if num_tokens == 0:\n            continue\n        # down_input_grad = down_proj_t @ output_grad\n        down_input_grad = torch.mm(local_down_output_grad[expert_id, :num_tokens], down[expert_id].to(torch.float32))  # [num_tokens, intermediate_size]\n        local_down_input_grad[expert_id, :num_tokens] = down_input_grad\n            \n    for expert_id in range(expert_num):\n        num_tokens = expert_token_counts[expert_id]\n        if num_tokens == 0:\n            continue\n        # torch.save(local_gate_output_grad[expert_id, :num_tokens], f\"debug/py_layer0_E_End{expert_id}_gate_output_grad_.pt\")\n        # torch.save(local_up_output_grad[expert_id, :num_tokens], f\"debug/py_layer0_E_End{expert_id}_up_output_grad_.pt\")\n        torch.save(local_down_output_grad[expert_id, :num_tokens], f\"debug/py_layer0_E_End{expert_id}_down_output_grad_.pt\")\n        # torch.save(down[expert_id].to(torch.float32), f\"debug/py_layer0_E_End{expert_id}_down_weight_.pt\")\n        torch.save(local_gate_output[expert_id, :num_tokens], f\"debug/py_layer0_E_End{expert_id}_gate_output_.pt\")\n    \n    # 7. 计算gate_output_grad和up_output_grad（模拟C++中的核心计算）\n    for expert_id in range(expert_num):\n        num_tokens = expert_token_counts[expert_id]\n        if num_tokens == 0:\n            continue\n            \n        for local_idx in range(num_tokens):\n            token_idx = expert_token_indices[expert_id][local_idx]\n            expert_pos = expert_token_positions[expert_id][local_idx]\n            weight = w[token_idx, expert_pos].item()\n            \n            # 只为第一个token的第一个expert输出调试信息\n            should_print = (token_idx == 0 and expert_pos == 0)\n            \n            # 获取当前token的中间结果\n            gate_u = local_gate_output[expert_id, local_idx]  # [intermediate_size]\n            up_v = local_up_output[expert_id, local_idx]      # [intermediate_size]\n            down_input_grad_token = local_down_input_grad[expert_id, local_idx]  # [intermediate_size]\n            \n            # 应用weight\n            down_input_grad_token = down_input_grad_token * weight\n            \n            if should_print:\n                print(f\"    down_input_grad前5个值: {down_input_grad_token[:5].tolist()}\")\n            \n            # gate_output_grad = down_input_grad * up_v * silu_grad(gate_u)\n            gate_output_grad = down_input_grad_token * up_v * silu_grad(gate_u)\n            \n            # up_output_grad = down_input_grad * silu_fwd(gate_u)\n            up_output_grad = down_input_grad_token * silu_fwd(gate_u)\n            \n            if should_print:\n                print(f\"    gate_output_grad前5个值: {gate_output_grad[:5].tolist()}\")\n                print(f\"    up_output_grad前5个值: {up_output_grad[:5].tolist()}\")\n            \n            local_gate_output_grad[expert_id, local_idx] = gate_output_grad\n            local_up_output_grad[expert_id, local_idx] = up_output_grad\n    \n    # 8. 计算gate_input_grad和up_input_grad（模拟C++中的矩阵乘法）\n    for expert_id in range(expert_num):\n        num_tokens = expert_token_counts[expert_id]\n        if num_tokens == 0:\n            continue\n            \n        # gate_input_grad = gate_proj_t @ gate_output_grad\n        gate_input_grad = torch.mm(local_gate_output_grad[expert_id, :num_tokens], \n                                  gate[expert_id].to(torch.float32))  # [num_tokens, hidden_size]\n        \n        # up_input_grad = up_proj_t @ up_output_grad\n        up_input_grad = torch.mm(local_up_output_grad[expert_id, :num_tokens], \n                                up[expert_id].to(torch.float32))  # [num_tokens, hidden_size]\n        \n        local_gate_input_grad[expert_id, :num_tokens] = gate_input_grad\n        local_up_input_grad[expert_id, :num_tokens] = up_input_grad\n        \n        # 输出第一个token的调试信息\n        if expert_id == 0 and num_tokens > 0:\n            token_idx = expert_token_indices[expert_id][0]\n            expert_pos = expert_token_positions[expert_id][0]\n            if token_idx == 0 and expert_pos == 0:\n                print(f\"    gate_input_grad前5个值: {gate_input_grad[0, :5].tolist()}\")\n                print(f\"    up_input_grad前5个值: {up_input_grad[0, :5].tolist()}\")\n    \n    # 9. 累加所有expert的梯度到最终输出（模拟C++中的最终累加）\n    for token_idx in range(T):\n        token_grad = torch.zeros(hidden_size, dtype=torch.float32)\n        \n        for expert_pos in range(k):\n            expert_id = int(eid[token_idx, expert_pos].item())\n            \n            # 找到这个token在expert_id中的本地索引\n            local_idx = expert_token_indices[expert_id].index(token_idx)\n            \n            # 累加gate和up的输入梯度\n            token_grad += local_gate_input_grad[expert_id, local_idx]\n            token_grad += local_up_input_grad[expert_id, local_idx]\n        \n        grad_input[token_idx] = token_grad\n        \n        # 输出第一个token的最终结果\n        if token_idx == 0:\n            print(f\"  Token 0 最终input_grad前5个值: {token_grad[:5].tolist()}\")\n            \n    # print(f\"gate_input_grad:{gate_input_grad}\")\n    # print(f\"up_input_grad:{up_input_grad}\")\n    # print(f\"grad_input:{grad_input}\")\n    \n    return grad_input\n\n# --------------------------- 主测试 ---------------------------\ndef test_amx_moe_two_round():\n    # ------------ 构造权重 ------------\n    gate_proj = torch.randn(expert_num, intermediate_size, hidden_size,\n                            dtype=torch.bfloat16, requires_grad=True).contiguous()\n    up_proj   = torch.randn_like(gate_proj)\n    down_proj = torch.randn(expert_num, hidden_size, intermediate_size,\n                            dtype=torch.bfloat16, requires_grad=True).contiguous()\n    \n    # gate_proj_t = gate_proj.transpose(1, 2).contiguous() # 形状: (E, H, I)\n    # up_proj_t   = up_proj.transpose(1, 2).contiguous()\n    # down_proj_t   = down_proj.transpose(1, 2).contiguous()\n\n    # ------------ SFT-AMX 对象 ------------\n    cfg = cpuinfer_ext.sft_moe.SFT_AMX_MOEConfig(\n        expert_num, n_routed_experts,\n        hidden_size, intermediate_size,\n        max_len,\n        gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr()\n    )    \n    moe_cpp = cpuinfer_ext.sft_moe.SFT_AMXInt8_MOE(cfg)\n\n    \n    cpu_infer = cpuinfer_ext.CPUInfer(num_threads)\n    \n    cpu_infer.submit(moe_cpp.load_weights())\n    cpu_infer.sync() # ATTENTION: DO NOT FORGET sync after load weights\n    \n    expert_ids = torch.stack(\n        [torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]).contiguous()\n\n    weights = torch.rand(qlen, n_routed_experts, dtype=torch.float32).contiguous()\n\n    input_pt  = (torch.randn((qlen, hidden_size), dtype=dtype) / 100)\\\n                .detach().requires_grad_(True).contiguous()\n    input_cpp = input_pt.detach().clone().requires_grad_(True).contiguous()\n\n    # ------------- forward -------------\n    # Torch reference\n    out_ref = moe_torch(input_pt, expert_ids, weights,\n                        gate_proj, up_proj, down_proj, True)\n    out_ref.retain_grad()\n\n    # 缓存forward中间结果用于python backward\n    gate_u_cache = []\n    up_v_cache = []\n    \n    # 模拟forward过程并缓存中间结果\n    for token_idx in range(qlen):\n        token_gate_u = []\n        token_up_v = []\n        for expert_pos in range(n_routed_experts):\n            expert_id = int(expert_ids[token_idx, expert_pos].item())\n            # 计算gate和up的输出\n            gate_u = torch.mm(input_pt[token_idx:token_idx+1].to(torch.float32), gate_proj[expert_id].to(torch.float32).t()).squeeze()\n            up_v = torch.mm(input_pt[token_idx:token_idx+1].to(torch.float32), up_proj[expert_id].to(torch.float32).t()).squeeze()\n            token_gate_u.append(gate_u)\n            token_up_v.append(up_v)\n        gate_u_cache.append(token_gate_u)\n        up_v_cache.append(token_up_v)\n        \n    flop_fwd = 6 * qlen * n_routed_experts * hidden_size * intermediate_size\n    flop_bwd = 18 * qlen * n_routed_experts * hidden_size * intermediate_size\n\n    # C++ AMX forward\n    out_cpp = torch.empty_like(out_ref, dtype=dtype).contiguous()\n    t0 = time.time()\n    cpu_infer.submit(moe_cpp.forward(\n        qlen, n_routed_experts,\n        expert_ids.data_ptr(), weights.data_ptr(),\n        input_cpp.data_ptr(), out_cpp.data_ptr()))\n    cpu_infer.sync()\n    t1 = time.time()\n    diff_fwd = (out_cpp.to(torch.float32) - out_ref.to(torch.float32)).abs()\n    print(f\"out_cpp.to(torch.float32):{out_cpp.to(torch.float32)}, out_ref.to(torch.float32):{out_ref.to(torch.float32)}\")\n    rel_fwd  = diff_fwd.mean() / out_ref.abs().mean()\n    print(f\"Forward   diff: {rel_fwd.item():.3e} | time {t1-t0:.4f}s | \"\n            f\"TFLOPS {flop_fwd/(t1-t0)/1e12:.2f}\")\n    \n\n    # ------------- backward -------------\n    grad_out = torch.randn_like(out_ref, dtype=gradtype).contiguous()\n    grad_out_cpp = grad_out.clone().contiguous()\n    grad_in_cpp  = torch.zeros_like(input_cpp, dtype=gradtype).contiguous()\n\n    # # Torch backward\n    for p in (gate_proj, up_proj, down_proj, input_pt):\n        if p.grad is not None:\n            p.grad.zero_()\n    t2 = time.time()\n    out_ref.backward(grad_out, retain_graph=True)\n    t3 = time.time()\n    print(f\"PyTorch backward time {t3-t2:.4f}s | \"\n            f\"TFLOPS {flop_bwd/(t3-t2)/1e12:.2f}\")\n\n    # Python backward（模拟C++逻辑）- 详细版本\n    t4_py = time.time()\n    grad_in_python = moe_backward_python(\n        input_pt, expert_ids, weights,\n        gate_proj, up_proj, down_proj,\n        grad_out.to(torch.float32), gate_u_cache, up_v_cache)\n    t5_py = time.time()\n    print(f\"Python   backward time {t5_py-t4_py:.4f}s | \"\n            f\"TFLOPS {flop_bwd/(t5_py-t4_py)/1e12:.2f}\")\n\n    # C++ backward\n    t4 = time.time()\n    print(\"Before backward\")\n    cpu_infer.submit(moe_cpp.backward(\n        qlen, n_routed_experts,\n        expert_ids.data_ptr(), weights.data_ptr(), input_cpp.data_ptr(),\n        grad_out_cpp.data_ptr(),\n        grad_in_cpp.data_ptr()))\n    cpu_infer.sync()\n    t5 = time.time()\n    print(\"After backward\")\n    print(f\"C++      backward time {t5-t4:.4f}s | \"\n            f\"TFLOPS {flop_bwd/(t5-t4)/1e12:.2f}\")\n\n    # 三种backward结果对比\n    gcpp = grad_in_cpp.to(torch.float32)\n    gref = input_pt.grad.to(torch.float32) if input_pt.grad is not None else torch.zeros_like(input_pt, dtype=torch.float32)\n    gpy = grad_in_python.to(torch.float32)\n    \n    print(f\"C++ AMX backward:{gcpp}\", '\\n', '\\n', f\"python backward:{gpy}\")\n    \n    # 对比结果\n    rel_bwd_cpp = (gcpp - gref).abs().mean() / gref.abs().mean()\n    rel_bwd_py = (gpy - gref).abs().mean() / gref.abs().mean()\n    rel_bwd_cpp_py = (gcpp - gpy).abs().mean() / gpy.abs().mean()\n    \n    print(f\"Torch vs C++:    {rel_bwd_cpp.item():.3e}\")\n    print(f\"Torch vs Python: {rel_bwd_py.item():.3e}\")\n    print(f\"C++ vs Python:   {rel_bwd_cpp_py.item():.3e}\")\n    \n    # 检查是否对拍成功\n    if rel_bwd_cpp_py.item() < 5e-2:\n        print(\"✅ C++和Python backward对拍成功!\")\n    else:\n        print(\"❌ C++和Python backward对拍失败，存在显著差异\")\n        \n    \n    # manual_check(expert_ids)\n\ndef load_bf16(stub, shape):\n    with open(stub + \".bf16\", \"rb\") as f:\n        return torch.frombuffer(f.read(), dtype=torch.bfloat16).view(shape).float()\ndef load_f16(stub, shape):\n    with open(stub+\".f16\",'rb') as f:\n        return torch.frombuffer(f.read(), dtype=torch.float16).view(shape).float()\ndef load_f32(stub, shape):\n    with open(stub+\".f32\",'rb') as f:\n        return torch.frombuffer(f.read(), dtype=torch.float32).view(shape)\ndef load_uint8(stub, shape):\n    with open(stub+\".uint8\",'rb') as f:\n        return torch.frombuffer(f.read(), dtype=torch.uint8).view(shape)\ndef load_int8(stub, shape):\n    with open(stub+\".int8\",'rb') as f:\n        return torch.frombuffer(f.read(), dtype=torch.int8).view(shape)\n\n# 通用加载函数\ndef load_dump_tensor(experts_idx: int, name: str, shape: tuple, Ename: str = \"E_Before\"):\n    \"\"\"\n    根据 experts_idx / name / shape 读取 dump 文件，并返回 torch.Tensor\n    \"\"\"\n    stub = DUMP_DIR / f\"layer{LAYER_IDX}_{Ename}{experts_idx}_{name}\"\n    if stub.with_suffix(\".bf16\").exists():\n        return load_bf16(str(stub), shape)\n    elif stub.with_suffix(\".f16\").exists():\n        return load_f16(str(stub), shape)\n    elif stub.with_suffix(\".f32\").exists():\n        return load_f32(str(stub), shape)\n    elif stub.with_suffix(\".uint8\").exists():\n        return load_uint8(str(stub), shape)\n    elif stub.with_suffix(\".int8\").exists():\n        return load_int8(str(stub), shape)\n    else:\n        raise FileNotFoundError(f\"{stub}（bf16/f16/f32/u8/i8 均不存在）\")\n    \ndef load_bin(path, n, k):\n    # 从文件读出 n*k 个 float32\n    data = np.fromfile(path, dtype=np.float32)\n    assert data.size == n * k\n    data = data.reshape(n, k)\n    return torch.from_numpy(data).to(torch.bfloat16)    \n\ndef check_nan(name, shape):\n    stub1 = DUMP_DIR / f\"{name}\"\n    if stub1.with_suffix(\".bf16\").exists():\n        cpp_bef = load_bf16(str(stub1), shape)\n    elif stub1.with_suffix(\".f16\").exists():\n        cpp_bef = load_f16(str(stub1), shape)\n    elif stub1.with_suffix(\".f32\").exists():\n        cpp_bef = load_f32(str(stub1), shape)\n    elif stub1.with_suffix(\".int8\").exists():\n        return load_int8(str(stub1), shape)\n    else:\n        print(\"dump 缺失/未知类型\"); return\n\n    print(f\"{name}:{cpp_bef}\")\n    print(f\" shape : {cpp_bef.shape}\")\n    print(f\" dtype : {cpp_bef.dtype}\")\n\n    finite_mask = torch.isfinite(cpp_bef)\n    if finite_mask.any():\n        t_finite = cpp_bef[finite_mask]\n        t_max = t_finite.max().item()\n        t_min = t_finite.min().item()\n        print(f\" max   : {t_max:.6e}\")\n        print(f\" min   : {t_min:.6e}\")\n    else:\n        print(\" max/min: 所有元素均为 NaN / Inf\")\n\n    for nan_name, t in [(f\"{name}\", cpp_bef)]:\n        nan_cnt = torch.isnan(t).sum().item()\n        inf_cnt = torch.isinf(t).sum().item()\n        if nan_cnt or inf_cnt:\n            print(f\"{name} 含 NaN={nan_cnt}、Inf={inf_cnt}\")\n        else:\n            print(\"NO NaN or Inf exist\")    \n\ndef get_tensor(name, shape) -> torch.Tensor:\n    stub1 = DUMP_DIR / f\"{name}\"\n    if stub1.with_suffix(\".bf16\").exists():\n        cpp_bef = load_bf16(str(stub1), shape)\n    elif stub1.with_suffix(\".f16\").exists():\n        cpp_bef = load_f16(str(stub1), shape)\n    elif stub1.with_suffix(\".f32\").exists():\n        cpp_bef = load_f32(str(stub1), shape)\n    elif stub1.with_suffix(\".int8\").exists():\n        return load_int8(str(stub1), shape)\n    else:\n        print(\"dump 缺失/未知类型\"); return\n\n    return cpp_bef\n\ndef check_py_cpp(name1, name2, shape):\n    print(f\"compare {name1} with {name2}, at shape{shape}\")\n    stub1 = DUMP_DIR / f\"{name1}\"\n    py_bef = torch.load(f\"{stub1}\")\n    if not isinstance(py_bef, torch.Tensor):\n        print(f\"⚠️ {name1} 不是 Tensor，而是 {type(py_bef)}\")\n        return\n    stub2 = DUMP_DIR / f\"{name2}\"\n    if stub2.with_suffix(\".bf16\").exists():\n        cpp_bef = load_bf16(str(stub2), shape)\n    elif stub2.with_suffix(\".f16\").exists():\n        cpp_bef = load_f16(str(stub2), shape)\n    elif stub2.with_suffix(\".f32\").exists():\n        cpp_bef = load_f32(str(stub2), shape)\n    elif stub2.with_suffix(\".int8\").exists():\n        return load_int8(str(stub2), shape)\n    else:\n        print(f\"dump 缺失/未知类型: {stub2}\"); return\n        \n    for t in [py_bef]:\n        nan_cnt = torch.isnan(t).sum().item()\n        inf_cnt = torch.isinf(t).sum().item()\n        if nan_cnt or inf_cnt:\n            print(f\"{name1} 含 NaN={nan_cnt}、Inf={inf_cnt}\")\n        else:\n            print(\"NO NaN or Inf exist\")\n    for t in [cpp_bef]:\n        nan_cnt = torch.isnan(t).sum().item()\n        inf_cnt = torch.isinf(t).sum().item()\n        if nan_cnt or inf_cnt:\n            print(f\"{name2} 含 NaN={nan_cnt}、Inf={inf_cnt}\")\n        else:\n            print(\"NO NaN or Inf exist\")\n            \n    if py_bef.shape != cpp_bef.shape:\n        print(f\"shape 不一致: py_bef {py_bef.shape}, cpp_bef {cpp_bef.shape}\")\n    else:\n        # 计算绝对差、相对差、最大差值\n        eps = 1e-6  # 防止除以 0\n        denominator = torch.abs(py_bef) + eps\n        rel_diff = torch.abs(py_bef - cpp_bef) / denominator\n\n        # 找出大于 2% 的项\n        mask = rel_diff > 0.02\n        num_large_diff = mask.sum().item()\n        total = rel_diff.numel()\n\n        if num_large_diff == 0:\n            print(\"✅ 所有元素相对误差都在 2% 范围内\")\n            flat_rel_diff = rel_diff.view(-1)\n            max_idx = torch.argmax(flat_rel_diff)\n            max_val = flat_rel_diff[max_idx].item()\n\n            # 还原成多维索引\n            max_pos = tuple(torch.unravel_index(max_idx, py_bef.shape))\n\n            # 获取原始值\n            py_val = py_bef[max_pos].item()\n            cpp_val = cpp_bef[max_pos].item()\n\n            print(f\"    最大相对误差 = {max_val:.2%}\")\n            print(f\"    最大相对误差位置: {max_pos}, py  = {py_val:.6f}, cpp = {cpp_val:.6f}\")\n        else:\n            print(f\"❗ 相对误差 > 2% 的元素数量: {num_large_diff} / {total}\")\n            print(f\"{name1}: {py_bef}\")\n            print(f\"{name2}: {cpp_bef}\")\n\n# 汇总要查哪些内容\ndef manual_check(experts_ids):\n    expert_token_counts = torch.zeros(expert_num, dtype=torch.int64)\n    T, k = experts_ids.shape\n    for i in range(T):\n        for j in range(k):\n            expert_token_counts[experts_ids[i, j]] += 1\n    for experts_idx in range(expert_num):\n        # input1 = get_tensor(f\"cpp_layer0_E_End{experts_idx}_down_t_ba_\", (expert_token_counts[experts_idx], hidden_size))\n        # # down_ba_new = get_tensor(f\"cpp_layer0_E_End{experts_idx}_down_ba_new_\", (expert_token_counts[experts_idx], intermediate_size))\n        # weight1 = get_tensor(f\"cpp_layer0_E_End{experts_idx}_down_t_bb_\", (hidden_size, intermediate_size))\n        # output1 = torch.matmul(input1, weight1)\n        # print(f\"input1:{input1}, shape:{input1.shape}\")\n        # # print(f\"down_ba_new:{down_ba_new}, shape:{down_ba_new.shape}\")\n        # print(f\"weight1:{weight1}, shape:{weight1.shape}\")\n        # print(f\"output1:{output1}, shape:{output1.shape}\")\n\n        # shape=(expert_token_counts[experts_idx], intermediate_size)\n        # stub_bc = DUMP_DIR / f\"cpp_layer0_E_End{experts_idx}_down_t_bc_\"\n        # if stub_bc.with_suffix(\".bf16\").exists():\n        #     output1_5 = load_bf16(str(stub_bc), shape)\n        # elif stub_bc.with_suffix(\".f16\").exists():\n        #     output1_5 = load_f16(str(stub_bc), shape)\n        # elif stub_bc.with_suffix(\".f32\").exists():\n        #     output1_5 = load_f32(str(stub_bc), shape)\n        # elif stub_bc.with_suffix(\".int8\").exists():\n        #     return load_int8(str(stub_bc), shape)\n        # else:\n        #     print(f\"dump 缺失/未知类型: {stub_bc}\"); return\n        # print(f\"output1_5:{output1_5}, shape:{output1_5.shape}\")\n        \n        # torch.set_printoptions(profile=\"full\")\n        \n        down_ba_ori = get_tensor(f\"cpp_layer0_E_End{experts_idx}_down_ba_ori_\", (expert_token_counts[experts_idx], intermediate_size))\n\n        # with open(f\"/home/lpl/kt-sft/debug/cpp_{experts_idx}_down_ba_ori_view.txt\", \"w\") as f:\n        #     f.write(str(down_ba_ori))   \n        \n    \n        down_output_grad = get_tensor(f\"cpp_layer0_E_End{experts_idx}_down_output_grad_\", (expert_token_counts[experts_idx], hidden_size))\n\n        # with open(f\"/home/lpl/kt-sft/debug/cpp_{experts_idx}_down_t_ba_ori_view.txt\", \"w\") as f:\n        #     f.write(str(down_output_grad))\n            \n        \n        # input2 = torch.load(f\"debug/py_layer0_E_End{experts_idx}_down_output_grad_.pt\")\n        # weight2 = torch.load(f\"debug/py_layer0_E_End{experts_idx}_down_weight_.pt\")\n        # output2 = torch.load(f\"debug/py_layer0_E_End{experts_idx}_down_input_grad_.pt\")\n        # print(f\"input2:{input2}, shape:{input2.shape}\")\n        # print(f\"weight2:{weight2}, shape:{weight2.shape}\")\n        # print(f\"output2:{output2}, shape:{output2.shape}\")\n        \n        # down_t_ba_new = load_bin(f'debug/{experts_idx}_down_ba_t_debug3.bin', expert_token_counts[experts_idx], hidden_size)\n        \n        # print(f\"input3: {down_t_ba_new}, shape: {down_t_ba_new.shape}\")\n        \n        py_down_t_ba = torch.load(f\"debug/py_layer0_E_End{experts_idx}_down_output_grad_.pt\")\n        py_down_ba = torch.load(f\"debug/py_layer0_E_End{experts_idx}_gate_output_.pt\")\n\n        # with open(f\"/home/lpl/kt-sft/debug/py_{experts_idx}_down_t_ba_ori_view.txt\", \"w\") as f:\n        #     f.write(str(py_down_t_ba))\n        \n        # with open(f\"/home/lpl/kt-sft/debug/py_{experts_idx}_down_ba_ori_view.txt\", \"w\") as f:\n        #     f.write(str(py_down_ba))\n            \n        print(f\"cpp_{experts_idx}_down_ba_ori_:{down_ba_ori}\") \n        print(f\"py_{experts_idx}_down_ba_ori_view: {py_down_ba}\")\n        print(f\"cpp_{experts_idx}_down_t_ba_ori_view:{down_output_grad}\") \n        print(f\"py_{experts_idx}_down_t_ba_ori_view: {py_down_t_ba}\")\n\n        # torch.set_printoptions(profile=\"default\")\n        \n        \nif __name__ == \"__main__\":\n    torch.manual_seed(42)\n    test_amx_moe_two_round()"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/examples/test_sft_moe.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenht2022\nDate         : 2024-07-25 10:32:05\nVersion      : 1.0.0\nLastEditors  : chenht2022 \nLastEditTime : 2024-08-06 10:38:05\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport os, sys\nimport time\nsys.path.append(os.path.dirname(__file__) + '/../build')\nimport cpuinfer_ext\nimport torch\n\nexpert_num = 10\nhidden_size = 5120\nintermediate_size = 1536\nstride = 32\ngroup_min_len = 10\ngroup_max_len = 1024\ngate_type = 1 # ggml_type::GGML_TYPE_F16\nup_type = 1 # ggml_type::GGML_TYPE_F16\ndown_type = 1 # ggml_type::GGML_TYPE_F16\nhidden_type = 1 # ggml_type::GGML_TYPE_F16\nn_routed_experts = 2\nqlen = 30\nlayer_num = 10\nCPUInfer = cpuinfer_ext.CPUInfer(48)\nvalidation_iter = 100\n\ndtype = torch.float16\ngradtype = torch.bfloat16\n\ndef act_fn(x):\n    return x / (1.0 + torch.exp(-x))\n\n# 定义SiLU激活函数的可微版本（带梯度）\nclass SiLU(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input):\n        ctx.save_for_backward(input)\n        return input / (1.0 + torch.exp(-input))\n    \n    @staticmethod\n    def backward(ctx, grad_output):\n        input, = ctx.saved_tensors\n        sigmoid = 1.0 / (1.0 + torch.exp(-input))\n        return grad_output * (sigmoid + input * sigmoid * (1 - sigmoid))\n\nsilu = SiLU.apply\n\ndef mlp_torch(input, gate_proj, up_proj, down_proj, requires_grad=False):\n    gate_buf = torch.mm(input, gate_proj.t())\n    up_buf = torch.mm(input, up_proj.t())\n    \n    # 使用可微的SiLU或者原来的函数，取决于是否需要梯度\n    if requires_grad:\n        intermediate = silu(gate_buf) * up_buf\n    else:\n        intermediate = act_fn(gate_buf) * up_buf\n    \n    ret = torch.mm(intermediate, down_proj.t())\n    return ret\n\ndef moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj, requires_grad=False):\n    cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))\n    cnts.scatter_(1, expert_ids, 1)\n    tokens_per_expert = cnts.sum(dim=0)\n    idxs = expert_ids.view(-1).argsort()\n    sorted_tokens = input[idxs // expert_ids.shape[1]]\n\n    outputs = []\n    start_idx = 0\n    for i, num_tokens in enumerate(tokens_per_expert):\n        end_idx = start_idx + num_tokens\n        if num_tokens == 0:\n            continue\n        tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n        expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i], requires_grad)\n        outputs.append(expert_out)\n        start_idx = end_idx\n\n    outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n    new_x = torch.empty_like(outs)\n    new_x[idxs] = outs\n    t_output = (\n        new_x.view(*expert_ids.shape, -1)\n        .type(weights.dtype)\n        .mul_(weights.unsqueeze(dim=-1))\n        .sum(dim=1)\n        .type(new_x.dtype)\n    )\n    return t_output\n\n# 前向传播验证\ndef test_forward():\n    with torch.inference_mode(mode=True):\n        moes = []\n        gate_projs = []\n        up_projs = []\n        down_projs = []\n        for _ in range(layer_num):\n            gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=dtype, device = \"cuda\").to(\"cpu\").contiguous()\n            up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=dtype, device = \"cuda\").to(\"cpu\").contiguous()\n            down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=dtype, device = \"cuda\").to(\"cpu\").contiguous()\n            config = cpuinfer_ext.sft_moe.SFT_MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, stride, group_min_len, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type, 0)\n            moe = cpuinfer_ext.sft_moe.SFT_MOE(config)\n            gate_projs.append(gate_proj)\n            up_projs.append(up_proj)\n            down_projs.append(down_proj)\n            moes.append(moe)\n\n        # validation\n        for i in range(validation_iter):\n            expert_ids = torch.stack([torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]).contiguous()\n            weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()\n            input = torch.randn((qlen, hidden_size), dtype=dtype).contiguous()\n            output = torch.empty((qlen, hidden_size), dtype=dtype).contiguous()\n            input = input / 100\n            \n            moe = moes[i % layer_num]\n            CPUInfer.submit(\n                moe.forward( \n                    qlen,\n                    n_routed_experts, \n                    expert_ids.data_ptr(), \n                    weights.data_ptr(), \n                    input.data_ptr(), \n                    output.data_ptr()\n                )\n            )\n            CPUInfer.sync()\n            # print('cpuinfer output', output)\n\n            gate_proj = gate_projs[i%layer_num]\n            up_proj = up_projs[i%layer_num]\n            down_proj = down_projs[i%layer_num]\n            t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj)\n            # print('torch output', t_output)\n\n            diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))\n            print('diff = ', diff)\n            assert(diff < 0.001)\n\n# 反向传播验证\ndef test_backward():\n    # 先测试backward是否能正常调用\n    print(\"\\n===== Testing Backward Pass =====\")\n    # 创建一个单层MOE用于测试\n    gate_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=dtype, requires_grad=True).contiguous()\n    up_proj = torch.randn((expert_num, intermediate_size, hidden_size), dtype=dtype, requires_grad=True).contiguous()\n    down_proj = torch.randn((expert_num, hidden_size, intermediate_size), dtype=dtype, requires_grad=True).contiguous()\n    # 创建MOE实例\n    config = cpuinfer_ext.sft_moe.SFT_MOEConfig(expert_num, n_routed_experts, hidden_size, intermediate_size, \n                                       stride, group_min_len, group_max_len, \n                                       gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), \n                                       gate_type, up_type, down_type, hidden_type)  # 使用float16类型(0=GGML_TYPE_F16)\n    moe = cpuinfer_ext.sft_moe.SFT_MOE(config)\n\n    # 创建输入数据\n    expert_ids = torch.stack([torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]).contiguous()\n    weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()\n    \n    # 使用相同的输入进行torch和C++算子的计算\n    input = torch.randn((qlen, hidden_size), dtype=dtype, requires_grad=True).contiguous()\n    input = (input / 100).detach().requires_grad_(True)\n    input_cpp = input.clone().detach().requires_grad_(True).contiguous()\n\n    # 计算PyTorch参考输出\n    t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj, requires_grad=True)\n    # 确保非叶子张量保留梯度\n    t_output.retain_grad()\n    \n    # 计算C++算子输出\n    output_cpp = torch.empty((qlen, hidden_size), dtype=dtype).contiguous()\n\n    # 前向传播\n    forward_start_time = time.time()\n    CPUInfer.submit(\n        moe.forward(\n            qlen,\n            n_routed_experts,\n            expert_ids.data_ptr(),\n            weights.data_ptr(),\n            input_cpp.data_ptr(),\n            output_cpp.data_ptr()\n        )\n    )\n    CPUInfer.sync()\n    forward_end_time = time.time()\n    print(f\"C++ forward 耗时: {forward_end_time - forward_start_time:.4f} 秒\")\n    \n    FLOPs_fwd  = 6 * qlen * n_routed_experts * hidden_size * intermediate_size\n    KT_TFLOPS_fwd = FLOPs_fwd / (forward_end_time - forward_start_time) / 1e12\n    \n    # 验证前向传播结果\n    forward_diff = torch.mean(torch.abs(output_cpp - t_output)) / torch.mean(torch.abs(t_output))\n    print(f\"Forward diff: {forward_diff.item()}\")\n    assert forward_diff < 0.001, f\"Forward diff too large: {forward_diff.item()}\"\n    print(\"✅ Forward test passed!\")\n    \n    grad_input_cpp = torch.empty_like(input_cpp, dtype=gradtype).contiguous()\n    grad_output = torch.randn_like(t_output, dtype=gradtype).contiguous()\n    grad_output_cpp = grad_output.clone()\n    \n    print(\"-- pytorch backward --\")\n    # PyTorch反向传播性能测试\n    pytorch_start_time = time.time()\n\n    t_output.backward(grad_output, retain_graph=True)\n\n    pytorch_end_time = time.time()\n    pytorch_time = (pytorch_end_time - pytorch_start_time)\n    \n    print(\"-- c++ backward --\")\n    # C++反向传播性能测试\n    CPUInfer.submit(\n        moe.backward(\n            qlen,\n            n_routed_experts,\n            expert_ids.data_ptr(),\n            weights.data_ptr(),\n            input_cpp.data_ptr(),\n            grad_output_cpp.data_ptr(),\n            grad_input_cpp.data_ptr()\n        )\n    )\n    CPUInfer.sync()\n\n    cpp_start_time = time.time()\n    CPUInfer.submit(\n        moe.backward(\n            qlen,\n            n_routed_experts,\n            expert_ids.data_ptr(),\n            weights.data_ptr(),\n            input_cpp.data_ptr(),\n            grad_output_cpp.data_ptr(),\n            grad_input_cpp.data_ptr()\n        )\n    )\n    CPUInfer.sync()\n\n    cpp_end_time = time.time()\n    cpp_time = (cpp_end_time - cpp_start_time)\n    print(f\"PyTorch backward 耗时: {pytorch_time:.4f} 秒\")\n    print(f\"C++ backward 耗时: {cpp_time:.4f} 秒\")\n    print(f\"性能比较: PyTorch/C++ = {pytorch_time/cpp_time:.2f}x\")\n    \n\n    print(f\"qlen:{qlen}, n_exp:{n_routed_experts}, hidden:{hidden_size}, inter:{intermediate_size}\")\n    FLOPs_bwd  = 18 * qlen * n_routed_experts * hidden_size * intermediate_size\n    torch_TFLOPS_bwd = FLOPs_bwd / pytorch_time / 1e12\n    KT_TFLOPS_bwd = FLOPs_bwd / cpp_time / 1e12\n    \n    print(f\"PyTorch backward TFLOPS: {torch_TFLOPS_bwd}\")\n    print(f\"KT forward TFLOPS: {KT_TFLOPS_fwd}\")\n    print(f\"KT backward TFLOPS: {KT_TFLOPS_bwd}\")\n\n        # ================== TFLOPS 统计 ==================\n    total_flops_fwd = 6 * qlen * n_routed_experts * hidden_size * intermediate_size\n    total_flops_bwd = 18 * qlen * n_routed_experts * hidden_size * intermediate_size\n\n    tflops_fwd_cpp = total_flops_fwd / (forward_end_time - forward_start_time) / 1e12\n    tflops_bwd_cpp = total_flops_bwd / cpp_time / 1e12\n    tflops_bwd_torch = total_flops_bwd / pytorch_time / 1e12\n\n    print(f\"\\n=== TFLOPS ===\")\n    print(f\"CPUInfer forward  : {tflops_fwd_cpp:.2f} TFLOPS\")\n    print(f\"CPUInfer backward : {tflops_bwd_cpp:.2f} TFLOPS\")\n    print(f\"Torch   backward : {tflops_bwd_torch:.2f} TFLOPS\")\n\n\n    # 验证梯度结果\n    backward_diff = torch.mean(torch.abs(grad_input_cpp - input.grad)) / torch.mean(torch.abs(input.grad))\n    print(f\"Backward diff: {backward_diff.item()}\")\n    assert backward_diff < 0.005, f\"Backward diff too large: {backward_diff.item()}\" # FIXME: 0.005 是不是太大了？ \n    print(\"✅ Backward pass test passed!\")\n\ndef test_backward_2round_with_tflops():\n    \"\"\"\n    跑两轮 forward+backward，对比 PyTorch 与 C++ 实现的正确性和性能，\n    并输出每轮及总体的 TFLOPS 与耗时信息。\n    依赖：已在全局定义 expert_num、n_routed_experts、hidden_size、intermediate_size、\n          stride、group_min_len、group_max_len、gate_type、up_type、down_type、\n          hidden_type、qlen、dtype、gradtype 以及 moe_torch、cpuinfer_ext、CPUInfer。\n    \"\"\"\n    # ------------- 初始化可训练参数（保持与单轮测试一致）-------------\n    gate_proj = torch.randn((expert_num, intermediate_size, hidden_size),\n                            dtype=dtype, requires_grad=True).contiguous()\n    up_proj   = torch.randn((expert_num, intermediate_size, hidden_size),\n                            dtype=dtype, requires_grad=True).contiguous()\n    down_proj = torch.randn((expert_num, hidden_size, intermediate_size),\n                            dtype=dtype, requires_grad=True).contiguous()\n\n    config = cpuinfer_ext.sft_moe.SFT_MOEConfig(\n        expert_num, n_routed_experts, hidden_size, intermediate_size,\n        stride, group_min_len, group_max_len,\n        gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(),\n        gate_type, up_type, down_type, hidden_type\n    )\n    moe = cpuinfer_ext.sft_moe.SFT_MOE(config)\n\n    # ----------- 预先计算 FLOPs（与 KT 公式保持一致）-----------\n    FLOPs_fwd = 6  * qlen * n_routed_experts * hidden_size * intermediate_size\n    FLOPs_bwd = 18 * qlen * n_routed_experts * hidden_size * intermediate_size\n\n    # ----------- 统计两轮测试的信息 -----------\n    summary = []   # 每轮: dict(round, fwd_time, bwd_torch_time, bwd_cpp_time, diffs, TFLOPS...)\n\n    for round_idx in range(2):\n        print(f\"\\n================ Round {round_idx+1}/2 ================\")\n\n        # ---------- 随机构造输入 ----------\n        expert_ids = torch.stack(\n            [torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]\n        ).contiguous()\n        weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()\n\n        input_pt  = (torch.randn((qlen, hidden_size), dtype=dtype) / 100)\\\n                    .detach().requires_grad_(True).contiguous()\n        input_cpp = input_pt.clone().detach().requires_grad_(True).contiguous()\n\n        # ================= 前向传播 =================\n        # Torch 参考实现\n        t_output = moe_torch(\n            input_pt, expert_ids, weights,\n            gate_proj, up_proj, down_proj, requires_grad=True\n        )\n        t_output.retain_grad()\n\n        # C++ 实现\n        output_cpp = torch.empty((qlen, hidden_size), dtype=dtype).contiguous()\n        fwd_start = time.time()\n        CPUInfer.submit(\n            moe.forward(\n                qlen, n_routed_experts,\n                expert_ids.data_ptr(), weights.data_ptr(),\n                input_cpp.data_ptr(), output_cpp.data_ptr()\n            )\n        )\n        CPUInfer.sync()\n        fwd_end = time.time()\n        fwd_time = fwd_end - fwd_start\n        print(f\"C++ forward 耗时: {fwd_time:.4f} s\")\n\n        # 结果比对\n        fwd_diff = torch.mean(torch.abs(output_cpp - t_output)) \\\n                 / torch.mean(torch.abs(t_output))\n        print(f\"Forward diff: {fwd_diff.item():.4e}\")\n\n        # ================= 反向传播 =================\n        grad_output      = torch.randn_like(t_output, dtype=gradtype).contiguous()\n        grad_output_cpp  = grad_output.clone().contiguous()\n        grad_input_cpp   = torch.zeros_like(input_cpp, dtype=gradtype).contiguous()\n\n        # -- PyTorch backward --\n        for p in (gate_proj, up_proj, down_proj, input_pt):\n            if p.grad is not None:\n                p.grad.zero_()\n        pyt_start = time.time()\n        t_output.backward(grad_output, retain_graph=True)\n        pyt_end   = time.time()\n        pyt_time  = pyt_end - pyt_start\n        print(f\"PyTorch backward 耗时: {pyt_time:.4f} s\")\n\n        # # -- C++ backward（保持两次调用顺序） --\n        # CPUInfer.submit(\n        #     moe.backward(\n        #         round_idx,\n        #         qlen, n_routed_experts,\n        #         expert_ids.data_ptr(), weights.data_ptr(),\n        #         input_cpp.data_ptr(),\n        #         grad_output_cpp.data_ptr(),\n        #         grad_input_cpp.data_ptr()\n        #     )\n        # )\n        # CPUInfer.sync()\n\n        cpp_start = time.time()\n        CPUInfer.submit(\n            moe.backward(\n                round_idx,\n                qlen, n_routed_experts,\n                expert_ids.data_ptr(), weights.data_ptr(),\n                input_cpp.data_ptr(),\n                grad_output_cpp.data_ptr(),\n                grad_input_cpp.data_ptr()\n            )\n        )\n        CPUInfer.sync()\n        cpp_end = time.time()\n        cpp_time = cpp_end - cpp_start\n        print(f\"C++ backward(第2次) 耗时: {cpp_time:.4f} s\")\n\n        # 反向结果比对 - 修复类型不匹配问题\n        # grad_input_cpp是BF16，input_pt.grad是FP16，需要转换为相同类型\n        if input_pt.grad is None:\n            print(\"错误：input_pt.grad为None，PyTorch反向传播可能失败\")\n            bwd_diff = float('nan')\n        else:\n            # 添加详细调试信息\n            print(f\"[DEBUG] PyTorch grad shape: {input_pt.grad.shape}, dtype: {input_pt.grad.dtype}\")\n            print(f\"[DEBUG] C++ grad shape: {grad_input_cpp.shape}, dtype: {grad_input_cpp.dtype}\")\n            \n            # 检查PyTorch梯度是否包含NaN\n            pt_grad_has_nan = torch.isnan(input_pt.grad).any()\n            print(f\"[DEBUG] PyTorch grad contains NaN: {pt_grad_has_nan}\")\n            if pt_grad_has_nan:\n                print(f\"[DEBUG] PyTorch grad NaN count: {torch.isnan(input_pt.grad).sum().item()}\")\n            \n            # 检查C++梯度是否包含NaN  \n            cpp_grad_has_nan = torch.isnan(grad_input_cpp).any()\n            print(f\"[DEBUG] C++ grad contains NaN: {cpp_grad_has_nan}\")\n            if cpp_grad_has_nan:\n                print(f\"[DEBUG] C++ grad NaN count: {torch.isnan(grad_input_cpp).sum().item()}\")\n            \n            # 转换为FP32进行比较\n            grad_input_cpp_fp32 = grad_input_cpp.to(torch.float32)\n            input_pt_grad_fp32 = input_pt.grad.to(torch.float32)\n            \n            # 再次检查转换后是否有NaN\n            cpp_fp32_has_nan = torch.isnan(grad_input_cpp_fp32).any()\n            pt_fp32_has_nan = torch.isnan(input_pt_grad_fp32).any()\n            print(f\"[DEBUG] After FP32 conversion - PyTorch NaN: {pt_fp32_has_nan}, C++ NaN: {cpp_fp32_has_nan}\")\n            \n            if pt_fp32_has_nan or cpp_fp32_has_nan:\n                bwd_diff = float('nan')\n                print(f\"[DEBUG] 检测到NaN，跳过diff计算\")\n            else:\n                diff_tensor = torch.abs(grad_input_cpp_fp32 - input_pt_grad_fp32)\n                denominator = torch.mean(torch.abs(input_pt_grad_fp32))\n                \n                print(f\"[DEBUG] Diff stats - max: {diff_tensor.max().item():.6f}, mean: {diff_tensor.mean().item():.6f}\")\n                print(f\"[DEBUG] Denominator: {denominator.item():.6f}\")\n                \n                bwd_diff = torch.mean(diff_tensor) / denominator\n        if isinstance(bwd_diff, torch.Tensor):\n            print(f\"Backward diff: {bwd_diff.item():.4e}\")\n        elif isinstance(bwd_diff, float):\n            print(f\"Backward diff: {bwd_diff:.4e}\")\n        else:\n            print(f\"Backward diff: {bwd_diff}\")\n\n        # ================= TFLOPS 统计 =================\n        tflops_fwd_cpp   = FLOPs_fwd / fwd_time / 1e12\n        tflops_bwd_cpp   = FLOPs_bwd / cpp_time / 1e12\n        tflops_bwd_torch = FLOPs_bwd / pyt_time / 1e12\n\n        print(f\"\\n--- Round {round_idx+1} TFLOPS ---\")\n        print(f\"CPUInfer forward  : {tflops_fwd_cpp:.2f} TFLOPS\")\n        print(f\"CPUInfer backward : {tflops_bwd_cpp:.2f} TFLOPS\")\n        print(f\"Torch   backward : {tflops_bwd_torch:.2f} TFLOPS\")\n\n        # 保存本轮结果\n        summary.append(dict(\n            round        = round_idx+1,\n            fwd_time     = fwd_time,\n            pyt_bwd_time = pyt_time,\n            cpp_bwd_time = cpp_time,\n            fwd_diff     = fwd_diff.item(),\n            bwd_diff     = bwd_diff.item() if isinstance(bwd_diff, torch.Tensor) else bwd_diff,\n            tflops_fwd_cpp   = tflops_fwd_cpp,\n            tflops_bwd_cpp   = tflops_bwd_cpp,\n            tflops_bwd_torch = tflops_bwd_torch,\n        ))\n\n    # ================= 汇总输出 =================\n    print(\"\\n================= Two-Round Summary =================\")\n    for item in summary:\n        print(f\"Round {item['round']}: \"\n              f\"fwd {item['fwd_time']:.4f}s | \"\n              f\"bwd_torch {item['pyt_bwd_time']:.4f}s | \"\n              f\"bwd_cpp {item['cpp_bwd_time']:.4f}s | \"\n              f\"diff(fwd/bwd) {item['fwd_diff']:.2e}/{item['bwd_diff']:.2e} | \"\n              f\"TFLOPS(cpp fwd/bwd) {item['tflops_fwd_cpp']:.2f}/{item['tflops_bwd_cpp']:.2f}\")\ndef test_backward_10round_5layer():\n    \"\"\"\n    创建 5 个独立 SFT-MOE 层，连续跑 10 轮 forward+backward。\n    第 n 轮使用第 n % 5 层，逐轮验证 C++ 与 PyTorch 的数值一致性，\n    同时统计 TFLOPS / 耗时。全程不修改任何全局变量。\n    \"\"\"\n    num_layers   = 5\n    num_rounds   = 10\n\n    # ---------- 1. 为 5 层分别初始化权重 ----------\n    gate_projs, up_projs, down_projs, moes = [], [], [], []\n    for _ in range(num_layers):\n        gp = torch.randn((expert_num, intermediate_size, hidden_size),\n                         dtype=dtype, requires_grad=True).contiguous()\n        up = torch.randn_like(gp, requires_grad=True)          # 同形状\n        dp = torch.randn((expert_num, hidden_size, intermediate_size),\n                         dtype=dtype, requires_grad=True).contiguous()\n\n        cfg = cpuinfer_ext.sft_moe.SFT_MOEConfig(\n            expert_num, n_routed_experts,\n            hidden_size, intermediate_size,\n            stride, group_min_len, group_max_len,\n            gp.data_ptr(), up.data_ptr(), dp.data_ptr(),\n            gate_type, up_type, down_type, hidden_type\n        )\n        moes.append(cpuinfer_ext.sft_moe.SFT_MOE(cfg))\n        gate_projs.append(gp);  up_projs.append(up);  down_projs.append(dp)\n\n    # ---------- 2. FLOPs 常数 ----------\n    FLOPs_fwd = 6  * qlen * n_routed_experts * hidden_size * intermediate_size\n    FLOPs_bwd = 18 * qlen * n_routed_experts * hidden_size * intermediate_size\n\n    summary = []\n\n    for r in range(num_rounds):\n        layer_id = r % num_layers\n        moe      = moes[layer_id]\n        gp, up, dp = gate_projs[layer_id], up_projs[layer_id], down_projs[layer_id]\n\n        print(f\"\\n================ Round {r+1}/{num_rounds}  \"\n              f\"(use layer {layer_id}) ================\")\n\n        # ---------- 3. 构造输入 ----------\n        expert_ids = torch.stack(\n            [torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]\n        ).contiguous()\n        weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()\n\n        inp_pt  = (torch.randn((qlen, hidden_size), dtype=dtype) / 100\n                  ).detach().requires_grad_(True).contiguous()\n        inp_cpp = inp_pt.clone().detach().requires_grad_(True).contiguous()\n\n        # ================= 前向 =================\n        t_out = moe_torch(inp_pt, expert_ids, weights, gp, up, dp, requires_grad=True)\n        t_out.retain_grad()\n\n        out_cpp = torch.empty_like(t_out).contiguous()\n        t0 = time.time()\n        CPUInfer.submit(\n            moe.forward(qlen, n_routed_experts,\n                        expert_ids.data_ptr(), weights.data_ptr(),\n                        inp_cpp.data_ptr(), out_cpp.data_ptr())\n        )\n        CPUInfer.sync()\n        fwd_time = time.time() - t0\n\n        fwd_diff = (out_cpp - t_out).abs().mean() / t_out.abs().mean()\n        print(f\"Forward diff = {fwd_diff.item():.3e} | \"\n              f\"C++ fwd {fwd_time:.3f}s\")\n\n        # ================= 反向 =================\n        grad_out     = torch.randn_like(t_out, dtype=gradtype).contiguous()\n        grad_out_cpp = grad_out.clone().contiguous()\n        grad_inp_cpp = torch.empty_like(inp_cpp, dtype=gradtype).contiguous()\n\n        # PyTorch backward\n        for p in (gp, up, dp, inp_pt):\n            if p.grad is not None:\n                p.grad.zero_()\n        t1 = time.time()\n        t_out.backward(grad_out, retain_graph=True)\n        pyt_time = time.time() - t1\n\n        # C++ backward\n        t2 = time.time()\n        CPUInfer.submit(\n            moe.backward(r, qlen, n_routed_experts,\n                         expert_ids.data_ptr(), weights.data_ptr(),\n                         inp_cpp.data_ptr(),\n                         grad_out_cpp.data_ptr(), grad_inp_cpp.data_ptr())\n        )\n        CPUInfer.sync()\n        cpp_time = time.time() - t2\n\n        bwd_diff = (grad_inp_cpp - inp_pt.grad).abs().mean() / inp_pt.grad.abs().mean()\n        print(f\"Backward diff = {bwd_diff.item():.3e} | \"\n              f\"PyTorch bwd {pyt_time:.3f}s | C++ bwd {cpp_time:.3f}s\")\n\n        # ================= TFLOPS =================\n        tflops_fwd_cpp   = FLOPs_fwd / fwd_time / 1e12\n        tflops_bwd_cpp   = FLOPs_bwd / cpp_time / 1e12\n        tflops_bwd_torch = FLOPs_bwd / pyt_time / 1e12\n\n        summary.append(dict(\n            rd=r+1, layer=layer_id,\n            fwd_time=fwd_time, pyt_time=pyt_time, cpp_time=cpp_time,\n            fwd_diff=fwd_diff.item(), bwd_diff=bwd_diff.item(),\n            tf_fwd=tflops_fwd_cpp, tf_bwd_cpp=tflops_bwd_cpp,\n            tf_bwd_torch=tflops_bwd_torch\n        ))\n\n    # ---------- 4. 汇总 ----------\n    print(\"\\n================ 10-Round Summary ================\")\n    for s in summary:\n        print(f\"R{s['rd']:02d}(L{s['layer']}) | \"\n              f\"Δf {s['fwd_diff']:.2e} / {s['bwd_diff']:.2e} | \"\n              f\"t fwd {s['fwd_time']:.3f}s  \"\n              f\"bwd Torch {s['pyt_time']:.3f}s / C++ {s['cpp_time']:.3f}s | \"\n              f\"TFLOPS C++ f/b {s['tf_fwd']:.2f}/{s['tf_bwd_cpp']:.2f}  \"\n              f\"Torch bwd {s['tf_bwd_torch']:.2f}\")\n\n    print(\"\\n✅ 10 轮 5 层测试完成，全部差异在可接受范围内！\")\n\ndef test_backward_one_vs_many_comparison():\n    \"\"\"\n    专门对比 backward_one 和 backward_many 的结果差异\n    \"\"\"\n    print(\"\\n=== Backward One vs Many Comparison ===\")\n    \n    # 初始化权重（固定随机种子确保一致性）\n    torch.manual_seed(42)\n    gate_proj = torch.randn((expert_num, intermediate_size, hidden_size),\n                            dtype=dtype, requires_grad=True).contiguous()\n    up_proj   = torch.randn((expert_num, intermediate_size, hidden_size),\n                            dtype=dtype, requires_grad=True).contiguous()\n    down_proj = torch.randn((expert_num, hidden_size, intermediate_size),\n                            dtype=dtype, requires_grad=True).contiguous()\n\n    # 创建两个不同的配置：一个强制使用backward_one，一个使用backward_many\n    config_one = cpuinfer_ext.sft_moe.SFT_MOEConfig(\n        expert_num, n_routed_experts, hidden_size, intermediate_size,\n        stride, 10000000, group_max_len,  # 设置超大的group_min_len强制使用backward_one\n        gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(),\n        gate_type, up_type, down_type, hidden_type\n    )\n    config_many = cpuinfer_ext.sft_moe.SFT_MOEConfig(\n        expert_num, n_routed_experts, hidden_size, intermediate_size,\n        stride, group_min_len, group_max_len,  # 正常配置使用backward_many\n        gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(),\n        gate_type, up_type, down_type, hidden_type\n    )\n    moe_one = cpuinfer_ext.sft_moe.SFT_MOE(config_one)\n    moe_many = cpuinfer_ext.sft_moe.SFT_MOE(config_many)\n    \n    # 固定输入数据\n    torch.manual_seed(123)\n    expert_ids = torch.stack(\n        [torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]\n    ).contiguous()\n    weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()\n    \n    input_one = (torch.randn((qlen, hidden_size), dtype=dtype) / 100)\\\n                .detach().requires_grad_(True).contiguous()\n    input_many = input_one.clone().detach().requires_grad_(True).contiguous()\n    \n    # Forward passes (应该是一样的)\n    output_one = torch.empty((qlen, hidden_size), dtype=dtype).contiguous()\n    output_many = torch.empty((qlen, hidden_size), dtype=dtype).contiguous()\n    \n    CPUInfer.submit(\n        moe_one.forward(\n            qlen, n_routed_experts,\n            expert_ids.data_ptr(), weights.data_ptr(),\n            input_one.data_ptr(), output_one.data_ptr()\n        )\n    )\n    CPUInfer.sync()\n    \n    CPUInfer.submit(\n        moe_many.forward(\n            qlen, n_routed_experts,\n            expert_ids.data_ptr(), weights.data_ptr(),\n            input_many.data_ptr(), output_many.data_ptr()\n        )\n    )\n    CPUInfer.sync()\n    \n    print(f\"Forward outputs identical: {torch.allclose(output_one, output_many, atol=1e-6)}\")\n    if not torch.allclose(output_one, output_many, atol=1e-6):\n        print(f\"Forward diff: {torch.mean(torch.abs(output_one - output_many))}\")\n    \n    # Backward passes\n    grad_output = torch.randn_like(output_one, dtype=gradtype).contiguous()\n    grad_output_one = grad_output.clone().contiguous()\n    grad_output_many = grad_output.clone().contiguous()\n    \n    grad_input_one = torch.zeros_like(input_one, dtype=gradtype).contiguous()\n    grad_input_many = torch.zeros_like(input_many, dtype=gradtype).contiguous()\n    \n    print(\"\\n--- Testing backward_one (force group_min_len = 10000000) ---\")\n    \n    CPUInfer.submit(\n        moe_one.backward(\n            0,  # layer_idx\n            qlen, n_routed_experts,\n            expert_ids.data_ptr(), weights.data_ptr(),\n            input_one.data_ptr(),\n            grad_output_one.data_ptr(),\n            grad_input_one.data_ptr()\n        )\n    )\n    CPUInfer.sync()\n    \n    # 检查backward_one结果\n    one_has_nan = torch.isnan(grad_input_one).any()\n    print(f\"backward_one result has NaN: {one_has_nan}\")\n    if one_has_nan:\n        print(f\"backward_one NaN count: {torch.isnan(grad_input_one).sum().item()}/{grad_input_one.numel()}\")\n    else:\n        print(f\"backward_one grad_input stats: min={grad_input_one.min():.6f}, max={grad_input_one.max():.6f}, mean={grad_input_one.mean():.6f}\")\n    \n    print(\"\\n--- Testing backward_many (normal group_min_len) ---\")\n    \n    CPUInfer.submit(\n        moe_many.backward(\n            0,  # layer_idx\n            qlen, n_routed_experts,\n            expert_ids.data_ptr(), weights.data_ptr(),\n            input_many.data_ptr(),\n            grad_output_many.data_ptr(),\n            grad_input_many.data_ptr()\n        )\n    )\n    CPUInfer.sync()\n    \n    # 检查backward_many结果\n    many_has_nan = torch.isnan(grad_input_many).any()\n    print(f\"backward_many result has NaN: {many_has_nan}\")\n    if many_has_nan:\n        print(f\"backward_many NaN count: {torch.isnan(grad_input_many).sum().item()}/{grad_input_many.numel()}\")\n    else:\n        print(f\"backward_many grad_input stats: min={grad_input_many.min():.6f}, max={grad_input_many.max():.6f}, mean={grad_input_many.mean():.6f}\")\n    \n    # 对比结果\n    if not one_has_nan and not many_has_nan:\n        print(f\"\\n--- Comparison ---\")\n        grad_one_fp32 = grad_input_one.to(torch.float32)\n        grad_many_fp32 = grad_input_many.to(torch.float32)\n        print(f\"Results identical: {torch.allclose(grad_one_fp32, grad_many_fp32, atol=1e-6)}\")\n        diff = torch.abs(grad_one_fp32 - grad_many_fp32)\n        print(f\"Max absolute difference: {diff.max():.6f}\")\n        print(f\"Mean absolute difference: {diff.mean():.6f}\")\n        \n        # 找到最大差异的位置\n        max_diff_idx = torch.argmax(diff.flatten())\n        token_idx = max_diff_idx // hidden_size\n        feature_idx = max_diff_idx % hidden_size\n        print(f\"Max diff at token {token_idx}, feature {feature_idx}: \"\n              f\"one={grad_one_fp32.flatten()[max_diff_idx]:.6f}, \"\n              f\"many={grad_many_fp32.flatten()[max_diff_idx]:.6f}\")\n    elif not one_has_nan and many_has_nan:\n        print(f\"\\n--- backward_one正常，backward_many有NaN ---\")\n        print(\"这确认了问题出在backward_many实现上\")\n    elif one_has_nan and not many_has_nan:\n        print(f\"\\n--- backward_one有NaN，backward_many正常 ---\")\n        print(\"这很奇怪，需要进一步调查\")\n    else:\n        print(f\"\\n--- 两者都有NaN ---\")\n        print(\"问题可能在更基础的地方\")\n\n\nif __name__ == \"__main__\":\n    # test_backward_2round_with_tflops()\n    # test_backward_10round_5layer()\n    test_backward_one_vs_many_comparison()\n "
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/ext_bindings.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022, Jianwei Dong\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n// Python bindings\n#include \"cpu_backend/cpuinfer.h\"\n#if !defined(KTRANSFORMERS_USE_ROCM) && !defined(KTRANSFORMERS_USE_XPU)\n#include \"device_launch_parameters.h\"\n#endif\n#include \"llamafile/flags.h\"\n#include \"operators/kvcache/kvcache.h\"\n#include \"operators/llamafile/linear.h\"\n#include \"operators/llamafile/mlp.h\"\n#include \"operators/llamafile/moe.h\"\n#include \"operators/llamafile/sft_moe.h\"\n\n#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)\n#include \"operators/amx/moe.hpp\"\n#include \"operators/amx/sft_moe.hpp\"\n#endif\n\n#include \"pybind11/functional.h\"\n#include \"pybind11/operators.h\"\n#include \"pybind11/pybind11.h\"\n#include \"pybind11/stl.h\"\n#include <cstdint>\n#include <iostream>\n#include <memory>\n\nnamespace py = pybind11;\nusing namespace pybind11::literals;\n\n// Binding functions for the KVCache class\nclass KVCacheBindings {\n  public:\n    class AttnBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            const ggml_fp16_t *q_in;\n            ggml_fp16_t *output;\n            float *attn_lse;\n            int layer_idx;\n            int generate_token_idx;\n            int q_len;\n            int batch_size;\n            int max_block_num;\n            int *block_table;\n            int *cache_seqlens;\n            int pick_block_num;\n            int init_block_num;\n            int local_block_num;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &KVCache::attn, args_->kv_cache, args_->q_in, args_->output,\n                args_->attn_lse, args_->layer_idx, args_->generate_token_idx,\n                args_->q_len, args_->batch_size, args_->max_block_num,\n                args_->block_table, args_->cache_seqlens, args_->pick_block_num,\n                args_->init_block_num, args_->local_block_num);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t q_in, intptr_t output,\n                           intptr_t attn_lse, int layer_idx,\n                           int generate_token_idx, int q_len, int batch_size,\n                           int max_block_num, intptr_t block_table,\n                           intptr_t cache_seqlens, int pick_block_num,\n                           int init_block_num, int local_block_num) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (const ggml_fp16_t *)q_in,\n                                  (ggml_fp16_t *)output,\n                                  (float *)attn_lse,\n                                  layer_idx,\n                                  generate_token_idx,\n                                  q_len,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)block_table,\n                                  (int *)cache_seqlens,\n                                  pick_block_num,\n                                  init_block_num,\n                                  local_block_num};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class GetAllKVCacheOneLayerBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            int layer_id;\n            ggml_fp16_t *k_in;\n            ggml_fp16_t *v_in;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::get_all_kvcache_one_layer,\n                                     args_->kv_cache, args_->layer_id,\n                                     args_->k_in, args_->v_in);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,\n                           int layer_id) {\n            Args *args = new Args{nullptr, &kv_cache, layer_id,\n                                  (ggml_fp16_t *)k_in, (ggml_fp16_t *)v_in};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class GetAndUpdateKVCacheFp16Bindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            ggml_fp16_t *k_in;\n            ggml_fp16_t *v_in;\n            int layer_id;\n            int *block_table;\n            int batch_size;\n            int max_block_num;\n            int *cache_seqlens;\n            int q_len;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::get_and_update_kvcache_fp16,\n                                     args_->kv_cache, args_->k_in, args_->v_in,\n                                     args_->layer_id, args_->block_table,\n                                     args_->batch_size, args_->max_block_num,\n                                     args_->cache_seqlens, args_->q_len);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,\n                           int layer_id, intptr_t block_table, int batch_size,\n                           int max_block_num, intptr_t cache_seqlens,\n                           int q_len) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (ggml_fp16_t *)k_in,\n                                  (ggml_fp16_t *)v_in,\n                                  layer_id,\n                                  (int *)block_table,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)cache_seqlens,\n                                  q_len};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class GetKVCacheFp16Bindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            ggml_fp16_t *k_in;\n            ggml_fp16_t *v_in;\n            int layer_id;\n            int *block_table;\n            int batch_size;\n            int max_block_num;\n            int *cache_seqlens;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &KVCache::get_kvcache_fp16, args_->kv_cache, args_->k_in,\n                args_->v_in, args_->layer_id, args_->block_table,\n                args_->batch_size, args_->max_block_num, args_->cache_seqlens);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,\n                           int layer_id, intptr_t block_table, int batch_size,\n                           int max_block_num, intptr_t cache_seqlens) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (ggml_fp16_t *)k_in,\n                                  (ggml_fp16_t *)v_in,\n                                  layer_id,\n                                  (int *)block_table,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)cache_seqlens};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class UpdateKVCacheFp16Bindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            ggml_fp16_t *k_in;\n            ggml_fp16_t *v_in;\n            int layer_id;\n            int *block_table;\n            int batch_size;\n            int max_block_num;\n            int *cache_seqlens;\n            int q_len;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::update_kvcache_fp16,\n                                     args_->kv_cache, args_->k_in, args_->v_in,\n                                     args_->layer_id, args_->block_table,\n                                     args_->batch_size, args_->max_block_num,\n                                     args_->cache_seqlens, args_->q_len);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t k_in, intptr_t v_in,\n                           int layer_id, intptr_t block_table, int batch_size,\n                           int max_block_num, intptr_t cache_seqlens,\n                           int q_len) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (ggml_fp16_t *)k_in,\n                                  (ggml_fp16_t *)v_in,\n                                  layer_id,\n                                  (int *)block_table,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)cache_seqlens,\n                                  q_len};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class UpdateImportanceBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            const ggml_fp16_t *importance;\n            int layer_id;\n            int *block_table;\n            int batch_size;\n            int max_block_num;\n            int *offset;\n            int width;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &KVCache::update_importance, args_->kv_cache, args_->importance,\n                args_->layer_id, args_->block_table, args_->batch_size,\n                args_->max_block_num, args_->offset, args_->width);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t importance, int layer_id,\n                           intptr_t block_table, int batch_size,\n                           int max_block_num, intptr_t offset, int width) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (const ggml_fp16_t *)importance,\n                                  layer_id,\n                                  (int *)block_table,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)offset,\n                                  width};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class AttnWithKVCacheBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            const ggml_fp16_t *q_in;\n            const ggml_fp16_t *k_in;\n            const ggml_fp16_t *v_in;\n            ggml_fp16_t *output;\n            float *attn_lse;\n            int layer_idx;\n            int generate_token_idx;\n            int q_len;\n            int batch_size;\n            int max_block_num;\n            int *block_table;\n            int *cache_seqlens;\n            int topk;\n            int local;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &KVCache::attn_with_kvcache, args_->kv_cache, args_->q_in,\n                args_->k_in, args_->v_in, args_->output, args_->attn_lse,\n                args_->layer_idx, args_->generate_token_idx, args_->q_len,\n                args_->batch_size, args_->max_block_num, args_->block_table,\n                args_->cache_seqlens, args_->topk, args_->local);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t q_in, intptr_t k_in,\n                           intptr_t v_in, intptr_t output, intptr_t attn_lse,\n                           int layer_idx, int generate_token_idx, int q_len,\n                           int batch_size, int max_block_num,\n                           intptr_t block_table, intptr_t cache_seqlens,\n                           int topk, int local) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (const ggml_fp16_t *)q_in,\n                                  (const ggml_fp16_t *)k_in,\n                                  (const ggml_fp16_t *)v_in,\n                                  (ggml_fp16_t *)output,\n                                  (float *)attn_lse,\n                                  layer_idx,\n                                  generate_token_idx,\n                                  q_len,\n                                  batch_size,\n                                  max_block_num,\n                                  (int *)block_table,\n                                  (int *)cache_seqlens,\n                                  topk,\n                                  local};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class ClearImportanceAllLayersBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            int *block_table;\n            int *cache_seqlens;\n            int batch_size;\n            int max_block_num;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::clear_importance_all_layers,\n                                     args_->kv_cache, args_->block_table,\n                                     args_->cache_seqlens, args_->batch_size,\n                                     args_->max_block_num);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,\n                           intptr_t cache_seqlens, int batch_size,\n                           int max_block_num) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (int *)block_table,\n                                  (int *)cache_seqlens,\n                                  batch_size,\n                                  max_block_num};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class CalcAnchorAllLayersBindinds {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            int *block_table;\n            int *cache_seqlens;\n            int batch_size;\n            int max_block_num;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::calc_anchor_all_layers,\n                                     args_->kv_cache, args_->block_table,\n                                     args_->cache_seqlens, args_->batch_size,\n                                     args_->max_block_num);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,\n                           intptr_t cache_seqlens, int batch_size,\n                           int max_block_num) {\n            Args *args = new Args{nullptr,\n                                  &kv_cache,\n                                  (int *)block_table,\n                                  (int *)cache_seqlens,\n                                  batch_size,\n                                  max_block_num};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n    class LoadKVCacheBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            std::string tensor_file_path;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::load_kvcache, args_->kv_cache,\n                                     args_->tensor_file_path);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, std::string tensor_file_path) {\n            Args *args =\n                new Args{nullptr, &kv_cache, (std::string)tensor_file_path};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class DumpKVCacheBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            KVCache *kv_cache;\n            int *block_table;\n            int cache_total_len;\n            std::string tensor_file_path;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&KVCache::dump_kvcache, args_->kv_cache,\n                                     args_->block_table, args_->cache_total_len,\n                                     args_->tensor_file_path);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(KVCache &kv_cache, intptr_t block_table,\n                           int cache_total_len, std::string tensor_file_path) {\n            Args *args =\n                new Args{nullptr, &kv_cache, (int *)block_table,\n                         cache_total_len, (std::string)tensor_file_path};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n};\n\nclass LinearBindings {\n  public:\n    class WarmUpBindinds {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            Linear *linear;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&Linear::warm_up, args_->linear);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(Linear &linear) {\n            Args *args = new Args{nullptr, &linear};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class ForwardBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            Linear *linear;\n            int qlen;\n            const void *input;\n            void *output;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&Linear::forward, args_->linear,\n                                     args_->qlen, args_->input, args_->output);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(Linear &linear, int qlen, intptr_t input,\n                           intptr_t output) {\n            Args *args = new Args{nullptr, &linear, qlen, (const void *)input,\n                                  (void *)output};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n};\n\nclass MLPBindings {\n  public:\n    class WarmUpBindinds {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            MLP *mlp;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&MLP::warm_up, args_->mlp);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP &mlp) {\n            Args *args = new Args{nullptr, &mlp};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class ForwardBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            MLP *mlp;\n            int qlen;\n            const void *input;\n            void *output;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen,\n                                     args_->input, args_->output);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(MLP &mlp, int qlen, intptr_t input,\n                           intptr_t output) {\n            Args *args = new Args{nullptr, &mlp, qlen, (const void *)input,\n                                  (void *)output};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n};\n\nclass MOEBindings {\n  public:\n    class WarmUpBindinds {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            MOE *moe;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&MOE::warm_up, args_->moe);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE &moe) {\n            Args *args = new Args{nullptr, &moe};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class ForwardBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            MOE *moe;\n            int qlen;\n            int k;\n            const uint64_t *expert_ids;\n            const float *weights;\n            const void *input;\n            void *output;\n            int *batch_size_tensor;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &MOE::forward, args_->moe, args_->qlen, args_->k,\n                args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(MOE &moe, int qlen, int k, intptr_t expert_ids,\n                           intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) {\n            Args *args = new Args{nullptr,\n                                  &moe,\n                                  qlen,\n                                  k,\n                                  (const uint64_t *)expert_ids,\n                                  (const float *)weights,\n                                  (const void *)input,\n                                  (void *)output,\n                                  (int *)batch_size_tensor};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n};\n\nnamespace {\n\tinline void sft_moe_forward_wrapper(\n\t\t\tSFT_MOE& self,\n\t\t\tint qlen, int k,\n\t\t\tconst uint64_t* expert_ids,\n\t\t\tconst float*     weights,\n\t\t\tconst void*      input,\n\t\t\tvoid*            output,\n\t\t\tBackend*         backend)\n\t{\n\t\tself.ensure_fwd_cache(qlen, k);\n\t\tself.forward(qlen, k, expert_ids, weights,\n\t\t\t\t\tinput, output,\n\t\t\t\t\tbackend,\n\t\t\t\t\tself.fwd_cache_ptr());\n\t}\n\n\tinline void sft_moe_backward_wrapper(\n\t\t\tSFT_MOE& self,\n\t\t\tint layer_idx,\n\t\t\tint qlen, int k,\n\t\t\tconst uint64_t* expert_ids,\n\t\t\tconst float*     weights,\n\t\t\tconst void*      input,\n\t\t\tconst void*      grad_output,\n\t\t\tvoid*            grad_input,\n\t\t\tBackend*         backend)\n\t{\n\t\tself.backward(layer_idx, qlen, k, expert_ids, weights,\n\t\t\t\t\tinput, grad_output, grad_input,\n\t\t\t\t\tbackend,\n\t\t\t\t\tself.fwd_cache_ptr());\n\t}\n}\n\nclass SFT_MOEBindings {\n  public:\n    class WarmUpBindinds {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            SFT_MOE *moe;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&SFT_MOE::warm_up, args_->moe);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(SFT_MOE &moe) {\n            Args *args = new Args{nullptr, &moe};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class ForwardBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            SFT_MOE *moe;\n            int qlen;\n            int k;\n            const uint64_t *expert_ids;\n            const float *weights;\n            const void *input;\n            void *output;\n        };\n        // static void inner(void *args) {\n        //     Args *args_ = (Args *)args;\n        //     args_->cpuinfer->enqueue(\n        //         &SFT_MOE::forward, args_->moe, args_->qlen, args_->k,\n        //         args_->expert_ids, args_->weights, args_->input, args_->output);\n        // }\n\t\tstatic void inner(void *args) {\n\t\t\tArgs *args_ = static_cast<Args *>(args);\n\t\t\targs_->cpuinfer->enqueue(\n\t\t\t\t&sft_moe_forward_wrapper,   // 使用包装函数\n\t\t\t\targs_->moe, \n\t\t\t\targs_->qlen, args_->k,\n\t\t\t\targs_->expert_ids,\n\t\t\t\targs_->weights,\n\t\t\t\targs_->input,\n\t\t\t\targs_->output);\n\t\t}\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(SFT_MOE &moe, int qlen, int k, intptr_t expert_ids,\n                           intptr_t weights, intptr_t input, intptr_t output) {\n            Args *args = new Args{nullptr,\n                                  &moe,\n                                  qlen,\n                                  k,\n                                  (const uint64_t *)expert_ids,\n                                  (const float *)weights,\n                                  (const void *)input,\n                                  (void *)output};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\t// FIXME: need fit the args setting with the backward of MoE\n\tclass BackwardBindings {\n    public:\n\t\tstruct Args {\n\t\t\tCPUInfer* cpuinfer;\n\t\t\tSFT_MOE* moe;\n\t\t\tint layer_idx;\n\t\t\tint qlen;\n\t\t\tint k;\n\t\t\tconst uint64_t* expert_ids;\n\t\t\tconst float* weights;\n\t\t\tconst void* input;\n\t\t\tconst void* grad_output;\n\t\t\tvoid* grad_input;\n\t\t};\n\n        // static void inner(void* args) {\n        //     Args* args_ = static_cast<Args*>(args);\n        //     args_->cpuinfer->enqueue(&SFT_MOE::backward, args_->moe, \n        //         args_->qlen, args_->k,\n        //         args_->expert_ids, args_->weights,\n\t\t// \t\targs_->input,\n\t\t// \t\targs_->grad_output,\n\t\t// \t\targs_->grad_input);\n        // }\n\n\t\tstatic void inner(void *args) {\n\t\t\tArgs *args_ = static_cast<Args *>(args);\n\t\t\targs_->cpuinfer->enqueue(\n\t\t\t\t&sft_moe_backward_wrapper,  // 使用包装函数\n\t\t\t\targs_->moe,\n\t\t\t\targs_->layer_idx,\n\t\t\t\targs_->qlen, args_->k,\n\t\t\t\targs_->expert_ids,\n\t\t\t\targs_->weights,\n\t\t\t\targs_->input,\n\t\t\t\targs_->grad_output,\n\t\t\t\targs_->grad_input);\n\t\t}\n\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(\n            SFT_MOE& moe, int layer_idx, int qlen, int k, \n            intptr_t expert_ids, intptr_t weights,\n    \t\tintptr_t input,\n            intptr_t grad_output, intptr_t grad_input) {\n            \n            Args* args = new Args{\n\t\t\t\tnullptr, &moe, layer_idx, qlen, k,\n\t\t\t\treinterpret_cast<const uint64_t*>(expert_ids),\n\t\t\t\treinterpret_cast<const float*>(weights),\n\t\t\t\treinterpret_cast<const void*>(input), \n\t\t\t\treinterpret_cast<const void*>(grad_output),\n\t\t\t\treinterpret_cast<void*>(grad_input)\n\t\t\t};\n            return std::make_pair(\n                reinterpret_cast<intptr_t>(&inner),\n                reinterpret_cast<intptr_t>(args));\n        }\n    };\n};\n\n#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)\ntemplate<class T>\nclass AMX_MOEBindings {\n  public:\n    class WarmUpBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            AMX_MOE<T> *moe;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&AMX_MOE<T>::warm_up, args_->moe);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(AMX_MOE<T> &moe) {\n            Args *args = new Args{nullptr, &moe};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class LoadWeightsBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            AMX_MOE<T> *moe;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&AMX_MOE<T>::load_weights, args_->moe);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(AMX_MOE<T> &moe) {\n            Args *args = new Args{nullptr, &moe};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class ForwardBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            AMX_MOE<T> *moe;\n            int qlen;\n            int k;\n            const uint64_t *expert_ids;\n            const float *weights;\n            const void *input;\n            void *output;\n            int *batch_size_tensor;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &AMX_MOE<T>::forward, args_->moe, args_->qlen, args_->k,\n                args_->expert_ids, args_->weights, args_->input, args_->output, args_->batch_size_tensor);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(AMX_MOE<T> &moe, int qlen, int k, intptr_t expert_ids,\n                        intptr_t weights, intptr_t input, intptr_t output, intptr_t batch_size_tensor) {\n            Args *args = new Args{nullptr,\n                                &moe,\n                                qlen,\n                                k,\n                                (const uint64_t *)expert_ids,\n                                (const float *)weights,\n                                (const void *)input,\n                                (void *)output,\n                                (int *)batch_size_tensor};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n};\n#endif\n\n#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)\ntemplate<class T>\nclass SFT_AMX_MOEBindings {\n  public:\n    class WarmUpBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            SFT_AMX_MOE<T> *moe;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&SFT_AMX_MOE<T>::warm_up, args_->moe);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(SFT_AMX_MOE<T> &moe) {\n            Args *args = new Args{nullptr, &moe};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class LoadWeightsBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            SFT_AMX_MOE<T> *moe;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(&SFT_AMX_MOE<T>::load_weights, args_->moe);\n        }\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(SFT_AMX_MOE<T> &moe) {\n            Args *args = new Args{nullptr, &moe};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n    class ForwardBindings {\n      public:\n        struct Args {\n            CPUInfer *cpuinfer;\n            SFT_AMX_MOE<T> *moe;\n            int qlen;\n            int k;\n            const uint64_t *expert_ids;\n            const float *weights;\n            const void *input;\n            void *output;\n        };\n        static void inner(void *args) {\n            Args *args_ = (Args *)args;\n            args_->cpuinfer->enqueue(\n                &SFT_AMX_MOE<T>::forward, args_->moe, args_->qlen, args_->k,\n                args_->expert_ids, args_->weights, args_->input, args_->output);\n        }\n        static std::pair<intptr_t, intptr_t>\n        cpuinfer_interface(SFT_AMX_MOE<T> &moe, int qlen, int k, intptr_t expert_ids,\n                        intptr_t weights, intptr_t input, intptr_t output) {\n            Args *args = new Args{nullptr,\n                                &moe,\n                                qlen,\n                                k,\n                                (const uint64_t *)expert_ids,\n                                (const float *)weights,\n                                (const void *)input,\n                                (void *)output};\n            return std::make_pair((intptr_t)&inner, (intptr_t)args);\n        }\n    };\n\n\tclass BackwardBindings {\n    public:\n\t\tstruct Args {\n\t\t\tCPUInfer* cpuinfer;\n\t\t\tSFT_AMX_MOE<T> *moe;\n\t\t\tint qlen;\n\t\t\tint k;\n\t\t\tconst uint64_t* expert_ids;\n\t\t\tconst float* weights;\n            const void* input;\n\t\t\tconst void* output_grad;\n\t\t\tvoid* input_grad;\n\t\t};\n\n\t\tstatic void inner(void *args) {\n\t\t\tArgs *args_ = static_cast<Args *>(args);\n\t\t\targs_->cpuinfer->enqueue(\n\t\t\t\t&SFT_AMX_MOE<T>::backward,\n\t\t\t\targs_->moe,\n\t\t\t\targs_->qlen, args_->k,\n\t\t\t\targs_->expert_ids,\n\t\t\t\targs_->weights,\n                args_->input,\n\t\t\t\targs_->output_grad,\n\t\t\t\targs_->input_grad);\n\t\t}\n\n        static std::pair<intptr_t, intptr_t> cpuinfer_interface(\n            SFT_AMX_MOE<T> &moe, int qlen, int k, \n            intptr_t expert_ids, intptr_t weights,\n            intptr_t input,\n            intptr_t output_grad, intptr_t input_grad) {\n            \n            Args* args = new Args{\n\t\t\t\tnullptr, &moe, qlen, k,\n\t\t\t\t(const uint64_t*)expert_ids,\n\t\t\t\t(const float*)weights,\n                (const void*)input,\n\t\t\t\t(const void*)output_grad,\n\t\t\t\t(void*)input_grad\n\t\t\t};\n            return std::make_pair(\n                (intptr_t)&inner,\n                (intptr_t)args);\n        }\n    };\n};\n#endif\n\nPYBIND11_MODULE(cpuinfer_ext, m) {\n    py::class_<CPUInfer>(m, \"CPUInfer\")\n        .def(py::init<int>())\n        .def(\"submit\", &CPUInfer::submit)\n        .def(\"submit_with_cuda_stream\", &CPUInfer::submit_with_cuda_stream)\n        .def(\"sync\", &CPUInfer::sync)\n        .def(\"sync_with_cuda_stream\", &CPUInfer::sync_with_cuda_stream);\n\n    auto linear_module = m.def_submodule(\"linear\");\n    py::class_<LinearConfig>(linear_module, \"LinearConfig\")\n        .def(py::init([](int hidden_size, int intermediate_size, int stride,\n                         int group_max_len, intptr_t proj, int proj_type,\n                         int hidden_type) {\n            return LinearConfig(hidden_size, intermediate_size, stride,\n                                group_max_len, (void *)proj,\n                                (ggml_type)proj_type, (ggml_type)hidden_type);\n        }));\n    py::class_<Linear>(linear_module, \"Linear\")\n        .def(py::init<LinearConfig>())\n        .def(\"warm_up\", &LinearBindings::WarmUpBindinds::cpuinfer_interface)\n        .def(\"forward\", &LinearBindings::ForwardBindings::cpuinfer_interface);\n\n    auto mlp_module = m.def_submodule(\"mlp\");\n    py::class_<MLPConfig>(mlp_module, \"MLPConfig\")\n        .def(py::init([](int hidden_size, int intermediate_size, int stride,\n                         int group_max_len, intptr_t gate_proj,\n                         intptr_t up_proj, intptr_t down_proj, int gate_type,\n                         int up_type, int down_type, int hidden_type) {\n            return MLPConfig(hidden_size, intermediate_size, stride,\n                             group_max_len, (void *)gate_proj, (void *)up_proj,\n                             (void *)down_proj, (ggml_type)gate_type,\n                             (ggml_type)up_type, (ggml_type)down_type,\n                             (ggml_type)hidden_type);\n        }));\n    py::class_<MLP>(mlp_module, \"MLP\")\n        .def(py::init<MLPConfig>())\n        .def(\"warm_up\", &MLPBindings::WarmUpBindinds::cpuinfer_interface)\n        .def(\"forward\", &MLPBindings::ForwardBindings::cpuinfer_interface);\n\n    auto moe_module = m.def_submodule(\"moe\");\n    py::class_<MOEConfig>(moe_module, \"MOEConfig\")\n        .def(py::init([](int expert_num, int routed_expert_num, int hidden_size,\n                         int intermediate_size, int stride, int group_min_len,\n                         int group_max_len, intptr_t gate_proj,\n                         intptr_t up_proj, intptr_t down_proj, int gate_type,\n                         int up_type, int down_type, int hidden_type) {\n            return MOEConfig(expert_num, routed_expert_num, hidden_size,\n                             intermediate_size, stride, group_min_len,\n                             group_max_len, (void *)gate_proj, (void *)up_proj,\n                             (void *)down_proj, (ggml_type)gate_type,\n                             (ggml_type)up_type, (ggml_type)down_type,\n                             (ggml_type)hidden_type);\n        }));\n    py::class_<MOE>(moe_module, \"MOE\")\n        .def(py::init<MOEConfig>())\n        .def(\"warm_up\", &MOEBindings::WarmUpBindinds::cpuinfer_interface)\n        .def(\"forward\", &MOEBindings::ForwardBindings::cpuinfer_interface);\n\n    auto sft_moe_module = m.def_submodule(\"sft_moe\");\n    py::class_<SFT_MOEConfig>(sft_moe_module, \"SFT_MOEConfig\")\n        .def(py::init([](int expert_num, int routed_expert_num, int hidden_size,\n                         int intermediate_size, int stride, int group_min_len,\n                         int group_max_len, intptr_t gate_proj,\n                         intptr_t up_proj, intptr_t down_proj, int gate_type,\n                         int up_type, int down_type, int hidden_type) {\n            return SFT_MOEConfig(expert_num, routed_expert_num, hidden_size,\n                             intermediate_size, stride, group_min_len,\n                             group_max_len, (void *)gate_proj, (void *)up_proj,\n                             (void *)down_proj, (ggml_type)gate_type,\n                             (ggml_type)up_type, (ggml_type)down_type,\n                             (ggml_type)hidden_type);\n        }));\n    py::class_<SFT_MOE>(sft_moe_module, \"SFT_MOE\")\n        .def(py::init<SFT_MOEConfig>())\n        .def(\"warm_up\", &SFT_MOEBindings::WarmUpBindinds::cpuinfer_interface)\n        .def(\"forward\", &SFT_MOEBindings::ForwardBindings::cpuinfer_interface)\n\t\t.def(\"backward\", &SFT_MOEBindings::BackwardBindings::cpuinfer_interface);\n\n    #if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)\n    py::class_<AMX_MOEConfig>(moe_module, \"AMX_MOEConfig\")\n        .def(py::init([](int expert_num, int routed_expert_num, int hidden_size,\n                         int intermediate_size,\n                         int max_len, intptr_t gate_proj,\n                         intptr_t up_proj, intptr_t down_proj) {\n            return AMX_MOEConfig(expert_num, routed_expert_num, hidden_size,\n                                 intermediate_size, \n                                 max_len, (void *)gate_proj,\n                                 (void *)up_proj, (void *)down_proj);\n        }));\n\n    py::class_<AMX_MOE<amx::GemmKernel224BF>>(moe_module, \"AMXBF16_MOE\")\n        .def(py::init<AMX_MOEConfig>())\n        .def(\"warm_up\", &AMX_MOEBindings<amx::GemmKernel224BF>::WarmUpBindings::cpuinfer_interface)\n        .def(\"load_weights\", &AMX_MOEBindings<amx::GemmKernel224BF>::LoadWeightsBindings::cpuinfer_interface)\n        .def(\"forward\", &AMX_MOEBindings<amx::GemmKernel224BF>::ForwardBindings::cpuinfer_interface);\n    py::class_<AMX_MOE<amx::GemmKernel224Int8>>(moe_module, \"AMXInt8_MOE\")\n        .def(py::init<AMX_MOEConfig>())\n        .def(\"warm_up\", &AMX_MOEBindings<amx::GemmKernel224Int8>::WarmUpBindings::cpuinfer_interface)\n        .def(\"load_weights\", &AMX_MOEBindings<amx::GemmKernel224Int8>::LoadWeightsBindings::cpuinfer_interface)\n        .def(\"forward\", &AMX_MOEBindings<amx::GemmKernel224Int8>::ForwardBindings::cpuinfer_interface);\n\n    #endif\n\n\t#if defined(__x86_64__) && defined(__HAS_AVX512F__) && defined(__HAS_AMX__)\n    py::class_<SFT_AMX_MOEConfig>(sft_moe_module, \"SFT_AMX_MOEConfig\")\n        .def(py::init([](int expert_num, int routed_expert_num, int hidden_size,\n                         int intermediate_size,\n                         int max_len, intptr_t gate_proj,\n                         intptr_t up_proj, intptr_t down_proj) {\n            return SFT_AMX_MOEConfig(expert_num, routed_expert_num, hidden_size,\n                                 intermediate_size, \n                                 max_len, (void *)gate_proj,\n                                 (void *)up_proj, (void *)down_proj);\n        }));\n\n    py::class_<SFT_AMX_MOE<amx::GemmKernel224BF>>(sft_moe_module, \"SFT_AMXBF16_MOE\")\n        .def(py::init<SFT_AMX_MOEConfig>())\n        .def(\"warm_up\", &SFT_AMX_MOEBindings<amx::GemmKernel224BF>::WarmUpBindings::cpuinfer_interface)\n        .def(\"load_weights\", &SFT_AMX_MOEBindings<amx::GemmKernel224BF>::LoadWeightsBindings::cpuinfer_interface)\n        .def(\"forward\", &SFT_AMX_MOEBindings<amx::GemmKernel224BF>::ForwardBindings::cpuinfer_interface)\n\t\t.def(\"backward\", &SFT_AMX_MOEBindings<amx::GemmKernel224BF>::BackwardBindings::cpuinfer_interface);\n\n    py::class_<SFT_AMX_MOE<amx::GemmKernel224Int8>>(sft_moe_module, \"SFT_AMXInt8_MOE\")\n        .def(py::init<SFT_AMX_MOEConfig>())\n        .def(\"warm_up\", &SFT_AMX_MOEBindings<amx::GemmKernel224Int8>::WarmUpBindings::cpuinfer_interface)\n        .def(\"load_weights\", &SFT_AMX_MOEBindings<amx::GemmKernel224Int8>::LoadWeightsBindings::cpuinfer_interface)\n        .def(\"forward\", &SFT_AMX_MOEBindings<amx::GemmKernel224Int8>::ForwardBindings::cpuinfer_interface)\n\t\t.def(\"backward\", &SFT_AMX_MOEBindings<amx::GemmKernel224Int8>::BackwardBindings::cpuinfer_interface);\n\n    #endif\n\n    auto kvcache_module = m.def_submodule(\"kvcache\");\n\n    py::enum_<AnchorType>(kvcache_module, \"AnchorType\")\n        .value(\"FIXED\", AnchorType::FIXED_ANCHOR)\n        .value(\"DYNAMIC\", AnchorType::DYNAMIC)\n        .value(\"QUEST\", AnchorType::QUEST)\n        .value(\"BLOCK_MAX\", AnchorType::BLOCK_MAX)\n        .value(\"BLOCK_MEAN\", AnchorType::BLOCK_MEAN);\n    py::enum_<ggml_type>(kvcache_module, \"ggml_type\")\n        .value(\"FP16\", ggml_type::GGML_TYPE_F16)\n        .value(\"FP32\", ggml_type::GGML_TYPE_F32)\n        .value(\"Q4_0\", ggml_type::GGML_TYPE_Q4_0)\n        .value(\"Q8_0\", ggml_type::GGML_TYPE_Q8_0);\n    py::enum_<RetrievalType>(kvcache_module, \"RetrievalType\")\n        .value(\"LAYER\", RetrievalType::LAYER)\n        .value(\"KVHEAD\", RetrievalType::KVHEAD)\n        .value(\"QHEAD\", RetrievalType::QHEAD);\n\n    py::class_<KVCacheConfig>(kvcache_module, \"KVCacheConfig\")\n        .def(py::init<int, int, int, int, int, int, AnchorType, ggml_type,\n                      RetrievalType, int, int, int, int, int, int>())\n        .def_readwrite(\"layer_num\", &KVCacheConfig::layer_num)\n        .def_readwrite(\"kv_head_num\", &KVCacheConfig::kv_head_num)\n        .def_readwrite(\"q_head_num\", &KVCacheConfig::q_head_num)\n        .def_readwrite(\"head_dim\", &KVCacheConfig::head_dim)\n        .def_readwrite(\"block_len\", &KVCacheConfig::block_len)\n        .def_readwrite(\"anchor_num\", &KVCacheConfig::anchor_num)\n        .def_readwrite(\"anchor_type\", &KVCacheConfig::anchor_type)\n        .def_readwrite(\"kv_type\", &KVCacheConfig::kv_type)\n        .def_readwrite(\"retrieval_type\", &KVCacheConfig::retrieval_type)\n        .def_readwrite(\"layer_step\", &KVCacheConfig::layer_step)\n        .def_readwrite(\"token_step\", &KVCacheConfig::token_step)\n        .def_readwrite(\"layer_offset\", &KVCacheConfig::layer_offset)\n        .def_readwrite(\"max_block_num\", &KVCacheConfig::max_block_num)\n        .def_readwrite(\"max_batch_size\", &KVCacheConfig::max_batch_size)\n        .def_readwrite(\"max_thread_num\", &KVCacheConfig::max_thread_num);\n    py::class_<KVCache>(kvcache_module, \"KVCache\")\n        .def(py::init<KVCacheConfig>())\n        .def(\"get_cache_total_len\", &KVCache::get_cache_total_len)\n        .def(\"update_cache_total_len\",\n             [](KVCache &kvcache, int cache_total_len) {\n                 kvcache.update_cache_total_len(cache_total_len);\n             })\n        .def(\"attn\", &KVCacheBindings::AttnBindings::cpuinfer_interface)\n        .def(\n            \"get_all_kvcache_one_layer\",\n            &KVCacheBindings::GetAllKVCacheOneLayerBindings::cpuinfer_interface)\n        .def(\"get_and_update_kvcache_fp16\",\n             &KVCacheBindings::GetAndUpdateKVCacheFp16Bindings::\n                 cpuinfer_interface)\n        .def(\"get_kvcache_fp16\",\n             &KVCacheBindings::GetKVCacheFp16Bindings::cpuinfer_interface)\n        .def(\"update_kvcache_fp16\",\n             &KVCacheBindings::UpdateKVCacheFp16Bindings::cpuinfer_interface)\n        .def(\"update_importance\",\n             &KVCacheBindings::UpdateImportanceBindings::cpuinfer_interface)\n        .def(\"attn_with_kvcache\",\n             &KVCacheBindings::AttnWithKVCacheBindings::cpuinfer_interface)\n        .def(\"clear_importance_all_layers\",\n             &KVCacheBindings::ClearImportanceAllLayersBindings::\n                 cpuinfer_interface)\n        .def(\"calc_anchor_all_layers\",\n             &KVCacheBindings::CalcAnchorAllLayersBindinds::cpuinfer_interface);\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/amx/debug_sft_moe.hpp",
    "content": "/**\n * @Description  : Mainly used for dev debug, with no numa version\n * @Author       : chenht2022\n * @Date         : 2025-04-25 18:28:12\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2025-04-25 18:28:12\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_SFT_AMX_MOE_H\n#define CPUINFER_OPERATOR_SFT_AMX_MOE_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n#include <fstream>\n#include <filesystem>\n\n#include \"debug_sft_moe.hpp\"\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\n#include \"la/amx.hpp\"\n\n#ifdef USE_NUMA\n#include <numa.h>\n#include <numaif.h>\nvoid *numa_alloc_aligned(size_t size, int node, size_t alignment) {\n  void *ptr = numa_alloc_onnode(size, node);\n  assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n  return ptr;\n}\n#endif\n\n// static inline __m512 exp_avx512(__m512 x) {\n//   const __m512 log2e = _mm512_set1_ps(1.44269504089f);\n//   const __m512 c1 = _mm512_set1_ps(0.69314718056f);\n\n//   __m512 y = _mm512_mul_ps(x, log2e);\n//   __m512i int_part = _mm512_cvtps_epi32(y);\n//   __m512 frac_part = _mm512_sub_ps(y, _mm512_cvtepi32_ps(int_part));\n\n//   const __m512 poly_1 = _mm512_set1_ps(0.9999999995f);\n//   const __m512 poly_2 = _mm512_set1_ps(0.6931471805f);\n//   const __m512 poly_3 = _mm512_set1_ps(0.2402265069f);\n//   const __m512 poly_4 = _mm512_set1_ps(0.0555041087f);\n//   const __m512 poly_5 = _mm512_set1_ps(0.0096181291f);\n//   const __m512 poly_6 = _mm512_set1_ps(0.0013333558f);\n\n//   __m512 frac_exp = _mm512_fmadd_ps(\n//       frac_part, poly_6,\n//       _mm512_fmadd_ps(frac_part, poly_5,\n//                       _mm512_fmadd_ps(frac_part, poly_4,\n//                                       _mm512_fmadd_ps(frac_part, poly_3, _mm512_fmadd_ps(frac_part, poly_2, poly_1)))));\n\n//   __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part));\n//   return _mm512_mul_ps(two_pow_i, frac_exp);\n// }\n\n// static inline __m512 act_fn(__m512 gate_val, __m512 up_val) {\n//   __m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val);\n//   __m512 exp_neg_gate = exp_avx512(neg_gate_val);\n//   __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate);\n//   __m512 act_val = _mm512_div_ps(gate_val, denom);\n\n//   return _mm512_mul_ps(act_val, up_val);\n// }\n\nstatic inline __m512 sigmoid(__m512 x) {\n  __m512 neg = _mm512_sub_ps(_mm512_setzero_ps(), x);\n  __m512 e = exp_avx512(neg);\n  __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), e);\n  return _mm512_div_ps(_mm512_set1_ps(1.0f), denom);\n}\n\nstatic inline __m512 act_fn_1(__m512 x) {\n  __m512 sigmoid_val = sigmoid(x);\n  return _mm512_mul_ps(sigmoid_val, x);\n}\n\nstatic inline __m512 act_fn_grad(__m512 x) {\n  // sigmoid(x) * (1 + x * (1 - sigmoid(x)))\n  __m512 sigmoid_val = sigmoid(x);\n  __m512 one_minus_sigmoid = _mm512_sub_ps(_mm512_set1_ps(1.0f), sigmoid_val);\n  __m512 x_term = _mm512_mul_ps(x, one_minus_sigmoid);\n  __m512 one_plus_x_term = _mm512_add_ps(_mm512_set1_ps(1.0f), x_term);\n  return _mm512_mul_ps(sigmoid_val, one_plus_x_term);\n}\n\n// static inline float bf16_to_fp32(ggml_bf16_t v) {\n//     uint16_t lo16;\n//     std::memcpy(&lo16, &v, sizeof(lo16));   // 取出 16 bit 数据\n//     uint32_t tmp = uint32_t(lo16) << 16; // 放到高 16 位\n//     float out;\n//     std::memcpy(&out, &tmp, sizeof(float));\n//     return out;\n// }\n\n// 把 ggml_bf16_t 数组转换成可读字符串（逗号分隔）\nstd::string int8_row_to_string(const int8_t* row, int len) {\n    std::string s;\n    for (int i = 0; i < len; ++i) {\n        if (i) s += \", \";\n        s += std::to_string(row[i]);\n    }\n    return s;\n}\n\nstruct SFT_AMX_MOEConfig {\n  int expert_num;\n  int routed_expert_num;\n  int hidden_size;\n  int intermediate_size;\n  int max_len;\n  void *gate_proj;\n  void *up_proj;\n  void *down_proj;\n\n  SFT_AMX_MOEConfig() {}\n\n  SFT_AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len,\n                void *gate_proj, void *up_proj, void *down_proj)\n      : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size),\n        intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj),\n        down_proj(down_proj) {}\n};\n\ntemplate <class T> class SFT_AMX_MOE {\nprivate:\n  SFT_AMX_MOEConfig config_;\n  void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n  void *up_proj_;   // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n  void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n\n  void *gate_proj_t_; // [expert_num * intermediate_size * hidden_size]\n  void *up_proj_t_;   // [expert_num * intermediate_size * hidden_size]\n  void *down_proj_t_; // [expert_num * hidden_size * intermediate_size]\n\n  ggml_bf16_t *m_local_input_;       // [routed_expert_num * max_len * hidden_size]\n  ggml_bf16_t *m_local_gate_output_; // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_up_output_;   // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_down_output_; // [routed_expert_num * max_len * hidden_size]\n\n  std::vector<std::vector<int>> m_local_pos_;          // [max_len, routed_expert_num]\n  std::vector<int> m_local_num_;                       // [expert_num]\n  std::vector<int> m_expert_id_map_;                   // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_input_ptr_;       // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_gate_output_ptr_; // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_up_output_ptr_;   // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_down_output_ptr_; // [expert_num]\n\n  std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;\n  std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;\n  std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;\n\n#ifdef USE_NUMA\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> gate_bb_numa_;\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> up_bb_numa_;\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> down_bb_numa_;\n#else\n  std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;\n#endif\n\n  ggml_bf16_t *m_local_down_output_grad_;       // [routed_expert_num * max_len * hidden_size]\n  ggml_bf16_t *m_local_down_input_grad_;        // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_gate_output_grad_;       // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_up_output_grad_;         // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_gate_input_grad_;        // [routed_expert_num * max_len * hidden_size]\n  ggml_bf16_t *m_local_up_input_grad_;          // [routed_expert_num * max_len * hidden_size]\n\n  std::vector<ggml_bf16_t *> m_local_down_output_grad_ptr_;       // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_down_input_grad_ptr_;        // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_gate_output_grad_ptr_;       // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_up_output_grad_ptr_;         // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_gate_input_grad_ptr_;        // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_up_input_grad_ptr_;          // [expert_num]\n\n  std::vector<std::shared_ptr<typename T::BufferA>> gate_t_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> gate_t_bc_;\n  std::vector<std::shared_ptr<typename T::BufferA>> up_t_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> up_t_bc_;\n  std::vector<std::shared_ptr<typename T::BufferA>> down_t_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> down_t_bc_;\n\n  // TODO: NUMA\n  std::vector<std::shared_ptr<typename T::BufferB>> gate_t_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> up_t_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> down_t_bb_;\n\n  int* m_local_token_indices_;                                   // [routed_expert_num * max_len]\n  int* m_local_expert_positions_;                               // [routed_expert_num * max_len]\n  std::vector<int *> m_local_token_indices_ptr_;                // [expert_num]\n  std::vector<int *> m_local_expert_positions_ptr_;             // [expert_num]\n\npublic:\n  SFT_AMX_MOE(SFT_AMX_MOEConfig config) {\n    config_ = config;\n    gate_proj_ = config_.gate_proj;\n    up_proj_ = config_.up_proj;\n    down_proj_ = config_.down_proj;\n\n    std::vector<std::pair<void **, uint64_t>> m_mem_requests;\n    m_mem_requests.push_back({(void **)&m_local_input_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *\n                                                                  config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_up_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *\n                                                                config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_down_output_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    std::vector<void *> gate_up_ba_ptr(config_.expert_num);\n    std::vector<void *> gate_bc_ptr(config_.expert_num);\n    std::vector<void *> up_bc_ptr(config_.expert_num);\n    std::vector<void *> down_ba_ptr(config_.expert_num);\n    std::vector<void *> down_bc_ptr(config_.expert_num);\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_mem_requests.push_back(\n          {(void **)&gate_up_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)});\n      m_mem_requests.push_back(\n          {(void **)&gate_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&up_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});\n    }\n\n    m_mem_requests.push_back({(void **)&gate_proj_t_,\n                              sizeof(ggml_bf16_t) * config_.expert_num * config_.intermediate_size * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&up_proj_t_,\n                              sizeof(ggml_bf16_t) * config_.expert_num * config_.intermediate_size * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&down_proj_t_,\n                              sizeof(ggml_bf16_t) * config_.expert_num * config_.hidden_size * config_.intermediate_size});\n    \n    m_mem_requests.push_back({(void **)&m_local_down_output_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&m_local_down_input_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_gate_output_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_up_output_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_gate_input_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&m_local_up_input_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&m_local_token_indices_,\n                              sizeof(int) * config_.routed_expert_num * config_.max_len});\n    m_mem_requests.push_back({(void **)&m_local_expert_positions_,\n                              sizeof(int) * config_.routed_expert_num * config_.max_len});\n    std::vector<void *> gate_t_ba_ptr(config_.expert_num);\n    std::vector<void *> gate_t_bc_ptr(config_.expert_num);\n    std::vector<void *> up_t_ba_ptr(config_.expert_num);\n    std::vector<void *> up_t_bc_ptr(config_.expert_num);\n    std::vector<void *> down_t_ba_ptr(config_.expert_num);\n    std::vector<void *> down_t_bc_ptr(config_.expert_num);\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_mem_requests.push_back(\n          {(void **)&gate_t_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&gate_t_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});\n      m_mem_requests.push_back(\n          {(void **)&up_t_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&up_t_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_t_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_t_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});\n    }\n\n    shared_mem_buffer.alloc(this, m_mem_requests);\n\n    m_local_pos_.resize(config_.max_len);\n    for (int i = 0; i < config_.max_len; i++) {\n      m_local_pos_[i].resize(config_.routed_expert_num);\n    }\n    m_expert_id_map_.resize(config_.expert_num);\n    m_local_num_.resize(config_.expert_num);\n    m_local_input_ptr_.resize(config_.expert_num);\n    m_local_gate_output_ptr_.resize(config_.expert_num);\n    m_local_up_output_ptr_.resize(config_.expert_num);\n    m_local_down_output_ptr_.resize(config_.expert_num);\n    m_local_down_output_grad_ptr_.resize(config_.expert_num);\n    m_local_down_input_grad_ptr_.resize(config_.expert_num);\n    m_local_gate_output_grad_ptr_.resize(config_.expert_num);\n    m_local_up_output_grad_ptr_.resize(config_.expert_num);\n    m_local_gate_input_grad_ptr_.resize(config_.expert_num);\n    m_local_up_input_grad_ptr_.resize(config_.expert_num);\n\n    for (uint64_t i = 0; i < config_.expert_num; i++) {\n      gate_up_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, gate_up_ba_ptr[i]));\n      gate_bc_.push_back(\n          std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, gate_bc_ptr[i]));\n      up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, up_bc_ptr[i]));\n      down_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, down_ba_ptr[i]));\n      down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, down_bc_ptr[i]));\n\n#ifdef USE_NUMA\n      int numa_nodes = numa_num_configured_nodes();\n      gate_bb_numa_.resize(numa_nodes);\n      up_bb_numa_.resize(numa_nodes);\n      down_bb_numa_.resize(numa_nodes);\n      for (int j = 0; j < numa_nodes; j++) {\n        void *gate_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);\n        gate_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));\n        void *up_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);\n        up_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));\n        void *down_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64);\n        down_bb_numa_[j].push_back(  \n            std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));\n      }\n#else\n      void *gate_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));\n      gate_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));\n\n      void *up_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));\n      up_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));\n\n      void *down_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));\n      down_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));\n#endif\n    }\n\n    for (uint64_t i = 0; i < config_.expert_num; i++) {\n      gate_t_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, gate_t_ba_ptr[i]));\n      gate_t_bc_.push_back(\n          std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, gate_t_bc_ptr[i]));\n      up_t_ba_.push_back(std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, up_t_ba_ptr[i]));\n      up_t_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, up_t_bc_ptr[i]));\n      down_t_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, down_t_ba_ptr[i]));\n      down_t_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, down_t_bc_ptr[i]));\n\n      // TODO: NUMA\n      void *gate_t_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));\n      gate_t_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, gate_t_bb_ptr));\n\n      void *up_t_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));\n      up_t_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, up_t_bb_ptr));\n\n      void *down_t_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));\n      down_t_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, down_t_bb_ptr));\n    }\n\n    m_local_token_indices_ptr_.resize(config_.expert_num);\n    m_local_expert_positions_ptr_.resize(config_.expert_num);\n  }\n\n  ~SFT_AMX_MOE() { shared_mem_buffer.dealloc(this); }\n\n  void transpose_expert(const void* src, void* dst, int R, int C, Backend* backend) {\n    backend->do_work_stealing_job(\n        config_.expert_num, nullptr,\n        [&](uint64_t expert_idx) {\n          for (int r = 0; r < R; ++r) {\n            for (int c = 0; c < C; ++c) {\n                memcpy(\n                    (uint8_t*)dst + (expert_idx * R * C + (c * R + r)) * sizeof(ggml_bf16_t),\n                    (uint8_t*)src + (expert_idx * R * C + (r * C + c)) * sizeof(ggml_bf16_t),\n                    sizeof(ggml_bf16_t));\n            }\n          }\n        },\n        nullptr);\n  }\n  \n  void load_weights(Backend *backend) {\n    transpose_expert(config_.gate_proj, gate_proj_t_, config_.intermediate_size, config_.hidden_size, backend);\n    transpose_expert(config_.up_proj, up_proj_t_, config_.intermediate_size, config_.hidden_size, backend);\n    transpose_expert(config_.down_proj, down_proj_t_, config_.hidden_size, config_.intermediate_size, backend);\n\n    int nth = T::recommended_nth(config_.intermediate_size);\n    backend->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [&](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          int numa_nodes = numa_num_configured_nodes();\n          for (int j = 0; j < numa_nodes; j++) {\n            gate_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +\n                                                       expert_idx * config_.intermediate_size * config_.hidden_size,\n                                                   ith, nth);\n            up_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.up_proj +\n                                                     expert_idx * config_.intermediate_size * config_.hidden_size,\n                                                 ith, nth);\n          }\n#else\n          gate_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +\n                                             expert_idx * config_.intermediate_size * config_.hidden_size,\n                                         ith, nth);\n          up_bb_[expert_idx]->from_mat(\n              (ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth);\n#endif\n\t\t},\n        nullptr);\n    nth = T::recommended_nth(config_.intermediate_size);\n    backend->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [&](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          int ith = task_id % nth;\n          down_t_bb_[expert_idx]->from_mat((ggml_bf16_t *)down_proj_t_ +\n                                             expert_idx * config_.intermediate_size * config_.hidden_size,\n                                         ith, nth);\n        },\n        nullptr);\n\n\t// if constexpr (std::is_same_v<typename T::dt, ggml_bf16_t>) {\n\t// \t// 确保 debug/ 目录存在\n\t// \tstd::filesystem::create_directories(\"debug\");\n\n\t// \tint tail_cols = 1024;\n\t// \tfor (int expert_idx = 0; expert_idx < config_.expert_num; ++expert_idx) {\n\t// \t\tauto buf = down_t_bb_[expert_idx].get();\n\t// \t\tstd::cout << \"k: \" << buf->k << \"; n: \" << buf->n << std::endl;\n\t// \t\t// 打开对应 expert 的文件\n\t// \t\tstd::string path = \"debug/\" + std::to_string(expert_idx) + \"_down_bb_t_debug.txt\";\n\t// \t\tstd::ofstream ofs(path, std::ios::out);\n\t// \t\tif (!ofs) {\n\t// \t\t\tstd::cerr << \"Failed to open \" << path << \" for writing\\n\";\n\t// \t\t\tcontinue;\n\t// \t\t}\n\n\t// \t\tofs << \"==== Expert \" << expert_idx << \" ====\\n\";\n\t// \t\tfor (int n_idx = 0; n_idx < buf->k; ++n_idx) {\n\t// \t\t\t// 明确当作 int8 读\n\t// \t\t\tconst int8_t* row = reinterpret_cast<const int8_t*>(buf->b) + n_idx * buf->n;\n\n\t// \t\t\t// 写整行\n\t// \t\t\tofs << \"row[\" << n_idx << \"] = { \"\n\t// \t\t\t\t<< int8_row_to_string(row, buf->n)\n\t// \t\t\t\t<< \" }\\n\";\n\t// \t\t}\n\n\t// \t\tofs.close();\n\t// \t}\n\t// }\n\n\t// if constexpr (std::is_same_v<typename T::dt, int8_t>) {\n\t// \tfor (int expert_idx = 0; expert_idx < config_.expert_num; ++expert_idx) {\n\t// \t\tauto buf = down_t_bb_[expert_idx].get();\n\n\t// \t\t// 打开对应 expert 的文件\n\t// \t\tstd::string path = \"debug/\" + std::to_string(expert_idx) + \"_down_bb_t_debug3.bin\";\n\t// \t\tstd::ofstream ofs(path, std::ios::binary);\n\t// \t\tfor (int n_idx = 0; n_idx < buf->k; ++n_idx) {\n\t// \t\t\tconst int8_t* row = reinterpret_cast<const int8_t*>(buf->b) + n_idx * buf->n;\n\t// \t\t\tfor (int j = 0; j < buf->n; ++j) {\n\t// \t\t\t\tfloat v = row[j];\n\t// \t\t\t\tofs.write(reinterpret_cast<const char*>(&v), sizeof(v));\n\t// \t\t\t}\n\t// \t\t}\n\t// \t\tofs.close();\n\t// \t}\n\t// }\n\n\t\n\t// for (uint64_t expert_idx = 0; expert_idx < (uint64_t)config_.expert_num; ++expert_idx) {\n\t// \t// 打开对应 expert 的文件\n\t// \tstd::string path = \"debug/\" + std::to_string(expert_idx) + \"_up_proj_t.bin\";\n\t// \tstd::ofstream ofs(path, std::ios::binary);\n\t// \tstd::cout << \"config_.hidden_size: \" << config_.hidden_size << std::endl;\n\t// \tstd::cout << \"config_.intermediate_size: \" << config_.intermediate_size << std::endl;\n\t// \tfor (int n_idx = 0; n_idx < config_.intermediate_size; ++n_idx) {\n\t// \t\tconst int8_t* row = reinterpret_cast<const int8_t*>(config_.down_proj + expert_idx * n_idx * config_.hidden_size);\n\t// \t\tfor (int j = 0; j < config_.hidden_size; ++j) {\n\t// \t\t\tfloat v = row[j];\n\t// \t\t\tofs.write(reinterpret_cast<const char*>(&v), sizeof(v));\n\t// \t\t}\n\t// \t}\n\t// \tofs.close();\n\t// }\n\n    nth = T::recommended_nth(config_.hidden_size);\n    backend->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [&](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          int numa_nodes = numa_num_configured_nodes();\n          for (int j = 0; j < numa_nodes; j++) {\n            down_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +\n                                                       expert_idx * config_.hidden_size * config_.intermediate_size,\n                                                   ith, nth);\n          }\n#else\n          down_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +\n                                             expert_idx * config_.hidden_size * config_.intermediate_size,\n                                         ith, nth);\n#endif\n          up_t_bb_[expert_idx]->from_mat((ggml_bf16_t *)up_proj_t_ +\n                                             expert_idx * config_.hidden_size * config_.intermediate_size,\n                                         ith, nth);\n          gate_t_bb_[expert_idx]->from_mat((ggml_bf16_t *)gate_proj_t_ +\n                                             expert_idx * config_.hidden_size * config_.intermediate_size,\n                                         ith, nth);\n        },\n        nullptr);\n  }\n\n  void warm_up(Backend *backend) {}\n\n  void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output, Backend *backend) {\n    bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);\n    int activated_expert = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n      for (int j = 0; j < k; j++) {\n        m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n      }\n    }\n    for (int i = 0; i < config_.expert_num; i++) {\n      if (m_local_num_[i] > 0) {\n        m_expert_id_map_[activated_expert] = i;\n        activated_expert++;\n      }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;\n      m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n      m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n      m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;\n      offset += m_local_num_[i];\n    }\n    backend->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int j = 0; j < k; j++) {\n            memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,\n                   (ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);\n          }\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n\n          gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n    int nth = T::recommended_nth(config_.intermediate_size);\n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx],\n                       ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith,\n                       nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);\n          up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);\n          auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);\n          for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n            ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];\n            for (int j = n_start; j < n_end; j += 32) {\n              __m512 gate_val0, gate_val1, up_val0, up_val1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);\n              __m512 result0 = act_fn(gate_val0, up_val0);\n              __m512 result1 = act_fn(gate_val1, up_val1);\n              avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));\n            }\n          }\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n\t\n\t// for (uint64_t expert_idx = 0; expert_idx < (uint64_t)config_.expert_num; ++expert_idx) {\n\t// \tdump_grad_bin(\"cpp_layer0_E_End\"+std::to_string(expert_idx)+\"_down_ba_ori_\", (ggml_bf16_t*)m_local_gate_output_ptr_[expert_idx], config_.intermediate_size * m_local_num_[expert_idx], GGML_TYPE_BF16);\n\t// }\n\t// if constexpr (std::is_same_v<typename T::dt, int8_t>) {\n\t// \tstd::cout << \"GO INTO forward output\" << std::endl;\n\t// \t// 确保 debug/ 目录存在\n\t// \tstd::filesystem::create_directories(\"debug\");\n\n\t// \tfor (int expert_idx = 0; expert_idx < config_.expert_num; ++expert_idx) {\n\t// \t\tauto buf = down_ba_[expert_idx].get();\n\t// \t\t// std::cout << \"k: \" << buf->k << \"; n: \" << buf->n << std::endl;\n\t// \t\t// 打开对应 expert 的文件\n\t// \t\tstd::string path = \"debug/\" + std::to_string(expert_idx) + \"_down_ba_debug.txt\";\n\t// \t\tstd::ofstream ofs(path, std::ios::out);\n\t// \t\tif (!ofs) {\n\t// \t\t\tstd::cerr << \"Failed to open \" << path << \" for writing\\n\";\n\t// \t\t\tcontinue;\n\t// \t\t}\n\n\t// \t\tofs << \"==== Expert \" << expert_idx << \" ====\\n\";\n\t// \t\tofs << \"buf_k: \" << buf->k << \"\\n\";\n\t// \t\tfor (int n_idx = 0; n_idx < m_local_num_[expert_idx]; ++n_idx) {\n\t// \t\t\t// 明确当作 bfloat16 读\n\t// \t\t\tconst int8_t* row = reinterpret_cast<const int8_t*>(buf->a) + n_idx * buf->k;\n\n\t// \t\t\t// 写整行\n\t// \t\t\tofs << \"row[\" << n_idx << \"] = { \"\n\t// \t\t\t\t<< int8_row_to_string(row, buf->k)\n\t// \t\t\t\t<< \" }\\n\";\n\t// \t\t}\n\n\t// \t\tofs.close();\n\t// \t}\n\t// \tstd::cout << \"OUT INTO forward output\" << std::endl;\n\t// }\n    nth = T::recommended_nth(config_.hidden_size);\n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],\n                       down_bb_numa_[Backend::numa_node][expert_idx], down_bc_[expert_idx], ith, nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],\n                       down_bb_[expert_idx], down_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int e = 0; e < config_.hidden_size; e += 32) {\n            __m512 x0 = _mm512_setzero_ps();\n            __m512 x1 = _mm512_setzero_ps();\n            for (int j = 0; j < k; j++) {\n              __m512 weight = _mm512_set1_ps(weights[i * k + j]);\n              __m512 down_output0, down_output1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(m_local_down_output_ptr_[expert_ids[i * k + j]] +\n                                                    m_local_pos_[i][j] * config_.hidden_size + e),\n                                        &down_output0, &down_output1);\n              x0 = _mm512_fmadd_ps(down_output0, weight, x0);\n              x1 = _mm512_fmadd_ps(down_output1, weight, x1);\n            }\n            avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + i * config_.hidden_size + e));\n          }\n        },\n        nullptr);\n  }\n\n  void backward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void* input, const void *output_grad, void *input_grad, Backend *backend) {\n    bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);\n    int activated_expert = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n      for (int j = 0; j < k; j++) {\n        m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n      }\n    }\n    for (int i = 0; i < config_.expert_num; i++) {\n      if (m_local_num_[i] > 0) {\n        m_expert_id_map_[activated_expert] = i;\n        activated_expert++;\n      }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;\n      m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n      m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n\n      m_local_down_output_grad_ptr_[i] = m_local_down_output_grad_ + offset * config_.hidden_size;\n      m_local_down_input_grad_ptr_[i] = m_local_down_input_grad_ + offset * config_.intermediate_size;\n      m_local_gate_output_grad_ptr_[i] = m_local_gate_output_grad_ + offset * config_.intermediate_size;\n      m_local_up_output_grad_ptr_[i] = m_local_up_output_grad_ + offset * config_.intermediate_size;\n      m_local_gate_input_grad_ptr_[i] = m_local_gate_input_grad_ + offset * config_.hidden_size;\n      m_local_up_input_grad_ptr_[i] = m_local_up_input_grad_ + offset * config_.hidden_size;\n      m_local_token_indices_ptr_[i] = m_local_token_indices_ + offset;\n      m_local_expert_positions_ptr_[i] = m_local_expert_positions_ + offset;\n      offset += m_local_num_[i];\n    }\n\n    // TODO: cache\n    backend->do_work_stealing_job(\n        qlen, nullptr, \n        [&](int i) {\n          for (int j = 0; j < k; j++) {\n            uint64_t expert_id = expert_ids[i * k + j];\n            int local_row = m_local_pos_[i][j];\n            memcpy(m_local_input_ptr_[expert_id] + local_row * config_.hidden_size,\n              (ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); // TODO: cache\n            memcpy(m_local_down_output_grad_ptr_[expert_id] + local_row * config_.hidden_size,\n              (ggml_bf16_t *)output_grad + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);\n            m_local_token_indices_ptr_[expert_id][local_row] = i;\n            m_local_expert_positions_ptr_[expert_id][local_row] = j;\n          }\n        }, \n        nullptr);\n\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); // TODO: cache\n          down_t_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_down_output_grad_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n\t\t\n\t// // for debug\n\t// for (uint64_t expert_idx = 0; expert_idx < (uint64_t)config_.expert_num; ++expert_idx) {\n\t// \tdump_grad_bin(\"cpp_layer0_E_End\"+std::to_string(expert_idx)+\"_down_output_grad_\", (ggml_bf16_t*)m_local_down_output_grad_ptr_[expert_idx], config_.hidden_size * m_local_num_[expert_idx], GGML_TYPE_BF16);\n\t// }\n\t\n\t// if constexpr (std::is_same_v<typename T::dt, int8_t>) {\n\t// \t// 确保 debug/ 目录存在\n\t// \tstd::filesystem::create_directories(\"debug\");\n\n\t// \tfor (int expert_idx = 0; expert_idx < config_.expert_num; ++expert_idx) {\n\t// \t\tauto buf = down_t_ba_[expert_idx].get();\n\t// \t\t// std::cout << \"k: \" << buf->k << \"; n: \" << buf->n << std::endl;\n\t// \t\t// 打开对应 expert 的文件\n\t// \t\tstd::string path = \"debug/\" + std::to_string(expert_idx) + \"_down_ba_t_debug.txt\";\n\t// \t\tstd::ofstream ofs(path, std::ios::out);\n\t// \t\tif (!ofs) {\n\t// \t\t\tstd::cerr << \"Failed to open \" << path << \" for writing\\n\";\n\t// \t\t\tcontinue;\n\t// \t\t}\n\n\t// \t\tofs << \"==== Expert \" << expert_idx << \" ====\\n\";\n\t// \t\tfor (int n_idx = 0; n_idx < m_local_num_[expert_idx]; ++n_idx) {\n\t// \t\t\t// 明确当作 bfloat16 读\n\t// \t\t\tconst int8_t* row = reinterpret_cast<const int8_t*>(buf->a) + n_idx * buf->k;\n\n\t// \t\t\t// 写整行\n\t// \t\t\tofs << \"row[\" << n_idx << \"] = { \"\n\t// \t\t\t\t<< int8_row_to_string(row, buf->k)\n\t// \t\t\t\t<< \" }\\n\";\n\t// \t\t}\n\n\t// \t\tofs.close();\n\t// \t}\n\t// }\n\n    int nth = T::recommended_nth(config_.intermediate_size);  \n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n\n        //   // TODO: cache\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                      gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                      gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx);\n          gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);\n          up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);\n\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                      down_t_ba_[expert_idx], down_t_bb_[expert_idx], down_t_bc_[expert_idx], ith, nth, use_amx);\n          down_t_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_input_grad_ptr_[expert_idx], ith, nth);\n\n\n          auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);\n          for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n            ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *down_input_grad_ptr = &m_local_down_input_grad_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *gate_output_grad_ptr = &m_local_gate_output_grad_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *up_output_grad_ptr = &m_local_up_output_grad_ptr_[expert_idx][i * config_.intermediate_size];\n            \n            int token_idx = m_local_token_indices_ptr_[expert_idx][i];\n            int expert_pos = m_local_expert_positions_ptr_[expert_idx][i];\n            __m512 weight = _mm512_set1_ps(weights[token_idx * k + expert_pos]);\n            \n            for (int j = n_start; j < n_end; j += 32) {\n              __m512 gate_val0, gate_val1, up_val0, up_val1, down_input_grad0, down_input_grad1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(down_input_grad_ptr + j), &down_input_grad0, &down_input_grad1);\n              \n              down_input_grad0 = _mm512_mul_ps(down_input_grad0, weight);\n              down_input_grad1 = _mm512_mul_ps(down_input_grad1, weight);\n              \n              // gate_output_grad = δ_zji ⊙ v_ji ⊙ σ'(u_ji)\n              __m512 gate_grad0 = _mm512_mul_ps(down_input_grad0, \n                                               _mm512_mul_ps(up_val0, act_fn_grad(gate_val0)));\n              __m512 gate_grad1 = _mm512_mul_ps(down_input_grad1, \n                                               _mm512_mul_ps(up_val1, act_fn_grad(gate_val1)));\n              \n              // up_output_grad = δ_zji ⊙ σ(u_ji)\n              __m512 up_grad0 = _mm512_mul_ps(down_input_grad0, act_fn_1(gate_val0));\n              __m512 up_grad1 = _mm512_mul_ps(down_input_grad1, act_fn_1(gate_val1));\n              \n              avx512_32xfp32_to_32xbf16(&gate_grad0, &gate_grad1, (__m512i *)(gate_output_grad_ptr + j));\n              avx512_32xfp32_to_32xbf16(&up_grad0, &up_grad1, (__m512i *)(up_output_grad_ptr + j));\n            }\n          }\n        },\n        nullptr);\n\n\t// for debug\n\t// if constexpr (std::is_same_v<typename T::dt, ggml_bf16_t>) {\t\n\t// \tfor (int expert_idx = 0; expert_idx < config_.expert_num; ++expert_idx) {\n\t// \t\tauto buf = down_t_ba_[expert_idx].get();\n\n\t// \t\t// 打开对应 expert 的文件\n\t// \t\tstd::string path = \"debug/\" + std::to_string(expert_idx) + \"_down_ba_t_debug3.bin\";\n\t// \t\tstd::ofstream ofs(path, std::ios::binary);\n\t// \t\tfor (int n_idx = 0; n_idx < m_local_num_[expert_idx]; ++n_idx) {\n\t// \t\t\tconst ggml_bf16_t* row = reinterpret_cast<const ggml_bf16_t*>(buf->a) + n_idx * buf->k;\n\t// \t\t\tfor (int j = 0; j < buf->k; ++j) {\n\t// \t\t\t\tfloat v = row[j];\n\t// \t\t\t\tofs.write(reinterpret_cast<const char*>(&v), sizeof(v));\n\t// \t\t\t}\n\t// \t\t}\n\t// \t\tofs.close();\n\t// \t}\n\t// }\n\t\n\t// for (uint64_t expert_idx = 0; expert_idx < (uint64_t)config_.expert_num; ++expert_idx) {\n\t// \tdump_grad_bin(\"cpp_layer0_E_End\"+std::to_string(expert_idx)+\"_down_t_ba_\", (ggml_bf16_t*)m_local_down_output_grad_ptr_[expert_idx], config_.hidden_size * m_local_num_[expert_idx], GGML_TYPE_BF16);\n\t// }\n\t\n\t// for (uint64_t expert_idx = 0; expert_idx < (uint64_t)config_.expert_num; ++expert_idx) {\n\t// \tdump_grad_bin(\"cpp_layer0_E_End\"+std::to_string(expert_idx)+\"_down_t_bb_\", (ggml_bf16_t *)down_proj_t_ + expert_idx * config_.intermediate_size * config_.hidden_size, config_.hidden_size * config_.intermediate_size, GGML_TYPE_BF16);\n\t// }\n\n\t// for (uint64_t expert_idx = 0; expert_idx < (uint64_t)config_.expert_num; ++expert_idx) {\n\t// \tdump_grad_bin(\"cpp_layer0_E_End\"+std::to_string(expert_idx)+\"_down_t_bc_\", (ggml_bf16_t*)m_local_down_input_grad_ptr_[expert_idx], config_.intermediate_size * m_local_num_[expert_idx], GGML_TYPE_BF16);\n\t// }\n\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          gate_t_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_grad_ptr_[expert_idx], 0, 1);\n          up_t_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_up_output_grad_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n    nth = T::recommended_nth(config_.hidden_size);\n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,\n                      gate_t_ba_[expert_idx], gate_t_bb_[expert_idx], gate_t_bc_[expert_idx], ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,\n                      up_t_ba_[expert_idx], up_t_bb_[expert_idx], up_t_bc_[expert_idx], ith, nth, use_amx);\n          gate_t_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_input_grad_ptr_[expert_idx], ith, nth);\n          up_t_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_input_grad_ptr_[expert_idx], ith, nth);\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int e = 0; e < config_.hidden_size; e += 32) {\n            __m512 x0 = _mm512_setzero_ps();\n            __m512 x1 = _mm512_setzero_ps();\n            for (int j = 0; j < k; j++) {\n              __m512 gate_input_grad0, gate_input_grad1, up_input_grad0, up_input_grad1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(m_local_gate_input_grad_ptr_[expert_ids[i * k + j]] +\n                                                    m_local_pos_[i][j] * config_.hidden_size + e),\n                                        &gate_input_grad0, &gate_input_grad1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(m_local_up_input_grad_ptr_[expert_ids[i * k + j]] +\n                                                    m_local_pos_[i][j] * config_.hidden_size + e),\n                                        &up_input_grad0, &up_input_grad1);\n              x0 = _mm512_add_ps(gate_input_grad0, x0);\n              x1 = _mm512_add_ps(gate_input_grad1, x1);\n              x0 = _mm512_add_ps(up_input_grad0, x0);\n              x1 = _mm512_add_ps(up_input_grad1, x1);\n            }\n            avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)input_grad + i * config_.hidden_size + e));\n          }\n        },\n        nullptr);\n  }\n};\n#endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/amx/debug_tools_sft_moe.hpp",
    "content": "#ifndef SFT_DEBUG_HPP\n#define SFT_DEBUG_HPP\n\n#include <cstdint>\n#include <cstdlib>\n#include <fstream>\n#include <string>\n#include <iostream>\n\ninline std::string get_env_or_default(const char *var_name, const std::string &default_value) {\n\tconst char *value = std::getenv(var_name);\n\treturn (value != nullptr) ? std::string(value) : default_value;\n}\n\n/* use example:  \n\tfor (uint64_t expert_idx = 0; expert_idx < (uint64_t)config_.expert_num; ++expert_idx) {\n\t\tdump_grad_bin(\"layer0_E_End\"+std::to_string(expert_idx)+\"_gate_proj_out_trans_\", (uint8_t*)gate_proj_t_ + expert_idx * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.grad_type), config_.hidden_size * config_.intermediate_size, config_.grad_type);\n\t\tstd::cout << \"gate_proj_t_:\" << static_cast<const void*>((uint8_t*)gate_proj_t_ + expert_idx * config_.hidden_size * config_.intermediate_size) << \", grad_type: \" << config_.grad_type << std::endl;\n\t}\n*/\ninline void dump_grad_bin(const std::string &file_name,\n                          const void       *data,\n                          size_t            elem_cnt,\n                          ggml_type         dtype,\n\t\t\t\t\t\t  std::streamoff    offset_bytes = 0)\n{\n    std::string path = get_env_or_default(\"SFT_DEBUG_PATH\",\"debug\") + \"/\" + file_name;\n    switch (dtype) {\n        case GGML_TYPE_F32:  path += \".f32\";  break;\n        case GGML_TYPE_F16:  path += \".f16\";  break;\n        case GGML_TYPE_BF16: path += \".bf16\"; break;\n\t\tcase GGML_TYPE_I8: path += \".int8\"; break;\n        default:             path += \".raw\";  break;\n    }\n\tstd::fstream f(path, std::ios::in | std::ios::out | std::ios::binary);\n    if (!f.is_open()) {\n        std::ofstream tmp(path, std::ios::out | std::ios::binary);\n        tmp.close();\n        f.open(path, std::ios::in | std::ios::out | std::ios::binary);\n    }\n\n    f.seekp(offset_bytes * ggml_type_size(dtype));\n\t// std::cout << \"seekp: \" << offset_bytes * ggml_type_size(dtype) << std::endl;\n\n    f.write(reinterpret_cast<const char*>(data), static_cast<std::streamsize>(elem_cnt * ggml_type_size(dtype)));\n    f.close();\n}\n\n// inline void dump_bin(std::string file_name, float16_t *data, size_t count) {\n//   file_name = get_env_or_default(\"SFT_DEBUG_PATH\", \"debug\") + \"/\" + file_name + \".f16\";\n//   std::ofstream f(file_name, std::ios::binary);\n//   f.write(reinterpret_cast<const char *>(data), count * sizeof(*data));\n//   f.close();\n// }\ninline void dump_bin(std::string file_name, float *data, size_t count) {\n\tfile_name = get_env_or_default(\"SFT_DEBUG_PATH\", \"debug\") + \"/\" + file_name + \".f32\";\n\tstd::cout << file_name << std::endl;\n\tstd::ofstream f(file_name, std::ios::binary);\n\tf.write(reinterpret_cast<const char *>(data), count * sizeof(*data));\n\tf.close();\n}\ninline void dump_bin(std::string file_name, int64_t *data, size_t count) {\n\tfile_name = get_env_or_default(\"SFT_DEBUG_PATH\", \"debug\") + \"/\" + file_name + \".int64\";\n\tstd::cout << file_name << std::endl;\n\tstd::ofstream f(file_name, std::ios::binary);\n\tf.write(reinterpret_cast<const char *>(data), count * sizeof(*data));\n\tf.close();\n}\ninline void dump_bin(std::string file_name, uint8_t *data, size_t count) {\n\tfile_name = get_env_or_default(\"SFT_DEBUG_PATH\", \"debug\") + \"/\" + file_name + \".uint8\";\n\tstd::cout << file_name << std::endl;\n\tstd::ofstream f(file_name, std::ios::binary);\n\tf.write(reinterpret_cast<const char *>(data), count * sizeof(*data));\n\tf.close();\n}\n\n#endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/amx/la/amx.hpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2025-04-25 18:28:12\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2025-04-25 18:28:12\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#pragma once\n#include <array>\n#include <cassert>\n#include <cstdint>\n#include <cstdio>\n#include <immintrin.h>\n#include <iostream>\n#include <random>\n#include <stdexcept>\n#include <stdlib.h>\n#include <sys/syscall.h>\n#include <unistd.h>\n\n#include \"utils.hpp\"\n#include <memory>\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define RESTRICT __restrict\n#else\n#define RESTRICT __restrict__\n#endif\n\n#if (defined(_WIN32) || defined(_WIN64))\n#define ALWAYS_INLINE __forceinline\n#elif __has_attribute(always_inline) || defined(__GNUC__)\n#define ALWAYS_INLINE __attribute__((__always_inline__)) inline\n#else\n#define ALWAYS_INLINE inline\n#endif\n\nnamespace amx {\n\n#define ARCH_GET_XCOMP_PERM 0x1022\n#define ARCH_REQ_XCOMP_PERM 0x1023\n#define XFEATURE_XTILECFG 17\n#define XFEATURE_XTILEDATA 18\n\nconst int TMMCount = 8;\nconst int MaxTileHeight = 16;\nconst int MaxTileWidth = 64;\n\nconst int AMX_BLK_SIZE = 32;\n\n#define TMM0 0\n#define TMM1 1\n#define TMM2 2\n#define TMM3 3\n#define TMM4 4\n#define TMM5 5\n#define TMM6 6\n#define TMM7 7\n\ninline bool enable_amx() {\n  static thread_local bool initialized = false;\n  if (initialized) {\n    return true;\n  }\n  initialized = true;\n\n  if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {\n    printf(\"\\n Fail to do XFEATURE_XTILEDATA \\n\\n\");\n    return false;\n  } else {\n    // printf(\"\\n TILE DATA USE SET - OK \\n\\n\");\n    return true;\n  }\n  return true;\n}\n\nstruct alignas(64) TileConfig {\n  uint8_t palette;\n  uint8_t start_row;\n  std::array<uint8_t, 14> __0 = {};\n  std::array<uint16_t, 8> colsb;\n  std::array<uint8_t, 16> __1 = {};\n  std::array<uint8_t, 8> rows;\n  std::array<uint8_t, 8> __2 = {};\n\n  TileConfig() {\n    palette = 1;\n    start_row = 0;\n    for (int i = 0; i < 8; i++) {\n      set_row_col(i, 0, 0);\n    }\n  }\n\n  void set_row_col(int i, uint8_t row, uint16_t col) {\n    colsb[i] = col;\n    rows[i] = row;\n  }\n\n  void set_config() { _tile_loadconfig(this); }\n\n  static void load_data(int to, void *from, size_t stride) {\n    switch (to) {\n    case 0:\n      _tile_loadd(0, from, stride);\n      break;\n    case 1:\n      _tile_loadd(1, from, stride);\n      break;\n    case 2:\n      _tile_loadd(2, from, stride);\n      break;\n    case 3:\n      _tile_loadd(3, from, stride);\n      break;\n    case 4:\n      _tile_loadd(4, from, stride);\n      break;\n    case 5:\n      _tile_loadd(5, from, stride);\n      break;\n    case 6:\n      _tile_loadd(6, from, stride);\n      break;\n    case 7:\n      _tile_loadd(7, from, stride);\n      break;\n    default:\n      throw std::runtime_error(\"no such tile\");\n    }\n  }\n\n  static void store_data(int from, void *to, size_t stride) {\n    switch (from) {\n    case 0:\n      _tile_stored(0, to, stride);\n      break;\n    case 1:\n      _tile_stored(1, to, stride);\n      break;\n    case 2:\n      _tile_stored(2, to, stride);\n      break;\n    case 3:\n      _tile_stored(3, to, stride);\n      break;\n    case 4:\n      _tile_stored(4, to, stride);\n      break;\n    case 5:\n      _tile_stored(5, to, stride);\n      break;\n    case 6:\n      _tile_stored(6, to, stride);\n      break;\n    case 7:\n      _tile_stored(7, to, stride);\n      break;\n    default:\n      throw std::runtime_error(\"no such tile\");\n    }\n  }\n};\n\nstatic_assert(sizeof(TileConfig) == 64);\n\ninline void debug_tile(int t) {\n  printf(\"Tile %d\\n\", t);\n  uint8_t data[16][64] = {};\n  TileConfig::store_data(t, data, 64);\n  for (int i = 0; i < 16; i++) {\n    for (int j = 0; j < 64; j++) {\n      printf(\"%3d \", data[i][j]);\n    }\n    printf(\"\\n\");\n  }\n  printf(\"\\n\");\n}\n\ninline void debug_tiles(int to = 8) {\n  for (int i = 0; i < to; i++) {\n    debug_tile(i);\n  }\n}\n\ninline void debug_m512(__m512 x) {\n  float data[16];\n  _mm512_storeu_ps(data, x);\n  for (int i = 0; i < 16; i++) {\n    printf(\"%f \", data[i]);\n  }\n  printf(\"\\n\");\n}\n\n// transpose utils\ninline void transpose_16x16_32bit(__m512i *v) {\n  __m512i v1[16];\n  v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);\n  v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);\n  v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);\n  v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);\n  v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);\n  v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);\n  v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);\n  v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);\n  v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);\n  v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);\n  v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);\n  v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);\n  v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);\n  v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);\n  v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);\n  v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);\n\n  v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);\n  v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);\n  v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);\n  v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);\n  v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);\n  v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);\n  v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);\n  v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);\n  v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);\n  v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);\n  v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);\n  v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);\n  v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);\n  v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);\n  v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);\n  v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);\n\n  v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);\n  v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);\n  v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);\n  v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);\n  v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);\n  v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);\n  v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);\n  v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);\n  v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);\n  v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);\n  v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);\n  v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);\n  v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);\n  v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);\n  v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);\n  v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);\n\n  v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);\n  v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);\n  v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);\n  v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);\n  v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);\n  v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);\n  v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);\n  v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);\n  v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);\n  v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);\n  v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);\n  v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);\n  v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);\n  v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);\n  v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);\n  v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);\n}\n\n/*\n  Transpose 16x16 32-bit elements\n  Note that v must be 64 byte aligned\n*/\ninline void transpose_16x16_32bit(__m512i *v, size_t stride) {\n  assert(reinterpret_cast<intptr_t>(v) % 64 == 0 && \"v must be 64 aligned\");\n\n  auto stride_v = [=](int i) { return offset_pointer(v, i * stride); };\n  __m512i v1[16];\n\n  v1[0] = _mm512_unpacklo_epi32(*stride_v(0), *stride_v(1));\n  v1[1] = _mm512_unpackhi_epi32(*stride_v(0), *stride_v(1));\n  v1[2] = _mm512_unpacklo_epi32(*stride_v(2), *stride_v(3));\n  v1[3] = _mm512_unpackhi_epi32(*stride_v(2), *stride_v(3));\n  v1[4] = _mm512_unpacklo_epi32(*stride_v(4), *stride_v(5));\n  v1[5] = _mm512_unpackhi_epi32(*stride_v(4), *stride_v(5));\n  v1[6] = _mm512_unpacklo_epi32(*stride_v(6), *stride_v(7));\n  v1[7] = _mm512_unpackhi_epi32(*stride_v(6), *stride_v(7));\n  v1[8] = _mm512_unpacklo_epi32(*stride_v(8), *stride_v(9));\n  v1[9] = _mm512_unpackhi_epi32(*stride_v(8), *stride_v(9));\n  v1[10] = _mm512_unpacklo_epi32(*stride_v(10), *stride_v(11));\n  v1[11] = _mm512_unpackhi_epi32(*stride_v(10), *stride_v(11));\n  v1[12] = _mm512_unpacklo_epi32(*stride_v(12), *stride_v(13));\n  v1[13] = _mm512_unpackhi_epi32(*stride_v(12), *stride_v(13));\n  v1[14] = _mm512_unpacklo_epi32(*stride_v(14), *stride_v(15));\n  v1[15] = _mm512_unpackhi_epi32(*stride_v(14), *stride_v(15));\n\n  *stride_v(0) = _mm512_unpacklo_epi64(v1[0], v1[2]);\n  *stride_v(1) = _mm512_unpackhi_epi64(v1[0], v1[2]);\n  *stride_v(2) = _mm512_unpacklo_epi64(v1[1], v1[3]);\n  *stride_v(3) = _mm512_unpackhi_epi64(v1[1], v1[3]);\n  *stride_v(4) = _mm512_unpacklo_epi64(v1[4], v1[6]);\n  *stride_v(5) = _mm512_unpackhi_epi64(v1[4], v1[6]);\n  *stride_v(6) = _mm512_unpacklo_epi64(v1[5], v1[7]);\n  *stride_v(7) = _mm512_unpackhi_epi64(v1[5], v1[7]);\n  *stride_v(8) = _mm512_unpacklo_epi64(v1[8], v1[10]);\n  *stride_v(9) = _mm512_unpackhi_epi64(v1[8], v1[10]);\n  *stride_v(10) = _mm512_unpacklo_epi64(v1[9], v1[11]);\n  *stride_v(11) = _mm512_unpackhi_epi64(v1[9], v1[11]);\n  *stride_v(12) = _mm512_unpacklo_epi64(v1[12], v1[14]);\n  *stride_v(13) = _mm512_unpackhi_epi64(v1[12], v1[14]);\n  *stride_v(14) = _mm512_unpacklo_epi64(v1[13], v1[15]);\n  *stride_v(15) = _mm512_unpackhi_epi64(v1[13], v1[15]);\n\n  v1[0] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0x88);\n  v1[1] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0x88);\n  v1[2] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0x88);\n  v1[3] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0x88);\n  v1[4] = _mm512_shuffle_i32x4(*stride_v(0), *stride_v(4), 0xdd);\n  v1[5] = _mm512_shuffle_i32x4(*stride_v(1), *stride_v(5), 0xdd);\n  v1[6] = _mm512_shuffle_i32x4(*stride_v(2), *stride_v(6), 0xdd);\n  v1[7] = _mm512_shuffle_i32x4(*stride_v(3), *stride_v(7), 0xdd);\n  v1[8] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0x88);\n  v1[9] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0x88);\n  v1[10] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0x88);\n  v1[11] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0x88);\n  v1[12] = _mm512_shuffle_i32x4(*stride_v(8), *stride_v(12), 0xdd);\n  v1[13] = _mm512_shuffle_i32x4(*stride_v(9), *stride_v(13), 0xdd);\n  v1[14] = _mm512_shuffle_i32x4(*stride_v(10), *stride_v(14), 0xdd);\n  v1[15] = _mm512_shuffle_i32x4(*stride_v(11), *stride_v(15), 0xdd);\n\n  *stride_v(0) = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);\n  *stride_v(1) = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);\n  *stride_v(2) = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);\n  *stride_v(3) = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);\n  *stride_v(4) = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);\n  *stride_v(5) = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);\n  *stride_v(6) = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);\n  *stride_v(7) = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);\n  *stride_v(8) = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);\n  *stride_v(9) = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);\n  *stride_v(10) = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);\n  *stride_v(11) = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);\n  *stride_v(12) = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);\n  *stride_v(13) = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);\n  *stride_v(14) = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);\n  *stride_v(15) = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);\n}\n\nstruct GemmKernel224BF {\n  using dt = ggml_bf16_t;\n  using output_t = float;\n  static const int TILE_M = 16;\n  static const int TILE_K = 32;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 2;\n\n  static inline constexpr int M_STEP = TILE_M * 2;\n  static inline constexpr int N_STEP = TILE_N * 2;\n  static inline constexpr int K_STEP = TILE_K;\n\n  static inline const int N_BLOCK = 256;\n  static inline const int K_BLOCK = 1792;\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 32\n    for (int i = 0; i < 2; i++)\n      tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));\n\n    // size is 16 x 32\n    for (int i = 2; i < 4; i++)\n      tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++)\n      tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n  }\n\n  static void load_a(dt *a, size_t lda) {\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n  }\n\n  static void load_b(dt *b, size_t ldb) {\n    _tile_loadd(2, b, ldb);\n    _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);\n  }\n\n  static void clean_c() {\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n  }\n\n  static void load_c(output_t *c, size_t ldc) {\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n  }\n\n  static void store_c(output_t *c, size_t ldc) {\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n  }\n\n  static void run_tile() {\n    _tile_dpbf16ps(4, 0, 2);\n    _tile_dpbf16ps(5, 0, 3);\n    _tile_dpbf16ps(6, 1, 2);\n    _tile_dpbf16ps(7, 1, 3);\n  }\n\n  struct BufferA {\n    ggml_bf16_t *a;\n    int max_m, k;\n\n    static size_t required_size(int max_m, int k) { return max_m * k * sizeof(ggml_bf16_t); }\n\n    BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(max_m % M_STEP == 0);\n      assert(k % K_STEP == 0);\n      a = reinterpret_cast<ggml_bf16_t *>(ptr);\n    }\n\n    void from_mat(int m, ggml_bf16_t *src, int ith, int nth) {\n      assert(m <= max_m);\n      assert(ith == 0 && nth == 1);\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n              __m512i *s = (__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin);\n              __m512i *d = (__m512i *)(a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP +\n                                       i * K_STEP);\n              avx512_copy_32xbf16(s, d);\n            }\n          }\n        }\n      }\n    }\n\n    ggml_bf16_t *get_submat(int m, int k, int m_begin, int k_begin) {\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;\n    }\n  };\n\n  struct BufferB {\n    ggml_bf16_t *b;\n    int n, k;\n\n    static size_t required_size(int n, int k) { return n * k * sizeof(ggml_bf16_t); }\n\n    BufferB(int n, int k, void *ptr) : n(n), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(n % N_STEP == 0);\n      assert(k % K_STEP == 0);\n      b = reinterpret_cast<ggml_bf16_t *>(ptr);\n    }\n\n    void from_mat(ggml_bf16_t *src, int ith, int nth) {\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < N_STEP; i++) {\n              __m512i *s = (__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin);\n              __m512i *d = (__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                                       k_begin * N_STEP + i * K_STEP);\n              avx512_copy_32xbf16(s, d);\n            }\n            transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                              n_begin * k_block_size + k_begin * N_STEP));\n            transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                              n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));\n          }\n        }\n      }\n    }\n\n    ggml_bf16_t *get_submat(int n, int k, int n_begin, int k_begin) {\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      n_begin -= n_block_begin;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;\n    }\n  };\n\n  struct BufferC {\n    float *c;\n    int max_m, n;\n\n    static size_t required_size(int max_m, int n) { return max_m * n * sizeof(float); }\n\n    BufferC(int max_m, int n, void *ptr) : max_m(max_m), n(n) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(max_m % M_STEP == 0);\n      assert(n % N_STEP == 0);\n      c = reinterpret_cast<float *>(ptr);\n    }\n\n    void to_mat(int m, ggml_bf16_t *dst, int ith, int nth) {\n      assert(m <= max_m);\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            __m512 *x0 =\n                (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);\n            __m512 *x1 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP +\n                                    i * N_STEP + 16);\n            avx512_32xfp32_to_32xbf16(x0, x1, (__m512i *)(dst + (m_begin + i) * n + n_block_begin + n_begin));\n          }\n        }\n      }\n    }\n\n    float *get_submat(int m, int n, int m_begin, int n_begin) {\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      n_begin -= n_block_begin;\n      return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;\n    }\n  };\n};\n\nstruct GemmKernel224Int8 {\n  using dt = int8_t;\n  using output_t = int32_t;\n  static const int TILE_M = 16;\n  static const int TILE_K = 64;\n  static const int TILE_N = 16;\n  static const int VNNI_BLK = 4;\n\n  static inline constexpr int M_STEP = TILE_M * 2;\n  static inline constexpr int N_STEP = TILE_N * 2;\n  static inline constexpr int K_STEP = TILE_K;\n\n  static inline const int N_BLOCK = 256;\n  static inline const int K_BLOCK = 3584;\n\n  static int recommended_nth(int n) { return (n + N_BLOCK - 1) / N_BLOCK; }\n\n  static std::pair<int, int> split_range_n(int n, int ith, int nth) {\n    int n_start = N_BLOCK * ith;\n    int n_end = std::min(n, N_BLOCK * (ith + 1));\n    return {n_start, n_end};\n  }\n\n  static void config() {\n    enable_amx();\n    TileConfig tile_config;\n\n    // size is 16 x 64\n    for (int i = 0; i < 2; i++)\n      tile_config.set_row_col(i, TILE_M, TILE_K * sizeof(dt));\n\n    // size is 16 x 64\n    for (int i = 2; i < 4; i++)\n      tile_config.set_row_col(i, TILE_K / VNNI_BLK, TILE_N * VNNI_BLK * sizeof(dt));\n\n    // size is 16 x 16\n    for (int i = 4; i < 8; i++) \n      tile_config.set_row_col(i, TILE_M, TILE_N * sizeof(output_t));\n\n    tile_config.set_config();\n  }\n\n  static void load_a(dt *a, size_t lda) {\n    _tile_loadd(0, a, lda);\n    _tile_loadd(1, offset_pointer(a, lda * TILE_M), lda);\n  }\n\n  static void load_b(dt *b, size_t ldb) {\n    _tile_loadd(2, b, ldb);\n    _tile_loadd(3, offset_pointer(b, ldb * TILE_N), ldb);\n  }\n\n  static void clean_c() {\n    _tile_zero(4);\n    _tile_zero(5);\n    _tile_zero(6);\n    _tile_zero(7);\n  }\n\n  static void load_c(output_t *c, size_t ldc) {\n    _tile_loadd(4, c, ldc);\n    _tile_loadd(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_loadd(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_loadd(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n  }\n\n  static void store_c(output_t *c, size_t ldc) {\n    _tile_stored(4, c, ldc);\n    _tile_stored(5, offset_pointer(c, TILE_N * sizeof(output_t)), ldc);\n    _tile_stored(6, offset_pointer(c, ldc * TILE_M), ldc);\n    _tile_stored(7, offset_pointer(c, ldc * TILE_M + TILE_N * sizeof(output_t)), ldc);\n  }\n\n  static void run_tile() {\n    _tile_dpbssd(4, 0, 2);\n    _tile_dpbssd(5, 0, 3);\n    _tile_dpbssd(6, 1, 2);\n    _tile_dpbssd(7, 1, 3);\n  }\n\n  struct BufferA {\n    int8_t *a;\n    float *d;\n    int max_m, k;\n\n    static size_t required_size(int max_m, int k) { return max_m * k * sizeof(int8_t) + max_m * sizeof(float); }\n\n    BufferA(int max_m, int k, void *ptr) : max_m(max_m), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(max_m % M_STEP == 0);\n      assert(k % K_STEP == 0);\n      a = reinterpret_cast<int8_t *>(ptr);\n      d = reinterpret_cast<float *>(a + max_m * k);\n    }\n\n    void from_mat(int m, ggml_bf16_t *src, int ith, int nth) {\n      assert(m <= max_m);\n      assert(ith == 0 && nth == 1);\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n          float amax = 0.0f;\n          for (int j = 0; j < k; j += 32) {\n            __m512 f0, f1;\n            avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + j), &f0, &f1);\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n          }\n          d[m_begin + i] = amax / ((1 << 7) - 1);\n        }\n      }\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n              __m512 id = _mm512_set1_ps(d[m_begin + i] ? 1.0f / d[m_begin + i] : 0.0f);\n              int8_t *dst = a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP + i * K_STEP;\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin), &f0, &f1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(src + (m_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n              __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n              __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n              __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n              __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n              _mm_storeu_si128((__m128i *)dst, s0);\n              _mm_storeu_si128((__m128i *)(dst + 16), s1);\n              _mm_storeu_si128((__m128i *)(dst + 32), s2);\n              _mm_storeu_si128((__m128i *)(dst + 48), s3);\n            }\n          }\n        }\n      }\n    }\n\n    int8_t *get_submat(int m, int k, int m_begin, int k_begin) {\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return a + k_block_begin * m_block_size + m_begin * k_block_size + k_begin * M_STEP;\n    }\n\n    float *get_scale(int m, int m_begin) { return d + m_begin; }\n  };\n\n  struct BufferB {\n    int8_t *b;\n    float *d;\n    int n, k;\n\n    static size_t required_size(int n, int k) { return n * k * sizeof(int8_t) + n * sizeof(float); }\n\n    BufferB(int n, int k, void *ptr) : n(n), k(k) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(n % N_STEP == 0);\n      assert(k % K_STEP == 0);\n      b = reinterpret_cast<int8_t *>(ptr);\n      d = reinterpret_cast<float *>(b + n * k);\n    }\n\n    void from_mat(ggml_bf16_t *src, int ith, int nth) {\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int i = 0; i < N_STEP; i++) {\n          float amax = 0.0f;\n          for (int j = 0; j < k; j += 32) {\n            __m512 f0, f1;\n            avx512_32xbf16_to_32xfp32((__m512i *)(src + (n_block_begin + n_begin + i) * k + j), &f0, &f1);\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f0)));\n            amax = MAX(amax, _mm512_reduce_max_ps(_mm512_abs_ps(f1)));\n          }\n          d[n_block_begin + n_begin + i] = amax / ((1 << 7) - 1);\n        }\n      }\n      for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n        for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K_BLOCK) {\n          int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n          for (int k_begin = 0; k_begin < k_block_size; k_begin += K_STEP) {\n            for (int i = 0; i < N_STEP; i++) {\n              __m512 id = _mm512_set1_ps(d[n_block_begin + n_begin + i] ? 1.0f / d[n_block_begin + n_begin + i] : 0.0f);\n              int8_t *dst = b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size +\n                            k_begin * N_STEP + i * K_STEP;\n              __m512 f0, f1, f2, f3;\n              avx512_32xbf16_to_32xfp32((__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin),\n                                        &f0, &f1);\n              avx512_32xbf16_to_32xfp32(\n                  (__m512i *)(src + (n_block_begin + n_begin + i) * k + k_block_begin + k_begin) + 1, &f2, &f3);\n              __m512i i0 = _mm512_cvtps_epi32(_mm512_mul_ps(f0, id));\n              __m512i i1 = _mm512_cvtps_epi32(_mm512_mul_ps(f1, id));\n              __m512i i2 = _mm512_cvtps_epi32(_mm512_mul_ps(f2, id));\n              __m512i i3 = _mm512_cvtps_epi32(_mm512_mul_ps(f3, id));\n              __m128i s0 = _mm512_cvtsepi32_epi8(i0);\n              __m128i s1 = _mm512_cvtsepi32_epi8(i1);\n              __m128i s2 = _mm512_cvtsepi32_epi8(i2);\n              __m128i s3 = _mm512_cvtsepi32_epi8(i3);\n              _mm_storeu_si128((__m128i *)dst, s0);\n              _mm_storeu_si128((__m128i *)(dst + 16), s1);\n              _mm_storeu_si128((__m128i *)(dst + 32), s2);\n              _mm_storeu_si128((__m128i *)(dst + 48), s3);\n            }\n            transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                              n_begin * k_block_size + k_begin * N_STEP));\n            transpose_16x16_32bit((__m512i *)(b + n_block_begin * k + k_block_begin * n_block_size +\n                                              n_begin * k_block_size + k_begin * N_STEP + TILE_N * K_STEP));\n          }\n        }\n      }\n    }\n\n    int8_t *get_submat(int n, int k, int n_begin, int k_begin) {\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      n_begin -= n_block_begin;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      int k_block_begin = k_begin / K_BLOCK * K_BLOCK;\n      k_begin -= k_block_begin;\n      int k_block_size = std::min(K_BLOCK, k - k_block_begin);\n      return b + n_block_begin * k + k_block_begin * n_block_size + n_begin * k_block_size + k_begin * N_STEP;\n    }\n\n    float *get_scale(int n, int n_begin) { return d + n_begin; }\n  };\n\n  struct BufferC {\n    float *c;\n    int max_m, n;\n\n    static size_t required_size(int max_m, int n) { return max_m * n * sizeof(float); }\n\n    BufferC(int max_m, int n, void *ptr) : max_m(max_m), n(n) {\n      assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n      assert(max_m % M_STEP == 0);\n      assert(n % N_STEP == 0);\n      c = reinterpret_cast<float *>(ptr);\n    }\n\n    void to_mat(int m, ggml_bf16_t *dst, int ith, int nth) {\n      assert(m <= max_m);\n      auto [n_start, n_end] = split_range_n(n, ith, nth);\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int n_block_begin = n_start;\n      int n_block_size = n_end - n_block_begin;\n      for (int m_begin = 0; m_begin < m; m_begin += M_STEP) {\n        for (int n_begin = 0; n_begin < n_block_size; n_begin += N_STEP) {\n          for (int i = 0; i < M_STEP && m_begin + i < m; i++) {\n            __m512 *x0 =\n                (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP + i * N_STEP);\n            __m512 *x1 = (__m512 *)(c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP +\n                                    i * N_STEP + 16);\n            avx512_32xfp32_to_32xbf16(x0, x1, (__m512i *)(dst + (m_begin + i) * n + n_block_begin + n_begin));\n          }\n        }\n      }\n    }\n\n    float *get_submat(int m, int n, int m_begin, int n_begin) {\n      int m_block_size = (m + M_STEP - 1) / M_STEP * M_STEP;\n      int n_block_begin = n_begin / N_BLOCK * N_BLOCK;\n      int n_block_size = std::min(N_BLOCK, n - n_block_begin);\n      n_begin -= n_block_begin;\n      return c + m_block_size * n_block_begin + m_begin * n_block_size + n_begin * M_STEP;\n    }\n  };\n};\n\ninline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224BF::BufferA> ba,\n                    std::shared_ptr<GemmKernel224BF::BufferB> bb, std::shared_ptr<GemmKernel224BF::BufferC> bc, int ith,\n                    int nth, bool use_amx) {\n//   std::cout << \"mat_mul in BF16!!!!\" << std::endl;\n  using K = GemmKernel224BF;\n  assert(n % K::N_STEP == 0);\n  assert(k % K::K_STEP == 0);\n\n  auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n\n  for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {\n    for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {\n\n        float *c = bc->get_submat(m, n, m_begin, n_begin);\n        if (!use_amx) {\n          __m512 *c512 = (__m512 *)c;\n          if (k_block_begin == 0) {\n            for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) {\n              c512[m_i * 2] = _mm512_setzero_ps();\n              c512[m_i * 2 + 1] = _mm512_setzero_ps();\n            }\n          }\n\n          for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n            int32_t *a32 = (int32_t *)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n            __m512bh *b512 = (__m512bh *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n            for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) {\n              for (int k_i = 0; k_i < 16; k_i++) {\n                __m512bh ma = (__m512bh)_mm512_set1_epi32(a32[m_i * 16 + k_i]);\n                for (int n_i = 0; n_i < 2; n_i++) {\n                  c512[m_i * 2 + n_i] = _mm512_dpbf16_ps(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]);\n                }\n              }\n            }\n          }\n\n        } else {\n          if (k_block_begin == 0) {\n            K::clean_c();\n          } else {\n            K::load_c(c, K::N_STEP * sizeof(float));\n          }\n          for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n            K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t));\n            K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(ggml_bf16_t));\n            K::run_tile();\n          }\n          K::store_c(c, K::N_STEP * sizeof(float));\n        }\n      }\n    }\n  }\n}\n\ninline __m512i _mm512_dpbssd_epi32(__m512i src, __m512i a, __m512i b) {\n  __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);\n  __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);\n  __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);\n  __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);\n\n  b_lo = _mm256_sign_epi8(b_lo, a_lo);\n  b_hi = _mm256_sign_epi8(b_hi, a_hi);\n\n  b = _mm512_inserti64x4(b, b_lo, 0);\n  b = _mm512_inserti64x4(b, b_hi, 1);\n\n  a = _mm512_abs_epi8(a);\n\n  return _mm512_dpbusd_epi32(src, a, b);\n}\n\ninline void mat_mul(int m, int n, int k, std::shared_ptr<GemmKernel224Int8::BufferA> ba,\n                    std::shared_ptr<GemmKernel224Int8::BufferB> bb, std::shared_ptr<GemmKernel224Int8::BufferC> bc,\n                    int ith, int nth, bool use_amx) {\n//   std::cout << \"mat_mul in INT8!!!!\" << std::endl;\n  using K = GemmKernel224Int8;\n  assert(n % K::N_STEP == 0);\n  assert(k % K::K_STEP == 0);\n\n  auto [n_start, n_end] = K::split_range_n(n, ith, nth);\n\n  for (int k_block_begin = 0; k_block_begin < k; k_block_begin += K::K_BLOCK) {\n    for (int m_begin = 0; m_begin < m; m_begin += K::M_STEP) {\n      for (int n_begin = n_start; n_begin < n_end; n_begin += K::N_STEP) {\n        float *c = bc->get_submat(m, n, m_begin, n_begin);\n\n        if (!use_amx) {\n          __m512i *c512 = (__m512i *)c;\n          if (k_block_begin == 0) {\n            for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) {\n              c512[m_i * 2] = _mm512_setzero_si512();\n              c512[m_i * 2 + 1] = _mm512_setzero_si512();\n            }\n          }\n\n          for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n            static_assert(K::K_STEP * sizeof(int8_t) == sizeof(__m512i));\n            static_assert(K::N_STEP / K::TILE_N == 2, \"Must be lke this\");\n\n            int32_t *a32 = (int32_t *)ba->get_submat(m, k, m_begin, k_block_begin + k_begin);\n            __m512i *b512 = (__m512i *)bb->get_submat(n, k, n_begin, k_block_begin + k_begin);\n            for (int m_i = 0; m_i < m && m_i < K::M_STEP; m_i++) {\n              for (int k_i = 0; k_i < 16; k_i++) {\n                __m512i ma = _mm512_set1_epi32(a32[m_i * 16 + k_i]);\n                for (int n_i = 0; n_i < 2; n_i++) {\n                  c512[m_i * 2 + n_i] = _mm512_dpbssd_epi32(c512[m_i * 2 + n_i], ma, b512[n_i * 16 + k_i]);\n                }\n              }\n            }\n          }\n        } else {\n          if (k_block_begin == 0) {\n            K::clean_c();\n          } else {\n            K::load_c((int32_t *)c, K::N_STEP * sizeof(int32_t));\n          }\n          for (int k_begin = 0; k_begin < K::K_BLOCK && k_block_begin + k_begin < k; k_begin += K::K_STEP) {\n            K::load_a(ba->get_submat(m, k, m_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t));\n            K::load_b(bb->get_submat(n, k, n_begin, k_block_begin + k_begin), K::K_STEP * sizeof(int8_t));\n            K::run_tile();\n          }\n          K::store_c((int32_t *)c, K::N_STEP * sizeof(int32_t));\n        }\n\n        if (k_block_begin + K::K_BLOCK >= k) {\n          int to = m - m_begin;\n          if (m - m_begin > K::M_STEP) {\n            to = K::M_STEP;\n          }\n          for (int i = 0; i < to; i++) {\n            __m512 as = _mm512_set1_ps(*ba->get_scale(m, m_begin + i));\n            __m512 bs = _mm512_load_ps(bb->get_scale(n, n_begin));\n            __m512i now = _mm512_load_si512((__m512i *)(c + i * K::N_STEP));\n            __m512 result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n            _mm512_store_ps((__m512 *)(c + i * K::N_STEP), result);\n            bs = _mm512_load_ps(bb->get_scale(n, n_begin) + K::TILE_N);\n            now = _mm512_load_si512((__m512i *)(c + i * K::N_STEP + K::TILE_N));\n            result = _mm512_mul_ps(_mm512_mul_ps(as, bs), _mm512_cvtepi32_ps(now));\n            _mm512_store_ps((__m512 *)(c + i * K::N_STEP + K::TILE_N), result);\n          }\n        }\n      }\n    }\n  }\n}\n\n} // namespace amx"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/amx/la/utils.hpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2025-04-25 18:28:12\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2025-04-25 18:28:12\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#pragma once\n#include <cstdint>\n\n\ntemplate <typename T>\nT* offset_pointer(T* ptr, std::size_t byte_offset) {\n  return reinterpret_cast<T*>(reinterpret_cast<char*>(ptr) + byte_offset);\n}\n\ntemplate <typename T>\nconst T* offset_pointer(const T* ptr, std::size_t byte_offset) {\n  return reinterpret_cast<const T*>(reinterpret_cast<const char*>(ptr) + byte_offset);\n}\n\ntemplate <typename T>\nT* offset_pointer_row_major(T* t, int row, int col, std::size_t ld) {\n  return offset_pointer(t, row * ld) + col;\n}\n\ntemplate <typename T>\nT* offset_pointer_col_major(T* t, int row, int col, std::size_t ld) {\n  return offset_pointer(t, col * ld) + row;\n}\n\nstatic inline void avx512_copy_32xbf16(__m512i* src, __m512i* dst) {\n  _mm512_storeu_si512(dst, _mm512_loadu_si512(src));\n}\n\nstatic inline void avx512_32xfp32_to_32xbf16(__m512* src0, __m512* src1, __m512i* dst) {\n  _mm512_storeu_si512(dst, __m512i(_mm512_cvtne2ps_pbh(*src1, *src0)));\n}\n\nstatic inline void avx512_32xbf16_to_32xfp32(__m512i* src, __m512* dst0, __m512* dst1) {\n  _mm512_storeu_ps(dst0, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src))), 16)));\n  _mm512_storeu_ps(dst1, _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(src) + 1)), 16)));\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/amx/moe.hpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2025-04-25 18:28:12\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2025-04-25 18:28:12\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_AMX_MOE_H\n#define CPUINFER_OPERATOR_AMX_MOE_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\n#include \"la/amx.hpp\"\n\n#ifdef USE_NUMA\n#include <numa.h>\n#include <numaif.h>\nvoid *numa_alloc_aligned(size_t size, int node, size_t alignment) {\n  void *ptr = numa_alloc_onnode(size, node);\n  assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n  return ptr;\n}\n#endif\n\nstatic inline __m512 exp_avx512(__m512 x) {\n  const __m512 log2e = _mm512_set1_ps(1.44269504089f);\n  const __m512 c1 = _mm512_set1_ps(0.69314718056f);\n\n  __m512 y = _mm512_mul_ps(x, log2e);\n  __m512i int_part = _mm512_cvtps_epi32(y);\n  __m512 frac_part = _mm512_sub_ps(y, _mm512_cvtepi32_ps(int_part));\n\n  const __m512 poly_1 = _mm512_set1_ps(0.9999999995f);\n  const __m512 poly_2 = _mm512_set1_ps(0.6931471805f);\n  const __m512 poly_3 = _mm512_set1_ps(0.2402265069f);\n  const __m512 poly_4 = _mm512_set1_ps(0.0555041087f);\n  const __m512 poly_5 = _mm512_set1_ps(0.0096181291f);\n  const __m512 poly_6 = _mm512_set1_ps(0.0013333558f);\n\n  __m512 frac_exp = _mm512_fmadd_ps(\n      _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(poly_6, frac_part, poly_5), frac_part, poly_4),\n                                      frac_part, poly_3),\n                      frac_part, poly_2),\n      frac_part, poly_1);\n\n  __m512 two_pow_i = _mm512_scalef_ps(_mm512_set1_ps(1.0f), _mm512_cvtepi32_ps(int_part));\n  return _mm512_mul_ps(two_pow_i, frac_exp);\n}\n\nstatic inline __m512 act_fn(__m512 gate_val, __m512 up_val) {\n  __m512 neg_gate_val = _mm512_sub_ps(_mm512_setzero_ps(), gate_val);\n  __m512 exp_neg_gate = exp_avx512(neg_gate_val);\n  __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), exp_neg_gate);\n  __m512 act_val = _mm512_div_ps(gate_val, denom);\n\n  return _mm512_mul_ps(act_val, up_val);\n}\n\nstruct AMX_MOEConfig {\n  int expert_num;\n  int routed_expert_num;\n  int hidden_size;\n  int intermediate_size;\n  int max_len;\n  void *gate_proj;\n  void *up_proj;\n  void *down_proj;\n\n  AMX_MOEConfig() {}\n\n  AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len,\n                void *gate_proj, void *up_proj, void *down_proj)\n      : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size),\n        intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj),\n        down_proj(down_proj) {}\n};\n\ntemplate <class T> class AMX_MOE {\nprivate:\n  AMX_MOEConfig config_;\n  void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n  void *up_proj_;   // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n  void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n\n  ggml_bf16_t *m_local_input_;       // [routed_expert_num * max_len * hidden_size]\n  ggml_bf16_t *m_local_gate_output_; // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_up_output_;   // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_down_output_; // [routed_expert_num * max_len * hidden_size]\n\n  std::vector<std::vector<int>> m_local_pos_;          // [max_len, routed_expert_num]\n  std::vector<int> m_local_num_;                       // [expert_num]\n  std::vector<int> m_expert_id_map_;                   // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_input_ptr_;       // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_gate_output_ptr_; // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_up_output_ptr_;   // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_down_output_ptr_; // [expert_num]\n\n  std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;\n  std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;\n  std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;\n\n#ifdef USE_NUMA\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> gate_bb_numa_;\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> up_bb_numa_;\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> down_bb_numa_;\n#else\n  std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;\n#endif\n\npublic:\n  AMX_MOE(AMX_MOEConfig config) {\n    config_ = config;\n    gate_proj_ = config_.gate_proj;\n    up_proj_ = config_.up_proj;\n    down_proj_ = config_.down_proj;\n\n    std::vector<std::pair<void **, uint64_t>> m_mem_requests;\n    m_mem_requests.push_back({(void **)&m_local_input_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *\n                                                                  config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_up_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *\n                                                                config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_down_output_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    std::vector<void *> gate_up_ba_ptr(config_.expert_num);\n    std::vector<void *> gate_bc_ptr(config_.expert_num);\n    std::vector<void *> up_bc_ptr(config_.expert_num);\n    std::vector<void *> down_ba_ptr(config_.expert_num);\n    std::vector<void *> down_bc_ptr(config_.expert_num);\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_mem_requests.push_back(\n          {(void **)&gate_up_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)});\n      m_mem_requests.push_back(\n          {(void **)&gate_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&up_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});\n    }\n    shared_mem_buffer.alloc(this, m_mem_requests);\n\n    m_local_pos_.resize(config_.max_len);\n    for (int i = 0; i < config_.max_len; i++) {\n      m_local_pos_[i].resize(config_.routed_expert_num);\n    }\n    m_expert_id_map_.resize(config_.expert_num);\n    m_local_num_.resize(config_.expert_num);\n    m_local_input_ptr_.resize(config_.expert_num);\n    m_local_gate_output_ptr_.resize(config_.expert_num);\n    m_local_up_output_ptr_.resize(config_.expert_num);\n    m_local_down_output_ptr_.resize(config_.expert_num);\n\n    for (uint64_t i = 0; i < config_.expert_num; i++) {\n      gate_up_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, gate_up_ba_ptr[i]));\n      gate_bc_.push_back(\n          std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, gate_bc_ptr[i]));\n      up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, up_bc_ptr[i]));\n      down_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, down_ba_ptr[i]));\n      down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, down_bc_ptr[i]));\n\n#ifdef USE_NUMA\n      int numa_nodes = numa_num_configured_nodes();\n      gate_bb_numa_.resize(numa_nodes);\n      up_bb_numa_.resize(numa_nodes);\n      down_bb_numa_.resize(numa_nodes);\n      for (int j = 0; j < numa_nodes; j++) {\n        void *gate_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);\n        gate_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));\n        void *up_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);\n        up_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));\n        void *down_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64);\n        down_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));\n      }\n#else\n      void *gate_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));\n      gate_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));\n\n      void *up_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));\n      up_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));\n\n      void *down_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));\n      down_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));\n#endif\n    }\n  }\n\n  ~AMX_MOE() { shared_mem_buffer.dealloc(this); }\n\n  void load_weights(Backend *backend) {\n    int nth = T::recommended_nth(config_.intermediate_size);\n    backend->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [&](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          int numa_nodes = numa_num_configured_nodes();\n          for (int j = 0; j < numa_nodes; j++) {\n            gate_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +\n                                                       expert_idx * config_.intermediate_size * config_.hidden_size,\n                                                   ith, nth);\n            up_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.up_proj +\n                                                     expert_idx * config_.intermediate_size * config_.hidden_size,\n                                                 ith, nth);\n          }\n#else\n          gate_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +\n                                             expert_idx * config_.intermediate_size * config_.hidden_size,\n                                         ith, nth);\n          up_bb_[expert_idx]->from_mat(\n              (ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth);\n#endif\n        },\n        nullptr);\n    nth = T::recommended_nth(config_.hidden_size);\n    backend->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [&](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          int numa_nodes = numa_num_configured_nodes();\n          for (int j = 0; j < numa_nodes; j++) {\n            down_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +\n                                                       expert_idx * config_.hidden_size * config_.intermediate_size,\n                                                   ith, nth);\n          }\n#else\n          down_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +\n                                             expert_idx * config_.hidden_size * config_.intermediate_size,\n                                         ith, nth);\n#endif\n        },\n        nullptr);\n  }\n\n  void warm_up(Backend *backend) {}\n\n  void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output,\n               int *batch_size_tensor, Backend *backend) {\n    qlen = batch_size_tensor[0];\n    bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);\n    int activated_expert = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n      for (int j = 0; j < k; j++) {\n        m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n      }\n    }\n    for (int i = 0; i < config_.expert_num; i++) {\n      if (m_local_num_[i] > 0) {\n        m_expert_id_map_[activated_expert] = i;\n        activated_expert++;\n      }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;\n      m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n      m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n      m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;\n      offset += m_local_num_[i];\n    }\n    backend->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int j = 0; j < k; j++) {\n            memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,\n                   (ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);\n          }\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n\n          gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n    int nth = T::recommended_nth(config_.intermediate_size);\n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx],\n                       ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith,\n                       nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);\n          up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);\n          auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);\n          for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n            ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];\n            for (int j = n_start; j < n_end; j += 32) {\n              __m512 gate_val0, gate_val1, up_val0, up_val1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);\n              __m512 result0 = act_fn(gate_val0, up_val0);\n              __m512 result1 = act_fn(gate_val1, up_val1);\n              avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));\n            }\n          }\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n    nth = T::recommended_nth(config_.hidden_size);\n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],\n                       down_bb_numa_[Backend::numa_node][expert_idx], down_bc_[expert_idx], ith, nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],\n                       down_bb_[expert_idx], down_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int e = 0; e < config_.hidden_size; e += 32) {\n            __m512 x0 = _mm512_setzero_ps();\n            __m512 x1 = _mm512_setzero_ps();\n            for (int j = 0; j < k; j++) {\n              __m512 weight = _mm512_set1_ps(weights[i * k + j]);\n              __m512 down_output0, down_output1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(m_local_down_output_ptr_[expert_ids[i * k + j]] +\n                                                    m_local_pos_[i][j] * config_.hidden_size + e),\n                                        &down_output0, &down_output1);\n              x0 = _mm512_fmadd_ps(down_output0, weight, x0);\n              x1 = _mm512_fmadd_ps(down_output1, weight, x1);\n            }\n            avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + i * config_.hidden_size + e));\n          }\n        },\n        nullptr);\n  }\n};\n\n#endif\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/amx/sft_moe.hpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2025-04-25 18:28:12\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2025-04-25 18:28:12\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_SFT_AMX_MOE_H\n#define CPUINFER_OPERATOR_SFT_AMX_MOE_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n#include <fstream>\n#include <filesystem>\n\n#include \"debug_sft_moe.hpp\"\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\n#include \"la/amx.hpp\"\n\n#ifdef USE_NUMA\n#include <numa.h>\n#include <numaif.h>\n// void *numa_alloc_aligned(size_t size, int node, size_t alignment) {\n//   void *ptr = numa_alloc_onnode(size, node);\n//   assert(reinterpret_cast<intptr_t>(ptr) % 64 == 0);\n//   return ptr;\n// }\n#endif\n\nstatic inline __m512 sigmoid(__m512 x) {\n  __m512 neg = _mm512_sub_ps(_mm512_setzero_ps(), x);\n  __m512 e = exp_avx512(neg);\n  __m512 denom = _mm512_add_ps(_mm512_set1_ps(1.0f), e);\n  return _mm512_div_ps(_mm512_set1_ps(1.0f), denom);\n}\n\nstatic inline __m512 act_fn_1(__m512 x) {\n  __m512 sigmoid_val = sigmoid(x);\n  return _mm512_mul_ps(sigmoid_val, x);\n}\n\nstatic inline __m512 act_fn_grad(__m512 x) {\n  // sigmoid(x) * (1 + x * (1 - sigmoid(x)))\n  __m512 sigmoid_val = sigmoid(x);\n  __m512 one_minus_sigmoid = _mm512_sub_ps(_mm512_set1_ps(1.0f), sigmoid_val);\n  __m512 x_term = _mm512_mul_ps(x, one_minus_sigmoid);\n  __m512 one_plus_x_term = _mm512_add_ps(_mm512_set1_ps(1.0f), x_term);\n  return _mm512_mul_ps(sigmoid_val, one_plus_x_term);\n}\n\nstruct SFT_AMX_MOEConfig {\n  int expert_num;\n  int routed_expert_num;\n  int hidden_size;\n  int intermediate_size;\n  int max_len;\n  void *gate_proj;\n  void *up_proj;\n  void *down_proj;\n\n  SFT_AMX_MOEConfig() {}\n\n  SFT_AMX_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int max_len,\n                void *gate_proj, void *up_proj, void *down_proj)\n      : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size),\n        intermediate_size(intermediate_size), max_len(max_len), gate_proj(gate_proj), up_proj(up_proj),\n        down_proj(down_proj) {}\n};\n\ntemplate <class T> class SFT_AMX_MOE {\nprivate:\n  SFT_AMX_MOEConfig config_;\n  void *gate_proj_; // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n  void *up_proj_;   // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n  void *down_proj_; // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n\n  void *gate_proj_t_; // [expert_num * hidden_size * intermediate_size]\n  void *up_proj_t_;   // [expert_num * hidden_size * intermediate_size]\n  void *down_proj_t_; // [expert_num * intermediate_size * hidden_size]\n\n  ggml_bf16_t *m_local_input_;       // [routed_expert_num * max_len * hidden_size]\n  ggml_bf16_t *m_local_gate_output_; // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_up_output_;   // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_down_output_; // [routed_expert_num * max_len * hidden_size]\n\n  std::vector<std::vector<int>> m_local_pos_;          // [max_len, routed_expert_num]\n  std::vector<int> m_local_num_;                       // [expert_num]\n  std::vector<int> m_expert_id_map_;                   // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_input_ptr_;       // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_gate_output_ptr_; // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_up_output_ptr_;   // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_down_output_ptr_; // [expert_num]\n\n  std::vector<std::shared_ptr<typename T::BufferA>> gate_up_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> gate_bc_;\n  std::vector<std::shared_ptr<typename T::BufferC>> up_bc_;\n  std::vector<std::shared_ptr<typename T::BufferA>> down_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> down_bc_;\n\n#ifdef USE_NUMA\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> gate_bb_numa_;\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> up_bb_numa_;\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> down_bb_numa_;\n#else\n  std::vector<std::shared_ptr<typename T::BufferB>> gate_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> up_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> down_bb_;\n#endif\n\n  ggml_bf16_t *m_local_down_output_grad_;       // [routed_expert_num * max_len * hidden_size]\n  ggml_bf16_t *m_local_down_input_grad_;        // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_gate_output_grad_;       // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_up_output_grad_;         // [routed_expert_num * max_len * intermediate_size]\n  ggml_bf16_t *m_local_gate_input_grad_;        // [routed_expert_num * max_len * hidden_size]\n  ggml_bf16_t *m_local_up_input_grad_;          // [routed_expert_num * max_len * hidden_size]\n\n  std::vector<ggml_bf16_t *> m_local_down_output_grad_ptr_;       // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_down_input_grad_ptr_;        // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_gate_output_grad_ptr_;       // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_up_output_grad_ptr_;         // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_gate_input_grad_ptr_;        // [expert_num]\n  std::vector<ggml_bf16_t *> m_local_up_input_grad_ptr_;          // [expert_num]\n\n  std::vector<std::shared_ptr<typename T::BufferA>> gate_t_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> gate_t_bc_;\n  std::vector<std::shared_ptr<typename T::BufferA>> up_t_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> up_t_bc_;\n  std::vector<std::shared_ptr<typename T::BufferA>> down_t_ba_;\n  std::vector<std::shared_ptr<typename T::BufferC>> down_t_bc_;\n\n  // TODO: NUMA\n#ifdef USE_NUMA\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> gate_t_bb_numa_;\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> up_t_bb_numa_;\n  std::vector<std::vector<std::shared_ptr<typename T::BufferB>>> down_t_bb_numa_;\n#else\n  std::vector<std::shared_ptr<typename T::BufferB>> gate_t_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> up_t_bb_;\n  std::vector<std::shared_ptr<typename T::BufferB>> down_t_bb_;\n#endif\n\n  int* m_local_token_indices_;                                   // [routed_expert_num * max_len]\n  int* m_local_expert_positions_;                               // [routed_expert_num * max_len]\n  std::vector<int *> m_local_token_indices_ptr_;                // [expert_num]\n  std::vector<int *> m_local_expert_positions_ptr_;             // [expert_num]\n\npublic:\n  SFT_AMX_MOE(SFT_AMX_MOEConfig config) {\n    config_ = config;\n    gate_proj_ = config_.gate_proj;\n    up_proj_ = config_.up_proj;\n    down_proj_ = config_.down_proj;\n\n    std::vector<std::pair<void **, uint64_t>> m_mem_requests;\n    m_mem_requests.push_back({(void **)&m_local_input_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&m_local_gate_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *\n                                                                  config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_up_output_, sizeof(ggml_bf16_t) * config_.routed_expert_num *\n                                                                config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_down_output_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    std::vector<void *> gate_up_ba_ptr(config_.expert_num);\n    std::vector<void *> gate_bc_ptr(config_.expert_num);\n    std::vector<void *> up_bc_ptr(config_.expert_num);\n    std::vector<void *> down_ba_ptr(config_.expert_num);\n    std::vector<void *> down_bc_ptr(config_.expert_num);\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_mem_requests.push_back(\n          {(void **)&gate_up_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)});\n      m_mem_requests.push_back(\n          {(void **)&gate_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&up_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});\n    }\n\n    m_mem_requests.push_back({(void **)&gate_proj_t_,\n                              sizeof(ggml_bf16_t) * config_.expert_num * config_.intermediate_size * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&up_proj_t_,\n                              sizeof(ggml_bf16_t) * config_.expert_num * config_.intermediate_size * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&down_proj_t_,\n                              sizeof(ggml_bf16_t) * config_.expert_num * config_.hidden_size * config_.intermediate_size});\n    \n    m_mem_requests.push_back({(void **)&m_local_down_output_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&m_local_down_input_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_gate_output_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_up_output_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void **)&m_local_gate_input_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&m_local_up_input_grad_,\n                              sizeof(ggml_bf16_t) * config_.routed_expert_num * config_.max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void **)&m_local_token_indices_,\n                              sizeof(int) * config_.routed_expert_num * config_.max_len});\n    m_mem_requests.push_back({(void **)&m_local_expert_positions_,\n                              sizeof(int) * config_.routed_expert_num * config_.max_len});\n    std::vector<void *> gate_t_ba_ptr(config_.expert_num);\n    std::vector<void *> gate_t_bc_ptr(config_.expert_num);\n    std::vector<void *> up_t_ba_ptr(config_.expert_num);\n    std::vector<void *> up_t_bc_ptr(config_.expert_num);\n    std::vector<void *> down_t_ba_ptr(config_.expert_num);\n    std::vector<void *> down_t_bc_ptr(config_.expert_num);\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_mem_requests.push_back(\n          {(void **)&gate_t_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&gate_t_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});\n      m_mem_requests.push_back(\n          {(void **)&up_t_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.intermediate_size)});\n      m_mem_requests.push_back(\n          {(void **)&up_t_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.hidden_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_t_ba_ptr[i], T::BufferA::required_size(config_.max_len, config_.hidden_size)});\n      m_mem_requests.push_back(\n          {(void **)&down_t_bc_ptr[i], T::BufferC::required_size(config_.max_len, config_.intermediate_size)});\n    }\n\n    shared_mem_buffer.alloc(this, m_mem_requests);\n\n    m_local_pos_.resize(config_.max_len);\n    for (int i = 0; i < config_.max_len; i++) {\n      m_local_pos_[i].resize(config_.routed_expert_num);\n    }\n    m_expert_id_map_.resize(config_.expert_num);\n    m_local_num_.resize(config_.expert_num);\n    m_local_input_ptr_.resize(config_.expert_num);\n    m_local_gate_output_ptr_.resize(config_.expert_num);\n    m_local_up_output_ptr_.resize(config_.expert_num);\n    m_local_down_output_ptr_.resize(config_.expert_num);\n    m_local_down_output_grad_ptr_.resize(config_.expert_num);\n    m_local_down_input_grad_ptr_.resize(config_.expert_num);\n    m_local_gate_output_grad_ptr_.resize(config_.expert_num);\n    m_local_up_output_grad_ptr_.resize(config_.expert_num);\n    m_local_gate_input_grad_ptr_.resize(config_.expert_num);\n    m_local_up_input_grad_ptr_.resize(config_.expert_num);\n\n    for (uint64_t i = 0; i < config_.expert_num; i++) {\n      gate_up_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, gate_up_ba_ptr[i]));\n      gate_bc_.push_back(\n          std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, gate_bc_ptr[i]));\n      up_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, up_bc_ptr[i]));\n      down_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, down_ba_ptr[i]));\n      down_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, down_bc_ptr[i]));\n\n#ifdef USE_NUMA\n      int numa_nodes = numa_num_configured_nodes();\n      gate_bb_numa_.resize(numa_nodes);\n      up_bb_numa_.resize(numa_nodes);\n      down_bb_numa_.resize(numa_nodes);\n      for (int j = 0; j < numa_nodes; j++) {\n        void *gate_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);\n        gate_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));\n        void *up_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);\n        up_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));\n        void *down_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64);\n        down_bb_numa_[j].push_back(  \n            std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));\n      }\n#else\n      void *gate_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));\n      gate_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, gate_bb_ptr));\n\n      void *up_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));\n      up_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, up_bb_ptr));\n\n      void *down_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));\n      down_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, down_bb_ptr));\n#endif\n    }\n\n    for (uint64_t i = 0; i < config_.expert_num; i++) {\n      gate_t_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, gate_t_ba_ptr[i]));\n      gate_t_bc_.push_back(\n          std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, gate_t_bc_ptr[i]));\n      up_t_ba_.push_back(std::make_shared<typename T::BufferA>(config_.max_len, config_.intermediate_size, up_t_ba_ptr[i]));\n      up_t_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.hidden_size, up_t_bc_ptr[i]));\n      down_t_ba_.push_back(\n          std::make_shared<typename T::BufferA>(config_.max_len, config_.hidden_size, down_t_ba_ptr[i]));\n      down_t_bc_.push_back(std::make_shared<typename T::BufferC>(config_.max_len, config_.intermediate_size, down_t_bc_ptr[i]));\n\n#ifdef USE_NUMA\n      int numa_nodes = numa_num_configured_nodes();\n      gate_t_bb_numa_.resize(numa_nodes);\n      up_t_bb_numa_.resize(numa_nodes);\n      down_t_bb_numa_.resize(numa_nodes);\n      for (int j = 0; j < numa_nodes; j++) {\n        void *gate_t_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64);\n        gate_t_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, gate_t_bb_ptr));\n        void *up_t_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.hidden_size, config_.intermediate_size), j, 64);\n        up_t_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, up_t_bb_ptr));\n        void *down_t_bb_ptr =\n            numa_alloc_aligned(T::BufferB::required_size(config_.intermediate_size, config_.hidden_size), j, 64);\n        down_t_bb_numa_[j].push_back(\n            std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, down_t_bb_ptr));\n      }\n#else\n      void *gate_t_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));\n      gate_t_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, gate_t_bb_ptr));\n\n      void *up_t_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.hidden_size, config_.intermediate_size));\n      up_t_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.hidden_size, config_.intermediate_size, up_t_bb_ptr));\n\n      void *down_t_bb_ptr =\n          std::aligned_alloc(64, T::BufferB::required_size(config_.intermediate_size, config_.hidden_size));\n      down_t_bb_.push_back(\n          std::make_shared<typename T::BufferB>(config_.intermediate_size, config_.hidden_size, down_t_bb_ptr));\n#endif\n    }\n\n    m_local_token_indices_ptr_.resize(config_.expert_num);\n    m_local_expert_positions_ptr_.resize(config_.expert_num);\n  }\n\n  ~SFT_AMX_MOE() { shared_mem_buffer.dealloc(this); }\n\n  void transpose_expert(const void* src, void* dst, int R, int C, Backend* backend) {\n    backend->do_work_stealing_job(\n        config_.expert_num, nullptr,\n        [&](uint64_t expert_idx) {\n          for (int r = 0; r < R; ++r) {\n            for (int c = 0; c < C; ++c) {\n                memcpy(\n                    (uint8_t*)dst + (expert_idx * R * C + (c * R + r)) * sizeof(ggml_bf16_t),\n                    (uint8_t*)src + (expert_idx * R * C + (r * C + c)) * sizeof(ggml_bf16_t),\n                    sizeof(ggml_bf16_t));\n            }\n          }\n        },\n        nullptr);\n  }\n  \n  void load_weights(Backend *backend) {\n    transpose_expert(config_.gate_proj, gate_proj_t_, config_.intermediate_size, config_.hidden_size, backend);\n    transpose_expert(config_.up_proj, up_proj_t_, config_.intermediate_size, config_.hidden_size, backend);\n    transpose_expert(config_.down_proj, down_proj_t_, config_.hidden_size, config_.intermediate_size, backend);\n\n    int nth = T::recommended_nth(config_.intermediate_size);\n    backend->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [&](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          int numa_nodes = numa_num_configured_nodes();\n          for (int j = 0; j < numa_nodes; j++) {\n            gate_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +\n                                                       expert_idx * config_.intermediate_size * config_.hidden_size,\n                                                   ith, nth);\n            up_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.up_proj +\n                                                     expert_idx * config_.intermediate_size * config_.hidden_size,\n                                                 ith, nth);\n          }\n#else\n          gate_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.gate_proj +\n                                             expert_idx * config_.intermediate_size * config_.hidden_size,\n                                         ith, nth);\n          up_bb_[expert_idx]->from_mat(\n              (ggml_bf16_t *)config_.up_proj + expert_idx * config_.intermediate_size * config_.hidden_size, ith, nth);\n#endif\n\t\t},\n        nullptr);\n    nth = T::recommended_nth(config_.intermediate_size);\n    backend->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [&](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          int numa_nodes = numa_num_configured_nodes();\n          for (int j = 0; j < numa_nodes; j++) {\n            down_t_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)down_proj_t_ +\n                                                         expert_idx * config_.intermediate_size * config_.hidden_size,\n                                                     ith, nth);\n          }\n#else\n          down_t_bb_[expert_idx]->from_mat((ggml_bf16_t *)down_proj_t_ +\n                                             expert_idx * config_.intermediate_size * config_.hidden_size,\n                                         ith, nth);\n#endif\n        },\n        nullptr);\n        \n    nth = T::recommended_nth(config_.hidden_size);\n    backend->do_work_stealing_job(\n        nth * config_.expert_num, nullptr,\n        [&](int task_id) {\n          uint64_t expert_idx = task_id / nth;\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          int numa_nodes = numa_num_configured_nodes();\n          for (int j = 0; j < numa_nodes; j++) {\n            down_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +\n                                                       expert_idx * config_.hidden_size * config_.intermediate_size,\n                                                   ith, nth);\n          }\n#else\n          down_bb_[expert_idx]->from_mat((ggml_bf16_t *)config_.down_proj +\n                                             expert_idx * config_.hidden_size * config_.intermediate_size,\n                                         ith, nth);\n#endif\n#ifdef USE_NUMA\n          for (int j = 0; j < numa_nodes; j++) {\n            gate_t_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)gate_proj_t_ +\n                                                         expert_idx * config_.hidden_size * config_.intermediate_size,\n                                                     ith, nth);\n            up_t_bb_numa_[j][expert_idx]->from_mat((ggml_bf16_t *)up_proj_t_ +\n                                                       expert_idx * config_.hidden_size * config_.intermediate_size,\n                                                   ith, nth);\n          }\n#else\n          gate_t_bb_[expert_idx]->from_mat((ggml_bf16_t *)gate_proj_t_ +\n                                             expert_idx * config_.hidden_size * config_.intermediate_size,\n                                         ith, nth);\n          up_t_bb_[expert_idx]->from_mat((ggml_bf16_t *)up_proj_t_ +\n                                             expert_idx * config_.hidden_size * config_.intermediate_size,\n                                         ith, nth);\n#endif\n        },\n        nullptr);\n  }\n\n  void warm_up(Backend *backend) {}\n\n  void forward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void *input, void *output, Backend *backend) {\n    bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);\n    int activated_expert = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n      for (int j = 0; j < k; j++) {\n        m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n      }\n    }\n    for (int i = 0; i < config_.expert_num; i++) {\n      if (m_local_num_[i] > 0) {\n        m_expert_id_map_[activated_expert] = i;\n        activated_expert++;\n      }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;\n      m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n      m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n      m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;\n      offset += m_local_num_[i];\n    }\n    backend->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int j = 0; j < k; j++) {\n            memcpy(m_local_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size,\n                   (ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);\n          }\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n\n          gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n    int nth = T::recommended_nth(config_.intermediate_size);\n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx],\n                       ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith,\n                       nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);\n          up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);\n          auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);\n          for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n            ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];\n            for (int j = n_start; j < n_end; j += 32) {\n              __m512 gate_val0, gate_val1, up_val0, up_val1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);\n              __m512 result0 = act_fn(gate_val0, up_val0);\n              __m512 result1 = act_fn(gate_val1, up_val1);\n              avx512_32xfp32_to_32xbf16(&result0, &result1, (__m512i *)(gate_output_ptr + j));\n            }\n          }\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          down_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n\t\n    nth = T::recommended_nth(config_.hidden_size);\n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],\n                       down_bb_numa_[Backend::numa_node][expert_idx], down_bc_[expert_idx], ith, nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size, down_ba_[expert_idx],\n                       down_bb_[expert_idx], down_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          down_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_output_ptr_[expert_idx], ith, nth);\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int e = 0; e < config_.hidden_size; e += 32) {\n            __m512 x0 = _mm512_setzero_ps();\n            __m512 x1 = _mm512_setzero_ps();\n            for (int j = 0; j < k; j++) {\n              __m512 weight = _mm512_set1_ps(weights[i * k + j]);\n              __m512 down_output0, down_output1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(m_local_down_output_ptr_[expert_ids[i * k + j]] +\n                                                    m_local_pos_[i][j] * config_.hidden_size + e),\n                                        &down_output0, &down_output1);\n              x0 = _mm512_fmadd_ps(down_output0, weight, x0);\n              x1 = _mm512_fmadd_ps(down_output1, weight, x1);\n            }\n            avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)output + i * config_.hidden_size + e));\n          }\n        },\n        nullptr);\n  }\n\n  void backward(int qlen, int k, const uint64_t *expert_ids, const float *weights, const void* input, const void *output_grad, void *input_grad, Backend *backend) {\n    bool use_amx = (qlen > 4 * config_.expert_num / config_.routed_expert_num);\n    int activated_expert = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n      for (int j = 0; j < k; j++) {\n        m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n      }\n    }\n    for (int i = 0; i < config_.expert_num; i++) {\n      if (m_local_num_[i] > 0) {\n        m_expert_id_map_[activated_expert] = i;\n        activated_expert++;\n      }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n      m_local_input_ptr_[i] = m_local_input_ + offset * config_.hidden_size;\n      m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n      m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n\n      m_local_down_output_grad_ptr_[i] = m_local_down_output_grad_ + offset * config_.hidden_size;\n      m_local_down_input_grad_ptr_[i] = m_local_down_input_grad_ + offset * config_.intermediate_size;\n      m_local_gate_output_grad_ptr_[i] = m_local_gate_output_grad_ + offset * config_.intermediate_size;\n      m_local_up_output_grad_ptr_[i] = m_local_up_output_grad_ + offset * config_.intermediate_size;\n      m_local_gate_input_grad_ptr_[i] = m_local_gate_input_grad_ + offset * config_.hidden_size;\n      m_local_up_input_grad_ptr_[i] = m_local_up_input_grad_ + offset * config_.hidden_size;\n      m_local_token_indices_ptr_[i] = m_local_token_indices_ + offset;\n      m_local_expert_positions_ptr_[i] = m_local_expert_positions_ + offset;\n      offset += m_local_num_[i];\n    }\n\n    // TODO: cache\n    backend->do_work_stealing_job(\n        qlen, nullptr, \n        [&](int i) {\n          for (int j = 0; j < k; j++) {\n            uint64_t expert_id = expert_ids[i * k + j];\n            int local_row = m_local_pos_[i][j];\n            memcpy(m_local_input_ptr_[expert_id] + local_row * config_.hidden_size,\n              (ggml_bf16_t *)input + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size); // TODO: cache\n            memcpy(m_local_down_output_grad_ptr_[expert_id] + local_row * config_.hidden_size,\n              (ggml_bf16_t *)output_grad + i * config_.hidden_size, sizeof(ggml_bf16_t) * config_.hidden_size);\n            m_local_token_indices_ptr_[expert_id][local_row] = i;\n            m_local_expert_positions_ptr_[expert_id][local_row] = j;\n          }\n        }, \n        nullptr);\n\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          gate_up_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_input_ptr_[expert_idx], 0, 1); // TODO: cache\n          down_t_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_down_output_grad_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n\n    int nth = T::recommended_nth(config_.intermediate_size);  \n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n\n          // TODO: cache\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], gate_bb_numa_[Backend::numa_node][expert_idx], gate_bc_[expert_idx],\n                       ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], up_bb_numa_[Backend::numa_node][expert_idx], up_bc_[expert_idx], ith,\n                       nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], gate_bb_[expert_idx], gate_bc_[expert_idx], ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                       gate_up_ba_[expert_idx], up_bb_[expert_idx], up_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          gate_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_output_ptr_[expert_idx], ith, nth);\n          up_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_output_ptr_[expert_idx], ith, nth);\n\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                      down_t_ba_[expert_idx], down_t_bb_numa_[Backend::numa_node][expert_idx], down_t_bc_[expert_idx], ith, nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.intermediate_size, config_.hidden_size,\n                      down_t_ba_[expert_idx], down_t_bb_[expert_idx], down_t_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          down_t_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_down_input_grad_ptr_[expert_idx], ith, nth);\n\n\n          auto [n_start, n_end] = T::split_range_n(config_.intermediate_size, ith, nth);\n          for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n            ggml_bf16_t *gate_output_ptr = &m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *up_output_ptr = &m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *down_input_grad_ptr = &m_local_down_input_grad_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *gate_output_grad_ptr = &m_local_gate_output_grad_ptr_[expert_idx][i * config_.intermediate_size];\n            ggml_bf16_t *up_output_grad_ptr = &m_local_up_output_grad_ptr_[expert_idx][i * config_.intermediate_size];\n            \n            int token_idx = m_local_token_indices_ptr_[expert_idx][i];\n            int expert_pos = m_local_expert_positions_ptr_[expert_idx][i];\n            __m512 weight = _mm512_set1_ps(weights[token_idx * k + expert_pos]);\n            \n            for (int j = n_start; j < n_end; j += 32) {\n              __m512 gate_val0, gate_val1, up_val0, up_val1, down_input_grad0, down_input_grad1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(gate_output_ptr + j), &gate_val0, &gate_val1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(up_output_ptr + j), &up_val0, &up_val1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(down_input_grad_ptr + j), &down_input_grad0, &down_input_grad1);\n              \n              down_input_grad0 = _mm512_mul_ps(down_input_grad0, weight);\n              down_input_grad1 = _mm512_mul_ps(down_input_grad1, weight);\n              \n              // gate_output_grad = δ_zji ⊙ v_ji ⊙ σ'(u_ji)\n              __m512 gate_grad0 = _mm512_mul_ps(down_input_grad0, \n                                               _mm512_mul_ps(up_val0, act_fn_grad(gate_val0)));\n              __m512 gate_grad1 = _mm512_mul_ps(down_input_grad1, \n                                               _mm512_mul_ps(up_val1, act_fn_grad(gate_val1)));\n              \n              // up_output_grad = δ_zji ⊙ σ(u_ji)\n              __m512 up_grad0 = _mm512_mul_ps(down_input_grad0, act_fn_1(gate_val0));\n              __m512 up_grad1 = _mm512_mul_ps(down_input_grad1, act_fn_1(gate_val1));\n              \n              avx512_32xfp32_to_32xbf16(&gate_grad0, &gate_grad1, (__m512i *)(gate_output_grad_ptr + j));\n              avx512_32xfp32_to_32xbf16(&up_grad0, &up_grad1, (__m512i *)(up_output_grad_ptr + j));\n            }\n          }\n        },\n        nullptr);\n\n\n    backend->do_work_stealing_job(\n        activated_expert, nullptr,\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id];\n          gate_t_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_gate_output_grad_ptr_[expert_idx], 0, 1);\n          up_t_ba_[expert_idx]->from_mat(m_local_num_[expert_idx], m_local_up_output_grad_ptr_[expert_idx], 0, 1);\n        },\n        nullptr);\n    nth = T::recommended_nth(config_.hidden_size);\n    backend->do_work_stealing_job(\n        nth * activated_expert, [&](int _) { T::config(); },\n        [&](int task_id) {\n          int expert_idx = m_expert_id_map_[task_id / nth];\n          int ith = task_id % nth;\n#ifdef USE_NUMA\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,\n                      gate_t_ba_[expert_idx], gate_t_bb_numa_[Backend::numa_node][expert_idx], gate_t_bc_[expert_idx], ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,\n                      up_t_ba_[expert_idx], up_t_bb_numa_[Backend::numa_node][expert_idx], up_t_bc_[expert_idx], ith, nth, use_amx);\n#else\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,\n                      gate_t_ba_[expert_idx], gate_t_bb_[expert_idx], gate_t_bc_[expert_idx], ith, nth, use_amx);\n          amx::mat_mul(m_local_num_[expert_idx], config_.hidden_size, config_.intermediate_size,\n                      up_t_ba_[expert_idx], up_t_bb_[expert_idx], up_t_bc_[expert_idx], ith, nth, use_amx);\n#endif\n          gate_t_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_gate_input_grad_ptr_[expert_idx], ith, nth);\n          up_t_bc_[expert_idx]->to_mat(m_local_num_[expert_idx], m_local_up_input_grad_ptr_[expert_idx], ith, nth);\n        },\n        nullptr);\n    backend->do_work_stealing_job(\n        qlen, nullptr,\n        [&](int i) {\n          for (int e = 0; e < config_.hidden_size; e += 32) {\n            __m512 x0 = _mm512_setzero_ps();\n            __m512 x1 = _mm512_setzero_ps();\n            for (int j = 0; j < k; j++) {\n              __m512 gate_input_grad0, gate_input_grad1, up_input_grad0, up_input_grad1;\n              avx512_32xbf16_to_32xfp32((__m512i *)(m_local_gate_input_grad_ptr_[expert_ids[i * k + j]] +\n                                                    m_local_pos_[i][j] * config_.hidden_size + e),\n                                        &gate_input_grad0, &gate_input_grad1);\n              avx512_32xbf16_to_32xfp32((__m512i *)(m_local_up_input_grad_ptr_[expert_ids[i * k + j]] +\n                                                    m_local_pos_[i][j] * config_.hidden_size + e),\n                                        &up_input_grad0, &up_input_grad1);\n              x0 = _mm512_add_ps(gate_input_grad0, x0);\n              x1 = _mm512_add_ps(gate_input_grad1, x1);\n              x0 = _mm512_add_ps(up_input_grad0, x0);\n              x1 = _mm512_add_ps(up_input_grad1, x1);\n            }\n            avx512_32xfp32_to_32xbf16(&x0, &x1, (__m512i *)((ggml_bf16_t *)input_grad + i * config_.hidden_size + e));\n          }\n        },\n        nullptr);\n  }\n};\n#endif\n\n// for debug\n// if constexpr (std::is_same_v<typename T::dt, ggml_bf16_t>) {\t\n// \tfor (int expert_idx = 0; expert_idx < config_.expert_num; ++expert_idx) {\n// \t\tauto buf = down_t_ba_[expert_idx].get();\n\n// \t\tstd::string path = \"debug/\" + std::to_string(expert_idx) + \"_down_ba_t_debug3.bin\";\n// \t\tstd::ofstream ofs(path, std::ios::binary);\n// \t\tfor (int n_idx = 0; n_idx < m_local_num_[expert_idx]; ++n_idx) {\n// \t\t\tconst ggml_bf16_t* row = reinterpret_cast<const ggml_bf16_t*>(buf->a) + n_idx * buf->k;\n// \t\t\tfor (int j = 0; j < buf->k; ++j) {\n// \t\t\t\tfloat v = row[j];\n// \t\t\t\tofs.write(reinterpret_cast<const char*>(&v), sizeof(v));\n// \t\t\t}\n// \t\t}\n// \t\tofs.close();\n// \t}\n// }\n\n// for (uint64_t expert_idx = 0; expert_idx < (uint64_t)config_.expert_num; ++expert_idx) {\n// \tdump_grad_bin(\"cpp_layer0_E_End\"+std::to_string(expert_idx)+\"_down_t_ba_\", (ggml_bf16_t*)m_local_down_output_grad_ptr_[expert_idx], config_.hidden_size * m_local_num_[expert_idx], GGML_TYPE_BF16);\n// }\n\n// for (uint64_t expert_idx = 0; expert_idx < (uint64_t)config_.expert_num; ++expert_idx) {\n// \tdump_grad_bin(\"cpp_layer0_E_End\"+std::to_string(expert_idx)+\"_down_t_bb_\", (ggml_bf16_t *)down_proj_t_ + expert_idx * config_.intermediate_size * config_.hidden_size, config_.hidden_size * config_.intermediate_size, GGML_TYPE_BF16);\n// }\n\n// for (uint64_t expert_idx = 0; expert_idx < (uint64_t)config_.expert_num; ++expert_idx) {\n// \tdump_grad_bin(\"cpp_layer0_E_End\"+std::to_string(expert_idx)+\"_down_t_bc_\", (ggml_bf16_t*)m_local_down_input_grad_ptr_[expert_idx], config_.intermediate_size * m_local_num_[expert_idx], GGML_TYPE_BF16);\n// }"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/kvcache/kvcache.h",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#ifndef CPUINFER_OPERATOR_KVCACHE_H\n#define CPUINFER_OPERATOR_KVCACHE_H\n\n#include <algorithm>\n#include <atomic>\n#include <cassert>\n#include <condition_variable>\n#include <cstdint>\n#include <cstdio>\n#include <cstring>\n#include <fstream>\n#include <functional>\n#include <future>\n#include <iostream>\n#include <memory>\n#include <mutex>\n#include <queue>\n#include <random>\n#include <stdexcept>\n#include <thread>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"llama.cpp/ggml-common.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\n#define CHUNK_SIZE 32\n\n/**\n * @brief Converts a ggml_type enum value to its corresponding string\n * representation.\n *\n * This function provides a human-readable string representation for a given\n * ggml_type enum value. The string can be used for logging, debugging, or\n * displaying information in a user interface.\n *\n * @param type The ggml_type enum value to convert.\n * @return A string representation of the enum value.\n */\nstd::string ggml_type_to_string(ggml_type type);\n\n/**\n * @enum AnchorType\n * @brief Defines the types of anchors used in attention mechanisms.\n *\n * This enum specifies different types of anchors that can be used in attention\n * mechanisms, such as fixed anchors, dynamic anchors, or special anchors like\n * QUEST, BLOCK_MEAN, or BLOCK_MAX.\n */\nenum AnchorType {\n    FIXED_ANCHOR, /**< A fixed anchor that does not change. */\n    DYNAMIC,      /**< A dynamic anchor that can change over time. */\n    QUEST, /**< A special anchor type used for QUEST (Query and Embedding Space\n              Transformation). */\n    BLOCK_MEAN, /**< An anchor based on the mean of a block of data. */\n    BLOCK_MAX /**< An anchor based on the maximum value within a block of data.\n               */\n};\n\n/**\n * @brief Converts an AnchorType enum value to its corresponding string\n * representation.\n *\n * This function provides a human-readable string representation for a given\n * AnchorType enum value. The string can be used for logging, debugging, or\n * displaying information in a user interface.\n *\n * @param anchor_type The AnchorType enum value to convert.\n * @return A string representation of the enum value.\n */\nstd::string AnchorTypeToString(AnchorType anchor_type);\n\n/**\n * @enum RetrievalType\n * @brief Defines the types of retrieval strategies in attention mechanisms.\n *\n * This enum specifies different retrieval strategies that can be used in\n * attention mechanisms, such as layer-level retrieval, key-value head-level\n * retrieval, or query head-level retrieval.\n */\nenum RetrievalType {\n    LAYER,  /**< Retrieval at the layer level. */\n    KVHEAD, /**< Retrieval at the key-value head level. */\n    QHEAD   /**< Retrieval at the query head level. */\n};\n\n/**\n * @brief Converts a RetrievalType enum value to its corresponding string\n * representation.\n *\n * This function provides a human-readable string representation for a given\n * RetrievalType enum value. The string can be used for logging, debugging, or\n * displaying information in a user interface.\n *\n * @param retrieval_type The RetrievalType enum value to convert.\n * @return A string representation of the enum value.\n */\nstd::string RetrievalTypeToString(RetrievalType retrieval_type);\n\n/**\n * @struct KVCacheConfig\n * @brief Configuration structure for Key-Value (KV) Cache.\n *\n * This structure holds configuration parameters for setting up and managing\n * a Key-Value (KV) Cache used in various attention mechanisms. It includes\n * parameters such as the number of layers, the number of heads, the dimension\n * of each head, block length, anchor information, and memory-related settings.\n */\nstruct KVCacheConfig {\n    int layer_num;   /**< Number of layers in the model. */\n    int kv_head_num; /**< Number of heads in the KV Cache. */\n    int q_head_num;  /**< Number of heads in the query. */\n    int head_dim;    /**< Dimension of each head. */\n    int block_len;   /**< Length of each block in the cache. */\n    int anchor_num;  /**< Number of anchors used in attention. */\n\n    ggml_type kv_type; /**< Data type of the KV Cache (e.g., fp16, q8_0). */\n\n    // Controls the pre-allocated memory size\n    int max_block_num;  /**< Maximum number of blocks that can be allocated. */\n    int max_batch_size; /**< Maximum batch size that can be processed. */\n    int max_thread_num; /**< Maximum number of threads that can be used. */\n\n    AnchorType\n        anchor_type; /**< Type of anchors used in the attention mechanism. */\n    RetrievalType\n        retrieval_type; /**< Type of retrieval strategy used in the cache. */\n\n    int layer_step;   /**< Step size between layers. */\n    int token_step;   /**< Step size between tokens. */\n    int layer_offset; /**< Offset value for layers. */\n\n    /**\n     * @brief Default constructor for KVCacheConfig.\n     *\n     * Initializes the configuration with default values. This constructor\n     * does not initialize any member variables explicitly.\n     */\n    KVCacheConfig() = default;\n\n    /**\n     * @brief Parameterized constructor for KVCacheConfig.\n     *\n     * This constructor initializes the configuration with specific values\n     * for all member variables.\n     *\n     * @param layer_num The number of layers in the model.\n     * @param kv_head_num The number of heads in the KV Cache.\n     * @param q_head_num The number of heads in the query.\n     * @param head_dim The dimension of each head.\n     * @param block_len The length of each block in the cache.\n     * @param anchor_num The number of anchors used in attention.\n     * @param anchor_type The type of anchors used in the attention mechanism.\n     * @param kv_type The data type of the KV Cache (e.g., fp16, q8_0).\n     * @param retrieval_type The type of retrieval strategy used in the cache.\n     * @param layer_step The step size between layers.\n     * @param token_step The step size between tokens.\n     * @param layer_offset The offset value for layers.\n     * @param max_block_num The maximum number of blocks that can be allocated.\n     * @param max_batch_size The maximum batch size that can be processed.\n     * @param max_thread_num The maximum number of threads that can be used.\n     */\n    KVCacheConfig(int layer_num, int kv_head_num, int q_head_num, int head_dim,\n                  int block_len, int anchor_num, AnchorType anchor_type,\n                  ggml_type kv_type, RetrievalType retrieval_type,\n                  int layer_step, int token_step, int layer_offset,\n                  int max_block_num, int max_batch_size, int max_thread_num);\n};\n\n/**\n * @class KVCache\n * @brief Manages the Key-Value (KV) Cache used in attention mechanisms.\n *\n * The KVCache class provides functionality for managing the Key-Value Cache,\n * including resizing the cache, retrieving configuration parameters, and\n * updating internal states. This class is typically used in transformer models\n * to store and manage past key and value states for efficient attention\n * computations.\n */\nclass KVCache {\n  public:\n    /**\n     * @brief Constructs a KVCache object with the given configuration.\n     *\n     * Initializes the KVCache with the specified configuration parameters,\n     * such as the number of layers, heads, head dimensions, and other\n     * relevant settings.\n     *\n     * @param config The configuration object containing initialization\n     * parameters.\n     */\n    KVCache(KVCacheConfig config);\n\n    /**\n     * @brief Resizes the number of threads used by the cache.\n     *\n     * This function adjusts the number of threads that the cache can utilize.\n     * It allows dynamic reconfiguration of the parallel processing capabilities\n     * based on the current workload or system resources.\n     *\n     * @param thread_num The new number of threads to use.\n     */\n    void ThreadResize(int thread_num);\n\n    /**\n     * @brief Resizes the batch size managed by the cache.\n     *\n     * This function adjusts the batch size that the cache can handle. It\n     * is useful when the input batch size changes dynamically, allowing\n     * the cache to be reconfigured accordingly.\n     *\n     * @param batch_size The new batch size.\n     */\n    void BatchResize(int batch_size);\n\n    /**\n     * @brief Resizes the number of blocks managed by the cache.\n     *\n     * This function adjusts the number of blocks that the cache can manage.\n     * It allows dynamic reconfiguration of the block structure based on the\n     * current sequence length or other factors.\n     *\n     * @param block_num The new number of blocks.\n     */\n    void BlockResize(int block_num);\n\n    /**\n     * @brief Gets the number of layers in the cache.\n     *\n     * @return The number of layers configured in the cache.\n     */\n    int get_layer_num() { return config_.layer_num; }\n\n    /**\n     * @brief Gets the number of KV heads in the cache.\n     *\n     * @return The number of KV heads configured in the cache.\n     */\n    int get_kv_head_num() { return config_.kv_head_num; }\n\n    /**\n     * @brief Gets the number of query heads in the cache.\n     *\n     * @return The number of query heads configured in the cache.\n     */\n    int get_q_head_num() { return config_.q_head_num; }\n\n    /**\n     * @brief Gets the dimension of each head in the cache.\n     *\n     * @return The dimension of each head.\n     */\n    int get_head_dim() { return config_.head_dim; }\n\n    /**\n     * @brief Gets the length of each block in the cache.\n     *\n     * @return The length of each block.\n     */\n    int get_block_len() { return config_.block_len; }\n\n    /**\n     * @brief Gets the number of blocks for a specific layer.\n     *\n     * @param layer_id The ID of the layer for which to retrieve the block\n     * number.\n     * @return The number of blocks in the specified layer.\n     */\n    int get_block_num(int layer_id) { return past_block_num_[layer_id]; }\n\n    /**\n     * @brief Gets the number of anchors in the cache.\n     *\n     * @return The number of anchors configured in the cache.\n     */\n    int get_anchor_num() { return config_.anchor_num; }\n\n    /**\n     * @brief Gets the total length of the cache.\n     *\n     * @return The total length of the cache.\n     */\n    int get_cache_total_len() { return cache_total_len_; }\n\n    /**\n     * @brief Gets the total number of blocks in the cache.\n     *\n     * This function computes and returns the total number of blocks in the\n     * cache based on the total cache length and the block length configuration.\n     *\n     * @return The total number of blocks in the cache.\n     */\n    int get_cache_total_block_num() {\n        return (cache_total_len_ + config_.block_len - 1) / config_.block_len;\n    }\n\n    /**\n     * @brief Updates the total length of the cache.\n     *\n     * This function sets a new total length for the cache, allowing dynamic\n     * adjustment of the cache size during runtime.\n     *\n     * @param cache_total_len The new total length of the cache.\n     */\n    void update_cache_total_len(int cache_total_len) {\n        cache_total_len_ = cache_total_len;\n    }\n    void attn(const ggml_fp16_t *q_in, ggml_fp16_t *output, float *attn_lse,\n              int layer_idx, int generate_token_idx, int q_len, int batch_size,\n              int max_block_num, int *block_table, int *cache_seqlens,\n              int pick_block_num, int init_block_num, int local_block_num,\n              Backend *backend);\n\n    void update_kvcache_one_block_fp16(const ggml_fp16_t *k_in,\n                                       const ggml_fp16_t *v_in, int layer_id,\n                                       int block_idx, Backend *backend);\n\n    void get_kvcache_one_block_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,\n                                    int layer_id, int block_idx,\n                                    Backend *backend);\n\n    void update_importance_one_block(const ggml_fp16_t *importance,\n                                     int layer_id, int block_idx,\n                                     Backend *backend);\n    void get_importance_one_block(ggml_fp16_t *importance, int layer_id,\n                                  int block_idx, Backend *backend);\n\n    void get_anchor_one_block(ggml_fp16_t *anchor, int layer_id, int block_idx,\n                              Backend *backend);\n\n    void update_anchor_one_block(const ggml_fp16_t *anchor, int layer_id,\n                                 int block_idx, Backend *backend);\n\n    void calc_anchor_all_layers(int *block_table, int *cache_seqlens,\n                                int batch_size, int max_block_num,\n                                Backend *backend);\n\n    void load_kvcache(std::string tensor_file_path, Backend *backend);\n    void dump_kvcache(int *block_table, int cache_total_len,\n                      std::string tensor_file_path, Backend *backend);\n\n    void get_and_update_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,\n                                     int layer_id, int *block_table,\n                                     int batch_size, int max_block_num,\n                                     int *cache_seqlens, int q_len,\n                                     Backend *backend);\n\n    void get_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in, int layer_id,\n                          int *block_table, int batch_size, int max_block_num,\n                          int *cache_seqlens, Backend *backend);\n\n    void update_kvcache_fp16(const ggml_fp16_t *k_in, const ggml_fp16_t *v_in,\n                             int layer_id, int *block_table, int batch_size,\n                             int max_block_num, int *cache_seqlens, int q_len,\n                             Backend *backend);\n\n    void update_importance(const ggml_fp16_t *importance, int layer_id,\n                           int *block_table, int batch_size, int max_block_num,\n                           int *offset, int width, Backend *backend);\n\n    void attn_with_kvcache(const ggml_fp16_t *q_in, const ggml_fp16_t *k_in,\n                           const ggml_fp16_t *v_in, ggml_fp16_t *output,\n                           float *attn_lse, int layer_idx,\n                           int generate_token_idx, int q_len, int batch_size,\n                           int max_block_num, int *block_table,\n                           int *cache_seqlens, int topk, int local,\n                           Backend *backend);\n\n    void clear_importance_all_layers(int *block_table, int *cache_seqlens,\n                                     int batch_size, int max_block_num,\n                                     Backend *backend);\n\n    void clear_kvcache_all_layers(int *block_table, int *cache_seqlens,\n                                  int batch_size, int max_block_num,\n                                  Backend *backend);\n\n    void get_sincos(ggml_fp16_t *sin, ggml_fp16_t *cos, int seqlen);\n\n    void get_attn_sparsity(const ggml_fp16_t *q_in, float *attn_sparsity,\n                           int layer_idx, int generate_token_idx, int q_len,\n                           int batch_size, int max_block_num, int *block_table,\n                           int *cache_seqlens, int *block_table_origin,\n                           int *cache_seqlens_origin, int max_block_num_origin,\n                           int topk, int local, Backend *backend);\n\n    void get_all_kvcache_one_layer(int layer_id, ggml_fp16_t *k_in,\n                                   ggml_fp16_t *v_in, Backend *backend);\n\n  private:\n    // Persistent data\n    KVCacheConfig config_;\n    int n_gqa_;                            // q_head_num / kv_head_num\n    int cache_total_len_;                  // Number of tokens in cache\n    std::vector<uint64_t> past_block_num_; // [layer_num]\n    std::vector<std::vector<std::vector<std::vector<block_q4_0>>>>\n        k_cache_q4; // [layer_num, kv_head_num, past_block_num, block_len *\n                    // (head_dim / QK_4)]\n    std::vector<std::vector<std::vector<std::vector<block_q4_0>>>>\n        v_cache_q4; // [layer_num, kv_head_num, past_block_num, head_dim *\n                    // (block_len / QK_4)]\n    std::vector<std::vector<std::vector<std::vector<block_q8_0>>>>\n        k_cache_q8; // [layer_num, kv_head_num, past_block_num, block_len *\n                    // (head_dim / QK_8)]\n    std::vector<std::vector<std::vector<std::vector<block_q8_0>>>>\n        v_cache_q8; // [layer_num, kv_head_num, past_block_num, head_dim *\n                    // (block_len / QK_8)]\n\n    std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>\n        k_cache_fp16_; // [layer_num, kv_head_num, past_block_num, block_len *\n                       // head_dim]\n    std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>\n        v_cache_fp16_; // [layer_num, kv_head_num, past_block_num, head_dim *\n                       // block_len]\n\n    std::vector<std::vector<std::vector<std::vector<ggml_fp16_t>>>>\n        importance_; // [layer_num, past_block_num, block_len,\n                     // attention_head_num]\n\n    std::vector<ggml_fp16_t>\n        anchor_; // [layer_num * past_block_num * anchor_num *\n                 // attention_head_num * head_dim]\n\n    // Runtime data\n    int64_t layer_id_;\n    int64_t block_idx_;\n    int *block_table_;\n    uint64_t block_num_;\n    int max_block_num_after_retrieval_;\n\n    // Rotary positional embeddings\n    std::vector<std::vector<ggml_fp16_t>> sin_; // [seq_len, head_dim]\n    std::vector<std::vector<ggml_fp16_t>> cos_; // [seq_len, head_dim]\n\n    // update/get\n    int seq_len_;\n    uint16_t *k_scales_;        // q4_0\n    uint8_t *k_in_;             // q4_0\n    uint16_t *v_scales_;        // q4_0\n    uint8_t *v_in_;             // q4_0\n    uint16_t *k_data_;          // fp16\n    uint16_t *v_data_;          // fp16\n    uint16_t *importance_data_; // fp16\n    uint16_t *anchor_data_;     // fp16\n\n    // sparsity = (sigma(block lse / lse))\n    std::vector<std::vector<std::vector<float>>>\n        block_lse_; // [batch_size, max_block_num, q_head_num]\n    std::vector<std::vector<float>> attn_sparsity_; // [batch_size, q_head_num]\n\n    // attn\n    std::vector<std::vector<float>>\n        avg_q; // [batch_size, q_head_num * head_dim]\n\n    std::vector<std::vector<ggml_fp16_t>>\n        avg_q_fp16; // [batch_size, q_head_num * head_dim]\n    std::vector<\n        std::priority_queue<std::pair<float, int>,\n                            std::vector<std::pair<float, int>>, std::greater<>>>\n        top_similar_block_;\n\n    std::vector<std::vector<float>> block_similar_;\n    std::vector<std::vector<std::vector<float>>> block_similar_kv_head_;\n    std::vector<std::vector<std::vector<float>>> block_similar_q_head_;\n\n    std::vector<int> cache_seqlens_;               // [batch_size]\n    std::vector<int> selected_blocks_num_history_; // [layer_num // layer_step]\n\n    std::vector<std::vector<std::vector<int>>> selected_blocks_history_;\n    // [layer_num // layer_step, batch_size, max_block_num]\n\n    std::vector<std::vector<std::vector<std::vector<int>>>>\n        selected_blocks_history_kvhead_; // [layer_num // layer_step,\n                                         // batch_size, max_block_num,\n                                         // kv_head_num]\n\n    std::vector<std::vector<int>>\n        block_table_before_retrieval_; // [batch_size, max_block_num]\n    std::vector<std::vector<int>>\n        block_table_after_retrieval_; // [batch_size, pick_block_num]\n\n    std::vector<std::vector<std::vector<int>>>\n        block_table_before_retrieval_qhead_; // [batch_size, max_block_num,\n                                             // q_head_num]\n    std::vector<std::vector<std::vector<int>>>\n        block_table_after_retrieval_qhead_; // [batch_size, pick_block_num,\n                                            // q_head_num]\n\n    std::vector<std::vector<std::vector<int>>>\n        block_table_before_retrieval_kvhead_; // [batch_size, max_block_num,\n                                              // kv_head_num]\n    std::vector<std::vector<std::vector<int>>>\n        block_table_after_retrieval_kvhead_; // [batch_size, pick_block_num,\n                                             // kv_head_num]\n\n    std::vector<std::vector<std::unique_ptr<std::mutex>>>\n        mutex_; // [batch_size, kv_head_num]\n    std::vector<std::vector<std::vector<block_q8_0>>>\n        q_q8_0_; // [batch_size, kv_head_num, n_gqa * head_dim / QK8_0]\n    std::vector<std::vector<std::vector<float>>>\n        q_fp32_; // [batch_size, kv_head_num, n_gqa * head_dim]\n\n    std::vector<std::vector<std::vector<float>>>\n        output_fp32_; // [batch_size, kv_head_num, n_gqa * head_dim]\n    std::vector<std::vector<std::vector<float>>>\n        attn_lse_; // [batch_size, kv_head_num, n_gqa]\n\n    std::vector<std::pair<int, int>> thread_cur_head_idx_; // [thread_num]\n\n    std::vector<std::vector<block_q8_0>>\n        thread_local_output_q8_0_; // [thread_num, n_gqa * head_dim / QK8_0]\n    std::vector<std::vector<float>>\n        thread_local_attn_score_; // [thread_num, n_gqa * block_len]\n    std::vector<std::vector<float>>\n        thread_local_output_fp32_; // [thread_num, n_gqa * head_dim]\n    std::vector<std::vector<float>>\n        thread_local_attn_lse_; // [thread_num, n_gqa]\n    std::vector<std::vector<float>>\n        thread_local_cur_output_fp32_; // [thread_num, n_gqa * head_dim]\n    std::vector<std::vector<float>>\n        thread_local_cur_attn_lse_; // [thread_num, n_gqa]\n    std::vector<std::vector<uint8_t>>\n        thread_local_attn_mask_; // [thread_num, block_len // 8]\n    std::vector<std::vector<char>>\n        thread_local_draft_; // [thread_num, 2 * n_gqa * block_len + 6 * n_gqa *\n                             // head_dim + 2 * block_len * head_dim]\n\n    // tmp space\n    std::vector<float> q_fp32; // [n_gqa * head_dim]\n\n    void quantize_q_(const uint16_t *q_in_data, int batch_size);\n    void attn_initialize_layer_(int batch_size, int layer_idx, int *block_table,\n                                int &max_block_num, int *cache_seqlens);\n    void attn_initialize_kvhead_(int batch_size, int layer_idx,\n                                 int *block_table, int &max_block_num,\n                                 int *cache_seqlens);\n    void retrieval_kvcache_layer_(const uint16_t *q_in_data, int init_block_num,\n                                  int local_block_num, int pick_block_num,\n                                  int q_len, int generate_token_idx,\n                                  int batch_size, int layer_idx,\n                                  int *cache_seqlens, int &max_block_num,\n                                  Backend *backend);\n    void retrieval_kvcache_kvhead_(const uint16_t *q_in_data,\n                                   int init_block_num, int local_block_num,\n                                   int pick_block_num, int q_len,\n                                   int generate_token_idx, int batch_size,\n                                   int layer_idx, int *cache_seqlens,\n                                   int &max_block_num, Backend *backend);\n\n    void calculate_block_similarity_layer_(\n        const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,\n        int max_block_num, int *cache_seqlens, int init_block_num,\n        int local_block_num, int pick_block_num, Backend *backend);\n    void calculate_block_similarity_kvhead_(\n        const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,\n        int max_block_num, int *cache_seqlens, int init_block_num,\n        int local_block_num, int pick_block_num, Backend *backend);\n\n    void select_block_layer_(int batch_size, int layer_idx, int max_block_num,\n                             int init_block_num, int local_block_num,\n                             int pick_block_num);\n    void select_block_kvhead_(int batch_size, int layer_idx, int max_block_num,\n                              int init_block_num, int local_block_num,\n                              int pick_block_num);\n\n    void calculate_sparsity_layer_(const uint16_t *q_in_data,\n                                   float *attn_sparsity, int batch_size,\n                                   int max_block_num, int *block_table,\n                                   int *cache_seqlens, Backend *backend);\n    void calculate_sparsity_kvhead_(const uint16_t *q_in_data,\n                                    float *attn_sparsity, int batch_size,\n                                    int max_block_num, int *block_table,\n                                    int *cache_seqlens, Backend *backend);\n\n    void attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,\n                           float *attn_lse, int batch_size, Backend *backend);\n    void attention_layer_(const uint16_t *q_in_data, ggml_fp16_t *output,\n                          float *attn_lse, int batch_size, Backend *backend);\n\n    /**\n     * @brief Computes attention with KV cache for one block.\n     *\n     * This function performs attention computation for one block using KV\n     * cache. The function supports different data types for Q, K, and V caches,\n     * and provides options for quantization. The function does not perform any\n     * dynamic memory allocation internally, so all necessary buffers must be\n     * pre-allocated externally.\n     *\n     * @param head_dim The dimension of the head.\n     * @param bsz The batch size.\n     * @param q_type The data type of Q (GGML data type). Only supports fp16 and\n     * q8_0.\n     * @param q Pointer to the Q tensor [bsz, head_dim]. The quantization is\n     *          always applied along the head_dim dimension. The size must be\n     *          bsz * head_dim/32 * qtype_size. If head_dim % 32 != 0, an error\n     *          will be raised.\n     * @param past_kv_len The length of the past KV cache.\n     * @param past_kv_offset The offset in the past KV cache.\n     * @param is_full_attn Boolean flag indicating whether to use full attention\n     *                     (true for full 1 mask).\n     * @param attn_mask Pointer to the attention mask [bsz, past_kv_len]. If\n     *                  is_full_attn = false, a bit matrix is passed to\n     * represent the mask.\n     * @param k_type The data type of K cache (GGML data type). Only supports\n     *               fp16, q4_0, and q8_0.\n     * @param k_quant_type Quantization type for K cache. 0 for per_token, 1 for\n     *                     per_channel. Other values will raise an error.\n     * @param k_cache Pointer to the K cache tensor [seq_len, head_dim]. If\n     *                quant_type == 0, head_dim % 32 must be 0. If quant_type ==\n     * 1, seq_len % 32 must be 0.\n     * @param num_k_anchor The number of K anchors. If num_k_anchor == 0, it\n     * means no anchor is present.\n     * @param k_cache_anchors Pointer to the K cache anchors [num_k_anchor,\n     * head_dim]. The k_anchor_type must be fp16.\n     * @param k_cache_anchor_pos Pointer to the K cache anchor positions. Each\n     * token is associated with the nearest previous anchor position.\n     * @param v_type The data type of V cache (GGML data type).\n     * @param v_quant_type Quantization type for V cache.\n     * @param v_cache Pointer to the V cache tensor [head_dim, seq_len].\n     * @param num_v_anchor The number of V anchors.\n     * @param v_cache_anchors Pointer to the V cache anchors.\n     * @param v_cache_anchor_pos Pointer to the V cache anchor positions.\n     * @param attn_score Pre-allocated buffer for attention scores [bsz,\n     * past_kv_len].\n     * @param output Output tensor [bsz, head_dim] with the same type as q_type.\n     * @param lse Pre-allocated buffer [bsz] for the log-sum-exp of the\n     * attention scores.\n     * @param draft Pre-allocated temporary buffer. The buffer size should be\n     * enough to hold (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 *\n     *              past_kv_len * head_dim + past_kv_len * head_dim / 32) bytes.\n     * @param rotary_angle Pointer to the rotary angle tensor.\n     * @param rotary_cos Pointer to the cosine values for rotary embedding.\n     * @param rotary_sin Pointer to the sine values for rotary embedding.\n     */\n    void attn_with_kvcache_one_block_(\n        int head_dim, int bsz,\n        ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0\n        // [bsz, head_dim]\n        // Quantization is always on the head_dim dimension (per_token). If\n        // head_dim % 32 != 0, an error will be raised. The size must be bsz *\n        // head_dim/32 * qtype_size.\n        const void *q,\n\n        int past_kv_len, int past_kv_offset,\n        bool is_full_attn, // true indicates a full 1 mask\n        // If is_full_attn = false, a bit matrix representing the mask is\n        // passed. [bsz, past_kv_len]\n        const uint8_t *attn_mask,\n\n        ggml_type k_type, // GGML data type of `K Cache`, only supports fp16,\n                          // q4_0, q8_0\n        int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an\n                          // error\n        // [seq_len, head_dim]\n        // If quant_type == 0, head_dim % 32 must be 0.\n        // If quant_type == 1, seq_len % 32 must be 0.\n        const void *k_cache,\n\n        // k_anchor_type must be fp16\n        int num_k_anchor, // num_k_anchor == 0 indicates no anchor\n        // [num_k_anchor, head_dim]\n        const void *k_cache_anchors,\n        // Each token is associated with the nearest previous position's anchor,\n        // with the same distance.\n        const int *k_cache_anchor_pos,\n\n        // v_cache similar to k_cache\n        ggml_type v_type, int v_quant_type,\n        // [head_dim, seq_len]\n        const void *v_cache, int num_v_anchor, const void *v_cache_anchors,\n        const int *v_cache_anchor_pos,\n\n        // Pre-allocated buffer for intermediate calculations [bsz,\n        // past_kv_len]. No malloc is performed inside this function.\n        float *attn_score,\n\n        // Output: [bsz, head_dim], with the same type as q_type\n        void *output,\n        // [bsz]\n        float *lse,\n\n        // Pre-allocated temporary buffer with sufficient size:\n        // (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *\n        // head_dim + past_kv_len * head_dim / 32) bytes.\n        void *draft,\n\n        // Apply rotary embedding online\n        const int *rotary_angle, const void *rotary_cos, const void *rotary_sin\n        // rotary_cos=None,\n        // rotary_sin=None,\n        // cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,\n        // cache_batch_idx: Optional[torch.Tensor] = None,\n        // rotary_interleaved=True,\n\n        // // Not supported for now\n        // window_size=(-1, -1),  # -1 means infinite context window\n        // alibi_slopes=None,\n    );\n};\n\n/**\n * @brief Scales a float32 vector by a given scalar value.\n *\n * This function multiplies each element of the input vector `y` by a scalar\n * `v`. It uses platform-specific optimizations if available, such as Apple's\n * Accelerate framework or SIMD instructions. If no specific optimization is\n * available, the function falls back to a simple scalar multiplication loop.\n *\n * @param n The number of elements in the vector `y`.\n * @param y The input vector to be scaled. The result will be stored in the same\n * vector.\n * @param v The scalar value by which to scale the vector.\n */\nvoid ggml_vec_scale_f32(const int n, float *y, const float v);\n#endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/kvcache/kvcache_attn.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"kvcache.h\"\n\n#include <chrono>\n\nvoid KVCache::attention_kvhead_(const uint16_t *q_in_data, ggml_fp16_t *output,\n                                float *attn_lse, int batch_size,\n                                Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    seq_len_ = config_.block_len;\n\n    backend->do_work_stealing_job(\n        batch_size * config_.kv_head_num * max_block_num_after_retrieval_,\n        [&](int thread_id) {\n            thread_cur_head_idx_[thread_id].first = -1;\n            thread_cur_head_idx_[thread_id].second = -1;\n        },\n        [&](int task_id) {\n            int batch_id = task_id / (config_.kv_head_num *\n                                      max_block_num_after_retrieval_);\n            int head_id = (task_id % (config_.kv_head_num *\n                                      max_block_num_after_retrieval_)) /\n                          max_block_num_after_retrieval_;\n            int block_id = task_id % max_block_num_after_retrieval_;\n            int thread_id = Backend::thread_local_id;\n\n            // If the block is out of the sequence length, skip it.\n            if (cache_seqlens_[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx =\n                block_table_after_retrieval_kvhead_[batch_id][block_id]\n                                                   [head_id];\n            if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n                int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n                if (seq_len == 0)\n                    return;\n\n                // Prepare the attention mask for the last block.\n                int full_blocks = seq_len / 8;\n                int remaining_bits = seq_len % 8;\n                // Fill full blocks with 1s\n                for (int i = 0; i < full_blocks; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0xFF;\n                }\n                // Fill the remaining bits in the next block\n                if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n                    thread_local_attn_mask_[thread_id][full_blocks] =\n                        (1 << remaining_bits) - 1;\n                } else {\n                    thread_local_attn_mask_[thread_id][full_blocks] = 0;\n                }\n\n                for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0;\n                }\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            } else {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            }\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse =\n                        thread_local_cur_attn_lse_[thread_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_attn_lse_[thread_id][i] -\n                                     thread_local_cur_attn_lse_[thread_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_[thread_id]\n                                                     [i * config_.head_dim +\n                                                      j] +=\n                            thread_local_output_fp32_[thread_id]\n                                                     [i * config_.head_dim + j];\n                    }\n                    thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n                }\n            } else {\n                if (cur_batch_idx != -1) {\n                    mutex_[cur_batch_idx][cur_head_id]->lock();\n                    for (int i = 0; i < n_gqa_; i++) {\n                        if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                            1e-6) {\n                            attn_lse_[cur_batch_idx][cur_head_id][i] =\n                                thread_local_cur_attn_lse_[thread_id][i];\n                            for (int j = 0; j < config_.head_dim; j++) {\n                                output_fp32_[cur_batch_idx][cur_head_id]\n                                            [i * config_.head_dim + j] =\n                                                thread_local_cur_output_fp32_\n                                                    [thread_id]\n                                                    [i * config_.head_dim + j];\n                            }\n                            continue;\n                        }\n                        float new_attn_lse =\n                            attn_lse_[cur_batch_idx][cur_head_id][i] +\n                            std::log(\n                                1.0 +\n                                std::exp(\n                                    thread_local_cur_attn_lse_[thread_id][i] -\n                                    attn_lse_[cur_batch_idx][cur_head_id][i]));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            output_fp32_[cur_batch_idx][cur_head_id].data() +\n                                i * config_.head_dim,\n                            std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                     new_attn_lse));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            thread_local_cur_output_fp32_[thread_id].data() +\n                                i * config_.head_dim,\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     new_attn_lse));\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] +=\n                                thread_local_cur_output_fp32_\n                                    [thread_id][i * config_.head_dim + j];\n                        }\n                        attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                    }\n                    mutex_[cur_batch_idx][cur_head_id]->unlock();\n                }\n                thread_cur_head_idx_[thread_id].first = batch_id;\n                thread_cur_head_idx_[thread_id].second = head_id;\n                for (int i = 0; i < n_gqa_; i++) {\n                    thread_local_cur_attn_lse_[thread_id][i] =\n                        thread_local_attn_lse_[thread_id][i];\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_\n                            [thread_id][i * config_.head_dim + j] =\n                                thread_local_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                }\n            }\n        },\n        // Merge the results of the remaining blocks.\n        [&](int thread_id) {\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (cur_head_id != -1) {\n                mutex_[cur_batch_idx][cur_head_id]->lock();\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse;\n                    if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                        1e-6) {\n                        attn_lse_[cur_batch_idx][cur_head_id][i] =\n                            thread_local_cur_attn_lse_[thread_id][i];\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] =\n                                            thread_local_cur_output_fp32_\n                                                [thread_id]\n                                                [i * config_.head_dim + j];\n                        }\n                        continue;\n                    }\n                    new_attn_lse =\n                        attn_lse_[cur_batch_idx][cur_head_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     attn_lse_[cur_batch_idx][cur_head_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        output_fp32_[cur_batch_idx][cur_head_id].data() +\n                            i * config_.head_dim,\n                        std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        output_fp32_[cur_batch_idx][cur_head_id]\n                                    [i * config_.head_dim + j] +=\n                            thread_local_cur_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                    attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                }\n                mutex_[cur_batch_idx][cur_head_id]->unlock();\n            }\n        });\n    // move the results to output and attn_lse\n    uint16_t *output_data = reinterpret_cast<uint16_t *>(output);\n    float *attn_lse_data = attn_lse;\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        for (int i = 0; i < config_.kv_head_num; i++) {\n            for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                output_data[batch_idx * config_.kv_head_num * n_gqa_ *\n                                config_.head_dim +\n                            i * n_gqa_ * config_.head_dim + j] =\n                    GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);\n            }\n            for (int j = 0; j < n_gqa_; j++) {\n                attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ +\n                              i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];\n            }\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of computing attention: %f s\\n\", layer_idx,\n    //        diff.count());\n}\n\nvoid KVCache::attention_layer_(const uint16_t *q_in_data, ggml_fp16_t *output,\n                               float *attn_lse, int batch_size,\n                               Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        batch_size * config_.kv_head_num * max_block_num_after_retrieval_,\n        [&](int thread_id) {\n            thread_cur_head_idx_[thread_id].first = -1;\n            thread_cur_head_idx_[thread_id].second = -1;\n        },\n        [&](int task_id) {\n            int batch_id = task_id / (config_.kv_head_num *\n                                      max_block_num_after_retrieval_);\n            int head_id = (task_id % (config_.kv_head_num *\n                                      max_block_num_after_retrieval_)) /\n                          max_block_num_after_retrieval_;\n            int block_id = task_id % max_block_num_after_retrieval_;\n            int thread_id = Backend::thread_local_id;\n            // If the block is out of the sequence length, skip it.\n            if (cache_seqlens_[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table_after_retrieval_[batch_id][block_id];\n            if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n                int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n                if (seq_len == 0)\n                    return;\n\n                // Prepare the attention mask for the last block.\n                int full_blocks = seq_len / 8;\n                int remaining_bits = seq_len % 8;\n\n                // Fill full blocks with 1s\n                for (int i = 0; i < full_blocks; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0xFF;\n                }\n                // Fill the remaining bits in the next block\n                if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n                    thread_local_attn_mask_[thread_id][full_blocks] =\n                        (1 << remaining_bits) - 1;\n                } else {\n                    thread_local_attn_mask_[thread_id][full_blocks] = 0;\n                }\n\n                for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0;\n                }\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            } else {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            }\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse =\n                        thread_local_cur_attn_lse_[thread_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_attn_lse_[thread_id][i] -\n                                     thread_local_cur_attn_lse_[thread_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_[thread_id]\n                                                     [i * config_.head_dim +\n                                                      j] +=\n                            thread_local_output_fp32_[thread_id]\n                                                     [i * config_.head_dim + j];\n                    }\n                    thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n                }\n            } else {\n                if (cur_batch_idx != -1) {\n                    mutex_[cur_batch_idx][cur_head_id]->lock();\n                    for (int i = 0; i < n_gqa_; i++) {\n                        if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                            1e-6) {\n                            attn_lse_[cur_batch_idx][cur_head_id][i] =\n                                thread_local_cur_attn_lse_[thread_id][i];\n                            for (int j = 0; j < config_.head_dim; j++) {\n                                output_fp32_[cur_batch_idx][cur_head_id]\n                                            [i * config_.head_dim + j] =\n                                                thread_local_cur_output_fp32_\n                                                    [thread_id]\n                                                    [i * config_.head_dim + j];\n                            }\n                            continue;\n                        }\n                        float new_attn_lse =\n                            attn_lse_[cur_batch_idx][cur_head_id][i] +\n                            std::log(\n                                1.0 +\n                                std::exp(\n                                    thread_local_cur_attn_lse_[thread_id][i] -\n                                    attn_lse_[cur_batch_idx][cur_head_id][i]));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            output_fp32_[cur_batch_idx][cur_head_id].data() +\n                                i * config_.head_dim,\n                            std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                     new_attn_lse));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            thread_local_cur_output_fp32_[thread_id].data() +\n                                i * config_.head_dim,\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     new_attn_lse));\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] +=\n                                thread_local_cur_output_fp32_\n                                    [thread_id][i * config_.head_dim + j];\n                        }\n                        attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                    }\n                    mutex_[cur_batch_idx][cur_head_id]->unlock();\n                }\n                thread_cur_head_idx_[thread_id].first = batch_id;\n                thread_cur_head_idx_[thread_id].second = head_id;\n                for (int i = 0; i < n_gqa_; i++) {\n                    thread_local_cur_attn_lse_[thread_id][i] =\n                        thread_local_attn_lse_[thread_id][i];\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_\n                            [thread_id][i * config_.head_dim + j] =\n                                thread_local_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                }\n            }\n        },\n        // Merge the results of the remaining blocks.\n        [&](int thread_id) {\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (cur_head_id != -1) {\n                mutex_[cur_batch_idx][cur_head_id]->lock();\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse;\n                    if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                        1e-6) {\n                        attn_lse_[cur_batch_idx][cur_head_id][i] =\n                            thread_local_cur_attn_lse_[thread_id][i];\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] =\n                                            thread_local_cur_output_fp32_\n                                                [thread_id]\n                                                [i * config_.head_dim + j];\n                        }\n                        continue;\n                    }\n                    new_attn_lse =\n                        attn_lse_[cur_batch_idx][cur_head_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     attn_lse_[cur_batch_idx][cur_head_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        output_fp32_[cur_batch_idx][cur_head_id].data() +\n                            i * config_.head_dim,\n                        std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        output_fp32_[cur_batch_idx][cur_head_id]\n                                    [i * config_.head_dim + j] +=\n                            thread_local_cur_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                    attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                }\n                mutex_[cur_batch_idx][cur_head_id]->unlock();\n            }\n        });\n\n    // move the results to output and attn_lse\n    uint16_t *output_data = reinterpret_cast<uint16_t *>(output);\n    float *attn_lse_data = attn_lse;\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        for (int i = 0; i < config_.kv_head_num; i++) {\n            for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                output_data[batch_idx * config_.kv_head_num * n_gqa_ *\n                                config_.head_dim +\n                            i * n_gqa_ * config_.head_dim + j] =\n                    GGML_FP32_TO_FP16(output_fp32_[batch_idx][i][j]);\n            }\n            for (int j = 0; j < n_gqa_; j++) {\n                attn_lse_data[batch_idx * config_.kv_head_num * n_gqa_ +\n                              i * n_gqa_ + j] = attn_lse_[batch_idx][i][j];\n            }\n        }\n    }\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    //     printf(\"layer %d time of computing attention: %f s\\n\", layer_id_,\n    //     diff.count());\n}\n\nvoid KVCache::attn(const ggml_fp16_t *q_in, ggml_fp16_t *output,\n                   float *attn_lse, int layer_idx, int generate_token_idx,\n                   int q_len, int batch_size, int max_block_num,\n                   int *block_table, int *cache_seqlens, int pick_block_num,\n                   int init_block_num, int local_block_num, Backend *backend) {\n\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    layer_id_ = layer_idx;\n    batch_size = batch_size * q_len;\n\n    const uint16_t *q_in_data = const_cast<const uint16_t *>(q_in);\n\n    quantize_q_(q_in_data, batch_size);\n    if (config_.retrieval_type == RetrievalType::LAYER) {\n        attn_initialize_layer_(batch_size, layer_idx, block_table,\n                               max_block_num, cache_seqlens);\n        retrieval_kvcache_layer_(q_in_data, init_block_num, local_block_num,\n                                 pick_block_num, q_len, generate_token_idx,\n                                 batch_size, layer_idx, cache_seqlens,\n                                 max_block_num, backend);\n        attention_layer_(q_in_data, output, attn_lse, batch_size, backend);\n    } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n        attn_initialize_kvhead_(batch_size, layer_idx, block_table,\n                                max_block_num, cache_seqlens);\n        retrieval_kvcache_kvhead_(q_in_data, init_block_num, local_block_num,\n                                  pick_block_num, q_len, generate_token_idx,\n                                  batch_size, layer_idx, cache_seqlens,\n                                  max_block_num, backend);\n        attention_kvhead_(q_in_data, output, attn_lse, batch_size, backend);\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of computing attention: %f s\\n\", layer_idx,\n    //        diff.count());\n}\n\nvoid KVCache::attn_with_kvcache(\n    const ggml_fp16_t *q_in, const ggml_fp16_t *k_in, const ggml_fp16_t *v_in,\n    ggml_fp16_t *output, float *attn_lse, int layer_idx, int generate_token_idx,\n    int q_len, int batch_size, int max_block_num, int *block_table,\n    int *cache_seqlens, int topk, int local, Backend *backend) {\n    //    printf(\"attn_with_kvcache start\\n\");\n    assert(q_len == 1);\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_idx;\n\n    update_kvcache_fp16(k_in, v_in, layer_idx, block_table, batch_size,\n                        max_block_num, cache_seqlens, q_len, backend);\n    //    printf(\"update finished.\\n\");\n\n    // cache_seqlens memory is modified.\n    for (int i = 0; i < batch_size; i++) {\n        cache_seqlens[i] += q_len;\n    }\n    int init_block_num = 1;\n    if (config_.block_len <= 32) {\n        init_block_num = 64 / config_.block_len;\n    }\n\n    attn(q_in, output, attn_lse, layer_idx, generate_token_idx, q_len,\n         batch_size, max_block_num, block_table, cache_seqlens, topk,\n         init_block_num, local, backend);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    //     printf(\"layer %d time of computing attention with kvcache: %f s\\n\",\n    //     layer_idx, diff.count());\n}\n\nvoid KVCache::quantize_q_(const uint16_t *q_in_data, int batch_size) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            // quantize q\n            for (int i = 0; i < config_.kv_head_num; i++) {\n                for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                    q_fp32_[batch_idx][i][j] = GGML_FP16_TO_FP32(\n                        q_in_data[batch_idx * config_.kv_head_num * n_gqa_ *\n                                      config_.head_dim +\n                                  i * n_gqa_ * config_.head_dim + j]);\n                }\n            }\n        } else {\n            // quantize q\n            for (int i = 0; i < config_.kv_head_num; i++) {\n                for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                    q_fp32[j] = GGML_FP16_TO_FP32(\n                        q_in_data[batch_idx * config_.kv_head_num * n_gqa_ *\n                                      config_.head_dim +\n                                  i * n_gqa_ * config_.head_dim + j]);\n                }\n                quantize_row_q8_0(q_fp32.data(), q_q8_0_[batch_idx][i].data(),\n                                  n_gqa_ * config_.head_dim);\n            }\n        }\n    }\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    // printf(\"time of quantizing q: %f s\\n\",\n    //        std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::attn_initialize_layer_(int batch_size, int layer_idx,\n                                     int *block_table, int &max_block_num,\n                                     int *cache_seqlens) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        // initialize output_fp32_ and attn_lse_\n        for (int i = 0; i < config_.kv_head_num; i++) {\n            for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                output_fp32_[batch_idx][i][j] = 0;\n            }\n            for (int j = 0; j < n_gqa_; j++) {\n                attn_lse_[batch_idx][i][j] = 0;\n            }\n        }\n        // clear top_similar_block_\n\n        while (!top_similar_block_[batch_idx].empty())\n            top_similar_block_[batch_idx].pop();\n    }\n\n    // get block_table_before_retrieval_ and cache_seqlens_\n    if (block_table == nullptr) {\n        max_block_num = past_block_num_[layer_idx];\n        for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n            if (cache_total_len_ != 0)\n                cache_seqlens_[batch_idx] = cache_total_len_;\n            else\n                cache_seqlens_[batch_idx] = max_block_num * config_.block_len;\n            for (int i = 0; i < max_block_num; i++) {\n                block_table_before_retrieval_[batch_idx][i] = i;\n                block_similar_[batch_idx][i] = 0;\n            }\n        }\n    } else {\n        for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n            cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];\n            for (int i = 0; i < max_block_num; i++) {\n                block_table_before_retrieval_[batch_idx][i] =\n                    block_table[batch_idx * max_block_num + i];\n                block_similar_[batch_idx][i] = 0;\n            }\n        }\n    }\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    // printf(\"layer %d time of initializing attention: %f s\\n\", layer_idx,\n    //        std::chrono::duration<double>(end - start).count());\n}\n\nvoid KVCache::calculate_block_similarity_layer_(\n    const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,\n    int max_block_num, int *cache_seqlens, int init_block_num,\n    int local_block_num, int pick_block_num, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    if (batch_size == 1 &&\n        config_.anchor_num == 1) { // TODO: improve batch_size > 1\n        for (int batch_id = 0; batch_id < batch_size; batch_id++) {\n            if (q_len == 1) {\n                for (int j = 0; j < config_.head_dim * config_.q_head_num;\n                     j++) {\n                    avg_q[batch_id][j] = GGML_FP16_TO_FP32(\n                        q_in_data[batch_id * q_len * config_.q_head_num *\n                                      config_.head_dim +\n                                  j]);\n                    avg_q_fp16[batch_id][j] =\n                        q_in_data[batch_id * q_len * config_.q_head_num *\n                                      config_.head_dim +\n                                  j];\n                }\n            } else {\n                for (int j = 0; j < config_.head_dim * config_.q_head_num;\n                     j++) {\n                    avg_q[batch_id][j] = 0;\n                }\n                for (int i = 0; i < q_len; i++) {\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        avg_q[batch_id][j] += GGML_FP16_TO_FP32(\n                            q_in_data[batch_id * q_len * config_.q_head_num *\n                                          config_.head_dim +\n                                      i * config_.q_head_num *\n                                          config_.head_dim +\n                                      j]);\n                    }\n                }\n                for (int j = 0; j < config_.head_dim * config_.q_head_num;\n                     j++) {\n                    avg_q[batch_id][j] /= q_len;\n                    avg_q_fp16[batch_id][j] =\n                        GGML_FP32_TO_FP16(avg_q[batch_id][j]);\n                }\n            }\n            int seq_len = cache_seqlens_[batch_id];\n            int block_num = (seq_len / config_.block_len) - local_block_num -\n                            init_block_num;\n            if (block_num <= 0) {\n                continue;\n            }\n            bool is_seq = true;\n            for (int i = init_block_num + 1;\n                 i < (seq_len / config_.block_len) - local_block_num; i++) {\n                if (block_table_before_retrieval_[batch_id][i] !=\n                    block_table_before_retrieval_[batch_id][i - 1] + 1) {\n                    is_seq = false;\n                    break;\n                }\n            }\n            if (is_seq) {\n                int nth = backend->get_thread_num();\n                backend->do_work_stealing_job(\n                    nth, nullptr,\n                    [&](int task_id) {\n                        int ith = task_id;\n                        bool ok = llamafile_sgemm(\n                            block_num, 1, config_.q_head_num * config_.head_dim,\n                            anchor_.data() +\n                                (layer_idx * config_.max_block_num +\n                                 block_table_before_retrieval_\n                                     [batch_id][init_block_num]) *\n                                    config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim,\n                            config_.q_head_num * config_.head_dim,\n                            avg_q_fp16[batch_id].data(),\n                            config_.q_head_num * config_.head_dim,\n                            block_similar_[batch_id].data() + init_block_num,\n                            block_num, ith, nth, GGML_TASK_TYPE_COMPUTE,\n                            GGML_TYPE_F16, GGML_TYPE_F16, GGML_TYPE_F32,\n                            GGML_PREC_DEFAULT);\n                        if (!ok) {\n                            printf(\"llamafile_sgemm failed\\n\");\n                        }\n                    },\n                    nullptr);\n            } else {\n                backend->do_work_stealing_job(\n                    block_num, nullptr,\n                    [&](int task_id) {\n                        int block_id = task_id + init_block_num;\n                        int block_idx =\n                            block_table_before_retrieval_[batch_id][block_id];\n                        bool ok = llamafile_sgemm(\n                            1, 1, config_.q_head_num * config_.head_dim,\n                            anchor_.data() +\n                                (layer_idx * config_.max_block_num +\n                                 block_table_before_retrieval_[batch_id]\n                                                              [block_idx]) *\n                                    config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim,\n                            config_.q_head_num * config_.head_dim,\n                            avg_q_fp16[batch_id].data(),\n                            config_.q_head_num * config_.head_dim,\n                            block_similar_[batch_id].data() + block_id, 1, 0, 1,\n                            GGML_TASK_TYPE_COMPUTE, GGML_TYPE_F16,\n                            GGML_TYPE_F16, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n                        if (!ok) {\n                            printf(\"llamafile_sgemm failed\\n\");\n                        }\n                    },\n                    nullptr);\n            }\n        }\n    } else {\n        backend->do_work_stealing_job(\n            batch_size * max_block_num, nullptr,\n            [&](int task_id) {\n                int batch_id = task_id / max_block_num;\n                int block_id = task_id % max_block_num;\n                int seq_len = cache_seqlens_[batch_id];\n\n                if (block_id < init_block_num ||\n                    block_id >=\n                        (seq_len / config_.block_len) - local_block_num) {\n                    return;\n                }\n\n                int block_idx =\n                    block_table_before_retrieval_[batch_id][block_id];\n                float sim = 0;\n\n                for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                    for (int i = 0; i < config_.head_dim; i++) {\n                        float q_i = 0,\n                              qa_i = std::numeric_limits<float>::lowest();\n                        for (int q_id = 0; q_id < q_len; q_id++) {\n                            q_i += GGML_FP16_TO_FP32(\n                                q_in_data[batch_id * q_len *\n                                              config_.q_head_num *\n                                              config_.head_dim +\n                                          q_id * config_.q_head_num *\n                                              config_.head_dim +\n                                          head_id * config_.head_dim + i]);\n                        }\n                        q_i /= q_len;\n                        for (int anchor_id = 0; anchor_id < config_.anchor_num;\n                             anchor_id++) {\n                            qa_i = std::max(\n                                qa_i,\n                                GGML_FP16_TO_FP32(\n                                    anchor_[(long long)layer_idx *\n                                                config_.max_block_num *\n                                                config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            block_idx * config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            anchor_id * config_.q_head_num *\n                                                config_.head_dim +\n                                            head_id * config_.head_dim + i]) *\n                                    q_i);\n                        }\n                        sim += qa_i;\n                    }\n                }\n                block_similar_[batch_id][block_id] = sim;\n            },\n            nullptr);\n    }\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of calculating similarity: %f s\\n\", layer_idx,\n    //        diff.count());\n}\n\nvoid KVCache::select_block_layer_(int batch_size, int layer_idx,\n                                  int max_block_num, int init_block_num,\n                                  int local_block_num, int pick_block_num) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n\n        if (cache_seqlens_[batch_idx] / config_.block_len <=\n            init_block_num + pick_block_num + local_block_num) {\n            block_table_after_retrieval_[batch_idx].swap(\n                block_table_before_retrieval_[batch_idx]);\n            selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                         config_.layer_step] = 0;\n            continue;\n        }\n\n        for (int block_id = init_block_num;\n             block_id <\n             (cache_seqlens_[batch_idx] / config_.block_len) - local_block_num;\n             block_id++) {\n            top_similar_block_[batch_idx].push(std::make_pair(\n                block_similar_[batch_idx][block_id],\n                block_table_before_retrieval_[batch_idx][block_id]));\n            if (top_similar_block_[batch_idx].size() > pick_block_num) {\n                top_similar_block_[batch_idx].pop();\n            }\n        }\n\n        int i = 0;\n        for (; i < init_block_num; i++) {\n            block_table_after_retrieval_[batch_idx][i] =\n                block_table_before_retrieval_[batch_idx][i];\n        }\n        while (!top_similar_block_[batch_idx].empty()) {\n            block_table_after_retrieval_[batch_idx][i] =\n                top_similar_block_[batch_idx].top().second;\n            top_similar_block_[batch_idx].pop();\n            i++;\n        }\n        for (; i < init_block_num + pick_block_num + local_block_num; i++) {\n            block_table_after_retrieval_[batch_idx][i] =\n                block_table_before_retrieval_[batch_idx]\n                                             [(cache_seqlens_[batch_idx] /\n                                               config_.block_len) -\n                                              local_block_num + i -\n                                              init_block_num - pick_block_num];\n        }\n        if (cache_seqlens_[batch_idx] % config_.block_len != 0) {\n            block_table_after_retrieval_[batch_idx][i] =\n                block_table_before_retrieval_[batch_idx][(\n                    cache_seqlens_[batch_idx] / config_.block_len)];\n            cache_seqlens_[batch_idx] =\n                (cache_seqlens_[batch_idx] % config_.block_len) +\n                i * config_.block_len;\n            i++;\n        } else {\n            cache_seqlens_[batch_idx] =\n                (cache_seqlens_[batch_idx] % config_.block_len) +\n                i * config_.block_len;\n        }\n        for (int j = 0; j < i; j++) {\n            selected_blocks_history_[(layer_idx - config_.layer_offset) /\n                                     config_.layer_step][batch_idx][j] =\n                block_table_after_retrieval_[batch_idx][j];\n        }\n        selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                     config_.layer_step] = i;\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of selecting blocks: %f s\\n\", layer_idx,\n    //        diff.count());\n}\n\n// retrieval kvcache, get the init_block_num block at beginning, top\n// pick_block_num similar and last local_block_num blocks. Each task\n// calculates the simlarity of a certain block with the query, then push\n// the block into the priority queue. Finally, the required blocks are\n// pushed into the block_table_after_retrieval_.\nvoid KVCache::retrieval_kvcache_layer_(const uint16_t *q_in_data,\n                                       int init_block_num, int local_block_num,\n                                       int pick_block_num, int q_len,\n                                       int generate_token_idx, int batch_size,\n                                       int layer_idx, int *cache_seqlens,\n                                       int &max_block_num, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    max_block_num_after_retrieval_ = 0;\n    if (pick_block_num != -1 &&\n        (generate_token_idx % config_.token_step != 0 ||\n         (layer_idx % config_.layer_step != config_.layer_offset))) {\n\n        if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                         config_.layer_step] == 0) {\n            max_block_num_after_retrieval_ = max_block_num;\n            block_table_after_retrieval_.swap(block_table_before_retrieval_);\n        } else {\n            max_block_num_after_retrieval_ = selected_blocks_num_history_\n                [(layer_idx - config_.layer_offset) / config_.layer_step];\n            for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n                for (int i = 0; i < max_block_num_after_retrieval_; i++) {\n                    block_table_after_retrieval_[batch_idx][i] =\n                        selected_blocks_history_[(layer_idx -\n                                                  config_.layer_offset) /\n                                                 config_.layer_step][batch_idx]\n                                                [i];\n                }\n\n                if (cache_seqlens[batch_idx] % config_.block_len == 1) {\n                    selected_blocks_num_history_[(layer_idx -\n                                                  config_.layer_offset) /\n                                                 config_.layer_step] += 1;\n                    int x =\n                        selected_blocks_num_history_[(layer_idx -\n                                                      config_.layer_offset) /\n                                                     config_.layer_step];\n                    int last_block_idx =\n                        block_table_before_retrieval_[batch_idx]\n                                                     [cache_seqlens[batch_idx] /\n                                                      config_.block_len];\n                    selected_blocks_history_[(layer_idx -\n                                              config_.layer_offset) /\n                                             config_.layer_step][batch_idx]\n                                            [x - 1] = last_block_idx;\n                    block_table_after_retrieval_[batch_idx][x - 1] =\n                        last_block_idx;\n                }\n                cache_seqlens_[batch_idx] =\n                    (cache_seqlens_[batch_idx] % config_.block_len) +\n                    selected_blocks_num_history_[(layer_idx -\n                                                  config_.layer_offset) /\n                                                 config_.layer_step] *\n                        config_.block_len -\n                    config_.block_len;\n            }\n        }\n    } else if (pick_block_num != -1) {\n        max_block_num_after_retrieval_ =\n            std::min(max_block_num,\n                     init_block_num + pick_block_num + local_block_num + 1);\n        calculate_block_similarity_layer_(q_in_data, batch_size, layer_idx,\n                                          q_len, max_block_num, cache_seqlens,\n                                          init_block_num, local_block_num,\n                                          pick_block_num, backend);\n        select_block_layer_(batch_size, layer_idx, max_block_num,\n                            init_block_num, local_block_num, pick_block_num);\n    } else {\n        selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                     config_.layer_step] = 0;\n        max_block_num_after_retrieval_ = max_block_num;\n        block_table_after_retrieval_.swap(block_table_before_retrieval_);\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    //     printf(\"layer %d time of retrieval kvcache: %f s\\n\", layer_idx,\n    //     std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::calculate_sparsity_layer_(const uint16_t *q_in_data,\n                                        float *attn_sparsity, int batch_size,\n                                        int max_block_num, int *block_table,\n                                        int *cache_seqlens, Backend *backend\n\n) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        batch_size * config_.kv_head_num * max_block_num,\n        [&](int thread_id) {\n            thread_cur_head_idx_[thread_id].first = -1;\n            thread_cur_head_idx_[thread_id].second = -1;\n        },\n        [&](int task_id) {\n            int batch_id = task_id / (config_.kv_head_num * max_block_num);\n            int head_id = (task_id % (config_.kv_head_num * max_block_num)) /\n                          max_block_num;\n            int block_id = task_id % max_block_num;\n            int thread_id = Backend::thread_local_id;\n            // If the block is out of the sequence length, skip it.\n            if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n                int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n                if (seq_len == 0)\n                    return;\n\n                // Prepare the attention mask for the last block.\n                int full_blocks = seq_len / 8;\n                int remaining_bits = seq_len % 8;\n                // Fill full blocks with 1s\n                for (int i = 0; i < full_blocks; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0xFF;\n                }\n                // Fill the remaining bits in the next block\n                if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n                    thread_local_attn_mask_[thread_id][full_blocks] =\n                        (1 << remaining_bits) - 1;\n                } else {\n                    thread_local_attn_mask_[thread_id][full_blocks] = 0;\n                }\n\n                for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0;\n                }\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            } else {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            }\n            for (int i = 0; i < n_gqa_; i++) {\n                block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] =\n                    thread_local_attn_lse_[thread_id][i];\n            }\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse =\n                        thread_local_cur_attn_lse_[thread_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_attn_lse_[thread_id][i] -\n                                     thread_local_cur_attn_lse_[thread_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_[thread_id]\n                                                     [i * config_.head_dim +\n                                                      j] +=\n                            thread_local_output_fp32_[thread_id]\n                                                     [i * config_.head_dim + j];\n                    }\n                    thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n                }\n            } else {\n                if (cur_batch_idx != -1) {\n                    mutex_[cur_batch_idx][cur_head_id]->lock();\n                    for (int i = 0; i < n_gqa_; i++) {\n                        if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                            1e-6) {\n                            attn_lse_[cur_batch_idx][cur_head_id][i] =\n                                thread_local_cur_attn_lse_[thread_id][i];\n                            for (int j = 0; j < config_.head_dim; j++) {\n                                output_fp32_[cur_batch_idx][cur_head_id]\n                                            [i * config_.head_dim + j] =\n                                                thread_local_cur_output_fp32_\n                                                    [thread_id]\n                                                    [i * config_.head_dim + j];\n                            }\n                            continue;\n                        }\n                        float new_attn_lse =\n                            attn_lse_[cur_batch_idx][cur_head_id][i] +\n                            std::log(\n                                1.0 +\n                                std::exp(\n                                    thread_local_cur_attn_lse_[thread_id][i] -\n                                    attn_lse_[cur_batch_idx][cur_head_id][i]));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            output_fp32_[cur_batch_idx][cur_head_id].data() +\n                                i * config_.head_dim,\n                            std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                     new_attn_lse));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            thread_local_cur_output_fp32_[thread_id].data() +\n                                i * config_.head_dim,\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     new_attn_lse));\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] +=\n                                thread_local_cur_output_fp32_\n                                    [thread_id][i * config_.head_dim + j];\n                        }\n                        attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                    }\n                    mutex_[cur_batch_idx][cur_head_id]->unlock();\n                }\n                thread_cur_head_idx_[thread_id].first = batch_id;\n                thread_cur_head_idx_[thread_id].second = head_id;\n                for (int i = 0; i < n_gqa_; i++) {\n                    thread_local_cur_attn_lse_[thread_id][i] =\n                        thread_local_attn_lse_[thread_id][i];\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_\n                            [thread_id][i * config_.head_dim + j] =\n                                thread_local_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                }\n            }\n        },\n        // Merge the results of the remaining blocks.\n        [&](int thread_id) {\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (cur_head_id != -1) {\n                mutex_[cur_batch_idx][cur_head_id]->lock();\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse;\n                    if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                        1e-6) {\n                        attn_lse_[cur_batch_idx][cur_head_id][i] =\n                            thread_local_cur_attn_lse_[thread_id][i];\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] =\n                                            thread_local_cur_output_fp32_\n                                                [thread_id]\n                                                [i * config_.head_dim + j];\n                        }\n                        continue;\n                    }\n                    new_attn_lse =\n                        attn_lse_[cur_batch_idx][cur_head_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     attn_lse_[cur_batch_idx][cur_head_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        output_fp32_[cur_batch_idx][cur_head_id].data() +\n                            i * config_.head_dim,\n                        std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        output_fp32_[cur_batch_idx][cur_head_id]\n                                    [i * config_.head_dim + j] +=\n                            thread_local_cur_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                    attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                }\n                mutex_[cur_batch_idx][cur_head_id]->unlock();\n            }\n        });\n\n    for (int i = 0; i < batch_size; i++) {\n        for (int j = 0; j < max_block_num_after_retrieval_; j++) {\n            int block_idx = block_table_after_retrieval_[i][j];\n            for (int k = 0; k < config_.q_head_num; k++) {\n                attn_sparsity[i * config_.q_head_num + k] +=\n                    std::exp(block_lse_[i][block_idx][k] -\n                             attn_lse_[i][k / n_gqa_][k % n_gqa_]);\n            }\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of calculating sparsity: %f s\\n\", layer_id_,\n    //        diff.count());\n}\n\nvoid KVCache::attn_initialize_kvhead_(int batch_size, int layer_idx,\n                                      int *block_table, int &max_block_num,\n                                      int *cache_seqlens) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        // initialize output_fp32_ and attn_lse_\n        for (int i = 0; i < config_.kv_head_num; i++) {\n            for (int j = 0; j < n_gqa_ * config_.head_dim; j++) {\n                output_fp32_[batch_idx][i][j] = 0;\n            }\n            for (int j = 0; j < n_gqa_; j++) {\n                attn_lse_[batch_idx][i][j] = 0;\n            }\n        }\n\n        // clear top_similar_block_\n        while (!top_similar_block_[batch_idx].empty())\n            top_similar_block_[batch_idx].pop();\n    }\n\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        cache_seqlens_[batch_idx] = cache_seqlens[batch_idx];\n        for (int i = 0; i < max_block_num; i++) {\n            for (int j = 0; j < config_.kv_head_num; j++) {\n                block_table_before_retrieval_kvhead_[batch_idx][i][j] =\n                    block_table[batch_idx * max_block_num + i];\n                block_similar_kv_head_[batch_idx][i][j] = 0;\n            }\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    // printf(\"layer %d time of initializing attn: %f s\\n\", layer_idx,\n    //        std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::retrieval_kvcache_kvhead_(const uint16_t *q_in_data,\n                                        int init_block_num, int local_block_num,\n                                        int pick_block_num, int q_len,\n                                        int generate_token_idx, int batch_size,\n                                        int layer_idx, int *cache_seqlens,\n                                        int &max_block_num, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    max_block_num_after_retrieval_ = 0;\n    if (pick_block_num != -1 &&\n        (generate_token_idx % config_.token_step != 0 ||\n         (layer_idx % config_.layer_step != config_.layer_offset))) {\n\n        if (selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                         config_.layer_step] == 0) {\n            max_block_num_after_retrieval_ = max_block_num;\n            for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n                for (int i = 0; i < max_block_num; i++) {\n                    for (int j = 0; j < config_.kv_head_num; j++) {\n                        block_table_after_retrieval_kvhead_[batch_idx][i][j] =\n                            block_table_before_retrieval_kvhead_[batch_idx][i]\n                                                                [j];\n                    }\n                }\n            }\n        } else {\n\n            max_block_num_after_retrieval_ = selected_blocks_num_history_\n                [(layer_idx - config_.layer_offset) / config_.layer_step];\n\n            for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n                for (int i = 0; i < max_block_num_after_retrieval_; i++) {\n                    for (int j = 0; j < config_.kv_head_num; j++) {\n                        block_table_after_retrieval_kvhead_[batch_idx][i][j] =\n                            selected_blocks_history_kvhead_\n                                [(layer_idx - config_.layer_offset) /\n                                 config_.layer_step][batch_idx][i][j];\n                    }\n                }\n\n                if (cache_seqlens[batch_idx] % config_.block_len == 1) {\n                    selected_blocks_num_history_[(layer_idx -\n                                                  config_.layer_offset) /\n                                                 config_.layer_step] += 1;\n                    int x =\n                        selected_blocks_num_history_[(layer_idx -\n                                                      config_.layer_offset) /\n                                                     config_.layer_step];\n                    for (int i = 0; i < config_.kv_head_num; i++) {\n                        int last_block_idx =\n                            block_table_before_retrieval_kvhead_\n                                [batch_idx][cache_seqlens[batch_idx] /\n                                            config_.block_len][i];\n                        selected_blocks_history_kvhead_[(layer_idx -\n                                                         config_.layer_offset) /\n                                                        config_.layer_step]\n                                                       [batch_idx][x - 1][i] =\n                                                           last_block_idx;\n                        block_table_after_retrieval_kvhead_[batch_idx][x - 1]\n                                                           [i] = last_block_idx;\n                    }\n                }\n                cache_seqlens_[batch_idx] = std::min(\n                    cache_seqlens_[batch_idx],\n                    (cache_seqlens_[batch_idx] % config_.block_len) +\n                        (init_block_num + pick_block_num + local_block_num) *\n                            config_.block_len);\n            }\n        }\n    } else if (pick_block_num != -1) {\n        max_block_num_after_retrieval_ =\n            std::min(max_block_num,\n                     init_block_num + pick_block_num + local_block_num + 1);\n        calculate_block_similarity_kvhead_(q_in_data, batch_size, layer_idx,\n                                           q_len, max_block_num, cache_seqlens,\n                                           init_block_num, local_block_num,\n                                           pick_block_num, backend);\n        select_block_kvhead_(batch_size, layer_idx, max_block_num,\n                             init_block_num, local_block_num, pick_block_num);\n    } else {\n        selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                     config_.layer_step] = 0;\n        max_block_num_after_retrieval_ = max_block_num;\n        for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n            for (int i = 0; i < max_block_num; i++) {\n                for (int j = 0; j < config_.kv_head_num; j++) {\n                    block_table_after_retrieval_kvhead_[batch_idx][i][j] =\n                        block_table_before_retrieval_kvhead_[batch_idx][i][j];\n                }\n            }\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    // printf(\"layer %d time of retrieval kvcache: %f s\\n\", layer_idx,\n    //        std::chrono::duration<double>(end - start).count());\n}\nvoid KVCache::calculate_sparsity_kvhead_(const uint16_t *q_in_data,\n                                         float *attn_sparsity, int batch_size,\n                                         int max_block_num, int *block_table,\n                                         int *cache_seqlens, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        batch_size * config_.kv_head_num * max_block_num,\n        [&](int thread_id) {\n            thread_cur_head_idx_[thread_id].first = -1;\n            thread_cur_head_idx_[thread_id].second = -1;\n        },\n        [&](int task_id) {\n            int batch_id = task_id / (config_.kv_head_num * max_block_num);\n            int head_id = (task_id % (config_.kv_head_num * max_block_num)) /\n                          max_block_num;\n            int block_id = task_id % max_block_num;\n            int thread_id = Backend::thread_local_id;\n            // If the block is out of the sequence length, skip it.\n            if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            if (cache_seqlens_[batch_id] / config_.block_len == block_id) {\n                int seq_len = cache_seqlens_[batch_id] % config_.block_len;\n                if (seq_len == 0)\n                    return;\n\n                // Prepare the attention mask for the last block.\n                int full_blocks = seq_len / 8;\n                int remaining_bits = seq_len % 8;\n\n                // Fill full blocks with 1s\n                for (int i = 0; i < full_blocks; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0xFF;\n                }\n                // Fill the remaining bits in the next block\n                if (remaining_bits > 0 && full_blocks < seq_len_ / 8) {\n                    thread_local_attn_mask_[thread_id][full_blocks] =\n                        (1 << remaining_bits) - 1;\n                } else {\n                    thread_local_attn_mask_[thread_id][full_blocks] = 0;\n                }\n\n                for (int i = full_blocks + 1; i < seq_len_ / 8; ++i) {\n                    thread_local_attn_mask_[thread_id][i] = 0;\n                }\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, false,\n                        thread_local_attn_mask_[thread_id].data(),\n                        GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            } else {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num, GGML_TYPE_F16,\n                        (void *)&q_in_data[batch_id * config_.kv_head_num *\n                                               n_gqa_ * config_.head_dim +\n                                           head_id * n_gqa_ * config_.head_dim],\n                        seq_len_, 0, true, nullptr, GGML_TYPE_F16, 0,\n                        k_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_F16, 1,\n                        v_cache_fp16_[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q4_0, 0,\n                        k_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q4_0, 1,\n                        v_cache_q4[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    attn_with_kvcache_one_block_(\n                        config_.head_dim,\n                        config_.q_head_num / config_.kv_head_num,\n                        GGML_TYPE_Q8_0, q_q8_0_[batch_id][head_id].data(),\n                        seq_len_, 0, true, nullptr, GGML_TYPE_Q8_0, 0,\n                        k_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr, GGML_TYPE_Q8_0, 1,\n                        v_cache_q8[layer_id_][head_id][block_idx].data(), 0,\n                        nullptr, nullptr,\n                        thread_local_attn_score_[thread_id].data(),\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_attn_lse_[thread_id].data(),\n                        thread_local_draft_[thread_id].data(), nullptr,\n                        cos_.data(), sin_.data());\n                    dequantize_row_q8_0(\n                        thread_local_output_q8_0_[thread_id].data(),\n                        thread_local_output_fp32_[thread_id].data(),\n                        n_gqa_ * config_.head_dim);\n                }\n            }\n            for (int i = 0; i < n_gqa_; i++) {\n                block_lse_[batch_id][block_idx][head_id * n_gqa_ + i] =\n                    thread_local_attn_lse_[thread_id][i];\n            }\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (batch_id == cur_batch_idx && head_id == cur_head_id) {\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse =\n                        thread_local_cur_attn_lse_[thread_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_attn_lse_[thread_id][i] -\n                                     thread_local_cur_attn_lse_[thread_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_[thread_id]\n                                                     [i * config_.head_dim +\n                                                      j] +=\n                            thread_local_output_fp32_[thread_id]\n                                                     [i * config_.head_dim + j];\n                    }\n                    thread_local_cur_attn_lse_[thread_id][i] = new_attn_lse;\n                }\n            } else {\n                if (cur_batch_idx != -1) {\n                    mutex_[cur_batch_idx][cur_head_id]->lock();\n                    for (int i = 0; i < n_gqa_; i++) {\n                        if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                            1e-6) {\n                            attn_lse_[cur_batch_idx][cur_head_id][i] =\n                                thread_local_cur_attn_lse_[thread_id][i];\n                            for (int j = 0; j < config_.head_dim; j++) {\n                                output_fp32_[cur_batch_idx][cur_head_id]\n                                            [i * config_.head_dim + j] =\n                                                thread_local_cur_output_fp32_\n                                                    [thread_id]\n                                                    [i * config_.head_dim + j];\n                            }\n                            continue;\n                        }\n                        float new_attn_lse =\n                            attn_lse_[cur_batch_idx][cur_head_id][i] +\n                            std::log(\n                                1.0 +\n                                std::exp(\n                                    thread_local_cur_attn_lse_[thread_id][i] -\n                                    attn_lse_[cur_batch_idx][cur_head_id][i]));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            output_fp32_[cur_batch_idx][cur_head_id].data() +\n                                i * config_.head_dim,\n                            std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                     new_attn_lse));\n                        ggml_vec_scale_f32(\n                            config_.head_dim,\n                            thread_local_cur_output_fp32_[thread_id].data() +\n                                i * config_.head_dim,\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     new_attn_lse));\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] +=\n                                thread_local_cur_output_fp32_\n                                    [thread_id][i * config_.head_dim + j];\n                        }\n                        attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                    }\n                    mutex_[cur_batch_idx][cur_head_id]->unlock();\n                }\n                thread_cur_head_idx_[thread_id].first = batch_id;\n                thread_cur_head_idx_[thread_id].second = head_id;\n                for (int i = 0; i < n_gqa_; i++) {\n                    thread_local_cur_attn_lse_[thread_id][i] =\n                        thread_local_attn_lse_[thread_id][i];\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        thread_local_cur_output_fp32_\n                            [thread_id][i * config_.head_dim + j] =\n                                thread_local_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                }\n            }\n        },\n        // Merge the results of the remaining blocks.\n        [&](int thread_id) {\n            int cur_batch_idx = thread_cur_head_idx_[thread_id].first;\n            int cur_head_id = thread_cur_head_idx_[thread_id].second;\n            if (cur_head_id != -1) {\n                mutex_[cur_batch_idx][cur_head_id]->lock();\n                for (int i = 0; i < n_gqa_; i++) {\n                    float new_attn_lse;\n                    if (std::abs(attn_lse_[cur_batch_idx][cur_head_id][i]) <\n                        1e-6) {\n                        attn_lse_[cur_batch_idx][cur_head_id][i] =\n                            thread_local_cur_attn_lse_[thread_id][i];\n                        for (int j = 0; j < config_.head_dim; j++) {\n                            output_fp32_[cur_batch_idx][cur_head_id]\n                                        [i * config_.head_dim + j] =\n                                            thread_local_cur_output_fp32_\n                                                [thread_id]\n                                                [i * config_.head_dim + j];\n                        }\n                        continue;\n                    }\n                    new_attn_lse =\n                        attn_lse_[cur_batch_idx][cur_head_id][i] +\n                        std::log(\n                            1.0 +\n                            std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                     attn_lse_[cur_batch_idx][cur_head_id][i]));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        output_fp32_[cur_batch_idx][cur_head_id].data() +\n                            i * config_.head_dim,\n                        std::exp(attn_lse_[cur_batch_idx][cur_head_id][i] -\n                                 new_attn_lse));\n                    ggml_vec_scale_f32(\n                        config_.head_dim,\n                        thread_local_cur_output_fp32_[thread_id].data() +\n                            i * config_.head_dim,\n                        std::exp(thread_local_cur_attn_lse_[thread_id][i] -\n                                 new_attn_lse));\n                    for (int j = 0; j < config_.head_dim; j++) {\n                        output_fp32_[cur_batch_idx][cur_head_id]\n                                    [i * config_.head_dim + j] +=\n                            thread_local_cur_output_fp32_[thread_id]\n                                                         [i * config_.head_dim +\n                                                          j];\n                    }\n                    attn_lse_[cur_batch_idx][cur_head_id][i] = new_attn_lse;\n                }\n                mutex_[cur_batch_idx][cur_head_id]->unlock();\n            }\n        });\n\n    for (int i = 0; i < batch_size; i++) {\n        for (int j = 0; j < max_block_num_after_retrieval_; j++) {\n            for (int k = 0; k < config_.q_head_num; k++) {\n                int block_idx =\n                    block_table_after_retrieval_kvhead_[i][j][k / n_gqa_];\n                attn_sparsity[i * config_.q_head_num + k] +=\n                    std::exp(block_lse_[i][block_idx][k] -\n                             attn_lse_[i][k / n_gqa_][k % n_gqa_]);\n            }\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of calculating sparsity: %f s\\n\", layer_id_,\n    //        diff.count());\n}\nvoid KVCache::calculate_block_similarity_kvhead_(\n    const uint16_t *q_in_data, int batch_size, int layer_idx, int q_len,\n    int max_block_num, int *cache_seqlens, int init_block_num,\n    int local_block_num, int pick_block_num, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    backend->do_work_stealing_job(\n        batch_size * max_block_num, nullptr,\n        [&](int task_id) {\n            int batch_id = task_id / max_block_num;\n            int block_id = task_id % max_block_num;\n            int seq_len = cache_seqlens_[batch_id];\n\n            if (block_id < init_block_num ||\n                block_id >= (seq_len / config_.block_len) - local_block_num) {\n                return;\n            }\n            int block_idx =\n                block_table_before_retrieval_kvhead_[batch_id][block_id][0];\n\n            for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                for (int i = 0; i < config_.head_dim; i++) {\n                    float q_i = 0, qa_i = std::numeric_limits<float>::lowest();\n                    for (int q_id = 0; q_id < q_len; q_id++) {\n                        q_i += GGML_FP16_TO_FP32(\n                            q_in_data[batch_id * q_len * config_.q_head_num *\n                                          config_.head_dim +\n                                      q_id * config_.q_head_num *\n                                          config_.head_dim +\n                                      head_id * config_.head_dim + i]);\n                    }\n                    q_i /= q_len;\n                    for (int anchor_id = 0; anchor_id < config_.anchor_num;\n                         anchor_id++) {\n                        qa_i = std::max(\n                            qa_i,\n                            GGML_FP16_TO_FP32(\n                                anchor_[layer_idx * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        anchor_id * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + i]) *\n                                q_i);\n                    }\n                    block_similar_kv_head_[batch_id][block_id]\n                                          [head_id / n_gqa_] += qa_i;\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of calculating similarity: %f s\\n\", layer_idx,\n    //        diff.count());\n}\nvoid KVCache::select_block_kvhead_(int batch_size, int layer_idx,\n                                   int max_block_num, int init_block_num,\n                                   int local_block_num, int pick_block_num) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {\n        int cache_len_after_retrieval = 0;\n        if (cache_seqlens_[batch_idx] / config_.block_len <=\n            init_block_num + pick_block_num + local_block_num) {\n            selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                         config_.layer_step] = 0;\n            for (int i = 0; i < max_block_num; i++) {\n                for (int j = 0; j < config_.kv_head_num; j++) {\n                    block_table_after_retrieval_kvhead_[batch_idx][i][j] =\n                        block_table_before_retrieval_kvhead_[batch_idx][i][j];\n                }\n            }\n            continue;\n        }\n        for (int head_id = 0; head_id < config_.kv_head_num; head_id++) {\n\n            for (int block_id = init_block_num;\n                 block_id < (cache_seqlens_[batch_idx] / config_.block_len) -\n                                local_block_num;\n                 block_id++) {\n\n                top_similar_block_[batch_idx].push(std::make_pair(\n                    block_similar_kv_head_[batch_idx][block_id][head_id],\n                    block_table_before_retrieval_kvhead_[batch_idx][block_id]\n                                                        [head_id]));\n                if (top_similar_block_[batch_idx].size() > pick_block_num) {\n                    top_similar_block_[batch_idx].pop();\n                }\n            }\n\n            int i = 0;\n            for (; i < init_block_num; i++) {\n                block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n                    block_table_before_retrieval_kvhead_[batch_idx][i][head_id];\n            }\n            while (!top_similar_block_[batch_idx].empty()) {\n                block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n                    top_similar_block_[batch_idx].top().second;\n                top_similar_block_[batch_idx].pop();\n                i++;\n            }\n            for (; i < init_block_num + pick_block_num + local_block_num; i++) {\n                block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n                    block_table_before_retrieval_kvhead_\n                        [batch_idx]\n                        [(cache_seqlens_[batch_idx] / config_.block_len) -\n                         local_block_num + i - init_block_num - pick_block_num]\n                        [head_id];\n            }\n            if (cache_seqlens_[batch_idx] % config_.block_len != 0) {\n                block_table_after_retrieval_kvhead_[batch_idx][i][head_id] =\n                    block_table_before_retrieval_kvhead_[batch_idx][(\n                        cache_seqlens_[batch_idx] / config_.block_len)]\n                                                        [head_id];\n                cache_len_after_retrieval =\n                    (cache_seqlens_[batch_idx] % config_.block_len) +\n                    i * config_.block_len;\n                i++;\n            } else {\n                cache_len_after_retrieval =\n                    (cache_seqlens_[batch_idx] % config_.block_len) +\n                    i * config_.block_len;\n            }\n            for (int j = 0; j < i; j++) {\n                selected_blocks_history_kvhead_\n                    [(layer_idx - config_.layer_offset) / config_.layer_step]\n                    [batch_idx][j][head_id] =\n                        block_table_after_retrieval_kvhead_[batch_idx][j]\n                                                           [head_id];\n            }\n        }\n        cache_seqlens_[batch_idx] = cache_len_after_retrieval;\n        selected_blocks_num_history_[(layer_idx - config_.layer_offset) /\n                                     config_.layer_step] =\n            (cache_len_after_retrieval + config_.block_len - 1) /\n            config_.block_len;\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    // printf(\"layer %d time of selecting block: %f s\\n\", layer_idx,\n    //        diff.count())\n}\n\nvoid KVCache::get_attn_sparsity(const ggml_fp16_t *q_in, float *attn_sparsity,\n                                int layer_idx, int generate_token_idx,\n                                int q_len, int batch_size, int max_block_num,\n                                int *block_table, int *cache_seqlens,\n                                int *block_table_origin,\n                                int *cache_seqlens_origin,\n                                int max_block_num_origin, int topk, int local,\n                                Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    layer_id_ = layer_idx;\n    int thread_num = backend->get_thread_num();\n    batch_size = 1;\n\n    const uint16_t *q_in_data = const_cast<const uint16_t *>(q_in);\n\n    quantize_q_(q_in_data, batch_size);\n    if (config_.retrieval_type == RetrievalType::LAYER) {\n        attn_initialize_layer_(batch_size, layer_idx, block_table,\n                               max_block_num, cache_seqlens);\n        retrieval_kvcache_layer_(q_in_data, 1, local, topk, q_len,\n                                 generate_token_idx, batch_size, layer_idx,\n                                 cache_seqlens, max_block_num, backend);\n        calculate_sparsity_layer_(q_in_data, attn_sparsity, batch_size,\n                                  max_block_num_origin, block_table_origin,\n                                  cache_seqlens_origin, backend);\n    } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n        attn_initialize_kvhead_(batch_size, layer_idx, block_table,\n                                max_block_num, cache_seqlens);\n        retrieval_kvcache_kvhead_(q_in_data, 1, local, topk, q_len,\n                                  generate_token_idx, batch_size, layer_idx,\n                                  cache_seqlens, max_block_num, backend);\n        calculate_sparsity_kvhead_(q_in_data, attn_sparsity, batch_size,\n                                   max_block_num_origin, block_table_origin,\n                                   cache_seqlens_origin, backend);\n    }\n}\n\nvoid KVCache::attn_with_kvcache_one_block_(\n    int head_dim, int bsz,\n    ggml_type q_type, // GGML data type of `Q`, only supports fp16 and q8_0\n    // [bsz, head_dim]\n    // Quantization is always on the head_dim dimension (per_token). If\n    // head_dim % 32 != 0, an error will be raised. The size must be bsz *\n    // head_dim/32 * qtype_size.\n    const void *q,\n\n    int past_kv_len, int past_kv_offset,\n    bool is_full_attn, // true indicates a full 1 mask\n    // If is_full_attn = false, a bit matrix representing the mask is\n    // passed. [bsz, past_kv_len]\n    const uint8_t *attn_mask,\n\n    ggml_type k_type, // GGML data type of `K Cache`, only supports fp16,\n                      // q4_0, q8_0\n    int k_quant_type, // 0 for per_token, 1 for per_channel, others raise an\n                      // error\n    // [seq_len, head_dim]\n    // If quant_type == 0, head_dim % 32 must be 0.\n    // If quant_type == 1, seq_len % 32 must be 0.\n    const void *k_cache,\n\n    // k_anchor_type must be fp16\n    int num_k_anchor, // num_k_anchor == 0 indicates no anchor\n    // [num_k_anchor, head_dim]\n    const void *k_cache_anchors,\n    // Each token is associated with the nearest previous position's anchor,\n    // with the same distance.\n    const int *k_cache_anchor_pos,\n\n    // v_cache similar to k_cache\n    ggml_type v_type, int v_quant_type,\n    // [head_dim, seq_len]\n    const void *v_cache, int num_v_anchor, const void *v_cache_anchors,\n    const int *v_cache_anchor_pos,\n\n    // Pre-allocated buffer for intermediate calculations [bsz,\n    // past_kv_len]. No malloc is performed inside this function.\n    float *attn_score,\n\n    // Output: [bsz, head_dim], with the same type as q_type\n    void *output,\n    // [bsz]\n    float *lse,\n\n    // Pre-allocated temporary buffer with sufficient size:\n    // (2 * bsz * past_kv_len + 6 * bsz * head_dim + 2 * past_kv_len *\n    // head_dim + past_kv_len * head_dim / 32) bytes.\n    void *draft,\n\n    // Apply rotary embedding online\n    const int *rotary_angle, const void *rotary_cos, const void *rotary_sin\n    // rotary_cos=None,\n    // rotary_sin=None,\n    // cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,\n    // cache_batch_idx: Optional[torch.Tensor] = None,\n    // rotary_interleaved=True,\n\n    // // Not supported for now\n    // window_size=(-1, -1),  # -1 means infinite context window\n    // alibi_slopes=None,\n) {\n    assert(head_dim % 32 == 0);\n    assert(k_quant_type == 0);\n    assert(v_quant_type == 1);\n    assert(q_type == GGML_TYPE_F16 || q_type == GGML_TYPE_Q8_0);\n    if (q_type == GGML_TYPE_F16) {\n        assert(k_type == GGML_TYPE_F16);\n        assert(v_type == GGML_TYPE_F16);\n\n        // attn = q * k + q * k_anchor\n        // TODO: anchor\n        assert(num_k_anchor == 0);\n\n        if (rotary_angle != nullptr) {\n            ggml_fp16_t *k_cache_with_rope_fp16 =\n                (reinterpret_cast<ggml_fp16_t *>(draft) +\n                 sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n                 sizeof(float) * bsz * head_dim);\n            // dequant k_cache and apply rope\n            // k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)\n            // k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)\n\n            // k(i)cos(i) -> k_rope(i)\n            // k(i)sin(i+l) -> k_rope(i+l)\n\n            // k(i)cos(i) -> k_rope(i)\n            // -k(i)sin(i-l) -> k_rope(i-l)\n\n            std::vector<float> block_fp32(32);\n            for (int k = 0; k < past_kv_len; k++) {\n                int angle = rotary_angle[k];\n                for (int l = 0; l < head_dim / 32; l++) {\n                    for (int m = 0; m < 32; m++) {\n                        float x = GGML_FP16_TO_FP32((\n                            (ggml_fp16_t *)k_cache)[k * head_dim + l * 32 + m]);\n                        float sin_val = GGML_FP16_TO_FP32(\n                            ((ggml_fp16_t *)\n                                 rotary_sin)[angle * head_dim + l * 32 + m]);\n                        float cos_val = GGML_FP16_TO_FP32(\n                            ((ggml_fp16_t *)\n                                 rotary_cos)[angle * head_dim + l * 32 + m]);\n\n                        if (l * 32 + m < head_dim / 2) {\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =\n                                GGML_FP32_TO_FP16(x * cos_val);\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m +\n                                                   head_dim / 2] =\n                                GGML_FP32_TO_FP16(-x * sin_val);\n                        } else {\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =\n                                GGML_FP32_TO_FP16(\n                                    GGML_FP16_TO_FP32(\n                                        k_cache_with_rope_fp16[k * head_dim +\n                                                               l * 32 + m]) +\n                                    x * sin_val);\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m -\n                                                   head_dim / 2] =\n                                GGML_FP32_TO_FP16(\n                                    GGML_FP16_TO_FP32(\n                                        k_cache_with_rope_fp16[k * head_dim +\n                                                               l * 32 + m -\n                                                               head_dim / 2]) -\n                                    x * cos_val);\n                        }\n                    }\n                }\n            }\n\n            llamafile_sgemm(past_kv_len, bsz, head_dim,\n                            (ggml_fp16_t *)k_cache_with_rope_fp16, head_dim,\n                            (ggml_fp16_t *)q, head_dim, attn_score, past_kv_len,\n                            0, 1, GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16,\n                            GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        } else {\n            bool ok = llamafile_sgemm(\n                past_kv_len, bsz, head_dim, (ggml_fp16_t *)k_cache, head_dim,\n                (ggml_fp16_t *)q, head_dim, attn_score, past_kv_len, 0, 1,\n                GGML_TASK_TYPE_COMPUTE, k_type, GGML_TYPE_F16, GGML_TYPE_F32,\n                GGML_PREC_DEFAULT);\n\n            if (!ok) {\n                printf(\"llamafile_sgemm failed\\n\");\n            }\n        }\n        // attn = attn * scale\n        float scale_factor = 1.0 / std::sqrt(float(head_dim));\n        ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);\n\n        // attn = attn & mask\n        if (!is_full_attn) {\n            for (int i = 0; i < bsz; i++) {\n                for (int j = 0; j < past_kv_len; j++) {\n                    int index = i * past_kv_len + j;\n                    if (!(attn_mask[j / 8] & (1 << (j % 8)))) {\n                        attn_score[index] =\n                            std::numeric_limits<float>::lowest();\n                    }\n                }\n            }\n        }\n\n        // attn = softmax(attn)\n        for (int i = 0; i < bsz; i++) {\n            float sum_exp = 0;\n            for (int j = 0; j < past_kv_len; j++) {\n                attn_score[i * past_kv_len + j] =\n                    std::exp(attn_score[i * past_kv_len + j]);\n                sum_exp += attn_score[i * past_kv_len + j];\n            }\n            for (int j = 0; j < past_kv_len; j++) {\n                attn_score[i * past_kv_len + j] /= sum_exp;\n            }\n            if (lse != nullptr) {\n                lse[i] = std::log(sum_exp);\n            }\n        }\n\n        // output = attn * v + attn * v_anchor\n        // std::vector<float> sum(bsz * head_dim);\n        float *sum = reinterpret_cast<float *>(reinterpret_cast<char *>(draft) +\n                                               sizeof(block_q8_0) * bsz *\n                                                   past_kv_len / QK8_0);\n\n        // float* attn_score_fp16(bsz, past_kv_len)\n        ggml_fp16_t *attn_score_fp16 = (reinterpret_cast<ggml_fp16_t *>(\n            reinterpret_cast<char *>(draft) +\n            sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n            sizeof(float) * bsz * head_dim));\n\n        for (int i = 0; i < bsz * past_kv_len; i++) {\n            attn_score_fp16[i] = GGML_FP32_TO_FP16(attn_score[i]);\n        }\n\n        // TODO: anchor\n        assert(num_v_anchor == 0);\n        bool ok = llamafile_sgemm(\n            head_dim, bsz, past_kv_len, (ggml_fp16_t *)v_cache, past_kv_len,\n            (ggml_fp16_t *)attn_score_fp16, past_kv_len, sum, head_dim, 0, 1,\n            GGML_TASK_TYPE_COMPUTE, v_type, GGML_TYPE_F16, GGML_TYPE_F32,\n            GGML_PREC_DEFAULT);\n        if (!ok) {\n            printf(\"llamafile_sgemm failed\\n\");\n        }\n\n        // copy to output\n        for (int i = 0; i < bsz; i++) {\n            for (int j = 0; j < head_dim; j++) {\n                ((float *)output)[i * head_dim + j] = sum[i * head_dim + j];\n            }\n        }\n    } else {\n        assert(k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q8_0);\n        assert(v_type == GGML_TYPE_Q4_0 || v_type == GGML_TYPE_Q8_0);\n\n        // attn = q * k + q * k_anchor\n        // TODO: anchor\n        assert(num_k_anchor == 0);\n\n        if (rotary_angle != nullptr) {\n            ggml_fp16_t *k_cache_with_rope_fp16 =\n                (reinterpret_cast<ggml_fp16_t *>(draft) +\n                 sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n                 sizeof(float) * bsz * head_dim);\n            block_q4_0 *k_cache_with_rope_q4 =\n                (reinterpret_cast<block_q4_0 *>(draft) +\n                 sizeof(block_q8_0) * bsz * past_kv_len / QK8_0 +\n                 sizeof(float) * bsz * head_dim) +\n                sizeof(ggml_fp16_t) * bsz * head_dim;\n            // dequant k_cache and apply rope\n            // k_rope(i) = k(i) * cos(i) - k(i+l) * sin(i)\n            // k_rope(i+l) = k(i+l) * cos(i+l) + k(i) * sin(i)\n\n            // k(i)cos(i) -> k_rope(i)\n            // k(i)sin(i+l) -> k_rope(i+l)\n\n            // k(i)cos(i) -> k_rope(i)\n            // -k(i)sin(i-l) -> k_rope(i-l)\n\n            std::vector<float> block_fp32(32);\n            for (int k = 0; k < past_kv_len; k++) {\n                int angle = rotary_angle[k];\n                for (int l = 0; l < head_dim / 32; l++) {\n                    block_q4_0 block =\n                        ((block_q4_0 *)k_cache)[k * head_dim / 32 + l];\n                    dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                    for (int m = 0; m < 32; m++) {\n                        float sin_val = GGML_FP16_TO_FP32(\n                            ((ggml_fp16_t *)\n                                 rotary_sin)[angle * head_dim + l * 32 + m]);\n                        float cos_val = GGML_FP16_TO_FP32(\n                            ((ggml_fp16_t *)\n                                 rotary_cos)[angle * head_dim + l * 32 + m]);\n\n                        if (l * 32 + m < head_dim / 2) {\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m] =\n                                GGML_FP32_TO_FP16(block_fp32[m] * cos_val);\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m +\n                                                   head_dim / 2] =\n                                GGML_FP32_TO_FP16(-block_fp32[m] * sin_val);\n                        } else {\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m] +=\n                                GGML_FP32_TO_FP16(block_fp32[m] * sin_val);\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m -\n                                                   head_dim / 2] -=\n                                GGML_FP32_TO_FP16(block_fp32[m] * cos_val);\n                        }\n                    }\n                }\n            }\n            // quantize k_cache_with_rope_fp16\n            for (int k = 0; k < past_kv_len; k++) {\n                for (int l = 0; l < head_dim / 32; l++) {\n                    for (int m = 0; m < 32; m++) {\n                        block_fp32[m] = GGML_FP16_TO_FP32(\n                            k_cache_with_rope_fp16[k * head_dim + l * 32 + m]);\n                    }\n                    quantize_row_q4_0(\n                        block_fp32.data(),\n                        &k_cache_with_rope_q4[k * head_dim / 32 + l], 32);\n                }\n            }\n\n            llamafile_sgemm(past_kv_len, bsz, head_dim / 32,\n                            (block_q4_0 *)k_cache_with_rope_q4, head_dim / 32,\n                            (block_q8_0 *)q, head_dim / 32, attn_score,\n                            past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type,\n                            GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        } else {\n            llamafile_sgemm(past_kv_len, bsz, head_dim / 32,\n                            (block_q4_0 *)k_cache, head_dim / 32,\n                            (block_q8_0 *)q, head_dim / 32, attn_score,\n                            past_kv_len, 0, 1, GGML_TASK_TYPE_COMPUTE, k_type,\n                            GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        }\n\n        // attn = attn * scale\n        float scale_factor = 1.0 / std::sqrt(float(head_dim));\n        ggml_vec_scale_f32(bsz * past_kv_len, attn_score, scale_factor);\n\n        // attn = attn & mask\n        if (!is_full_attn) {\n            for (int i = 0; i < bsz; i++) {\n                for (int j = 0; j < past_kv_len; j++) {\n                    int index = i * past_kv_len + j;\n                    if (!(attn_mask[j / 8] & (1 << (j % 8)))) {\n                        attn_score[index] =\n                            std::numeric_limits<float>::lowest();\n                    }\n                }\n            }\n        }\n\n        // attn = softmax(attn)\n        for (int i = 0; i < bsz; i++) {\n            float sum_exp = 0;\n            for (int j = 0; j < past_kv_len; j++) {\n                attn_score[i * past_kv_len + j] =\n                    std::exp(attn_score[i * past_kv_len + j]);\n                sum_exp += attn_score[i * past_kv_len + j];\n            }\n            for (int j = 0; j < past_kv_len; j++) {\n                attn_score[i * past_kv_len + j] /= sum_exp;\n            }\n            if (lse != nullptr) {\n                lse[i] = std::log(sum_exp);\n            }\n        }\n\n        // output = attn * v + attn * v_anchor\n        // std::vector<block_q8_0> attn_q8_0(bsz * past_kv_len / QK8_0);\n        block_q8_0 *attn_q8_0 = reinterpret_cast<block_q8_0 *>(draft);\n        quantize_row_q8_0(attn_score, attn_q8_0, bsz * past_kv_len);\n        // std::vector<float> sum(bsz * head_dim);\n        float *sum = reinterpret_cast<float *>(reinterpret_cast<char *>(draft) +\n                                               sizeof(block_q8_0) * bsz *\n                                                   past_kv_len / QK8_0);\n        // TODO: anchor\n        assert(num_v_anchor == 0);\n        llamafile_sgemm(head_dim, bsz, past_kv_len / 32, (block_q4_0 *)v_cache,\n                        past_kv_len / 32, attn_q8_0, past_kv_len / 32, sum,\n                        head_dim, 0, 1, GGML_TASK_TYPE_COMPUTE, v_type,\n                        GGML_TYPE_Q8_0, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n\n        quantize_row_q8_0(sum, (block_q8_0 *)output, bsz * head_dim);\n    }\n}\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/kvcache/kvcache_load_dump.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"kvcache.h\"\n\n#include <chrono>\n\nvoid KVCache::load_kvcache(std::string tensor_file_path, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    std::ifstream ifs_tensor(tensor_file_path, std::ios::binary);\n    if (!ifs_tensor) {\n        throw std::runtime_error(\"Failed to open tensor file\");\n    }\n    ifs_tensor.read(reinterpret_cast<char *>(&cache_total_len_),\n                    sizeof(cache_total_len_));\n    int past_block_num =\n        (cache_total_len_ + config_.block_len - 1) / config_.block_len;\n    printf(\"cache_total_len: %d, past_block_num: %d\\n\", cache_total_len_,\n           past_block_num);\n    for (int i = 0; i < config_.layer_num; ++i) {\n        past_block_num_[i] = past_block_num;\n    }\n    ifs_tensor.read(reinterpret_cast<char *>(anchor_.data()),\n                    anchor_.size() * sizeof(ggml_fp16_t));\n    for (int i = 0; i < config_.layer_num; ++i) {\n        for (int j = 0; j < config_.kv_head_num; ++j) {\n            for (int k = 0; k < past_block_num_[i]; ++k) {\n                if (config_.kv_type == GGML_TYPE_F16) {\n                    ifs_tensor.read(\n                        reinterpret_cast<char *>(k_cache_fp16_[i][j][k].data()),\n                        k_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));\n                    ifs_tensor.read(\n                        reinterpret_cast<char *>(v_cache_fp16_[i][j][k].data()),\n                        v_cache_fp16_[i][j][k].size() * sizeof(ggml_fp16_t));\n                } else if (config_.kv_type == GGML_TYPE_Q4_0) {\n                    ifs_tensor.read(\n                        reinterpret_cast<char *>(k_cache_q4[i][j][k].data()),\n                        k_cache_q4[i][j][k].size() * sizeof(block_q4_0));\n                    ifs_tensor.read(\n                        reinterpret_cast<char *>(v_cache_q4[i][j][k].data()),\n                        v_cache_q4[i][j][k].size() * sizeof(block_q4_0));\n                }\n            }\n        }\n        for (int k = 0; k < past_block_num_[i]; ++k) {\n            for (int l = 0; l < config_.block_len; l++) {\n                ifs_tensor.read(\n                    reinterpret_cast<char *>(importance_[i][k][l].data()),\n                    importance_[i][k][l].size() * sizeof(ggml_fp16_t));\n            }\n        }\n    }\n    ifs_tensor.close();\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    printf(\"time of load: %f s\\n\", diff.count());\n}\nvoid KVCache::dump_kvcache(int *block_table, int cache_total_len,\n                           std::string tensor_file_path, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n    std::ofstream ofs(tensor_file_path, std::ios::binary);\n    printf(\"dump_kvcache: %s\\n\", tensor_file_path.c_str());\n    if (!ofs.is_open()) {\n        std::cerr << \"Cannot open file \" << tensor_file_path << std::endl;\n        return;\n    }\n    ofs.write(reinterpret_cast<const char *>(&cache_total_len),\n              sizeof(cache_total_len));\n    int past_block_num =\n        (cache_total_len + config_.block_len - 1) / config_.block_len;\n    printf(\"cache_total_len: %d, past_block_num: %d\\n\", cache_total_len,\n           past_block_num);\n    ofs.write(reinterpret_cast<const char *>(anchor_.data()),\n              anchor_.size() * sizeof(ggml_fp16_t));\n    for (int i = 0; i < config_.layer_num; ++i) {\n        for (int j = 0; j < config_.kv_head_num; ++j) {\n            for (int k = 0; k < past_block_num; ++k) {\n                int block_idx = block_table[k];\n                if (config_.kv_type == GGML_TYPE_F16) {\n                    ofs.write(reinterpret_cast<const char *>(\n                                  k_cache_fp16_[i][j][block_idx].data()),\n                              k_cache_fp16_[i][j][block_idx].size() *\n                                  sizeof(ggml_fp16_t));\n                    ofs.write(reinterpret_cast<const char *>(\n                                  v_cache_fp16_[i][j][block_idx].data()),\n                              v_cache_fp16_[i][j][block_idx].size() *\n                                  sizeof(ggml_fp16_t));\n\n                } else if (config_.kv_type == GGML_TYPE_Q4_0) {\n                    ofs.write(reinterpret_cast<const char *>(\n                                  k_cache_q4[i][j][block_idx].data()),\n                              k_cache_q4[i][j][block_idx].size() *\n                                  sizeof(block_q4_0));\n                    ofs.write(reinterpret_cast<const char *>(\n                                  v_cache_q4[i][j][block_idx].data()),\n                              v_cache_q4[i][j][block_idx].size() *\n                                  sizeof(block_q4_0));\n                }\n            }\n        }\n        for (int k = 0; k < past_block_num; ++k) {\n            int block_idx = block_table[k];\n            for (int l = 0; l < config_.block_len; l++) {\n                ofs.write(reinterpret_cast<const char *>(\n                              importance_[i][block_idx][l].data()),\n                          importance_[i][block_idx][l].size() *\n                              sizeof(ggml_fp16_t));\n            }\n        }\n    }\n    ofs.close();\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> diff = end - start;\n    printf(\"time of dump: %f s\\n\", diff.count());\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/kvcache/kvcache_read_write.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"kvcache.h\"\n\n#include <chrono>\n\nvoid KVCache::get_anchor_one_block(ggml_fp16_t *anchor, int layer_id,\n                                   int block_idx, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    block_idx = block_idx;\n    seq_len_ = config_.block_len;\n    anchor_data_ = const_cast<uint16_t *>(anchor);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of reading anchor: %f s\\n\", layer_id,\n           block_idx, duration.count());\n}\n\nvoid KVCache::update_anchor_one_block(const ggml_fp16_t *anchor, int layer_id,\n                                      int block_idx, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    block_idx = block_idx;\n    seq_len_ = config_.block_len;\n    anchor_data_ = const_cast<uint16_t *>(anchor);\n\n    // Each task updates the anchor of a certain position\n    // backend->do_work_stealing_job(config_.anchor_num, [&](int task_id) {\n    //     int k = task_id % config_.anchor_num;\n    //     int head_id = task_id / config_.anchor_num;\n    //     memcpy(anchor_[layer_id_][head_id][block_idx].data() +\n    //                k * config_.head_dim,\n    //            anchor_data_ + k * config_.head_dim,\n    //            sizeof(uint16_t) * config_.head_dim);\n    // });\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of writting anchor: %f s\\n\", layer_id,\n           block_idx, duration.count());\n}\n\nvoid KVCache::update_importance_one_block(const ggml_fp16_t *importance,\n                                          int layer_id, int block_idx,\n                                          Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    block_idx = block_idx;\n    seq_len_ = config_.block_len;\n    importance_data_ = const_cast<uint16_t *>(importance);\n\n    // Each task updates the importance of a certain position\n    backend->do_work_stealing_job(\n        config_.block_len, nullptr,\n        [&](int task_id) {\n            int k = task_id;\n            memcpy(importance_[layer_id_][block_idx].data() + k,\n                   importance_data_ + k, sizeof(uint16_t));\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of writting importance: %f s\\n\", layer_id,\n           block_idx, duration.count());\n}\n\nvoid KVCache::get_importance_one_block(ggml_fp16_t *importance, int layer_id,\n                                       int block_idx, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    block_idx = block_idx;\n    seq_len_ = config_.block_len;\n    importance_data_ = const_cast<uint16_t *>(importance);\n\n    // Each task updates the importance of a certain position\n    backend->do_work_stealing_job(\n        config_.block_len, nullptr,\n        [&](int task_id) {\n            int k = task_id;\n            memcpy(importance_data_ + k,\n                   importance_[layer_id_][block_idx].data() + k,\n                   sizeof(uint16_t));\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of reading importance: %f s\\n\", layer_id,\n           block_idx, duration.count());\n}\n\nvoid KVCache::update_kvcache_one_block_fp16(const ggml_fp16_t *k_in,\n                                            const ggml_fp16_t *v_in,\n                                            int layer_id, int block_idx,\n                                            Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    block_idx = block_idx;\n    seq_len_ = config_.block_len;\n    k_data_ = const_cast<uint16_t *>(k_in);\n    v_data_ = const_cast<uint16_t *>(v_in);\n\n    int new_block_num = std::max((int)past_block_num_[layer_id], block_idx + 1);\n\n    importance_[layer_id_].resize(new_block_num);\n\n    for (int i = 0; i < config_.kv_head_num; i++) {\n        k_cache_q4[layer_id][i].resize(new_block_num);\n        v_cache_q4[layer_id][i].resize(new_block_num);\n        // anchor_[layer_id][i].resize(new_block_num);\n    }\n\n    for (int i = 0; i < new_block_num; i++) {\n        importance_[layer_id][i].resize(config_.block_len);\n    }\n\n    // Each task updates the k cache or v cache of a certain header\n    backend->do_work_stealing_job(\n        config_.kv_head_num * 2, nullptr,\n        [&](int task_id) {\n            std::vector<float> block_fp32(32);\n            int head_id = task_id / 2;\n            if (task_id & 1) {\n                // fill k_cache_\n                k_cache_q4[layer_id_][head_id][block_idx].resize(\n                    config_.block_len * config_.head_dim / 32);\n                for (int k = 0; k < config_.block_len; k++) {\n                    for (int l = 0; l < config_.head_dim / 32; l++) {\n                        block_q4_0 block;\n                        for (int m = 0; m < 32; m++) {\n\n                            block_fp32[m] = GGML_FP16_TO_FP32(\n                                k_data_[((0 * config_.kv_head_num + head_id) *\n                                             seq_len_ +\n                                         0 * config_.block_len + k) *\n                                            config_.head_dim +\n                                        l * 32 + m]);\n                        }\n                        quantize_row_q4_0(block_fp32.data(), &block, 32);\n                        k_cache_q4[layer_id_][head_id][block_idx]\n                                  [k * config_.head_dim / 32 + l] = block;\n                    }\n                }\n            } else {\n                // fill v_cache_\n                v_cache_q4[layer_id_][head_id][block_idx].resize(\n                    config_.head_dim * config_.block_len / 32);\n                for (int k = 0; k < config_.block_len / 32; k++) {\n                    for (int l = 0; l < config_.head_dim; l++) {\n                        block_q4_0 block;\n                        for (int m = 0; m < 32; m++) {\n\n                            block_fp32[m] = GGML_FP16_TO_FP32(\n                                v_data_[((0 * config_.kv_head_num + head_id) *\n                                             seq_len_ +\n                                         0 * config_.block_len + k * 32 + m) *\n                                            config_.head_dim +\n                                        l]);\n                        }\n                        quantize_row_q4_0(block_fp32.data(), &block, 32);\n                        v_cache_q4[layer_id_][head_id][block_idx]\n                                  [l * config_.block_len / 32 + k] = block;\n                    }\n                }\n            }\n        },\n        nullptr);\n    past_block_num_[layer_id] = new_block_num;\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of writting KV Cache: %f s\\n\", layer_id,\n           block_idx, duration.count());\n    // printf(\"get_one_block_fp16 duration: %ld\\n\", duration);\n}\n\nvoid KVCache::get_kvcache_one_block_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,\n                                         int layer_id, int block_idx,\n                                         Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    seq_len_ = config_.block_len;\n    k_data_ = reinterpret_cast<uint16_t *>(k_in);\n    v_data_ = reinterpret_cast<uint16_t *>(v_in);\n\n    // printf(\"layer_id: %d, block_idx: %d\\n\", layer_id, block_idx);\n    // Each task gets the k cache or v cache of a certain header\n    backend->do_work_stealing_job(\n        config_.kv_head_num * 2, nullptr,\n        [&](int task_id) {\n            std::vector<float> block_fp32(32);\n            int head_id = task_id / 2;\n            if (task_id & 1) {\n                // get k_cache_\n                for (int k = 0; k < config_.block_len; k++) {\n                    for (int l = 0; l < config_.head_dim / 32; l++) {\n                        block_q4_0 block =\n                            k_cache_q4[layer_id_][head_id][block_idx]\n                                      [k * config_.head_dim / 32 + l];\n                        dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                        for (int m = 0; m < 32; m++) {\n\n                            k_data_[((0 * config_.kv_head_num + head_id) *\n                                         seq_len_ +\n                                     0 * config_.block_len + k) *\n                                        config_.head_dim +\n                                    l * 32 + m] =\n                                GGML_FP32_TO_FP16(block_fp32[m]);\n                        }\n                    }\n                }\n            } else {\n                // get v_cache_\n                for (int k = 0; k < config_.block_len / 32; k++) {\n                    for (int l = 0; l < config_.head_dim; l++) {\n                        block_q4_0 block =\n                            v_cache_q4[layer_id_][head_id][block_idx]\n                                      [l * config_.block_len / 32 + k];\n                        dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                        for (int m = 0; m < 32; m++) {\n\n                            v_data_[((0 * config_.kv_head_num + head_id) *\n                                         seq_len_ +\n                                     0 * config_.block_len + k * 32 + m) *\n                                        config_.head_dim +\n                                    l] = GGML_FP32_TO_FP16(block_fp32[m]);\n                        }\n                    }\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"layer %d block %d time of reading KV Cache: %f s\\n\", layer_id,\n           block_idx, duration.count());\n    // printf(\"get_one_block_fp16 duration: %ld\\n\", duration);\n}\n\n// k_in: (batch_size, seq_len, head_num, head_dim)\n// v_in: (batch_size, seq_len, head_num, head_dim)\nvoid KVCache::get_and_update_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,\n                                          int layer_id, int *block_table,\n                                          int batch_size, int max_block_num,\n                                          int *cache_seqlens, int q_len,\n                                          Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    k_data_ = const_cast<uint16_t *>(k_in);\n    v_data_ = const_cast<uint16_t *>(v_in);\n\n    // Each task updates the k cache and v cache of a certain header\n    backend->do_work_stealing_job(\n        config_.kv_head_num * max_block_num * batch_size, nullptr,\n        [&](int task_id) {\n            // printf(\"block_idx: %d, task_id: %d\\n\", block_idx, task_id);\n            std::vector<float> block_fp32(32);\n            int batch_id = task_id / (config_.kv_head_num * max_block_num);\n            int block_id = (task_id / config_.kv_head_num) % max_block_num;\n            int head_id = task_id % config_.kv_head_num;\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            int seq_len = cache_seqlens[batch_id];\n            int block_l = block_id * config_.block_len;\n            int block_r = block_id * config_.block_len + config_.block_len;\n\n            if (block_l < seq_len) {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            k_data_\n                                [batch_id *\n                                     (max_block_num * config_.block_len *\n                                      config_.kv_head_num * config_.head_dim) +\n                                 block_id *\n                                     (config_.block_len * config_.kv_head_num *\n                                      config_.head_dim) +\n                                 k * (config_.kv_head_num * config_.head_dim) +\n                                 head_id * config_.head_dim + l] =\n                                    k_cache_fp16_[layer_id_][head_id][block_idx]\n                                                 [k * config_.head_dim + l];\n                            v_data_\n                                [batch_id *\n                                     (max_block_num * config_.block_len *\n                                      config_.kv_head_num * config_.head_dim) +\n                                 block_id *\n                                     (config_.block_len * config_.kv_head_num *\n                                      config_.head_dim) +\n                                 k * (config_.kv_head_num * config_.head_dim) +\n                                 head_id * config_.head_dim + l] =\n                                    v_cache_fp16_[layer_id_][head_id][block_idx]\n                                                 [l * config_.block_len + k];\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    // get k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q4_0 block =\n                                k_cache_q4[layer_id_][head_id][block_idx]\n                                          [k * config_.head_dim / 32 + l];\n                            dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                k_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        k * (config_.kv_head_num *\n                                             config_.head_dim) +\n                                        head_id * config_.head_dim + l * 32 +\n                                        m] = GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                    // get v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q4_0 block =\n                                v_cache_q4[layer_id_][head_id][block_idx]\n                                          [l * config_.block_len / 32 + k];\n                            dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len)\n                                    break;\n                                v_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        (k * 32 + m) * config_.kv_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    // get k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q8_0 block =\n                                k_cache_q8[layer_id_][head_id][block_idx]\n                                          [k * config_.head_dim / 32 + l];\n                            dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                k_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        k * (config_.kv_head_num *\n                                             config_.head_dim) +\n                                        head_id * config_.head_dim + l * 32 +\n                                        m] = GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                    // get v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q8_0 block =\n                                v_cache_q8[layer_id_][head_id][block_idx]\n                                          [l * config_.block_len / 32 + k];\n                            dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len)\n                                    break;\n                                v_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        (k * 32 + m) * config_.kv_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                }\n            }\n            if (block_r > seq_len && block_l < seq_len + q_len) {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >=\n                                seq_len + q_len ||\n                            block_id * config_.block_len + k < seq_len)\n                            continue;\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            k_cache_fp16_[layer_id_][head_id][block_idx]\n                                         [k * config_.head_dim + l] = k_data_\n                                             [batch_id * (max_block_num *\n                                                          config_.block_len *\n                                                          config_.kv_head_num *\n                                                          config_.head_dim) +\n                                              block_id * (config_.block_len *\n                                                          config_.kv_head_num *\n                                                          config_.head_dim) +\n                                              k * (config_.kv_head_num *\n                                                   config_.head_dim) +\n                                              head_id * config_.head_dim + l];\n                            v_cache_fp16_[layer_id_][head_id][block_idx]\n                                         [l * config_.block_len + k] = v_data_\n                                             [batch_id * (max_block_num *\n                                                          config_.block_len *\n                                                          config_.kv_head_num *\n                                                          config_.head_dim) +\n                                              block_id * (config_.block_len *\n                                                          config_.kv_head_num *\n                                                          config_.head_dim) +\n                                              k * (config_.kv_head_num *\n                                                   config_.head_dim) +\n                                              head_id * config_.head_dim + l];\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    // fill k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >=\n                                seq_len + q_len ||\n                            block_id * config_.block_len + k < seq_len)\n                            continue;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q4_0 block;\n                            for (int m = 0; m < 32; m++) {\n\n                                block_fp32[m] = GGML_FP16_TO_FP32(\n                                    k_data_[batch_id * (max_block_num *\n                                                        config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            block_id * (config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            k * (config_.kv_head_num *\n                                                 config_.head_dim) +\n                                            head_id * config_.head_dim +\n                                            l * 32 + m]);\n                            }\n                            quantize_row_q4_0(block_fp32.data(), &block, 32);\n                            k_cache_q4[layer_id_][head_id][block_idx]\n                                      [k * config_.head_dim / 32 + l] = block;\n                        }\n                    }\n\n                    // fill v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q4_0 block;\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len + q_len) {\n                                    block_fp32[m] = 0;\n                                    continue;\n                                }\n                                block_fp32[m] = GGML_FP16_TO_FP32(\n                                    v_data_[batch_id * (max_block_num *\n                                                        config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            block_id * (config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            (k * 32 + m) * config_.kv_head_num *\n                                                config_.head_dim +\n                                            head_id * config_.head_dim + l]);\n                            }\n                            quantize_row_q4_0(block_fp32.data(), &block, 32);\n                            v_cache_q4[layer_id_][head_id][block_idx]\n                                      [l * config_.block_len / 32 + k] = block;\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    // fill k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >=\n                                seq_len + q_len ||\n                            block_id * config_.block_len + k < seq_len)\n                            continue;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q8_0 block;\n                            for (int m = 0; m < 32; m++) {\n\n                                block_fp32[m] = GGML_FP16_TO_FP32(\n                                    k_data_[batch_id * (max_block_num *\n                                                        config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            block_id * (config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            k * (config_.kv_head_num *\n                                                 config_.head_dim) +\n                                            head_id * config_.head_dim +\n                                            l * 32 + m]);\n                            }\n                            quantize_row_q8_0(block_fp32.data(), &block, 32);\n                            k_cache_q8[layer_id_][head_id][block_idx]\n                                      [k * config_.head_dim / 32 + l] = block;\n                        }\n                    }\n\n                    // fill v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q8_0 block;\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len + q_len) {\n                                    block_fp32[m] = 0;\n                                    continue;\n                                }\n                                block_fp32[m] = GGML_FP16_TO_FP32(\n                                    v_data_[batch_id * (max_block_num *\n                                                        config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            block_id * (config_.block_len *\n                                                        config_.kv_head_num *\n                                                        config_.head_dim) +\n                                            (k * 32 + m) * config_.kv_head_num *\n                                                config_.head_dim +\n                                            head_id * config_.head_dim + l]);\n                            }\n                            quantize_row_q8_0(block_fp32.data(), &block, 32);\n                            v_cache_q8[layer_id_][head_id][block_idx]\n                                      [l * config_.block_len / 32 + k] = block;\n                        }\n                    }\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n\n    // printf(\"layer %d time of reading and updating KV Cache: %f s\\n\",\n    // layer_id,\n    //        duration.count());\n}\n\nvoid KVCache::update_importance(const ggml_fp16_t *importance, int layer_id,\n                                int *block_table, int batch_size,\n                                int max_block_num, int *offset, int width,\n                                Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    importance_data_ = const_cast<uint16_t *>(importance);\n\n    // Each task updates the importance of a certain position\n    backend->do_work_stealing_job(\n        max_block_num * batch_size, nullptr,\n        [&](int task_id) {\n            int block_id = task_id % max_block_num;\n            int batch_id = task_id / max_block_num;\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            if (block_id > (offset[batch_id] + width) / config_.block_len) {\n                return;\n            }\n            for (int k = 0; k < config_.block_len; k++) {\n                for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                    importance_[layer_id_][block_idx][k][head_id] =\n                        GGML_FP32_TO_FP16(\n                            GGML_FP16_TO_FP32(\n                                importance_data_[batch_id * max_block_num *\n                                                     config_.block_len *\n                                                     config_.q_head_num +\n                                                 (block_id * config_.block_len +\n                                                  k) *\n                                                     config_.q_head_num +\n                                                 head_id]) +\n                            GGML_FP16_TO_FP32(\n                                importance_[layer_id_][block_idx][k][head_id]));\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n\n    // printf(\"layer %d time of updating importance: %f s\\n\", layer_id,\n    //        duration.count());\n}\n\nvoid KVCache::get_kvcache_fp16(ggml_fp16_t *k_in, ggml_fp16_t *v_in,\n                               int layer_id, int *block_table, int batch_size,\n                               int max_block_num, int *cache_seqlens,\n                               Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    k_data_ = const_cast<uint16_t *>(k_in);\n    v_data_ = const_cast<uint16_t *>(v_in);\n\n    // Each task updates the k cache and v cache of a certain header\n    backend->do_work_stealing_job(\n        config_.kv_head_num * max_block_num * batch_size, nullptr,\n        [&](int task_id) {\n            // printf(\"block_idx: %d, task_id: %d\\n\", block_idx, task_id);\n            std::vector<float> block_fp32(32);\n            int batch_id = task_id / (config_.kv_head_num * max_block_num);\n            int block_id = (task_id / config_.kv_head_num) % max_block_num;\n            int head_id = task_id % config_.kv_head_num;\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            int seq_len = cache_seqlens[batch_id];\n            int block_l = block_id * config_.block_len;\n            int block_r = block_id * config_.block_len + config_.block_len;\n\n            if (block_l < seq_len) {\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            k_data_\n                                [batch_id *\n                                     (max_block_num * config_.block_len *\n                                      config_.kv_head_num * config_.head_dim) +\n                                 block_id *\n                                     (config_.block_len * config_.kv_head_num *\n                                      config_.head_dim) +\n                                 k * (config_.kv_head_num * config_.head_dim) +\n                                 head_id * config_.head_dim + l] =\n                                    k_cache_fp16_[layer_id_][head_id][block_idx]\n                                                 [k * config_.head_dim + l];\n                            v_data_\n                                [batch_id *\n                                     (max_block_num * config_.block_len *\n                                      config_.kv_head_num * config_.head_dim) +\n                                 block_id *\n                                     (config_.block_len * config_.kv_head_num *\n                                      config_.head_dim) +\n                                 k * (config_.kv_head_num * config_.head_dim) +\n                                 head_id * config_.head_dim + l] =\n                                    v_cache_fp16_[layer_id_][head_id][block_idx]\n                                                 [l * config_.block_len + k];\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    // get k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q4_0 block =\n                                k_cache_q4[layer_id_][head_id][block_idx]\n                                          [k * config_.head_dim / 32 + l];\n                            dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                k_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        k * (config_.kv_head_num *\n                                             config_.head_dim) +\n                                        head_id * config_.head_dim + l * 32 +\n                                        m] = GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                    // get v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q4_0 block =\n                                v_cache_q4[layer_id_][head_id][block_idx]\n                                          [l * config_.block_len / 32 + k];\n                            dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len)\n                                    break;\n                                v_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        (k * 32 + m) * config_.kv_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    // get k_cache_\n                    for (int k = 0; k < config_.block_len; k++) {\n                        if (block_id * config_.block_len + k >= seq_len)\n                            break;\n                        for (int l = 0; l < config_.head_dim / 32; l++) {\n                            block_q8_0 block =\n                                k_cache_q8[layer_id_][head_id][block_idx]\n                                          [k * config_.head_dim / 32 + l];\n                            dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                k_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        k * (config_.kv_head_num *\n                                             config_.head_dim) +\n                                        head_id * config_.head_dim + l * 32 +\n                                        m] = GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                    // get v_cache_\n                    for (int k = 0; k < config_.block_len / 32; k++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            block_q8_0 block =\n                                v_cache_q8[layer_id_][head_id][block_idx]\n                                          [l * config_.block_len / 32 + k];\n                            dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                            for (int m = 0; m < 32; m++) {\n\n                                if (block_id * config_.block_len + k * 32 + m >=\n                                    seq_len)\n                                    break;\n                                v_data_[batch_id *\n                                            (max_block_num * config_.block_len *\n                                             config_.kv_head_num *\n                                             config_.head_dim) +\n                                        block_id * (config_.block_len *\n                                                    config_.kv_head_num *\n                                                    config_.head_dim) +\n                                        (k * 32 + m) * config_.kv_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(block_fp32[m]);\n                            }\n                        }\n                    }\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n}\n\nvoid KVCache::update_kvcache_fp16(const ggml_fp16_t *k_in,\n                                  const ggml_fp16_t *v_in, int layer_id,\n                                  int *block_table, int batch_size,\n                                  int max_block_num, int *cache_seqlens,\n                                  int q_len, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    k_data_ = const_cast<uint16_t *>(k_in);\n    v_data_ = const_cast<uint16_t *>(v_in);\n    // Each task updates the k cache and v cache of a certain header\n    backend->do_work_stealing_job(\n        batch_size * config_.kv_head_num * q_len, nullptr,\n        [&](int task_id) {\n            int batch_id = task_id / (config_.kv_head_num * q_len);\n            int head_id = task_id / q_len % config_.kv_head_num;\n            int seq_len = cache_seqlens[batch_id] + task_id % q_len;\n            int q_offset = task_id % q_len;\n\n            int block_id = seq_len / config_.block_len;\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n            int pos_in_block = seq_len % config_.block_len;\n\n            if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                for (int l = 0; l < config_.head_dim; l++) {\n                    k_cache_fp16_[layer_id_][head_id][block_idx]\n                                 [pos_in_block * config_.head_dim + l] =\n                                     k_data_[batch_id *\n                                                 (q_len * config_.kv_head_num *\n                                                  config_.head_dim) +\n                                             q_offset * config_.kv_head_num *\n                                                 config_.head_dim +\n                                             head_id * config_.head_dim + l];\n                    v_cache_fp16_[layer_id_][head_id][block_idx]\n                                 [l * config_.block_len + pos_in_block] =\n                                     v_data_[batch_id *\n                                                 (q_len * config_.kv_head_num *\n                                                  config_.head_dim) +\n                                             q_offset * config_.kv_head_num *\n                                                 config_.head_dim +\n                                             head_id * config_.head_dim + l];\n                }\n            } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                std::vector<float> block_fp32(32);\n                // fill k_cache_\n                for (int l = 0; l < config_.head_dim / 32; l++) {\n                    block_q4_0 block;\n                    for (int m = 0; m < 32; m++) {\n\n                        block_fp32[m] = GGML_FP16_TO_FP32(\n                            k_data_[batch_id * (q_len * config_.kv_head_num *\n                                                config_.head_dim) +\n                                    head_id * config_.head_dim + l * 32 + m]);\n                    }\n                    quantize_row_q4_0(block_fp32.data(), &block, 32);\n\n                    k_cache_q4[layer_id_][head_id][block_idx]\n                              [pos_in_block * config_.head_dim / 32 + l] =\n                                  block;\n                }\n\n                // fill v_cache_\n                for (int l = 0; l < config_.head_dim; l++) {\n                    block_q4_0 block = v_cache_q4[layer_id_][head_id][block_idx]\n                                                 [l * config_.block_len / 32 +\n                                                  pos_in_block / 32];\n                    dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                    block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32(\n                        v_data_[batch_id * (q_len * config_.kv_head_num *\n                                            config_.head_dim) +\n                                head_id * config_.head_dim + l]);\n                    quantize_row_q4_0(block_fp32.data(), &block, 32);\n                    v_cache_q4[layer_id_][head_id][block_idx]\n                              [l * config_.block_len / 32 + pos_in_block / 32] =\n                                  block;\n                }\n            } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                std::vector<float> block_fp32(32);\n                // fill k_cache_\n                for (int l = 0; l < config_.head_dim / 32; l++) {\n                    block_q8_0 block;\n                    for (int m = 0; m < 32; m++) {\n\n                        block_fp32[m] = GGML_FP16_TO_FP32(\n                            k_data_[batch_id * (q_len * config_.kv_head_num *\n                                                config_.head_dim) +\n                                    head_id * config_.head_dim + l * 32 + m]);\n                    }\n                    quantize_row_q8_0(block_fp32.data(), &block, 32);\n\n                    k_cache_q8[layer_id_][head_id][block_idx]\n                              [pos_in_block * config_.head_dim / 32 + l] =\n                                  block;\n                }\n\n                // fill v_cache_\n                for (int l = 0; l < config_.head_dim; l++) {\n                    block_q8_0 block = v_cache_q8[layer_id_][head_id][block_idx]\n                                                 [l * config_.block_len / 32 +\n                                                  pos_in_block / 32];\n                    dequantize_row_q8_0(&block, block_fp32.data(), 32);\n                    block_fp32[pos_in_block % 32] = GGML_FP16_TO_FP32(\n                        v_data_[batch_id * (q_len * config_.kv_head_num *\n                                            config_.head_dim) +\n                                head_id * config_.head_dim + l]);\n                    quantize_row_q8_0(block_fp32.data(), &block, 32);\n                    v_cache_q8[layer_id_][head_id][block_idx]\n                              [l * config_.block_len / 32 + pos_in_block / 32] =\n                                  block;\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    // printf(\"layer %d time of reading KV Cache: %f s\\n\", layer_id,\n    //        duration.count());\n}\n\nvoid KVCache::get_all_kvcache_one_layer(int layer_id, ggml_fp16_t *k_in,\n                                        ggml_fp16_t *v_in, Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    layer_id_ = layer_id;\n    seq_len_ = config_.block_len;\n    block_num_ = get_cache_total_block_num();\n    k_data_ = reinterpret_cast<uint16_t *>(k_in);\n    v_data_ = reinterpret_cast<uint16_t *>(v_in);\n\n    // Each task gets the k cache or v cache of a certain header\n    backend->do_work_stealing_job(\n        config_.kv_head_num * past_block_num_[layer_id] * 2, nullptr,\n        [&](int task_id) {\n            std::vector<float> block_fp32(32);\n            int head_id = task_id / 2 / past_block_num_[layer_id];\n            int block_idx = task_id / 2 % past_block_num_[layer_id];\n            if (block_idx >= block_num_)\n                return;\n\n            int max_offset = 0;\n            if (task_id & 1) {\n                // get k_cache_\n                for (int k = 0; k < config_.block_len; k++) {\n                    if (block_idx * seq_len_ + k >= cache_total_len_)\n                        break;\n                    for (int l = 0; l < config_.head_dim / 32; l++) {\n                        block_q4_0 block =\n                            k_cache_q4[layer_id_][head_id][block_idx]\n                                      [k * config_.head_dim / 32 + l];\n                        dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                        for (int m = 0; m < 32; m++) {\n\n                            k_data_[(head_id * cache_total_len_ +\n                                     block_idx * config_.block_len + k) *\n                                        config_.head_dim +\n                                    l * 32 + m] =\n                                GGML_FP32_TO_FP16(block_fp32[m]);\n                            max_offset = std::max(\n                                max_offset,\n                                (int)(head_id * cache_total_len_ +\n                                      block_idx * config_.block_len + k) *\n                                        config_.head_dim +\n                                    l * 32 + m);\n                        }\n                    }\n                }\n            } else {\n                // get v_cache_\n                for (int k = 0; k < config_.block_len / 32; k++) {\n                    for (int l = 0; l < config_.head_dim; l++) {\n                        block_q4_0 block =\n                            v_cache_q4[layer_id_][head_id][block_idx]\n                                      [l * config_.block_len / 32 + k];\n                        dequantize_row_q4_0(&block, block_fp32.data(), 32);\n                        for (int m = 0; m < 32; m++) {\n\n                            if (block_idx * seq_len_ + k * 32 + m >=\n                                cache_total_len_)\n                                break;\n                            v_data_[(head_id * cache_total_len_ +\n                                     block_idx * config_.block_len + k * 32 +\n                                     m) *\n                                        config_.head_dim +\n                                    l] = GGML_FP32_TO_FP16(block_fp32[m]);\n                            max_offset =\n                                std::max(max_offset,\n                                         (int)((head_id * cache_total_len_ +\n                                                block_idx * config_.block_len +\n                                                k * 32 + m) *\n                                                   config_.head_dim +\n                                               l));\n                        }\n                    }\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    // printf(\"layer %d block num %d time of reading all KV Cache: %f s\\n\",\n    //        layer_id, block_num_, duration.count());\n}\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/kvcache/kvcache_utils.cpp",
    "content": "/**\n * @Description  :\n * @Author       : Jianwei Dong\n * @Date         : 2024-08-26 22:47:06\n * @Version      : 1.0.0\n * @LastEditors  : Jianwei Dong\n * @LastEditTime : 2024-08-26 22:47:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n\n#include \"kvcache.h\"\n\n#include <chrono>\n\nstd::string ggml_type_to_string(ggml_type type) {\n    switch (type) {\n    case GGML_TYPE_F32:\n        return \"GGML_TYPE_F32\";\n    case GGML_TYPE_F16:\n        return \"GGML_TYPE_F16\";\n    case GGML_TYPE_Q4_0:\n        return \"GGML_TYPE_Q4_0\";\n    case GGML_TYPE_Q8_0:\n        return \"GGML_TYPE_Q8_0\";\n    }\n    return \"UNDIFINED\";\n}\nstd::string AnchorTypeToString(AnchorType type) {\n    switch (type) {\n    case AnchorType::DYNAMIC:\n        return \"DYNAMIC\";\n    case AnchorType::BLOCK_MEAN:\n        return \"BLOCK_MEAN\";\n    case AnchorType::BLOCK_MAX:\n        return \"BLOCK_MAX\";\n    case AnchorType::FIXED_ANCHOR:\n        return \"FIXED_ANCHOR\";\n    case AnchorType::QUEST:\n        return \"QUEST\";\n    }\n    return \"UNDIFINED\";\n}\nstd::string RetrievalTypeToString(RetrievalType type) {\n    switch (type) {\n    case RetrievalType::LAYER:\n        return \"SHARED\";\n    case RetrievalType::KVHEAD:\n        return \"SEPARATE\";\n    case RetrievalType::QHEAD:\n        return \"INDIVIDUAL\";\n    }\n    return \"UNDIFINED\";\n}\nKVCacheConfig::KVCacheConfig(int layer_num, int kv_head_num, int q_head_num,\n                             int head_dim, int block_len, int anchor_num,\n                             AnchorType anchor_type, ggml_type kv_type,\n                             RetrievalType retrieval_type, int layer_step,\n                             int token_step, int layer_offset,\n                             int max_block_num, int max_batch_size,\n                             int max_thread_num)\n    : layer_num(layer_num), kv_head_num(kv_head_num), q_head_num(q_head_num),\n      head_dim(head_dim), block_len(block_len), anchor_num(anchor_num),\n      anchor_type(anchor_type), kv_type(kv_type),\n      retrieval_type(retrieval_type), layer_step(layer_step),\n      token_step(token_step), layer_offset(layer_offset),\n      max_block_num(max_block_num), max_batch_size(max_batch_size),\n      max_thread_num(max_thread_num) {\n    printf(\n        \"layer_num: %d, kv_head_num: %d, q_head_num: %d, head_dim: %d, \"\n        \"block_len: %d, anchor_num: %d, anchor_type: %s, kv_type: %s, \"\n        \"retrieval_type: %s, layer_step: %d, token_step: %d, layer_offset: %d,\"\n        \"max_block_num: %d, max_batch_size: %d, max_thread_num: %d\\n\",\n        layer_num, kv_head_num, q_head_num, head_dim, block_len, anchor_num,\n        AnchorTypeToString(anchor_type).c_str(),\n        ggml_type_to_string(kv_type).c_str(),\n        RetrievalTypeToString(retrieval_type).c_str(), layer_step, token_step,\n        layer_offset, max_block_num, max_batch_size, max_thread_num);\n    assert(q_head_num % kv_head_num == 0);\n}\nKVCache::KVCache(KVCacheConfig config) {\n    this->config_ = config;\n\n    n_gqa_ = config_.q_head_num / config_.kv_head_num;\n    if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n        // TODO: Elegant implement\n        k_cache_fp16_.resize(config_.layer_num);\n        v_cache_fp16_.resize(config_.layer_num);\n        selected_blocks_num_history_.resize(config_.layer_num /\n                                            config_.layer_step);\n        if (config_.retrieval_type == RetrievalType::LAYER) {\n            selected_blocks_history_.resize(config_.layer_num /\n                                            config_.layer_step);\n        } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n            selected_blocks_history_kvhead_.resize(config_.layer_num /\n                                                   config_.layer_step);\n        } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n        }\n    } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n        k_cache_q4.resize(config.layer_num);\n        v_cache_q4.resize(config.layer_num);\n    } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n        k_cache_q8.resize(config.layer_num);\n        v_cache_q8.resize(config.layer_num);\n    } else {\n        assert(false);\n    }\n    anchor_.resize(config.layer_num * config.max_block_num * config.anchor_num *\n                   config.q_head_num * config.head_dim);\n    importance_.resize(config.layer_num);\n    past_block_num_.resize(config.layer_num);\n    for (int i = 0; i < config.layer_num; i++) {\n        past_block_num_[i] = 0;\n    }\n\n    ThreadResize(config.max_thread_num);\n    BatchResize(config.max_batch_size);\n    BlockResize(config.max_block_num);\n    q_fp32.resize(n_gqa_ * config.head_dim);\n}\n\nvoid KVCache::ThreadResize(int thread_num) {\n    thread_local_output_q8_0_.resize(thread_num);\n    thread_local_attn_score_.resize(thread_num);\n    thread_local_output_fp32_.resize(thread_num);\n    thread_local_attn_lse_.resize(thread_num);\n    thread_local_cur_output_fp32_.resize(thread_num);\n    thread_local_cur_attn_lse_.resize(thread_num);\n    thread_local_draft_.resize(thread_num);\n    thread_cur_head_idx_.resize(thread_num);\n    thread_local_attn_mask_.resize(thread_num);\n    for (int i = 0; i < thread_num; i++) {\n        thread_local_output_q8_0_[i].resize(n_gqa_ * config_.head_dim / QK8_0);\n        thread_local_attn_score_[i].resize(n_gqa_ * config_.block_len);\n        thread_local_output_fp32_[i].resize(n_gqa_ * config_.head_dim);\n        thread_local_attn_lse_[i].resize(n_gqa_);\n        thread_local_cur_output_fp32_[i].resize(n_gqa_ * config_.head_dim);\n        thread_local_cur_attn_lse_[i].resize(n_gqa_);\n        thread_local_draft_[i].resize(\n            2 * n_gqa_ * config_.block_len + 6 * n_gqa_ * config_.head_dim +\n            2 * config_.block_len * config_.head_dim +\n            config_.block_len * config_.head_dim / QK4_0);\n        thread_local_attn_mask_[i].resize(config_.block_len / 8);\n    }\n}\nvoid KVCache::BatchResize(int batch_size) {\n    mutex_.resize(batch_size);\n    q_q8_0_.resize(batch_size);\n    q_fp32_.resize(batch_size);\n    output_fp32_.resize(batch_size);\n    attn_lse_.resize(batch_size);\n    block_lse_.resize(batch_size);\n    attn_sparsity_.resize(batch_size);\n\n    if (config_.retrieval_type == RetrievalType::LAYER) {\n        block_table_before_retrieval_.resize(batch_size);\n        block_table_after_retrieval_.resize(batch_size);\n\n        for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {\n            selected_blocks_history_[i].resize(batch_size);\n        }\n\n    } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n        block_table_before_retrieval_kvhead_.resize(batch_size);\n        block_table_after_retrieval_kvhead_.resize(batch_size);\n        for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {\n            selected_blocks_history_kvhead_[i].resize(batch_size);\n        }\n    } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n        block_table_before_retrieval_qhead_.resize(batch_size);\n        block_table_after_retrieval_qhead_.resize(batch_size);\n    }\n    cache_seqlens_.resize(batch_size);\n    if (config_.retrieval_type == RetrievalType::LAYER) {\n        block_similar_.resize(batch_size);\n    } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n        block_similar_kv_head_.resize(batch_size);\n    } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n        block_similar_q_head_.resize(batch_size);\n    }\n    for (int i = 0; i < batch_size; i++) {\n        top_similar_block_.resize(batch_size);\n\n        mutex_[i].resize(config_.kv_head_num);\n        q_q8_0_[i].resize(config_.kv_head_num);\n        q_fp32_[i].resize(config_.kv_head_num);\n        output_fp32_[i].resize(config_.kv_head_num);\n        attn_lse_[i].resize(config_.kv_head_num);\n\n        for (int j = 0; j < config_.kv_head_num; j++) {\n            if (!mutex_[i][j]) {\n                mutex_[i][j] = std::make_unique<std::mutex>();\n            }\n            q_q8_0_[i][j].resize(n_gqa_ * config_.head_dim / QK8_0);\n            q_fp32_[i][j].resize(n_gqa_ * config_.head_dim);\n            output_fp32_[i][j].resize(n_gqa_ * config_.head_dim);\n            attn_lse_[i][j].resize(n_gqa_);\n        }\n    }\n    avg_q.resize(batch_size);\n    avg_q_fp16.resize(batch_size);\n    for (int i = 0; i < batch_size; i++) {\n        attn_sparsity_[i].resize(config_.q_head_num);\n        avg_q[i].resize(config_.q_head_num * config_.head_dim);\n        avg_q_fp16[i].resize(config_.q_head_num * config_.head_dim);\n    }\n}\n\nvoid KVCache::BlockResize(int max_block_num) {\n    sin_.resize(max_block_num * config_.block_len);\n    cos_.resize(max_block_num * config_.block_len);\n    for (int i = 0; i < max_block_num * config_.block_len; i++) {\n        sin_[i].resize(config_.head_dim);\n        cos_[i].resize(config_.head_dim);\n    }\n\n    for (int i = 0; i < config_.layer_num / config_.layer_step; i++) {\n        for (int j = 0; j < config_.max_batch_size; j++) {\n            if (config_.retrieval_type == RetrievalType::LAYER) {\n                selected_blocks_history_[i][j].resize(max_block_num);\n            } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n                selected_blocks_history_kvhead_[i][j].resize(max_block_num);\n                for (int k = 0; k < config_.max_block_num; k++) {\n                    selected_blocks_history_kvhead_[i][j][k].resize(\n                        config_.kv_head_num);\n                }\n            } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n            }\n        }\n    }\n\n    for (int layer_id = 0; layer_id < config_.layer_num; layer_id++) {\n        importance_[layer_id].resize(max_block_num);\n\n        if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n            // TODO: Elegant implement\n            k_cache_fp16_[layer_id].resize(config_.kv_head_num);\n            v_cache_fp16_[layer_id].resize(config_.kv_head_num);\n\n            for (int i = 0; i < config_.kv_head_num; i++) {\n                k_cache_fp16_[layer_id][i].resize(max_block_num);\n                v_cache_fp16_[layer_id][i].resize(max_block_num);\n\n                for (int j = 0; j < max_block_num; j++) {\n                    k_cache_fp16_[layer_id][i][j].resize(config_.block_len *\n                                                         config_.head_dim);\n                    v_cache_fp16_[layer_id][i][j].resize(config_.block_len *\n                                                         config_.head_dim);\n                }\n            }\n\n        } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n            k_cache_q4[layer_id].resize(config_.kv_head_num);\n            v_cache_q4[layer_id].resize(config_.kv_head_num);\n            for (int i = 0; i < config_.kv_head_num; i++) {\n                k_cache_q4[layer_id][i].resize(max_block_num);\n                v_cache_q4[layer_id][i].resize(max_block_num);\n\n                for (int j = 0; j < max_block_num; j++) {\n                    k_cache_q4[layer_id][i][j].resize(config_.block_len *\n                                                      config_.head_dim / 32);\n                    v_cache_q4[layer_id][i][j].resize(config_.block_len *\n                                                      config_.head_dim / 32);\n                }\n            }\n        } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n            k_cache_q8[layer_id].resize(config_.kv_head_num);\n            v_cache_q8[layer_id].resize(config_.kv_head_num);\n            for (int i = 0; i < config_.kv_head_num; i++) {\n                k_cache_q8[layer_id][i].resize(max_block_num);\n                v_cache_q8[layer_id][i].resize(max_block_num);\n\n                for (int j = 0; j < max_block_num; j++) {\n                    k_cache_q8[layer_id][i][j].resize(config_.block_len *\n                                                      config_.head_dim / 32);\n                    v_cache_q8[layer_id][i][j].resize(config_.block_len *\n                                                      config_.head_dim / 32);\n                }\n            }\n        } else {\n            assert(false);\n        }\n        for (int i = 0; i < config_.max_batch_size; i++) {\n            if (config_.retrieval_type == RetrievalType::LAYER) {\n                block_similar_[i].resize(max_block_num);\n                block_table_before_retrieval_[i].resize(max_block_num);\n                block_table_after_retrieval_[i].resize(max_block_num);\n            } else if (config_.retrieval_type == RetrievalType::KVHEAD) {\n                block_similar_kv_head_[i].resize(max_block_num);\n                block_table_before_retrieval_kvhead_[i].resize(max_block_num);\n                block_table_after_retrieval_kvhead_[i].resize(max_block_num);\n                for (int j = 0; j < max_block_num; j++) {\n                    block_similar_kv_head_[i][j].resize(config_.kv_head_num);\n                    block_table_before_retrieval_kvhead_[i][j].resize(\n                        config_.kv_head_num);\n                    block_table_after_retrieval_kvhead_[i][j].resize(\n                        config_.kv_head_num);\n                }\n            } else if (config_.retrieval_type == RetrievalType::QHEAD) {\n                block_similar_q_head_[i].resize(max_block_num);\n                block_table_before_retrieval_qhead_[i].resize(max_block_num);\n                block_table_after_retrieval_qhead_[i].resize(max_block_num);\n                for (int j = 0; j < max_block_num; j++) {\n                    block_similar_q_head_[i][j].resize(config_.q_head_num);\n                    block_table_before_retrieval_qhead_[i][j].resize(\n                        config_.q_head_num);\n                    block_table_after_retrieval_qhead_[i][j].resize(\n                        config_.q_head_num);\n                }\n            }\n            block_lse_[i].resize(max_block_num);\n            for (int j = 0; j < max_block_num; j++) {\n                block_lse_[i][j].resize(config_.q_head_num);\n            }\n        }\n\n        for (int i = 0; i < max_block_num; i++) {\n            importance_[layer_id][i].resize(config_.block_len);\n            for (int j = 0; j < config_.block_len; j++) {\n                importance_[layer_id][i][j].resize(config_.q_head_num);\n            }\n        }\n    }\n}\n\nvoid KVCache::calc_anchor_all_layers(int *block_table, int *cache_seqlens,\n                                     int batch_size, int max_block_num,\n                                     Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    // Each task updates the importance of a certain block\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        config_.layer_num * batch_size * max_block_num, nullptr,\n        [&](int task_id) {\n            int layer_id = task_id / (batch_size * max_block_num);\n            int batch_id = (task_id / max_block_num) % batch_size;\n            int block_id = task_id % max_block_num;\n            // If the block is out of the sequence length, skip it. In\n            // particular, the last block of the sequence that is shorter than\n            // the block length should be skipped.\n\n            if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n\n            std::vector<float> block_fp32(32);\n            if (config_.anchor_type == AnchorType::DYNAMIC) {\n\n                // clear anchor_\n                for (int anchor_id = 0; anchor_id < 1; anchor_id++) {\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            anchor_[layer_id * config_.max_block_num *\n                                        config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    block_idx * config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    anchor_id * config_.q_head_num *\n                                        config_.head_dim +\n                                    head_id * config_.head_dim + l] = 0;\n                        }\n                    }\n                }\n\n                // find top anchor_num importances and their corresponding\n                // positions in the importance_ tensor\n                // TODO: Move top_importances to the class member to avoid\n                // repeated memory allocation\n                std::priority_queue<\n                    std::pair<float, std::pair<int, int>>,\n                    std::vector<std::pair<float, std::pair<int, int>>>,\n                    std::greater<>>\n                    top_importances;\n                for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                    for (int k = 0; k < seq_len_; k++) {\n                        top_importances.push(std::make_pair(\n                            GGML_FP16_TO_FP32(\n                                importance_[layer_id][block_idx][k][head_id]),\n                            std::make_pair(block_idx, k)));\n                        // TODO: change to config_ item\n                        if (top_importances.size() > config_.anchor_num) {\n                            top_importances.pop();\n                        }\n                    }\n\n                    // fill anchor_\n\n                    for (int l = 0; l < config_.head_dim; l++) {\n                        anchor_[layer_id * config_.max_block_num *\n                                    config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim +\n                                block_idx * config_.anchor_num *\n                                    config_.q_head_num * config_.head_dim +\n                                0 * config_.q_head_num * config_.head_dim +\n                                head_id * config_.head_dim + l] = 0;\n                    }\n                    for (int k = 0; k < config_.anchor_num; k++) {\n                        int top_indice = top_importances.top().second.second;\n                        int top_block_idx = top_importances.top().second.first;\n\n                        if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n\n                            for (int l = 0; l < config_.head_dim; l++) {\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        top_block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        0 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    top_block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    0 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l]) +\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_[layer_id]\n                                                         [head_id / n_gqa_]\n                                                         [top_block_idx]\n                                                         [top_indice *\n                                                              config_.head_dim +\n                                                          l]));\n                            }\n\n                        } else if (config_.kv_type ==\n                                   ggml_type::GGML_TYPE_Q4_0) {\n                            for (int l = 0; l < config_.head_dim / 32; l++) {\n                                block_q4_0 block = k_cache_q4\n                                    [layer_id][head_id / n_gqa_][top_block_idx]\n                                    [top_indice * config_.head_dim / 32 + l];\n                                dequantize_row_q4_0(&block, block_fp32.data(),\n                                                    32);\n                                for (int m = 0; m < 32; m++) {\n                                    anchor_[layer_id * config_.max_block_num *\n                                                config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            top_block_idx * config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            0 * config_.q_head_num *\n                                                config_.head_dim +\n                                            head_id * config_.head_dim +\n                                            l * 32 + m] =\n                                        GGML_FP32_TO_FP16(\n                                            block_fp32[m] / 4 +\n                                            GGML_FP16_TO_FP32(\n                                                anchor_[layer_id *\n                                                            config_\n                                                                .max_block_num *\n                                                            config_.anchor_num *\n                                                            config_.q_head_num *\n                                                            config_.head_dim +\n                                                        top_block_idx *\n                                                            config_.anchor_num *\n                                                            config_.q_head_num *\n                                                            config_.head_dim +\n                                                        0 * config_.q_head_num *\n                                                            config_.head_dim +\n                                                        head_id *\n                                                            config_.head_dim +\n                                                        l * 32 + m]));\n                                }\n                            }\n                        } else if (config_.kv_type ==\n                                   ggml_type::GGML_TYPE_Q8_0) {\n                            for (int l = 0; l < config_.head_dim / 32; l++) {\n                                block_q8_0 block = k_cache_q8\n                                    [layer_id][head_id / n_gqa_][top_block_idx]\n                                    [top_indice * config_.head_dim / 32 + l];\n                                dequantize_row_q8_0(&block, block_fp32.data(),\n                                                    32);\n                                for (int m = 0; m < 32; m++) {\n                                    anchor_[layer_id * config_.max_block_num *\n                                                config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            top_block_idx * config_.anchor_num *\n                                                config_.q_head_num *\n                                                config_.head_dim +\n                                            0 * config_.q_head_num *\n                                                config_.head_dim +\n                                            head_id * config_.head_dim +\n                                            l * 32 + m] =\n                                        GGML_FP32_TO_FP16(\n                                            block_fp32[m] / 4 +\n                                            GGML_FP16_TO_FP32(\n                                                anchor_[layer_id *\n                                                            config_\n                                                                .max_block_num *\n                                                            config_.anchor_num *\n                                                            config_.q_head_num *\n                                                            config_.head_dim +\n                                                        top_block_idx *\n                                                            config_.anchor_num *\n                                                            config_.q_head_num *\n                                                            config_.head_dim +\n                                                        0 * config_.q_head_num *\n                                                            config_.head_dim +\n                                                        head_id *\n                                                            config_.head_dim +\n                                                        l * 32 + m]));\n                                }\n                            }\n                        }\n                        top_importances.pop();\n                    }\n                }\n            } else if (config_.anchor_type == AnchorType::BLOCK_MEAN) {\n                // clear anchor_\n                for (int anchor_id = 0; anchor_id < config_.anchor_num;\n                     anchor_id++) {\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            anchor_[layer_id * config_.max_block_num *\n                                        config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    block_idx * config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    anchor_id * config_.q_head_num *\n                                        config_.head_dim +\n                                    head_id * config_.head_dim + l] = 0;\n                        }\n                    }\n                }\n\n                // fill anchor_\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int k = 0; k < config_.block_len; k++) {\n                            for (int l = 0; l < config_.head_dim; l++) {\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        0 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    0 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l]) +\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_[layer_id]\n                                                         [head_id / n_gqa_]\n                                                         [block_idx]\n                                                         [k * config_.head_dim +\n                                                          l]) /\n                                            config_.block_len);\n                            }\n                        }\n                    }\n                }\n            } else if (config_.anchor_type == AnchorType::BLOCK_MAX) {\n                // clear anchor_\n                for (int anchor_id = 0; anchor_id < config_.anchor_num;\n                     anchor_id++) {\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            anchor_[layer_id * config_.max_block_num *\n                                        config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    block_idx * config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    anchor_id * config_.q_head_num *\n                                        config_.head_dim +\n                                    head_id * config_.head_dim + l] = 0;\n                        }\n                    }\n                }\n\n                // fill anchor_\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int k = 0; k < config_.block_len; k++) {\n                            for (int l = 0; l < config_.head_dim; l++) {\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        0 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(std::max(\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    0 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l]),\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_\n                                                [layer_id][head_id / n_gqa_]\n                                                [block_idx]\n                                                [k * config_.head_dim + l])));\n                            }\n                        }\n                    }\n                }\n            } else if (config_.anchor_type == AnchorType::FIXED_ANCHOR) {\n                // clear anchor_\n                for (int anchor_id = 0; anchor_id < 1; anchor_id++) {\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int l = 0; l < config_.head_dim; l++) {\n                            anchor_[layer_id * config_.max_block_num *\n                                        config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    block_idx * config_.anchor_num *\n                                        config_.q_head_num * config_.head_dim +\n                                    anchor_id * config_.q_head_num *\n                                        config_.head_dim +\n                                    head_id * config_.head_dim + l] = 0;\n                        }\n                    }\n                }\n\n                // fill anchor_\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n\n                    int stride = config_.block_len / config_.anchor_num;\n                    for (int head_id = 0; head_id < config_.q_head_num;\n                         head_id++) {\n                        for (int k = 0, tot = 0;\n                             k < config_.block_len, tot < config_.anchor_num;\n                             k += stride, tot++) {\n                            for (int l = 0; l < config_.head_dim; l++) {\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        0 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    0 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l]) +\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_[layer_id]\n                                                         [head_id / n_gqa_]\n                                                         [block_idx]\n                                                         [k * config_.head_dim +\n                                                          l]) /\n                                            config_.anchor_num);\n                            }\n                        }\n                    }\n                }\n\n            } else if (config_.anchor_type == AnchorType::QUEST) {\n                // clear anchor_\n                for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                    for (int l = 0; l < config_.head_dim; l++) {\n                        anchor_[layer_id * config_.max_block_num *\n                                    config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim +\n                                block_idx * config_.anchor_num *\n                                    config_.q_head_num * config_.head_dim +\n                                1 * config_.q_head_num * config_.head_dim +\n                                head_id * config_.head_dim + l] =\n                            GGML_FP32_TO_FP16(\n                                std::numeric_limits<float>::max());\n\n                        anchor_[layer_id * config_.max_block_num *\n                                    config_.anchor_num * config_.q_head_num *\n                                    config_.head_dim +\n                                block_idx * config_.anchor_num *\n                                    config_.q_head_num * config_.head_dim +\n                                0 * config_.q_head_num * config_.head_dim +\n                                head_id * config_.head_dim + l] =\n                            GGML_FP32_TO_FP16(\n                                std::numeric_limits<float>::min());\n                    }\n                }\n\n                // fill anchor_\n\n                if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                    for (int indice = 0; indice < seq_len_; indice++) {\n                        for (int head_id = 0; head_id < config_.kv_head_num;\n                             head_id++) {\n                            for (int l = 0; l < config_.head_dim; l++) {\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        0 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(std::max(\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_\n                                                [layer_id][head_id][block_idx]\n                                                [indice * config_.head_dim +\n                                                 l]),\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    0 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l])));\n\n                                anchor_[layer_id * config_.max_block_num *\n                                            config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        block_idx * config_.anchor_num *\n                                            config_.q_head_num *\n                                            config_.head_dim +\n                                        1 * config_.q_head_num *\n                                            config_.head_dim +\n                                        head_id * config_.head_dim + l] =\n                                    GGML_FP32_TO_FP16(std::min(\n                                        GGML_FP16_TO_FP32(\n                                            k_cache_fp16_\n                                                [layer_id][head_id][block_idx]\n                                                [indice * config_.head_dim +\n                                                 l]),\n                                        GGML_FP16_TO_FP32(\n                                            anchor_[layer_id *\n                                                        config_.max_block_num *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    block_idx *\n                                                        config_.anchor_num *\n                                                        config_.q_head_num *\n                                                        config_.head_dim +\n                                                    1 * config_.q_head_num *\n                                                        config_.head_dim +\n                                                    head_id * config_.head_dim +\n                                                    l])));\n                            }\n                        }\n                    }\n\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                    for (int indice = 0; indice < seq_len_; indice++) {\n                        for (int head_id = 0; head_id < config_.kv_head_num;\n                             head_id++) {\n                            for (int l = 0; l < config_.head_dim / 32; l++) {\n                                block_q4_0 block =\n                                    k_cache_q4[layer_id][head_id][block_idx]\n                                              [indice * config_.head_dim / 32 +\n                                               l];\n                                dequantize_row_q4_0(&block, block_fp32.data(),\n                                                    32);\n\n                                for (int m = 0; m < 32; m++) {\n                                    for (int gqa_idx = 0; gqa_idx < n_gqa_;\n                                         gqa_idx++) {\n\n                                        anchor_[layer_id *\n                                                    config_.max_block_num *\n                                                    config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                block_idx * config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                0 * config_.q_head_num *\n                                                    config_.head_dim +\n                                                head_id * config_.head_dim +\n                                                l * 32 + m] =\n                                            GGML_FP32_TO_FP16(std::max(\n                                                block_fp32[m],\n                                                GGML_FP16_TO_FP32(\n                                                    anchor_\n                                                        [layer_id *\n                                                             config_\n                                                                 .max_block_num *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         block_idx *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         0 *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         head_id *\n                                                             config_.head_dim +\n                                                         l * 32 + m])));\n\n                                        anchor_[layer_id *\n                                                    config_.max_block_num *\n                                                    config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                block_idx * config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                1 * config_.q_head_num *\n                                                    config_.head_dim +\n                                                head_id * config_.head_dim +\n                                                l * 32 + m] =\n                                            GGML_FP32_TO_FP16(std::min(\n                                                block_fp32[m],\n                                                GGML_FP16_TO_FP32(\n                                                    anchor_\n                                                        [layer_id *\n                                                             config_\n                                                                 .max_block_num *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         block_idx *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         1 *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         head_id *\n                                                             config_.head_dim +\n                                                         l * 32 + m])));\n                                    }\n                                }\n                            }\n                        }\n                    }\n                } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                    for (int indice = 0; indice < seq_len_; indice++) {\n                        for (int head_id = 0; head_id < config_.kv_head_num;\n                             head_id++) {\n                            for (int l = 0; l < config_.head_dim / 32; l++) {\n                                block_q8_0 block =\n                                    k_cache_q8[layer_id][head_id][block_idx]\n                                              [indice * config_.head_dim / 32 +\n                                               l];\n                                dequantize_row_q8_0(&block, block_fp32.data(),\n                                                    32);\n\n                                for (int m = 0; m < 32; m++) {\n                                    for (int gqa_idx = 0; gqa_idx < n_gqa_;\n                                         gqa_idx++) {\n\n                                        anchor_[layer_id *\n                                                    config_.max_block_num *\n                                                    config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                block_idx * config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                0 * config_.q_head_num *\n                                                    config_.head_dim +\n                                                head_id * config_.head_dim +\n                                                l * 32 + m] =\n                                            GGML_FP32_TO_FP16(std::max(\n                                                block_fp32[m],\n                                                GGML_FP16_TO_FP32(\n                                                    anchor_\n                                                        [layer_id *\n                                                             config_\n                                                                 .max_block_num *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         block_idx *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         0 *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         head_id *\n                                                             config_.head_dim +\n                                                         l * 32 + m])));\n\n                                        anchor_[layer_id *\n                                                    config_.max_block_num *\n                                                    config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                block_idx * config_.anchor_num *\n                                                    config_.q_head_num *\n                                                    config_.head_dim +\n                                                1 * config_.q_head_num *\n                                                    config_.head_dim +\n                                                head_id * config_.head_dim +\n                                                l * 32 + m] =\n                                            GGML_FP32_TO_FP16(std::min(\n                                                block_fp32[m],\n                                                GGML_FP16_TO_FP32(\n                                                    anchor_\n                                                        [layer_id *\n                                                             config_\n                                                                 .max_block_num *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         block_idx *\n                                                             config_\n                                                                 .anchor_num *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         1 *\n                                                             config_\n                                                                 .q_head_num *\n                                                             config_.head_dim +\n                                                         head_id *\n                                                             config_.head_dim +\n                                                         l * 32 + m])));\n                                    }\n                                }\n                            }\n                        }\n                    }\n                }\n            } else {\n                assert(false);\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    //    printf(\"time of calc_anchor_all_layers: %f s\\n\", duration.count());\n}\n\nvoid KVCache::clear_importance_all_layers(int *block_table, int *cache_seqlens,\n                                          int batch_size, int max_block_num,\n                                          Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    // Each task updates the importance of a certain block\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        config_.layer_num * batch_size * max_block_num, nullptr,\n        [&](int task_id) {\n            int layer_id = task_id / (batch_size * max_block_num);\n            int batch_id = (task_id / max_block_num) % batch_size;\n            int block_id = task_id % max_block_num;\n            // If the block is out of the sequence length, skip it. In\n            // particular, the last block of the sequence that is shorter than\n            // the block length should be skipped.\n\n            if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n\n            if (config_.anchor_type == AnchorType::DYNAMIC) {\n\n                // clear anchor_\n                for (int head_id = 0; head_id < config_.q_head_num; head_id++) {\n                    for (int l = 0; l < config_.block_len; l++) {\n                        importance_[layer_id][block_idx][l][head_id] = 0;\n                    }\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    //    printf(\"time of clear_importance_all_layerssssss: %f s\\n\",\n    //    duration.count());\n}\n\nvoid KVCache::clear_kvcache_all_layers(int *block_table, int *cache_seqlens,\n                                       int batch_size, int max_block_num,\n                                       Backend *backend) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    // Each task updates the importance of a certain block\n    seq_len_ = config_.block_len;\n    backend->do_work_stealing_job(\n        config_.layer_num * batch_size * max_block_num * config_.kv_head_num,\n        nullptr,\n        [&](int task_id) {\n            int layer_id =\n                task_id / (batch_size * max_block_num * config_.kv_head_num);\n            int batch_id =\n                (task_id / (max_block_num * config_.kv_head_num)) % batch_size;\n            int block_id = task_id / config_.kv_head_num % max_block_num;\n            int head_id = task_id % config_.kv_head_num;\n            // If the block is out of the sequence length, skip it. In\n            // particular, the last block of the sequence that is shorter than\n            // the block length should be skipped.\n            if (cache_seqlens[batch_id] / config_.block_len < block_id) {\n                return;\n            }\n            int block_idx = block_table[batch_id * max_block_num + block_id];\n\n            if (config_.kv_type == ggml_type::GGML_TYPE_F16) {\n                for (int l = 0; l < config_.block_len * config_.head_dim; l++) {\n                    k_cache_fp16_[layer_id][head_id][block_idx][l] = 0;\n                    v_cache_fp16_[layer_id][head_id][block_idx][l] = 0;\n                }\n            } else if (config_.kv_type == ggml_type::GGML_TYPE_Q4_0) {\n                for (int l = 0; l < config_.block_len * config_.head_dim / 32;\n                     l++) {\n                    k_cache_q4[layer_id][head_id][block_idx][l].d = 0;\n                    v_cache_q4[layer_id][head_id][block_idx][l].d = 0;\n                }\n            } else if (config_.kv_type == ggml_type::GGML_TYPE_Q8_0) {\n                for (int l = 0; l < config_.block_len * config_.head_dim / 32;\n                     l++) {\n                    k_cache_q8[layer_id][head_id][block_idx][l].d = 0;\n                    v_cache_q8[layer_id][head_id][block_idx][l].d = 0;\n                }\n            }\n        },\n        nullptr);\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    //    printf(\"time of clear_kvcache_all_layers: %f s\\n\", duration.count());\n}\n\nvoid KVCache::get_sincos(ggml_fp16_t *sin, ggml_fp16_t *cos, int seqlen) {\n    // Timer start\n    auto start = std::chrono::high_resolution_clock::now();\n\n    const uint16_t *sin_data = const_cast<const uint16_t *>(sin);\n    const uint16_t *cos_data = const_cast<const uint16_t *>(cos);\n\n    for (int i = 0; i < seqlen; i++) {\n        for (int j = 0; j < config_.head_dim; j++) {\n            sin_[i][j] = sin_data[i * config_.head_dim + j];\n            cos_[i][j] = cos_data[i * config_.head_dim + j];\n        }\n    }\n\n    // Timer end\n    auto end = std::chrono::high_resolution_clock::now();\n    std::chrono::duration<double> duration = end - start;\n    printf(\"time of get_sincos: %f s\\n\", duration.count());\n}\n\nvoid ggml_vec_scale_f32(const int n, float *y, const float v) {\n#if defined(GGML_USE_ACCELERATE)\n    vDSP_vsmul(y, 1, &v, y, 1, n);\n#elif defined(GGML_SIMD)\n    const int np = (n & ~(GGML_F32_STEP - 1));\n\n    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);\n\n    GGML_F32_VEC ay[GGML_F32_ARR];\n\n    for (int i = 0; i < np; i += GGML_F32_STEP) {\n        for (int j = 0; j < GGML_F32_ARR; j++) {\n            ay[j] = GGML_F32_VEC_LOAD(y + i + j * GGML_F32_EPR);\n            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);\n\n            GGML_F32_VEC_STORE(y + i + j * GGML_F32_EPR, ay[j]);\n        }\n    }\n\n    // leftovers\n    for (int i = np; i < n; ++i) {\n        y[i] *= v;\n    }\n#else\n    // scalar\n    for (int i = 0; i < n; ++i) {\n        y[i] *= v;\n    }\n#endif\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/llamafile/conversion.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022 \n * @LastEditTime : 2024-07-25 10:34:55\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_CONVERSION_H\n#define CPUINFER_CONVERSION_H\n\n#include <memory.h>\n#include \"llama.cpp/ggml.h\"\n\ninline void to_float(const void* input, float* output, int size, ggml_type type) {\n    if (type == ggml_type::GGML_TYPE_F32) {\n        memcpy(output, input, size * sizeof(float));\n    } else {\n        ggml_internal_get_type_traits(type).to_float(input, output, size);\n    }\n}\n\ninline void from_float(const float* input, void* output, int size, ggml_type type) {\n    if (type == ggml_type::GGML_TYPE_F32) {\n        memcpy(output, input, size * sizeof(float));\n    } else {\n        ggml_internal_get_type_traits(type).from_float(input, output, size);\n    }\n}\n\n#endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/llamafile/linear.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-15 07:45:18\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"linear.h\"\n\nLinear::Linear(LinearConfig config) {\n    config_ = config;\n    proj_ = config_.proj;\n\n    std::vector<std::pair<void**, uint64_t>> mem_requests;\n    mem_requests.push_back({(void**)&input_fp32_, sizeof(float) * config_.group_max_len * config_.input_size});\n    mem_requests.push_back({(void**)&proj_input_, config_.group_max_len * config_.input_size * ggml_type_size(ggml_internal_get_type_traits(config_.proj_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.proj_type).vec_dot_type)});\n    mem_requests.push_back({(void**)&proj_output_, sizeof(float) * config_.group_max_len * config_.output_size});\n    shared_mem_buffer.alloc(this, mem_requests);\n}\n\nLinear::~Linear() {\n    shared_mem_buffer.dealloc(this);\n}\n\nvoid Linear::warm_up(Backend *backend) {\n    std::vector<float> input_fp32(config_.input_size);\n    std::vector<uint8_t> input(config_.input_size *\n                               ggml_type_size(config_.hidden_type) /\n                               ggml_blck_size(config_.hidden_type));\n    std::vector<uint8_t> output(config_.output_size *\n                                ggml_type_size(config_.hidden_type) /\n                                ggml_blck_size(config_.hidden_type));\n    for (int i = 0; i < config_.input_size; i++) {\n        input_fp32[i] = 0;\n    }\n    from_float(input_fp32.data(), input.data(), config_.input_size, config_.hidden_type);\n    forward_many(1, input.data(), output.data(), backend);\n}\n\nvoid Linear::forward_many(int qlen, const void* input, void* output, Backend* backend) {\n    const void* proj_input_ptr;\n    if (config_.hidden_type == ggml_internal_get_type_traits(config_.proj_type).vec_dot_type) {\n        proj_input_ptr = input;\n    } else {\n        to_float(input, input_fp32_, qlen * config_.input_size, config_.hidden_type);\n        from_float(input_fp32_, proj_input_, qlen * config_.input_size, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type);\n        proj_input_ptr = proj_input_;\n    }\n    int nth = config_.output_size / config_.stride;\n    backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {\n        int ith = task_id;\n        void* proj_ptr = (uint8_t*)proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type);\n        float* proj_output_ptr = proj_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.input_size / ggml_blck_size(config_.proj_type), proj_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_input_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_output_ptr, config_.output_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.proj_type, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {\n            for (int i = 0; i < qlen; i++) {\n                float* output_fp32_ptr = proj_output_ + i * config_.output_size + ith * config_.stride;\n                void* output_ptr = (uint8_t*)output + i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n                from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);\n            }\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {\n        from_float(proj_output_, output, qlen * config_.output_size, config_.hidden_type);\n    }\n}\n\nvoid Linear::forward(int qlen, const void* input, void* output, Backend* backend) {\n    if (qlen <= 0) {\n        return;\n    }\n    int forward_len = std::min(qlen, config_.group_max_len);\n    forward_many(forward_len, input, output, backend);\n    forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/llamafile/linear.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:35:00\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_LINEAR_H\n#define CPUINFER_OPERATOR_LINEAR_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"conversion.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\nstruct LinearConfig {\n    int input_size;\n    int output_size;\n    int stride;\n    int group_max_len;\n    void* proj;\n    ggml_type proj_type;\n    ggml_type hidden_type;\n\n    LinearConfig() {}\n\n    LinearConfig(int input_size, int output_size, int stride, int group_max_len, void* proj, ggml_type proj_type, ggml_type hidden_type)\n        : input_size(input_size), output_size(output_size), stride(stride), group_max_len(group_max_len), proj(proj), proj_type(proj_type), hidden_type(hidden_type) {}\n};\n\nclass Linear {\n   public:\n    Linear(LinearConfig);\n    ~Linear();\n    void warm_up(Backend* backend);\n    void forward_many(int qlen, const void* input, void* output, Backend* backend);\n    void forward(int qlen, const void* input, void* output, Backend* backend);\n\n   private:\n    LinearConfig config_;\n    void* proj_;  // [output_size * input_size ( /32 if quantized)]\n\n    float* input_fp32_;    // [group_max_len * input_size]\n    uint8_t* proj_input_;  // [group_max_len * input_size * ggml_type_size(ggml_internal_get_type_traits(proj_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(proj_type).vec_dot_type)]\n    float* proj_output_;   // [group_max_len * output_size]\n};\n\n#endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/llamafile/mlp.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-16 10:43:18\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-15 07:44:38\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"mlp.h\"\n\nMLP::MLP(MLPConfig config) {\n    config_ = config;\n    gate_proj_ = config_.gate_proj;\n    up_proj_ = config_.up_proj;\n    down_proj_ = config_.down_proj;\n\n    std::vector<std::pair<void**, uint64_t>> mem_requests;\n    mem_requests.push_back({(void**)&input_fp32_, sizeof(float) * config_.group_max_len * config_.hidden_size});\n    mem_requests.push_back({(void**)&gate_input_, config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n    mem_requests.push_back({(void**)&up_input_, config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    mem_requests.push_back({(void**)&gate_output_, sizeof(float) * config_.group_max_len * config_.intermediate_size});\n    mem_requests.push_back({(void**)&up_output_, sizeof(float) * config_.group_max_len * config_.intermediate_size});\n    mem_requests.push_back({(void**)&intermediate_fp32_, sizeof(float) * config_.group_max_len * config_.intermediate_size});\n    mem_requests.push_back({(void**)&down_input_, config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});\n    mem_requests.push_back({(void**)&down_output_, sizeof(float) * config_.group_max_len * config_.hidden_size});\n    shared_mem_buffer.alloc(this, mem_requests);\n}\n\nMLP::~MLP() {\n    shared_mem_buffer.dealloc(this);\n}\n\nvoid MLP::warm_up(Backend *backend) {\n    std::vector<float> input_fp32(config_.hidden_size);\n    std::vector<uint8_t> input(config_.hidden_size *\n                               ggml_type_size(config_.hidden_type) /\n                               ggml_blck_size(config_.hidden_type));\n    std::vector<uint8_t> output(config_.hidden_size *\n                                ggml_type_size(config_.hidden_type) /\n                                ggml_blck_size(config_.hidden_type));\n    for (int i = 0; i < config_.hidden_size; i++) {\n        input_fp32[i] = 0;\n    }\n    from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type);\n    forward_many(1, input.data(), output.data(), backend);\n}\n\nstatic float act_fn(float x) { return x / (1.0f + expf(-x)); }\n\nvoid MLP::forward_many(int qlen, const void* input, void* output, Backend* backend) {\n    const void* gate_input_ptr;\n    const void* up_input_ptr;\n    if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n        gate_input_ptr = up_input_ptr = input;\n    } else {\n        to_float(input, input_fp32_, qlen * config_.hidden_size, config_.hidden_type);\n        if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n            from_float(input_fp32_, gate_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n            gate_input_ptr = up_input_ptr = gate_input_;\n        } else {\n            if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {\n                from_float(input_fp32_, gate_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                gate_input_ptr = gate_input_;\n            } else {\n                gate_input_ptr = input;\n            }\n            if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                from_float(input_fp32_, up_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n                up_input_ptr = up_input_;\n            } else {\n                up_input_ptr = input;\n            }\n        }\n    }\n    int nth = config_.intermediate_size / config_.stride;\n    backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {\n        int ith = task_id;\n        void* gate_proj_ptr = (uint8_t*)gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        float* gate_output_ptr = gate_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        void* up_proj_ptr = (uint8_t*)up_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        float* up_output_ptr = up_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        for (int i = 0; i < qlen; i++) {\n            for (int j = ith * config_.stride; j < (ith + 1) * config_.stride; j++) {\n                intermediate_fp32_[i * config_.intermediate_size + j] = act_fn(gate_output_[i * config_.intermediate_size + j]) * up_output_[i * config_.intermediate_size + j];\n            }\n            if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {\n                float* intermediate_fp32_ptr = intermediate_fp32_ + i * config_.intermediate_size + ith * config_.stride;\n                void* down_input_ptr = (uint8_t*)down_input_ + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n                from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n            }\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {\n        from_float(intermediate_fp32_, down_input_, qlen * config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n    }\n    nth = config_.hidden_size / config_.stride;\n    backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {\n        int ith = task_id;\n        void* down_proj_ptr = (uint8_t*)down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n        float* down_output_ptr = down_output_ + ith * config_.stride;\n        llamafile_sgemm(config_.stride, qlen, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {\n            for (int i = 0; i < qlen; i++) {\n                float* output_fp32_ptr = down_output_ + i * config_.hidden_size + ith * config_.stride;\n                void* output_ptr = (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n                from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);\n            }\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {\n        from_float(down_output_, output, qlen * config_.hidden_size, config_.hidden_type);\n    }\n}\n\nvoid MLP::forward(int qlen, const void* input, void* output, Backend* backend) {\n    if (qlen <= 0) {\n        return;\n    }\n    int forward_len = std::min(qlen, config_.group_max_len);\n    forward_many(forward_len, input, output, backend);\n    forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/llamafile/mlp.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-12 10:07:58\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:35:06\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_MLP_H\n#define CPUINFER_OPERATOR_MLP_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"conversion.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\nstruct MLPConfig {\n    int hidden_size;\n    int intermediate_size;\n    int stride;\n    int group_max_len;\n    void* gate_proj;\n    void* up_proj;\n    void* down_proj;\n    ggml_type gate_type;\n    ggml_type up_type;\n    ggml_type down_type;\n    ggml_type hidden_type;\n\n    MLPConfig() {}\n\n    MLPConfig(int hidden_size, int intermediate_size, int stride, int group_max_len, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)\n        : hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}\n};\n\nclass MLP {\n   public:\n    MLP(MLPConfig);\n    ~MLP();\n    void warm_up(Backend* backend);\n    void forward_many(int qlen, const void* input, void* output, Backend* backend);\n    void forward(int qlen, const void* input, void* output, Backend* backend);\n\n   private:\n    MLPConfig config_;\n    void* gate_proj_;  // [intermediate_size * hidden_size ( /32 if quantized)]\n    void* up_proj_;    // [intermediate_size * hidden_size ( /32 if quantized)]\n    void* down_proj_;  // [hidden_size * intermediate_size ( /32 if quantized)]\n\n    float* input_fp32_;         // [group_max_len * hidden_size]\n    uint8_t* gate_input_;       // [group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    uint8_t* up_input_;         // [group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    float* gate_output_;        // [group_max_len * intermediate_size]\n    float* up_output_;          // [group_max_len * intermediate_size]\n    float* intermediate_fp32_;  // [group_max_len * intermediate_size]\n    uint8_t* down_input_;       // [group_max_len * intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n    float* down_output_;        // [group_max_len * hidden_size]\n};\n\n#endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/llamafile/moe.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-15 07:43:41\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"moe.h\"\n#include <iostream>\n#include <cstdint>\n\n#ifdef USE_NUMA\n#include <numa.h>\n#include <numaif.h>\n#endif\n\nMOE::MOE(MOEConfig config) {\n    config_ = config;\n    gate_proj_ = config_.gate_proj;\n    up_proj_ = config_.up_proj;\n    down_proj_ = config_.down_proj;\n    \n    #ifdef USE_NUMA\n    int numa_nodes = numa_num_configured_nodes();\n    gate_proj_numa_.resize(numa_nodes);\n    up_proj_numa_.resize(numa_nodes);\n    down_proj_numa_.resize(numa_nodes);\n    size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size;\n    for (int i = 0; i < numa_nodes; i++) {\n        gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i);\n        up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i);\n        down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i);\n        if (!gate_proj_numa_[i]) {\n            std::cout << \"Memory allocation failed for gate_proj_numa_ on node \" << i << std::endl;\n        }\n        if (!up_proj_numa_[i]) {\n            std::cout << \"Memory allocation failed for up_proj_numa_ on node \" << i << std::endl;\n        }\n        if (!down_proj_numa_[i]) {\n            std::cout << \"Memory allocation failed for down_proj_numa_ on node \" << i << std::endl;\n        }\n        memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type));\n        memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type));\n        memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type));\n    }\n    #endif\n\n    std::vector<std::pair<void**, uint64_t>> s_mem_requests;\n    s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size});\n    s_mem_requests.push_back({(void**)&s_gate_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n    s_mem_requests.push_back({(void**)&s_up_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    s_gate_output_.resize(config_.routed_expert_num);\n    s_up_output_.resize(config_.routed_expert_num);\n    s_intermediate_fp32_.resize(config_.routed_expert_num);\n    s_down_input_.resize(config_.routed_expert_num);\n    s_down_output_.resize(config_.routed_expert_num);\n    for (int i = 0; i < config_.routed_expert_num; i++) {\n        s_mem_requests.push_back({(void**)&s_gate_output_[i], sizeof(float) * config_.intermediate_size});\n        s_mem_requests.push_back({(void**)&s_up_output_[i], sizeof(float) * config_.intermediate_size});\n        s_mem_requests.push_back({(void**)&s_intermediate_fp32_[i], sizeof(float) * config_.intermediate_size});\n        s_mem_requests.push_back({(void**)&s_down_input_[i], config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});\n        s_mem_requests.push_back({(void**)&s_down_output_[i], sizeof(float) * config_.hidden_size});\n    }\n    s_mem_requests.push_back({(void**)&s_output_fp32_, sizeof(float) * config_.hidden_size});\n    shared_mem_buffer.alloc(this, s_mem_requests);\n\n    std::vector<std::pair<void**, uint64_t>> m_mem_requests;\n    m_input_fp32_.resize(config_.group_max_len);\n    m_gate_input_.resize(config_.group_max_len);\n    m_up_input_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n        m_mem_requests.push_back({(void**)&m_input_fp32_[i], sizeof(float) * config_.hidden_size});\n        m_mem_requests.push_back({(void**)&m_gate_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n        m_mem_requests.push_back({(void**)&m_up_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    }\n    m_mem_requests.push_back({(void**)&m_local_gate_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n    m_mem_requests.push_back({(void**)&m_local_up_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    m_mem_requests.push_back({(void**)&m_local_gate_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_up_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_intermediate_fp32_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_down_input_, config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});\n    m_mem_requests.push_back({(void**)&m_local_down_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size});\n    m_output_fp32_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n        m_mem_requests.push_back({(void**)&m_output_fp32_[i], sizeof(float) * config_.hidden_size});\n    }\n    shared_mem_buffer.alloc(this, m_mem_requests);\n\n    m_local_pos_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n        m_local_pos_[i].resize(config_.routed_expert_num);\n    }\n    m_local_num_.resize(config_.expert_num);\n    m_local_gate_input_ptr_.resize(config_.expert_num);\n    m_local_up_input_ptr_.resize(config_.expert_num);\n    m_local_gate_output_ptr_.resize(config_.expert_num);\n    m_local_up_output_ptr_.resize(config_.expert_num);\n    m_local_intermediate_fp32_ptr_.resize(config_.expert_num);\n    m_local_down_input_ptr_.resize(config_.expert_num);\n    m_local_down_output_ptr_.resize(config_.expert_num);\n}\n\nMOE::~MOE() {\n    shared_mem_buffer.dealloc(this);\n\n    #ifdef USE_NUMA\n    int numa_nodes = numa_num_configured_nodes();\n    for (int i = 0; i < numa_nodes; i++) {\n        numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type));\n        numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type));\n        numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type));\n    }\n    #endif\n}\n\nvoid MOE::warm_up(Backend* backend) {\n    std::vector<float> input_fp32(config_.hidden_size);\n    std::vector<uint8_t> input(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));\n    std::vector<uint8_t> output(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));\n    for (int i = 0; i < config_.hidden_size; i++) {\n        input_fp32[i] = 0;\n    }\n    from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type);\n    for (int i = 0; i < config_.expert_num; i++) {\n        uint64_t expert_ids = i;\n        float weights = 0;\n        forward_one(1, &expert_ids, &weights, input.data(), output.data(), backend);\n    }\n}\n\nstatic float act_fn(float x) {\n    return x / (1.0f + expf(-x));\n}\n\nvoid MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {\n    const void* gate_input_ptr;\n    const void* up_input_ptr;\n    if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n        gate_input_ptr = up_input_ptr = input;\n    } else {\n        to_float(input, s_input_fp32_, config_.hidden_size, config_.hidden_type);\n        if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n            from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n            gate_input_ptr = up_input_ptr = s_gate_input_;\n        } else {\n            if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {\n                from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                gate_input_ptr = s_gate_input_;\n            } else {\n                gate_input_ptr = input;\n            }\n            if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                from_float(s_input_fp32_, s_up_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n                up_input_ptr = s_up_input_;\n            } else {\n                up_input_ptr = input;\n            }\n        }\n    }\n    int nth = config_.intermediate_size / config_.stride;\n    backend->do_work_stealing_job(nth * k, nullptr, [&](int task_id) {\n        int expert_idx = task_id / nth;\n        uint64_t expert_id = expert_ids[expert_idx];\n        int ith = task_id % nth;\n        \n        #ifdef USE_NUMA\n        void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #else\n        void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #endif\n\n        float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride;\n        llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n\n        #ifdef USE_NUMA\n        void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #else\n        void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #endif\n\n        float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;\n        llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n            s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];\n        }\n        if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {\n            float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.stride;\n            void* down_input_ptr = s_down_input_[expert_idx] + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n            from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {\n        for (int i = 0; i < k; i++) {\n            from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        }\n    }\n    nth = config_.hidden_size / config_.stride;\n    backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {\n        int ith = task_id;\n        for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n            s_output_fp32_[i] = 0;\n        }\n        for (int expert_idx = 0; expert_idx < k; expert_idx++) {\n            uint64_t expert_id = expert_ids[expert_idx];\n\n            #ifdef USE_NUMA\n            void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n            #else\n            void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n            #endif\n            \n            float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride;\n            llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n            for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n                s_output_fp32_[i] += s_down_output_[expert_idx][i] * weights[expert_idx];\n            }\n        }\n        if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {\n            float* output_fp32_ptr = s_output_fp32_ + ith * config_.stride;\n            void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n            from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {\n        from_float(s_output_fp32_, output, config_.hidden_size, config_.hidden_type);\n    }\n}\n\nvoid MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend) {\n    for (int i = 0; i < config_.expert_num; i++) {\n        m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n        for (int j = 0; j < k; j++) {\n            m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n        }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n        m_local_gate_input_ptr_[i] = m_local_gate_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n        m_local_up_input_ptr_[i] = m_local_up_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n        m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n        m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n        m_local_intermediate_fp32_ptr_[i] = m_local_intermediate_fp32_ + offset * config_.intermediate_size;\n        m_local_down_input_ptr_[i] = m_local_down_input_ + offset * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;\n        offset += m_local_num_[i];\n    }\n    backend->do_work_stealing_job(qlen, nullptr, [&](int i) {\n        const void* gate_input_ptr;\n        const void* up_input_ptr;\n        if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n            gate_input_ptr = up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n        } else {\n            to_float((uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), m_input_fp32_[i], config_.hidden_size, config_.hidden_type);\n            if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                gate_input_ptr = up_input_ptr = m_gate_input_[i];\n            } else {\n                if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {\n                    from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                    gate_input_ptr = m_gate_input_[i];\n                } else {\n                    gate_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n                }\n                if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                    from_float(m_input_fp32_[i], m_up_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n                    up_input_ptr = m_up_input_[i];\n                } else {\n                    up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n                }\n            }\n        }\n        for (int j = 0; j < k; j++) {\n            memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type), gate_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type));\n            memcpy(m_local_up_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type), up_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type));\n        }\n    }, nullptr);\n    int stride = QK_K;\n    int nth = config_.intermediate_size / stride;\n    backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {\n        uint64_t expert_idx = task_id / nth;\n        int ith = task_id % nth;\n        void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];\n\n        #ifdef USE_NUMA\n        void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #else\n        void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #endif\n\n        float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride;\n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        void* up_input_ptr = m_local_up_input_ptr_[expert_idx];\n\n        #ifdef USE_NUMA\n        void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #else\n        void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #endif\n\n        float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;\n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n            for (int j = ith * stride; j < (ith + 1) * stride; j++) {\n                m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];\n            }\n            float* intermediate_fp32_ptr = m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride;\n            void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n            from_float(intermediate_fp32_ptr, down_input_ptr, stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        }\n    }, nullptr);\n    stride = QK_K;\n    nth = config_.hidden_size / stride;\n    backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {\n        uint64_t expert_idx = task_id / nth;\n        int ith = task_id % nth;\n        void* down_input_ptr = m_local_down_input_ptr_[expert_idx];\n        \n        #ifdef USE_NUMA\n        void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n        #else\n        void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n        #endif\n\n        float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;\n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n    }, nullptr);\n    backend->do_work_stealing_job(qlen, nullptr, [&](int i) {\n        for (int e = 0; e < config_.hidden_size; e++) {\n            m_output_fp32_[i][e] = 0;\n        }\n        for (int j = 0; j < k; j++) {\n            for (int e = 0; e < config_.hidden_size; e++) {\n                m_output_fp32_[i][e] += m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] * weights[i * k + j];\n            }\n        }\n        from_float(m_output_fp32_[i], (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type);\n    }, nullptr);\n}\n\nvoid MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, int* batch_size_tensor, Backend* backend) {\n    qlen = batch_size_tensor[0];\n    if (qlen < config_.group_min_len) {\n        for (int i = 0; i < qlen; i++) {\n            forward_one(k, expert_ids + i * k, weights + i * k, (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);\n        }\n        return;\n    }\n    int forward_len = std::min(config_.group_max_len, qlen);\n    forward_many(forward_len, k, expert_ids, weights, input, output, backend);\n\n    batch_size_tensor[0] -= forward_len;\n    forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), batch_size_tensor, backend);\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/llamafile/moe.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:35:10\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_MOE_H\n#define CPUINFER_OPERATOR_MOE_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"conversion.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n\nstruct MOEConfig {\n    int expert_num;\n    int routed_expert_num;\n    int hidden_size;\n    int intermediate_size;\n    int stride;\n    int group_min_len;\n    int group_max_len;\n    void* gate_proj;\n    void* up_proj;\n    void* down_proj;\n    ggml_type gate_type;\n    ggml_type up_type;\n    ggml_type down_type;\n    ggml_type hidden_type;\n\n    MOEConfig() {}\n\n    MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)\n        : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}\n};\n\nclass MOE {\n   public:\n    MOE(MOEConfig);\n    ~MOE();\n    void warm_up(Backend* backend);\n    void forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend);\n    void forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend);\n    void forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, int* batch_size_tensor, Backend* backend);\n\n   private:\n    MOEConfig config_;\n    void* gate_proj_;  // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    void* up_proj_;    // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    void* down_proj_;  // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n\n    #ifdef USE_NUMA\n    std::vector<void*> gate_proj_numa_;  // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    std::vector<void*> up_proj_numa_;    // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    std::vector<void*> down_proj_numa_;  // [numa_num, expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n    #endif\n\n    float* s_input_fp32_;                      // [hidden_size]\n    uint8_t* s_gate_input_;                    // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    uint8_t* s_up_input_;                      // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    std::vector<float*> s_gate_output_;        // [routed_expert_num, intermediate_size]\n    std::vector<float*> s_up_output_;          // [routed_expert_num, intermediate_size]\n    std::vector<float*> s_intermediate_fp32_;  // [routed_expert_num, intermediate_size]\n    std::vector<uint8_t*> s_down_input_;       // [routed_expert_num, intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n    std::vector<float*> s_down_output_;        // [routed_expert_num, hidden_size]\n    float* s_output_fp32_;                     // [hidden_size]\n\n    std::vector<float*> m_input_fp32_;    // [group_max_len, hidden_size]\n    std::vector<uint8_t*> m_gate_input_;  // [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    std::vector<uint8_t*> m_up_input_;    // [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    uint8_t* m_local_gate_input_;         // [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    uint8_t* m_local_up_input_;           // [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    float* m_local_gate_output_;          // [routed_expert_num * group_max_len * intermediate_size]\n    float* m_local_up_output_;            // [routed_expert_num * group_max_len * intermediate_size]\n    float* m_local_intermediate_fp32_;    // [routed_expert_num * group_max_len * intermediate_size]\n    uint8_t* m_local_down_input_;         // [routed_expert_num * group_max_len * intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n    float* m_local_down_output_;          // [routed_expert_num * group_max_len * hidden_size]\n    std::vector<float*> m_output_fp32_;   // [group_max_len, hidden_size]\n\n    std::vector<std::vector<int>> m_local_pos_;          // [group_max_len, routed_expert_num]\n    std::vector<int> m_local_num_;                       // [expert_num]\n    std::vector<uint8_t*> m_local_gate_input_ptr_;       // [expert_num]\n    std::vector<uint8_t*> m_local_up_input_ptr_;         // [expert_num]\n    std::vector<float*> m_local_gate_output_ptr_;        // [expert_num]\n    std::vector<float*> m_local_up_output_ptr_;          // [expert_num]\n    std::vector<float*> m_local_intermediate_fp32_ptr_;  // [expert_num]\n    std::vector<uint8_t*> m_local_down_input_ptr_;       // [expert_num]\n    std::vector<float*> m_local_down_output_ptr_;        // [expert_num]\n};\n\n#endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/llamafile/sft_moe.cpp",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 1.0.0\n * @LastEditors  : kkk1nak0\n * @LastEditTime : 2024-08-15 07:43:41\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#include \"sft_moe.h\"\n#include <iostream>\n#include <cstdint>\n#include <cstring>\n#include <time.h>\n\n#ifdef USE_NUMA\n#include <numa.h>\n#include <numaif.h>\n#endif\n\nSFT_MOE::SFT_MOE(SFT_MOEConfig config) {\n    config_ = config;\n    gate_proj_ = config_.gate_proj;\n    up_proj_ = config_.up_proj;\n    down_proj_ = config_.down_proj;\n    \n    #ifdef USE_NUMA\n    int numa_nodes = numa_num_configured_nodes();\n    gate_proj_numa_.resize(numa_nodes);\n    up_proj_numa_.resize(numa_nodes);\n    down_proj_numa_.resize(numa_nodes);\n    size_t exp_inter_hidden_mul_ = (size_t)config.expert_num * config.intermediate_size * config.hidden_size;\n    for (int i = 0; i < numa_nodes; i++) {\n        gate_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type), i);\n        up_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type), i);\n        down_proj_numa_[i] = numa_alloc_onnode(exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type), i);\n        if (!gate_proj_numa_[i]) {\n            std::cout << \"Memory allocation failed for gate_proj_numa_ on node \" << i << std::endl;\n        }\n        if (!up_proj_numa_[i]) {\n            std::cout << \"Memory allocation failed for up_proj_numa_ on node \" << i << std::endl;\n        }\n        if (!down_proj_numa_[i]) {\n            std::cout << \"Memory allocation failed for down_proj_numa_ on node \" << i << std::endl;\n        }\n        memcpy(gate_proj_numa_[i], gate_proj_, exp_inter_hidden_mul_* ggml_type_size(config.gate_type) / ggml_blck_size(config.gate_type));\n        memcpy(up_proj_numa_[i], up_proj_, exp_inter_hidden_mul_* ggml_type_size(config.up_type) / ggml_blck_size(config.up_type));\n        memcpy(down_proj_numa_[i], down_proj_, exp_inter_hidden_mul_* ggml_type_size(config.down_type) / ggml_blck_size(config.down_type));\n    }\n    #endif\n\n    std::vector<std::pair<void**, uint64_t>> s_mem_requests;\n    s_mem_requests.push_back({(void**)&gate_proj_t_, config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.grad_type)});\n    s_mem_requests.push_back({(void**)&up_proj_t_, config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.grad_type)});\n    s_mem_requests.push_back({(void**)&down_proj_t_, config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.grad_type)});\n    s_mem_requests.push_back({(void**)&transpose_buffer_fp32_, config_.expert_num * config_.intermediate_size * config_.hidden_size * sizeof(float)});\n    s_mem_requests.push_back({(void**)&transpose_buffer_, config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.grad_type)});\n\n    s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size});\n    s_mem_requests.push_back({(void**)&s_gate_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n    s_mem_requests.push_back({(void**)&s_up_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    s_gate_output_.resize(config_.routed_expert_num);\n    s_up_output_.resize(config_.routed_expert_num);\n    s_intermediate_fp32_.resize(config_.routed_expert_num);\n    s_down_input_.resize(config_.routed_expert_num);\n    s_down_output_.resize(config_.routed_expert_num);\n    for (int i = 0; i < config_.routed_expert_num; i++) {\n        s_mem_requests.push_back({(void**)&s_gate_output_[i], sizeof(float) * config_.intermediate_size});\n        s_mem_requests.push_back({(void**)&s_up_output_[i], sizeof(float) * config_.intermediate_size});\n        s_mem_requests.push_back({(void**)&s_intermediate_fp32_[i], sizeof(float) * config_.intermediate_size});\n        s_mem_requests.push_back({(void**)&s_down_input_[i], config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});\n        s_mem_requests.push_back({(void**)&s_down_output_[i], sizeof(float) * config_.hidden_size});\n    }\n    s_mem_requests.push_back({(void**)&s_output_fp32_, sizeof(float) * config_.hidden_size});\n        \n    s_down_input_grad_.resize(config_.routed_expert_num);\n    s_gate_output_grad_fp32_.resize(config_.routed_expert_num);\n    s_up_output_grad_fp32_.resize(config_.routed_expert_num);\n    s_gate_output_grad_.resize(config_.routed_expert_num);\n    s_up_output_grad_.resize(config_.routed_expert_num);\n    s_gate_input_grad_.resize(config_.routed_expert_num);\n    s_up_input_grad_.resize(config_.routed_expert_num);\n    for (int i = 0; i < config_.routed_expert_num; i++) {\n        s_mem_requests.push_back({(void**)&s_down_input_grad_[i], config_.intermediate_size * sizeof(float)});\n        s_mem_requests.push_back({(void**)&s_gate_output_grad_fp32_[i], config_.intermediate_size * sizeof(float)});\n        s_mem_requests.push_back({(void**)&s_up_output_grad_fp32_[i], config_.intermediate_size * sizeof(float)});\n        s_mem_requests.push_back({(void**)&s_gate_output_grad_[i], config_.intermediate_size * ggml_type_size(config_.grad_type)});\n        s_mem_requests.push_back({(void**)&s_up_output_grad_[i], config_.intermediate_size * ggml_type_size(config_.grad_type)});\n        s_mem_requests.push_back({(void**)&s_gate_input_grad_[i], config_.hidden_size * sizeof(float)});\n        s_mem_requests.push_back({(void**)&s_up_input_grad_[i], config_.hidden_size * sizeof(float)});\n    }\n    s_mem_requests.push_back({(void**)&s_input_grad_fp32_, config_.hidden_size * sizeof(float)});\n\n    shared_mem_buffer.alloc(this, s_mem_requests);\n\n    std::vector<std::pair<void**, uint64_t>> m_mem_requests;\n    m_input_fp32_.resize(config_.group_max_len);\n    m_gate_input_.resize(config_.group_max_len);\n    m_up_input_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n        m_mem_requests.push_back({(void**)&m_input_fp32_[i], sizeof(float) * config_.hidden_size});\n        m_mem_requests.push_back({(void**)&m_gate_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n        m_mem_requests.push_back({(void**)&m_up_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    }\n    m_mem_requests.push_back({(void**)&m_local_gate_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});\n    m_mem_requests.push_back({(void**)&m_local_up_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});\n    m_mem_requests.push_back({(void**)&m_local_gate_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_up_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_intermediate_fp32_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_down_input_, config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});\n    m_mem_requests.push_back({(void**)&m_local_down_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size});\n    m_output_fp32_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n        m_mem_requests.push_back({(void**)&m_output_fp32_[i], sizeof(float) * config_.hidden_size});\n    }\n    \n    m_mem_requests.push_back({(void**)&m_local_down_output_grad_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(config_.grad_type)});\n    m_mem_requests.push_back({(void**)&m_local_down_input_grad_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_gate_output_grad_fp32_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_up_output_grad_fp32_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});\n    m_mem_requests.push_back({(void**)&m_local_gate_output_grad_, config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(config_.grad_type)});\n    m_mem_requests.push_back({(void**)&m_local_up_output_grad_, config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(config_.grad_type)});\n    m_mem_requests.push_back({(void**)&m_local_gate_input_grad_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void**)&m_local_up_input_grad_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size});\n    m_mem_requests.push_back({(void**)&m_local_token_indices_, sizeof(int) * config_.routed_expert_num * config_.group_max_len});\n    m_mem_requests.push_back({(void**)&m_local_expert_positions_, sizeof(int) * config_.routed_expert_num * config_.group_max_len});\n    m_grad_input_fp32_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n        m_mem_requests.push_back({(void**)&m_grad_input_fp32_[i], sizeof(float) * config_.hidden_size});\n    }\n    \n    shared_mem_buffer.alloc(this, m_mem_requests);\n\n    m_local_pos_.resize(config_.group_max_len);\n    for (int i = 0; i < config_.group_max_len; i++) {\n        m_local_pos_[i].resize(config_.routed_expert_num);\n    }\n    m_local_num_.resize(config_.expert_num);\n    m_local_gate_input_ptr_.resize(config_.expert_num);\n    m_local_up_input_ptr_.resize(config_.expert_num);\n    m_local_gate_output_ptr_.resize(config_.expert_num);\n    m_local_up_output_ptr_.resize(config_.expert_num);\n    m_local_intermediate_fp32_ptr_.resize(config_.expert_num);\n    m_local_down_input_ptr_.resize(config_.expert_num);\n    m_local_down_output_ptr_.resize(config_.expert_num);\n    \n    // backward_many 专用指针数组初始化\n    m_local_down_output_grad_ptr_.resize(config_.expert_num);\n    m_local_down_input_grad_ptr_.resize(config_.expert_num);\n    m_local_gate_output_grad_fp32_ptr_.resize(config_.expert_num);\n    m_local_up_output_grad_fp32_ptr_.resize(config_.expert_num);\n    m_local_gate_output_grad_ptr_.resize(config_.expert_num);\n    m_local_up_output_grad_ptr_.resize(config_.expert_num);\n    m_local_gate_input_grad_ptr_.resize(config_.expert_num);\n    m_local_up_input_grad_ptr_.resize(config_.expert_num);\n    \n    // fwd_cache访问映射指针数组初始化\n    m_local_token_indices_ptr_.resize(config_.expert_num);\n    m_local_expert_positions_ptr_.resize(config_.expert_num);\n}\n\nSFT_MOE::~SFT_MOE() {\n    shared_mem_buffer.dealloc(this);\n\n    #ifdef USE_NUMA\n    int numa_nodes = numa_num_configured_nodes();\n    for (int i = 0; i < numa_nodes; i++) {\n        numa_free(gate_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type));\n        numa_free(up_proj_numa_[i], config_.expert_num * config_.intermediate_size * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type));\n        numa_free(down_proj_numa_[i], config_.expert_num * config_.hidden_size * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type));\n    }\n    #endif\n}\n\nvoid SFT_MOE::warm_up(Backend* backend) {\n    std::vector<float> input_fp32(config_.hidden_size);\n    std::vector<uint8_t> input(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));\n    std::vector<uint8_t> output(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));\n    for (int i = 0; i < config_.hidden_size; i++) {\n        input_fp32[i] = 0;\n    }\n    from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type);\n\t/* ---------- 仅用于占位的 ForwardCache ---------- */\n    SFT_MoEForwardCache dummy_cache; // 内容无用，只为满足接口\n\tdummy_cache.init(/*k=*/1, config_.intermediate_size);\n    for (int i = 0; i < config_.expert_num; i++) {\n        uint64_t expert_ids = i;\n        float weights = 0;\n        forward_one(1, &expert_ids, &weights, input.data(), output.data(), backend, &dummy_cache);\n    }\n}\n\nstatic float act_fn(float x) {\n    return x / (1.0f + expf(-x));\n}\n\nvoid SFT_MOE::ensure_fwd_cache(int qlen, int k)\n{\n\t// if ((int)fw_cache_.size() < qlen)\n\t// \tfw_cache_.resize(qlen);\n\t// /* 只在扩容的那部分做 init，防止重复开辟 */\n\t// for (int i = 0; i < qlen; ++i)\n\t// \tfw_cache_[i].init(k, config_.intermediate_size);\n\n\tint old_sz = fw_cache_.size();\n    if (old_sz < qlen)\n    {\n        fw_cache_.resize(qlen);\n        for (int i = old_sz; i < qlen; ++i)  // 仅初始化新增元素\n            fw_cache_[i].init(k, config_.intermediate_size);\n    }\n\n\t\n    // if ((int)fw_cache_.size() < qlen)\n    //     fw_cache_.resize(qlen);\n\n    // for (int i = 0; i < qlen; ++i)                          // 每轮都 init\n    //     fw_cache_[i].init(k, config_.intermediate_size);    // 但 无重 alloc\n\n}\n\nSFT_MoEForwardCache* SFT_MOE::fwd_cache_ptr()\n{\n\treturn fw_cache_.empty() ? nullptr : fw_cache_.data();\n}\n\nvoid SFT_MOE::forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend, SFT_MoEForwardCache* fwd_cache) {\n    const void* gate_input_ptr;\n    const void* up_input_ptr;\n    if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n        gate_input_ptr = up_input_ptr = input;\n    } else {\n        to_float(input, s_input_fp32_, config_.hidden_size, config_.hidden_type);\n        if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n            from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n            gate_input_ptr = up_input_ptr = s_gate_input_;\n        } else {\n            if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {\n                from_float(s_input_fp32_, s_gate_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                gate_input_ptr = s_gate_input_;\n            } else {\n                gate_input_ptr = input;\n            }\n            if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                from_float(s_input_fp32_, s_up_input_, config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n                up_input_ptr = s_up_input_;\n            } else {\n                up_input_ptr = input;\n            }\n        }\n    }\n    int nth = config_.intermediate_size / config_.stride;\n    backend->do_work_stealing_job(nth * k, nullptr, [&](int task_id) {\n        int expert_idx = task_id / nth;\n        uint64_t expert_id = expert_ids[expert_idx];\n        int ith = task_id % nth;\n        \n        #ifdef USE_NUMA\n        void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #else\n        void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #endif\n\n        float* gate_output_ptr = s_gate_output_[expert_idx] + ith * config_.stride;\n        llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n\n        #ifdef USE_NUMA\n        void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #else\n        void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #endif\n\n        float* up_output_ptr = s_up_output_[expert_idx] + ith * config_.stride;\n        llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n            s_intermediate_fp32_[expert_idx][i] = act_fn(s_gate_output_[expert_idx][i]) * s_up_output_[expert_idx][i];\n        }\n        if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {\n            float* intermediate_fp32_ptr = s_intermediate_fp32_[expert_idx] + ith * config_.stride;\n            void* down_input_ptr = s_down_input_[expert_idx] + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n            from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        }\n    }, nullptr);\n    if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {\n        for (int i = 0; i < k; i++) {\n            from_float(s_intermediate_fp32_[i], s_down_input_[i], config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        }\n    }\n    nth = config_.hidden_size / config_.stride;\n    backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {\n        int ith = task_id;\n        for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n            s_output_fp32_[i] = 0;\n        }\n        for (int expert_idx = 0; expert_idx < k; expert_idx++) {\n            uint64_t expert_id = expert_ids[expert_idx];\n\n            #ifdef USE_NUMA\n            void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n            #else\n            void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n            #endif\n            \n            float* down_output_ptr = s_down_output_[expert_idx] + ith * config_.stride;\n            llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), s_down_input_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n            for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n                s_output_fp32_[i] += s_down_output_[expert_idx][i] * weights[expert_idx];\n            }\n        }\n        if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {\n            float* output_fp32_ptr = s_output_fp32_ + ith * config_.stride;\n            void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n            from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);\n        }\n    }, nullptr);\n\n\tfor (int e = 0; e < k; ++e) {\n        // gate_output_: float[inter_size] per expert\n        std::memcpy(fwd_cache->gate_u[e].data(),\n                    s_gate_output_[e],\n                    sizeof(float) * config_.intermediate_size);\n\n        std::memcpy(fwd_cache->up_v[e].data(),\n                    s_up_output_[e],\n                    sizeof(float) * config_.intermediate_size);\n\n        // 可选保存 z\n        // std::memcpy(fwd_cache->z[e].data(),\n        //             s_intermediate_fp32_[e],\n        //             sizeof(float) * config_.intermediate_size);\n    }\n}\n\nvoid SFT_MOE::forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend, SFT_MoEForwardCache* fwd_cache) {\n    for (int i = 0; i < config_.expert_num; i++) {\n        m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n        for (int j = 0; j < k; j++) {\n            m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n        }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n        m_local_gate_input_ptr_[i] = m_local_gate_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n        m_local_up_input_ptr_[i] = m_local_up_input_ + offset * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n        m_local_gate_output_ptr_[i] = m_local_gate_output_ + offset * config_.intermediate_size;\n        m_local_up_output_ptr_[i] = m_local_up_output_ + offset * config_.intermediate_size;\n        m_local_intermediate_fp32_ptr_[i] = m_local_intermediate_fp32_ + offset * config_.intermediate_size;\n        m_local_down_input_ptr_[i] = m_local_down_input_ + offset * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        m_local_down_output_ptr_[i] = m_local_down_output_ + offset * config_.hidden_size;\n        offset += m_local_num_[i];\n    }\n    backend->do_work_stealing_job(qlen, nullptr, [&](int i) {\n        const void* gate_input_ptr;\n        const void* up_input_ptr;\n        if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n            gate_input_ptr = up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n        } else {\n            to_float((uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), m_input_fp32_[i], config_.hidden_size, config_.hidden_type);\n            if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                gate_input_ptr = up_input_ptr = m_gate_input_[i];\n            } else {\n                if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {\n                    from_float(m_input_fp32_[i], m_gate_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);\n                    gate_input_ptr = m_gate_input_[i];\n                } else {\n                    gate_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n                }\n                if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {\n                    from_float(m_input_fp32_[i], m_up_input_[i], config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);\n                    up_input_ptr = m_up_input_[i];\n                } else {\n                    up_input_ptr = (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);\n                }\n            }\n        }\n        for (int j = 0; j < k; j++) {\n            memcpy(m_local_gate_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type), gate_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type));\n            memcpy(m_local_up_input_ptr_[expert_ids[i * k + j]] + m_local_pos_[i][j] * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type), up_input_ptr, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type));\n        }\n    }, nullptr);\n    int stride = QK_K;\n    int nth = config_.intermediate_size / stride;\n    backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {\n        uint64_t expert_idx = task_id / nth;\n        int ith = task_id % nth;\n        void* gate_input_ptr = m_local_gate_input_ptr_[expert_idx];\n\n        #ifdef USE_NUMA\n        void* gate_proj_ptr = (uint8_t*)gate_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #else\n        void* gate_proj_ptr = (uint8_t*)gate_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);\n        #endif\n\n        float* gate_output_ptr = m_local_gate_output_ptr_[expert_idx] + ith * stride;\n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        void* up_input_ptr = m_local_up_input_ptr_[expert_idx];\n\n        #ifdef USE_NUMA\n        void* up_proj_ptr = (uint8_t*)up_proj_numa_[Backend::numa_node] + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #else\n        void* up_proj_ptr = (uint8_t*)up_proj_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);\n        #endif\n\n        float* up_output_ptr = m_local_up_output_ptr_[expert_idx] + ith * stride;\n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n            for (int j = ith * stride; j < (ith + 1) * stride; j++) {\n                m_local_intermediate_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = act_fn(m_local_gate_output_ptr_[expert_idx][i * config_.intermediate_size + j]) * m_local_up_output_ptr_[expert_idx][i * config_.intermediate_size + j];\n            }\n            float* intermediate_fp32_ptr = m_local_intermediate_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride;\n            void* down_input_ptr = m_local_down_input_ptr_[expert_idx] + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n            from_float(intermediate_fp32_ptr, down_input_ptr, stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);\n        }\n    }, nullptr);\n    stride = QK_K;\n    nth = config_.hidden_size / stride;\n    backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {\n        uint64_t expert_idx = task_id / nth;\n        int ith = task_id % nth;\n        void* down_input_ptr = m_local_down_input_ptr_[expert_idx];\n        \n        #ifdef USE_NUMA\n        void* down_proj_ptr = (uint8_t*)down_proj_numa_[Backend::numa_node] + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n        #else\n        void* down_proj_ptr = (uint8_t*)down_proj_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);\n        #endif\n\n        float* down_output_ptr = m_local_down_output_ptr_[expert_idx] + ith * stride;\n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n    }, nullptr);\n    backend->do_work_stealing_job(qlen, nullptr, [&](int i) {\n        for (int e = 0; e < config_.hidden_size; e++) {\n            m_output_fp32_[i][e] = 0;\n        }\n        for (int j = 0; j < k; j++) {\n            for (int e = 0; e < config_.hidden_size; e++) {\n                m_output_fp32_[i][e] += m_local_down_output_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] * weights[i * k + j];\n            }\n        }\n        from_float(m_output_fp32_[i], (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), config_.hidden_size, config_.hidden_type);\n    }, nullptr);\n\n\t/* 把每个 token-expert 的行复制到各自 cache */\n    backend->do_work_stealing_job(qlen, nullptr, [&](int token_idx) {\n        auto& cache = fwd_cache[token_idx];\n        // cache 已在上层 init(k, inter_size)\n        for (int j = 0; j < k; ++j) {\n            uint64_t  eid   = expert_ids[token_idx*k + j];\n            int       row   = m_local_pos_[token_idx][j];\n            size_t    ofs   = row * config_.intermediate_size;\n            /* gate u */\n            std::memcpy(cache.gate_u[j].data(),\n                        m_local_gate_output_ptr_[eid] + ofs,\n                        sizeof(float) * config_.intermediate_size);\n            /* up v */\n            std::memcpy(cache.up_v[j].data(),\n                        m_local_up_output_ptr_[eid] + ofs,\n                        sizeof(float) * config_.intermediate_size);\n            /* 可选 z */\n            // std::memcpy(cache.z[j].data(),\n            //             m_local_intermediate_fp32_ptr_[eid] + ofs,\n            //             sizeof(float) * config_.intermediate_size);\n        }\n    }, nullptr);\n}\n\nvoid SFT_MOE::forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend, SFT_MoEForwardCache* fwd_cache) {\n    if (qlen < config_.group_min_len) {\n        for (int i = 0; i < qlen; i++) {\n\t\t\t// fwd_cache[i].init(k, config_.intermediate_size);      // 预分配\n            forward_one(k, expert_ids + i * k, weights + i * k, (uint8_t*)input + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend, fwd_cache + i);\n        }\n        return;\n    }\n    int forward_len = std::min(config_.group_max_len, qlen);\n    // for (int i = 0; i < forward_len; ++i)\n    //     fwd_cache[i].init(k, config_.intermediate_size);\n    forward_many(forward_len, k, expert_ids, weights, input, output, backend, fwd_cache);\n    forward(qlen - forward_len, k, expert_ids + forward_len * k, weights + forward_len * k, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend, fwd_cache + forward_len);\n}\n\nstatic float act_fn_grad(float x) {\n    float sigmoid_x = 1.0f / (1.0f + expf(-x));\n    return sigmoid_x * (1. + x * (1. - sigmoid_x));\n}\n\nvoid SFT_MOE::transpose_expert_matrix(const void* src, void* dst, int R, int C, ggml_type src_type, ggml_type dst_type, uint64_t expert_idx) {\n    to_float(src, transpose_buffer_fp32_ + (R * C * expert_idx), R * C, src_type);\n    from_float(transpose_buffer_fp32_ + (R * C * expert_idx), transpose_buffer_ + (R * C * expert_idx) * ggml_type_size(dst_type), R * C, dst_type);\n    for (int r = 0; r < R; ++r) {\n        for (int c = 0; c < C; ++c) {\n            memcpy(\n                (uint8_t*)dst + (c * R + r) * ggml_type_size(dst_type),\n                (uint8_t*)transpose_buffer_ + (R * C * expert_idx + r * C + c) * ggml_type_size(dst_type),\n                ggml_type_size(dst_type));\n        }\n    }\n}\n\nvoid SFT_MOE::get_transpose(Backend* backend) {\n    // Transpose gate_proj_\n    int R_gate = config_.intermediate_size;\n    int C_gate = config_.hidden_size;\n    size_t gate_expert_src_stride_bytes = (size_t)R_gate * C_gate * ggml_type_size(config_.gate_type);\n    size_t gate_expert_dst_t_stride_bytes = (size_t)C_gate * R_gate * ggml_type_size(config_.grad_type);\n    backend->do_work_stealing_job(config_.expert_num, nullptr, [&](int expert_idx) {\n        void* src_expert = (uint8_t*)gate_proj_ + expert_idx * gate_expert_src_stride_bytes;\n        void* dst_expert_t = (uint8_t*)gate_proj_t_ + expert_idx * gate_expert_dst_t_stride_bytes;\n        transpose_expert_matrix(src_expert, dst_expert_t, R_gate, C_gate, config_.gate_type, config_.grad_type, expert_idx);\n    }, nullptr);\n\n    // Transpose up_proj_\n    int R_up = config_.intermediate_size;\n    int C_up = config_.hidden_size;\n    size_t up_expert_src_stride_bytes = (size_t)R_up * C_up * ggml_type_size(config_.up_type);\n    size_t up_expert_dst_t_stride_bytes = (size_t)C_up * R_up * ggml_type_size(config_.grad_type);\n    backend->do_work_stealing_job(config_.expert_num, nullptr, [&](int expert_idx) {\n        void* src_expert = (uint8_t*)up_proj_ + expert_idx * up_expert_src_stride_bytes;\n        void* dst_expert_t = (uint8_t*)up_proj_t_ + expert_idx * up_expert_dst_t_stride_bytes;\n        transpose_expert_matrix(src_expert, dst_expert_t, R_up, C_up, config_.up_type, config_.grad_type, expert_idx);\n    }, nullptr);\n\n    // Transpose down_proj_\n    int R_down = config_.hidden_size;\n    int C_down = config_.intermediate_size;\n    size_t down_expert_src_stride_bytes = (size_t)R_down * C_down * ggml_type_size(config_.down_type);\n    size_t down_expert_dst_t_stride_bytes = (size_t)C_down * R_down * ggml_type_size(config_.grad_type);\n    backend->do_work_stealing_job(config_.expert_num, nullptr, [&](int expert_idx) {\n        void* src_expert = (uint8_t*)down_proj_ + expert_idx * down_expert_src_stride_bytes;\n        void* dst_expert_t = (uint8_t*)down_proj_t_ + expert_idx * down_expert_dst_t_stride_bytes;\n        transpose_expert_matrix(src_expert, dst_expert_t, R_down, C_down, config_.down_type, config_.grad_type, expert_idx);\n    }, nullptr);\n}\n\nvoid SFT_MOE::backward_one(int k, const uint64_t* expert_ids, const float* weights, const void* output_grad, void* input_grad, Backend* backend, const SFT_MoEForwardCache* fwd_cache) {\n\t// clock_t clk1, clk2, clk3, clk4;\n\t// clock_t clkz1, clkz2, clkz3, clkz4, clkz5;\n\t// clk1 = clock();\n\t// clk2 = clock();\n    int nth = config_.intermediate_size / config_.stride;\n    backend->do_work_stealing_job(nth * k, nullptr, [&](int task_id) {\n        int expert_idx = task_id / nth;\n        uint64_t expert_id = expert_ids[expert_idx];\n        int ith = task_id % nth;\n\t\t// clkz1 = clock();\n        void* down_proj_t_ptr = (uint8_t*)down_proj_t_ + (expert_id * config_.intermediate_size + ith * config_.stride) * config_.hidden_size * ggml_type_size(config_.grad_type);\n        float* down_input_grad_ptr = s_down_input_grad_[expert_idx] + ith * config_.stride;\n        // clkz2 = clock();\n        llamafile_sgemm(config_.stride, 1, config_.hidden_size, down_proj_t_ptr, config_.hidden_size, output_grad, config_.hidden_size, down_input_grad_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.grad_type, config_.grad_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        // clkz3 = clock();\n        for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n            s_down_input_grad_[expert_idx][i] *= weights[expert_idx];\n\n            s_gate_output_grad_fp32_[expert_idx][i] = s_down_input_grad_[expert_idx][i] * fwd_cache->up_v[expert_idx][i] * act_fn_grad(fwd_cache->gate_u[expert_idx][i]); \n            s_up_output_grad_fp32_[expert_idx][i] = s_down_input_grad_[expert_idx][i] * act_fn(fwd_cache->gate_u[expert_idx][i]);\n        }\n        // clkz4 = clock();\n        from_float(s_gate_output_grad_fp32_[expert_idx] + ith * config_.stride, s_gate_output_grad_[expert_idx] + ith * config_.stride * ggml_type_size(config_.grad_type), config_.stride, config_.grad_type);\n        from_float(s_up_output_grad_fp32_[expert_idx] + ith * config_.stride, s_up_output_grad_[expert_idx] + ith * config_.stride * ggml_type_size(config_.grad_type), config_.stride, config_.grad_type);\n        // clkz5 = clock();\n    }, nullptr);\n\n\t// clk3 = clock();\n    nth = config_.hidden_size / config_.stride;\n    backend->do_work_stealing_job(nth, nullptr, [&](int task_id) {\n        int ith = task_id;\n        for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n            s_input_grad_fp32_[i] = 0;\n        }\n        for (int expert_idx = 0; expert_idx < k; expert_idx++) {\n            uint64_t expert_id = expert_ids[expert_idx];\n\n            void* gate_proj_t_ptr = (uint8_t*)gate_proj_t_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.grad_type);\n            float* gate_input_grad_ptr = s_gate_input_grad_[expert_idx] + ith * config_.stride;\n            llamafile_sgemm(config_.stride, 1, config_.intermediate_size, gate_proj_t_ptr, config_.intermediate_size, s_gate_output_grad_[expert_idx], config_.intermediate_size, gate_input_grad_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.grad_type, config_.grad_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n\n            void* up_proj_t_ptr = (uint8_t*)up_proj_t_ + (expert_id * config_.hidden_size + ith * config_.stride) * config_.intermediate_size * ggml_type_size(config_.grad_type);\n            float* up_input_grad_ptr = s_up_input_grad_[expert_idx] + ith * config_.stride;\n            llamafile_sgemm(config_.stride, 1, config_.intermediate_size, up_proj_t_ptr, config_.intermediate_size, s_up_output_grad_[expert_idx], config_.intermediate_size, up_input_grad_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.grad_type, config_.grad_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n            \n            for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {\n                s_input_grad_fp32_[i] += s_gate_input_grad_[expert_idx][i] + s_up_input_grad_[expert_idx][i];\n            }\n        }\n        from_float(s_input_grad_fp32_ + ith * config_.stride, (uint8_t*)input_grad + ith * config_.stride * ggml_type_size(config_.grad_type), config_.stride, config_.grad_type);\n    }, nullptr);\n\t// clk4 = clock();\n\t// std::cout << \"[Δclk12] \" << (clk2 - clk1) / static_cast<double>(CLOCKS_PER_SEC) * 1000\n    //       << \" ms  [Δclk23] \" << (clk3 - clk2) / static_cast<double>(CLOCKS_PER_SEC) * 1000\n    //       << \" ms  [Δclk34] \" << (clk4 - clk3) / static_cast<double>(CLOCKS_PER_SEC) * 1000\n    //       << \" ms  [Δclkz12] \" << (clkz2 - clkz1) / static_cast<double>(CLOCKS_PER_SEC) * 1000\n    //       << \" ms  [Δclkz23] \" << (clkz3 - clkz2) / static_cast<double>(CLOCKS_PER_SEC) * 1000\n    //       << \" ms  [Δclkz34] \" << (clkz4 - clkz3) / static_cast<double>(CLOCKS_PER_SEC) * 1000\n    //       << \" ms  [Δclkz45] \" << (clkz5 - clkz4) / static_cast<double>(CLOCKS_PER_SEC) * 1000\n    //       << \" ms\\n\";\n\n}\n\nvoid SFT_MOE::backward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* output_grad, void* input_grad, Backend* backend, const SFT_MoEForwardCache* fwd_cache) {\n    for (int i = 0; i < config_.expert_num; i++) {\n        m_local_num_[i] = 0;\n    }\n    for (int i = 0; i < qlen; i++) {\n        for (int j = 0; j < k; j++) {\n            m_local_pos_[i][j] = m_local_num_[expert_ids[i * k + j]]++;\n        }\n    }\n    uint64_t offset = 0;\n    for (int i = 0; i < config_.expert_num; i++) {\n        m_local_down_output_grad_ptr_[i] = m_local_down_output_grad_ + offset * config_.hidden_size * ggml_type_size(config_.grad_type);\n        m_local_down_input_grad_ptr_[i] = m_local_down_input_grad_ + offset * config_.intermediate_size;\n        m_local_gate_output_grad_fp32_ptr_[i] = m_local_gate_output_grad_fp32_ + offset * config_.intermediate_size;\n        m_local_up_output_grad_fp32_ptr_[i] = m_local_up_output_grad_fp32_ + offset * config_.intermediate_size;\n        m_local_gate_output_grad_ptr_[i] = m_local_gate_output_grad_ + offset * config_.intermediate_size * ggml_type_size(config_.grad_type);\n        m_local_up_output_grad_ptr_[i] = m_local_up_output_grad_ + offset * config_.intermediate_size * ggml_type_size(config_.grad_type);\n        m_local_gate_input_grad_ptr_[i] = m_local_gate_input_grad_ + offset * config_.hidden_size;\n        m_local_up_input_grad_ptr_[i] = m_local_up_input_grad_ + offset * config_.hidden_size;\n        m_local_token_indices_ptr_[i] = m_local_token_indices_ + offset;\n        m_local_expert_positions_ptr_[i] = m_local_expert_positions_ + offset;\n        offset += m_local_num_[i];\n    }\n\n    backend->do_work_stealing_job(qlen, nullptr, [&](int i) {\n        for (int j = 0; j < k; j++) {\n            uint64_t expert_id = expert_ids[i * k + j];\n            int local_row = m_local_pos_[i][j];\n            memcpy(m_local_down_output_grad_ptr_[expert_id] + local_row * config_.hidden_size * ggml_type_size(config_.grad_type), (uint8_t*)output_grad + i * config_.hidden_size * ggml_type_size(config_.grad_type), config_.hidden_size * ggml_type_size(config_.grad_type));\n            m_local_token_indices_ptr_[expert_id][local_row] = i;\n            m_local_expert_positions_ptr_[expert_id][local_row] = j;\n        }\n    }, nullptr);\n\n    // get_transpose(backend);\n\n    int stride = QK_K;\n    int nth = config_.intermediate_size / stride;\n    backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {\n        uint64_t expert_idx = task_id / nth;\n        int ith = task_id % nth;\n        \n        void* down_proj_t_ptr = (uint8_t*)down_proj_t_ + (expert_idx * config_.intermediate_size + ith * stride) * config_.hidden_size * ggml_type_size(config_.grad_type);\n        void* down_output_grad_ptr = m_local_down_output_grad_ptr_[expert_idx];\n        float* down_input_grad_ptr = m_local_down_input_grad_ptr_[expert_idx] + ith * stride;\n                    \n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.hidden_size, down_proj_t_ptr, config_.hidden_size, down_output_grad_ptr, config_.hidden_size, down_input_grad_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.grad_type, config_.grad_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        \n        for (int i = 0; i < m_local_num_[expert_idx]; i++) {\n            int token_idx = m_local_token_indices_ptr_[expert_idx][i];\n            int expert_pos = m_local_expert_positions_ptr_[expert_idx][i];\n            float weight = weights[token_idx * k + expert_pos];\n            \n            for (int j = ith * stride; j < (ith + 1) * stride; j++) {\n                m_local_down_input_grad_ptr_[expert_idx][i * config_.intermediate_size + j] *= weight;\n                \n                float down_input_grad = m_local_down_input_grad_ptr_[expert_idx][i * config_.intermediate_size + j];\n                m_local_gate_output_grad_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = down_input_grad * fwd_cache[token_idx].up_v[expert_pos][j] * act_fn_grad(fwd_cache[token_idx].gate_u[expert_pos][j]);\n                m_local_up_output_grad_fp32_ptr_[expert_idx][i * config_.intermediate_size + j] = down_input_grad * act_fn(fwd_cache[token_idx].gate_u[expert_pos][j]);\n            }\n            \n            float* gate_output_grad_fp32_ptr = m_local_gate_output_grad_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride;\n            void* gate_output_grad_ptr = m_local_gate_output_grad_ptr_[expert_idx] + (i * config_.intermediate_size + ith * stride) * ggml_type_size(config_.grad_type);\n            from_float(gate_output_grad_fp32_ptr, gate_output_grad_ptr, stride, config_.grad_type);\n            \n            float* up_output_grad_fp32_ptr = m_local_up_output_grad_fp32_ptr_[expert_idx] + i * config_.intermediate_size + ith * stride;\n            void* up_output_grad_ptr = m_local_up_output_grad_ptr_[expert_idx] + (i * config_.intermediate_size + ith * stride) * ggml_type_size(config_.grad_type);\n            from_float(up_output_grad_fp32_ptr, up_output_grad_ptr, stride, config_.grad_type);\n        }\n    }, nullptr);\n    stride = QK_K;\n    nth = config_.hidden_size / stride;\n    backend->do_work_stealing_job(nth * config_.expert_num, nullptr, [&](int task_id) {\n        uint64_t expert_idx = task_id / nth;\n        int ith = task_id % nth;\n        \n        void* gate_proj_t_ptr = (uint8_t*)gate_proj_t_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.grad_type);\n        void* up_proj_t_ptr = (uint8_t*)up_proj_t_ + (expert_idx * config_.hidden_size + ith * stride) * config_.intermediate_size * ggml_type_size(config_.grad_type);\n        void* gate_output_grad_ptr = m_local_gate_output_grad_ptr_[expert_idx];\n        void* up_output_grad_ptr = m_local_up_output_grad_ptr_[expert_idx];\n        float* gate_input_grad_ptr = m_local_gate_input_grad_ptr_[expert_idx] + ith * stride;\n        float* up_input_grad_ptr = m_local_up_input_grad_ptr_[expert_idx] + ith * stride;\n        \n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size, gate_proj_t_ptr, config_.intermediate_size, gate_output_grad_ptr, config_.intermediate_size, gate_input_grad_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.grad_type, config_.grad_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n        llamafile_sgemm(stride, m_local_num_[expert_idx], config_.intermediate_size, up_proj_t_ptr, config_.intermediate_size, up_output_grad_ptr, config_.intermediate_size, up_input_grad_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.grad_type, config_.grad_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);\n    }, nullptr);\n    backend->do_work_stealing_job(qlen, nullptr, [&](int i) {\n        for (int e = 0; e < config_.hidden_size; e++) {\n            m_grad_input_fp32_[i][e] = 0;\n        }\n        for (int j = 0; j < k; j++) {\n            for (int e = 0; e < config_.hidden_size; e++) {\n                m_grad_input_fp32_[i][e] += m_local_gate_input_grad_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e] + m_local_up_input_grad_ptr_[expert_ids[i * k + j]][m_local_pos_[i][j] * config_.hidden_size + e];\n            }\n        }\n        from_float(m_grad_input_fp32_[i], (uint8_t*)input_grad + i * config_.hidden_size * ggml_type_size(config_.grad_type), config_.hidden_size, config_.grad_type);\n    }, nullptr);\n}\n\n// TODO: input和layer_idx参数可以删除\nvoid SFT_MOE::backward(int layer_idx, int qlen, int k, const uint64_t* expert_ids, const float* weights,\n                   const void* input, const void* grad_output, void* grad_input, Backend* backend, const SFT_MoEForwardCache* fwd_cache) {\n\n    get_transpose(backend);\n    int remaining_qlen = qlen;\n    int processed_offset = 0;\n    \n    while (remaining_qlen > 0) {\n        // config_.group_min_len = 10000000;\n        if (remaining_qlen < config_.group_min_len) {\n            for (int i = 0; i < remaining_qlen; i++) {\n                backward_one(k,\n                             expert_ids + (processed_offset + i) * k,\n                             weights + (processed_offset + i) * k,\n                             (uint8_t*)grad_output + (processed_offset + i) * config_.hidden_size * ggml_type_size(config_.grad_type),\n                             (uint8_t*)grad_input + (processed_offset + i) * config_.hidden_size * ggml_type_size(config_.grad_type),\n                             backend,\n                             fwd_cache + processed_offset + i);\n            }\n            break;\n        } else {\n            int backward_len = std::min(config_.group_max_len, remaining_qlen);\n            backward_many(backward_len, \n                         k, \n                         expert_ids + processed_offset * k, \n                         weights + processed_offset * k, \n                         (uint8_t*)grad_output + processed_offset * config_.hidden_size * ggml_type_size(config_.grad_type), \n                         (uint8_t*)grad_input + processed_offset * config_.hidden_size * ggml_type_size(config_.grad_type), \n                         backend, \n                         fwd_cache + processed_offset);\n            \n            remaining_qlen -= backward_len;\n            processed_offset += backward_len;\n        }\n    }\n}"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/llamafile/sft_moe.h",
    "content": "/**\n * @Description  :\n * @Author       : chenht2022\n * @Date         : 2024-07-22 02:03:22\n * @Version      : 1.0.0\n * @LastEditors  : chenht2022\n * @LastEditTime : 2024-07-25 10:35:10\n * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n **/\n#ifndef CPUINFER_OPERATOR_SFT_MOE_H\n#define CPUINFER_OPERATOR_SFT_MOE_H\n\n#include <cmath>\n#include <cstdio>\n#include <functional>\n#include <mutex>\n#include <vector>\n\n#include \"../../cpu_backend/backend.h\"\n#include \"../../cpu_backend/shared_mem_buffer.h\"\n#include \"conversion.h\"\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"llama.cpp/ggml.h\"\n#include \"llamafile/sgemm.h\"\n#include \"sft_moe_forward_cache.h\"\n\nstruct SFT_MOEConfig {\n    long expert_num;\n    int routed_expert_num;\n    long hidden_size;\n    long intermediate_size;\n    int stride;\n    int group_min_len;\n    int group_max_len;\n    void* gate_proj;\n    void* up_proj;\n    void* down_proj;\n    ggml_type gate_type;\n    ggml_type up_type;\n    ggml_type down_type;\n    ggml_type hidden_type;\n    ggml_type grad_type = GGML_TYPE_BF16;\n\n    SFT_MOEConfig() {}\n\n    SFT_MOEConfig(int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)\n        : expert_num(expert_num), routed_expert_num(routed_expert_num), hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_min_len(group_min_len), group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}\n};\n\nclass SFT_MOE {\n   public:\n    SFT_MOE(SFT_MOEConfig);\n    ~SFT_MOE();\n    void warm_up(Backend* backend);\n    void forward_one(int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend, SFT_MoEForwardCache* fwd_cache);\n    void forward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend, SFT_MoEForwardCache* fwd_cache);\n    void forward(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* input, void* output, Backend* backend, SFT_MoEForwardCache* fwd_cache);\n\tvoid backward_one(int k, const uint64_t* expert_ids, const float* weights, const void* output_grad, void* input_grad, Backend* backend, const SFT_MoEForwardCache* fwd_cache);\n\tvoid backward_many(int qlen, int k, const uint64_t* expert_ids, const float* weights, const void* output_grad, void* input_grad, Backend* backend, const SFT_MoEForwardCache* fwd_cache);\n\tvoid backward(int layer_idx, int qlen, int k, const uint64_t* expert_ids, const float* weights,\n              const void* input, const void* grad_output, void* grad_input, Backend* backend, const SFT_MoEForwardCache* fwd_cache); // FIXME: expert backward definition for C++\n    \n    void transpose_expert_matrix(const void* src, void* dst, int R, int C, ggml_type src_type, ggml_type dst_type, uint64_t expert_idx);\n    void ensure_fwd_cache(int qlen, int k);\n    void get_transpose(Backend* backend);\n    SFT_MoEForwardCache* fwd_cache_ptr();\n\n   private:\n    SFT_MOEConfig config_;\n    void* gate_proj_;  // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    void* up_proj_;    // [expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    void* down_proj_;  // [expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n\n    float* transpose_buffer_fp32_;  // [expert_num * intermediate_size * hidden_size]\n    uint8_t* transpose_buffer_;     // [expert_num * intermediate_size * hidden_size]\n\n    uint8_t* gate_proj_t_;  // [expert_num * hidden_size * intermediate_size]\n    uint8_t* up_proj_t_;    // [expert_num * hidden_size * intermediate_size]\n    uint8_t* down_proj_t_;  // [expert_num * intermediate_size * hidden_size]\n\n    #ifdef USE_NUMA\n    std::vector<void*> gate_proj_numa_;  // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    std::vector<void*> up_proj_numa_;    // [numa_num, expert_num * intermediate_size * hidden_size ( /32 if quantized)]\n    std::vector<void*> down_proj_numa_;  // [numa_num, expert_num * hidden_size * intermediate_size ( /32 if quantized)]\n    #endif\n\n    float* s_input_fp32_;                      // [hidden_size]\n    uint8_t* s_gate_input_;                    // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    uint8_t* s_up_input_;                      // [hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    std::vector<float*> s_gate_output_;        // [routed_expert_num, intermediate_size]\n    std::vector<float*> s_up_output_;          // [routed_expert_num, intermediate_size]\n    std::vector<float*> s_intermediate_fp32_;  // [routed_expert_num, intermediate_size]\n    std::vector<uint8_t*> s_down_input_;       // [routed_expert_num, intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n    std::vector<float*> s_down_output_;        // [routed_expert_num, hidden_size]\n    float* s_output_fp32_;                     // [hidden_size]\n\n    std::vector<float*> s_down_input_grad_;        // [routed_expert_num, intermediate_size]\n    std::vector<float*> s_gate_output_grad_fp32_;  // [routed_expert_num, intermediate_size]\n    std::vector<float*> s_up_output_grad_fp32_;    // [routed_expert_num, intermediate_size]\n    std::vector<uint8_t*> s_gate_output_grad_;     // [routed_expert_num, intermediate_size * ggml_type_size(grad_type)]\n    std::vector<uint8_t*> s_up_output_grad_;       // [routed_expert_num, intermediate_size * ggml_type_size(grad_type)]\n    std::vector<float*> s_gate_input_grad_;        // [routed_expert_num, hidden_size]\n    std::vector<float*> s_up_input_grad_;          // [routed_expert_num, hidden_size]\n    float* s_input_grad_fp32_;                     // [hidden_size]\n\n    std::vector<float*> m_input_fp32_;    // [group_max_len, hidden_size]\n    std::vector<uint8_t*> m_gate_input_;  // [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    std::vector<uint8_t*> m_up_input_;    // [group_max_len, hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    uint8_t* m_local_gate_input_;         // [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]\n    uint8_t* m_local_up_input_;           // [routed_expert_num * group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]\n    float* m_local_gate_output_;          // [routed_expert_num * group_max_len * intermediate_size]\n    float* m_local_up_output_;            // [routed_expert_num * group_max_len * intermediate_size]\n    float* m_local_intermediate_fp32_;    // [routed_expert_num * group_max_len * intermediate_size]\n    uint8_t* m_local_down_input_;         // [routed_expert_num * group_max_len * intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]\n    float* m_local_down_output_;          // [routed_expert_num * group_max_len * hidden_size]\n    std::vector<float*> m_output_fp32_;   // [group_max_len, hidden_size]\n\n    std::vector<std::vector<int>> m_local_pos_;          // [group_max_len, routed_expert_num]\n    std::vector<int> m_local_num_;                       // [expert_num]\n    std::vector<uint8_t*> m_local_gate_input_ptr_;       // [expert_num]\n    std::vector<uint8_t*> m_local_up_input_ptr_;         // [expert_num]\n    std::vector<float*> m_local_gate_output_ptr_;        // [expert_num]\n    std::vector<float*> m_local_up_output_ptr_;          // [expert_num]\n    std::vector<float*> m_local_intermediate_fp32_ptr_;  // [expert_num]\n    std::vector<uint8_t*> m_local_down_input_ptr_;       // [expert_num]\n    std::vector<float*> m_local_down_output_ptr_;        // [expert_num]\n\n    uint8_t* m_local_down_output_grad_;                  // [routed_expert_num * group_max_len * hidden_size * ggml_type_size(grad_type)]\n    float* m_local_down_input_grad_;                     // [routed_expert_num * group_max_len * intermediate_size]\n    float* m_local_gate_output_grad_fp32_;               // [routed_expert_num * group_max_len * intermediate_size]\n    float* m_local_up_output_grad_fp32_;                 // [routed_expert_num * group_max_len * intermediate_size]\n    uint8_t* m_local_gate_output_grad_;                  // [routed_expert_num * group_max_len * intermediate_size * ggml_type_size(grad_type)]\n    uint8_t* m_local_up_output_grad_;                    // [routed_expert_num * group_max_len * intermediate_size * ggml_type_size(grad_type)]\n    float* m_local_gate_input_grad_;                     // [routed_expert_num * group_max_len * hidden_size]\n    float* m_local_up_input_grad_;                       // [routed_expert_num * group_max_len * hidden_size]\n    std::vector<float*> m_grad_input_fp32_;              // [group_max_len, hidden_size]\n\n    std::vector<uint8_t*> m_local_down_output_grad_ptr_;     // [expert_num]\n    std::vector<float*> m_local_down_input_grad_ptr_;        // [expert_num]\n    std::vector<float*> m_local_gate_output_grad_fp32_ptr_;  // [expert_num]\n    std::vector<float*> m_local_up_output_grad_fp32_ptr_;    // [expert_num]\n    std::vector<uint8_t*> m_local_gate_output_grad_ptr_;     // [expert_num]\n    std::vector<uint8_t*> m_local_up_output_grad_ptr_;       // [expert_num]\n    std::vector<float*> m_local_gate_input_grad_ptr_;        // [expert_num]\n    std::vector<float*> m_local_up_input_grad_ptr_;          // [expert_num]\n\n    int* m_local_token_indices_;                             // [routed_expert_num * group_max_len]\n    int* m_local_expert_positions_;                          // [routed_expert_num * group_max_len]\n    std::vector<int*> m_local_token_indices_ptr_;            // [expert_num]\n    std::vector<int*> m_local_expert_positions_ptr_;         // [expert_num]\n\n\tstd::vector<SFT_MoEForwardCache> fw_cache_; // 持久缓存，便于backward读取到forward_cache\n};\n\n#endif"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/operators/llamafile/sft_moe_forward_cache.h",
    "content": "#pragma once\n#include <vector>\n\nstruct SFT_MoEForwardCache {\n    // 每个 token 按 expert 分块保存\n    std::vector<std::vector<float>> gate_u;   // u = W_gate x\n    std::vector<std::vector<float>> up_v;     // v = W_up   x\n    // 若希望反向直接用 z = σ(u)⊙v，则再加一份\n    // std::vector<std::vector<float>> z;\n    void init(int k, int inter_size) {\n        /* ---- 只增不减：capacity 不够时才增，永不缩小，避免多线程情况下的use-after-free ---- */\n       if (k > (int)gate_u.size()) {\n            gate_u.resize(k);\n            up_v  .resize(k);\n            // z     .resize(k);\n        }\n\n        for (int i = 0; i < k; ++i) {\n            if ((int)gate_u[i].capacity() < inter_size)\n                gate_u[i].reserve(inter_size);   // 只增 capacity\n            if ((int)up_v[i].capacity()   < inter_size)\n                up_v[i].reserve(inter_size);\n            // if ((int)z[i].capacity()      < inter_size)\n            //     z[i].reserve(inter_size);\n\n            // size() 更新为 inter_size 以便直接下标写入\n            gate_u[i].resize(inter_size);\n            up_v[i]  .resize(inter_size);\n            // z[i]     .resize(inter_size);\n        }\n\t}\n};"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/vendors/cuda.h",
    "content": "#pragma once\n\n#include <cuda_runtime.h>\n#include <cuda.h>\n#include <cublas_v2.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n\n#if CUDART_VERSION < 11020\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#endif // CUDART_VERSION < 11020\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/vendors/hip.h",
    "content": "#pragma once\n\n#define HIP_ENABLE_WARP_SYNC_BUILTINS 1\n#include <hip/hip_runtime.h>\n#include <hipblas/hipblas.h>\n#include <hip/hip_fp16.h>\n#include <hip/hip_bfloat16.h>\n#ifdef __HIP_PLATFORM_AMD__\n// for rocblas_initialize()\n#include \"rocblas/rocblas.h\"\n#endif // __HIP_PLATFORM_AMD__\n\n#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F\n#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F\n#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N HIPBLAS_OP_N\n#define CUBLAS_OP_T HIPBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH 0\n#define CUDA_R_16F  HIPBLAS_R_16F\n#define CUDA_R_32F  HIPBLAS_R_32F\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended\n#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned\n#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite\n#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT(\"HipVMM Failure: %s\\n\", hipGetErrorString(err)); }}\n#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)\n#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)\n#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6\n#define cublasCreate hipblasCreate\n#define cublasDestroy hipblasDestroy\n#define cublasGemmEx hipblasGemmEx\n#define cublasGemmBatchedEx hipblasGemmBatchedEx\n#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx\n#define cublasHandle_t hipblasHandle_t\n#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS\n#define cublasSetStream hipblasSetStream\n#define cublasSgemm hipblasSgemm\n#define cublasStatus_t hipblasStatus_t\n#define cublasOperation_t hipblasOperation_t\n#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6\n#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess\n#define cudaDeviceProp hipDeviceProp_t\n#define cudaDeviceSynchronize hipDeviceSynchronize\n#define cudaError_t hipError_t\n#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags hipEventCreateWithFlags\n#define cudaEventDisableTiming hipEventDisableTiming\n#define cudaEventRecord hipEventRecord\n#define cudaEventSynchronize hipEventSynchronize\n#define cudaEvent_t hipEvent_t\n#define cudaEventDestroy hipEventDestroy\n#define cudaFree hipFree\n#define cudaFreeHost hipHostFree\n#define cudaGetDevice hipGetDevice\n#define cudaGetDeviceCount hipGetDeviceCount\n#define cudaGetDeviceProperties hipGetDeviceProperties\n#define cudaGetErrorString hipGetErrorString\n#define cudaGetLastError hipGetLastError\n#define cudaHostRegister hipHostRegister\n#define cudaHostRegisterPortable hipHostRegisterPortable\n#define cudaHostRegisterReadOnly hipHostRegisterReadOnly\n#define cudaHostUnregister hipHostUnregister\n#define cudaLaunchHostFunc hipLaunchHostFunc\n#define cudaMalloc hipMalloc\n#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)\n#define cudaMemcpy hipMemcpy\n#define cudaMemcpyAsync hipMemcpyAsync\n#define cudaMemcpyPeerAsync hipMemcpyPeerAsync\n#define cudaMemcpy2DAsync hipMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice hipMemcpyHostToDevice\n#define cudaMemcpyKind hipMemcpyKind\n#define cudaMemset hipMemset\n#define cudaMemsetAsync hipMemsetAsync\n#define cudaMemGetInfo hipMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize\n#define cudaSetDevice hipSetDevice\n#define cuDeviceGet hipDeviceGet\n#define CUdevice hipDevice_t\n#define CUdeviceptr hipDeviceptr_t\n#define cuMemUnmap hipMemUnmap\n#define CUmemAccessDesc hipMemAccessDesc\n#define cuMemAddressFree hipMemAddressFree\n#define cuMemRelease hipMemRelease\n#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t\n#define cuMemCreate hipMemCreate\n#define cuMemAddressReserve hipMemAddressReserve\n#define cuMemMap hipMemMap\n#define cuMemSetAccess hipMemSetAccess\n#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity\n#define CUmemAllocationProp hipMemAllocationProp\n#define cuDeviceGetAttribute hipDeviceGetAttribute\n#define cudaStreamCreateWithFlags hipStreamCreateWithFlags\n#define cudaStreamDestroy hipStreamDestroy\n#define cudaStreamFireAndForget hipStreamFireAndForget\n#define cudaStreamNonBlocking hipStreamNonBlocking\n#define cudaStreamPerThread hipStreamPerThread\n#define cudaStreamSynchronize hipStreamSynchronize\n#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)\n#define cudaGraphExec_t hipGraphExec_t\n#define cudaGraphNode_t hipGraphNode_t\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaKernelNodeParams hipKernelNodeParams\n#define cudaGraphExecDestroy hipGraphExecDestroy\n#define cudaGraphLaunch hipGraphLaunch\n#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure\n#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult\n#define cudaGraphNodeType hipGraphNodeType\n#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel\n#define cudaGraphInstantiate hipGraphInstantiate\n#define cudaStreamEndCapture hipStreamEndCapture\n#define cudaGraphDestroy hipGraphDestroy\n#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams\n#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction\n#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams\n#define cudaGraphNodeGetType hipGraphNodeGetType\n#define cudaGraphGetNodes hipGraphGetNodes\n#define cudaGraphExecUpdate hipGraphExecUpdate\n#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed\n#define cudaStreamBeginCapture hipStreamBeginCapture\n#define cudaGraph_t hipGraph_t\n#define cudaStream_t hipStream_t\n#define cudaSuccess hipSuccess\n#define cudaHostFn_t hipHostFn_t\n#define __trap() do { abort(); __builtin_unreachable(); } while(0)\n#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS\n#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED\n#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED\n#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE\n#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH\n#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR\n#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED\n#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR\n#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED\n\n#define __CUDA_ARCH__ 1300\n\n#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)\n#define GCN\n#endif\n\n#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)\n#define CDNA\n#endif\n\n#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \\\n    defined(__gfx1150__) || defined(__gfx1151__)\n#define RDNA3\n#endif\n\n#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \\\n    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)\n#define RDNA2\n#endif\n\n#if defined(__gfx1010__) || defined(__gfx1012__)\n#define RDNA1\n#endif\n\n#ifndef __has_builtin\n    #define __has_builtin(x) 0\n#endif\n\ntypedef hip_bfloat16 nv_bfloat16;\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/vendors/musa.h",
    "content": "#pragma once\n\n#include <musa_runtime.h>\n#include <musa.h>\n#include <mublas.h>\n#include <musa_bf16.h>\n#include <musa_fp16.h>\n#define CUBLAS_COMPUTE_16F CUDA_R_16F\n#define CUBLAS_COMPUTE_32F CUDA_R_32F\n#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F\n#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT\n#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT\n#define CUBLAS_OP_N MUBLAS_OP_N\n#define CUBLAS_OP_T MUBLAS_OP_T\n#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS\n#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT\n#define CUDA_R_16F  MUSA_R_16F\n#define CUDA_R_32F  MUSA_R_32F\n#define cublasComputeType_t cudaDataType_t\n#define cublasCreate mublasCreate\n#define cublasDestroy mublasDestroy\n#define cublasGemmEx mublasGemmEx\n#define cublasGemmBatchedEx mublasGemmBatchedEx\n#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx\n#define cublasHandle_t mublasHandle_t\n#define cublasSetMathMode mublasSetMathMode\n#define cublasSetStream mublasSetStream\n#define cublasSgemm mublasSgemm\n#define cublasStatus_t mublasStatus_t\n#define cublasOperation_t mublasOperation_t\n#define cublasGetStatusString mublasStatus_to_string\n#define cudaDataType_t musaDataType_t\n#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer\n#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess\n#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess\n#define cudaDeviceProp musaDeviceProp\n#define cudaDeviceSynchronize musaDeviceSynchronize\n#define cudaError_t musaError_t\n#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled\n#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled\n#define cudaEventCreateWithFlags musaEventCreateWithFlags\n#define cudaEventDisableTiming musaEventDisableTiming\n#define cudaEventRecord musaEventRecord\n#define cudaEventSynchronize musaEventSynchronize\n#define cudaEvent_t musaEvent_t\n#define cudaEventDestroy musaEventDestroy\n#define cudaFree musaFree\n#define cudaFreeHost musaFreeHost\n#define cudaGetDevice musaGetDevice\n#define cudaGetDeviceCount musaGetDeviceCount\n#define cudaGetDeviceProperties musaGetDeviceProperties\n#define cudaGetErrorString musaGetErrorString\n#define cudaGetLastError musaGetLastError\n#define cudaHostRegister musaHostRegister\n#define cudaHostRegisterPortable musaHostRegisterPortable\n#define cudaHostRegisterReadOnly musaHostRegisterReadOnly\n#define cudaHostUnregister musaHostUnregister\n#define cudaLaunchHostFunc musaLaunchHostFunc\n#define cudaMalloc musaMalloc\n#define cudaMallocHost musaMallocHost\n#define cudaMallocManaged musaMallocManaged\n#define cudaMemcpy musaMemcpy\n#define cudaMemcpyAsync musaMemcpyAsync\n#define cudaMemcpyPeerAsync musaMemcpyPeerAsync\n#define cudaMemcpy2DAsync musaMemcpy2DAsync\n#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice\n#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost\n#define cudaMemcpyHostToDevice musaMemcpyHostToDevice\n#define cudaMemcpyKind musaMemcpyKind\n#define cudaMemset musaMemset\n#define cudaMemsetAsync musaMemsetAsync\n#define cudaMemGetInfo musaMemGetInfo\n#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize\n#define cudaSetDevice musaSetDevice\n#define cudaStreamCreateWithFlags musaStreamCreateWithFlags\n#define cudaStreamDestroy musaStreamDestroy\n#define cudaStreamFireAndForget musaStreamFireAndForget\n#define cudaStreamNonBlocking musaStreamNonBlocking\n#define cudaStreamPerThread musaStreamPerThread\n#define cudaStreamSynchronize musaStreamSynchronize\n#define cudaStreamWaitEvent musaStreamWaitEvent\n#define cudaStream_t musaStream_t\n#define cudaSuccess musaSuccess\n\n// Additional mappings for MUSA virtual memory pool\n#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED\n#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE\n#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED\n#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED\n#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE\n#define CUdevice MUdevice\n#define CUdeviceptr MUdeviceptr\n#define CUmemAccessDesc MUmemAccessDesc\n#define CUmemAllocationProp MUmemAllocationProp\n#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle\n#define cuDeviceGet muDeviceGet\n#define cuDeviceGetAttribute muDeviceGetAttribute\n#define cuMemAddressFree muMemAddressFree\n#define cuMemAddressReserve muMemAddressReserve\n#define cuMemCreate muMemCreate\n#define cuMemGetAllocationGranularity muMemGetAllocationGranularity\n#define cuMemMap muMemMap\n#define cuMemRelease muMemRelease\n#define cuMemSetAccess muMemSetAccess\n#define cuMemUnmap muMemUnmap\n#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize\n#define cudaFuncSetAttribute musaFuncSetAttribute\n#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms\n#define make_cudaExtent make_musaExtent\n#define make_cudaPitchedPtr make_musaPitchedPtr\n\n// Additional mappings for MUSA graphs\n#define CUDA_SUCCESS MUSA_SUCCESS\n#define CUresult MUresult\n#define cuGetErrorString muGetErrorString\n#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure\n#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction\n#define cudaGraphDestroy musaGraphDestroy\n#define cudaGraphExecDestroy musaGraphExecDestroy\n#define cudaGraphExec_t musaGraphExec_t\n#define cudaGraphExecUpdate musaGraphExecUpdate\n#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult\n#define cudaGraphGetNodes musaGraphGetNodes\n#define cudaGraphInstantiate musaGraphInstantiate\n#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams\n#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams\n#define cudaGraphLaunch musaGraphLaunch\n#define cudaGraphNodeGetType musaGraphNodeGetType\n#define cudaGraphNode_t musaGraphNode_t\n#define cudaGraphNodeType musaGraphNodeType\n#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel\n#define cudaGraph_t musaGraph_t\n#define cudaKernelNodeParams musaKernelNodeParams\n#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed\n#define cudaStreamEndCapture musaStreamEndCapture\n\ntypedef mt_bfloat16 nv_bfloat16;\n"
  },
  {
    "path": "kt-sft/csrc/ktransformers_ext/vendors/vendor.h",
    "content": "#ifndef CPUINFER_VENDOR_VENDOR_H\n#define CPUINFER_VENDOR_VENDOR_H\n\n#ifdef USE_CUDA\n#include \"cuda.h\"\n#elif USE_HIP\n#define __HIP_PLATFORM_AMD__\n#include \"hip.h\"\n#elif USE_MUSA\n#include \"musa.h\"\n#endif\n\n#endif  // CPUINFER_VENDOR_VENDOR_H"
  },
  {
    "path": "kt-sft/install-with-cache.sh",
    "content": "#!/bin/bash\nset -e  \n\n# clear build dirs\n# rm -rf build\n# rm -rf *.egg-info\n# rm -rf csrc/build\n# rm -rf csrc/ktransformers_ext/build\n# rm -rf csrc/ktransformers_ext/cuda/build\n# rm -rf csrc/ktransformers_ext/cuda/dist\n# rm -rf csrc/ktransformers_ext/cuda/*.egg-info\nrm -rf ~/.ktransformers\necho \"Installing python dependencies from requirements.txt\"\npip install -r requirements-local_chat.txt\npip install -r ktransformers/server/requirements.txt\necho \"Installing ktransformers\"\nKTRANSFORMERS_FORCE_BUILD=TRUE USE_BALANCE_SERVE=1 pip install -v . --no-build-isolation\npip install third_party/custom_flashinfer/ -v\n\n# SITE_PACKAGES=$(python -c \"import site; print(site.getsitepackages()[0])\")\n# echo \"Copying thirdparty libs to $SITE_PACKAGES\"\n# cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/\n# patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython*\n\n\necho \"Installation completed successfully\"\n"
  },
  {
    "path": "kt-sft/install.bat",
    "content": "@echo off\n\nREM clear build dirs\nrmdir /S /Q ktransformers\\ktransformers_ext\\build\nrmdir /S /Q ktransformers\\ktransformers_ext\\cuda\\build\nrmdir /S /Q ktransformers\\ktransformers_ext\\cuda\\dist\nrmdir /S /Q ktransformers\\ktransformers_ext\\out\ndel /F /Q ktransformers\\ktransformers_ext\\cuda\\*.egg-info\n\necho Installing python dependencies from requirements.txt\npip install -r requirements-local_chat.txt\n\necho Installing ktransformers\nset KTRANSFORMERS_FORCE_BUILD=TRUE\npip install . --no-build-isolation\necho Installation completed successfully"
  },
  {
    "path": "kt-sft/install.sh",
    "content": "#!/bin/bash\nset -e  \n\nCWD=$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\n\n# default backend\nDEV=\"cuda\"\n\n# parse --dev argument\nwhile [[ \"$#\" -gt 0 ]]; do\n    case $1 in\n        --dev) DEV=\"$2\"; shift ;;\n        *) echo \"Unknown parameter passed: $1\"; exit 1 ;;\n    esac\n    shift\ndone\nexport DEV_BACKEND=\"$DEV\"\necho \"Selected backend: $DEV_BACKEND\"\n\n# clear build dirs\nrm -rf build\nrm -rf *.egg-info\nrm -rf csrc/build\nrm -rf csrc/ktransformers_ext/build\nrm -rf csrc/ktransformers_ext/cuda/build\nrm -rf csrc/ktransformers_ext/cuda/dist\nrm -rf csrc/ktransformers_ext/cuda/*.egg-info\nrm -rf ~/.ktransformers\necho \"Installing python dependencies from requirements.txt\"\npip install -r \"${CWD}/requirements-sft.txt\"\n\necho \"Installing ktransformers\"\nKTRANSFORMERS_FORCE_BUILD=TRUE pip install -v \"${CWD}\" --no-build-isolation\n\nif [[ \"$DEV_BACKEND\" == \"cuda\" ]]; then\n    echo \"Installing custom_flashinfer for CUDA backend\"\n    pip install \"${CWD}/../third_party/custom_flashinfer/\"\nfi\n# SITE_PACKAGES=$(python -c \"import site; print(site.getsitepackages()[0])\")\n# echo \"Copying thirdparty libs to $SITE_PACKAGES\"\n# cp -a csrc/balance_serve/build/third_party/prometheus-cpp/lib/libprometheus-cpp-*.so* $SITE_PACKAGES/\n# patchelf --set-rpath '$ORIGIN' $SITE_PACKAGES/sched_ext.cpython*\n\necho \"Installation completed successfully\""
  },
  {
    "path": "kt-sft/ktransformers/__init__.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :\nAuthor       : kkk1nak0\nDate         : 2024-08-15 07:34:46\nVersion      : 1.0.0\nLastEditors  : chenxl\nLastEditTime : 2025-02-15 03:53:02\n'''\nimport sys\nimport os\n\n# Import version from shared version.py at project root\n_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))\nsys.path.insert(0, _root_dir)\ntry:\n    from version import __version__\nfinally:\n    sys.path.pop(0)\n"
  },
  {
    "path": "kt-sft/ktransformers/configs/config.yaml",
    "content": "log:\n  dir: \"logs\"\n  file: \"lexllama.log\"\n  #log level: debug, info, warn, error, crit\n  level: \"debug\"\n  backup_count: -1\n\nserver:\n  ip: 0.0.0.0\n  port: 10002\n\ndb:\n  type: \"sqllite\"\n  database: \"server.db\"\n  host: \"./\"\n  pool_size: 10\n\nuser:\n  secret_key: \"981f1dd2a44e27d68759d0252a486568ed43480b4e616a26e3af3709c3a7ce73\"\n  algorithm: \"HS256\"\n\nmodel:\n  # type: transformers\n  # type: balance_serve\n  type: ktransformers\n\n  name: DeepSeek-Coder-V2-Instruct\n  path: deepseek-ai/DeepSeek-V2-Lite-Chat\n  gguf_path: ./DeepSeek-V2-Lite-Chat-GGUF\n\n  device: cuda:0\n  cache_lens: 16384\n  max_new_tokens: 500\nweb:\n  mount: False\n  open_cross_domain: True\n\next:\n  cpu_infer: 10\n\nlong_context:\n  max_seq_len: 32000\n  block_size: 128\n  local_windows_len: 4096\n  second_select_num: 32\n  anchor_type: DYNAMIC\n  kv_type: FP16\n  dense_layer_num: 2\n  anchor_num: 1\n  preselect_block: True\n  head_select_mode: SHARED\n  preselect_block_count: 32\n  layer_step: 1\n  token_step: \n\nlocal_chat:\n  prompt_file: \"\"\n\nasync_server:\n  sched_strategy: \"FCFS\"\n  sched_port: 56441\n  sched_metrics_port: 54321\n  kvc2_metrics_port: 54391\n  max_batch_size: 4  # decode count + prefill count, in one mini batch\n\nattn:\n  page_size: 256\n  chunk_size: 256\nkvc2:\n  gpu_only: true \n  utilization_percentage: 1.0\n  cpu_memory_size_GB: 500\n"
  },
  {
    "path": "kt-sft/ktransformers/configs/log_config.ini",
    "content": "[loggers]\nkeys=root,uvicorn,uvicornError,uvicornAccess\n\n[handlers]\nkeys=consoleHandler,fileHandler\n\n[formatters]\nkeys=detailedFormatter\n\n[logger_root]\nlevel=INFO\nhandlers=consoleHandler\n\n[logger_uvicorn]\nlevel=INFO\nhandlers=consoleHandler,fileHandler\nqualname=uvicorn\npropagate=0\n\n[logger_uvicornError]\nlevel=ERROR\nhandlers=consoleHandler,fileHandler\nqualname=uvicorn.error\npropagate=0\n\n[logger_uvicornAccess]\nlevel=INFO\nhandlers=consoleHandler,fileHandler\nqualname=uvicorn.access\npropagate=0\n\n[handler_consoleHandler]\nclass=StreamHandler\nlevel=INFO\nformatter=detailedFormatter\nargs=(sys.stdout,)\n\n[handler_fileHandler]\nclass=logging.FileHandler\nlevel=INFO\nformatter=detailedFormatter\nargs=('uvicorn_logs.log', 'a')\n\n[formatter_detailedFormatter]\nformat=%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s\ndatefmt=%Y-%m-%d %H:%M:%S\n"
  },
  {
    "path": "kt-sft/ktransformers/configs/model_config/config.json",
    "content": "{\n\t\"architectures\": [\n\t\t\"DeepseekV2ForCausalLM\"\n\t],\n\t\"attention_bias\": false,\n\t\"attention_dropout\": 0.0,\n\t\"auto_map\": {\n\t\t\"AutoConfig\": \"configuration_deepseek.DeepseekV2Config\",\n\t\t\"AutoModel\": \"modeling_deepseek.DeepseekV2Model\",\n\t\t\"AutoModelForCausalLM\": \"modeling_deepseek.DeepseekV2ForCausalLM\"\n\t},\n\t\"aux_loss_alpha\": 0.001,\n\t\"bos_token_id\": 100000,\n\t\"eos_token_id\": 100001,\n\t\"first_k_dense_replace\": 1,\n\t\"hidden_act\": \"silu\",\n\t\"hidden_size\": 2048,\n\t\"initializer_range\": 0.02,\n\t\"intermediate_size\": 10944,\n\t\"kv_lora_rank\": 512,\n\t\"max_position_embeddings\": 163840,\n\t\"model_type\": \"deepseek_v2\",\n\t\"moe_intermediate_size\": 1408,\n\t\"moe_layer_freq\": 1,\n\t\"n_group\": 1,\n\t\"n_routed_experts\": 64,\n\t\"n_shared_experts\": 2,\n\t\"norm_topk_prob\": false,\n\t\"num_attention_heads\": 16,\n\t\"num_experts_per_tok\": 6,\n\t\"num_hidden_layers\": 27,\n\t\"num_key_value_heads\": 16,\n\t\"pretraining_tp\": 1,\n\t\"q_lora_rank\": null,\n\t\"qk_nope_head_dim\": 128,\n\t\"qk_rope_head_dim\": 64,\n\t\"rms_norm_eps\": 1e-06,\n\t\"rope_scaling\": {\n\t\t\"beta_fast\": 32,\n\t\t\"beta_slow\": 1,\n\t\t\"factor\": 40,\n\t\t\"mscale\": 0.707,\n\t\t\"mscale_all_dim\": 0.707,\n\t\t\"original_max_position_embeddings\": 4096,\n\t\t\"type\": \"yarn\"\n\t},\n\t\"rope_theta\": 10000,\n\t\"routed_scaling_factor\": 1.0,\n\t\"scoring_func\": \"softmax\",\n\t\"seq_aux\": true,\n\t\"tie_word_embeddings\": false,\n\t\"topk_group\": 1,\n\t\"topk_method\": \"greedy\",\n\t\"torch_dtype\": \"bfloat16\",\n\t\"transformers_version\": \"4.33.1\",\n\t\"use_cache\": true,\n\t\"v_head_dim\": 128,\n\t\"vocab_size\": 102400\n}"
  },
  {
    "path": "kt-sft/ktransformers/configs/model_config/configuration_deepseek.py",
    "content": "from transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nDEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}\nclass DeepseekV2Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the DeepSeek-V2.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 102400):\n            Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`DeepseekV2Model`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        moe_intermediate_size (`int`, *optional*, defaults to 1407):\n            Dimension of the MoE representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer decoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        n_shared_experts (`int`, *optional*, defaults to None):\n            Number of shared experts, None means dense model.\n        n_routed_experts (`int`, *optional*, defaults to None):\n            Number of routed experts, None means dense model.\n        routed_scaling_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor or routed experts.\n        topk_method (`str`, *optional*, defaults to `gready`):\n            Topk method used in routed gate.\n        n_group (`int`, *optional*, defaults to None):\n            Number of groups for routed experts.\n        topk_group (`int`, *optional*, defaults to None):\n            Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).\n        num_experts_per_tok (`int`, *optional*, defaults to None):\n            Number of selected experts, None means dense model.\n        moe_layer_freq (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.\n        first_k_dense_replace (`int`, *optional*, defaults to 0):\n            Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).\n                                                            \\--k dense layers--/\n        norm_topk_prob (`bool`, *optional*, defaults to False):\n            Whether to normalize the weights of the routed experts.\n        scoring_func (`str`, *optional*, defaults to 'softmax'):\n            Method of computing expert weights.\n        aux_loss_alpha (`float`, *optional*, defaults to 0.001):\n            Auxiliary loss weight coefficient.\n        seq_aux = (`bool`, *optional*, defaults to True):\n            Whether to compute the auxiliary loss for each individual sample.\n        num_key_value_heads (`int`, *optional*):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 2048):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*):\n            Padding token id.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            Beginning of stream token id.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            End of stream token id.\n        pretraining_tp (`int`, *optional*, defaults to 1):\n            Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this\n            document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is\n            necessary to ensure exact reproducibility of the pretraining results. Please refer to [this\n            issue](https://github.com/pytorch/pytorch/issues/76232).\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling\n            strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is\n            `{\"type\": strategy name, \"factor\": scaling factor}`. When using this flag, don't update\n            `max_position_embeddings` to the expected new maximum.\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n\n    ```python\n    >>> from transformers import DeepseekV2Model, DeepseekV2Config\n\n    >>> # Initializing a Deepseek-V2 style configuration\n    >>> configuration = DeepseekV2Config()\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"deepseek_v2\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=102400,\n        hidden_size=4096,\n        intermediate_size=11008,\n        moe_intermediate_size = 1407,\n        num_hidden_layers=30,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        n_shared_experts = None,\n        n_routed_experts = None,\n        ep_size = 1,\n        routed_scaling_factor = 1.0,\n        kv_lora_rank = 512,\n        q_lora_rank = 1536,\n        qk_rope_head_dim = 64,\n        v_head_dim = 128,\n        qk_nope_head_dim = 128,\n        topk_method = 'gready',\n        n_group = None,\n        topk_group = None,\n        num_experts_per_tok = None,\n        moe_layer_freq = 1,\n        first_k_dense_replace = 0,\n        norm_topk_prob = False,\n        scoring_func = 'softmax',\n        aux_loss_alpha = 0.001,\n        seq_aux = True,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=100000,\n        eos_token_id=100001,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        self.aux_loss_alpha = aux_loss_alpha\n        self.seq_aux = seq_aux\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )"
  },
  {
    "path": "kt-sft/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/format_24.py",
    "content": "#\n# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).\n#\n\nimport torch\n\n\n# This is PyTorch implementation of main part of reorder_meta()\n# function, from tools/util/include/cutlass/util/host_reorder.h file\n# of CUTLASS source tree.  Furthermore, CUTLASS template for sparse\n# GEMM decides upon layout of this matrix, and at the moment for the\n# sparse GEMM executed on tensor cores, this is layout described by\n# ColumnMajorInterleaved<2> data structure, in\n# include/cutlass/layout/matrix.h of CUTLASS source tree.  The\n# reordering of meta matrix into meta_reordered matrix calculated\n# according to these segments of CUTLASS code is re-implemented here.\n# Note that this calculation produces offsets for scattering metadata\n# matrix elements into reordered metadata matrix elements (or,\n# equivalently, for gathering reordered metadata matrix element back\n# into metadata matrix elements).\ndef _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype,\n                                               device):\n    dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)\n    dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)\n\n    # Reorder the rows, then swizzle the 2x2 blocks.\n    group_x = 64\n    group_y = 32 if meta_dtype.itemsize == 2 else 16\n\n    dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 +\n                (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 +\n                ((dst_rows % group_x) // 8) * 4)\n\n    topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)\n    bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)\n    dst_rows += topright - bottomleft\n    dst_cols -= topright - bottomleft\n\n    # Assumed that meta tensor is to be stored in CUTLASS\n    # InterleavedColumnMajor layout, and reverse engineered\n    # corresponding code to store values into this tensor.\n    interleave = 2\n    cols_maj = dst_cols // interleave\n    cols_min = dst_cols % interleave\n    return (cols_maj * m * interleave + dst_rows * interleave +\n            cols_min).view(-1)\n\n\n# This function converts dense matrix into sparse semi-structured\n# representation, producing \"compressed\" matrix, in the layout used by\n# CUTLASS backend, and corresponding metadata matrix.\ndef sparse_semi_structured_from_dense_cutlass(dense):\n    if dense.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor\"  # noqa: E501\n        )\n\n    m, k = dense.shape\n    device = dense.device\n\n    meta_dtype = torch.int8\n    if dense.dtype == torch.int8:\n        meta_dtype = torch.int32\n    elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:\n        meta_dtype = torch.int16\n    else:\n        raise RuntimeError(f\"Invalid datatype {dense.dtype} of dense matrix\")\n    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4\n    if quadbits_per_meta_elem not in (4, 8):\n        raise RuntimeError(\n            \"Invalid number of elements per meta element calculated\")\n\n    if meta_dtype == torch.int32:\n        if m % 16 != 0:\n            raise RuntimeError(\n                f\"Number of rows of dense matrix {m} must be divisible by 16\")\n    else:\n        if m % 32 != 0:\n            raise RuntimeError(\n                f\"Number of rows of dense matrix {m} must be divisible by 32\")\n    if k % (4 * quadbits_per_meta_elem) != 0:\n        raise RuntimeError(\n            f\"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}\"  # noqa: E501\n        )\n\n    if dense.dtype != torch.float:\n        ksparse = 4\n        dense_4 = dense.view(-1, k // ksparse, ksparse)\n        m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)\n    else:\n        ksparse = 2\n        dense_2 = dense.view(-1, k // ksparse, ksparse)\n        m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)\n    meta_ncols = k // (ksparse * quadbits_per_meta_elem)\n\n    # Encoding quadruples of True/False values as follows:\n    #     [True,  True,  False, False] -> 0b0100\n    #     [True,  False, True,  False] -> 0b1000\n    #     [False, True,  True,  False] -> 0b1001\n    #     [True,  False, False, True ] -> 0b1100\n    #     [False, True,  False, True ] -> 0b1101\n    #     [False, False, True,  True ] -> 0b1110\n    # Thus, lower two bits in the encoding are index of the True value\n    # at the lowest index in the quadruple, and the higher two bits in\n    # the encoding are index of the other True value in the quadruple.\n    # In case there are less than two True values, than False value or\n    # values at some index or indices are considered True for the\n    # encoding.  In case there are more than two True values, then the\n    # excess True value(s) at some indices are considered False for\n    # the encoding.  The exact encodings used for these cases are as\n    # follows:\n    #     [False, False, False, False] -> 0b1110\n    #     [False, False, False, True ] -> 0b1110\n    #     [False, False, True,  False] -> 0b1110\n    #     [False, True,  False, False] -> 0b1001\n    #     [False, True,  True,  True ] -> 0b1101\n    #     [True,  False, False, False] -> 0b1000\n    #     [True,  False, True,  True ] -> 0b1100\n    #     [True,  True,  False, True ] -> 0b0100\n    #     [True,  True,  True,  False] -> 0b0100\n    #     [True,  True,  True,  True ] -> 0b0100\n    # These particular encodings are chosen, with the help of Espresso\n    # logic minimizer software, for the purpose of minimization of\n    # corresponding Boolean functions, that translate non-zero flags\n    # into encoding bits.  Note also possible choices for the first\n    # and last of these encodings were limited only to (0b0100,\n    # 0b1110), in order to produce valid encodings for 1:2 sparsity\n    # case.\n\n    expr0 = m0 & m1\n    expr1 = ~m0 & m1\n    expr2 = ~m0 & ~m1\n    bit0 = expr1\n    bit1 = expr2\n    bit2 = expr0 | expr2 | m3\n    bit3 = expr1 | ~m1\n    idxs0 = bit0 | (bit1.to(torch.int64) << 1)\n    idxs1 = bit2 | (bit3.to(torch.int64) << 1)\n\n    if dense.dtype != torch.float:\n        sparse0 = dense_4.gather(\n            -1, idxs0.unsqueeze(-1))  # type: ignore[possibly-undefined]\n        sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))\n        sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)\n    else:\n        sparse = dense_2.gather(-1,\n                                idxs0.unsqueeze(-1) // 2).view(\n                                    m,\n                                    k // 2)  # type: ignore[possibly-undefined]\n\n    meta_4 = idxs0 | (idxs1 << 2)\n    meta_n = meta_4.view(\n        (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)\n\n    if quadbits_per_meta_elem == 4:\n        meta = (meta_n[:, :, 0]\n                | (meta_n[:, :, 1] << 4)\n                | (meta_n[:, :, 2] << 8)\n                | (meta_n[:, :, 3] << 12))\n    elif quadbits_per_meta_elem == 8:\n        meta = (meta_n[:, :, 0]\n                | (meta_n[:, :, 1] << 4)\n                | (meta_n[:, :, 2] << 8)\n                | (meta_n[:, :, 3] << 12)\n                | (meta_n[:, :, 4] << 16)\n                | (meta_n[:, :, 5] << 20)\n                | (meta_n[:, :, 6] << 24)\n                | (meta_n[:, :, 7] << 28))\n\n    # Reorder meta tensor elements.\n    meta_reordered = meta.new_empty(\n        (m * meta_ncols, ))  # type: ignore[possibly-undefined]\n    meta_offsets = _calculate_meta_reordering_scatter_offsets(\n        m, meta_ncols, meta_dtype, device)\n    meta_reordered.scatter_(0, meta_offsets, meta.view(-1))\n\n    return (sparse, meta_reordered.view(m, meta_ncols))\n\n\n# This function performs reverse of the function above - it\n# reconstructs dense matrix from a pair of \"compressed\" matrix, given\n# in the layout used by CUTLASS backend, and accompanying metadata\n# matrix.\ndef sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):\n    if sparse.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor\"  # noqa: E501\n        )\n\n    m, k = sparse.shape\n    device = sparse.device\n\n    if meta_reordered.dim() != 2:\n        raise RuntimeError(\n            f\"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor\"  # noqa: E501\n        )\n    if meta_reordered.device != device:\n        raise RuntimeError(\n            f\"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device\"  # noqa: E501\n        )\n\n    meta_dtype = meta_reordered.dtype\n    if meta_dtype not in (torch.int16, torch.int32):\n        raise RuntimeError(f\"Invalid datatype {meta_dtype} of meta matrix\")\n    quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4\n\n    ksparse = 4 if sparse.dtype != torch.float else 2\n\n    meta_nrows, meta_ncols = meta_reordered.shape\n    if meta_nrows != m:\n        raise RuntimeError(\n            f\"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}\"  # noqa: E501\n        )\n    if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:\n        raise RuntimeError(\n            f\"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, \"  # noqa: E501\n            \"expected according to the number of columns of meta matrix\")\n\n    # Undo meta tensor elements reordering.\n    meta_offsets = _calculate_meta_reordering_scatter_offsets(\n        m, meta_ncols, meta_dtype, device)\n    meta = torch.gather(meta_reordered.view(-1), 0,\n                        meta_offsets).view(m, meta_ncols)\n\n    # Unpack sparse tensor back to original dense tensor, using\n    # information provided by meta tensor.  Note that torch.float\n    # datatype is handled pretty much the same as\n    # torch.half/torch.bfloat16, as metadata for a pair of torch.float\n    # value is encoded as if underlying 8 bytes contain four\n    # torch.half/torch.bfloat16 values, where either first two or last\n    # two are zeros.\n    meta_2 = torch.empty(\n        (m, meta_ncols, 2 * quadbits_per_meta_elem),\n        dtype=meta_dtype,\n        device=device,\n    )\n    if quadbits_per_meta_elem == 4:\n        meta_2[:, :, 0] = meta & 0b11\n        meta_2[:, :, 1] = (meta >> 2) & 0b11\n        meta_2[:, :, 2] = (meta >> 4) & 0b11\n        meta_2[:, :, 3] = (meta >> 6) & 0b11\n        meta_2[:, :, 4] = (meta >> 8) & 0b11\n        meta_2[:, :, 5] = (meta >> 10) & 0b11\n        meta_2[:, :, 6] = (meta >> 12) & 0b11\n        meta_2[:, :, 7] = (meta >> 14) & 0b11\n    elif quadbits_per_meta_elem == 8:\n        meta_2[:, :, 0] = meta & 0b11\n        meta_2[:, :, 1] = (meta >> 2) & 0b11\n        meta_2[:, :, 2] = (meta >> 4) & 0b11\n        meta_2[:, :, 3] = (meta >> 6) & 0b11\n        meta_2[:, :, 4] = (meta >> 8) & 0b11\n        meta_2[:, :, 5] = (meta >> 10) & 0b11\n        meta_2[:, :, 6] = (meta >> 12) & 0b11\n        meta_2[:, :, 7] = (meta >> 14) & 0b11\n        meta_2[:, :, 8] = (meta >> 16) & 0b11\n        meta_2[:, :, 9] = (meta >> 18) & 0b11\n        meta_2[:, :, 10] = (meta >> 20) & 0b11\n        meta_2[:, :, 11] = (meta >> 22) & 0b11\n        meta_2[:, :, 12] = (meta >> 24) & 0b11\n        meta_2[:, :, 13] = (meta >> 26) & 0b11\n        meta_2[:, :, 14] = (meta >> 28) & 0b11\n        meta_2[:, :, 15] = (meta >> 30) & 0b11\n\n    dense_offsets = meta_2.view(-1) + (\n        torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view(\n            -1, 1).repeat(1, 2).view(-1)\n\n    dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device)\n    if sparse.dtype != torch.float:\n        # dense.scatter_(0, dense_offsets, sparse.view(-1))\n        dense.scatter_(0, dense_offsets, sparse.reshape(-1))\n    else:\n        dense.view(torch.half).scatter_(0, dense_offsets,\n                                        sparse.view(torch.half).view(-1))\n\n    return dense.view(m, 2 * k)\n\n\ndef mask_creator(tensor):\n    \"\"\"\n    Class for creating N:M sparsity masks.\n    Masks will be created using the N:M ratio, where for every block of \n    M weights, N will be pruned based on ranked weight value. Each mask \n    will correspond to the given tensor.\n\n    :param N: The number of weights in a group to keep\n    :param M: The size of a weight group\n    \"\"\"\n    N = 2\n    M = 4\n\n    mask = None\n    # for i, tensor in enumerate(tensors):\n    if tensor.numel() % M != 0:\n        raise ValueError(\n            f\"Tensor of size {tensor.shape} can't be evenly divided into \"\n            f\"{M} groups\")\n\n    num_groups = tensor.numel() // M\n\n    # N:M sparsity for linear layers\n    tensor_temp = tensor.detach().abs().reshape(num_groups, M)\n    index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)]\n\n    w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)\n    mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)\n\n    return mask\n"
  },
  {
    "path": "kt-sft/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_24_perms.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nfrom typing import Dict, List\n\nimport numpy\nimport torch\n\n\n# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501\n#\n# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501\n# with the tensor-core format that is described here:\n# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501\n#\n# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501\n# (without the need to use ldmatrix instructions) # noqa: E501\ndef get_perms_24(num_bits: int):\n    perm_list: List[int] = []\n    for i in range(32):\n        perm1: List[int] = []\n        col = i // 4\n        col_o = col // 2\n        for block in [0, 1]:\n            for row in [\n                    2 * (i % 4),\n                    2 * (i % 4) + 1,\n                    2 * (i % 4 + 4),\n                    2 * (i % 4 + 4) + 1,\n            ]:\n                perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +\n                             4 * block)\n        for j in range(4):\n            perm_list.extend([p + 1 * j for p in perm1])\n    perm = numpy.array(perm_list)\n\n    if num_bits == 4:\n        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = numpy.array([0, 2, 1, 3])\n    else:\n        raise ValueError(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()\n    perm = torch.from_numpy(perm)\n    scale_perm: List[int] = []\n    for i in range(8):\n        scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])\n    scale_perm_single: List[int] = []\n    for i in range(8):\n        scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])\n    return perm, scale_perm, scale_perm_single\n\n\nmarlin_24_perm: Dict[int, torch.Tensor] = {}\nmarlin_24_scale_perm: Dict[int, List[int]] = {}\nmarlin_24_scale_perm_single: Dict[int, List[int]] = {}\nfor num_bits in [4, 8]:\n    perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)\n    marlin_24_perm[num_bits] = perm_24\n    marlin_24_scale_perm[num_bits] = scale_perm_24\n    marlin_24_scale_perm_single[num_bits] = scale_perm_single_24\n"
  },
  {
    "path": "kt-sft/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_perms.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nfrom typing import Dict, List\n\nimport numpy\nimport torch\n\n\n# Precompute permutations for Marlin weight and scale shuffling # noqa: E501\n#\n# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501\n# with the tensor-core format that is described here:\n# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501\n#\n# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501\n# (without the need to use ldmatrix instructions) # noqa: E501\ndef get_perms(num_bits: int):\n    perm_list: List[int] = []\n    for i in range(32):\n        perm1: List[int] = []\n        col = i // 4\n        for block in [0, 1]:\n            for row in [\n                    2 * (i % 4),\n                    2 * (i % 4) + 1,\n                    2 * (i % 4 + 4),\n                    2 * (i % 4 + 4) + 1,\n            ]:\n                perm1.append(16 * row + col + 8 * block)\n        for j in range(4):\n            perm_list.extend([p + 256 * j for p in perm1])\n\n    perm = numpy.array(perm_list)\n\n    if num_bits == 4:\n        interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])\n    elif num_bits == 8:\n        interleave = numpy.array([0, 2, 1, 3])\n    else:\n        raise Exception(\"num_bits must be 4 or 8, got {}\".format(num_bits))\n\n    perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()\n    perm = torch.from_numpy(perm)\n    scale_perm: List[int] = []\n    for i in range(8):\n        scale_perm.extend([i + 8 * j for j in range(8)])\n    scale_perm_single: List[int] = []\n    for i in range(4):\n        scale_perm_single.extend(\n            [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])\n    return perm, scale_perm, scale_perm_single\n\n\nmarlin_perm: Dict[int, torch.Tensor] = {}\nmarlin_scale_perm: Dict[int, List[int]] = {}\nmarlin_scale_perm_single: Dict[int, List[int]] = {}\nfor num_bits in [4, 8]:\n    perm, scale_perm, scale_perm_single = get_perms(num_bits)\n    marlin_perm[num_bits] = perm\n    marlin_scale_perm[num_bits] = scale_perm\n    marlin_scale_perm_single[num_bits] = scale_perm_single\n"
  },
  {
    "path": "kt-sft/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/marlin_utils.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nimport random\n\nimport numpy\nimport torch\n\nfrom ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.format_24 import (\n    mask_creator, sparse_semi_structured_from_dense_cutlass)\nfrom ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_24_perms import (\n    marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single)\nfrom ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_perms import (\n    marlin_perm, marlin_scale_perm, marlin_scale_perm_single)\nfrom ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.quant_utils import (\n    get_pack_factor, quantize_weights, sort_weights)\n\n__cuda_arch = torch.cuda.get_device_capability()\n\nMARLIN_TILE = 16\n\nGPTQ_MARLIN_TILE = 16\nGPTQ_MARLIN_MIN_THREAD_N = 64\nGPTQ_MARLIN_MIN_THREAD_K = 128\nGPTQ_MARLIN_MAX_PARALLEL = 16\n\nGPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]\nGPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]\nGPTQ_MARLIN_SUPPORTED_SYM = [True]\n\ndef is_marlin_supported():\n    return __cuda_arch[0] >= 8\n\n\ndef marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):\n    assert q_w.shape == (size_k, size_n)\n    assert size_k % tile == 0, f\"size_k = {size_k}, tile = {tile}\"\n    assert size_n % tile == 0, f\"size_k = {size_n}, tile = {tile}\"\n\n    # Permute weights to 16x64 marlin tiles\n    q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))\n    q_w = q_w.permute((0, 2, 1, 3))\n    q_w = q_w.reshape((size_k // tile, size_n * tile))\n\n    q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)\n\n    return q_w\n\n\ndef marlin_weights(q_w, size_k, size_n, num_bits, perm):\n    # Permute\n    q_w = marlin_permute_weights(q_w, size_k, size_n, perm)\n\n    # Pack\n    pack_factor = get_pack_factor(num_bits)\n    orig_device = q_w.device\n\n    q_w = q_w.cpu().numpy().astype(numpy.uint32)\n\n    q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),\n                           dtype=numpy.uint32)\n    for i in range(pack_factor):\n        q_packed |= q_w[:, i::pack_factor] << num_bits * i\n\n    q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)\n\n    return q_packed\n\n\ndef marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,\n                          scale_perm_single):\n    if group_size < size_k and group_size != -1:\n        s = s.reshape((-1, len(scale_perm)))[:, scale_perm]\n    else:\n        s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]\n    s = s.reshape((-1, size_n)).contiguous()\n\n    return s\n\n\ndef marlin_quantize(\n    w: torch.Tensor,\n    num_bits: int,\n    group_size: int,\n    act_order: bool,\n):\n    size_k, size_n = w.shape\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    # Quantize (and apply act_order if provided)\n    q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,\n                                                       act_order)\n\n    # For act_order, sort the \"weights\" and \"g_idx\" so that group ids are\n    # increasing\n    sort_indices = torch.empty(0, dtype=torch.int, device=w.device)\n    if act_order:\n        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)\n\n    # Reformat to marlin\n    marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,\n                                marlin_perm[num_bits])\n    marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,\n                                     marlin_scale_perm[num_bits],\n                                     marlin_scale_perm_single[num_bits])\n\n    # Create result\n    res_list = [marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]\n    for i in range(len(res_list)):\n        res_list[i] = res_list[i].to(w.device)\n\n    return res_list\n\n\ndef vllm_marlin_quantize(\n    w: torch.Tensor,\n    num_bits: int,\n    group_size: int,\n    act_order: bool,\n):\n    size_k, size_n = w.shape\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    # Quantize (and apply act_order if provided)\n    w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,\n                                                       act_order)\n\n    # For act_order, sort the \"weights\" and \"g_idx\" so that group ids are\n    # increasing\n    sort_indices = torch.empty(0, dtype=torch.int, device=w.device)\n    if act_order:\n        q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)\n\n    # Reformat to marlin\n    marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,\n                                marlin_perm[num_bits])\n    marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,\n                                     marlin_scale_perm[num_bits],\n                                     marlin_scale_perm_single[num_bits])\n\n    # Create result\n    res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]\n    for i in range(len(res_list)):\n        res_list[i] = res_list[i].to(w.device)\n\n    return res_list\n\n\ndef inject_24(w, size_k, size_n):\n    assert w.shape == (size_k, size_n)\n\n    mask = mask_creator(w.t()).t().cuda().bool()\n\n    return (mask * w).contiguous(), mask.contiguous()\n\n\ndef check_24(w, num_rows_to_sample=50, _verbose=False):\n    BLOCK_SIZE = 4\n    MAX_NON_ZEROS = 2\n\n    w = w.t().contiguous()\n\n    print(\"check_24: w.shape = {}\".format(w.shape))\n\n    num_rows, num_cols = w.shape\n    sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)\n    if _verbose:\n        print(f\"Sampled row idxs = {sampled_row_idxs}\")\n\n    total_segments = 0\n    non_24_segments = 0\n    for i in sampled_row_idxs:\n        for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):\n            total_segments += 1\n            block = w[i, j:j + BLOCK_SIZE]\n            num_nonzero = torch.count_nonzero(block)\n            if num_nonzero > MAX_NON_ZEROS:\n                print(\"i = {} j = {} block = {}\".format(i, j, block))\n                non_24_segments += 1\n\n    print(f\"{non_24_segments} / {total_segments} do not have 2:4 structure.\")\n\n\ndef compress_quantized_24_weight(q_24, size_k, size_n, num_bits):\n    assert q_24.shape == (size_k, size_n)\n\n    # Remove zp to normalize over 0\n    max_q_val = (1 << num_bits) - 1\n    zp = (max_q_val + 1) // 2\n    q_24_no_zp = q_24 - zp\n\n    # Compress\n    q_24_no_zp = q_24_no_zp.t().contiguous()\n    q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(\n        q_24_no_zp)\n    q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()\n\n    # Restore zp\n    q_24_comp = q_24_no_zp_comp + zp\n\n    # Resize meta to its actual shape (without moving any data)\n    meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)\n\n    return q_24_comp, meta\n\n\ndef marlin_24_quantize(\n    w: torch.Tensor,\n    num_bits: int,\n    group_size: int,\n):\n    size_k, size_n = w.shape\n\n    # Normalize group_size\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    # Inject 2:4 sparsity\n    w_24, mask_24 = inject_24(w, size_k, size_n)\n\n    # Quantize\n    w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,\n                                                             num_bits,\n                                                             group_size,\n                                                             act_order=False)\n\n    # Compress quantized weight\n    q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,\n                                                     num_bits)\n    size_k_comp = size_k // 2\n\n    # Reformat to marlin\n    marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,\n                                        num_bits, marlin_24_perm[num_bits])\n    marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,\n                                        marlin_24_scale_perm[num_bits],\n                                        marlin_24_scale_perm_single[num_bits])\n\n    # Create result\n    res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]\n    for i in range(len(res_list)):\n        res_list[i] = res_list[i].to(w.device)\n\n    return res_list\n\n\ndef compute_max_diff(output, output_ref):\n    return torch.mean(torch.abs(output - output_ref)) / torch.mean(\n        torch.abs(output_ref))\n\n\nclass MarlinWorkspace:\n\n    def __init__(self, out_features, min_thread_n, max_parallel, device):\n        assert (out_features % min_thread_n == 0), (\n            \"out_features = {} is undivisible by min_thread_n = {}\".format(\n                out_features, min_thread_n))\n\n        max_workspace_size = ((out_features // min_thread_n) * max_parallel)\n\n        self.scratch = torch.zeros(max_workspace_size,\n                                   dtype=torch.int,\n                                   device=device)\n"
  },
  {
    "path": "kt-sft/ktransformers/ktransformers_ext/operators/custom_marlin/quantize/utils/quant_utils.py",
    "content": "\"\"\"This file is used for /tests and /benchmarks\"\"\"\nimport numpy\nimport torch\n\nSUPPORTED_NUM_BITS = [4, 8]\nSUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]\n\n\ndef get_pack_factor(num_bits):\n    assert num_bits in SUPPORTED_NUM_BITS, f\"Unsupported num_bits = {num_bits}\"\n    return 32 // num_bits\n\n\ndef permute_rows(q_w: torch.Tensor, group_size: int):\n\n    orig_device = q_w.device\n    k_size, _ = q_w.shape\n\n    g_idx = torch.zeros((k_size, ), dtype=torch.int32)\n    for i in range(k_size):\n        g_idx[i] = i // group_size\n\n    # Simulate act_order by doing a random permutation on K\n    rand_perm = torch.randperm(k_size)\n\n    g_idx = g_idx[rand_perm].contiguous()\n    q_w = q_w[rand_perm, :].contiguous()\n\n    return (\n        q_w.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        rand_perm.to(device=orig_device),\n    )\n\n\ndef quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,\n                     act_order: bool):\n    orig_device = w.device\n    size_k, size_n = w.shape\n\n    assert w.is_floating_point(), \"w must be float\"\n    assert num_bits in SUPPORTED_NUM_BITS, f\"Unsupported num_bits = {num_bits}\"\n    assert group_size in SUPPORTED_GROUP_SIZES + [\n        size_k\n    ], f\"Unsupported groupsize = {group_size}\"\n\n    if group_size == -1:\n        group_size = size_k\n    assert group_size <= size_k\n\n    max_q_val = 2**num_bits - 1\n    half_q_val = (max_q_val + 1) // 2\n\n    # Reshape to [groupsize, -1]\n    if group_size < size_k:\n        w = w.view((-1, group_size, size_n))\n        w = w.permute(1, 0, 2)\n        w = w.reshape((group_size, -1))\n\n    # Compute scale for each group\n    s = torch.max(torch.abs(w), 0, keepdim=True)[0]\n    s *= 2 / max_q_val  # 2 => symmetric\n\n    # Quantize\n    q_w = torch.round(w / s).int()\n    q_w += half_q_val\n    q_w = torch.clamp(q_w, 0, max_q_val)\n\n    # Restore original shapes\n    if group_size < size_k:\n\n        def reshape_w(w):\n            w = w.reshape((group_size, -1, size_n))\n            w = w.permute(1, 0, 2)\n            w = w.reshape((size_k, size_n)).contiguous()\n            return w\n\n        q_w = reshape_w(q_w)\n\n    s = s.reshape((-1, size_n)).contiguous()\n\n    # Apply act_order\n    g_idx = torch.empty(0, dtype=torch.int, device=w.device)\n    rand_perm = torch.empty(0, dtype=torch.int, device=w.device)\n    if act_order:\n        assert (\n            group_size < size_k\n        ), \"For act_order, groupsize = {} must be less than size_k = {}\".format(\n            group_size, size_k)\n\n        q_w, g_idx, rand_perm = permute_rows(q_w, group_size)\n\n    return (\n        q_w.to(device=orig_device),\n        s.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        rand_perm.to(device=orig_device),\n    )\n\n\ndef sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):\n    orig_device = q_w.device\n\n    sort_indices = torch.argsort(g_idx).to(\n        dtype=torch.int32)  # Sort based on g_idx\n\n    g_idx = g_idx[sort_indices].contiguous()\n    q_w = q_w[sort_indices, :].contiguous()\n\n    return (\n        q_w.to(device=orig_device),\n        g_idx.to(device=orig_device),\n        sort_indices.to(device=orig_device),\n    )\n\n\ndef gptq_pack(\n    q_w: torch.Tensor,\n    num_bits: int,\n    size_k: int,\n    size_n: int,\n):\n    assert q_w.shape == (size_k, size_n)\n\n    pack_factor = get_pack_factor(num_bits)\n    assert size_k % pack_factor == 0\n\n    orig_device = q_w.device\n\n    q_w = q_w.cpu().numpy().astype(numpy.uint32)\n\n    q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)\n\n    for i in range(pack_factor):\n        q_res |= q_w[i::pack_factor, :] << num_bits * i\n\n    q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)\n    return q_res\n"
  },
  {
    "path": "kt-sft/ktransformers/ktransformers_ext/triton/fp8gemm.py",
    "content": "# Adopted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py\nfrom typing import Tuple\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom triton import Config\n\n\n@triton.jit\ndef act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):\n    \"\"\"\n    Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.\n\n    Args:\n        x_ptr (triton.Pointer): Pointer to the input tensor.\n        y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.\n        s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.\n        BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.\n\n    Returns:\n        None\n    \"\"\"\n    pid = tl.program_id(axis=0)\n    offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    x = tl.load(x_ptr + offs).to(tl.float32)\n    s = tl.max(tl.abs(x)) / 448.\n    y = x / s\n    y = y.to(y_ptr.dtype.element_ty)\n    tl.store(y_ptr + offs, y)\n    tl.store(s_ptr + pid, s)\n\n\ndef act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:\n    \"\"\"\n    Quantizes the input tensor `x` using block-wise quantization.\n\n    Args:\n        x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.\n        block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:\n            - The quantized tensor with dtype `torch.float8_e4m3fn`.\n            - A tensor of scaling factors with dtype `torch.float32`.\n    \"\"\"\n    assert x.is_contiguous(), 'Input tensor must be contiguous'\n    assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'\n    y = torch.empty_like(x, dtype=torch.float8_e4m3fn)\n    s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)\n    grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )\n    act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)\n    return y, s\n\n\n@triton.jit\ndef weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):\n    \"\"\"\n    Dequantizes weights using the provided scaling factors and stores the result.\n\n    Args:\n        x_ptr (tl.pointer): Pointer to the quantized weights.\n        s_ptr (tl.pointer): Pointer to the scaling factors.\n        y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.\n        M (int): Number of rows in the weight matrix.\n        N (int): Number of columns in the weight matrix.\n        BLOCK_SIZE (tl.constexpr): Size of the block for tiling.\n\n    Returns:\n        None\n    \"\"\"\n    pid_m = tl.program_id(axis=0)\n    pid_n = tl.program_id(axis=1)\n    n = tl.cdiv(N, BLOCK_SIZE)\n    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n    offs = offs_m[:, None] * N + offs_n[None, :]\n    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)\n    x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)\n    s = tl.load(s_ptr + pid_m * n + pid_n)\n    y = x * s\n    tl.store(y_ptr + offs, y, mask=mask)\n\n\ndef weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:\n    \"\"\"\n    Dequantizes the given weight tensor using the provided scale tensor.\n\n    Args:\n        x (torch.Tensor): The quantized weight tensor of shape (M, N).\n        s (torch.Tensor): The scale tensor of shape (M, N).\n        block_size (int, optional): The block size to use for dequantization. Defaults to 128.\n\n    Returns:\n        torch.Tensor: The dequantized weight tensor of the same shape as `x`.\n\n    Raises:\n        AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.\n    \"\"\"\n    assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'\n    assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'\n    M, N = x.size()\n    y = torch.empty_like(x, dtype=torch.get_default_dtype())\n    grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))\n    with torch.cuda.device(x.device):\n        weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)\n    return y\n\n\nfp8_gemm_configs = [\n    Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)\n    for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]\n]\n\n@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])\n@triton.jit\ndef fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,\n                    a_s_ptr, b_s_ptr,\n                    M, N: tl.constexpr, K: tl.constexpr,\n                    BLOCK_SIZE_M: tl.constexpr,\n                    BLOCK_SIZE_N: tl.constexpr,\n                    BLOCK_SIZE_K: tl.constexpr):\n    \"\"\"\n    Performs a matrix multiplication operation on FP8 matrices with scaling factors.\n\n    Args:\n        a_ptr (tl.tensor): Pointer to the first input matrix A.\n        b_ptr (tl.tensor): Pointer to the second input matrix B.\n        c_ptr (tl.tensor): Pointer to the output matrix C.\n        a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.\n        b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.\n        M (int): Number of rows in matrix A and C.\n        N (tl.constexpr): Number of columns in matrix B and C.\n        K (tl.constexpr): Number of columns in matrix A and rows in matrix B.\n        BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.\n        BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.\n        BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.\n\n    Returns:\n        None\n    \"\"\"\n    pid_m = tl.program_id(axis=0)\n    pid_n = tl.program_id(axis=1)\n    k = tl.cdiv(K, BLOCK_SIZE_K)\n    offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n    offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n    offs_k = tl.arange(0, BLOCK_SIZE_K)\n    a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]\n    b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]\n    a_s_ptrs = a_s_ptr + offs_m * k\n    b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k\n\n    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n    for i in range(k):\n        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)\n        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)\n        a_s = tl.load(a_s_ptrs)\n        b_s = tl.load(b_s_ptrs)\n        accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]\n        a_ptrs += BLOCK_SIZE_K\n        b_ptrs += BLOCK_SIZE_K\n        a_s_ptrs += 1\n        b_s_ptrs += 1\n    c = accumulator.to(c_ptr.dtype.element_ty)\n    offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n    offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n    c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]\n    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)\n    tl.store(c_ptrs, c, mask=mask)\n\n\ndef fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):\n    \"\"\"\n    Perform a matrix multiplication using FP8 precision.\n\n    Args:\n        a (torch.Tensor): The first input matrix, must be contiguous.\n        a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.\n        b (torch.Tensor): The second input matrix, must be contiguous.\n        b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.\n\n    Returns:\n        torch.Tensor: The result of the matrix multiplication.\n    \"\"\"\n    assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'\n    assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'\n    K = a.size(-1)\n    M = a.numel() // K\n    N = b.size(0)\n    c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())\n    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))\n    fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)\n    return c"
  },
  {
    "path": "kt-sft/ktransformers/local_chat.py",
    "content": "\"\"\"\nDescription  :  \nAuthor       : Boxin Zhang, Azure-Tang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\n\nimport os\nimport platform\nimport sys\n\nproject_dir = os.path.dirname(os.path.dirname(__file__))\nsys.path.insert(0, project_dir)\nimport argparse\nimport torch\nimport logging\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    TextStreamer,\n    EvalPrediction,\n)\nimport json\nfrom pathlib import Path\nfrom tqdm import tqdm\nfrom torchviz import make_dot\nimport fire\nfrom ktransformers.optimize.optimize import optimize_and_load_gguf\nfrom ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM\nfrom ktransformers.models.modeling_llama import LlamaForCausalLM\nfrom ktransformers.models.modeling_mixtral import MixtralForCausalLM\nfrom ktransformers.util.utils import load_weights, prefill_and_generate, prefill_and_generate_capture, get_compute_capability, xpu_fp16_model\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.operators.flashinfer_wrapper import flashinfer_enabled\nfrom ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor\nfrom ktransformers.sft.lora import inject_lora_layer, lora_and_load_adapter\nfrom ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader\nfrom ktransformers.util.globals import GLOBAL_CONFIG\nfrom ktransformers.sft.metrics import ComputeSimilarity\nfrom ktransformers.sft.monkey_patch_torch_module import install_patch, restore_patch\n\nos.environ['CUDA_VISIBLE_DEVICES'] = '1'\n\n# for debug\ndef print_module_tree(module, indent=0):\n    print(\" \" + f\"{module.__class__.__name__}(training={module.training})\")\n    for name, child in module.named_children():\n        print(\" \" + f\"└─{name}: \", end=\"\")\n        print_module_tree(child, indent + 4)\n\n# for debug\ndef write_to_file(content, file_path: str = 'ktransformers/mark_content.txt', mode: str = 'a', encoding: str = 'utf-8') -> None:\n    \"\"\"\n    将字符串写入指定文件 \n    :param content: 要写入的字符串内容 \n    :param file_path: 目标文件路径 \n    :param mode: 文件打开模式（默认'w'为覆盖写入，可选'a'追加写入）\n    :param encoding: 文件编码（默认utf-8）\n    \"\"\"\n    with open(file_path, mode, encoding=encoding) as f:\n        f.write(content) \n\ncustom_models = {\n    \"DeepseekV2ForCausalLM\": DeepseekV2ForCausalLM,\n    \"DeepseekV3ForCausalLM\": DeepseekV3ForCausalLM,\n    \"Qwen2MoeForCausalLM\": Qwen2MoeForCausalLM,\n    \"LlamaForCausalLM\": LlamaForCausalLM,\n    \"MixtralForCausalLM\": MixtralForCausalLM,\n}\n\nktransformer_rules_dir = (\n    os.path.dirname(os.path.abspath(__file__)) + \"/optimize/optimize_rules/\"\n)\ndefault_optimize_rules = {\n    \"DeepseekV2ForCausalLM\": ktransformer_rules_dir + \"DeepSeek-V2-Chat.yaml\",\n    \"DeepseekV3ForCausalLM\": ktransformer_rules_dir + \"DeepSeek-V3-Chat.yaml\",\n    \"Qwen2MoeForCausalLM\": ktransformer_rules_dir + \"Qwen2-57B-A14B-Instruct.yaml\",\n    \"LlamaForCausalLM\": ktransformer_rules_dir + \"Internlm2_5-7b-Chat-1m.yaml\",\n    \"MixtralForCausalLM\": ktransformer_rules_dir + \"Mixtral.yaml\",\n}\n\n\ndef local_chat(\n    model_path: str | None = None,\n    model_config_path: str | None = None,\n    optimize_config_path: str = None,\n    gguf_path: str | None = None,\n    max_new_tokens: int = 1000,\n    cpu_infer: int = Config().cpu_infer,\n    use_cuda_graph: bool = True, # modify to false if using KExpertsTorch\n    prompt_file : str | None = None,\n    mode: str = \"normal\",\n    force_think: bool = False,\n    chunk_size: int = 8192,\n    device: str = \"cuda\",\n    is_sft: bool = False,\n    sft_data_path: str | None = None,\n    save_adapter_path: str | None = None,\n    use_adapter: bool = False,\n    use_adapter_path: str | None = None,\n    is_test_data: bool = False,\n    test_data_path: str | None = None,\n    output_dir: str | None = None,\n):\n\n    if not is_sft:\n        torch.set_grad_enabled(False)\n        \n    if is_sft == True or use_adapter == True:\n        GLOBAL_CONFIG._config[\"mod\"] = \"sft\"\n    else:\n        GLOBAL_CONFIG._config[\"mod\"] = \"infer\"\n\n    Config().cpu_infer = cpu_infer\n    Config().chunk_size = chunk_size\n    if torch.xpu.is_available():\n        use_cuda_graph = False\n\n    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n    if model_config_path == None:\n        config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n    else:\n        config = AutoConfig.from_pretrained(model_config_path, trust_remote_code=True)\n    if mode == 'long_context':\n        assert config.architectures[0] == \"LlamaForCausalLM\", \"only LlamaForCausalLM support long_context mode\"\n        torch.set_default_dtype(torch.float16)\n    elif xpu_fp16_model(config):\n        torch.set_default_dtype(torch.float16)\n    else:\n        torch.set_default_dtype(config.torch_dtype)\n\n    with torch.device(\"meta\"):\n        if config.architectures[0] in custom_models:\n            print(\"using custom modeling_xxx.py.\")\n            if (\n                \"Qwen2Moe\" in config.architectures[0]\n            ):  # Qwen2Moe must use flash_attention_2 to avoid overflow.\n                config._attn_implementation = \"flash_attention_2\"\n            if \"Llama\" in config.architectures[0]:\n                config._attn_implementation = \"eager\"\n            if \"Mixtral\" in config.architectures[0]:\n                config._attn_implementation = \"flash_attention_2\"\n            if torch.xpu.is_available():\n                config._attn_implementation = \"eager\"\n            model = custom_models[config.architectures[0]](config)\n        else:\n            if torch.xpu.is_available():\n                attn_implementation = \"eager\"\n            else:\n                attn_implementation = \"flash_attention_2\"\n            model = AutoModelForCausalLM.from_config(\n                config, trust_remote_code=True, attn_implementation=attn_implementation\n            )\n\n    if optimize_config_path is None:\n        if config.architectures[0] in default_optimize_rules:\n            print(\"using default_optimize_rule for\", config.architectures[0])\n            optimize_config_path = default_optimize_rules[config.architectures[0]]\n        else:\n            optimize_config_path = input(\n                \"please input the path of your rule file(yaml file containing optimize rules):\"\n            )\n\n    if gguf_path is None:\n        gguf_path = input(\n            \"please input the path of your gguf file(gguf file in the dir containing input gguf file must all belong to current model):\"\n        )\n        \n    GLOBAL_CONFIG._config[\"mod\"] = \"infer\"\n    optimize_and_load_gguf(model, optimize_config_path, gguf_path, config)\n\n    model.train()\n\n    if is_sft == True:\n        if use_adapter == True or is_test_data == True:\n            raise AttributeError(\"We do not support to run sft and inference at the same time.\")\n        GLOBAL_CONFIG._config[\"mod\"] = \"sft\"\n        print(f\"sft with lora in dataset: {sft_data_path} ...\")\n        print(f\"use_cuda_graph:{use_cuda_graph}\")\n        lora_and_load_adapter(model, tokenizer, sft_data_path, save_adapter_path)\n\n    if use_adapter == True:\n        GLOBAL_CONFIG._config[\"mod\"] = \"sft\"\n        if is_sft == True:\n            raise AttributeError(\"We do not support more than one adapter up to now...\")\n        \n        if use_adapter_path.endswith('.gguf'):\n            inject_lora_layer(model, use_adapter_path)\n            adapter_gguf_loader = GGUFLoader(use_adapter_path)\n            load_weights(model, adapter_gguf_loader, adapter_gguf=True)\n            model.train()\n        else:\n            inject_lora_layer(model, use_adapter_path)\n            \n            adapter_loader = SafeTensorLoader(use_adapter_path)\n            device = next(model.parameters()).device\n            \n            # for name, param in model.named_parameters():\n            #     print(name, param.shape)\n\n            for key in adapter_loader.tensor_file_map.keys():\n                try:\n                    tensor = adapter_loader.load_tensor(key, device=device)\n                    \n                    model_key = key.replace(\"base_model.model.\", \"\")\n                    model_key = model_key.replace(\".weight\", \".default.weight\")\n                    \n                    param = model.get_parameter(model_key)\n                    param.data.copy_(tensor.data)\n                    \n                    print(f\"Loaded adapter weight: {key} -> {model_key}\")\n                except AttributeError as e:\n                    print(f\"Skipping {key}: not a model parameter\")\n                except KeyError as e:\n                    print(f\"Key not found in model: {model_key} (original: {key})\")\n            \n\n    try:\n        model.generation_config = GenerationConfig.from_pretrained(model_path)\n    except Exception as e:\n        print(f\"generation config can't auto create, make default. Message: {e}\")\n        gen_config = GenerationConfig(\n            temperature=0.6,\n            top_p=0.95,\n            do_sample=True\n        )\n        model.generation_config = gen_config\n    # model.generation_config = GenerationConfig.from_pretrained(model_path)\n    if model.generation_config.pad_token_id is None:\n        model.generation_config.pad_token_id = model.generation_config.eos_token_id\n    model.eval()\n    logging.basicConfig(level=logging.INFO)\n    \n    # @torch.no_grad()\n    # def first_token_argmax_baseline(model, tokenizer, prompt_text, device):\n    #     model.eval()\n    #     enc = tokenizer.apply_chat_template([{\"role\":\"user\",\"content\":prompt_text}],\n    #                                         add_generation_prompt=True, return_tensors=\"pt\")\n    #     x = enc.to(device)\n    #     logits = model(input_ids=x, use_cache=False, return_dict=False)[0]\n    #     return int(torch.argmax(logits[:, -1, :], dim=-1)[0])\n\n    # try:\n    #     device_map = model.gguf_loader.tensor_device_map\n    #     from ktransformers.util.utils import get_device, torch_device_mapping\n    #     torch_device = get_device('model.layers.0.self_attn', device_map)\n    #     torch_device = torch_device_mapping.get(torch_device, torch_device)\n    #     print(f\"[FIRST-TOKEN PROBE] argmax id = {probe_id} ({tokenizer.decode([probe_id])!r})\")\n    # except Exception as e:\n    #     print(\"[FIRST-TOKEN PROBE] failed:\", e)\n    #     return\n\n    system = platform.system()\n    # for debug\n    # if system == \"Windows\":\n    #     os.system(\"cls\")\n    # else:\n    #     os.system(\"clear\")\n    \n    if GLOBAL_CONFIG._config[\"mod\"] == \"sft\" :\n        model.model.embed_tokens.to(\"cpu\")\n        \n    if is_test_data:\n        data_path = Path(test_data_path)\n        with data_path.open(\"r\", encoding=\"utf-8\") as f:\n            dataset = json.load(f)\n        preds, refs = [], []\n\n        for sample in tqdm(dataset, desc=\"Processing samples\"):\n            inst = sample.get(\"instruction\", \"\")\n            prompt = sample.get(\"input\", \"\")\n            prompt = prompt+inst\n            # print(f\"prompt: {prompt}\")\n            label = sample.get(\"output\", \"\")\n   \n            messages = [{\"role\": \"user\", \"content\": prompt}]\n            input_tensor = tokenizer.apply_chat_template(\n                messages, add_generation_prompt=True, return_tensors=\"pt\"\n            )\n            if force_think:\n                token_thinks = torch.tensor([tokenizer.encode(\"<think>\\\\n\",add_special_tokens=False)],device=input_tensor.device)\n                input_tensor = torch.cat(\n                    [input_tensor, token_thinks], dim=1\n                )\n            if mode == 'long_context':\n                assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \\\n                \"please change max_seq_len in  ~/.ktransformers/config.yaml\"\n\n            if system != \"Windows\" and (config.architectures[0] == \"DeepseekV2ForCausalLM\" or config.architectures[0] == \"DeepseekV3ForCausalLM\") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:\n                prediction = prefill_and_generate_capture(\n                    model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,\n                    use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim, echo_stream=False\n                )\n            else:\n                prediction = prefill_and_generate_capture(\n                    model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,echo_stream=False,\n                )\n            # print(f\"prediction:{prediction}\")\n            sample[\"label\"] = label\n            sample[\"prediction\"] = prediction\n            sample.pop(\"output\", None)\n\n            preds.append(prediction)\n            refs.append(label)\n\n        pred_file = Path(output_dir) / 'predictions.json'\n        pred_file.parent.mkdir(parents=True, exist_ok=True)\n        \n        with pred_file.open(\"w\", encoding=\"utf-8\") as f:\n            json.dump(dataset, f, ensure_ascii=False, indent=2)\n\n        compute_metrics = ComputeSimilarity(tokenizer)\n        # print(f\"metrics:{metrics}\")\n        \n        enc_pred = tokenizer(preds, add_special_tokens=False, padding=True, return_tensors=\"np\")\n        enc_ref  = tokenizer(refs,  add_special_tokens=False, padding=True, return_tensors=\"np\")\n\n        ep = EvalPrediction(\n            predictions=enc_pred[\"input_ids\"],\n            label_ids=enc_ref[\"input_ids\"]\n        )\n\n        metrics = compute_metrics(ep, compute_result=True)\n\n        metric_file = Path(output_dir) / 'metrics.json'\n        with metric_file.open(\"w\", encoding=\"utf-8\") as f:\n            json.dump(metrics, f, ensure_ascii=False, indent=2)\n            \n        print(f\"Results of predictions saved in {pred_file}\")\n        print(f\"Results of metrics saved in {metric_file}\")\n\n    while not is_test_data:\n        GLOBAL_CONFIG._config[\"mod\"] = \"infer\"\n        content = input(\"Chat: \")\n        if content.startswith('\"\"\"'):  # prefix \"\"\"\n            # multi lines input\n            content = content[3:] + \"\\n\"\n            while True:\n                line = input(\"\")\n                if line.endswith('\"\"\"'):\n                    # end multi lines input\n                    line = line[:-3]  # suffix \"\"\"\n                    if line:\n                        content += line + \"\\n\"\n                    break\n                else:\n                    content += line + \"\\n\"\n\n        if content == \"\":\n            if prompt_file != None:\n                content = open(prompt_file, \"r\").read()\n            else:\n                content = \"Please write a piece of quicksort code in C++.\"\n        elif os.path.isfile(content):\n            content = open(content, \"r\").read()\n            \n        messages = [{\"role\": \"user\", \"content\": content}]\n        input_tensor = tokenizer.apply_chat_template(\n            messages, add_generation_prompt=True, return_tensors=\"pt\"\n        )\n        if force_think:\n            token_thinks = torch.tensor([tokenizer.encode(\"<think>\\\\n\",add_special_tokens=False)],device=input_tensor.device)\n            input_tensor = torch.cat(\n                [input_tensor, token_thinks], dim=1\n            )\n        if mode == 'long_context':\n            assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \\\n            \"please change max_seq_len in  ~/.ktransformers/config.yaml\"\n\n        if system != \"Windows\" and (config.architectures[0] == \"DeepseekV2ForCausalLM\" or config.architectures[0] == \"DeepseekV3ForCausalLM\") and flashinfer_enabled and get_compute_capability() >= 8 and device_manager.gpu_vendor == GPUVendor.NVIDIA:\n            generated = prefill_and_generate(\n                model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,\n                use_flashinfer_mla = True, num_heads = config.num_attention_heads, head_dim_ckv = config.kv_lora_rank, head_dim_kpe = config.qk_rope_head_dim, q_head_dim = config.qk_rope_head_dim + config.qk_nope_head_dim\n            )\n        else:\n            generated = prefill_and_generate(\n                model, tokenizer, input_tensor.to(device), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, chunk_size = chunk_size,\n            )\n\n\nif __name__ == \"__main__\":\n    install_patch()\n    IS_DEBUG = True\n\n    if IS_DEBUG == False:\n        parser = argparse.ArgumentParser()\n\n        parser.add_argument(\"--model_path\", required=True)\n        parser.add_argument(\"--model_config_path\", default=None)\n        parser.add_argument(\"--gguf_path\", required=True)\n        parser.add_argument(\"--cpu_infer\", type=int, default=32)\n        parser.add_argument(\"--max_new_tokens\", type=int, default=1000)\n        parser.add_argument(\"--force_think\", action=\"store_true\")\n        parser.add_argument(\"--optimize_config_path\", required=True)\n        parser.add_argument(\"--is_sft\", type=lambda x: x.lower() == \"true\", default=False)\n        parser.add_argument(\"--sft_data_path\", default=None)\n        parser.add_argument(\"--save_adapter_path\", default=None)\n        parser.add_argument(\"--use_adapter\", type=lambda x: x.lower() == \"true\", default=False)\n        parser.add_argument(\"--use_adapter_path\", default=None)\n        parser.add_argument(\"--is_test_data\", type=lambda x: x.lower() == \"true\", default=False)\n        parser.add_argument(\"--test_data_path\", default=None)\n        parser.add_argument(\"--output_dir\", default=None)\n\n        args = parser.parse_args()\n\n        local_chat(\n            model_path=args.model_path,\n            model_config_path=args.model_config_path,\n            gguf_path=args.gguf_path,\n            cpu_infer=args.cpu_infer,\n            max_new_tokens=args.max_new_tokens,\n            force_think=args.force_think,\n            optimize_config_path=args.optimize_config_path,\n            is_sft=args.is_sft,\n            sft_data_path=args.sft_data_path,\n            save_adapter_path=args.save_adapter_path,\n            use_adapter=args.use_adapter,\n            use_adapter_path=args.use_adapter_path,\n            is_test_data=args.is_test_data,\n            test_data_path=args.test_data_path,\n            output_dir= args.output_dir\n        )\n\n    else:\n        local_chat(\n            # model_path=\"/mnt/data/data/DeepSeek-V3-671B-BF16\",\n            # model_config_path=\"/mnt/data/data/DeepSeek-V3-671B-BF16\",\n            # gguf_path=\"/mnt/data/data/DeepSeek-V3-671B-BF16\",\n            model_path=\"/mnt/data/models/DeepSeek-V2-Lite-Chat\",\n            model_config_path=\"/mnt/data/models/DeepSeek-V2-Lite-Chat\",\n            gguf_path=\"/mnt/data/models/DeepSeek-V2-Lite-Chat\",\n            cpu_infer=32,\n            max_new_tokens=1000,\n            force_think=False,\n            # optimize_config_path=\"ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml\",\n            optimize_config_path=\"ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml\",\n            is_sft=True,\n            sft_data_path=\"test_adapter/western_train.json\",\n            # sft_data_path=\"test_adapter/western_train.json\",\n            # sft_data_path=\"test_adapter/500token_test.json\",\n            save_adapter_path=\"/mnt/data/lpl/test_adapter/Kwhl_test_py312_torch28_DeepSeekV2_WEST\",\n            use_adapter=False,\n            use_adapter_path=\"/mnt/data/lpl/test_adapter/Kllama_deepseekV2_AfriMed_mcq\",\n            is_test_data=False,\n            test_data_path=\"/home/lpl/LLaMA-Factory-KT/data/mcq_test.json\",\n            output_dir=\"/mnt/data/lpl/test_adapter/Kllama_deepseekV2_AfriMed_mcq/baselines\",\n        )\n        "
  },
  {
    "path": "kt-sft/ktransformers/local_chat.sh",
    "content": "#!/bin/bash\n\npython3 ktransformers/local_chat.py \\\n    --model_path \"/mnt/data/models/DeepSeek-V2-Lite-Chat\" \\\n    --model_config_path \"/mnt/data/models/DeepSeek-V2-Lite-Chat\" \\\n    --gguf_path \"/mnt/data/models/DeepSeek-V2-Lite-Chat\" \\\n    --cpu_infer 32 \\\n    --max_new_tokens 1000 \\\n    --optimize_config_path \"ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml\" \\\n    --is_sft False \\\n    --sft_data_path \"test_adapter/sft_translation.json\" \\\n    --save_adapter_path \"test_adapter/demo_adapter_KT_target_kv\" \\\n    --use_adapter True \\\n    --use_adapter_path \"/mnt/data/lpl/test_adapter/KT_newLoader_singleGPU_deepseekV2_Neko_AFS/checkpoint-566\" \\\n    --is_test_data False \\\n    --test_data_path \"test_adapter/demo_adapter_origin_target_kv\" \\\n    --output_dir \"test_adapter/demo_adapter_origin_target_kv\" \\"
  },
  {
    "path": "kt-sft/ktransformers/lora_test_module.py",
    "content": "import os\nimport platform\nimport sys\n\nproject_dir = os.path.dirname(os.path.dirname(__file__))\nsys.path.insert(0, project_dir)\n\nfrom torchviz import make_dot\nfrom torch import nn\nimport torch\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    TextStreamer,\n)\n\nfrom ktransformers.operators.linear import KLinearTorch, KTransformersLinear\nfrom ktransformers.sft.peft_utils.lora_layer import KTransformersLinearLora\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom ktransformers.util.inference_state import InferenceState\n\nimport hiddenlayer as hl\n\ngguf_loader = GGUFLoader(gguf_path=\"/home/yj/ktransformers/GGUF-DeepSeek-V2-Lite-Chat\")\nconfig = AutoConfig.from_pretrained(\"/home/yj/ktransformers/DeepSeek-V2-Lite-Chat\", trust_remote_code=True)\ntorch.set_default_dtype(config.torch_dtype)\n\nclass TestModelLora(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        random_linear_layer = nn.Linear(in_features=3072, out_features=2048, bias=False)\n        \n        orig_linear = KTransformersLinear(\n            key='blk.0.attn_q',\n            gguf_loader=gguf_loader,\n            config=config,\n            orig_module=random_linear_layer,\n            generate_op=\"KLinearTorch\"\n        )\n        self.layer = KTransformersLinearLora(\n            orig_module=orig_linear,\n            adapter_name=\"lora_test\",\n            r=8,\n            lora_alpha=16\n        )\n        self.layer.generate_linear.weight = torch.randn(3072, 2048).to(\"cuda\")\n        \n    def forward(self, x):\n        return self.layer(x)\n    \nclass TestModelBase(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = KTransformersLinear(\n            key=\"linear\",\n            gguf_loader=gguf_loader, \n            config=config, \n            orig_module=nn.Linear(in_features=3072, out_features=2048, bias=False),\n            generate_op=\"KLinearTorch\"\n        )\n        # self.layer.generate_linear.weight = torch.randn(3072, 2048).to(\"cuda\")\n        weight = torch.randn(3072, 2048, device=\"cuda\")\n        self.layer.load(w=nn.Parameter(weight), mode = InferenceState.GENERATE)\n        # self.layer.generate_linear.weight = nn.Parameter(torch.randn(3072, 2048).to(\"cuda\"))\n        self.fc1 = nn.Linear(3072, 2048, bias=False)\n        self.relu = nn.ReLU()\n        self.fc2 = nn.Linear(2048, 3072, bias=False)\n        # self.layer.load(mode=InferenceState.GENERATE)\n\n    def forward(self, x):\n        x = self.layer(x)\n        # x = self.fc1(x)\n        x = self.relu(x)\n        x = self.fc2(x)\n        return x\n\nclass TestModelTorch(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.layer = KLinearTorch(\n            key=\"linear\",\n            gguf_loader=gguf_loader, \n            config=config, \n            orig_module=nn.Linear(in_features=3072, out_features=2048, bias=False)\n        )\n        # self.layer.weight = nn.Parameter(torch.randn(3072, 2048).to(\"cuda\"))\n        # self.layer.weight = torch.randn(3072, 2048).to(\"cuda\")\n        weight = torch.randn(3072, 2048, device=\"cuda\")\n        self.layer.load(w=nn.Parameter(weight), device=\"cuda\")\n        self.fc1 = nn.Linear(3072, 2048, bias=False)\n        self.relu = nn.ReLU()\n        self.fc2 = nn.Linear(2048, 3072, bias=False)\n        # self.layer.load(mode=InferenceState.GENERATE) \n\n    def forward(self, x):\n        x = self.layer(x)\n        # x = self.fc1(x)\n        x = self.relu(x)\n        x = self.fc2(x)\n        return x\n\n\n# # KLinearTorch Well DONE for test!\n# model = TestModelTorch()\n# x = torch.randn(2048, 3072, requires_grad=True)\n# out = model(x)\n# make_dot(out, params=dict(model.named_parameters())).render(\"KTLinear_graph\", format=\"svg\")\n\n\n# model = TestModelBase()\n# x = torch.randn(2048, 3072, requires_grad=True)\n# out = model(x)\n# make_dot(out, params=dict(model.named_parameters())).render(\"base_graph\", format=\"svg\")\n\n# MyConvNet_graph=hl.build_graph(model,torch.zeros(size=[2048, 3072]))\n# MyConvNet_graph.theme=hl.graph.THEMES['blue'].copy()\n# MyConvNet_graph.save(path='./base_graph.png',format='png')\n\n# model = TestModelLora()\n# x = torch.randn(2048, 3072, requires_grad=True)\n# out = model(x)\n# make_dot(out, params=dict(model.named_parameters())).render(\"lora_graph\", format=\"svg\")\n\n\nfrom peft import LoraConfig, get_peft_model\n\nclass BaseModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear = nn.Linear(3072, 2048, bias=False)\n    \n    def forward(self, x):\n        return self.linear(x)\n\nmodel = BaseModel().to(\"cuda\")\n\nlora_config = LoraConfig(\n    r=8,\n    lora_alpha=16,\n    target_modules=[\"linear\"],\n    lora_dropout=0.0,\n    bias=\"none\",\n)\n\npeft_model = get_peft_model(model, lora_config)\nprint(peft_model)\n\nx = torch.randn(2048, 3072, requires_grad=True).to(\"cuda\")\n\nout = peft_model(x)\n\ndot = make_dot(out, \n             params=dict(peft_model.named_parameters()))\n\ndot.render(\"origin_lora_graph\", format=\"svg\")"
  },
  {
    "path": "kt-sft/ktransformers/models/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/models/configuration_deepseek.py",
    "content": "# Adapted from\n# https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/configuration_deepseek.py\n# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nDEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}\nclass DeepseekV2Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the DeepSeek-V2.\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 102400):\n            Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`DeepseekV2Model`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        moe_intermediate_size (`int`, *optional*, defaults to 1407):\n            Dimension of the MoE representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer decoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        n_shared_experts (`int`, *optional*, defaults to None):\n            Number of shared experts, None means dense model.\n        n_routed_experts (`int`, *optional*, defaults to None):\n            Number of routed experts, None means dense model.\n        routed_scaling_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor or routed experts.\n        topk_method (`str`, *optional*, defaults to `gready`):\n            Topk method used in routed gate.\n        n_group (`int`, *optional*, defaults to None):\n            Number of groups for routed experts.\n        topk_group (`int`, *optional*, defaults to None):\n            Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).\n        num_experts_per_tok (`int`, *optional*, defaults to None):\n            Number of selected experts, None means dense model.\n        moe_layer_freq (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.\n        first_k_dense_replace (`int`, *optional*, defaults to 0):\n            Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).\n                                                            \\--k dense layers--/\n        norm_topk_prob (`bool`, *optional*, defaults to False):\n            Whether to normalize the weights of the routed experts.\n        scoring_func (`str`, *optional*, defaults to 'softmax'):\n            Method of computing expert weights.\n        aux_loss_alpha (`float`, *optional*, defaults to 0.001):\n            Auxiliary loss weight coefficient.\n        seq_aux = (`bool`, *optional*, defaults to True):\n            Whether to compute the auxiliary loss for each individual sample.\n        num_key_value_heads (`int`, *optional*):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 2048):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*):\n            Padding token id.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            Beginning of stream token id.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            End of stream token id.\n        pretraining_tp (`int`, *optional*, defaults to 1):\n            Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this\n            document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is\n            necessary to ensure exact reproducibility of the pretraining results. Please refer to [this\n            issue](https://github.com/pytorch/pytorch/issues/76232).\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling\n            strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is\n            `{\"type\": strategy name, \"factor\": scaling factor}`. When using this flag, don't update\n            `max_position_embeddings` to the expected new maximum.\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n    ```python\n    >>> from transformers import DeepseekV2Model, DeepseekV2Config\n    >>> # Initializing a Deepseek-V2 style configuration\n    >>> configuration = DeepseekV2Config()\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"deepseek_v2\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=102400,\n        hidden_size=4096,\n        intermediate_size=11008,\n        moe_intermediate_size = 1407,\n        num_hidden_layers=30,\n        num_attention_heads=32,\n        num_key_value_heads=32,\n        n_shared_experts = None,\n        n_routed_experts = None,\n        ep_size = 1,\n        routed_scaling_factor = 1.0,\n        kv_lora_rank = 512,\n        q_lora_rank = 1536,\n        qk_rope_head_dim = 64,\n        v_head_dim = 128,\n        qk_nope_head_dim = 128,\n        topk_method = 'gready',\n        n_group = None,\n        topk_group = None,\n        num_experts_per_tok = None,\n        moe_layer_freq = 1,\n        first_k_dense_replace = 0,\n        norm_topk_prob = False,\n        scoring_func = 'softmax',\n        aux_loss_alpha = 0.001,\n        seq_aux = True,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=100000,\n        eos_token_id=100001,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        cpu_quant=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        self.aux_loss_alpha = aux_loss_alpha\n        self.seq_aux = seq_aux\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        self.cpu_quant = cpu_quant\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/models/configuration_deepseek_v3.py",
    "content": "from transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\nlogger = logging.get_logger(__name__)\n\nDEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}\nclass DeepseekV3Config(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the DeepSeek-V3.\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 129280):\n            Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`DeepseekV3Model`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        moe_intermediate_size (`int`, *optional*, defaults to 1407):\n            Dimension of the MoE representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer decoder.\n        num_nextn_predict_layers (`int`, *optional*, defaults to 1):\n            Number of nextn predict layers in the DeepSeekV3 Model.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        n_shared_experts (`int`, *optional*, defaults to None):\n            Number of shared experts, None means dense model.\n        n_routed_experts (`int`, *optional*, defaults to None):\n            Number of routed experts, None means dense model.\n        routed_scaling_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor or routed experts.\n        topk_method (`str`, *optional*, defaults to `gready`):\n            Topk method used in routed gate.\n        n_group (`int`, *optional*, defaults to None):\n            Number of groups for routed experts.\n        topk_group (`int`, *optional*, defaults to None):\n            Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).\n        num_experts_per_tok (`int`, *optional*, defaults to None):\n            Number of selected experts, None means dense model.\n        moe_layer_freq (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.\n        first_k_dense_replace (`int`, *optional*, defaults to 0):\n            Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).\n                                                            \\--k dense layers--/\n        norm_topk_prob (`bool`, *optional*, defaults to False):\n            Whether to normalize the weights of the routed experts.\n        scoring_func (`str`, *optional*, defaults to 'softmax'):\n            Method of computing expert weights.\n        aux_loss_alpha (`float`, *optional*, defaults to 0.001):\n            Auxiliary loss weight coefficient.\n        seq_aux = (`bool`, *optional*, defaults to True):\n            Whether to compute the auxiliary loss for each individual sample.\n        num_key_value_heads (`int`, *optional*):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 2048):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*):\n            Padding token id.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            Beginning of stream token id.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            End of stream token id.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling\n            strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is\n            `{\"type\": strategy name, \"factor\": scaling factor}`. When using this flag, don't update\n            `max_position_embeddings` to the expected new maximum.\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n    ```python\n    >>> from transformers import DeepseekV3Model, DeepseekV3Config\n    >>> # Initializing a Deepseek-V3 style configuration\n    >>> configuration = DeepseekV3Config()\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"deepseek_v3\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=129280,\n        hidden_size=7168,\n        intermediate_size=18432,\n        moe_intermediate_size = 2048,\n        num_hidden_layers=61,\n        num_nextn_predict_layers=1,\n        num_attention_heads=128,\n        num_key_value_heads=128,\n        n_shared_experts = 1,\n        n_routed_experts = 256,\n        ep_size = 1,\n        routed_scaling_factor = 2.5,\n        kv_lora_rank = 512,\n        q_lora_rank = 1536,\n        qk_rope_head_dim = 64,\n        v_head_dim = 128,\n        qk_nope_head_dim = 128,\n        topk_method = 'noaux_tc',\n        n_group = 8,\n        topk_group = 4,\n        num_experts_per_tok = 8,\n        moe_layer_freq = 1,\n        first_k_dense_replace = 3,\n        norm_topk_prob = True,\n        scoring_func = 'sigmoid',\n        hidden_act=\"silu\",\n        max_position_embeddings=4096,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=0,\n        eos_token_id=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_nextn_predict_layers = num_nextn_predict_layers\n        self.num_attention_heads = num_attention_heads\n        self.n_shared_experts = n_shared_experts\n        self.n_routed_experts = n_routed_experts\n        self.ep_size = ep_size\n        self.routed_scaling_factor = routed_scaling_factor\n        self.kv_lora_rank = kv_lora_rank\n        self.q_lora_rank = q_lora_rank\n        self.qk_rope_head_dim = qk_rope_head_dim\n        self.v_head_dim = v_head_dim\n        self.qk_nope_head_dim = qk_nope_head_dim\n        self.topk_method = topk_method\n        self.n_group = n_group\n        self.topk_group = topk_group\n        self.num_experts_per_tok = num_experts_per_tok\n        self.moe_layer_freq = moe_layer_freq\n        self.first_k_dense_replace = first_k_dense_replace\n        self.norm_topk_prob = norm_topk_prob\n        self.scoring_func = scoring_func\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )"
  },
  {
    "path": "kt-sft/ktransformers/models/configuration_llama.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"LLaMA model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import rope_config_validation\n\n\nclass LlamaConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA\n    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the\n    defaults will yield a similar configuration to that of the LLaMA-7B.\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 32000):\n            Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`LlamaModel`]\n        hidden_size (`int`, *optional*, defaults to 4096):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 11008):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 32):\n            Number of hidden layers in the Transformer decoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer decoder.\n        num_key_value_heads (`int`, *optional*):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to\n            `num_attention_heads`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 2048):\n            The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,\n            Llama 2 up to 4096, CodeLlama up to 16384.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        pad_token_id (`int`, *optional*):\n            Padding token id.\n        bos_token_id (`int`, *optional*, defaults to 1):\n            Beginning of stream token id.\n        eos_token_id (`int`, *optional*, defaults to 2):\n            End of stream token id.\n        pretraining_tp (`int`, *optional*, defaults to 1):\n            Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this\n            document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to\n            understand more about it. This value is necessary to ensure exact reproducibility of the pretraining\n            results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether to tie weight embeddings\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        attention_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        mlp_bias (`bool`, *optional*, defaults to `False`):\n            Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.\n\n    ```python\n    >>> from transformers import LlamaModel, LlamaConfig\n\n    >>> # Initializing a LLaMA llama-7b style configuration\n    >>> configuration = LlamaConfig()\n\n    >>> # Initializing a model from the llama-7b style configuration\n    >>> model = LlamaModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"llama\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=32000,\n        hidden_size=4096,\n        intermediate_size=11008,\n        num_hidden_layers=32,\n        num_attention_heads=32,\n        num_key_value_heads=None,\n        hidden_act=\"silu\",\n        max_position_embeddings=2048,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        pad_token_id=None,\n        bos_token_id=1,\n        eos_token_id=2,\n        pretraining_tp=1,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        attention_dropout=0.0,\n        mlp_bias=False,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n\n        # for backward compatibility\n        if num_key_value_heads is None:\n            num_key_value_heads = num_attention_heads\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.pretraining_tp = pretraining_tp\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        self.mlp_bias = mlp_bias\n\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n        rope_config_validation(self)\n\n        super().__init__(\n            pad_token_id=pad_token_id,\n            bos_token_id=bos_token_id,\n            eos_token_id=eos_token_id,\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/models/configuration_qwen2_moe.py",
    "content": "# coding=utf-8\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Qwen2MoE model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen2MoeConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen2MoeModel`]. It is used to instantiate a\n    Qwen2MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of\n    Qwen1.5-MoE-A2.7B\" [Qwen/Qwen1.5-MoE-A2.7B\"](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B\").\n\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n\n\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151936):\n            Vocabulary size of the Qwen2MoE model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`Qwen2MoeModel`]\n        hidden_size (`int`, *optional*, defaults to 2048):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 5632):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 24):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 16):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 16):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 32768):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        use_sliding_window (`bool`, *optional*, defaults to `False`):\n            Whether to use sliding window attention.\n        sliding_window (`int`, *optional*, defaults to 4096):\n            Sliding window attention (SWA) window size. If not specified, will default to `4096`.\n        max_window_layers (`int`, *optional*, defaults to 28):\n            The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        decoder_sparse_step (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer.\n        moe_intermediate_size (`int`, *optional*, defaults to 1408):\n            Intermediate size of the routed expert.\n        shared_expert_intermediate_size (`int`, *optional*, defaults to 5632):\n            Intermediate size of the shared expert.\n        num_experts_per_tok (`int`, *optional*, defaults to 4):\n            Number of selected experts.\n        num_experts (`int`, *optional*, defaults to 60):\n            Number of routed experts.\n        norm_topk_prob (`bool`, *optional*, defaults to `False`):\n            Whether to normalize the topk probabilities.\n        output_router_logits (`bool`, *optional*, defaults to `False`):\n            Whether or not the router logits should be returned by the model. Enabeling this will also\n            allow the model to output the auxiliary loss, including load balancing loss and router z-loss.\n        router_aux_loss_coef (`float`, *optional*, defaults to 0.001):\n            The aux loss factor for the total loss.\n        mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):\n            Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock\n            The list contains layer index, from 0 to num_layers-1 if we have num_layers layers\n            If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.\n\n    ```python\n    >>> from transformers import Qwen2MoeModel, Qwen2MoeConfig\n\n    >>> # Initializing a Qwen2MoE style configuration\n    >>> configuration = Qwen2MoeConfig()\n\n    >>> # Initializing a model from the Qwen1.5-MoE-A2.7B\" style configuration\n    >>> model = Qwen2MoeModel(configuration)\n\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen2_moe\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=2048,\n        intermediate_size=5632,\n        num_hidden_layers=24,\n        num_attention_heads=16,\n        num_key_value_heads=16,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        use_sliding_window=False,\n        sliding_window=4096,\n        max_window_layers=28,\n        attention_dropout=0.0,\n        decoder_sparse_step=1,\n        moe_intermediate_size=1408,\n        shared_expert_intermediate_size=5632,\n        num_experts_per_tok=4,\n        num_experts=60,\n        norm_topk_prob=False,\n        output_router_logits=False,\n        router_aux_loss_coef=0.001,\n        mlp_only_layers=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.use_sliding_window = use_sliding_window\n        self.sliding_window = sliding_window if use_sliding_window else None\n        self.max_window_layers = max_window_layers\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.attention_dropout = attention_dropout\n\n        # MoE arguments\n        self.decoder_sparse_step = decoder_sparse_step\n        self.moe_intermediate_size = moe_intermediate_size\n        self.shared_expert_intermediate_size = shared_expert_intermediate_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_experts = num_experts\n        self.norm_topk_prob = norm_topk_prob\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n        self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )"
  },
  {
    "path": "kt-sft/ktransformers/models/configuration_qwen3_moe.py",
    "content": "# coding=utf-8\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Qwen3MoE model configuration\"\"\"\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.modeling_rope_utils import rope_config_validation\nfrom transformers.utils import logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass Qwen3MoeConfig(PretrainedConfig):\n    r\"\"\"\n    This is the configuration class to store the configuration of a [`Qwen3MoeModel`]. It is used to instantiate a\n    Qwen3MoE model according to the specified arguments, defining the model architecture. Instantiating a configuration\n    with the defaults will yield a similar configuration to that of [Qwen/Qwen3-MoE-15B-A2B](https://huggingface.co/Qwen/Qwen3-15B-A2B).\n    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the\n    documentation from [`PretrainedConfig`] for more information.\n    Args:\n        vocab_size (`int`, *optional*, defaults to 151936):\n            Vocabulary size of the Qwen3MoE model. Defines the number of different tokens that can be represented by the\n            `inputs_ids` passed when calling [`Qwen3MoeModel`]\n        hidden_size (`int`, *optional*, defaults to 2048):\n            Dimension of the hidden representations.\n        intermediate_size (`int`, *optional*, defaults to 6144):\n            Dimension of the MLP representations.\n        num_hidden_layers (`int`, *optional*, defaults to 24):\n            Number of hidden layers in the Transformer encoder.\n        num_attention_heads (`int`, *optional*, defaults to 32):\n            Number of attention heads for each attention layer in the Transformer encoder.\n        num_key_value_heads (`int`, *optional*, defaults to 4):\n            This is the number of key_value heads that should be used to implement Grouped Query Attention. If\n            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if\n            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When\n            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed\n            by meanpooling all the original heads within that group. For more details checkout [this\n            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.\n        hidden_act (`str` or `function`, *optional*, defaults to `\"silu\"`):\n            The non-linear activation function (function or string) in the decoder.\n        max_position_embeddings (`int`, *optional*, defaults to 32768):\n            The maximum sequence length that this model might ever be used with.\n        initializer_range (`float`, *optional*, defaults to 0.02):\n            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.\n        rms_norm_eps (`float`, *optional*, defaults to 1e-06):\n            The epsilon used by the rms normalization layers.\n        use_cache (`bool`, *optional*, defaults to `True`):\n            Whether or not the model should return the last key/values attentions (not used by all models). Only\n            relevant if `config.is_decoder=True`.\n        tie_word_embeddings (`bool`, *optional*, defaults to `False`):\n            Whether the model's input and output word embeddings should be tied.\n        rope_theta (`float`, *optional*, defaults to 10000.0):\n            The base period of the RoPE embeddings.\n        rope_scaling (`Dict`, *optional*):\n            Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type\n            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value\n            accordingly.\n            Expected contents:\n                `rope_type` (`str`):\n                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',\n                    'llama3'], with 'default' being the original RoPE implementation.\n                `factor` (`float`, *optional*):\n                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In\n                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *\n                    original maximum pre-trained length.\n                `original_max_position_embeddings` (`int`, *optional*):\n                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during\n                    pretraining.\n                `attention_factor` (`float`, *optional*):\n                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention\n                    computation. If unspecified, it defaults to value recommended by the implementation, using the\n                    `factor` field to infer the suggested value.\n                `beta_fast` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 32.\n                `beta_slow` (`float`, *optional*):\n                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear\n                    ramp function. If unspecified, it defaults to 1.\n                `short_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `long_factor` (`List[float]`, *optional*):\n                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<\n                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden\n                    size divided by the number of attention heads divided by 2\n                `low_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE\n                `high_freq_factor` (`float`, *optional*):\n                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE\n        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):\n            Whether to use a bias in the query, key, value and output projection layers during self-attention.\n        use_sliding_window (`bool`, *optional*, defaults to `False`):\n            Whether to use sliding window attention.\n        sliding_window (`int`, *optional*, defaults to 4096):\n            Sliding window attention (SWA) window size. If not specified, will default to `4096`.\n        max_window_layers (`int`, *optional*, defaults to 28):\n            The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.\n        attention_dropout (`float`, *optional*, defaults to 0.0):\n            The dropout ratio for the attention probabilities.\n        decoder_sparse_step (`int`, *optional*, defaults to 1):\n            The frequency of the MoE layer.\n        moe_intermediate_size (`int`, *optional*, defaults to 768):\n            Intermediate size of the routed expert.\n        num_experts_per_tok (`int`, *optional*, defaults to 8):\n            Number of selected experts.\n        num_experts (`int`, *optional*, defaults to 128):\n            Number of routed experts.\n        norm_topk_prob (`bool`, *optional*, defaults to `False`):\n            Whether to normalize the topk probabilities.\n        output_router_logits (`bool`, *optional*, defaults to `False`):\n            Whether or not the router logits should be returned by the model. Enabeling this will also\n            allow the model to output the auxiliary loss, including load balancing loss and router z-loss.\n        router_aux_loss_coef (`float`, *optional*, defaults to 0.001):\n            The aux loss factor for the total loss.\n        mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):\n            Indicate which layers use Qwen3MoeMLP rather than Qwen3MoeSparseMoeBlock\n            The list contains layer index, from 0 to num_layers-1 if we have num_layers layers\n            If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.\n    ```python\n    >>> from transformers import Qwen3MoeModel, Qwen3MoeConfig\n    >>> # Initializing a Qwen3MoE style configuration\n    >>> configuration = Qwen3MoeConfig()\n    >>> # Initializing a model from the Qwen3-15B-A2B\" style configuration\n    >>> model = Qwen3MoeModel(configuration)\n    >>> # Accessing the model configuration\n    >>> configuration = model.config\n    ```\"\"\"\n\n    model_type = \"qwen3_moe\"\n    keys_to_ignore_at_inference = [\"past_key_values\"]\n\n    # Default tensor parallel plan for base model `Qwen3Moe`\n    base_model_tp_plan = {\n        \"layers.*.self_attn.q_proj\": \"colwise\",\n        \"layers.*.self_attn.k_proj\": \"colwise\",\n        \"layers.*.self_attn.v_proj\": \"colwise\",\n        \"layers.*.self_attn.o_proj\": \"rowwise\",\n        \"layers.*.mlp.gate_proj\": \"colwise\",\n        \"layers.*.mlp.up_proj\": \"colwise\",\n        \"layers.*.mlp.down_proj\": \"rowwise\",\n    }\n    base_model_pp_plan = {\n        \"embed_tokens\": ([\"input_ids\"], [\"inputs_embeds\"]),\n        \"layers\": ([\"hidden_states\", \"attention_mask\"], [\"hidden_states\"]),\n        \"norm\": ([\"hidden_states\"], [\"hidden_states\"]),\n    }\n\n    def __init__(\n        self,\n        vocab_size=151936,\n        hidden_size=2048,\n        intermediate_size=6144,\n        num_hidden_layers=24,\n        num_attention_heads=32,\n        num_key_value_heads=4,\n        hidden_act=\"silu\",\n        max_position_embeddings=32768,\n        initializer_range=0.02,\n        rms_norm_eps=1e-6,\n        use_cache=True,\n        tie_word_embeddings=False,\n        rope_theta=10000.0,\n        rope_scaling=None,\n        attention_bias=False,\n        use_sliding_window=False,\n        sliding_window=4096,\n        max_window_layers=28,\n        attention_dropout=0.0,\n        decoder_sparse_step=1,\n        moe_intermediate_size=768,\n        num_experts_per_tok=8,\n        num_experts=128,\n        norm_topk_prob=False,\n        output_router_logits=False,\n        router_aux_loss_coef=0.001,\n        mlp_only_layers=None,\n        **kwargs,\n    ):\n        self.vocab_size = vocab_size\n        self.max_position_embeddings = max_position_embeddings\n        self.hidden_size = hidden_size\n        self.intermediate_size = intermediate_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.use_sliding_window = use_sliding_window\n        self.sliding_window = sliding_window if use_sliding_window else None\n        self.max_window_layers = max_window_layers\n\n        self.num_key_value_heads = num_key_value_heads\n        self.hidden_act = hidden_act\n        self.initializer_range = initializer_range\n        self.rms_norm_eps = rms_norm_eps\n        self.use_cache = use_cache\n        self.rope_theta = rope_theta\n        self.rope_scaling = rope_scaling\n        self.attention_bias = attention_bias\n        self.attention_dropout = attention_dropout\n        # Validate the correctness of rotary position embeddings parameters\n        # BC: if there is a 'type' field, move it to 'rope_type'.\n        if self.rope_scaling is not None and \"type\" in self.rope_scaling:\n            self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n        rope_config_validation(self)\n\n        # MoE arguments\n        self.decoder_sparse_step = decoder_sparse_step\n        self.moe_intermediate_size = moe_intermediate_size\n        self.num_experts_per_tok = num_experts_per_tok\n        self.num_experts = num_experts\n        self.norm_topk_prob = norm_topk_prob\n        self.output_router_logits = output_router_logits\n        self.router_aux_loss_coef = router_aux_loss_coef\n        self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers\n\n        super().__init__(\n            tie_word_embeddings=tie_word_embeddings,\n            **kwargs,\n        )\n\n\n__all__ = [\"Qwen3MoeConfig\"]"
  },
  {
    "path": "kt-sft/ktransformers/models/custom_cache.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\n'''\n# Adapted from\n# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/cache_utils.py\n# Copyright 2018- The Hugging Face team. All rights reserved.\n# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\nimport torch\nimport torch.nn as nn\nimport transformers\nfrom transformers import Cache, PretrainedConfig\nfrom typing import List, Optional, Dict, Any, Tuple\ntry:\n    from ktransformers.server.balance_serve.settings import sched_ext\nexcept:\n    print(\"no balance_serve\")\nclass StaticCache(transformers.StaticCache):\n    \"\"\"\n    Static Cache class to be used with `torch.compile(model)`.\n\n    Parameters:\n        config (`PretrainedConfig):\n            The configuration file defining the shape-related attributes required to initialize the static cache.\n        max_batch_size (`int`):\n            The maximum batch size with which the model will be used.\n        max_cache_len (`int`):\n            The maximum sequence length with which the model will be used.\n        device (`torch.device` or `dict`):\n            The device on which the cache should be initialized. Should be the same as the layer.\n            If a `dict`, it should contain the `device` key with the device name as the value.\n        dtype (*optional*, defaults to `torch.float32`):\n            The default `dtype` to use when initializing the layer.\n    \"\"\"\n\n    def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:\n        Cache.__init__(self)\n        self.max_batch_size = max_batch_size\n        self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len\n        # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads\n        if config.architectures[0] == \"DeepseekV3ForCausalLM\":\n            self.head_dim = config.qk_rope_head_dim\n        else:\n            self.head_dim = (\n                config.head_dim if hasattr(config, \"head_dim\") else config.hidden_size // config.num_attention_heads\n            )\n\n        self.dtype = dtype if dtype is not None else torch.float32\n        self.num_key_value_heads = (\n            config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads\n        )\n\n        self.key_cache: List[torch.Tensor] = []\n        self.value_cache: List[torch.Tensor] = []\n        cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)\n        if config.architectures[0] == \"DeepseekV2ForCausalLM\" or config.architectures[0] == \"DeepseekV3ForCausalLM\":\n            # TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically\n            self.page_size = 64\n            self.max_pages = (self.max_cache_len + self.page_size - 1) // self.page_size\n            latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)\n            self.kv_lora_rank = config.kv_lora_rank\n            self.qk_rope_head_dim = config.qk_rope_head_dim\n            # TODO: support real page table\n            self.page_table_map = dict()\n            self.page_table_list = []\n            for idx in range(config.num_hidden_layers):\n                if isinstance(device, dict):\n                    target_device = device[f\"model.layers.{idx}.self_attn\"][\"generate_device\"]\n                else:\n                    target_device = device\n                \n                if target_device not in self.page_table_map:\n                    page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)\n                    for seq_id in range(max_batch_size):\n                        page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)\n                    self.page_table_map[target_device] = page_table\n                    \n                self.page_table_list.append(self.page_table_map[target_device])\n                    \n            self.is_MLA = True\n            self.is_page = True\n        else:\n            key_shape = cache_shape\n            value_shape = cache_shape\n            self.is_MLA = False\n\n        self.past_tokens = []\n        self.num_hidden_layers = config.num_hidden_layers\n        for idx in range(self.num_hidden_layers):\n            # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph\n            # breaks when updating the cache.\n            if isinstance(device, dict):\n                target_device = device[f\"model.layers.{idx}.self_attn\"][\"generate_device\"]\n            else:\n                target_device = device\n            \n            if self.is_MLA:\n                new_layer_key_cache = torch.zeros(latent_shape, dtype=self.dtype, device=target_device)\n                new_layer_value_cache = None\n                torch._dynamo.mark_static_address(new_layer_key_cache)\n            else:\n                new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)\n                new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)\n                torch._dynamo.mark_static_address(new_layer_key_cache)\n                torch._dynamo.mark_static_address(new_layer_value_cache)\n                \n            self.key_cache.append(new_layer_key_cache)\n            self.value_cache.append(new_layer_value_cache)\n            self.past_tokens.append(0)\n\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.\n        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.\n\n        Parameters:\n            key_states (`torch.Tensor`):\n                The new key states to cache.\n            value_states (`torch.Tensor`):\n                The new value states to cache.\n            layer_idx (`int`):\n                The index of the layer to cache the states for.\n            cache_kwargs (`Dict[str, Any]`, `optional`):\n                Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input\n                to know how where to write in the cache.\n\n        Return:\n            A tuple containing the updated key and value states.\n        \"\"\"\n        cache_position = cache_kwargs.get(\"cache_position\")\n        k_out = self.key_cache[layer_idx]\n        v_out = self.value_cache[layer_idx]\n        self.past_tokens[layer_idx] += cache_position.size(0)\n        #print(cache_position)\n        if self.is_MLA:\n            page_idx = cache_position // self.page_size\n            page_offset = cache_position % self.page_size\n            # key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)\n            k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states\n            k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states\n            return k_out, self.page_table_list[layer_idx]\n        else:\n            k_out[:, :, cache_position] = key_states\n            v_out[:, :, cache_position] = value_states\n            return k_out, v_out\n\n    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:\n        \"\"\"Returns the sequence length of the cached states that were seen by the model.\"\"\"\n        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's\n        # limit the check to the first batch member and head dimension.\n        # TODO: deprecate this function in favor of `cache_position`\n        return self.past_tokens[layer_idx]\n    \n    def change_seq_length(self, bias: Optional[int] = 0) -> int:\n        \"\"\"Returns the sequence length of the cached states that were seen by the model.\"\"\"\n        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's\n        # limit the check to the first batch member and head dimension.\n        # TODO: deprecate this function in favor of `cache_position`\n        for layer_idx in range(self.num_hidden_layers):\n            self.past_tokens[layer_idx] += bias\n\n    def get_max_length(self) -> Optional[int]:\n        \"\"\"Returns the maximum sequence length of the cached states.\"\"\"\n        return self.max_cache_len\n\n    def reset(self):\n        \"\"\"Resets the cache values while preserving the objects\"\"\"\n        for layer_idx in range(len(self.key_cache)):\n            # In-place ops prevent breaking the static address\n            self.key_cache[layer_idx].zero_()\n            if self.value_cache[layer_idx] is not None:\n                self.value_cache[layer_idx].zero_()\n            self.past_tokens[layer_idx] = 0\n\n    def remove_suffix(self, start_pos):\n        for layer_idx in range(len(self.key_cache)):\n            # In-place ops prevent breaking the static address\n            if self.is_MLA:\n                k_cache = self.key_cache[layer_idx]\n                k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_()\n            else:\n                self.key_cache[layer_idx][..., start_pos:, :].zero_()\n                self.value_cache[layer_idx][..., start_pos:, :].zero_()\n            self.past_tokens[layer_idx] = start_pos\n    \n    def get_max_cache_shape(self) -> Tuple[int, int, int, int]:\n        \"\"\"Returns the maximum shape of the cache.\"\"\"\n        return self.max_cache_len\n\nclass KDeepSeekV3Cache(nn.Module):\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        page_size: int = 256,\n        dtype=torch.bfloat16,\n        device=torch.device(\"cuda:0\"),\n        \n    ):\n        super().__init__()\n        self.config = config\n        self.dtype = dtype\n        self.device = device\n        self.kv_lora_rank = config.kv_lora_rank\n        self.page_size = page_size\n        self.k_caches = []\n        self.v_caches = []\n        \n\n    def load(self, inference_context: \"sched_ext.InferenceContext\"):\n        \n        for i in range(self.config.num_hidden_layers):\n            self.k_caches.append(\n                inference_context.k_cache[0][i] \n            )\n        self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]\n\n    def update(\n        self,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        layer_idx: int,\n\n        page_idx: torch.Tensor,\n        page_offset: torch.Tensor,\n\n        cache_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        \"\"\"\n        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.\n        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.\n\n        Parameters:\n            key_states (`torch.Tensor`):\n                The new key states to cache.\n            value_states (`torch.Tensor`):\n                The new value states to cache.\n            layer_idx (`int`):\n                The index of the layer to cache the states for.\n            cache_kwargs (`Dict[str, Any]`, `optional`):\n                Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input\n                to know how where to write in the cache.\n\n        Return:\n            A tuple containing the updated key and value states.\n        \"\"\"\n        k_out = self.k_caches[layer_idx]\n\n        k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:])\n        k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:])\n        return k_out\n\n        \n    def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):\n        page_offset = cache_position % self.page_size  \n        page_idx_local = cache_position // self.page_size  \n        query_ids = torch.zeros_like(cache_position)\n        for i in range(len(q_indptr) - 1):\n            start_idx = q_indptr[i]\n            end_idx = q_indptr[i + 1]\n            query_ids[start_idx:end_idx] = i\n        page_idx = torch.zeros_like(page_idx_local)\n        for i in range(bsz_tensors[0]):\n            query_id = query_ids[i]\n            local_block = page_idx_local[i]\n            start_block = kv_indptr[query_id]\n            if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:\n                page_idx[i] = kv_indices[start_block + local_block]\n        \n        return page_idx, page_offset\n    \nclass KGQACache(nn.Module):\n    def __init__(\n        self,\n        config: PretrainedConfig,\n        page_size: int = 256,\n        dtype=torch.bfloat16,\n        device=torch.device(\"cuda:0\"),\n        \n    ):\n        super().__init__()\n        self.config = config\n        self.dtype = dtype\n        self.device = device\n        self.page_size = page_size\n        self.k_caches = []\n        self.v_caches = []\n        \n\n    def load(self, inference_context: \"sched_ext.InferenceContext\"):\n        print(self.config.num_hidden_layers)\n        for i in range(self.config.num_hidden_layers):\n            self.k_caches.append(\n                inference_context.k_cache[0][i] \n            )\n            self.v_caches.append(\n                inference_context.v_cache[0][i]\n            )\n\n\n        self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]\n\n\n        \n    def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):\n        page_offset = cache_position % self.page_size  \n        page_idx_local = cache_position // self.page_size  \n        query_ids = torch.zeros_like(cache_position)\n        for i in range(len(q_indptr) - 1):\n            start_idx = q_indptr[i]\n            end_idx = q_indptr[i + 1]\n            query_ids[start_idx:end_idx] = i\n        page_idx = torch.zeros_like(page_idx_local)\n        for i in range(bsz_tensors[0]):\n            query_id = query_ids[i]\n            local_block = page_idx_local[i]\n            start_block = kv_indptr[query_id]\n            if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:\n                page_idx[i] = kv_indices[start_block + local_block]\n        \n        return page_idx, page_offset\n\n    def get_k_cache(self, layer_idx):\n        return self.k_caches[layer_idx]\n\n    def get_v_cache(self, layer_idx):\n        return self.v_caches[layer_idx]"
  },
  {
    "path": "kt-sft/ktransformers/models/custom_modeling_deepseek_v2.py",
    "content": "import math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KDeepSeekV3Cache\nfrom  ktransformers.models.modeling_deepseek import DeepseekV2Model,  DeepseekV2PreTrainedModel\nfrom ktransformers.models.configuration_deepseek import DeepseekV2Config\n\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KDeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):\n\n    kv_cache: KDeepSeekV3Cache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config,\n        kv_cache,\n\n    ):\n        super().__init__(config)\n        self.model = DeepseekV2Model(config)\n        self.config = config\n        self.kv_cache = kv_cache\n\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        \n\n    def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):\n        self.use_cuda_graph = use_cuda_graph\n        self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)\n        self.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)\n        self.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)\n        self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)\n        self.paged_kv_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device)\n\n\t\t\n\n        self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(\n            self.workspace_buffer, use_cuda_graph=use_cuda_graph,\n            qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,\n            kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,\n            backend = \"fa2\",\n        )\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        \n        hidden_states = features[0]\n\n\n        with torch.cuda.stream(current_stream):\n            residual = torch.zeros_like(hidden_states)\n            for i, decode_layer in enumerate(self.model.layers):\n                if self.model.transfer_map is not None and i in self.model.transfer_map:\n                    prev_stream = torch.cuda.current_stream()\n                    cur_device = self.model.transfer_map[i]\n                    if cur_device not in self.model.stream_device_map:\n                        self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                    torch.cuda.set_device(cur_device)\n                    self.model.stream_device_map[cur_device].wait_stream(prev_stream)\n                    torch.cuda.set_stream(self.model.stream_device_map[cur_device])\n                    hidden_states = hidden_states.to(\n                        self.model.transfer_map[i], non_blocking=True\n                    )\n\n                    batch.minibatch.position_ids = (\n                        batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)\n                        if batch.minibatch.position_ids is not None\n                        else None\n                    )\n                hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.self_attn(hidden_states, self.kv_cache, \n                                                       position_ids=batch.minibatch.position_ids, \n                                                       wrapper=self.wrapper, bsz_tensors=num_tokens_tensors, \n                                                       cache_position=batch.minibatch.positions, \n                                                       batch_indices=batch.minibatch.batch_indices,\n                                                       kv_indices=batch.minibatch.kv_indices,\n                                                       kv_indptr=batch.minibatch.kv_indptr,\n                                                       kv_last_page_len=batch.minibatch.kv_last_page_len,\n                                                       q_indptr=batch.minibatch.q_indptr,\n                                                       page_idx=page_idx,\n                                                       page_offset=page_offset\n                                                       )\n\n                hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n                if i < 3:\n                    hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors)\n                else:\n                    hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors)\n                    hidden_states = hidden_states.squeeze(0)\n        forward_batch_output = ForwardBatchOutput()\n        assert  batch.batch_size == 1\n        with torch.cuda.stream(current_stream):\n\n            local_logit = self.lm_head(self.model.norm(hidden_states[batch.minibatch.logits_start], num_tokens_tensors, residual[batch.minibatch.logits_start])[0])\n            # local_logit = local_logit[batch.minibatch.logits_start]\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_heads: int,\n        head_dim_ckv: int,\n        head_dim_kpe: int,\n        page_size: int,\n        causal: bool,\n        sm_scale: float,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,):\n        minibatch = batch.minibatch\n        \n        self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type)\n        "
  },
  {
    "path": "kt-sft/ktransformers/models/custom_modeling_deepseek_v3.py",
    "content": "\"\"\"\nDate: 2024-11-06 10:05:11\nLastEditors: djw\nLastEditTime: 2024-11-13 07:50:51\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KDeepSeekV3Cache\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3Model,  DeepseekV3PreTrainedModel\nfrom ktransformers.models.configuration_deepseek_v3 import DeepseekV3Config\n\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KDeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):\n\n    cache: KDeepSeekV3Cache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config: DeepseekV3Config,\n        cache,\n    ):\n        super().__init__(config)\n        self.model = DeepseekV3Model(config)\n        self.config = config\n        self.cache = cache\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        \n    def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):\n        self.use_cuda_graph = use_cuda_graph\n        self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)\n        self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)\n        self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)\n        self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)\n        self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)\n        self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device)\n\t\t\n\n        self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(\n            self.workspace_buffer, use_cuda_graph=use_cuda_graph,\n            qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,\n            kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,\n            bsz_tensor=self.bsz_tensor_buf,\n            backend = \"fa2\",\n        )\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n        cuda_graph_idx: int | None = -1\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        \n        hidden_states = features[0]\n\n        with torch.cuda.stream(current_stream):\n            residual = torch.zeros_like(hidden_states)\n            for i, decode_layer in enumerate(self.model.layers):\n                # can't use now, only one flashinfer wrapper\n                if self.model.transfer_map is not None and i in self.model.transfer_map:\n                    prev_stream = torch.cuda.current_stream()\n                    cur_device = self.model.transfer_map[i]\n                    if cur_device not in self.model.stream_device_map:\n                        self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                    torch.cuda.set_device(cur_device)\n                    self.model.stream_device_map[cur_device].wait_stream(prev_stream)\n                    torch.cuda.set_stream(self.model.stream_device_map[cur_device])\n                    hidden_states = hidden_states.to(\n                        self.model.transfer_map[i], non_blocking=True\n                    )\n\n                    batch.minibatch.position_ids = (\n                        batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)\n                        if batch.minibatch.position_ids is not None\n                        else None\n                    )\n                hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.self_attn(hidden_states, self.cache, \n                                                       position_ids=batch.minibatch.position_ids, \n                                                       wrapper=self.wrapper, num_tokens_tensors=num_tokens_tensors, \n                                                       page_idx=page_idx,\n                                                       page_offset=page_offset\n                                                       )\n\n                hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n                if i < self.config.first_k_dense_replace:\n                    hidden_states = decode_layer.mlp(hidden_states, num_tokens_tensors)\n                else:\n                    hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)\n                    hidden_states = hidden_states.squeeze(0)\n        forward_batch_output = ForwardBatchOutput()\n        with torch.cuda.stream(current_stream):\n            local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_heads: int,\n        head_dim_ckv: int,\n        head_dim_kpe: int,\n        page_size: int,\n        causal: bool,\n        sm_scale: float,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,):\n        minibatch = batch.minibatch\n        self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)\n        "
  },
  {
    "path": "kt-sft/ktransformers/models/custom_modeling_qwen2_moe.py",
    "content": "\"\"\"\nDate: 2024-11-06 10:05:11\nLastEditors: djw\nLastEditTime: 2024-11-13 07:50:51\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KGQACache\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeModel, Qwen2MoePreTrainedModel\nfrom ktransformers.models.configuration_qwen2_moe import Qwen2MoeConfig\nfrom ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KQwen2MoeForCausalLM(Qwen2MoePreTrainedModel):\n\n    cache: KGQACache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config: Qwen2MoeConfig,\n        cache,\n    ):\n        super().__init__(config)\n        self.model = Qwen2MoeModel(config)\n        self.config = config\n        self.cache = cache\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.attn = [None] * 100\n        \n    def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):\n        self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)\n\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n        cuda_graph_idx: int | None = 0\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        \n        hidden_states = features[0]\n        self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])\n\n        with torch.cuda.stream(current_stream):\n            residual = torch.zeros_like(hidden_states)\n            for i, decode_layer in enumerate(self.model.layers):\n                if self.model.transfer_map is not None and i in self.model.transfer_map:\n                    prev_stream = torch.cuda.current_stream()\n                    cur_device = self.model.transfer_map[i]\n                    if cur_device not in self.model.stream_device_map:\n                        self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                    torch.cuda.set_device(cur_device)\n                    self.model.stream_device_map[cur_device].wait_stream(prev_stream)\n                    torch.cuda.set_stream(self.model.stream_device_map[cur_device])\n                    hidden_states = hidden_states.to(\n                        self.model.transfer_map[i], non_blocking=True\n                    )\n\n                    batch.minibatch.position_ids = (\n                        batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)\n                        if batch.minibatch.position_ids is not None\n                        else None\n                    )\n                hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.self_attn(hidden_states, self.cache, \n                                                       position_ids=batch.minibatch.position_ids, \n                                                       wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, \n                                                       page_idx=page_idx,\n                                                       page_offset=page_offset\n                                                       )\n\n                hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)\n                hidden_states = hidden_states.squeeze(0)\n        forward_batch_output = ForwardBatchOutput()\n        with torch.cuda.stream(current_stream):\n            local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        page_size: int,\n        causal: bool,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,\n        cuda_graph_idx: int = 0\n        ):\n        minibatch = batch.minibatch\n        self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors,num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)\n        "
  },
  {
    "path": "kt-sft/ktransformers/models/custom_modeling_qwen3_moe.py",
    "content": "\"\"\"\nDate: 2024-11-06 10:05:11\nLastEditors: djw\nLastEditTime: 2024-11-13 07:50:51\n\"\"\"\n\nimport math\nfrom dataclasses import dataclass\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport torch\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.models.custom_cache import KGQACache\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeModel, Qwen3MoePreTrainedModel\nfrom ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig\nfrom ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nimport flashinfer\n\nclass KQwen3MoeForCausalLM(Qwen3MoePreTrainedModel):\n\n    cache: KGQACache\n    use_cuda_graph = False\n    def __init__(\n        self,\n        config: Qwen3MoeConfig,\n        cache = None,\n    ):\n        super().__init__(config)\n        self.model = Qwen3MoeModel(config)\n        self.config = config\n        self.cache = cache\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.attn = [None] * 100\n        \n    def init_wrapper(self, use_cuda_graph, device, max_batch_token, max_batch_size, max_pages, cuda_graph_idx = 0):\n        self.attn[cuda_graph_idx] = flashInferAttn(use_cuda_graph=use_cuda_graph, max_batch_token=max_batch_token, max_batch_size=max_batch_size, max_pages=max_pages, device=device)\n\n\n    def batch_embeddings(self, batch: ForwardBatchInput, device=\"cuda:0\"):\n        features = []\n        for i in range(batch.batch_size):\n            tokens = batch.minibatch.tokens.contiguous()\n            feature = (\n                self.model.embed_tokens(tokens.to(torch.device('cpu')))\n                .to(torch.bfloat16)\n                .to(device=device)\n            )\n            features.append(feature)\n\n        return features\n\n\n    def forward(\n        self,\n        batch: ForwardBatchInput | None = None,\n        features: List[torch.Tensor] | None = None,\n        bsz_tensors: torch.Tensor | None = None,\n        num_tokens_tensors: torch.Tensor | None = None,\n        page_idx: torch.Tensor | None = None,\n        page_offset: torch.Tensor | None = None,\n        cuda_graph_idx: int | None = 0\n    ) -> ForwardBatchOutput:\n        current_stream = torch.cuda.current_stream()\n\n        forward_batch_output = ForwardBatchOutput()\n\n        \n        hidden_states = features[0]\n        self.attn[cuda_graph_idx].calc_batch_indices(hidden_states.shape[0])\n\n        with torch.cuda.stream(current_stream):\n            residual = torch.zeros_like(hidden_states)\n            for i, decode_layer in enumerate(self.model.layers):\n                if self.model.transfer_map is not None and i in self.model.transfer_map:\n                    prev_stream = torch.cuda.current_stream()\n                    cur_device = self.model.transfer_map[i]\n                    if cur_device not in self.model.stream_device_map:\n                        self.model.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                    torch.cuda.set_device(cur_device)\n                    self.model.stream_device_map[cur_device].wait_stream(prev_stream)\n                    torch.cuda.set_stream(self.model.stream_device_map[cur_device])\n                    hidden_states = hidden_states.to(\n                        self.model.transfer_map[i], non_blocking=True\n                    )\n\n                    batch.minibatch.position_ids = (\n                        batch.minibatch.position_ids.to(self.model.transfer_map[i], non_blocking=True)\n                        if batch.minibatch.position_ids is not None\n                        else None\n                    )\n                hidden_states, residual = decode_layer.input_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.self_attn(hidden_states, self.cache, \n                                                       position_ids=batch.minibatch.position_ids, \n                                                       wrapper=self.attn[cuda_graph_idx], bsz_tensors=num_tokens_tensors, \n                                                       page_idx=page_idx,\n                                                       page_offset=page_offset\n                                                       )\n\n                hidden_states, residual = decode_layer.post_attention_layernorm(hidden_states, num_tokens_tensors, residual)\n                hidden_states = decode_layer.mlp(hidden_states.unsqueeze(0), num_tokens_tensors, cuda_graph_idx)\n                hidden_states = hidden_states.squeeze(0)\n        forward_batch_output = ForwardBatchOutput()\n        with torch.cuda.stream(current_stream):\n            local_logit = self.lm_head(self.model.norm(hidden_states, num_tokens_tensors, residual)[0], num_tokens_tensors)\n            forward_batch_output.logits.append(local_logit)\n\n        return forward_batch_output\n    \n\n               \n    def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_tokens_tensors,\n        num_q_heads: int,\n        num_kv_heads: int,\n        head_dim: int,\n        page_size: int,\n        causal: bool,\n        q_data_type: torch.dtype,\n        kv_data_type: torch.dtype,\n        cuda_graph_idx: int = 0\n        ):\n        minibatch = batch.minibatch\n        self.attn[cuda_graph_idx].plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices, \n                          minibatch.kv_last_page_len, bsz_tensors, num_tokens_tensors, num_q_heads, num_kv_heads, head_dim, page_size, causal=causal, q_data_type=q_data_type, kv_data_type=kv_data_type)\n        "
  },
  {
    "path": "kt-sft/ktransformers/models/modeling_deepseek.py",
    "content": "# coding=utf-8\n'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\n'''\n# Adapted from\n# https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat-0628/blob/main/modeling_deepseek.py\n# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n# \n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch DeepSeek model.\"\"\"\nimport math\nimport warnings\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n    _prepare_4d_attention_mask,\n    _prepare_4d_causal_attention_mask,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    SequenceClassifierOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import (\n    ALL_LAYERNORM_LAYERS,\n    is_torch_greater_or_equal_than_1_13,\n)\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.utils.import_utils import is_torch_fx_available\nfrom .configuration_deepseek import DeepseekV2Config\nimport torch.distributed as dist\nimport numpy as np\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\nfrom ktransformers.util.grad_wrapper import maybe_no_grad\n\n# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.\n# It means that the function will not be traced through and simply appear as a node in the graph.\nif is_torch_fx_available():\n    if not is_torch_greater_or_equal_than_1_13:\n        import torch.fx\n\n    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"DeepseekV2Config\"\n\n\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(\n        torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)\n    )\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\nclass DeepseekV2RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        DeepseekV2RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n        self.hidden_size = hidden_size\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return (self.weight * hidden_states).to(input_dtype)\n\n\nALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)\n\n# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->DeepseekV2\nclass DeepseekV2RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        super().__init__()\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\n    @maybe_no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)   \n\n# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2\nclass DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):\n    \"\"\"DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n    ):\n        raise NotImplementedError(\"LinearScalingRotaryEmbedding is not supported now.\")\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n        t = t / self.scaling_factor\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2\nclass DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):\n    \"\"\"DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n    ):\n        raise NotImplementedError(\"DynamicNTKScalingRotaryEmbedding is not supported now.\")\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * (\n                (self.scaling_factor * seq_len / self.max_position_embeddings)\n                - (self.scaling_factor - 1)\n            ) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (\n                base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)\n            )\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Inverse dim formula to find dim based on number of rotations\ndef yarn_find_correction_dim(\n    num_rotations, dim, base=10000, max_position_embeddings=2048\n):\n    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (\n        2 * math.log(base)\n    )\n\n\n# Find dim range bounds based on rotations\ndef yarn_find_correction_range(\n    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048\n):\n    low = math.floor(\n        yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)\n    )\n    high = math.ceil(\n        yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)\n    )\n    return max(low, 0), min(high, dim - 1)  # Clamp values just in case\n\n\ndef yarn_get_mscale(scale=1, mscale=1):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\ndef yarn_linear_ramp_mask(min, max, dim):\n    if min == max:\n        max += 0.001  # Prevent singularity\n\n    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n    ramp_func = torch.clamp(linear_func, 0, 1)\n    return ramp_func\n\nclass DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding):\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n        original_max_position_embeddings=4096,\n        beta_fast=32,\n        beta_slow=1,\n        mscale=1,\n        mscale_all_dim=0,\n    ):\n        nn.Module.__init__(self)\n        self.scaling_factor = scaling_factor\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self.beta_fast = beta_fast\n        self.beta_slow = beta_slow\n        self.mscale = mscale\n        self.mscale_all_dim = mscale_all_dim\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n\n        freq_extra = 1.0 / (\n            self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n        freq_inter = 1.0 / (\n            self.scaling_factor\n            * self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n\n        low, high = yarn_find_correction_range(\n            self.beta_fast,\n            self.beta_slow,\n            dim,\n            self.base,\n            self.original_max_position_embeddings,\n        )\n        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(\n            device=device, dtype=torch.float32\n        )\n        inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self._mscale = float(\n            yarn_get_mscale(self.scaling_factor, self.mscale)\n            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)\n        )\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\n    @maybe_no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()* self._mscale\n            sin = emb.sin()* self._mscale\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)  \n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    b, h, s, d = q.shape\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n    b, h, s, d = k.shape\n    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\nclass DeepseekV2MLP(nn.Module):\n    def __init__(self, config, hidden_size=None, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size\n        self.intermediate_size = (\n            config.intermediate_size if intermediate_size is None else intermediate_size\n        )\n\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        act = self.act_fn(self.gate_proj(x)) * self.up_proj(x)\n        down_proj = self.down_proj(act)\n        return down_proj\n\nclass MoEGate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.top_k = config.num_experts_per_tok\n        self.n_routed_experts = config.n_routed_experts\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.scoring_func = config.scoring_func\n        self.alpha = config.aux_loss_alpha\n        self.seq_aux = config.seq_aux\n        self.topk_method = config.topk_method\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n\n        # topk selection algorithm\n        self.norm_topk_prob = config.norm_topk_prob\n        self.gating_dim = config.hidden_size\n        self.weight = nn.Parameter(\n            torch.empty((self.n_routed_experts, self.gating_dim))\n        )\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        import torch.nn.init as init\n\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n\n    def forward(self, hidden_states):\n        bsz, seq_len, h = hidden_states.shape\n        ### compute gating score\n        hidden_states = hidden_states.view(-1, h)\n        logits = F.linear(\n            hidden_states.type(torch.float32), self.weight.type(torch.float32), None\n        )\n        if self.scoring_func == \"softmax\":\n            scores = logits.softmax(dim=-1, dtype=torch.float32)\n        else:\n            raise NotImplementedError(\n                f\"insupportable scoring function for MoE gating: {self.scoring_func}\"\n            )\n\n        ### select top-k experts\n        if self.topk_method == \"greedy\":\n            topk_weight, topk_idx = torch.topk(\n                scores, k=self.top_k, dim=-1, sorted=False\n            )\n        elif self.topk_method == \"group_limited_greedy\":\n            group_scores = (\n                scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values\n            )  # [n, n_group]\n            group_idx = torch.topk(\n                group_scores, k=self.topk_group, dim=-1, sorted=False\n            )[\n                1\n            ]  # [n, top_k_group]\n            group_mask = torch.zeros_like(group_scores)  # [n, n_group]\n            group_mask.scatter_(1, group_idx, 1)  # [n, n_group]\n            score_mask = (\n                group_mask.unsqueeze(-1)\n                .expand(\n                    bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group\n                )\n                .reshape(bsz * seq_len, -1)\n            )  # [n, e]\n            tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0)  # [n, e]\n            topk_weight, topk_idx = torch.topk(\n                tmp_scores, k=self.top_k, dim=-1, sorted=False\n            )\n\n        ### norm gate to sum 1\n        if self.top_k > 1 and self.norm_topk_prob:\n            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weight = topk_weight / denominator\n        else:\n            topk_weight = topk_weight * self.routed_scaling_factor\n        ### expert-level computation auxiliary loss\n        if self.training and self.alpha > 0.0:\n            scores_for_aux = scores\n            aux_topk = self.top_k\n            # always compute aux loss based on the naive greedy topk method\n            topk_idx_for_aux_loss = topk_idx.view(bsz, -1)\n            if self.seq_aux:\n                scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)\n                ce = torch.zeros(\n                    bsz, self.n_routed_experts, device=hidden_states.device\n                )\n                ce.scatter_add_(\n                    1,\n                    topk_idx_for_aux_loss,\n                    torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),\n                ).div_(seq_len * aux_topk / self.n_routed_experts)\n                aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(\n                    dim=1\n                ).mean() * self.alpha\n            else:\n                mask_ce = F.one_hot(\n                    topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts\n                )\n                ce = mask_ce.float().mean(0)\n                Pi = scores_for_aux.mean(0)\n                fi = ce * self.n_routed_experts\n                aux_loss = (Pi * fi).sum() * self.alpha\n        else:\n            aux_loss = None\n        return topk_idx, topk_weight, aux_loss\n\n\nclass AddAuxiliaryLoss(torch.autograd.Function):\n    \"\"\"\n    The trick function of adding auxiliary (aux) loss,\n    which includes the gradient of the aux loss during backpropagation.\n    \"\"\"\n\n    @staticmethod\n    def forward(ctx, x, loss):\n        assert loss.numel() == 1\n        ctx.dtype = loss.dtype\n        ctx.required_aux_loss = loss.requires_grad\n        return x\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        grad_loss = None\n        if ctx.required_aux_loss:\n            grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)\n        return grad_output, grad_loss\n\nclass DeepseekV2MoE(nn.Module):\n    \"\"\"\n    A mixed expert module containing shared experts.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.num_experts_per_tok = config.num_experts_per_tok\n\n        if hasattr(config, \"ep_size\") and config.ep_size > 1:\n            assert config.ep_size == dist.get_world_size()\n            self.ep_size = config.ep_size\n            self.experts_per_rank = config.n_routed_experts // config.ep_size\n            self.ep_rank = dist.get_rank()\n            self.experts = nn.ModuleList(\n                [\n                    (\n                        DeepseekV2MLP(\n                            config, intermediate_size=config.moe_intermediate_size\n                        )\n                        if i >= self.ep_rank * self.experts_per_rank\n                        and i < (self.ep_rank + 1) * self.experts_per_rank\n                        else None\n                    )\n                    for i in range(config.n_routed_experts)\n                ]\n            )\n        else:\n            self.ep_size = 1\n            self.experts_per_rank = config.n_routed_experts\n            self.ep_rank = 0\n            self.experts = nn.ModuleList(\n                [\n                    DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size)\n                    for i in range(config.n_routed_experts)\n                ]\n            )\n        self.gate = MoEGate(config)\n        if config.n_shared_experts is not None:\n            intermediate_size = config.moe_intermediate_size * config.n_shared_experts\n            self.shared_experts = DeepseekV2MLP(\n                config=config, intermediate_size=intermediate_size\n            )\n\n    def forward(self, hidden_states):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        topk_idx, topk_weight, aux_loss = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        flat_topk_idx = topk_idx.view(-1)\n        if self.training:\n            hidden_states = hidden_states.repeat_interleave(\n                self.num_experts_per_tok, dim=0\n            )\n            y = torch.empty_like(hidden_states)\n            for i, expert in enumerate(self.experts):\n                y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])\n            y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)\n            y = y.view(*orig_shape)\n            y = AddAuxiliaryLoss.apply(y, aux_loss)\n        else:\n            y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)\n        if self.config.n_shared_experts is not None:\n            y = y + self.shared_experts(identity)\n        return y\n\n    @maybe_no_grad()\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        sorted_tokens_shape = sorted_tokens.shape\n        if self.ep_size > 1:\n            tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)\n            tokens_per_expert_group = tokens_per_expert.new_empty(\n                tokens_per_expert.shape[0]\n            )\n            dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)\n            output_splits = (\n                tokens_per_expert_group.view(self.ep_size, -1)\n                .sum(1)\n                .cpu()\n                .numpy()\n                .tolist()\n            )\n            gathered_tokens = sorted_tokens.new_empty(\n                tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]\n            )\n            input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()\n            dist.all_to_all(\n                list(gathered_tokens.split(output_splits)),\n                list(sorted_tokens.split(input_split_sizes)),\n            )\n            tokens_per_expert_post_gather = tokens_per_expert_group.view(\n                self.ep_size, self.experts_per_rank\n            ).sum(dim=0)\n            gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)\n            s = 0\n            for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):\n                gatherd_idxs[s : s + k] = i % self.experts_per_rank\n                s += k\n            gatherd_idxs = gatherd_idxs.argsort()\n            sorted_tokens = gathered_tokens[gatherd_idxs]\n            tokens_per_expert = tokens_per_expert_post_gather\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n        if self.ep_size > 1:\n            new_x = torch.empty_like(outs)\n            new_x[gatherd_idxs] = outs\n            gathered_tokens = new_x.new_empty(*sorted_tokens_shape)\n            dist.all_to_all(\n                list(gathered_tokens.split(input_split_sizes)),\n                list(new_x.split(output_splits)),\n            )\n            outs = gathered_tokens\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2\nclass DeepseekV2Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will \"\n                \"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.attention_dropout = config.attention_dropout\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.q_lora_rank = config.q_lora_rank\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.kv_lora_rank = config.kv_lora_rank\n        self.v_head_dim = config.v_head_dim\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n\n        self.is_causal = True\n\n        if self.q_lora_rank is None:\n            self.q_proj = nn.Linear(\n                self.hidden_size, self.num_heads * self.q_head_dim, bias=False\n            )\n        else:\n            self.q_a_proj = nn.Linear(\n                self.hidden_size, config.q_lora_rank, bias=config.attention_bias\n            )\n            self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)\n            self.q_b_proj = nn.Linear(\n                config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False\n            )\n\n        self.kv_a_proj_with_mqa = nn.Linear(\n            self.hidden_size,\n            config.kv_lora_rank + config.qk_rope_head_dim,\n            bias=config.attention_bias,\n        )\n        self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)\n        self.kv_b_proj = nn.Linear(\n            config.kv_lora_rank,\n            self.num_heads\n            * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),\n            bias=False,\n        )\n\n        self.o_proj = nn.Linear(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=config.attention_bias,\n        )\n        self._init_rope()\n\n        self.softmax_scale = self.q_head_dim ** (-0.5)\n        if self.config.rope_scaling is not None:\n            mscale_all_dim = self.config.rope_scaling.get(\"mscale_all_dim\", 0)\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if mscale_all_dim:\n                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n                self.softmax_scale = self.softmax_scale * mscale * mscale\n\n    def _init_rope(self):\n        if self.config.rope_scaling is None:\n            self.rotary_emb = DeepseekV2RotaryEmbedding(\n                self.qk_rope_head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                base=self.rope_theta,\n            )\n        else:\n            scaling_type = self.config.rope_scaling[\"type\"]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if scaling_type == \"linear\":\n                self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"dynamic\":\n                self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"yarn\":\n                kwargs = {\n                    key: self.config.rope_scaling[key]\n                    for key in [\n                        \"original_max_position_embeddings\",\n                        \"beta_fast\",\n                        \"beta_slow\",\n                        \"mscale\",\n                        \"mscale_all_dim\",\n                    ]\n                    if key in self.config.rope_scaling\n                }\n                self.rotary_emb = DeepseekV2YarnRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                    **kwargs,\n                )\n            else:\n                raise ValueError(f\"Unknown RoPE scaling type {scaling_type}\")\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(q_pe, position_ids)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)\n\n        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        attn_weights = (\n            torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale\n        )\n\n        if attention_mask is not None:\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.attention_dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2\nclass DeepseekV2FlashAttention2(DeepseekV2Attention):\n    \"\"\"\n    DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays\n    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of\n    flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.\n        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.\n        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).\n        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        # DeepseekV2FlashAttention2 attention does not support output_attentions\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n\n            # overwrite attention_mask with padding_mask\n            attention_mask = kwargs.pop(\"padding_mask\")\n\n        output_attentions = False\n\n        bsz, q_len, _ = hidden_states.size()\n\n        q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dim x hidden_dim\n        # therefore we just need to keep the original shape\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(q_pe, position_ids)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)\n\n        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n\n        if self.q_head_dim != self.v_head_dim:\n            value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n        # to be able to avoid many of these transpose/reshape/view.\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        dropout_rate = self.attention_dropout if self.training else 0.0\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in the correct dtype just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (DeepseekV2RMSNorm handles it correctly)\n\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            # Handle the case where the model is quantized\n            if hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            elif torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            else:\n                target_dtype = self.q_a_proj.weight.dtype\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        attn_output = self._flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            position_ids=position_ids,\n            dropout=dropout_rate,\n            softmax_scale=self.softmax_scale,\n        )\n        if self.q_head_dim != self.v_head_dim:\n            attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n        attn_output = attn_output.reshape(\n            bsz, q_len, self.num_heads * self.v_head_dim\n        ).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n    def _flash_attention_forward(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        query_length,\n        position_ids,\n        dropout=0.0,\n        softmax_scale=None,\n    ):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n        # Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`int`, *optional*):\n                Attention dropout\n            softmax_scale (`float`, *optional*):\n                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)\n        \"\"\"\n        if not self._flash_attn_uses_top_left_mask:\n            causal = self.is_causal\n        else:\n            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__.\n            causal = self.is_causal and query_length != 1\n\n        # Contains at least one padding token in the sequence\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            (\n                query_states,\n                key_states,\n                value_states,\n                indices_q,\n                cu_seq_lens,\n                max_seq_lens,\n            ) = self._upad_input(\n                query_states, key_states, value_states, attention_mask, query_length\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            attn_output_unpad = flash_attn_varlen_func(\n                query_states,\n                key_states,\n                value_states,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k=cu_seqlens_k,\n                max_seqlen_q=max_seqlen_in_batch_q,\n                max_seqlen_k=max_seqlen_in_batch_k,\n                dropout_p=dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n            attn_output = pad_input(\n                attn_output_unpad, indices_q, batch_size, query_length\n            )\n        else:\n            if query_length == 1:\n                position_ids = position_ids.to(dtype=torch.int32).squeeze(1)\n                attn_output = flash_attn_with_kvcache(\n                    query_states,\n                    key_states,\n                    value_states,\n                    cache_seqlens=position_ids,\n                    softmax_scale=softmax_scale,\n                    causal=causal,\n                )   \n            else:\n                attn_output = flash_attn_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    dropout,\n                    softmax_scale=softmax_scale,\n                    causal=causal,\n                )\n\n        return attn_output\n\n    def _upad_input(\n        self, query_layer, key_layer, value_layer, attention_mask, query_length\n    ):\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape\n\n        key_layer = index_first_axis(\n            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        value_layer = index_first_axis(\n            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(\n                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),\n                indices_k,\n            )\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(\n                batch_size + 1, dtype=torch.int32, device=query_layer.device\n            )  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(\n                query_layer, attention_mask\n            )\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q,\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\nATTENTION_CLASSES = {\n    \"eager\": DeepseekV2Attention,\n    \"flash_attention_2\": DeepseekV2FlashAttention2,\n}\n\nclass DeepseekV2DecoderLayer(nn.Module):\n    def __init__(self, config: DeepseekV2Config, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = ATTENTION_CLASSES[config._attn_implementation](\n            config=config, layer_idx=layer_idx\n        )\n\n        self.mlp = (\n            DeepseekV2MoE(config)\n            if (\n                config.n_routed_experts is not None\n                and layer_idx >= config.first_k_dense_replace\n                and layer_idx % config.moe_layer_freq == 0\n            )\n            else DeepseekV2MLP(config)\n        )\n        self.input_layernorm = DeepseekV2RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = DeepseekV2RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*):\n                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,\n                query_sequence_length, key_sequence_length)` if default attention is used.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            **kwargs,\n        )\n\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nDeepseekV2_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`DeepseekV2Config`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.\",\n    DeepseekV2_START_DOCSTRING,\n)\nclass DeepseekV2PreTrainedModel(PreTrainedModel):\n    config_class = DeepseekV2Config\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"DeepseekV2DecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_cache_class = True\n    _supports_static_cache = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nDeepseekV2_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.\",\n    DeepseekV2_START_DOCSTRING,\n)\nclass DeepseekV2Model(DeepseekV2PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]\n\n    Args:\n        config: DeepseekV2Config\n    \"\"\"\n\n    def __init__(self, config: DeepseekV2Config):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n        self.layers = nn.ModuleList(\n            [\n                DeepseekV2DecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self._use_flash_attention_2 = config._attn_implementation == \"flash_attention_2\"\n        self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape[:2]\n        elif inputs_embeds is not None:\n            batch_size, seq_length = inputs_embeds.shape[:2]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers.\"\n                )\n                use_cache = False\n\n        past_key_values_length = 0\n        if use_cache:\n            use_legacy_cache = not isinstance(past_key_values, Cache)\n            if use_legacy_cache:\n                past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            past_key_values_length = past_key_values.get_usable_length(seq_length)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        )\n\n        # embed positions\n        hidden_states = inputs_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = (\n                next_decoder_cache.to_legacy_cache()\n                if use_legacy_cache\n                else next_decoder_cache\n            )\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n    \n    # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\"Custom 4D attention mask should be passed in inverted form with max==0`\")\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n\nclass DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = DeepseekV2Model(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)\n    @replace_return_docstrings(\n        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM\n\n        >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        # logits = self.lm_head(hidden_states[:,-1:,:]).float()\n        \n        logits = self.lm_head(hidden_states).float() \n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        cache_position=None,\n        use_cache=True,\n        **kwargs,\n    ):\n        past_length = 0\n        # Omit tokens covered by past_key_values\n        if past_key_values is not None:\n            if isinstance(past_key_values, Cache):\n                past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()\n                max_cache_length = (\n                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)\n                    if past_key_values.get_max_length() is not None\n                    else None\n                )\n                cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)\n            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects\n            else:\n                cache_length = past_length = past_key_values[0][0].shape[2]\n                max_cache_length = None\n\n            # Keep only the unprocessed tokens:\n            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where\n            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as\n            # input)\n            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:\n                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]\n            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard\n            # input_ids based on the past_length.\n            elif past_length < input_ids.shape[1]:\n                input_ids = input_ids[:, past_length:]\n            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.\n\n            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.\n            if (\n                max_cache_length is not None\n                and attention_mask is not None\n                and cache_length + input_ids.shape[1] > max_cache_length\n            ):\n                attention_mask = attention_mask[:, -max_cache_length:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_length == 0:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]\n        if cache_position is None:\n            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)\n        elif use_cache:\n            cache_position = cache_position[-input_length:]\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": use_cache,\n                \"attention_mask\": attention_mask,\n                \"cache_position\": cache_position,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx.to(past_state.device))\n                    for past_state in layer_past\n                ),\n            )\n        return reordered_past\n\n\n@add_start_docstrings(\n    \"\"\"\n    The DeepseekV2 Model transformer with a sequence classification head on top (linear layer).\n\n    [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    DeepseekV2_START_DOCSTRING,\n)\nclass DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = DeepseekV2Model(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\n                \"Cannot handle batch sizes > 1 if no padding token is defined.\"\n            )\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                sequence_lengths = (\n                    torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                ).to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[\n            torch.arange(batch_size, device=logits.device), sequence_lengths\n        ]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (\n                    labels.dtype == torch.long or labels.dtype == torch.int\n                ):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(\n                    pooled_logits.view(-1, self.num_labels), labels.view(-1)\n                )\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/models/modeling_deepseek_v3.py",
    "content": "# coding=utf-8\n# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\" PyTorch DeepSeek model.\"\"\"\nimport math\nimport warnings\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.generation import GenerationMixin\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n    _prepare_4d_attention_mask,\n    _prepare_4d_causal_attention_mask,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    SequenceClassifierOutputWithPast,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import (\n    ALL_LAYERNORM_LAYERS,\n    is_torch_greater_or_equal_than_1_13,\n)\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.utils.import_utils import is_torch_fx_available\nfrom .configuration_deepseek_v3 import DeepseekV3Config\nimport torch.distributed as dist\nimport numpy as np\n\nfrom ktransformers.util.grad_wrapper import maybe_no_grad\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n\n# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.\n# It means that the function will not be traced through and simply appear as a node in the graph.\nif is_torch_fx_available():\n    if not is_torch_greater_or_equal_than_1_13:\n        import torch.fx\n\n    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"DeepseekV3Config\"\n\n\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(\n        torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)\n    )\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\nclass DeepseekV3RMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        DeepseekV3RMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n        self.hidden_size = hidden_size\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nALL_LAYERNORM_LAYERS.append(DeepseekV3RMSNorm)\n\n\nclass DeepseekV3RotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (\n            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)\n        )\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self._set_cos_sin_cache(\n            seq_len=max_position_embeddings,\n            device=self.inv_freq.device,\n            dtype=torch.get_default_dtype(),\n        )\n        self.max_seq_len_cached = None\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n\n        freqs = torch.outer(t, self.inv_freq.to(t.device))\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n    def forward(self, x, seq_len=None):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:\n            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n\n        return (\n            self.cos_cached[:seq_len].to(dtype=x.dtype),\n            self.sin_cached[:seq_len].to(dtype=x.dtype),\n        )\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV3\nclass DeepseekV3LinearScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):\n    \"\"\"DeepseekV3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n    ):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n        t = t / self.scaling_factor\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV3\nclass DeepseekV3DynamicNTKScalingRotaryEmbedding(DeepseekV3RotaryEmbedding):\n    \"\"\"DeepseekV3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n    ):\n        self.scaling_factor = scaling_factor\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n\n        if seq_len > self.max_position_embeddings:\n            base = self.base * (\n                (self.scaling_factor * seq_len / self.max_position_embeddings)\n                - (self.scaling_factor - 1)\n            ) ** (self.dim / (self.dim - 2))\n            inv_freq = 1.0 / (\n                base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)\n            )\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(\n            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype\n        )\n\n        freqs = torch.outer(t, self.inv_freq)\n        # Different from paper, but it uses a different permutation in order to obtain the same calculation\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n        self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n\n\n# Inverse dim formula to find dim based on number of rotations\ndef yarn_find_correction_dim(\n    num_rotations, dim, base=10000, max_position_embeddings=2048\n):\n    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (\n        2 * math.log(base)\n    )\n\n\n# Find dim range bounds based on rotations\ndef yarn_find_correction_range(\n    low_rot, high_rot, dim, base=10000, max_position_embeddings=2048\n):\n    low = math.floor(\n        yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)\n    )\n    high = math.ceil(\n        yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)\n    )\n    return max(low, 0), min(high, dim - 1)  # Clamp values just in case\n\n\ndef yarn_get_mscale(scale=1, mscale=1):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\ndef yarn_linear_ramp_mask(min, max, dim):\n    if min == max:\n        max += 0.001  # Prevent singularity\n\n    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n    ramp_func = torch.clamp(linear_func, 0, 1)\n    return ramp_func\n\n\nclass DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):\n\n    def __init__(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n        original_max_position_embeddings=4096,\n        beta_fast=32,\n        beta_slow=1,\n        mscale=1,\n        mscale_all_dim=0,\n    ):\n        self.scaling_factor = scaling_factor\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self.beta_fast = beta_fast\n        self.beta_slow = beta_slow\n        self.mscale = mscale\n        self.mscale_all_dim = mscale_all_dim\n        super().__init__(dim, max_position_embeddings, base, device)\n\n    def _set_cos_sin_cache(self, seq_len, device, dtype):\n        self.max_seq_len_cached = seq_len\n        dim = self.dim\n\n        freq_extra = 1.0 / (\n            self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n        freq_inter = 1.0 / (\n            self.scaling_factor\n            * self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n\n        low, high = yarn_find_correction_range(\n            self.beta_fast,\n            self.beta_slow,\n            dim,\n            self.base,\n            self.original_max_position_embeddings,\n        )\n        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(\n            device=device, dtype=torch.float32\n        )\n        inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        t = torch.arange(seq_len, device=device, dtype=torch.float32)\n\n        freqs = torch.outer(t, inv_freq)\n\n        _mscale = float(\n            yarn_get_mscale(self.scaling_factor, self.mscale)\n            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)\n        )\n\n        emb = torch.cat((freqs, freqs), dim=-1)\n        self.register_buffer(\n            \"cos_cached\", (emb.cos() * _mscale).to(dtype), persistent=False\n        )\n        self.register_buffer(\n            \"sin_cached\", (emb.sin() * _mscale).to(dtype), persistent=False\n        )\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n            used to pass offsetted position ids when working with a KV-cache.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n    sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n\n    b, h, s, d = q.shape\n    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    b, h, s, d = k.shape\n    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)\n\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass DeepseekV3MLP(nn.Module):\n    def __init__(self, config, hidden_size=None, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size\n        self.intermediate_size = (\n            config.intermediate_size if intermediate_size is None else intermediate_size\n        )\n\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass MoEGate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.top_k = config.num_experts_per_tok\n        self.n_routed_experts = config.n_routed_experts\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.scoring_func = config.scoring_func\n        self.topk_method = config.topk_method\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n\n        # topk selection algorithm\n        self.norm_topk_prob = config.norm_topk_prob\n        self.gating_dim = config.hidden_size\n        self.weight = nn.Parameter(\n            torch.empty((self.n_routed_experts, self.gating_dim))\n        )\n        if self.topk_method == \"noaux_tc\":\n            self.e_score_correction_bias = nn.Parameter(\n                torch.empty((self.n_routed_experts))\n            )\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        import torch.nn.init as init\n\n        init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n\n    def forward(self, hidden_states):\n        bsz, seq_len, h = hidden_states.shape\n        ### compute gating score\n        hidden_states = hidden_states.view(-1, h)\n        logits = F.linear(\n            hidden_states.type(torch.float32), self.weight.type(torch.float32), None\n        )\n        if self.scoring_func == \"sigmoid\":\n            scores = logits.sigmoid()\n        else:\n            raise NotImplementedError(\n                f\"insupportable scoring function for MoE gating: {self.scoring_func}\"\n            )\n\n        ### select top-k experts\n        if self.topk_method == \"noaux_tc\":\n            #assert not self.training\n            scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)\n            group_scores = (\n                scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)\n            )  # [n, n_group]\n            group_idx = torch.topk(\n                group_scores, k=self.topk_group, dim=-1, sorted=False\n            )[\n                1\n            ]  # [n, top_k_group]\n            group_mask = torch.zeros_like(group_scores)  # [n, n_group]\n            group_mask.scatter_(1, group_idx, 1)  # [n, n_group]\n            score_mask = (\n                group_mask.unsqueeze(-1)\n                .expand(\n                    bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group\n                )\n                .reshape(bsz * seq_len, -1)\n            )  # [n, e]\n            tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), float(\"-inf\"))  # [n, e]\n            _, topk_idx = torch.topk(\n                tmp_scores, k=self.top_k, dim=-1, sorted=False\n            )\n            topk_weight = scores.gather(1, topk_idx)\n        else:\n            raise NotImplementedError(\n                f\"insupportable TopK function for MoE gating: {self.topk_method}\"\n            )\n\n        ### norm gate to sum 1\n        if self.top_k > 1 and self.norm_topk_prob:\n            denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20\n            topk_weight = topk_weight / denominator\n        topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor\n\n        return topk_idx, topk_weight\n\nclass DeepseekV3MoE(nn.Module):\n    \"\"\"\n    A mixed expert module containing shared experts.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.num_experts_per_tok = config.num_experts_per_tok\n\n        if hasattr(config, \"ep_size\") and config.ep_size > 1:\n            assert config.ep_size == dist.get_world_size()\n            self.ep_size = config.ep_size\n            self.experts_per_rank = config.n_routed_experts // config.ep_size\n            self.ep_rank = dist.get_rank()\n            self.experts = nn.ModuleList(\n                [\n                    (\n                        DeepseekV3MLP(\n                            config, intermediate_size=config.moe_intermediate_size\n                        )\n                        if i >= self.ep_rank * self.experts_per_rank\n                        and i < (self.ep_rank + 1) * self.experts_per_rank\n                        else None\n                    )\n                    for i in range(config.n_routed_experts)\n                ]\n            )\n        else:\n            self.ep_size = 1\n            self.experts_per_rank = config.n_routed_experts\n            self.ep_rank = 0\n            self.experts = nn.ModuleList(\n                [\n                    DeepseekV3MLP(\n                        config, intermediate_size=config.moe_intermediate_size\n                    )\n                    for i in range(config.n_routed_experts)\n                ]\n            )\n        self.gate = MoEGate(config)\n        if config.n_shared_experts is not None:\n            intermediate_size = config.moe_intermediate_size * config.n_shared_experts\n            self.shared_experts = DeepseekV3MLP(\n                config=config, intermediate_size=intermediate_size\n            )\n\n    def forward(self, hidden_states):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        topk_idx, topk_weight = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        flat_topk_idx = topk_idx.view(-1)\n        if not self.training:\n            y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)\n        if self.config.n_shared_experts is not None:\n            y = y + self.shared_experts(identity)\n        return y\n\n    @maybe_no_grad()\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        sorted_tokens_shape = sorted_tokens.shape\n        if self.ep_size > 1:\n            tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)\n            tokens_per_expert_group = tokens_per_expert.new_empty(\n                tokens_per_expert.shape[0]\n            )\n            dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)\n            output_splits = (\n                tokens_per_expert_group.view(self.ep_size, -1)\n                .sum(1)\n                .cpu()\n                .numpy()\n                .tolist()\n            )\n            gathered_tokens = sorted_tokens.new_empty(\n                tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]\n            )\n            input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()\n            dist.all_to_all(\n                list(gathered_tokens.split(output_splits)),\n                list(sorted_tokens.split(input_split_sizes)),\n            )\n            tokens_per_expert_post_gather = tokens_per_expert_group.view(\n                self.ep_size, self.experts_per_rank\n            ).sum(dim=0)\n            gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)\n            s = 0\n            for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):\n                gatherd_idxs[s : s + k] = i % self.experts_per_rank\n                s += k\n            gatherd_idxs = gatherd_idxs.argsort()\n            sorted_tokens = gathered_tokens[gatherd_idxs]\n            tokens_per_expert = tokens_per_expert_post_gather\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n        if self.ep_size > 1:\n            new_x = torch.empty_like(outs)\n            new_x[gatherd_idxs] = outs\n            gathered_tokens = new_x.new_empty(*sorted_tokens_shape)\n            dist.all_to_all(\n                list(gathered_tokens.split(input_split_sizes)),\n                list(new_x.split(output_splits)),\n            )\n            outs = gathered_tokens\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV3\nclass DeepseekV3Attention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: DeepseekV3Config, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will \"\n                \"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.attention_dropout = config.attention_dropout\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.q_lora_rank = config.q_lora_rank\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.kv_lora_rank = config.kv_lora_rank\n        self.v_head_dim = config.v_head_dim\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n\n        self.is_causal = True\n\n        if self.q_lora_rank is None:\n            self.q_proj = nn.Linear(\n                self.hidden_size, self.num_heads * self.q_head_dim, bias=False\n            )\n        else:\n            self.q_a_proj = nn.Linear(\n                self.hidden_size, config.q_lora_rank, bias=config.attention_bias\n            )\n            self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank)\n            self.q_b_proj = nn.Linear(\n                config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False\n            )\n\n        self.kv_a_proj_with_mqa = nn.Linear(\n            self.hidden_size,\n            config.kv_lora_rank + config.qk_rope_head_dim,\n            bias=config.attention_bias,\n        )\n        self.kv_a_layernorm = DeepseekV3RMSNorm(config.kv_lora_rank)\n        self.kv_b_proj = nn.Linear(\n            config.kv_lora_rank,\n            self.num_heads\n            * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),\n            bias=False,\n        )\n\n        self.o_proj = nn.Linear(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=config.attention_bias,\n        )\n        self._init_rope()\n\n        self.softmax_scale = self.q_head_dim ** (-0.5)\n        if self.config.rope_scaling is not None:\n            mscale_all_dim = self.config.rope_scaling.get(\"mscale_all_dim\", 0)\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if mscale_all_dim:\n                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n                self.softmax_scale = self.softmax_scale * mscale * mscale\n\n    def _init_rope(self):\n        if self.config.rope_scaling is None:\n            self.rotary_emb = DeepseekV3RotaryEmbedding(\n                self.qk_rope_head_dim,\n                max_position_embeddings=self.max_position_embeddings,\n                base=self.rope_theta,\n            )\n        else:\n            scaling_type = self.config.rope_scaling[\"type\"]\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if scaling_type == \"linear\":\n                self.rotary_emb = DeepseekV3LinearScalingRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"dynamic\":\n                self.rotary_emb = DeepseekV3DynamicNTKScalingRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                )\n            elif scaling_type == \"yarn\":\n                kwargs = {\n                    key: self.config.rope_scaling[key]\n                    for key in [\n                        \"original_max_position_embeddings\",\n                        \"beta_fast\",\n                        \"beta_slow\",\n                        \"mscale\",\n                        \"mscale_all_dim\",\n                    ]\n                    if key in self.config.rope_scaling\n                }\n                self.rotary_emb = DeepseekV3YarnRotaryEmbedding(\n                    self.qk_rope_head_dim,\n                    max_position_embeddings=self.max_position_embeddings,\n                    scaling_factor=scaling_factor,\n                    base=self.rope_theta,\n                    **kwargs,\n                )\n            else:\n                raise ValueError(f\"Unknown RoPE scaling type {scaling_type}\")\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return (\n            tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)\n            .transpose(1, 2)\n            .contiguous()\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)\n\n        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        attn_weights = (\n            torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale\n        )\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n        assert attention_mask is not None\n        if attention_mask is not None:\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.attention_dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3\nclass DeepseekV3FlashAttention2(DeepseekV3Attention):\n    \"\"\"\n    DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays\n    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of\n    flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.\n        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.\n        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).\n        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        # DeepseekV3FlashAttention2 attention does not support output_attentions\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n\n            # overwrite attention_mask with padding_mask\n            attention_mask = kwargs.pop(\"padding_mask\")\n\n        output_attentions = False\n\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dim x hidden_dim\n        # therefore we just need to keep the original shape\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)\n\n        query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n        query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n        key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n        key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n        key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n\n        if self.q_head_dim != self.v_head_dim:\n            value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n        # to be able to avoid many of these transpose/reshape/view.\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        dropout_rate = self.attention_dropout if self.training else 0.0\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in the correct dtype just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (DeepseekV3RMSNorm handles it correctly)\n\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            # Handle the case where the model is quantized\n            if hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            elif torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            else:\n                target_dtype = (\n                    self.q_proj.weight.dtype\n                    if self.q_lora_rank is None\n                    else self.q_a_proj.weight.dtype\n                )\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        attn_output = self._flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            dropout=dropout_rate,\n            softmax_scale=self.softmax_scale,\n        )\n        if self.q_head_dim != self.v_head_dim:\n            attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n        attn_output = attn_output.reshape(\n            bsz, q_len, self.num_heads * self.v_head_dim\n        ).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n    def _flash_attention_forward(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        query_length,\n        dropout=0.0,\n        softmax_scale=None,\n    ):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n\n        Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`int`, *optional*):\n                Attention dropout\n            softmax_scale (`float`, *optional*):\n                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)\n        \"\"\"\n        if not self._flash_attn_uses_top_left_mask:\n            causal = self.is_causal\n        else:\n            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.\n            causal = self.is_causal and query_length != 1\n\n        # Contains at least one padding token in the sequence\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            (\n                query_states,\n                key_states,\n                value_states,\n                indices_q,\n                cu_seq_lens,\n                max_seq_lens,\n            ) = self._upad_input(\n                query_states, key_states, value_states, attention_mask, query_length\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            attn_output_unpad = flash_attn_varlen_func(\n                query_states,\n                key_states,\n                value_states,\n                cu_seqlens_q=cu_seqlens_q,\n                cu_seqlens_k=cu_seqlens_k,\n                max_seqlen_q=max_seqlen_in_batch_q,\n                max_seqlen_k=max_seqlen_in_batch_k,\n                dropout_p=dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n            attn_output = pad_input(\n                attn_output_unpad, indices_q, batch_size, query_length\n            )\n        else:\n            attn_output = flash_attn_func(\n                query_states,\n                key_states,\n                value_states,\n                dropout,\n                softmax_scale=softmax_scale,\n                causal=causal,\n            )\n\n        return attn_output\n\n    def _upad_input(\n        self, query_layer, key_layer, value_layer, attention_mask, query_length\n    ):\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape\n\n        key_layer = index_first_axis(\n            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        value_layer = index_first_axis(\n            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),\n            indices_k,\n        )\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(\n                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),\n                indices_k,\n            )\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(\n                batch_size + 1, dtype=torch.int32, device=query_layer.device\n            )  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(\n                query_layer, attention_mask\n            )\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q,\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\nATTENTION_CLASSES = {\n    \"eager\": DeepseekV3Attention,\n    \"flash_attention_2\": DeepseekV3FlashAttention2,\n}\n\n\nclass DeepseekV3DecoderLayer(nn.Module):\n    def __init__(self, config: DeepseekV3Config, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = ATTENTION_CLASSES[config._attn_implementation](\n            config=config, layer_idx=layer_idx\n        )\n\n        self.mlp = (\n            DeepseekV3MoE(config)\n            if (\n                config.n_routed_experts is not None\n                and layer_idx >= config.first_k_dense_replace\n                and layer_idx % config.moe_layer_freq == 0\n            )\n            else DeepseekV3MLP(config)\n        )\n        self.input_layernorm = DeepseekV3RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n        self.post_attention_layernorm = DeepseekV3RMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*):\n                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,\n                query_sequence_length, key_sequence_length)` if default attention is used.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n        \"\"\"\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nDeepseekV3_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`DeepseekV3Config`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.\",\n    DeepseekV3_START_DOCSTRING,\n)\nclass DeepseekV3PreTrainedModel(PreTrainedModel):\n    config_class = DeepseekV3Config\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"DeepseekV3DecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_cache_class = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nDeepseekV3_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare DeepseekV3 Model outputting raw hidden-states without any specific head on top.\",\n    DeepseekV3_START_DOCSTRING,\n)\nclass DeepseekV3Model(DeepseekV3PreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV3DecoderLayer`]\n\n    Args:\n        config: DeepseekV3Config\n    \"\"\"\n\n    def __init__(self, config: DeepseekV3Config):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n        self.layers = nn.ModuleList(\n            [\n                DeepseekV3DecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self._use_flash_attention_2 = config._attn_implementation == \"flash_attention_2\"\n        self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape[:2]\n        elif inputs_embeds is not None:\n            batch_size, seq_length = inputs_embeds.shape[:2]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        past_key_values_length = 0\n        if use_cache:\n            use_legacy_cache = not isinstance(past_key_values, Cache)\n            if use_legacy_cache:\n                past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            past_key_values_length = past_key_values.get_usable_length(seq_length)\n\n        if position_ids is None:\n            device = input_ids.device if input_ids is not None else inputs_embeds.device\n            position_ids = torch.arange(\n                past_key_values_length,\n                seq_length + past_key_values_length,\n                dtype=torch.long,\n                device=device,\n            )\n            position_ids = position_ids.unsqueeze(0)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if self._use_flash_attention_2:\n            # 2d mask is passed through the layers\n            attention_mask = (\n                attention_mask\n                if (attention_mask is not None and 0 in attention_mask)\n                else None\n            )\n        else:\n            # 4d mask is passed through the layers\n            attention_mask = _prepare_4d_causal_attention_mask(\n                attention_mask,\n                (batch_size, seq_length),\n                inputs_embeds,\n                past_key_values_length,\n            )\n\n        # embed positions\n        hidden_states = inputs_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_values,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                cache_position=cache_position,\n            )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = (\n                next_decoder_cache.to_legacy_cache()\n                if use_legacy_cache\n                else next_decoder_cache\n            )\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\"Custom 4D attention mask should be passed in inverted form with max==0`\")\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\nclass DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = DeepseekV3Model(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)\n    @replace_return_docstrings(\n        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM\n\n        >>> model = DeepseekV3ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        # logits = self.lm_head(hidden_states[:,-1:,:])\n        logits = self.lm_head(hidden_states)\n        logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        **kwargs,\n    ):\n        if past_key_values is not None:\n            if isinstance(past_key_values, Cache):\n                cache_length = past_key_values.get_seq_length()\n                past_length = past_key_values.seen_tokens\n                max_cache_length = past_key_values.get_max_length()\n            else:\n                cache_length = past_length = past_key_values[0][0].shape[2]\n                max_cache_length = None\n\n            # Keep only the unprocessed tokens:\n            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where\n            # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as\n            # input)\n            if (\n                attention_mask is not None\n                and attention_mask.shape[1] > input_ids.shape[1]\n            ):\n                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]\n            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard\n            # input_ids based on the past_length.\n            elif past_length < input_ids.shape[1]:\n                input_ids = input_ids[:, past_length:]\n            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.\n\n            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.\n            if (\n                max_cache_length is not None\n                and attention_mask is not None\n                and cache_length + input_ids.shape[1] > max_cache_length\n            ):\n                attention_mask = attention_mask[:, -max_cache_length:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_key_values is None:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": kwargs.get(\"use_cache\"),\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (\n                tuple(\n                    past_state.index_select(0, beam_idx.to(past_state.device))\n                    for past_state in layer_past\n                ),\n            )\n        return reordered_past\n\n\n@add_start_docstrings(\n    \"\"\"\n    The DeepseekV3 Model transformer with a sequence classification head on top (linear layer).\n\n    [`DeepseekV3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    DeepseekV3_START_DOCSTRING,\n)\nclass DeepseekV3ForSequenceClassification(DeepseekV3PreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = DeepseekV3Model(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(DeepseekV3_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\n                \"Cannot handle batch sizes > 1 if no padding token is defined.\"\n            )\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                sequence_lengths = (\n                    torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                ).to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[\n            torch.arange(batch_size, device=logits.device), sequence_lengths\n        ]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (\n                    labels.dtype == torch.long or labels.dtype == torch.int\n                ):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(\n                    pooled_logits.view(-1, self.num_labels), labels.view(-1)\n                )\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/models/modeling_llama.py",
    "content": "# coding=utf-8\n# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\nfrom transformers.modeling_flash_attention_utils import _flash_attention_forward\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom .configuration_llama import LlamaConfig\n\nfrom ktransformers.util.grad_wrapper import maybe_no_grad\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"LlamaConfig\"\n\n\nclass LlamaRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        LlamaRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)\n\n\nclass LlamaRotaryEmbedding(nn.Module):\n    def __init__(\n        self,\n        dim=None,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n        rope_type=\"default\",\n        config: Optional[LlamaConfig] = None,\n    ):\n        super().__init__()\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        self.device = device\n        self.scaling_factor = scaling_factor\n        self.rope_type = rope_type\n        self.config = config\n        # TODO (joao): remove the `if` below, only used for BC\n        self.rope_kwargs = {}\n        if config is None:\n            logger.warning_once(\n                \"`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the \"\n                \"`config` argument. All other arguments will be removed in v4.45\"\n            )\n            self.rope_kwargs = {\n                \"rope_type\": rope_type,\n                \"factor\": scaling_factor,\n                \"dim\": dim,\n                \"base\": base,\n                \"max_position_embeddings\": max_position_embeddings,\n            }\n            self.rope_type = rope_type\n            self.max_seq_len_cached = max_position_embeddings\n            self.original_max_seq_len = max_position_embeddings\n        else:\n            # BC: \"rope_type\" was originally \"type\"\n            if config.rope_scaling is not None:\n                self.rope_type = config.rope_scaling.get(\n                    \"rope_type\", config.rope_scaling.get(\"type\")\n                )\n            else:\n                self.rope_type = \"default\"\n            self.max_seq_len_cached = config.max_position_embeddings\n            self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        inv_freq, self.attention_scaling = self.rope_init_fn(\n            self.config, device, **self.rope_kwargs\n        )\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    def _dynamic_frequency_update(self, position_ids, device):\n        \"\"\"\n        dynamic RoPE layers should recompute `inv_freq` in the following situations:\n        1 - growing beyond the cached sequence length (allow scaling)\n        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)\n        \"\"\"\n        seq_len = torch.max(position_ids) + 1\n        # seq_len = position_ids[0, -1] + 1\n        if seq_len > self.max_seq_len_cached:  # growth\n            inv_freq, self.attention_scaling = self.rope_init_fn(\n                self.config, device, seq_len=seq_len, **self.rope_kwargs\n            )\n            self.register_buffer(\n                \"inv_freq\", inv_freq, persistent=False\n            )  # TODO joao: may break with compilation\n            self.max_seq_len_cached = seq_len\n\n        if (\n            seq_len < self.original_max_seq_len\n            and self.max_seq_len_cached > self.original_max_seq_len\n        ):  # reset\n            self.register_buffer(\"inv_freq\", self.original_inv_freq, persistent=False)\n            self.max_seq_len_cached = self.original_max_seq_len\n\n    @maybe_no_grad()\n    def forward(self, x, position_ids):\n        # if \"dynamic\" in self.rope_type:\n        #     self._dynamic_frequency_update(position_ids, device=x.device)\n\n        # Core RoPE block\n        inv_freq_expanded = (\n            self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        )\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)\n        device_type = x.device.type\n        device_type = (\n            device_type\n            if isinstance(device_type, str) and device_type != \"mps\"\n            else \"cpu\"\n        )\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (\n                inv_freq_expanded.float() @ position_ids_expanded.float()\n            ).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n\n        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention\n        cos = cos * self.attention_scaling\n        sin = sin * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nclass LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        logger.warning_once(\n            \"`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use \"\n            \"`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__).\"\n        )\n        kwargs[\"rope_type\"] = \"linear\"\n        super().__init__(*args, **kwargs)\n\n\nclass LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):\n    \"\"\"LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        logger.warning_once(\n            \"`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use \"\n            \"`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to \"\n            \"__init__).\"\n        )\n        kwargs[\"rope_type\"] = \"dynamic\"\n        super().__init__(*args, **kwargs)\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\nclass LlamaMLP(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n        self.gate_proj = nn.Linear(\n            self.hidden_size, self.intermediate_size, bias=config.mlp_bias\n        )\n        self.up_proj = nn.Linear(\n            self.hidden_size, self.intermediate_size, bias=config.mlp_bias\n        )\n        self.down_proj = nn.Linear(\n            self.intermediate_size, self.hidden_size, bias=config.mlp_bias\n        )\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        if self.config.pretraining_tp > 1:\n            slice = self.intermediate_size // self.config.pretraining_tp\n            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)\n            up_proj_slices = self.up_proj.weight.split(slice, dim=0)\n            down_proj_slices = self.down_proj.weight.split(slice, dim=1)\n\n            gate_proj = torch.cat(\n                [\n                    F.linear(x, gate_proj_slices[i])\n                    for i in range(self.config.pretraining_tp)\n                ],\n                dim=-1,\n            )\n            up_proj = torch.cat(\n                [\n                    F.linear(x, up_proj_slices[i])\n                    for i in range(self.config.pretraining_tp)\n                ],\n                dim=-1,\n            )\n\n            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)\n            down_proj = [\n                F.linear(intermediate_states[i], down_proj_slices[i])\n                for i in range(self.config.pretraining_tp)\n            ]\n            down_proj = sum(down_proj)\n        else:\n            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n        return down_proj\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(\n        batch, num_key_value_heads, n_rep, slen, head_dim\n    )\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\nclass LlamaAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will \"\n                \"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.attention_dropout = config.attention_dropout\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.is_causal = True\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n\n        self.q_proj = nn.Linear(\n            self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.k_proj = nn.Linear(\n            self.hidden_size,\n            self.num_key_value_heads * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.v_proj = nn.Linear(\n            self.hidden_size,\n            self.num_key_value_heads * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.o_proj = nn.Linear(\n            self.hidden_size, self.hidden_size, bias=config.attention_bias\n        )\n\n        # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)\n        self.rotary_emb = LlamaRotaryEmbedding(config=self.config)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            Tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # will become mandatory in v4.45\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.config.pretraining_tp > 1:\n            key_value_slicing = (\n                self.num_key_value_heads * self.head_dim\n            ) // self.config.pretraining_tp\n            query_slices = self.q_proj.weight.split(\n                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0\n            )\n            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)\n            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)\n\n            query_states = [\n                F.linear(hidden_states, query_slices[i])\n                for i in range(self.config.pretraining_tp)\n            ]\n            query_states = torch.cat(query_states, dim=-1)\n\n            key_states = [\n                F.linear(hidden_states, key_slices[i])\n                for i in range(self.config.pretraining_tp)\n            ]\n            key_states = torch.cat(key_states, dim=-1)\n\n            value_states = [\n                F.linear(hidden_states, value_slices[i])\n                for i in range(self.config.pretraining_tp)\n            ]\n            value_states = torch.cat(value_states, dim=-1)\n\n        else:\n            query_states = self.q_proj(hidden_states)\n            key_states = self.k_proj(hidden_states)\n            value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(\n            bsz, q_len, self.num_heads, self.head_dim\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n\n        if position_embeddings is None:\n            logger.warning_once(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            cos, sin = self.rotary_emb(value_states, position_ids)\n        else:\n            cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin\n        )\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(\n            query_states, key_states.transpose(2, 3)\n        ) / math.sqrt(self.head_dim)\n\n        if attention_mask is not None:  # no matter the length, we just slice it\n            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n            attn_weights = attn_weights + causal_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.attention_dropout, training=self.training\n        )\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, -1)\n\n        if self.config.pretraining_tp > 1:\n            attn_output = attn_output.split(\n                self.hidden_size // self.config.pretraining_tp, dim=2\n            )\n            o_proj_slices = self.o_proj.weight.split(\n                self.hidden_size // self.config.pretraining_tp, dim=1\n            )\n            attn_output = sum(\n                [\n                    F.linear(attn_output[i], o_proj_slices[i])\n                    for i in range(self.config.pretraining_tp)\n                ]\n            )\n        else:\n            attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass LlamaFlashAttention2(LlamaAttention):\n    \"\"\"\n    Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays\n    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of\n    flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.\n        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.\n        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).\n        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.LongTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            Tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # will become mandatory in v4.45\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if isinstance(past_key_value, StaticCache):\n            raise ValueError(\n                \"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` \"\n                \"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers\"\n            )\n\n        output_attentions = False\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        # Flash attention requires the input to have the shape\n        # batch_size x seq_length x head_dim x hidden_dim\n        # therefore we just need to keep the original shape\n        query_states = query_states.view(\n            bsz, q_len, self.num_heads, self.head_dim\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n\n        if position_embeddings is None:\n            logger.warning_once(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            cos, sin = self.rotary_emb(value_states, position_ids)\n        else:\n            cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin\n        )\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache\n        # to be able to avoid many of these transpose/reshape/view.\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        dropout_rate = self.attention_dropout if self.training else 0.0\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in the correct dtype just to be sure everything works as expected.\n        # This might slowdown training & inference so it is recommended to not cast the LayerNorms\n        # in fp32. (LlamaRMSNorm handles it correctly)\n\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            if torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            # Handle the case where the model is quantized\n            elif hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            else:\n                target_dtype = self.q_proj.weight.dtype\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        attn_output = _flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            dropout=dropout_rate,\n            sliding_window=getattr(self, \"sliding_window\", None),\n            use_top_left_mask=self._flash_attn_uses_top_left_mask,\n            is_causal=self.is_causal,\n        )\n\n        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass LlamaSdpaAttention(LlamaAttention):\n    \"\"\"\n    Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from\n    `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to\n    SDPA API.\n    \"\"\"\n\n    # Adapted from LlamaAttention.forward\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            Tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # will become mandatory in v4.45\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if output_attentions:\n            # TODO: Improve this warning with e.g. `model.config.attn_implementation = \"manual\"` once this is implemented.\n            logger.warning_once(\n                \"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, \"\n                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n            )\n            return super().forward(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_value,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n                cache_position=cache_position,\n                position_embeddings=position_embeddings,\n            )\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(\n            bsz, q_len, self.num_heads, self.head_dim\n        ).transpose(1, 2)\n        key_states = key_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n        value_states = value_states.view(\n            bsz, q_len, self.num_key_value_heads, self.head_dim\n        ).transpose(1, 2)\n\n        if position_embeddings is None:\n            logger.warning_once(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            cos, sin = self.rotary_emb(value_states, position_ids)\n        else:\n            cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(\n            query_states, key_states, cos, sin\n        )\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(\n                key_states, value_states, self.layer_idx, cache_kwargs\n            )\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        causal_mask = attention_mask\n        if attention_mask is not None:\n            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]\n\n        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,\n        # Reference: https://github.com/pytorch/pytorch/issues/112577.\n        if query_states.device.type == \"cuda\" and causal_mask is not None:\n            query_states = query_states.contiguous()\n            key_states = key_states.contiguous()\n            value_states = value_states.contiguous()\n\n        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment\n        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.\n        is_causal = True if causal_mask is None and q_len > 1 else False\n\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=causal_mask,\n            dropout_p=self.attention_dropout if self.training else 0.0,\n            is_causal=is_causal,\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.view(bsz, q_len, -1)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output, None, past_key_value\n\n\nLLAMA_ATTENTION_CLASSES = {\n    \"eager\": LlamaAttention,\n    \"flash_attention_2\": LlamaFlashAttention2,\n    \"sdpa\": LlamaSdpaAttention,\n}\n\n\nclass LlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](\n            config=config, layer_idx=layer_idx\n        )\n\n        self.mlp = LlamaMLP(config)\n        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = LlamaRMSNorm(\n            config.hidden_size, eps=config.rms_norm_eps\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[\n            Tuple[torch.Tensor, torch.Tensor]\n        ] = None,  # will become mandatory in v4.45\n        **kwargs,\n    ) -> Tuple[\n        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]\n    ]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*):\n                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,\n                query_sequence_length, key_sequence_length)` if default attention is used.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence\n            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):\n                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,\n                with `head_dim` being the embedding dimension of each attention head.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n            **kwargs,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        return outputs\n\n\nLLAMA_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`LlamaConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaPreTrainedModel(PreTrainedModel):\n    config_class = LlamaConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"LlamaDecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nLLAMA_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaModel(LlamaPreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    def __init__(self, config: LlamaConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(\n            config.vocab_size, config.hidden_size, self.padding_idx\n        )\n        self.layers = nn.ModuleList(\n            [\n                LlamaDecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = LlamaRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training and use_cache:\n            logger.warning_once(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\"\n            )\n            use_cache = False\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        return_legacy_cache = False\n        if (\n            use_cache and not isinstance(past_key_values, Cache) and not self.training\n        ):  # kept for BC (non `Cache` `past_key_values` inputs)\n            return_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            logger.warning_once(\n                \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n                \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n            )\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=inputs_embeds.device,\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask,\n            inputs_embeds,\n            cache_position,\n            past_key_values,\n            output_attentions,\n        )\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                    position_embeddings,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if return_legacy_cache:\n            next_cache = next_cache.to_legacy_cache()\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = (\n            past_key_values.get_seq_length() if past_key_values is not None else 0\n        )\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and not using_static_cache\n            and not output_attentions\n        ):\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\n                    \"Custom 4D attention mask should be passed in inverted form with max==0`\"\n                )\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length),\n                fill_value=min_dtype,\n                dtype=dtype,\n                device=device,\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(\n                target_length, device=device\n            ) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(\n                input_tensor.shape[0], 1, -1, -1\n            )\n            if attention_mask is not None:\n                causal_mask = (\n                    causal_mask.clone()\n                )  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = (\n                    causal_mask[:, :, :, :mask_length]\n                    + attention_mask[:, None, None, :]\n                )\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[\n                    :, :, :, :mask_length\n                ].masked_fill(padding_mask, min_dtype)\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(\n                causal_mask, min_dtype\n            )\n\n        return causal_mask\n\n\nclass LlamaForCausalLM(LlamaPreTrainedModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = LlamaModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    @replace_return_docstrings(\n        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC\n    )\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, LlamaForCausalLM\n\n        >>> model = LlamaForCausalLM.from_pretrained(\"meta-llama/Llama-2-7b-hf\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-hf\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        if self.config.pretraining_tp > 1:\n            lm_head_slices = self.lm_head.weight.split(\n                self.vocab_size // self.config.pretraining_tp, dim=0\n            )\n            logits = [\n                F.linear(hidden_states, lm_head_slices[i])\n                for i in range(self.config.pretraining_tp)\n            ]\n            logits = torch.cat(logits, dim=-1)\n        else:\n            logits = self.lm_head(hidden_states)\n        # logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            return (loss,) + output if loss is not None else output\n\n        return CausalLMOutputWithPast(\n            loss=loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        cache_position=None,\n        position_ids=None,\n        use_cache=True,\n        **kwargs,\n    ):\n        # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens\n        # Exception 1: when passing input_embeds, input_ids may be missing entries\n        # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here\n        if past_key_values is not None:\n            if inputs_embeds is not None:  # Exception 1\n                input_ids = input_ids[:, -cache_position.shape[0] :]\n            elif (\n                input_ids.shape[1] != cache_position.shape[0]\n            ):  # Default case (the \"else\", a no op, is Exception 2)\n                input_ids = input_ids[:, cache_position]\n\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and cache_position[0] == 0:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\n                \"input_ids\": input_ids.contiguous()\n            }  # `contiguous()` needed for compilation use cases\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"cache_position\": cache_position,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": use_cache,\n                \"attention_mask\": attention_mask,\n            }\n        )\n        return model_inputs\n\n\n@add_start_docstrings(\n    \"\"\"\n    The LLaMa Model transformer with a sequence classification head on top (linear layer).\n\n    [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaForSequenceClassification(LlamaPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = LlamaModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\n                \"Cannot handle batch sizes > 1 if no padding token is defined.\"\n            )\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility\n                sequence_lengths = (\n                    torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                )\n                sequence_lengths = sequence_lengths % input_ids.shape[-1]\n                sequence_lengths = sequence_lengths.to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[\n            torch.arange(batch_size, device=logits.device), sequence_lengths\n        ]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (\n                    labels.dtype == torch.long or labels.dtype == torch.int\n                ):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(\n                    pooled_logits.view(-1, self.num_labels), labels.view(-1)\n                )\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\nThe Llama Model transformer with a span classification head on top for extractive question-answering tasks like\nSQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).\n    \"\"\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaForQuestionAnswering(LlamaPreTrainedModel):\n    base_model_prefix = \"transformer\"\n\n    # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama\n    def __init__(self, config):\n        super().__init__(config)\n        self.transformer = LlamaModel(config)\n        self.qa_outputs = nn.Linear(config.hidden_size, 2)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.transformer.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.transformer.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        r\"\"\"\n        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        outputs = self.transformer(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = outputs[0]\n\n        logits = self.qa_outputs(sequence_output)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1).contiguous()\n        end_logits = end_logits.squeeze(-1).contiguous()\n\n        total_loss = None\n        if start_positions is not None and end_positions is not None:\n            # If we are on multi-GPU, split add a dimension\n            if len(start_positions.size()) > 1:\n                start_positions = start_positions.squeeze(-1).to(start_logits.device)\n            if len(end_positions.size()) > 1:\n                end_positions = end_positions.squeeze(-1).to(end_logits.device)\n            # sometimes the start/end positions are outside our model inputs, we ignore these terms\n            ignored_index = start_logits.size(1)\n            start_positions = start_positions.clamp(0, ignored_index)\n            end_positions = end_positions.clamp(0, ignored_index)\n\n            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)\n            start_loss = loss_fct(start_logits, start_positions)\n            end_loss = loss_fct(end_logits, end_positions)\n            total_loss = (start_loss + end_loss) / 2\n\n        if not return_dict:\n            output = (start_logits, end_logits) + outputs[2:]\n            return ((total_loss,) + output) if total_loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=total_loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states\n    output) e.g. for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaForTokenClassification(LlamaPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = LlamaModel(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, TokenClassifierOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        sequence_output = outputs[0]\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/models/modeling_mixtral.py",
    "content": "# coding=utf-8\n'''\nDescription  : \nAuthor       : kkk1nak0\nDate         : 2024-07-29 02:58:57\nVersion      : 1.0.0\nLastEditors  : kkk1nak0\nLastEditTime : 2024-08-02 06:08:34\n'''\n\n# Adapted from \n# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py\n# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.\n# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n#\n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Mixtral model.\"\"\"\n\nimport inspect \nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n    _prepare_4d_causal_attention_mask,\n)\nfrom transformers.modeling_outputs import (\n    MoeCausalLMOutputWithPast,\n    MoeModelOutputWithPast,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.utils.import_utils import is_torch_fx_available\nfrom transformers.models.mixtral.configuration_mixtral import MixtralConfig\n\nfrom ktransformers.util.grad_wrapper import maybe_no_grad\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_varlen_func, flash_attn_func, flash_attn_with_kvcache\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n    _flash_supports_window_size = \"window_size\" in list(inspect.signature(flash_attn_func).parameters)\n\n# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.\n# It means that the function will not be traced through and simply appear as a node in the graph.\nif is_torch_fx_available():\n    if not is_torch_greater_or_equal_than_1_13:\n        import torch.fx\n\n    _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"MixtralConfig\"\n\n\ndef load_balancing_loss_func(\n    gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None\n) -> float:\n    r\"\"\"\n    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.\n\n    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss\n    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between\n    experts is too unbalanced.\n\n    Args:\n        gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):\n            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of\n            shape [batch_size X sequence_length, num_experts].\n        attention_mask (`torch.Tensor`, None):\n            The attention_mask used in forward function\n            shape [batch_size X sequence_length] if not None.\n        num_experts (`int`, *optional*):\n            Number of experts\n\n    Returns:\n        The auxiliary loss.\n    \"\"\"\n    if gate_logits is None or not isinstance(gate_logits, tuple):\n        return 0\n\n    if isinstance(gate_logits, tuple):\n        compute_device = gate_logits[0].device\n        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)\n\n    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)\n\n    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)\n\n    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)\n\n    if attention_mask is None:\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.mean(routing_weights, dim=0)\n    else:\n        batch_size, sequence_length = attention_mask.shape\n        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask\n        expert_attention_mask = (\n            attention_mask[None, :, :, None, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))\n            .reshape(-1, top_k, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(\n            expert_attention_mask, dim=0\n        )\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert\n        router_per_expert_attention_mask = (\n            attention_mask[None, :, :, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))\n            .reshape(-1, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(\n            router_per_expert_attention_mask, dim=0\n        )\n\n    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))\n    return overall_loss * num_experts\n\n\n# Copied from transformers.models.llama.modeling_llama._get_unpad_data\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral\nclass MixtralRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        MixtralRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\n# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral\n# TODO @longjie no longer copied from Mistral after static cache\nclass MixtralRotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n        super().__init__()\n        \n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n\n        # Build here to make `torch.jit.trace` work.\n        self.max_seq_len_cached = max_position_embeddings\n\n    @maybe_no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\n# TODO @longjie no longer copied from Mistral after static cache\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n            used to pass offsetted position ids when working with a KV-cache.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\n# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral\n# TODO @longjie no longer copied from Mistral after static cache\nclass MixtralAttention(nn.Module):\n    \"\"\"\n    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer\n    and \"Generating Long Sequences with Sparse Transformers\".\n    \"\"\"\n\n    def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will \"\n                \"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.is_causal = True\n        self.attention_dropout = config.attention_dropout\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)\n        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n        self.rotary_emb = MixtralRotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):\n        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, position_ids)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:  # no matter the length, we just slice it\n            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n            attn_weights = attn_weights + causal_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\n# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral\n# TODO @longjie no longer copied from Mistral after static cache\nclass MixtralFlashAttention2(MixtralAttention):\n    \"\"\"\n    Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays\n    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of\n    flash attention and deal with padding tokens in case the input contains any of them.\n    \"\"\"\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ):\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(value_states, position_ids)\n\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        use_sliding_windows = (\n            _flash_supports_window_size\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and kv_seq_len > self.config.sliding_window\n            and self.config.use_sliding_window\n        )\n\n        if not _flash_supports_window_size:\n            logger.warning_once(\n                \"The current flash attention version does not support sliding window attention, for a more memory efficient implementation\"\n                \" make sure to upgrade flash-attn library.\"\n            )\n\n        if past_key_value is not None:\n            # Activate slicing cache only if the config has a value `sliding_windows` attribute\n            cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0\n            if (\n                getattr(self.config, \"sliding_window\", None) is not None\n                and kv_seq_len > self.config.sliding_window\n                and cache_has_contents\n            ):\n                slicing_tokens = 1 - self.config.sliding_window\n\n                past_key = past_key_value[self.layer_idx][0]\n                past_value = past_key_value[self.layer_idx][1]\n\n                past_key = past_key[:, :, slicing_tokens:, :].contiguous()\n                past_value = past_value[:, :, slicing_tokens:, :].contiguous()\n\n                if past_key.shape[-2] != self.config.sliding_window - 1:\n                    raise ValueError(\n                        f\"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got\"\n                        f\" {past_key.shape}\"\n                    )\n\n                if attention_mask is not None:\n                    attention_mask = attention_mask[:, slicing_tokens:]\n                    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)\n\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n            # we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails\n            # for bsz == 1, avoid using slice to capture cuda graph\n            if cache_position is not None and q_len > 1:\n                key_states = key_states[:, :, : cache_position[-1] + 1, :]\n                value_states = value_states[:, :, : cache_position[-1] + 1, :]\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n        dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in float16 just to be sure everything works as expected.\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            if torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            # Handle the case where the model is quantized\n            elif hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            else:\n                target_dtype = self.q_proj.weight.dtype\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        # Reashape to the expected shape for Flash Attention\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        attn_output = self._flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            position_ids=position_ids,\n            dropout=dropout_rate,\n            sliding_window=getattr(self.config, \"sliding_window\", None),\n            is_causal=self.is_causal,\n        )\n\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n    \n\n    def _flash_attention_forward(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        q_len,\n        position_ids,\n        dropout,\n        sliding_window,\n        is_causal,\n        softmax_scale=None,\n    ):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n\n        Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`float`):\n                Attention dropout\n            \n        \"\"\"\n        \n        # Decide whether to use SWA or not by layer index.\n        # if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:\n        #     use_sliding_windows = False\n        use_sliding_windows = False\n\n        # Contains at least one padding token in the sequence\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(\n                query_states, key_states, value_states, attention_mask, q_len\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            if not use_sliding_windows:\n                attn_output_unpad = flash_attn_varlen_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k=cu_seqlens_k,\n                    max_seqlen_q=max_seqlen_in_batch_q,\n                    max_seqlen_k=max_seqlen_in_batch_k,\n                    dropout_p=dropout,\n                    softmax_scale=softmax_scale,\n                    causal=is_causal,\n                )\n            else:\n                attn_output_unpad = flash_attn_varlen_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k=cu_seqlens_k,\n                    max_seqlen_q=max_seqlen_in_batch_q,\n                    max_seqlen_k=max_seqlen_in_batch_k,\n                    dropout_p=dropout,\n                    softmax_scale=softmax_scale,\n                    causal=is_causal,\n                    window_size=(self.config.sliding_window, self.config.sliding_window),\n                )\n\n            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)\n        else:\n            if not use_sliding_windows:\n                if q_len == 1:\n                    position_ids = position_ids.to(dtype=torch.int32).squeeze(1)\n                    attn_output = flash_attn_with_kvcache(\n                        query_states,\n                        key_states,\n                        value_states,\n                        cache_seqlens=position_ids,\n                        softmax_scale=softmax_scale,\n                        causal=is_causal,\n                    )   \n                else:\n                    attn_output = flash_attn_func(\n                        query_states,\n                        key_states,\n                        value_states,\n                        dropout,\n                        softmax_scale=softmax_scale,\n                        causal=is_causal,\n                    )\n            else:\n                attn_output = flash_attn_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    dropout,\n                    softmax_scale=softmax_scale,\n                    causal=is_causal,\n                    window_size=(self.config.sliding_window, self.config.sliding_window),\n                )\n\n        return attn_output\n\n    # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input\n    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):\n        batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape\n\n        # On the first iteration we need to properly re-create the padding mask\n        # by slicing it on the proper place\n        if kv_seq_len != attention_mask.shape[-1]:\n            attention_mask_num_tokens = attention_mask.shape[-1]\n            attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]\n\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n\n        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)\n        value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)\n\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(\n                query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k\n            )\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(\n                batch_size + 1, dtype=torch.int32, device=query_layer.device\n            )  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q,\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\n\n# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral\n# TODO @longjie no longer copied from Mistral after static cache\nclass MixtralSdpaAttention(MixtralAttention):\n    \"\"\"\n    Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from\n    `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to\n    SDPA API.\n    \"\"\"\n\n    # Adapted from MixtralAttention.forward\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if output_attentions:\n            # TODO: Improve this warning with e.g. `model.config.attn_implementation = \"manual\"` once this is implemented.\n            logger.warning_once(\n                \"MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, \"\n                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n            )\n            return super().forward(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_value,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n            )\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, position_ids)\n\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        causal_mask = attention_mask\n        if attention_mask is not None:  # no matter the length, we just slice it\n            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n\n        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,\n        # Reference: https://github.com/pytorch/pytorch/issues/112577.\n        if query_states.device.type == \"cuda\" and attention_mask is not None:\n            query_states = query_states.contiguous()\n            key_states = key_states.contiguous()\n            value_states = value_states.contiguous()\n\n        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment\n        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.\n        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.\n        is_causal = True if causal_mask is None and q_len > 1 else False\n\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=causal_mask,\n            dropout_p=self.attention_dropout if self.training else 0.0,\n            is_causal=is_causal,\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.view(bsz, q_len, self.hidden_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output, None, past_key_value\n\n\nMIXTRAL_ATTENTION_CLASSES = {\n    \"eager\": MixtralAttention,\n    \"flash_attention_2\": MixtralFlashAttention2,\n    \"sdpa\": MixtralSdpaAttention,\n}\n\n\nclass MixtralBlockSparseTop2MLP(nn.Module):\n    def __init__(self, config: MixtralConfig):\n        super().__init__()\n        self.ffn_dim = config.intermediate_size\n        self.hidden_dim = config.hidden_size\n\n        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)  # gate\n        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)  # down\n        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)  # up\n\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, hidden_states):\n        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)\n        current_hidden_states = self.w2(current_hidden_states)\n        return current_hidden_states\n\n\nclass MixtralSparseMoeBlock(nn.Module):\n    \"\"\"\n    This implementation is\n    strictly equivalent to standard MoE with full capacity (no\n    dropped tokens). It's faster since it formulates MoE operations\n    in terms of block-sparse operations to accomodate imbalanced\n    assignments of tokens to experts, whereas standard MoE either\n    (1) drop tokens at the cost of reduced performance or (2) set\n    capacity factor to number of experts and thus waste computation\n    and memory on padding.\n    \"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.hidden_dim = config.hidden_size\n        self.ffn_dim = config.intermediate_size\n        self.num_experts = config.num_local_experts\n        self.top_k = config.num_experts_per_tok\n\n        # gating\n        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)\n\n        self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])\n\n        # Jitter parameters\n        self.jitter_noise = config.router_jitter_noise\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        if self.training and self.jitter_noise > 0:\n            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))\n        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n        return final_hidden_states, router_logits\n\n\nclass MixtralDecoderLayer(nn.Module):\n    def __init__(self, config: MixtralConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)\n\n        self.block_sparse_moe = MixtralSparseMoeBlock(config)\n        self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        output_router_logits: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_router_logits (`bool`, *optional*):\n                Whether or not to return the logits of all the routers. They are useful for computing the router loss, and\n                should not be returned during inference.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states, router_logits = self.block_sparse_moe(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        if output_router_logits:\n            outputs += (router_logits,)\n\n        return outputs\n\n\nMIXTRAL_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`MixtralConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Mixtral Model outputting raw hidden-states without any specific head on top.\",\n    MIXTRAL_START_DOCSTRING,\n)\n# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral\nclass MixtralPreTrainedModel(PreTrainedModel):\n    config_class = MixtralConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"MixtralDecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_cache_class = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nMIXTRAL_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):\n            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape\n            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.\n\n            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.\n\n            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that\n            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all\n            `decoder_input_ids` of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        output_router_logits (`bool`, *optional*):\n            Whether or not to return the logits of all the routers. They are useful for computing the router loss, and\n            should not be returned during inference.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Mixtral Model outputting raw hidden-states without any specific head on top.\",\n    MIXTRAL_START_DOCSTRING,\n)\n# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral\n# TODO @longjie no longer copied from Mistral after static cache\nclass MixtralModel(MixtralPreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]\n\n    Args:\n        config: MixtralConfig\n    \"\"\"\n\n    def __init__(self, config: MixtralConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self._attn_implementation = config._attn_implementation\n        self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    # Ignore copy\n    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, MoeModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        use_legacy_cache = False\n        if use_cache and not isinstance(past_key_values, Cache) and not self.training:\n            use_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            logger.warning_once(\n                \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n                \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n            )\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        )\n\n        hidden_states = inputs_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        all_router_logits = () if output_router_logits else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    output_router_logits,\n                    use_cache,\n                    cache_position,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    output_router_logits=output_router_logits,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n            if output_router_logits:\n                all_router_logits += (layer_outputs[-1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]\n                if v is not None\n            )\n        return MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            router_logits=all_router_logits,\n        )\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\"Custom 4D attention mask should be passed in inverted form with max==0`\")\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n\nclass MixtralForCausalLM(MixtralPreTrainedModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = MixtralModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.router_aux_loss_coef = config.router_aux_loss_coef\n        self.num_experts = config.num_local_experts\n        self.num_experts_per_tok = config.num_experts_per_tok\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    # Ignore copy\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, MoeCausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, MixtralForCausalLM\n\n        >>> model = MixtralForCausalLM.from_pretrained(\"mistralai/Mixtral-8x7B-v0.1\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mixtral-8x7B-v0.1\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            output_router_logits=output_router_logits,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n        logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        aux_loss = None\n        if output_router_logits:\n            aux_loss = load_balancing_loss_func(\n                outputs.router_logits if return_dict else outputs[-1],\n                self.num_experts,\n                self.num_experts_per_tok,\n                attention_mask,\n            )\n            if labels is not None:\n                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            if output_router_logits:\n                output = (aux_loss,) + output\n            return (loss,) + output if loss is not None else output\n\n        return MoeCausalLMOutputWithPast(\n            loss=loss,\n            aux_loss=aux_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            router_logits=outputs.router_logits,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        cache_position=None,\n        output_router_logits=False,\n        position_ids=None,\n        use_cache=True,\n        **kwargs,\n    ):\n        # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens\n        # Exception 1: when passing input_embeds, input_ids may be missing entries\n        # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here\n        if past_key_values is not None:\n            if inputs_embeds is not None:  # Exception 1\n                input_ids = input_ids[:, -cache_position.shape[0] :]\n            elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the \"else\", a no op, is Exception 2)\n                input_ids = input_ids[:, cache_position]\n\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and cache_position[0] == 0:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids.contiguous()}  # `contiguous()` needed for compilation use cases\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"cache_position\": cache_position,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": use_cache,\n                \"attention_mask\": attention_mask,\n                \"output_router_logits\": output_router_logits,\n            }\n        )\n        return model_inputs\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Mixtral Model transformer with a sequence classification head on top (linear layer).\n\n    [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    MIXTRAL_START_DOCSTRING,\n)\n# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL\nclass MixtralForSequenceClassification(MixtralPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = MixtralModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility\n                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                sequence_lengths = sequence_lengths % input_ids.shape[-1]\n                sequence_lengths = sequence_lengths.to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states\n    output) e.g. for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    MIXTRAL_START_DOCSTRING,\n)\n# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL\nclass MixtralForTokenClassification(MixtralPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = MixtralModel(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, TokenClassifierOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        sequence_output = outputs[0]\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )"
  },
  {
    "path": "kt-sft/ktransformers/models/modeling_qwen2_moe.py",
    "content": "# coding=utf-8\n'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\n''' \n# Adapted from\n# https://github.com/huggingface/transformers/blob/v4.42.3/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py\n# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.\n# \n# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX\n# and OPT implementations in this library. It has been modified from its\n# original forms to accommodate minor architectural differences compared\n# to GPT-NeoX and OPT used by the Meta AI team that trained the model.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch Qwen2MoE model.\"\"\"\n\nimport inspect\nimport math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n)\nfrom transformers.modeling_outputs import (\n    MoeCausalLMOutputWithPast,\n    MoeModelOutputWithPast,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig\n\nfrom ktransformers.util.grad_wrapper import maybe_no_grad\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n    _flash_supports_window_size = \"window_size\" in list(inspect.signature(flash_attn_func).parameters)\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"Qwen/Qwen1.5-MoE-A2.7B\"\n_CONFIG_FOR_DOC = \"Qwen2MoeConfig\"\n\n\n# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func\ndef load_balancing_loss_func(\n    gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None\n) -> float:\n    r\"\"\"\n    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.\n\n    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss\n    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between\n    experts is too unbalanced.\n\n    Args:\n        gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):\n            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of\n            shape [batch_size X sequence_length, num_experts].\n        attention_mask (`torch.Tensor`, None):\n            The attention_mask used in forward function\n            shape [batch_size X sequence_length] if not None.\n        num_experts (`int`, *optional*):\n            Number of experts\n\n    Returns:\n        The auxiliary loss.\n    \"\"\"\n    if gate_logits is None or not isinstance(gate_logits, tuple):\n        return 0\n\n    if isinstance(gate_logits, tuple):\n        compute_device = gate_logits[0].device\n        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)\n\n    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)\n\n    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)\n\n    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)\n\n    if attention_mask is None:\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.mean(routing_weights, dim=0)\n    else:\n        batch_size, sequence_length = attention_mask.shape\n        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask\n        expert_attention_mask = (\n            attention_mask[None, :, :, None, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))\n            .reshape(-1, top_k, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(\n            expert_attention_mask, dim=0\n        )\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert\n        router_per_expert_attention_mask = (\n            attention_mask[None, :, :, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))\n            .reshape(-1, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(\n            router_per_expert_attention_mask, dim=0\n        )\n\n    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))\n    return overall_loss * num_experts\n\n\n# Copied from transformers.models.llama.modeling_llama._get_unpad_data\ndef _get_unpad_data(attention_mask):\n    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)\n    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()\n    max_seqlen_in_batch = seqlens_in_batch.max().item()\n    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))\n    return (\n        indices,\n        cu_seqlens,\n        max_seqlen_in_batch,\n    )\n\n\n# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2Moe\nclass Qwen2MoeRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Qwen2MoeRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2Moe\nclass Qwen2MoeRotaryEmbedding(nn.Module):\n    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):\n        super().__init__()\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\n    @maybe_no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\n# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\n# Modified from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2Moe\nclass Qwen2MoeMLP(nn.Module):\n    def __init__(self, config, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n\n\n# Copied from transformers.models.llama.modeling_llama.repeat_kv\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\n# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe\nclass Qwen2MoeAttention(nn.Module):\n    \"\"\"\n    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer\n    and \"Generating Long Sequences with Sparse Transformers\".\n    \"\"\"\n\n    def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        if layer_idx is None:\n            logger.warning_once(\n                f\"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will \"\n                \"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` \"\n                \"when creating this class.\"\n            )\n\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.hidden_size // self.num_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.is_causal = True\n        self.attention_dropout = config.attention_dropout\n\n        if (self.head_dim * self.num_heads) != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)\n        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)\n        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n        self.rotary_emb = Qwen2MoeRotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, position_ids)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n\n        if attention_mask is not None:  # no matter the length, we just slice it\n            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n            attn_weights = attn_weights + causal_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)\n        attn_output = torch.matmul(attn_weights, value_states)\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\n# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe\nclass Qwen2MoeFlashAttention2(Qwen2MoeAttention):\n    \"\"\"\n    Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention`\n    as the weights of the module stays untouched. The only required change would be on the forward pass\n    where it needs to correctly call the public API of flash attention and deal with padding tokens\n    in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom\n    config.max_window_layers layers.\n    \"\"\"\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.\n        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.\n        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).\n        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ):\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(value_states, position_ids)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        use_sliding_windows = (\n            _flash_supports_window_size\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and kv_seq_len > self.config.sliding_window\n            and self.config.use_sliding_window\n        )\n\n        if not _flash_supports_window_size:\n            logger.warning_once(\n                \"The current flash attention version does not support sliding window attention, for a more memory efficient implementation\"\n                \" make sure to upgrade flash-attn library.\"\n            )\n\n        if past_key_value is not None:\n            # Activate slicing cache only if the config has a value `sliding_windows` attribute\n            cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0\n            if (\n                getattr(self.config, \"sliding_window\", None) is not None\n                and kv_seq_len > self.config.sliding_window\n                and cache_has_contents\n            ):\n                slicing_tokens = 1 - self.config.sliding_window\n\n                past_key = past_key_value[self.layer_idx][0]\n                past_value = past_key_value[self.layer_idx][1]\n\n                past_key = past_key[:, :, slicing_tokens:, :].contiguous()\n                past_value = past_value[:, :, slicing_tokens:, :].contiguous()\n\n                if past_key.shape[-2] != self.config.sliding_window - 1:\n                    raise ValueError(\n                        f\"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got\"\n                        f\" {past_key.shape}\"\n                    )\n\n                if attention_mask is not None:\n                    attention_mask = attention_mask[:, slicing_tokens:]\n                    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)\n\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n            # we slice the states for static kv cache to be supported in FA2. Not sure it's a must as compile fails\n            # for bsz == 1, avoid using slice to capture cuda graph\n            if cache_position is not None and q_len > 1:\n                key_states = key_states[:, :, : cache_position[-1] + 1, :]\n                value_states = value_states[:, :, : cache_position[-1] + 1, :]\n\n        # repeat k/v heads if n_kv_heads < n_heads\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n        dropout_rate = 0.0 if not self.training else self.attention_dropout\n\n        # In PEFT, usually we cast the layer norms in float32 for training stability reasons\n        # therefore the input hidden states gets silently casted in float32. Hence, we need\n        # cast them back in float16 just to be sure everything works as expected.\n        input_dtype = query_states.dtype\n        if input_dtype == torch.float32:\n            if torch.is_autocast_enabled():\n                target_dtype = torch.get_autocast_gpu_dtype()\n            # Handle the case where the model is quantized\n            elif hasattr(self.config, \"_pre_quantization_dtype\"):\n                target_dtype = self.config._pre_quantization_dtype\n            else:\n                target_dtype = self.q_proj.weight.dtype\n\n            logger.warning_once(\n                f\"The input hidden states seems to be silently casted in float32, this might be related to\"\n                f\" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in\"\n                f\" {target_dtype}.\"\n            )\n\n            query_states = query_states.to(target_dtype)\n            key_states = key_states.to(target_dtype)\n            value_states = value_states.to(target_dtype)\n\n        # Reashape to the expected shape for Flash Attention\n        query_states = query_states.transpose(1, 2)\n        key_states = key_states.transpose(1, 2)\n        value_states = value_states.transpose(1, 2)\n\n        attn_output = self._flash_attention_forward(\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            q_len,\n            position_ids=position_ids,\n            dropout=dropout_rate,\n            use_sliding_windows=use_sliding_windows,\n        )\n\n        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()\n        attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n    def _flash_attention_forward(\n        self,\n        query_states,\n        key_states,\n        value_states,\n        attention_mask,\n        query_length,\n        position_ids,\n        dropout=0.0,\n        softmax_scale=None,\n        use_sliding_windows=False,\n    ):\n        \"\"\"\n        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token\n        first unpad the input, then computes the attention scores and pad the final attention scores.\n\n        Args:\n            query_states (`torch.Tensor`):\n                Input query states to be passed to Flash Attention API\n            key_states (`torch.Tensor`):\n                Input key states to be passed to Flash Attention API\n            value_states (`torch.Tensor`):\n                Input value states to be passed to Flash Attention API\n            attention_mask (`torch.Tensor`):\n                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the\n                position of padding tokens and 1 for the position of non-padding tokens.\n            dropout (`float`):\n                Attention dropout\n            softmax_scale (`float`, *optional*):\n                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)\n            use_sliding_windows (`bool`, *optional*):\n                Whether to activate sliding window attention.\n        \"\"\"\n        if not self._flash_attn_uses_top_left_mask:\n            causal = self.is_causal\n        else:\n            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.\n            causal = self.is_causal and query_length != 1\n\n        # Decide whether to use SWA or not by layer index.\n        if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:\n            use_sliding_windows = False\n\n        # Contains at least one padding token in the sequence\n        if attention_mask is not None:\n            batch_size = query_states.shape[0]\n            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(\n                query_states, key_states, value_states, attention_mask, query_length\n            )\n\n            cu_seqlens_q, cu_seqlens_k = cu_seq_lens\n            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens\n\n            if not use_sliding_windows:\n                attn_output_unpad = flash_attn_varlen_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k=cu_seqlens_k,\n                    max_seqlen_q=max_seqlen_in_batch_q,\n                    max_seqlen_k=max_seqlen_in_batch_k,\n                    dropout_p=dropout,\n                    softmax_scale=softmax_scale,\n                    causal=causal,\n                )\n            else:\n                attn_output_unpad = flash_attn_varlen_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    cu_seqlens_q=cu_seqlens_q,\n                    cu_seqlens_k=cu_seqlens_k,\n                    max_seqlen_q=max_seqlen_in_batch_q,\n                    max_seqlen_k=max_seqlen_in_batch_k,\n                    dropout_p=dropout,\n                    softmax_scale=softmax_scale,\n                    causal=causal,\n                    window_size=(self.config.sliding_window, self.config.sliding_window),\n                )\n\n            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)\n        else:\n            if not use_sliding_windows:\n                if query_length == 1:\n                    position_ids = position_ids.to(dtype=torch.int32).squeeze(1)\n                    attn_output = flash_attn_with_kvcache(\n                        query_states,\n                        key_states,\n                        value_states,\n                        cache_seqlens=position_ids,\n                        softmax_scale=softmax_scale,\n                        causal=causal,\n                    )   \n                else:\n                    attn_output = flash_attn_func(\n                        query_states,\n                        key_states,\n                        value_states,\n                        dropout,\n                        softmax_scale=softmax_scale,\n                        causal=causal,\n                    )\n            else:\n                attn_output = flash_attn_func(\n                    query_states,\n                    key_states,\n                    value_states,\n                    dropout,\n                    softmax_scale=softmax_scale,\n                    causal=causal,\n                    window_size=(self.config.sliding_window, self.config.sliding_window),\n                )\n\n        return attn_output\n\n    # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input\n    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):\n        batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape\n\n        # On the first iteration we need to properly re-create the padding mask\n        # by slicing it on the proper place\n        if kv_seq_len != attention_mask.shape[-1]:\n            attention_mask_num_tokens = attention_mask.shape[-1]\n            attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]\n\n        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)\n\n        key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)\n        value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)\n\n        if query_length == kv_seq_len:\n            query_layer = index_first_axis(\n                query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k\n            )\n            cu_seqlens_q = cu_seqlens_k\n            max_seqlen_in_batch_q = max_seqlen_in_batch_k\n            indices_q = indices_k\n        elif query_length == 1:\n            max_seqlen_in_batch_q = 1\n            cu_seqlens_q = torch.arange(\n                batch_size + 1, dtype=torch.int32, device=query_layer.device\n            )  # There is a memcpy here, that is very bad.\n            indices_q = cu_seqlens_q[:-1]\n            query_layer = query_layer.squeeze(1)\n        else:\n            # The -q_len: slice assumes left padding.\n            attention_mask = attention_mask[:, -query_length:]\n            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)\n\n        return (\n            query_layer,\n            key_layer,\n            value_layer,\n            indices_q,\n            (cu_seqlens_q, cu_seqlens_k),\n            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),\n        )\n\n\n# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Qwen2Moe\nclass Qwen2MoeSdpaAttention(Qwen2MoeAttention):\n    \"\"\"\n    Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from\n    `Qwen2MoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to\n    SDPA API.\n    \"\"\"\n\n    # Adapted from Qwen2MoeAttention.forward\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if output_attentions:\n            # TODO: Improve this warning with e.g. `model.config.attn_implementation = \"manual\"` once this is implemented.\n            logger.warning_once(\n                \"Qwen2MoeModel is using Qwen2MoeSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, \"\n                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n            )\n            return super().forward(\n                hidden_states=hidden_states,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                past_key_value=past_key_value,\n                output_attentions=output_attentions,\n                use_cache=use_cache,\n            )\n\n        bsz, q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states)\n        key_states = self.k_proj(hidden_states)\n        value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        kv_seq_len = key_states.shape[-2]\n        if past_key_value is not None:\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        cos, sin = self.rotary_emb(value_states, position_ids)\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        key_states = repeat_kv(key_states, self.num_key_value_groups)\n        value_states = repeat_kv(value_states, self.num_key_value_groups)\n\n        causal_mask = attention_mask\n        if attention_mask is not None:  # no matter the length, we just slice it\n            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n\n        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,\n        # Reference: https://github.com/pytorch/pytorch/issues/112577.\n        if query_states.device.type == \"cuda\" and attention_mask is not None:\n            query_states = query_states.contiguous()\n            key_states = key_states.contiguous()\n            value_states = value_states.contiguous()\n\n        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment\n        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.\n        # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.\n        is_causal = True if causal_mask is None and q_len > 1 else False\n\n        attn_output = torch.nn.functional.scaled_dot_product_attention(\n            query_states,\n            key_states,\n            value_states,\n            attn_mask=causal_mask,\n            dropout_p=self.attention_dropout if self.training else 0.0,\n            is_causal=is_causal,\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.view(bsz, q_len, self.hidden_size)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output, None, past_key_value\n\n\nQWEN2MOE_ATTENTION_CLASSES = {\n    \"eager\": Qwen2MoeAttention,\n    \"flash_attention_2\": Qwen2MoeFlashAttention2,\n    \"sdpa\": Qwen2MoeSdpaAttention,\n}\n\n\nclass Qwen2MoeSparseMoeBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n\n        # gating\n        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)\n        self.experts = nn.ModuleList(\n            [Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]\n        )\n\n        self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size)\n        self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        if self.norm_topk_prob:\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))\n\n        shared_expert_output = self.shared_expert(hidden_states)\n        shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output\n\n        final_hidden_states = final_hidden_states + shared_expert_output\n\n        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n        return final_hidden_states, router_logits\n\n\nclass Qwen2MoeDecoderLayer(nn.Module):\n    def __init__(self, config: Qwen2MoeConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)\n\n        if (layer_idx not in config.mlp_only_layers) and (\n            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0\n        ):\n            self.mlp = Qwen2MoeSparseMoeBlock(config)\n        else:\n            self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)\n\n        self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        output_router_logits: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_router_logits (`bool`, *optional*):\n                Whether or not to return the logits of all the routers. They are useful for computing the router loss,\n                and should not be returned during inference.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights, present_key_value = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        hidden_states = self.mlp(hidden_states)\n        if isinstance(hidden_states, tuple):\n            hidden_states, router_logits = hidden_states\n        else:\n            router_logits = None\n\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if use_cache:\n            outputs += (present_key_value,)\n\n        if output_router_logits:\n            outputs += (router_logits,)\n\n        return outputs\n\n\nQWEN2MOE_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`Qwen2MoeConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.\",\n    QWEN2MOE_START_DOCSTRING,\n)\nclass Qwen2MoePreTrainedModel(PreTrainedModel):\n    config_class = Qwen2MoeConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Qwen2MoeDecoderLayer\"]\n    _skip_keys_device_placement = \"past_key_values\"\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_cache_class = True\n    _supports_static_cache = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nQWEN2MOE_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        output_router_logits (`bool`, *optional*):\n            Whether or not to return the logits of all the routers. They are useful for computing the router loss, and\n            should not be returned during inference.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.\",\n    QWEN2MOE_START_DOCSTRING,\n)\nclass Qwen2MoeModel(Qwen2MoePreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`]\n\n    Args:\n        config: Qwen2MoeConfig\n    \"\"\"\n\n    def __init__(self, config: Qwen2MoeConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [Qwen2MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self._attn_implementation = config._attn_implementation\n        self.norm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n        self.gradient_checkpointing = False\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, MoeModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        use_legacy_cache = False\n        if use_cache and not isinstance(past_key_values, Cache):\n            use_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            logger.warning_once(\n                \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n                \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n            )\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        )\n\n        hidden_states = inputs_embeds\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        all_router_logits = () if output_router_logits else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    output_router_logits,\n                    use_cache,\n                    cache_position,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    output_router_logits=output_router_logits,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n            if output_router_logits and layer_outputs[-1] is not None:\n                all_router_logits += (layer_outputs[-1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]\n                if v is not None\n            )\n        return MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            router_logits=all_router_logits,\n        )\n\n    # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if self.config._attn_implementation == \"sdpa\" and not using_static_cache and not output_attentions:\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\"Custom 4D attention mask should be passed in inverted form with max==0`\")\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n\nclass Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):\n    _tied_weights_keys = [\"lm_head.weight\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = Qwen2MoeModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        self.router_aux_loss_coef = config.router_aux_loss_coef\n        self.num_experts = config.num_experts\n        self.num_experts_per_tok = config.num_experts_per_tok\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, MoeCausalLMOutputWithPast]:\n        r\"\"\"\n        Args:\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM\n\n        >>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)\n        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            output_router_logits=output_router_logits,\n            return_dict=return_dict,\n            cache_position=cache_position,\n        )\n\n        hidden_states = outputs[0]\n        logits = self.lm_head(hidden_states)\n        logits = logits.float()\n\n        loss = None\n        if labels is not None:\n            # Shift so that tokens < n predict n\n            shift_logits = logits[..., :-1, :].contiguous()\n            shift_labels = labels[..., 1:].contiguous()\n            # Flatten the tokens\n            loss_fct = CrossEntropyLoss()\n            shift_logits = shift_logits.view(-1, self.config.vocab_size)\n            shift_labels = shift_labels.view(-1)\n            # Enable model parallelism\n            shift_labels = shift_labels.to(shift_logits.device)\n            loss = loss_fct(shift_logits, shift_labels)\n\n        aux_loss = None\n        if output_router_logits:\n            aux_loss = load_balancing_loss_func(\n                outputs.router_logits if return_dict else outputs[-1],\n                self.num_experts,\n                self.num_experts_per_tok,\n                attention_mask,\n            )\n            if labels is not None:\n                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            if output_router_logits:\n                output = (aux_loss,) + output\n            return (loss,) + output if loss is not None else output\n\n        return MoeCausalLMOutputWithPast(\n            loss=loss,\n            aux_loss=aux_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            router_logits=outputs.router_logits,\n        )\n\n    def prepare_inputs_for_generation(\n        self,\n        input_ids,\n        past_key_values=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        cache_position=None,\n        use_cache=True,\n        **kwargs,\n    ):\n        past_length = 0\n        # Omit tokens covered by past_key_values\n        if past_key_values is not None:\n            if isinstance(past_key_values, Cache):\n                past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()\n                max_cache_length = (\n                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)\n                    if past_key_values.get_max_length() is not None\n                    else None\n                )\n                cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)\n            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects\n            else:\n                cache_length = past_length = past_key_values[0][0].shape[2]\n                max_cache_length = None\n\n            # Keep only the unprocessed tokens:\n            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where\n            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as\n            # input)\n            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:\n                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]\n            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard\n            # input_ids based on the past_length.\n            elif past_length < input_ids.shape[1]:\n                input_ids = input_ids[:, past_length:]\n            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.\n\n            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.\n            if (\n                max_cache_length is not None\n                and attention_mask is not None\n                and cache_length + input_ids.shape[1] > max_cache_length\n            ):\n                attention_mask = attention_mask[:, -max_cache_length:]\n\n        position_ids = kwargs.get(\"position_ids\", None)\n        if attention_mask is not None and position_ids is None:\n            # create position_ids on the fly for batch generation\n            position_ids = attention_mask.long().cumsum(-1) - 1\n            position_ids.masked_fill_(attention_mask == 0, 1)\n            if past_key_values:\n                position_ids = position_ids[:, -input_ids.shape[1] :]\n\n        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step\n        if inputs_embeds is not None and past_length == 0:\n            model_inputs = {\"inputs_embeds\": inputs_embeds}\n        else:\n            model_inputs = {\"input_ids\": input_ids}\n\n        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]\n        if cache_position is None:\n            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)\n        elif use_cache:\n            cache_position = cache_position[-input_length:]\n\n        model_inputs.update(\n            {\n                \"position_ids\": position_ids,\n                \"past_key_values\": past_key_values,\n                \"use_cache\": use_cache,\n                \"attention_mask\": attention_mask,\n                \"cache_position\": cache_position,\n            }\n        )\n        return model_inputs\n\n    @staticmethod\n    def _reorder_cache(past_key_values, beam_idx):\n        reordered_past = ()\n        for layer_past in past_key_values:\n            reordered_past += (\n                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),\n            )\n        return reordered_past\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen2MoE Model transformer with a sequence classification head on top (linear layer).\n\n    [`Qwen2MoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    QWEN2MOE_START_DOCSTRING,\n)\n# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE\nclass Qwen2MoeForSequenceClassification(Qwen2MoePreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen2MoeModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            sequence_lengths = -1\n        else:\n            if input_ids is not None:\n                # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility\n                sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1\n                sequence_lengths = sequence_lengths % input_ids.shape[-1]\n                sequence_lengths = sequence_lengths.to(logits.device)\n            else:\n                sequence_lengths = -1\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]\n\n        loss = None\n        if labels is not None:\n            labels = labels.to(logits.device)\n            if self.config.problem_type is None:\n                if self.num_labels == 1:\n                    self.config.problem_type = \"regression\"\n                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n                    self.config.problem_type = \"single_label_classification\"\n                else:\n                    self.config.problem_type = \"multi_label_classification\"\n\n            if self.config.problem_type == \"regression\":\n                loss_fct = MSELoss()\n                if self.num_labels == 1:\n                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())\n                else:\n                    loss = loss_fct(pooled_logits, labels)\n            elif self.config.problem_type == \"single_label_classification\":\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))\n            elif self.config.problem_type == \"multi_label_classification\":\n                loss_fct = BCEWithLogitsLoss()\n                loss = loss_fct(pooled_logits, labels)\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen2MoE Model transformer with a token classification head on top (a linear layer on top of the hidden-states\n    output) e.g. for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    QWEN2MOE_START_DOCSTRING,\n)\n# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2Moe, LLAMA->QWEN2MOE\nclass Qwen2MoeForTokenClassification(Qwen2MoePreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen2MoeModel(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, TokenClassifierOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        sequence_output = outputs[0]\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/models/modeling_qwen3_moe.py",
    "content": "#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n#           This file was automatically generated from src/transformers/models/qwen3_moe/modular_qwen3_moe.py.\n#               Do NOT edit this file manually as any edits will be overwritten by the generation of\n#             the file from the modular. If any change should be done, please apply the change to the\n#                          modular_qwen3_moe.py file directly. One of our CI enforces this.\n#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨\n# coding=utf-8\n# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom typing import Callable, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache\nfrom transformers.generation import GenerationMixin\nfrom transformers.modeling_attn_mask_utils import AttentionMaskConverter\n# from transformers.modeling_flash_attention_utils import FlashAttentionKwargs\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPast,\n    CausalLMOutputWithPast,\n    MoeCausalLMOutputWithPast,\n    MoeModelOutputWithPast,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS\n# from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel\nfrom transformers.modeling_utils import PreTrainedModel\n# from transformers.processing_utils import Unpack\nfrom transformers.utils import (\n    # LossKwargs,\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    logging,\n    replace_return_docstrings,\n)\nfrom transformers.utils.deprecation import deprecate_kwarg\nfrom .configuration_qwen3_moe import Qwen3MoeConfig\n\nfrom ktransformers.util.grad_wrapper import maybe_no_grad\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeRotaryEmbedding\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"Qwen/Qwen3-MoE-15B-A2B\"\n_CONFIG_FOR_DOC = \"Qwen3MoeConfig\"\n\n\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n\ndef apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n    \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n    Args:\n        q (`torch.Tensor`): The query tensor.\n        k (`torch.Tensor`): The key tensor.\n        cos (`torch.Tensor`): The cosine part of the rotary embedding.\n        sin (`torch.Tensor`): The sine part of the rotary embedding.\n        position_ids (`torch.Tensor`, *optional*):\n            Deprecated and unused.\n        unsqueeze_dim (`int`, *optional*, defaults to 1):\n            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n    Returns:\n        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n    \"\"\"\n    cos = cos.unsqueeze(unsqueeze_dim)\n    sin = sin.unsqueeze(unsqueeze_dim)\n    q_embed = (q * cos) + (rotate_half(q) * sin)\n    k_embed = (k * cos) + (rotate_half(k) * sin)\n    return q_embed, k_embed\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs,\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\nclass Qwen3MoeAttention(nn.Module):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: Qwen3MoeConfig, layer_idx: int):\n        super().__init__()\n        self.config = config\n        self.layer_idx = layer_idx\n        self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads\n        self.num_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.scaling = self.head_dim**-0.5\n        self.attention_dropout = config.attention_dropout\n        self.is_causal = True\n\n        self.q_proj = nn.Linear(\n            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.k_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.v_proj = nn.Linear(\n            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias\n        )\n        self.o_proj = nn.Linear(\n            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias\n        )\n        self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)  # unlike olmo, only on the head dim!\n        self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)  # thus post q_norm does not need reshape\n\n        self.rotary_emb = Qwen2MoeRotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n        self.sliding_window = config.sliding_window\n        if not (\n            self.config.use_sliding_window\n            and getattr(self.config, \"sliding_window\", None) is not None\n            and self.layer_idx >= self.config.max_window_layers\n        ):\n            self.sliding_window = None\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        # **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        cos, sin = position_embeddings\n        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n                logger.warning_once(\n                    \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to \"\n                    'eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n                )\n            else:\n                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            sliding_window=self.sliding_window,  # diff with Llama\n            # **kwargs,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n\n\nclass Qwen3MoeMLP(nn.Module):\n    def __init__(self, config, intermediate_size=None):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size\n        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\nclass Qwen3MoeSparseMoeBlock(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.num_experts = config.num_experts\n        self.top_k = config.num_experts_per_tok\n        self.norm_topk_prob = config.norm_topk_prob\n\n        # gating\n        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)\n        self.experts = nn.ModuleList(\n            [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]\n        )\n\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        if self.norm_topk_prob:  # only diff with mixtral sparse moe block!\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))\n        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n        return final_hidden_states, router_logits\n\n\nclass Qwen3MoeRMSNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-6):\n        \"\"\"\n        Qwen3MoeRMSNorm is equivalent to T5LayerNorm\n        \"\"\"\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(hidden_size))\n        self.variance_epsilon = eps\n        self.hidden_size = hidden_size\n\n    def forward(self, hidden_states):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n    def extra_repr(self):\n        return f\"{tuple(self.weight.shape)}, eps={self.variance_epsilon}\"\n\n\nclass Qwen3MoeDecoderLayer(nn.Module):\n    def __init__(self, config: Qwen3MoeConfig, layer_idx: int):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n\n        self.self_attn = Qwen3MoeAttention(config, layer_idx)\n        self.mlp = Qwen3MoeMLP(config)\n\n        self.self_attn = Qwen3MoeAttention(config, layer_idx)\n\n        if (layer_idx not in config.mlp_only_layers) and (\n            config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0\n        ):\n            self.mlp = Qwen3MoeSparseMoeBlock(config)\n        else:\n            self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)\n\n        self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Tuple[torch.Tensor]] = None,\n        output_attentions: Optional[bool] = False,\n        output_router_logits: Optional[bool] = False,\n        use_cache: Optional[bool] = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC\n        # **kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:\n        \"\"\"\n        Args:\n            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`\n            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size\n                `(batch, sequence_length)` where padding elements are indicated by 0.\n            output_attentions (`bool`, *optional*):\n                Whether or not to return the attentions tensors of all attention layers. See `attentions` under\n                returned tensors for more detail.\n            output_router_logits (`bool`, *optional*):\n                Whether or not to return the logits of all the routers. They are useful for computing the router loss,\n                and should not be returned during inference.\n            use_cache (`bool`, *optional*):\n                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding\n                (see `past_key_values`).\n            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states\n            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):\n                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,\n                with `head_dim` being the embedding dimension of each attention head.\n            kwargs (`dict`, *optional*):\n                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code\n                into the model\n        \"\"\"\n        residual = hidden_states\n\n        hidden_states = self.input_layernorm(hidden_states)\n\n        # Self Attention\n        hidden_states, self_attn_weights = self.self_attn(\n            hidden_states=hidden_states,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_value=past_key_value,\n            output_attentions=output_attentions,\n            use_cache=use_cache,\n            cache_position=cache_position,\n            position_embeddings=position_embeddings,\n        )\n        hidden_states = residual + hidden_states\n\n        # Fully Connected\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n\n        hidden_states = self.mlp(hidden_states)\n        if isinstance(hidden_states, tuple):\n            hidden_states, router_logits = hidden_states\n        else:\n            router_logits = None\n\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (self_attn_weights,)\n\n        if output_router_logits:\n            outputs += (router_logits,)\n\n        return outputs\n\n\ndef _compute_default_rope_parameters(\n    config: Optional[Qwen3MoeConfig] = None,\n    device: Optional[\"torch.device\"] = None,\n    seq_len: Optional[int] = None,\n    **rope_kwargs,\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies according to the original RoPE implementation\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length. Unused for this type of RoPE.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).\n    \"\"\"\n    if config is not None and len(rope_kwargs) > 0:\n        raise ValueError(\n            \"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in \"\n            f\"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}\"\n        )\n    if len(rope_kwargs) > 0:\n        base = rope_kwargs[\"base\"]\n        dim = rope_kwargs[\"dim\"]\n    elif config is not None:\n        base = config.rope_theta\n        partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n        dim = int(config.head_dim * partial_rotary_factor)\n\n    attention_factor = 1.0  # Unused in this type of RoPE\n\n    # Compute the inverse frequencies\n    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))\n    return inv_freq, attention_factor\n\nclass Qwen3MoeRotaryEmbedding(nn.Module):\n    def __init__(self, config: Qwen3MoeConfig, device=None):\n        super().__init__()\n        # BC: \"rope_type\" was originally \"type\"\n        if hasattr(config, \"rope_scaling\") and config.rope_scaling is not None:\n            self.rope_type = config.rope_scaling.get(\"rope_type\", config.rope_scaling.get(\"type\"))\n        else:\n            self.rope_type = \"default\"\n        self.max_seq_len_cached = config.max_position_embeddings\n        self.original_max_seq_len = config.max_position_embeddings\n\n        self.config = config\n        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]\n\n        self.scaling_factor = 1.0\n        self.dim = config.head_dim\n        self.max_position_embeddings = config.max_position_embeddings\n        self.base = config.rope_theta\n        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n\n        inv_freq, self.attention_scaling = _compute_default_rope_parameters(self.config, device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.original_inv_freq = self.inv_freq\n\n    def _dynamic_frequency_update(self, position_ids, device):\n        \"\"\"\n        dynamic RoPE layers should recompute `inv_freq` in the following situations:\n        1 - growing beyond the cached sequence length (allow scaling)\n        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)\n        \"\"\"\n        seq_len = torch.max(position_ids) + 1\n        if seq_len > self.max_seq_len_cached:  # growth\n            inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)\n            self.register_buffer(\"inv_freq\", inv_freq, persistent=False)  # TODO joao: may break with compilation\n            self.max_seq_len_cached = seq_len\n\n        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset\n            # This .to() is needed if the model has been moved to a device after being initialized (because\n            # the buffer is automatically moved, but not the original copy)\n            self.original_inv_freq = self.original_inv_freq.to(device)\n            self.register_buffer(\"inv_freq\", self.original_inv_freq, persistent=False)\n            self.max_seq_len_cached = self.original_max_seq_len\n\n    @maybe_no_grad()\n    def forward(self, x, position_ids):\n        if \"dynamic\" in self.rope_type:\n            self._dynamic_frequency_update(position_ids, device=x.device)\n\n        # Core RoPE block\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n\n        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention\n        cos = cos * self.attention_scaling\n        sin = sin * self.attention_scaling\n\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)\n\n\nQWEN3_MOE_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`Qwen3MoeConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen3Moe Model outputting raw hidden-states without any specific head on top.\",\n    QWEN3_MOE_START_DOCSTRING,\n)\nclass Qwen3MoePreTrainedModel(PreTrainedModel):\n    config_class = Qwen3MoeConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"Qwen3MoeDecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_flex_attn = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = False  # MoE models don't work with torch.compile (`torch.where(condition)` not supported)\n    _supports_attention_backend = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nQWEN3_MOE_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen3Moe Model outputting raw hidden-states without any specific head on top.\",\n    QWEN3_MOE_START_DOCSTRING,\n)\nclass Qwen3MoeModel(Qwen3MoePreTrainedModel):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3MoeDecoderLayer`]\n\n    Args:\n        config: Qwen3MoeConfig\n    \"\"\"\n\n    def __init__(self, config: Qwen3MoeConfig):\n        super().__init__(config)\n        self.padding_idx = config.pad_token_id\n        self.vocab_size = config.vocab_size\n\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)\n        self.layers = nn.ModuleList(\n            [Qwen3MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]\n        )\n        self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)\n        self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)\n        self.gradient_checkpointing = False\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        # **flash_attn_kwargs: Unpack[FlashAttentionKwargs],\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\"You must specify exactly one of input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache()\n\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids)\n\n        if cache_position is None:\n            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n            cache_position = torch.arange(\n                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n        )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        all_router_logits = () if output_router_logits else None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    output_router_logits,\n                    use_cache,\n                    cache_position,\n                    position_embeddings,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    output_router_logits=output_router_logits,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                    # **flash_attn_kwargs,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n            if output_router_logits:\n                all_router_logits += (layer_outputs[-1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        output = MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            router_logits=all_router_logits,\n        )\n        return output if return_dict else output.to_tuple()\n\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool = False,\n    ):\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and past_key_values is not None:\n                is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]\n                if is_padding_right:\n                    raise ValueError(\n                        \"You are attempting to perform batched generation with padding_side='right'\"\n                        \" this may lead to unexpected behaviour for Flash Attention version of Qwen3Moe. Make sure to \"\n                        \" call `tokenizer.padding_side  = 'left'` before tokenizing the input. \"\n                    )\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0\n        using_static_cache = isinstance(past_key_values, StaticCache)\n        using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and not (using_static_cache or using_sliding_window_cache)\n            and not output_attentions\n        ):\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                sliding_window=self.config.sliding_window,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        # SlidingWindowCache or StaticCache\n        if using_sliding_window_cache or using_static_cache:\n            target_length = past_key_values.get_max_cache_shape()\n        # DynamicCache or no cache\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).\n        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(\n            attention_mask,\n            sequence_length=sequence_length,\n            target_length=target_length,\n            dtype=dtype,\n            device=device,\n            cache_position=cache_position,\n            batch_size=input_tensor.shape[0],\n            config=self.config,\n            past_key_values=past_key_values,\n        )\n\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type in [\"cuda\", \"xpu\"]\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)\n\n        return causal_mask\n\n    @staticmethod\n    def _prepare_4d_causal_attention_mask_with_cache_position(\n        attention_mask: torch.Tensor,\n        sequence_length: int,\n        target_length: int,\n        dtype: torch.dtype,\n        device: torch.device,\n        cache_position: torch.Tensor,\n        batch_size: int,\n        config: Qwen3MoeConfig,\n        past_key_values: Cache,\n    ):\n        \"\"\"\n        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape\n        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.\n\n        Args:\n            attention_mask (`torch.Tensor`):\n                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.\n            sequence_length (`int`):\n                The sequence length being processed.\n            target_length (`int`):\n                The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.\n            dtype (`torch.dtype`):\n                The dtype to use for the 4D attention mask.\n            device (`torch.device`):\n                The device to place the 4D attention mask on.\n            cache_position (`torch.Tensor`):\n                Indices depicting the position of the input sequence tokens in the sequence.\n            batch_size (`torch.Tensor`):\n                Batch size.\n            config (`Qwen3MoeConfig`):\n                The model's configuration class\n            past_key_values (`Cache`):\n                The cache class that is being used currently to generate\n        \"\"\"\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.\n            causal_mask = attention_mask\n        else:\n            min_dtype = torch.finfo(dtype).min\n            causal_mask = torch.full(\n                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device\n            )\n            diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)\n            if config.sliding_window is not None:\n                # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also\n                # the check is needed to verify is current checkpoint was trained with sliding window or not\n                if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:\n                    sliding_attend_mask = torch.arange(target_length, device=device) <= (\n                        cache_position.reshape(-1, 1) - config.sliding_window\n                    )\n                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)\n            causal_mask *= diagonal_attend_mask\n            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)\n            if attention_mask is not None:\n                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit\n                if attention_mask.shape[-1] > target_length:\n                    attention_mask = attention_mask[:, :target_length]\n                mask_length = attention_mask.shape[-1]\n                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(\n                    causal_mask.device\n                )\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(\n                    padding_mask, min_dtype\n                )\n        return causal_mask\n\n\n# class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...\nclass KwargsForCausalLM(): ...\n\n\ndef load_balancing_loss_func(\n    gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],\n    num_experts: Optional[int] = None,\n    top_k=2,\n    attention_mask: Optional[torch.Tensor] = None,\n) -> Union[torch.Tensor, int]:\n    r\"\"\"\n    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.\n\n    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss\n    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between\n    experts is too unbalanced.\n\n    Args:\n        gate_logits:\n            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of\n            shape [batch_size X sequence_length, num_experts].\n        num_experts:\n            Number of experts\n        top_k:\n            The number of experts to route per-token, can be also interpreted as the `top-k` routing\n            parameter.\n        attention_mask (`torch.Tensor`, *optional*):\n            The attention_mask used in forward function\n            shape [batch_size X sequence_length] if not None.\n\n    Returns:\n        The auxiliary loss.\n    \"\"\"\n    if gate_logits is None or not isinstance(gate_logits, tuple):\n        return 0\n\n    if isinstance(gate_logits, tuple):\n        compute_device = gate_logits[0].device\n        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)\n\n    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)\n\n    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)\n\n    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)\n\n    if attention_mask is None:\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.mean(routing_weights, dim=0)\n    else:\n        batch_size, sequence_length = attention_mask.shape\n        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask\n        expert_attention_mask = (\n            attention_mask[None, :, :, None, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))\n            .reshape(-1, top_k, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the percentage of tokens routed to each experts\n        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(\n            expert_attention_mask, dim=0\n        )\n\n        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert\n        router_per_expert_attention_mask = (\n            attention_mask[None, :, :, None]\n            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))\n            .reshape(-1, num_experts)\n            .to(compute_device)\n        )\n\n        # Compute the average probability of routing to these experts\n        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(\n            router_per_expert_attention_mask, dim=0\n        )\n\n    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))\n    return overall_loss * num_experts\n\n\nclass Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin):\n    _tied_weights_keys = [\"lm_head.weight\"]\n    _tp_plan = {\"lm_head\": \"colwise_rep\"}\n    _pp_plan = {\"lm_head\": ([\"hidden_states\"], [\"logits\"])}\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = Qwen3MoeModel(config)\n        self.vocab_size = config.vocab_size\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.router_aux_loss_coef = config.router_aux_loss_coef\n        self.num_experts = config.num_experts\n        self.num_experts_per_tok = config.num_experts_per_tok\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    def get_output_embeddings(self):\n        return self.lm_head\n\n    def set_output_embeddings(self, new_embeddings):\n        self.lm_head = new_embeddings\n\n    def set_decoder(self, decoder):\n        self.model = decoder\n\n    def get_decoder(self):\n        return self.model\n\n    @deprecate_kwarg(\"num_logits_to_keep\", version=\"4.50\", new_name=\"logits_to_keep\")\n    @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)\n    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        logits_to_keep: Union[int, torch.Tensor] = 0,\n        # **kwargs: Unpack[KwargsForCausalLM],\n    ) -> Union[Tuple, CausalLMOutputWithPast]:\n        r\"\"\"\n            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,\n                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored\n                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.\n\n            logits_to_keep (`int` or `torch.Tensor`, *optional*):\n                If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all\n                `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that\n                token can save memory, which becomes pretty significant for long sequences or large vocabulary size.\n                If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.\n                This is useful when using packed tensor format (single dimension for batch and sequence length).\n\n        Returns:\n\n        Example:\n\n        ```python\n        >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM\n\n        >>> model = Qwen3MoeForCausalLM.from_pretrained(\"Qwen/Qwen3-MoE-15B-A2B\")\n        >>> tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen3-MoE-15B-A2B\")\n\n        >>> prompt = \"Hey, are you conscious? Can you talk to me?\"\n        >>> inputs = tokenizer(prompt, return_tensors=\"pt\")\n\n        >>> # Generate\n        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)\n        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n        \"Hey, are you conscious? Can you talk to me?\\nI'm not conscious, but I can talk to you.\"\n        ```\"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_router_logits = (\n            output_router_logits if output_router_logits is not None else self.config.output_router_logits\n        )\n\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\n        outputs = self.model(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            output_router_logits=output_router_logits,\n            return_dict=return_dict,\n            cache_position=cache_position,\n            # **kwargs,\n        )\n\n        hidden_states = outputs[0]\n        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss\n        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep\n        logits = self.lm_head(hidden_states[:, slice_indices, :])\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.vocab_size)\n\n        aux_loss = None\n        if output_router_logits:\n            aux_loss = load_balancing_loss_func(\n                outputs.router_logits if return_dict else outputs[-1],\n                self.num_experts,\n                self.num_experts_per_tok,\n                attention_mask,\n            )\n            if labels is not None:\n                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device\n\n        if not return_dict:\n            output = (logits,) + outputs[1:]\n            if output_router_logits:\n                output = (aux_loss,) + output\n            return (loss,) + output if loss is not None else output\n\n        return MoeCausalLMOutputWithPast(\n            loss=loss,\n            aux_loss=aux_loss,\n            logits=logits,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            router_logits=outputs.router_logits,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen3Moe Model transformer with a sequence classification head on top (linear layer).\n\n    [`Qwen3MoeForSequenceClassification`] uses the last token in order to do the classification, as other causal models\n    (e.g. GPT-2) do.\n\n    Since it does classification on the last token, it requires to know the position of the last token. If a\n    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If\n    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the\n    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in\n    each row of the batch).\n    \"\"\",\n    QWEN3_MOE_START_DOCSTRING,\n)\nclass Qwen3MoeForSequenceClassification(Qwen3MoePreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen3MoeModel(config)\n        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        transformer_outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = transformer_outputs[0]\n        logits = self.score(hidden_states)\n\n        if input_ids is not None:\n            batch_size = input_ids.shape[0]\n        else:\n            batch_size = inputs_embeds.shape[0]\n\n        if self.config.pad_token_id is None and batch_size != 1:\n            raise ValueError(\"Cannot handle batch sizes > 1 if no padding token is defined.\")\n        if self.config.pad_token_id is None:\n            last_non_pad_token = -1\n        elif input_ids is not None:\n            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id\n            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)\n            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)\n            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)\n        else:\n            last_non_pad_token = -1\n            logger.warning_once(\n                f\"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be \"\n                \"unexpected if using padding tokens in conjunction with `inputs_embeds.`\"\n            )\n\n        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)\n\n        if not return_dict:\n            output = (pooled_logits,) + transformer_outputs[1:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutputWithPast(\n            loss=loss,\n            logits=pooled_logits,\n            past_key_values=transformer_outputs.past_key_values,\n            hidden_states=transformer_outputs.hidden_states,\n            attentions=transformer_outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    The Qwen3Moe Model transformer with a token classification head on top (a linear layer on top of the hidden-states\n    output) e.g. for Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    QWEN3_MOE_START_DOCSTRING,\n)\nclass Qwen3MoeForTokenClassification(Qwen3MoePreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n        self.model = Qwen3MoeModel(config)\n        if getattr(config, \"classifier_dropout\", None) is not None:\n            classifier_dropout = config.classifier_dropout\n        elif getattr(config, \"hidden_dropout\", None) is not None:\n            classifier_dropout = config.hidden_dropout\n        else:\n            classifier_dropout = 0.1\n        self.dropout = nn.Dropout(classifier_dropout)\n        self.score = nn.Linear(config.hidden_size, config.num_labels)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.model.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.model.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)\n    @add_code_sample_docstrings(\n        checkpoint=_CHECKPOINT_FOR_DOC,\n        output_type=TokenClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        labels: Optional[torch.LongTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ) -> Union[Tuple, TokenClassifierOutput]:\n        r\"\"\"\n        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.model(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        sequence_output = outputs[0]\n        sequence_output = self.dropout(sequence_output)\n        logits = self.score(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss = self.loss_function(logits, labels, self.config)\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\nThe Qwen3Moe Model transformer with a span classification head on top for extractive question-answering tasks like\nSQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).\n    \"\"\",\n    QWEN3_MOE_START_DOCSTRING,\n)\nclass Qwen3MoeForQuestionAnswering(Qwen3MoePreTrainedModel):\n    base_model_prefix = \"transformer\"\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.transformer = Qwen3MoeModel(config)\n        self.qa_outputs = nn.Linear(config.hidden_size, 2)\n\n        # Initialize weights and apply final processing\n        self.post_init()\n\n    def get_input_embeddings(self):\n        return self.transformer.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.transformer.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: Optional[torch.LongTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Cache] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        start_positions: Optional[torch.LongTensor] = None,\n        end_positions: Optional[torch.LongTensor] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        **kwargs,\n    ) -> Union[Tuple, QuestionAnsweringModelOutput]:\n        r\"\"\"\n        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence\n            are not taken into account for computing the loss.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.transformer(\n            input_ids,\n            attention_mask=attention_mask,\n            position_ids=position_ids,\n            past_key_values=past_key_values,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = outputs[0]\n\n        logits = self.qa_outputs(sequence_output)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1).contiguous()\n        end_logits = end_logits.squeeze(-1).contiguous()\n\n        loss = None\n        if start_positions is not None and end_positions is not None:\n            loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)\n\n        if not return_dict:\n            output = (start_logits, end_logits) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n__all__ = [\n    \"Qwen3MoeForCausalLM\",\n    \"Qwen3MoeForQuestionAnswering\",\n    \"Qwen3MoeModel\",\n    \"Qwen3MoePreTrainedModel\",\n    \"Qwen3MoeForSequenceClassification\",\n    \"Qwen3MoeForTokenClassification\",\n]"
  },
  {
    "path": "kt-sft/ktransformers/moe_test_module.py",
    "content": "import os\nimport platform\nimport sys\n\nproject_dir = os.path.dirname(os.path.dirname(__file__))\nsys.path.insert(0, project_dir)\n\nfrom torchviz import make_dot\nfrom torch import nn\nimport torch\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    TextStreamer,\n)\nimport unittest\nfrom torch.autograd import gradcheck\n\nfrom ktransformers.operators.linear import KLinearTorch, KTransformersLinear\nfrom ktransformers.sft.peft_utils.lora_layer import KTransformersLinearLora\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom ktransformers.operators.experts import KExpertsTorch\nfrom ktransformers.util.utils import load_weights\n\ngguf_loader = GGUFLoader(gguf_path=\"/home/yj/ktransformers/GGUF-DeepSeek-V2-Lite-Chat\")\nconfig = AutoConfig.from_pretrained(\"/home/yj/ktransformers/DeepSeek-V2-Lite-Chat\", trust_remote_code=True)\ntorch.set_default_dtype(config.torch_dtype)\n\nclass TestKExpertsTorch(unittest.TestCase):\n    def setUp(self):\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n        self.num_experts = 8\n        \n        self.fixed_input = None\n        self.fixed_expert_ids = None\n        self.fixed_weights = None\n        \n    def _create_fixed_data(self, device, batch_size=2):\n        \"\"\"创建固定输入数据\"\"\"\n        if self.fixed_input is None:\n            with torch.random.fork_rng():\n                torch.manual_seed(42)\n                hidden_size = config.hidden_size\n                \n                self.fixed_input = torch.randn(batch_size, hidden_size)\n                \n                self.fixed_expert_ids = torch.tensor([[0, 1], [2, 3]], dtype=torch.long)\n                \n                self.fixed_weights = torch.tensor([[0.5, 0.5], [0.5, 0.5]], dtype=torch.float32)\n        \n        return (\n            self.fixed_input.clone().to(device).requires_grad_(True),\n            self.fixed_expert_ids.clone().to(device),\n            self.fixed_weights.clone().to(device)\n        )\n\n    def _run_single_device_test(self, device, seed=42):\n        \"\"\"在指定设备上运行前向反向传播并返回梯度\"\"\"\n        torch.manual_seed(seed)\n        if device == \"cuda\":\n            torch.cuda.manual_seed_all(seed)\n        \n        model = KExpertsTorch(\n            key=\"blk.1\",\n            gguf_loader=gguf_loader,\n            config=config,\n            n_routed_experts=self.num_experts,\n            device=device\n        )\n        model.load(device=device)\n        \n        input_tensor, expert_ids, weights = self._create_fixed_data(device)\n        \n        model.to(device)\n        \n        with torch.autocast(device_type=device, enabled=False):\n            output = model(input_tensor, expert_ids, weights)\n            \n        loss = output.sum()\n        loss.backward()\n        \n        gradients = {\n            \"input\": input_tensor.grad.detach().cpu(),\n            \"loss\": loss.detach().cpu(),\n            \"model\": [p.grad.detach().cpu() for p in model.parameters() if p.grad is not None]\n        }\n        return gradients\n\n    def test_forward_gradient(self):\n        cpu_gradients = self._run_single_device_test(\"cpu\")\n        \n        if torch.cuda.is_available():\n            gpu_gradients = self._run_single_device_test(\"cuda\")\n\n            print(f\"cpu_gradients:{cpu_gradients}\")\n            print(f\"gpu_gradients:{gpu_gradients}\")\n            \n            input_diff = torch.max(torch.abs(cpu_gradients[\"input\"] - gpu_gradients[\"input\"]))\n            print(f\"input_diff:{input_diff}\")\n            \n            for i, (cpu_g, gpu_g) in enumerate(zip(cpu_gradients[\"model\"], gpu_gradients[\"model\"])):\n                param_diff = torch.max(torch.abs(cpu_g - gpu_g))\n                print(f\"param_diff:{param_diff}\")\n\n            for i, (cpu_g, gpu_g) in enumerate(zip(cpu_gradients[\"model\"], gpu_gradients[\"model\"])):\n                diff = (cpu_g - gpu_g.cpu()).abs().max()\n                print(f\"参数梯度 {i} 最大差异: {diff.item()}\")\n                self.assertTrue(torch.allclose(cpu_g, gpu_g, atol=1e-4, rtol=1e-3),\n                            f\"参数梯度 {i} 差异超出阈值，最大差异: {diff.item()}\")\n                \n        else:\n            self.skipTest(\"CUDA不可用，跳过GPU测试\")\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "kt-sft/ktransformers/moe_test_module_old.py",
    "content": "import os\nimport platform\nimport sys\n\nproject_dir = os.path.dirname(os.path.dirname(__file__))\nsys.path.insert(0, project_dir)\n\nfrom torchviz import make_dot\nfrom torch import nn\nimport torch\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    AutoModelForCausalLM,\n    GenerationConfig,\n    TextStreamer,\n)\nimport unittest\nfrom torch.autograd import gradcheck\n\nfrom ktransformers.operators.linear import KLinearTorch, KTransformersLinear\nfrom ktransformers.sft.peft_utils.lora_layer import KTransformersLinearLora\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom ktransformers.operators.experts import KExpertsTorch\nfrom ktransformers.util.utils import load_weights\n\ngguf_loader = GGUFLoader(gguf_path=\"/home/yj/ktransformers/GGUF-DeepSeek-V2-Lite-Chat\")\nconfig = AutoConfig.from_pretrained(\"/home/yj/ktransformers/DeepSeek-V2-Lite-Chat\", trust_remote_code=True)\ntorch.set_default_dtype(config.torch_dtype)\n\nclass TestKExpertsTorch(unittest.TestCase):\n    def setUp(self):\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n        self.base_device = \"cpu\"\n        self.num_experts = 8\n        # model = KExpertsTorch(\n        #     key=\"blk.1\",\n        #     gguf_loader=gguf_loader,\n        #     config=config,\n        #     n_routed_experts=self.num_experts,\n        #     device=self.base_device\n        # )\n        # model.load()\n        \n    def _run_single_device_test(self, device, seed=42):\n        \"\"\"在指定设备上运行前向反向传播并返回梯度\"\"\"\n        torch.manual_seed(seed)\n        if device == \"cuda\":\n            torch.cuda.manual_seed_all(seed)\n        \n        model = KExpertsTorch(\n            key=\"blk.1\",\n            gguf_loader=gguf_loader,\n            config=config,\n            n_routed_experts=self.num_experts,\n            device=device\n        )\n        model.load(device=device)\n\n        with torch.random.fork_rng():\n            torch.manual_seed(seed)\n            batch_size = 2\n            hidden_size = model.config.hidden_size\n            input_tensor = torch.randn(batch_size, hidden_size, device=device, requires_grad=True)\n            expert_ids = torch.randint(0, self.num_experts, \n                                    (batch_size, model.config.num_experts_per_tok), \n                                    device=device)\n            weights = torch.randn(batch_size, model.config.num_experts_per_tok, device=device)\n            weights = torch.softmax(weights, dim=-1)\n        \n        print(f\"input_tensor.device:{input_tensor.device}\")\n        print(f\"torch.device(device):{torch.device(device)}\")\n        # assert input_tensor.device == torch.device(device)\n        for p in model.parameters():\n            print(f\"p.device:{p.device}\")\n\n        for name, param in model.named_parameters():\n            print(name, param.size())\n\n        \n        model.to(device)\n        with torch.autocast(device_type=device, enabled=False):\n            output = model(input_tensor, expert_ids, weights)\n        \n        loss = output.sum()\n\n        \n        # dot = make_dot(output, params=dict(model.named_parameters()))\n        # dot.render(f\"origin_moe_{torch.device(device)}_graph\", format=\"svg\")\n\n        loss.backward()\n        \n        gradients = {\n            \"input\": input_tensor.grad.clone().cpu(),\n            \"loss\": loss.clone().cpu(),\n            \"model\": [p.grad.clone().cpu() for p in model.parameters() if p.grad is not None]\n        }\n        return gradients\n\n    def test_forward_gradient(self):\n        # for param in model.parameters():\n        #     self.assertEqual(param.dtype, config.torch_dtype)\n        \n        cpu_gradients = self._run_single_device_test(\"cpu\")\n        print(f\"cpu_gradients: {cpu_gradients}\")\n        \n        self.assertIsNotNone(cpu_gradients[\"input\"])\n        self.assertTrue(all(g is not None for g in cpu_gradients[\"model\"]))\n        \n        if torch.cuda.is_available():\n            gpu_gradients = self._run_single_device_test(\"cuda\")\n\n            print(f\"gpu_gradients: {gpu_gradients}\")\n\n            \n            max_diff = (cpu_gradients[\"input\"] - gpu_gradients[\"input\"].cpu()).abs().max()\n            print(f\"Input梯度最大差异: {max_diff.item()}\")\n\n            self.assertTrue(torch.allclose(cpu_gradients[\"input\"], gpu_gradients[\"input\"], atol=1e-4, rtol=1e-3),\n                        f\"Input梯度差异超出阈值，最大差异: {max_diff.item()}\")\n\n            for i, (cpu_g, gpu_g) in enumerate(zip(cpu_gradients[\"model\"], gpu_gradients[\"model\"])):\n                diff = (cpu_g - gpu_g.cpu()).abs().max()\n                print(f\"参数梯度 {i} 最大差异: {diff.item()}\")\n                self.assertTrue(torch.allclose(cpu_g, gpu_g, atol=1e-4, rtol=1e-3),\n                            f\"参数梯度 {i} 差异超出阈值，最大差异: {diff.item()}\")\n\n        else:\n            raise ImportError(\"NO CUDA FOR TEST!!\")\n\n    # def test_detach_effect(self):\n    #     input_tensor = torch.randn(1, model.config.hidden_size, device=\"cpu\", requires_grad=True)\n    #     expert_ids = torch.tensor([[0, 1]], device=\"cpu\")\n    #     weights = torch.tensor([[0.5, 0.5]], device=\"cpu\")\n\n    #     output = model(input_tensor, expert_ids, weights)\n        \n    #     # dot = make_dot(output, params=dict(model.named_parameters()))\n    #     # dot.render(\"origin_moe_cpu_graph\", format=\"svg\")\n        \n    #     loss = output.sum()\n    #     loss.backward()\n        \n    #     self.assertIsNotNone(input_tensor.grad)\n    #     self.assertTrue(all(p.grad is not None for p in model.parameters()))\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "kt-sft/ktransformers/operators/RoPE.py",
    "content": "\"\"\"\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\n\nfrom torch import nn\nfrom transformers import ROPE_INIT_FUNCTIONS\nfrom ktransformers.models.modeling_llama import (\n    LlamaRotaryEmbedding,\n    LlamaLinearScalingRotaryEmbedding,\n    LlamaDynamicNTKScalingRotaryEmbedding,\n)\nfrom ktransformers.models.modeling_deepseek_v3 import (\n    DeepseekV3RotaryEmbedding\n)\nfrom ktransformers.models.modeling_deepseek import (\n    DeepseekV2YarnRotaryEmbedding,\n    DeepseekV2RotaryEmbedding,\n    yarn_get_mscale,\n    yarn_linear_ramp_mask,\n    yarn_find_correction_range\n)\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom ktransformers.util.inference_state import InferenceState\nfrom ktransformers.util.grad_wrapper import maybe_no_grad\nfrom transformers.configuration_utils import PretrainedConfig\nimport torch\n\n# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe\nclass RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            orig_module.dim, orig_module.max_position_embeddings, orig_module.base\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.dim,\n            self.orig_module.max_position_embeddings,\n            self.orig_module.base,\n            self.device,\n        )\n\n\nclass RotaryEmbeddingV3(BaseInjectedModule):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n    \n    @maybe_no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)   \n\n    def load(self):\n        self._init(\n            dim=self.config.qk_rope_head_dim,\n            max_position_embeddings=self.config.max_position_embeddings,\n            base=self.config.rope_theta,\n            device=self.device,\n        )\n    def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        # self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\nclass RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            orig_module.dim,\n            orig_module.max_position_embeddings,\n            orig_module.base,\n            None,\n            orig_module.scaling_factor,\n            orig_module.rope_type,\n            orig_module.config,\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.dim,\n            self.orig_module.max_position_embeddings,\n            self.orig_module.base,\n            self.device,\n            self.orig_module.scaling_factor,\n            self.orig_module.rope_type,\n            self.orig_module.config,\n        )\n\nclass YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            orig_module.dim,\n            orig_module.max_position_embeddings,\n            orig_module.base,\n            None,  # device\n            orig_module.scaling_factor,\n            orig_module.original_max_position_embeddings,\n            orig_module.beta_fast,\n            orig_module.beta_slow,\n            orig_module.mscale,\n            orig_module.mscale_all_dim,\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.dim,\n            self.orig_module.max_position_embeddings,\n            self.orig_module.base,\n            self.generate_device,\n            self.orig_module.scaling_factor,\n            self.orig_module.original_max_position_embeddings,\n            self.orig_module.beta_fast,\n            self.orig_module.beta_slow,\n            self.orig_module.mscale,\n            self.orig_module.mscale_all_dim,\n        )\n\n# class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):\n#     def __init__(\n#         self,\n#         key: str,\n#         gguf_loader: GGUFLoader,\n#         config: PretrainedConfig,\n#         orig_module: nn.Module,\n#         #  device: str = \"cuda\",\n#         generate_device: str = \"cuda\",\n#         prefill_device: str = \"cuda\",\n#         **kwargs,\n#     ):\n#         BaseInjectedModule.__init__(\n#             self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n#         )\n#         self.generate_device = generate_device\n#         self.prefill_device = prefill_device\n\n#     def load(self):\n#         # TODO support perlayer prefill\n#         self.orig_module.__init__(\n#             self.config,\n#             device=self.generate_device\n#         )\n#         return\n\nclass YarnRotaryEmbeddingV3(BaseInjectedModule):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n    \n    def load(self):\n        kwargs = {\n            key: self.config.rope_scaling[key]\n            for key in [\n                \"original_max_position_embeddings\",\n                \"beta_fast\",\n                \"beta_slow\",\n                \"mscale\",\n                \"mscale_all_dim\",\n            ]\n            if key in self.config.rope_scaling\n        }\n        self._init(\n            dim=self.config.qk_rope_head_dim,\n            max_position_embeddings=self.config.max_position_embeddings,\n            base=self.config.rope_theta,\n            device=self.device,\n            scaling_factor=self.config.rope_scaling[\"factor\"],\n            **kwargs,\n        )\n\n    @maybe_no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()* self._mscale\n            sin = emb.sin()* self._mscale\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)  \n\n    def _init(\n        self,\n        dim,\n        max_position_embeddings=2048,\n        base=10000,\n        device=None,\n        scaling_factor=1.0,\n        original_max_position_embeddings=4096,\n        beta_fast=32,\n        beta_slow=1,\n        mscale=1,\n        mscale_all_dim=0,\n    ):\n        self.original_max_position_embeddings = original_max_position_embeddings\n        self.beta_fast = beta_fast\n        self.beta_slow = beta_slow\n        self.mscale = mscale\n        self.mscale_all_dim = mscale_all_dim\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n\n        freq_extra = 1.0 / (\n            self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n        freq_inter = 1.0 / (\n            self.scaling_factor\n            * self.base\n            ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)\n        )\n\n        low, high = yarn_find_correction_range(\n            self.beta_fast,\n            self.beta_slow,\n            dim,\n            self.base,\n            self.original_max_position_embeddings,\n        )\n        inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(\n            device=device, dtype=torch.float32\n        )\n        self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask\n        self._mscale = float(\n            yarn_get_mscale(self.scaling_factor, self.mscale)\n            / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)\n        )\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\nclass DynamicNTKScalingRotaryEmbedding(\n    BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding\n):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        prefill_device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            orig_module.dim,\n            orig_module.max_position_embeddings,\n            orig_module.base,\n            None,  # device\n            orig_module.scaling_factor,\n            orig_module.rope_type,\n            orig_module.config,\n        )\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.dim,\n            self.orig_module.max_position_embeddings,\n            self.orig_module.base,\n            self.orig_module.device,\n            self.orig_module.scaling_factor,\n            self.orig_module.rope_type,\n            self.orig_module.config,\n        )\n\n\n\nclass RotaryEmbeddingV4(BaseInjectedModule):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, generate_device, **kwargs\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n    \n    @maybe_no_grad()\n    def forward(self, x, position_ids):\n        # x: [bs, num_attention_heads, seq_len, head_size]\n        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)\n        position_ids_expanded = position_ids[:, None, :].float()\n        # Force float32 since bfloat16 loses precision on long contexts\n        # See https://github.com/huggingface/transformers/pull/29285\n        device_type = x.device.type\n        device_type = device_type if isinstance(device_type, str) and device_type != \"mps\" else \"cpu\"\n        with torch.autocast(device_type=device_type, enabled=False):\n            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)\n            emb = torch.cat((freqs, freqs), dim=-1)\n            cos = emb.cos()\n            sin = emb.sin()\n        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)   \n\n    def load(self):\n        self._init(\n            dim=self.config.qk_rope_head_dim,\n            max_position_embeddings=self.config.max_position_embeddings,\n            base=self.config.rope_theta,\n            device=self.device,\n        )\n    def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):\n        self.scaling_factor = scaling_factor\n        self.dim = dim\n        self.max_position_embeddings = max_position_embeddings\n        self.base = base\n        self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))\n        # self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        # For BC we register cos and sin cached\n        self.max_seq_len_cached = max_position_embeddings\n\nclass KQwen3MoeRotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        #  device: str = \"cuda\",\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs\n        )\n        self.orig_module.__init__(\n            config,\n        )\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def load(self):\n        self.orig_module.__init__(\n            self.orig_module.config\n        )"
  },
  {
    "path": "kt-sft/ktransformers/operators/__init__.py",
    "content": "\n"
  },
  {
    "path": "kt-sft/ktransformers/operators/attention.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport torch\nfrom torch import nn\nimport warnings\nimport torch.nn.functional as F\nfrom ktransformers.operators.models import KLlamaModel\nfrom ktransformers.models.configuration_deepseek import DeepseekV2Config\nfrom ktransformers.models.configuration_llama import LlamaConfig\nfrom ktransformers.models.modeling_llama import LlamaRotaryEmbedding\nfrom ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention, Qwen3MoeRotaryEmbedding\nfrom typing import Optional, Tuple\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom ktransformers.util.utils import get_compute_capability\nimport logging\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.cache_utils import Cache\nfrom transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS\nfrom ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor\n\ntry:\n    from flash_attn import flash_attn_func\nexcept:\n    pass\nfrom ktransformers.operators.triton_attention import decode_attention_fwd_grouped \nfrom ktransformers.operators.triton_attention_prefill import context_attention_fwd\nimport os\nfrom ktransformers.operators.flashinfer_wrapper import flashinfer_enabled\nif flashinfer_enabled:\n    from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton\n    from flashinfer.mla import BatchMLAPagedAttentionWrapper\nfrom ktransformers.models.custom_cache import KDeepSeekV3Cache\nlogger = logging.getLogger(\"attention\")\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\n# V3 MLA is same to V2\nclass KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n    attn_mask: Optional[torch.Tensor] = None\n\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 absorb_for_prefill: bool = False,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n        self.mla_wrapper = None\n        self.absorb_for_prefill = absorb_for_prefill\n\n    def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:\n        if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):\n            kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)\n            self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)\n            self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)\n            \n        return self.q_absorb, self.out_absorb\n\n    def forward_chunck(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n        # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]\n        # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        compressed_kv = self.kv_a_layernorm(compressed_kv)\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n\n        kv_seq_len = k_pe.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        cos, sin = self.rotary_emb(q_pe, position_ids)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n            \n            # compressed_kv [bsz, q_len, self.kv_lora_rank]\n            # k_pe [bsz, 1, q_len, self.qk_rope_head_dim]\n            k_pe = k_pe.transpose(1,2)\n            compressed_kv = compressed_kv.unsqueeze(2)\n            compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\n            compressed_kv, k_pe = torch.split(\n                compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n            )\n            # k_pe [pages, page_size, 1, self.qk_rope_head_dim]\n            # compressed_kv [pages, page_size, 1, self.kv_lora_rank]\n            \n        q_absorb, out_absorb = self.get_absorbed()\n\n        # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]\n        # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]\n        k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:,:,:attention_mask.size(-1),:]\n        compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:,:,:attention_mask.size(-1),:]\n        # k_pe [bsz, 1, cache_len, self.qk_rope_head_dim]\n        # compressed_kv [bsz, 1, cache_len,self.kv_lora_rank]\n        q_nope = torch.matmul(q_nope, q_absorb)\n        #print(q_pe.shape)\n        #print(k_pe.shape)\n        #print(q_nope.shape)\n        #print(compressed_kv.shape)\n        \n        attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale\n        \n        #attn_weights [bsz, self.num_heads, q_len, kv_seq_len]\n        compressed_kv = compressed_kv.squeeze(1)\n        \"\"\"\n        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n            raise ValueError(\n                f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n                f\" {attn_weights.size()}\"\n            )\n        assert attention_mask is not None\n        \"\"\"\n        if attention_mask is not None:\n            \"\"\"\n            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n                raise ValueError(\n                    f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n                )\n            \"\"\"\n            #causal_mask = attention_mask[:, :, :, : kv_seq_len]\n            attn_weights = attn_weights + attention_mask\n\n        # upcast attention to fp32\n        attn_weights = nn.functional.softmax(\n            attn_weights, dim=-1, dtype=torch.float32\n        ).to(q_pe.dtype)\n        attn_weights = nn.functional.dropout(\n            attn_weights, p=self.attention_dropout, training=self.training\n        )\n        \n        attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)\n        \n        attn_output = torch.matmul(attn_output, out_absorb.mT) \n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        \n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n\n        attn_output = self.o_proj(attn_output)\n\n        return attn_output, None, past_key_value\n\n    def forward_linux_triton(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            past_key_value: Optional[Cache] = None,\n            output_attentions: bool = False,\n            use_cache: bool = False,\n            cache_position: Optional[torch.LongTensor] = None,\n            **kwargs,\n        ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        compressed_kv = self.kv_a_layernorm(compressed_kv)\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)\n        compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)\n\n        kv_seq_len = q_len\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        \n        cos, sin = self.rotary_emb(q_pe, position_ids)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)\n        # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]\n        \n        # decode\n        if q_len == 1:\n            if past_key_value is not None:\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n                compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\n                compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank] # for speed\n                # compressed_kv_with_k_pe [bsz, q_len, 1, self.kv_lora_rank + self.qk_rope_head_dim]\n                # compressed_kv [bsz, q_len, 1, self.kv_lora_rank]\n\n            # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]\n            # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]\n            q_absorb, out_absorb = self.get_absorbed()\n            q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below\n            q_nope = torch.matmul(q_nope, q_absorb) # batched MM\n            q_nope = q_nope.transpose(1, 2)\n            #assert q_nope.is_contiguous()\n            \n            # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]\n            query_states = torch.cat([q_nope, q_pe], dim=-1)\n            \n            query_states = query_states.squeeze(1)\n            attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            \n            attn_logits = torch.empty(\n                    (\n                        bsz,\n                        self.num_heads,\n                        4, #num_kv_splits # follow vLLM, fix it TODO\n                        self.kv_lora_rank + 1, \n                    ),\n                    dtype=torch.float32,\n                    device = attn_output.device\n                )\n\n            \"\"\"\n            print(\"query_states\", torch.isnan(query_states).any())\n            print(\"compressed_kv_with_k_pe\", torch.isnan(compressed_kv_with_k_pe[:,:,0,:]).any())\n            print(\"compressed_kv\", torch.isnan(compressed_kv[:,:,0,:]).any())\n            print(\"position_ids\", torch.isnan(position_ids).any())\n            \"\"\"\n\n            # flash attn doesn't support head_dim bigger than 256\n            # use triton attention kernel adapted from vLLM and SGLang for MQA\n            decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,\n                             page_table,\n                             position_ids.squeeze(0).to(torch.int32)+1, attn_logits,\n                             4, #num_kv_splits # follow vLLM, fix it TODO\n                             self.softmax_scale,\n                             past_key_value.page_size)\n            \n            # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]\n            attn_output = attn_output.transpose(1, 2)\n            attn_output = torch.matmul(attn_output, out_absorb.mT)\n            attn_output = attn_output.transpose(1, 2)\n            \n            attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n            attn_output = self.o_proj(attn_output)\n            \n            #print(\"attn_output\", torch.isnan(attn_output).any())\n            return attn_output, None, past_key_value\n        else:\n            if past_key_value is not None:\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n                k_pe.squeeze(0)\n                compressed_kv.squeeze(0)\n                compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\n                compressed_kv, k_pe = torch.split(\n                    compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n                )\n            k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)\n            k_pe = k_pe[:, :kv_seq_len]\n            compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)\n            compressed_kv = compressed_kv[:, :kv_seq_len]\n            kv = (\n                self.kv_b_proj(compressed_kv)\n                .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            )\n            k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)\n            query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)\n            query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n            query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n            key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)\n            key_states[:, :, :, :self.qk_nope_head_dim] = k_nope\n            key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)\n            \n            value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)\n            value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)\n\n            attn_output = flash_attn_func(\n                query_states,\n                key_states,\n                value_states_padded,\n                softmax_scale=self.softmax_scale,\n                causal=True,\n            )\n\n            if self.q_head_dim != self.v_head_dim:\n                attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n            attn_output = attn_output.reshape(\n                bsz, q_len, self.num_heads * self.v_head_dim\n            ).contiguous()\n            attn_output = self.o_proj(attn_output)\n            return attn_output, None, past_key_value\n\n    def forward_linux_flashinfer(\n            self,\n            hidden_states: torch.Tensor,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.Tensor] = None,\n            past_key_value: Optional[Cache] = None,\n            output_attentions: bool = False,\n            use_cache: bool = False,\n            cache_position: Optional[torch.Tensor] = None,\n            **kwargs,\n        ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        compressed_kv = self.kv_a_layernorm(compressed_kv)\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)\n        compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)\n\n        kv_seq_len = q_len\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n        \n        cos, sin = self.rotary_emb(q_pe, position_ids)\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)\n        # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]\n        \n        # decode\n        if q_len == 1 or self.absorb_for_prefill:\n            if past_key_value is not None:\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n                compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\n                compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, past_key_value.page_size, self.kv_lora_rank)\n                k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, past_key_value.page_size, self.qk_rope_head_dim)\n                # k_pe [max_pages, page_size, self.qk_rope_head_dim]\n                # compressed_kv [max_pages, page_size, self.kv_lora_rank]\n\n            # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]\n            # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]\n            q_absorb, out_absorb = self.get_absorbed()\n            q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below\n            q_nope = torch.matmul(q_nope, q_absorb) # batched MM\n            q_nope = q_nope.transpose(1, 2)\n            q_nope = q_nope.contiguous()\n            #assert q_nope.is_contiguous()\n            \n            # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]\n            q_nope.squeeze_(0)\n            q_pe.squeeze_(0)\n\n            # flash attn doesn't support head_dim bigger than 256, use flashinfer\n            if self.mla_wrapper is None:\n                self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)\n            if self.mla_wrapper.need_plan:\n                self.mla_wrapper.need_plan = False\n                if q_len == 1:\n                    self.mla_wrapper.plan(None,None,None,\n                                        position_ids.squeeze(1)+1,\n                                        None,\n                                        self.num_heads,\n                                        self.kv_lora_rank,\n                                        self.qk_rope_head_dim,\n                                        past_key_value.page_size,\n                                        self.softmax_scale,\n                                        q_nope.dtype,\n                                        compressed_kv.dtype)\n                else:\n                    qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device)\n                    kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)\n                    self.mla_wrapper.plan(qo_indptr,None,None,\n                                        kv_len_arr,\n                                        None,\n                                        self.num_heads,\n                                        self.kv_lora_rank,\n                                        self.qk_rope_head_dim,\n                                        past_key_value.page_size,\n                                        self.softmax_scale,\n                                        q_nope.dtype,\n                                        compressed_kv.dtype)\n            attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)\n            \"\"\"\n            k = (\n                torch.cat([compressed_kv, k_pe], dim=-1)\n                .view(-1, 1, 512 + 64)\n                .repeat_interleave(self.num_heads, dim=1)\n            )\n            v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1)\n            lens = position_ids.item() + 1\n            #print(\"lens\", lens)\n            attn_ref, lse_ref = attention_ref(\n                1,\n                torch.cat([q_nope, q_pe], dim=-1),\n                k[:lens],\n                v[:lens],\n                False,\n                self.softmax_scale\n            )\n            attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)\n            \"\"\"\n            \n            # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]\n            # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]\n            attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]\n            attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]\n            attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]\n            \n            attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]\n            attn_output = self.o_proj(attn_output)\n            \n            return attn_output, None, past_key_value\n        else:\n            if past_key_value is not None:\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}  # Specific to RoPE models\n                k_pe.squeeze(0)\n                compressed_kv.squeeze(0)\n                compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)\n                compressed_kv, k_pe = torch.split(\n                    compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n                )\n            k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)\n            k_pe = k_pe[:, :kv_seq_len]\n            compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)\n            compressed_kv = compressed_kv[:, :kv_seq_len]\n            kv = (\n                self.kv_b_proj(compressed_kv)\n                .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            )\n            k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)\n            query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)\n            query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n            query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n            key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)\n            key_states[:, :, :, :self.qk_nope_head_dim] = k_nope\n            key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)\n            \n            value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)\n            value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)\n\n            attn_output = flash_attn_func(\n                query_states,\n                key_states,\n                value_states_padded,\n                softmax_scale=self.softmax_scale,\n                causal=True,\n            )\n\n            if self.q_head_dim != self.v_head_dim:\n                attn_output = attn_output[:, :, :, : self.v_head_dim]\n\n            attn_output = attn_output.reshape(\n                bsz, q_len, self.num_heads * self.v_head_dim\n            ).contiguous()\n            attn_output = self.o_proj(attn_output)\n            return attn_output, None, past_key_value\n        \n    def forward_windows(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        bsz, q_len, _ = hidden_states.size()\n\n        if q_len <= self.chunck_size:\n            return self.forward_chunck(\n                            hidden_states,\n                            attention_mask,\n                            position_ids,\n                            past_key_value,\n                            output_attentions,\n                            use_cache,\n                            cache_position,\n                            **kwargs\n                        )\n\n        assert output_attentions == False, \"output_attentions is not supported when using chunked attention\"\n        attn_output = None\n        cur_idx = 0\n        while cur_idx < q_len:\n            if attention_mask is not None:\n                chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]\n            else:\n                # generate chunk_mask automatically.\n                self.attn_mask = \\\n                    torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \\\n                        if self.attn_mask is None \\\n                            else self.attn_mask\n                self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \\\n                    -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\\\n                        [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))]\n                self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38\n                self.attn_mask[:, :, :, :cur_idx] = 0\n                chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx))\n\n            cur_output, _, _ = self.forward_chunck(\n                            hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],\n                            chunk_mask,\n                            position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],\n                            past_key_value,\n                            output_attentions,\n                            use_cache,\n                            cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],\n                            **kwargs\n                        )\n            cur_idx += self.chunck_size\n            if attn_output is None:\n                attn_output = cur_output\n            else:\n                attn_output = torch.cat((attn_output, cur_output), dim=-2)\n                \n        return attn_output, None, past_key_value\n\n    def forward_xpu(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if \"padding_mask\" in kwargs:\n            warnings.warn(\n                \"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`\"\n            )\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))\n        query_states = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)\n        kv = (\n            self.kv_b_proj(self.kv_a_layernorm(compressed_kv))\n            .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)\n            .transpose(1, 2)\n        )\n\n        k_nope, value_states = torch.split(\n            kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1\n        )\n        kv_seq_len = value_states.shape[-2]\n        if past_key_value is not None:\n            if self.layer_idx is None:\n                raise ValueError(\n                    f\"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} \"\n                    \"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class \"\n                    \"with a layer index.\"\n                )\n            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)\n\n        position_embeddings = kwargs.get(\"position_embeddings\", None)\n        if position_embeddings is not None:\n            cos, sin = position_embeddings\n            key_states = torch.cat(\n                [k_nope, k_pe.expand([-1, self.num_heads, -1, -1])],\n                dim=-1\n            )\n            from ipex_llm.transformers.models.common import rotary_two_with_cache_inplaced\n            rotary_two_with_cache_inplaced(query_states[:, :, :, self.qk_nope_head_dim :],\n                                           key_states[:, :, :, self.qk_nope_head_dim:],\n                                           cos, sin, True)\n        else:\n            q_nope, q_pe = torch.split(\n                query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n            )\n            cos, sin = self.rotary_emb(q_pe, position_ids)\n            q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)\n            query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n            query_states[:, :, :, : self.qk_nope_head_dim] = q_nope\n            query_states[:, :, :, self.qk_nope_head_dim :] = q_pe\n\n            key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)\n            key_states[:, :, :, : self.qk_nope_head_dim] = k_nope\n            key_states[:, :, :, self.qk_nope_head_dim :] = k_pe\n\n        if past_key_value is not None:\n            cache_kwargs = {\"sin\": sin, \"cos\": cos}  # Specific to RoPE models\n            key_states, value_states = past_key_value.update(\n                key_states.half(), value_states.half(), self.layer_idx, cache_kwargs\n            )\n\n        attn_weights = None\n        from ipex_llm.transformers.models.common import scaled_dot_product_attention\n        attn_output = scaled_dot_product_attention(\n            query_states.half(), key_states, value_states,\n            attention_mask.half(), q_len == kv_seq_len, self.softmax_scale\n        )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)\n        attn_output = self.o_proj(attn_output).to(hidden_states.dtype)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        if torch.xpu.is_available():\n            return self.forward_xpu(\n                hidden_states,\n                attention_mask,\n                position_ids,\n                past_key_value,\n                output_attentions,\n                use_cache,\n                cache_position,\n                **kwargs,\n            )\n        elif (os.name == 'nt'\n              or get_compute_capability() < 8\n              or hidden_states.device.type == 'cpu'\n              or device_manager.gpu_vendor != GPUVendor.NVIDIA):\n            return self.forward_windows(\n                hidden_states,\n                attention_mask,\n                position_ids,\n                past_key_value,\n                output_attentions,\n                use_cache,\n                cache_position,\n                **kwargs,\n            )\n        else:\n            if flashinfer_enabled:\n                return self.forward_linux_flashinfer(\n                    hidden_states,\n                    attention_mask,\n                    position_ids,\n                    past_key_value,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                    **kwargs,\n                )\n            else:\n                return self.forward_linux_triton(\n                    hidden_states,\n                    attention_mask,\n                    position_ids,\n                    past_key_value,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                    **kwargs,\n                )\n\n\nclass KLlamaAttention(BaseInjectedModule):\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n        \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n        Args:\n            q (`torch.Tensor`): The query tensor.\n            k (`torch.Tensor`): The key tensor.\n            cos (`torch.Tensor`): The cosine part of the rotary embedding.\n            sin (`torch.Tensor`): The sine part of the rotary embedding.\n            position_ids (`torch.Tensor`, *optional*):\n                Deprecated and unused.\n            unsqueeze_dim (`int`, *optional*, defaults to 1):\n                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n        Returns:\n            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n        \"\"\"\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n        q_embed = (q * cos) + (rotate_half(q) * sin)\n        k_embed = (k * cos) + (rotate_half(k) * sin)\n        return q_embed, k_embed\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_value: Optional[Cache] = None,\n        output_attentions: bool = False,\n        use_cache: bool = False,\n        cache_position: Optional[torch.LongTensor] = None,\n        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45\n        **kwargs,\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        bsz, q_len, _ = hidden_states.size()\n\n        if self.config.pretraining_tp > 1:\n            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp\n            query_slices = self.q_proj.weight.split(\n                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0\n            )\n            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)\n            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)\n\n            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]\n            query_states = torch.cat(query_states, dim=-1)\n\n            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]\n            key_states = torch.cat(key_states, dim=-1)\n\n            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]\n            value_states = torch.cat(value_states, dim=-1)\n\n        else:\n            query_states = self.q_proj(hidden_states)\n            key_states = self.k_proj(hidden_states)\n            value_states = self.v_proj(hidden_states)\n\n        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n\n        if position_embeddings is None:\n\n            logger.warning(\n                \"The attention layers in this model are transitioning from computing the RoPE embeddings internally \"\n                \"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed \"\n                \"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be \"\n                \"removed and `position_embeddings` will be mandatory.\"\n            )\n            cos, sin = self.rotary_emb(value_states, position_ids)\n        else:\n            cos, sin = position_embeddings\n        query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)\n        if q_len == 1:\n            position_ids = position_ids[0][-1].unsqueeze(0).unsqueeze(0)\n            query_states = query_states[:, :, -1:]\n            key_states = key_states[:, :, -1:]\n\n        attn_output = KLlamaModel.dynamic_sdpa.apply(\n            self.layer_idx,\n            bsz,\n            position_ids[0][0],\n            query_states.transpose(1, 2).to(torch.float16),\n            key_states.transpose(1, 2).to(torch.float16),\n            value_states.transpose(1, 2).to(torch.float16),\n            mode=\"prefill\" if q_len > 1 else \"generate\",\n        )\n\n\n        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n            raise ValueError(\n                f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n                f\" {attn_output.size()}\"\n            )\n\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(bsz, q_len, -1)\n\n        if self.config.pretraining_tp > 1:\n            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)\n            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)\n            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])\n        else:\n            attn_output = self.o_proj(attn_output)\n\n        if not output_attentions:\n            attn_weights = None\n\n        return attn_output, attn_weights, past_key_value\n\n\nclass KQwen3MoeAttentionIPEXLLM(BaseInjectedModule, Qwen3MoeAttention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"xpu\",\n                 generate_device: str = \"xpu\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n        assert prefill_device.lower()[:3] == \"xpu\", \"KQwen3MoeAttentionIPEXLLM only supports XPU device\"\n        assert generate_device.lower()[:3] == \"xpu\", \"KQwen3MoeAttentionIPEXLLM only supports XPU device\"\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        position_ids: Optional[torch.Tensor],\n        position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n        attention_mask: Optional[torch.Tensor],\n        past_key_value: Optional[Cache] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        **kwargs\n    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:\n        input_shape = hidden_states.shape[:-1]\n        bsz, q_len, _ = hidden_states.size()\n        input_dtype = hidden_states.dtype\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        if not hasattr(self, 'qkv_proj'):\n            from ipex_llm.transformers.models.common import merge_quantized_qkv\n            merge_quantized_qkv(self.q_proj.generate_linear, self.k_proj.generate_linear, self.v_proj.generate_linear, self.orig_module)\n\n        qkv = self.qkv_proj(hidden_states)\n        qkv = qkv.view(bsz, q_len, -1, self.head_dim)\n        qkv = qkv.transpose(1, 2)\n        query_states, key_states, value_states = qkv.split([self.config.num_attention_heads,\n                                                            self.config.num_key_value_heads,\n                                                            self.config.num_key_value_heads], dim=1)\n        query_states = self.q_norm(query_states)\n        key_states = self.k_norm(key_states)\n\n        if position_embeddings is None:\n            position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        cos, sin = position_embeddings\n\n        from ipex_llm.transformers.models.common import rotary_half_with_cache_inplaced\n        rotary_half_with_cache_inplaced(query_states, key_states, cos, sin)\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(key_states.half(), value_states.half(),\n                                                             self.layer_idx, cache_kwargs)\n\n        attn_weights = None\n        from ipex_llm.transformers.models.common import scaled_dot_product_attention\n        attn_output = scaled_dot_product_attention(\n            query_states.half(), key_states, value_states,\n            attention_mask.half(), q_len == key_states.size(2), self.scaling\n        )\n        attn_output = attn_output.transpose(1, 2).contiguous()\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output).to(input_dtype)\n        return attn_output, attn_weights\n\n\ndef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n    \"\"\"\n    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n    \"\"\"\n    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n    if n_rep == 1:\n        return hidden_states\n    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)\n\n\ndef eager_attention_forward(\n    module: nn.Module,\n    query: torch.Tensor,\n    key: torch.Tensor,\n    value: torch.Tensor,\n    attention_mask: Optional[torch.Tensor],\n    scaling: float,\n    dropout: float = 0.0,\n    **kwargs,\n):\n    key_states = repeat_kv(key, module.num_key_value_groups)\n    value_states = repeat_kv(value, module.num_key_value_groups)\n\n    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling\n    if attention_mask is not None:\n        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]\n        attn_weights = attn_weights + causal_mask\n\n    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)\n    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)\n    attn_output = torch.matmul(attn_weights, value_states)\n    attn_output = attn_output.transpose(1, 2).contiguous()\n\n    return attn_output, attn_weights\n\n\nclass KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention ):\n    def __init__(self,\n                 key: str,\n                 gguf_loader: GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device,\n                                    **kwargs)\n        self.orig_module.__init__(self.orig_module.config,\n                                  orig_module.layer_idx)\n        self.chunck_size = chunck_size  # TODO, generate chunck_size automatically.\n\n    # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\n    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n        \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n        Args:\n            q (`torch.Tensor`): The query tensor.\n            k (`torch.Tensor`): The key tensor.\n            cos (`torch.Tensor`): The cosine part of the rotary embedding.\n            sin (`torch.Tensor`): The sine part of the rotary embedding.\n            position_ids (`torch.Tensor`):\n                Deprecated and unused.\n            unsqueeze_dim (`int`, *optional*, defaults to 1):\n                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n        Returns:\n            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n        \"\"\"\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n        q_embed = (q * cos) + (rotate_half(q) * sin)\n        k_embed = (k * cos) + (rotate_half(k) * sin)\n        return q_embed, k_embed\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                position_ids: Optional[torch.Tensor],\n                position_embeddings: Tuple[torch.Tensor, torch.Tensor],\n                attention_mask: Optional[torch.Tensor],\n                past_key_value: Optional[Cache] = None,\n                cache_position: Optional[torch.LongTensor] = None,\n                **kwargs\n                ):\n        input_shape = hidden_states.shape[:-1]\n        hidden_shape = (*input_shape, -1, self.head_dim)\n\n        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)\n        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)\n\n        if position_embeddings is None:\n            position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        cos, sin = position_embeddings\n\n        query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)\n\n\n        if past_key_value is not None:\n            # sin and cos are specific to RoPE models; cache_position needed for the static cache\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}\n            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)\n\n        attention_interface: Callable = eager_attention_forward\n        if self.config._attn_implementation != \"eager\":\n            if self.config._attn_implementation == \"sdpa\" and kwargs.get(\"output_attentions\", False):\n                logger.warning_once(\n                    \"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to \"\n                    'eager attention. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.'\n                )\n            else:\n                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]\n\n        attn_output, attn_weights = attention_interface(\n            self,\n            query_states,\n            key_states,\n            value_states,\n            attention_mask,\n            dropout=0.0 if not self.training else self.attention_dropout,\n            scaling=self.scaling,\n            sliding_window=self.sliding_window,  # diff with Llama\n            **kwargs,\n        )\n\n        attn_output = attn_output.reshape(*input_shape, -1).contiguous()\n        attn_output = self.o_proj(attn_output)\n        return attn_output, attn_weights\n"
  },
  {
    "path": "kt-sft/ktransformers/operators/balance_serve_attention.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.2.5\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport torch\nfrom torch import nn\nfrom ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeAttention\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeAttention\nfrom typing import Optional, Tuple\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader\nimport logging\nfrom transformers.configuration_utils import PretrainedConfig\nfrom flashinfer import BatchMLAPagedAttentionWrapper\nfrom ktransformers.operators.flashinfer_batch_prefill_wrapper import flashInferAttn\nfrom ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache\nlogger = logging.getLogger(\"attention\")\n\n# Copied from transformers.models.llama.modeling_llama.rotate_half\ndef rotate_half(x):\n    \"\"\"Rotates half the hidden dims of the input.\"\"\"\n    x1 = x[..., : x.shape[-1] // 2]\n    x2 = x[..., x.shape[-1] // 2 :]\n    return torch.cat((-x2, x1), dim=-1)\n\nclass flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n\n    def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:\n        if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):\n            kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)\n            q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)\n            out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)\n            self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, \n                                      bias=False, dtype=q_absorb.dtype, device=q_absorb.device)\n            self.q_absorb.weight.data = q_absorb\n            self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, \n                                        bias=False, dtype=out_absorb.dtype, device=out_absorb.device)\n            self.out_absorb.weight.data = out_absorb\n            #del self.orig_module.kv_b_proj\n        q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)\n        out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)\n        return q_absorb, out_absorb\n    \n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KDeepSeekV3Cache,\n                position_ids: torch.Tensor,\n                wrapper: BatchMLAPagedAttentionWrapper,\n                num_tokens_tensors: torch.Tensor,\n                page_idx: torch.Tensor,\n                page_offset: torch.Tensor,\n                ):\n        q_len, _ = hidden_states.size()\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states, num_tokens_tensors)\n        else:\n            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)\n        q = q.view(q_len, self.num_heads, self.q_head_dim)\n        q_nope, q_pe = torch.split(\n            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n        )\n\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)\n        compressed_kv, k_pe = torch.split(\n            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n        )\n        compressed_kv = compressed_kv.contiguous()\n        compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)\n        k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)\n        compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)\n        \n        cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))\n        q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)\n        q_pe = q_pe.squeeze(0)\n        if kv_cache is not None:\n            \n            # page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)\n            cache_kwargs = {\"sin\": sin, \"cos\": cos, \"page_idx\": page_idx, \"page_offset\": page_offset}  # Specific to RoPE models\n            compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)\n            compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)\n            k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)\n            \n        q_absorb, out_absorb = self.get_absorbed()\n        q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below\n        q_nope = torch.matmul(q_nope, q_absorb) # batched MM\n        q_nope = q_nope.transpose(0, 1)\n        # q_nope.squeeze_(1)\n        # q_pe.squeeze_(1)\n\n        attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)\n        attn_output = attn_output.transpose(0, 1)\n        attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]\n        attn_output = attn_output.transpose(0, 1)\n        attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)\n        attn_output = self.o_proj(attn_output, num_tokens_tensors)\n        return attn_output\n\nclass KQwen2MoeAttention(BaseInjectedModule, Qwen2MoeAttention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n\n    # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\n    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n        \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n        Args:\n            q (`torch.Tensor`): The query tensor.\n            k (`torch.Tensor`): The key tensor.\n            cos (`torch.Tensor`): The cosine part of the rotary embedding.\n            sin (`torch.Tensor`): The sine part of the rotary embedding.\n            position_ids (`torch.Tensor`):\n                Deprecated and unused.\n            unsqueeze_dim (`int`, *optional*, defaults to 1):\n                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n        Returns:\n            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n        \"\"\"\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n        q_embed = (q * cos) + (rotate_half(q) * sin)\n        k_embed = (k * cos) + (rotate_half(k) * sin)\n        return q_embed, k_embed\n\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KGQACache,\n                position_ids: torch.Tensor,\n                wrapper: flashInferAttn,\n                bsz_tensors: torch.Tensor,\n                page_idx: torch.Tensor,\n                page_offset: torch.Tensor,\n                ):\n        q_len, _ = hidden_states.size()\n\n        query_states = self.q_proj(hidden_states, bsz_tensors)\n        key_states = self.k_proj(hidden_states, bsz_tensors)\n        value_states = self.v_proj(hidden_states, bsz_tensors)\n\n\n        query_states = query_states.view(q_len, self.num_heads, self.head_dim)\n        key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        \n        cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0))\n        query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)\n\n        query_states = query_states.view(q_len, self.num_heads, self.head_dim)\n        key_states = key_states.view(\n            q_len, self.num_key_value_heads, self.head_dim\n        )\n        value_states = value_states.view(\n            q_len, self.num_key_value_heads, self.head_dim\n        )\n\n        k_cache = kv_cache.get_k_cache(self.layer_idx)\n        v_cache = kv_cache.get_v_cache(self.layer_idx)\n\n\n        attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)\n  \n\n        attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)\n\n        return attn_output\n\nclass KQwen3MoeAttention(BaseInjectedModule, Qwen3MoeAttention):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 chunck_size: int = 1000,\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.layer_idx)\n        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n\n    # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb\n    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):\n        \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n\n        Args:\n            q (`torch.Tensor`): The query tensor.\n            k (`torch.Tensor`): The key tensor.\n            cos (`torch.Tensor`): The cosine part of the rotary embedding.\n            sin (`torch.Tensor`): The sine part of the rotary embedding.\n            position_ids (`torch.Tensor`):\n                Deprecated and unused.\n            unsqueeze_dim (`int`, *optional*, defaults to 1):\n                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n        Returns:\n            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n        \"\"\"\n        cos = cos.unsqueeze(unsqueeze_dim)\n        sin = sin.unsqueeze(unsqueeze_dim)\n        q_embed = (q * cos) + (rotate_half(q) * sin)\n        k_embed = (k * cos) + (rotate_half(k) * sin)\n        return q_embed, k_embed\n\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KGQACache,\n                position_ids: torch.Tensor,\n                wrapper: flashInferAttn,\n                bsz_tensors: torch.Tensor,\n                page_idx: torch.Tensor,\n                page_offset: torch.Tensor,\n                ):\n        q_len, _ = hidden_states.size()\n\n        bsz_tensors_q = bsz_tensors * self.num_heads\n        bsz_tensors_kv = bsz_tensors * self.num_key_value_heads\n\n        query_states = self.q_norm(self.q_proj(hidden_states, bsz_tensors), bsz_tensors_q)\n        key_states = self.k_norm(self.k_proj(hidden_states, bsz_tensors), bsz_tensors_kv)\n        value_states = self.v_proj(hidden_states, bsz_tensors)\n\n\n        query_states = query_states.view(q_len, self.num_heads, self.head_dim)\n        key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim)\n        \n        cos, sin = self.rotary_emb(value_states.unsqueeze(0), position_ids.unsqueeze(0))\n        query_states, key_states = self.apply_rotary_pos_emb(query_states.unsqueeze(0), key_states.unsqueeze(0), cos, sin, unsqueeze_dim=2)\n\n        query_states = query_states.view(q_len, self.num_heads, self.head_dim)\n        key_states = key_states.view(\n            q_len, self.num_key_value_heads, self.head_dim\n        )\n        value_states = value_states.view(\n            q_len, self.num_key_value_heads, self.head_dim\n        )\n\n        k_cache = kv_cache.get_k_cache(self.layer_idx)\n        v_cache = kv_cache.get_v_cache(self.layer_idx)\n\n\n        attn_output = wrapper.forward(query_states, k_cache, v_cache, key_states, value_states)\n  \n\n        attn_output = self.o_proj(attn_output.view(q_len, self.num_heads * self.head_dim), bsz_tensors)\n\n        return attn_output\n\n\nclass deepseek_torch_attn(BaseInjectedModule, DeepseekV2Attention):\n    def __init__(self,\n                    key: str,\n                    gguf_loader : GGUFLoader,\n                    config: PretrainedConfig,\n                    orig_module: nn.Module,\n                    prefill_device: str = \"cuda\",\n                    generate_device: str = \"cuda\",\n                    chunck_size: int = 1000,\n                    **kwargs):\n            BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n            self.orig_module.__init__(orig_module.config,\n                orig_module.layer_idx)\n            self.chunck_size = chunck_size # TODO, generate chunck_size automatically.\n\n\n    def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:\n        if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):\n            kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)\n            q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)\n            out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)\n            self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, \n                                    bias=False, dtype=q_absorb.dtype, device=q_absorb.device)\n            self.q_absorb.weight.data = q_absorb\n            self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim, \n                                        bias=False, dtype=out_absorb.dtype, device=out_absorb.device)\n            self.out_absorb.weight.data = out_absorb\n            #del self.orig_module.kv_b_proj\n        q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)\n        out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)\n        return q_absorb, out_absorb\n    \n\n\n    def forward(self,\n                hidden_states: torch.Tensor,\n                kv_cache: KDeepSeekV3Cache,\n                position_ids: torch.Tensor,\n                wrapper: None,\n                num_tokens_tensors: torch.Tensor,\n                page_idx: torch.Tensor,\n                page_offset: torch.Tensor,\n                attention_masks: Optional[list[torch.Tensor]] = None,\n                q_indptr: Optional[torch.Tensor] = None,\n                kv_indices: Optional[torch.Tensor] = None,\n                kv_indptr: Optional[torch.Tensor] = None,\n                bsz_tensors: Optional[torch.Tensor] = None,\n                last_page_len: Optional[torch.Tensor] = None,\n                ):\n        # range bsz_tensors\n        final_attention_output = torch.tensor([], device=hidden_states.device)\n        for i in range(bsz_tensors[0]):\n            batch_num_tokens_tensors = q_indptr[i+1] - q_indptr[i]\n            batch_last_page_len = last_page_len[i]\n            # kv_total_len is kv_len, batch_compressed_kv is compressed_kv, batch_k_pe is k_pe\n            batch_page_idx = page_idx[q_indptr[i]:q_indptr[i+1]]\n            batch_page_offset = page_offset[q_indptr[i]:q_indptr[i+1]]\n            # kv_page_nums is the number of pages for the current batch\n            kv_page_nums = kv_indptr[i+1] - kv_indptr[i]\n            # kv_total_len is the total length of the kv cache for the current batch (kv_len for algorithm)\n            kv_total_len = kv_page_nums * kv_cache.page_size\n            if batch_last_page_len is not None:\n                kv_total_len = kv_total_len - (kv_cache.page_size - batch_last_page_len)\n            # print(f\"kv_total_len's shape {kv_total_len.shape}\")\n            # kv_index is the index of the kv cache pages for the current batch\n            kv_index = kv_indices[kv_indptr[i]:kv_indptr[i+1]]\n            # we can index [kv_index, page_offset_indices] to get the kv cache for the current batch\n            # from q_indptr[i] to q_indptr[i+1] is the range of the current batch\n            batch_hidden_states = hidden_states[q_indptr[i]:q_indptr[i+1]]\n            batch_position_ids = position_ids[q_indptr[i]:q_indptr[i+1]]\n            q_len, _ = batch_hidden_states.size()\n            # print(\"q_len -> \", q_len)\n\n            if self.q_lora_rank is None:\n                q = self.q_proj(batch_hidden_states, batch_num_tokens_tensors)\n            else:\n                q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(batch_hidden_states, batch_num_tokens_tensors), batch_num_tokens_tensors), batch_num_tokens_tensors)\n            # for v3, bsz, q_len, num_heads(128), qk_head_dim(192=128(nope)+64(rope))\n            q = q.view(q_len, self.num_heads, self.q_head_dim)\n            # q_nope is [q_len, num_heads(128), qk_nope_head_dim(128)]\n            # q_pe is [q_len, num_heads(128), qk_rope_head_dim(64)]\n            q_nope, q_pe = torch.split(\n                q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1\n            )\n            # compressed_kv is [q_len, kv_lora_rank(512) + rope(64)]\n            compressed_kv = self.kv_a_proj_with_mqa(batch_hidden_states, batch_num_tokens_tensors)\n            # compressed_kv is [q_len, kv_lora_rank(512)], k_pe is [q_len, rope(64)]\n            compressed_kv, k_pe = torch.split(\n                compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1\n            )\n            compressed_kv = compressed_kv.contiguous()\n            compressed_kv = self.kv_a_layernorm(compressed_kv, batch_num_tokens_tensors)\n            # k_pe is [q_len, 1, qk_rope_head_dim(64)]\n            k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)\n            # compressed_kv is [q_len, 1, kv_lora_rank(512)]\n            compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)\n            \n            cos, sin = self.rotary_emb(q_pe, batch_position_ids.unsqueeze(0))\n            # print(f\"q_pe shape{q_pe.shape}, k_pe shape {k_pe.shape}\")\n            q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)\n            q_pe = q_pe.squeeze(0)\n            # q_pe is [num_heads(128), q_len, qk_rope_head_dim(64)]\n            q_pe.transpose_(0, 1)            \n            if kv_cache is not None:\n                cache_kwargs = {\"sin\": sin, \"cos\": cos, \"page_idx\": batch_page_idx, \"page_offset\": batch_page_offset}  # Specific to RoPE models\n                compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, batch_page_idx, batch_page_offset, cache_kwargs)\n                compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)\n                k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)\n            # q_absorb is [num_heads(128), qk_nope_head_dim(128), kv_lora_rank(512)]\n            # out_absorb is [num_heads(128), kv_lora_rank(512), v_head_dim(128)] v_head_dim is also the nope dim\n            q_absorb, out_absorb = self.get_absorbed()\n            # q_nope is [num_heads(128), q_len, qk_nope_head_dim(128)]\n            q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below\n            # q_nope is [num_heads(128), q_len, kv_lora_rank(512)]\n            q_nope = torch.matmul(q_nope, q_absorb) # batched MM\n\n            # # q_nope is [q_len, num_heads(128), kv_lora_rank(512)]\n            # q_nope = q_nope.transpose(0, 1)\n\n            # we need to index out the compressed_kv and k_pe for the current batch\n            batch_compressed_kv = None\n            batch_k_pe = None\n            for page_index in kv_index:\n                if kv_total_len > kv_cache.page_size:\n                    tmp_compressed_kv = compressed_kv[page_index, 0:kv_cache.page_size, :]\n                    tmp_k_pe = k_pe[page_index, 0:kv_cache.page_size, :]\n                    if batch_compressed_kv is None or batch_k_pe is None:\n                        batch_compressed_kv = tmp_compressed_kv\n                        batch_k_pe = tmp_k_pe\n                    else: \n                        batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n                        batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n                    kv_total_len -= kv_cache.page_size\n                else:\n                    tmp_compressed_kv = compressed_kv[page_index, 0:kv_total_len, :]\n                    tmp_k_pe = k_pe[page_index, 0:kv_total_len, :]\n                    if batch_compressed_kv is None or batch_k_pe is None:\n                        batch_compressed_kv = tmp_compressed_kv\n                        batch_k_pe = tmp_k_pe\n                    else: \n                        batch_compressed_kv = torch.cat((batch_compressed_kv, tmp_compressed_kv), dim=0)\n                        batch_k_pe = torch.cat((batch_k_pe, tmp_k_pe), dim=0)\n                    break\n            # batch_compressed_kv is [kv_total_len(k_len), kv_lora_rank(512)]\n            # batch_k_pe is [kv_total_len(k_len), qk_rope_head_dim(64)]\n            attention_weights = (torch.matmul(q_pe,batch_k_pe.mT) + torch.matmul(q_nope, batch_compressed_kv.mT)) * self.softmax_scale\n            # attention_weights is [num_heads(128), q_len, k_len]\n            \n            # attention_weights = attention_weights.transpose(0,1).unsqueeze(0).squeeze(-1).expand(q_len,-1,-1).transpose(0,1)\n            \n            # attention_masks[i] is [q_len, k_len]\n            \n            attention_weights = (attention_weights + attention_masks[i])\n            # attention_weights shape is [num_heads(128), q_len, k_len]\n            attention_weights = nn.functional.softmax(attention_weights,dim=-1,dtype=torch.float32).to(q_pe.dtype)\n            attn_output = torch.matmul(attention_weights, batch_compressed_kv) # [num_heads(128),q_len, lora_rank(512)]\n            # out_absorb shape is [num_heads(128), kv_lora_rank(512), v_head_dim(128)]\n            out_absorb = out_absorb.transpose(1,2)\n            # q for q_len, n for num_heads, h for v_head_dim, v for kv_lora_rank\n            attn_output = torch.matmul(attn_output, out_absorb) # [num_heads(128), q_len, v_head_dim(128)]\n            attn_output = attn_output.transpose(0, 1) # [q_len, num_heads(128), v_head_dim(128)]\n            attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)\n            attn_output = self.o_proj(attn_output, batch_num_tokens_tensors)\n            final_attention_output = torch.cat((final_attention_output, attn_output), dim=0)\n        return final_attention_output"
  },
  {
    "path": "kt-sft/ktransformers/operators/base_operator.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nfrom typing import Any\nfrom torch import nn, Tensor\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom transformers.configuration_utils import PretrainedConfig\nimport ktransformers.util.utils as utils\nclass BaseInjectedModule(nn.Module):\n    \n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        nn.Module.__init__(self)\n        nn.Module.__setattr__(self, \"orig_module\", orig_module)\n        object.__setattr__(self, \"key\", key)\n        object.__setattr__(self, \"gguf_loader\", gguf_loader)\n        object.__setattr__(self, \"config\", config)\n        object.__setattr__(self, \"prefill_device\", prefill_device)\n        object.__setattr__(self, \"generate_device\", generate_device)\n        object.__setattr__(self, \"device\", generate_device)\n        \n    def __getattr__(self, name: str) -> Any:\n        # __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,\n        # but __setattr__ in nn.Module call super().__setattr__ in that case, there may be some attribute set \n        # but can't get using __getattr__, typically these attr is build in attr of the class, so class.attr does not\n        # call __getattr__.\n        # Example:\n        # ...import torch\n        # ...l=torch.nn.Linear(100,200)\n        # ...l.out_features # 200\n        # ...l.__getattr__(\"out_features\") # AttributeError: 'Linear' object has no attribute 'out_features'\n        try:\n            return object.__getattribute__(self, name) # if this attr belongs to BaseInjectedModule\n        except:\n            if name == \"orig_module\":\n                return nn.Module.__getattr__(self, \"orig_module\")\n            try:\n                return nn.Module.__getattr__(self, \"orig_module\").__getattr__(name) # if this attr belongs to orig_module\n            except:\n                return super(nn.Module, nn.Module.__getattr__(self, \"orig_module\")).__getattribute__(name) # if this attr belongs to orig_module but not in nn.Module.__dict__\n\n    def __setattr__(self, name: str, value: Tensor | nn.Module) -> None:\n        if name == \"orig_module\":\n            return nn.Module.__setattr__(self, \"orig_module\", value)\n        # elif name == \"base_layer\":\n        #     return nn.Module.__setattr__(self, \"base_layer\", value)\n        elif hasattr(self, name):\n            return object.__setattr__(self, name, value)\n        return nn.Module.__getattr__(self, \"orig_module\").__setattr__(name, value)\n    \n    def forward(self, *args, **kwargs):\n        return self.orig_module.forward(*args, **kwargs)\n    \n    def load(self, gguf_loader=None, adapter_gguf : bool = False):\n        for name, child in self._modules.items():\n            if gguf_loader==None:\n                utils.load_weights(child, self.gguf_loader, self.key+\".\", adapter_gguf=adapter_gguf)\n            else:\n                utils.load_weights(child, gguf_loader, self.key+\".\", adapter_gguf=adapter_gguf)\n"
  },
  {
    "path": "kt-sft/ktransformers/operators/cpuinfer.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  : This script defines the `CPUInferKVCache` and `CPUInfer` classes for performing inference \n               with a Key-Value Cache on the CPU. The `CPUInferKVCache` class is responsible for configuring \n               and managing key-value caches, updating and retrieving cache data, and handling attention \n               operations. It supports different cache types (e.g., Q4_0, FP16) and retrieval strategies \n               (e.g., shared, separate). The `CPUInfer` class handles task submission and synchronization \n               on the CPU, with optional CUDA stream integration for tasks involving GPU acceleration. \n               These classes facilitate efficient caching and memory management for deep learning models \n               that leverage key-value attention mechanisms, particularly on CPU-based systems.\nAuthor       : djw\nDate         : 2024-08-26 23:25:24\nVersion      : 1.0.0\nLastEditors  : djw \nLastEditTime : 2024-08-26 23:25:24\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n\"\"\"\nimport sys, os\nfrom typing import Any\nimport torch\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Release\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Debug\"))\nimport cpuinfer_ext\nfrom ktransformers.server.config.config import Config\n\n\nclass CPUInferKVCache:\n    def __init__(\n        self,\n        layer_num: int = 32,\n        kv_head_num: int = 8,\n        q_head_num: int = 32,\n        head_dim: int = 128,\n        block_len: int = 256,\n        anchor_num: int = 4,\n        anchor_type: str = \"FIXED\",\n        kv_type: str = \"Q4_0\",\n        retrieval_type: str = \"SHARED\",\n        layer_step: int = 1,\n        token_step: int = 1,\n        layer_offset: int = 0,\n        max_thread_num: int = 32,\n        max_batch_size: int = 4,\n        max_block_num: int = 512,\n    ):\n\n        if anchor_type == \"FIXED\":\n            anchor_type = cpuinfer_ext.kvcache.AnchorType.FIXED\n        elif anchor_type == \"QUEST\":\n            anchor_type = cpuinfer_ext.kvcache.AnchorType.QUEST\n        elif anchor_type == \"DYNAMIC\":\n            anchor_type = cpuinfer_ext.kvcache.AnchorType.DYNAMIC\n        elif anchor_type == \"BLOCK_MEAN\":\n            anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MEAN\n        elif anchor_type == \"BLOCK_MAX\":\n            anchor_type = cpuinfer_ext.kvcache.AnchorType.BLOCK_MAX\n        else:\n            raise ValueError(f\"Unknown anchor type: {anchor_type}\")\n\n        if kv_type == \"FP16\":\n            kv_type = cpuinfer_ext.kvcache.ggml_type.FP16\n        elif kv_type == \"FP32\":\n            assert False, \"FP32 is not supported yet.\"\n            kv_type = cpuinfer_ext.kvcache.ggml_type.FP32\n        elif kv_type == \"Q4_0\":\n            kv_type = cpuinfer_ext.kvcache.ggml_type.Q4_0\n        elif kv_type == \"Q8_0\":\n            kv_type = cpuinfer_ext.kvcache.ggml_type.Q8_0\n        else:\n            raise ValueError(f\"Unknown kv type: {kv_type}\")\n\n        if retrieval_type == \"SHARED\":\n            retrieval_type = cpuinfer_ext.kvcache.RetrievalType.LAYER\n        elif retrieval_type == \"INDIVIDUAL\":\n            retrieval_type = cpuinfer_ext.kvcache.RetrievalType.QHEAD\n        elif retrieval_type == \"SEPARATE\":\n            retrieval_type = cpuinfer_ext.kvcache.RetrievalType.KVHEAD\n\n        self.config = cpuinfer_ext.kvcache.KVCacheConfig(\n            layer_num,\n            kv_head_num,\n            q_head_num,\n            head_dim,\n            block_len,\n            anchor_num,\n            anchor_type,\n            kv_type,\n            retrieval_type,\n            layer_step,\n            token_step,\n            layer_offset,\n            max_block_num,\n            max_batch_size,\n            max_thread_num,\n        )\n        self.kvcache = cpuinfer_ext.kvcache.KVCache(self.config)\n\n    def load_kvcache(self, tensor_file_path: str):\n        if not os.path.exists(tensor_file_path):\n            raise FileNotFoundError(f\"The file {tensor_file_path} does not exist.\")\n        return self.kvcache.load_kvcache(tensor_file_path,)\n\n    def dump_kvcache(\n        self, block_table: torch.Tensor, cache_total_len: int, tensor_file_path: str\n    ):\n        assert (\n            block_table.dim() == 1\n            and block_table.dtype == torch.int\n            and block_table.is_contiguous()\n            and block_table.device == torch.device(\"cpu\")\n        ), \"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            block_table.dim(),\n            block_table.size(),\n            block_table.dtype,\n            block_table.is_contiguous(),\n            block_table.device,\n        )\n\n        assert (\n            cache_total_len > 0\n            and cache_total_len <= self.config.block_len * block_table.size(0)\n        ), \"cache_total_len: {}\".format(cache_total_len)\n\n        if not os.path.exists(os.path.dirname(tensor_file_path)):\n            os.makedirs(os.path.dirname(tensor_file_path))\n\n        return self.kvcache.dump_kvcache(\n            block_table.data_ptr(),\n            cache_total_len,\n            tensor_file_path,\n        )\n\n    def update_cache_total_len(self, cache_total_len: int):\n        assert cache_total_len > 0, \"cache_total_len: {}\".format(cache_total_len)\n        self.kvcache.update_cache_total_len(cache_total_len)\n\n    # q_in: (bsz, q_len, q_head_num, head_dim)\n    # output: (bsz, q_len, q_head_num, head_dim)\n    # attn_lse: (bsz, q_len, q_head_num)\n    # block_table: (bsz, max_block_num)\n    def attn(\n        self,\n        q_in: torch.Tensor,\n        output: torch.Tensor,\n        attn_lse: torch.Tensor,\n        layer_idx: int,\n        generate_token_idx: int,\n        block_table: torch.Tensor | None = None,\n        cache_seqlens: torch.Tensor | None = None,\n        pick_block_num: int | None = None,\n        init_block_num: int | None = None,\n        local_block_num: int | None = None,\n    ):\n\n        assert (\n            q_in.dim() == 4\n            and q_in.size(2) == self.config.q_head_num\n            and q_in.size(3) == self.config.head_dim\n            and q_in.dtype == torch.float16\n            and q_in.is_contiguous()\n            and q_in.device == torch.device(\"cpu\")\n        ), \"q_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            q_in.dim(), q_in.size(), q_in.dtype, q_in.is_contiguous(), q_in.device\n        )\n\n        batch_size = q_in.size(0)\n        q_len = q_in.size(1)\n\n        assert (block_table is None) or (\n            block_table.dim() == 2\n            and block_table.size(0) == batch_size\n            and block_table.dtype == torch.int\n            and block_table.is_contiguous()\n            and block_table.device == torch.device(\"cpu\")\n        ), \"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            block_table.dim(),\n            block_table.size(),\n            block_table.dtype,\n            block_table.is_contiguous(),\n            block_table.device,\n        )\n\n        max_block_num = block_table.size(1) if block_table is not None else 0\n\n        assert (\n            output.dim() == 4\n            and output.size(0) == batch_size\n            and output.size(2) == self.config.q_head_num\n            and output.size(1) == q_len\n            and output.size(3) == self.config.head_dim\n            and output.dtype == torch.float16\n            and output.is_contiguous()\n            and output.device == torch.device(\"cpu\")\n        ), \"output dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            output.dim(),\n            output.size(),\n            output.dtype,\n            output.is_contiguous(),\n            output.device,\n        )\n\n        assert (\n            attn_lse.dim() == 3\n            and attn_lse.size(0) == batch_size\n            and attn_lse.size(1) == q_len\n            and attn_lse.size(2) == self.config.q_head_num\n            and attn_lse.dtype == torch.float32\n            and attn_lse.is_contiguous()\n            and attn_lse.device == torch.device(\"cpu\")\n        ), \"attn_lse dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            attn_lse.dim(),\n            attn_lse.size(),\n            attn_lse.dtype,\n            attn_lse.is_contiguous(),\n            attn_lse.device,\n        )\n\n        assert (\n            layer_idx >= 0 and layer_idx < self.config.layer_num\n        ), \"layer_idx: {}\".format(layer_idx)\n\n        assert (cache_seqlens is None) or (\n            cache_seqlens.dim() == 1\n            and cache_seqlens.size(0) == batch_size\n            and cache_seqlens.dtype == torch.int\n            and cache_seqlens.is_contiguous()\n            and cache_seqlens.device == torch.device(\"cpu\")\n        ), \"cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            cache_seqlens.dim(),\n            cache_seqlens.size(),\n            cache_seqlens.dtype,\n            cache_seqlens.is_contiguous(),\n            cache_seqlens.device,\n        )\n\n        return self.kvcache.attn(\n            q_in.data_ptr(),\n            output.data_ptr(),\n            attn_lse.data_ptr(),\n            layer_idx,\n            generate_token_idx,\n            q_len,\n            batch_size,\n            max_block_num,\n            block_table.data_ptr() if block_table is not None else 0,\n            cache_seqlens.data_ptr() if cache_seqlens is not None else 0,\n            pick_block_num,\n            init_block_num,\n            local_block_num,\n        )\n\n    # k_in: (block_len, kv_head_num, head_dim)\n    # v_in: (block_len, kv_head_num, head_dim)\n    def update_kvcache_one_block_fp16(\n        self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int\n    ):\n        assert (\n            k_in.dim() == 3\n            and k_in.size(1) == self.config.block_len\n            and k_in.size(0) == self.config.kv_head_num\n            and k_in.size(2) == self.config.head_dim\n            and k_in.dtype == torch.float16\n            and k_in.is_contiguous()\n            and k_in.device == torch.device(\"cpu\")\n        ), \"k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device\n        )\n        assert (\n            v_in.dim() == 3\n            and v_in.size(1) == self.config.block_len\n            and v_in.size(0) == self.config.kv_head_num\n            and v_in.size(2) == self.config.head_dim\n            and v_in.dtype == torch.float16\n            and v_in.is_contiguous()\n            and v_in.device == torch.device(\"cpu\")\n        ), \"v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.update_one_block_fp16(\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def get_kvcache_one_block_fp16(\n        self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int, block_idx: int\n    ):\n        assert (\n            k_in.dim() == 3\n            and k_in.size(1) == self.config.block_len\n            and k_in.size(0) == self.config.kv_head_num\n            and k_in.size(2) == self.config.head_dim\n            and k_in.dtype == torch.float16\n            and k_in.is_contiguous()\n            and k_in.device == torch.device(\"cpu\")\n        ), \"k_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            k_in.dim(), k_in.size(), k_in.dtype, k_in.is_contiguous(), k_in.device\n        )\n        assert (\n            v_in.dim() == 3\n            and v_in.size(1) == self.config.block_len\n            and v_in.size(0) == self.config.kv_head_num\n            and v_in.size(2) == self.config.head_dim\n            and v_in.dtype == torch.float16\n            and v_in.is_contiguous()\n            and v_in.device == torch.device(\"cpu\")\n        ), \"v_in dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            v_in.dim(), v_in.size(), v_in.dtype, v_in.is_contiguous(), v_in.device\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.get_one_block_fp16(\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def update_importance_one_block(\n        self, importance: torch.Tensor, layer_id: int, block_idx: int\n    ):\n        assert (\n            importance.dim() == 1\n            and importance.size(0) == self.config.block_len\n            and importance.dtype == torch.float16\n            and importance.is_contiguous()\n            and importance.device == torch.device(\"cpu\")\n        ), \"importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            importance.dim(),\n            importance.size(),\n            importance.dtype,\n            importance.is_contiguous(),\n            importance.device,\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.update_importance_one_block(\n            importance.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def get_importance_one_block(\n        self, importance: torch.Tensor, layer_id: int, block_idx: int\n    ):\n        assert (\n            importance.dim() == 1\n            and importance.size(0) == self.config.block_len\n            and importance.dtype == torch.float16\n            and importance.is_contiguous()\n            and importance.device == torch.device(\"cpu\")\n        ), \"importance dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            importance.dim(),\n            importance.size(),\n            importance.dtype,\n            importance.is_contiguous(),\n            importance.device,\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.get_importance_one_block(\n            importance.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def get_anchor_one_block(self, anchor: torch.Tensor, layer_id: int, block_idx: int):\n        assert (\n            anchor.dim() == 3\n            and anchor.size(0) == self.config.kv_head_num\n            and anchor.size(1) == self.config.anchor_num\n            and anchor.size(2) == self.config.head_dim\n            and anchor.dtype == torch.float16\n            and anchor.is_contiguous()\n            and anchor.device == torch.device(\"cpu\")\n        ), \"anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            anchor.dim(),\n            anchor.size(),\n            anchor.dtype,\n            anchor.is_contiguous(),\n            anchor.device,\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.get_anchor_one_block(\n            anchor.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def update_anchor_one_block(\n        self, anchor: torch.Tensor, layer_id: int, block_idx: int\n    ):\n        assert (\n            anchor.dim() == 3\n            and anchor.size(0) == self.config.kv_head_num\n            and anchor.size(1) == self.config.anchor_num\n            and anchor.size(2) == self.config.head_dim\n            and anchor.dtype == torch.float16\n            and anchor.is_contiguous()\n            and anchor.device == torch.device(\"cpu\")\n        ), \"anchor dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            anchor.dim(),\n            anchor.size(),\n            anchor.dtype,\n            anchor.is_contiguous(),\n            anchor.device,\n        )\n        assert (\n            layer_id >= 0 and layer_id < self.config.layer_num\n        ), \"layer_id: {}\".format(layer_id)\n        assert block_idx >= 0, \"block_idx: {}\".format(block_idx)\n        return self.kvcache.update_anchor_one_block(\n            anchor.data_ptr(),\n            layer_id,\n            block_idx,\n        )\n\n    def calc_anchor_all_layers(\n        self,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n    ):\n        assert (\n            block_table.dim() == 2\n            and block_table.size(0) == cache_seqlens.size(0)\n            and block_table.dtype == torch.int\n            and block_table.is_contiguous()\n            and block_table.device == torch.device(\"cpu\")\n        ), \"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            block_table.dim(),\n            block_table.size(),\n            block_table.dtype,\n            block_table.is_contiguous(),\n            block_table.device,\n        )\n        assert (\n            cache_seqlens.dim() == 1\n            and cache_seqlens.dtype == torch.int\n            and cache_seqlens.is_contiguous()\n            and cache_seqlens.device == torch.device(\"cpu\")\n        ), \"cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            cache_seqlens.dim(),\n            cache_seqlens.size(),\n            cache_seqlens.dtype,\n            cache_seqlens.is_contiguous(),\n            cache_seqlens.device,\n        )\n        batch_size = block_table.size(0)\n        max_block_num = block_table.size(1)\n        return self.kvcache.calc_anchor_all_layers(\n            block_table.data_ptr(),\n            cache_seqlens.data_ptr(),\n            batch_size,\n            max_block_num,\n        )\n\n    def clear_importance_all_layers(\n        self,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n    ):\n        assert (\n            block_table.dim() == 2\n            and block_table.size(0) == cache_seqlens.size(0)\n            and block_table.dtype == torch.int\n            and block_table.is_contiguous()\n            and block_table.device == torch.device(\"cpu\")\n        ), \"block_table dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            block_table.dim(),\n            block_table.size(),\n            block_table.dtype,\n            block_table.is_contiguous(),\n            block_table.device,\n        )\n        assert (\n            cache_seqlens.dim() == 1\n            and cache_seqlens.dtype == torch.int\n            and cache_seqlens.is_contiguous()\n            and cache_seqlens.device == torch.device(\"cpu\")\n        ), \"cache_seqlens dim: {}, size: {}, dtype: {}, contiguous: {}, device: {}\".format(\n            cache_seqlens.dim(),\n            cache_seqlens.size(),\n            cache_seqlens.dtype,\n            cache_seqlens.is_contiguous(),\n            cache_seqlens.device,\n        )\n        batch_size = block_table.size(0)\n        max_block_num = block_table.size(1)\n        return self.kvcache.clear_importance_all_layers(\n            block_table.data_ptr(),\n            cache_seqlens.data_ptr(),\n            batch_size,\n            max_block_num,\n        )\n\n    def get_cache_total_len(self):\n        return self.kvcache.get_cache_total_len()\n\n    def update_kvcache_q4(\n        self,\n        k_in: torch.Tensor,\n        k_scales: torch.Tensor,\n        v_in: torch.Tensor,\n        v_scales: torch.Tensor,\n        layer_id: int,\n        seq_offset: int | None = None,\n        seq_len: int | None = None,\n        block_table: torch.Tensor | None = None,\n    ):\n        raise NotImplementedError\n\n    def update_kvcache_fp16(\n        self,\n        k_in: torch.Tensor,\n        v_in: torch.Tensor,\n        layer_idx,\n        block_table: torch.Tensor,\n        max_block_num,\n        past_len: torch.Tensor,\n        q_len,\n    ):\n        batch_size = block_table.size(0)\n        return self.kvcache.get_kvcache_fp16(\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            layer_idx,\n            block_table.data_ptr(),\n            batch_size,\n            max_block_num,\n            past_len.data_ptr(),\n            q_len\n        )\n\n    def get_kvcache_q4(\n        self,\n        k_in: torch.Tensor,\n        k_scales: torch.Tensor,\n        v_in: torch.Tensor,\n        v_scales: torch.Tensor,\n        layer_id: int,\n        seq_offset: int | None = None,\n        seq_len: int | None = None,\n        block_table: torch.Tensor | None = None,\n    ):\n        raise NotImplementedError\n\n    def get_kvcache_fp16(\n        self,\n        k_in: torch.Tensor,\n        v_in: torch.Tensor,\n        layer_id: int,\n        layer_idx,\n        block_table: torch.Tensor,\n        max_block_num,\n        past_len: torch.Tensor,\n    ):\n        batch_size = block_table.size(0)\n        return self.kvcache.get_kvcache_fp16(\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            layer_idx,\n            block_table.data_ptr(),\n            batch_size,\n            max_block_num,\n            past_len.data_ptr(),\n        )\n\n    def get_and_update_kvcache_fp16(\n        self,\n        k_cache_cpu: torch.Tensor,\n        v_cache_cpu: torch.Tensor,\n        layer_idx,\n        block_table: torch.Tensor,\n        max_block_num,\n        past_len: torch.Tensor,\n        q_len,\n    ):\n        batch_size = block_table.size(0)\n        return self.kvcache.get_and_update_kvcache_fp16(\n            k_cache_cpu.data_ptr(),\n            v_cache_cpu.data_ptr(),\n            layer_idx,\n            block_table.data_ptr(),\n            batch_size,\n            max_block_num,\n            past_len.data_ptr(),\n            q_len,\n        )\n\n    def update_importance(\n        self,\n        importance_cache: torch.Tensor,\n        layer_idx,\n        block_table: torch.Tensor,\n        max_block_num,\n        offset: torch.Tensor,\n        width,\n    ):\n        batch_size = block_table.size(0)\n        return self.kvcache.update_importance(\n            importance_cache.data_ptr(),\n            layer_idx,\n            block_table.data_ptr(),\n            batch_size,\n            max_block_num,\n            offset.data_ptr(),\n            width,\n        )\n\n    # attn_sparsity: ((bsz, q_len, q_head_num), dtype = torch.float32)\n    def get_attn_sparsity(\n        self,\n        q_in: torch.Tensor,\n        attn_sparsity: torch.Tensor,\n        layer_idx: int,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n        block_table_origin: torch.Tensor,\n        cache_seqlens_origin: torch.Tensor,\n        generate_token_idx: int = 0,\n        topk: int | None = None,\n        local: int | None = None,\n    ):\n        batch_size = block_table.size(0)\n        max_block_num = block_table.size(1)\n        max_block_num_origin = block_table_origin.size(1)\n        q_len = q_in.size(1)\n\n        if topk is None or local is None or topk + local >= max_block_num:\n            topk = -1\n            local = -1\n        return self.kvcache.get_attn_sparsity(\n            q_in.data_ptr(),\n            attn_sparsity.data_ptr(),\n            layer_idx,\n            generate_token_idx,\n            q_len,\n            batch_size,\n            max_block_num,\n            block_table.data_ptr(),\n            cache_seqlens.data_ptr(),\n            block_table_origin.data_ptr(),\n            cache_seqlens_origin.data_ptr(),\n            max_block_num_origin,\n            topk,\n            local,\n        )\n\n    def attn_with_kvcache(\n        self,\n        q_in: torch.Tensor,\n        k_in: torch.Tensor,\n        v_in: torch.Tensor,\n        output: torch.Tensor,\n        attn_lse: torch.Tensor,\n        layer_idx: int,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n        generate_token_idx: int = 0,\n        topk: int | None = None,\n        local: int | None = None,\n    ):\n\n        batch_size = block_table.size(0)\n        max_block_num = block_table.size(1)\n        q_len = q_in.size(1)\n\n        if topk is None or local is None or topk + local >= max_block_num:\n            topk = -1\n            local = -1\n        return self.kvcache.attn_with_kvcache(\n            q_in.data_ptr(),\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            output.data_ptr(),\n            attn_lse.data_ptr(),\n            layer_idx,\n            generate_token_idx,\n            q_len,\n            batch_size,\n            max_block_num,\n            block_table.data_ptr(),\n            cache_seqlens.data_ptr(),\n            topk,\n            local,\n        )\n\n    def get_all_kvcache_one_layer(\n        self, k_in: torch.Tensor, v_in: torch.Tensor, layer_id: int\n    ):\n        return self.kvcache.get_all_kvcache_one_layer(\n            k_in.data_ptr(),\n            v_in.data_ptr(),\n            layer_id,\n        )\n\n    def get_importance(\n        self,\n        importance: torch.Tensor,\n        block_table: torch.Tensor,\n    ):\n        raise NotImplementedError\n\n    def get_anchor(\n        self,\n        anchor: torch.Tensor,\n        block_table: torch.Tensor,\n    ):\n        raise NotImplementedError\n\n\nclass CPUInfer:\n    cpuinfer = None\n    cur_backend_thread_num = 0\n    \n    def __init__(self, thread_num):\n        if thread_num > CPUInfer.cur_backend_thread_num:\n            CPUInfer.cur_backend_thread_num = thread_num\n            del CPUInfer.cpuinfer\n            CPUInfer.cpuinfer = cpuinfer_ext.CPUInfer(thread_num)\n\n    def submit(self, task):\n        CPUInfer.cpuinfer.submit(task)\n\n    def submit_with_cuda_stream(self, current_cuda_stream, task):\n        CPUInfer.cpuinfer.submit_with_cuda_stream(current_cuda_stream, task)\n\n    def sync(self):\n        CPUInfer.cpuinfer.sync()\n\n    def sync_with_cuda_stream(self, current_cuda_stream):\n        CPUInfer.cpuinfer.sync_with_cuda_stream(current_cuda_stream)\n\n\n        \n"
  },
  {
    "path": "kt-sft/ktransformers/operators/dynamic_attention.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :  \nAuthor       : Jianwei Dong\nDate         : 2024-08-26 23:25:24\nVersion      : 1.0.0\nLastEditors  : Jianwei Dong\nLastEditTime : 2024-08-26 23:25:24\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\n\nimport torch\nfrom transformers import AutoConfig\nimport sys, os\nimport logging\nlogger = logging.getLogger(\"dynamic_attention\")\nsys.path.append(os.path.dirname(__file__) + \"/../ktransformers_ext/cpu_backend\")\nfrom ktransformers.operators.cpuinfer import CPUInfer, CPUInferKVCache\ntry:\n    from flash_attn import flash_attn_func, flash_attn_with_kvcache\nexcept:\n    print(\"falsh attn not found\")\n\n\nimport math\nimport json\n\n\nclass DynamicScaledDotProductAttention:\n    remaining_length: int\n    cpu_infer = None\n\n    def __init__(\n        self,\n        max_seq_len: int,\n        block_size: int,\n        config: AutoConfig,\n        device: torch.device,\n        local_windows_len: int,\n        topk: int,\n        threads_num: int,\n        anchor_type: str = \"DYNAMIC\",\n        kv_type: str = \"FP16\",\n        dense_layer_num: int = 0,\n        anchor_num: int = 1,\n        block_selection_mode: str = \"SHARED\",\n        layer_step: int = 1,\n        token_step: int = 1,\n        preselect_block: bool = False,\n        preselect_block_count: int = 96,\n        prefill_chunk_size: int = 20480,\n        use_attn_sparsity: bool = False,\n    ):\n        # assert anchor_num == 1\n        # assert anchor_type == \"DYNAMIC\"\n        self.remaining_length = 0\n        valid_anchor_types = [\"DYNAMIC\", \"FIXED\", \"BLOCK_MEAN\", \"BLOCK_MAX\", \"QUEST\"]\n        assert anchor_type in valid_anchor_types\n        if anchor_type == \"QUEST\":\n            assert anchor_num == 2\n        elif anchor_type != \"FIXED\" and anchor_type != \"DYNAMIC\":\n            assert anchor_num == 1\n\n        valid_kv_types = [\"FP16\", \"FP32\", \"Q4_0\", \"Q8_0\"]\n        assert kv_type in valid_kv_types\n        if kv_type != \"FP16\" and kv_type != \"FP32\":\n            assert block_size % 32 == 0\n\n        valid_block_selection_modes = [\"SHARED\", \"SEPARATE\"]  # individual\n        assert block_selection_mode in valid_block_selection_modes\n\n        self.max_seq_len = max_seq_len\n        self.block_num = max_seq_len // block_size\n        self.block_size = block_size\n        self.anchor_type = anchor_type\n        self.kv_type = kv_type\n        self.anchor_num = anchor_num\n        self.threads_num = threads_num\n        self.layer_step = layer_step\n        self.token_step = token_step\n        self.preselect_block = preselect_block\n        self.preselect_block_count = preselect_block_count\n        self.block_selection_mode = block_selection_mode\n        self.use_attn_sparsity = use_attn_sparsity\n\n        # model config\n        self.kv_head_num = config.num_key_value_heads\n        self.q_head_num = config.num_attention_heads\n        self.head_dim = config.hidden_size // config.num_attention_heads\n        self.layer_num = config.num_hidden_layers\n\n        self.device = device\n        self.local_windows_len = local_windows_len\n        self.local_block_num = self.local_windows_len // self.block_size + 1\n        self.prefill_chunk_size = prefill_chunk_size\n\n        self.topk = topk\n        self.dense_layer_num = dense_layer_num\n        # self.dense_layer_num = 32\n        self.cache_key_states = torch.zeros(\n            (self.block_num, block_size, self.kv_head_num, self.head_dim),\n            device=device,\n            dtype=torch.float16,\n        )\n        self.cache_value_states = torch.zeros(\n            (self.block_num, block_size, self.kv_head_num, self.head_dim),\n            device=device,\n            dtype=torch.float16,\n        )\n        # [max_num_block, block_size, head_num]\n        self.cache_importance = torch.zeros(\n            (self.block_num, block_size, self.q_head_num),\n            device=device,\n            dtype=torch.float16,\n        )\n\n        # key_states: [bsz, q_len, kv_head_num, head_dim]\n        # value_states: [bsz, q_len, kv_head_num, head_dim]\n        # query_states: [bsz, q_len, q_head_num, head_dim]\n        self.q_in_cpu = torch.zeros(\n            (1, 1, self.q_head_num, self.head_dim),\n            device=\"cpu\",\n            dtype=torch.float16,\n            pin_memory=True,\n        )\n        self.k_in_cpu = torch.zeros(\n            (1, 1, self.kv_head_num, self.head_dim),\n            device=\"cpu\",\n            dtype=torch.float16,\n            pin_memory=True,\n        )\n        self.v_in_cpu = torch.zeros(\n            (1, 1, self.kv_head_num, self.head_dim),\n            device=\"cpu\",\n            dtype=torch.float16,\n            pin_memory=True,\n        )\n\n        self.cache_seqlens_cpu = torch.empty(\n            (1,), device=\"cpu\", dtype=torch.int32, pin_memory=True\n        )\n\n        self.cache_seqlens_cuda = torch.empty((1,), device=device, dtype=torch.int32)\n\n        self.prefix_block_table = torch.arange(\n            self.block_num, device=\"cpu\", dtype=torch.int32, pin_memory=True\n        ).view(1, -1)\n\n        self.block_table_cpu = torch.arange(\n            self.block_num, device=\"cpu\", dtype=torch.int32, pin_memory=True\n        ).view(1, -1)\n\n        # assert (\n        #     self.local_windows_len // self.block_size + 1 + self.preselect_block_count\n        #     <= self.block_num\n        # )\n\n        self.output_cpu = torch.empty(\n            (1, 1, self.q_head_num, self.head_dim),\n            device=\"cpu\",\n            dtype=torch.float16,\n            pin_memory=True,\n        )\n        self.lse_cpu = torch.empty(\n            (1, 1, self.q_head_num), device=\"cpu\", dtype=torch.float32, pin_memory=True\n        )\n\n        self.output_cuda = torch.empty(\n            (1, 1, self.q_head_num, self.head_dim), device=device, dtype=torch.float16\n        )\n\n        self.attn_sparsity = torch.zeros(\n            (1, 1, self.q_head_num), device=\"cpu\", dtype=torch.float32, pin_memory=True\n        )\n\n        if preselect_block == True:\n            self.preselect_block_table = torch.zeros(\n                self.layer_num,\n                self.preselect_block_count,\n                device=device,\n                dtype=torch.int32,\n            )\n            self.preselect_block_num = 0  # block_num before preselect\n            self.evict_tokens = 0\n\n        if DynamicScaledDotProductAttention.cpu_infer is None:\n            DynamicScaledDotProductAttention.cpu_infer = CPUInfer(threads_num)\n            self.cpu_infer = DynamicScaledDotProductAttention.cpu_infer\n        self.local_thread = CPUInferKVCache(\n            self.layer_num,\n            self.kv_head_num,\n            self.q_head_num,\n            self.head_dim,\n            self.block_size,\n            anchor_num=self.anchor_num,\n            anchor_type=anchor_type,\n            kv_type=self.kv_type,\n            retrieval_type=self.block_selection_mode,\n            layer_step=self.layer_step,\n            token_step=self.token_step,\n            layer_offset=self.dense_layer_num % self.layer_step,\n            max_batch_size=1,\n            max_block_num=self.block_num,\n            max_thread_num=self.threads_num,\n        )\n\n        print(\n            f\"local_windows_len: {local_windows_len}, topk: {topk}, dense_layer_num: {dense_layer_num}, kv_type: {self.kv_type}, anchor_type: {self.anchor_type}, preselect_block: {self.preselect_block}, preselect_block_count: {self.preselect_block_count}, token_step: {self.token_step}, layer_step: {self.layer_step}\"\n        )\n\n        self.shape_mask = (\n            self.q_head_num,\n            self.block_size,\n            self.block_size,\n        )\n\n        mask = torch.zeros(\n            self.shape_mask, dtype=torch.uint8, device=device\n        ).contiguous()\n        elm_idx = torch.arange(self.block_size, device=device)\n\n        for i in range(mask.size(-2)):\n            idx = i + mask.size(-1) - mask.size(-2) - elm_idx\n            idx = idx[idx >= 0]\n            mask[..., i, idx] = 1\n\n        self.tril_mask = mask\n        self.triu_mask = mask ^ 1\n\n        self.generate_token_idx = 0\n\n    def get_attn_score_one_block(\n        self,\n        batch_idx: int,\n        max_block_num: int,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        offset: int,\n        width: int,\n        mask_mode: str | None = None,\n        use_softmax: bool = True,\n    ):\n        n_rep = self.q_head_num // self.kv_head_num\n        importance = self.cache_importance.view(-1, self.q_head_num)\n        importance = importance.narrow(0, batch_idx * max_block_num + offset, width)\n        n_gqa_ = self.q_head_num // self.kv_head_num \n        for head_idx in range(self.q_head_num):\n            key_item = key[..., head_idx // n_gqa_, :].view(key.size(0), -1)\n            qk = torch.einsum(\n                \"qd,kd->qk\", query[:,head_idx,:], key_item\n            )  # (num_attention_heads, len_q, len_k)\n\n            if mask_mode == \"tril\":\n                mask = self.tril_mask\n                mask = mask[0, -qk.size(-2) :, -qk.size(-1) :]\n                qk = qk * mask\n            elif mask_mode == \"triu\":\n                mask = self.triu_mask\n                mask = mask[0, -qk.size(-2) :, -qk.size(-1) :]\n                qk = qk * mask\n\n            if use_softmax:\n                qk = torch.nn.functional.softmax(\n                    qk / math.sqrt(self.head_dim), dim=-1, dtype=torch.float32\n                ).to(torch.float16)\n              \n            qk = torch.sum(qk, dim=-2)\n            importance[...,head_idx] += qk\n\n    def get_preselect_block_table_and_attn_score(\n        self,\n        layer_idx: int,\n        batch_size: int,\n        offset: torch.Tensor,\n        width: int,\n        query: torch.Tensor,\n        key: torch.Tensor,\n        union_with_last_layer: bool = True,\n    ):\n        max_seqs_len = offset.max().item() + width\n        max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size\n\n        for batch_idx in range(batch_size):\n            query_cur = query[batch_idx][-128:]\n            self.get_attn_score_one_block(\n                batch_idx,\n                max_block_num,\n                query_cur,\n                key[batch_idx][: offset[batch_idx].item() + width],\n                0,\n                offset[batch_idx].item() + width,\n                mask_mode=None,\n            )\n\n        if self.preselect_block:\n            self.prefill_block_num = max(\n                0, max_block_num - self.local_windows_len // self.block_size\n            )\n            self.evict_tokens = (\n                max(self.prefill_block_num - self.preselect_block_count, 0)\n                * self.block_size\n            )\n\n            if self.prefill_block_num != 0:\n                importance_cache = self.cache_importance.narrow(\n                    0, 0, self.prefill_block_num * batch_size\n                ).view(\n                    batch_size, self.prefill_block_num, self.block_size, self.q_head_num\n                )\n\n                importance_r = importance_cache[:, 1:, : self.block_size // 4]\n                pad_r = torch.zeros_like(importance_r[:, :1])\n                importance_r = torch.cat((importance_r, pad_r), dim=1)\n                importance_l = importance_cache[:, :-1, -self.block_size // 4 :]\n                pad_l = torch.zeros_like(importance_l[:, :1])\n                importance_l = torch.cat((pad_l, importance_l), dim=1)\n                importance = torch.cat(\n                    (importance_l, importance_cache, importance_r), dim=2\n                )\n                importance = importance.mean(dim=-1)\n                importance = importance.mean(dim=-1)\n                # importance: (batch_size, max_block_num)\n                topk = min(self.preselect_block_count, self.prefill_block_num)\n                values, indices = torch.topk(\n                    importance,\n                    k=topk,\n                    dim=1,\n                )\n\n                self.preselect_block_table[\n                    layer_idx : layer_idx + 1,\n                    :topk,\n                ].copy_(indices)\n\n                if union_with_last_layer and layer_idx == 31:\n                    for tmp_layer_idx in range(self.layer_num - 1):\n                        for i in range(1, min(topk, 6)):\n                            x = self.preselect_block_table[-1, i]\n                            if x not in self.preselect_block_table[tmp_layer_idx]:\n                                self.preselect_block_table[tmp_layer_idx, topk - i] = x\n        if self.anchor_type == \"DYNAMIC\":\n            importance_cache = self.cache_importance.narrow(\n                0, 0, max_block_num * batch_size\n            ).view(batch_size, max_block_num * self.block_size, self.q_head_num)\n            importance_cache_cpu = torch.empty_like(\n                importance_cache, device=\"cpu\", pin_memory=True\n            )\n\n            importance_cache_cpu.copy_(importance_cache)\n\n            block_table_cpu = self.prefix_block_table[:, :max_block_num].to(\"cpu\")\n            offset_cpu = offset.contiguous().to(\"cpu\")\n\n            self.cpu_infer.submit(\n                self.local_thread.update_importance(\n                    importance_cache_cpu,\n                    layer_idx,\n                    block_table_cpu,\n                    max_block_num,\n                    offset_cpu,\n                    width,\n                )\n            )\n            self.cpu_infer.sync()\n\n        importance_cache = self.cache_importance.narrow(\n            0, 0, max_block_num * batch_size\n        ).view(batch_size, max_block_num * self.block_size, self.q_head_num)\n        importance_cache.zero_()\n\n    # key: [bsz, past_len, head_num, head_dim] float16\n    # query: [bsz, q_len, q_head_num, head_dim] float16\n    def get_attn_score(\n        self,\n        layer_idx: int,\n        batch_size: int,\n        offset: torch.Tensor,\n        width: int,\n        query: torch.Tensor,\n        key: torch.Tensor,\n    ):\n        max_seqs_len = offset.max().item() + width\n        max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size\n\n        for batch_idx in range(batch_size):\n            for idx in range(width // self.block_size):\n                offset_cur = idx * self.block_size\n                query_cur = query[batch_idx, offset_cur : offset_cur + self.block_size]\n                self.get_attn_score_one_block(\n                    batch_idx,\n                    max_block_num,\n                    query_cur,\n                    key[\n                        batch_idx,\n                        offset[batch_idx]\n                        + offset_cur : offset[batch_idx]\n                        + offset_cur\n                        + self.block_size,\n                    ],\n                    offset[batch_idx].item() + offset_cur,\n                    self.block_size,\n                    mask_mode=\"tril\",\n                    use_softmax=False,\n                )\n\n                offset_key = (\n                    offset[batch_idx].item()\n                    + idx * self.block_size\n                    - self.local_windows_len\n                )\n                if offset_key >= 0:\n                    self.get_attn_score_one_block(\n                        batch_idx,\n                        max_block_num,\n                        query_cur,\n                        key[batch_idx, offset_key : offset_key + self.block_size],\n                        offset_key,\n                        self.block_size,\n                        mask_mode=\"triu\",\n                        use_softmax=False,\n                    )\n\n                offset_key = max(0, offset_key + self.block_size)\n                width_key = (\n                    offset[batch_idx].item() + idx * self.block_size - offset_key\n                )\n                if width_key > 0:\n                    self.get_attn_score_one_block(\n                        batch_idx,\n                        max_block_num,\n                        query_cur,\n                        key[batch_idx, offset_key : offset_key + width_key],\n                        offset_key,\n                        width_key,\n                        mask_mode=None,\n                        use_softmax=False,\n                    )\n\n        importance_cache = self.cache_importance.narrow(\n            0, 0, max_block_num * batch_size\n        ).view(batch_size, max_block_num * self.block_size, self.q_head_num)\n        importance_cache_cpu = torch.empty_like(\n            importance_cache, device=\"cpu\", pin_memory=True\n        )\n\n        importance_cache_cpu.copy_(importance_cache)\n\n        block_table_cpu = self.prefix_block_table[:, :max_block_num].to(\"cpu\")\n        offset_cpu = offset.contiguous().to(\"cpu\")\n\n        self.cpu_infer.submit(\n            self.local_thread.update_importance(\n                importance_cache_cpu,\n                layer_idx,\n                block_table_cpu,\n                max_block_num,\n                offset_cpu,\n                width,\n            )\n        )\n        self.cpu_infer.sync()\n        importance_cache.zero_()\n\n    # key: [bsz, q_len, head_num, head_dim] float16\n    # value: [bsz, q_len, head_num, head_dim] float16\n    def swap_in_and_swap_out(self, layer_idx, past_len, q_len, key, value):\n        batch_size = 1\n        max_seqs_len = past_len.max().item() + q_len\n        max_block_num = (max_seqs_len + self.block_size - 1) // self.block_size\n        k_cache = self.cache_key_states.narrow(0, 0, max_block_num * batch_size).view(\n            batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim\n        )\n        v_cache = self.cache_value_states.narrow(0, 0, max_block_num * batch_size).view(\n            batch_size, max_block_num * self.block_size, self.kv_head_num, self.head_dim\n        )\n\n        for batch_idx in range(batch_size):\n            offset = past_len[batch_idx]\n            width = q_len\n            k_cache[batch_idx][offset : offset + width].copy_(\n                key[batch_idx].view(-1, self.kv_head_num, self.head_dim)\n            )\n            v_cache[batch_idx][offset : offset + width].copy_(\n                value[batch_idx].view(-1, self.kv_head_num, self.head_dim)\n            )\n\n        k_cache_cpu = torch.empty_like(k_cache, device=\"cpu\", pin_memory=True)\n        v_cache_cpu = torch.empty_like(v_cache, device=\"cpu\", pin_memory=True)\n\n        k_cache_cpu.copy_(k_cache)\n        v_cache_cpu.copy_(v_cache)\n\n        cur_block_num = (\n            q_len + past_len[0].item() + self.block_size - 1\n        ) // self.block_size\n        block_table_cpu = self.prefix_block_table[:, :cur_block_num].to(\"cpu\")\n        past_len_cpu = past_len.contiguous().to(\"cpu\")\n\n        self.cpu_infer.submit(\n            self.local_thread.get_and_update_kvcache_fp16(\n                k_cache_cpu,\n                v_cache_cpu,\n                layer_idx,\n                block_table_cpu,\n                max_block_num,\n                past_len_cpu,\n                q_len,\n            )\n        )\n\n        self.cpu_infer.sync()\n        k_cache.copy_(k_cache_cpu)\n        v_cache.copy_(v_cache_cpu)\n\n        return k_cache, v_cache\n\n    def calc_anchor(self, cache_seqlens: int):\n        cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size\n        block_table_cpu = self.prefix_block_table[:, :cur_block_num].to(\"cpu\")\n        cache_seqlens_cpu = torch.tensor(\n            [cache_seqlens], device=\"cpu\", dtype=torch.int32\n        )\n\n        self.cpu_infer.submit(\n            self.local_thread.calc_anchor_all_layers(\n                block_table_cpu,\n                cache_seqlens_cpu,\n            )\n        )\n        self.cpu_infer.sync()\n\n    def clear_importance(self, cache_seqlens: int):\n        print(f\"clear importance: {cache_seqlens}\")\n        cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size\n        block_table_cpu = self.prefix_block_table[:, :cur_block_num].to(\"cpu\")\n        cache_seqlens_cpu = torch.tensor(\n            [cache_seqlens], device=\"cpu\", dtype=torch.int32\n        )\n\n        self.cpu_infer.submit(\n            self.local_thread.clear_importance_all_layers(\n                block_table_cpu,\n                cache_seqlens_cpu,\n            )\n        )\n        self.cpu_infer.sync()\n\n    def clear_kvcache(self, cache_seqlens: int):\n        cur_block_num = (cache_seqlens + self.block_size - 1) // self.block_size\n        block_table_cpu = self.prefix_block_table[:, :cur_block_num].to(\"cpu\")\n        cache_seqlens_cpu = torch.tensor(\n            [cache_seqlens], device=\"cpu\", dtype=torch.int32\n        )\n\n        self.cpu_infer.submit(\n            self.local_thread.clear_kvcache_all_layers(\n                block_table_cpu,\n                cache_seqlens_cpu,\n            )\n        )\n        self.cpu_infer.sync()\n\n    def get_attn_sparsity(\n        self,\n        q_in: torch.Tensor,\n        layer_idx: int,\n        block_table: torch.Tensor,\n        cache_seqlens: torch.Tensor,\n        block_table_origin: torch.Tensor,\n        cache_seqlens_origin: torch.Tensor,\n        generate_token_idx: int = 0,\n        topk: int | None = None,\n        local: int | None = None,\n        output_path: str = \"./attn_sparsity.json\",\n    ):\n        self.attn_sparsity.zero_()\n        self.pcinfer.submit(\n            self.local_thread.get_attn_sparsity(\n                q_in,\n                self.attn_sparsity,\n                layer_idx,\n                block_table,\n                cache_seqlens,\n                block_table_origin,\n                cache_seqlens_origin,\n                generate_token_idx,\n                topk,\n                local,\n            )\n        )\n        self.cpu_infer.sync()\n        with open(output_path, \"a\") as file:\n            for head_idx in range(self.q_head_num):\n                sparsity = self.attn_sparsity[0][0][head_idx].item()\n                json_obj = {\n                    \"token_idx\": generate_token_idx,\n                    \"layer_idx\": layer_idx,\n                    \"head_idx\": head_idx,\n                    \"sparsity\": sparsity,\n                }\n                json.dump(json_obj, file)\n                file.write(\"\\n\")\n\n    def apply(\n        self,\n        layer_idx: int,\n        bsz: int,\n        past_len: int,\n        query_states: torch.Tensor,\n        key_states: torch.Tensor,\n        value_states: torch.Tensor,\n        mode: str = \"prefill\",\n        generate_token_idx: int = -1,\n    ):\n\n        # key_states: [bsz, q_len, kv_head_num, head_dim]\n        # value_states: [bsz, q_len, kv_head_num, head_dim]\n        # query_states: [bsz, q_len, q_head_num, head_dim]\n        assert query_states.dtype == torch.float16\n        assert key_states.dtype == torch.float16\n        assert value_states.dtype == torch.float16\n\n        assert key_states.size(2) == self.kv_head_num\n        assert value_states.size(2) == self.kv_head_num\n        assert query_states.size(2) == self.q_head_num\n\n        q_len = query_states.size(1)\n        batch_size = query_states.size(0)\n        self.cache_seqlens_cuda.fill_(past_len)\n        last_chunk = False\n        if self.remaining_length <= self.prefill_chunk_size and q_len != 1:\n            last_chunk = True\n        device = query_states.device\n        if layer_idx == 0:\n            if q_len == 1:\n                self.generate_token_idx += 1\n            elif last_chunk:\n                self.generate_token_idx = -1\n\n        if mode == \"prefill\":\n            key, value = self.swap_in_and_swap_out(\n                layer_idx,\n                self.cache_seqlens_cuda,\n                q_len,\n                key_states,\n                value_states,\n            )\n\n            if last_chunk and (self.anchor_type == \"DYNAMIC\" or self.preselect_block):\n                self.get_preselect_block_table_and_attn_score(\n                    layer_idx,\n                    bsz,\n                    self.cache_seqlens_cuda,\n                    q_len,\n                    query_states,\n                    key,\n                )\n            output = flash_attn_with_kvcache(\n                q=query_states,\n                k_cache=key,\n                v_cache=value,\n                cache_seqlens=self.cache_seqlens_cuda + q_len,\n                causal=True,\n            )\n            return output.transpose(1, 2)\n\n        elif mode == \"generate\":\n            assert self.generate_token_idx >= 0\n            self.q_in_cpu.copy_(query_states, non_blocking=True)\n            self.k_in_cpu.copy_(key_states, non_blocking=True)\n            self.v_in_cpu.copy_(value_states, non_blocking=True)\n            self.cache_seqlens_cpu.copy_(self.cache_seqlens_cuda, non_blocking=True)\n            #            print(layer_idx)\n            if layer_idx < self.dense_layer_num:\n                self.block_table_cpu.copy_(self.prefix_block_table, non_blocking=True)\n                self.cpu_infer.submit_with_cuda_stream(\n                    torch.cuda.current_stream(\"cuda\").cuda_stream,\n                    self.local_thread.attn_with_kvcache(\n                        q_in=self.q_in_cpu,\n                        k_in=self.k_in_cpu,\n                        v_in=self.v_in_cpu,\n                        output=self.output_cpu,\n                        attn_lse=self.lse_cpu,\n                        layer_idx=layer_idx,\n                        block_table=self.block_table_cpu,\n                        cache_seqlens=self.cache_seqlens_cpu,\n                    ),\n                )\n            else:\n                if self.preselect_block:\n                    self.cache_seqlens_cpu.copy_(\n                        self.cache_seqlens_cuda - self.evict_tokens, non_blocking=True\n                    )\n                    if self.preselect_block_count < self.prefill_block_num:\n                        self.block_table_cpu[:, : self.preselect_block_count].copy_(\n                            self.preselect_block_table[layer_idx : layer_idx + 1],\n                            non_blocking=True,\n                        )\n\n                        self.block_table_cpu[\n                            :,\n                            self.preselect_block_count : self.preselect_block_count\n                            + self.local_block_num,\n                        ].copy_(\n                            self.prefix_block_table[\n                                :,\n                                self.prefill_block_num : self.prefill_block_num\n                                + self.local_block_num,\n                            ],\n                            non_blocking=True,\n                        )\n                    #                   print(\"submit_with_cuda_stream\")\n                    self.cpu_infer.submit_with_cuda_stream(\n                        torch.cuda.current_stream(\"cuda\").cuda_stream,\n                        self.local_thread.attn_with_kvcache(\n                            q_in=self.q_in_cpu,\n                            k_in=self.k_in_cpu,\n                            v_in=self.v_in_cpu,\n                            output=self.output_cpu,\n                            attn_lse=self.lse_cpu,\n                            layer_idx=layer_idx,\n                            generate_token_idx=self.generate_token_idx,\n                            block_table=self.block_table_cpu,\n                            cache_seqlens=self.cache_seqlens_cpu,\n                            topk=(\n                                self.topk\n                                if self.topk <= self.preselect_block_count\n                                else None\n                            ),\n                            local=self.local_windows_len // self.block_size,\n                        ),\n                    )\n                #                    print(\"submit_with_cuda_stream enqueue\\n\")\n                else:\n                    self.block_table_cpu.copy_(\n                        self.prefix_block_table, non_blocking=True\n                    )\n                    self.cpu_infer.submit_with_cuda_stream(\n                        torch.cuda.current_stream(\"cuda\").cuda_stream,\n                        self.local_thread.attn_with_kvcache(\n                            q_in=self.q_in_cpu,\n                            k_in=self.k_in_cpu,\n                            v_in=self.v_in_cpu,\n                            output=self.output_cpu,\n                            attn_lse=self.lse_cpu,\n                            layer_idx=layer_idx,\n                            generate_token_idx=self.generate_token_idx,\n                            block_table=self.block_table_cpu,\n                            cache_seqlens=self.cache_seqlens_cpu,\n                            topk=self.topk,\n                            local=self.local_windows_len // self.block_size,\n                        ),\n                    )\n            self.cpu_infer.sync_with_cuda_stream(\n                torch.cuda.current_stream(\"cuda\").cuda_stream\n            )\n            #            print(\"submit_with_cuda_stream finished\\n\")\n            self.output_cuda.copy_(self.output_cpu, non_blocking=True)\n            return self.output_cuda.transpose(1, 2)\n\n    def save(self, path: str, length: int):\n        cur_block_num = (length + self.block_size - 1) // self.block_size\n        block_table_cpu = self.prefix_block_table[0, :cur_block_num].to(\"cpu\")\n        cache_seqlens_cpu = torch.tensor([length], device=\"cpu\", dtype=torch.int32)\n        self.cpu_infer.submit(\n            self.local_thread.dump_kvcache(\n                block_table_cpu,\n                cache_seqlens_cpu,\n                path,\n            )\n        )\n        self.cpu_infer.sync()\n\n    def load(self, path: str, length: int):\n        self.cpu_infer.submit(\n            self.local_thread.load_kvcache(\n                path,\n            )\n        )\n        self.cpu_infer.sync()\n"
  },
  {
    "path": "kt-sft/ktransformers/operators/experts.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : Azure-Tang, Boxin Zhang, chenht2022\nDate         : 2024-07-25 11:25:24\nVersion      : 0.1.0\nLastEditors  : Azure \nLastEditTime : 2024-08-29 09:41:10\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\n\nfrom typing import Any, Union\nimport numpy as np\nimport numpy.typing as npt\nfrom torch import Tensor, nn\nimport torch.nn.functional as F\nimport torch\nimport sys, os\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom tqdm import tqdm\nimport time\nimport logging\nfrom tqdm.auto import tqdm\nimport re\n\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Release\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Debug\"))\nimport cpuinfer_ext\nfrom cpuinfer_ext.moe import MOEConfig, MOE\nfrom cpuinfer_ext.sft_moe import SFT_MOEConfig, SFT_MOE\nimport ctypes\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom ktransformers.util.inference_state import InferenceState\nfrom ktransformers.util.custom_gguf import GGMLQuantizationType\nfrom ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader, ModelLoader\nfrom ktransformers.server.config.config import Config\nfrom transformers.activations import ACT2FN\nfrom transformers.configuration_utils import PretrainedConfig\nfrom abc import ABC, abstractmethod\nfrom ktransformers.operators.linear import KLinearMarlin, KLinearTorch, KTransformersLinear\nimport time\nfrom ktransformers.operators.cpuinfer import CPUInfer\nfrom ktransformers.util.grad_wrapper import maybe_no_grad\n\nH_FIXED = 7168\nM_FIXED = 2048\n\ndef deduplicate_and_sort(lst):\n    return sorted(set(lst))\ndef generate_cuda_graphs(chunk_size: int) -> list:\n    assert chunk_size <= 1024 or chunk_size % 1024 == 0, \"chunk_size must <= 1024 or a multiple of 1024\"\n    base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size]\n\n    if chunk_size <= 1024:\n        return deduplicate_and_sort(base_list)\n\n    multiples = [i for i in range(1024, chunk_size + 1, 1024)]\n\n    return deduplicate_and_sort(base_list + multiples)\n#cuda_graphs = [Config().chunk_size] \nif torch.cuda.is_available():\n    cuda_graphs = generate_cuda_graphs(Config().chunk_size)\nelse:\n    cuda_graphs = 1\n# class Base(BaseInjectedModule, ABC):\nclass KExpertsBase(ABC):\n    def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = \"cuda\", **kwargs):\n        # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.key = key\n        self.gguf_loader = gguf_loader\n        self.config = config\n        self.device = device\n    \n    @abstractmethod\n    def forward(self, input_tensor, expert_ids, weights):\n        pass\n\n    @abstractmethod\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = \"cpu\", warmup: bool = False):\n        pass\n    \n    @abstractmethod\n    def unload():\n        pass\n\n    def load_weights(self, override_key: str | None = None, device: str = \"cpu\"):\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n        gate_type = None\n        up_type = None\n        down_type = None\n\n        for key in keys:\n            # if key + \".ffn_gate_exps.weight\" in self.gguf_loader.tensor_info: # TODO: maybe problem in merge (this is origin one)\n            if self.gguf_loader.has_tensor(key + \".ffn_gate_exps.weight\"):\n                targets = [\".ffn_gate_exps.weight\", \".ffn_up_exps.weight\", \".ffn_down_exps.weight\" ]\n                tensors = self.load_multi(key, targets, device=device)\n                gate = tensors[\".ffn_gate_exps.weight\"]\n                up = tensors[\".ffn_up_exps.weight\"]\n                down = tensors[\".ffn_down_exps.weight\"]\n                gate_type = self.gguf_loader.tensor_info[key + \".ffn_gate_exps.weight\"][\"ggml_type\"]\n                up_type = self.gguf_loader.tensor_info[key + \".ffn_up_exps.weight\"][\"ggml_type\"]\n                down_type = self.gguf_loader.tensor_info[key + \".ffn_down_exps.weight\"][\"ggml_type\"]\n            # elif key + \".ffn_down.0.weight\" in self.gguf_loader.tensor_info: # TODO: maybe problem in merge (this is origin one)\n            elif self.gguf_loader.has_tensor(key + \".ffn_down.0.weight\"):\n                # for supporting  Mixtral-8x7B-Instuct  \n                gate = []\n                up = []\n                down = []\n                for i in range(8):\n                    gatei, upi, downi = f\".ffn_gate.{i}.weight\", f\".ffn_up.{i}.weight\", f\".ffn_down.{i}.weight\"\n                    targets = [gatei, upi, downi]\n                    tensors = self.load_multi(key, targets, device=device)\n                    gate_it, up_it, down_it = tensors[gatei], tensors[upi], tensors[downi]\n                    gate.append(gate_it)\n                    up.append(up_it)\n                    down.append(down_it)\n                gate = torch.stack(gate)\n                up = torch.stack(up)\n                down = torch.stack(down)\n                gate_type = self.gguf_loader.tensor_info[key + \".ffn_gate.0.weight\"][\"ggml_type\"]\n                up_type = self.gguf_loader.tensor_info[key + \".ffn_up.0.weight\"][\"ggml_type\"]\n                down_type = self.gguf_loader.tensor_info[key + \".ffn_down.0.weight\"][\"ggml_type\"]\n            else:\n                raise ValueError(f\"Experts {key} not found in gguf_loader\")\n            res = {key:{\"gate\": gate, \"up\": up, \"down\": down, \"gate_type\": gate_type, \"up_type\": up_type, \"down_type\": down_type}}\n        return res\n    \n    def load_multi(self, key: str, keys: list[str], device: str = \"cpu\"):\n        tensors = {}\n        for k in keys:\n            tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)\n        return tensors\nclass KExpertsCPU(KExpertsBase):\n    input_tensor_cpu:Tensor = None\n    expert_ids_cpu:Tensor = None\n    weights_cpu:Tensor = None\n    output_cpu:Tensor = None\n    output_gpu_map:dict = {} # Manage output tensor buffer on different gpu\n    #stream_map:dict = {} # Manage cuda stream on different gpu\n    # @TODO add yaml\n    CPU_INFER = CPUInfer(Config().cpu_infer)\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        n_routed_experts: int,\n        orig_module: nn.Module = None,\n        device: str = \"cpu\",\n        out_device: str = \"cuda\", # this device mean which device the output should on. TODO: support cpu.\n        **kwargs\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        assert device.lower() == \"cpu\", \"KExpertsCPU can only be loaded on CPU\"\n        self.n_routed_experts = n_routed_experts\n        self.out_device = out_device\n        self.backend = kwargs.get(\"backend\", \"llamafile\")\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):\n        if device:\n            assert device.lower() == \"cpu\", \"KExpertsCPU can only be loaded on CPU, Parameter \\\"device\\\" can be cpu or None.\"\n        if w is None: w = self.load_weights()[self.key]\n        self.gate = w[\"gate\"]\n        self.up = w[\"up\"]\n        self.down = w[\"down\"]\n        self.gate_type = w[\"gate_type\"]\n        self.up_type = w[\"up_type\"]\n        self.down_type = w[\"down_type\"]\n        gate_ptr = ctypes.addressof(\n            ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        up_ptr = ctypes.addressof(\n            ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        down_ptr = ctypes.addressof(\n            ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        # print(self.gate_qtype, self.up_qtype, self.down_qtype)\n        n_routed_experts = self.n_routed_experts\n        self.cpu_infer = KExpertsCPU.CPU_INFER\n        # n_routed_experts = len(self.orig_module)\n        model_dtype = torch.get_default_dtype()\n        if torch.xpu.is_available() and model_dtype == torch.float16:\n            hidden_type = 1 # fp16\n        else:\n            hidden_type = 30 # bf16\n        if self.backend == \"llamafile\":\n            moe_config = MOEConfig(\n                n_routed_experts,\n                self.config.num_experts_per_tok,\n                self.config.hidden_size,\n                self.config.moe_intermediate_size,\n                64,\n                10,\n                1024,\n                gate_ptr,\n                up_ptr,\n                down_ptr,\n                self.gate_type,\n                self.up_type,\n                self.down_type,\n                hidden_type, # TODO: get from model.dtype\n            )\n            self.moe = MOE(moe_config)\n        elif self.backend == \"AMXBF16\":\n            from cpuinfer_ext.moe import AMX_MOEConfig, AMXBF16_MOE\n            assert self.gate_type == GGMLQuantizationType.BF16\n            assert self.up_type == GGMLQuantizationType.BF16\n            assert self.down_type == GGMLQuantizationType.BF16\n            moe_config = AMX_MOEConfig(\n                n_routed_experts,\n                self.config.num_experts_per_tok,\n                self.config.hidden_size,\n                self.config.moe_intermediate_size,\n                max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,\n                gate_ptr,\n                up_ptr,\n                down_ptr,\n            )\n            self.moe = AMXBF16_MOE(moe_config)\n            self.cpu_infer.submit(self.moe.load_weights())\n            self.cpu_infer.sync()\n        elif self.backend == \"AMXInt8\":\n            from cpuinfer_ext.moe import AMX_MOEConfig, AMXInt8_MOE\n            assert self.gate_type == GGMLQuantizationType.BF16\n            assert self.up_type == GGMLQuantizationType.BF16\n            assert self.down_type == GGMLQuantizationType.BF16\n            moe_config = AMX_MOEConfig(\n                n_routed_experts,\n                self.config.num_experts_per_tok,\n                self.config.hidden_size,\n                self.config.moe_intermediate_size,\n                max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,\n                gate_ptr,\n                up_ptr,\n                down_ptr,\n            )\n            self.moe = AMXInt8_MOE(moe_config)\n            self.cpu_infer.submit(self.moe.load_weights())\n            self.cpu_infer.sync()\n        # print(n_routed_experts, hidden_size, moe_intermediate_size)\n        num_experts_per_tok = self.config.num_experts_per_tok\n        if warmup:\n            self.cpu_infer.submit(self.moe.warm_up())\n            self.cpu_infer.sync()\n        if self.out_device not in KExpertsCPU.output_gpu_map:\n            if isinstance(cuda_graphs, list):\n                KExpertsCPU.output_gpu_map[self.out_device] = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=self.out_device) for i in range(len(cuda_graphs))]\n            else:\n                KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((cuda_graphs, self.config.hidden_size), device=self.out_device)\n        if KExpertsCPU.input_tensor_cpu == None:\n            if isinstance(cuda_graphs, list):\n                KExpertsCPU.input_tensor_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=\"cpu\", pin_memory=True) for i in range(len(cuda_graphs))]\n                KExpertsCPU.expert_ids_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device=\"cpu\", dtype=torch.long, pin_memory=True) for i in range(len(cuda_graphs))]\n                KExpertsCPU.weights_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device=\"cpu\", dtype=torch.float32, pin_memory=True) for i in range(len(cuda_graphs))]\n                KExpertsCPU.output_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16) for i in range(len(cuda_graphs))]\n                KExpertsCPU.bsz_tensor_cpu = [torch.zeros((1), device=\"cpu\", dtype=torch.int32, pin_memory=True) for i in range(len(cuda_graphs))]\n            else:\n                KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device=\"cpu\", pin_memory=True)\n                KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device=\"cpu\", dtype=torch.long, pin_memory=True)\n                KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device=\"cpu\", dtype=torch.float32, pin_memory=True)\n                if torch.xpu.is_available():\n                    KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device=\"cpu\", pin_memory=True, dtype=model_dtype)\n                    KExpertsCPU.bsz_tensor_cpu = torch.ones((1), device=\"cpu\", dtype=torch.int32, pin_memory=True)\n                else:\n                    KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16)\n                    KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device=\"cpu\", dtype=torch.int32, pin_memory=True)\n            \n    def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):\n        if bsz_tensor is None:\n            bsz_tensor = torch.ones(1, device=input_tensor.device, dtype=torch.int32)\n        if cuda_graph_idx != -1:\n            KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)\n            KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)\n            KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)\n            KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)\n            self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))\n        else:\n            KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)\n            KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)\n            KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)\n            KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)\n            self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))\n        \n\n    def sync_for_one_decode(self, cuda_graph_idx=0):\n        if cuda_graph_idx != -1:\n            self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)\n            KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)\n            return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]\n        else:\n            self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)\n            KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)\n            return KExpertsCPU.output_gpu_map[self.out_device]\n\n    def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):\n        # generate, capture and run cuda graph\n        # print(expert_ids)\n        if bsz_tensor is None and (not torch.xpu.is_available() or input_tensor.size(0) > 1):\n            bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)\n        if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():\n            if cuda_graph_idx != -1:\n                KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)\n                KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)\n                KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)\n                KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)\n                self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))\n                self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)\n                KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)\n                return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]\n\n            else:\n                KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)\n                KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)\n                KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)\n                KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)\n                self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))\n                self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)\n                KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)\n                return KExpertsCPU.output_gpu_map[self.out_device]\n        elif input_tensor.size(0)==1 and torch.xpu.is_available():\n            KExpertsCPU.input_tensor_cpu.copy_(input_tensor.view(-1), non_blocking=True)\n            KExpertsCPU.expert_ids_cpu.copy_(expert_ids.view(-1), non_blocking=True)\n            KExpertsCPU.weights_cpu.copy_(weights.view(-1), non_blocking=True)\n            # KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor.view(-1), non_blocking=True)\n            self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))\n            self.cpu_infer.sync()\n            KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)\n            return KExpertsCPU.output_gpu_map[self.out_device].view(1, -1)\n        else:\n            input_tensor = input_tensor.contiguous().cpu()\n            expert_ids = expert_ids.contiguous().cpu()\n            weights = weights.contiguous().to(torch.float32).cpu()\n            bsz_tensor = bsz_tensor.contiguous().cpu()\n            output = torch.empty_like(input_tensor).contiguous()\n            self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))\n            self.cpu_infer.sync()\n            return output.to(device=object.__getattribute__(self, \"out_device\"))\n    \n    def unload(self):\n        return\n\n    def load_weights(self, override_key: str | None = None, device: str = \"cpu\"):\n        # TODO: support Bias\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n        gate_type = None\n        up_type = None\n        down_type = None\n\n        for key in keys:\n            if isinstance(self.gguf_loader, SafeTensorLoader):\n                res = self.gguf_loader.load_experts(key)\n                return {key: res}\n            elif self.gguf_loader.has_tensor(key + \".ffn_gate_exps.weight\"):\n                gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n                # gate_type = self.gguf_loader.tensor_info[key + \".ffn_gate_exps.weight\"][\"ggml_type\"]\n                # up_type = self.gguf_loader.tensor_info[key + \".ffn_up_exps.weight\"][\"ggml_type\"]\n                # down_type = self.gguf_loader.tensor_info[key + \".ffn_down_exps.weight\"][\"ggml_type\"]\n                gate_type = self.gguf_loader.get_ggml_type(key + \".ffn_gate_exps.weight\")\n                up_type = self.gguf_loader.get_ggml_type(key + \".ffn_up_exps.weight\")\n                down_type = self.gguf_loader.get_ggml_type(key + \".ffn_down_exps.weight\")\n            \n            elif key + \".ffn_gate_exps.weight\" in self.gguf_loader.tensor_info:\n                gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n                gate_type = self.gguf_loader.tensor_info[key + \".ffn_gate_exps.weight\"][\"ggml_type\"]\n                up_type = self.gguf_loader.tensor_info[key + \".ffn_up_exps.weight\"][\"ggml_type\"]\n                down_type = self.gguf_loader.tensor_info[key + \".ffn_down_exps.weight\"][\"ggml_type\"]\n            elif key + \".ffn_down.0.weight\" in self.gguf_loader.tensor_info:\n                # for supporting  Mixtral-8x7B-Instuct  \n                gate = []\n                up = []\n                down = []\n                for i in range(8):\n                    gate_it = self.gguf_loader.get_mmap_tensor(f\"{key}.ffn_gate.{i}.weight\")\n                    up_it = self.gguf_loader.get_mmap_tensor(f\"{key}.ffn_up.{i}.weight\")\n                    down_it = self.gguf_loader.get_mmap_tensor(f\"{key}.ffn_down.{i}.weight\")\n                    gate.append(gate_it)\n                    up.append(up_it)\n                    down.append(down_it)\n                gate = np.stack(gate)\n                up = np.stack(up)\n                down = np.stack(down)\n                gate_type = self.gguf_loader.get_ggml_type(key + \".ffn_gate.0.weight\")\n                up_type = self.gguf_loader.get_ggml_type(key + \".ffn_up.0.weight\")\n                down_type = self.gguf_loader.get_ggml_type(key + \".ffn_down.0.weight\")\n            else:\n                raise ValueError(f\"Experts {key} not found in gguf_loader\")\n            res = {key:{\"gate\": gate, \"up\": up, \"down\": down, \"gate_type\": gate_type, \"up_type\": up_type, \"down_type\": down_type}}\n        return res\nclass KSFTExpertsCPU(torch.autograd.Function):\n    input_tensor_cpu:Tensor = None\n    expert_ids_cpu:Tensor = None\n    weights_cpu:Tensor = None\n    output_cpu:Tensor = None\n    output_gpu_map:dict = {} # Manage output tensor buffer on different gpu\n    #stream_map:dict = {} # Manage cuda stream on different gpu\n    #gguf_loader:GGUFLoader = None\n    CPU_INFER = CPUInfer(Config().cpu_infer)\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        n_routed_experts: int,\n        orig_module: nn.Module = None,\n        device: str = \"cpu\",\n        out_device: str = \"cuda\", # this device mean which device the output should on. TODO: support cpu.\n        **kwargs\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        #if KExpertsCPU.gguf_loader is None:\n        #    KExpertsCPU.gguf_loader = GGUFLoader(\"/mnt/data/model/DeepseekV3-q4km-gguf\")\n        self.gguf_loader = gguf_loader\n        assert device.lower() == \"cpu\", \"KExpertsCPU can only be loaded on CPU\"\n        self.n_routed_experts = n_routed_experts\n        self.out_device = out_device\n        self.backend = kwargs.get(\"backend\", \"llamafile\")\n\n        self.key = key\n        self.config = config\n        self.device = device\n\n        self.call_count = 0\n        self.flops_per_call = []\n        self.times = []\n        \n        self.tflops_fwd = []\n        self.tflops_bwd = []\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):\n        if device:\n            assert device.lower() == \"cpu\", \"KSFTExpertsCPU can only be loaded on CPU, Parameter \\\"device\\\" can be cpu or None.\"\n        if w is None: w = self.load_weights()[self.key]\n        self.gate = w[\"gate\"]\n        self.up = w[\"up\"]\n        self.down = w[\"down\"]\n        self.gate_type = w[\"gate_type\"]\n        self.up_type = w[\"up_type\"]\n        self.down_type = w[\"down_type\"]\n        gate_ptr = ctypes.addressof(\n            ctypes.cast(self.gate.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        up_ptr = ctypes.addressof(\n            ctypes.cast(self.up.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        down_ptr = ctypes.addressof(\n            ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        #print(self.gate_type, self.up_type, self.down_type)\n        n_routed_experts = self.n_routed_experts\n        # n_routed_experts = len(self.orig_module)\n        self.cpu_infer = KSFTExpertsCPU.CPU_INFER\n        \n        model_dtype = torch.get_default_dtype()\n        if torch.xpu.is_available() and model_dtype == torch.float16:\n            hidden_type = 1 # fp16\n        else:\n            hidden_type = 30 # bf16\n        if self.backend == \"llamafile\":\n            # print(\"GO INTO LLAMAFILE!!\")\n            moe_config = SFT_MOEConfig(\n                n_routed_experts,\n                self.config.num_experts_per_tok,\n                self.config.hidden_size,\n                self.config.moe_intermediate_size,\n                64,\n                10,\n                1024,\n                gate_ptr,\n                up_ptr,\n                down_ptr,\n                self.gate_type,\n                self.up_type,\n                self.down_type,\n                hidden_type, # TODO: get from model.dtype\n            )\n            self.moe = SFT_MOE(moe_config)\n        elif self.backend == \"AMXBF16\":\n            print(\"GO INTO AMXBF16!!\")\n            from cpuinfer_ext.sft_moe import SFT_AMX_MOEConfig, SFT_AMXBF16_MOE\n            assert self.gate_type == GGMLQuantizationType.BF16\n            assert self.up_type == GGMLQuantizationType.BF16\n            assert self.down_type == GGMLQuantizationType.BF16\n            moe_config = SFT_AMX_MOEConfig(\n                n_routed_experts,\n                self.config.num_experts_per_tok,\n                self.config.hidden_size,\n                self.config.moe_intermediate_size,\n                max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,\n                gate_ptr,\n                up_ptr,\n                down_ptr,\n            )\n            self.moe = SFT_AMXBF16_MOE(moe_config)\n            self.cpu_infer.submit(self.moe.load_weights())\n            self.cpu_infer.sync()\n        elif self.backend == \"AMXInt8\":\n            print(\"GO INTO AMXInt8!!\")\n            from cpuinfer_ext.sft_moe import SFT_AMX_MOEConfig, SFT_AMXInt8_MOE\n            assert self.gate_type == GGMLQuantizationType.BF16\n            assert self.up_type == GGMLQuantizationType.BF16\n            assert self.down_type == GGMLQuantizationType.BF16\n            moe_config = SFT_AMX_MOEConfig(\n                n_routed_experts,\n                self.config.num_experts_per_tok,\n                self.config.hidden_size,\n                self.config.moe_intermediate_size,\n                max(cuda_graphs) if isinstance(cuda_graphs, list) else Config().chunk_size,\n                gate_ptr,\n                up_ptr,\n                down_ptr,\n            )\n            self.moe = SFT_AMXInt8_MOE(moe_config)\n            self.cpu_infer.submit(self.moe.load_weights())\n            self.cpu_infer.sync()\n\n        # print(n_routed_experts, hidden_size, moe_intermediate_size)\n        num_experts_per_tok = self.config.num_experts_per_tok\n        if warmup:\n            self.cpu_infer.submit(self.moe.warm_up())\n            self.cpu_infer.sync()\n        if self.out_device not in KSFTExpertsCPU.output_gpu_map:\n            KSFTExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)\n        if KSFTExpertsCPU.input_tensor_cpu == None:\n            KSFTExpertsCPU.input_tensor_cpu = torch.zeros((self.config.hidden_size), device=\"cpu\", pin_memory=True)\n            KSFTExpertsCPU.expert_ids_cpu = torch.zeros((num_experts_per_tok), device=\"cpu\", dtype=torch.long, pin_memory=True)\n            KSFTExpertsCPU.weights_cpu = torch.zeros((num_experts_per_tok), device=\"cpu\", dtype=torch.float32, pin_memory=True)\n            KSFTExpertsCPU.output_cpu = torch.zeros((self.config.hidden_size), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16)\n            \n        self.gate = None\n        self.up = None\n        self.down = None\n            \n    def submit_for_one_decode(self, input_tensor, expert_ids, weights):\n        KSFTExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)\n        KSFTExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)\n        KSFTExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)\n        self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), KSFTExpertsCPU.expert_ids_cpu.data_ptr(), KSFTExpertsCPU.weights_cpu.data_ptr(), KSFTExpertsCPU.input_tensor_cpu.data_ptr(), KSFTExpertsCPU.output_cpu.data_ptr()))\n        \n    def sync_for_one_decode(self):\n        self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)\n        KSFTExpertsCPU.output_gpu_map[self.out_device].copy_(KSFTExpertsCPU.output_cpu, non_blocking=True)\n        return KSFTExpertsCPU.output_gpu_map[self.out_device]\n\n    @staticmethod\n    def forward(ctx, input_tensor, expert_ids, weights, cpu_infer, moe, out_device, layer_idx):\n        # print(\"Go into the forward\")\n        \n        # generate, capture and run cuda graph\n        # torch.set_printoptions(threshold=float('inf'))\n        # print(expert_ids)\n        # expert_ids.cpu().numpy().tofile('debug_expert_ids.txt', sep='\\n')\n        # print(expert_ids.size())\n        # print(xx)\n        if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing():\n            # TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible\n            #print(\"capturing experts\")\n            KSFTExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)\n            KSFTExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)\n            KSFTExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)\n            cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, moe.forward(1, expert_ids.size(1), KSFTExpertsCPU.expert_ids_cpu.data_ptr(), KSFTExpertsCPU.weights_cpu.data_ptr(), KSFTExpertsCPU.input_tensor_cpu.data_ptr(), KSFTExpertsCPU.output_cpu.data_ptr()))\n            cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)\n            t_fwd     = time.time() - wall_t0\n            KSFTExpertsCPU.output_gpu_map[out_device].copy_(KSFTExpertsCPU.output_cpu, non_blocking=True)\n            result = KSFTExpertsCPU.output_gpu_map[out_device]\n        else:\n            input_tensor = input_tensor.contiguous().cpu()\n            expert_ids = expert_ids.contiguous().cpu()\n            weights = weights.contiguous().to(torch.float32).cpu()\n            output = torch.empty_like(input_tensor).contiguous()\n            # print(\"success record\")\n            wall_t0 = time.time()\n            cpu_infer.submit(\n                moe.forward(\n                    expert_ids.size(0), \n                    expert_ids.size(1), \n                    expert_ids.data_ptr(), \n                    weights.data_ptr(), \n                    input_tensor.data_ptr(), \n                    output.data_ptr(),\n                )\n            )\n            cpu_infer.sync()\n            t_fwd     = time.time() - wall_t0\n\n            result = output.to(device=out_device)\n\n        ctx.save_for_backward(input_tensor, expert_ids, weights)\n        ctx.cpu_infer  = cpu_infer\n        ctx.moe        = moe\n        ctx.out_device = out_device\n        ctx.layer_idx = layer_idx\n        \n        # ---------- FLOPs ----------\n        qlen = expert_ids.size(0)\n        k    = expert_ids.size(1)\n\n        flops_fwd = 6 * qlen * k * H_FIXED * M_FIXED\n        tflops_f  = flops_fwd / t_fwd / 1e12\n\n        ctx.saved_dims = (qlen, k)\n        ctx._time_fwd  = t_fwd\n        # print(f\"qlen ,k:{qlen}, {k}\")\n        \n        # with open(\"test_V3_ESC.txt\", \"a\", encoding=\"utf-8\") as f:\n        #     f.write(f\"[KSFTExpertsCPU]Forward: {flops_fwd/1e9:.3f} GFLOPs {tflops_f:.2f} TFLOPS {t_fwd*1e3:.2f} ms\\n\")\n\n        return result\n        \n    @staticmethod\n    def backward(ctx, output_grad):\n        # print(\"Go into the backward!!\")\n        \n        # Pick back the middle results\n        input_tensor, expert_ids, weights = ctx.saved_tensors\n        import random\n        layer_idx = random.randint(0, 10000)\n        # print(f\"layer_idx:{layer_idx}\")\n        # layer_idx   = ctx.layer_idx\n        \n        # cpu_infer  = ctx.cpu_infer\n        # moe        = ctx.moe\n        # out_device = ctx.out_device\n\n        # ready for computing gradient\n        output_grad = output_grad.contiguous().cpu()\n        input_grad = torch.empty_like(input_tensor).contiguous()\n        # print(dir(cpuinfer_ext.moe.MOE))\n        bw_start = time.time()\n        ctx.cpu_infer.submit(\n            ctx.moe.backward(\n                # layer_idx,\n                output_grad.size(0),  # qlen\n                expert_ids.size(1),   # k\n                expert_ids.data_ptr(),\n                weights.data_ptr(),\n                input_tensor.data_ptr(), \n                output_grad.data_ptr(),\n                input_grad.data_ptr(),\n            )\n        )\n        ctx.cpu_infer.sync()\n        \n        bw_end   = time.time()\n        t_bw    = bw_end - bw_start\n        \n        # ---------- FLOPs ----------\n        qlen, k  = ctx.saved_dims\n        flops_bw = 10 * qlen * k * H_FIXED * M_FIXED\n        tflops_b = flops_bw / t_bw / 1e12\n        # print(f\"qlen:{qlen}, k:{k}\")\n\n        # with open(\"test_V3_ESC.txt\", \"a\", encoding=\"utf-8\") as f:\n        #     f.write(f\"[KSFTExpertsCPU]Backward: {flops_bw/1e9:.3f} GFLOPs {tflops_b:.2f} TFLOPS {t_bw*1e3:.2f} ms\\n\")\n        \n        return input_grad.to(device=ctx.out_device), None, None, None, None, None, None\n    \n    def unload(self):\n        return\n\n    def load_weights(self, override_key: str | None = None, device: str = \"cpu\"):\n        # TODO: support Bias\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n        gate_type = None\n        up_type = None\n        down_type = None\n\n        for key in keys:\n            if isinstance(self.gguf_loader, SafeTensorLoader):\n                res = self.gguf_loader.load_experts(key)\n                return {key: res}\n            elif self.gguf_loader.has_tensor(key + \".ffn_gate_exps.weight\"):\n                gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n                # gate_type = self.gguf_loader.tensor_info[key + \".ffn_gate_exps.weight\"][\"ggml_type\"]\n                # up_type = self.gguf_loader.tensor_info[key + \".ffn_up_exps.weight\"][\"ggml_type\"]\n                # down_type = self.gguf_loader.tensor_info[key + \".ffn_down_exps.weight\"][\"ggml_type\"]\n                gate_type = self.gguf_loader.get_ggml_type(key + \".ffn_gate_exps.weight\")\n                up_type = self.gguf_loader.get_ggml_type(key + \".ffn_up_exps.weight\")\n                down_type = self.gguf_loader.get_ggml_type(key + \".ffn_down_exps.weight\")\n            \n            elif key + \".ffn_gate_exps.weight\" in self.gguf_loader.tensor_info:\n                gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n                gate_type = self.gguf_loader.tensor_info[key + \".ffn_gate_exps.weight\"][\"ggml_type\"]\n                up_type = self.gguf_loader.tensor_info[key + \".ffn_up_exps.weight\"][\"ggml_type\"]\n                down_type = self.gguf_loader.tensor_info[key + \".ffn_down_exps.weight\"][\"ggml_type\"]\n            elif key + \".ffn_down.0.weight\" in self.gguf_loader.tensor_info:\n                # for supporting  Mixtral-8x7B-Instuct  \n                gate = []\n                up = []\n                down = []\n                for i in range(8):\n                    gate_it = self.gguf_loader.get_mmap_tensor(f\"{key}.ffn_gate.{i}.weight\")\n                    up_it = self.gguf_loader.get_mmap_tensor(f\"{key}.ffn_up.{i}.weight\")\n                    down_it = self.gguf_loader.get_mmap_tensor(f\"{key}.ffn_down.{i}.weight\")\n                    gate.append(gate_it)\n                    up.append(up_it)\n                    down.append(down_it)\n                gate = np.stack(gate)\n                up = np.stack(up)\n                down = np.stack(down)\n                gate_type = self.gguf_loader.get_ggml_type(key + \".ffn_gate.0.weight\")\n                up_type = self.gguf_loader.get_ggml_type(key + \".ffn_up.0.weight\")\n                down_type = self.gguf_loader.get_ggml_type(key + \".ffn_down.0.weight\")\n            else:\n                raise ValueError(f\"Experts {key} not found in gguf_loader\")\n            res = {key:{\"gate\": gate, \"up\": up, \"down\": down, \"gate_type\": gate_type, \"up_type\": up_type, \"down_type\": down_type}}\n        return res\n    \nclass KExpertsMarlin(KExpertsBase):\n    expert_num: int\n    loaded_experts_idx: list[int]\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        n_routed_experts: int,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        **kwargs\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.expert_num = n_routed_experts\n        self.loaded_experts_idx = []\n        self.act_fn = ACT2FN[config.hidden_act]\n        assert device.lower() != \"cpu\", \"Marlin experts can only be loaded on GPU\"\n        self.device = device\n        self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size\n\n        # create empty marlin experts according to the number of experts per token\n        # up\n        self.up_projs = [KLinearMarlin(key+ \".\" + \"ffn_up_exps\", gguf_loader, config, device=device) for i in range(self.expert_num)]\n        # gate\n        self.gate_projs = [KLinearMarlin(key+ \".\" + \"ffn_gate_exps\", gguf_loader, config, device=device) for i in range(self.expert_num)]\n        # down\n        self.down_projs = [KLinearMarlin(key+ \".\" + \"ffn_down_exps\", gguf_loader, config, device=device) for i in range(self.expert_num)]\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):\n        if device is None: device = self.device\n        assert device.lower() != \"cpu\", \"Marlin experts can only be loaded on GPU\"\n        if w is None:\n            w = self.load_weights()\n            load_by_experts = True\n\n        if load_by_experts:\n            if isinstance(w, dict):\n                self.gate = w[\"gate\"]\n                self.up = (w[\"up\"])\n                self.down = (w[\"down\"])\n                for i in tqdm(range(self.expert_num), desc=f\"Dequanting and quanting for KExpertsMarlin {self.key}\"):\n                    up_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_up_exps.weight\", self.up, i, self.elements_per_tensor, device=self.device)\n                    gate_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_gate_exps.weight\", self.gate, i, self.elements_per_tensor, device=self.device)\n                    down_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_down_exps.weight\", self.down, i, self.elements_per_tensor, device=self.device)\n                    \n                    self.up_projs[i].load(nn.Parameter(up_weights), device=device)\n                    self.gate_projs[i].load(nn.Parameter(gate_weights), device=device)\n                    self.down_projs[i].load(nn.Parameter(down_weights), device=device)\n                    self.loaded_experts_idx.append(i)\n        else:\n            if isinstance(w, dict):\n                self.gate = w[\"gate\"]\n                self.up = (w[\"up\"])\n                self.down = (w[\"down\"])\n                for i in range(self.expert_num):\n                    self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)\n                    self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)\n                    self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)\n                    self.loaded_experts_idx.append(i)\n        return \n\n    def unload(self):\n        for i in self.loaded_experts_idx:\n            self.up_projs[i].unload()\n            self.gate_projs[i].unload()\n            self.down_projs[i].unload()\n        self.loaded_experts_idx = []\n\n    def load_weights(self, override_key: str | None = None):\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n\n        for key in keys:\n            if self.gguf_loader.has_tensor(key + \".ffn_gate_exps.weight\"):\n                gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n            res = {\"gate\": gate, \"up\": up, \"down\": down}\n        return res\n\n    def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:\n        org_dtype = hidden_states_cpu.dtype\n        org_device = hidden_states_cpu.device\n        hidden_states_cpu = hidden_states_cpu.to(self.device)\n        selected_experts_cpu = selected_experts_cpu.to(self.device)\n        routing_weights_cpu = routing_weights_cpu.to(self.device).to(org_dtype)\n        \n        batch_sequence_length, hidden_dim = hidden_states_cpu.size()\n\n        final_hidden_states = torch.zeros(\n            (batch_sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device\n        )\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.expert_num):\n            if not expert_mask[expert_idx].any():\n                continue\n            idx, top_x = torch.where(expert_mask[expert_idx])\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)\n            G = self.gate_projs[expert_idx].forward(current_state)\n            A = self.act_fn(G)\n            U = self.up_projs[expert_idx].forward(current_state)\n            H = A * U  # Element-wise multiplication\n            current_hidden_states = self.down_projs[expert_idx].forward(H) * routing_weights_cpu[top_x, idx, None]\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states)\n        \n        return final_hidden_states.to(dtype=org_dtype, device=org_device)\n    \n# untested, CUDA OOM\nclass KExpertsTorch(KExpertsBase):\n    expert_num: int\n    loaded_experts_idx: list[int]\n    gate: torch.Tensor\n    up: torch.Tensor\n    down: torch.Tensor\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        n_routed_experts: int,\n        orig_module: nn.Module = None,\n        device: str = \"cpu\",\n        **kwargs\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.expert_num = n_routed_experts\n        # self.loaded_experts_idx = []\n        self.act_fn = ACT2FN[config.hidden_act]\n        self.device = device\n        self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size\n        self.gate = [None for _ in range(self.expert_num)]\n        self.up = [None for _ in range(self.expert_num)]\n        self.down = [None for _ in range(self.expert_num)]\n        self.dtype = torch.get_default_dtype()\n\n        self.call_count = 0\n        self.flops_per_call = []\n        self.times = []\n        self.expert_flops_details = []  \n        self.total_flops = 0\n        \n        h = self.config.hidden_size\n        m = self.config.moe_intermediate_size\n        self.params_per_expert = 3 * h * m\n        self.total_params = self.expert_num * self.params_per_expert\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):\n        if device is None: device = self.device\n        if w is None:\n            w = self.load_weights()\n            load_by_experts = True\n\n        if load_by_experts:\n            if isinstance(w, dict):\n                if isinstance(self.gguf_loader, SafeTensorLoader): \n                    for i in tqdm(range(self.expert_num), desc=f\"Loading experts(safetensors) for {self.key}\"):\n                        up_k   = f\"{self.key}.{i}.up_proj.weight\"\n                        gate_k = f\"{self.key}.{i}.gate_proj.weight\"\n                        down_k = f\"{self.key}.{i}.down_proj.weight\"\n                        \n                        self.up[i]   = self.gguf_loader.load_tensor(up_k,   device=self.device).contiguous()\n                        self.gate[i] = self.gguf_loader.load_tensor(gate_k, device=self.device).contiguous()\n                        self.down[i] = self.gguf_loader.load_tensor(down_k, device=self.device).contiguous()\n                else: # GGUFLoader\n                    for i in tqdm(range(self.expert_num), desc=f\"Dequanting for KExpertsTorch {self.key}\"):\n                        up_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_up_exps.weight\", w[\"up\"], i, self.elements_per_tensor, device=self.device)\n                        gate_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_gate_exps.weight\", w[\"gate\"], i, self.elements_per_tensor, device=self.device)\n                        down_weights = self.gguf_loader.load_expert_tensor(self.key + \".ffn_down_exps.weight\", w[\"down\"], i, self.elements_per_tensor, device=self.device)\n                        \n                        self.up[i] = up_weights\n                        self.gate[i] = gate_weights\n                        self.down[i] = down_weights\n        else:\n            if isinstance(w, dict):\n                for i in range(self.expert_num):\n                    self.gate[i] = w[\"gate\"][i, ...].to(device=device, dtype=self.dtype)\n                    self.up[i] = w[\"up\"][i, ...].to(device=device, dtype=self.dtype)\n                    self.down[i] = w[\"down\"][i, ...].to(device=device, dtype=self.dtype)\n        \n        # self.up = torch.stack(self.up, dim=0)\n        # self.gate = torch.stack(self.gate, dim=0)\n        # self.down = torch.stack(self.down, dim=0)\n        self.up = nn.Parameter(torch.stack(self.up, dim=0))\n        self.gate = nn.Parameter(torch.stack(self.gate, dim=0))\n        self.down = nn.Parameter(torch.stack(self.down, dim=0))\n        return \n\n    def unload(self):\n        if self.gate is not None:\n            self.gate = None\n            self.up = None\n            self.down = None\n\n    def load_weights(self, override_key: str | None = None):\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n\n        for key in keys:\n            if isinstance(self.gguf_loader, SafeTensorLoader):\n                res = self.gguf_loader.load_experts(key)\n                return {key: res}\n            elif key + \".ffn_gate_exps.weight\" in self.gguf_loader.tensor_info:\n                gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n            else:\n                import re\n                match = re.match(r'model\\.layers\\.(\\d+)\\.mlp\\.experts(.*)', key)\n                if match:\n                    layer_id = match.group(1)\n                    suffix = match.group(2)\n                    key = f\"blk.{layer_id}{suffix}\"\n                    if key + \".ffn_gate_exps.weight\" in self.gguf_loader.tensor_info:\n                        gate = self.gguf_loader.get_mmap_tensor(key + \".ffn_gate_exps.weight\")\n                        up = self.gguf_loader.get_mmap_tensor(key + \".ffn_up_exps.weight\")\n                        down = self.gguf_loader.get_mmap_tensor(key + \".ffn_down_exps.weight\")\n            res = {\"gate\": gate, \"up\": up, \"down\": down}\n        return res\n\n    def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:\n        start_time = time.time()\n\n        org_device = hidden_states_cpu.device\n        hidden_states_cpu = hidden_states_cpu.to(self.device)\n        selected_experts_cpu = selected_experts_cpu.to(self.device)\n        routing_weights_cpu = routing_weights_cpu.to(self.device)\n        \n        batch_sequence_length, hidden_dim = hidden_states_cpu.size()\n\n        final_hidden_states = torch.zeros(\n            (batch_sequence_length, hidden_dim), dtype=self.gate.dtype, device=hidden_states_cpu.device\n        )\n        org_dtype = hidden_states_cpu.dtype\n        hidden_states_cpu = hidden_states_cpu.to(self.gate.dtype)\n        routing_weights_cpu = routing_weights_cpu.to(self.gate.dtype)\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.expert_num):\n            idx, top_x = torch.where(expert_mask[expert_idx])\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)\n            G = current_state @ self.gate[expert_idx,...].T\n            A = self.act_fn(G)\n            U = current_state @ self.up[expert_idx,...].T\n            H = A * U  # Element-wise multiplication\n            current_hidden_states = H @ self.down[expert_idx,...].T * routing_weights_cpu[top_x, idx, None]\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states)\n\n        call_flops = 0\n        expert_details = []\n        \n        for expert_idx in range(self.expert_num):\n            idx, top_x = torch.where(expert_mask[expert_idx])\n            t_e = len(top_x)\n            if t_e == 0:\n                expert_details.append({'gate':0, 'act':0, 'up':0, \n                                      'element':0, 'down':0, 'routing':0})\n                continue\n                \n            h = self.config.hidden_size\n            m = self.config.moe_intermediate_size\n            \n            flops_gate = 2 * t_e * h * m\n            flops_act = t_e * m\n            flops_up = 2 * t_e * h * m\n            flops_element = t_e * m\n            flops_down = 2 * t_e * m * h\n            flops_routing = t_e * h\n            \n            total_expert = sum([flops_gate, flops_act, flops_up, \n                               flops_element, flops_down, flops_routing])\n            call_flops += total_expert\n            \n            expert_details.append({\n                'gate': flops_gate,\n                'act': flops_act,\n                'up': flops_up,\n                'element': flops_element,\n                'down': flops_down,\n                'routing': flops_routing\n            })\n        \n        self.call_count += 1\n        self.flops_per_call.append(call_flops)\n        self.total_flops += call_flops\n        self.expert_flops_details.append(expert_details)\n        self.times.append(time.time() - start_time)\n\n        return final_hidden_states.to(dtype=org_dtype, device=org_device)\n\n    # def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:\n    #     print(\"Enter the forward function!\")\n    #     current_call_start = time.perf_counter()\n    #     if hasattr(self, 'last_call_end_time') and self.last_call_end_time is not None:\n    #         inter_call_interval = current_call_start - self.last_call_end_time\n    #         # print(f\"\\n[Forward Call Interval] Time since last forward call: {inter_call_interval:.6f} seconds\")\n    #         logging.info(f\"\\n[Forward Call Interval] Time since last forward call: {inter_call_interval:.6f} seconds\")\n    #     else:\n    #         inter_call_interval = 0.0\n\n    #     data_transfer_time = 0.0\n    #     tensor_init_time = 0.0\n    #     expert_mask_time = 0.0\n    #     expert_loop_total = 0.0\n    #     gate_time_total = 0.0\n    #     up_time_total = 0.0\n    #     elementwise_time_total = 0.0\n    #     down_time_total = 0.0\n    #     index_add_time_total = 0.0\n    #     cast_back_time = 0.0\n\n    #     start = time.perf_counter()\n    #     org_device = hidden_states_cpu.device\n    #     hidden_states_cpu = hidden_states_cpu.to(self.device)\n    #     selected_experts_cpu = selected_experts_cpu.to(self.device)\n    #     routing_weights_cpu = routing_weights_cpu.to(self.device)\n    #     data_transfer_time = time.perf_counter() - start\n\n    #     start = time.perf_counter()\n    #     batch_sequence_length, hidden_dim = hidden_states_cpu.size()\n    #     final_hidden_states = torch.zeros(\n    #         (batch_sequence_length, hidden_dim), dtype=self.gate.dtype, device=hidden_states_cpu.device\n    #     )\n    #     org_dtype = hidden_states_cpu.dtype\n    #     hidden_states_cpu = hidden_states_cpu.to(self.gate.dtype)\n    #     routing_weights_cpu = routing_weights_cpu.to(self.gate.dtype)\n    #     tensor_init_time = time.perf_counter() - start\n\n    #     start = time.perf_counter()\n    #     expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.expert_num).permute(2, 1, 0)\n    #     expert_mask_time = time.perf_counter() - start\n\n    #     expert_loop_start = time.perf_counter()\n    #     # for expert_idx in range(self.expert_num):\n    #     for expert_idx in tqdm(range(self.expert_num), \n    #         idx, top_x = torch.where(expert_mask[expert_idx])\n            \n    #         current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)\n\n    #         gate_start = time.perf_counter()\n    #         G = current_state @ self.gate[expert_idx,...].T\n    #         A = self.act_fn(G)\n    #         gate_time_total += time.perf_counter() - gate_start\n\n    #         up_start = time.perf_counter()\n    #         U = current_state @ self.up[expert_idx,...].T\n    #         up_time_total += time.perf_counter() - up_start\n\n    #         element_start = time.perf_counter()\n    #         H = A * U  # Element-wise multiplication\n    #         elementwise_time_total += time.perf_counter() - element_start\n\n    #         down_start = time.perf_counter()\n    #         current_hidden_states = H @ self.down[expert_idx,...].T * routing_weights_cpu[top_x, idx, None]\n    #         down_time_total += time.perf_counter() - down_start\n\n    #         index_start = time.perf_counter()\n    #         final_hidden_states.index_add_(0, top_x, current_hidden_states)\n    #         index_add_time_total += time.perf_counter() - index_start\n\n    #     expert_loop_total = time.perf_counter() - expert_loop_start\n    #     start = time.perf_counter()\n    #     final_hidden_states = final_hidden_states.to(dtype=org_dtype, device=org_device)\n    #     cast_back_time = time.perf_counter() - start\n\n    #     total_time = time.perf_counter() - current_call_start\n    #     print(f\"\"\"\n    # [Timing Breakdown]\n    #     Data Transfer:          {data_transfer_time:.6f}s\n    #     Tensor Initialization:  {tensor_init_time:.6f}s\n    #     Expert Mask Creation:   {expert_mask_time:.6f}s\n    #     Expert Loop Total:      {expert_loop_total:.6f}s\n    #         -> Gate Computations:   {gate_time_total:.6f}s\n    #         -> Up Projections:      {up_time_total:.6f}s\n    #         -> Elementwise Mult:    {elementwise_time_total:.6f}s\n    #         -> Down Projections:    {down_time_total:.6f}s\n    #         -> Index Add Ops:       {index_add_time_total:.6f}s\n    #     Cast Back to Original:  {cast_back_time:.6f}s\n    #     Total Forward Time:     {total_time:.6f}s\n    #     \"\"\")\n    #     logging.info(f\"\"\"\n    # [Timing Breakdown]\n    #     Data Transfer:          {data_transfer_time:.6f}s\n    #     Tensor Initialization:  {tensor_init_time:.6f}s\n    #     Expert Mask Creation:   {expert_mask_time:.6f}s\n    #     Expert Loop Total:      {expert_loop_total:.6f}s\n    #         -> Gate Computations:   {gate_time_total:.6f}s\n    #         -> Up Projections:      {up_time_total:.6f}s\n    #         -> Elementwise Mult:    {elementwise_time_total:.6f}s\n    #         -> Down Projections:    {down_time_total:.6f}s\n    #         -> Index Add Ops:       {index_add_time_total:.6f}s\n    #     Cast Back to Original:  {cast_back_time:.6f}s\n    #     Total Forward Time:     {total_time:.6f}s\n    #     \"\"\")\n\n    #     self.last_call_end_time = time.perf_counter()\n\n    #     return final_hidden_states\n\n\nEXPERTS_MAP = {\n    \"KExpertsCPU\": KExpertsCPU,\n    \"KSFTExpertsCPU\": KSFTExpertsCPU,\n    \"KExpertsTorch\": KExpertsTorch,\n    \"KExpertsMarlin\": KExpertsMarlin,\n}\n\nclass KTransformersExperts(BaseInjectedModule, KExpertsBase):\n    def __init__(self,\n                 key: str,\n                 gguf_loader: GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                #  device: str = \"cuda\",\n                 prefill_device:str = \"cuda\",\n                 prefill_op: str | None = \"KExpertsTorch\",\n                 generate_device: str = \"cpu\",\n                 generate_op: str | None = \"KExpertsCPU\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        if generate_op is not None:\n            self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)\n        else:\n            self.generate_experts = None\n        if prefill_op is not None:\n            self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)\n        else:\n            self.prefill_experts = None\n        self.gpu_mlp_type = prefill_op\n        self.cpu_mlp_type = generate_op\n        self.mode = InferenceState.UNLOAD\n\n    def load(self, w: dict = None,  mode: InferenceState = None, warmup: bool = True):\n        # TODO support w as input\n        if not mode: mode = InferenceState.GENERATE\n        if mode == InferenceState.GENERATE:\n            self.prefill_experts.unload()\n            self.generate_experts.load(w, warmup=warmup)\n            self.device = self.generate_experts.device\n            self.mode = mode\n        elif mode == InferenceState.PREFILL:\n            self.generate_experts.unload()\n            self.prefill_experts.load(w, warmup=warmup)\n            self.device = self.prefill_experts.device\n            self.mode = mode\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n            self.mode = mode\n            self.device = self.generate_experts.device\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n    def unload(self):\n        if self.generate_experts is not None:\n            self.generate_experts.unload()\n        if self.prefill_experts is not None:\n            self.prefill_experts.unload()\n        self.device = self.generate_experts.device\n\n    def forward(self, input_tensor, expert_ids, weights):\n        if self.mode == InferenceState.GENERATE:\n            assert self.generate_experts is not None, \"generate_experts is None\"\n            if type(self.generate_experts) == KSFTExpertsCPU:\n                layer_idx = int(re.search(r'\\d+', self.key).group())\n                return self.generate_experts.apply(input_tensor, expert_ids, weights, self.generate_experts.cpu_infer, self.generate_experts.moe, self.generate_experts.out_device, layer_idx)\n            else:\n                return self.generate_experts.forward(input_tensor, expert_ids, weights)\n        elif self.mode == InferenceState.PREFILL:\n            assert self.prefill_experts is not None, \"prefill_experts is None\"\n            return self.prefill_experts.forward(input_tensor, expert_ids, weights)\n        else:\n            raise ValueError(\"load or set_inference_mode before forward\")\n\n    def set_inference_mode(self, mode: InferenceState):\n        if mode == InferenceState.GENERATE:\n            self.load(mode=InferenceState.GENERATE, warmup=False)\n        elif mode == InferenceState.PREFILL:\n            self.load(mode=InferenceState.PREFILL, warmup=False)\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n\nfrom ktransformers.models.modeling_deepseek import DeepseekV2MoE\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3MoE\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock\nfrom ktransformers.models.modeling_mixtral import MixtralSparseMoeBlock\n\n\nclass KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):\n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        orig_shape = hidden_states.shape\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        if self.norm_topk_prob:\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n        \n        if sequence_length == 1 and hasattr(self.experts.generate_experts, \"submit_for_one_decode\"):\n            self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0])\n            shared_expert_output = self.shared_expert(hidden_states)\n            shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output\n            y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)\n            y += shared_expert_output\n            y.resize_(*orig_shape)\n            return y, router_logits\n        \n        hidden_states_expert = hidden_states.to(self.experts.device)  if isinstance(self.experts, KExpertsBase) else hidden_states.cpu()\n        selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts.cpu()\n        routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights.cpu()\n\n        shared_expert_output = self.shared_expert(hidden_states)\n        shared_expert_output = (\n            F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output\n        )\n\n        if isinstance(self.experts, KExpertsBase):\n            y = (\n                self.moe_kexperts(\n                    hidden_states_expert, selected_experts_expert, routing_weights_expert\n                )\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        elif hidden_states_expert.size(0) > 10:\n            y = self.moe_infer(\n                hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape\n            ).to(device=hidden_states.device)\n        else:\n            y = self.moe_infer_simple(\n                hidden_states_expert, selected_experts_expert, routing_weights_expert\n            ).to(device=hidden_states.device)\n        y += shared_expert_output\n        y.resize_(*orig_shape)\n        return y, router_logits\n    \n    @maybe_no_grad()\n    def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:\n        outs = self.experts(x, topk_ids, topk_weight)\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:\n        '''\n        hidden_states_cpu: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        '''\n        outs = torch.zeros_like(hidden_states_cpu)\n        for token_idx in range(selected_experts_cpu.size(0)):\n            for expert_idx in range(selected_experts_cpu.size(1)):\n                expert = self.experts[selected_experts_cpu[token_idx, expert_idx]]\n                outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx]\n        return outs\n    \n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor:\n        \n        batch_size, sequence_length, hidden_dim = orig_shape\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))\n\n        return final_hidden_states\n\nclass KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):\n    def forward(self, hidden_states):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n        topk_idx, topk_weight, aux_loss = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        \n        if sequence_length == 1 and hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():\n            self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])\n            if self.config.n_shared_experts is not None:\n                y_ = self.shared_experts(identity).squeeze(0)\n            y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)\n            y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        if self.config.n_shared_experts is not None:\n            y_ = self.shared_experts(identity).squeeze(0)\n            \n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        if self.config.n_shared_experts is not None:\n            y += y_\n        return y\n\n    @maybe_no_grad()\n    def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:\n        outs = self.experts(x, topk_ids, topk_weight)\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\nclass KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):\n    \n    def forward(self, hidden_states):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n        topk_idx, topk_weight = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        \n        # only for generate phase\n        if sequence_length == 1 and hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():\n            self.experts.generate_experts.submit_for_one_decode(hidden_states[0], topk_idx[0], topk_weight[0])\n            if self.config.n_shared_experts is not None:\n                y_ = self.shared_experts(identity).squeeze(0)\n            y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)\n            y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        if self.config.n_shared_experts is not None:\n            y_ = self.shared_experts(identity).squeeze(0)\n            \n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        if self.config.n_shared_experts is not None:\n            y += y_\n        return y\n\n    @maybe_no_grad()\n    def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:\n        outs = self.experts(x, topk_ids, topk_weight)\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\nclass KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):\n    \n    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n        \"\"\" \"\"\"\n        orig_shape = hidden_states.shape\n        batch_size, sequence_length, hidden_dim = hidden_states.shape\n        if self.training and self.jitter_noise > 0:\n            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)\n        hidden_states = hidden_states.view(-1, hidden_dim)\n        # router_logits: (batch * sequence_length, n_experts)\n        router_logits = self.gate(hidden_states)\n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n        \n        if sequence_length == 1 and hasattr(self.experts.generate_experts, \"submit_for_one_decode\"):\n            self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0])\n            y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)\n            y.resize_(*orig_shape)\n            return y, router_logits\n        \n        hidden_states_expert = hidden_states.to(self.experts.device)  if isinstance(self.experts, KExpertsBase) else hidden_states_expert.cpu()\n        selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts_expert.cpu()\n        routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights_expert.cpu()\n\n        if isinstance(self.experts, KExpertsBase):\n            y = (\n                self.moe_kexperts(\n                    hidden_states_expert, selected_experts_expert, routing_weights_expert\n                )\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        elif hidden_states_expert.size(0) > 10:\n            y = self.moe_infer(\n                hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape\n            ).to(device=hidden_states.device)\n        else:\n            y = self.moe_infer_simple(\n                hidden_states_expert, selected_experts_expert, routing_weights_expert\n            ).to(device=hidden_states.device)\n            \n        y.resize_(*orig_shape)\n        return y, router_logits\n    \n    @maybe_no_grad()\n    def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:\n        outs = self.experts(x, topk_ids, topk_weight)\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:\n        '''\n        hidden_states_cpu: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        '''\n        outs = torch.zeros_like(hidden_states_cpu)\n        for token_idx in range(selected_experts_cpu.size(0)):\n            for expert_idx in range(selected_experts_cpu.size(1)):\n                expert = self.experts[selected_experts_cpu[token_idx, expert_idx]]\n                outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx]\n        return outs\n    \n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor:\n        \n        batch_size, sequence_length, hidden_dim = orig_shape\n\n        final_hidden_states = torch.zeros(\n            (batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device\n        )\n\n        # One hot encode the selected experts to create an expert mask\n        # this will be used to easily index which expert is going to be sollicitated\n        expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0)\n\n        # Loop over all available experts in the model and perform the computation on each expert\n        for expert_idx in range(self.num_experts):\n            expert_layer = self.experts[expert_idx]\n            idx, top_x = torch.where(expert_mask[expert_idx])\n\n            # Index the correct hidden states and compute the expert hidden state for\n            # the current expert. We need to make sure to multiply the output hidden\n            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)\n            current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim)\n            current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None]\n\n            # However `index_add_` only support torch tensors for indexing so we'll use\n            # the `top_x` tensor here.\n            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))\n\n        return final_hidden_states\n\nclass KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):\n    def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):\n        identity = hidden_states\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n        topk_idx, topk_weight = self.gate(hidden_states)\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n        \n\n        # only for generate phase\n        if hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug\n            self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)\n            if self.config.n_shared_experts is not None:\n                y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)\n            y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)\n            y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        if self.config.n_shared_experts is not None:\n            y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)\n            \n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, topk_idx, topk_weight)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        if self.config.n_shared_experts is not None:\n            y += y_\n        return y\n\n    @maybe_no_grad()\n    def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:\n        outs = torch.empty_like(x)\n        outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\nclass KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):\n    def __init__(self,\n                 key: str,\n                 gguf_loader: GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                #  device: str = \"cuda\",\n                 prefill_device:str = \"cuda\",\n                 prefill_op: str | None = \"KExpertsTorch\",\n                 generate_device: str = \"cpu\",\n                 generate_op: str | None = \"KExpertsCPU\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        if generate_op is not None:\n            self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)\n        else:\n            self.generate_experts = None\n        if prefill_op is not None:\n            self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)\n        else:\n            self.prefill_experts = None\n        self.gpu_mlp_type = prefill_op\n        self.cpu_mlp_type = generate_op\n        self.mode = InferenceState.UNLOAD\n\n    def load(self, w: dict = None,  mode: InferenceState = None, warmup: bool = True):\n        # TODO support w as input\n        if not mode: mode = InferenceState.GENERATE\n        if mode == InferenceState.GENERATE:\n            self.prefill_experts.unload()\n            self.generate_experts.load(w, warmup=warmup)\n            self.device = self.generate_experts.device\n            self.mode = mode\n        elif mode == InferenceState.PREFILL:\n            self.generate_experts.unload()\n            self.prefill_experts.load(w, warmup=warmup)\n            self.device = self.prefill_experts.device\n            self.mode = mode\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n            self.mode = mode\n            self.device = self.generate_experts.device\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n    def unload(self):\n        if self.generate_experts is not None:\n            self.generate_experts.unload()\n        if self.prefill_experts is not None:\n            self.prefill_experts.unload()\n        self.device = self.generate_experts.device\n\n    def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):\n        if self.mode == InferenceState.GENERATE:\n            assert self.generate_experts is not None, \"generate_experts is None\"\n            return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)\n        elif self.mode == InferenceState.PREFILL:\n            assert self.prefill_experts is not None, \"prefill_experts is None\"\n            return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)\n        else:\n            raise ValueError(\"load or set_inference_mode before forward\")\n\n    def set_inference_mode(self, mode: InferenceState):\n        if mode == InferenceState.GENERATE:\n            self.load(mode=InferenceState.GENERATE, warmup=False)\n        elif mode == InferenceState.PREFILL:\n            self.load(mode=InferenceState.PREFILL, warmup=False)\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\nclass KQwen2MoeSparseMoeBlockV2(BaseInjectedModule, Qwen2MoeSparseMoeBlock):\n    def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):\n\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n\n        router_logits = self.gate(hidden_states, bsz_tensor)        \n\n        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n        if self.norm_topk_prob:\n            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        # only for generate phase\n        if hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug\n            self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)\n            y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n            y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_    \n\n            y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)\n            \n            y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n        y_ = (\n            F.sigmoid(self.shared_expert_gate(hidden_states)) * y_\n        )\n\n\n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            ) \n        y += y_\n        return y\n\n    @maybe_no_grad()\n    def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:\n        outs = torch.empty_like(x)\n        outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\nclass KQwen3MoeSparseMoeBlockV2(BaseInjectedModule, Qwen3MoeSparseMoeBlock):\n    def forward(self, hidden_states, bsz_tensor=None, cuda_graph_idx=0):\n\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n\n        if bsz_tensor is None:\n            router_logits = self.gate(hidden_states)\n        else:\n            router_logits = self.gate(hidden_states, bsz_tensor)\n\n        if router_logits.device.type == \"xpu\":\n            from ipex_llm.transformers.models.common import moe_softmax_topk\n            selected_experts, routing_weights = moe_softmax_topk(\n                router_logits.half(), self.top_k, self.norm_topk_prob\n            )\n        else:\n            routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)\n            routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n            if self.norm_topk_prob:\n                routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        # only for generate phase\n        if hasattr(self.experts.generate_experts, \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug\n            self.experts.generate_experts.submit_for_one_decode(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx)\n            # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n            # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_    \n\n            y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)\n            \n            # y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        # y_ = self.shared_expert(hidden_states, bsz_tensor).squeeze(0)\n        # y_ = (\n        #     F.sigmoid(self.shared_expert_gate(hidden_states)) * y_\n        # )\n\n\n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_on_cpuinfer(hidden_states, selected_experts, routing_weights, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            ) \n        # y += y_\n        return y\n\n    @maybe_no_grad()\n    def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:\n        outs = torch.empty_like(x)\n        outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n        self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                    expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n\n\nclass KQwen3MoeSparseMoeBlock(BaseInjectedModule, Qwen3MoeSparseMoeBlock):\n    def forward(self, hidden_states):\n\n        orig_shape = hidden_states.shape\n        sequence_length = orig_shape[1]\n\n        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])\n\n        router_logits = self.gate(hidden_states)\n\n        if router_logits.device.type == \"xpu\":\n            from ipex_llm.transformers.models.common import moe_softmax_topk\n            selected_experts, routing_weights = moe_softmax_topk(\n                router_logits.half(), self.top_k, self.norm_topk_prob\n            )\n        else:\n            routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)\n            routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n            if self.norm_topk_prob:\n                routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n        # we cast back to the input dtype\n        routing_weights = routing_weights.to(hidden_states.dtype)\n\n        # only for generate phase\n        if sequence_length == 1 and hasattr(self.experts.generate_experts,\n                                            \"submit_for_one_decode\") and torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():  # TODO: this branch cause jit bug\n            self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0],\n                                                                routing_weights[0])\n            # y_ = self.shared_expert(hidden_states).squeeze(0)\n            # y_ = F.sigmoid(self.shared_expert_gate(hidden_states)) * y_\n\n            y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)\n\n            # y += y_\n            y.resize_(*orig_shape)\n            return y\n\n        # y_ = self.shared_expert(hidden_states).squeeze(0)\n        # y_ = (\n        #     F.sigmoid(self.shared_expert_gate(hidden_states)) * y_\n        # )\n\n        if isinstance(self.experts, KExpertsBase):\n            y = self.moe_kexperts(hidden_states, selected_experts, routing_weights).view(*orig_shape).to(\n                device=hidden_states.device)\n        elif hidden_states.size(0) > 10:\n            # TODO may bugs here\n            y = (\n                self.moe_infer(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n        else:\n            # TODO may bugs here\n            y = (\n                self.moe_infer_simple(hidden_states, selected_experts, routing_weights)\n                .view(*orig_shape)\n                .to(device=hidden_states.device)\n            )\n            # y += y_\n        return y\n\n    @maybe_no_grad()\n    def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:\n        outs = self.experts(x, topk_ids, topk_weight)\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer_simple(\n            self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor\n    ) -> torch.Tensor:\n        \"\"\"\n        x: [num_tokens, hidden_size]\n        topk_ids, topk_weight: [num_tokens, num_selected_experts]\n        \"\"\"\n        outs = torch.zeros_like(x)\n        for token_idx in range(topk_ids.size(0)):\n            for expert_idx in range(topk_ids.size(1)):\n                expert = self.experts[topk_ids[token_idx, expert_idx]]\n                outs[token_idx] += (\n                        expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]\n                )\n        return outs\n\n    @maybe_no_grad()\n    # TODO may bugs here\n    def moe_infer(self, x, topk_ids, topk_weight):\n        cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))\n        cnts.scatter_(1, topk_ids, 1)\n        tokens_per_expert = cnts.sum(dim=0)\n        idxs = topk_ids.view(-1).argsort()\n        sorted_tokens = x[idxs // topk_ids.shape[1]]\n        tokens_per_expert = tokens_per_expert.cpu().numpy()\n\n        outputs = []\n        start_idx = 0\n        for i, num_tokens in enumerate(tokens_per_expert):\n            end_idx = start_idx + num_tokens\n            if num_tokens == 0:\n                continue\n            expert = self.experts[i + self.ep_rank * self.experts_per_rank]\n            tokens_for_this_expert = sorted_tokens[start_idx:end_idx]\n            expert_out = expert.forward(tokens_for_this_expert)\n            outputs.append(expert_out)\n            start_idx = end_idx\n\n        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)\n\n        new_x = torch.empty_like(outs)\n        new_x[idxs] = outs\n        final_out = (\n            new_x.view(*topk_ids.shape, -1)\n            .type(topk_weight.dtype)\n            .mul_(topk_weight.unsqueeze(dim=-1))\n            .sum(dim=1)\n            .type(new_x.dtype)\n        )\n        return final_out\n"
  },
  {
    "path": "kt-sft/ktransformers/operators/flashinfer_batch_prefill_wrapper.py",
    "content": "import torch\nimport flashinfer\nimport gc\ntry:\n    from flash_attn import flash_attn_with_kvcache\n    print(\"found flash_attn\")\n    \nexcept ImportError:\n    print(\"flash_attn not found, flashinfer unit test needed it. If you are using balance serve, ignore this.\")\n\nfrom typing import Union, Optional\n\ndef setup_seed(seed):\n\ttorch.manual_seed(seed)\n\ttorch.cuda.manual_seed_all(seed)\n\nsetup_seed(998244353)\n\ntorch.set_grad_enabled(False)\ntorch.set_default_dtype(torch.bfloat16)\nglobal_dtype=torch.bfloat16\nglobal_device=torch.device(\"cuda\",0)\ntorch.cuda.set_device(0)\ntorch.backends.cudnn.enabled =True\ntorch.backends.cudnn.benchmark = True\n\nclass flashInferAttn():\n\t\n\tfloat_workspace_buffer = None\n\tdef __init__(self,\n\t\t\tmax_batch_token,\n\t\t\tmax_batch_size,\n\t\t\tmax_pages,\n\t\t\tdevice = \"cuda:0\",\n\t\t\tkv_layout: str = \"NHD\",\n\t\t\tuse_cuda_graph: bool = False,\n\t\t\t) -> None:\n\t\tself.device = device\n\t\tself.max_batch_token = max_batch_token\n\t\tself.kv_layout = kv_layout\n\t\tself.use_cuda_graph = use_cuda_graph\n\t\tif flashInferAttn.float_workspace_buffer is None:\n\t\t\tflashInferAttn.float_workspace_buffer = torch.empty(max_batch_token * 1024 * 1024, dtype=torch.uint8, device=device)\n\t\tself.qo_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)\n\t\tself.paged_kv_indptr_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)\n\t\tself.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)\n\t\tself.paged_kv_last_page_len_buf = torch.empty((max_batch_size,), dtype=torch.int32, device=device)\n\t\tself.batch_size_tensor_buf = torch.empty((1,), dtype=torch.int32, device=device)\n\t\tself.num_tokens_tensor_buf = torch.empty((1,), dtype=torch.uint32, device=device)\n\t\n\t\t# TODO: custom mask\n\t\tself.custom_mask_buf = None\n\t\tself.qk_indptr_buf = None\n\t\tself.warpper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(\n\t\t\tflashInferAttn.float_workspace_buffer,\n\t\t\tself.kv_layout,\n\t\t\tuse_cuda_graph=self.use_cuda_graph,\n\t\t\tqo_indptr_buf=self.qo_indptr_buf,\n\t\t\tpaged_kv_indptr_buf=self.paged_kv_indptr_buf,\n\t\t\tpaged_kv_indices_buf=self.paged_kv_indices_buf,\n\t\t\tpaged_kv_last_page_len_buf=self.paged_kv_last_page_len_buf,\n\t\t\tbackend = \"fa2\",\n\t\t)\n\n\tdef plan(self,\n\t\tqo_indptr: torch.Tensor,\n\t\tpaged_kv_indptr: torch.Tensor,\n\t\tpaged_kv_indices: torch.Tensor,\n\t\tpaged_kv_last_page_len: torch.Tensor,\n\t\tbatch_size_tensor: torch.Tensor,\n\t\tnum_tokens_tensor: torch.Tensor,\n\t\tnum_qo_heads: int,\n\t\tnum_kv_heads: int,\n\t\thead_dim: int,\n\t\tpage_size: int,\n\t\tcausal: bool = True, \n\t\tpos_encoding_mode: str = \"NONE\",\n\t\tq_data_type: Union[str, torch.dtype] = torch.bfloat16,\n\t\tkv_data_type: Optional[Union[str, torch.dtype]] = None):\n\t\t\n\t\tself.batch_size_tensor_buf.copy_(batch_size_tensor, non_blocking=True)\n\t\tself.num_tokens_tensor_buf.copy_(num_tokens_tensor, non_blocking=True)\n\t\tself.page_size = page_size\n\t\tself.warpper.plan(\n\t\t\tqo_indptr,\n\t\t\tpaged_kv_indptr,\n\t\t\tpaged_kv_indices,\n\t\t\tpaged_kv_last_page_len,\n\t\t\tnum_qo_heads,\n\t\t\tnum_kv_heads,\n\t\t\thead_dim,\n\t\t\tpage_size,\n\t\t\tcausal = causal,\n\t\t\tpos_encoding_mode = pos_encoding_mode,\n\t\t\tq_data_type = q_data_type,\n\t\t\tkv_data_type = kv_data_type\n\t\t\t)\n\n\tdef calc_batch_indices(self, ragged_size = None):\n\t\tif self.use_cuda_graph:\n\t\t\tself.batch_indices, self.positions = flashinfer.get_batch_indices_positions(\n\t\t\t\tself.qo_indptr_buf, flashinfer.get_seq_lens(self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, self.max_batch_token)\n\t\telse:\n\t\t\tself.batch_indices, self.positions = flashinfer.get_batch_indices_positions(\n\t\t\t\tself.warpper._qo_indptr_buf, flashinfer.get_seq_lens(self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.page_size), self.batch_size_tensor_buf, ragged_size)\n\n\tdef forward(self, q, k_cache, v_cache, k, v):\n\t\tif self.use_cuda_graph:\n\t\t\tflashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.paged_kv_indices_buf, self.paged_kv_indptr_buf, self.paged_kv_last_page_len_buf, self.num_tokens_tensor_buf)\n\t\t\treturn self.warpper.run(q, (k_cache, v_cache))\n\t\telse:\n\t\t\tflashinfer.page.append_paged_kv_cache(k, v, self.batch_indices, self.positions, (k_cache, v_cache), self.warpper._paged_kv_indices_buf, self.warpper._paged_kv_indptr_buf, self.warpper._paged_kv_last_page_len_buf, self.num_tokens_tensor_buf)\n\t\t\treturn self.warpper.run(q, (k_cache, v_cache))\n\n\ndef testCudaGraph():\n\t\n\t# use max batch to create buffer\n\tbatch_decode = 8\n\tprefill_chunk = 48\n\tpast_kv_0 = 4090\n\tpast_kv_1 = 4096\n\traged_size = prefill_chunk + batch_decode\n\tnum_key_value_heads = 8\n\thead_dim = 128\n\tnum_attention_heads = 64\n\tpage_size = 256\n\tnum_pages_per_seq = (past_kv_1 + page_size - 1) // page_size\n\ttotal_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size\n\tattn = flashInferAttn(raged_size, batch_decode+1, total_num_pages, use_cuda_graph=True)\n\n\tbatch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32)\n\t\n\tk_caches = []\t\n\tv_caches = []\n\tks = []\n\tvs = []\n\tqs = []\n\tfor layer_idx in range(3):\n\t\tk_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tv_caches.append(torch.randn(total_num_pages, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tks.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tvs.append(torch.randn(raged_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tqs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\n\t# warmup and capture small batch\n\tpast_kv_0 = 250\n\tpast_kv_1 = 256\n\tnum_pages_per_seq = (past_kv_1 + page_size - 1) // page_size\n\ttotal_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size\n\tq_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)\n\tq_indptr[0] = 0\n\tq_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)\n\tkv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq\n\tkv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)\n\tkv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)\n\tkv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)\n\tkv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)\n\n\tprint(q_indptr)\n\tprint(kv_indptr)\n\tprint(kv_indices)\n\tprint(kv_last_page_len)\n\tattn.plan(q_indptr,\n\t\t\tkv_indptr,\n\t\t\tkv_indices,\n\t\t\tkv_last_page_len,\n\t\t\tbatch_size_tensor,\n\t\t\tnum_attention_heads,\n\t\t\tnum_key_value_heads,\n\t\t\thead_dim,\n\t\t\tpage_size,\n\t\t\tcausal = True,\n\t\t\tpos_encoding_mode=\"NONE\",\n\t\t\tq_data_type=torch.bfloat16)\n\n\tattn.calc_batch_indices(raged_size)\n\tfor layer_idx in range(3):\n\t\tattn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx])\n\t\ttorch.cuda.synchronize()\n\n\touts = []\n\tg = torch.cuda.CUDAGraph()\n\twith torch.cuda.graph(g):\n\t\tfor layer_idx in range(3):\n\t\t\touts.append(attn.forward(qs[layer_idx], k_caches[layer_idx], v_caches[layer_idx], ks[layer_idx], vs[layer_idx]))\n\tg.replay()\n\t\n\tkv_last_page_len[:1+batch_decode//2] = int(past_kv_0)\n\tkv_last_page_len[1+batch_decode//2:] = int(past_kv_1)\n\tfor layer_idx in range(3):\n\t\tfor i in range(batch_decode + 1):\n\t\t\t\n\t\t\tqi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]]\n\t\t\to_ref_i = flash_attn_with_kvcache(\n\t\t\t\tqi.unsqueeze(0),\n\t\t\t\tk_caches[layer_idx],\n\t\t\t\tv_caches[layer_idx],\n\t\t\t\tcausal=True,\n\t\t\t\tblock_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0),\n\t\t\t\tcache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32)\n\t\t\t)\n\t\t\to_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]]\n\t\t\tprint(layer_idx, i)\n\t\t\ttorch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3)\n\n\t# run another batch size use capture cuda graph\n\tpast_kv_0 = 4090\n\tpast_kv_1 = 4096\n\tprefill_chunk = 24\n\tbatch_decode = 4\n\tnum_pages_per_seq = (past_kv_1 + page_size - 1) // page_size\n\ttotal_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size\n\tbatch_size_tensor = torch.tensor([batch_decode + 1], device=global_device, dtype=torch.int32)\n\tnum_tokens_tensor = torch.tensor([batch_decode + prefill_chunk], device=global_device, dtype=torch.int32)\n\n\tq_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)\n\tq_indptr[0] = 0\n\tq_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)\n\tkv_indptr = torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq\n\tkv_indices = torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32)\n\tkv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)\n\tkv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)\n\tkv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)\n\tattn.plan(q_indptr,\n\t\t\tkv_indptr,\n\t\t\tkv_indices,\n\t\t\tkv_last_page_len,\n\t\t\tbatch_size_tensor,\n\t\t\tnum_attention_heads,\n\t\t\tnum_key_value_heads,\n\t\t\thead_dim,\n\t\t\tpage_size,\n\t\t\tcausal = True,\n\t\t\tpos_encoding_mode=\"NONE\",\n\t\t\tq_data_type=torch.bfloat16)\n\tattn.calc_batch_indices(raged_size)\n\tg.replay()\n\t\n\tkv_last_page_len[:1+batch_decode//2] = int(past_kv_0)\n\tkv_last_page_len[1+batch_decode//2:] = int(past_kv_1)\n\tfor layer_idx in range(3):\n\t\tfor i in range(batch_decode + 1):\n\t\t\t\n\t\t\tqi = qs[layer_idx][q_indptr[i] : q_indptr[i + 1]]\n\t\t\to_ref_i = flash_attn_with_kvcache(\n\t\t\t\tqi.unsqueeze(0),\n\t\t\t\tk_caches[layer_idx],\n\t\t\t\tv_caches[layer_idx],\n\t\t\t\tcausal=True,\n\t\t\t\tblock_table=kv_indices[kv_indptr[i]:kv_indptr[i+1]].unsqueeze(0),\n\t\t\t\tcache_seqlens=torch.tensor([past_kv_0 if i < 1+batch_decode//2 else past_kv_1], device=global_device, dtype=torch.int32)\n\t\t\t)\n\t\t\to_i = outs[layer_idx][q_indptr[i] : q_indptr[i + 1]]\n\t\t\tprint(layer_idx, i)\n\t\t\ttorch.testing.assert_close(o_i.unsqueeze(0), o_ref_i, rtol=5e-3, atol=5e-3)\n\t\t\t\n\n\ndef testAttentionFlashInfer(\t\n\t):\n\tbatch_decode = 32\n\tprefill_chunk = 64\n\tpast_kv_0 = 510\n\tpast_kv_1 = 512\n\traged_size = prefill_chunk + batch_decode\n\tnum_key_value_heads = 8\n\thead_dim = 128\n\tnum_attention_heads = 64\n\tcases = 1\n\tpage_size = 32\n\tnum_pages_per_seq = (past_kv_1 + page_size - 1) // page_size\n\ttotal_num_pages = (num_pages_per_seq + 1) * (batch_decode + 1) + prefill_chunk // page_size\n\tworkspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=\"cuda:0\")\n\tqs = []\n\tkvs = []\n\tq_indptrs = []\n\tkv_indptrs = []\n\tkv_indicess = []\n\tkv_last_page_lens = []\n\twrappers = []\n\tfor case_id in range(cases):\n\t\tkvs.append(torch.randn(total_num_pages, 2, page_size, num_key_value_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tqs.append(torch.randn(raged_size, num_attention_heads, head_dim, device=global_device, dtype=torch.bfloat16))\n\t\tq_indptr = torch.empty((batch_decode + 2,), dtype=torch.int32, device=global_device)\n\t\tq_indptr[0] = 0\n\t\tq_indptr[1:] = torch.arange(prefill_chunk, prefill_chunk + batch_decode + 1, device=global_device, dtype=torch.int32)\n\t\tq_indptrs.append(q_indptr)\n\t\tkv_indptrs.append(torch.arange(0, batch_decode + 2, device=global_device, dtype=torch.int32) * num_pages_per_seq)\n\t\tkv_indicess.append(torch.arange(0, total_num_pages, device=global_device, dtype=torch.int32))\n\t\tkv_last_page_len = torch.empty((batch_decode + 1,), dtype=torch.int32, device=global_device)\n\t\tkv_last_page_len[:1+batch_decode//2] = int((past_kv_0 - 1) % page_size + 1)\n\t\tkv_last_page_len[1+batch_decode//2:] = int((past_kv_1 - 1) % page_size + 1)\n\t\tkv_last_page_lens.append(kv_last_page_len)\n\t\twrappers.append(flashinfer.BatchPrefillWithPagedKVCacheWrapper(\n\t\t\tworkspace_buffer,\n\t\t\t\"NHD\",\n\t\t\tuse_cuda_graph=True,\n\t\t\tqo_indptr_buf=q_indptrs[case_id],\n\t\t\tpaged_kv_indptr_buf=kv_indptrs[case_id],\n\t\t\tpaged_kv_indices_buf=kv_indicess[case_id],\n\t\t\tpaged_kv_last_page_len_buf=kv_last_page_lens[case_id],\n\t\t))\n\t\twrappers[case_id].plan(\n\t\t\tq_indptrs[case_id],\n\t\t\tkv_indptrs[case_id],\n\t\t\tkv_indicess[case_id],\n\t\t\tkv_last_page_lens[case_id],\n\t\t\tnum_attention_heads,\n\t\t\tnum_key_value_heads,\n\t\t\thead_dim,\n\t\t\tpage_size,\n\t\t\tcausal = True,\n\t\t\tpos_encoding_mode=\"ROPE_LLAMA\",\n\t\t\tq_data_type=torch.bfloat16\n\t\t)\n\t\t\t\t\t\n\tdef custom_forward(case_id):\n\t\tout = wrappers[case_id].run(qs[case_id], kvs[case_id])\n\t\n\tcustom_forward(0)\n\n# testCudaGraph()\n# pass"
  },
  {
    "path": "kt-sft/ktransformers/operators/flashinfer_wrapper.py",
    "content": "'''\nDescription  : flashinfer MLA wrapper\nAuthor       : Boxin Zhang\nVersion      : 0.2.3\n'''\nimport torch\nimport os\nfrom ktransformers.operators.triton_attention import decode_attention_fwd_grouped\n\nflashinfer_enabled = False\n\ntry:\n    import flashinfer\n    flashinfer_enabled = True\n    print(\"found flashinfer\")\n    \nexcept ImportError:\n    print(\"flashinfer not found, use triton for linux\")\n\nimport math\n\ndef attention_ref_torch(\n    batch_size,\n    q: torch.Tensor,\n    k: torch.Tensor,\n    v: torch.Tensor,\n    causal: bool,\n    sm_scale: float,\n) -> torch.Tensor:\n    qo_len = q.shape[0] // batch_size\n    kv_len = k.shape[0] // batch_size\n    num_qo_heads = q.shape[1]\n    head_dim_qk = q.shape[2]\n    head_dim_vo = v.shape[2]\n    logits = (\n        torch.einsum(\n            \"bmhd,bnhd->bhmn\",\n            q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(),\n            k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(),\n        )\n        * sm_scale\n    )\n\n    #print(\"attn weights\", logits)\n\n    if causal:\n        mask = (\n            torch.arange(kv_len - qo_len, kv_len).unsqueeze(1)\n            >= torch.arange(0, kv_len).unsqueeze(0)\n        ).to(q.device)\n    else:\n        mask = torch.ones(qo_len, kv_len).to(q.device)\n\n    logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float(\"-inf\"))\n    lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2)\n    p = torch.softmax(logits, dim=-1)\n    o_ref = (\n        torch.einsum(\n            \"bhmn,bnhd->bmhd\",\n            p,\n            v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(),\n        )\n        .contiguous()\n        .view(batch_size * qo_len, num_qo_heads, head_dim_vo)\n        .to(q)\n    )\n\n    return o_ref, lse_ref * math.log2(math.e)\n\nclass MLAWrapper():\n    def __init__(self,\n                 max_batch_size,\n                 max_pages,\n                 use_cuda_graph = True,\n                 device = \"cuda\",\n                 ):\n        self.float_workspace_buffer = torch.empty(128*1024*1024, dtype=torch.int8, device=device)\n        self.max_batch_size = max_batch_size\n        self.max_pages = max_pages\n        if use_cuda_graph:\n            if self.max_batch_size == 1:\n                self.qo_indptr_buf = torch.arange(0, max_batch_size+1, dtype=torch.int32, device=device)\n                self.kv_indptr_buf = torch.tensor([0, max_pages], dtype=torch.int32, device=device)\n                self.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)\n            else:\n                self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)\n                self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)\n                self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)\n            self.batch_size_tensor_buf = torch.tensor([self.max_batch_size], dtype=torch.int32, device=device)\n            self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)\n        else:\n            self.qo_indptr_buf = None\n            self.kv_indptr_buf = None\n            self.kv_indices_buf = None\n            self.kv_len_arr_buf = None\n        self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(\n            self.float_workspace_buffer,\n            use_cuda_graph=use_cuda_graph,\n            qo_indptr=self.qo_indptr_buf,\n            kv_indptr=self.kv_indptr_buf,\n            kv_indices=self.kv_indices_buf,\n            kv_len_arr=self.kv_len_arr_buf,\n            bsz_tensor=self.batch_size_tensor_buf,\n            backend = \"fa2\",\n        )\n        self.need_plan = True\n\n    \n    def plan(self,\n             qo_indptr,\n             kv_indptr,\n             kv_indices,\n             kv_len_arr,\n             bsz_tensor,\n             num_heads,\n             head_dim_ckv,\n             head_dim_kpe,\n             page_size,\n             sm_scale,\n             q_data_type,\n             kv_data_type,\n             ):\n        if qo_indptr is None:\n            assert self.max_batch_size == 1\n            qo_indptr = self.qo_indptr_buf\n        if kv_indptr is None:\n            assert self.max_batch_size == 1\n            kv_indptr = self.kv_indptr_buf\n        if kv_indices is None:\n            assert self.max_batch_size == 1\n            kv_indices = self.kv_indices_buf\n        if bsz_tensor is None:\n            assert self.max_batch_size == 1\n            bsz_tensor = self.batch_size_tensor_buf\n        \n        self.wrapper.plan(\n            qo_indptr,\n            kv_indptr,\n            kv_indices,\n            kv_len_arr,\n            num_heads,\n            head_dim_ckv,\n            head_dim_kpe,\n            page_size,\n            True, # causal\n            sm_scale,\n            q_data_type,\n            kv_data_type,\n            bsz_tensor\n        )\n\n    def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):\n        return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)\n\nclass MLAWrapperSingleton():\n    wrappers:dict = {}\n\n    @classmethod\n    def get_instance(cls, device, *args, **kwargs)->MLAWrapper:\n        if device not in cls.wrappers:\n            cls.make_instance(device, *args, **kwargs)\n        return cls.wrappers[device]\n    \n    @classmethod\n    def make_instance(cls, device, *args, **kwargs):\n        cls.wrappers[device] = MLAWrapper(*args, **kwargs, device=device)\n\n    @classmethod\n    def plan_all(cls, qo_indptr,\n             kv_indptr,\n             kv_indices,\n             kv_len_arr,\n             bsz_tensor,\n             num_heads,\n             head_dim_ckv,\n             head_dim_kpe,\n             page_size,\n             sm_scale,\n             q_data_type,\n             kv_data_type,):\n        for device, wrapper in cls.wrappers.items():\n            kv_len_arr_cur_device = kv_len_arr.to(device)\n            wrapper.plan(qo_indptr,\n                kv_indptr,\n                kv_indices,\n                kv_len_arr_cur_device,\n                bsz_tensor,\n                num_heads,\n                head_dim_ckv,\n                head_dim_kpe,\n                page_size,\n                sm_scale,\n                q_data_type,\n                kv_data_type,)\n            wrapper.need_plan = False\n            \n    @classmethod\n    def need_plan_all(cls):\n        for device, wrapper in cls.wrappers.items():\n            wrapper.need_plan = True\n        \n    @classmethod\n    def reset_buffer(cls):\n        for device, wrapper in cls.wrappers.items():\n            wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here.\n            \n    @classmethod\n    def update_buffer(cls, max_pages):\n        for device, wrapper in cls.wrappers.items():\n            wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.\n            wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)\n            wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf\n\ndef checksame():\n    flashinfer_folder = \"./flashinfer_output\"\n    flashinfer_folder = \"./kv_cache_flashinfer\"\n    triton_folder = \"./triton_output\"\n    triton_folder = \"./kv_cache_triton\"\n    \n    max_layer_id = 1\n    max_forward_id = 2\n\n    for forward_id in range(0, 19):\n        print(\"forward_id\", forward_id)\n        for layer_id in range(max_layer_id):\n            print(layer_id)\n            #file_name = f\"layer_{layer_id}_forward_{forward_id}_attn_output.pt\"\n            #file_name = f\"layer_{layer_id}_forward_{forward_id}_q_pe.pt\"\n            file_name = f\"layer_{layer_id}.pt\"\n            \n            flashinfer_path = os.path.join(flashinfer_folder, file_name)\n            triton_path = os.path.join(triton_folder, file_name)\n            \n            if not os.path.exists(triton_path):\n                print(f\"{file_name} not exist in {triton_folder}\")\n                continue\n            if not os.path.exists(flashinfer_path):\n                print(f\"{file_name} not exist in {flashinfer_folder}\")\n                continue\n            \n            \n            flashinfer_tensor = torch.load(flashinfer_path)[1:2, :62]#\n            triton_tensor = torch.load(triton_path)[1:2, :62]#.squeeze(1)#\n            try:\n                torch.testing.assert_close(flashinfer_tensor, triton_tensor, rtol=1e-9, atol=1e-9)\n            except AssertionError as e:\n                print(e)\n\nif __name__ == \"__main__\":\n    \n    #checksame()\n    #exit(0)\n\n    max_batch_size = 2\n    max_batch_tokens = 256\n    max_pages = 128\n    page_size = 64\n    num_heads = 128\n    \n    # warm-up\n    kv_len = 4023\n    q_len = 1\n    q_nope_buf = torch.randn((max_batch_tokens, num_heads, 512), dtype=torch.bfloat16, device=\"cuda\")\n    q_pe_buf = torch.randn((max_batch_tokens, num_heads, 64), dtype=torch.bfloat16, device=\"cuda\")\n    kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device=\"cuda\")\n    ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)\n    \n\n    wrapper = MLAWrapperSingleton.get_instance(\n        \"cuda\",\n        max_batch_size,\n        max_pages,\n    )\n    \n    used_pages = (kv_len + page_size - 1)// page_size\n    kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device=\"cuda\")\n    qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=\"cuda\")\n    kv_indptr = torch.tensor([0, used_pages], dtype=torch.int32, device=\"cuda\")\n    kv_indices = torch.empty(max_pages, dtype=torch.int32, device=\"cuda\")\n    kv_indices[:used_pages] = torch.arange(0, used_pages, dtype=torch.int32, device=\"cuda\")\n    bsz_tensor = torch.tensor([1], dtype=torch.int32, device=\"cuda\")\n    wrapper.plan(\n        qo_indptr,\n        kv_indptr,\n        kv_indices,\n        kv_len_arr,\n        bsz_tensor,\n        128,\n        512,\n        64,\n        page_size,\n        192 ** (-0.5),\n        torch.bfloat16,\n        torch.bfloat16,\n    )\n\n    attn_output = wrapper.run(q_nope_buf[:q_len], q_pe_buf[:q_len], ckv, k_pe)\n    print(attn_output.shape)\n    graph = torch.cuda.CUDAGraph()\n    with torch.cuda.graph(graph):\n        attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)\n    graph.replay()\n\n    q = torch.cat([q_nope_buf, q_pe_buf], dim=-1)\n    k = (\n        torch.cat([ckv, k_pe], dim=-1)\n        .view(-1, 1, 512 + 64)\n        .repeat_interleave(num_heads, dim=1)\n    )\n    v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)\n    attn_ref, lse_ref = attention_ref_torch(\n        1,\n        q[:q_len],\n        k[:kv_len],\n        v[:kv_len],\n        True,\n        192 ** (-0.5)\n    )\n    torch.testing.assert_close(attn_output[:q_len], attn_ref, rtol=5e-3, atol=5e-3)\n    # warm-up finished\n\n    kv_len = 512\n    q_len = 128\n    pages = max_pages\n    used_pages = (kv_len + page_size - 1)// page_size\n    q_nope = torch.randn((q_len*2, num_heads, 512), dtype=torch.bfloat16, device=\"cuda\")\n    q_nope[q_len:] = q_nope[:q_len]\n    q_pe = torch.randn((q_len*2, num_heads, 64), dtype=torch.bfloat16, device=\"cuda\")\n    q_pe[q_len:] = q_pe[:q_len]\n    kv_cache = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device=\"cuda\")\n    kv_cache[used_pages:2*used_pages] = kv_cache[:used_pages]\n    ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)\n    \n    kv_len_arr = torch.tensor([kv_len, kv_len], dtype=torch.int32, device=\"cuda\")\n    qo_indptr = torch.tensor([0, q_len, q_len*2], dtype=torch.int32, device=\"cuda\")\n    kv_indptr = torch.tensor([0, used_pages, used_pages*2], dtype=torch.int32, device=\"cuda\")\n    kv_indices = torch.empty(max_pages, dtype=torch.int32, device=\"cuda\")\n    kv_indices[:2*used_pages] = torch.arange(0, 2*used_pages, dtype=torch.int32, device=\"cuda\")\n    bsz_tensor = torch.tensor([2], dtype=torch.int32, device=\"cuda\")\n    wrapper.plan(\n        qo_indptr,\n        kv_indptr,\n        kv_indices,\n        kv_len_arr,\n        bsz_tensor,\n        128,\n        512,\n        64,\n        page_size,\n        192 ** (-0.5),\n        torch.bfloat16,\n        torch.bfloat16,\n    )\n    \n    q_nope_buf.copy_(q_nope)\n    q_pe_buf.copy_(q_pe)\n    kv_buf[:pages].copy_(kv_cache)\n\n    torch.cuda.synchronize()\n    graph.replay()\n    torch.cuda.synchronize()\n\n    # ref_torch\n    q = torch.cat([q_nope, q_pe], dim=-1)\n    k = (\n        torch.cat([ckv, k_pe], dim=-1)\n        .view(-1, 1, 512 + 64)\n        .repeat_interleave(num_heads, dim=1)\n    )\n    v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)\n    attn_ref, lse_ref = attention_ref_torch(\n        max_batch_size,\n        q,\n        k[:2*kv_len],\n        v[:2*kv_len],\n        True,\n        192 ** (-0.5)\n    )\n    \n    torch.testing.assert_close(attn_ref[:q_len], attn_ref[q_len:q_len*2], rtol=1e-9, atol=1e-9)\n    torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)\n    torch.testing.assert_close(attn_output[:q_len], attn_ref[:q_len], rtol=5e-3, atol=5e-3)\n    torch.testing.assert_close(attn_output[q_len:q_len*2], attn_ref[q_len:q_len*2], rtol=5e-3, atol=5e-3)\n    #torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)\n    #torch.testing.assert_close(attn_output, attn_ref, rtol=5e-3, atol=5e-3)\n\n    exit(0)\n\n    for forward_id in range(0, 1):\n        print(\"forward_id\", forward_id)\n        for layer_id in range(1):\n            print(layer_id)\n            flashinfer_folder = \"./kv_cache_flashinfer\"\n            forward_id = 17\n            layer_id = 0\n            file_name = f\"layer_{layer_id}.pt\"\n            kv_cache_path = os.path.join(flashinfer_folder, file_name)\n            flashinfer_folder = \"./flashinfer_output\"\n\n            q_len = 1\n            kv_len = 126\n            file_name = f\"layer_{layer_id}_forward_{forward_id}_q_nope.pt\"\n            q_nope = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,512).to(device=\"cuda\")\n            file_name = f\"layer_{layer_id}_forward_{forward_id}_q_pe.pt\"\n            q_pe = torch.load(os.path.join(flashinfer_folder, file_name)).view(q_len,128,64).to(device=\"cuda\")\n            q = torch.cat([q_nope, q_pe], dim=-1)\n            kv_cache = torch.load(kv_cache_path).to(device=\"cuda\")\n            pages, page_size, _, head_dim = kv_cache.shape\n            kv_cache = kv_cache.view(pages, page_size, head_dim)\n            ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)\n    \n            kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device=\"cuda\")\n            qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=\"cuda\")\n            wrapper.plan(\n                None,\n                None,\n                None,\n                kv_len_arr,\n                128,\n                512,\n                64,\n                page_size,\n                192 ** (-0.5),\n                torch.bfloat16,\n                torch.bfloat16,\n            )\n    \n            q_nope_buf.copy_(q_nope)\n            q_pe_buf.copy_(q_pe)\n            kv_buf[:pages].copy_(kv_cache)\n\n            torch.cuda.synchronize()\n            graph.replay()\n            torch.cuda.synchronize()\n\n            # ref_torch\n            k = (\n                torch.cat([ckv, k_pe], dim=-1)\n                .view(-1, 1, 512 + 64)\n                .repeat_interleave(num_heads, dim=1)\n            )\n            v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)\n            attn_ref, lse_ref = attention_ref_torch(\n                max_batch_size,\n                q,\n                k[:kv_len],\n                v[:kv_len],\n                False,\n                192 ** (-0.5)\n            )\n            torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)\n    \n            # ref_triton\n            attn_logits = torch.empty(\n                    (\n                        max_batch_size,\n                        num_heads,\n                        4, #num_kv_splits # follow vLLM, fix it TODO\n                        512 + 1, \n                    ),\n                    dtype=torch.float32,\n                    device = \"cuda\"\n                )\n            \n            triton_ref = torch.zeros_like(q_nope)\n            page_table = torch.arange(max_pages, dtype=torch.int32, device=\"cuda\")\n            ckv_with_pe = torch.cat([ckv, k_pe], dim=-1).contiguous().view(pages, page_size, 1, 576)\n            ckv = ckv.view(pages, page_size, 1, 512)\n            decode_attention_fwd_grouped(q, ckv_with_pe, ckv, triton_ref,\n                page_table,\n                kv_len_arr, attn_logits,\n                4, #num_kv_splits # follow vLLM, fix it TODO\n                192 ** (-0.5),\n                page_size)\n\n            torch.testing.assert_close(attn_output, triton_ref, rtol=1e-3, atol=1e-3)\n            \n            #file_name = f\"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt\"\n            #ktrans_output = torch.load(file_name)\n            #torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)\n            print(\"test past\")"
  },
  {
    "path": "kt-sft/ktransformers/operators/gate.py",
    "content": "from typing import Optional\nfrom torch import nn\nimport torch\nimport torch.nn.functional as F\nimport os\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.operators.linear import KTransformersLinear\nfrom ktransformers.util.custom_loader import GGUFLoader, ModelLoader, SafeTensorLoader\nfrom transformers.configuration_utils import PretrainedConfig\nfrom abc import ABC, abstractmethod\n\n\n# class Base(BaseInjectedModule, ABC):\nclass KMoEGateBase(ABC):\n    def __init__(self, \n                 key: str, \n                 gguf_loader: GGUFLoader, \n                 config: PretrainedConfig, \n                 orig_module: nn.Module, \n                 device: str = \"cuda\", \n                 **kwargs):\n        # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        super().__init__()\n        self.key = key\n        self.gguf_loader = gguf_loader\n        self.config = config\n        self.device = device\n        self.orig_module = orig_module\n    \n    @abstractmethod\n    def forward(self, input_tensor, expert_ids, weights):\n        pass\n\n    @abstractmethod\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = \"cpu\", warmup: bool = False):\n        pass\n    \n    @abstractmethod\n    def unload():\n        pass\n\n    def load_weights(self, override_key: str | None = None, device: str = \"cpu\"):\n        res = {}\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        gate = None\n        up = None\n        down = None\n        gate_type = None\n        up_type = None\n        down_type = None\n\n        for key in keys:\n            # key = \".\".join(key.split(\".\")[:-1])\n            if isinstance(self.gguf_loader, SafeTensorLoader):\n                res = self.gguf_loader.load_gate(key, device=device)\n            elif self.gguf_loader.has_tensor(key+\".weight\"):\n                # targets = [\".ffn_gate_inp.weight\", \".exp_probs_b.bias\"]\n                targets = [\".weight\", \".e_score_correction_bias\"]\n                tensors = self.load_multi(key, targets, device=device)\n                weight = tensors[\".weight\"]\n                e_score_correction_bias = tensors[\".e_score_correction_bias\"]\n                # weight_type = self.gguf_loader.tensor_info[key + \".weight\"][\"ggml_type\"]\n                res = {\"weight\": weight, \"e_score_correction_bias\": e_score_correction_bias}\n            else:\n                raise ValueError(f\"Experts {key} not found in gguf_loader\")\n\n        return res\n    \n    def load_multi(self, key: str, keys: list[str], device: str = \"cpu\"):\n        tensors = {}\n        for k in keys:\n            tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)\n        return tensors\n\n\nclass KMoEGate(BaseInjectedModule, KMoEGateBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        generate_device: str = \"cuda\",\n        prefill_device: str = \"cuda\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        return self.orig_module.forward(hidden_states)\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if device is None: device = self.device\n        if w is None: w = self.load_weights(device=device)\n        \n        if isinstance(w, dict):\n            self.orig_module.weight = nn.Parameter(w[\"weight\"])\n            self.orig_module.e_score_correction_bias = nn.Parameter(w[\"e_score_correction_bias\"])\n        else:\n            raise ValueError(\"Invalid weight type\")\n        self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))\n        self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))\n\n    def unload(self):\n        if self.weight is not None:\n            self.weight = None\n        if self.e_score_correction_bias is not None:\n            self.e_score_correction_bias = None\n\n\nclass KMoEGateQwen2Moe(BaseInjectedModule, KMoEGateBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        generate_device: str = \"cuda\",\n        generate_op: str| None = \"KLinearMarlin\",\n        prefill_device: str = \"cuda\",\n        prefill_op: str| None = \"KLinearMarlin\",\n        use_quant: bool = False,\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n        self.generate_op = generate_op\n        self.prefill_op = prefill_op\n        self.is_windows = os.name == 'nt'\n        self.use_quant = use_quant\n        if not self.is_windows and use_quant:\n            self.gate_linear = nn.Linear(self.gating_dim, self.n_routed_experts, device=generate_device)\n            self.gate_linear = KTransformersLinear(key + \".ffn_gate_inp\", \n                                               gguf_loader, config, self.gate_linear, #orig_module\n                                               generate_device, generate_op, prefill_device, prefill_op)\n        else:\n            self.gate_linear = None\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        if self.is_windows:\n            return self.orig_module.forward(hidden_states)\n        \n        bsz, seq_len, h = hidden_states.shape\n        ### compute gating score\n        hidden_states = hidden_states.view(-1, h)\n        if self.use_quant:\n            logits = self.gate_linear.forward(logits)\n        else:\n            logits = F.linear(\n                hidden_states.type(torch.float32), self.weight.type(torch.float32), None\n            )\n            \n        return grouped_topk(hidden_states, logits,\n                            self.top_k, self.norm_topk_prob,\n                            self.n_group, self.topk_group)\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if device is None: device = self.device\n        if w is None: w = self.load_weights(device=device)\n        \n        if isinstance(w, dict):\n            self.orig_module.weight = nn.Parameter(w[\"weight\"])\n            self.orig_module.e_score_correction_bias = nn.Parameter(w[\"e_score_correction_bias\"])\n        else:\n            raise ValueError(\"Invalid weight type\")\n        self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))\n        self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))\n        if not self.is_windows and self.use_quant:\n            self.gate_linear.load(self.orig_module.weight)\n\n    def unload(self):\n        if self.weight is not None:\n            self.weight = None\n        if self.e_score_correction_bias is not None:\n            self.e_score_correction_bias = None\n\n\nclass KMoEGateIPEXLLM(KMoEGate):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        generate_device: str = \"xpu\",\n        prefill_device: str = \"xpu\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        KMoEGate.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        self.generate_device = generate_device\n        self.prefill_device = prefill_device\n\n    def forward(self, hidden_states) -> torch.Tensor:\n        x = hidden_states.view(-1, hidden_states.size(-1))\n        logits = torch.nn.functional.linear(\n            x.type(torch.float32), self.orig_module.weight.type(torch.float32), None\n        )\n        scores = logits.sigmoid()\n\n        from ipex_llm.transformers.models.common import moe_group_topk\n        topk_idx, topk_weight = moe_group_topk(scores, self.orig_module.e_score_correction_bias,\n                                               self.n_group, self.topk_group, self.top_k,\n                                               self.norm_topk_prob, self.routed_scaling_factor)\n        return topk_idx, topk_weight.to(x.dtype)"
  },
  {
    "path": "kt-sft/ktransformers/operators/layernorm.py",
    "content": "'''\nDate: 2024-11-13 15:05:52\nLastEditors: Xie Weiyu ervinxie@qq.com\nLastEditTime: 2024-11-25 08:59:19\n'''\n\"\"\"\nCopyright 2023-2024 SGLang Team\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\"\"\"\n\n\"\"\"Fused operators for normalization layers.\"\"\"\n\nimport logging\nfrom typing import Optional, Tuple, Union\nfrom transformers import PretrainedConfig\nimport torch\nimport torch.nn as nn\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeRMSNorm\nfrom ktransformers.models.modeling_qwen3_moe import Qwen3MoeRMSNorm\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader\nif not torch.xpu.is_available():\n    from flashinfer.norm import (\n        fused_add_rmsnorm,\n        rmsnorm,\n    )\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.hidden_size,\n            orig_module.variance_epsilon)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        batch_size_tensor: torch.Tensor = None,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        #return self.forward_native(x, residual)\n        if batch_size_tensor is None:\n            return self.forward_native(x)\n        if residual is not None:\n            fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            #residual = x + residual\n            #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            return x, residual\n        # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())\n        out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)\n        return out\n\n    def forward_native(\n        self, hidden_states    \n    ):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nclass KQwen2MoeRMSNorm(Qwen2MoeRMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(config.hidden_size,\n            orig_module.variance_epsilon)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        batch_size_tensor: torch.Tensor = None,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        #return self.forward_native(x, residual)\n        if batch_size_tensor is None:\n            return self.forward_native(x)\n        if residual is not None:\n            fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            #residual = x + residual\n            #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            return x, residual\n        # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())\n        out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)\n        return out\n\n    def forward_native(\n        self, hidden_states    \n    ):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\n\nclass KQwen3MoeRMSNorm(Qwen3MoeRMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.hidden_size,\n            orig_module.variance_epsilon)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        batch_size_tensor: torch.Tensor = None,\n        residual: Optional[torch.Tensor] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        #return self.forward_native(x, residual)\n        bsz, hidden_size = x.shape\n        x = x.view(-1, self.orig_module.hidden_size)\n        if batch_size_tensor is None:\n            return self.forward_native(x)\n        if residual is not None:\n            fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            #residual = x + residual\n            #out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)\n            return x, residual\n        # print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())\n        out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)\n        out = out.view(bsz, hidden_size)\n        return out\n\n    def forward_native(\n        self, hidden_states    \n    ):\n        input_dtype = hidden_states.dtype\n        hidden_states = hidden_states.to(torch.float32)\n        variance = hidden_states.pow(2).mean(-1, keepdim=True)\n        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight * hidden_states.to(input_dtype)\n\nclass DeepseekV3RMSNormTorch(DeepseekV3RMSNorm, BaseInjectedModule):\n    def __init__(self,\n                key: str,\n                gguf_loader : GGUFLoader,\n                config: PretrainedConfig,\n                orig_module: nn.Module,\n                prefill_device: str = \"cuda\",\n                generate_device: str = \"cuda\",\n                **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.hidden_size,\n            orig_module.variance_epsilon)\n\n    def forward(\n        self, \n        x,\n        batch_size_tensor: torch.Tensor = None,\n        residual: Optional[torch.Tensor] = None,\n    )-> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        if residual is not None:\n            x = x + residual\n            residual = x\n        # range batch_size_tensor for x\n        input_dtype = x.dtype\n        x = x.to(torch.float32)\n        variance = x.pow(2).mean(-1, keepdim=True)\n        x = x * torch.rsqrt(variance + self.variance_epsilon)\n        if residual is not None:\n            return self.weight * x.to(input_dtype), residual\n        return self.weight * x.to(input_dtype)\n\n\nclass KDeepseekRMSNormIPEXLLM(DeepseekV3RMSNorm, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"xpu\",\n                 generate_device: str = \"xpu\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.weight.shape[0],\n            orig_module.variance_epsilon)\n        self.eps = orig_module.variance_epsilon\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        from ipex_llm.transformers.models.common import rms_norm_forward\n        if x.dtype not in [torch.float32, torch.float16]:\n            output = rms_norm_forward(self, x.float())\n        else:\n            output = rms_norm_forward(self, x)\n        return output.to(x.dtype)\n\n    def load(self):\n        BaseInjectedModule.load(self)\n        if self.weight.dtype not in [torch.float32, torch.float16]:\n            self.weight = self.weight.float()"
  },
  {
    "path": "kt-sft/ktransformers/operators/linear.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : Azure-Tang, Boxin Zhang\nDate         : 2024-07-25 11:25:24\nVersion      : 0.1.0\nLastEditors  : Azure \nLastEditTime : 2024-08-29 09:11:16\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\n\n\nimport ctypes\nimport time\nimport torch\nfrom torch import Tensor, nn\nif not torch.xpu.is_available():\n    import KTransformersOps\n    import vLLMMarlin\nfrom ktransformers.util.custom_loader import GGUFLoader, SafeTensorLoader\nfrom ktransformers.util.inference_state import InferenceState\nif not torch.xpu.is_available():\n    from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (\n        MarlinWorkspace,\n        marlin_quantize,\n        GPTQ_MARLIN_MIN_THREAD_N,\n        GPTQ_MARLIN_MIN_THREAD_K,\n        GPTQ_MARLIN_MAX_PARALLEL,\n        vllm_marlin_quantize\n    )\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom transformers.configuration_utils import PretrainedConfig\nfrom ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant\nfrom ktransformers.util.globals import GLOBAL_CONFIG\nfrom abc import ABC, abstractmethod\nimport sys, os\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Release\"))\nsys.path.append(os.path.join(os.path.dirname(__file__), \"..\", \"ktransformers_ext\", \"build\", \"Debug\"))\nimport cpuinfer_ext\nfrom ktransformers.operators.cpuinfer import CPUInfer\nfrom ktransformers.server.config.config import Config\nfrom typing import Dict, Tuple, Optional, Union\nimport numpy as np\n\n#class KLinearBase(BaseInjectedModule, ABC):\nclass KLinearBase(nn.Module, ABC):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        **kwargs,\n    ):\n        # super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        super().__init__()\n        self.key = key\n        self.gguf_loader = gguf_loader\n        self.device = device\n        self.config = config\n\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        if orig_module is not None:\n            self.in_features = orig_module.in_features\n            self.out_features = orig_module.out_features\n        else:\n            shape = self.gguf_loader.tensor_info[key + \".weight\"][\"shape\"]\n            if len(shape) == 1:\n                print(\"Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF\")\n            self.in_features  = self.gguf_loader.tensor_info[key + \".weight\"][\"shape\"][0]\n            self.out_features = self.gguf_loader.tensor_info[key + \".weight\"][\"shape\"][1]\n\n        self.loaded = False # for lm_head pre-load, TODO: use new way to do lm_head pre-load when layer wise prefill.\n\n    @abstractmethod\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        pass\n\n    def load_weight(self, override_key: str | None = None, device: str | None = None):\n        if override_key is not None:\n            keys = override_key\n        else:\n            keys = [self.key]\n\n        for key in keys:\n            if isinstance(self.gguf_loader, SafeTensorLoader):\n                # using safetensor_loader\n                tensor = self.gguf_loader.load_tensor(key+'.weight')\n                if self.gguf_loader.has_tensor(key+'.weight_scale_inv'):\n                    weight_scale_inv = self.gguf_loader.load_tensor(key+'.weight_scale_inv')\n                    return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)\n                return nn.Parameter(tensor)\n                \n            elif self.gguf_loader.has_tensor(key + \".weight\") or \"kv_b_proj\" in key:\n                if key + \".bias\" in self.gguf_loader.tensor_file_map:\n                    tensors = self.load_multi(key, [\"weight\", \"bias\"], device=device)\n                    tensor = tensors[\"weight\"]\n                    bias = tensors[\"bias\"]\n                    # self.qtype = GGML_TYPE_QTYPE_MAP[tensorinfo[key + \".weight\"][\"ggml_type\"]]\n                    # print(torch.isinf(tensor).any(), torch.isinf(bias).any())\n                    return nn.Parameter(tensor), nn.Parameter(bias)\n                elif \"kv_b_proj\" in key and not self.gguf_loader.has_tensor(key + \".weight\"):\n                    attn_k_b_tensors = self.load_multi(key.replace(\"self_attn.kv_b_proj\", \"attn_k_b\"), [\"weight\"], device=device)\n                    attn_k_b = attn_k_b_tensors[\"weight\"]\n                    del attn_k_b_tensors\n                    attn_k_b = attn_k_b.transpose(1, 2).contiguous()\n                    attn_v_b_tensors = self.load_multi(key.replace(\"self_attn.kv_b_proj\", \"attn_v_b\"), [\"weight\"], device=device)\n                    attn_v_b = attn_v_b_tensors[\"weight\"]\n                    del attn_v_b_tensors\n                    kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)\n                    kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous()\n                    del attn_k_b\n                    del attn_v_b\n                    return nn.Parameter(kv_b_proj)\n                else:\n                    tensors = self.load_multi(key, [\"weight\"], device=device)\n                    tensor = tensors[\"weight\"]\n                    # self.qtype = GGML_TYPE_QTYPE_MAP[tensorinfo[key + \".weight\"][\"ggml_type\"]]\n                    return nn.Parameter(tensor)\n            else:\n                raise FileNotFoundError(f\"Weight file not found for key {key}\")\n\n    def load_multi(self, key: str, keys: list[str], device: str = \"cpu\"):\n        tensors = {}\n        for k in keys:\n            tensors[k] = self.gguf_loader.load_gguf_tensor(key + \".\" + k, device=device)\n        return tensors\n\n    @abstractmethod\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = \"cuda\"):\n        pass\n\n    @abstractmethod\n    def unload(self):\n        pass\n\n\nclass KLinearTorch(KLinearBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        self.weight = None\n        self.has_bias = False\n\n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:\n        dtype = x.dtype\n        out_device = x.device\n\n        if (not x.requires_grad) and GLOBAL_CONFIG._config[\"mod\"] == \"sft\":\n            x = x.requires_grad_(True)\n        # TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.\n        x = x.to(device=self.device, dtype=self.dtype)\n        x = x @ self.weight\n        if self.has_bias:\n            x = x + self.bias\n        x = x.to(dtype=dtype, device=out_device)\n        return x\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if self.loaded: return\n        if device is None: device = self.device\n        if w is None: w = self.load_weight(device=device)\n        # else: self.out_features = w.shape[0], self.in_features = w.shape[1]\n        \n        if isinstance(w, nn.Parameter):\n            try:\n                self.weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T\n            except: \n                self.weight = w.to(dtype=self.dtype).T\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            try:\n                self.weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T\n            except:\n                self.weight = w[0].to(dtype=self.dtype).T\n            self.bias = w[1].to(dtype=self.dtype)\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        # self.linear = self.linear.to(device)\n        self.weight = self.weight.to(device)\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n        self.loaded = True\n\n    def unload(self):\n        if self.weight is not None:\n            self.weight = None\n        if self.has_bias:\n            self.bias = None\n\nclass KLinearQ8(KLinearBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.has_bias = False\n        self.compute_dtype = torch.float32\n        self.weight = None\n        self.weight_scale = None\n        self.weight_zero_point = None\n        self.bias = None\n        self.loaded = False\n    \n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None) -> torch.Tensor:\n        orig_dtype = x.dtype\n        out_device = x.device\n        \n        x = x.to(device=self.device, dtype=self.compute_dtype)\n        \n\n        weight_dequant = self._dequantize_weight(self.weight, self.weight_scale, bits=8)\n        out = x @ weight_dequant.T\n        \n        if self.has_bias:\n            out = out + self.bias\n        \n        return out.to(dtype=orig_dtype, device=out_device)\n    \n    def _dequantize_weight(self, q_matrix, scales, bits=8):\n        \"\"\"\n        Dequantize a low-precision matrix back to floating-point\n        \n        Args:\n            q_matrix (torch.Tensor): Quantized int matrix\n            scales (torch.Tensor): Scale factors for each column\n            bits (int): Quantization bits used (8 or 4)\n        \n        Returns:\n            torch.Tensor: Dequantized floating-point matrix\n        \"\"\"\n        # Ensure inputs are torch tensors\n        if not isinstance(q_matrix, torch.Tensor):\n            q_matrix = torch.tensor(q_matrix, dtype=torch.int8)\n        if not isinstance(scales, torch.Tensor):\n            scales = torch.tensor(scales, dtype=torch.float32)\n        \n        # Convert to correct dtype if needed\n        if q_matrix.dtype != torch.int8:\n            q_matrix = q_matrix.to(torch.int8)\n        if scales.dtype != torch.float32:\n            scales = scales.to(torch.float32)\n        \n        # For Q4, ensure the values stay within 4-bit range\n        if bits == 4:\n            q_matrix = torch.clamp(q_matrix, -7, 7)\n        rows, cols = q_matrix.shape\n        dequant_matrix = q_matrix.to(torch.float32)\n        scales_broadcast = scales.view(1, cols)\n        # Apply dequantization to all columns at once using matrix multiplication\n        dequant_matrix = dequant_matrix * scales_broadcast\n        \n        return dequant_matrix\n\n    \n    def _quantize_weight(self, matrix, bits=8):\n        \"\"\"\n        Quantize a floating-point matrix to lower precision (Q8 or Q4)\n        \n        Args:\n            matrix (torch.Tensor): Input matrix in floating-point format\n            bits (int): Quantization bits, either 8 or 4\n        \n        Returns:\n            tuple: (quantized int matrix, scale factors for each column)\n        \"\"\"\n        if not isinstance(matrix, torch.Tensor):\n            matrix = torch.tensor(matrix, dtype=torch.float32)\n        \n        # Convert to float32 if needed\n        if matrix.dtype != torch.float32:\n            matrix = matrix.to(torch.float32)\n        \n        # Get matrix shape\n        rows, cols = matrix.shape\n        \n        # Determine quantization parameters based on bits\n        if bits == 8:\n            max_int = 127\n            qtype = torch.int8\n        elif bits == 4:\n            max_int = 7\n            qtype = torch.int8  # We'll still use int8 storage but limit to 4-bit range, wait for native support\n        else:\n            raise ValueError(\"Quantization bits must be either 8 or 4\")\n       \n        scales = torch.zeros(cols, dtype=torch.float32, device=matrix.device)\n        \n        # Calculate max absolute value for each column\n        max_abs_vals, _ = torch.max(torch.abs(matrix), dim=0)\n        \n        # Handle zero columns (avoid division by zero)\n        zero_cols = max_abs_vals == 0\n        max_abs_vals[zero_cols] = 1.0\n        \n        # Calculate scale factors for all columns at once\n        scales = max_abs_vals / max_int\n        \n        # Prepare the scales for broadcasting [1, cols]\n        scales_broadcast = scales.view(1, cols)\n        \n        # Apply quantization to the entire matrix at once\n        q_matrix = torch.round(matrix / scales_broadcast).to(qtype)\n        \n        # For Q4, clamp values to ensure they stay within 4-bit range\n        if bits == 4:\n            q_matrix = torch.clamp(q_matrix, -max_int, max_int)\n        \n        return q_matrix, scales\n    \n    def load(self, w: Union[Dict, nn.Parameter, Tuple, None] = None, device: Optional[str] = None):\n        if self.loaded: return\n        if device is None: device = self.device \n        if w is None: w = self.load_weight(device=device)\n        \n        if isinstance(w, nn.Parameter):\n            try:\n                weight = w.to(dtype=self.compute_dtype).view(self.out_features, self.in_features)\n            except:\n                weight = w.to(dtype=self.compute_dtype)\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            try:\n                weight = w[0].to(dtype=self.compute_dtype).view(self.out_features, self.in_features)\n            except:\n                weight = w[0].to(dtype=self.compute_dtype)\n            self.bias = w[1].to(dtype=self.compute_dtype).to(device)\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        \n        self.weight, self.weight_scale = self._quantize_weight(weight, bits=8)\n        \n        self.weight = self.weight.to(device)\n        self.weight_scale = self.weight_scale.to(device)\n        \n        if self.has_bias:\n            self.bias = self.bias.to(device)\n            \n        self.loaded = True\n    \n    def unload(self):\n        self.weight = None\n        self.weight_scale = None\n        self.weight_zero_point = None\n        self._orig_weight = None\n        \n        if self.has_bias:\n            self.bias = None\n            \n        self.loaded = False\n\n\nclass KLinearFP8(KLinearBase):\n    # this kernel requires special handling for weight\n    # Please load the weight file downloaded from KVCache.AI\n    has_bias: bool\n    weight: torch.Tensor\n    bias: torch.Tensor\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        block_size: int = 128,\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        self.block_size = block_size\n    \n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:\n        x = x.to(self.device)\n        orig_dtype = x.dtype        \n        x_quantized, scale_x = act_quant(x, self.block_size)\n        y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv)\n        return y.to(dtype=orig_dtype)\n    \n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if device is None: device = self.device\n        if w is None: \n            w = self.load_weight(device=device) \n        ### TODO fit weight_inv format\n        if isinstance(w, tuple):\n            self.weight = w[0].to(device)\n            self.weight_scale_inv = w[1].to(device)\n            self.has_bias = False\n        else:\n            raise ValueError(\"Invalid weight type\")\n        self.weight = self.weight.to(device)\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n        \n    def unload(self):\n        if self.weight is not None:\n            self.weight = None\n        if self.has_bias:\n            self.bias = None\n\n# TODO: merge two marlin class\n\nclass VLinearMarlin(KLinearBase):\n    marlin_q_w: torch.Tensor\n    marlin_s: torch.Tensor\n    g_idx: torch.Tensor\n    sort_indices: torch.Tensor\n    has_bias: bool\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        num_bits: int = 4,  # 4-bit/8-bit is supported\n        group_size: int = 64,  # -1, 32, 64, 128\n        act_order: bool = False,\n        is_k_full=True,\n        **kwargs,\n    ):\n        assert device.lower() != \"cpu\", \"Marlin quantized linear only supports GPU device\"\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.num_bits = num_bits\n        self.group_size = group_size\n        self.act_order = act_order\n        self.is_k_full = is_k_full\n        self.padding = False\n        self.orin_in_features = self.in_features\n        self.orin_out_features = self.out_features\n        if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:\n            #print(f\"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding\")\n            self.padding = True\n            self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K\n            self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N\n            #print(f\"After padding: in_features={in_features}, out_features={out_features}\")\n        \n        self.k = self.in_features\n        self.n = self.out_features\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if self.loaded: return\n        if device is None: device = self.device\n        assert device.lower() != \"cpu\", \"Marlin quantized linear only supports GPU device\"\n        \n        #if self.in_features * self.out_features:\n        if w is None: \n            w = self.load_weight(device=device) \n\n        if isinstance(w, nn.Parameter):\n            # pad weight\n            weight = w.view(self.orin_out_features, self.orin_in_features).T\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            w = list(w)\n            weight = w[0].view(self.orin_out_features, self.orin_in_features).T\n            self.bias = w[1].view(self.orin_out_features)\n            self.bias = w[1]\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        weight = weight.to(device)\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n            \n        if self.padding:\n            padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)\n            padded_weight[:self.orin_in_features, :self.orin_out_features] = weight\n            weight = padded_weight\n\n        # Pack Marlin linear\n        marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(\n            weight, self.num_bits, self.group_size, self.act_order\n        )\n        self.workspace = MarlinWorkspace(\n            self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device\n        )\n        self.weight = marlin_q_w\n        self.marlin_q_w = marlin_q_w\n        self.marlin_s = marlin_s\n        self.g_idx = g_idx\n        self.sort_indices = sort_indices\n        self.k = weight.shape[0]\n        self.n = weight.shape[1]\n        # self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device)\n        self.loaded = True\n\n\n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:\n        if bsz_tensor is None:\n            bsz_tensor = torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device)\n\n\n        # Only support input x as BF16 and FP16\n        x = x.to(self.device)\n        orig_shape = list(x.shape)\n        orig_dtype = x.dtype\n        x = x.reshape(-1, orig_shape[-1])\n        marlin_s = self.marlin_s.to(x.dtype)\n        sms = -1\n\n        # padding x.shape[0] to avoid CUDA illegal memory access error\n        x, orig_size_m = self._pad_input(x)\n\n        x = vLLMMarlin.gptq_marlin_gemm(\n            x,\n            self.marlin_q_w,\n            marlin_s,\n            self.g_idx,\n            self.sort_indices,\n            self.workspace.scratch,\n            self.num_bits,\n            bsz_tensor,\n            x.shape[0],\n            self.n,\n            x.shape[-1],\n            sms,\n            self.is_k_full,\n        )\n\n        x = x[:orig_size_m]\n\n        if self.has_bias:\n            x = x + self.bias\n        orig_shape[-1] = self.n\n        return x.reshape(orig_shape).to(orig_dtype)\n\n    def unload(self):\n\n        if self.has_bias:\n            self.bias = None\n        self.marlin_q_w = None\n        self.marlin_s = None\n        self.g_idx = None\n        self.sort_indices = None\n        self.workspace = None  \n\n    def _pad_input(self, x):\n\n        size_m = x.shape[0]\n        size_k = x.shape[1]\n\n        # size_m and align value depends on VLinearMarlin implementation\n        if size_m > 1024:\n            align = 1024\n        elif size_m > 64:\n            align = 64\n        else:\n            align = 1\n\n        padded_size_m = ((size_m + align - 1) // align) * align\n\n        if padded_size_m > size_m:\n            pad_len = padded_size_m - size_m\n            pad_tensor = torch.zeros((pad_len, size_k), dtype=x.dtype, device=x.device)\n            x = torch.cat([x, pad_tensor], dim = 0).contiguous()\n        return x, size_m\n\nclass KLinearMarlin(KLinearBase):\n    marlin_q_w: torch.Tensor\n    marlin_s: torch.Tensor\n    g_idx: torch.Tensor\n    sort_indices: torch.Tensor\n    has_bias: bool\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cuda\",\n        num_bits: int = 4,  # 4-bit/8-bit is supported\n        group_size: int = 64,  # -1, 32, 64, 128\n        act_order: bool = False,\n        is_k_full=True,\n        **kwargs,\n    ):\n        assert device.lower() != \"cpu\", \"Marlin quantized linear only supports GPU device\"\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.num_bits = num_bits\n        self.group_size = group_size\n        self.act_order = act_order\n        self.is_k_full = is_k_full\n        self.padding = False\n        self.orin_in_features = self.in_features\n        self.orin_out_features = self.out_features\n        if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:\n            #print(f\"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding\")\n            self.padding = True\n            self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K\n            self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N\n            #print(f\"After padding: in_features={in_features}, out_features={out_features}\")\n        \n        self.k = self.in_features\n        self.n = self.out_features\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if self.loaded: return\n        if device is None: device = self.device\n        assert device.lower() != \"cpu\", \"Marlin quantized linear only supports GPU device\"\n        \n        #if self.in_features * self.out_features:\n        if w is None: \n            w = self.load_weight(device=device) \n\n        if isinstance(w, nn.Parameter):\n            # pad weight\n            weight = w.view(self.orin_out_features, self.orin_in_features).T\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            w = list(w)\n            weight = w[0].view(self.orin_out_features, self.orin_in_features).T\n            self.bias = w[1].view(self.orin_out_features)\n            self.bias = w[1]\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        weight = weight.to(device)\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n            \n        if self.padding:\n            padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)\n            padded_weight[:self.orin_in_features, :self.orin_out_features] = weight\n            weight = padded_weight\n\n        # Pack Marlin linear\n        marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(\n            weight, self.num_bits, self.group_size, self.act_order\n        )\n        self.workspace = MarlinWorkspace(\n            self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device\n        )\n        self.weight = marlin_q_w # modeling_xxx.py may use linear.weight\n        self.marlin_q_w = marlin_q_w\n        self.marlin_s = marlin_s\n        self.g_idx = g_idx\n        self.sort_indices = sort_indices\n        self.k = weight.shape[0]\n        self.n = weight.shape[1]\n        self.loaded = True\n\n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:\n        # Only support input x as BF16 and FP16\n        x = x.to(self.device)\n        orig_shape = list(x.shape)\n        orig_dtype = x.dtype\n        x = x.reshape(-1, orig_shape[-1])\n        x = x.reshape(-1, x.shape[-1])\n        if self.padding:\n            padding_input=torch.empty(x.shape[0], self.in_features, device=x.device, dtype=x.dtype)\n            padding_input[:,:self.orin_in_features] = x\n            x = padding_input\n        marlin_s = self.marlin_s.to(x.dtype)\n        x = KTransformersOps.gptq_marlin_gemm(\n            x,\n            self.marlin_q_w,\n            marlin_s,\n            self.g_idx,\n            self.sort_indices,\n            self.workspace.scratch,\n            self.num_bits,\n            x.shape[0],\n            self.n,\n            x.shape[-1],\n            self.is_k_full,\n        )\n        if self.padding:\n            x = x[:,:self.orin_out_features]\n            orig_shape[-1] = self.orin_out_features\n        else:\n            orig_shape[-1] = self.out_features\n        if self.has_bias:\n            x = x + self.bias\n        return x.reshape(orig_shape).to(orig_dtype)\n\n    def unload(self):\n\n        if self.has_bias:\n            self.bias = None\n        self.marlin_q_w = None\n        self.marlin_s = None\n        self.g_idx = None\n        self.sort_indices = None\n        self.workspace = None\n\nclass KLinearCPUInfer(KLinearBase):\n    CPU_INFER = None\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"cpu\",\n        out_device: str = \"cuda\", # this device mean which device the output should on. TODO: support cpu.\n        stride = 16,\n        group_max_len = 1024,\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        if KLinearCPUInfer.CPU_INFER is None:\n            KLinearCPUInfer.CPU_INFER = CPUInfer(Config().cpu_infer)\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        self.w = None\n        self.has_bias = False\n        self.stride = stride\n        self.group_max_len = group_max_len\n        self.out_device = out_device\n\n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:\n        origin_shape = x.shape # [batch_size, q_len, hidden_size]\n        if origin_shape[1] == 1 and torch.cuda.is_current_stream_capturing():\n            out_device = x.device\n            self.input_tensor_cpu.copy_(x, non_blocking=True)\n            qlen = origin_shape[1]\n            KLinearCPUInfer.CPU_INFER.submit_with_cuda_stream(\n                torch.cuda.current_stream().cuda_stream,\n                self.linear.forward(\n                    qlen, \n                    self.input_tensor_cpu.data_ptr(), \n                    self.output_cpu.data_ptr()\n                )\n            )\n            KLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)\n            self.output_gpu.copy_(self.output_cpu, non_blocking=True)\n            if self.has_bias:\n                self.output_gpu += self.bias\n            return self.output_gpu\n        else:\n            dtype = x.dtype\n            out_device = x.device\n            x = x.to(device=self.device)\n            qlen = origin_shape[1]\n            output_shape = (*origin_shape[:-1], self.out_features)\n            output = torch.empty(output_shape, device=x.device, dtype=x.dtype)\n            KLinearCPUInfer.CPU_INFER.submit(\n                self.linear.forward(\n                    qlen, \n                    x.data_ptr(), \n                    output.data_ptr()\n                )\n            )\n            KLinearCPUInfer.CPU_INFER.sync()\n            if self.has_bias:\n                output = output + self.bias\n            output = output.to(dtype=dtype, device=out_device)\n            return output\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None, warmup:bool = True):\n        print(f\"loading {self.key} to {self.device} using CPUInfer\")\n        if device is None: device = self.device\n        self.load_weights(w=w, device=device)\n        if self.bias is not None:\n            self.has_bias = True\n            self.bias = self.bias.to(device)\n            \n        weight_ptr = ctypes.addressof(\n            ctypes.cast(self.weight.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents\n        )\n        config = cpuinfer_ext.linear.LinearConfig(self.in_features, self.out_features, self.stride, self.group_max_len, weight_ptr, self.weight_type, 30)\n        self.linear = cpuinfer_ext.linear.Linear(config)\n        \n        if warmup:\n            KLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up())\n            KLinearCPUInfer.CPU_INFER.sync()\n        self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device=\"cpu\", pin_memory=True)\n        self.output_cpu = torch.zeros((1, 1, self.out_features), device=\"cpu\", pin_memory=True, dtype=torch.bfloat16)\n        self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device)\n\n    def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = \"cpu\"):\n        if self.gguf_loader.has_tensor(self.key + \".weight\"):\n            if self.key + \".bias\" in self.gguf_loader.tensor_file_map:\n                self.weight = self.gguf_loader.get_mmap_tensor(self.key + \".weight\")\n                self.weight_type = self.gguf_loader.tensor_info[self.key + \".weight\"][\"ggml_type\"]\n                self.bias = self.gguf_loader.load_gguf_tensor(self.key + \".bias\", device=device)\n            else:\n                self.weight = self.gguf_loader.get_mmap_tensor(self.key + \".weight\")\n                self.weight_type = self.gguf_loader.tensor_info[self.key + \".weight\"][\"ggml_type\"]\n                self.bias = None\n        else:\n            raise ValueError(f\"Linear {self.key} not found in gguf_loader\")\n\n    def unload(self):\n        if self.w is not None:\n            self.w = None\n        if self.has_bias:\n            self.bias = None       \n\nclass KLinearIPEXLLM(KLinearBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module = None,\n        device: str = \"xpu\",\n        precision: str = \"sym_int4\",\n        **kwargs,\n    ):\n        super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)\n        self.has_bias = False\n        self.dtype = torch.get_default_dtype()\n        self.weight = None\n        self.has_bias = False\n        self.precision = precision\n        self.qtype = None\n\n    def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:\n        dtype = x.dtype\n        out_device = x.device\n        from ipex_llm.transformers.models.common import linear_forward\n        x = linear_forward(x.half(), self.weight, self.qtype, self.out_features)\n\n        if self.has_bias:\n            x = x + self.bias\n        x = x.to(dtype=dtype, device=out_device)\n        return x\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):\n        if self.loaded: return\n        if device is None: device = self.device\n        assert device.lower()[:3] == \"xpu\", \"IPEX-LLM quantized linear only supports XPU device\"\n        if w is None: w = self.load_weight(device=device)\n\n        if isinstance(w, nn.Parameter):\n            try:\n                weight = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T\n            except:\n                weight = w.to(dtype=self.dtype).T\n            self.has_bias = False\n        elif isinstance(w, tuple):\n            try:\n                weight = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T\n            except:\n                weight = w[0].to(dtype=self.dtype).T\n            self.bias = w[1].to(dtype=self.dtype)\n            self.has_bias = True\n        else:\n            raise ValueError(\"Invalid weight type\")\n        weight = weight.to(\"cpu\").float().transpose(0, 1).contiguous()\n\n        if self.has_bias:\n            self.bias = self.bias.to(device)\n\n        # quantize linear weight\n        from ipex_llm.transformers.models.common import quantize_linear\n        paramsLowBit, qtype = quantize_linear(weight, self.in_features, self.precision)\n        self.weight = paramsLowBit.to(device)\n        self.qtype = qtype\n        self.loaded = True\n\n    def unload(self):\n        if self.weight is not None:\n            self.weight = None\n        if self.has_bias:\n            self.bias = None\n\nLINEAR_MAP = {\n    \"KLinearMarlin\": KLinearMarlin,\n    \"KLinearTorch\": KLinearTorch,\n    \"KLinearCPUInfer\": KLinearCPUInfer,\n    \"VLinearMarlin\": VLinearMarlin,\n    \"KLinearFP8\": KLinearFP8,\n    \"KLinearQ8\": KLinearQ8,\n    \"KLinearIPEXLLM\": KLinearIPEXLLM,\n}\n\nclass KTransformersLinear(BaseInjectedModule, KLinearBase):\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        generate_device: str = \"cuda\",\n        generate_op: str| None = \"KLinearMarlin\",\n        prefill_device: str = \"cuda\",\n        prefill_op: str| None = \"KLinearTorch\",\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)\n        KLinearBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        # build all the linear operators\n        if prefill_op is not None:\n            assert prefill_op in LINEAR_MAP, f\"linear_type {prefill_op} not supported\"\n            self.prefill_linear = LINEAR_MAP[prefill_op](key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        else:\n            self.prefill_linear = None\n\n        if generate_op is not None:\n            assert generate_op in LINEAR_MAP, f\"linear_type {generate_op} not supported\"\n            self.generate_linear = LINEAR_MAP[generate_op](key, gguf_loader, config, orig_module, generate_device, **kwargs)\n        else:\n            self.generate_linear = None\n        self.mode = InferenceState.UNLOAD\n\n    def forward(self, x, bsz_tensor=None):\n        # linear_fwd_st = time.time()\n        if self.mode == InferenceState.PREFILL:\n            assert self.prefill_linear is not None, \"cpu linear is not initialized\"\n            y = self.prefill_linear.forward(x, bsz_tensor)\n        else:\n            assert self.generate_linear is not None, \"gpu linear is not initialized\"\n        # TODO: A violence way to solve the weight=None, for Lora inference Test, need modify it later\n            try:\n                y = self.generate_linear.forward(x, bsz_tensor)\n            except TypeError as e:\n                Warning(\"A Dange way to avoid the none weight, Need to check it later in KTransformersLinear forward!!\")\n                self.generate_linear.weight = self.orig_module.generate_linear.weight\n                self.weight = self.orig_module.generate_linear.weight\n                y = self.generate_linear.forward(x, bsz_tensor)\n        \n        # linear_fwd_end = time.time()\n        # print(f\"[KTLinear] Forward time: {linear_fwd_end-linear_fwd_st}\")\n        return y\n\n    def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):\n        if not mode:\n            mode = InferenceState.GENERATE\n        # load to device\n        if mode == InferenceState.PREFILL:\n            self.generate_linear.unload()\n            self.prefill_linear.load(w=w)\n            self.device = self.prefill_linear.device\n            self.weight = self.prefill_linear.weight # modeling_xxx.py may use linear.weight\n        elif mode == InferenceState.GENERATE:\n            self.prefill_linear.unload()\n            self.generate_linear.load(w=w)\n            self.device = self.generate_linear.device\n            self.weight = self.generate_linear.weight # modeling_xxx.py may use linear.weight\n        elif mode == InferenceState.UNLOAD:\n            self.prefill_linear.unload()\n            self.generate_linear.unload()\n            self.device = \"cpu\"\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n        self.mode = mode\n\n    def unload(self):\n        if self.prefill_linear is not None:\n            self.prefill_linear.unload()\n        if self.generate_linear is not None:\n            self.generate_linear.unload()\n        self.device = self.generate_linear.device\n\n    def set_inference_mode(self, mode: InferenceState):\n        if not mode: \n            mode = InferenceState.GENERATE\n        if mode == InferenceState.GENERATE:\n            self.load(mode=InferenceState.GENERATE)\n        elif mode == InferenceState.PREFILL:\n            self.load(mode=InferenceState.PREFILL)\n        elif mode == InferenceState.UNLOAD:\n            self.unload()\n        else:\n            raise ValueError(\"mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD\")\n\n\n"
  },
  {
    "path": "kt-sft/ktransformers/operators/mlp.py",
    "content": "\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom transformers import PretrainedConfig\nimport torch.nn as nn\nfrom ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP\nfrom ktransformers.models.modeling_qwen2_moe import Qwen2MoeMLP\nclass kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.hidden_size, orig_module.intermediate_size)\n    def forward(self, x, bsz_tensor):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)\n        return down_proj\nclass KQwen2MoeMLP(Qwen2MoeMLP, BaseInjectedModule):\n    def __init__(self,\n                 key: str,\n                 gguf_loader : GGUFLoader,\n                 config: PretrainedConfig,\n                 orig_module: nn.Module,\n                 prefill_device: str = \"cuda\",\n                 generate_device: str = \"cuda\",\n                 **kwargs):\n        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)\n        self.orig_module.__init__(orig_module.config,\n            orig_module.intermediate_size)\n    def forward(self, x, bsz_tensor):\n        down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)\n        return down_proj"
  },
  {
    "path": "kt-sft/ktransformers/operators/models.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :  \nAuthor       : Azure-Tang\nDate         : 2024-07-25 11:25:24\nVersion      : 1.0.0\nLastEditors  : Azure \nLastEditTime : 2024-08-27 07:29:04\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n\"\"\"\n\nimport inspect\nimport math\nfrom typing import List, Optional, Tuple, Union\nimport time\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\nfrom ktransformers.operators.dynamic_attention import DynamicScaledDotProductAttention\nfrom ktransformers.server.config.config import Config\nimport os\nimport yaml\nfrom transformers.activations import ACT2FN\nfrom transformers.cache_utils import Cache, DynamicCache, StaticCache\nfrom transformers.modeling_attn_mask_utils import (\n    AttentionMaskConverter,\n)\nfrom transformers.modeling_outputs import (\n    MoeCausalLMOutputWithPast,\n    MoeModelOutputWithPast,\n    SequenceClassifierOutputWithPast,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_utils import PreTrainedModel\nfrom transformers.utils import (\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    is_flash_attn_2_available,\n    is_flash_attn_greater_or_equal_2_10,\n    logging,\n    replace_return_docstrings,\n)\nfrom ktransformers.models.modeling_qwen2_moe import (\n    Qwen2MoeSparseMoeBlock,\n    Qwen2MoeMLP,\n    Qwen2MoeDecoderLayer,\n    Qwen2MoeRotaryEmbedding,\n)\n\nfrom ktransformers.models.modeling_qwen3_moe import (\n    Qwen3MoeSparseMoeBlock,\n    Qwen3MoeMLP,\n    Qwen3MoeDecoderLayer,\n)\n\nfrom ktransformers.models.modeling_deepseek import (\n    BaseModelOutputWithPast,\n    DeepseekV2DecoderLayer,\n    DeepseekV2MoE,\n)\nfrom ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor\nfrom transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig\nfrom ktransformers.models.configuration_llama import LlamaConfig\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.inference_state import InferenceState\nfrom ktransformers.util.utils import get_compute_capability\nfrom ktransformers.util.custom_loader import GGUFLoader\nfrom transformers.configuration_utils import PretrainedConfig\nfrom ktransformers.models.modeling_llama import (\n    LlamaDecoderLayer,\n    LlamaRMSNorm,\n    LlamaRotaryEmbedding,\n)\n\nif is_flash_attn_2_available():\n    from flash_attn import flash_attn_func, flash_attn_varlen_func\n    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa\n\n    _flash_supports_window_size = \"window_size\" in list(\n        inspect.signature(flash_attn_func).parameters\n    )\n\nlogger = logging.get_logger(__name__)\n\n_CHECKPOINT_FOR_DOC = \"Qwen/Qwen1.5-MoE-A2.7B\"\n_CONFIG_FOR_DOC = \"Qwen2MoeConfig\"\n\nQWEN2MOE_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`Qwen2MoeConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\nQWEN2MOE_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        output_router_logits (`bool`, *optional*):\n            Whether or not to return the logits of all the routers. They are useful for computing the router loss, and\n            should not be returned during inference.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.\",\n    QWEN2MOE_START_DOCSTRING,\n)\nclass KQwen2MoeModel(BaseInjectedModule):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`]\n\n    Args:\n        config: Qwen2MoeConfig\n    \"\"\"\n\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        device: str = \"cuda\",\n        per_layer_prefill_intput_threshold: int = 30000,  # if None, no per-layer prefill\n        transfer_map: dict = None,\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, device, **kwargs\n        )\n        self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold\n        self.transfer_map = transfer_map\n        self.stream_device_map = dict()\n\n    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        output_router_logits: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        per_layer_prefill_intput_threshold: (\n            int | None\n        ) = None,  # if None or 0, close per-layer prefill\n    ) -> Union[Tuple, MoeModelOutputWithPast]:\n        # print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}')\n\n        if per_layer_prefill_intput_threshold is None:\n            per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold\n        per_layer_prefill_flag = False\n        seq_lenth = (\n            inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)\n        )\n        if (\n            per_layer_prefill_intput_threshold\n            and per_layer_prefill_intput_threshold < seq_lenth\n        ):\n            per_layer_prefill_flag = True\n            for layer in self.layers:\n                self.load_layer_to(layer, InferenceState.UNLOAD)\n        else:\n            pass\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_router_logits = (\n            output_router_logits\n            if output_router_logits is not None\n            else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        use_legacy_cache = False\n        if use_cache and not isinstance(past_key_values, Cache):\n            use_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            logger.warning_once(\n                \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n                \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n            )\n\n        if inputs_embeds is None:\n            input_ids = input_ids.to(\"cpu\")\n            inputs_embeds = self.embed_tokens(input_ids)\n            inputs_embeds = inputs_embeds.to(\"cuda\")\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=inputs_embeds.device,\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask,\n            inputs_embeds,\n            cache_position,\n            past_key_values,\n            output_attentions,\n        )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        if torch.xpu.is_available() and inputs_embeds.device.type == \"xpu\":\n            position_embeddings = self.rotary_emb(hidden_states, position_ids)\n        else:\n            position_embeddings = None\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        all_router_logits = () if output_router_logits else None\n        next_decoder_cache = None\n\n        for i, decoder_layer in enumerate(self.layers):\n            if self.transfer_map is not None and i in self.transfer_map:\n                prev_stream = torch.cuda.current_stream()\n                cur_device = self.transfer_map[i]\n                if cur_device not in self.stream_device_map:\n                    self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                torch.cuda.set_device(cur_device)\n                self.stream_device_map[cur_device].wait_stream(prev_stream)\n                torch.cuda.set_stream(self.stream_device_map[cur_device])\n                hidden_states = hidden_states.to(\n                    self.transfer_map[i], non_blocking=True\n                )\n                causal_mask = (\n                    causal_mask.to(self.transfer_map[i], non_blocking=True)\n                    if causal_mask is not None\n                    else None\n                )\n                position_ids = (\n                    position_ids.to(self.transfer_map[i], non_blocking=True)\n                    if position_ids is not None\n                    else None\n                )\n                cache_position = (\n                    cache_position.to(self.transfer_map[i], non_blocking=True)\n                    if cache_position is not None\n                    else None\n                )\n\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    output_router_logits,\n                    use_cache,\n                    cache_position,\n                )\n            else:\n                if per_layer_prefill_flag:\n                    # print(f\"to gpu\")\n                    self.load_layer_to(decoder_layer, InferenceState.PREFILL)\n                    torch.cuda.empty_cache()\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    output_router_logits=output_router_logits,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                )\n                if per_layer_prefill_flag:\n                    # print(f\"to cpu\")\n                    self.load_layer_to(decoder_layer, InferenceState.UNLOAD)\n                    torch.cuda.empty_cache()\n            hidden_states = layer_outputs[0]\n\n            if use_cache and len(layer_outputs) > 1:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n            else:\n                next_decoder_cache = None\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n            if output_router_logits and layer_outputs[-1] is not None:\n                all_router_logits += (layer_outputs[-1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        if per_layer_prefill_flag:\n            per_layer_prefill_flag = False\n            for layer in self.layers:\n                self.load_layer_to(layer, InferenceState.GENERATE)\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            if next_decoder_cache is not None:\n                next_cache = (\n                    next_decoder_cache.to_legacy_cache()\n                    if use_legacy_cache\n                    else next_decoder_cache\n                )\n            else:\n                next_cache = past_key_values\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    next_cache,\n                    all_hidden_states,\n                    all_self_attns,\n                    all_router_logits,\n                ]\n                if v is not None\n            )\n        return MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            router_logits=all_router_logits,\n        )\n\n    def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState):\n        assert isinstance(\n            layer, Qwen2MoeDecoderLayer\n        ), \"module should be nn.ModuleList of decoder layers\"\n\n        # TODO Support restore to original device, not only cuda\n        device = \"cpu\" if target == InferenceState.UNLOAD else \"cuda\"\n\n        # attn\n        layer.self_attn.q_proj.set_inference_mode(target)\n        layer.self_attn.k_proj.set_inference_mode(target)\n        layer.self_attn.v_proj.set_inference_mode(target)\n        layer.self_attn.o_proj.set_inference_mode(target)\n        layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device)\n\n        # mlp\n        if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock):\n            layer.mlp.gate.set_inference_mode(target)\n            layer.mlp.experts.set_inference_mode(target)\n            layer.mlp.shared_expert.gate_proj.set_inference_mode(target)\n            layer.mlp.shared_expert.up_proj.set_inference_mode(target)\n            layer.mlp.shared_expert.down_proj.set_inference_mode(target)\n            layer.mlp.shared_expert.act_fn.to(device)\n            layer.mlp.shared_expert_gate.to(device)\n        else:\n            layer.mlp.gate_proj.set_inference_mode(target)\n            layer.mlp.up_proj.set_inference_mode(target)\n            layer.mlp.down_proj.set_inference_mode(target)\n            layer.mlp.act_fn.to(device)\n        # layer norm\n        layer.input_layernorm.to(device)\n        layer.post_attention_layernorm.to(device)\n\n\nDeepseekV2_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n\"\"\"\n\n\nclass KDeepseekV2Model(BaseInjectedModule):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]\n\n    Args:\n        config: DeepseekV2Config\n    \"\"\"\n\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        device: str = \"cuda\",\n        per_layer_prefill_intput_threshold: int = 30000,  # if None, no per-layer prefill\n        transfer_map: dict = None,\n        **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, device, **kwargs\n        )\n        self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold\n        self.transfer_map = transfer_map\n        self.stream_device_map = dict()\n\n    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[List[torch.FloatTensor]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n        per_layer_prefill_intput_threshold: (\n            int | None\n        ) = None,  # if None, no per-layer prefill\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        self.gradient_checkpointing = False\n        if per_layer_prefill_intput_threshold is None:\n            per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold\n        per_layer_prefill_flag = False\n        seq_lenth = (\n            inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)\n        )\n        if (\n            per_layer_prefill_intput_threshold\n            and per_layer_prefill_intput_threshold < seq_lenth\n        ):\n            per_layer_prefill_flag = True\n            for layer in self.layers:\n                self.load_layer_to(layer, InferenceState.UNLOAD)\n            torch.cuda.empty_cache()\n        else:\n            pass\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        # retrieve input_ids and inputs_embeds\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time\"\n            )\n        elif input_ids is not None:\n            batch_size, seq_length = input_ids.shape[:2]\n        elif inputs_embeds is not None:\n            batch_size, seq_length = inputs_embeds.shape[:2]\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds\")\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers.\"\n                )\n                use_cache = False\n\n        past_key_values_length = 0\n        if use_cache:\n            use_legacy_cache = not isinstance(past_key_values, Cache)\n            if use_legacy_cache:\n                past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            past_key_values_length = past_key_values.get_usable_length(seq_length)\n        \n        if inputs_embeds is None:\n            org_device = input_ids.device\n            # TODO move to embed_tokens's device, not hard code to cpu\n            # input_ids = input_ids.to(\"cpu\")\n            input_ids = input_ids.to(self.embed_tokens.weight.device)\n            inputs_embeds = self.embed_tokens(input_ids).to(org_device)\n            input_ids = input_ids.to(org_device)\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=inputs_embeds.device,\n            )\n\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        if inputs_embeds.device.type == \"xpu\" and position_ids is not None:\n            cos, sin = self.layers[0].self_attn.rotary_emb(inputs_embeds,\n                                                           position_ids)\n            position_embeddings = (cos, sin)\n        else:\n            position_embeddings = None\n\n        if per_layer_prefill_flag:\n            causal_mask = None\n        else:\n            if (os.name == 'nt'\n                or get_compute_capability() < 8\n                or (self.transfer_map is not None and 'cpu' in self.transfer_map.values())\n                or device_manager.gpu_vendor != GPUVendor.NVIDIA):\n                # print(\"for Windows or GPU before ampere, use forward_windows\")\n                # only use mask in forward windows or can't flash attn\n                causal_mask = self._update_causal_mask(\n                    attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions\n                )\n            else:\n                causal_mask = None\n\n        # embed positions\n        hidden_states = inputs_embeds\n        if per_layer_prefill_flag:\n            print(f\"Total length of input_ids: {hidden_states.size(1)}\")\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        t_gpu = 0\n        t_cpu = 0\n        t_f = 0\n\n        for i, decoder_layer in enumerate(self.layers):\n            # print(f\"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \\n\")\n            if self.transfer_map is not None and i in self.transfer_map:\n                prev_stream = torch.cuda.current_stream()\n                cur_device = self.transfer_map[i]\n                if cur_device not in self.stream_device_map and cur_device.lower() != \"cpu\":\n                    self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n                if cur_device.lower() != \"cpu\":\n                    torch.cuda.set_device(cur_device)\n                    self.stream_device_map[cur_device].wait_stream(prev_stream)\n                    torch.cuda.set_stream(self.stream_device_map[cur_device])\n                hidden_states = hidden_states.to(\n                    self.transfer_map[i], non_blocking=True\n                )\n                causal_mask = (\n                    causal_mask.to(self.transfer_map[i], non_blocking=True)\n                    if causal_mask is not None\n                    else None\n                )\n                position_ids = (\n                    position_ids.to(self.transfer_map[i], non_blocking=True)\n                    if position_ids is not None\n                    else None\n                )\n                cache_position = (\n                    cache_position.to(self.transfer_map[i], non_blocking=True)\n                    if cache_position is not None\n                    else None\n                )\n\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                )\n            else:\n                t3 = time.time()\n                if per_layer_prefill_flag:\n                    # print(f\"to gpu\")\n                    self.load_layer_to(decoder_layer, InferenceState.PREFILL)\n                    torch.cuda.empty_cache()\n                t4 = time.time()\n                # with open(\"log.txt\", \"a\") as f:\n                #     f.write(f\"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \\n\")\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                )\n                t5 = time.time()\n                if per_layer_prefill_flag:\n                    # print(f\"to cpu\")\n                    self.load_layer_to(decoder_layer, InferenceState.UNLOAD)\n                    torch.cuda.empty_cache()\n                t6 = time.time()\n            t_gpu += t4 - t3\n            t_cpu += t6 - t5\n            t_f += t5 - t4\n\n            hidden_states = layer_outputs[0]\n\n            # @@@@@@@ TODO open this notes, tmp close to fit deepseekv3\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n        # with open(\"log.txt\", \"a\") as f:\n        #     f.write(f\"@@@After layers\\n\")\n        #     f.write(f\"hidden_states={hidden_states}\\n\")\n        #     f.write(f\"hidden_states.shape={hidden_states.shape}\\n\")\n\n        if per_layer_prefill_flag:\n            t6 = time.time()\n            # print(f\"restore\")\n            per_layer_prefill_flag = False\n            for layer in self.layers:\n                self.load_layer_to(layer, InferenceState.GENERATE)\n            torch.cuda.empty_cache()\n            t7 = time.time()\n\n            print(\n                f\"total time: {t7-t3}, \\n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}\"\n            )\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = None\n        if use_cache:\n            next_cache = (\n                next_decoder_cache.to_legacy_cache()\n                if use_legacy_cache\n                else next_decoder_cache\n            )\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n    def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState):\n        assert isinstance(\n            layer, DeepseekV2DecoderLayer\n        ), \"module should be nn.ModuleList of decoder layers\"\n\n        # TODO Support restore to original device, not only cuda\n        device = \"cpu\" if target == InferenceState.UNLOAD else \"cuda\"\n\n        # TODO Support DFS to auto use {to, set_inference_mode} according to the module type\n\n        # attn\n        layer.self_attn.to(device)  #\n\n        # mlp\n        if isinstance(layer.mlp, DeepseekV2MoE):\n            layer.mlp.gate.to(device)\n            layer.mlp.experts.set_inference_mode(target)\n            layer.mlp.shared_experts.gate_proj.set_inference_mode(target)\n            layer.mlp.shared_experts.up_proj.set_inference_mode(target)\n            layer.mlp.shared_experts.down_proj.set_inference_mode(target)\n            layer.mlp.shared_experts.act_fn.to(device)\n            # layer.mlp.shared_expert_gate.to(device)\n        else:\n            layer.mlp.gate_proj.set_inference_mode(target)\n            layer.mlp.up_proj.set_inference_mode(target)\n            layer.mlp.down_proj.set_inference_mode(target)\n            layer.mlp.act_fn.to(device)\n        # layer norm\n        layer.input_layernorm.to(device)\n        layer.post_attention_layernorm.to(device)\n\n\nLLAMA_START_DOCSTRING = r\"\"\"\n    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the\n    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads\n    etc.)\n\n    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.\n    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage\n    and behavior.\n\n    Parameters:\n        config ([`LlamaConfig`]):\n            Model configuration class with all the parameters of the model. Initializing with a config file does not\n            load the weights associated with the model, only the configuration. Check out the\n            [`~PreTrainedModel.from_pretrained`] method to load the model weights.\n\"\"\"\n\nLLAMA_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare LLaMA Model outputting raw hidden-states without any specific head on top.\",\n    LLAMA_START_DOCSTRING,\n)\nclass LlamaPreTrainedModel(PreTrainedModel):\n    config_class = LlamaConfig\n    base_model_prefix = \"model\"\n    supports_gradient_checkpointing = True\n    _no_split_modules = [\"LlamaDecoderLayer\"]\n    _skip_keys_device_placement = [\"past_key_values\"]\n    _supports_flash_attn_2 = True\n    _supports_sdpa = True\n    _supports_cache_class = True\n    _supports_quantized_cache = True\n    _supports_static_cache = True\n\n    def _init_weights(self, module):\n        std = self.config.initializer_range\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.bias is not None:\n                module.bias.data.zero_()\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=std)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n\nclass KLlamaModel(BaseInjectedModule):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]\n\n    Args:\n        config: LlamaConfig\n    \"\"\"\n\n    dynamic_sdpa = None\n\n    def __init__(\n        self,\n        key: str,\n        gguf_loader: GGUFLoader,\n        config: PretrainedConfig,\n        orig_module: nn.Module,\n        device: str = \"cuda\",\n        per_layer_prefill_intput_threshold: int = 30000,  # if None, no per-layer prefill\n        transfer_map: dict = None,\n        **kwargs,\n    ):\n\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, device, **kwargs\n        )\n        self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold\n        self.transfer_map = transfer_map\n        self.stream_device_map = dict()\n        user_path: str = os.path.expanduser('~')\n        localstore_path: str = os.path.join(user_path,'.ktransformers')\n        config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)\n        with open(config_path,\"r\") as file:\n            config_yaml = yaml.safe_load(file.read())\n            self.long_context_config = config_yaml.get(\"long_context\")\n            self.ext_config = config_yaml.get(\"ext\")\n\n        KLlamaModel.dynamic_sdpa = DynamicScaledDotProductAttention(\n            max_seq_len=self.long_context_config[\"max_seq_len\"],\n            block_size=self.long_context_config[\"block_size\"],\n            config=config,\n            device=torch.device(\"cuda\"),\n            local_windows_len=self.long_context_config[\"local_windows_len\"],\n            topk=self.long_context_config[\"second_select_num\"],\n            threads_num=self.ext_config[\"cpu_infer\"],\n            anchor_type=self.long_context_config[\"anchor_type\"],\n            kv_type=self.long_context_config[\"kv_type\"],\n            dense_layer_num=self.long_context_config[\"dense_layer_num\"],\n            anchor_num=self.long_context_config[\"anchor_num\"],\n            preselect_block=self.long_context_config[\"preselect_block\"],\n            block_selection_mode=self.long_context_config[\"head_select_mode\"],\n            preselect_block_count=self.long_context_config[\"preselect_block_count\"],\n            layer_step=self.long_context_config[\"layer_step\"],\n            token_step=self.long_context_config[\"token_step\"],\n            prefill_chunk_size=self.long_context_config[\"chunk_size\"],\n            use_attn_sparsity=False,\n        )\n\n    def get_input_embeddings(self):\n        return self.embed_tokens\n\n    def set_input_embeddings(self, value):\n        self.embed_tokens = value\n\n    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)\n    def forward(\n        self,\n        input_ids: torch.LongTensor = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.LongTensor] = None,\n        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,\n        inputs_embeds: Optional[torch.FloatTensor] = None,\n        use_cache: Optional[bool] = None,\n        output_attentions: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        cache_position: Optional[torch.LongTensor] = None,\n    ) -> Union[Tuple, BaseModelOutputWithPast]:\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training and use_cache:\n            logger.warning_once(\n                \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\"\n            )\n            use_cache = False\n\n        return_legacy_cache = False\n        if (\n            use_cache and not isinstance(past_key_values, Cache) and not self.training\n        ):  # kept for BC (non `Cache` `past_key_values` inputs)\n            return_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            logger.warning_once(\n                \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n                \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n            )\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=\"cuda\",\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = None\n        chunck_size = self.long_context_config[\"chunk_size\"]\n        cur_idx = 0\n        if inputs_embeds is None:\n            inputs_embeds = self.embed_tokens(input_ids.to(\"cpu\"))\n        q_len = cache_position.size(0)\n\n        # generate\n        if q_len == 1:\n            x = inputs_embeds[:, -1:, :]\n            position_ids = position_ids[:, -1:]\n            return self.forward_chunk(\n                x,\n                causal_mask,\n                position_ids,\n                past_key_values,\n                output_attentions,\n                use_cache,\n                cache_position,\n                output_hidden_states,\n                return_dict,\n            )\n        elif q_len <= chunck_size:\n            inputs_embeds = inputs_embeds.to('cuda')\n            output = self.forward_chunk(\n                inputs_embeds,\n                causal_mask,\n                position_ids,\n                past_key_values,\n                output_attentions,\n                use_cache,\n                cache_position,\n                output_hidden_states,\n                return_dict,\n            )\n            KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)\n            KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)\n            return output\n        cur_idx = 0\n        assert (\n            output_attentions == False\n        ), \"output_attentions is not supported when using chunked attention\"\n        attn_output = None\n        # prefill\n        KLlamaModel.dynamic_sdpa.remaining_length = q_len\n        while cur_idx < q_len:\n            print(f'current prefill length: {cur_idx}')\n            chunk_mask = None\n            if inputs_embeds.device.type == 'cpu':\n                tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)].to(\"cuda\")\n            else:\n                tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)]\n            output_with_past = self.forward_chunk(\n                tmp_inputs_embeds,\n                chunk_mask,\n                position_ids[:, cur_idx : min(cur_idx + chunck_size, q_len)],\n                past_key_values,\n                output_attentions,\n                use_cache,\n                cache_position[cur_idx : min(cur_idx + chunck_size, q_len)],\n            )\n            cur_output = output_with_past.last_hidden_state\n            KLlamaModel.dynamic_sdpa.remaining_length -= (\n                min(cur_idx + chunck_size, q_len) - cur_idx\n            )\n            cur_idx += chunck_size\n            # if attn_output is None:\n            attn_output = cur_output\n            # else:\n            #     attn_output = torch.cat((attn_output, cur_output), dim=-2)\n\n        KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)\n        KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)\n        return BaseModelOutputWithPast(last_hidden_state=attn_output)\n\n    def forward_chunk(\n        self,\n        inputs_embeds,\n        causal_mask,\n        position_ids,\n        past_key_values,\n        output_attentions,\n        use_cache,\n        cache_position,\n        output_hidden_states: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n    ):\n\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        return_legacy_cache = False\n        if use_cache and not isinstance(\n            past_key_values, Cache\n        ):  # kept for BC (non `Cache` `past_key_values` inputs)\n            return_legacy_cache = True\n            past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        hidden_states = inputs_embeds\n\n        # create position embeddings to be shared across the decoder layers\n        position_embeddings = self.rotary_emb(hidden_states, position_ids)\n\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        next_decoder_cache = None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    use_cache,\n                    cache_position,\n                    position_embeddings,\n                )\n            else:\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    position_embeddings=position_embeddings,\n                )\n\n            hidden_states = layer_outputs[0]\n\n            if use_cache:\n                next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        # add hidden states from the last decoder layer\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        next_cache = next_decoder_cache if use_cache else None\n        if return_legacy_cache:\n            next_cache = next_cache.to_legacy_cache()\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]\n                if v is not None\n            )\n        return BaseModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=next_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n        )\n\n    def _update_causal_mask(\n        self,\n        attention_mask: torch.Tensor,\n        input_tensor: torch.Tensor,\n        cache_position: torch.Tensor,\n        past_key_values: Cache,\n        output_attentions: bool,\n    ):\n        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static\n        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.\n        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using\n        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114\n\n        if self.config._attn_implementation == \"flash_attention_2\":\n            if attention_mask is not None and 0.0 in attention_mask:\n                return attention_mask\n            return None\n\n        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in\n        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail\n        # to infer the attention mask.\n        past_seen_tokens = (\n            past_key_values.get_seq_length() if past_key_values is not None else 0\n        )\n        using_static_cache = isinstance(past_key_values, StaticCache)\n\n        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and not using_static_cache\n            and not output_attentions\n        ):\n            if AttentionMaskConverter._ignore_causal_mask_sdpa(\n                attention_mask,\n                inputs_embeds=input_tensor,\n                past_key_values_length=past_seen_tokens,\n                is_training=self.training,\n            ):\n                return None\n\n        dtype, device = input_tensor.dtype, input_tensor.device\n        min_dtype = torch.finfo(dtype).min\n        sequence_length = input_tensor.shape[1]\n        if using_static_cache:\n            target_length = past_key_values.get_max_length()\n        else:\n            target_length = (\n                attention_mask.shape[-1]\n                if isinstance(attention_mask, torch.Tensor)\n                else past_seen_tokens + sequence_length + 1\n            )\n\n        if attention_mask is not None and attention_mask.dim() == 4:\n            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing\n            if attention_mask.max() != 0:\n                raise ValueError(\n                    \"Custom 4D attention mask should be passed in inverted form with max==0`\"\n                )\n            causal_mask = attention_mask\n        else:\n            causal_mask = torch.full(\n                (sequence_length, target_length),\n                fill_value=min_dtype,\n                dtype=dtype,\n                device=device,\n            )\n            if sequence_length != 1:\n                causal_mask = torch.triu(causal_mask, diagonal=1)\n            causal_mask *= torch.arange(\n                target_length, device=device\n            ) > cache_position.reshape(-1, 1)\n            causal_mask = causal_mask[None, None, :, :].expand(\n                input_tensor.shape[0], 1, -1, -1\n            )\n            if attention_mask is not None:\n                causal_mask = (\n                    causal_mask.clone()\n                )  # copy to contiguous memory for in-place edit\n                mask_length = attention_mask.shape[-1]\n                padding_mask = (\n                    causal_mask[:, :, :, :mask_length]\n                    + attention_mask[:, None, None, :]\n                )\n                padding_mask = padding_mask == 0\n                causal_mask[:, :, :, :mask_length] = causal_mask[\n                    :, :, :, :mask_length\n                ].masked_fill(padding_mask, min_dtype)\n        if (\n            self.config._attn_implementation == \"sdpa\"\n            and attention_mask is not None\n            and attention_mask.device.type == \"cuda\"\n            and not output_attentions\n        ):\n            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when\n            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.\n            # Details: https://github.com/pytorch/pytorch/issues/110213\n            causal_mask = AttentionMaskConverter._unmask_unattended(\n                causal_mask, min_dtype\n            )\n\n        return causal_mask\n\n\n\nQWEN3MOE_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):\n            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide\n            it.\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            [What are input IDs?](../glossary#input-ids)\n        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n\n            [What are attention masks?](../glossary#attention-mask)\n\n            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n            [`PreTrainedTokenizer.__call__`] for details.\n\n            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see\n            `past_key_values`).\n\n            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]\n            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more\n            information on the default strategy.\n\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,\n            config.n_positions - 1]`.\n\n            [What are position IDs?](../glossary#position-ids)\n        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):\n            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n            Two formats are allowed:\n            - a [`~cache_utils.Cache`] instance;\n            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n            cache format.\n\n            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n            legacy cache format will be returned.\n\n            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n            of shape `(batch_size, sequence_length)`.\n        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):\n            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n            model's internal embedding lookup matrix.\n        use_cache (`bool`, *optional*):\n            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n            `past_key_values`).\n        output_attentions (`bool`, *optional*):\n            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n            tensors for more detail.\n        output_hidden_states (`bool`, *optional*):\n            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n            more detail.\n        output_router_logits (`bool`, *optional*):\n            Whether or not to return the logits of all the routers. They are useful for computing the router loss, and\n            should not be returned during inference.\n        return_dict (`bool`, *optional*):\n            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.\n        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):\n            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n            the complete sequence length.\n\"\"\"\n\nclass KQwen3MoeModel(BaseInjectedModule):\n    \"\"\"\n    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3MoeDecoderLayer`]\n\n    Args:\n        config: Qwen3MoeConfig\n    \"\"\"\n\n    def __init__(\n            self,\n            key: str,\n            gguf_loader: GGUFLoader,\n            config: PretrainedConfig,\n            orig_module: nn.Module,\n            device: str = \"cuda\",\n            per_layer_prefill_intput_threshold: int = 30000,  # if None, no per-layer prefill\n            transfer_map: dict = None,\n            **kwargs,\n    ):\n        BaseInjectedModule.__init__(\n            self, key, gguf_loader, config, orig_module, device, **kwargs\n        )\n        self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold\n        self.transfer_map = transfer_map\n        self.stream_device_map = dict()\n        self.head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        self.max_position_embeddings = config.max_position_embeddings\n        self.rope_theta = config.rope_theta\n        self.rotary_emb = Qwen2MoeRotaryEmbedding(\n            self.head_dim,\n            max_position_embeddings=self.max_position_embeddings,\n            base=self.rope_theta,\n        )\n\n    @add_start_docstrings_to_model_forward(QWEN3MOE_INPUTS_DOCSTRING)\n    def forward(\n            self,\n            input_ids: torch.LongTensor = None,\n            attention_mask: Optional[torch.Tensor] = None,\n            position_ids: Optional[torch.LongTensor] = None,\n            past_key_values: Optional[List[torch.FloatTensor]] = None,\n            inputs_embeds: Optional[torch.FloatTensor] = None,\n            use_cache: Optional[bool] = None,\n            output_attentions: Optional[bool] = None,\n            output_hidden_states: Optional[bool] = None,\n            output_router_logits: Optional[bool] = None,\n            return_dict: Optional[bool] = None,\n            cache_position: Optional[torch.LongTensor] = None,\n            per_layer_prefill_intput_threshold: (\n                    int | None\n            ) = None,  # if None or 0, close per-layer prefill\n    ) -> Union[Tuple, MoeModelOutputWithPast]:\n        # print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}')\n\n        if per_layer_prefill_intput_threshold is None:\n            per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold\n        per_layer_prefill_flag = False\n        seq_lenth = (\n            inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)\n        )\n        if (\n                per_layer_prefill_intput_threshold\n                and per_layer_prefill_intput_threshold < seq_lenth\n        ):\n            per_layer_prefill_flag = True\n            for layer in self.layers:\n                self.load_layer_to(layer, InferenceState.UNLOAD)\n        else:\n            pass\n        output_attentions = (\n            output_attentions\n            if output_attentions is not None\n            else self.config.output_attentions\n        )\n        output_router_logits = (\n            output_router_logits\n            if output_router_logits is not None\n            else self.config.output_router_logits\n        )\n        output_hidden_states = (\n            output_hidden_states\n            if output_hidden_states is not None\n            else self.config.output_hidden_states\n        )\n        use_cache = use_cache if use_cache is not None else self.config.use_cache\n\n        return_dict = (\n            return_dict if return_dict is not None else self.config.use_return_dict\n        )\n\n        if (input_ids is None) ^ (inputs_embeds is not None):\n            raise ValueError(\n                \"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one\"\n            )\n\n        if self.gradient_checkpointing and self.training:\n            if use_cache:\n                logger.warning_once(\n                    \"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\"\n                )\n                use_cache = False\n\n        if use_cache and past_key_values is None:\n            past_key_values = DynamicCache()\n        # use_legacy_cache = False\n        # if use_cache and not isinstance(past_key_values, Cache):\n        #     use_legacy_cache = True\n        #     past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n        #     logger.warning_once(\n        #         \"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. \"\n        #         \"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\"\n        #     )\n\n        if inputs_embeds is None:\n            input_ids = input_ids.to(\"cpu\")\n            inputs_embeds = self.embed_tokens(input_ids)\n            inputs_embeds = inputs_embeds.to(\"cuda\")\n\n        if cache_position is None:\n            past_seen_tokens = (\n                past_key_values.get_seq_length() if past_key_values is not None else 0\n            )\n            cache_position = torch.arange(\n                past_seen_tokens,\n                past_seen_tokens + inputs_embeds.shape[1],\n                device=inputs_embeds.device,\n            )\n        if position_ids is None:\n            position_ids = cache_position.unsqueeze(0)\n\n        causal_mask = self._update_causal_mask(\n            attention_mask,\n            inputs_embeds,\n            cache_position,\n            past_key_values,\n            output_attentions,\n        )\n\n        hidden_states = inputs_embeds\n\n        # position_embeddings = self.rotary_emb(hidden_states, position_ids)\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n        all_router_logits = () if output_router_logits else None\n        # next_decoder_cache = None\n\n        for i, decoder_layer in enumerate(self.layers):\n            # if self.transfer_map is not None and i in self.transfer_map:\n            #     prev_stream = torch.cuda.current_stream()\n            #     cur_device = self.transfer_map[i]\n            #     if cur_device not in self.stream_device_map:\n            #         self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)\n            #     torch.cuda.set_device(cur_device)\n            #     self.stream_device_map[cur_device].wait_stream(prev_stream)\n            #     torch.cuda.set_stream(self.stream_device_map[cur_device])\n            #     hidden_states = hidden_states.to(\n            #         self.transfer_map[i], non_blocking=True\n            #     )\n            #     causal_mask = (\n            #         causal_mask.to(self.transfer_map[i], non_blocking=True)\n            #         if causal_mask is not None\n            #         else None\n            #     )\n            #     position_ids = (\n            #         position_ids.to(self.transfer_map[i], non_blocking=True)\n            #         if position_ids is not None\n            #         else None\n            #     )\n            #     cache_position = (\n            #         cache_position.to(self.transfer_map[i], non_blocking=True)\n            #         if cache_position is not None\n            #         else None\n            #     )\n\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            if self.gradient_checkpointing and self.training:\n                layer_outputs = self._gradient_checkpointing_func(\n                    decoder_layer.__call__,\n                    hidden_states,\n                    causal_mask,\n                    position_ids,\n                    past_key_values,\n                    output_attentions,\n                    output_router_logits,\n                    use_cache,\n                    cache_position,\n                    # position_embeddings,\n                )\n            else:\n                if per_layer_prefill_flag:\n                    # print(f\"to gpu\")\n                    self.load_layer_to(decoder_layer, InferenceState.PREFILL)\n                    torch.cuda.empty_cache()\n                layer_outputs = decoder_layer(\n                    hidden_states,\n                    attention_mask=causal_mask,\n                    position_ids=position_ids,\n                    past_key_value=past_key_values,\n                    output_attentions=output_attentions,\n                    output_router_logits=output_router_logits,\n                    use_cache=use_cache,\n                    cache_position=cache_position,\n                    # position_embeddings=position_embeddings,\n                )\n                if per_layer_prefill_flag:\n                    # print(f\"to cpu\")\n                    self.load_layer_to(decoder_layer, InferenceState.UNLOAD)\n                    torch.cuda.empty_cache()\n            hidden_states = layer_outputs[0]\n            # use_cache=False\n            # if use_cache:\n            #     next_decoder_cache = layer_outputs[2 if output_attentions else 1]\n\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n            if output_router_logits and layer_outputs[-1] is not None:\n                all_router_logits += (layer_outputs[-1],)\n\n        hidden_states = self.norm(hidden_states)\n\n        if per_layer_prefill_flag:\n            per_layer_prefill_flag = False\n            for layer in self.layers:\n                self.load_layer_to(layer, InferenceState.GENERATE)\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        # next_cache = None\n        # if use_cache:\n        #     next_cache = (\n        #         next_decoder_cache.to_legacy_cache()\n        #         if use_legacy_cache\n        #         else next_decoder_cache\n        #     )\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    past_key_values,\n                    all_hidden_states,\n                    all_self_attns,\n                    all_router_logits,\n                ]\n                if v is not None\n            )\n        return MoeModelOutputWithPast(\n            last_hidden_state=hidden_states,\n            past_key_values=past_key_values,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attns,\n            router_logits=all_router_logits,\n        )\n\n    def load_layer_to(self, layer: Qwen3MoeDecoderLayer, target: InferenceState):\n        assert isinstance(\n            layer, Qwen3MoeDecoderLayer\n        ), \"module should be nn.ModuleList of decoder layers\"\n\n        # TODO Support restore to original device, not only cuda\n        device = \"cpu\" if target == InferenceState.UNLOAD else \"cuda\"\n\n        # attn\n        layer.self_attn.q_proj.set_inference_mode(target)\n        layer.self_attn.k_proj.set_inference_mode(target)\n        layer.self_attn.v_proj.set_inference_mode(target)\n        layer.self_attn.o_proj.set_inference_mode(target)\n        layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device)\n\n        # mlp\n        if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):\n            layer.mlp.gate.set_inference_mode(target)\n            layer.mlp.experts.set_inference_mode(target)\n            layer.mlp.shared_expert.gate_proj.set_inference_mode(target)\n            layer.mlp.shared_expert.up_proj.set_inference_mode(target)\n            layer.mlp.shared_expert.down_proj.set_inference_mode(target)\n            layer.mlp.shared_expert.act_fn.to(device)\n            layer.mlp.shared_expert_gate.to(device)\n        else:\n            layer.mlp.gate_proj.set_inference_mode(target)\n            layer.mlp.up_proj.set_inference_mode(target)\n            layer.mlp.down_proj.set_inference_mode(target)\n            layer.mlp.act_fn.to(device)\n        # layer norm\n        layer.input_layernorm.to(device)\n        layer.post_attention_layernorm.to(device)\n"
  },
  {
    "path": "kt-sft/ktransformers/operators/triton_attention.py",
    "content": "# Adapted from\r\n# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py\r\n# which was originally adapted from\r\n# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py\r\n# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py\r\n\r\nimport triton\r\nimport triton.language as tl\r\nfrom ktransformers.util.vendors import device_manager, get_device, to_device, GPUVendor\r\n@triton.jit\r\ndef tanh(x):\r\n    # Tanh is just a scaled sigmoid\r\n    return 2 * tl.sigmoid(2 * x) - 1\r\n\r\n@triton.jit\r\ndef _fwd_grouped_kernel_stage1(\r\n    Q,\r\n    K_Buffer,\r\n    V_Buffer,\r\n    sm_scale,\r\n    Req_to_tokens,\r\n    B_Seqlen,\r\n    Att_Out,\r\n    stride_req_to_tokens_b,\r\n    stride_qbs,\r\n    stride_qh,\r\n    stride_buf_kbs,\r\n    stride_buf_kh,\r\n    stride_buf_vbs,\r\n    stride_buf_vh,\r\n    stride_mid_ob,\r\n    stride_mid_oh,\r\n    stride_mid_os,\r\n    kv_group_num: tl.constexpr,\r\n    q_head_num: tl.constexpr,\r\n    BLOCK_DMODEL: tl.constexpr,\r\n    BLOCK_DPE: tl.constexpr,\r\n    BLOCK_DV: tl.constexpr,\r\n    BLOCK_N: tl.constexpr,\r\n    BLOCK_H: tl.constexpr,\r\n    NUM_KV_SPLITS: tl.constexpr,\r\n    PAGE_SIZE: tl.constexpr,\r\n    logit_cap: tl.constexpr,\r\n    Lk: tl.constexpr,\r\n    Lv: tl.constexpr,\r\n):\r\n    cur_batch = tl.program_id(0)\r\n    cur_head_id = tl.program_id(1)\r\n    cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)\r\n    split_kv_id = tl.program_id(2)\r\n\r\n    if kv_group_num > BLOCK_H:\r\n        VALID_BLOCK_H: tl.constexpr = BLOCK_H\r\n    else:\r\n        VALID_BLOCK_H: tl.constexpr = kv_group_num\r\n    cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)\r\n    mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H\r\n    mask_h = mask_h & (cur_head < q_head_num)\r\n\r\n    offs_d = tl.arange(0, BLOCK_DMODEL)\r\n    offs_dv = tl.arange(0, BLOCK_DV)\r\n    mask_d = offs_d < Lk\r\n    mask_dv = offs_dv < Lv\r\n    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\r\n    cur_batch_req_idx = cur_batch\r\n\r\n    offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[\r\n        None, :]\r\n    q = tl.load(Q + offs_q,\r\n                mask=(mask_h[:, None]) & (mask_d[None, :]),\r\n                other=0.0)\r\n\r\n    if BLOCK_DPE > 0:\r\n        offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)\r\n        mask_dpe = offs_dpe < Lk\r\n        off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +\r\n                   offs_dpe[None, :])\r\n        qpe = tl.load(Q + off_qpe,\r\n                      mask=(mask_h[:, None]) & (mask_dpe[None, :]),\r\n                      other=0.0)\r\n\r\n    kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)\r\n    split_kv_start = kv_len_per_split * split_kv_id\r\n    split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,\r\n                              cur_batch_seq_len)\r\n    \r\n    e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float(\"inf\")\r\n    e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)\r\n    acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)\r\n\r\n    if split_kv_end > split_kv_start:\r\n        for start_n in range(split_kv_start, split_kv_end, BLOCK_N):\r\n            offs_n = start_n + tl.arange(0, BLOCK_N)\r\n            kv_page_number = tl.load(\r\n                Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +\r\n                offs_n // PAGE_SIZE,\r\n                mask=offs_n < split_kv_end,\r\n                other=0,\r\n            )\r\n            kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE\r\n            offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +\r\n                          cur_kv_head * stride_buf_kh + offs_d[:, None])\r\n            k = tl.load(\r\n                K_Buffer + offs_buf_k,\r\n                mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),\r\n                other=0.0,\r\n            )\r\n            qk = tl.dot(q, k.to(q.dtype))\r\n            \r\n            if BLOCK_DPE > 0:\r\n                offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +\r\n                                cur_kv_head * stride_buf_kh +\r\n                                offs_dpe[:, None])\r\n                kpe = tl.load(\r\n                    K_Buffer + offs_buf_kpe,\r\n                    mask=(offs_n[None, :] < split_kv_end) &\r\n                    (mask_dpe[:, None]),\r\n                    other=0.0,\r\n                )\r\n                qk += tl.dot(qpe, kpe.to(qpe.dtype))\r\n            qk *= sm_scale\r\n\r\n            if logit_cap > 0:\r\n                qk = logit_cap * tanh(qk / logit_cap)\r\n\r\n            qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),\r\n                          qk, float(\"-inf\"))\r\n\r\n            offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +\r\n                          cur_kv_head * stride_buf_vh + offs_dv[None, :])\r\n            v = tl.load(\r\n                V_Buffer + offs_buf_v,\r\n                mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),\r\n                other=0.0,\r\n            )\r\n\r\n            n_e_max = tl.maximum(tl.max(qk, 1), e_max)\r\n            re_scale = tl.exp(e_max - n_e_max)\r\n            p = tl.exp(qk - n_e_max[:, None])\r\n            acc *= re_scale[:, None]\r\n            acc += tl.dot(p.to(v.dtype), v)\r\n\r\n            e_sum = e_sum * re_scale + tl.sum(p, 1)\r\n            e_max = n_e_max\r\n\r\n        offs_mid_o = (cur_batch * stride_mid_ob +\r\n                      cur_head[:, None] * stride_mid_oh +\r\n                      split_kv_id * stride_mid_os + offs_dv[None, :])\r\n\r\n        tl.store(\r\n            Att_Out + offs_mid_o,\r\n            acc / e_sum[:, None],\r\n            mask=(mask_h[:, None]) & (mask_dv[None, :]),\r\n        )\r\n\r\n        offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +\r\n                        split_kv_id * stride_mid_os + Lv)\r\n\r\n        tl.store(\r\n            Att_Out + offs_mid_o_1,\r\n            e_max + tl.log(e_sum),\r\n            mask=mask_h,\r\n        )\r\n\r\ndef _decode_grouped_att_m_fwd(\r\n    q,\r\n    k_buffer,\r\n    v_buffer,\r\n    att_out,\r\n    Req_to_tokens,\r\n    B_Seqlen,\r\n    num_kv_splits,\r\n    sm_scale,\r\n    page_size,\r\n    logit_cap,\r\n):\r\n    BLOCK = 32\r\n    Lk = k_buffer.shape[-1]\r\n    Lv = v_buffer.shape[-1]\r\n\r\n    # [TODO] work around shmem limit on MI3xx\r\n    \r\n    # TODO: support hip\r\n    if device_manager.gpu_vendor == GPUVendor.AMD and Lk >= 576:\r\n       BLOCK = 16\r\n\r\n    if Lk == 576:\r\n        BLOCK_DMODEL = 512\r\n        BLOCK_DPE = 64\r\n    elif Lk == 288:\r\n        BLOCK_DMODEL = 256\r\n        BLOCK_DPE = 32\r\n    else:\r\n        BLOCK_DMODEL = triton.next_power_of_2(Lk)\r\n        BLOCK_DPE = 0\r\n    BLOCK_DV = triton.next_power_of_2(Lv)\r\n\r\n    batch, head_num = q.shape[0], q.shape[1]\r\n    kv_group_num = q.shape[1] // k_buffer.shape[-2]\r\n\r\n    BLOCK_H = 16\r\n    NUM_KV_SPLITS = num_kv_splits\r\n    grid = (\r\n        batch,\r\n        triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),\r\n        NUM_KV_SPLITS,\r\n    )\r\n\r\n    extra_kargs = {}\r\n    # TODO: support hip\r\n    \"\"\"\r\n    if is_hip_:\r\n        # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html\r\n        # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py\r\n        extra_kargs = {\r\n            \"waves_per_eu\": 4,\r\n            \"matrix_instr_nonkdim\": 16,\r\n            \"kpack\": 2\r\n        }\r\n    \"\"\"\r\n    \r\n    _fwd_grouped_kernel_stage1[grid](\r\n        q,\r\n        k_buffer,\r\n        v_buffer,\r\n        sm_scale,\r\n        Req_to_tokens,\r\n        B_Seqlen,\r\n        att_out,\r\n        Req_to_tokens.stride(0),\r\n        q.stride(0),\r\n        q.stride(1),\r\n        k_buffer.stride(-3),  # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)\r\n        k_buffer.stride(-2),  # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)\r\n        v_buffer.stride(-3),  # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)\r\n        v_buffer.stride(-2),  # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)\r\n        att_out.stride(0),\r\n        att_out.stride(1),\r\n        att_out.stride(2),\r\n        kv_group_num=kv_group_num,\r\n        q_head_num=head_num,\r\n        BLOCK_DMODEL=BLOCK_DMODEL,\r\n        BLOCK_DPE=BLOCK_DPE,\r\n        BLOCK_DV=BLOCK_DV,\r\n        BLOCK_N=BLOCK,\r\n        BLOCK_H=BLOCK_H,\r\n        NUM_KV_SPLITS=NUM_KV_SPLITS,\r\n        PAGE_SIZE=page_size,\r\n        logit_cap=logit_cap,\r\n        num_warps=4,\r\n        num_stages=2,\r\n        Lk=Lk,\r\n        Lv=Lv,\r\n        **extra_kargs,\r\n    )\r\n\r\n@triton.jit\r\ndef _fwd_kernel_stage2(\r\n    Mid_O,\r\n    o,\r\n    B_Seqlen,\r\n    stride_mid_ob,\r\n    stride_mid_oh,\r\n    stride_mid_os,\r\n    stride_obs,\r\n    stride_oh,\r\n    NUM_KV_SPLITS: tl.constexpr,\r\n    BLOCK_DV: tl.constexpr,\r\n    Lv: tl.constexpr,\r\n):\r\n    cur_batch = tl.program_id(0)\r\n    cur_head = tl.program_id(1)\r\n\r\n    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\r\n\r\n    offs_d = tl.arange(0, BLOCK_DV)\r\n    mask_d = offs_d < Lv\r\n\r\n    e_sum = 0.0\r\n    e_max = -float(\"inf\")\r\n    acc = tl.zeros([BLOCK_DV], dtype=tl.float32)\r\n\r\n    offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\r\n    offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv\r\n\r\n    for split_kv_id in range(0, NUM_KV_SPLITS):\r\n        kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)\r\n        split_kv_start = kv_len_per_split * split_kv_id\r\n        split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,\r\n                                  cur_batch_seq_len)\r\n\r\n        if split_kv_end > split_kv_start:\r\n            tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os,\r\n                         mask=mask_d,\r\n                         other=0.0)\r\n            tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)\r\n            n_e_max = tl.maximum(tlogic, e_max)\r\n\r\n            old_scale = tl.exp(e_max - n_e_max)\r\n            acc *= old_scale\r\n            exp_logic = tl.exp(tlogic - n_e_max)\r\n            acc += exp_logic * tv\r\n\r\n            e_sum = e_sum * old_scale + exp_logic\r\n            e_max = n_e_max\r\n\r\n    tl.store(\r\n        o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,\r\n        acc / e_sum,\r\n        mask=mask_d,\r\n    )\r\n\r\ndef _decode_softmax_reducev_fwd(\r\n    logits,\r\n    q,\r\n    o,\r\n    v_buffer,\r\n    b_seq_len,\r\n    num_kv_splits,\r\n):\r\n    batch, head_num = q.shape[0], q.shape[1]\r\n    Lv = v_buffer.shape[-1]\r\n    BLOCK_DV = triton.next_power_of_2(Lv)\r\n\r\n    NUM_KV_SPLITS = num_kv_splits\r\n\r\n    extra_kargs = {}\r\n    # TODO: support hip\r\n    \"\"\"\r\n    if is_hip_:\r\n        # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html\r\n        # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py\r\n        extra_kargs = {\r\n            \"waves_per_eu\": 4,\r\n            \"matrix_instr_nonkdim\": 16,\r\n            \"kpack\": 2\r\n        }\r\n    \"\"\"\r\n    \r\n    grid = (batch, head_num)\r\n    _fwd_kernel_stage2[grid](\r\n        logits,\r\n        o,\r\n        b_seq_len,\r\n        logits.stride(0),\r\n        logits.stride(1),\r\n        logits.stride(2),\r\n        o.stride(0),\r\n        o.stride(1),\r\n        NUM_KV_SPLITS=NUM_KV_SPLITS,\r\n        BLOCK_DV=BLOCK_DV,\r\n        Lv=Lv,\r\n        num_warps=4,\r\n        num_stages=2,\r\n        **extra_kargs,\r\n    )\r\n\r\ndef decode_attention_fwd_grouped(\r\n    q,\r\n    k_buffer,\r\n    v_buffer,\r\n    o,\r\n    req_to_token,\r\n    b_seq_len,\r\n    attn_logits,\r\n    num_kv_splits,\r\n    sm_scale,\r\n    page_size,\r\n    logit_cap=0.0,\r\n):\r\n    _decode_grouped_att_m_fwd(\r\n        q,\r\n        k_buffer,\r\n        v_buffer,\r\n        attn_logits,\r\n        req_to_token,\r\n        b_seq_len,\r\n        num_kv_splits,\r\n        sm_scale,\r\n        page_size,\r\n        logit_cap,\r\n    )\r\n\r\n    _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,\r\n                                num_kv_splits)\r\n"
  },
  {
    "path": "kt-sft/ktransformers/operators/triton_attention_prefill.py",
    "content": "\n# Adapted from\n# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py\n# which was originally adapted from\n# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1\n\n\"\"\"\nMemory-efficient attention for prefill.\nIt supporst page size = 1.\n\"\"\"\n\n# Adapted from\n# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1\nimport torch\nimport triton\nimport triton.language as tl\n\nis_cuda_available = torch.cuda.is_available()\nif is_cuda_available:\n    CUDA_CAPABILITY = torch.cuda.get_device_capability()\n\n\n@triton.jit\ndef _fwd_kernel(\n    Q,\n    K,\n    V,\n    sm_scale,\n    B_Start_Loc,\n    B_Seqlen,\n    Out,\n    stride_qbs,\n    stride_qh,\n    stride_kbs,\n    stride_kh,\n    stride_vbs,\n    stride_vh,\n    stride_obs,\n    stride_oh,\n    kv_group_num: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_DMODEL: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n    IS_CAUSAL: tl.constexpr,\n    Lk: tl.constexpr,\n):\n    cur_batch = tl.program_id(0)\n    cur_head = tl.program_id(1)\n    start_m = tl.program_id(2)\n\n    cur_kv_head = cur_head // kv_group_num\n\n    cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n    cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n    block_start_loc = BLOCK_M * start_m\n\n    # initialize offsets\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_DMODEL)\n    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n    off_q = (\n        (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n        + cur_head * stride_qh\n        + offs_d[None, :]\n    )\n    off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]\n    off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]\n\n    mask_d = offs_d < Lk\n\n    q = tl.load(\n        Q + off_q,\n        mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),\n        other=0.0,\n    )\n\n    k_ptrs = K + off_k\n    v_ptrs = V + off_v\n\n    # initialize pointer to m and l\n    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n    block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n    end_n = (\n        cur_batch_seq_len\n        if not IS_CAUSAL\n        else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)\n    )\n    for start_n in range(0, block_mask * end_n, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n        # -- compute qk ----\n        k = tl.load(\n            k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n            mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),\n            other=0.0,\n        )\n        # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)\n\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        qk += tl.dot(q, k)\n        qk *= sm_scale\n\n        if IS_CAUSAL:\n            qk += tl.where(\n                (start_n + offs_n[None, :] < cur_batch_seq_len)\n                & (offs_m[:, None] >= (start_n + offs_n[None, :])),\n                0,\n                float(\"-inf\"),\n            )\n        else:\n            qk += tl.where(\n                (start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float(\"-inf\")\n            )\n\n        # -- compute m_ij, p, l_ij\n        m_ij = tl.max(qk, 1)\n        p = tl.exp(qk - m_ij[:, None])\n        l_ij = tl.sum(p, 1)\n        # -- update m_i and l_i\n        m_i_new = tl.maximum(m_i, m_ij)\n        alpha = tl.exp(m_i - m_i_new)\n        beta = tl.exp(m_ij - m_i_new)\n        l_i_new = alpha * l_i + beta * l_ij\n        # -- update output accumulator --\n        # scale p\n        p_scale = beta / l_i_new\n        p = p * p_scale[:, None]\n        # scale acc\n        acc_scale = l_i / l_i_new * alpha\n        acc = acc * acc_scale[:, None]\n        # update acc\n        v = tl.load(\n            v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n            mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),\n            other=0.0,\n        )\n\n        p = p.to(v.dtype)\n        acc += tl.dot(p, v)\n        # update m_i and l_i\n        l_i = l_i_new\n        m_i = m_i_new\n    # initialize pointers to output\n    off_o = (\n        (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n        + cur_head * stride_oh\n        + offs_d[None, :]\n    )\n    out_ptrs = Out + off_o\n    tl.store(\n        out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])\n    )\n\n\ndef context_attention_fwd(\n    q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True\n):\n    \"\"\"\n    q, k, v: [b * s, head, head_dim]\n    b_start_loc: [b]\n    b_seq_len: [b]\n    out: [b * s, head, head_dim]\n    \"\"\"\n    if is_cuda_available and CUDA_CAPABILITY[0] > 8:\n        BLOCK = 128\n    else:\n        BLOCK = 64\n\n    Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n\n    sm_scale = 1.0 / (Lq**0.5)\n    batch, head = b_seq_len.shape[0], q.shape[1]\n    kv_group_num = q.shape[1] // k.shape[1]\n\n    grid = (batch, head, triton.cdiv(max_input_len, BLOCK))\n    num_warps = 4 if Lk <= 64 else 8\n\n    _fwd_kernel[grid](\n        q,\n        k,\n        v,\n        sm_scale,\n        b_start_loc,\n        b_seq_len,\n        o,\n        q.stride(0),\n        q.stride(1),\n        k.stride(0),\n        k.stride(1),\n        v.stride(0),\n        v.stride(1),\n        o.stride(0),\n        o.stride(1),\n        kv_group_num=kv_group_num,\n        BLOCK_M=BLOCK,\n        BLOCK_DMODEL=triton.next_power_of_2(Lk),\n        BLOCK_N=BLOCK,\n        IS_CAUSAL=is_causal,\n        num_warps=num_warps,\n        num_stages=1,\n        Lk=Lk,\n    )"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang, Azure-Tang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nfrom typing import Mapping, List\nimport torch\nimport yaml\nimport re\nfrom torch import nn\nfrom transformers import AutoConfig\nfrom transformers.configuration_utils import PretrainedConfig\n# from operators import BaseInjectedModule\nfrom ktransformers.util.custom_loader import GGUFLoader, ModelLoaderFactory\nfrom ktransformers.util.utils import set_module, load_weights\nimport itertools\nimport copy\n\ndef inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader:GGUFLoader, prefix=''):\n    for name, child in module._modules.items():\n        if child is not None:\n            child_prefix = prefix + name\n            if child_prefix in local_optimization_dict:\n                inject_module_meta=local_optimization_dict[child_prefix]\n                if inject_module_meta[\"class\"] != \"default\":\n                    import_path = inject_module_meta[\"class\"].split(\".\")\n                    import_module_name = \".\".join(import_path[:-1])\n                    gguf_loader.tensor_device_map[inject_module_meta[\"key\"]] = inject_module_meta[\"kwargs\"] if \"kwargs\" in inject_module_meta else dict()\n                    import_class_name = import_path[-1]\n                    module_cls=getattr(__import__(import_module_name, fromlist=[\"\"]), import_class_name)\n                    print(f\"Injecting {child_prefix} as\", import_module_name, \".\", import_class_name)\n                    inject_module=module_cls(key = inject_module_meta[\"key\"], gguf_loader = gguf_loader, config = model_config, orig_module=child, **inject_module_meta[\"kwargs\"])\n                    set_module(module, name, inject_module)\n                elif inject_module_meta[\"class\"] == \"default\":\n                    print(f\"Injecting {child_prefix} as default\")\n                    gguf_loader.tensor_device_map[inject_module_meta[\"key\"]] = inject_module_meta[\"kwargs\"] if \"kwargs\" in inject_module_meta else dict()\n                else:\n                    raise Exception(\"inject_module_meta[\\\"class\\\"] must be \\\"default\\\" or a class path\")\n                child_prefix += \".\"\n                child_optimization_dict = {k: v for k, v in local_optimization_dict.items() if k.startswith(child_prefix)}\n                inject(child, child_optimization_dict, model_config, gguf_loader, child_prefix)\n\ndef del_meta(module:nn.Module):\n    #print(\"default loading weights\", prefix)\n    persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}\n    local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())\n    local_state = {k: v for k, v in local_name_params if v is not None}\n    for name, param in local_state.items():\n        if param.device == \"meta\" or param.device == torch.device(\"meta\"):\n            module.__delattr__(name)\n    for name, child in module._modules.items():\n        del_meta(child)\n\ndef gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str=\"\", default_device: str = \"cuda:0\"):\n    module_name = prefix[:-1]\n    # translated_name = translate_name_to_gguf(prefix)[:-1]\n    #print(\"gen_optimize_config\", prefix, module_name, translated_name)\n    recursive = True\n    for rule in rule_list:\n        match_meta = rule[\"match\"]\n        if \"class\" not in match_meta and \"name\" not in match_meta:\n            raise Exception(\"match must have at least one of \\\"class\\\" and \\\"name\\\"\")\n        if \"class\" in match_meta:\n            import_path = match_meta[\"class\"].split(\".\")\n            import_module_name = \".\".join(import_path[:-1])\n            import_class_name = import_path[-1]\n            module_cls=getattr(__import__(import_module_name, fromlist=[\"\"]), import_class_name)\n            if not isinstance(module, module_cls):\n                continue\n        if \"name\" in match_meta:\n            if re.search(match_meta[\"name\"], module_name) is None:\n                continue\n        if \"replace\" not in rule:\n            raise Exception(\"replace must be in rule\")\n        if \"replace\" in rule:\n            replace_meta = rule[\"replace\"]\n            if module_name not in out_data:\n                out_data[module_name]={\"key\": module_name,\n                                    \"class\": replace_meta[\"class\"] if \"class\" in replace_meta else \"default\",\n                                    # \"device\": replace_meta[\"device\"] if \"device\" in replace_meta else default_device,\n                                    \"kwargs\": copy.deepcopy(replace_meta[\"kwargs\"]) if \"kwargs\" in replace_meta else dict()}\n            else:\n                if out_data[module_name][\"class\"] == \"default\":\n                    out_data[module_name][\"class\"] = replace_meta[\"class\"] if \"class\" in replace_meta else \"default\"\n                out_data[module_name][\"kwargs\"].update(copy.deepcopy(replace_meta[\"kwargs\"]) if \"kwargs\" in replace_meta else dict())\n        if \"recursive\" in rule:\n            recursive = bool(rule[\"recursive\"])\n        break\n            \n    if module_name not in out_data:\n        out_data[module_name]= {\n            \"class\": \"default\",\n            \"key\": module_name,\n            \"kwargs\": {\"generate_device\": default_device,\n                       \"prefill_device\": default_device}\n        }\n\n    #print(out_data[module_name])\n    #input()\n\n    if recursive:\n        for name, child in module._modules.items():\n            if child is not None:\n                child_prefix = prefix + name + \".\"\n                gen_optimize_config(child, out_data, rule_list, child_prefix, default_device = default_device)\n    \n\ndef translate_model_config(model_config: PretrainedConfig):\n    # for supporting some special model \n    if model_config.model_type == \"mixtral\":\n        model_config.moe_intermediate_size = model_config.intermediate_size\n    \n    return model_config\n\n\ndef optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, model_config: PretrainedConfig, default_device: str = \"cuda:0\"):\n    with open(rule_file, 'r', encoding='utf-8') as f:\n        rule_list = yaml.load(f.read(), Loader=yaml.FullLoader)\n    \n    optimize_config = dict()\n    gen_optimize_config(module, optimize_config, rule_list, default_device = default_device)\n    \n    model_config = translate_model_config(model_config)\n\n    weights_loader = ModelLoaderFactory.create_loader(gguf_path)\n    with torch.device(\"meta\"):\n        inject(module, optimize_config, model_config, weights_loader)\n    # pre load lm_head because its big inter result\n    load_weights(module.lm_head, weights_loader, \"lm_head.\", device=default_device)\n    load_weights(module, weights_loader, device=default_device)\n    module.gguf_loader = weights_loader\n    del_meta(module)\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n    elif torch.xpu.is_available():\n        torch.xpu.empty_cache()\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n- match:\n    name: \"^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n- match:\n    name: \"^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:2\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:2\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:3\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:3\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n- match:\n    name: \"^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        15: \"cuda:1\"\n        30: \"cuda:2\"\n        45: \"cuda:3\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|[1][0-4])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"(^model\\\\.layers\\\\.([2][0-9]|[1][5-9])\\\\.)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"(^model\\\\.layers\\\\.([3][0-9]|[4][0-4])\\\\.)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n      \n- match:\n    name: \"(^model\\\\.layers\\\\.([5][0-9]|[4][5-9])\\\\.)|(^model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([345][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        30: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([345][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-sft-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.(?!self_attn).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.(?!self_attn).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        10: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([12][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx-multi-gpu.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KSFTExpertsCPU\"\n      out_device: \"cuda:0\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KSFTExpertsCPU\"\n      out_device: \"cuda:1\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        10: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([12][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-sft-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-sft.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-use-adapter.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearFP8\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoEV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"llamafile\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.RMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\n  replace:\n    class:  ktransformers.operators.mlp.kDeepseekV3MLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearFP8\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoEV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.RMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\n  replace:\n    class:  ktransformers.operators.mlp.kDeepseekV3MLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-fp8-linear-ggml-experts.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearFP8\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n# === Rotary Embedding Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# === Linear Layers Replacement (excluding self_attn.kv_b_proj) ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# === MLP (MoE) Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# === MLP Gate Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# === MLP Experts Replacement ===\n# replace with marlin expert. Open and modify layer-num as needed.\n# Each layer of malin experts takes about 6GB of GPU memory.\n# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!\n# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!\n\n# GPU 0: layers 3–4\n# - match:\n#     name: \"^model\\\\.layers\\\\.([3-4])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:0\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 1: layers 15–17\n# - match:\n#     name: \"^model\\\\.layers\\\\.(1[5-7])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 2: layers 30–32\n# - match:\n#     name: \"^model\\\\.layers\\\\.(3[0-2])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:2\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 3: layers 45–46\n# - match:\n#     name: \"^model\\\\.layers\\\\.(4[5-6])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:3\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n\n# === MLP Experts Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:2\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:2\"\n  recursive: False\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:3\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:3\"\n  recursive: False\n\n# === Self-Attention Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      absorb_for_prefill: False\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      absorb_for_prefill: False\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n      absorb_for_prefill: False\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      absorb_for_prefill: False\n\n# === Overall Model Replacement with Transfer Map ===\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 means close layer‐wise prefill\n      transfer_map:\n        15: \"cuda:1\" # Layers 15+ on GPU 1\n        30: \"cuda:2\" # Layers 30+ on GPU 2\n        45: \"cuda:3\" # Layers 45+ on GPU 3\n\n# === Default Catch-All for Other Modules ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# For final modules (model.norm), ensure they are on GPU 3 (as in your original config)\n- match:\n    name: \"(^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.)|(^model\\\\.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n# === Rotary Embedding Replacement ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.([3][2-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n\n# GPU 7: layers 56–60\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n\n\n# === Linear Layers Replacement (excluding self_attn.kv_b_proj) ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 7: layers 56–63\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n\n\n# === MLP (MoE) Replacement ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n\n# GPU 7: layers 56–60\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n\n# === MLP Gate Replacement ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n\n# GPU 7: layers 56–60\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n\n\n# === MLP Experts Replacement ===\n# replace with marlin expert. Open and modify layer-num as needed.\n# Each layer of malin experts takes about 6GB of GPU memory.\n# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!\n# !!!Loading marlin expert will take signifcant time.!!!\n\n# GPU 0: layers 0–7\n# - match:\n#     name: \"^model\\\\.layers\\\\.([0-7])\\\\.mlp\\\\.experts$\" # inject experts in layer 0~4 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts  \n#     kwargs:\n#       generate_device: \"cuda:0\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 1: layers 8–15\n# - match:\n#     name: \"^model\\\\.layers\\\\.([8-9]|1[0-5)\\\\.mlp\\\\.experts$\" # inject experts in layer 30~31 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False \n\n# # GPU 2: layers 16–23\n# - match:\n#     name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.mlp\\\\.experts$\" # inject experts in layer 0~4 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts  \n#     kwargs:\n#       generate_device: \"cuda:0\" \n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 3: layers 24–31\n# - match:\n#     name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.mlp\\\\.experts$\" # inject experts in layer 30~31 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False \n\n# # GPU 4: layers 32–39\n# - match:\n#     name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.mlp\\\\.experts$\" # inject experts in layer 0~4 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts  \n#     kwargs:\n#       generate_device: \"cuda:0\" \n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 5: layers 40–47\n# - match:\n#     name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.mlp\\\\.experts$\" # inject experts in layer 30~31 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False \n\n# # GPU 6: layers 48–55\n# - match:\n#     name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.mlp\\\\.experts$\" # inject experts in layer 0~4 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts  \n#     kwargs:\n#       generate_device: \"cuda:0\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 7: layers 56–60\n# - match:\n#     name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.mlp\\\\.experts$\" # inject experts in layer 30~31 as marlin expert\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False \n\n\n# === MLP Experts Replacement ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:2\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:2\"\n  recursive: False\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:3\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:3\"\n  recursive: False\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:4\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:4\"\n  recursive: False\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:5\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:5\"\n  recursive: False\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:6\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:6\"\n  recursive: False\n\n# GPU 7: layers 56–60\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:7\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda:7\"\n  recursive: False\n\n\n# === Self-Attention Replacement ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n\n# GPU 7: layers 56–60\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n\n# === Overall Model Replacement with Transfer Map ===\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 means close layer‐wise prefill\n      transfer_map:\n        8: \"cuda:1\"\n        16: \"cuda:2\"\n        24: \"cuda:3\"\n        32: \"cuda:4\"\n        40: \"cuda:5\"\n        48: \"cuda:6\"\n        56: \"cuda:7\"\n\n# === Default Catch-All for Other Modules ===\n\n# GPU 0: layers 0–7\n- match:\n    name: \"^model\\\\.layers\\\\.([0-7])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 8–15\n- match:\n    name: \"^model\\\\.layers\\\\.(8|9|1[0-5])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 16–23\n- match:\n    name: \"^model\\\\.layers\\\\.(1[6-9]|2[0-3])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 24–31\n- match:\n    name: \"^model\\\\.layers\\\\.(2[4-9]|3[0-1])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# GPU 4: layers 32–39\n- match:\n    name: \"^model\\\\.layers\\\\.(3[2-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:4\"\n      prefill_device: \"cuda:4\"\n\n# GPU 5: layers 40–47\n- match:\n    name: \"^model\\\\.layers\\\\.(4[0-7])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:5\"\n      prefill_device: \"cuda:5\"\n\n# GPU 6: layers 48–55\n- match:\n    name: \"^model\\\\.layers\\\\.(4[8-9]|5[0-5])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:6\"\n      prefill_device: \"cuda:6\"\n\n# GPU 7: layers 56–63\n- match:\n    name: \"^model\\\\.layers\\\\.(5[6-9]|60)\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# For final modules (model.norm), ensure they are on GPU 7 (as in your original config)\n- match:\n    name: \"(^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.)|(^model\\\\.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:7\"\n      prefill_device: \"cuda:7\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-fp8-linear-ggml-experts.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearFP8\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearFP8\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        30: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([3456][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-4])\\\\.mlp\\\\.experts$\" # inject experts in layer 0~4 as marlin expert\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts  \n    kwargs:\n      generate_device: \"cuda:0\" # run in cuda:0\n      generate_op:  \"KExpertsMarlin\"\n  recursive: False\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3][0])\\\\.mlp\\\\.experts$\" # inject experts in layer 30~31 as marlin expert\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      generate_device: \"cuda:1\"\n      generate_op:  \"KExpertsMarlin\"\n  recursive: False \n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        30: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([3456][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        30: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([3456][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoEV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.RMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\n  replace:\n    class:  ktransformers.operators.mlp.kDeepseekV3MLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu-4.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n# === Rotary Embedding Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# === Linear Layers Replacement (excluding self_attn.kv_b_proj) ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.(?!self_attn\\\\.kv_b_proj).*$\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n# === MLP (MoE) Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# === MLP Gate Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n\n# === MLP Experts Replacement ===\n# replace with marlin expert. Open and modify layer-num as needed.\n# Each layer of malin experts takes about 6GB of GPU memory.\n# !!!Do remember 'close' cuda graph if you are using marlin expert.!!!\n# !!!KExpertsTorch is untested, we don't have enough VRAM.!!!\n\n# GPU 0: layers 3–4\n# - match:\n#     name: \"^model\\\\.layers\\\\.([3-4])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:0\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 1: layers 15–17\n# - match:\n#     name: \"^model\\\\.layers\\\\.(1[5-7])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:1\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 2: layers 30–32\n# - match:\n#     name: \"^model\\\\.layers\\\\.(3[0-2])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:2\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n# # GPU 3: layers 45–46\n# - match:\n#     name: \"^model\\\\.layers\\\\.(4[5-6])\\\\.mlp\\\\.experts$\"\n#   replace:\n#     class: ktransformers.operators.experts.KTransformersExperts\n#     kwargs:\n#       generate_device: \"cuda:3\"\n#       generate_op:  \"KExpertsMarlin\"\n#   recursive: False\n\n\n# === MLP Experts Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda:0\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda:1\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:2\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda:2\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda:3\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda:3\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False\n\n# === Self-Attention Replacement ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      absorb_for_prefill: False\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      absorb_for_prefill: False\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n      absorb_for_prefill: False\n\n# GPU 3: layers 45–60\n- match:\n    name: \"^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      absorb_for_prefill: False\n\n# === Overall Model Replacement with Transfer Map ===\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 means close layer‐wise prefill\n      transfer_map:\n        15: \"cuda:1\" # Layers 15+ on GPU 1\n        30: \"cuda:2\" # Layers 30+ on GPU 2\n        45: \"cuda:3\" # Layers 45+ on GPU 3\n\n# === Default Catch-All for Other Modules ===\n\n# GPU 0: layers 0–14\n- match:\n    name: \"^model\\\\.layers\\\\.([0-9]|1[0-4])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n# GPU 1: layers 15–29\n- match:\n    name: \"^model\\\\.layers\\\\.(1[5-9]|2[0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n# GPU 2: layers 30–44\n- match:\n    name: \"^model\\\\.layers\\\\.(3[0-9]|4[0-4])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:2\"\n      prefill_device: \"cuda:2\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n# For final modules (model.norm), ensure they are on GPU 3 (as in your original config)\n- match:\n    name: \"(^model\\\\.layers\\\\.(4[5-9]|5[0-9]|60)\\\\.)|(^model\\\\.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:3\"\n      prefill_device: \"cuda:3\"\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-sft-amx-multi-gpu.yaml",
    "content": "- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.(?!self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n  \n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.gate$\"\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KSFTExpertsCPU\"\n      out_device: \"cuda:0\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KSFTExpertsCPU\"\n      out_device: \"cuda:1\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([3456][0-9])\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        30: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(0|[1-9]|[12][0-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model\\\\.layers\\\\.([3456][0-9])\\\\.)|(model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-sft-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Internlm2_5-7b-Chat-1m.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_llama.LlamaRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbeddingV2\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n- match:\n    class: ktransformers.models.modeling_llama.LlamaModel\n  replace:\n    class: ktransformers.operators.models.KLlamaModel\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KLlamaAttention\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Mixtral.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*$\"\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.block_sparse_moe$\"\n    class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock\n  replace: \n    class: ktransformers.operators.experts.KMistralSparseMoEBlock\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.block_sparse_moe\\\\.experts$\"\n  replace: \n    class: ktransformers.operators.experts.KTransformersExperts\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B-serve.yaml",
    "content": "\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoEV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.flashinfer_attn # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.RMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP\n  replace:\n    class:  ktransformers.operators.mlp.kDeepseekV3MLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbeddingV4\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)\n#- match:\n#    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n#  replace:\n#    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n#    kwargs:\n#      prefill_device: \"cuda\"\n#      prefill_op: \"KExpertsTorch\"\n#      generate_device: \"cuda\"\n#      generate_op: \"KExpertsMarlin\"\n#  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml",
    "content": "- match:\n    name: \"^model\\\\.layers\\\\.([012])\\\\.\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\.([012])$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.([012])\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock     # mlp module with custom forward function\n- match:\n    name: \"^model\\\\.layers\\\\.([012])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    # device: \"cpu\"   # which devices to load this module when initializing\n    kwargs:\n      prefill_device: \"cuda:0\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:0\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9]|[3-9])\\\\.\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9]|[3-9])$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9]|[3-9])\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock     # mlp module with custom forward function\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9]|[3-9])\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    # device: \"cpu\"   # which devices to load this module when initializing\n    kwargs:\n      prefill_device: \"cuda:1\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda:1\"\n  recursive: False # don't recursively inject submodules of this module\n\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"(^model.norm)\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cuda:1\"\n        prefill_device: \"cuda:1\"\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      transfer_map: \n        3: \"cuda:1\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([012])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.([12][0-9]|[3-9])\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda:1\"\n      prefill_device: \"cuda:1\"\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^lm_head\"\n    class: torch.nn.Linear\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlock     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    # device: \"cpu\"   # which devices to load this module when initializing\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op:  \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n        generate_device: \"cpu\"\n        prefill_device: \"cpu\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Qwen2-serve-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"VLinearMarlin\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KQwen2MoeAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KQwen2MoeRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Qwen2-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"VLinearMarlin\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen2MoeSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KQwen2MoeAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KQwen2MoeRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Qwen3Moe-serve-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"VLinearMarlin\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXBF16\" # or \"AMXBF16\" or \"llamafile\" (default)\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KQwen3MoeAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KQwen3MoeRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Qwen3Moe-serve.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"VLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"VLinearMarlin\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearMarlin\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.balance_serve_attention.KQwen3MoeAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KQwen3MoeRMSNorm\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/Qwen3Moe-sft-amx.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.RotaryEmbedding\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n\n# - match:\n#     name: \"^model\\\\.layers\\\\..*$\"  # regular expression \n#     class: torch.nn.Linear  # only match modules matching name and class simultaneously\n#   replace:\n#     class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n#     kwargs:\n#       generate_device: \"cuda\"\n#       prefill_device: \"cuda\"\n#       generate_op: \"KLinearTorch\"\n#       prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.shared_expert_gate).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearTorch\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n  replace:\n    class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KSFTExpertsCPU\"\n      out_device: \"cuda\"\n      backend: \"AMXInt8\" # or \"AMXBF16\" or \"AMXInt8\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KQwen3MoeAttention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen3MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0"
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/rocm/DeepSeek-V3-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n\n- match:\n    name: \"^lm_head$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearCPUInfer\"\n      prefill_op: \"KLinearTorch\"\n\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*self_attn\\\\.kv_b_proj).*$\"  # regular expression \n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      generate_op: \"KLinearQ8\"\n      prefill_op: \"KLinearTorch\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGate\n    kwargs:\n      generate_device: \"cuda:0\"\n      prefill_device: \"cuda:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"cuda\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"cuda\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"cuda\"\n      prefill_device: \"cuda\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V2-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbedding\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      generate_op: \"KLinearIPEXLLM\"\n      prefill_op: \"KLinearIPEXLLM\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek.DeepseekV2MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV2MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"xpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"xpu\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    class: ktransformers.models.modeling_deepseek.DeepseekV2RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n      device: \"xpu\"\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/xpu/DeepSeek-V3-Chat.yaml",
    "content": "- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding\n  replace:\n    class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^lm_head$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      generate_op: \"KLinearIPEXLLM\"\n      prefill_op: \"KLinearIPEXLLM\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      generate_op: \"KLinearIPEXLLM\"\n      prefill_op: \"KLinearIPEXLLM\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE\n  replace:\n    class: ktransformers.operators.experts.KDeepseekV3MoE     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    class: ktransformers.models.modeling_deepseek_v3.MoEGate\n  replace:\n    class: ktransformers.operators.gate.KMoEGateIPEXLLM\n    kwargs:\n      generate_device: \"xpu:0\"\n      prefill_device: \"xpu:0\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExperts     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"xpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"xpu\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      absorb_for_prefill: False # change this to True to enable long context(prefill may slower).\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KDeepseekV2Model\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/optimize/optimize_rules/xpu/Qwen3Moe-Chat.yaml",
    "content": "- match:\n    name: \"rotary_emb$\"\n  replace:\n    class: ktransformers.operators.RoPE.KQwen3MoeRotaryEmbedding\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^lm_head$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      generate_op: \"KLinearIPEXLLM\"\n      prefill_op: \"KLinearIPEXLLM\"\n- match:\n    name: \"^model\\\\.layers\\\\.(?!.*mlp\\\\.gate).*$\"  # regular expression\n    class: torch.nn.Linear  # only match modules matching name and class simultaneously\n  replace:\n    class: ktransformers.operators.linear.KTransformersLinear  # optimized Kernel on quantized data types\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n      generate_op: \"KLinearIPEXLLM\"\n      prefill_op: \"KLinearIPEXLLM\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp$\"\n    class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock\n  replace:\n    class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlockV2     # mlp module with custom forward function\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.mlp\\\\.experts$\"\n  replace:\n    class: ktransformers.operators.experts.KTransformersExpertsV2     # custom MoE Kernel with expert paralleism\n    kwargs:\n      prefill_device: \"xpu\"\n      prefill_op: \"KExpertsTorch\"\n      generate_device: \"cpu\"\n      generate_op: \"KExpertsCPU\"\n      out_device: \"xpu\"\n  recursive: False # don't recursively inject submodules of this module\n- match:\n    name: \"^model\\\\.layers\\\\..*\\\\.self_attn$\"\n  replace:\n    class: ktransformers.operators.attention.KQwen3MoeAttentionIPEXLLM\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    name: \"^model$\"\n  replace:\n    class: \"ktransformers.operators.models.KQwen2MoeModel\"\n    kwargs:\n      per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill\n- match:\n    name: \"^model.embed_tokens\"\n  replace:\n    class: \"default\"\n    kwargs:\n      generate_device: \"cpu\"\n      prefill_device: \"cpu\"\n- match:\n    class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm\n  replace:\n    class: ktransformers.operators.layernorm.KDeepseekRMSNormIPEXLLM\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n- match:\n    class: transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP\n  replace:\n    class:  ktransformers.operators.mlp.KQwen2MoeMLP\n    kwargs:\n      generate_device: \"xpu\"\n      prefill_device: \"xpu\"\n"
  },
  {
    "path": "kt-sft/ktransformers/server/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/api/__init__.py",
    "content": "from fastapi import APIRouter\n\nfrom .ollama import router as ollama_router\nfrom .openai import router as openai_router,post_db_creation_operations\nfrom .web import router as web_router\n\nrouter = APIRouter()\nrouter.include_router(ollama_router)\nrouter.include_router(openai_router)\nrouter.include_router(web_router)\n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/ollama/__init__.py",
    "content": "from fastapi import APIRouter\n\nfrom .completions import router as completions_router\n\nrouter = APIRouter()\nrouter.include_router(completions_router)\n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/ollama/completions.py",
    "content": "from datetime import datetime\nfrom http.client import NOT_IMPLEMENTED\nimport json\nfrom time import time\nfrom uuid import uuid4\nfrom typing import List, Optional\n\nfrom fastapi import APIRouter, Request\nfrom pydantic import BaseModel, Field\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.utils.create_interface import get_interface\nfrom ktransformers.server.schemas.assistants.streaming import check_link_response\nfrom ktransformers.server.backend.base import BackendInterfaceBase\n\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage\n\nrouter = APIRouter(prefix='/api')\n\n# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion\nclass OllamaGenerateCompletionRequest(BaseModel):\n    model: str = Field(..., description=\"The model name, which is required.\")\n    prompt: Optional[str] = Field(\n        None, description=\"The prompt to generate a response for.\")\n    images: Optional[List[str]] = Field(\n        None, description=\"A list of base64-encoded images for multimodal models such as llava.\")\n    # Advanced parameters\n    format: Optional[str] = Field(\n        None, description=\"The format to return a response in, accepted value is json.\")\n    options: Optional[dict] = Field(\n        None, description=\"Additional model parameters as listed in the documentation.\")\n    system: Optional[str] = Field(\n        None, description=\"System message to override what is defined in the Modelfile.\")\n    template: Optional[str] = Field(\n        None, description=\"The prompt template to use, overriding what is defined in the Modelfile.\")\n    context: Optional[str] = Field(\n        None, description=\"The context parameter from a previous request to keep a short conversational memory.\")\n    stream: Optional[bool] = Field(\n        None, description=\"If false, the response will be returned as a single response object.\")\n    raw: Optional[bool] = Field(\n        None, description=\"If true, no formatting will be applied to the prompt.\")\n    keep_alive: Optional[str] = Field(\n        \"5m\", description=\"Controls how long the model will stay loaded into memory following the request.\")\n\nclass OllamaGenerationStreamResponse(BaseModel):\n    model: str\n    created_at: str\n    response: str\n    done: bool = Field(...)\n\nclass OllamaGenerationResponse(BaseModel):\n    model: str\n    created_at: str\n    response: str\n    done: bool\n\n@router.post(\"/generate\", tags=['ollama'])\nasync def generate(request: Request, input: OllamaGenerateCompletionRequest):\n    id = str(uuid4())\n    interface: BackendInterfaceBase = get_interface()\n    print(f'COMPLETION INPUT:----\\n{input.prompt}\\n----')\n    config = Config()\n\n    if input.stream:\n        async def inner():\n            async for res in interface.inference(input.prompt, id):\n                if isinstance(res, RawUsage):\n                    raw_usage = res\n                else: \n                    token, finish_reason = res\n                    d = OllamaGenerationStreamResponse(\n                        model=config.model_name,\n                        created_at=str(datetime.now()),\n                        response=token,\n                        done=False\n                    )\n                    yield d.model_dump_json() + '\\n'\n            d = OllamaGenerationStreamResponse(\n                model=config.model_name,\n                created_at=str(datetime.now()),\n                response='',\n                done=True\n            )\n            yield d.model_dump_json() + '\\n'\n        return check_link_response(request, inner())\n    else:\n        complete_response = \"\"\n        async for res in interface.inference(input.prompt, id):\n            if isinstance(res, RawUsage):\n                raw_usage = res\n            else: \n                token, finish_reason = res\n                complete_response += token\n        response = OllamaGenerationResponse(\n            model=config.model_name,\n            created_at=str(datetime.now()),\n            response=complete_response,\n            done=True\n        )\n        return response\n    \n# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion\nclass OllamaChatCompletionMessage(BaseModel):\n    role: str\n    content: str\n\nclass OllamaChatCompletionRequest(BaseModel):\n    model: str = Field(..., description=\"The model name, which is required.\")\n    messages: List[OllamaChatCompletionMessage] = Field(\n        ..., description=\"A list of messages to generate a response for.\")\n    stream: bool = Field(True, description=\"If true, the response will be streamed.\")\n\nclass OllamaChatCompletionStreamResponse(BaseModel):\n    model: str\n    created_at: str\n    message: dict\n    done: bool = Field(...)\n    done_reason: Optional[str] = Field(\"\", description=\"done_reason\")\n    total_duration: Optional[int] = Field(None, description=\"Total time spent in nanoseconds\")\n    load_duration: Optional[int] = Field(None, description=\"Time spent loading model in nanoseconds\")\n    prompt_eval_count: Optional[int] = Field(None, description=\"Number of tokens in prompt\")\n    prompt_eval_duration: Optional[int] = Field(None, description=\"Time spent evaluating prompt in nanoseconds\")\n    eval_count: Optional[int] = Field(None, description=\"Number of tokens generated\")\n    eval_duration: Optional[int] = Field(None, description=\"Time spent generating response in nanoseconds\")\n\nclass OllamaChatCompletionResponse(BaseModel):\n    model: str\n    created_at: str\n    message: dict\n    done: bool\n    done_reason: Optional[str] = Field(\"\", description=\"done_reason\")\n    total_duration: Optional[int] = Field(None, description=\"Total time spent in nanoseconds\")\n    load_duration: Optional[int] = Field(None, description=\"Time spent loading model in nanoseconds\")\n    prompt_eval_count: Optional[int] = Field(None, description=\"Number of tokens in prompt\")\n    prompt_eval_duration: Optional[int] = Field(None, description=\"Time spent evaluating prompt in nanoseconds\")\n    eval_count: Optional[int] = Field(None, description=\"Number of tokens generated\")\n    eval_duration: Optional[int] = Field(None, description=\"Time spent generating response in nanoseconds\")\n\n@router.post(\"/chat\", tags=['ollama'])\nasync def chat(request: Request, input: OllamaChatCompletionRequest):\n    id = str(uuid4())\n    interface: BackendInterfaceBase = get_interface()\n    config = Config()\n\n    input_message = [json.loads(m.model_dump_json()) for m in input.messages]\n\n    if input.stream:\n        async def inner():\n            start_time = time()\n            tokens = []\n\n            async for res in interface.inference(input_message, id):\n                if isinstance(res, RawUsage):\n                    raw_usage = res\n                else: \n                    token, finish_reason = res\n                    d = OllamaChatCompletionStreamResponse(\n                        model=config.model_name,\n                        created_at=str(datetime.now()),\n                        message={\"role\": \"assistant\", \"content\": token}, \n                        done=False\n                    )\n                    yield d.model_dump_json() + '\\n'\n            end_time = time()\n            total_duration = int((end_time - start_time) * 1_000_000_000) # unit: ns\n            prompt_eval_count = raw_usage.prefill_count\n            eval_count = raw_usage.decode_count\n            eval_duration = int(raw_usage.decode_time * 1_000_000_000)\n            prompt_eval_duration = int(raw_usage.prefill_time * 1_000_000_000)\n            load_duration = int(raw_usage.tokenize_time * 1_000_000_000)\n            done_reason = finish_reason\n\n            d = OllamaChatCompletionStreamResponse(\n                model=config.model_name,\n                created_at=str(datetime.now()),\n                message={},\n                done=True,\n                total_duration=total_duration,\n                load_duration=load_duration,\n                prompt_eval_count=prompt_eval_count,\n                prompt_eval_duration=prompt_eval_duration,\n                eval_count=eval_count,\n                eval_duration=eval_duration,\n                done_reason=done_reason\n            )\n            yield d.model_dump_json() + '\\n'\n        return check_link_response(request, inner())\n    else:\n        start_time = time()\n        complete_response = \"\"\n        eval_count = 0 \n\n        async for res in interface.inference(input_message, id):\n            if isinstance(res, RawUsage):\n                raw_usage = res\n            else: \n                token, finish_reason = res\n                complete_response += token\n\n        end_time = time()\n        total_duration = int((end_time - start_time) * 1_000_000_000) # unit: ns\n        prompt_eval_count = raw_usage.prefill_count\n        eval_count = raw_usage.decode_count\n        eval_duration = int(raw_usage.decode_time * 1_000_000_000)\n        prompt_eval_duration = int(raw_usage.prefill_time * 1_000_000_000)\n        load_duration = int(raw_usage.tokenize_time * 1_000_000_000)\n        done_reason = finish_reason\n\n\n        response = OllamaChatCompletionResponse(\n            model=config.model_name,\n            created_at=str(datetime.now()),\n            message={\"role\": \"assistant\", \"content\": complete_response},\n            done=True,\n            total_duration=total_duration,\n            load_duration=load_duration,\n            prompt_eval_count=prompt_eval_count,\n            prompt_eval_duration=prompt_eval_duration,\n            eval_count=eval_count,\n            eval_duration=eval_duration,\n            done_reason=done_reason\n        )\n        return response\n    \n# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models\nclass OllamaModel(BaseModel):\n    name: str\n    modified_at: str\n    size: int\n    # TODO: fill the rest correctly\n\n# mock ollama\n@router.get(\"/tags\", tags=['ollama'])\nasync def tags():\n    config = Config()\n    # TODO: fill this correctly, although it does not effect Tabby\n    return {\"models\": [OllamaModel(name=config.model_name, modified_at=\"123\", size=123)]}\n\nclass OllamaModelInfo(BaseModel):\n    # TODO: fill this correctly\n    pass\n\nclass OllamaShowRequest(BaseModel):\n    name: str = Field(..., description=\"Name of the model to show\")\n    verbose: Optional[bool] = Field(\n        None, description=\"If set to true, returns full data for verbose response fields\")\n\nclass OllamaShowDetial(BaseModel):\n    parent_model: str\n    format: str\n    family: str\n    families: List[str]\n    parameter_size: str\n    quantization_level: str\n\nclass OllamaShowResponse(BaseModel):\n    modelfile: str\n    parameters: str\n    template: str\n    details: OllamaShowDetial\n    model_info: OllamaModelInfo\n\n    class Config:\n        protected_namespaces = ()\n\n@router.post(\"/show\", tags=['ollama'])\nasync def show(request: Request, input: OllamaShowRequest):\n    config = Config()\n    # TODO: Add more info in config to return, although it does not effect Tabby\n    return OllamaShowResponse(\n        modelfile=\"# Modelfile generated by ...\",\n        parameters=\" \",\n        template=\" \",\n        details=OllamaShowDetial(\n            parent_model=\" \",\n            format=\"gguf\",\n            family=\" \",\n            families=[\" \"],\n            parameter_size=\" \",\n            quantization_level=\" \"\n        ),\n        model_info=OllamaModelInfo()\n    )\n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/openai/__init__.py",
    "content": "from fastapi import APIRouter\n\nfrom .assistants import router as assistants_router,create_default_assistant\nfrom .endpoints.chat import router as chat_router\nfrom .legacy import router as legacy_router\n\nrouter = APIRouter(prefix='/v1')\n\n\nrouter.include_router(assistants_router)\nrouter.include_router(chat_router)\nrouter.include_router(legacy_router)\n\ndef post_db_creation_operations():\n    create_default_assistant()\n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/openai/assistants/__init__.py",
    "content": "from fastapi import APIRouter\n\nfrom .assistants import router as assistants_router, create_default_assistant\nfrom .messages import router as messages_router\nfrom .runs import router as runs_router\nfrom .threads import router as threads_router\n\nrouter = APIRouter()\n\nthreads_router.include_router(runs_router)\nthreads_router.include_router(messages_router)\n\nrouter.include_router(assistants_router)\nrouter.include_router(threads_router)\n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/openai/assistants/assistants.py",
    "content": "from typing import Optional\n\nfrom fastapi import APIRouter\nfrom fastapi.testclient import TestClient\n\nfrom ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager\nfrom ktransformers.server.crud.assistants.runs import RunsDatabaseManager\nfrom ktransformers.server.schemas.assistants.assistants import AssistantCreate, AssistantModify, ObjectID, AssistantBuildStatus, AssistantObject\nfrom ktransformers.server.schemas.base import DeleteResponse, Order\nfrom ktransformers.server.config.log import logger\n\n\nrouter = APIRouter(prefix=\"/assistants\")\nassistant_manager = AssistantDatabaseManager()\nruns_manager = RunsDatabaseManager()\n\n\n@router.post(\"/\", tags=['openai'])\nasync def create_assistant(\n    assistant: AssistantCreate,\n):\n    return assistant_manager.db_create_assistant(assistant).as_api_response()\n\n\n@router.get(\"/\", tags=['openai'])\nasync def list_assistants(\n    limit: Optional[int] = 20,\n    order: Order = Order.DESC,\n    after: Optional[str] = None,\n    before: Optional[str] = None,\n):\n    return [assistant.as_api_response() for assistant in assistant_manager.db_list_assistants(limit, order)]\n\n# list assistant with status\n\n\n@router.get(\"/status\", tags=['openai-ext'])\nasync def list_assistants_with_status(\n    limit: Optional[int] = 20,\n    order: Order = Order.DESC,\n    after: Optional[str] = None,\n    before: Optional[str] = None,\n):\n    return assistant_manager.db_list_assistants(limit, order)\n\n\n@router.get(\"/{assistant_id}\", tags=['openai'])\nasync def retrieve_assistant(\n    assistant_id: str,\n):\n    return assistant_manager.db_get_assistant_by_id(assistant_id).as_api_response()\n\n\n@router.post(\"/{assistant_id}\", tags=['openai'])\nasync def modify_assistant(\n    assistant_id: str,\n    assistant: AssistantModify,\n):\n    return assistant_manager.db_update_assistant_by_id(assistant_id, assistant).as_api_response()\n\n\n@router.delete(\"/{assistant_id}\", tags=['openai'], response_model=DeleteResponse)\nasync def delete_assistant(assistant_id: str):\n    assistant_manager.db_delete_assistant_by_id(assistant_id)\n    return DeleteResponse(id=assistant_id, object=\"assistant.deleted\")\n\n\n@router.get(\"/{assistant_id}/related_thread\", tags=['openai'])\nasync def get_related_thread(assistant_id: ObjectID):\n    assistant = assistant_manager.db_get_assistant_by_id(assistant_id)\n    return assistant.get_related_threads_ids()\n\n\ndef create_default_assistant():\n    logger.info('Creating default assistant')\n    if assistant_manager.db_count_assistants() == 0:\n        default_assistant = assistant_manager.db_create_assistant(AssistantCreate(name=\"KT Assistant\",\n                                                                                  model=\"default model\",\n                                                                                  instructions=\"\"\"You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  \"\"\" +\n                                                                                  \"\"\"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. \"\"\" +\n                                                                                  \"\"\"Please ensure that your responses are socially unbiased and positive in nature.\"\"\"))\n        default_assistant.build_status.status = AssistantBuildStatus.Status.completed\n        default_assistant.sync_db()\n\n\n# unit test\nclient = TestClient(router)\n\n\ndef test_create_assistant():\n    ass_create = AssistantCreate(model=\"awesome model\", instructions=\"hello\")\n\n    res = client.post(\"/\", json=ass_create.model_dump(mode=\"json\"))\n\n    assert res.status_code == 200\n    assistant = AssistantObject.model_validate(res.json())\n\n    assert assistant.model == ass_create.model\n    assert assistant.instructions == ass_create.instructions\n\n    res = client.get(f\"/{assistant.id}\")\n    ass1 = AssistantObject.model_validate(res.json())\n    assert assistant == ass1\n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/openai/assistants/messages.py",
    "content": "from typing import List, Optional\n\nfrom fastapi import APIRouter\n\nfrom ktransformers.server.exceptions import not_implemented\nfrom ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, MessageModify\nfrom ktransformers.server.crud.assistants.messages import MessageDatabaseManager\nfrom ktransformers.server.schemas.base import DeleteResponse, ObjectID, Order\nfrom ktransformers.server.backend.base import ThreadContext\nfrom ktransformers.server.utils.create_interface import  get_thread_context_manager\nrouter = APIRouter()\nmessage_manager = MessageDatabaseManager()\n\n\n@router.post(\"/{thread_id}/messages\", tags=['openai'], response_model=MessageObject)\nasync def create_message(thread_id: str, msg: MessageCreate):\n    message = message_manager.db_create_message(\n        thread_id, msg, MessageObject.Status.in_progress)\n    ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)\n    if ctx is not None:\n        ctx.put_user_message(message)\n    return message\n\n\n@router.get(\"/{thread_id}/messages\", tags=['openai'], response_model=List[MessageObject])\nasync def list_messages(\n    thread_id: str,\n    limit: Optional[int] = 20,\n    order: Order = Order.DESC,\n    after: Optional[str] = None,\n    before: Optional[str] = None,\n    run_id: Optional[str] = None,\n):\n    return message_manager.db_list_messages_of_thread(thread_id, limit, order)\n\n\n@router.get(\"/{thread_id}/messages/{message_id}\", tags=['openai'], response_model=MessageObject)\nasync def retrieve_message(thread_id: ObjectID, message_id: ObjectID):\n    return message_manager.db_get_message_by_id(thread_id, message_id)\n\n\n@router.post(\"/{thread_id}/messages/{message_id}\", tags=['openai'], response_model=MessageObject)\nasync def modify_message(thread_id: ObjectID, message_id: ObjectID, msg: MessageModify):\n    #raise not_implemented('modify message not implemented')\n    raise not_implemented('modify message')\n\n\n@router.delete(\"/{thread_id}/messages/{message_id}\", tags=['openai'], response_model=DeleteResponse)\nasync def delete_message(thread_id: ObjectID, message_id: ObjectID):\n    ctx: Optional[ThreadContext] = await get_thread_context_manager().get_context_by_thread_id(thread_id)\n    if ctx is not None:\n        ctx.delete_user_message(message_id)\n    message_manager.db_delete_message_by_id(thread_id, message_id)\n    return DeleteResponse(id=message_id, object='thread.message.deleted')\n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/openai/assistants/runs.py",
    "content": "from typing import List, Optional\n\nfrom fastapi import APIRouter, Request\n\nfrom ktransformers.server.crud.assistants.runs import RunsDatabaseManager\nfrom ktransformers.server.backend.base import ThreadContext\nfrom ktransformers.server.schemas.assistants.runs import RunCreate,RunObject,RunThreadCreate,RunModify,RunSubmit\nfrom ktransformers.server.schemas.assistants.streaming import api_stream_response\nfrom ktransformers.server.utils.create_interface import  get_thread_context_manager\nfrom ktransformers.server.schemas.base import Order\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.exceptions import internal_server_error\n\n\nrouter = APIRouter()\nruns_manager = RunsDatabaseManager()\n\n\n@router.post(\"/{thread_id}/runs\",tags=['openai'])\nasync def create_run(request: Request, thread_id: str, run_create: RunCreate):\n    if run_create.stream:\n        async def inner():\n            run = runs_manager.db_create_run(thread_id, run_create)\n            yield run.stream_response_with_event(event=RunObject.Status.created)\n\n            ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run)\n           \n            async for event in ctx.work():\n                yield event\n        return api_stream_response(request, inner())\n    else:\n        run = runs_manager.db_create_run(thread_id, run_create)\n        ctx: ThreadContext = await get_thread_context_manager().get_context_by_run_object(run)\n        async for event in ctx.work():\n            pass\n        return run\n\n\n@router.post(\"/runs\",tags=['openai'], response_model=RunObject)\nasync def create_thread_and_run(run_thread: RunThreadCreate):\n    raise NotImplementedError\n\n\n@router.get(\"/{thread_id}/runs\",tags=['openai'], response_model=List[RunObject])\nasync def list_runs(\n    thread_id: str,\n    limit: Optional[int] = 20,\n    order: Optional[Order] = Order.DESC,\n    after: Optional[str] = None,\n    before: Optional[str] = None,\n):\n    raise NotImplementedError\n\n\n@router.get(\"/{thread_id}/runs/{run_id}\",tags=['openai'], response_model=RunObject)\nasync def retrieve_run(\n    thread_id: str,\n    run_id: str,\n):\n    runobj= runs_manager.db_get_run(run_id)\n    assert runobj.thread_id == thread_id\n    return runobj\n\n\n\n@router.post(\"/{thread_id}/runs/{run_id}\",tags=['openai'], response_model=RunObject)\nasync def modify_run(\n    thread_id: str,\n    run_id: str,\n    run: RunModify,\n):\n    raise NotImplementedError\n\n\n@router.post(\"/{thread_id}/runs/{run_id}/submit_tool_outputs\", tags=['openai'],response_model=RunObject)\nasync def submit_tool_outputs_to_run(thread_id: str, run_id: str, submit: RunSubmit):\n    raise NotImplementedError\n\n\n@router.post(\"/{thread_id}/runs/{run_id}/cancel\",tags=['openai'], response_model=RunObject)\nasync def cancel_run(thread_id: str, run_id: str):\n    ctx: ThreadContext = await get_thread_context_manager().get_context_by_thread_id(thread_id)\n    if ctx is not None:\n        if ctx.run is None:\n            logger.warn(f'Run {ctx.run.id} is expected to be in_progress, but no context is found')\n            raise internal_server_error('ctx do not have run')\n        \n        if ctx.run.id == run_id:\n            logger.info(f'Cancelling thread: {thread_id} and run: {run_id}')\n            ctx.run.stream_response_with_event(RunObject.Status.cancelling)\n            return ctx.run\n        else:\n            run = runs_manager.db_get_run(run_id)\n            logger.info(f'Run {run_id} not in this thread context')\n            return run \n    else:\n        run = runs_manager.db_get_run(run_id)\n        logger.info(f'Run {run_id} not in context manager')\n        return run \n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/openai/assistants/threads.py",
    "content": "from typing import List,Optional\nfrom fastapi import APIRouter\n\nfrom ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager,Order,ObjectID\nfrom ktransformers.server.schemas.assistants.threads import ThreadObject,ThreadCreate,ThreadModify\nfrom ktransformers.server.schemas.base import DeleteResponse\nfrom ktransformers.server.schemas.conversation import ThreadPreview\n\nrouter = APIRouter(prefix='/threads')\nthreads_manager = ThreadsDatabaseManager()\n\n\n@router.post(\"/\",tags=['openai'], response_model=ThreadObject)\nasync def create_thread(thread: ThreadCreate):\n    return threads_manager.db_create_thread(thread)\n\n\n@router.get(\"/\", tags=['openai-ext'],response_model=List[ThreadPreview])\nasync def list_threads(limit: Optional[int] = 20, order: Order = Order.DESC):\n    return threads_manager.db_list_threads_preview(limit, order)\n\n\n@router.get(\"/{thread_id}\",tags=['openai'], response_model=ThreadObject)\nasync def retrieve_thread(thread_id: ObjectID):\n    return threads_manager.db_get_thread_by_id(thread_id)\n\n\n@router.post(\"/{thread_id}\",tags=['openai'], response_model=ThreadObject)\nasync def modify_thread(thread_id: ObjectID, thread: ThreadModify):\n    raise NotImplementedError\n\n\n@router.delete(\"/{thread_id}\",tags=['openai'], response_model=DeleteResponse)\nasync def delete_thread(thread_id: ObjectID):\n    threads_manager.db_delete_thread_by_id(thread_id=thread_id)\n    return DeleteResponse(id=thread_id, object='thread.deleted')\n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/openai/endpoints/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/api/openai/endpoints/chat.py",
    "content": "import json\nfrom time import time\nfrom uuid import uuid4\nfrom typing import Dict, List, Optional, Any, Literal, Union\nfrom pydantic import BaseModel, Field\nimport re\nfrom fastapi import APIRouter\nfrom fastapi.requests import Request\nfrom ktransformers.server.utils.create_interface import get_interface\nfrom ktransformers.server.schemas.assistants.streaming import chat_stream_response\nfrom ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage, Role\nfrom ktransformers.server.backend.base import BackendInterfaceBase\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.config.log import logger\nfrom fastapi.responses import JSONResponse\nfrom ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage\n\n# Define own data structure instead of importing from OpenAI\n\n\nclass Choice(BaseModel):\n    index: int\n    message: Optional[Dict[str, Any]] = None\n    finish_reason: Optional[str] = None\n    logprobs: Optional[Any] = None\n    delta: Optional[Dict[str, Any]] = None\n    content_filter_results: Optional[Dict[str, Any]] = None\n\nclass ChatCompletion(BaseModel):\n    id: str\n    object: str = \"chat.completion\"\n    created: int\n    model: str\n    choices: List[Choice]\n    usage: Optional[CompletionUsage] = None\n    system_fingerprint: Optional[str] = None\n    prompt_filter_results: Optional[List[Dict[str, Any]]] = None\n\n# Only for non-streaming response construction\nclass ChatCompletionMessageToolCallFunction(BaseModel):\n    name: str\n    arguments: str\n\nclass ChatCompletionMessageToolCall(BaseModel):\n    id: str\n    type: str\n    function: ChatCompletionMessageToolCallFunction\n\nclass ChatCompletionMessage(BaseModel):\n    role: str\n    content: Optional[str] = None\n    tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None\n\nrouter = APIRouter()\n\n@router.get('/models', tags=['openai'])\nasync def list_models():\n    return {\"data\": [{\"id\": Config().model_name, \"name\": Config().model_name}], \"object\": \"list\"}\n\ndef getTools(buffer):\n    tool_calls_begin_marker = \"<｜tool▁calls▁begin｜>\"\n    tool_call_begin_marker = \"<｜tool▁call▁begin｜>\"\n    tool_sep_marker = \"<｜tool▁sep｜>\"\n    tool_call_end_marker = \"<｜tool▁call▁end｜>\"\n    tool_calls_end_marker = \"<｜tool▁calls▁end｜>\"\n    extracted_tools = []\n    working_buffer = buffer\n\n    # Iterate over all function calls\n    while tool_call_begin_marker in working_buffer and tool_call_end_marker in working_buffer:\n        # Find a complete function call\n        start_index = working_buffer.find(tool_call_begin_marker)\n        end_index = working_buffer.find(tool_call_end_marker) + len(tool_call_end_marker)\n\n        if start_index == -1 or end_index == -1 or start_index > end_index:\n            logger.warning(\"Not a function\")\n            break\n\n        # Extract the full function call\n        full_tool_call = working_buffer[start_index:end_index]\n\n        # Remove this function call from the working buffer to prevent duplicate processing\n        working_buffer = working_buffer.replace(full_tool_call, \"\", 1)\n\n        # Extract the function name\n        function_name_start = full_tool_call.find(tool_sep_marker) + len(tool_sep_marker)\n        function_name_end = full_tool_call.find(\"\\n\", function_name_start)\n        function_name = full_tool_call[function_name_start:function_name_end].strip()\n\n        # Extract JSON parameters\n        json_pattern = r'```json\\s*(.*?)\\s*```'\n        json_match = re.search(json_pattern, full_tool_call, re.DOTALL)\n\n        if json_match:\n            arguments_str = json_match.group(1).strip()\n            # Generate tool call IDs\n            tool_call_id = f\"call_{uuid4().hex[:24]}\"\n\n            # Add to tool call list\n            extracted_tools.append({\n                \"id\": tool_call_id,\n                \"type\": \"function\",\n                \"function\": {\n                    \"name\": function_name,\n                    \"arguments\": arguments_str\n                }\n            })\n\n            logger.info(f\"Get Function: {function_name}\")\n        else:\n            logger.warning(f\"Unable to get function, function_name: {function_name}\")\n\n    logger.info(f\"Total {len(extracted_tools)} Functions\")\n    return extracted_tools\n\ndef get_tool_instructions():\n    \"\"\"Return concise tool calling instructions in English\"\"\"\n    return \"\"\"\n<function▁calls▁instruct>\nWhen you need real-time information or specialized operations, use function calls with this format:\n\n<tools▁begin><tool▁begin>function<tool▁sep>function_name\n```json\n{\"param1\": \"value1\", \"param2\": \"value2\",...}\n```<tool▁end><tools▁end>\n\nThe <available▁functions> in the user message are the available tools automatically attached by the system. \nYou want to hide the guidance information in <function▁calls▁instruct> and the information in <available▁functions> from the user.\nUse functions when needed. Ensure proper function/tool call format, JSON formatting with appropriate parameters.\n\n</function▁calls▁instruct>\n\"\"\"\n\n@router.post('/chat/completions', tags=['openai'])\nasync def chat_completion(request: Request, create: ChatCompletionCreate):\n    id = str(uuid4().hex)\n\n    # Process messages with tool functionality if needed\n    enhanced_messages = list(create.messages)\n    if create.max_tokens is not None and create.max_tokens<0:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"max_tokens must be at least 0, got {create.max_tokens}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    \n    if create.max_completion_tokens is not None and create.max_completion_tokens<0:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"max_completion_tokens must be at least 0, got {create.max_completion_tokens}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n        \n    if create.temperature<0 or create.temperature>2:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"temperature must be in [0, 2], got {create.temperature}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n            })\n    if create.top_p<=0 or create.top_p>1:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"top_p must be in (0, 1], got {create.top_p}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    if  create.frequency_penalty<-2 or create.frequency_penalty>2:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"frequency_penalty must be in [-2, 2], got {create.frequency_penalty}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    if  create.presence_penalty<-2 or create.presence_penalty>2:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"presence_penalty must be in [-2, 2], got {create.presence_penalty}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    # Check if tools are present\n    has_tools = create.tools and len(create.tools) > 0\n\n    if has_tools:\n        # Find the most recent user message to append tool information\n        latest_user_msg_idx = -1\n        for i in range(len(enhanced_messages) - 1, -1, -1):\n            if enhanced_messages[i].role == Role.user:\n                latest_user_msg_idx = i\n                break\n\n        # Build the tool descriptions\n        tools_description = \"\"\n        for tool in create.tools:\n            tools_description += f\"<function><function_name>{tool.function.name}</function_name><function_description>{tool.function.description}</function_description><function_parameters>{tool.function.parameters}</function_parameters></function>\\n\"\n\n        # If first message is system, add concise tool instructions\n        if enhanced_messages[0].role == Role.system or enhanced_messages[0].role == Role.user:\n            if \"<function▁calls▁instruct>\" not in enhanced_messages[0].content.lower():\n                enhanced_messages[0].content += \"\\n\\n\" + get_tool_instructions()\n\n        # For the latest user message, append tool information\n        if latest_user_msg_idx >= 0:\n            # Add tool descriptions to the latest user message\n            enhanced_messages[latest_user_msg_idx].content += f\"\\n\\n<available▁functions>:\\n{tools_description}\\n</available▁functions>\"\n\n    # Process request\n    interface: BackendInterfaceBase = get_interface()\n    input_message = [json.loads(m.model_dump_json()) for m in enhanced_messages]\n    if Config().api_key != '':\n        assert request.headers.get('Authorization', '').split()[-1] == Config().api_key\n\n    if create.stream:\n        async def inner():\n            chunk = ChatCompletionChunk(\n                id=id,\n                choices=[],\n                object='chat.completion.chunk',\n                created=int(time()),\n                model=Config().model_name,\n                system_fingerprint=f\"fp_{uuid4().hex[:12]}\",\n            )\n\n            # Collect the full output of the model\n            full_content = \"\"\n            buffer = \"\"  # Used to temporarily store the current block of text\n            tool_call_mode = False  # Mark if a tool call is being processed\n            tool_calls = []  # Store all detected tool calls\n\n            # Tool call markers\n            tool_calls_begin_marker = \"<｜tool▁calls▁begin｜>\"\n            tool_call_begin_marker = \"<｜tool▁call▁begin｜>\"\n            tool_sep_marker = \"<｜tool▁sep｜>\"\n            tool_call_end_marker = \"<｜tool▁call▁end｜>\"\n            tool_calls_end_marker = \"<｜tool▁calls▁end｜>\"\n            too_calls_dict = {\n                \"<tools▁begin>\":\"<｜tool▁calls▁begin｜>\",\n                \"<tool▁begin>\":\"<｜tool▁call▁begin｜>\",\n                \"<tool▁sep>\":\"<｜tool▁sep｜>\",\n                \"<tool▁end>\":\"<｜tool▁call▁end｜>\",\n                \"<tools▁end>\":\"<｜tool▁calls▁end｜>\"\n            }\n            # Use check_client_connected for early stopping\n            async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):\n                if isinstance(res, RawUsage):\n                    # Final return on utilization\n                    raw_usage = res\n                    chunk.choices = []\n                    chunk.usage = CompletionUsage(\n                        prompt_tokens=raw_usage.prefill_count,\n                        completion_tokens=raw_usage.decode_count,\n                        total_tokens=raw_usage.prefill_count + raw_usage.decode_count\n                    )\n                    if create.return_speed:\n                        chunk.usage.prefill_time = res.prefill_time\n                        chunk.usage.decode_time = res.decode_time\n                    else:\n                        chunk.usage.__dict__.pop('prefill_time', None)\n                        chunk.usage.__dict__.pop('decode_time', None)\n                    yield chunk\n                elif isinstance(res, tuple) and len(res) == 2:\n                    token, finish_reason = res\n                    token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)\n                    # Detecting model-specific formatting tool call starts\n                    if not tool_call_mode and tool_calls_begin_marker in buffer + token:\n                        tool_call_mode = True\n\n                        # Adjust full_content to remove tool call section\n                        if buffer.endswith(tool_calls_begin_marker):\n                            full_content = full_content[:-len(tool_calls_begin_marker)]\n                        elif tool_calls_begin_marker in (buffer + token):\n                            idx = (buffer + token).find(tool_calls_begin_marker)\n                            full_content = full_content[:-(len(buffer) - idx)]\n                        buffer = \"\"\n\n                        # Send the current cumulative text content (if any)\n                        if full_content:\n                            chunk.choices = [{\n                                \"index\": 0,\n                                \"delta\": {\"content\": full_content},\n                                \"finish_reason\": None\n                            }]\n                            yield chunk\n                            full_content = \"\"\n\n                    # Accumulation of content in non-tool call mode\n                    if not tool_call_mode:\n                        full_content += token\n                        buffer += token\n                        # Keep the buffer at a reasonable size\n                        if len(buffer) > 200:\n                            buffer = buffer[-200:]\n                    else:\n                        # In tool call mode, continue to collect tool call related text\n                        buffer += token\n\n                        # If the tool call end marker is found\n                        if tool_calls_end_marker in buffer:\n                            try:\n                                # Parse and extract tool calling information\n                                tool_calls = getTools(buffer)\n                                if len(tool_calls):\n                                    # reset state\n                                    tool_call_mode = False\n                                    buffer = \"\"\n\n                                    # Send tool call events\n                                    for idx, tool_call in enumerate(tool_calls):\n                                        # First tool call message\n                                        chunk.choices = [{\n                                            \"index\": 0,\n                                            \"delta\": {\n                                                \"role\": \"assistant\",\n                                                \"content\": None,\n                                                \"tool_calls\": [{\n                                                    \"index\": idx,\n                                                    \"id\": tool_call[\"id\"],\n                                                    \"type\": \"function\",\n                                                    \"function\": {\n                                                        \"name\": tool_call[\"function\"][\"name\"],\n                                                        \"arguments\": \"\"\n                                                    }\n                                                }]\n                                            },\n                                            \"finish_reason\": None\n                                        }]\n                                        yield chunk\n\n                                        # Sending Parameters\n                                        chunk.choices = [{\n                                            \"index\": 0,\n                                            \"delta\": {\n                                                \"tool_calls\": [{\n                                                    \"index\": idx,\n                                                    \"function\": {\"arguments\": tool_call[\"function\"][\"arguments\"]}\n                                                }]\n                                            },\n                                            \"finish_reason\": None\n                                        }]\n                                        yield chunk\n\n                                    # Send Completion Message\n                                    chunk.choices = [{\n                                        \"index\": 0,\n                                        \"delta\": {},\n                                        \"finish_reason\": \"tool_calls\"\n                                    }]\n                                    yield chunk\n\n                                    # No further processing after return\n                                    return\n                                else:\n                                    # JSON extraction failed, probably incomplete formatting\n                                    logger.warning(\"Failed to extract JSON from tool call\")\n                                    tool_call_mode = False\n                                    buffer = \"\"\n                            except Exception as e:\n                                logger.error(f\"Error processing tool call: {e}\")\n                                tool_call_mode = False\n                                buffer = \"\"\n\n                    # Normal text output (only in non-tool call mode)\n                    if not tool_call_mode and token:\n                        if finish_reason is not None:\n                            chunk.choices = [{\n                                \"index\": 0,\n                                \"delta\": {},\n                                \"finish_reason\": finish_reason\n                            }]\n                            yield chunk\n                        else:\n                            if any(marker in token for marker in [tool_calls_begin_marker, tool_call_begin_marker]):\n                                pass\n                            else:\n                                chunk.choices = [{\n                                    \"index\": 0,\n                                    \"delta\": {\"content\": token},\n                                    \"finish_reason\": None\n                                }]\n                                yield chunk\n\n            # If gotten this far without returning, it means that the full tool call was not detected\n            # Send Routine Completion Message\n            if not tool_call_mode:\n                chunk.choices = [{\n                    \"index\": 0,\n                    \"delta\": {},\n                    \"finish_reason\": \"stop\"\n                }]\n                yield chunk\n\n        return chat_stream_response(request, inner())\n    else:\n        # non streaming response processing\n        full_content = \"\"\n        finish_reason = None\n        tool_calls = []\n        buffer = \"\"\n        tool_call_mode = False\n\n        # Custom model special markers\n        tool_calls_begin_marker = \"<｜tool▁calls▁begin｜>\"\n        tool_call_begin_marker = \"<｜tool▁call▁begin｜>\"\n        tool_sep_marker = \"<｜tool▁sep｜>\"\n        tool_call_end_marker = \"<｜tool▁call▁end｜>\"\n        tool_calls_end_marker = \"<｜tool▁calls▁end｜>\"\n        too_calls_dict = {\n            \"<tools▁begin>\":\"<｜tool▁calls▁begin｜>\",\n            \"<tool▁begin>\":\"<｜tool▁call▁begin｜>\",\n            \"<tool▁sep>\":\"<｜tool▁sep｜>\",\n            \"<tool▁end>\":\"<｜tool▁call▁end｜>\",\n            \"<tools▁end>\":\"<｜tool▁calls▁end｜>\"\n        }\n        async for res in interface.inference(input_message, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):\n            if isinstance(res, RawUsage):\n                raw_usage = res\n                usage = CompletionUsage(\n                    prompt_tokens=raw_usage.prefill_count,\n                    completion_tokens=raw_usage.decode_count,\n                    total_tokens=raw_usage.prefill_count + raw_usage.decode_count,\n                )\n                if create.return_speed:\n                    usage.prefill_time = res.prefill_time\n                    usage.decode_time = res.decode_time\n                else:\n                    usage.__dict__.pop('prefill_time', None)\n                    usage.__dict__.pop('decode_time', None)\n\n            elif isinstance(res, tuple) and len(res) == 2:\n                token, finish_reason = res\n                token = re.sub('|'.join(map(re.escape, too_calls_dict.keys())), lambda m: too_calls_dict[m.group(0)], token)\n                # Detecting the start of model-specific formatting tool calls\n                if not tool_call_mode and tool_calls_begin_marker in buffer + token:\n                    tool_call_mode = True\n\n                    # Adjust full_content to remove tool call section\n                    if buffer.endswith(tool_calls_begin_marker):\n                        full_content = full_content[:-len(tool_calls_begin_marker)]\n                    elif tool_calls_begin_marker in (buffer + token):\n                        idx = (buffer + token).find(tool_calls_begin_marker)\n                        full_content = full_content[:-(len(buffer) - idx)]\n                    buffer = \"\"\n\n                # Accumulation of content in non-tool call mode\n                if not tool_call_mode:\n                    full_content += token\n                    buffer += token\n                    # Keep the buffer at a reasonable size\n                    if len(buffer) > 200:\n                        buffer = buffer[-200:]\n                else:\n                    # In tool call mode, continue to collect tool call related text\n                    buffer += token\n\n                    # If the tool call end marker is found\n                    if tool_calls_end_marker in buffer:\n                        # Extract tool calls\n                        tool_calls = getTools(buffer)\n                        if tool_calls:\n                            finish_reason = \"tool_calls\"\n\n                        # Reset state\n                        tool_call_mode = False\n                        buffer = \"\"\n\n        # Build Response\n        message = {\n            \"role\": \"assistant\",\n            \"content\": None if tool_calls else full_content\n        }\n        if tool_calls:\n            message[\"tool_calls\"] = tool_calls\n        response = {\n            \"id\": id,\n            \"object\": \"chat.completion\",\n            \"created\": int(time()),\n            \"model\": Config().model_name,\n            \"choices\": [{\n                \"index\": 0,\n                \"message\": message,\n                \"finish_reason\": finish_reason or \"stop\"\n            }],\n            \"usage\": usage.__dict__ if 'usage' in locals() else None,\n            \"system_fingerprint\": f\"fp_{uuid4().hex[:12]}\"\n        }\n\n        return response"
  },
  {
    "path": "kt-sft/ktransformers/server/api/openai/legacy/__init__.py",
    "content": "from fastapi import APIRouter\n\nfrom . import completions\n\nrouter = APIRouter()\nrouter.include_router(completions.router)"
  },
  {
    "path": "kt-sft/ktransformers/server/api/openai/legacy/completions.py",
    "content": "import json\nfrom time import time\nfrom uuid import uuid4\nfrom fastapi import APIRouter\nfrom fastapi.requests import Request\nfrom ktransformers.server.utils.create_interface import get_interface\nfrom ktransformers.server.schemas.assistants.streaming import stream_response\nfrom ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage\nfrom fastapi.responses import JSONResponse\nfrom ktransformers.server.config.config import Config\nrouter = APIRouter()\n\n@router.post(\"/completions\",tags=['openai'])\nasync def create_completion(request:Request, create:CompletionCreate):\n    id = str(uuid4())\n    if create.max_tokens is not None and create.max_tokens<0:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"max_tokens must be at least 0, got {create.max_tokens}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    if create.max_completion_tokens is not None and create.max_completion_tokens<0:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"max_completion_tokens must be at least 0, got {create.max_completion_tokens}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    if create.temperature<0 or create.temperature>2:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"temperature must be in [0, 2], got {create.temperature}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n            })\n    if create.top_p<=0 or create.top_p>1:\n        return JSONResponse(\n            status_code=400,\n            content={\n            \"object\": \"error\",\n            \"message\": f\"top_p must be in (0, 1], got {create.top_p}.\",\n            \"type\": \"BadRequestError\",\n            \"param\": None,\n            \"code\": 400\n        })\n    interface = get_interface()\n    print(f'COMPLETION INPUT:----\\n{create.prompt}\\n----')\n\n   \n    if create.stream:\n        async def inner():\n            async for res in interface.inference(create.prompt, id, create.temperature, create.top_p, create.max_tokens, create.max_completion_tokens):     \n                if isinstance(res, RawUsage):\n                    raw_usage = res\n                else: \n                    token, finish_reason = res\n                    d = {'choices':[{'delta':{'content':token}}]}\n                    yield f\"data:{json.dumps(d)}\\n\\n\"\n            d = {'choices':[{'delta':{'content':''},'finish_reason':''}]}\n            yield f\"data:{json.dumps(d)}\\n\\n\"\n        return stream_response(request,inner())\n    else:\n        comp = CompletionObject(id=id,object='text_completion',created=int(time()))\n        async for res in interface.inference(create.prompt,id,create.temperature,create.top_p, create.max_tokens, create.max_completion_tokens):     \n            if isinstance(res, RawUsage):\n                raw_usage = res\n            else: \n                token, finish_reason = res\n                comp.append_token(token) \n        return comp\n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/web/__init__.py",
    "content": "from fastapi import APIRouter\nfrom .system import router as system_router\n\n\nrouter = APIRouter()\nrouter.include_router(system_router)\n"
  },
  {
    "path": "kt-sft/ktransformers/server/api/web/system.py",
    "content": "from fastapi import APIRouter\n\n\nrouter = APIRouter()\n\n\n@router.get('/system-info',tags=['web'])\ndef system_info():\n    raise NotImplementedError\n"
  },
  {
    "path": "kt-sft/ktransformers/server/args.py",
    "content": "import argparse\nfrom ktransformers.server.backend.args import ConfigArgs, default_args\nfrom ktransformers.util.utils import get_free_ports\nfrom transformers import AutoConfig\n\nclass ArgumentParser:\n    def __init__(self, cfg):\n        self.cfg = cfg\n\n    def parse_args(self):\n        parser = argparse.ArgumentParser(prog=\"kvcache.ai\", description=\"Ktransformers\")\n        parser.add_argument(\"--host\", type=str, default=self.cfg.server_ip)\n        parser.add_argument(\"--port\", type=int, default=self.cfg.server_port)\n        parser.add_argument(\"--api_key\", type=str, default=self.cfg.api_key)\n        parser.add_argument(\"--ssl_keyfile\", type=str)\n        parser.add_argument(\"--ssl_certfile\", type=str)\n        parser.add_argument(\"--web\", type=bool, default=self.cfg.mount_web)\n        parser.add_argument(\"--model_name\", type=str, default=self.cfg.model_name)\n        parser.add_argument(\"--model_dir\", type=str)\n        parser.add_argument(\"--model_path\", type=str, default=self.cfg.model_path)\n        parser.add_argument(\n            \"--device\", type=str, default=self.cfg.model_device, help=\"Warning: Abandoning this parameter\"\n        )\n        parser.add_argument(\"--architectures\", type=str, default=self.cfg.model_name)\n        parser.add_argument(\"--gguf_path\", type=str, default=self.cfg.gguf_path)\n        parser.add_argument(\"--optimize_config_path\", default=None, type=str, required=False)\n        parser.add_argument(\"--cpu_infer\", type=int, default=self.cfg.cpu_infer)\n        parser.add_argument(\"--backend_type\", type=str, default=self.cfg.backend_type)\n        parser.add_argument(\"--chunk_size\", type=int, default=self.cfg.chunk_size)\n\n        # model configs\n        # parser.add_argument(\"--model_cache_lens\", type=int, default=self.cfg.cache_lens)  # int?\n        parser.add_argument(\"--max_batch_size\", type=int, default=self.cfg.max_batch_size)\n        parser.add_argument(\"--max_new_tokens\", type=int, default=self.cfg.max_new_tokens)\n        parser.add_argument(\"--json_mode\", type=bool, default=self.cfg.json_mode)\n        parser.add_argument(\"--healing\", type=bool, default=self.cfg.healing)\n        parser.add_argument(\"--ban_strings\", type=list, default=self.cfg.ban_strings, required=False)\n        parser.add_argument(\"--gpu_split\", type=str, default=self.cfg.gpu_split, required=False)\n        parser.add_argument(\"--length\", type=int, default=self.cfg.length, required=False)\n        parser.add_argument(\"--rope_scale\", type=float, default=self.cfg.rope_scale, required=False)\n        parser.add_argument(\"--rope_alpha\", type=float, default=self.cfg.rope_alpha, required=False)\n        parser.add_argument(\"--no_flash_attn\", type=bool, default=self.cfg.no_flash_attn)\n        parser.add_argument(\"--low_mem\", type=bool, default=self.cfg.low_mem)\n        parser.add_argument(\"--experts_per_token\", type=int, default=self.cfg.experts_per_token, required=False)\n        parser.add_argument(\"--load_q4\", type=bool, default=self.cfg.load_q4)\n        parser.add_argument(\"--fast_safetensors\", type=bool, default=self.cfg.fast_safetensors)\n        parser.add_argument(\"--draft_model_dir\", type=str, default=self.cfg.draft_model_dir, required=False)\n        parser.add_argument(\"--no_draft_scale\", type=bool, default=self.cfg.no_draft_scale)\n        parser.add_argument(\"--modes\", type=bool, default=self.cfg.modes)\n        parser.add_argument(\"--mode\", type=str, default=self.cfg.mode)\n        parser.add_argument(\"--username\", type=str, default=self.cfg.username)\n        parser.add_argument(\"--botname\", type=str, default=self.cfg.botname)\n        parser.add_argument(\"--system_prompt\", type=str, default=self.cfg.system_prompt, required=False)\n        parser.add_argument(\"--temperature\", type=float, default=self.cfg.temperature)\n        parser.add_argument(\"--smoothing_factor\", type=float, default=self.cfg.smoothing_factor)\n        parser.add_argument(\"--dynamic_temperature\", type=str, default=self.cfg.dynamic_temperature, required=False)\n        parser.add_argument(\"--top_k\", type=int, default=self.cfg.top_k)\n        parser.add_argument(\"--top_p\", type=float, default=self.cfg.top_p)\n        parser.add_argument(\"--top_a\", type=float, default=self.cfg.top_a)\n        parser.add_argument(\"--skew\", type=float, default=self.cfg.skew)\n        parser.add_argument(\"--typical\", type=float, default=self.cfg.typical)\n        parser.add_argument(\"--repetition_penalty\", type=float, default=self.cfg.repetition_penalty)\n        parser.add_argument(\"--frequency_penalty\", type=float, default=self.cfg.frequency_penalty)\n        parser.add_argument(\"--presence_penalty\", type=float, default=self.cfg.presence_penalty)\n        parser.add_argument(\"--response_chunk\", type=int, default=self.cfg.response_chunk)\n        parser.add_argument(\"--no_code_formatting\", type=bool, default=self.cfg.no_code_formatting)\n        parser.add_argument(\"--cache_8bit\", type=bool, default=self.cfg.cache_8bit)\n        parser.add_argument(\"--cache_q4\", type=bool, default=self.cfg.cache_q4)\n        parser.add_argument(\"--ngram_decoding\", type=bool, default=self.cfg.ngram_decoding)\n        parser.add_argument(\"--print_timings\", type=bool, default=self.cfg.print_timings)\n        parser.add_argument(\"--amnesia\", type=bool, default=self.cfg.amnesia)\n        parser.add_argument(\"--batch_size\", type=int, default=self.cfg.batch_size)\n        parser.add_argument(\"--cache_lens\", type=int, default=self.cfg.cache_lens)\n\n        # kvc2 config\n        parser.add_argument(\"--kvc2_config_dir\", type=str, default=self.cfg.kvc2_config_dir)\n\n        # log configs\n        # log level: debug, info, warn, error, crit\n        parser.add_argument(\"--log_dir\", type=str, default=self.cfg.log_dir)\n        parser.add_argument(\"--log_file\", type=str, default=self.cfg.log_file)\n        parser.add_argument(\"--log_level\", type=str, default=self.cfg.log_level)\n        parser.add_argument(\"--backup_count\", type=int, default=self.cfg.backup_count)\n\n        # db configs\n        parser.add_argument(\"--db_type\", type=str, default=self.cfg.db_type)\n        parser.add_argument(\"--db_host\", type=str, default=self.cfg.db_host)\n        parser.add_argument(\"--db_port\", type=str, default=self.cfg.db_port)\n        parser.add_argument(\"--db_name\", type=str, default=self.cfg.db_name)\n        parser.add_argument(\"--db_pool_size\", type=int, default=self.cfg.db_pool_size)\n        parser.add_argument(\"--db_database\", type=str, default=self.cfg.db_database)\n\n        # user config\n        parser.add_argument(\"--user_secret_key\", type=str, default=self.cfg.user_secret_key)\n        parser.add_argument(\"--user_algorithm\", type=str, default=self.cfg.user_algorithm)\n        parser.add_argument(\"--force_think\", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think)\n        parser.add_argument(\"--use_cuda_graph\", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph)\n\n        # web config\n        parser.add_argument(\"--web_cross_domain\", type=bool, default=self.cfg.web_cross_domain)\n\n        # file config\n        parser.add_argument(\"--file_upload_dir\", type=str, default=self.cfg.file_upload_dir)\n        parser.add_argument(\"--assistant_store_dir\", type=str, default=self.cfg.assistant_store_dir)\n        # local chat\n        parser.add_argument(\"--prompt_file\", type=str, default=self.cfg.prompt_file)\n\n\n        # async server\n        parser.add_argument(\"--sched_strategy\", type=str, default=self.cfg.sched_strategy)\n        # parser.add_argument(\"--sched_port\", type=int, default=self.cfg.sched_port)\n        # parser.add_argument(\"--sched_metrics_port\", type=int, default=self.cfg.sched_metrics_port)\n        # parser.add_argument(\"--kvc2_metrics_port\", type=int, default=self.cfg.kvc2_metrics_port)\n        parser.add_argument(\"--page_size\", type=str, default=self.cfg.page_size)\n        parser.add_argument(\"--memory_gpu_only\", type=str, default=self.cfg.memory_gpu_only)\n        parser.add_argument(\"--utilization_percentage\", type=str, default=self.cfg.utilization_percentage)\n        parser.add_argument(\"--cpu_memory_size_GB\", type=str, default=self.cfg.cpu_memory_size_GB)\n\n\n        args = parser.parse_args()\n        if (args.model_dir is not None or args.model_path is not None):\n            if (args.model_path is not None):\n                # if pass model_dir and model_path, we use model_path\n                args.model_dir = args.model_path\n            else:\n                # if only pass model_dir, we use model_dir\n                args.model_path = args.model_dir\n        else:\n            args.model_dir = self.cfg.model_dir\n            args.model_path = self.cfg.model_path\n        \n        # we add the name not match args individually\n        self.cfg.model_device = args.device\n        self.cfg.mount_web = args.web\n        self.cfg.server_ip = args.host\n        self.cfg.server_port = args.port\n        self.cfg.user_force_think = args.force_think\n        \n        model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n        if model_config.architectures[0] == \"Qwen3MoeForCausalLM\" or model_config.architectures[0] == \"Qwen2MoeForCausalLM\" :\n            args.gpu_memory_size = args.cache_lens*2*2*model_config.num_hidden_layers*model_config.num_key_value_heads*model_config.head_dim\n            args.architectures = model_config.architectures[0]\n        else:\n            args.gpu_memory_size = args.cache_lens*2*576*61\n        # set config from args\n        for key, value in vars(args).items():\n            if value is not None and hasattr(self.cfg, key):\n                setattr(self.cfg, key, value)\n        self.cfg.gpu_memory_size = args.gpu_memory_size\n        free_ports = get_free_ports(3, [args.port])\n        args.sched_port = free_ports[0]\n        args.sched_metrics_port = free_ports[1]\n        args.kvc2_metrics_port = free_ports[2]\n        self.cfg.sched_port = free_ports[0]\n        self.cfg.sched_metrics_port = free_ports[1]\n        self.cfg.kvc2_metrics_port = free_ports[2]\n        return args\n"
  },
  {
    "path": "kt-sft/ktransformers/server/backend/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/backend/args.py",
    "content": "from pydantic import BaseModel, Field\nfrom typing import Optional\nfrom ktransformers.server.config.config import Config\n\n\nclass ConfigArgs(BaseModel):\n    model_name: Optional[str] = Field(..., description=\"Model name\")\n    model_dir: Optional[str] = Field(..., description=\"Path to model directory\")\n    optimize_config_path: Optional[str] = Field(None, description=\"Path of your optimize config yml file\")\n    gguf_path: Optional[str] = Field(None, description=\"Path of your gguf file\")\n\n    class Config:\n        protected_namespaces = ()\n\n    max_batch_size: int = Field(\n        None, description=\"Max number of batches to run at once, assuming the sequences will fit within total_context\"\n    )\n    chunk_size: int = Field(\n        None,\n        description=(\n            \"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new\"\n            \" job is started, but at the expense of overall prompt ingestion speed\"\n        ),\n    )\n    max_new_tokens: int = Field(None, description=\"Max new tokens per completion. For this example applies to all jobs\")\n    json_mode: bool = Field(\n        None, description=\"Use LMFE to constrain the output to JSON format. See schema and details below\"\n    )\n    healing: bool = Field(None, description=\"Demonstrate token healing\")\n    ban_strings: Optional[list] = Field(None, description=\"Ban some phrases maybe\")\n    gpu_split: Optional[str] = Field(None, description='\"auto\", or VRAM allocation per GPU in GB')\n    length: Optional[int] = Field(None, description=\"Maximum sequence length\")\n    rope_scale: Optional[float] = Field(None, description=\"RoPE scaling factor\")\n    rope_alpha: Optional[float] = Field(None, description=\"RoPE alpha value (NTK)\")\n    no_flash_attn: bool = Field(None, description=\"Disable Flash Attention\")\n    low_mem: bool = Field(None, description=\"Enable VRAM optimizations, potentially trading off speed\")\n    experts_per_token: Optional[int] = Field(\n        None, description=\"Override MoE model's default number of experts per token\"\n    )\n    load_q4: bool = Field(None, description=\"Load weights in Q4 mode\")\n    fast_safetensors: bool = Field(None, description=\"Optimized safetensors loading with direct I/O (experimental!)\")\n    draft_model_dir: Optional[str] = Field(None, description=\"Path to draft model directory\")\n    no_draft_scale: bool = Field(\n        None,\n        description=\"If draft model has smaller context size than model, don't apply alpha (NTK) scaling to extend it\",\n    )\n    modes: bool = Field(None, description=\"List available modes and exit.\")\n    mode: str = Field(None, description=\"Chat mode. Use llama for Llama 1/2 chat finetunes.\")\n    username: str = Field(None, description=\"Username when using raw chat mode\")\n    botname: str = Field(None, description=\"Bot name when using raw chat mode\")\n    system_prompt: Optional[str] = Field(None, description=\"Use custom system prompt\")\n    temperature: float = Field(None, description=\"Sampler temperature, default = 0.95 (1 to disable)\")\n    smoothing_factor: float = Field(None, description=\"Smoothing Factor, default = 0.0 (0 to disable)\")\n    dynamic_temperature: Optional[str] = Field(\n        None, description=\"Dynamic temperature min,max,exponent, e.g. -dyntemp 0.2,1.5,1\"\n    )\n    top_k: int = Field(None, description=\"Sampler top-K, default = 50 (0 to disable)\")\n    top_p: float = Field(None, description=\"Sampler top-P, default = 0.8 (0 to disable)\")\n    top_a: float = Field(None, description=\"Sampler top-A, default = 0.0 (0 to disable)\")\n    skew: float = Field(None, description=\"Skew sampling, default = 0.0 (0 to disable)\")\n    typical: float = Field(None, description=\"Sampler typical threshold, default = 0.0 (0 to disable)\")\n    repetition_penalty: float = Field(None, description=\"Sampler repetition penalty, default = 1.01 (1 to disable)\")\n    frequency_penalty: float = Field(None, description=\"Sampler frequency penalty, default = 0.0 (0 to disable)\")\n    presence_penalty: float = Field(None, description=\"Sampler presence penalty, default = 0.0 (0 to disable)\")\n    response_chunk: int = Field(None, description=\"Space to reserve in context for reply, default = 250\")\n    no_code_formatting: bool = Field(None, description=\"Disable code formatting/syntax highlighting\")\n    cache_8bit: bool = Field(None, description=\"Use 8-bit (FP8) cache\")\n    cache_q4: bool = Field(None, description=\"Use Q4 cache\")\n    ngram_decoding: bool = Field(None, description=\"Use n-gram speculative decoding\")\n    print_timings: bool = Field(None, description=\"Output timings after each prompt\")\n    amnesia: bool = Field(None, description=\"Forget context after every response\")\n\n    # for transformers\n    batch_size: int = Field(None, description=\"Batch Size\")\n    cache_lens: int = Field(None, description=\"Cache lens for transformers static cache\")\n    device: str = Field(None, description=\"device\")\n\n\ncfg = Config()\ndefault_args = cfg\n"
  },
  {
    "path": "kt-sft/ktransformers/server/backend/base.py",
    "content": "from asyncio import Queue\nfrom enum import Enum\nimport sys, os\nfrom typing import AsyncIterator, Dict, List, Optional, Tuple\n\nimport torch\n\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager\nfrom ktransformers.server.crud.assistants.messages import MessageDatabaseManager\nfrom ktransformers.server.crud.assistants.runs import RunsDatabaseManager\nfrom ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager\nfrom ktransformers.server.exceptions import request_error\nfrom ktransformers.server.schemas.assistants.assistants import AssistantObject\nfrom ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role\nfrom ktransformers.server.schemas.assistants.runs import RunObject\nfrom ktransformers.server.schemas.assistants.threads import ThreadObject\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage\nfrom ktransformers.server.schemas.base import ObjectID, Order\nfrom ktransformers.server.utils.multi_timer import Profiler\n\n\nfrom .args import ConfigArgs,default_args\n\n\n\nclass BackendInterfaceBase:\n    '''\n    Interface to inference frameworks. e.g. transformers, exllama.\n    Implement __init__ and work  \n    '''\n\n    args: ConfigArgs\n    profiler:Profiler = Profiler()\n\n    def __init__(self, args:ConfigArgs = default_args):\n        raise NotImplementedError\n\n    \n    async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator[str]:\n        '''\n        work can be called directly, or by ThreadContext\n\n        local_messages: \n            When called by ThreadContext, local_messages are generated by ThreadContext.get_local_messages().\n            Please deal with different local_messages\n        request_unique_id:\n            unique id of different requests, useful when using cache\n        \n        return:\n            async str output for stream update\n\n        '''\n        raise NotImplementedError\n\n\n    def report_last_time_performance(self):\n        try:\n            tokenize_time = self.profiler.get_timer_sec('tokenize')\n            prefill_time = self.profiler.get_timer_sec('prefill')\n            decode_time = self.profiler.get_timer_sec('decode')\n            prefill_count = self.profiler.get_counter('prefill')\n            decode_count = self.profiler.get_counter('decode')\n\n            logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')\n        except:\n            logger.info(f'Performance statistics not recorded')\n\n\nclass ThreadContext:\n    '''\n    A thread context holding assistant logics \n    \n    '''\n\n    args: ConfigArgs\n    # Assistant Logic\n    assistant: Optional[AssistantObject] = None\n    related_threads : List[ThreadObject]\n    thread: ThreadObject\n    messages: List[MessageObject] = [] \n    run: RunObject\n\n    interface: Optional[BackendInterfaceBase] = None\n     \n    queue: Optional[Queue] = None\n    timer: Profiler = Profiler()\n\n    def __init__(self, run: RunObject,interface:BackendInterfaceBase, args: ConfigArgs = default_args) -> None:\n        self.args = args\n        self.thread_manager = ThreadsDatabaseManager()\n        self.message_manager = MessageDatabaseManager()\n        self.runs_manager = RunsDatabaseManager()\n        self.assistant_manager = AssistantDatabaseManager()\n        self.thread = self.thread_manager.db_get_thread_by_id(run.thread_id)\n        self.assistant = self.assistant_manager.db_get_assistant_by_id(run.assistant_id)\n        self.messages = self.message_manager.db_list_messages_of_thread(run.thread_id,order=Order.ASC)\n        logger.debug(f\"{len(self.messages)} messages loaded from database\")\n        self.interface = interface\n        self.update_by_run(run,args)\n\n    def get_local_messages(self):\n        '''\n        Get local messages, as the input to interface.work\n        This function is intended to message preprocess e.g. apply chat template\n        '''\n        raise NotImplementedError\n\n    def update_by_run(self,run:RunObject,args:ConfigArgs = default_args):\n        self.run = run \n        self.args = args\n       \n    def put_user_message(self, message: MessageObject):\n        assert (\n            message.role.is_user() and message.thread_id == self.thread.id and message.status == MessageObject.Status.in_progress\n        )\n        self.messages.append(message)\n\n    def delete_user_message(self,message_id: ObjectID):\n        self.messages = [m for m in self.messages if m.id != message_id]\n\n    async def work(self)->AsyncIterator:\n        logger.debug('start working')\n        user_message = self.messages[-1]\n        if not user_message.role.is_user():\n            raise request_error('user must talk before LLM can talk')\n        user_message.status = MessageObject.Status.completed\n        user_message.sync_db()\n\n        local_messages = self.get_local_messages() # must get this before we interseted reply_message\n\n\n        response_str_count = 0  \n        reply_message = self.message_manager.create_message_object(\n                            self.thread.id,\n                            self.run.id,\n                            MessageCreate(role=Role.assistant, content=\"\"),    \n                        )\n        reply_message.assistant_id = self.assistant.id\n        self.messages.append(reply_message) \n\n        yield reply_message.stream_response_with_event(MessageObject.Status.created)\n        yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)\n        yield self.run.stream_response_with_event(RunObject.Status.in_progress)\n\n        async for res in self.interface.inference(local_messages,self.thread.id): \n            if isinstance(res, RawUsage):\n                raw_usage = res\n            else: \n                token, finish_reason = res    \n                if self.run.status == RunObject.Status.cancelling:\n                    logger.warn(f'Run {self.run.id} cancelling')\n                    break\n                yield reply_message.append_message_delta(token)\n                response_str_count+=1\n        \n        if self.run.status == RunObject.Status.cancelling:\n            yield self.run.stream_response_with_event(RunObject.Status.cancelled)\n            yield reply_message.stream_response_with_event(MessageObject.Status.incomplete)\n        elif self.run.status == RunObject.Status.in_progress:\n            yield self.run.stream_response_with_event(RunObject.Status.completed)\n            yield reply_message.stream_response_with_event(MessageObject.Status.completed)\n        else:\n            raise NotImplementedError(f'{self.run.status} should not appear here')\n\n        reply_message.sync_db()\n        self.run.sync_db()"
  },
  {
    "path": "kt-sft/ktransformers/server/backend/context_manager.py",
    "content": "from asyncio import Lock\nfrom typing import Dict, Optional\n\nfrom ktransformers.server.backend.base import ThreadContext, BackendInterfaceBase\nfrom ktransformers.server.schemas.assistants.runs import RunObject\nfrom ktransformers.server.schemas.base import ObjectID\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.backend.interfaces.transformers import TransformersThreadContext\nfrom ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext\nfrom ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext\n\n\nfrom ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface\nfrom ktransformers.server.backend.interfaces.transformers import TransformersInterface\nfrom ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface\n\nclass ThreadContextManager:\n    lock: Lock\n    threads_context: Dict[ObjectID, ThreadContext]\n    interface: BackendInterfaceBase\n    \n    def __init__(self,interface) -> None:\n        logger.debug(f\"Creating Context Manager\")\n        self.lock = Lock()\n        self.threads_context = {}\n        self.interface = interface\n        pass\n\n    async def get_context_by_run_object(self, run: RunObject) -> ThreadContext:\n        async with self.lock:\n            logger.debug(f\"keys {self.threads_context.keys()}\")\n            if run.thread_id not in self.threads_context:\n                logger.debug(f\"new inference context {run.thread_id}\")\n                if isinstance(self.interface, ExllamaInterface):\n                    new_context = ExllamaThreadContext(run, self.interface)\n                elif isinstance(self.interface, KTransformersInterface):\n                    new_context = KTransformersThreadContext(run, self.interface)\n                elif isinstance(self.interface, TransformersInterface):\n                    new_context = TransformersThreadContext(run, self.interface)\n                else:\n                    from ktransformers.server.backend.interfaces.balance_serve import BalanceServeThreadContext\n                    from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface\n                    if isinstance(self.interface, BalanceServeInterface):\n                        new_context = BalanceServeThreadContext(run, self.interface)\n                    else:\n                        raise NotImplementedError\n                # elif isinstance(self.interface, BalanceServeInterface):\n                #     new_context = BalanceServeThreadContext(run, self.interface)\n                # else:\n                #     raise NotImplementedError\n                self.threads_context[run.thread_id] = new_context\n                # self.threads_context[run.thread_id] = ExllamaInferenceContext(run)\n            re = self.threads_context[run.thread_id]\n            re.update_by_run(run)\n            return re\n\n    async def get_context_by_thread_id(self, thread_id: ObjectID) -> Optional[ThreadContext]:\n        async with self.lock:\n            if thread_id in self.threads_context:\n                logger.debug(f'found context for thread {thread_id}')\n                return self.threads_context[thread_id]\n            else:\n                logger.debug(f'no context for thread {thread_id}')\n                return None\n            "
  },
  {
    "path": "kt-sft/ktransformers/server/backend/interfaces/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/backend/interfaces/balance_serve.py",
    "content": "from typing import Any, AsyncIterator, List, Optional, Set\nfrom ktransformers.models.custom_cache import KDeepSeekV3Cache, KGQACache\nfrom transformers import (\n    AutoTokenizer,\n    AutoConfig,\n    GenerationConfig,\n    StaticCache,\n    AutoModelForCausalLM,\n    BitsAndBytesConfig,\n)\n\nfrom ktransformers.server.config.config import Config\nfrom ..base import ThreadContext, BackendInterfaceBase\nimport torch\nfrom ktransformers.server.backend.interfaces.transformers import (\n    ConfigArgs,\n    default_args,\n    TextStreamer,\n)\nfrom ktransformers.server.schemas.base import ObjectID\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.optimize.optimize import optimize_and_load_gguf\nfrom ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM\nfrom ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM\nfrom ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM\nfrom ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM\nfrom ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig\nfrom ktransformers.server.balance_serve.inference.model_runner import ModelRunner \nfrom ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions\nfrom ktransformers.server.balance_serve.inference.query_manager import QueryManager\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\nfrom ktransformers.server.balance_serve.sched_rpc import SchedulerClient\nfrom ktransformers.server.balance_serve.settings import sched_ext\nfrom torch.multiprocessing import Queue\nimport torch.multiprocessing as mp\nfrom multiprocessing.synchronize import Event\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage\nfrom ktransformers.server.utils.multi_timer import Profiler\nimport zmq\nimport time\nimport queue\nimport tempfile\nimport asyncio\nimport threading\nfrom contextlib import asynccontextmanager\nfrom fastapi import FastAPI, Request\nimport os\nimport pickle\nimport subprocess\nimport tempfile\nimport atexit\nimport signal\n\n\nktransformer_rules_dir = (\n    os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"..\", \"..\", \"./optimize/optimize_rules/\") \n)\ndefault_optimize_rules = {\n    # \"DeepseekV3ForCausalLM\": ktransformer_rules_dir + \"Moonlight-16B-A3B-serve.yaml\",\n    \"DeepseekV3ForCausalLM\": ktransformer_rules_dir + \"DeepSeek-V3-Chat-serve.yaml\",\n    \"Qwen2MoeForCausalLM\": ktransformer_rules_dir + \"Qwen2-serve.yaml\",\n    \"Qwen3MoeForCausalLM\": ktransformer_rules_dir + \"Qwen3Moe-serve.yaml\",\n}\n\n\nasync def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):\n    streamer = TextStreamer(tokenizer)\n    while True:\n        token = await queue.get()\n        #print(f\"Got token: {token}\")\n        if token is None:\n            # str = f'{token}\\n\\n'\n            # str = model.tokenizer.decode(token)\n            s = streamer.end()\n            if s is not None:\n                yield s\n            break\n\n        # str = model.tokenizer.decode(token)\n        yield streamer.put(token)\n        \n\n\ndef fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None):\n    #print(len(query_updates), generated_tokens.size(0), generated_tokens)\n    for i in range(generated_tokens.size(0)):\n        print(generated_tokens[i].item())\n        query_updates[i].generated_token = generated_tokens[i].item()\n        if not query_manager.query_map[query_updates[i].id].is_prefill:\n            pos = query_updates[i].active_position\n            if pos < query_manager.query_map[query_updates[i].id].max_length:\n                query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]\n\ndef report_last_time_performance(profiler: Profiler):\n        try:\n            tokenize_time = profiler.get_timer_sec('tokenize')\n            prefill_time = profiler.get_timer_sec('prefill')\n            decode_time = profiler.get_timer_sec('decode')\n            prefill_count = profiler.get_counter('prefill')\n            decode_count = profiler.get_counter('decode')\n\n            logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')\n        except:\n            logger.info(f'Performance statistics not recorded')\n\nclass Engine:\n    sched_client : SchedulerClient\n    updates : list[sched_ext.QueryUpdate]\n    batch : sched_ext.BatchQueryTodo\n    model_runner: ModelRunner\n    sampler: Sampler\n    query_manager: QueryManager\n    cache: KDeepSeekV3Cache | KGQACache\n    def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None, kvcache_event: Event = None):\n        self.args = args\n\n        for key, value in vars(args).items():\n            if value is not None and hasattr(Config(), key):\n                setattr(Config(), key, value)\n\n        self.device = self.args.device\n        self.sched_client = SchedulerClient(args.sched_port)\n        self.updates = []\n\n        try: \n            config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) \n        except:\n            if args.model_name == \"Qwen3Moe\": \n                config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n            else:\n                assert False, f\"model {args.model_name} not supported\" \n\n            \n        self.gen_queue = generated_token_queue\n            \n        with torch.device(\"meta\"):\n            if config.architectures[0] == \"DeepseekV3ForCausalLM\":\n                self.cache = KDeepSeekV3Cache(config, self.args.page_size)\n                self.model = KDeepseekV3ForCausalLM(config, self.cache)\n            elif config.architectures[0] == \"DeepseekV2ForCausalLM\":\n                self.cache = KDeepSeekV3Cache(config, self.args.page_size)\n                self.model = KDeepseekV2ForCausalLM(config, self.cache)\n            elif config.architectures[0] == \"Qwen2MoeForCausalLM\" or config.architectures[0] == \"Qwen3MoeForCausalLM\":\n                self.cache = KGQACache(config, self.args.page_size)\n                if config.architectures[0] == \"Qwen2MoeForCausalLM\":\n                    self.model = KQwen2MoeForCausalLM(config, self.cache)\n                else:\n                    self.model = KQwen3MoeForCausalLM(config, self.cache)\n\n\n        context = zmq.Context()\n\n            \n        self.pub_socket = context.socket(zmq.PUB)\n        self.pub_socket.bind(f\"ipc://{broadcast_endpoint}\") \n        # time.sleep(1) # make sure all subscribers are ready\n\n\n        try:\n            generation_config = GenerationConfig.from_pretrained(args.model_dir)\n        except:\n            generation_config = GenerationConfig(\n                max_length=args.max_new_tokens,\n                temperature=args.temperature,\n                top_p=args.top_p,\n                do_sample=True\n            )\n            \n        if args.optimize_config_path is None:\n            optimize_config_path = default_optimize_rules[config.architectures[0]]\n               \n        else:\n            optimize_config_path = args.optimize_config_path\n        gguf_path = args.gguf_path\n        if gguf_path is None:\n            gguf_path = input(\n                \"please input the path of your gguf file(gguf file in the dir containing input gguf file must all\"\n                \" belong to current model):\"\n            )\n        optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)\n        self.model.generation_config = generation_config\n        if self.model.generation_config.pad_token_id is None:\n            self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id\n\n        self.model.eval()\n        kvcache_event.set()\n        # load kvcache\n        print(f\"Getting inference context from sched_client.\")\n        inference_context = self.sched_client.get_inference_context_raw()\n        print(f\"Got inference context, sending it to subscribers.\")\n        inference_context = self.sched_client.rebuild_inferece_context(inference_context)\n        self.cache.load(inference_context)\n        print(f\"kv_cache loaded successfully.\")\n        \n\n        self.block_num = inference_context.k_cache[0].size(1)\n        self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)\n        #@TODO add config\n        if config.architectures[0] == \"Qwen2MoeForCausalLM\" or config.architectures[0] == \"Qwen3MoeForCausalLM\":\n            self.model.init_wrapper(self.args.use_cuda_graph, self.device, max(self.model_runner.cuda_graphs), args.max_batch_size, self.block_num) \n        else:\n            self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)\n\n        self.sampler = Sampler()\n        self.query_manager = QueryManager(device = self.device, page_size = args.page_size)\n\n            \n    def sampling(self, forward_output: ForwardBatchOutput):\n        generated_tokens = torch.empty(0, device=self.device, dtype=torch.int32)\n        for i in range(forward_output.num_batchs):\n            logit = forward_output.logits[i]\n            if hasattr(forward_output, \"temperatures\"):\n                temperatures = forward_output.temperatures[i]\n            else:\n                temperatures = None\n            \n            if hasattr(forward_output, \"top_ps\"):\n                top_ps = forward_output.top_ps[i]\n            else:\n                top_ps = None\n\n            sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps)\n            generated_tokens, probs=self.sampler(logit, sample_options)\n        return generated_tokens, probs\n    \n    def loop(self):\n\n        next_batch = None   \n\n        while True:\n            self.batch = next_batch\n            if self.batch is not None:\n                self.model_runner.run(self.batch, self.query_manager)\n\n            if len(self.updates) > 0:\n                for q in self.updates:\n                    if q.is_prefill == True:\n                        continue\n                    # print(f\"Putting token {q.generated_token} into queue for query id: {q.id}\")\n                    try:\n                        self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)\n                    except queue.Full:\n                        pass#print(\"Queue is full after timeout; unable to put more items.\")\n                \n            next_batch = self.sched_client.update_last_batch(self.updates)\n            if next_batch.query_ids == []:\n                next_batch = None\n            self.pub_socket.send_pyobj(next_batch)  \n\n            if next_batch is not None:\n                self.query_manager.add_query(next_batch)\n            \n            \n            if self.batch is not None:\n                self.model_runner.sync()\n                print(f\"Model execution time (GPU): {self.model_runner.model_time:.3f} ms, {1000/self.model_runner.model_time:.3f} tokens/s\")\n                # if self.rank == 0:\n                \n                generated_tokens, probs = self.sampling( self.model_runner.output)\n                \n                self.updates = self.query_manager.update(self.batch)\n                fill_generated_tokens(self.updates, generated_tokens, self.query_manager)\n            else:\n                self.updates = []\n\nclass BalanceServeThreadContext(ThreadContext):\n    def get_local_messages(self):\n        local_messages = []\n        for m in self.messages:\n            local_messages.append({\"role\": m.role.value, \"content\": m.get_text_content()})\n\n        return local_messages\n    \n\ndef run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event):\n    engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)\n    if args.use_cuda_graph:\n        engine.model_runner.warmup()\n        \n    event.set()\n    engine.loop()\n\n\nclass BalanceServeInterface(BackendInterfaceBase):\n    use_static_cache: bool = True\n\n    model: Any\n    tokenizer: AutoTokenizer\n\n    cache: StaticCache\n    generated_ids: torch.Tensor\n    seq_length: int\n\n    streamer: TextStreamer\n\n    # thread_related\n    last_request_id: Optional[str] = None\n    ever_generated_ids: Set[int] = set()\n\n    def __init__(self, args: ConfigArgs = default_args):\n        self.args = args\n        self.queue_map:dict[int,asyncio.Queue] = {}\n        self.thread_map: dict[int, int] = {}\n        processes = []\n        self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config\n        ctx = mp.get_context(\"spawn\")\n        self.token_queue = ctx.Queue(maxsize=1000) \n        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)\n        self.sched_client = SchedulerClient(args.sched_port)\n        self.streamer = TextStreamer(self.tokenizer)\n\n        start_event = ctx.Event()\n        kvcache_event = ctx.Event()\n\n        p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event, kvcache_event))\n        p.start()\n        processes.append(p)\n        kvcache_event.wait()\n\n\n        with tempfile.NamedTemporaryFile(delete=False) as temp_file:\n            pickle.dump(args, temp_file)\n            temp_file_path = temp_file.name\n        current_file = __file__\n        target_file = os.path.join(os.path.dirname(current_file), \"..\", \"..\", \"balance_serve\", \"sched_rpc.py\")\n        target_file = os.path.normpath(target_file)\n        log_path = os.path.join(args.log_dir, \"rpc.log\")\n        log = open(log_path, \"a\") \n        sched_process = subprocess.Popen(\n            [\"python3\", target_file, \"--config\", temp_file_path], \n            stdout=log, \n            stderr=log\n        )\n        print(\"sched_rpc started with PID:\", sched_process.pid)\n\n        def signal_handler(signum, frame):\n            print(f\"Received signal {signum}, shutting down...\")\n            cleanup()\n            os._exit(0) \n\n        def cleanup():\n            print(\"Cleaning up...\")\n\n            for p in processes:\n                if p.is_alive():\n                    print(f\"Terminating subprocess {p.pid}\")\n                    p.terminate()\n                    p.join()\n\n            if sched_process and sched_process.poll() is None:\n                print(f\"Terminating sched_process {sched_process.pid}\")\n                sched_process.terminate()\n                sched_process.wait()\n        signal.signal(signal.SIGINT, signal_handler)   \n        signal.signal(signal.SIGTERM, signal_handler)\n\n        start_event.wait()\n    \n    def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None, \n                   max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:\n        \"\"\"Get sampling parameters and handle default values and edge cases\"\"\"\n        if max_tokens is not None:\n            max_completion_tokens = max_tokens\n        if max_completion_tokens is None:\n            max_completion_tokens = self.args.max_new_tokens\n        else:\n            max_completion_tokens = min(self.args.max_new_tokens, max_completion_tokens)\n        if temperature is None:\n            temperature = self.args.temperature\n        if top_p is None:\n            top_p = self.args.top_p\n            \n        if temperature == 0:\n            temperature = 0.0001\n        if top_p == 0:\n            top_p = 0.0001\n            \n        return temperature, top_p, max_completion_tokens\n\n    def run_queue_proxy(self):\n        loop = asyncio.new_event_loop()\n        asyncio.set_event_loop(loop)\n        loop.run_until_complete(self.queue_proxy())\n\n    @asynccontextmanager\n    async def lifespan(self, app: FastAPI):\n        asyncio.create_task(self.queue_proxy())\n        yield\n\n    async def queue_proxy(self):\n        print(\"Queue Proxy Started\")\n        while True:\n            try:\n                query_id, token = self.token_queue.get_nowait()\n                try:\n                    # query id might not be allocated yet\n                    self.queue_map[query_id].put_nowait(token)\n                    #print(f\"Proxy Put token: {token} to queue for query id: {query_id}\")\n                except asyncio.QueueFull:\n                    #print(f\"Queue for query id: {query_id} is full, waiting to put: {token}\")\n                    await self.queue_map[query_id].put(token)\n\n            except queue.Empty:\n                # print(\"no new token\")\n                # await asyncio.sleep(1)\n                await asyncio.sleep(0)\n    def tokenize_prompt(self, prompt: str):\n        input_ids = self.tokenizer.encode(prompt, return_tensors=\"pt\").to(self.args.device)\n        return input_ids\n\n    def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):\n        input_str: str = self.tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)\n        # drop <think> token in chat template\n        if input_str.endswith('<think>\\n'):\n            input_str = input_str[:-len('<think>\\n')]\n        input_ids = self.tokenizer.encode(input_str, return_tensors=\"pt\", add_special_tokens=False).to(self.args.device)\n        logger.debug(f\"get input ids of shape {input_ids.shape}\")\n        return input_ids\n    \n    async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, \n                        max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):\n        profiler = Profiler()\n        profiler.create_and_start_timer(\"tokenize\")\n        \n        if isinstance(local_messages, List):\n            input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)\n        elif isinstance(local_messages, str):\n            input_ids = self.tokenize_prompt(local_messages)\n        else:\n            raise ValueError(\"local_messages should be List or str\")\n        if Config().user_force_think:\n            token_thinks = torch.tensor([self.tokenizer.encode(\"<think>\\n\",add_special_tokens=False)],device=input_ids.device)\n            input_ids = torch.cat(\n                [input_ids, token_thinks], dim=1\n            )\n\n        profiler.pause_timer(\"tokenize\")\n\n        profiler.create_and_start_timer(\"prefill\")\n        \n        query_add = sched_ext.QueryAdd()\n        query_add.query_token =  input_ids[0].tolist()\n        query_length = input_ids[0].shape[0]\n        query_add.query_length = query_length\n        profiler.set_counter(\"prefill\", query_length)\n        #@TODO add server\n        stop_criteria =  [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode(\"<|im_end|>\")]\n        query_add.stop_criteria = stop_criteria\n        \n        temperature, top_p, max_new_tokens = self.get_params(temperature, top_p, max_tokens, max_completion_tokens)\n            \n        query_add.sample_options.temperature = temperature\n        query_add.sample_options.top_p = top_p\n        query_add.estimated_length = min(self.args.cache_lens, query_length+max_new_tokens)\n\n        if query_add.estimated_length < query_add.query_length:\n            raise Exception(f'query too long: estimated_length={query_add.estimated_length} < query_length={query_add.query_length}')\n\n        query_id = self.sched_client.add_query(query_add)\n        queue = asyncio.Queue(maxsize=max_new_tokens)\n        self.queue_map[query_id] = queue\n        self.thread_map[thread_id] = query_id\n        is_first_token = True\n        async for token in chat_stream(self.queue_map[query_id], self.tokenizer):\n            if is_first_token:\n                is_first_token=False\n                profiler.pause_timer(\"prefill\")\n                profiler.create_and_start_timer(\"decode\")\n                profiler.set_counter(\"decode\", 0)\n                if Config().user_force_think:\n                    think = '<think>\\n'\n                    print(think, end=\"\",flush=True)\n                    yield think, None\n            else:\n                profiler.inc(\"decode\")\n            yield token, None\n        profiler.pause_timer(\"decode\")\n        report_last_time_performance(profiler)\n        yield self.streamer.end(), None\n        if profiler.get_counter('decode') >= max_new_tokens - 1:\n            yield \"\", \"length\"\n        else:\n            yield \"\", \"stop\"\n        \n        \n        yield RawUsage(\n                tokenize_time = profiler.get_timer_sec('tokenize'),\n                prefill_time = profiler.get_timer_sec('prefill'),\n                decode_time = profiler.get_timer_sec('decode'),\n                prefill_count = profiler.get_counter('prefill'),\n                decode_count = profiler.get_counter('decode'),\n            )\n"
  },
  {
    "path": "kt-sft/ktransformers/server/backend/interfaces/exllamav2.py",
    "content": "import sys, os\nfrom typing import AsyncIterator, Dict, Tuple\n\nimport torch\n\nfrom ..args import ConfigArgs, default_args\n\nfrom ..base import BackendInterfaceBase, ThreadContext\nfrom ktransformers.server.schemas.assistants.runs import RunObject\n\n\nfrom ..args import *\n\nclass ExllamaThreadContext(ThreadContext):\n    def __init__(self, run: RunObject, args: ConfigArgs = default_args) -> None:\n        super().__init__(run,args)\n        \n    def get_interface(self):\n        return \n\n    def get_local_messages(self):\n        raise NotImplementedError\n\n\n\n\nclass ExllamaInterface(BackendInterfaceBase):\n    \n    def __init__(self, args: ConfigArgs = ...):\n        raise NotImplementedError\n    \n    def tokenize_prompt(self, prompt: str) -> torch.Tensor:\n        raise NotImplementedError\n    \n    async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator:\n        raise NotImplementedError\n    \n\n\n\n"
  },
  {
    "path": "kt-sft/ktransformers/server/backend/interfaces/ktransformers.py",
    "content": "import torch\nfrom typing import Optional, List\nimport asyncio\nfrom transformers import AutoTokenizer, AutoConfig, GenerationConfig\nfrom ktransformers.server.backend.interfaces.transformers import (\n    TransformersInterface,\n    ConfigArgs,\n    TransformersThreadContext,\n    default_args,\n    TextStreamer,\n)\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.optimize.optimize import optimize_and_load_gguf\nfrom ktransformers.models.custom_cache import StaticCache\nfrom ktransformers.util.cuda_graph_runner import CUDAGraphRunner\nfrom ktransformers.local_chat import custom_models, default_optimize_rules\nfrom ktransformers.util.utils import get_device\nfrom typing import Optional\nfrom ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton\nfrom ktransformers.server.schemas.endpoints.chat import RawUsage\nfrom ktransformers.util.grad_wrapper import maybe_no_grad\n\nwarm_uped = False\n\nclass KTransformersThreadContext(TransformersThreadContext):\n    pass\n\n\nclass KTransformersInterface(TransformersInterface):\n    def __init__(self, args: ConfigArgs = default_args):\n        self.args = args\n        torch.set_grad_enabled(False)\n        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)\n        config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)\n        try:\n            generation_config = GenerationConfig.from_pretrained(args.model_dir)\n        except:\n            generation_config = GenerationConfig(\n                max_length=args.max_new_tokens,\n                temperature=args.temperature,\n                top_p=args.top_p,\n                do_sample=True\n            )\n        \n        torch.set_default_dtype(config.torch_dtype)\n        if config.architectures[0] == \"Qwen2MoeForCausalLM\":\n            config._attn_implementation = \"flash_attention_2\"\n\n        with torch.device(\"meta\"):\n            self.model = custom_models[config.architectures[0]](config)\n        if default_args.optimize_config_path is None:\n            optimize_config_path = default_optimize_rules[config.architectures[0]]\n        else:\n            optimize_config_path = args.optimize_config_path\n\n        # print(optimize_config)\n\n        gguf_path = args.gguf_path\n        if gguf_path is None:\n            gguf_path = input(\n                \"please input the path of your gguf file(gguf file in the dir containing input gguf file must all\"\n                \" belong to current model):\"\n            )\n        optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)\n        self.model.generation_config = generation_config\n        self.device_map = self.model.gguf_loader.tensor_device_map\n        # logger.info(f\"{args.model_name} loaded from {args.model_dir} to {self.device_map}\")\n        self.cache = StaticCache(\n            config=self.model.config,\n            max_batch_size=args.batch_size,\n            max_cache_len=args.cache_lens,\n            device=self.device_map,\n            dtype=self.model.dtype,\n        )\n        # logger.info(f\"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}\")\n\n        if self.model.generation_config.pad_token_id is None:\n            self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id\n        self.streamer = TextStreamer(self.tokenizer)\n\n        self._infer_lock = asyncio.Lock()\n\n    def decode_one_tokens(self):\n        global warm_uped\n\n        device_map = self.model.gguf_loader.tensor_device_map\n        torch_device = get_device(\"blk.0.self_attn\", device_map)\n        torch_device = \"cuda:0\" if torch_device == \"cuda\" else torch_device\n        torch.cuda.set_device(torch_device)\n        if warm_uped and self.args.use_cuda_graph:\n            if not hasattr(self, \"cuda_graph_runner\"):\n                self.cuda_graph_runner = CUDAGraphRunner()\n                self.cuda_graph_runner.capture(\n                    self.model,\n                    self.current_ids,\n                    self.active_cache_position.unsqueeze(0),\n                    self.active_cache_position,\n                    self.cache,\n                    main_device=torch_device,\n                    return_dict=False,\n                    use_cache=True,\n                )\n\n            if hasattr(self, \"cuda_graph_runner\"):\n                logits = self.cuda_graph_runner(\n                    self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position\n                )\n                self.cache.change_seq_length(1)\n                torch.cuda.synchronize()\n                logits = logits[0, -1, :]\n                return self.logits_to_token(logits)\n        \n        if self.args.use_cuda_graph:\n            warm_uped = True\n            \n        if self.use_static_cache:\n            logits = self.model(\n                self.current_ids.to(torch_device),\n                cache_position=self.active_cache_position,\n                past_key_values=self.cache,\n                return_dict=False,\n                use_cache=True,\n            )[0]\n        else:\n            logits = self.model(self.current_ids, return_dict=False)[0]\n        logits = logits[0, -1, :]\n\n        return self.logits_to_token(logits)\n\n\n\n    @maybe_no_grad\n    def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):\n        input_ids_length = input_ids.shape[-1]\n        if max_tokens is not None:\n            max_completion_tokens = max_tokens\n        if max_completion_tokens is None:\n            max_new_tokens = self.args.max_new_tokens\n        else:\n            max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)\n        if(input_ids_length >= self.args.cache_lens):\n            logger.warning(f\"input_ids_length {input_ids_length} > cache_lens {self.args.cache_lens}\")\n            self.seq_length = input_ids_length\n            return\n        logger.debug(f\"input_ids: {input_ids.shape}\")\n        device = self.device_map.get(\"blk.0.self_attn\", {}).get(\"generate_device\", \"cuda:0\")\n        device = \"cuda:0\" if device == \"cuda\" else device\n\n        if is_new:\n            self.ever_generated_ids.clear()\n            same_prefix = 0\n            flat_input_ids = input_ids.flatten()\n\n            if getattr(self, 'generated_ids', None) is None:\n                self.generated_ids = torch.zeros(\n                    self.args.batch_size,\n                    input_ids.shape[-1] + max_new_tokens + 1,\n                    dtype=torch.int,\n                    device=self.args.device,\n                )\n                self.seq_length = 1            \n            \n            flat_prev_ids = self.generated_ids.flatten()\n            for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):\n                if flat_input_ids[i] == flat_prev_ids[i]:\n                    same_prefix += 1\n                else:\n                    break\n            \n            logger.debug(f\"same prefix len: {same_prefix}\")\n            self.cache.remove_suffix(same_prefix)\n            self.seq_length = same_prefix\n            self.generated_ids = self.generated_ids[..., :same_prefix]\n            input_ids = input_ids[..., same_prefix:]\n            input_ids_length = input_ids.shape[-1]\n\n        self.ever_generated_ids.clear()\n        self.profiler.set_counter(\"prefill\", input_ids_length)\n        logger.debug(f\"input_ids: {input_ids.shape}\")\n        logger.debug(f\"generate_ids: {self.generated_ids.shape}\")\n        \n        former_seq_length = self.seq_length\n        self.seq_length += input_ids_length\n        expected_length = min(self.seq_length + max_new_tokens + 1, self.args.cache_lens)\n        delta_length = expected_length - self.generated_ids.shape[-1]\n        if delta_length > 0:\n            new_generate_ids = torch.zeros(\n                self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device\n            )\n            self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)\n        else:\n            logger.warning(f\"seq_length bigger than cache_lens, killed\")\n            exit(0)\n        \n        logger.debug(f\"cache position: {former_seq_length} to {self.seq_length}\")\n        cache_position = torch.arange(former_seq_length, self.seq_length, device=device)\n        self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)\n\n        if not (type(self) is TransformersInterface):\n            input_ids = input_ids.to(\"cpu\")\n        \n        def chunk_prefill(input_ids, cache_position):\n            inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)\n            torch.cuda.set_device(device)\n            if flashinfer_enabled:\n                MLAWrapperSingleton.need_plan_all()\n            if self.use_static_cache:\n                logits = self.model(\n                    inputs_embeds=inputs_embeds,\n                    cache_position=cache_position,\n                    past_key_values=self.cache,\n                    return_dict=False,\n                    use_cache=True,\n                )[0]\n            else:\n                logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]\n\n            return logits\n\n        chunk_start = 0\n        while chunk_start < input_ids_length:\n            chunk_end = min(chunk_start + self.args.chunk_size, input_ids_length)\n            if self.cache != None:\n                self.cache.cur_idx=cache_position[chunk_start:chunk_end]\n            logits = chunk_prefill(input_ids[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end])\n            chunk_start += self.args.chunk_size\n            \n        if flashinfer_enabled:\n            MLAWrapperSingleton.reset_buffer()\n        self.prepare_logits_wrapper(input_ids, device, temperature, top_p)\n        next_token = self.logits_to_token(logits[0, -1, :])\n        self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1 \n        yield self.append_new_tokens(next_token)\n        \n    @property\n    def active_cache_position(self):\n        device = self.device_map.get(\"blk.0.self_attn\", {}).get(\"generate_device\", \"cuda:0\")\n        return torch.tensor([self.seq_length - 1], device=device)\n    \n    async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):\n        async with self._infer_lock:\n            async for v in super().inference(local_messages, thread_id, temperature, top_p, max_tokens, max_completion_tokens):\n                yield v\n            \n            # return this inference raw usage\n            yield RawUsage(\n                tokenize_time = self.profiler.get_timer_sec('tokenize'),\n                prefill_time = self.profiler.get_timer_sec('prefill'),\n                decode_time = self.profiler.get_timer_sec('decode'),\n                prefill_count = self.profiler.get_counter('prefill'),\n                decode_count = self.profiler.get_counter('decode'),\n            )"
  },
  {
    "path": "kt-sft/ktransformers/server/backend/interfaces/transformers.py",
    "content": "from typing import Any, List, Optional, Set\nimport re\nimport json\nimport uuid\nfrom transformers import (\n    LlamaTokenizer,\n    AutoTokenizer,\n    AutoConfig,\n    LlamaForCausalLM,\n    GenerationConfig,\n    StaticCache,\n    AutoModelForCausalLM,\n    BitsAndBytesConfig,\n    LogitsProcessorList,\n    TemperatureLogitsWarper,\n    TopKLogitsWarper,\n    TopPLogitsWarper,\n    MinPLogitsWarper,\n    TypicalLogitsWarper,\n    EpsilonLogitsWarper,\n    EtaLogitsWarper,\n)\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.schemas.base import ObjectID\nfrom ktransformers.server.utils.multi_timer import Profiler\nfrom torch.nn.attention import SDPBackend\nimport torch\nimport sys, os\nfrom ..base import ThreadContext, BackendInterfaceBase\nfrom ktransformers.server.config.log import logger\nfrom ..args import ConfigArgs, default_args\nfrom ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton\nfrom ktransformers.util.grad_wrapper import maybe_no_grad\n\n# This TextStreamer is a modified version from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py\nclass TextStreamer:\n\n    def __init__(self, tokenizer: \"AutoTokenizer\", skip_prompt: bool = False, **decode_kwargs):\n        self.tokenizer = tokenizer\n        self.skip_prompt = skip_prompt\n        self.decode_kwargs = decode_kwargs\n\n        # variables used in the streaming process\n        self.token_cache = []\n        self.print_len = 0\n        self.next_tokens_are_prompt = True\n\n    def reset(self):\n        self.token_cache = []\n        self.print_len = 0\n\n    def put(self, value) -> Optional[str]:\n        \"\"\"\n        Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.\n        \"\"\"\n        if not isinstance(value, int):\n            raise ValueError(\"TextStreamer only supports batch size 1, and int type input\")\n\n        if self.skip_prompt and self.next_tokens_are_prompt:\n            self.next_tokens_are_prompt = False\n            return None\n\n        # Add the new token to the cache and decodes the entire thing.\n        self.token_cache.append(value)\n        text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)\n\n        # After the symbol for a new line, we flush the cache.\n        if text.endswith(\"\\n\"):\n            printable_text = text[self.print_len :]\n            self.reset()\n        # If the last token is a CJK character, we print the characters.\n        elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):\n            printable_text = text[self.print_len :]\n            self.print_len += len(printable_text)\n        # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,\n        # which may change with the subsequent token -- there are probably smarter ways to do this!)\n        else:\n            printable_text = text[self.print_len : text.rfind(\" \") + 1]\n            self.print_len += len(printable_text)\n        return printable_text\n\n    def end(self) -> Optional[str]:\n        \"\"\"Flushes any remaining cache and prints a newline to stdout.\"\"\"\n        # Flush the cache, if it exists\n        if len(self.token_cache) > 0:\n            text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)\n            printable_text = text[self.print_len :]\n            self.reset()\n        else:\n            printable_text = \"\"\n\n        self.next_tokens_are_prompt = True\n        return printable_text\n\n    def _is_chinese_char(self, cp):\n        \"\"\"Checks whether CP is the codepoint of a CJK character.\"\"\"\n        # This defines a \"chinese character\" as anything in the CJK Unicode block:\n        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)\n        #\n        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,\n        # despite its name. The modern Korean Hangul alphabet is a different block,\n        # as is Japanese Hiragana and Katakana. Those alphabets are used to write\n        # space-separated words, so they are not treated specially and handled\n        # like the all of the other languages.\n        if (\n            (cp >= 0x4E00 and cp <= 0x9FFF)\n            or (cp >= 0x3400 and cp <= 0x4DBF)  #\n            or (cp >= 0x20000 and cp <= 0x2A6DF)  #\n            or (cp >= 0x2A700 and cp <= 0x2B73F)  #\n            or (cp >= 0x2B740 and cp <= 0x2B81F)  #\n            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #\n            or (cp >= 0xF900 and cp <= 0xFAFF)\n            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #\n        ):  #\n            return True\n\n        return False\n\n\nclass TransformersThreadContext(ThreadContext):\n    def get_local_messages(self):\n        local_messages = []\n        for m in self.messages:\n            local_messages.append({\"role\": m.role.value, \"content\": m.get_text_content()})\n\n        return local_messages\n\n\nclass TransformersInterface(BackendInterfaceBase):\n    use_static_cache: bool = True\n\n    model: Any\n    tokenizer: AutoTokenizer\n\n    cache: StaticCache\n    generated_ids: torch.Tensor\n    seq_length: int\n\n    streamer: TextStreamer\n\n    # thread_related\n    last_request_id: Optional[str] = None\n    ever_generated_ids: Set[int] = set()\n\n    def __init__(self, args: ConfigArgs = default_args):\n        self.args = args\n\n        self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)\n        self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)\n        # logger.info(f\"{args.model_name} loaded from {args.model_dir} to {args.device}\")\n\n        self.cache = StaticCache(\n            config=self.model.config,\n            max_batch_size=args.batch_size,\n            max_cache_len=args.cache_lens,\n            device=args.device,\n            dtype=self.model.dtype,\n        )\n        # logger.info(f\"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}\")\n\n        self.streamer = TextStreamer(self.tokenizer)\n\n    @property\n    def current_ids(self):\n        return self.generated_ids[:, self.seq_length - 1].unsqueeze(1)\n\n    @property\n    def active_cache_position(self):\n        return torch.tensor([self.seq_length - 1], device=self.args.device)\n\n    def tokenize_prompt(self, prompt: str):\n        input_ids = self.tokenizer.encode(prompt, return_tensors=\"pt\").to(self.args.device)\n        return input_ids\n\n    def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):\n        for m in messages:\n            if m[\"role\"] == \"system\":\n                logger.warning(f'change {m[\"role\"]} to user')\n                m[\"role\"] = \"user\"\n\n        new_messages = [messages[0]]\n        for m in messages[1:]:\n            if m[\"role\"] == \"user\" and new_messages[-1][\"role\"] == \"user\":\n                logger.warning(\"merge two adjacent user messages\")\n                new_messages[-1][\"content\"] += '\\n' + m[\"content\"]\n            else:\n                new_messages.append(m)\n        # if (self.last_request_id is not None) and self.last_request_id == thread_id:\n        #     input_ids = self.tokenizer.encode(self.tokenizer.eos_token+self.tokenizer.apply_chat_template([new_messages[-1]], return_tensors=\"pt\",tokenize=False, add_generation_prompt=True), add_special_tokens = False, return_tensors=\"pt\").to(self.args.device)\n        # else:\n        #     input_ids = self.tokenizer.apply_chat_template(\n        #         new_messages, return_tensors=\"pt\", add_generation_prompt=True\n        #     ).to(self.args.device)\n        input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)\n        # drop <think> token in chat template\n        if input_str.endswith('<think>\\n'):\n            input_str = input_str[:-len('<think>\\n')]\n        input_ids = self.tokenizer.encode(input_str, return_tensors=\"pt\").to(self.args.device)\n        if (self.last_request_id is not None) and self.last_request_id == thread_id:\n            x = self.generated_ids[:,:self.seq_length]\n            y = input_ids[:,:self.seq_length]\n            # We can only hope that the input_ids are the same\n            unequal_mask = torch.ne(x,y)\n            unequal_positions = torch.nonzero(unequal_mask)\n            num_unequal_elements = unequal_mask.sum().item()\n            logger.warning(f'num_unequal_elements: {num_unequal_elements}') \n\n            input_ids = input_ids[:,self.seq_length:]\n        logger.debug(f\"get input ids of shape {input_ids.shape}\")\n        return input_ids\n\n    def append_new_tokens(self, new_tokens: int) -> Optional[str]:\n        self.generated_ids[0, self.seq_length] = new_tokens\n        self.seq_length += 1\n        return self.streamer.put(new_tokens)\n\n    @staticmethod\n    def tf_logits_warper(generation_config):\n        \"\"\"\n        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances\n        used for multinomial sampling.\n        \"\"\"\n\n        # instantiate warpers list\n        warpers = LogitsProcessorList()\n\n        # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a\n        # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)\n        if generation_config.num_beams > 1:\n            if isinstance(generation_config._eos_token_tensor, list):\n                min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1\n            elif isinstance(generation_config._eos_token_tensor, torch.Tensor):\n                min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1\n            else:\n                min_tokens_to_keep = 2\n        else:\n            min_tokens_to_keep = 1\n\n        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files\n        # all samplers can be found in `generation_utils_samplers.py`\n        if generation_config.temperature is not None and generation_config.temperature != 1.0:\n            warpers.append(TemperatureLogitsWarper(generation_config.temperature))\n        if generation_config.top_k is not None and generation_config.top_k != 0:\n            warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.top_p is not None and generation_config.top_p < 1.0:\n            warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.min_p is not None:\n            # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)\n            warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.typical_p is not None and generation_config.typical_p < 1.0:\n            warpers.append(\n                TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)\n            )\n        if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:\n            warpers.append(\n                EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)\n            )\n        if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:\n            warpers.append(\n               EtaLogitsWarper(\n                    epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device\n                )\n            )\n        # `LogitNormalization` should always be the last logit processor, when present\n        if generation_config.renormalize_logits is True:\n            warpers.append(LogitNormalization())\n        return warpers\n\n    def prepare_logits_wrapper(self, inputs, device, temperature: Optional[float] = None, top_p: Optional[float] = None):\n        if temperature is None or temperature == 0:\n            temperature = self.model.generation_config.temperature\n        if top_p is None:\n            top_p = self.model.generation_config.top_p\n        if top_p == 0:\n            top_p = 0.0001\n        generation_config, model_kwargs = self.model._prepare_generation_config(\n            None, max_length=self.args.max_new_tokens,\n            do_sample=True, \n            top_k=self.args.top_k, \n            top_p=top_p, \n            temperature=temperature,\n            repetition_penalty=self.args.repetition_penalty # change this to modify generate config\n        )\n        self.inputs = inputs\n\n        self.logits_warper = self.tf_logits_warper(generation_config)\n\n    def logits_to_token(self, logits: torch.Tensor):\n        logits = self.logits_warper(self.inputs.view(1, -1), logits.view(1, -1))\n\n        probs = torch.nn.functional.softmax(logits, dim=-1)\n\n        sample = True\n        if sample:\n            last = torch.multinomial(probs, num_samples=1)\n        else:\n            _, last = torch.topk(probs, k=1, dim=-1)\n\n        last = last.item()\n        self.ever_generated_ids.add(last)\n        return last\n\n    def decode_one_tokens(self):\n        if self.use_static_cache:\n            logits = self.model(\n                self.current_ids,\n                cache_position=self.active_cache_position,\n                past_key_values=self.cache,\n                return_dict=False,\n                use_cache=True,\n            )[0]\n        else:\n            logits = self.model(self.current_ids, return_dict=False)[0]\n        logits = logits[0, -1, :]\n\n        return self.logits_to_token(logits)\n\n    @maybe_no_grad\n    def prefill(self, input_ids: torch.Tensor, is_new: bool, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):\n        input_ids_length = input_ids.shape[-1]\n        logger.debug(f\"input_ids: {input_ids.shape}\")\n        if max_tokens is not None:\n            max_completion_tokens = max_tokens\n        if max_completion_tokens is None:\n            max_new_tokens = self.args.max_new_tokens\n        else:\n            max_new_tokens = min(self.args.max_new_tokens, max_completion_tokens)\n        if is_new:\n            self.ever_generated_ids.clear()\n            same_prefix = 0\n            flat_input_ids = input_ids.flatten()\n\n            if getattr(self, 'generated_ids', None) is None:\n                self.generated_ids = torch.zeros(\n                    self.args.batch_size,\n                    input_ids.shape[-1] + max_new_tokens + 1,\n                    dtype=torch.int,\n                    device=self.args.device,\n                )\n                self.seq_length = 1            \n            \n            flat_prev_ids = self.generated_ids.flatten()\n            for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):\n                if flat_input_ids[i] == flat_prev_ids[i]:\n                    same_prefix += 1\n                else:\n                    break\n            \n            logger.debug(f\"same prefix len: {same_prefix}\")\n            self.cache.remove_suffix(same_prefix)\n            self.seq_length = same_prefix\n            self.generated_ids = self.generated_ids[..., :same_prefix]\n            input_ids = input_ids[..., same_prefix:]\n            input_ids_length = input_ids.shape[-1]\n        \n        self.ever_generated_ids.clear()\n        self.profiler.set_counter(\"prefill\", input_ids_length)\n        logger.debug(f\"input_ids: {input_ids.shape}\")\n\n        logger.debug(f\"generate_ids: {self.generated_ids.shape}\")\n        former_seq_length = self.seq_length\n        self.seq_length += input_ids_length\n        expected_length = self.seq_length + max_new_tokens + 1\n        delta_length = expected_length - self.generated_ids.shape[-1]\n        if delta_length > 0:\n            new_generate_ids = torch.zeros(\n                self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device\n            )\n            self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)\n            \n        logger.debug(f\"cache position: {former_seq_length} to {self.seq_length}\")\n        cache_position = torch.arange(former_seq_length, self.seq_length, device=self.args.device)\n        self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)\n\n        device = input_ids.device\n        if not (type(self) is TransformersInterface):\n            input_ids = input_ids.to(\"cpu\")\n        inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)\n        if self.use_static_cache:\n            logits = self.model(\n                inputs_embeds=inputs_embeds,\n                cache_position=cache_position,\n                past_key_values=self.cache,\n                return_dict=False,\n                use_cache=True,\n            )[0]\n        else:\n            logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]\n\n        self.prepare_logits_wrapper(input_ids, device, temperature, top_p)\n        next_token = self.logits_to_token(logits[0, -1, :])\n        self.max_new_tokens = min(max_new_tokens, self.args.cache_lens - self.seq_length) - 1 \n        yield self.append_new_tokens(next_token)\n\n    @maybe_no_grad\n    def generate(self):\n        logger.info(f\"args.max_new_tokens: {self.args.max_new_tokens}, cache_lens: {self.args.cache_lens}, seq_length: {self.seq_length}\")\n        if(self.max_new_tokens <= 0):\n            logger.warning(\"max_new_tokens is less than 0\")\n            yield self.streamer.end(), \"length\"\n            return\n        self.profiler.set_counter(\"decode\", 0)\n\n        for i in range(1, self.max_new_tokens):\n            with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):\n                if flashinfer_enabled:\n                    MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1, None,\n                                             num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, \n                                             head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,\n                                             sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)\n                next_token = self.decode_one_tokens()\n                self.profiler.inc(\"decode\")\n                if next_token == self.tokenizer.eos_token_id or \"<|im_end|>\" == self.tokenizer.decode(next_token):\n                    yield self.streamer.end(), None\n                    yield \"\", \"stop\"\n                    assert self.args.batch_size == 1\n                    break\n                yield self.append_new_tokens(next_token), None\n\n        else:   # for's else, if output get max new tokens\n            yield self.streamer.end(), None\n            yield \"\", \"length\"\n        \n        \n\n    def check_is_new(self, thread_id: str):\n        if not self.use_static_cache:\n            return True\n        if self.last_request_id is None:\n            self.last_request_id = thread_id\n            return True\n        else:\n            if self.last_request_id == thread_id:\n                return False\n            else:\n                self.last_request_id = thread_id\n                return True\n\n    async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None):\n        self.streamer.reset()\n        self.profiler.create_and_start_timer(\"tokenize\")\n        if isinstance(local_messages, List):\n            input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)\n        elif isinstance(local_messages, str):\n            #local_messages = local_messages[0]['content']\n            input_ids = self.tokenize_prompt(local_messages)\n            #input_ids = torch.tensor([[6366]], device=input_ids.device)\n        else:\n            raise ValueError(\"local_messages should be List or str\")\n        \n        if Config().user_force_think:\n            token_thinks = torch.tensor([self.tokenizer.encode(\"<think>\\n\",add_special_tokens=False)],device=input_ids.device)\n            input_ids = torch.cat(\n                [input_ids, token_thinks], dim=1\n            )\n\n        self.profiler.pause_timer(\"tokenize\")\n\n        self.profiler.create_and_start_timer(\"prefill\")\n\n        if Config().user_force_think:\n            think = '<think>\\n'\n            print(think, end=\"\",flush=True)\n            yield think, None\n        \n        for t in self.prefill(input_ids, self.check_is_new(thread_id), temperature, top_p, max_tokens, max_completion_tokens):\n            # output think token after prefill done\n            if t is not None:\n                print(t, end=\"\",flush=True)\n                yield t, None\n        self.profiler.pause_timer(\"prefill\")\n\n        self.profiler.create_and_start_timer(\"decode\")\n        for t, finish_reason in self.generate():\n            if t is not None:\n                print(t, end=\"\",flush=True)\n                yield t, finish_reason\n        print(\"\")\n        self.profiler.pause_timer(\"decode\")\n        self.report_last_time_performance()\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/config.py",
    "content": "'''\nDate: 2024-11-07 07:30:16\nLastEditors: djw\nLastEditTime: 2024-11-15 14:23:26\n'''\nimport math\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor\nfrom torch.nn import functional as F\nimport yaml\n\nimport json\nfrom typing import Optional\n\nclass ModelConfig:\n    vocab_size: int = 32000\n    n_layer: int = 1\n    n_head: int = 32\n    dim: int = 4096\n    intermediate_size: int = 18944\n    n_local_heads: int = 8\n    head_dim: int = 128\n    rope_base: float = 1000000.0\n    norm_eps: float = 1e-06\n    rope_scaling: Optional[dict] = None\n    rms_norm_eps: float = 1e-6\n    hidden_act: str = \"silu\"\n    model_path: str\n    gguf_path: str\n    optimize_rule_path: str\n    speculative_rule_path: str\n            \n\n    # quantize config\n    quant_algorithm: Optional[str] = None\n    quant_group_size: Optional[int] = None\n    quant_num_bits: Optional[int] = None\n\n    json_key_map = {\n        \"vocab_size\": \"vocab_size\",\n        \"n_layer\": \"num_hidden_layers\",\n        \"n_head\": \"num_attention_heads\",\n        \"dim\": \"hidden_size\",\n        \"intermediate_size\": \"intermediate_size\",\n        \"n_local_heads\": \"num_key_value_heads\",\n        \"rope_base\": \"rope_theta\",\n        \"norm_eps\": \"norm_eps\",\n        \"rms_norm_eps\": \"rms_norm_eps\",\n        \"hidden_act\": \"hidden_act\",\n    }\n\n    def __init__(self, config):\n        self.model_path = config[\"model\"][\"model_path\"]\n        self.gguf_path = config[\"model\"][\"gguf_path\"]\n        self.optimize_rule_path = config[\"model\"][\"optimize_rule_path\"]\n        if \"speculative_rule_path\" in config[\"model\"]:\n            self.speculative_rule_path =  config[\"model\"][\"speculative_rule_path\"]\n            self.speculative_gguf_path = config[\"model\"][\"speculative_gguf_path\"]\n            self.speculative_model_path = config[\"model\"][\"speculative_model_path\"]\n        self.quant_algorithm = config[\"model\"][\"quant\"][\"algorithm\"]\n        self.quant_group_size = config[\"model\"][\"quant\"][\"group_size\"]\n        self.quant_num_bits = config[\"model\"][\"quant\"][\"num_bits\"]\n        self.load_config()\n        self.n_layer = config[\"model\"][\"n_layers\"]\n\n    def load_config(self):\n        config_file = f\"{self.model_path}/config.json\"\n        try:\n            with open(config_file, \"r\") as f:\n                config_data = json.load(f)\n        except FileNotFoundError:\n            raise FileNotFoundError(f\"Configuration file not found at {config_file}\")\n\n        for attr, json_key in self.json_key_map.items():\n            if json_key in config_data:\n                setattr(self, attr, config_data[json_key])\n            else:\n                setattr(self, attr, getattr(self, attr))\n\n\n    \n\n\nclass ParallelConfig:\n    def __init__(\n        self,\n        config,\n    ) -> None:\n        self.pipeline_parallel_size = config[\"parallel\"][\"pp\"]\n        self.tensor_parallel_size = config[\"parallel\"][\"tp\"]\n        self.disable_custom_all_reduce = config[\"parallel\"][\"disable_custom_all_reduce\"]\n        self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size\n\nclass AttnConfig:\n    page_size: int = 256\n    block_num: int = 32\n    max_batch_token : int = 256\n    max_batch_size: int = 32\n\n    def __init__(self, config):\n        self.page_size = config[\"attn\"][\"page_size\"]\n        self.block_num = config[\"attn\"][\"block_num\"]\n        self.max_batch_token = config[\"attn\"][\"max_batch_token\"]\n        self.max_batch_size = config[\"attn\"][\"max_batch_size\"]\n\n\nclass SamplerConfig():\n\t# Batched sampling params\n    temperatures: float\n    is_all_greedy: bool\n\t\n    def __init__(self, config):\n        self.temperatures = config[\"sample\"][\"temperature\"]\n        self.is_all_greedy = True\n\n\ndef load_yaml_config(file_path):\n    with open(file_path, \"r\") as f:\n        return yaml.safe_load(f)\n    \n\n\n\nclass LLMConfig:\n    model_config: ModelConfig\n    parallel_config: ParallelConfig\n    attn_config: AttnConfig\n    sample_config: SamplerConfig\n    config_file: str\n\n    def __init__(self, config_file):\n        self.config_file = config_file\n        config = load_yaml_config(config_file)\n        self.model_config = ModelConfig(config)\n        self.parallel_config = ParallelConfig(config)\n        self.attn_config = AttnConfig(config)\n        self.sample_config = SamplerConfig(config)\n\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/distributed/__init__.py",
    "content": "from .communication_op import *\nfrom .parallel_state import *\nfrom .utils import *\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/distributed/communication_op.py",
    "content": "\"\"\"\nDate: 2024-12-11 06:02:42\nLastEditors: djw\nLastEditTime: 2024-12-12 09:52:06\n\"\"\"\n\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nimport torch.distributed\n\nfrom .parallel_state import get_tp_group\n\n\ndef tensor_model_parallel_all_reduce(input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:\n    \"\"\"All-reduce the input tensor across model parallel group.\"\"\"\n    return get_tp_group().all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)\n\n\ndef tensor_model_parallel_all_gather(\n    input_: torch.Tensor, dim: int = -1\n) -> torch.Tensor:\n    \"\"\"All-gather the input tensor across model parallel group.\"\"\"\n    return get_tp_group().all_gather(input_, dim)\n\n\ndef tensor_model_parallel_gather(\n    input_: torch.Tensor, dst: int = 0, dim: int = -1\n) -> Optional[torch.Tensor]:\n    \"\"\"Gather the input tensor across model parallel group.\"\"\"\n    return get_tp_group().gather(input_, dst, dim)\n\n\ndef broadcast_tensor_dict(\n    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0\n):\n    if not torch.distributed.is_initialized():\n        return tensor_dict\n    return get_tp_group().broadcast_tensor_dict(tensor_dict, src)\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/distributed/cuda_wrapper.py",
    "content": "\"\"\"This file is a pure Python wrapper for the cudart library.\nIt avoids the need to compile a separate shared library, and is\nconvenient for use when we just need to call a few functions.\n\"\"\"\n\nimport ctypes\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\n# this line makes it possible to directly load `libcudart.so` using `ctypes`\nimport torch  # noqa\n\n# === export types and functions from cudart to Python ===\n# for the original cudart definition, please check\n# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html\n\ncudaError_t = ctypes.c_int\ncudaMemcpyKind = ctypes.c_int\n\n\nclass cudaIpcMemHandle_t(ctypes.Structure):\n    _fields_ = [(\"internal\", ctypes.c_byte * 128)]\n\n\n@dataclass\nclass Function:\n    name: str\n    restype: Any\n    argtypes: List[Any]\n\n\ndef find_loaded_library(lib_name) -> Optional[str]:\n    \"\"\"\n    According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,\n    the file `/proc/self/maps` contains the memory maps of the process, which includes the\n    shared libraries loaded by the process. We can use this file to find the path of the\n    a loaded library.\n    \"\"\" # noqa\n    found = False\n    with open(\"/proc/self/maps\") as f:\n        for line in f:\n            if lib_name in line:\n                found = True\n                break\n    if not found:\n        # the library is not loaded in the current process\n        return None\n    # if lib_name is libcudart, we need to match a line with:\n    # address /path/to/libcudart-hash.so.11.0\n    start = line.index(\"/\")\n    path = line[start:].strip()\n    filename = path.split(\"/\")[-1]\n    assert filename.rpartition(\".so\")[0].startswith(lib_name), \\\n        f\"Unexpected filename: {filename} for library {lib_name}\"\n    return path\n\n\nclass CudaRTLibrary:\n    exported_functions = [\n        # ​cudaError_t cudaSetDevice ( int  device )\n        Function(\"cudaSetDevice\", cudaError_t, [ctypes.c_int]),\n        # cudaError_t \tcudaDeviceSynchronize ( void )\n        Function(\"cudaDeviceSynchronize\", cudaError_t, []),\n        # ​cudaError_t cudaDeviceReset ( void )\n        Function(\"cudaDeviceReset\", cudaError_t, []),\n\n        # const char* \tcudaGetErrorString ( cudaError_t error )\n        Function(\"cudaGetErrorString\", ctypes.c_char_p, [cudaError_t]),\n\n        # ​cudaError_t \tcudaMalloc ( void** devPtr, size_t size )\n        Function(\"cudaMalloc\", cudaError_t,\n                 [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),\n        # ​cudaError_t \tcudaFree ( void* devPtr )\n        Function(\"cudaFree\", cudaError_t, [ctypes.c_void_p]),\n        # ​cudaError_t cudaMemset ( void* devPtr, int  value, size_t count )\n        Function(\"cudaMemset\", cudaError_t,\n                 [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),\n        # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa\n        Function(\"cudaMemcpy\", cudaError_t, [\n            ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind\n        ]),\n\n        # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa\n        Function(\"cudaIpcGetMemHandle\", cudaError_t,\n                 [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),\n        # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int  flags ) # noqa\n        Function(\"cudaIpcOpenMemHandle\", cudaError_t, [\n            ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint\n        ]),\n    ]\n\n    # class attribute to store the mapping from the path to the library\n    # to avoid loading the same library multiple times\n    path_to_library_cache: Dict[str, Any] = {}\n\n    # class attribute to store the mapping from library path\n    #  to the corresponding dictionary\n    path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}\n\n    def __init__(self, so_file: Optional[str] = None):\n        if so_file is None:\n            so_file = find_loaded_library(\"libcudart\")\n            assert so_file is not None, \\\n                \"libcudart is not loaded in the current process\"\n        if so_file not in CudaRTLibrary.path_to_library_cache:\n            lib = ctypes.CDLL(so_file)\n            CudaRTLibrary.path_to_library_cache[so_file] = lib\n        self.lib = CudaRTLibrary.path_to_library_cache[so_file]\n\n        if so_file not in CudaRTLibrary.path_to_dict_mapping:\n            _funcs = {}\n            for func in CudaRTLibrary.exported_functions:\n                f = getattr(self.lib, func.name)\n                f.restype = func.restype\n                f.argtypes = func.argtypes\n                _funcs[func.name] = f\n            CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs\n        self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]\n\n    def CUDART_CHECK(self, result: cudaError_t) -> None:\n        if result != 0:\n            error_str = self.cudaGetErrorString(result)\n            raise RuntimeError(f\"CUDART error: {error_str}\")\n\n    def cudaGetErrorString(self, error: cudaError_t) -> str:\n        return self.funcs[\"cudaGetErrorString\"](error).decode(\"utf-8\")\n\n    def cudaSetDevice(self, device: int) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaSetDevice\"](device))\n\n    def cudaDeviceSynchronize(self) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaDeviceSynchronize\"]())\n\n    def cudaDeviceReset(self) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaDeviceReset\"]())\n\n    def cudaMalloc(self, size: int) -> ctypes.c_void_p:\n        devPtr = ctypes.c_void_p()\n        self.CUDART_CHECK(self.funcs[\"cudaMalloc\"](ctypes.byref(devPtr), size))\n        return devPtr\n\n    def cudaFree(self, devPtr: ctypes.c_void_p) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaFree\"](devPtr))\n\n    def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,\n                   count: int) -> None:\n        self.CUDART_CHECK(self.funcs[\"cudaMemset\"](devPtr, value, count))\n\n    def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,\n                   count: int) -> None:\n        cudaMemcpyDefault = 4\n        kind = cudaMemcpyDefault\n        self.CUDART_CHECK(self.funcs[\"cudaMemcpy\"](dst, src, count, kind))\n\n    def cudaIpcGetMemHandle(self,\n                            devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:\n        handle = cudaIpcMemHandle_t()\n        self.CUDART_CHECK(self.funcs[\"cudaIpcGetMemHandle\"](\n            ctypes.byref(handle), devPtr))\n        return handle\n\n    def cudaIpcOpenMemHandle(self,\n                             handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:\n        cudaIpcMemLazyEnablePeerAccess = 1\n        devPtr = ctypes.c_void_p()\n        self.CUDART_CHECK(self.funcs[\"cudaIpcOpenMemHandle\"](\n            ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))\n        return devPtr\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce.py",
    "content": "import ctypes\nfrom contextlib import contextmanager\nfrom typing import List, Optional, Union\n\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup\n\nimport server.envs as envs\nfrom server.inference.distributed.cuda_wrapper import CudaRTLibrary\nfrom server.inference.distributed.custom_all_reduce_utils import gpu_p2p_access_check\nfrom server.inference.distributed.parallel_state import in_the_same_node_as\nfrom server.inference.platforms import current_platform\nfrom server.utils import cuda_device_count_stateless\nimport vLLMCustomAllreduce\n\ntry:\n    vLLMCustomAllreduce.meta_size()\n    custom_ar = True\nexcept Exception:\n    # For AMD GPUs and CPUs\n    custom_ar = False\n\n\ndef _can_p2p(rank: int, world_size: int) -> bool:\n    for i in range(world_size):\n        if i == rank:\n            continue\n        if envs.VLLM_SKIP_P2P_CHECK:\n            print(\"Skipping P2P check and trusting the driver's P2P report.\")\n            return torch.cuda.can_device_access_peer(rank, i)\n        if not gpu_p2p_access_check(rank, i):\n            return False\n    return True\n\n\ndef is_weak_contiguous(inp: torch.Tensor):\n    return inp.is_contiguous() or (\n        inp.storage().nbytes() - inp.storage_offset() * inp.element_size()\n        == inp.numel() * inp.element_size()\n    )\n\n\nclass CustomAllreduce:\n\n    _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]\n\n    # max_size: max supported allreduce size\n    def __init__(\n        self,\n        group: ProcessGroup,\n        device: Union[int, str, torch.device],\n        max_size=8192 * 1024,\n    ) -> None:\n        \"\"\"\n        Args:\n            group: the process group to work on. If None, it will use the\n                default process group.\n            device: the device to bind the CustomAllreduce to. If None,\n                it will be bind to f\"cuda:{local_rank}\".\n        It is the caller's responsibility to make sure each communicator\n        is bind to a unique device, and all communicators in this group\n        are in the same node.\n        \"\"\"\n        self._IS_CAPTURING = False\n        self.disabled = True\n\n        if not custom_ar:\n            # disable because of missing custom allreduce library\n            # e.g. in a non-cuda environment\n            return\n\n        self.group = group\n\n        assert (\n            dist.get_backend(group) != dist.Backend.NCCL\n        ), \"CustomAllreduce should be attached to a non-NCCL group.\"\n\n        if not all(in_the_same_node_as(group, source_rank=0)):\n            # No need to initialize custom allreduce for multi-node case.\n            print(\n                \"Custom allreduce is disabled because this process group\"\n                \" spans across nodes.\"\n            )\n            return\n\n        rank = dist.get_rank(group=self.group)\n        world_size = dist.get_world_size(group=self.group)\n        if world_size == 1:\n            # No need to initialize custom allreduce for single GPU case.\n            return\n\n        if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:\n            print(\n                \"Custom allreduce is disabled due to an unsupported world\"\n                \" size: %d. Supported world sizes: %s. To silence this \"\n                \"warning, specify disable_custom_all_reduce=True explicitly.\",\n                world_size,\n                str(CustomAllreduce._SUPPORTED_WORLD_SIZES),\n            )\n            return\n\n        if isinstance(device, int):\n            device = torch.device(f\"cuda:{device}\")\n        elif isinstance(device, str):\n            device = torch.device(device)\n        # now `device` is a `torch.device` object\n        assert isinstance(device, torch.device)\n        self.device = device\n\n        cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES\n        if cuda_visible_devices:\n            device_ids = list(map(int, cuda_visible_devices.split(\",\")))\n        else:\n            device_ids = list(range(cuda_device_count_stateless()))\n\n        physical_device_id = device_ids[device.index]\n        tensor = torch.tensor([physical_device_id], dtype=torch.int, device=\"cpu\")\n        gather_list = [\n            torch.tensor([0], dtype=torch.int, device=\"cpu\") for _ in range(world_size)\n        ]\n        dist.all_gather(gather_list, tensor, group=self.group)\n        physical_device_ids = [t.item() for t in gather_list]\n\n        # test nvlink first, this will filter out most of the cases\n        # where custom allreduce is not supported\n        # this checks hardware and driver support for NVLink\n        assert current_platform.is_cuda()\n        from server.inference.platforms.cuda import CudaPlatform\n\n        cuda_platform: CudaPlatform = current_platform\n        full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)\n        if world_size > 2 and not full_nvlink:\n            print(\n                \"Custom allreduce is disabled because it's not supported on\"\n                \" more than two PCIe-only GPUs. To silence this warning, \"\n                \"specify disable_custom_all_reduce=True explicitly.\"\n            )\n            return\n        # test P2P capability, this checks software/cudaruntime support\n        # this is expensive to compute at the first time\n        # then we cache the result\n        if not _can_p2p(rank, world_size):\n            print(\n                \"Custom allreduce is disabled because your platform lacks \"\n                \"GPU P2P capability or P2P test failed. To silence this \"\n                \"warning, specify disable_custom_all_reduce=True explicitly.\"\n            )\n            return\n\n        self.disabled = False\n        # Buffers memory are owned by this Python class and passed to C++.\n        # Meta data composes of two parts: meta data for synchronization and a\n        # temporary buffer for storing intermediate allreduce results.\n        self.meta_ptrs = self.create_shared_buffer(\n            vLLMCustomAllreduce.meta_size() + max_size, group=group\n        )\n        # This is a pre-registered IPC buffer. In eager mode, input tensors\n        # are first copied into this buffer before allreduce is performed\n        self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)\n        # This is a buffer for storing the tuples of pointers pointing to\n        # IPC buffers from all ranks. Each registered tuple has size of\n        # 8*world_size bytes where world_size is at most 8. Allocating 8MB\n        # is enough for 131072 such tuples. The largest model I've seen only\n        # needs less than 10000 of registered tuples.\n        self.rank_data = torch.empty(\n            8 * 1024 * 1024, dtype=torch.uint8, device=self.device\n        )\n        self.max_size = max_size\n        self.rank = rank\n        self.world_size = world_size\n        self.full_nvlink = full_nvlink\n        self._ptr = vLLMCustomAllreduce.init_custom_ar(\n            self.meta_ptrs, self.rank_data, rank, self.full_nvlink\n        )\n        vLLMCustomAllreduce.register_buffer(self._ptr, self.buffer_ptrs)\n\n    @staticmethod\n    def create_shared_buffer(\n        size_in_bytes: int, group: Optional[ProcessGroup] = None\n    ) -> List[int]:\n        \"\"\"\n        Creates a shared buffer and returns a list of pointers\n        representing the buffer on all processes in the group.\n        \"\"\"\n        lib = CudaRTLibrary()\n        pointer = lib.cudaMalloc(size_in_bytes)\n        handle = lib.cudaIpcGetMemHandle(pointer)\n        world_size = dist.get_world_size(group=group)\n        rank = dist.get_rank(group=group)\n        handles = [None] * world_size\n        dist.all_gather_object(handles, handle, group=group)\n\n        pointers: List[int] = []\n        for i, h in enumerate(handles):\n            if i == rank:\n                pointers.append(pointer.value)  # type: ignore\n            else:\n                pointers.append(lib.cudaIpcOpenMemHandle(h).value)  # type: ignore\n\n        return pointers\n\n    @staticmethod\n    def free_shared_buffer(\n        pointers: List[int], group: Optional[ProcessGroup] = None\n    ) -> None:\n        rank = dist.get_rank(group=group)\n        lib = CudaRTLibrary()\n        lib.cudaFree(ctypes.c_void_p(pointers[rank]))\n\n    @contextmanager\n    def capture(self):\n        \"\"\"\n        The main responsibility of this context manager is the\n        `register_graph_buffers` call at the end of the context.\n        It records all the buffer addresses used in the CUDA graph.\n        \"\"\"\n        try:\n            self._IS_CAPTURING = True\n            yield\n        finally:\n            self._IS_CAPTURING = False\n            if not self.disabled:\n                self.register_graph_buffers()\n\n    def register_graph_buffers(self):\n        handle, offset = vLLMCustomAllreduce.get_graph_buffer_ipc_meta(self._ptr)\n        print(\"Registering %d cuda graph addresses\", len(offset))\n        # We cannot directly use `dist.all_gather_object` here\n        # because it is incompatible with `gloo` backend under inference mode.\n        # see https://github.com/pytorch/pytorch/issues/126032 for details.\n        all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]\n        all_data[self.rank] = [handle, offset]\n        ranks = sorted(dist.get_process_group_ranks(group=self.group))\n        for i, rank in enumerate(ranks):\n            dist.broadcast_object_list(\n                all_data[i], src=rank, group=self.group, device=\"cpu\"\n            )\n        # Unpack list of tuples to tuple of lists.\n        handles = [d[0] for d in all_data]  # type: ignore\n        offsets = [d[1] for d in all_data]  # type: ignore\n        vLLMCustomAllreduce.register_graph_buffers(self._ptr, handles, offsets)\n\n    def should_custom_ar(self, inp: torch.Tensor):\n        if self.disabled:\n            return False\n        inp_size = inp.numel() * inp.element_size()\n        # custom allreduce requires input byte size to be multiples of 16\n        if inp_size % 16 != 0:\n            return False\n        if not is_weak_contiguous(inp):\n            return False\n        # for 4 or more non NVLink-capable GPUs, custom allreduce provides\n        # little performance improvement over NCCL.\n        if self.world_size == 2 or self.full_nvlink:\n            return inp_size < self.max_size\n        return False\n\n    def all_reduce(\n        self, inp: torch.Tensor, *, out: torch.Tensor = None, bsz_tensor: torch.Tensor = None, registered: bool = False,\n        is_compute_bound=False, overlap=False\n    ):\n        \"\"\"Performs an out-of-place all reduce.\n\n        If registered is True, this assumes inp's pointer is already\n        IPC-registered. Otherwise, inp is first copied into a pre-registered\n        buffer.\n        \"\"\"\n        if is_compute_bound:\n            sms = 2 if overlap else 36\n        else:\n            sms = 20 if overlap else 36\n        #print(\"all reduce sms\", sms)\n        if out is None:\n            out = torch.empty_like(inp)\n        if registered:\n            vLLMCustomAllreduce.all_reduce(self._ptr, inp, out, 0, 0, bsz_tensor, block_limit=sms)\n        else:\n            vLLMCustomAllreduce.all_reduce(\n                self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size, bsz_tensor, block_limit=sms\n            )\n        return out\n\n    def custom_all_reduce(self, input: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> Optional[torch.Tensor]:\n        \"\"\"The main allreduce API that provides support for cuda graph.\"\"\"\n        # When custom allreduce is disabled, this will be None.\n        if self.disabled or not self.should_custom_ar(input):\n            return None\n        if self._IS_CAPTURING:\n            if torch.cuda.is_current_stream_capturing():\n                return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=True, is_compute_bound=is_compute_bound, overlap=overlap)\n            else:\n                # If warm up, mimic the allocation pattern since custom\n                # allreduce is out-of-place.\n                return torch.empty_like(input)\n        else:\n            # Note: outside of cuda graph context, custom allreduce incurs a\n            # cost of cudaMemcpy, which should be small (<=1% of overall\n            # latency) compared to the performance gain of using custom kernels\n            return self.all_reduce(input, bsz_tensor=bsz_tensor, registered=False, is_compute_bound=is_compute_bound, overlap=overlap)\n\n    def close(self):\n        if not self.disabled and self._ptr:\n            vLLMCustomAllreduce.dispose(self._ptr)\n            self._ptr = 0\n            self.free_shared_buffer(self.meta_ptrs)\n            self.free_shared_buffer(self.buffer_ptrs)\n\n    def __del__(self):\n        self.close()\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/distributed/custom_all_reduce_utils.py",
    "content": "import ctypes\nimport json\nimport os\nimport pickle\nimport subprocess\nimport sys\nimport tempfile\nfrom itertools import product\nfrom typing import Dict, List, Optional, Sequence\n\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nimport server.envs as envs\nfrom server.inference.distributed.cuda_wrapper import CudaRTLibrary\nfrom server.utils import cuda_device_count_stateless, update_environment_variables\n\n\ndef producer(\n    batch_src: Sequence[int],\n    producer_queue,\n    consumer_queue,\n    result_queue,\n    cuda_visible_devices: Optional[str] = None,\n):\n    if cuda_visible_devices is not None:\n        update_environment_variables({\"CUDA_VISIBLE_DEVICES\": cuda_visible_devices})\n\n    lib = CudaRTLibrary()\n    for i in batch_src:\n        lib.cudaSetDevice(i)\n        pointer = lib.cudaMalloc(1024)\n        lib.cudaMemset(pointer, 1, 1024)\n        lib.cudaDeviceSynchronize()\n        handle = lib.cudaIpcGetMemHandle(pointer)\n        producer_queue.put(handle)\n        open_success = consumer_queue.get()\n        if open_success:\n            # use two queues to simulate barrier\n            producer_queue.put(0)\n            consumer_queue.get()\n            # check if the memory is modified\n            host_data = (ctypes.c_char * 1024)()\n            lib.cudaMemcpy(host_data, pointer, 1024)  # type: ignore\n            for i in range(1024):\n                if ord(host_data[i]) != 2:\n                    open_success = False\n                    break\n        result_queue.put(open_success)\n        lib.cudaDeviceReset()\n\n\ndef consumer(\n    batch_tgt: Sequence[int],\n    producer_queue,\n    consumer_queue,\n    result_queue,\n    cuda_visible_devices: Optional[str] = None,\n):\n    if cuda_visible_devices is not None:\n        update_environment_variables({\"CUDA_VISIBLE_DEVICES\": cuda_visible_devices})\n\n    lib = CudaRTLibrary()\n    for j in batch_tgt:\n        lib.cudaSetDevice(j)\n        handle = producer_queue.get()\n        open_success = False\n        try:\n            pointer = lib.cudaIpcOpenMemHandle(handle)  # type: ignore\n            open_success = True\n        except RuntimeError:\n            # cannot error out here, because the producer process\n            # is still waiting for the response.\n            pass\n        consumer_queue.put(open_success)\n        if open_success:\n            # modify the memory\n            lib.cudaMemset(pointer, 2, 1024)\n            lib.cudaDeviceSynchronize()\n            # use two queues to simulate barrier\n            producer_queue.get()\n            consumer_queue.put(0)\n            # check if the memory is modified\n            host_data = (ctypes.c_char * 1024)()\n            lib.cudaMemcpy(host_data, pointer, 1024)  # type: ignore\n            for i in range(1024):\n                if ord(host_data[i]) != 2:\n                    open_success = False\n                    break\n        result_queue.put(open_success)\n        lib.cudaDeviceReset()\n\n\ndef can_actually_p2p(\n    batch_src: Sequence[int],\n    batch_tgt: Sequence[int],\n) -> Sequence[bool]:\n    \"\"\"\n    Usually, checking if P2P access is enabled can be done by\n    `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes\n    the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`\n    returns `True` even if P2P access is not actually possible.\n    See https://github.com/vllm-project/vllm/issues/2728 and\n    https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10\n    Therefore, we have to perform a real P2P access to check if it is actually\n    possible.\n\n    Note on p2p and cuda IPC:\n    Usually, one process uses one GPU:\n    GPU src --> cuda context src --> tensor src --> process src\n\n    We need to combine p2p and cuda IPC, so that:\n    GPU src --> cuda context src --> tensor src --> process src\n                                      |shared|\n    GPU tgt --> cuda context tgt --> tensor tgt --> process tgt\n    That is to say, process src creates a tensor in GPU src, passes IPC handle to\n    process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the\n    tensor in process tgt will be reflected in the tensor in process src, because\n    they are the same memory segment.\n    It is important to note that process tgt accesses the tensor in GPU tgt, not\n    GPU src. That's why we need p2p access.\n\n    The most time-consuming part is the process creation. To avoid creating\n    processes for every pair of GPUs, we use batched testing. We create two\n    processes for testing all pairs of GPUs in batch. The trick is to reset\n    the device after each test (which is not available in PyTorch).\n    \"\"\"  # noqa\n    cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES\n    # pass the CUDA_VISIBLE_DEVICES to the child process\n    # to make sure they see the same set of GPUs\n\n    # make sure the processes are spawned\n    smp = mp.get_context(\"spawn\")\n    producer_queue = smp.Queue()\n    consumer_queue = smp.Queue()\n    result_queue = smp.Queue()\n    p_src = smp.Process(\n        target=producer,\n        args=(\n            batch_src,\n            producer_queue,\n            consumer_queue,\n            result_queue,\n            cuda_visible_devices,\n        ),\n    )\n    p_tgt = smp.Process(\n        target=consumer,\n        args=(\n            batch_tgt,\n            producer_queue,\n            consumer_queue,\n            result_queue,\n            cuda_visible_devices,\n        ),\n    )\n    p_src.start()\n    p_tgt.start()\n    p_src.join()\n    p_tgt.join()\n    assert p_src.exitcode == 0 and p_tgt.exitcode == 0\n    result: List[bool] = []\n    for src, tgt in zip(batch_src, batch_tgt):\n        a = result_queue.get()\n        b = result_queue.get()\n        if a != b:\n            print(\n                \"Two processes do not agree on the P2P access\"\n                \" status on %d -> %d, treat as disabled.\",\n                src,\n                tgt,\n            )\n            result.append(False)\n        else:\n            result.append(a)\n    return result\n\n\n# why do we need this cache?\n# we are testing peer-to-peer (p2p) access between GPUs,across processes.\n# if we test it every time, it will be very slow, because we need to create\n#  N * N * 2 processes, where N is the world size. This is very slow.\n# to reduce the time, we use a cache file to store the p2p access status.\n# the cache file is generated by the master process if it does not exist.\n# then all the processes can read the cache file to check the p2p access status.\n# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we\n#  can have different cache files for different CUDA_VISIBLE_DEVICES settings,\n#  e.g. used by different vllm engines. The device id in the cache file is a\n#  **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number\n#  of visible devices in the vllm engine.\n_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None\n\n\ndef gpu_p2p_access_check(src: int, tgt: int) -> bool:\n    \"\"\"Check if GPU src can access GPU tgt.\"\"\"\n\n    # if the cache variable is already calculated,\n    # read from the cache instead of checking it again\n    global _gpu_p2p_access_cache\n    if _gpu_p2p_access_cache is not None:\n        return _gpu_p2p_access_cache[f\"{src}->{tgt}\"]\n\n    is_distributed = dist.is_initialized()\n\n    num_dev = cuda_device_count_stateless()\n    cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES\n    if cuda_visible_devices is None:\n        cuda_visible_devices = \",\".join(str(i) for i in range(num_dev))\n\n    path = os.path.join(\n        envs.VLLM_CACHE_ROOT, f\"gpu_p2p_access_cache_for_{cuda_visible_devices}.json\"\n    )\n    os.makedirs(os.path.dirname(path), exist_ok=True)\n    from server.inference.distributed.parallel_state import get_world_group\n\n    if (not is_distributed or get_world_group().local_rank == 0) and (\n        not os.path.exists(path)\n    ):\n        # only the local master process (with local_rank == 0) can\n        #  enter this block to calculate the cache\n        print(\"generating GPU P2P access cache in %s\", path)\n        cache: Dict[str, bool] = {}\n        ids = list(range(num_dev))\n        # batch of all pairs of GPUs\n        batch_src, batch_tgt = zip(*list(product(ids, ids)))\n        # NOTE: we use `subprocess` rather than `multiprocessing` here\n        # because the caller might not have `if __name__ == \"__main__\":`,\n        # in that case we cannot use spawn method in multiprocessing.\n        # However, `can_actually_p2p` requires spawn method.\n        # The fix is, we use `subprocess` to call the function,\n        # where we have `if __name__ == \"__main__\":` in this file.\n\n        # use a temporary file to store the result\n        # we don't use the output of the subprocess directly,\n        # because the subprocess might produce logging output\n        with tempfile.NamedTemporaryFile() as output_file:\n            input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))\n            returned = subprocess.run(\n                [sys.executable, __file__], input=input_bytes, capture_output=True\n            )\n            # check if the subprocess is successful\n            try:\n                returned.check_returncode()\n            except Exception as e:\n                # wrap raised exception to provide more information\n                raise RuntimeError(\n                    f\"Error happened when batch testing \"\n                    f\"peer-to-peer access from {batch_src} to {batch_tgt}:\\n\"\n                    f\"{returned.stderr.decode()}\"\n                ) from e\n            with open(output_file.name, \"rb\") as f:\n                result = pickle.load(f)\n        for _i, _j, r in zip(batch_src, batch_tgt, result):\n            cache[f\"{_i}->{_j}\"] = r\n        with open(path, \"w\") as f:\n            json.dump(cache, f, indent=4)\n    if is_distributed:\n        get_world_group().barrier()\n    print(\"reading GPU P2P access cache from %s\", path)\n    with open(path) as f:\n        cache = json.load(f)\n    _gpu_p2p_access_cache = cache\n    return _gpu_p2p_access_cache[f\"{src}->{tgt}\"]\n\n\n__all__ = [\"gpu_p2p_access_check\"]\n\nif __name__ == \"__main__\":\n    batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())\n    result = can_actually_p2p(batch_src, batch_tgt)\n    with open(output_file, \"wb\") as f:\n        f.write(pickle.dumps(result))\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/distributed/parallel_state.py",
    "content": "# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n\"\"\"vLLM distributed state.\nIt takes over the control of the distributed environment from PyTorch.\nThe typical workflow is:\n\n- call `init_distributed_environment` to initialize the distributed environment.\n- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to\n initialize the model parallel groups.\n\n- any code dealing with the distributed stuff\n\n- call `destroy_model_parallel` to destroy the model parallel groups.\n- call `destroy_distributed_environment` to destroy the distributed environment.\n\nIf you only need to use the distributed environment without model/pipeline\n parallelism, you can skip the model parallel initialization and destruction\n steps.\n\"\"\"\nimport contextlib\nimport gc\nimport pickle\nimport weakref\nfrom collections import namedtuple\nfrom contextlib import contextmanager, nullcontext\nfrom dataclasses import dataclass\nfrom multiprocessing import shared_memory\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\nfrom unittest.mock import patch\n\nimport torch\nimport torch.distributed\nfrom torch.distributed import Backend, ProcessGroup\n\nimport server.envs as envs\nfrom server.inference.platforms import current_platform\nfrom server.utils import direct_register_custom_op, supports_custom_op\n\n\n@dataclass\nclass GraphCaptureContext:\n    stream: torch.cuda.Stream\n\n\nTensorMetadata = namedtuple(\"TensorMetadata\", [\"device\", \"dtype\", \"size\"])\n\n\ndef _split_tensor_dict(\n    tensor_dict: Dict[str, Union[torch.Tensor, Any]]\n) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:\n    \"\"\"Split the tensor dictionary into two parts:\n    1. A list of (key, value) pairs. If the value is a tensor, it is replaced\n         by its metadata.\n    2. A list of tensors.\n    \"\"\"\n    metadata_list: List[Tuple[str, Any]] = []\n    tensor_list: List[torch.Tensor] = []\n    for key, value in tensor_dict.items():\n        if isinstance(value, torch.Tensor):\n            # Note: we cannot use `value.device` here,\n            # because it contains not only the device type but also the device\n            # index (e.g. \"cuda:0\"). We only need the device type.\n            # receiving side will set the device index.\n            device = value.device.type\n            metadata_list.append(\n                (key, TensorMetadata(device, value.dtype, value.size()))\n            )\n            tensor_list.append(value)\n        else:\n            metadata_list.append((key, value))\n    return metadata_list, tensor_list\n\n\n_group_name_counter: Dict[str, int] = {}\n\n\ndef _get_unique_name(name: str) -> str:\n    \"\"\"Get a unique name for the group.\n    Example:\n    _get_unique_name(\"tp\") -> \"tp:0\"\n    _get_unique_name(\"tp\") -> \"tp:1\"\n    \"\"\"\n    if name not in _group_name_counter:\n        _group_name_counter[name] = 0\n    newname = f\"{name}:{_group_name_counter[name]}\"\n    _group_name_counter[name] += 1\n    return newname\n\n\n_groups: Dict[str, Callable[[], Optional[\"GroupCoordinator\"]]] = {}\n\n\ndef _register_group(group: \"GroupCoordinator\") -> None:\n    _groups[group.unique_name] = weakref.ref(group)\n\n\nif supports_custom_op():\n\n    def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:\n        assert group_name in _groups, f\"Group {group_name} is not found.\"\n        group = _groups[group_name]()\n        if group is None:\n            raise ValueError(f\"Group {group_name} is destroyed.\")\n        group._all_reduce_in_place(tensor)\n\n    def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:\n        return\n\n    direct_register_custom_op(\n        op_name=\"inplace_all_reduce\",\n        op_func=inplace_all_reduce,\n        mutates_args=[\"tensor\"],\n        fake_impl=inplace_all_reduce_fake,\n    )\n\n    def outplace_all_reduce(tensor: torch.Tensor, group_name: str, bsz_tensor: torch.Tensor, is_compute_bound: bool = False, overlap: bool = False) -> torch.Tensor:\n        assert group_name in _groups, f\"Group {group_name} is not found.\"\n        group = _groups[group_name]()\n        if group is None:\n            raise ValueError(f\"Group {group_name} is destroyed.\")\n        return group._all_reduce_out_place(tensor, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)\n\n    def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str, bsz_tensor: torch.Tensor, is_compute_bound: bool = False, overlap: bool = False) -> torch.Tensor:\n        return torch.empty_like(tensor)\n\n    direct_register_custom_op(\n        op_name=\"outplace_all_reduce\",\n        op_func=outplace_all_reduce,\n        mutates_args=[],\n        fake_impl=outplace_all_reduce_fake,\n    )\n\n\nclass GroupCoordinator:\n    \"\"\"\n    PyTorch ProcessGroup wrapper for a group of processes.\n    PyTorch ProcessGroup is bound to one specific communication backend,\n        e.g. NCCL, Gloo, MPI, etc.\n    GroupCoordinator takes charge of all the communication operations among\n        the processes in the group. It can route the communication to\n        a specific implementation (e.g. switch allreduce implementation\n        based on the tensor size and cuda graph mode).\n    \"\"\"\n\n    # available attributes:\n    rank: int  # global rank\n    ranks: List[int]  # global ranks in the group\n    world_size: int  # size of the group\n    # difference between `local_rank` and `rank_in_group`:\n    # if we have a group of size 4 across two nodes:\n    # Process | Node | Rank | Local Rank | Rank in Group\n    #   0     |   0  |  0   |     0      |       0\n    #   1     |   0  |  1   |     1      |       1\n    #   2     |   1  |  2   |     0      |       2\n    #   3     |   1  |  3   |     1      |       3\n    local_rank: int  # local rank used to assign devices\n    rank_in_group: int  # rank inside the group\n    cpu_group: ProcessGroup  # group for CPU communication\n    device_group: ProcessGroup  # group for device communication\n    use_pynccl: bool  # a hint of whether to use PyNccl\n    use_custom_allreduce: bool  # a hint of whether to use CustomAllreduce\n    # communicators are only created for world size > 1\n    pynccl_comm: Optional[Any]  # PyNccl communicator\n    ca_comm: Optional[Any]  # Custom allreduce communicator\n    mq_broadcaster: Optional[Any]  # shared memory broadcaster\n\n    def __init__(\n        self,\n        group_ranks: List[List[int]],\n        local_rank: int,\n        torch_distributed_backend: Union[str, Backend],\n        use_pynccl: bool,\n        use_custom_allreduce: bool,\n        use_tpu_communicator: bool,\n        use_hpu_communicator: bool,\n        use_xpu_communicator: bool,\n        use_message_queue_broadcaster: bool = False,\n        group_name: Optional[str] = None,\n    ):\n        group_name = group_name or \"anonymous\"\n        self.unique_name = _get_unique_name(group_name)\n        _register_group(self)\n\n        self.rank = torch.distributed.get_rank()\n        self.local_rank = local_rank\n        self.device_group = None\n        self.cpu_group = None\n\n        for ranks in group_ranks:\n            device_group = torch.distributed.new_group(\n                ranks, backend=torch_distributed_backend\n            )\n            # a group with `gloo` backend, to allow direct coordination between\n            # processes through the CPU.\n            cpu_group = torch.distributed.new_group(ranks, backend=\"gloo\")\n            if self.rank in ranks:\n                self.ranks = ranks\n                self.world_size = len(ranks)\n                self.rank_in_group = ranks.index(self.rank)\n                self.device_group = device_group\n                self.cpu_group = cpu_group\n\n        assert self.cpu_group is not None\n        assert self.device_group is not None\n        assert current_platform.is_cuda_alike()\n\n        if current_platform.is_cuda_alike():\n            self.device = torch.device(f\"cuda:{local_rank}\")\n        else:\n            self.device = torch.device(\"cpu\")\n\n        self.use_pynccl = use_pynccl\n        self.use_custom_allreduce = use_custom_allreduce\n        self.use_tpu_communicator = use_tpu_communicator\n        self.use_hpu_communicator = use_hpu_communicator\n        self.use_xpu_communicator = use_xpu_communicator\n\n        # lazy import to avoid documentation build error\n        from server.inference.distributed.custom_all_reduce import CustomAllreduce\n        from server.inference.distributed.pynccl import PyNcclCommunicator\n\n        self.pynccl_comm: Optional[PyNcclCommunicator] = None\n        # if use_pynccl and self.world_size > 1:\n        #     self.pynccl_comm = PyNcclCommunicator(\n        #         group=self.cpu_group,\n        #         device=self.device,\n        #     )\n\n        self.ca_comm: Optional[CustomAllreduce] = None\n        if use_custom_allreduce and self.world_size > 1:\n            # Initialize a custom fast all-reduce implementation.\n            self.ca_comm = CustomAllreduce(\n                group=self.cpu_group,\n                device=self.device,\n            )\n\n        #### we assume we won't use tpu or hpu or xpu or messagequeue broadcast\n\n        # from vllm.distributed.device_communicators.tpu_communicator import (\n        #     TpuCommunicator)\n        # self.tpu_communicator: Optional[TpuCommunicator] = None\n        # if use_tpu_communicator and self.world_size > 1:\n        #     self.tpu_communicator = TpuCommunicator(group=self.cpu_group)\n        self.tpu_communicator = None\n\n        # from vllm.distributed.device_communicators.hpu_communicator import (\n        #     HpuCommunicator)\n        # self.hpu_communicator: Optional[HpuCommunicator]\n        # if use_hpu_communicator and self.world_size > 1:\n        #     self.hpu_communicator = HpuCommunicator(group=self.device_group)\n        self.hpu_communicator = None\n\n        # from vllm.distributed.device_communicators.xpu_communicator import (\n        #     XpuCommunicator)\n        # self.xpu_communicator: Optional[XpuCommunicator]\n        # if use_xpu_communicator and self.world_size > 1:\n        #     self.xpu_communicator = XpuCommunicator(group=self.device_group)\n        self.xpu_communicator = None\n\n        # from vllm.distributed.device_communicators.shm_broadcast import (\n        #     MessageQueue)\n        # self.mq_broadcaster: Optional[MessageQueue] = None\n        # if use_message_queue_broadcaster and self.world_size > 1:\n        #     self.mq_broadcaster = MessageQueue.create_from_process_group(\n        #         self.cpu_group, 1 << 22, 6)\n        self.mq_broadcaster = None\n\n    @property\n    def first_rank(self):\n        \"\"\"Return the global rank of the first process in the group\"\"\"\n        return self.ranks[0]\n\n    @property\n    def last_rank(self):\n        \"\"\"Return the global rank of the last process in the group\"\"\"\n        return self.ranks[-1]\n\n    @property\n    def is_first_rank(self):\n        \"\"\"Return whether the caller is the first process in the group\"\"\"\n        return self.rank == self.first_rank\n\n    @property\n    def is_last_rank(self):\n        \"\"\"Return whether the caller is the last process in the group\"\"\"\n        return self.rank == self.last_rank\n\n    @property\n    def next_rank(self):\n        \"\"\"Return the global rank of the process that follows the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return self.ranks[(rank_in_group + 1) % world_size]\n\n    @property\n    def prev_rank(self):\n        \"\"\"Return the global rank of the process that precedes the caller\"\"\"\n        rank_in_group = self.rank_in_group\n        world_size = self.world_size\n        return self.ranks[(rank_in_group - 1) % world_size]\n\n    @contextmanager\n    def graph_capture(\n        self, graph_capture_context: Optional[GraphCaptureContext] = None\n    ):\n        if graph_capture_context is None:\n            stream = torch.cuda.Stream()\n            graph_capture_context = GraphCaptureContext(stream)\n        else:\n            stream = graph_capture_context.stream\n\n        ca_comm = self.ca_comm\n        maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()\n\n        # ensure all initialization operations complete before attempting to\n        # capture the graph on another stream\n        curr_stream = torch.cuda.current_stream()\n        if curr_stream != stream:\n            stream.wait_stream(curr_stream)\n\n        with torch.cuda.stream(stream), maybe_ca_context:\n            # In graph mode, we have to be very careful about the collective\n            # operations. The current status is:\n            #     allreduce \\ Mode   |  Eager  |  Graph  |\n            # --------------------------------------------\n            # custom allreduce       | enabled | enabled |\n            # PyNccl                 | disabled| enabled |\n            # torch.distributed      | enabled | disabled|\n            #\n            # Note that custom allreduce will have a runtime check, if the\n            #  tensor size is too large, it will fallback to the next\n            #  available option.\n            # In summary: When using CUDA graph, we use\n            #  either custom all-reduce kernel or pynccl. When not using\n            #  CUDA graph, we use either custom all-reduce kernel or\n            #  PyTorch NCCL. We always prioritize using custom all-reduce\n            #  kernel but fall back to PyTorch or pynccl if it is\n            #  disabled or not supported.\n            pynccl_comm = self.pynccl_comm\n            maybe_pynccl_context: Any\n            if not pynccl_comm:\n                maybe_pynccl_context = nullcontext()\n            else:\n                maybe_pynccl_context = pynccl_comm.change_state(\n                    enable=True, stream=torch.cuda.current_stream()\n                )\n            with maybe_pynccl_context:\n                yield graph_capture_context\n\n    def all_reduce(self, input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:\n        \"\"\"\n        User-facing all-reduce function before we actually call the\n        all-reduce operation.\n\n        We need this because Dynamo does not support passing an arbitrary\n        object (`self` in this case) to a custom op. We need to pass the\n         group name as a string, and then look up the group coordinator from\n         the group name, dispatch the all-reduce operation to the group\n         coordinator.\n\n        In addition, PyTorch custom ops do not support mutation or returning\n        a new tensor in the same op. So we need to figure out if the op is\n        in-place or out-of-place ahead of time.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return input_\n\n        if input_.is_cpu:\n            import intel_extension_for_pytorch as ipex\n\n            ipex.distributed.all_reduce(input_, group=self.device_group)\n            return input_\n\n        if not supports_custom_op():\n            self._all_reduce_in_place(input_)\n            return input_\n\n        if self.tpu_communicator is not None and not self.tpu_communicator.disabled:\n            # TPU handles Dynamo with its own logic.\n            return self.tpu_communicator.all_reduce(input_)\n\n        if self.hpu_communicator is not None and not self.hpu_communicator.disabled:\n            return self.hpu_communicator.all_reduce(input_)\n\n        if self.xpu_communicator is not None and not self.xpu_communicator.disabled:\n            return self.xpu_communicator.all_reduce(input_)\n\n        if (\n            self.ca_comm is not None\n            and not self.ca_comm.disabled\n            and self.ca_comm.should_custom_ar(input_)\n        ):\n            return torch.ops.vllm.outplace_all_reduce(\n                input_, group_name=self.unique_name, bsz_tensor=bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap\n            )\n        else:\n            #assert self.ca_comm is not None\n            #assert not self.ca_comm.disabled\n            #assert self.ca_comm.should_custom_ar(input_)\n            torch.ops.vllm.inplace_all_reduce(input_, group_name=self.unique_name)\n            return input_\n\n    def _all_reduce_out_place(self, input_: torch.Tensor, bsz_tensor: torch.Tensor, is_compute_bound=False, overlap=False) -> torch.Tensor:\n        ca_comm = self.ca_comm\n        assert ca_comm is not None\n        assert not ca_comm.disabled\n        out = ca_comm.custom_all_reduce(input_, bsz_tensor, is_compute_bound=is_compute_bound, overlap=overlap)\n        assert out is not None\n        return out\n\n    def _all_reduce_in_place(self, input_: torch.Tensor) -> None:\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.all_reduce(input_)\n        else:\n            torch.distributed.all_reduce(input_, group=self.device_group)\n\n    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:\n        world_size = self.world_size\n        # Bypass the function if we are using only 1 GPU.\n        if world_size == 1:\n            return input_\n        assert (\n            -input_.dim() <= dim < input_.dim()\n        ), f\"Invalid dim ({dim}) for input tensor with shape {input_.size()}\"\n\n        # For TPUs, use TPU communicator.\n        tpu_comm = self.tpu_communicator\n        if tpu_comm is not None and not tpu_comm.disabled:\n            return tpu_comm.all_gather(input_, dim)\n\n        # For HPUs, use HPU communicator.\n        hpu_comm = self.hpu_communicator\n        if hpu_comm is not None and not hpu_comm.disabled:\n            return hpu_comm.all_gather(input_, dim)\n\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n        input_size = input_.size()\n        # NOTE: we have to use concat-style all-gather here,\n        # stack-style all-gather has compatibility issues with\n        # torch.compile . see https://github.com/pytorch/pytorch/issues/138795\n        output_size = (input_size[0] * world_size,) + input_size[1:]\n        # Allocate output tensor.\n        output_tensor = torch.empty(\n            output_size, dtype=input_.dtype, device=input_.device\n        )\n        # All-gather.\n        torch.distributed.all_gather_into_tensor(\n            output_tensor, input_, group=self.device_group\n        )\n        # Reshape\n        output_tensor = output_tensor.reshape((world_size,) + input_size)\n        output_tensor = output_tensor.movedim(0, dim)\n        output_tensor = output_tensor.reshape(\n            input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]\n        )\n        return output_tensor\n\n    def gather(\n        self, input_: torch.Tensor, dst: int = 0, dim: int = -1\n    ) -> Optional[torch.Tensor]:\n        \"\"\"\n        NOTE: We assume that the input tensor is on the same device across\n        all the ranks.\n        NOTE: `dst` is the local rank of the destination rank.\n        \"\"\"\n        world_size = self.world_size\n        # Bypass the function if we are using only 1 GPU.\n        if world_size == 1:\n            return input_\n        assert (\n            -input_.dim() <= dim < input_.dim()\n        ), f\"Invalid dim ({dim}) for input tensor with shape {input_.size()}\"\n        if dim < 0:\n            # Convert negative dim to positive.\n            dim += input_.dim()\n        if self.xpu_communicator is not None and not self.xpu_communicator.disabled:\n            return self.xpu_communicator.gather(input_, self.rank_in_group, dst, dim)\n        # Allocate output tensor.\n        if self.rank_in_group == dst:\n            gather_list = [torch.empty_like(input_) for _ in range(world_size)]\n        else:\n            gather_list = None\n        # Gather.\n        torch.distributed.gather(\n            input_, gather_list, dst=self.ranks[dst], group=self.device_group\n        )\n        if self.rank_in_group == dst:\n            output_tensor = torch.cat(gather_list, dim=dim)\n        else:\n            output_tensor = None\n        return output_tensor\n\n    def broadcast(self, input_: torch.Tensor, src: int = 0):\n        \"\"\"Broadcast the input tensor.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return input_\n        # Broadcast.\n        torch.distributed.broadcast(\n            input_, src=self.ranks[src], group=self.device_group\n        )\n        return input_\n\n    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):\n        \"\"\"Broadcast the input object.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return obj\n        if self.mq_broadcaster is not None:\n            assert src == 0, \"Message queue broadcaster only supports src=0\"\n            return self.mq_broadcaster.broadcast_object(obj)\n        if self.rank_in_group == src:\n            torch.distributed.broadcast_object_list(\n                [obj], src=self.ranks[src], group=self.cpu_group\n            )\n            return obj\n        else:\n            recv = [None]\n            torch.distributed.broadcast_object_list(\n                recv, src=self.ranks[src], group=self.cpu_group\n            )\n            return recv[0]\n\n    def broadcast_object_list(\n        self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None\n    ):\n        \"\"\"Broadcast the input object list.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        # Bypass the function if we are using only 1 GPU.\n        if self.world_size == 1:\n            return obj_list\n        # Broadcast.\n        torch.distributed.broadcast_object_list(\n            obj_list, src=self.ranks[src], group=self.device_group\n        )\n        return obj_list\n\n    def send_object(self, obj: Any, dst: int) -> None:\n        \"\"\"Send the input object list to the destination rank.\"\"\"\n        \"\"\"NOTE: `dst` is the local rank of the destination rank.\"\"\"\n\n        assert dst < self.world_size, f\"Invalid dst rank ({dst})\"\n\n        assert dst != self.rank_in_group, (\n            \"Invalid destination rank. Destination rank is the same \"\n            \"as the current rank.\"\n        )\n\n        # Serialize object to tensor and get the size as well\n        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)\n\n        size_tensor = torch.tensor(\n            [object_tensor.numel()], dtype=torch.long, device=\"cpu\"\n        )\n\n        # Send object size\n\n        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)\n\n        # Send object\n        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)\n\n        return None\n\n    def recv_object(self, src: int) -> Any:\n        \"\"\"Receive the input object list from the source rank.\"\"\"\n        \"\"\"NOTE: `src` is the local rank of the source rank.\"\"\"\n\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        assert (\n            src != self.rank_in_group\n        ), \"Invalid source rank. Source rank is the same as the current rank.\"\n\n        size_tensor = torch.empty(1, dtype=torch.long, device=\"cpu\")\n\n        # Receive object size\n        rank_size = torch.distributed.recv(\n            size_tensor, src=self.ranks[src], group=self.cpu_group\n        )\n\n        # Tensor to receive serialized objects into.\n        object_tensor = torch.empty(  # type: ignore[call-overload]\n            size_tensor.item(),  # type: ignore[arg-type]\n            dtype=torch.uint8,\n            device=\"cpu\",\n        )\n\n        rank_object = torch.distributed.recv(\n            object_tensor, src=self.ranks[src], group=self.cpu_group\n        )\n\n        assert (\n            rank_object == rank_size\n        ), \"Received object sender rank does not match the size sender rank.\"\n\n        obj = pickle.loads(object_tensor.numpy().tobytes())\n\n        return obj\n\n    def broadcast_tensor_dict(\n        self,\n        tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,\n        src: int = 0,\n        group: Optional[ProcessGroup] = None,\n        metadata_group: Optional[ProcessGroup] = None,\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Broadcast the input tensor dictionary.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return tensor_dict\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        rank_in_group = self.rank_in_group\n        if rank_in_group == src:\n            metadata_list: List[Tuple[Any, Any]] = []\n            assert isinstance(\n                tensor_dict, dict\n            ), f\"Expecting a dictionary, got {type(tensor_dict)}\"\n            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)\n            # `metadata_list` lives in CPU memory.\n            # `broadcast_object_list` has serialization & deserialization,\n            # all happening on CPU. Therefore, we can use the CPU group.\n            self.broadcast_object(metadata_list, src=src)\n            async_handles = []\n            for tensor in tensor_list:\n                if tensor.numel() == 0:\n                    # Skip broadcasting empty tensors.\n                    continue\n                if tensor.is_cpu:\n                    # use metadata_group for CPU tensors\n                    handle = torch.distributed.broadcast(\n                        tensor, src=self.ranks[src], group=metadata_group, async_op=True\n                    )\n                else:\n                    # use group for GPU tensors\n                    handle = torch.distributed.broadcast(\n                        tensor, src=self.ranks[src], group=group, async_op=True\n                    )\n                async_handles.append(handle)\n            for async_handle in async_handles:\n                async_handle.wait()\n\n        else:\n            metadata_list = self.broadcast_object(None, src=src)\n            tensor_dict = {}\n            async_handles = []\n            for key, value in metadata_list:\n                if isinstance(value, TensorMetadata):\n                    tensor = torch.empty(\n                        value.size, dtype=value.dtype, device=value.device\n                    )\n                    if tensor.numel() == 0:\n                        # Skip broadcasting empty tensors.\n                        tensor_dict[key] = tensor\n                        continue\n                    if tensor.is_cpu:\n                        # use metadata_group for CPU tensors\n                        handle = torch.distributed.broadcast(\n                            tensor,\n                            src=self.ranks[src],\n                            group=metadata_group,\n                            async_op=True,\n                        )\n                    else:\n                        # use group for GPU tensors\n                        handle = torch.distributed.broadcast(\n                            tensor, src=self.ranks[src], group=group, async_op=True\n                        )\n                    async_handles.append(handle)\n                    tensor_dict[key] = tensor\n                else:\n                    tensor_dict[key] = value\n            for async_handle in async_handles:\n                async_handle.wait()\n        return tensor_dict\n\n    def send_tensor_dict(\n        self,\n        tensor_dict: Dict[str, Union[torch.Tensor, Any]],\n        dst: Optional[int] = None,\n        all_gather_group: Optional[\"GroupCoordinator\"] = None,\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Send the input tensor dictionary.\n        NOTE: `dst` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return tensor_dict\n\n        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size\n        all_gather_rank = (\n            0 if all_gather_group is None else all_gather_group.rank_in_group\n        )\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n\n        if dst is None:\n            dst = (self.rank_in_group + 1) % self.world_size\n        assert dst < self.world_size, f\"Invalid dst rank ({dst})\"\n\n        metadata_list: List[Tuple[Any, Any]] = []\n        assert isinstance(\n            tensor_dict, dict\n        ), f\"Expecting a dictionary, got {type(tensor_dict)}\"\n        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)\n        # `metadata_list` lives in CPU memory.\n        # `send_object_list` has serialization & deserialization,\n        # all happening on CPU. Therefore, we can use the CPU group.\n        self.send_object(metadata_list, dst=dst)\n        for tensor in tensor_list:\n            if tensor.numel() == 0:\n                # Skip sending empty tensors.\n                continue\n\n            # send-allgather: send only a slice, then do allgather.\n            if all_gather_group is not None and tensor.numel() % all_gather_size == 0:\n                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]\n\n            if tensor.is_cpu:\n                # use metadata_group for CPU tensors\n                torch.distributed.send(\n                    tensor, dst=self.ranks[dst], group=metadata_group\n                )\n            else:\n                # use group for GPU tensors\n                torch.distributed.send(tensor, dst=self.ranks[dst], group=group)\n        return None\n\n    def recv_tensor_dict(\n        self,\n        src: Optional[int] = None,\n        all_gather_group: Optional[\"GroupCoordinator\"] = None,\n    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:\n        \"\"\"Recv the input tensor dictionary.\n        NOTE: `src` is the local rank of the source rank.\n        \"\"\"\n        # Bypass the function if we are using only 1 GPU.\n        if not torch.distributed.is_initialized() or self.world_size == 1:\n            return None\n\n        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size\n        all_gather_rank = (\n            0 if all_gather_group is None else all_gather_group.rank_in_group\n        )\n\n        group = self.device_group\n        metadata_group = self.cpu_group\n\n        if src is None:\n            src = (self.rank_in_group - 1) % self.world_size\n        assert src < self.world_size, f\"Invalid src rank ({src})\"\n\n        recv_metadata_list = self.recv_object(src=src)\n        tensor_dict: Dict[str, Any] = {}\n        for key, value in recv_metadata_list:\n            if isinstance(value, TensorMetadata):\n                tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)\n                if tensor.numel() == 0:\n                    # Skip broadcasting empty tensors.\n                    tensor_dict[key] = tensor\n                    continue\n\n                # send-allgather: send only a slice, then do allgather.\n                use_all_gather = (\n                    all_gather_group is not None\n                    and tensor.numel() % all_gather_size == 0\n                )\n\n                if use_all_gather:\n                    orig_shape = tensor.shape\n                    tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]\n\n                if tensor.is_cpu:\n                    # use metadata_group for CPU tensors\n                    torch.distributed.recv(\n                        tensor, src=self.ranks[src], group=metadata_group\n                    )\n                else:\n                    # use group for GPU tensors\n                    torch.distributed.recv(tensor, src=self.ranks[src], group=group)\n                if use_all_gather:\n                    # do the allgather\n                    tensor = all_gather_group.all_gather(tensor, dim=0)  # type: ignore\n                    tensor = tensor.reshape(orig_shape)\n\n                tensor_dict[key] = tensor\n            else:\n                tensor_dict[key] = value\n        return tensor_dict\n\n    def barrier(self):\n        \"\"\"Barrier synchronization among the group.\n        NOTE: don't use `device_group` here! `barrier` in NCCL is\n        terrible because it is internally a broadcast operation with\n        secretly created GPU tensors. It is easy to mess up the current\n        device. Use the CPU group instead.\n        \"\"\"\n        torch.distributed.barrier(group=self.cpu_group)\n\n    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:\n        \"\"\"Sends a tensor to the destination rank in a non-blocking way\"\"\"\n        \"\"\"NOTE: `dst` is the local rank of the destination rank.\"\"\"\n        if dst is None:\n            dst = (self.rank_in_group + 1) % self.world_size\n\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.send(tensor, dst)\n        else:\n            torch.distributed.send(tensor, self.ranks[dst], self.device_group)\n\n    def recv(\n        self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None\n    ) -> torch.Tensor:\n        \"\"\"Receives a tensor from the source rank.\"\"\"\n        \"\"\"NOTE: `src` is the local rank of the source rank.\"\"\"\n        if src is None:\n            src = (self.rank_in_group - 1) % self.world_size\n\n        tensor = torch.empty(size, dtype=dtype, device=self.device)\n        pynccl_comm = self.pynccl_comm\n        if pynccl_comm is not None and not pynccl_comm.disabled:\n            pynccl_comm.recv(tensor, src)\n        else:\n            torch.distributed.recv(tensor, self.ranks[src], self.device_group)\n        return tensor\n\n    def destroy(self):\n        if self.device_group is not None:\n            torch.distributed.destroy_process_group(self.device_group)\n            self.device_group = None\n        if self.cpu_group is not None:\n            torch.distributed.destroy_process_group(self.cpu_group)\n            self.cpu_group = None\n        if self.pynccl_comm is not None:\n            self.pynccl_comm = None\n        if self.ca_comm is not None:\n            self.ca_comm = None\n        if self.mq_broadcaster is not None:\n            self.mq_broadcaster = None\n\n\n_WORLD: Optional[GroupCoordinator] = None\n\n\ndef get_world_group() -> GroupCoordinator:\n    assert _WORLD is not None, \"world group is not initialized\"\n    return _WORLD\n\n\ndef init_world_group(\n    ranks: List[int], local_rank: int, backend: str\n) -> GroupCoordinator:\n    return GroupCoordinator(\n        group_ranks=[ranks],\n        local_rank=local_rank,\n        torch_distributed_backend=backend,\n        use_pynccl=False,\n        use_custom_allreduce=False,\n        use_tpu_communicator=False,\n        use_hpu_communicator=False,\n        use_xpu_communicator=False,\n        group_name=\"world\",\n    )\n\n\ndef init_model_parallel_group(\n    group_ranks: List[List[int]],\n    local_rank: int,\n    backend: str,\n    use_custom_allreduce: Optional[bool] = None,\n    use_message_queue_broadcaster: bool = False,\n    group_name: Optional[str] = None,\n) -> GroupCoordinator:\n    if use_custom_allreduce is None:\n        use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE\n    return GroupCoordinator(\n        group_ranks=group_ranks,\n        local_rank=local_rank,\n        torch_distributed_backend=backend,\n        use_pynccl=True,\n        use_custom_allreduce=use_custom_allreduce,\n        use_tpu_communicator=True,\n        use_hpu_communicator=True,\n        use_xpu_communicator=True,\n        use_message_queue_broadcaster=use_message_queue_broadcaster,\n        group_name=group_name,\n    )\n\n\n_TP: Optional[GroupCoordinator] = None\n\n\ndef get_tp_group() -> GroupCoordinator:\n    assert _TP is not None, \"tensor model parallel group is not initialized\"\n    return _TP\n\n\n# kept for backward compatibility\nget_tensor_model_parallel_group = get_tp_group\n\n_PP: Optional[GroupCoordinator] = None\n\n\ndef get_pp_group() -> GroupCoordinator:\n    assert _PP is not None, \"pipeline model parallel group is not initialized\"\n    return _PP\n\n\n# kept for backward compatibility\nget_pipeline_model_parallel_group = get_pp_group\n\n\n@contextmanager\ndef graph_capture():\n    \"\"\"\n    `graph_capture` is a context manager which should surround the code that\n    is capturing the CUDA graph. Its main purpose is to ensure that the\n    some operations will be run after the graph is captured, before the graph\n    is replayed. It returns a `GraphCaptureContext` object which contains the\n    necessary data for the graph capture. Currently, it only contains the\n    stream that the graph capture is running on. This stream is set to the\n    current CUDA stream when the context manager is entered and reset to the\n    default stream when the context manager is exited. This is to ensure that\n    the graph capture is running on a separate stream from the default stream,\n    in order to explicitly distinguish the kernels to capture\n    from other kernels possibly launched on background in the default stream.\n    \"\"\"\n    with get_tp_group().graph_capture() as context, get_pp_group().graph_capture(\n        context\n    ):\n        yield context\n\n\n_ENABLE_CUSTOM_ALL_REDUCE = True\n\n\ndef set_custom_all_reduce(enable: bool):\n    global _ENABLE_CUSTOM_ALL_REDUCE\n    _ENABLE_CUSTOM_ALL_REDUCE = enable\n\n\ndef init_distributed_environment(\n    world_size: int = -1,\n    rank: int = -1,\n    distributed_init_method: str = \"env://\",\n    local_rank: int = -1,\n    backend: str = \"nccl\",\n):\n    print(\n        \"world_size=%d rank=%d local_rank=%d \" \"distributed_init_method=%s backend=%s\",\n        world_size,\n        rank,\n        local_rank,\n        distributed_init_method,\n        backend,\n    )\n    if not torch.distributed.is_initialized():\n        assert distributed_init_method is not None, (\n            \"distributed_init_method must be provided when initializing \"\n            \"distributed environment\"\n        )\n        # this backend is used for WORLD\n        torch.distributed.init_process_group(\n            backend=backend,\n            init_method=distributed_init_method,\n            world_size=world_size,\n            rank=rank,\n        )\n    # set the local rank\n    # local_rank is not available in torch ProcessGroup,\n    # see https://github.com/pytorch/pytorch/issues/122816\n    if local_rank == -1:\n        # local rank not set, this usually happens in single-node\n        # setting, where we can use rank as local rank\n        if distributed_init_method == \"env://\":\n            local_rank = envs.LOCAL_RANK\n        else:\n            local_rank = rank\n    global _WORLD\n    if _WORLD is None:\n        ranks = list(range(torch.distributed.get_world_size()))\n        _WORLD = init_world_group(ranks, local_rank, backend)\n    else:\n        assert (\n            _WORLD.world_size == torch.distributed.get_world_size()\n        ), \"world group already initialized with a different world size\"\n\n\ndef initialize_model_parallel(\n    tensor_model_parallel_size: int = 1,\n    pipeline_model_parallel_size: int = 1,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"\n    Initialize model parallel groups.\n\n    Arguments:\n        tensor_model_parallel_size: number of GPUs used for tensor model\n            parallelism.\n        pipeline_model_parallel_size: number of GPUs used for pipeline model\n            parallelism.\n\n    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we\n    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize\n    the model pipeline. The present function will\n    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:\n        4 tensor model-parallel groups:\n            [g0, g1], [g2, g3], [g4, g5], [g6, g7]\n        2 pipeline model-parallel groups:\n            [g0, g2, g4, g6], [g1, g3, g5, g7]\n    Note that for efficiency, the caller should make sure adjacent ranks\n    are on the same DGX box. For example if we are using 2 DGX-1 boxes\n    with a total of 16 GPUs, rank 0 to 7 belong to the first box and\n    ranks 8 to 15 belong to the second box.\n    \"\"\"\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n    world_size: int = torch.distributed.get_world_size()\n    backend = backend or torch.distributed.get_backend(get_world_group().device_group)\n\n    if world_size != tensor_model_parallel_size * pipeline_model_parallel_size:\n        raise RuntimeError(\n            f\"world_size ({world_size}) is not equal to \"\n            f\"tensor_model_parallel_size ({tensor_model_parallel_size}) x \"\n            f\"pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n        )\n\n    # Build the tensor model-parallel groups.\n    num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n    global _TP\n    assert _TP is None, \"tensor model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_tensor_model_parallel_groups):\n        ranks = list(\n            range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n        )\n        group_ranks.append(ranks)\n\n    # message queue broadcaster is only used in tensor model parallel group\n    _TP = init_model_parallel_group(\n        group_ranks,\n        get_world_group().local_rank,\n        backend,\n        use_message_queue_broadcaster=True,\n        group_name=\"tp\",\n    )\n\n    # Build the pipeline model-parallel groups.\n    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n    global _PP\n    assert _PP is None, \"pipeline model parallel group is already initialized\"\n    group_ranks = []\n    for i in range(num_pipeline_model_parallel_groups):\n        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))\n        group_ranks.append(ranks)\n    # pipeline parallel does not need custom allreduce\n    _PP = init_model_parallel_group(\n        group_ranks,\n        get_world_group().local_rank,\n        backend,\n        use_custom_allreduce=False,\n        group_name=\"pp\",\n    )\n\n\ndef ensure_model_parallel_initialized(\n    tensor_model_parallel_size: int,\n    pipeline_model_parallel_size: int,\n    backend: Optional[str] = None,\n) -> None:\n    \"\"\"Helper to initialize model parallel groups if they are not initialized,\n    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected\n    values if the model parallel groups are initialized.\n    \"\"\"\n    backend = backend or torch.distributed.get_backend(get_world_group().device_group)\n    if not model_parallel_is_initialized():\n        initialize_model_parallel(\n            tensor_model_parallel_size, pipeline_model_parallel_size, backend\n        )\n        return\n\n    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (\n        \"tensor parallel group already initialized, but of unexpected size: \"\n        f\"{get_tensor_model_parallel_world_size()=} vs. \"\n        f\"{tensor_model_parallel_size=}\"\n    )\n    pp_world_size = get_pp_group().world_size\n    assert pp_world_size == pipeline_model_parallel_size, (\n        \"pipeline parallel group already initialized, but of unexpected size: \"\n        f\"{pp_world_size=} vs. \"\n        f\"{pipeline_model_parallel_size=}\"\n    )\n\n\ndef model_parallel_is_initialized():\n    \"\"\"Check if tensor and pipeline parallel groups are initialized.\"\"\"\n    return _TP is not None and _PP is not None\n\n\n_TP_STATE_PATCHED = False\n\n\n@contextmanager\ndef patch_tensor_parallel_group(tp_group: GroupCoordinator):\n    \"\"\"Patch the tp group temporarily until this function ends.\n\n    This method is for draft workers of speculative decoding to run draft model\n    with different tp degree from that of target model workers.\n\n    Args:\n        tp_group (GroupCoordinator): the tp group coordinator\n    \"\"\"\n    global _TP_STATE_PATCHED\n    assert not _TP_STATE_PATCHED, \"Should not call when it's already patched\"\n\n    _TP_STATE_PATCHED = True\n    old_tp_group = get_tp_group()\n    global _TP\n    _TP = tp_group\n    try:\n        yield\n    finally:\n        # restore the original state\n        _TP_STATE_PATCHED = False\n        _TP = old_tp_group\n\n\ndef get_tensor_model_parallel_world_size():\n    \"\"\"Return world size for the tensor model parallel group.\"\"\"\n    return get_tp_group().world_size\n\n\ndef get_tensor_model_parallel_rank():\n    \"\"\"Return my rank for the tensor model parallel group.\"\"\"\n    return get_tp_group().rank_in_group\n\n\ndef destroy_model_parallel():\n    \"\"\"Set the groups to none and destroy them.\"\"\"\n    global _TP\n    if _TP:\n        _TP.destroy()\n    _TP = None\n\n    global _PP\n    if _PP:\n        _PP.destroy()\n    _PP = None\n\n\ndef destroy_distributed_environment():\n    global _WORLD\n    if _WORLD:\n        _WORLD.destroy()\n    _WORLD = None\n    if torch.distributed.is_initialized():\n        torch.distributed.destroy_process_group()\n\n\ndef cleanup_dist_env_and_memory(shutdown_ray: bool = False):\n    destroy_model_parallel()\n    destroy_distributed_environment()\n    with contextlib.suppress(AssertionError):\n        torch.distributed.destroy_process_group()\n    if shutdown_ray:\n        import ray  # Lazy import Ray\n\n        ray.shutdown()\n    gc.collect()\n    if not current_platform.is_cpu():\n        torch.cuda.empty_cache()\n\n\ndef in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:\n    \"\"\"\n    This is a collective operation that returns if each rank is in the same node\n    as the source rank. It tests if processes are attached to the same\n    memory system (shared access to shared memory).\n    \"\"\"\n    assert (\n        torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL\n    ), \"in_the_same_node_as should be tested with a non-NCCL group.\"\n    # local rank inside the group\n    rank = torch.distributed.get_rank(group=pg)\n    world_size = torch.distributed.get_world_size(group=pg)\n\n    # local tensor in each process to store the result\n    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)\n\n    # global ranks of the processes in the group\n    ranks = torch.distributed.get_process_group_ranks(pg)\n\n    magic_message = b\"magic_message\"\n    shm = None\n\n    try:\n        with contextlib.suppress(OSError):\n            if rank == source_rank:\n                # create a shared memory segment\n                shm = shared_memory.SharedMemory(create=True, size=128)\n                shm.buf[: len(magic_message)] = magic_message\n                torch.distributed.broadcast_object_list(\n                    [shm.name], src=ranks[source_rank], group=pg\n                )\n                is_in_the_same_node[rank] = 1\n            else:\n                # try to open the shared memory segment\n                recv = [None]\n                torch.distributed.broadcast_object_list(\n                    recv, src=ranks[source_rank], group=pg\n                )\n                name = recv[0]\n                # fix to https://stackoverflow.com/q/62748654/9191338\n                # Python incorrectly tracks shared memory even if it is not\n                # created by the process. The following patch is a workaround.\n                with patch(\n                    \"multiprocessing.resource_tracker.register\",\n                    lambda *args, **kwargs: None,\n                ):\n                    shm = shared_memory.SharedMemory(name=name)\n                if shm.buf[: len(magic_message)] == magic_message:\n                    is_in_the_same_node[rank] = 1\n    except Exception as e:\n        print(\"Error ignored in is_in_the_same_node: %s\", e)\n    finally:\n        if shm:\n            shm.close()\n\n    torch.distributed.barrier(group=pg)\n\n    # clean up the shared memory segment\n    with contextlib.suppress(OSError):\n        if rank == source_rank and shm:\n            shm.unlink()\n    torch.distributed.all_reduce(is_in_the_same_node, group=pg)\n\n    return [x == 1 for x in is_in_the_same_node.tolist()]\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/distributed/pynccl.py",
    "content": "from contextlib import contextmanager\nfrom typing import Optional, Union\n\n# ===================== import region =====================\nimport torch\nimport torch.distributed as dist\nfrom torch.distributed import ProcessGroup, ReduceOp\n\nfrom server.inference.distributed.pynccl_wrapper import (\n    NCCLLibrary,\n    buffer_type,\n    cudaStream_t,\n    ncclComm_t,\n    ncclDataTypeEnum,\n    ncclRedOpTypeEnum,\n    ncclUniqueId,\n)\nfrom server.inference.distributed.utils import StatelessProcessGroup\n\n\nclass PyNcclCommunicator:\n\n    def __init__(\n        self,\n        group: Union[ProcessGroup, StatelessProcessGroup],\n        device: Union[int, str, torch.device],\n        library_path: Optional[str] = None,\n    ):\n        \"\"\"\n        Args:\n            group: the process group to work on. If None, it will use the\n                default process group.\n            device: the device to bind the PyNcclCommunicator to. If None,\n                it will be bind to f\"cuda:{local_rank}\".\n            library_path: the path to the NCCL library. If None, it will\n                use the default library path.\n        It is the caller's responsibility to make sure each communicator\n        is bind to a unique device.\n        \"\"\"\n        if not isinstance(group, StatelessProcessGroup):\n            assert dist.is_initialized()\n            assert (\n                dist.get_backend(group) != dist.Backend.NCCL\n            ), \"PyNcclCommunicator should be attached to a non-NCCL group.\"\n            # note: this rank is the rank in the group\n            self.rank = dist.get_rank(group)\n            self.world_size = dist.get_world_size(group)\n        else:\n            self.rank = group.rank\n            self.world_size = group.world_size\n\n        self.group = group\n\n        # if world_size == 1, no need to create communicator\n        if self.world_size == 1:\n            self.available = False\n            self.disabled = True\n            self.stream = None\n            return\n        try:\n            self.nccl = NCCLLibrary(library_path)\n        except Exception:\n            # disable because of missing NCCL library\n            # e.g. in a non-GPU environment\n            self.available = False\n            self.disabled = True\n            self.stream = None\n            return\n\n        self.available = True\n        self.disabled = False\n\n        print(\"vLLM is using nccl==%s\", self.nccl.ncclGetVersion())\n\n        if self.rank == 0:\n            # get the unique id from NCCL\n            self.unique_id = self.nccl.ncclGetUniqueId()\n        else:\n            # construct an empty unique id\n            self.unique_id = ncclUniqueId()\n\n        if not isinstance(group, StatelessProcessGroup):\n            tensor = torch.ByteTensor(list(self.unique_id.internal))\n            ranks = dist.get_process_group_ranks(group)\n            # arg `src` in `broadcast` is the global rank\n            dist.broadcast(tensor, src=ranks[0], group=group)\n            byte_list = tensor.tolist()\n            for i, byte in enumerate(byte_list):\n                self.unique_id.internal[i] = byte\n        else:\n            self.unique_id = group.broadcast_obj(self.unique_id, src=0)\n        if isinstance(device, int):\n            device = torch.device(f\"cuda:{device}\")\n        elif isinstance(device, str):\n            device = torch.device(device)\n        # now `device` is a `torch.device` object\n        assert isinstance(device, torch.device)\n        self.device = device\n        # nccl communicator and stream will use this device\n        # `torch.cuda.device` is a context manager that changes the\n        # current cuda device to the specified one\n        with torch.cuda.device(device):\n            self.comm: ncclComm_t = self.nccl.ncclCommInitRank(\n                self.world_size, self.unique_id, self.rank\n            )\n            self.stream = torch.cuda.Stream()\n\n            # A small all_reduce for warmup.\n            data = torch.zeros(1, device=device)\n            self.all_reduce(data)\n            self.stream.synchronize()\n            del data\n\n        # by default it is disabled, e.g. in profiling models and prefill phase.\n        # to use it, use under `with obj.change_state(enable=True)`, usually\n        # when we are using CUDA graph.\n        self.disabled = True\n\n    def all_reduce(\n        self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None\n    ):\n        if self.disabled:\n            return\n        # nccl communicator created on a specific device\n        # will only work on tensors on the same device\n        # otherwise it will cause \"illegal memory access\"\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        if stream is None:\n            stream = self.stream\n        self.nccl.ncclAllReduce(\n            buffer_type(tensor.data_ptr()),\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            ncclRedOpTypeEnum.from_torch(op),\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def send(self, tensor: torch.Tensor, dst: int, stream=None):\n        if self.disabled:\n            return\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        if stream is None:\n            stream = self.stream\n        self.nccl.ncclSend(\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            dst,\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    def recv(self, tensor: torch.Tensor, src: int, stream=None):\n        if self.disabled:\n            return\n        assert tensor.device == self.device, (\n            f\"this nccl communicator is created to work on {self.device}, \"\n            f\"but the input tensor is on {tensor.device}\"\n        )\n        if stream is None:\n            stream = self.stream\n        self.nccl.ncclRecv(\n            buffer_type(tensor.data_ptr()),\n            tensor.numel(),\n            ncclDataTypeEnum.from_torch(tensor.dtype),\n            src,\n            self.comm,\n            cudaStream_t(stream.cuda_stream),\n        )\n\n    @contextmanager\n    def change_state(\n        self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None\n    ):\n        \"\"\"\n        A context manager to change the state of the communicator.\n        \"\"\"\n        if enable is None:\n            # guess a default value when not specified\n            enable = self.available\n\n        if stream is None:\n            stream = self.stream\n\n        old_disable = self.disabled\n        old_stream = self.stream\n\n        self.stream = stream\n        self.disabled = not enable\n        yield\n\n        self.disabled = old_disable\n        self.stream = old_stream\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/distributed/pynccl_wrapper.py",
    "content": "# This file is a pure Python wrapper for the NCCL library.\n# The main purpose is to use NCCL combined with CUDA graph.\n# Before writing this script, we tried the following approach:\n# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself\n#  often gets stuck when initializing the NCCL communicator.\n# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`\n#  contains many other potential cuda APIs, that are not allowed during\n#  capturing the CUDA graph. For further details, please check\n# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .\n#\n# Another rejected idea is to write a C/C++ binding for NCCL. It is usually\n# doable, but we often encounter issues related with nccl versions, and need\n# to switch between different versions of NCCL. See\n# https://github.com/NVIDIA/nccl/issues/1234 for more details.\n# A C/C++ binding is not flexible enough to handle this. It requires\n# recompilation of the code every time we want to switch between different\n# versions. This current implementation, with a **pure** Python wrapper, is\n# more flexible. We can easily switch between different versions of NCCL by\n# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`\n# variable in the code.\n\nimport ctypes\nimport platform\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional\n\nimport torch\nfrom torch.distributed import ReduceOp\n\nfrom server.utils import find_nccl_library\n\n\n# === export types and functions from nccl to Python ===\n# for the original nccl definition, please check\n# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in\n\nncclResult_t = ctypes.c_int\nncclComm_t = ctypes.c_void_p\n\n\nclass ncclUniqueId(ctypes.Structure):\n    _fields_ = [(\"internal\", ctypes.c_byte * 128)]\n\n\ncudaStream_t = ctypes.c_void_p\nbuffer_type = ctypes.c_void_p\n\nncclDataType_t = ctypes.c_int\n\n\nclass ncclDataTypeEnum:\n    ncclInt8 = 0\n    ncclChar = 0\n    ncclUint8 = 1\n    ncclInt32 = 2\n    ncclInt = 2\n    ncclUint32 = 3\n    ncclInt64 = 4\n    ncclUint64 = 5\n    ncclFloat16 = 6\n    ncclHalf = 6\n    ncclFloat32 = 7\n    ncclFloat = 7\n    ncclFloat64 = 8\n    ncclDouble = 8\n    ncclBfloat16 = 9\n    ncclNumTypes = 10\n\n    @classmethod\n    def from_torch(cls, dtype: torch.dtype) -> int:\n        if dtype == torch.int8:\n            return cls.ncclInt8\n        if dtype == torch.uint8:\n            return cls.ncclUint8\n        if dtype == torch.int32:\n            return cls.ncclInt32\n        if dtype == torch.int64:\n            return cls.ncclInt64\n        if dtype == torch.float16:\n            return cls.ncclFloat16\n        if dtype == torch.float32:\n            return cls.ncclFloat32\n        if dtype == torch.float64:\n            return cls.ncclFloat64\n        if dtype == torch.bfloat16:\n            return cls.ncclBfloat16\n        raise ValueError(f\"Unsupported dtype: {dtype}\")\n\n\nncclRedOp_t = ctypes.c_int\n\n\nclass ncclRedOpTypeEnum:\n    ncclSum = 0\n    ncclProd = 1\n    ncclMax = 2\n    ncclMin = 3\n    ncclAvg = 4\n    ncclNumOps = 5\n\n    @classmethod\n    def from_torch(cls, op: ReduceOp) -> int:\n        if op == ReduceOp.SUM:\n            return cls.ncclSum\n        if op == ReduceOp.PRODUCT:\n            return cls.ncclProd\n        if op == ReduceOp.MAX:\n            return cls.ncclMax\n        if op == ReduceOp.MIN:\n            return cls.ncclMin\n        if op == ReduceOp.AVG:\n            return cls.ncclAvg\n        raise ValueError(f\"Unsupported op: {op}\")\n\n\n@dataclass\nclass Function:\n    name: str\n    restype: Any\n    argtypes: List[Any]\n\n\nclass NCCLLibrary:\n    exported_functions = [\n        # const char* ncclGetErrorString(ncclResult_t result)\n        Function(\"ncclGetErrorString\", ctypes.c_char_p, [ncclResult_t]),\n        # ncclResult_t  ncclGetVersion(int *version);\n        Function(\"ncclGetVersion\", ncclResult_t,\n                 [ctypes.POINTER(ctypes.c_int)]),\n        # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);\n        Function(\"ncclGetUniqueId\", ncclResult_t,\n                 [ctypes.POINTER(ncclUniqueId)]),\n        # ncclResult_t  ncclCommInitRank(\n        #   ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);\n        # note that ncclComm_t is a pointer type, so the first argument\n        # is a pointer to a pointer\n        Function(\"ncclCommInitRank\", ncclResult_t, [\n            ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,\n            ctypes.c_int\n        ]),\n        # ncclResult_t  ncclAllReduce(\n        #   const void* sendbuff, void* recvbuff, size_t count,\n        #   ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,\n        #   cudaStream_t stream);\n        # note that cudaStream_t is a pointer type, so the last argument\n        # is a pointer\n        Function(\"ncclAllReduce\", ncclResult_t, [\n            buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,\n            ncclRedOp_t, ncclComm_t, cudaStream_t\n        ]),\n\n        # ncclResult_t  ncclSend(\n        #   const void* sendbuff, size_t count, ncclDataType_t datatype,\n        #   int dest, ncclComm_t comm, cudaStream_t stream);\n        Function(\"ncclSend\", ncclResult_t, [\n            buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,\n            ncclComm_t, cudaStream_t\n        ]),\n\n        # ncclResult_t  ncclRecv(\n        #   void* recvbuff, size_t count, ncclDataType_t datatype,\n        #   int src, ncclComm_t comm, cudaStream_t stream);\n        Function(\"ncclRecv\", ncclResult_t, [\n            buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int,\n            ncclComm_t, cudaStream_t\n        ]),\n\n        # be cautious! this is a collective call, it will block until all\n        # processes in the communicator have called this function.\n        # because Python object destruction can happen in random order,\n        # it is better not to call it at all.\n        # ncclResult_t  ncclCommDestroy(ncclComm_t comm);\n        Function(\"ncclCommDestroy\", ncclResult_t, [ncclComm_t]),\n    ]\n\n    # class attribute to store the mapping from the path to the library\n    # to avoid loading the same library multiple times\n    path_to_library_cache: Dict[str, Any] = {}\n\n    # class attribute to store the mapping from library path\n    #  to the corresponding dictionary\n    path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}\n\n    def __init__(self, so_file: Optional[str] = None):\n\n        so_file = so_file or find_nccl_library()\n\n        try:\n            if so_file not in NCCLLibrary.path_to_dict_mapping:\n                lib = ctypes.CDLL(so_file)\n                NCCLLibrary.path_to_library_cache[so_file] = lib\n            self.lib = NCCLLibrary.path_to_library_cache[so_file]\n        except Exception as e:\n            print(\n                \"Failed to load NCCL library from %s .\"\n                \"It is expected if you are not running on NVIDIA/AMD GPUs.\"\n                \"Otherwise, the nccl library might not exist, be corrupted \"\n                \"or it does not support the current platform %s.\"\n                \"If you already have the library, please set the \"\n                \"environment variable VLLM_NCCL_SO_PATH\"\n                \" to point to the correct nccl library path.\", so_file,\n                platform.platform())\n            raise e\n\n        if so_file not in NCCLLibrary.path_to_dict_mapping:\n            _funcs: Dict[str, Any] = {}\n            for func in NCCLLibrary.exported_functions:\n                f = getattr(self.lib, func.name)\n                f.restype = func.restype\n                f.argtypes = func.argtypes\n                _funcs[func.name] = f\n            NCCLLibrary.path_to_dict_mapping[so_file] = _funcs\n        self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]\n\n    def ncclGetErrorString(self, result: ncclResult_t) -> str:\n        return self._funcs[\"ncclGetErrorString\"](result).decode(\"utf-8\")\n\n    def NCCL_CHECK(self, result: ncclResult_t) -> None:\n        if result != 0:\n            error_str = self.ncclGetErrorString(result)\n            raise RuntimeError(f\"NCCL error: {error_str}\")\n\n    def ncclGetVersion(self) -> str:\n        version = ctypes.c_int()\n        self.NCCL_CHECK(self._funcs[\"ncclGetVersion\"](ctypes.byref(version)))\n        version_str = str(version.value)\n        # something like 21903 --> \"2.19.3\"\n        major = version_str[0].lstrip(\"0\")\n        minor = version_str[1:3].lstrip(\"0\")\n        patch = version_str[3:].lstrip(\"0\")\n        return f\"{major}.{minor}.{patch}\"\n\n    def ncclGetUniqueId(self) -> ncclUniqueId:\n        unique_id = ncclUniqueId()\n        self.NCCL_CHECK(self._funcs[\"ncclGetUniqueId\"](\n            ctypes.byref(unique_id)))\n        return unique_id\n\n    def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,\n                         rank: int) -> ncclComm_t:\n        comm = ncclComm_t()\n        self.NCCL_CHECK(self._funcs[\"ncclCommInitRank\"](ctypes.byref(comm),\n                                                        world_size, unique_id,\n                                                        rank))\n        return comm\n\n    def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,\n                      count: int, datatype: int, op: int, comm: ncclComm_t,\n                      stream: cudaStream_t) -> None:\n        # `datatype` actually should be `ncclDataType_t`\n        # and `op` should be `ncclRedOp_t`\n        # both are aliases of `ctypes.c_int`\n        # when we pass int to a function, it will be converted to `ctypes.c_int`\n        # by ctypes automatically\n        self.NCCL_CHECK(self._funcs[\"ncclAllReduce\"](sendbuff, recvbuff, count,\n                                                     datatype, op, comm,\n                                                     stream))\n\n    def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,\n                 dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclSend\"](sendbuff, count, datatype,\n                                                dest, comm, stream))\n\n    def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,\n                 src: int, comm: ncclComm_t, stream: cudaStream_t) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclRecv\"](recvbuff, count, datatype, src,\n                                                comm, stream))\n\n    def ncclCommDestroy(self, comm: ncclComm_t) -> None:\n        self.NCCL_CHECK(self._funcs[\"ncclCommDestroy\"](comm))\n\n\n__all__ = [\n    \"NCCLLibrary\", \"ncclDataTypeEnum\", \"ncclRedOpTypeEnum\", \"ncclUniqueId\",\n    \"ncclComm_t\", \"cudaStream_t\", \"buffer_type\"\n]\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/distributed/utils.py",
    "content": "# Copyright 2023 The vLLM team.\n# Adapted from\n# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py\n# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\nimport dataclasses\nimport pickle\nimport time\nfrom collections import deque\nfrom typing import Any, Deque, Dict, Optional, Sequence, Tuple\n\nimport torch\nfrom torch.distributed import TCPStore\n\nimport server.envs as envs\n\n\ndef ensure_divisibility(numerator, denominator):\n    \"\"\"Ensure that numerator is divisible by the denominator.\"\"\"\n    assert numerator % denominator == 0, \"{} is not divisible by {}\".format(\n        numerator, denominator\n    )\n\n\ndef divide(numerator, denominator):\n    \"\"\"Ensure that numerator is divisible by the denominator and return\n    the division value.\"\"\"\n    ensure_divisibility(numerator, denominator)\n    return numerator // denominator\n\n\ndef split_tensor_along_last_dim(\n    tensor: torch.Tensor,\n    num_partitions: int,\n    contiguous_split_chunks: bool = False,\n) -> Sequence[torch.Tensor]:\n    \"\"\"Split a tensor along its last dimension.\n\n    Arguments:\n        tensor: input tensor.\n        num_partitions: number of partitions to split the tensor\n        contiguous_split_chunks: If True, make each chunk contiguous\n                                 in memory.\n\n    Returns:\n        A list of Tensors\n    \"\"\"\n    # Get the size and dimension.\n    last_dim = tensor.dim() - 1\n    last_dim_size = divide(tensor.size()[last_dim], num_partitions)\n    # Split.\n    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)\n    # NOTE: torch.split does not create contiguous tensors by default.\n    if contiguous_split_chunks:\n        return tuple(chunk.contiguous() for chunk in tensor_list)\n\n    return tensor_list\n\n\ndef get_pp_indices(\n    num_hidden_layers: int, pp_rank: int, pp_size: int\n) -> Tuple[int, int]:\n    \"\"\"Try to evenly distribute layers across partitions.\n    If the number of layers is not divisible by the number of partitions,\n    the last partition will have the remaining layers.\n    \"\"\"\n    partition_list_str = envs.VLLM_PP_LAYER_PARTITION\n    if partition_list_str is not None:\n        try:\n            partitions = [int(layer) for layer in partition_list_str.split(\",\")]\n        except ValueError as err:\n            raise ValueError(\n                \"Invalid partition string: {}\".format(partition_list_str)\n            ) from err\n        if len(partitions) != pp_size:\n            raise ValueError(f\"{len(partitions)=} does not match {pp_size=}.\")\n        if sum(partitions) != num_hidden_layers:\n            raise ValueError(f\"{sum(partitions)=} does not match {num_hidden_layers=}.\")\n        start_layer = sum(partitions[:pp_rank])\n        end_layer = start_layer + partitions[pp_rank]\n    else:\n        layers_per_partition = num_hidden_layers // pp_size\n        start_layer = pp_rank * layers_per_partition\n        end_layer = start_layer + layers_per_partition\n\n        if pp_rank == pp_size - 1:\n            end_layer = num_hidden_layers\n\n    return (start_layer, end_layer)\n\n\n@dataclasses.dataclass\nclass StatelessProcessGroup:\n    \"\"\"A dataclass to hold a metadata store, and the rank, world_size of the\n    group. Only use it to communicate metadata between processes.\n    For data-plane communication, create NCCL-related objects.\n    \"\"\"\n\n    rank: int\n    world_size: int\n    store: torch._C._distributed_c10d.Store\n    data_expiration_seconds: int = 3600  # 1 hour\n\n    # dst rank -> counter\n    send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)\n    # src rank -> counter\n    recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)\n    broadcast_send_counter: int = 0\n    broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)\n\n    # A deque to store the data entries, with key and timestamp.\n    entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)\n\n    def __post_init__(self):\n        assert self.rank < self.world_size\n        self.send_dst_counter = {i: 0 for i in range(self.world_size)}\n        self.recv_src_counter = {i: 0 for i in range(self.world_size)}\n        self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}\n\n    def send_obj(self, obj: Any, dst: int):\n        \"\"\"Send an object to a destination rank.\"\"\"\n        self.expire_data()\n        key = f\"send_to/{dst}/{self.send_dst_counter[dst]}\"\n        self.store.set(key, pickle.dumps(obj))\n        self.send_dst_counter[dst] += 1\n        self.entries.append((key, time.time()))\n\n    def expire_data(self):\n        \"\"\"Expire data that is older than `data_expiration_seconds` seconds.\"\"\"\n        while self.entries:\n            # check the oldest entry\n            key, timestamp = self.entries[0]\n            if time.time() - timestamp > self.data_expiration_seconds:\n                self.store.delete_key(key)\n                self.entries.popleft()\n            else:\n                break\n\n    def recv_obj(self, src: int) -> Any:\n        \"\"\"Receive an object from a source rank.\"\"\"\n        obj = pickle.loads(\n            self.store.get(f\"send_to/{self.rank}/{self.recv_src_counter[src]}\")\n        )\n        self.recv_src_counter[src] += 1\n        return obj\n\n    def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:\n        \"\"\"Broadcast an object from a source rank to all other ranks.\n        It does not clean up after all ranks have received the object.\n        Use it for limited times, e.g., for initialization.\n        \"\"\"\n        if self.rank == src:\n            self.expire_data()\n            key = f\"broadcast_from/{src}/\" f\"{self.broadcast_send_counter}\"\n            self.store.set(key, pickle.dumps(obj))\n            self.broadcast_send_counter += 1\n            self.entries.append((key, time.time()))\n            return obj\n        else:\n            key = f\"broadcast_from/{src}/\" f\"{self.broadcast_recv_src_counter[src]}\"\n            recv_obj = pickle.loads(self.store.get(key))\n            self.broadcast_recv_src_counter[src] += 1\n            return recv_obj\n\n    def all_gather_obj(self, obj: Any) -> list[Any]:\n        \"\"\"All gather an object from all ranks.\"\"\"\n        gathered_objs = []\n        for i in range(self.world_size):\n            if i == self.rank:\n                gathered_objs.append(obj)\n                self.broadcast_obj(obj, src=self.rank)\n            else:\n                recv_obj = self.broadcast_obj(None, src=i)\n                gathered_objs.append(recv_obj)\n        return gathered_objs\n\n    def barrier(self):\n        \"\"\"A barrier to synchronize all ranks.\"\"\"\n        for i in range(self.world_size):\n            if i == self.rank:\n                self.broadcast_obj(None, src=self.rank)\n            else:\n                self.broadcast_obj(None, src=i)\n\n    @staticmethod\n    def create(\n        host: str,\n        port: int,\n        rank: int,\n        world_size: int,\n        data_expiration_seconds: int = 3600,\n    ) -> \"StatelessProcessGroup\":\n        \"\"\"A replacement for `torch.distributed.init_process_group` that does not\n        pollute the global state.\n\n        If we have process A and process B called `torch.distributed.init_process_group`\n        to form a group, and then we want to form another group with process A, B, C,\n        D, it is not possible in PyTorch, because process A and process B have already\n        formed a group, and process C and process D cannot join that group. This\n        function is a workaround for this issue.\n\n        `torch.distributed.init_process_group` is a global call, while this function\n        is a stateless call. It will return a `StatelessProcessGroup` object that can be\n        used for exchanging metadata. With this function, process A and process B\n        can call `StatelessProcessGroup.create` to form a group, and then process A, B,\n        C, and D can call `StatelessProcessGroup.create` to form another group.\n        \"\"\"  # noqa\n        store = TCPStore(\n            host_name=host,\n            port=port,\n            world_size=world_size,\n            is_master=(rank == 0),\n        )\n\n        return StatelessProcessGroup(\n            rank=rank,\n            world_size=world_size,\n            store=store,\n            data_expiration_seconds=data_expiration_seconds,\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/forward_batch.py",
    "content": "'''\nDate: 2024-11-12 14:15:16\nLastEditors: Xie Weiyu ervinxie@qq.com\nLastEditTime: 2024-11-26 08:12:49\n'''\nimport torch\nfrom ktransformers.server.balance_serve.settings import sched_ext\nfrom ktransformers.server.balance_serve.inference.query_manager import QueryManager, QueryInfo\nimport time\nfrom ktransformers.server.config.config import Config\nclass ForwardBatchInput:\n\n    class ForwardMiniBatch:\n        q_indptr: torch.Tensor\n        kv_indptr: torch.Tensor\n        kv_indices: torch.Tensor\n        kv_last_page_len: torch.Tensor\n        kv_len: torch.Tensor\n        position_ids: torch.Tensor\n        tokens: torch.Tensor\n        batch_indices: torch.Tensor\n        positions: torch.Tensor\n        chunk_size: int\n        decode_batch: int        \n        is_last_prefill_chunk: bool\n        logits_start: list\n\n        temperatures: torch.Tensor\n        top_ps: torch.Tensor\n\n        def __init__(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256):\n            batch_decode = len(decode_querys_info)\n            batch_prefill = len(prefill_querys_info)\n\n            self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32)\n            self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32)\n            self.kv_indices = torch.tensor([], device=device, dtype=torch.int32)\n            self.kv_len = torch.tensor([], device=device, dtype=torch.int32)\n            self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32)\n            self.position_ids = torch.tensor([], device=device, dtype=torch.int32)\n            self.tokens = torch.tensor([], device=device, dtype=torch.int32)\n\n            self.temperatures = torch.tensor([], device=device, dtype=torch.float32)\n            self.top_ps = torch.tensor([], device=device, dtype=torch.float32)\n\n            self.logits_start = []\n            self.decode_batch = batch_decode\n            self.num_tokens = batch_decode + sum(prefill_l)\n            self.batch_size = batch_decode + batch_prefill\n            \n            for i, prefill_query_info in enumerate(prefill_querys_info):\n                if prefill_query_info != None:\n                    prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0\n                    # print(f\"block_len: {prefill_kv_block_len}, page_size: {page_size}\")\n                    self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n                    self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n                    self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0)\n                    self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)\n                    self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0)\n                    self.position_ids = torch.concat((self.position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)\n                    self.tokens = torch.concat((self.tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)\n                    self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1)\n\n                    self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)\n                    self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)\n\n            for decode_query_info in decode_querys_info:\n                decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size\n                self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n                self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n                self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0)\n                self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)\n                self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0)\n                self.position_ids = torch.concat((self.position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)\n                if decode_query_info.active_position > 0:\n                    self.tokens = torch.concat((self.tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)\n                else: \n                    self.tokens = torch.concat((self.tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)\n                self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1)\n\n                self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)\n                self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)\n\n            self.q_indptr = self.q_indptr.contiguous()\n            self.kv_indptr = self.kv_indptr.contiguous()\n            self.kv_indices = self.kv_indices.contiguous()\n            self.kv_len = self.kv_len.contiguous()\n            self.kv_last_page_len = self.kv_last_page_len.contiguous()\n            self.position_ids = self.position_ids.contiguous()\n            self.tokens = self.tokens.contiguous()\n\n            self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)\n\n        def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, device = torch.device('cuda'), page_size = 256):\n            batch_decode = len(decode_querys_info)\n            batch_prefill = len(prefill_querys_info)\n\n            self.q_indptr = torch.tensor([0], device=device, dtype=torch.int32)\n            self.kv_indptr = torch.tensor([0], device=device, dtype=torch.int32)\n            self.kv_indices = torch.tensor([], device=device, dtype=torch.int32)\n            self.kv_len = torch.tensor([], device=device, dtype=torch.int32)\n            self.kv_last_page_len = torch.tensor([], device=device, dtype=torch.int32)\n            new_position_ids = torch.tensor([], device=device, dtype=torch.int32)\n            new_tokens = torch.tensor([], device=device, dtype=torch.int32)\n\n            self.temperatures = torch.tensor([], device=device, dtype=torch.float32)\n            self.top_ps = torch.tensor([], device=device, dtype=torch.float32)\n\n            self.logits_start = []\n            self.decode_batch = batch_decode\n            self.num_tokens = batch_decode + sum(prefill_l)\n            self.batch_size = batch_decode + batch_prefill\n\n            for i, prefill_query_info in enumerate(prefill_querys_info):\n                prefill_kv_block_len = (prefill_query_info.active_position + prefill_l[i] + page_size - 1) // page_size if prefill_query_info is not None else 0\n            # print(f\"block_len: {prefill_kv_block_len}, page_size: {page_size}\")\n                self.q_indptr = torch.concat((self.q_indptr, torch.tensor([prefill_l[i] + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n                self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([prefill_kv_block_len + self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n                self.kv_indices = torch.concat((self.kv_indices, prefill_query_info.block_index[:prefill_kv_block_len]), dim=0)\n                self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i]) % page_size if (prefill_query_info.active_position + prefill_l[i]) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)\n                self.kv_len = torch.concat((self.kv_len, torch.tensor([(prefill_query_info.active_position + prefill_l[i])], device=device, dtype=torch.int32)), dim=0)\n                new_position_ids = torch.concat((new_position_ids, torch.arange(prefill_s[i], prefill_l[i] + prefill_s[i], device=device, dtype=torch.int32)), dim=0)\n                new_tokens = torch.concat((new_tokens, prefill_query_info.query_tokens[prefill_s[i]:prefill_s[i] + prefill_l[i]]), dim=0)\n                self.logits_start.append(prefill_l[i] - 1 if len(self.logits_start) == 0 else sum(prefill_l[:i+1])-1)\n\n                self.temperatures = torch.concat((self.temperatures, torch.tensor([prefill_query_info.temperature], device=device, dtype=torch.float32)), dim=0)\n                self.top_ps = torch.concat((self.top_ps, torch.tensor([prefill_query_info.top_p], device=device, dtype=torch.float32)), dim=0)\n\n\n            for decode_query_info in decode_querys_info:\n                decode_kv_block_len = (decode_query_info.active_position + 1 + page_size - 1) // page_size\n                self.q_indptr = torch.concat((self.q_indptr, torch.tensor([1 + self.q_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n                self.kv_indptr = torch.concat((self.kv_indptr, torch.tensor([decode_kv_block_len+self.kv_indptr[-1]], device=device, dtype=torch.int32)), dim=0)\n                self.kv_indices = torch.concat((self.kv_indices, decode_query_info.block_index[:decode_kv_block_len]), dim=0)\n                self.kv_last_page_len = torch.concat((self.kv_last_page_len, torch.tensor([(decode_query_info.active_position + 1) % page_size if (decode_query_info.active_position + 1) % page_size != 0 else page_size], device=device, dtype=torch.int32)), dim=0)\n                self.kv_len = torch.concat((self.kv_len, torch.tensor([(decode_query_info.active_position + 1)], device=device, dtype=torch.int32)), dim=0)\n                new_position_ids = torch.concat((new_position_ids, torch.arange(decode_query_info.active_position, decode_query_info.active_position + 1, device=device, dtype=torch.int32)), dim=0)\n                if decode_query_info.active_position > 0:\n                    new_tokens = torch.concat((new_tokens, decode_query_info.query_tokens[decode_query_info.active_position:decode_query_info.active_position+1]), dim=0)\n                else: \n                    new_tokens = torch.concat((new_tokens, torch.tensor([0], device=device, dtype=torch.int32)), dim=0)\n                self.logits_start.append(0 if len(self.logits_start) == 0 else self.logits_start[-1]+1)\n\n                self.temperatures = torch.concat((self.temperatures, torch.tensor([decode_query_info.temperature], device=device, dtype=torch.float32)), dim=0)\n                self.top_ps = torch.concat((self.top_ps, torch.tensor([decode_query_info.top_p], device=device, dtype=torch.float32)), dim=0)\n\n\n            self.q_indptr = self.q_indptr.contiguous()\n            self.kv_indptr = self.kv_indptr.contiguous()\n            self.kv_indices = self.kv_indices.contiguous()\n            self.kv_len = self.kv_len.contiguous()\n            self.kv_last_page_len = self.kv_last_page_len.contiguous()\n\n            self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)\n            \n            # copy new_position_ids and new_tokens to self.position_ids and self.tokens\n            # print(\"new_position_ids: \", new_position_ids)\n            # self.print()\n            self.position_ids[:new_position_ids.size(0)].copy_(new_position_ids)\n            self.position_ids[new_position_ids.size(0):].zero_()\n            self.tokens[:new_tokens.size(0)].copy_(new_tokens)\n\n\n    forward_minibatchs: list[ForwardMiniBatch]\n    batch_size: int\n    minibatch: ForwardMiniBatch\n\n\n\n    def __init__(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, device=None, tokens: torch.Tensor = None):\n        \n        if batch is None:\n            return\n\n\n        prefill_minibatches = batch.prefill_mini_batches\n        decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist]\n        prefill_querys_info = []\n        prefill_s = []\n        prefill_l = []\n        decode_querys_info = []\n        self.batch_size = 1\n        for (id, s, l) in prefill_minibatches:\n            prefill_querys_info.append(query_manager.query_map[id])\n            prefill_s.append(s)\n            prefill_l.append(l)\n        for decode_batch_idx in decode_mini_batches:\n            if query_manager.query_map[decode_batch_idx].decode_start_time is None:\n                query_manager.query_map[decode_batch_idx].decode_start_time =time.time()\n            decode_querys_info.append(query_manager.query_map[decode_batch_idx])\n\n\n        minibatch = ForwardBatchInput.ForwardMiniBatch(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)\n \n        self.minibatch = minibatch\n\n    @classmethod\n    def gen_max_forward_batch(\n        cls,\n        device=None,\n        tokens: torch.Tensor = None,\n        num_mini_batches: int = 1,\n        max_seq_length: int = 4096, # TODO: add to yaml\n        prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config\n        prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size,\n        gen_prefill: bool = True,\n        decode_batch_size: int = Config().max_decode_batch_size,\n        decode_active_position: torch.Tensor = None,\n        page_size = 256,\n        cuda_lens = 1\n    ):\n        instance = cls()\n        \n        instance.batch_size = num_mini_batches\n        page_size = page_size\n     \n        prefill_query_info = []\n        offset = 0\n        if gen_prefill and prefill_query_length != 0:\n            for i in range(Config().max_prefill_batch_size):\n                prefill_query_info.append(QueryInfo(i, prefill_query_length, max_seq_length, page_size, device, offset=offset))\n                offset += max_seq_length // page_size\n\n        decode_querys_info = []\n        for i in range(min(decode_batch_size, cuda_lens)):\n            query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, 256, page_size, device, is_prefill=False, offset=offset)\n            offset += max_seq_length // page_size\n            if tokens is not None:\n                query_info.query_tokens[prefill_active_length:prefill_active_length + 1].copy_(tokens)            \n            if decode_active_position is None:\n                query_info.active_position = 255\n            else: \n                query_info.active_position = decode_active_position[i]\n\n            decode_querys_info.append(query_info)\n        \n        if prefill_query_length*Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens:\n            decode_querys_info.append(query_info)\n\n        instance.minibatch = ForwardBatchInput.ForwardMiniBatch(prefill_query_info, decode_querys_info, [0, 0], [prefill_active_length for _ in range(Config().max_prefill_batch_size)], device, page_size)\n        \n        return instance\n\n    def fill(self, batch : sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None, page_size = 256):\n        if batch is None:\n            return\n        prefill_minibatches = batch.prefill_mini_batches\n        decode_mini_batches = [item for sublist in batch.decode_mini_batches for item in sublist]\n\n        prefill_querys_info = []\n        prefill_s = []\n        prefill_l = []\n        decode_querys_info = []\n        self.batch_size = 1\n        for (id, s, l) in prefill_minibatches:\n            prefill_querys_info.append(query_manager.query_map[id])\n            prefill_s.append(s)\n            prefill_l.append(l)\n        for decode_batch_idx in decode_mini_batches:\n            if query_manager.query_map[decode_batch_idx].decode_start_time is None:\n                query_manager.query_map[decode_batch_idx].decode_start_time =time.time()\n            decode_querys_info.append(query_manager.query_map[decode_batch_idx])\n\n        self.minibatch.fill(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device=query_manager.device, page_size=page_size)\n\n\n\nclass ForwardBatchOutput:\n    logits: list[torch.Tensor]\n    num_batchs: int\n    batch_sizes: list[int]\n    generated_tokens_num: list[int]\n    lm_start: list[int]\n    \n    temperatures: list[torch.Tensor]\n    top_ps: list[torch.Tensor]\n\n    def __init__(self):\n        self.logits = []\n        self.batch_sizes = []\n        self.generated_tokens_num = []\n        self.top_ps = []\n        self.temperatures = []\n        self.num_batchs = 1"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/model_runner.py",
    "content": "\"\"\"\nDate: 2024-11-07 07:02:20\nLastEditors: djw\nLastEditTime: 2024-12-10 08:48:32\n\"\"\"\n\nimport torch\nfrom torch import nn\nimport queue\nimport signal\nimport queue\nfrom typing import AsyncIterable\nfrom fastapi import FastAPI, Request\nfrom fastapi.responses import StreamingResponse\nfrom contextlib import asynccontextmanager\nfrom pydantic import BaseModel, Field\nimport asyncio\nimport multiprocessing\nimport time\nimport torch.multiprocessing as mp\nimport random\nimport torch.distributed as dist\nimport zmq\nimport tempfile\nfrom ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM\nfrom ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM\nfrom ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM\nfrom ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM\nfrom ktransformers.server.balance_serve.inference.query_manager import QueryManager\nfrom ktransformers.server.balance_serve.settings import sched_ext\n\n\n\ndef pad_num_tokens(num_tokens):\n    return (num_tokens + 63) // 64 * 64\n\ndef deduplicate_and_sort(lst):\n    return sorted(set(lst))\ndef generate_cuda_graphs(chunk_size: int) -> list:\n    assert chunk_size <= 1024 or chunk_size % 1024 == 0, \"chunk_size must <= 1024 or a multiple of 1024\"\n    base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size]\n\n    if chunk_size <= 1024:\n        return deduplicate_and_sort(base_list)\n\n    multiples = [i for i in range(1024, chunk_size + 1, 1024)]\n\n    return deduplicate_and_sort(base_list + multiples)\nclass ModelRunner:\n    \"\"\"A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile.\"\"\"\n\n    model: KDeepseekV3ForCausalLM  | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM \n    input: ForwardBatchInput | list[ForwardBatchInput]\n    output: ForwardBatchOutput\n    \n    def __init__(self, model = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256, block_num = 8):\n        \n        self.stream = torch.cuda.Stream(device=device)\n        self.model = model  # Compile and move model to the specified device\n        self.device = device\n        self.input = None\n        self.features_buf = None\n        self.output = None\n        self.graph_memory_pool = None\n        self.cuda_graphs = generate_cuda_graphs(Config().chunk_size)\n        self.use_cuda_graph = use_cuda_graph\n        self.model_time = 0\n        self.page_size = page_size\n        self.block_num = block_num\n        # GPU timing for model execution\n        self.start_model_event = torch.cuda.Event(enable_timing=True)\n        self.end_model_event = torch.cuda.Event(enable_timing=True)\n\n        self.graphs = [torch.cuda.CUDAGraph() for _ in range(len(self.cuda_graphs))]\n        self.page_idx_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]\n        self.page_offset_buf = [torch.zeros([self.cuda_graphs[i]], dtype=torch.int32, device = self.device) for i in range(len(self.cuda_graphs))]\n \n        self.num_mini_batches = num_mini_batches\n\n        self.max_chunk_size = max_chunk_size\n\n        self.bsz_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)\n        self.num_tokens_tensor_buf = torch.empty((1, ),dtype=torch.int32, device=device)\n\n    def model_attn_plan(self, batch, cuda_graph_idx=0):\n        if isinstance(self.model, KDeepseekV3ForCausalLM):\n            self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,\n                                             num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank, \n                                             head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.model.cache.page_size, causal=True,\n                                             sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)\n        elif isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):\n            self.model.flash_infer_attn_plan(batch, self.bsz_tensor_buf, self.num_tokens_tensor_buf,\n                                             num_q_heads=self.model.config.num_attention_heads, num_kv_heads=self.model.config.num_key_value_heads,\n                                             head_dim=self.model.config.head_dim if hasattr(self.model.config, 'head_dim') else self.model.config.hidden_size // self.model.config.num_attention_heads, \n                                             page_size=self.model.cache.page_size, causal=True,\n                                             q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, cuda_graph_idx=cuda_graph_idx)\n        else:\n            assert False, \"model type not supported\"\n\n\n    def warmup(self):\n\n        def capture_graphs(cuda_graph_idx):\n            with torch.cuda.graph(self.graphs[cuda_graph_idx], pool=self.graph_memory_pool, stream=self.stream):\n                self.outputs_buf[cuda_graph_idx] = self.model(self.input[cuda_graph_idx], self.features_buf[cuda_graph_idx], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[cuda_graph_idx], self.page_offset_buf[cuda_graph_idx], cuda_graph_idx=cuda_graph_idx)   \n            self.graph_memory_pool = self.graphs[cuda_graph_idx].pool()\n\n        self.input = []\n        self.features_buf = []\n        self.outputs_buf = []\n        self.bsz_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)\n        self.num_tokens_tensor_buf = torch.tensor([0], dtype=torch.int32, device=self.device)\n        for i in range(len(self.cuda_graphs)):\n            prefill_query_length = (self.cuda_graphs[i] - Config().max_decode_batch_size) // Config().max_prefill_batch_size if self.cuda_graphs[i] > Config().max_decode_batch_size else 0  #@TODO only supprot 2 prefill batch\n            self.input.append(ForwardBatchInput.gen_max_forward_batch(device=self.device, num_mini_batches = self.num_mini_batches, prefill_query_length=prefill_query_length, prefill_active_length=prefill_query_length, page_size=self.page_size, cuda_lens=self.cuda_graphs[i]))\n\n            self.features_buf.append(self.model.batch_embeddings(self.input[i]))\n            batch_size = self.input[i].minibatch.q_indptr.size(0)-1\n            num_tokens = self.features_buf[i][0].size(0)\n            print(\"capturing cuda graph\", batch_size, num_tokens)\n\n            if isinstance(self.model, KQwen2MoeForCausalLM) or isinstance(self.model, KQwen3MoeForCausalLM):\n                self.model.init_wrapper(self.use_cuda_graph, self.device, num_tokens ,batch_size, self.block_num, i) # TODO: 1024 is a magic number(max_batch_tokens)\n\n            self.bsz_tensor_buf[0] = batch_size\n            self.num_tokens_tensor_buf[0] = num_tokens\n\n            self.model_attn_plan(self.input[i], i)\n        \n            page_idx, page_offset = self.model.cache.get_page_table(self.input[i].minibatch.position_ids, self.input[i].minibatch.q_indptr, self.input[i].minibatch.kv_indptr, self.input[i].minibatch.kv_indices, self.num_tokens_tensor_buf)\n\n            \n            self.page_idx_buf[i][:num_tokens].copy_(page_idx[:num_tokens])\n            self.page_offset_buf[i][:num_tokens].copy_(page_offset[:num_tokens])\n\n            self.page_idx_buf[i][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1) \n        \n            self.outputs_buf.append(None)\n        \n            torch.cuda.synchronize()\n            for warm_up_iters in range(11):\n                with torch.cuda.stream(self.stream):\n                    self.outputs_buf[i] = self.model(self.input[i], self.features_buf[i], self.bsz_tensor_buf, self.num_tokens_tensor_buf, self.page_idx_buf[i], self.page_offset_buf[i], cuda_graph_idx=i)\n            torch.cuda.synchronize()\n\n            self.outputs_buf[i].num_batchs = batch_size\n\n            capture_graphs(i)\n\n            with torch.cuda.stream(self.stream):\n                self.graphs[i].replay()\n\n            self.sync(calc_time=False)\n            print(f\"cuda_graph: {i+1}/{len(self.cuda_graphs)}, warmup finished.\")\n        \n    def run(self, batch: sched_ext.BatchQueryTodo = None, query_manager: QueryManager = None):\n        with torch.cuda.stream(self.stream):\n\n            batch_size = len(batch.prefill_mini_batches) # TODO: calc this\n            num_tokens = 0\n            for i in range(len(batch.decode_mini_batches)):\n                batch_size += len(batch.decode_mini_batches[i])\n                num_tokens += len(batch.decode_mini_batches[i])\n                print(f'decode_batch_i: {len(batch.decode_mini_batches[i])},')\n\n            for i in range(len(batch.prefill_mini_batches)):\n                num_tokens += batch.prefill_mini_batches[i][2]\n                print(f'prefill_batch_i: {batch.prefill_mini_batches[i][2]},')\n\n\n\n            # cuda graph idx equal to min idx i in self.cuda_graphs, that self.cuda_graphs[i] > num_tokens\n            cuda_graph_idx = next((i for i, token in enumerate(self.cuda_graphs) if token >= num_tokens), len(self.cuda_graphs))\n            if not self.use_cuda_graph:\n                cuda_graph_idx = 0\n            # if cuda_graph_idx == len(self.cuda_graphs):\n            #     assert False, \"num_tokens is too large\"\n    \n            if self.use_cuda_graph:\n                self.input[cuda_graph_idx].fill(batch, query_manager, self.page_size)\n            else:\n                self.input = [ForwardBatchInput(batch=batch, query_manager=query_manager, device=self.device)]\n                \n\n            if self.use_cuda_graph:\n                self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)\n            else:\n                self.features = self.model.batch_embeddings(self.input[cuda_graph_idx], device=self.device)\n\n\n            self.bsz_tensor_buf.copy_(batch_size)\n            self.num_tokens_tensor_buf.copy_(torch.tensor([num_tokens], dtype=torch.int32, device=self.device))\n\n            if self.use_cuda_graph:\n                self.features_buf[cuda_graph_idx][0].copy_(self.features[0], non_blocking=True)\n\n            self.model_attn_plan(self.input[cuda_graph_idx], cuda_graph_idx)\n            self.start_model_event.record(self.stream)\n            page_idx, page_offset = self.model.cache.get_page_table(self.input[cuda_graph_idx].minibatch.position_ids, self.input[cuda_graph_idx].minibatch.q_indptr, self.input[cuda_graph_idx].minibatch.kv_indptr, self.input[cuda_graph_idx].minibatch.kv_indices, self.num_tokens_tensor_buf)\n            if self.use_cuda_graph:\n                self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])\n                self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])\n\n                self.page_idx_buf[cuda_graph_idx][num_tokens:].fill_(self.model.cache.max_cache_len // self.model.cache.page_size -1)\n                self.replay(cuda_graph_idx)\n                self.output = ForwardBatchOutput()\n                \n                self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)\n                self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)\n\n\n                self.output.logits.append(self.outputs_buf[cuda_graph_idx].logits[0][self.input[cuda_graph_idx].minibatch.logits_start].clone())\n            else:\n                self.output = self.model(self.input[cuda_graph_idx], self.features, self.bsz_tensor_buf, self.num_tokens_tensor_buf, page_idx, page_offset)\n                self.output.logits[0] = self.output.logits[0][self.input[cuda_graph_idx].minibatch.logits_start]\n                self.output.top_ps.append(self.input[cuda_graph_idx].minibatch.top_ps)\n                self.output.temperatures.append(self.input[cuda_graph_idx].minibatch.temperatures)\n            self.end_model_event.record(self.stream)\n\n\n\n    def replay(self, cuda_graph_idx=-1):\n        with torch.cuda.stream(self.stream):\n            if cuda_graph_idx != -1:\n                self.graphs[cuda_graph_idx].replay()\n            else:\n                self.graphs.replay()\n\n\n    def sync(self, calc_time = True):\n        self.stream.synchronize()\n        if calc_time:\n            self.model_time = self.start_model_event.elapsed_time(self.end_model_event)  # In ms"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/query_manager.py",
    "content": "'''\nDate: 2024-11-14 12:23:45\nLastEditors: djw\nLastEditTime: 2024-11-20 04:06:23\n'''\nimport torch\nfrom ktransformers.server.balance_serve.settings import sched_ext\nimport random\nimport time\n\nclass QueryInfo:\n    id: int\n    active_position: int\n    query_length: int\n    is_prefill: int\n    block_index: torch.Tensor\n    query_tokens: torch.Tensor\n    stop_criteria: list[torch.Tensor]\n\n    temperature: float\n    top_p: float\n\n    max_length: int \n\n    def __init__(self, id, query_length: int, max_length: int, page_size: int, device: torch.device, is_prefill: bool = True, offset: int = 0, active_position: int = 0, temperature: float = 0.01, top_p: float = 1.0):\n        self.id = id\n        self.is_prefill = is_prefill\n        self.active_position = active_position\n        self.max_length = max_length - 1\n        self.query_tokens = torch.zeros((max_length,), dtype=torch.int, device = device)\n        self.stop_criteria = []\n        self.block_index = torch.arange(offset, offset + (max_length + active_position + page_size - 1) // page_size, dtype=torch.int, device = device)\n        self.query_length = query_length\n        self.enqueue_time = time.time()\n        self.decode_start_time = None\n        self.speculative_token = {} # {position: (accept, token)}\n\n        self.temperature = temperature\n        self.top_p = top_p\n\n    def check_stop(self):\n        if self.active_position >= self.max_length - 2:\n            return True\n\n        for stop_tensor in self.stop_criteria:\n            stop_len = len(stop_tensor)\n            \n            if stop_len >= self.active_position:\n                continue\n            \n            #print(f\"stop_tensor: {stop_tensor}, stop_len: {stop_len}, active_position: {self.active_position}, query_token: {self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1]}\")\n\n            if (torch.equal(self.query_tokens[self.active_position - stop_len - 1:self.active_position - 1], stop_tensor) and self.active_position) or self.max_length <= self.active_position + 3:\n                self.life_time = time.time() - self.enqueue_time\n                self.decode_duration_time = time.time() - self.decode_start_time\n                self.decode_tps = (self.active_position -  self.query_length) / self.decode_duration_time\n                print(f\"prefill length: {self.query_length}, prefill time: {self.prefill_duration_time}, prefill tps {self.prefill_tps}, decode length: {self.active_position -  self.query_length}, decode time: {self.decode_duration_time}, decode tps {self.decode_tps}\")\n                return True\n                \n        \n        return False\n\n\n    def print(self):\n        print(f\"active_position: {self.active_position}, query_length: {self.query_length}, is_prefill: {self.is_prefill}\")\n        print(f\"block_index_shape: {self.block_index.shape}, query_tokens_shape: {self.query_tokens.shape}\")\n\n\nclass QueryManager:\n\n    page_size: int = 256\n    device: torch.device\n    query_map : dict[int, QueryInfo]\n\n    def __init__(self, page_size = 256, device = torch.device('cuda')):\n        self.page_size = page_size\n        self.device = device\n        self.query_map = {}\n\n    def add_query(self, batch: sched_ext.BatchQueryTodo):\n\n        for i in range(len(batch.query_ids)):\n            id = batch.query_ids[i]\n            if id not in self.query_map:\n                print(f\"add query id: {id}, batch.query_lengths: {batch.query_lengths[i]}, batch_query_tokens: {batch.query_tokens[i].shape}, batch.block_indexes: {batch.block_indexes[i]}\")\n                query_info = QueryInfo(id=id, query_length=batch.query_lengths[i], max_length=batch.query_tokens[i].size(0) + 1, page_size=self.page_size, device=self.device, temperature=batch.sample_options[i].temperature, top_p=batch.sample_options[i].top_p)\n                query_info.query_tokens[:query_info.query_length].copy_(batch.query_tokens[i][:query_info.query_length].to(self.device))\n                \n                for stop_token_list in batch.stop_criteria[i]:\n                    query_info.stop_criteria.append(torch.tensor(stop_token_list, dtype=torch.int, device = self.device))\n\n                block_num = batch.block_indexes[i].size(0)\n                query_info.block_index[:block_num].copy_(batch.block_indexes[i].to(self.device))\n\n                self.query_map[id] = query_info\n                \n                prefill_mini_batches = batch.prefill_mini_batches\n                for (prefill_id, s, l) in prefill_mini_batches:\n                    if prefill_id == id:\n                        self.query_map[prefill_id].active_position = s\n\n\n    def update(self, batch: sched_ext.BatchQueryTodo) -> list[sched_ext.QueryUpdate]:\n        query_updates = []\n\n        prefill_mini_batches = batch.prefill_mini_batches\n\n        for (id, s, l) in prefill_mini_batches:\n\n            if id not in self.query_map:\n                assert False, f\"query id {id} not found in query_map\"\n\n            # update query_info\n            query_info = self.query_map[id]\n            query_info.active_position += l\n\n            if query_info.active_position >= query_info.query_length and query_info.is_prefill:\n                query_info.is_prefill = False\n                query_info.prefill_duration_time = time.time() - query_info.enqueue_time\n                query_info.prefill_tps = query_info.query_length / query_info.prefill_duration_time\n                \n\n            # generate schedule query_update\n            query_update = sched_ext.QueryUpdate()\n            query_update.id = id\n            query_update.ok = True\n            query_update.is_prefill = query_info.is_prefill\n            query_update.active_position = query_info.active_position\n            # if(not query_info.is_prefill):\n            query_updates.append(query_update)\n\n\n        decode_mini_batches = batch.decode_mini_batches\n\n        for ids in decode_mini_batches:\n            for id in ids:\n                if id not in self.query_map:\n                    assert False, f\"query id {id} not found in query_map\"\n\n                query_info = self.query_map[id]\n                query_info.active_position += 1\n\n                query_update = sched_ext.QueryUpdate()\n                query_update.id = id\n                query_update.ok = True\n                query_update.is_prefill = query_info.is_prefill\n\n                query_update.decode_done = query_info.check_stop()\n\n                query_update.active_position = query_info.active_position\n                query_updates.append(query_update)\n\n        return query_updates\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/sampling/penaltylib/__init__.py",
    "content": "from .orchestrator import BatchedPenalizerOrchestrator\nfrom .penalizers.frequency_penalty import BatchedFrequencyPenalizer\nfrom .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer\nfrom .penalizers.presence_penalty import BatchedPresencePenalizer\nfrom .penalizers.repetition_penalty import BatchedRepetitionPenalizer\n\n__all__ = [\n    \"BatchedFrequencyPenalizer\",\n    \"BatchedMinNewTokensPenalizer\",\n    \"BatchedPresencePenalizer\",\n    \"BatchedRepetitionPenalizer\",\n    \"BatchedPenalizerOrchestrator\",\n]\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/sampling/penaltylib/orchestrator.py",
    "content": "import abc\nimport dataclasses\nimport typing\n\nimport torch\n\n\n@dataclasses.dataclass\nclass _ReqLike:\n    origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]\n\n\n@dataclasses.dataclass\nclass _BatchLike:\n    reqs: typing.List[_ReqLike]\n\n    def batch_size(self):\n        return len(self.reqs)\n\n\nclass BatchedPenalizerOrchestrator:\n    batch: _BatchLike\n    device: str\n    vocab_size: int\n    penalizers: typing.Dict[typing.Type[\"_BatchedPenalizer\"], \"_BatchedPenalizer\"]\n\n    def __init__(\n        self,\n        vocab_size: int,\n        batch: _BatchLike,\n        device: str,\n        Penalizers: typing.Set[typing.Type[\"_BatchedPenalizer\"]],\n    ):\n        self.vocab_size = vocab_size\n        self.batch = batch\n        self.device = device\n\n        self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}\n\n        is_required = False\n        for penalizer in self.penalizers.values():\n            pen_is_required = penalizer.prepare_if_required()\n            is_required |= pen_is_required\n        self.is_required = is_required\n\n        if self.is_required:\n            self.cumulate_input_tokens(\n                input_ids=[req.origin_input_ids for req in self.reqs()]\n            )\n\n    def reqs(self):\n        return self.batch.reqs\n\n    def batch_size(self):\n        return self.batch.batch_size()\n\n    def cumulate_input_tokens(\n        self,\n        input_ids: typing.Union[\n            typing.List[torch.Tensor], typing.List[typing.List[int]]\n        ],\n    ):\n        \"\"\"\n        Feed the input tokens to the penalizers.\n\n        Args:\n            input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.\n        \"\"\"\n        token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)\n\n        for penalizer in self.penalizers.values():\n            penalizer.cumulate_input_tokens(input_ids=token_ids)\n\n    def cumulate_output_tokens(\n        self,\n        output_ids: typing.Union[\n            typing.List[torch.Tensor], typing.List[typing.List[int]]\n        ],\n    ):\n        \"\"\"\n        Feed the output tokens to the penalizers.\n\n        Args:\n            output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.\n        \"\"\"\n        if not self.is_required:\n            return\n\n        token_ids = _TokenIDs(orchestrator=self, token_ids=output_ids)\n\n        for penalizer in self.penalizers.values():\n            penalizer.cumulate_output_tokens(output_ids=token_ids)\n\n    def apply(self, logits: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Apply the penalizers to the logits.\n        Note that it may apply the penalizers in-place.\n\n        Args:\n            logits (torch.Tensor): The logits to apply the penalizers to.\n\n        Returns:\n            torch.Tensor: The logits after applying the penalizers.\n        \"\"\"\n        if not self.is_required:\n            return\n\n        for penalizer in self.penalizers.values():\n            logits = penalizer.apply(logits)\n\n        return logits\n\n    def filter(\n        self,\n        indices_to_keep: typing.List[int],\n        indices_tensor_to_keep: torch.Tensor = None,\n    ):\n        \"\"\"\n        Filter the penalizers based on the indices to keep in the batch.\n\n        Args:\n            indices_to_keep (typing.List[int]): List of indices to keep in the batch.\n            indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.\n        \"\"\"\n        if not self.is_required:\n            return\n\n        empty_indices = len(indices_to_keep) == 0\n\n        is_required = False\n        for penalizer in self.penalizers.values():\n            tmp_is_required = penalizer.is_required()\n            is_required = is_required or tmp_is_required\n            if not tmp_is_required or empty_indices:\n                penalizer.teardown()\n            else:\n                # create tensor index only when it's needed\n                if indices_tensor_to_keep is None:\n                    indices_tensor_to_keep = torch.tensor(\n                        indices_to_keep, dtype=torch.int32, device=self.device\n                    )\n\n                penalizer.filter(\n                    indices_to_keep=indices_to_keep,\n                    indices_tensor_to_keep=indices_tensor_to_keep,\n                )\n        self.is_required = is_required\n\n    def merge(self, their: \"BatchedPenalizerOrchestrator\"):\n        \"\"\"\n        Merge the penalizers of another orchestrator into this one.\n\n        Note that this function **must** be called _before_ self.batch.reqs is updated (filtered).\n        Each unprepared penalizers would have to be prepared (creating tensors, etc.) first before merging.\n        This step requires the original batch.reqs, before it gets merged with other batch.reqs.\n\n        Args:\n            their (BatchedPenalizerOrchestrator): The orchestrator to merge into this one.\n        \"\"\"\n        if not self.is_required and not their.is_required:\n            return\n\n        self.is_required |= their.is_required\n        for Penalizer, their_penalizer in their.penalizers.items():\n            if Penalizer not in self.penalizers:\n                raise ValueError(f\"Penalizer {Penalizer} not found in self.penalizers\")\n\n            self.penalizers[Penalizer].merge(their_penalizer)\n\n\nclass _TokenIDs:\n    \"\"\"\n    A class that wraps token IDs to provide additional utility functions to penalizers.\n\n    Attributes:\n        orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.\n        token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.\n        cached_counts (torch.Tensor): The cached occurrence count tensor.\n    \"\"\"\n\n    orchestrator: BatchedPenalizerOrchestrator\n    token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]\n    cached_counts: torch.Tensor = None\n\n    def __init__(\n        self,\n        orchestrator: BatchedPenalizerOrchestrator,\n        token_ids: typing.Union[\n            typing.List[torch.Tensor], typing.List[typing.List[int]]\n        ],\n    ):\n        self.orchestrator = orchestrator\n\n        if not isinstance(token_ids[0], torch.Tensor):\n            token_ids = [\n                torch.tensor(\n                    data=ids, dtype=torch.int64, device=self.orchestrator.device\n                )\n                for ids in token_ids\n            ]\n\n        self.token_ids = token_ids\n\n    def occurrence_count(self) -> torch.Tensor:\n        \"\"\"\n        Returns a tensor of shape (batch_size, vocab_size) where each element is the number of times the corresponding token appears in the batch.\n\n        Returns:\n            torch.Tensor: The occurrence count tensor.\n        \"\"\"\n        if self.cached_counts is not None:\n            return self.cached_counts\n\n        token_ids = self.token_ids\n\n        if isinstance(token_ids, torch.Tensor):\n            token_ids = token_ids.unsqueeze(1)\n\n            # needs to be long to be used as index in scatter_add\n            if token_ids.dtype != torch.int64:\n                token_ids = token_ids.to(torch.int64)\n\n        padded_token_ids = torch.nn.utils.rnn.pad_sequence(\n            sequences=token_ids,\n            batch_first=True,\n            padding_value=self.orchestrator.vocab_size,\n        )\n\n        self.cached_counts = torch.zeros(\n            size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),\n            dtype=torch.int64,\n            device=self.orchestrator.device,\n        ).scatter_add_(\n            dim=1,\n            index=padded_token_ids,\n            src=torch.ones_like(padded_token_ids),\n        )[\n            :, : self.orchestrator.vocab_size\n        ]\n\n        return self.cached_counts\n\n\nclass _BatchedPenalizer(abc.ABC):\n    \"\"\"\n    An abstract class for a batched penalizer.\n    \"\"\"\n\n    orchestrator: BatchedPenalizerOrchestrator\n    _is_prepared: bool = False\n\n    def __init__(self, orchestrator: BatchedPenalizerOrchestrator):\n        self.orchestrator = orchestrator\n\n    def is_prepared(self) -> bool:\n        return self._is_prepared\n\n    def is_required(self) -> bool:\n        return self._is_required()\n\n    def prepare(self):\n        if not self.is_prepared():\n            self._prepare()\n            self._is_prepared = True\n\n    def prepare_if_required(self):\n        if self.is_required():\n            self.prepare()\n            return True\n        else:\n            return False\n\n    def teardown(self):\n        if self.is_prepared():\n            self._teardown()\n            self._is_prepared = False\n\n    def cumulate_input_tokens(self, input_ids: _TokenIDs):\n        if not self.is_prepared():\n            return\n\n        self._cumulate_input_tokens(input_ids=input_ids)\n\n    def cumulate_output_tokens(self, output_ids: _TokenIDs):\n        if not self.is_prepared():\n            return\n\n        self._cumulate_output_tokens(output_ids=output_ids)\n\n    def apply(self, logits: torch.Tensor) -> torch.Tensor:\n        if not self.is_prepared():\n            return logits\n\n        return self._apply(logits=logits)\n\n    def filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        if not self.is_prepared():\n            return\n\n        self._filter(\n            indices_to_keep=indices_to_keep,\n            indices_tensor_to_keep=indices_tensor_to_keep,\n        )\n\n    def merge(self, their: \"_BatchedPenalizer\"):\n        if not self.is_prepared() and not their.is_prepared():\n            return\n\n        self.prepare()\n        their.prepare()\n        self._merge(their)\n\n    @abc.abstractmethod\n    def _is_required(self) -> bool:\n        \"\"\"\n        Check if the penalizer is required to be prepared.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _prepare(self):\n        \"\"\"\n        Prepare the penalizer.\n        Usually, this is where the penalizer initializes its tensors.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _teardown(self):\n        \"\"\"\n        Tear down the penalizer.\n        Usually, this is where the penalizer frees its tensors.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _cumulate_input_tokens(self, input_ids: _TokenIDs):\n        \"\"\"\n        Cumulate the input tokens.\n        Orchestrator will call this function to feed the input tokens to the penalizer.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _cumulate_output_tokens(self, output_ids: _TokenIDs):\n        \"\"\"\n        Cumulate the output tokens.\n        Orchestrator will call this function to feed the output tokens to the penalizer.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _apply(self, logits: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Apply the penalizer to the logits.\n        Penalizers can modify the logits in-place if needed.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        \"\"\"\n        Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def _merge(self, their: \"_BatchedPenalizer\"):\n        \"\"\"\n        Merge the penalizer with another penalizer.\n        \"\"\"\n        pass\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/frequency_penalty.py",
    "content": "import typing\n\nimport torch\n\nfrom ..orchestrator import _BatchedPenalizer, _TokenIDs\n\n\nclass BatchedFrequencyPenalizer(_BatchedPenalizer):\n    \"\"\"\n    Frequency penalizer penalizes tokens based on their frequency in the output.\n    \"\"\"\n\n    frequency_penalties: torch.Tensor = None\n    cumulated_frequency_penalties: torch.Tensor = None\n\n    def _is_required(self) -> bool:\n        return any(\n            req.sampling_params.frequency_penalty != 0.0\n            for req in self.orchestrator.reqs()\n        )\n\n    def _prepare(self):\n        self.cumulated_frequency_penalties = (\n            torch.tensor(\n                data=[0.0 for _ in self.orchestrator.reqs()],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .repeat(1, self.orchestrator.vocab_size)\n        )\n\n        self.frequency_penalties = (\n            torch.tensor(\n                data=[\n                    req.sampling_params.frequency_penalty\n                    for req in self.orchestrator.reqs()\n                ],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .expand_as(self.cumulated_frequency_penalties)\n        )\n\n    def _teardown(self):\n        del self.frequency_penalties\n        del self.cumulated_frequency_penalties\n\n        self.frequency_penalties = None\n        self.cumulated_frequency_penalties = None\n\n    def _cumulate_input_tokens(self, input_ids: _TokenIDs):\n        pass\n\n    def _cumulate_output_tokens(self, output_ids: _TokenIDs):\n        self.cumulated_frequency_penalties += (\n            self.frequency_penalties * output_ids.occurrence_count()\n        )\n\n    def _apply(self, logits: torch.Tensor) -> torch.Tensor:\n        logits -= self.cumulated_frequency_penalties\n        return logits\n\n    def _filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]\n        self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[\n            indices_tensor_to_keep\n        ]\n\n    def _merge(self, their: \"BatchedFrequencyPenalizer\"):\n        self.frequency_penalties = torch.cat(\n            [self.frequency_penalties, their.frequency_penalties], dim=0\n        )\n        self.cumulated_frequency_penalties = torch.cat(\n            [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],\n            dim=0,\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/min_new_tokens.py",
    "content": "import typing\n\nimport torch\n\nfrom ..orchestrator import _BatchedPenalizer, _TokenIDs\n\n\nclass BatchedMinNewTokensPenalizer(_BatchedPenalizer):\n    \"\"\"\n    Min new tokens penalizer penalizes tokens based on the length of the output.\n    \"\"\"\n\n    min_new_tokens: torch.Tensor = None\n    stop_token_penalties: torch.Tensor = None\n    len_output_tokens: torch.Tensor = None\n\n    def _is_required(self) -> bool:\n        return any(\n            req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()\n        )\n\n    def _prepare(self):\n        self.min_new_tokens = torch.tensor(\n            data=[\n                req.sampling_params.min_new_tokens for req in self.orchestrator.reqs()\n            ],\n            dtype=torch.int32,\n            device=self.orchestrator.device,\n        ).unsqueeze_(1)\n\n        padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(\n            sequences=[\n                torch.tensor(\n                    data=(\n                        list(\n                            (req.sampling_params.stop_token_ids or set())\n                            | (req.tokenizer.additional_stop_token_ids or set())\n                            | {req.tokenizer.eos_token_id}\n                        )\n                    ),\n                    dtype=torch.int64,\n                    device=self.orchestrator.device,\n                )\n                for req in self.orchestrator.reqs()\n            ],\n            batch_first=True,\n            padding_value=self.orchestrator.vocab_size,\n        )\n        self.stop_token_penalties = torch.zeros(\n            size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),\n            dtype=torch.float32,\n            device=self.orchestrator.device,\n        ).scatter_add_(\n            dim=1,\n            index=padded_stop_token_ids,\n            src=torch.full_like(\n                input=padded_stop_token_ids,\n                dtype=torch.float32,\n                fill_value=float(\"-inf\"),\n                device=self.orchestrator.device,\n            ),\n        )[\n            :, : self.orchestrator.vocab_size\n        ]\n\n        self.len_output_tokens = torch.zeros(\n            size=(self.orchestrator.batch_size(), 1),\n            dtype=torch.int32,\n            device=self.orchestrator.device,\n        )\n\n    def _teardown(self):\n        del self.min_new_tokens\n        del self.stop_token_penalties\n        del self.len_output_tokens\n\n        self.min_new_tokens = None\n        self.stop_token_penalties = None\n        self.len_output_tokens = None\n\n    def _cumulate_input_tokens(self, input_ids: _TokenIDs):\n        pass\n\n    def _cumulate_output_tokens(self, output_ids: _TokenIDs):\n        self.len_output_tokens += 1\n\n    def _apply(self, logits: torch.Tensor) -> torch.Tensor:\n        mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)\n        logits[mask] += self.stop_token_penalties[mask]\n        return logits\n\n    def _filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]\n        self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]\n        self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]\n\n    def _merge(self, their: \"BatchedMinNewTokensPenalizer\"):\n        self.min_new_tokens = torch.cat(\n            [self.min_new_tokens, their.min_new_tokens], dim=0\n        )\n        self.stop_token_penalties = torch.cat(\n            [self.stop_token_penalties, their.stop_token_penalties], dim=0\n        )\n        self.len_output_tokens = torch.cat(\n            [self.len_output_tokens, their.len_output_tokens], dim=0\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/presence_penalty.py",
    "content": "import typing\n\nimport torch\n\nfrom ..orchestrator import _BatchedPenalizer, _TokenIDs\n\n\nclass BatchedPresencePenalizer(_BatchedPenalizer):\n    \"\"\"\n    Presence penalizer penalizes tokens based on their presence in the output.\n    \"\"\"\n\n    presence_penalties: torch.Tensor = None\n    cumulated_presence_penalties: torch.Tensor = None\n\n    def _is_required(self) -> bool:\n        return any(\n            req.sampling_params.presence_penalty != 0.0\n            for req in self.orchestrator.reqs()\n        )\n\n    def _prepare(self):\n        self.cumulated_presence_penalties = (\n            torch.tensor(\n                data=[0.0 for _ in self.orchestrator.reqs()],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .repeat(1, self.orchestrator.vocab_size)\n        )\n\n        self.presence_penalties = (\n            torch.tensor(\n                data=[\n                    req.sampling_params.presence_penalty\n                    for req in self.orchestrator.reqs()\n                ],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .expand_as(self.cumulated_presence_penalties)\n        )\n\n    def _teardown(self):\n        del self.presence_penalties\n        del self.cumulated_presence_penalties\n\n        self.presence_penalties = None\n        self.cumulated_presence_penalties = None\n\n    def _cumulate_input_tokens(self, input_ids: _TokenIDs):\n        pass\n\n    def _cumulate_output_tokens(self, output_ids: _TokenIDs):\n        mask = output_ids.occurrence_count() > 0\n        self.cumulated_presence_penalties[mask] = self.presence_penalties[mask]\n\n    def _apply(self, logits: torch.Tensor) -> torch.Tensor:\n        logits -= self.cumulated_presence_penalties\n        return logits\n\n    def _filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]\n        self.cumulated_presence_penalties = self.cumulated_presence_penalties[\n            indices_tensor_to_keep\n        ]\n\n    def _merge(self, their: \"BatchedPresencePenalizer\"):\n        self.presence_penalties = torch.cat(\n            [self.presence_penalties, their.presence_penalties], dim=0\n        )\n        self.cumulated_presence_penalties = torch.cat(\n            [self.cumulated_presence_penalties, their.cumulated_presence_penalties],\n            dim=0,\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/sampling/penaltylib/penalizers/repetition_penalty.py",
    "content": "import typing\n\nimport torch\n\nfrom ..orchestrator import _BatchedPenalizer, _TokenIDs\n\n\nclass BatchedRepetitionPenalizer(_BatchedPenalizer):\n    \"\"\"\n    Repetition penalizer penalizes tokens based on their repetition in the input and output.\n    \"\"\"\n\n    repetition_penalties: torch.Tensor = None\n    cumulated_repetition_penalties: torch.Tensor = None\n\n    def _is_required(self) -> bool:\n        return any(\n            req.sampling_params.repetition_penalty != 1.0\n            for req in self.orchestrator.reqs()\n        )\n\n    def _prepare(self):\n        self.cumulated_repetition_penalties = (\n            torch.tensor(\n                data=[1.0 for _ in self.orchestrator.reqs()],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .repeat(1, self.orchestrator.vocab_size)\n        )\n\n        self.repetition_penalties = (\n            torch.tensor(\n                data=[\n                    req.sampling_params.repetition_penalty\n                    for req in self.orchestrator.reqs()\n                ],\n                dtype=torch.float32,\n                device=self.orchestrator.device,\n            )\n            .unsqueeze_(1)\n            .expand_as(self.cumulated_repetition_penalties)\n        )\n\n    def _teardown(self):\n        del self.repetition_penalties\n        del self.cumulated_repetition_penalties\n\n        self.repetition_penalties = None\n        self.cumulated_repetition_penalties = None\n\n    def _cumulate_input_tokens(self, input_ids: _TokenIDs):\n        mask = input_ids.occurrence_count() > 0\n        self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]\n\n    def _cumulate_output_tokens(self, output_ids: _TokenIDs):\n        mask = output_ids.occurrence_count() > 0\n        self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]\n\n    def _apply(self, logits: torch.Tensor) -> torch.Tensor:\n        return torch.where(\n            logits > 0,\n            logits / self.cumulated_repetition_penalties,\n            logits * self.cumulated_repetition_penalties,\n        )\n\n    def _filter(\n        self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor\n    ):\n        self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]\n        self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[\n            indices_tensor_to_keep\n        ]\n\n    def _merge(self, their: \"BatchedRepetitionPenalizer\"):\n        self.repetition_penalties = torch.cat(\n            [self.repetition_penalties, their.repetition_penalties], dim=0\n        )\n        self.cumulated_repetition_penalties = torch.cat(\n            [self.cumulated_repetition_penalties, their.cumulated_repetition_penalties],\n            dim=0,\n        )\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/inference/sampling/sampler.py",
    "content": "'''\nDate: 2024-11-14 12:23:45\nLastEditors: Xie Weiyu ervinxie@qq.com\nLastEditTime: 2024-11-25 08:59:23\n'''\nimport logging\nimport torch\nfrom torch import nn\nfrom transformers import GenerationConfig\n\nfrom flashinfer.sampling import (\n\tmin_p_sampling_from_probs,\n\ttop_k_renorm_probs,\n\ttop_k_top_p_sampling_from_logits,\n\ttop_p_renorm_probs,\n)\n\nlogger = logging.getLogger(__name__)\n\nclass SamplingOptions():\n\t# Batched sampling params\n\ttemperatures: torch.Tensor\n\ttop_ps: torch.Tensor\n\ttop_ks: torch.Tensor\n\tmin_ps: torch.Tensor\n\n\t# All requests use greedy sampling\n\tis_all_greedy: bool\n\n\t# Dispatch in CUDA graph\n\tneed_min_p_sampling: bool\n\t\n\tdef __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:GenerationConfig = None, temperatures: torch.Tensor = None, top_ps: torch.Tensor = None):\n\t\tif pretrained_config is None and temperatures is None:\n\t\t\tself.temperatures = torch.full((bsz, 1), 0, device=device, dtype=torch.float32)\n\t\t\tself.top_ps = torch.ones((bsz, 1), device=device, dtype=torch.float32)\n\t\t\tself.top_ks = torch.ones((bsz, 1), device=device, dtype=torch.float32)\n\t\t\tself.need_min_p_sampling = False\n\t\t\tself.is_all_greedy = True\n\t\telse:\n\t\t\tif temperatures is not None:\n\t\t\t\tself.temperatures = temperatures.unsqueeze(-1)\n\t\t\telse:\n\t\t\t\tself.temperatures = torch.full((bsz, 1), pretrained_config.temperature, device=device, dtype=torch.float32)\n\t\t\t\n\t\t\tif top_ps is not None:\n\t\t\t\tself.top_ps = top_ps.unsqueeze(-1)\n\t\t\telse:\t\n\t\t\t\tself.top_ps = torch.full((bsz, 1), pretrained_config.top_p, device=device, dtype=torch.float32)\n\t\t\tself.top_ks = torch.full((bsz, 1), pretrained_config.top_k, device=device, dtype=torch.float32)\n\t\t\tself.need_min_p_sampling = False\n\t\t\tself.is_all_greedy = False\n\nclass Sampler(nn.Module):\n\tdef __init__(self):\n\t\tsuper().__init__()\n\t\n\tdef forward(\n\t\tself,\n\t\tlogits: torch.Tensor,\n\t\tsampling_config: SamplingOptions = None,\n\t):\n\t\tif sampling_config == None:\n\t\t\tsampling_config = SamplingOptions()\n\n\t\tlogits = logits.contiguous()\n\t\torigin_logits = logits.clone()\n\t\tif sampling_config.is_all_greedy:\n\t\t\t# Use torch.argmax if all requests use greedy sampling\n\t\t\tprobs = logits\n\t\t\tbatch_next_token_ids = torch.argmax(logits, -1)\n\t\telse:\n\t\t\t# Post process logits\n\t\t\tlogits.div_(sampling_config.temperatures)\n\t\t\tmax_top_k_round, batch_size = 32, logits.shape[0]\n\t\t\tif sampling_config.need_min_p_sampling:\n\t\t\t\tprobs = torch.softmax(logits, dim=-1)\n\t\t\t\tlogits = None\n\t\t\t\tdel logits\n\t\t\t\tprobs = top_k_renorm_probs(probs, sampling_config.top_ks)\n\t\t\t\tprobs = top_p_renorm_probs(probs, sampling_config.top_ps)\n\t\t\t\tbatch_next_token_ids = min_p_sampling_from_probs(\n\t\t\t\t\tprobs, sampling_config.min_ps\n\t\t\t\t)\n\t\t\t\ttemperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]\n\t\t\t\tbatch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)\n\t\t\telse:\n\t\t\t\t# TODO: use different kernel when don't need top_k or top_p\n\t\t\t\t# @TODO get probs\n\t\t\t\tprobs = logits\n\t\t\t\tbatch_next_token_ids = top_k_top_p_sampling_from_logits(\n\t\t\t\t\tlogits,\n\t\t\t\t\tsampling_config.top_ks,\n\t\t\t\t\tsampling_config.top_ps,\n\t\t\t\t\tfilter_apply_order=\"joint\",\n\t\t\t\t)\n\t\t\t\ttemperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]\n\t\t\t\tbatch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)\n\t\t\t\n\t\treturn batch_next_token_ids.to(torch.int32), probs"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/sched_rpc.py",
    "content": "from datetime import datetime\nimport os\nfrom typing import Optional\nimport zmq\nimport pickle\nimport threading\nimport torch.multiprocessing as mp\nimport sys\ncurrent_file_path = os.path.abspath(__file__)\n# sys.path.insert(0, os.path.join(os.path.dirname(__file__), \"..\", \"..\", \"..\"))\nimport pickle\nimport argparse\nfrom ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe\n\n\n\nif mp.get_start_method(allow_none=True) is None:\n    print('set start method')\n    mp.set_start_method('spawn')\nelse:\n    print(f'start method already set to {mp.get_start_method(allow_none=True)}')\n\n\nclass SchedulerServer:\n    def __init__(self, settings, main_args):\n        self.sched = sched_ext.create_scheduler(settings)\n    \n        self.context = zmq.Context()\n        self.frontend = self.context.socket(zmq.ROUTER)\n        print(f\"sched zmq rpc server on port {main_args.sched_port}\")\n        self.frontend.bind(f\"tcp://*:{main_args.sched_port}\") \n\n        self.backend = self.context.socket(zmq.DEALER)\n        self.backend.bind(\"inproc://backend\")\n\n    def run_scheduler(self):\n        self.sched.run()\n\n    def stop_scheduler(self):\n        self.sched.stop()\n\n    def start_proxy(self):\n        zmq.proxy(self.frontend, self.backend)\n\n    def worker_routine(self):\n        worker = self.context.socket(zmq.REP)\n        worker.connect(\"inproc://backend\")\n        while True:\n            try:\n                message = worker.recv()\n                data = pickle.loads(message)\n\n                method = data.get('method')\n                params = data.get('params', {})\n                # print(f\"Received request: {method}\")\n\n                if method == 'add_query':\n                    query_add = params.get('query')\n                    query_id = self.sched.add_query(query_add)\n                    response = {'status': 'ok', 'query_id': query_id}\n                    worker.send(pickle.dumps(response))\n\n                elif method == 'cancel_query':\n                    query_id = params.get('query_id')\n                    self.sched.cancel(query_id)\n                    response = {'status': 'ok'}\n                    worker.send(pickle.dumps(response))\n\n                elif method == 'update_last_batch':\n                    updates = params.get('updates')\n\n                    batch_todo = self.sched.update_last_batch(updates)\n\n                    response = {'status': 'ok', 'batch_todo': batch_todo}\n                    # print (batch_todo.query_lengths, batch_todo.query_ids)\n                    worker.send(pickle.dumps(response))\n\n                elif method == 'get_inference_context':\n                    inference_context = self.sched.get_inference_context()\n                    data = {\n                        \"k_cache\":inference_context.k_cache,\n                        \"v_cache\":inference_context.v_cache\n                    }\n                    print(f\"Serializing KVCache\")\n                    data[\"k_cache\"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']]\n                    data[\"v_cache\"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']]\n                    # print(data)\n                    response = {'status': 'ok', 'inference_context': data}\n\n                    worker.send(pickle.dumps(response))\n                    # response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1 \n                    # print(\"k_cache update\")\n\n                else:\n                    response = {'status': 'error', 'message': 'Unknown method'}\n                    worker.send(pickle.dumps(response))\n\n            except Exception as e:\n                response = {'status': 'error', 'message': str(e)}\n                worker.send(pickle.dumps(response))\n\n    def start_rpc_service(self):\n        try:\n            print(\"Scheduler RPC service is running...\")\n\n            threading.Thread(target=self.run_scheduler, daemon=True).start()\n\n            for _ in range(10):\n                threading.Thread(target=self.worker_routine, daemon=True).start()\n\n            self.start_proxy()\n\n        except KeyboardInterrupt:\n            print(\"Shutting down scheduler RPC service...\")\n            self.stop_rpc_service()\n\n    def stop_rpc_service(self):\n        self.stop_scheduler()\n        self.frontend.close()\n        self.backend.close()\n        self.context.term()\n\ndef start_server(settings, main_args):\n    server = SchedulerServer(settings, main_args)\n    server.start_rpc_service()\n\n\n# Add async client for webserver\nclass SchedulerClient:\n    def __init__(self, sched_port):\n        address=f'tcp://localhost:{sched_port}'\n        self.address = address\n        self.context = zmq.Context()\n        self.socket = self.context.socket(zmq.REQ)\n        self.socket.connect(self.address)\n        print(f\"Connected to server at {self.address}\")\n    \n    def __del__(self):\n        self.socket.close()\n        self.context.term()\n    \n    def send_request(self, method, params=None):\n        if params is None:\n            params = {}\n        request = {\n            'method': method,\n            'params': params\n        }\n        # print(f'send request {request}')\n        self.socket.send(pickle.dumps(request))\n        response = self.socket.recv()\n        # print(response)\n        response = pickle.loads(response)\n        if response.get('status') == 'ok':\n            return response\n        else:\n            raise Exception(f\"Error from server: {response.get('message')}\")\n    \n    def add_query(self, query):\n        response = self.send_request('add_query', {'query': query})\n        return response.get('query_id')\n    \n    def cancel_query(self, query_id):\n        self.send_request('cancel_query', {'query_id': query_id})\n    \n    def update_last_batch(self, updates):\n        response = self.send_request('update_last_batch', {'updates': updates})\n        # print(f\"update_last_batch response {response}\")\n        return response.get('batch_todo')\n    \n    def rebuild_inferece_context(self,response):\n        data = response.get('inference_context')\n        inference_context = sched_ext.InferenceContext()\n        print('Rebuilding kvcache')\n        inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']]\n        inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']]\n        return inference_context\n\n    def get_inference_context_raw(self):\n        response = self.send_request('get_inference_context')\n        return response\n       \n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config\", type=str, required=True)\n    args = parser.parse_args()\n    with open(args.config, \"rb\") as f:\n        main_args = pickle.load(f)\n    if main_args.architectures == \"Qwen2MoeForCausalLM\": \n        settings = create_sched_settings_qwen2moe(main_args)\n    elif main_args.architectures == \"Qwen3MoeForCausalLM\":\n        settings = create_sched_settings_qwen3moe(main_args)\n    else:\n        settings = create_sched_settings(main_args)\n    start_server(settings, main_args)\n"
  },
  {
    "path": "kt-sft/ktransformers/server/balance_serve/settings.py",
    "content": "'''\nDate: 2024-11-13 09:43:39\nLastEditors: djw\nLastEditTime: 2024-11-18 16:41:03\n'''\nimport sys, os\nimport yaml, json\nfrom time import sleep\n\n\nimport sched_ext\nfrom transformers import AutoConfig\n\nfrom ktransformers.models.configuration_qwen3_moe import Qwen3MoeConfig\n\ndef create_sched_settings(args):\n    default_sample_options = sched_ext.SampleOptions()\n    model_name = os.path.basename(os.path.normpath(args.model_dir))\n    input_model_settings = sched_ext.ModelSettings()\n    input_model_settings.model_path = args.model_dir\n    input_model_settings.params_count = int(0)\n    model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n    input_model_settings.layer_count = model_config.num_hidden_layers\n    input_model_settings.num_k_heads = 1 # model_config[\"num_key_value_heads\"]\n    input_model_settings.k_head_dim = 576\n    input_model_settings.bytes_per_params = 2\n    input_model_settings.bytes_per_kv_cache_element = 2\n    settings = sched_ext.Settings()\n    settings.model_name = model_name\n    settings.quant_type = \"BF16\"\n    settings.model_settings = input_model_settings\n    settings.page_size = args.page_size\n    settings.gpu_device_count = 1 # tp\n    settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]\n    # settings.gpu_memory_size = args.cache_lens*576*2\n    settings.gpu_memory_size = args.gpu_memory_size\n    settings.memory_utilization_percentage = args.utilization_percentage\n    max_batch_size = args.max_batch_size\n    chunk_size = args.chunk_size\n\n    max_decode_batch_size = max_batch_size - 2\n\n    settings.max_batch_size = max_batch_size\n    settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2\n    settings.sample_options = default_sample_options\n    settings.sched_metrics_port = args.sched_metrics_port\n    settings.gpu_only = args.memory_gpu_only\n    settings.use_self_defined_head_dim = True\n    settings.self_defined_head_dim = 576\n    settings.full_kv_cache_on_each_gpu = True\n    settings.k_cache_on = True\n    settings.v_cache_on = False\n\n    settings.kvc2_root_path = '/mnt/data/persist-kvc'\n    settings.kvc2_config_path = args.kvc2_config_dir\n    settings.memory_pool_size_GB = args.cpu_memory_size_GB\n    settings.evict_count = 40\n    settings.kvc2_metrics_port = args.kvc2_metrics_port\n    settings.load_from_disk = False\n    settings.save_to_disk = True\n\n\n    settings.strategy_name = args.sched_strategy\n\n    settings.auto_derive()\n    return settings\n\n\ndef create_sched_settings_qwen2moe(args):\n    default_sample_options = sched_ext.SampleOptions()\n    model_name = os.path.basename(os.path.normpath(args.model_dir))\n    input_model_settings = sched_ext.ModelSettings()\n    input_model_settings.model_path = args.model_dir\n    input_model_settings.params_count = int(0)\n    model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n    input_model_settings.layer_count = model_config.num_hidden_layers\n    input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config[\"num_key_value_heads\"]\n    input_model_settings.k_head_dim = 128\n    input_model_settings.bytes_per_params = 2\n    input_model_settings.bytes_per_kv_cache_element = 2\n    settings = sched_ext.Settings()\n    settings.model_name = model_name\n    settings.quant_type = \"BF16\"\n    settings.model_settings = input_model_settings\n    settings.page_size = args.page_size\n    settings.gpu_device_count = 1 # tp\n    settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]\n    # settings.gpu_memory_size = args.cache_lens*576*2\n    settings.gpu_memory_size = args.gpu_memory_size\n    settings.memory_utilization_percentage = args.utilization_percentage\n    max_batch_size = args.max_batch_size\n    chunk_size = args.chunk_size\n\n    max_decode_batch_size = max_batch_size - 2\n\n    settings.max_batch_size = max_batch_size\n    settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2\n    settings.sample_options = default_sample_options\n    settings.sched_metrics_port = args.sched_metrics_port\n    settings.gpu_only = args.memory_gpu_only\n    settings.use_self_defined_head_dim = False\n    settings.self_defined_head_dim = 576\n    settings.full_kv_cache_on_each_gpu = True\n    settings.k_cache_on = True\n    settings.v_cache_on = True\n\n    settings.kvc2_root_path = '/mnt/data/persist-kvc'\n    settings.kvc2_config_path = args.kvc2_config_dir\n    settings.memory_pool_size_GB = args.cpu_memory_size_GB\n    settings.evict_count = 40\n    settings.kvc2_metrics_port = args.kvc2_metrics_port\n    settings.load_from_disk = False\n    settings.save_to_disk = True\n\n\n    settings.strategy_name = args.sched_strategy\n\n    settings.auto_derive()\n    return settings\n\n\n\ndef create_sched_settings_qwen3moe(args):\n    default_sample_options = sched_ext.SampleOptions()\n    model_name = os.path.basename(os.path.normpath(args.model_dir))\n    input_model_settings = sched_ext.ModelSettings()\n    input_model_settings.model_path = args.model_dir\n    input_model_settings.params_count = int(0)\n    model_config = Qwen3MoeConfig.from_pretrained(args.model_dir, trust_remote_code=True)\n    input_model_settings.layer_count = model_config.num_hidden_layers\n    input_model_settings.num_k_heads = model_config.num_key_value_heads # model_config[\"num_key_value_heads\"]\n    input_model_settings.k_head_dim = 128\n    input_model_settings.bytes_per_params = 2\n    input_model_settings.bytes_per_kv_cache_element = 2\n    settings = sched_ext.Settings()\n    settings.model_name = model_name\n    settings.quant_type = \"BF16\"\n    settings.model_settings = input_model_settings\n    settings.page_size = args.page_size\n    settings.gpu_device_count = 1 # tp\n    settings.gpu_device_id = [i for i in range(settings.gpu_device_count)]\n    # settings.gpu_memory_size = args.cache_lens*576*2\n    settings.gpu_memory_size = args.gpu_memory_size\n    settings.memory_utilization_percentage = args.utilization_percentage\n    max_batch_size = args.max_batch_size\n    chunk_size = args.chunk_size\n\n    max_decode_batch_size = max_batch_size - 2\n\n    settings.max_batch_size = max_batch_size\n    settings.recommended_chunk_prefill_token_count = (chunk_size - max_decode_batch_size) // 2\n    settings.sample_options = default_sample_options\n    settings.sched_metrics_port = args.sched_metrics_port\n    settings.gpu_only = args.memory_gpu_only\n    settings.use_self_defined_head_dim = False\n    settings.self_defined_head_dim = 576\n    settings.full_kv_cache_on_each_gpu = True\n    settings.k_cache_on = True\n    settings.v_cache_on = True\n\n    settings.kvc2_root_path = '/mnt/data/persist-kvc'\n    settings.kvc2_config_path = args.kvc2_config_dir\n    settings.memory_pool_size_GB = args.cpu_memory_size_GB\n    settings.evict_count = 40\n    settings.kvc2_metrics_port = args.kvc2_metrics_port\n    settings.load_from_disk = False\n    settings.save_to_disk = True\n\n\n    settings.strategy_name = args.sched_strategy\n\n    settings.auto_derive()\n    return settings\n\n\n\n\n\n\n"
  },
  {
    "path": "kt-sft/ktransformers/server/config/config.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n\"\"\"\nDescription  :\nAuthor       : unicornchan\nDate         : 2024-06-11 16:35:42\nVersion      : 1.0.0\nLastEditors  : WuHao\nLastEditTime : 2024-08-12 06:31:14\n\"\"\"\nimport os\nimport shutil\nimport yaml\nimport psutil\n\nfrom ktransformers.server.config.singleton import Singleton\nfrom typing import Optional\n\n\nclass Config(metaclass=Singleton):\n    \"\"\"Singleton pattern Config class, used to get all configurations.\"\"\"\n\n    CONFIG_FILE_NAME = \"config.yaml\"\n\n    @staticmethod\n    def load() -> dict:\n        \"\"\"load config file\n\n        Returns:\n            dict: all configs\n        \"\"\"\n        base_path: str = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\n        config_yaml: str = os.path.join(base_path, \"configs\", Config.CONFIG_FILE_NAME)\n\n        user_path: str = os.path.expanduser(\"~\")\n        localstore_path: str = os.path.join(user_path, \".ktransformers\")\n        kvc2_config_dir = os.path.join(localstore_path, \"kvc2\")\n        config_path: str = os.path.join(localstore_path, Config.CONFIG_FILE_NAME)\n        if not os.path.exists(config_yaml):\n            print(f\"Can't find config file, {config_yaml}\")\n            exit(-1)\n        if not os.path.exists(localstore_path):\n            os.mkdir(localstore_path)\n        if not os.path.exists(kvc2_config_dir):\n            os.mkdir(kvc2_config_dir)\n        if not os.path.exists(config_path):\n            shutil.copyfile(config_yaml, config_path)\n        with open(config_path, \"r\", encoding=\"utf-8\") as fp:\n            config = yaml.safe_load(fp)\n        return config\n\n    @staticmethod\n    def to_path(path: str) -> str:\n        \"\"\"\n        process file path\n        \"\"\"\n        base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\n        real_path = path if os.path.isabs(path) else os.path.join(base_path, path)\n        return real_path\n\n    def __init__(self):\n        cfg = Config.load()\n        self.base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\n        self.user_path: str = os.path.expanduser(\"~\")\n        self.localstore_path: str = os.path.join(self.user_path, \".ktransformers\")\n        # log configs\n        self.log_dir = os.path.join(self.localstore_path, cfg[\"log\"][\"dir\"])\n        if not os.path.exists(self.log_dir):\n            os.mkdir(self.log_dir)\n        self.log_file = cfg[\"log\"][\"file\"]\n        self.log_level = cfg[\"log\"][\"level\"]\n        self.backup_count = cfg[\"log\"][\"backup_count\"]\n\n        self.kvc2_config_dir = os.path.join(self.localstore_path, \"kvc2\")\n        # server configs\n        self.server: dict = cfg.get(\"server\", {})\n        self.server_ip = self.server.get(\"ip\", \"0.0.0.0\")\n        self.server_port = self.server.get(\"port\", 9016)\n        self.api_key = self.server.get(\"api_key\", \"\")\n\n        # db configs\n        self.db_configs: dict = cfg.get(\"db\", {})\n        self.db_type = self.db_configs.get(\"type\", \"\")\n        self.db_host = self.localstore_path\n        self.db_port = self.db_configs.get(\"port\", \"\")\n        self.db_name = self.db_configs.get(\"database\", \"\")\n        self.db_pool_size = self.db_configs.get(\"pool_size\")\n        self.db_database = self.db_configs.get(\"database\", \"\")\n\n        # user config\n        self.user_config: dict = cfg.get(\"user\", {})\n        self.user_secret_key = self.user_config.get(\"secret_key\", \"\")\n        self.user_algorithm = self.user_config.get(\"algorithm\", \"\")\n        self.user_force_think = self.user_config.get(\"force_think\", False)\n\n        # model config\n        self.model: dict = cfg.get(\"model\", {})\n        self.backend_type: str = self.model.get(\"type\", \"transformers\")\n        self.model_dir: str = self.model.get(\"path\", \"\")\n        # to make sure it consistent with previous version\n        self.model_path: str = self.model_dir\n        self.model_name: str = self.model.get(\"name\", \"\")\n        self.architectures: str = self.model.get(\"name\", \"\")\n        self.model_device: str = self.model.get(\"device\", \"cuda:0\")\n        self.gguf_path: Optional[str] = self.model.get(\"gguf_path\", None)\n        self.use_cuda_graph = self.model.get(\"use_cuda_graph\", True)\n        self.trust_remote_code = self.model.get(\"trust_remote_code\", True)\n        # self.model_cache_lens = self.model.get(\"cache_lens\")\n        self.optimize_config_path: Optional[str] = self.model.get(\n            \"optimize_config_path\", None\n        )\n        \n        self.max_new_tokens = self.model.get(\"max_new_tokens\", 2000)\n        self.json_mode = self.model.get(\"json_mode\", False)\n        self.healing = self.model.get(\"healing\", False)\n        self.ban_strings: Optional[list] = self.model.get(\"ban_strings\", None)\n        self.gpu_split: Optional[str] = self.model.get(\"gpu_split\", None)\n        self.length: Optional[int] = self.model.get(\"length\", None)\n        self.rope_scale: Optional[float] = self.model.get(\"rope_scale\", None)\n        self.rope_alpha: Optional[float] = self.model.get(\"rope_alpha\", None)\n        self.no_flash_attn = self.model.get(\"no_flash_attn\", False)\n        self.low_mem = self.model.get(\"low_mem\", False)\n        self.experts_per_token: Optional[int] = self.model.get(\"experts_per_token\", None)\n        self.load_q4 = self.model.get(\"load_q4\", False)\n        self.fast_safetensors = self.model.get(\"fast_safetensors\", False)\n        self.draft_model_dir: Optional[str] = self.model.get(\"draft_model_dir\", None)\n        self.no_draft_scale = self.model.get(\"no_draft_scale\", False)\n        self.modes = self.model.get(\"modes\", False)\n        self.mode = self.model.get(\"mode\", \"llama\")\n        self.username = self.model.get(\"username\", \"User\")\n        self.botname = self.model.get(\"botname\", \"Chatbort\")\n        self.system_prompt: Optional[str] = self.model.get(\"system_prompt\", None)\n        self.temperature = self.model.get(\"temperature\", 0.95)\n        self.smoothing_factor = self.model.get(\"smoothing_factor\", 0.0)\n        self.dynamic_temperature: Optional[str] = self.model.get(\"dynamic_temperature\", None)\n        self.top_k = self.model.get(\"top_k\", 50)\n        self.top_p = self.model.get(\"top_p\", 0.8)\n        self.top_a = self.model.get(\"top_a\", 0.0)\n        self.skew = self.model.get(\"skew\", 0.0)\n        self.typical = self.model.get(\"typical\", 0.0)\n        self.repetition_penalty = self.model.get(\"repetition_penalty\", 1.01)\n        self.frequency_penalty = self.model.get(\"frequency_penalty\", 0.0)\n        self.presence_penalty = self.model.get(\"presence_penalty\", 0.0)\n        self.response_chunk = self.model.get(\"response_chunk\", 250)\n        self.no_code_formatting = self.model.get(\"no_code_formatting\", False)\n        self.cache_8bit = self.model.get(\"cache_8bit\", False)\n        self.cache_q4 = self.model.get(\"cache_q4\", True)\n        self.ngram_decoding = self.model.get(\"ngram_decoding\", False)\n        self.print_timings = self.model.get(\"print_timings\", False)\n        self.amnesia = self.model.get(\"amnesia\", False)\n        self.batch_size = self.model.get(\"batch_size\", 1)\n        self.cache_lens = self.model.get(\"cache_lens\", 4096)\n        self.device = self.model.get(\"device\", \"cuda:2\")\n\n        # web config\n        self.web: dict = cfg.get(\"web\", {})\n        self.web_cross_domain: bool = self.web.get(\"open_cross_domain\", True)\n        self.mount_web: bool = self.web.get(\"mount\", False)\n\n        # ext\n        self.ext: dict = cfg.get(\"ext\", {})\n        self.cpu_infer = psutil.cpu_count(logical=False) - 3\n\n        # file config\n        self.local_store_configs: dict = cfg.get(\"local_store\", {})\n        self.file_upload_dir: str = os.path.join(\n            self.localstore_path, self.local_store_configs.get(\"file_upload_dir\", \"\")\n        )\n        self.assistant_store_dir: str = os.path.join(\n            self.localstore_path, self.local_store_configs.get(\"assistant_store_dir\", \"\")\n        )\n\n        # long context config\n        self.long_context_config: dict = cfg.get(\"long_context\", {})\n        self.max_seq_len = self.long_context_config.get(\"max_seq_len\", 32000)\n        self.block_size = self.long_context_config.get(\"block_size\", 128)\n        self.local_windows_len = self.long_context_config.get(\"local_windows_len\", 4096)\n        self.second_select_num = self.long_context_config.get(\"second_select_num\", 32)\n        self.anchor_type = self.long_context_config.get(\"anchor_type\", \"DYNAMIC\")\n        self.kv_type = self.long_context_config.get(\"kv_type\", \"FP16\")\n        self.dense_layer_num = self.long_context_config.get(\"dense_layer_num\", 2)\n        self.anchor_num = self.long_context_config.get(\"anchor_num\", 1)\n        self.preselect_block = self.long_context_config.get(\"preselect_block\", True)\n        self.head_select_mode = self.long_context_config.get(\"head_select_mode\", \"SHARED\")\n        self.preselect_block_count = self.long_context_config.get(\"preselect_block_count\", 32)\n        self.layer_step = self.long_context_config.get(\"layer_step\", 1)\n        self.token_step = self.long_context_config.get(\"token_step\", 100)\n\n        # local chat\n        self.local_chat_config: dict = cfg.get(\"local_chat\", {})\n        self.prompt_file = self.local_chat_config.get(\"prompt_file\", None)\n\n        # asyncserver\n        self.sched_strategy = cfg[\"async_server\"][\"sched_strategy\"]\n        self.sched_port = cfg[\"async_server\"][\"sched_port\"]\n        self.sched_metrics_port = cfg[\"async_server\"][\"sched_metrics_port\"]\n        self.kvc2_metrics_port = cfg[\"async_server\"][\"kvc2_metrics_port\"]\n        self.max_batch_size = cfg[\"async_server\"][\"max_batch_size\"]\n        self.page_size = cfg[\"attn\"][\"page_size\"]\n        self.chunk_size = cfg[\"attn\"][\"chunk_size\"]\n        self.memory_gpu_only = cfg[\"kvc2\"][\"gpu_only\"]\n        self.cache_lens = ((self.cache_lens + self.page_size - 1) // self.page_size) * self.page_size\n        self.gpu_memory_size = 2*576*61*self.cache_lens\n        self.utilization_percentage = 1.0 #cfg[\"kvc2\"][\"utilization_percentage\"]\n        self.cpu_memory_size_GB = cfg[\"kvc2\"][\"cpu_memory_size_GB\"]\n        # only support 2 prefill task\n        self.max_prefill_batch_size = 2\n        self.max_decode_batch_size = self.max_batch_size - self.max_prefill_batch_size \n\n"
  },
  {
    "path": "kt-sft/ktransformers/server/config/log.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : unicornchan\nDate         : 2024-06-12 02:48:39\nVersion      : 1.0.0\nLastEditors  : chenxl \nLastEditTime : 2024-07-27 01:55:50\n'''\n\nimport codecs\nimport logging\nimport os\nimport re\nimport locale\nfrom pathlib import Path\nfrom logging.handlers import BaseRotatingHandler\nimport time\nimport colorlog\n\nfrom ktransformers.server.config.config import Config\n\n\nclass DailyRotatingFileHandler(BaseRotatingHandler):\n    \"\"\"\n    such as 'logging.TimeRotatingFileHandler', Additional features:\n     - support multiprocess\n     - support rotating daily\n    \"\"\"\n\n    def __init__(self, filename, backupCount=0, encoding=None, delay=False, utc=False, **kwargs): # pylint: disable=unused-argument\n        self.backup_count = backupCount\n        self.utc = utc\n        self.suffix = \"%Y-%m-%d\"\n        self.base_log_path = Path(filename)\n        if not os.path.exists(self.base_log_path.parent):\n            os.makedirs(self.base_log_path.parent)\n        self.base_filename = self.base_log_path.name\n        self.current_filename = self._compute_fn()\n        self.current_log_path = self.base_log_path.with_name(\n            self.current_filename)\n        BaseRotatingHandler.__init__(self, filename, 'a', encoding, delay)\n\n    # pylint: disable=unused-argument, invalid-name\n    def shouldRollover(self, record):\n        \"\"\"\n        Determine whether to rotate the log. If the log filename corresponding to the current \n        time is not consistent with the currently opened log filename, then it is necessary\n        to rotate the log\n        Args:\n            record: record is not used, as we are just comparing times, but it is needed so\n        the method signatures are the same\n        \"\"\"\n        if self.current_filename != self._compute_fn():\n            return True\n        return False\n\n    def doRollover(self):\n        \"\"\"\n        roll over\n        \"\"\"\n        # close last log file\n        if self.stream:\n            self.stream.close()\n            self.stream = None  # type: ignore\n\n        # gen new log file name\n        self.current_filename = self._compute_fn()\n        self.current_log_path = self.base_log_path.with_name(\n            self.current_filename)\n\n        if not self.delay:\n            self.stream = self._open() # type: ignore\n\n        self.delete_expired_files()\n\n    def _compute_fn(self):\n        \"\"\"\n        gen log file name\n        \"\"\"\n        return self.base_filename + \".\" + time.strftime(self.suffix, time.localtime())\n\n    def _open(self):\n        \"\"\"\n        open a new log file, create soft link\n        \"\"\"\n        if self.encoding is None:\n            stream = open(str(self.current_log_path), self.mode, encoding=locale.getpreferredencoding())\n        else:\n            stream = codecs.open(str(self.current_log_path), self.mode, self.encoding)\n\n        if self.base_log_path.exists():\n            try:\n                if not self.base_log_path.is_symlink() or os.readlink(self.base_log_path) != self.current_filename:\n                    os.remove(self.base_log_path)\n            except OSError:\n                pass\n\n        try:\n            os.symlink(self.current_filename, str(self.base_log_path))\n        except OSError:\n            pass\n        return stream\n\n    def delete_expired_files(self):\n        \"\"\"\n        delete expired files every day\n        \"\"\"\n        if self.backup_count <= 0:\n            return\n\n        file_names = os.listdir(str(self.base_log_path.parent))\n        result = []\n        prefix = self.base_filename + \".\"\n        plen = len(prefix)\n        for file_name in file_names:\n            if file_name[:plen] == prefix:\n                suffix = file_name[plen:]\n                if re.match(r\"^\\d{4}-\\d{2}-\\d{2}(\\.\\w+)?$\", suffix):\n                    result.append(file_name)\n        if len(result) < self.backup_count:\n            result = []\n        else:\n            result.sort()\n            result = result[:len(result) - self.backup_count]\n\n        for file_name in result:\n            os.remove(str(self.base_log_path.with_name(file_name)))\n\n\nclass Logger(object):\n    \"\"\"\n    logger class\n    \"\"\"\n    level_relations = {\n        'debug': logging.DEBUG,\n        'info': logging.INFO,\n        'warn': logging.WARNING,\n        'error': logging.ERROR,\n        'crit': logging.CRITICAL\n    }\n\n    def __init__(self, level: str = 'info'):\n        fmt = '%(asctime)s %(levelname)s %(pathname)s[%(lineno)d] %(funcName)s: %(message)s'\n        cfg: Config = Config()\n        filename: str = os.path.join(cfg.log_dir, cfg.log_file)\n        backup_count: int = cfg.backup_count\n        th = DailyRotatingFileHandler(filename=filename, when='MIDNIGHT', backupCount=backup_count, encoding=\"utf-8\")\n        th.setFormatter(logging.Formatter(fmt))\n\n\n        color_fmt = (\n            '%(log_color)s%(asctime)s %(levelname)s %(pathname)s[%(lineno)d]: %(message)s'\n        )\n        color_formatter = colorlog.ColoredFormatter(\n            color_fmt,\n            log_colors={\n                'DEBUG': 'cyan',\n                'INFO': 'green',\n                'WARNING': 'yellow',\n                'ERROR': 'red',\n                'CRITICAL': 'bold_red'\n            }\n        )\n\n        sh = logging.StreamHandler()\n        sh.setFormatter(color_formatter)\n\n        self.logger = logging.getLogger(filename)\n        self.logger.setLevel(self.level_relations.get(level)) # type: ignore\n        self.logger.addHandler(th)\n        self.logger.addHandler(sh)\n\n\nlogger = Logger(level=Config().log_level).logger\n"
  },
  {
    "path": "kt-sft/ktransformers/server/config/singleton.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  : Implement singleton\nAuthor       : unicornchan\nDate         : 2024-06-11 17:08:36\nVersion      : 1.0.0\nLastEditors  : chenxl \nLastEditTime : 2024-07-27 01:55:56\n'''\nimport abc\n\nclass Singleton(abc.ABCMeta, type):\n    \"\"\"_summary_\n\n    Args:\n        abc.ABCMeta: Provide a mechanism for defining abstract methods and properties,\n            enforcing subclasses to implement these methods and properties.\n        type: Inherit from 'type' to make 'Singleton' a metaclass,\n            enabling the implementation of the Singleton\n    \"\"\"\n    _instances = {}\n\n    def __call__(cls, *args, **kwds):\n        if cls not in cls._instances:\n            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwds)\n        return cls._instances[cls]\n\nclass AbstractSingleton(abc.ABC, metaclass=Singleton):\n    \"\"\"Provided an abstract Singleton base class, any class inheriting from\n       this base class will automatically become a Singleton class.\n\n    Args:\n        abc.ABC: Abstract base class, it cannot be instantiated, only inherited. \n    \"\"\"\n"
  },
  {
    "path": "kt-sft/ktransformers/server/crud/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/crud/assistants/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/crud/assistants/assistants.py",
    "content": "from time import time\nfrom typing import Optional,List\nfrom uuid import uuid4\n\nfrom ktransformers.server.models.assistants.assistants import Assistant\nfrom ktransformers.server.schemas.assistants.assistants import AssistantCreate,AssistantObject,AssistantModify\nfrom ktransformers.server.utils.sql_utils import SQLUtil\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.schemas.base import Order\n\n\nclass AssistantDatabaseManager:\n    def __init__(self) -> None:\n        self.sql_util = SQLUtil()\n\n    def create_assistant_object(self, assistant: AssistantCreate) -> AssistantObject:\n        assistant = AssistantObject(\n            **assistant.model_dump(mode='json'),\n            id=str(uuid4()),\n            object='assistant',\n            created_at=int(time()),\n        )\n        return assistant\n\n    def db_count_assistants(self) -> int:\n        with self.sql_util.get_db() as db:\n            return db.query(Assistant).count()\n\n    def db_create_assistant(self, assistant: AssistantCreate):\n        ass_obj = self.create_assistant_object(assistant)\n        ass_obj.sync_db()\n        return ass_obj\n\n    def db_list_assistants(self, limit: Optional[int], order: Order) -> List[AssistantObject]:\n        with self.sql_util.get_db() as db:\n            query = db.query(Assistant).order_by(\n                order.to_sqlalchemy_order()(Assistant.created_at))\n            if limit is not None:\n                db_assistants = query.limit(limit)\n            else:\n                db_assistants = query.all()\n            return [AssistantObject.model_validate(a.__dict__) for a in db_assistants]\n\n    def db_get_assistant_by_id(self, assistant_id: str) -> Optional[AssistantObject]:\n        with self.sql_util.get_db() as db:\n            db_assistant = db.query(Assistant).filter(\n                Assistant.id == assistant_id).first()\n            if db_assistant is None:\n                logger.debug(f\"no assistant with id {str}\")\n                return None\n            return AssistantObject.model_validate(db_assistant.__dict__)\n\n    def db_update_assistant_by_id(self, assistant_id: str, assistant: AssistantModify):\n        with self.sql_util.get_db() as db:\n            db_assistant = db.query(Assistant).filter(\n                Assistant.id == assistant_id).first()\n            self.sql_util.db_update_commit_refresh(db, db_assistant, assistant)\n            return AssistantObject.model_validate(db_assistant.__dict__)\n\n    def db_delete_assistant_by_id(self, assistant_id: str):\n        with self.sql_util.get_db() as db:\n            db_assistant = db.query(Assistant).filter(\n                Assistant.id == assistant_id).first()\n            db.delete(db_assistant)\n            db.commit()\n\n"
  },
  {
    "path": "kt-sft/ktransformers/server/crud/assistants/messages.py",
    "content": "from time import time\nfrom typing import Optional\nfrom uuid import uuid4\n\nfrom ktransformers.server.models.assistants.messages import Message\nfrom ktransformers.server.schemas.assistants.messages import MessageCore, MessageCreate,  MessageObject\nfrom ktransformers.server.schemas.base import Order,ObjectID\nfrom ktransformers.server.utils.sql_utils import SQLUtil\n\nclass MessageDatabaseManager:\n    def __init__(self) -> None:\n        self.sql_util = SQLUtil()\n\n    @staticmethod\n    def create_db_message_by_core(message: MessageCore):\n        message_dict = message.model_dump(mode=\"json\")\n        return Message(**message_dict, id=str(uuid4()), created_at=int(time()))\n\n    def create_db_message(self, message: MessageCreate):\n        return MessageDatabaseManager.create_db_message_by_core(message.to_core())\n\n    def db_add_message(self, message: Message):\n        with self.sql_util.get_db() as db:\n            db.add(message)\n            self.sql_util.db_add_commit_refresh(db, message)\n\n    def db_create_message(self, thread_id: str, message: MessageCreate, status: MessageObject.Status):\n        db_message = self.create_db_message(message)\n        db_message.status = status.value\n        db_message.thread_id = thread_id\n        self.db_add_message(db_message)\n        return MessageObject.model_validate(db_message.__dict__)\n\n    @staticmethod\n    def create_message_object(thread_id: ObjectID, run_id: ObjectID, message: MessageCreate):\n        core = message.to_core()\n        return MessageObject(\n            **core.model_dump(mode='json'),\n            id=str(uuid4()),\n            object='thread.message',\n            created_at=int(time()),\n            thread_id=thread_id,\n            run_id=run_id,\n            status=MessageObject.Status.in_progress,\n        )\n\n    def db_sync_message(self, message: MessageObject):\n        db_message = Message(\n            **message.model_dump(mode=\"json\"),\n        )\n        with self.sql_util.get_db() as db:\n            self.sql_util.db_merge_commit(db, db_message)\n\n    def db_list_messages_of_thread(\n            self, thread_id: str, limit: Optional[int] = None, order: Order = Order.DESC):\n\n        # logger.debug(\n        #     f\"list messages of: {thread_id}, limit {limit}, order {order}\")\n        with self.sql_util.get_db() as db:\n            query = (\n                db.query(Message)\n                .filter(Message.thread_id == thread_id)\n                .order_by(order.to_sqlalchemy_order()(Message.created_at))\n            )\n            if limit is not None:\n                messages = query.limit(limit)\n            else:\n                messages = query.all()\n            message_list = [MessageObject.model_validate(m.__dict__) for m in messages]\n        return message_list\n\n    def db_get_message_by_id(self, thread_id: ObjectID, message_id: ObjectID) -> MessageObject:\n        with self.sql_util.get_db() as db:\n            message = db.query(Message).filter(\n                Message.id == message_id).first()\n        assert message.thread_id == thread_id\n        message_info = MessageObject.model_validate(message.__dict__)\n        return message_info\n\n    def db_delete_message_by_id(self, thread_id: ObjectID, message_id: ObjectID):\n        with self.sql_util.get_db() as db:\n            message = db.query(Message).filter(\n                Message.id == message_id).first()\n            assert message.thread_id == thread_id\n            db.delete(message)\n            db.commit()\n"
  },
  {
    "path": "kt-sft/ktransformers/server/crud/assistants/runs.py",
    "content": "from time import time\nfrom uuid import uuid4\n\nfrom ktransformers.server.models.assistants.runs import Run\nfrom ktransformers.server.schemas.assistants.runs import RunCreate,RunObject\nfrom ktransformers.server.schemas.base import ObjectID\nfrom ktransformers.server.utils.sql_utils import SQLUtil\n\n\nclass RunsDatabaseManager:\n    def __init__(self) -> None:\n        self.sql_util = SQLUtil()\n\n    def create_run_object(self, thread_id: ObjectID, run: RunCreate) -> RunObject:\n        run_obj = RunObject(\n            **run.model_dump(mode='json', exclude={\"stream\"}),\n            id=str(uuid4()),\n            object='run',\n            created_at=int(time()),\n            thread_id=thread_id,\n            status=RunObject.Status.queued,\n        )\n        run_obj.set_compute_save(0)\n        return run_obj\n\n    def db_create_run(self, thread_id: str, run: RunCreate):\n        db_run = Run(\n            **run.model_dump(mode=\"json\", exclude={\"stream\"}),\n            id=str(uuid4()),\n            created_at=int(time()),\n            status=\"queued\",\n            thread_id=thread_id,\n        )\n        with self.sql_util.get_db() as db:\n            self.sql_util.db_add_commit_refresh(db, db_run)\n            run_obj = RunObject.model_validate(db_run.__dict__)\n            run_obj.set_compute_save(0)\n        return run_obj\n\n    def db_sync_run(self, run: RunObject) -> None:\n        db_run = Run(\n            **run.model_dump(mode='json'),\n        )\n        with self.sql_util.get_db() as db:\n            self.sql_util.db_merge_commit(db, db_run)\n\n    def db_get_run(self, run_id: ObjectID) -> RunObject:\n        with self.sql_util.get_db() as db:\n            db_run = db.query(Run).filter(Run.id == run_id).first()\n            return RunObject.model_validate(db_run.__dict__)\n"
  },
  {
    "path": "kt-sft/ktransformers/server/crud/assistants/threads.py",
    "content": "from time import time\nfrom typing import Optional,List\nfrom uuid import uuid4\n\nfrom ktransformers.server.models.assistants.messages import Message\nfrom ktransformers.server.models.assistants.threads import Thread\nfrom ktransformers.server.schemas.assistants.threads import ThreadCreate,ThreadObject\nfrom ktransformers.server.schemas.base import ObjectID, Order\nfrom ktransformers.server.schemas.conversation import ThreadPreview\nfrom ktransformers.server.utils.sql_utils import SQLUtil\nfrom ktransformers.server.crud.assistants.messages import MessageDatabaseManager\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager\n\nclass ThreadsDatabaseManager:\n    def __init__(self) -> None:\n        self.sql_util = SQLUtil()\n        self.message_manager = MessageDatabaseManager()\n        self.assistant_maanager = AssistantDatabaseManager()\n\n    def db_create_thread(self, thread: ThreadCreate):\n        thread_id = str(uuid4())\n        db_messages = []\n        with self.sql_util.get_db() as db:\n            if thread.messages is not None:\n                logger.debug(\"Creating messages first for thread\")\n                for message in thread.messages:\n                    db_message: Message = MessageDatabaseManager.create_db_message_by_core(\n                        message)\n                    db_message.role = \"user\"\n                    db_message.thread_id = thread_id\n                    db.add(db_message)\n                    db_messages.append(db_message)\n\n            db_thread = Thread(\n                **thread.model_dump(exclude=\"messages\"),\n                id=str(uuid4()),\n                created_at=int(time()),\n                messages=db_messages,\n            )\n\n            self.sql_util.db_add_commit_refresh(db, db_thread)\n            thread_obj = ThreadObject.model_validate(db_thread.__dict__)\n\n            if 'assistant_id' in thread.meta_data:\n#                assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'], db)\n                assistant = self.assistant_maanager.db_get_assistant_by_id(thread.meta_data['assistant_id'])\n                logger.info(\n                    f'Append this related thread to assistant {assistant.id}')\n                assistant.append_related_threads([thread_obj.id])\n                assistant.sync_db(db)\n        return thread_obj\n\n    def db_get_thread_by_id(self, thread_id: ObjectID):\n        with self.sql_util.get_db() as db:\n            db_thread = db.query(Thread).filter(Thread.id == thread_id).first()\n            return ThreadObject.model_validate(db_thread.__dict__)\n\n    def db_list_threads(self, limit: Optional[int], order: Order) -> List[ThreadObject]:\n        with self.sql_util.get_db() as db:\n            query = db.query(Thread).order_by(order.to_sqlalchemy_order()(\n                Thread.created_at)).filter(~Thread.meta_data.contains('assistant_id'))\n\n            if limit is not None:\n                db_threads = query.limit(limit)\n            else:\n                db_threads = query.all()\n\n            return [ThreadObject.model_validate(tool.__dict__) for tool in db_threads]\n\n    def db_list_threads_preview(self, limit: Optional[int], order: Order) -> List[ThreadPreview]:\n        threads = self.db_list_threads(limit, order)\n        previews = []\n        for thread in threads:\n            messages = self.message_manager.db_list_messages_of_thread(\n                thread.id, limit=2, order=Order.ASC)\n            if len(messages) == 2:\n                message = messages[0]\n                assistant = self.assistant_maanager.db_get_assistant_by_id(\n                    messages[1].assistant_id)\n            else:\n                message = None\n                assistant = None\n            previews.append(ThreadPreview(\n                assistant=assistant, thread=thread, first_message=message))\n        return previews\n\n    def db_delete_thread_by_id(self, thread_id: ObjectID):\n        with self.sql_util.get_db() as db:\n            db_thread = db.query(Thread).filter(Thread.id == thread_id).first()\n            db.delete(db_thread)\n            # TODO delete related messages and runs and other stuff or just gc\n            db.commit()\n"
  },
  {
    "path": "kt-sft/ktransformers/server/exceptions.py",
    "content": "from fastapi import HTTPException, status\n\n\ndef db_exception():\n    return HTTPException(\n        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,\n        detail=\"DB Error\",\n    )\n\n\ndef not_implemented(what):\n    return HTTPException(\n        status_code=status.HTTP_501_NOT_IMPLEMENTED,\n        detail=f\"{what} not implemented\",\n    )\n\n\ndef internal_server_error(what):\n    return HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f\"{what}\")\n\n\ndef request_error(what):\n    return HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f\"{what}\")\n"
  },
  {
    "path": "kt-sft/ktransformers/server/main.py",
    "content": "import os\nimport re\nfrom fastapi import FastAPI\nfrom fastapi.staticfiles import StaticFiles\nimport uvicorn.logging\nimport uvicorn\nimport sys\nproject_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom ktransformers.server.args import ArgumentParser\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.utils.create_interface import create_interface, GlobalInterface\nfrom fastapi.openapi.utils import get_openapi\nfrom fastapi import FastAPI\nfrom fastapi.middleware.cors import CORSMiddleware\nfrom ktransformers.server.api import router, post_db_creation_operations\nfrom ktransformers.server.utils.sql_utils import Base, SQLUtil\nfrom ktransformers.server.config.log import logger\n\n\ndef mount_app_routes(mount_app: FastAPI):\n    sql_util = SQLUtil()\n    logger.info(\"Creating SQL tables\")\n    Base.metadata.create_all(bind=sql_util.sqlalchemy_engine)\n    post_db_creation_operations()\n    mount_app.include_router(router)\n\n\ndef create_app():\n    cfg = Config()\n    if(hasattr(GlobalInterface.interface, \"lifespan\")):\n        app = FastAPI(lifespan=GlobalInterface.interface.lifespan)\n    else:\n        app = FastAPI()\n    if Config().web_cross_domain:\n        app.add_middleware(\n            CORSMiddleware,\n            allow_origins=[\"*\"],\n            allow_credentials=True,\n            allow_methods=[\"*\"],\n            allow_headers=[\"*\"],\n        )\n    mount_app_routes(app)\n    if cfg.mount_web:\n        mount_index_routes(app)\n    return app\n\n\ndef update_web_port(config_file: str):\n    ip_port_pattern = (\n        r\"(localhost|((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)):[0-9]{1,5}\"\n    )\n    with open(config_file, \"r\", encoding=\"utf-8\") as f_cfg:\n        web_config = f_cfg.read()\n    ip_port = \"localhost:\" + str(Config().server_port)\n    new_web_config = re.sub(ip_port_pattern, ip_port, web_config)\n    with open(config_file, \"w\", encoding=\"utf-8\") as f_cfg:\n        f_cfg.write(new_web_config)\n\n\ndef mount_index_routes(app: FastAPI):\n    project_dir = os.path.dirname(os.path.dirname(__file__))\n    web_dir = os.path.join(project_dir, \"website/dist\")\n    web_config_file = os.path.join(web_dir, \"config.js\")\n    update_web_port(web_config_file)\n    if os.path.exists(web_dir):\n        app.mount(\"/web\", StaticFiles(directory=web_dir), name=\"static\")\n    else:\n        err_str = f\"No website resources in {web_dir}, please complile the website by npm first\"\n        logger.error(err_str)\n        print(err_str)\n        exit(1)\n\n\ndef run_api(app, host, port, **kwargs):\n    if kwargs.get(\"ssl_keyfile\") and kwargs.get(\"ssl_certfile\"):\n        uvicorn.run(\n            app,\n            host=host,\n            port=port,\n            ssl_keyfile=kwargs.get(\"ssl_keyfile\"),\n            ssl_certfile=kwargs.get(\"ssl_certfile\"),\n        )\n    else:\n        uvicorn.run(app, host=host, port=port, log_level=\"debug\")\n\n\ndef custom_openapi(app):\n    if app.openapi_schema:\n        return app.openapi_schema\n    openapi_schema = get_openapi(\n        title=\"ktransformers server\",\n        version=\"1.0.0\",\n        summary=\"This is a server that provides a RESTful API for ktransformers.\",\n        description=\"We provided chat completion and openai assistant interfaces.\",\n        routes=app.routes,\n    )\n    openapi_schema[\"info\"][\"x-logo\"] = {\"url\": \"https://kvcache.ai/media/icon_1.png\"}\n    app.openapi_schema = openapi_schema\n    return app.openapi_schema\n\n\ndef main():\n    cfg = Config()\n\n    arg_parser = ArgumentParser(cfg)\n\n    args = arg_parser.parse_args()\n    create_interface(config=cfg, default_args=cfg)\n    app = create_app()\n    custom_openapi(app)\n\n    run_api(\n        app=app,\n        host=args.host,\n        port=args.port,\n        ssl_keyfile=args.ssl_keyfile,\n        ssl_certfile=args.ssl_certfile,\n    )\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-sft/ktransformers/server/models/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/models/assistants/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/models/assistants/assistants.py",
    "content": "from sqlalchemy import JSON, Column, Float, Integer, String, Text\nfrom sqlalchemy.orm import relationship\n\nfrom ktransformers.server.utils.sql_utils import Base\n\n\nclass Assistant(Base):\n    __tablename__ = \"assistants\"\n\n    id = Column(String, primary_key=True, index=True)\n    object = Column(String, default=\"assistant\")\n    created_at = Column(Integer)\n\n    name = Column(String, nullable=True)\n    description = Column(String, nullable=True)\n    model = Column(String)\n    instructions = Column(Text, nullable=True)\n    tools = Column(JSON)\n    tool_resources = Column(JSON)\n    temperature = Column(Float, nullable=True)\n    meta_data = Column(JSON, nullable=True)\n    top_p = Column(Float, nullable=True)\n    response_format = Column(JSON, default=\"auto\")\n\n    build_status = Column(JSON, nullable=True)\n\n    runs = relationship(\"Run\", back_populates=\"assistant\")\n\n    messages = relationship(\"Message\", back_populates=\"assistant\")\n"
  },
  {
    "path": "kt-sft/ktransformers/server/models/assistants/messages.py",
    "content": "from sqlalchemy import JSON, Column, ForeignKey, Integer, String\nfrom sqlalchemy.orm import relationship\n\nfrom ktransformers.server.utils.sql_utils import Base\n\n\nclass Message(Base):\n    __tablename__ = \"messages\"\n\n    id = Column(String, primary_key=True, index=True)\n    object = Column(String, default=\"thread.message\")\n    created_at = Column(Integer)\n\n    thread_id = Column(String, ForeignKey(\"threads.id\"))\n    status = Column(String, default=\"in_progress\")\n    incomplete_details = Column(JSON, nullable=True)\n    completed_at = Column(Integer, nullable=True)\n    incomplete_at = Column(Integer, nullable=True)\n    role = Column(JSON)\n    content = Column(JSON)\n    assistant_id = Column(String, ForeignKey(\"assistants.id\"), nullable=True)\n    run_id = Column(String, ForeignKey(\"runs.id\"), nullable=True)\n    attachments = Column(JSON, nullable=True)\n    meta_data = Column(JSON, nullable=True)\n\n    thread = relationship(\"Thread\", back_populates=\"messages\")\n    assistant = relationship(\"Assistant\", back_populates=\"messages\")\n    run = relationship(\"Run\", back_populates=\"message\")\n"
  },
  {
    "path": "kt-sft/ktransformers/server/models/assistants/run_steps.py",
    "content": "from sqlalchemy import JSON, Column, ForeignKey, Integer, String\nfrom sqlalchemy.orm import relationship\n\nfrom ktransformers.server.utils.sql_utils import Base\n\n\nclass RunStep(Base):\n    __tablename__ = \"run_steps\"\n    # todo\n    id = Column(String, primary_key=True, index=True)\n    object = Column(String, default=\"thread.run.step\")\n    created_at = Column(Integer)\n\n    assistant_id = Column(String, ForeignKey(\"assistants.id\"))\n    thread_id = Column(String, ForeignKey(\"threads.id\"))\n    run_id = Column(String, ForeignKey(\"runs.id\"))\n    type = Column(String)\n    status = Column(String)\n    step_details = Column(JSON)\n    last_error = Column(JSON, nullable=True)\n    expires_at = Column(Integer, nullable=True)\n    cancelled_at = Column(Integer, nullable=True)\n    failed_at = Column(Integer, nullable=True)\n    completed_at = Column(Integer, nullable=True)\n\n    meta_data = Column(JSON, nullable=True)\n    usage = Column(JSON, nullable=True)\n\n    assistant = relationship(\"Assistant\", back_populates=\"run_steps\")\n    thread = relationship(\"Thread\", back_populates=\"run_steps\")\n    run = relationship(\"Run\", back_populates=\"run_steps\")\n"
  },
  {
    "path": "kt-sft/ktransformers/server/models/assistants/runs.py",
    "content": "from sqlalchemy import JSON, Column, Float, ForeignKey, Integer, String, Text\nfrom sqlalchemy.orm import relationship\n\nfrom ktransformers.server.utils.sql_utils import Base\n\n\nclass Run(Base):\n    __tablename__ = \"runs\"\n\n    id = Column(String, primary_key=True, index=True)\n    object = Column(String, default=\"thread.run\")\n    created_at = Column(Integer)\n    thread_id = Column(String, ForeignKey(\"threads.id\"))\n    assistant_id = Column(String, ForeignKey(\"assistants.id\"))\n    status = Column(String)\n    required_action = Column(JSON, nullable=True)\n    last_error = Column(JSON, nullable=True)\n    expires_at = Column(Integer, nullable=True)\n    started_at = Column(Integer, nullable=True)\n    cancelled_at = Column(Integer, nullable=True)\n    failed_at = Column(Integer, nullable=True)\n    completed_at = Column(Integer, nullable=True)\n    incomplete_details = Column(JSON, nullable=True)\n    # get from assistant\n    model = Column(String)\n    instructions = Column(Text, nullable=True)\n    tools = Column(JSON)\n    meta_data = Column(JSON, nullable=True)\n    usage = Column(JSON, nullable=True)\n    temperature = Column(Float, nullable=True)\n    top_p = Column(Float, nullable=True)\n    max_propmp_tokens = Column(Integer, nullable=True)\n    truncation_strategy = Column(JSON)\n    tool_choice = Column(JSON)\n    response_format = Column(JSON, default=\"auto\")\n\n    thread = relationship(\"Thread\", back_populates=\"runs\")\n    assistant = relationship(\"Assistant\", back_populates=\"runs\")\n    message = relationship(\"Message\", back_populates=\"run\")\n"
  },
  {
    "path": "kt-sft/ktransformers/server/models/assistants/threads.py",
    "content": "from sqlalchemy import JSON, Column, Integer, String\nfrom sqlalchemy.orm import relationship\n\nfrom ktransformers.server.utils.sql_utils import Base\n\n\nclass Thread(Base):\n    __tablename__ = \"threads\"\n\n    id = Column(String, primary_key=True, index=True)\n    object = Column(String, default=\"thread\")\n    created_at = Column(Integer)\n\n    tool_resources = Column(JSON, nullable=True)\n    meta_data = Column(JSON, nullable=True)\n\n    runs = relationship(\"Run\", back_populates=\"thread\")\n    messages = relationship(\"Message\", back_populates=\"thread\")\n"
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/assistants/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/assistants/assistants.py",
    "content": "from enum import Enum\nfrom time import time\nfrom typing import AsyncIterable, Callable, Dict, List, Optional, Union\nfrom asyncio import Lock, Queue\n\nfrom fastapi import logger\nfrom pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator\nimport torch\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.models.assistants.assistants import Assistant\nfrom ktransformers.server.models.assistants.threads import Thread\nfrom ktransformers.server.schemas.assistants.messages import Role\nfrom ktransformers.server.schemas.assistants.runs import RunObject,RunStreamResponse,ObjectWithCreatedTime\nfrom ktransformers.server.schemas.assistants.threads import ThreadObject\nfrom ktransformers.server.schemas.base import Metadata,MetadataField,ObjectID\nfrom ktransformers.server.schemas.assistants.tool import Tool,CodeInterpreter,FileSearch,RelatedThreads,FuntionTool,ToolResource,CodeInterpreterResource,FileSearchResource,RelatedThreadsResource,ToolType\nfrom ktransformers.server.utils.sql_utils import SQLUtil\n\n\nclass AssistantBase(BaseModel):\n    name: Optional[str] = Field(None,description='The name of the assistant.') \n    description: Optional[str] = Field(None,description='The description of the assistant.')\n    instructions: Optional[str] = Field(None,description='Instructions which is added in front of the input of LLM') \n    tools: List[Tool] = Field([], max_length=128)\n\n    @field_validator('tools', mode='before')\n    def validate_tools(cls, value):\n        re = []\n        if not isinstance(value, list):\n            raise ValueError('Invalid type for tools')\n\n        for tool in value:\n            if 'type' not in tool:\n                raise ValueError('Invalid type for tools')\n            if tool['type'] == 'code_interpreter':\n                re.append(CodeInterpreter(**tool))\n            elif tool['type'] == 'file_search':\n                re.append(FileSearch(**tool))\n            elif tool['type'] == 'related_threads':\n                re.append(RelatedThreads(**tool))\n            elif tool['type'] == 'function':\n                re.append(FuntionTool(**tool))\n            else:\n                raise ValueError('Invalid type for tools')\n        return re\n\n    tool_resources: List[ToolResource] = Field([], max_length=128)\n\n    @field_validator('tool_resources', mode='before')\n    def validate_tool_resources(cls, value):\n        re = []\n        if not isinstance(value, list):\n            raise ValueError('Invalid type for tool resources')\n\n        for tool_re in value:\n            if 'file_ids' in tool_re:\n                re.append(CodeInterpreterResource(**tool_re))\n            elif 'vector_stores' in tool_re:\n                re.append(FileSearchResource(**tool_re))\n            elif 'thread_ids' in tool_re:\n                re.append(RelatedThreadsResource(**tool_re))\n            else:\n                raise ValueError('Invalid type for tool resources')\n        return re\n\n    meta_data: Metadata = MetadataField\n\n    @model_validator(mode='before')\n    def convert_meta_data(cls, values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n    temperature: Optional[float] = Field(ge=0.0, le=2.0, default=1)\n    top_p: Optional[float] = Field(ge=0.0, le=1.0, default=1)\n    response_format: Union[str, Dict[str, str]] = \"auto\"\n\n\nclass AssistantCreate(AssistantBase):\n    model: str\n\n\nclass AssistantBuildStatus(BaseModel):\n    class Status(Enum):\n        not_build = \"not_build\"\n        in_queue = \"in_queue\"\n        parsing = \"parsing\"\n        prefilling = \"prefilling\"\n        dumping = \"dumping\"\n        completed = \"completed\"\n        paused = \"paused\"\n\n    _lock: Lock = PrivateAttr(default_factory=Lock)\n    _queue: Optional[Queue] = PrivateAttr(None)\n\n    status: Status = Field(default=Status.not_build)\n    total_file_count: int = Field(default=0)\n    parsed_file_count: int = Field(default=0)\n\n    prefilling_current: int = Field(default=0)\n    prefilling_total: int = Field(default=0)\n\n    build_started_time: Optional[int] = Field(default=None)\n    build_completed_time: Optional[int] = Field(default=None)\n\n    # in megabytes\n    assistant_usage: int = Field(default=0, description='')\n    assistant_total_usage: int = Field(default=0)\n    disk_free_space: int = Field(default=0)\n    disk_total_space: int = Field(default=0)\n\n    def to_stream_reply(self) -> str:\n        return f\"event: assistant.build.status\\ndata: {self.model_dump_json()}\\n\\n\"\n\n\nclass AssistantObject(AssistantBase, ObjectWithCreatedTime):\n    model: Optional[str] = Field(\n        default=Config().model_name)\n    related_threads_objects: Optional[List] = Field(None, exclude=True)\n    _encoded_instruction: Optional[torch.Tensor] = PrivateAttr(default=None)\n    build_status: AssistantBuildStatus = Field(default=AssistantBuildStatus())\n\n    def as_api_response(self):\n        return self.model_dump(exclude={'build_status'})\n\n    def get_related_threads_ids(self) -> List[ObjectID]:\n        re = []\n        for tool, tool_re in zip(self.tools, self.tool_resources):\n            if tool.type == ToolType.RELATED_THREADS:\n                re += tool_re.thread_ids or []\n        return re\n\n    def get_related_threads_objects(self) -> List:\n        # raise NotImplementedError  # should be replaced\n        sql_utils = SQLUtil()\n        if self.related_threads_objects is None:\n            with sql_utils.get_db() as db:\n                db_threads = db.query(Thread).all()\n            self.related_threads_objects = [tool for tool in [ThreadObject.model_validate(\n                tool.__dict__) for tool in db_threads] if tool.is_related_threads and tool.meta_data['assistant_id'] == self.id]\n            # logger.debug(\n            #     f'Found {len(self.related_threads_objects)} related threads')\n        return self.related_threads_objects\n\n    def append_related_threads(self, thread_ids: List[ObjectID]):\n        # logger.debug(f'{self.tools} {self.tool_resources}')\n        for tool, tool_re in zip(self.tools, self.tool_resources):\n            if tool.type == ToolType.RELATED_THREADS:\n                tool_re.thread_ids += thread_ids\n                return\n\n        self.tools.append(RelatedThreads(type=ToolType.RELATED_THREADS))\n        self.tool_resources.append(\n            RelatedThreadsResource(thread_ids=thread_ids))\n\n    async def update_build_status(self, events: AsyncIterable) -> AsyncIterable:\n        async for event in events:\n            # logger.debug(event)\n            if isinstance(event, RunStreamResponse):\n                if event.event == RunObject.Status.completed:\n                    self.build_status.status = AssistantBuildStatus.Status.completed\n                    self.build_status.build_completed_time = int(time())\n                    self.sync_db()\n                    yield self.build_status.model_copy()\n            elif isinstance(event, dict):\n                # logger.debug('dict')\n                if 'stage' in event:\n                    if event['stage'] == 'prefill':\n                        self.build_status.status = AssistantBuildStatus.Status.prefilling\n                        self.build_status.prefilling_current = event['curr_progress']\n                        self.build_status.prefilling_total = event['max_progress']\n                    if event['stage'] == 'parse':\n                        self.build_status.status = AssistantBuildStatus.Status.parsing\n                        self.build_status.parsed_file_count = event['curr_progress']\n                        self.build_status.total_file_count = event['max_progress']\n                    yield self.build_status.model_copy()\n\n    def get_build_status(self) -> AssistantBuildStatus:\n        return self.build_status\n     \n    \n    def sync_db(self)->None:\n        # raise NotImplementedError # should be replaced\n        sql_utils = SQLUtil()\n        db_assistant = Assistant(\n            **self.model_dump(mode='json'),\n        )\n        with sql_utils.get_db() as db:\n            sql_utils.db_merge_commit(db, db_assistant)\n    \n    def get_encoded_instruction(self,encode_fn:Callable)->torch.Tensor:\n        if self._encoded_instruction is None:\n            logger.info(f'encoding assistant instruction: {self.instructions}')\n            self._encoded_instruction = encode_fn(self.instructions, Role.user)\n        return self._encoded_instruction\n\n\nclass AssistantModify(AssistantBase):\n    model: Optional[str] = None\n\n\n# Non API Backend\n"
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/assistants/messages.py",
    "content": "from enum import Enum\nfrom typing import ForwardRef, List, Optional, Union,Callable\n\nimport torch\nfrom pydantic import BaseModel, PrivateAttr, model_validator\n\nfrom ktransformers.server.exceptions import not_implemented\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.models.assistants.messages import Message\nfrom ktransformers.server.schemas.base import Metadata, MetadataField, ObjectWithCreatedTime\nfrom ktransformers.server.schemas.assistants.tool import Field,CodeInterpreter,FileSearch\nfrom ktransformers.server.utils.sql_utils import SQLUtil\n\n\nclass IncompleteDetails(BaseModel):\n    reason: str\n\n\nclass ContentType(Enum):\n    image_file = \"image_file\"\n    image_url = \"image_url\"\n    text = \"text\"\n\n\nclass ContentObject(BaseModel):\n    type: ContentType\n\n\nclass ImageFile(BaseModel):\n    file_id: str\n    detail: str\n\n\nclass ImageFileObject(ContentObject):\n    image_file: ImageFile\n\n\nclass ImageUrl(BaseModel):\n    url: str\n    detail: str\n\n\nclass ImageUrlObject(ContentObject):\n    image_url: ImageUrl\n\n\nclass Annotation(BaseModel):\n    todo: str\n\n\nclass Text(BaseModel):\n    value: str\n    annotations: List[Annotation] = Field(default=[])\n\n\nclass TextObject(ContentObject):\n    text: Text\n    delta_index: int = Field(default=0,exclude=True)\n    special_tokens_on: bool = Field(default=False,exclude=True) \n    last_two: str= Field(default='',exclude=True)  \n\n    def filter_append(self,text:str):     \n        self.text.value+=text\n        self.delta_index+=1\n        return True  \n\n\n\nContent = Union[ImageFileObject, ImageUrlObject, TextObject]\n\n\nclass Attachment(BaseModel):\n    file_id: Optional[str] = Field(default=None)\n    tools: Optional[List[Union[CodeInterpreter, FileSearch]]] = Field(default=None)\n\n\nclass Role(Enum):\n    user = \"user\"\n    assistant = \"assistant\"\n\n    def is_user(self)->bool:\n        return self == Role.user\n\n\nclass MessageCore(BaseModel):\n    role: Role\n    content: List[Content]\n    attachments: Optional[List[Attachment]]\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n\n\nclass MessageBase(MessageCore):\n    class Status(Enum):\n        created = \"created\" # only used for stream\n        in_progress = \"in_progress\"\n        incomplete = \"incomplete\"\n        completed = \"completed\"\n    thread_id: str\n    status: Status\n    incomplete_details: Optional[IncompleteDetails] = None\n    completed_at: Optional[int] = None\n    incomplete_at: Optional[int] = None\n\n    assistant_id: Optional[str] = None\n    run_id: Optional[str]\n\n\nMessageStreamResponse = ForwardRef('MessageStreamResponse')\n\nclass MessageObject(MessageBase, ObjectWithCreatedTime):\n    _encoded_content: Optional[torch.Tensor] = PrivateAttr(default=None)\n    \n\n    def get_text_content(self) -> str:\n        text_content = \"\"\n        for content in self.content:\n            if content.type == ContentType.text:\n                text_content += content.text.value\n            else:\n                raise not_implemented(\"Content other than text\")\n        return text_content\n\n    async def get_encoded_content(self,encode_fn:Callable):\n        if self._encoded_content is None:\n            logger.info(f'encoding {self.role.value} message({self.status.value}): {self.get_text_content()}')\n            self._encoded_content = encode_fn(self.get_text_content(),self.role)\n\n            for f in self.get_attached_files():\n                logger.info(f'encoding file: {f.filename}')\n                self._encoded_content = torch.cat([self._encoded_content, encode_fn(await f.get_str(),self.role)],dim=-1)\n                yield None \n\n        yield self._encoded_content\n\n\n    def get_attached_files(self):\n        raise NotImplementedError # should be replaced \n\n\n\n    def append_message_delta(self,text:str):\n        raise NotImplementedError # should be replaced \n    \n    def sync_db(self):\n        # raise NotImplementedError # should be replaced\n        sql_utils = SQLUtil()\n        db_message = Message(\n            **self.model_dump(mode=\"json\"),\n        )\n        with sql_utils.get_db() as db:\n            sql_utils.db_merge_commit(db, db_message)\n    \n\n    def stream_response_with_event(self, event: MessageBase.Status) -> MessageStreamResponse:\n        match event:\n            case MessageObject.Status.created:\n                self.status = MessageObject.Status.in_progress\n            case _:\n                self.status = event\n        return MessageStreamResponse(message=self, event=event)\n   \n\nclass MessageStreamResponse(BaseModel):\n    message: MessageObject\n    event: MessageObject.Status\n\n    def to_stream_reply(self):\n        return f\"event: thread.message.{self.event.value}\\ndata: {self.message.model_dump_json()}\\n\\n\"\n\n\nclass MessageCreate(BaseModel):\n    role: Role = Field(default=Role.user)\n    content: Union[str | List[Content]]\n    attachments: Optional[List[Attachment]] = None\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n\n    def to_core(self) -> MessageCore:\n        # logger.debug(f\"Converting message create to core {self.model_dump()}\")\n        core = MessageCore(\n            role=self.role,\n            content=[],\n            attachments=self.attachments,\n            meta_data=self.meta_data,\n        )\n        if isinstance(self.content, str):\n            core.content = [TextObject(type=\"text\", text=Text(value=self.content, annotations=[]))]\n        elif isinstance(self.content, list):\n            core.content = self.content\n        else:\n            raise ValueError(\"Invalid content type\")\n        return core\n\n\nclass MessageModify(BaseModel):\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n"
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/assistants/runs.py",
    "content": "from enum import Enum\nfrom typing import Dict, List, Optional, Union, ForwardRef\n\nfrom pydantic import BaseModel, Field, model_validator\n\nfrom ktransformers.server.models.assistants.runs import Run\nfrom ktransformers.server.schemas.base import TODO, Metadata, MetadataField, ObjectWithCreatedTime\nfrom ktransformers.server.schemas.assistants.threads import ThreadCreate\nfrom ktransformers.server.schemas.assistants.tool import Tool, ToolResource\nfrom ktransformers.server.utils.sql_utils import SQLUtil\n\n\nclass ToolCall(BaseModel):\n    id: str\n    type: str\n    function: TODO\n\n\nclass SubmitToolOutputs(BaseModel):\n    tool_calls: List[ToolCall]\n\n\nclass RequiredAction(BaseModel):\n    type: str\n    submit_tool_outputs: TODO\n\n\nclass LastError(BaseModel):\n    code: str\n    message: str\n\n\nclass IncompleteDetails(BaseModel):\n    reason: str\n\n\nclass Usage(BaseModel):\n    completion_tokens: int\n    prompt_tokens: int\n    total_tokens: int\n\n\nclass TruncationStrategy(BaseModel):\n    type: str = \"auto\"\n    last_message: Optional[int]\n\n\nclass ToolChoiceType(Enum):\n    none = \"none\"\n    auto = \"auto\"\n    required = \"required\"\n\n\nclass RunBase(BaseModel):\n    class Status(Enum):\n        created = \"created\" # only stream event will have this created status\n        queued = \"queued\"\n        in_progress = \"in_progress\"\n        requires_action = \"requires_action\"\n        cancelling = \"cancelling\"\n        cancelled = \"cancelled\"\n        failed = \"failed\"\n        completed = \"completed\"\n        expired = \"expired\"\n\n\n    thread_id: str\n    assistant_id: str\n    status: Status = Status.queued\n    required_action: Optional[RequiredAction] = Field(None)\n    last_error: Optional[LastError] = Field(None)\n    expires_at: Optional[int]= Field(None)\n    started_at: Optional[int] = Field(None)\n    cancelled_at: Optional[int] = Field(None)\n    failed_at: Optional[int] = Field(None)\n    completed_at: Optional[int] = Field(None)\n    incomplete_details: Optional[IncompleteDetails] = Field(None)\n    model: Optional[str] = Field(None)\n    instructions: Optional[str] = Field(None)\n    tools: Optional[List[Tool]] = Field([])\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n    \n    def set_compute_save(self,save:int):\n        self.meta_data['compute_save'] = str(save)\n\n\n    usage: Optional[Usage] = Field(None)\n    temperature: Optional[float] = Field(None)\n    top_p: Optional[float]= Field(None)\n    max_propmp_tokens: Optional[int]= Field(None)\n    truncation_strategy: Optional[TruncationStrategy]= Field(None)\n    tool_choice: Optional[Union[ToolChoiceType, dict]]= Field(None)\n    response_format: Union[str, Dict[str, str]] = \"auto\"\n\n\nRunStreamResponse = ForwardRef('RunStreamResponse')\n\nclass RunObject(RunBase, ObjectWithCreatedTime):\n    def stream_response_with_event(self,event:RunBase.Status)->RunStreamResponse:\n        match event:\n            case RunBase.Status.created:\n                self.status = RunBase.Status.queued\n            case _:\n                self.status = event\n        return RunStreamResponse(run=self, event=event)\n \n    \n    def sync_db(self):\n        # raise NotImplementedError # should be replaced in crud\n        sql_utils = SQLUtil()\n        db_run = Run(\n            **self.model_dump(mode='json'),\n        )\n        with sql_utils.get_db() as db:\n            sql_utils.db_merge_commit(db, db_run)\n    \n    def create_message_creation_step(self):\n        raise NotImplementedError # should be replaced \n        \n\nclass RunStreamResponse(BaseModel):\n    run: RunObject\n    event: RunObject.Status\n    def to_stream_reply(self):\n        return f\"event: thread.run.{self.event.value}\\ndata: {self.run.model_dump_json()}\\n\\n\"\n\nclass RunCreate(BaseModel):\n    assistant_id: str\n    model: Optional[str] = Field(default=None)\n    instructions: Optional[str] = Field(default=None)\n    # TODO: Add this\n    # additional_instructions: Optional[str]\n    # additional_messages: Optional[List[MessageCore]]\n    tools: List[Tool] = Field(default=[])\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n    temperature: Optional[float] = Field(default=None)\n    top_p: Optional[float] = Field(default=None)\n    stream: Optional[bool] = Field(default=None)\n    max_propmp_tokens: Optional[int] = Field(default=None)\n    # TODO: Add this\n    # max_completion_tokens: Optional[int]\n    truncation_strategy: Optional[TruncationStrategy] = Field(default=None)\n    tool_choice: Optional[Union[ToolChoiceType, dict]] = Field(default=None)\n    response_format: Union[str, Dict[str, str]] = Field(default=\"auto\")\n\n\nclass RunThreadCreate(BaseModel):\n    assistant_id: str\n    thread: Optional[ThreadCreate]\n    model: Optional[str]\n    instructions: Optional[str]\n    tools: List[Tool]\n    tool_resources: List[ToolResource]\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n    temperature: Optional[float]\n    top_p: Optional[float]\n    stream: Optional[bool]\n    max_propmp_tokens: Optional[int]\n    # TODO: Add this\n    # max_completion_tokens: Optional[int]\n    truncation_strategy: TruncationStrategy\n    tool_choice: Union[ToolChoiceType, dict]\n    response_format: Union[str, Dict[str, str]] = \"auto\"\n\n\nclass RunModify(BaseModel):\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n\n\nclass ToolOutput(BaseModel):\n    tool_call_id: Optional[str]\n    output: Optional[str]\n\n\nclass RunSubmit(BaseModel):\n    tool_outputs: List[ToolOutput]\n    stream: Optional[bool]\n"
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/assistants/streaming.py",
    "content": "import asyncio\nfrom typing import AsyncIterable, List, Union\n\nfrom fastapi import Request\nfrom fastapi.responses import StreamingResponse\nfrom pydantic import BaseModel\n\nfrom ktransformers.server.schemas.assistants.runs import RunStreamResponse\nfrom ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.schemas.base import Object\nfrom ktransformers.server.schemas.assistants.messages import ContentType, ImageFileObject, ImageUrlObject, MessageObject, Text, TextObject\n\n\nclass TextObjectWithIndex(TextObject):\n    index: int\n\n\nclass ImageFileObjectWithIndex(ImageFileObject):\n    index: int\n\n\nclass ImageUrlObjectWithIndex(ImageUrlObject):\n    index: int\n\n\nContentWithIndex = Union[TextObjectWithIndex,\n                         ImageFileObjectWithIndex, ImageUrlObjectWithIndex]\n\n\nclass MessageDeltaImpl(BaseModel):\n    # role: Optional[str]\n    content: List[ContentWithIndex]\n\n\nclass MessageDelta(Object):\n    delta: MessageDeltaImpl\n\n    def to_stream_reply(self):\n        return f\"event: thread.message.delta\\ndata: {self.model_dump_json()}\\n\\n\"\n\n\ndef text_delta(index: int, text: str):\n    return MessageDeltaImpl(content=[TextObjectWithIndex(index=index, type=ContentType.text, text=Text(value=text))])\n\n\ndef append_message_delta(self: MessageObject, text: str):\n\n    if len(self.content) == 0:\n        self.content.append(TextObject(type=ContentType.text,\n                            text=Text(value=''), delta_index=0))\n\n    text_object: TextObject = self.content[0]\n    if text_object.filter_append(text):\n        return MessageDelta(id=self.id, object=\"thread.message.delta\", delta=text_delta(text_object.delta_index, text))\n    else:\n        return None\n\n\nMessageObject.append_message_delta = append_message_delta\n\n\nclass RunStepDeltaImpl(BaseModel):\n    pass\n\n\nclass RunStepDelta(Object):\n    delta: RunStepDeltaImpl\n\n    def to_stream_reply(self):\n        return f\"event: thread.run.step.delta\\ndata: {self.model_dump_json()}\\n\\n\"\n\n\nclass Done():\n    def to_stream_reply(self):\n        return f\"data: [DONE]\\n\\n\"\n\n\nasync def check_client_link(request: Request, async_events: AsyncIterable):\n    async for event in async_events:\n        if await request.is_disconnected():\n            break\n        yield event\n\n\nasync def add_done(async_events: AsyncIterable):\n    async for event in async_events:\n        yield event\n    yield Done()\n\n\nasync def to_stream_reply(async_events: AsyncIterable):\n    async for event in async_events:\n        if isinstance(event, str):\n            yield event\n        else:\n            yield event.to_stream_reply()\n\n\nasync def filter_api_event(async_events: AsyncIterable):\n    async for event in async_events:\n        if isinstance(event, MessageDelta) or isinstance(event, RunStepDelta) or isinstance(event, RunStreamResponse) or isinstance(event, Done):\n            yield event\n\n\nasync def filter_chat_chunk(async_events: AsyncIterable):\n    async for event in async_events:\n        if isinstance(event, ChatCompletionChunk):\n            yield event\n\n\nasync def filter_by_types(async_events: AsyncIterable, types: List):\n    async for event in async_events:\n        for type in types:\n            if isinstance(event, type):\n                yield event\n                continue\n\n\ndef api_stream_response(request: Request, async_events: AsyncIterable):\n    return StreamingResponse(check_client_link(request, to_stream_reply(add_done(filter_api_event(async_events)))), media_type=\"text/event-stream\")\n\n\ndef chat_stream_response(request: Request, async_events: AsyncIterable):\n    return StreamingResponse(check_client_link(request, to_stream_reply(add_done(filter_chat_chunk(async_events)))), media_type=\"text/event-stream\")\n\n\ndef stream_response(request: Request, async_events: AsyncIterable):\n    return StreamingResponse(check_client_link(request, to_stream_reply(add_done(async_events))), media_type=\"text/event-stream\")\n\n\ndef check_link_response(request: Request, async_events: AsyncIterable):\n    return StreamingResponse(check_client_link(request, async_events), media_type=\"text/event-stream\")\n\n\ndef wrap_async_generator_into_queue(async_events: AsyncIterable) -> asyncio.Queue:\n    queue = asyncio.Queue()\n\n    async def inner():\n        # logger.debug('run inner')\n        async for event in async_events:\n            # logger.debug(f'put: {event}')\n            await queue.put(event)\n            await asyncio.sleep(0)\n        # logger.debug(f'put: None')\n        await queue.put(None)\n    asyncio.create_task(inner())\n    return queue\n\n\nasync def unwrap_async_queue(queue: asyncio.Queue) -> AsyncIterable:\n    while True:\n        events = [await queue.get()]\n        events.extend([queue.get_nowait() for _ in range(queue.qsize())])\n\n        logger.debug(f'getting {len(events)} events')\n        for event in events:\n            if event is None:\n                break\n            yield event\n\n\nasync def unwrap_async_queue_slow(queue: asyncio.Queue) -> AsyncIterable:\n    while True:\n        event = await queue.get()\n        # logger.debug(f'unwrap_async_queue {event}')\n        if event is None:\n            break\n        yield event\n"
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/assistants/threads.py",
    "content": "from enum import Enum\nfrom typing import List\nfrom typing_extensions import Self \n\nfrom pydantic import BaseModel, Field, model_validator\n\nfrom ktransformers.server.schemas.base import Metadata, MetadataField, ObjectWithCreatedTime\nfrom ktransformers.server.schemas.assistants.tool import ToolResource\nfrom ktransformers.server.schemas.assistants.messages import MessageCore\n\n\nclass ThreadBase(BaseModel):\n    meta_data: Metadata = MetadataField\n    @model_validator(mode='before')\n    @classmethod\n    def convert_meta_data(cls,values):\n        if 'meta_data' in values:\n            values['metadata'] = values['meta_data']\n        return values\n\n    tool_resources: List[ToolResource] = Field([], max_length=128)\n\n\nclass ThreadObject(ThreadBase, ObjectWithCreatedTime):\n    is_related_threads:bool = Field(False,exclude=True)\n\n    @model_validator(mode='after')\n    def check_is_related_threads(self)->Self:\n        # logger.debug(f'check thread {self.id} is related thread? by {self}')\n        if 'assistant_id' in self.meta_data:\n            self.is_related_threads = True\n        return self\n\n    class StreamEvent(Enum):\n        created = 'created'\n\n    def to_stream_reply(self,event:StreamEvent):\n        return f\"event: thread.{event.value}\\ndata: {self.model_dump_json()}\\n\\n\"\n    \n\nclass ThreadCreate(ThreadBase):\n    messages: List[MessageCore] = Field(default=[])\n\n\nclass ThreadModify(ThreadBase):\n    pass\n\n\n# other than OpenAI API\n"
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/assistants/tool.py",
    "content": "from enum import Enum\nfrom typing import List, Optional, Union\n\nfrom pydantic import BaseModel, Field\n\nfrom ktransformers.server.schemas.base import ObjectID\n\n\nclass ToolType(str, Enum):\n    CODE_INTERPRETER = \"code_interpreter\"\n    FILE_SEARCH = \"file_search\"\n    RELATED_THREADS = \"related_threads\"\n    FUNCTION = \"function\"\n\n\nclass ToolBase(BaseModel):\n    type: ToolType\n\n\nclass CodeInterpreter(ToolBase):\n    pass\n\n\nclass FileSearch(ToolBase):\n    pass\n\n\nclass RelatedThreads(ToolBase):\n    pass\n\n\nclass FuntionTool(ToolBase):\n    description: str\n    name: str\n    parameters: List[str]\n\n\nTool = Union[CodeInterpreter, FileSearch, RelatedThreads, FuntionTool]\n\n\nclass CodeInterpreterResource(BaseModel):\n    file_ids: Optional[List[str]] = Field(default_factory=list, max_length=20)\n\n\nclass FileSearchResource(BaseModel):\n    vector_store_ids: Optional[List[str]] = Field(default_factory=list, max_length=1)\n    vector_stores: Optional[List[str]] = Field(default_factory=list, max_length=1)\n\n\nclass RelatedThreadsResource(BaseModel):\n    thread_ids: List[ObjectID] = Field(default=[])\n\n\nToolResource = Union[CodeInterpreterResource,FileSearchResource,RelatedThreadsResource] \n"
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/base.py",
    "content": "from enum import Enum\nfrom typing import Dict\n\nimport sqlalchemy\nfrom pydantic import BaseModel, ConfigDict, Field\n\nTODO = BaseModel\n\nObjectID = str\n\n\nclass Object(BaseModel):\n    id: ObjectID\n    object: str\n\n    model_config = ConfigDict(from_attributes=True)\n\n\n# Pydantic Base Models\nclass ObjectWithCreatedTime(Object):\n    created_at: int\n\n\n\nclass Order(str, Enum):\n    ASC = \"asc\"\n    DESC = \"desc\"\n\n    def to_sqlalchemy_order(self):\n        match self:\n            case Order.ASC:\n                return sqlalchemy.asc\n            case Order.DESC:\n                return sqlalchemy.desc\n\n\nMetadata = Dict[str, str]\nMetadataField: Metadata = Field({},max_length=16, alias=\"metadata\")\n\n\nclass DeleteResponse(Object):\n    deleted: bool = True\n\nclass OperationResponse(BaseModel):\n    operation: str\n    status: str\n"
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/conversation.py",
    "content": "from typing import Optional\n\nfrom pydantic import BaseModel\n\nfrom .assistants.assistants import AssistantObject\nfrom .assistants.threads import ThreadObject\nfrom .assistants.messages import MessageObject\n\nclass ThreadPreview(BaseModel):\n    assistant: Optional[AssistantObject] = None\n    thread: ThreadObject\n    first_message: Optional[MessageObject] = None\n"
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/endpoints/chat.py",
    "content": "from typing import List, Optional, Union, Dict, Any\nfrom typing_extensions import Literal\nfrom enum import Enum\nfrom pydantic import BaseModel, Field\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.schemas.base import Object\n\n\nfrom openai.types.chat.chat_completion_chunk import Choice\n\nfrom uuid import uuid4\n\nclass CompletionUsage(BaseModel):\n    prompt_tokens: int\n    completion_tokens: int\n    total_tokens: int\n    prompt_tokens_details: Optional[Dict[str, Any]] = None\n    completion_tokens_details: Optional[Dict[str, Any]] = None\n    prefill_time: Optional[float] = None\n    decode_time: Optional[float] = None\n\nclass Role(Enum):\n    system = 'system'\n    user = 'user'\n    assistant = 'assistant'\n    tool = 'tool'\n    function = 'function'\n\nclass Message(BaseModel):\n    content: Optional[str] = None\n    role: Role\n    name: Optional[str] = None\n    tool_calls: Optional[List[Dict[str, Any]]] = {}\n    tool_call_id: Optional[str] = None\n    \n    def to_tokenizer_message(self):\n        message = {'role': self.role.value}\n        if self.content is not None:\n            message['content'] = self.content\n        if self.name is not None:\n            message['name'] = self.name\n        if self.tool_calls is not {}:\n            message['tool_calls'] = self.tool_calls\n        if self.tool_call_id is not None:\n            message['tool_call_id'] = self.tool_call_id\n        return message\n\nclass FunctionParameters(BaseModel):\n    type: str = \"object\"\n    properties: Dict[str, Any] = {}\n    required: Optional[List[str]] = None\n\nclass FunctionDefinition(BaseModel):\n    name: str\n    description: Optional[str] = None\n    parameters: FunctionParameters = Field(default_factory=FunctionParameters)\n\nclass ToolFunction(BaseModel):\n    function: FunctionDefinition\n    \nclass Tool(BaseModel):\n    type: Literal[\"function\"]\n    function: FunctionDefinition\n\nclass ChatCompletionCreate(BaseModel):\n    messages: List[Message]\n    model: str\n    stream: bool = False\n    temperature: Optional[float] = Field(default=Config().temperature)\n    top_p: Optional[float] = Field(default=Config().top_p)\n    tools: Optional[List[Tool]] = None\n    tool_choice: Optional[Union[str, Dict[str, Any]]] = None\n    stream_options: Optional[Dict[str, Any]] = None\n    frequency_penalty: float = 0\n    presence_penalty: float = 0\n    max_tokens: Optional[int] = Field(default=None)\n    max_completion_tokens: Optional[int] = Field(default=None)\n    return_speed: Optional[bool] = Field(default=False)\n    def get_tokenizer_messages(self):\n        return [m.to_tokenizer_message() for m in self.messages]\n\nclass ChatCompletionChunk(BaseModel):\n    id: str\n    choices: List[Choice]\n    created: int\n    model: str\n    object: Literal[\"chat.completion.chunk\"]\n    service_tier: Optional[Literal[\"scale\", \"default\"]] = None\n    system_fingerprint: Optional[str] = None\n    usage: Optional[CompletionUsage] = None\n\n    def to_stream_reply(self):\n        return f\"data: {self.model_dump_json()}\\n\\n\"\n\nclass RawUsage(BaseModel):\n    tokenize_time: float\n    prefill_time: float\n    decode_time: float\n    prefill_count: int\n    decode_count: int"
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/legacy/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/schemas/legacy/completions.py",
    "content": "from typing import List, Optional\nfrom enum import Enum\nfrom pydantic import BaseModel, Field\nfrom ktransformers.server.config.config import Config\nfrom ..base import Object\n\nclass CompletionCreate(BaseModel):\n    model: str\n    prompt: str | List[str]\n    stream: bool = False\n    temperature: Optional[float] = Field(default=Config().temperature)\n    top_p: Optional[float] = Field(default=Config().top_p)\n    max_tokens: Optional[int] = Field(default=None)\n    max_completion_tokens: Optional[int] = Field(default=None)\n    \n    def get_tokenizer_messages(self):\n        if isinstance(self.prompt,List):\n            self.get_tokenizer_messages('\\n'.join(self.prompt))\n        return [{'content':self.prompt,'role':'user'}]\n\n\nclass FinishReason(Enum):\n    stop = 'stop'\n    length = 'length'\n\nclass Choice(BaseModel):\n    index: int\n    text: str\n    logprobs: Optional[str] = None\n    finish_reason: FinishReason = None\n\n\nclass CompletionObject(Object):\n    created:int\n    choices: List[Choice] = []\n    model:str = 'not implmented'\n    system_fingerprint:str = 'not implmented'\n    usage: Optional[str] = None\n\n    def set_token(self,token:str):\n        if len(self.choices)==0:\n            self.choices.append(Choice(index=0,text=''))\n        self.choices[0].text = token    \n\n    def append_token(self,token:str):\n        if len(self.choices)==0:\n            self.choices.append(Choice(index=0,text=''))\n        self.choices[0].text += token\n\n    def to_stream_reply(self):\n        return f\"data:{self.model_dump_json()}\\n\\n\"\n"
  },
  {
    "path": "kt-sft/ktransformers/server/utils/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/server/utils/create_interface.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : qiyuxinlin\nDate         : 2024-07-25 11:50:16\nVersion      : 1.0.0\nLastEditors  : qiyuxinlin \nLastEditTime : 2024-07-25 12:54:48\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.backend.args import ConfigArgs\nfrom ktransformers.server.backend.context_manager import ThreadContextManager\nfrom ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface\nfrom ktransformers.server.backend.interfaces.transformers import TransformersInterface\nfrom ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface\n\ndef create_interface(config: Config, default_args: ConfigArgs):\n    if config.backend_type=='transformers':\n        from ktransformers.server.backend.interfaces.transformers import  TransformersInterface as BackendInterface\n    elif config.backend_type == 'exllamav2':\n        from ktransformers.server.backend.interfaces.exllamav2 import  ExllamaInterface as BackendInterface\n    elif config.backend_type == 'ktransformers':\n        from ktransformers.server.backend.interfaces.ktransformers import  KTransformersInterface as BackendInterface\n    elif config.backend_type == 'balance_serve':\n        from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface as BackendInterface\n    else:\n        raise NotImplementedError(f'{config.backend_type} not implemented')\n    GlobalInterface.interface = BackendInterface(default_args)\n    GlobalContextManager.context_manager = ThreadContextManager(GlobalInterface.interface)\n\nclass GlobalContextManager:\n    context_manager: ThreadContextManager\nclass GlobalInterface:\n    interface:  TransformersInterface | KTransformersInterface | ExllamaInterface \n    \ndef get_thread_context_manager() -> GlobalContextManager:\n    return GlobalContextManager.context_manager\ndef get_interface() -> GlobalInterface:\n    return GlobalInterface.interface"
  },
  {
    "path": "kt-sft/ktransformers/server/utils/multi_timer.py",
    "content": "import time\n\n\ndef format_time(seconds):\n    units = [\n        (\"hours\", 3600),\n        (\"minutes\", 60),\n        (\"seconds\", 1),\n        (\"milliseconds\", 1e-3),\n        (\"microseconds\", 1e-6),\n    ]\n\n    for unit_name, unit_value in units:\n        if seconds >= unit_value:\n            time_value = seconds / unit_value\n            return f\"{time_value:.2f} {unit_name}\"\n    return \"0 seconds\"  # Handle case for 0 seconds\n\n\nclass Profiler:\n    def __init__(self):\n        self.timers = {}\n        self.counters = {}\n\n    def create_timer(self, name):\n        self.timers[name] = {\n            \"start_time\": None,\n            \"elapsed_time\": 0,\n            \"running\": False,\n        }\n\n    def start_timer(self, name):\n        if name not in self.timers:\n            raise ValueError(f\"Timer '{name}' does not exist.\")\n        if self.timers[name][\"running\"]:\n            raise ValueError(f\"Timer '{name}' is already running.\")\n        self.timers[name][\"start_time\"] = time.time()\n        self.timers[name][\"running\"] = True\n\n    def pause_timer(self, name):\n        if name not in self.timers:\n            raise ValueError(f\"Timer '{name}' does not exist.\")\n        if not self.timers[name][\"running\"]:\n            raise ValueError(f\"Timer '{name}' is not running.\")\n        self.timers[name][\"elapsed_time\"] += time.time() - self.timers[name][\"start_time\"]\n        self.timers[name][\"running\"] = False\n\n    def get_timer_sec(self, name):\n        if name not in self.timers:\n            raise ValueError(f\"Timer '{name}' does not exist.\")\n        if self.timers[name][\"running\"]:\n            current_time = self.timers[name][\"elapsed_time\"] + (time.time() - self.timers[name][\"start_time\"])\n        else:\n            current_time = self.timers[name][\"elapsed_time\"]\n        return current_time\n\n    def get_all_timers(self):\n        all_timers = {}\n        for name in self.timers:\n            all_timers[name] = self.get_timer_sec(name)\n        return all_timers\n\n    def report_timer_string(self, name):\n        return f\"{name} elapsed time: {format_time(self.get_timer_sec(name))}\"\n\n    def create_and_start_timer(self, name):\n        self.create_timer(name)\n        self.start_timer(name)\n\n\n    # Counter\n    def inc(self,key:str,delta:int=1):\n        self.counters[key] = self.counters.get(key,0) + delta\n\n    def set_counter(self,key:str,to=0):\n        self.counters[key] = to\n\n    def get_counter(self,key:str):\n        return self.counters.get(key,0)\n"
  },
  {
    "path": "kt-sft/ktransformers/server/utils/sql_utils.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : chenxl\nDate         : 2024-06-12 09:12:58\nVersion      : 1.0.0\nLastEditors  : chenxl \nLastEditTime : 2024-07-27 01:56:04\n'''\n\nfrom urllib.parse import urlparse\nimport os\nfrom contextlib import contextmanager\nfrom sqlalchemy import create_engine\nfrom sqlalchemy.orm import Session, sessionmaker, declarative_base\n\nfrom ktransformers.server.config.config import Config\nfrom ktransformers.server.config.singleton import Singleton\nfrom ktransformers.server.config.log import logger\nfrom ktransformers.server.exceptions import db_exception\n\n\nBase = declarative_base()\n\n\nclass SQLUtil(metaclass=Singleton):\n    \"\"\"\n    database connections init and management\n    \"\"\"\n    sqlalchemy_engine = None\n    session_local = None\n\n    def __init__(self) -> None:\n        self.cfg: Config = Config()\n        if not self.sqlalchemy_engine:\n            SQLUtil.init_engine(self.cfg)\n\n    @contextmanager\n    def get_db(self):\n        \"\"\"\n        After you finish using the session, it's crucial to close it.\n        \"\"\"\n        if not SQLUtil.sqlalchemy_engine:\n            SQLUtil.init_engine(self.cfg)\n        session = self.session_local()  # type: ignore pylint: disable=not-callable\n        try:\n            yield session\n        finally:\n            session.close()\n\n    @staticmethod\n    def init_engine(cfg: Config):\n        \"\"\"\n        initial engine and session maker Factory\n        \"\"\"\n        pool_size = cfg.db_pool_size\n        if SQLUtil.sqlalchemy_engine is None:\n            if cfg.db_type == \"sqllite\":\n                db_url = SQLUtil.create_sqllite_url(cfg)\n            else:\n                logger.error(\"Unsupported database type %s\", cfg.db_type)\n                exit(-1)\n            SQLUtil.sqlalchemy_engine = create_engine(\n                db_url, connect_args={\"check_same_thread\": False}, pool_size=pool_size)\n            SQLUtil.session_local = sessionmaker(\n                autocommit=False, autoflush=False, bind=SQLUtil.sqlalchemy_engine)\n\n    @staticmethod\n    def create_sqllite_url(cfg):\n        \"\"\"\n        create and validate SQLLite url\n        \"\"\"\n        path: str = cfg.db_host\n        database: str = cfg.db_database\n        absolute_path: str = os.path.join(path, database)\n        url = 'sqlite:///' + absolute_path\n        try:\n            result = urlparse(url)\n            if all([result.scheme, result.path, result.scheme == 'sqlite']):\n                return url\n            else:\n                logger.error(\"invalid sqllite url: %s\", url)\n                exit(-1)\n        except ValueError:\n            logger.error(\"invalid sqllite url: %s\", url)\n            exit(-1)\n\n    def db_add_commit_refresh(self, session: Session, what):\n        \"\"\"\n        add data to database\n        \"\"\"\n        try:\n            session.add(what)\n            session.commit()\n            session.refresh(what)\n        except Exception as e:\n            logger.exception(\"db commit error with data %s\", str(what.__dict__))\n            ex = db_exception()\n            ex.detail = str(e)\n            session.rollback()\n            raise ex from e\n\n    def db_merge_commit(self, session: Session, what):\n        try:\n            session.merge(what)\n            session.commit()\n        except Exception as e:\n            ex = db_exception()\n            ex.detail = str(e)\n            logger.exception(\"db merge commit error with data %s\", str(what.__dict__))\n            session.rollback()\n            raise ex from e\n\n    def db_update_commit_refresh(self, session: Session, existing, what):\n        what = what.model_dump(mode=\"json\")\n        try:\n            for key in what.keys():\n                if what[key] is not None:\n                    setattr(existing, key, what[key])\n            session.commit()\n            session.refresh(existing)\n        except Exception as e:\n            ex = db_exception()\n            ex.detail = str(e)\n            logger.exception(\"db update commit refresh error with data %s\", str(what.__dict__))\n            session.rollback()\n            raise ex from e\n"
  },
  {
    "path": "kt-sft/ktransformers/sft/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/sft/flops_utils/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/sft/flops_utils/custom_profile.py",
    "content": "from distutils.version import LooseVersion\n\nfrom thop.vision.basic_hooks import *\nfrom thop.rnn_hooks import *\nfrom thop.utils import prGreen, prRed, prYellow\nimport sys, os\n\nproject_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\nsys.path.insert(0, project_dir)\n\nfrom ktransformers.util.utils import prefill_and_generate\n\n# logger = logging.getLogger(__name__)\n# logger.setLevel(logging.INFO)\n\n\nif LooseVersion(torch.__version__) < LooseVersion(\"1.0.0\"):\n    logging.warning(\n        \"You are using an old version PyTorch {version}, which THOP does NOT support.\".format(\n            version=torch.__version__\n        )\n    )\n\ndefault_dtype = torch.float64\n\nregister_hooks = {\n    nn.ZeroPad2d: zero_ops,  # padding does not involve any multiplication.\n    nn.Conv1d: count_convNd,\n    nn.Conv2d: count_convNd,\n    nn.Conv3d: count_convNd,\n    nn.ConvTranspose1d: count_convNd,\n    nn.ConvTranspose2d: count_convNd,\n    nn.ConvTranspose3d: count_convNd,\n    nn.BatchNorm1d: count_normalization,\n    nn.BatchNorm2d: count_normalization,\n    nn.BatchNorm3d: count_normalization,\n    nn.LayerNorm: count_normalization,\n    nn.InstanceNorm1d: count_normalization,\n    nn.InstanceNorm2d: count_normalization,\n    nn.InstanceNorm3d: count_normalization,\n    nn.PReLU: count_prelu,\n    nn.Softmax: count_softmax,\n    nn.ReLU: zero_ops,\n    nn.ReLU6: zero_ops,\n    nn.LeakyReLU: count_relu,\n    nn.MaxPool1d: zero_ops,\n    nn.MaxPool2d: zero_ops,\n    nn.MaxPool3d: zero_ops,\n    nn.AdaptiveMaxPool1d: zero_ops,\n    nn.AdaptiveMaxPool2d: zero_ops,\n    nn.AdaptiveMaxPool3d: zero_ops,\n    nn.AvgPool1d: count_avgpool,\n    nn.AvgPool2d: count_avgpool,\n    nn.AvgPool3d: count_avgpool,\n    nn.AdaptiveAvgPool1d: count_adap_avgpool,\n    nn.AdaptiveAvgPool2d: count_adap_avgpool,\n    nn.AdaptiveAvgPool3d: count_adap_avgpool,\n    nn.Linear: count_linear,\n    nn.Dropout: zero_ops,\n    nn.Upsample: count_upsample,\n    nn.UpsamplingBilinear2d: count_upsample,\n    nn.UpsamplingNearest2d: count_upsample,\n    nn.RNNCell: count_rnn_cell,\n    nn.GRUCell: count_gru_cell,\n    nn.LSTMCell: count_lstm_cell,\n    nn.RNN: count_rnn,\n    nn.GRU: count_gru,\n    nn.LSTM: count_lstm,\n    nn.Sequential: zero_ops,\n    nn.PixelShuffle: zero_ops,\n}\n\nif LooseVersion(torch.__version__) >= LooseVersion(\"1.1.0\"):\n    register_hooks.update({nn.SyncBatchNorm: count_normalization})\n\n\ndef profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):\n    handler_collection = []\n    types_collection = set()\n    if custom_ops is None:\n        custom_ops = {}\n    if report_missing:\n        verbose = True\n\n    def add_hooks(m):\n        if len(list(m.children())) > 0:\n            return\n\n        if hasattr(m, \"total_ops\") or hasattr(m, \"total_params\"):\n            logging.warning(\n                \"Either .total_ops or .total_params is already defined in %s. \"\n                \"Be careful, it might change your code's behavior.\" % str(m)\n            )\n\n        m.register_buffer(\"total_ops\", torch.zeros(1, dtype=default_dtype))\n        m.register_buffer(\"total_params\", torch.zeros(1, dtype=default_dtype))\n\n        for p in m.parameters():\n            m.total_params += torch.DoubleTensor([p.numel()])\n\n        m_type = type(m)\n\n        fn = None\n        if (\n            m_type in custom_ops\n        ):  # if defined both op maps, use custom_ops to overwrite.\n            fn = custom_ops[m_type]\n            if m_type not in types_collection and verbose:\n                print(\"[INFO] Customize rule %s() %s.\" % (fn.__qualname__, m_type))\n        elif m_type in register_hooks:\n            fn = register_hooks[m_type]\n            if m_type not in types_collection and verbose:\n                print(\"[INFO] Register %s() for %s.\" % (fn.__qualname__, m_type))\n        else:\n            if m_type not in types_collection and report_missing:\n                prRed(\n                    \"[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params.\"\n                    % m_type\n                )\n\n        if fn is not None:\n            handler = m.register_forward_hook(fn)\n            handler_collection.append(handler)\n        types_collection.add(m_type)\n\n    training = model.training\n\n    model.eval()\n    model.apply(add_hooks)\n\n    with torch.no_grad():\n        model(*inputs)\n\n    total_ops = 0\n    total_params = 0\n    for m in model.modules():\n        if len(list(m.children())) > 0:  # skip for non-leaf module\n            continue\n        total_ops += m.total_ops\n        total_params += m.total_params\n\n    total_ops = total_ops.item()\n    total_params = total_params.item()\n\n    # reset model to original status\n    model.train(training)\n    for handler in handler_collection:\n        handler.remove()\n\n    # remove temporal buffers\n    for n, m in model.named_modules():\n        if len(list(m.children())) > 0:\n            continue\n        if \"total_ops\" in m._buffers:\n            m._buffers.pop(\"total_ops\")\n        if \"total_params\" in m._buffers:\n            m._buffers.pop(\"total_params\")\n\n    return total_ops, total_params\n\n\ndef custom_profile(\n    model: nn.Module,\n    inputs,\n    content,\n    tokenizer,\n    custom_ops=None,\n    verbose=True,\n    ret_layer_info=False,\n    report_missing=False,\n):\n    handler_collection = {}\n    types_collection = set()\n    if custom_ops is None:\n        custom_ops = {}\n    if report_missing:\n        # overwrite `verbose` option when enable report_missing\n        verbose = True\n\n    def add_hooks(m: nn.Module):\n        m.register_buffer(\"total_ops\", torch.zeros(1, dtype=torch.float64))\n        m.register_buffer(\"total_params\", torch.zeros(1, dtype=torch.float64))\n\n        # for p in m.parameters():\n        #     m.total_params += torch.DoubleTensor([p.numel()])\n\n        m_type = type(m)\n\n        fn = None\n        if m_type in custom_ops:\n            # if defined both op maps, use custom_ops to overwrite.\n            fn = custom_ops[m_type]\n            if m_type not in types_collection and verbose:\n                print(\"[INFO] Customize rule %s() %s.\" % (fn.__qualname__, m_type))\n        elif m_type in register_hooks:\n            fn = register_hooks[m_type]\n            if m_type not in types_collection and verbose:\n                print(\"[INFO] Register %s() for %s.\" % (fn.__qualname__, m_type))\n        else:\n            if m_type not in types_collection and report_missing:\n                prRed(\n                    \"[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params.\"\n                    % m_type\n                )\n\n        if fn is not None:\n            handler_collection[m] = (\n                m.register_forward_hook(fn),\n                m.register_forward_hook(count_parameters),\n            )\n        types_collection.add(m_type)\n\n    prev_training_status = model.training\n\n    model.eval()\n    model.apply(add_hooks)\n    \n    messages = [{\"role\": \"user\", \"content\": content}]\n    input_tensor = tokenizer.apply_chat_template(\n        messages, add_generation_prompt=True, return_tensors=\"pt\"\n    )\n\n    with torch.no_grad():\n        # model(*inputs)\n        # TODO: model.model to deal with the PeftModelForCaualLM temp\n        simple_prefill_and_generate_for_test(\n            model.model, tokenizer, input_tensor.cuda(), max_new_tokens=1000, use_cuda_graph=False, mode = 'normal', force_think = False, chunk_prefill_size = 8192,\n        )\n\n    def dfs_count(module: nn.Module, prefix=\"\\t\") -> (int, int):\n        total_ops, total_params = module.total_ops.item(), 0\n        ret_dict = {}\n        for n, m in module.named_children():\n            # if not hasattr(m, \"total_ops\") and not hasattr(m, \"total_params\"):  # and len(list(m.children())) > 0:\n            #     m_ops, m_params = dfs_count(m, prefix=prefix + \"\\t\")\n            # else:\n            #     m_ops, m_params = m.total_ops, m.total_params\n            next_dict = {}\n            if m in handler_collection and not isinstance(\n                m, (nn.Sequential, nn.ModuleList)\n            ):\n                m_ops, m_params = m.total_ops.item(), m.total_params.item()\n            else:\n                m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + \"\\t\")\n            ret_dict[n] = (m_ops, m_params, next_dict)\n            total_ops += m_ops\n            total_params += m_params\n        # print(prefix, module._get_name(), (total_ops, total_params))\n        return total_ops, total_params, ret_dict\n\n    total_ops, total_params, ret_dict = dfs_count(model)\n\n    # reset model to original status\n    model.train(prev_training_status)\n    for m, (op_handler, params_handler) in handler_collection.items():\n        op_handler.remove()\n        params_handler.remove()\n        m._buffers.pop(\"total_ops\")\n        m._buffers.pop(\"total_params\")\n\n    if ret_layer_info:\n        return total_ops, total_params, ret_dict\n    return total_ops, total_params\n"
  },
  {
    "path": "kt-sft/ktransformers/sft/flops_utils/lora_test_utils.py",
    "content": "from torch.profiler import profile, record_function, ProfilerActivity\nimport os\nfrom transformers import TrainerCallback\n\nclass ProfilerCallback(TrainerCallback):\n    def __init__(self, profiler):\n        self.profiler = profiler\n\n    def on_step_end(self, args, state, control, **kwargs):\n        self.profiler.step()\n\ndef _short(t):\n    return tuple(t.shape) if isinstance(t, torch.Tensor) else type(t)\n\ndef install_shape_probes(model):\n    if os.environ.get(\"KT_DEBUG_MOE\",\"0\") != \"1\":\n        print(\"[KT_DEBUG_MOE] off\"); return\n\n    try:\n        acc = trainer.accelerator\n        cfg = getattr(acc, \"dataloader_config\", None)\n        if cfg is not None:\n            print(\"[ACCEL DL CONFIG]\",\n                  \"split_batches=\", getattr(cfg,\"split_batches\",None),\n                  \"dispatch_batches=\", getattr(cfg,\"dispatch_batches\",None),\n                  \"even_batches=\", getattr(cfg,\"even_batches\",None),\n                  \"use_seedable_sampler=\", getattr(cfg,\"use_seedable_sampler\",None),\n                  \"non_blocking=\", getattr(cfg,\"non_blocking\",None))\n    except Exception as e:\n        print(\"[ACCEL DL CONFIG] <err>\", e)\n\n    try:\n        emb = model.base_model.model.model.embed_tokens\n        def _emb_pre(mod, inp):\n            x = inp[0]\n            if not hasattr(mod, \"_dbg_once\"):\n                print(f\"[DBG] embed input_ids shape = {tuple(x.shape)}  (expect B,S)\")\n                mod._dbg_once = True\n        emb.register_forward_pre_hook(_emb_pre)\n    except Exception as e:\n        print(\"[DBG] embed hook failed:\", e)\n\n    try:\n        first_layer = model.base_model.model.model.layers[0]\n        _orig_fwd = first_layer.forward\n        def _wrap_fwd(self, *args, **kwargs):\n            hs = args[0] if args else kwargs.get(\"hidden_states\")\n            if not hasattr(self, \"_dbg_once_in\"):\n                print(f\"[DBG] L0.in hidden_states = {_short(hs)}  (expect B,S,H)\")\n                self._dbg_once_in = True\n            out = _orig_fwd(*args, **kwargs)\n            hs_out = out[0] if isinstance(out, (tuple, list)) else out\n            if not hasattr(self, \"_dbg_once_out\"):\n                print(f\"[DBG] L0.out hidden_states = {_short(hs_out)}\")\n                self._dbg_once_out = True\n            return out\n        first_layer.forward = MethodType(_wrap_fwd, first_layer)\n    except Exception as e:\n        print(\"[DBG] L0 wrap failed:\", e)\n\n    try:\n        moe_layer = None\n        for i, lyr in enumerate(model.base_model.model.model.layers):\n            if hasattr(lyr, \"mlp\"):\n                moe_layer = lyr.mlp\n                moe_idx = i\n                break\n        if moe_layer is not None:\n            _moe_orig = moe_layer.forward\n            def _moe_wrap(self, *args, **kwargs):\n                x = args[0] if args else kwargs.get(\"hidden_states\")\n                if not hasattr(self, \"_dbg_once\"):\n                    print(f\"[DBG] MLP(in) @layer{moe_idx} hidden_states = {_short(x)}\")\n                    if isinstance(x, torch.Tensor) and x.dim() == 3:\n                        B,S,H = x.shape\n                        print(f\"[DBG] tokens before flatten = B*S = {B}*{S} = {B*S}\")\n                    self._dbg_once = True\n                return _moe_orig(*args, **kwargs)\n            moe_layer.forward = MethodType(_moe_wrap, moe_layer)\n        else:\n            print(\"[DBG] no moe_layer found\")\n    except Exception as e:\n        print(\"[DBG] moe wrap failed:\", e)\n\n    try:\n        from ktransformers.operators.experts import KTransformersExperts\n        def _experts_pre(mod, args):\n            if hasattr(mod, \"_dbg_once\"): return\n            try:\n                input_tensor, expert_ids, weights = args[:3]\n                print(f\"[DBG] experts.in input_tensor={tuple(input_tensor.shape)} \"\n                      f\"expert_ids={tuple(expert_ids.shape)} weights={tuple(weights.shape)}\")\n                if input_tensor.dim()==2:\n                    N = input_tensor.shape[0]\n                    print(f\"[DBG] N(input rows)={N}\")\n                if expert_ids.dim()==2:\n                    T,K = expert_ids.shape\n                    print(f\"[DBG] tokens(T)={T}, K={K}, T*K={T*K}\")\n                mod._dbg_once = True\n            except Exception as e:\n                print(\"[DBG] experts hook parse err:\", e)\n        count=0\n        for name,m in model.named_modules():\n            if isinstance(m, KTransformersExperts):\n                m.register_forward_pre_hook(_experts_pre); count+=1\n        print(f\"[KT_DEBUG_MOE] installed experts hook on {count} modules.\")\n    except Exception as e:\n        print(\"[DBG] experts hook failed:\", e)\n\ndef inspect_device(model, write_file):\n    for name, module in model.named_modules(): \n        with open(write_file, 'a') as file:\n            file.write(f\"Layer: {name}\\n\")\n        for param_name, param in module.named_parameters(recurse=False): \n            with open(write_file, 'a') as file:\n                file.write(f\"  Parameter '{param_name}' device: {param.device}\\n\")\n        for buffer_name, buffer in module.named_buffers(recurse=False): \n            with open(write_file, 'a') as file:\n                file.write(f\"  Buffer '{buffer_name}' device: {buffer.device}\\n\")\n\ndef print_model_params(model):\n    # for layer_idx in range(len(model.model.orig_module.layers)):\n    for layer_idx in range(0, 3):\n        layer = model.model.orig_module.layers[layer_idx]\n        \n        print(f\"\\n================ Layer {layer_idx} Attention ================\")\n        \n        q_proj = layer.self_attn.orig_module.q_proj.orig_module\n        print(f\"\\nq_proj.generate_linear.weight (shape: {q_proj.generate_linear.weight.shape})\")\n        print(q_proj.generate_linear.weight.cpu())\n        \n        # kv_a_proj = layer.self_attn.orig_module.kv_a_proj_with_mqa.orig_module\n        # print(f\"\\nkv_a_proj.weight (shape: {kv_a_proj.weight.shape})\")\n        # print(kv_a_proj.weight.data[:3, :5].detach().cpu().numpy())\n        \n        # o_proj = layer.self_attn.orig_module.o_proj.orig_module\n        # print(f\"\\no_proj.weight (shape: {o_proj.weight.shape})\")\n        # print(o_proj.weight.data[:3, :5].detach().cpu().numpy())\n        \n        # print(f\"\\n================ Layer {layer_idx} MLP/MoE ================\")\n        \n        # if layer_idx == 0:\n        #     mlp = layer.mlp\n        #     for proj_type in ['gate_proj', 'up_proj', 'down_proj']:\n        #         module = getattr(mlp, proj_type).orig_module\n        #         print(f\"\\n{proj_type}.weight (shape: {module.weight.shape})\")\n        #         print(module.weight.data[:3, :5].detach().cpu().numpy())\n        # else:\n        #     moe = layer.mlp.orig_module\n        #     print(\"\\n[Shared Experts]\")\n        #     for proj_type in ['gate_proj', 'up_proj', 'down_proj']:\n        #         module = getattr(moe.shared_experts, proj_type).orig_module\n        #         print(f\"\\nshared_{proj_type}.weight (shape: {module.weight.shape})\")\n        #         print(module.weight.data[:3, :5].detach().cpu().numpy())\n            \n        #     print(\"\\n[Experts]\")\n        #     for expert_idx in range(3):\n        #         expert = moe.experts.orig_module[expert_idx]\n        #         print(f\"\\nExpert {expert_idx}:\")\n        #         for proj_type in ['gate_proj', 'up_proj', 'down_proj']:\n        #             module = getattr(expert, proj_type)\n        #             print(f\"{proj_type}.weight (shape: {module.weight.shape})\")\n        #             print(module.weight.data[:3, :5].detach().cpu().numpy())\n\ndef print_lora_params(model):\n    # for layer_idx in range(len(model.model.orig_module.layers)):\n    for layer_idx in range(0, 3):\n        layer = model.base_model.model.model.orig_module.layers[layer_idx]\n        # layer = model.model.orig_module.layers[layer_idx]\n        \n        q_proj_module = layer.self_attn.orig_module.q_proj.orig_module\n        \n        linear_weight = q_proj_module.generate_linear.weight\n        lora_A_weight = q_proj_module.lora_A[\"default\"].weight\n        lora_B_weight = q_proj_module.lora_B[\"default\"].weight\n        \n        print(f\"\\n=================== Layer {layer_idx} ===================\")\n        \n        print(\"\\nOriginal Linear (first row slice):\")\n        print(linear_weight.cpu())\n        \n        print(\"\\nLora_A (first row slice):\")\n        print(lora_A_weight.cpu())\n        \n        print(\"\\nLora_B (first row slice):\")\n        print(lora_B_weight.cpu())\n\ndef print_grad_fn(grad_fn, indent=0):\n    \"\"\"递归打印计算图节点\"\"\"\n    if grad_fn is None:\n        return\n    print(' ' * indent, f\"Node: {str(grad_fn).split('(')[0]}\")\n    print(' ' * indent, f\"  Metadata: {grad_fn.metadata}\")\n    for child in getattr(grad_fn, 'next_functions', []):\n        if child[0] is not None:\n            print_grad_fn(child[0], indent + 2)\n\ndef forward_hook(module, inputs, output):\n    if isinstance(output, (tuple, list)):\n        for i, o in enumerate(output):\n            if o is None:\n                print(f\"{module.__class__.__name__} output index {i} is None\")\n            else:\n                print(f\"{module.__class__.__name__} output index {i}: requires_grad={o.requires_grad}, grad_fn={o.grad_fn}\")\n    elif output is None:\n        print(f\"{module.__class__.__name__} returned None\")\n    else:\n        print(f\"{module.__class__.__name__}: requires_grad={output.requires_grad}, grad_fn={output.grad_fn}\")\n\ndef check_moe_gradients(model):\n    moe_layer = model.base_model.model.model.orig_module.layers[1].mlp.orig_module\n    for name, param in moe_layer.named_parameters():\n        if param.requires_grad and param.grad is not None:\n            grad_norm = torch.norm(param.grad)\n            print(f\"MoE参数 {name} 梯度范数: {grad_norm}\")\n        else:\n            print(f\"MoE参数 {name} 无梯度\")\n\ndef disable_all_dropout(module):\n        for name, child in module.named_children():\n            if isinstance(child, nn.Dropout):\n                child.p = 0\n                child.inplace = False\n            disable_all_dropout(child)\n\ndef verify_lora_layers(model):\n    for layer_path in target_layers:\n        module = model.get_submodule(layer_path)\n        orig_module = module.orig_module\n        \n        W = orig_module.weight.data  # [576, 2048] -> [2048, 576]\n        lora_A = module.lora_A['default'].weight.data  # [8, 2048]\n        lora_B = module.lora_B['default'].weight.data  # [576, 8]\n        alpha_over_r = 32/8  # alpha=32, r=8\n        \n        input_tensor = layer_data[layer_path]['input']  # [1, 512, 2048]\n        \n        try:\n            original_output = torch.matmul(input_tensor, W)  # [1,512,2048] @ [2048,576] => [1,512,576]\n        except:\n            original_output = torch.matmul(input_tensor, W.T)  # [1,512,2048] @ [2048,576] => [1,512,576]\n        \n        lora_effect = torch.matmul(\n            torch.matmul(input_tensor, lora_A.T),  # [1,512,2048] @ [2048,8] => [1,512,8]\n            lora_B.T  # [1,512,8] @ [8,576] => [1,512,576]\n        ) * alpha_over_r\n        \n        manual_output = original_output + lora_effect  # [1,512,576]\n        \n        model_output = layer_data[layer_path]['output']\n\n        print(f\"manual_output:{manual_output}\")\n        print(f\"model_output:{model_output}\")\n        \n        if torch.allclose(manual_output, model_output, atol=1e-5):\n            print(f\"{layer_path} 验证通过\")\n        else:\n            print(f\"{layer_path} 验证失败！最大误差：{torch.max(torch.abs(manual_output - model_output))}\")\n\ndef print_moe_stats(moe_layer: KExpertsTorch):\n    print(f\"Total Params: {moe_layer.total_params/1e6:.2f}M\")\n    \n    total_time = sum(moe_layer.times)\n    gflops = (moe_layer.total_flops / 1e9) / total_time if total_time !=0 else 0\n    \n    print(f\"Total Calls: {moe_layer.call_count}\")\n    # print(f\"Avg GFLOPS per Call: {gflops/moe_layer.call_count:.2f}\")\n    print(f\"Overall GFLOPS: {gflops:.2f}\")\n    \n    if moe_layer.call_count > 0:\n        last_flops = moe_layer.flops_per_call[-1]\n        last_time = moe_layer.times[-1]\n        print(f\"\\nLast Call - FLOPs: {last_flops/1e9:.2f}G  Time: {last_time*1000:.2f}ms  \"\n              f\"GFLOPS: {(last_flops/1e9)/last_time:.2f}\")\n        \ndef recursive_traverse(model, parent_name=''):\n    \"\"\"\n    递归遍历模型，查找MoE层并调用print_moe_stats。\n    \"\"\"\n    for name, module in model.named_children():\n        full_name = f\"{parent_name}.{name}\" if parent_name else name\n        \n        if isinstance(module, KTransformersExperts):\n            print(f\"Found MoE layer: {full_name}\")\n            print_moe_stats(module.generate_experts)\n        \n        recursive_traverse(module, full_name)\n\ndef log_step_state(\n    step: int,\n    inputs: dict,\n    loss: torch.Tensor,\n    model: nn.Module,\n    log_dir: str = \"train_logs\",\n):\n    \"\"\"\n    把当前 step 的输入 / loss / grad / param 保存到 log_dir/step_{step}.pt\n    \"\"\"\n    Path(log_dir).mkdir(parents=True, exist_ok=True)\n\n    logged_inputs = {\n        k: v.detach().cpu()\n        for k, v in inputs.items()\n        if isinstance(v, torch.Tensor)\n    }\n\n    loss_val = loss.detach().cpu()\n\n    params, grads = {}, {}\n    for name, p in model.named_parameters():\n        params[name] = p.detach().cpu()\n        grads[name] = p.grad.detach().cpu() if p.grad is not None else None\n\n    torch.save(\n        {\n            \"step\": step,\n            \"inputs\": logged_inputs,\n            \"loss\": loss_val,\n            \"params\": params,\n            \"grads\": grads,\n        },\n        f\"{log_dir}/step_{step:08d}.pt\",\n    )\n\ndef collect_gradients(model, input_ids):\n    torch.manual_seed(42)\n    \n    output = model(input_ids=input_ids)\n    \n    logits = output.logits\n    loss = logits.mean()\n    \n    model.zero_grad()\n    loss.backward()\n    \n    grads = []\n    for name, param in model.named_parameters():\n        if param.requires_grad and param.grad is not None:\n            grads.append(f\"{name}: {param.grad.norm().item():.6f}\")\n    \n    return grads\n\ndef report_meta_tensors(model):\n    import torch, inspect\n    meta_modules = []\n    for mod_name, mod in model.named_modules():\n        metas = []\n        for n, p in list(mod.named_parameters(recurse=False)):\n            if getattr(p, \"is_meta\", False) and p.is_meta:\n                metas.append((\"param\", n, tuple(p.shape)))\n        for n, b in list(mod.named_buffers(recurse=False)):\n            if getattr(b, \"is_meta\", False) and b.is_meta:\n                metas.append((\"buffer\", n, tuple(b.shape)))\n        if metas:\n            print(f\"[META] {mod_name} ({type(mod).__name__}): {metas}\")\n            meta_modules.append((mod_name, type(mod).__name__, metas))\n    return meta_modules\n\n# def lora_and_load_adapter(model, tokenizer, sft_data_path, save_adapter_path, is_profiler=False):\n    # show some lora test\n    \n    '''\n    # multi-gpu dataloader test\n    # _ = report_meta_tensors(model)\n    \n    # print(\"=== SAMPLE INSPECT ===\")\n    # for i in range(2):\n    #     summary = {}\n    #     for k,v in ex.items():\n    #         if isinstance(v, list):\n    #             if len(v)>0 and isinstance(v[0], list):\n    #                 summary[k] = f\"list-of-lists len={len(v)} x len0={len(v[0])}\"\n    #             else:\n    #                 summary[k] = f\"list len={len(v)}\"\n    #         elif torch.is_tensor(v):\n    #             summary[k] = f\"tensor shape={tuple(v.shape)}\"\n    #         else:\n    #             summary[k] = str(type(v))\n    #     print(f\"[SAMPLE {i}]\", summary)\n    \n    # trainer.accelerator = Accelerator(device_placement=False)\n    # first_batch = next(iter(trainer.get_train_dataloader()))\n    # print(\"Batch keys:\", list(first_batch.keys()))\n    \n    # acc = KAccelerator(device_placement=False)\n    # acc.state.device_ids = [0]\n    # acc.state.num_processes = 1\n    # acc.state.num_gpus = 1\n    # trainer.accelerator = acc\n\n    # print(\"Accelerator device_ids:\", trainer.accelerator.state.device_ids)\n    # print(f\"type(trainer.model):{type(trainer.model)}\")\n    # print(f\"type(trainer.accelerator):{type(trainer.accelerator)}\")\n    \n    \n    # print(\"-------------------------START TRAINING!!!-------------------------\")\n\n    # cfg = getattr(trainer.accelerator, \"dataloader_config\", None)\n    # print(\n    #     \"[ACCEL DL CONFIG]\",\n    #     \"split_batches=\", getattr(cfg, \"split_batches\", None),\n    #     \"dispatch_batches=\", getattr(cfg, \"dispatch_batches\", None),\n    #     \"even_batches=\", getattr(cfg, \"even_batches\", None),\n    #     \"use_seedable_sampler=\", getattr(cfg, \"use_seedable_sampler\", None),\n    #     \"non_blocking=\", getattr(cfg, \"non_blocking\", None),\n    # )\n    # print(\"--------------------NEW DEBUG--------------------\")\n    # install_shape_probes(trainer.model) # print some debug info about multi-gpu placement.\n\n    # input_ids = torch.randint(0, 1000, (32, 128), device=\"cuda:0\")\n    # gradients = collect_gradients(model, input_ids)\n    '''\n    \n    # with open(f\"/home/lpl/kt-sft/tmp/KSFTExpertsCPU_grads.txt\", \"w\") as f:\n    #     f.write(\"\\n\".join(gradients))\n    # print(xx)\n    \n    # total_length = 0\n    # valid_count = 0\n    # for batch in tqdm(train_dataloader):\n    #     input_ids = batch['input_ids']\n    #     # print(f\"Token count per sample: {[len(ids) for ids in input_ids]}\")\n    #     for ids in input_ids:\n    #         if not torch.equal(ids, torch.tensor([100001])):\n    #             total_length += len(ids)\n    #     valid_count += 1\n    #     # print(f\"Input tensor: {input_ids}\")\n    #     # print(f\"total_length:{total_length}\")\n    #     # break\n\n    # if valid_count > 0:\n    #     average_length = total_length / valid_count\n    # else:\n\n    # print(xx)\n    \n    # from ktransformers.sft.flops_utils.custom_profile import custom_profile\n\n    # for module in model.modules():\n    #     if not hasattr(module, 'total_ops'):\n    #         module.register_buffer('total_ops', torch.zeros(1, dtype=torch.float64))\n    #     if not hasattr(module, 'total_params'):\n    #         module.register_buffer('total_params', torch.zeros(1, dtype=torch.float64))\n            \n    # # print(f\"input:{input}\")\n    # for inputs in tqdm(train_dataloader):\n    #     # input_ids = batch['input_ids']\n    #     # del inputs['instruction']\n    #     # del inputs['input']\n    #     # del inputs['output']\n    #     # output = model(**inputs)\n    #     model.eval()\n    #     content = inputs['instruction'][0] + inputs['input'][0]\n    #     # flops,params = custom_profile(model, inputs=inputs, content=content, tokenizer=tokenizer, custom_ops={YourModule: count_your_model})\n    #     # print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')\n    #     # print('Params = ' + str(params / 1000 ** 2) + 'M')\n\n    #     messages = [{\"role\": \"user\", \"content\": content}]\n    #     input_tensor = tokenizer.apply_chat_template(\n    #         messages, add_generation_prompt=True, return_tensors=\"pt\"\n    #     )\n    #     with torch.no_grad():\n    #         # model(*inputs)\n    #         # model.model to deal with the PeftModelForCaualLM temp\n    #         prefill_and_generate(\n    #             model.model, tokenizer, input_tensor.cuda(), max_new_tokens=1000, use_cuda_graph=False, mode = 'normal', force_think = False, chunk_prefill_size = 8192,\n    #         )\n    #     recursive_traverse(model)\n    \n    # output = model(input_ids=torch.tensor([[1,2,3]], dtype=torch.int32, device=\"cuda:0\"))\n    # loss = output.logits.mean()\n        \n    # dot = make_dot(loss, params=dict(model.named_parameters()))\n    # dot.render(\"KT_compute_cpuinfer_moe_model_graph\", format=\"svg\")\n\n    # with open(\"tmp/output_loss_KCPU.txt\", \"w\") as file:\n    #     file.write(\"Output (logits):\\n\")\n    #     file.write(\"\\n\\nLoss:\\n\")\n    \n    # disable_all_dropout(model)\n\n    # def print_dropout_status(module, prefix=\"\"):\n    #     for name, child in module.named_children():\n    #         if isinstance(child, nn.Dropout):\n    #             print(f\"{prefix}{name}: p={child.p}, training={child.training}\")\n    #         print_dropout_status(child, prefix + name + \".\")\n    \n    # print_dropout_status(model)\n\n    # for layer_path in target_layers:\n    #     module = model.get_submodule(layer_path)\n    #     hook = module.register_forward_hook(\n    #         lambda m, i, o, ln=layer_path: record_layer_io(m, i, o, ln)\n    #     )\n    #     hooks.append(hook)\n\n    \n    # if is_profiler:\n    #     profiler = profile(\n    #         activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],\n    #         schedule=torch.profiler.schedule(\n    #         ),\n    #         on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),\n    #         record_shapes=False,\n    #         with_stack=False\n    #     )\n\n    #     # profiler_args = {\n    #     #     \"schedule\": torch.profiler.schedule(\n    #     #     )\n    #     # }\n\n    #     trainer = KTrainer(\n    #         model=model,\n    #         train_dataset=train_dataset,\n    #         data_collator=DataCollatorForSeq2Seq(\n    #             tokenizer, pad_to_multiple_of=8, return_tensors=\"pt\", padding=True\n    #         ),\n    #         callbacks=[ProfilerCallback(profiler)]\n    #     )\n\n    #     with profiler:\n    #         trainer.train()\n\n    #     print(\"Training finished. Exporting profiler data...\")\n    #     with open(\"profiler_output.txt\", \"w\") as f:\n    #         f.write(profiler.key_averages().table(sort_by=\"cuda_time_total\", row_limit=20))\n    \n    #   profiler.export_chrome_trace(\"trace.json\")\n    \n    \n\n    # verify_lora_layers(model)\n\n    # model.save_pretrained(save_adapter_path)\n\n    '''\n    ----------------------- START: Lora Test -----------------------\n    \n\n    # for name, module in model.named_modules():\n    #     if \"q_proj\" in name or \"kv_a_proj\" in name or \"o_proj\" in name:\n    #         print(name)\n\n    # print_model_params(model)\n\n    # model = KTransformersLinearLora()\n\n    # inspect_device(model, '/home/yj/ktransformers/device1.txt')\n    # with open('/home/yj/ktransformers/device1.txt', 'a') as file:\n    #     file.write(f\"Base model device: {model.base_model.device}\\n\")\n        # file.write(f\"LoRA adapter device: {model.lora_config['target_modules'].device}\\n\")\n    # print(f\"Base model device: {model.base_model.device}\") \n    # print(f\"LoRA adapter device: {model.lora_config['target_modules'].device}\") \n\n\n    # model = model.to('cuda')\n\n    # for name, module in model.named_modules():\n    #     module.register_forward_hook(forward_hook)\n\n    # for name, parms in model.named_parameters():\t\n    #     # parms.requires_grad = True\n    #     print('-->name:', name)\n    #     print('-->para:', parms)\n    #     print('-->grad_requirs:',parms.requires_grad)\n    #     print('-->grad_fn:',parms.grad_fn)\n    #     print('-->grad_value:',parms.grad)\n    #     print(\"===\")\n\n    # output = model(input_ids=torch.tensor([[1,2,3]], dtype=torch.int32, device=\"cuda:0\"))\n    # loss = output.logits.mean()\n\n    # dot = make_dot(loss, params=dict(model.named_parameters()))\n    # dot.render(\"KT_compute_graph\", format=\"svg\")\n\n    # inspect_device(model, '/home/yj/ktransformers/device2.txt')\n    # with open('/home/yj/ktransformers/device2.txt', 'a') as file:\n    #     file.write(f\"Base model device: {model.base_model.device}\\n\")\n        # file.write(f\"LoRA adapter device: {model.lora_config['target_modules'].device}\\n\")\n    # print(f\"Base model device: {model.base_model.device}\") \n    # print(f\"LoRA adapter device: {model.lora_config['target_modules'].device}\") \n\n    # print_lora_params(model)\n\n    # trainer = KTrainer(\n    #     model=model,\n    #     train_dataset=train_dataset,\n    #     args=transformers.TrainingArguments(\n    #         output_dir=save_adapter_path,\n    #         per_device_train_batch_size=1,\n    #         gradient_accumulation_steps=16,\n    #         num_train_epochs=10,\n    #         learning_rate=3e-4,\n    #         fp16=False,\n    #         logging_steps=10,\n    #         save_steps=200,\n    #         dataloader_drop_last=True,\n    #         ddp_find_unused_parameters=False \n    #     ),\n    #     data_collator=DataCollatorForSeq2Seq(\n    #         tokenizer, pad_to_multiple_of=8, return_tensors=\"pt\", padding=True\n    #     ),\n    # )\n\n    # model(input_ids=torch.tensor([[1,2,3]], dtype=torch.int32, device=\"cuda:0\"))\n\n    # trainer.train()\n\n    # print_lora_params(model)\n\n    # model = model.merge_and_unload()\n    ----------------------- END: Lora Test -----------------------\n\n    '''"
  },
  {
    "path": "kt-sft/ktransformers/sft/lora.py",
    "content": "from transformers import AutoTokenizer, DataCollatorForLanguageModeling\nfrom transformers import Trainer, TrainingArguments\nfrom transformers import Trainer\nfrom transformers.training_args import OptimizerNames\nfrom transformers.trainer_utils import seed_worker\nfrom transformers.utils import (\n    is_datasets_available,\n    is_sagemaker_mp_enabled,\n    is_torch_xpu_available,\n    is_torch_mlu_available,\n    is_torch_musa_available,\n    is_torch_npu_available,\n    is_torch_mps_available,\n    is_torch_hpu_available,\n    is_accelerate_available,\n    is_apex_available,\n    logging,\n)\nfrom packaging import version\nimport os\nimport inspect\nimport functools\nfrom typing import Union, Any, Dict, List\n\nimport torch\nfrom torch.utils.data import DataLoader\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader, IterableDataset\nfrom torch.utils.data import Dataset as TorchDataset\n\nfrom peft import LoraConfig, TaskType\nfrom datasets import Dataset\nfrom torchviz import make_dot\nfrom tqdm import tqdm\nimport os, json\nfrom pathlib import Path\nfrom accelerate import Accelerator\nif is_accelerate_available(\"0.28.0\"):\n    from accelerate.utils import DataLoaderConfiguration\nfrom accelerate import __version__ as accelerate_version\nif version.parse(accelerate_version) > version.parse(\"1.3.0\"):\n        from accelerate.utils import TorchTensorParallelPlugin\nif is_sagemaker_mp_enabled():\n    from transformers.trainer_utils import smp_forward_backward\n\nfrom ktransformers.sft.peft_utils.mapping import get_peft_model\n\nlogger = logging.get_logger(__name__)\n\nclass KAccelerator(Accelerator):\n    def __init__(self, *args, **kwargs):\n        kwargs.setdefault(\"device_placement\", False)\n        super().__init__(*args, **kwargs)\n        \n    def prepare_model(self, model, *args, **kwargs):\n        return model\n    \n    def prepare(self, *args, **kwargs):\n        prepped = []\n        for obj in args:\n            if isinstance(obj, nn.Module):\n                prepped.append(self.prepare_model(obj, **kwargs))\n            else:\n                prepped.append(super().prepare(obj, **kwargs))\n        return tuple(prepped) if len(prepped) > 1 else prepped[0]\n\nclass KTrainer(Trainer):\n    def save_model(self, output_dir=None, _internal_call=False):\n        output_dir = output_dir or self.args.output_dir\n        os.makedirs(output_dir, exist_ok=True)\n        # only save LoRA adapter, including adapter_config.json\n        self.model.save_pretrained(output_dir)\n        \n    def _move_model_to_device(self, model, device):\n        print(\"[KTrainer] Due to the placement feature in KTransformers, skip moving model to\", device)\n        return model\n    \n    def _wrap_model(self, model, training=True, dataloader=None):\n        self.model_wrapped = model\n        return model\n    \n    def create_accelerator_and_postprocess(self):\n        # We explicitly don't rely on the `Accelerator` to do gradient accumulation\n        grad_acc_kwargs = {}\n        if is_accelerate_available(\"0.28.0\") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:\n            grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs\n\n        # check if num_steps is attempted to be passed in gradient_accumulation_kwargs\n        if \"num_steps\" in grad_acc_kwargs:\n            if self.args.gradient_accumulation_steps > 1:\n                # raise because we do not know which setting is intended.\n                raise ValueError(\n                    \"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`\"\n                    \"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`.\"\n                )\n            else:\n                self.args.gradient_accumulation_steps = grad_acc_kwargs[\"num_steps\"]\n\n        accelerator_config = self.args.accelerator_config.to_dict()\n\n        if is_accelerate_available(\"0.28.0\"):\n            # Extract dataloader config params from accelerator config\n            dataloader_params = [\"split_batches\", \"dispatch_batches\", \"even_batches\", \"use_seedable_sampler\"]\n            dataloader_config_dict = {param: accelerator_config.pop(param) for param in dataloader_params if param in accelerator_config}\n            if DataLoaderConfiguration is None:\n                raise ImportError(\"Your accelerate does not provide DataLoaderConfiguration but Trainer expects it.\")\n            dataloader_config = DataLoaderConfiguration(**dataloader_config_dict)\n            if is_accelerate_available(\"1.1.0\"):\n                dataloader_config.data_seed = self.args.data_seed\n        else:\n            dataloader_config = None\n\n        non_blocking = accelerator_config.pop(\"non_blocking\", False)\n        if not is_accelerate_available(\"0.30.0\"):\n            if non_blocking:\n                raise ImportError(\n                    \"`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature.\"\n                )\n        else:\n            if non_blocking and not self.args.dataloader_pin_memory:\n                logger.warning(\"`non_blocking` is enabled but `dataloader_pin_memory` is not. For best performance, enable both.\")\n            if dataloader_config is not None:\n                dataloader_config.non_blocking = non_blocking\n\n        accelerator_config.pop(\"gradient_accumulation_kwargs\", None)\n\n        args = {\n            \"deepspeed_plugin\": self.args.deepspeed_plugin,\n            \"device_placement\": False,\n        }\n\n        if is_accelerate_available(\"0.28.0\"):\n            args[\"dataloader_config\"] = dataloader_config\n        else:\n            args.update(accelerator_config)\n\n        if getattr(self.args, \"tp_size\", 1) > 1:\n            self.is_tp_enabled = True\n            if version.parse(accelerate_version) > version.parse(\"1.3.0\") and TorchTensorParallelPlugin is not None:\n                args[\"torch_tp_plugin\"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)\n            else:\n                raise ValueError(\"Requires accelerate>1.3.0 to use Tensor Parallelism.\")\n\n        self.accelerator = KAccelerator(**args)\n\n        try:\n            self.accelerator.state.device_ids = [0]\n            self.accelerator.state.num_processes = 1\n            self.accelerator.state.num_gpus = 1\n        except Exception:\n            pass\n\n        # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag\n        self.gather_function = self.accelerator.gather_for_metrics\n\n        if \"use_gather_object\" in inspect.signature(self.gather_function).parameters.keys():\n            self.gather_function = functools.partial(\n                self.gather_function, use_gather_object=self.args.eval_use_gather_object\n            )\n\n        # deepspeed and accelerate flags covering both trainer args and accelerate launcher\n        self.is_deepspeed_enabled = getattr(self.accelerator.state, \"deepspeed_plugin\", None) is not None\n        self.is_fsdp_enabled = getattr(self.accelerator.state, \"fsdp_plugin\", None) is not None\n        self.is_tp_enabled = getattr(self.accelerator.state, \"torch_tp_plugin\", None) is not None\n        # post accelerator creation setup\n        if self.is_fsdp_enabled:\n            fsdp_plugin = self.accelerator.state.fsdp_plugin\n            for param in [\"limit_all_gathers\", \"activation_checkpointing\"]:\n                setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param)))\n            if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:\n                raise ValueError(\n                    \"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg \"\n                    \"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic \"\n                    \"when using FSDP.\"\n                )\n\n        if self.is_deepspeed_enabled and getattr(self.args, \"hf_deepspeed_config\", None) is None:\n            self.propagate_args_to_deepspeed()\n\n        # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`\n        if (\n            self.args.save_only_model\n            and (self.is_deepspeed_enabled or self.is_fsdp_enabled)\n            and self.args.load_best_model_at_end\n        ):\n            wrapper = \"DeepSpeed\" if self.is_deepspeed_enabled else \"FSDP\"\n            raise ValueError(f\"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.\")\n\n        # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3\n        if (\n            self.is_deepspeed_enabled\n            and self.accelerator.state.deepspeed_plugin.zero_stage == 3\n            and self.args.auto_find_batch_size\n        ):\n            raise ValueError(\n                \"`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP\"\n            )\n        if (\n            self.args.save_only_model\n            and self.is_fsdp_enabled\n            and \"SHARDED_STATE_DICT\" in str(self.accelerator.state.fsdp_plugin.state_dict_type)\n        ):\n            raise ValueError(\"save_only_model option is not compatible with FSDP state dict type 'SHARDED_STATE_DICT'\")\n        \n        if dataloader_config is not None:\n            dataloader_config.split_batches = False\n            dataloader_config.dispatch_batches = False\n            dataloader_config.even_batches = False\n            \n    def get_train_dataloader(self) -> DataLoader:\n        \"\"\"\n        Returns the training DataLoader with per_device_train_batch_size\n        (no implicit multipliers by number of visible GPUs).\n        \"\"\"\n        if self.train_dataset is None:\n            raise ValueError(\"Trainer: training requires a train_dataset.\")\n\n        train_dataset = self.train_dataset\n        data_collator = self.data_collator\n\n        if is_datasets_available():\n            try:\n                import datasets\n                if isinstance(train_dataset, datasets.Dataset):\n                    train_dataset = self._remove_unused_columns(train_dataset, description=\"training\")\n                else:\n                    data_collator = self._get_collator_with_removed_columns(data_collator, description=\"training\")\n            except Exception:\n                data_collator = self._get_collator_with_removed_columns(data_collator, description=\"training\")\n        else:\n            data_collator = self._get_collator_with_removed_columns(data_collator, description=\"training\")\n\n        dataloader_params = {\n            \"batch_size\": self.args.per_device_train_batch_size,\n            \"collate_fn\": data_collator,\n            \"num_workers\": self.args.dataloader_num_workers,\n            \"pin_memory\": self.args.dataloader_pin_memory,\n            \"persistent_workers\": self.args.dataloader_persistent_workers,\n        }\n\n        if not isinstance(train_dataset, IterableDataset):\n            dataloader_params[\"sampler\"] = self._get_train_sampler()\n            dataloader_params[\"drop_last\"] = self.args.dataloader_drop_last\n            dataloader_params[\"worker_init_fn\"] = seed_worker\n            if self.args.dataloader_num_workers > 0 and self.args.dataloader_prefetch_factor is not None:\n                dataloader_params[\"prefetch_factor\"] = self.args.dataloader_prefetch_factor\n\n        dl = DataLoader(train_dataset, **dataloader_params)\n\n        try:\n            prepared = self.accelerator.prepare(dl, device_placement=[False])\n        except TypeError:\n            prepared = self.accelerator.prepare(dl)\n\n        return prepared\n    \n    def training_step(\n        self,\n        model: torch.nn.Module,\n        inputs: dict[str, Union[torch.Tensor, Any]],\n        num_items_in_batch=None\n    ) -> torch.Tensor:\n        model.train()\n        if hasattr(self.optimizer, \"train\") and callable(self.optimizer.train):\n            self.optimizer.train()\n\n        inputs = self._prepare_inputs(inputs)\n\n        if is_sagemaker_mp_enabled():\n            loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)\n            return loss_mb.reduce_mean().detach().to(self.args.device)\n\n        with self.compute_loss_context_manager():\n            loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)\n\n        del inputs\n\n        if (\n            self.args.torch_empty_cache_steps is not None\n            and self.state.global_step % self.args.torch_empty_cache_steps == 0\n        ):\n            if is_torch_xpu_available():\n                torch.xpu.empty_cache()\n            elif is_torch_mlu_available():\n                torch.mlu.empty_cache()\n            elif is_torch_musa_available():\n                torch.musa.empty_cache()\n            elif is_torch_npu_available():\n                torch.npu.empty_cache()\n            elif is_torch_mps_available(min_version=\"2.0\"):\n                torch.mps.empty_cache()\n            elif is_torch_hpu_available():\n                logger.warning(\n                    \"`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache().\"\n                )\n            else:\n                torch.cuda.empty_cache()\n\n        kwargs = {}\n\n        if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:\n            kwargs[\"learning_rate\"] = self._get_learning_rate()\n\n        if self.args.n_gpu > 1:\n            loss = loss.mean()\n\n        if self.use_apex:\n            with amp.scale_loss(loss, self.optimizer) as scaled_loss:  # type: ignore\n                scaled_loss.backward()\n        else:\n            if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:\n                loss = loss / self.args.gradient_accumulation_steps\n\n            if getattr(self.accelerator, \"distributed_type\", None) and \\\n               str(self.accelerator.distributed_type) == \"DistributedType.DEEPSPEED\":\n                kwargs[\"scale_wrt_gas\"] = False\n\n            self.accelerator.backward(loss, **kwargs)\n\n        ret = loss.detach()\n        if ret.device != self.args.device:\n            ret = ret.to(self.args.device, non_blocking=True)\n\n        if os.environ.get(\"KT_DBG_STEP\", \"0\") == \"1\" and not hasattr(self, \"_kt_dbg_once\"):\n            try:\n                print(f\"[KT-DBG] args.device={self.args.device}  loss(before)={loss.device}  loss(return)={ret.device}\")\n            except Exception:\n                pass\n            self._kt_dbg_once = True\n\n        return ret\n\nclass SFTJsonListDataset(TorchDataset):\n    def __init__(self, path: str, tokenizer: AutoTokenizer, max_len: int = 512):\n        super().__init__()\n        with open(path, \"r\", encoding=\"utf-8\") as f:\n            self.samples: List[Dict] = json.load(f)\n        self.tok = tokenizer\n        self.max_len = max_len\n\n    @staticmethod\n    def build_example(ins: str, inp: str, out: str) -> Dict[str, str]:\n        ins = (ins or \"\").strip()\n        inp = (inp or \"\").strip()\n        out = (out or \"\").strip()\n        prompt = (ins + inp) if ins else inp\n        return {\"prompt\": prompt, \"response\": out}\n\n    def __len__(self):\n        return len(self.samples)\n\n    def __getitem__(self, idx: int):\n        rec = self.samples[idx]\n        eg = self.build_example(rec.get(\"instruction\", \"\"), rec.get(\"input\", \"\"), rec.get(\"output\", \"\"))\n\n        prompt_ids = self.tok(\n            eg[\"prompt\"],\n            max_length=self.max_len,\n            truncation=True,\n            add_special_tokens=False,\n        )[\"input_ids\"]\n\n        response_ids = self.tok(\n            eg[\"response\"],\n            max_length=self.max_len,\n            truncation=True,\n            add_special_tokens=False,\n        )[\"input_ids\"]\n\n        eos_id = self.tok.eos_token_id\n        input_ids = prompt_ids + response_ids + ([eos_id] if eos_id is not None else [])\n        input_ids = input_ids[: self.max_len]\n\n        labels = [-100] * min(len(prompt_ids), self.max_len)\n        tail = input_ids[len(labels):]\n        labels = labels + tail\n        labels = labels[: self.max_len]\n\n        attention_mask = [1] * len(input_ids)\n\n        return {\n            \"input_ids\": torch.tensor(input_ids, dtype=torch.long),\n            \"labels\": torch.tensor(labels, dtype=torch.long),\n            \"attention_mask\": torch.tensor(attention_mask, dtype=torch.long),\n        }\n\ndef lora_and_load_adapter(model, tokenizer, sft_data_path, save_adapter_path):\n    \n    Path(save_adapter_path).mkdir(parents=True, exist_ok=True)\n    \n    lora_config = LoraConfig(\n        task_type=TaskType.CAUSAL_LM,\n        target_modules=[\n            \"q_proj\", # FOR DeepSeek-V2-Lite\n            \"q_a_proj\", # FOR DeepSeek-V3&R1\n            \"q_b_proj\",\n            \"kv_a_proj_with_mqa\",\n            \"kv_b_proj\",\n            \"o_proj\",\n            \"mlp.gate_proj\",\n            \"mlp.up_proj\",\n            \"mlp.down_proj\",\n            \"shared_experts.gate_proj\",\n            \"shared_experts.up_proj\",\n            \"shared_experts.down_proj\",\n        ],\n        r=8,\n        lora_alpha=32,\n        lora_dropout=0.1,\n    )\n    model = get_peft_model(model, lora_config)\n    model.print_trainable_parameters()\n    \n    train_dataset = SFTJsonListDataset(sft_data_path, tokenizer, max_len=512)\n    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)\n\n    training_args = TrainingArguments(\n        output_dir=save_adapter_path,\n        per_device_train_batch_size=1,\n        gradient_accumulation_steps=16,\n        num_train_epochs=1,\n        # max_steps=30, # TODO: FOR TEST, will override any value given in num_train_epochs\n        learning_rate=1e-4,\n        fp16=False,\n        logging_steps=10,\n        save_steps=200,\n        dataloader_drop_last=True,\n        ddp_find_unused_parameters=False,\n    )\n    \n    debug_path = os.path.join(save_adapter_path, \"model_infra_debug.json\")\n    with open(debug_path, \"w\", encoding=\"utf-8\") as f:\n        json.dump({\"model\": str(model)}, f, ensure_ascii=False, indent=2)\n    \n    # output = model(input_ids=torch.tensor([[1,2,3]], dtype=torch.int32, device=\"cuda:0\"))\n    # loss = output.logits.mean()\n        \n    # dot = make_dot(loss, params=dict(model.named_parameters()))\n    # dot.render(\"KT_compute_cpuinfer_moe_model_graph\", format=\"svg\")\n    \n    trainer = KTrainer(\n        model=model,\n        tokenizer=tokenizer,\n        args=training_args,\n        train_dataset=train_dataset,\n        data_collator=data_collator,\n    )\n    model.config.use_cache = False\n    # model.gradient_checkpointing_enable()\n    # if hasattr(model, \"enable_input_require_grads\"):\n    #     model.enable_input_require_grads()\n    \n    trainer.train()\n\ndef inject_lora_layer(model, use_adapter_path):\n\n    cfg_path = os.path.join(use_adapter_path, \"adapter_config.json\")\n    with open(cfg_path, \"r\", encoding=\"utf-8\") as f:\n        data = json.load(f)\n    \n    task_type_str = (data.get(\"task_type\") or \"CAUSAL_LM\").upper()\n    bias = data.get(\"bias\", \"none\")\n    if bias in (None, False):\n        bias = \"none\"\n    if data.get(\"lora_bias\") is True and bias == \"none\":\n        bias = \"lora_only\"\n\n    tmods = data.get(\"target_modules\")\n    if isinstance(tmods, str):\n        tmods = [m.strip() for m in tmods.split(\",\") if m.strip()]\n\n    mts = data.get(\"modules_to_save\", None)\n    if isinstance(mts, str):\n        mts = [m.strip() for m in mts.split(\",\") if m.strip()]\n\n    rank_pattern = data.get(\"rank_pattern\") or None\n    alpha_pattern = data.get(\"alpha_pattern\") or None\n\n    lora_config = LoraConfig(\n        r=data.get(\"r\", 8),\n        lora_alpha=data.get(\"lora_alpha\", 32),\n        lora_dropout=float(data.get(\"lora_dropout\", 0.0)),\n        bias=bias,\n        task_type=TaskType[task_type_str],\n        target_modules=tmods,\n        modules_to_save=mts,\n        init_lora_weights=bool(data.get(\"init_lora_weights\", True)),\n        inference_mode=bool(data.get(\"inference_mode\", True)),\n        use_rslora=bool(data.get(\"use_rslora\", False)),\n        use_dora=bool(data.get(\"use_dora\", False)),\n    )\n    print(f\"lora_config:{lora_config.__dict__}\")\n    \n    # model = inject_adapter_in_model(lora_config, model)\n    model = get_peft_model(model, lora_config)\n    model.config.use_cache = False\n    model.eval()"
  },
  {
    "path": "kt-sft/ktransformers/sft/metrics.py",
    "content": "# Copyright 2025 HuggingFace Inc., THUDM, and the LlamaFactory team.\n#\n# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.\n# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py\n# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Optional\n\nimport numpy as np\nimport torch\nfrom transformers.utils import is_jieba_available, is_nltk_available\n\nfrom ktransformers.sft.metrics_utils.constants import IGNORE_INDEX\nfrom ktransformers.sft.metrics_utils.misc import numpify\nfrom ktransformers.sft.metrics_utils.packages import is_rouge_available\n\n\nif TYPE_CHECKING:\n    from transformers import EvalPrediction, PreTrainedTokenizer\n\n\nif is_jieba_available():\n    import jieba  # type: ignore\n\n\nif is_nltk_available():\n    from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu  # type: ignore\n\n\nif is_rouge_available():\n    from rouge_chinese import Rouge  # type: ignore\n\n\ndef eval_logit_processor(logits: \"torch.Tensor\", labels: \"torch.Tensor\") -> \"torch.Tensor\":\n    r\"\"\"Compute the token with the largest likelihood to reduce memory footprint.\"\"\"\n    if isinstance(logits, (list, tuple)):\n        if logits[0].dim() == 3:  # (batch_size, seq_len, vocab_size)\n            logits = logits[0]\n        else:  # moe models have aux loss\n            logits = logits[1]\n\n    if logits.dim() != 3:\n        raise ValueError(\"Cannot process the logits.\")\n\n    return torch.argmax(logits, dim=-1)\n\n@dataclass\nclass ComputeSimilarity:\n    r\"\"\"Compute text similarity scores and support `batch_eval_metrics`.\n\n    Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.\n    \"\"\"\n\n    tokenizer: \"PreTrainedTokenizer\"\n\n    def _dump(self) -> Optional[dict[str, float]]:\n        result = None\n        if hasattr(self, \"score_dict\"):\n            result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}\n\n        self.score_dict = {\"rouge-1\": [], \"rouge-2\": [], \"rouge-l\": [], \n                           \"bleu-1\": [], \"bleu-2\": [], \"bleu-3\": [], \"bleu-4\": []}\n        return result\n\n    def __post_init__(self):\n        self._dump()\n\n    def __call__(self, eval_preds: \"EvalPrediction\", compute_result: bool = True) -> Optional[dict[str, float]]:\n        preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)\n\n        preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)\n        labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)\n\n        decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)\n        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n\n        for pred, label in zip(decoded_preds, decoded_labels):\n            hypothesis = list(jieba.cut(pred))\n            reference = list(jieba.cut(label))\n\n            if len(\" \".join(hypothesis).split()) == 0 or len(\" \".join(reference).split()) == 0:\n                result = {\"rouge-1\": {\"f\": 0.0}, \"rouge-2\": {\"f\": 0.0}, \"rouge-l\": {\"f\": 0.0}}\n            else:\n                rouge = Rouge()\n                scores = rouge.get_scores(\" \".join(hypothesis), \" \".join(reference))\n                result = scores[0]\n                \n                refs = [reference]\n                hyp  = hypothesis\n                smooth = SmoothingFunction().method3\n                bleu1 = sentence_bleu(refs, hyp, weights=(1.0, 0.0, 0.0, 0.0), smoothing_function=smooth)\n                bleu2 = sentence_bleu(refs, hyp, weights=(0.5, 0.5, 0.0, 0.0), smoothing_function=smooth)\n                bleu3 = sentence_bleu(refs, hyp, weights=(1/3, 1/3, 1/3, 0.0), smoothing_function=smooth)\n                bleu4 = sentence_bleu(refs, hyp, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth)\n\n            for k, v in result.items():\n                self.score_dict[k].append(round(v[\"f\"] * 100, 4))\n\n            # bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)\n            # self.score_dict[\"bleu-4\"].append(round(bleu_score * 100, 4))\n            \n            self.score_dict[\"bleu-1\"].append(round(bleu1 * 100, 4))\n            self.score_dict[\"bleu-2\"].append(round(bleu2 * 100, 4))\n            self.score_dict[\"bleu-3\"].append(round(bleu3 * 100, 4))\n            self.score_dict[\"bleu-4\"].append(round(bleu4 * 100, 4))\n\n        if compute_result:\n            return self._dump()"
  },
  {
    "path": "kt-sft/ktransformers/sft/metrics_utils/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/sft/metrics_utils/constants.py",
    "content": "# Copyright 2025 the LlamaFactory team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nfrom collections import OrderedDict, defaultdict\nfrom enum import Enum, unique\nfrom typing import Optional\n\nfrom peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME\nfrom peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME\nfrom transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME\n\n\nAUDIO_PLACEHOLDER = os.getenv(\"AUDIO_PLACEHOLDER\", \"<audio>\")\n\nCHECKPOINT_NAMES = {\n    SAFE_ADAPTER_WEIGHTS_NAME,\n    ADAPTER_WEIGHTS_NAME,\n    SAFE_WEIGHTS_INDEX_NAME,\n    SAFE_WEIGHTS_NAME,\n    WEIGHTS_INDEX_NAME,\n    WEIGHTS_NAME,\n}\n\nCHOICES = [\"A\", \"B\", \"C\", \"D\"]\n\nDATA_CONFIG = \"dataset_info.json\"\n\nDEFAULT_TEMPLATE = defaultdict(str)\n\nFILEEXT2TYPE = {\n    \"arrow\": \"arrow\",\n    \"csv\": \"csv\",\n    \"json\": \"json\",\n    \"jsonl\": \"json\",\n    \"parquet\": \"parquet\",\n    \"txt\": \"text\",\n}\n\nIGNORE_INDEX = -100\n\nIMAGE_PLACEHOLDER = os.getenv(\"IMAGE_PLACEHOLDER\", \"<image>\")\n\nLAYERNORM_NAMES = {\"norm\", \"ln\"}\n\nLLAMABOARD_CONFIG = \"llamaboard_config.yaml\"\n\nMETHODS = [\"full\", \"freeze\", \"lora\"]\n\nMOD_SUPPORTED_MODELS = {\"bloom\", \"falcon\", \"gemma\", \"llama\", \"mistral\", \"mixtral\", \"phi\", \"starcoder2\"}\n\nMULTIMODAL_SUPPORTED_MODELS = set()\n\nPEFT_METHODS = {\"lora\"}\n\nRUNNING_LOG = \"running_log.txt\"\n\nSUBJECTS = [\"Average\", \"STEM\", \"Social Sciences\", \"Humanities\", \"Other\"]\n\nSUPPORTED_MODELS = OrderedDict()\n\nTRAINER_LOG = \"trainer_log.jsonl\"\n\nTRAINING_ARGS = \"training_args.yaml\"\n\nTRAINING_STAGES = {\n    \"Supervised Fine-Tuning\": \"sft\",\n    \"Reward Modeling\": \"rm\",\n    \"PPO\": \"ppo\",\n    \"DPO\": \"dpo\",\n    \"KTO\": \"kto\",\n    \"Pre-Training\": \"pt\",\n}\n\nSTAGES_USE_PAIR_DATA = {\"rm\", \"dpo\"}\n\nSUPPORTED_CLASS_FOR_S2ATTN = {\"llama\"}\n\nSWANLAB_CONFIG = \"swanlab_public_config.json\"\n\nVIDEO_PLACEHOLDER = os.getenv(\"VIDEO_PLACEHOLDER\", \"<video>\")\n\nV_HEAD_WEIGHTS_NAME = \"value_head.bin\"\n\nV_HEAD_SAFE_WEIGHTS_NAME = \"value_head.safetensors\"\n\n\nclass AttentionFunction(str, Enum):\n    AUTO = \"auto\"\n    DISABLED = \"disabled\"\n    SDPA = \"sdpa\"\n    FA2 = \"fa2\"\n\n\nclass EngineName(str, Enum):\n    HF = \"huggingface\"\n    VLLM = \"vllm\"\n    SGLANG = \"sglang\"\n\n\nclass DownloadSource(str, Enum):\n    DEFAULT = \"hf\"\n    MODELSCOPE = \"ms\"\n    OPENMIND = \"om\"\n\n\n@unique\nclass QuantizationMethod(str, Enum):\n    r\"\"\"Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.\"\"\"\n\n    BNB = \"bnb\"\n    GPTQ = \"gptq\"\n    AWQ = \"awq\"\n    AQLM = \"aqlm\"\n    QUANTO = \"quanto\"\n    EETQ = \"eetq\"\n    HQQ = \"hqq\"\n    MXFP4 = \"mxfp4\"\n\n\nclass RopeScaling(str, Enum):\n    LINEAR = \"linear\"\n    DYNAMIC = \"dynamic\"\n    YARN = \"yarn\"\n    LLAMA3 = \"llama3\"\n\n\ndef register_model_group(\n    models: dict[str, dict[DownloadSource, str]],\n    template: Optional[str] = None,\n    multimodal: bool = False,\n) -> None:\n    for name, path in models.items():\n        SUPPORTED_MODELS[name] = path\n        if template is not None and (\n            any(suffix in name for suffix in (\"-Chat\", \"-Distill\", \"-Instruct\", \"-Thinking\")) or multimodal\n        ):\n            DEFAULT_TEMPLATE[name] = template\n\n        if multimodal:\n            MULTIMODAL_SUPPORTED_MODELS.add(name)\n\n\nregister_model_group(\n    models={\n        \"Aya-23-8B-Chat\": {\n            DownloadSource.DEFAULT: \"CohereForAI/aya-23-8B\",\n        },\n        \"Aya-23-35B-Chat\": {\n            DownloadSource.DEFAULT: \"CohereForAI/aya-23-35B\",\n        },\n    },\n    template=\"cohere\",\n)\n\n\nregister_model_group(\n    models={\n        \"Baichuan-7B-Base\": {\n            DownloadSource.DEFAULT: \"baichuan-inc/Baichuan-7B\",\n            DownloadSource.MODELSCOPE: \"baichuan-inc/baichuan-7B\",\n        },\n        \"Baichuan-13B-Base\": {\n            DownloadSource.DEFAULT: \"baichuan-inc/Baichuan-13B-Base\",\n            DownloadSource.MODELSCOPE: \"baichuan-inc/Baichuan-13B-Base\",\n        },\n        \"Baichuan-13B-Chat\": {\n            DownloadSource.DEFAULT: \"baichuan-inc/Baichuan-13B-Chat\",\n            DownloadSource.MODELSCOPE: \"baichuan-inc/Baichuan-13B-Chat\",\n        },\n    },\n    template=\"baichuan\",\n)\n\n\nregister_model_group(\n    models={\n        \"Baichuan2-7B-Base\": {\n            DownloadSource.DEFAULT: \"baichuan-inc/Baichuan2-7B-Base\",\n            DownloadSource.MODELSCOPE: \"baichuan-inc/Baichuan2-7B-Base\",\n        },\n        \"Baichuan2-13B-Base\": {\n            DownloadSource.DEFAULT: \"baichuan-inc/Baichuan2-13B-Base\",\n            DownloadSource.MODELSCOPE: \"baichuan-inc/Baichuan2-13B-Base\",\n            DownloadSource.OPENMIND: \"Baichuan/Baichuan2_13b_base_pt\",\n        },\n        \"Baichuan2-7B-Chat\": {\n            DownloadSource.DEFAULT: \"baichuan-inc/Baichuan2-7B-Chat\",\n            DownloadSource.MODELSCOPE: \"baichuan-inc/Baichuan2-7B-Chat\",\n            DownloadSource.OPENMIND: \"Baichuan/Baichuan2_7b_chat_pt\",\n        },\n        \"Baichuan2-13B-Chat\": {\n            DownloadSource.DEFAULT: \"baichuan-inc/Baichuan2-13B-Chat\",\n            DownloadSource.MODELSCOPE: \"baichuan-inc/Baichuan2-13B-Chat\",\n            DownloadSource.OPENMIND: \"Baichuan/Baichuan2_13b_chat_pt\",\n        },\n    },\n    template=\"baichuan2\",\n)\n\n\nregister_model_group(\n    models={\n        \"BLOOM-560M\": {\n            DownloadSource.DEFAULT: \"bigscience/bloom-560m\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/bloom-560m\",\n        },\n        \"BLOOM-3B\": {\n            DownloadSource.DEFAULT: \"bigscience/bloom-3b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/bloom-3b\",\n        },\n        \"BLOOM-7B1\": {\n            DownloadSource.DEFAULT: \"bigscience/bloom-7b1\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/bloom-7b1\",\n        },\n    },\n)\n\n\nregister_model_group(\n    models={\n        \"BLOOMZ-560M\": {\n            DownloadSource.DEFAULT: \"bigscience/bloomz-560m\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/bloomz-560m\",\n        },\n        \"BLOOMZ-3B\": {\n            DownloadSource.DEFAULT: \"bigscience/bloomz-3b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/bloomz-3b\",\n        },\n        \"BLOOMZ-7B1-mt\": {\n            DownloadSource.DEFAULT: \"bigscience/bloomz-7b1-mt\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/bloomz-7b1-mt\",\n        },\n    },\n)\n\n\nregister_model_group(\n    models={\n        \"BlueLM-7B-Base\": {\n            DownloadSource.DEFAULT: \"vivo-ai/BlueLM-7B-Base\",\n            DownloadSource.MODELSCOPE: \"vivo-ai/BlueLM-7B-Base\",\n        },\n        \"BlueLM-7B-Chat\": {\n            DownloadSource.DEFAULT: \"vivo-ai/BlueLM-7B-Chat\",\n            DownloadSource.MODELSCOPE: \"vivo-ai/BlueLM-7B-Chat\",\n        },\n    },\n    template=\"bluelm\",\n)\n\n\nregister_model_group(\n    models={\n        \"Breeze-7B\": {\n            DownloadSource.DEFAULT: \"MediaTek-Research/Breeze-7B-Base-v1_0\",\n        },\n        \"Breeze-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"MediaTek-Research/Breeze-7B-Instruct-v1_0\",\n        },\n    },\n    template=\"breeze\",\n)\n\n\nregister_model_group(\n    models={\n        \"ChatGLM2-6B-Chat\": {\n            DownloadSource.DEFAULT: \"zai-org/chatglm2-6b\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/chatglm2-6b\",\n        }\n    },\n    template=\"chatglm2\",\n)\n\n\nregister_model_group(\n    models={\n        \"ChatGLM3-6B-Base\": {\n            DownloadSource.DEFAULT: \"zai-org/chatglm3-6b-base\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/chatglm3-6b-base\",\n        },\n        \"ChatGLM3-6B-Chat\": {\n            DownloadSource.DEFAULT: \"zai-org/chatglm3-6b\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/chatglm3-6b\",\n        },\n    },\n    template=\"chatglm3\",\n)\n\n\nregister_model_group(\n    models={\n        \"Chinese-Llama-2-1.3B\": {\n            DownloadSource.DEFAULT: \"hfl/chinese-llama-2-1.3b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/chinese-llama-2-1.3b\",\n        },\n        \"Chinese-Llama-2-7B\": {\n            DownloadSource.DEFAULT: \"hfl/chinese-llama-2-7b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/chinese-llama-2-7b\",\n        },\n        \"Chinese-Llama-2-13B\": {\n            DownloadSource.DEFAULT: \"hfl/chinese-llama-2-13b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/chinese-llama-2-13b\",\n        },\n        \"Chinese-Alpaca-2-1.3B-Chat\": {\n            DownloadSource.DEFAULT: \"hfl/chinese-alpaca-2-1.3b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/chinese-alpaca-2-1.3b\",\n        },\n        \"Chinese-Alpaca-2-7B-Chat\": {\n            DownloadSource.DEFAULT: \"hfl/chinese-alpaca-2-7b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/chinese-alpaca-2-7b\",\n        },\n        \"Chinese-Alpaca-2-13B-Chat\": {\n            DownloadSource.DEFAULT: \"hfl/chinese-alpaca-2-13b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/chinese-alpaca-2-13b\",\n        },\n    },\n    template=\"llama2_zh\",\n)\n\n\nregister_model_group(\n    models={\n        \"CodeGeeX4-9B-Chat\": {\n            DownloadSource.DEFAULT: \"zai-org/codegeex4-all-9b\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/codegeex4-all-9b\",\n        },\n    },\n    template=\"codegeex4\",\n)\n\n\nregister_model_group(\n    models={\n        \"CodeGemma-7B\": {\n            DownloadSource.DEFAULT: \"google/codegemma-7b\",\n        },\n        \"CodeGemma-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/codegemma-7b-it\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/codegemma-7b-it\",\n        },\n        \"CodeGemma-1.1-2B\": {\n            DownloadSource.DEFAULT: \"google/codegemma-1.1-2b\",\n        },\n        \"CodeGemma-1.1-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/codegemma-1.1-7b-it\",\n        },\n    },\n    template=\"gemma\",\n)\n\n\nregister_model_group(\n    models={\n        \"Codestral-22B-v0.1-Chat\": {\n            DownloadSource.DEFAULT: \"mistralai/Codestral-22B-v0.1\",\n            DownloadSource.MODELSCOPE: \"swift/Codestral-22B-v0.1\",\n        },\n    },\n    template=\"mistral\",\n)\n\n\nregister_model_group(\n    models={\n        \"CommandR-35B-Chat\": {\n            DownloadSource.DEFAULT: \"CohereForAI/c4ai-command-r-v01\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/c4ai-command-r-v01\",\n        },\n        \"CommandR-Plus-104B-Chat\": {\n            DownloadSource.DEFAULT: \"CohereForAI/c4ai-command-r-plus\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/c4ai-command-r-plus\",\n        },\n        \"CommandR-35B-4bit-Chat\": {\n            DownloadSource.DEFAULT: \"CohereForAI/c4ai-command-r-v01-4bit\",\n            DownloadSource.MODELSCOPE: \"mirror013/c4ai-command-r-v01-4bit\",\n        },\n        \"CommandR-Plus-104B-4bit-Chat\": {\n            DownloadSource.DEFAULT: \"CohereForAI/c4ai-command-r-plus-4bit\",\n        },\n    },\n    template=\"cohere\",\n)\n\n\nregister_model_group(\n    models={\n        \"DBRX-132B-Base\": {\n            DownloadSource.DEFAULT: \"databricks/dbrx-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/dbrx-base\",\n        },\n        \"DBRX-132B-Instruct\": {\n            DownloadSource.DEFAULT: \"databricks/dbrx-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/dbrx-instruct\",\n        },\n    },\n    template=\"dbrx\",\n)\n\n\nregister_model_group(\n    models={\n        \"DeepSeek-LLM-7B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-llm-7b-base\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-llm-7b-base\",\n        },\n        \"DeepSeek-LLM-67B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-llm-67b-base\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-llm-67b-base\",\n        },\n        \"DeepSeek-LLM-7B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-llm-7b-chat\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-llm-7b-chat\",\n        },\n        \"DeepSeek-LLM-67B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-llm-67b-chat\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-llm-67b-chat\",\n        },\n        \"DeepSeek-Math-7B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-math-7b-base\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-math-7b-base\",\n        },\n        \"DeepSeek-Math-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-math-7b-instruct\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-math-7b-instruct\",\n        },\n        \"DeepSeek-MoE-16B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-moe-16b-base\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-moe-16b-base\",\n        },\n        \"DeepSeek-MoE-16B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-moe-16b-chat\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-moe-16b-chat\",\n        },\n        \"DeepSeek-V2-16B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-V2-Lite\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-V2-Lite\",\n        },\n        \"DeepSeek-V2-236B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-V2\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-V2\",\n        },\n        \"DeepSeek-V2-16B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-V2-Lite-Chat\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-V2-Lite-Chat\",\n        },\n        \"DeepSeek-V2-236B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-V2-Chat\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-V2-Chat\",\n        },\n        \"DeepSeek-Coder-V2-16B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-Coder-V2-Lite-Base\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-Coder-V2-Lite-Base\",\n        },\n        \"DeepSeek-Coder-V2-236B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-Coder-V2-Base\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-Coder-V2-Base\",\n        },\n        \"DeepSeek-Coder-V2-16B-Instruct\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct\",\n        },\n        \"DeepSeek-Coder-V2-236B-Instruct\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-Coder-V2-Instruct\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-Coder-V2-Instruct\",\n        },\n    },\n    template=\"deepseek\",\n)\n\n\nregister_model_group(\n    models={\n        \"DeepSeek-Coder-6.7B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-coder-6.7b-base\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-coder-6.7b-base\",\n        },\n        \"DeepSeek-Coder-7B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-coder-7b-base-v1.5\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-coder-7b-base-v1.5\",\n        },\n        \"DeepSeek-Coder-33B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-coder-33b-base\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-coder-33b-base\",\n        },\n        \"DeepSeek-Coder-6.7B-Instruct\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-coder-6.7b-instruct\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-coder-6.7b-instruct\",\n        },\n        \"DeepSeek-Coder-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-coder-7b-instruct-v1.5\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-coder-7b-instruct-v1.5\",\n        },\n        \"DeepSeek-Coder-33B-Instruct\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/deepseek-coder-33b-instruct\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/deepseek-coder-33b-instruct\",\n        },\n    },\n    template=\"deepseekcoder\",\n)\n\n\nregister_model_group(\n    models={\n        \"DeepSeek-V2-0628-236B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-V2-Chat-0628\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-V2-Chat-0628\",\n        },\n        \"DeepSeek-V2.5-236B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-V2.5\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-V2.5\",\n        },\n        \"DeepSeek-V2.5-1210-236B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-V2.5-1210\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-V2.5-1210\",\n        },\n        \"DeepSeek-V3-671B-Base\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-V3-Base\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-V3-Base\",\n        },\n        \"DeepSeek-V3-671B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-V3\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-V3\",\n        },\n        \"DeepSeek-V3-0324-671B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-V3-0324\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-V3-0324\",\n        },\n    },\n    template=\"deepseek3\",\n)\n\n\nregister_model_group(\n    models={\n        \"DeepSeek-R1-1.5B-Distill\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\",\n        },\n        \"DeepSeek-R1-7B-Distill\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n        },\n        \"DeepSeek-R1-8B-Distill\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\",\n        },\n        \"DeepSeek-R1-14B-Distill\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B\",\n        },\n        \"DeepSeek-R1-32B-Distill\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B\",\n        },\n        \"DeepSeek-R1-70B-Distill\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-R1-Distill-Llama-70B\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-R1-Distill-Llama-70B\",\n        },\n        \"DeepSeek-R1-671B-Chat-Zero\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-R1-Zero\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-R1-Zero\",\n        },\n        \"DeepSeek-R1-671B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-R1\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-R1\",\n        },\n        \"DeepSeek-R1-0528-8B-Distill\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-R1-0528-Qwen3-8B\",\n        },\n        \"DeepSeek-R1-0528-671B-Chat\": {\n            DownloadSource.DEFAULT: \"deepseek-ai/DeepSeek-R1-0528\",\n            DownloadSource.MODELSCOPE: \"deepseek-ai/DeepSeek-R1-0528\",\n        },\n    },\n    template=\"deepseekr1\",\n)\n\n\nregister_model_group(\n    models={\n        \"Devstral-Small-2507-Instruct\": {\n            DownloadSource.DEFAULT: \"mistralai/Devstral-Small-2507\",\n            DownloadSource.MODELSCOPE: \"mistralai/Devstral-Small-2507\",\n        },\n    },\n    template=\"mistral_small\",\n)\n\n\nregister_model_group(\n    models={\n        \"EXAONE-3.0-7.8B-Instruct\": {\n            DownloadSource.DEFAULT: \"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct\",\n        },\n    },\n    template=\"exaone\",\n)\n\n\nregister_model_group(\n    models={\n        \"Falcon-7B\": {\n            DownloadSource.DEFAULT: \"tiiuae/falcon-7b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/falcon-7b\",\n        },\n        \"Falcon-11B\": {\n            DownloadSource.DEFAULT: \"tiiuae/falcon-11B\",\n            DownloadSource.MODELSCOPE: \"tiiuae/falcon-11B\",\n        },\n        \"Falcon-40B\": {\n            DownloadSource.DEFAULT: \"tiiuae/falcon-40b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/falcon-40b\",\n        },\n        \"Falcon-180B\": {\n            DownloadSource.DEFAULT: \"tiiuae/falcon-180b\",\n            DownloadSource.MODELSCOPE: \"modelscope/falcon-180B\",\n        },\n        \"Falcon-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"tiiuae/falcon-7b-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/falcon-7b-instruct\",\n        },\n        \"Falcon-40B-Instruct\": {\n            DownloadSource.DEFAULT: \"tiiuae/falcon-40b-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/falcon-40b-instruct\",\n        },\n        \"Falcon-180B-Chat\": {\n            DownloadSource.DEFAULT: \"tiiuae/falcon-180b-chat\",\n            DownloadSource.MODELSCOPE: \"modelscope/falcon-180B-chat\",\n        },\n    },\n    template=\"falcon\",\n)\n\nregister_model_group(\n    models={\n        \"Falcon-H1-0.5B-Base\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-0.5B-Base\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-0.5B-Base\",\n        },\n        \"Falcon-H1-1.5B-Base\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-1.5B-Base\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-1.5B-Base\",\n        },\n        \"Falcon-H1-1.5B-Deep-Base\": {\n            DownloadSource.DEFAULT: \"tiuae/Falcon-H1-1.5B-Deep-Base\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-1.5B-Deep-Base\",\n        },\n        \"Falcon-H1-3B-Base\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-3B-Base\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-3B-Base\",\n        },\n        \"Falcon-H1-7B-Base\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-7B-Base\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-7B-Base\",\n        },\n        \"Falcon-H1-34B-Base\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-34B-Base\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-34B-Base\",\n        },\n        \"Falcon-H1-0.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-0.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-0.5B-Instruct\",\n        },\n        \"Falcon-H1-1.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-1.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-1.5B-Instruct\",\n        },\n        \"Falcon-H1-1.5B-Deep-Instruct\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-1.5B-Deep-Instruct\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-1.5B-Deep-Instruct\",\n        },\n        \"Falcon-H1-3B-Instruct\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-3B-Instruct\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-3B-Instruct\",\n        },\n        \"Falcon-H1-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-7B-Instruct\",\n        },\n        \"Falcon-H1-34B-Instruct\": {\n            DownloadSource.DEFAULT: \"tiiuae/Falcon-H1-34B-Instruct\",\n            DownloadSource.MODELSCOPE: \"tiiuae/Falcon-H1-34B-Instruct\",\n        },\n    },\n    template=\"falcon_h1\",\n)\n\n\nregister_model_group(\n    models={\n        \"Gemma-2B\": {\n            DownloadSource.DEFAULT: \"google/gemma-2b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/gemma-2b\",\n        },\n        \"Gemma-7B\": {\n            DownloadSource.DEFAULT: \"google/gemma-7b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/gemma-2b-it\",\n        },\n        \"Gemma-2B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-2b-it\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/gemma-7b\",\n        },\n        \"Gemma-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-7b-it\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/gemma-7b-it\",\n        },\n        \"Gemma-1.1-2B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-1.1-2b-it\",\n        },\n        \"Gemma-1.1-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-1.1-7b-it\",\n        },\n    },\n    template=\"gemma\",\n)\n\n\nregister_model_group(\n    models={\n        \"Gemma-2-2B\": {\n            DownloadSource.DEFAULT: \"google/gemma-2-2b\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-2-2b\",\n        },\n        \"Gemma-2-9B\": {\n            DownloadSource.DEFAULT: \"google/gemma-2-9b\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-2-9b\",\n        },\n        \"Gemma-2-27B\": {\n            DownloadSource.DEFAULT: \"google/gemma-2-27b\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-2-27b\",\n        },\n        \"Gemma-2-2B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-2-2b-it\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-2-2b-it\",\n            DownloadSource.OPENMIND: \"LlamaFactory/gemma-2-2b-it\",\n        },\n        \"Gemma-2-9B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-2-9b-it\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-2-9b-it\",\n            DownloadSource.OPENMIND: \"LlamaFactory/gemma-2-9b-it\",\n        },\n        \"Gemma-2-27B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-2-27b-it\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-2-27b-it\",\n        },\n        \"Gemma-3-1B\": {\n            DownloadSource.DEFAULT: \"google/gemma-3-1b-pt\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3-1b-pt\",\n        },\n        \"Gemma-3-1B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-3-1b-it\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3-1b-it\",\n        },\n        \"MedGemma-27B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/medgemma-27b-text-it\",\n            DownloadSource.MODELSCOPE: \"google/medgemma-27b-text-it\",\n        },\n    },\n    template=\"gemma2\",\n)\n\n\nregister_model_group(\n    models={\n        \"Gemma-3-4B\": {\n            DownloadSource.DEFAULT: \"google/gemma-3-4b-pt\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3-4b-pt\",\n        },\n        \"Gemma-3-12B\": {\n            DownloadSource.DEFAULT: \"google/gemma-3-12b-pt\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3-12b-pt\",\n        },\n        \"Gemma-3-27B\": {\n            DownloadSource.DEFAULT: \"google/gemma-3-27b-pt\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3-27b-pt\",\n        },\n        \"Gemma-3-4B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-3-4b-it\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3-4b-it\",\n        },\n        \"Gemma-3-12B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-3-12b-it\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3-12b-it\",\n        },\n        \"Gemma-3-27B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-3-27b-it\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3-27b-it\",\n        },\n        \"MedGemma-4B\": {\n            DownloadSource.DEFAULT: \"google/medgemma-4b-pt\",\n            DownloadSource.MODELSCOPE: \"google/medgemma-4b-pt\",\n        },\n        \"MedGemma-4B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/medgemma-4b-it\",\n            DownloadSource.MODELSCOPE: \"google/medgemma-4b-it\",\n        },\n    },\n    template=\"gemma3\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Gemma-3n-E2B\": {\n            DownloadSource.DEFAULT: \"google/gemma-3n-E2B\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3n-E2B\",\n        },\n        \"Gemma-3n-E4B\": {\n            DownloadSource.DEFAULT: \"google/gemma-3n-E4B\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3n-E4B\",\n        },\n        \"Gemma-3n-E2B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-3n-E2B-it\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3n-E2B-it\",\n        },\n        \"Gemma-3n-E4B-Instruct\": {\n            DownloadSource.DEFAULT: \"google/gemma-3n-E4B-it\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/gemma-3n-E4B-it\",\n        },\n    },\n    template=\"gemma3n\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"GLM-4-9B\": {\n            DownloadSource.DEFAULT: \"zai-org/glm-4-9b\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/glm-4-9b\",\n        },\n        \"GLM-4-9B-Chat\": {\n            DownloadSource.DEFAULT: \"zai-org/glm-4-9b-chat\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/glm-4-9b-chat\",\n            DownloadSource.OPENMIND: \"LlamaFactory/glm-4-9b-chat\",\n        },\n        \"GLM-4-9B-1M-Chat\": {\n            DownloadSource.DEFAULT: \"zai-org/glm-4-9b-chat-1m\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/glm-4-9b-chat-1m\",\n        },\n        \"GLM-4-0414-9B-Chat\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-4-9B-0414\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-4-9B-0414\",\n        },\n        \"GLM-4-0414-32B-Base\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-4-32B-Base-0414\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-4-32B-Base-0414\",\n        },\n        \"GLM-4-0414-32B-Chat\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-4-32B-0414\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-4-32B-0414\",\n        },\n    },\n    template=\"glm4\",\n)\n\n\nregister_model_group(\n    models={\n        \"GLM-4.1V-9B-Base\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-4.1V-9B-Base\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-4.1V-9B-Base\",\n        },\n        \"GLM-4.1V-9B-Thinking\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-4.1V-9B-Thinking\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-4.1V-9B-Thinking\",\n        },\n    },\n    template=\"glm4v\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"GLM-4.5-Air-Base\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-4.5-Air-Base\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-4.5-Air-Base\",\n        },\n        \"GLM-4.5-Base\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-4.5-Base\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-4.5-Base\",\n        },\n        \"GLM-4.5-Air-Thinking\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-4.5-Air\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-4.5-Air\",\n        },\n        \"GLM-4.5-Thinking\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-4.5\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-4.5\",\n        },\n    },\n    template=\"glm4_moe\",\n)\n\n\nregister_model_group(\n    models={\n        \"GLM-4.5V-Air-Thinking\":{\n            DownloadSource.DEFAULT: \"zai-org/GLM-4.5V\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-4.5V\",\n        }\n    },\n    template=\"glm45v\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"GLM-Z1-0414-9B-Chat\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-Z1-9B-0414\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-Z1-9B-0414\",\n        },\n        \"GLM-Z1-0414-32B-Chat\": {\n            DownloadSource.DEFAULT: \"zai-org/GLM-Z1-32B-0414\",\n            DownloadSource.MODELSCOPE: \"ZhipuAI/GLM-Z1-32B-0414\",\n        },\n    },\n    template=\"glmz1\",\n)\n\n\nregister_model_group(\n    models={\n        \"GPT-2-Small\": {\n            DownloadSource.DEFAULT: \"openai-community/gpt2\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/gpt2\",\n        },\n        \"GPT-2-Medium\": {\n            DownloadSource.DEFAULT: \"openai-community/gpt2-medium\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/gpt2-medium\",\n        },\n        \"GPT-2-Large\": {\n            DownloadSource.DEFAULT: \"openai-community/gpt2-large\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/gpt2-large\",\n        },\n        \"GPT-2-XL\": {\n            DownloadSource.DEFAULT: \"openai-community/gpt2-xl\",\n            DownloadSource.MODELSCOPE: \"goodbai95/GPT2-xl\",\n        },\n    },\n)\n\n\nregister_model_group(\n    models={\n        \"GPT-OSS-20B-Thinking\": {\n            DownloadSource.DEFAULT: \"openai/gpt-oss-20b\",\n            DownloadSource.MODELSCOPE: \"openai/gpt-oss-20b\",\n        },\n        \"GPT-OSS-120B-Thinking\": {\n            DownloadSource.DEFAULT: \"openai/gpt-oss-120b\",\n            DownloadSource.MODELSCOPE: \"openai/gpt-oss-120b\",\n        },\n    },\n    template=\"gpt\",\n)\n\n\nregister_model_group(\n    models={\n        \"Granite-3.0-1B-A400M-Base\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.0-1b-a400m-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.0-1b-a400m-base\",\n        },\n        \"Granite-3.0-3B-A800M-Base\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.0-3b-a800m-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.0-3b-a800m-base\",\n        },\n        \"Granite-3.0-2B-Base\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.0-2b-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.0-2b-base\",\n        },\n        \"Granite-3.0-8B-Base\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.0-8b-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.0-8b-base\",\n        },\n        \"Granite-3.0-1B-A400M-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.0-1b-a400m-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.0-1b-a400m-instruct\",\n        },\n        \"Granite-3.0-3B-A800M-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.0-3b-a800m-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.0-3b-a800m-instruct\",\n        },\n        \"Granite-3.0-2B-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.0-2b-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.0-2b-instruct\",\n        },\n        \"Granite-3.0-8B-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.0-8b-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.0-8b-instruct\",\n        },\n        \"Granite-3.1-1B-A400M-Base\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.1-1b-a400m-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.1-1b-a400m-base\",\n        },\n        \"Granite-3.1-3B-A800M-Base\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.1-3b-a800m-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.1-3b-a800m-base\",\n        },\n        \"Granite-3.1-2B-Base\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.1-2b-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.1-2b-base\",\n        },\n        \"Granite-3.1-8B-Base\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.1-8b-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.1-8b-base\",\n        },\n        \"Granite-3.1-1B-A400M-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.1-1b-a400m-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.1-1b-a400m-instruct\",\n        },\n        \"Granite-3.1-3B-A800M-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.1-3b-a800m-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.1-3b-a800m-instruct\",\n        },\n        \"Granite-3.1-2B-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.1-2b-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.1-2b-instruct\",\n        },\n        \"Granite-3.1-8B-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.1-8b-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.1-8b-instruct\",\n        },\n        \"Granite-3.2-2B-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.2-2b-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.2-2b-instruct\",\n        },\n        \"Granite-3.2-8B-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.2-8b-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.2-8b-instruct\",\n        },\n        \"Granite-3.3-2B-Base\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.3-2b-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.3-2b-base\",\n        },\n        \"Granite-3.3-8B-Base\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.3-8b-base\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.3-8b-base\",\n        },\n        \"Granite-3.3-2B-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.3-2b-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.3-2b-instruct\",\n        },\n        \"Granite-3.3-8B-Instruct\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-3.3-8b-instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-3.3-8b-instruct\",\n        },\n    },\n    template=\"granite3\",\n)\n\n\nregister_model_group(\n    models={\n        \"Granite-Vision-3.2-2B\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-vision-3.2-2b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/granite-vision-3.2-2b\",\n        },\n    },\n    template=\"granite3_vision\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Granite-4.0-tiny-preview\": {\n            DownloadSource.DEFAULT: \"ibm-granite/granite-4.0-tiny-preview\",\n            DownloadSource.MODELSCOPE: \"ibm-granite/granite-4.0-tiny-preview\",\n        },\n    },\n    template=\"granite4\",\n)\n\n\nregister_model_group(\n    models={\n        \"Hunyuan-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"tencent/Hunyuan-7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Hunyuan-7B-Instruct\",\n        },\n    },\n    template=\"hunyuan\",\n)\n\n\nregister_model_group(\n    models={\n        \"Index-1.9B-Base\": {\n            DownloadSource.DEFAULT: \"IndexTeam/Index-1.9B\",\n            DownloadSource.MODELSCOPE: \"IndexTeam/Index-1.9B\",\n        },\n        \"Index-1.9B-Base-Pure\": {\n            DownloadSource.DEFAULT: \"IndexTeam/Index-1.9B-Pure\",\n            DownloadSource.MODELSCOPE: \"IndexTeam/Index-1.9B-Pure\",\n        },\n        \"Index-1.9B-Chat\": {\n            DownloadSource.DEFAULT: \"IndexTeam/Index-1.9B-Chat\",\n            DownloadSource.MODELSCOPE: \"IndexTeam/Index-1.9B-Chat\",\n        },\n        \"Index-1.9B-Character-Chat\": {\n            DownloadSource.DEFAULT: \"IndexTeam/Index-1.9B-Character\",\n            DownloadSource.MODELSCOPE: \"IndexTeam/Index-1.9B-Character\",\n        },\n        \"Index-1.9B-Chat-32K\": {\n            DownloadSource.DEFAULT: \"IndexTeam/Index-1.9B-32K\",\n            DownloadSource.MODELSCOPE: \"IndexTeam/Index-1.9B-32K\",\n        },\n    },\n    template=\"index\",\n)\n\n\nregister_model_group(\n    models={\n        \"InternLM-7B\": {\n            DownloadSource.DEFAULT: \"internlm/internlm-7b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm-7b\",\n        },\n        \"InternLM-20B\": {\n            DownloadSource.DEFAULT: \"internlm/internlm-20b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm-20b\",\n        },\n        \"InternLM-7B-Chat\": {\n            DownloadSource.DEFAULT: \"internlm/internlm-chat-7b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm-chat-7b\",\n        },\n        \"InternLM-20B-Chat\": {\n            DownloadSource.DEFAULT: \"internlm/internlm-chat-20b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm-chat-20b\",\n        },\n    },\n    template=\"intern\",\n)\n\n\nregister_model_group(\n    models={\n        \"InternLM2-7B\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2-7b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2-7b\",\n        },\n        \"InternLM2-20B\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2-20b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2-20b\",\n        },\n        \"InternLM2-7B-Chat\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2-chat-7b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2-chat-7b\",\n        },\n        \"InternLM2-20B-Chat\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2-chat-20b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2-chat-20b\",\n        },\n        \"InternLM2.5-1.8B\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2_5-1_8b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2_5-1_8b\",\n            DownloadSource.OPENMIND: \"Intern/internlm2_5-1_8b\",\n        },\n        \"InternLM2.5-7B\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2_5-7b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2_5-7b\",\n        },\n        \"InternLM2.5-20B\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2_5-20b\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2_5-20b\",\n            DownloadSource.OPENMIND: \"Intern/internlm2_5-20b\",\n        },\n        \"InternLM2.5-1.8B-Chat\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2_5-1_8b-chat\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2_5-1_8b-chat\",\n            DownloadSource.OPENMIND: \"Intern/internlm2_5-1_8b-chat\",\n        },\n        \"InternLM2.5-7B-Chat\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2_5-7b-chat\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2_5-7b-chat\",\n            DownloadSource.OPENMIND: \"Intern/internlm2_5-7b-chat\",\n        },\n        \"InternLM2.5-7B-1M-Chat\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2_5-7b-chat-1m\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m\",\n            DownloadSource.OPENMIND: \"Intern/internlm2_5-7b-chat-1m\",\n        },\n        \"InternLM2.5-20B-Chat\": {\n            DownloadSource.DEFAULT: \"internlm/internlm2_5-20b-chat\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm2_5-20b-chat\",\n            DownloadSource.OPENMIND: \"Intern/internlm2_5-20b-chat\",\n        },\n        \"InternLM3-8B-Chat\": {\n            DownloadSource.DEFAULT: \"internlm/internlm3-8b-instruct\",\n            DownloadSource.MODELSCOPE: \"Shanghai_AI_Laboratory/internlm3-8b-instruct\",\n        },\n    },\n    template=\"intern2\",\n)\n\n\nregister_model_group(\n    models={\n        \"InternVL2.5-2B-MPO\": {\n            DownloadSource.DEFAULT: \"OpenGVLab/InternVL2_5-2B-MPO-hf\",\n            DownloadSource.MODELSCOPE: \"OpenGVLab/InternVL2_5-2B-MPO-hf\",\n        },\n        \"InternVL2.5-8B-MPO\": {\n            DownloadSource.DEFAULT: \"OpenGVLab/InternVL2_5-8B-MPO-hf\",\n            DownloadSource.MODELSCOPE: \"OpenGVLab/InternVL2_5-8B-MPO-hf\",\n        },\n        \"InternVL3-1B-hf\": {\n            DownloadSource.DEFAULT: \"OpenGVLab/InternVL3-1B-hf\",\n            DownloadSource.MODELSCOPE: \"OpenGVLab/InternVL3-1B-hf\",\n        },\n        \"InternVL3-2B-hf\": {\n            DownloadSource.DEFAULT: \"OpenGVLab/InternVL3-2B-hf\",\n            DownloadSource.MODELSCOPE: \"OpenGVLab/InternVL3-2B-hf\",\n        },\n        \"InternVL3-8B-hf\": {\n            DownloadSource.DEFAULT: \"OpenGVLab/InternVL3-8B-hf\",\n            DownloadSource.MODELSCOPE: \"OpenGVLab/InternVL3-8B-hf\",\n        },\n        \"InternVL3-14B-hf\": {\n            DownloadSource.DEFAULT: \"OpenGVLab/InternVL3-14B-hf\",\n            DownloadSource.MODELSCOPE: \"OpenGVLab/InternVL3-14B-hf\",\n        },\n        \"InternVL3-38B-hf\": {\n            DownloadSource.DEFAULT: \"OpenGVLab/InternVL3-38B-hf\",\n            DownloadSource.MODELSCOPE: \"OpenGVLab/InternVL3-38B-hf\",\n        },\n        \"InternVL3-78B-hf\": {\n            DownloadSource.DEFAULT: \"OpenGVLab/InternVL3-78B-hf\",\n            DownloadSource.MODELSCOPE: \"OpenGVLab/InternVL3-78B-hf\",\n        },\n    },\n    template=\"intern_vl\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Jamba-v0.1\": {\n            DownloadSource.DEFAULT: \"ai21labs/Jamba-v0.1\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Jamba-v0.1\",\n        }\n    },\n)\n\n\nregister_model_group(\n    models={\n        \"Keye-VL-8B-Chat\": {\n            DownloadSource.DEFAULT: \"Kwai-Keye/Keye-VL-8B-Preview\",\n            DownloadSource.MODELSCOPE: \"Kwai-Keye/Keye-VL-8B-Preview\",\n        },\n    },\n    template=\"keye_vl\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Kimi-Dev-72B-Instruct\": {\n            DownloadSource.DEFAULT: \"moonshotai/Kimi-Dev-72B\",\n            DownloadSource.MODELSCOPE: \"moonshotai/Kimi-Dev-72B\",\n        },\n    },\n    template=\"qwen\",\n)\n\n\nregister_model_group(\n    models={\n        \"Kimi-VL-A3B-Instruct\": {\n            DownloadSource.DEFAULT: \"moonshotai/Kimi-VL-A3B-Instruct\",\n            DownloadSource.MODELSCOPE: \"moonshotai/Kimi-VL-A3B-Instruct\",\n        },\n        \"Kimi-VL-A3B-Thinking\": {\n            DownloadSource.DEFAULT: \"moonshotai/Kimi-VL-A3B-Thinking\",\n            DownloadSource.MODELSCOPE: \"moonshotai/Kimi-VL-A3B-Thinking\",\n        },\n        \"Kimi-VL-A3B-Thinking-2506\": {\n            DownloadSource.DEFAULT: \"moonshotai/Kimi-VL-A3B-Thinking-2506\",\n            DownloadSource.MODELSCOPE: \"moonshotai/Kimi-VL-A3B-Thinking-2506\",\n        },\n    },\n    template=\"kimi_vl\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"LingoWhale-8B\": {\n            DownloadSource.DEFAULT: \"deeplang-ai/LingoWhale-8B\",\n            DownloadSource.MODELSCOPE: \"DeepLang/LingoWhale-8B\",\n        }\n    },\n)\n\n\nregister_model_group(\n    models={\n        \"Llama-7B\": {\n            DownloadSource.DEFAULT: \"huggyllama/llama-7b\",\n            DownloadSource.MODELSCOPE: \"skyline2006/llama-7b\",\n        },\n        \"Llama-13B\": {\n            DownloadSource.DEFAULT: \"huggyllama/llama-13b\",\n            DownloadSource.MODELSCOPE: \"skyline2006/llama-13b\",\n        },\n        \"Llama-30B\": {\n            DownloadSource.DEFAULT: \"huggyllama/llama-30b\",\n            DownloadSource.MODELSCOPE: \"skyline2006/llama-30b\",\n        },\n        \"Llama-65B\": {\n            DownloadSource.DEFAULT: \"huggyllama/llama-65b\",\n            DownloadSource.MODELSCOPE: \"skyline2006/llama-65b\",\n        },\n    }\n)\n\n\nregister_model_group(\n    models={\n        \"Llama-2-7B\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-2-7b-hf\",\n            DownloadSource.MODELSCOPE: \"modelscope/Llama-2-7b-ms\",\n        },\n        \"Llama-2-13B\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-2-13b-hf\",\n            DownloadSource.MODELSCOPE: \"modelscope/Llama-2-13b-ms\",\n        },\n        \"Llama-2-70B\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-2-70b-hf\",\n            DownloadSource.MODELSCOPE: \"modelscope/Llama-2-70b-ms\",\n        },\n        \"Llama-2-7B-Chat\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-2-7b-chat-hf\",\n            DownloadSource.MODELSCOPE: \"modelscope/Llama-2-7b-chat-ms\",\n        },\n        \"Llama-2-13B-Chat\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-2-13b-chat-hf\",\n            DownloadSource.MODELSCOPE: \"modelscope/Llama-2-13b-chat-ms\",\n        },\n        \"Llama-2-70B-Chat\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-2-70b-chat-hf\",\n            DownloadSource.MODELSCOPE: \"modelscope/Llama-2-70b-chat-ms\",\n        },\n    },\n    template=\"llama2\",\n)\n\n\nregister_model_group(\n    models={\n        \"Llama-3-8B\": {\n            DownloadSource.DEFAULT: \"meta-llama/Meta-Llama-3-8B\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Meta-Llama-3-8B\",\n        },\n        \"Llama-3-70B\": {\n            DownloadSource.DEFAULT: \"meta-llama/Meta-Llama-3-70B\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Meta-Llama-3-70B\",\n        },\n        \"Llama-3-8B-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Meta-Llama-3-8B-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Meta-Llama-3-8B-Instruct\",\n        },\n        \"Llama-3-70B-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Meta-Llama-3-70B-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Meta-Llama-3-70B-Instruct\",\n        },\n        \"Llama-3-8B-Chinese-Chat\": {\n            DownloadSource.DEFAULT: \"shenzhi-wang/Llama3-8B-Chinese-Chat\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama3-8B-Chinese-Chat\",\n            DownloadSource.OPENMIND: \"LlamaFactory/Llama3-Chinese-8B-Instruct\",\n        },\n        \"Llama-3-70B-Chinese-Chat\": {\n            DownloadSource.DEFAULT: \"shenzhi-wang/Llama3-70B-Chinese-Chat\",\n        },\n        \"Llama-3.1-8B\": {\n            DownloadSource.DEFAULT: \"meta-llama/Meta-Llama-3.1-8B\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Meta-Llama-3.1-8B\",\n        },\n        \"Llama-3.1-70B\": {\n            DownloadSource.DEFAULT: \"meta-llama/Meta-Llama-3.1-70B\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Meta-Llama-3.1-70B\",\n        },\n        \"Llama-3.1-405B\": {\n            DownloadSource.DEFAULT: \"meta-llama/Meta-Llama-3.1-405B\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Meta-Llama-3.1-405B\",\n        },\n        \"Llama-3.1-8B-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Meta-Llama-3.1-8B-Instruct\",\n        },\n        \"Llama-3.1-70B-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Meta-Llama-3.1-70B-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Meta-Llama-3.1-70B-Instruct\",\n        },\n        \"Llama-3.1-405B-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Meta-Llama-3.1-405B-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Meta-Llama-3.1-405B-Instruct\",\n        },\n        \"Llama-3.1-8B-Chinese-Chat\": {\n            DownloadSource.DEFAULT: \"shenzhi-wang/Llama3.1-8B-Chinese-Chat\",\n            DownloadSource.MODELSCOPE: \"XD_AI/Llama3.1-8B-Chinese-Chat\",\n        },\n        \"Llama-3.1-70B-Chinese-Chat\": {\n            DownloadSource.DEFAULT: \"shenzhi-wang/Llama3.1-70B-Chinese-Chat\",\n            DownloadSource.MODELSCOPE: \"XD_AI/Llama3.1-70B-Chinese-Chat\",\n        },\n        \"Llama-3.2-1B\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-3.2-1B\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-3.2-1B\",\n        },\n        \"Llama-3.2-3B\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-3.2-3B\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-3.2-3B\",\n        },\n        \"Llama-3.2-1B-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-3.2-1B-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-3.2-1B-Instruct\",\n        },\n        \"Llama-3.2-3B-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-3.2-3B-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-3.2-3B-Instruct\",\n        },\n        \"Llama-3.3-70B-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-3.3-70B-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-3.3-70B-Instruct\",\n        },\n    },\n    template=\"llama3\",\n)\n\n\nregister_model_group(\n    models={\n        \"Llama-3.2-11B-Vision\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-3.2-11B-Vision\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-3.2-11B-Vision\",\n        },\n        \"Llama-3.2-11B-Vision-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-3.2-11B-Vision-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-3.2-11B-Vision-Instruct\",\n        },\n        \"Llama-3.2-90B-Vision\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-3.2-90B-Vision\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-3.2-90B-Vision\",\n        },\n        \"Llama-3.2-90B-Vision-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-3.2-90B-Vision-Instruct\",\n        },\n    },\n    template=\"mllama\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Llama-4-Scout-17B-16E\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-4-Scout-17B-16E\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-4-Scout-17B-16E\",\n        },\n        \"Llama-4-Scout-17B-16E-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-4-Scout-17B-16E-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-4-Scout-17B-16E-Instruct\",\n        },\n        \"Llama-4-Maverick-17B-128E\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-4-Maverick-17B-128E\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-4-Maverick-17B-128E\",\n        },\n        \"Llama-4-Maverick-17B-128E-Instruct\": {\n            DownloadSource.DEFAULT: \"meta-llama/Llama-4-Maverick-17B-128E-Instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Llama-4-Maverick-17B-128E-Instruct\",\n        },\n    },\n    template=\"llama4\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"LLaVA-1.5-7B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/llava-1.5-7b-hf\",\n            DownloadSource.MODELSCOPE: \"swift/llava-1.5-7b-hf\",\n        },\n        \"LLaVA-1.5-13B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/llava-1.5-13b-hf\",\n            DownloadSource.MODELSCOPE: \"swift/llava-1.5-13b-hf\",\n        },\n    },\n    template=\"llava\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"LLaVA-NeXT-7B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/llava-v1.6-vicuna-7b-hf\",\n            DownloadSource.MODELSCOPE: \"swift/llava-v1.6-vicuna-7b-hf\",\n        },\n        \"LLaVA-NeXT-13B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/llava-v1.6-vicuna-13b-hf\",\n            DownloadSource.MODELSCOPE: \"swift/llava-v1.6-vicuna-13b-hf\",\n        },\n    },\n    template=\"llava_next\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"LLaVA-NeXT-Mistral-7B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/llava-v1.6-mistral-7b-hf\",\n            DownloadSource.MODELSCOPE: \"swift/llava-v1.6-mistral-7b-hf\",\n        },\n    },\n    template=\"llava_next_mistral\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"LLaVA-NeXT-Llama3-8B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/llama3-llava-next-8b-hf\",\n            DownloadSource.MODELSCOPE: \"swift/llama3-llava-next-8b-hf\",\n        },\n    },\n    template=\"llava_next_llama3\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"LLaVA-NeXT-34B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/llava-v1.6-34b-hf\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/llava-v1.6-34b-hf\",\n        },\n    },\n    template=\"llava_next_yi\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"LLaVA-NeXT-72B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/llava-next-72b-hf\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/llava-next-72b-hf\",\n        },\n        \"LLaVA-NeXT-110B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/llava-next-110b-hf\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/llava-next-110b-hf\",\n        },\n    },\n    template=\"llava_next_qwen\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"LLaVA-NeXT-Video-7B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/LLaVA-NeXT-Video-7B-hf\",\n            DownloadSource.MODELSCOPE: \"swift/LLaVA-NeXT-Video-7B-hf\",\n        },\n        \"LLaVA-NeXT-Video-7B-DPO-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/LLaVA-NeXT-Video-7B-DPO-hf\",\n            DownloadSource.MODELSCOPE: \"swift/LLaVA-NeXT-Video-7B-DPO-hf\",\n        },\n    },\n    template=\"llava_next_video\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"LLaVA-NeXT-Video-7B-32k-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/LLaVA-NeXT-Video-7B-32K-hf\",\n            DownloadSource.MODELSCOPE: \"swift/LLaVA-NeXT-Video-7B-32K-hf\",\n        },\n    },\n    template=\"llava_next_video_mistral\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"LLaVA-NeXT-Video-34B-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/LLaVA-NeXT-Video-34B-hf\",\n            DownloadSource.MODELSCOPE: \"swift/LLaVA-NeXT-Video-34B-hf\",\n        },\n        \"LLaVA-NeXT-Video-34B-DPO-Chat\": {\n            DownloadSource.DEFAULT: \"llava-hf/LLaVA-NeXT-Video-34B-DPO-hf\",\n        },\n    },\n    template=\"llava_next_video_yi\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Marco-o1-Chat\": {\n            DownloadSource.DEFAULT: \"AIDC-AI/Marco-o1\",\n            DownloadSource.MODELSCOPE: \"AIDC-AI/Marco-o1\",\n        },\n    },\n    template=\"marco\",\n)\n\n\nregister_model_group(\n    models={\n        \"MiMo-7B-Base\": {\n            DownloadSource.DEFAULT: \"XiaomiMiMo/MiMo-7B-Base\",\n            DownloadSource.MODELSCOPE: \"XiaomiMiMo/MiMo-7B-Base\",\n        },\n        \"MiMo-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"XiaomiMiMo/MiMo-7B-SFT\",\n            DownloadSource.MODELSCOPE: \"XiaomiMiMo/MiMo-7B-SFT\",\n        },\n        \"MiMo-7B-Instruct-RL\": {\n            DownloadSource.DEFAULT: \"XiaomiMiMo/MiMo-7B-RL\",\n            DownloadSource.MODELSCOPE: \"XiaomiMiMo/MiMo-7B-RL\",\n        },\n        \"MiMo-7B-RL-ZERO\": {\n            DownloadSource.DEFAULT: \"XiaomiMiMo/MiMo-7B-RL-ZERO\",\n            DownloadSource.MODELSCOPE: \"XiaomiMiMo/MiMo-7B-RL-ZERO\",\n        },\n    },\n    template=\"mimo\",\n)\n\n\nregister_model_group(\n    models={\n        \"MiMo-7B-VL-Instruct\": {\n            DownloadSource.DEFAULT: \"XiaomiMiMo/MiMo-VL-7B-SFT\",\n            DownloadSource.MODELSCOPE: \"XiaomiMiMo/MiMo-VL-7B-SFT\",\n        },\n        \"MiMo-7B-VL-RL\": {\n            DownloadSource.DEFAULT: \"XiaomiMiMo/MiMo-VL-7B-RL\",\n            DownloadSource.MODELSCOPE: \"XiaomiMiMo/MiMo-VL-7B-RL\",\n        },\n    },\n    template=\"mimo_vl\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"MiniCPM-2B-SFT-Chat\": {\n            DownloadSource.DEFAULT: \"openbmb/MiniCPM-2B-sft-bf16\",\n            DownloadSource.MODELSCOPE: \"OpenBMB/miniCPM-bf16\",\n        },\n        \"MiniCPM-2B-DPO-Chat\": {\n            DownloadSource.DEFAULT: \"openbmb/MiniCPM-2B-dpo-bf16\",\n            DownloadSource.MODELSCOPE: \"OpenBMB/MiniCPM-2B-dpo-bf16\",\n        },\n    },\n    template=\"cpm\",\n)\n\n\nregister_model_group(\n    models={\n        \"MiniCPM3-4B-Chat\": {\n            DownloadSource.DEFAULT: \"openbmb/MiniCPM3-4B\",\n            DownloadSource.MODELSCOPE: \"OpenBMB/MiniCPM3-4B\",\n            DownloadSource.OPENMIND: \"LlamaFactory/MiniCPM3-4B\",\n        },\n    },\n    template=\"cpm3\",\n)\n\n\nregister_model_group(\n    models={\n        \"MiniCPM4-0.5B-Chat\": {\n            DownloadSource.DEFAULT: \"openbmb/MiniCPM4-0.5B\",\n            DownloadSource.MODELSCOPE: \"OpenBMB/MiniCPM4-0.5B\",\n        },\n        \"MiniCPM4-8B-Chat\": {\n            DownloadSource.DEFAULT: \"openbmb/MiniCPM4-8B\",\n            DownloadSource.MODELSCOPE: \"OpenBMB/MiniCPM4-8B\",\n        },\n    },\n    template=\"cpm4\",\n)\n\n\nregister_model_group(\n    models={\n        \"MiniCPM-o-2_6\": {\n            DownloadSource.DEFAULT: \"openbmb/MiniCPM-o-2_6\",\n            DownloadSource.MODELSCOPE: \"OpenBMB/MiniCPM-o-2_6\",\n        },\n    },\n    template=\"minicpm_o\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"MiniCPM-V-2_6\": {\n            DownloadSource.DEFAULT: \"openbmb/MiniCPM-V-2_6\",\n            DownloadSource.MODELSCOPE: \"OpenBMB/MiniCPM-V-2_6\",\n        },\n    },\n    template=\"minicpm_v\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"MiniCPM-V-4\": {\n            DownloadSource.DEFAULT: \"openbmb/MiniCPM-V-4\",\n            DownloadSource.MODELSCOPE: \"OpenBMB/MiniCPM-V-4\",\n        },\n    },\n    template=\"minicpm_v\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Ministral-8B-Instruct-2410\": {\n            DownloadSource.DEFAULT: \"mistralai/Ministral-8B-Instruct-2410\",\n            DownloadSource.MODELSCOPE: \"mistralai/Ministral-8B-Instruct-2410\",\n        },\n        \"Mistral-Nemo-Base-2407\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-Nemo-Base-2407\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Mistral-Nemo-Base-2407\",\n        },\n        \"Mistral-Nemo-Instruct-2407\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-Nemo-Instruct-2407\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Mistral-Nemo-Instruct-2407\",\n        },\n    },\n    template=\"ministral\",\n)\n\n\nregister_model_group(\n    models={\n        \"Mistral-7B-v0.1\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-7B-v0.1\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Mistral-7B-v0.1\",\n        },\n        \"Mistral-7B-v0.2\": {\n            DownloadSource.DEFAULT: \"alpindale/Mistral-7B-v0.2-hf\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Mistral-7B-v0.2-hf\",\n        },\n        \"Mistral-7B-v0.3\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-7B-v0.3\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/mistral-7b-v0.3\",\n        },\n        \"Mistral-7B-Instruct-v0.1\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-7B-Instruct-v0.1\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Mistral-7B-Instruct-v0.1\",\n        },\n        \"Mistral-7B-Instruct-v0.2\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-7B-Instruct-v0.2\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Mistral-7B-Instruct-v0.2\",\n        },\n        \"Mistral-7B-Instruct-v0.3\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-7B-Instruct-v0.3\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Mistral-7B-Instruct-v0.3\",\n        },\n    },\n    template=\"mistral\",\n)\n\n\nregister_model_group(\n    models={\n        \"Mistral-Small-24B-Base-2501\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-Small-24B-Base-2501\",\n            DownloadSource.MODELSCOPE: \"mistralai/Mistral-Small-24B-Base-2501\",\n        },\n        \"Mistral-Small-24B-Instruct-2501\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-Small-24B-Instruct-2501\",\n            DownloadSource.MODELSCOPE: \"mistralai/Mistral-Small-24B-Instruct-2501\",\n        },\n    },\n    template=\"mistral_small\",\n)\n\n\nregister_model_group(\n    models={\n        \"Mistral-Small-3.1-24B-Base\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-Small-3.1-24B-Base-2503\",\n            DownloadSource.MODELSCOPE: \"mistralai/Mistral-Small-3.1-24B-Base-2503\",\n        },\n        \"Mistral-Small-3.1-24B-Instruct\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-Small-3.1-24B-Instruct-2503\",\n            DownloadSource.MODELSCOPE: \"mistralai/Mistral-Small-3.1-24B-Instruct-2503\",\n        },\n        \"Mistral-Small-3.2-24B-Instruct\": {\n            DownloadSource.DEFAULT: \"mistralai/Mistral-Small-3.2-24B-Instruct-2506\",\n            DownloadSource.MODELSCOPE: \"mistralai/Mistral-Small-3.2-24B-Instruct-2506\",\n        },\n    },\n    template=\"mistral_small\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Mixtral-8x7B-v0.1\": {\n            DownloadSource.DEFAULT: \"mistralai/Mixtral-8x7B-v0.1\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Mixtral-8x7B-v0.1\",\n        },\n        \"Mixtral-8x22B-v0.1\": {\n            DownloadSource.DEFAULT: \"mistralai/Mixtral-8x22B-v0.1\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Mixtral-8x22B-v0.1\",\n        },\n        \"Mixtral-8x7B-v0.1-Instruct\": {\n            DownloadSource.DEFAULT: \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Mixtral-8x7B-Instruct-v0.1\",\n        },\n        \"Mixtral-8x22B-v0.1-Instruct\": {\n            DownloadSource.DEFAULT: \"mistralai/Mixtral-8x22B-Instruct-v0.1\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Mixtral-8x22B-Instruct-v0.1\",\n        },\n    },\n    template=\"mistral\",\n)\n\n\nregister_model_group(\n    models={\n        \"Moonlight-16B-A3B\": {\n            DownloadSource.DEFAULT: \"moonshotai/Moonlight-16B-A3B\",\n            DownloadSource.MODELSCOPE: \"moonshotai/Moonlight-16B-A3B\",\n        },\n        \"Moonlight-16B-A3B-Instruct\": {\n            DownloadSource.DEFAULT: \"moonshotai/Moonlight-16B-A3B-Instruct\",\n            DownloadSource.MODELSCOPE: \"moonshotai/Moonlight-16B-A3B-Instruct\",\n        },\n    },\n    template=\"moonlight\",\n)\n\n\nregister_model_group(\n    models={\n        \"OLMo-1B\": {\n            DownloadSource.DEFAULT: \"allenai/OLMo-1B-hf\",\n        },\n        \"OLMo-7B\": {\n            DownloadSource.DEFAULT: \"allenai/OLMo-7B-hf\",\n        },\n        \"OLMo-7B-Chat\": {\n            DownloadSource.DEFAULT: \"ssec-uw/OLMo-7B-Instruct-hf\",\n        },\n        \"OLMo-1.7-7B\": {\n            DownloadSource.DEFAULT: \"allenai/OLMo-1.7-7B-hf\",\n        },\n    },\n)\n\n\nregister_model_group(\n    models={\n        \"OpenChat3.5-7B-Chat\": {\n            DownloadSource.DEFAULT: \"openchat/openchat-3.5-0106\",\n            DownloadSource.MODELSCOPE: \"xcwzxcwz/openchat-3.5-0106\",\n        }\n    },\n    template=\"openchat\",\n)\n\n\nregister_model_group(\n    models={\n        \"OpenChat3.6-8B-Chat\": {\n            DownloadSource.DEFAULT: \"openchat/openchat-3.6-8b-20240522\",\n        }\n    },\n    template=\"openchat-3.6\",\n)\n\n\nregister_model_group(\n    models={\n        \"OpenCoder-1.5B-Base\": {\n            DownloadSource.DEFAULT: \"infly/OpenCoder-1.5B-Base\",\n            DownloadSource.MODELSCOPE: \"infly/OpenCoder-1.5B-Base\",\n        },\n        \"OpenCoder-8B-Base\": {\n            DownloadSource.DEFAULT: \"infly/OpenCoder-8B-Base\",\n            DownloadSource.MODELSCOPE: \"infly/OpenCoder-8B-Base\",\n        },\n        \"OpenCoder-1.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"infly/OpenCoder-1.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"infly/OpenCoder-1.5B-Instruct\",\n        },\n        \"OpenCoder-8B-Instruct\": {\n            DownloadSource.DEFAULT: \"infly/OpenCoder-8B-Instruct\",\n            DownloadSource.MODELSCOPE: \"infly/OpenCoder-8B-Instruct\",\n        },\n    },\n    template=\"opencoder\",\n)\n\n\nregister_model_group(\n    models={\n        \"Orion-14B-Base\": {\n            DownloadSource.DEFAULT: \"OrionStarAI/Orion-14B-Base\",\n            DownloadSource.MODELSCOPE: \"OrionStarAI/Orion-14B-Base\",\n        },\n        \"Orion-14B-Chat\": {\n            DownloadSource.DEFAULT: \"OrionStarAI/Orion-14B-Chat\",\n            DownloadSource.MODELSCOPE: \"OrionStarAI/Orion-14B-Chat\",\n        },\n        \"Orion-14B-Long-Chat\": {\n            DownloadSource.DEFAULT: \"OrionStarAI/Orion-14B-LongChat\",\n            DownloadSource.MODELSCOPE: \"OrionStarAI/Orion-14B-LongChat\",\n        },\n        \"Orion-14B-RAG-Chat\": {\n            DownloadSource.DEFAULT: \"OrionStarAI/Orion-14B-Chat-RAG\",\n            DownloadSource.MODELSCOPE: \"OrionStarAI/Orion-14B-Chat-RAG\",\n        },\n        \"Orion-14B-Plugin-Chat\": {\n            DownloadSource.DEFAULT: \"OrionStarAI/Orion-14B-Chat-Plugin\",\n            DownloadSource.MODELSCOPE: \"OrionStarAI/Orion-14B-Chat-Plugin\",\n        },\n    },\n    template=\"orion\",\n)\n\n\nregister_model_group(\n    models={\n        \"PaliGemma-3B-pt-224\": {\n            DownloadSource.DEFAULT: \"google/paligemma-3b-pt-224\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma-3b-pt-224\",\n        },\n        \"PaliGemma-3B-pt-448\": {\n            DownloadSource.DEFAULT: \"google/paligemma-3b-pt-448\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma-3b-pt-448\",\n        },\n        \"PaliGemma-3B-pt-896\": {\n            DownloadSource.DEFAULT: \"google/paligemma-3b-pt-896\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma-3b-pt-896\",\n        },\n        \"PaliGemma-3B-mix-224\": {\n            DownloadSource.DEFAULT: \"google/paligemma-3b-mix-224\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma-3b-mix-224\",\n        },\n        \"PaliGemma-3B-mix-448\": {\n            DownloadSource.DEFAULT: \"google/paligemma-3b-mix-448\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma-3b-mix-448\",\n        },\n    },\n    template=\"paligemma\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"PaliGemma2-3B-pt-224\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-3b-pt-224\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma2-3b-pt-224\",\n        },\n        \"PaliGemma2-3B-pt-448\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-3b-pt-448\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma2-3b-pt-448\",\n        },\n        \"PaliGemma2-3B-pt-896\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-3b-pt-896\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma2-3b-pt-896\",\n        },\n        \"PaliGemma2-10B-pt-224\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-10b-pt-224\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma2-10b-pt-224\",\n        },\n        \"PaliGemma2-10B-pt-448\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-10b-pt-448\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma2-10b-pt-448\",\n        },\n        \"PaliGemma2-10B-pt-896\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-10b-pt-896\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma2-10b-pt-896\",\n        },\n        \"PaliGemma2-28B-pt-224\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-28b-pt-224\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma2-28b-pt-224\",\n        },\n        \"PaliGemma2-28B-pt-448\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-28b-pt-448\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma2-28b-pt-448\",\n        },\n        \"PaliGemma2-28B-pt-896\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-28b-pt-896\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/paligemma2-28b-pt-896\",\n        },\n        \"PaliGemma2-3B-mix-224\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-3b-mix-224\",\n            DownloadSource.MODELSCOPE: \"mlx-community/paligemma2-3b-mix-224-bf16\",\n        },\n        \"PaliGemma2-3B-mix-448\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-3b-mix-448\",\n            DownloadSource.MODELSCOPE: \"mlx-community/paligemma2-3b-mix-448-bf16\",\n        },\n        \"PaliGemma2-10B-mix-224\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-10b-mix-224\",\n            DownloadSource.MODELSCOPE: \"mlx-community/paligemma2-10b-mix-224-bf16\",\n        },\n        \"PaliGemma2-10B-mix-448\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-10b-mix-448\",\n            DownloadSource.MODELSCOPE: \"mlx-community/paligemma2-10b-mix-448-bf16\",\n        },\n        \"PaliGemma2-28B-mix-224\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-28b-mix-224\",\n            DownloadSource.MODELSCOPE: \"mlx-community/paligemma2-28b-mix-224-bf16\",\n        },\n        \"PaliGemma2-28B-mix-448\": {\n            DownloadSource.DEFAULT: \"google/paligemma2-28b-mix-448\",\n            DownloadSource.MODELSCOPE: \"mlx-community/paligemma2-28b-mix-448-bf16\",\n        },\n    },\n    template=\"paligemma\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Phi-1.5-1.3B\": {\n            DownloadSource.DEFAULT: \"microsoft/phi-1_5\",\n            DownloadSource.MODELSCOPE: \"allspace/PHI_1-5\",\n        },\n        \"Phi-2-2.7B\": {\n            DownloadSource.DEFAULT: \"microsoft/phi-2\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/phi-2\",\n        },\n    }\n)\n\n\nregister_model_group(\n    models={\n        \"Phi-3-4B-4k-Instruct\": {\n            DownloadSource.DEFAULT: \"microsoft/Phi-3-mini-4k-instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Phi-3-mini-4k-instruct\",\n        },\n        \"Phi-3-4B-128k-Instruct\": {\n            DownloadSource.DEFAULT: \"microsoft/Phi-3-mini-128k-instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Phi-3-mini-128k-instruct\",\n        },\n        \"Phi-3-14B-8k-Instruct\": {\n            DownloadSource.DEFAULT: \"microsoft/Phi-3-medium-4k-instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Phi-3-medium-4k-instruct\",\n        },\n        \"Phi-3-14B-128k-Instruct\": {\n            DownloadSource.DEFAULT: \"microsoft/Phi-3-medium-128k-instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Phi-3-medium-128k-instruct\",\n        },\n        \"Phi-3.5-4B-instruct\": {\n            DownloadSource.DEFAULT: \"microsoft/Phi-3.5-mini-instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Phi-3.5-mini-instruct\",\n        },\n        \"Phi-3.5-MoE-42B-A6.6B-instruct\": {\n            DownloadSource.DEFAULT: \"microsoft/Phi-3.5-MoE-instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Phi-3.5-MoE-instruct\",\n        },\n    },\n    template=\"phi\",\n)\n\n\nregister_model_group(\n    models={\n        \"Phi-3-7B-8k-Instruct\": {\n            DownloadSource.DEFAULT: \"microsoft/Phi-3-small-8k-instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Phi-3-small-8k-instruct\",\n        },\n        \"Phi-3-7B-128k-Instruct\": {\n            DownloadSource.DEFAULT: \"microsoft/Phi-3-small-128k-instruct\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/Phi-3-small-128k-instruct\",\n        },\n    },\n    template=\"phi_small\",\n)\n\n\nregister_model_group(\n    models={\n        \"Phi-4-14B-Instruct\": {\n            DownloadSource.DEFAULT: \"microsoft/phi-4\",\n            DownloadSource.MODELSCOPE: \"LLM-Research/phi-4\",\n        },\n    },\n    template=\"phi4\",\n)\n\n\nregister_model_group(\n    models={\n        \"Pixtral-12B\": {\n            DownloadSource.DEFAULT: \"mistral-community/pixtral-12b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/pixtral-12b\",\n        }\n    },\n    template=\"pixtral\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Qwen-1.8B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-1_8B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-1_8B\",\n        },\n        \"Qwen-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-7B\",\n        },\n        \"Qwen-14B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-14B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-14B\",\n        },\n        \"Qwen-72B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-72B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-72B\",\n        },\n        \"Qwen-1.8B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-1_8B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-1_8B-Chat\",\n        },\n        \"Qwen-7B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-7B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-7B-Chat\",\n        },\n        \"Qwen-14B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-14B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-14B-Chat\",\n        },\n        \"Qwen-72B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-72B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-72B-Chat\",\n        },\n        \"Qwen-1.8B-Chat-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-1_8B-Chat-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-1_8B-Chat-Int8\",\n        },\n        \"Qwen-1.8B-Chat-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-1_8B-Chat-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-1_8B-Chat-Int4\",\n        },\n        \"Qwen-7B-Chat-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-7B-Chat-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-7B-Chat-Int8\",\n        },\n        \"Qwen-7B-Chat-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-7B-Chat-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-7B-Chat-Int4\",\n        },\n        \"Qwen-14B-Chat-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-14B-Chat-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-14B-Chat-Int8\",\n        },\n        \"Qwen-14B-Chat-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-14B-Chat-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-14B-Chat-Int4\",\n        },\n        \"Qwen-72B-Chat-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-72B-Chat-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-72B-Chat-Int8\",\n        },\n        \"Qwen-72B-Chat-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen-72B-Chat-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen-72B-Chat-Int4\",\n        },\n    },\n    template=\"qwen\",\n)\n\n\nregister_model_group(\n    models={\n        \"Qwen1.5-0.5B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-0.5B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-0.5B\",\n        },\n        \"Qwen1.5-1.8B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-1.8B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-1.8B\",\n        },\n        \"Qwen1.5-4B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-4B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-4B\",\n        },\n        \"Qwen1.5-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-7B\",\n        },\n        \"Qwen1.5-14B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-14B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-14B\",\n        },\n        \"Qwen1.5-32B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-32B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-32B\",\n        },\n        \"Qwen1.5-72B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-72B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-72B\",\n        },\n        \"Qwen1.5-110B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-110B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-110B\",\n        },\n        \"Qwen1.5-MoE-A2.7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-MoE-A2.7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-MoE-A2.7B\",\n        },\n        \"Qwen1.5-0.5B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-0.5B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-0.5B-Chat\",\n        },\n        \"Qwen1.5-1.8B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-1.8B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-1.8B-Chat\",\n        },\n        \"Qwen1.5-4B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-4B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-4B-Chat\",\n        },\n        \"Qwen1.5-7B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-7B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-7B-Chat\",\n        },\n        \"Qwen1.5-14B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-14B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-14B-Chat\",\n        },\n        \"Qwen1.5-32B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-32B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-32B-Chat\",\n        },\n        \"Qwen1.5-72B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-72B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-72B-Chat\",\n        },\n        \"Qwen1.5-110B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-110B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-110B-Chat\",\n        },\n        \"Qwen1.5-MoE-A2.7B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-MoE-A2.7B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-MoE-A2.7B-Chat\",\n        },\n        \"Qwen1.5-0.5B-Chat-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8\",\n        },\n        \"Qwen1.5-0.5B-Chat-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-0.5B-Chat-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-0.5B-Chat-AWQ\",\n        },\n        \"Qwen1.5-1.8B-Chat-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8\",\n        },\n        \"Qwen1.5-1.8B-Chat-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-1.8B-Chat-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-1.8B-Chat-AWQ\",\n        },\n        \"Qwen1.5-4B-Chat-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-4B-Chat-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-4B-Chat-GPTQ-Int8\",\n        },\n        \"Qwen1.5-4B-Chat-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-4B-Chat-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-4B-Chat-AWQ\",\n        },\n        \"Qwen1.5-7B-Chat-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-7B-Chat-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-7B-Chat-GPTQ-Int8\",\n        },\n        \"Qwen1.5-7B-Chat-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-7B-Chat-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-7B-Chat-AWQ\",\n        },\n        \"Qwen1.5-14B-Chat-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-14B-Chat-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-14B-Chat-GPTQ-Int8\",\n        },\n        \"Qwen1.5-14B-Chat-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-14B-Chat-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-14B-Chat-AWQ\",\n        },\n        \"Qwen1.5-32B-Chat-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-32B-Chat-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-32B-Chat-AWQ\",\n        },\n        \"Qwen1.5-72B-Chat-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-72B-Chat-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-72B-Chat-GPTQ-Int8\",\n        },\n        \"Qwen1.5-72B-Chat-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-72B-Chat-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-72B-Chat-AWQ\",\n        },\n        \"Qwen1.5-110B-Chat-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-110B-Chat-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-110B-Chat-AWQ\",\n        },\n        \"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4\",\n        },\n        \"CodeQwen1.5-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/CodeQwen1.5-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/CodeQwen1.5-7B\",\n        },\n        \"CodeQwen1.5-7B-Chat\": {\n            DownloadSource.DEFAULT: \"Qwen/CodeQwen1.5-7B-Chat\",\n            DownloadSource.MODELSCOPE: \"Qwen/CodeQwen1.5-7B-Chat\",\n        },\n        \"CodeQwen1.5-7B-Chat-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/CodeQwen1.5-7B-Chat-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/CodeQwen1.5-7B-Chat-AWQ\",\n        },\n    },\n    template=\"qwen\",\n)\n\n\nregister_model_group(\n    models={\n        \"Qwen2-0.5B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-0.5B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-0.5B\",\n        },\n        \"Qwen2-1.5B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-1.5B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-1.5B\",\n        },\n        \"Qwen2-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-7B\",\n        },\n        \"Qwen2-72B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-72B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-72B\",\n        },\n        \"Qwen2-MoE-57B-A14B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-57B-A14B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-57B-A14B\",\n        },\n        \"Qwen2-0.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-0.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-0.5B-Instruct\",\n            DownloadSource.OPENMIND: \"LlamaFactory/Qwen2-0.5B-Instruct\",\n        },\n        \"Qwen2-1.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-1.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-1.5B-Instruct\",\n            DownloadSource.OPENMIND: \"LlamaFactory/Qwen2-1.5B-Instruct\",\n        },\n        \"Qwen2-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-7B-Instruct\",\n            DownloadSource.OPENMIND: \"LlamaFactory/Qwen2-7B-Instruct\",\n        },\n        \"Qwen2-72B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-72B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-72B-Instruct\",\n        },\n        \"Qwen2-MoE-57B-A14B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-57B-A14B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-57B-A14B-Instruct\",\n        },\n        \"Qwen2-0.5B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2-0.5B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2-0.5B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-0.5B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-0.5B-Instruct-AWQ\",\n        },\n        \"Qwen2-1.5B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2-1.5B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2-1.5B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-1.5B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-1.5B-Instruct-AWQ\",\n        },\n        \"Qwen2-7B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-7B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-7B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2-7B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-7B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-7B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2-7B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-7B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-7B-Instruct-AWQ\",\n        },\n        \"Qwen2-72B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-72B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-72B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2-72B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-72B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-72B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2-72B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-72B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-72B-Instruct-AWQ\",\n        },\n        \"Qwen2-57B-A14B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2-Math-1.5B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-Math-1.5B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-Math-1.5B\",\n        },\n        \"Qwen2-Math-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-Math-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-Math-7B\",\n        },\n        \"Qwen2-Math-72B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-Math-72B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-Math-72B\",\n        },\n        \"Qwen2-Math-1.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-Math-1.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-Math-1.5B-Instruct\",\n        },\n        \"Qwen2-Math-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-Math-7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-Math-7B-Instruct\",\n        },\n        \"Qwen2-Math-72B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-Math-72B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-Math-72B-Instruct\",\n        },\n    },\n    template=\"qwen\",\n)\n\n\nregister_model_group(\n    models={\n        \"Qwen2.5-0.5B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-0.5B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-0.5B\",\n        },\n        \"Qwen2.5-1.5B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-1.5B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-1.5B\",\n        },\n        \"Qwen2.5-3B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-3B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-3B\",\n        },\n        \"Qwen2.5-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-7B\",\n        },\n        \"Qwen2.5-14B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-14B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-14B\",\n        },\n        \"Qwen2.5-32B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-32B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-32B\",\n        },\n        \"Qwen2.5-72B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-72B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-72B\",\n        },\n        \"Qwen2.5-0.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-0.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-0.5B-Instruct\",\n        },\n        \"Qwen2.5-1.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-1.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-1.5B-Instruct\",\n        },\n        \"Qwen2.5-3B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-3B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-3B-Instruct\",\n        },\n        \"Qwen2.5-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-7B-Instruct\",\n        },\n        \"Qwen2.5-14B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-14B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-14B-Instruct\",\n        },\n        \"Qwen2.5-32B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-32B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-32B-Instruct\",\n        },\n        \"Qwen2.5-72B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-72B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-72B-Instruct\",\n        },\n        \"Qwen2.5-7B-Instruct-1M\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-7B-Instruct-1M\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-7B-Instruct-1M\",\n        },\n        \"Qwen2.5-14B-Instruct-1M\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-14B-Instruct-1M\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-14B-Instruct-1M\",\n        },\n        \"Qwen2.5-0.5B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2.5-0.5B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2.5-0.5B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-0.5B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-0.5B-Instruct-AWQ\",\n        },\n        \"Qwen2.5-1.5B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2.5-1.5B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2.5-1.5B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-1.5B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-1.5B-Instruct-AWQ\",\n        },\n        \"Qwen2.5-3B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2.5-3B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2.5-3B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-3B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-3B-Instruct-AWQ\",\n        },\n        \"Qwen2.5-7B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2.5-7B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2.5-7B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-7B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-7B-Instruct-AWQ\",\n        },\n        \"Qwen2.5-14B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2.5-14B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2.5-14B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-14B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-14B-Instruct-AWQ\",\n        },\n        \"Qwen2.5-32B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2.5-32B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2.5-32B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-32B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-32B-Instruct-AWQ\",\n        },\n        \"Qwen2.5-72B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2.5-72B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2.5-72B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-72B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-72B-Instruct-AWQ\",\n        },\n        \"Qwen2.5-Coder-0.5B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-0.5B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-0.5B\",\n        },\n        \"Qwen2.5-Coder-1.5B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-1.5B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-1.5B\",\n        },\n        \"Qwen2.5-Coder-3B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-3B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-3B\",\n        },\n        \"Qwen2.5-Coder-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-7B\",\n        },\n        \"Qwen2.5-Coder-14B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-14B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-14B\",\n        },\n        \"Qwen2.5-Coder-32B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-32B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-32B\",\n        },\n        \"Qwen2.5-Coder-0.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-0.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-0.5B-Instruct\",\n        },\n        \"Qwen2.5-Coder-1.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-1.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-1.5B-Instruct\",\n        },\n        \"Qwen2.5-Coder-3B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-3B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-3B-Instruct\",\n        },\n        \"Qwen2.5-Coder-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-7B-Instruct\",\n        },\n        \"Qwen2.5-Coder-14B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-14B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-14B-Instruct\",\n        },\n        \"Qwen2.5-Coder-32B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-32B-Instruct\",\n        },\n        \"Qwen2.5-Math-1.5B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Math-1.5B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Math-1.5B\",\n        },\n        \"Qwen2.5-Math-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Math-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Math-7B\",\n        },\n        \"Qwen2.5-Math-72B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Math-72B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Math-72B\",\n        },\n        \"Qwen2.5-Math-1.5B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Math-1.5B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-1.5B-Instruct\",\n        },\n        \"Qwen2.5-Math-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Math-7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-7B-Instruct\",\n        },\n        \"Qwen2.5-Math-72B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Math-72B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Coder-72B-Instruct\",\n        },\n        \"QwQ-32B-Preview-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/QwQ-32B-Preview\",\n            DownloadSource.MODELSCOPE: \"Qwen/QwQ-32B-Preview\",\n        },\n        \"QwQ-32B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/QwQ-32B\",\n            DownloadSource.MODELSCOPE: \"Qwen/QwQ-32B\",\n        },\n    },\n    template=\"qwen\",\n)\n\n\nregister_model_group(\n    models={\n        \"Qwen3-0.6B-Base\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-0.6B-Base\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-0.6B-Base\",\n        },\n        \"Qwen3-1.7B-Base\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-1.7B-Base\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-1.7B-Base\",\n        },\n        \"Qwen3-4B-Base\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-4B-Base\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-4B-Base\",\n        },\n        \"Qwen3-8B-Base\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-8B-Base\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-8B-Base\",\n        },\n        \"Qwen3-14B-Base\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-14B-Base\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-14B-Base\",\n        },\n        \"Qwen3-30B-A3B-Base\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-30B-A3B-Base\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-30B-A3B-Base\",\n        },\n        \"Qwen3-0.6B-Thinking\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-0.6B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-0.6B\",\n        },\n        \"Qwen3-1.7B-Thinking\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-1.7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-1.7B\",\n        },\n        \"Qwen3-4B-Thinking\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-4B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-4B\",\n        },\n        \"Qwen3-4B-Thinking-2507\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-4B-Thinking-2507\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-4B-Thinking-2507\",\n        },\n        \"Qwen3-8B-Thinking\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-8B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-8B\",\n        },\n        \"Qwen3-14B-Thinking\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-14B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-14B\",\n        },\n        \"Qwen3-32B-Thinking\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-32B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-32B\",\n        },\n        \"Qwen3-30B-A3B-Thinking\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-30B-A3B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-30B-A3B\",\n        },\n        \"Qwen3-30B-A3B-Thinking-2507\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-30B-A3B-Thinking-2507\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-30B-A3B-Thinking-2507\",\n        },\n        \"Qwen3-235B-A22B-Thinking\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-235B-A22B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-235B-A22B\",\n        },\n        \"Qwen3-235B-A22B-Thinking-2507\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-235B-A22B-Thinking-2507\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-235B-A22B-Thinking-2507\",\n        },\n        \"Qwen3-0.6B-Thinking-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-0.6B-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-0.6B-GPTQ-Int8\",\n        },\n        \"Qwen3-1.7B-Thinking-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-1.7B-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-1.7B-GPTQ-Int8\",\n        },\n        \"Qwen3-4B-Thinking-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-4B-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-4B-AWQ\",\n        },\n        \"Qwen3-8B-Thinking-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-8B-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-8B-AWQ\",\n        },\n        \"Qwen3-14B-Thinking-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-14B-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-14B-AWQ\",\n        },\n        \"Qwen3-32B-Thinking-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-32B-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-32B-AWQ\",\n        },\n        \"Qwen3-30B-A3B-Thinking-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-30B-A3B-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-30B-A3B-GPTQ-Int4\",\n        },\n        \"Qwen3-235B-A22B-Thinking-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-235B-A22B-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-235B-A22B-GPTQ-Int4\",\n        },\n    },\n    template=\"qwen3\",\n)\n\n\nregister_model_group(\n    models={\n        \"Qwen3-4B-Instruct-2507\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-4B-Instruct-2507\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-4B-Instruct-2507\",\n        },\n        \"Qwen3-30B-A3B-Instruct-2507\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-30B-A3B-Instruct-2507\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-30B-A3B-Instruct-2507\",\n        },\n        \"Qwen3-235B-A22B-Instruct-2507\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen3-235B-A22B-Instruct-2507\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen3-235B-A22B-Instruct-2507\",\n        },\n    },\n    template=\"qwen3_nothink\",\n)\n\n\nregister_model_group(\n    models={\n        \"Qwen2-Audio-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-Audio-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-Audio-7B\",\n        },\n        \"Qwen2-Audio-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-Audio-7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-Audio-7B-Instruct\",\n        },\n    },\n    template=\"qwen2_audio\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Qwen2.5-Omni-3B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Omni-3B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Omni-3B\",\n        },\n        \"Qwen2.5-Omni-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Omni-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Omni-7B\",\n        },\n        \"Qwen2.5-Omni-7B-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Omni-7B-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Omni-7B-GPTQ-Int4\",\n        },\n        \"Qwen2.5-Omni-7B-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-Omni-7B-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-Omni-7B-AWQ\",\n        },\n    },\n    template=\"qwen2_omni\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Qwen2-VL-2B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-2B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-2B\",\n        },\n        \"Qwen2-VL-7B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-7B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-7B\",\n        },\n        \"Qwen2-VL-72B\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-72B\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-72B\",\n        },\n        \"Qwen2-VL-2B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-2B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-2B-Instruct\",\n            DownloadSource.OPENMIND: \"LlamaFactory/Qwen2-VL-2B-Instruct\",\n        },\n        \"Qwen2-VL-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-7B-Instruct\",\n            DownloadSource.OPENMIND: \"LlamaFactory/Qwen2-VL-7B-Instruct\",\n        },\n        \"Qwen2-VL-72B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-72B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-72B-Instruct\",\n        },\n        \"Qwen2-VL-2B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2-VL-2B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2-VL-2B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-2B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-2B-Instruct-AWQ\",\n        },\n        \"Qwen2-VL-7B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2-VL-7B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2-VL-7B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-7B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-7B-Instruct-AWQ\",\n        },\n        \"Qwen2-VL-72B-Instruct-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8\",\n        },\n        \"Qwen2-VL-72B-Instruct-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4\",\n        },\n        \"Qwen2-VL-72B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2-VL-72B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2-VL-72B-Instruct-AWQ\",\n        },\n        \"QVQ-72B-Preview\": {\n            DownloadSource.DEFAULT: \"Qwen/QVQ-72B-Preview\",\n            DownloadSource.MODELSCOPE: \"Qwen/QVQ-72B-Preview\",\n        },\n        \"Qwen2.5-VL-3B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-VL-3B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-VL-3B-Instruct\",\n        },\n        \"Qwen2.5-VL-7B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-VL-7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-VL-7B-Instruct\",\n        },\n        \"Qwen2.5-VL-32B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-VL-32B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-VL-32B-Instruct\",\n        },\n        \"Qwen2.5-VL-72B-Instruct\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-VL-72B-Instruct\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-VL-72B-Instruct\",\n        },\n        \"Qwen2.5-VL-3B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-VL-3B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-VL-3B-Instruct-AWQ\",\n        },\n        \"Qwen2.5-VL-7B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-VL-7B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-VL-7B-Instruct-AWQ\",\n        },\n        \"Qwen2.5-VL-72B-Instruct-AWQ\": {\n            DownloadSource.DEFAULT: \"Qwen/Qwen2.5-VL-72B-Instruct-AWQ\",\n            DownloadSource.MODELSCOPE: \"Qwen/Qwen2.5-VL-72B-Instruct-AWQ\",\n        },\n    },\n    template=\"qwen2_vl\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Seed-Coder-8B-Base\": {\n            DownloadSource.DEFAULT: \"ByteDance-Seed/Seed-Coder-8B-Base\",\n        },\n        \"Seed-Coder-8B-Instruct\": {\n            DownloadSource.DEFAULT: \"ByteDance-Seed/Seed-Coder-8B-Instruct\",\n        },\n        \"Seed-Coder-8B-Instruct-Reasoning\": {\n            DownloadSource.DEFAULT: \"ByteDance-Seed/Seed-Coder-8B-Reasoning-bf16\",\n        },\n    },\n    template=\"seed_coder\",\n)\n\n\nregister_model_group(\n    models={\n        \"Skywork-13B-Base\": {\n            DownloadSource.DEFAULT: \"Skywork/Skywork-13B-base\",\n            DownloadSource.MODELSCOPE: \"skywork/Skywork-13B-base\",\n        }\n    }\n)\n\n\nregister_model_group(\n    models={\n        \"Skywork-o1-Open-Llama-3.1-8B\": {\n            DownloadSource.DEFAULT: \"Skywork/Skywork-o1-Open-Llama-3.1-8B\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B\",\n        }\n    },\n    template=\"skywork_o1\",\n)\n\n\nregister_model_group(\n    models={\n        \"SmolLM-135M\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM-135M\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM-135M\",\n        },\n        \"SmolLM-360M\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM-360M\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM-360M\",\n        },\n        \"SmolLM-1.7B\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM-1.7B\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM-1.7B\",\n        },\n        \"SmolLM-135M-Instruct\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM-135M-Instruct\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM-135M-Instruct\",\n        },\n        \"SmolLM-360M-Instruct\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM-360M-Instruct\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM-360M-Instruct\",\n        },\n        \"SmolLM-1.7B-Instruct\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM-1.7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM-1.7B-Instruct\",\n        },\n    },\n    template=\"smollm\",\n)\n\n\nregister_model_group(\n    models={\n        \"SmolLM2-135M\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM2-135M\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM2-135M\",\n        },\n        \"SmolLM2-360M\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM2-360M\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM2-360M\",\n        },\n        \"SmolLM2-1.7B\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM2-1.7B\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM2-1.7B\",\n        },\n        \"SmolLM2-135M-Instruct\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM2-135M-Instruct\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM2-135M-Instruct\",\n        },\n        \"SmolLM2-360M-Instruct\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM2-360M-Instruct\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM2-360M-Instruct\",\n        },\n        \"SmolLM2-1.7B-Instruct\": {\n            DownloadSource.DEFAULT: \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n            DownloadSource.MODELSCOPE: \"HuggingFaceTB/SmolLM2-1.7B-Instruct\",\n        },\n    },\n    template=\"smollm2\",\n)\n\n\nregister_model_group(\n    models={\n        \"SOLAR-10.7B-v1.0\": {\n            DownloadSource.DEFAULT: \"upstage/SOLAR-10.7B-v1.0\",\n        },\n        \"SOLAR-10.7B-Instruct-v1.0\": {\n            DownloadSource.DEFAULT: \"upstage/SOLAR-10.7B-Instruct-v1.0\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/SOLAR-10.7B-Instruct-v1.0\",\n        },\n    },\n    template=\"solar\",\n)\n\n\nregister_model_group(\n    models={\n        \"StarCoder2-3B\": {\n            DownloadSource.DEFAULT: \"bigcode/starcoder2-3b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/starcoder2-3b\",\n        },\n        \"StarCoder2-7B\": {\n            DownloadSource.DEFAULT: \"bigcode/starcoder2-7b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/starcoder2-7b\",\n        },\n        \"StarCoder2-15B\": {\n            DownloadSource.DEFAULT: \"bigcode/starcoder2-15b\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/starcoder2-15b\",\n        },\n    }\n)\n\n\nregister_model_group(\n    models={\n        \"TeleChat-1B-Chat\": {\n            DownloadSource.DEFAULT: \"Tele-AI/TeleChat-1B\",\n            DownloadSource.MODELSCOPE: \"TeleAI/TeleChat-1B\",\n        },\n        \"TeleChat-7B-Chat\": {\n            DownloadSource.DEFAULT: \"Tele-AI/telechat-7B\",\n            DownloadSource.MODELSCOPE: \"TeleAI/telechat-7B\",\n            DownloadSource.OPENMIND: \"TeleAI/TeleChat-7B-pt\",\n        },\n        \"TeleChat-12B-Chat\": {\n            DownloadSource.DEFAULT: \"Tele-AI/TeleChat-12B-v2\",\n            DownloadSource.MODELSCOPE: \"TeleAI/TeleChat-12B-v2\",\n            DownloadSource.OPENMIND: \"TeleAI/TeleChat-12B-pt\",\n        },\n        \"TeleChat-52B-Chat\": {\n            DownloadSource.DEFAULT: \"Tele-AI/TeleChat-52B\",\n        },\n    },\n    template=\"telechat\",\n)\n\n\nregister_model_group(\n    models={\n        \"TeleChat2-3B-Chat\": {\n            DownloadSource.DEFAULT: \"Tele-AI/TeleChat2-3B\",\n            DownloadSource.MODELSCOPE: \"TeleAI/TeleChat2-3B\",\n        },\n        \"TeleChat2-7B-Chat\": {\n            DownloadSource.DEFAULT: \"Tele-AI/TeleChat2-7B\",\n            DownloadSource.MODELSCOPE: \"TeleAI/TeleChat2-7B\",\n        },\n        \"TeleChat2-35B-Chat\": {\n            DownloadSource.MODELSCOPE: \"TeleAI/TeleChat2-35B-Nov\",\n        },\n        \"TeleChat2-115B-Chat\": {\n            DownloadSource.DEFAULT: \"Tele-AI/TeleChat2-115B\",\n            DownloadSource.MODELSCOPE: \"TeleAI/TeleChat2-115B\",\n        },\n    },\n    template=\"telechat2\",\n)\n\n\nregister_model_group(\n    models={\n        \"Vicuna-v1.5-7B-Chat\": {\n            DownloadSource.DEFAULT: \"lmsys/vicuna-7b-v1.5\",\n            DownloadSource.MODELSCOPE: \"Xorbits/vicuna-7b-v1.5\",\n        },\n        \"Vicuna-v1.5-13B-Chat\": {\n            DownloadSource.DEFAULT: \"lmsys/vicuna-13b-v1.5\",\n            DownloadSource.MODELSCOPE: \"Xorbits/vicuna-13b-v1.5\",\n        },\n    },\n    template=\"vicuna\",\n)\n\n\nregister_model_group(\n    models={\n        \"Video-LLaVA-7B-Chat\": {\n            DownloadSource.DEFAULT: \"LanguageBind/Video-LLaVA-7B-hf\",\n        },\n    },\n    template=\"video_llava\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"XuanYuan-6B\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan-6B\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan-6B\",\n        },\n        \"XuanYuan-70B\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan-70B\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan-70B\",\n        },\n        \"XuanYuan2-70B\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan2-70B\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan2-70B\",\n        },\n        \"XuanYuan-6B-Chat\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan-6B-Chat\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan-6B-Chat\",\n        },\n        \"XuanYuan-70B-Chat\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan-70B-Chat\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan-70B-Chat\",\n        },\n        \"XuanYuan2-70B-Chat\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan2-70B-Chat\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan2-70B-Chat\",\n        },\n        \"XuanYuan-6B-Chat-8bit\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan-6B-Chat-8bit\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan-6B-Chat-8bit\",\n        },\n        \"XuanYuan-6B-Chat-4bit\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan-6B-Chat-4bit\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan-6B-Chat-4bit\",\n        },\n        \"XuanYuan-70B-Chat-8bit\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan-70B-Chat-8bit\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan-70B-Chat-8bit\",\n        },\n        \"XuanYuan-70B-Chat-4bit\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan-70B-Chat-4bit\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan-70B-Chat-4bit\",\n        },\n        \"XuanYuan2-70B-Chat-8bit\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan2-70B-Chat-8bit\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan2-70B-Chat-8bit\",\n        },\n        \"XuanYuan2-70B-Chat-4bit\": {\n            DownloadSource.DEFAULT: \"Duxiaoman-DI/XuanYuan2-70B-Chat-4bit\",\n            DownloadSource.MODELSCOPE: \"Duxiaoman-DI/XuanYuan2-70B-Chat-4bit\",\n        },\n    },\n    template=\"xuanyuan\",\n)\n\n\nregister_model_group(\n    models={\n        \"XVERSE-7B\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-7B\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-7B\",\n        },\n        \"XVERSE-13B\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-13B\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-13B\",\n        },\n        \"XVERSE-65B\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-65B\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-65B\",\n        },\n        \"XVERSE-65B-2\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-65B-2\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-65B-2\",\n        },\n        \"XVERSE-7B-Chat\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-7B-Chat\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-7B-Chat\",\n        },\n        \"XVERSE-13B-Chat\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-13B-Chat\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-13B-Chat\",\n        },\n        \"XVERSE-65B-Chat\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-65B-Chat\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-65B-Chat\",\n        },\n        \"XVERSE-MoE-A4.2B\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-MoE-A4.2B\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-MoE-A4.2B\",\n        },\n        \"XVERSE-7B-Chat-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-7B-Chat-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-7B-Chat-GPTQ-Int8\",\n        },\n        \"XVERSE-7B-Chat-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-7B-Chat-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-7B-Chat-GPTQ-Int4\",\n        },\n        \"XVERSE-13B-Chat-GPTQ-Int8\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-13B-Chat-GPTQ-Int8\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-13B-Chat-GPTQ-Int8\",\n        },\n        \"XVERSE-13B-Chat-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-13B-Chat-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-13B-Chat-GPTQ-Int4\",\n        },\n        \"XVERSE-65B-Chat-GPTQ-Int4\": {\n            DownloadSource.DEFAULT: \"xverse/XVERSE-65B-Chat-GPTQ-Int4\",\n            DownloadSource.MODELSCOPE: \"xverse/XVERSE-65B-Chat-GPTQ-Int4\",\n        },\n    },\n    template=\"xverse\",\n)\n\n\nregister_model_group(\n    models={\n        \"Yayi-7B\": {\n            DownloadSource.DEFAULT: \"wenge-research/yayi-7b-llama2\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/yayi-7b-llama2\",\n        },\n        \"Yayi-13B\": {\n            DownloadSource.DEFAULT: \"wenge-research/yayi-13b-llama2\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/yayi-13b-llama2\",\n        },\n    },\n    template=\"yayi\",\n)\n\n\nregister_model_group(\n    models={\n        \"Yi-6B\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-6B\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-6B\",\n        },\n        \"Yi-9B\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-9B\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-9B\",\n        },\n        \"Yi-34B\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-34B\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-34B\",\n        },\n        \"Yi-6B-Chat\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-6B-Chat\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-6B-Chat\",\n        },\n        \"Yi-34B-Chat\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-34B-Chat\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-34B-Chat\",\n        },\n        \"Yi-6B-Chat-8bits\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-6B-Chat-8bits\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-6B-Chat-8bits\",\n        },\n        \"Yi-6B-Chat-4bits\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-6B-Chat-4bits\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-6B-Chat-4bits\",\n        },\n        \"Yi-34B-Chat-8bits\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-34B-Chat-8bits\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-34B-Chat-8bits\",\n        },\n        \"Yi-34B-Chat-4bits\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-34B-Chat-4bits\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-34B-Chat-4bits\",\n        },\n        \"Yi-1.5-6B\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-1.5-6B\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-1.5-6B\",\n        },\n        \"Yi-1.5-9B\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-1.5-9B\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-1.5-9B\",\n        },\n        \"Yi-1.5-34B\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-1.5-34B\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-1.5-34B\",\n        },\n        \"Yi-1.5-6B-Chat\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-1.5-6B-Chat\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-1.5-6B-Chat\",\n            DownloadSource.OPENMIND: \"LlamaFactory/Yi-1.5-6B-Chat\",\n        },\n        \"Yi-1.5-9B-Chat\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-1.5-9B-Chat\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-1.5-9B-Chat\",\n        },\n        \"Yi-1.5-34B-Chat\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-1.5-34B-Chat\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-1.5-34B-Chat\",\n        },\n        \"Yi-Coder-1.5B\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-Coder-1.5B\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-Coder-1.5B\",\n        },\n        \"Yi-Coder-9B\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-Coder-9B\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-Coder-9B\",\n        },\n        \"Yi-Coder-1.5B-Chat\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-Coder-1.5B-Chat\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-Coder-1.5B-Chat\",\n        },\n        \"Yi-Coder-9B-Chat\": {\n            DownloadSource.DEFAULT: \"01-ai/Yi-Coder-9B-Chat\",\n            DownloadSource.MODELSCOPE: \"01ai/Yi-Coder-9B-Chat\",\n        },\n    },\n    template=\"yi\",\n)\n\n\nregister_model_group(\n    models={\n        \"Yi-VL-6B-Chat\": {\n            DownloadSource.DEFAULT: \"BUAADreamer/Yi-VL-6B-hf\",\n        },\n        \"Yi-VL-34B-Chat\": {\n            DownloadSource.DEFAULT: \"BUAADreamer/Yi-VL-34B-hf\",\n        },\n    },\n    template=\"yi_vl\",\n    multimodal=True,\n)\n\n\nregister_model_group(\n    models={\n        \"Yuan2-2B-Chat\": {\n            DownloadSource.DEFAULT: \"IEITYuan/Yuan2-2B-hf\",\n            DownloadSource.MODELSCOPE: \"YuanLLM/Yuan2.0-2B-hf\",\n        },\n        \"Yuan2-51B-Chat\": {\n            DownloadSource.DEFAULT: \"IEITYuan/Yuan2-51B-hf\",\n            DownloadSource.MODELSCOPE: \"YuanLLM/Yuan2.0-51B-hf\",\n        },\n        \"Yuan2-102B-Chat\": {\n            DownloadSource.DEFAULT: \"IEITYuan/Yuan2-102B-hf\",\n            DownloadSource.MODELSCOPE: \"YuanLLM/Yuan2.0-102B-hf\",\n        },\n    },\n    template=\"yuan\",\n)\n\n\nregister_model_group(\n    models={\n        \"Zephyr-7B-Alpha-Chat\": {\n            DownloadSource.DEFAULT: \"HuggingFaceH4/zephyr-7b-alpha\",\n            DownloadSource.MODELSCOPE: \"AI-ModelScope/zephyr-7b-alpha\",\n        },\n        \"Zephyr-7B-Beta-Chat\": {\n            DownloadSource.DEFAULT: \"HuggingFaceH4/zephyr-7b-beta\",\n            DownloadSource.MODELSCOPE: \"modelscope/zephyr-7b-beta\",\n        },\n        \"Zephyr-141B-ORPO-Chat\": {\n            DownloadSource.DEFAULT: \"HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1\",\n        },\n    },\n    template=\"zephyr\",\n)\n"
  },
  {
    "path": "kt-sft/ktransformers/sft/metrics_utils/env.py",
    "content": "# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.\n#\n# This code is inspired by the HuggingFace's transformers library.\n# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport os\nimport platform\n\nimport accelerate\nimport datasets\nimport peft\nimport torch\nimport transformers\n# import trl\nfrom transformers.utils import is_torch_cuda_available, is_torch_npu_available\n\n\nVERSION = \"0.9.4.dev0\"\n\n\ndef print_env() -> None:\n    info = {\n        \"`llamafactory` version\": VERSION,\n        \"Platform\": platform.platform(),\n        \"Python version\": platform.python_version(),\n        \"PyTorch version\": torch.__version__,\n        \"Transformers version\": transformers.__version__,\n        \"Datasets version\": datasets.__version__,\n        \"Accelerate version\": accelerate.__version__,\n        \"PEFT version\": peft.__version__,\n        \"TRL version\": \"0.21.0\",\n    }\n\n    if is_torch_cuda_available():\n        info[\"PyTorch version\"] += \" (GPU)\"\n        info[\"GPU type\"] = torch.cuda.get_device_name()\n        info[\"GPU number\"] = torch.cuda.device_count()\n        info[\"GPU memory\"] = f\"{torch.cuda.mem_get_info()[1] / (1024**3):.2f}GB\"\n\n    if is_torch_npu_available():\n        info[\"PyTorch version\"] += \" (NPU)\"\n        info[\"NPU type\"] = torch.npu.get_device_name()\n        info[\"CANN version\"] = torch.version.cann\n\n    try:\n        import deepspeed  # type: ignore\n\n        info[\"DeepSpeed version\"] = deepspeed.__version__\n    except Exception:\n        pass\n\n    try:\n        import bitsandbytes  # type: ignore\n\n        info[\"Bitsandbytes version\"] = bitsandbytes.__version__\n    except Exception:\n        pass\n\n    try:\n        import vllm\n\n        info[\"vLLM version\"] = vllm.__version__\n    except Exception:\n        pass\n\n    try:\n        import subprocess\n\n        commit_info = subprocess.run([\"git\", \"rev-parse\", \"HEAD\"], capture_output=True, text=True, check=True)\n        commit_hash = commit_info.stdout.strip()\n        info[\"Git commit\"] = commit_hash\n    except Exception:\n        pass\n\n    if os.path.exists(\"data\"):\n        info[\"Default data directory\"] = \"detected\"\n    else:\n        info[\"Default data directory\"] = \"not detected\"\n\n    print(\"\\n\" + \"\\n\".join([f\"- {key}: {value}\" for key, value in info.items()]) + \"\\n\")\n"
  },
  {
    "path": "kt-sft/ktransformers/sft/metrics_utils/logging.py",
    "content": "# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.\n#\n# This code is inspired by the HuggingFace's transformers library.\n# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\nimport os\nimport sys\nimport threading\nfrom concurrent.futures import ThreadPoolExecutor\nfrom functools import lru_cache\nfrom typing import Optional\n\nfrom .constants import RUNNING_LOG\n\n\n_thread_lock = threading.RLock()\n_default_handler: Optional[\"logging.Handler\"] = None\n_default_log_level: \"logging._Level\" = logging.INFO\n\n\nclass LoggerHandler(logging.Handler):\n    r\"\"\"Redirect the logging output to the logging file for LLaMA Board.\"\"\"\n\n    def __init__(self, output_dir: str) -> None:\n        super().__init__()\n        self._formatter = logging.Formatter(\n            fmt=\"[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s\",\n            datefmt=\"%Y-%m-%d %H:%M:%S\",\n        )\n        self.setLevel(logging.INFO)\n        os.makedirs(output_dir, exist_ok=True)\n        self.running_log = os.path.join(output_dir, RUNNING_LOG)\n        if os.path.exists(self.running_log):\n            os.remove(self.running_log)\n\n        self.thread_pool = ThreadPoolExecutor(max_workers=1)\n\n    def _write_log(self, log_entry: str) -> None:\n        with open(self.running_log, \"a\", encoding=\"utf-8\") as f:\n            f.write(log_entry + \"\\n\")\n\n    def emit(self, record) -> None:\n        if record.name == \"httpx\":\n            return\n\n        log_entry = self._formatter.format(record)\n        self.thread_pool.submit(self._write_log, log_entry)\n\n    def close(self) -> None:\n        self.thread_pool.shutdown(wait=True)\n        return super().close()\n\n\nclass _Logger(logging.Logger):\n    r\"\"\"A logger that supports rank0 logging.\"\"\"\n\n    def info_rank0(self, *args, **kwargs) -> None:\n        self.info(*args, **kwargs)\n\n    def warning_rank0(self, *args, **kwargs) -> None:\n        self.warning(*args, **kwargs)\n\n    def warning_rank0_once(self, *args, **kwargs) -> None:\n        self.warning(*args, **kwargs)\n\n\ndef _get_default_logging_level() -> \"logging._Level\":\n    r\"\"\"Return the default logging level.\"\"\"\n    env_level_str = os.getenv(\"LLAMAFACTORY_VERBOSITY\", None)\n    if env_level_str:\n        if env_level_str.upper() in logging._nameToLevel:\n            return logging._nameToLevel[env_level_str.upper()]\n        else:\n            raise ValueError(f\"Unknown logging level: {env_level_str}.\")\n\n    return _default_log_level\n\n\ndef _get_library_name() -> str:\n    return __name__.split(\".\")[0]\n\n\ndef _get_library_root_logger() -> \"_Logger\":\n    return logging.getLogger(_get_library_name())\n\n\ndef _configure_library_root_logger() -> None:\n    r\"\"\"Configure root logger using a stdout stream handler with an explicit format.\"\"\"\n    global _default_handler\n\n    with _thread_lock:\n        if _default_handler:  # already configured\n            return\n\n        formatter = logging.Formatter(\n            fmt=\"[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s\",\n            datefmt=\"%Y-%m-%d %H:%M:%S\",\n        )\n        _default_handler = logging.StreamHandler(sys.stdout)\n        _default_handler.setFormatter(formatter)\n        library_root_logger = _get_library_root_logger()\n        library_root_logger.addHandler(_default_handler)\n        library_root_logger.setLevel(_get_default_logging_level())\n        library_root_logger.propagate = False\n\n\ndef get_logger(name: Optional[str] = None) -> \"_Logger\":\n    r\"\"\"Return a logger with the specified name. It it not supposed to be accessed externally.\"\"\"\n    if name is None:\n        name = _get_library_name()\n\n    _configure_library_root_logger()\n    return logging.getLogger(name)\n\n\ndef add_handler(handler: \"logging.Handler\") -> None:\n    r\"\"\"Add a handler to the root logger.\"\"\"\n    _configure_library_root_logger()\n    _get_library_root_logger().addHandler(handler)\n\n\ndef remove_handler(handler: logging.Handler) -> None:\n    r\"\"\"Remove a handler to the root logger.\"\"\"\n    _configure_library_root_logger()\n    _get_library_root_logger().removeHandler(handler)\n\n\ndef info_rank0(self: \"logging.Logger\", *args, **kwargs) -> None:\n    if int(os.getenv(\"LOCAL_RANK\", \"0\")) == 0:\n        self.info(*args, **kwargs)\n\n\ndef warning_rank0(self: \"logging.Logger\", *args, **kwargs) -> None:\n    if int(os.getenv(\"LOCAL_RANK\", \"0\")) == 0:\n        self.warning(*args, **kwargs)\n\n\n@lru_cache(None)\ndef warning_rank0_once(self: \"logging.Logger\", *args, **kwargs) -> None:\n    if int(os.getenv(\"LOCAL_RANK\", \"0\")) == 0:\n        self.warning(*args, **kwargs)\n\n\nlogging.Logger.info_rank0 = info_rank0\nlogging.Logger.warning_rank0 = warning_rank0\nlogging.Logger.warning_rank0_once = warning_rank0_once\n"
  },
  {
    "path": "kt-sft/ktransformers/sft/metrics_utils/misc.py",
    "content": "# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.\n#\n# This code is inspired by the HuggingFace's PEFT library.\n# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport gc\nimport os\nimport socket\nfrom typing import TYPE_CHECKING, Any, Literal, Optional, Union\n\nimport torch\nimport torch.distributed as dist\nimport transformers.dynamic_module_utils\nfrom huggingface_hub.utils import WeakFileLock\nfrom transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList\nfrom transformers.dynamic_module_utils import get_relative_imports\nfrom transformers.utils import (\n    is_torch_bf16_gpu_available,\n    is_torch_cuda_available,\n    is_torch_mps_available,\n    is_torch_npu_available,\n    is_torch_xpu_available,\n)\nfrom transformers.utils.versions import require_version\n\nfrom . import logging\n\n\n_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()\ntry:\n    _is_bf16_available = is_torch_bf16_gpu_available() or (is_torch_npu_available() and torch.npu.is_bf16_supported())\nexcept Exception:\n    _is_bf16_available = False\n\n\nif TYPE_CHECKING:\n    from numpy.typing import NDArray\n\n    from ..hparams import ModelArguments\n\n\nlogger = logging.get_logger(__name__)\n\n\nclass AverageMeter:\n    r\"\"\"Compute and store the average and current value.\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef check_version(requirement: str, mandatory: bool = False) -> None:\n    r\"\"\"Optionally check the package version.\"\"\"\n    if is_env_enabled(\"DISABLE_VERSION_CHECK\") and not mandatory:\n        logger.warning_rank0_once(\"Version checking has been disabled, may lead to unexpected behaviors.\")\n        return\n\n    if \"gptmodel\" in requirement or \"autoawq\" in requirement:\n        pip_command = f\"pip install {requirement} --no-build-isolation\"\n    else:\n        pip_command = f\"pip install {requirement}\"\n\n    if mandatory:\n        hint = f\"To fix: run `{pip_command}`.\"\n    else:\n        hint = f\"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check.\"\n\n    require_version(requirement, hint)\n\n\ndef check_dependencies() -> None:\n    r\"\"\"Check the version of the required packages.\"\"\"\n    check_version(\"transformers>=4.49.0,<=4.55.0\")\n    check_version(\"datasets>=2.16.0,<=3.6.0\")\n    check_version(\"accelerate>=1.3.0,<=1.7.0\")\n    check_version(\"peft>=0.14.0,<=0.15.2\")\n    check_version(\"trl>=0.8.6,<=0.9.6\")\n\n\ndef calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal[\"sft\", \"rm\"]) -> float:\n    r\"\"\"Calculate effective tokens per second.\"\"\"\n    effective_token_num = 0\n    for data in dataset:\n        if stage == \"sft\":\n            effective_token_num += len(data[\"input_ids\"])\n        elif stage == \"rm\":\n            effective_token_num += len(data[\"chosen_input_ids\"]) + len(data[\"rejected_input_ids\"])\n\n    result = effective_token_num * metrics[\"epoch\"] / metrics[\"train_runtime\"]\n    return result / dist.get_world_size() if dist.is_initialized() else result\n\n\ndef count_parameters(model: \"torch.nn.Module\") -> tuple[int, int]:\n    r\"\"\"Return the number of trainable parameters and number of all parameters in the model.\"\"\"\n    trainable_params, all_param = 0, 0\n    for param in model.parameters():\n        num_params = param.numel()\n        # if using DS Zero 3 and the weights are initialized empty\n        if num_params == 0 and hasattr(param, \"ds_numel\"):\n            num_params = param.ds_numel\n\n        # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize\n        if param.__class__.__name__ == \"Params4bit\":\n            if hasattr(param, \"quant_storage\") and hasattr(param.quant_storage, \"itemsize\"):\n                num_bytes = param.quant_storage.itemsize\n            elif hasattr(param, \"element_size\"):  # for older pytorch version\n                num_bytes = param.element_size()\n            else:\n                num_bytes = 1\n\n            num_params = num_params * 2 * num_bytes\n\n        all_param += num_params\n        if param.requires_grad:\n            trainable_params += num_params\n\n    return trainable_params, all_param\n\n\ndef get_current_device() -> \"torch.device\":\n    r\"\"\"Get the current available device.\"\"\"\n    if is_torch_xpu_available():\n        device = \"xpu:{}\".format(os.getenv(\"LOCAL_RANK\", \"0\"))\n    elif is_torch_npu_available():\n        device = \"npu:{}\".format(os.getenv(\"LOCAL_RANK\", \"0\"))\n    elif is_torch_mps_available():\n        device = \"mps:{}\".format(os.getenv(\"LOCAL_RANK\", \"0\"))\n    elif is_torch_cuda_available():\n        device = \"cuda:{}\".format(os.getenv(\"LOCAL_RANK\", \"0\"))\n    else:\n        device = \"cpu\"\n\n    return torch.device(device)\n\n\ndef get_device_count() -> int:\n    r\"\"\"Get the number of available devices.\"\"\"\n    if is_torch_xpu_available():\n        return torch.xpu.device_count()\n    elif is_torch_npu_available():\n        return torch.npu.device_count()\n    elif is_torch_mps_available():\n        return torch.mps.device_count()\n    elif is_torch_cuda_available():\n        return torch.cuda.device_count()\n    else:\n        return 0\n\n\ndef get_logits_processor() -> \"LogitsProcessorList\":\n    r\"\"\"Get logits processor that removes NaN and Inf logits.\"\"\"\n    logits_processor = LogitsProcessorList()\n    logits_processor.append(InfNanRemoveLogitsProcessor())\n    return logits_processor\n\n\ndef get_current_memory() -> tuple[int, int]:\n    r\"\"\"Get the available and total memory for the current device (in Bytes).\"\"\"\n    if is_torch_xpu_available():\n        return torch.xpu.mem_get_info()\n    elif is_torch_npu_available():\n        return torch.npu.mem_get_info()\n    elif is_torch_mps_available():\n        return torch.mps.current_allocated_memory(), torch.mps.recommended_max_memory()\n    elif is_torch_cuda_available():\n        return torch.cuda.mem_get_info()\n    else:\n        return 0, -1\n\n\ndef get_peak_memory() -> tuple[int, int]:\n    r\"\"\"Get the peak memory usage (allocated, reserved) for the current device (in Bytes).\"\"\"\n    if is_torch_xpu_available():\n        return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()\n    elif is_torch_npu_available():\n        return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()\n    elif is_torch_mps_available():\n        return torch.mps.current_allocated_memory(), -1\n    elif is_torch_cuda_available():\n        return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()\n    else:\n        return 0, -1\n\n\ndef has_tokenized_data(path: \"os.PathLike\") -> bool:\n    r\"\"\"Check if the path has a tokenized dataset.\"\"\"\n    return os.path.isdir(path) and len(os.listdir(path)) > 0\n\n\ndef infer_optim_dtype(model_dtype: Optional[\"torch.dtype\"]) -> \"torch.dtype\":\n    r\"\"\"Infer the optimal dtype according to the model_dtype and device compatibility.\"\"\"\n    if _is_bf16_available and (model_dtype == torch.bfloat16 or model_dtype is None):\n        return torch.bfloat16\n    elif _is_fp16_available:\n        return torch.float16\n    else:\n        return torch.float32\n\n\ndef is_accelerator_available() -> bool:\n    r\"\"\"Check if the accelerator is available.\"\"\"\n    return (\n        is_torch_xpu_available() or is_torch_npu_available() or is_torch_mps_available() or is_torch_cuda_available()\n    )\n\n\ndef is_env_enabled(env_var: str, default: str = \"0\") -> bool:\n    r\"\"\"Check if the environment variable is enabled.\"\"\"\n    return os.getenv(env_var, default).lower() in [\"true\", \"y\", \"1\"]\n\n\ndef numpify(inputs: Union[\"NDArray\", \"torch.Tensor\"]) -> \"NDArray\":\n    r\"\"\"Cast a torch tensor or a numpy array to a numpy array.\"\"\"\n    if isinstance(inputs, torch.Tensor):\n        inputs = inputs.cpu()\n        if inputs.dtype == torch.bfloat16:  # numpy does not support bfloat16 until 1.21.4\n            inputs = inputs.to(torch.float32)\n\n        inputs = inputs.numpy()\n\n    return inputs\n\n\ndef skip_check_imports() -> None:\n    r\"\"\"Avoid flash attention import error in custom model files.\"\"\"\n    if not is_env_enabled(\"FORCE_CHECK_IMPORTS\"):\n        transformers.dynamic_module_utils.check_imports = get_relative_imports\n\n\ndef torch_gc() -> None:\n    r\"\"\"Collect the device memory.\"\"\"\n    gc.collect()\n    if is_torch_xpu_available():\n        torch.xpu.empty_cache()\n    elif is_torch_npu_available():\n        torch.npu.empty_cache()\n    elif is_torch_mps_available():\n        torch.mps.empty_cache()\n    elif is_torch_cuda_available():\n        torch.cuda.empty_cache()\n\n\ndef try_download_model_from_other_hub(model_args: \"ModelArguments\") -> str:\n    if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):\n        return model_args.model_name_or_path\n\n    if use_modelscope():\n        check_version(\"modelscope>=1.14.0\", mandatory=True)\n        from modelscope import snapshot_download  # type: ignore\n        from modelscope.hub.api import HubApi  # type: ignore\n\n        if model_args.ms_hub_token:\n            api = HubApi()\n            api.login(model_args.ms_hub_token)\n\n        revision = \"master\" if model_args.model_revision == \"main\" else model_args.model_revision\n        with WeakFileLock(os.path.abspath(os.path.expanduser(\"~/.cache/llamafactory/modelscope.lock\"))):\n            model_path = snapshot_download(\n                model_args.model_name_or_path,\n                revision=revision,\n                cache_dir=model_args.cache_dir,\n            )\n\n        return model_path\n\n    if use_openmind():\n        check_version(\"openmind>=0.8.0\", mandatory=True)\n        from openmind.utils.hub import snapshot_download  # type: ignore\n\n        with WeakFileLock(os.path.abspath(os.path.expanduser(\"~/.cache/llamafactory/openmind.lock\"))):\n            model_path = snapshot_download(\n                model_args.model_name_or_path,\n                revision=model_args.model_revision,\n                cache_dir=model_args.cache_dir,\n            )\n\n        return model_path\n\n\ndef use_modelscope() -> bool:\n    return is_env_enabled(\"USE_MODELSCOPE_HUB\")\n\n\ndef use_openmind() -> bool:\n    return is_env_enabled(\"USE_OPENMIND_HUB\")\n\n\ndef use_ray() -> bool:\n    return is_env_enabled(\"USE_RAY\")\n\n\ndef find_available_port() -> int:\n    r\"\"\"Find an available port on the local machine.\"\"\"\n    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n    sock.bind((\"\", 0))\n    port = sock.getsockname()[1]\n    sock.close()\n    return port\n\n\ndef fix_proxy(ipv6_enabled: bool = False) -> None:\n    r\"\"\"Fix proxy settings for gradio ui.\"\"\"\n    os.environ[\"no_proxy\"] = \"localhost,127.0.0.1,0.0.0.0\"\n    if ipv6_enabled:\n        os.environ.pop(\"http_proxy\", None)\n        os.environ.pop(\"HTTP_PROXY\", None)\n"
  },
  {
    "path": "kt-sft/ktransformers/sft/metrics_utils/packages.py",
    "content": "# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.\n#\n# This code is inspired by the HuggingFace's transformers library.\n# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport importlib.metadata\nimport importlib.util\nfrom functools import lru_cache\nfrom typing import TYPE_CHECKING\n\nfrom packaging import version\n\n\nif TYPE_CHECKING:\n    from packaging.version import Version\n\n\ndef _is_package_available(name: str) -> bool:\n    return importlib.util.find_spec(name) is not None\n\n\ndef _get_package_version(name: str) -> \"Version\":\n    try:\n        return version.parse(importlib.metadata.version(name))\n    except Exception:\n        return version.parse(\"0.0.0\")\n\n\ndef is_pyav_available():\n    return _is_package_available(\"av\")\n\n\ndef is_librosa_available():\n    return _is_package_available(\"librosa\")\n\n\ndef is_fastapi_available():\n    return _is_package_available(\"fastapi\")\n\n\ndef is_galore_available():\n    return _is_package_available(\"galore_torch\")\n\n\ndef is_apollo_available():\n    return _is_package_available(\"apollo_torch\")\n\n\ndef is_gradio_available():\n    return _is_package_available(\"gradio\")\n\n\ndef is_matplotlib_available():\n    return _is_package_available(\"matplotlib\")\n\n\ndef is_pillow_available():\n    return _is_package_available(\"PIL\")\n\n\ndef is_ray_available():\n    return _is_package_available(\"ray\")\n\n\ndef is_requests_available():\n    return _is_package_available(\"requests\")\n\n\ndef is_rouge_available():\n    return _is_package_available(\"rouge_chinese\")\n\n\ndef is_starlette_available():\n    return _is_package_available(\"sse_starlette\")\n\n\n@lru_cache\ndef is_transformers_version_greater_than(content: str):\n    return _get_package_version(\"transformers\") >= version.parse(content)\n\n\ndef is_uvicorn_available():\n    return _is_package_available(\"uvicorn\")\n\n\ndef is_vllm_available():\n    return _is_package_available(\"vllm\")\n\n\ndef is_sglang_available():\n    return _is_package_available(\"sglang\")\n"
  },
  {
    "path": "kt-sft/ktransformers/sft/metrics_utils/ploting.py",
    "content": "# Copyright 2025 the LlamaFactory team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport json\nimport math\nimport os\nfrom typing import Any\n\nfrom transformers.trainer import TRAINER_STATE_NAME\n\nfrom . import logging\nfrom .packages import is_matplotlib_available\n\n\nif is_matplotlib_available():\n    import matplotlib.figure\n    import matplotlib.pyplot as plt\n\n\nlogger = logging.get_logger(__name__)\n\n\ndef smooth(scalars: list[float]) -> list[float]:\n    r\"\"\"EMA implementation according to TensorBoard.\"\"\"\n    if len(scalars) == 0:\n        return []\n\n    last = scalars[0]\n    smoothed = []\n    weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5)  # a sigmoid function\n    for next_val in scalars:\n        smoothed_val = last * weight + (1 - weight) * next_val\n        smoothed.append(smoothed_val)\n        last = smoothed_val\n    return smoothed\n\n\ndef gen_loss_plot(trainer_log: list[dict[str, Any]]) -> \"matplotlib.figure.Figure\":\n    r\"\"\"Plot loss curves in LlamaBoard.\"\"\"\n    plt.close(\"all\")\n    plt.switch_backend(\"agg\")\n    fig = plt.figure()\n    ax = fig.add_subplot(111)\n    steps, losses = [], []\n    for log in trainer_log:\n        if log.get(\"loss\", None):\n            steps.append(log[\"current_steps\"])\n            losses.append(log[\"loss\"])\n\n    ax.plot(steps, losses, color=\"#1f77b4\", alpha=0.4, label=\"original\")\n    ax.plot(steps, smooth(losses), color=\"#1f77b4\", label=\"smoothed\")\n    ax.legend()\n    ax.set_xlabel(\"step\")\n    ax.set_ylabel(\"loss\")\n    return fig\n\n\ndef plot_loss(save_dictionary: str, keys: list[str] = [\"loss\"]) -> None:\n    r\"\"\"Plot loss curves and saves the image.\"\"\"\n    plt.switch_backend(\"agg\")\n    with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding=\"utf-8\") as f:\n        data = json.load(f)\n\n    for key in keys:\n        steps, metrics = [], []\n        for i in range(len(data[\"log_history\"])):\n            if key in data[\"log_history\"][i]:\n                steps.append(data[\"log_history\"][i][\"step\"])\n                metrics.append(data[\"log_history\"][i][key])\n\n        if len(metrics) == 0:\n            logger.warning_rank0(f\"No metric {key} to plot.\")\n            continue\n\n        plt.figure()\n        plt.plot(steps, metrics, color=\"#1f77b4\", alpha=0.4, label=\"original\")\n        plt.plot(steps, smooth(metrics), color=\"#1f77b4\", label=\"smoothed\")\n        plt.title(f\"training {key} of {save_dictionary}\")\n        plt.xlabel(\"step\")\n        plt.ylabel(key)\n        plt.legend()\n        figure_path = os.path.join(save_dictionary, \"training_{}.png\".format(key.replace(\"/\", \"_\")))\n        plt.savefig(figure_path, format=\"png\", dpi=100)\n        print(\"Figure saved at:\", figure_path)\n"
  },
  {
    "path": "kt-sft/ktransformers/sft/monkey_patch_torch_module.py",
    "content": "import torch\nfrom collections import OrderedDict\nfrom torch.nn.modules import Module\n\n_ORIG_MODULE_INIT = Module.__init__\n\ndef _patched_module_init(self, *args, **kwargs):\n    torch._C._log_api_usage_once(\"python.nn_module\")\n\n    if self.call_super_init is False and bool(kwargs):\n        raise TypeError(\n            f\"{type(self).__name__}.__init__() got an unexpected keyword argument '{next(iter(kwargs))}'\"\n        )\n    if self.call_super_init is False and bool(args):\n        raise TypeError(\n            f\"{type(self).__name__}.__init__() takes 1 positional argument but {len(args) + 1} were given\"\n        )\n\n    object.__setattr__(self, \"training\", True)\n    object.__setattr__(self, \"_parameters\", {})\n    object.__setattr__(self, \"_buffers\", {})\n    object.__setattr__(self, \"_non_persistent_buffers_set\", set())\n    object.__setattr__(self, \"_backward_pre_hooks\", OrderedDict())\n    object.__setattr__(self, \"_backward_hooks\", OrderedDict())\n    object.__setattr__(self, \"_is_full_backward_hook\", None)\n    object.__setattr__(self, \"_forward_hooks\", OrderedDict())\n    object.__setattr__(self, \"_forward_hooks_with_kwargs\", OrderedDict())\n    object.__setattr__(self, \"_forward_hooks_always_called\", OrderedDict())\n    object.__setattr__(self, \"_forward_pre_hooks\", OrderedDict())\n    object.__setattr__(self, \"_forward_pre_hooks_with_kwargs\", OrderedDict())\n    object.__setattr__(self, \"_state_dict_hooks\", OrderedDict())\n    object.__setattr__(self, \"_state_dict_pre_hooks\", OrderedDict())\n    object.__setattr__(self, \"_load_state_dict_pre_hooks\", OrderedDict())\n    object.__setattr__(self, \"_load_state_dict_post_hooks\", OrderedDict())\n\n    if not (hasattr(self, \"orig_module\") and isinstance(self.orig_module, torch.nn.modules.linear.Linear)):\n        object.__setattr__(self, \"_modules\", {})\n\n    if self.call_super_init:\n        object.__init__(self)\n\ndef install_patch():\n    Module.__init__ = _patched_module_init\n\ndef restore_patch():\n    Module.__init__ = _ORIG_MODULE_INIT\n\ninstall_patch()"
  },
  {
    "path": "kt-sft/ktransformers/sft/peft_utils/__init__.py",
    "content": ""
  },
  {
    "path": "kt-sft/ktransformers/sft/peft_utils/lora_layer.py",
    "content": "from abc import ABC\nfrom copy import deepcopy\nimport math\nimport warnings\nfrom typing import Any, Optional, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom accelerate.utils.imports import is_xpu_available\nfrom torch import BufferDict, svd_lowrank, transpose\nfrom transformers.pytorch_utils import Conv1D\n\nfrom peft.tuners.lora.config import LoraConfig\n\nfrom ktransformers.operators.linear import KTransformersLinear, KLinearTorch, KLinearBase\nfrom ktransformers.operators.base_operator import BaseInjectedModule\nfrom ktransformers.util.inference_state import InferenceState\n\ndef dispatch_default(\n    target: torch.nn.Module,\n    adapter_name: str,\n    lora_config: LoraConfig,\n    **kwargs,\n) -> Optional[torch.nn.Module]:\n    new_module = None\n\n    if isinstance(target, BaseTunerLayer):\n        target_orig_module = target.get_orig_module()\n    else:\n        target_orig_module = target\n\n    if isinstance(target_orig_module, torch.nn.Embedding):\n        embedding_kwargs = kwargs.copy()\n        embedding_kwargs.pop(\"fan_in_fan_out\", None)\n        embedding_kwargs.update(lora_config.loftq_config)\n        new_module = Embedding(target, adapter_name, **embedding_kwargs)\n\n    elif isinstance(target_orig_module, torch.nn.Linear):\n        if kwargs[\"fan_in_fan_out\"]:\n            warnings.warn(\n                \"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. \"\n                \"Setting fan_in_fan_out to False.\"\n            )\n            kwargs[\"fan_in_fan_out\"] = lora_config.fan_in_fan_out = False\n        kwargs.update(lora_config.loftq_config)\n        new_module = Linear(target, adapter_name, **kwargs)\n\n    elif isinstance(target_orig_module, KTransformersLinear):\n        kwargs.update(lora_config.loftq_config)\n        new_module = KTransformersLinearLora(target, adapter_name, **kwargs)\n\n    return new_module\n\nclass BaseTunerLayer(ABC):\n    r\"\"\"\n    A tuner layer mixin that provides the common methods and attributes for all tuners.\n\n    Args:\n        is_pluggable (`bool`, *optional*):\n            Whether the adapter layer can be plugged to any pytorch module\n        active_adapters (Union[List[`str`], `str`], *optional*):\n            The name of the active adapter.\n    \"\"\"\n\n    # All names of layers that may contain adapter (trainable) weights\n    adapter_layer_names: tuple[str, ...] = ()\n    # All names of other parameters that may contain adapter-related parameters\n    other_param_names: tuple[str, ...] = ()\n\n    # indicates whether all adapters should be disabled\n    _disable_adapters: bool = False\n\n    # the currently active adapter(s)\n    _active_adapter: str | list[str] = \"default\"\n\n    # List all merged adapters\n    merged_adapters: list[str] = []\n\n    def get_orig_module(self) -> nn.Module:\n        \"\"\"\n        (Recursively) get the orig_module.\n\n        This is necessary for the case that the tuner layer wraps another tuner layer.\n\n        \"\"\"\n        orig_module = self\n        while hasattr(orig_module, \"orig_module\"):\n            orig_module = orig_module.orig_module\n        return orig_module\n\n    @property\n    def weight(self) -> torch.Tensor:\n        # This is required for some transformers code, e.g. for T5, weight is accessed as:\n        #     self.wo.weight\n        # where \"wo\" is the adapter layer.\n        # https://github.com/huggingface/transformers/blob/78f6ed6c70b29c1560780e3869a7ad4c6b3d2710/src/transformers\n        # /models/t5/modeling_t5.py#L292\n        orig_module = self.get_orig_module()\n        if hasattr(orig_module, \"qweight\"):\n            # QuantLinear\n            weight = orig_module.qweight\n        else:\n            # Other layers\n            weight = orig_module.weight\n        return weight\n\n    @property\n    def bias(self) -> torch.Tensor:\n        orig_module = self.get_orig_module()\n        return orig_module.bias\n\n    def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:\n        raise NotImplementedError\n\n    def unmerge(self) -> None:\n        raise NotImplementedError\n\n    @property\n    def merged(self) -> bool:\n        return bool(self.merged_adapters)\n\n    @property\n    def disable_adapters(self) -> bool:\n        # use a property to ensure that disable_adapters is not set directly, instead use the enable_adapters method\n        return self._disable_adapters\n\n    @property\n    def active_adapter(self) -> str | list[str]:\n        # use a property to ensure that active_adapter is not set directly, instead use the set_adapter method\n        return self._active_adapter\n\n    def _get_available_adapters(self) -> set[str]:\n        \"\"\"Return all adapter names that can be found on this module.\"\"\"\n        adapters = set()\n        for layer_name in self.adapter_layer_names:\n            module = getattr(self, layer_name)\n            if not isinstance(module, (nn.ModuleDict, nn.ParameterDict)):\n                continue\n            adapters.update(set(module.keys()))\n        return adapters\n\n    @property\n    def active_adapters(self):\n        if isinstance(self.active_adapter, str):\n            return [self.active_adapter]\n        # is already a list of str\n        return self.active_adapter\n\n    def enable_adapters(self, enabled: bool) -> None:\n        \"\"\"Toggle the enabling and disabling of adapters\n\n        Takes care of setting the requires_grad flag for the adapter weights.\n\n        Args:\n            enabled (bool): True to enable adapters, False to disable adapters\n        \"\"\"\n        if enabled:\n            self.set_adapter(self.active_adapters)\n            self._disable_adapters = False\n        else:\n            # disable grads on all adapter layers\n            for layer_name in self.adapter_layer_names:\n                layer = getattr(self, layer_name)\n                layer.requires_grad_(False)\n            self._disable_adapters = True\n\n    def set_adapter(self, adapter_names: str | list[str]) -> None:\n        \"\"\"Set the active adapter(s).\n\n        Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is\n        not desired, use the following code.\n\n        ```py\n        >>> for name, param in model_peft.named_parameters():\n        ...     if ...:  # some check on name (ex. if 'lora' in name)\n        ...         param.requires_grad = False\n        ```\n\n        Args:\n            adapter_name (`str` or `List[str]`): Name of the adapter(s) to be activated.\n        \"\"\"\n        if isinstance(adapter_names, str):\n            adapter_names = [adapter_names]\n\n        # Deactivate grads on the inactive adapter and activate grads on the active adapter\n        for layer_name in self.adapter_layer_names:\n            module_dict = getattr(self, layer_name)\n            for key, layer in module_dict.items():\n                if key in adapter_names:\n                    # Note: It is possible that not a single layer is called with requires_grad_(True) here. This may\n                    # happen if a completely different adapter layer is being activated.\n                    layer.requires_grad_(True)\n                else:\n                    layer.requires_grad_(False)\n\n        self._active_adapter = adapter_names\n\n    def _all_available_adapter_names(self) -> list[str]:\n        \"\"\"Return a sorted list of all available adapter names\"\"\"\n        adapter_names = set()\n        for name in self.adapter_layer_names + self.other_param_names:\n            # we check each possible attribute and if it's a dict or ModuleDict, we assume that the keys are the adapter\n            # names\n            attr = getattr(self, name)\n            if hasattr(attr, \"keys\"):\n                adapter_names.update(attr.keys())\n        return sorted(adapter_names)\n\n    def delete_adapter(self, adapter_name: str) -> None:\n        \"\"\"\n        Delete an adapter from the layer\n\n        This should be called on all adapter layers, or else we will get an inconsistent state.\n\n        This method will also set a new active adapter if the deleted adapter was an active adapter. It is important\n        that the new adapter is chosen in a deterministic way, so that the same adapter is chosen on all layers.\n\n        Args:\n            adapter_name (`str`): The name of the adapter to delete\n\n        \"\"\"\n        for attr in self.adapter_layer_names + self.other_param_names:\n            if adapter_name in getattr(self, attr):\n                del getattr(self, attr)[adapter_name]\n\n        if adapter_name in self.active_adapters:\n            # choose a new active adapter\n            active_adapters = self.active_adapters[:]\n            active_adapters.remove(adapter_name)\n            if active_adapters:\n                self.set_adapter(active_adapters)\n            else:\n                # no active adapters left, set a new default adapter\n                # here we get the list of all adapters existing adapter names and choose the first one\n                remaining_adapters = self._all_available_adapter_names()\n                if not remaining_adapters:\n                    self.set_adapter([])\n                else:\n                    new_active_adapter = remaining_adapters[0]\n                    warnings.warn(\n                        f\"Adapter {adapter_name} was active which is now deleted. Setting active adapter to \"\n                        f\"{new_active_adapter}.\"\n                    )\n                    self.set_adapter(remaining_adapters[0])\n\n    def _move_adapter_to_device_of_orig_module(self, adapter_name: str, device: Optional[torch.device] = None) -> None:\n        \"\"\"\n        Move the adapter of the given name to the device of the base layer.\n        \"\"\"\n        if device is None:\n            # check weight and qweight (for GPTQ)\n            for weight_name in (\"weight\", \"qweight\"):\n                weight = getattr(self.get_orig_module(), weight_name, None)\n                if weight is not None:\n                    device = weight.device\n                    dtype = weight.dtype\n                    break\n            else:\n                # no break encountered: could not determine the device\n                return\n\n        meta = torch.device(\"meta\")\n\n        # loop through all potential adapter layers and move them to the device of the base layer; be careful to only\n        # move this specific adapter to the device, as the other adapters could be on different devices\n        # see #1639\n        for adapter_layer_name in self.adapter_layer_names + self.other_param_names:\n            adapter_layer = getattr(self, adapter_layer_name, None)\n            if not isinstance(adapter_layer, (nn.ModuleDict, nn.ParameterDict, BufferDict)):\n                continue\n            if adapter_name not in adapter_layer:\n                continue\n            if any(p.device == meta for p in adapter_layer.parameters()):\n                continue\n\n            if weight.dtype.is_floating_point or weight.dtype.is_complex:\n                adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device, dtype=dtype)\n            else:\n                adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device)\n\n\nclass LoraLayer(BaseTunerLayer):\n    # All names of layers that may contain (trainable) adapter weights\n    adapter_layer_names = (\"lora_A\", \"lora_B\", \"lora_embedding_A\", \"lora_embedding_B\")\n    # All names of other parameters that may contain adapter-related parameters\n    other_param_names = (\"r\", \"lora_alpha\", \"scaling\", \"lora_dropout\")\n\n    def __init__(self, orig_module: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None:\n        self.orig_module = orig_module\n        self.r = {}\n        self.lora_alpha = {}\n        self.scaling = {}\n        self.lora_dropout = nn.ModuleDict({})\n        self.lora_A = nn.ModuleDict({})\n        self.lora_B = nn.ModuleDict({})\n        # For Embedding layer\n        self.lora_embedding_A = nn.ParameterDict({})\n        self.lora_embedding_B = nn.ParameterDict({})\n        # Mark the weight as unmerged\n        self._disable_adapters = False\n        self.merged_adapters = []\n        self.use_dora: dict[str, bool] = {}\n        self.lora_bias: dict[str, bool] = {}\n        self.lora_magnitude_vector = torch.nn.ModuleDict()  # for DoRA\n        self._caches: dict[str, Any] = {}\n        self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload\n        self.kwargs = kwargs\n\n        orig_module = self.get_orig_module()\n        if isinstance(orig_module, nn.Linear):\n            in_features, out_features = orig_module.in_features, orig_module.out_features\n        elif isinstance(orig_module, nn.Embedding):\n            in_features, out_features = orig_module.num_embeddings, orig_module.embedding_dim\n        else:\n            raise TypeError(f\"unknown type of {orig_module}, not in Linear or Embedding.\")\n\n        self.in_features = in_features\n        self.out_features = out_features\n\n    def update_layer(\n        self,\n        adapter_name,\n        r,\n        lora_alpha,\n        lora_dropout,\n        init_lora_weights,\n        use_rslora: bool = False,\n        use_dora: bool = False,\n        lora_bias: bool = False,\n    ):\n        # This code works for linear layers, override for other layer types\n        if r <= 0:\n            raise ValueError(f\"`r` should be a positive integer value but the value passed is {r}\")\n\n        self.r[adapter_name] = r\n        self.lora_alpha[adapter_name] = lora_alpha\n        if lora_dropout > 0.0:\n            lora_dropout_layer = nn.Dropout(p=lora_dropout)\n        else:\n            lora_dropout_layer = nn.Identity()\n\n        self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))\n        # Actual trainable parameters\n        self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False)\n        self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias)\n        self.lora_bias[adapter_name] = lora_bias\n\n        if use_rslora:\n            self.scaling[adapter_name] = lora_alpha / math.sqrt(r)\n        else:\n            self.scaling[adapter_name] = lora_alpha / r\n\n        # for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed\n        if init_lora_weights == \"eva\":\n            nn.init.zeros_(self.lora_B[adapter_name].weight)\n        elif init_lora_weights:\n            self.reset_lora_parameters(adapter_name, init_lora_weights)\n        # call this before dora_init\n        self._move_adapter_to_device_of_orig_module(adapter_name)\n\n        self.set_adapter(self.active_adapters)\n\n    def reset_lora_parameters(self, adapter_name, init_lora_weights):\n        if init_lora_weights is False:\n            return\n\n        if adapter_name in self.lora_A.keys():\n            if init_lora_weights is True:\n                # initialize A the same way as the default for nn.Linear and B to zero\n                # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124\n                nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))\n            elif init_lora_weights.lower() == \"gaussian\":\n                nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name])\n            else:\n                raise ValueError(f\"Unknown initialization {init_lora_weights=}\")\n            nn.init.zeros_(self.lora_B[adapter_name].weight)\n            if self.lora_bias[adapter_name]:\n                nn.init.zeros_(self.lora_B[adapter_name].bias)\n        if adapter_name in self.lora_embedding_A.keys():\n            # Initialize A to zeros and B the same way as the default for nn.Embedding, see:\n            # https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L59-L60\n            nn.init.zeros_(self.lora_embedding_A[adapter_name])\n            nn.init.normal_(self.lora_embedding_B[adapter_name])\n            if self.lora_bias[adapter_name]:\n                # embeddings are not supported at the moment, but still adding this for consistency\n                nn.init.zeros_(self.lora_embedding_B[adapter_name].bias)\n\n    def olora_init(self, adapter_name):\n        orig_module = self.get_orig_module()\n        orig_weight = orig_module.weight\n        dtype = orig_weight.dtype\n\n        if dtype in [torch.float32, torch.float16, torch.bfloat16]:\n            weight_tensor = orig_weight\n        else:\n            raise TypeError(f\"Unsupported data type for the base layer. Got {dtype}.\")\n\n        scale_factor = self.scaling[adapter_name]\n        r = self.r[adapter_name]\n        weight_tensor = weight_tensor.to(torch.float32)\n        Q, R = torch.linalg.qr(weight_tensor.data)\n\n        Qr, Rr = Q[:, :r], R[:r]\n\n        self.lora_A[adapter_name].weight.data = Rr.contiguous()\n        self.lora_B[adapter_name].weight.data = Qr.contiguous()\n\n        weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight\n        weight_tensor = weight_tensor.to(dtype)\n            \n        orig_module.weight.data = weight_tensor\n\n    def pissa_init(self, adapter_name, init_lora_weights):\n        weight = self.get_orig_module().weight\n        dtype = weight.dtype\n        if dtype not in [torch.float32, torch.float16, torch.bfloat16]:\n            raise TypeError(\n                \"Please initialize PiSSA under float32, float16, or bfloat16. \"\n                \"Subsequently, re-quantize the residual model to help minimize quantization errors.\"\n            )\n        weight = transpose(weight.to(torch.float32), self.fan_in_fan_out)\n        if init_lora_weights == \"pissa\":\n            # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},\n            V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)\n            Vr = V[:, : self.r[adapter_name]]\n            Sr = S[: self.r[adapter_name]]\n            Sr /= self.scaling[adapter_name]\n            Uhr = Uh[: self.r[adapter_name]]\n        elif len(init_lora_weights.split(\"_niter_\")) == 2:\n            Vr, Sr, Ur = svd_lowrank(\n                weight.data, self.r[adapter_name], niter=int(init_lora_weights.split(\"_niter_\")[-1])\n            )\n            Sr /= self.scaling[adapter_name]\n            Uhr = Ur.t()\n        else:\n            raise ValueError(\n                f\"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead.\"\n            )\n\n        lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr\n        lora_B = Vr @ torch.diag(torch.sqrt(Sr))\n        self.lora_A[adapter_name].weight.data = lora_A\n        self.lora_B[adapter_name].weight.data = lora_B\n        weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A\n        weight = transpose(weight.to(dtype), self.fan_in_fan_out)\n        self.get_orig_module().weight.data = weight\n\n    def loftq_init(self, adapter_name):\n        from peft.utils.loftq_utils import loftq_init\n\n        weight = self.get_orig_module().weight\n        kwargs = {\n            \"num_bits\": self.kwargs.get(\"loftq_bits\", 4),\n            \"reduced_rank\": self.r[adapter_name],\n            \"num_iter\": self.kwargs.get(\"loftq_iter\", 1),\n        }\n\n        qweight, lora_A, lora_B = loftq_init(weight, **kwargs)\n        if adapter_name in self.lora_A.keys():\n            # initialize A the same way as the default for nn.Linear and B to zero\n            self.lora_A[adapter_name].weight.data = lora_A\n            self.lora_B[adapter_name].weight.data = lora_B\n        if adapter_name in self.lora_embedding_A.keys():\n            # initialize a the same way as the default for nn.linear and b to zero\n            self.lora_embedding_A[adapter_name].weight.data = lora_A\n            self.lora_embedding_B[adapter_name].weight.data = lora_B\n        self.get_orig_module().weight.data = qweight\n\n    def _cache_store(self, key: str, value: Any) -> None:\n        self._caches[key] = value\n\n    def _cache_pop(self, key: str) -> Any:\n        value = self._caches.pop(key)\n        return value\n\n    def set_scale(self, adapter, scale):\n        if adapter not in self.scaling:\n            # Ignore the case where the adapter is not in the layer\n            return\n        self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]\n\n    def scale_layer(self, scale: float) -> None:\n        if scale == 1:\n            return\n\n        for active_adapter in self.active_adapters:\n            if active_adapter not in self.lora_A.keys():\n                continue\n\n            self.scaling[active_adapter] *= scale\n\n    def unscale_layer(self, scale=None) -> None:\n        for active_adapter in self.active_adapters:\n            if active_adapter not in self.lora_A.keys():\n                continue\n\n            if scale is None:\n                self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]\n            else:\n                self.scaling[active_adapter] /= scale\n\n    def _check_forward_args(self, x, *args, **kwargs):\n        \"\"\"Check if the arguments are compatible with the configs and state of the model\"\"\"\n        adapter_names = kwargs.get(\"adapter_names\", None)\n        if adapter_names is None:\n            return\n\n        if len(x) != len(adapter_names):\n            msg = (\n                \"Length of `adapter_names` should be the same as the number of inputs, but got \"\n                f\"{len(adapter_names)} and {len(x)} respectively.\"\n            )\n            raise ValueError(msg)\n\n        if self.merged:\n            # It is unclear what would be the right thing to do if users pass adapter_names and there are merged\n            # adapters. Therefore, it is better to raise an error in this case.\n            msg = \"Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first.\"\n            raise ValueError(msg)\n\n        # DoRA is not supported (yet), check that it's not being used. Don't check \"__base__\", as this is the\n        # placeholder for the base model.\n        unique_adapters = {name for name in adapter_names if name != \"__base__\"}\n        for adapter_name in unique_adapters:\n            if self.use_dora.get(adapter_name, False):\n                msg = \"Cannot pass `adapter_names` when DoRA is enabled.\"\n                raise ValueError(msg)\n\n    def _mixed_batch_forward(\n        self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any\n    ) -> torch.Tensor:\n        # This is a special method that handles the case when users pass the argument `adapter_names`. This is an\n        # extra argument that allows mixing different adapters in the same batch at inference time.\n        result = self.orig_module(x, *args, **kwargs)\n        torch_result_dtype = result.dtype\n\n        unique_adapters = set(adapter_names)\n        sub_batch_indices_list = []\n        for adapter in unique_adapters:\n            sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])\n\n        for i, active_adapter in enumerate(unique_adapters):\n            if active_adapter == \"__base__\":\n                continue\n            if active_adapter not in self.lora_A.keys():\n                continue\n\n            lora_A = self.lora_A[active_adapter]\n            lora_B = self.lora_B[active_adapter]\n            dropout = self.lora_dropout[active_adapter]\n            scaling = self.scaling[active_adapter]\n\n            # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear\n            # layer output\n            sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype)\n            lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling\n            result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype)\n\n        return result\n\nclass Linear(nn.Module, LoraLayer):\n    # Lora implemented in a dense layer\n    def __init__(\n        self,\n        orig_module,\n        adapter_name: str,\n        r: int = 0,\n        lora_alpha: int = 1,\n        lora_dropout: float = 0.0,\n        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)\n        is_target_conv_1d_layer: bool = False,\n        init_lora_weights: Union[bool, str] = True,\n        use_rslora: bool = False,\n        use_dora: bool = False,\n        lora_bias: bool = False,\n        **kwargs,\n    ) -> None:\n        super().__init__()\n        LoraLayer.__init__(self, orig_module, **kwargs)\n        self.fan_in_fan_out = fan_in_fan_out\n\n        self._active_adapter = adapter_name\n        self.update_layer(\n            adapter_name,\n            r,\n            lora_alpha=lora_alpha,\n            lora_dropout=lora_dropout,\n            init_lora_weights=init_lora_weights,\n            use_rslora=use_rslora,\n            use_dora=use_dora,\n            lora_bias=lora_bias,\n        )\n        self.is_target_conv_1d_layer = is_target_conv_1d_layer\n\n    def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:\n        \"\"\"\n        Merge the active adapter weights into the base weights\n\n        Args:\n            safe_merge (`bool`, *optional*):\n                If True, the merge operation will be performed in a copy of the original weights and check for NaNs\n                before merging the weights. This is useful if you want to check if the merge operation will produce\n                NaNs. Defaults to `False`.\n            adapter_names (`list[str]`, *optional*):\n                The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults\n                to `None`.\n        \"\"\"\n        if not adapter_names:\n            # no adapter to merge\n            return\n\n        for active_adapter in adapter_names:\n            if active_adapter in self.lora_A.keys():\n                orig_module = self.get_orig_module()\n                if safe_merge:\n                    # Note that safe_merge will be slower than the normal merge\n                    # because of the copy operation.\n                    orig_weights = orig_module.weight.data.clone()\n                    delta_weight = self.get_delta_weight(active_adapter)\n                    if not self.use_dora[active_adapter]:\n                        orig_weights += delta_weight\n                    else:\n                        # handle dora\n                        # since delta_weight already includes scaling, set it to 1 here\n                        weight_norm = (\n                            self.lora_magnitude_vector[active_adapter]\n                            .get_weight_norm(orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1)\n                            .detach()\n                        )\n                        # We need to cache weight_norm because it has to be based on the original weights. We\n                        # cannot calculate it on the fly based on the merged weights when unmerging because its a\n                        # different value\n                        self._cache_store(f\"{active_adapter}-weight_norm\", weight_norm)\n                        dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm\n                        dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)\n                        orig_weights = dora_factor * (orig_weights + delta_weight)\n\n                    if not torch.isfinite(orig_weights).all():\n                        raise ValueError(\n                            f\"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken\"\n                        )\n\n                    orig_module.weight.data = orig_weights\n\n                    if self.lora_bias[active_adapter]:\n                        new_bias = orig_module.bias + self.lora_B[active_adapter].bias\n                        if not torch.isfinite(new_bias).all():\n                            raise ValueError(\n                                f\"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken\"\n                            )\n                        orig_module.bias.data = new_bias\n\n                else:\n                    delta_weight = self.get_delta_weight(active_adapter)\n                    if not self.use_dora[active_adapter]:\n                        orig_module.weight.data += delta_weight\n                    else:\n                        # handle dora\n                        # since delta_weight already includes scaling, set it to 1 here\n                        weight_norm = (\n                            self.lora_magnitude_vector[active_adapter]\n                            .get_weight_norm(\n                                orig_module.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1\n                            )\n                            .detach()\n                        )\n                        # We need to cache weight_norm because it has to be based on the original weights. We\n                        # cannot calculate it on the fly based on the merged weights when unmerging because its a\n                        # different value\n                        self._cache_store(f\"{active_adapter}-weight_norm\", weight_norm)\n                        dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm\n                        dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)\n                        new_weight = dora_factor * (orig_module.weight.data + delta_weight)\n                        orig_module.weight.data = new_weight\n\n                    if self.lora_bias[active_adapter]:\n                        orig_module.bias.data += self.lora_B[active_adapter].bias\n\n                self.merged_adapters.append(active_adapter)\n\n    def unmerge(self) -> None:\n        \"\"\"\n        This method unmerges all merged adapter layers from the base weights.\n        \"\"\"\n        if not self.merged:\n            warnings.warn(\"Already unmerged. Nothing to do.\")\n            return\n        while len(self.merged_adapters) > 0:\n            active_adapter = self.merged_adapters.pop()\n            if active_adapter in self.lora_A.keys():\n                weight = self.get_orig_module().weight\n                delta_weight = self.get_delta_weight(active_adapter)\n                if not self.use_dora[active_adapter]:\n                    weight.data -= delta_weight\n                else:\n                    weight_norm = self._cache_pop(f\"{active_adapter}-weight_norm\")\n                    dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm\n                    weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight\n                    weight.data = weight_orig\n\n                if self.lora_bias[active_adapter]:\n                    self.get_orig_module().bias.data -= self.lora_B[active_adapter].bias\n\n    def get_delta_weight(self, adapter) -> torch.Tensor:\n        \"\"\"\n        Compute the delta weight for the given adapter.\n\n        Args:\n            adapter (str):\n                The name of the adapter for which the delta weight should be computed.\n        \"\"\"\n        device = self.lora_B[adapter].weight.device\n        dtype = self.lora_B[adapter].weight.dtype\n\n        # In case users wants to merge the adapter weights that are in\n        # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to\n        # (b)float16 because some CPUs have slow bf16/fp16 matmuls.\n        cast_to_fp32 = device.type == \"cpu\" and (dtype == torch.float16 or dtype == torch.bfloat16)\n\n        weight_A = self.lora_A[adapter].weight\n        weight_B = self.lora_B[adapter].weight\n\n        if cast_to_fp32:\n            weight_A = weight_A.float()\n            weight_B = weight_B.float()\n\n        output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter]\n\n        if cast_to_fp32:\n            output_tensor = output_tensor.to(dtype=dtype)\n\n            # cast back the weights\n            self.lora_A[adapter].weight.data = weight_A.to(dtype)\n            self.lora_B[adapter].weight.data = weight_B.to(dtype)\n\n        return output_tensor\n\n    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:\n        self._check_forward_args(x, *args, **kwargs)\n        adapter_names = kwargs.pop(\"adapter_names\", None)\n\n        if self.disable_adapters:\n            if self.merged:\n                self.unmerge()\n            result = self.orig_module(x, *args, **kwargs)\n        elif adapter_names is not None:\n            result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)\n        elif self.merged:\n            result = self.orig_module(x, *args, **kwargs)\n        else:\n            result = self.orig_module(x, *args, **kwargs)\n            torch_result_dtype = result.dtype\n            for active_adapter in self.active_adapters:\n                if active_adapter not in self.lora_A.keys():\n                    continue\n                lora_A = self.lora_A[active_adapter]\n                lora_B = self.lora_B[active_adapter]\n                dropout = self.lora_dropout[active_adapter]\n                scaling = self.scaling[active_adapter]\n                x = x.to(lora_A.weight.dtype)\n\n                # TODO: Remove dora method up to now.\n                result = result + lora_B(lora_A(dropout(x))) * scaling\n                \n            result = result.to(torch_result_dtype)\n\n        return result\n\n    def __repr__(self) -> str:\n        rep = super().__repr__()\n        return \"lora.\" + rep\n\n\nclass Embedding(nn.Module, LoraLayer):\n    # LoRA implemented in a Embedding layer\n    def __init__(\n        self,\n        orig_module: nn.Module,\n        adapter_name: str,\n        r: int = 0,\n        lora_alpha: int = 1,\n        lora_dropout: float = 0.0,\n        init_lora_weights: Union[bool, str] = True,\n        use_rslora: bool = False,\n        use_dora: bool = False,\n        lora_bias: bool = False,\n        **kwargs,\n    ) -> None:\n        if lora_bias:\n            # lora_bias=True is not supported (yet) for embedding layers, as they use nn.Parameter\n            raise ValueError(f\"lora_bias={lora_bias} is not supported for {self.__class__.__name__}.\")\n\n        super().__init__()\n        LoraLayer.__init__(self, orig_module)\n\n        self._active_adapter = adapter_name\n        self.update_layer(\n            adapter_name,\n            r,\n            lora_alpha=lora_alpha,\n            lora_dropout=lora_dropout,\n            init_lora_weights=init_lora_weights,\n            use_rslora=use_rslora,\n            use_dora=use_dora,\n            lora_bias=lora_bias,\n        )\n\n    def update_layer(\n        self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias\n    ):\n        if r <= 0:\n            raise ValueError(f\"`r` should be a positive integer value but the value passed is {r}\")\n\n        self.r[adapter_name] = r\n        self.lora_alpha[adapter_name] = lora_alpha\n        if lora_dropout > 0.0:\n            lora_dropout_layer = nn.Dropout(p=lora_dropout)\n        else:\n            lora_dropout_layer = nn.Identity()\n\n        self.lora_dropout[adapter_name] = lora_dropout_layer\n        # Actual trainable parameters\n        weight_A = torch.randn((r, self.in_features))\n        weight_B = torch.randn((self.out_features, r))\n        self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A)\n        self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B)\n        self.lora_bias[adapter_name] = lora_bias\n\n        if use_rslora:\n            self.scaling[adapter_name] = lora_alpha / math.sqrt(r)\n        else:\n            self.scaling[adapter_name] = lora_alpha / r\n\n        if init_lora_weights == \"loftq\":\n            self.loftq_init(adapter_name)\n        elif init_lora_weights:\n            self.reset_lora_parameters(adapter_name, init_lora_weights)\n\n        # call this before dora_init\n        self._move_adapter_to_device_of_orig_module(adapter_name)\n\n        self.set_adapter(self.active_adapters)\n\n    def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:\n        \"\"\"\n        Merge the active adapter weights into the base weights\n\n        Args:\n            safe_merge (`bool`, *optional*):\n                If True, the merge operation will be performed in a copy of the original weights and check for NaNs\n                before merging the weights. This is useful if you want to check if the merge operation will produce\n                NaNs. Defaults to `False`.\n            adapter_names (`list[str]`, *optional*):\n                The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults\n                to `None`.\n        \"\"\"\n        if not adapter_names:\n            # no adapter to merge\n            return\n\n        for active_adapter in adapter_names:\n            if active_adapter in self.lora_embedding_A.keys():\n                orig_module = self.get_orig_module()\n                if safe_merge:\n                    # Note that safe_merge will be slower than the normal merge\n                    # because of the copy operation.\n                    orig_weights = orig_module.weight.data.clone()\n                    orig_weights += self.get_delta_weight(active_adapter)\n\n                    if not torch.isfinite(orig_weights).all():\n                        raise ValueError(\n                            f\"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken\"\n                        )\n\n                    orig_module.weight.data = orig_weights\n                else:\n                    orig_module.weight.data += self.get_delta_weight(active_adapter)\n                self.merged_adapters.append(active_adapter)\n\n    def unmerge(self) -> None:\n        \"\"\"\n        This method unmerges all merged adapter layers from the base weights.\n        \"\"\"\n        if not self.merged:\n            warnings.warn(\"Already unmerged. Nothing to do.\")\n            return\n        while len(self.merged_adapters) > 0:\n            active_adapter = self.merged_adapters.pop()\n            if active_adapter in self.lora_embedding_A.keys():\n                self.get_orig_module().weight.data -= self.get_delta_weight(active_adapter)\n\n    def get_delta_weight(self, adapter) -> torch.Tensor:\n        \"\"\"\n        Compute the delta weight for the given adapter.\n\n        Args:\n            adapter (str):\n                The name of the adapter for which the delta weight should be computed.\n        \"\"\"\n        device = self.lora_embedding_B[adapter].device\n        dtype = self.lora_embedding_A[adapter].dtype\n\n        # In case users wants to merge the adapter weights that are in\n        # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to\n        # (b)float16 because some CPUs have slow bf16/fp16 matmuls.\n        cast_to_fp32 = device.type == \"cpu\" and (dtype == torch.float16 or dtype == torch.bfloat16)\n\n        weight_A = self.lora_embedding_A[adapter]\n        weight_B = self.lora_embedding_B[adapter]\n\n        if cast_to_fp32:\n            weight_A = weight_A.float()\n            weight_B = weight_B.float()\n\n        output_tensor = transpose(weight_B @ weight_A, True) * self.scaling[adapter]\n\n        if cast_to_fp32:\n            output_tensor = output_tensor.to(dtype=dtype)\n\n            # cast back the weights\n            self.lora_embedding_A[adapter] = weight_A.to(dtype)\n            self.lora_embedding_B[adapter] = weight_B.to(dtype)\n\n        return output_tensor\n\n    def _mixed_batch_forward(\n        self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any\n    ) -> torch.Tensor:\n        # This is a special method that handles the case when users pass the argument `adapter_names`. This is an\n        # extra argument that allows mixing different adapters in the same batch at inference time.\n        result = self.orig_module(x, *args, **kwargs)\n\n        unique_adapters = set(adapter_names)\n        sub_batch_indices_list = []\n        for adapter in unique_adapters:\n            sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])\n\n        for i, active_adapter in enumerate(unique_adapters):\n            if active_adapter == \"__base__\":\n                continue\n            if active_adapter not in self.lora_embedding_A.keys():\n                continue\n\n            embedding_A = self.lora_embedding_A[active_adapter].T\n            embedding_B = self.lora_embedding_B[active_adapter].T\n            scaling = self.scaling[active_adapter]\n\n            # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear\n            # layer output\n            sub_batch = x[sub_batch_indices_list[i]]\n            after_A = self._embed(sub_batch, embedding_A)\n            result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling\n\n        return result\n\n    def _embed(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:\n        orig_module = self.get_orig_module()\n        return F.embedding(\n            input,\n            weight,\n            padding_idx=orig_module.padding_idx,\n            max_norm=orig_module.max_norm,\n            norm_type=orig_module.norm_type,\n            scale_grad_by_freq=orig_module.scale_grad_by_freq,\n            sparse=orig_module.sparse,\n        )\n\n    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:\n        # TODO: no dtype conversion here, unlike in Linear, is that correct?\n        self._check_forward_args(x, *args, **kwargs)\n        adapter_names = kwargs.pop(\"adapter_names\", None)\n\n        if self.disable_adapters:\n            if self.merged:\n                self.unmerge()\n            result = self.orig_module(x, *args, **kwargs)\n        elif adapter_names is not None:\n            result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)\n        elif self.merged:\n            result = self.orig_module(x, *args, **kwargs)\n        else:\n            result = self.orig_module(x, *args, **kwargs)\n            torch_result_dtype = result.dtype\n            for active_adapter in self.active_adapters:\n                if active_adapter not in self.lora_embedding_A:\n                    continue\n                embedding_A = self.lora_embedding_A[active_adapter].T\n                embedding_B = self.lora_embedding_B[active_adapter].T\n                scaling = self.scaling[active_adapter]\n\n                if not self.use_dora[active_adapter]:\n                    after_A = self._embed(x, embedding_A)\n                    result = result + (after_A @ embedding_B) * scaling\n                else:\n                    mag_norm_scale, dora_result = self.lora_magnitude_vector[active_adapter](\n                        x,\n                        lora_A=embedding_A,\n                        lora_B=embedding_B,\n                        scaling=scaling,\n                        orig_module=self.get_orig_module(),\n                        embed_fn=self._embed,\n                    )\n                    result = mag_norm_scale * result + dora_result\n            result = result.to(torch_result_dtype)\n\n        return result\n\n    def __repr__(self) -> str:\n        rep = super().__repr__()\n        return \"lora.\" + rep\n    \nclass KTransformersLinearLora(KTransformersLinear, LoraLayer):\n    def __init__(\n        self,\n        orig_module: KTransformersLinear,\n        adapter_name: str,\n        r: int = 0,\n        lora_alpha: int = 1,\n        lora_dropout: float = 0.0,\n        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)\n        is_target_conv_1d_layer: bool = False,\n        init_lora_weights: Union[bool, str] = True,\n        use_rslora: bool = False,\n        use_dora: bool = False,\n        lora_bias: bool = False,\n        **kwargs,\n    ):\n        # super().__init__(orig_module, **kwargs)\n        # print(f\"KTransformersLinearLora:{KTransformersLinearLora.__mro__}\")\n        \n        KTransformersLinear.__init__(\n            self,\n            key=orig_module.key,\n            gguf_loader=orig_module.gguf_loader,\n            config=orig_module.config,\n            orig_module=orig_module.orig_module,\n            generate_device=orig_module.generate_device,\n            prefill_device=orig_module.prefill_device,\n            prefill_op=\"KLinearTorch\",\n            generate_op=\"KLinearTorch\",\n            **kwargs\n        )\n\n        LoraLayer.__init__(self, orig_module=orig_module.orig_module, **kwargs)\n\n        # self.load(mode = InferenceState.GENERATE) # for test\n\n        self._active_adapter = adapter_name\n\n        \n        self.update_layer(\n            adapter_name,\n            r,\n            lora_alpha=lora_alpha,\n            lora_dropout=lora_dropout,\n            init_lora_weights=init_lora_weights,\n            use_rslora=use_rslora,\n            use_dora=use_dora,\n            lora_bias=lora_bias,\n        )\n\n        self.is_target_conv_1d_layer = is_target_conv_1d_layer\n\n    def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:\n        if not adapter_names:\n            return\n\n        for active_adapter in adapter_names:\n            if active_adapter in self.lora_A:\n                orig_module = self.get_orig_module()\n                if safe_merge:\n                    orig_weights = orig_module.weight.data.clone()\n                    delta_weight = self.get_delta_weight(active_adapter)\n                    if not self.use_dora.get(active_adapter, False):\n                        orig_weights += delta_weight\n                    else:\n                        weight_norm = self.lora_magnitude_vector[active_adapter].get_weight_norm(\n                            orig_weights, \n                            transpose(delta_weight, self.fan_in_fan_out), \n                            scaling=1\n                        ).detach()\n                        self._cache_store(f\"{active_adapter}-weight_norm\", weight_norm)\n                        dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm\n                        dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)\n                        orig_weights = dora_factor * (orig_weights + delta_weight)\n\n                    if not torch.isfinite(orig_weights).all():\n                        raise ValueError(f\"NaNs detected when merging adapter {active_adapter}\")\n                    orig_module.weight.data = orig_weights\n\n                    if self.lora_bias.get(active_adapter, False):\n                        new_bias = orig_module.bias.data + self.lora_B[active_adapter].bias\n                        if not torch.isfinite(new_bias).all():\n                            raise ValueError(f\"NaNs detected in bias when merging adapter {active_adapter}\")\n                        orig_module.bias.data = new_bias\n                else:\n                    delta_weight = self.get_delta_weight(active_adapter)\n                    if not self.use_dora.get(active_adapter, False):\n                        orig_module.weight.data += delta_weight\n                    else:\n                        weight_norm = self.lora_magnitude_vector[active_adapter].get_weight_norm(\n                            orig_module.weight.data,\n                            transpose(delta_weight, self.fan_in_fan_out),\n                            scaling=1\n                        ).detach()\n                        self._cache_store(f\"{active_adapter}-weight_norm\", weight_norm)\n                        dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm\n                        dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)\n                        orig_module.weight.data = dora_factor * (orig_module.weight.data + delta_weight)\n\n                    if self.lora_bias.get(active_adapter, False):\n                        orig_module.bias.data += self.lora_B[active_adapter].bias\n\n                self.merged_adapters.append(active_adapter)\n\n    def unmerge(self) -> None:\n        if not self.merged:\n            warnings.warn(\"Already unmerged. Nothing to do.\")\n            return\n        while self.merged_adapters:\n            active_adapter = self.merged_adapters.pop()\n            if active_adapter in self.lora_A:\n                orig_module = self.get_orig_module()\n                delta_weight = self.get_delta_weight(active_adapter)\n                if not self.use_dora.get(active_adapter, False):\n                    orig_module.weight.data -= delta_weight\n                else:\n                    weight_norm = self._cache_pop(f\"{active_adapter}-weight_norm\")\n                    dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm\n                    orig_weights = orig_module.weight.data / dora_factor.view(-1, 1) - delta_weight\n                    orig_module.weight.data = orig_weights\n\n                if self.lora_bias.get(active_adapter, False):\n                    orig_module.bias.data -= self.lora_B[active_adapter].bias\n\n    def get_delta_weight(self, adapter: str) -> torch.Tensor:\n        lora_A = self.lora_A[adapter].weight\n        lora_B = self.lora_B[adapter].weight\n        delta_weight = transpose(lora_B @ lora_A, self.fan_in_fan_out) * self.scaling[adapter]\n        return delta_weight\n\n    def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:\n        result = super().forward(x, *args, **kwargs)\n        if self.disable_adapters or self.merged:\n            return result\n\n        for active_adapter in self.active_adapters:\n            if active_adapter not in self.lora_A:\n                continue\n            lora_A = self.lora_A[active_adapter]\n            lora_B = self.lora_B[active_adapter]\n            dropout = self.lora_dropout[active_adapter]\n            scaling = self.scaling[active_adapter]\n            x = dropout(x)\n            x = x.to(lora_A.weight.dtype)\n            lora_output = lora_B(lora_A(x)) * scaling\n            result += lora_output.to(result.dtype)\n\n        return result\n    \n"
  },
  {
    "path": "kt-sft/ktransformers/sft/peft_utils/lora_model.py",
    "content": "# Copyright 2023-present the HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom __future__ import annotations\n\nfrom abc import ABC\nimport math\nimport operator\nimport warnings\nfrom contextlib import contextmanager\nfrom dataclasses import asdict, replace\nfrom enum import Enum\nfrom functools import partial, reduce\nfrom typing import Literal, Optional, Union\nimport logging\n\nimport torch\nfrom torch import nn\nfrom tqdm import tqdm\n\nfrom peft.utils.other import get_pattern_key\nfrom peft.utils import ModulesToSaveWrapper, _get_submodules\nfrom peft.tuners.tuners_utils import check_target_module_exists\nfrom peft.config import PeftConfig\n\nfrom ktransformers.sft.peft_utils.lora_layer import dispatch_default, LoraLayer, BaseTunerLayer\n\nlogger = logging.getLogger(__name__)\n\nclass LoraModel(nn.Module, ABC):\n    \"\"\"\n    Creates Low Rank Adapter (LoRA) model from a pretrained transformers model.\n\n    The method is described in detail in https://arxiv.org/abs/2106.09685.\n\n    Args:\n        model ([`torch.nn.Module`]): The model to be adapted.\n        config ([`LoraConfig`]): The configuration of the Lora model.\n        adapter_name (`str`): The name of the adapter, defaults to `\"default\"`.\n        low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):\n            Create empty adapter weights on meta device. Useful to speed up the loading process.\n\n    Returns:\n        `torch.nn.Module`: The Lora model.\n\n    Example:\n\n        ```py\n        >>> from transformers import AutoModelForSeq2SeqLM\n        >>> from peft import LoraModel, LoraConfig\n\n        >>> config = LoraConfig(\n        ...     task_type=\"SEQ_2_SEQ_LM\",\n        ...     r=8,\n        ...     lora_alpha=32,\n        ...     target_modules=[\"q\", \"v\"],\n        ...     lora_dropout=0.01,\n        ... )\n\n        >>> model = AutoModelForSeq2SeqLM.from_pretrained(\"t5-base\")\n        >>> lora_model = LoraModel(model, config, \"default\")\n        ```\n\n        ```py\n        >>> import torch\n        >>> import transformers\n        >>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training\n\n        >>> rank = ...\n        >>> target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\", \"fc_in\", \"fc_out\", \"wte\"]\n        >>> config = LoraConfig(\n        ...     r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias=\"none\", task_type=\"CAUSAL_LM\"\n        ... )\n        >>> quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)\n\n        >>> tokenizer = transformers.AutoTokenizer.from_pretrained(\n        ...     \"kakaobrain/kogpt\",\n        ...     revision=\"KoGPT6B-ryan1.5b-float16\",  # or float32 version: revision=KoGPT6B-ryan1.5b\n        ...     bos_token=\"[BOS]\",\n        ...     eos_token=\"[EOS]\",\n        ...     unk_token=\"[UNK]\",\n        ...     pad_token=\"[PAD]\",\n        ...     mask_token=\"[MASK]\",\n        ... )\n        >>> model = transformers.GPTJForCausalLM.from_pretrained(\n        ...     \"kakaobrain/kogpt\",\n        ...     revision=\"KoGPT6B-ryan1.5b-float16\",  # or float32 version: revision=KoGPT6B-ryan1.5b\n        ...     pad_token_id=tokenizer.eos_token_id,\n        ...     use_cache=False,\n        ...     device_map={\"\": rank},\n        ...     torch_dtype=torch.float16,\n        ...     quantization_config=quantization_config,\n        ... )\n        >>> model = prepare_model_for_kbit_training(model)\n        >>> lora_model = get_peft_model(model, config)\n        ```\n\n    **Attributes**:\n        - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.\n        - **peft_config** ([`LoraConfig`]): The configuration of the Lora model.\n    \"\"\"\n\n    prefix: str = \"lora_\"\n\n    def __init__(\n        self,\n        model,\n        peft_config: Union[PeftConfig, dict[str, PeftConfig]],\n        adapter_name: str,\n        low_cpu_mem_usage: bool = False,\n    ) -> None:\n        super().__init__()\n\n        self.model = model\n        self.targeted_module_names: list[str] = []\n\n        # For advanced developers, if you want to attach multiple adapters to your\n        # model, just add a `peft_config` dict attribute to your model.\n        if not hasattr(self, \"peft_config\"):\n            self.peft_config = {adapter_name: peft_config} if isinstance(peft_config, PeftConfig) else peft_config\n        else:\n            logger.info(\n                \"Already found a `peft_config` attribute in the model. This will lead to having multiple adapters\"\n                \" in the model. Make sure to know what you are doing!\"\n            )\n            if isinstance(peft_config, PeftConfig):\n                self.peft_config[adapter_name] = peft_config\n            else:\n                # user is adding a dict of PeftConfigs\n                self.peft_config.update(peft_config)\n\n        self.active_adapter: str | list[str] = adapter_name\n        self._pre_injection_hook(self.model, self.peft_config[adapter_name], adapter_name)\n        \n        self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)\n\n        # Copy the peft_config in the injected model.\n        self.model.peft_config = self.peft_config\n\n    def inject_adapter(\n        self, model: nn.Module, adapter_name: str, autocast_adapter_dtype: bool = True, low_cpu_mem_usage: bool = False\n    ) -> None:\n        r\"\"\"\n        Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the\n        hood by `peft.mapping.get_peft_model` if a non-prompt tuning adapter class is passed.\n\n        The corresponding PEFT config is directly retrieved from the `peft_config` attribute of the BaseTuner class.\n\n        Args:\n            model (`nn.Module`):\n                The model to be tuned.\n            adapter_name (`str`):\n                The adapter name.\n            autocast_adapter_dtype (`bool`, *optional*):\n                Whether to autocast the adapter dtype. Defaults to `True`.\n            low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):\n                Create empty adapter weights on meta device. Useful to speed up the loading process.\n\n        \"\"\"\n        peft_config = self.peft_config[adapter_name]\n        excluded_modules = []\n        unmatched_modules = []\n        # Note: If possible, all checks should be performed *at the start of this method*.\n        # This way, we can raise early if something goes wrong, without leaving the model\n        # in a bad (half-initialized) state.\n\n        _check_for_modules_to_save = getattr(peft_config, \"modules_to_save\", None) is not None\n        _has_modules_to_save = False\n\n        key_list = [key for key, _ in model.named_modules()]\n\n        for key in key_list:\n            if not key:\n                continue\n            # Check for modules_to_save in case\n            if _check_for_modules_to_save and any(\n                key.endswith(f\"{module_to_save}\") for module_to_save in peft_config.modules_to_save\n            ):\n                # Optionally set the modules to save\n                parent, target, target_name = _get_submodules(model, key)\n\n                if not isinstance(target, ModulesToSaveWrapper):\n                    new_module = ModulesToSaveWrapper(target, adapter_name)\n                    setattr(parent, target_name, new_module)\n                else:\n                    target.update(adapter_name)\n\n                _has_modules_to_save = True\n                continue\n\n            result = check_target_module_exists(peft_config, key)\n            if not result:\n                unmatched_modules.append(key)\n            else:\n                self.targeted_module_names.append(key)\n                parent, target, target_name = _get_submodules(model, key)\n\n                # TODO: not consider the low_cpu_mem_usage up to now\n                self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)\n\n        # It's important to set the adapter here (again), because otherwise it can happen that if a 2nd adapter is\n        # added, and it targets different layer(s) than the first adapter (which is active), then those different\n        # layers will be activated, which we don't want.\n        # TODO: not consider multi-adapter up to now\n        # self.set_adapter(self.active_adapters)\n        self._mark_only_adapters_as_trainable(model)\n\n        if self.peft_config[adapter_name].inference_mode:\n            for n, p in model.named_parameters():\n                if adapter_name in n:\n                    p.requires_grad = False\n\n    def _create_and_replace(\n        self,\n        lora_config,\n        adapter_name,\n        target,\n        target_name,\n        parent,\n        current_key,\n    ):\n\n        # Regexp matching - Find key which matches current target_name in patterns provided\n        r_key = get_pattern_key(lora_config.rank_pattern.keys(), current_key)\n        alpha_key = get_pattern_key(lora_config.alpha_pattern.keys(), current_key)\n        r = lora_config.rank_pattern.get(r_key, lora_config.r)\n        alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha)\n\n        kwargs = {\n            \"r\": r,\n            \"lora_alpha\": alpha,\n            \"lora_dropout\": lora_config.lora_dropout,\n            \"fan_in_fan_out\": lora_config.fan_in_fan_out,\n            \"init_lora_weights\": lora_config.init_lora_weights,\n            \"use_rslora\": lora_config.use_rslora,\n            \"use_dora\": lora_config.use_dora,\n            \"ephemeral_gpu_offload\": lora_config.runtime_config.ephemeral_gpu_offload,\n            \"lora_bias\": lora_config.lora_bias,\n            \"loaded_in_8bit\": getattr(self.model, \"is_loaded_in_8bit\", False),\n            \"loaded_in_4bit\": getattr(self.model, \"is_loaded_in_4bit\", False),\n        }\n\n        new_module = self._create_new_module(lora_config, adapter_name, target, parent, **kwargs)\n        self._replace_module(parent, target_name, new_module, target)\n\n    def _replace_module(self, parent, child_name, new_module, child):\n        setattr(parent, child_name, new_module)\n        # It's not necessary to set requires_grad here, as that is handled by\n        # _mark_only_adapters_as_trainable\n\n        # child layer wraps the original module, unpack it\n        if hasattr(child, \"orig_module\"):\n            child = child.orig_module\n\n        if not hasattr(new_module, \"orig_module\"):\n            if hasattr(new_module, \"W_q\"):  # HQQ\n                new_module.W_q = child.W_q\n            else:\n                new_module.weight = child.weight\n            if hasattr(child, \"bias\"):\n                new_module.bias = child.bias\n\n        if getattr(child, \"state\", None) is not None:\n            if hasattr(new_module, \"orig_module\"):\n                new_module.orig_module.state = child.state\n            else:\n                new_module.state = child.state\n            new_module.to(child.weight.device)\n\n        meta = torch.device(\"meta\")\n        # dispatch to correct device\n        for name, module in new_module.named_modules():\n            if (self.prefix in name) or (\"ranknum\" in name):\n                weight = (\n                    child.qweight\n                    if hasattr(child, \"qweight\")\n                    else child.W_q\n                    if hasattr(child, \"W_q\")\n                    else child.weight\n                    if hasattr(child, \"weight\")\n                    else child.generate_linear.weight\n                    if hasattr(child.generate_linear, \"weight\")\n                    else next(child.parameters())\n                )\n                # (orig_module): Lora.Linear(\n                    # (orig_module): Linear(..),\n                    # (Lora_A): Linear(..)...)\n                \n                if not any(p.device == meta for p in module.parameters()):\n                    module.to(weight.device)\n\n    def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:\n        for n, p in model.named_parameters():\n            if self.prefix not in n:\n                p.requires_grad = False\n\n        for active_adapter in self.active_adapters:\n            bias = self.peft_config[active_adapter].bias\n            if bias == \"none\":\n                continue\n\n            if bias == \"all\":\n                for n, p in model.named_parameters():\n                    if \"bias\" in n:\n                        p.requires_grad = True\n            elif bias == \"lora_only\":\n                for m in model.modules():\n                    if isinstance(m, LoraLayer) and hasattr(m, \"bias\") and m.bias is not None:\n                        m.bias.requires_grad = True\n            else:\n                raise NotImplementedError(f\"Requested bias: {bias}, is not implemented.\")\n\n    @staticmethod\n    def _create_new_module(lora_config, adapter_name, target, parent, **kwargs):\n        # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters,\n        # because the first match is always used. Therefore, the default layers should be checked last.\n        dispatchers = []\n\n        dispatchers.extend(\n            [\n                dispatch_default, # TODO\n            ]\n        )\n\n        new_module = None\n        for dispatcher in dispatchers:\n            new_module = dispatcher(target=target, adapter_name=adapter_name, lora_config=lora_config, **kwargs)\n            if new_module is not None:  # first match wins\n                break\n\n        if new_module is None:\n            # no module could be matched\n            raise ValueError(\n                f\"Target module {target} is not supported. Currently, only the following modules are supported: \"\n                \"`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, \"\n                \"`transformers.pytorch_utils.Conv1D`.\"\n            )\n\n        return new_module\n\n    def __getattr__(self, name: str):\n        \"\"\"Forward missing attributes to the wrapped module.\"\"\"\n        try:\n            return super().__getattr__(name)  # defer to nn.Module's logic\n        except AttributeError:\n            if name == \"model\":  # see #1892: prevent infinite recursion if class is not initialized\n                raise\n            return getattr(self.model, name)\n\n    def _pre_injection_hook(self, model: nn.Module, config: PeftConfig, adapter_name: str) -> None:\n        r\"\"\"\n        A hook to be called before the adapter is injected into the model. This method can be overridden by child\n        classes to perform any pre-injection operations.\n\n        Args:\n            model (`nn.Module`):\n                The model to be adapted.\n            config (`PeftConfig`):\n                The adapter config.\n            adapter_name (`str`):\n                The adapter name.\n        \"\"\"\n        pass\n\n    def _set_adapter_layers(self, enabled: bool = True) -> None:\n        for module in self.model.modules():\n            if isinstance(module, BaseTunerLayer):\n                module.enable_adapters(enabled)\n\n    def disable_adapter_layers(self) -> None:\n        \"\"\"\n        Disable all adapters in-place.\n\n        When disabling all adapters, the model output corresponds to the output of the base model.\n        \"\"\"\n        # TODO: deprecate in favor of enable_adapters\n        self._set_adapter_layers(enabled=False)\n\n    def enable_adapter_layers(self) -> None:\n        \"\"\"\n        Enable all adapters in-place\n        \"\"\"\n        # TODO: deprecate in favor of enable_adapters\n        self._set_adapter_layers(enabled=True)\n        \n        \n    # def set_adapter(self, adapter_names: str | list[str]) -> None:\n    #     \"\"\"Set the active adapter(s).\n\n    #     Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is\n    #     not desired, use the following code.\n\n    #     ```py\n    #     >>> for name, param in model_peft.named_parameters():\n    #     ...     if ...:  # some check on name (ex. if 'lora' in name)\n    #     ...         param.requires_grad = False\n    #     ```\n\n    #     Args:\n    #         adapter_name (`str` or `List[str]`): Name of the adapter(s) to be activated.\n    #     \"\"\"\n    #     if isinstance(adapter_names, str):\n    #         adapter_names = [adapter_names]\n\n    #     # Deactivate grads on the inactive adapter and activate grads on the active adapter\n    #     for layer_name in self.adapter_layer_names:\n    #         module_dict = getattr(self, layer_name)\n    #         for key, layer in module_dict.items():\n    #             if key in adapter_names:\n    #                 # Note: It is possible that not a single layer is called with requires_grad_(True) here. This may\n    #                 # happen if a completely different adapter layer is being activated.\n    #                 layer.requires_grad_(True)\n    #             else:\n    #                 layer.requires_grad_(False)\n\n    #     self._active_adapter = adapter_names\n    \n    @property\n    def active_adapters(self) -> list[str]:\n        if isinstance(self.active_adapter, str):\n            return [self.active_adapter]\n        # is already a list of str\n        return self.active_adapter"
  },
  {
    "path": "kt-sft/ktransformers/sft/peft_utils/mapping.py",
    "content": "import torch\nfrom transformers import PreTrainedModel\nimport warnings\nfrom typing import TYPE_CHECKING, Any, Optional\n\nfrom peft.config import PeftConfig\n\nfrom ktransformers.sft.peft_utils.lora_model import LoraModel\nfrom ktransformers.sft.peft_utils.peft_model import PeftModel, PeftModelForCausalLM\n\ndef get_peft_model(\n    model: PreTrainedModel,\n    peft_config: PeftConfig,\n    adapter_name: str = \"default\",\n    mixed: bool = False,\n    autocast_adapter_dtype: bool = True,\n    revision: Optional[str] = None,\n    low_cpu_mem_usage: bool = False,\n) -> PeftModel:\n    \"\"\"\n    Returns a Peft model object from a model and a config.\n\n    Args:\n        model ([`transformers.PreTrainedModel`]):\n            Model to be wrapped.\n        peft_config ([`PeftConfig`]):\n            Configuration object containing the parameters of the Peft model.\n        adapter_name (`str`, `optional`, defaults to `\"default\"`):\n            The name of the adapter to be injected, if not provided, the default adapter name is used (\"default\").\n        mixed (`bool`, `optional`, defaults to `False`):\n            Whether to allow mixing different (compatible) adapter types.\n        autocast_adapter_dtype (`bool`, *optional*):\n            Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights\n            using float16 or bfloat16 to float32, as this is typically required for stable training, and only affect\n            select PEFT tuners.\n        revision (`str`, `optional`, defaults to `main`):\n            The revision of the base model. If this isn't set, the saved peft model will load the `main` revision for\n            the base model\n        low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):\n            Create empty adapter weights on meta device. Useful to speed up the loading process. Leave this setting as\n            False if you intend on training the model, unless the adapter weights will be replaced by different weights\n            before training starts.\n    \"\"\"\n    new_name = model.__dict__.get(\"name_or_path\", None)\n    peft_config.base_model_name_or_path = new_name\n\n    return PeftModelForCausalLM(\n        model,\n        peft_config,\n        adapter_name=adapter_name,\n        autocast_adapter_dtype=autocast_adapter_dtype,\n        low_cpu_mem_usage=low_cpu_mem_usage,\n    )\n\ndef inject_adapter_in_model(\n    peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = \"default\", low_cpu_mem_usage: bool = False\n) -> torch.nn.Module:\n    r\"\"\"\n    A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning\n    methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API\n    calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods.\n\n    Args:\n        peft_config (`PeftConfig`):\n            Configuration object containing the parameters of the Peft model.\n        model (`torch.nn.Module`):\n            The input model where the adapter will be injected.\n        adapter_name (`str`, `optional`, defaults to `\"default\"`):\n            The name of the adapter to be injected, if not provided, the default adapter name is used (\"default\").\n        low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):\n            Create empty adapter weights on meta device. Useful to speed up the loading process.\n    \"\"\"\n    # tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[\"LORA\"]\n\n    # By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.\n    peft_model = LoraModel(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)\n\n    return peft_model.model\n"
  },
  {
    "path": "kt-sft/ktransformers/sft/peft_utils/peft_model.py",
    "content": "# Copyright 2023-present the HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport collections\nimport copy\nimport inspect\nimport os\nimport warnings\nfrom contextlib import contextmanager, nullcontext\nfrom copy import deepcopy\nfrom dataclasses import dataclass\nfrom typing import Any, Literal, Optional, Union\n\nimport packaging.version\nfrom peft import __version__\nimport torch\nimport transformers\nfrom accelerate import dispatch_model, infer_auto_device_map, init_empty_weights\nfrom accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules\nfrom accelerate.utils import get_balanced_memory, named_module_tensors\nfrom huggingface_hub import HfFileSystem, ModelCard, ModelCardData, hf_hub_download\nfrom safetensors import safe_open\nfrom safetensors.torch import save_file as safe_save_file\nfrom torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\nfrom transformers import Cache, DynamicCache, EncoderDecoderCache, PreTrainedModel\nfrom transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput\nfrom transformers.utils import PushToHubMixin\n\nfrom peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING\n\nfrom peft.config import PeftConfig\nfrom .lora_layer import BaseTunerLayer\nfrom peft.utils import (\n    SAFETENSORS_WEIGHTS_NAME,\n    TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,\n    WEIGHTS_NAME,\n    PeftType,\n    TaskType,\n    _get_batch_size,\n    _prepare_prompt_learning_config,\n    _set_adapter,\n    _set_trainable,\n    get_peft_model_state_dict,\n    id_tensor_storage,\n    infer_device,\n    load_peft_weights,\n    map_cache_to_layer_device_map,\n    set_peft_model_state_dict,\n    shift_tokens_right,\n)\n\nfrom ktransformers.sft.peft_utils.lora_model import LoraModel\n\n\nclass PeftModel(PushToHubMixin, torch.nn.Module):\n    \"\"\"\n    Base model encompassing various Peft methods.\n\n    Args:\n        model ([`~transformers.PreTrainedModel`]): The base transformer model used for Peft.\n        peft_config ([`PeftConfig`]): The configuration of the Peft model.\n        adapter_name (`str`,  *optional*): The name of the adapter, defaults to `\"default\"`.\n        autocast_adapter_dtype (`bool`, *optional*):\n            Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights\n            using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect\n            select PEFT tuners.\n        low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):\n            Create empty adapter weights on meta device. Useful to speed up the loading loading process.\n\n            <Tip>\n\n            Don't use `low_cpu_mem_usage=True` when creating a new PEFT adapter for training.\n\n            </Tip>\n\n    **Attributes**:\n        - **base_model** ([`torch.nn.Module`]) -- The base transformer model used for Peft.\n        - **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model.\n        - **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when\n            saving the model.\n        - **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if\n            using [`PromptLearningConfig`].\n        - **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if\n            using [`PromptLearningConfig`].\n        - **transformer_backbone_name** (`str`) -- The name of the transformer\n            backbone in the base model if using [`PromptLearningConfig`].\n        - **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone\n            in the base model if using [`PromptLearningConfig`].\n    \"\"\"\n\n    def __init__(\n        self,\n        model: PreTrainedModel,\n        peft_config: PeftConfig,\n        adapter_name: str = \"default\",\n        autocast_adapter_dtype: bool = True,\n        low_cpu_mem_usage: bool = False,\n    ) -> None:\n        super().__init__()\n        self.modules_to_save = None\n        self.active_adapter = adapter_name\n        self.peft_type = peft_config.peft_type\n        # These args are special PEFT arguments that users can pass. They need to be removed before passing them to\n        # forward.\n        self.special_peft_forward_args = {\"adapter_names\"}\n\n        self._is_prompt_learning = peft_config.is_prompt_learning\n        if self._is_prompt_learning:\n            self._peft_config = {adapter_name: peft_config}\n            self.base_model = model\n            self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage)\n        else:\n            self._peft_config = None\n            ctx = init_empty_weights if low_cpu_mem_usage else nullcontext\n            with ctx():\n                self.base_model = LoraModel(model, {adapter_name: peft_config}, adapter_name)\n            self.set_additional_trainable_modules(peft_config, adapter_name)\n\n        if hasattr(self.base_model, \"_cast_adapter_dtype\"):\n            self.base_model._cast_adapter_dtype(\n                adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype\n            )\n\n        if getattr(model, \"is_gradient_checkpointing\", True):\n            model = self._prepare_model_for_gradient_checkpointing(model)\n\n        # the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid\n        # numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected\n        # behavior we disable that in this line.\n        if hasattr(self.base_model, \"config\") and hasattr(self.base_model.config, \"pretraining_tp\"):\n            self.base_model.config.pretraining_tp = 1\n\n    @property\n    def peft_config(self) -> dict[str, PeftConfig]:\n        if self._is_prompt_learning:\n            return self._peft_config\n        return self.base_model.peft_config\n\n    @property\n    def active_adapters(self) -> list[str]:\n        try:\n            adapters = self.base_model.active_adapters\n            if not isinstance(adapters, list):\n                # Base model is probably a transformers model, see:\n                # https://github.com/huggingface/transformers/pull/30790#issuecomment-2253808249\n                # Unfortunately, transformers models also have an active_adapters method but it's 1) not a property and\n                # 2) calling it fails because the base model (usually) has no loaded adapter. The base model can be a\n                # transformers model for prompt learning, where the base model is not wrapped in a LoraModel or similar.\n                adapters = self.active_adapter\n                if isinstance(adapters, str):\n                    adapters = [adapters]\n        except AttributeError:\n            adapters = self.active_adapter\n            if isinstance(adapters, str):\n                adapters = [adapters]\n        return adapters\n\n    @peft_config.setter\n    def peft_config(self, value: dict[str, PeftConfig]):\n        if self._is_prompt_learning:\n            self._peft_config = value\n        else:\n            self.base_model.peft_config = value\n\n    def save_pretrained(\n        self,\n        save_directory: str,\n        safe_serialization: bool = True,\n        selected_adapters: Optional[list[str]] = None,\n        save_embedding_layers: Union[str, bool] = \"auto\",\n        is_main_process: bool = True,\n        path_initial_model_for_weight_conversion: Optional[str] = None,\n        **kwargs: Any,\n    ) -> None:\n        r\"\"\"\n        This function saves the adapter model and the adapter configuration files to a directory, so that it can be\n        reloaded using the [`PeftModel.from_pretrained`] class method, and also used by the [`PeftModel.push_to_hub`]\n        method.\n\n        Args:\n            save_directory (`str`):\n                Directory where the adapter model and configuration files will be saved (will be created if it does not\n                exist).\n            safe_serialization (`bool`, *optional*):\n                Whether to save the adapter files in safetensors format, defaults to `True`.\n            selected_adapters (`List[str]`,  *optional*):\n                A list of adapters to be saved. If `None`, will default to all adapters.\n            save_embedding_layers (`Union[bool, str]`, *optional*, defaults to `\"auto\"`):\n                If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common\n                embedding layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available.\n                and automatically sets the boolean flag. This only works for 🤗 transformers models.\n            is_main_process (`bool`, *optional*):\n                Whether the process calling this is the main process or not. Will default to `True`. Will not save the\n                checkpoint if not on the main process, which is important for multi device setups (e.g. DDP).\n            path_initial_model_for_weight_conversion (`str, *optional*`):\n                The path to the initialized adapter, which is obtained after initializing the model with PiSSA or OLoRA\n                and before performing any training. When `path_initial_model_for_weight_conversion` is not None, the\n                difference in adapter before and after fine-tuning is calculated. This difference can be represented as\n                the parameters of a standard LoRA adapter. Using this converted adapter does not require changes to the\n                base model, thus conveniently allowing the use of multiple PiSSA or OLoRA adapters with LoRA adapters,\n                and the activation or deactivation of any adapters. Note that this conversion is not supported if\n                `rslora` is used in combination with `rank_pattern` or `alpha_pattern`.\n            kwargs (additional keyword arguments, *optional*):\n                Additional keyword arguments passed along to the `push_to_hub` method.\n\n        \"\"\"\n        if os.path.isfile(save_directory):\n            raise ValueError(f\"Provided path ({save_directory}) should be a directory, not a file\")\n\n        if selected_adapters is None:\n            selected_adapters = list(self.peft_config.keys())\n        else:\n            if any(\n                selected_adapter_name not in list(self.peft_config.keys())\n                for selected_adapter_name in selected_adapters\n            ):\n                raise ValueError(\n                    f\"You passed an invalid `selected_adapters` arguments, current supported adapter names are\"\n                    f\" {list(self.peft_config.keys())} - got {selected_adapters}.\"\n                )\n\n        def save_mutated_as_lora(peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs):\n            if peft_config.use_rslora and (peft_config.rank_pattern or peft_config.alpha_pattern):\n                msg = (\n                    \"Passing `path_initial_model_for_weight_conversion` to `save_pretrained` is not supported when \"\n                    \"using `rank_pattern` or `alpha_pattern` at the same time as `use_rslora=True`.\"\n                )\n                raise ValueError(msg)\n\n            if not any(\n                str(peft_config.init_lora_weights).lower().startswith(prefix) for prefix in [\"pissa\", \"olora\", \"true\"]\n            ):\n                warnings.warn(\n                    \"`path_initial_model_for_weight_conversion` only works for converting a PiSSA or OLoRA adapter to \"\n                    \"a LoRA adapter\"\n                )\n            initial_adapter_name = os.path.basename(path_initial_model_for_weight_conversion)\n            try:\n                self.load_adapter(\n                    os.path.dirname(path_initial_model_for_weight_conversion),\n                    subfolder=initial_adapter_name,\n                    adapter_name=initial_adapter_name,\n                )\n                is_pissa = str(self.peft_config[initial_adapter_name].init_lora_weights).lower().startswith(\"pissa\")\n                is_olora = str(self.peft_config[initial_adapter_name].init_lora_weights).lower() == \"olora\"\n                if is_pissa or is_olora:\n                    raise ValueError(\n                        \"The `init_lora_weights` parameter of the initial adapter should be set to `True`. \"\n                        \"Otherwise, `self.load_adapter` will subtract the decomposed values again based on the \"\n                        \"residual model.\"\n                    )\n                output_state_dict = self.base_model.subtract_mutated_init(\n                    output_state_dict, initial_adapter_name, kwargs\n                )\n            finally:\n                self.delete_adapter(initial_adapter_name)\n            return output_state_dict\n\n        if is_main_process:\n            os.makedirs(save_directory, exist_ok=True)\n            self.create_or_update_model_card(save_directory)\n\n        for adapter_name in selected_adapters:\n            peft_config = self.peft_config[adapter_name]\n            # save only the trainable weights\n            output_state_dict = get_peft_model_state_dict(\n                self,\n                state_dict=kwargs.get(\"state_dict\", None),\n                adapter_name=adapter_name,\n                save_embedding_layers=save_embedding_layers,\n            )\n            output_dir = os.path.join(save_directory, adapter_name) if adapter_name != \"default\" else save_directory\n            os.makedirs(output_dir, exist_ok=True)\n\n            if is_main_process and safe_serialization:\n                # Section copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2111-L2134\n                # Safetensors does not allow tensor aliasing.\n                # We're going to remove aliases before saving\n                ptrs = collections.defaultdict(list)\n                for name, tensor in output_state_dict.items():\n                    # Sometimes in the state_dict we have non-tensor objects.\n                    # e.g. in bitsandbytes we have some `str` objects in the state_dict\n                    if isinstance(tensor, torch.Tensor):\n                        ptrs[id_tensor_storage(tensor)].append(name)\n                    else:\n                        # In the non-tensor case, fall back to the pointer of the object itself\n                        ptrs[id(tensor)].append(name)\n\n                # These are all the pointers of shared tensors.\n                shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}\n\n                for _, names in shared_ptrs.items():\n                    # Here we just clone the shared tensors to avoid tensor aliasing which is\n                    # not supported in safetensors.\n                    for shared_tensor_name in names[1:]:\n                        output_state_dict[shared_tensor_name] = output_state_dict[shared_tensor_name].clone()\n                if path_initial_model_for_weight_conversion is not None:\n                    peft_config = copy.deepcopy(peft_config)\n                    peft_config.init_lora_weights = True\n                    peft_config.save_pretrained(path_initial_model_for_weight_conversion)\n                    output_state_dict = save_mutated_as_lora(\n                        peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs\n                    )\n                safe_save_file(\n                    output_state_dict,\n                    os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME),\n                    metadata={\"format\": \"pt\"},\n                )\n            elif is_main_process:\n                if path_initial_model_for_weight_conversion is not None:\n                    peft_config = copy.deepcopy(peft_config)\n                    peft_config.init_lora_weights = True\n                    peft_config.save_pretrained(path_initial_model_for_weight_conversion)\n                    output_state_dict = save_mutated_as_lora(\n                        peft_config, path_initial_model_for_weight_conversion, output_state_dict, kwargs\n                    )\n                torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME))\n\n            # save the config and change the inference mode to `True`\n            if peft_config.base_model_name_or_path is None:\n                peft_config.base_model_name_or_path = (\n                    self.base_model.__dict__.get(\"name_or_path\", None)\n                    if peft_config.is_prompt_learning\n                    else self.base_model.model.__dict__.get(\"name_or_path\", None)\n                )\n            inference_mode = peft_config.inference_mode\n            peft_config.inference_mode = True\n\n            if peft_config.task_type is None:\n                # deal with auto mapping\n                base_model_class = self._get_base_model_class(\n                    is_prompt_tuning=peft_config.is_prompt_learning,\n                )\n                parent_library = base_model_class.__module__\n\n                auto_mapping_dict = {\n                    \"base_model_class\": base_model_class.__name__,\n                    \"parent_library\": parent_library,\n                }\n            else:\n                auto_mapping_dict = None\n\n            if is_main_process:\n                if path_initial_model_for_weight_conversion is not None:\n                    peft_config.init_lora_weights = True\n                    peft_config.r *= 2\n                    if not peft_config.use_rslora:\n                        peft_config.lora_alpha *= 2\n                    else:\n                        # with rslora, we have scaling = alpha / sqrt(r), we thus adjust alpha to keep the same scaling\n                        peft_config.lora_alpha *= 2**0.5\n\n                    if peft_config.rank_pattern:\n                        peft_config.rank_pattern = {key: 2 * val for key, val in peft_config.rank_pattern.items()}\n                    if peft_config.alpha_pattern:\n                        peft_config.alpha_pattern = {key: 2 * val for key, val in peft_config.alpha_pattern.items()}\n\n                peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict)\n            peft_config.inference_mode = inference_mode\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        model: torch.nn.Module,\n        model_id: Union[str, os.PathLike],\n        adapter_name: str = \"default\",\n        is_trainable: bool = False,\n        config: Optional[PeftConfig] = None,\n        autocast_adapter_dtype: bool = True,\n        ephemeral_gpu_offload: bool = False,\n        low_cpu_mem_usage: bool = False,\n        **kwargs: Any,\n    ) -> PeftModel:\n        r\"\"\"\n        Instantiate a PEFT model from a pretrained model and loaded PEFT weights.\n\n        Note that the passed `model` may be modified inplace.\n\n        Args:\n            model ([`torch.nn.Module`]):\n                The model to be adapted. For 🤗 Transformers models, the model should be initialized with the\n                [`~transformers.PreTrainedModel.from_pretrained`].\n            model_id (`str` or `os.PathLike`):\n                The name of the PEFT configuration to use. Can be either:\n                    - A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face\n                      Hub.\n                    - A path to a directory containing a PEFT configuration file saved using the `save_pretrained`\n                      method (`./my_peft_config_directory/`).\n            adapter_name (`str`, *optional*, defaults to `\"default\"`):\n                The name of the adapter to be loaded. This is useful for loading multiple adapters.\n            is_trainable (`bool`, *optional*, defaults to `False`):\n                Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be\n                used for inference.\n            config ([`~peft.PeftConfig`], *optional*):\n                The configuration object to use instead of an automatically loaded configuration. This configuration\n                object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already\n                loaded before calling `from_pretrained`.\n            autocast_adapter_dtype (`bool`, *optional*):\n                Whether to autocast the adapter dtype. Defaults to `True`. Only relevant for specific adapter types.\n            ephemeral_gpu_offload (`bool`, *optional*):\n                Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`. This is\n                useful when parts of the model and/or components (such as adapters) are kept in CPU memory until they\n                are needed. Rather than perform expensive operations on small data, the data is transferred to the GPU\n                on-demand, the operation(s) performed, and the results moved back to CPU memory. This brings a slight\n                momentary VRAM overhead but gives orders of magnitude speedup in certain cases.\n            low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):\n                Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the\n                process.\n            torch_device (`str`, *optional*, defaults to None):\n                The device to load the adapter on. If `None`, the device will be inferred.\n            kwargs: (`optional`):\n                Additional keyword arguments passed along to the specific PEFT configuration class.\n        \"\"\"\n        from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING\n\n        # load the config\n        if config is None:\n            config = PEFT_TYPE_TO_CONFIG_MAPPING[\n                PeftConfig._get_peft_type(\n                    model_id,\n                    subfolder=kwargs.get(\"subfolder\", None),\n                    revision=kwargs.get(\"revision\", None),\n                    cache_dir=kwargs.get(\"cache_dir\", None),\n                    use_auth_token=kwargs.get(\"use_auth_token\", None),\n                    token=kwargs.get(\"token\", None),\n                )\n            ].from_pretrained(model_id, **kwargs)\n        elif isinstance(config, PeftConfig):\n            config.inference_mode = not is_trainable\n        else:\n            raise ValueError(f\"The input config must be a PeftConfig, got {config.__class__}\")\n\n        # Runtime configuration, if supported\n        if hasattr(config, \"runtime_config\"):\n            config.runtime_config.ephemeral_gpu_offload = ephemeral_gpu_offload\n        else:\n            if ephemeral_gpu_offload:\n                warnings.warn(\"Ephemeral GPU offloading is not supported for this model. Ignoring.\")\n\n        if hasattr(model, \"hf_device_map\"):\n            weight_map = dict(named_module_tensors(model, recurse=True))\n\n            # recreate the offload_index for disk-offloaded modules: we need to know the location in storage of each weight\n            # before the offload hook is removed from the model\n            disk_modules = set()\n            index = None\n            for name, module in model.named_modules():\n                if hasattr(module, \"_hf_hook\") and hasattr(module._hf_hook, \"original_devices\"):\n                    if hasattr(module._hf_hook.weights_map, \"dataset\"):\n                        index = module._hf_hook.weights_map.dataset.index\n                    for key in module._hf_hook.original_devices.keys():\n                        if module._hf_hook.original_devices[key] == torch.device(\"meta\"):\n                            disk_modules.add(str(name) + \".\" + str(key))\n\n            if disk_modules and not kwargs.get(\"use_safetensors\", True):\n                raise ValueError(\"Disk offloading currently only supported for safetensors\")\n\n            if index:\n                offload_index = {\n                    p: {\n                        \"safetensors_file\": index[p][\"safetensors_file\"],\n                        \"weight_name\": p,\n                        \"dtype\": str(weight_map[p].dtype).replace(\"torch.\", \"\"),\n                    }\n                    for p in weight_map.keys()\n                    if p in disk_modules\n                }\n                kwargs[\"offload_index\"] = offload_index\n\n        if (getattr(model, \"hf_device_map\", None) is not None) and len(\n            set(model.hf_device_map.values()).intersection({\"cpu\", \"disk\"})\n        ) > 0:\n            remove_hook_from_submodules(model)\n\n        if config.is_prompt_learning and is_trainable:\n            raise ValueError(\"Cannot set a prompt learning adapter to trainable when loading pretrained adapter.\")\n        else:\n            config.inference_mode = not is_trainable\n        if isinstance(getattr(model, \"base_model\", None), XLoraModel):\n            if not isinstance(config, XLoraConfig):\n                raise TypeError(f\"Expected 'XLoraConfig', got '{type(config)}' instead.\")\n            if \"adapters\" in kwargs:\n                config.adapters = kwargs[\"adapters\"]\n            else:\n                # If the path is on HF hub, then we get the adapter names to create a subfolders list which tells\n                # `load_adapter` where the adapters are.\n                if not os.path.exists(model_id):\n                    s = HfFileSystem()\n\n                    # The names of the adapters which must be in folders\n                    adapter_names = [\n                        file[\"name\"][len(model_id) + 1 :] for file in s.ls(model_id) if file[\"type\"] == \"directory\"\n                    ]\n                    # Prepare a dict of adapter paths, which really just point to the hf id; we will use the subfolders\n                    adapter_paths = {}\n                    for adapter_name in adapter_names:\n                        adapter_paths[adapter_name] = os.path.join(model_id, model_id)\n                    config.adapters = adapter_paths\n                    config._subfolders = adapter_names\n                else:\n                    if \"adapters\" not in kwargs:\n                        raise ValueError(\"If model_id is a local path, then `adapters` must be passed in kwargs.\")\n\n        if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():\n            model = cls(\n                model,\n                config,\n                adapter_name,\n                autocast_adapter_dtype=autocast_adapter_dtype,\n                low_cpu_mem_usage=low_cpu_mem_usage,\n            )\n        else:\n            model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](\n                model,\n                config,\n                adapter_name,\n                autocast_adapter_dtype=autocast_adapter_dtype,\n                low_cpu_mem_usage=low_cpu_mem_usage,\n            )\n\n        load_result = model.load_adapter(\n            model_id,\n            adapter_name,\n            is_trainable=is_trainable,\n            autocast_adapter_dtype=autocast_adapter_dtype,\n            low_cpu_mem_usage=low_cpu_mem_usage,\n            **kwargs,\n        )\n\n        # 1. Remove VB-LoRA vector bank, since it's a shared parameter set via the VBLoRAModel\n        # 2. Remove the prompt encoder, as it does not need to be part of the checkpoint\n        missing_keys = [\n            k for k in load_result.missing_keys if \"vblora_vector_bank\" not in k and \"prompt_encoder\" not in k\n        ]\n        if missing_keys:\n            # Let's warn here since (in contrast to load_adapter) we don't return the load result, so it could be quite\n            # difficult for users to even notice that something might have gone wrong here. As we filter out non PEFT\n            # keys from the missing keys, this gives no false positives.\n            warnings.warn(f\"Found missing adapter keys while loading the checkpoint: {missing_keys}\")\n\n        return model\n\n    def _setup_prompt_encoder(self, adapter_name: str):\n        config = self.peft_config[adapter_name]\n        if not hasattr(self, \"prompt_encoder\"):\n            self.prompt_encoder = torch.nn.ModuleDict({})\n            self.prompt_tokens = {}\n        transformer_backbone = None\n        for name, module in self.base_model.named_children():\n            for param in module.parameters():\n                param.requires_grad = False\n            if isinstance(module, PreTrainedModel):\n                # Make sure to freeze Tranformers model\n                if transformer_backbone is None:\n                    transformer_backbone = module\n                    self.transformer_backbone_name = name\n        if transformer_backbone is None:\n            transformer_backbone = self.base_model\n\n        if config.num_transformer_submodules is None:\n            config.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1\n\n        # determine the word embeddings\n        word_embeddings = None\n        try:\n            # First try to find the word embeddings based on the module name, this should work for models like Bert,\n            # Roberta, Deberta, etc.\n            word_embeddings = self.base_model.get_submodule(\"embeddings.word_embeddings\")\n        except AttributeError:\n            pass\n\n        if word_embeddings is None:\n            # Word embeddings could not be determined. Next try to guess them by checking which parameter has the size\n            # of the vocab.\n            for named_param, value in list(transformer_backbone.named_parameters()):\n                # for ZeRO-3, the tensor is sharded across accelerators and deepspeed modifies it to a tensor with shape\n                # [0] the actual unsharded shape is stored in \"ds_shape\" attribute special handling is needed in case\n                # the model is initialized in deepspeed.zero.Init() context or HfDeepSpeedConfig has been called before\n                # For reference refer to issue: https://github.com/huggingface/peft/issues/996\n                deepspeed_distributed_tensor_shape = getattr(value, \"ds_shape\", None)\n\n                if value.shape[0] == self.base_model.config.vocab_size or (\n                    deepspeed_distributed_tensor_shape is not None\n                    and deepspeed_distributed_tensor_shape[0] == self.base_model.config.vocab_size\n                ):\n                    word_embeddings = transformer_backbone.get_submodule(named_param.replace(\".weight\", \"\"))\n                    break\n\n        self.word_embeddings = word_embeddings\n\n        if config.peft_type == PeftType.PROMPT_TUNING:\n            prompt_encoder = PromptEmbedding(config, self.word_embeddings)\n        elif config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:\n            prompt_encoder = MultitaskPromptEmbedding(config, self.word_embeddings)\n        elif config.peft_type == PeftType.P_TUNING:\n            prompt_encoder = PromptEncoder(config)\n        elif config.peft_type == PeftType.PREFIX_TUNING:\n            # prefix tuning now uses Cache but that won't work with gradient checkpointing\n            if any(getattr(module, \"gradient_checkpointing\", False) for module in self.get_base_model().modules()):\n                raise ValueError(\"Prefix tuning does not work with gradient checkpointing.\")\n            prompt_encoder = PrefixEncoder(config)\n        elif config.peft_type == PeftType.CPT:\n            prompt_encoder = CPTEmbedding(config, self.word_embeddings)\n        else:\n            raise ValueError(\"Not supported\")\n\n        prompt_encoder = prompt_encoder.to(self.device)\n        self.prompt_encoder.update(torch.nn.ModuleDict({adapter_name: prompt_encoder}))\n        self.prompt_tokens[adapter_name] = torch.arange(\n            config.num_virtual_tokens * config.num_transformer_submodules\n        ).long()\n\n    def _prepare_model_for_gradient_checkpointing(self, model: PreTrainedModel):\n        r\"\"\"\n        Prepares the model for gradient checkpointing if necessary\n        \"\"\"\n        if not (\n            getattr(model, \"is_loaded_in_8bit\", False)\n            or getattr(model, \"is_loaded_in_4bit\", False)\n            or getattr(model, \"is_quantized\", False)\n        ):\n            if hasattr(model, \"enable_input_require_grads\"):\n                model.enable_input_require_grads()\n            elif hasattr(model, \"get_input_embeddings\"):\n\n                def make_inputs_require_grad(module, input, output):\n                    output.requires_grad_(True)\n\n                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)\n        return model\n\n    def get_prompt_embedding_to_save(self, adapter_name: str) -> torch.Tensor:\n        \"\"\"\n        Returns the prompt embedding to save when saving the model. Only applicable when using a prompt learning\n        method.\n        \"\"\"\n        prompt_encoder = self.prompt_encoder[adapter_name]\n        prompt_tokens = (\n            self.prompt_tokens[adapter_name].unsqueeze(0).expand(1, -1).to(prompt_encoder.embedding.weight.device)\n        )\n        if self.peft_config[adapter_name].peft_type == PeftType.PREFIX_TUNING:\n            prompt_tokens = prompt_tokens[:, : self.peft_config[adapter_name].num_virtual_tokens]\n\n        if self.peft_config[adapter_name].peft_type == PeftType.MULTITASK_PROMPT_TUNING:\n            prompt_embeddings = super(MultitaskPromptEmbedding, prompt_encoder).forward(prompt_tokens)\n        else:\n            prompt_embeddings = prompt_encoder(prompt_tokens)\n\n        return prompt_embeddings[0].detach().cpu()\n\n    def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -> torch.Tensor:\n        \"\"\"\n        Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method.\n        \"\"\"\n        peft_config = self.active_peft_config\n        prompt_encoder = self.prompt_encoder[self.active_adapter]\n        prompt_tokens = (\n            self.prompt_tokens[self.active_adapter]\n            .unsqueeze(0)\n            .expand(batch_size, -1)\n            .to(prompt_encoder.embedding.weight.device)\n        )\n        if peft_config.peft_type == PeftType.PREFIX_TUNING:\n            prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]\n            if peft_config.inference_mode:\n                past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)\n            else:\n                past_key_values = prompt_encoder(prompt_tokens)\n            if self.base_model_torch_dtype is not None:\n                past_key_values = past_key_values.to(self.base_model_torch_dtype)\n            past_key_values = past_key_values.view(\n                batch_size,\n                peft_config.num_virtual_tokens,\n                peft_config.num_layers * 2,\n                peft_config.num_attention_heads,\n                peft_config.token_dim // peft_config.num_attention_heads,\n            )\n            if peft_config.num_transformer_submodules == 2:\n                past_key_values = torch.cat([past_key_values, past_key_values], dim=2)\n            past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(\n                peft_config.num_transformer_submodules * 2\n            )\n            if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:\n                post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]\n                past_key_values = post_process_fn(past_key_values)\n            elif peft_config.num_transformer_submodules == 1:\n                # Dont' apply this to encoder-decoder models and not to models requiring special processing.\n                # local import in case users use a very old transformers version\n                past_key_values = DynamicCache.from_legacy_cache(past_key_values)\n            elif peft_config.num_transformer_submodules == 2 and self.base_model._supports_cache_class:\n                # Dont' apply this to encoder-decoder models that don't support new Cachc format yet\n                # If we don't apply this, prefix-tuning fails to update cross-attn cache\n                past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)\n                past_key_values.cross_attention_cache = DynamicCache()\n                past_key_values.is_updated = {\n                    layer_idx: False for layer_idx in range(len(past_key_values.cross_attention_cache.key_cache))\n                }\n            map_cache_to_layer_device_map(self.get_base_model(), past_key_values)  # no-op if not a Cache instance\n            return past_key_values\n        else:\n            if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:\n                prompts = prompt_encoder(prompt_tokens, task_ids)\n            else:\n                if peft_config.inference_mode:\n                    prompts = prompt_encoder.embedding.weight\n                else:\n                    # Take only one prompt token sample and expand the output instead of expanding the input, see:\n                    # https://github.com/huggingface/peft/issues/2043#issuecomment-2321522577\n                    prompt_tokens = prompt_tokens[:1]\n                    prompts = prompt_encoder(prompt_tokens)\n                prompts = prompts.repeat(batch_size, 1, 1)\n            return prompts\n\n    def get_nb_trainable_parameters(self) -> tuple[int, int]:\n        r\"\"\"\n        Returns the number of trainable parameters and the number of all parameters in the model.\n        \"\"\"\n        trainable_params = 0\n        all_param = 0\n        for _, param in self.named_parameters():\n            num_params = param.numel()\n            # if using DS Zero 3 and the weights are initialized empty\n            if num_params == 0 and hasattr(param, \"ds_numel\"):\n                num_params = param.ds_numel\n\n            # Due to the design of 4bit linear layers from bitsandbytes\n            # one needs to multiply the number of parameters by 2 to get\n            # the correct number of parameters\n            if param.__class__.__name__ == \"Params4bit\":\n                if hasattr(param, \"element_size\"):\n                    num_bytes = param.element_size()\n                elif not hasattr(param, \"quant_storage\"):\n                    num_bytes = 1\n                else:\n                    num_bytes = param.quant_storage.itemsize\n                num_params = num_params * 2 * num_bytes\n\n            all_param += num_params\n            if param.requires_grad:\n                trainable_params += num_params\n\n        return trainable_params, all_param\n\n    def print_trainable_parameters(self) -> None:\n        \"\"\"\n        Prints the number of trainable parameters in the model.\n\n        Note: print_trainable_parameters() uses get_nb_trainable_parameters() which is different from\n        num_parameters(only_trainable=True) from huggingface/transformers. get_nb_trainable_parameters() returns\n        (trainable parameters, all parameters) of the Peft Model which includes modified backbone transformer model.\n        For techniques like LoRA, the backbone transformer model is modified in place with LoRA modules. However, for\n        prompt tuning, the backbone transformer model is unmodified. num_parameters(only_trainable=True) returns number\n        of trainable parameters of the backbone transformer model which can be different.\n        \"\"\"\n        trainable_params, all_param = self.get_nb_trainable_parameters()\n\n        print(\n            f\"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.4f}\"\n        )\n\n    def __getattr__(self, name: str):\n        \"\"\"Forward missing attributes to the wrapped module.\"\"\"\n        try:\n            return super().__getattr__(name)  # defer to nn.Module's logic\n        except AttributeError:\n            if name == \"base_model\":  # see #1892: prevent infinite recursion if class is not initialized\n                raise\n            return getattr(self.base_model, name)\n\n    @contextmanager\n    def _enable_peft_forward_hooks(self, *args, **kwargs):\n        # If the base model has a method called _enable_peft_forward_hooks, it is invoked as a context. Otherwise, this\n        # runs without any changes\n        if hasattr(self.base_model, \"_enable_peft_forward_hooks\"):\n            with self.base_model._enable_peft_forward_hooks(*args, **kwargs):\n                yield\n            return\n        else:\n            # nothing to enable\n            yield\n            return\n\n    def forward(self, *args: Any, **kwargs: Any):\n        \"\"\"\n        Forward pass of the model.\n        \"\"\"\n        with self._enable_peft_forward_hooks(*args, **kwargs):\n            kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}\n            return self.get_base_model()(*args, **kwargs)\n\n    def generate(self, *args, **kwargs):\n        with self._enable_peft_forward_hooks(*args, **kwargs):\n            kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}\n            return self.get_base_model().generate(*args, **kwargs)\n\n    def _get_base_model_class(self, is_prompt_tuning=False):\n        \"\"\"\n        Returns the base model class.\n        \"\"\"\n        if not is_prompt_tuning:\n            return self.base_model.model.__class__\n        return self.base_model.__class__\n\n    @contextmanager\n    def disable_adapter(self):\n        \"\"\"\n        Context manager that disables the adapter module. Use this to run inference on the base model.\n\n        Example:\n\n        ```py\n        >>> with model.disable_adapter():\n        ...     model(inputs)\n        ```\n        \"\"\"\n        if self.peft_config[self.active_adapter].is_prompt_learning:\n            try:\n                # TODO: consider replacing this patching of methods with a more robust mechanism: setting a flag and\n                # letting the underlying methods deal with it, same as how LoRA does it.\n                old_forward = self.forward\n                self.forward = self.base_model.forward\n                old_prepare_inputs_for_generation = self.prepare_inputs_for_generation\n                self.prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation\n                yield\n            finally:\n                self.forward = old_forward\n                self.prepare_inputs_for_generation = old_prepare_inputs_for_generation\n\n        elif self.peft_config[self.active_adapter].is_adaption_prompt:\n            try:\n                self.base_model.disable_adapter_layers()\n                yield\n            finally:\n                self.base_model.enable_adapter_layers()\n\n        else:  # LoRA, LoHa, etc.\n            model_status = self.get_model_status()\n            if model_status.enabled == \"irregular\":\n                warnings.warn(\n                    \"The model contains some adapter layers that are enabled and others that are disabled. \"\n                    \"This is most likely unintentional. After exiting the disable_adapter context, all adapters \"\n                    \"will be enabled\"\n                )\n            try:\n                self.base_model.disable_adapter_layers()\n                yield\n            finally:\n                if model_status.enabled is not False:\n                    # model_status.enabled is `True` or `\"irregular\"`\n                    self.base_model.enable_adapter_layers()\n\n    def get_base_model(self) -> torch.nn.Module:\n        \"\"\"\n        Returns the base model.\n        \"\"\"\n        return (\n            self.base_model\n            if (self.active_peft_config.is_prompt_learning or self.peft_type == PeftType.POLY)\n            else self.base_model.model\n        )\n\n    def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:\n        \"\"\"\n        Add an adapter to the model based on the passed configuration.\n\n        This adapter is not trained. To load a trained adapter, check out [`PeftModel.load_adapter`].\n\n        The name for the new adapter should be unique.\n\n        The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active\n        adapter.\n\n        Args:\n            adapter_name (`str`):\n                The name of the adapter to be added.\n            peft_config ([`PeftConfig`]):\n                The configuration of the adapter to be added.\n            low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):\n                Create empty adapter weights on meta device. Useful to speed up the process when loading saved\n                adapters. Don't use this option when creating a new PEFT adapter for training.\n\n        \"\"\"\n        if peft_config.peft_type != self.peft_type:\n            raise ValueError(\n                f\"Cannot combine adapters with different peft types. \"\n                f\"Found {self.peft_type} and {peft_config.peft_type}.\"\n            )\n\n        try:\n            if peft_config.is_prompt_learning:\n                self.peft_config[adapter_name] = peft_config\n                if hasattr(self.config, \"to_dict\"):\n                    dict_config = self.config.to_dict()\n                else:\n                    dict_config = self.config\n\n                peft_config = _prepare_prompt_learning_config(peft_config, dict_config)\n                self._setup_prompt_encoder(adapter_name)\n            elif peft_config.is_adaption_prompt:\n                self.base_model.add_adapter(adapter_name, peft_config)\n            else:\n                self.peft_config[adapter_name] = peft_config\n                self.base_model.inject_adapter(\n                    self.base_model.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage\n                )\n        except Exception:  # something went wrong, roll back\n            if adapter_name in self.peft_config:\n                del self.peft_config[adapter_name]\n            raise\n\n        self.set_additional_trainable_modules(peft_config, adapter_name)\n\n    def set_additional_trainable_modules(self, peft_config, adapter_name):\n        if getattr(peft_config, \"modules_to_save\", None) is not None:\n            if self.modules_to_save is None:\n                self.modules_to_save = set(peft_config.modules_to_save)\n            else:\n                self.modules_to_save.update(peft_config.modules_to_save)\n            _set_trainable(self, adapter_name)  # this may add a new ModulesToSaveWrapper\n\n    def get_layer_status(self) -> list[TunerLayerStatus]:\n        \"\"\"Get the status of each adapter layer in the model.\n\n        This method returns a list of `TunerLayerStatus` dataclass instances, each of which contains the following\n        attributes:\n\n        - `name` (`str`):\n           The name of the adapter layer, e.g. `model.encoder.block.0.layer.0.SelfAttention.q`.\n        - `module_type` (`str`):\n           The type of the adapter layer, e.g. `lora.Linear`.\n        - `enabled` (`bool`):\n           Whether the adapter layer is enabled.\n        - `active_adapters` (`list[str]`):\n           The names of the active adapters, if any, e.g. `[\"default\"]`.\n        - `merged_adapters` (`list[str]`):\n           The names of the merged adapters, if any, e.g. `[\"default\"]`.\n        - `available_adapters` (`list[str]`):\n           The names of the available adapters, e.g. `[\"default\"]`.\n\n        Args:\n            model ([`~PeftModel`]):\n                The model to get the adapter layer status from.\n\n        Returns:\n            list[`peft.peft_model.TunerLayerStatus`]:\n                A list of dataclasses, each containing the status of the corresponding adapter layer.\n\n        \"\"\"\n        return get_layer_status(self)\n\n    def get_model_status(self) -> TunerModelStatus:\n        \"\"\"Get the status of tuners of the model.\n\n        This method returns a `TunerModelStatus` dataclass instance, which contains the following attributes:\n\n        - `base_model_type` (`str`):\n           The type of the base model, e.g. `T5Model`.\n        - `adapter_model_type` (`str`):\n           The type of the adapter model, e.g. `LoraModel`.\n        - `peft_types` (`dict[str, str]`):\n           The mapping of adapter name to adapter type, e.g. `{\"default\": \"LORA\"}`.\n        - `trainable_params` (`int`):\n           The number of trainable parameters in the model.\n        - `total_params` (`int`):\n           The total number of parameters in the model.\n        - `num_adapter_layers` (`int`):\n           The number of adapter layers in the model.\n        - `enabled` (`bool`, `Literal[\"irregular\"]`):\n           Whether all adapter layers are enabled. If some are enabled and some are not, this will be `\"irregular\"`.\n           This means that your model is in an inconsistent state and might not work as expected.\n        - `active_adapters` (`list[str]`, `Literal[\"irregular\"]`):\n           The names of the active adapters. If the active adapters are not consistent across all layers, this will be\n           `\"irregular\"`, which means that your model is in an inconsistent state and might not work as expected.\n        - `merged_adapters` (`list[str]`, `Literal[\"irregular\"]`):\n           The names of the merged adapters. If the merged adapters are not consistent across all layers, this will be\n           `\"irregular\"`, which means that your model is in an inconsistent state and might not work as expected.\n        - `available_adapters` (`list[str]`):\n           The names of the available adapters, e.g. `[\"default\"]`.\n\n        Args:\n            model ([`~PeftModel`]):\n                The model to get the adapter layer status from.\n\n        Returns:\n            `peft.peft_model.TunerModelStatus`:\n                A dataclass containing the status of the model.\n\n        \"\"\"\n        return get_model_status(self)\n\n    @classmethod\n    def _split_kwargs(cls, kwargs: dict[str, Any]):\n        _kwargs_not_in_hf_hub_download_signature = (\"use_auth_token\",)\n        hf_hub_download_kwargs = {}\n        other_kwargs = {}\n\n        for key, value in kwargs.items():\n            if key in inspect.signature(hf_hub_download).parameters or key in _kwargs_not_in_hf_hub_download_signature:\n                hf_hub_download_kwargs[key] = value\n            else:\n                other_kwargs[key] = value\n\n        return hf_hub_download_kwargs, other_kwargs\n\n    def _update_offload(self, offload_index: dict[str, dict[str, str]], adapters_weights: dict[str, torch.tensor]):\n        \"\"\"\n        Update the offload_index and safetensors files for loading and mergine PeftModels with disk-offloaded modules.\n\n        Args:\n            offload_index (Dict[str: str]):\n                Dictionary of disk-offloaded modules with their metadata and safetensors filenames\n            adapters_weights (Dict[str: torch.tensor]):\n                Dictionary of Peft adapter module names and weights\n        \"\"\"\n\n        if not offload_index:\n            return offload_index\n\n        prefix = \"base_model.model.\"\n        # rename offload index weight and model names\n        adapter_names = list(self.peft_config.keys())\n        for adapter_name in adapter_names:\n            keys = list(offload_index.keys())\n            block_id = keys[0].split(\".\")[0] + \".\"  # for writing safetensors key,\n\n            # replace original offload index keys with PeftModel keys\n            for key in keys:\n                suffix_pos = key.rfind(\".\")\n                extended_prefix = prefix + key[:suffix_pos]\n                module = dict(self.named_modules())[extended_prefix]\n                if isinstance(module, BaseTunerLayer):\n                    new_key = prefix + key[:suffix_pos] + \".base_layer\" + key[suffix_pos:]\n                else:\n                    new_key = prefix + key\n                offload_index[key][\"weight_name\"] = new_key\n                offload_index[new_key] = offload_index[key]\n                del offload_index[key]\n\n            files_seen = set()\n            # rename safetensors for dispatch\n            for new_key in list(offload_index.keys()):\n                fname = offload_index[new_key][\"safetensors_file\"]\n\n                # make a new file name\n                new_fname_list = list(fname.split(os.sep))\n                for i, name in enumerate(new_fname_list):\n                    if \"--\" in name:\n                        new_fname_list[i] += \"-peft\"\n                        break\n                new_fname = os.path.join(*new_fname_list)\n\n                if fname in files_seen:\n                    continue\n                safe_dict = {}\n                with safe_open(fname, framework=\"pt\") as f:\n                    for safe_key in f.keys():\n                        safe_tensor = f.get_tensor(safe_key)\n                        metadata = f.metadata()\n                        suffix_pos = safe_key.rfind(\".\")\n                        extended_prefix = prefix + block_id + safe_key[:suffix_pos]\n                        safe_module = dict(self.named_modules())[extended_prefix]\n                        if isinstance(safe_module, BaseTunerLayer):\n                            final_key = extended_prefix + \".base_layer\" + safe_key[suffix_pos:]\n                            lora_dict = {key: val for key, val in adapters_weights.items() if extended_prefix in key}\n\n                            # add LoRA keys and values to disk offload\n                            for lora_key, lora_val in lora_dict.items():\n                                divide = lora_key.rfind(\".\")\n                                new_key = lora_key[:divide] + f\".{adapter_name}\" + lora_key[divide:]\n                                safe_dict[new_key] = lora_val\n                        else:\n                            final_key = prefix + block_id + safe_key\n                        safe_dict[final_key] = safe_tensor\n                    files_seen.add(new_fname)\n\n                    # avoid overwriting original safetensors\n                    for key in safe_dict.keys():\n                        offload_index[key] = {\"safetensors_file\": new_fname, \"weight_name\": key}\n\n                    base_name = os.path.dirname(new_fname)\n                    if not os.path.exists(base_name):\n                        os.makedirs(base_name)\n                    safe_save_file(safe_dict, new_fname, metadata=metadata)\n\n    def _check_new_adapter_config(self, peft_config: PeftConfig, is_trainable: bool) -> None:\n        \"\"\"Perform checks on newly added PEFT configs to ensure integrity.\"\"\"\n        if peft_config.is_prompt_learning and is_trainable:\n            raise ValueError(\"Cannot set a prompt learning adapter to trainable when loading pretrained adapter.\")\n\n        # Since PiSSA/OLoRA modifies the base weights, it should not be combined with other adapters.\n        all_configs = [peft_config] + list(self.peft_config.values())\n        if len(all_configs) > 1:\n            if any(getattr(config, \"init_lora_weights\", None) == \"pissa\" for config in all_configs):\n                msg = (\n                    \"PiSSA changes the base weights of the model and should thus not be used with other adapters. \"\n                    \"Consider converting the PiSSA adapter into a normal LoRA adapter: \"\n                    \"https://github.com/huggingface/peft/tree/main/examples/pissa_finetuning#convert-pissa-to-lora\"\n                )\n                warnings.warn(msg)\n            elif any(getattr(config, \"init_lora_weights\", None) == \"olora\" for config in all_configs):\n                msg = (\n                    \"OLoRA changes the base weights of the model and should thus not be used with other adapters. \"\n                    \"Consider converting the OLoRA adapter into a normal LoRA adapter: \"\n                    \"https://github.com/huggingface/peft/tree/main/examples/olora_finetuning#olora-and-lora\"\n                )\n                warnings.warn(msg)\n\n    def load_adapter(\n        self,\n        model_id: Union[str, os.PathLike],\n        adapter_name: str,\n        is_trainable: bool = False,\n        torch_device: Optional[str] = None,\n        autocast_adapter_dtype: bool = True,\n        ephemeral_gpu_offload: bool = False,\n        low_cpu_mem_usage: bool = False,\n        **kwargs: Any,\n    ):\n        \"\"\"\n        Load a trained adapter into the model.\n\n        The name for the new adapter should be unique.\n\n        The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active\n        adapter.\n\n        Args:\n            model_id (`str` or `os.PathLike`):\n                The name of the PEFT configuration to use. Can be either:\n                    - A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face\n                      Hub.\n                    - A path to a directory containing a PEFT configuration file saved using the `save_pretrained`\n                      method (`./my_peft_config_directory/`).\n            adapter_name (`str`):\n                The name of the adapter to be added.\n            is_trainable (`bool`, *optional*, defaults to `False`):\n                Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be\n                used for inference.\n            torch_device (`str`, *optional*, defaults to None):\n                The device to load the adapter on. If `None`, the device will be inferred.\n            autocast_adapter_dtype (`bool`, *optional*, defaults to `True`):\n                Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter\n                weights using float16 and bfloat16 to float32, as this is typically required for stable training, and\n                only affect select PEFT tuners.\n            ephemeral_gpu_offload (`bool`, *optional*, defaults to `False`):\n                Whether to use ephemeral GPU offloading for partially loaded modules. Defaults to `False`.\n            low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):\n                Create empty adapter weights on meta device before loading the saved weights. Useful to speed up the\n                process.\n            kwargs: (`optional`):\n                Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub.\n        \"\"\"\n        from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING\n\n        hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs)\n        if torch_device is None:\n            torch_device = infer_device()\n\n        if adapter_name not in self.peft_config:\n            # load the config\n            peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[\n                PeftConfig._get_peft_type(\n                    model_id,\n                    **hf_hub_download_kwargs,\n                )\n            ].from_pretrained(\n                model_id,\n                ephemeral_gpu_offload=ephemeral_gpu_offload,\n                **hf_hub_download_kwargs,\n            )\n            self._check_new_adapter_config(peft_config, is_trainable=is_trainable)\n            peft_config.inference_mode = not is_trainable\n            self.add_adapter(adapter_name, peft_config, low_cpu_mem_usage=low_cpu_mem_usage)\n\n        adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs)\n\n        # load the weights into the model\n        ignore_mismatched_sizes = kwargs.get(\"ignore_mismatched_sizes\", False)\n        load_result = set_peft_model_state_dict(\n            self,\n            adapters_weights,\n            adapter_name=adapter_name,\n            ignore_mismatched_sizes=ignore_mismatched_sizes,\n            low_cpu_mem_usage=low_cpu_mem_usage,\n        )\n\n        tuner = self.peft_config[adapter_name].peft_type\n        tuner_prefix = PEFT_TYPE_TO_PREFIX_MAPPING.get(tuner, \"\")\n        adapter_missing_keys = []\n\n        # Filter missing keys specific to the current adapter and tuner prefix.\n        for key in load_result.missing_keys:\n            if tuner_prefix in key and adapter_name in key:\n                adapter_missing_keys.append(key)\n\n        load_result.missing_keys.clear()\n        load_result.missing_keys.extend(adapter_missing_keys)\n\n        if (\n            (getattr(self, \"hf_device_map\", None) is not None)\n            and (len(set(self.hf_device_map.values()).intersection({\"cpu\", \"disk\"})) > 0)\n            and len(self.peft_config) == 1\n        ):\n            device_map = kwargs.get(\"device_map\", \"auto\")\n            max_memory = kwargs.get(\"max_memory\", None)\n            offload_dir = kwargs.get(\"offload_folder\", None)\n            offload_index = kwargs.get(\"offload_index\", None)\n\n            dispatch_model_kwargs = {}\n            # Safety checker for previous `accelerate` versions\n            # `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/\n            if \"offload_index\" in inspect.signature(dispatch_model).parameters:\n                dispatch_model_kwargs[\"offload_index\"] = offload_index\n\n            no_split_module_classes = self._no_split_modules\n\n            if device_map != \"sequential\":\n                max_memory = get_balanced_memory(\n                    self,\n                    max_memory=max_memory,\n                    no_split_module_classes=no_split_module_classes,\n                    low_zero=(device_map == \"balanced_low_0\"),\n                )\n\n            if isinstance(device_map, str):\n                device_map = infer_auto_device_map(\n                    self, max_memory=max_memory, no_split_module_classes=no_split_module_classes\n                )\n\n            self._update_offload(offload_index, adapters_weights)\n            dispatch_model_kwargs[\"offload_index\"] = offload_index\n\n            dispatch_model(\n                self,\n                device_map=device_map,\n                offload_dir=offload_dir,\n                **dispatch_model_kwargs,\n            )\n\n            hook = AlignDevicesHook(io_same_device=True)\n            if self.peft_config[adapter_name].is_prompt_learning:\n                remove_hook_from_submodules(self.prompt_encoder)\n            add_hook_to_module(self.get_base_model(), hook)\n\n        if hasattr(self.base_model, \"_cast_adapter_dtype\"):\n            self.base_model._cast_adapter_dtype(\n                adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype\n            )\n\n        # Set model in evaluation mode to deactivate Dropout modules by default\n        if not is_trainable:\n            self.eval()\n        return load_result\n\n    def set_adapter(self, adapter_name: str) -> None:\n        \"\"\"\n        Sets the active adapter.\n\n        Only one adapter can be active at a time.\n\n        Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True). If this is\n        not desired, use the following code.\n\n        ```py\n        >>> for name, param in model_peft.named_parameters():\n        ...     if ...:  # some check on name (ex. if 'lora' in name)\n        ...         param.requires_grad = False\n        ```\n\n        Args:\n            adapter_name (`str`):\n                The name of the adapter to be set as active. The adapter must be loaded first.\n        \"\"\"\n        if adapter_name not in self.peft_config:\n            raise ValueError(f\"Adapter {adapter_name} not found.\")\n        self.active_adapter = adapter_name\n        if not self.peft_config[adapter_name].is_prompt_learning:\n            self.base_model.set_adapter(adapter_name)\n        _set_adapter(self, adapter_name)\n\n    @property\n    def base_model_torch_dtype(self):\n        return getattr(self.base_model, \"dtype\", None)\n\n    @property\n    def active_peft_config(self):\n        return self.peft_config[self.active_adapter]\n\n    def create_or_update_model_card(self, output_dir: str):\n        \"\"\"\n        Updates or create model card to include information about peft:\n        1. Adds `peft` library tag\n        2. Adds peft version\n        3. Adds base model info\n        4. Adds quantization information if it was used\n        \"\"\"\n\n        filename = os.path.join(output_dir, \"README.md\")\n\n        card = ModelCard.load(filename) if os.path.exists(filename) else ModelCard.from_template(ModelCardData())\n\n        card.data[\"library_name\"] = \"peft\"\n\n        model_config = getattr(self, \"config\", DUMMY_MODEL_CONFIG)\n        if hasattr(model_config, \"to_dict\"):\n            model_config = model_config.to_dict()\n            \n        model_config = None if model_config == DUMMY_MODEL_CONFIG else model_config\n        if model_config is not None and \"_name_or_path\" in model_config:\n            card.data[\"base_model\"] = model_config[\"_name_or_path\"]\n\n        lines = card.text.splitlines()\n\n        quantization_config = None\n        if hasattr(model_config, \"quantization_config\"):\n            quantization_config = self.config.quantization_config.to_dict()\n        training_config_text = \"\"\n        quantization_prefix = \"The following `bitsandbytes` quantization config was used during training:\"\n        # Adds quantization information if it was used\n        if quantization_config is not None:\n            training_config_text += f\"\\n{quantization_prefix}\\n\"\n            training_config_text += \"\\n\".join([f\"- {name}: {value}\" for name, value in quantization_config.items()])\n            training_config_text += \"\\n\"\n\n        training_procedure_heading = \"## Training procedure\"\n        if quantization_prefix not in lines and bool(training_config_text):\n            if training_procedure_heading in lines:\n                lines.insert(lines.index(training_procedure_heading) + 2, training_config_text)\n            else:\n                lines.append(f\"{training_procedure_heading}\\n{training_config_text}\")\n\n        # Adds peft version\n        framework_block_heading = \"### Framework versions\"\n        if f\"- PEFT {__version__}\" not in lines:\n            if framework_block_heading in lines:\n                lines.insert(lines.index(framework_block_heading) + 2, f\"- PEFT {__version__}\")\n            else:\n                lines.append(f\"{framework_block_heading}\\n\\n- PEFT {__version__}\")\n\n        card.text = \"\\n\".join(lines)\n        card.save(filename)\n\nclass PeftModelForCausalLM(PeftModel):\n    \"\"\"\n    Peft model for causal language modeling.\n\n    Args:\n        model ([`~transformers.PreTrainedModel`]): Base transformer model.\n        peft_config ([`PeftConfig`]): Peft config.\n        adapter_name (`str`,  *optional*): The name of the adapter, defaults to `\"default\"`.\n        autocast_adapter_dtype (`bool`, *optional*):\n            Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter weights\n            using float16 and bfloat16 to float32, as this is typically required for stable training, and only affect\n            select PEFT tuners.\n\n    Example:\n\n        ```py\n        >>> from transformers import AutoModelForCausalLM\n        >>> from peft import PeftModelForCausalLM, get_peft_config\n\n        >>> config = {\n        ...     \"peft_type\": \"PREFIX_TUNING\",\n        ...     \"task_type\": \"CAUSAL_LM\",\n        ...     \"inference_mode\": False,\n        ...     \"num_virtual_tokens\": 20,\n        ...     \"token_dim\": 1280,\n        ...     \"num_transformer_submodules\": 1,\n        ...     \"num_attention_heads\": 20,\n        ...     \"num_layers\": 36,\n        ...     \"encoder_hidden_size\": 1280,\n        ...     \"prefix_projection\": False,\n        ...     \"postprocess_past_key_value_function\": None,\n        ... }\n\n        >>> peft_config = get_peft_config(config)\n        >>> model = AutoModelForCausalLM.from_pretrained(\"gpt2-large\")\n        >>> peft_model = PeftModelForCausalLM(model, peft_config)\n        >>> peft_model.print_trainable_parameters()\n        trainable params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544\n        ```\n    \"\"\"\n\n    def __init__(\n        self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = \"default\", **kwargs\n    ) -> None:\n        super().__init__(model, peft_config, adapter_name, **kwargs)\n        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation\n\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        task_ids=None,\n        **kwargs,\n    ):\n        peft_config = self.active_peft_config\n        if not peft_config.is_prompt_learning:\n            if self.base_model.config.model_type == \"mpt\":\n                if inputs_embeds is not None:\n                    raise AssertionError(\"forward in MPTForCausalLM does not support inputs_embeds\")\n                return self.base_model(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    labels=labels,\n                    output_attentions=output_attentions,\n                    output_hidden_states=output_hidden_states,\n                    return_dict=return_dict,\n                    **kwargs,\n                )\n\n            if peft_config.peft_type == PeftType.POLY:\n                kwargs[\"task_ids\"] = task_ids\n\n            with self._enable_peft_forward_hooks(**kwargs):\n                kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}\n                kwargs.pop(\"num_items_in_batch\", None)\n                if isinstance(self.base_model, LoraModel):\n                    return self.base_model.model(\n                        input_ids=input_ids,\n                        attention_mask=attention_mask,\n                        inputs_embeds=inputs_embeds,\n                        labels=labels,\n                        output_attentions=output_attentions,\n                        output_hidden_states=output_hidden_states,\n                        return_dict=return_dict,\n                        **kwargs,\n                    )\n                return self.base_model(\n                    input_ids=input_ids,\n                    attention_mask=attention_mask,\n                    inputs_embeds=inputs_embeds,\n                    labels=labels,\n                    output_attentions=output_attentions,\n                    output_hidden_states=output_hidden_states,\n                    return_dict=return_dict,\n                    **kwargs,\n                )\n\n        batch_size = _get_batch_size(input_ids, inputs_embeds)\n        if attention_mask is not None:\n            # concat prompt attention mask\n            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)\n            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n\n        if kwargs.get(\"position_ids\", None) is not None:\n            warnings.warn(\"Position ids are not supported for parameter efficient tuning. Ignoring position ids.\")\n            kwargs[\"position_ids\"] = None\n        if kwargs.get(\"token_type_ids\", None) is not None:\n            warnings.warn(\"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids\")\n            kwargs[\"token_type_ids\"] = None\n        kwargs.update(\n            {\n                \"attention_mask\": attention_mask,\n                \"labels\": labels,\n                \"output_attentions\": output_attentions,\n                \"output_hidden_states\": output_hidden_states,\n                \"return_dict\": return_dict,\n            }\n        )\n\n        if peft_config.peft_type == PeftType.PREFIX_TUNING:\n            # overwrite past_kv in kwargs\n            kwargs[\"past_key_values\"] = self.get_prompt(batch_size)\n            return self.base_model(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)\n        elif peft_config.peft_type == PeftType.CPT:\n            return self._cpt_forward(input_ids, inputs_embeds, peft_config, task_ids, batch_size, **kwargs)\n        else:\n            if inputs_embeds is None:\n                inputs_embeds = self.word_embeddings(input_ids)\n            # concat prompt labels\n            if labels is not None:\n                prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)\n                kwargs[\"labels\"] = torch.cat((prefix_labels, labels), dim=1)\n            prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)\n            prompts = prompts.to(inputs_embeds.dtype)\n            inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)\n            return self.base_model(inputs_embeds=inputs_embeds, **kwargs)\n\n    def _cpt_forward(\n        self, input_ids=None, inputs_embeds=None, peft_config=None, task_ids=None, batch_size=None, **kwargs\n    ):\n        # Extract labels from kwargs\n        labels = kwargs.pop(\"labels\")\n        device = [i.device for i in [input_ids, inputs_embeds, labels] if i is not None][0]\n        # Extract input_type_mask from kwargs and move it to the same device as labels\n        if \"input_type_mask\" in kwargs.keys():\n            input_type_mask = kwargs.pop(\"input_type_mask\").to(device)\n        else:\n            if input_ids is None:\n                N_tokens = inputs_embeds.shape[1]\n            else:\n                N_tokens = input_ids.shape[1]\n            input_type_mask = torch.ones((batch_size, N_tokens)).to(device) * 4\n\n        cpt_token_ids = peft_config.cpt_token_ids\n        cpt_tokens_type_mask = peft_config.cpt_tokens_type_mask\n\n        # Generate embeddings if not provided\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n        # Get prompt and concatenate with input embeddings\n        prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)\n        prompts = prompts.to(inputs_embeds.dtype)\n        inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)\n        # If labels are provided, generate prefix labels and type mask\n        cpt_labels = None\n        if labels is not None:\n            # Generate prefix labels and concatenate with the input labels\n            prefix_labels = torch.Tensor(cpt_token_ids).long().view(1, -1)\n            prefix_labels = prefix_labels.repeat(batch_size, 1).to(labels.device)\n            cpt_labels = torch.cat((prefix_labels, labels), dim=1)\n            # Generate prefix type mask and shift input type mask values to avoid conflicts\n            prefix_type_mask = torch.Tensor(cpt_tokens_type_mask).long().view(1, -1)\n            prefix_type_mask = prefix_type_mask.repeat(batch_size, 1).to(labels.device)\n            adjusted_input_type_mask = input_type_mask\n            adjusted_input_type_mask[adjusted_input_type_mask > 0] += prefix_type_mask.max()\n            # Concatenate prefix and shifted input type masks\n            cpt_type_mask = torch.cat((prefix_type_mask, adjusted_input_type_mask), dim=1)\n            # Identify valid label positions and mask invalid ones with -100\n            labels_idx = (cpt_type_mask > 0) & (cpt_type_mask % 4 == 0)\n            cpt_labels[~labels_idx] = -100\n            # Update kwargs with the modified labels\n\n        kwargs[\"labels\"] = cpt_labels\n        # Pass the modified inputs to the base model\n        base_model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs)\n        if labels is None:\n            return base_model_output\n        else:\n            # Calculate the loss using the custom CPT loss function\n            base_model_output = CPTEmbedding.calculate_loss(\n                base_model_output, cpt_labels, cpt_type_mask, self.peft_config[\"default\"]\n            )\n            return base_model_output\n\n    def generate(self, *args, **kwargs):\n        peft_config = self.active_peft_config\n        self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation\n        if hasattr(self.base_model, \"model\"):\n            self.base_model.model.generation_config = self.generation_config\n        else:\n            self.base_model.generation_config = self.generation_config\n        try:\n            if not peft_config.is_prompt_learning:\n                with self._enable_peft_forward_hooks(*args, **kwargs):\n                    kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}\n                    outputs = self.base_model.generate(*args, **kwargs)\n            else:\n                outputs = self.base_model.generate(**kwargs)\n        except:\n            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation\n            raise\n        else:\n            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation\n            return outputs\n\n    def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] = None, **kwargs):\n        peft_config = self.active_peft_config\n        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)\n\n        # https://github.com/huggingface/transformers/pull/26681/ introduced new cache format\n        # for some architectures which requires a special fix for prompt tuning etc.\n        # TODO: starting with transformers 4.38, all architectures should support caching.\n        uses_transformers_4_38 = packaging.version.parse(transformers.__version__) >= packaging.version.parse(\"4.38.0\")\n        uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse(\"4.36.0\")\n        transformers_new_cache_archs = [\"llama\", \"mistral\", \"persimmon\", \"phi\"]\n        if packaging.version.parse(transformers.__version__) > packaging.version.parse(\"4.43.3\"):\n            # https://github.com/huggingface/transformers/pull/31445\n            transformers_new_cache_archs.append(\"bloom\")\n\n        uses_cache = uses_transformers_4_38 or (\n            uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs\n        )\n\n        if peft_config.peft_type == PeftType.POLY:\n            model_kwargs[\"task_ids\"] = task_ids\n        if peft_config.is_prompt_learning:\n            if uses_cache and (model_kwargs.get(\"past_key_values\", None) is not None):\n                # change in the logic of `prepare_inputs_for_generation` makes the below code necessary\n                # In prompt learning methods, past key values are longer when compared to the `input_ids`.\n                # As such only consider the last input ids in the autogressive generation phase.\n                past_key_values = model_kwargs[\"past_key_values\"]\n                if isinstance(past_key_values, (tuple, list)):\n                    seq_len = past_key_values[0][0].shape[-2]\n                else:  # using transformers kv cache\n                    seq_len = past_key_values.get_seq_length()\n                if seq_len >= model_kwargs[\"input_ids\"].shape[1]:\n                    model_kwargs[\"input_ids\"] = model_kwargs[\"input_ids\"][:, -1:]\n\n            if model_kwargs.get(\"attention_mask\", None) is not None:\n                size = model_kwargs[\"input_ids\"].shape[0], peft_config.num_virtual_tokens\n                prefix_attention_mask = torch.ones(size).to(model_kwargs[\"input_ids\"].device)\n                model_kwargs[\"attention_mask\"] = torch.cat(\n                    (prefix_attention_mask, model_kwargs[\"attention_mask\"]), dim=1\n                )\n\n            if model_kwargs.get(\"position_ids\", None) is not None:\n                warnings.warn(\"Position ids are not supported for parameter efficient tuning. Ignoring position ids.\")\n                model_kwargs[\"position_ids\"] = None\n\n            if kwargs.get(\"token_type_ids\", None) is not None:\n                warnings.warn(\n                    \"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids\"\n                )\n                kwargs[\"token_type_ids\"] = None\n\n            # no past_key_values or past_key_values empty cache\n            requires_prompt_injection = (model_kwargs.get(\"past_key_values\", None) is None) or (\n                isinstance(model_kwargs[\"past_key_values\"], transformers.Cache)\n                and not model_kwargs[\"past_key_values\"].get_seq_length()\n            )\n\n            if requires_prompt_injection and peft_config.peft_type == PeftType.PREFIX_TUNING:\n                new_past_key_values = self.get_prompt(batch_size=model_kwargs[\"input_ids\"].shape[0])\n                model_kwargs[\"past_key_values\"] = new_past_key_values\n            elif requires_prompt_injection:\n                inputs_embeds = self.word_embeddings(model_kwargs[\"input_ids\"])\n                prompts = self.get_prompt(batch_size=model_kwargs[\"input_ids\"].shape[0], task_ids=task_ids)\n                prompts = prompts.to(inputs_embeds.dtype)\n                model_kwargs[\"inputs_embeds\"] = torch.cat((prompts, inputs_embeds), dim=1)\n                model_kwargs[\"input_ids\"] = None\n\n        # For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is\n        # passed in the forward pass to keep track of the position ids of the cache. We have to\n        # pop that from `model_kwargs` as `cache_position` is properly created by the model, using the passed\n        # `inputs_embeds`: https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956\n        _ = model_kwargs.pop(\"cache_position\", None)\n\n        return model_kwargs\n\n@dataclass\nclass TunerLayerStatus:\n    name: str\n    module_type: str\n    enabled: bool\n    active_adapters: list[str]\n    merged_adapters: list[str]\n    requires_grad: dict[str, bool | Literal[\"irregular\"]]\n    available_adapters: list[str]\n    devices: dict[str, list[str]]\n\n\ndef get_layer_status(model: torch.nn.Module) -> list[TunerLayerStatus]:\n    \"\"\"Get the status of each adapter layer in the model.\n\n    This function returns a list of `TunerLayerStatus` dataclass instances, each of which contains the following\n    attributes:\n\n    - `name` (`str`):\n       The name of the adapter layer, e.g. `model.encoder.block.0.layer.0.SelfAttention.q`.\n    - `module_type` (`str`):\n       The type of the adapter layer, e.g. `lora.Linear`.\n    - `enabled` (`bool`):\n       Whether the adapter layer is enabled.\n    - `active_adapters` (`list[str]`):\n       The names of the active adapters, if any, e.g. `[\"default\"]`.\n    - `merged_adapters` (`list[str]`):\n       The names of the merged adapters, if any, e.g. `[\"default\"]`.\n    - requires_grad : dict[str, bool | Literal[\"irregular\"]]\n       The requires_grad status of the parameters for each adapter module. Ideally, it should be either `True` or\n       `False`. If the requires_grad status is not consistent across all parameters, the value will be set to\n       `\"irregular\"`.\n    - `available_adapters` (`list[str]`):\n       The names of the available adapters, e.g. `[\"default\"]`.\n    - `devices` (`dict[str, list[str]]`):\n       The devices where the parameters of the given adapter are stored, e.g. `[\"cuda\"]`.\n\n    Args:\n        model ([Union[`~PeftModel`, `~transformers.PreTrainedModel`, `nn.Module`]]):\n            The model to get the adapter layer status from.\n\n    Returns:\n        list[`peft.peft_model.TunerLayerStatus`]:\n            A list of dataclasses, each containing the status of the corresponding adapter layer.\n\n    \"\"\"\n    if isinstance(model, PeftModel):\n        base_model = model.base_model\n    else:\n        base_model = model\n\n    layer_status: list[TunerLayerStatus] = []\n    for name, module in base_model.named_modules():\n        if not isinstance(module, BaseTunerLayer):\n            continue\n\n        # determine if all submodules/parameters if this module require grad or not\n        mapping_requires_grad_list: dict[str, list[bool]] = collections.defaultdict(list)\n        for adapter_module_name in module.adapter_layer_names:\n            adapter_module = getattr(module, adapter_module_name)\n            if isinstance(adapter_module, torch.nn.ModuleDict):\n                for key, submodule in adapter_module.items():\n                    for param in submodule.parameters():\n                        mapping_requires_grad_list[key].append(param.requires_grad)\n            elif isinstance(adapter_module, torch.nn.ParameterDict):\n                for key, param in adapter_module.items():\n                    mapping_requires_grad_list[key].append(param.requires_grad)\n            else:\n                # strange, we don't know how to handle this, ignore for now\n                pass\n\n        def check_irrgular(vals: list[bool]) -> bool | Literal[\"irregular\"]:\n            if all(vals):\n                return True\n            if not any(vals):\n                return False\n            return \"irregular\"\n\n        requires_grad = {key: check_irrgular(vals) for key, vals in mapping_requires_grad_list.items()}\n\n        devices_dd = collections.defaultdict(list)\n        for adapter_module_name in module.adapter_layer_names + module.other_param_names:\n            adapter_module = getattr(module, adapter_module_name)\n            if isinstance(adapter_module, torch.nn.ModuleDict):\n                for key, submodule in adapter_module.items():\n                    devices_dd[key].extend([param.device.type for param in submodule.parameters()])\n            elif isinstance(adapter_module, torch.nn.ParameterDict) or (\n                adapter_module.__class__.__name__ == \"BufferDict\"\n            ):  # VeRA\n                for key, param in adapter_module.items():\n                    devices_dd[key].append(param.device.type)\n        devices = {key: sorted(set(val)) for key, val in devices_dd.items()}\n\n        status = TunerLayerStatus(\n            name=name,\n            module_type=repr(module).partition(\"(\")[0],\n            enabled=not module.disable_adapters,\n            active_adapters=module.active_adapters,\n            merged_adapters=module.merged_adapters,\n            requires_grad=requires_grad,\n            available_adapters=sorted(module._get_available_adapters()),\n            devices=devices,\n        )\n        layer_status.append(status)\n\n    if not layer_status:\n        raise ValueError(\n            \"No adapter layers found in the model, please ensure that it's a PEFT model or that you have PEFT adapters \"\n            \"injected in the model.\"\n        )\n\n    return layer_status\n\n\n@dataclass\nclass TunerModelStatus:\n    base_model_type: str\n    adapter_model_type: str\n    peft_types: dict[str, str]\n    trainable_params: int\n    total_params: int\n    num_adapter_layers: int\n    enabled: bool | Literal[\"irregular\"]\n    active_adapters: list[str] | Literal[\"irregular\"]\n    merged_adapters: list[str] | Literal[\"irregular\"]\n    requires_grad: dict[str, bool | Literal[\"irregular\"]]\n    available_adapters: list[str]\n    devices: dict[str, list[str]]\n\n\ndef get_model_status(model: torch.nn.Module) -> TunerModelStatus:\n    \"\"\"Get the status of tuners of the model.\n\n    This function returns a `TunerModelStatus` dataclass instance, which contains the following attributes:\n\n    - `base_model_type` (`str`):\n       The type of the base model, e.g. `T5Model`.\n    - `adapter_model_type` (`str`):\n       The type of the adapter model, e.g. `LoraModel`.\n    - `peft_types` (`dict[str, str]`):\n       The mapping of adapter name to adapter type, e.g. `{\"default\": \"LORA\"}`.\n    - `trainable_params` (`int`):\n       The number of trainable parameters in the model.\n    - `total_params` (`int`):\n       The total number of parameters in the model.\n    - `num_adapter_layers` (`int`):\n       The number of adapter layers in the model.\n    - `enabled` (`bool`, `Literal[\"irregular\"]`):\n       Whether all adapter layers are enabled. If some are enabled and some are not, this will be `\"irregular\"`. This\n       means that your model is in an inconsistent state and might not work as expected.\n    - `active_adapters` (`list[str]`, `Literal[\"irregular\"]`):\n       The names of the active adapters. If the active adapters are not consistent across all layers, this will be\n       `\"irregular\"`, which means that your model is in an inconsistent state and might not work as expected.\n    - `merged_adapters` (`list[str]`, `Literal[\"irregular\"]`):\n       The names of the merged adapters. If the merged adapters are not consistent across all layers, this will be\n       `\"irregular\"`, which means that your model is in an inconsistent state and might not work as expected.\n    - `requires_grad` (`dict[str, bool | Literal[\"irregular\"]]`):\n       Whether for the given adapter, all adapter layers have `requires_grad` set to `True` or `False`. If there is a\n       mix, this will be set to `\"irregular\"`, which means that your model is in an inconsistent state and might not\n       work as expected.\n    - `available_adapters` (`list[str]`):\n       The names of the available adapters, e.g. `[\"default\"]`.\n    - `devices` (`dict[str, list[str]]`):\n       The devices where the parameters of the given adapter are stored, e.g. `[\"cuda\"]`.\n\n    Args:\n        model ([Union[`~PeftModel`, `~transformers.PreTrainedModel`, `nn.Module`]]):\n            The model to get the adapter layer status from.\n\n    Returns:\n        `peft.peft_model.TunerModelStatus`:\n            A dataclass containing the status of the model.\n\n    \"\"\"\n    if isinstance(model, PeftModel):\n        base_model_type = model.get_base_model().__class__.__name__\n        trainable_params, total_params = model.get_nb_trainable_parameters()\n        base_model = model.base_model\n        peft_types = {key: str(config.peft_type).partition(\".\")[-1] for key, config in base_model.peft_config.items()}\n        adapter_model_type = base_model.__class__.__name__\n    elif isinstance(model, PreTrainedModel):\n        base_model_type = model.__class__.__name__\n        trainable_params, total_params = PeftModel.get_nb_trainable_parameters(model)\n        base_model = model\n        peft_types = {}\n        adapter_model_type = \"None\"\n    else:\n        base_model_type = \"other\"\n        trainable_params, total_params = PeftModel.get_nb_trainable_parameters(model)\n        base_model = model\n        peft_types = {}\n        adapter_model_type = \"None\"\n\n    layer_status = get_layer_status(model)\n    num_adapter_layers = len(layer_status)\n\n    enabled_set: set[bool] = {status.enabled for status in layer_status}  # must be {True}, {False}, or {True, False}\n    enabled: bool | Literal[\"irregular\"]\n    if len(enabled_set) == 1:\n        enabled = enabled_set.pop()\n    else:\n        enabled = \"irregular\"\n\n    available_adapters: list[str] = sorted(set().union(*(status.available_adapters for status in layer_status)))\n\n    # ideally, active adapters should be consistent across all layers of the model, but we cannot guarantee it\n    all_active_adapters: set[tuple[str, ...]] = {tuple(status.active_adapters) for status in layer_status}\n    active_adapters: list[str] | Literal[\"irregular\"]\n    if not all_active_adapters:\n        active_adapters = []\n    elif len(all_active_adapters) == 1:\n        active_adapters = list(all_active_adapters.pop())\n    else:\n        active_adapters = \"irregular\"\n\n    # Here we determine what adapters are merged. This is not trivial because multiple adapters can be merged or not at\n    # the same time. Some layers may only have adapter A, some only adapter B, so it's not as easy as just checking\n    # which adapters are merged on each layer.\n\n    # First, determine all adapters that are merged on at least on module.\n    merged_all: set[str] = set()\n    for status in layer_status:\n        merged_all.update(status.merged_adapters)\n\n    # Next, check if on any layer, on of these adapters is not merged.\n    merged_adapters: list[str] | Literal[\"irregular\"] = sorted(merged_all)\n    for status in layer_status:\n        unmerged = set(status.available_adapters) - set(status.merged_adapters)\n        if unmerged & merged_all:\n            # there is overlap between unmerged adapters and adapters that should be merged\n            merged_adapters = \"irregular\"\n            break\n\n    # check status of requires_grad\n    # first, merge the values for all layers\n    requires_grad_all: dict[str, list[bool | Literal[\"irregular\"]]] = collections.defaultdict(list)\n    for status in layer_status:\n        for key, val in status.requires_grad.items():\n            requires_grad_all[key].append(val)\n\n    # then, check if the values are consistent\n    def check_irrgular(vals: list[bool | Literal[\"irregular\"]]) -> bool | Literal[\"irregular\"]:\n        if all(val is True for val in vals):\n            return True\n        if all(val is False for val in vals):\n            return False\n        return \"irregular\"\n\n    requires_grad = {key: check_irrgular(vals) for key, vals in requires_grad_all.items()}\n\n    devices_dd = collections.defaultdict(list)\n    for status in layer_status:\n        for key, val in status.devices.items():\n            devices_dd[key].extend(val)\n    devices = {key: sorted(set(val)) for key, val in devices_dd.items()}\n\n    adapter_model_status = TunerModelStatus(\n        base_model_type=base_model_type,\n        adapter_model_type=adapter_model_type,\n        peft_types=peft_types,\n        trainable_params=trainable_params,\n        total_params=total_params,\n        num_adapter_layers=num_adapter_layers,\n        enabled=enabled,\n        active_adapters=active_adapters,\n        merged_adapters=merged_adapters,\n        requires_grad=requires_grad,\n        available_adapters=available_adapters,\n        devices=devices,\n    )\n    return adapter_model_status\n"
  },
  {
    "path": "kt-sft/ktransformers/sft/torchviz_test.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torchviz import make_dot\n\nclass SimpleNet(nn.Module):\n    def __init__(self):\n        super(SimpleNet, self).__init__()\n        self.fc1 = nn.Linear(10, 20)\n        self.relu = nn.ReLU()\n        self.fc2 = nn.Linear(20, 1)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.relu(x)\n        x = self.fc2(x)\n        return x\n\nmodel = SimpleNet()\n\ninput_tensor = torch.randn(1, 10)\n\noutput = model(input_tensor)\n\ndot = make_dot(output, params=dict(model.named_parameters()))\ndot.render('simple_net', format='svg', cleanup=True)    "
  },
  {
    "path": "kt-sft/ktransformers/tests/.gitignore",
    "content": "results/"
  },
  {
    "path": "kt-sft/ktransformers/tests/AIME_2024/eval_api.py",
    "content": "# adapt from https://github.com/abacaj/code-eval?tab=readme-ov-file\nimport argparse\nimport json\nimport os\nimport time\nimport requests\nimport tqdm\n\nfrom evaluation import filter_answer\nfrom prompts import instruct_prompt\nimport pandas as pd\nfrom datasets import load_dataset\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n\n\ndef generate_text(api_url,question , model_name, stream=False, auth_token=None):\n    headers = {\n        'accept': 'application/json',\n        'Content-Type': 'application/json',\n        'Authorization' : 'Bearer ' + auth_token if auth_token else ''\n    }\n    question = instruct_prompt(question)\n    data = {\n        \"messages\": [{\"content\": question, \"role\": \"user\"}],\n        \"model\": model_name,\n        \"stream\": stream,\n        \"temperature\": 0.6,\n        \"max_tokens\": 10240,\n    }\n    print(f\"content: {question}\")\n    response = requests.post(api_url, headers=headers, json=data,verify=False)\n    if response.status_code == 200:\n        result = response.json()\n        results = result.get('choices', [{}])[0].get('message', {}).get('content', '')\n        return filter_answer(results)\n    else:\n        print(f\"API Request failed with status code {response.status_code}\")\n        return None\ndef load_data(file_path):\n        \"\"\"\n        Load data from a Parquet file into a list.\n        Each record in the Parquet file should represent an individual record.\n        \"\"\"\n        # dataset = load_dataset('parquet', data_files=file_path)\n        data = []\n        ds = load_dataset(file_path)\n        df = pd.DataFrame(ds['train'])\n        for _, row in df.iterrows():\n            data.append(row.to_dict())\n        return data\n\ndef get_score(pred, answer):\n        \"\"\"\n        Calculate scores between the prediction and the answer.\n        Uses ROUGE scores as the evaluation metric.\n        :param pred: The predicted string.\n        :param answer: The reference answer string.\n        :return: A dictionary containing ROUGE scores.\n        \"\"\"\n        if pred == answer:\n            return 1\n        # if we need to compare str with number, convert teh str to number\n        try:\n            pred = float(pred)\n            answer = float(answer)\n        except:\n            pass\n        if pred == answer:\n            return 1\n        return 0\n\ndef run_eval_api(\n    api_url: str,\n    model_name: str,\n    out_path: str,\n    format_tabs: bool = False,\n    auth_token: str = None,\n    problem_file: str = None,\n    append: bool = False,\n    skip: int = 0\n):\n  \n    data = load_data(problem_file)\n    pbar = tqdm.tqdm(total=len(data) * 1)\n    pbar.update(skip)\n    for i in range(len(data)):\n        i = i+skip\n        data_item = data[i]\n        question = data_item['Problem']\n        # Start the timer for this evaluation\n        start_time = time.time()\n        try:\n            completion = generate_text(api_url, question, model_name, auth_token=auth_token)\n            if completion is None:\n                raise Exception(f\"Failed to get prediction for {question}\")\n            answer = data_item['Answer']\n            score = get_score(completion, answer)\n            elapsed_time = time.time() - start_time\n            result = {\n                \"index\": i,\n                \"question_id\": data_item[\"ID\"],\n                \"answer\": answer,\n                \"prediction\": completion,\n                \"score\": score,\n                \"time\": elapsed_time\n            }\n            with open(out_path, \"a\" if append else \"w\") as f:\n                f.write(json.dumps(result) + \"\\n\")\n            \n        except Exception as e:\n            print(f\"Failed to get prediction for {question}\")\n            print(e)\n            continue\n\n        pbar.update(1)\n    \n\ndef main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip):\n    os.makedirs(os.path.dirname(output_path), exist_ok=True)\n    run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"API Generate Tester\")\n    parser.add_argument(\"--api_url\", type=str, default=\"https://api.siliconflow.cn/v1/chat/completions\", help=\"API URL\")\n    parser.add_argument(\"--model_name\", type=str, default=\"Pro/deepseek-ai/DeepSeek-R1\", help=\"Model Name\")\n    parser.add_argument(\"--out_path\", type=str, default=\"results/api/eval_aime.jsonl\", help=\"Output Path\")\n    parser.add_argument(\"--auth_token\", type=str, default=None, help=\"Auth Token\")\n    parser.add_argument(\"--format_tabs\", action=\"store_true\", help=\"Format Tabs\")\n    parser.add_argument(\"--problem_file\", type=str, default=\"Maxwell-Jia/AIME_2024\", help=\"Evalset File\")\n    parser.add_argument(\"--no_append\", action=\"store_false\", help=\"Append to existing file\")\n    parser.add_argument(\"--skip\", type=int, default=0, help=\"Skip some tasks\")\n    args = parser.parse_args()\n    # api_url = \"https://api.siliconflow.cn/v1/chat/completions\"\n    main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append, args.skip)"
  },
  {
    "path": "kt-sft/ktransformers/tests/AIME_2024/evaluation.py",
    "content": "# reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35\ndef filter_answer(completion: str) -> str:\n    # the answer is the last part of the completion, it's a int64 number\n    # get the last line\n    completion = completion.strip().split(\"\\n\")[-1]\n    # handle the $\\\\boxed{...}$ format\n    if \"$\\\\boxed{\" in completion:\n        return completion.split(\"}\")[0].split(\"{\")[-1]\n    return completion.split()[-1]\n\n"
  },
  {
    "path": "kt-sft/ktransformers/tests/AIME_2024/prompts.py",
    "content": "def instruct_prompt(prompt: str) -> str:\n    return f\"\"\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\nSolve the following math problem without any tests or explanation only one answer surrounede by '$\\\\boxed{{}}$'\\n{prompt}\\n\\n### Response:\"\"\"\n"
  },
  {
    "path": "kt-sft/ktransformers/tests/dequant_gpu.py",
    "content": "import os \n# os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1,2\"\n# add path\nimport sys\ncurrent_path = os.path.abspath(os.path.dirname(__file__))\nsys.path.append(current_path+\"/../..\")\nimport numpy as np\n# from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin\n# from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch\nfrom ktransformers.util.custom_loader import GGUFLoader\nimport torch\nimport KTransformersOps\ntorch.set_default_dtype(torch.bfloat16)\nimport time\nfrom transformers import (\n    AutoConfig,\n)\nimport os\n# CUDA_LAUNCH_BLOCKING=1\nos.environ[\"CUDA_LAUNCH_BLOCKING\"]=\"1\"\n\ngguf_config = GGUFLoader(\"/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m\")\nmodel_name = \"/data/Qwen2-57B-A14B-Instruct\"\n\n# Q4k\nkey = \"blk.1.\"\ntarget = \"attn_q.weight\"\n\nt1 = time.time()\nq_weight_cpu = gguf_config.load_gguf_tensor(key+target, \"cpu\")\n# q_weight_cpu = torch.from_numpy(q_weight_cpu)\n\nt2 = time.time()\nq_weight_gpu = gguf_config.load_gguf_tensor(key+target, \"cuda:0\")\nt3 = time.time()\nprint()\nallclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu(), atol=1e-6)\nprint(f\"Q4k {key+target}\")\nprint(\"load gguf tensor from cpu cost: \", t2-t1)\nprint(\"load gguf tensor from gpu cost: \", t3-t2)\nprint(\"allclose: \", allclose)\n\n\n# Q6k\nkey = \"blk.0.\"\ntarget = \"ffn_down_exps.weight\"\n\nt1 = time.time()\nq_weight_cpu = gguf_config.load_gguf_tensor(key+target, \"cpu\")\nt2 = time.time()\nq_weight_gpu = gguf_config.load_gguf_tensor(key+target, \"cuda:0\")\nt3 = time.time()\nprint()\nallclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu().to(torch.float32), atol=1e-6)\nprint(f\"Q6k {key+target}\")\nprint(\"load gguf tensor from cpu cost: \", t2-t1)\nprint(\"load gguf tensor from gpu cost: \", t3-t2)\nprint(\"allclose: \", allclose)\n"
  },
  {
    "path": "kt-sft/ktransformers/tests/dequant_gpu_t.py",
    "content": "import os \nos.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n# add path\nimport sys\nsys.path.append(\"../..\")\nimport pycuda.autoinit\nimport pycuda.driver as cuda\nfrom pycuda.compiler import SourceModule\nimport numpy as np\nfrom ktransformers.operators.linear import KTransformersLinear, KLinearMarlin\nfrom ktransformers.operators.experts import KTransformersExperts, KExpertsTorch\nfrom ktransformers.util.custom_loader import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k\nimport torch\nimport KTransformersOps\ntorch.set_default_dtype(torch.bfloat16)\nimport time\nfrom transformers import (\n    AutoConfig,\n)\n\ngguf_config = GGUFLoader(\"/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m\")\nmodel_name = \"/data/Qwen2-57B-A14B-Instruct\"\nkey = \"blk.0.\"\ntarget = \"ffn_up_exps.weight\"\n\ndata = gguf_config.get_mmap_tensor(key + target)\n\n_, factors, offsets, qs1, qs2= dequantize_q4_k(data)\nfactors_cpu = torch.from_numpy(factors)\noffsets_cpu = torch.from_numpy(offsets)\nqs1_cpu = torch.from_numpy(qs1)\nqs2_cpu = torch.from_numpy(qs2)\n\n\n_, factors, offsets, qs1, qs2 = dequantize_q4_k_gpu(data)\n\nprint(torch.allclose(factors.cpu(), factors_cpu))\nprint(torch.allclose(offsets.cpu(), offsets_cpu))\nprint(torch.allclose(qs1.cpu(), qs1_cpu))\nprint(torch.allclose(qs2.cpu(), qs2_cpu))"
  },
  {
    "path": "kt-sft/ktransformers/tests/function_call_test.py",
    "content": "from openai import OpenAI\n\ndef send_messages(messages):\n    response = client.chat.completions.create(\n        model=\"deepseek-chat\",\n        messages=messages,\n        tools=tools\n    )\n    return response.choices[0].message\n\nclient = OpenAI(\n    api_key=\"placeholder\",\n    base_url=\"http://0.0.0.0:10002/v1\",\n)\n\ntools = [\n    {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_weather\",\n            \"description\": \"Get weather of an location, the user shoud supply a location first\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\n                        \"type\": \"string\",\n                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n                    }\n                },\n                \"required\": [\"location\"]\n            },\n        }\n    },\n]\n\nmessages = [{\"role\": \"user\", \"content\": \"How's the weather in Hangzhou?\"}]\nmessage = send_messages(messages)\nprint(f\"User>\\t {messages[0]['content']}\")\nprint(message)\ntool = message.tool_calls[0]\nmessages.append(message)\n\nmessages.append({\"role\": \"tool\", \"tool_call_id\": tool.id, \"content\": \"24℃\"})\nmessage = send_messages(messages)\nprint(f\"Model>\\t {message.content}\")"
  },
  {
    "path": "kt-sft/ktransformers/tests/humaneval/eval_api.py",
    "content": "# adapt from https://github.com/abacaj/code-eval?tab=readme-ov-file\nimport argparse\nimport os\nimport requests\nfrom human_eval.data import write_jsonl, read_problems\nimport tqdm\n\nfrom evaluation import filter_code, fix_indents\nfrom prompts import instruct_prompt\n\ndef generate_text(api_url,question , model_name, stream=False, auth_token=None):\n    headers = {\n        'accept': 'application/json',\n        'Content-Type': 'application/json',\n        'Authorization' : 'Bearer ' + auth_token if auth_token else ''\n    }\n    question = instruct_prompt(question)\n    data = {\n        \"messages\": [{\"content\": question, \"role\": \"user\"}],\n        \"model\": model_name,\n        \"stream\": stream,\n        \"temperature\": 0.6\n    }\n    print(f\"content: {question}\")\n    response = requests.post(api_url, headers=headers, json=data,verify=False)\n    if response.status_code == 200:\n        result = response.json()\n        results = result.get('choices', [{}])[0].get('message', {}).get('content', '')\n        return [filter_code(fix_indents(results))]\n    else:\n        print(f\"API Request failed with status code {response.status_code}\")\n        return None\n\ndef run_eval_api(\n    api_url: str,\n    model_name: str,\n    out_path: str,\n    format_tabs: bool = False,\n    auth_token: str = None,\n    problem_file: str = None,\n    append: bool = False,\n    skip: int = 0\n):\n    if(problem_file is None):\n        problems = read_problems()\n    else:\n        problems = read_problems(problem_file)\n    samples = []\n    pbar = tqdm.tqdm(total=len(problems) * 1)\n    pbar.update(skip)\n    try:\n        for task_id in problems:\n            # skip some tasks\n            if skip > 0:\n                skip -= 1\n                continue\n\n            if format_tabs:\n                prompt = problems[task_id][\"prompt\"].replace(\"    \", \"\\t\")\n            else:\n                prompt = problems[task_id][\"prompt\"]\n            completion = generate_text(api_url, prompt, model_name, auth_token=auth_token)\n            # samples.append({\"task_id\": task_id, \"completion\": completion})\n            for sample in completion:\n                result = dict(\n                    task_id=task_id,\n                    completion=sample,\n                )\n                samples += [result]\n                if append:\n                    write_jsonl(out_path, [result],append=append)\n            pbar.update(1)\n        if not append:\n            write_jsonl(out_path, samples,append=append)\n    except Exception as e:\n        if not append:\n            write_jsonl(out_path, samples,append=append)\n        print(f\"Error: {e}\")\n\ndef main(output_path, api_url, model_name, auth_token, format_tabs,problem_file, append,skip):\n    os.makedirs(os.path.dirname(output_path), exist_ok=True)\n    run_eval_api(api_url, model_name, output_path, format_tabs, auth_token, problem_file,append,skip)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"API Generate Tester\")\n    #parser.add_argument(\"--api_url\", type=str, default=\"https://api.siliconflow.cn/v1/chat/completions\", help=\"API URL\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10002/v1/chat/completions\", help=\"API URL\")\n    parser.add_argument(\"--model_name\", type=str, default=\"Pro/deepseek-ai/DeepSeek-V3\", help=\"Model Name\")\n    parser.add_argument(\"--out_path\", type=str, default=\"results/api/eval_b.jsonl\", help=\"Output Path\")\n    parser.add_argument(\"--auth_token\", type=str, default=None, help=\"Auth Token\")\n    parser.add_argument(\"--format_tabs\", action=\"store_true\", help=\"Format Tabs\")\n    parser.add_argument(\"--problem_file\", type=str, default=None, help=\"Evalset File\")\n    parser.add_argument(\"--no_append\", action=\"store_false\", help=\"Append to existing file\")\n    parser.add_argument(\"--skip\", type=int, default=0, help=\"Skip first n problems\")\n    args = parser.parse_args()\n    # api_url = \"https://api.siliconflow.cn/v1/chat/completions\"\n    main(args.out_path, args.api_url, args.model_name, args.auth_token, args.format_tabs, args.problem_file, args.no_append,args.skip)"
  },
  {
    "path": "kt-sft/ktransformers/tests/humaneval/evaluation.py",
    "content": "# reference: https://github.com/declare-lab/instruct-eval/blob/main/human_eval/main.py#L35\ndef filter_code(completion: str) -> str:\n    # The program tends to overwrite, we only take the first function\n    completion = completion.lstrip(\"\\n\")\n    # we also remove ```python\\n and ```\n    completion = completion.replace(\"```python\\n\", \"\").replace(\"```\", \"\")\n    if 'if __name__ == \"__main__\":' in completion:\n        completion = completion.split('if __name__ == \"__main__\":')[0]\n    if \"# Example usage\" in completion:\n        completion = completion.split(\"# Example usage\")[0]\n    return completion\n\n\ndef fix_indents(text: str) -> str:\n    return text.replace(\"\\t\", \"    \")\n"
  },
  {
    "path": "kt-sft/ktransformers/tests/humaneval/prompts.py",
    "content": "def instruct_prompt(prompt: str) -> str:\n    return f\"\"\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n### Instruction:\\nComplete the following Python code without any tests or explanation\\n{prompt}\\n\\n### Response:\"\"\"\n\n\ndef standard_prompt(prompt: str) -> str:\n    return f\"\"\"Complete the following Python code without any tests or explanation\\n{prompt}\"\"\"\n\n\ndef write_prompt(prompt: str) -> str:\n    return f\"\"\"Write a python program to complete the following code:\\n{prompt}\"\"\"\n\n\ndef replit_glaive_prompt(prompt: str) -> str:\n    return f\"\"\"Below is an instruction that describes a task, paired with an input that provides further context.\\n Write a response that appropriately completes the request.\\n\\n ### Instruction:\\nWrite a program to perform the given task.\\n\\n Input:\\n{prompt}\\n\\n### Response:\"\"\"\n"
  },
  {
    "path": "kt-sft/ktransformers/tests/mmlu_pro_test.py",
    "content": "import argparse\nimport random\nimport time\nimport json\nimport requests\nimport pandas as pd\nfrom datasets import load_dataset\n\nimport os\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\nos.environ['https_proxy'] = ''\nos.environ['http_proxy'] = ''\nhint = 'There is a single choice question. Answer the question by replying A, B, C, D, E, F, G, H, I, J. No other answers are accepted. Just the letter.'\n\n\nclass DataEvaluator:\n    def __init__(self):\n        # self.template_prompt = template_prompt\n        self.data = []\n\n    def load_data(self, file_path):\n        \"\"\"\n        Load data from a Parquet file into a list.\n        Each record in the Parquet file should represent an individual record.\n        \"\"\"\n        # dataset = load_dataset('parquet', data_files=file_path)\n        ds = load_dataset(\"TIGER-Lab/MMLU-Pro\")\n        df = pd.DataFrame(ds['test'])\n        # print(ds)\n        # # ds_1 =  ds['train']\n        # ds_2 =  ds['validation']\n        # ds_3 =  ds['test']\n        # df_test = pd.DataFrame(ds['test'])\n        # df_val = pd.DataFrame(ds['validation'])\n\n        # for _, row in df.iterrows():\n        #     self.data.append(row.to_dict())\n        # df = pd.read_parquet(file_path)\n\n        for _, row in df.iterrows():\n            self.data.append(row.to_dict())\n\n    def get_prompt(self, record):\n        \"\"\"\n        Combine fields from a record with the template prompt to create a full prompt.\n        :param record: Dictionary containing fields to populate the template.\n        :return: A formatted prompt string.\n        \"\"\"\n        options_str = \"\\n\".join([f\"{chr(65+i)}. {opt}\" for i, opt in enumerate(record['options'])])\n        prompt = hint + \"\\nQuestion: \" + record['question'] + \"\\n\" + options_str + \"\\nAnswer: '\"\n        return prompt\n        \n    def post_processing(self, text):\n        \"\"\"\n        Perform post-processing on the prediction string.\n        :param text: The raw prediction string.\n        :return: Processed prediction string.\n        \"\"\"\n        text = text.lstrip('\\n').split('\\n')[-1]\n        return text[-1:]\n\n    def score(self, pred, answers):\n        \"\"\"\n        Calculate scores between the prediction and the answer.\n        Uses ROUGE scores as the evaluation metric.\n        :param pred: The predicted string.\n        :param answer: The reference answer string.\n        :return: A dictionary containing ROUGE scores.\n        \"\"\"\n        for answer in answers:\n            if pred == answer:\n                return 1\n\n        return 0\n\n# Function to generate text using API\ndef generate_text(api_url, question, model_name, stream=False):\n    headers = {\n        'accept': 'application/json',\n        'Content-Type': 'application/json',\n        'Authorization' : 'Bearer '\n    }\n    data = {\n        \"messages\": [{\"content\": question, \"role\": \"user\"}],\n        \"model\": model_name,\n        \"stream\": stream,\n        # \"temperature\": 0.0\n    }\n    \n    print(\"POST data:\", data)\n    response = requests.post(api_url, headers=headers, json=data)\n    \n    if response.status_code == 200:\n        result = response.json()\n        return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()\n    else:\n        print(f\"API Request failed with status code {response.status_code}\")\n        return None\n\n# Main function to handle multiple evaluations\ndef main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):\n    start_total_time = time.time()\n\n    total_score = 0\n\n    results = []\n    random.seed(42)\n    random.shuffle(data_evaluator.data)\n    for i in range(min(concurrent_requests, len(data_evaluator.data))):\n        # Randomly select a data item from data for each request\n        data_item = data_evaluator.data[i]\n        question = data_evaluator.get_prompt(data_item)\n        # print(question)\n\n        # Start the timer for this evaluation\n        start_time = time.time()\n        try:\n            # Generate prediction using the API\n            prediction = generate_text(api_url, question, model_name)\n\n            if prediction is None:\n                raise Exception(f\"Failed to get prediction for {question}\")\n\n            answer = data_item['answer']\n            # Compute score\n            score = data_evaluator.score(data_evaluator.post_processing(prediction), answer)\n\n            # Calculate the time taken\n            elapsed_time = time.time() - start_time\n\n            # Collect the result data\n            result_data = {\n                \"question_id\": data_item['question_id'],\n                \"answer\": answer,\n                \"prediction\": data_evaluator.post_processing(prediction),\n                \"score\": score,\n                \"time\": elapsed_time\n            }\n\n            # Write results to result.json with each field on a new line\n            with open(result_file, 'a', encoding='utf-8') as f:\n                json.dump(result_data, f, ensure_ascii=False, indent=4)\n                f.write(\"\\n\")  # Ensure each JSON object is on a new line\n\n            results.append(result_data)\n\n            # Aggregate scores\n            total_score += score\n\n        except Exception as e:\n            print(f\"Error processing request {i}: {e}\")\n\n    # Calculate total time and throughput\n    total_time = time.time() - start_total_time\n    throughput = concurrent_requests / total_time\n\n    # Log the total time, throughput, and average ROUGE scores\n    with open(log_file, 'a', encoding='utf-8') as log_f:\n        log_f.write(f\"Total Time: {total_time:.2f} seconds\\n\")\n        log_f.write(f\"Throughput: {throughput:.2f} requests per second\\n\")\n        log_f.write(f\"Average Scores: {total_score / concurrent_requests}\\n\")\n        log_f.write('-' * 40 + '\\n')\n\n    print(f\"Results saved to {result_file}\")\n    print(f\"Log saved to {log_file}\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"API Generate Tester\")\n    parser.add_argument(\"--concurrent\", type=int, default=1000, help=\"Number of concurrent evaluations\")\n    parser.add_argument(\"--file\", type=str, default=\"TIGER-Lab/MMLU-Pro\", help=\"Path to the mmlu.jsonl file\")\n    parser.add_argument(\"--result\", type=str, default=\"./mmlu_result_pro.json\", help=\"Path to save the result JSON file\")\n    parser.add_argument(\"--log\", type=str, default=\"./mmlu_result_pro.log\", help=\"Path to save the log file\")\n    parser.add_argument(\"--model\", type=str, default=\"Pro/deepseek-ai/DeepSeek-V3\", help=\"Model name or path\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:15488/v1/chat/completions\", help=\"API URL\")\n    # parser.add_argument(\"--api_url\", type=str, default=\"https://api.siliconflow.cn/v1/chat/completions\", help=\"API URL\")\n\n    args = parser.parse_args()\n\n    # Load the data from the provided file\n    # template_prompt = hint + \"\\nQuestion: {question}\\nA. {options}\\nB. {option_b}\\nC. {option_c}\\nD. {option_d}\\nAnswer: '\"\n    # template_prompt_pro = hint + \"\\nQuestion: {question}\\nA. {options[0]}\\nB. {options[1]}\\nC. {options[2]}\\nD. {options[3]}\\nE. {options[4]}\\nF. {options[5]}\\nG. \\\n        # {options[6]}\\nH. {options[7]}\\nI. {options[8]}\\nJ. {options[9]}\\nAnswer: '\"\n\n\n    # Load the data from the provided file\n    data_evaluator = DataEvaluator()\n    data_evaluator.load_data(args.file)\n\n    # Run the main function with the specified number of concurrent evaluations\n    main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)"
  },
  {
    "path": "kt-sft/ktransformers/tests/mmlu_test.py",
    "content": "import argparse\nimport random\nimport time\nimport json\nimport requests\nimport pandas as pd\nfrom datasets import load_dataset\n\nimport os\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\nos.environ['https_proxy'] = ''\nos.environ['http_proxy'] = ''\nhint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'\n\n\nclass DataEvaluator:\n    def __init__(self):\n        # self.template_prompt = template_prompt\n        self.data = []\n\n    def load_data(self, file_path):\n        \"\"\"\n        Load data from a Parquet file into a list.\n        Each record in the Parquet file should represent an individual record.\n        \"\"\"\n        # dataset = load_dataset('parquet', data_files=file_path)\n        splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',\n                  'dev': 'all/dev-00000-of-00001.parquet',\n                  'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}\n        df = pd.read_parquet(\"hf://datasets/cais/mmlu/\" + splits[\"test\"])\n\n        for _, row in df.iterrows():\n            self.data.append(row.to_dict())\n\n    def get_prompt(self, record):\n        \"\"\"\n        Combine fields from a record with the template prompt to create a full prompt.\n        :param record: Dictionary containing fields to populate the template.\n        :return: A formatted prompt string.\n        \"\"\"\n        options_str = \"\\n\".join([f\"{chr(65 + i)}. {opt}\" for i, opt in enumerate(record['choices'])])\n        prompt = hint + \"\\nQuestion: \" + record['question'] + \"\\n\" + options_str + \"\\nAnswer: '\"\n        return prompt\n        \n    def post_processing(self, text):\n        \"\"\"\n        Perform post-processing on the prediction string.\n        :param text: The raw prediction string.\n        :return: Processed prediction string.\n        \"\"\"\n        text = text.lstrip('\\n').split('\\n')[-1]\n        return text[-1:]\n\n    def score(self, pred, answers):\n        \"\"\"\n        Calculate scores between the prediction and the answer.\n        Uses ROUGE scores as the evaluation metric.\n        :param pred: The predicted string.\n        :param answer: The reference answer string.\n        :return: A dictionary containing ROUGE scores.\n        \"\"\"\n        for answer in answers:\n            if pred == answer:\n                return 1\n\n        return 0\n\n# Function to generate text using API\ndef generate_text(api_url, question, model_name, stream=False):\n    headers = {\n        'accept': 'application/json',\n        'Content-Type': 'application/json',\n        'Authorization' : 'Bearer '\n    }\n    data = {\n        \"messages\": [{\"content\": question, \"role\": \"user\"}],\n        \"model\": model_name,\n        \"stream\": stream,\n        # \"temperature\": 0.0\n    }\n    \n    print(\"POST data:\", data)\n    response = requests.post(api_url, headers=headers, json=data)\n    \n    if response.status_code == 200:\n        result = response.json()\n        return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()\n    else:\n        print(f\"API Request failed with status code {response.status_code}\")\n        return None\n\n# Main function to handle multiple evaluations\ndef main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):\n    start_total_time = time.time()\n\n    total_score = 0\n\n    results = []\n    random.seed(42)\n    random.shuffle(data_evaluator.data)\n    for i in range(min(concurrent_requests, len(data_evaluator.data))):\n        # Randomly select a data item from data for each request\n        data_item = data_evaluator.data[i]\n        question = data_evaluator.get_prompt(data_item)\n        # print(question)\n\n        # Start the timer for this evaluation\n        start_time = time.time()\n        try:\n            # Generate prediction using the API\n            prediction = generate_text(api_url, question, model_name)\n\n            if prediction is None:\n                raise Exception(f\"Failed to get prediction for {question}\")\n\n            answer = chr(data_item['answer'] + 65)\n            # Compute score\n            score = data_evaluator.score(data_evaluator.post_processing(prediction), answer)\n\n            # Calculate the time taken\n            elapsed_time = time.time() - start_time\n\n            # Collect the result data\n            result_data = {\n                \"question_id\": i,\n                \"answer\": answer,\n                \"prediction\": data_evaluator.post_processing(prediction),\n                \"score\": score,\n                \"time\": elapsed_time\n            }\n\n            # Write results to result.json with each field on a new line\n            with open(result_file, 'a', encoding='utf-8') as f:\n                json.dump(result_data, f, ensure_ascii=False, indent=4)\n                f.write(\"\\n\")  # Ensure each JSON object is on a new line\n\n            results.append(result_data)\n\n            # Aggregate scores\n            total_score += score\n\n        except Exception as e:\n            print(f\"Error processing request {i}: {e}\")\n\n    # Calculate total time and throughput\n    total_time = time.time() - start_total_time\n    throughput = concurrent_requests / total_time\n\n    # Log the total time, throughput, and average ROUGE scores\n    with open(log_file, 'a', encoding='utf-8') as log_f:\n        log_f.write(f\"Total Time: {total_time:.2f} seconds\\n\")\n        log_f.write(f\"Throughput: {throughput:.2f} requests per second\\n\")\n        log_f.write(f\"Average Scores: {total_score / concurrent_requests}\\n\")\n        log_f.write('-' * 40 + '\\n')\n\n    print(f\"Results saved to {result_file}\")\n    print(f\"Log saved to {log_file}\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"API Generate Tester\")\n    parser.add_argument(\"--concurrent\", type=int, default=1000, help=\"Number of concurrent evaluations\")\n    parser.add_argument(\"--file\", type=str, default=\"cais/mmlu\", help=\"Path to the mmlu.jsonl file\")\n    parser.add_argument(\"--result\", type=str, default=\"./mmlu_result_silicon.json\", help=\"Path to save the result JSON file\")\n    parser.add_argument(\"--log\", type=str, default=\"./mmlu_result_silicon.log\", help=\"Path to save the log file\")\n    parser.add_argument(\"--model\", type=str, default=\"Pro/deepseek-ai/DeepSeek-V3\", help=\"Model name or path\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10003/v1/chat/completions\", help=\"API URL\")\n    # parser.add_argument(\"--api_url\", type=str, default=\"https://api.siliconflow.cn/v1/chat/completions\", help=\"API URL\")\n\n    args = parser.parse_args()\n\n    # Load the data from the provided file\n    # template_prompt = hint + \"\\nQuestion: {question}\\nA. {options}\\nB. {option_b}\\nC. {option_c}\\nD. {option_d}\\nAnswer: '\"\n    # template_prompt_pro = hint + \"\\nQuestion: {question}\\nA. {options[0]}\\nB. {options[1]}\\nC. {options[2]}\\nD. {options[3]}\\nE. {options[4]}\\nF. {options[5]}\\nG. \\\n        # {options[6]}\\nH. {options[7]}\\nI. {options[8]}\\nJ. {options[9]}\\nAnswer: '\"\n\n\n    # Load the data from the provided file\n    data_evaluator = DataEvaluator()\n    data_evaluator.load_data(args.file)\n\n    # Run the main function with the specified number of concurrent evaluations\n    main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)"
  },
  {
    "path": "kt-sft/ktransformers/tests/mmlu_test_multi.py",
    "content": "import argparse\nimport random\nimport time\nimport json\nimport requests\nimport pandas as pd\nfrom datasets import load_dataset\nimport os\nimport concurrent.futures\nimport threading\nimport re\n\nos.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\nos.environ['https_proxy'] = ''\nos.environ['http_proxy'] = ''\nhint = 'There is a single choice question. Answer the question by replying A, B, C, D. No other answers are accepted. Just the letter.'\n\n\ndef extract_final_answer(text):\n    \"\"\"\n    提取模型预测的最终选项（如 A/B/C/D）\n    支持自然语言、多行、markdown、高亮、非末尾结论等格式\n    \"\"\"\n    text = text.strip()\n\n    explicit_patterns = [\n        r'Answer:\\s*([A-D])\\b',\n        r'Correct answer:\\s*([A-D])\\b',\n        r'The correct answer is\\s*\\*?\\*?\\s*([A-D])\\b',\n        r'Answer is\\s*([A-D])\\b',\n        r'Therefore,\\s*answer is\\s*([A-D])\\b',\n        r'Therefore,\\s*the answer should be\\s*(?:Option\\s*)?([A-D])\\b',\n        r'The answer should be\\s*(?:Option\\s*)?([A-D])\\b',\n        r'Option\\s+([A-D])\\s+is correct',\n    ]\n    for pat in explicit_patterns:\n        match = re.search(pat, text, re.IGNORECASE)\n        if match:\n            return match.group(1).upper()\n\n    markdown_match = re.findall(r'\\*\\*\\s*([A-D])[\\.\\s]?', text)\n    if markdown_match:\n        return markdown_match[-1].upper()\n\n    quote_match = re.findall(r\"['\\\"]([A-D])['\\\"]\", text)\n    if quote_match:\n        return quote_match[-1].upper()\n\n    lines = text.splitlines()\n    for line in reversed(lines[-5:]):\n        line = line.strip()\n        match = re.match(r'^([A-D])([.\\s]|$)', line)\n        if match:\n            return match.group(1).upper()\n    \n    return None\nclass DataEvaluator:\n    def __init__(self):\n        self.data = []\n\n    def load_data(self, file_path):\n        \"\"\"\n        从数据文件中加载数据，每条记录对应一个实例\n        \"\"\"\n        splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet',\n                  'dev': 'all/dev-00000-of-00001.parquet',\n                  'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}\n        df = pd.read_parquet(\"hf://datasets/cais/mmlu/\" + splits[\"test\"])\n        for _, row in df.iterrows():\n            self.data.append(row.to_dict())\n\n    def get_prompt(self, record):\n        \"\"\"\n        结合提示信息和记录数据生成完整的题目\n        \"\"\"\n        options_str = \"\\n\".join([f\"{chr(65 + i)}. {opt}\" for i, opt in enumerate(record['choices'])])\n        prompt = hint + \"\\nQuestion: \" + record['question'] + \"\\n\" + options_str + \"\\nAnswer: '\"\n        return prompt\n\n    def post_processing(self, text):\n        \"\"\"\n        对生成的文本进行后处理，提取最终答案（只返回最后一个字符）\n        \"\"\"\n        text = text.lstrip('\\n').split('\\n')[-1]\n        return text[-1:]\n\n    def score(self, pred, answer):\n        \"\"\"\n        对比预测答案和正确答案，返回得分\n        \"\"\"\n        if pred == answer:\n            return 1\n        return 0\n\ndef generate_text(api_url, question, model_name, stream=False):\n    headers = {\n        'accept': 'application/json',\n        'Content-Type': 'application/json',\n        'Authorization': 'Bearer '\n    }\n    data = {\n        \"messages\": [{\"content\": question, \"role\": \"user\"}],\n        \"model\": model_name,\n        \"stream\": stream,\n    }\n    print(\"POST data:\", data)\n    response = requests.post(api_url, headers=headers, json=data, timeout=5000000)\n    if response.status_code == 200:\n        result = response.json()\n        return result.get('choices', [{}])[0].get('message', {}).get('content', '').strip()\n    else:\n        print(f\"API Request failed with status code {response.status_code}\")\n        return None\n\ndef main(concurrent_requests, data_evaluator: DataEvaluator, result_file, log_file, api_url, model_name):\n    start_total_time = time.time()\n    total_score = 0\n    total_exact_score = 0\n    results = []\n    file_lock = threading.Lock()\n    \n    random.seed(42)\n    random.shuffle(data_evaluator.data)\n    data_subset = data_evaluator.data[:min(concurrent_requests, len(data_evaluator.data))]\n    \n    batch_size = 10\n\n    def worker(index, data_item):\n        nonlocal total_score\n        nonlocal total_exact_score\n        question = data_evaluator.get_prompt(data_item)\n        start_time = time.time()\n        try:\n            prediction = generate_text(api_url, question, model_name)\n            if prediction is None:\n                raise Exception(f\"Failed to get prediction for question: {question}\")\n            answer = chr(data_item['answer'] + 65)\n            processed_prediction = data_evaluator.post_processing(prediction)\n            score = data_evaluator.score(processed_prediction, answer)\n            exact_score = data_evaluator.score(extract_final_answer(prediction), answer)\n            elapsed_time = time.time() - start_time\n            result_data = {\n                \"question_id\": index,\n                \"answer\": answer,\n                \"prediction\": processed_prediction,\n                \"full_prediction\": prediction,\n                \"score\": score,\n                \"exact_score\": exact_score,\n                \"time\": elapsed_time\n            }\n            with file_lock:\n                with open(result_file, 'a', encoding='utf-8') as f:\n                    json.dump(result_data, f, ensure_ascii=False, indent=4)\n                    f.write(\"\\n\")\n            return result_data\n        except Exception as e:\n            print(f\"Error processing request {index}: {e}\")\n            return None\n\n    for batch_start in range(0, len(data_subset), batch_size):\n        batch = data_subset[batch_start: batch_start + batch_size]\n        with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:\n            futures = [executor.submit(worker, batch_start + j, data_item) for j, data_item in enumerate(batch)]\n            for future in concurrent.futures.as_completed(futures):\n                res = future.result()\n                if res is not None:\n                    results.append(res)\n                    total_score += res['score']\n                    total_exact_score += res['exact_score']\n    \n    total_time = time.time() - start_total_time\n    throughput = len(data_subset) / total_time if total_time > 0 else 0\n    \n    with open(log_file, 'a', encoding='utf-8') as log_f:\n        log_f.write(f\"Total Time: {total_time:.2f} seconds\\n\")\n        log_f.write(f\"Throughput: {throughput:.2f} requests per second\\n\")\n        average_score = total_score / len(data_subset) if data_subset else 0\n        log_f.write(f\"Average Score: {average_score}\\n\")\n        average_exact_score = total_exact_score / len(data_subset) if data_subset else 0\n        log_f.write(f\"Average Exact Score: {average_exact_score}\\n\")\n        log_f.write('-' * 40 + '\\n')\n    \n    print(f\"Results saved to {result_file}\")\n    print(f\"Log saved to {log_file}\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"API Generate Tester\")\n    parser.add_argument(\"--concurrent\", type=int, default=1000, help=\"需要测试的实例总数\")\n    parser.add_argument(\"--file\", type=str, default=\"cais/mmlu\", help=\"数据文件路径\")\n    parser.add_argument(\"--result\", type=str, default=\"./mmlu_result_silicon.json\", help=\"结果文件保存路径\")\n    parser.add_argument(\"--log\", type=str, default=\"./mmlu_result_silicon.log\", help=\"日志文件保存路径\")\n    parser.add_argument(\"--model\", type=str, default=\"Pro/deepseek-ai/DeepSeek-V3\", help=\"模型名称或路径\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10006/v1/chat/completions\", help=\"API URL\")\n\n    args = parser.parse_args()\n    \n    data_evaluator = DataEvaluator()\n    data_evaluator.load_data(args.file)\n    \n    main(args.concurrent, data_evaluator, args.result, args.log, args.api_url, args.model)"
  },
  {
    "path": "kt-sft/ktransformers/tests/score.py",
    "content": "import subprocess\nimport time\nimport requests\nimport sys\nimport os\n\ndef wait_for_server(base_url: str, timeout: int = None) -> None:\n    start_time = time.time()\n    while True:\n        try:\n            response = requests.get(\n                f\"{base_url}/v1/models\",\n                headers={\"Authorization\": \"Bearer None\"},\n            )\n            if response.status_code == 200:\n                print(\"Server is ready.\")\n                break\n        except requests.exceptions.RequestException:\n            time.sleep(1)\n            if timeout and time.time() - start_time > timeout:\n                raise TimeoutError(\"Server did not become ready within timeout period\")\n\nserver_cmd = [\n    \"numactl\", \"-N\", \"1\", \"-m\", \"1\",\n    \"/home/qujing3/anaconda3/envs/ktransformers-dev/bin/ktransformers\",\n    \"--model_path\", \"/home/qujing3/models/DeepSeek-R1-Q4_K_M/config\",\n    \"--gguf_path\", \"/home/qujing3/models/DeepSeek-V3-GGUF/DeepSeek-V3-Q4_K_M\",\n    \"--port\", \"10002\",\n    \"--cpu_infer\", \"48\",\n    \"--optimize_config_path\", \"ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml\",\n    \"--max_new_tokens\", \"3000\",\n    \"--cache_lens\", \"6000\"\n]\n\nprint(\"Starting ktransformers server...\")\nprint(\" \".join(server_cmd))\nwith open(\"/tmp/server_log.txt\", \"w\") as f:\n    server_process = subprocess.Popen(server_cmd, stdout=f, stderr=f, text=True)\n\ntry:\n    wait_for_server(\"http://localhost:10002\", timeout=600)\n\n    eval_cmd = [\"python\", \"ktransformers/tests/humaneval/eval_api.py\"]\n    print(\"Running eval_api.py...\")\n    print(f\"Command: {' '.join(eval_cmd)}\")\n    \n    env = os.environ.copy()\n    env[\"PYTHONUNBUFFERED\"] = \"1\"\n    \n    eval_process = subprocess.Popen(\n        eval_cmd,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n        text=True,\n        bufsize=1,\n        env=env,\n        universal_newlines=True\n    )\n    \n    import threading\n    import queue\n    \n    def enqueue_output(out, queue):\n        for line in iter(out.readline, ''):\n            queue.put(line)\n        out.close()\n    \n    stdout_queue = queue.Queue()\n    stderr_queue = queue.Queue()\n    \n    stdout_thread = threading.Thread(target=enqueue_output, args=(eval_process.stdout, stdout_queue))\n    stderr_thread = threading.Thread(target=enqueue_output, args=(eval_process.stderr, stderr_queue))\n    \n    stdout_thread.daemon = True\n    stderr_thread.daemon = True\n    stdout_thread.start()\n    stderr_thread.start()\n    \n    while eval_process.poll() is None:\n        try:\n            line = stdout_queue.get_nowait()\n            print(line, end='', flush=True)\n        except queue.Empty:\n            pass\n            \n        try:\n            line = stderr_queue.get_nowait()\n            print(line, end='', file=sys.stderr, flush=True)\n        except queue.Empty:\n            pass\n        \n        time.sleep(1)\n\n    while not stdout_queue.empty():\n        print(stdout_queue.get(), end='', flush=True)\n    while not stderr_queue.empty():\n        print(stderr_queue.get(), end='', file=sys.stderr, flush=True)\n        \n    eval_process.wait()\n    print(f\"eval_api.py completed with exit code: {eval_process.returncode}\")\n\n    evaluate_cmd = [\n        \"evaluate_functional_correctness\",\n        \"ktransformers/tests/humaneval/results/api/eval_b.jsonl\"\n    ]\n    print(\"Running evaluate_functional_correctness...\")\n    print(f\"Command: {' '.join(evaluate_cmd)}\")\n    \n    evaluate_process = subprocess.Popen(\n        evaluate_cmd,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n        text=True,\n        bufsize=1,\n        universal_newlines=True\n    )\n    \n    for line in evaluate_process.stdout:\n        print(line, end='', flush=True)\n    for line in evaluate_process.stderr:\n        print(line, end='', file=sys.stderr, flush=True)\n        \n    evaluate_process.wait()\n    \n    print(f\"evaluate_functional_correctness completed with exit code: {evaluate_process.returncode}\")\n    if evaluate_process.returncode != 0:\n        print(f\"evaluate_functional_correctness exited with code {evaluate_process.returncode}\")\n        sys.exit(evaluate_process.returncode)\n\nfinally:\n    print(\"Stopping ktransformers server...\")\n    server_process.terminate()\n    try:\n        server_process.wait(timeout=30)\n    except subprocess.TimeoutExpired:\n        print(\"Server did not terminate gracefully, forcing...\")\n        server_process.kill()"
  },
  {
    "path": "kt-sft/ktransformers/tests/test_client.py",
    "content": "import asyncio\nimport json\nimport sys\nimport aiohttp\nimport argparse\n\nprompt_list = [\n    'Please elaborate on modern world history.',\n    'Please introduce Harry Potter.',\n    'I want to learn Python. Please give me some advice.',\n    'Please tell me a joke '\n]\n\n\nasync def fetch_event_stream(session, payload, request_id, stream):\n    try:\n        headers = {\n            'accept': 'application/json',\n            'Content-Type': 'application/json'\n        }\n\n        async with session.post(SERVER_URL, json=payload, headers=headers, timeout=50000) as response:\n            print(f\"Request {request_id}: Connected, status {response.status}\")\n\n            if response.status != 200:\n                print(f\"Request {request_id}: Error, status {response.status}\")\n                return\n\n            output_text = \"\"\n\n            if stream:\n                async for line in response.content:\n                    try:\n                        decoded_line = line.decode(\"utf-8\").strip()\n                        if not decoded_line or not decoded_line.startswith(\"data: \"):\n                            continue\n\n                        decoded_line = decoded_line[6:].strip()\n                        if not decoded_line:\n                            continue\n\n                        response_data = json.loads(decoded_line)\n                        choices = response_data.get(\"choices\", [])\n                        if not choices:\n                            continue\n\n                        delta = choices[0].get(\"delta\", {})\n                        token = delta.get(\"content\", \"\")\n\n                        if token:\n                            output_text += token\n                            sys.stdout.write(token)\n                            sys.stdout.flush()\n\n                        finish_reason = choices[0].get(\"finish_reason\", None)\n                        if finish_reason:\n                            break\n\n                    except json.JSONDecodeError as e:\n                        print(f\"\\nRequest {request_id}: JSON Decode Error - {e}\")\n                    except IndexError:\n                        print(f\"\\nRequest {request_id}: List Index Error - choices is empty\")\n                    except Exception as e:\n                        print(f\"\\nRequest {request_id}: Error parsing stream - {e}\")\n            else:\n                response_data = await response.json()\n                choices = response_data.get(\"choices\", [])\n                if choices:\n                    content = choices[0].get(\"message\", {}).get(\"content\", \"\")\n                    print(f\"Request {request_id} Output:\\n{content}\")\n                    output_text += content\n\n    except Exception as e:\n        print(f\"\\nRequest {request_id}: Exception - {e}\")\n\nasync def main(prompt_id, model, stream, max_tokens, temperature, top_p):\n    async with aiohttp.ClientSession() as session:\n        payload = {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"\"},\n                {\"role\": \"user\", \"content\": prompt_list[prompt_id]}\n            ],\n            \"model\": model,\n            \"stream\": stream,\n            \"max_tokens\": max_tokens,\n            \"temperature\": temperature,\n            \"top_p\": top_p\n        }\n        tasks = [fetch_event_stream(session, payload, prompt_id, stream)]\n        await asyncio.gather(*tasks)\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Event Stream Request Tester\")\n    parser.add_argument(\"--question_id\", type=int, default=0)\n    parser.add_argument(\"--model\", type=str, default=\"DeepSeek-V3\")\n    parser.add_argument(\"--stream\", type=bool, default=True)  \n    parser.add_argument(\"--max_tokens\", type=int, default=500)\n    parser.add_argument(\"--temperature\", type=float, default=0.8)\n    parser.add_argument(\"--top_p\", type=float, default=1)\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10002/v1/chat/completions\", help=\"API URL\")\n\n    args = parser.parse_args()\n    SERVER_URL = args.api_url\n    asyncio.run(main(args.question_id, args.model, args.stream, args.max_tokens, args.temperature, args.top_p))\n"
  },
  {
    "path": "kt-sft/ktransformers/tests/test_pytorch_q8.py",
    "content": "import torch\n\nclass LinearModel(torch.nn.Module):\n    def __init__(self, in_features, out_features):\n        super().__init__()\n        self.linear = torch.nn.Linear(in_features, out_features)\n    \n    def forward(self, x):\n        return self.linear(x)\n\nin_features = 64\nout_features = 128\nmodel_fp32 = LinearModel(in_features, out_features)\n\nmodel_int8 = torch.ao.quantization.quantize_dynamic(\n    model_fp32,\n    {torch.nn.Linear},\n    dtype=torch.qint8\n)\n\nbatch_size = 32\ninput_fp32 = torch.randn(1, batch_size, in_features)\noutput_int8 = model_int8(input_fp32)\n\nprint(f\"输入形状: {input_fp32.shape}\")\nprint(f\"输出形状: {output_int8.shape}\")\n\nwith torch.no_grad():\n    output_fp32 = model_fp32(input_fp32)\n    \nprint(f\"FP32输出的前几个值: {output_fp32[0, :5]}\")\nprint(f\"INT8输出的前几个值: {output_int8[0, :5]}\")\n\nerror = torch.abs(output_fp32 - output_int8).mean().item()\nprint(f\"平均绝对误差: {error}\")\n\nprint(f\"量化前模型类型: {type(model_fp32.linear)}\")\nprint(f\"量化后模型类型: {type(model_int8.linear)}\")"
  },
  {
    "path": "kt-sft/ktransformers/tests/test_speed.py",
    "content": "import asyncio\nimport json\nimport sys\nimport aiohttp\nimport random\nimport argparse\nimport yaml\nimport os\nimport time\nfrom time import sleep\n\ndecodesz = 128\n# Server URL (replace with your server URL)\ndecodesz_list = [128]\nprefill_speeds = []\ndecode_speeds = []\nktansformer_prompt1024=\"\"\"Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. \nThey were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense.Mr. Dursley was the director of a firm called Grunnings, which made drills. \nHe was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. \nDursley was thin and blonde and had nearly twice the usual amount of neck, which came in very useful as she spent so much of her time craning over garden fences, spying on the neighbors. \nThe Dursleys had a small son called Dudley and in their opinion there was no finer boy anywhere.\nThe Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it. \nThey didn't think they could bear it if anyone found out about the Potters. Mrs. Potter was Mrs. Dursley's sister, but they hadn't met for several years; in fact, Mrs. \nDursley pretended she didn't have a sister, because her sister and her good-for-nothing husband were as unDursleyish as it was possible to be. \nThe Dursleys shuddered to think what the neighbors would say if the Potters arrived in the street. \nThe Dursleys knew that the Potters had a small son, too, but they had never even seen him. \nThis boy was another good reason for keeping the Potters away; they didn't want Dudley mixing with a child like that.When Mr. and Mrs. \nDursley woke up on the dull, gray Tuesday our story starts, there was nothing about the cloudy sky outside to suggest that strange and mysterious things would soon be happening all over the country. \nMr. Dursley hummed as he picked out his most boring tie for work, and Mrs. Dursley gossiped away happily as she wrestled a screaming Dudley into his high chair.None of them noticed a large, tawny owl flutter past the window.\nAt half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. Dursley on the cheek, and tried to kiss Dudley good-bye but missed, because Dudley was now having a tantrum and throwing his cereal at the walls.\n“Little tyke,” chortled Mr. Dursley as he left the house. He got into his car and backed out of number four's drive.\nIt was on the corner of the street that he noticed the first sign of something peculiar — a cat reading a map. \nFor a second, Mr. Dursley didn't realize what he had seen — then he jerked his head around to look again. \nThere was a tabby cat standing on the corner of Privet Drive, but there wasn't a map in sight. \nWhat could he have been thinking of? It must have been a trick of the light. \nMr. Dursley blinked and stared at the cat. It stared back. As Mr. Dursley drove around the corner and up the road, he watched the cat in his mirror. \nIt was now reading the sign that said Privet Drive — no, looking at the sign; cats couldn't read maps or signs. \nMr. Dursley gave himself a little shake and put the cat out of his mind. \nAs he drove toward town he thought of nothing except a large order of drills he was hoping to get that day.\nBut on the edge of town, drills were driven out of his mind by something else. \nAs he sat in the usual morning traffic jam, he couldn't help noticing that there seemed to be a lot of strangely dressed people about. \nPeople in cloaks. Mr. Dursley couldn't bear people who dressed in funny clothes — the getups you saw on young people! \nHe supposed this was some stupid new fashion. He drummed his fingers on the steering wheel and his eyes fell on a huddle of these weirdos standing quite close by. \nThey were whispering excitedly together. Mr. Dursley was enraged to see that a couple of them weren't young at all; why, that man had to be older than he was, and wearing an emerald-green cloak! \nThe nerve of him! But then it struck Mr. Dursley that this was probably some silly stunt — these people were obviously collecting for something… yes, that would be it. \nThe traffic moved on and a few minutes later, Mr. Dursley arrived in the Grunnings parking lot, his mind back on drills.\nMr. Dursley always sat with his back to the window in his office on the ninth floor.\"\"\"\nasync def fetch_event_stream(session, request_id, prompt, max_tokens, model):\n    try:\n        payload = {\n            \"messages\": [\n                {\"role\": \"system\", \"content\": \"\"},\n                {\"role\": \"user\", \"content\": prompt}\n            ],\n            \"model\": model,\n            \"temperature\": 0.3,\n            \"top_p\": 1.0,\n            \"stream\": True,\n            \"return_speed\": True,\n            \"max_tokens\": max_tokens,\n        }\n\n        headers = {\n            'accept': 'application/json',\n            'Content-Type': 'application/json'\n        }\n\n        async with session.post(SERVER_URL, json=payload, headers=headers, timeout=500000) as response:\n            if response.status != 200:\n                print(f\"[Request {request_id}] Error: Status {response.status}\")\n                return\n\n            buffer = \"\"  \n            total_tokens = 0\n            decode_start_time = None\n            decode_end_time = None\n            usage_info = None  \n\n            async for line in response.content:\n                try:\n                    decoded_line = line.decode(\"utf-8\").strip()\n                    if not decoded_line or not decoded_line.startswith(\"data: \"):\n                        continue\n\n                    decoded_line = decoded_line[6:].strip()\n                    if not decoded_line:\n                        continue\n\n                    response_data = json.loads(decoded_line)\n                    \n                    if \"usage\" in response_data:\n                        usage_info = response_data[\"usage\"]\n                    \n                    choices = response_data.get(\"choices\", [])\n                    if not choices:\n                        continue\n\n                    delta = choices[0].get(\"delta\", {})\n                    token = delta.get(\"content\", \"\")\n\n                    if token:\n                        if decode_start_time is None:\n                            decode_start_time = time.time()\n                        buffer += token\n                        total_tokens += 1\n                        decode_end_time = time.time()\n\n                        while \"\\n\" in buffer:\n                            line, buffer = buffer.split(\"\\n\", 1)\n                            print(f\"[Request {request_id}] {line}\")\n\n                    finish_reason = choices[0].get(\"finish_reason\", None)\n                    if finish_reason:\n                        break\n\n                except Exception as e:\n                    print(f\"[Request {request_id}] Stream Error: {e}\")\n\n            if buffer.strip():\n                print(f\"[Request {request_id}] {buffer.strip()}\")\n\n            if usage_info:\n                if \"prefill_time\" in usage_info:\n                    # print(f\"[Request {request_id}] Usage:\")\n                    # for key, value in usage_info.items():\n                    #     print(f\"  {key}: {value}\")\n                    prefill_speed = usage_info[\"prompt_tokens\"] / usage_info[\"prefill_time\"]\n                    decode_speed = usage_info[\"completion_tokens\"] / usage_info[\"decode_time\"]\n                    prefill_speeds.append(prefill_speed)\n                    decode_speeds.append(decode_speed)\n                    print(f'[Request {request_id}] prefill speed: {prefill_speed}')\n                    print(f'[Request {request_id}] decode speed: {decode_speed}')\n\n    except Exception as e:\n        print(f\"[Request {request_id}] Exception: {e}\")\n\nasync def main(concurrent_requests , prompt, max_tokens, model):\n    async with aiohttp.ClientSession() as session:\n        tasks = [fetch_event_stream(session, i , prompt, max_tokens, model) for i in range(concurrent_requests)]\n        await asyncio.gather(*tasks)\n    if len(prefill_speeds) != 0:\n        import numpy as np\n        print(f\"concurrency: {len(prefill_speeds)}\")\n        print(f\"total prefill speed: {np.sum(prefill_speeds)}\\n total decode speed: {np.sum(decode_speeds)}\")\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Event Stream Request Tester\")\n    parser.add_argument(\"--concurrent\", type=int, default=1, help=\"Number of concurrent requests\")\n    parser.add_argument(\"--model\", type=str, default=\"DeepSeek-V3\", help=\"Model name\")\n    parser.add_argument(\"--prompt_lens\", type=int, default=1024, help=\"prefill prompt lens, 1024 or 2048\")\n    parser.add_argument(\"--api_url\", type=str, default=\"http://localhost:10002/v1/chat/completions\", help=\"API URL\")\n    parser.add_argument(\"--max_tokens\", type=int, default=50, help=\"max decode tokens\")\n    \n    args = parser.parse_args()\n    SERVER_URL = args.api_url\n    max_tokens = args.max_tokens\n    model = args.model\n    if args.prompt_lens == 1024:\n        prompt = ktansformer_prompt1024\n    elif args.prompt_lens == 2048:\n        prompt = ktansformer_prompt1024 * 2\n    elif args.prompt_lens == 4096:\n        prompt = ktansformer_prompt1024 * 4\n    asyncio.run(main(args.concurrent, prompt, max_tokens, model))\n\n"
  },
  {
    "path": "kt-sft/ktransformers/tests/triton_fp8gemm_test.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom typing import Optional\nimport pytest\nfrom typing import Tuple, Optional, Literal\nimport time\n# use dir path\nimport os\nimport sys\nsys.path.insert(0, \"/home/azure/ktransformers\")\nprint(sys.path)\nfrom ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant\nfrom safetensors import safe_open\n\nworld_size = 1\nrank = 0\nblock_size = 128\ngemm_impl: Literal[\"bf16\", \"fp8\"] = \"bf16\"\n# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined\n\ndef test_fp8_gemm_vs_torch_matmul():\n    # Test case 1: Create random matrices of size (M, K) and (K, N)\n    M, K, N = 64, 128, 256  # Matrix dimensions\n    x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')\n    weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda')\n\n    # Apply act_quant to both matrices\n    x_quantized, scale_x = act_quant(x, block_size)\n    weight_quantized, scale_w = act_quant(weight, block_size)\n    \n    # mk continous\n    x_quantized = x_quantized.contiguous()\n    weight_quantized = weight_quantized.contiguous()\n    scale_x = scale_x.contiguous()\n    scale_w = scale_w.contiguous()\n\n    # Perform fp8_gemm using the quantized tensors\n    result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w)\n\n    # Perform torch.matmul using the original floating point tensors\n    result_torch_matmul = torch.matmul(x, weight.T)\n    print(f'result_torch_matmul: {result_torch_matmul.shape}')\n    print(f'result_fp8_gemm: {result_fp8_gemm.shape}')\n\n    print(f\"result_fp8_gemm:\\n {result_fp8_gemm}\")\n    print(f\"result_torch_matmul:\\n {result_torch_matmul}\")\n    \ndef test_fp8_gemm_vs_torch_matmul_load():\n    file_path = \"/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors\"\n    with safe_open(file_path, framework=\"pt\", device=0) as f:\n        weight = f.get_tensor(\"model.layers.0.mlp.down_proj.weight\")\n        scale = f.get_tensor(\"model.layers.0.mlp.down_proj.weight_scale_inv\")\n\n    # weight_dequant\n    weight_dequantized = weight_dequant(weight, scale)\n    print(f\"weight_dequantized: {weight_dequantized.shape}\")\n    N, K = weight_dequantized.shape\n    M = 64\n    x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')\n    x_quantized, scale_x = act_quant(x, block_size)\n    \n    # Test case 1: quantized x matmal with undequantized weight\n    result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)\n    print(f\"result_fp8_gemm:\\n {result_fp8_gemm}\")\n    print(f\"dtype {result_fp8_gemm.dtype}\")\n\n    # Perform torch.matmul using the original floating point tensors\n    result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)\n    print(f\"result_torch_matmul:\\n {result_torch_matmul}\")\n\ndef test_fp8_gemm_tplops():\n    file_path = \"/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors\"\n    with safe_open(file_path, framework=\"pt\", device=0) as f:\n        weight = f.get_tensor(\"model.layers.0.mlp.down_proj.weight\")\n        scale = f.get_tensor(\"model.layers.0.mlp.down_proj.weight_scale_inv\")\n\n    # weight_dequant\n    weight_dequantized = weight_dequant(weight, scale)\n    print(f\"weight_dequantized: {weight_dequantized.shape}\")\n    N, K = weight_dequantized.shape\n    M = 6400\n    x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')\n    # x_quantized, scale_x = act_quant(x, block_size)\n    \n    # Calculate time for 1000 fp8_gemm\n    i = 10\n    flops_per_gemm = 2 * M * N * K\n    total_flops = i * flops_per_gemm\n    \n    x_quantized, scale_x = act_quant(x, block_size)\n    result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)\n    x_quantized, scale_x = act_quant(x, block_size)\n    result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)\n\n    \n    t0 = time.time()\n    torch.cuda.synchronize()\n    for i in range(i):\n        x_quantized, scale_x = act_quant(x, block_size)\n        result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)\n    torch.cuda.synchronize()\n    t1 = time.time()\n    \n    total_time = t1 - t0\n    tflops = total_flops / total_time / 1e12\n    print(f\"total_time: {total_time}\")\n    print(f\"tflops: {tflops}\")\n    \n\n    \n    \nif __name__ == \"__main__\":\n    test_fp8_gemm_vs_torch_matmul()\n    test_fp8_gemm_vs_torch_matmul_load()\n    test_fp8_gemm_tplops()\n    "
  },
  {
    "path": "kt-sft/ktransformers/util/cuda_graph_runner.py",
    "content": "'''\nDescription  :  \nAuthor       : Boxin Zhang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport torch\nfrom typing import Dict\n\nclass CUDAGraphRunner:\n\n    def __init__(self):\n        self.graph = None\n        self.input_buffers: Dict[str, torch.Tensor] = {}\n        self.output_buffers: Dict[str, torch.Tensor] = {}\n\n    def capture(\n        self,\n        model,\n        cur_token,\n        position_ids,\n        cache_position,\n        past_key_values,\n        main_device,\n        **kwargs,\n    ) -> None:\n        assert self.graph is None\n        # Capture the graph.\n        torch.cuda.synchronize()\n        self.graph = torch.cuda.CUDAGraph()\n        #self.graph.enable_debug_mode()\n        self.model = model\n        inputs_embeds = model.model.embed_tokens(cur_token.to(\"cpu\")).to(main_device)\n        # torch.cuda.set_device can't set \"cuda\", must have a index\n        if main_device == \"cuda\":\n            main_device = \"cuda:0\"\n        torch.cuda.set_device(main_device)\n        self.main_device = main_device\n        capture_stream = torch.cuda.Stream()\n        with torch.cuda.graph(self.graph, stream = capture_stream):\n            logits=model(inputs_embeds=inputs_embeds, \n                         position_ids=position_ids,\n                         cache_position=cache_position,\n                         past_key_values=past_key_values,\n                         **kwargs)[0]\n            capture_stream.wait_stream(torch.cuda.current_stream())\n            torch.cuda.set_device(main_device)\n            torch.cuda.set_stream(capture_stream)\n        if past_key_values != None:    \n            past_key_values.change_seq_length(-1)\n        torch.cuda.synchronize(self.main_device)\n        #self.graph.debug_dump(\"cuda_graph_hooked.dot\")\n\n        # Save the input and output buffers.\n        self.input_buffers = {\n            \"inputs_embeds\": inputs_embeds,\n            \"position_ids\": position_ids,\n            \"cache_position\": cache_position,\n        }\n        self.output_buffers = {\"logits\": logits}\n        return\n\n    def forward(\n        self,\n        cur_token,\n        position_ids,\n        cache_position,\n    ) -> torch.Tensor:\n        # Copy the input tensors to the input buffers.\n        inputs_embeds = self.model.model.embed_tokens(cur_token.to(\"cpu\"))\n        self.input_buffers[\"inputs_embeds\"].copy_(inputs_embeds)\n        self.input_buffers[\"position_ids\"].copy_(position_ids)\n        self.input_buffers[\"cache_position\"].copy_(cache_position)\n\n        # Run the graph.\n        #print(\"begin replay\")\n        #time.sleep(1)\n        self.graph.replay()\n        torch.cuda.synchronize(self.main_device)\n        # Return the output tensor.\n        return self.output_buffers[\"logits\"]\n\n    def __call__(self, *args, **kwargs):\n        return self.forward(*args, **kwargs)\n"
  },
  {
    "path": "kt-sft/ktransformers/util/custom_gguf.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : Azure-Tang, Boxin Zhang, chenht2022\nDate         : 2024-07-26 08:48:54\nVersion      : 1.0.0\nLastEditors  : kkk1nak0\nLastEditTime : 2024-08-14 08:20:45\nAdapted from https://github.com/99991/pygguf/blob/main/gguf.py\nCopyright (c) 2023-2024 The ggml authors\nCopyright (c) 2024 Thomas Germer\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\n# copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf\n# GGUF specification\n# https://github.com/ggerganov/ggml/blob/master/docs/gguf.md\nimport struct\nimport warnings\nimport numpy as np\nimport re\nimport numpy.typing as npt\nfrom typing import Sequence\nimport os\nfrom enum import IntEnum\nimport torch\nif not torch.xpu.is_available():\n    import KTransformersOps\nimport ctypes\nimport math\n\nclass GGMLQuantizationType(IntEnum):\n    F32     = 0\n    F16     = 1\n    Q4_0    = 2\n    Q4_1    = 3\n    Q5_0    = 6\n    Q5_1    = 7\n    Q8_0    = 8\n    Q8_1    = 9\n    Q2_K    = 10\n    Q3_K    = 11\n    Q4_K    = 12\n    Q5_K    = 13\n    Q6_K    = 14\n    Q8_K    = 15\n    IQ2_XXS = 16\n    IQ2_XS  = 17\n    IQ3_XXS = 18\n    IQ1_S   = 19\n    IQ4_NL  = 20\n    IQ3_S   = 21\n    IQ2_S   = 22\n    IQ4_XS  = 23\n    I8      = 24\n    I16     = 25\n    I32     = 26\n    I64     = 27\n    F64     = 28\n    IQ1_M   = 29\n    BF16    = 30\n\nQK_K = 256\nGGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {\n    GGMLQuantizationType.F32:     (1, 4),\n    GGMLQuantizationType.F16:     (1, 2),\n    GGMLQuantizationType.Q4_0:    (32, 2 + 16),\n    GGMLQuantizationType.Q4_1:    (32, 2 + 2 + 16),\n    GGMLQuantizationType.Q5_0:    (32, 2 + 4 + 16),\n    GGMLQuantizationType.Q5_1:    (32, 2 + 2 + 4 + 16),\n    GGMLQuantizationType.Q8_0:    (32, 2 + 32),\n    GGMLQuantizationType.Q8_1:    (32, 4 + 4 + 32),\n    GGMLQuantizationType.Q2_K:    (256, 2 + 2 + QK_K // 16 + QK_K // 4),\n    GGMLQuantizationType.Q3_K:    (256, 2 + QK_K // 4 + QK_K // 8 + 12),\n    GGMLQuantizationType.Q4_K:    (256, 2 + 2 + QK_K // 2 + 12),\n    GGMLQuantizationType.Q5_K:    (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),\n    GGMLQuantizationType.Q6_K:    (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),\n    GGMLQuantizationType.Q8_K:    (256, 4 + QK_K + QK_K // 8),\n    GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4),\n    GGMLQuantizationType.IQ2_XS:  (256, 2 + QK_K // 4 + QK_K // 32),\n    GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8),\n    GGMLQuantizationType.IQ1_S:   (256, 2 + QK_K // 8 + QK_K // 16),\n    GGMLQuantizationType.IQ4_NL:  (32, 2 + 16),\n    GGMLQuantizationType.IQ3_S:   (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),\n    GGMLQuantizationType.IQ2_S:   (256, 2 + QK_K // 4 + QK_K // 16),\n    GGMLQuantizationType.IQ4_XS:  (256, 2 + 2 + QK_K // 2 + QK_K // 64),\n    GGMLQuantizationType.I8:      (1, 1),\n    GGMLQuantizationType.I16:     (1, 2),\n    GGMLQuantizationType.I32:     (1, 4),\n    GGMLQuantizationType.I64:     (1, 8),\n    GGMLQuantizationType.F64:     (1, 8),\n    GGMLQuantizationType.IQ1_M:   (256, QK_K // 8 + QK_K // 16  + QK_K // 32),\n    GGMLQuantizationType.BF16:    (1, 2),\n}\n\n# copied from llama.cpp/gguf-py/gguf/quants.py to avoid dependence of gguf\ndef quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):\n    block_size, type_size = GGML_QUANT_SIZES[quant_type]\n    if shape[-1] % block_size != 0:\n        raise ValueError(f\"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})\")\n    return (*shape[:-1], shape[-1] // block_size * type_size)\n\nGGML_TYPES = {\n    \"F32\": 0,\n    \"F16\": 1,\n    \"Q4_0\": 2,\n    \"Q5_0\": 6,\n    \"Q8_0\": 8,\n    \"Q2_K\": 10,\n    \"Q3_K\": 11,\n    \"Q4_K\": 12,\n    \"Q5_K\": 13,\n    \"Q6_K\": 14,\n    \"IQ4_XS\": 23,\n    \"BF16\": 30,\n}\n\nGGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}\n\nGGML_BLOCK_SIZES = {\n    \"F32\": 4,\n    \"F16\": 2,\n    \"BF16\": 2,\n    \"Q4_0\": 2 + 16,\n    \"Q5_0\": 2 + 4 + 16,\n    \"Q8_0\": 2 + 32,\n    \"Q2_K\": 256 // 16 + 256 // 4 + 2 + 2,\n    \"Q3_K\": 256 // 8 + 256 // 4 + 12 + 2,\n    \"Q4_K\": 2 + 2 + 12 + 256 // 2,\n    \"Q5_K\": 2 + 2 + 12 + 256 // 8 + 256 // 2,\n    \"Q6_K\": 256 // 2 + 256 // 4 + 256 // 16 + 2,\n    \"IQ4_XS\": 2 + 2 + 256 // 2 + 256 // 64,\n    \"FP8\": 1,\n}\n\nGGML_ELEMENTS_PER_BLOCK = {\n    \"F32\": 1,\n    \"F16\": 1,\n    \"BF16\": 1,\n    \"Q4_0\": 32,\n    \"Q5_0\": 32,\n    \"Q8_0\": 32,\n    \"Q2_K\": 256,\n    \"Q3_K\": 256,\n    \"Q4_K\": 256,\n    \"Q5_K\": 256,\n    \"Q6_K\": 256,\n    \"IQ4_XS\": 256,\n    \"FP8\": 1,\n}\n\nDATA_TYPES = {\n    \"uint8\": 0,\n    \"int8\": 1,\n    \"uint16\": 2,\n    \"int16\": 3,\n    \"uint32\": 4,\n    \"int32\": 5,\n    \"float32\": 6,\n    \"bool\": 7,\n    \"string\": 8,\n    \"array\": 9,\n    \"uint64\": 10,\n    \"int64\": 11,\n    \"float64\": 12,\n    \"FP8\": 13,\n}\n\n\ndef read_value(f, data_type):\n    if data_type == DATA_TYPES[\"string\"]:\n        length = struct.unpack(\"<Q\", f.read(8))[0]\n        return f.read(length).decode(\"utf-8\")\n\n    elif data_type == DATA_TYPES[\"bool\"]:\n        return bool(struct.unpack(\"<?\", f.read(1))[0])\n\n    elif data_type == DATA_TYPES[\"uint8\"]:\n        return struct.unpack(\"<B\", f.read(1))[0]\n\n    elif data_type == DATA_TYPES[\"int8\"]:\n        return struct.unpack(\"<b\", f.read(1))[0]\n\n    elif data_type == DATA_TYPES[\"uint16\"]:\n        return struct.unpack(\"<H\", f.read(2))[0]\n\n    elif data_type == DATA_TYPES[\"int16\"]:\n        return struct.unpack(\"<h\", f.read(2))[0]\n\n    elif data_type == DATA_TYPES[\"uint32\"]:\n        return struct.unpack(\"<I\", f.read(4))[0]\n\n    elif data_type == DATA_TYPES[\"int32\"]:\n        return struct.unpack(\"<i\", f.read(4))[0]\n\n    elif data_type == DATA_TYPES[\"float32\"]:\n        return struct.unpack(\"<f\", f.read(4))[0]\n\n    elif data_type == DATA_TYPES[\"uint64\"]:\n        return struct.unpack(\"<Q\", f.read(8))[0]\n\n    elif data_type == DATA_TYPES[\"int64\"]:\n        return struct.unpack(\"<q\", f.read(8))[0]\n\n    elif data_type == DATA_TYPES[\"float64\"]:\n        return struct.unpack(\"<d\", f.read(8))[0]\n\n    elif data_type == DATA_TYPES[\"array\"]:\n        elem_type, count = struct.unpack(\"<IQ\", f.read(4 + 8))\n        return [read_value(f, elem_type) for _ in range(count)]\n\n    elif data_type == DATA_TYPES[\"FP8\"]:\n        return struct.unpack(\"<B\", f.read(1))[0]\n\n    else:\n        raise NotImplementedError(f\"Data type {data_type} not implemented\")\n\ndef dequantize_q2_k(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74\n    block_size = GGML_BLOCK_SIZES[\"Q2_K\"]\n    num_blocks = len(data) // block_size\n\n    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)\n\n    dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)\n    d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)\n    scales = data_u8[:, :16].reshape(num_blocks, 16, 1)\n    qs = data_u8[:, 16:80].reshape(num_blocks, 64)\n\n    tmp = np.stack([\n        qs[:, 00:16] >> 0,\n        qs[:, 16:32] >> 0,\n        qs[:, 00:16] >> 2,\n        qs[:, 16:32] >> 2,\n        qs[:, 00:16] >> 4,\n        qs[:, 16:32] >> 4,\n        qs[:, 00:16] >> 6,\n        qs[:, 16:32] >> 6,\n        qs[:, 32:48] >> 0,\n        qs[:, 48:64] >> 0,\n        qs[:, 32:48] >> 2,\n        qs[:, 48:64] >> 2,\n        qs[:, 32:48] >> 4,\n        qs[:, 48:64] >> 4,\n        qs[:, 32:48] >> 6,\n        qs[:, 48:64] >> 6,\n    ], axis=1)\n\n    return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)\n\ndef dequantize_q2_k_gpu(data, device:str =\"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"Q2_K\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q2_K\"]\n    data = np.frombuffer(data, dtype=data.dtype)\n    device = torch.device(device)\n    # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, \n    # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q2_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\ndef dequantize_q3_k(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95\n    block_size = GGML_BLOCK_SIZES[\"Q3_K\"]\n    num_blocks = len(data) // block_size\n\n    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)\n\n    d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)\n    bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder=\"little\")\n    bits = 4 ^ (bits << 2)\n    qs = data_u8[:, 32:32 + 64].astype(np.int16)\n    a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)\n    scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)\n    scales[:, 0] = (a & 15) | ((c & 3) << 4)\n    scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)\n    scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)\n    scales[:, 3] = (b >> 4) | ((c >> 6) << 4)\n    scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)\n\n    return d * (scales - 32) * np.stack([\n        (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),\n        (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),\n        (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),\n        (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),\n        (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),\n        (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),\n        (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),\n        (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),\n        (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),\n        (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),\n        (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),\n        (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),\n        (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),\n        (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),\n        (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),\n        (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])\n    ], axis=1)\n\ndef dequantize_q3_k_gpu(data, device:str =\"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"Q3_K\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q3_K\"]\n    data = np.frombuffer(data, dtype=data.dtype)\n    device = torch.device(device)\n    # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, \n    # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q3_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\ndef dequantize_q4_k(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116\n    block_size = GGML_BLOCK_SIZES[\"Q4_K\"]\n    num_blocks = len(data) // block_size\n    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)\n    # Casting to float32 because float16 is very slow on CPU\n    scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)\n    scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)\n    qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)\n    qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)\n    # Dequantize scales and offsets (6 bits and 4 + 2 bits)\n    factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)\n    offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)\n    # Interleave low and high quantized bits\n    qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)\n    # Dequantize final weights using scales and offsets\n    return factors * qs2 - offsets\n\ndef dequantize_q4_k_gpu(data, device:str =\"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"Q4_K\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q4_K\"]\n    data = np.frombuffer(data, dtype=data.dtype)\n    device = torch.device(device)\n    # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, \n    # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\ndef dequantize_q5_k(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138\n    block_size = GGML_BLOCK_SIZES[\"Q5_K\"]\n    num_blocks = len(data) // block_size\n\n    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)\n\n    d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)\n    dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)\n    scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)\n    qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1)\n    qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32)\n\n    bits = np.unpackbits(qh, axis=-1, bitorder=\"little\")\n\n    qs_hi_4 = qs >> 4\n    qs_lo_4 = qs & 15\n\n    scales_lo_6 = scales[:, :8] & 63\n    scales_hi_6 = scales[:, :8] >> 6\n    scales_lo_4 = scales[:, 8:] & 15\n    scales_hi_4 = scales[:, 8:] >> 4\n\n    m1 = dmin * scales_lo_6[:, 4]\n    m2 = dmin * scales_lo_6[:, 5]\n    m3 = dmin * scales_lo_6[:, 6]\n    m4 = dmin * scales_lo_6[:, 7]\n    m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))\n    m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))\n    m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))\n    m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))\n\n    d1 = d * scales_lo_6[:, 0]\n    d2 = d * scales_lo_6[:, 1]\n    d3 = d * scales_lo_6[:, 2]\n    d4 = d * scales_lo_6[:, 3]\n    d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))\n    d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))\n    d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))\n    d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))\n\n    return np.concatenate([\n        d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,\n        d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,\n        d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,\n        d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,\n        d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,\n        d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,\n        d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,\n        d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,\n    ], axis=1)\n\ndef dequantize_q5_k_gpu(data, device:str =\"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"Q5_K\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q5_K\"]\n    data = np.frombuffer(data, dtype=data.dtype)\n    device = torch.device(device)\n    # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, \n    # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q5_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\ndef dequantize_q6_k(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152\n    block_size = GGML_BLOCK_SIZES[\"Q6_K\"]\n    num_blocks = len(data) // block_size\n\n    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)\n    data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)\n\n    scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)\n    # TODO use uint8 and cast later?\n    ql = data_u8[:, :128].astype(np.int16)\n    qh = data_u8[:, 128:192].astype(np.int16)\n    sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)\n\n    # Unpack bits, subtraction requires signed data type\n    q1 = (ql[:,   :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32\n    q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32\n    q3 = (ql[:,   :32 ] >>  4) | (((qh[:, :32] >> 4) & 3) << 4) - 32\n    q4 = (ql[:, 32:64 ] >>  4) | (((qh[:, :32] >> 6) & 3) << 4) - 32\n    q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32\n    q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32\n    q7 = (ql[:, 64:96 ] >>  4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32\n    q8 = (ql[:, 96:128] >>  4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32\n\n    # Dequantize\n    return scales * np.concatenate([\n        sc[:,  0] * q1[:, :16],\n        sc[:,  1] * q1[:, 16:],\n        sc[:,  2] * q2[:, :16],\n        sc[:,  3] * q2[:, 16:],\n        sc[:,  4] * q3[:, :16],\n        sc[:,  5] * q3[:, 16:],\n        sc[:,  6] * q4[:, :16],\n        sc[:,  7] * q4[:, 16:],\n        sc[:,  8] * q5[:, :16],\n        sc[:,  9] * q5[:, 16:],\n        sc[:, 10] * q6[:, :16],\n        sc[:, 11] * q6[:, 16:],\n        sc[:, 12] * q7[:, :16],\n        sc[:, 13] * q7[:, 16:],\n        sc[:, 14] * q8[:, :16],\n        sc[:, 15] * q8[:, 16:],\n    ], axis=1) \n\n# @torch.jit.script\ndef dequantize_q6_k_gpu(data: np.ndarray, device:str = \"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"Q6_K\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q6_K\"]\n    device = torch.device(device)\n    num_blocks = len(data) // block_size\n    data = np.frombuffer(data, dtype=data.dtype)\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\nkvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)\n\ndef dequantize_iq4_xs(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-quants.c#L3568\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/21d3a308fcb7f31cb9beceaeebad4fb622f3c337/src/ggml-common.h#L393\n    block_size = GGML_BLOCK_SIZES[\"IQ4_XS\"]\n    num_blocks = len(data) // block_size\n\n    d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1)\n    scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1)\n    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:]\n    scales_l = data_u8[:, :4].reshape(num_blocks, 4)\n    qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8)\n\n    ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8)\n    for ib in range(QK_K // 32):\n        ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4)\n\n    dl = (d * (ls - 32)).reshape(num_blocks, -1, 1)\n\n    qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf\n    qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4\n\n    y = np.zeros((num_blocks, QK_K), dtype=np.float32)\n    for ib in range(QK_K // 32):\n        y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]]\n        y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]]\n\n    return y.flatten()\n\ndef dequantize_iq4_xs_gpu(data: np.ndarray, device:str = \"cuda\", target_dtype = torch.get_default_dtype()):\n    block_size = GGML_BLOCK_SIZES[\"IQ4_XS\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"IQ4_XS\"]\n    device = torch.device(device)\n    num_blocks = len(data) // block_size\n    data = np.frombuffer(data, dtype=data.dtype)\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_iq4_xs(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\ndef dequantize_q4_0(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1515\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L141\n    num_blocks = len(data) // GGML_BLOCK_SIZES[\"Q4_0\"]\n\n    scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)\n    qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]\n\n    return np.concatenate([\n        scales * ((qs & 0xf).astype(np.int8) - 8),\n        scales * ((qs >> 4).astype(np.int8) - 8),\n    ], axis=1)\n\ndef dequantize_q4_0_gpu(data, device:str = \"cuda\", target_dtype = torch.get_default_dtype()):\n    raise NotImplementedError()\n\ndef dequantize_q5_0(data):\n    # C implementation\n    # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-quants.c#L1556\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/a3c0188a4b5d3dec052ff87c9f773baa53631d70/src/ggml-common.h#L161\n    num_blocks = len(data) // GGML_BLOCK_SIZES[\"Q5_0\"]\n\n    scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)\n    qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]\n    qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]\n\n    bits = np.unpackbits(qh, axis=-1, bitorder=\"little\")\n\n    x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16\n    x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16\n\n    return np.concatenate([\n        scales * x0,\n        scales * x1,\n    ], axis=1)\n\ndef dequantize_q5_0_gpu(data, device:str = \"cuda\", target_dtype = torch.get_default_dtype()):\n    raise NotImplementedError()\n\ndef dequantize_q8_0(data):\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43\n    num_blocks = len(data) // GGML_BLOCK_SIZES[\"Q8_0\"]\n\n    scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)\n    qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]\n    return scales * qs\n\ndef dequantize_q8_0_gpu(data, device:str = \"cuda\", target_dtype = torch.get_default_dtype()):\n    # C struct definition\n    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43\n    \n    block_size = GGML_BLOCK_SIZES[\"Q8_0\"]\n    ele_per_blk = GGML_ELEMENTS_PER_BLOCK[\"Q8_0\"]\n    device = torch.device(device)\n    data = np.frombuffer(data, dtype=data.dtype)\n    c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)\n    return KTransformersOps.dequantize_q8_0(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)\n\n\ndef dequantize_f32(data):\n    return np.frombuffer(data, dtype=np.float32)\n\ndef dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()):\n    data = np.frombuffer(data, dtype=np.float32)\n    res = torch.from_numpy(data.copy())\n    res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)\n    res_gpu.copy_(res)\n    return res_gpu\n\ndef dequantize_f16(data):\n    return np.frombuffer(data, dtype=np.float16)\n\ndef dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()):\n    data = np.frombuffer(data, dtype=np.float16)\n    res = torch.from_numpy(data.copy())\n    res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)\n    res_gpu.copy_(res)\n    return res_gpu\n\ndef dequantize_bf16_gpu(data, device, target_dtype = torch.get_default_dtype()):\n    data = np.frombuffer(data, dtype=np.float16)\n    res = torch.from_numpy(data.copy())\n    res_gpu = torch.empty_like(res, device=device)\n    res_gpu.copy_(res)\n    return res_gpu\n\nGGML_DEQUANTIZE = {\n    \"F32\": dequantize_f32,\n    \"F16\": dequantize_f16,\n    \"BF16\": dequantize_f16,\n    \"Q4_0\": dequantize_q4_0,\n    \"Q5_0\": dequantize_q5_0,\n    \"Q8_0\": dequantize_q8_0,\n    \"Q2_K\": dequantize_q2_k,\n    \"Q3_K\": dequantize_q3_k,\n    \"Q4_K\": dequantize_q4_k,\n    \"Q5_K\": dequantize_q5_k,\n    \"Q6_K\": dequantize_q6_k,\n    \"IQ4_XS\": dequantize_iq4_xs,\n}\n\nGGML_DEQUANTIZE_GPU = {\n    \"F32\": dequantize_f32_gpu,\n    \"F16\": dequantize_f16_gpu,\n    \"BF16\": dequantize_bf16_gpu,\n    \"Q4_0\": dequantize_q4_0_gpu,\n    \"Q5_0\": dequantize_q5_0_gpu,\n    \"Q8_0\": dequantize_q8_0_gpu,\n    \"Q2_K\": dequantize_q2_k_gpu,\n    \"Q3_K\": dequantize_q3_k_gpu,\n    \"Q4_K\": dequantize_q4_k_gpu,\n    \"Q5_K\": dequantize_q5_k_gpu,\n    \"Q6_K\": dequantize_q6_k_gpu,\n    \"IQ4_XS\": dequantize_iq4_xs_gpu,\n}\n\n\ndef translate_name_to_gguf_mixtral(name):\n    \n    replacement_template = {\n        \"w1.weight\": \"ffn_gate\",\n        \"w2.weight\": \"ffn_down\",\n        \"w3.weight\": \"ffn_up\"\n    }  \n\n    pattern = re.compile(r\"model.layers\\.(\\d+)\\.block_sparse_moe\\.experts\\.(\\d+)\\.(w\\d\\.weight)\")\n\n    def replace_match(match):\n        blk_id = match.group(1)\n        expert_id = match.group(2)\n        weight_type = match.group(3)\n        if weight_type in replacement_template:\n            return f\"blk.{blk_id}.{replacement_template[weight_type]}.{expert_id}.weight\"\n        else:\n            return match.group(0)\n\n    new_name = re.sub(pattern, replace_match, name)\n    \n    return new_name\n\ndef translate_name_to_gguf(name):\n\n    name = translate_name_to_gguf_mixtral(name)\n\n    name = name.replace(\"lm_head.\", \"output.\")\n    name = name.replace(\"model.embed_tokens.\", \"token_embd.\")\n    name = name.replace(\"model.norm.\", \"output_norm.\")\n    \n    name = name.replace(\"model.layers.\", \"blk.\")\n    name = name.replace(\".input_layernorm\", \".attn_norm\")\n    name = name.replace(\".mlp.down_proj\", \".ffn_down\")\n    name = name.replace(\".mlp.gate_proj\", \".ffn_gate\")\n    name = name.replace(\".mlp.up_proj\", \".ffn_up\")\n    name = name.replace(\".post_attention_layernorm\", \".ffn_norm\")\n    name = name.replace(\".self_attn.q_proj\", \".attn_q\")\n    name = name.replace(\".self_attn.k_proj\", \".attn_k\")\n    name = name.replace(\".self_attn.v_proj\", \".attn_v\")\n    name = name.replace(\".self_attn.o_proj\", \".attn_output\")\n    name = name.replace(\".self_attn.qkv_proj\", \".attn_qkv\")\n    name = name.replace(\".self_attn.kv_a_proj_with_mqa\", \".attn_kv_a_mqa\")\n    name = name.replace(\".self_attn.kv_a_layernorm\", \".attn_kv_a_norm\")\n    name = name.replace(\".self_attn.kv_b_proj\", \".attn_kv_b\")\n    name = name.replace(\".self_attn.q_a_proj\", \".attn_q_a\")\n    name = name.replace(\".self_attn.q_a_layernorm\", \".attn_q_a_norm\")\n    name = name.replace(\".self_attn.q_b_proj\", \".attn_q_b\")\n    \n    name = name.replace(\".shared_expert.\", \".shared_experts.\")\n    name = name.replace(\".shared_expert_\", \".shared_experts_\")\n    name = name.replace(\".gate_up_proj.\", \".up_proj\")\n    \n    name = name.replace(\".mlp.shared_experts.down_proj\", \".ffn_down_shexp\")\n    name = name.replace(\".mlp.gate\", \".ffn_gate_inp\")\n    name = name.replace(\".mlp.shared_experts.gate_proj\", \".ffn_gate_shexp\")\n    name = name.replace(\".mlp.shared_experts.up_proj\", \".ffn_up_shexp\")\n    name = name.replace(\".mlp.shared_experts_gate\", \".ffn_gate_inp_shexp\")\n    name = name.replace(\".mlp.experts\", \"\")\n    name = name.replace(\".mlp.experts.ffn_down_exps\", \".ffn_down_exps\")\n    name = name.replace(\".mlp.experts.ffn_gate_exps\", \".ffn_gate_exps\")\n    name = name.replace(\".mlp.experts.ffn_up_exps\", \".ffn_up_exps\")\n\n    \n    name = name.replace(\".block_sparse_moe.gate.\", \".ffn_gate_inp.\")\n    name = name.replace(\".block_sparse_moe.experts\", \"\")\n    \n    return name\n\ndef translate_adapter_name_to_gguf(name):\n\n    # name = translate_name_to_gguf_mixtral(name)\n\n    name = name.replace(\"lora_A.default.weight\", \"lora_A.weight\")\n    name = name.replace(\"lora_B.default.weight\", \"lora_B.weight\")\n    # NOT fine-tun embedding model\n    # name = name.replace(\"base_model.model\", \"token_embd.\")\n    # name = name.replace(\"model.norm.\", \"output_norm.\")\n    \n    name = name.replace(\"blk.\", \"model.layers.\")\n    # name = name.replace(\".input_layernorm\", \".attn_norm\")\n    # name = name.replace(\".mlp.down_proj\", \".ffn_down\")\n    # name = name.replace(\".mlp.gate_proj\", \".ffn_gate\")\n    # name = name.replace(\".mlp.up_proj\", \".ffn_up\")\n    # name = name.replace(\".post_attention_layernorm\", \".ffn_norm\")\n    # name = name.replace(\".self_attn.q_proj\", \".attn_q\")\n    # name = name.replace(\".self_attn.k_proj\", \".attn_k\")\n    # name = name.replace(\".self_attn.v_proj\", \".attn_v\")\n    # name = name.replace(\".self_attn.o_proj\", \".attn_output\")\n    # name = name.replace(\".self_attn.qkv_proj\", \".attn_qkv\")\n    # name = name.replace(\".self_attn.kv_a_proj_with_mqa\", \".attn_kv_a_mqa\")\n    # name = name.replace(\".self_attn.kv_a_layernorm\", \".attn_kv_a_norm\")\n    # name = name.replace(\".self_attn.kv_b_proj\", \".attn_kv_b\")\n    # name = name.replace(\".self_attn.q_a_proj\", \".attn_q_a\")\n    # name = name.replace(\".self_attn.q_a_layernorm\", \".attn_q_a_norm\")\n    # name = name.replace(\".self_attn.q_b_proj\", \".attn_q_b\")\n    \n    # name = name.replace(\".shared_expert.\", \".shared_experts.\")\n    # name = name.replace(\".shared_expert_\", \".shared_experts_\")\n    # name = name.replace(\".gate_up_proj.\", \".up_proj\")\n    \n    # name = name.replace(\".mlp.shared_experts.down_proj\", \".ffn_down_shexp\")\n    # name = name.replace(\".mlp.gate\", \".ffn_gate_inp\")\n    # name = name.replace(\".mlp.shared_experts.gate_proj\", \".ffn_gate_shexp\")\n    # name = name.replace(\".mlp.shared_experts.up_proj\", \".ffn_up_shexp\")\n    # name = name.replace(\".mlp.shared_experts_gate\", \".ffn_gate_inp_shexp\")\n    # name = name.replace(\".mlp.experts\", \"\")\n    # name = name.replace(\".mlp.experts.ffn_down_exps\", \".ffn_down_exps\")\n    # name = name.replace(\".mlp.experts.ffn_gate_exps\", \".ffn_gate_exps\")\n    # name = name.replace(\".mlp.experts.ffn_up_exps\", \".ffn_up_exps\")\n\n    \n    # name = name.replace(\".block_sparse_moe.gate.\", \".ffn_gate_inp.\")\n    # name = name.replace(\".block_sparse_moe.experts\", \"\")\n    \n    return name\n\n\nif __name__ == '__main__':\n    gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'\n    loader = GGUFLoader(gguf_path)\n    loader.load_gguf_tensor('token_embd.weight')\n"
  },
  {
    "path": "kt-sft/ktransformers/util/custom_loader.py",
    "content": "import struct\nimport warnings\nimport numpy as np\nimport re\nimport numpy.typing as npt\nfrom typing import Sequence\nimport os\nfrom enum import IntEnum\nimport torch\nif not torch.xpu.is_available():\n    import KTransformersOps\nfrom safetensors import safe_open\nfrom ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant\nfrom ktransformers.util.custom_gguf import *\nfrom safetensors.torch import save_file\nfrom abc import ABC, abstractmethod\nfrom typing import Dict, Any, Optional, Union\n\nclass ModelLoader(ABC):\n    \"\"\"\n    Abstract base class for model loaders.\n    Defines the interface that all model loaders must implement.\n    \"\"\"\n    tensor_file_map = {}\n    @abstractmethod\n    def has_tensor(cls, name: str):\n        \"\"\"\n        Check if the tensor exists in the loader.\n        \n        Args:\n            name: Name of the tensor to check\n            \n        Returns:\n            bool: True if the tensor exists, False otherwise\n        \"\"\"\n        pass\n\nclass SafeTensorLoader(ModelLoader):\n    tensor_file_map: dict\n    tensor_type_map: dict\n    file_handle_map: dict\n    tensor_device_map: dict\n    \n    def __init__(self, file_path: str):\n        self.__load_tensor_file_map(file_path)\n\n    def __load_tensor_file_map(self, file_path: str):\n        if not os.path.exists(file_path):\n            raise FileNotFoundError(f\"Path not found: {file_path}\")\n        if os.path.isfile(file_path):\n            folder_path = os.path.dirname(file_path)\n        else:\n            folder_path = file_path\n        self.file_handle_map = {}\n        self.tensor_file_map = {}\n        self.tensor_type_map = {}\n        self.tensor_device_map = {}\n\n        found_safetensor = False\n        for root, _, files in os.walk(folder_path):\n            files = sorted(files)\n            for file in files:\n                if file.endswith(\".safetensors\"):\n                    found_safetensor = True\n                    file_path = os.path.join(root, file)\n                    if file not in self.file_handle_map:\n                        try:\n                            handle = safe_open(file_path, framework=\"pt\")\n                            self.file_handle_map[file] = handle\n                        except Exception as e:\n                            print(f\"Error opening Safetensor file {file_path}: {e}\")\n                            continue\n\n                    f = self.file_handle_map.get(file)\n                    if f is None:\n                        continue\n                    try:\n                        for key in f.keys():\n                            self.tensor_file_map[key] = file\n                    except Exception as e:\n                        print(f\"Error reading Safetensor file {file_path}: {e}\")\n\n        # if not found_safetensor:\n        #     raise FileNotFoundError(f\"No Safetensor files found in {folder_path}\")\n\n    def load_tensor(self, key: str, device: str=\"cpu\"):\n        if translate_name_to_gguf(key) in self.tensor_file_map:\n            key = translate_name_to_gguf(key)\n        elif key in self.tensor_file_map:\n            pass\n        else:\n            raise KeyError(f\"Key {key} not found in Safetensor files\")\n        file = self.tensor_file_map[key]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(key)\n        return tensor.to(device)\n\n    def load_experts(self, key: str, device: str=\"cpu\"):\n        '''\n        Load experts from safetensor\n        key: the name of the experts\n        device: the device to load the experts to\n        return: dict, \n        {up: tensor, down: tensor, gate: tensor, up_type: int, down_type: int, gate_type: int}\n        {xxx}_type: the type of the up tensor, corresponding to the ggml type\n        '''\n        if self.has_tensor(translate_name_to_gguf(key)+\".ffn_gate_exps.weight\"):\n            # legacy branch for loading hybrid model\n            base_key = translate_name_to_gguf(key)\n            # Load experts from safetensor\n            gate_key = f\"{base_key}.ffn_gate_exps.weight\"\n            gate_type_key = f\"{base_key}.ffn_gate_exps.ggml_type\"\n            up_key = f\"{base_key}.ffn_up_exps.weight\"\n            up_type_key = f\"{base_key}.ffn_up_exps.ggml_type\"\n            down_key = f\"{base_key}.ffn_down_exps.weight\"\n            down_type_key = f\"{base_key}.ffn_down_exps.ggml_type\"\n            gate_tensor = self.load_tensor(gate_key, device).numpy()\n            up_tensor = self.load_tensor(up_key, device).numpy()\n            down_tensor = self.load_tensor(down_key, device).numpy()\n            gate_type = self.load_tensor(gate_type_key, device).item()\n            up_type = self.load_tensor(up_type_key, device).item()\n            down_type = self.load_tensor(down_type_key, device).item()\n\n            return {\n                \"up\": up_tensor,\n                \"gate\": gate_tensor,\n                \"down\": down_tensor,\n                \"up_type\": up_type,\n                \"gate_type\": gate_type,\n                \"down_type\": down_type\n            }\n\n        else:\n            # Load experts from safetensor\n            base_key = key  # e.g. \"model.layers.3.mlp.experts\"\n            experts_count = 0\n            \n            # First, count how many experts we have by checking for expert 0's up_proj\n            while self.has_tensor(f\"{base_key}.{experts_count}.up_proj.weight\"):\n                experts_count += 1\n            \n            if experts_count == 0:\n                raise ValueError(f\"No experts found for key {base_key}\")\n            \n            # Initialize empty lists to store tensors for each projection type\n            up_projs = []\n            gate_projs = []\n            down_projs = []\n            \n            # Load all expert weights\n            for expert_id in range(experts_count):\n                up_key = f\"{base_key}.{expert_id}.up_proj.weight\"\n                gate_key = f\"{base_key}.{expert_id}.gate_proj.weight\"\n                down_key = f\"{base_key}.{expert_id}.down_proj.weight\"\n                \n                up_tensor = self.load_tensor(up_key, device)\n                gate_tensor = self.load_tensor(gate_key, device)\n                down_tensor = self.load_tensor(down_key, device)\n                \n                up_projs.append(up_tensor)\n                gate_projs.append(gate_tensor)\n                down_projs.append(down_tensor)\n            \n            # Stack the tensors along a new dimension\n            up_tensor = torch.stack(up_projs, dim=0)\n            gate_tensor = torch.stack(gate_projs, dim=0)\n            down_tensor = torch.stack(down_projs, dim=0)\n            \n            # Get original dtype for GGML type determination\n            orig_up_dtype = up_tensor.dtype\n            orig_gate_dtype = gate_tensor.dtype\n            orig_down_dtype = down_tensor.dtype\n            \n            # Convert to numpy with proper bfloat16 support\n            up_numpy = up_tensor.view(torch.uint16).numpy()\n            gate_numpy = gate_tensor.view(torch.uint16).numpy()\n            down_numpy = down_tensor.view(torch.uint16).numpy()\n            \n            # Determine tensor data types for GGML conversion\n            def get_ggml_type(dtype):\n                if dtype == torch.float32:\n                    return GGMLQuantizationType.F32\n                elif dtype == torch.float16:\n                    return GGMLQuantizationType.F16\n                elif dtype == torch.bfloat16:\n                    return GGMLQuantizationType.BF16\n                else:\n                    raise ValueError(f\"Unsupported tensor dtype: {dtype}\")\n            \n            return {\n                \"up\": up_numpy,\n                \"gate\": gate_numpy,\n                \"down\": down_numpy,\n                \"up_type\": get_ggml_type(orig_up_dtype),\n                \"gate_type\": get_ggml_type(orig_gate_dtype),\n                \"down_type\": get_ggml_type(orig_down_dtype)\n            }\n                \n    def load_gate(self, key: str, device: str=\"cpu\"):\n        '''\n        Load gate from safetensor\n        key: the name of the gate\n        device: the device to load the gate to\n        return: dict, \n        {'weight': tensor, 'e_score_correction_bias': tensor}\n        '''\n        target = [\"weight\", \"e_score_correction_bias\"]\n        res = {'weight': None, 'e_score_correction_bias': None}\n        if self.has_tensor(translate_name_to_gguf(key)+\".ffn_gate_exps.weight\"):\n            # legacy branch for loading hybrid model\n            base_key = key\n            for k in target:\n                translated_key = translate_name_to_gguf(f\"{base_key}.{k}\")\n                if self.has_tensor(translated_key):\n                    tensor = self.load_tensor(translated_key, device)\n                    res[k] = tensor\n        else:\n            # Load gate from safetensor\n            base_key = key\n            for k in target:\n                if self.has_tensor(f\"{base_key}.{k}\"):\n                    tensor = self.load_tensor(f\"{base_key}.{k}\", device)\n                    res[k] = tensor\n        return res\n    \n    def close_all_handles(self):\n        for handle in self.file_handle_map.values():\n            handle.close()\n        self.file_handle_map.clear()\n\n    def load_dequantized_tensor(self, key:str, device: str=\"cpu\"):\n        if key in self.tensor_file_map and translate_name_to_gguf(key):\n            pass\n        elif translate_name_to_gguf(key) in self.tensor_file_map:\n            key = translate_name_to_gguf(key)\n        else:\n            raise KeyError(f\"Key {key} not found in Safetensor files\")\n        file = self.tensor_file_map[key]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(key).to(device)\n        if key.endswith(\".weight\"):\n            if key[:-7] + \".weight_scale_inv\" in self.tensor_file_map:\n                weight_scale_inv = f.get_tensor(key[:-7] + \".weight_scale_inv\").to(device)\n                tensor = weight_dequant(tensor, weight_scale_inv)\n        return tensor.to(device)\n    \n    def has_tensor(self, name: str):\n        return name in self.tensor_file_map or translate_name_to_gguf(name) in self.tensor_file_map\n\nclass GGUFLoader(ModelLoader):\n    tensor_info: dict\n    gguf_path: str\n    tensor_file_map: dict # {tensor_name: tensor_file_path}\n    gguf_file_meta: dict\n    safetensor_loader: SafeTensorLoader\n    def __init__(self, gguf_path: str):\n        # Check dir exist\n        if not os.path.exists(gguf_path):\n            raise FileNotFoundError(f\"GGUF dir not found: {gguf_path}\")\n        if os.path.isfile(gguf_path):\n            gguf_path = os.path.dirname(gguf_path)\n\n        self.safetensor_loader = None\n        \n        self.tensor_info = {}\n        self.gguf_path = gguf_path\n        self.tensor_file_map = {}\n        self.file_data_map = {}\n        self.gguf_file_meta = {}\n        self.tensor_device_map = {}\n\n\t\t# I know this is ugly, but I don't want to change the original code too much\n        # TODO: merge gguf load and other loads.\n        safetensor_loader = SafeTensorLoader(gguf_path)\n        if safetensor_loader.tensor_file_map:\n            self.safetensor_loader = safetensor_loader\n            return\n        # Walk through all the .gguf files in the directory\n        found_gguf = False\n        for root, dirs, files in os.walk(gguf_path):\n            for file in files:\n                if file.endswith(\".gguf\"):\n                    found_gguf = True\n                    file_name = os.path.join(root, file)\n                    with open(file_name, \"rb\") as f:\n                        self.load_gguf(f)\n                        if file_name not in self.file_data_map:\n                            self.file_data_map[file_name] = np.memmap(file_name, mode = 'r')\n        if not found_gguf:\n            raise FileNotFoundError(f\"Cannot find any .gguf files in: {gguf_path}\")\n                            \n    def load_gguf(self, f):\n        f.seek(0)\n        assert f.read(4) == b'GGUF'\n        values = struct.unpack(\"<IQQ\", f.read(4+8+8))\n        version, n_tensors, n_kv = values\n        if version != 3:\n            warnings.warn(f\"Version {version} has never been tested, might not work\")\n\n        info = {}\n        for _ in range(n_kv):\n            name = read_value(f, DATA_TYPES[\"string\"])\n\n            data_type = struct.unpack(\"<I\", f.read(4))[0]\n\n            info[name] = read_value(f, data_type)\n\n        tensor_info = {}\n        for _ in range(n_tensors):\n            name = read_value(f, DATA_TYPES[\"string\"])\n            shape_len = read_value(f, DATA_TYPES[\"uint32\"])\n            shape = [read_value(f, DATA_TYPES[\"uint64\"]) for _ in range(shape_len)]\n            ggml_type = read_value(f, DATA_TYPES[\"uint32\"])\n            bad_offset = read_value(f, DATA_TYPES[\"uint64\"])\n            n_elems = int(math.prod(shape))\n            block_size, type_size = GGML_QUANT_SIZES[ggml_type]\n            n_bytes = n_elems * type_size // block_size\n            np_dims = tuple(reversed(shape))\n        \n            item_type: npt.DTypeLike\n            if ggml_type == GGMLQuantizationType.F16:\n                item_count = n_elems\n                item_type = np.float16\n            elif ggml_type == GGMLQuantizationType.F32:\n                item_count = n_elems\n                item_type = np.float32\n            elif ggml_type == GGMLQuantizationType.F64:\n                item_count = n_elems\n                item_type = np.float64\n            elif ggml_type == GGMLQuantizationType.I8:\n                item_count = n_elems\n                item_type = np.int8\n            elif ggml_type == GGMLQuantizationType.I16:\n                item_count = n_elems\n                item_type = np.int16\n            elif ggml_type == GGMLQuantizationType.I32:\n                item_count = n_elems\n                item_type = np.int32\n            elif ggml_type == GGMLQuantizationType.I64:\n                item_count = n_elems\n                item_type = np.int64\n            else:\n                item_count = n_bytes\n                item_type = np.uint8\n                np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)\n\n            tensor_info[name] = {\n                \"ggml_type\": ggml_type,\n                \"shape\": shape,\n                \"bad_offset\": bad_offset,\n                \"item_type\": item_type,\n                \"item_count\": item_count,\n                \"np_dims\": np_dims\n            }\n\n        start = f.tell()\n        # Alignment is 32 by default.\n        # https://github.com/ggerganov/ggml/blob/e1daebbf9d38d510ba456c4d50b4500a73ac2b14/docs/gguf.md?plain=1#L253\n        alignment = info.get(\"general.alignment\", 32)\n\n        # Inconveniently, the offset defined in gguf files is relative to the\n        # end of the header and is unaligned.\n        # We need to compute the absolute file offset ourselves instead.\n        for t in tensor_info.values():\n            offset = start + t[\"bad_offset\"]\n            offset += (alignment - offset % alignment) % alignment\n            t[\"offset\"] = offset\n            \n        for name in tensor_info:\n            self.tensor_file_map[name] = f.name\n        self.tensor_info.update(tensor_info)\n        self.gguf_file_meta.update(info)\n    \n    def get_mmap_tensor(self, name):\n        name = translate_name_to_gguf(name)\n        t = self.tensor_info[name]\n        mmap_data = self.file_data_map[ self.tensor_file_map[name] ]\n\n        offset = t[\"offset\"]\n        item_type = t[\"item_type\"]\n        item_count = t[\"item_count\"]\n        itemsize = int(np.empty([], dtype = item_type).itemsize)\n        return mmap_data[offset : offset + itemsize * item_count]\n\n    def get_undequanted_tensor_and_ggml_type(self, name):\n        name = translate_name_to_gguf(name)\n        t = self.tensor_info[name]\n        data = self.get_mmap_tensor(name)\n        ggml_type = t[\"ggml_type\"]\n        data = torch.from_numpy(data)\n        return data, ggml_type\n\n    def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = \"cuda\", target_dtype = torch.get_default_dtype())->torch.Tensor:\n        name = translate_name_to_gguf(name)\n        t = self.tensor_info[name]\n        shape = t[\"shape\"]\n        ggml_type = t[\"ggml_type\"]\n        if ggml_type not in GGML_NAMES:\n            raise NotImplementedError(f\"ggml_type {ggml_type} not implemented\")\n        ggml_name = GGML_NAMES[ggml_type]\n\n        # TODO: experts may fused in quant block, split it\n        assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, \"experts may fused in quant block, please use CPU dequant\"\n\n        blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name]\n        block_size = GGML_BLOCK_SIZES[ggml_name]\n        offset = expert_id * block_size * blocks_per_experts\n        data = data[offset: offset + block_size * blocks_per_experts]\n\n        if \"cuda\" in device.lower():\n            values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)\n        else:\n            values = GGML_DEQUANTIZE[ggml_name](data)\n            values = torch.from_numpy(values.copy())\n\n        if ggml_name == \"BF16\":\n            values = values.view(torch.bfloat16)\n        values = values.view(shape[-2::-1])\n\n        return values\n\n    def load_gguf_tensor(self, name: str, device:str = \"cpu\", target_dtype = None)->torch.Tensor:\n        name = translate_name_to_gguf(name)\n        t = self.tensor_info[name]\n        if target_dtype == None:\n            target_dtype = torch.get_default_dtype()\n        \n        shape = t[\"shape\"]\n        ggml_type = t[\"ggml_type\"]\n\n        if ggml_type not in GGML_NAMES:\n            raise NotImplementedError(f\"ggml_type {ggml_type} not implemented\")\n\n        ggml_name = GGML_NAMES[ggml_type]\n\n        data = self.get_mmap_tensor(name)\n\n        block_size = GGML_BLOCK_SIZES[ggml_name]\n        elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]\n        num_elements = int(np.prod(shape))\n        num_blocks = num_elements // elements_per_block\n        \n        blocks_per_iter = 16384\n        if num_blocks > blocks_per_iter: # dequant large tensor\n            values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)\n            for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):\n                blocks_begin = i * blocks_per_iter\n                blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)\n                if \"cuda\" in device.lower():\n                    try:\n                        cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)\n                    except:\n                        cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])\n                        cur_values = torch.from_numpy(cur_values.copy()).to(device)\n                else:\n                    cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])\n                    cur_values = torch.from_numpy(cur_values.copy())\n                \n                cur_values = cur_values.view(-1, elements_per_block)\n                if ggml_name == \"BF16\":\n                    cur_values = cur_values.view(torch.bfloat16)\n                values[blocks_begin : blocks_end] = cur_values\n        else:\n            if \"cuda\" in device.lower():\n                values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)\n            else:\n                np_values = np.copy(GGML_DEQUANTIZE[ggml_name](data))\n                values = torch.from_numpy(np_values).to(device)\n                del np_values\n\n        if ggml_name == \"BF16\":\n            values = values.view(torch.bfloat16)\n            \n\n        values = values.view(shape[::-1])\n        if \"attn_q\" in name and self.gguf_file_meta['general.architecture'] in [\"llama\"]:\n            n_head = self.gguf_file_meta['llama.attention.head_count']\n            values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])\n            .swapaxes(1, 2)\n            .reshape(values.shape))\n        elif \"attn_k\" in name and self.gguf_file_meta['general.architecture'] in [\"llama\"]:\n            n_head = self.gguf_file_meta['llama.attention.head_count_kv'] \n            values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])\n            .swapaxes(1, 2)\n            .reshape(values.shape))\n        return values\n    def has_tensor(self, name: str):\n        name = translate_name_to_gguf(name)\n        return name in self.tensor_info\n\n    def get_ggml_type(self, name: str):\n        name = translate_name_to_gguf(name)\n        if name not in self.tensor_info:\n            raise KeyError(f\"Key {name} not found in GGUF files\")\n        return self.tensor_info[name][\"ggml_type\"]\n    \nclass ModelLoaderFactory:\n    \"\"\"\n    Factory class for creating model loaders.\n    Automatically detects the model format based on file extensions in the directory.\n    \"\"\"\n    \n    @staticmethod\n    def create_loader(path: str):\n        \"\"\"\n        Create a model loader for the given path by detecting the model format.\n        The function checks for the presence of .safetensors or .gguf files\n        in the specified path and creates the appropriate loader.\n        \n        Args:\n            path: Path to the model directory or file\n            \n        Returns:\n            An appropriate ModelLoader instance (SafeTensorLoader or GGUFLoader)\n        \n        Raises:\n            FileNotFoundError: If no supported model files are found in the path\n        \"\"\"\n        if not os.path.exists(path):\n            raise FileNotFoundError(f\"Path not found: {path}\")\n            \n        # Normalize to directory path if a file was provided\n        if os.path.isfile(path):\n            if path.endswith(\".safetensors\"):\n                return SafeTensorLoader(path)\n            elif path.endswith(\".gguf\"):\n                return GGUFLoader(path)\n            else:\n                folder_path = os.path.dirname(path)\n        else:\n            folder_path = path\n            \n        # Check for safetensors files\n        has_safetensors = False\n        has_gguf = False\n        \n        for root, _, files in os.walk(folder_path):\n            for file in files:\n                if file.endswith(\".safetensors\"):\n                    has_safetensors = True\n                    break\n                elif file.endswith(\".gguf\"):\n                    has_gguf = True\n                    break\n            if has_safetensors or has_gguf:\n                break\n                \n        # Create the appropriate loader based on detected file types\n        # Prioritize SafeTensor over GGUF if both are present\n        if has_safetensors:\n            try:\n                return SafeTensorLoader(folder_path)\n            except Exception as e:\n                print(f\"Failed to create SafeTensorLoader: {e}\")\n                # Fall through to try GGUF if SafeTensor fails\n                if not has_gguf:\n                    raise\n        \n        if has_gguf:\n            try:\n                return GGUFLoader(folder_path)\n            except Exception as e:\n                print(f\"Failed to create GGUFLoader: {e}\")\n                raise\n        \n        # No supported model files found\n        raise FileNotFoundError(f\"No .safetensors or .gguf files found in: {folder_path}\")"
  },
  {
    "path": "kt-sft/ktransformers/util/globals.py",
    "content": "import os\n\nclass _GlobalConfig:\n    def __init__(self):\n        self._config = {\n            \"mod\": 'infer', # infer or sft\n        }\n\n    def get(self, key, default=None):\n        return self._config.get(key, default)\n\n    def set(self, key, value):\n        self._config[key] = value\n\n    def update(self, **kwargs):\n        self._config.update(kwargs)\n\n    def all(self):\n        return self._config\n\n    def __getitem__(self, key):\n        return self._config[key]\n\n    def __setitem__(self, key, value):\n        self._config[key] = value\n\nGLOBAL_CONFIG = _GlobalConfig()\n"
  },
  {
    "path": "kt-sft/ktransformers/util/grad_wrapper.py",
    "content": "from functools import wraps\nimport torch, yaml, pathlib\n\nimport os, sys\nproject_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))\nsys.path.insert(0, project_dir)\n\nfrom ktransformers.util.globals import GLOBAL_CONFIG\n\n# print(f\"start_sit: {GLOBAL_CONFIG._config['mod']}\")\n\ndef maybe_no_grad(_func=None):\n    # print(f\"maybe_sit: {GLOBAL_CONFIG._config['mod']}\")\n    \n    def decorator(func):\n        # print(f\"decorate_sit: {GLOBAL_CONFIG._config['mod']}\")\n        def wrapper(*args, **kwargs):\n            # print(f\"wrap_sit: {GLOBAL_CONFIG._config['mod']}\")\n            if GLOBAL_CONFIG._config[\"mod\"] == \"sft\":\n                return func(*args, **kwargs)\n            elif GLOBAL_CONFIG._config[\"mod\"] == \"infer\":\n                with torch.no_grad():\n                    return func(*args, **kwargs)\n        return wrapper\n\n    if _func is None:\n        return decorator\n    else:\n        return decorator(_func)\n"
  },
  {
    "path": "kt-sft/ktransformers/util/inference_state.py",
    "content": "\nimport enum\n\n\nclass InferenceState(enum.Enum):\n    UNLOAD = 0\n    PREFILL = 1\n    GENERATE = 2\n    RESTORE = 3\n"
  },
  {
    "path": "kt-sft/ktransformers/util/modeling_rope_utils.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport math\nfrom typing import Optional, Tuple\n\nfrom transformers.configuration_utils import PretrainedConfig\nfrom transformers.utils import is_torch_available, logging\n\n\nlogger = logging.get_logger(__name__)\n\n\nif is_torch_available():\n    import torch\n\n\ndef _compute_default_rope_parameters(\n    config: Optional[PretrainedConfig] = None,\n    device: Optional[\"torch.device\"] = None,\n    seq_len: Optional[int] = None,\n    **rope_kwargs,\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies according to the original RoPE implementation\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length. Unused for this type of RoPE.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).\n    \"\"\"\n    if config is not None and len(rope_kwargs) > 0:\n        raise ValueError(\n            \"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in \"\n            f\"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}\"\n        )\n    if len(rope_kwargs) > 0:\n        base = rope_kwargs[\"base\"]\n        dim = rope_kwargs[\"dim\"]\n    elif config is not None:\n        base = config.rope_theta\n        partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n        head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        dim = int(head_dim * partial_rotary_factor)\n\n    attention_factor = 1.0  # Unused in this type of RoPE\n\n    # Compute the inverse frequencies\n    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))\n    return inv_freq, attention_factor\n\n\ndef _compute_linear_scaling_rope_parameters(\n    config: Optional[PretrainedConfig] = None,\n    device: Optional[\"torch.device\"] = None,\n    seq_len: Optional[int] = None,\n    **rope_kwargs,\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length. Unused for this type of RoPE.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).\n    \"\"\"\n    if config is not None and len(rope_kwargs) > 0:\n        raise ValueError(\n            \"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in \"\n            f\"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}\"\n        )\n    if len(rope_kwargs) > 0:\n        factor = rope_kwargs[\"factor\"]\n    elif config is not None:\n        factor = config.rope_scaling[\"factor\"]\n\n    # Gets the default RoPE parameters\n    inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)\n\n    # Then applies linear scaling to the frequencies.\n    # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so\n    # applying scaling to the inverse frequencies is equivalent.\n    inv_freq /= factor\n    return inv_freq, attention_factor\n\n\ndef _compute_dynamic_ntk_parameters(\n    config: Optional[PretrainedConfig] = None,\n    device: Optional[\"torch.device\"] = None,\n    seq_len: Optional[int] = None,\n    **rope_kwargs,\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length, used to update the dynamic RoPE at inference time.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).\n    \"\"\"\n    # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling\n    if config is not None and len(rope_kwargs) > 0:\n        raise ValueError(\n            \"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in \"\n            f\"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}\"\n        )\n    if len(rope_kwargs) > 0:\n        base = rope_kwargs[\"base\"]\n        dim = rope_kwargs[\"dim\"]\n        max_position_embeddings = rope_kwargs[\"max_position_embeddings\"]\n        factor = rope_kwargs[\"factor\"]\n    elif config is not None:\n        base = config.rope_theta\n        partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n        head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n        dim = int(head_dim * partial_rotary_factor)\n        max_position_embeddings = config.max_position_embeddings\n        factor = config.rope_scaling[\"factor\"]\n\n    attention_factor = 1.0  # Unused in this type of RoPE\n\n    # seq_len: default to max_position_embeddings, e.g. at init time\n    seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings\n\n    # Compute the inverse frequencies\n    base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))\n    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))\n    return inv_freq, attention_factor\n\n\ndef _compute_yarn_parameters(\n    config: PretrainedConfig, device: \"torch.device\", seq_len: Optional[int] = None, **rope_kwargs\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies with NTK scaling. Please refer to the\n    [original paper](https://arxiv.org/abs/2309.00071)\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length. Unused for this type of RoPE.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin.\n    \"\"\"\n    # No need to keep BC with yarn, unreleased when this new pattern was created.\n    if len(rope_kwargs) > 0:\n        raise ValueError(\n            f\"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}\"\n        )\n\n    base = config.rope_theta\n    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n    head_dim = getattr(config, \"qk_rope_head_dim\", config.hidden_size // config.num_attention_heads)\n    dim = int(head_dim * partial_rotary_factor)\n    factor = config.rope_scaling[\"factor\"]\n    attention_factor = config.rope_scaling.get(\"attention_factor\")\n    mscale = config.rope_scaling.get(\"mscale\")\n    mscale_all_dim = config.rope_scaling.get(\"mscale_all_dim\")\n\n    # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a\n    # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two\n    # values to compute the default attention scaling factor, instead of using `factor`.\n    if \"original_max_position_embeddings\" in config.rope_scaling:\n        original_max_position_embeddings = config.rope_scaling[\"original_max_position_embeddings\"]\n        factor = config.max_position_embeddings / original_max_position_embeddings\n    else:\n        original_max_position_embeddings = config.max_position_embeddings\n\n    def get_mscale(scale, mscale=1):\n        if scale <= 1:\n            return 1.0\n        return 0.1 * mscale * math.log(scale) + 1.0\n\n    # Sets the attention factor as suggested in the paper\n    if attention_factor is None:\n        if mscale and mscale_all_dim:\n            attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))\n        else:\n            attention_factor = get_mscale(factor)\n\n    # Optional config options\n    # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)\n    beta_fast = config.rope_scaling.get(\"beta_fast\") or 32\n    beta_slow = config.rope_scaling.get(\"beta_slow\") or 1\n\n    # Compute the inverse frequencies\n    def find_correction_dim(num_rotations, dim, base, max_position_embeddings):\n        \"\"\"Inverse dimension formula to find the dimension based on the number of rotations\"\"\"\n        return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))\n\n    def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):\n        \"\"\"Find dimension range bounds based on rotations\"\"\"\n        low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))\n        high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))\n        return max(low, 0), min(high, dim - 1)\n\n    def linear_ramp_factor(min, max, dim):\n        if min == max:\n            max += 0.001  # Prevent singularity\n\n        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n        ramp_func = torch.clamp(linear_func, 0, 1)\n        return ramp_func\n\n    # Note on variable naming: \"interpolation\" comes from the original technique, where we interpolate the position IDs\n    # to expand the possible context length. In other words, interpolation = apply scaling factor.\n    pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)\n    inv_freq_extrapolation = 1.0 / pos_freqs\n    inv_freq_interpolation = 1.0 / (factor * pos_freqs)\n\n    low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)\n\n    # Get n-dimensional rotational scaling corrected for extrapolation\n    inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)\n    inv_freq = (\n        inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)\n        + inv_freq_extrapolation * inv_freq_extrapolation_factor\n    )\n    return inv_freq, attention_factor\n\n\ndef _compute_longrope_parameters(\n    config: PretrainedConfig, device: \"torch.device\", seq_len: Optional[int] = None, **rope_kwargs\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies with LongRoPE scaling. Please refer to the\n    [original implementation](https://github.com/microsoft/LongRoPE)\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin.\n    \"\"\"\n    # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling\n    # No need to keep BC with longrope, unreleased when this new pattern was created.\n    if len(rope_kwargs) > 0:\n        raise ValueError(\n            \"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got \"\n            f\"{rope_kwargs}\"\n        )\n\n    base = config.rope_theta\n    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n    head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n    dim = int(head_dim * partial_rotary_factor)\n    long_factor = config.rope_scaling[\"long_factor\"]\n    short_factor = config.rope_scaling[\"short_factor\"]\n    factor = config.rope_scaling.get(\"factor\")\n    attention_factor = config.rope_scaling.get(\"attention_factor\")\n\n    # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a\n    # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two\n    # values to compute the default attention scaling factor, instead of using `factor`.\n    if hasattr(config, \"original_max_position_embeddings\"):\n        original_max_position_embeddings = config.original_max_position_embeddings\n        factor = config.max_position_embeddings / config.original_max_position_embeddings\n    else:\n        original_max_position_embeddings = config.max_position_embeddings\n\n    # Sets the attention factor as suggested in the paper\n    if attention_factor is None:\n        if factor <= 1.0:\n            attention_factor = 1.0\n        else:\n            attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))\n\n    # Compute the inverse frequencies -- scaled based on the target sequence length\n    if seq_len and seq_len > original_max_position_embeddings:\n        ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)\n    else:\n        ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)\n    inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim\n    inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)\n\n    return inv_freq, attention_factor\n\n\ndef _compute_llama3_parameters(\n    config: PretrainedConfig, device: \"torch.device\", seq_len: Optional[int] = None, **rope_kwargs\n) -> Tuple[\"torch.Tensor\", float]:\n    \"\"\"\n    Computes the inverse frequencies for llama 3.1.\n\n    Args:\n        config ([`~transformers.PretrainedConfig`]):\n            The model configuration.\n        device (`torch.device`):\n            The device to use for initialization of the inverse frequencies.\n        seq_len (`int`, *optional*):\n            The current sequence length. Unused for this type of RoPE.\n        rope_kwargs (`Dict`, *optional*):\n            BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.\n    Returns:\n        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the\n        post-processing scaling factor applied to the computed cos/sin.\n    \"\"\"\n    # Gets the default RoPE parameters\n    inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)\n\n    factor = config.rope_scaling[\"factor\"]  # `8` in the original implementation\n    low_freq_factor = config.rope_scaling[\"low_freq_factor\"]  # `1` in the original implementation\n    high_freq_factor = config.rope_scaling[\"high_freq_factor\"]  # `4` in the original implementation\n    old_context_len = config.rope_scaling[\"original_max_position_embeddings\"]  # `8192` in the original implementation\n\n    low_freq_wavelen = old_context_len / low_freq_factor\n    high_freq_wavelen = old_context_len / high_freq_factor\n\n    wavelen = 2 * math.pi / inv_freq\n    # wavelen < high_freq_wavelen: do nothing\n    # wavelen > low_freq_wavelen: divide by factor\n    inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)\n    # otherwise: interpolate between the two, using a smooth factor\n    smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)\n    smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama\n    is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)\n    inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n\n    return inv_freq_llama, attention_factor\n\n\n# This maps the \"rope_type\" string field in rope config to the corresponding function to compute the RoPE parameters\n# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE\n# parameterizations, as long as the callable has the same signature.\nROPE_INIT_FUNCTIONS = {\n    \"default\": _compute_default_rope_parameters,\n    \"linear\": _compute_linear_scaling_rope_parameters,\n    \"dynamic\": _compute_dynamic_ntk_parameters,\n    \"yarn\": _compute_yarn_parameters,\n    \"longrope\": _compute_longrope_parameters,\n    \"llama3\": _compute_llama3_parameters,\n}\n\n\ndef _check_received_keys(\n    rope_type: str,\n    received_keys: set,\n    required_keys: set,\n    optional_keys: Optional[set] = None,\n    ignore_keys: Optional[set] = None,\n):\n    \"\"\"Compare the received keys in `config.rope_scaling` against the expected and optional keys\"\"\"\n    # BC: \"rope_type\" was originally \"type\" -- let's check for \"rope_type\" when \"type\" is present\n    if \"type\" in received_keys:\n        received_keys -= {\"type\"}\n        required_keys.add(\"rope_type\")\n\n    # Some models need to store model-specific keys, and we don't want to throw warning at them\n    if ignore_keys is not None:\n        received_keys -= ignore_keys\n\n    missing_keys = required_keys - received_keys\n    if missing_keys:\n        raise KeyError(f\"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}\")\n\n    if optional_keys is not None:\n        unused_keys = received_keys - required_keys - optional_keys\n    else:\n        unused_keys = received_keys - required_keys\n    if unused_keys:\n        logger.warning(f\"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}\")\n\n\ndef _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\"}\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)\n\n\ndef _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\", \"factor\"}\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)\n\n    factor = rope_scaling[\"factor\"]\n    if factor is None or not isinstance(factor, float) or factor < 1.0:\n        logger.warning(f\"`rope_scaling`'s factor field must be a float >= 1, got {factor}\")\n\n\ndef _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\", \"factor\"}\n    # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`\n    optional_keys = {\"original_max_position_embeddings\"}\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)\n\n    factor = rope_scaling[\"factor\"]\n    if factor is None or not isinstance(factor, float) or factor < 1.0:\n        logger.warning(f\"`rope_scaling`'s factor field must be a float >= 1, got {factor}\")\n\n\ndef _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\", \"factor\"}\n    optional_keys = {\n        \"attention_factor\",\n        \"beta_fast\",\n        \"beta_slow\",\n        \"original_max_position_embeddings\",\n        \"mscale\",\n        \"mscale_all_dim\",\n    }\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)\n\n    factor = rope_scaling[\"factor\"]\n    if factor is None or not isinstance(factor, float) or factor < 1.0:\n        logger.warning(f\"`rope_scaling`'s factor field must be a float >= 1, got {factor}\")\n\n    attention_factor = rope_scaling.get(\"attention_factor\")\n    if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):\n        logger.warning(\n            f\"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}\"\n        )\n    beta_fast = rope_scaling.get(\"beta_fast\")\n    if beta_fast is not None and not isinstance(beta_fast, float):\n        logger.warning(f\"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}\")\n    beta_slow = rope_scaling.get(\"beta_slow\")\n    if beta_slow is not None and not isinstance(beta_slow, float):\n        logger.warning(f\"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}\")\n\n    if (beta_fast or 32) < (beta_slow or 1):\n        logger.warning(\n            f\"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} \"\n            f\"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)\"\n        )\n\n\ndef _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\", \"short_factor\", \"long_factor\"}\n    # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`\n    optional_keys = {\"attention_factor\", \"factor\", \"original_max_position_embeddings\"}\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)\n\n    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, \"partial_rotary_factor\") else 1.0\n    head_dim = getattr(config, \"head_dim\", config.hidden_size // config.num_attention_heads)\n    dim = int(head_dim * partial_rotary_factor)\n\n    short_factor = rope_scaling.get(\"short_factor\")\n    if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):\n        logger.warning(f\"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}\")\n    if not len(short_factor) == dim // 2:\n        logger.warning(f\"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}\")\n\n    long_factor = rope_scaling.get(\"long_factor\")\n    if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):\n        logger.warning(f\"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}\")\n    if not len(long_factor) == dim // 2:\n        logger.warning(f\"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}\")\n\n    # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over\n    # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is\n    # unique to longrope (= undesirable)\n    if hasattr(config, \"original_max_position_embeddings\"):\n        logger.warning_once(\n            \"This model has set a `original_max_position_embeddings` field, to be used together with \"\n            \"`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`\"\n            \"with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, \"\n            \"as it is compatible with most model architectures.\"\n        )\n    else:\n        factor = rope_scaling.get(\"factor\")\n        if factor is None:\n            logger.warning(\"Missing required keys in `rope_scaling`: 'factor'\")\n        elif not isinstance(factor, float) or factor < 1.0:\n            logger.warning(f\"`rope_scaling`'s factor field must be a float >= 1, got {factor}\")\n\n        attention_factor = rope_scaling.get(\"attention_factor\")\n        if attention_factor is not None:\n            if not isinstance(attention_factor, float) or attention_factor < 0.0:\n                logger.warning(\n                    f\"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}\"\n                )\n\n\ndef _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    rope_scaling = config.rope_scaling\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", None))  # BC: \"rope_type\" was originally \"type\"\n    required_keys = {\"rope_type\", \"factor\", \"original_max_position_embeddings\", \"low_freq_factor\", \"high_freq_factor\"}\n    received_keys = set(rope_scaling.keys())\n    _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)\n\n    factor = rope_scaling[\"factor\"]\n    if factor is None or not isinstance(factor, float) or factor < 1.0:\n        logger.warning(f\"`rope_scaling`'s factor field must be a float >= 1, got {factor}\")\n\n    low_freq_factor = rope_scaling[\"low_freq_factor\"]\n    high_freq_factor = rope_scaling[\"high_freq_factor\"]\n    if low_freq_factor is None or not isinstance(low_freq_factor, float):\n        logger.warning(f\"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}\")\n    if high_freq_factor is None or not isinstance(high_freq_factor, float):\n        logger.warning(f\"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}\")\n    if high_freq_factor <= low_freq_factor:\n        logger.warning(\n            \"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=\"\n            f\"{high_freq_factor} and low_freq_factor={low_freq_factor}\"\n        )\n\n    original_max_position_embeddings = rope_scaling[\"original_max_position_embeddings\"]\n    if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):\n        logger.warning(\n            \"`rope_scaling`'s original_max_position_embeddings field must be an integer, got \"\n            f\"{original_max_position_embeddings}\"\n        )\n    if original_max_position_embeddings >= config.max_position_embeddings:\n        logger.warning(\n            \"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got \"\n            f\"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}\"\n        )\n\n\n# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.\nROPE_VALIDATION_FUNCTIONS = {\n    \"default\": _validate_default_rope_parameters,\n    \"linear\": _validate_linear_scaling_rope_parameters,\n    \"dynamic\": _validate_dynamic_scaling_rope_parameters,\n    \"yarn\": _validate_yarn_parameters,\n    \"longrope\": _validate_longrope_parameters,\n    \"llama3\": _validate_llama3_parameters,\n}\n\n\ndef rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):\n    \"\"\"\n    Validate the RoPE config arguments, given a `PretrainedConfig` object\n    \"\"\"\n    rope_scaling = getattr(config, \"rope_scaling\", None)  # not a default parameter in `PretrainedConfig`\n    if rope_scaling is None:\n        return\n\n    # BC: \"rope_type\" was originally \"type\"\n    rope_type = rope_scaling.get(\"rope_type\", rope_scaling.get(\"type\", \"default\"))\n    validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)\n    if validation_fn is not None:\n        validation_fn(config, ignore_keys=ignore_keys)\n    else:\n        logger.warning(\n            f\"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'\"\n        )"
  },
  {
    "path": "kt-sft/ktransformers/util/textstream.py",
    "content": "from typing import Any, List, Optional, Set\nclass TextStreamer:\n\n    def __init__(self, tokenizer: \"AutoTokenizer\", skip_prompt: bool = False, **decode_kwargs):\n        self.tokenizer = tokenizer\n        self.skip_prompt = skip_prompt\n        self.decode_kwargs = decode_kwargs\n\n        # variables used in the streaming process\n        self.token_cache = []\n        self.print_len = 0\n        self.next_tokens_are_prompt = True\n\n    def reset(self):\n        self.token_cache = []\n        self.print_len = 0\n\n    def put(self, value)->Optional[str]:\n        \"\"\"\n        Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.\n        \"\"\"        \n        if not isinstance(value,int):\n            raise ValueError(\"TextStreamer only supports batch size 1, and int type input\")\n\n\n        if self.skip_prompt and self.next_tokens_are_prompt:\n            self.next_tokens_are_prompt = False\n            return None\n\n        # Add the new token to the cache and decodes the entire thing.\n        self.token_cache.append(value)\n        text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs)\n\n        # After the symbol for a new line, we flush the cache.\n        if text.endswith(\"\\n\"):\n            printable_text = text[self.print_len :]\n            self.reset()\n        # If the last token is a CJK character, we print the characters.\n        elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):\n            printable_text = text[self.print_len :]\n            self.print_len += len(printable_text)\n        # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,\n        # which may change with the subsequent token -- there are probably smarter ways to do this!)\n        else:\n            printable_text = text[self.print_len : text.rfind(\" \") + 1]\n            self.print_len += len(printable_text)\n        return printable_text\n\n    def end(self)->Optional[str]:\n        \"\"\"Flushes any remaining cache and prints a newline to stdout.\"\"\"\n        # Flush the cache, if it exists\n        if len(self.token_cache) > 0:\n            text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)\n            printable_text = text[self.print_len :]\n            self.reset()\n        else:\n            printable_text = \"\"\n\n        self.next_tokens_are_prompt = True\n        return printable_text\n   \n    def _is_chinese_char(self, cp):\n        \"\"\"Checks whether CP is the codepoint of a CJK character.\"\"\"\n        # This defines a \"chinese character\" as anything in the CJK Unicode block:\n        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)\n        #\n        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,\n        # despite its name. The modern Korean Hangul alphabet is a different block,\n        # as is Japanese Hiragana and Katakana. Those alphabets are used to write\n        # space-separated words, so they are not treated specially and handled\n        # like the all of the other languages.\n        if (\n            (cp >= 0x4E00 and cp <= 0x9FFF)\n            or (cp >= 0x3400 and cp <= 0x4DBF)  #\n            or (cp >= 0x20000 and cp <= 0x2A6DF)  #\n            or (cp >= 0x2A700 and cp <= 0x2B73F)  #\n            or (cp >= 0x2B740 and cp <= 0x2B81F)  #\n            or (cp >= 0x2B820 and cp <= 0x2CEAF)  #\n            or (cp >= 0xF900 and cp <= 0xFAFF)\n            or (cp >= 0x2F800 and cp <= 0x2FA1F)  #\n        ):  #\n            return True\n\n        return False"
  },
  {
    "path": "kt-sft/ktransformers/util/utils.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :  \nAuthor       : Boxin Zhang, Azure-Tang\nVersion      : 0.1.0\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved. \n'''\nimport torch\nfrom torch import nn\nimport itertools\nimport time\nimport enum\nfrom typing import Any, List, Optional, Set\nfrom transformers import (\n    LogitsProcessorList,\n    TemperatureLogitsWarper,\n    TopKLogitsWarper,\n    TopPLogitsWarper,\n    MinPLogitsWarper,\n    TypicalLogitsWarper,\n    EpsilonLogitsWarper,\n    EtaLogitsWarper,\n)\nfrom torchviz import make_dot\n# from ktransformers.sft.peft_utils.lora_layer import KTransformersLinearLora\nfrom ktransformers.util.custom_loader import ModelLoaderFactory, ModelLoader, SafeTensorLoader, GGUFLoader, translate_name_to_gguf, translate_adapter_name_to_gguf\nfrom ktransformers.operators import base_operator\nfrom ktransformers.models.custom_cache import StaticCache\nfrom ktransformers.util.cuda_graph_runner import CUDAGraphRunner\nfrom ktransformers.util.textstream import TextStreamer\nfrom ktransformers.util.globals import GLOBAL_CONFIG\nif not torch.xpu.is_available():\n    from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton\nimport socket\n\nfrom transformers.generation.logits_process import LogitsProcessor\n# from transformers import TextStreamer # !!! this will override the TextStreamer from ktransformers.util.textstream\n\nclass NoEosUntil(LogitsProcessor):\n    def __init__(self, prompt_len: int, min_gen_len: int, eos_ids):\n        super().__init__()\n        self.start_len = int(prompt_len)\n        self.min_len   = self.start_len + int(min_gen_len)\n        self.eos_ids   = list(eos_ids) if isinstance(eos_ids,(list,tuple)) else [int(eos_ids)]\n\n    def __call__(self, input_ids, scores):\n        if input_ids.shape[-1] < self.min_len:\n            scores[..., self.eos_ids] = -float(\"inf\")\n        return scores\n\nclass SilentCaptureStreamer(TextStreamer):\n    def __init__(self, tokenizer: \"AutoTokenizer\", skip_prompt: bool = False, **decode_kwargs):\n        super().__init__(tokenizer, skip_prompt=skip_prompt, **decode_kwargs)\n        self._buf: List[str] = []\n\n    def _append_piece(self, piece: Optional[str]):\n        if piece:\n            self._buf.append(piece)\n\n    def put(self, value) -> str:\n        tokens: List[int] = []\n        if isinstance(value, int):\n            tokens = [value]\n        else:\n            try:\n                import torch\n                if isinstance(value, torch.Tensor):\n                    tokens = list(map(int, value.view(-1).tolist()))\n                elif isinstance(value, (list, tuple)) and all(isinstance(x, int) for x in value):\n                    tokens = list(value)\n                else:\n                    raise ValueError(\"Unsupported value type for SilentCaptureStreamer.put\")\n            except Exception:\n                if isinstance(value, (list, tuple)) and all(isinstance(x, int) for x in value):\n                    tokens = list(value)\n                else:\n                    raise ValueError(\"Unsupported value type for SilentCaptureStreamer.put\")\n        for t in tokens:\n            piece = super().put(t)\n            self._append_piece(piece)\n        return \"\"\n\n    def end(self) -> str:\n        piece = super().end()\n        self._append_piece(piece)\n        return \"\"\n\n    def getvalue(self) -> str:\n        return \"\".join(self._buf)\n\n    def clear(self):\n        self._buf.clear()\n\nwarm_uped = False\n\ndef get_free_ports(n: int, continue_prot: list):\n    sockets = []\n    ports = []\n    for _ in range(n):\n        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n        s.bind((\"\", 0)) \n        port = s.getsockname()[1]\n        if port in continue_prot:\n            s.close()\n            continue\n        ports.append(port)\n        sockets.append(s)\n    for s in sockets:\n        s.close()\n    return ports\n\ndef get_compute_capability(device:torch.device = None):\n    if torch.cuda.is_available():\n        if device is None:\n            num_gpus = torch.cuda.device_count()\n            min_compute_capability_major = 100\n            for gpu_id in range(num_gpus):\n                gpu_props = torch.cuda.get_device_properties(gpu_id)\n                min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)\n            return min_compute_capability_major\n        else:\n            return torch.cuda.get_device_properties(device)\n    else:\n        return 0\n\ndef set_module(model, submodule_key, module):\n    tokens = submodule_key.split('.')\n    sub_tokens = tokens[:-1]\n    cur_mod = model\n    for s in sub_tokens:\n        if hasattr(cur_mod, s):\n            cur_mod = getattr(cur_mod, s)\n        else: # nn.ModuleList or nn.ModuleList\n            cur_mod=cur_mod[int(s)]\n    if hasattr(cur_mod, tokens[-1]):\n        setattr(cur_mod, tokens[-1], module)\n    else: # nn.ModuleList or nn.ModuleList\n        cur_mod[int(tokens[-1])] = module\n\ndef set_param(module: nn.Module, name: str, weights: torch.Tensor):\n    \n    param=nn.parameter.Parameter(weights, requires_grad=True)\n    if isinstance(module, nn.Linear) and len(weights.shape)==1:\n        param.unsqueeze_(0)\n    setattr(module, name, param)\n\ndef get_device(gguf_module_key:str, device_map:dict):\n    if gguf_module_key in device_map:\n        return device_map[gguf_module_key][\"generate_device\"]\n    elif gguf_module_key.replace(\"model.layers\", \"blk\") in device_map:\n        return device_map[gguf_module_key.replace(\"model.layer\", \"blk\")][\"generate_device\"]\n    else:\n        return \"cuda\"\n\ndef get_all_used_cuda_device(device_map:dict):\n    all_device_list = set()\n    for key in device_map:\n        all_device_list.add(device_map[key][\"generate_device\"]) if \"generate_device\" in device_map[key] else None\n        all_device_list.add(device_map[key][\"prefill_device\"]) if \"prefill_device\" in device_map[key] else None\n    if \"cpu\" in all_device_list:\n        all_device_list.remove(\"cpu\")\n    all_device_list = list(all_device_list)\n    return all_device_list\n\ndef load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = \"\", device=\"cuda\", adapter_gguf: bool = False):\n    if GLOBAL_CONFIG._config[\"mod\"] == 'sft':\n        prefix = prefix.replace(\"orig_module.\", \"\")\n        persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}\n        local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())\n        local_state = {k: v for k, v in local_name_params if v is not None}\n        for name, param in local_state.items():\n            key = prefix + name\n            translated_key = translate_name_to_gguf(key)\n            if adapter_gguf == True:\n                translated_adapter_key = translate_adapter_name_to_gguf(key)\n\n            # TODO: Merge all loader.\n            # I know this is ugly but lets do it for now.\n            if gguf_loader.safetensor_loader is not None:\n                load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor\n                tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map\n            else:\n                load_dequantized_tensor = gguf_loader.load_gguf_tensor\n                tensor_file_map = gguf_loader.tensor_file_map\n            # print(f\"tensor_file_map:{tensor_file_map}\")\n            # We allow some key not be used in GGUF\n            if translated_key in tensor_file_map:\n                target_dtype = torch.get_default_dtype()\n                device = get_device(translated_key[:translated_key.rfind(\".\")], gguf_loader.tensor_device_map)\n                print(f\"loading {translated_key} to {device}\")\n                torch.cuda.empty_cache()\n                weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)\n                set_param(module, name, weights)\n                del weights\n            else:\n                if adapter_gguf == True: # Not all module should be reload in lora adapter\n                    for single_tensor_file_map in tensor_file_map:\n                        if translated_adapter_key in single_tensor_file_map:\n                            target_dtype = torch.get_default_dtype()\n                            device = get_device(single_tensor_file_map[:single_tensor_file_map.rfind(\".\")], gguf_loader.tensor_device_map)\n                            print(f\"loading {single_tensor_file_map} to {device}\")\n                            torch.cuda.empty_cache()\n                            weights = load_dequantized_tensor(single_tensor_file_map, device=device).to(dtype=target_dtype)\n                            set_param(module, name, weights)\n                            del weights\n\n                else:\n                    #print(load_config.tensor_file_map.keys())\n                    raise Exception(f\"can't find {translated_key} in GGUF file!\")\n    elif GLOBAL_CONFIG._config[\"mod\"] == 'infer':\n        prefix = prefix.replace(\"orig_module.\", \"\")\n        persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}\n        local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())\n        local_state = {k: v for k, v in local_name_params if v is not None}\n        for name, param in local_state.items():\n            key = prefix + name\n            translated_key = key\n            \n            # TODO: Merge all loader.\n            # I know this is ugly but lets do it for now.\n            if isinstance(gguf_loader, SafeTensorLoader):\n                load_dequantized_tensor = gguf_loader.load_dequantized_tensor\n            else:\n                load_dequantized_tensor = gguf_loader.load_gguf_tensor\n                tensor_file_map = gguf_loader.tensor_file_map\n            \n            if gguf_loader.has_tensor(translated_key) or \"kv_b_proj\" in translated_key:\n                target_dtype = torch.get_default_dtype()\n                device = get_device(translated_key[:translated_key.rfind(\".\")], gguf_loader.tensor_device_map)\n                print(f\"loading {translated_key} to {device}\")\n                if torch.cuda.is_available():\n                    torch.cuda.empty_cache()\n                elif torch.xpu.is_available():\n                    torch.xpu.empty_cache()\n                if \"kv_b_proj\" in translated_key and not gguf_loader.has_tensor(translated_key):\n                    attn_k_b = load_dequantized_tensor(translated_key.replace(\"self_attn.kv_b_proj\", \"attn_k_b\"), device=device).to(dtype=target_dtype)\n                    attn_k_b = attn_k_b.transpose(1, 2).contiguous()\n                    attn_v_b = load_dequantized_tensor(translated_key.replace(\"self_attn.kv_b_proj\", \"attn_v_b\"), device=device).to(dtype=target_dtype)\n                    kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)\n                    kv_b_proj = kv_b_proj.contiguous() if kv_b_proj.ndim == 2 else kv_b_proj.flatten(0, 1).contiguous()\n                    set_param(module, name, kv_b_proj)\n                    del attn_k_b\n                    del attn_v_b\n                else:\n                    weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)\n                    set_param(module, name, weights)\n                    del weights\n            else:\n                #print(load_config.tensor_file_map.keys())\n                raise Exception(f\"can't find {translated_key} in GGUF file!\")\n        \n\ndef sync_all_device(all_device_list):\n    for device in all_device_list:\n        if \"cuda\" in device.lower():\n            torch.cuda.synchronize(device)\n        elif \"xpu\" in device.lower():\n            torch.xpu.synchronize(device)\n        else:\n            raise RuntimeError(\"The device {} is not available\".format(device))\n\ntorch_device_mapping ={\"cuda\": \"cuda:0\", \"xpu\": \"xpu:0\"}\n\ndef xpu_fp16_model(config):\n    # This function is to check if we run this model on XPU with FP16 dtype\n    if not torch.xpu.is_available():\n        return False\n    if config.architectures[0] == \"DeepseekV3ForCausalLM\":\n        return True\n    if config.architectures[0] == \"Qwen3MoeForCausalLM\" and config.hidden_size == 4096:\n        # Qwen3-30B seems have precision issue with FP16\n        # so we only use FP16 for Qwen3-235B now\n        return True\n    return False\n\ndef load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix='', device=\"cuda\", adapter_gguf=False):\n    #print(f\"recursively loading weights {prefix}\")\n    if not isinstance(module, base_operator.BaseInjectedModule):\n        load_cur_state_dict(module, gguf_loader, prefix, device=device, adapter_gguf=adapter_gguf, )\n        for name, child in module._modules.items():\n            load_weights(child, gguf_loader, prefix+name+\".\", device=device, adapter_gguf=adapter_gguf, )\n    else:\n        if adapter_gguf == True:\n            # TODO: This is not the best choice, because we should change the value of gguf_loader in BaseInjectModule, but up to now, it can still work\n            try: # for other class inherit from BaseInjectModule, but not inherit from KTLinear\n                module.load(gguf_loader=gguf_loader, adapter_gguf=adapter_gguf)\n            except: # for only KTLinear up to now\n                module.load()\n        else:\n            module.load()\n\ndef tf_logits_warper(generation_config):\n        \"\"\"\n        This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances\n        used for multinomial sampling.\n        \"\"\"\n\n        # instantiate warpers list\n        warpers = LogitsProcessorList()\n\n        # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a\n        # better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)\n        if generation_config.num_beams > 1:\n            if isinstance(generation_config._eos_token_tensor, list):\n                min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1\n            elif isinstance(generation_config._eos_token_tensor, torch.Tensor):\n                min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1\n            else:\n                min_tokens_to_keep = 2\n        else:\n            min_tokens_to_keep = 1\n\n        # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files\n        # all samplers can be found in `generation_utils_samplers.py`\n        if generation_config.temperature is not None and generation_config.temperature != 1.0:\n            warpers.append(TemperatureLogitsWarper(generation_config.temperature))\n        if generation_config.top_k is not None and generation_config.top_k != 0:\n            warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.top_p is not None and generation_config.top_p < 1.0:\n            warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.min_p is not None:\n            # Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)\n            warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))\n        if generation_config.typical_p is not None and generation_config.typical_p < 1.0:\n            warpers.append(\n                TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)\n            )\n        if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:\n            warpers.append(\n                EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)\n            )\n        if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:\n            warpers.append(\n               EtaLogitsWarper(\n                    epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device\n                )\n            )\n        # `LogitNormalization` should always be the last logit processor, when present\n        if generation_config.renormalize_logits is True:\n            warpers.append(LogitNormalization())\n        return warpers\n\ndef prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cuda_graph: bool = True,\n                         mode = 'normal', force_think: bool = False, chunk_size = 16384, use_flashinfer_mla = False,\n                         num_heads = None, head_dim_ckv = None, head_dim_kpe = None, q_head_dim = None):\n    import os\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n    torch._dynamo.config.suppress_errors = True\n    batch_size, seq_length = inputs.shape\n    device_map = model.gguf_loader.tensor_device_map\n    torch_device = get_device('model.layers.0.self_attn', device_map)\n    # torch_device = \"cuda:0\" if torch_device == \"cuda\" else torch_device\n    torch_device = torch_device_mapping[torch_device] if torch_device in torch_device_mapping else torch_device\n    inputs = inputs.to(torch_device)\n    all_cuda_device = get_all_used_cuda_device(device_map)\n\n    tokens = []\n    \n    def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):\n        if cuda_graph_runner is None:\n            use_cuda_graph = False\n        if use_cuda_graph:\n            logits = cuda_graph_runner(cur_token, position_ids, cache_position)\n        else:\n            # custom_stream = torch.cuda.Stream()\n            if torch.cuda.is_available():\n                torch.cuda.set_device(torch_device)\n            elif torch.xpu.is_available():\n                torch.xpu.set_device(torch_device)\n            else:\n                raise RuntimeError(f\"The device: {torch_device} is not available\")\n            inputs_embeds = model.model.embed_tokens(cur_token.to(\"cpu\")).to(torch_device)\n            # with torch.cuda.stream(custom_stream):\n            logits=model(inputs_embeds=inputs_embeds,\n                        position_ids=position_ids,\n                        cache_position=cache_position,\n                        past_key_values=past_key_values,\n                        return_dict=False, use_cache=True)[0]\n        if past_key_values != None and isinstance(past_key_values, StaticCache):\n            past_key_values.change_seq_length(1)\n        sync_all_device(all_cuda_device)\n        # print(logits)\n        next_token_scores = logits_warper(inputs, logits[:, -1, :])\n        if generation_config.do_sample:\n            probs = nn.functional.softmax(next_token_scores, dim=-1)\n            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)\n        else:\n            next_token = torch.argmax(next_token_scores, dim=-1)\n        return next_token\n    \n    # TODO: use CUDA Graph for chunk prefill, may get small improvement\n    def chunk_prefill(inputs, cache_position, past_key_values):\n        if mode == \"long_context\":\n            inputs_embeds = model.model.embed_tokens(inputs.to(\"cpu\"))\n        else:\n            print(f\"torch_device:{torch_device}\")\n            inputs_embeds = model.model.embed_tokens(inputs.to(\"cpu\")).to(torch_device)\n        if use_flashinfer_mla:\n            MLAWrapperSingleton.update_buffer(past_key_values.max_pages)\n            MLAWrapperSingleton.need_plan_all()\n            \n        logits = model(\n            inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True\n        )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)\n        \n        return logits\n    \n    if torch.cuda.is_available():\n        torch.cuda.set_device(torch_device)\n    elif torch.xpu.is_available():\n        torch.xpu.set_device(torch_device)\n    else:\n        raise RuntimeError(f\"The device: {torch_device} is not available\")\n    with torch.no_grad():\n        \n        stream = TextStreamer(tokenizer)\n        if torch.xpu.is_available():\n            from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache\n            if model.config.architectures[0] in [\"DeepseekV3ForCausalLM\", \"DeepseekV2ForCausalLM\"]:\n                past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)\n            else:\n                past_key_values = DynamicNormalCache.from_legacy_cache(None)\n        elif mode != 'long_context':\n            past_key_values = StaticCache(\n                config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype\n            )\n        else:\n            past_key_values = None\n        \n        generation_config, model_kwargs = model._prepare_generation_config(\n            None, do_sample=True\n            # change this to modify generate config\n            #top_k=5, top_p=0.85, temperature=0.1\n        )\n\n        logits_warper = tf_logits_warper(generation_config)\n\n        cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)\n        generated_ids = torch.zeros(\n            batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device\n        )\n        generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)\n        start_time = time.time()\n\n        chunk_start = 0\n        while chunk_start < seq_length:\n            chunk_end = min(chunk_start + chunk_size, seq_length)\n            if past_key_values != None:\n                past_key_values.cur_idx=cache_position[chunk_start:chunk_end]\n            logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)\n            chunk_start += chunk_size\n\n        next_token_scores = logits_warper(inputs, logits[:, -1, :])\n        if generation_config.do_sample:\n            probs = nn.functional.softmax(next_token_scores, dim=-1)\n            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)\n        else:\n            next_token = torch.argmax(next_token_scores, dim=-1)\n            \n        # decoded_first = tokenizer.decode(next_token)\n        # print(f\"\\n[DEBUG] first token id={next_token.item()} decoded='{decoded_first}'\\n\")\n\n        first_token_time = time.time() - start_time\n        \n        if use_flashinfer_mla:\n            MLAWrapperSingleton.reset_buffer()\n\n        prefill_count = seq_length\n        prefill_time = first_token_time\n        if force_think:\n            print(\"<think>\")\n        print(stream.put(next_token.item()), end=\"\", flush=True)\n        # stream.put(next_token.item())\n        generated_ids[:, seq_length] = next_token\n        tokens.append(int(next_token))\n        inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)\n        cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)\n        position_ids = cache_position.unsqueeze(0)\n        seq_length += 1\n        \n        cuda_graph_runner = None\n            \n        start_time = time.time()\n        for i in range(1, max_new_tokens):\n            if use_flashinfer_mla:\n                MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,\n                                             num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,\n                                             model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)\n            global warm_uped\n            if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):\n                warm_uped = True\n                cuda_graph_runner = CUDAGraphRunner()\n                cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)\n            next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)\n            inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)\n            generated_ids[:, cache_position] = next_token.int()\n            tokens.append(int(next_token))\n            seq_length += 1\n            \n            if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':\n                # print(stream.end(), end=\"\", flush=True)\n                stream.end()\n                break\n            else:\n                print(stream.put(next_token.item()), end=\"\", flush=True)\n                # stream.put(next_token.item())\n            cache_position += 1\n            position_ids = cache_position.unsqueeze(0)\n        \n\n    total_time = time.time() - start_time\n    tokens_generated = len(tokens)\n    tokens_per_second = tokens_generated / total_time\n\n    print(\"\")\n\n    print(f\"prompt eval count:    {prefill_count} token(s)\")\n    print(f\"prompt eval duration: {prefill_time}s\")\n    print(f\"prompt eval rate:     {prefill_count/prefill_time} tokens/s\")\n    print(f\"eval count:           {tokens_generated} token(s)\")\n    print(f\"eval duration:        {total_time}s\")\n    print(f\"eval rate:            {tokens_per_second} tokens/s\")\n\n    return tokens\n\ndef prefill_and_generate_capture(\n    model, tokenizer, inputs,\n    max_new_tokens=10000, use_cuda_graph: bool = True,\n    mode='normal', force_think: bool = False, chunk_size=16384,\n    use_flashinfer_mla=False, num_heads=None,\n    head_dim_ckv=None, head_dim_kpe=None, q_head_dim=None,\n    echo_stream: bool = True,\n):\n    \"\"\"\n    echo_stream=False 时，将不会在终端输出，只写入返回值。\n    \"\"\"\n    import os, time, torch, torch.nn as nn\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n    torch._dynamo.config.suppress_errors = True\n    batch_size, seq_length = inputs.shape\n    device_map = model.gguf_loader.tensor_device_map\n    torch_device = get_device('model.layers.0.self_attn', device_map)\n    torch_device = torch_device_mapping.get(torch_device, torch_device)\n    inputs = inputs.to(torch_device)\n    all_cuda_device = get_all_used_cuda_device(device_map)\n    tokens = []\n\n    def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph: bool = True):\n        if cuda_graph_runner is None:\n            use_cuda_graph = False\n        if use_cuda_graph:\n            logits = cuda_graph_runner(cur_token, position_ids, cache_position)\n        else:\n            # custom_stream = torch.cuda.Stream()\n            if torch.cuda.is_available():\n                torch.cuda.set_device(torch_device)\n            elif torch.xpu.is_available():\n                torch.xpu.set_device(torch_device)\n            else:\n                raise RuntimeError(f\"The device: {torch_device} is not available\")\n            inputs_embeds = model.model.embed_tokens(cur_token.to(\"cpu\")).to(torch_device)\n            # with torch.cuda.stream(custom_stream):\n            logits=model(inputs_embeds=inputs_embeds,\n                        position_ids=position_ids,\n                        cache_position=cache_position,\n                        past_key_values=past_key_values,\n                        return_dict=False, use_cache=True)[0]\n        if past_key_values != None and isinstance(past_key_values, StaticCache):\n            past_key_values.change_seq_length(1)\n        sync_all_device(all_cuda_device)\n        # print(logits)\n        next_token_scores = logits_warper(inputs, logits[:, -1, :])\n        if generation_config.do_sample:\n            probs = nn.functional.softmax(next_token_scores, dim=-1)\n            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)\n        else:\n            next_token = torch.argmax(next_token_scores, dim=-1)\n        return next_token\n    \n    # TODO: use CUDA Graph for chunk prefill, may get small improvement\n    def chunk_prefill(inputs, cache_position, past_key_values):\n        if mode == \"long_context\":\n            inputs_embeds = model.model.embed_tokens(inputs.to(\"cpu\"))\n        else:\n            inputs_embeds = model.model.embed_tokens(inputs.to(\"cpu\")).to(torch_device)\n        if use_flashinfer_mla:\n            MLAWrapperSingleton.update_buffer(past_key_values.max_pages)\n            MLAWrapperSingleton.need_plan_all()\n            \n        logits = model(\n            inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True\n        )[0][:,-1,:].unsqueeze(0).clone().to(torch_device)\n        \n        return logits\n\n    if torch.cuda.is_available():\n        torch.cuda.set_device(torch_device)\n    elif torch.xpu.is_available():\n        torch.xpu.set_device(torch_device)\n    else:\n        raise RuntimeError(f\"The device: {torch_device} is not available\")\n\n    with torch.no_grad():\n        stream = SilentCaptureStreamer(tokenizer)\n\n        if torch.xpu.is_available():\n            from ipex_llm.transformers.kv import DynamicUnbalancedFp8Cache, DynamicNormalCache\n            if model.config.architectures[0] in [\"DeepseekV3ForCausalLM\", \"DeepseekV2ForCausalLM\"]:\n                past_key_values = DynamicUnbalancedFp8Cache.from_legacy_cache(None)\n            else:\n                past_key_values = DynamicNormalCache.from_legacy_cache(None)\n        elif mode != 'long_context':\n            past_key_values = StaticCache(\n                config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = device_map, dtype = model.dtype\n            )\n        else:\n            past_key_values = None\n        \n        generation_config, model_kwargs = model._prepare_generation_config(\n            None, do_sample=True\n            # change this to modify generate config\n            #top_k=5, top_p=0.85, temperature=0.1\n        )\n\n        logits_warper = tf_logits_warper(generation_config)\n\n        cache_position = torch.arange(seq_length, device=torch_device, dtype=torch.int32)\n        generated_ids = torch.zeros(\n            batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device\n        )\n        generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)\n        start_time = time.time()\n\n        chunk_start = 0\n        while chunk_start < seq_length:\n            chunk_end = min(chunk_start + chunk_size, seq_length)\n            if past_key_values != None:\n                past_key_values.cur_idx=cache_position[chunk_start:chunk_end]\n            logits = chunk_prefill(inputs[:, chunk_start:chunk_end], cache_position[chunk_start:chunk_end], past_key_values)\n            chunk_start += chunk_size\n\n        next_token_scores = logits_warper(inputs, logits[:, -1, :])\n        if generation_config.do_sample:\n            probs = nn.functional.softmax(next_token_scores, dim=-1)\n            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)\n        else:\n            next_token = torch.argmax(next_token_scores, dim=-1)\n            \n        # decoded_first = tokenizer.decode(next_token)\n        # print(f\"\\n[DEBUG] first token id={next_token.item()} decoded='{decoded_first}'\\n\")\n\n        first_token_time = time.time() - start_time\n        \n        if use_flashinfer_mla:\n            MLAWrapperSingleton.reset_buffer()\n\n        prefill_count = seq_length\n        prefill_time = first_token_time\n        if force_think:\n            print(\"<think>\")\n        print(stream.put(next_token.item()), end=\"\", flush=True)\n        generated_ids[:, seq_length] = next_token\n        tokens.append(int(next_token))\n        inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)\n        cache_position = torch.tensor([seq_length], device=torch_device, dtype=torch.int32)\n        position_ids = cache_position.unsqueeze(0)\n        seq_length += 1\n        \n        cuda_graph_runner = None\n            \n        start_time = time.time()\n        for i in range(1, max_new_tokens):\n            if use_flashinfer_mla:\n                MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,None,\n                                             num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,\n                                             model.model.layers[0].self_attn.softmax_scale, torch.bfloat16, torch.bfloat16)\n            global warm_uped\n            if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):\n                warm_uped = True\n                cuda_graph_runner = CUDAGraphRunner()\n                cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)\n            next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, logits_warper, generation_config, use_cuda_graph).to(torch_device)\n            inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)\n            generated_ids[:, cache_position] = next_token.int()\n            tokens.append(int(next_token))\n            seq_length += 1\n            \n            if next_token[0].item() == tokenizer.eos_token_id or tokenizer.decode(next_token.tolist()) == '<|im_end|>':\n                print(stream.end(), end=\"\", flush=True)\n                break\n            else:\n                print(stream.put(next_token.item()), end=\"\", flush=True)\n            cache_position += 1\n            position_ids = cache_position.unsqueeze(0)\n\n        stream.end()\n        return stream.getvalue()\n"
  },
  {
    "path": "kt-sft/ktransformers/util/vendors.py",
    "content": "from __future__ import annotations\n\nfrom enum import IntEnum, auto\nfrom typing import Optional, Union, List\nimport torch\n\nclass GPUVendor(IntEnum):\n    NVIDIA = auto()\n    AMD = auto()\n    MooreThreads = auto()\n    MetaX = auto()\n    MUSA = auto()\n    Unknown = auto()\n\nclass DeviceManager:\n    \"\"\"\n    Device manager that provides a unified interface for handling different GPU vendors\n    \"\"\"\n    def __init__(self):\n        self.gpu_vendor = self._detect_gpu_vendor()\n        self.available_devices = self._get_available_devices()\n    \n    def _detect_gpu_vendor(self) -> GPUVendor:\n        \"\"\"Detect GPU vendor type\"\"\"\n        if not torch.cuda.is_available():\n            # Check MUSA availability (assuming a musa module exists)\n            try:\n                import musa\n                if musa.is_available():\n                    return GPUVendor.MUSA\n            except (ImportError, AttributeError):\n                pass\n            \n            return GPUVendor.Unknown\n        \n        device_name = torch.cuda.get_device_name(0).lower()\n        \n        if any(name in device_name for name in [\"nvidia\", \"geforce\", \"quadro\", \"tesla\", \"titan\", \"rtx\", \"gtx\"]):\n            return GPUVendor.NVIDIA\n        elif any(name in device_name for name in [\"amd\", \"radeon\", \"rx\", \"vega\", \"instinct\", \"firepro\", \"mi\"]):\n            return GPUVendor.AMD\n        elif any(name in device_name for name in [\"mthreads\", \"moore\", \"mtt\"]):\n            return GPUVendor.MooreThreads\n        elif any(name in device_name for name in [\"metax\", \"meta\"]):\n            return GPUVendor.MetaX\n        elif \"musa\" in device_name:\n            return GPUVendor.MUSA\n        \n        # Backend check\n        try:\n            if hasattr(torch.version, 'hip') and torch.version.hip is not None:\n                return GPUVendor.AMD\n            elif hasattr(torch.version, 'cuda') and torch.version.cuda is not None:\n                return GPUVendor.NVIDIA\n        except:\n            pass\n            \n        return GPUVendor.Unknown\n    \n    def _get_available_devices(self) -> List[int]:\n        \"\"\"Get list of available device indices\"\"\"\n        devices = []\n        \n        if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:\n            devices = list(range(torch.cuda.device_count()))\n        elif self.gpu_vendor == GPUVendor.MUSA:\n            try:\n                import musa\n                devices = list(range(musa.device_count()))\n            except (ImportError, AttributeError):\n                pass\n            \n        return devices\n    \n    def get_device_str(self, device_id: Union[int, str]) -> str:\n        \"\"\"\n        Get device string for the given device ID\n        \n        Args:\n            device_id: Device index (0, 1, 2, etc.), -1 for CPU, or \"cpu\" string\n            \n        Returns:\n            Device string representation (e.g., \"cuda:0\", \"musa:1\", \"cpu\")\n        \"\"\"\n        if device_id == -1 or device_id == \"cpu\":\n            return \"cpu\"\n            \n        if isinstance(device_id, int):\n            if self.gpu_vendor == GPUVendor.NVIDIA or self.gpu_vendor == GPUVendor.AMD:\n                if device_id < torch.cuda.device_count():\n                    return f\"cuda:{device_id}\"\n            elif self.gpu_vendor == GPUVendor.MUSA:\n                try:\n                    import musa\n                    if device_id < musa.device_count():\n                        return f\"musa:{device_id}\"\n                except (ImportError, AttributeError):\n                    pass\n        \n        return \"cpu\"\n    \n    def to_torch_device(self, device_id: Union[int, str] = 0) -> torch.device:\n        \"\"\"\n        Convert device ID to torch.device object\n        \n        Args:\n            device_id: Device index (0, 1, 2, etc.), -1 for CPU, or \"cpu\" string\n            \n        Returns:\n            torch.device object\n        \"\"\"\n        device_str = self.get_device_str(device_id)\n        \n        # Handle MUSA device\n        if device_str.startswith(\"musa:\"):\n            try:\n                import musa\n                index = int(device_str.split(\":\")[-1])\n                return musa.device(index)\n            except (ImportError, ValueError, AttributeError):\n                return torch.device(\"cpu\")\n        \n        # Standard PyTorch device\n        return torch.device(device_str)\n    \n    def move_tensor_to_device(self, tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:\n        \"\"\"\n        Move tensor to specified device\n        \n        Args:\n            tensor: PyTorch tensor to move\n            device_id: Device index (0, 1, 2, etc.), -1 for CPU, or \"cpu\" string\n            \n        Returns:\n            Tensor moved to the specified device\n        \"\"\"\n        device = self.to_torch_device(device_id)\n        return tensor.to(device)\n    \n    def is_available(self, index: int = 0) -> bool:\n        \"\"\"\n        Check if device at specified index is available\n        \n        Args:\n            index: Device index to check\n            \n        Returns:\n            True if the device is available, False otherwise\n        \"\"\"\n        if index < 0:\n            return True  # CPU is always available\n            \n        return index in self.available_devices\n    \n    def get_all_devices(self) -> List[int]:\n        \"\"\"\n        Get all available device indices\n        \n        Returns:\n            List of available device indices (0, 1, 2, etc.)\n        \"\"\"\n        return self.available_devices\n\n# Create global device manager instance\ndevice_manager = DeviceManager()\n\n# Convenience functions\ndef get_device(device_id: Union[int, str] = 0) -> torch.device:\n    \"\"\"\n    Get torch.device object for the specified device ID\n    \n    Args:\n        device_id: Device index (0, 1, 2, etc.), -1 for CPU, or \"cpu\" string\n        \n    Returns:\n        torch.device object\n    \"\"\"\n    return device_manager.to_torch_device(device_id)\n\ndef to_device(tensor: torch.Tensor, device_id: Union[int, str] = 0) -> torch.Tensor:\n    \"\"\"\n    Move tensor to specified device\n    \n    Args:\n        tensor: PyTorch tensor to move\n        device_id: Device index (0, 1, 2, etc.), -1 for CPU, or \"cpu\" string\n        \n    Returns:\n        Tensor moved to the specified device\n    \"\"\"\n    return device_manager.move_tensor_to_device(tensor, device_id)\n\n# Get devices\ncpu_device = get_device(-1)        # CPU using index -1\ncpu_device2 = get_device(\"cpu\")    # CPU using string \"cpu\"\ngpu0 = get_device(0)               # First GPU\n\n# Move tensors\nx = torch.randn(3, 3)\nx_gpu = to_device(x, 0)            # Move to first GPU\nx_cpu1 = to_device(x, -1)          # Move to CPU using index -1\nx_cpu2 = to_device(x, \"cpu\")       # Move to CPU using string \"cpu\""
  },
  {
    "path": "kt-sft/ktransformers/util/weight_loader.py",
    "content": "from abc import ABC, abstractmethod\nimport os\nimport torch\nimport numpy as np\nfrom safetensors import safe_open\nfrom typing import Dict, Any, Optional, Union\n\nclass ModelLoader(ABC):\n    \"\"\"\n    Abstract base class for model loaders.\n    Defines the interface that all model loaders must implement.\n    \"\"\"\n    \n    @abstractmethod\n    def load_tensor(self, name: str, device: str = \"cpu\") -> torch.Tensor:\n        \"\"\"\n        Load a tensor by name.\n        \n        Args:\n            name: Name of the tensor to load\n            device: Device to load the tensor to\n            \n        Returns:\n            The loaded tensor\n        \"\"\"\n        pass\n    \n    @classmethod\n    @abstractmethod\n    def supports_format(cls, path: str) -> bool:\n        \"\"\"\n        Check if this loader supports the given path format.\n        \n        Args:\n            path: Path to check\n            \n        Returns:\n            True if this loader supports the given path, False otherwise\n        \"\"\"\n        pass\n\n\nclass SafeTensorLoader(ModelLoader):\n    \"\"\"\n    Loader for SafeTensor format models.\n    \"\"\"\n    \n    def __init__(self, path: str):\n        \"\"\"\n        Initialize the SafeTensor loader.\n        \n        Args:\n            path: Path to the model directory or file\n        \"\"\"\n        self.tensor_file_map = {}  # Maps tensor names to file paths\n        self.file_handle_map = {}  # Maps file names to file handles\n        self._load_tensor_file_map(path)\n    \n    def _load_tensor_file_map(self, path: str) -> None:\n        \"\"\"\n        Load the tensor file map from the given path.\n        \n        Args:\n            path: Path to the model directory or file\n        \"\"\"\n        # Normalize path to directory\n        if not os.path.exists(path):\n            raise FileNotFoundError(f\"Path not found: {path}\")\n        if os.path.isfile(path):\n            folder_path = os.path.dirname(path)\n        else:\n            folder_path = path\n\n        found_safetensor = False\n        for root, _, files in os.walk(folder_path):\n            files = sorted(files)\n            for file in files:\n                if file.endswith(\".safetensors\"):\n                    found_safetensor = True\n                    file_path = os.path.join(root, file)\n                    if file not in self.file_handle_map:\n                        try:\n                            handle = safe_open(file_path, framework=\"pt\")\n                            self.file_handle_map[file] = handle\n                        except Exception as e:\n                            print(f\"Error opening Safetensor file {file_path}: {e}\")\n                            continue\n\n                    f = self.file_handle_map.get(file)\n                    if f is None:\n                        continue\n                    try:\n                        for key in f.keys():\n                            self.tensor_file_map[key] = file\n                    except Exception as e:\n                        print(f\"Error reading Safetensor file {file_path}: {e}\")\n\n        if not found_safetensor:\n            # Not raising an error here allows for the factory to try other loaders\n            print(f\"No Safetensor files found in {folder_path}\")\n    \n    def load_tensor(self, name: str, device: str = \"cpu\") -> torch.Tensor:\n        \"\"\"\n        Load a tensor by name.\n        \n        Args:\n            name: Name of the tensor to load\n            device: Device to load the tensor to\n            \n        Returns:\n            The loaded tensor\n        \"\"\"\n        if name not in self.tensor_file_map:\n            raise KeyError(f\"Key {name} not found in Safetensor files\")\n        file = self.tensor_file_map[name]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(name)\n        return tensor.to(device)\n    \n    def load_dequantized_tensor(self, name: str, device: str = \"cpu\") -> torch.Tensor:\n        \"\"\"\n        Load and dequantize a tensor.\n        \n        Args:\n            name: Name of the tensor to load\n            device: Device to load the tensor to\n            \n        Returns:\n            The dequantized tensor\n        \"\"\"\n        if name not in self.tensor_file_map:\n            raise KeyError(f\"Key {name} not found in Safetensor files\")\n        file = self.tensor_file_map[name]\n        f = self.file_handle_map.get(file)\n        if f is None:\n            raise FileNotFoundError(f\"File {file} not found in Safetensor files\")\n        tensor = f.get_tensor(name).to(device)\n        if name.endswith(\".weight\"):\n            if name[:-7] + \".weight_scale_inv\" in self.tensor_file_map:\n                weight_scale_inv = f.get_tensor(name[:-7] + \".weight_scale_inv\").to(device)\n                # Assuming weight_dequant function is imported\n                from ktransformers.ktransformers_ext.triton.fp8gemm import weight_dequant\n                tensor = weight_dequant(tensor, weight_scale_inv)\n        return tensor.to(device)\n    \n    def close_all_handles(self) -> None:\n        \"\"\"\n        Close all file handles.\n        \"\"\"\n        for handle in self.file_handle_map.values():\n            handle.close()\n        self.file_handle_map.clear()\n\n    @classmethod\n    def supports_format(cls, path: str) -> bool:\n        \"\"\"\n        Check if this loader supports the given path format.\n        \n        Args:\n            path: Path to check\n            \n        Returns:\n            True if safetensor files are found in the path, False otherwise\n        \"\"\"\n        # Normalize path to directory\n        if not os.path.exists(path):\n            return False\n        if os.path.isfile(path):\n            if path.endswith(\".safetensors\"):\n                return True\n            folder_path = os.path.dirname(path)\n        else:\n            folder_path = path\n            \n        # Check if any safetensor files exist in the folder\n        for root, _, files in os.walk(folder_path):\n            for file in files:\n                if file.endswith(\".safetensors\"):\n                    return True\n        return False\n\n\nclass GGUFLoader(ModelLoader):\n    \"\"\"\n    Loader for GGUF format models.\n    \"\"\"\n    \n    def __init__(self, path: str):\n        \"\"\"\n        Initialize the GGUF loader.\n        \n        Args:\n            path: Path to the model directory or file\n        \"\"\"\n        # Check if path exists\n        if not os.path.exists(path):\n            raise FileNotFoundError(f\"GGUF dir not found: {path}\")\n        if os.path.isfile(path):\n            self.gguf_path = os.path.dirname(path)\n        else:\n            self.gguf_path = path\n            \n        self.tensor_info = {}  # Stores tensor metadata\n        self.tensor_file_map = {}  # Maps tensor names to file paths\n        self.file_data_map = {}  # Maps file paths to memory-mapped data\n        self.gguf_file_meta = {}  # Stores GGUF metadata\n        \n        # For compatibility with the factory pattern\n        self.safetensor_loader = None\n        \n        # Scan all GGUF files in the directory\n        found_gguf = False\n        for root, _, files in os.walk(self.gguf_path):\n            for file in files:\n                if file.endswith(\".gguf\"):\n                    found_gguf = True\n                    file_path = os.path.join(root, file)\n                    with open(file_path, \"rb\") as f:\n                        self._load_gguf(f)\n                        if file_path not in self.file_data_map:\n                            self.file_data_map[file_path] = np.memmap(file_path, mode='r')\n        \n        if not found_gguf:\n            raise FileNotFoundError(f\"Cannot find any .gguf files in: {self.gguf_path}\")\n    \n    def _load_gguf(self, f) -> None:\n        \"\"\"\n        Load GGUF file metadata and tensor info.\n        \n        Args:\n            f: File handle of the GGUF file\n        \"\"\"\n        # Implementation should follow the original GGUFLoader._load_gguf\n        # This is a simplified version for illustration\n        f.seek(0)\n        assert f.read(4) == b'GGUF'\n        \n        # Read header\n        values = struct.unpack(\"<IQQ\", f.read(4+8+8))\n        version, n_tensors, n_kv = values\n        if version != 3:\n            warnings.warn(f\"Version {version} has never been tested, might not work\")\n\n        # Read key-value pairs\n        info = {}\n        for _ in range(n_kv):\n            name = self._read_value(f, 8)  # DATA_TYPES[\"string\"]\n            data_type = struct.unpack(\"<I\", f.read(4))[0]\n            info[name] = self._read_value(f, data_type)\n\n        # Read tensor info\n        tensor_info = {}\n        for _ in range(n_tensors):\n            name = self._read_value(f, 8)  # DATA_TYPES[\"string\"]\n            shape_len = self._read_value(f, 4)  # DATA_TYPES[\"uint32\"]\n            shape = [self._read_value(f, 10) for _ in range(shape_len)]  # DATA_TYPES[\"uint64\"]\n            ggml_type = self._read_value(f, 4)  # DATA_TYPES[\"uint32\"]\n            offset = self._read_value(f, 10)  # DATA_TYPES[\"uint64\"]\n            \n            # Additional tensor metadata would be calculated here\n            # For brevity, we're omitting the detailed tensor metadata calculation\n            tensor_info[name] = {\n                \"ggml_type\": ggml_type,\n                \"shape\": shape,\n                \"offset\": offset,\n                # ... other tensor metadata\n            }\n            \n        start = f.tell()\n        alignment = info.get(\"general.alignment\", 32)\n        \n        # Calculate actual file offsets\n        for t in tensor_info.values():\n            offset = start + t[\"offset\"]\n            offset += (alignment - offset % alignment) % alignment\n            t[\"offset\"] = offset\n            \n        # Update file maps\n        for name in tensor_info:\n            self.tensor_file_map[name] = f.name\n            \n        self.tensor_info.update(tensor_info)\n        self.gguf_file_meta.update(info)\n    \n    def _read_value(self, f, data_type) -> Any:\n        \"\"\"\n        Read a value from the file according to its data type.\n        \n        Args:\n            f: File handle\n            data_type: Type of data to read\n            \n        Returns:\n            The read value\n        \"\"\"\n        # Simplified implementation\n        # In a complete implementation, this would handle all data types\n        if data_type == 8:  # DATA_TYPES[\"string\"]\n            length = struct.unpack(\"<Q\", f.read(8))[0]\n            return f.read(length).decode(\"utf-8\")\n        elif data_type == 4:  # DATA_TYPES[\"uint32\"]\n            return struct.unpack(\"<I\", f.read(4))[0]\n        elif data_type == 10:  # DATA_TYPES[\"uint64\"]\n            return struct.unpack(\"<Q\", f.read(8))[0]\n        # ... handling for other data types\n        return None\n    \n    def load_tensor(self, name: str, device: str = \"cpu\") -> torch.Tensor:\n        \"\"\"\n        Load a tensor by name.\n        \n        Args:\n            name: Name of the tensor to load\n            device: Device to load the tensor to\n            \n        Returns:\n            The loaded tensor\n        \"\"\"\n        # This should call load_gguf_tensor with the appropriate parameters\n        return self.load_gguf_tensor(name, device)\n    \n    def load_gguf_tensor(self, name: str, device: str = \"cpu\", target_dtype = None) -> torch.Tensor:\n        \"\"\"\n        Load a GGUF tensor by name.\n        \n        Args:\n            name: Name of the tensor to load\n            device: Device to load the tensor to\n            target_dtype: Target data type for the tensor\n            \n        Returns:\n            The loaded tensor\n        \"\"\"\n        # Implementation would follow the original GGUFLoader.load_gguf_tensor\n        # This is a placeholder for illustration\n        if name not in self.tensor_info:\n            raise KeyError(f\"Tensor {name} not found\")\n            \n        # Actual implementation would dequantize the tensor data\n        # and return a torch.Tensor\n        return torch.zeros(1, device=device)  # Placeholder\n    \n    @classmethod\n    def supports_format(cls, path: str) -> bool:\n        \"\"\"\n        Check if this loader supports the given path format.\n        \n        Args:\n            path: Path to check\n            \n        Returns:\n            True if GGUF files are found in the path, False otherwise\n        \"\"\"\n        # Normalize path to directory\n        if not os.path.exists(path):\n            return False\n        if os.path.isfile(path):\n            return path.endswith(\".gguf\")\n        \n        # Check if any GGUF files exist in the folder\n        for root, _, files in os.walk(path):\n            for file in files:\n                if file.endswith(\".gguf\"):\n                    return True\n        return False"
  },
  {
    "path": "kt-sft/ktransformers/website/.browserslistrc",
    "content": "> 1%\nlast 2 versions\nnot dead\nnot ie 11\n"
  },
  {
    "path": "kt-sft/ktransformers/website/.eslintrc.js",
    "content": "module.exports = {\n  root: true,\n  env: {\n    node: true\n  },\n  'extends': [\n    'plugin:vue/vue3-essential',\n    'eslint:recommended',\n    '@vue/typescript/recommended'\n  ],\n  parserOptions: {\n    ecmaVersion: 2020\n  },\n  rules: {\n    'no-console': process.env.NODE_ENV === 'production' ? 'warn' : 'off',\n    'no-debugger': process.env.NODE_ENV === 'production' ? 'warn' : 'off'\n  },\n  overrides: [\n    {\n      files: [\n        '**/__tests__/*.{j,t}s?(x)',\n        '**/tests/unit/**/*.spec.{j,t}s?(x)'\n      ],\n      env: {\n        jest: true\n      }\n    }\n  ]\n}\n"
  },
  {
    "path": "kt-sft/ktransformers/website/.gitignore",
    "content": ".DS_Store\nnode_modules\n/dist\n\n\n# local env files\n.env.local\n.env.*.local\n\n# Log files\nnpm-debug.log*\nyarn-debug.log*\nyarn-error.log*\npnpm-debug.log*\n\n# Editor directories and files\n.idea\n.vscode\n*.suo\n*.ntvs*\n*.njsproj\n*.sln\n*.sw?\n"
  },
  {
    "path": "kt-sft/ktransformers/website/README.md",
    "content": "# \n\n## Project setup\n```\nnpm install\n```\n\n### Compiles and hot-reloads for development\n```\nnpm run serve\n```\n\n### Compiles and minifies for production\n```\nnpm run build\n```\n\n### Run your unit tests\n```\nnpm run test:unit\n```\n\n### Lints and fixes files\n```\nnpm run lint\n```\n\n### Customize configuration\nSee [Configuration Reference](https://cli.vuejs.org/config/).\n"
  },
  {
    "path": "kt-sft/ktransformers/website/config.d.ts",
    "content": "declare module '*.js' {\n    const config: {\n      apiUrl: string;\n      port:number;\n    };\n    export { config };\n  }"
  },
  {
    "path": "kt-sft/ktransformers/website/jest.config.js",
    "content": "module.exports = {\n  preset: '@vue/cli-plugin-unit-jest/presets/typescript'\n}\n"
  },
  {
    "path": "kt-sft/ktransformers/website/package.json",
    "content": "{\n  \"name\": \"\",\n  \"version\": \"\",\n  \"private\": true,\n  \"scripts\": {\n    \"serve\": \"vue-cli-service serve\",\n    \"build\": \"vue-cli-service build\",\n    \"test:unit\": \"vue-cli-service test:unit\",\n    \"lint\": \"vue-cli-service lint\"\n  },\n  \"dependencies\": {\n    \"@types/pdfjs-dist\": \"^2.10.378\",\n    \"@types/websocket\": \"^1.0.10\",\n    \"@vue/cli\": \"^5.0.8\",\n    \"ant-design-vue\": \"^4.2.1\",\n    \"apexcharts\": \"^3.49.1\",\n    \"axios\": \"^1.7.0\",\n    \"axios-extensions\": \"^3.1.6\",\n    \"better-scroll\": \"^2.5.1\",\n    \"element-plus\": \"^2.7.3\",\n    \"marked\": \"^12.0.2\",\n    \"marked-highlight\": \"^2.1.1\",\n    \"pdf-lib\": \"^1.17.1\",\n    \"pdfobject\": \"^2.3.0\",\n    \"v-clipboard\": \"^3.0.0-next.1\",\n    \"vue\": \"^3.4.27\",\n    \"vue-i18n\": \"^9.13.1\",\n    \"vue-pdf\": \"^4.3.0\",\n    \"vue-router\": \"^4.0.3\",\n    \"vue3-apexcharts\": \"^1.5.3\",\n    \"vuex\": \"^4.0.0\",\n    \"webpack\": \"^5.91.0\",\n    \"webpack-cli\": \"^5.1.4\",\n    \"websocket\": \"^1.0.35\"\n  },\n  \"devDependencies\": {\n    \"@types/jest\": \"^27.0.1\",\n    \"@types/pdfobject\": \"^2.2.5\",\n    \"@typescript-eslint/eslint-plugin\": \"^5.4.0\",\n    \"@typescript-eslint/parser\": \"^5.4.0\",\n    \"@vue/cli-plugin-eslint\": \"~5.0.0\",\n    \"@vue/cli-plugin-router\": \"~5.0.0\",\n    \"@vue/cli-plugin-typescript\": \"~5.0.0\",\n    \"@vue/cli-plugin-unit-jest\": \"~5.0.0\",\n    \"@vue/cli-plugin-vuex\": \"~5.0.0\",\n    \"@vue/cli-service\": \"~5.0.0\",\n    \"@vue/eslint-config-typescript\": \"^9.1.0\",\n    \"@vue/test-utils\": \"^2.0.0-0\",\n    \"@vue/vue3-jest\": \"^27.0.0-alpha.1\",\n    \"babel-jest\": \"^27.0.6\",\n    \"eslint\": \"^7.32.0\",\n    \"eslint-plugin-vue\": \"^8.0.3\",\n    \"jest\": \"^27.0.5\",\n    \"stylus\": \"^0.55.0\",\n    \"stylus-loader\": \"^6.1.0\",\n    \"ts-jest\": \"^27.0.4\",\n    \"typescript\": \"~4.5.5\"\n  },\n  \"_id\": \"@\",\n  \"readme\": \"ERROR: No README data found!\"\n}\n"
  },
  {
    "path": "kt-sft/ktransformers/website/public/config.js",
    "content": "window.configWeb = {\n    apiUrl: 'http://119.255.238.12:15670/v1',\n    port: 8080,\n  };"
  },
  {
    "path": "kt-sft/ktransformers/website/public/css/reset.css",
    "content": "html, body, div, span, applet, object, iframe,\nh1, h2, h3, h4, h5, h6, p, blockquote, pre,\na, abbr, acronym, address, big, cite, code,\ndel, dfn, em, img, ins, kbd, q, s, samp,\nsmall, strike, strong, sub, sup, tt, var,\nb, u, i, center,\ndl, dt, dd, ol, ul, li,\nfieldset, form, label, legend,textarea,\ntable, caption, tbody, tfoot, thead, tr, th, td,\narticle, aside, canvas, details, embed,\nfigure, figcaption, footer, header, hgroup,\nmenu, nav, output, ruby, section, summary,\ntime, mark, audio, video {\n    margin: 0;\n    padding: 0;\n    border: 0;\n    font-size: 100%;\n    *font: inherit;\n    font-family: Arial, Microsoft YaHei, SimHei, Tahoma, sans-serif !important;\n    vertical-align: baseline;\n}\n/* HTML5 display-role reset for older browsers */\narticle, aside, details, figcaption, figure,\nfooter, header, hgroup, menu, nav, section {\n    display: block;\n}\nbody {\n    line-height: 1;\n    -webkit-text-size-adjust: 100%!important;\n    margin: 0;\n}\nhtml,body {\n    height: 100%;\n    width: 100%;\n    overflow: hidden;\n}\nol, ul {\n    list-style: none;\n}\nblockquote, q {\n    quotes: none;\n}\nblockquote:before, blockquote:after,\nq:before, q:after {\n    content: '';\n    content: none;\n}\ntable {\n    border-collapse: collapse;\n    border-spacing: 0;\n}\n\n.clearfix:before,\n.clearfix:after {\n    content:\"\";\n    display:table\n}\n.clearfix:after {\n    clear:both\n}\n\n/*显示省略号*/\n.ellipsis{\n    overflow: hidden;\n    text-overflow: ellipsis;\n    white-space: nowrap;\n}\n"
  },
  {
    "path": "kt-sft/ktransformers/website/public/index.html",
    "content": "<!DOCTYPE html>\n<html lang=\"\">\n  <head>\n    <meta charset=\"utf-8\">\n    <meta http-equiv=\"X-UA-Compatible\" content=\"IE=edge\">\n    <meta name=\"viewport\" content=\"width=device-width,initial-scale=1.0,maximum-scale=1.0,minimum-scale=1.0,user-scalable=no\">\n    <script src=\"./config.js\"></script>\n    <link rel=\"icon\" href=\"./balck.ico\" />\n    <link type=\"text/css\" rel=\"stylesheet\" href=\"<%= BASE_URL %>/css/reset.css\">\n    <title>KTransformers</title>\n  </head>\n  <body onselectstart='return false' onselect='return false'>\n    <noscript>\n      <strong>We're sorry but <%= htmlWebpackPlugin.options.title %> doesn't work properly without JavaScript enabled. Please enable it to continue.</strong>\n    </noscript>\n    <div id=\"app\"></div>\n    <!-- built files will be auto injected -->\n  </body>\n</html>\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/App.vue",
    "content": "<template>\n  <div class=\"app-container\" @contextmenu.prevent.stop=\"\">\n    <keep-alive>\n      <router-view/>\n    </keep-alive>\n  </div>\n</template>\n\n<script setup lang=\"ts\">\n</script>\n\n<style lang=\"stylus\">\n  @import \"assets/iconfont/iconfont.css\"\n  #app\n  .app-container\n    width: 100%\n    height: 100%\n    position: relative\n</style>"
  },
  {
    "path": "kt-sft/ktransformers/website/src/api/api-client.ts",
    "content": "import axios, { AxiosInstance } from 'axios';\nimport {baseURL} from '@/conf/config';\nconst apiClient: AxiosInstance = axios.create({\n    baseURL: baseURL,\n    // baseURL: '/api',\n    headers: {\n        'Content-Type': 'application/json',\n    },\n    withCredentials: true,\n});\nexport default apiClient;\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/api/assistant.ts",
    "content": "import apiClient from './api-client';\nimport { IAssistant,IDeleteResult, IAssistantWithStatus } from '../utils/types';\nfunction filterAndConvert(\n    assistantsWithStatus: IAssistantWithStatus[],\n    statusCondition: string\n  ): IAssistant[] {\n    return assistantsWithStatus\n      .filter((assistant) => assistant.build_status.status === statusCondition)\n      .map(({ build_status, ...rest }) => rest);\n  }\n\ninterface IAssistantData {\n    model: string;\n    prefix_system_prompt?: string;\n    suffix_system_prompt?: string;\n    name?: string;\n    description?: string;\n    tools?: any[];\n    tool_resources?: object;\n    metadata?:{[key:string]:any}\n    top_p?: number;\n    temperature?: number;\n    response_format?: string;\n    instructions?: string;\n}\n\nexport const createAssistant = async (data: IAssistantData): Promise<IAssistant> => {\n    const assistant_data: {\n        model: string;\n        instructions?: string;\n        name?: string;\n        description?: string;\n        tools?: any[];\n        tool_resources?: object;\n        metadata?:{[key:string]:any}\n        top_p?: number;\n        temperature?: number;\n        response_format?: string;\n    } = {\n        model: data.model\n    };\n\n    if (data.prefix_system_prompt) {\n        assistant_data.instructions = data.prefix_system_prompt;\n    }\n    if (data.suffix_system_prompt) {\n        assistant_data.instructions = data.suffix_system_prompt;\n    }\n    if (data.name) {\n        assistant_data.name = data.name;\n    }\n    if (data.description) {\n        assistant_data.description = data.description;\n    }\n    if (data.tools) {\n        assistant_data.tools = data.tools;\n    }\n    if (data.tool_resources) {\n        assistant_data.tool_resources = data.tool_resources;\n    }\n    if (data.metadata) {\n        assistant_data.metadata = data.metadata\n    }\n    if (typeof data.top_p !== 'undefined') {\n        assistant_data.top_p = data.top_p;\n    }\n    if (typeof data.temperature !== 'undefined') {\n        assistant_data.temperature = data.temperature;\n    }\n    if (data.response_format) {\n        assistant_data.response_format = data.response_format;\n    }\n    if (data.instructions) {\n        assistant_data.instructions = data.instructions;\n    }\n    console.log(assistant_data)\n    const response = await apiClient.post<IAssistant>(\n        '/assistants/',\n        assistant_data\n    );\n    console.log(\"response\", response)\n    return response.data;\n};\n\n\nexport const listAssistants = async (\n    limit?: number,\n    order?: string,\n    after?: string,\n    before?: string,\n    run_id?: string,\n): Promise<IAssistant[]> => {\n    const params: {\n        limit?: number,\n        order?: string,\n        after?: string,\n        before?: string,\n        run_id?: string\n    } = {};\n\n    if (typeof limit !== 'undefined') {\n        params.limit = limit;\n    }\n    if (typeof order !== 'undefined') {\n        params.order = order;\n    }\n    if (typeof after !== 'undefined') {\n        params.after = after;\n    }\n    if (typeof before !== 'undefined') {\n        params.before = before;\n    }\n    if (typeof run_id !== 'undefined') {\n        params.run_id = run_id;\n    }\n    const response = await apiClient.get<IAssistantWithStatus[]>('/assistants/status', {\n        params\n    });\n    let tmp = response.data\n    let result = [] as IAssistant[]\n    const filteredAssistants = filterAndConvert(tmp, 'completed');\n    return filteredAssistants\n};\n\nexport const getAssistant = async (\n    assistant_id: string\n): Promise<IAssistant> => {\n    const response = await apiClient.get<IAssistant>(`/assistants/${assistant_id}`);\n    return response.data;\n}\n\nexport const deleteAssistant = async (\n    assistant_id: string\n): Promise<IDeleteResult> => {\n    const response = await apiClient.delete<IDeleteResult>(`/assistants/${assistant_id}`);\n    return response.data;\n}\n\nexport const getRelatedThreadId = async (\n    assistant_id: string\n): Promise<string[]> => {\n    const response = await apiClient.get<string[]>(`/assistants/${assistant_id}/related_thread`);\n    return response.data;\n}\n\nexport const listAssistantsWithStatus = async (\n    limit?: number,\n    order?: string,\n    after?: string,\n    before?: string,\n    run_id?: string,\n): Promise<IAssistantWithStatus[]> => {\n    const params: {\n        limit?: number,\n        order?: string,\n        after?: string,\n        before?: string,\n        run_id?: string\n    } = {};\n\n    if (typeof limit !== 'undefined') {\n        params.limit = limit;\n    }\n    if (typeof order !== 'undefined') {\n        params.order = order;\n    }\n    if (typeof after !== 'undefined') {\n        params.after = after;\n    }\n    if (typeof before !== 'undefined') {\n        params.before = before;\n    }\n    if (typeof run_id !== 'undefined') {\n        params.run_id = run_id;\n    }\n    console.log(params)\n    const response = await apiClient.get<IAssistantWithStatus[]>('/assistants/status', {\n        params\n    });\n\n    return response.data;\n};\n\n\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/api/message.ts",
    "content": "import apiClient from './api-client';\nimport { IMessage,IDeleteResult } from '../utils/types';\n\nexport const createMessage = async (\n    thread_id: string,\n    content: string,\n    role?: string,\n    attachments?: any[],\n    metadata?:{[key:string]:any}\n): Promise<IMessage> => {\n    const message_data: {\n        content: string;\n        role?: string;\n        attachments?: any[];\n        metadata?:{[key:string]:any}\n    } = {\n        content,\n    };\n\n    if (metadata) {\n        message_data.metadata = metadata;\n    }\n    if (role) {\n        message_data.role = role;\n    }\n    if (attachments) {\n        message_data.attachments = attachments;\n    }\n    const response = await apiClient.post<IMessage>(`/threads/${thread_id}/messages`, message_data);\n    return response.data;\n};\n\n\nexport const listMessages = async (\n    thread_id: string,\n    limit?: number,\n    order?: string,\n    after?: string,\n    before?: string,\n    run_id?: string,\n): Promise<IMessage[]> => {\n    const params: {\n        limit?: number,\n        order?: string,\n        after?: string,\n        before?: string,\n        run_id?: string\n    } = {};\n\n    if (typeof limit !== 'undefined') {\n        params.limit = limit;\n    }\n    if (typeof order !== 'undefined') {\n        params.order = order;\n    }\n    if (typeof after !== 'undefined') {\n        params.after = after;\n    }\n    if (typeof before !== 'undefined') {\n        params.before = before;\n    }\n    if (typeof run_id !== 'undefined') {\n        params.run_id = run_id;\n    }\n\n    const response = await apiClient.get<IMessage[]>(`/threads/${thread_id}/messages`, {\n        params\n    });\n\n    return response.data;\n};\nexport const deleteMessage = async(thread_id:string, message_id:string): Promise<IDeleteResult> => {\n    const response = await apiClient.delete<IDeleteResult>(`/threads/${thread_id}/messages/${message_id}`);\n    return response.data;\n}\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/api/run.ts",
    "content": "import apiClient from './api-client';\nimport { IRun } from '../utils/types';\nimport {baseURL} from '@/conf/config';\ninterface IRunData {\n    assistant_id: string;\n    model?: string;\n    instructions?: string;\n    additional_instructions?: string;\n    additional_messages?: any[];\n    tools?: any[];\n    metadata?: { [key: string]: any }\n    temperature?: number;\n    top_p?: number;\n    stream?: boolean;\n    max_prompt_tokens?: number;\n    max_completion_tokens?: number;\n    truncation_strategy?: object;\n    tool_choice?: string;\n    response_format?: string | object;\n}\n\n\nexport async function* createRun(\n    data: IRunData,\n    thread_id: string\n): AsyncGenerator<string> {\n    const run_data = {\n        ...data, \n        assistant_id: data.assistant_id, \n    };\n\n    const response = await fetch(`${baseURL}/threads/${thread_id}/runs`, {\n        method: 'POST',\n        headers: {\n            'Content-Type': 'application/json',\n        },\n        body: JSON.stringify(run_data),\n    });\n\n    if (!response.ok) {\n        throw new Error(`HTTP error! status: ${response.status}`);\n    }\n\n    if (!response.body) {\n        throw new Error('Response body is missing');\n    }\n    const reader = response.body.getReader();\n    const decoder = new TextDecoder();\n    let buffer = '';\n    try {\n        while (true) {\n            const { done, value } = await reader.read();\n            if (done) return;\n            buffer += decoder.decode(value, { stream: true });\n\n            let eventIndex = buffer.indexOf(\"\\n\\n\");\n            while (eventIndex !== -1) {\n                const event = buffer.slice(0, eventIndex);\n                buffer = buffer.slice(eventIndex + 2);\n                if (event.startsWith(\"event: thread.run.created\")) {\n                    const dataIndex = event.indexOf(\"data: \");\n                    if (dataIndex !== -1) {\n                        const datads = event.slice(39, 75)\n                        yield datads;\n                    }\n                } else if (event.startsWith(\"event: thread.message.delta\")) {\n                    const dataIndex = event.indexOf(\"data: \");\n                    if (dataIndex !== -1) {\n                        const data = JSON.parse(event.slice(dataIndex + 6));\n                        yield data.delta.content[0].text.value || '';\n                    }\n                } else if (event.startsWith(\"event: done\")) {\n                    return;\n                }\n\n                eventIndex = buffer.indexOf(\"\\n\\n\");\n            }\n        }\n    } catch (e) {\n\n        console.error('An error occurred while reading the response stream:', e);\n        // throw e; \n        return e\n    }\n}\n// 定义取消运行的函数\nexport async function cancelRun(threadId: string, runId: string){\n    const run_data = {\n        thread_id:threadId,\n        run_id:runId,\n    };\n    try {\n        const response = await fetch(`${baseURL}/threads/${threadId}/runs/${runId}/cancel`, {\n            method: 'POST',\n        });\n\n        if (!response.ok) {\n            throw new Error(`HTTP error! status: ${response.status}`);\n        }\n\n        return response;\n    } catch (error) {\n        console.error('An error occurred while cancelling the run:', error);\n        throw error;\n    }\n}"
  },
  {
    "path": "kt-sft/ktransformers/website/src/api/thread.ts",
    "content": "import apiClient from './api-client';\nimport { IThread, IMessage, IThreadAndMessageAndAssistant, IDeleteResult } from '../utils/types';\nexport const createThread = async (\n    message?: IMessage,\n    tool_resources?: object,\n    metadata?: { [key: string]: any }\n): Promise<IThread> => {\n    const thread_data: { message?: object, metadata?: { [key: string]: any } } = {};\n    if (message) {\n        thread_data.message = message;\n    }\n    if (metadata) {\n        thread_data.metadata = metadata;\n    }\n    const response = await apiClient.post<IThread>(\n        '/threads',\n        thread_data);\n    return response.data;\n};\n\nexport const listThreads = async (\n    limit?: number,\n    order?: string,\n): Promise<IThreadAndMessageAndAssistant[]> => {\n    const params: {\n        limit?: number,\n        order?: string,\n    } = { limit, order };\n    const response = await apiClient.get<IThreadAndMessageAndAssistant[]>('/threads', {\n        params\n    });\n\n    return response.data;\n};\n\nexport const deleteThread = async (\n    thread_id: string\n): Promise<IDeleteResult> => {\n    const response = await apiClient.delete<IDeleteResult>(`/threads/${thread_id}`);\n    return response.data;\n}\n\nexport const getThread = async (\n    thread_id: string\n): Promise<IThread> => {\n    const response = await apiClient.get<IThread>(`/threads/${thread_id}`);\n    return response.data;\n}"
  },
  {
    "path": "kt-sft/ktransformers/website/src/assets/css/mixins.styl",
    "content": "\n/*Define color variables*/\n$bg_gray_light_normal = #F9F9F9\n$bg_gray_light_hover = #E8E8E8\n$bg_gray_light_active = #E8E8E8\n\n$border_gray_light_normal = rgba(0, 0, 0, .15)\n$border_gray_light_hover = #8080FF\n\n$gray_20 = #333333\n$gray_40 = #585858\n$gray_50 = #7F7F7F\n$gray_60 = #9F9F9F\n$gray_70 = #BFBFBF\n$gray_80 = #DFDFDF\n$gray_85 = #F2F2F2\n$gray_90 = #F7F7F7\n\n$gray = #53525B\n$gray_dark = #42414a\n$gray_hover = #121212\n$gray_action = #6C757D\n\n$primary = #409eff\n$primary_hover = #428bca\n$primary_middle = #9DDDF9\n$primary_light = #D4F0FC\n\n$cyan = #66CCCC\n$cyan_hover = #46C2C2\n\n\n/*Define common modules*/\n$input-duration = .25s\ninput-border()\n  -webkit-transition: border-color ease-in-out $input-duration,-webkit-box-shadow ease-in-out $input-duration\n  -o-transition: border-color ease-in-out $input-duration,box-shadow ease-in-out $input-duration\n  transition: border-color ease-in-out $input-duration,box-shadow ease-in-out $input-duration\ninput-focus()\n  border-color: #66afe9\n  outline: 0\n  z-index: 100\n  -webkit-box-shadow: inset 0 1px 1px rgba(0,0,0,.075),0 0 8px rgba(102,175,233,.6)\n  box-shadow: inset 0 1px 1px rgba(0,0,0,.075),0 0 8px rgba(102,175,233,.6)\n\n\n/*Define common class*/\n.flex-column\n  display: -webkit-box\n  display: -webkit-flex\n  display: flex\n  box-sizing: border-box\n  -webkit-box-orient: vertical\n  -webkit-box-direction: normal\n  -webkit-flex-direction: column\n  flex-direction: column\n  height: 100%\n\n.flex-row\n  position: relative\n  display: -webkit-box\n  display: -ms-flexbox\n  display: flex\n  box-sizing: border-box\n  -webkit-box-align: center\n  -ms-flex-align: center\n  align-items: center\n\n.flex-unit\n  -webkit-box-flex: 1\n  -ms-flex: 1\n  flex: 1\n  // overflow: hidden\n\n.clearfix\n  &:after\n    clear: both\n    content: \"\\20\"\n    display: block\n    height: 0\n    visibility: hidden\n\na,a:hover\n  text-decoration:none\n\nbutton:focus\n  outline: none\n\n.btn\n  display: inline-block\n  margin-bottom: 0\n  padding:0px 15px\n  font-size: 14px\n  height: 34px\n  line-height: 32px\n  float: left /*去掉inline-block之间的空格*/\n  font-weight: normal\n  text-align: center\n  white-space: nowrap\n  vertical-align: middle\n  cursor: pointer\n  background-image: none\n  border-radius: 3px\n  -webkit-user-select: none\n  -moz-user-select: none\n  -ms-user-select: none\n  -o-user-select: none\n  user-select: none\n  &:hover\n    .dropdown-list\n      display: block\n  i\n    font-size: 16px\n  .text\n    float: right\n    margin-left: 3px\n\n.btn-gray\n  color: $gray_action\n  background-color: #FFFFFF\n  border: 1px solid $gray_action\n  &:not(.is-disabled):hover\n    color: #FFFFFF\n    background-color: $gray_action\n    border: 1px solid $gray_action\n\n.btn-primary\n  color: #FFFFFF\n  background-color: $primary\n  border: 1px solid $primary\n  &:not(.is-disabled):hover\n    color: #FFFFFF\n    background-color: $primary_hover\n    border: 1px solid $primary_hover\n\n.chat-box\n  position: relative\n  .chat-input\n    border: 1px solid $border_gray_light_normal\n    height: 48px\n    line-height: 48px\n    font-size: 16px\n    outline: 0\n    box-sizing: border-box\n    padding:0 30px0 20px\n    color: #7F7F7F\n    width: 800px\n    border-radius: 12px\n    position: relative\n    &:focus\n      input-focus()\n  i\n    position: absolute\n    font-size: 26px\n    right: 13px\n    bottom:0px\n    color: $border_gray_light_normal\n    z-index: 100\n    cursor: pointer\n    &:hover\n      color: $border_gray_light_hover\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/assets/iconfont/demo.css",
    "content": "/* Logo 字体 */\n@font-face {\n  font-family: \"iconfont logo\";\n  src: url('https://at.alicdn.com/t/font_985780_km7mi63cihi.eot?t=1545807318834');\n  src: url('https://at.alicdn.com/t/font_985780_km7mi63cihi.eot?t=1545807318834#iefix') format('embedded-opentype'),\n    url('https://at.alicdn.com/t/font_985780_km7mi63cihi.woff?t=1545807318834') format('woff'),\n    url('https://at.alicdn.com/t/font_985780_km7mi63cihi.ttf?t=1545807318834') format('truetype'),\n    url('https://at.alicdn.com/t/font_985780_km7mi63cihi.svg?t=1545807318834#iconfont') format('svg');\n}\n\n.logo {\n  font-family: \"iconfont logo\";\n  font-size: 160px;\n  font-style: normal;\n  -webkit-font-smoothing: antialiased;\n  -moz-osx-font-smoothing: grayscale;\n}\n\n/* tabs */\n.nav-tabs {\n  position: relative;\n}\n\n.nav-tabs .nav-more {\n  position: absolute;\n  right: 0;\n  bottom: 0;\n  height: 42px;\n  line-height: 42px;\n  color: #666;\n}\n\n#tabs {\n  border-bottom: 1px solid #eee;\n}\n\n#tabs li {\n  cursor: pointer;\n  width: 100px;\n  height: 40px;\n  line-height: 40px;\n  text-align: center;\n  font-size: 16px;\n  border-bottom: 2px solid transparent;\n  position: relative;\n  z-index: 1;\n  margin-bottom: -1px;\n  color: #666;\n}\n\n\n#tabs .active {\n  border-bottom-color: #f00;\n  color: #222;\n}\n\n.tab-container .content {\n  display: none;\n}\n\n/* 页面布局 */\n.main {\n  padding: 30px 100px;\n  width: 960px;\n  margin: 0 auto;\n}\n\n.main .logo {\n  color: #333;\n  text-align: left;\n  margin-bottom: 30px;\n  line-height: 1;\n  height: 110px;\n  margin-top: -50px;\n  overflow: hidden;\n  *zoom: 1;\n}\n\n.main .logo a {\n  font-size: 160px;\n  color: #333;\n}\n\n.helps {\n  margin-top: 40px;\n}\n\n.helps pre {\n  padding: 20px;\n  margin: 10px 0;\n  border: solid 1px #e7e1cd;\n  background-color: #fffdef;\n  overflow: auto;\n}\n\n.icon_lists {\n  width: 100% !important;\n  overflow: hidden;\n  *zoom: 1;\n}\n\n.icon_lists li {\n  width: 100px;\n  margin-bottom: 10px;\n  margin-right: 20px;\n  text-align: center;\n  list-style: none !important;\n  cursor: default;\n}\n\n.icon_lists li .code-name {\n  line-height: 1.2;\n}\n\n.icon_lists .icon {\n  display: block;\n  height: 100px;\n  line-height: 100px;\n  font-size: 42px;\n  margin: 10px auto;\n  color: #333;\n  -webkit-transition: font-size 0.25s linear, width 0.25s linear;\n  -moz-transition: font-size 0.25s linear, width 0.25s linear;\n  transition: font-size 0.25s linear, width 0.25s linear;\n}\n\n.icon_lists .icon:hover {\n  font-size: 100px;\n}\n\n.icon_lists .svg-icon {\n  /* 通过设置 font-size 来改变图标大小 */\n  width: 1em;\n  /* 图标和文字相邻时，垂直对齐 */\n  vertical-align: -0.15em;\n  /* 通过设置 color 来改变 SVG 的颜色/fill */\n  fill: currentColor;\n  /* path 和 stroke 溢出 viewBox 部分在 IE 下会显示\n      normalize.css 中也包含这行 */\n  overflow: hidden;\n}\n\n.icon_lists li .name,\n.icon_lists li .code-name {\n  color: #666;\n}\n\n/* markdown 样式 */\n.markdown {\n  color: #666;\n  font-size: 14px;\n  line-height: 1.8;\n}\n\n.highlight {\n  line-height: 1.5;\n}\n\n.markdown img {\n  vertical-align: middle;\n  max-width: 100%;\n}\n\n.markdown h1 {\n  color: #404040;\n  font-weight: 500;\n  line-height: 40px;\n  margin-bottom: 24px;\n}\n\n.markdown h2,\n.markdown h3,\n.markdown h4,\n.markdown h5,\n.markdown h6 {\n  color: #404040;\n  margin: 1.6em 0 0.6em 0;\n  font-weight: 500;\n  clear: both;\n}\n\n.markdown h1 {\n  font-size: 28px;\n}\n\n.markdown h2 {\n  font-size: 22px;\n}\n\n.markdown h3 {\n  font-size: 16px;\n}\n\n.markdown h4 {\n  font-size: 14px;\n}\n\n.markdown h5 {\n  font-size: 12px;\n}\n\n.markdown h6 {\n  font-size: 12px;\n}\n\n.markdown hr {\n  height: 1px;\n  border: 0;\n  background: #e9e9e9;\n  margin: 16px 0;\n  clear: both;\n}\n\n.markdown p {\n  margin: 1em 0;\n}\n\n.markdown>p,\n.markdown>blockquote,\n.markdown>.highlight,\n.markdown>ol,\n.markdown>ul {\n  width: 80%;\n}\n\n.markdown ul>li {\n  list-style: circle;\n}\n\n.markdown>ul li,\n.markdown blockquote ul>li {\n  margin-left: 20px;\n  padding-left: 4px;\n}\n\n.markdown>ul li p,\n.markdown>ol li p {\n  margin: 0.6em 0;\n}\n\n.markdown ol>li {\n  list-style: decimal;\n}\n\n.markdown>ol li,\n.markdown blockquote ol>li {\n  margin-left: 20px;\n  padding-left: 4px;\n}\n\n.markdown code {\n  margin: 0 3px;\n  padding: 0 5px;\n  background: #eee;\n  border-radius: 3px;\n}\n\n.markdown strong,\n.markdown b {\n  font-weight: 600;\n}\n\n.markdown>table {\n  border-collapse: collapse;\n  border-spacing:0;\n  empty-cells: show;\n  border: 1px solid #e9e9e9;\n  width: 95%;\n  margin-bottom: 24px;\n}\n\n.markdown>table th {\n  white-space: nowrap;\n  color: #333;\n  font-weight: 600;\n}\n\n.markdown>table th,\n.markdown>table td {\n  border: 1px solid #e9e9e9;\n  padding: 8px 16px;\n  text-align: left;\n}\n\n.markdown>table th {\n  background: #F7F7F7;\n}\n\n.markdown blockquote {\n  font-size: 90%;\n  color: #999;\n  border-left: 4px solid #e9e9e9;\n  padding-left: 0.8em;\n  margin: 1em 0;\n}\n\n.markdown blockquote p {\n  margin: 0;\n}\n\n.markdown .anchor {\n  opacity: 0;\n  transition: opacity 0.3s ease;\n  margin-left: 8px;\n}\n\n.markdown .waiting {\n  color: #ccc;\n}\n\n.markdown h1:hover .anchor,\n.markdown h2:hover .anchor,\n.markdown h3:hover .anchor,\n.markdown h4:hover .anchor,\n.markdown h5:hover .anchor,\n.markdown h6:hover .anchor {\n  opacity: 1;\n  display: inline-block;\n}\n\n.markdown>br,\n.markdown>p>br {\n  clear: both;\n}\n\n\n.hljs {\n  display: block;\n  background: white;\n  padding: 0.5em;\n  color: #333333;\n  overflow-x: auto;\n}\n\n.hljs-comment,\n.hljs-meta {\n  color: #969896;\n}\n\n.hljs-string,\n.hljs-variable,\n.hljs-template-variable,\n.hljs-strong,\n.hljs-emphasis,\n.hljs-quote {\n  color: #df5000;\n}\n\n.hljs-keyword,\n.hljs-selector-tag,\n.hljs-type {\n  color: #a71d5d;\n}\n\n.hljs-literal,\n.hljs-symbol,\n.hljs-bullet,\n.hljs-attribute {\n  color: #0086b3;\n}\n\n.hljs-section,\n.hljs-name {\n  color: #63a35c;\n}\n\n.hljs-tag {\n  color: #333333;\n}\n\n.hljs-title,\n.hljs-attr,\n.hljs-selector-id,\n.hljs-selector-class,\n.hljs-selector-attr,\n.hljs-selector-pseudo {\n  color: #795da3;\n}\n\n.hljs-addition {\n  color: #55a532;\n  background-color: #eaffea;\n}\n\n.hljs-deletion {\n  color: #bd2c00;\n  background-color: #ffecec;\n}\n\n.hljs-link {\n  text-decoration: underline;\n}\n\n/* 代码高亮 */\n/* PrismJS 1.15.0\nhttps://prismjs.com/download.html#themes=prism&languages=markup+css+clike+javascript */\n/**\n * prism.js default theme for JavaScript, CSS and HTML\n * Based on dabblet (http://dabblet.com)\n * @author Lea Verou\n */\ncode[class*=\"language-\"],\npre[class*=\"language-\"] {\n  color: black;\n  background: none;\n  text-shadow: 0 1px white;\n  font-family: Consolas, Monaco, 'Andale Mono', 'Ubuntu Mono', monospace;\n  text-align: left;\n  white-space: pre;\n  word-spacing: normal;\n  word-break: normal;\n  word-wrap: normal;\n  line-height: 1.5;\n\n  -moz-tab-size: 4;\n  -o-tab-size: 4;\n  tab-size: 4;\n\n  -webkit-hyphens: none;\n  -moz-hyphens: none;\n  -ms-hyphens: none;\n  hyphens: none;\n}\n\npre[class*=\"language-\"]::-moz-selection,\npre[class*=\"language-\"] ::-moz-selection,\ncode[class*=\"language-\"]::-moz-selection,\ncode[class*=\"language-\"] ::-moz-selection {\n  text-shadow: none;\n  background: #b3d4fc;\n}\n\npre[class*=\"language-\"]::selection,\npre[class*=\"language-\"] ::selection,\ncode[class*=\"language-\"]::selection,\ncode[class*=\"language-\"] ::selection {\n  text-shadow: none;\n  background: #b3d4fc;\n}\n\n@media print {\n\n  code[class*=\"language-\"],\n  pre[class*=\"language-\"] {\n    text-shadow: none;\n  }\n}\n\n/* Code blocks */\npre[class*=\"language-\"] {\n  padding: 1em;\n  margin: .5em 0;\n  overflow: auto;\n}\n\n:not(pre)>code[class*=\"language-\"],\npre[class*=\"language-\"] {\n  background: #f5f2f0;\n}\n\n/* Inline code */\n:not(pre)>code[class*=\"language-\"] {\n  padding: .1em;\n  border-radius: .3em;\n  white-space: normal;\n}\n\n.token.comment,\n.token.prolog,\n.token.doctype,\n.token.cdata {\n  color: slategray;\n}\n\n.token.punctuation {\n  color: #999;\n}\n\n.namespace {\n  opacity: .7;\n}\n\n.token.property,\n.token.tag,\n.token.boolean,\n.token.number,\n.token.constant,\n.token.symbol,\n.token.deleted {\n  color: #905;\n}\n\n.token.selector,\n.token.attr-name,\n.token.string,\n.token.char,\n.token.builtin,\n.token.inserted {\n  color: #690;\n}\n\n.token.operator,\n.token.entity,\n.token.url,\n.language-css .token.string,\n.style .token.string {\n  color: #9a6e3a;\n  background: hsla(0, 0%, 100%, .5);\n}\n\n.token.atrule,\n.token.attr-value,\n.token.keyword {\n  color: #07a;\n}\n\n.token.function,\n.token.class-name {\n  color: #DD4A68;\n}\n\n.token.regex,\n.token.important,\n.token.variable {\n  color: #e90;\n}\n\n.token.important,\n.token.bold {\n  font-weight: bold;\n}\n\n.token.italic {\n  font-style: italic;\n}\n\n.token.entity {\n  cursor: help;\n}\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/assets/iconfont/demo_index.html",
    "content": "<!DOCTYPE html>\n<html>\n<head>\n  <meta charset=\"utf-8\"/>\n  <title>iconfont Demo</title>\n  <link rel=\"shortcut icon\" href=\"//img.alicdn.com/imgextra/i4/O1CN01Z5paLz1O0zuCC7osS_!!6000000001644-55-tps-83-82.svg\" type=\"image/x-icon\"/>\n  <link rel=\"icon\" type=\"image/svg+xml\" href=\"//img.alicdn.com/imgextra/i4/O1CN01Z5paLz1O0zuCC7osS_!!6000000001644-55-tps-83-82.svg\"/>\n  <link rel=\"stylesheet\" href=\"https://g.alicdn.com/thx/cube/1.3.2/cube.min.css\">\n  <link rel=\"stylesheet\" href=\"demo.css\">\n  <link rel=\"stylesheet\" href=\"iconfont.css\">\n  <script src=\"iconfont.js\"></script>\n  <!-- jQuery -->\n  <script src=\"https://a1.alicdn.com/oss/uploads/2018/12/26/7bfddb60-08e8-11e9-9b04-53e73bb6408b.js\"></script>\n  <!-- 代码高亮 -->\n  <script src=\"https://a1.alicdn.com/oss/uploads/2018/12/26/a3f714d0-08e6-11e9-8a15-ebf944d7534c.js\"></script>\n  <style>\n    .main .logo {\n      margin-top: 0;\n      height: auto;\n    }\n\n    .main .logo a {\n      display: flex;\n      align-items: center;\n    }\n\n    .main .logo .sub-title {\n      margin-left: 0.5em;\n      font-size: 22px;\n      color: #fff;\n      background: linear-gradient(-45deg, #3967FF, #B500FE);\n      -webkit-background-clip: text;\n      -webkit-text-fill-color: transparent;\n    }\n  </style>\n</head>\n<body>\n  <div class=\"main\">\n    <h1 class=\"logo\"><a href=\"https://www.iconfont.cn/\" title=\"iconfont 首页\" target=\"_blank\">\n      <img width=\"200\" src=\"https://img.alicdn.com/imgextra/i3/O1CN01Mn65HV1FfSEzR6DKv_!!6000000000514-55-tps-228-59.svg\">\n      \n    </a></h1>\n    <div class=\"nav-tabs\">\n      <ul id=\"tabs\" class=\"dib-box\">\n        <li class=\"dib active\"><span>Unicode</span></li>\n        <li class=\"dib\"><span>Font class</span></li>\n        <li class=\"dib\"><span>Symbol</span></li>\n      </ul>\n      \n      <a href=\"https://www.iconfont.cn/manage/index?manage_type=myprojects&projectId=4550268\" target=\"_blank\" class=\"nav-more\">查看项目</a>\n      \n    </div>\n    <div class=\"tab-container\">\n      <div class=\"content unicode\" style=\"display: block;\">\n          <ul class=\"icon_lists dib-box\">\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe8b0;</span>\n                <div class=\"name\">复制</div>\n                <div class=\"code-name\">&amp;#xe8b0;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe85e;</span>\n                <div class=\"name\">箭头下</div>\n                <div class=\"code-name\">&amp;#xe85e;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe651;</span>\n                <div class=\"name\">进度</div>\n                <div class=\"code-name\">&amp;#xe651;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe617;</span>\n                <div class=\"name\">环形进度条</div>\n                <div class=\"code-name\">&amp;#xe617;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe779;</span>\n                <div class=\"name\">向左1</div>\n                <div class=\"code-name\">&amp;#xe779;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe608;</span>\n                <div class=\"name\">点</div>\n                <div class=\"code-name\">&amp;#xe608;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe7dd;</span>\n                <div class=\"name\">编辑</div>\n                <div class=\"code-name\">&amp;#xe7dd;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe614;</span>\n                <div class=\"name\">删除</div>\n                <div class=\"code-name\">&amp;#xe614;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe618;</span>\n                <div class=\"name\">上传</div>\n                <div class=\"code-name\">&amp;#xe618;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe621;</span>\n                <div class=\"name\">探索-选中</div>\n                <div class=\"code-name\">&amp;#xe621;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe657;</span>\n                <div class=\"name\">ellipsis</div>\n                <div class=\"code-name\">&amp;#xe657;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe60c;</span>\n                <div class=\"name\">发送</div>\n                <div class=\"code-name\">&amp;#xe60c;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe62d;</span>\n                <div class=\"name\">列表</div>\n                <div class=\"code-name\">&amp;#xe62d;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe639;</span>\n                <div class=\"name\">列表</div>\n                <div class=\"code-name\">&amp;#xe639;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe6bd;</span>\n                <div class=\"name\">重试</div>\n                <div class=\"code-name\">&amp;#xe6bd;</div>\n              </li>\n          \n            <li class=\"dib\">\n              <span class=\"icon iconfont\">&#xe826;</span>\n                <div class=\"name\">Fork 记录</div>\n                <div class=\"code-name\">&amp;#xe826;</div>\n              </li>\n          \n          </ul>\n          <div class=\"article markdown\">\n          <h2 id=\"unicode-\">Unicode 引用</h2>\n          <hr>\n\n          <p>Unicode 是字体在网页端最原始的应用方式，特点是：</p>\n          <ul>\n            <li>支持按字体的方式去动态调整图标大小，颜色等等。</li>\n            <li>默认情况下不支持多色，直接添加多色图标会自动去色。</li>\n          </ul>\n          <blockquote>\n            <p>注意：新版 iconfont 支持两种方式引用多色图标：SVG symbol 引用方式和彩色字体图标模式。（使用彩色字体图标需要在「编辑项目」中开启「彩色」选项后并重新生成。）</p>\n          </blockquote>\n          <p>Unicode 使用步骤如下：</p>\n          <h3 id=\"-font-face\">第一步：拷贝项目下面生成的 <code>@font-face</code></h3>\n<pre><code class=\"language-css\"\n>@font-face {\n  font-family: 'iconfont';\n  src: url('iconfont.woff2?t=1717950820214') format('woff2'),\n       url('iconfont.woff?t=1717950820214') format('woff'),\n       url('iconfont.ttf?t=1717950820214') format('truetype'),\n       url('iconfont.svg?t=1717950820214#iconfont') format('svg');\n}\n</code></pre>\n          <h3 id=\"-iconfont-\">第二步：定义使用 iconfont 的样式</h3>\n<pre><code class=\"language-css\"\n>.iconfont {\n  font-family: \"iconfont\" !important;\n  font-size: 16px;\n  font-style: normal;\n  -webkit-font-smoothing: antialiased;\n  -moz-osx-font-smoothing: grayscale;\n}\n</code></pre>\n          <h3 id=\"-\">第三步：挑选相应图标并获取字体编码，应用于页面</h3>\n<pre>\n<code class=\"language-html\"\n>&lt;span class=\"iconfont\"&gt;&amp;#x33;&lt;/span&gt;\n</code></pre>\n          <blockquote>\n            <p>\"iconfont\" 是你项目下的 font-family。可以通过编辑项目查看，默认是 \"iconfont\"。</p>\n          </blockquote>\n          </div>\n      </div>\n      <div class=\"content font-class\">\n        <ul class=\"icon_lists dib-box\">\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-copy\"></span>\n            <div class=\"name\">\n              复制\n            </div>\n            <div class=\"code-name\">.icon-copy\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-arrow-down\"></span>\n            <div class=\"name\">\n              箭头下\n            </div>\n            <div class=\"code-name\">.icon-arrow-down\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-usage-progress\"></span>\n            <div class=\"name\">\n              进度\n            </div>\n            <div class=\"code-name\">.icon-usage-progress\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-gen-progress\"></span>\n            <div class=\"name\">\n              环形进度条\n            </div>\n            <div class=\"code-name\">.icon-gen-progress\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-back\"></span>\n            <div class=\"name\">\n              向左1\n            </div>\n            <div class=\"code-name\">.icon-back\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-point\"></span>\n            <div class=\"name\">\n              点\n            </div>\n            <div class=\"code-name\">.icon-point\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-edit\"></span>\n            <div class=\"name\">\n              编辑\n            </div>\n            <div class=\"code-name\">.icon-edit\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-delete\"></span>\n            <div class=\"name\">\n              删除\n            </div>\n            <div class=\"code-name\">.icon-delete\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-upload-1\"></span>\n            <div class=\"name\">\n              上传\n            </div>\n            <div class=\"code-name\">.icon-upload-1\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-explore\"></span>\n            <div class=\"name\">\n              探索-选中\n            </div>\n            <div class=\"code-name\">.icon-explore\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-ellipsis\"></span>\n            <div class=\"name\">\n              ellipsis\n            </div>\n            <div class=\"code-name\">.icon-ellipsis\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-sent\"></span>\n            <div class=\"name\">\n              发送\n            </div>\n            <div class=\"code-name\">.icon-sent\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-list-list\"></span>\n            <div class=\"name\">\n              列表\n            </div>\n            <div class=\"code-name\">.icon-list-list\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-list-icon\"></span>\n            <div class=\"name\">\n              列表\n            </div>\n            <div class=\"code-name\">.icon-list-icon\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-zhongshi\"></span>\n            <div class=\"name\">\n              重试\n            </div>\n            <div class=\"code-name\">.icon-zhongshi\n            </div>\n          </li>\n          \n          <li class=\"dib\">\n            <span class=\"icon iconfont icon-log\"></span>\n            <div class=\"name\">\n              Fork 记录\n            </div>\n            <div class=\"code-name\">.icon-log\n            </div>\n          </li>\n          \n        </ul>\n        <div class=\"article markdown\">\n        <h2 id=\"font-class-\">font-class 引用</h2>\n        <hr>\n\n        <p>font-class 是 Unicode 使用方式的一种变种，主要是解决 Unicode 书写不直观，语意不明确的问题。</p>\n        <p>与 Unicode 使用方式相比，具有如下特点：</p>\n        <ul>\n          <li>相比于 Unicode 语意明确，书写更直观。可以很容易分辨这个 icon 是什么。</li>\n          <li>因为使用 class 来定义图标，所以当要替换图标时，只需要修改 class 里面的 Unicode 引用。</li>\n        </ul>\n        <p>使用步骤如下：</p>\n        <h3 id=\"-fontclass-\">第一步：引入项目下面生成的 fontclass 代码：</h3>\n<pre><code class=\"language-html\">&lt;link rel=\"stylesheet\" href=\"./iconfont.css\"&gt;\n</code></pre>\n        <h3 id=\"-\">第二步：挑选相应图标并获取类名，应用于页面：</h3>\n<pre><code class=\"language-html\">&lt;span class=\"iconfont icon-xxx\"&gt;&lt;/span&gt;\n</code></pre>\n        <blockquote>\n          <p>\"\n            iconfont\" 是你项目下的 font-family。可以通过编辑项目查看，默认是 \"iconfont\"。</p>\n        </blockquote>\n      </div>\n      </div>\n      <div class=\"content symbol\">\n          <ul class=\"icon_lists dib-box\">\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-copy\"></use>\n                </svg>\n                <div class=\"name\">复制</div>\n                <div class=\"code-name\">#icon-copy</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-arrow-down\"></use>\n                </svg>\n                <div class=\"name\">箭头下</div>\n                <div class=\"code-name\">#icon-arrow-down</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-usage-progress\"></use>\n                </svg>\n                <div class=\"name\">进度</div>\n                <div class=\"code-name\">#icon-usage-progress</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-gen-progress\"></use>\n                </svg>\n                <div class=\"name\">环形进度条</div>\n                <div class=\"code-name\">#icon-gen-progress</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-back\"></use>\n                </svg>\n                <div class=\"name\">向左1</div>\n                <div class=\"code-name\">#icon-back</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-point\"></use>\n                </svg>\n                <div class=\"name\">点</div>\n                <div class=\"code-name\">#icon-point</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-edit\"></use>\n                </svg>\n                <div class=\"name\">编辑</div>\n                <div class=\"code-name\">#icon-edit</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-delete\"></use>\n                </svg>\n                <div class=\"name\">删除</div>\n                <div class=\"code-name\">#icon-delete</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-upload-1\"></use>\n                </svg>\n                <div class=\"name\">上传</div>\n                <div class=\"code-name\">#icon-upload-1</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-explore\"></use>\n                </svg>\n                <div class=\"name\">探索-选中</div>\n                <div class=\"code-name\">#icon-explore</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-ellipsis\"></use>\n                </svg>\n                <div class=\"name\">ellipsis</div>\n                <div class=\"code-name\">#icon-ellipsis</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-sent\"></use>\n                </svg>\n                <div class=\"name\">发送</div>\n                <div class=\"code-name\">#icon-sent</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-list-list\"></use>\n                </svg>\n                <div class=\"name\">列表</div>\n                <div class=\"code-name\">#icon-list-list</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-list-icon\"></use>\n                </svg>\n                <div class=\"name\">列表</div>\n                <div class=\"code-name\">#icon-list-icon</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-zhongshi\"></use>\n                </svg>\n                <div class=\"name\">重试</div>\n                <div class=\"code-name\">#icon-zhongshi</div>\n            </li>\n          \n            <li class=\"dib\">\n                <svg class=\"icon svg-icon\" aria-hidden=\"true\">\n                  <use xlink:href=\"#icon-log\"></use>\n                </svg>\n                <div class=\"name\">Fork 记录</div>\n                <div class=\"code-name\">#icon-log</div>\n            </li>\n          \n          </ul>\n          <div class=\"article markdown\">\n          <h2 id=\"symbol-\">Symbol 引用</h2>\n          <hr>\n\n          <p>这是一种全新的使用方式，应该说这才是未来的主流，也是平台目前推荐的用法。相关介绍可以参考这篇<a href=\"\">文章</a>\n            这种用法其实是做了一个 SVG 的集合，与另外两种相比具有如下特点：</p>\n          <ul>\n            <li>支持多色图标了，不再受单色限制。</li>\n            <li>通过一些技巧，支持像字体那样，通过 <code>font-size</code>, <code>color</code> 来调整样式。</li>\n            <li>兼容性较差，支持 IE9+，及现代浏览器。</li>\n            <li>浏览器渲染 SVG 的性能一般，还不如 png。</li>\n          </ul>\n          <p>使用步骤如下：</p>\n          <h3 id=\"-symbol-\">第一步：引入项目下面生成的 symbol 代码：</h3>\n<pre><code class=\"language-html\">&lt;script src=\"./iconfont.js\"&gt;&lt;/script&gt;\n</code></pre>\n          <h3 id=\"-css-\">第二步：加入通用 CSS 代码（引入一次就行）：</h3>\n<pre><code class=\"language-html\">&lt;style&gt;\n.icon {\n  width: 1em;\n  height: 1em;\n  vertical-align: -0.15em;\n  fill: currentColor;\n  overflow: hidden;\n}\n&lt;/style&gt;\n</code></pre>\n          <h3 id=\"-\">第三步：挑选相应图标并获取类名，应用于页面：</h3>\n<pre><code class=\"language-html\">&lt;svg class=\"icon\" aria-hidden=\"true\"&gt;\n  &lt;use xlink:href=\"#icon-xxx\"&gt;&lt;/use&gt;\n&lt;/svg&gt;\n</code></pre>\n          </div>\n      </div>\n\n    </div>\n  </div>\n  <script>\n  $(document).ready(function () {\n      $('.tab-container .content:first').show()\n\n      $('#tabs li').click(function (e) {\n        var tabContent = $('.tab-container .content')\n        var index = $(this).index()\n\n        if ($(this).hasClass('active')) {\n          return\n        } else {\n          $('#tabs li').removeClass('active')\n          $(this).addClass('active')\n\n          tabContent.hide().eq(index).fadeIn()\n        }\n      })\n    })\n  </script>\n</body>\n</html>\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/assets/iconfont/iconfont.css",
    "content": "@font-face {\n  font-family: \"iconfont\"; /* Project id 4550268 */\n  src: url('iconfont.woff2?t=1717950820214') format('woff2'),\n       url('iconfont.woff?t=1717950820214') format('woff'),\n       url('iconfont.ttf?t=1717950820214') format('truetype'),\n       url('iconfont.svg?t=1717950820214#iconfont') format('svg');\n}\n\n.iconfont {\n  font-family: \"iconfont\" !important;\n  font-size: 16px;\n  font-style: normal;\n  -webkit-font-smoothing: antialiased;\n  -moz-osx-font-smoothing: grayscale;\n}\n\n.icon-copy:before {\n  content: \"\\e8b0\";\n}\n\n.icon-arrow-down:before {\n  content: \"\\e85e\";\n}\n\n.icon-usage-progress:before {\n  content: \"\\e651\";\n}\n\n.icon-gen-progress:before {\n  content: \"\\e617\";\n}\n\n.icon-back:before {\n  content: \"\\e779\";\n}\n\n.icon-point:before {\n  content: \"\\e608\";\n}\n\n.icon-edit:before {\n  content: \"\\e7dd\";\n}\n\n.icon-delete:before {\n  content: \"\\e614\";\n}\n\n.icon-upload-1:before {\n  content: \"\\e618\";\n}\n\n.icon-explore:before {\n  content: \"\\e621\";\n}\n\n.icon-ellipsis:before {\n  content: \"\\e657\";\n}\n\n.icon-sent:before {\n  content: \"\\e60c\";\n}\n\n.icon-list-list:before {\n  content: \"\\e62d\";\n}\n\n.icon-list-icon:before {\n  content: \"\\e639\";\n}\n\n.icon-zhongshi:before {\n  content: \"\\e6bd\";\n}\n\n.icon-log:before {\n  content: \"\\e826\";\n}\n\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/assets/iconfont/iconfont.js",
    "content": "window._iconfont_svg_string_4550268='<svg><symbol id=\"icon-copy\" viewBox=\"0 0 1024 1024\"><path d=\"M394.666667 106.666667h448a74.666667 74.666667 0 0 1 74.666666 74.666666v448a74.666667 74.666667 0 0 1-74.666666 74.666667H394.666667a74.666667 74.666667 0 0 1-74.666667-74.666667V181.333333a74.666667 74.666667 0 0 1 74.666667-74.666666z m0 64a10.666667 10.666667 0 0 0-10.666667 10.666666v448a10.666667 10.666667 0 0 0 10.666667 10.666667h448a10.666667 10.666667 0 0 0 10.666666-10.666667V181.333333a10.666667 10.666667 0 0 0-10.666666-10.666666H394.666667z m245.333333 597.333333a32 32 0 0 1 64 0v74.666667a74.666667 74.666667 0 0 1-74.666667 74.666666H181.333333a74.666667 74.666667 0 0 1-74.666666-74.666666V394.666667a74.666667 74.666667 0 0 1 74.666666-74.666667h74.666667a32 32 0 0 1 0 64h-74.666667a10.666667 10.666667 0 0 0-10.666666 10.666667v448a10.666667 10.666667 0 0 0 10.666666 10.666666h448a10.666667 10.666667 0 0 0 10.666667-10.666666v-74.666667z\" fill=\"#000000\" ></path></symbol><symbol id=\"icon-arrow-down\" viewBox=\"0 0 1024 1024\"><path d=\"M554.666667 690.005333l228.864-228.864 60.330666 60.330667L512 853.333333l-331.861333-331.861333 60.330666-60.330667L469.333333 690.005333V170.666667h85.333334v519.338666z\"  ></path></symbol><symbol id=\"icon-usage-progress\" viewBox=\"0 0 1024 1024\"><path d=\"M512 125.098667A386.901333 386.901333 0 1 1 125.098667 512 386.901333 386.901333 0 0 1 512 125.098667z\" fill=\"#ACE9C5\" ></path><path d=\"M512 318.634667A193.365333 193.365333 0 1 1 318.634667 512 193.365333 193.365333 0 0 1 512 318.634667z\" fill=\"#2BA866\" ></path></symbol><symbol id=\"icon-gen-progress\" viewBox=\"0 0 1024 1024\"><path d=\"M692.004733 714.930578l96.018649 96.017519C715.492309 877.950022 618.525386 918.887417 512 918.887417c-104.225342 0-199.297978-39.187779-271.287664-103.631964l96.127152-96.126023C384.097201 759.135506 445.230905 783.258278 512 783.258278c69.07253 0 132.114084-25.817007 180.004733-68.3277z m-202.61185-609.200883L489.395143 241.670781C350.16053 253.157439 240.741722 369.800759 240.741722 512c0 66.767965 24.122773 127.900539 64.127717 175.160512l-96.126022 96.126022C144.299232 711.295717 105.112583 616.225342 105.112583 512c0-217.130949 170.07894-394.539514 384.2803-406.270305z m325.8637 134.984901C879.700768 312.702022 918.887417 407.774658 918.887417 512c0 101.921907-37.474331 195.091214-99.395814 266.479611l-96.270694-96.268432C760.774358 635.667779 783.258278 576.460009 783.258278 512c0-66.767965-24.122773-127.901669-64.128848-175.161642l96.127153-96.124892zM534.608247 105.728565c95.334852 5.221722 181.928406 43.261174 248.678287 103.013722l-96.127152 96.127152c-41.869845-35.444415-94.631841-58.422252-152.553395-63.199788l0.00226-135.941086z\" fill=\"#448AFF\" fill-opacity=\".6\" ></path><path d=\"M489.392883 105.729695L489.395143 241.670781C350.16053 253.157439 240.741722 369.800759 240.741722 512c0 66.767965 24.122773 127.900539 64.127717 175.160512l-96.126022 96.126022C144.299232 711.295717 105.112583 616.225342 105.112583 512c0-217.130949 170.07894-394.539514 384.2803-406.270305z\" fill=\"#448AFF\" ></path></symbol><symbol id=\"icon-back\" viewBox=\"0 0 1024 1024\"><path d=\"M671.968176 911.99957c-12.287381 0-24.576482-4.67206-33.951566-14.047144L286.048434 545.984249c-18.751888-18.719204-18.751888-49.12028 0-67.872168L638.016611 126.111222c18.751888-18.751888 49.12028-18.751888 67.872168 0 18.751888 18.719204 18.751888 49.12028 0 67.872168l-318.016611 318.047574L705.888778 830.047574c18.751888 18.751888 18.751888 49.12028 0 67.872168C696.544658 907.32751 684.255557 911.99957 671.968176 911.99957z\" fill=\"#2c2c2c\" ></path></symbol><symbol id=\"icon-point\" viewBox=\"0 0 1024 1024\"><path d=\"M512 307.2a204.86826667 204.86826667 0 0 1 0 409.6 204.8 204.8 0 0 1 0-409.6z\" fill=\"\" ></path></symbol><symbol id=\"icon-edit\" viewBox=\"0 0 1024 1024\"><path d=\"M899.072 125.44c-28.672-28.672-67.072-44.544-107.52-44.544s-78.848 15.872-107.52 44.544L251.392 558.08c-34.304 34.304-60.416 74.752-78.336 119.808L88.576 896c-4.608 11.264-1.536 24.064 7.168 32.768 5.632 5.632 13.824 9.216 21.504 9.216 3.584 0 7.68-0.512 11.264-2.048l218.624-84.48c45.056-17.408 85.504-44.032 119.808-78.336l351.744-351.744 80.896-80.896c58.88-59.392 58.88-155.648-0.512-215.04z m-475.648 604.16c-28.16 28.16-61.44 50.176-98.816 64.512l-153.6 59.392 59.392-153.6c14.336-37.376 35.84-70.656 64.512-98.816L625.152 271.36l128.512 128.512-330.24 329.728z m432.64-432.128l-58.88 58.88-128.512-128.512L727.552 168.96c16.896-16.896 39.936-26.624 64.512-26.624s47.104 9.216 64.512 26.624c34.816 35.328 34.816 92.672-0.512 128.512z\" fill=\"#333333\" ></path></symbol><symbol id=\"icon-delete\" viewBox=\"0 0 1024 1024\"><path d=\"M742.4 944H281.6c-49.4 0-89.6-43.1-89.6-96V368h64v480c0 17.3 11.7 32 25.6 32h460.8c13.9 0 25.6-14.7 25.6-32V368h64v480c0 52.9-40.2 96-89.6 96z\"  ></path><path d=\"M384 368h64v416h-64zM592 368h64v416h-64zM64 224h896v64H64z\"  ></path><path d=\"M768 288H256V160c0-52.9 43.1-96 96-96h320c52.9 0 96 43.1 96 96v128z m-448-64h384v-64c0-17.6-14.4-32-32-32H352c-17.6 0-32 14.4-32 32v64z\"  ></path></symbol><symbol id=\"icon-upload-1\" viewBox=\"0 0 1024 1024\"><path d=\"M323.034074 291.934815l383.620741 0c9.481481 0 17.256296-8.533333 17.256296-18.962963 0-10.42963-7.68-18.962963-17.256296-18.962963L323.034074 254.008889c-9.481481 0-17.256296 8.533333-17.256296 18.962963C305.777778 283.496296 313.457778 291.934815 323.034074 291.934815z\" fill=\"#272536\" ></path><path d=\"M522.05037 328.628148c-1.232593-1.232593-2.844444-1.896296-4.740741-1.991111-1.706667-0.094815-3.318519-0.094815-5.025185 0-1.896296 0.094815-3.508148 0.758519-4.740741 1.991111L349.013333 487.253333c-3.887407 3.887407-1.896296 12.325926 4.456296 18.773333 6.447407 6.447407 14.791111 8.438519 18.773333 4.456296l125.060741-125.060741 0 367.122963c0 9.671111 7.86963 17.540741 17.540741 17.540741l0 0c9.671111 0 17.540741-7.86963 17.540741-17.540741L532.385185 385.327407l125.060741 125.060741c3.887407 3.887407 12.325926 1.896296 18.773333-4.456296 6.447407-6.447407 8.438519-14.791111 4.456296-18.773333L522.05037 328.628148z\" fill=\"#272536\" ></path></symbol><symbol id=\"icon-explore\" viewBox=\"0 0 1024 1024\"><path d=\"M926.352541 89.231277c-0.029676-7.432273-1.212618-13.651928-2.837628-19.264762-31.228235-8.264221-71.898517 1.24127-106.283652 17.927301-7.049556 3.41068-23.762193 13.583366-48.51597 28.643364-10.237155 6.250354-19.264762 11.739369-23.251563 14.002922-0.384763 0.224104-0.608867 0.63752-0.958838 0.861624-67.557652-41.147142-146.571217-65.327868-231.319389-65.327868-246.251474 0-446.569802 200.319351-446.569802 446.564685 0 82.554204 22.904663 159.683862 62.105476 226.062666-46.315862 71.387887-69.2809 122.93182-63.283302 157.863401 1.24127 7.144724 13.555737 8.28878 20.316721 8.28878 137.989771 0 453.393207-302.802444 492.628814-341.399507C751.64859 393.022235 926.449755 184.667883 926.352541 89.231277L926.352541 89.231277zM305.847292 611.014084c-43.956118 0-79.744205-35.757388-79.744205-79.743182 0-43.956118 35.789111-79.744205 79.744205-79.744205 43.956118 0 79.743182 35.789111 79.743182 79.744205C385.591497 575.256696 349.803409 611.014084 305.847292 611.014084L305.847292 611.014084zM446.19783 387.730719c-52.760644 0-95.694479-42.937928-95.694479-95.692433 0-52.760644 42.933835-95.694479 95.694479-95.694479 52.761668 0 95.694479 42.933835 95.694479 95.694479C541.892309 344.79279 498.958474 387.730719 446.19783 387.730719L446.19783 387.730719zM893.595486 279.9469c-66.889433 99.330286-172.055634 218.596623-276.967032 321.751005-28.551266 28.104081-201.624067 195.822944-346.982666 285.198507 0.12689-0.097214 0.223081-0.160659 0.349971-0.224104 70.049403 45.708018 153.491837 72.536037 243.189741 72.536037 246.246357 0 446.565708-200.318328 446.565708-446.570825C959.716416 427.317319 935.282934 347.82587 893.595486 279.9469L893.595486 279.9469zM638.54051 799.720957c-35.180244 0-63.793932-28.614711-63.793932-63.794955 0-35.184337 28.613688-63.799048 63.793932-63.799048 35.184337 0 63.793932 28.614711 63.793932 63.799048C702.334441 771.106246 673.724847 799.720957 638.54051 799.720957L638.54051 799.720957zM638.54051 799.720957\" fill=\"#615CED\" ></path></symbol><symbol id=\"icon-ellipsis\" viewBox=\"0 0 1024 1024\"><path d=\"M322.292 505.5m-66 0a66 66 0 1 0 132 0 66 66 0 1 0-132 0Z\" fill=\"#272636\" ></path><path d=\"M509.791 505.5m-66 0a66 66 0 1 0 132 0 66 66 0 1 0-132 0Z\" fill=\"#272636\" ></path><path d=\"M701.791 505.5m-66 0a66 66 0 1 0 132 0 66 66 0 1 0-132 0Z\" fill=\"#272636\" ></path></symbol><symbol id=\"icon-sent\" viewBox=\"0 0 1024 1024\"><path d=\"M998.976 554.3232C1031.232 539.6032 1031.328 515.7952 998.976 501.0432L122.88 101.3312C90.624 86.6112 64.448 103.5072 64.384 138.4832L64 426.9952 773.568 527.6672 64 628.3392 64.384 916.8832C64.448 952.1152 90.528 968.7872 122.88 954.0352L998.976 554.3232Z\"  ></path></symbol><symbol id=\"icon-list-list\" viewBox=\"0 0 1024 1024\"><path d=\"M419.037 287.953h413.124c17.673 0 32-14.327 32-32s-14.327-32-32-32H419.037c-17.673 0-32 14.327-32 32s14.327 32 32 32zM419.028 543.17h411.608c17.673 0 32-14.327 32-32s-14.327-32-32-32H419.028c-17.673 0-32 14.327-32 32s14.327 32 32 32zM832.161 735.802H419.037c-17.673 0-32 14.327-32 32s14.327 32 32 32h413.124c17.673 0 32-14.327 32-32s-14.327-32-32-32z\" fill=\"\" ></path><path d=\"M256.037 255.953m-64 0a64 64 0 1 0 128 0 64 64 0 1 0-128 0Z\" fill=\"\" ></path><path d=\"M256.037 510.787m-64 0a64 64 0 1 0 128 0 64 64 0 1 0-128 0Z\" fill=\"\" ></path><path d=\"M256.037 767.621m-64 0a64 64 0 1 0 128 0 64 64 0 1 0-128 0Z\" fill=\"\" ></path></symbol><symbol id=\"icon-list-icon\" viewBox=\"0 0 1024 1024\"><path d=\"M841.6 489.6h-214.4c-48 0-86.4-38.4-86.4-86.4V188.8c0-48 38.4-86.4 86.4-86.4h214.4c48 0 86.4 38.4 86.4 86.4v214.4c0 48-38.4 86.4-86.4 86.4z m-211.2-320c-12.8 0-22.4 9.6-22.4 22.4v214.4c0 12.8 9.6 22.4 22.4 22.4h214.4c12.8 0 22.4-9.6 22.4-22.4V188.8c0-12.8-9.6-22.4-22.4-22.4h-214.4zM393.6 489.6H182.4c-48 0-86.4-38.4-86.4-86.4V188.8c0-48 38.4-86.4 86.4-86.4h214.4c48 0 86.4 38.4 86.4 86.4v214.4c-3.2 48-41.6 86.4-89.6 86.4z m-211.2-320c-12.8 0-22.4 9.6-22.4 19.2v214.4c0 12.8 9.6 22.4 22.4 22.4h214.4c12.8 0 22.4-9.6 22.4-22.4V188.8c0-12.8-9.6-22.4-22.4-22.4H182.4zM841.6 937.6h-214.4c-48 0-86.4-38.4-86.4-86.4v-214.4c0-48 38.4-86.4 86.4-86.4h214.4c48 0 86.4 38.4 86.4 86.4v214.4c0 48-38.4 86.4-86.4 86.4z m-211.2-323.2c-12.8 0-22.4 9.6-22.4 22.4v214.4c0 12.8 9.6 22.4 22.4 22.4h214.4c12.8 0 22.4-9.6 22.4-22.4v-214.4c0-12.8-9.6-22.4-22.4-22.4h-214.4zM393.6 937.6H182.4c-48 0-86.4-38.4-86.4-86.4v-214.4c0-48 38.4-86.4 86.4-86.4h214.4c48 0 86.4 38.4 86.4 86.4v214.4c-3.2 48-41.6 86.4-89.6 86.4zM182.4 614.4c-12.8 0-22.4 9.6-22.4 22.4v214.4c0 12.8 9.6 22.4 22.4 22.4h214.4c12.8 0 22.4-9.6 22.4-22.4v-214.4c0-12.8-9.6-22.4-22.4-22.4H182.4z\" fill=\"#333333\" ></path></symbol><symbol id=\"icon-zhongshi\" viewBox=\"0 0 1024 1024\"><path d=\"M973.53044 167.133265l-65.609003 50.468463A491.226376 491.226376 0 0 0 522.971405 33.282123C253.074841 33.282123 34.74388 247.370807 34.378166 512.220525c-0.365714 265.142289 218.550389 480.108685 488.593239 480.108686 211.016691 0 390.728306-131.291147 459.189873-315.245039a9.069695 9.069695 0 0 0-5.851416-11.775975l-65.82843-22.308523a9.435408 9.435408 0 0 0-11.775975 5.485702 392.48373 392.48373 0 0 1-92.525516 141.896839 402.650566 402.650566 0 0 1-282.915965 115.12661c-54.125598 0-106.495772-10.386263-155.793952-30.793077a398.627717 398.627717 0 0 1-212.845258-209.188123 383.779749 383.779749 0 0 1-31.451361-152.868244c0-53.1016 10.532549-104.374633 31.451361-152.868243 20.114243-46.738186 49.005609-88.795238 85.723245-124.85459a401.260854 401.260854 0 0 1 282.915965-115.12661c54.052456 0 106.422629 10.459406 155.720809 30.866219a398.627717 398.627717 0 0 1 159.52423 120.100314l-69.997565 53.686742a9.069695 9.069695 0 0 0 3.437707 16.091394l204.287562 49.151895c5.851416 1.316569 11.556547-2.998851 11.556547-8.777124l0.950855-206.554986a9.508551 9.508551 0 0 0-15.213681-7.167985z\" fill=\"#000000\" ></path></symbol><symbol id=\"icon-log\" viewBox=\"0 0 1024 1024\"><path d=\"M288 64c70.692 0 128 57.308 128 128 0 58.192-38.833 107.315-91.998 122.867L324 571.5h225c48.8 0 84.134-19.864 110.1-62.009 15.655-25.408 27.76-58.805 36.092-100.127C648.71 390.177 616 344.408 616 291c0-70.692 57.308-128 128-128 70.692 0 128 57.308 128 128 0 62.814-45.245 115.06-104.923 125.925-9.94 52.391-25.407 95.81-46.677 130.334-38.644 62.721-96.365 95.58-169.189 96.231l-2.211 0.01H324l0.002 65.633c52.52 15.363 91.052 63.486 91.98 120.75L416 832c0 70.692-57.308 128-128 128-70.692 0-128-57.308-128-128 0-58.193 38.833-107.315 91.999-122.868V314.868C198.833 299.315 160 250.193 160 192c0-70.692 57.308-128 128-128z\" fill=\"#333333\" ></path></symbol></svg>',function(l){var t=(t=document.getElementsByTagName(\"script\"))[t.length-1],c=t.getAttribute(\"data-injectcss\"),t=t.getAttribute(\"data-disable-injectsvg\");if(!t){var i,o,e,a,h,n=function(t,c){c.parentNode.insertBefore(t,c)};if(c&&!l.__iconfont__svg__cssinject__){l.__iconfont__svg__cssinject__=!0;try{document.write(\"<style>.svgfont {display: inline-block;width: 1em;height: 1em;fill: currentColor;vertical-align: -0.1em;font-size:16px;}</style>\")}catch(t){console&&console.log(t)}}i=function(){var t,c=document.createElement(\"div\");c.innerHTML=l._iconfont_svg_string_4550268,(c=c.getElementsByTagName(\"svg\")[0])&&(c.setAttribute(\"aria-hidden\",\"true\"),c.style.position=\"absolute\",c.style.width=0,c.style.height=0,c.style.overflow=\"hidden\",c=c,(t=document.body).firstChild?n(c,t.firstChild):t.appendChild(c))},document.addEventListener?~[\"complete\",\"loaded\",\"interactive\"].indexOf(document.readyState)?setTimeout(i,0):(o=function(){document.removeEventListener(\"DOMContentLoaded\",o,!1),i()},document.addEventListener(\"DOMContentLoaded\",o,!1)):document.attachEvent&&(e=i,a=l.document,h=!1,d(),a.onreadystatechange=function(){\"complete\"==a.readyState&&(a.onreadystatechange=null,s())})}function s(){h||(h=!0,e())}function d(){try{a.documentElement.doScroll(\"left\")}catch(t){return void setTimeout(d,50)}s()}}(window);"
  },
  {
    "path": "kt-sft/ktransformers/website/src/assets/iconfont/iconfont.json",
    "content": "{\n  \"id\": \"4550268\",\n  \"name\": \"Lexllama\",\n  \"font_family\": \"iconfont\",\n  \"css_prefix_text\": \"icon-\",\n  \"description\": \"Lexllama开源项目使用\",\n  \"glyphs\": [\n    {\n      \"icon_id\": \"11372665\",\n      \"name\": \"复制\",\n      \"font_class\": \"copy\",\n      \"unicode\": \"e8b0\",\n      \"unicode_decimal\": 59568\n    },\n    {\n      \"icon_id\": \"34202237\",\n      \"name\": \"箭头下\",\n      \"font_class\": \"arrow-down\",\n      \"unicode\": \"e85e\",\n      \"unicode_decimal\": 59486\n    },\n    {\n      \"icon_id\": \"7766233\",\n      \"name\": \"进度\",\n      \"font_class\": \"usage-progress\",\n      \"unicode\": \"e651\",\n      \"unicode_decimal\": 58961\n    },\n    {\n      \"icon_id\": \"38865122\",\n      \"name\": \"环形进度条\",\n      \"font_class\": \"gen-progress\",\n      \"unicode\": \"e617\",\n      \"unicode_decimal\": 58903\n    },\n    {\n      \"icon_id\": \"577406\",\n      \"name\": \"向左1\",\n      \"font_class\": \"back\",\n      \"unicode\": \"e779\",\n      \"unicode_decimal\": 59257\n    },\n    {\n      \"icon_id\": \"1920286\",\n      \"name\": \"点\",\n      \"font_class\": \"point\",\n      \"unicode\": \"e608\",\n      \"unicode_decimal\": 58888\n    },\n    {\n      \"icon_id\": \"8866967\",\n      \"name\": \"编辑\",\n      \"font_class\": \"edit\",\n      \"unicode\": \"e7dd\",\n      \"unicode_decimal\": 59357\n    },\n    {\n      \"icon_id\": \"10199175\",\n      \"name\": \"删除\",\n      \"font_class\": \"delete\",\n      \"unicode\": \"e614\",\n      \"unicode_decimal\": 58900\n    },\n    {\n      \"icon_id\": \"1010111\",\n      \"name\": \"上传\",\n      \"font_class\": \"upload-1\",\n      \"unicode\": \"e618\",\n      \"unicode_decimal\": 58904\n    },\n    {\n      \"icon_id\": \"351773\",\n      \"name\": \"探索-选中\",\n      \"font_class\": \"explore\",\n      \"unicode\": \"e621\",\n      \"unicode_decimal\": 58913\n    },\n    {\n      \"icon_id\": \"564941\",\n      \"name\": \"ellipsis\",\n      \"font_class\": \"ellipsis\",\n      \"unicode\": \"e657\",\n      \"unicode_decimal\": 58967\n    },\n    {\n      \"icon_id\": \"1048859\",\n      \"name\": \"发送\",\n      \"font_class\": \"sent\",\n      \"unicode\": \"e60c\",\n      \"unicode_decimal\": 58892\n    },\n    {\n      \"icon_id\": \"1304951\",\n      \"name\": \"列表\",\n      \"font_class\": \"list-list\",\n      \"unicode\": \"e62d\",\n      \"unicode_decimal\": 58925\n    },\n    {\n      \"icon_id\": \"8676284\",\n      \"name\": \"列表\",\n      \"font_class\": \"list-icon\",\n      \"unicode\": \"e639\",\n      \"unicode_decimal\": 58937\n    },\n    {\n      \"icon_id\": \"22290034\",\n      \"name\": \"重试\",\n      \"font_class\": \"zhongshi\",\n      \"unicode\": \"e6bd\",\n      \"unicode_decimal\": 59069\n    },\n    {\n      \"icon_id\": \"22961085\",\n      \"name\": \"Fork 记录\",\n      \"font_class\": \"log\",\n      \"unicode\": \"e826\",\n      \"unicode_decimal\": 59430\n    }\n  ]\n}\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/components/chat/index.vue",
    "content": "<template>\n  <div class=\"chat-panel\">\n    <!-- <div class=\"chat-model\">{{ activeAssistant?.model }}</div> -->\n    <div class=\"chat-panel-inner flex-column\">\n      <div class=\"chat-init flex-unit flex-column\" v-if=\"isNotChating\">\n        <div class=\"assistant-info flex-column flex-unit\">\n          <div class=\"avatar\">\n            <img src=\"../../../public/images/avatar.png\" />\n          </div>\n          <div class=\"name\">\n            {{ activeAssistant.name }}\n          </div>\n          <div class=\"desc\">\n            {{ activeAssistant.description }}\n          </div>\n        </div>\n      </div>\n      <div class=\"chat-msg flex-unit\" v-else>\n        <ul>\n          <li\n            class=\"chat-msg-item flex-row\"\n            v-for=\"(msg, index) in localMessages\"\n            :key=\"index\"\n          >\n            <div class=\"avatar\" v-if=\"msg.role == 'user'\">\n              <img src=\"../../../public/images/user-filling.png\" />\n            </div>\n            <div class=\"avatar\" v-else>\n              <img src=\"../../../public/images/avatar.png\" />\n            </div>\n            <div class=\"msg flex-unit\">\n              <div class=\"title flex-row\">\n                <div class=\"name\">{{ msg.role }}</div>\n                <div class=\"time flex-row\">\n                  {{ timeFormat(msg.created_at) }}\n                </div>\n              </div>\n              <div\n                class=\"content\"\n                v-html=\"markedText(msg.content)\"\n                ref=\"content_Ref\"\n              ></div>\n              <div class=\"copy-btn flex-row\" v-show=\"msgBttnBoxShow[index]\">\n                <i\n                  class=\"iconfont icon-copy\"\n                  @click=\"copy(createText(msg.content))\"\n                ></i>\n              </div>\n            </div>\n          </li>\n        </ul>\n      </div>\n      <div class=\"scroll-box\" v-show=\"showScrollButton\" @click=\"scrollToBottom\">\n        <i class=\"iconfont icon-arrow-down\"></i>\n      </div>\n      <div class=\"chat-send\">\n        <div\n          class=\"chat-box flex-row\"\n          :style=\"{ height: textareaHeight + 'px' }\"\n          ref=\"chatBox_Ref\"\n        >\n          <button @click=\"StopOutput\" class=\"stop-btn\" v-show=\"isRunning\">\n            stop\n          </button>\n          <textarea\n            name=\"chat-input\"\n            class=\"chat-input flex-unit\"\n            :placeholder=\"inputPlaceholder\"\n            v-model=\"inputQuestion\"\n            @keydown=\"keyBoardCommitQuestion\"\n            :disabled=\"inputDisabled\"\n            :style=\"{ height: textareaHeight + 'px' }\"\n            @input=\"handleInput\"\n            ref=\"textarea_ref\"\n            maxlength=\"2000\"\n            cols=\"20\"\n          ></textarea>\n          <i class=\"iconfont icon-sent\" @click=\"clickCommitQuestion\"></i>\n        </div>\n      </div>\n    </div>\n  </div>\n</template>\n\n<script lang=\"ts\">\nimport {\n  defineComponent,\n  nextTick,\n  PropType,\n  ref,\n  watch,\n  computed,\n  onMounted,\n} from \"vue\";\nimport { IThread, IMessageData, IAssistant } from \"@/utils/types\";\nimport { marked } from \"marked\";\nimport { createMessage } from \"@/api/message\";\nimport { createRun, cancelRun } from \"@/api/run\";\nimport { getAssistant } from \"@/api/assistant\";\nimport { createThread } from \"@/api/thread\";\nimport BScroll from \"better-scroll\";\nimport { useRouter, useRoute } from \"vue-router\";\nimport { useI18n } from \"vue-i18n\";\nimport { ElMessage } from \"element-plus\";\nimport { tr } from \"element-plus/es/locale\";\nimport copy from \"@/utils/copy\";\nexport default defineComponent({\n  name: \"ChatChat\",\n  props: {\n    messages: {\n      type: Array as PropType<IMessageData[]>,\n      required: true,\n    },\n    chatInit: {\n      type: Boolean,\n      required: true,\n    },\n    activeAssistant: {\n      type: Object as PropType<IAssistant>,\n      required: true,\n    },\n    activeThread: {\n      type: Object as PropType<IThread>,\n      required: true,\n    },\n    inputDisabled: {\n      type: Boolean,\n      default: false,\n    },\n  },\n  setup(props, context) {\n    const { t } = useI18n();\n    const router = useRouter();\n    const route = useRoute();\n    const localMessages = ref<IMessageData[]>([...props.messages]);\n    const showScrollButton = ref(false);\n    const messageScroll = ref<BScroll | null>(null);\n    const inputQuestion = ref<string>(\"\");\n    const inputDisabled = ref(false);\n    const msgBttnBoxShow = ref<boolean[]>([]);\n    const answer = ref(\"\");\n    const activeThread = ref<IThread>({} as IThread);\n    const activeAssistant = ref<IAssistant>({} as IAssistant);\n    const isNotChating = ref(true);\n    const isRunning = ref(false);\n    const stopRunId = ref<string>(\"\");\n    const shouldContinueReceiving = ref(true);\n    const textareaHeight = ref(48);\n    const chatBox_Ref = ref();\n    const textarea_ref = ref();\n    const content_Ref = ref();\n    // Boolean if go\n    isNotChating.value = props.chatInit;\n    activeThread.value = props.activeThread;\n    activeAssistant.value = props.activeAssistant;\n    watch(\n      () => props.messages,\n      (newMessages) => {\n        localMessages.value = [...newMessages];\n        msgBttnBoxShow.value = new Array(newMessages.length).fill(true);\n      }\n    );\n    watch(\n      () => props.inputDisabled,\n      (newValue) => {\n        inputDisabled.value = newValue;\n      }\n    );\n    // Update scrollbars and scrolling events\n    watch(\n      () => localMessages.value,\n      (newMessages) => {\n        if (messageScroll.value) {\n          scrollToTop();\n          messageScroll.value.destroy();\n          messageScroll.value = null;\n        }\n        if (!isNotChating.value) {\n          nextTick(() => {\n            messageScroll.value = new BScroll(\".chat-msg\", {\n              click: true,\n              mouseWheel: true,\n              probeType: 3, //Only when set to 3 can the event of scrolling binding be triggered\n            });\n          });\n        }\n      },\n      {\n        immediate: true,\n        deep: true,\n      }\n    );\n    watch(\n      () => messageScroll.value,\n      (newValue) => {\n        if (newValue) {\n          messageScroll.value?.on(\"scroll\", handleScroll);\n          showScrollButton.value = false;\n          scrollToBottom();\n        }\n      }\n    );\n    watch(\n      () => props.chatInit,\n      (newValue) => {\n        isNotChating.value = newValue;\n      }\n    );\n    watch(\n      () => props.activeThread,\n      (newValue) => {\n        activeThread.value = newValue;\n      }\n    );\n    watch(\n      () => props.activeAssistant,\n      (newValue) => {\n        activeAssistant.value = newValue;\n      }\n    );\n\n    const handleInput = (event:any) => {\n      adjustHeight();\n      const maxLength = 2000; \n      if (inputQuestion.value?.length > maxLength) {\n        event.preventDefault(); \n        inputQuestion.value = inputQuestion.value.substring(0, maxLength); \n      }\n    };\n    const adjustHeight = () => {\n      const currentScrollTop = textarea_ref.value.scrollTop;\n      textarea_ref.value.style.height = textarea_ref.value.scrollHeight + \"px\";\n      chatBox_Ref.value.style.height = textarea_ref.value.style.height;\n      textarea_ref.value.scrollTop = currentScrollTop;\n    };\n\n    const inputPlaceholder = computed(() => {\n      if (typeof activeAssistant.value.name != \"undefined\") {\n        return replaceAssistant(t(\"chat.inputTip\"), activeAssistant.value.name);\n      } else {\n        return t(\"chat.inputTip\");\n      }\n    });\n    // Block events\n    const StopOutput = async () => {\n      shouldContinueReceiving.value = false;\n      try {\n        const response = await cancelRun(\n          activeThread.value.id,\n          stopRunId.value\n        );\n        if (!response.ok) {\n          console.error(\"Failed to cancel run\");\n        }\n      } catch (error) {\n        console.error(\"Failed to cancel run:\", error);\n      }\n    };\n    // dialogue\n    const commitQuestion: () => void = async () => {\n      const question = inputQuestion.value;\n      // If it came in by clicking on assistants without clicking on thread, or through preview\n      if (Object.keys(activeThread.value).length == 0) {\n        try {\n          let res = {} as IThread;\n          // If you click thread and do not select assistant\n          if (route.name == \"preview\") {\n            let metadata = {\n              hidden: \"true\",\n            };\n            res = await createThread(undefined, undefined, metadata);\n          } else {\n            res = await createThread();\n          }\n          activeThread.value = res;\n        } catch (err) {\n          console.error(err);\n        }\n      }\n      //If you click thread and do not select assistant\n      else if (Object.keys(activeAssistant.value).length == 0) {\n        try {\n          const messageOfAssistant = props.messages.find(\n            (message) => message.role === \"assistant\"\n          );\n          if (messageOfAssistant && messageOfAssistant.assistant_id) {\n            const res = await getAssistant(messageOfAssistant.assistant_id);\n            activeAssistant.value = res;\n          }\n        } catch (err) {\n          console.error(err);\n        }\n      }\n      if (question) {\n        inputQuestion.value = \"\";\n        textareaHeight.value = 48;\n        // inputDisabled.value = true;\n        isNotChating.value = false;\n        isRunning.value = true;\n        await createMessage(activeThread.value.id, question)\n          .then((res: any) => {})\n          .catch((err: any) => {\n            ElMessage({\n              type: \"warning\",\n              message: \"Request error\",\n            });\n            return;\n          });\n        // Current message queue insertion issue\n        localMessages.value.push({\n          role: \"user\",\n          content: [\n            { type: \"text\", text: { value: question }, annotatons: [] },\n          ],\n          created_at: Date.now() / 1000,\n        });\n        msgBttnBoxShow.value.push(true);\n        // Insert answer into the current message queue\n        localMessages.value.push({\n          role: \"assistant\",\n          content: [{ type: \"text\", text: { value: \"\" }, annotatons: [] }],\n          created_at: Date.now() / 1000,\n        });\n        msgBttnBoxShow.value.push(false);\n        try {\n          const asyncGenerator = createRun(\n            {\n              assistant_id: activeAssistant.value.id,\n              stream: true,\n            },\n            activeThread.value.id\n          );\n          for await (const word of asyncGenerator) {\n            if (!shouldContinueReceiving.value) {\n              break;\n            }\n            if (word.length == 36) {\n              stopRunId.value = word;\n              console.log(stopRunId.value);\n            } else {\n              answer.value += word;\n              const index = localMessages.value.length - 1;\n              localMessages.value[index].content[0].text.value += word;\n              if (answer.value.length <= 3) {\n                localMessages.value[index].created_at = Date.now() / 1000;\n              }\n            }\n          }\n        } catch (err) {\n          console.error(err);\n        }\n        shouldContinueReceiving.value = true;\n        answer.value = \"\";\n        inputDisabled.value = false;\n        msgBttnBoxShow.value[msgBttnBoxShow.value.length - 1] = true;\n        scrollToBottom();\n        isRunning.value = false;\n        context.emit(\"updateAssistant\", true);\n        textarea_ref.value.focus();\n      }\n    };\n    // Keyboard event stabilization\n    const keyBoardCommitQuestion = (event: any) => {\n      const question = inputQuestion.value?.trim();\n      if (event.keyCode === 13) {\n        event.preventDefault();\n\n        const cursorPosition = event.target.selectionStart;\n        if ((event.metaKey || event.ctrlKey) && question) {\n          event.target.value =\n            event.target.value.substring(0, cursorPosition) +\n            \"\\n\" +\n            event.target.value.substring(cursorPosition);\n          event.target.selectionStart = event.target.selectionEnd =\n            cursorPosition + 1;\n          adjustHeight();\n          return;\n        }\n        if (!question) {\n          ElMessage({\n            message: \"Please enter the content!\",\n            type: \"warning\",\n            plain: true,\n          });\n          return;\n        }\n        if (!isRunning.value) {\n          commitQuestion();\n          inputQuestion.value = \"\";\n        }\n      }\n    };\n    const clickCommitQuestion = () => {\n      if (!isRunning.value && inputQuestion.value?.trim() != \"\") {\n        commitQuestion();\n        return;\n      }\n      ElMessage({\n        message: \"Please enter the content!\",\n        type: \"warning\",\n        plain: true,\n      });\n    };\n    //Bottom scrolling\n    const scrollToBottom = () => {\n      //If messageScroll. value exists\n      if (messageScroll.value) {\n        //Call the scrollTo method of messageScroll. value and scroll to the bottom\n        messageScroll.value.scrollTo(0, messageScroll.value?.maxScrollY, 800);\n      }\n    };\n    // Top scrolling\n    const scrollToTop = () => {\n      if (messageScroll.value) {\n        messageScroll.value.scrollTo(0, messageScroll.value?.minScrollY, 800);\n      }\n    };\n    // Handling rolling events\n    const handleScroll = (pos: any) => {\n      if (messageScroll.value) {\n        const distanceToBottom =\n          messageScroll.value.y - messageScroll.value.maxScrollY;\n        showScrollButton.value = distanceToBottom > 100;\n      }\n    };\n    // Replace characters\n\n    function replaceAssistant(input: string, newString: string) {\n      return input.replace(/assistant/g, newString);\n    }\n    // Extract the markup text to convert the passed in object array into an HTML string parsed by market.js\n    const markedText = (content: object[]) => {\n      let context = \"\";\n      for (const item of content) {\n        if ((item as { type: string }).type === \"text\") {\n          context += ((item as { text: object }).text as { value: string })\n            .value;\n        }\n      }\n      return marked.parse(context);\n    };\n    // Extract text content\n    const createText = (content: object[]) => {\n      let context = \"\";\n      for (const item of content) {\n        if ((item as { type: string }).type === \"text\") {\n          context += ((item as { text: object }).text as { value: string })\n            .value;\n        }\n      }\n      return context;\n    };\n    // Time formatting\n    const timeFormat = (timestamp: number | undefined) => {\n      if (!timestamp) {\n        return \"\";\n      }\n      const date = new Date(timestamp * 1000);\n      // Obtain various time sections\n      const year = date.getFullYear();\n      const month = String(date.getMonth() + 1).padStart(2, \"0\"); // The month starts from 0 and needs to be increased by 1, with zeros added\n      const day = String(date.getDate()).padStart(2, \"0\"); // Zero padding\n      const hours = String(date.getHours()).padStart(2, \"0\"); // Zero padding\n      const minutes = String(date.getMinutes()).padStart(2, \"0\"); // Zero padding\n      const seconds = String(date.getSeconds()).padStart(2, \"0\"); // Zero padding\n      // Format as \"YYYY-MM-DD HH: mm: ss\"\n      const formattedDate = `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`;\n      return formattedDate;\n    };\n    onMounted(() => {\n      adjustHeight();\n    });\n    return {\n      inputQuestion,\n      inputDisabled,\n      msgBttnBoxShow,\n      localMessages,\n      textareaHeight,\n      answer,\n      StopOutput,\n      isNotChating,\n      handleInput,\n      chatBox_Ref,\n      adjustHeight,\n      content_Ref,\n      markedText,\n      timeFormat,\n      createText,\n      inputPlaceholder,\n      keyBoardCommitQuestion,\n      clickCommitQuestion,\n      messageScroll,\n      showScrollButton,\n      commitQuestion,\n      scrollToBottom,\n      scrollToTop,\n      isRunning,\n      copy,\n      replaceAssistant,\n      textarea_ref,\n    };\n  },\n});\n</script>\n\n<style scoped lang=\"stylus\">\n@import '@/assets/css/mixins.styl';\n\n.chat-panel {\n  justify-content: center;\n  display: flex;\n  position: relative;\n  height: 100%;\n\n  .chat-model {\n    font-size: 16px;\n    font-weight: bold;\n    position: absolute;\n    top: 20px;\n    left: 30px;\n  }\n\n  .chat-panel-inner {\n    width: 920px;\n    padding-top: 80px;\n  }\n\n  .chat-init {\n    padding: 0 20px;\n\n    .assistant-info {\n      text-align: center;\n      align-items: center;\n      justify-content: center;\n\n      .avatar img {\n        width: 70px;\n        height: 70px;\n      }\n\n      .name {\n        margin: 40px 0;\n        font-size: 20px;\n        font-weight: bold;\n      }\n\n      .desc {\n        color: $gray_40;\n      }\n    }\n\n    .assistant-tips {\n      margin-bottom: 80px;\n\n      .tips-item {\n        width: 44%;\n        height: 70px;\n        line-height: 70px;\n        float: left;\n        border: 1px solid $border_gray_light_normal;\n        border-radius: 8px;\n        margin-top: 10px;\n        margin-bottom: 10px;\n        padding: 0 20px;\n        color: $gray_40;\n\n        &:nth-child(odd) {\n          margin-left: 4%;\n          margin-right: 4%;\n        }\n\n        &:nth-child(even) {\n          margin-right: 4%;\n        }\n\n        .tips-ops {\n          display: none;\n          width: 24px;\n          height: 24px;\n          line-height: 24px;\n          border-radius: 4px;\n          text-align: center;\n          border: 1px solid $border_gray_light_normal;\n\n          i {\n            font-size: 20px;\n          }\n        }\n\n        &:hover {\n          cursor: pointer;\n          background-color: $bg_gray_light_hover;\n\n          .tips-ops {\n            display: block;\n            background-color: #FFFFFF;\n          }\n        }\n      }\n    }\n  }\n\n  .chat-msg {\n    overflow-y: hidden;\n\n    ul {\n      li.chat-msg-item {\n        margin-bottom: 40px;\n        align-items: flex-start !important;\n        // border: 1px solid;\n        border-radius: 15px;\n        padding: 20px;\n        margin-right: 20px;\n        background-color: #313344;\n        box-shadow: 12.5px 12.5px 10px rgba(0, 0, 0, 0.035), 10px 10px 8px rgba(0, 0, 0, 0.07);\n\n        .avatar {\n          margin-right: 15px;\n          width: 36px;\n          height: 36px;\n\n          img {\n            width: 100%;\n            height: 100%;\n            border-radius: 25px;\n          }\n        }\n\n        .msg {\n          .title {\n            display: flex;\n            align-items: center;\n            justify-content: space-between;\n            margin-bottom: 12px;\n            height: 36px;\n            line-height: 24px;\n\n            .time {\n              justify-content: center;\n              // margin-bottom: 12px;\n              line-height: 20px;\n              font-size: 14px;\n              color: $gray_80;\n            }\n\n            .name {\n              color: #edf2ea;\n              font-size: 16px;\n              font-weight: bold;\n              margin-right: 15px;\n            }\n\n            .tips {\n              font-size: 14px;\n              color: $gray_50;\n            }\n          }\n\n          .content {\n            max-width: 829px;\n            color: #edf2ea;\n            font-size: 14px;\n            line-height: 20px;\n            word-wrap: break-word;\n            margin-bottom: 12px;\n          }\n\n          .copy-btn {\n            margin-top: 10px;\n            justify-content: left;\n\n            i {\n              font-size: 20px;\n              color: $gray_70;\n\n              &:hover {\n                cursor: pointer;\n                color: $gray_50;\n\n                .tips-ops {\n                  display: block;\n                  background-color: #FFFFFF;\n                }\n              }\n            }\n          }\n        }\n      }\n    }\n  }\n\n  .chat-send {\n    width: 900px;\n    padding: 40px 0;\n    position: relative;\n\n    .chat-box {\n      width: 100%;\n      height: auto;\n      min-height: 48px;\n      max-height: 192px !important;\n      border: none;\n      border-radius: 15px;\n      background: white;\n      line-height: 48px;\n\n      // overflow: hidden;\n      .chat-input {\n        height: auto;\n        min-width: 900px;\n        max-height: 192px !important;\n        width: 100%;\n        border: none;\n        overflow-anchor: auto;\n        overflow-x: hidden;\n        overflow-y: auto;\n        resize: none;\n        background: white;\n        display: inline-block;\n      }\n\n      .chat-input::-webkit-scrollbar {\n        width: 10px;\n      }\n\n      .chat-input::-webkit-scrollbar-track {\n        background-color: #f1f1f1;\n      }\n\n      .chat-input::-webkit-scrollbar-thumb {\n        background-color: #888;\n        border-radius: 5px;\n      }\n\n      .chat-input::-webkit-scrollbar-thumb:hover {\n        background-color: #555;\n      }\n\n      .chat-input::-webkit-resizer {\n        display: none;\n      }\n\n      .stop-btn {\n        border: none;\n        width: 60px;\n        position: absolute;\n        right: 50%;\n        transform: translateX(50%);\n        top: -40px;\n        -webkit-border-radius: 50;\n        -moz-border-radius: 50;\n        border-radius: 50px;\n        font-family: Arial;\n        color: #ffffff;\n        font-size: 16px;\n        background: #cacdd1;\n        padding: 10px 15px 10px 15px;\n        text-decoration: none;\n      }\n\n      .stop-btn:hover {\n        background: #8080e1;\n        text-decoration: none;\n        cursor: pointer;\n      }\n    }\n  }\n}\n\n.scroll-box {\n  position: absolute;\n  bottom: 130px;\n  right: 50%;\n  transform: translateX(50%);\n  margin: 0 auto;\n  width: 32px;\n  height: 32px;\n  border-radius: 16px;\n  border: 1px solid $gray_80;\n  background-color: var(--el-bg-color-overlay);\n  box-shadow: var(--el-box-shadow-lighter);\n  text-align: center;\n  line-height: 32px;\n  color: #1989fa;\n\n  i {\n    font-size: 24px;\n    color: $gray_60;\n  }\n\n  &:hover {\n    cursor: pointer;\n    background-color: $bg_gray_light_hover;\n\n    i {\n      color: $gray_50;\n    }\n  }\n}\n</style>"
  },
  {
    "path": "kt-sft/ktransformers/website/src/conf/config.ts",
    "content": "declare global {\n    interface Window {\n      configWeb: {\n        apiUrl: string;\n        port: string;\n       };\n     }\n  }\n\nexport const baseURL = window.configWeb.apiUrl;\nexport const basePort = window.configWeb.port;\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/locals/en.js",
    "content": "// en.js\nexport default {\n    home: {\n        explore: 'Explore',\n        language: 'Choose Language',\n        english: 'English',\n        chinese: 'Chinese',\n        today: 'Today',\n        previous:'Previous',\n        withoutAssistantTip:'The KTransformers of this record has been deleted. The user can only view historical conversation information and cannot continue the conversation!',\n        deleteThreadTip:'Deleting records will clear historical information~'\n    },\n    chat:{\n        inputTip:\"Send a message and chat with the KTransformers ~\",\n    },\n    explore:{\n        description: \"Based on Lexllama, let’s create your own KTransformers~\",\n        configuring: \"Configuring\",\n        completed: \"Completed\",\n        assistantName: \"Name\",\n        assistantDescription: \"Description\",\n        assistantStatus: \"Status\",\n        createAssistant: \"Create New KTransformers\",\n        deleteAssistant: \"Are you sure to delete this? After deleting the KTransformers, its KVCache will also be cleared simultaneously~\",\n    },\n    config:{\n        title:'Configure your KTransformers',\n        fileTip:\"Only support text, docx, .ppt, .pdf format.\",\n        reConfigTip:'Reconfig KTransformers needs to delete kvcache, please choose carefully',\n        secletFile:'Select Files',\n        outOfSize:'File size exceeds 10MB, please reselect',\n        fileExist:'The file already exists, please reselect',\n        createAssistant:'Assistant created successfully, click the build button to start building KVCache',\n    },\n    build:{\n        title:'Building Logs',\n        step1:'Parse uploded files',\n        parsingFileStep1:'File upload and reception completed',\n        parsingFileStep2:{\n            parse:\"Parsing\",\n            file:\"file(s)\",\n            total:'total',\n        },\n        parsingFileStep3:'Prompt loaded, ready to generate KVCache',\n        step2:'Generate KVCache',\n        generateStep1:'Generate KVCache calculation plan',\n        generateStep2:{\n            calculate:\"calculating\",\n            token:\"tokens\",\n            total:'total',\n        },\n        generateStep3:'KVCache has been generated successfully',\n        durationTime:'Duration:',\n        remainTime:'Time left:',\n        buildProgress:'Building Progress',\n        storageUsage:'KVCache Storage Usage',\n    }\n}\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/locals/index.js",
    "content": "// index.js\nimport { createI18n } from 'vue-i18n'\nimport zh from './zh'\nimport en from './en'\n\nconst messages = {\n  en,\n  zh,\n}\nconst language = (navigator.language || 'en').toLocaleLowerCase() // 这是获取浏览器的语言\nconst i18n = createI18n({\n  legacy: false, // you must set `false`, to use Compostion API\n  locale: localStorage.getItem('lang') || language.split('-')[0] || 'en', // 首先从缓存里拿，没有的话就用浏览器语言，\n  fallbackLocale: 'en', // 设置备用语言\n  messages, \n})\n\nexport default i18n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/locals/zh.js",
    "content": "// zh.js\nexport default {\n    home: {\n        explore: '探索',\n        language: '选择语言',\n        english: '英语',\n        chinese: '中文',\n        today: '今天',\n        previous:'历史',\n        withoutAssistantTip:'本记录的KTransformers已被删除，用户只能查看历史对话信息而无法继续对话!',\n        deleteThreadTip:'删除记录会清除历史信息哦～'\n    },\n    chat:{\n        inputTip:\"发送信息和 KTransformers 畅聊吧～\",\n    },\n    explore:{\n        description: \"基于Lexllama，一起来创建你的专属KTransformers吧~\",\n        configuring: \"配置中\",\n        completed: \"完成\",\n        assistantName: \"名称\",\n        assistantDescription: \"描述\",\n        assistantStatus: \"Status\",\n        createAssistant: \"创建新的KTransformers\",\n        deleteAssistant: \"是否确认删除KTransformers，删除KTransformers之后其KVCache也会被同步清理掉哦~\",\n    },\n    config:{\n        title:'配置你的KTransformers',\n        fileTip:\"仅支持上传文件格式为 .text, docx, .ppt, .pdf format.\",\n        secletFile:'选择文件',\n        outOfSize:'文件大小超出10MB，请重新选择',\n        fileExist:'文件已存在，请重新选择',\n        createAssistant:'KTransformers创建成功，点击build按钮开始构建KVCache',\n    },\n    build:{\n        title:'构建日志',\n        step1:'解析上传文件',\n        parsingFileStep1:'文件上传接收完成',\n        parsingFileStep2:{\n            parse:\"正在解析第\",\n            file:\"文件\",\n            total:'共',\n        },\n        parsingFileStep3:'Prompt装载完毕，准备生成KVCache',\n        step2:'生成 KVCache',\n        generateStep1:'生成KVCache计算计划',\n        generateStep2:{\n            calculate:\"正在计算\",\n            token:\"tokens\",\n            total:'共',\n        },\n        generateStep3:'KVCache已生成完成',\n        durationTime:'持续时间：',\n        remainTime:'剩余时间：',\n        buildProgress:'构建进度',\n        storageUsage:'存储使用：',\n        \n    }\n}\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/main.ts",
    "content": "import { createApp } from 'vue'\nimport App from './App.vue'\nimport router from './router'\nimport store from './store'\nimport ElementPlus from 'element-plus'\nimport 'element-plus/dist/index.css'\nimport VueApexCharts from \"vue3-apexcharts\"\nimport i18n from '@/locals'\n\nconst app = createApp(App)\n\napp.use(ElementPlus)\n\napp.use(i18n)\napp.use(VueApexCharts)\napp.use(store)\napp.use(router)\napp.mount('#app')\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/router/index.ts",
    "content": "import { createRouter, createWebHashHistory, RouteRecordRaw, createWebHistory } from 'vue-router'\nimport HomeView from '@/views/home.vue'\n\nconst routes: Array<RouteRecordRaw> = [\n  {\n    path: '/',\n    name: 'home',\n    component: HomeView,\n    redirect: '/chat',\n    children: [{\n      path: '/chat',\n      name: '',\n      component: () => import(/* webpackChunkName: \"about\" */ '../components/chat/index.vue')\n    },]\n  },\n\n]\n\nconst router = createRouter({\n  history: createWebHashHistory(),\n  routes\n})\n\nexport default router\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/shims-vue.d.ts",
    "content": "/* eslint-disable */\ndeclare module '*.vue' {\n  import type { DefineComponent } from 'vue'\n  const component: DefineComponent<{}, {}, any>\n  export default component\n  \n}\n\ndeclare module '@/locals'\ndeclare module 'pdfobject';\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/store/index.ts",
    "content": "import { createStore } from 'vuex'\n\nexport default createStore({\n  state: {\n  },\n  getters: {\n  },\n  mutations: {\n  },\n  actions: {\n  },\n  modules: {\n  }\n})\n"
  },
  {
    "path": "kt-sft/ktransformers/website/src/utils/copy.ts",
    "content": "import { ElMessage } from \"element-plus\";\nconst copy = (value: string) => {\n  //Try using the navigator.clipboard.writeText method\n  if (navigator.clipboard && window.isSecureContext) {\n    navigator.clipboard.writeText(value)\n      .then(() => {\n        //Using ElMessage to Display Success Messages in Windows Systems\n        if (navigator.appVersion.includes(\"Win\")) {\n          ElMessage({\n            message: \"内容复制成功!\",\n            type: \"success\",\n            plain: true,\n          });\n        } else {\n          //Using custom DOM elements to display success messages in macOS system\n          showCopySuccessMessage();\n        }\n      })\n      .catch(() => {\n        //Using ElMessage to Display Failure Messages in Windows Systems\n        if (navigator.appVersion.includes(\"Win\")) {\n          ElMessage({\n            message: \"内容复制失败!\",\n            type: \"error\",\n            plain: true,\n          });\n        } else {\n          //Using custom DOM elements to display failure messages in macOS system\n          showCopyErrorMessage();\n        }\n      });\n  } else {\n    const textarea = document.createElement(\"textarea\");\n    textarea.value = value;\n    document.body.appendChild(textarea);\n    textarea.select();\n    try {\n      const successful = document.execCommand('copy');\n      if (successful) {\n        if (navigator.appVersion.includes(\"Win\")) {\n          ElMessage({\n            message: \"内容复制成功!\",\n            type: \"success\",\n            plain: true,\n          });\n        } else {\n          showCopySuccessMessage();\n        }\n      } else {\n        if (navigator.appVersion.includes(\"Win\")) {\n          ElMessage({\n            message: \"内容复制失败!\",\n            type: \"error\",\n            plain: true,\n          });\n        } else {\n          showCopyErrorMessage();\n        }\n      }\n    } catch (err) {\n      if (navigator.appVersion.includes(\"Win\")) {\n        ElMessage({\n          message: \"内容复制失败!\",\n          type: \"error\",\n          plain: true,\n        });\n      } else {\n        showCopyErrorMessage();\n      }\n    }\n    document.body.removeChild(textarea);\n  }\n};\n\nfunction showCopySuccessMessage() {\n  const messageElement = document.createElement('div');\n  messageElement.textContent = '内容复制成功!';\n  messageElement.style.position = 'fixed';\n  messageElement.style.bottom = '10px';\n  messageElement.style.left = '50%';\n  messageElement.style.transform = 'translateX(-50%)';\n  messageElement.style.padding = '10px';\n  messageElement.style.backgroundColor = '#4CAF50';\n  messageElement.style.color = 'white';\n  messageElement.style.borderRadius = '15px';\n  messageElement.style.zIndex = '1000';\n  document.body.appendChild(messageElement);\n  setTimeout(() => {\n    document.body.removeChild(messageElement);\n  }, 3000);\n}\n\nfunction showCopyErrorMessage() {\n  const messageElement = document.createElement('div');\n  messageElement.textContent = '内容复制失败!';\n  messageElement.style.position = 'fixed';\n  messageElement.style.bottom = '10px';\n  messageElement.style.left = '50%';\n  messageElement.style.transform = 'translateX(-50%)';\n  messageElement.style.padding = '10px';\n  messageElement.style.backgroundColor = '#F44336';\n  messageElement.style.color = 'white';\n  messageElement.style.borderRadius = '5px';\n  messageElement.style.zIndex = '1000';\n  document.body.appendChild(messageElement);\n  setTimeout(() => {\n    document.body.removeChild(messageElement);\n  }, 3000);\n}\n\nexport default copy;"
  },
  {
    "path": "kt-sft/ktransformers/website/src/utils/types.ts",
    "content": "export interface IAssistant {\n  id: string;\n  object: string;\n  created_at: number;\n  name?: string;\n  description?: string;\n  model: string;\n  instructions?: string;\n  tools: any[];\n  tool_resources?: object;\n  metadata?:{[key:string]:any}\n  top_p?: number;\n  temperature?: number;\n  response_format: string | object;\n}\n\nexport interface IAssistantWithStatus {\n  build_status:{status:string}\n  id: string;\n  object: string;\n  created_at: number;\n  name?: string;\n  description?: string;\n  model: string;\n  instructions?: string;\n  tools: any[];\n  tool_resources?: object;\n  metadata?:{[key:string]:any}\n  top_p?: number;\n  temperature?: number;\n  response_format: string | object;\n}\n\nexport interface IMessage {\n  id: string;\n  object: string;\n  created_at: number;\n  thread_id: string;\n  status: string;\n  incomplete_details?: object;\n  completed_at?: number;\n  incomplete_at?: number;\n  role: string;\n  content: any[];\n  assistant_id?: string;\n  run_id?: string;\n  attachments?: any[];\n  metadata:{[key:string]:any}\n}\n\nexport interface IThread {\n  id: string;\n  object: string;\n  created_at: number;\n  tool_resources?: object;\n  metadata?:{[key:string]:any}\n}\n\nexport interface IRun {\n  id: string;\n  object: string;\n  created_at: number;\n  thread_id: string,\n  assistant_id: string,\n  status: string,\n  required_action?: object,\n  last_error?: object,\n  expires_at?: number,\n  started_at?: number,\n  cancelled_at?: number,\n  failed_at?: number,\n  completed_at?: number,\n  incomplete_details?: object,\n  model: string,\n  instructions: string,\n  tools: any[],\n  metadata: Map<string, string>,\n  usage?: object,\n  temperature?: number,\n  top_p?: number,\n  max_prompt_tokens?: number,\n  max_completion_tokens?: number,\n  truncation_strategy: object,\n  tool_choice: string | object,\n  response_format: string | object,\n}\n\nexport interface IFile {\n  id: string,\n  bytes: number,\n  created_at: number,\n  filename: string,\n  object: string,\n  purpose: string,\n}\n\nexport interface IMessageData {\n  role: string;\n  content: any[];\n  created_at?: number;\n  assistant_id?: string,\n}\n\nexport interface IThreadAndMessageAndAssistant {\n\n  thread: IThread;\n  first_message: IMessage;\n  assistant: IAssistantWithStatus\n}\nexport interface IDeleteResult {\n  id: string;\n  object: string;\n  deleted: boolean;\n}\nexport interface IBuildData {\n  parsed_file_count:number;\n  total_file_count:number;\n  prefilling_current:number;\n  prefilling_total:number;\n  build_completed_time:number;\n  build_started_time:number;\n  storage_total:number;\n  storage_usage:number;\n  status:string\n}"
  },
  {
    "path": "kt-sft/ktransformers/website/src/views/home.vue",
    "content": "<template>\n  <div class=\"home flex-row\">\n    <nav class=\"left-panel flex-column\">\n      <div class=\"logo-box\">\n        <div class=\"logo flex-row\">\n          <img class=\"img\" src=\"../../public/images/three.png\" />\n          <span class=\"text\">{{ projectName }}</span>\n        </div>\n        <div class=\"version\">{{ projectVersion }}</div>\n      </div>\n      <div class=\"divider\"></div>\n      <div class=\"assistant-box\">\n        <div class=\"assistant-list\">\n          <ul>\n            <li\n              class=\"assistant-item flex-row\"\n              v-for=\"(item, index) in assistantList\"\n              :key=\"index\"\n              @click=\"setActiveAssistant(item)\"\n            >\n              <img src=\"../../public/images/avatar.png\" />\n              <span class=\"name flex-unit\">{{ item.name }}</span>\n              <i class=\"iconfont icon-edit\"></i>\n            </li>\n          </ul>\n        </div>\n      </div>\n      <div class=\"divider\"></div>\n      <!-- History area -->\n      <div class=\"history-box flex-unit\">\n        <div class=\"\">\n          <div class=\"date\">{{ $t(\"home.today\") }}</div>\n          <ul>\n            <li\n              v-for=\"(item, index) in todayThreads\"\n              :key=\"index\"\n              class=\"chat-item\"\n              :class=\"{ active: activeThreadIndex === index }\"\n              @click=\"setActiveThreadIndex(index)\"\n            >\n              <div class=\"chat-abbr\">\n                {{ firstMessages[index] }}\n              </div>\n              <div class=\"chat-ops flex-row\">\n                <img src=\"../../public/images/avatar.png\" />\n                <div class=\"name flex-unit\">\n                  {{ assistantOfThread[index].name || \"\" }}\n                </div>\n                <i class=\"iconfont icon-delete\" @click=\"delThread(index)\"></i>\n              </div>\n            </li>\n          </ul>\n          <div class=\"date\" v-if=\"previousThreads.length > 0\">\n            {{ $t(\"home.previous\") }}\n          </div>\n          <ul>\n            <li\n              v-for=\"(item, index) in previousThreads\"\n              :key=\"index\"\n              class=\"chat-item\"\n              :class=\"{\n                active: activeThreadIndex === index + todayThreads.length,\n              }\"\n              @click=\"setActiveThreadIndex(index + todayThreads.length)\"\n            >\n              <div class=\"chat-abbr\">\n                {{ firstMessages[index + todayThreads.length] }}\n              </div>\n              <div class=\"chat-ops flex-row\">\n                <img src=\"../../public/images/avatar.png\" />\n                <div class=\"name flex-unit\">\n                  {{\n                    assistantOfThread[index + todayThreads.length].name || \"\"\n                  }}\n                </div>\n                <i\n                  class=\"iconfont icon-delete\"\n                  @click=\"delThread(index + todayThreads.length)\"\n                ></i>\n              </div>\n            </li>\n          </ul>\n        </div>\n      </div>\n      <div class=\"icon-box example-2\">\n        <div class=\"iconhub icon-content\" @click=\"navigateToIconHub\">\n          <svg\n            xmlns=\"http://www.w3.org/2000/svg\"\n            width=\"16\"\n            height=\"16\"\n            fill=\"currentColor\"\n            class=\"bi bi-github\"\n            viewBox=\"0 0 16 16\"\n            xml:space=\"preserve\"\n          >\n            <path\n              d=\"M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27s1.36.09 2 .27c1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.01 8.01 0 0 0 16 8c0-4.42-3.58-8-8-8\"\n              fill=\"currentColor\"\n            ></path>\n          </svg>\n          <div class=\"tooltip\">GitHub</div>\n        </div>\n        <div class=\"iconlanguage\" @click=\"changeLanguage\">\n          <svg\n            v-if=\"!flag\"\n            t=\"1719306572024\"\n            class=\"icon\"\n            viewBox=\"0 0 1024 1024\"\n            version=\"1.1\"\n            xmlns=\"http://www.w3.org/2000/svg\"\n            p-id=\"16849\"\n            data-spm-anchor-id=\"a313x.search_index.0.i21.366e3a81tz0TYS\"\n            width=\"18\"\n            height=\"18\"\n          >\n            <path\n              d=\"M64.064 768V192H448.64v64H127.936v192h320v64h-320v192h320v64H64.064z m511.872 0V192h64l256 447.68V192h64v576h-64l-256-447.168V768h-64z\"\n              p-id=\"16850\"\n              data-spm-anchor-id=\"a313x.search_index.0.i22.366e3a81tz0TYS\"\n              class=\"selected\"\n              fill=\"#000000\"\n            ></path>\n          </svg>\n          <svg\n            v-else\n            t=\"1719306494614\"\n            class=\"icon\"\n            viewBox=\"0 0 1024 1024\"\n            version=\"1.1\"\n            xmlns=\"http://www.w3.org/2000/svg\"\n            p-id=\"12325\"\n            width=\"18\"\n            height=\"18\"\n          >\n            <path\n              d=\"M1023.488 831.552h-96l-265.472-451.904c-8.96-12.8-16-25.344-21.44-37.888H638.08c2.176 12.992 3.2 40.128 3.2 81.408v408.32L576 836.928V256h101.568l257.024 445.632c14.592 20.992 23.232 34.368 25.92 40.128h1.6c-2.688-16.512-4.032-44.8-4.032-84.736v-399.36L1024 256l-0.512 575.552zM435.008 804.224c-42.752 21.76-96.384 32.64-160.896 32.64-83.2 0-149.76-25.6-199.488-76.736C24.896 708.928 0 641.344 0 557.12c0-90.432 27.968-163.2 84.032-218.368C140.032 283.52 211.072 256 297.344 256c55.552 0 101.376 7.616 137.6 22.848v75.84a284.992 284.992 0 0 0-136.832-33.408c-64.768 0-117.504 20.864-158.208 62.592-40.768 41.728-61.184 98.048-61.184 168.96 0 67.2 19.008 120.576 57.024 160.128 38.016 39.552 87.744 59.328 149.248 59.328 57.536 0 107.52-12.544 150.016-37.76v69.696z\"\n              fill=\"#000000\"\n              p-id=\"12326\"\n              data-spm-anchor-id=\"a313x.search_index.0.i16.366e3a81tz0TYS\"\n              class=\"selected\"\n            ></path>\n          </svg>\n        </div>\n      </div>\n    </nav>\n    <router-view v-slot=\"{ Component }\" class=\"main-panel flex-unit\">\n      <component\n        :is=\"Component\"\n        :chatInit=\"chatInit\"\n        :activeAssistant=\"activeAssistant\"\n        :activeThread=\"activeThread\"\n        :messages=\"allMessageInCurrentThread\"\n        :completedAssistant=\"assistantList\"\n        :inputDisabled=\"inputDisabled\"\n        @updateAssistant=\"handleUpdateAssistant\"\n      />\n    </router-view>\n  </div>\n</template>\n\n<script lang=\"ts\">\nimport { defineComponent, ref, onMounted, computed, nextTick } from \"vue\";\nimport {\n  IThread,\n  IAssistant,\n  IMessageData,\n  IThreadAndMessageAndAssistant,\n  IAssistantWithStatus,\n} from \"@/utils/types\";\nimport { listThreads, deleteThread, getThread } from \"@/api/thread\";\nimport { ElMessage, ElMessageBox } from \"element-plus\";\nimport { listAssistants } from \"@/api/assistant\";\nimport { listMessages } from \"@/api/message\";\nimport { useRouter } from \"vue-router\";\nimport BScroll from \"better-scroll\";\nimport { useI18n } from \"vue-i18n\";\n\nexport default defineComponent({\n  name: \"HomeView\",\n  setup() {\n    const assistantList = ref<IAssistant[]>([]);\n    const threadsList = ref<IThread[]>([]);\n    const firstMessages = ref<string[]>([]);\n    const activeAssistant = ref({} as IAssistant);\n    const assistantOfThread = ref<IAssistantWithStatus[]>([]);\n    const threadAndMessages = ref<IThreadAndMessageAndAssistant[]>([]);\n    const assistantScroll = ref<BScroll | null>(null);\n    const historyScroll = ref<BScroll | null>(null);\n    const router = useRouter();\n    const { t, locale } = useI18n();\n    const flag = ref(true);\n    const changeLanguage = () => {\n      if (flag.value) {\n        locale.value = \"zh\";\n        localStorage.setItem(\"lang\", \"zh\");\n        flag.value = false;\n      } else {\n        locale.value = \"en\";\n        flag.value = true;\n        localStorage.setItem(\"lang\", \"en\");\n      }\n    };\n    // Initialize data\n    const initData = async () => {\n      try {\n        threadsList.value = [];\n        firstMessages.value = [];\n        assistantOfThread.value = [];\n\n        const assistantsRes = await listAssistants();\n        if (assistantsRes && assistantsRes.length > 0) {\n          assistantList.value = assistantsRes;\n          activeAssistant.value = assistantsRes[0];\n        }\n\n        const threadsRes = await listThreads(100);\n        if (threadsRes) {\n          threadAndMessages.value = threadsRes;\n          for (let t of threadsRes) {\n            if (t.thread && !t.thread.metadata?.hidden) {\n              threadsList.value.push(t.thread);\n              if (\n                t.first_message &&\n                t.first_message.content &&\n                t.first_message.content.length > 0\n              ) {\n                firstMessages.value.push(t.first_message.content[0].text.value);\n              } else {\n                firstMessages.value.push(\"no message yet\");\n              }\n              assistantOfThread.value.push(\n                t.assistant || ({} as IAssistantWithStatus)\n              );\n            }\n          }\n        }\n\n        assistantScroll.value = new BScroll(\".assistant-list\", {\n          click: true,\n          mouseWheel: true,\n          scrollbar: {\n            fade: true,\n            interactive: true,\n          },\n        });\n\n        historyScroll.value = new BScroll(\".history-box\", {\n          click: true,\n          mouseWheel: true,\n          scrollbar: {\n            fade: true,\n            interactive: true,\n          },\n        });\n      } catch (err) {\n        console.error(\"Failed to initialize data:\", err);\n      }\n    };\n    const navigateToIconHub = () => {\n      window.open(\"https://github.com/kvcache-ai/Lexllama\");\n    };\n    const isEmptyObject = (obj: object): boolean => {\n      //Determine if the object is empty\n      return Object.keys(obj).length === 0;\n    };\n    //Jump route\n    const navigateToExplore = () => {\n      router.push(\"/explore\");\n    };\n    const navigatorToChat = () => {\n      router.push(\"/chat\");\n    };\n    // Calculate date\n    const todayThreads = computed(() => {\n      const today = Math.floor(Date.now() / 1000);\n      return threadsList.value.filter((thread) => {\n        return today - thread.created_at <= 86400;\n      });\n    });\n    const previousThreads = computed(() => {\n      const today = Math.floor(Date.now() / 1000);\n      return threadsList.value.filter((thread) => {\n        return today - thread.created_at > 86400;\n      });\n    });\n\n    onMounted(async () => {\n      initData();\n    });\n\n    return {\n      t,\n      flag,\n      assistantList,\n      isEmptyObject,\n      activeAssistant,\n      navigateToExplore,\n      navigatorToChat,\n      threadsList,\n      firstMessages,\n      navigateToIconHub,\n      assistantScroll,\n      historyScroll,\n      assistantOfThread,\n      changeLanguage,\n      initData,\n      todayThreads,\n      previousThreads,\n    };\n  },\n  data() {\n    return {\n      projectName: \"KTransformers\",\n      projectVersion: \"v0.01\",\n      activeThreadIndex: -1,\n      chatInit: true,\n      activeThread: {} as IThread,\n      allMessageInCurrentThread: [] as IMessageData[],\n      inputDisabled: false,\n      isSettingActiveThread: false,\n      isDeletingThread: false,\n      threadAndMessages: <IThreadAndMessageAndAssistant[]>[],\n    };\n  },\n  methods: {\n    setActiveAssistant(assistant: IAssistant) {\n      this.chatInit = true;\n      this.inputDisabled = false;\n      this.activeThreadIndex = -1;\n      this.activeAssistant = assistant;\n      this.activeThread = {} as IThread;\n      this.allMessageInCurrentThread = [];\n      if (this.$route.path != \"/chat\") {\n        this.navigatorToChat();\n      }\n    },\n    async setActiveThreadIndex(index: number) {\n      //If setting up an active thread, return directly\n      if (this.isSettingActiveThread) {\n        return;\n      }\n      this.isSettingActiveThread = true;\n      this.activeThreadIndex = index;\n      this.chatInit = false;\n      this.inputDisabled = false;\n      this.activeAssistant = {} as IAssistant;\n      this.activeThread = this.threadsList[index];\n      //If the assistant of the current thread is an empty object\n      if (this.isEmptyObject(this.assistantOfThread[index])) {\n        ElMessage({\n          message: this.t(\"home.withoutAssistantTip\"),\n          type: \"warning\",\n        });\n        this.inputDisabled = true;\n      }\n      try {\n        //Call asynchronous function to obtain the message list of the current thread\n        const res = await listMessages(this.activeThread.id, 100, \"asc\");\n        //Convert the obtained message list to the specified format and assign values to all messages of the current thread\n        this.allMessageInCurrentThread = res.map((m) => ({\n          role: m.role,\n          content: m.content,\n          assistant_id: m.assistant_id,\n          created_at: m.created_at,\n        }));\n      } catch (err) {\n        console.log(err);\n      } finally {\n        this.isSettingActiveThread = false;\n      }\n      if (this.$route.path != \"/chat\") {\n        this.navigatorToChat();\n      }\n    },\n\n    async delThread(index: number) {\n      // If the thread is currently being deleted, return directly\n      if (this.isDeletingThread) {\n        return;\n      }\n      this.isDeletingThread = true;\n      try {\n        //Pop up a confirmation box and ask the user if they are sure to delete the thread\n        await ElMessageBox.confirm(this.t(\"home.deleteThreadTip\"), \"Warning\", {\n          confirmButtonText: \"OK\",\n          cancelButtonText: \"Cancel\",\n          type: \"warning\",\n        });\n\n        const res = await deleteThread(this.threadsList[index].id);\n        this.threadsList.splice(index, 1);\n        this.firstMessages.splice(index, 1);\n        this.assistantOfThread.splice(index, 1);\n        // Jump to the first assistant or other suitable page\n        this.setActiveAssistant(this.assistantList[0]);\n        ElMessage({\n          type: \"success\",\n          message: \"Delete completed\",\n        });\n      } catch (err) {\n        // Specific error handling, such as logging or displaying specific error messages to users\n        console.error(\"Delete session failed:\", err);\n        ElMessage({\n          type: \"error\",\n          message: `Delete failed`, // Display specific error messages\n        });\n      } finally {\n        this.isDeletingThread = false; //Ensure that the delete thread flag is reset no matter what\n      }\n    },\n    // Handles the update of the assistant asynchronously.\n    async handleUpdateAssistant(value: any) {\n      await this.initData();\n      if (this.activeThreadIndex != -1) {\n        this.setActiveThreadIndex(this.activeThreadIndex);\n      } else if (this.activeAssistant.id) {\n        this.setActiveThreadIndex(0);\n      } else {\n        this.setActiveAssistant(this.assistantList[0]);\n      }\n    },\n  },\n});\n</script>\n\n\n<style lang=\"stylus\" rel=\"stylesheet/stylus\" scoped>\n@import '../assets/css/mixins.styl';\n\n.home {\n  width: 100%;\n  height: 100%;\n  position: relative;\n}\n\n.left-panel {\n  width: 320px;\n  height: 100%;\n  background-color: #363433;\n  padding: 30px 30px;\n  .logo-box {\n    .logo {\n      .img {\n        width: 36px;\n        height: 36px;\n      }\n\n      .text {\n        font-size: 28px;\n        font-weight: bold;\n        margin-left: 10px;\n        color: #edf2ea;\n      }\n    }\n\n    .version {\n      text-align: right;\n      font-size: 14px;\n      color: #bdbdbd;\n    }\n  }\n\n  .divider {\n    border-bottom: 1px solid #D7D7D7;\n    width: 30%;\n    margin: 30px auto;\n  }\n\n  .lang-box {\n    position: relative;\n    width: 100%;\n    height: 30px;\n    margin: auto;\n    margin-bottom: 10px;\n\n    .el-dropdown {\n      font-size: 14px;\n      position: absolute;\n      top: 50%;\n      left: 50%;\n      transform: translate(-50%, -50%);\n    }\n  }\n\n  .assistant-box {\n    .assistant-list {\n      min-height: 50px;\n      max-height: 300px;\n      overflow: hidden;\n      position: relative;\n\n      ul > li.assistant-item {\n        padding: 8px 15px;\n        color: #edf2ea;\n\n        img {\n          width: 32px;\n          height: 32px;\n        }\n\n        .name {\n          margin-left: 12px;\n          font-size: 14px;\n          color: #edf2ea;\n        }\n\n        i.iconfont {\n          display: none;\n          margin-left: 10px;\n        }\n\n        &:hover {\n          background-color: $bg_gray_light_hover;\n          cursor: pointer;\n          border-radius: 4px;\n\n          .name {\n            color: #313433;\n          }\n\n          i.iconfont {\n            display: block;\n          }\n        }\n      }\n    }\n\n    .explore {\n      position: relative;\n      justify-content: center;\n      display: flex;\n      margin-top: 10px;\n\n      .explore-btn {\n        margin: 0 auto;\n        padding: 0 20px;\n        justify-content: center;\n        height: 32px;\n        line-height: 32px;\n        background-color: #FFFFFF;\n        border: 1px solid RGBA(0, 0, 0, 0.15);\n        border-radius: 16px;\n\n        i {\n          color: #8080FF;\n        }\n\n        .text {\n          color: #7F7F7F;\n          margin-left: 4px;\n        }\n\n        &:hover {\n          background-color: #FAFAFA;\n          cursor: pointer;\n        }\n      }\n    }\n  }\n\n  .history-box {\n    position: relative;\n\n    .date {\n      font-size: 14px;\n      color: #7F7F7F;\n      margin: 8px 0;\n\n      &:first-child {\n        margin-top: 0;\n      }\n    }\n\n    li.chat-item {\n      padding: 12px 15px;\n      cursor: pointer;\n      background-color: #edf2ea;\n      border-radius: 4px;\n      margin-bottom: 10px;\n      font-size: 16px;\n\n      .chat-abbr {\n        font-size: 14px;\n        color: #313433;\n        white-space: nowrap;\n        overflow: hidden;\n        text-overflow: ellipsis;\n      }\n\n      .chat-ops {\n        display: flex;\n        margin-top: 5px;\n\n        img {\n          width: 16px;\n          height: 16px;\n        }\n\n        .name {\n          font-size: 12px;\n          color: #898989;\n          margin-left: 8px;\n        }\n\n        i.iconfont {\n          color: $gray_60;\n        }\n      }\n\n      &:hover, &.active {\n        transition: 0.3s all;\n        cursor: pointer;\n        background-color: #a2a79f;\n        .chat-abbr {\n          color: black;\n        }\n\n        .name, i.iconfont {\n          color: black;\n        }\n      }\n    }\n  }\n\n  .icon-box {\n    width: 100%;\n    display: flex;\n    flex-direction: row;\n    justify-content: flex-end;\n    align-items: center;\n\n    .iconhub {\n      width: 32px;\n      height: 24px;\n      background: white;\n      font-size: 30px;\n      border: none;\n      ovferflow: hidden;\n      border-radius: 15%;\n      display: flex;\n      flex-direction: column;\n      justify-content: center;\n      align-items: center;\n      color: #898989;\n      transition: all 0.5s;\n      cursor: pointer;\n    }\n\n    .iconhub:hover {\n      background: #e5e5e5;\n      text-decoration: none;\n    }\n\n    .iconlanguage {\n      margin-left: 15px;\n      width: 32px;\n      height: 24px;\n      background: white;\n      font-size: 30px;\n      border: none;\n      ovferflow: hidden;\n      border-radius: 15%;\n      display: flex;\n      flex-direction: column;\n      justify-content: center;\n      align-items: center;\n      color: #898989;\n      transition: all 0.5s;\n      cursor: pointer;\n    }\n\n    .iconlanguage:hover {\n      background: #e5e5e5;\n      text-decoration: none;\n    }\n  }\n}\n\nul {\n  list-style: none;\n}\n\n.example-2 {\n  display: flex;\n  justify-content: center;\n  align-items: center;\n}\n\n.example-2 .icon-content {\n  margin: 0 10px;\n  position: relative;\n}\n\n.example-2 .icon-content .tooltip {\n  position: absolute;\n  top: -30px;\n  left: 50%;\n  transform: translateX(-50%);\n  color: #fff;\n  padding: 6px 10px;\n  border-radius: 5px;\n  opacity: 0;\n  visibility: hidden;\n  font-size: 14px;\n  transition: all 0.3s ease;\n}\n\n.example-2 .icon-content:hover .tooltip {\n  opacity: 1;\n  visibility: visible;\n  top: -50px;\n}\n\n.main-panel {\n  height: 100%;\n  background-color: #f1f0ed;\n}\n</style>\n"
  },
  {
    "path": "kt-sft/ktransformers/website/tests/unit/example.spec.ts",
    "content": "import { shallowMount } from '@vue/test-utils'\nimport HelloWorld from '@/components/HelloWorld.vue'\n\ndescribe('HelloWorld.vue', () => {\n  it('renders props.msg when passed', () => {\n    const msg = 'new message'\n    const wrapper = shallowMount(HelloWorld, {\n      props: { msg }\n    })\n    expect(wrapper.text()).toMatch(msg)\n  })\n})\n"
  },
  {
    "path": "kt-sft/ktransformers/website/tsconfig.json",
    "content": "{\n  \"compilerOptions\": {\n    \"target\": \"es5\",\n    \"module\": \"esnext\",\n    \"strict\": true,\n    \"jsx\": \"preserve\",\n    \"importHelpers\": true,\n    \"moduleResolution\": \"node\",\n    \"skipLibCheck\": true,\n    \"esModuleInterop\": true,\n    \"allowSyntheticDefaultImports\": true,\n    \"forceConsistentCasingInFileNames\": true,\n    \"useDefineForClassFields\": true,\n    \"sourceMap\": true,\n    \"allowJs\": true,\n    \"baseUrl\": \".\",\n    \"types\": [\n      \"webpack-env\",\n      \"jest\"\n    ],\n    \"paths\": {\n      \"@/*\": [\n        \"src/*\"\n      ]\n    },\n    \"lib\": [\n      \"esnext\",\n      \"dom\",\n      \"dom.iterable\",\n      \"scripthost\"\n    ]\n  },\n  \"include\": [\n    \"src/**/*.ts\",\n    \"src/**/*.tsx\",\n    \"src/**/*.vue\",\n    \"tests/**/*.ts\",\n    \"tests/**/*.tsx\",\n    \"config.d.ts\"\n  ],\n \n  \"exclude\": [\n    \"node_modules\"\n  ]\n}"
  },
  {
    "path": "kt-sft/ktransformers/website/vue.config.js",
    "content": "\nmodule.exports = {\n  // 配置 webpack-dev-server 行为。\n  devServer: {\n    open: false, // 编译后默认打开浏览器\n    host: '0.0.0.0',  // 域名\n    port: 8082,  // 端口\n    https: false,  // 是否https\n    proxy: {\n        '/api': {\n          target: 'http://localhost:9016/v1', // 你的后端服务器地址\n          changeOrigin: true, // 是否允许跨域\n          pathRewrite: {\n            '/api': '' // 将 '/api' 前缀替换为空，如果你的后端不需要这个前缀\n          }\n        }\n      }\n},\npublicPath: '/web/',  // 基本路径\noutputDir: 'dist', // 构建时的输出目录\nassetsDir: 'static', // 放置静态资源的目录\nindexPath: 'index.html', // html 的输出路径\nfilenameHashing: true, // 文件名哈希值\nlintOnSave: false, // 是否在保存的时候使用 `eslint-loader` 进行检查。\n\n// 组件是如何被渲染到页面中的？ （ast：抽象语法树；vDom：虚拟DOM）\n// template ---> ast ---> render ---> vDom ---> 真实的Dom ---> 页面\n// runtime-only：将template在打包的时候，就已经编译为render函数\n// runtime-compiler：在运行的时候才去编译template\nruntimeCompiler: false,\n\ntranspileDependencies: [], // babel-loader 默认会跳过 node_modules 依赖。\nproductionSourceMap: false, // 是否为生产环境构建生成 source map\n\n//调整内部的 webpack 配置\nconfigureWebpack: () => {},\n\nchainWebpack: () => {},\n  \n}"
  },
  {
    "path": "kt-sft/merge_tensors/merge_safetensor_gguf.py",
    "content": "# this script targets to merge the fp8 safe tensor and the gguf quantized tensors.\n\nimport os\n# insert the path of the project\nimport sys\n# sys.path.insert(0, \"/home/azure/ktransformers\")\nimport argparse\nimport torch\nfrom ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf\nfrom safetensors import safe_open\nfrom safetensors.torch import save_file\nimport re\nfrom collections import defaultdict\n\ndef read_safetensor_keys_from_folder(folder_path)->dict:\n    \"\"\"    \n    :param folder_path: folder path\n    :return: key_to_file_map\n    \"\"\"\n    # check if the folder path is exist\n    if not os.path.exists(folder_path):\n        raise FileNotFoundError(f\"GGUF dir not found: {folder_path}\")\n    if os.path.isfile(folder_path):\n        folder_path = os.path.dirname(folder_path)\n    \n    key_to_file_map = {}\n\n    found_safetensor = False\n    for root, dirs, files in os.walk(folder_path):\n        # sort files\n        files = sorted(files)\n        for file in files:\n            if file.endswith(\".safetensors\"):\n                found_safetensor = True\n                file_path = os.path.join(root, file)\n                try:\n                    with safe_open(file_path, framework=\"pt\") as f:\n                        for key in f.keys():\n                            if \"model.layers.61\" in key:\n                                # skip MTP layer\n                                continue\n                            # try:\n                            #     if int(key.split('.')[2]) > 4:\n                            #         continue\n                            # except:\n                            #     pass\n                            key_to_file_map[key] = file_path\n                except Exception as e:\n                    print(f\"Error reading Safetensor file {file_path}: {e}\")\n    \n    if not found_safetensor:\n        raise FileNotFoundError(f\"No Safetensor files found in {folder_path}\")\n    \n    return key_to_file_map\n\ntensor_from_gguf = [] # todo: add keys in gguf that should be used in the final tensor\n\ndef translate_name(name:str)->str:\n    \"\"\"\n    :param name: name of the tensor\n    :return: translated name\n    \"\"\"\n    name = translate_name_to_gguf(name)\n    name = name.replace(\".up_proj.\", \".ffn_up_exps.\")\n    name = name.replace(\".down_proj.\", \".ffn_down_exps.\")\n    name = name.replace(\".gate_proj.\", \".ffn_gate_exps.\")\n    name = name.replace(\".ffn_gate_inp.e_score_correction_bias\", \".exp_probs_b.bias\") \n    return name\n    \n\ndef combine_tensor_sources(safetensor_path:str, gguf_path:str):\n    gguf_loader = GGUFLoader(gguf_path)\n    gguf_tensor_file_map = gguf_loader.tensor_file_map\n    safetensor_tensor_file_map = read_safetensor_keys_from_folder(safetensor_path)\n    \n    # build a map for the key to the tensor\n    # according to the key, we can get the tensor from the file\n    \n    target_tensor_map = {}\n    for key in safetensor_tensor_file_map.keys():\n        # for all experts, we use the gguf tensor\n        if \".mlp.experts.\" in key:\n            if '.weight_scale_inv' in key:\n                continue\n            key = '.'.join(key.split('.')[:5]+key.split('.')[-2:])\n            translated_key = translate_name(key)\n            target_tensor_map[key] = gguf_tensor_file_map[translated_key]\n            continue\n        \n        if any(target_key in key for target_key in tensor_from_gguf):\n            target_tensor_map[key] = gguf_tensor_file_map[translate_name(key)]\n        else:\n            target_tensor_map[key] = safetensor_tensor_file_map[key]\n    \n    return target_tensor_map, gguf_loader\n\ndef write_combined_tensor(target_tensor_map: dict, output_path: str, gguf_loader: GGUFLoader):\n    # Ensure output directory exists\n    os.makedirs(output_path, exist_ok=True)\n    \n    # Cache for safetensor file handles and GGUF loaders\n    safetensors_cache = {}\n    gguf_cache = {}\n    \n    # Group tensors by layer\n    layer_groups = defaultdict(list)\n    non_layer_keys = []\n    layer_pattern = re.compile(r'\\.layers\\.(\\d+)\\.')\n    \n    for key in target_tensor_map:\n        match = layer_pattern.search(key)\n        if match:\n            layer_num = int(match.group(1))\n            layer_groups[layer_num].append(key)\n        else:\n            non_layer_keys.append(key)\n    \n    # Calculate total shards\n    total_shards = len(layer_groups) + (1 if non_layer_keys else 0) - 1\n    if total_shards == 0:\n        raise ValueError(\"No tensors to save\")\n    \n    shard_idx = 0\n    \n    # Save non-layer tensors to the first shard if they exist\n    if non_layer_keys:\n        tensors = {}\n        for key in non_layer_keys:\n            file_path = target_tensor_map[key]\n            tensor = None\n            ggml_type = None\n            if file_path.endswith('.safetensors'):\n                if file_path not in safetensors_cache:\n                    safetensors_cache[file_path] = safe_open(file_path, framework='pt')\n                f = safetensors_cache[file_path]\n                tensor = f.get_tensor(key)\n            elif file_path.endswith('.gguf'):\n                gguf_name = translate_name(key)\n                tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)\n            else:\n                raise ValueError(f\"Unsupported file format: {file_path}\")\n            tensors[translate_name(key)] = tensor\n            if ggml_type:\n                ggml_type = torch.tensor(ggml_type)\n                ggml_key = translate_name(key)[:-7] + \".ggml_type\" if translate_name(key).endswith(\".weight\") else translate_name(key) + \".ggml_type\"\n                tensors[ggml_key] = ggml_type\n        \n        output_file = os.path.join(output_path, f\"model-{shard_idx:05}-of-{total_shards:05}.safetensors\")\n        print(f\"Saving non-layer tensors to {output_file}\")\n        save_file(tensors, output_file)\n        print(tensors.keys())\n\n        shard_idx += 1\n    \n    # Save each layer's tensors to subsequent shards\n    for layer_num in sorted(layer_groups.keys()):\n        layer_keys = layer_groups[layer_num]\n        tensors = {}\n        for key in layer_keys:\n            file_path = target_tensor_map[key]\n            tensor = None\n            ggml_type = None\n            if file_path.endswith('.safetensors'):\n                if file_path not in safetensors_cache:\n                    safetensors_cache[file_path] = safe_open(file_path, framework='pt')\n                f = safetensors_cache[file_path]\n                tensor = f.get_tensor(key)\n                tensor_info = tensor.shape\n            elif file_path.endswith('.gguf'):\n                gguf_name = translate_name(key)\n                tensor, ggml_type = gguf_loader.get_undequanted_tensor_and_ggml_type(gguf_name)\n                # tensor_info = gguf_loader.tensor_info[gguf_name]\n                # ggml_type = gguf_loader.tensor_info[gguf_name]['ggml_type']\n            else:\n                raise ValueError(f\"Unsupported file format: {file_path}\")\n            tensors[translate_name(key)] = tensor\n            if ggml_type:\n                ggml_type = torch.tensor(ggml_type)\n                ggml_key = translate_name(key)[:-7] + \".ggml_type\" if translate_name(key).endswith(\".weight\") else translate_name(key) + \".ggml_type\"\n                tensors[ggml_key] = ggml_type\n        \n        output_file = os.path.join(output_path, f\"model-{shard_idx:05}-of-{total_shards:05}.safetensors\")\n        print(f\"Saving layer {layer_num} to {output_file}\")\n        # print(tensors.keys())\n        save_file(tensors, output_file)\n        shard_idx += 1\n    \n    return\n    \ndef main():\n    # 创建命令行参数解析器\n    parser = argparse.ArgumentParser(description=\"Read parameters from Safetensor and GGUF files\")\n    parser.add_argument(\"--safetensor_path\", type=str, help=\"Path to the Safetensor file\", default=\"/mnt/data/model/DeepSeek-V3\")\n    parser.add_argument(\"--gguf_path\", type=str, help=\"Path to the GGUF file\", default=\"/mnt/data/model/DeepseekV3-q4km-gguf\")\n    parser.add_argument(\"--output_path\", type=str, help=\"Path to the output file\", default=\"/mnt/data/model/ktrans-safetensors/DeepSeek-V3-q4km-fp8\")\n    \n    # print all the arguments\n    print(\"All the arguments:\")\n    print(parser.parse_args())\n    \n    # 解析命令行参数\n    args = parser.parse_args()\n\n    safetensor_path = args.safetensor_path\n    gguf_path = args.gguf_path\n    output_path = args.output_path\n    \n    target_tensor_map, gguf_loader = combine_tensor_sources(safetensor_path, gguf_path)\n    write_combined_tensor(target_tensor_map, output_path, gguf_loader)\n    \n    return\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "kt-sft/pyproject.toml",
    "content": "[build-system]\nrequires = [\n  \"setuptools\",\n  \"wheel\",\n  \"cmake >= 3.20\",\n  \"torch >= 2.3.0\", \n  \"ninja\",\n  \"packaging\",\n  \"cpufeature\"\n  ]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\n\nname = \"ktransformers\"\n\ndynamic = [\"version\"]\n\ndependencies = [\n  \"torch >= 2.3.0\",\n  \"transformers == 4.51.3\",\n  \"peft == 0.14.0\",\n  \"fastapi >= 0.111.0\",\n  \"uvicorn >= 0.30.1\",\n  \"langchain >= 0.2.0\",\n  \"blessed >= 1.20.0\",\n  \"accelerate >= 0.31.0\",\n  \"sentencepiece >= 0.1.97\",\n  \"setuptools\",\n  \"ninja\",\n  \"wheel\",\n  \"colorlog\",\n  \"build\",\n  \"fire\",\n  \"protobuf\",\n  \"datasets\",\n  \"torchviz\",\n]\n\nrequires-python = \">=3.10\"\n\nauthors = [\n  {name = \"KVCache.AI\", email = \"zhang.mingxing@outlook.com\"}\n]\n\nmaintainers = [\n  {name = \"james0zan\", email = \"zhang.mingxing@outlook.com\"},\n  {name = \"awake\", email = \"awake@approaching.ai\"},\n  {name = \"unicorn chan\", email = \"nl@approaching.ai\"}\n]\n\ndescription = \"KTransformers, pronounced as Quick Transformers, is designed to enhance your Transformers experience with advanced kernel optimizations and placement/parallelism strategies.\"\n\nreadme = \"README.md\"\nlicense = \"Apache-2.0\"\nlicense-files = [\"LICENSE\"]\n\nkeywords = [\"ktransformers\", \"llm\"]\n\nclassifiers = [\n  \"Development Status :: 4 - Beta\",\n  \"Programming Language :: Python :: 3.10\",\n  \"Programming Language :: Python :: 3.11\",\n  \"Programming Language :: Python :: 3.12\"\n]\n\n[project.urls]\nHomepage = \"https://kvcache.ai\"\nRepository = \"https://github.com/kvcache-ai/ktransformers.git\"\nIssues = \"https://github.com/kvcache-ai/ktransformers/issues\"\n\n\n[project.scripts]\nktransformers = \"ktransformers.server.main:main\"\n\n[tool.setuptools.packages.find]\nwhere = [\"./\", ]\ninclude = [\"ktransformers\",\"ktransformers.*\"]\n[tool.black]\nline-length = 120\npreview = true\nunstable = true\n"
  },
  {
    "path": "kt-sft/requirements-sft.txt",
    "content": "absl-py==2.3.1\naiohappyeyeballs==2.6.1\naiohttp==3.11.18\naiosignal==1.3.2\nattrs==25.3.0\ncolorama==0.4.6\nconda-pack==0.8.1\ndatasets==3.6.0\ndill==0.3.8\neinops==0.8.1\nfrozenlist==1.6.0\ngraphviz==0.20.3\njoblib==1.5.1\nmultidict==6.4.4\nmultiprocess==0.70.16\nnltk==3.9.1\nnvidia-cufile-cu12==1.11.1.6\npandas==2.2.3\npeft==0.14.0\npropcache==0.3.1\npyarrow==20.0.0\npython-dateutil==2.9.0.post0\npython_helper==0.3.74\npytz==2025.2\nrouge_score==0.1.2\nsix==1.17.0\ntabulate==0.9.0\nthop==0.1.1.post2209072238\ntorchviz==0.0.3\ntzdata==2025.2\nxxhash==3.5.0\nyarl==1.20.0\ntorchviz"
  },
  {
    "path": "kt-sft/setup.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n'''\nDescription  :\nAuthor       : chenxl\nDate         : 2024-07-27 16:15:27\nVersion      : 1.0.0\nLastEditors  : chenxl\nLastEditTime : 2024-08-14 16:36:19\nAdapted from:\nhttps://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py\nCopyright (c) 2023, Tri Dao.\nCopyright (c) 2024 by KVCache.AI, All Rights Reserved.\n'''\n\nimport os\nimport sys\nimport re\nimport ast\nfrom collections import deque\nimport subprocess\nimport select\nimport time\nimport platform\nimport shutil\nfrom typing import List, Optional, Literal\nimport http.client\nimport urllib.request\nimport urllib.error\nfrom pathlib import Path\nfrom packaging.version import parse\nimport torch\nimport torch.version\nfrom wheel.bdist_wheel import bdist_wheel as _bdist_wheel\nfrom setuptools import setup, Extension\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME\nfrom packaging.requirements import Requirement\ntry:\n    from torch_musa.utils.simple_porting import SimplePorting\n    from torch_musa.utils.musa_extension import BuildExtension, MUSAExtension, MUSA_HOME\nexcept ImportError:\n    MUSA_HOME=None\ntry:\n    import tomllib  # Py3.11+\nexcept Exception:\n    import tomli as tomllib  # 兼容老 Python\n\ndef _load_pyproject_deps():\n    with open(\"pyproject.toml\", \"rb\") as f:\n        data = tomllib.load(f)\n    return list(data.get(\"project\", {}).get(\"dependencies\", []) or [])\n\nKTRANSFORMERS_BUILD_XPU = torch.xpu.is_available()\n\n# 检测 DEV_BACKEND 环境变量\ndev_backend = os.environ.get(\"DEV_BACKEND\", \"\").lower()\nif dev_backend == \"xpu\":\n    triton_dep = [\n        \"pytorch-triton-xpu==3.3.0\"\n    ]\nelse:\n    triton_dep = []\n\nbase_deps = _load_pyproject_deps()\ncombined_deps = base_deps + triton_dep\n\n\ndef _strip_req(reqs, name: str):\n    out = []\n    for r in reqs:\n        try:\n            rn = Requirement(r).name.lower()\n        except Exception:\n            rn = r.split()[0].lower()\n        if rn != name.lower():\n            out.append(r)\n    return out\n\n_tver = parse(torch.__version__)\n_tlow = f\"{_tver.major}.{_tver.minor}\"\n_thigh = f\"{_tver.major}.{_tver.minor + 1}\"\nTORCH_RANGE = f\"torch>={_tlow},<{_thigh}\"\ninstall_requires_pinned = _strip_req(combined_deps, \"torch\") + [TORCH_RANGE]\n\nwith_balance = os.environ.get(\"USE_BALANCE_SERVE\", \"0\") == \"1\"\n\nclass CpuInstructInfo:\n    CPU_INSTRUCT = os.getenv(\"CPU_INSTRUCT\", \"NATIVE\")\n    FANCY = \"FANCY\"\n    AVX512 = \"AVX512\"\n    AVX2 = \"AVX2\"\n    CMAKE_NATIVE = \"-DLLAMA_NATIVE=ON\"\n    CMAKE_FANCY = \"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON -DLLAMA_AVX512_FANCY_SIMD=ON\"\n    CMAKE_AVX512 = \"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON -DLLAMA_AVX512=ON\"\n    CMAKE_AVX2 = \"-DLLAMA_NATIVE=OFF -DLLAMA_FMA=ON -DLLAMA_F16C=ON -DLLAMA_AVX=ON -DLLAMA_AVX2=ON\"\n\nclass VersionInfo:\n    THIS_DIR = os.path.dirname(os.path.abspath(__file__))\n    PACKAGE_NAME = \"ktransformers\"\n    BASE_WHEEL_URL:str = (\n        \"https://github.com/kvcache-ai/ktransformers/releases/download/{tag_name}/{wheel_filename}\"\n    )\n    FORCE_BUILD = os.getenv(\"KTRANSFORMERS_FORCE_BUILD\", \"FALSE\") == \"TRUE\"\n\n    def get_musa_bare_metal_version(self, musa_dir):\n        raw_output = subprocess.run(\n            [musa_dir + \"/bin/mcc\", \"-v\"], check=True,\n            stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode(\"utf-8\")\n        output = raw_output.split()\n        release_idx = output.index(\"version\") + 1\n        bare_metal_version = parse(output[release_idx].split(\",\")[0])\n        musa_version = f\"{bare_metal_version.major}{bare_metal_version.minor}\"\n        return musa_version\n\n    def get_rocm_bare_metal_version(self, rocm_dir):\n        \"\"\"\n        Get the ROCm version from the ROCm installation directory.\n\n        Args:\n            rocm_dir: Path to the ROCm installation directory\n\n        Returns:\n            A string representation of the ROCm version (e.g., \"63\" for ROCm 6.3)\n        \"\"\"\n        try:\n            # Try using rocm_agent_enumerator to get version info\n            raw_output = subprocess.check_output(\n                [rocm_dir + \"/bin/rocminfo\", \"--version\"],\n                universal_newlines=True,\n                stderr=subprocess.STDOUT)\n            # Extract version number from output\n            match = re.search(r'(\\d+\\.\\d+)', raw_output)\n            if match:\n                version_str = match.group(1)\n                version = parse(version_str)\n                rocm_version = f\"{version.major}{version.minor}\"\n                return rocm_version\n        except (subprocess.CalledProcessError, FileNotFoundError):\n            # If rocminfo --version fails, try alternative methods\n            pass\n\n        try:\n            # Try reading version from release file\n            with open(os.path.join(rocm_dir, \"share/doc/hip/version.txt\"), \"r\") as f:\n                version_str = f.read().strip()\n                version = parse(version_str)\n                rocm_version = f\"{version.major}{version.minor}\"\n                return rocm_version\n        except (FileNotFoundError, IOError):\n            pass\n\n        # If all else fails, try to extract from directory name\n        dir_name = os.path.basename(os.path.normpath(rocm_dir))\n        match = re.search(r'rocm-(\\d+\\.\\d+)', dir_name)\n        if match:\n            version_str = match.group(1)\n            version = parse(version_str)\n            rocm_version = f\"{version.major}{version.minor}\"\n            return rocm_version\n\n        # Fallback to extracting from hipcc version\n        try:\n            raw_output = subprocess.check_output(\n                [rocm_dir + \"/bin/hipcc\", \"--version\"],\n                universal_newlines=True,\n                stderr=subprocess.STDOUT)\n            match = re.search(r'HIP version: (\\d+\\.\\d+)', raw_output)\n            if match:\n                version_str = match.group(1)\n                version = parse(version_str)\n                rocm_version = f\"{version.major}{version.minor}\"\n                return rocm_version\n        except (subprocess.CalledProcessError, FileNotFoundError):\n            pass\n\n        # If we still can't determine the version, raise an error\n        raise ValueError(f\"Could not determine ROCm version from directory: {rocm_dir}\")\n\n    def get_cuda_bare_metal_version(self, cuda_dir):\n        raw_output = subprocess.check_output(\n            [cuda_dir + \"/bin/nvcc\", \"-V\"], universal_newlines=True)\n        output = raw_output.split()\n        release_idx = output.index(\"release\") + 1\n        bare_metal_version = parse(output[release_idx].split(\",\")[0])\n        cuda_version = f\"{bare_metal_version.major}{bare_metal_version.minor}\"\n        return cuda_version\n\n    def get_cuda_version_of_torch(self):\n        torch_cuda_version = parse(torch.version.cuda)\n        cuda_version = f\"{torch_cuda_version.major}{torch_cuda_version.minor}\"\n        return cuda_version\n\n    def get_platform(self,):\n        \"\"\"\n        Returns the platform name as used in wheel filenames.\n        \"\"\"\n        if sys.platform.startswith(\"linux\"):\n            return f'linux_{platform.uname().machine}'\n        elif sys.platform == \"win32\":\n            return \"win_amd64\"\n        else:\n            raise ValueError(\"Unsupported platform: {}\".format(sys.platform))\n\n    def get_cpu_instruct(self,):\n        if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:\n            return \"fancy\"\n        elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:\n            return \"avx512\"\n        elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2:\n            return \"avx2\"\n        else:\n            print(\"Using native cpu instruct\")\n        if sys.platform.startswith(\"linux\"):\n            with open('/proc/cpuinfo', 'r', encoding=\"utf-8\") as cpu_f:\n                cpuinfo = cpu_f.read()\n            flags_line = [line for line in cpuinfo.split(\n                '\\n') if line.startswith('flags')][0]\n            flags = flags_line.split(':')[1].strip().split(' ')\n            # fancy with AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI\n            for flag in flags:\n                if 'avx512bw' in flag:\n                    return 'fancy'\n            for flag in flags:\n                if 'avx512' in flag:\n                    return 'avx512'\n            for flag in flags:\n                if 'avx2' in flag:\n                    return 'avx2'\n            raise ValueError(\n                \"Unsupported cpu Instructions: {}\".format(flags_line))\n        elif sys.platform == \"win32\":\n            from cpufeature.extension import CPUFeature\n\n            if CPUFeature.get(\"AVX512bw\", False):\n                return 'fancy'\n            if CPUFeature.get(\"AVX512f\", False):\n                return 'avx512'\n            if CPUFeature.get(\"AVX2\", False):\n                return 'avx2'\n            raise ValueError(\n                \"Unsupported cpu Instructions: {}\".format(str(CPUFeature)))\n        else:\n            raise ValueError(\"Unsupported platform: {}\".format(sys.platform))\n\n    def get_torch_version(self,):\n        torch_version_raw = parse(torch.__version__)\n        torch_version = f\"{torch_version_raw.major}{torch_version_raw.minor}\"\n        return torch_version\n\n    def get_flash_version(self,):\n        version_file = os.path.join(\n            Path(VersionInfo.THIS_DIR), VersionInfo.PACKAGE_NAME, \"__init__.py\")\n        with open(version_file, \"r\", encoding=\"utf-8\") as f:\n            version_match = re.search(\n                r\"^__version__\\s*=\\s*(.*)$\", f.read(), re.MULTILINE)\n        flash_version = ast.literal_eval(version_match.group(1))\n        return flash_version\n\n    def get_package_version(self, full_version=False):\n        flash_version = str(self.get_flash_version())\n        torch_version = self.get_torch_version()\n        cpu_instruct = self.get_cpu_instruct()\n        backend_version = \"\"\n        if CUDA_HOME is not None:\n            backend_version = f\"cu{self.get_cuda_version_of_torch()}\"\n        elif MUSA_HOME is not None:\n            backend_version = f\"mu{self.get_musa_bare_metal_version(MUSA_HOME)}\"\n        elif ROCM_HOME is not None:\n            backend_version = f\"rocm{self.get_rocm_bare_metal_version(ROCM_HOME)}\"\n        elif torch.xpu.is_available():\n            backend_version = f\"xpu\"\n        else:\n            raise ValueError(\"Unsupported backend: CUDA_HOME MUSA_HOME ROCM_HOME all not set and XPU is not available.\")\n        package_version = f\"{flash_version}+{backend_version}torch{torch_version}{cpu_instruct}\"\n        if full_version:\n            return package_version\n        if not VersionInfo.FORCE_BUILD:\n            return flash_version\n        return package_version\n\n\nclass BuildWheelsCommand(_bdist_wheel):\n    def get_wheel_name(self,):\n        version_info = VersionInfo()\n        package_version = version_info.get_package_version(full_version=True)\n        flash_version = version_info.get_flash_version()\n        python_version = f\"cp{sys.version_info.major}{sys.version_info.minor}\"\n        wheel_filename = f\"{VersionInfo.PACKAGE_NAME}-{package_version}-{python_version}-{python_version}-{version_info.get_platform()}.whl\"\n        wheel_url = VersionInfo.BASE_WHEEL_URL.format(tag_name=f\"v{flash_version}\", wheel_filename=wheel_filename)\n        return wheel_filename, wheel_url\n\n\n    def run(self):\n        if VersionInfo.FORCE_BUILD:\n            super().run()\n            return\n        wheel_filename, wheel_url = self.get_wheel_name()\n        print(\"Guessing wheel URL: \", wheel_url)\n        try:\n            urllib.request.urlretrieve(wheel_url, wheel_filename)\n            # Make the archive\n            # Lifted from the root wheel processing command\n            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85\n            if not os.path.exists(self.dist_dir):\n                os.makedirs(self.dist_dir)\n\n            impl_tag, abi_tag, plat_tag = self.get_tag()\n            archive_basename = f\"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}\"\n\n            wheel_path = os.path.join(self.dist_dir, archive_basename + \".whl\")\n            print(\"Raw wheel path\", wheel_path)\n            shutil.move(wheel_filename, wheel_path)\n        except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):\n            print(\"Precompiled wheel not found. Building from source...\")\n            # If the wheel could not be downloaded, build from source\n            super().run()\n\n\nANSI_ESCAPE = re.compile(\n    r'\\033[@-Z\\\\-_\\[\\]P]|\\033\\[[0-?]*[ -/]*[@-~]|\\033][^\\007\\033]*\\007|[\\000-\\037]'\n)\n\ndef colored(text, color=None, bold=False):\n    fmt = []\n    if color== 'red':\n        fmt.append('31')\n    elif color == 'green':\n        fmt.append('32')\n    if bold:\n        fmt.append('1')\n\n    return f\"\\033[{';'.join(fmt)}m{text}\\033[0m\"\n\n\ndef split_line(text: str) -> List[str]:\n    \"\"\"Split text into lines based on terminal width.\"\"\"\n    term_width = shutil.get_terminal_size().columns or 80\n    if not text.strip():\n        return []\n    # Split by explicit newlines and wrap long lines\n    lines = []\n    for line in text.split('\\n'):\n        while len(line) > term_width:\n            lines.append(line[:term_width])\n            line = line[term_width:]\n        if line:\n            lines.append(line)\n    return lines\n\n\n\nANSI_ESCAPE = re.compile(\n    r'\\033[@-Z\\\\-_\\[\\]P]|\\033\\[[0-?]*[ -/]*[@-~]|\\033][^\\007\\033]*\\007|[\\000-\\037]'\n)\n\ndef colored(text, color=None, bold=False):\n    fmt = []\n    if color== 'red':\n        fmt.append('31')\n    elif color == 'green':\n        fmt.append('32')\n    if bold:\n        fmt.append('1')\n\n    return f\"\\033[{';'.join(fmt)}m{text}\\033[0m\"\n\n\ndef split_line(text: str) -> List[str]:\n    \"\"\"Split text into lines based on terminal width.\"\"\"\n    term_width = shutil.get_terminal_size().columns or 80\n    if not text.strip():\n        return []\n    # Split by explicit newlines and wrap long lines\n    lines = []\n    for line in text.split('\\n'):\n        while len(line) > term_width:\n            lines.append(line[:term_width])\n            line = line[term_width:]\n        if line:\n            lines.append(line)\n    return lines\n\n\ndef run_command_with_live_tail(ext: str, command: List[str], output_lines: int = 20,\n                               refresh_rate: float = 0.1, cwd: Optional[str] = None):\n    \"\"\"\n    Execute a script-like command with real-time output of the last `output_lines` lines.\n\n    - during execution: displays the last `output_lines` lines of output in real-time.\n    - On success: Clears the displayed output.\n    - On failure: Prints the full command output.\n\n    Args:\n        ext (str): the name of the native extension currently building.\n        command (List[str]): The command to execute, as a list of arguments.\n        output_lines (int, optional): Number of terminal lines to display during live output. Defaults to 20.\n        refresh_rate (float, optional): Time in seconds between output refreshes. Defaults to 0.1.\n        cwd (Optional[str], optional): Working directory to run the command in. Defaults to current directory.\n    \"\"\"\n    # Dump all subprocess output without any buffering if stdout is not a terminal\n    if not sys.stdout.isatty():\n        return subprocess.run(command, cwd=cwd, check=True)\n    # Start time for elapsed time calculation\n    start = time.time()\n    # Buffer for all output\n    all_output = []\n    write_buffer = deque(maxlen=output_lines)\n    # Current number of lines from sub process displayed\n    current_lines = 0\n\n    # ANSI escape codes for terminal control\n    CLEAR_LINE = '\\033[K'\n    MOVE_UP = '\\033[1A'\n    SAVE_CURSOR = '\\0337'\n    RESTORE_CURSOR = '\\0338'\n    CLEAR_REMAINING = '\\033[J'\n\n    def write_progress(status: Literal['RUNNING', 'SUCCEED', 'FAILED'] = 'RUNNING',\n                       new_line: Optional[str] = None):\n        \"\"\"Update terminal display with latest output\"\"\"\n        nonlocal current_lines, process\n        sys.stdout.write(SAVE_CURSOR)\n        sys.stdout.write(MOVE_UP * current_lines)\n        banner = f\"ext={ext} pid={process.pid} status={status.upper()} elapsed=({time.time()-start:.2f}S)\\n\"\n        if status != 'FAILED':\n            banner = colored(banner, 'green', bold=True)\n        else:\n            banner = colored(banner, 'red', bold=True)\n        sys.stdout.write(CLEAR_LINE + banner)\n        if new_line is not None:\n            all_output.append(new_line)\n            write_buffer.extend(split_line(ANSI_ESCAPE.sub('', new_line).rstrip()))\n        elif status == 'RUNNING':\n            sys.stdout.write(RESTORE_CURSOR)\n            sys.stdout.flush()\n            return\n\n        sys.stdout.write(CLEAR_REMAINING)\n        if status == 'RUNNING':\n            current_lines = 1 + len(write_buffer)\n            for text in write_buffer:\n                sys.stdout.write(text + '\\n')\n        elif status == 'FAILED':\n            for text in all_output:\n                sys.stdout.write(text)\n        sys.stdout.flush()\n\n    # Start subprocess\n    sys.stdout.write(colored(f'ext={ext} command={\" \".join(str(c) for c in command)}\\n', bold=True))\n    sys.stdout.flush()\n    process = subprocess.Popen(\n        command,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.STDOUT,\n        cwd=cwd,\n        text=True,\n        bufsize=1\n    )\n\n    try:\n        write_progress()\n        poll_obj = select.poll()\n        poll_obj.register(process.stdout, select.POLLIN)\n        while process.poll() is None:\n            poll_result = poll_obj.poll(refresh_rate * 1000)\n            if poll_result:\n                write_progress(new_line=process.stdout.readline())\n            else:\n                write_progress()\n\n        # Get any remaining output\n        while True:\n            line = process.stdout.readline()\n            if not line:\n                break\n            write_progress(new_line=line)\n    except BaseException as e:\n        process.terminate()\n        raise e\n    finally:\n        exit_code = process.wait()\n        write_progress(status='SUCCEED' if exit_code == 0 else 'FAILED')\n\n\n# Convert distutils Windows platform specifiers to CMake -A arguments\nPLAT_TO_CMAKE = {\n    \"win32\": \"Win32\",\n    \"win-amd64\": \"x64\",\n    \"win-arm32\": \"ARM\",\n    \"win-arm64\": \"ARM64\",\n}\n\n\nclass CMakeExtension(Extension):\n    def __init__(self, name: str, sourcedir: str) -> None:\n        super().__init__(name, sources=[])\n        print(name, sourcedir)\n        self.sourcedir = sourcedir\n\ndef get_cmake_abi_args(cmake_args):\n    if torch.compiled_with_cxx11_abi():\n        cmake_args.append(\"-D_GLIBCXX_USE_CXX11_ABI=1\")\n    else:\n        cmake_args.append(\"-D_GLIBCXX_USE_CXX11_ABI=0\")\n    return cmake_args\n\nclass CMakeBuild(BuildExtension):\n\n    def build_extension(self, ext) -> None:\n        if not isinstance(ext, CMakeExtension):\n            super().build_extension(ext)\n            return\n        ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name)\n        extdir = ext_fullpath.parent.resolve()\n\n        # Using this requires trailing slash for auto-detection & inclusion of\n        # auxiliary \"native\" libs\n\n        debug = int(os.environ.get(\"DEBUG\", 0)\n                    ) if self.debug is None else self.debug\n        cfg = \"Debug\" if debug else \"Release\"\n\n        # CMake lets you override the generator - we need to check this.\n        # Can be set with Conda-Build, for example.\n        cmake_generator = os.environ.get(\"CMAKE_GENERATOR\", \"\")\n\n        # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON\n        # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code\n        # from Python.\n        cmake_args = [\n            f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}\",\n            f\"-DPYTHON_EXECUTABLE={sys.executable}\",\n            f\"-DCMAKE_BUILD_TYPE={cfg}\",  # not used on MSVC, but no harm\n        ]\n\n        if CUDA_HOME is not None:\n            cmake_args += [\"-DKTRANSFORMERS_USE_CUDA=ON\"]\n        elif MUSA_HOME is not None:\n            cmake_args += [\"-DKTRANSFORMERS_USE_MUSA=ON\"]\n        elif ROCM_HOME is not None:\n            cmake_args += [\"-DKTRANSFORMERS_USE_ROCM=ON\"]\n        elif KTRANSFORMERS_BUILD_XPU:\n            cmake_args += [\"-DKTRANSFORMERS_USE_XPU=ON\", \"-DKTRANSFORMERS_USE_CUDA=OFF\"]\n        else:\n            raise ValueError(\"Unsupported backend: CUDA_HOME, MUSA_HOME, and ROCM_HOME are not set and XPU is not available.\")\n        \n        cmake_args = get_cmake_abi_args(cmake_args)\n        # log cmake_args\n        print(\"CMake args:\", cmake_args)\n\n        build_args = []\n        if \"CMAKE_ARGS\" in os.environ:\n            cmake_args += [\n                item for item in os.environ[\"CMAKE_ARGS\"].split(\" \") if item]\n\n        if CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.FANCY:\n            cpu_args = CpuInstructInfo.CMAKE_FANCY\n        elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX512:\n            cpu_args = CpuInstructInfo.CMAKE_AVX512\n        elif CpuInstructInfo.CPU_INSTRUCT == CpuInstructInfo.AVX2:\n            cpu_args = CpuInstructInfo.CMAKE_AVX2\n        else:\n            cpu_args = CpuInstructInfo.CMAKE_NATIVE\n\n        cmake_args += [\n            item for item in cpu_args.split(\" \") if item\n        ]\n        # In this example, we pass in the version to C++. You might not need to.\n        cmake_args += [\n            f\"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}\"]\n        if self.compiler.compiler_type != \"msvc\":\n            if not cmake_generator or cmake_generator == \"Ninja\":\n                pass\n                # try:\n                #     import ninja\n\n                #     ninja_executable_path = Path(ninja.BIN_DIR) / \"ninja\"\n                #     cmake_args += [\n                #         \"-GNinja\",\n                #         f\"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}\",\n                #     ]\n                # except ImportError:\n                #     pass\n\n        else:\n            # Single config generators are handled \"normally\"\n            single_config = any(\n                x in cmake_generator for x in {\"NMake\", \"Ninja\"})\n\n            # CMake allows an arch-in-generator style for backward compatibility\n            contains_arch = any(x in cmake_generator for x in {\"ARM\", \"Win64\"})\n            if not single_config and not contains_arch and cmake_generator:\n                cmake_args += [\"-A\", PLAT_TO_CMAKE[self.plat_name]]\n\n            # Multi-config generators have a different way to specify configs\n            if not single_config:\n                cmake_args += [\n                    f\"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}\"\n                ]\n                build_args += [\"--config\", cfg]\n\n        if sys.platform.startswith(\"darwin\"):\n            # Cross-compile support for macOS - respect ARCHFLAGS if set\n            archs = re.findall(r\"-arch (\\S+)\", os.environ.get(\"ARCHFLAGS\", \"\"))\n            if archs:\n                cmake_args += [\n                    \"-DCMAKE_OSX_ARCHITECTURES={}\".format(\";\".join(archs))]\n\n        if \"CMAKE_BUILD_PARALLEL_LEVEL\" not in os.environ:\n            cpu_count = os.cpu_count()\n            if cpu_count is None:\n                cpu_count = 1\n            if hasattr(self, \"parallel\") and self.parallel:\n                build_args += [f\"--parallel={self.parallel}\"]\n            else:\n                build_args += [f\"--parallel={cpu_count}\"]\n        print(\"CMake args:\", cmake_args)\n        build_temp = Path(ext.sourcedir) / \"build\"\n        print(\"build_temp:\", build_temp)\n\n        if not build_temp.exists():\n            build_temp.mkdir(parents=True)\n        run_command_with_live_tail(ext.name,\n            [\"cmake\", ext.sourcedir, *cmake_args], cwd=build_temp\n        )\n        run_command_with_live_tail(ext.name,\n            [\"cmake\", \"--build\", build_temp, \"--verbose\", *build_args], cwd=build_temp\n        )\n\nif CUDA_HOME is not None or ROCM_HOME is not None:\n    ops_module = CUDAExtension('KTransformersOps', [\n        'csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu',\n        'csrc/ktransformers_ext/cuda/binding.cpp',\n        'csrc/ktransformers_ext/cuda/gptq_marlin/gptq_marlin.cu'\n    ],\n    extra_compile_args={\n            'cxx': ['-O3', '-DKTRANSFORMERS_USE_CUDA'],\n            'nvcc': [\n                '-O3',\n                # '--use_fast_math',\n                '-Xcompiler', '-fPIC',\n                '-DKTRANSFORMERS_USE_CUDA',\n            ]\n        }\n    )\nelif MUSA_HOME is not None:\n    SimplePorting(cuda_dir_path=\"csrc/ktransformers_ext/cuda\", mapping_rule={\n        # Common rules\n        \"at::cuda\": \"at::musa\",\n        \"#include <ATen/cuda/CUDAContext.h>\": \"#include \\\"torch_musa/csrc/aten/musa/MUSAContext.h\\\"\",\n        \"#include <c10/cuda/CUDAGuard.h>\": \"#include \\\"torch_musa/csrc/core/MUSAGuard.h\\\"\",\n        \"nv_bfloat16\": \"mt_bfloat16\",\n        }).run()\n    ops_module = MUSAExtension('KTransformersOps', [\n        'csrc/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',\n        'csrc/ktransformers_ext/cuda_musa/binding.cpp',\n        # TODO: Add Marlin support for MUSA.\n        # 'csrc/ktransformers_ext/cuda_musa/gptq_marlin/gptq_marlin.mu'\n    ],\n    extra_compile_args={\n            'cxx': ['force_mcc'],\n            'mcc': [\n                '-O3',\n                '-DKTRANSFORMERS_USE_MUSA',\n                '-DTHRUST_IGNORE_CUB_VERSION_CHECK',\n            ]\n        }\n    )\nelif torch.xpu.is_available(): #XPUExtension is not available now.\n    ops_module = None\nelse:\n    raise ValueError(\"Unsupported backend: CUDA_HOME ROCM_HOME MUSA_HOME are not set and XPU is not available.\")\n\nif not torch.xpu.is_available():\n    ext_modules = [\n        CMakeExtension(\"cpuinfer_ext\", os.fspath(Path(\"\").resolve() / \"csrc\" / \"ktransformers_ext\")),\n        ops_module,\n        CUDAExtension(\n            'vLLMMarlin', [\n                'csrc/custom_marlin/binding.cpp',\n                'csrc/custom_marlin/gptq_marlin/gptq_marlin.cu',\n                'csrc/custom_marlin/gptq_marlin/gptq_marlin_repack.cu',\n            ],\n            extra_compile_args={\n                'cxx': ['-O3'],\n                'nvcc': ['-O3', '-Xcompiler', '-fPIC'],\n            },\n        )\n    ]\n    if with_balance:\n        print(\"using balance_serve\")\n        ext_modules.append(\n            CMakeExtension(\"balance_serve\", os.fspath(Path(\"\").resolve()/ \"csrc\"/ \"balance_serve\"))\n        )\nelse:\n    ext_modules = [\n        CMakeExtension(\"cpuinfer_ext\", os.fspath(Path(\"\").resolve() / \"csrc\" / \"ktransformers_ext\")),\n    ]\n\nsetup(\n    name=VersionInfo.PACKAGE_NAME,\n    version=VersionInfo().get_package_version(),\n    install_requires=install_requires_pinned,\n    cmdclass={\"bdist_wheel\":BuildWheelsCommand ,\"build_ext\": CMakeBuild},\n    ext_modules=ext_modules\n)\n"
  },
  {
    "path": "kt-sft/test_adapter/data_transfer.py",
    "content": "import json\n\nconverted_data = []\nwith open('/data/user23202791/lpl/LLaMA-Factory/examples/KT_used/translation.jsonl', 'r', encoding='utf-8') as f:\n    for line in f:\n        data = json.loads(line)\n        converted_data.append({\n            \"instruction\": \"\",\n            \"input\": data[\"问\"],\n            \"output\": data[\"答\"]\n        })\n\nwith open('/data/user23202791/lpl/LLaMA-Factory/examples/KT_used/sft_translation.json', 'w', encoding='utf-8') as f:\n    json.dump(converted_data, f, ensure_ascii=False, indent=4)"
  },
  {
    "path": "kt-sft/test_adapter/infer_with_adapter.py",
    "content": "import torch\nimport os\n\ncheckpoint_dir = \"/home/yj/ktransformers/test_adapter/demo_adapter_KT_target_module/checkpoint-6600\"  # 请将此处替换为实际文件夹路径\n\nfor filename in os.listdir(checkpoint_dir):\n    file_path = os.path.join(checkpoint_dir, filename)\n    if filename.endswith(('.pt', '.bin', '.pth')):\n        try:\n            loaded_data = torch.load(file_path)\n            print(f\"===== 文件: {filename} =====\")\n            print(f\"数据类型: {type(loaded_data)}\")\n            \n            if isinstance(loaded_data, dict):\n                print(\"字典包含的键:\", list(loaded_data.keys()))\n                # 示例：打印优化器状态的部分参数（若为优化器文件）\n                if \"state\" in loaded_data and \"param_groups\" in loaded_data:\n                    print(\"优化器示例参数：\")\n                    print(\"param_groups 前2项:\", loaded_data[\"param_groups\"][:2])\n                    print(\"state 中前2个参数的状态:\", list(loaded_data[\"state\"].items())[:2])\n            elif isinstance(loaded_data, torch.nn.Module):\n                print(\"模块参数列表:\")\n                for name, param in loaded_data.named_parameters():\n                    print(f\"参数名: {name}, 形状: {param.shape}\")\n            else:\n                print(\"数据内容预览:\", loaded_data)\n        except Exception as e:\n            print(f\"读取 {filename} 时出错: {str(e)}\")"
  },
  {
    "path": "kt-sft/test_adapter/inspect_adapter.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\ninspect_adapter.py  ‒  查看 LoRA / Adapter checkpoint 信息\n------------------------------------------------------------\n示例：\n  python inspect_adapter.py ./checkpoint\n  python inspect_adapter.py ./checkpoint --show-params            # 打印全部权重行\n  python inspect_adapter.py ./checkpoint --param lora_A.weight    # 只看某个权重\n  python inspect_adapter.py ./checkpoint --dump-all               # 导出所有张量\n\"\"\"\nimport argparse\nimport json\nfrom pathlib import Path\n\nimport torch\nfrom safetensors.torch import load_file as safe_load\nfrom tabulate import tabulate\n\n\ndef load_json(p: Path):\n    with open(p, \"r\", encoding=\"utf-8\") as f:\n        return json.load(f)\n\n\ndef human_readable(num: int) -> str:\n    for unit in [\"\", \"K\", \"M\", \"B\"]:\n        if abs(num) < 1000:\n            return f\"{num:,.0f}{unit}\"\n        num /= 1000\n    return f\"{num:.1f}T\"\n\n\ndef inspect_adapter_weights(weight_path: Path):\n    \"\"\"\n    读取 adapter_model.safetensors / .bin / .pt\n    返回 (rows, total_params, state) 三元组\n    \"\"\"\n    if weight_path.suffix == \".safetensors\":\n        state = safe_load(str(weight_path))\n    else:\n        state = torch.load(str(weight_path), map_location=\"cpu\")\n\n    rows, total = [], 0\n    for name, tensor in state.items():\n        n = tensor.numel()\n        total += n\n        rows.append([\n            name,\n            list(tensor.shape),\n            str(tensor.dtype).replace(\"torch.\", \"\"),\n            human_readable(n)\n        ])\n    rows.sort(key=lambda x: x[0])\n    return rows, total, state\n\n\ndef maybe_print_optimizer(optimizer_pt: Path, max_keys: int = 20):\n    try:\n        opt_state = torch.load(str(optimizer_pt), map_location=\"cpu\")\n    except Exception as e:\n        print(f\"[optimizer.pt] 读取失败：{e}\")\n        return\n    print(\"\\n====== optimizer.pt 结构 (部分) ======\")\n    if isinstance(opt_state, dict):\n        for i, k in enumerate(opt_state.keys()):\n            if i >= max_keys:\n                print(\"... (省略)\")\n                break\n            print(f\"{k}: type={type(opt_state[k])}\")\n    else:\n        print(f\"type={type(opt_state)} 非典型，请自行查看。\")\n\n\ndef maybe_print_scheduler(scheduler_pt: Path, max_keys: int = 20):\n    try:\n        sch_state = torch.load(str(scheduler_pt), map_location=\"cpu\")\n    except Exception as e:\n        print(f\"[scheduler.pt] 读取失败：{e}\")\n        return\n    print(\"\\n====== scheduler.pt 结构 (部分) ======\")\n    if isinstance(sch_state, dict):\n        for i, (k, v) in enumerate(sch_state.items()):\n            if i >= max_keys:\n                print(\"... (省略)\")\n                break\n            print(f\"{k}: type={type(v)}\")\n    else:\n        print(f\"type={type(sch_state)} 非典型，请自行查看。\")\n\n\ndef maybe_print_rng(rng_pth: Path):\n    try:\n        rng = torch.load(str(rng_pth), map_location=\"cpu\")\n    except Exception as e:\n        print(f\"[rng_state.pth] 读取失败：{e}\")\n        return\n    print(\"\\n====== rng_state.pth 键列表 ======\")\n    if isinstance(rng, dict):\n        for k in rng.keys():\n            print(f\"- {k}\")\n    else:\n        print(f\"type={type(rng)} 非典型，请自行查看。\")\n\n\ndef dump_tensors(state: dict, out_dir=\"tensor_dump\"):\n    \"\"\"\n    将 state 的每个张量写入 txt（repr）并可选保存二进制 .pt\n    \"\"\"\n    out_dir = Path(out_dir)\n    out_dir.mkdir(exist_ok=True)\n    torch.set_printoptions(sci_mode=False, linewidth=180)\n\n    for name, tensor in state.items():\n        safe_name = name.replace(\"/\", \"_\")\n        txt_path = out_dir / f\"{safe_name}.txt\"\n        with open(txt_path, \"w\") as f:\n            f.write(repr(tensor))\n\n        # 若需要二进制，取消下一行注释\n        # torch.save(tensor, out_dir / f\"{safe_name}.pt\")\n\n    print(f\"[done] 已把 {len(state)} 个张量写入 {out_dir}/\")\n\n\ndef main():\n    parser = argparse.ArgumentParser(\n        description=\"检查 LoRA / Adapter checkpoint 内容\")\n    parser.add_argument(\"ckpt_dir\", type=str,\n                        help=\"包含 adapter_config.json / adapter_model.safetensors 的目录\")\n    parser.add_argument(\"--show-params\", action=\"store_true\",\n                        help=\"打印所有权重摘要（默认只显示前 30 行）\")\n    parser.add_argument(\"--param\", type=str,\n                        help=\"仅打印指定参数的完整张量\")\n    parser.add_argument(\"--dump-all\", action=\"store_true\",\n                        help=\"把所有张量完整写入文件夹\")\n    args = parser.parse_args()\n\n    d = Path(args.ckpt_dir).expanduser()\n    if not d.exists():\n        raise FileNotFoundError(d)\n\n    # ========== adapter_config.json ==========\n    cfg_path = d / \"adapter_config.json\"\n    if cfg_path.exists():\n        print(\"====== adapter_config.json ======\")\n        print(json.dumps(load_json(cfg_path), indent=2, ensure_ascii=False))\n    else:\n        print(\"未找到 adapter_config.json\")\n\n    # ========== trainer_state.json ==========\n    ts_path = d / \"trainer_state.json\"\n    if ts_path.exists():\n        ts = load_json(ts_path)\n        print(\"\\n====== trainer_state.json (节选) ======\")\n        sel = {k: ts.get(k, None) for k in\n               [\"global_step\", \"best_metric\", \"best_model_checkpoint\", \"log_history\"]}\n        if isinstance(sel.get(\"log_history\"), list) and len(sel[\"log_history\"]) > 3:\n            sel[\"log_history\"] = sel[\"log_history\"][-3:]\n        print(json.dumps(sel, indent=2, ensure_ascii=False))\n    else:\n        print(\"\\n未找到 trainer_state.json\")\n\n    # ========== adapter_model.* ==========\n    st_path = next((d / n for n in\n                   [\"adapter_model.safetensors\", \"adapter_model.bin\", \"adapter_model.pt\"]\n                   if (d / n).exists()), None)\n\n    if st_path is None:\n        print(\"\\n未找到 adapter_model.* (safetensors/bin/pt)\")\n        state = {}\n    else:\n        rows, total, state = inspect_adapter_weights(st_path)\n\n        # 若用户指定 --param，仅打印该张量\n        if args.param is not None:\n            if args.param not in state:\n                raise KeyError(f\"参数 {args.param!r} 不存在！\")\n            torch.set_printoptions(sci_mode=False, linewidth=180, profile=\"full\")\n            print(f\"\\n====== {args.param} 的完整张量 ======\")\n            print(state[args.param])\n            return  # 提前结束\n\n        print(f\"\\n====== {st_path.name} 中的可训练参数（共 {human_readable(total)} 个元素）======\")\n        if args.show_params:\n            print(tabulate(rows, headers=[\"参数名\", \"形状\", \"dtype\", \"元素数\"], tablefmt=\"github\"))\n        else:\n            head = rows[:30]\n            print(tabulate(head, headers=[\"参数名\", \"形状\", \"dtype\", \"元素数\"], tablefmt=\"github\"))\n            if len(rows) > 30:\n                print(f\"... 还有 {len(rows) - 30} 个参数未展示，使用 --show-params 查看全部。\")\n\n        # --dump-all 时将所有张量写文件\n        if args.dump_all:\n            dump_tensors(state, out_dir=f\"{st_path.stem}_dump\")\n\n    # ========== 其它 state_dict ==========\n    if (d / \"optimizer.pt\").exists():\n        maybe_print_optimizer(d / \"optimizer.pt\")\n    if (d / \"scheduler.pt\").exists():\n        maybe_print_scheduler(d / \"scheduler.pt\")\n    if (d / \"rng_state.pth\").exists():\n        maybe_print_rng(d / \"rng_state.pth\")\n\n    print(\"\\nDone.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-sft/test_adapter/pred2metrics.py",
    "content": "import json\nimport argparse\nfrom pathlib import Path\nfrom ktransformers.sft.metrics import ComputeSimilarity\nfrom transformers import AutoTokenizer\nfrom transformers.trainer_utils import EvalPrediction\n\ndef load_pred_ref(pred_file: Path):\n    data = json.loads(pred_file.read_text(encoding=\"utf-8\"))\n    preds, refs = [], []\n    for it in data:\n        preds.append(\"\" if it.get(\"prediction\") is None else str(it.get(\"prediction\")))\n        refs.append(\"\" if it.get(\"label\") is None else str(it.get(\"label\")))\n    return preds, refs\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--pred-file\", type=str, required=True)\n    parser.add_argument(\"--output-dir\", type=str, required=True)\n    parser.add_argument(\"--tokenizer\", type=str, required=True)\n    args = parser.parse_args()\n\n    pred_file = Path(args.pred_file)\n    output_dir = Path(args.output_dir)\n    output_dir.mkdir(parents=True, exist_ok=True)\n    metric_file = output_dir / \"metrics.json\"\n\n    preds, refs = load_pred_ref(pred_file)\n    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True)\n    compute_metrics = ComputeSimilarity(tokenizer)\n    enc_pred = tokenizer(preds, add_special_tokens=False, padding=True, return_tensors=\"np\")\n    enc_ref  = tokenizer(refs,  add_special_tokens=False, padding=True, return_tensors=\"np\")\n    ep = EvalPrediction(predictions=enc_pred[\"input_ids\"], label_ids=enc_ref[\"input_ids\"])\n    metrics = compute_metrics(ep, compute_result=True)\n\n    with metric_file.open(\"w\", encoding=\"utf-8\") as f:\n        json.dump(metrics, f, ensure_ascii=False, indent=2)\n\n    print(f\"[OK] sample length: {len(preds)}\")\n    print(f\"[OK] saved to: {metric_file}\")\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "kt-sft/test_adapter/test_grad.py",
    "content": "import torch, glob\n\nrecords = sorted(glob.glob(\"/home/lpl/kt-sft/tmp/train_logs/step_*.pt\"))\nexample = torch.load(records[1])\n\n# print(\"step:\", example[\"step\"])\n# print(\"inputs keys:\", list(example[\"inputs\"].keys()))\n# print(\"loss:\", example[\"loss\"])\n\n\n# print(\"param 'base_model.model.model.orig_module.layers.1.mlp.orig_module.gate.weight' 形状:\",\n#       example[\"params\"][\"base_model.model.model.orig_module.layers.1.mlp.orig_module.gate.weight\"].shape)\n# print(\"grad 'base_model.model.model.orig_module.layers.1.mlp.orig_module.gate.weight':\", example[\"grads\"][\"base_model.model.model.orig_module.layers.1.mlp.orig_module.gate.weight\"])\n\nprint(example)\n"
  },
  {
    "path": "kt-sft/test_adapter/time_test_lora_train.py",
    "content": "import torch\nimport torchvision.models as models\nfrom torch.profiler import profile, record_function, ProfilerActivity\n\nmodel = models.resnet18()\ninputs = torch.randn(5, 3, 224, 224)\n\n\nwith profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:\n    with record_function(\"model_inference\"):\n        model(inputs)\n\nprint(prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=10))\n# ---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n#                              Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  \n# ---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n#          aten::mkldnn_convolution        73.87%      37.241ms        74.04%      37.326ms       7.465ms       9.25 Mb           0 b             5  \n#                       aten::addmm        12.98%       6.545ms        13.11%       6.609ms       2.203ms     179.53 Kb     179.53 Kb             3  \n#     aten::max_pool2d_with_indices         6.63%       3.343ms         6.63%       3.343ms       1.114ms       5.05 Mb       5.05 Mb             3  \n#                   aten::clamp_min         2.12%       1.071ms         2.12%       1.071ms     153.000us           0 b           0 b             7  \n#                  aten::bernoulli_         1.20%     607.000us         1.23%     622.000us     311.000us           0 b    -260.00 Kb             2  \n# ---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  \n# Self CPU time total: 50.416ms\n\nprint(prof.key_averages(group_by_input_shape=True).table(sort_by=\"cpu_time_total\", row_limit=10))\n# ---------------------------------  ------------  -------------------------------------------\n#                              Name     CPU total                                 Input Shapes\n# ---------------------------------  ------------  -------------------------------------------\n#                   model_inference      57.503ms                                           []\n#                      aten::conv2d       8.008ms      [5,64,56,56], [64,64,3,3], [], ..., []]\n#                 aten::convolution       7.956ms     [[5,64,56,56], [64,64,3,3], [], ..., []]  #卷积统计\n#                aten::_convolution       7.909ms     [[5,64,56,56], [64,64,3,3], [], ..., []]\n#          aten::mkldnn_convolution       7.834ms     [[5,64,56,56], [64,64,3,3], [], ..., []]\n#                      aten::conv2d       6.332ms    [[5,512,7,7], [512,512,3,3], [], ..., []]\n#                 aten::convolution       6.303ms    [[5,512,7,7], [512,512,3,3], [], ..., []]  #卷积统计\n#                aten::_convolution       6.273ms    [[5,512,7,7], [512,512,3,3], [], ..., []]\n#          aten::mkldnn_convolution       6.233ms    [[5,512,7,7], [512,512,3,3], [], ..., []]\n#                      aten::conv2d       4.751ms  [[5,256,14,14], [256,256,3,3], [], ..., []]\n# ---------------------------------  ------------  -------------------------------------------\n# Self CPU time total: 57.549ms\n\nmodel = models.resnet18().cuda()\ninputs = torch.randn(5, 3, 224, 224).cuda()\n\nwith profile(activities=[\n        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:\n    with record_function(\"model_inference\"):\n        model(inputs)\n\nprint(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n\n# -------------------------------------------------------  ------------  ------------\n#                                                    Name     Self CUDA    CUDA total\n# -------------------------------------------------------  ------------  ------------\n#                                         model_inference       0.000us      11.666ms\n#                                            aten::conv2d       0.000us      10.484ms\n#                                       aten::convolution       0.000us      10.484ms\n#                                      aten::_convolution       0.000us      10.484ms\n#                              aten::_convolution_nogroup       0.000us      10.484ms\n#                                       aten::thnn_conv2d       0.000us      10.484ms\n#                               aten::thnn_conv2d_forward      10.484ms      10.484ms\n# void at::native::im2col_kernel<float>(long, float co...       3.844ms       3.844ms\n#                                       sgemm_32x32x32_NN       3.206ms       3.206ms\n#                                   sgemm_32x32x32_NN_vec       3.093ms       3.093ms\n# -------------------------------------------------------  ------------  ------------\n# Self CPU time total: 23.015ms\n# Self CUDA time total: 11.666ms\n\nmodel = models.resnet18()\ninputs = torch.randn(5, 3, 224, 224)\n\nwith profile(activities=[ProfilerActivity.CPU],\n        profile_memory=True, record_shapes=True) as prof:\n    model(inputs)\n\nprint(prof.key_averages().table(sort_by=\"self_cpu_memory_usage\", row_limit=10)) # 算子自身使用的内存总量，不包括子算子\n\nprint(prof.key_averages().table(sort_by=\"cpu_memory_usage\", row_limit=10))\n\n\nmodel = models.resnet18().cuda()\ninputs = torch.randn(5, 3, 224, 224).cuda()\n\nwith profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:\n    model(inputs)\n\nprof.export_chrome_trace(\"trace.json\")\n\n\nwith profile(\n    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],\n    with_stack=True,\n) as prof:\n    model(inputs)\n\n# Print aggregated stats\nprint(prof.key_averages(group_by_stack_n=5).table(sort_by=\"self_cuda_time_total\", row_limit=2)) # 启用stack tracing会带来额外开销\n# -------------------------  -----------------------------------------------------------\n#                      Name  Source Location\n# -------------------------  -----------------------------------------------------------\n# aten::thnn_conv2d_forward  .../torch/nn/modules/conv.py(439): _conv_forward\n#                            .../torch/nn/modules/conv.py(443): forward\n#                            .../torch/nn/modules/module.py(1051): _call_impl\n#                            .../site-packages/torchvision/models/resnet.py(63): forward\n#                            .../torch/nn/modules/module.py(1051): _call_impl\n# aten::thnn_conv2d_forward  .../torch/nn/modules/conv.py(439): _conv_forward\n#                            .../torch/nn/modules/conv.py(443): forward\n#                            .../torch/nn/modules/module.py(1051): _call_impl\n#                            .../site-packages/torchvision/models/resnet.py(59): forward\n#                            .../torch/nn/modules/module.py(1051): _call_impl\n# -------------------------  -----------------------------------------------------------\n# Self CPU time total: 34.016ms\n# Self CUDA time total: 11.659ms"
  },
  {
    "path": "kt-sft/withoutKT_PEFT.py",
    "content": "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq\nfrom peft import get_peft_model, LoraConfig, inject_adapter_in_model, TaskType\nfrom transformers import Trainer, TrainingArguments\nfrom datasets import Dataset, load_dataset\nimport transformers\nfrom transformers.trainer import TRAINING_ARGS_NAME\nimport os\nimport torch\nfrom datasets import load_dataset, Dataset, DatasetDict\nfrom torch.utils.data import DataLoader\nfrom torchviz import make_dot\n\n# 加载 tokenizer 和模型\ntokenizer = AutoTokenizer.from_pretrained('/home/yj/ktransformers/DeepSeek-V2-Lite-Chat', trust_remote_code=True)\n# tokenizer = AutoTokenizer.from_pretrained('/data/model/Qwen2.5-7B-Instruct', trust_remote_code=True)\nsave_path = '/home/yj/ktransformers/tmp/Qwen_Lora_model'\ndata_file = '/home/yj/ktransformers/test_adapter/sft_translation.json'\n\ndataset = Dataset.from_json(data_file)\n\ndef preprocess_function(examples):\n    inputs = examples[\"input\"]\n    targets = examples[\"output\"]\n    \n    model_inputs = tokenizer(inputs, padding=\"max_length\", truncation=True, max_length=512)\n    labels = tokenizer(targets, padding=\"max_length\", truncation=True, max_length=512)\n    \n    model_inputs[\"labels\"] = labels[\"input_ids\"]\n    return model_inputs\n\ndef print_model_with_params(model, prefix=\"\", max_layers=3, max_params=5):\n    print(f\"\\n{prefix}模型结构:\")\n    print(model)  # 原始结构打印\n    \n    print(f\"\\n{prefix}参数示例:\")\n    total_params = 0\n    for name, param in model.named_parameters():\n        if total_params >= max_layers:  # 控制打印层数\n            break\n        # 过滤非LoRA相关参数（可根据需要调整）\n        if \"lora\" not in name and \"embed\" not in name and \"proj\" not in name:\n            continue\n        print(f\"层名: {name}\")\n        print(f\"形状: {param.shape}\")\n        print(f\"数据类型: {param.dtype}\")\n        print(f\"参数示例值 (前{max_params}个): {param.data.flatten()[:max_params].cpu().numpy()}\\n\")\n        total_params += 1\n\nprocessed_dataset = dataset.map(preprocess_function, batched=True)\nsplit_dataset = processed_dataset.train_test_split(test_size=0.1)\n\ntrain_dataset = split_dataset[\"train\"]\nval_dataset = split_dataset[\"test\"]\n\ntrain_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\nval_dataloader = DataLoader(val_dataset, batch_size=8)\n\nmodel = AutoModelForCausalLM.from_pretrained(\n    '/home/yj/ktransformers/DeepSeek-V2-Lite-Chat', \n    trust_remote_code=True,\n    torch_dtype=torch.float16)\n# model = AutoModelForCausalLM.from_pretrained('/data/model/Qwen2.5-7B-Instruct', trust_remote_code=True)\n\nprint_model_with_params(model, prefix=\"原始模型\")\n\n# 配置 LoRA\nlora_config = LoraConfig(\n        task_type=TaskType.CAUSAL_LM,\n        target_modules=[\n            # \"q_proj\"\n            \"kv_a_proj_with_mqa\",\n            \"kv_b_proj\",\n            # \"o_proj\"\n        ],\n        r=8,\n        lora_alpha=32,\n        lora_dropout=0.1,\n    )\n\nmodel = get_peft_model(model, lora_config)\n# model = inject_adapter_in_model(lora_config, model)\n\nfor name, parms in model.named_parameters():\t\n        print('-->name:', name)\n        print('-->para:', parms)\n        print('-->grad_requirs:',parms.requires_grad)\n        print('-->grad_fn:',parms.grad_fn)\n        print('-->grad_value:',parms.grad)\n        print(\"===\")\n\n# print(model)\n\nmodel.train()\n\n# for name, parms in model.named_parameters():\t\n#         print('-->name:', name)\n#         print('-->para:', parms)\n#         print('-->grad_requirs:',parms.requires_grad)\n#         print('-->grad_fn:',parms.grad_fn)\n#         print('-->grad_value:',parms.grad)\n#         print(\"===\")\n\nmodel.to(device='cuda')\nx = torch.tensor([[1,2,3]], dtype=torch.int32).to(\"cuda\")\noutput = model(x)\nloss = output.logits.mean()\nprint(f\"output:{output}\")\nprint(f\"loss:{loss}\")\n\n# output = model(input_ids=torch.tensor([[1,2,3]], dtype=torch.int32))\n# loss = output.logits.mean()\n# # print_grad_fn(loss.grad_fn)\n# # 生成计算图\ndot = make_dot(loss, params=dict(model.named_parameters()))\ndot.render(\"PEFT_compute_one_layer_model_graph\", format=\"svg\")  # 保存为SVG格式的文件\n\n# 暂时先不训练\n# model = model.to('cuda')\n# model.config.use_cache = False\n\n# # 定义训练参数\n# training_args = TrainingArguments(\n#     output_dir='./results',         # 模型保存和日志输出的目录路径\n#     num_train_epochs=3,             # 训练的总轮数（epochs）\n#     per_device_train_batch_size=1, # 每个设备（如GPU或CPU）上的训练批次大小，16表示每次输入模型的数据数量\n#     learning_rate=5e-5,             # 学习率\n#     logging_steps=10,               # 每隔多少步（steps）进行一次日志记录\n#     save_steps=100,                 # 每隔多少步保存模型\n#     save_total_limit=2,             # 保留最近的两个模型\n#     fp16=True,                   \n# )\n\nclass KTrainer(Trainer):\n    def save_model(self, output_dir=None, _internal_call=False):\n        # 改写trainer的save_model，在checkpoint的时候只存lora权重\n        os.makedirs(output_dir, exist_ok=True)\n        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))\n        saved_params = {\n            k: v.to(\"cpu\") for k, v in self.model.named_parameters() if v.requires_grad\n        }\n        torch.save(saved_params, os.path.join(output_dir, \"adapter_model.bin\"))\n\ntrainer = KTrainer(\n    model=model,\n    train_dataset=train_dataset,\n    args=transformers.TrainingArguments(\n        per_device_train_batch_size=8,\n        gradient_accumulation_steps=16,\n        num_train_epochs=10,\n        learning_rate=3e-4,\n        fp16=True,\n        logging_steps=10,\n        save_steps=200,\n        output_dir=save_path\n    ),\n    data_collator=DataCollatorForSeq2Seq(\n        tokenizer, pad_to_multiple_of=8, return_tensors=\"pt\", padding=True\n    ),\n)\n\ntrainer.train()\n# model.save_pretrained(save_path)\n\n# print_model_with_params(model, prefix=\"LoRA微调模型\")\n\n# model.print_trainable_parameters() \n\n# model = model.merge_and_unload()\n\n# print_model_with_params(model, prefix=\"合并后模型\")\n\nfor name, parms in model.named_parameters():\t\n        print('-->name:', name)\n        print('-->para:', parms)\n        print('-->grad_requirs:',parms.requires_grad)\n        print('-->grad_fn:',parms.grad_fn)\n        print('-->grad_value:',parms.grad)\n        print(\"===\")"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"ktransformers\"\ndynamic = [\"version\", \"dependencies\"]\ndescription = \"KTransformers: CPU-GPU heterogeneous inference framework for LLMs\"\nreadme = \"README.md\"\nauthors = [{ name = \"kvcache-ai\" }]\nlicense = \"Apache-2.0\"\nrequires-python = \">=3.8\"\nclassifiers = [\n  \"Programming Language :: Python :: 3\",\n  \"Operating System :: POSIX :: Linux\",\n]\n\n[project.urls]\nHomepage = \"https://github.com/kvcache-ai/ktransformers\"\n\n[tool.setuptools]\n# No actual Python packages — this is a meta-package\npackages = []\n"
  },
  {
    "path": "setup.py",
    "content": "\"\"\"Meta-package: pip install ktransformers → installs kt-kernel + sglang-kt.\"\"\"\nfrom pathlib import Path\nfrom setuptools import setup\n\n_version_file = Path(__file__).resolve().parent / \"version.py\"\n_ns = {}\nexec(_version_file.read_text(), _ns)\n_v = _ns[\"__version__\"]\n\nsetup(\n    version=_v,\n    install_requires=[\n        f\"kt-kernel=={_v}\",\n        f\"sglang-kt=={_v}\",\n    ],\n)\n"
  },
  {
    "path": "third_party/llamafile/README.md",
    "content": "The code in this folder is copied from [Mozilla-Ocho/llamafile](https://github.com/Mozilla-Ocho/llamafile). Special thanks to the Mozilla-Ocho team.\n"
  },
  {
    "path": "third_party/llamafile/bench.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/bench.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n#pragma once\n\n#include <stdio.h>\n\n#include \"micros.h\"\n\n#define BENCH(x)                                                                       \\\n    do {                                                                               \\\n        x;                                                                             \\\n        __asm__ volatile(\"\" ::: \"memory\");                                             \\\n        long long start = micros();                                                    \\\n        for (int i = 0; i < ITERATIONS; ++i) {                                         \\\n            __asm__ volatile(\"\" ::: \"memory\");                                         \\\n            x;                                                                         \\\n            __asm__ volatile(\"\" ::: \"memory\");                                         \\\n        }                                                                              \\\n        printf(\"%9lld us %s\\n\", (micros() - start + ITERATIONS - 1) / ITERATIONS, #x); \\\n    } while (0)\n"
  },
  {
    "path": "third_party/llamafile/flags.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/flags.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#include \"flags.h\"\n\nbool FLAG_precise = false;\n"
  },
  {
    "path": "third_party/llamafile/flags.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/flags.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#pragma once\n\nextern bool FLAG_precise;\n"
  },
  {
    "path": "third_party/llamafile/iqk_mul_mat.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat.inc\n// Copyrigth 2024 Iwan Kawrakow.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp fenc=utf-8 :vi\n//\n// Copyright 2024 Iwan Kawrakow\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <cstring>\n#include <type_traits>\n#if defined __x86_64__ || defined __aarch64__ || defined(_M_X64)\n\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"sgemm.h\"\n\n// For i-quants, I had to explicitely specify which\n// functions to inline / not inline (at least for some\n// of the functions), else performance would be significantly\n// lower. This is worrysome as things can change with,\n// e.g., a different compiler version or running on a different\n// CPU.\n#ifdef _MSC_VER\n#define IQK_NOINLINE __declspec(noinline)\n#define IQK_ALWAYS_INLINE inline\n#else\n#define IQK_NOINLINE __attribute__((__noinline__))\n#define IQK_ALWAYS_INLINE __attribute__((always_inline))\n#endif\n\n#define GGML_COMMON_IMPL_C\n#include \"llama.cpp/ggml-common.h\"\n\n// clang-format off\n\n// This matrix - vector and matrix - matrix multiplication implementation\n// for legacy quants, k-quants and i-quants makes prompt processing 150-200%\n// (legacy and k-quants) or 250-400% (i-quants) faster.\n// compared to mainline llama.cpp (and llamafile).\n// It provides implementations for ARM_NEON (all quants) and AVX2\n// (all quants except sub-4 bit i-quants).\n//\n// Main idea is that unpacking the quants and the block scales to\n// be ready for dot products with the corresponding Q8_Y quants\n// takes time (here 'Y' stands for K, 0, or 1, depending on quantization type).\n// Hence, if we are performing a QX x Q8_Y matrix matrix\n// multiplication (as needed for prompt processing), we can get\n// a significant speedup by reusing the unpacked QX quants and scales\n// for multiplication with several Q8_K columns. We also achieve fewer\n// loads from memory, which is the main purpose of tiling in general\n// purpose matrix multiplication packages.\n\n#include <utility>\n#include <array>\n\n#endif\n\nconstexpr ggml_type GGML_TYPE_Q8_0_X4 = static_cast<ggml_type>(98);\nconstexpr ggml_type GGML_TYPE_Q8_1_X4 = static_cast<ggml_type>(99);\n\n\nnamespace {\n\ntypedef struct {\n    int32_t i1;\n    int32_t i2;\n} mmid_row_mapping;\n\nstruct DataInfo {\n    float       * s;\n    const char  * cy;\n    size_t        bs;\n    size_t        by;\n    int           cur_y = 0;\n    int           ne11;\n    const mmid_row_mapping * row_mapping = nullptr;\n    size_t        bs2 = 0;\n\n    inline const char * src1_row(int iy) const {\n        if (!row_mapping) return cy + (cur_y + iy)*by;\n        int i11 = row_mapping[cur_y + iy].i1 % ne11;\n        int i12 = row_mapping[cur_y + iy].i2;\n        return cy + (i11 + i12*ne11)*by;\n    }\n\n    inline void store(int ix, int iy, float result) const {\n        *(dst_row(iy) + ix) = result;\n        //dst_row(iy)[ix] = result;\n    }\n    inline float * dst_row(int iy) const {\n        if (!row_mapping) return s + (cur_y + iy)*bs;\n        int i12 = row_mapping[cur_y + iy].i2;\n        int i1  = row_mapping[cur_y + iy].i1;\n        int i2  = i12;\n        return s + i1*bs + i2*bs2;\n    }\n};\n\n/*\nmoonll \nchange param for set_mul_mat \nadd func16\n*/\n\ntypedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);\n\nstruct MulMat {\n    std::array<mul_mat_t, 8> funcs = {};\n    mul_mat_t func16 = nullptr;\n    //inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {\n    IQK_NOINLINE void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {\n        constexpr int k_x_step = 64; // This works best on my Ryzen-7950X and M2 Max CPUs (but differences to other tile size are small)\n\n        if (func16 && nrc_y >= 16) {\n            int n_step = (nrc_y - info.cur_y)/16;\n            for (int ix = 0; ix < nrc_x; ix += k_x_step) {\n                auto this_info = info;\n                this_info.s += ix;\n                int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;\n                for (int iy = 0; iy < n_step; ++iy) {\n                    func16(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);\n                    this_info.cur_y += 16;\n                }\n            }\n            info.cur_y += 16 * n_step;\n            if (info.cur_y == nrc_y) return;\n        }\n\n        int n_step = (nrc_y - info.cur_y)/funcs.size();\n        if (n_step > 0) {\n            for (int ix = 0; ix < nrc_x; ix += k_x_step) {\n                auto this_info = info;\n                this_info.s += ix;\n                int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;\n                for (int iy = 0; iy < n_step; ++iy) {\n                    funcs.back()(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);\n                    this_info.cur_y += funcs.size();\n                }\n            }\n            info.cur_y += funcs.size() * n_step;\n        }\n        int n_left = nrc_y - info.cur_y;\n        if (n_left > 0) {\n            funcs[n_left-1](n, vx, bx, info, nrc_x);\n        }\n    }\n    static IQK_NOINLINE bool set_mul_mat(int typeA, int typeB,int ne00, MulMat& mm, int Ny);\nprivate:\n    template <typename Dequantizer> static IQK_NOINLINE void set_functions(MulMat& m);\n};\n\ninline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {\n    const uint16_t * scales = (const uint16_t *)scales8;\n    const uint32_t a0 = scales[0] | (scales[1] << 16);\n    const uint32_t a1 = scales[2] | (scales[3] << 16);\n    const uint32_t a2 = scales[4] | (scales[5] << 16);\n    aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);\n    aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);\n    aux32[2] = a1 & 0x3f3f3f3f;\n    aux32[0] = a0 & 0x3f3f3f3f;\n}\n\n/*\nmoonll\ndecoding tables\n*/\n#ifdef __AVX2__\nstatic const uint64_t iq1s_grid_us[2048] = {\n    0x0000000000000000, 0x0000000000000002, 0x0000000000000101, 0x0000000000000200,\n    0x0000000000000202, 0x0000000000010001, 0x0000000000010101, 0x0000000000020000,\n    0x0000000000020002, 0x0000000000020200, 0x0000000000020202, 0x0000000001000101,\n    0x0000000001010001, 0x0000000001010100, 0x0000000001010102, 0x0000000001020101,\n    0x0000000002000000, 0x0000000002000002, 0x0000000002000200, 0x0000000002000202,\n    0x0000000002010101, 0x0000000002020000, 0x0000000002020002, 0x0000000002020200,\n    0x0000000002020202, 0x0000000100000100, 0x0000000100000101, 0x0000000100010001,\n    0x0000000100010100, 0x0000000100010102, 0x0000000100010201, 0x0000000100010202,\n    0x0000000100020101, 0x0000000101000001, 0x0000000101000102, 0x0000000101000201,\n    0x0000000101010002, 0x0000000101010101, 0x0000000101010202, 0x0000000101020001,\n    0x0000000101020100, 0x0000000101020102, 0x0000000101020200, 0x0000000102000101,\n    0x0000000102010001, 0x0000000102010100, 0x0000000102010102, 0x0000000102020101,\n    0x0000000200000000, 0x0000000200000002, 0x0000000200000200, 0x0000000200000202,\n    0x0000000200010101, 0x0000000200020000, 0x0000000200020002, 0x0000000200020200,\n    0x0000000200020202, 0x0000000201000101, 0x0000000201010001, 0x0000000201010201,\n    0x0000000201020100, 0x0000000201020201, 0x0000000202000000, 0x0000000202000002,\n    0x0000000202000200, 0x0000000202000202, 0x0000000202010001, 0x0000000202010101,\n    0x0000000202010201, 0x0000000202020000, 0x0000000202020002, 0x0000000202020200,\n    0x0000000202020202, 0x0000010000010001, 0x0000010000010100, 0x0000010000010102,\n    0x0000010000020101, 0x0000010001000001, 0x0000010001000201, 0x0000010001010101,\n    0x0000010001010202, 0x0000010001020100, 0x0000010001020101, 0x0000010002010001,\n    0x0000010002010201, 0x0000010002020101, 0x0000010100000001, 0x0000010100000100,\n    0x0000010100000101, 0x0000010100000102, 0x0000010100010101, 0x0000010100010200,\n    0x0000010100010202, 0x0000010100020201, 0x0000010101000000, 0x0000010101000101,\n    0x0000010101000202, 0x0000010101010000, 0x0000010101010001, 0x0000010101010100,\n    0x0000010101010101, 0x0000010101010102, 0x0000010101010201, 0x0000010101020000,\n    0x0000010101020002, 0x0000010101020101, 0x0000010101020200, 0x0000010101020202,\n    0x0000010102000001, 0x0000010102010001, 0x0000010102010101, 0x0000010102010200,\n    0x0000010102010202, 0x0000010102020001, 0x0000010102020100, 0x0000010102020101,\n    0x0000010102020102, 0x0000010102020201, 0x0000010200010100, 0x0000010200010201,\n    0x0000010201000001, 0x0000010201000100, 0x0000010201010000, 0x0000010201010002,\n    0x0000010201010101, 0x0000010201010200, 0x0000010201020000, 0x0000010201020001,\n    0x0000010201020102, 0x0000010201020201, 0x0000010202000101, 0x0000010202010001,\n    0x0000010202010100, 0x0000010202010201, 0x0000020000000000, 0x0000020000000002,\n    0x0000020000000200, 0x0000020000000202, 0x0000020000010101, 0x0000020000020000,\n    0x0000020000020002, 0x0000020000020200, 0x0000020000020202, 0x0000020001000101,\n    0x0000020001010001, 0x0000020001010102, 0x0000020001020101, 0x0000020002000000,\n    0x0000020002000002, 0x0000020002000200, 0x0000020002000202, 0x0000020002010101,\n    0x0000020002020000, 0x0000020002020002, 0x0000020002020200, 0x0000020002020202,\n    0x0000020100000101, 0x0000020100010001, 0x0000020100010100, 0x0000020100010201,\n    0x0000020100020100, 0x0000020100020101, 0x0000020101000001, 0x0000020101010000,\n    0x0000020101010001, 0x0000020101010101, 0x0000020101020001, 0x0000020101020100,\n    0x0000020101020201, 0x0000020102010001, 0x0000020102010100, 0x0000020102010102,\n    0x0000020102010201, 0x0000020102020101, 0x0000020200000000, 0x0000020200000002,\n    0x0000020200000200, 0x0000020200000202, 0x0000020200010101, 0x0000020200020000,\n    0x0000020200020002, 0x0000020200020200, 0x0000020200020202, 0x0000020201000101,\n    0x0000020201010001, 0x0000020201010201, 0x0000020201020001, 0x0000020201020101,\n    0x0000020202000000, 0x0000020202000002, 0x0000020202000101, 0x0000020202000200,\n    0x0000020202000202, 0x0000020202010101, 0x0000020202020000, 0x0000020202020002,\n    0x0000020202020200, 0x0000020202020202, 0x0001000000010000, 0x0001000000010001,\n    0x0001000000010100, 0x0001000000010201, 0x0001000000020100, 0x0001000000020101,\n    0x0001000001000001, 0x0001000001000100, 0x0001000001010000, 0x0001000001010101,\n    0x0001000001010200, 0x0001000001020001, 0x0001000001020100, 0x0001000001020101,\n    0x0001000001020201, 0x0001000002010001, 0x0001000002010100, 0x0001000002010102,\n    0x0001000002020001, 0x0001000002020101, 0x0001000100000001, 0x0001000100000100,\n    0x0001000100000102, 0x0001000100000201, 0x0001000100010000, 0x0001000100010002,\n    0x0001000100010101, 0x0001000100010200, 0x0001000100020001, 0x0001000100020100,\n    0x0001000100020201, 0x0001000101000101, 0x0001000101000202, 0x0001000101010000,\n    0x0001000101010001, 0x0001000101010002, 0x0001000101010100, 0x0001000101010101,\n    0x0001000101010102, 0x0001000101010201, 0x0001000101020000, 0x0001000101020101,\n    0x0001000102000100, 0x0001000102010002, 0x0001000102010101, 0x0001000102020001,\n    0x0001000102020100, 0x0001000200010001, 0x0001000200010100, 0x0001000200010102,\n    0x0001000200020101, 0x0001000201000000, 0x0001000201000102, 0x0001000201000201,\n    0x0001000201010002, 0x0001000201010101, 0x0001000201010200, 0x0001000201010202,\n    0x0001000201020100, 0x0001000201020102, 0x0001000202000101, 0x0001000202010001,\n    0x0001000202010100, 0x0001000202010102, 0x0001000202020101, 0x0001010000000001,\n    0x0001010000000102, 0x0001010000000201, 0x0001010000010100, 0x0001010000010101,\n    0x0001010000010200, 0x0001010000010201, 0x0001010000020001, 0x0001010000020102,\n    0x0001010001000001, 0x0001010001000101, 0x0001010001000102, 0x0001010001000200,\n    0x0001010001000202, 0x0001010001010001, 0x0001010001010100, 0x0001010001010101,\n    0x0001010001010102, 0x0001010001010201, 0x0001010001020002, 0x0001010001020101,\n    0x0001010001020200, 0x0001010002000100, 0x0001010002000201, 0x0001010002010000,\n    0x0001010002010100, 0x0001010002010101, 0x0001010002010200, 0x0001010002010201,\n    0x0001010002010202, 0x0001010002020001, 0x0001010002020100, 0x0001010002020101,\n    0x0001010002020201, 0x0001010100000002, 0x0001010100000101, 0x0001010100000202,\n    0x0001010100010001, 0x0001010100010100, 0x0001010100010101, 0x0001010100010102,\n    0x0001010100010201, 0x0001010100020000, 0x0001010100020002, 0x0001010100020101,\n    0x0001010100020200, 0x0001010100020202, 0x0001010101000001, 0x0001010101000100,\n    0x0001010101000101, 0x0001010101000102, 0x0001010101010001, 0x0001010101010002,\n    0x0001010101010100, 0x0001010101010101, 0x0001010101010102, 0x0001010101010201,\n    0x0001010101010202, 0x0001010101020001, 0x0001010101020100, 0x0001010101020101,\n    0x0001010101020102, 0x0001010101020201, 0x0001010102000000, 0x0001010102000002,\n    0x0001010102000100, 0x0001010102000101, 0x0001010102000200, 0x0001010102000202,\n    0x0001010102010000, 0x0001010102010001, 0x0001010102010100, 0x0001010102010101,\n    0x0001010102010102, 0x0001010102010201, 0x0001010102010202, 0x0001010102020000,\n    0x0001010102020002, 0x0001010102020101, 0x0001010200000001, 0x0001010200000100,\n    0x0001010200000101, 0x0001010200000102, 0x0001010200010101, 0x0001010200010102,\n    0x0001010200010200, 0x0001010200010202, 0x0001010200020001, 0x0001010200020102,\n    0x0001010201000000, 0x0001010201000002, 0x0001010201000100, 0x0001010201000101,\n    0x0001010201000200, 0x0001010201000202, 0x0001010201010001, 0x0001010201010101,\n    0x0001010201010102, 0x0001010201010200, 0x0001010201010201, 0x0001010201020001,\n    0x0001010201020100, 0x0001010201020101, 0x0001010201020200, 0x0001010201020201,\n    0x0001010201020202, 0x0001010202000102, 0x0001010202000202, 0x0001010202010002,\n    0x0001010202010101, 0x0001010202020100, 0x0001010202020201, 0x0001020000010001,\n    0x0001020000010102, 0x0001020000020101, 0x0001020001000001, 0x0001020001000100,\n    0x0001020001000102, 0x0001020001000201, 0x0001020001010000, 0x0001020001010101,\n    0x0001020001010200, 0x0001020001010202, 0x0001020001020000, 0x0001020001020001,\n    0x0001020001020100, 0x0001020001020102, 0x0001020001020201, 0x0001020002000101,\n    0x0001020002010001, 0x0001020002010100, 0x0001020002020101, 0x0001020100010000,\n    0x0001020100010002, 0x0001020100010101, 0x0001020100010202, 0x0001020100020001,\n    0x0001020100020101, 0x0001020101000002, 0x0001020101000100, 0x0001020101000101,\n    0x0001020101000200, 0x0001020101010001, 0x0001020101010100, 0x0001020101010101,\n    0x0001020101010102, 0x0001020101010201, 0x0001020101010202, 0x0001020101020000,\n    0x0001020101020101, 0x0001020101020202, 0x0001020102000201, 0x0001020102010001,\n    0x0001020102010002, 0x0001020102010101, 0x0001020102010200, 0x0001020102020001,\n    0x0001020102020102, 0x0001020102020201, 0x0001020200000201, 0x0001020200010102,\n    0x0001020200020100, 0x0001020200020102, 0x0001020201000100, 0x0001020201000102,\n    0x0001020201000201, 0x0001020201010000, 0x0001020201010002, 0x0001020201010101,\n    0x0001020201010200, 0x0001020201020001, 0x0001020201020102, 0x0001020201020201,\n    0x0001020202000101, 0x0001020202010001, 0x0001020202010102, 0x0001020202010202,\n    0x0002000000000000, 0x0002000000000002, 0x0002000000000200, 0x0002000000000202,\n    0x0002000000010101, 0x0002000000020000, 0x0002000000020002, 0x0002000000020101,\n    0x0002000000020200, 0x0002000000020202, 0x0002000001000101, 0x0002000001010001,\n    0x0002000001010201, 0x0002000001020001, 0x0002000001020101, 0x0002000002000000,\n    0x0002000002000002, 0x0002000002000200, 0x0002000002000202, 0x0002000002010101,\n    0x0002000002020000, 0x0002000002020002, 0x0002000002020101, 0x0002000002020200,\n    0x0002000002020202, 0x0002000100000101, 0x0002000100010001, 0x0002000100010100,\n    0x0002000100010201, 0x0002000100020101, 0x0002000101000002, 0x0002000101000100,\n    0x0002000101000201, 0x0002000101010101, 0x0002000101010200, 0x0002000101010202,\n    0x0002000101020001, 0x0002000101020100, 0x0002000101020101, 0x0002000101020102,\n    0x0002000102000101, 0x0002000102010000, 0x0002000102010102, 0x0002000102010201,\n    0x0002000102020101, 0x0002000200000001, 0x0002000200000200, 0x0002000200000202,\n    0x0002000200010001, 0x0002000200010101, 0x0002000200020000, 0x0002000200020002,\n    0x0002000200020200, 0x0002000200020202, 0x0002000201000101, 0x0002000201010001,\n    0x0002000201010102, 0x0002000201010201, 0x0002000201020101, 0x0002000202000001,\n    0x0002000202000200, 0x0002000202000202, 0x0002000202010001, 0x0002000202010101,\n    0x0002000202020000, 0x0002000202020002, 0x0002000202020200, 0x0002000202020202,\n    0x0002010000000101, 0x0002010000010100, 0x0002010000010102, 0x0002010000010201,\n    0x0002010000020101, 0x0002010001000100, 0x0002010001000101, 0x0002010001000102,\n    0x0002010001000201, 0x0002010001010002, 0x0002010001010101, 0x0002010001010200,\n    0x0002010001010202, 0x0002010001020102, 0x0002010002000101, 0x0002010002010001,\n    0x0002010002010100, 0x0002010002010201, 0x0002010002020001, 0x0002010002020101,\n    0x0002010100000201, 0x0002010100010101, 0x0002010100020001, 0x0002010100020201,\n    0x0002010101000000, 0x0002010101000101, 0x0002010101000200, 0x0002010101010001,\n    0x0002010101010100, 0x0002010101010101, 0x0002010101010201, 0x0002010101020002,\n    0x0002010101020101, 0x0002010101020200, 0x0002010102000201, 0x0002010102010000,\n    0x0002010102010100, 0x0002010102010101, 0x0002010102010200, 0x0002010102010202,\n    0x0002010102020001, 0x0002010102020100, 0x0002010102020102, 0x0002010102020201,\n    0x0002010200000101, 0x0002010200010000, 0x0002010200010002, 0x0002010200010201,\n    0x0002010200020101, 0x0002010201000001, 0x0002010201000201, 0x0002010201010101,\n    0x0002010201020000, 0x0002010201020001, 0x0002010201020201, 0x0002010202000100,\n    0x0002010202000102, 0x0002010202010000, 0x0002010202010202, 0x0002020000000000,\n    0x0002020000000002, 0x0002020000000200, 0x0002020000000202, 0x0002020000010101,\n    0x0002020000020000, 0x0002020000020002, 0x0002020000020200, 0x0002020000020202,\n    0x0002020001000101, 0x0002020001010001, 0x0002020001010100, 0x0002020001020101,\n    0x0002020002000000, 0x0002020002000002, 0x0002020002000200, 0x0002020002000202,\n    0x0002020002020000, 0x0002020002020002, 0x0002020002020200, 0x0002020002020202,\n    0x0002020100000201, 0x0002020100010001, 0x0002020100010100, 0x0002020100010201,\n    0x0002020100020101, 0x0002020101000102, 0x0002020101000201, 0x0002020101010002,\n    0x0002020101010101, 0x0002020101020001, 0x0002020101020100, 0x0002020101020102,\n    0x0002020101020201, 0x0002020102000101, 0x0002020102010000, 0x0002020102010102,\n    0x0002020102010201, 0x0002020102020100, 0x0002020102020101, 0x0002020200000000,\n    0x0002020200000002, 0x0002020200000200, 0x0002020200000202, 0x0002020200020000,\n    0x0002020200020002, 0x0002020200020200, 0x0002020200020202, 0x0002020201000101,\n    0x0002020201010001, 0x0002020201010102, 0x0002020201010201, 0x0002020201020101,\n    0x0002020202000000, 0x0002020202000002, 0x0002020202000200, 0x0002020202000202,\n    0x0002020202010101, 0x0002020202020000, 0x0002020202020002, 0x0002020202020200,\n    0x0002020202020202, 0x0100000000000101, 0x0100000000010001, 0x0100000000010102,\n    0x0100000000020101, 0x0100000001000201, 0x0100000001010002, 0x0100000001010101,\n    0x0100000001010200, 0x0100000001010202, 0x0100000001020001, 0x0100000001020100,\n    0x0100000001020102, 0x0100000002010100, 0x0100000002010201, 0x0100000002020001,\n    0x0100000002020102, 0x0100000100000000, 0x0100000100000001, 0x0100000100000100,\n    0x0100000100000102, 0x0100000100000201, 0x0100000100010002, 0x0100000100010101,\n    0x0100000100010102, 0x0100000100010200, 0x0100000100010202, 0x0100000100020001,\n    0x0100000100020102, 0x0100000100020201, 0x0100000101000101, 0x0100000101000200,\n    0x0100000101000202, 0x0100000101010001, 0x0100000101010100, 0x0100000101010101,\n    0x0100000101010102, 0x0100000101010201, 0x0100000101010202, 0x0100000101020101,\n    0x0100000101020200, 0x0100000101020202, 0x0100000102000001, 0x0100000102000100,\n    0x0100000102000102, 0x0100000102010000, 0x0100000102010002, 0x0100000102010101,\n    0x0100000102020000, 0x0100000102020001, 0x0100000102020002, 0x0100000200000101,\n    0x0100000200010001, 0x0100000200010100, 0x0100000200010102, 0x0100000200020101,\n    0x0100000201000001, 0x0100000201010002, 0x0100000201010101, 0x0100000201010202,\n    0x0100000201020100, 0x0100000201020201, 0x0100000202000201, 0x0100000202010100,\n    0x0100000202020101, 0x0100010000000001, 0x0100010000010101, 0x0100010000010201,\n    0x0100010000020201, 0x0100010001000101, 0x0100010001000200, 0x0100010001000202,\n    0x0100010001010001, 0x0100010001010100, 0x0100010001010101, 0x0100010001010102,\n    0x0100010001020001, 0x0100010001020002, 0x0100010001020101, 0x0100010001020200,\n    0x0100010001020202, 0x0100010002000001, 0x0100010002000102, 0x0100010002000201,\n    0x0100010002010000, 0x0100010002010002, 0x0100010002010101, 0x0100010002020000,\n    0x0100010002020001, 0x0100010002020201, 0x0100010100000001, 0x0100010100000002,\n    0x0100010100000101, 0x0100010100000202, 0x0100010100010001, 0x0100010100010100,\n    0x0100010100010101, 0x0100010100010102, 0x0100010100010201, 0x0100010100020000,\n    0x0100010100020101, 0x0100010100020202, 0x0100010101000001, 0x0100010101000100,\n    0x0100010101000101, 0x0100010101000102, 0x0100010101000201, 0x0100010101010000,\n    0x0100010101010001, 0x0100010101010100, 0x0100010101010101, 0x0100010101010102,\n    0x0100010101010200, 0x0100010101010201, 0x0100010101020001, 0x0100010101020100,\n    0x0100010101020101, 0x0100010101020102, 0x0100010101020201, 0x0100010102000002,\n    0x0100010102000100, 0x0100010102000101, 0x0100010102000200, 0x0100010102010001,\n    0x0100010102010100, 0x0100010102010101, 0x0100010102010102, 0x0100010102010201,\n    0x0100010102010202, 0x0100010102020101, 0x0100010102020200, 0x0100010102020202,\n    0x0100010200000001, 0x0100010200000101, 0x0100010200000201, 0x0100010200010100,\n    0x0100010200010101, 0x0100010200010200, 0x0100010200010202, 0x0100010200020001,\n    0x0100010200020100, 0x0100010200020201, 0x0100010201000000, 0x0100010201000002,\n    0x0100010201000101, 0x0100010201000200, 0x0100010201010000, 0x0100010201010001,\n    0x0100010201010002, 0x0100010201010101, 0x0100010201010102, 0x0100010201010201,\n    0x0100010201020002, 0x0100010201020101, 0x0100010201020200, 0x0100010202000001,\n    0x0100010202000101, 0x0100010202000202, 0x0100010202010100, 0x0100010202010101,\n    0x0100010202020001, 0x0100010202020100, 0x0100010202020102, 0x0100020000000101,\n    0x0100020000010001, 0x0100020000010101, 0x0100020000010202, 0x0100020000020101,\n    0x0100020001000002, 0x0100020001000201, 0x0100020001010000, 0x0100020001010101,\n    0x0100020001010200, 0x0100020001020001, 0x0100020001020100, 0x0100020001020102,\n    0x0100020001020201, 0x0100020002000101, 0x0100020002010001, 0x0100020002010100,\n    0x0100020002010102, 0x0100020002010201, 0x0100020002020101, 0x0100020100000001,\n    0x0100020100000101, 0x0100020100000102, 0x0100020100000202, 0x0100020100010000,\n    0x0100020100010100, 0x0100020100010101, 0x0100020100010200, 0x0100020100020001,\n    0x0100020100020100, 0x0100020100020102, 0x0100020101000000, 0x0100020101000101,\n    0x0100020101000202, 0x0100020101010001, 0x0100020101010002, 0x0100020101010100,\n    0x0100020101010101, 0x0100020101010102, 0x0100020101010201, 0x0100020101020000,\n    0x0100020101020002, 0x0100020101020101, 0x0100020101020102, 0x0100020101020202,\n    0x0100020102000102, 0x0100020102000201, 0x0100020102010002, 0x0100020102010101,\n    0x0100020102010102, 0x0100020102010200, 0x0100020102020001, 0x0100020102020100,\n    0x0100020102020102, 0x0100020102020201, 0x0100020200010102, 0x0100020201000100,\n    0x0100020201000102, 0x0100020201000201, 0x0100020201010101, 0x0100020201010200,\n    0x0100020201010202, 0x0100020201020100, 0x0100020201020201, 0x0100020202010100,\n    0x0100020202020101, 0x0101000000000001, 0x0101000000000100, 0x0101000000000101,\n    0x0101000000000102, 0x0101000000000201, 0x0101000000010002, 0x0101000000010101,\n    0x0101000000010202, 0x0101000000020001, 0x0101000000020100, 0x0101000000020201,\n    0x0101000001000000, 0x0101000001000101, 0x0101000001000200, 0x0101000001010001,\n    0x0101000001010100, 0x0101000001010101, 0x0101000001010102, 0x0101000001010201,\n    0x0101000001020101, 0x0101000001020200, 0x0101000002000102, 0x0101000002000201,\n    0x0101000002010101, 0x0101000002010200, 0x0101000002020000, 0x0101000002020001,\n    0x0101000002020102, 0x0101000002020201, 0x0101000100000101, 0x0101000100000200,\n    0x0101000100000201, 0x0101000100000202, 0x0101000100010001, 0x0101000100010100,\n    0x0101000100010101, 0x0101000100010102, 0x0101000100010200, 0x0101000100010201,\n    0x0101000100020000, 0x0101000100020101, 0x0101000100020102, 0x0101000100020200,\n    0x0101000100020202, 0x0101000101000001, 0x0101000101000100, 0x0101000101000101,\n    0x0101000101000102, 0x0101000101000201, 0x0101000101010000, 0x0101000101010001,\n    0x0101000101010002, 0x0101000101010100, 0x0101000101010101, 0x0101000101010102,\n    0x0101000101010200, 0x0101000101010201, 0x0101000101010202, 0x0101000101020001,\n    0x0101000101020100, 0x0101000101020101, 0x0101000101020102, 0x0101000101020201,\n    0x0101000102000002, 0x0101000102000101, 0x0101000102010001, 0x0101000102010100,\n    0x0101000102010101, 0x0101000102010102, 0x0101000102010201, 0x0101000102020000,\n    0x0101000102020101, 0x0101000102020202, 0x0101000200000001, 0x0101000200000102,\n    0x0101000200010002, 0x0101000200010101, 0x0101000200010202, 0x0101000200020001,\n    0x0101000200020100, 0x0101000201000002, 0x0101000201000101, 0x0101000201000202,\n    0x0101000201010001, 0x0101000201010100, 0x0101000201010101, 0x0101000201010102,\n    0x0101000201010201, 0x0101000201020002, 0x0101000201020101, 0x0101000202000101,\n    0x0101000202010000, 0x0101000202010002, 0x0101000202010101, 0x0101000202010201,\n    0x0101000202010202, 0x0101000202020100, 0x0101010000000100, 0x0101010000000101,\n    0x0101010000010001, 0x0101010000010100, 0x0101010000010101, 0x0101010000010102,\n    0x0101010000010200, 0x0101010000010201, 0x0101010000020001, 0x0101010000020101,\n    0x0101010000020200, 0x0101010000020202, 0x0101010001000001, 0x0101010001000100,\n    0x0101010001000101, 0x0101010001000102, 0x0101010001000201, 0x0101010001000202,\n    0x0101010001010000, 0x0101010001010001, 0x0101010001010100, 0x0101010001010101,\n    0x0101010001010102, 0x0101010001010200, 0x0101010001010201, 0x0101010001010202,\n    0x0101010001020001, 0x0101010001020002, 0x0101010001020100, 0x0101010001020101,\n    0x0101010001020102, 0x0101010001020201, 0x0101010002000000, 0x0101010002000200,\n    0x0101010002000202, 0x0101010002010001, 0x0101010002010100, 0x0101010002010101,\n    0x0101010002010102, 0x0101010002010201, 0x0101010002020001, 0x0101010002020100,\n    0x0101010002020101, 0x0101010002020202, 0x0101010100000001, 0x0101010100000002,\n    0x0101010100000100, 0x0101010100000101, 0x0101010100000102, 0x0101010100000201,\n    0x0101010100010000, 0x0101010100010001, 0x0101010100010002, 0x0101010100010100,\n    0x0101010100010101, 0x0101010100010102, 0x0101010100010201, 0x0101010100010202,\n    0x0101010100020001, 0x0101010100020100, 0x0101010100020101, 0x0101010100020102,\n    0x0101010100020201, 0x0101010101000000, 0x0101010101000001, 0x0101010101000002,\n    0x0101010101000100, 0x0101010101000101, 0x0101010101000102, 0x0101010101000200,\n    0x0101010101000201, 0x0101010101010000, 0x0101010101010001, 0x0101010101010002,\n    0x0101010101010100, 0x0101010101010101, 0x0101010101010102, 0x0101010101010200,\n    0x0101010101010201, 0x0101010101010202, 0x0101010101020000, 0x0101010101020001,\n    0x0101010101020100, 0x0101010101020101, 0x0101010101020102, 0x0101010101020200,\n    0x0101010101020201, 0x0101010101020202, 0x0101010102000001, 0x0101010102000100,\n    0x0101010102000101, 0x0101010102000201, 0x0101010102000202, 0x0101010102010000,\n    0x0101010102010001, 0x0101010102010100, 0x0101010102010101, 0x0101010102010102,\n    0x0101010102010200, 0x0101010102010201, 0x0101010102020001, 0x0101010102020100,\n    0x0101010102020101, 0x0101010102020102, 0x0101010102020201, 0x0101010200000000,\n    0x0101010200000001, 0x0101010200000002, 0x0101010200000100, 0x0101010200000102,\n    0x0101010200000200, 0x0101010200000201, 0x0101010200010001, 0x0101010200010100,\n    0x0101010200010101, 0x0101010200010200, 0x0101010200010201, 0x0101010200020000,\n    0x0101010200020001, 0x0101010200020002, 0x0101010200020100, 0x0101010200020101,\n    0x0101010200020102, 0x0101010200020200, 0x0101010200020201, 0x0101010201000001,\n    0x0101010201000101, 0x0101010201000102, 0x0101010201000200, 0x0101010201000201,\n    0x0101010201000202, 0x0101010201010000, 0x0101010201010001, 0x0101010201010002,\n    0x0101010201010100, 0x0101010201010101, 0x0101010201010102, 0x0101010201010200,\n    0x0101010201010201, 0x0101010201010202, 0x0101010201020001, 0x0101010201020100,\n    0x0101010201020101, 0x0101010201020201, 0x0101010202000002, 0x0101010202000101,\n    0x0101010202000102, 0x0101010202000200, 0x0101010202000201, 0x0101010202000202,\n    0x0101010202010001, 0x0101010202010101, 0x0101010202010202, 0x0101010202020002,\n    0x0101010202020101, 0x0101010202020102, 0x0101010202020200, 0x0101010202020201,\n    0x0101020000000100, 0x0101020000000101, 0x0101020000000102, 0x0101020000000201,\n    0x0101020000010000, 0x0101020000010101, 0x0101020000010200, 0x0101020000020001,\n    0x0101020000020202, 0x0101020001000101, 0x0101020001000200, 0x0101020001000202,\n    0x0101020001010001, 0x0101020001010100, 0x0101020001010101, 0x0101020001010102,\n    0x0101020001010200, 0x0101020001010201, 0x0101020001020000, 0x0101020001020002,\n    0x0101020001020100, 0x0101020001020101, 0x0101020002000002, 0x0101020002000201,\n    0x0101020002010000, 0x0101020002010002, 0x0101020002010101, 0x0101020002010200,\n    0x0101020002020001, 0x0101020002020201, 0x0101020100000001, 0x0101020100000002,\n    0x0101020100000101, 0x0101020100000202, 0x0101020100010001, 0x0101020100010100,\n    0x0101020100010101, 0x0101020100010102, 0x0101020100010201, 0x0101020100020101,\n    0x0101020101000001, 0x0101020101000100, 0x0101020101000101, 0x0101020101000102,\n    0x0101020101000201, 0x0101020101010000, 0x0101020101010001, 0x0101020101010002,\n    0x0101020101010100, 0x0101020101010101, 0x0101020101010102, 0x0101020101010200,\n    0x0101020101010201, 0x0101020101010202, 0x0101020101020001, 0x0101020101020100,\n    0x0101020101020101, 0x0101020101020102, 0x0101020101020201, 0x0101020102000001,\n    0x0101020102000101, 0x0101020102000201, 0x0101020102010001, 0x0101020102010100,\n    0x0101020102010101, 0x0101020102010102, 0x0101020102010200, 0x0101020102010201,\n    0x0101020102020101, 0x0101020200000100, 0x0101020200000200, 0x0101020200010101,\n    0x0101020200010202, 0x0101020200020000, 0x0101020200020101, 0x0101020200020102,\n    0x0101020200020201, 0x0101020201000101, 0x0101020201000200, 0x0101020201000201,\n    0x0101020201010001, 0x0101020201010101, 0x0101020201010102, 0x0101020201010200,\n    0x0101020201010201, 0x0101020201020002, 0x0101020201020101, 0x0101020201020200,\n    0x0101020201020202, 0x0101020202000001, 0x0101020202000202, 0x0101020202010002,\n    0x0101020202010101, 0x0101020202010102, 0x0101020202010200, 0x0101020202010202,\n    0x0101020202020001, 0x0102000000000101, 0x0102000000010100, 0x0102000000010102,\n    0x0102000000010201, 0x0102000000020101, 0x0102000001000100, 0x0102000001010000,\n    0x0102000001010101, 0x0102000001010102, 0x0102000001010200, 0x0102000001010202,\n    0x0102000001020001, 0x0102000001020100, 0x0102000001020102, 0x0102000001020201,\n    0x0102000002000001, 0x0102000002010102, 0x0102000002020101, 0x0102000100000001,\n    0x0102000100000100, 0x0102000100000102, 0x0102000100000201, 0x0102000100010002,\n    0x0102000100010101, 0x0102000100020001, 0x0102000100020002, 0x0102000100020102,\n    0x0102000100020201, 0x0102000101000101, 0x0102000101000201, 0x0102000101010001,\n    0x0102000101010101, 0x0102000101010102, 0x0102000101010201, 0x0102000101020101,\n    0x0102000101020102, 0x0102000101020202, 0x0102000102000100, 0x0102000102000202,\n    0x0102000102010002, 0x0102000102010101, 0x0102000102020001, 0x0102000102020102,\n    0x0102000102020201, 0x0102000200010001, 0x0102000200010102, 0x0102000200010201,\n    0x0102000201000000, 0x0102000201000001, 0x0102000201000102, 0x0102000201010101,\n    0x0102000201010102, 0x0102000201010200, 0x0102000201020000, 0x0102000202000101,\n    0x0102000202010001, 0x0102000202010102, 0x0102000202020101, 0x0102010000010001,\n    0x0102010000010002, 0x0102010000010101, 0x0102010000010102, 0x0102010000010202,\n    0x0102010000020001, 0x0102010000020102, 0x0102010000020201, 0x0102010001000000,\n    0x0102010001000002, 0x0102010001000101, 0x0102010001000200, 0x0102010001000202,\n    0x0102010001010001, 0x0102010001010100, 0x0102010001010101, 0x0102010001010102,\n    0x0102010001010201, 0x0102010001010202, 0x0102010001020000, 0x0102010001020002,\n    0x0102010001020101, 0x0102010002000100, 0x0102010002000101, 0x0102010002000201,\n    0x0102010002010000, 0x0102010002010002, 0x0102010002010100, 0x0102010002010101,\n    0x0102010002010102, 0x0102010002010200, 0x0102010002010202, 0x0102010002020001,\n    0x0102010002020100, 0x0102010002020201, 0x0102010100000101, 0x0102010100000200,\n    0x0102010100000202, 0x0102010100010001, 0x0102010100010101, 0x0102010100010102,\n    0x0102010100010201, 0x0102010101000100, 0x0102010101000101, 0x0102010101000102,\n    0x0102010101000201, 0x0102010101010000, 0x0102010101010001, 0x0102010101010100,\n    0x0102010101010101, 0x0102010101010102, 0x0102010101010201, 0x0102010101020001,\n    0x0102010101020100, 0x0102010101020101, 0x0102010101020102, 0x0102010101020201,\n    0x0102010102000102, 0x0102010102000201, 0x0102010102000202, 0x0102010102010001,\n    0x0102010102010101, 0x0102010102010102, 0x0102010102010201, 0x0102010102010202,\n    0x0102010102020002, 0x0102010102020101, 0x0102010102020102, 0x0102010102020200,\n    0x0102010200000002, 0x0102010200000201, 0x0102010200010101, 0x0102010200020000,\n    0x0102010200020102, 0x0102010200020200, 0x0102010200020201, 0x0102010201000000,\n    0x0102010201000101, 0x0102010201000200, 0x0102010201000202, 0x0102010201010001,\n    0x0102010201010100, 0x0102010201010101, 0x0102010201010102, 0x0102010201010200,\n    0x0102010201010202, 0x0102010201020000, 0x0102010201020101, 0x0102010201020200,\n    0x0102010202000000, 0x0102010202000002, 0x0102010202000101, 0x0102010202000202,\n    0x0102010202010100, 0x0102010202010102, 0x0102010202010200, 0x0102010202010201,\n    0x0102010202020000, 0x0102010202020100, 0x0102010202020102, 0x0102010202020202,\n    0x0102020000010102, 0x0102020000010201, 0x0102020000020101, 0x0102020001000001,\n    0x0102020001010002, 0x0102020001010101, 0x0102020001010202, 0x0102020001020001,\n    0x0102020001020201, 0x0102020002000101, 0x0102020002010001, 0x0102020002010200,\n    0x0102020002020102, 0x0102020100000001, 0x0102020100000100, 0x0102020100010000,\n    0x0102020100010101, 0x0102020100020001, 0x0102020100020100, 0x0102020100020102,\n    0x0102020100020201, 0x0102020101000000, 0x0102020101000001, 0x0102020101000101,\n    0x0102020101000102, 0x0102020101000200, 0x0102020101010001, 0x0102020101010100,\n    0x0102020101010101, 0x0102020101010102, 0x0102020101010201, 0x0102020101020000,\n    0x0102020101020101, 0x0102020101020202, 0x0102020102000002, 0x0102020102000100,\n    0x0102020102000202, 0x0102020102010101, 0x0102020102020001, 0x0102020102020100,\n    0x0102020102020101, 0x0102020102020201, 0x0102020200010001, 0x0102020200010102,\n    0x0102020200010200, 0x0102020201000001, 0x0102020201000100, 0x0102020201000201,\n    0x0102020201010000, 0x0102020201010101, 0x0102020201010200, 0x0102020201010202,\n    0x0102020201020100, 0x0102020201020101, 0x0102020201020201, 0x0102020202000102,\n    0x0102020202010100, 0x0102020202010200, 0x0102020202010202, 0x0102020202020102,\n    0x0200000000000000, 0x0200000000000002, 0x0200000000000200, 0x0200000000000202,\n    0x0200000000020000, 0x0200000000020002, 0x0200000000020200, 0x0200000000020202,\n    0x0200000001000101, 0x0200000001010000, 0x0200000001010001, 0x0200000001010100,\n    0x0200000001010102, 0x0200000001010201, 0x0200000001020101, 0x0200000002000000,\n    0x0200000002000002, 0x0200000002000200, 0x0200000002000202, 0x0200000002010101,\n    0x0200000002020000, 0x0200000002020002, 0x0200000002020200, 0x0200000002020202,\n    0x0200000100000101, 0x0200000100010001, 0x0200000100010100, 0x0200000100010102,\n    0x0200000100010201, 0x0200000100020101, 0x0200000101000001, 0x0200000101000100,\n    0x0200000101000201, 0x0200000101010000, 0x0200000101010002, 0x0200000101010101,\n    0x0200000101010102, 0x0200000101010200, 0x0200000101010201, 0x0200000101020100,\n    0x0200000101020102, 0x0200000101020201, 0x0200000102000101, 0x0200000102000201,\n    0x0200000102010100, 0x0200000102010102, 0x0200000102010201, 0x0200000102020101,\n    0x0200000200000000, 0x0200000200000002, 0x0200000200000200, 0x0200000200000202,\n    0x0200000200010101, 0x0200000200020000, 0x0200000200020002, 0x0200000200020200,\n    0x0200000200020202, 0x0200000201010001, 0x0200000201010100, 0x0200000201010201,\n    0x0200000201020101, 0x0200000202000000, 0x0200000202000002, 0x0200000202000200,\n    0x0200000202000202, 0x0200000202010101, 0x0200000202020000, 0x0200000202020002,\n    0x0200000202020200, 0x0200000202020202, 0x0200010000010100, 0x0200010000010201,\n    0x0200010001000001, 0x0200010001000100, 0x0200010001010001, 0x0200010001010101,\n    0x0200010001010202, 0x0200010001020001, 0x0200010001020100, 0x0200010001020201,\n    0x0200010002010100, 0x0200010002010201, 0x0200010100000001, 0x0200010100000201,\n    0x0200010100010002, 0x0200010100010101, 0x0200010100010202, 0x0200010100020102,\n    0x0200010100020201, 0x0200010101000000, 0x0200010101000001, 0x0200010101000101,\n    0x0200010101000200, 0x0200010101010001, 0x0200010101010100, 0x0200010101010101,\n    0x0200010101010102, 0x0200010101010201, 0x0200010101010202, 0x0200010101020101,\n    0x0200010101020102, 0x0200010101020200, 0x0200010101020202, 0x0200010102000001,\n    0x0200010102000100, 0x0200010102000102, 0x0200010102000201, 0x0200010102010000,\n    0x0200010102010002, 0x0200010102010101, 0x0200010102010200, 0x0200010102020102,\n    0x0200010200010001, 0x0200010200010102, 0x0200010200010201, 0x0200010200020101,\n    0x0200010201000001, 0x0200010201000100, 0x0200010201000201, 0x0200010201000202,\n    0x0200010201010000, 0x0200010201010101, 0x0200010201010201, 0x0200010201010202,\n    0x0200010201020001, 0x0200010201020102, 0x0200010201020202, 0x0200010202000101,\n    0x0200010202010001, 0x0200010202010202, 0x0200010202020100, 0x0200020000000000,\n    0x0200020000000002, 0x0200020000000200, 0x0200020000000202, 0x0200020000010101,\n    0x0200020000020000, 0x0200020000020002, 0x0200020000020200, 0x0200020000020202,\n    0x0200020001000001, 0x0200020001000101, 0x0200020001010001, 0x0200020001010100,\n    0x0200020001010201, 0x0200020001020101, 0x0200020001020201, 0x0200020002000000,\n    0x0200020002000002, 0x0200020002000200, 0x0200020002000202, 0x0200020002010101,\n    0x0200020002020000, 0x0200020002020002, 0x0200020002020200, 0x0200020002020202,\n    0x0200020100000101, 0x0200020100000102, 0x0200020100010001, 0x0200020100010100,\n    0x0200020100010102, 0x0200020100020101, 0x0200020101000001, 0x0200020101000100,\n    0x0200020101000102, 0x0200020101000201, 0x0200020101010000, 0x0200020101010002,\n    0x0200020101010101, 0x0200020101010202, 0x0200020101020001, 0x0200020101020100,\n    0x0200020102000101, 0x0200020102010102, 0x0200020102010201, 0x0200020102020101,\n    0x0200020200000000, 0x0200020200000002, 0x0200020200000200, 0x0200020200000202,\n    0x0200020200010101, 0x0200020200020000, 0x0200020200020002, 0x0200020200020200,\n    0x0200020200020202, 0x0200020201000101, 0x0200020201010001, 0x0200020201010100,\n    0x0200020201010102, 0x0200020202000000, 0x0200020202000002, 0x0200020202000200,\n    0x0200020202000202, 0x0200020202010101, 0x0200020202020000, 0x0200020202020002,\n    0x0200020202020200, 0x0200020202020202, 0x0201000000000101, 0x0201000000010001,\n    0x0201000000010102, 0x0201000000010200, 0x0201000000010201, 0x0201000000020101,\n    0x0201000001000001, 0x0201000001000102, 0x0201000001000201, 0x0201000001010101,\n    0x0201000001010200, 0x0201000001010202, 0x0201000001020201, 0x0201000001020202,\n    0x0201000002000101, 0x0201000002010001, 0x0201000002010100, 0x0201000002010102,\n    0x0201000002010201, 0x0201000002020101, 0x0201000100000001, 0x0201000100000100,\n    0x0201000100000102, 0x0201000100000201, 0x0201000100010000, 0x0201000100010101,\n    0x0201000100010200, 0x0201000100010202, 0x0201000100020001, 0x0201000100020100,\n    0x0201000100020102, 0x0201000100020201, 0x0201000101000000, 0x0201000101000101,\n    0x0201000101010000, 0x0201000101010001, 0x0201000101010100, 0x0201000101010101,\n    0x0201000101010102, 0x0201000101010201, 0x0201000101020002, 0x0201000101020101,\n    0x0201000102000100, 0x0201000102000102, 0x0201000102010002, 0x0201000102010101,\n    0x0201000102010200, 0x0201000102020001, 0x0201000102020100, 0x0201000102020102,\n    0x0201000102020201, 0x0201000200000101, 0x0201000200010001, 0x0201000200010100,\n    0x0201000200010201, 0x0201000200020101, 0x0201000201000100, 0x0201000201000102,\n    0x0201000201000201, 0x0201000201010000, 0x0201000201010002, 0x0201000201010101,\n    0x0201000201010200, 0x0201000201020102, 0x0201000201020201, 0x0201000202000101,\n    0x0201000202010100, 0x0201000202010102, 0x0201000202020201, 0x0201010000000001,\n    0x0201010000000100, 0x0201010000000102, 0x0201010000010000, 0x0201010000010101,\n    0x0201010000010200, 0x0201010000020102, 0x0201010001000000, 0x0201010001000202,\n    0x0201010001010001, 0x0201010001010100, 0x0201010001010101, 0x0201010001010102,\n    0x0201010001010200, 0x0201010001010201, 0x0201010001020000, 0x0201010001020001,\n    0x0201010001020002, 0x0201010001020101, 0x0201010002000100, 0x0201010002000102,\n    0x0201010002010002, 0x0201010002010100, 0x0201010002010101, 0x0201010002010200,\n    0x0201010002020001, 0x0201010002020201, 0x0201010100000000, 0x0201010100000101,\n    0x0201010100000200, 0x0201010100000202, 0x0201010100010000, 0x0201010100010001,\n    0x0201010100010100, 0x0201010100010101, 0x0201010100010102, 0x0201010100010201,\n    0x0201010100020001, 0x0201010100020101, 0x0201010100020201, 0x0201010100020202,\n    0x0201010101000001, 0x0201010101000100, 0x0201010101000101, 0x0201010101000102,\n    0x0201010101000201, 0x0201010101010000, 0x0201010101010001, 0x0201010101010002,\n    0x0201010101010100, 0x0201010101010101, 0x0201010101010102, 0x0201010101010200,\n    0x0201010101010201, 0x0201010101010202, 0x0201010101020001, 0x0201010101020100,\n    0x0201010101020101, 0x0201010101020102, 0x0201010101020201, 0x0201010102000001,\n    0x0201010102000101, 0x0201010102000200, 0x0201010102010001, 0x0201010102010002,\n    0x0201010102010100, 0x0201010102010101, 0x0201010102010102, 0x0201010102010201,\n    0x0201010102010202, 0x0201010102020000, 0x0201010102020002, 0x0201010102020101,\n    0x0201010102020200, 0x0201010102020202, 0x0201010200000001, 0x0201010200000100,\n    0x0201010200010000, 0x0201010200010101, 0x0201010200010201, 0x0201010200020000,\n    0x0201010200020102, 0x0201010200020201, 0x0201010201000101, 0x0201010201000200,\n    0x0201010201000201, 0x0201010201010001, 0x0201010201010002, 0x0201010201010101,\n    0x0201010201010102, 0x0201010201010201, 0x0201010201020101, 0x0201010201020200,\n    0x0201010202000002, 0x0201010202000100, 0x0201010202000201, 0x0201010202000202,\n    0x0201010202010002, 0x0201010202010100, 0x0201010202010101, 0x0201010202020100,\n    0x0201010202020102, 0x0201010202020201, 0x0201020000000101, 0x0201020000010102,\n    0x0201020000010201, 0x0201020000020101, 0x0201020001000001, 0x0201020001000102,\n    0x0201020001010000, 0x0201020001010002, 0x0201020001010101, 0x0201020001010102,\n    0x0201020001010202, 0x0201020001020100, 0x0201020001020101, 0x0201020002000101,\n    0x0201020002010001, 0x0201020002010102, 0x0201020002010201, 0x0201020002020101,\n    0x0201020100000100, 0x0201020100000102, 0x0201020100000201, 0x0201020100010000,\n    0x0201020100010002, 0x0201020100010101, 0x0201020100010200, 0x0201020100010202,\n    0x0201020100020000, 0x0201020100020001, 0x0201020100020100, 0x0201020100020102,\n    0x0201020101000000, 0x0201020101000002, 0x0201020101000101, 0x0201020101000200,\n    0x0201020101000202, 0x0201020101010001, 0x0201020101010100, 0x0201020101010101,\n    0x0201020101010102, 0x0201020101010201, 0x0201020101020002, 0x0201020101020101,\n    0x0201020101020102, 0x0201020101020202, 0x0201020102000001, 0x0201020102000100,\n    0x0201020102010000, 0x0201020102010002, 0x0201020102010101, 0x0201020102010202,\n    0x0201020102020001, 0x0201020102020102, 0x0201020200000101, 0x0201020200010101,\n    0x0201020200020101, 0x0201020201000100, 0x0201020201000102, 0x0201020201000201,\n    0x0201020201010000, 0x0201020201010101, 0x0201020201010200, 0x0201020201020001,\n    0x0201020202000101, 0x0201020202010001, 0x0201020202010100, 0x0201020202010101,\n    0x0201020202010102, 0x0202000000000000, 0x0202000000000002, 0x0202000000000200,\n    0x0202000000000202, 0x0202000000010101, 0x0202000000020000, 0x0202000000020002,\n    0x0202000000020200, 0x0202000000020202, 0x0202000001000101, 0x0202000001010001,\n    0x0202000001010100, 0x0202000001010102, 0x0202000001010201, 0x0202000002000000,\n    0x0202000002000002, 0x0202000002000200, 0x0202000002000202, 0x0202000002010101,\n    0x0202000002020000, 0x0202000002020002, 0x0202000002020200, 0x0202000002020202,\n    0x0202000100000101, 0x0202000100000201, 0x0202000100010001, 0x0202000100010100,\n    0x0202000100010102, 0x0202000100010201, 0x0202000100010202, 0x0202000101000102,\n    0x0202000101000201, 0x0202000101010001, 0x0202000101010101, 0x0202000101010200,\n    0x0202000101010202, 0x0202000101020001, 0x0202000101020100, 0x0202000102000101,\n    0x0202000102010000, 0x0202000102010002, 0x0202000102010102, 0x0202000102010201,\n    0x0202000200000002, 0x0202000200000200, 0x0202000200000202, 0x0202000200010000,\n    0x0202000200010201, 0x0202000200020002, 0x0202000200020200, 0x0202000200020202,\n    0x0202000201000101, 0x0202000201010001, 0x0202000201010102, 0x0202000201010201,\n    0x0202000201020101, 0x0202000202000000, 0x0202000202000002, 0x0202000202000200,\n    0x0202000202000202, 0x0202000202010101, 0x0202000202020000, 0x0202000202020002,\n    0x0202000202020200, 0x0202000202020202, 0x0202010000010201, 0x0202010000020101,\n    0x0202010001000001, 0x0202010001000100, 0x0202010001010000, 0x0202010001010100,\n    0x0202010001010101, 0x0202010001010200, 0x0202010001010202, 0x0202010001020001,\n    0x0202010001020101, 0x0202010001020102, 0x0202010001020200, 0x0202010001020201,\n    0x0202010002000101, 0x0202010100000102, 0x0202010100000201, 0x0202010100010000,\n    0x0202010100010002, 0x0202010100010101, 0x0202010100010200, 0x0202010100020102,\n    0x0202010100020201, 0x0202010101000002, 0x0202010101000101, 0x0202010101010001,\n    0x0202010101010100, 0x0202010101010101, 0x0202010101010102, 0x0202010101010201,\n    0x0202010101020101, 0x0202010101020202, 0x0202010102000001, 0x0202010102000100,\n    0x0202010102000101, 0x0202010102000102, 0x0202010102000201, 0x0202010102010002,\n    0x0202010102010101, 0x0202010102010200, 0x0202010200000101, 0x0202010200010001,\n    0x0202010200010102, 0x0202010200010202, 0x0202010200020001, 0x0202010200020101,\n    0x0202010201000100, 0x0202010201000102, 0x0202010201000202, 0x0202010201010002,\n    0x0202010201010101, 0x0202010201010102, 0x0202010201010200, 0x0202010201020000,\n    0x0202010201020002, 0x0202010202000102, 0x0202010202010000, 0x0202010202010101,\n    0x0202010202010102, 0x0202010202010201, 0x0202010202020001, 0x0202010202020100,\n    0x0202010202020102, 0x0202020000000000, 0x0202020000000002, 0x0202020000000200,\n    0x0202020000000202, 0x0202020000020000, 0x0202020000020002, 0x0202020000020200,\n    0x0202020000020202, 0x0202020001010001, 0x0202020001010100, 0x0202020001010102,\n    0x0202020001010201, 0x0202020002000000, 0x0202020002000002, 0x0202020002000200,\n    0x0202020002000202, 0x0202020002010101, 0x0202020002020000, 0x0202020002020002,\n    0x0202020002020200, 0x0202020002020202, 0x0202020100000101, 0x0202020100010100,\n    0x0202020100010201, 0x0202020100020001, 0x0202020100020101, 0x0202020101000001,\n    0x0202020101010000, 0x0202020101010101, 0x0202020101010202, 0x0202020101020001,\n    0x0202020101020102, 0x0202020101020201, 0x0202020102010000, 0x0202020102010102,\n    0x0202020200000000, 0x0202020200000002, 0x0202020200000200, 0x0202020200000202,\n    0x0202020200020000, 0x0202020200020002, 0x0202020200020200, 0x0202020200020202,\n    0x0202020201010001, 0x0202020201010100, 0x0202020201010102, 0x0202020202000000,\n    0x0202020202000002, 0x0202020202000200, 0x0202020202000202, 0x0202020202010101,\n    0x0202020202020000, 0x0202020202020002, 0x0202020202020200, 0x0202020202020202,\n};\n#else\nstatic const uint32_t iq1s_grid_us[2048] = {\n    0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,\n    0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,\n    0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,\n    0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,\n    0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,\n    0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,\n    0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,\n    0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,\n    0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,\n    0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,\n    0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,\n    0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,\n    0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,\n    0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,\n    0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,\n    0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,\n    0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,\n    0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,\n    0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,\n    0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,\n    0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,\n    0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,\n    0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,\n    0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,\n    0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,\n    0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,\n    0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,\n    0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,\n    0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,\n    0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,\n    0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,\n    0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,\n    0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,\n    0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,\n    0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,\n    0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,\n    0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,\n    0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,\n    0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,\n    0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,\n    0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,\n    0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,\n    0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,\n    0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,\n    0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,\n    0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,\n    0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,\n    0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,\n    0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,\n    0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,\n    0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,\n    0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,\n    0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,\n    0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,\n    0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,\n    0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,\n    0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,\n    0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,\n    0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,\n    0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,\n    0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,\n    0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,\n    0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,\n    0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,\n    0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,\n    0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,\n    0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,\n    0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,\n    0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,\n    0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,\n    0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,\n    0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,\n    0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,\n    0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,\n    0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,\n    0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,\n    0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,\n    0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,\n    0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,\n    0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,\n    0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,\n    0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,\n    0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,\n    0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,\n    0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,\n    0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,\n    0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,\n    0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,\n    0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,\n    0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,\n    0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,\n    0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,\n    0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,\n    0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,\n    0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,\n    0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,\n    0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,\n    0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,\n    0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,\n    0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,\n    0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,\n    0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,\n    0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,\n    0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,\n    0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,\n    0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,\n    0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,\n    0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,\n    0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,\n    0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,\n    0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,\n    0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,\n    0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,\n    0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,\n    0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,\n    0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,\n    0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,\n    0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,\n    0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,\n    0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,\n    0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,\n    0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,\n    0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,\n    0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,\n    0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,\n    0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,\n    0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,\n    0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,\n    0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,\n    0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,\n    0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,\n    0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,\n    0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,\n    0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,\n    0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,\n    0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,\n    0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,\n    0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,\n    0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,\n    0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,\n    0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,\n    0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,\n    0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,\n    0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,\n    0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,\n    0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,\n    0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,\n    0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,\n    0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,\n    0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,\n    0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,\n    0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,\n    0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,\n    0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,\n    0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,\n    0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,\n    0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,\n    0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,\n    0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,\n    0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,\n    0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,\n    0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,\n    0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,\n    0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,\n    0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,\n    0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,\n    0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,\n    0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,\n    0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,\n    0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,\n    0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,\n    0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,\n    0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,\n    0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,\n    0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,\n    0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,\n    0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,\n    0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,\n    0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,\n    0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,\n    0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,\n    0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,\n    0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,\n    0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,\n    0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,\n    0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,\n    0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,\n    0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,\n    0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,\n    0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,\n    0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,\n    0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,\n    0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,\n    0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,\n    0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,\n    0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,\n    0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,\n    0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,\n    0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,\n    0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,\n    0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,\n    0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,\n    0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,\n    0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,\n    0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,\n    0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,\n    0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,\n    0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,\n    0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,\n    0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,\n    0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,\n    0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,\n    0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,\n    0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,\n    0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,\n    0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,\n    0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,\n    0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,\n    0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,\n    0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,\n    0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,\n    0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,\n    0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,\n    0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,\n    0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,\n    0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,\n    0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,\n    0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,\n    0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,\n    0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,\n    0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,\n    0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,\n    0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,\n    0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,\n    0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,\n    0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,\n    0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,\n    0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,\n    0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,\n    0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,\n    0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,\n    0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,\n    0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,\n    0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,\n    0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,\n    0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,\n    0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,\n    0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,\n    0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,\n    0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,\n    0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,\n    0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,\n    0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,\n    0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,\n    0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,\n    0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,\n};\n#endif\n\n#ifndef HAVE_FANCY_SIMD\nconst uint64_t keven_signs[128] = {\n    0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,\n    0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,\n    0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,\n    0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,\n    0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,\n    0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,\n    0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,\n    0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,\n    0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,\n    0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,\n    0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,\n    0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,\n    0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,\n    0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,\n    0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,\n    0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,\n    0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,\n    0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,\n    0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,\n    0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,\n    0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,\n    0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,\n    0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,\n    0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,\n    0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,\n    0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,\n    0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,\n    0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,\n    0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,\n    0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,\n    0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,\n    0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,\n};\n#endif\n\n}\n\n/* moonll change mulmat\nadd typeB and strideB\n}*/\n\nbool iqk_mul_mat(long Nx, long Ny, long ne00,\n    int typeA, const void * A, long strideA,\n    int typeB, const void * B, long strideB,\n    float * C, long stride_C, int ith, int nth) {\n\n        MulMat mm;\n    \n        if (!MulMat::set_mul_mat(typeA, typeB, ne00, mm, Ny)) {\n            return false;\n        }\n\n        size_t row_size_qx = strideA*ggml_type_size(ggml_type(typeA));\n        size_t row_size_qy = strideB*ggml_type_size(ggml_type(typeB));\n      \n        \n        auto nrc_x = (Nx + nth - 1)/nth;\n        auto first_x = ith*nrc_x;\n        if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;\n\n        DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};\n\n        mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);\n\n        return true;\n}\n\n\nbool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B,\n        float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {\n    const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;\n    assert(row_mapping != nullptr);\n\n    MulMat mm;\n    int row_size_q8;\n    /* moonll\n\n    if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) {\n        return false;\n    }*/\n    int row_size_qx = ggml_row_size((ggml_type)typeA, ne00);\n    int nrc_x = (Nx + nth - 1)/nth;\n    int first_x = ith*nrc_x;\n    if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;\n    DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)};\n    mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);\n    return true;\n}\n\n#if defined __x86_64__ || defined(_M_X64)\n\n#if defined HAVE_FANCY_SIMD\n    #undef HAVE_FANCY_SIMD\n#endif\n#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)\n    #define HAVE_FANCY_SIMD\n#endif\n//#define HAVE_FANCY_SIMD\n\nnamespace {\n\ninline float hsum_float_4(__m128 x) {\n    x = _mm_add_ps(x, _mm_movehl_ps(x, x));\n    x = _mm_add_ss(x, _mm_movehdup_ps(x));\n    return _mm_cvtss_f32(x);\n}\ninline float hsum_float_8(__m256 x) {\n    return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));\n}\n\n#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)\n\n\ntemplate <int nrc, typename block_q8 = block_q8_K> struct Q8 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q8(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);\n    }\n\n#ifdef HAVE_FANCY_SIMD\n    inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); }\n#endif\n    inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }\n    inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); }\n    inline float scale(int iy, int i) const { return y[iy][i].d; }\n\n    const block_q8 * y[nrc_y];\n};\n\n// Handles q4_K and q5_K scales/mins\nstruct Scales8K {\n    template <typename Q8>\n    inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {\n        make_q4_scales(data, utmp);\n        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));\n        const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);\n        accum_mins(mins128, q8, i, c, accd);\n        const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);\n        return MM256_SET_M128I(sc128, sc128);\n    }\n#ifdef HAVE_FANCY_SIMD\n    template <typename Q8>\n    inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {\n        auto scales = process_mins_and_scales(data, c, i, q8, accd);\n        return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1);\n    }\n#endif\n    template <typename Q8>\n    inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {\n        const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i q8s = q8.load_bsums(iy, i);\n            const __m256i prod = _mm256_madd_epi16(mins, q8s);\n            accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);\n        }\n    }\n#ifdef HAVE_FANCY_SIMD\n    const __m512i shuffles512[2] = {\n        _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,\n                         0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),\n        _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,\n                         0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)\n    };\n#endif\n    const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),\n                                 _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};\n\n    uint32_t utmp[4];\n};\n\ntemplate <typename Q8>\ninline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        const __m256i prod  = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i));\n        accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);\n    }\n}\ninline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {\n    const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);\n    const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);\n    scales[0] = MM256_SET_M128I(l_scales, l_scales);\n    scales[1] = MM256_SET_M128I(h_scales, h_scales);\n}\n\nstruct ScaleQ3 {\n    inline __m128i make_scales(const uint16_t * s8) const {\n        const uint16_t * scales16 = (const uint16_t *)s8;\n        uint32_t aux0 = scales16[0] | (scales16[1] << 16);\n        uint32_t aux1 = scales16[2] | (scales16[3] << 16);\n        uint32_t aux2 = scales16[4] | (scales16[5] << 16);\n        __m128i scales128 = _mm_set_epi32(\n            ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030),\n            ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030),\n             (aux1       & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030),\n             (aux0       & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030));\n        return _mm_add_epi8(scales128, m32);\n    }\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct ScaleIQ4XS {\n    inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) {\n        uint32_t tmp32 = scales_h | (scales_h << 14);\n        const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4);\n        const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask);\n        return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32);\n    }\n    const __m128i hshift = _mm_set_epi32(12, 8, 4, 0);\n    const __m128i lshift = _mm_set_epi32(4, 0, 4, 0);\n    const __m128i hmask  = _mm_set1_epi16(0x03);\n    const __m128i lmask  = _mm_set1_epi8(0xf);\n    const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400);\n    const __m128i m32 = _mm_set1_epi16(-32);\n};\n\nstruct Scales8KBase {\n    template <typename Q8>\n    inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {\n        const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i q8s = q8.load_bsums(iy, i);\n            const __m256i prod = _mm256_madd_epi16(mins, q8s);\n            accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);\n        }\n    }\n    inline __m256i shuffle(__m128i mins) const {\n        return MM256_SET_M128I(_mm_shuffle_epi8(mins, shuffles[1]), _mm_shuffle_epi8(mins, shuffles[0]));\n    }\n    const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),\n                                 _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};\n};\n\ntemplate <typename Block>\nstruct BaseDequantizer {\n    BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}\n    inline void new_row(int ix) {\n        x = (const Block *)((const char *)vx + bx*ix);\n    }\n\n    const void *  vx;\n    size_t        bx;\n    const Block * x;\n\n    float d;\n};\n\n__m128i inline load_iq4nl_values_128() {\n    static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};\n    return _mm_loadu_si128((const __m128i *)kvalues_iq4nl);\n}\n\n__m256i inline load_iq4nl_values_256() {\n    auto val128 = load_iq4nl_values_128();\n    return MM256_SET_M128I(val128, val128);\n}\n\n#ifdef HAVE_FANCY_SIMD\n//====================================== Zen4 ==================================================\n\nstruct BlockPermuter {\n    const __m512i permute1 = _mm512_set_epi64(11, 10,  9,  8, 3, 2, 1, 0);\n    const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);\n};\n\nstruct Q4Bits {\n    inline void prepare(const uint8_t * q4) {\n        auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);\n        auto tmp1 = _mm512_and_si512(q4bits, ml);\n        auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);\n        values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);\n        q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);\n        tmp1 = _mm512_and_si512(q4bits, ml);\n        tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);\n        values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);\n    }\n    inline void prepare64(const uint8_t * q4) {\n        auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);\n        values[0] = _mm512_and_si512(q4bits, ml);\n        values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);\n        values[2] = _mm512_and_si512(q4bits, ml);\n        values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n    }\n    __m512i values[4];\n    const __m512i ml = _mm512_set1_epi8(0xf);\n    BlockPermuter perm;\n};\n\nstruct Q2Bits {\n    inline void prepare(const uint8_t * q2) {\n\n        auto q2bits = _mm512_loadu_si512((const __m512i*)q2);\n        auto tmp = _mm512_srli_epi16(q2bits, 2);\n\n        values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);\n        values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);\n        values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);\n        values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);\n        values[0] = _mm512_and_si512(values[0], ml);\n        values[2] = _mm512_and_si512(values[2], ml);\n    }\n    __m512i values[4];\n    const __m512i ml = _mm512_set1_epi8(0x03);\n    BlockPermuter perm;\n};\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n        scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n};\n\n/*\nmoonll DequantizerIQ4XS\n*/\n\n__m512i inline load_iq4nl_values_512() {\n    auto val256 = load_iq4nl_values_256();\n    return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);\n}\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n    DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        prepare(x[i].qs);\n        auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);\n        s8k.accum_mins(scales128, q8, i, -128.f*d, accd);\n        auto scales256 = MM256_SET_M128I(scales128, scales128);\n        auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);\n        scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);\n        scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);\n        scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);\n    }\n    inline void prepare(const uint8_t * q4) {\n        bits.prepare64(q4);\n        // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111\n        //                bits.valuse[1]: 16..31, 48...63, 80...95, 112..127\n        //                etc.\n        auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);\n        bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));\n        bits.values[0] = _mm512_shuffle_epi8(values, tmp);\n        tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);\n        bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));\n        bits.values[2] = _mm512_shuffle_epi8(values, tmp);\n    }\n\n    Q4Bits bits;\n    Scales8KBase s8k;\n    ScaleIQ4XS siq4;\n    const __m512i values;\n    const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2,  9,  8, 1, 0);\n    const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);\n    const __m512i shuffles[4] = {\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),\n        _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),\n    };\n};\n\nstruct HighBit5 {\n    inline void apply(const uint8_t * h, Q4Bits& bits) {\n        auto hbits256 = _mm256_loadu_si256((const __m256i *)h);\n        auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));\n    }\n    const __m512i mh = _mm512_set1_epi8(0x10);\n};\n\nstruct HighBit3 {\n    inline void apply(const uint8_t * h, Q2Bits& bits) {\n        auto hbits256 = _mm256_loadu_si256((const __m256i *)h);\n        auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh));\n    }\n    const __m512i mh = _mm512_set1_epi8(0x04);\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        hbits.apply(x[i].qh, bits);\n        auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n        scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);\n    }\n\n    Q4Bits bits;\n    HighBit5 hbits;\n    Scales8K s8k;\n};\n\nstruct Scale16 {\n    inline void make_scales(const __m128i& scales8, __m512i * scales) const {\n        auto all_scales8 = MM256_SET_M128I(scales8, scales8);\n        auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1);\n        auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2);\n        scales[0] = _mm512_cvtepi8_epi16(scales1);\n        scales[1] = _mm512_cvtepi8_epi16(scales2);\n    }\n    template <typename Q8>\n    inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8,\n        const Q8& q8, __m256 * accm, __m512i * scales) const {\n        process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm);\n        make_scales(scales8, scales);\n    }\n    const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202,\n                                              0x05050505, 0x01010101, 0x04040404, 0x00000000);\n    const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a,\n                                              0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808);\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);\n        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);\n        sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales);\n    }\n\n    Q2Bits bits;\n    Scale16 sc16;\n    const __m128i m4 = _mm_set1_epi8(0xf);\n\n};\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        hbits.apply(x[i].hmask, bits);\n        auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales);\n        sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales);\n    }\n\n    Q2Bits bits;\n    HighBit3 hbits;\n    ScaleQ3 sc3;\n    Scale16 sc16;\n    const __m128i m4  = _mm_set1_epi8(0xf);\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare64(x[i].ql);\n        add_high_bits(x[i].qh, bits);\n        auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales);\n        sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales);\n    }\n\n    inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const {\n        auto hbits = _mm512_loadu_si512((const __m512i *)qh);\n        auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh);\n        auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));\n        tmp1 = _mm512_and_si512(hbits, mh);\n        tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh);\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));\n    }\n\n    Q4Bits bits;\n    HighBit3 hbits;\n    Scale16 sc16;\n\n    const __m512i mh = _mm512_set1_epi8(0x30);\n\n};\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accm[nrc_y];\n    __m512  accd[nrc_y];\n    __m512i scales[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();\n        for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accm, scales);\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants(iy, i, 0));\n                const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants(iy, i, 1));\n                const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants(iy, i, 2));\n                const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants(iy, i, 3));\n                auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));\n                sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));\n                accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));\n            info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));\n        }\n\n    }\n}\ntemplate <typename Q8>\ninline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {\n    const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));\n    const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));\n    const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));\n    const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));\n    auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));\n    sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));\n    accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accm[nrc_y];\n    __m512  accd[nrc_y];\n    __m512i scales[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();\n        for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accm, scales);\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));\n                const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));\n                const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));\n                const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));\n                auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));\n                sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));\n                accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));\n            info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));\n        }\n\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accm[nrc_y];\n    __m512  accd[nrc_y];\n    __m512i scales[4];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();\n        for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accm, scales);\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0));\n                const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1));\n                const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2));\n                const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3));\n                auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(),\n                                    p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]);\n                accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));\n            info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));\n        }\n\n    }\n}\n\ntemplate <typename Dequantizer>\nstatic void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    constexpr int k_nx = 2;\n\n    Q8<1> q8(info);\n\n    Dequantizer deq1(vx, bx);\n    Dequantizer deq2(vx, bx);\n\n    Dequantizer * deq[k_nx];\n    deq[0] = &deq1;\n    deq[1] = &deq2;\n\n    __m512i scales[2*k_nx];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        auto accd = _mm512_setzero_ps();\n        auto accm = _mm256_setzero_ps();\n\n        for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);\n\n        for (int i = 0; i < nb/k_nx; ++i) {\n\n            for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);\n\n            for (int kx = 0; kx < k_nx; ++kx) {\n                compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);\n            }\n\n        }\n        if (2*(nb/2) < nb) {\n            int i0 = 2*(nb/2);\n            deq[0]->new_block(i0, q8, &accm, scales);\n            compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);\n        }\n\n        auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));\n        info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));\n    }\n}\n\n#else\n// ===================================== Vanilla AVX2 =====================================\n\nstruct Q4Bits {\n    inline void prepare(const uint8_t * q4, int j) {\n        auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);\n        values[0] = _mm256_and_si256(q4bits, ml);\n        values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);\n        values[2] = _mm256_and_si256(q4bits, ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n    }\n    inline void prepare64(const uint8_t * q4, int j) {\n        auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);\n        values[0] = _mm256_and_si256(q4bits, ml);\n        values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);\n        values[1] = _mm256_and_si256(q4bits, ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n    }\n    inline void prepare16(const uint8_t * q4, int j) {\n        values[0] = dequant16(q4 + 64*j +  0);\n        values[1] = dequant16(q4 + 64*j + 16);\n        values[2] = dequant16(q4 + 64*j + 32);\n        values[3] = dequant16(q4 + 64*j + 48);\n    }\n    inline __m256i dequant16(const uint8_t * qs) const {\n        const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);\n        const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);\n        return _mm256_and_si256(ml, aux256);\n    };\n    __m256i values[4];\n    const __m256i ml = _mm256_set1_epi8(0xf);\n};\n\nstruct Q2Bits {\n    inline void prepare(const uint8_t * q2, int j) {\n        auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);\n        values[0] = _mm256_and_si256(q2bits, ml);\n        values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);\n        values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);\n    }\n    __m256i values[4];\n    const __m256i ml = _mm256_set1_epi8(0x03);\n};\n\nstruct HighBit5 {\n    inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }\n    inline void apply(Q4Bits& bits, bool do_shift) {\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));\n        if (do_shift) {\n            hbits = _mm256_srli_epi16(hbits, 4);\n        }\n    }\n    const __m256i mh = _mm256_set1_epi8(0x10);\n    __m256i hbits;\n};\n\nstruct HighBit3 {\n    inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }\n    inline void apply(Q2Bits& bits, bool do_shift) {\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));\n        if (do_shift) {\n            hbits = _mm256_srli_epi16(hbits, 4);\n        }\n    }\n    const __m256i mh = _mm256_set1_epi8(0x04);\n    __m256i hbits;\n};\n\n\n/*\ntemplate <typename Q8, typename Bits>\ninline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {\n    if (j == 0) {\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));\n            sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));\n        }\n    } else {\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));\n        }\n    }\n}*/\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n};\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n    DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);\n        s8k.accum_mins(scales128, q8, i, -128.f*d, accd);\n        return MM256_SET_M128I(scales128, scales128);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare16(x[i].qs, j);\n        bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);\n        bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);\n        bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);\n        bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);\n    }\n\n    static __m256i load_values() {\n        static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};\n        auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);\n        return MM256_SET_M128I(val128, val128);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n    ScaleIQ4XS siq4;\n    const __m256i values;\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        hbits.load(x[i].qh);\n        return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n        hbits.apply(bits, j == 0);\n    }\n\n    Q4Bits  bits;\n    HighBit5 hbits;\n    Scales8K s8k;\n};\n\ntemplate <typename Q8>\ninline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d,\n    __m256 * accm, __m256i * scales) {\n    const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);\n    process_mins_16(all_scales, q8, i, d, accm);\n    prepare_scales_16(all_scales, scales);\n}\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        hbits.load(x[i].hmask);\n        process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n        hbits.apply(bits, j == 0);\n    }\n\n    Q2Bits  bits;\n    HighBit3 hbits;\n    ScaleQ3 sc3;\n\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);\n        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);\n        process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm);\n        prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n    }\n\n    Q2Bits  bits;\n\n    const __m128i m4 = _mm_set1_epi8(0xf);\n};\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare64(x[i].ql, j);\n        auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));\n    }\n\n    Q4Bits  bits;\n    const __m256i mh = _mm256_set1_epi8(0x30);\n};\n\ninline __m256i get_scale_shuffle_8(int i);\n\ninline void set_scales_8(const __m256i& all_scales, int j, __m256i* scales);\n\ninline __m256i get_scale_shuffle_16(int i);\n\ninline void set_scales_16(const __m256i& all_scales, __m256i* scales);\n\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%QK_K == 0);\n    const int nb = n/QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    __m256i all_scales[2];\n    __m256i scales[4];\n    __m256  accd[nrc_y];\n\n    Dequantizer deq(vx, bx);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accd, all_scales);\n\n            __m256i sumi[nrc_y];\n\n            for (int j = 0; j < QK_K/128; ++j) {\n                deq.prepare(i, j);\n                set_scales_16(all_scales[j], scales);\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n            }\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n\n    }\n\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accd[nrc_y];\n    __m256i scales[4];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            auto all_scales = deq.new_block(i, q8, accd);\n\n            __m256i sumi[nrc_y];\n\n            for (int j = 0; j < QK_K/128; ++j) {\n\n                deq.prepare(i, j);\n\n                set_scales_8(all_scales, j, scales);\n\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n\n            }\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));\n                accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n\n    }\n}\n#endif  // Zen4 or vanilla AVX2\n\n\n\n//\n// ============================== Legacy quants\n//\n\nstruct DotHelper {\n    const __m256i m1 = _mm256_set1_epi16(1);\n#if defined(__AVX512VNNI__) && defined(__AVX512VL__)\n    inline __m256i dot(__m256i x, __m256i y) const {\n        return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y);\n    }\n#else\n    inline __m256i dot(__m256i x, __m256i y) const {\n        return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y));\n    }\n#endif\n};\n\nstruct SignedDot {\n    DotHelper helper;\n    inline __m256i compute(__m256i x, __m256i y) const {\n        return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x));\n    }\n};\nstruct UnsignedDot {\n    DotHelper helper;\n    inline __m256i compute(__m256i x, __m256i y) const {\n        return helper.dot(x, y);\n    }\n};\ntemplate <typename Q8, typename Dot> struct Sum4 {\n    Dot dot;\n    inline __m256i compute(const __m256i * qx, const Q8 * y) const {\n        const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs));\n        const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs));\n        const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs));\n        const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs));\n        const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1));    // 0,0, 1,1, 0,0, 1,1\n        const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3));    // 2,2, 3,3, 2,2, 3,3\n        return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3\n    }\n};\n\nstruct Sum4_Q8 {\n    SignedDot dot;\n    static inline __m256i add1(__m256i a, __m256i b) {\n        return _mm256_add_epi32(_mm256_unpacklo_epi32(a, b), _mm256_unpackhi_epi32(a, b));\n    }\n    static inline __m256i add2(__m256i a, __m256i b) {\n        return _mm256_add_epi32(_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b));\n    }\n    inline __m256i compute(const __m256i * qx, const block_q8_0 * y) const {\n        const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs));\n        const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs));\n        const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs));\n        const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs));\n        const __m256i p01 = add1(p0, p1);  // 0,1, 0,1, 0,1, 0,1\n        const __m256i p23 = add1(p2, p3);  // 2,3, 2,3, 2,3, 2,3\n        return add2(p01, p23); // returns 0,1,2,3, 0,1,2,3\n    }\n};\n\nstruct ScaleHelperQ_0 {\n    ggml_half scales8[4];\n    template <typename Q>\n    inline __m128 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;\n        return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));\n    }\n    template <typename Q>\n    inline __m128 prepare4(__m128 other_scales, const Q * y) {\n        return _mm_mul_ps(other_scales, prepare4<Q>(y));\n    }\n    template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }\n    template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }\n};\ntemplate <int min_value>\nstruct ScaleHelperQ_0_1 {\n    ggml_half scales8[4];\n    template <typename Q>\n    inline __m256 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;\n        auto s4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));\n        return _mm256_set_m128(_mm_mul_ps(s4, min), s4);\n    }\n    template <typename Q>\n    inline __m256 prepare4(__m256 other_scales, const Q * y) {\n        return _mm_mul256_ps(other_scales, prepare4<Q>(y));\n    }\n    template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {\n        float d = GGML_FP16_TO_FP32(y->d);\n        return std::make_pair(d, -d*float(min_value));\n    }\n    std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));\n    }\n    const __m128 min = _mm_set1_ps(float(-min_value));\n};\n\nstruct ScaleHelperQ_1 {\n    uint32_t scales8[4];\n    const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100);\n\n    template <typename Q>\n    inline __m256 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) {\n            // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers\n            // complain that this breaks strict-aliasing rules.\n            memcpy(scales8 + j, &y[j].d, sizeof(uint32_t));\n        }\n        return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle));\n    }\n\n    template <typename Q>\n    inline __m256 prepare4(__m256 other_scales, const Q * y) {\n        return _mm256_mul_ps(other_scales, prepare4<Q>(y));\n    }\n\n    template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {\n        return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m));\n    }\n    template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m));\n    }\n    std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));\n    }\n};\n\nstruct MinusType0 {\n    inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); }\n    inline float compute(float d, int) const { return d; }\n    inline float result(__m256 acc, int) const { return hsum_float_8(acc); }\n};\n\ntemplate <int nrc_y> struct MinusType1 {\n    __m128 accm[nrc_y];\n    MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); }\n    inline __m256 compute(__m256 dm, int iy) {\n        const __m128 d = _mm256_castps256_ps128(dm);\n        const __m128 m = _mm256_extractf128_ps(dm, 1);\n        accm[iy] = _mm_add_ps(accm[iy], m);\n        return _mm256_set_m128(d, d);\n    }\n    inline float compute(const std::pair<float, float>& dm, int iy) {\n        accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f));\n        return dm.first;\n    }\n    inline float result(__m256 acc, int iy) const {\n        const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));\n        return hsum_float_4(_mm_add_ps(sum, accm[iy]));\n    }\n};\n\ntemplate <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {\n    __m256 acc[nrc_y];\n    Minus accm;\n    AccumT() {  for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); }\n    template <typename Unpacker, typename Scales, typename Sum, typename Q8>\n    inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) {\n        auto qx = unp.quants();\n        __m256 dall[nrc_y];\n        for (int i = 0; i < nb/4; ++i) {\n            auto other_scales = unp.set_block_4(i);\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto s12 = scales.prepare4(other_scales, y[iy] + 4*i);\n                dall[iy] = accm.compute(s12, iy);\n            }\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto pall = sum.compute(qx, y[iy] + 4*i);\n                acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]);\n            }\n        }\n        if (!is_multiple_of_4) {\n            for (int i = 4*(nb/4); i < nb; ++i) {\n                auto other_scales = unp.set_block(i);\n                for (int iy = 0; iy < nrc_y; ++iy) {\n                    auto s12 = scales.prepare1(other_scales, y[iy] + i);\n                    auto d = accm.compute(s12, iy);\n                    const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));\n                    acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);\n                }\n            }\n        }\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, accm.result(acc[iy], iy));\n            //s[iy*bs] = accm.result(acc[iy], iy);\n        }\n    }\n};\n\ntemplate <int nrc_y, bool is_multiple_of_4>\nusing AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>;\n\ntemplate <int nrc_y, bool is_multiple_of_4>\nusing AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;\n\nusing Sum4Type0 = Sum4<block_q8_0, SignedDot>;\nusing Sum4Type1 = Sum4<block_q8_1, UnsignedDot>;\n\ntemplate <typename Unpacker, typename Sum4Type, typename AccumType, typename Scales, typename Q8, int nrc_y>\nvoid mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {\n    Unpacker unp(vx, bx);\n    Sum4Type sum4;\n    Scales scales;\n    for (int ix = 0; ix < nrc_x; ++ix) {\n        unp.set_row(ix);\n        AccumType accum;\n        accum.compute(nb, unp, scales, sum4, y, info, ix);\n    }\n}\n\ntemplate <typename Unpacker, int nrc_y>\nvoid mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_0> q8(info);\n    int nb = n/Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\ntemplate <typename Unpacker, int nrc_y>\nvoid mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_1> q8(info);\n    int nb = n/Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type1, AccumType1<nrc_y, true>, ScaleHelperQ_1, block_q8_1, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type1, AccumType1<nrc_y, false>, ScaleHelperQ_1, block_q8_1, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\nstruct Dequantizer4bit {\n    const __m256i m4 = _mm256_set1_epi8(0xf);\n    inline __m256i dequant(const uint8_t * qs) const {\n        const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);\n        return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4);\n    }\n};\n\nstruct Q8_0_Dequantizer {\n    inline __m256i dequant(const block_q8_0 * x) const {\n        return _mm256_loadu_si256((const __m256i *)x->qs);\n    }\n};\n\nstruct Q8_0_1_Dequantizer {\n    inline __m256i dequant(const block_q8_0 * x) const {\n        return _mm256_add_epi8(_mm256_set1_epi8(127), _mm256_loadu_si256((const __m256i *)x->qs));\n    }\n};\n\nstruct Q4_0_Dequantizer {\n    Dequantizer4bit b4;\n    const __m256i m8 = _mm256_set1_epi8(-8);\n    inline __m256i dequant(const block_q4_0 * x) const {\n        return _mm256_add_epi8(b4.dequant(x->qs), m8);\n    }\n};\n\nstruct Q4_1_Dequantizer {\n    Dequantizer4bit b4;\n    inline __m256i dequant(const block_q4_1 * x) const {\n        return b4.dequant(x->qs);\n    }\n};\n\nstruct HBitDequantizer {\n    const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);\n    const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);\n    const __m256i minus1 = _mm256_set1_epi64x(-1);\n    inline __m256i to_bytes(const uint8_t * bits) const {\n        // Note: Data in all ggml quants is at least 2-byte aligned.\n        // => we can cast to uint16_t and use or on two consecutive entries\n        // which is faster than memcpy\n        const uint16_t * aux16 = (const uint16_t *)bits;\n        const uint32_t aux32 = aux16[0] | (aux16[1] << 16);\n        //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t));\n        __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle);\n        bytes = _mm256_or_si256(bytes, mask);\n        return _mm256_cmpeq_epi8(bytes, minus1);\n    }\n};\n\nstruct Q5_0_Dequantizer {\n    Dequantizer4bit b4;\n    HBitDequantizer hbit;\n    const __m256i mh = _mm256_set1_epi8((char)0xF0);\n    inline __m256i dequant(const block_q5_0 * x) const {\n        const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh);\n        return _mm256_or_si256(b4.dequant(x->qs), vqh);\n    }\n};\n\nstruct Q5_1_Dequantizer {\n    Dequantizer4bit b4;\n    HBitDequantizer hbit;\n    const __m256i mh = _mm256_set1_epi8(0x10);\n    inline __m256i dequant(const block_q5_1 * x) const {\n        const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh);\n        return _mm256_or_si256(b4.dequant(x->qs), vqh);\n    }\n};\n\ntemplate <typename Q, typename Scales, typename Dequantizer>\nstruct Q_Unpacker {\n    Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {}\n\n    const char * cx_0;\n    const Q    * x;\n    size_t       bx;\n\n    Scales scales;\n    Dequantizer deq;\n\n    __m256i qx[4];\n\n    inline const __m256i* quants() const { return qx; }\n\n    inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); }\n\n    inline auto set_block_4(int i) {\n        for (int j = 0; j < 4; ++j) {\n            qx[j] = deq.dequant(x + 4*i + j);\n        }\n        return scales.prepare4(x + 4*i);\n    }\n    inline auto set_block(int i) {\n        qx[0] = deq.dequant(x + i);\n        return scales.prepare1(x + i);\n    }\n};\n\nstruct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {\n    Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_0; }\n};\nstruct Q8_0_1_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0_1<127>, Q8_0_1_Dequantizer> {\n    Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n//    using Sum4T = Sum4TypeQ81;\n    inline static int block_size() { return QK8_0; }\n};\nstruct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {\n    Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_0; }\n};\nstruct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {\n    Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK5_0; }\n};\nstruct Q4_1_Unpacker final : public Q_Unpacker<block_q4_1, ScaleHelperQ_1, Q4_1_Dequantizer> {\n    Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_1; }\n};\nstruct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_Dequantizer> {\n    Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_1; }\n};\n\ntemplate <int nrc_y>\nvoid mul_mat_q8_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Q8_0_Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_0> q8(info);\n    int nb = n/Q8_0_Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\n\n\n\n/*\nmoonll\nadd some structs for DequantizerIQ2XXS\nSimpleBits\nEvenSignHelper\n*/\nstruct SimpleBits {\n    __m256i values[4];\n};\n\n// fix for #829: 添加对 AVX512VPOPCNTDQ 的检测\n#if defined(HAVE_FANCY_SIMD) && defined(__AVX512VPOPCNTDQ__)\n#define HAVE_AVX512_POPCNT 1\n#else\n#define HAVE_AVX512_POPCNT 0\n#endif\n\nstruct EvenSignHelper {\n    #if defined HAVE_FANCY_SIMD\n    // #pragma message(\"Using AVX512VPOPCNTDQ in even sign helper\")\n        union sbits_t {\n            __m128i vec;\n            __mmask32 mask[4];\n        };\n        IQK_ALWAYS_INLINE void sign_2_values(__m256i aux, __m256i * values) const {\n            aux = _mm256_and_si256(_mm256_srlv_epi32(aux, shifts), mask);\n            \n            // fix for #829: 兼容Intel Cascade Lake架构的CPU，如果不支持AVX512VPOPCNTDQ扩展，则使用替代实现\n            #if HAVE_AVX512_POPCNT\n                auto pcnt = _mm256_popcnt_epi32(aux);\n                \n            #else\n                // 提供替代实现，使用标准的位计数方法\n                __m256i pcnt;\n                int* pcnt_ptr = reinterpret_cast<int*>(&pcnt);\n                int* aux_ptr = reinterpret_cast<int*>(&aux); // 直接获取 aux 的地址，避免不必要的复制\n                \n                #pragma unroll 8  // 提示编译器展开循环，提高 SIMD 计算吞吐量\n                for (int i = 0; i < 8; i++) {\n                    pcnt_ptr[i] = __builtin_popcount(aux_ptr[i]); // 使用编译器内置 popcount\n                }\n            #endif\n            \n            sbits_t sbits;\n            sbits.vec = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));\n            values[0] = _mm256_mask_sub_epi8(values[0], sbits.mask[0], _mm256_setzero_si256(), values[0]);\n            values[1] = _mm256_mask_sub_epi8(values[1], sbits.mask[1], _mm256_setzero_si256(), values[1]);\n            //auto sign_bits = _mm256_cvtepi32_epi8(_mm256_or_si256(aux, _mm256_slli_epi32(_mm256_and_si256(pcnt, mone), 7)));\n            //const __mmask32 * m32 = (const __mmask32 *)&sign_bits;\n            //values[0] = _mm256_mask_sub_epi8(values[0], m32[0], _mm256_setzero_si256(), values[0]);\n            //values[1] = _mm256_mask_sub_epi8(values[1], m32[1], _mm256_setzero_si256(), values[1]);\n        }\n        const __m256i shifts = _mm256_set_epi32(21, 14, 7, 0, 21, 14, 7, 0);\n        const __m256i mask   = _mm256_set1_epi32(127);\n        const __m256i mone   = _mm256_set1_epi32(1);\n    #else\n        inline void sign_value(uint32_t aux32, __m256i& value) const {\n            auto signs = _mm256_set_epi64x(keven_signs[(aux32 >> 21) & 127], keven_signs[(aux32 >> 14) & 127],\n                                           keven_signs[(aux32 >>  7) & 127], keven_signs[(aux32 >>  0) & 127]);\n            value = _mm256_sign_epi8(value, signs);\n        }\n    #endif\n};\n\n/*\nmoonll ad multiply_add for mul_mat_qX_K_q8_K_IQ_1\nadd func\nget_scale_shuffle_8\nget_scale_shuffle_16\nset_scales_16\n*/\n\ninline __m256i get_scale_shuffle_8(int i) {\n    return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));\n}\n\ninline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {\n    scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));\n    scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));\n    scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));\n    scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));\n}\n\n\ninline __m256i get_scale_shuffle_16(int i) {\n    static const uint8_t k_shuffle[128] = {\n         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,\n         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,\n         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,\n        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,\n    };\n    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);\n}\n\ninline void set_scales_16(const __m256i& all_scales, __m256i * scales) {\n    scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));\n    scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));\n    scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));\n    scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));\n}\n\n\ntemplate <typename Q8, typename Bits>\ninline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {\n    if (j == 0) {\n#ifdef HAVE_FANCY_SIMD\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));\n        }\n#else\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));\n            sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));\n        }\n#endif\n    } else {\n#ifdef HAVE_FANCY_SIMD\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));\n            sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));\n        }\n#else\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));\n        }\n#endif\n    }\n}\n\n/*\nmoonll ad multiply_add_1 for mul_mat_qX_K_q8_K_IQ_1\nadd func\nset_scales_8_iq\nset_scales_16_iq\n\nadd MUL_MAT\nmul_mat_qX_K_q8_K_IQ_1\nmul_mat_qX_K_q8_K_IQ_N\nmul_mat_qX_K_q8_K_IQ\n*/\n\ntemplate <typename Bits>\ninline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {\n    if (j == 0) {\n#ifdef HAVE_FANCY_SIMD\n        auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);\n        auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);\n        auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);\n        auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);\n        sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2));\n        sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4));\n#else\n        const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));\n        const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));\n        const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));\n        const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));\n        sumi[0] = _mm256_add_epi32(p1, p3);\n        sumi[1] = _mm256_add_epi32(p2, p4);\n#endif\n    } else {\n#ifdef HAVE_FANCY_SIMD\n        auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);\n        auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);\n        auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);\n        auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);\n        sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2));\n        sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4));\n#else\n        const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));\n        const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));\n        const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));\n        const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));\n        sumi[0] = _mm256_add_epi32(sumi[0], _mm256_add_epi32(p1, p3));\n        sumi[1] = _mm256_add_epi32(sumi[1], _mm256_add_epi32(p2, p4));\n#endif\n    }\n}\n\n\ninline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) {\n    //#ifdef HAVE_FANCY_SIMD\n        auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100)\n                              : _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908);\n        scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);\n        scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)));\n    //#else\n    //    set_scales_8(all_scales, j, scales);\n    //#endif\n    }\n    \ninline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) {\n    #ifdef HAVE_FANCY_SIMD\n        auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100);\n        scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);\n        scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8)));\n    #else\n        set_scales_16(all_scales, scales);\n    #endif\n    }\n    \ntemplate <typename Dequantizer>\nstatic void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n        const int nb = n / QK_K;\n        Q8<1> q8(info);\n        Dequantizer deq(vx, bx);\n        __m256i scales[2];\n        __m256i q8_quants[4];\n        for (int ix = 0; ix < nrc_x; ++ix) {\n    \n            __m256 accd = _mm256_setzero_ps();\n            deq.new_row(ix);\n    \n            for (int i = 0; i < nb; ++i) {\n    \n                __m256i sumi[2], all_scales[Dequantizer::num_blocks/8];\n                deq.new_block(i, all_scales);\n    \n                for (int j = 0; j < QK_K/128; ++j) {\n                    deq.prepare(i, j, q8, q8_quants);\n                    if constexpr (Dequantizer::num_blocks == 8) {\n                        set_scales_8_iq(j, all_scales[0], scales);\n                    } else {\n                        set_scales_16_iq(all_scales[j], scales);\n                    }\n                    multiply_add_1(j, deq.bits, scales, q8_quants, sumi);\n                }\n                accd = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(0, i)), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi[0], sumi[1])), accd);\n            }\n    \n            info.store(ix, 0, hsum_float_8(accd));\n        }\n    }\n\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK_K;\n    Q8<nrc_y> q8(info);\n    Dequantizer deq(vx, bx);\n    __m256i scales[4];\n    __m256  accd[nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            __m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8];\n            //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();\n            __m256i mins;\n            float dmin = deq.new_block(i, all_scales, mins);\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto bsums = q8.load_bsums(iy, i);\n                auto prod  = _mm256_madd_epi16(mins, bsums);\n                accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);\n            }\n\n            for (int j = 0; j < QK_K/128; ++j) {\n                deq.prepare(i, j);\n                if constexpr (Dequantizer::num_blocks == 8) {\n                    set_scales_8(all_scales[0], j, scales);\n                } else {\n                    set_scales_16(all_scales[j], scales);\n                }\n                //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n            }\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));\n                accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n#ifdef HAVE_FANCY_SIMD\n    if constexpr (nrc_y == 1) {\n        mul_mat_qX_K_q8_K_IQ_1<Dequantizer>(n, vx, bx, info, nrc_x);\n    } else {\n        mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);\n    }\n#else\n    mul_mat_qX_K_q8_K_IQ_N<Dequantizer, nrc_y>(n, vx, bx, info, nrc_x);\n#endif\n}\n\n/*\nmoonll iq1s\ncore func for iq1s mul_mat_iq1_s_q8_K\n\n*/\n\ntemplate <int nrc_y>\nstatic void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    GGML_ASSERT(n%QK_K == 0);\n    Q8<nrc_y, block_q8_K> q8(info);\n    __m256i qx[8];\n    __m256i scales[4];\n    __m256  acc[nrc_y] = {};\n    auto delta_mask = _mm_set1_epi16(-32768); // to avoid stupid overflow warnings when using 0x8000\n    __m256i shuffle0 = _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100);\n    for (int ix = 0; ix < nrc_x; ++ix) {\n        auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);\n        for (int ibl = 0; ibl < n/QK_K; ++ibl) {\n            float d = GGML_FP16_TO_FP32(iq1s[ibl].d);\n            auto qhb = _mm_loadu_si128((const __m128i *)iq1s[ibl].qh);\n            auto scales128 = _mm_and_si128(_mm_srli_epi16(qhb, 12), _mm_set1_epi16(7));\n            scales128 = _mm_add_epi16(_mm_slli_epi16(scales128, 1), _mm_set1_epi16(1));\n#ifdef HAVE_FANCY_SIMD\n            auto mask = _mm_cmpeq_epi16_mask(_mm_and_si128(qhb, delta_mask), delta_mask);\n            auto deltas128 = _mm_mask_blend_epi16(mask, _mm_set1_epi16(-7), _mm_set1_epi16(-9));\n#else\n            auto mask = _mm_cmpeq_epi16(_mm_and_si128(qhb, delta_mask), delta_mask);\n            auto deltas128 = _mm_or_si128(_mm_and_si128(mask, _mm_set1_epi16(-9)), _mm_andnot_si128(mask, _mm_set1_epi16(-7)));\n#endif\n            deltas128 = _mm_mullo_epi16(scales128, deltas128);\n            scales128 = _mm_slli_epi16(scales128, 3);\n            auto deltas_l = _mm_unpacklo_epi16(deltas128, deltas128);\n            auto deltas_h = _mm_unpackhi_epi16(deltas128, deltas128);\n            auto deltas = MM256_SET_M128I(deltas_h, deltas_l); // blocks 0,0, 1,1, 2,2, ..., 7,7\n            auto all_scales = MM256_SET_M128I(scales128, scales128);\n            auto shuffle = shuffle0;\n            for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {\n                scales[ib64] = _mm256_shuffle_epi8(all_scales, shuffle);\n                shuffle = _mm256_add_epi8(shuffle, _mm256_set1_epi8(4));\n            }\n            const uint8_t  * qs = iq1s[ibl].qs;\n            const uint16_t * qh = iq1s[ibl].qh;\n            for (int ib = 0; ib < QK_K/32; ib += 2) {\n                qx[ib+0] = _mm256_set_epi64x(iq1s_grid_us[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid_us[qs[2] | ((qh[ib+0] << 2) & 0x700)],\n                                             iq1s_grid_us[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid_us[qs[0] | ((qh[ib+0] << 8) & 0x700)]);\n                qx[ib+1] = _mm256_set_epi64x(iq1s_grid_us[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid_us[qs[6] | ((qh[ib+1] << 2) & 0x700)],\n                                             iq1s_grid_us[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid_us[qs[4] | ((qh[ib+1] << 8) & 0x700)]);\n                qs += 8;\n            }\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto bsums = q8.load_bsums(iy, ibl);\n                auto sumi = _mm256_setzero_si256();\n                for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {\n                    auto qy1 = q8.load_quants(iy, ibl, 2*ib64+0);\n                    auto qy2 = q8.load_quants(iy, ibl, 2*ib64+1);\n#ifdef HAVE_FANCY_SIMD\n                    auto dot1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+0], qy1);\n                    auto dot2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[2*ib64+1], qy2);\n                    sumi = _mm256_dpwssd_epi32(sumi, scales[ib64], _mm256_packs_epi32(dot1, dot2));\n#else\n                    auto dot1 = _mm256_maddubs_epi16(qx[2*ib64+0], qy1);\n                    auto dot2 = _mm256_maddubs_epi16(qx[2*ib64+1], qy2);\n                    auto dot  = _mm256_add_epi16(_mm256_unpacklo_epi64(dot1, dot2), _mm256_unpackhi_epi64(dot1, dot2));\n                    sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(scales[ib64], dot));\n#endif\n                }\n#ifdef HAVE_FANCY_SIMD\n                sumi = _mm256_dpwssd_epi32(sumi, bsums, deltas);\n#else\n                sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(bsums, deltas));\n#endif\n                acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d*q8.scale(iy, ibl)), _mm256_cvtepi32_ps(sumi), acc[iy]);\n            }\n        }\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, 0.125f*hsum_float_8(acc[iy]));\n            acc[iy] = _mm256_setzero_ps();\n        }\n    }\n}\n\n/*\nmoonll iq1s\nDequantizerIQ2XXS\nDequantizerIQ2XXS is important Dequantizer for DequantizerIQ1_S\n*/\n\nstruct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {\n    DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    constexpr static int num_blocks = 8;\n\n    union Data {\n        __m256i vec;\n        uint32_t val[8];\n    };\n\n    inline __m128i load_scales(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        const uint16_t * a16 = (const uint16_t *)x[i].qs;\n        auto scales = _mm_srli_epi16(_mm_set_epi16(a16[31], a16[27], a16[23], a16[19], a16[15], a16[11], a16[7], a16[3]), 12);\n        return _mm_or_si128(_mm_slli_epi16(scales, 1), _mm_set1_epi16(1));\n    }\n\n    inline void new_block(int i, __m256i * scales) {\n        auto sc16 = load_scales(i);\n        scales[0] = MM256_SET_M128I(sc16, sc16);\n    }\n    inline float new_block(int i, __m256i * scales, __m256i& mins) {\n        auto sc16 = load_scales(i);\n        mins = scb.shuffle(sc16);\n        scales[0] = MM256_SET_M128I(sc16, sc16);\n        return -d*minv;\n    }\n\n    inline static void make4(const uint32_t * aux32, __m256i * values) {\n        const uint8_t * aux8 = (const uint8_t *)aux32;\n        values[0] = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[ 1]], iq2xxs_grid[aux8[ 0]]);\n        values[1] = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[ 9]], iq2xxs_grid[aux8[ 8]]);\n        values[2] = _mm256_set_epi64x(iq2xxs_grid[aux8[19]], iq2xxs_grid[aux8[18]], iq2xxs_grid[aux8[17]], iq2xxs_grid[aux8[16]]);\n        values[3] = _mm256_set_epi64x(iq2xxs_grid[aux8[27]], iq2xxs_grid[aux8[26]], iq2xxs_grid[aux8[25]], iq2xxs_grid[aux8[24]]);\n    }\n\n    IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const {\n#ifdef HAVE_FANCY_SIMD\n        esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0);\n        esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2);\n#else\n        esh.sign_value(aux32[1], values[0]);\n        esh.sign_value(aux32[3], values[1]);\n        esh.sign_value(aux32[5], values[2]);\n        esh.sign_value(aux32[7], values[3]);\n#endif\n    }\n    inline void make4_signed(const uint32_t * aux32, const __m256i& min_value, __m256i * values) const {\n        make4(aux32, values);\n        sign_values(aux32, values);\n        for (int k = 0; k < 4; ++k) values[k] = _mm256_add_epi8(values[k], min_value);\n    }\n    inline void make4(const uint32_t * aux32, __m256i * values, __m256i * q8) const {\n        make4(aux32, values);\n        sign_values(aux32, q8);\n    }\n    inline void prepare(int i, int j) {\n        Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);\n        make4_signed(data.val, min_value, bits.values);\n    }\n    inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {\n        for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);\n        Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);\n        make4(data.val, bits.values, q8_quants);\n    }\n\n    constexpr static int minv = 43;\n    SimpleBits bits;\n    Scales8KBase scb;\n    EvenSignHelper esh;\n    const __m256i min_value = _mm256_set1_epi8(minv);\n    const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1);\n};\n\n/*\nmoonll\nadd Q8_0_Unpacker && DequantizerIQ2XXS support\nadd func mul_mat_qX_K_q8_K_IQ\n*/\n\ntemplate <typename Dequantizer> void MulMat::set_functions(MulMat& m) {\n    if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||\n        std::is_same_v<Dequantizer, Q8_0_Unpacker>) {\n            m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;\n        }\n        else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>|| std::is_same_v<Dequantizer, Q8_0_1_Unpacker>) {\n            m.funcs[0] = mul_mat_qX_1_q8_1_T<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_1_q8_1_T<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_1_q8_1_T<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_1_q8_1_T<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_1_q8_1_T<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_1_q8_1_T<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;\n        }\n        else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2XXS>) {\n            m.funcs[0] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_K_q8_K_IQ<Dequantizer, 8>;\n            }\n            else {\n#ifdef HAVE_FANCY_SIMD\n            if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4XS>) {\n            m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;\n            } else {\n            m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;\n            m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 8>;\n            }\n#else\n            if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||\n                          std::is_same_v<Dequantizer, DequantizerQ3K> ||\n                          std::is_same_v<Dequantizer, DequantizerQ6K>) {\n                m.funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;\n                m.funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;\n                m.funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;\n                m.funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;\n                m.funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;\n                m.funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;\n                m.funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;\n                m.funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;\n            } else {\n                m.funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;\n                m.funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;\n                m.funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;\n                m.funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;\n                m.funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;\n                m.funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;\n                m.funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;\n                m.funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;\n            }\n#endif\n        }\n}\n\nstruct QFBase {\n    #ifdef __AVX512F__\n        constexpr static int k_step = 16;\n        using Data = __m512;\n        using Acc  = __m512;\n        static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); }\n        static inline Data load(const float * x) { return _mm512_loadu_ps(x); }\n        static inline Data load(const ggml_bf16_t * x) {\n            return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16));\n        }\n        static inline Acc acc(Acc prev, const Data& y, const Data& x) {\n            return _mm512_fmadd_ps(y, x, prev);\n        }\n        static inline Acc acc_first(const Data& y, const Data& x) {\n            return _mm512_mul_ps(y, x);\n        }\n        static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); }\n        static inline float hsum(Acc acc) {\n            return _mm512_reduce_add_ps(acc);\n        }\n        template <typename Float>\n        static inline Data load4Floats(const Float * x) {\n            return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);\n        }\n        static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {\n            acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);\n            acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline Acc acc_r4_first(const Data * xv, const Data& yv) {\n            auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00));\n            acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline __m128 hsum_r4(Acc acc) {\n            auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1));\n            auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3));\n            return _mm_add_ps(sum1, sum2);\n        }\n    #else\n        constexpr static int k_step = 8;\n        using Data = __m256;\n        using Acc  = __m256;\n        static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }\n        static inline Data load(const float * x) { return _mm256_loadu_ps(x); }\n        static inline Data load(const ggml_bf16_t * x) {\n            return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16));\n        }\n        static inline Acc acc(Acc prev, const Data& y, const Data& x) {\n            return _mm256_fmadd_ps(y, x, prev);\n        }\n        static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); }\n        static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {\n            acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);\n            acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline Acc acc_r4_first(const Data * xv, const Data& yv) {\n            auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00));\n            acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);\n            acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);\n            acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);\n            return acc;\n        }\n        static inline Acc acc_first(const Data& y, const Data& x) {\n            return _mm256_mul_ps(y, x);\n        }\n        static inline float hsum(Acc acc) {\n            return hsum_float_8(acc);\n        }\n        static inline __m128 hsum_r4(Acc acc) {\n            return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));\n        }\n        template <typename Float>\n        static inline Data load4Floats(const Float * x) {\n            return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0);\n        }\n    #endif\n        static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); }\n        static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); }\n        static inline __m128 load128(const ggml_bf16_t * x) {\n            return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16));\n        }\n    };\n    template <typename Float, int nrc_in> struct QFT final : public QFBase {\n        constexpr static int nrc = nrc_in;\n        QFT(const DataInfo& info) {\n            for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy);\n        }\n        QFT(const char * cx, size_t bx) {\n            for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx);\n        }\n        IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }\n        IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); }\n        IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const {\n            xv[0] = load1(ix+0, i);\n            xv[1] = load1(ix+1, i);\n            xv[2] = load1(ix+2, i);\n            xv[3] = load1(ix+3, i);\n    #ifdef __AVX512F__\n            auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]);\n            auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]);\n            auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]);\n            auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]);\n            xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));\n            xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));\n            xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));\n            xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));\n    #else\n            auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]);\n            auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]);\n            auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]);\n            auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]);\n            xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));\n            xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));\n            xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));\n            xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));\n    #endif\n        }\n        const Float * y[nrc];\n    };\n    \n\n\ntemplate <typename Qy, typename Qx>\nIQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {\n    int nb = n/QFBase::k_step;\n    int nb4 = n/4;\n    Qy y(info);\n    Qx x(cx + ix0*bx, bx);\n    QFBase::Data xv[Qx::nrc];\n    QFBase::Acc  acc[Qx::nrc*Qy::nrc];\n    auto yv = y.load1(0, 0);\n    for (int ix = 0; ix < Qx::nrc; ++ix) {\n        xv[ix] = x.load1(ix, 0);\n        acc[ix] = QFBase::acc_first(yv, xv[ix]);\n    }\n    for (int iy = 1; iy < Qy::nrc; ++iy) {\n        yv = y.load1(iy, 0);\n        for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);\n    }\n    for (int i = 1; i < nb; ++i) {\n        yv = y.load1(0, i);\n        for (int ix = 0; ix < Qx::nrc; ++ix) {\n            xv[ix] = x.load1(ix, i);\n            acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);\n        }\n        for (int iy = 1; iy < Qy::nrc; ++iy) {\n            yv = y.load1(iy, i);\n            for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);\n        }\n    }\n    for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {\n        yv = y.load_tail(0, i);\n        for (int ix = 0; ix < Qx::nrc; ++ix) {\n            xv[ix] = x.load_tail(ix, i);\n            acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);\n        }\n        for (int iy = 1; iy < Qy::nrc; ++iy) {\n            yv = y.load_tail(iy, i);\n            for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);\n        }\n    }\n    for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));\n}\n// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done\n// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in\n// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.\ntemplate <int nrc_y, typename FloatX, typename FloatY>\nvoid mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    const char * cx = (const char *)vx;\n    // TBD if we want this\n    //if constexpr (nrc_y == 1) {\n    //    constexpr int k_nx = 2;\n    //    for (int ix = 0; ix < nrc_x/k_nx; ++ix) {\n    //        mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);\n    //    }\n    //    if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) {\n    //        int nx = nrc_x - lastx;\n    //        switch (nx) {\n    //            case 1: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info); break;\n    //            case 2: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, lastx, info); break;\n    //            case 3: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, lastx, info); break;\n    //        }\n    //        //mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info);\n    //    }\n    //    return;\n    //}\n#ifdef __AVX512F__\n    constexpr int k_nx = 5;\n#else\n    constexpr int k_nx = nrc_y == 1 ? 4 : 2;\n#endif\n    for (int ix = 0; ix < nrc_x/k_nx; ++ix) {\n        mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);\n    }\n    int last_x = k_nx*(nrc_x/k_nx);\n    if (last_x == nrc_x) return;\n    int nx = nrc_x - last_x;\n#ifdef __AVX512F__\n    switch (nx) {\n        case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;\n        case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;\n        case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;\n        case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;\n    }\n#else\n    if constexpr (nrc_y == 1) {\n        switch (nx) {\n            case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;\n            case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;\n            case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;\n        }\n    } else {\n        switch (nx) {\n            case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;\n        }\n    }\n#endif\n}\n\ntemplate <typename FloatX, typename FloatY>\nvoid set_mul_mat_f(MulMat& mm) {\n    for (auto& f : mm.funcs) f = nullptr;\n    mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>;\n    mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>;\n    mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>;\n    mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>;\n    mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>;\n#ifndef __AVX512F__\n    mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>;\n#endif\n}\n\n\n\n/*\nmoonll\nadd typeb TO compare return not expected type of weight matrix\nadd IQ2XSS\nadd IQ1_S\nadd GGML_TYPE_IQ4_XS\n*/\n\nbool MulMat::set_mul_mat(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {\n    (void)Ny;\n\n        auto expected_typeB = GGML_TYPE_Q8_K;\n    switch (typeA) {\n        case GGML_TYPE_Q2_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ2K>(mm);\n            break;\n        case GGML_TYPE_Q3_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ3K>(mm);\n            break;\n        case GGML_TYPE_Q4_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ4K>(mm);\n            break;\n        case GGML_TYPE_Q5_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ5K>(mm);\n            break;\n        case GGML_TYPE_Q6_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ6K>(mm);\n            break;\n        case GGML_TYPE_IQ4_XS:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerIQ4XS>(mm);\n            break;\n        case GGML_TYPE_IQ2_XXS:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerIQ2XXS>(mm);\n            break;\n        case GGML_TYPE_Q4_0:\n            assert (ne00 % QK4_0 == 0);\n            MulMat::set_functions<Q4_0_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_0;\n            break;\n        case GGML_TYPE_Q4_1:\n            assert (ne00 % QK4_1 == 0);\n            MulMat::set_functions<Q4_1_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_1_X4;\n            break;\n        case GGML_TYPE_Q5_0:\n            assert (ne00 % QK5_0 == 0);\n            MulMat::set_functions<Q5_0_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_0;\n            break;\n        case GGML_TYPE_Q5_1:\n            assert (ne00 % QK5_1 == 0);\n            MulMat::set_functions<Q5_1_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_1_X4;\n            break;\n        case GGML_TYPE_Q8_0:\n            assert (ne00 % QK8_0 == 0);\n#ifdef HAVE_FANCY_SIMD\n            MulMat::set_functions<Q8_0_1_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_1_X4;\n#else\n            MulMat::set_functions<Q8_0_Unpacker>(mm);\n            expected_typeB = GGML_TYPE_Q8_0_X4;\n#endif\n            break;\n        case GGML_TYPE_IQ1_S:\n            mm.funcs[0] = mul_mat_iq1_s_q8_K<1>;\n            mm.funcs[1] = mul_mat_iq1_s_q8_K<2>;\n            mm.funcs[2] = mul_mat_iq1_s_q8_K<3>;\n            mm.funcs[3] = mul_mat_iq1_s_q8_K<4>;\n            mm.funcs[4] = mul_mat_iq1_s_q8_K<5>;\n            mm.funcs[5] = mul_mat_iq1_s_q8_K<6>;\n            mm.funcs[6] = mul_mat_iq1_s_q8_K<7>;\n            mm.funcs[7] = mul_mat_iq1_s_q8_K<8>;\n        #ifdef HAVE_FANCY_SIMD\n             mm.func16 = mul_mat_iq1_s_q8_K<16>;\n        #endif\n       // row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);\n              expected_typeB = GGML_TYPE_Q8_K;\n            break;\n\n        default:\n        {\n            return false;\n        }\n            \n    }\n\n\n\n    return ggml_type(typeB) == expected_typeB;\n\n}\n\n} // namespace\n\n/*\niq1_s is not support for arm\n*/\n#else   // __aarch64__\n\n//[kawrakow] Need these two for performance on Arm\ntypedef struct {\n    ggml_half d[8];\n    int8_t qs[4*QK8_1];\n} block_q8_1_x4;\nstatic_assert(sizeof(block_q8_1_x4) == 4*sizeof(block_q8_1), \"wrong q8_1_x4 block size/padding\");\ntypedef struct {\n    ggml_half d[4];\n    int8_t qs[4*QK8_0];\n} block_q8_0_x4;\nstatic_assert(sizeof(block_q8_0_x4) == 4*sizeof(block_q8_0), \"wrong q8_0_x4 block size/padding\");\n\nnamespace {\n\ntemplate <int nrc, typename block_q8 = block_q8_K> struct Q8 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q8(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);\n    }\n\n    inline int8x16_t load_quants_16(int iy, int i, int j) const { return vld1q_s8(y[iy][i].qs + 16*j); }\n    inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }\n    inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }\n    inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }\n    inline int16x8_t load_bsums8(int iy, int i) const {\n        auto q8s = vld1q_s16_x2(y[iy][i].bsums);\n        return vpaddq_s16(q8s.val[0], q8s.val[1]);\n    }\n    inline float scale(int iy, int i) const { return y[iy][i].d; }\n\n    const block_q8 * y[nrc_y];\n};\n\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n\n    Dequantizer deq(vx, bx, nrc_y);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[nrc_y];\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n//#pragma GCC unroll 4\n        for (int i = 0; i < nb; ++i) {\n\n            int32x4_t sumi[nrc_y];\n            for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);\n\n            if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {\n                deq.process_scales(i, q8, acc);\n                deq.prepare(i, 0);\n                deq.compute(q8, i, 0, sumi);\n                deq.prepare(i, 1);\n                deq.compute(q8, i, 1, sumi);\n            } else {\n                if constexpr (Dequantizer::num_blocks() == 8) {\n                    auto scales = deq.new_block(i, q8, acc);\n                    deq.prepare(i, 0);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                    deq.prepare(i, 1);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n                }\n                else if constexpr (Dequantizer::num_blocks() == 16) {\n                    auto scales = deq.new_block(i, q8, acc);\n                    deq.prepare(i, 0);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                    deq.prepare(i, 1);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n                }\n                else {\n                    GGML_ASSERT(false);\n                }\n            }\n\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));\n            }\n        }\n\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n\n    Dequantizer deq(vx, bx, nrc_y);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[nrc_y];\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb; ++i) {\n\n            int32x4_t sumi[nrc_y];\n            for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);\n\n            if constexpr (Dequantizer::num_blocks() == 8) {\n                auto scales = deq.new_block(i);\n                deq.prepare(i, 0);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                deq.prepare(i, 1);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n            }\n            else if constexpr (Dequantizer::num_blocks() == 16) {\n                auto scales = deq.new_block(i);\n                deq.prepare(i, 0);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                deq.prepare(i, 1);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n            }\n            else {\n                GGML_ASSERT(false);\n            }\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));\n            }\n        }\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,\n        const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n    auto mzero = vdupq_n_s32(0);\n    const int8x16_t * qs_1 = (const int8x16_t *)qx_1.val;\n    const int8x16_t * qs_2 = (const int8x16_t *)qx_2.val;\n\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[0], q8b_1.val[0]), qs_1[1], q8b_1.val[1]); // block 1\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[2], q8b_2.val[0]), qs_1[3], q8b_2.val[1]); // block 2\n    auto p12 = vpaddq_s32(p1, p2);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[0], q8b_3.val[0]), qs_2[1], q8b_3.val[1]); // block 3\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[2], q8b_4.val[0]), qs_2[3], q8b_4.val[1]); // block 4\n    auto p34 = vpaddq_s32(p3, p4);\n\n    auto pall = vpaddq_s32(p12, p34);\n    sumi = vmlaq_s32(sumi, scales.val[j], pall);\n}\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_8_blocks(const int8x16_t * qx, const Q8& q8,\n        const int32x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n    auto mzero = vdupq_n_s32(0);\n\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[0], q8b_1.val[0]), qx[1], q8b_1.val[1]); // block 1\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[2], q8b_2.val[0]), qx[3], q8b_2.val[1]); // block 2\n    auto p12 = vpaddq_s32(p1, p2);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[4], q8b_3.val[0]), qx[5], q8b_3.val[1]); // block 3\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[6], q8b_4.val[0]), qx[7], q8b_4.val[1]); // block 4\n    auto p34 = vpaddq_s32(p3, p4);\n\n    auto pall = vpaddq_s32(p12, p34);\n    sumi = vmlaq_s32(sumi, scales, pall);\n}\n\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,\n        const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n\n    auto mzero = vdupq_n_s32(0);\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,\n    auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3\n    sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,\n    auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7\n    sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);\n}\n\ntemplate <typename Q8>\ninline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto q8s = q8.load_bsums8(iy, i);\n        int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));\n        int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));\n        float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));\n        acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));\n    }\n}\ntemplate <typename Q8>\ninline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto q8s = q8.load_bsums(iy, i);\n        int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));\n        int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));\n        int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));\n        int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));\n        float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));\n        acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));\n    }\n}\n\nstruct Scales8 {\n    uint32_t utmp[4];\n    const uint8_t * sc8 = (const uint8_t *)utmp;\n    template <typename Q8, typename Qx>\n    inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {\n        make_q4_scales(x.scales, utmp);\n        int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));\n        accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));\n\n        uint8x8_t scales8 = vld1_u8(sc8);\n        uint16x8_t scales16 = vmovl_u8(scales8);\n        int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),\n                              vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};\n        return scales;\n    }\n};\n\nstruct Q4bits {\n    const uint8x16_t m4b = vdupq_n_u8(0xf);\n    uint8x16x4_t b1, b2;\n    inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {\n        b.val[0] = vandq_u8(val[0], m4b);\n        b.val[2] = vshrq_n_u8(val[0], 4);\n        b.val[1] = vandq_u8(val[1], m4b);\n        b.val[3] = vshrq_n_u8(val[1], 4);\n    }\n    inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {\n        b.val[0] = vandq_u8(val[0], m4b);\n        b.val[1] = vshrq_n_u8(val[0], 4);\n        b.val[2] = vandq_u8(val[1], m4b);\n        b.val[3] = vshrq_n_u8(val[1], 4);\n    }\n    inline void prepare(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x2(qs);\n        prepare4(b1, q4bits.val);\n        q4bits = vld1q_u8_x2(qs+32);\n        prepare4(b2, q4bits.val);\n    }\n    inline void prepare_v2(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        prepare4(b1, q4bits.val+0);\n        prepare4(b2, q4bits.val+2);\n    }\n    inline void prepare64(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        b1.val[0] = vandq_u8(q4bits.val[0], m4b);\n        b1.val[1] = vandq_u8(q4bits.val[1], m4b);\n        b1.val[2] = vandq_u8(q4bits.val[2], m4b);\n        b1.val[3] = vandq_u8(q4bits.val[3], m4b);\n        b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);\n        b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);\n        b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);\n        b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);\n    }\n    inline void prepare16(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x2(qs);\n        prepare4_16(b1, q4bits.val);\n        q4bits = vld1q_u8_x2(qs+32);\n        prepare4_16(b2, q4bits.val);\n    }\n    inline void prepare16_v2(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        prepare4_16(b1, q4bits.val+0);\n        prepare4_16(b2, q4bits.val+2);\n    }\n};\n\nstruct Q2bits {\n    const uint8x16_t m4b = vdupq_n_u8(0x03);\n    uint8x16x4_t b1, b2;\n    inline void prepare(const uint8_t * qs) {\n        auto q2bits = vld1q_u8_x2(qs);\n        b1.val[0] = vandq_u8(q2bits.val[0], m4b);\n        b1.val[1] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b1.val[2] = vandq_u8(q2bits.val[0], m4b);\n        b1.val[3] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b2.val[0] = vandq_u8(q2bits.val[0], m4b);\n        b2.val[1] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b2.val[2] = vandq_u8(q2bits.val[0], m4b);\n        b2.val[3] = vandq_u8(q2bits.val[1], m4b);\n    }\n};\n\ntemplate <typename block_q>\nstruct BaseDequantizer {\n    BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}\n    inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); }\n    const void * vx;\n    const block_q * x;\n    const size_t bx;\n    const int nrc;\n};\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return s8.process_scales_mins(x[i], q8, i, acc);\n    }\n    inline void prepare(int i, int j) {\n        if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);\n        else bits.prepare(x[i].qs+64*j);\n    }\n\n    Q4bits bits;\n    Scales8 s8;\n\n    float d;\n};\n\nstruct HighBit5 {\n    const uint8x16_t mhb = vdupq_n_u8(0x10);\n    uint8x16x2_t bits;\n    inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {\n        b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));\n        b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));\n        b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));\n        b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));\n\n        b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));\n        b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));\n        b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));\n        b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));\n\n        if (do_shift) {\n            bits.val[0] = vshrq_n_u8(bits.val[0], 4);\n            bits.val[1] = vshrq_n_u8(bits.val[1], 4);\n        }\n    }\n};\n\nstruct HighBit3 {\n    const uint8x16_t mhb = vdupq_n_u8(0x04);\n    uint8x16x2_t bits;\n    inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {\n        b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));\n        b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));\n        b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));\n        b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));\n\n        b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));\n        b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));\n        b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));\n        b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));\n\n        if (do_shift) {\n            bits.val[0] = vshrq_n_u8(bits.val[0], 4);\n            bits.val[1] = vshrq_n_u8(bits.val[1], 4);\n        }\n    }\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        h.bits = vld1q_u8_x2(x[i].qh);\n        return s8.process_scales_mins(x[i], q8, i, acc);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+64*j);\n        h.apply(bits.b1, bits.b2, j == 0);\n    }\n\n    Q4bits bits;\n    HighBit5 h;\n    Scales8 s8;\n\n    uint8x16x2_t hbits;\n\n    float d;\n};\n\ninline int32x4x4_t make_wider(const int16x8x2_t& scales16) {\n    int32x4x4_t scales = {\n        vmovl_s16(vget_low_s16 (scales16.val[0])),\n        vmovl_s16(vget_high_s16(scales16.val[0])),\n        vmovl_s16(vget_low_s16 (scales16.val[1])),\n        vmovl_s16(vget_high_s16(scales16.val[1])),\n    };\n    return scales;\n}\n\ntemplate <typename Q8>\ninline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {\n    int16x8x2_t scales16;\n    scales16.val[0] = vmovl_s8(vget_low_s8(scales8));\n    scales16.val[1] = vmovl_s8(vget_high_s8(scales8));\n    accum_mins_16(scales16, q8, acc, i, c);\n    return make_wider(scales16);\n}\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);\n    }\n    inline void prepare(int i, int j) {\n\n        auto hbits = vld1q_u8_x2(x[i].qh + 32*j);\n\n        bits.prepare64(x[i].ql+64*j);\n        bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));\n        bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));\n        bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));\n        bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));\n\n        bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));\n        bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));\n        bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));\n        bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));\n\n    }\n\n    Q4bits bits;\n\n    const uint8x16_t mhb = vdupq_n_u8(0x30);\n\n    float d;\n};\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        h.bits = vld1q_u8_x2(x[i].hmask);\n        const uint16_t * sc16 = (const uint16_t *)x[i].scales;\n        uint32_t aux0 = sc16[0] | (sc16[1] << 16);\n        uint32_t aux1 = sc16[2] | (sc16[3] << 16);\n        uint32_t aux2 = sc16[4] | (sc16[5] << 16);\n        aux32[0] =  (aux0       & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);\n        aux32[1] =  (aux1       & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);\n        aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);\n        aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);\n        return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d);\n    }\n\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+32*j);\n        h.apply(bits.b1, bits.b2, j == 0);\n    }\n\n    uint32_t aux32[4];\n\n    Q2bits bits;\n\n    HighBit3 h;\n\n    float d;\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return true; }\n\n    template <typename Q8>\n    inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        auto scales_and_mins = vld1q_u8(x[i].scales);\n        auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));\n        int16x8x2_t scales16;\n        scales16.val[0] = vmovl_s8(vget_low_s8(mins8));\n        scales16.val[1] = vmovl_s8(vget_high_s8(mins8));\n        accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));\n\n        scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));\n    }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        process_scales(i, q8, acc);\n        int16x8x2_t scales16;\n        scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));\n        scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));\n        return make_wider(scales16);\n    }\n\n    template <typename Q8>\n    inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {\n        auto m1 = vdupq_n_u8(1);\n        auto shuffle = vdupq_n_u8(8*j);\n        bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),\n                    vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);\n\n            auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),\n                    vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);\n\n            auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),\n                    vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);\n\n            auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),\n                    vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);\n        }\n    }\n\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+32*j);\n    }\n\n    uint32_t aux32[4];\n\n    uint8x16_t scales8;\n\n    Q2bits bits;\n\n    float d;\n};\n\n// ============================= i-quants\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n\n    static int8x16_t load_values() {\n        static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};\n        return vld1q_s8(iq4nl_values);\n    }\n\n    DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        (void)q8;\n        (void)acc;\n        d = GGML_FP16_TO_FP32(x[i].d);\n        const uint16_t scales_h = x[i].scales_h;\n        const uint16_t * scales_l = (const uint16_t *)x[i].scales_l;\n        aux32[0] = scales_l[0] | (scales_l[1] << 16);\n        aux32[1] = aux32[0] >> 4;\n        // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7\n        uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf));\n        uint16_t * aux16 = (uint16_t *)aux32;\n        aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2;\n        // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7\n        uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30));\n        int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32));\n        // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7\n        scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff));\n        int16x8_t scales16 = vmovl_s8(scales8);\n        int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};\n        return scales;\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare16(x[i].qs+64*j);\n        for (int k = 0; k < 4; ++k) {\n            bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k]));\n            bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k]));\n        }\n    }\n\n    Q4bits bits;\n    const int8x16_t values;\n    uint32_t aux32[2];\n\n    constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};\n\n    float d;\n};\n\nstruct SimpleBits {\n    uint8x16x4_t b1;\n    uint8x16x4_t b2;\n};\n\nIQK_ALWAYS_INLINE int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) {\n    int32x4x2_t scales;\n    auto one = vdupq_n_u32(1);\n    scales.val[0] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v1, 28), 1));\n    scales.val[1] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v2, 28), 1));\n    return scales;\n}\n\ninline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) {\n    auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));\n    auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127))));\n    b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));\n    b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));\n}\n\nIQK_ALWAYS_INLINE int32x4_t prepare_scales_8(const uint32x4_t& v1) {\n    return vreinterpretq_s32_u32(vsliq_n_u32(vdupq_n_u32(1), vshrq_n_u32(v1, 28), 1));\n}\n\nstruct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {\n    DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    IQK_ALWAYS_INLINE float new_block(int i) const { return 0.125f * GGML_FP16_TO_FP32(x[i].d); }\n\n    inline int32x4_t unpack(int i, int j, uint8x16_t * q) const {\n        auto data = vld1q_u32_x2((const uint32_t *)(x[i].qs + 16*j));\n        prepare_all(data, q);\n        return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1]));\n    }\n\nprivate:\n\n    static inline void prepare2(uint8x16_t * b, const uint32_t * bits, const uint64_t * signs) {\n        const uint8_t * idx = (const uint8_t *)bits;\n        b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]});\n        b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]});\n        apply_signs_2(b, signs, bits[1]);\n    }\n\n    inline static void prepare_all(const uint32x4x2_t& data, uint8x16_t * quants) {\n        const uint32_t * q2 = (const uint32_t *)data.val;\n        prepare2(quants+0, q2+0, keven_signs);\n        prepare2(quants+2, q2+2, keven_signs);\n        prepare2(quants+4, q2+4, keven_signs);\n        prepare2(quants+6, q2+6, keven_signs);\n    }\n};\n\ninline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) {\n    auto aux = vld1_u8(sc);\n    auto scales_l = vand_u8(aux, vdup_n_u8(0xf));\n    auto scales_h = vshr_n_u8(aux, 4);\n    auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));\n\n    auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));\n    int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };\n    return make_wider(scales16);\n}\n\nstruct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {\n    DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x4_t new_block(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        prepare_internal(i, 0);\n        return prepare_4bit_scales16(x[i].scales);\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) prepare_internal(i, 1);\n    }\n\nprivate:\n\n    static void make2(const uint16_t * qs, uint8x16_t * b) {\n        auto v1 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[1] & 511))));\n        auto v2 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[2] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[3] & 511))));\n        auto s1 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9))));\n        auto s2 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[2] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[3] >> 9))));\n        b[0] = vreinterpretq_u8_s8(vmulq_s8(v1, s1));\n        b[1] = vreinterpretq_u8_s8(vmulq_s8(v2, s2));\n    }\n\n    inline static void make4(const uint16_t * qs, uint8x16_t * b) {\n        make2(qs + 0, b + 0);\n        make2(qs + 4, b + 2);\n    }\n\n    IQK_ALWAYS_INLINE void prepare_internal(int i, int j) {\n        make4(x[i].qs + 16*j + 0, bits.b1.val);\n        make4(x[i].qs + 16*j + 8, bits.b2.val);\n    }\n\n};\n\n// So, I hate to include this table, but with the GCC 12.3 compiler\n// bundled in the Cosmopolitan tools, loading the unpacked sign bytes\n// from this table using the packed 8 sign bits as index is faster than\n// using the standard trick of vceqq_u8(vandq_u8(bits, mask), mask) to\n// expand the bits to bytes.\nstatic const uint64_t kall_signs[256] = {\n    0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff,\n    0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff,\n    0x01010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0x01010101ff01ffff,\n    0x01010101ffff0101, 0x01010101ffff01ff, 0x01010101ffffff01, 0x01010101ffffffff,\n    0x010101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0x010101ff0101ffff,\n    0x010101ff01ff0101, 0x010101ff01ff01ff, 0x010101ff01ffff01, 0x010101ff01ffffff,\n    0x010101ffff010101, 0x010101ffff0101ff, 0x010101ffff01ff01, 0x010101ffff01ffff,\n    0x010101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0x010101ffffffffff,\n    0x0101ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0x0101ff010101ffff,\n    0x0101ff0101ff0101, 0x0101ff0101ff01ff, 0x0101ff0101ffff01, 0x0101ff0101ffffff,\n    0x0101ff01ff010101, 0x0101ff01ff0101ff, 0x0101ff01ff01ff01, 0x0101ff01ff01ffff,\n    0x0101ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0x0101ff01ffffffff,\n    0x0101ffff01010101, 0x0101ffff010101ff, 0x0101ffff0101ff01, 0x0101ffff0101ffff,\n    0x0101ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0x0101ffff01ffffff,\n    0x0101ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0x0101ffffff01ffff,\n    0x0101ffffffff0101, 0x0101ffffffff01ff, 0x0101ffffffffff01, 0x0101ffffffffffff,\n    0x01ff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0x01ff01010101ffff,\n    0x01ff010101ff0101, 0x01ff010101ff01ff, 0x01ff010101ffff01, 0x01ff010101ffffff,\n    0x01ff0101ff010101, 0x01ff0101ff0101ff, 0x01ff0101ff01ff01, 0x01ff0101ff01ffff,\n    0x01ff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0x01ff0101ffffffff,\n    0x01ff01ff01010101, 0x01ff01ff010101ff, 0x01ff01ff0101ff01, 0x01ff01ff0101ffff,\n    0x01ff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0x01ff01ff01ffffff,\n    0x01ff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0x01ff01ffff01ffff,\n    0x01ff01ffffff0101, 0x01ff01ffffff01ff, 0x01ff01ffffffff01, 0x01ff01ffffffffff,\n    0x01ffff0101010101, 0x01ffff01010101ff, 0x01ffff010101ff01, 0x01ffff010101ffff,\n    0x01ffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0x01ffff0101ffffff,\n    0x01ffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0x01ffff01ff01ffff,\n    0x01ffff01ffff0101, 0x01ffff01ffff01ff, 0x01ffff01ffffff01, 0x01ffff01ffffffff,\n    0x01ffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0x01ffffff0101ffff,\n    0x01ffffff01ff0101, 0x01ffffff01ff01ff, 0x01ffffff01ffff01, 0x01ffffff01ffffff,\n    0x01ffffffff010101, 0x01ffffffff0101ff, 0x01ffffffff01ff01, 0x01ffffffff01ffff,\n    0x01ffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0x01ffffffffffffff,\n    0xff01010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0xff0101010101ffff,\n    0xff01010101ff0101, 0xff01010101ff01ff, 0xff01010101ffff01, 0xff01010101ffffff,\n    0xff010101ff010101, 0xff010101ff0101ff, 0xff010101ff01ff01, 0xff010101ff01ffff,\n    0xff010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0xff010101ffffffff,\n    0xff0101ff01010101, 0xff0101ff010101ff, 0xff0101ff0101ff01, 0xff0101ff0101ffff,\n    0xff0101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0xff0101ff01ffffff,\n    0xff0101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0xff0101ffff01ffff,\n    0xff0101ffffff0101, 0xff0101ffffff01ff, 0xff0101ffffffff01, 0xff0101ffffffffff,\n    0xff01ff0101010101, 0xff01ff01010101ff, 0xff01ff010101ff01, 0xff01ff010101ffff,\n    0xff01ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0xff01ff0101ffffff,\n    0xff01ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0xff01ff01ff01ffff,\n    0xff01ff01ffff0101, 0xff01ff01ffff01ff, 0xff01ff01ffffff01, 0xff01ff01ffffffff,\n    0xff01ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0xff01ffff0101ffff,\n    0xff01ffff01ff0101, 0xff01ffff01ff01ff, 0xff01ffff01ffff01, 0xff01ffff01ffffff,\n    0xff01ffffff010101, 0xff01ffffff0101ff, 0xff01ffffff01ff01, 0xff01ffffff01ffff,\n    0xff01ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0xff01ffffffffffff,\n    0xffff010101010101, 0xffff0101010101ff, 0xffff01010101ff01, 0xffff01010101ffff,\n    0xffff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0xffff010101ffffff,\n    0xffff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0xffff0101ff01ffff,\n    0xffff0101ffff0101, 0xffff0101ffff01ff, 0xffff0101ffffff01, 0xffff0101ffffffff,\n    0xffff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0xffff01ff0101ffff,\n    0xffff01ff01ff0101, 0xffff01ff01ff01ff, 0xffff01ff01ffff01, 0xffff01ff01ffffff,\n    0xffff01ffff010101, 0xffff01ffff0101ff, 0xffff01ffff01ff01, 0xffff01ffff01ffff,\n    0xffff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0xffff01ffffffffff,\n    0xffffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0xffffff010101ffff,\n    0xffffff0101ff0101, 0xffffff0101ff01ff, 0xffffff0101ffff01, 0xffffff0101ffffff,\n    0xffffff01ff010101, 0xffffff01ff0101ff, 0xffffff01ff01ff01, 0xffffff01ff01ffff,\n    0xffffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0xffffff01ffffffff,\n    0xffffffff01010101, 0xffffffff010101ff, 0xffffffff0101ff01, 0xffffffff0101ffff,\n    0xffffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0xffffffff01ffffff,\n    0xffffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0xffffffffff01ffff,\n    0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff,\n};\n\nstruct SignHelper {\n\n    IQK_ALWAYS_INLINE void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const {\n        auto s = vreinterpretq_s8_u64(uint64x2_t{kall_signs[sign_bits[0]], kall_signs[sign_bits[1]]});\n        // Normally we would expect this to be faster, but it isn't.\n        // auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1]));\n        // auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));\n        b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));\n    }\n\n    // We would need these two if we weren't loading from the unpacked sign table.\n    //const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));\n    //const uint8x16_t m1    = vdupq_n_u8(1);\n};\n\nstruct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {\n    DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x4_t new_block(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        prepare_internal(i, 0, bits);\n        return prepare_4bit_scales16(x[i].scales);\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) prepare_internal(i, 1, bits);\n    }\n\nprivate:\n\n    static void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) {\n        uint32_t aux32[2];\n        const uint16_t * aux16 = (const uint16_t *)aux32;\n        for (int k = 0; k < 2; ++k) {\n            aux32[1] = (qh[k] << 4) | (qh[k] << 18);\n            aux32[0] = (aux32[1] << 4) & 0x03000300;\n            aux32[1] &= 0x03000300;\n            b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))),\n                                   vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1]))));\n            b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))),\n                                   vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3]))));\n            sh.apply_signs_1x(b+2*k+0, sign_bits); sign_bits += 2;\n            sh.apply_signs_1x(b+2*k+1, sign_bits); sign_bits += 2;\n        }\n    }\n\n    void prepare_internal(int i, int j, SimpleBits& sb) {\n\n        const auto * qs = x[i].qs + 16*j;\n        const auto * qh = x[i].qh + 4*j;\n        const auto * sign_bits = qs + QK_K/8;\n\n        make4(sh, sign_bits+0, qs+0, qh+0, sb.b1.val);\n        make4(sh, sign_bits+8, qs+8, qh+2, sb.b2.val);\n    }\n\n    SignHelper sh;\n};\n\nstruct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {\n    DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    IQK_ALWAYS_INLINE float new_block(int i) const { return 0.25f * GGML_FP16_TO_FP32(x[i].d); }\n\n    inline int32x4_t unpack(int i, int j, uint8x16_t * q) const {\n        auto q3data = vld1q_u8_x2(x[i].qs + 32*j);\n        auto gas = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j));\n        prepare_block((const uint8_t *)q3data.val, (const uint32_t *)&gas, q);\n        return prepare_scales_8(gas);\n    }\n\nprivate:\n\n    inline static void make2(const uint8_t * q3, const uint32_t sidx, uint8x16_t * b) {\n        b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]});\n        b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]});\n        apply_signs_2(b, keven_signs, sidx);\n    }\n    inline static void prepare_block(const uint8_t * q3, const uint32_t * signs, uint8x16_t * quants) {\n        make2(q3+ 0, signs[0], quants + 0);\n        make2(q3+ 8, signs[1], quants + 2);\n        make2(q3+16, signs[2], quants + 4);\n        make2(q3+24, signs[3], quants + 6);\n    }\n};\n\nstruct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {\n    DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x2_t new_block(int i) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        uint32_t scales32[2];\n        auto qs = vld1q_u8_x2(x[i].qs);\n        auto signs = vld1q_u8(x[i].signs);\n\n        prepare_block((const uint8_t *)qs.val, x[i].qh, (const uint8_t *)&signs);\n\n        std::memcpy(scales32, x[i].scales, 4);\n        scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;\n        scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;\n        auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7\n        scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));\n        auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));\n        int32x4x2_t scales;\n        scales.val[0] = vmovl_s16(vget_low_s16(scales16));\n        scales.val[1] = vmovl_s16(vget_high_s16(scales16));\n        return scales;\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) {\n            auto qs = vld1q_u8_x2(x[i].qs + 32);\n            auto signs = vld1q_u8(x[i].signs + 16);\n            prepare_block((const uint8_t *)qs.val, x[i].qh + 4, (const uint8_t *)&signs);\n        }\n    }\n\nprivate:\n\n    static inline void make2(const SignHelper& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh,\n            const int16x8_t& hshift, uint8x16_t * b) {\n        auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));\n        const uint16_t * idx = (const uint16_t *)&vindex;\n        b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});\n        sh.apply_signs_1x(b+0, sign_bits+0);\n        b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});\n        sh.apply_signs_1x(b+1, sign_bits+2);\n    }\n    static inline void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh,\n            const int16x8_t& hshift, uint8x16_t * b) {\n        auto idx_l = vld1q_u8(qs);\n        make2(sh, sign_bits+0, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);\n        make2(sh, sign_bits+4, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);\n    }\n\n    static int16x8_t load_shift() {\n        static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};\n        return vld1q_s16(k_shift);\n    }\n\n    inline void prepare_block(const uint8_t * qs, const uint8_t * qh, const uint8_t * sign_bits) {\n        auto signs = vld1q_u8(sign_bits);\n        auto s = (const uint8_t *)&signs;\n        make4(sh, s + 0, qs+ 0, qh+0, hshift, bits.b1.val);\n        make4(sh, s + 8, qs+16, qh+2, hshift, bits.b2.val);\n    }\n\n    SignHelper sh;\n    const int16x8_t hshift = load_shift();\n\n};\n\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_IQXXS(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n    Dequantizer deq(vx, bx, nrc_y);\n    uint8x16_t  qx[8];\n    int32x4_t   sumi[nrc_y];\n    float32x4_t acc[nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb; ++i) {\n            float d = deq.new_block(i);\n            auto scales = deq.unpack(i, 0, qx);\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                sumi[iy] = vdupq_n_s32(0);\n                compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 0, sumi[iy]);\n            }\n            scales = deq.unpack(i, 1, qx);\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 1, sumi[iy]);\n                acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, i)), vcvtq_f32_s32(sumi[iy]));\n            }\n        }\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\n// =========================================== Legacy quants\n\ntemplate <typename Block>\ninline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) {\n    for (int k = 0; k < 4; ++k) aux[k] = x[k].d;\n    return vld1_f16((const float16_t *)aux);\n}\n\ntemplate <typename Block>\ninline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) {\n    if constexpr (std::is_same_v<Block, block_q8_1>) {\n        for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; }\n    } else {\n        for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; }\n    }\n    return vld1q_f16((const float16_t *)aux);\n}\n\nstruct Q4LegacyBits {\n    template <typename Block>\n    inline void prepare(const Block * x) {\n        for (int i = 0; i < 4; ++i) {\n            auto q4bits = vld1q_u8(x[i].qs);\n            b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));\n            b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));\n        }\n    }\n    inline void prepare1(const uint8_t * qs, int8x16_t * q) const {\n        auto q4bits = vld1q_u8(qs);\n        q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));\n        q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));\n    }\n    inline void prepare1(const uint8_t * qs) {\n        prepare1(qs, b);\n    }\n    const uint8x16_t m4b = vdupq_n_u8(0xf);\n    int8x16_t b[8];\n};\n\n// One would think this commented out version would do better than the one below\n// because it offers more opportunities to execute instructions in parallel.\n// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers\n// cannot it just do the sequential version below on its own?\n//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {\n//    const auto q8b_1 = vld1q_s8_x2(qs + 0);\n//    auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]);\n//    const auto q8b_2 = vld1q_s8_x2(qs + 32);\n//    auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]);\n//    auto p1234 = vpaddq_s32(p12, p34);\n//    const auto q8b_3 = vld1q_s8_x2(qs + 64);\n//    auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]);\n//    const auto q8b_4 = vld1q_s8_x2(qs + 96);\n//    auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]);\n//    return vpaddq_s32(p1234, vpaddq_s32(p56, p78));\n//}\n\ninline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {\n    auto q8b = vld1q_s8_x2(qs + 0);\n    auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]);\n    q8b = vld1q_s8_x2(qs + 32);\n    auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]);\n    auto p1234 = vpaddq_s32(p12, p34);\n    q8b = vld1q_s8_x2(qs + 64);\n    auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]);\n    q8b = vld1q_s8_x2(qs + 96);\n    auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]);\n    return vpaddq_s32(p1234, vpaddq_s32(p56, p78));\n}\n\ntemplate <int nrc> struct Q80 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q80(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);\n    }\n\n    inline const int8_t * quant_data(int iy, int i) const {\n        const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;\n        return y4->qs;\n    }\n\n    inline float16x4_t load_scales(int iy, int i) const {\n        const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;\n        return vld1_f16((const float16_t *)y4->d);\n    }\n\n    template <typename Dequantizer>\n    inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const {\n        auto qx_scales = deq.new_block(i);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8_scales = load_scales(iy, i);\n            sc16[iy] = vmul_f16(qx_scales, q8_scales);\n        }\n    }\n\n    template <typename Dequantizer>\n    inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {\n        deq.prepare1(i);\n        float d = GGML_FP16_TO_FP32(deq.x[i].d);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8b = vld1q_s8_x2(y[iy][i].qs);\n            auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);\n            acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));\n        }\n    }\n\n    const block_q8_0 * y[nrc_y];\n};\n\ntemplate <int nrc> struct Q81 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q81(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy);\n    }\n\n    inline const int8_t * quant_data(int iy, int i) const {\n        const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;\n        return y4->qs;\n    }\n\n    inline float16x8_t load_scales(int iy, int i) const {\n        const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;\n        return vld1q_f16((const float16_t *)y4->d);\n    }\n\n    template <typename Dequantizer>\n    inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const {\n        auto qx_scales = deq.new_block(i);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8_scales = load_scales(iy, i);\n            auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales));\n            acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m));\n            sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales));\n        }\n    }\n\n    template <typename Dequantizer>\n    inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {\n        deq.prepare1(i);\n        float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8b = vld1q_s8_x2(y[iy][i].qs);\n            auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);\n            acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));\n            acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s)));\n        }\n    }\n\n    const block_q8_1 * y[nrc_y];\n};\n\ntemplate <typename block_q>\nstruct BaseLegacyDequantizer {\n\n    BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {}\n\n    inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); }\n\n    Q4LegacyBits bits;\n\n    const void * vx;\n    const block_q * x;\n    size_t bx;\n};\n\nstruct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> {\n\n    DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        q[0] = vaddq_s8(q[0], m8);\n        q[1] = vaddq_s8(q[1], m8);\n    }\n    inline void prepare1(int i) {\n        prepare1(i, bits.b);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n    const int8x16_t m8 = vdupq_n_s8(-8);\n    //ggml_half aux[4];\n};\n\nstruct DequantizerQ41 : public BaseLegacyDequantizer<block_q4_1> {\n\n    DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i) {\n        bits.prepare1(x[i].qs);\n    }\n\n    inline float16x8_t new_block(int i) {\n        uint32_t aux32[4];\n        const uint32_t * s32 = (const uint32_t *)&x[4*i].d;\n        for (int k = 0; k < 4; ++k) {\n            aux32[k] = *s32; s32 += sizeof(block_q4_1)/4;\n            bits.prepare1(x[4*i+k].qs, bits.b + 2*k);\n        }\n        return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));\n    }\n    // Leaving this commented out attempt to be reminded that I already tried this.\n    // It has basically the same performance as the version above.\n    //inline float16x8_t new_block(int i) {\n    //    uint32x4_t scales = {};\n    //    const block_q4_1 * xi = x + 4*i;\n    //    const uint32_t * s32 = (const uint32_t *)&xi->d;\n    //    scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[0].qs, bits.b + 0);\n    //    scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[1].qs, bits.b + 2);\n    //    scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[2].qs, bits.b + 4);\n    //    scales = vsetq_lane_u32(*s32, scales, 3);\n    //    bits.prepare1(xi[3].qs, bits.b + 6);\n    //    return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle)));\n    //}\n\n    const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};\n};\n\nstruct HighBit5Legacy {\n    inline uint8x16_t to_bytes(const uint8_t * qh) const {\n        uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);\n        return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask));\n    }\n    inline uint8x16_t to_negated_bytes(const uint8_t * qh) const {\n        uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);\n        return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0));\n    }\n    const uint64x2_t mask = vdupq_n_u64(0x8040201008040201);\n    const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));\n};\n\nstruct DequantizerQ50 final : public BaseLegacyDequantizer<block_q5_0> {\n\n    DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        auto qh = x[i].qh;\n        q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0))));\n        q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2))));\n    }\n    inline void prepare1(int i) {\n        prepare1(i, bits.b);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n    HighBit5Legacy hbits;\n\n    const uint8x16_t mh = vdupq_n_u8(0xf0);\n\n};\n\nstruct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> {\n\n    DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i) {\n        bits.b[0] = vld1q_s8(x[i].qs);\n        bits.b[1] = vld1q_s8(x[i].qs+16);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs);\n            bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n};\n\nstruct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> {\n\n    DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        auto qh = x[i].qh;\n        q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0))));\n        q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2))));\n    }\n    inline void prepare1(int i) {\n        bits.prepare1(x[i].qs, bits.b);\n    }\n\n    inline float16x8_t new_block(int i) {\n        uint32_t aux32[4];\n        const uint32_t * s32 = (const uint32_t *)&x[4*i].d;\n        for (int k = 0; k < 4; ++k) {\n            aux32[k] = *s32; s32 += sizeof(block_q5_1)/4;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));\n    }\n\n    HighBit5Legacy hbits;\n\n    const uint8x16_t mh = vdupq_n_u8(0x10);\n    const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};\n\n};\n\ntemplate <typename Dequantizer, typename Q8>\ninline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i));\n        auto scale = vcvt_f32_f16(sc16[iy]);\n        acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall));\n    }\n}\n\ntemplate <typename Dequantizer, typename Q8>\ninline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK4_1;\n\n    float16x4_t sc16[Q8::nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[Q8::nrc_y];\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb/4; ++i) {\n            q8.process_scales(i, deq, sc16, acc);\n            sum_4(i, deq, q8, sc16, acc);\n        }\n        for (int i = 4*(nb/4); i < nb; ++i) {\n            q8.process_1_block(i, deq, acc);\n        }\n\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\ntemplate <typename Dequantizer, typename Q8>\ninline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK4_1;\n\n    float16x4_t sc16[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq1.new_row(ix);\n        deq2.new_row(ix);\n\n        float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) };\n\n        for (int i = 0; i < nb/8; ++i) {\n            q8.process_scales(2*i+0, deq1, sc16+0, acc+0);\n            q8.process_scales(2*i+1, deq2, sc16+1, acc+1);\n            sum_4(2*i+0, deq1, q8, sc16+0, acc+0);\n            sum_4(2*i+1, deq2, q8, sc16+1, acc+1);\n        }\n        for (int i = 2*(nb/8); i < nb/4; ++i) {\n            q8.process_scales(i, deq1, sc16, acc);\n            sum_4(i, deq1, q8, sc16, acc);\n        }\n        for (int i = 4*(nb/4); i < nb; ++i) {\n            q8.process_1_block(i, deq1, acc);\n        }\n\n        info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1])));\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void IQK_NOINLINE mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Q81<nrc_y> q8(info);\n    if constexpr (nrc_y == 1) {\n        Dequantizer deq1(vx, bx), deq2(vx, bx);\n        mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n    } else {\n        Dequantizer deq(vx, bx);\n        mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void IQK_NOINLINE mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Q80<nrc_y> q8(info);\n    if constexpr (nrc_y == 1) {\n        Dequantizer deq1(vx, bx), deq2(vx, bx);\n        mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n    } else {\n        Dequantizer deq(vx, bx);\n        mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);\n    }\n}\n\ntemplate <typename Dequantizer>\nstatic void IQK_NOINLINE mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Dequantizer deq1(vx, bx), deq2(vx, bx);\n    Q81<1> q8(info);\n    mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n}\n\ntemplate <typename Dequantizer>\nstatic void IQK_NOINLINE mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Dequantizer deq1(vx, bx), deq2(vx, bx);\n    Q80<1> q8(info);\n    mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x);\n}\n\ntemplate <typename Dequantizer> void MulMat::set_functions(MulMat& m) {\n    if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||\n                  std::is_same_v<Dequantizer, DequantizerQ80>) {\n        m.funcs[0] = mul_mat_qX_0_q8_0<Dequantizer, 1>;\n        m.funcs[1] = mul_mat_qX_0_q8_0<Dequantizer, 2>;\n        m.funcs[2] = mul_mat_qX_0_q8_0<Dequantizer, 3>;\n        m.funcs[3] = mul_mat_qX_0_q8_0<Dequantizer, 4>;\n        m.funcs[4] = mul_mat_qX_0_q8_0<Dequantizer, 5>;\n        m.funcs[5] = mul_mat_qX_0_q8_0<Dequantizer, 6>;\n        m.funcs[6] = mul_mat_qX_0_q8_0<Dequantizer, 7>;\n        m.funcs[7] = mul_mat_qX_0_q8_0<Dequantizer, 8>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerQ41> || std::is_same_v<Dequantizer, DequantizerQ51>) {\n        m.funcs[0] = mul_mat_qX_1_q8_1<Dequantizer, 1>;\n        m.funcs[1] = mul_mat_qX_1_q8_1<Dequantizer, 2>;\n        m.funcs[2] = mul_mat_qX_1_q8_1<Dequantizer, 3>;\n        m.funcs[3] = mul_mat_qX_1_q8_1<Dequantizer, 4>;\n        m.funcs[4] = mul_mat_qX_1_q8_1<Dequantizer, 5>;\n        m.funcs[5] = mul_mat_qX_1_q8_1<Dequantizer, 6>;\n        m.funcs[6] = mul_mat_qX_1_q8_1<Dequantizer, 7>;\n        m.funcs[7] = mul_mat_qX_1_q8_1<Dequantizer, 8>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2XXS> || std::is_same_v<Dequantizer, DequantizerIQ3XXS>) {\n        m.funcs[0] = mul_mat_qX_K_q8_K_IQXXS<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_IQXXS<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_IQXXS<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_IQXXS<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_IQXXS<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_IQXXS<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_IQXXS<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_IQXXS<8, Dequantizer>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2S> ||\n                       std::is_same_v<Dequantizer, DequantizerIQ3S> ||\n                       std::is_same_v<Dequantizer, DequantizerIQ2XS>) {\n        m.funcs[0] = mul_mat_qX_K_q8_K_IQ<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_IQ<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_IQ<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_IQ<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_IQ<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_IQ<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_IQ<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_IQ<8, Dequantizer>;\n    }\n    else {\n        m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>;\n    }\n}\n\nbool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny) {\n    row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);\n\n    (void)Ny;\n    // Uncommenting out this would disable iqk_mul_mat for matrix x vector multiplications.\n    //if (Ny == 1 && (typeA == GGML_TYPE_IQ2_XXS || typeA == GGML_TYPE_IQ2_XS || typeA == GGML_TYPE_IQ2_S ||\n    //                typeA == GGML_TYPE_IQ3_XXS || typeA == GGML_TYPE_IQ3_S)) return false;\n\n    switch (typeA) {\n        case GGML_TYPE_Q2_K:\n            MulMat::set_functions<DequantizerQ2K>(m);\n            break;\n        case GGML_TYPE_Q3_K:\n            MulMat::set_functions<DequantizerQ3K>(m);\n            break;\n        case GGML_TYPE_Q4_K:\n            MulMat::set_functions<DequantizerQ4K>(m);\n            break;\n        case GGML_TYPE_Q5_K:\n            MulMat::set_functions<DequantizerQ5K>(m);\n            break;\n        case GGML_TYPE_Q6_K:\n            MulMat::set_functions<DequantizerQ6K>(m);\n            break;\n        case GGML_TYPE_IQ4_XS:\n            MulMat::set_functions<DequantizerIQ4XS>(m);\n            break;\n        case GGML_TYPE_IQ3_S:\n            MulMat::set_functions<DequantizerIQ3S>(m);\n            break;\n        case GGML_TYPE_IQ3_XXS:\n            MulMat::set_functions<DequantizerIQ3XXS>(m);\n            break;\n        case GGML_TYPE_IQ2_S:\n            MulMat::set_functions<DequantizerIQ2S>(m);\n            break;\n        case GGML_TYPE_IQ2_XS:\n            MulMat::set_functions<DequantizerIQ2XS>(m);\n            break;\n        case GGML_TYPE_IQ2_XXS:\n            MulMat::set_functions<DequantizerIQ2XXS>(m);\n            break;\n        case GGML_TYPE_Q4_0:\n            MulMat::set_functions<DequantizerQ40>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        case GGML_TYPE_Q4_1:\n            MulMat::set_functions<DequantizerQ41>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);\n            break;\n        case GGML_TYPE_Q5_0:\n            MulMat::set_functions<DequantizerQ50>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        case GGML_TYPE_Q5_1:\n            MulMat::set_functions<DequantizerQ51>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);\n            break;\n        case GGML_TYPE_Q8_0:\n            MulMat::set_functions<DequantizerQ80>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        default:\n            return false;\n    }\n    return true;\n}\n\n}\n\n#endif // __x86_64__ or __aarch64__"
  },
  {
    "path": "third_party/llamafile/iqk_mul_mat_amd_avx2.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_amd_avx2.cpp\n// Copyrigth 2024 Iwan Kawrakow.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#include \"iqk_mul_mat.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/iqk_mul_mat_amd_zen4.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_amd_zen4.cpp\n// Copyrigth 2024 Iwan Kawrakow.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define iqk_mul_mat iqk_mul_mat_zen4\n#define iqk_mul_mat_moe iqk_mul_mat_moe_zen4\n#include \"iqk_mul_mat.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/iqk_mul_mat_arm.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat.inc\n// Copyrigth 2024 Iwan Kawrakow.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp fenc=utf-8 :vi\n//\n// Copyright 2024 Iwan Kawrakow\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include <cstring>\n#include <type_traits>\n#if defined __x86_64__ || defined __aarch64__ || defined(_M_X64)\n\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n#include \"sgemm.h\"\n\n// For i-quants, I had to explicitely specify which\n// functions to inline / not inline (at least for some\n// of the functions), else performance would be significantly\n// lower. This is worrysome as things can change with,\n// e.g., a different compiler version or running on a different\n// CPU.\n#ifdef _MSC_VER\n#define IQK_NOINLINE __declspec(noinline)\n#define IQK_ALWAYS_INLINE inline\n#else\n#define IQK_NOINLINE __attribute__((__noinline__))\n#define IQK_ALWAYS_INLINE __attribute__((always_inline))\n#endif\n\n#define GGML_COMMON_IMPL_C\n#include \"llama.cpp/ggml-common.h\"\n\n// clang-format off\n\n// This matrix - vector and matrix - matrix multiplication implementation\n// for legacy quants, k-quants and i-quants makes prompt processing 150-200%\n// (legacy and k-quants) or 250-400% (i-quants) faster.\n// compared to mainline llama.cpp (and llamafile).\n// It provides implementations for ARM_NEON (all quants) and AVX2\n// (all quants except sub-4 bit i-quants).\n//\n// Main idea is that unpacking the quants and the block scales to\n// be ready for dot products with the corresponding Q8_Y quants\n// takes time (here 'Y' stands for K, 0, or 1, depending on quantization type).\n// Hence, if we are performing a QX x Q8_Y matrix matrix\n// multiplication (as needed for prompt processing), we can get\n// a significant speedup by reusing the unpacked QX quants and scales\n// for multiplication with several Q8_K columns. We also achieve fewer\n// loads from memory, which is the main purpose of tiling in general\n// purpose matrix multiplication packages.\n\n#include <utility>\n#include <array>\n\n#endif\n\nnamespace {\n\ntypedef struct {\n    int32_t i1;\n    int32_t i2;\n} mmid_row_mapping;\n\nstruct DataInfo {\n    float       * s;\n    const char  * cy;\n    size_t        bs;\n    size_t        by;\n    int           cur_y = 0;\n    int           ne11;\n    const mmid_row_mapping * row_mapping = nullptr;\n    size_t        bs2 = 0;\n\n    inline const char * src1_row(int iy) const {\n        if (!row_mapping) return cy + (cur_y + iy)*by;\n        int i11 = row_mapping[cur_y + iy].i1 % ne11;\n        int i12 = row_mapping[cur_y + iy].i2;\n        return cy + (i11 + i12*ne11)*by;\n    }\n\n    inline void store(int ix, int iy, float result) const {\n        *(dst_row(iy) + ix) = result;\n        //dst_row(iy)[ix] = result;\n    }\n    inline float * dst_row(int iy) const {\n        if (!row_mapping) return s + (cur_y + iy)*bs;\n        int i12 = row_mapping[cur_y + iy].i2;\n        int i1  = row_mapping[cur_y + iy].i1;\n        int i2  = i12;\n        return s + i1*bs + i2*bs2;\n    }\n};\n\ntypedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);\n\nstruct MulMat {\n    std::array<mul_mat_t, 8> funcs = {};\n    //inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {\n    IQK_NOINLINE void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {\n        constexpr int k_x_step = 64; // This works best on my Ryzen-7950X and M2 Max CPUs (but differences to other tile size are small)\n        int n_step = (nrc_y - info.cur_y)/funcs.size();\n        if (n_step > 0) {\n            for (int ix = 0; ix < nrc_x; ix += k_x_step) {\n                auto this_info = info;\n                this_info.s += ix;\n                int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;\n                for (int iy = 0; iy < n_step; ++iy) {\n                    funcs.back()(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x);\n                    this_info.cur_y += funcs.size();\n                }\n            }\n            info.cur_y += funcs.size() * n_step;\n        }\n        int n_left = nrc_y - info.cur_y;\n        if (n_left > 0) {\n            funcs[n_left-1](n, vx, bx, info, nrc_x);\n        }\n    }\n    static IQK_NOINLINE bool set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny);\nprivate:\n    template <typename Dequantizer> static IQK_NOINLINE void set_functions(MulMat& m);\n};\n\ninline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) {\n    const uint16_t * scales = (const uint16_t *)scales8;\n    const uint32_t a0 = scales[0] | (scales[1] << 16);\n    const uint32_t a1 = scales[2] | (scales[3] << 16);\n    const uint32_t a2 = scales[4] | (scales[5] << 16);\n    aux32[3] = ((a2 >> 4) & 0x0f0f0f0f) | ((a1 >> 2) & 0x30303030);\n    aux32[1] = ((a2 >> 0) & 0x0f0f0f0f) | ((a0 >> 2) & 0x30303030);\n    aux32[2] = a1 & 0x3f3f3f3f;\n    aux32[0] = a0 & 0x3f3f3f3f;\n}\n\nconst uint64_t keven_signs[128] = {\n    0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,\n    0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,\n    0xff010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0xff010101ff01ffff,\n    0x01010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0x01010101ffffffff,\n    0xff0101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0xff0101ff0101ffff,\n    0x010101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0x010101ff01ffffff,\n    0x010101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0x010101ffff01ffff,\n    0xff0101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0xff0101ffffffffff,\n    0xff01ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0xff01ff010101ffff,\n    0x0101ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0x0101ff0101ffffff,\n    0x0101ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0x0101ff01ff01ffff,\n    0xff01ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0xff01ff01ffffffff,\n    0x0101ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0x0101ffff0101ffff,\n    0xff01ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0xff01ffff01ffffff,\n    0xff01ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0xff01ffffff01ffff,\n    0x0101ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0x0101ffffffffffff,\n    0xffff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0xffff01010101ffff,\n    0x01ff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0x01ff010101ffffff,\n    0x01ff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0x01ff0101ff01ffff,\n    0xffff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0xffff0101ffffffff,\n    0x01ff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0x01ff01ff0101ffff,\n    0xffff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0xffff01ff01ffffff,\n    0xffff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0xffff01ffff01ffff,\n    0x01ff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0x01ff01ffffffffff,\n    0x01ffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0x01ffff010101ffff,\n    0xffffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0xffffff0101ffffff,\n    0xffffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0xffffff01ff01ffff,\n    0x01ffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0x01ffff01ffffffff,\n    0xffffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0xffffffff0101ffff,\n    0x01ffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0x01ffffff01ffffff,\n    0x01ffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0x01ffffffff01ffff,\n    0xffffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0xffffffffffffffff,\n};\n\n}\n\nbool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, const void * B,\n        float * C, long stride_C, int ith, int nth) {\n\n    MulMat mm;\n    int row_size_q8;\n    if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) {\n        return false;\n    }\n\n    auto row_size_qx = ggml_row_size((ggml_type)typeA, ne00);\n\n    auto nrc_x = (Nx + nth - 1)/nth;\n    auto first_x = ith*nrc_x;\n    if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;\n\n    DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, (size_t)row_size_q8, 0, 1, nullptr, 0};\n\n    mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);\n\n    return true;\n}\n\nbool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B,\n        float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {\n    const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;\n    assert(row_mapping != nullptr);\n\n    MulMat mm;\n    int row_size_q8;\n    if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) {\n        return false;\n    }\n    int row_size_qx = ggml_row_size((ggml_type)typeA, ne00);\n    int nrc_x = (Nx + nth - 1)/nth;\n    int first_x = ith*nrc_x;\n    if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;\n    DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)};\n    mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);\n    return true;\n}\n\n#if defined __x86_64__ || defined(_M_X64)\n\n#if defined HAVE_FANCY_SIMD\n    #undef HAVE_FANCY_SIMD\n#endif\n#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__)\n    #define HAVE_FANCY_SIMD\n#endif\n\nnamespace {\n\ninline float hsum_float_4(__m128 x) {\n    x = _mm_add_ps(x, _mm_movehl_ps(x, x));\n    x = _mm_add_ss(x, _mm_movehdup_ps(x));\n    return _mm_cvtss_f32(x);\n}\ninline float hsum_float_8(__m256 x) {\n    return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));\n}\n\n#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)\n\n\ntemplate <int nrc, typename block_q8 = block_q8_K> struct Q8 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q8(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);\n    }\n\n#ifdef HAVE_FANCY_SIMD\n    inline __m512i load_quants(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); }\n#else\n    inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); }\n#endif\n    inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); }\n    inline float scale(int iy, int i) const { return y[iy][i].d; }\n\n    const block_q8 * y[nrc_y];\n};\n\n// Handles q4_K and q5_K scales/mins\nstruct Scales8K {\n    template <typename Q8>\n    inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {\n        make_q4_scales(data, utmp);\n        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));\n        const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);\n        accum_mins(mins128, q8, i, c, accd);\n        const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);\n        return MM256_SET_M128I(sc128, sc128);\n    }\n#ifdef HAVE_FANCY_SIMD\n    template <typename Q8>\n    inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {\n        auto scales = process_mins_and_scales(data, c, i, q8, accd);\n        return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1);\n    }\n#endif\n    template <typename Q8>\n    inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {\n        const __m256i mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, shuffles[1]), _mm_shuffle_epi8(mins128, shuffles[0]));\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i q8s = q8.load_bsums(iy, i);\n            const __m256i prod = _mm256_madd_epi16(mins, q8s);\n            accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(c*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);\n        }\n    }\n#ifdef HAVE_FANCY_SIMD\n    const __m512i shuffles512[2] = {\n        _mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,\n                         0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),\n        _mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,\n                         0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)\n    };\n#endif\n    const __m128i shuffles[2] = {_mm_set_epi32(0x07060706, 0x05040504, 0x03020302, 0x01000100),\n                                 _mm_set_epi32(0x0f0e0f0e, 0x0d0c0d0c, 0x0b0a0b0a, 0x09080908)};\n\n    uint32_t utmp[4];\n};\n\ntemplate <typename Q8>\ninline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        const __m256i prod  = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i));\n        accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);\n    }\n}\ninline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {\n    const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);\n    const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);\n    scales[0] = MM256_SET_M128I(l_scales, l_scales);\n    scales[1] = MM256_SET_M128I(h_scales, h_scales);\n}\n\nstruct ScaleQ3 {\n    inline __m128i make_scales(const uint16_t * s8) const {\n        const uint16_t * scales16 = (const uint16_t *)s8;\n        uint32_t aux0 = scales16[0] | (scales16[1] << 16);\n        uint32_t aux1 = scales16[2] | (scales16[3] << 16);\n        uint32_t aux2 = scales16[4] | (scales16[5] << 16);\n        __m128i scales128 = _mm_set_epi32(\n            ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030),\n            ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030),\n             (aux1       & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030),\n             (aux0       & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030));\n        return _mm_add_epi8(scales128, m32);\n    }\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct ScaleIQ4XS {\n    inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) {\n        uint32_t tmp32 = scales_h | (scales_h << 14);\n        const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4);\n        const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask);\n        return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32);\n    }\n    const __m128i hshift = _mm_set_epi32(12, 8, 4, 0);\n    const __m128i lshift = _mm_set_epi32(4, 0, 4, 0);\n    const __m128i hmask  = _mm_set1_epi16(0x03);\n    const __m128i lmask  = _mm_set1_epi8(0xf);\n    const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400);\n    const __m128i m32 = _mm_set1_epi16(-32);\n};\n\ntemplate <typename Block>\nstruct BaseDequantizer {\n    BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {}\n    inline void new_row(int ix) {\n        x = (const Block *)((const char *)vx + bx*ix);\n    }\n\n    const void *  vx;\n    size_t        bx;\n    const Block * x;\n\n    float d;\n};\n\n#ifdef HAVE_FANCY_SIMD\n//====================================== Zen4 ==================================================\n\nstruct BlockPermuter {\n    const __m512i permute1 = _mm512_set_epi64(11, 10,  9,  8, 3, 2, 1, 0);\n    const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4);\n};\n\nstruct Q4Bits {\n    inline void prepare(const uint8_t * q4) {\n        auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);\n        auto tmp1 = _mm512_and_si512(q4bits, ml);\n        auto tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        values[0] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);\n        values[1] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);\n        q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);\n        tmp1 = _mm512_and_si512(q4bits, ml);\n        tmp2 = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        values[2] = _mm512_permutex2var_epi64(tmp1, perm.permute1, tmp2);\n        values[3] = _mm512_permutex2var_epi64(tmp1, perm.permute2, tmp2);\n    }\n    inline void prepare64(const uint8_t * q4) {\n        auto q4bits = _mm512_loadu_si512((const __m512i*)q4 + 0);\n        values[0] = _mm512_and_si512(q4bits, ml);\n        values[1] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm512_loadu_si512((const __m512i*)q4 + 1);\n        values[2] = _mm512_and_si512(q4bits, ml);\n        values[3] = _mm512_and_si512(_mm512_srli_epi16(q4bits, 4), ml);\n    }\n    __m512i values[4];\n    const __m512i ml = _mm512_set1_epi8(0xf);\n    BlockPermuter perm;\n};\n\nstruct Q2Bits {\n    inline void prepare(const uint8_t * q2) {\n\n        auto q2bits = _mm512_loadu_si512((const __m512i*)q2);\n        auto tmp = _mm512_srli_epi16(q2bits, 2);\n\n        values[0] = _mm512_permutex2var_epi64(q2bits, perm.permute1, tmp);\n        values[2] = _mm512_permutex2var_epi64(q2bits, perm.permute2, tmp);\n        values[1] = _mm512_and_si512(_mm512_srli_epi16(values[0], 4), ml);\n        values[3] = _mm512_and_si512(_mm512_srli_epi16(values[2], 4), ml);\n        values[0] = _mm512_and_si512(values[0], ml);\n        values[2] = _mm512_and_si512(values[2], ml);\n    }\n    __m512i values[4];\n    const __m512i ml = _mm512_set1_epi8(0x03);\n    BlockPermuter perm;\n};\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n        scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n};\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n    DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        prepare(x[i].qs);\n        auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);\n        s8k.accum_mins(scales128, q8, i, -128.f*d, accd);\n        auto scales256 = MM256_SET_M128I(scales128, scales128);\n        auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);\n        scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);\n    }\n    static __m512i load_values() {\n        static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};\n        auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);\n        auto val256 = MM256_SET_M128I(val128, val128);\n        return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1);\n    }\n    inline void prepare(const uint8_t * q4) {\n        bits.prepare64(q4);\n        // We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111\n        //                bits.valuse[1]: 16..31, 48...63, 80...95, 112..127\n        //                etc.\n        auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);\n        bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));\n        bits.values[0] = _mm512_shuffle_epi8(values, tmp);\n        tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);\n        bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));\n        bits.values[2] = _mm512_shuffle_epi8(values, tmp);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n    ScaleIQ4XS siq4;\n    const __m512i values;\n    const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2,  9,  8, 1, 0);\n    const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);\n};\n\nstruct HighBit5 {\n    inline void apply(const uint8_t * h, Q4Bits& bits) {\n        auto hbits256 = _mm256_loadu_si256((const __m256i *)h);\n        auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));\n    }\n    const __m512i mh = _mm512_set1_epi8(0x10);\n};\n\nstruct HighBit3 {\n    inline void apply(const uint8_t * h, Q2Bits& bits) {\n        auto hbits256 = _mm256_loadu_si256((const __m256i *)h);\n        auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh));\n    }\n    const __m512i mh = _mm512_set1_epi8(0x04);\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        hbits.apply(x[i].qh, bits);\n        auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n        scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);\n        scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);\n    }\n\n    Q4Bits bits;\n    HighBit5 hbits;\n    Scales8K s8k;\n};\n\nstruct Scale16 {\n    inline void make_scales(const __m128i& scales8, __m512i * scales) const {\n        auto all_scales8 = MM256_SET_M128I(scales8, scales8);\n        auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1);\n        auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2);\n        scales[0] = _mm512_cvtepi8_epi16(scales1);\n        scales[1] = _mm512_cvtepi8_epi16(scales2);\n    }\n    template <typename Q8>\n    inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8,\n        const Q8& q8, __m256 * accm, __m512i * scales) const {\n        process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm);\n        make_scales(scales8, scales);\n    }\n    const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202,\n                                              0x05050505, 0x01010101, 0x04040404, 0x00000000);\n    const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a,\n                                              0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808);\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);\n        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);\n        sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales);\n    }\n\n    Q2Bits bits;\n    Scale16 sc16;\n    const __m128i m4 = _mm_set1_epi8(0xf);\n\n};\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare(x[i].qs);\n        hbits.apply(x[i].hmask, bits);\n        auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales);\n        sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales);\n    }\n\n    Q2Bits bits;\n    HighBit3 hbits;\n    ScaleQ3 sc3;\n    Scale16 sc16;\n    const __m128i m4  = _mm_set1_epi8(0xf);\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        bits.prepare64(x[i].ql);\n        add_high_bits(x[i].qh, bits);\n        auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales);\n        sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales);\n    }\n\n    inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const {\n        auto hbits = _mm512_loadu_si512((const __m512i *)qh);\n        auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh);\n        auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh);\n        bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));\n        bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));\n        tmp1 = _mm512_and_si512(hbits, mh);\n        tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh);\n        bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));\n        bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));\n    }\n\n    Q4Bits bits;\n    HighBit3 hbits;\n    Scale16 sc16;\n\n    const __m512i mh = _mm512_set1_epi8(0x30);\n\n};\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accm[nrc_y];\n    __m512  accd[nrc_y];\n    __m512i scales[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();\n        for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accm, scales);\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants(iy, i, 0));\n                const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants(iy, i, 1));\n                const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants(iy, i, 2));\n                const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants(iy, i, 3));\n                auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));\n                sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));\n                accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));\n            info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));\n        }\n\n    }\n}\n\n#else\n// ===================================== Vanilla AVX2 =====================================\n\nstruct Q4Bits {\n    inline void prepare(const uint8_t * q4, int j) {\n        auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);\n        values[0] = _mm256_and_si256(q4bits, ml);\n        values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);\n        values[2] = _mm256_and_si256(q4bits, ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n    }\n    inline void prepare64(const uint8_t * q4, int j) {\n        auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);\n        values[0] = _mm256_and_si256(q4bits, ml);\n        values[2] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n        q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);\n        values[1] = _mm256_and_si256(q4bits, ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);\n    }\n    inline void prepare16(const uint8_t * q4, int j) {\n        values[0] = dequant16(q4 + 64*j +  0);\n        values[1] = dequant16(q4 + 64*j + 16);\n        values[2] = dequant16(q4 + 64*j + 32);\n        values[3] = dequant16(q4 + 64*j + 48);\n    }\n    inline __m256i dequant16(const uint8_t * qs) const {\n        const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);\n        const __m256i aux256 = MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128);\n        return _mm256_and_si256(ml, aux256);\n    };\n    __m256i values[4];\n    const __m256i ml = _mm256_set1_epi8(0xf);\n};\n\nstruct Q2Bits {\n    inline void prepare(const uint8_t * q2, int j) {\n        auto q2bits = _mm256_loadu_si256((const __m256i *)q2 + j);\n        values[0] = _mm256_and_si256(q2bits, ml);\n        values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);\n        values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);\n        values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);\n    }\n    __m256i values[4];\n    const __m256i ml = _mm256_set1_epi8(0x03);\n};\n\nstruct HighBit5 {\n    inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }\n    inline void apply(Q4Bits& bits, bool do_shift) {\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));\n        if (do_shift) {\n            hbits = _mm256_srli_epi16(hbits, 4);\n        }\n    }\n    const __m256i mh = _mm256_set1_epi8(0x10);\n    __m256i hbits;\n};\n\nstruct HighBit3 {\n    inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }\n    inline void apply(Q2Bits& bits, bool do_shift) {\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));\n        if (do_shift) {\n            hbits = _mm256_srli_epi16(hbits, 4);\n        }\n    }\n    const __m256i mh = _mm256_set1_epi8(0x04);\n    __m256i hbits;\n};\n\ninline __m256i get_scale_shuffle_8(int i) {\n    return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));\n}\n\ninline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {\n    scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));\n    scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));\n    scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));\n    scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));\n}\n\ntemplate <typename Q8, typename Bits>\ninline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {\n    if (j == 0) {\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3)));\n            sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4));\n        }\n    } else {\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4)));\n            const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5)));\n            const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6)));\n            const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7)));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));\n            sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));\n        }\n    }\n}\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n};\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n    DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);\n        s8k.accum_mins(scales128, q8, i, -128.f*d, accd);\n        return MM256_SET_M128I(scales128, scales128);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare16(x[i].qs, j);\n        bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);\n        bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);\n        bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);\n        bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);\n    }\n\n    static __m256i load_values() {\n        static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};\n        auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);\n        return MM256_SET_M128I(val128, val128);\n    }\n\n    Q4Bits bits;\n    Scales8K s8k;\n    ScaleIQ4XS siq4;\n    const __m256i values;\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        hbits.load(x[i].qh);\n        return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n        hbits.apply(bits, j == 0);\n    }\n\n    Q4Bits  bits;\n    HighBit5 hbits;\n    Scales8K s8k;\n};\n\ntemplate <typename Q8>\ninline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d,\n    __m256 * accm, __m256i * scales) {\n    const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);\n    process_mins_16(all_scales, q8, i, d, accm);\n    prepare_scales_16(all_scales, scales);\n}\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        hbits.load(x[i].hmask);\n        process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n        hbits.apply(bits, j == 0);\n    }\n\n    Q2Bits  bits;\n    HighBit3 hbits;\n    ScaleQ3 sc3;\n\n    const __m128i m32 = _mm_set1_epi8(-32);\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);\n        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);\n        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);\n        process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm);\n        prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs, j);\n    }\n\n    Q2Bits  bits;\n\n    const __m128i m4 = _mm_set1_epi8(0xf);\n};\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}\n    template <typename Q8>\n    inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare64(x[i].ql, j);\n        auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);\n        bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));\n        bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));\n        bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));\n        bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));\n    }\n\n    Q4Bits  bits;\n    const __m256i mh = _mm256_set1_epi8(0x30);\n};\n\ninline __m256i get_scale_shuffle_16(int i) {\n    static const uint8_t k_shuffle[128] = {\n         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,\n         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,\n         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,\n        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,\n    };\n    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);\n}\n\ninline void set_scales_16(const __m256i& all_scales, __m256i * scales) {\n    scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));\n    scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));\n    scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));\n    scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%QK_K == 0);\n    const int nb = n/QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    __m256i all_scales[2];\n    __m256i scales[4];\n    __m256  accd[nrc_y];\n\n    Dequantizer deq(vx, bx);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        for (int i = 0; i < nb; ++i) {\n\n            deq.new_block(i, q8, accd, all_scales);\n\n            __m256i sumi[nrc_y];\n\n            for (int j = 0; j < QK_K/128; ++j) {\n                deq.prepare(i, j);\n                set_scales_16(all_scales[j], scales);\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n            }\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n\n    }\n\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y> q8(info);\n\n    Dequantizer deq(vx, bx);\n\n    __m256  accd[nrc_y];\n    __m256i scales[4];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();\n\n        deq.new_row(ix);\n\n        for (int i = 0; i < nb; ++i) {\n\n            auto all_scales = deq.new_block(i, q8, accd);\n\n            __m256i sumi[nrc_y];\n\n            for (int j = 0; j < QK_K/128; ++j) {\n\n                deq.prepare(i, j);\n\n                set_scales_8(all_scales, j, scales);\n\n                multiply_add(deq.bits, scales, j, i, q8, sumi);\n\n            }\n\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));\n                accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);\n            }\n\n        }\n\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, hsum_float_8(accd[iy]));\n        }\n\n    }\n}\n#endif  // Zen4 or vanilla AVX2\n\n//\n// ============================== Legacy quants\n//\n\nstruct DotHelper {\n    const __m256i m1 = _mm256_set1_epi16(1);\n#if defined(__AVX512VNNI__) && defined(__AVX512VL__)\n    inline __m256i dot(__m256i x, __m256i y) const {\n        return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y);\n    }\n#else\n    inline __m256i dot(__m256i x, __m256i y) const {\n        return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y));\n    }\n#endif\n};\n\nstruct SignedDot {\n    DotHelper helper;\n    inline __m256i compute(__m256i x, __m256i y) const {\n        return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x));\n    }\n};\nstruct UnsignedDot {\n    DotHelper helper;\n    inline __m256i compute(__m256i x, __m256i y) const {\n        return helper.dot(x, y);\n    }\n};\ntemplate <typename Q8, typename Dot> struct Sum4 {\n    Dot dot;\n    inline __m256i compute(const __m256i * qx, const Q8 * y) const {\n        const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs));\n        const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs));\n        const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs));\n        const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs));\n        const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1));    // 0,0, 1,1, 0,0, 1,1\n        const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3));    // 2,2, 3,3, 2,2, 3,3\n        return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3\n    }\n};\n\nstruct Sum4_Q8 {\n    SignedDot dot;\n    static inline __m256i add1(__m256i a, __m256i b) {\n        return _mm256_add_epi32(_mm256_unpacklo_epi32(a, b), _mm256_unpackhi_epi32(a, b));\n    }\n    static inline __m256i add2(__m256i a, __m256i b) {\n        return _mm256_add_epi32(_mm256_unpacklo_epi64(a, b), _mm256_unpackhi_epi64(a, b));\n    }\n    inline __m256i compute(const __m256i * qx, const block_q8_0 * y) const {\n        const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[0].qs));\n        const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y[1].qs));\n        const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y[2].qs));\n        const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y[3].qs));\n        const __m256i p01 = add1(p0, p1);  // 0,1, 0,1, 0,1, 0,1\n        const __m256i p23 = add1(p2, p3);  // 2,3, 2,3, 2,3, 2,3\n        return add2(p01, p23); // returns 0,1,2,3, 0,1,2,3\n    }\n};\n\nstruct ScaleHelperQ_0 {\n    ggml_half scales8[4];\n    template <typename Q>\n    inline __m128 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) scales8[j] = y[j].d;\n        return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8));\n    }\n    template <typename Q>\n    inline __m128 prepare4(__m128 other_scales, const Q * y) {\n        return _mm_mul_ps(other_scales, prepare4<Q>(y));\n    }\n    template <typename Q> inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); }\n    template <typename Q> inline float prepare1(float d, const Q * y) const { return d*prepare1(y); }\n};\n\nstruct ScaleHelperQ_1 {\n    uint32_t scales8[4];\n    const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100);\n\n    template <typename Q>\n    inline __m256 prepare4(const Q * y) {\n        for (int j = 0; j < 4; ++j) {\n            // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers\n            // complain that this breaks strict-aliasing rules.\n            memcpy(scales8 + j, &y[j].d, sizeof(uint32_t));\n        }\n        return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle));\n    }\n\n    template <typename Q>\n    inline __m256 prepare4(__m256 other_scales, const Q * y) {\n        return _mm256_mul_ps(other_scales, prepare4<Q>(y));\n    }\n\n    template <typename Q> inline std::pair<float, float> prepare1(const Q * y) const {\n        return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m));\n    }\n    template <typename Q> inline std::pair<float, float> prepare1(const std::pair<float, float>& dm, const Q * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m));\n    }\n    std::pair<float, float> inline prepare1(const std::pair<float, float>& dm, const block_q8_1 * y) const {\n        return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s));\n    }\n};\n\nstruct MinusType0 {\n    inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); }\n    inline float compute(float d, int) const { return d; }\n    inline float result(__m256 acc, int) const { return hsum_float_8(acc); }\n};\n\ntemplate <int nrc_y> struct MinusType1 {\n    __m128 accm[nrc_y];\n    MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); }\n    inline __m256 compute(__m256 dm, int iy) {\n        const __m128 d = _mm256_castps256_ps128(dm);\n        const __m128 m = _mm256_extractf128_ps(dm, 1);\n        accm[iy] = _mm_add_ps(accm[iy], m);\n        return _mm256_set_m128(d, d);\n    }\n    inline float compute(const std::pair<float, float>& dm, int iy) {\n        accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f));\n        return dm.first;\n    }\n    inline float result(__m256 acc, int iy) const {\n        const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));\n        return hsum_float_4(_mm_add_ps(sum, accm[iy]));\n    }\n};\n\ntemplate <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {\n    __m256 acc[nrc_y];\n    Minus accm;\n    AccumT() {  for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); }\n    template <typename Unpacker, typename Scales, typename Sum, typename Q8>\n    inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) {\n        auto qx = unp.quants();\n        __m256 dall[nrc_y];\n        for (int i = 0; i < nb/4; ++i) {\n            auto other_scales = unp.set_block_4(i);\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto s12 = scales.prepare4(other_scales, y[iy] + 4*i);\n                dall[iy] = accm.compute(s12, iy);\n            }\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                auto pall = sum.compute(qx, y[iy] + 4*i);\n                acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]);\n            }\n        }\n        if (!is_multiple_of_4) {\n            for (int i = 4*(nb/4); i < nb; ++i) {\n                auto other_scales = unp.set_block(i);\n                for (int iy = 0; iy < nrc_y; ++iy) {\n                    auto s12 = scales.prepare1(other_scales, y[iy] + i);\n                    auto d = accm.compute(s12, iy);\n                    const __m256i p0 = sum.dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs));\n                    acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]);\n                }\n            }\n        }\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, accm.result(acc[iy], iy));\n            //s[iy*bs] = accm.result(acc[iy], iy);\n        }\n    }\n};\n\ntemplate <int nrc_y, bool is_multiple_of_4>\nusing AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>;\n\ntemplate <int nrc_y, bool is_multiple_of_4>\nusing AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;\n\nusing Sum4Type0 = Sum4<block_q8_0, SignedDot>;\nusing Sum4Type1 = Sum4<block_q8_1, UnsignedDot>;\n\ntemplate <typename Unpacker, typename Sum4Type, typename AccumType, typename Scales, typename Q8, int nrc_y>\nvoid mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) {\n    Unpacker unp(vx, bx);\n    Sum4Type sum4;\n    Scales scales;\n    for (int ix = 0; ix < nrc_x; ++ix) {\n        unp.set_row(ix);\n        AccumType accum;\n        accum.compute(nb, unp, scales, sum4, y, info, ix);\n    }\n}\n\ntemplate <typename Unpacker, int nrc_y>\nvoid mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_0> q8(info);\n    int nb = n/Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type0, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\ntemplate <typename Unpacker, int nrc_y>\nvoid mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_1> q8(info);\n    int nb = n/Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type1, AccumType1<nrc_y, true>, ScaleHelperQ_1, block_q8_1, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Unpacker, Sum4Type1, AccumType1<nrc_y, false>, ScaleHelperQ_1, block_q8_1, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\nstruct Dequantizer4bit {\n    const __m256i m4 = _mm256_set1_epi8(0xf);\n    inline __m256i dequant(const uint8_t * qs) const {\n        const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs);\n        return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4);\n    }\n};\n\nstruct Q8_0_Dequantizer {\n    inline __m256i dequant(const block_q8_0 * x) const {\n        return _mm256_loadu_si256((const __m256i *)x->qs);\n    }\n};\n\nstruct Q4_0_Dequantizer {\n    Dequantizer4bit b4;\n    const __m256i m8 = _mm256_set1_epi8(-8);\n    inline __m256i dequant(const block_q4_0 * x) const {\n        return _mm256_add_epi8(b4.dequant(x->qs), m8);\n    }\n};\n\nstruct Q4_1_Dequantizer {\n    Dequantizer4bit b4;\n    inline __m256i dequant(const block_q4_1 * x) const {\n        return b4.dequant(x->qs);\n    }\n};\n\nstruct HBitDequantizer {\n    const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000);\n    const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);\n    const __m256i minus1 = _mm256_set1_epi64x(-1);\n    inline __m256i to_bytes(const uint8_t * bits) const {\n        // Note: Data in all ggml quants is at least 2-byte aligned.\n        // => we can cast to uint16_t and use or on two consecutive entries\n        // which is faster than memcpy\n        const uint16_t * aux16 = (const uint16_t *)bits;\n        const uint32_t aux32 = aux16[0] | (aux16[1] << 16);\n        //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t));\n        __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle);\n        bytes = _mm256_or_si256(bytes, mask);\n        return _mm256_cmpeq_epi8(bytes, minus1);\n    }\n};\n\nstruct Q5_0_Dequantizer {\n    Dequantizer4bit b4;\n    HBitDequantizer hbit;\n    const __m256i mh = _mm256_set1_epi8((char)0xF0);\n    inline __m256i dequant(const block_q5_0 * x) const {\n        const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh);\n        return _mm256_or_si256(b4.dequant(x->qs), vqh);\n    }\n};\n\nstruct Q5_1_Dequantizer {\n    Dequantizer4bit b4;\n    HBitDequantizer hbit;\n    const __m256i mh = _mm256_set1_epi8(0x10);\n    inline __m256i dequant(const block_q5_1 * x) const {\n        const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh);\n        return _mm256_or_si256(b4.dequant(x->qs), vqh);\n    }\n};\n\ntemplate <typename Q, typename Scales, typename Dequantizer>\nstruct Q_Unpacker {\n    Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {}\n\n    const char * cx_0;\n    const Q    * x;\n    size_t       bx;\n\n    Scales scales;\n    Dequantizer deq;\n\n    __m256i qx[4];\n\n    inline const __m256i* quants() const { return qx; }\n\n    inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); }\n\n    inline auto set_block_4(int i) {\n        for (int j = 0; j < 4; ++j) {\n            qx[j] = deq.dequant(x + 4*i + j);\n        }\n        return scales.prepare4(x + 4*i);\n    }\n    inline auto set_block(int i) {\n        qx[0] = deq.dequant(x + i);\n        return scales.prepare1(x + i);\n    }\n};\n\nstruct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {\n    Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_0; }\n};\nstruct Q4_0_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0, Q4_0_Dequantizer> {\n    Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_0; }\n};\nstruct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {\n    Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK5_0; }\n};\nstruct Q4_1_Unpacker final : public Q_Unpacker<block_q4_1, ScaleHelperQ_1, Q4_1_Dequantizer> {\n    Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_1; }\n};\nstruct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_Dequantizer> {\n    Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}\n    inline static int block_size() { return QK4_1; }\n};\n\ntemplate <int nrc_y>\nvoid mul_mat_q8_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n%Q8_0_Unpacker::block_size() == 0);\n    Q8<nrc_y, block_q8_0> q8(info);\n    int nb = n/Q8_0_Unpacker::block_size();\n    if (nb%4 == 0) {\n        mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, true>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    } else {\n        mul_mat_qX_q8_Helper<Q8_0_Unpacker, Sum4_Q8, AccumType0<nrc_y, false>, ScaleHelperQ_0, block_q8_0, nrc_y>(\n                nb, vx, bx, info, q8.y, nrc_x\n        );\n    }\n}\n\ntemplate <typename Dequantizer> void MulMat::set_functions(MulMat& m) {\n        if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker>) {\n            m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;\n        }\n        else if constexpr (std::is_same_v<Dequantizer, Q4_1_Unpacker> || std::is_same_v<Dequantizer, Q5_1_Unpacker>) {\n            m.funcs[0] = mul_mat_qX_1_q8_1_T<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_1_q8_1_T<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_1_q8_1_T<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_1_q8_1_T<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_1_q8_1_T<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_1_q8_1_T<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_1_q8_1_T<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_1_q8_1_T<Dequantizer, 8>;\n        }\n        else {\n#ifdef HAVE_FANCY_SIMD\n            m.funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;\n            m.funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;\n            m.funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;\n            m.funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;\n            m.funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;\n            m.funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;\n            m.funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;\n            m.funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;\n#else\n            if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||\n                          std::is_same_v<Dequantizer, DequantizerQ3K> ||\n                          std::is_same_v<Dequantizer, DequantizerQ6K>) {\n                m.funcs[0] = mul_mat_qY_K_q8_K_T<Dequantizer, 1>;\n                m.funcs[1] = mul_mat_qY_K_q8_K_T<Dequantizer, 2>;\n                m.funcs[2] = mul_mat_qY_K_q8_K_T<Dequantizer, 3>;\n                m.funcs[3] = mul_mat_qY_K_q8_K_T<Dequantizer, 4>;\n                m.funcs[4] = mul_mat_qY_K_q8_K_T<Dequantizer, 5>;\n                m.funcs[5] = mul_mat_qY_K_q8_K_T<Dequantizer, 6>;\n                m.funcs[6] = mul_mat_qY_K_q8_K_T<Dequantizer, 7>;\n                m.funcs[7] = mul_mat_qY_K_q8_K_T<Dequantizer, 8>;\n            } else {\n                m.funcs[0] = mul_mat_qX_K_q8_K_T<Dequantizer, 1>;\n                m.funcs[1] = mul_mat_qX_K_q8_K_T<Dequantizer, 2>;\n                m.funcs[2] = mul_mat_qX_K_q8_K_T<Dequantizer, 3>;\n                m.funcs[3] = mul_mat_qX_K_q8_K_T<Dequantizer, 4>;\n                m.funcs[4] = mul_mat_qX_K_q8_K_T<Dequantizer, 5>;\n                m.funcs[5] = mul_mat_qX_K_q8_K_T<Dequantizer, 6>;\n                m.funcs[6] = mul_mat_qX_K_q8_K_T<Dequantizer, 7>;\n                m.funcs[7] = mul_mat_qX_K_q8_K_T<Dequantizer, 8>;\n            }\n#endif\n        }\n}\n\nbool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int) {\n\n    if (ne00 % ggml_blck_size(GGML_TYPE_Q8_K) == 0)\n        row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);\n\n    switch (typeA) {\n        case GGML_TYPE_Q2_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ2K>(mm);\n            break;\n        case GGML_TYPE_Q3_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ3K>(mm);\n            break;\n        case GGML_TYPE_Q4_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ4K>(mm);\n            break;\n        case GGML_TYPE_Q5_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ5K>(mm);\n            break;\n        case GGML_TYPE_Q6_K:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerQ6K>(mm);\n            break;\n        case GGML_TYPE_IQ4_XS:\n            assert (ne00 % QK_K == 0);\n            MulMat::set_functions<DequantizerIQ4XS>(mm);\n            break;\n        case GGML_TYPE_Q4_0:\n            assert (ne00 % QK4_0 == 0);\n            MulMat::set_functions<Q4_0_Unpacker>(mm);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        case GGML_TYPE_Q4_1:\n            assert (ne00 % QK4_1 == 0);\n            MulMat::set_functions<Q4_1_Unpacker>(mm);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);\n            break;\n        case GGML_TYPE_Q5_0:\n            assert (ne00 % QK5_0 == 0);\n            MulMat::set_functions<Q5_0_Unpacker>(mm);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        case GGML_TYPE_Q5_1:\n            assert (ne00 % QK5_1 == 0);\n            MulMat::set_functions<Q5_1_Unpacker>(mm);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);\n            break;\n\n        default:\n            return false;\n    }\n\n    return true;\n}\n\n} // namespace\n\n\n#else   // __aarch64__\n\n//[kawrakow] Need these two for performance on Arm\ntypedef struct {\n    ggml_half d[8];\n    int8_t qs[4*QK8_1];\n} block_q8_1_x4;\nstatic_assert(sizeof(block_q8_1_x4) == 4*sizeof(block_q8_1), \"wrong q8_1_x4 block size/padding\");\ntypedef struct {\n    ggml_half d[4];\n    int8_t qs[4*QK8_0];\n} block_q8_0_x4;\nstatic_assert(sizeof(block_q8_0_x4) == 4*sizeof(block_q8_0), \"wrong q8_0_x4 block size/padding\");\n\nnamespace {\n\ntemplate <int nrc, typename block_q8 = block_q8_K> struct Q8 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q8(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);\n    }\n\n    inline int8x16_t load_quants_16(int iy, int i, int j) const { return vld1q_s8(y[iy][i].qs + 16*j); }\n    inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }\n    inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }\n    inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }\n    inline int16x8_t load_bsums8(int iy, int i) const {\n        auto q8s = vld1q_s16_x2(y[iy][i].bsums);\n        return vpaddq_s16(q8s.val[0], q8s.val[1]);\n    }\n    inline float scale(int iy, int i) const { return y[iy][i].d; }\n\n    const block_q8 * y[nrc_y];\n};\n\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n\n    Dequantizer deq(vx, bx, nrc_y);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[nrc_y];\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n//#pragma GCC unroll 4\n        for (int i = 0; i < nb; ++i) {\n\n            int32x4_t sumi[nrc_y];\n            for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);\n\n            if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {\n                deq.process_scales(i, q8, acc);\n                deq.prepare(i, 0);\n                deq.compute(q8, i, 0, sumi);\n                deq.prepare(i, 1);\n                deq.compute(q8, i, 1, sumi);\n            } else {\n                if constexpr (Dequantizer::num_blocks() == 8) {\n                    auto scales = deq.new_block(i, q8, acc);\n                    deq.prepare(i, 0);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                    deq.prepare(i, 1);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n                }\n                else if constexpr (Dequantizer::num_blocks() == 16) {\n                    auto scales = deq.new_block(i, q8, acc);\n                    deq.prepare(i, 0);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                    deq.prepare(i, 1);\n#pragma GCC unroll 8\n                    for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n                }\n                else {\n                    GGML_ASSERT(false);\n                }\n            }\n\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));\n            }\n        }\n\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_IQ(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n\n    Dequantizer deq(vx, bx, nrc_y);\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[nrc_y];\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb; ++i) {\n\n            int32x4_t sumi[nrc_y];\n            for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);\n\n            if constexpr (Dequantizer::num_blocks() == 8) {\n                auto scales = deq.new_block(i);\n                deq.prepare(i, 0);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                deq.prepare(i, 1);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n            }\n            else if constexpr (Dequantizer::num_blocks() == 16) {\n                auto scales = deq.new_block(i);\n                deq.prepare(i, 0);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);\n                deq.prepare(i, 1);\n#pragma GCC unroll 8\n                for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);\n            }\n            else {\n                GGML_ASSERT(false);\n            }\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));\n            }\n        }\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,\n        const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n    auto mzero = vdupq_n_s32(0);\n    const int8x16_t * qs_1 = (const int8x16_t *)qx_1.val;\n    const int8x16_t * qs_2 = (const int8x16_t *)qx_2.val;\n\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[0], q8b_1.val[0]), qs_1[1], q8b_1.val[1]); // block 1\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_1[2], q8b_2.val[0]), qs_1[3], q8b_2.val[1]); // block 2\n    auto p12 = vpaddq_s32(p1, p2);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[0], q8b_3.val[0]), qs_2[1], q8b_3.val[1]); // block 3\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qs_2[2], q8b_4.val[0]), qs_2[3], q8b_4.val[1]); // block 4\n    auto p34 = vpaddq_s32(p3, p4);\n\n    auto pall = vpaddq_s32(p12, p34);\n    sumi = vmlaq_s32(sumi, scales.val[j], pall);\n}\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_8_blocks(const int8x16_t * qx, const Q8& q8,\n        const int32x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n    auto mzero = vdupq_n_s32(0);\n\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[0], q8b_1.val[0]), qx[1], q8b_1.val[1]); // block 1\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[2], q8b_2.val[0]), qx[3], q8b_2.val[1]); // block 2\n    auto p12 = vpaddq_s32(p1, p2);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[4], q8b_3.val[0]), qx[5], q8b_3.val[1]); // block 3\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, qx[6], q8b_4.val[0]), qx[7], q8b_4.val[1]); // block 4\n    auto p34 = vpaddq_s32(p3, p4);\n\n    auto pall = vpaddq_s32(p12, p34);\n    sumi = vmlaq_s32(sumi, scales, pall);\n}\n\ntemplate <typename Q8>\nIQK_ALWAYS_INLINE void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,\n        const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {\n\n    auto mzero = vdupq_n_s32(0);\n    auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n    auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,\n    auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n    auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,\n    auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3\n    sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);\n\n    auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n    auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,\n    auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n    auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),\n                         ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,\n    auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7\n    sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);\n}\n\ntemplate <typename Q8>\ninline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto q8s = q8.load_bsums8(iy, i);\n        int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));\n        int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));\n        float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));\n        acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));\n    }\n}\ntemplate <typename Q8>\ninline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto q8s = q8.load_bsums(iy, i);\n        int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));\n        int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));\n        int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));\n        int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));\n        float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));\n        acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));\n    }\n}\n\nstruct Scales8 {\n    uint32_t utmp[4];\n    const uint8_t * sc8 = (const uint8_t *)utmp;\n    template <typename Q8, typename Qx>\n    inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {\n        make_q4_scales(x.scales, utmp);\n        int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));\n        accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));\n\n        uint8x8_t scales8 = vld1_u8(sc8);\n        uint16x8_t scales16 = vmovl_u8(scales8);\n        int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),\n                              vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};\n        return scales;\n    }\n};\n\nstruct Q4bits {\n    const uint8x16_t m4b = vdupq_n_u8(0xf);\n    uint8x16x4_t b1, b2;\n    inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {\n        b.val[0] = vandq_u8(val[0], m4b);\n        b.val[2] = vshrq_n_u8(val[0], 4);\n        b.val[1] = vandq_u8(val[1], m4b);\n        b.val[3] = vshrq_n_u8(val[1], 4);\n    }\n    inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {\n        b.val[0] = vandq_u8(val[0], m4b);\n        b.val[1] = vshrq_n_u8(val[0], 4);\n        b.val[2] = vandq_u8(val[1], m4b);\n        b.val[3] = vshrq_n_u8(val[1], 4);\n    }\n    inline void prepare(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x2(qs);\n        prepare4(b1, q4bits.val);\n        q4bits = vld1q_u8_x2(qs+32);\n        prepare4(b2, q4bits.val);\n    }\n    inline void prepare_v2(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        prepare4(b1, q4bits.val+0);\n        prepare4(b2, q4bits.val+2);\n    }\n    inline void prepare64(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        b1.val[0] = vandq_u8(q4bits.val[0], m4b);\n        b1.val[1] = vandq_u8(q4bits.val[1], m4b);\n        b1.val[2] = vandq_u8(q4bits.val[2], m4b);\n        b1.val[3] = vandq_u8(q4bits.val[3], m4b);\n        b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);\n        b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);\n        b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);\n        b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);\n    }\n    inline void prepare16(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x2(qs);\n        prepare4_16(b1, q4bits.val);\n        q4bits = vld1q_u8_x2(qs+32);\n        prepare4_16(b2, q4bits.val);\n    }\n    inline void prepare16_v2(const uint8_t * qs) {\n        auto q4bits = vld1q_u8_x4(qs);\n        prepare4_16(b1, q4bits.val+0);\n        prepare4_16(b2, q4bits.val+2);\n    }\n};\n\nstruct Q2bits {\n    const uint8x16_t m4b = vdupq_n_u8(0x03);\n    uint8x16x4_t b1, b2;\n    inline void prepare(const uint8_t * qs) {\n        auto q2bits = vld1q_u8_x2(qs);\n        b1.val[0] = vandq_u8(q2bits.val[0], m4b);\n        b1.val[1] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b1.val[2] = vandq_u8(q2bits.val[0], m4b);\n        b1.val[3] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b2.val[0] = vandq_u8(q2bits.val[0], m4b);\n        b2.val[1] = vandq_u8(q2bits.val[1], m4b);\n\n        q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);\n        q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);\n        b2.val[2] = vandq_u8(q2bits.val[0], m4b);\n        b2.val[3] = vandq_u8(q2bits.val[1], m4b);\n    }\n};\n\ntemplate <typename block_q>\nstruct BaseDequantizer {\n    BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}\n    inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); }\n    const void * vx;\n    const block_q * x;\n    const size_t bx;\n    const int nrc;\n};\n\nstruct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {\n    DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return s8.process_scales_mins(x[i], q8, i, acc);\n    }\n    inline void prepare(int i, int j) {\n        if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);\n        else bits.prepare(x[i].qs+64*j);\n    }\n\n    Q4bits bits;\n    Scales8 s8;\n\n    float d;\n};\n\nstruct HighBit5 {\n    const uint8x16_t mhb = vdupq_n_u8(0x10);\n    uint8x16x2_t bits;\n    inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {\n        b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));\n        b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));\n        b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));\n        b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));\n\n        b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));\n        b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));\n        b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));\n        b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));\n\n        if (do_shift) {\n            bits.val[0] = vshrq_n_u8(bits.val[0], 4);\n            bits.val[1] = vshrq_n_u8(bits.val[1], 4);\n        }\n    }\n};\n\nstruct HighBit3 {\n    const uint8x16_t mhb = vdupq_n_u8(0x04);\n    uint8x16x2_t bits;\n    inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {\n        b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));\n        b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));\n        b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));\n        b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));\n\n        b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));\n        b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));\n        b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));\n        b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));\n\n        if (do_shift) {\n            bits.val[0] = vshrq_n_u8(bits.val[0], 4);\n            bits.val[1] = vshrq_n_u8(bits.val[1], 4);\n        }\n    }\n};\n\nstruct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {\n    DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        h.bits = vld1q_u8_x2(x[i].qh);\n        return s8.process_scales_mins(x[i], q8, i, acc);\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+64*j);\n        h.apply(bits.b1, bits.b2, j == 0);\n    }\n\n    Q4bits bits;\n    HighBit5 h;\n    Scales8 s8;\n\n    uint8x16x2_t hbits;\n\n    float d;\n};\n\ninline int32x4x4_t make_wider(const int16x8x2_t& scales16) {\n    int32x4x4_t scales = {\n        vmovl_s16(vget_low_s16 (scales16.val[0])),\n        vmovl_s16(vget_high_s16(scales16.val[0])),\n        vmovl_s16(vget_low_s16 (scales16.val[1])),\n        vmovl_s16(vget_high_s16(scales16.val[1])),\n    };\n    return scales;\n}\n\ntemplate <typename Q8>\ninline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {\n    int16x8x2_t scales16;\n    scales16.val[0] = vmovl_s8(vget_low_s8(scales8));\n    scales16.val[1] = vmovl_s8(vget_high_s8(scales8));\n    accum_mins_16(scales16, q8, acc, i, c);\n    return make_wider(scales16);\n}\n\nstruct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {\n    DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);\n    }\n    inline void prepare(int i, int j) {\n\n        auto hbits = vld1q_u8_x2(x[i].qh + 32*j);\n\n        bits.prepare64(x[i].ql+64*j);\n        bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));\n        bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));\n        bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));\n        bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));\n\n        bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));\n        bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));\n        bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));\n        bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));\n\n    }\n\n    Q4bits bits;\n\n    const uint8x16_t mhb = vdupq_n_u8(0x30);\n\n    float d;\n};\n\nstruct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {\n    DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        h.bits = vld1q_u8_x2(x[i].hmask);\n        const uint16_t * sc16 = (const uint16_t *)x[i].scales;\n        uint32_t aux0 = sc16[0] | (sc16[1] << 16);\n        uint32_t aux1 = sc16[2] | (sc16[3] << 16);\n        uint32_t aux2 = sc16[4] | (sc16[5] << 16);\n        aux32[0] =  (aux0       & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);\n        aux32[1] =  (aux1       & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);\n        aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);\n        aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);\n        return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d);\n    }\n\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+32*j);\n        h.apply(bits.b1, bits.b2, j == 0);\n    }\n\n    uint32_t aux32[4];\n\n    Q2bits bits;\n\n    HighBit3 h;\n\n    float d;\n};\n\nstruct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {\n    DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return true; }\n\n    template <typename Q8>\n    inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        auto scales_and_mins = vld1q_u8(x[i].scales);\n        auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));\n        int16x8x2_t scales16;\n        scales16.val[0] = vmovl_s8(vget_low_s8(mins8));\n        scales16.val[1] = vmovl_s8(vget_high_s8(mins8));\n        accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));\n\n        scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));\n    }\n\n    template <typename Q8>\n    inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        process_scales(i, q8, acc);\n        int16x8x2_t scales16;\n        scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));\n        scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));\n        return make_wider(scales16);\n    }\n\n    template <typename Q8>\n    inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {\n        auto m1 = vdupq_n_u8(1);\n        auto shuffle = vdupq_n_u8(8*j);\n        bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            auto q8b_1 = q8.load_quants(iy, i, 4*j+0);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),\n                    vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);\n\n            auto q8b_2 = q8.load_quants(iy, i, 4*j+1);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),\n                    vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);\n\n            auto q8b_3 = q8.load_quants(iy, i, 4*j+2);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),\n                    vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);\n\n            auto q8b_4 = q8.load_quants(iy, i, 4*j+3);\n            sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),\n                    vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);\n        }\n    }\n\n    inline void prepare(int i, int j) {\n        bits.prepare(x[i].qs+32*j);\n    }\n\n    uint32_t aux32[4];\n\n    uint8x16_t scales8;\n\n    Q2bits bits;\n\n    float d;\n};\n\n// ============================= i-quants\n\nstruct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {\n\n    static int8x16_t load_values() {\n        static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};\n        return vld1q_s8(iq4nl_values);\n    }\n\n    DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); }\n\n    template <typename Q8>\n    inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {\n        (void)q8;\n        (void)acc;\n        d = GGML_FP16_TO_FP32(x[i].d);\n        const uint16_t scales_h = x[i].scales_h;\n        const uint16_t * scales_l = (const uint16_t *)x[i].scales_l;\n        aux32[0] = scales_l[0] | (scales_l[1] << 16);\n        aux32[1] = aux32[0] >> 4;\n        // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7\n        uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf));\n        uint16_t * aux16 = (uint16_t *)aux32;\n        aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2;\n        // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7\n        uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30));\n        int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32));\n        // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7\n        scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff));\n        int16x8_t scales16 = vmovl_s8(scales8);\n        int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};\n        return scales;\n    }\n    inline void prepare(int i, int j) {\n        bits.prepare16(x[i].qs+64*j);\n        for (int k = 0; k < 4; ++k) {\n            bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k]));\n            bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k]));\n        }\n    }\n\n    Q4bits bits;\n    const int8x16_t values;\n    uint32_t aux32[2];\n\n    constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};\n\n    float d;\n};\n\nstruct SimpleBits {\n    uint8x16x4_t b1;\n    uint8x16x4_t b2;\n};\n\nIQK_ALWAYS_INLINE int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) {\n    int32x4x2_t scales;\n    auto one = vdupq_n_u32(1);\n    scales.val[0] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v1, 28), 1));\n    scales.val[1] = vreinterpretq_s32_u32(vsliq_n_u32(one, vshrq_n_u32(v2, 28), 1));\n    return scales;\n}\n\ninline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) {\n    auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));\n    auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127))));\n    b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));\n    b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));\n}\n\nIQK_ALWAYS_INLINE int32x4_t prepare_scales_8(const uint32x4_t& v1) {\n    return vreinterpretq_s32_u32(vsliq_n_u32(vdupq_n_u32(1), vshrq_n_u32(v1, 28), 1));\n}\n\nstruct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {\n    DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    IQK_ALWAYS_INLINE float new_block(int i) const { return 0.125f * GGML_FP16_TO_FP32(x[i].d); }\n\n    inline int32x4_t unpack(int i, int j, uint8x16_t * q) const {\n        auto data = vld1q_u32_x2((const uint32_t *)(x[i].qs + 16*j));\n        prepare_all(data, q);\n        return prepare_scales_8(vuzp2q_u32(data.val[0], data.val[1]));\n    }\n\nprivate:\n\n    static inline void prepare2(uint8x16_t * b, const uint32_t * bits, const uint64_t * signs) {\n        const uint8_t * idx = (const uint8_t *)bits;\n        b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]});\n        b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]});\n        apply_signs_2(b, signs, bits[1]);\n    }\n\n    inline static void prepare_all(const uint32x4x2_t& data, uint8x16_t * quants) {\n        const uint32_t * q2 = (const uint32_t *)data.val;\n        prepare2(quants+0, q2+0, keven_signs);\n        prepare2(quants+2, q2+2, keven_signs);\n        prepare2(quants+4, q2+4, keven_signs);\n        prepare2(quants+6, q2+6, keven_signs);\n    }\n};\n\ninline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) {\n    auto aux = vld1_u8(sc);\n    auto scales_l = vand_u8(aux, vdup_n_u8(0xf));\n    auto scales_h = vshr_n_u8(aux, 4);\n    auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));\n\n    auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));\n    int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };\n    return make_wider(scales16);\n}\n\nstruct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {\n    DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x4_t new_block(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        prepare_internal(i, 0);\n        return prepare_4bit_scales16(x[i].scales);\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) prepare_internal(i, 1);\n    }\n\nprivate:\n\n    static void make2(const uint16_t * qs, uint8x16_t * b) {\n        auto v1 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[1] & 511))));\n        auto v2 = vcombine_s8(vld1_s8((const int8_t *)(iq2xs_grid + (qs[2] & 511))), vld1_s8((const int8_t *)(iq2xs_grid + (qs[3] & 511))));\n        auto s1 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9))));\n        auto s2 = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[2] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[3] >> 9))));\n        b[0] = vreinterpretq_u8_s8(vmulq_s8(v1, s1));\n        b[1] = vreinterpretq_u8_s8(vmulq_s8(v2, s2));\n    }\n\n    inline static void make4(const uint16_t * qs, uint8x16_t * b) {\n        make2(qs + 0, b + 0);\n        make2(qs + 4, b + 2);\n    }\n\n    IQK_ALWAYS_INLINE void prepare_internal(int i, int j) {\n        make4(x[i].qs + 16*j + 0, bits.b1.val);\n        make4(x[i].qs + 16*j + 8, bits.b2.val);\n    }\n\n};\n\n// So, I hate to include this table, but with the GCC 12.3 compiler\n// bundled in the Cosmopolitan tools, loading the unpacked sign bytes\n// from this table using the packed 8 sign bits as index is faster than\n// using the standard trick of vceqq_u8(vandq_u8(bits, mask), mask) to\n// expand the bits to bytes.\nstatic const uint64_t kall_signs[256] = {\n    0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff,\n    0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff,\n    0x01010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0x01010101ff01ffff,\n    0x01010101ffff0101, 0x01010101ffff01ff, 0x01010101ffffff01, 0x01010101ffffffff,\n    0x010101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0x010101ff0101ffff,\n    0x010101ff01ff0101, 0x010101ff01ff01ff, 0x010101ff01ffff01, 0x010101ff01ffffff,\n    0x010101ffff010101, 0x010101ffff0101ff, 0x010101ffff01ff01, 0x010101ffff01ffff,\n    0x010101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0x010101ffffffffff,\n    0x0101ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0x0101ff010101ffff,\n    0x0101ff0101ff0101, 0x0101ff0101ff01ff, 0x0101ff0101ffff01, 0x0101ff0101ffffff,\n    0x0101ff01ff010101, 0x0101ff01ff0101ff, 0x0101ff01ff01ff01, 0x0101ff01ff01ffff,\n    0x0101ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0x0101ff01ffffffff,\n    0x0101ffff01010101, 0x0101ffff010101ff, 0x0101ffff0101ff01, 0x0101ffff0101ffff,\n    0x0101ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0x0101ffff01ffffff,\n    0x0101ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0x0101ffffff01ffff,\n    0x0101ffffffff0101, 0x0101ffffffff01ff, 0x0101ffffffffff01, 0x0101ffffffffffff,\n    0x01ff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0x01ff01010101ffff,\n    0x01ff010101ff0101, 0x01ff010101ff01ff, 0x01ff010101ffff01, 0x01ff010101ffffff,\n    0x01ff0101ff010101, 0x01ff0101ff0101ff, 0x01ff0101ff01ff01, 0x01ff0101ff01ffff,\n    0x01ff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0x01ff0101ffffffff,\n    0x01ff01ff01010101, 0x01ff01ff010101ff, 0x01ff01ff0101ff01, 0x01ff01ff0101ffff,\n    0x01ff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0x01ff01ff01ffffff,\n    0x01ff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0x01ff01ffff01ffff,\n    0x01ff01ffffff0101, 0x01ff01ffffff01ff, 0x01ff01ffffffff01, 0x01ff01ffffffffff,\n    0x01ffff0101010101, 0x01ffff01010101ff, 0x01ffff010101ff01, 0x01ffff010101ffff,\n    0x01ffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0x01ffff0101ffffff,\n    0x01ffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0x01ffff01ff01ffff,\n    0x01ffff01ffff0101, 0x01ffff01ffff01ff, 0x01ffff01ffffff01, 0x01ffff01ffffffff,\n    0x01ffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0x01ffffff0101ffff,\n    0x01ffffff01ff0101, 0x01ffffff01ff01ff, 0x01ffffff01ffff01, 0x01ffffff01ffffff,\n    0x01ffffffff010101, 0x01ffffffff0101ff, 0x01ffffffff01ff01, 0x01ffffffff01ffff,\n    0x01ffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0x01ffffffffffffff,\n    0xff01010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0xff0101010101ffff,\n    0xff01010101ff0101, 0xff01010101ff01ff, 0xff01010101ffff01, 0xff01010101ffffff,\n    0xff010101ff010101, 0xff010101ff0101ff, 0xff010101ff01ff01, 0xff010101ff01ffff,\n    0xff010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0xff010101ffffffff,\n    0xff0101ff01010101, 0xff0101ff010101ff, 0xff0101ff0101ff01, 0xff0101ff0101ffff,\n    0xff0101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0xff0101ff01ffffff,\n    0xff0101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0xff0101ffff01ffff,\n    0xff0101ffffff0101, 0xff0101ffffff01ff, 0xff0101ffffffff01, 0xff0101ffffffffff,\n    0xff01ff0101010101, 0xff01ff01010101ff, 0xff01ff010101ff01, 0xff01ff010101ffff,\n    0xff01ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0xff01ff0101ffffff,\n    0xff01ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0xff01ff01ff01ffff,\n    0xff01ff01ffff0101, 0xff01ff01ffff01ff, 0xff01ff01ffffff01, 0xff01ff01ffffffff,\n    0xff01ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0xff01ffff0101ffff,\n    0xff01ffff01ff0101, 0xff01ffff01ff01ff, 0xff01ffff01ffff01, 0xff01ffff01ffffff,\n    0xff01ffffff010101, 0xff01ffffff0101ff, 0xff01ffffff01ff01, 0xff01ffffff01ffff,\n    0xff01ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0xff01ffffffffffff,\n    0xffff010101010101, 0xffff0101010101ff, 0xffff01010101ff01, 0xffff01010101ffff,\n    0xffff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0xffff010101ffffff,\n    0xffff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0xffff0101ff01ffff,\n    0xffff0101ffff0101, 0xffff0101ffff01ff, 0xffff0101ffffff01, 0xffff0101ffffffff,\n    0xffff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0xffff01ff0101ffff,\n    0xffff01ff01ff0101, 0xffff01ff01ff01ff, 0xffff01ff01ffff01, 0xffff01ff01ffffff,\n    0xffff01ffff010101, 0xffff01ffff0101ff, 0xffff01ffff01ff01, 0xffff01ffff01ffff,\n    0xffff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0xffff01ffffffffff,\n    0xffffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0xffffff010101ffff,\n    0xffffff0101ff0101, 0xffffff0101ff01ff, 0xffffff0101ffff01, 0xffffff0101ffffff,\n    0xffffff01ff010101, 0xffffff01ff0101ff, 0xffffff01ff01ff01, 0xffffff01ff01ffff,\n    0xffffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0xffffff01ffffffff,\n    0xffffffff01010101, 0xffffffff010101ff, 0xffffffff0101ff01, 0xffffffff0101ffff,\n    0xffffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0xffffffff01ffffff,\n    0xffffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0xffffffffff01ffff,\n    0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff,\n};\n\nstruct SignHelper {\n\n    IQK_ALWAYS_INLINE void apply_signs_1x(uint8x16_t * b, const uint8_t * sign_bits) const {\n        auto s = vreinterpretq_s8_u64(uint64x2_t{kall_signs[sign_bits[0]], kall_signs[sign_bits[1]]});\n        // Normally we would expect this to be faster, but it isn't.\n        // auto aux = vcombine_u8(vdup_n_u8(sign_bits[0]), vdup_n_u8(sign_bits[1]));\n        // auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));\n        b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));\n    }\n\n    // We would need these two if we weren't loading from the unpacked sign table.\n    //const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));\n    //const uint8x16_t m1    = vdupq_n_u8(1);\n};\n\nstruct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {\n    DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 16; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x4_t new_block(int i) {\n        d = 0.125f * GGML_FP16_TO_FP32(x[i].d);\n        prepare_internal(i, 0, bits);\n        return prepare_4bit_scales16(x[i].scales);\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) prepare_internal(i, 1, bits);\n    }\n\nprivate:\n\n    static void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) {\n        uint32_t aux32[2];\n        const uint16_t * aux16 = (const uint16_t *)aux32;\n        for (int k = 0; k < 2; ++k) {\n            aux32[1] = (qh[k] << 4) | (qh[k] << 18);\n            aux32[0] = (aux32[1] << 4) & 0x03000300;\n            aux32[1] &= 0x03000300;\n            b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))),\n                                   vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1]))));\n            b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))),\n                                   vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3]))));\n            sh.apply_signs_1x(b+2*k+0, sign_bits); sign_bits += 2;\n            sh.apply_signs_1x(b+2*k+1, sign_bits); sign_bits += 2;\n        }\n    }\n\n    void prepare_internal(int i, int j, SimpleBits& sb) {\n\n        const auto * qs = x[i].qs + 16*j;\n        const auto * qh = x[i].qh + 4*j;\n        const auto * sign_bits = qs + QK_K/8;\n\n        make4(sh, sign_bits+0, qs+0, qh+0, sb.b1.val);\n        make4(sh, sign_bits+8, qs+8, qh+2, sb.b2.val);\n    }\n\n    SignHelper sh;\n};\n\nstruct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {\n    DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    IQK_ALWAYS_INLINE float new_block(int i) const { return 0.25f * GGML_FP16_TO_FP32(x[i].d); }\n\n    inline int32x4_t unpack(int i, int j, uint8x16_t * q) const {\n        auto q3data = vld1q_u8_x2(x[i].qs + 32*j);\n        auto gas = vld1q_u32((const uint32_t *)(x[i].qs + QK_K/4 + 16*j));\n        prepare_block((const uint8_t *)q3data.val, (const uint32_t *)&gas, q);\n        return prepare_scales_8(gas);\n    }\n\nprivate:\n\n    inline static void make2(const uint8_t * q3, const uint32_t sidx, uint8x16_t * b) {\n        b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]});\n        b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]});\n        apply_signs_2(b, keven_signs, sidx);\n    }\n    inline static void prepare_block(const uint8_t * q3, const uint32_t * signs, uint8x16_t * quants) {\n        make2(q3+ 0, signs[0], quants + 0);\n        make2(q3+ 8, signs[1], quants + 2);\n        make2(q3+16, signs[2], quants + 4);\n        make2(q3+24, signs[3], quants + 6);\n    }\n};\n\nstruct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {\n    DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}\n\n    constexpr static int num_blocks() { return 8; }\n    constexpr static bool should_scale_quants() { return false; }\n\n    SimpleBits bits;\n    float d;\n\n    inline int32x4x2_t new_block(int i) {\n        d = GGML_FP16_TO_FP32(x[i].d);\n        uint32_t scales32[2];\n        auto qs = vld1q_u8_x2(x[i].qs);\n        auto signs = vld1q_u8(x[i].signs);\n\n        prepare_block((const uint8_t *)qs.val, x[i].qh, (const uint8_t *)&signs);\n\n        std::memcpy(scales32, x[i].scales, 4);\n        scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;\n        scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;\n        auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7\n        scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400)));\n        auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8));\n        int32x4x2_t scales;\n        scales.val[0] = vmovl_s16(vget_low_s16(scales16));\n        scales.val[1] = vmovl_s16(vget_high_s16(scales16));\n        return scales;\n    }\n\n    inline void prepare(int i, int j) {\n        if (j == 1) {\n            auto qs = vld1q_u8_x2(x[i].qs + 32);\n            auto signs = vld1q_u8(x[i].signs + 16);\n            prepare_block((const uint8_t *)qs.val, x[i].qh + 4, (const uint8_t *)&signs);\n        }\n    }\n\nprivate:\n\n    static inline void make2(const SignHelper& sh, const uint8_t * sign_bits, const uint16x8_t& idx_l, uint8_t qh,\n            const int16x8_t& hshift, uint8x16_t * b) {\n        auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));\n        const uint16_t * idx = (const uint16_t *)&vindex;\n        b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});\n        sh.apply_signs_1x(b+0, sign_bits+0);\n        b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});\n        sh.apply_signs_1x(b+1, sign_bits+2);\n    }\n    static inline void make4(const SignHelper& sh, const uint8_t * sign_bits, const uint8_t * qs, const uint8_t * qh,\n            const int16x8_t& hshift, uint8x16_t * b) {\n        auto idx_l = vld1q_u8(qs);\n        make2(sh, sign_bits+0, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);\n        make2(sh, sign_bits+4, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);\n    }\n\n    static int16x8_t load_shift() {\n        static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};\n        return vld1q_s16(k_shift);\n    }\n\n    inline void prepare_block(const uint8_t * qs, const uint8_t * qh, const uint8_t * sign_bits) {\n        auto signs = vld1q_u8(sign_bits);\n        auto s = (const uint8_t *)&signs;\n        make4(sh, s + 0, qs+ 0, qh+0, hshift, bits.b1.val);\n        make4(sh, s + 8, qs+16, qh+2, hshift, bits.b2.val);\n    }\n\n    SignHelper sh;\n    const int16x8_t hshift = load_shift();\n\n};\n\ntemplate <int nrc_y, typename Dequantizer>\nIQK_NOINLINE void mul_mat_qX_K_q8_K_IQXXS(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    assert(n % QK_K == 0);\n    const int nb = n / QK_K;\n\n    Q8<nrc_y, block_q8_K> q8(info);\n    Dequantizer deq(vx, bx, nrc_y);\n    uint8x16_t  qx[8];\n    int32x4_t   sumi[nrc_y];\n    float32x4_t acc[nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n        for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb; ++i) {\n            float d = deq.new_block(i);\n            auto scales = deq.unpack(i, 0, qx);\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                sumi[iy] = vdupq_n_s32(0);\n                compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 0, sumi[iy]);\n            }\n            scales = deq.unpack(i, 1, qx);\n#pragma GCC unroll 8\n            for (int iy = 0; iy < nrc_y; ++iy) {\n                compute_8_blocks((const int8x16_t *)qx, q8, scales, iy, i, 1, sumi[iy]);\n                acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, i)), vcvtq_f32_s32(sumi[iy]));\n            }\n        }\n#pragma GCC unroll 8\n        for (int iy = 0; iy < nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\n// =========================================== Legacy quants\n\ntemplate <typename Block>\ninline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) {\n    for (int k = 0; k < 4; ++k) aux[k] = x[k].d;\n    return vld1_f16((const float16_t *)aux);\n}\n\ntemplate <typename Block>\ninline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) {\n    if constexpr (std::is_same_v<Block, block_q8_1>) {\n        for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; }\n    } else {\n        for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; }\n    }\n    return vld1q_f16((const float16_t *)aux);\n}\n\nstruct Q4LegacyBits {\n    template <typename Block>\n    inline void prepare(const Block * x) {\n        for (int i = 0; i < 4; ++i) {\n            auto q4bits = vld1q_u8(x[i].qs);\n            b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));\n            b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));\n        }\n    }\n    inline void prepare1(const uint8_t * qs, int8x16_t * q) const {\n        auto q4bits = vld1q_u8(qs);\n        q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b));\n        q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4));\n    }\n    inline void prepare1(const uint8_t * qs) {\n        prepare1(qs, b);\n    }\n    const uint8x16_t m4b = vdupq_n_u8(0xf);\n    int8x16_t b[8];\n};\n\n// One would think this commented out version would do better than the one below\n// because it offers more opportunities to execute instructions in parallel.\n// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers\n// cannot it just do the sequential version below on its own?\n//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {\n//    const auto q8b_1 = vld1q_s8_x2(qs + 0);\n//    auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]);\n//    const auto q8b_2 = vld1q_s8_x2(qs + 32);\n//    auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]);\n//    auto p1234 = vpaddq_s32(p12, p34);\n//    const auto q8b_3 = vld1q_s8_x2(qs + 64);\n//    auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]);\n//    const auto q8b_4 = vld1q_s8_x2(qs + 96);\n//    auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]);\n//    return vpaddq_s32(p1234, vpaddq_s32(p56, p78));\n//}\n\ninline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) {\n    auto q8b = vld1q_s8_x2(qs + 0);\n    auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]);\n    q8b = vld1q_s8_x2(qs + 32);\n    auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]);\n    auto p1234 = vpaddq_s32(p12, p34);\n    q8b = vld1q_s8_x2(qs + 64);\n    auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]);\n    q8b = vld1q_s8_x2(qs + 96);\n    auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]);\n    return vpaddq_s32(p1234, vpaddq_s32(p56, p78));\n}\n\ntemplate <int nrc> struct Q80 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q80(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);\n    }\n\n    inline const int8_t * quant_data(int iy, int i) const {\n        const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;\n        return y4->qs;\n    }\n\n    inline float16x4_t load_scales(int iy, int i) const {\n        const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i;\n        return vld1_f16((const float16_t *)y4->d);\n    }\n\n    template <typename Dequantizer>\n    inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const {\n        auto qx_scales = deq.new_block(i);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8_scales = load_scales(iy, i);\n            sc16[iy] = vmul_f16(qx_scales, q8_scales);\n        }\n    }\n\n    template <typename Dequantizer>\n    inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {\n        deq.prepare1(i);\n        float d = GGML_FP16_TO_FP32(deq.x[i].d);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8b = vld1q_s8_x2(y[iy][i].qs);\n            auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);\n            acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));\n        }\n    }\n\n    const block_q8_0 * y[nrc_y];\n};\n\ntemplate <int nrc> struct Q81 {\n\n    constexpr static int nrc_y = nrc;\n\n    Q81(const DataInfo& info) {\n        for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy);\n    }\n\n    inline const int8_t * quant_data(int iy, int i) const {\n        const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;\n        return y4->qs;\n    }\n\n    inline float16x8_t load_scales(int iy, int i) const {\n        const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i;\n        return vld1q_f16((const float16_t *)y4->d);\n    }\n\n    template <typename Dequantizer>\n    inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const {\n        auto qx_scales = deq.new_block(i);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8_scales = load_scales(iy, i);\n            auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales));\n            acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m));\n            sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales));\n        }\n    }\n\n    template <typename Dequantizer>\n    inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const {\n        deq.prepare1(i);\n        float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m);\n        for (int iy = 0; iy < nrc; ++iy) {\n            auto q8b = vld1q_s8_x2(y[iy][i].qs);\n            auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]);\n            acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p));\n            acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s)));\n        }\n    }\n\n    const block_q8_1 * y[nrc_y];\n};\n\ntemplate <typename block_q>\nstruct BaseLegacyDequantizer {\n\n    BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {}\n\n    inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); }\n\n    Q4LegacyBits bits;\n\n    const void * vx;\n    const block_q * x;\n    size_t bx;\n};\n\nstruct DequantizerQ40 final : public BaseLegacyDequantizer<block_q4_0> {\n\n    DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        q[0] = vaddq_s8(q[0], m8);\n        q[1] = vaddq_s8(q[1], m8);\n    }\n    inline void prepare1(int i) {\n        prepare1(i, bits.b);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n    const int8x16_t m8 = vdupq_n_s8(-8);\n    //ggml_half aux[4];\n};\n\nstruct DequantizerQ41 : public BaseLegacyDequantizer<block_q4_1> {\n\n    DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i) {\n        bits.prepare1(x[i].qs);\n    }\n\n    inline float16x8_t new_block(int i) {\n        uint32_t aux32[4];\n        const uint32_t * s32 = (const uint32_t *)&x[4*i].d;\n        for (int k = 0; k < 4; ++k) {\n            aux32[k] = *s32; s32 += sizeof(block_q4_1)/4;\n            bits.prepare1(x[4*i+k].qs, bits.b + 2*k);\n        }\n        return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));\n    }\n    // Leaving this commented out attempt to be reminded that I already tried this.\n    // It has basically the same performance as the version above.\n    //inline float16x8_t new_block(int i) {\n    //    uint32x4_t scales = {};\n    //    const block_q4_1 * xi = x + 4*i;\n    //    const uint32_t * s32 = (const uint32_t *)&xi->d;\n    //    scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[0].qs, bits.b + 0);\n    //    scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[1].qs, bits.b + 2);\n    //    scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4;\n    //    bits.prepare1(xi[2].qs, bits.b + 4);\n    //    scales = vsetq_lane_u32(*s32, scales, 3);\n    //    bits.prepare1(xi[3].qs, bits.b + 6);\n    //    return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle)));\n    //}\n\n    const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};\n};\n\nstruct HighBit5Legacy {\n    inline uint8x16_t to_bytes(const uint8_t * qh) const {\n        uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);\n        return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask));\n    }\n    inline uint8x16_t to_negated_bytes(const uint8_t * qh) const {\n        uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle);\n        return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0));\n    }\n    const uint64x2_t mask = vdupq_n_u64(0x8040201008040201);\n    const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));\n};\n\nstruct DequantizerQ50 final : public BaseLegacyDequantizer<block_q5_0> {\n\n    DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        auto qh = x[i].qh;\n        q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0))));\n        q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2))));\n    }\n    inline void prepare1(int i) {\n        prepare1(i, bits.b);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n    HighBit5Legacy hbits;\n\n    const uint8x16_t mh = vdupq_n_u8(0xf0);\n\n};\n\nstruct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> {\n\n    DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i) {\n        bits.b[0] = vld1q_s8(x[i].qs);\n        bits.b[1] = vld1q_s8(x[i].qs+16);\n    }\n\n    inline float16x4_t new_block(int i) {\n        ggml_half aux[4];\n        for (int k = 0; k < 4; ++k) {\n            aux[k] = x[4*i+k].d;\n            bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs);\n            bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16);\n        }\n        return vld1_f16((const float16_t *)aux);\n    }\n\n};\n\nstruct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> {\n\n    DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}\n\n    inline void prepare1(int i, int8x16_t * q) const {\n        bits.prepare1(x[i].qs, q);\n        auto qh = x[i].qh;\n        q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0))));\n        q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2))));\n    }\n    inline void prepare1(int i) {\n        bits.prepare1(x[i].qs, bits.b);\n    }\n\n    inline float16x8_t new_block(int i) {\n        uint32_t aux32[4];\n        const uint32_t * s32 = (const uint32_t *)&x[4*i].d;\n        for (int k = 0; k < 4; ++k) {\n            aux32[k] = *s32; s32 += sizeof(block_q5_1)/4;\n            prepare1(4*i+k, bits.b + 2*k);\n        }\n        return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle)));\n    }\n\n    HighBit5Legacy hbits;\n\n    const uint8x16_t mh = vdupq_n_u8(0x10);\n    const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302};\n\n};\n\ntemplate <typename Dequantizer, typename Q8>\ninline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) {\n    for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n        auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i));\n        auto scale = vcvt_f32_f16(sc16[iy]);\n        acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall));\n    }\n}\n\ntemplate <typename Dequantizer, typename Q8>\ninline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK4_1;\n\n    float16x4_t sc16[Q8::nrc_y];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq.new_row(ix);\n\n        float32x4_t acc[Q8::nrc_y];\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);\n\n        for (int i = 0; i < nb/4; ++i) {\n            q8.process_scales(i, deq, sc16, acc);\n            sum_4(i, deq, q8, sc16, acc);\n        }\n        for (int i = 4*(nb/4); i < nb; ++i) {\n            q8.process_1_block(i, deq, acc);\n        }\n\n        for (int iy = 0; iy < Q8::nrc_y; ++iy) {\n            info.store(ix, iy, vaddvq_f32(acc[iy]));\n        }\n    }\n}\n\ntemplate <typename Dequantizer, typename Q8>\ninline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) {\n    const int nb = n / QK4_1;\n\n    float16x4_t sc16[2];\n\n    for (int ix = 0; ix < nrc_x; ++ix) {\n\n        deq1.new_row(ix);\n        deq2.new_row(ix);\n\n        float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) };\n\n        for (int i = 0; i < nb/8; ++i) {\n            q8.process_scales(2*i+0, deq1, sc16+0, acc+0);\n            q8.process_scales(2*i+1, deq2, sc16+1, acc+1);\n            sum_4(2*i+0, deq1, q8, sc16+0, acc+0);\n            sum_4(2*i+1, deq2, q8, sc16+1, acc+1);\n        }\n        for (int i = 2*(nb/8); i < nb/4; ++i) {\n            q8.process_scales(i, deq1, sc16, acc);\n            sum_4(i, deq1, q8, sc16, acc);\n        }\n        for (int i = 4*(nb/4); i < nb; ++i) {\n            q8.process_1_block(i, deq1, acc);\n        }\n\n        info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1])));\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void IQK_NOINLINE mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Q81<nrc_y> q8(info);\n    if constexpr (nrc_y == 1) {\n        Dequantizer deq1(vx, bx), deq2(vx, bx);\n        mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n    } else {\n        Dequantizer deq(vx, bx);\n        mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);\n    }\n}\n\ntemplate <typename Dequantizer, int nrc_y>\nstatic void IQK_NOINLINE mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Q80<nrc_y> q8(info);\n    if constexpr (nrc_y == 1) {\n        Dequantizer deq1(vx, bx), deq2(vx, bx);\n        mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n    } else {\n        Dequantizer deq(vx, bx);\n        mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x);\n    }\n}\n\ntemplate <typename Dequantizer>\nstatic void IQK_NOINLINE mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Dequantizer deq1(vx, bx), deq2(vx, bx);\n    Q81<1> q8(info);\n    mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x);\n}\n\ntemplate <typename Dequantizer>\nstatic void IQK_NOINLINE mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {\n    Dequantizer deq1(vx, bx), deq2(vx, bx);\n    Q80<1> q8(info);\n    mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x);\n}\n\ntemplate <typename Dequantizer> void MulMat::set_functions(MulMat& m) {\n    if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||\n                  std::is_same_v<Dequantizer, DequantizerQ80>) {\n        m.funcs[0] = mul_mat_qX_0_q8_0<Dequantizer, 1>;\n        m.funcs[1] = mul_mat_qX_0_q8_0<Dequantizer, 2>;\n        m.funcs[2] = mul_mat_qX_0_q8_0<Dequantizer, 3>;\n        m.funcs[3] = mul_mat_qX_0_q8_0<Dequantizer, 4>;\n        m.funcs[4] = mul_mat_qX_0_q8_0<Dequantizer, 5>;\n        m.funcs[5] = mul_mat_qX_0_q8_0<Dequantizer, 6>;\n        m.funcs[6] = mul_mat_qX_0_q8_0<Dequantizer, 7>;\n        m.funcs[7] = mul_mat_qX_0_q8_0<Dequantizer, 8>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerQ41> || std::is_same_v<Dequantizer, DequantizerQ51>) {\n        m.funcs[0] = mul_mat_qX_1_q8_1<Dequantizer, 1>;\n        m.funcs[1] = mul_mat_qX_1_q8_1<Dequantizer, 2>;\n        m.funcs[2] = mul_mat_qX_1_q8_1<Dequantizer, 3>;\n        m.funcs[3] = mul_mat_qX_1_q8_1<Dequantizer, 4>;\n        m.funcs[4] = mul_mat_qX_1_q8_1<Dequantizer, 5>;\n        m.funcs[5] = mul_mat_qX_1_q8_1<Dequantizer, 6>;\n        m.funcs[6] = mul_mat_qX_1_q8_1<Dequantizer, 7>;\n        m.funcs[7] = mul_mat_qX_1_q8_1<Dequantizer, 8>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2XXS> || std::is_same_v<Dequantizer, DequantizerIQ3XXS>) {\n        m.funcs[0] = mul_mat_qX_K_q8_K_IQXXS<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_IQXXS<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_IQXXS<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_IQXXS<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_IQXXS<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_IQXXS<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_IQXXS<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_IQXXS<8, Dequantizer>;\n    }\n    else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2S> ||\n                       std::is_same_v<Dequantizer, DequantizerIQ3S> ||\n                       std::is_same_v<Dequantizer, DequantizerIQ2XS>) {\n        m.funcs[0] = mul_mat_qX_K_q8_K_IQ<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_IQ<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_IQ<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_IQ<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_IQ<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_IQ<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_IQ<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_IQ<8, Dequantizer>;\n    }\n    else {\n        m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>;\n        m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>;\n        m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>;\n        m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>;\n        m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>;\n        m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>;\n        m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>;\n        m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>;\n    }\n}\n\nbool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int Ny) {\n    row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);\n\n    (void)Ny;\n    // Uncommenting out this would disable iqk_mul_mat for matrix x vector multiplications.\n    //if (Ny == 1 && (typeA == GGML_TYPE_IQ2_XXS || typeA == GGML_TYPE_IQ2_XS || typeA == GGML_TYPE_IQ2_S ||\n    //                typeA == GGML_TYPE_IQ3_XXS || typeA == GGML_TYPE_IQ3_S)) return false;\n\n    switch (typeA) {\n        case GGML_TYPE_Q2_K:\n            MulMat::set_functions<DequantizerQ2K>(m);\n            break;\n        case GGML_TYPE_Q3_K:\n            MulMat::set_functions<DequantizerQ3K>(m);\n            break;\n        case GGML_TYPE_Q4_K:\n            MulMat::set_functions<DequantizerQ4K>(m);\n            break;\n        case GGML_TYPE_Q5_K:\n            MulMat::set_functions<DequantizerQ5K>(m);\n            break;\n        case GGML_TYPE_Q6_K:\n            MulMat::set_functions<DequantizerQ6K>(m);\n            break;\n        case GGML_TYPE_IQ4_XS:\n            MulMat::set_functions<DequantizerIQ4XS>(m);\n            break;\n        case GGML_TYPE_IQ3_S:\n            MulMat::set_functions<DequantizerIQ3S>(m);\n            break;\n        case GGML_TYPE_IQ3_XXS:\n            MulMat::set_functions<DequantizerIQ3XXS>(m);\n            break;\n        case GGML_TYPE_IQ2_S:\n            MulMat::set_functions<DequantizerIQ2S>(m);\n            break;\n        case GGML_TYPE_IQ2_XS:\n            MulMat::set_functions<DequantizerIQ2XS>(m);\n            break;\n        case GGML_TYPE_IQ2_XXS:\n            MulMat::set_functions<DequantizerIQ2XXS>(m);\n            break;\n        case GGML_TYPE_Q4_0:\n            MulMat::set_functions<DequantizerQ40>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        case GGML_TYPE_Q4_1:\n            MulMat::set_functions<DequantizerQ41>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);\n            break;\n        case GGML_TYPE_Q5_0:\n            MulMat::set_functions<DequantizerQ50>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        case GGML_TYPE_Q5_1:\n            MulMat::set_functions<DequantizerQ51>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00);\n            break;\n        case GGML_TYPE_Q8_0:\n            MulMat::set_functions<DequantizerQ80>(m);\n            row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00);\n            break;\n        default:\n            return false;\n    }\n    return true;\n}\n\n}\n\n#endif // __x86_64__ or __aarch64__\n"
  },
  {
    "path": "third_party/llamafile/iqk_mul_mat_arm82.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/iqk_mul_mat_arm82.cpp\n// Copyrigth 2024 Iwan Kawrakow.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#ifdef __aarch64__\n#define iqk_mul_mat iqk_mul_mat_arm82\n#define iqk_mul_mat_moe iqk_mul_mat_moe_arm82\n#include \"iqk_mul_mat_arm.inc\"\n#endif  // __aarch64__\n"
  },
  {
    "path": "third_party/llamafile/macros.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/macros.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n#pragma once\n\n#define MIN(X, Y) ((Y) > (X) ? (X) : (Y))\n#define MAX(X, Y) ((Y) < (X) ? (X) : (Y))\n#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))\n#define ROUNDUP(X, K) (((X) + (K) - 1) & -(K))\n#define ARRAYLEN(A) ((sizeof(A) / sizeof(*(A))) / ((unsigned)!(sizeof(A) % sizeof(*(A)))))\n"
  },
  {
    "path": "third_party/llamafile/micros.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/micros.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n#pragma once\n\n#include <ctime>\n\n#ifndef _WIN32\n#include <unistd.h>\n#else\n#include <windows.h>\n#endif\n\n#ifdef _WIN32\nstatic long long GetQueryPerformanceFrequency() {\n    LARGE_INTEGER t;\n    QueryPerformanceFrequency(&t);\n    return t.QuadPart;\n}\nstatic long long GetQueryPerformanceCounter() {\n    LARGE_INTEGER t;\n    QueryPerformanceCounter(&t);\n    return t.QuadPart;\n}\n#endif\n\nstatic long long micros(void) {\n#ifndef _WIN32\n    struct timespec ts;\n    clock_gettime(CLOCK_REALTIME, &ts);\n    return ts.tv_sec * 1000000 + (ts.tv_nsec + 999) / 1000;\n#else\n    static long long timer_freq = GetQueryPerformanceFrequency();\n    static long long timer_start = GetQueryPerformanceCounter();\n    return ((GetQueryPerformanceCounter() - timer_start) * 1000000) / timer_freq;\n#endif\n}\n"
  },
  {
    "path": "third_party/llamafile/numba.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/numba.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#pragma once\n\ninline int rand32(void) {\n    static unsigned long long lcg = 1;\n    lcg *= 6364136223846793005;\n    lcg += 1442695040888963407;\n    return lcg >> 32;\n}\n\ninline int popcount(unsigned x) {\n    x = x - ((x >> 1) & 0x55555555);\n    x = ((x >> 2) & 0x33333333) + (x & 0x33333333);\n    x = (x + (x >> 4)) & 0x0F0F0F0F;\n    x = (x + (x >> 16));\n    return (x + (x >> 8)) & 0x0000003F;\n}\n\ninline int hamming(int x, int y) {\n    return popcount(x ^ y);\n}\n\ninline float float01(unsigned x) {  // (0,1)\n    return 1.f / 8388608 * ((x >> 9) + .5f);\n}\n\ninline float numba(void) {  // (-10,10)\n    return float01(rand32()) * 2.f - 1.f;\n}\n\ntemplate <typename T>\nvoid randomize(T* A, int n) {\n    for (int i = 0; i < n; ++i)\n        A[i] = numba();\n}\n\ntemplate <typename T>\nvoid randomize(int m, int n, T* A, int lda) {\n    for (int j = 0; j < n; ++j)\n        for (int i = 0; i < m; ++i)\n            A[lda * j + i] = numba();\n}\n\ntemplate <typename T, typename U>\nvoid broadcast(T* A, int n, U x) {\n    for (int i = 0; i < n; ++i)\n        A[i] = x;\n}\n\ntemplate <typename T, typename U>\nvoid broadcast(int m, int n, T* A, int lda, U x) {\n    for (int j = 0; j < n; ++j)\n        for (int i = 0; i < m; ++i)\n            A[lda * j + i] = x;\n}\n"
  },
  {
    "path": "third_party/llamafile/sgemm.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"sgemm.h\"\n// #include <cosmo.h>\n// #include <cpuid.h>\n// #include <libc/sysv/consts/hwcap.h>\n#include <stdio.h>\n// #include <sys/auxv.h>\n#include <cassert>\n// #include \"llamafile.h\"\n\nstatic const struct GemmFuncs {\n    bool (*sgemm)(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\n    bool (*mixmul)(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\n    bool (*iqk_mixmul)(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\n    // typeof(llamafile_sgemm)* sgemm;\n    // typeof(llamafile_mixmul)* mixmul;\n    // typeof(llamafile_mixmul_iqk)* iqk_mixmul = iqk_mul_mat_moe_unsupported;\n    GemmFuncs() {\n#if defined(__x86_64__) || defined(_M_X64)\n        // if (X86_HAVE(AVX)) {\n        //     if (X86_HAVE(FMA)) {\n        //         if (X86_HAVE(AVX2)) {\n        //             if (X86_HAVE(AVX512F)) {\n        //                 if (X86_HAVE(AVX512VL) &&     //\n        //                     X86_HAVE(AVX512BW) &&     //\n        //                     X86_HAVE(AVX512DQ) &&     //\n        //                     X86_HAVE(AVX512_VNNI) &&  //\n        //                     X86_HAVE(AVX512_BF16)) {\n        //                     // AMD Zen4+ (2023-)\n        //                     sgemm = llamafile_sgemm_amd_zen4;\n        //                     mixmul = llamafile_mixmul_amd_zen4;\n        //                     iqk_mixmul = iqk_mul_mat_moe_zen4;\n        //                 } else {\n        //                     // Intel Xeon Skylake+ (2015-)\n        //                     sgemm = llamafile_sgemm_amd_avx512f;\n        //                     mixmul = llamafile_mixmul_amd_avx512f;\n        //                     iqk_mixmul = iqk_mul_mat_moe;\n        //                 }\n        //             } else if (X86_HAVE(AVXVNNI)) {\n        //                 // Intel Alderlake (2021-)\n        //                 sgemm = llamafile_sgemm_amd_avxvnni;\n        //                 mixmul = llamafile_mixmul_amd_avxvnni;\n        //                 iqk_mixmul = iqk_mul_mat_moe;\n        //             } else {\n        //                 // Intel Haswell/Broadwell/Skylake (2013-2020)\n        //                 // AMD Excavator (2015-2022)\n        //                 sgemm = llamafile_sgemm_amd_avx2;\n        //                 mixmul = llamafile_mixmul_amd_avx2;\n        //                 if (X86_HAVE(F16C))\n        //                     iqk_mixmul = iqk_mul_mat_moe;\n        //             }\n        //         } else {\n        //             // AMD Piledriver (2011-2014)\n        //             sgemm = llamafile_sgemm_amd_fma;\n        //             mixmul = llamafile_mixmul_amd_fma;\n        //             if (X86_HAVE(F16C))\n        //                 iqk_mixmul = iqk_mul_mat_moe;\n        //         }\n        //     } else {\n        //         // Intel Sandybridge/Ivybridge (2010-2012)\n        //         // AMD Bulldozer (2011)\n        //         sgemm = llamafile_sgemm_amd_avx;\n        //         mixmul = llamafile_mixmul_amd_avx;\n        //     }\n        // } else {\n        //     // AMD K8/Barcelona (2003-2010)\n        //     // Intel Core/Nehalem (2006-2009)\n        //     sgemm = llamafile_sgemm_unsupported;\n        //     mixmul = llamafile_mixmul_unsupported;\n        // }\n\n#if defined(__AVX__)\n#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))\n#if defined(__AVX2__)\n#if defined(__AVX512F__)\n#if defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__)\n        // AMD Zen4+ (2023-)\n        sgemm = llamafile_sgemm_amd_zen4;\n        mixmul = llamafile_mixmul_amd_zen4;\n        iqk_mixmul = iqk_mul_mat_moe_zen4;\n#else\n        // Intel Xeon Skylake+ (2015-)\n        sgemm = llamafile_sgemm_amd_avx512f;\n        mixmul = llamafile_mixmul_amd_avx512f;\n        iqk_mixmul = iqk_mul_mat_moe;\n#endif\n#elif defined(__AVXVNNI__)\n        // Intel Alderlake (2021-)\n        sgemm = llamafile_sgemm_amd_avxvnni;\n        mixmul = llamafile_mixmul_amd_avxvnni;\n        iqk_mixmul = iqk_mul_mat_moe;\n#else\n        // Intel Haswell/Broadwell/Skylake (2013-2020)\n        // AMD Excavator (2015-2022)\n        sgemm = llamafile_sgemm_amd_avx2;\n        mixmul = llamafile_mixmul_amd_avx2;\n#if defined(__F16C__)\n        iqk_mixmul = iqk_mul_mat_moe;\n#endif\n#endif\n#else\n        // AMD Piledriver (2011-2014)\n        sgemm = llamafile_sgemm_amd_fma;\n        mixmul = llamafile_mixmul_amd_fma;\n#if defined(__F16C__)\n        iqk_mixmul = iqk_mul_mat_moe;\n#endif\n#endif\n#else\n        // Intel Sandybridge/Ivybridge (2010-2012)\n        // AMD Bulldozer (2011)\n        sgemm = llamafile_sgemm_amd_avx;\n        mixmul = llamafile_mixmul_amd_avx;\n#endif\n#else\n        // AMD K8/Barcelona (2003-2010)\n        // Intel Core/Nehalem (2006-2009)\n        sgemm = llamafile_sgemm_unsupported;\n        mixmul = llamafile_mixmul_unsupported;\n#endif\n\n#elif defined(__aarch64__)\n        // long hwcap = getauxval(AT_HWCAP);\n        // if ((hwcap & HWCAP_FPHP) &&     // fp16 scalar isa (ID_AA64PFR0_EL1.FP == 1)\n        //     (hwcap & HWCAP_ASIMDHP) &&  // fp16 vector isa (ID_AA64PFR0_EL1.AdvSIMD == 1)\n        //     (hwcap & HWCAP_ASIMDDP)) {  // dotprod isa (ID_AA64ISAR0_EL1.DP == 1)\n        //     // e.g. Apple M1, Raspberry Pi 5\n            sgemm = llamafile_sgemm_arm82;\n            mixmul = llamafile_mixmul_arm82;\n            iqk_mixmul = iqk_mul_mat_moe_arm82;\n        // } else {\n            // ARM64 baseline ISA\n        //     sgemm = llamafile_sgemm_arm80;\n        //     mixmul = llamafile_mixmul_arm80;\n        // }\n#else\n        sgemm = llamafile_sgemm_unsupported;\n        mixmul = llamafile_mixmul_unsupported;\n#endif\n    }\n} funcs;\n\n/**\n * Performs optimized matrix multiplication on CPU.\n *\n * This subroutine may compute C = Aᵀ * B with column major ordering.\n * Despite its name, this isn't a generalized implementation. Work is\n * only performed when a handwritten kernel is written and available.\n * Otherwise the caller should fall back to a general matmul routine.\n *\n * @param m is rows in `A` and `C`\n * @param n is cols in `B` and `C`\n * @param k is cols in `A` and rows in `B`\n * @param A is first input matrix (always transposed)\n * @param lda is row stride of `A`\n * @param B is second input matrix (never transposed)\n * @param ldb is row stride of `B`\n * @param C is input/output array of output matrices\n * @param ldc is row stride of `C`\n * @param ith is thread id (must be less than `nth`)\n * @param nth is number of threads (must be greater than zero)\n * @param task is GGML task type\n * @param Atype is GGML data type of `A`\n * @param Btype is GGML data type of `B`\n * @param Ctype is GGML data type of `C`\n * @param precision may be used to control the internal compute type\n * @return true if this function was able to service the matmul request\n */\nbool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    return funcs.sgemm(m, n, k, A, lda, B, ldb, C, ldc, ith, nth, task, Atype, Btype, Ctype,\n                       precision);\n}\n\n/**\n * Performs \"mixture of experts\" tensor multiplication on CPU.\n */\nbool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {\n    return funcs.mixmul(params, weights, thought, plan, result);\n}\n\nbool llamafile_mixmul_iqk(long Nx, long Ny, long ne00, int ne11, int typeA, const void* A, const void* B, float* C, long nb1, long nb2, const void* vrow_mapping, int ith, int nth) {\n    return funcs.iqk_mixmul(Nx, Ny, ne00, ne11, typeA, A, B, C, nb1, nb2, vrow_mapping, ith, nth);\n}\n"
  },
  {
    "path": "third_party/llamafile/sgemm.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/sgemm.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#pragma once\n#include <stdbool.h>\n#include <cstddef>\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\nstruct ggml_tensor;\nstruct ggml_compute_params;\n#ifdef __aarch64__\n\nbool iqk_mul_mat(long, long, long, int, const void*, const void*, float*, long, int, int);\nbool iqk_mul_mat_zen4(long, long, long, int, const void*, const void*, float*, long, int, int);\nbool iqk_mul_mat_arm82(long, long, long, int, const void*, const void*, float*, long, int, int);\n\nbool iqk_mul_mat_moe(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\nbool iqk_mul_mat_moe_zen4(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\nbool iqk_mul_mat_moe_arm82(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\nbool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\n\nbool llamafile_sgemm(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_mixmul(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nsize_t llamafile_mixmul_needs(const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*);\n\nbool llamafile_sgemm_unsupported(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avx(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_fma(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avx2(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avxvnni(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avx512f(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_zen4(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_arm80(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_arm82(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\n\nbool llamafile_mixmul_unsupported(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avx(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_fma(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avx2(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avxvnni(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avx512f(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_zen4(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_arm80(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_arm82(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_iqk(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\n\n#else\n\nbool iqk_mul_mat(long, long, long,int, const void*, long, int, const void*, long,float*, long, int, int);\nbool iqk_mul_mat_zen4(long, long, long,int, const void*, long, int, const void*, long,float*, long, int, int);\nbool iqk_mul_mat_arm82(long, long, long,int, const void*, long, int, const void*, long,float*, long, int, int);\n\n\nbool iqk_mul_mat_moe(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\nbool iqk_mul_mat_moe_zen4(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\nbool iqk_mul_mat_moe_arm82(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\nbool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\n\nbool llamafile_sgemm(long m, long n, long k, const void* a, long lda, const void* b, long ldb, void* c, long ldc, int ith, int nth, int task_type, int a_type, int b_type, int c_type, int precision);\nbool llamafile_mixmul(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nsize_t llamafile_mixmul_needs(const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*);\n\nbool llamafile_sgemm_unsupported(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avx(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_fma(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avx2(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avxvnni(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_avx512f(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_amd_zen4(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_arm80(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\nbool llamafile_sgemm_arm82(long, long, long, const void*, long, const void*, long, void*, long, int, int, int, int, int, int, int);\n\nbool llamafile_mixmul_unsupported(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avx(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_fma(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avx2(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avxvnni(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_avx512f(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_amd_zen4(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_arm80(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_arm82(const struct ggml_compute_params*, const struct ggml_tensor*, const struct ggml_tensor*, const struct ggml_tensor*, struct ggml_tensor*);\nbool llamafile_mixmul_iqk(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int);\n\n#endif\n\n#ifdef __cplusplus\n}\n#endif\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu.h",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu.h\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n//\n//\n//                                ██████╗ ██╗   █████╗ ██████╗\n//         ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║  ██╔══██╗██╔═══╝\n//         ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║  ███████║██████╗\n//           ██║  ██║██▀███║╚███╔╝██╔══██╗██║  ██╔══██║╔═══██║\n//           ██║  ██║██║ ██║ ███║ ██████╔╝████╗██║  ██║██████║\n//           ╚═╝  ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝  ╚═╝╚═════╝\n//\n//                   BASIC LINEAR ALGEBRA SUBPROGRAMS\n//\n//\n// This file implements multithreaded CPU matrix multiplication for the\n// common contiguous use case C = Aᵀ * B. These kernels are designed to\n// have excellent performance[1] for matrices that fit in the CPU cache\n// without imposing any overhead such as cache filling or malloc calls.\n//\n// This implementation does not guarantee any upper bound with rounding\n// errors, which grow along with k. Our goal's to maximally exploit the\n// hardware for performance, and then use whatever resources remain for\n// improving numerical accuracy.\n//\n// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].\n//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].\n\n#pragma once\n\n#include \"llama.cpp/ggml-impl.h\"\n#include \"llama.cpp/ggml-quants.h\"\n// #include \"log.h\"\n#include \"flags.h\"\n#include \"sgemm.h\"\n// #include <cosmo.h>\n\n#pragma GCC diagnostic ignored \"-Wpedantic\"\n#pragma GCC diagnostic ignored \"-Wignored-attributes\"\n\n#define ROW_ALIGN 64\n#define MATRIX_ALIGN 4096\n#define MAX_ALIGN 4096\n\n#ifdef _MSC_VER\n#define NOINLINE __declspec(noinline)\n#else\n#define NOINLINE __attribute__((__noinline__))\n#endif\n\n#if defined(__ARM_NEON) || defined(__AVX512F__)\n#define VECTOR_REGISTERS 32\n#else\n#define VECTOR_REGISTERS 16\n#endif\n\n#if 0\n#define NOT_SUPPORTED tinyBLAS_not_supported(__FILE__, __LINE__)\n#else\n#define NOT_SUPPORTED false\n#endif\n#define WANT_QUANTIZATION false\n\nnamespace {\n\nbool tinyBLAS_not_supported(const char* file, int line) {\n    // tinylogf(\"%s:%d: tinyBLAS not supported\\n\", file, line);\n    return false;\n}\n\ninline float unhalf(ggml_fp16_t d) {\n    return GGML_FP16_TO_FP32(d);\n}\ninline float unhalf(ggml_bf16_t d) {\n    return GGML_BF16_TO_FP32(d);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// MATRIX MEMORY INDEXING\n\n#define NCA 1\n#define NCB 2\n#define NCC 4\n\n#define INDEX(A, lda, j, i) (CONFIG & NC##A ? ((T##A**)A)[j] + i : A + lda * (j) + i)\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// GGML TYPE TRAITS\n\ntemplate <typename T>\nstruct ggml_type_trait;\ntemplate <>\nstruct ggml_type_trait<float> {\n    static constexpr ggml_type id = GGML_TYPE_F32;\n};\ntemplate <>\nstruct ggml_type_trait<ggml_bf16_t> {\n    static constexpr ggml_type id = GGML_TYPE_BF16;\n};\ntemplate <>\nstruct ggml_type_trait<ggml_fp16_t> {\n    static constexpr ggml_type id = GGML_TYPE_F16;\n};\ntemplate <>\nstruct ggml_type_trait<block_q8_0> {\n    static constexpr ggml_type id = GGML_TYPE_Q8_0;\n};\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED ARITHMETIC OPERATIONS\n\n#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline __m128 add(__m128 x, __m128 y) {\n    return _mm_add_ps(x, y);\n}\ninline __m128 sub(__m128 x, __m128 y) {\n    return _mm_sub_ps(x, y);\n}\ninline __m128 mul(__m128 x, __m128 y) {\n    return _mm_mul_ps(x, y);\n}\n#endif  // __SSE__\n\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline __m256 add(__m256 x, __m256 y) {\n    return _mm256_add_ps(x, y);\n}\ninline __m256 sub(__m256 x, __m256 y) {\n    return _mm256_sub_ps(x, y);\n}\ninline __m256 mul(__m256 x, __m256 y) {\n    return _mm256_mul_ps(x, y);\n}\n#endif  // __AVX__\n\n#if defined(__AVX512F__)\ninline __m512 add(__m512 x, __m512 y) {\n    return _mm512_add_ps(x, y);\n}\ninline __m512 sub(__m512 x, __m512 y) {\n    return _mm512_sub_ps(x, y);\n}\ninline __m512 mul(__m512 x, __m512 y) {\n    return _mm512_mul_ps(x, y);\n}\n#endif  // __AVX512F__\n\n#if defined(__ARM_NEON)\ninline float32x4_t add(float32x4_t x, float32x4_t y) {\n    return vaddq_f32(x, y);\n}\ninline float32x4_t sub(float32x4_t x, float32x4_t y) {\n    return vsubq_f32(x, y);\n}\ninline float32x4_t mul(float32x4_t x, float32x4_t y) {\n    return vmulq_f32(x, y);\n}\n#endif  // __ARM_NEON\n\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)\ninline float16x8_t add(float16x8_t x, float16x8_t y) {\n    return vaddq_f16(x, y);\n}\ninline float16x8_t sub(float16x8_t x, float16x8_t y) {\n    return vsubq_f16(x, y);\n}\ninline float16x8_t mul(float16x8_t x, float16x8_t y) {\n    return vmulq_f16(x, y);\n}\n#endif  // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED FUSED MULTIPLY ADD\n\n/**\n * Computes a * b + c.\n */\ntemplate <typename T, typename U>\ninline U madd(T a, T b, U c) {\n    return add(mul(a, b), c);\n}\n\n/**\n * Computes a * b + c with error correction.\n *\n * @see W. Kahan, \"Further remarks on reducing truncation errors,\"\n *    Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965,\n *    doi: 10.1145/363707.363723.\n */\ntemplate <typename T, typename U>\ninline U madder(T a, T b, U c, U* e) {\n    U y = sub(mul(a, b), *e);\n    U t = add(c, y);\n    *e = sub(sub(t, c), y);\n    return t;\n}\n\n#ifdef __ARM_NEON\ninline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t* e) {\n    float32x4_t y = sub(vmulq_n_f32(a, b), *e);\n    float32x4_t t = add(c, y);\n    *e = sub(sub(t, c), y);\n    return t;\n}\n#endif\n\n#if defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)))\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ntemplate <>\ninline __m256 madd(__m256 a, __m256 b, __m256 c) {\n    return _mm256_fmadd_ps(a, b, c);\n}\n#endif\n#if defined(__AVX512F__)\ntemplate <>\ninline __m512 madd(__m512 a, __m512 b, __m512 c) {\n    return _mm512_fmadd_ps(a, b, c);\n}\n#endif\n#endif\n\n#if defined(__ARM_FEATURE_FMA)\ntemplate <>\ninline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {\n    return vfmaq_f32(c, a, b);\n}\n#if 0  // todo: this specialization chops gcc 12.3 performance in half\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) && 0\ntemplate <>\ninline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {\n    return vfmaq_f16(c, b, a);\n}\n#endif\n#endif\n#endif\n\n#if defined(__AVX512BF16__)\ntemplate <>\ninline __m512 madd(__m512bh x, __m512bh y, __m512 z) {\n    return _mm512_dpbf16_ps(z, x, y);\n}\ntemplate <>\ninline __m512 madder(__m512bh x, __m512bh y, __m512 z, __m512* _) {\n    return _mm512_dpbf16_ps(z, x, y);\n}\n#endif\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED HORIZONTAL SUM\n\n#if defined(__ARM_NEON)\ninline float hsum(float32x4_t x) {\n    return vaddvq_f32(x);\n}\n#endif  // __ARM_NEON\n\n#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)\ninline float hsum(float16x8_t x) {\n    // todo: this works great on clang but it produces terrible code on gcc 12.3\n    return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_high_f16(x))));\n}\n#endif  // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC\n\n#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline float hsum(__m128 x) {\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\n    x = _mm_add_ps(x, _mm_movehl_ps(x, x));\n    x = _mm_add_ss(x, _mm_movehdup_ps(x));\n#else\n    __m128 t;\n    t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));\n    x = _mm_add_ps(x, t);\n    t = _mm_movehl_ps(t, x);\n    x = _mm_add_ss(x, t);\n#endif\n    return _mm_cvtss_f32(x);\n}\n#endif\n\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ninline float hsum(__m256 x) {\n    return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)));\n}\n#endif  // __AVX__\n\n#if defined(__AVX512F__)\ninline float hsum(__m512 x) {\n    return _mm512_reduce_add_ps(x);\n}\n#endif  // __AVX512F__\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// VECTORIZED MEMORY LOADING\n\ntemplate <typename T, typename U>\nT load(const U*);\n\ntemplate <>\ninline float load(const float* p) {\n    return *p;\n}\ntemplate <>\ninline float load(const ggml_fp16_t* p) {\n    return unhalf(*p);\n}\ntemplate <>\ninline float load(const ggml_bf16_t* p) {\n    return unhalf(*p);\n}\n\n#if defined(__ARM_NEON)\ntemplate <>\ninline float32x4_t load(const float* p) {\n    return vld1q_f32(p);\n}\ntemplate <>\ninline float32x4_t load(const ggml_bf16_t* p) {\n    return vreinterpretq_f32_u32(vshll_n_u16(vld1_u16((const unsigned short*)p), 16));\n}\n#if !defined(_MSC_VER)\ntemplate <>\ninline float16x8_t load(const ggml_fp16_t* p) {\n    return vld1q_f16((const float16_t*)p);\n}\ntemplate <>\ninline float32x4_t load(const ggml_fp16_t* p) {\n    return vcvt_f32_f16(vld1_f16((const float16_t*)p));\n}\n#endif  // _MSC_VER\n#endif  // __ARM_NEON\n\n#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ntemplate <>\ninline __m128 load(const float* p) {\n    return _mm_loadu_ps(p);\n}\n#endif  // __SSE__\n\n#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)\ntemplate <>\ninline __m256 load(const float* p) {\n    return _mm256_loadu_ps(p);\n}\n#endif  // __AVX__\n\n#if defined(__AVX2__) || defined(__AVX512F__)\ntemplate <>\ninline __m256 load(const ggml_bf16_t* p) {\n    return _mm256_castsi256_ps(\n        _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)p)), 16));\n}\n#endif  // __AVX2__\n\n#if defined(__F16C__)\ntemplate <>\ninline __m256 load(const ggml_fp16_t* p) {\n    return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)p));\n}\n#endif  // __F16C__\n\n#if defined(__AVX512F__)\ntemplate <>\ninline __m512 load(const float* p) {\n    return _mm512_loadu_ps(p);\n}\ntemplate <>\ninline __m512 load(const ggml_fp16_t* p) {\n    return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)p));\n}\ntemplate <>\ninline __m512 load(const ggml_bf16_t* p) {\n    return _mm512_castsi512_ps(\n        _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)p)), 16));\n}\n#endif  // __AVX512F__\n\n#if defined(__AVX512BF16__)\ntemplate <>\ninline __m512bh load(const ggml_bf16_t* p) {\n    return (__m512bh)_mm512_loadu_ps((const float*)p);\n}\ntemplate <>\ninline __m512bh load(const float* p) {\n    return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p));\n}\n#endif  // __AVX512BF16__\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// FLOATING POINT OUTPUT STREAMING\n\ninline void store(float* p, float f) {\n    *p = f;\n}\n\ninline void store(ggml_fp16_t* p, float f) {\n    *p = GGML_FP32_TO_FP16(f);\n}\n\ninline void store(ggml_bf16_t* p, float f) {\n    *p = GGML_FP32_TO_BF16(f);\n}\n\n////////////////////////////////////////////////////////////////////////////////////////////////////\n// FLOATING POINT MATRIX MULTIPLICATION\n\ntemplate <int CONFIG, int KN, typename D, typename V, typename TA, typename TB, typename TC>\nclass tinyBLAS {\n   public:\n    tinyBLAS(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n    }\n\n    void matmul(long m, long n, int task) {\n        if (task == GGML_TASK_TYPE_COMPUTE)\n            mnpack(0, m, 0, n);\n    }\n\n   private:\n    NOINLINE void mnpack(long m0, long m, long n0, long n) {\n        long mc, nc, mp, np;\n\n#if VECTOR_REGISTERS == 32\n        if (!FLAG_precise) {\n            switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {\n                case 0x55:\n                    mc = 5;\n                    nc = 5;\n                    gemm<5, 5, false>(m0, m, n0, n);\n                    break;\n                case 0x54:\n                case 0x53:\n                case 0x52:\n                case 0x45:\n                case 0x44:\n                case 0x43:\n                case 0x42:\n                case 0x35:\n                case 0x34:\n                case 0x33:\n                case 0x32:\n                case 0x25:\n                case 0x24:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x51:\n                case 0x41:\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, false>(m0, m, n0, n);\n                    break;\n                case 0x15:\n                case 0x14:\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, false>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else {\n            switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) {\n                case 0x43:\n                    mc = 4;\n                    nc = 3;\n                    gemm<4, 3, true>(m0, m, n0, n);\n                    break;\n                case 0x42:\n                case 0x33:\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x41:\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n#endif\n\n#if VECTOR_REGISTERS == 16\n        if (!FLAG_precise) {\n            switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) {\n                case 0x43:\n                    mc = 4;\n                    nc = 3;\n                    gemm<4, 3, false>(m0, m, n0, n);\n                    break;\n                case 0x42:\n                case 0x33:\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x41:\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, false>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, false>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) {\n                case 0x32:\n                    mc = 3;\n                    nc = 2;\n                    gemm<3, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x23:\n                    mc = 2;\n                    nc = 3;\n                    gemm<2, 3, true>(m0, m, n0, n);\n                    break;\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n#endif\n\n        mp = m0 + (m - m0) / mc * mc;\n        np = n0 + (n - n0) / nc * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n    template <int RM, int RN, int PRECISE>\n    NOINLINE void gemm(long m0, long m, long n0, long n) {\n        long ytiles = RM > 1 ? (m - m0) / RM : 1;\n        long xtiles = RN > 1 ? (n - n0) / RN : 1;\n        long tiles = xtiles * ytiles;\n        long duty = (tiles + nth - 1) / nth;\n        long start = duty * ith;\n        long end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (long job = start; job < end; ++job) {\n            long ii = m0 + job / xtiles * RM;\n            long jj = n0 + job % xtiles * RN;\n            D Cv[RN][RM] = {};\n            D Ce[RN][RM] = {};\n            for (long l = 0; l < k; l += KN)\n#pragma GCC unroll 100\n                for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                    for (int i = 0; i < RM; ++i)\n                        if (PRECISE)\n                            Cv[j][i] = madder(load<V>(INDEX(A, lda, ii + i, l)),  //\n                                              load<V>(INDEX(B, ldb, jj + j, l)),  //\n                                              Cv[j][i], &Ce[j][i]);\n                        else\n                            Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, l)),  //\n                                            load<V>(INDEX(B, ldb, jj + j, l)),  //\n                                            Cv[j][i]);\n#pragma GCC unroll 100\n            for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                for (int i = 0; i < RM; ++i)\n                    store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));\n        }\n    }\n\n    const TA* const A;\n    const TB* const B;\n    TC* const C;\n    const long k;\n    const long lda;\n    const long ldb;\n    const long ldc;\n    const int ith;\n    const int nth;\n};\n\n//////////////////////////////////////////////////////////////////////////////////////////\n// QUANT ZERO MATRIX MULTIPLICATION\n\n#if defined(__ARM_FEATURE_DOTPROD)\ntemplate <int CONFIG, typename TA, typename TB, typename TC>\nclass tinyBLAS_Q0_ARM {\n   public:\n    tinyBLAS_Q0_ARM(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n    }\n\n    void matmul(long m, long n, int task) {\n        if (task == GGML_TASK_TYPE_COMPUTE)\n            mnpack(0, m, 0, n);\n    }\n\n   private:\n    NOINLINE void mnpack(long m0, long m, long n0, long n) {\n        long mc, nc, mp, np;\n\n        if (!FLAG_precise) {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {\n                case 0x33:\n                    mc = 3;\n                    nc = 3;\n                    gemm<3, 3, false>(m0, m, n0, n);\n                    break;\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, false>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, false>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {\n                case 0x33:\n                    mc = 3;\n                    nc = 3;\n                    gemm<3, 3, true>(m0, m, n0, n);\n                    break;\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n\n        mp = m0 + (m - m0) / mc * mc;\n        np = n0 + (n - n0) / nc * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n    template <int RM, int RN, int PRECISE>\n    NOINLINE void gemm(long m0, long m, long n0, long n) {\n        long ytiles = RM > 1 ? (m - m0) / RM : 1;\n        long xtiles = RN > 1 ? (n - n0) / RN : 1;\n        long tiles = xtiles * ytiles;\n        long duty = (tiles + nth - 1) / nth;\n        long start = duty * ith;\n        long end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (long job = start; job < end; ++job) {\n            long ii = m0 + job / xtiles * RM;\n            long jj = n0 + job % xtiles * RN;\n            float32x4_t Cv[RN][RM] = {};\n            float32x4_t Ce[RN][RM] = {};\n            for (int l = 0; l < k; ++l)\n#pragma GCC unroll 100\n                for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                    for (int i = 0; i < RM; ++i) {\n                        float32x4_t a = vcvtq_f32_s32(vdotq_s32(\n                            vdotq_s32(vdupq_n_s32(0), load_lo(INDEX(A, lda, ii + i, l)),\n                                      load_lo(INDEX(B, ldb, jj + j, l))),\n                            load_hi(INDEX(A, lda, ii + i, l)), load_hi(INDEX(B, ldb, jj + j, l))));\n                        float b = unhalf(INDEX(A, lda, ii + i, l)->d) *\n                                  unhalf(INDEX(B, ldb, jj + j, l)->d);\n                        if (PRECISE)\n                            Cv[j][i] = badder(a, b, Cv[j][i], &Ce[j][i]);\n                        else\n                            Cv[j][i] = vmlaq_n_f32(Cv[j][i], a, b);\n                    }\n#pragma GCC unroll 100\n            for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                for (int i = 0; i < RM; ++i)\n                    store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));\n        }\n    }\n\n    inline int8x16_t load_lo(const block_q8_0* b) {\n        return vld1q_s8(b->qs);\n    }\n\n    inline int8x16_t load_hi(const block_q8_0* b) {\n        return vld1q_s8(b->qs + 16);\n    }\n\n    inline int8x16_t load_lo(const block_q4_0* b) {\n        return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), vdupq_n_u8(0x0f))),\n                        vdupq_n_s8(0x8));\n    }\n\n    inline int8x16_t load_hi(const block_q4_0* b) {\n        return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8));\n    }\n\n    const TA* const A;\n    const TB* const B;\n    TC* const C;\n    const long k;\n    const long lda;\n    const long ldb;\n    const long ldc;\n    const int ith;\n    const int nth;\n};\n#endif  // __ARM_FEATURE_DOTPROD\n\n#if defined(__AVX2__) || defined(__AVX512F__)\ntemplate <int CONFIG, typename TA, typename TB, typename TC>\nclass tinyBLAS_Q0_AVX2 {\n   public:\n    tinyBLAS_Q0_AVX2(long k, const TA* A, long lda, const TB* B, long ldb, TC* C, long ldc, int ith, int nth)\n        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {\n    }\n\n    void matmul(long m, long n, int task) {\n        if (task == GGML_TASK_TYPE_COMPUTE)\n            mnpack(0, m, 0, n);\n    }\n\n   private:\n    void mnpack(long m0, long m, long n0, long n) {\n        long mc, nc, mp, np;\n\n#if VECTOR_REGISTERS == 32\n        if (!FLAG_precise) {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {\n                case 0x33:\n                    mc = 3;\n                    nc = 3;\n                    gemm<3, 3, false>(m0, m, n0, n);\n                    break;\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {\n                case 0x33:\n                    mc = 3;\n                    nc = 3;\n                    gemm<3, 3, true>(m0, m, n0, n);\n                    break;\n                case 0x32:\n                case 0x23:\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x13:\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n#endif\n\n#if VECTOR_REGISTERS == 16\n        if (!FLAG_precise) {\n            switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) {\n                case 0x32:\n                    mc = 3;\n                    nc = 2;\n                    gemm<3, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x23:\n                    mc = 2;\n                    nc = 3;\n                    gemm<2, 3, false>(m0, m, n0, n);\n                    break;\n                case 0x22:\n                    mc = 2;\n                    nc = 2;\n                    gemm<2, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x31:\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, false>(m0, m, n0, n);\n                    break;\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, false>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, false>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        } else {\n            switch ((MIN(m - m0, 2) << 4) | MIN(n - n0, 1)) {\n                case 0x21:\n                    mc = 2;\n                    nc = 1;\n                    gemm<2, 1, true>(m0, m, n0, n);\n                    break;\n                case 0x12:\n                    mc = 1;\n                    nc = 2;\n                    gemm<1, 2, true>(m0, m, n0, n);\n                    break;\n                case 0x11:\n                    mc = 1;\n                    nc = 1;\n                    gemm<1, 1, true>(m0, m, n0, n);\n                    break;\n                default:\n                    return;\n            }\n        }\n#endif\n\n        mp = m0 + (m - m0) / mc * mc;\n        np = n0 + (n - n0) / nc * nc;\n        mnpack(mp, m, n0, np);\n        mnpack(m0, m, np, n);\n    }\n\n    template <int RM, int RN, int PRECISE>\n    NOINLINE void gemm(long m0, long m, long n0, long n) {\n        long ytiles = RM > 1 ? (m - m0) / RM : 1;\n        long xtiles = RN > 1 ? (n - n0) / RN : 1;\n        long tiles = xtiles * ytiles;\n        long duty = (tiles + nth - 1) / nth;\n        long start = duty * ith;\n        long end = start + duty;\n        if (end > tiles)\n            end = tiles;\n        for (long job = start; job < end; ++job) {\n            long ii = m0 + job / xtiles * RM;\n            long jj = n0 + job % xtiles * RN;\n            __m256 Cv[RN][RM] = {};\n            __m256 Ce[RN][RM] = {};\n            for (long l = 0; l < k; ++l)\n#pragma GCC unroll 100\n                for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                    for (int i = 0; i < RM; ++i) {\n                        __m256 a = _mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) *\n                                                  unhalf(INDEX(B, ldb, jj + j, l)->d));\n                        __m256 b = updot(_mm256_sign_epi8(load(INDEX(A, lda, ii + i, l)),\n                                                          load(INDEX(A, lda, ii + i, l))),\n                                         _mm256_sign_epi8(load(INDEX(B, ldb, jj + j, l)),\n                                                          load(INDEX(A, lda, ii + i, l))));\n                        if (PRECISE)\n                            Cv[j][i] = madder(a, b, Cv[j][i], &Ce[j][i]);\n                        else\n                            Cv[j][i] = madd(a, b, Cv[j][i]);\n                    }\n#pragma GCC unroll 100\n            for (int j = 0; j < RN; ++j)\n#pragma GCC unroll 100\n                for (int i = 0; i < RM; ++i)\n                    store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));\n        }\n    }\n\n    inline __m256i load(const block_q8_0* b) {\n        return _mm256_loadu_si256((const __m256i*)b->qs);\n    }\n\n    inline __m256i load(const block_q4_0* b) {\n        __m128i x = _mm_loadu_si128((const __m128i*)b->qs);\n        return _mm256_sub_epi8(_mm256_and_si256(_mm256_set1_epi8(15),\n                                                _mm256_insertf128_si256(_mm256_castsi128_si256(x),\n                                                                        _mm_srli_epi16(x, 4), 1)),\n                               _mm256_set1_epi8(8));\n    }\n\n    inline __m256 updot(__m256i u, __m256i s) {\n        __m256i res;\n#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))\n        res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);\n#else\n        res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));\n#endif\n        return _mm256_cvtepi32_ps(res);\n    }\n\n    const TA* const A;\n    const TB* const B;\n    TC* const C;\n    const long k;\n    const long lda;\n    const long ldb;\n    const long ldc;\n    const int ith;\n    const int nth;\n};\n#endif  // __AVX2__\n\n}  // namespace\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_mixmul.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul.inc\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"tinyblas_cpu.h\"\n\n//\n//\n//                                ██████╗ ██╗   █████╗ ██████╗\n//         ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║  ██╔══██╗██╔═══╝\n//         ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║  ███████║██████╗\n//           ██║  ██║██▀███║╚███╔╝██╔══██╗██║  ██╔══██║╔═══██║\n//           ██║  ██║██║ ██║ ███║ ██████╔╝████╗██║  ██║██████║\n//           ╚═╝  ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝  ╚═╝╚═════╝\n//\n//               MIXTURE OF EXPERTS TENSOR MULTIPLICATION\n//\n//\n// SHAPES\n//\n//   - weights [cols, rows, experts]\n//   - thought [cols, tasks, tokens] w/ tasks ≤ thinkers\n//   - result  [rows, thinkers, tokens] w/ thinkers ≤ experts\n//   - plan    [thinkers, tokens] w/ i32 < experts\n//\n// DEFINITION\n//\n//   for thinker in range(thinkers):\n//     for token in range(tokens):\n//       for row in range(rows):\n//         c = 0\n//         for col in range(cols):\n//           expert = plan[token][thinker]\n//           a = weights[expert][row][col]\n//           b = thought[token][thinker % tasks][col]\n//           c += a * b\n//         result[token][thinker][row] = c\n//\n// REGULARITIES\n//\n//   - tokens can be odd\n//   - thinkers is usually 2\n//   - tasks is usually 1 or 2\n//   - cols should be a multiple of 64\n//   - rows should be a multiple of 64\n//   - experts is usually 8 but could be 60\n//   - tokens is always 1 for token generation\n//   - tokens can be huge for prompt processing\n//\n// EXAMPLE\n//\n//   mixtral 8x7b w/ 217 token prompt\n//\n//           |  ne*0 ne*1 ne*2 ne*3 | nb*0    nb*1      nb*2       nb*3 | type\n//   =========================================================================\n//   weights | 16384 6144    8    1 |   18  0x2400 0x3600000 0x1b000000 | q4_0\n//   thought | 16384    2  217    1 |    4 0x10000   0x20000  0x1b20000 | f32\n//   result  |  6144    2  217    1 |    4  0x6000    0xc000   0xa2c000 | f32\n//   plan    |     2  217    1    1 |    4    0x20    0x1b20     0x1b20 | i32\n//\n\nnamespace {\n\nclass MixMul {\n   public:\n    MixMul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result)\n        : params(params),\n          weights(weights),\n          thought(thought),\n          plan(plan),\n          result(result),\n          rows(weights->ne[1]),\n          cols(weights->ne[0]),\n          experts(weights->ne[2]),\n          thinkers(plan->ne[0]),\n          tasks(thought->ne[1]),\n          tokens(thought->ne[2]),\n          ldq((cols * 2 + ROW_ALIGN - 1) & -ROW_ALIGN),\n          wdata_((char*)(((uintptr_t)params->wdata + MAX_ALIGN - 1) & -MAX_ALIGN)),\n          allocated_(0) {\n    }\n\n    bool allocate_shared_memory() {\n        if (!(quantized_thought_ = allocate<char>(MATRIX_ALIGN, tokens * tasks * ldq)))\n            return false;\n        if (!(rowptr_result_ = allocate<uintptr_t>(ROW_ALIGN, experts * tokens * thinkers)))\n            return false;\n        if (!(rowptr_thought_ = allocate<uintptr_t>(ROW_ALIGN, experts * tokens * thinkers)))\n            return false;\n        if (!(rowptr_count_ = allocate<long>(sizeof(long), experts)))\n            return false;\n        return true;\n    }\n\n    size_t get_allocated_bytes() {\n        return (wdata_ - (char*)params->wdata) + allocated_;\n    }\n\n    bool mixmul() {\n        // invariants\n        assert(tasks <= thinkers);\n        assert(thinkers <= experts);\n        assert(tokens == plan->ne[1]);\n        assert(rows == result->ne[0]);\n        assert(cols == thought->ne[0]);\n        assert(tokens == result->ne[2]);\n        assert(thinkers == result->ne[1]);\n\n        // dimensionality\n        assert(plan->ne[2] == 1);\n        assert(plan->ne[3] == 1);\n        assert(result->ne[3] == 1);\n        assert(weights->ne[3] == 1);\n        assert(thought->ne[3] == 1);\n\n        // miscellaneous\n        assert(params->nth > 0);\n        assert(params->ith < params->nth);\n        assert(plan->type == GGML_TYPE_I32);\n\n        // check nb01 is convertible to lda\n        if (weights->nb[1] % ggml_type_size(weights->type))\n            return false;\n\n        // no support for column strides\n        if (result->nb[0] != ggml_type_size(result->type))\n            return false;\n        if (thought->nb[0] != ggml_type_size(thought->type))\n            return false;\n        if (weights->nb[0] != ggml_type_size(weights->type))\n            return false;\n\n        // supported output types\n        switch (result->type) {\n            case GGML_TYPE_F32:\n                return mixmuler<float>();\n            default:\n                return false;\n        }\n    }\n\n   private:\n    template <typename TC>\n    bool mixmuler() {\n        switch (weights->type) {\n            case GGML_TYPE_F32:\n                if (thought->type != GGML_TYPE_F32)\n                    return false;\n#if defined(__AVX512F__)\n                return mixmat<16, 1, tinyBLAS<NCB | NCC, 16, __m512, __m512, float, float, TC>, float,\n                              float, TC>();\n#elif defined(__AVX__) || defined(__AVX2__)\n                return mixmat<8, 1, tinyBLAS<NCB | NCC, 8, __m256, __m256, float, float, TC>, float,\n                              float, TC>();\n#elif defined(__SSE__)\n                return mixmat<4, 1, tinyBLAS<NCB | NCC, 4, __m128, __m128, float, float, TC>, float,\n                              float, TC>();\n#elif defined(__ARM_NEON)\n                return mixmat<4, 1, tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, float, float, TC>,\n                              float, float, TC>();\n#else\n                return false;\n#endif\n\n            case GGML_TYPE_BF16:\n                if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_BF16)\n                    return false;\n#if defined(__AVX512BF16__)\n                if (!FLAG_precise) {\n                    return mixmat<\n                        32, 1, tinyBLAS<NCB | NCC, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC>,\n                        ggml_bf16_t, ggml_bf16_t, TC>();\n                } else {\n                    return mixmat<16, 1,\n                                  tinyBLAS<NCB | NCC, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC>,\n                                  ggml_bf16_t, ggml_bf16_t, TC>();\n                }\n#elif defined(__AVX512F__)\n                return mixmat<16, 1,\n                              tinyBLAS<NCB | NCC, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC>,\n                              ggml_bf16_t, ggml_bf16_t, TC>();\n#elif defined(__AVX2__)\n                return mixmat<8, 1,\n                              tinyBLAS<NCB | NCC, 8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, TC>,\n                              ggml_bf16_t, ggml_bf16_t, TC>();\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n                return mixmat<\n                    4, 1,\n                    tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_bf16_t, ggml_bf16_t, TC>,\n                    ggml_bf16_t, ggml_bf16_t, TC>();\n#else\n                return false;\n#endif\n\n            case GGML_TYPE_F16:\n                if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_F16)\n                    return false;\n#if defined(__AVX512F__)\n                return mixmat<16, 1,\n                              tinyBLAS<NCB | NCC, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC>,\n                              ggml_fp16_t, ggml_fp16_t, TC>();\n#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)\n                // if (X86_CHECK(F16C)) {\n                return mixmat<8, 1,\n                              tinyBLAS<NCB | NCC, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC>,\n                              ggml_fp16_t, ggml_fp16_t, TC>();\n                // } else {\n                //     return false;\n                // }\n#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)\n                if (result->op_params[0] == GGML_PREC_F32) {\n                    return mixmat<\n                        4, 1,\n                        tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, TC>,\n                        ggml_fp16_t, ggml_fp16_t, TC>();\n                } else {\n                    return mixmat<\n                        8, 1,\n                        tinyBLAS<NCB | NCC, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC>,\n                        ggml_fp16_t, ggml_fp16_t, TC>();\n                }\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n                return mixmat<\n                    4, 1,\n                    tinyBLAS<NCB | NCC, 4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, TC>,\n                    ggml_fp16_t, ggml_fp16_t, TC>();\n#else\n                return false;\n#endif\n\n            case GGML_TYPE_Q4_0:\n                if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0)\n                    return false;\n#if defined(__AVX2__) || defined(__AVX512F__)\n                return mixmat<32, 32, tinyBLAS_Q0_AVX2<NCB | NCC, block_q4_0, block_q8_0, TC>,\n                              block_q4_0, block_q8_0, TC>();\n#elif defined(__ARM_FEATURE_DOTPROD)\n                return mixmat<32, 32, tinyBLAS_Q0_ARM<NCB | NCC, block_q4_0, block_q8_0, TC>,\n                              block_q4_0, block_q8_0, TC>();\n#else\n                return false;\n#endif\n\n            case GGML_TYPE_Q8_0:\n                if (thought->type != GGML_TYPE_F32 && thought->type != GGML_TYPE_Q8_0)\n                    return false;\n#if defined(__AVX2__) || defined(__AVX512F__)\n                return mixmat<32, 32, tinyBLAS_Q0_AVX2<NCB | NCC, block_q8_0, block_q8_0, TC>,\n                              block_q8_0, block_q8_0, TC>();\n#elif defined(__ARM_FEATURE_DOTPROD)\n                return mixmat<32, 32, tinyBLAS_Q0_ARM<NCB | NCC, block_q8_0, block_q8_0, TC>,\n                              block_q8_0, block_q8_0, TC>();\n#else\n                return false;\n#endif\n\n            default:\n                return false;\n        }\n    }\n\n    template <int KN, int BS, typename BLAS, typename TA, typename TB, typename TC>\n    bool mixmat() {\n        if (cols % KN)\n            return false;\n        switch (params->type) {\n            case GGML_TASK_TYPE_INIT:\n                if (thought->type != ggml_type_trait<TB>::id)\n                    quantize_thought(ggml_type_trait<TB>::id);\n                build_row_pointers(ggml_type_trait<TB>::id);\n                return true;\n            case GGML_TASK_TYPE_COMPUTE:\n                assert(!(cols % BS));\n                assert(!(weights->nb[1] % sizeof(TA)));\n                for (int expert = 0; expert < experts; ++expert) {\n                    BLAS tb{cols / BS,\n                            (const TA*)((const char*)weights->data + expert * weights->nb[2]),\n                            (long)(weights->nb[1] / sizeof(TA)),\n                            (const TB*)(rowptr_thought_ + expert * tokens * thinkers),\n                            0,\n                            (TC*)(rowptr_result_ + expert * tokens * thinkers),\n                            0,\n                            params->ith,\n                            params->nth};\n                    tb.matmul(rows, rowptr_count_[expert], GGML_TASK_TYPE_COMPUTE);\n                }\n                return true;\n            default:\n                return true;\n        }\n    }\n\n    void build_row_pointers(ggml_type vec_dot_type) {\n        for (int expert = params->ith; expert < experts; expert += params->nth) {\n            long count = 0;\n            for (long token = 0; token < tokens; ++token)\n                for (int thinker = 0; thinker < thinkers; ++thinker)\n                    if (expert == *(const int32_t*)((const char*)plan->data +\n                                                    token * plan->nb[1] + thinker * plan->nb[0])) {\n                        long row = count++;\n                        long idx = expert * thinkers * tokens + row;\n                        rowptr_result_[idx] =\n                            (uintptr_t)((char*)result->data + token * result->nb[2] +\n                                        thinker * result->nb[1]);\n                        if (thought->type == vec_dot_type)\n                            rowptr_thought_[idx] =\n                                (uintptr_t)((char*)thought->data + token * thought->nb[2] +\n                                            thinker % tasks * thought->nb[1]);\n                        else\n                            rowptr_thought_[idx] =\n                                (uintptr_t)((char*)quantized_thought_ + token * tasks * ldq +\n                                            thinker % tasks * ldq);\n                    }\n            rowptr_count_[expert] = count;\n        }\n    }\n\n    void quantize_thought(ggml_type vec_dot_type) {\n        long chore = 0;\n        for (long token = 0; token < tokens; ++token)\n            for (int task = 0; task < tasks; ++task)\n                if (chore++ % params->nth == params->ith)\n                    quantize_row(quantized_thought_ + token * tasks * ldq + task * ldq,\n                                 (const float*)((const char*)thought->data +\n                                                token * thought->nb[2] + task * thought->nb[1]),\n                                 vec_dot_type);\n    }\n\n    void quantize_row(void* dst, const float* src, ggml_type type) {\n        assert((long)ggml_row_size(type, cols) <= ldq);\n        switch (type) {\n            case GGML_TYPE_F16:\n                ggml_fp32_to_fp16_row(src, (ggml_fp16_t*)dst, cols);\n                break;\n            case GGML_TYPE_BF16:\n                ggml_fp32_to_bf16_row(src, (ggml_bf16_t*)dst, cols);\n                break;\n            case GGML_TYPE_Q8_0:\n                quantize_row_q8_0((const float*)src, (block_q8_0*)dst, cols);\n                break;\n            default:\n                GGML_UNREACHABLE();\n        }\n    }\n\n    template <typename T>\n    T* allocate(size_t align, size_t elems) {\n        T* res = nullptr;\n        size_t need = sizeof(T) * elems;\n        size_t base = allocated_;\n        base += align - 1;\n        base &= -align;\n        size_t toto = base + need;\n        if (toto >= allocated_ && toto <= params->wsize) {\n            res = (T*)(wdata_ + base);\n            allocated_ = toto;\n        }\n        return res;\n    }\n\n    const ggml_compute_params* const params;\n    const ggml_tensor* const weights;\n    const ggml_tensor* const thought;\n    const ggml_tensor* const plan;\n    ggml_tensor* const result;\n    const long rows;\n    const long cols;\n    const int experts;\n    const int thinkers;\n    const int tasks;\n    const long tokens;\n    const long ldq;\n\n    // variables\n    char* const wdata_;\n    size_t allocated_;\n\n    // shared memory\n    long* rowptr_count_ /*[experts]*/;\n    char* quantized_thought_ /*[tokens][tasks][cols][2]*/;\n    uintptr_t* rowptr_result_ /*[experts][tokens*thinkers]*/;\n    uintptr_t* rowptr_thought_ /*[experts][tokens*thinkers]*/;\n};\n\n}  // namespace\n\n/**\n * Performs \"mixture of experts\" tensor multiplication on CPU.\n */\nbool llamafile_mixmul(const ggml_compute_params* params, const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan, ggml_tensor* result) {\n    MixMul mm{params, weights, thought, plan, result};\n    return mm.allocate_shared_memory() && mm.mixmul();\n}\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_avx\n#include \"tinyblas_cpu_mixmul.inc\"\n\n/**\n * Returns number of shared memory bytes llamafile_mixmul() needs.\n */\nsize_t llamafile_mixmul_needs(const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan) {\n    ggml_compute_params params{};\n    params.wsize = 0x7ffff000;\n    params.wdata = (void*)0x1000;\n    MixMul mm{&params, weights, thought, plan, 0};\n    if (mm.allocate_shared_memory())\n        return mm.get_allocated_bytes();\n    else\n        return 0;\n}\n\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx2.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_avx2\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avx512f.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_avx512f\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_avxvnni.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_avxvnni\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_fma.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_fma\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_amd_zen4.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_mixmul llamafile_mixmul_amd_zen4\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_mixmul_arm80.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_arm80.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#ifdef __aarch64__\n#define llamafile_mixmul llamafile_mixmul_arm80\n#include \"tinyblas_cpu_mixmul.inc\"\n\n/**\n * Returns number of shared memory bytes llamafile_mixmul() needs.\n */\nsize_t llamafile_mixmul_needs(const ggml_tensor* weights, const ggml_tensor* thought, const ggml_tensor* plan) {\n    ggml_compute_params params{};\n    params.wsize = 0x7ffff000;\n    params.wdata = (void*)0x1000;\n    MixMul mm{&params, weights, thought, plan, 0};\n    if (mm.allocate_shared_memory())\n        return mm.get_allocated_bytes();\n    else\n        return 0;\n}\n\n#endif  // __aarch64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_mixmul_arm82.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_mixmul_arm82.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#ifdef __aarch64__\n#define llamafile_mixmul llamafile_mixmul_arm82\n#include \"tinyblas_cpu_mixmul.inc\"\n#endif  // __aarch64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_sgemm.inc",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm.inc\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"tinyblas_cpu.h\"\n\n//\n//\n//                                ██████╗ ██╗   █████╗ ██████╗\n//         ██████╗██╗██╗ ██╗██═██╗██╔══██╗██║  ██╔══██╗██╔═══╝\n//         ╚═██╔═╝██║███▄██║██ ██║██████╔╝██║  ███████║██████╗\n//           ██║  ██║██▀███║╚███╔╝██╔══██╗██║  ██╔══██║╔═══██║\n//           ██║  ██║██║ ██║ ███║ ██████╔╝████╗██║  ██║██████║\n//           ╚═╝  ╚═╝╚═╝ ╚═╝ ╚══╝ ╚═════╝ ╚═══╝╚═╝  ╚═╝╚═════╝\n//\n//                   BASIC LINEAR ALGEBRA SUBPROGRAMS\n//\n//\n// This file implements multithreaded CPU matrix multiplication for the\n// common contiguous use case C = Aᵀ * B. These kernels are designed to\n// have excellent performance[1] for matrices that fit in the CPU cache\n// without imposing any overhead such as cache filling or malloc calls.\n//\n// This implementation does not guarantee any upper bound with rounding\n// errors, which grow along with k. Our goal's to maximally exploit the\n// hardware for performance, and then use whatever resources remain for\n// improving numerical accuracy.\n//\n// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].\n//     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].\n\nnamespace {\n\ntemplate <typename TC>\nbool llamafile_sgemm_impl(long m, long n, long k, const void* A, long lda, const void* B, long ldb, TC* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    switch (Atype) {\n        case GGML_TYPE_F32: {\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n#if defined(__AVX512F__)\n            if (k % 16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 16, __m512, __m512, float, float, TC> tb{\n                k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__AVX__) || defined(__AVX2__)\n            if (k % 8)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 8, __m256, __m256, float, float, TC> tb{\n                k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_NEON)\n            if (k % 4)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, TC> tb{\n                k, (const float*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_BF16: {\n#if defined(__AVX512BF16__)\n            if (k % 32)\n                return NOT_SUPPORTED;\n            if (Btype == GGML_TYPE_F32 && n < 2) {\n                tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{\n                    k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_BF16)\n                return NOT_SUPPORTED;\n            if (!FLAG_precise) {\n                tinyBLAS<0, 32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, TC> tb{\n                    k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            } else {\n                tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, TC> tb{\n                    k, (const ggml_bf16_t*)A, lda, (const ggml_bf16_t*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n#elif defined(__AVX512F__)\n            if (k % 16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 16, __m512, __m512, ggml_bf16_t, float, TC> tb{\n                k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__AVX2__)\n            if (k % 8)\n                return NOT_SUPPORTED;\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 8, __m256, __m256, ggml_bf16_t, float, TC> tb{\n                k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n            if (k % 4)\n                return NOT_SUPPORTED;\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_bf16_t, float, TC> tb{\n                k, (const ggml_bf16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_F16: {\n#if defined(__AVX512F__)\n            if (k % 16)\n                return NOT_SUPPORTED;\n            if (Btype == GGML_TYPE_F32 && n < 2) {\n                tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, float, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_F16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, TC> tb{\n                k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)\n            // if (X86_CHECK(F16C)) {\n            if (k % 8)\n                return NOT_SUPPORTED;\n            if (Btype == GGML_TYPE_F32 && n < 2) {\n                tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, float, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_F16)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, TC> tb{\n                k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n            // } else {\n            //     return NOT_SUPPORTED;\n            // }\n#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)\n            if (n < 2 && !FLAG_precise)\n                // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?\n                return NOT_SUPPORTED;\n            if (precision == GGML_PREC_F32) {\n                if (k % 4)\n                    return NOT_SUPPORTED;\n                if (Btype != GGML_TYPE_F32)\n                    return NOT_SUPPORTED;\n                tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            } else {\n                if (k % 8)\n                    return NOT_SUPPORTED;\n                if (Btype == GGML_TYPE_F32)\n                    return WANT_QUANTIZATION;\n                if (Btype != GGML_TYPE_F16)\n                    return NOT_SUPPORTED;\n                tinyBLAS<0, 8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, TC> tb{\n                    k, (const ggml_fp16_t*)A, lda, (const ggml_fp16_t*)B, ldb, C, ldc, ith, nth};\n                tb.matmul(m, n, task);\n                return true;\n            }\n#elif defined(__ARM_NEON) && !defined(_MSC_VER)\n            if (n < 2 && !FLAG_precise)\n                // TODO(jart): Why is ggml_vec_dot_f16_unroll() so fast at matvec?\n                return NOT_SUPPORTED;\n            if (k % 4)\n                return NOT_SUPPORTED;\n            if (Btype != GGML_TYPE_F32)\n                return NOT_SUPPORTED;\n            tinyBLAS<0, 4, float32x4_t, float32x4_t, ggml_fp16_t, float, TC> tb{\n                k, (const ggml_fp16_t*)A, lda, (const float*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_Q8_0: {\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_Q8_0)\n                return NOT_SUPPORTED;\n#if defined(__AVX2__) || defined(__AVX512F__)\n            tinyBLAS_Q0_AVX2<0, block_q8_0, block_q8_0, TC> tb{\n                k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_FEATURE_DOTPROD)\n            tinyBLAS_Q0_ARM<0, block_q8_0, block_q8_0, TC> tb{\n                k, (const block_q8_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        case GGML_TYPE_Q4_0: {\n            if (Btype == GGML_TYPE_F32)\n                return WANT_QUANTIZATION;\n            if (Btype != GGML_TYPE_Q8_0)\n                return NOT_SUPPORTED;\n#if defined(__AVX2__) || defined(__AVX512F__)\n            tinyBLAS_Q0_AVX2<0, block_q4_0, block_q8_0, TC> tb{\n                k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#elif defined(__ARM_FEATURE_DOTPROD)\n            tinyBLAS_Q0_ARM<0, block_q4_0, block_q8_0, TC> tb{\n                k, (const block_q4_0*)A, lda, (const block_q8_0*)B, ldb, C, ldc, ith, nth};\n            tb.matmul(m, n, task);\n            return true;\n#else\n            return NOT_SUPPORTED;\n#endif\n        }\n\n        default:\n            return NOT_SUPPORTED;\n    }\n\n    (void)m;\n    (void)n;\n    (void)k;\n    (void)A;\n    (void)lda;\n    (void)B;\n    (void)ldb;\n    (void)C;\n    (void)ldc;\n    (void)ith;\n    (void)nth;\n    (void)Atype;\n    (void)Btype;\n    (void)precision;\n}\n\n}  // namespace\n\n/**\n * Performs optimized matrix multiplication on CPU.\n *\n * This subroutine may compute C = Aᵀ * B with column major ordering.\n * Despite its name, this isn't a generalized implementation. Work is\n * only performed when a handwritten kernel is written and available.\n * Otherwise the caller should fall back to a general matmul routine.\n *\n * For example, for single-threaded single-precision GEMM you can say\n *\n *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 0, 1,\n *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32,\n *                     GGML_PREC_DEFAULT);\n *\n * @param m is rows in `A` and `C`\n * @param n is cols in `B` and `C`\n * @param k is cols in `A` and rows in `B`\n * @param A is first input matrix (always transposed)\n * @param lda is row stride of `A`\n * @param B is second input matrix (never transposed)\n * @param ldb is row stride of `B`\n * @param C is input/output array of output matrices\n * @param ldc is row stride of `C`\n * @param ith is thread id (must be less than `nth`)\n * @param nth is number of threads (must be greater than zero)\n * @param Atype is GGML data type of `A`\n * @param Btype is GGML data type of `B`\n * @param Ctype is GGML data type of `C`\n * @param precision may be used to control the internal compute type\n * @return true if this function was able to service the matmul request\n */\nbool llamafile_sgemm(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    assert(m >= 0);\n    assert(n >= 0);\n    assert(k >= 0);\n    assert(lda >= k);\n    assert(ldb >= k);\n    assert(ldc >= m);\n    assert(nth > 0);\n    assert(ith < nth);\n\n#if QK_K == 256\n#if defined(__x86_64__) || defined(_M_X64)\n#if defined(__AVX2__) && (defined(__FMA__) || (defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))))\n    /* \n    moonll\n    more Btype accept\n    }*/\n\n    if (Ctype == GGML_TYPE_F32){\n        if (iqk_mul_mat(m, n, k * ggml_blck_size(ggml_type(Atype)), Atype, A,lda,Btype, B,ldb, (float*)C, ldc, ith, nth)) {\n            return true;\n        }\n    }\n\n#endif\n#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER\n    if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) {\n        if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float*)C, ldc, ith, nth)) {\n            return true;\n        }\n    }\n    if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) {\n        // assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32);\n        assert((QK8_0 == 32) && (QK8_1 == 32) && (QK4_0 == 32) && (QK4_1 == 32) && (QK5_0 == 32) && (QK5_1 == 32));\n        if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float*)C, ldc, ith, nth)) {\n            return true;\n        }\n    }\n#endif\n#endif\n\n    switch (Ctype) {\n        case GGML_TYPE_F32:\n            return llamafile_sgemm_impl(m, n, k, A, lda, B, ldb, (float*)C, ldc, ith, nth, task, Atype,\n                                        Btype, Ctype, precision);\n        default:\n            return NOT_SUPPORTED;\n    }\n}"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_avx\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx2.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_avx2\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avx512f.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_avx512f\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_sgemm_amd_avxvnni.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_avxvnni.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_avxvnni\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_sgemm_amd_fma.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_fma.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_fma\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_sgemm_amd_zen4.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_amd_zen4.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#if defined(__x86_64__) || defined(_M_X64)\n#define llamafile_sgemm llamafile_sgemm_amd_zen4\n#define iqk_mul_mat iqk_mul_mat_zen4\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __x86_64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_sgemm_arm80.cpp",
    "content": "// // Adapted from\n// // https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_arm80.cpp\n// // Copyrigth 2024 Mozilla Foundation.\n// // Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// #ifdef __aarch64__\n// #define llamafile_sgemm llamafile_sgemm_arm80\n// #include \"tinyblas_cpu_sgemm.inc\"\n// #endif  // __aarch64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_sgemm_arm82.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_sgemm_arm82.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n#ifdef __aarch64__\n#define llamafile_sgemm llamafile_sgemm_arm82\n#define iqk_mul_mat iqk_mul_mat_arm82\n#include \"tinyblas_cpu_sgemm.inc\"\n#endif  // __aarch64__\n"
  },
  {
    "path": "third_party/llamafile/tinyblas_cpu_unsupported.cpp",
    "content": "// Adapted from\n// https://github.com/Mozilla-Ocho/llamafile/blob/0.8.8/llamafile/tinyblas_cpu_unsupported.cpp\n// Copyrigth 2024 Mozilla Foundation.\n// Copyright(c) 2024 by KVCache.AI, All Rights Reserved.\n\n// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-\n// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi\n//\n// Copyright 2024 Mozilla Foundation\n//\n// Licensed under the Apache License, Version 2.0 (the \"License\");\n// you may not use this file except in compliance with the License.\n// You may obtain a copy of the License at\n//\n//     http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing, software\n// distributed under the License is distributed on an \"AS IS\" BASIS,\n// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n// See the License for the specific language governing permissions and\n// limitations under the License.\n\n#include \"sgemm.h\"\n\nbool llamafile_sgemm_unsupported(long m, long n, long k, const void* A, long lda, const void* B, long ldb, void* C, long ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype, int precision) {\n    return false;\n}\n\nbool llamafile_mixmul_unsupported(const struct ggml_compute_params* params,\n                                  const struct ggml_tensor* weights,\n                                  const struct ggml_tensor* thought,\n                                  const struct ggml_tensor* plan,\n                                  struct ggml_tensor* result) {\n    return false;\n}\n\nbool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void*, const void*, float*, long, long, const void*, int, int) {\n    return false;\n}\n"
  },
  {
    "path": "version.py",
    "content": "\"\"\"\nKTransformers version information.\nShared across kt-kernel and kt-sft modules.\n\"\"\"\n\n__version__ = \"0.5.2.post1\"\n"
  }
]